In [0]:
# Copyright 2019 Google LLC.
#
# Licensed under the Apache License, Version 2.0 (the "License")
import tensorflow as tf
import numpy as np
from collections import namedtuple
import matplotlib.pyplot as plt
import random
import time
import collections
import seaborn as sns
import pickle
import tensorflow_probability as tfp
from tensorflow_probability.python.layers import util as tfp_layers_util
tf.enable_v2_tensorshape()
NeuralProcessParams = collections.namedtuple('NeuralProcessParams',
['dim_r', 'dim_z','dim_w',
'n_hidden_units_r',
'n_hidden_units_g'])
GaussianParams = collections.namedtuple('GaussianParams', ['mu', 'sigma'])
cycle = plt.rcParams['axes.prop_cycle'].by_key()['color']
In [0]:
#choose the model here
exp_type = 'CNP' # choose from 'CNP', 'MR-CNP-A', 'MR-CNP-W'
noise_scale = 0.1 #@param {type:"number"}
n_obs = 20 #@param {type:"number"}
n_context = 10 #@param {type:"number"}
K_amp = 20 #@param {type:"number"}
x_width = 5 #@param {type:"number"}
learning_rate= 0.001 #@param {type:"number"}
n_iter = 30000 #@param {type:"number"}
amps = np.linspace(0.1,4,K_amp)
params = \
NeuralProcessParams(dim_r=10, dim_z=5, dim_w = 5,
n_hidden_units_r=[100,100],
n_hidden_units_g=[100,100])
In [0]:
# auxilliary functions
def split_context_target(xs, ys, n_context,
context_xs, context_ys, target_xs, target_ys):
"""
split samples randomly into task training and task validation sets.
"""
indices = set(range(ys.shape[0]))
context_set_indices = set(random.sample(indices, n_context))
target_set_indices = indices - context_set_indices
return {
context_xs: xs[list(context_set_indices), :],
context_ys: ys[list(context_set_indices), :],
target_xs: xs[list(target_set_indices), :],
target_ys: ys[list(target_set_indices), :]
}
def sampling(output):
"""
sample from Gaussian Distribution
"""
mu, logstd = tf.split(output, num_or_size_splits=2, axis=-1)
sigma = tf.nn.softplus(logstd)
ws = mu + tf.random_normal(tf.shape(mu)) * sigma
return ws, mu, sigma
def merge(A, B):
'''A is [n, k1], B is [m, k2], return [n, m, (k1+k2)]
'''
A_repeat = tf.expand_dims(A, axis=1)
A_repeat = tf.tile(A_repeat, [1, tf.shape(B)[0], 1])
B_repeat = tf.expand_dims(B, axis=0)
B_repeat = tf.tile(B_repeat, [tf.shape(A)[0], 1, 1])
return tf.concat([A_repeat, B_repeat], axis=2)
def kl_qp_gaussian(mu_q, sigma_q, mu_p, sigma_p):
"""KL(N(mu_q), Diag(sigma_q^2) || N(mu_p), Diag(sigma_p^2))"""
sigma2_q = tf.square(sigma_q) + 1e-16
sigma2_p = tf.square(sigma_p) + 1e-16
temp = tf.log(sigma2_p) - tf.log(sigma2_q) - 1.0 + \
sigma2_q / sigma2_p + tf.square(mu_q - mu_p) / sigma2_p #N*D
kl = 0.5 * tf.reduce_sum(temp, axis = 1)
return tf.reduce_mean(kl)
In [0]:
# (MR-)CNP model
if exp_type == 'MR-CNP-W':
kernel_posterior_fn=tfp_layers_util.default_mean_field_normal_fn(untransformed_scale_initializer=tf.compat.v1.initializers.random_normal(
mean=-3., stddev=0.1))
encoder_w = tf.keras.Sequential([
tfp.layers.DenseReparameterization(100, activation=tf.nn.relu, kernel_posterior_fn = kernel_posterior_fn),
tfp.layers.DenseReparameterization(params.dim_w),
])
if exp_type == 'MR-CNP-A':
def encoder_w(xas, params):
hidden_layer = xas
hidden_layer = tf.layers.dense(hidden_layer, 100,
activation=tf.nn.relu,
name='encoder_w_{}'.format(0),
reuse=tf.AUTO_REUSE,
kernel_initializer='normal')
# Last layer is simple linear
mu_w = tf.layers.dense(hidden_layer, params.dim_w, name='w_params_mu',
reuse=tf.AUTO_REUSE, kernel_initializer='normal')
sigma_w = tf.layers.dense(hidden_layer, params.dim_w, name='w_params_sigma',
reuse=tf.AUTO_REUSE,
kernel_initializer='normal')
sigma_w = tf.nn.softplus(sigma_w)
return mu_w, sigma_w
def encoder_r(xyas, params):
"""
encode task training data.
"""
hidden_layer = xyas
for i, n_hidden_units in enumerate(params.n_hidden_units_r):
hidden_layer = tf.layers.dense(hidden_layer, n_hidden_units,
activation=tf.nn.relu,
name='encoder_layer_{}'.format(i),
reuse=tf.AUTO_REUSE,
kernel_initializer='normal')
# Last layer is linear
i = len(params.n_hidden_units_r)
r = tf.layers.dense(hidden_layer, params.dim_r,
name='encoder_layer_{}'.format(i),
reuse=tf.AUTO_REUSE,
kernel_initializer='normal')
return r
def xy_to_z_params(xs, ys, amplitude, params, encoder_w=None):
'''
Aggregator of task training data.
i) rs = T1(xas, ys); ii) r = mean(rs) iii)z = T2(r)
'''
# i)
amplitude = tf.tile(amplitude, [tf.shape(xs)[0], 1])
if exp_type == 'CNP':
xyas = tf.concat([xs, ys, amplitude], axis=1)
rs = encoder_r(xyas, params)
if exp_type == 'MR-CNP-W':
xas = tf.concat([xs, amplitude], axis=1)
ws = encoder_w(xas)
wys = tf.concat([ws, ys], axis=1)
rs = encoder_r(wys, params)
if exp_type == 'MR-CNP-A':
xas = tf.concat([xs, amplitude], axis=1)
mu_w, sigma_w= encoder_w(xas, params)
ws = mu_w + tf.random_normal(tf.shape(mu_w)) * sigma_w
wys = tf.concat([ws, ys], axis=1)
rs = encoder_r(wys, params)
# ii)
r = tf.reshape(tf.reduce_mean(rs, axis=0), [1, -1])
# iii)
z_sample = tf.layers.dense(r, params.dim_z, name='z_params_mu',
reuse=tf.AUTO_REUSE, kernel_initializer='normal')
if exp_type == 'MR-CNP-A':
return z_sample, mu_w, sigma_w
else:
return z_sample
def decoder_g(zws, params, activation = None):
'''
y_hat = G(context, task testing input)
'''
hidden_layer = zws
for i, n_hidden_units in enumerate(params.n_hidden_units_g):
hidden_layer = tf.layers.dense(hidden_layer, n_hidden_units,
activation=tf.nn.relu,
name='decoder_layer_{}'.format(i),
reuse=tf.AUTO_REUSE,
kernel_initializer='normal')
# Last layer is linear
i = len(params.n_hidden_units_g)
y_hat = tf.layers.dense(hidden_layer, 1,
name='decoder_layer_{}'.format(i),
reuse=tf.AUTO_REUSE,
kernel_initializer='normal')
return y_hat
def neg_loglikelihood(y_star, mu_star, noise_scale=noise_scale):
p_normal = tf.distributions.Normal(loc=mu_star, scale=noise_scale)
loglike = p_normal.log_prob(y_star) #n_target * n_z
loglike = tf.reduce_sum(loglike, axis=0)
loglike = tf.reduce_mean(loglike)
return -loglike
In [0]:
tf.reset_default_graph()
# Placeholders for training inputs
context_xs = tf.placeholder(tf.float32, (None, 1))
context_ys = tf.placeholder(tf.float32, (None, 1))
target_xs = tf.placeholder(tf.float32, (None, 1))
target_ys = tf.placeholder(tf.float32, (None, 1))
amplitude = tf.placeholder(tf.float32, (1, K_amp))
if exp_type == 'MR-CNP-W':
Beta = tf.placeholder_with_default(0.15, ())
if exp_type == 'MR-CNP-A':
Beta = tf.placeholder_with_default(5.0, ())
if exp_type == 'MR-CNP-A':
z_samples, mu_w, sigma_w = xy_to_z_params(context_xs, context_ys, amplitude, params, encoder_w)
else:
z_samples = xy_to_z_params(context_xs, context_ys, amplitude, params, encoder_w)
target_xas = tf.concat([target_xs, tf.tile(amplitude, [tf.shape(target_xs)[0], 1])], axis=1)
if exp_type == 'CNP':
input_target = merge(z_samples, target_xas)
if exp_type == 'MR-CNP-W':
with tf.variable_scope('encoder_w'):
target_ws = encoder_w(target_xas)
input_target = merge(z_samples, target_ws) #n_z * n_target * (d_z + d_w)
if exp_type == 'MR-CNP-A':
target_mu_w, target_sigma_w = encoder_w(target_xas, params)
target_ws = target_mu_w + tf.random_normal(tf.shape(target_mu_w)) * target_sigma_w
input_target = merge(z_samples, target_ws) #n_z * n_target * (d_z + d_w)
# sample y_hat ~ y|(w,z)
y_star_mu_test = decoder_g(input_target, params)
y_star_mu_test = tf.squeeze(y_star_mu_test, axis=2)
y_star_mu_test = tf.transpose(y_star_mu_test) #n_target * n_z
# loss & optimizer
neg_loglike = neg_loglikelihood(target_ys, y_star_mu_test)
if exp_type == 'CNP':
loss = neg_loglike
if exp_type == 'MR-CNP-W':
kl_loss = Beta * sum(encoder_w.losses)
loss = neg_loglike + kl_loss
if exp_type == 'MR-CNP-A':
kl_ib = kl_qp_gaussian(mu_w, sigma_w,
tf.zeros(tf.shape(mu_w)), tf.ones(tf.shape(mu_w)))
loss = neg_loglike + Beta * kl_ib
optimizer = tf.train.AdamOptimizer(learning_rate)
train_op = optimizer.minimize(loss)
def train_model(sess, n_iter = n_iter, verbose = False):
cost = []; cost_r = []
for i in range(n_iter):
#generate data
xs = np.random.uniform(-x_width, x_width, n_obs)
amp_ind = random.randint(0,K_amp-5)
amp = np.zeros([1, K_amp]); amp[0,amp_ind] = 1
ys = amps[amp_ind] * np.sin(xs) + np.random.normal(scale = noise_scale, size = xs.shape)
feed_dict = split_context_target(xs.reshape(-1, 1), ys.reshape(-1, 1),
n_context, context_xs, context_ys, target_xs, target_ys)
feed_dict.update({amplitude:amp})
#training
A = sess.run((train_op, loss), feed_dict=feed_dict)
cost.append(A[1])
if verbose and i%5000 == 0:
cost_r.append(np.mean(cost))
print('iter=', i, "Loss: {:.3f}".format(np.mean(cost)))
cost = []
return sess
In [0]:
def plot_out(sess, amp_ind, n_context=n_context, amps = amps, width = x_width, ax = None, seed=None, K_test_sample=100, legend = False):
if seed is not None:
np.random.seed(seed**2)
A = np.random.uniform(low = amps[0], high = amps[-1])
a_onehot = np.zeros([1, K_amp]); a_onehot[0,amp_ind] = 1
#task training
xc_test = np.reshape(np.random.uniform(-width, width, n_context),[-1, 1])
yc_test = A*np.sin(xc_test) + np.random.normal(scale = noise_scale, size = xc_test.shape)
#task validation
xs_test = np.reshape(np.linspace(-5, 5, 100),[-1, 1])
ys_test = A*np.sin(xs_test)
mean_curve = 0
for i in range(K_test_sample):
sample_curves = sess.run(y_star_mu_test, feed_dict = {context_xs:xc_test , context_ys:yc_test ,
target_xs:xs_test, target_ys:ys_test, amplitude:a_onehot})
mean_curve += sample_curves[:,0] / (K_test_sample*1.0)
error = np.mean(np.square(mean_curve - ys_test[:,0]))
if ax is not None:
plt.sca(ax)
plt.plot(xs_test,ys_test,'-', color='darkorange', linewidth=2.0, label = 'True Function', zorder=2)
plt.plot(xs_test,mean_curve,'b-', linewidth=2.0, label = 'Prediction', zorder=2)
plt.ylim([-6,6])
plt.axvline(x=-width, linestyle='--')
plt.axvline(x=width,linestyle='--')
plt.tick_params(axis='x', which='both', bottom=False, top=False, labelbottom=False)
plt.plot(xc_test,yc_test, '^', label = 'Context Points', zorder=3)
if legend:
plt.legend( loc='center left', bbox_to_anchor=(1, 0.5))
return error, A
In [0]:
sess = tf.Session()
sess.run(tf.global_variables_initializer())
sess = train_model(sess, n_iter=n_iter, verbose=False)
print("Model is", exp_type)
n_test_task = 100
n_context_test = 5
errors = []
for i in range(n_test_task):
error, A = plot_out(sess, amp_ind = random.randint(5,K_amp-5), n_context=n_context_test, seed = i)
errors.append(error)
print("n_context = ", n_context_test, "error =", np.mean(errors))
n_context_test = 10
errors = []
for i in range(n_test_task):
error, A = plot_out(sess, amp_ind = random.randint(5,K_amp-5), n_context=n_context_test, seed = i)
errors.append(error)
print("n_context = ", n_context_test, "error =", np.mean(errors))
In [0]:
fig, ax = plt.subplots(ncols=4, figsize=(16,4))
a_onehot = 15
error, A = plot_out(sess, a_onehot, seed = 1, ax = ax[0]);
error, A = plot_out(sess, a_onehot, seed = 3, ax = ax[1]);
error, A = plot_out(sess, a_onehot, seed = 5, ax = ax[2]);
error, A = plot_out(sess, a_onehot, seed = 15, ax = ax[3], legend='True');
for ax_i in ax:
ax_i.tick_params(axis='both', which='major',length=0)
In [0]:
sess = tf.Session()
sess.run(tf.global_variables_initializer())
sess = train_model(sess, n_iter=n_iter, verbose=False)
print("Model is", exp_type)
n_test_task = 100
n_context_test = 5
errors = []
for i in range(n_test_task):
error, A = plot_out(sess, amp_ind = random.randint(5,K_amp-5), n_context=n_context_test, seed = i)
errors.append(error)
print("n_context = ", n_context_test, "error =", np.mean(errors))
n_context_test = 10
errors = []
for i in range(n_test_task):
error, A = plot_out(sess, amp_ind = random.randint(5,K_amp-5), n_context=n_context_test, seed = i)
errors.append(error)
print("n_context = ", n_context_test, "error =", np.mean(errors))
In [0]:
fig, ax = plt.subplots(ncols=4, figsize=(16,4))
a_onehot = 15
error, A = plot_out(sess, a_onehot, seed = 1, ax = ax[0]);
error, A = plot_out(sess, a_onehot, seed = 3, ax = ax[1]);
error, A = plot_out(sess, a_onehot, seed = 5, ax = ax[2]);
error, A = plot_out(sess, a_onehot, seed = 15, ax = ax[3], legend='True');
for ax_i in ax:
ax_i.tick_params(axis='both', which='major',length=0)
In [0]:
n_iter = 120000
sess = tf.Session()
sess.run(tf.global_variables_initializer())
sess = train_model(sess, n_iter=n_iter, verbose=False)
print("Model is", exp_type)
n_test_task = 100
n_context_test = 5
errors = []
for i in range(n_test_task):
error, A = plot_out(sess, amp_ind = random.randint(5,K_amp-5), n_context=n_context_test, seed = i)
errors.append(error)
print("n_context = ", n_context_test, "error =", np.mean(errors))
n_context_test = 10
errors = []
for i in range(n_test_task):
error, A = plot_out(sess, amp_ind = random.randint(5,K_amp-5), n_context=n_context_test, seed = i)
errors.append(error)
print("n_context = ", n_context_test, "error =", np.mean(errors))
In [0]:
fig, ax = plt.subplots(ncols=4, figsize=(16,4))
a_onehot = 15
error, A = plot_out(sess, a_onehot, seed = 1, ax = ax[0]);
error, A = plot_out(sess, a_onehot, seed = 3, ax = ax[1]);
error, A = plot_out(sess, a_onehot, seed = 5, ax = ax[2]);
error, A = plot_out(sess, a_onehot, seed = 15, ax = ax[3], legend='True');
for ax_i in ax:
ax_i.tick_params(axis='both', which='major',length=0)
In [0]: