In [1]:
import numpy as np
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
import time

In [2]:
mnist = input_data.read_data_sets('', one_hot = True)


Extracting train-images-idx3-ubyte.gz
Extracting train-labels-idx1-ubyte.gz
Extracting t10k-images-idx3-ubyte.gz
Extracting t10k-labels-idx1-ubyte.gz

In [3]:
def embed_seq(inputs, vocab_size=None, embed_dim=None, zero_pad=False, scale=False):
    lookup_table = tf.get_variable('lookup_table', dtype=tf.float32, shape=[vocab_size, embed_dim])
    if zero_pad:
        lookup_table = tf.concat((tf.zeros([1, embed_dim]), lookup_table[1:, :]), axis=0)
    outputs = tf.nn.embedding_lookup(lookup_table, inputs)
    if scale:
        outputs = outputs * (embed_dim ** 0.5)
    return outputs

def learned_positional_encoding(inputs, embed_dim, zero_pad=False, scale=False):
    T = inputs.get_shape().as_list()[1]
    outputs = tf.range(T) 
    outputs = tf.expand_dims(outputs, 0)
    outputs = tf.tile(outputs, [tf.shape(inputs)[0], 1])
    return embed_seq(outputs, T, embed_dim, zero_pad=zero_pad, scale=scale)

def layer_norm(inputs, epsilon=1e-8):
    mean, variance = tf.nn.moments(inputs, [-1], keep_dims=True)
    normalized = (inputs - mean) / (tf.sqrt(variance + epsilon))
    params_shape = inputs.get_shape()[-1:]
    gamma = tf.get_variable('gamma', params_shape, tf.float32, tf.ones_initializer())
    beta = tf.get_variable('beta', params_shape, tf.float32, tf.zeros_initializer())
    return gamma * normalized + beta

def pointwise_feedforward(inputs, num_units=[None, None], activation=None):
    outputs = tf.layers.conv1d(inputs, num_units[0], kernel_size=1, activation=activation)
    outputs = tf.layers.conv1d(outputs, num_units[1], kernel_size=1, activation=None)
    outputs += inputs
    outputs = layer_norm(outputs)
    return outputs

class Model:
    def __init__(self, dimension_output=10, learning_rate=1e-4, 
                 seq_len=28, dimension_input=28, num_heads=28, attn_windows=range(1, 6)):
        self.size_layer = dimension_input
        self.num_heads = num_heads
        self.seq_len = seq_len
        self.X = tf.placeholder(tf.float32, [None, seq_len, dimension_input])
        self.Y = tf.placeholder(tf.float32, [None, dimension_output])
        feed = self.X
        for i, win_size in enumerate(attn_windows):
            with tf.variable_scope('attn_masked_window_%d' % win_size):
                feed = self.multihead_attn(feed, self.window_mask(win_size))
        feed += learned_positional_encoding(feed, dimension_input)
        with tf.variable_scope('multihead'):
            feed = self.multihead_attn(feed, None)
        with tf.variable_scope('pointwise'):
            feed = pointwise_feedforward(feed, num_units=[4*dimension_input, 
                                                          dimension_input], activation=tf.nn.relu)
        self.logits = tf.layers.dense(feed, dimension_output)[:,-1]
        self.cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits = self.logits, labels = self.Y))
        self.optimizer = tf.train.AdamOptimizer(learning_rate = learning_rate).minimize(self.cost)
        self.correct_pred = tf.equal(tf.argmax(self.logits, 1), tf.argmax(self.Y, 1))
        self.accuracy = tf.reduce_mean(tf.cast(self.correct_pred, tf.float32))
        
    def multihead_attn(self, inputs, masks):
        T_q = T_k = inputs.get_shape().as_list()[1]           
        Q_K_V = tf.layers.dense(inputs, 3*self.size_layer, tf.nn.relu)
        Q, K, V = tf.split(Q_K_V, 3, -1)
        Q_ = tf.concat(tf.split(Q, self.num_heads, axis=2), axis=0)
        K_ = tf.concat(tf.split(K, self.num_heads, axis=2), axis=0)
        V_ = tf.concat(tf.split(V, self.num_heads, axis=2), axis=0)
        align = tf.matmul(Q_, tf.transpose(K_, [0,2,1]))
        align = align / np.sqrt(K_.get_shape().as_list()[-1])
        if masks is not None:
            paddings = tf.fill(tf.shape(align), float('-inf'))                         
            align = tf.where(tf.equal(masks, 0), paddings, align) 
        align = tf.nn.softmax(align)
        outputs = tf.matmul(align, V_)
        outputs = tf.concat(tf.split(outputs, self.num_heads, axis=0), axis=2)
        outputs += inputs                
        return layer_norm(outputs)
        
    def window_mask(self, h_w):
        masks = np.zeros([self.seq_len, self.seq_len])
        for i in range(self.seq_len):
            if i < h_w:
                masks[i, :i+h_w+1] = 1.
            elif i > self.seq_len - h_w - 1:
                masks[i, i-h_w:] = 1.
            else:                                                             
                masks[i, i-h_w:i+h_w+1] = 1.
        masks = tf.convert_to_tensor(masks)
        return tf.tile(tf.expand_dims(masks,0), [tf.shape(self.X)[0]*self.num_heads, 1, 1])

In [4]:
sess = tf.InteractiveSession()
model = Model()
sess.run(tf.global_variables_initializer())

In [5]:
EPOCH = 10
BATCH_SIZE = 128

In [ ]:
for i in range(EPOCH):
    last = time.time()
    TOTAL_LOSS, ACCURACY = 0, 0
    for n in range(0, (mnist.train.images.shape[0] // BATCH_SIZE) * BATCH_SIZE, BATCH_SIZE):
        batch_x = mnist.train.images[n: n + BATCH_SIZE, :].reshape((-1, 28, 28))
        acc, cost, _ = sess.run([model.accuracy, model.cost, model.optimizer], 
                           feed_dict = {model.X : batch_x, 
                                         model.Y : mnist.train.labels[n: n + BATCH_SIZE, :]})
        ACCURACY += acc
        TOTAL_LOSS += cost
    TOTAL_LOSS /= (mnist.train.images.shape[0] // BATCH_SIZE)
    ACCURACY /= (mnist.train.images.shape[0] // BATCH_SIZE)
    print('epoch %d, avg loss %f, avg acc %f, time taken %f secs'%(i+1,TOTAL_LOSS,ACCURACY,time.time()-last))


epoch 1, avg loss 1.731810, avg acc 0.382995, time taken 71.409638 secs
epoch 2, avg loss 1.012760, avg acc 0.664663, time taken 71.097286 secs

In [ ]: