In [113]:
# TensorFlow Model !
import os
import shutil
import numpy as np
import tensorflow as tf
tf.reset_default_graph()
from cell import ConvLSTMCell
import sys
module_path = os.path.join("/home/pratik/work/dl/deepvideos/model/../")
if module_path not in sys.path:
sys.path.append(module_path)
from datasets.batch_generator import datasets
In [114]:
batch_size = 4
timesteps = 32
shape = [64, 64] # Image shape
kernel = [3, 3]
channels = 3
filters = [128, 128] # 2 stacked conv lstm filters
# Create a placeholder for videos.
inputs = tf.placeholder(tf.float32, [batch_size, timesteps] + shape + [channels], name="conv_lstm_inputs") # (batch_size, timestep, H, W, C)
outputs_exp = tf.placeholder(tf.float32, [batch_size, timesteps] + shape + [channels], name="conv_lstm_outputs_exp") # (batch_size, timestep, H, W, C)
# model output
model_output = None
# loss
l2_loss = None
# optimizer
optimizer = None
In [115]:
inputs
Out[115]:
In [116]:
conv_inp_reshape_size = [batch_size * timesteps,]+shape+[channels,]
conv_input = tf.reshape(inputs, conv_inp_reshape_size)
In [117]:
slim = tf.contrib.slim
from tensorflow.python.ops import init_ops
from tensorflow.contrib.layers.python.layers import regularizers
trunc_normal = lambda stddev: init_ops.truncated_normal_initializer(0.0, stddev)
l2_val = 0.00005
In [118]:
#tf.contrib.slim.conv2d?
#tf.contrib.slim.max_pool2d?
tf.contrib.slim.conv2d_transpose?
In [119]:
with tf.variable_scope('conv_before_lstm'):
net = slim.conv2d(conv_input, 32, [3,3], scope='conv_1',weights_initializer=trunc_normal(0.01),weights_regularizer=regularizers.l2_regularizer(l2_val))
print (net)
net = slim.conv2d(net, 64, [3,3], scope='conv_2',weights_initializer=trunc_normal(0.01),weights_regularizer=regularizers.l2_regularizer(l2_val))
print (net)
net = slim.max_pool2d(net, [2,2], scope='pool_1')
print (net)
net = slim.conv2d(net, 32, [3,3], scope='conv_3',weights_initializer=trunc_normal(0.01),weights_regularizer=regularizers.l2_regularizer(l2_val))
print (net)
net = slim.max_pool2d(net, [2,2], scope='pool_2')
print (net)
net = slim.conv2d(net, 32, [3,3], scope='conv_4',weights_initializer=trunc_normal(0.01),weights_regularizer=regularizers.l2_regularizer(l2_val))
print (net)
In [120]:
net_output_shape = net.get_shape().as_list()
lstm_reshape_size = [batch_size, timesteps] + net_output_shape[1:]
lstm_reshape = tf.reshape(net, lstm_reshape_size)
print lstm_reshape
In [121]:
batch_size, time_step, H, W, C = lstm_reshape.get_shape().as_list()
with tf.variable_scope('conv_lstm_model'):
cells = []
for i, each_filter in enumerate(filters):
cell = ConvLSTMCell([H,W], each_filter, kernel)
cells.append(cell)
cell = tf.nn.rnn_cell.MultiRNNCell(cells, state_is_tuple=True)
states_series, current_state = tf.nn.dynamic_rnn(cell, lstm_reshape, dtype=lstm_reshape.dtype)
# current_state => Not used ...
model_output = states_series
In [122]:
model_output
Out[122]:
In [123]:
batch_size, time_step, H, W, C = model_output.get_shape().as_list()
deconv_reshape = tf.reshape(model_output, [batch_size*time_step, H, W, C])
deconv_reshape
Out[123]:
In [124]:
with tf.variable_scope('deconv_after_lstm'):
net = slim.conv2d_transpose(deconv_reshape, 64, [3,3], scope='deconv_1',weights_initializer=trunc_normal(0.01),weights_regularizer=regularizers.l2_regularizer(l2_val))
print net
net = slim.conv2d_transpose(net, 32, [3,3], stride=2, scope='deconv_2',weights_initializer=trunc_normal(0.01),weights_regularizer=regularizers.l2_regularizer(l2_val))
print net
net = slim.conv2d_transpose(net, 3, [3,3], stride=2, activation_fn=tf.tanh ,scope='deconv_3',weights_initializer=trunc_normal(0.01),weights_regularizer=regularizers.l2_regularizer(l2_val))
print net
In [125]:
net_pred_shape = net.get_shape().as_list()
out_pred_shape = [batch_size, timesteps,] + net_pred_shape[1:]
output_pred = tf.reshape(net, out_pred_shape)
output_pred
Out[125]:
In [ ]: