MPLP: MAML Sinusoidal 2 step, 10 shot

Copyright 2020 Google LLC

Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at

https://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License.


In [ ]:
# @title Connect to internal TF kernel and run this.
import os
import io
import numpy as np
import glob

import tensorflow.compat.v2 as tf
tf.enable_v2_behavior()
tf.get_logger().setLevel('ERROR')

import matplotlib.pyplot as plt # visualization

from collections import defaultdict
import random

import itertools
import tensorflow_datasets as tfds
import tensorflow.compat.v2 as tf
import matplotlib.pyplot as plt

import IPython.display as display
from IPython.display import clear_output

from PIL import Image
import numpy as np
import os

In [ ]:
!pip install --upgrade -e git+https://github.com/google-research/self-organising-systems.git#egg=mplp&subdirectory=mplp

In [ ]:
# symlink for saved models.
!ln -s src/mplp/mplp/savedmodels savedmodels

In [ ]:
from mplp.tf_layers import MPDense
from mplp.tf_layers import MPActivation
from mplp.tf_layers import MPSoftmax
from mplp.tf_layers import MPL2Loss
from mplp.tf_layers import MPNetwork
from mplp.sinusoidals import SinusoidalsDS
from mplp.util import SamplePool
from mplp.training import TrainingRegime

The task is to fit sinusoidals from randomly initialized networks.

Therefore, there are:

  • Outer batch size = 4 number of tasks at every step. Each has a different network, different amplitude and different phase.
  • Inner batch size = 10 number of examples for each forward/backward steps.
  • num steps = 5, number of inner steps the network has to get better.
  • train/eval split: the network only sees train instances during forward/backward. The meta-learning regime may choose to use eval splits as well, MAML-style.

In [ ]:
# @title create dataset and plot it
OUTER_BATCH_SIZE = 4
INNER_BATCH_SIZE = 10
NUM_STEPS = 2

ds_factory = SinusoidalsDS()

ds = ds_factory.create_ds(OUTER_BATCH_SIZE, INNER_BATCH_SIZE, NUM_STEPS)
ds_iter = iter(ds)

# Utility range
xrange_inputs = np.linspace(-5,5,100).reshape((100, 1)).astype(np.float32)

xtb, ytb, xeb, yeb = next(ds_iter)
plt.figure(figsize=(14, 10))
colors = itertools.cycle(plt.rcParams['axes.prop_cycle'].by_key()['color'])
for xts, yts, xes, yes in zip(xtb, ytb, xeb, yeb):
  c_t = next(colors)
  c_e = next(colors)
  markers = itertools.cycle((',', '+', '.', 'o', '*')) 
  for xtsib, ytsib, xesib, yesib in zip(xts, yts, xes, yes):
    marker = next(markers)
    plt.scatter(xtsib, ytsib, c=c_t, marker=marker)
    plt.scatter(xesib, yesib, c=c_e, marker=marker)

plt.show()

Create a MP network:


In [ ]:
# This is the size of the message passed.
message_size = 8
stateful = False
stateful_hidden_n = 15

# This network is keras-style initialized.
# If you want to create a single layer, you need to pass it also the in_dim
# and message size.
network = MPNetwork(
    [
     MPDense(20, stateful=stateful, stateful_hidden_n=stateful_hidden_n),
     MPActivation(tf.nn.relu, stateful=stateful, stateful_hidden_n=stateful_hidden_n),
     MPDense(20, stateful=stateful, stateful_hidden_n=stateful_hidden_n), 
     MPActivation(tf.nn.relu, stateful=stateful, stateful_hidden_n=stateful_hidden_n),
     MPDense(1, stateful=stateful, stateful_hidden_n=stateful_hidden_n),
     ],
     MPL2Loss(stateful=stateful, stateful_hidden_n=stateful_hidden_n))
network.setup(in_dim=1, message_size=message_size, inner_batch_size=INNER_BATCH_SIZE)

# see trainable weights:
tr_w = network.get_trainable_weights()
print("trainable weights:")
tot_w = 0
for w in tr_w:
  w = w.numpy()
  w_size = w.size
  tot_w += w_size

  print(w.shape, w_size)
print("tot n:", tot_w)

In [ ]:
# for MAML training, we need one and only one set of variables.
trained_pfw = [tf.Variable(t) for t in network.init()]

In [ ]:
num_steps = tf.constant(NUM_STEPS)

learning_schedule = 1e-4

training_regime = TrainingRegime(
    network, heldout_weight=1.0, hint_loss_ratio=None, remember_loss_ratio=None)

last_step = 0

# minibatch init, allowing to initialize by looking at more
# than just one step.
# Likewise, this can be run multiple times to improve the initialization.
for j in range(1):
  print("on", j)
  stats = []
  pfw = trained_pfw

  x_b, y_b, _, _ = next(ds_iter)
  x_b, y_b = x_b[0], y_b[0]
  for i in range(NUM_STEPS):
    pfw, stats_i = network.minibatch_init(x_b[i],  y_b[i], x_b[i].shape[0], pfw=pfw)
    stats.append(stats_i)
  # update
  network.update_statistics(stats, update_perc=1.)

  print("final mean:")
  for p in tf.nest.flatten(pfw):
    print(p.shape, tf.reduce_mean(p), tf.math.reduce_std(p))



# The outer loop here uses Adam. SGD/Momentum are more stable but way slower.
trainer = tf.keras.optimizers.Adam(learning_schedule)

loss_log = []
def smoothen(l, lookback=20):
  # first of all, if it's a nan, change it to a high value
  kernel = [1./lookback] * lookback
  return np.convolve(l[0:1] * (lookback - 1) + l, kernel, "valid")

In [ ]:
print([p.shape for p in trained_pfw])

In [ ]:
training_steps = 200000
print("Stop this block whenever after 1-2k steps. It's good even very early.")

@tf.function
def step(pfw, xts, yts, xes, yes, num_steps):
  print("compiling")
  with tf.GradientTape() as g:
    pfw_serialized = network.serialize_pfw(pfw)
    l, _, _ = training_regime.batch_mp_loss(
        pfw_serialized, xts, yts, xes, yes, num_steps, same_pfw=True)
  all_weights = network.get_trainable_weights()
  all_weights += pfw
  grads = g.gradient(l, all_weights)
  # Try grad clipping to avoid explosions.
  grads = [g/(tf.norm(g)+1e-8) for g in grads]
  trainer.apply_gradients(zip(grads, all_weights))
  return l


import time
start_time = time.time()

for i in range(last_step + 1, last_step +1 + training_steps):
  last_step = i

  tmp_t = time.time()
  xts, yts, xes, yes = next(ds_iter)

  l = step(trained_pfw, xts, yts, xes, yes, num_steps)
  loss_log.append(l)

  if i % 50 == 0:
    print(i)
    print("--- %s seconds ---" % (time.time() - start_time))
  if i % 500 == 0:
    plt.plot(smoothen(loss_log, 100), label='mp')
    plt.yscale('log')
    #plt.ylim(0.0, 1e-1)
    plt.legend()
    plt.show()
print("--- %s seconds ---" % (time.time() - start_time))

In [ ]:
print(loss_log[-1])
plt.plot(smoothen(loss_log, 100), label='mp')
plt.yscale('log')
plt.ylim(0.0, 1e-1)
plt.gca().yaxis.grid(True)
plt.legend()
plt.show()

Proper evaluation: run 100 different few-shot instances with totally new network params.

The train loss is computed only on points that the network has already observed.

The eval loss is computed on the entire range [-5, 5]


In [ ]:
!mkdir tmp

!ls tmp -R

In [ ]:
!mkdir tmp
file_path = "tmp/maml_sin_net_weights"

network.save_weights(file_path, last_step)

with open("tmp/maml_sin_prior_weights_{:08d}.npy".format(
    last_step), "wb") as fout:
  prior_to_save = tf.concat([tf.reshape(e, [-1]) for e in trained_pfw], 0)
  np.save(fout, prior_to_save.numpy())

!ls -lh tmp

In [ ]:
# try to save and load
file_path = "savedmodels/maml_sin_net_weights"
network = MPNetwork(
    [
     MPDense(20, stateful=stateful, stateful_hidden_n=stateful_hidden_n),
     MPActivation(tf.nn.relu, stateful=stateful, stateful_hidden_n=stateful_hidden_n),
     MPDense(20, stateful=stateful, stateful_hidden_n=stateful_hidden_n), 
     MPActivation(tf.nn.relu, stateful=stateful, stateful_hidden_n=stateful_hidden_n),
     MPDense(1, stateful=stateful, stateful_hidden_n=stateful_hidden_n),
     ],
     MPL2Loss(stateful=stateful, stateful_hidden_n=stateful_hidden_n))
network.setup(in_dim=1, message_size=message_size, inner_batch_size=INNER_BATCH_SIZE)
network.load_weights(file_path)

# Load prior weights too
import tensorflow.io.gfile as gfile
matcher = "savedmodels/maml_sin_prior_weights_*.npy"
filenames = sorted(gfile.glob(matcher), reverse=True)
assert len(filenames) > 0, "No files matching {}".format(matcher)
filename = filenames[0]
print(filename)

# load array
with gfile.GFile(filename, "rb") as fin:
  serialized_weights = np.load(fin)

trained_pfw_shapes = [t.shape for t in network.init()]
trained_pfw_flat_sizes = [int(tf.reshape(t, [-1]).shape[0]) for t in network.init()]

print(serialized_weights.shape, trained_pfw_flat_sizes)
all_weights_flat_split = tf.split(serialized_weights,
                                  trained_pfw_flat_sizes)
trained_pfw = [tf.reshape(t, s) for t, s in zip(
    all_weights_flat_split, trained_pfw_shapes)]

In [ ]:
eval_tot_steps = 100

tr_losses = np.zeros([eval_tot_steps, NUM_STEPS])
ev_losses = np.zeros([eval_tot_steps, NUM_STEPS + 1]) # also 0-step.

@tf.function
def get_loss(pfw, x, y):
  predictions, _= network.forward(pfw, x)
  loss, _ = network.compute_loss(predictions, y)
  return loss

start_time = time.time()

for r in range(eval_tot_steps):
  p_fw = trained_pfw

  A, ph = ds_factory._create_task()

  targets = A * np.sin(xrange_inputs + ph)

  xt, yt = ds_factory._create_instance(A, ph, INNER_BATCH_SIZE, NUM_STEPS)

  # initial loss.
  loss = get_loss(p_fw, xrange_inputs, targets)
  ev_losses[r, 0] = tf.reduce_mean(loss)

  for i in range(NUM_STEPS):
    p_fw, _= network.inner_update(p_fw, xt[i], yt[i])

    # loss specific to only what we observe.
    x_observed_so_far = tf.reshape(xt[:i+1], (-1, 1))
    y_observed_so_far = tf.reshape(yt[:i+1], (-1, 1))
    loss = get_loss(p_fw, x_observed_so_far, y_observed_so_far)
    tr_losses[r, i] = tf.reduce_mean(loss)

    # Plotting for the continuous input range
    loss = get_loss(p_fw, xrange_inputs, targets)
    ev_losses[r, i + 1] = tf.reduce_mean(loss)
print("--- %s seconds ---" % (time.time() - start_time))

tr_losses_m = np.mean(tr_losses, axis=0)
ev_losses_m = np.mean(ev_losses, axis=0)

tr_losses_sd = np.std(tr_losses, axis=0)
ev_losses_sd = np.std(ev_losses, axis=0)

print("tr_l, m:", tr_losses_m, " sd:", tr_losses_sd)
print("ev_l, m:", ev_losses_m, " sd:", ev_losses_sd)

ub = [m + sd for m, sd in zip(tr_losses_m, tr_losses_sd)]
lb = [m - sd for m, sd in zip(tr_losses_m, tr_losses_sd)]
plt.fill_between(range(1, len(tr_losses_m) + 1), ub, lb, alpha=.5)
plt.plot(range(1, len(tr_losses_m) + 1), tr_losses_m, label='train loss')

ub = [m + sd for m, sd in zip(ev_losses_m, ev_losses_sd)]
lb = [m - sd for m, sd in zip(ev_losses_m, ev_losses_sd)]
plt.fill_between(range(0, len(ev_losses_m)), ub, lb, alpha=.5)
plt.plot(range(0, len(ev_losses_m)), ev_losses_m, label='eval loss')
plt.ylim(0.0, 0.025)
plt.xlabel("num steps")
plt.ylabel("L2 loss")
plt.legend()

In [ ]:
print(tr_losses_m, ev_losses_m)

In [ ]:
# @title Show an example run:

fig, axs = plt.subplots(5, 2, figsize=(10,15))

for fig_n in range(5):
  p_fw = trained_pfw

  n_plot = 5
  plot_every = 1

  predictions, _= network.forward(p_fw, xrange_inputs)

  A, ph = ds_factory._create_task()

  targets = A * np.sin(xrange_inputs + ph)
  axs[fig_n][0].plot(xrange_inputs, targets, label='target')

  predictions, _= network.forward(p_fw, xrange_inputs)
  axs[fig_n][0].plot(xrange_inputs, predictions, label='{}-step predictions'.format(0))

  xt, yt = ds_factory._create_instance(A, ph, INNER_BATCH_SIZE, NUM_STEPS)
  xe, ye = ds_factory._create_instance(A, ph, INNER_BATCH_SIZE, NUM_STEPS)
  tr_losses = []
  ev_losses = []

  for i in range(NUM_STEPS):
    p_fw, _ = network.inner_update(p_fw, xt[i], yt[i])

    # loss specific to only what we observe.
    x_observed_so_far = tf.reshape(xt[:i+1], (-1, 1))
    y_observed_so_far = tf.reshape(yt[:i+1], (-1, 1))
    predictions, _= network.forward(p_fw, x_observed_so_far)
    loss, _ = network.compute_loss(predictions, y_observed_so_far)
    tr_losses.append(tf.reduce_mean(loss))

    # Plotting for the continuous input range
    predictions, _= network.forward(p_fw, xrange_inputs)
    if (i+1) % plot_every == 0:
      axs[fig_n][0].plot(xrange_inputs, predictions, label='{}-step predictions'.format(i+1))
    loss, _ = network.compute_loss(predictions, targets)
    ev_losses.append(tf.reduce_mean(loss))

  axs[fig_n][1].plot(np.arange(len(tr_losses)), tr_losses, label='tr_losses')
  axs[fig_n][1].plot(np.arange(len(ev_losses)), ev_losses, label='ev_losses')


axs[0][0].legend()
axs[0][1].legend()

In [ ]:
# @title Single run for drawing.

p_fw = trained_pfw

plot_every = 1

predictions, _ = network.forward(p_fw, xrange_inputs)

A, ph = ds_factory._create_task()

targets = A * np.sin(xrange_inputs + ph)
plt.plot(xrange_inputs, targets, label='target')

predictions, _= network.forward(p_fw, xrange_inputs)
plt.plot(xrange_inputs, predictions, label='{}-step predictions'.format(0))

xt, yt = ds_factory._create_instance(A, ph, INNER_BATCH_SIZE, NUM_STEPS)
xe, ye = ds_factory._create_instance(A, ph, INNER_BATCH_SIZE, NUM_STEPS)
tr_losses = []
ev_losses = []

for i in range(NUM_STEPS):
  p_fw, _ = network.inner_update(p_fw, xt[i], yt[i])

  # loss specific to only what we observe.
  x_observed_so_far = tf.reshape(xt[:i+1], (-1, 1))
  y_observed_so_far = tf.reshape(yt[:i+1], (-1, 1))
  predictions, _dh= network.forward(p_fw, x_observed_so_far)
  loss, _ = network.compute_loss(predictions, y_observed_so_far)
  tr_losses.append(tf.reduce_mean(loss))

  # Plotting for the continuous input range
  predictions, _= network.forward(p_fw, xrange_inputs)
  if (i+1) % plot_every == 0:
    plt.plot(xrange_inputs, predictions, label='{}-step predictions'.format(i+1))
  loss, _ = network.compute_loss(predictions, targets)
  ev_losses.append(tf.reduce_mean(loss))


plt.legend()


with open("tmp/mplp_example_run.png", "wb") as fout:
  plt.savefig(fout)

Compare it with MAML run


In [ ]:
maml_pfw = [tf.Variable(t) for t in network.init()]
maml_last_step = 0
maml_loss_log = []

In [ ]:
training_steps = 200000
print("Stop this block whenever after 1-2k steps. It's good even very early.")

def update_pfw(pfw, xt, yt, num_steps):
  for i in tf.range(num_steps):
    with tf.GradientTape() as g:
      g.watch(pfw)
      prediction, _ = network.forward(pfw, xt[i])
      loss, _ = network.compute_loss(prediction, yt[i])
      loss = tf.reduce_mean(loss)
    grads = g.gradient(loss, pfw)
    
    pfw = [p - 0.05 * pg for p, pg in zip(pfw, grads)]
  return pfw

def single_loss(pfw, xt, yt, xe, ye, num_steps):
  new_pfw = update_pfw(pfw, xt, yt, num_steps)

  prediction, _ = network.forward(new_pfw, xe)
  cv_loss, _ = network.compute_loss(prediction, ye)
  cv_loss = tf.reduce_mean(cv_loss)
  return cv_loss

def batch_maml_loss(pfw, xts, yts, xes, yes, num_steps):
  task_losses = []
  for i in range(len(xts)):
    task_losses.append(
        single_loss(pfw, xts[i], yts[i], xes[i], yes[i], num_steps))
  return tf.reduce_mean(tf.stack(task_losses))

@tf.function
def maml_step(pfw, xts, yts, xes, yes, num_steps):
  print("compiling")
  with tf.GradientTape() as g:
    l = batch_maml_loss(pfw, xts, yts, xes, yes, num_steps)
  grads = g.gradient(l, pfw)
  # Try grad clipping to avoid explosions.
  grads = [g/(tf.norm(g)+1e-8) for g in grads]
  trainer.apply_gradients(zip(grads, pfw))
  return l


import time
start_time = time.time()

for i in range(maml_last_step + 1, maml_last_step +1 + training_steps):
  maml_last_step = i

  tmp_t = time.time()
  xts, yts, xes, yes = next(ds_iter)

  l = maml_step(maml_pfw, xts, yts, xes, yes, num_steps)
  maml_loss_log.append(l)

  if i % 50 == 0:
    print(i)
    print("--- %s seconds ---" % (time.time() - start_time))
  if i % 500 == 0:
    plt.plot(smoothen(maml_loss_log, 100), label='mp')
    plt.yscale('log')
    plt.legend()
    plt.show()
print("--- %s seconds ---" % (time.time() - start_time))

In [ ]:
eval_tot_steps = 100

tr_losses = np.zeros([eval_tot_steps, NUM_STEPS])
ev_losses = np.zeros([eval_tot_steps, NUM_STEPS + 1]) # also 0 step.

for r in range(eval_tot_steps):
  # We need to transform these into variables.
  p_fw = maml_pfw

  A, ph = ds_factory._create_task()

  targets = A * np.sin(xrange_inputs + ph)

  xt, yt = ds_factory._create_instance(A, ph, INNER_BATCH_SIZE, NUM_STEPS)
  xe, ye = ds_factory._create_instance(A, ph, INNER_BATCH_SIZE, NUM_STEPS)

  # initial loss.
  predictions, _= network.forward(p_fw, xrange_inputs)
  loss, _ = network.compute_loss(predictions, targets)
  ev_losses[r, 0] = tf.reduce_mean(loss)

  for i in range(NUM_STEPS):
    p_fw = update_pfw(p_fw, xt[i:i+1], yt[i:i+1], num_steps=1)

    # loss specific to only what we observe.
    x_observed_so_far = tf.reshape(xt[:i+1], (-1, 1))
    y_observed_so_far = tf.reshape(yt[:i+1], (-1, 1))
    predictions, _= network.forward(p_fw, x_observed_so_far)
    loss, _ = network.compute_loss(predictions, y_observed_so_far)
    tr_losses[r, i] = tf.reduce_mean(loss)

    # Plotting for the continuous input range
    predictions, _= network.forward(p_fw, xrange_inputs)
    loss, _ = network.compute_loss(predictions, targets)
    ev_losses[r, i + 1] = tf.reduce_mean(loss)

tr_losses_m = np.mean(tr_losses, axis=0)
ev_losses_m = np.mean(ev_losses, axis=0)
tr_losses_sd = np.std(tr_losses, axis=0)
ev_losses_sd = np.std(ev_losses, axis=0)
print("tr_l, m:", tr_losses_m, " sd:", tr_losses_sd)
print("ev_l, m:", ev_losses_m, " sd:", ev_losses_sd)

ub = [m + sd for m, sd in zip(tr_losses_m, tr_losses_sd)]
lb = [m - sd for m, sd in zip(tr_losses_m, tr_losses_sd)]
plt.fill_between(range(1, len(tr_losses_m) + 1), ub, lb, alpha=.5)
plt.plot(range(1, len(tr_losses_m) + 1), tr_losses_m, label='train loss')

ub = [m + sd for m, sd in zip(ev_losses_m, ev_losses_sd)]
lb = [m - sd for m, sd in zip(ev_losses_m, ev_losses_sd)]
plt.fill_between(range(0, len(ev_losses_m)), ub, lb, alpha=.5)
plt.plot(range(0, len(ev_losses_m)), ev_losses_m, label='eval loss')
plt.ylim(0.0, 0.04)
plt.xlabel("num steps")
plt.ylabel("L2 loss")
plt.legend()

In [ ]:
# Same task as MPLP fo same drawing

p_fw = maml_pfw

plot_every = 1

predictions, _ = network.forward(p_fw, xrange_inputs)

targets = A * np.sin(xrange_inputs + ph)
plt.plot(xrange_inputs, targets, label='target')

predictions, _= network.forward(p_fw, xrange_inputs)
plt.plot(xrange_inputs, predictions, label='{}-step predictions'.format(0))

xt, yt = ds_factory._create_instance(A, ph, INNER_BATCH_SIZE, NUM_STEPS)
xe, ye = ds_factory._create_instance(A, ph, INNER_BATCH_SIZE, NUM_STEPS)
tr_losses = []
ev_losses = []

for i in range(NUM_STEPS):
  p_fw = update_pfw(p_fw, xt[i:i+1], yt[i:i+1], num_steps=1)

  # loss specific to only what we observe.
  x_observed_so_far = tf.reshape(xt[:i+1], (-1, 1))
  y_observed_so_far = tf.reshape(yt[:i+1], (-1, 1))
  predictions, _= network.forward(p_fw, x_observed_so_far)
  loss, _ = network.compute_loss(predictions, y_observed_so_far)
  tr_losses.append(tf.reduce_mean(loss))

  # Plotting for the continuous input range
  predictions, _= network.forward(p_fw, xrange_inputs)
  if (i+1) % plot_every == 0:
    plt.plot(xrange_inputs, predictions, label='{}-step predictions'.format(i+1))
  loss, _ = network.compute_loss(predictions, targets)
  ev_losses.append(tf.reduce_mean(loss))

plt.legend()

with open("tmp/maml_example_run.png", "wb") as fout:
  plt.savefig(fout)

In [ ]:
# @title Show an example run:
n_plot = 5
plot_every = max(1, NUM_STEPS // n_plot)

p_fw = maml_pfw
predictions, _ = network.forward(p_fw, xrange_inputs)

A, ph = ds_factory._create_task()

targets = A * np.sin(xrange_inputs + ph)
plt.plot(xrange_inputs, targets, label='target')

predictions, _= network.forward(p_fw, xrange_inputs)
plt.plot(xrange_inputs, predictions, label='{}-step predictions'.format(0))

xt, yt = ds_factory._create_instance(A, ph, INNER_BATCH_SIZE, NUM_STEPS)
xe, ye = ds_factory._create_instance(A, ph, INNER_BATCH_SIZE, NUM_STEPS)
tr_losses = []
ev_losses = []

for i in range(NUM_STEPS):
  p_fw = update_pfw(p_fw, xt[i:i+1], yt[i:i+1], num_steps=1)

  # loss specific to only what we observe.
  x_observed_so_far = tf.reshape(xt[:i+1], (-1, 1))
  y_observed_so_far = tf.reshape(yt[:i+1], (-1, 1))
  predictions, _= network.forward(p_fw, x_observed_so_far)
  loss, _ = network.compute_loss(predictions, y_observed_so_far)
  tr_losses.append(tf.reduce_mean(loss))

  # Plotting for the continuous input range
  predictions, _= network.forward(p_fw, xrange_inputs)
  if (i+1) % plot_every == 0:
    plt.plot(xrange_inputs, predictions, label='{}-step predictions'.format(i+1))
  loss, _ = network.compute_loss(predictions, targets)
  ev_losses.append(tf.reduce_mean(loss))

plt.legend()
plt.show()

plt.plot(np.arange(len(tr_losses)), tr_losses, label='tr_losses')
plt.plot(np.arange(len(ev_losses)), ev_losses, label='ev_losses')
plt.legend()
plt.show()