this notebook basically just runs the script "sketch_rnn_train.py" that is included within magenta's sketch_rnn repository.

to use it, first place the notebook file inside the sketch_rnn directory.

make new folders "datasets" and "models" inside the sketch_rnn directory.

make subfolders inside these two directories for any dataset you want to train.

change the settings below to point to your dataset source ("data_set") and where the model should be written to ("log_root").

see magenta's github repo for sketch_rnn for a full list of settings that you can specify when training, as well as what the defaults are.


In [ ]:
%run -i sketch_rnn_train.py --log_root=models/aaron_sheep --data_dir=datasets/ --hparams='{"data_set":"aaron_sheep.npz"}'


INFO:tensorflow:sketch-rnn
INFO:tensorflow:Hyperparams:
INFO:tensorflow:grad_clip = 1.0
INFO:tensorflow:conditional = True
INFO:tensorflow:min_learning_rate = 1e-05
INFO:tensorflow:num_mixture = 20
INFO:tensorflow:is_training = True
INFO:tensorflow:input_dropout_prob = 0.9
INFO:tensorflow:kl_decay_rate = 0.99995
INFO:tensorflow:kl_tolerance = 0.2
INFO:tensorflow:random_scale_factor = 0.15
INFO:tensorflow:max_seq_len = 250
INFO:tensorflow:use_recurrent_dropout = True
INFO:tensorflow:num_steps = 10000000
INFO:tensorflow:use_output_dropout = False
INFO:tensorflow:decay_rate = 0.9999
INFO:tensorflow:z_size = 128
INFO:tensorflow:augment_stroke_prob = 0.1
INFO:tensorflow:learning_rate = 0.001
INFO:tensorflow:batch_size = 100
INFO:tensorflow:enc_model = lstm
INFO:tensorflow:use_input_dropout = False
INFO:tensorflow:dec_model = lstm
INFO:tensorflow:enc_rnn_size = 256
INFO:tensorflow:output_dropout_prob = 0.9
INFO:tensorflow:save_every = 500
INFO:tensorflow:kl_weight = 0.5
INFO:tensorflow:data_set = aaron_sheep.npz
INFO:tensorflow:kl_weight_start = 0.01
INFO:tensorflow:dec_rnn_size = 512
INFO:tensorflow:recurrent_dropout_prob = 0.9
INFO:tensorflow:Loading data files.
INFO:tensorflow:Loaded 7400/300/300 from aaron_sheep.npz
INFO:tensorflow:Dataset combined: 8000 (7400/300/300), avg len 125
INFO:tensorflow:model_params.max_seq_len 250.
total images <= max_seq_len is 7400
total images <= max_seq_len is 300
total images <= max_seq_len is 300
INFO:tensorflow:normalizing_scale_factor 18.5198.
INFO:tensorflow:Model using gpu.
INFO:tensorflow:Input dropout mode = False.
INFO:tensorflow:Output dropout mode = False.
INFO:tensorflow:Recurrent dropout mode = True.
INFO:tensorflow:Model using gpu.
INFO:tensorflow:Input dropout mode = False.
INFO:tensorflow:Output dropout mode = False.
INFO:tensorflow:Recurrent dropout mode = False.
INFO:tensorflow:vector_rnn/ENC_RNN/fw/LSTMCell/W_xh:0 (5, 1024) 5120
INFO:tensorflow:vector_rnn/ENC_RNN/fw/LSTMCell/W_hh:0 (256, 1024) 262144
INFO:tensorflow:vector_rnn/ENC_RNN/fw/LSTMCell/bias:0 (1024,) 1024
INFO:tensorflow:vector_rnn/ENC_RNN/bw/LSTMCell/W_xh:0 (5, 1024) 5120
INFO:tensorflow:vector_rnn/ENC_RNN/bw/LSTMCell/W_hh:0 (256, 1024) 262144
INFO:tensorflow:vector_rnn/ENC_RNN/bw/LSTMCell/bias:0 (1024,) 1024
INFO:tensorflow:vector_rnn/ENC_RNN_mu/super_linear_w:0 (512, 128) 65536
INFO:tensorflow:vector_rnn/ENC_RNN_mu/super_linear_b:0 (128,) 128
INFO:tensorflow:vector_rnn/ENC_RNN_sigma/super_linear_w:0 (512, 128) 65536
INFO:tensorflow:vector_rnn/ENC_RNN_sigma/super_linear_b:0 (128,) 128
INFO:tensorflow:vector_rnn/linear/super_linear_w:0 (128, 1024) 131072
INFO:tensorflow:vector_rnn/linear/super_linear_b:0 (1024,) 1024
INFO:tensorflow:vector_rnn/RNN/output_w:0 (512, 123) 62976
INFO:tensorflow:vector_rnn/RNN/output_b:0 (123,) 123
INFO:tensorflow:vector_rnn/RNN/LSTMCell/W_xh:0 (133, 2048) 272384
INFO:tensorflow:vector_rnn/RNN/LSTMCell/W_hh:0 (512, 2048) 1048576
INFO:tensorflow:vector_rnn/RNN/LSTMCell/bias:0 (2048,) 2048
INFO:tensorflow:Total trainable variables 2186107.
INFO:tensorflow:starting now
INFO:tensorflow:step 0
INFO:tensorflow:and... 0
INFO:tensorflow:step 1
INFO:tensorflow:and... 1
INFO:tensorflow:step 2
INFO:tensorflow:and... 2
INFO:tensorflow:step 3
INFO:tensorflow:and... 3
INFO:tensorflow:step 4
INFO:tensorflow:and... 4
INFO:tensorflow:step 5
INFO:tensorflow:and... 5
INFO:tensorflow:step 6
INFO:tensorflow:and... 6
INFO:tensorflow:step 7
INFO:tensorflow:and... 7
INFO:tensorflow:step 8
INFO:tensorflow:and... 8
INFO:tensorflow:step 9
INFO:tensorflow:and... 9
INFO:tensorflow:step 10
INFO:tensorflow:and... 10
INFO:tensorflow:step 11
INFO:tensorflow:and... 11
INFO:tensorflow:step 12
INFO:tensorflow:and... 12
INFO:tensorflow:step 13
INFO:tensorflow:and... 13
INFO:tensorflow:step 14
INFO:tensorflow:and... 14
INFO:tensorflow:step 15
INFO:tensorflow:and... 15
INFO:tensorflow:step 16
INFO:tensorflow:and... 16
INFO:tensorflow:step 17
INFO:tensorflow:and... 17
INFO:tensorflow:step 18

In [ ]: