In [1]:
import tensorflow as tf
import numpy as np

In [2]:
sequence_length = 96
embedding_length = 64

Placeholder


In [3]:
input_data = tf.placeholder(tf.float32,[None,sequence_length,embedding_length])

Build RNN Cell


In [4]:
hidden_vector_size = 100

rnn_cell = tf.contrib.rnn.LSTMCell(hidden_vector_size)


WARNING:tensorflow:From /usr/local/lib/python3.5/dist-packages/tensorflow/contrib/learn/python/learn/datasets/base.py:198: retry (from tensorflow.contrib.learn.python.learn.datasets.base) is deprecated and will be removed in a future version.
Instructions for updating:
Use the retry module or similar alternatives.

Get batch Size from input data rather than hard coding it


In [5]:
initial_zero_h = tf.matmul(tf.reduce_mean(tf.zeros_like(input_data),2),
                           tf.zeros([sequence_length,hidden_vector_size]))

Initial State with tuple of h,c


In [6]:
initial_state = tf.contrib.rnn.LSTMStateTuple(initial_zero_h,initial_zero_h)

Getting the outputs


In [7]:
outputs, state = tf.nn.dynamic_rnn(rnn_cell, input_data,
                                   initial_state=initial_state,
                                   dtype=tf.float32)

In [8]:
outputs


Out[8]:
<tf.Tensor 'rnn/transpose_1:0' shape=(?, 96, 100) dtype=float32>

In [9]:
sess = tf.InteractiveSession()
sess.run(tf.global_variables_initializer())

Use any batch Size


In [10]:
batch_size = 14
fake_input = np.random.uniform(size=[batch_size,96,64])
outputs.eval(feed_dict={input_data:fake_input})[0,:10,:10]


Out[10]:
array([[ 0.0133772 ,  0.1248076 ,  0.02860027, -0.00279607, -0.18480375,
         0.00869518,  0.02967772,  0.03201196, -0.03233923, -0.12331957],
       [ 0.01717433,  0.2263697 ,  0.06569117, -0.03991626, -0.21566994,
        -0.00274733,  0.01245386,  0.00521528, -0.03463598, -0.15695679],
       [ 0.04783954,  0.3069565 ,  0.08629331, -0.03657083, -0.20743844,
         0.01699171, -0.00409255,  0.02383869, -0.03702394, -0.18661733],
       [ 0.07255518,  0.27716413,  0.1319593 , -0.07308843, -0.27800056,
         0.0886552 , -0.04429418, -0.03370504, -0.0502329 , -0.22744851],
       [ 0.13757218,  0.34861842,  0.1028095 , -0.1123821 , -0.3602204 ,
         0.15615414, -0.06908301, -0.05146805, -0.01621163, -0.25484738],
       [ 0.15186062,  0.3146225 ,  0.08943865, -0.15840401, -0.34275725,
         0.14421731, -0.1636517 , -0.00490215,  0.0531736 , -0.20377146],
       [ 0.19039147,  0.27051648,  0.04969539, -0.13397585, -0.36124727,
         0.20126095, -0.11209239, -0.03270925,  0.01869536, -0.21338326],
       [ 0.21901123,  0.30190438,  0.00503324, -0.16593175, -0.36799055,
         0.23424742, -0.11466883, -0.02946316,  0.05092116, -0.159033  ],
       [ 0.17929147,  0.25147724,  0.01736976, -0.19413503, -0.37185222,
         0.25556013, -0.15847237,  0.0221071 ,  0.05562782, -0.18257754],
       [ 0.2271389 ,  0.27966276,  0.06107444, -0.14930013, -0.36397922,
         0.2083975 , -0.18683249,  0.06619405,  0.03500409, -0.20571777]],
      dtype=float32)

In [11]:
batch_size = 140
fake_input = np.random.uniform(size=[batch_size,96,64])
outputs.eval(feed_dict={input_data:fake_input})[0,:10,:10]


Out[11]:
array([[ 0.03142734,  0.15519665,  0.02140723, -0.00588781, -0.10464185,
         0.01756072, -0.00956259,  0.04701442,  0.08324529, -0.09299037],
       [ 0.01613303,  0.19432577,  0.0701172 , -0.03475946, -0.17193635,
         0.03618715, -0.00548373,  0.02486533,  0.12223729, -0.09514315],
       [ 0.03262727,  0.2981133 ,  0.08287209, -0.06223779, -0.25430405,
         0.04866415,  0.01303787,  0.09578612,  0.03252373, -0.16682807],
       [ 0.10809346,  0.3251696 ,  0.07644037, -0.08040279, -0.3162625 ,
         0.09884892, -0.04423967,  0.15149339,  0.03924495, -0.22590007],
       [ 0.1504862 ,  0.3340804 ,  0.05860933, -0.09605774, -0.32826826,
         0.19024172, -0.05101546,  0.147878  ,  0.06868754, -0.13980682],
       [ 0.14726807,  0.31644508,  0.05159319, -0.1089751 , -0.32962084,
         0.1447724 , -0.03714456,  0.1136796 ,  0.10779426, -0.11198579],
       [ 0.22446983,  0.3208124 ,  0.05245988, -0.12058066, -0.31481117,
         0.10862615, -0.08089321,  0.10375389,  0.12114908, -0.1494493 ],
       [ 0.29125607,  0.33517447,  0.01976913, -0.09151037, -0.3436202 ,
         0.19459113, -0.05055251,  0.12222207,  0.09918524, -0.22896896],
       [ 0.23131868,  0.319851  ,  0.05228816, -0.11567234, -0.3553371 ,
         0.15113856, -0.11206171,  0.06238765,  0.07839007, -0.20897834],
       [ 0.27803063,  0.32792312,  0.06652988, -0.03546051, -0.35768583,
         0.09689046, -0.14316122,  0.05327737,  0.14563368, -0.22349374]],
      dtype=float32)

In [ ]: