In [ ]:
import numpy as np

import matplotlib.pyplot as plt
plt.rcParams["figure.figsize"] = [10,6]

#from copy import deepcopy

In [ ]:
import tensorflow as tf
import tensorflow.contrib.eager as tfe

tf.enable_eager_execution()

print("TensorFlow version: {}".format(tf.VERSION))
print("Eager execution: {}".format(tf.executing_eagerly()))

In [ ]:
seed = 0
plot = True

## Inner Optimisations
inner_stepsize = 0.02 # stepsize in inner SGD
inner_epochs   = 1 # number of epochs of each inner SGD
inner_batchsize= 10 # Size of training minibatches

## Outer Optimisations
# initial stepsize of outer optimization, i.e., meta-optimization
outer_stepsize0 = 0.1 
# number of outer updates; each iteration we sample one task and update on it
outer_steps = 10000

In [ ]:
#torch.manual_seed(seed)
tf.set_random_seed(seed+0)

rng = np.random.RandomState(seed)
rng

In [ ]:
# Define task distribution
x_all = np.linspace(-5, 5, 50)[:,None] # All of the x points

In [ ]:
# Generate classification problems that we're going to learn about
def gen_task(): 
    phase = rng.uniform(low=0, high=2*np.pi)
    ampl = rng.uniform(2., 5.)
    f_randomsine = lambda x : np.sin(x+phase)*ampl
    return f_randomsine # i.e. return a random *function* to learn

In [ ]:
# Define model - this model is going to be easily *trainable* to 
#  solve the each of the problems we generate

#model = nn.Sequential(
#    nn.Linear(1, 64),
#    nn.Tanh(), 
#    nn.Linear(64, 64),
#    nn.Tanh(),
#    nn.Linear(64, 1),
#)

#inits=dict()
inits=dict(kernel_initializer='lecun_uniform', bias_initializer='lecun_uniform')

model = tf.keras.Sequential([
  # Reptile paper uses ReLU, but Tanh gives slightly better results
  tf.keras.layers.Dense(64, activation="tanh", input_shape=(1,), **inits), # input shape required
  tf.keras.layers.Dense(64, activation="tanh", **inits),
  tf.keras.layers.Dense(1, **inits)  # NB: Keras 'knows' the previous layer sizes
])

# Alternatively (more PyTorch-y) : 
class SineModel(tf.keras.Model):
  def __init__(self):
    super(SineModel, self).__init__()
    # Input shape not required, since it is set the first time model is called
    #    Problem with shape if we try to define the inits, though
    self.dense1 = tf.keras.layers.Dense(units=64, activation="tanh") 
    self.dense2 = tf.keras.layers.Dense(units=64, activation="tanh")
    self.dense3 = tf.keras.layers.Dense(units=1)  

  def call(self, input):
    """Run the model."""
    x = self.dense1(input)    
    x = self.dense2(x)  
    x = self.dense3(x) 
    return x
#model = SineModel()


#def totorch(x):
#    return torch.autograd.Variable(torch.Tensor(x))

#def train_current_model_on_batch(x, y):
#    x = totorch(x)
#    y = totorch(y)
#    model.zero_grad()
#    ypred = model(x)
#    loss = (ypred - y).pow(2).mean()
#    loss.backward()
#    for param in model.parameters():
#        param.data -= inner_stepsize * param.grad.data

def train_current_model_on_batch(x, y):
    x = tfe.Variable(x, dtype='float32')
    y = tfe.Variable(y, dtype='float32')
    with tfe.GradientTape() as tape:
        ypred = model(x)
        loss = tf.reduce_mean(tf.square(ypred - y))
    grads = tape.gradient(loss, model.variables)    
    
    for param, grad in zip(model.variables, grads):
        param.assign_sub(inner_stepsize * grad)
        
#def predict_using_current_model(x):
#    x = totorch(x)
#    return model(x).data.numpy()

def predict_using_current_model(x):
    x = tfe.Variable(x, dtype='float32')
    return model(x).numpy()

In [ ]:
model.summary()

In [ ]:
# Choose a fixed task and minibatch for visualization
f_plot = gen_task()  # This is one specfic task
xtrain_plot = x_all[rng.choice(len(x_all), size=inner_batchsize)]  # One illustrative batch

In [ ]:
# Reptile training loop
for outer_step in range(outer_steps):
    #weights_before = deepcopy(model.state_dict())
    weights_before = [ param.read_value() for param in model.variables ]
    
    # Generate a task of the sort we want to learn about learning
    f = gen_task()  # This is a different task every time
    y_all = f(x_all) # Get the correct outputs for the x_all values (i.e a dataset)
    
    # Do SGD on this task + dataset
    inds = rng.permutation(len(x_all))  # Shuffle data indices
    for _ in range(inner_epochs):
        for start in range(0, len(x_all), inner_batchsize):
            batch_indices = inds[start:start+inner_batchsize]
            train_current_model_on_batch(
                x_all[batch_indices], y_all[batch_indices]
            )
            
    # Interpolate between current weights and trained weights from this task
    #   i.e. (weights_before - weights_after) is the meta-gradient
    
    #weights_after = model.state_dict()
    weights_after = [ param.read_value() for param in model.variables ]
    
    outer_stepsize = outer_stepsize0*(1-outer_step/outer_steps) # linear schedule
    
    # This updates the weights in the model -
    #   Not a 'training' gradient update, but one reflecting
    #   the training that occurred during the inner loop training
    
    #model.load_state_dict({name : 
    #    weights_before[name] + (weights_after[name]-weights_before[name])*outer_stepsize 
    #    for name in weights_before})
    
    for param, w_before, w_after in zip(model.variables, weights_before, weights_after):
        param.assign( w_before + (w_after-w_before)*outer_stepsize )

    # Periodically plot the results on a particular task and minibatch
    if plot and outer_step==0 or (outer_step+1) % 1000 == 0:
        fig = plt.figure(figsize=(10,6))
        ax = plt.subplot(111)
        
        f = f_plot
        
        #weights_before = deepcopy(model.state_dict()) # save snapshot before evaluation
        weights_saved = [ param.read_value() for param in model.variables ]
        
        # Plot the initial model (having seen no data to train on for this task)
        plt.plot(x_all, predict_using_current_model(x_all), 
                 label="pred after 0", color=(0.5,0.5,1))

        for inner_step in range(1, 32+1):
            train_current_model_on_batch(xtrain_plot, f(xtrain_plot))
            #if (inner_step) % 8 == 0:
            if inner_step in [1, 2, 4, 8, 16, 32]:
                frac = np.log(inner_step)/np.log(32.)
                plt.plot(x_all, predict_using_current_model(x_all), 
                         label="pred after %i" % (inner_step), 
                         color=(frac, 0, 1-frac))
        lossval = np.square(predict_using_current_model(x_all) - f(x_all)).mean()

        plt.plot(x_all, f(x_all), label="true", color=(0,1,0))
        
        plt.plot(xtrain_plot, f(xtrain_plot), "x", label="train", color="k")
        plt.ylim(-4.5,4.5)
        
        box = ax.get_position()
        ax.set_position([box.x0, box.y0, box.width * 0.8, box.height])
        ax.legend(loc='center left', bbox_to_anchor=(1, 0.5))        
        
        plt.show()
        
        print(f"outer_step             {outer_step+1:6d}")
        # would be better to average loss over a set of examples, 
        #   but this is optimized for brevity
        print(f"loss on plotted curve   {lossval:.3f}") 
        print(f"-----------------------------")

        #model.load_state_dict(weights_before) # restore from snapshot 
        for param, w_saved in zip(model.variables, weights_saved):
            param.assign( w_saved )
        
print("FINISHED")

In [ ]:
#weights_before

In [ ]:


In [ ]:

Test code


In [ ]:
a = tfe.Variable([[1, 2], [3, 4]])

b = a.read_value()
d = a.numpy()

a.assign_add([[2, 2], [1, 1]] )

In [ ]:
b

In [ ]:
d

In [ ]: