Understanding Embeddings on Texts


In [0]:
# Based on
# https://github.com/fchollet/deep-learning-with-python-notebooks/blob/master/6.2-understanding-recurrent-neural-networks.ipynb

In [0]:
import warnings
warnings.filterwarnings('ignore')

In [32]:
%matplotlib inline
%pylab inline


Populating the interactive namespace from numpy and matplotlib

In [33]:
import tensorflow as tf
tf.logging.set_verbosity(tf.logging.ERROR)
print(tf.__version__)


1.12.0

In [0]:
from tensorflow import keras

# https://keras.io/datasets/#imdb-movie-reviews-sentiment-classification
max_features = 10000  # number of words to consider as features
maxlen = 50  # cut texts after this number of words (among top max_features most common words)

# each review is encoded as a sequence of word indexes
# indexed by overall frequency in the dataset
# output is 0 (negative) or 1 (positive) 
imdb = keras.datasets.imdb.load_data(num_words=max_features)
(raw_input_train, y_train), (raw_input_test, y_test) = imdb

In [0]:
# tf.keras.datasets.imdb.load_data?

In [36]:
y_train.min()


Out[36]:
0

In [37]:
y_train.max()


Out[37]:
1

In [38]:
# 25000 texts
len(raw_input_train)


Out[38]:
25000

In [39]:
# first text has 218 words
len(raw_input_train[0])


Out[39]:
218

In [40]:
raw_input_train[0]


Out[40]:
[1,
 14,
 22,
 16,
 43,
 530,
 973,
 1622,
 1385,
 65,
 458,
 4468,
 66,
 3941,
 4,
 173,
 36,
 256,
 5,
 25,
 100,
 43,
 838,
 112,
 50,
 670,
 2,
 9,
 35,
 480,
 284,
 5,
 150,
 4,
 172,
 112,
 167,
 2,
 336,
 385,
 39,
 4,
 172,
 4536,
 1111,
 17,
 546,
 38,
 13,
 447,
 4,
 192,
 50,
 16,
 6,
 147,
 2025,
 19,
 14,
 22,
 4,
 1920,
 4613,
 469,
 4,
 22,
 71,
 87,
 12,
 16,
 43,
 530,
 38,
 76,
 15,
 13,
 1247,
 4,
 22,
 17,
 515,
 17,
 12,
 16,
 626,
 18,
 2,
 5,
 62,
 386,
 12,
 8,
 316,
 8,
 106,
 5,
 4,
 2223,
 5244,
 16,
 480,
 66,
 3785,
 33,
 4,
 130,
 12,
 16,
 38,
 619,
 5,
 25,
 124,
 51,
 36,
 135,
 48,
 25,
 1415,
 33,
 6,
 22,
 12,
 215,
 28,
 77,
 52,
 5,
 14,
 407,
 16,
 82,
 2,
 8,
 4,
 107,
 117,
 5952,
 15,
 256,
 4,
 2,
 7,
 3766,
 5,
 723,
 36,
 71,
 43,
 530,
 476,
 26,
 400,
 317,
 46,
 7,
 4,
 2,
 1029,
 13,
 104,
 88,
 4,
 381,
 15,
 297,
 98,
 32,
 2071,
 56,
 26,
 141,
 6,
 194,
 7486,
 18,
 4,
 226,
 22,
 21,
 134,
 476,
 26,
 480,
 5,
 144,
 30,
 5535,
 18,
 51,
 36,
 28,
 224,
 92,
 25,
 104,
 4,
 226,
 65,
 16,
 38,
 1334,
 88,
 12,
 16,
 283,
 5,
 16,
 4472,
 113,
 103,
 32,
 15,
 16,
 5345,
 19,
 178,
 32]

In [0]:
# tf.keras.preprocessing.sequence.pad_sequences?

In [0]:
# https://www.tensorflow.org/api_docs/python/tf/keras/preprocessing/sequence/pad_sequences

input_train = keras.preprocessing.sequence.pad_sequences(raw_input_train, maxlen=maxlen)
input_test = keras.preprocessing.sequence.pad_sequences(raw_input_test, maxlen=maxlen)

In [43]:
input_train.shape, input_test.shape, y_train.shape, y_test.shape


Out[43]:
((25000, 50), (25000, 50), (25000,), (25000,))

In [44]:
# left padded with zeros
# As a convention, "0" does not stand for a specific word, but instead is used to encode any unknown word.
input_train[0]


Out[44]:
array([2071,   56,   26,  141,    6,  194, 7486,   18,    4,  226,   22,
         21,  134,  476,   26,  480,    5,  144,   30, 5535,   18,   51,
         36,   28,  224,   92,   25,  104,    4,  226,   65,   16,   38,
       1334,   88,   12,   16,  283,    5,   16, 4472,  113,  103,   32,
         15,   16, 5345,   19,  178,   32], dtype=int32)

Training the embedding together with the whole model is more reasonable

Alternative: use a pre-trained model, probably trained using skip-gram


In [0]:
# tf.keras.layers.Embedding?

In [46]:
from tensorflow.keras.layers import Embedding, Flatten, GlobalAveragePooling1D, Dense, Dropout

embedding_dim = 2

model = keras.Sequential()
# Parameters: max_features * embedding_dim 
model.add(Embedding(name='embedding', input_dim=max_features, output_dim=embedding_dim, input_length=maxlen))

# Output: maxlen * embedding_dim (8)
model.add(Flatten(name='flatten'))

# ALTERNATIVE
# average of all embeddings (does not preserve sequence)
# model.add(GlobalAveragePooling1D(name='average_pooling'))

# binary classifier
# model.add(Dense(name='fc', units=32, activation='relu'))
# model.add(Dropout(0.4))
model.add(Dense(name='classifier', units=1, activation='sigmoid'))

model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])

model.summary()


_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
embedding (Embedding)        (None, 50, 2)             20000     
_________________________________________________________________
flatten (Flatten)            (None, 100)               0         
_________________________________________________________________
classifier (Dense)           (None, 1)                 101       
=================================================================
Total params: 20,101
Trainable params: 20,101
Non-trainable params: 0
_________________________________________________________________

In [47]:
batch_size = 96

%time history = model.fit(input_train, y_train, epochs=40, batch_size=batch_size, validation_data=(input_test, y_test))


Train on 25000 samples, validate on 25000 samples
Epoch 1/40
25000/25000 [==============================] - 1s 36us/step - loss: 0.6879 - acc: 0.5650 - val_loss: 0.6706 - val_acc: 0.6829
Epoch 2/40
25000/25000 [==============================] - 1s 20us/step - loss: 0.5941 - acc: 0.7602 - val_loss: 0.5229 - val_acc: 0.7744
Epoch 3/40
25000/25000 [==============================] - 1s 21us/step - loss: 0.4554 - acc: 0.8132 - val_loss: 0.4410 - val_acc: 0.8046
Epoch 4/40
25000/25000 [==============================] - 1s 20us/step - loss: 0.3862 - acc: 0.8384 - val_loss: 0.4101 - val_acc: 0.8161
Epoch 5/40
25000/25000 [==============================] - 1s 21us/step - loss: 0.3471 - acc: 0.8569 - val_loss: 0.3966 - val_acc: 0.8208
Epoch 6/40
25000/25000 [==============================] - 1s 20us/step - loss: 0.3198 - acc: 0.8701 - val_loss: 0.3909 - val_acc: 0.8218
Epoch 7/40
25000/25000 [==============================] - 1s 20us/step - loss: 0.2987 - acc: 0.8788 - val_loss: 0.3906 - val_acc: 0.8223
Epoch 8/40
25000/25000 [==============================] - 1s 20us/step - loss: 0.2815 - acc: 0.8875 - val_loss: 0.3911 - val_acc: 0.8226
Epoch 9/40
25000/25000 [==============================] - 1s 21us/step - loss: 0.2664 - acc: 0.8957 - val_loss: 0.3949 - val_acc: 0.8216
Epoch 10/40
25000/25000 [==============================] - 1s 20us/step - loss: 0.2531 - acc: 0.9026 - val_loss: 0.3993 - val_acc: 0.8209
Epoch 11/40
25000/25000 [==============================] - 1s 21us/step - loss: 0.2411 - acc: 0.9096 - val_loss: 0.4053 - val_acc: 0.8194
Epoch 12/40
25000/25000 [==============================] - 1s 21us/step - loss: 0.2298 - acc: 0.9141 - val_loss: 0.4116 - val_acc: 0.8167
Epoch 13/40
25000/25000 [==============================] - 1s 20us/step - loss: 0.2192 - acc: 0.9201 - val_loss: 0.4189 - val_acc: 0.8150
Epoch 14/40
25000/25000 [==============================] - 1s 21us/step - loss: 0.2092 - acc: 0.9254 - val_loss: 0.4276 - val_acc: 0.8131
Epoch 15/40
25000/25000 [==============================] - 1s 20us/step - loss: 0.1997 - acc: 0.9297 - val_loss: 0.4360 - val_acc: 0.8107
Epoch 16/40
25000/25000 [==============================] - 1s 20us/step - loss: 0.1905 - acc: 0.9338 - val_loss: 0.4448 - val_acc: 0.8090
Epoch 17/40
25000/25000 [==============================] - 1s 20us/step - loss: 0.1820 - acc: 0.9379 - val_loss: 0.4541 - val_acc: 0.8084
Epoch 18/40
25000/25000 [==============================] - 1s 20us/step - loss: 0.1736 - acc: 0.9418 - val_loss: 0.4653 - val_acc: 0.8060
Epoch 19/40
25000/25000 [==============================] - 1s 20us/step - loss: 0.1653 - acc: 0.9458 - val_loss: 0.4751 - val_acc: 0.8043
Epoch 20/40
25000/25000 [==============================] - 1s 21us/step - loss: 0.1576 - acc: 0.9504 - val_loss: 0.4867 - val_acc: 0.8028
Epoch 21/40
25000/25000 [==============================] - 1s 21us/step - loss: 0.1502 - acc: 0.9527 - val_loss: 0.4975 - val_acc: 0.8011
Epoch 22/40
25000/25000 [==============================] - 1s 20us/step - loss: 0.1429 - acc: 0.9558 - val_loss: 0.5102 - val_acc: 0.7982
Epoch 23/40
25000/25000 [==============================] - 1s 20us/step - loss: 0.1360 - acc: 0.9591 - val_loss: 0.5232 - val_acc: 0.7970
Epoch 24/40
25000/25000 [==============================] - 1s 21us/step - loss: 0.1295 - acc: 0.9630 - val_loss: 0.5353 - val_acc: 0.7949
Epoch 25/40
25000/25000 [==============================] - 1s 20us/step - loss: 0.1233 - acc: 0.9650 - val_loss: 0.5488 - val_acc: 0.7941
Epoch 26/40
25000/25000 [==============================] - 1s 21us/step - loss: 0.1171 - acc: 0.9689 - val_loss: 0.5629 - val_acc: 0.7919
Epoch 27/40
25000/25000 [==============================] - 1s 20us/step - loss: 0.1114 - acc: 0.9707 - val_loss: 0.5773 - val_acc: 0.7910
Epoch 28/40
25000/25000 [==============================] - 1s 20us/step - loss: 0.1059 - acc: 0.9729 - val_loss: 0.5918 - val_acc: 0.7894
Epoch 29/40
25000/25000 [==============================] - 1s 21us/step - loss: 0.1003 - acc: 0.9757 - val_loss: 0.6078 - val_acc: 0.7866
Epoch 30/40
25000/25000 [==============================] - 1s 20us/step - loss: 0.0953 - acc: 0.9781 - val_loss: 0.6230 - val_acc: 0.7856
Epoch 31/40
25000/25000 [==============================] - 1s 20us/step - loss: 0.0903 - acc: 0.9799 - val_loss: 0.6382 - val_acc: 0.7847
Epoch 32/40
25000/25000 [==============================] - 1s 21us/step - loss: 0.0855 - acc: 0.9815 - val_loss: 0.6559 - val_acc: 0.7828
Epoch 33/40
25000/25000 [==============================] - 1s 20us/step - loss: 0.0811 - acc: 0.9836 - val_loss: 0.6714 - val_acc: 0.7816
Epoch 34/40
25000/25000 [==============================] - 1s 20us/step - loss: 0.0766 - acc: 0.9851 - val_loss: 0.6886 - val_acc: 0.7806
Epoch 35/40
25000/25000 [==============================] - 0s 20us/step - loss: 0.0724 - acc: 0.9869 - val_loss: 0.7069 - val_acc: 0.7789
Epoch 36/40
25000/25000 [==============================] - 1s 21us/step - loss: 0.0685 - acc: 0.9878 - val_loss: 0.7243 - val_acc: 0.7773
Epoch 37/40
25000/25000 [==============================] - 1s 20us/step - loss: 0.0646 - acc: 0.9888 - val_loss: 0.7435 - val_acc: 0.7753
Epoch 38/40
25000/25000 [==============================] - 1s 21us/step - loss: 0.0609 - acc: 0.9902 - val_loss: 0.7614 - val_acc: 0.7752
Epoch 39/40
25000/25000 [==============================] - 1s 20us/step - loss: 0.0575 - acc: 0.9911 - val_loss: 0.7800 - val_acc: 0.7743
Epoch 40/40
25000/25000 [==============================] - 1s 21us/step - loss: 0.0542 - acc: 0.9923 - val_loss: 0.7999 - val_acc: 0.7733
CPU times: user 27.6 s, sys: 1.28 s, total: 28.9 s
Wall time: 21.2 s

In [48]:
import pandas as pd

def plot_history(history, samples=10, init_phase_samples=None):
    epochs = history.params['epochs']
    
    acc = history.history['acc']
    val_acc = history.history['val_acc']

    every_sample =  int(epochs / samples)
    acc = pd.DataFrame(acc).iloc[::every_sample, :]
    val_acc = pd.DataFrame(val_acc).iloc[::every_sample, :]

    fig, ax = plt.subplots(figsize=(20,5))

    ax.plot(acc, 'bo', label='Training acc')
    ax.plot(val_acc, 'b', label='Validation acc')
    ax.set_title('Training and validation accuracy')
    ax.legend()

plot_history(history)



In [49]:
train_loss, train_accuracy = model.evaluate(input_train, y_train, batch_size=batch_size)
train_accuracy


25000/25000 [==============================] - 0s 8us/step
Out[49]:
0.9943199937438965

In [50]:
test_loss, test_accuracy = model.evaluate(input_test, y_test, batch_size=batch_size)
test_accuracy


25000/25000 [==============================] - 0s 8us/step
Out[50]:
0.7732800003242493

In [51]:
# precition
model.predict(input_test[0:5])


Out[51]:
array([[0.80683035],
       [0.9945182 ],
       [0.4653315 ],
       [0.68047816],
       [0.99997866]], dtype=float32)

In [52]:
# ground truth
y_test[0:5]


Out[52]:
array([0, 1, 1, 0, 1])

How does the output of the trained embedding look like?


In [0]:
embedding_layer = model.get_layer('embedding')

In [0]:
model_stub= keras.Model(inputs=model.input, outputs=embedding_layer.output)

In [0]:
word_to_id = keras.datasets.imdb.get_word_index()

def encode_text(text):
    input_words = text.split()
    input_tokens = np.array([word_to_id[word] for word in input_words])
    padded_input_tokens = keras.preprocessing.sequence.pad_sequences([input_tokens], maxlen=maxlen)
    return padded_input_tokens

def plot_text_embedding(model, text):
    input_words = text.split()
    input_sequence = encode_text(text)
    
    embeddings = model.predict(input_sequence)[0][-len(input_words):, :]
    x_coords = embeddings[:, 0] # First latent dim
    y_coords = embeddings[:, 1] # Second latent dim
    plt.figure(figsize=(20, 20))
    plt.scatter(x_coords, y_coords)
    for i, txt in enumerate(input_words):
        plt.annotate(txt, (x_coords[i], y_coords[i]))
    plt.show()

In [56]:
text = """good best brilliant amazing great lovely awesome 
          bad worst awful 
          art
          garbage gross horrible
          sad funny 
          beautiful ugly
               movie actor male female love"""
plot_text_embedding(model_stub, text)



In [57]:
from tensorflow.keras.layers import Embedding, Flatten, GlobalAveragePooling1D, Dense, Dropout

embedding_dim = 1

model = keras.Sequential()
# Parameters: max_features * embedding_dim 
model.add(Embedding(name='embedding', input_dim=max_features, output_dim=embedding_dim, input_length=maxlen))

# Output: maxlen * embedding_dim (8)
model.add(Flatten(name='flatten'))

# ALTERNATIVE
# average of all embeddings (does not preserve sequence)
# model.add(GlobalAveragePooling1D(name='average_pooling'))

# binary classifier
model.add(Dense(name='fc', units=32, activation='relu'))
# model.add(Dropout(0.4))
model.add(Dense(name='classifier', units=1, activation='sigmoid'))

model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])

# model.summary()
batch_size = 96

%time history = model.fit(input_train, y_train, epochs=40, batch_size=batch_size, validation_data=(input_test, y_test))


Train on 25000 samples, validate on 25000 samples
Epoch 1/40
25000/25000 [==============================] - 1s 40us/step - loss: 0.6897 - acc: 0.5386 - val_loss: 0.6646 - val_acc: 0.6654
Epoch 2/40
25000/25000 [==============================] - 1s 22us/step - loss: 0.4884 - acc: 0.7833 - val_loss: 0.4087 - val_acc: 0.8124
Epoch 3/40
25000/25000 [==============================] - 1s 22us/step - loss: 0.3422 - acc: 0.8529 - val_loss: 0.3987 - val_acc: 0.8188
Epoch 4/40
25000/25000 [==============================] - 1s 22us/step - loss: 0.2953 - acc: 0.8766 - val_loss: 0.4096 - val_acc: 0.8180
Epoch 5/40
25000/25000 [==============================] - 1s 22us/step - loss: 0.2651 - acc: 0.8928 - val_loss: 0.4291 - val_acc: 0.8135
Epoch 6/40
25000/25000 [==============================] - 1s 22us/step - loss: 0.2430 - acc: 0.9044 - val_loss: 0.4544 - val_acc: 0.8089
Epoch 7/40
25000/25000 [==============================] - 1s 22us/step - loss: 0.2248 - acc: 0.9122 - val_loss: 0.4799 - val_acc: 0.8054
Epoch 8/40
25000/25000 [==============================] - 1s 22us/step - loss: 0.2090 - acc: 0.9212 - val_loss: 0.5114 - val_acc: 0.7996
Epoch 9/40
25000/25000 [==============================] - 1s 22us/step - loss: 0.1966 - acc: 0.9263 - val_loss: 0.5413 - val_acc: 0.7978
Epoch 10/40
25000/25000 [==============================] - 1s 22us/step - loss: 0.1855 - acc: 0.9321 - val_loss: 0.5776 - val_acc: 0.7950
Epoch 11/40
25000/25000 [==============================] - 1s 22us/step - loss: 0.1747 - acc: 0.9383 - val_loss: 0.6090 - val_acc: 0.7905
Epoch 12/40
25000/25000 [==============================] - 1s 22us/step - loss: 0.1664 - acc: 0.9408 - val_loss: 0.6452 - val_acc: 0.7862
Epoch 13/40
25000/25000 [==============================] - 1s 22us/step - loss: 0.1565 - acc: 0.9452 - val_loss: 0.6852 - val_acc: 0.7846
Epoch 14/40
25000/25000 [==============================] - 1s 22us/step - loss: 0.1498 - acc: 0.9488 - val_loss: 0.7192 - val_acc: 0.7802
Epoch 15/40
25000/25000 [==============================] - 1s 22us/step - loss: 0.1422 - acc: 0.9526 - val_loss: 0.7573 - val_acc: 0.7787
Epoch 16/40
25000/25000 [==============================] - 1s 22us/step - loss: 0.1356 - acc: 0.9553 - val_loss: 0.7975 - val_acc: 0.7774
Epoch 17/40
25000/25000 [==============================] - 1s 22us/step - loss: 0.1297 - acc: 0.9579 - val_loss: 0.8419 - val_acc: 0.7740
Epoch 18/40
25000/25000 [==============================] - 1s 22us/step - loss: 0.1227 - acc: 0.9614 - val_loss: 0.8872 - val_acc: 0.7739
Epoch 19/40
25000/25000 [==============================] - 1s 22us/step - loss: 0.1183 - acc: 0.9624 - val_loss: 0.9255 - val_acc: 0.7701
Epoch 20/40
25000/25000 [==============================] - 1s 22us/step - loss: 0.1126 - acc: 0.9658 - val_loss: 0.9649 - val_acc: 0.7665
Epoch 21/40
25000/25000 [==============================] - 1s 22us/step - loss: 0.1069 - acc: 0.9689 - val_loss: 1.0222 - val_acc: 0.7674
Epoch 22/40
25000/25000 [==============================] - 1s 22us/step - loss: 0.1026 - acc: 0.9703 - val_loss: 1.0784 - val_acc: 0.7676
Epoch 23/40
25000/25000 [==============================] - 1s 22us/step - loss: 0.0982 - acc: 0.9706 - val_loss: 1.1093 - val_acc: 0.7646
Epoch 24/40
25000/25000 [==============================] - 1s 23us/step - loss: 0.0926 - acc: 0.9741 - val_loss: 1.1541 - val_acc: 0.7634
Epoch 25/40
25000/25000 [==============================] - 1s 22us/step - loss: 0.0890 - acc: 0.9759 - val_loss: 1.2094 - val_acc: 0.7613
Epoch 26/40
25000/25000 [==============================] - 1s 22us/step - loss: 0.0852 - acc: 0.9754 - val_loss: 1.2589 - val_acc: 0.7612
Epoch 27/40
25000/25000 [==============================] - 1s 22us/step - loss: 0.0810 - acc: 0.9782 - val_loss: 1.2962 - val_acc: 0.7583
Epoch 28/40
25000/25000 [==============================] - 1s 22us/step - loss: 0.0773 - acc: 0.9788 - val_loss: 1.3492 - val_acc: 0.7584
Epoch 29/40
25000/25000 [==============================] - 1s 22us/step - loss: 0.0724 - acc: 0.9824 - val_loss: 1.4034 - val_acc: 0.7568
Epoch 30/40
25000/25000 [==============================] - 1s 22us/step - loss: 0.0706 - acc: 0.9814 - val_loss: 1.4516 - val_acc: 0.7562
Epoch 31/40
25000/25000 [==============================] - 1s 22us/step - loss: 0.0678 - acc: 0.9819 - val_loss: 1.4972 - val_acc: 0.7563
Epoch 32/40
25000/25000 [==============================] - 1s 22us/step - loss: 0.0639 - acc: 0.9838 - val_loss: 1.5725 - val_acc: 0.7552
Epoch 33/40
25000/25000 [==============================] - 1s 22us/step - loss: 0.0616 - acc: 0.9834 - val_loss: 1.6270 - val_acc: 0.7552
Epoch 34/40
25000/25000 [==============================] - 1s 22us/step - loss: 0.0581 - acc: 0.9850 - val_loss: 1.6406 - val_acc: 0.7522
Epoch 35/40
25000/25000 [==============================] - 1s 22us/step - loss: 0.0559 - acc: 0.9863 - val_loss: 1.6812 - val_acc: 0.7514
Epoch 36/40
25000/25000 [==============================] - 1s 22us/step - loss: 0.0526 - acc: 0.9876 - val_loss: 1.7367 - val_acc: 0.7511
Epoch 37/40
25000/25000 [==============================] - 1s 22us/step - loss: 0.0487 - acc: 0.9889 - val_loss: 1.7682 - val_acc: 0.7480
Epoch 38/40
25000/25000 [==============================] - 1s 22us/step - loss: 0.0472 - acc: 0.9891 - val_loss: 1.8227 - val_acc: 0.7497
Epoch 39/40
25000/25000 [==============================] - 1s 22us/step - loss: 0.0474 - acc: 0.9883 - val_loss: 1.8272 - val_acc: 0.7467
Epoch 40/40
25000/25000 [==============================] - 1s 22us/step - loss: 0.0444 - acc: 0.9891 - val_loss: 1.8871 - val_acc: 0.7501
CPU times: user 29.8 s, sys: 1.46 s, total: 31.3 s
Wall time: 22.9 s

What about 1d?


In [58]:
embedding_layer = model.get_layer('embedding')
model_stub= keras.Model(inputs=model.input, outputs=embedding_layer.output)

def plot_1d_text_embedding(model, text):
    input_words = text.split()
    input_sequence = encode_text(text)
    
    embeddings = model.predict(input_sequence)[0][-len(input_words):, :]
    plt.figure(figsize=(20, 5))
    plt.scatter(embeddings, np.zeros(len(embeddings)))
    for i, txt in enumerate(input_words):
        plt.annotate(txt, (embeddings[i], 0.004), rotation=80)
    plt.show()

text = """good best brilliant amazing great lovely awesome 
          bad worst awful 
          garbage gross horrible
          sad funny 
          beautiful ugly"""
plot_1d_text_embedding(model_stub, text)



In [0]: