Copyright 2019 The Google Research Authors.

Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at

 http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License.


In [0]:
import tensorflow as tf
tf.enable_v2_behavior()
tf.compat.v1.enable_resource_variables()
import numpy as np
from tensorflow import keras
from tensorflow.keras import layers as L
from tensorflow.keras import Model
import tensorflow_datasets as tfds
from tensorflow_probability.python.distributions import categorical

Class ContextEmbedder outputs a predicted representation for a target word by forming a linear combination of context word feature vectors. In particular, for a context $h=w_1, \dots, w_n$ with context embeddings $r_{w_i}$, the predicted representation is: $$ \hat{q}(h) = \sum_{i=1}^n c_i \odot r_{w_i} $$


In [0]:
class ContextEmbedder(Model):
  """Reweights and sums a set of vectors in R^d representing context word
  vectors.
  @param:
    `vocab_size` is how many words in the vocabulary
    `window_size` is how many context words surround the word of interest. Note
      that we assume the window is centered here, and the context includes the word.
      Hence, a window size of 5 means [w_0, w_1, w_2, w_3, w_4, w_5] where w_3 is
      the target word.
    `embed_dim` is the size of the embeddings dimension
  """
  def __init__(self, vocab_size, window_size, embed_dim, exclude_target=True,
               #embeddings_intializer='uniform',
               embeddings_intializer=keras.initializers.RandomUniform(minval=-0.001, maxval=0.001, seed=None),
               ):

    super(ContextEmbedder, self).__init__()
    self._embed_dim = embed_dim
    self._vocab_size = vocab_size
    self._window_size = window_size
    self._exclude_target = exclude_target
    self._target_idx = self._window_size // 2
    self._get_context_weights = L.Embedding(self._window_size, self._embed_dim,
                                            embeddings_initializer=embeddings_intializer)
    self._get_context_embeddings = L.Embedding(self._vocab_size, self._embed_dim,
                                                embeddings_initializer=embeddings_intializer)
    print(f"INFO: exclude target is {self._exclude_target}")
    
  def _exclude_target_idx(self, lst):
    """return list without the target index
    """
    print(f"lst: {lst.shape}")
    out = []
    for row in lst:
      out.append(row[:self._target_idx] + row[self._target_idx + 1:])
    print(out)
    print(tf.map_fn(exclude_helper, lst))

  def embed(self, x):
    """Given (? x `self._window_size`) sequence of integers `x`, returns 
    a 
    (? x `self._window_size` x `self._embed_dim`) tensor of embeddings
    """
    if self._exclude_target:
      y = tf.concat([x[:,:self._target_idx], x[:,self._target_idx + 1:]], axis=-1)
      return self._get_context_embeddings(y)
    else:
      return self._get_context_embeddings(x)
    # return self._get_context_embeddings(self._exclude_target_idx(x))

  def _make_index_tensor(self, x):
    """Returns a tensor of shape `x.shape` where the last axis is always 
    the integers list(range(len(x.shape[1])))
    """
     # create set of index tensors to retrieve position weights
    batch_shape, window_len, _ = x.shape
    index_numbers = list(range(window_len))
    to_tile = tf.reshape(tf.constant(index_numbers), shape=(1, -1))
    context_positions = tf.tile(to_tile, multiples=(x.shape[0], 1))
    # print(context_positions)
    return context_positions

  def _weight_and_sum(self, x):
    context_positions = self._make_index_tensor(x)
    position_weights = self._get_context_weights(context_positions)
    # reweight embeddings vector dims according to `position_weights`, then sum embedding vectors
    weighted_sum = tf.reduce_sum(tf.multiply(position_weights, x), axis=1)
    return weighted_sum

  def call(self, h):
    """Given (? x `self._context_window_size`) sequence of integers `x`, returns
    (? x `self._embed_dim`) tensor of predicted embeddings
    """
    context_word_embeddings = self.embed(h)  # (? x `self._window_size` x `self._embed_dim`)
    predicted_embeddings = self._weight_and_sum(context_word_embeddings)  
    return predicted_embeddings


INFO: exclude target is True
INFO: exclude target is False

Class LogBilinearScore computes a score as the similarity between a predicted vector $\hat{q}$ or qhat and the embedding $q$ or q. We include a term $b_w$ or b to capture the context-independent frequency of a word $w$. Given some $h$ and $\hat{q}(h)$, calling an instance of LogBilinearScore returns $$s_{\theta}(w,h) = \hat{q}(h)^\top q_w + b_w$$


In [0]:
class LogBilinearScore(Model):

  def __init__(self, vocab_size,  embed_dim, embeddings_intializer='uniform'):
    super(LogBilinearScore, self).__init__()
    self.embed_dim = embed_dim
    self.vocab_size = vocab_size
    self.q = L.Embedding(self.vocab_size, self.embed_dim,
                         embeddings_initializer=embeddings_intializer)
    self.b = L.Embedding(self.vocab_size, 1,
                          embeddings_initializer=embeddings_intializer)

  def call(self, w, qhat):
    """ Returns a (? x `self.window_size`) score tensor
    w shape is (? x `self.window_size`)
    qhat is (? x `self._window_size` x `self._embed_dim`) 
    """
    q = self.q(w)              # (? x `self.window_size` x `self.embed_dim`)
    b = tf.squeeze(self.b(w))  # (? x `self.window_size`)
    return tf.reduce_sum(tf.multiply(q, qhat), axis=-1, keepdims=False) + b


INFO: exclude target is True
INFO: exclude target is False

In [0]:
import random
class WindowRebatcher(object):
  """Take a batch of sequences, some padded with `pad_token`. Return a new batch
  of sequences `window_size` over the original sequences
  """

  def __init__(self, window_size, pad_token=-1):
    self._ws = window_size
    self._pt = pad_token
  
  def rebatch(self, batch):
    rebatches = []
    # print(batch)
    for item in batch:
      # print(item)
      end_idx = np.where(item.numpy() == -1)[0]
      if len(end_idx) > 0:
        # print(end_idx[0])
        end_idx = end_idx[0]
      else:
        # print("no -1")
        end_idx = len(item)
        # print(len(item))
      n_windows = end_idx - (self._ws - 1)
      # print(n_windows)
      idxs = [list(range(i, i + self._ws)) for i in range(n_windows)]
      random.shuffle(idxs)
      idxs = tf.constant(idxs[0:20],
                         dtype=tf.int64)
      idxs = tf.constant(idxs,
                    dtype=tf.int64)
      # print(f"idxs:{idxs}")
      if np.sum(idxs) == 0:
        # print("saw empty batch...")
        continue  # empty batch?
      # print(tf.map_fn(lambda x: tf.gather(item, x), idxs))
      rebatches.append(tf.map_fn(lambda x: tf.gather(item, x), idxs))
      # print(f"rebatches_shape:{rebatches[-1].shape}")

    return tf.concat(rebatches, axis=0)

Given a context $h$, an NPLM defines the distribution for the word to be predicted using the scoring function $s_{\theta}(w,h)$ that quantifies the compatibility between the context and the candidate target word. Here $\theta$ are model parameters, which include the word embeddings. The scores are converted to probabilities by exponentiating and normalizing: $$P_{\theta}^h = \frac{\exp (s_{\theta}(w,h))}{\sum_{w'}\exp(s_{\theta}(w',h))}$$ Unfortunately both evaluating $P^h_{\theta}(w)$ and computing the corresponding likelihood gradient requires normalizing over the entire vocabulary, which means that maximum likelihood training of such models takes time linear in the vocabulary size, and thus is prohibitively expensive for all but the smallest vocabularies.


In [0]:
class NPLMEstimator(Model):

  def __init__(self, vocab_size, embed_dim, window_size, mc_noise_samples, noise_pmf=None):
    super(NPLMEstimator, self).__init__()
    assert(window_size % 2 == 1), f"window_size must be odd but saw {window_size}!"
    self.mc_noise_samples = mc_noise_samples
    self.batcher = WindowRebatcher(window_size)
    self.context_embedder = ContextEmbedder(vocab_size, window_size, embed_dim)
    self.word_idx_to_score = LogBilinearScore(vocab_size, embed_dim)
    self.target_idx = window_size // 2
    if noise_pmf is None:
      noise_pmf = tf.ones(shape=(vocab_size,), dtype=tf.float64) / vocab_size
    self.noise_dist = categorical.Categorical(logits=tf.constant(tf.math.log(noise_pmf), 
                                                                 dtype=tf.float64))
    self.logk = tf.log(tf.constant(mc_noise_samples, dtype=tf.float32))
    
  def sample_and_log_prob_noise(self, n):
    samples = []
    log_probs = []
    for _ in range(n):
      samples.append(self.noise_dist.sample(self.mc_noise_samples))
      log_probs.append(self.noise_dist.log_prob(samples[-1]))
    return tf.stack(samples), tf.stack(log_probs)
    
  def del_score(self, score, log_p_noise, simple=True):
    if simple:
      return score
    else: 
    return score - tf.cast(log_p_noise + tf.ones_like(log_p_noise) * tf.cast(self.logk, dtype=tf.float64), dtype=tf.float32)

  def call(self, rebatched):
    """ x is (1, S) sequence, where m changes batch to batch
    Let W =  S - (window_size - 1)
    """
    # --> embed the context words
    qhat = self.context_embedder(rebatched)                    # (W, embed_size)

    # --> compute scores for target word in each windowed batch
    scores = self.word_idx_to_score(rebatched[:, self.target_idx], qhat)     # (W, 1)
    log_p_true_under_noise = self.noise_dist.log_prob(rebatched[:, self.target_idx])

    # --> Monte Carlo estimate of loss under noise distribution
    noise_samples, log_p_noise = self.sample_and_log_prob_noise(rebatched.shape[0])
    noise_words = tf.reshape(noise_samples,           # (? * n_noise_samples, 1)
                            shape=(tf.reduce_prod(noise_samples.shape),))
    noise_log_probs = tf.reshape(log_p_noise,         # (? * n_noise_samples, 1)
                            shape=(tf.reduce_prod(log_p_noise.shape),))
    
    qhat_tiled = tf.reshape(tf.keras.backend.repeat(qhat, n=self.mc_noise_samples),  # (? * n_noise_samples, window_size, embed_size)  
                            shape=(-1, qhat.shape[-1]))
    
    noise_scores_repeated = self.word_idx_to_score(noise_words, qhat_tiled)  # (? * n_noise_samples, 1)
    noise_scores_samples = tf.reshape(noise_scores_repeated,           
                                      shape=(rebatched.shape[0], -1))
    
    noise_scores = tf.reduce_mean(noise_scores_samples, axis=-1)        # (?, 1)

    true_del_score = self.del_score(scores, log_p_true_under_noise)
    noise_del_score_repeated = self.del_score(noise_scores_repeated, 
                                                      noise_log_probs)

    expected_true = tf.nn.sigmoid_cross_entropy_with_logits(labels=tf.ones_like(true_del_score), 
                                                            logits=true_del_score)

    expected_false = tf.reduce_mean(tf.reshape(  
        tf.nn.sigmoid_cross_entropy_with_logits(
            labels=tf.zeros_like(noise_del_score_repeated), 
            logits=noise_del_score_repeated), shape=(rebatched.shape[0], -1)), 
            axis=-1)

    return tf.reduce_mean(expected_true + expected_false)

In [0]:
(train_data, test_data), info = tfds.load( 
  'lm1b/subwords8k',
  split = (tfds.Split.TRAIN, tfds.Split.TEST), 
  with_info=True, as_supervised=True)

padded_shapes = ([None],                             
                # ()) 
                (None,))

pad_int = -1     # integer flag that never appears in dataset
batch_size = 10  # retrieve one sequence at a time

padding_values = (tf.constant(pad_int, dtype=tf.int64),
                  tf.constant(0, dtype=tf.int64)) 

train_batches = train_data.shuffle(1024).padded_batch(
  batch_size,
  padded_shapes,
  padding_values=padding_values)

test_batches = test_data.shuffle(1024).padded_batch(
    batch_size,
    padded_shapes,
    padding_values=padding_values)

In [0]:
window_size = 5           # @param
mc_noise_samples = 25     # @param
embed_dim = 2             # @param
vocab_size = info.features['text'].encoder.vocab_size
train_iter = 50000         # @param
log_interval = 10000       # @param
epochs = 1

In [0]:
# !!! Only run this cell to retrain from scratch

# %%time
# save_interval = 5
# save_path = "./


# @tf.function
# def update(rebatched):
#   with tf.GradientTape() as tape:
#     loss = nce_scorer(rebatched)
  
#   gradients = tape.gradient(loss, nce_scorer.trainable_variables)
#   optimizer.apply_gradients(zip(gradients, nce_scorer.trainable_variables))
#   return loss

# losses = []
# for e in range(epochs):

#   for idx, batch in enumerate(iter(train_batches)):
#     seq, _ = batch  # (1, ?)
#     rebatched = nce_scorer.batcher.rebatch(seq)                       # (W, window_size)
#     loss = update(rebatched)
#     print(f"iter: {idx}")
#     if idx % save_interval == 0:
#       print("saving>...")
#       nce_scorer.save_weights(save_path)
#     # nce_scorer.save_weights
#     losses.append(loss)
  
#     # with tf.GradientTape() as tape:
#     #   loss = nce_scorer(seq)
    
#     # gradients = tape.gradient(loss, nce_scorer.trainable_variables)
#     # optimizer.apply_gradients(zip(gradients, nce_scorer.trainable_variables))

#     if idx % log_interval == 0:
#       losses.append(loss)
#       print(f"Epoch: {e} | iter {idx} | Loss: {loss}")
#     if (idx + 1) % train_iter == 0:
#       break
    
#   losses.append(loss)
#   print(f"Epoch End: {e} Loss: {loss}")

# HERE: save the model output weights to a directory of your choosing

In [0]:
from sklearn.cross_decomposition import CCA
from numpy import linalg as LA

model_ids = [0, 2, 3, 4]

# Model checkpoints have been provided with this script

# Choose the appropriate root directory for the data, e.g., Google Drive
USER_ROOT_DIR = "./" 

model_paths = [[os.path.join(USER_ROOT_DIR, "2_0.003000_%d/model_checkpoints/rep1"%x),
                os.path.join(USER_ROOT_DIR, "2_0.003000_%d/model_checkpoints/rep2"%x)] for x in model_ids]

print(model_paths)

estimators = [[NPLMEstimator(vocab_size, embed_dim, window_size, mc_noise_samples, noise_pmf=noise_pmf),
              NPLMEstimator(vocab_size, embed_dim, window_size, mc_noise_samples, noise_pmf=noise_pmf)] for _ in range(len(model_paths))]

for idx, est in enumerate(estimators):
  est[0].load_weights(model_paths[idx][0])
  est[1].load_weights(model_paths[idx][1])

In [0]:
# Tableau Color Blind 10
tableau20blind = [(0, 107, 164), (255, 128, 14), (171, 171, 171), (89, 89, 89),
             (95, 158, 209), (200, 82, 0), (137, 137, 137), (163, 200, 236),
             (255, 188, 121), (207, 207, 207)]
  
for i in range(len(tableau20blind)):  
    r, g, b = tableau20blind[i]  
    tableau20blind[i] = (r / 255., g / 255., b / 255.)

In [0]:
import jax.numpy as jnp
import jax
import jax.experimental.optimizers

# learn the best possible linear transformation by regressing

def align(act1, act2):
  def align_loss(p):
    w, b = p
    pred = jnp.matmul(act1, w)# + b
    return jnp.mean(jnp.square(act2 - pred))
    
  w = jnp.zeros((2,2))
  b = jnp.zeros((2,))
  p = (w,b)
  #init_fn, update_fn, get_p = jax.experimental.optimizers.momentum(0.005, 0.09)
  #init_fn, update_fn, get_p = jax.experimental.optimizers.momentum(0.001, 0.01)
  #init_fn, update_fn, get_p = jax.experimental.optimizers.momentum(0.001, 0.9)
  init_fn, update_fn, get_p = jax.experimental.optimizers.momentum(0.01, 0.9)
  opt_state = init_fn(p)
  v_grad_fn = jax.jit(jax.value_and_grad(align_loss))
  
  def update(i, opt_state):
    p = get_p(opt_state)
    v, grad = v_grad_fn(p)
    if i % 50 == 0:
      print(v)
    return update_fn(i, grad, opt_state)
  
  for i in range(1000):
    opt_state = update(i, opt_state)
  w, b = get_p(opt_state)
  return w, b

In [0]:
import matplotlib.pyplot as plt
import seaborn as sns
from matplotlib import collections  as mc
from matplotlib import cm

sns.set_style("whitegrid")
print(f"est: {len(estimators)}")

fig, ax = plt.subplots(2, 3, figsize=(12, 7))
X1 = None
X2 = None

texts = [tt.ints2str([i]) for i in range(N)]
idxs = np.asarray([i for i in range(N)])

# Set data offsets for a nice visualization if text is desired
offset = 70    # picked to avoid subwords that aren't full words
ss = 90        # picked to make plot easier to read. see also `offset` comment
texts = texts[offset:ss]

# set colours:
plasma1 = cm.get_cmap('tab20c', len(texts)).colors
plasma2 = cm.get_cmap('tab20b', len(texts)).colors
plasma3 = cm.get_cmap('tab20', len(texts)).colors

num_to_text = 0
lines_fun = lambda V: [[(0,0),(V[i,0], V[i,1])] for i in range(len(V))]

def set_lims(t_ax):
  t_ax.set_xlim([-5,5])
  t_ax.set_ylim([-3,3])

def draw_arrow(vec, num, color="lightgrey", alpha=1.0):
  ax[idx][num].quiver(0, 0, vec[:,0], vec[:,1], alpha=alpha)#,  scale = 4, zorder=10, 
                    # angles='xy', color=color, linestyle='--')


def draw_arrow_drop_shadow(vec, num):
  ax[idx][num].quiver(0, 0, vec[:,0], vec[:,1])#, width=.020, scale=3.6,  
                      #zorder=9, angles='c', color="white")

def number_text(curr_idx, num, A, idxs_to_label, zorder=100):
  for i in idxs_to_label:
    # print(f"text, i: {texts[i]}, %d"%i)
    if texts[i] == "han":
      texts[i] = "hand"  # labels are subwords and arbitrary
    # print(f"vec: {len(A[:,0])}")
    scale = 0
    y_jitter = scale * np.random.normal()
    # if texts[i] == "known " and (curr_idx == 0 or curr_idx == 1) and num == 2:
    #   print(f"saw {texts[i]}")
    #   y_jitter = -.2
    # if texts[i] == "increase " and curr_idx == 1:
    #   if num == 0:
    #     y_jitter = -1.0
    #   elif num == 1:
    #     y_jitter = 2.3
    #   print(f"saw {texts[i]}")
    ax[curr_idx][num].text(A[i, 0], A[i, 1] + y_jitter, texts[i], zorder=zorder)

def stdize(vec):  # seems to have no effect
  print(vec.shape)
  return (vec - np.mean(vec, axis=0, keepdims=True)) / np.std(vec, axis=0, keepdims=True)

def draw_lines(lines, this_ax, c, cmap=None, alpha=1.0, zorder=10, autoscale=False):
  lc = mc.LineCollection(lines, colors=cmap, linewidths=4, alpha=alpha, zorder=zorder,)
  this_ax.add_collection(lc)
  if autoscale:
    this_ax.autoscale()

dist = lambda p1, p2: np.sqrt( ((p1[0]-p2[0])**2)+((p1[1]-p2[1])**2) )

for idx, est in enumerate([estimators[i] for i in [1, 2]]):
#for idx, est in enumerate(estimators):
  # use different colours
  if idx == 0:
    plasma = plasma1
  elif idx == 1:
    plasma = plasma2
  else:
    plasma = plasma3

  print(f"INDEX: {idx}")
  # get embeddings
  embeddings1 = est[0].word_idx_to_score.q(idxs.reshape([-1, 1]))
  embeddings2 = est[1].word_idx_to_score.q(idxs.reshape([-1, 1]))

  scale = 1  # don't change the scale

  X1_full = stdize(scale*embeddings1.numpy().squeeze())
  X2_full = stdize(scale*embeddings2.numpy().squeeze())


  X1 = X1_full[offset:ss]
  X2 = X2_full[offset:ss]

  # print(f"len texts: {len(texts)}")

  c2 = "red"   # tableau20blind[8]
  c1 = "blue"  # tableau20blind[4]
  alph = 1.0
  lt=" "

  # compute lengths of all the arrows
  dists_X1 = [dist(x[0],x[1]) for x in lines_fun(X1)]
  dists_X2 = [dist(x[0],x[1]) for x in lines_fun(X2)]

  # find the longest arrows on X1 and X2 (not necessarily the same)
  idxs_to_label_X1 = np.argsort(dists_X1)[-num_to_text:]
  idxs_to_label_X2 = np.argsort(dists_X2)[-num_to_text:]

  # place vector lines on plot
  ax[idx][0].plot(X2[:, 0], X2[:, 1], lt, alpha = alph, c=c1)
  draw_lines(lines_fun(X2), ax[idx][0], c1, cmap=plasma, alpha=1.0)

  ax[idx][1].plot(X1[:, 0], X1[:, 1], lt, alpha = alph, c=c2)
  draw_lines(lines_fun(X1), ax[idx][1], c2, cmap=plasma, alpha=1.0)



  w_trans, b_trans = align(np.copy(X1_full), np.copy(X2_full))
  X1_c = np.copy(X1_full)
  X2_c = np.copy(np.matmul(X2_full, w_trans))# + b_trans)
  #X2_c = X2_full

  X1_c = X1_c[offset:ss]
  X2_c = X2_c[offset:ss]
  
  draw_lines(lines_fun(X1_c), ax[idx][2], c2, cmap=plasma, alpha=1.0)
  draw_lines(lines_fun(X2_c), ax[idx][2], c1, cmap=plasma, alpha=1.0)

    # place text on plot
  if num_to_text > 0:
    number_text(idx, 0, X2, idxs_to_label_X2)
    number_text(idx, 1, X1, idxs_to_label_X2)
    number_text(idx, 2, X2_c, idxs_to_label_X2)

  sns.despine(left=True, right=True, bottom=True, top=True)

  for t_ax in ax[idx]:
    t_ax.set_xlim(-1, 1)
    t_ax.set_ylim(-1, 1)

  for t_ax in ax[idx]:
    t_ax.set_xticks([])
    t_ax.set_yticks([])
fig.tight_layout()
fig.savefig("fig1.pdf", bbox_inches="tight")
%download_file test.pdf