In [1]:
import tensorflow as tf
import numpy as np
In [2]:
h1 = tf.placeholder("float", 1)
h2 = tf.placeholder("float", 1)
In [3]:
def step(h_tm1, h_tm2):
h_t = h_tm1 + h_tm2
return h_t, h_tm1
In [4]:
outputs = []
h_tm1 = h1
h_tm2 = h2
for i in range(10):
h_tm1, h_tm2 = step(h_tm1, h_tm2)
outputs.append(h_tm1)
In [4]:
sess = tf.Session()
In [6]:
r = sess.run(outputs, feed_dict=({h2:np.zeros(1), h1:np.ones(1)}))
In [7]:
r
Out[7]:
In [ ]:
"""
We can write our own "scan" for tensorflow so that we have a unifying API.
For LSTM and GRU, TF already provides a fast implementation
"""
In [21]:
def scan(step, sequences=None, outputs_info=None, non_sequences=None, n_steps=None):
#if sequences is None:
# assert n_steps is not None
leng = n_steps
outputs = outputs_info
results = [[] for i in outputs_info]
for i in range(leng):
seq = [s[i] for s in sequences]
nseq = [n[i] for n in non_sequences]
input_list = seq + outputs + nseq
outputs = list(step(*input_list))
for i, o in enumerate(outputs):
results[i].append(outputs[i])
results = [tf.concat(0, r) for r in results]
return results
In [22]:
outputs = scan(step, sequences=[], outputs_info=[h1, h2], non_sequences=[], n_steps=10)
In [23]:
r = sess.run(outputs, feed_dict=({h2:np.zeros(1), h1:np.ones(1)}))
In [24]:
r
Out[24]: