Skip to content Skip to sidebar Skip to footer

Bert Get Sentence Level Embedding After Fine Tuning

I came across this page 1) I would like to get sentence level embedding (embedding given by [CLS] token) after the fine tuning is done. How could I do it? 2) I also noticed that th

Solution 1:

1) From BERT documentation

The output dictionary contains:

pooled_output: pooled output of the entire sequence with shape [batch_size, hidden_size]. sequence_output: representations of every token in the input sequence with shape [batch_size, max_sequence_length, hidden_size].

I've added pooled_output vector which corresponds to the CLS vector.

3) You receive log probabilities. Just apply softmax to get normal probabilities.

Now all it is left to do is for model to report it. I have left the log probs, but they are not necessary anymore.

See the code changes:

defcreate_model(is_predicting, input_ids, input_mask, segment_ids, labels,
                 num_labels):
  """Creates a classification model."""

  bert_module = hub.Module(
      BERT_MODEL_HUB,
      trainable=True)
  bert_inputs = dict(
      input_ids=input_ids,
      input_mask=input_mask,
      segment_ids=segment_ids)
  bert_outputs = bert_module(
      inputs=bert_inputs,
      signature="tokens",
      as_dict=True)

  # Use "pooled_output" for classification tasks on an entire sentence.# Use "sequence_outputs" for token-level output.
  output_layer = bert_outputs["pooled_output"]

  pooled_output = output_layer

  hidden_size = output_layer.shape[-1].value

  # Create our own layer to tune for politeness data.
  output_weights = tf.get_variable(
      "output_weights", [num_labels, hidden_size],
      initializer=tf.truncated_normal_initializer(stddev=0.02))

  output_bias = tf.get_variable(
      "output_bias", [num_labels], initializer=tf.zeros_initializer())

  with tf.variable_scope("loss"):

    # Dropout helps prevent overfitting
    output_layer = tf.nn.dropout(output_layer, keep_prob=0.9)

    logits = tf.matmul(output_layer, output_weights, transpose_b=True)
    logits = tf.nn.bias_add(logits, output_bias)
    log_probs = tf.nn.log_softmax(logits, axis=-1)
    probs = tf.nn.softmax(logits, axis=-1)

    # Convert labels into one-hot encoding
    one_hot_labels = tf.one_hot(labels, depth=num_labels, dtype=tf.float32)

    predicted_labels = tf.squeeze(tf.argmax(log_probs, axis=-1, output_type=tf.int32))
    # If we're predicting, we want predicted labels and the probabiltiies.if is_predicting:
      return (predicted_labels, log_probs, probs, pooled_output)

    # If we're train/eval, compute loss between predicted and actual label
    per_example_loss = -tf.reduce_sum(one_hot_labels * log_probs, axis=-1)
    loss = tf.reduce_mean(per_example_loss)
    return (loss, predicted_labels, log_probs, probs, pooled_output)

Now in the model_fn_builder() add support for those values:

# this should be changed in both places
  (predicted_labels, log_probs, probs, pooled_output) = create_model(
    is_predicting, input_ids, input_mask, segment_ids, label_ids, num_labels)

  # return dictionary of all the values you wanted
  predictions = {
      'log_probabilities': log_probs,
      'probabilities': probs,
      'labels': predicted_labels,
      'pooled_output': pooled_output
  }

Adjust getPrediction() accordingly and in the end your predictions will look like this:

('That movie was absolutely awful',
  array([0.99599314, 0.00400678], dtype=float32),  <= Probability
  array([-4.0148855e-03, -5.5197663e+00], dtype=float32), <= Log probability, same as previously
  'Negative', <= Label
  array([ 0.9181199 ,  0.7763732 ,  0.9999883 , -0.93533266, -0.9841384 ,
          0.78126144, -0.9918988 , -0.18764131,  0.9981035 ,  0.99999994,
          0.900716  , -0.99926263, -0.5078789 , -0.99417543, -0.07695035,
          0.9501321 ,  0.75836045,  0.49151263, -0.7886792 ,  0.97505844,
         -0.8931161 , -1.        ,  0.9318583 , -0.60531116, -0.8644371 ,
        ...
        andthisis768-d [CLS] vector (sentence embedding).    

Regarding 2): At my end training took about 5 minutes and test about 40 seconds. Very reasonable.

UPDATE

For 20k samples it took 12:48 to train and 2:07 minutes to test.

For 10k samples timings are 8:40 and 1:07 respectively.

Solution 2:

Sure, here is the rest of changes:

# model_fn_builder actually creates our model function# using the passed parameters for num_labels, learning_rate, etc.defmodel_fn_builder(num_labels, learning_rate, num_train_steps,
                     num_warmup_steps):
  """Returns `model_fn` closure for TPUEstimator."""defmodel_fn(features, labels, mode, params):  # pylint: disable=unused-argument"""The `model_fn` for TPUEstimator."""

    input_ids = features["input_ids"]
    input_mask = features["input_mask"]
    segment_ids = features["segment_ids"]
    label_ids = features["label_ids"]

    is_predicting = (mode == tf.estimator.ModeKeys.PREDICT)

    # TRAIN and EVALifnot is_predicting:

      (loss, predicted_labels, log_probs, probs, pooled_output) = create_model(
        is_predicting, input_ids, input_mask, segment_ids, label_ids, num_labels)

      train_op = bert.optimization.create_optimizer(
          loss, learning_rate, num_train_steps, num_warmup_steps, use_tpu=False)

      # Calculate evaluation metrics. defmetric_fn(label_ids, predicted_labels):
        accuracy = tf.metrics.accuracy(label_ids, predicted_labels)
        f1_score = tf.contrib.metrics.f1_score(
            label_ids,
            predicted_labels)
        auc = tf.metrics.auc(
            label_ids,
            predicted_labels)
        recall = tf.metrics.recall(
            label_ids,
            predicted_labels)
        precision = tf.metrics.precision(
            label_ids,
            predicted_labels) 
        true_pos = tf.metrics.true_positives(
            label_ids,
            predicted_labels)
        true_neg = tf.metrics.true_negatives(
            label_ids,
            predicted_labels)   
        false_pos = tf.metrics.false_positives(
            label_ids,
            predicted_labels)  
        false_neg = tf.metrics.false_negatives(
            label_ids,
            predicted_labels)
        return {
            "eval_accuracy": accuracy,
            "f1_score": f1_score,
            "auc": auc,
            "precision": precision,
            "recall": recall,
            "true_positives": true_pos,
            "true_negatives": true_neg,
            "false_positives": false_pos,
            "false_negatives": false_neg
        }

      eval_metrics = metric_fn(label_ids, predicted_labels)

      if mode == tf.estimator.ModeKeys.TRAIN:
        return tf.estimator.EstimatorSpec(mode=mode,
          loss=loss,
          train_op=train_op)
      else:
          return tf.estimator.EstimatorSpec(mode=mode,
            loss=loss,
            eval_metric_ops=eval_metrics)
    else:
      (predicted_labels, log_probs, probs, pooled_output) = create_model(
        is_predicting, input_ids, input_mask, segment_ids, label_ids, num_labels)

      predictions = {
          'log_probabilities': log_probs,
          'probabilities': probs,
          'labels': predicted_labels,
          'pooled_output': pooled_output
      }
      return tf.estimator.EstimatorSpec(mode, predictions=predictions)

  # Return the actual model function in the closurereturn model_fn


defgetPrediction(in_sentences):
  labels = ["Negative", "Positive"]
  input_examples = [run_classifier.InputExample(guid="", text_a = x, text_b = None, label = 0) for x in in_sentences] # here, "" is just a dummy label
  input_features = run_classifier.convert_examples_to_features(input_examples, label_list, MAX_SEQ_LENGTH, tokenizer)
  predict_input_fn = run_classifier.input_fn_builder(features=input_features, seq_length=MAX_SEQ_LENGTH, is_training=False, drop_remainder=False)
  predictions = estimator.predict(predict_input_fn)
  return [(sentence, prediction['probabilities'], prediction['log_probabilities'], labels[prediction['labels']], prediction['pooled_output']) for sentence, prediction inzip(in_sentences, predictions)]

and the first output (others is cut off bc 30K symbols limit on the answer):

[('That movie was absolutely awful',
  array([0.99599314, 0.00400678], dtype=float32),
  array([-4.0148855e-03, -5.5197663e+00], dtype=float32),
  'Negative',
  array([ 0.9181199 ,  0.7763732 ,  0.9999883 , -0.93533266, -0.9841384 ,
          0.78126144, -0.9918988 , -0.18764131,  0.9981035 ,  0.99999994,
          0.900716  , -0.99926263, -0.5078789 , -0.99417543, -0.07695035,
          0.9501321 ,  0.75836045,  0.49151263, -0.7886792 ,  0.97505844,
         -0.8931161 , -1.        ,  0.9318583 , -0.60531116, -0.8644371 ,
         -0.9999866 ,  0.5820049 ,  0.3257555 , -0.81900954, -0.8326617 ,
          0.87788117, -0.7791749 ,  0.11098853,  0.67873836,  0.9999771 ,
          0.9833652 , -0.8420576 ,  0.83076835,  0.37272754,  0.8667175 ,
          0.792386  , -0.82003427, -0.9999999 , -0.9382297 , -0.9713775 ,
          0.55752313,  1.        , -0.72632766, -0.4752956 , -0.9999852 ,
         -0.99974227, -0.9998661 , -0.3094257 , -0.93023825, -0.72663504,
          0.92974335, -0.8601105 , -0.8113003 ,  0.7660112 ,  0.9313508 ,
          0.21427669, -0.45660907,  0.99970686,  0.56852764, -0.9997675 ,
         -0.9999096 ,  0.8247045 ,  0.7205424 ,  0.47192624, -0.7523966 ,
         -0.9588541 , -0.48866934,  0.9809366 , -0.07110611, -0.99886   ,
         -0.63922834, -0.68144   , -1.        ,  0.8531816 ,  0.26078308,
         -0.99898577, -0.99968046,  0.6711601 ,  0.99857473, -0.99990964,
          1.        , -0.97127694, -0.10644457,  0.46306637, -0.32486317,
         -0.68167734,  0.43291137, -0.996574  ,  0.05164305,  0.9897354 ,
          0.93853104,  0.94800174,  0.9995697 ,  0.6532897 ,  0.93846226,
         -0.6281378 ,  0.5574107 ,  0.725278  ,  0.74160355, -0.6486919 ,
          0.88869256,  0.9439776 , -0.9654787 , -0.95139974, -0.9366148 ,
          0.17409436,  0.83473635, -0.87414986, -0.35965624, -0.8395183 ,
          0.5546853 ,  0.7452196 , -0.6152899 , -0.82187194, -0.65487677,
          0.94367695,  0.6834396 , -0.72266734,  0.99376386, -0.76821744,
          0.4485644 ,  0.99982166,  1.        ,  0.9260674 ,  0.9759094 ,
          0.9397613 ,  0.8128903 , -0.7918152 ,  0.30299878, -0.95160294,
          0.25385544, -0.57780135, -0.9999994 ,  0.9168113 , -0.36585295,
          0.9798102 ,  0.95976156, -0.99428   ,  0.6471789 , -0.9948078 ,
         -0.9686591 ,  0.93615085, -0.11481134,  0.87566274, -0.91601896,
          0.9952683 ,  0.26532048,  0.99861896,  0.79298306,  0.5872364 ,
         -0.56314534,  0.96794534,  0.9999797 ,  0.9879324 ,  0.5003342 ,
          0.9516269 , -0.8878316 , -0.9665091 , -0.88037425,  0.8356687 ,
         -0.71543014, -0.99985015, -0.9414574 ,  0.8681497 ,  0.950698  ,
         -0.8007153 ,  0.78748596,  0.9999305 ,  0.40210736,  0.4856055 ,
         -0.9390776 ,  0.63564163, -0.85989815, -0.8421344 , -0.99436   ,
          0.78081733, -0.97038007,  0.39290914,  0.7834218 ,  0.88715357,
         -0.03653741,  0.99126273, -0.96559966,  0.11924513, -0.99363935,
         -0.9901692 ,  0.963858  ,  0.5713922 ,  0.5676979 ,  0.69982123,
          0.858003  ,  0.9983819 , -0.87965024,  0.46213093, -0.3256273 ,
          0.77337253,  0.7246244 , -0.99894017, -0.9170495 , -0.98803675,
         -0.93148243,  0.09674019,  0.09448949, -0.7453027 , -0.78955775,
         -0.6304773 , -0.5597632 ,  0.992308  ,  0.7769483 ,  0.04146893,
         -0.15876745, -0.7682887 , -0.5231416 ,  0.7871302 ,  0.9503481 ,
         -0.9607153 ,  0.99047405, -0.9948017 , -0.82257754,  0.9990552 ,
          0.79346406, -0.78624016,  0.8760266 , -0.7855991 ,  0.13444276,
         -0.7183107 , -0.9999819 ,  0.7019429 , -0.918913  , -0.6569654 ,
          0.9998794 , -0.33805153, -0.9427715 ,  0.10419375, -0.94257164,
          0.9187495 , -0.9994855 , -0.99979955, -0.9277688 ,  0.6353426 ,
          0.9994905 ,  0.90688777,  0.9992008 ,  0.7817533 , -0.9996674 ,
         -0.999962  , -0.13310781, -0.82505953,  0.9997485 ,  0.82616794,
         -0.999998  ,  0.45386457,  0.6069964 ,  0.52272975,  0.8811922 ,
          0.52668494, -0.9994814 , -0.21601789, -0.99882716,  0.90246916,
          0.94196504,  0.30058604, -0.9876776 , -0.7699927 , -0.9980288 ,
          0.7727592 ,  0.9936947 ,  0.98021245, -0.77723926, -0.785372  ,
          0.5150317 ,  0.9983137 , -0.7461883 ,  0.3311537 , -0.63709795,
         -0.6487831 , -0.9173727 ,  0.9997706 , -0.9999893 , -1.        ,
          0.60389155, -0.6516268 , -0.95422006,  1.        ,  0.09109057,
         -0.99999994,  0.99998957,  1.        , -0.19451752,  0.94624877,
         -0.2761865 ,  1.        ,  0.52399474,  0.70230734,  0.5218801 ,
         -0.99716544, -0.70075685, -0.99992603,  1.        , -0.9785006 ,
          0.22457084, -0.5356722 , -0.9991887 ,  0.7062409 ,  0.66816545,
         -0.90308225, -0.8084922 ,  0.50301254, -0.7062079 ,  0.9998321 ,
          0.9823206 ,  0.9984027 ,  0.9948857 , -1.        , -0.7067878 ,
          0.975454  ,  0.87161005, -0.9882297 ,  0.8296374 , -0.88615334,
          0.4316883 ,  0.86287475, -0.9893329 , -0.9022001 , -0.68322754,
         -0.84212875,  0.78632677, -0.5131366 , -0.996949  , -0.75479275,
         -0.06342169,  0.92238575,  0.66769385,  0.9926053 , -0.78391105,
          0.9976865 ,  0.07086544,  0.34079495,  0.69730175, -0.99970955,
         -1.        , -0.9860551 ,  0.89584446, -0.96889114, -0.90435815,
          0.944296  , -1.        , -0.9931756 , -0.7014334 , -0.6742562 ,
         -0.96786517,  0.848328  ,  0.8903087 , -0.9998633 ,  0.73993397,
          0.99345684,  0.9691821 ,  0.87563246, -0.6073146 , -0.9999999 ,
          0.90763575,  0.30225936, -0.47824544,  0.7179979 ,  0.9450465 ,
          0.9715953 , -0.5422173 ,  0.99995065, -0.5920663 ,  0.92390317,
         -0.9670669 , -0.3623574 ,  0.74825   , -0.7817521 ,  0.9888685 ,
         -0.7653631 , -0.8933355 ,  0.9481424 ,  0.97803396, -0.9999731 ,
         -0.89597356,  0.35502487, -0.7190486 ,  0.30777818,  0.55025375,
          0.6365793 , -0.99094397, -1.        ,  0.93482614, -0.99970514,
          0.98721176,  0.14699097, -0.86038756, -0.68365514, -0.8104672 ,
          0.57238674,  0.97475344, -0.9963499 ,  0.98476464,  0.40495875,
         -0.7001948 , -0.40898973,  0.61900675, -1.        , -0.9371812 ,
         -0.62749994, -0.8841316 , -0.9999847 , -0.39386114, -0.925245  ,
         -0.99991447, -0.5872595 ,  0.5835767 ,  0.7003338 , -0.9761974 ,
          0.99995846,  0.33676207,  0.9079994 , -0.76412004, -0.7648706 ,
          0.68863285,  0.43983305,  0.74911463, -0.99995685, -0.6692586 ,
         -0.45761266, -0.9980771 , -1.        ,  0.31244457, -0.8834693 ,
          0.9388263 , -0.987405  ,  1.        ,  0.9512058 ,  0.23448633,
          0.37940192,  0.99989796,  0.8402514 , -0.84526414,  0.7378776 ,
         -0.9996204 , -0.99434114,  0.9987527 ,  0.5569713 ,  0.99648696,
         -0.9933159 , -0.13116199,  0.9999992 ,  0.9642579 , -0.48285434,
         -0.97517425,  0.7185596 ,  0.5286405 ,  0.9902838 ,  0.7796022 ,
         -0.80703837,  0.2376029 ,  0.534117  , -0.9999413 ,  0.99828076,
          0.9998345 ,  0.93249476,  0.3620626 ,  0.7567034 , -0.9222681 ,
          0.97832036,  0.9999682 ,  0.6433209 , -1.        ,  0.9268615 ,
         -0.9999511 , -0.9145363 , -0.9213852 ,  0.7606066 , -0.5501025 ,
         -0.99999434, -0.7783993 ,  0.9999771 ,  0.99980384,  0.987094  ,
          0.7531475 , -0.8551696 , -0.9973968 , -0.9999853 , -0.08913276,
         -0.9919206 , -0.49190572,  0.70230234, -0.31277484, -0.99999964,
          0.828591  ,  0.6363776 ,  0.86796165,  0.81575817,  0.7782955 ,
          0.9436437 , -1.        , -0.7509046 , -0.9946139 , -0.6647415 ,
          0.999543  ,  0.9312092 , -1.        ,  0.5639159 ,  0.9482462 ,
         -0.9289936 , -0.9678435 ,  0.60937124, -0.987818  ,  0.5511619 ,
          0.75886583, -0.48466644, -0.71833754,  0.8042149 ,  0.9154103 ,
         -0.8177468 ,  0.7195895 , -0.82283056,  0.24990956, -1.        ,
          0.7729634 ,  0.84048635,  0.7989596 ,  0.9469012 , -0.9898951 ,
         -0.92565274,  0.74726975,  0.78213847, -0.672894  , -0.58831286,
         -0.8039038 , -0.72197783,  0.5289216 , -0.9998796 , -0.9904479 ,
          0.9996592 , -0.28984115,  0.23964961, -0.7427149 , -0.662416  ,
         -1.        , -0.5538268 , -0.9945287 , -0.63471127,  0.5896127 ,
         -0.48429146,  0.9976076 , -0.94329506, -0.49143887,  0.7695602 ,
          0.8638134 , -0.82130384,  0.50105464,  0.9336961 , -0.24716294,
         -0.6922282 , -0.02228704,  0.75649065,  0.82303154, -0.30867255,
         -0.9602714 ,  0.64568967,  0.314201  , -0.4811752 ,  0.27952817,
          0.9227022 ,  0.88095886,  0.89470226,  1.        , -0.19237158,
          1.        , -0.991253  , -0.9991121 ,  0.5637482 , -0.75780976,
         -0.3904836 , -0.9881965 , -0.2912058 ,  0.9998215 ,  0.9869475 ,
         -0.12784953,  0.81566185,  0.9787118 , -0.17835459, -0.7027824 ,
          0.72269535, -0.18194303,  0.9968796 ,  0.03490257,  0.7751488 ,
         -1.        , -0.7761089 ,  0.85105944,  0.9968074 , -0.8156342 ,
          0.5300792 , -1.        ,  0.99626255, -0.7515625 , -0.6672005 ,
          0.9792111 ,  0.8660997 , -0.69161206,  0.32184905,  0.9071073 ,
          0.9999385 , -0.82744277, -0.99044186, -0.71309817, -0.5004305 ,
          0.70707524,  0.89751345, -0.6819585 , -0.9999414 , -0.45255637,
         -0.94375473, -0.91838425,  0.64272994,  0.9375524 ,  0.6609169 ,
         -0.88743365, -0.9534722 , -0.47888806, -1.        , -0.5251781 ,
          0.8274516 ,  0.9326824 ,  0.8961964 ,  0.5295862 ,  0.43714878,
         -0.7488347 , -0.75295556, -0.5187054 ,  0.75924635, -0.7862662 ,
          0.99981725, -0.80290836,  0.97651815,  0.99763787, -0.29619345,
         -0.1252967 ,  0.33606276, -0.65137684, -0.9680231 ,  0.77586985,
          0.22347753,  0.27245504, -0.07826214, -0.8383849 , -0.85373163,
          1.        , -0.4563588 , -0.91339815, -0.9999861 ,  0.66063935,
         -0.985843  , -0.7818757 , -0.7000497 , -0.6840764 ,  0.9995542 ,
          0.60819125,  0.80064404, -0.9776968 , -0.90925264, -0.6644932 ,
         -0.8771755 ,  0.71411085,  0.8113569 ,  0.9974196 , -0.75211936,
          0.63400257, -0.8272833 ,  0.99780786,  0.9965285 ,  0.59551436,
         -0.9876875 , -0.04439292,  0.9939223 ,  0.9993717 , -0.9965501 ,
         -0.9630328 , -0.9027949 , -0.48490363, -0.60193753, -0.6870232 ,
         -0.95355797, -0.67561924,  0.9997761 , -0.85473967,  0.998495  ,
         -0.95756954,  0.633171  ,  0.4570475 , -0.5316367 , -0.9663824 ,
          0.9567106 , -0.45497724,  0.12964879,  0.9964744 , -0.9711668 ,
          0.69636106, -0.9178346 ,  0.8313186 ,  0.69686604,  0.8141587 ,
         -0.33600506,  0.94798595,  0.8800869 ,  0.15029034, -0.91185665,
          0.6322724 , -0.9971475 ,  0.71948224,  0.9695236 ,  0.84242374,
          0.99995124,  0.5982563 , -0.98341423,  0.61301434,  0.9997318 ,
         -0.9981808 , -0.65651804, -0.8484874 , -0.9961815 ,  0.9030814 ,
          0.87141925,  0.8897381 , -0.92870414,  0.07134341,  0.8739935 ,
          0.91630197, -0.9465984 , -0.59741104, -1.        ,  0.9989559 ,
          0.99991184,  0.67439264,  0.92025673, -0.60730827,  0.8362061 ,
          1.        , -0.70801497,  0.9883806 , -0.9984141 ,  0.9919259 ,
         -0.998869  ,  0.9976203 ,  0.9888036 ,  0.8556838 , -0.9722744 ,
         -0.99810714,  0.8182833 ,  0.98808485,  0.6643728 ,  0.99212515,
         -0.99988   ,  0.26405996,  0.93139845,  0.99021816,  0.6846886 ,
          0.9986462 ,  0.92254627, -0.6406982 ], dtype=float32)),
 ('The acting was a bit lacking',
  array([0.9921152 , 0.00788479], dtype=float32),
  array([-0.00791603, -4.842819  ], dtype=float32),
  'Negative',
  array([ 0.67417824,  0.8235167 ,  0.99999565, -0.8565971 , -0.99499583,
          0.8219966 , -0.9185583 , -0.5234593 ,  0.99962074,  0.99999714,
          0.9507927 , -0.9996754 ,  0.22211392, -0.99826247,  0.7562492 ,
          0.93803996,  0.82738185,  0.4773049 , -0.73478544,  0.85207295,

Post a Comment for "Bert Get Sentence Level Embedding After Fine Tuning"