In [20]:
!pip install svgwrite==1.1.6;
!pip uninstall -y tensorflow;
!pip uninstall -y magenta;
!pip install tensorflow==1.2.1 magenta;

import numpy as np
import time
import random
import cPickle
import codecs
import collections
import os
import math
import json
import tensorflow as tf
from six.moves import xrange

# libraries required for visualisation:
from IPython.display import SVG, display
import svgwrite # conda install -c omnia svgwrite=1.1.6
import PIL
from PIL import Image
import matplotlib.pyplot as plt


from google.colab import files


# set numpy output to something sensible
np.set_printoptions(precision=8, edgeitems=6, linewidth=200, suppress=True)


Requirement already satisfied: svgwrite==1.1.6 in /usr/local/lib/python2.7/dist-packages (1.1.6)
Requirement already satisfied: pyparsing>=2.0.1 in /usr/local/lib/python2.7/dist-packages (from svgwrite==1.1.6) (2.2.0)
magenta 0.3.9 has requirement tensorflow>=1.8.0, but you'll have tensorflow 1.2.1 which is incompatible.
Uninstalling tensorflow-1.2.1:
  Successfully uninstalled tensorflow-1.2.1
Uninstalling magenta-0.3.9:
  Successfully uninstalled magenta-0.3.9
Collecting tensorflow==1.2.1
  Using cached https://files.pythonhosted.org/packages/35/9c/1c353f584f90769e7fa41fe5c45ffc8a3032221945e142938e74f5498cb9/tensorflow-1.2.1-cp27-cp27mu-manylinux1_x86_64.whl
Collecting magenta
  Using cached https://files.pythonhosted.org/packages/ef/39/96c9aee9e10d29b339ea0330c1440e79742c05159cf1ac4c139ee9db0821/magenta-0.3.9-py2.py3-none-any.whl
Requirement already satisfied: numpy>=1.11.0 in /usr/local/lib/python2.7/dist-packages (from tensorflow==1.2.1) (1.14.5)
Requirement already satisfied: mock>=2.0.0 in /usr/local/lib/python2.7/dist-packages (from tensorflow==1.2.1) (2.0.0)
Requirement already satisfied: protobuf>=3.2.0 in /usr/local/lib/python2.7/dist-packages (from tensorflow==1.2.1) (3.6.0)
Requirement already satisfied: backports.weakref==1.0rc1 in /usr/local/lib/python2.7/dist-packages (from tensorflow==1.2.1) (1.0rc1)
Requirement already satisfied: bleach==1.5.0 in /usr/local/lib/python2.7/dist-packages (from tensorflow==1.2.1) (1.5.0)
Requirement already satisfied: wheel in /usr/local/lib/python2.7/dist-packages (from tensorflow==1.2.1) (0.31.1)
Requirement already satisfied: six>=1.10.0 in /usr/local/lib/python2.7/dist-packages (from tensorflow==1.2.1) (1.11.0)
Requirement already satisfied: html5lib==0.9999999 in /usr/local/lib/python2.7/dist-packages (from tensorflow==1.2.1) (0.9999999)
Requirement already satisfied: werkzeug>=0.11.10 in /usr/local/lib/python2.7/dist-packages (from tensorflow==1.2.1) (0.14.1)
Requirement already satisfied: markdown>=2.6.8 in /usr/local/lib/python2.7/dist-packages (from tensorflow==1.2.1) (2.6.11)
Requirement already satisfied: joblib<0.12 in /usr/local/lib/python2.7/dist-packages (from magenta) (0.11)
Requirement already satisfied: pandas>=0.18.1 in /usr/local/lib/python2.7/dist-packages (from magenta) (0.22.0)
Requirement already satisfied: scipy>=0.18.1 in /usr/local/lib/python2.7/dist-packages (from magenta) (0.19.1)
Requirement already satisfied: python-rtmidi in /usr/local/lib/python2.7/dist-packages (from magenta) (1.1.0)
Requirement already satisfied: pretty-midi>=0.2.6 in /usr/local/lib/python2.7/dist-packages (from magenta) (0.2.8)
Requirement already satisfied: matplotlib>=1.5.3 in /usr/local/lib/python2.7/dist-packages (from magenta) (2.1.2)
Requirement already satisfied: intervaltree>=2.1.0 in /usr/local/lib/python2.7/dist-packages (from magenta) (2.1.0)
Requirement already satisfied: futures; python_version == "2.7" in /usr/local/lib/python2.7/dist-packages (from magenta) (3.2.0)
Requirement already satisfied: IPython in /usr/local/lib/python2.7/dist-packages (from magenta) (5.5.0)
Requirement already satisfied: mido==1.2.6 in /usr/local/lib/python2.7/dist-packages (from magenta) (1.2.6)
Requirement already satisfied: bokeh>=0.12.0 in /usr/local/lib/python2.7/dist-packages (from magenta) (0.13.0)
Requirement already satisfied: mir-eval>=0.4 in /usr/local/lib/python2.7/dist-packages (from magenta) (0.4)
Requirement already satisfied: Pillow>=3.4.2 in /usr/local/lib/python2.7/dist-packages (from magenta) (4.0.0)
Requirement already satisfied: librosa>=0.6.0 in /usr/local/lib/python2.7/dist-packages (from magenta) (0.6.1)
Requirement already satisfied: funcsigs>=1; python_version < "3.3" in /usr/local/lib/python2.7/dist-packages (from mock>=2.0.0->tensorflow==1.2.1) (1.0.2)
Requirement already satisfied: pbr>=0.11 in /usr/local/lib/python2.7/dist-packages (from mock>=2.0.0->tensorflow==1.2.1) (4.1.0)
Requirement already satisfied: setuptools in /usr/local/lib/python2.7/dist-packages (from protobuf>=3.2.0->tensorflow==1.2.1) (39.1.0)
Requirement already satisfied: pytz>=2011k in /usr/local/lib/python2.7/dist-packages (from pandas>=0.18.1->magenta) (2018.5)
Requirement already satisfied: python-dateutil in /usr/local/lib/python2.7/dist-packages (from pandas>=0.18.1->magenta) (2.5.3)
Requirement already satisfied: cycler>=0.10 in /usr/local/lib/python2.7/dist-packages (from matplotlib>=1.5.3->magenta) (0.10.0)
Requirement already satisfied: backports.functools-lru-cache in /usr/local/lib/python2.7/dist-packages (from matplotlib>=1.5.3->magenta) (1.5)
Requirement already satisfied: subprocess32 in /usr/local/lib/python2.7/dist-packages (from matplotlib>=1.5.3->magenta) (3.5.2)
Requirement already satisfied: pyparsing!=2.0.4,!=2.1.2,!=2.1.6,>=2.0.1 in /usr/local/lib/python2.7/dist-packages (from matplotlib>=1.5.3->magenta) (2.2.0)
Requirement already satisfied: sortedcontainers in /usr/local/lib/python2.7/dist-packages (from intervaltree>=2.1.0->magenta) (2.0.4)
Requirement already satisfied: simplegeneric>0.8 in /usr/local/lib/python2.7/dist-packages (from IPython->magenta) (0.8.1)
Requirement already satisfied: pickleshare in /usr/local/lib/python2.7/dist-packages (from IPython->magenta) (0.7.4)
Requirement already satisfied: backports.shutil-get-terminal-size; python_version == "2.7" in /usr/local/lib/python2.7/dist-packages (from IPython->magenta) (1.0.0)
Requirement already satisfied: pathlib2; python_version == "2.7" or python_version == "3.3" in /usr/local/lib/python2.7/dist-packages (from IPython->magenta) (2.3.2)
Requirement already satisfied: pexpect; sys_platform != "win32" in /usr/local/lib/python2.7/dist-packages (from IPython->magenta) (4.6.0)
Requirement already satisfied: traitlets>=4.2 in /usr/local/lib/python2.7/dist-packages (from IPython->magenta) (4.3.2)
Requirement already satisfied: pygments in /usr/local/lib/python2.7/dist-packages (from IPython->magenta) (2.1.3)
Requirement already satisfied: decorator in /usr/local/lib/python2.7/dist-packages (from IPython->magenta) (4.3.0)
Requirement already satisfied: prompt-toolkit<2.0.0,>=1.0.4 in /usr/local/lib/python2.7/dist-packages (from IPython->magenta) (1.0.15)
Requirement already satisfied: packaging>=16.8 in /usr/local/lib/python2.7/dist-packages (from bokeh>=0.12.0->magenta) (17.1)
Requirement already satisfied: PyYAML>=3.10 in /usr/local/lib/python2.7/dist-packages (from bokeh>=0.12.0->magenta) (3.13)
Requirement already satisfied: Jinja2>=2.7 in /usr/local/lib/python2.7/dist-packages (from bokeh>=0.12.0->magenta) (2.10)
Requirement already satisfied: tornado>=4.3 in /usr/local/lib/python2.7/dist-packages (from bokeh>=0.12.0->magenta) (4.5.3)
Requirement already satisfied: future in /usr/local/lib/python2.7/dist-packages (from mir-eval>=0.4->magenta) (0.16.0)
Requirement already satisfied: olefile in /usr/local/lib/python2.7/dist-packages (from Pillow>=3.4.2->magenta) (0.45.1)
Requirement already satisfied: audioread>=2.0.0 in /usr/local/lib/python2.7/dist-packages (from librosa>=0.6.0->magenta) (2.1.6)
Requirement already satisfied: numba>=0.38.0 in /usr/local/lib/python2.7/dist-packages (from librosa>=0.6.0->magenta) (0.39.0)
Requirement already satisfied: scikit-learn!=0.19.0,>=0.14.0 in /usr/local/lib/python2.7/dist-packages (from librosa>=0.6.0->magenta) (0.19.2)
Requirement already satisfied: resampy>=0.2.0 in /usr/local/lib/python2.7/dist-packages (from librosa>=0.6.0->magenta) (0.2.1)
Requirement already satisfied: scandir; python_version < "3.5" in /usr/local/lib/python2.7/dist-packages (from pathlib2; python_version == "2.7" or python_version == "3.3"->IPython->magenta) (1.7)
Requirement already satisfied: ptyprocess>=0.5 in /usr/local/lib/python2.7/dist-packages (from pexpect; sys_platform != "win32"->IPython->magenta) (0.6.0)
Requirement already satisfied: enum34; python_version == "2.7" in /usr/local/lib/python2.7/dist-packages (from traitlets>=4.2->IPython->magenta) (1.1.6)
Requirement already satisfied: ipython-genutils in /usr/local/lib/python2.7/dist-packages (from traitlets>=4.2->IPython->magenta) (0.2.0)
Requirement already satisfied: wcwidth in /usr/local/lib/python2.7/dist-packages (from prompt-toolkit<2.0.0,>=1.0.4->IPython->magenta) (0.1.7)
Requirement already satisfied: MarkupSafe>=0.23 in /usr/local/lib/python2.7/dist-packages (from Jinja2>=2.7->bokeh>=0.12.0->magenta) (1.0)
Requirement already satisfied: singledispatch in /usr/local/lib/python2.7/dist-packages (from tornado>=4.3->bokeh>=0.12.0->magenta) (3.4.0.3)
Requirement already satisfied: certifi in /usr/local/lib/python2.7/dist-packages (from tornado>=4.3->bokeh>=0.12.0->magenta) (2018.4.16)
Requirement already satisfied: backports_abc>=0.4 in /usr/local/lib/python2.7/dist-packages (from tornado>=4.3->bokeh>=0.12.0->magenta) (0.5)
Requirement already satisfied: llvmlite>=0.24.0dev0 in /usr/local/lib/python2.7/dist-packages (from numba>=0.38.0->librosa>=0.6.0->magenta) (0.24.0)
magenta 0.3.9 has requirement tensorflow>=1.8.0, but you'll have tensorflow 1.2.1 which is incompatible.
Installing collected packages: tensorflow, magenta
Successfully installed magenta-0.3.9 tensorflow-1.2.1

In [21]:
tf.logging.info("TensorFlow Version: %s", tf.__version__)


INFO:tensorflow:TensorFlow Version: 1.2.1

In [0]:
#!kill -9 -1

In [0]:
# import our command line tools
#!pip uninstall -y magenta;
#!pip install magenta;

from magenta.models.sketch_rnn.sketch_rnn_train import *
from magenta.models.sketch_rnn.model import *
from magenta.models.sketch_rnn.utils import *
from magenta.models.sketch_rnn.rnn import *

In [24]:
tf.logging.info("TensorFlow Version: %s", tf.__version__)


INFO:tensorflow:TensorFlow Version: 1.2.1

In [0]:
# little function that displays vector images and saves them to .svg
count = 0
def draw_strokes(data, factor=0.2, svg_filename = '/tmp/sketch_rnn/svg/sample.svg'):
  global count
  tf.gfile.MakeDirs(os.path.dirname(svg_filename))
  min_x, max_x, min_y, max_y = get_bounds(data, factor)
  dims = (50 + max_x - min_x, 50 + max_y - min_y)
  dwg = svgwrite.Drawing(svg_filename, size=dims)
  dwg.add(dwg.rect(insert=(0, 0), size=dims,fill='white'))
  lift_pen = 1
  abs_x = 25 - min_x 
  abs_y = 25 - min_y
  p = "M%s,%s " % (abs_x, abs_y)
  command = "m"
  for i in xrange(len(data)):
    if (lift_pen == 1):
      command = "m"
    elif (command != "l"):
      command = "l"
    else:
      command = ""
    x = float(data[i,0])/factor
    y = float(data[i,1])/factor
    lift_pen = data[i, 2]
    p += command+str(x)+","+str(y)+" "
  the_color = "black"
  stroke_width = 1
  dwg.add(dwg.path(p).stroke(the_color,stroke_width).fill("none"))
  dwg.save()
  #dwg.write(f)
  #f.write(SVG(dwg.tostring()))
  dwg.save()
  files.download(svg_filename)
  count += 1
  display(SVG(dwg.tostring()))

# generate a 2D grid of many vector drawings
def make_grid_svg(s_list, grid_space=10.0, grid_space_x=16.0):
  def get_start_and_end(x):
    x = np.array(x)
    x = x[:, 0:2]
    x_start = x[0]
    x_end = x.sum(axis=0)
    x = x.cumsum(axis=0)
    x_max = x.max(axis=0)
    x_min = x.min(axis=0)
    center_loc = (x_max+x_min)*0.5
    return x_start-center_loc, x_end
  x_pos = 0.0
  y_pos = 0.0
  result = [[x_pos, y_pos, 1]]
  for sample in s_list:
    s = sample[0]
    grid_loc = sample[1]
    grid_y = grid_loc[0]*grid_space+grid_space*0.5
    grid_x = grid_loc[1]*grid_space_x+grid_space_x*0.5
    start_loc, delta_pos = get_start_and_end(s)

    loc_x = start_loc[0]
    loc_y = start_loc[1]
    new_x_pos = grid_x+loc_x
    new_y_pos = grid_y+loc_y
    result.append([new_x_pos-x_pos, new_y_pos-y_pos, 0])

    result += s.tolist()
    result[-1][2] = 1
    x_pos = new_x_pos+delta_pos[0]
    y_pos = new_y_pos+delta_pos[1]
  return np.array(result)

In [41]:
!mkdir sketch_rnn;
!mkdir sketch_rnn/models;
!mkdir sketch_rnn/models/aaron_sheep;
!mkdir sketch_rnn/models/aaron_sheep/layer_norm;


data_dir = 'http://github.com/hardmaru/sketch-rnn-datasets/raw/master/aaron_sheep/'
models_root_dir = './sketch_rnn/models'
model_dir = './sketch_rnn/models/aaron_sheep/layer_norm'


mkdir: cannot create directory ‘sketch_rnn’: File exists
mkdir: cannot create directory ‘sketch_rnn/models’: File exists
mkdir: cannot create directory ‘sketch_rnn/models/aaron_sheep’: File exists
mkdir: cannot create directory ‘sketch_rnn/models/aaron_sheep/layer_norm’: File exists

In [42]:
download_pretrained_models(models_root_dir=models_root_dir)


INFO:tensorflow:./sketch_rnn/models/sketch_rnn.zip already exists, using cached copy
INFO:tensorflow:Unzipping ./sketch_rnn/models/sketch_rnn.zip...
INFO:tensorflow:Unzipping complete.

In [43]:
[train_set, valid_set, test_set, hps_model, eval_hps_model, sample_hps_model] = load_env(data_dir, model_dir)


INFO:tensorflow:Downloading http://github.com/hardmaru/sketch-rnn-datasets/raw/master/aaron_sheep/aaron_sheep.npz
INFO:tensorflow:Loaded 7400/300/300 from aaron_sheep.npz
INFO:tensorflow:Dataset combined: 8000 (7400/300/300), avg len 125
INFO:tensorflow:model_params.max_seq_len 250.
total images <= max_seq_len is 7400
total images <= max_seq_len is 300
total images <= max_seq_len is 300
INFO:tensorflow:normalizing_scale_factor 18.5198.

In [44]:
reset_graph()
model = Model(hps_model)
eval_model = Model(eval_hps_model, reuse=True)
sample_model = Model(sample_hps_model, reuse=True)


INFO:tensorflow:Model using gpu.
INFO:tensorflow:Input dropout mode = 0.
INFO:tensorflow:Output dropout mode = 0.
INFO:tensorflow:Recurrent dropout mode = 1.
INFO:tensorflow:Model using gpu.
INFO:tensorflow:Input dropout mode = 0.
INFO:tensorflow:Output dropout mode = 0.
INFO:tensorflow:Recurrent dropout mode = 0.
INFO:tensorflow:Model using gpu.
INFO:tensorflow:Input dropout mode = 0.
INFO:tensorflow:Output dropout mode = 0.
INFO:tensorflow:Recurrent dropout mode = 0.

In [0]:
sess = tf.InteractiveSession()
sess.run(tf.global_variables_initializer())

In [46]:
load_checkpoint(sess, model_dir)


INFO:tensorflow:Loading model ./sketch_rnn/models/aaron_sheep/layer_norm/vector.
INFO:tensorflow:Restoring parameters from ./sketch_rnn/models/aaron_sheep/layer_norm/vector

In [0]:
def encode(input_strokes):
  strokes = to_big_strokes(input_strokes).tolist()
  strokes.insert(0, [0, 0, 1, 0, 0])
  seq_len = [len(input_strokes)]
  draw_strokes(to_normal_strokes(np.array(strokes)))
  return sess.run(eval_model.batch_z, feed_dict={eval_model.input_data: [strokes], eval_model.sequence_lengths: seq_len})[0]

In [0]:
def decode(z_input=None, draw_mode=True, temperature=0.1, factor=0.2):
  z = None
  if z_input is not None:
    z = [z_input]
  sample_strokes, m = sample(sess, sample_model, seq_len=eval_model.hps.max_seq_len, temperature=temperature, z=z)
  strokes = to_normal_strokes(sample_strokes)
  if draw_mode:
    draw_strokes(strokes, factor)
  return strokes

In [49]:
stroke = test_set.random_sample()
draw_strokes(stroke)



In [50]:
z = encode(stroke)



In [52]:
_ = decode(z, temperature=0.9)



In [54]:
stroke_list = []
for i in range(10):
  stroke_list.append([decode(z, draw_mode=False, temperature=0.05*i+0.1), [0, i]])
stroke_grid = make_grid_svg(stroke_list)
draw_strokes(stroke_grid)



In [0]: