Bert Get Sentence Level Embedding After Fine Tuning
Solution 1:
1) From BERT documentation
The output dictionary contains:
pooled_output: pooled output of the entire sequence with shape [batch_size, hidden_size]. sequence_output: representations of every token in the input sequence with shape [batch_size, max_sequence_length, hidden_size].
I've added pooled_output
vector which corresponds to the CLS vector.
3) You receive log probabilities. Just apply softmax
to get normal probabilities.
Now all it is left to do is for model to report it. I have left the log probs, but they are not necessary anymore.
See the code changes:
defcreate_model(is_predicting, input_ids, input_mask, segment_ids, labels,
num_labels):
"""Creates a classification model."""
bert_module = hub.Module(
BERT_MODEL_HUB,
trainable=True)
bert_inputs = dict(
input_ids=input_ids,
input_mask=input_mask,
segment_ids=segment_ids)
bert_outputs = bert_module(
inputs=bert_inputs,
signature="tokens",
as_dict=True)
# Use "pooled_output" for classification tasks on an entire sentence.# Use "sequence_outputs" for token-level output.
output_layer = bert_outputs["pooled_output"]
pooled_output = output_layer
hidden_size = output_layer.shape[-1].value
# Create our own layer to tune for politeness data.
output_weights = tf.get_variable(
"output_weights", [num_labels, hidden_size],
initializer=tf.truncated_normal_initializer(stddev=0.02))
output_bias = tf.get_variable(
"output_bias", [num_labels], initializer=tf.zeros_initializer())
with tf.variable_scope("loss"):
# Dropout helps prevent overfitting
output_layer = tf.nn.dropout(output_layer, keep_prob=0.9)
logits = tf.matmul(output_layer, output_weights, transpose_b=True)
logits = tf.nn.bias_add(logits, output_bias)
log_probs = tf.nn.log_softmax(logits, axis=-1)
probs = tf.nn.softmax(logits, axis=-1)
# Convert labels into one-hot encoding
one_hot_labels = tf.one_hot(labels, depth=num_labels, dtype=tf.float32)
predicted_labels = tf.squeeze(tf.argmax(log_probs, axis=-1, output_type=tf.int32))
# If we're predicting, we want predicted labels and the probabiltiies.if is_predicting:
return (predicted_labels, log_probs, probs, pooled_output)
# If we're train/eval, compute loss between predicted and actual label
per_example_loss = -tf.reduce_sum(one_hot_labels * log_probs, axis=-1)
loss = tf.reduce_mean(per_example_loss)
return (loss, predicted_labels, log_probs, probs, pooled_output)
Now in the model_fn_builder()
add support for those values:
# this should be changed in both places
(predicted_labels, log_probs, probs, pooled_output) = create_model(
is_predicting, input_ids, input_mask, segment_ids, label_ids, num_labels)
# return dictionary of all the values you wanted
predictions = {
'log_probabilities': log_probs,
'probabilities': probs,
'labels': predicted_labels,
'pooled_output': pooled_output
}
Adjust getPrediction()
accordingly and in the end your predictions will look like this:
('That movie was absolutely awful',
array([0.99599314, 0.00400678], dtype=float32), <= Probability
array([-4.0148855e-03, -5.5197663e+00], dtype=float32), <= Log probability, same as previously
'Negative', <= Label
array([ 0.9181199 , 0.7763732 , 0.9999883 , -0.93533266, -0.9841384 ,
0.78126144, -0.9918988 , -0.18764131, 0.9981035 , 0.99999994,
0.900716 , -0.99926263, -0.5078789 , -0.99417543, -0.07695035,
0.9501321 , 0.75836045, 0.49151263, -0.7886792 , 0.97505844,
-0.8931161 , -1. , 0.9318583 , -0.60531116, -0.8644371 ,
...
andthisis768-d [CLS] vector (sentence embedding).
Regarding 2): At my end training took about 5 minutes and test about 40 seconds. Very reasonable.
UPDATE
For 20k samples it took 12:48 to train and 2:07 minutes to test.
For 10k samples timings are 8:40 and 1:07 respectively.
Solution 2:
Sure, here is the rest of changes:
# model_fn_builder actually creates our model function# using the passed parameters for num_labels, learning_rate, etc.defmodel_fn_builder(num_labels, learning_rate, num_train_steps,
num_warmup_steps):
"""Returns `model_fn` closure for TPUEstimator."""defmodel_fn(features, labels, mode, params): # pylint: disable=unused-argument"""The `model_fn` for TPUEstimator."""
input_ids = features["input_ids"]
input_mask = features["input_mask"]
segment_ids = features["segment_ids"]
label_ids = features["label_ids"]
is_predicting = (mode == tf.estimator.ModeKeys.PREDICT)
# TRAIN and EVALifnot is_predicting:
(loss, predicted_labels, log_probs, probs, pooled_output) = create_model(
is_predicting, input_ids, input_mask, segment_ids, label_ids, num_labels)
train_op = bert.optimization.create_optimizer(
loss, learning_rate, num_train_steps, num_warmup_steps, use_tpu=False)
# Calculate evaluation metrics. defmetric_fn(label_ids, predicted_labels):
accuracy = tf.metrics.accuracy(label_ids, predicted_labels)
f1_score = tf.contrib.metrics.f1_score(
label_ids,
predicted_labels)
auc = tf.metrics.auc(
label_ids,
predicted_labels)
recall = tf.metrics.recall(
label_ids,
predicted_labels)
precision = tf.metrics.precision(
label_ids,
predicted_labels)
true_pos = tf.metrics.true_positives(
label_ids,
predicted_labels)
true_neg = tf.metrics.true_negatives(
label_ids,
predicted_labels)
false_pos = tf.metrics.false_positives(
label_ids,
predicted_labels)
false_neg = tf.metrics.false_negatives(
label_ids,
predicted_labels)
return {
"eval_accuracy": accuracy,
"f1_score": f1_score,
"auc": auc,
"precision": precision,
"recall": recall,
"true_positives": true_pos,
"true_negatives": true_neg,
"false_positives": false_pos,
"false_negatives": false_neg
}
eval_metrics = metric_fn(label_ids, predicted_labels)
if mode == tf.estimator.ModeKeys.TRAIN:
return tf.estimator.EstimatorSpec(mode=mode,
loss=loss,
train_op=train_op)
else:
return tf.estimator.EstimatorSpec(mode=mode,
loss=loss,
eval_metric_ops=eval_metrics)
else:
(predicted_labels, log_probs, probs, pooled_output) = create_model(
is_predicting, input_ids, input_mask, segment_ids, label_ids, num_labels)
predictions = {
'log_probabilities': log_probs,
'probabilities': probs,
'labels': predicted_labels,
'pooled_output': pooled_output
}
return tf.estimator.EstimatorSpec(mode, predictions=predictions)
# Return the actual model function in the closurereturn model_fn
defgetPrediction(in_sentences):
labels = ["Negative", "Positive"]
input_examples = [run_classifier.InputExample(guid="", text_a = x, text_b = None, label = 0) for x in in_sentences] # here, "" is just a dummy label
input_features = run_classifier.convert_examples_to_features(input_examples, label_list, MAX_SEQ_LENGTH, tokenizer)
predict_input_fn = run_classifier.input_fn_builder(features=input_features, seq_length=MAX_SEQ_LENGTH, is_training=False, drop_remainder=False)
predictions = estimator.predict(predict_input_fn)
return [(sentence, prediction['probabilities'], prediction['log_probabilities'], labels[prediction['labels']], prediction['pooled_output']) for sentence, prediction inzip(in_sentences, predictions)]
and the first output (others is cut off bc 30K symbols limit on the answer):
[('That movie was absolutely awful',
array([0.99599314, 0.00400678], dtype=float32),
array([-4.0148855e-03, -5.5197663e+00], dtype=float32),
'Negative',
array([ 0.9181199 , 0.7763732 , 0.9999883 , -0.93533266, -0.9841384 ,
0.78126144, -0.9918988 , -0.18764131, 0.9981035 , 0.99999994,
0.900716 , -0.99926263, -0.5078789 , -0.99417543, -0.07695035,
0.9501321 , 0.75836045, 0.49151263, -0.7886792 , 0.97505844,
-0.8931161 , -1. , 0.9318583 , -0.60531116, -0.8644371 ,
-0.9999866 , 0.5820049 , 0.3257555 , -0.81900954, -0.8326617 ,
0.87788117, -0.7791749 , 0.11098853, 0.67873836, 0.9999771 ,
0.9833652 , -0.8420576 , 0.83076835, 0.37272754, 0.8667175 ,
0.792386 , -0.82003427, -0.9999999 , -0.9382297 , -0.9713775 ,
0.55752313, 1. , -0.72632766, -0.4752956 , -0.9999852 ,
-0.99974227, -0.9998661 , -0.3094257 , -0.93023825, -0.72663504,
0.92974335, -0.8601105 , -0.8113003 , 0.7660112 , 0.9313508 ,
0.21427669, -0.45660907, 0.99970686, 0.56852764, -0.9997675 ,
-0.9999096 , 0.8247045 , 0.7205424 , 0.47192624, -0.7523966 ,
-0.9588541 , -0.48866934, 0.9809366 , -0.07110611, -0.99886 ,
-0.63922834, -0.68144 , -1. , 0.8531816 , 0.26078308,
-0.99898577, -0.99968046, 0.6711601 , 0.99857473, -0.99990964,
1. , -0.97127694, -0.10644457, 0.46306637, -0.32486317,
-0.68167734, 0.43291137, -0.996574 , 0.05164305, 0.9897354 ,
0.93853104, 0.94800174, 0.9995697 , 0.6532897 , 0.93846226,
-0.6281378 , 0.5574107 , 0.725278 , 0.74160355, -0.6486919 ,
0.88869256, 0.9439776 , -0.9654787 , -0.95139974, -0.9366148 ,
0.17409436, 0.83473635, -0.87414986, -0.35965624, -0.8395183 ,
0.5546853 , 0.7452196 , -0.6152899 , -0.82187194, -0.65487677,
0.94367695, 0.6834396 , -0.72266734, 0.99376386, -0.76821744,
0.4485644 , 0.99982166, 1. , 0.9260674 , 0.9759094 ,
0.9397613 , 0.8128903 , -0.7918152 , 0.30299878, -0.95160294,
0.25385544, -0.57780135, -0.9999994 , 0.9168113 , -0.36585295,
0.9798102 , 0.95976156, -0.99428 , 0.6471789 , -0.9948078 ,
-0.9686591 , 0.93615085, -0.11481134, 0.87566274, -0.91601896,
0.9952683 , 0.26532048, 0.99861896, 0.79298306, 0.5872364 ,
-0.56314534, 0.96794534, 0.9999797 , 0.9879324 , 0.5003342 ,
0.9516269 , -0.8878316 , -0.9665091 , -0.88037425, 0.8356687 ,
-0.71543014, -0.99985015, -0.9414574 , 0.8681497 , 0.950698 ,
-0.8007153 , 0.78748596, 0.9999305 , 0.40210736, 0.4856055 ,
-0.9390776 , 0.63564163, -0.85989815, -0.8421344 , -0.99436 ,
0.78081733, -0.97038007, 0.39290914, 0.7834218 , 0.88715357,
-0.03653741, 0.99126273, -0.96559966, 0.11924513, -0.99363935,
-0.9901692 , 0.963858 , 0.5713922 , 0.5676979 , 0.69982123,
0.858003 , 0.9983819 , -0.87965024, 0.46213093, -0.3256273 ,
0.77337253, 0.7246244 , -0.99894017, -0.9170495 , -0.98803675,
-0.93148243, 0.09674019, 0.09448949, -0.7453027 , -0.78955775,
-0.6304773 , -0.5597632 , 0.992308 , 0.7769483 , 0.04146893,
-0.15876745, -0.7682887 , -0.5231416 , 0.7871302 , 0.9503481 ,
-0.9607153 , 0.99047405, -0.9948017 , -0.82257754, 0.9990552 ,
0.79346406, -0.78624016, 0.8760266 , -0.7855991 , 0.13444276,
-0.7183107 , -0.9999819 , 0.7019429 , -0.918913 , -0.6569654 ,
0.9998794 , -0.33805153, -0.9427715 , 0.10419375, -0.94257164,
0.9187495 , -0.9994855 , -0.99979955, -0.9277688 , 0.6353426 ,
0.9994905 , 0.90688777, 0.9992008 , 0.7817533 , -0.9996674 ,
-0.999962 , -0.13310781, -0.82505953, 0.9997485 , 0.82616794,
-0.999998 , 0.45386457, 0.6069964 , 0.52272975, 0.8811922 ,
0.52668494, -0.9994814 , -0.21601789, -0.99882716, 0.90246916,
0.94196504, 0.30058604, -0.9876776 , -0.7699927 , -0.9980288 ,
0.7727592 , 0.9936947 , 0.98021245, -0.77723926, -0.785372 ,
0.5150317 , 0.9983137 , -0.7461883 , 0.3311537 , -0.63709795,
-0.6487831 , -0.9173727 , 0.9997706 , -0.9999893 , -1. ,
0.60389155, -0.6516268 , -0.95422006, 1. , 0.09109057,
-0.99999994, 0.99998957, 1. , -0.19451752, 0.94624877,
-0.2761865 , 1. , 0.52399474, 0.70230734, 0.5218801 ,
-0.99716544, -0.70075685, -0.99992603, 1. , -0.9785006 ,
0.22457084, -0.5356722 , -0.9991887 , 0.7062409 , 0.66816545,
-0.90308225, -0.8084922 , 0.50301254, -0.7062079 , 0.9998321 ,
0.9823206 , 0.9984027 , 0.9948857 , -1. , -0.7067878 ,
0.975454 , 0.87161005, -0.9882297 , 0.8296374 , -0.88615334,
0.4316883 , 0.86287475, -0.9893329 , -0.9022001 , -0.68322754,
-0.84212875, 0.78632677, -0.5131366 , -0.996949 , -0.75479275,
-0.06342169, 0.92238575, 0.66769385, 0.9926053 , -0.78391105,
0.9976865 , 0.07086544, 0.34079495, 0.69730175, -0.99970955,
-1. , -0.9860551 , 0.89584446, -0.96889114, -0.90435815,
0.944296 , -1. , -0.9931756 , -0.7014334 , -0.6742562 ,
-0.96786517, 0.848328 , 0.8903087 , -0.9998633 , 0.73993397,
0.99345684, 0.9691821 , 0.87563246, -0.6073146 , -0.9999999 ,
0.90763575, 0.30225936, -0.47824544, 0.7179979 , 0.9450465 ,
0.9715953 , -0.5422173 , 0.99995065, -0.5920663 , 0.92390317,
-0.9670669 , -0.3623574 , 0.74825 , -0.7817521 , 0.9888685 ,
-0.7653631 , -0.8933355 , 0.9481424 , 0.97803396, -0.9999731 ,
-0.89597356, 0.35502487, -0.7190486 , 0.30777818, 0.55025375,
0.6365793 , -0.99094397, -1. , 0.93482614, -0.99970514,
0.98721176, 0.14699097, -0.86038756, -0.68365514, -0.8104672 ,
0.57238674, 0.97475344, -0.9963499 , 0.98476464, 0.40495875,
-0.7001948 , -0.40898973, 0.61900675, -1. , -0.9371812 ,
-0.62749994, -0.8841316 , -0.9999847 , -0.39386114, -0.925245 ,
-0.99991447, -0.5872595 , 0.5835767 , 0.7003338 , -0.9761974 ,
0.99995846, 0.33676207, 0.9079994 , -0.76412004, -0.7648706 ,
0.68863285, 0.43983305, 0.74911463, -0.99995685, -0.6692586 ,
-0.45761266, -0.9980771 , -1. , 0.31244457, -0.8834693 ,
0.9388263 , -0.987405 , 1. , 0.9512058 , 0.23448633,
0.37940192, 0.99989796, 0.8402514 , -0.84526414, 0.7378776 ,
-0.9996204 , -0.99434114, 0.9987527 , 0.5569713 , 0.99648696,
-0.9933159 , -0.13116199, 0.9999992 , 0.9642579 , -0.48285434,
-0.97517425, 0.7185596 , 0.5286405 , 0.9902838 , 0.7796022 ,
-0.80703837, 0.2376029 , 0.534117 , -0.9999413 , 0.99828076,
0.9998345 , 0.93249476, 0.3620626 , 0.7567034 , -0.9222681 ,
0.97832036, 0.9999682 , 0.6433209 , -1. , 0.9268615 ,
-0.9999511 , -0.9145363 , -0.9213852 , 0.7606066 , -0.5501025 ,
-0.99999434, -0.7783993 , 0.9999771 , 0.99980384, 0.987094 ,
0.7531475 , -0.8551696 , -0.9973968 , -0.9999853 , -0.08913276,
-0.9919206 , -0.49190572, 0.70230234, -0.31277484, -0.99999964,
0.828591 , 0.6363776 , 0.86796165, 0.81575817, 0.7782955 ,
0.9436437 , -1. , -0.7509046 , -0.9946139 , -0.6647415 ,
0.999543 , 0.9312092 , -1. , 0.5639159 , 0.9482462 ,
-0.9289936 , -0.9678435 , 0.60937124, -0.987818 , 0.5511619 ,
0.75886583, -0.48466644, -0.71833754, 0.8042149 , 0.9154103 ,
-0.8177468 , 0.7195895 , -0.82283056, 0.24990956, -1. ,
0.7729634 , 0.84048635, 0.7989596 , 0.9469012 , -0.9898951 ,
-0.92565274, 0.74726975, 0.78213847, -0.672894 , -0.58831286,
-0.8039038 , -0.72197783, 0.5289216 , -0.9998796 , -0.9904479 ,
0.9996592 , -0.28984115, 0.23964961, -0.7427149 , -0.662416 ,
-1. , -0.5538268 , -0.9945287 , -0.63471127, 0.5896127 ,
-0.48429146, 0.9976076 , -0.94329506, -0.49143887, 0.7695602 ,
0.8638134 , -0.82130384, 0.50105464, 0.9336961 , -0.24716294,
-0.6922282 , -0.02228704, 0.75649065, 0.82303154, -0.30867255,
-0.9602714 , 0.64568967, 0.314201 , -0.4811752 , 0.27952817,
0.9227022 , 0.88095886, 0.89470226, 1. , -0.19237158,
1. , -0.991253 , -0.9991121 , 0.5637482 , -0.75780976,
-0.3904836 , -0.9881965 , -0.2912058 , 0.9998215 , 0.9869475 ,
-0.12784953, 0.81566185, 0.9787118 , -0.17835459, -0.7027824 ,
0.72269535, -0.18194303, 0.9968796 , 0.03490257, 0.7751488 ,
-1. , -0.7761089 , 0.85105944, 0.9968074 , -0.8156342 ,
0.5300792 , -1. , 0.99626255, -0.7515625 , -0.6672005 ,
0.9792111 , 0.8660997 , -0.69161206, 0.32184905, 0.9071073 ,
0.9999385 , -0.82744277, -0.99044186, -0.71309817, -0.5004305 ,
0.70707524, 0.89751345, -0.6819585 , -0.9999414 , -0.45255637,
-0.94375473, -0.91838425, 0.64272994, 0.9375524 , 0.6609169 ,
-0.88743365, -0.9534722 , -0.47888806, -1. , -0.5251781 ,
0.8274516 , 0.9326824 , 0.8961964 , 0.5295862 , 0.43714878,
-0.7488347 , -0.75295556, -0.5187054 , 0.75924635, -0.7862662 ,
0.99981725, -0.80290836, 0.97651815, 0.99763787, -0.29619345,
-0.1252967 , 0.33606276, -0.65137684, -0.9680231 , 0.77586985,
0.22347753, 0.27245504, -0.07826214, -0.8383849 , -0.85373163,
1. , -0.4563588 , -0.91339815, -0.9999861 , 0.66063935,
-0.985843 , -0.7818757 , -0.7000497 , -0.6840764 , 0.9995542 ,
0.60819125, 0.80064404, -0.9776968 , -0.90925264, -0.6644932 ,
-0.8771755 , 0.71411085, 0.8113569 , 0.9974196 , -0.75211936,
0.63400257, -0.8272833 , 0.99780786, 0.9965285 , 0.59551436,
-0.9876875 , -0.04439292, 0.9939223 , 0.9993717 , -0.9965501 ,
-0.9630328 , -0.9027949 , -0.48490363, -0.60193753, -0.6870232 ,
-0.95355797, -0.67561924, 0.9997761 , -0.85473967, 0.998495 ,
-0.95756954, 0.633171 , 0.4570475 , -0.5316367 , -0.9663824 ,
0.9567106 , -0.45497724, 0.12964879, 0.9964744 , -0.9711668 ,
0.69636106, -0.9178346 , 0.8313186 , 0.69686604, 0.8141587 ,
-0.33600506, 0.94798595, 0.8800869 , 0.15029034, -0.91185665,
0.6322724 , -0.9971475 , 0.71948224, 0.9695236 , 0.84242374,
0.99995124, 0.5982563 , -0.98341423, 0.61301434, 0.9997318 ,
-0.9981808 , -0.65651804, -0.8484874 , -0.9961815 , 0.9030814 ,
0.87141925, 0.8897381 , -0.92870414, 0.07134341, 0.8739935 ,
0.91630197, -0.9465984 , -0.59741104, -1. , 0.9989559 ,
0.99991184, 0.67439264, 0.92025673, -0.60730827, 0.8362061 ,
1. , -0.70801497, 0.9883806 , -0.9984141 , 0.9919259 ,
-0.998869 , 0.9976203 , 0.9888036 , 0.8556838 , -0.9722744 ,
-0.99810714, 0.8182833 , 0.98808485, 0.6643728 , 0.99212515,
-0.99988 , 0.26405996, 0.93139845, 0.99021816, 0.6846886 ,
0.9986462 , 0.92254627, -0.6406982 ], dtype=float32)),
('The acting was a bit lacking',
array([0.9921152 , 0.00788479], dtype=float32),
array([-0.00791603, -4.842819 ], dtype=float32),
'Negative',
array([ 0.67417824, 0.8235167 , 0.99999565, -0.8565971 , -0.99499583,
0.8219966 , -0.9185583 , -0.5234593 , 0.99962074, 0.99999714,
0.9507927 , -0.9996754 , 0.22211392, -0.99826247, 0.7562492 ,
0.93803996, 0.82738185, 0.4773049 , -0.73478544, 0.85207295,
Post a Comment for "Bert Get Sentence Level Embedding After Fine Tuning"