In [113]:
# TensorFlow Model !
import os
import shutil
import numpy as np
import tensorflow as tf
tf.reset_default_graph()
from cell import ConvLSTMCell
import sys
module_path = os.path.join("/home/pratik/work/dl/deepvideos/model/../")
if module_path not in sys.path:
    sys.path.append(module_path)
from datasets.batch_generator import datasets

In [114]:
batch_size = 4
timesteps = 32
shape = [64, 64]  # Image shape
kernel = [3, 3]
channels = 3
filters = [128, 128]  # 2 stacked conv lstm filters

# Create a placeholder for videos.
inputs = tf.placeholder(tf.float32, [batch_size, timesteps] + shape + [channels], name="conv_lstm_inputs")  # (batch_size, timestep, H, W, C)
outputs_exp = tf.placeholder(tf.float32, [batch_size, timesteps] + shape + [channels], name="conv_lstm_outputs_exp")  # (batch_size, timestep, H, W, C)

# model output
model_output = None

# loss
l2_loss = None

# optimizer
optimizer = None

In [115]:
inputs


Out[115]:
<tf.Tensor 'conv_lstm_inputs:0' shape=(4, 32, 64, 64, 3) dtype=float32>

In [116]:
conv_inp_reshape_size = [batch_size * timesteps,]+shape+[channels,]
conv_input = tf.reshape(inputs, conv_inp_reshape_size)

In [117]:
slim = tf.contrib.slim
from tensorflow.python.ops import init_ops
from tensorflow.contrib.layers.python.layers import regularizers
trunc_normal = lambda stddev: init_ops.truncated_normal_initializer(0.0, stddev)
l2_val = 0.00005

In [118]:
#tf.contrib.slim.conv2d?
#tf.contrib.slim.max_pool2d?
tf.contrib.slim.conv2d_transpose?

In [119]:
with tf.variable_scope('conv_before_lstm'):
    net = slim.conv2d(conv_input, 32, [3,3], scope='conv_1',weights_initializer=trunc_normal(0.01),weights_regularizer=regularizers.l2_regularizer(l2_val))
    print (net)
    net = slim.conv2d(net, 64, [3,3], scope='conv_2',weights_initializer=trunc_normal(0.01),weights_regularizer=regularizers.l2_regularizer(l2_val))
    print (net)
    net = slim.max_pool2d(net, [2,2], scope='pool_1')
    print (net)
    net = slim.conv2d(net, 32, [3,3], scope='conv_3',weights_initializer=trunc_normal(0.01),weights_regularizer=regularizers.l2_regularizer(l2_val))
    print (net)
    net = slim.max_pool2d(net, [2,2], scope='pool_2')
    print (net)
    net = slim.conv2d(net, 32, [3,3], scope='conv_4',weights_initializer=trunc_normal(0.01),weights_regularizer=regularizers.l2_regularizer(l2_val))
    print (net)


Tensor("conv_before_lstm/conv_1/Relu:0", shape=(128, 64, 64, 32), dtype=float32)
Tensor("conv_before_lstm/conv_2/Relu:0", shape=(128, 64, 64, 64), dtype=float32)
Tensor("conv_before_lstm/pool_1/MaxPool:0", shape=(128, 32, 32, 64), dtype=float32)
Tensor("conv_before_lstm/conv_3/Relu:0", shape=(128, 32, 32, 32), dtype=float32)
Tensor("conv_before_lstm/pool_2/MaxPool:0", shape=(128, 16, 16, 32), dtype=float32)
Tensor("conv_before_lstm/conv_4/Relu:0", shape=(128, 16, 16, 32), dtype=float32)

In [120]:
net_output_shape =  net.get_shape().as_list()
lstm_reshape_size = [batch_size, timesteps] + net_output_shape[1:]
lstm_reshape = tf.reshape(net, lstm_reshape_size)
print lstm_reshape


Tensor("Reshape_1:0", shape=(4, 32, 16, 16, 32), dtype=float32)

In [121]:
batch_size, time_step, H, W, C = lstm_reshape.get_shape().as_list()
with tf.variable_scope('conv_lstm_model'):
    cells = []
    for i, each_filter in enumerate(filters):
        cell = ConvLSTMCell([H,W], each_filter, kernel)
        cells.append(cell)

    cell = tf.nn.rnn_cell.MultiRNNCell(cells, state_is_tuple=True)        
    states_series, current_state = tf.nn.dynamic_rnn(cell, lstm_reshape, dtype=lstm_reshape.dtype)
    # current_state => Not used ... 
    model_output = states_series

In [122]:
model_output


Out[122]:
<tf.Tensor 'conv_lstm_model/rnn/transpose:0' shape=(4, 32, 16, 16, 128) dtype=float32>

In [123]:
batch_size, time_step, H, W, C = model_output.get_shape().as_list()
deconv_reshape = tf.reshape(model_output, [batch_size*time_step, H, W, C])
deconv_reshape


Out[123]:
<tf.Tensor 'Reshape_2:0' shape=(128, 16, 16, 128) dtype=float32>

In [124]:
with tf.variable_scope('deconv_after_lstm'):
    net = slim.conv2d_transpose(deconv_reshape, 64, [3,3], scope='deconv_1',weights_initializer=trunc_normal(0.01),weights_regularizer=regularizers.l2_regularizer(l2_val))
    print net
    net = slim.conv2d_transpose(net, 32, [3,3], stride=2, scope='deconv_2',weights_initializer=trunc_normal(0.01),weights_regularizer=regularizers.l2_regularizer(l2_val))
    print net
    net = slim.conv2d_transpose(net, 3, [3,3], stride=2, activation_fn=tf.tanh ,scope='deconv_3',weights_initializer=trunc_normal(0.01),weights_regularizer=regularizers.l2_regularizer(l2_val))
    print net


Tensor("deconv_after_lstm/deconv_1/Relu:0", shape=(128, 16, 16, 64), dtype=float32)
Tensor("deconv_after_lstm/deconv_2/Relu:0", shape=(128, 32, 32, 32), dtype=float32)
Tensor("deconv_after_lstm/deconv_3/Tanh:0", shape=(128, 64, 64, 3), dtype=float32)

In [125]:
net_pred_shape = net.get_shape().as_list()
out_pred_shape = [batch_size, timesteps,] + net_pred_shape[1:]
output_pred = tf.reshape(net, out_pred_shape)
output_pred


Out[125]:
<tf.Tensor 'Reshape_3:0' shape=(4, 32, 64, 64, 3) dtype=float32>

In [ ]: