This script compares the running time of two versions of Variational autoencoders (in an epoch):
In [1]:
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
import tensorflow as tf
import time
from tensorflow.python.client import timeline
import matplotlib.pyplot as plt
%matplotlib inline
In [2]:
import sys
sys.path.append('../vae')
In [3]:
from misc_sta663 import *
(mnist, n_samples) = mnist_loader()
In [4]:
from vae_sta663 import *
import tensorflow as tf
import numpy as np
config_normal = {}
config_normal['x_in'] = 784
config_normal['encoder_1'] = 500
config_normal['encoder_2'] = 500
config_normal['decoder_1'] = 500
config_normal['decoder_2'] = 500
config_normal['z'] = 20
In [5]:
from vae_parallel_sta663 import *
In [6]:
import tensorflow as tf
import numpy as np
config_parallel = {}
config_parallel['x_in'] = 784
config_parallel['encoder_1'] = 500
config_parallel['encoder_2'] = 500
config_parallel['decoder_1'] = 500
config_parallel['decoder_2'] = 500
config_parallel['z'] = 20
In [7]:
batch_size = np.arange(100, 1100, 100)
run_time_normal = []
In [8]:
def vae_train(sess, optimizer, cost, x, n_samples, batch_size=100, learn_rate=0.001, train_epoch=10, verb=1, verb_step=5):
for epoch in range(train_epoch):
avg_cost = 0
total_batch = int(n_samples / batch_size)
for i in range(total_batch):
batch_x, _ = mnist.train.next_batch(batch_size)
_, c = sess.run((optimizer, cost), feed_dict={x: batch_x})
avg_cost += c / n_samples * batch_size
if verb:
if epoch % verb_step == 0:
print('Epoch:%04d' % (epoch+1), 'cost=', '{:.9f}'.format(avg_cost))
In [9]:
for s in batch_size:
print('Evaluating at: %d' % s)
(sess_1, optimizer_1, cost_1, x_1, x_prime_1) = vae_init(batch_size=s, config=config_normal)
result_1 = %timeit -o -n1 -r5 vae_train(sess_1, optimizer_1, cost_1, x_1, n_samples, batch_size=s, train_epoch=1, verb=0)
sess_1.close()
run_time_normal.append(result_1.all_runs)
In [11]:
import pickle
with open('../]data/runtime_normal.pickle', 'wb') as f:
pickle.dump(run_time_normal, f)
In [ ]: