This notebook uses pre-trained models to explore what can be done with sketch-RNN.

it should be placed inside of the "sketch_rnn" directory that is bundled with the magenta git repository, so you'll need to git clone that repository into your working environment if you haven't already.

note! for this to work properly, you'll need a dataset formatted as an .npz file, and a pre-trained model (which you can create from .npz using the sketch_rnn_train.ipynb notebook or by running sketch_rnn_train.py on its own)

note! you will need to create a datasets/ folder and a models/ folder in the same directory as this notebook. (i.e. sketch_rnn/). place the .npz file into datasets and the model files (checkpoint, model_config.json and various vector files) into a subdirectory of the models/ folder.

note that this script can only read two levels into the models/ directory: i.e. models/sheep/model_checkpoint.json will work, but models/sheep/layer_norm/model_checkpoint.json throws an error. go figure!


In [1]:
# import the required libraries
import numpy as np
import time
import random
import cPickle
import codecs
import collections
import os
import math
import json
import tensorflow as tf
from six.moves import xrange

# libraries required for visualisation:
from IPython.display import SVG, display
import svgwrite # conda install -c omnia svgwrite=1.1.6
import PIL
from PIL import Image
import matplotlib.pyplot as plt

# set numpy output to something sensible
np.set_printoptions(precision=8, edgeitems=6, linewidth=200, suppress=True)

# tells which version of tensorflow is being used
tf.logging.info("TensorFlow Version: %s", tf.__version__)

In [3]:
# import command line tools
from magenta.models.sketch_rnn.sketch_rnn_train import *
from magenta.models.sketch_rnn.model import *
from magenta.models.sketch_rnn.utils import *
from magenta.models.sketch_rnn.rnn import *

In [4]:
# this function displays vector images, and saves them to .svg
# you can invoke the "draw_strokes" function anytime you want to render an image -
# specify source, destination filename, and random scale factor (defaults below) 

def draw_strokes(data, svg_filename = 'sample.svg', factor=0.2):
  tf.gfile.MakeDirs(os.path.dirname(svg_filename))
  min_x, max_x, min_y, max_y = get_bounds(data, factor)
  dims = (50 + max_x - min_x, 50 + max_y - min_y)
  dwg = svgwrite.Drawing(svg_filename, size=dims)
  dwg.add(dwg.rect(insert=(0, 0), size=dims,fill='white'))
  lift_pen = 1
  abs_x = 25 - min_x 
  abs_y = 25 - min_y
  p = "M%s,%s " % (abs_x, abs_y)
  command = "m"
  for i in xrange(len(data)):
    if (lift_pen == 1):
      command = "m"
    elif (command != "l"):
      command = "l"
    else:
      command = ""
    x = float(data[i,0])/factor
    y = float(data[i,1])/factor
    lift_pen = data[i, 2]
    p += command+str(x)+","+str(y)+" "
  the_color = "black"
  stroke_width = 1
  dwg.add(dwg.path(p).stroke(the_color,stroke_width).fill("none"))
  dwg.save()
  display(SVG(dwg.tostring()))

# generate a 2D grid of many vector drawings
def make_grid_svg(s_list, grid_space=10.0, grid_space_x=16.0):
  def get_start_and_end(x):
    x = np.array(x)
    x = x[:, 0:2]
    x_start = x[0]
    x_end = x.sum(axis=0)
    x = x.cumsum(axis=0)
    x_max = x.max(axis=0)
    x_min = x.min(axis=0)
    center_loc = (x_max+x_min)*0.5
    return x_start-center_loc, x_end
  x_pos = 0.0
  y_pos = 0.0
  result = [[x_pos, y_pos, 1]]
  for sample in s_list:
    s = sample[0]
    grid_loc = sample[1]
    grid_y = grid_loc[0]*grid_space+grid_space*0.5
    grid_x = grid_loc[1]*grid_space_x+grid_space_x*0.5
    start_loc, delta_pos = get_start_and_end(s)

    loc_x = start_loc[0]
    loc_y = start_loc[1]
    new_x_pos = grid_x+loc_x
    new_y_pos = grid_y+loc_y
    result.append([new_x_pos-x_pos, new_y_pos-y_pos, 0])

    result += s.tolist()
    result[-1][2] = 1
    x_pos = new_x_pos+delta_pos[0]
    y_pos = new_y_pos+delta_pos[1]
  return np.array(result)

In [5]:
# these global variables define the relative path to the pre-trained model 
# and original dataset 

data_dir = 'datasets/' # this is where your .npz file lives
models_root_dir = 'models/' # this is where trained models live
model_dir = 'models/sheep' # change "sheep" to whatever name you like

# note! you must create the "datasets/" and "models/" folders 
# in the sketch_rnn directory before running this.

# you will also need to place model files (generated in training process)
# into the "models/sheep" (or whatever you've named it) folder -
# i.e. a checkpoint file, a model_config.json file, and vector data, index, 
# and meta files too.

# note! model_dir value only handles two levels of recursion (i.e. models/sheep)
# subfolders break the next step (i.e. you can't do models/sheep/layer_norm)

In [6]:
# populates the above global variables throughout the sketch_rnn project files
[train_set, valid_set, test_set, hps_model, eval_hps_model, sample_hps_model] = load_env(data_dir, model_dir)


INFO:tensorflow:model_params.max_seq_len 250.
total images <= max_seq_len is 7400
total images <= max_seq_len is 300
total images <= max_seq_len is 300
INFO:tensorflow:normalizing_scale_factor 18.5198.

In [7]:
#construct the sketch-rnn model:
reset_graph()
model = Model(hps_model)
eval_model = Model(eval_hps_model, reuse=True)
sample_model = Model(sample_hps_model, reuse=True)


INFO:tensorflow:Model using gpu.
INFO:tensorflow:Input dropout mode = False.
INFO:tensorflow:Output dropout mode = False.
INFO:tensorflow:Recurrent dropout mode = True.
---------------------------------------------------------------------------
KeyboardInterrupt                         Traceback (most recent call last)
<ipython-input-7-3f5faf261252> in <module>()
      1 #construct the sketch-rnn model:
      2 reset_graph()
----> 3 model = Model(hps_model)
      4 eval_model = Model(eval_hps_model, reuse=True)
      5 sample_model = Model(sample_hps_model, reuse=True)

/home/lfischbeck/miniconda2/envs/magenta/lib/python2.7/site-packages/magenta/models/sketch_rnn/model.pyc in __init__(self, hps, gpu_mode, reuse)
     91       else:
     92         tf.logging.info('Model using gpu.')
---> 93         self.build_model(hps)
     94 
     95   def encoder(self, batch, sequence_lengths):

/home/lfischbeck/miniconda2/envs/magenta/lib/python2.7/site-packages/magenta/models/sketch_rnn/model.pyc in build_model(self, hps)
    356       self.cost = self.r_cost + self.kl_cost * self.kl_weight
    357 
--> 358       gvs = optimizer.compute_gradients(self.cost)
    359       g = self.hps.grad_clip
    360       capped_gvs = [(tf.clip_by_value(grad, -g, g), var) for grad, var in gvs]

/home/lfischbeck/miniconda2/envs/magenta/lib/python2.7/site-packages/tensorflow/python/training/optimizer.pyc in compute_gradients(self, loss, var_list, gate_gradients, aggregation_method, colocate_gradients_with_ops, grad_loss)
    384         gate_gradients=(gate_gradients == Optimizer.GATE_OP),
    385         aggregation_method=aggregation_method,
--> 386         colocate_gradients_with_ops=colocate_gradients_with_ops)
    387     if gate_gradients == Optimizer.GATE_GRAPH:
    388       grads = control_flow_ops.tuple(grads)

/home/lfischbeck/miniconda2/envs/magenta/lib/python2.7/site-packages/tensorflow/python/ops/gradients_impl.pyc in gradients(ys, xs, grad_ys, name, colocate_gradients_with_ops, gate_gradients, aggregation_method)
    558                 # functions.
    559                 in_grads = _MaybeCompile(
--> 560                     grad_scope, op, func_call, lambda: grad_fn(op, *out_grads))
    561               else:
    562                 # For function call ops, we add a 'SymbolicGradient'

/home/lfischbeck/miniconda2/envs/magenta/lib/python2.7/site-packages/tensorflow/python/ops/gradients_impl.pyc in _MaybeCompile(scope, op, func, grad_fn)
    366       xla_scope = op.get_attr("_XlaScope").decode()
    367     except ValueError:
--> 368       return grad_fn()  # Exit early
    369 
    370   if not xla_compile:

/home/lfischbeck/miniconda2/envs/magenta/lib/python2.7/site-packages/tensorflow/python/ops/gradients_impl.pyc in <lambda>()
    558                 # functions.
    559                 in_grads = _MaybeCompile(
--> 560                     grad_scope, op, func_call, lambda: grad_fn(op, *out_grads))
    561               else:
    562                 # For function call ops, we add a 'SymbolicGradient'

/home/lfischbeck/miniconda2/envs/magenta/lib/python2.7/site-packages/tensorflow/python/ops/array_grad.pyc in _ConcatGradV2(op, grad)
    192 def _ConcatGradV2(op, grad):
    193   return _ConcatGradHelper(
--> 194       op, grad, start_value_index=0, end_value_index=-1, dim_index=-1)
    195 
    196 

/home/lfischbeck/miniconda2/envs/magenta/lib/python2.7/site-packages/tensorflow/python/ops/array_grad.pyc in _ConcatGradHelper(op, grad, start_value_index, end_value_index, dim_index)
    124       offset = gen_array_ops._concat_offset(non_neg_concat_dim, sizes)
    125       for (begin, size) in zip(offset, sizes):
--> 126         out_grads.append(array_ops.slice(grad, begin, size))
    127     # pylint: enable=protected-access
    128   elif isinstance(grad, ops.IndexedSlices):

/home/lfischbeck/miniconda2/envs/magenta/lib/python2.7/site-packages/tensorflow/python/ops/array_ops.pyc in slice(input_, begin, size, name)
    543     A `Tensor` the same type as `input`.
    544   """
--> 545   return gen_array_ops._slice(input_, begin, size, name=name)
    546 
    547 

/home/lfischbeck/miniconda2/envs/magenta/lib/python2.7/site-packages/tensorflow/python/ops/gen_array_ops.pyc in _slice(input, begin, size, name)
   2931   """
   2932   result = _op_def_lib.apply_op("Slice", input=input, begin=begin, size=size,
-> 2933                                 name=name)
   2934   return result
   2935 

/home/lfischbeck/miniconda2/envs/magenta/lib/python2.7/site-packages/tensorflow/python/framework/op_def_library.pyc in apply_op(self, op_type_name, name, **keywords)
    766         op = g.create_op(op_type_name, inputs, output_types, name=scope,
    767                          input_types=input_types, attrs=attr_protos,
--> 768                          op_def=op_def)
    769         if output_structure:
    770           outputs = op.outputs

/home/lfischbeck/miniconda2/envs/magenta/lib/python2.7/site-packages/tensorflow/python/framework/ops.pyc in create_op(self, op_type, inputs, dtypes, input_types, name, attrs, op_def, compute_shapes, compute_device)
   2336                     original_op=self._default_original_op, op_def=op_def)
   2337     if compute_shapes:
-> 2338       set_shapes_for_outputs(ret)
   2339     self._add_op(ret)
   2340     self._record_op_seen_by_control_dependencies(ret)

/home/lfischbeck/miniconda2/envs/magenta/lib/python2.7/site-packages/tensorflow/python/framework/ops.pyc in set_shapes_for_outputs(op)
   1717       shape_func = _call_cpp_shape_fn_and_require_op
   1718 
-> 1719   shapes = shape_func(op)
   1720   if shapes is None:
   1721     raise RuntimeError(

/home/lfischbeck/miniconda2/envs/magenta/lib/python2.7/site-packages/tensorflow/python/framework/ops.pyc in call_with_requiring(op)
   1667 
   1668   def call_with_requiring(op):
-> 1669     return call_cpp_shape_fn(op, require_shape_fn=True)
   1670 
   1671   _call_cpp_shape_fn_and_require_op = call_with_requiring

/home/lfischbeck/miniconda2/envs/magenta/lib/python2.7/site-packages/tensorflow/python/framework/common_shapes.pyc in call_cpp_shape_fn(op, input_tensors_needed, input_tensors_as_shapes_needed, debug_python_shape_fn, require_shape_fn)
    608     res = _call_cpp_shape_fn_impl(op, input_tensors_needed,
    609                                   input_tensors_as_shapes_needed,
--> 610                                   debug_python_shape_fn, require_shape_fn)
    611     if not isinstance(res, dict):
    612       # Handles the case where _call_cpp_shape_fn_impl calls unknown_shape(op).

KeyboardInterrupt: 

In [9]:
sess = tf.InteractiveSession()
sess.run(tf.global_variables_initializer())

In [10]:
# loads the weights from checkpoint into our model
load_checkpoint(sess, model_dir)


INFO:tensorflow:Loading model /home/lfischbeck/magenta/magenta/models/sketch_rnn/models/sheep/vector-20.
INFO:tensorflow:Restoring parameters from /home/lfischbeck/magenta/magenta/models/sketch_rnn/models/sheep/vector-20

Encode and Decode Sample Drawings

First, define two convenience functions to encode a stroke into a latent vector, and decode from latent vector to stroke:


In [13]:
def encode(input_strokes):
  strokes = to_big_strokes(input_strokes).tolist()
  strokes.insert(0, [0, 0, 1, 0, 0])
  seq_len = [len(input_strokes)]
  draw_strokes(to_normal_strokes(np.array(strokes)))
  return sess.run(eval_model.batch_z, feed_dict={eval_model.input_data: [strokes], eval_model.sequence_lengths: seq_len})[0]

In [14]:
def decode(z_input=None, temperature=0.1, factor=0.2):
  z = None
  if z_input is not None:
    z = [z_input]
  sample_strokes, m = sample(sess, sample_model, seq_len=eval_model.hps.max_seq_len, temperature=temperature, z=z)
  strokes = to_normal_strokes(sample_strokes)
  return strokes

In [15]:
# get a sample drawing from the test set, and render it to .svg
example_drawing = test_set.random_sample()
draw_strokes(example_drawing)



In [16]:
#encode the sample drawing into latent vector z
z = encode(example_drawing)



In [15]:
# convert z back to drawing, using a "temperature" of 0.1
decoded_drawing = decode(z, temperature=0.1) 
draw_strokes(decoded_drawing, 'sample3.svg', 0.2) 
#specify the input source, the filename to save to (in the same directory as this notebook), and the random scale factor (default is 0.2), and


Temperature Interpolation


In [16]:
#Create a series of drawings stepping through various "temperatures" from 0.1 to 1.0
stroke_list = []
for i in range(10):
  stroke_list.append([decode(z, temperature=0.1*i+0.1), [0, i]])
stroke_grid = make_grid_svg(stroke_list)
draw_strokes(stroke_grid, 'sample-interp-temp.svg') #if two arguments are given to draw_strokes, they are input vector and output filename


Latent Space Interpolation

Stepping through latent space between two sample images ($z_0$ and $z_1$)


In [17]:
#z0 is the first sample
z0 = z #use the random sample we'd already  selected
decoded_drawing = decode(z0)
#each time it's decoded from the latent vector it's slightly different
draw_strokes(decoded_drawing)
#uses default file destination of 'sample.svg' and default random scale factor of 0.2



In [18]:
#z1 is the second sample
z1 = encode(test_set.random_sample()) #grab a new random sample and encode it
decoded_drawing2 = decode(z1) #then decode it 
draw_strokes(decoded_drawing2)
#the top drawing is the encoded version, the bottom is the decoded version



In [19]:
z_list = [] # interpolate spherically between z0 and z1
N = 10 # change this number to add more steps
for t in np.linspace(0, 1, N):
  z_list.append(slerp(z0, z1, t))

In [20]:
# for every latent vector in z_list, sample a vector image
reconstructions = []
for i in range(N):
  reconstructions.append([decode(z_list[i]), [0, i]])

In [21]:
#draw the interpolation steps
stroke_grid = make_grid_svg(reconstructions)
draw_strokes(stroke_grid, 'sample-interp1.svg')


Unconditional (Decoder-Only) Generation


In [22]:
model_dir = '/tmp/sketch_rnn/models/flamingo/lstm_uncond'

In [23]:
[hps_model, eval_hps_model, sample_hps_model] = load_model(model_dir)

In [24]:
# construct the sketch-rnn model here:
reset_graph()
model = Model(hps_model)
eval_model = Model(eval_hps_model, reuse=True)
sample_model = Model(sample_hps_model, reuse=True)


INFO:tensorflow:Model using gpu.
INFO:tensorflow:Input dropout mode = False.
INFO:tensorflow:Output dropout mode = False.
INFO:tensorflow:Recurrent dropout mode = False.
INFO:tensorflow:Model using gpu.
INFO:tensorflow:Input dropout mode = False.
INFO:tensorflow:Output dropout mode = False.
INFO:tensorflow:Recurrent dropout mode = False.
INFO:tensorflow:Model using gpu.
INFO:tensorflow:Input dropout mode = False.
INFO:tensorflow:Output dropout mode = False.
INFO:tensorflow:Recurrent dropout mode = False.

In [25]:
sess = tf.InteractiveSession()
sess.run(tf.global_variables_initializer())

In [26]:
# loads the weights from checkpoint into our model
load_checkpoint(sess, model_dir)


INFO:tensorflow:Loading model /tmp/sketch_rnn/models/flamingo/lstm_uncond/vector.
INFO:tensorflow:Restoring parameters from /tmp/sketch_rnn/models/flamingo/lstm_uncond/vector

In [28]:
# randomly unconditionally generate 10 examples
N = 10
reconstructions = []
for i in range(N):
  reconstructions.append([decode(temperature=0.1), [0, i]])
#experiment with different temperature values to get more variety

In [29]:
#draw 10 examples
stroke_grid = make_grid_svg(reconstructions)
draw_strokes(stroke_grid)


Generate sketches using random IID gaussian latent vectors


In [30]:
#other models available:
#model_dir = '/tmp/sketch_rnn/models/owl/lstm'
#model_dir = '/tmp/sketch_rnn/models/catbus/lstm'
model_dir = '/tmp/sketch_rnn/models/elephantpig/lstm'

In [31]:
[hps_model, eval_hps_model, sample_hps_model] = load_model(model_dir)
# construct the sketch-rnn model here:
reset_graph()
model = Model(hps_model)
eval_model = Model(eval_hps_model, reuse=True)
sample_model = Model(sample_hps_model, reuse=True)
sess = tf.InteractiveSession()
sess.run(tf.global_variables_initializer())
# loads the weights from checkpoint into our model
load_checkpoint(sess, model_dir)


INFO:tensorflow:Model using gpu.
INFO:tensorflow:Input dropout mode = False.
INFO:tensorflow:Output dropout mode = False.
INFO:tensorflow:Recurrent dropout mode = False.
INFO:tensorflow:Model using gpu.
INFO:tensorflow:Input dropout mode = False.
INFO:tensorflow:Output dropout mode = False.
INFO:tensorflow:Recurrent dropout mode = False.
INFO:tensorflow:Model using gpu.
INFO:tensorflow:Input dropout mode = False.
INFO:tensorflow:Output dropout mode = False.
INFO:tensorflow:Recurrent dropout mode = False.
INFO:tensorflow:Loading model /tmp/sketch_rnn/models/elephantpig/lstm/vector.
INFO:tensorflow:Restoring parameters from /tmp/sketch_rnn/models/elephantpig/lstm/vector

In [32]:
#randomly select a latent vector z_0
z_0 = np.random.randn(eval_model.hps.z_size)
_ = decode(z_0)
draw_strokes(_)



In [33]:
#randomly select a second vector z_1
z_1 = np.random.randn(eval_model.hps.z_size)
_ = decode(z_1)
draw_strokes(_)



In [34]:
z_list = [] # interpolate spherically between z_0 and z_1
N = 10
for t in np.linspace(0, 1, N):
  z_list.append(slerp(z_0, z_1, t))
# for every latent vector in z_list, sample a vector image
reconstructions = []
for i in range(N):
  reconstructions.append([decode(z_list[i], temperature=0.1), [0, i]])

In [35]:
#draw the interpolation
stroke_grid = make_grid_svg(reconstructions)
draw_strokes(stroke_grid, 'sample-interp2.svg')