Copyright 2019 The Google Research Authors.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License.
In [0]:
import tensorflow as tf
tf.enable_v2_behavior()
tf.compat.v1.enable_resource_variables()
import numpy as np
from tensorflow import keras
from tensorflow.keras import layers as L
from tensorflow.keras import Model
import tensorflow_datasets as tfds
from tensorflow_probability.python.distributions import categorical
Class ContextEmbedder
outputs a predicted representation for a target word by forming a linear combination of context word feature vectors. In particular, for a context $h=w_1, \dots, w_n$ with context embeddings $r_{w_i}$, the predicted representation is:
$$ \hat{q}(h) = \sum_{i=1}^n c_i \odot r_{w_i} $$
In [0]:
class ContextEmbedder(Model):
"""Reweights and sums a set of vectors in R^d representing context word
vectors.
@param:
`vocab_size` is how many words in the vocabulary
`window_size` is how many context words surround the word of interest. Note
that we assume the window is centered here, and the context includes the word.
Hence, a window size of 5 means [w_0, w_1, w_2, w_3, w_4, w_5] where w_3 is
the target word.
`embed_dim` is the size of the embeddings dimension
"""
def __init__(self, vocab_size, window_size, embed_dim, exclude_target=True,
#embeddings_intializer='uniform',
embeddings_intializer=keras.initializers.RandomUniform(minval=-0.001, maxval=0.001, seed=None),
):
super(ContextEmbedder, self).__init__()
self._embed_dim = embed_dim
self._vocab_size = vocab_size
self._window_size = window_size
self._exclude_target = exclude_target
self._target_idx = self._window_size // 2
self._get_context_weights = L.Embedding(self._window_size, self._embed_dim,
embeddings_initializer=embeddings_intializer)
self._get_context_embeddings = L.Embedding(self._vocab_size, self._embed_dim,
embeddings_initializer=embeddings_intializer)
print(f"INFO: exclude target is {self._exclude_target}")
def _exclude_target_idx(self, lst):
"""return list without the target index
"""
print(f"lst: {lst.shape}")
out = []
for row in lst:
out.append(row[:self._target_idx] + row[self._target_idx + 1:])
print(out)
print(tf.map_fn(exclude_helper, lst))
def embed(self, x):
"""Given (? x `self._window_size`) sequence of integers `x`, returns
a
(? x `self._window_size` x `self._embed_dim`) tensor of embeddings
"""
if self._exclude_target:
y = tf.concat([x[:,:self._target_idx], x[:,self._target_idx + 1:]], axis=-1)
return self._get_context_embeddings(y)
else:
return self._get_context_embeddings(x)
# return self._get_context_embeddings(self._exclude_target_idx(x))
def _make_index_tensor(self, x):
"""Returns a tensor of shape `x.shape` where the last axis is always
the integers list(range(len(x.shape[1])))
"""
# create set of index tensors to retrieve position weights
batch_shape, window_len, _ = x.shape
index_numbers = list(range(window_len))
to_tile = tf.reshape(tf.constant(index_numbers), shape=(1, -1))
context_positions = tf.tile(to_tile, multiples=(x.shape[0], 1))
# print(context_positions)
return context_positions
def _weight_and_sum(self, x):
context_positions = self._make_index_tensor(x)
position_weights = self._get_context_weights(context_positions)
# reweight embeddings vector dims according to `position_weights`, then sum embedding vectors
weighted_sum = tf.reduce_sum(tf.multiply(position_weights, x), axis=1)
return weighted_sum
def call(self, h):
"""Given (? x `self._context_window_size`) sequence of integers `x`, returns
(? x `self._embed_dim`) tensor of predicted embeddings
"""
context_word_embeddings = self.embed(h) # (? x `self._window_size` x `self._embed_dim`)
predicted_embeddings = self._weight_and_sum(context_word_embeddings)
return predicted_embeddings
Class LogBilinearScore
computes a score as the similarity between a predicted vector $\hat{q}$ or qhat
and the embedding $q$ or q
. We include a term $b_w$ or b
to capture the context-independent frequency of a word $w$. Given some $h$ and $\hat{q}(h)$, calling an instance of LogBilinearScore
returns
$$s_{\theta}(w,h) = \hat{q}(h)^\top q_w + b_w$$
In [0]:
class LogBilinearScore(Model):
def __init__(self, vocab_size, embed_dim, embeddings_intializer='uniform'):
super(LogBilinearScore, self).__init__()
self.embed_dim = embed_dim
self.vocab_size = vocab_size
self.q = L.Embedding(self.vocab_size, self.embed_dim,
embeddings_initializer=embeddings_intializer)
self.b = L.Embedding(self.vocab_size, 1,
embeddings_initializer=embeddings_intializer)
def call(self, w, qhat):
""" Returns a (? x `self.window_size`) score tensor
w shape is (? x `self.window_size`)
qhat is (? x `self._window_size` x `self._embed_dim`)
"""
q = self.q(w) # (? x `self.window_size` x `self.embed_dim`)
b = tf.squeeze(self.b(w)) # (? x `self.window_size`)
return tf.reduce_sum(tf.multiply(q, qhat), axis=-1, keepdims=False) + b
In [0]:
import random
class WindowRebatcher(object):
"""Take a batch of sequences, some padded with `pad_token`. Return a new batch
of sequences `window_size` over the original sequences
"""
def __init__(self, window_size, pad_token=-1):
self._ws = window_size
self._pt = pad_token
def rebatch(self, batch):
rebatches = []
# print(batch)
for item in batch:
# print(item)
end_idx = np.where(item.numpy() == -1)[0]
if len(end_idx) > 0:
# print(end_idx[0])
end_idx = end_idx[0]
else:
# print("no -1")
end_idx = len(item)
# print(len(item))
n_windows = end_idx - (self._ws - 1)
# print(n_windows)
idxs = [list(range(i, i + self._ws)) for i in range(n_windows)]
random.shuffle(idxs)
idxs = tf.constant(idxs[0:20],
dtype=tf.int64)
idxs = tf.constant(idxs,
dtype=tf.int64)
# print(f"idxs:{idxs}")
if np.sum(idxs) == 0:
# print("saw empty batch...")
continue # empty batch?
# print(tf.map_fn(lambda x: tf.gather(item, x), idxs))
rebatches.append(tf.map_fn(lambda x: tf.gather(item, x), idxs))
# print(f"rebatches_shape:{rebatches[-1].shape}")
return tf.concat(rebatches, axis=0)
Given a context $h$, an NPLM defines the distribution for the word to be predicted using the scoring function $s_{\theta}(w,h)$ that quantifies the compatibility between the context and the candidate target word. Here $\theta$ are model parameters, which include the word embeddings. The scores are converted to probabilities by exponentiating and normalizing: $$P_{\theta}^h = \frac{\exp (s_{\theta}(w,h))}{\sum_{w'}\exp(s_{\theta}(w',h))}$$ Unfortunately both evaluating $P^h_{\theta}(w)$ and computing the corresponding likelihood gradient requires normalizing over the entire vocabulary, which means that maximum likelihood training of such models takes time linear in the vocabulary size, and thus is prohibitively expensive for all but the smallest vocabularies.
In [0]:
class NPLMEstimator(Model):
def __init__(self, vocab_size, embed_dim, window_size, mc_noise_samples, noise_pmf=None):
super(NPLMEstimator, self).__init__()
assert(window_size % 2 == 1), f"window_size must be odd but saw {window_size}!"
self.mc_noise_samples = mc_noise_samples
self.batcher = WindowRebatcher(window_size)
self.context_embedder = ContextEmbedder(vocab_size, window_size, embed_dim)
self.word_idx_to_score = LogBilinearScore(vocab_size, embed_dim)
self.target_idx = window_size // 2
if noise_pmf is None:
noise_pmf = tf.ones(shape=(vocab_size,), dtype=tf.float64) / vocab_size
self.noise_dist = categorical.Categorical(logits=tf.constant(tf.math.log(noise_pmf),
dtype=tf.float64))
self.logk = tf.log(tf.constant(mc_noise_samples, dtype=tf.float32))
def sample_and_log_prob_noise(self, n):
samples = []
log_probs = []
for _ in range(n):
samples.append(self.noise_dist.sample(self.mc_noise_samples))
log_probs.append(self.noise_dist.log_prob(samples[-1]))
return tf.stack(samples), tf.stack(log_probs)
def del_score(self, score, log_p_noise, simple=True):
if simple:
return score
else:
return score - tf.cast(log_p_noise + tf.ones_like(log_p_noise) * tf.cast(self.logk, dtype=tf.float64), dtype=tf.float32)
def call(self, rebatched):
""" x is (1, S) sequence, where m changes batch to batch
Let W = S - (window_size - 1)
"""
# --> embed the context words
qhat = self.context_embedder(rebatched) # (W, embed_size)
# --> compute scores for target word in each windowed batch
scores = self.word_idx_to_score(rebatched[:, self.target_idx], qhat) # (W, 1)
log_p_true_under_noise = self.noise_dist.log_prob(rebatched[:, self.target_idx])
# --> Monte Carlo estimate of loss under noise distribution
noise_samples, log_p_noise = self.sample_and_log_prob_noise(rebatched.shape[0])
noise_words = tf.reshape(noise_samples, # (? * n_noise_samples, 1)
shape=(tf.reduce_prod(noise_samples.shape),))
noise_log_probs = tf.reshape(log_p_noise, # (? * n_noise_samples, 1)
shape=(tf.reduce_prod(log_p_noise.shape),))
qhat_tiled = tf.reshape(tf.keras.backend.repeat(qhat, n=self.mc_noise_samples), # (? * n_noise_samples, window_size, embed_size)
shape=(-1, qhat.shape[-1]))
noise_scores_repeated = self.word_idx_to_score(noise_words, qhat_tiled) # (? * n_noise_samples, 1)
noise_scores_samples = tf.reshape(noise_scores_repeated,
shape=(rebatched.shape[0], -1))
noise_scores = tf.reduce_mean(noise_scores_samples, axis=-1) # (?, 1)
true_del_score = self.del_score(scores, log_p_true_under_noise)
noise_del_score_repeated = self.del_score(noise_scores_repeated,
noise_log_probs)
expected_true = tf.nn.sigmoid_cross_entropy_with_logits(labels=tf.ones_like(true_del_score),
logits=true_del_score)
expected_false = tf.reduce_mean(tf.reshape(
tf.nn.sigmoid_cross_entropy_with_logits(
labels=tf.zeros_like(noise_del_score_repeated),
logits=noise_del_score_repeated), shape=(rebatched.shape[0], -1)),
axis=-1)
return tf.reduce_mean(expected_true + expected_false)
In [0]:
(train_data, test_data), info = tfds.load(
'lm1b/subwords8k',
split = (tfds.Split.TRAIN, tfds.Split.TEST),
with_info=True, as_supervised=True)
padded_shapes = ([None],
# ())
(None,))
pad_int = -1 # integer flag that never appears in dataset
batch_size = 10 # retrieve one sequence at a time
padding_values = (tf.constant(pad_int, dtype=tf.int64),
tf.constant(0, dtype=tf.int64))
train_batches = train_data.shuffle(1024).padded_batch(
batch_size,
padded_shapes,
padding_values=padding_values)
test_batches = test_data.shuffle(1024).padded_batch(
batch_size,
padded_shapes,
padding_values=padding_values)
In [0]:
window_size = 5 # @param
mc_noise_samples = 25 # @param
embed_dim = 2 # @param
vocab_size = info.features['text'].encoder.vocab_size
train_iter = 50000 # @param
log_interval = 10000 # @param
epochs = 1
In [0]:
# !!! Only run this cell to retrain from scratch
# %%time
# save_interval = 5
# save_path = "./
# @tf.function
# def update(rebatched):
# with tf.GradientTape() as tape:
# loss = nce_scorer(rebatched)
# gradients = tape.gradient(loss, nce_scorer.trainable_variables)
# optimizer.apply_gradients(zip(gradients, nce_scorer.trainable_variables))
# return loss
# losses = []
# for e in range(epochs):
# for idx, batch in enumerate(iter(train_batches)):
# seq, _ = batch # (1, ?)
# rebatched = nce_scorer.batcher.rebatch(seq) # (W, window_size)
# loss = update(rebatched)
# print(f"iter: {idx}")
# if idx % save_interval == 0:
# print("saving>...")
# nce_scorer.save_weights(save_path)
# # nce_scorer.save_weights
# losses.append(loss)
# # with tf.GradientTape() as tape:
# # loss = nce_scorer(seq)
# # gradients = tape.gradient(loss, nce_scorer.trainable_variables)
# # optimizer.apply_gradients(zip(gradients, nce_scorer.trainable_variables))
# if idx % log_interval == 0:
# losses.append(loss)
# print(f"Epoch: {e} | iter {idx} | Loss: {loss}")
# if (idx + 1) % train_iter == 0:
# break
# losses.append(loss)
# print(f"Epoch End: {e} Loss: {loss}")
# HERE: save the model output weights to a directory of your choosing
In [0]:
from sklearn.cross_decomposition import CCA
from numpy import linalg as LA
model_ids = [0, 2, 3, 4]
# Model checkpoints have been provided with this script
# Choose the appropriate root directory for the data, e.g., Google Drive
USER_ROOT_DIR = "./"
model_paths = [[os.path.join(USER_ROOT_DIR, "2_0.003000_%d/model_checkpoints/rep1"%x),
os.path.join(USER_ROOT_DIR, "2_0.003000_%d/model_checkpoints/rep2"%x)] for x in model_ids]
print(model_paths)
estimators = [[NPLMEstimator(vocab_size, embed_dim, window_size, mc_noise_samples, noise_pmf=noise_pmf),
NPLMEstimator(vocab_size, embed_dim, window_size, mc_noise_samples, noise_pmf=noise_pmf)] for _ in range(len(model_paths))]
for idx, est in enumerate(estimators):
est[0].load_weights(model_paths[idx][0])
est[1].load_weights(model_paths[idx][1])
In [0]:
# Tableau Color Blind 10
tableau20blind = [(0, 107, 164), (255, 128, 14), (171, 171, 171), (89, 89, 89),
(95, 158, 209), (200, 82, 0), (137, 137, 137), (163, 200, 236),
(255, 188, 121), (207, 207, 207)]
for i in range(len(tableau20blind)):
r, g, b = tableau20blind[i]
tableau20blind[i] = (r / 255., g / 255., b / 255.)
In [0]:
import jax.numpy as jnp
import jax
import jax.experimental.optimizers
# learn the best possible linear transformation by regressing
def align(act1, act2):
def align_loss(p):
w, b = p
pred = jnp.matmul(act1, w)# + b
return jnp.mean(jnp.square(act2 - pred))
w = jnp.zeros((2,2))
b = jnp.zeros((2,))
p = (w,b)
#init_fn, update_fn, get_p = jax.experimental.optimizers.momentum(0.005, 0.09)
#init_fn, update_fn, get_p = jax.experimental.optimizers.momentum(0.001, 0.01)
#init_fn, update_fn, get_p = jax.experimental.optimizers.momentum(0.001, 0.9)
init_fn, update_fn, get_p = jax.experimental.optimizers.momentum(0.01, 0.9)
opt_state = init_fn(p)
v_grad_fn = jax.jit(jax.value_and_grad(align_loss))
def update(i, opt_state):
p = get_p(opt_state)
v, grad = v_grad_fn(p)
if i % 50 == 0:
print(v)
return update_fn(i, grad, opt_state)
for i in range(1000):
opt_state = update(i, opt_state)
w, b = get_p(opt_state)
return w, b
In [0]:
import matplotlib.pyplot as plt
import seaborn as sns
from matplotlib import collections as mc
from matplotlib import cm
sns.set_style("whitegrid")
print(f"est: {len(estimators)}")
fig, ax = plt.subplots(2, 3, figsize=(12, 7))
X1 = None
X2 = None
texts = [tt.ints2str([i]) for i in range(N)]
idxs = np.asarray([i for i in range(N)])
# Set data offsets for a nice visualization if text is desired
offset = 70 # picked to avoid subwords that aren't full words
ss = 90 # picked to make plot easier to read. see also `offset` comment
texts = texts[offset:ss]
# set colours:
plasma1 = cm.get_cmap('tab20c', len(texts)).colors
plasma2 = cm.get_cmap('tab20b', len(texts)).colors
plasma3 = cm.get_cmap('tab20', len(texts)).colors
num_to_text = 0
lines_fun = lambda V: [[(0,0),(V[i,0], V[i,1])] for i in range(len(V))]
def set_lims(t_ax):
t_ax.set_xlim([-5,5])
t_ax.set_ylim([-3,3])
def draw_arrow(vec, num, color="lightgrey", alpha=1.0):
ax[idx][num].quiver(0, 0, vec[:,0], vec[:,1], alpha=alpha)#, scale = 4, zorder=10,
# angles='xy', color=color, linestyle='--')
def draw_arrow_drop_shadow(vec, num):
ax[idx][num].quiver(0, 0, vec[:,0], vec[:,1])#, width=.020, scale=3.6,
#zorder=9, angles='c', color="white")
def number_text(curr_idx, num, A, idxs_to_label, zorder=100):
for i in idxs_to_label:
# print(f"text, i: {texts[i]}, %d"%i)
if texts[i] == "han":
texts[i] = "hand" # labels are subwords and arbitrary
# print(f"vec: {len(A[:,0])}")
scale = 0
y_jitter = scale * np.random.normal()
# if texts[i] == "known " and (curr_idx == 0 or curr_idx == 1) and num == 2:
# print(f"saw {texts[i]}")
# y_jitter = -.2
# if texts[i] == "increase " and curr_idx == 1:
# if num == 0:
# y_jitter = -1.0
# elif num == 1:
# y_jitter = 2.3
# print(f"saw {texts[i]}")
ax[curr_idx][num].text(A[i, 0], A[i, 1] + y_jitter, texts[i], zorder=zorder)
def stdize(vec): # seems to have no effect
print(vec.shape)
return (vec - np.mean(vec, axis=0, keepdims=True)) / np.std(vec, axis=0, keepdims=True)
def draw_lines(lines, this_ax, c, cmap=None, alpha=1.0, zorder=10, autoscale=False):
lc = mc.LineCollection(lines, colors=cmap, linewidths=4, alpha=alpha, zorder=zorder,)
this_ax.add_collection(lc)
if autoscale:
this_ax.autoscale()
dist = lambda p1, p2: np.sqrt( ((p1[0]-p2[0])**2)+((p1[1]-p2[1])**2) )
for idx, est in enumerate([estimators[i] for i in [1, 2]]):
#for idx, est in enumerate(estimators):
# use different colours
if idx == 0:
plasma = plasma1
elif idx == 1:
plasma = plasma2
else:
plasma = plasma3
print(f"INDEX: {idx}")
# get embeddings
embeddings1 = est[0].word_idx_to_score.q(idxs.reshape([-1, 1]))
embeddings2 = est[1].word_idx_to_score.q(idxs.reshape([-1, 1]))
scale = 1 # don't change the scale
X1_full = stdize(scale*embeddings1.numpy().squeeze())
X2_full = stdize(scale*embeddings2.numpy().squeeze())
X1 = X1_full[offset:ss]
X2 = X2_full[offset:ss]
# print(f"len texts: {len(texts)}")
c2 = "red" # tableau20blind[8]
c1 = "blue" # tableau20blind[4]
alph = 1.0
lt=" "
# compute lengths of all the arrows
dists_X1 = [dist(x[0],x[1]) for x in lines_fun(X1)]
dists_X2 = [dist(x[0],x[1]) for x in lines_fun(X2)]
# find the longest arrows on X1 and X2 (not necessarily the same)
idxs_to_label_X1 = np.argsort(dists_X1)[-num_to_text:]
idxs_to_label_X2 = np.argsort(dists_X2)[-num_to_text:]
# place vector lines on plot
ax[idx][0].plot(X2[:, 0], X2[:, 1], lt, alpha = alph, c=c1)
draw_lines(lines_fun(X2), ax[idx][0], c1, cmap=plasma, alpha=1.0)
ax[idx][1].plot(X1[:, 0], X1[:, 1], lt, alpha = alph, c=c2)
draw_lines(lines_fun(X1), ax[idx][1], c2, cmap=plasma, alpha=1.0)
w_trans, b_trans = align(np.copy(X1_full), np.copy(X2_full))
X1_c = np.copy(X1_full)
X2_c = np.copy(np.matmul(X2_full, w_trans))# + b_trans)
#X2_c = X2_full
X1_c = X1_c[offset:ss]
X2_c = X2_c[offset:ss]
draw_lines(lines_fun(X1_c), ax[idx][2], c2, cmap=plasma, alpha=1.0)
draw_lines(lines_fun(X2_c), ax[idx][2], c1, cmap=plasma, alpha=1.0)
# place text on plot
if num_to_text > 0:
number_text(idx, 0, X2, idxs_to_label_X2)
number_text(idx, 1, X1, idxs_to_label_X2)
number_text(idx, 2, X2_c, idxs_to_label_X2)
sns.despine(left=True, right=True, bottom=True, top=True)
for t_ax in ax[idx]:
t_ax.set_xlim(-1, 1)
t_ax.set_ylim(-1, 1)
for t_ax in ax[idx]:
t_ax.set_xticks([])
t_ax.set_yticks([])
fig.tight_layout()
fig.savefig("fig1.pdf", bbox_inches="tight")
%download_file test.pdf