In [1]:
# Import here
%matplotlib inline
import matplotlib.pylab as pylab
import matplotlib.pyplot as plt
import numpy as np
import theano
import theano.tensor as T
pylab.rcParams['figure.figsize'] = (9.0, 6.0)
In [2]:
def create_decoder(A):
x = T.fmatrix(name='x')
u = T.nnet.sigmoid(T.dot(A, x))
return theano.function(inputs=[x], outputs=u)
A = np.load('/Users/chayut/Dropbox/data/grid_cells/cpc_decoding_vector.npy')
decoder = create_decoder(A)
In [3]:
dim_unit = 64
n_steps = 100
n_batch = 2 ** 14
sigma = np.float32(1. / np.sqrt(dim_unit))
class VariableHolder(object):
pass
def rectify(x):
return T.maximum(x, 0)
def rail(x):
return T.maximum(T.minimum(x, 1), 0)
def one_step(z, S, b_eff):
return rectify(T.dot(S, z) + b_eff)
sparse_basis = np.arange(dim_unit, dtype=np.float32).reshape((dim_unit, 1))
t = VariableHolder()
t.u = T.frow(name='u')
t.nu = T.fmatrix(name='nu')
t.S, t.W, t.V = T.fmatrices('S', 'W', 'V')
t.b, t.c = T.cols('b', 'c')
t.u_dist = t.u * np.float32(dim_unit - 1) - sparse_basis
t.x0 = T.exp(-(t.u_dist ** 2 / 2))
t.x = t.x0 / T.sqrt(T.sum(t.x0 ** 2, axis=0, keepdims=True))
t.y = rail(T.dot(t.V, t.x) + t.c)
t.b_eff = t.b + T.dot(t.W, t.y + t.nu)
t.z_seq, t.updates = theano.scan(fn=one_step,
outputs_info=T.zeros_like(t.b_eff),
non_sequences=[t.S, t.b_eff],
n_steps=n_steps)
t.x_hat = t.z_seq[-1]
t.x_error = t.x_hat - t.x
t.cost = T.mean(T.sum(t.x_error ** 2, axis=0))
rnn_compute = theano.function(inputs=[t.u, t.nu, t.S, t.W, t.V, t.b, t.c],
outputs=[t.cost, t.x, t.y, t.x_hat],
updates=t.updates)
In [14]:
with np.load('/Users/chayut/Dropbox/data/grid_cells/opt_coding_1d_fourier_init_tmp_save.npz') as result:
S = result['S']
W = result['W']
V = result['V']
b = result['b']
c = result['c']
In [15]:
u = np.random.uniform(size=(1, n_batch)).astype(np.float32)
nu = sigma * np.random.randn(dim_unit, n_batch).astype(np.float32)
cost, x, y, x_hat = rnn_compute(u, nu, S, W, V, b, c)
In [16]:
np.sqrt(np.mean((u - decoder(x_hat)) ** 2))
Out[16]:
In [17]:
plt.imshow(V, interpolation='None')
plt.colorbar()
plt.show()
plt.imshow(W, interpolation='None')
plt.colorbar()
plt.show()
plt.imshow(S, interpolation='None')
plt.colorbar()
plt.show()
In [18]:
i = 60
plt.plot(V[i, :].T)
plt.show()
plt.plot(np.abs(np.fft.rfft(V[i, :])).T)
plt.grid()
plt.show()
plt.imshow(np.abs(np.fft.fft(V)), interpolation='None')
plt.colorbar()
plt.show()
In [19]:
np.mean(np.sum(y ** 2, axis=0))
Out[19]:
In [20]:
_ = plt.hist(y.reshape((y.size, 1)), bins=100)
print np.amax(y)
In [21]:
np.sum(V ** 2, axis=1)
Out[21]:
In [22]:
plt.imshow(x[:, :100], interpolation='None')
plt.show()
plt.imshow(x_hat[:, :100], interpolation='None')
plt.show()
In [23]:
np.sqrt(np.mean(np.sum((x - x_hat) ** 2, axis=0)))
Out[23]:
In [13]: