In [1]:
# Import here
%matplotlib inline
import matplotlib.pylab as pylab
import matplotlib.pyplot as plt
import numpy as np
import theano
import theano.tensor as T

pylab.rcParams['figure.figsize'] = (9.0, 6.0)

In [26]:
# architecture params
dim_unit = 64
n_steps = 100
n_batch = 1024
n_iter = 1000000
disp_every = 10
acce_every = 250
save_every = 1000
sigma = np.float32(1. / np.sqrt(dim_unit))

# learning rate params
eta_A0 = np.float32(0.001)
eta_V0 = np.float32(0.001)
eta_W0 = np.float32(eta_V0 / n_steps)
eta_S0 = np.float32(eta_W0 / n_steps)
eta_c0 = np.float32(eta_V0 / dim_unit)
eta_b0 = np.float32(eta_W0 / dim_unit)
eta_ratio = np.float32(100)
mu_max = np.float(0.999)
mu_ending = np.float(0.9)

# save file path
tmp_save = '/Users/chayut/Dropbox/data/grid_cells/opt_coding_1d_fine_tuning_tmp_save.npz'
fin_save = '/Users/chayut/Dropbox/data/grid_cells/opt_coding_1d_fine_tuning_fin_save.npz'

# initialization
A = np.load('/Users/chayut/Dropbox/data/grid_cells/cpc_decoding_vector.npy')
with np.load('/Users/chayut/Dropbox/data/grid_cells/opt_coding_1d_fin_save.npz') as result:
    S = result['S']
    W = result['W']
    V = result['V']
    b = result['b']
    c = result['c']

vA = np.zeros_like(A)
vS = np.zeros_like(S)
vV = np.zeros_like(V)
vW = np.zeros_like(W)
vb = np.zeros_like(b)
vc = np.zeros_like(c)

eta_A = eta_A0
eta_V = eta_V0
eta_W = eta_W0
eta_S = eta_S0
eta_c = eta_c0
eta_b = eta_b0
mu = np.float(0.5)

In [27]:
class VariableHolder(object):
    pass
    
def rectify(x):
    return T.maximum(x, 0)

def one_step(z, S, b_eff):
    return rectify(T.dot(S, z) + b_eff)

sparse_basis = np.arange(dim_unit, dtype=np.float32).reshape((dim_unit, 1))

t = VariableHolder()
t.u = T.frow(name='u')
t.nu = T.fmatrix(name='nu')
t.A, t.S, t.W, t.V = T.fmatrices('A', 'S', 'W', 'V')
t.b, t.c = T.cols('b', 'c')

t.u_dist = t.u * np.float32(dim_unit - 1) - sparse_basis
t.x0 = T.exp(-(t.u_dist ** 2 / 2))
t.x = t.x0 / T.sqrt(T.sum(t.x0 ** 2, axis=0, keepdims=True))
t.y = rectify(T.dot(t.V, t.x) + t.c)
t.b_eff = t.b + T.dot(t.W, t.y + t.nu)
t.z_seq, t.updates = theano.scan(fn=one_step,
                                 outputs_info=T.zeros_like(t.b_eff),
                                 non_sequences=[t.S, t.b_eff],
                                 n_steps=n_steps)
t.x_hat = t.z_seq[-1]
t.u_hat = T.nnet.sigmoid(T.dot(t.A, t.x_hat))
t.cost = T.mean((t.u_hat - t.u) ** 2)

t.dA = T.grad(t.cost, t.A)
t.dS = T.grad(t.cost, t.S)
t.dW = T.grad(t.cost, t.W)
t.dV = T.grad(t.cost, t.V)
t.db = T.grad(t.cost, t.b)
t.dc = T.grad(t.cost, t.c)

rnn_compute = theano.function(inputs=[t.u, t.nu, t.A, t.S, t.W, t.V, t.b, t.c],
                              outputs=[t.cost, t.x, t.y, t.x_hat, t.u_hat,
                                       t.dA, t.dS, t.dW, t.dV, t.db, t.dc],
                              updates=t.updates)

In [28]:
for i in range(n_iter):
    i1 = i + 1
    if i1 % acce_every == 0:
        eta_factor = 1.0 / (1.0 + (eta_ratio * i1) / n_iter)
        eta_A = eta_A0 * eta_factor
        eta_S = eta_S0 * eta_factor
        eta_W = eta_W0 * eta_factor
        eta_V = eta_V0 * eta_factor
        eta_b = eta_b0 * eta_factor
        eta_c = eta_c0 * eta_factor
        if i1 < 0.9 * n_iter:
            mu_unbounded = 1 - 0.5 / (1 + (i1 / acce_every))
            mu = np.minimum(mu_max, mu_unbounded)
        else:
            mu = mu_ending
    
    # decay momentums
    vA *= mu
    vS *= mu
    vW *= mu
    vV *= mu
    vb *= mu
    vc *= mu
    # update weights with momentums
    A += vA
    S += vS
    W += vW
    V += vV
    b += vb
    c += vc
    V /= np.sqrt(np.sum(V ** 2, axis=1, keepdims=True))

    u = np.random.uniform(size=(1, n_batch)).astype(np.float32)
    nu = sigma * np.random.randn(dim_unit, n_batch).astype(np.float32)
    cost, x, y, x_hat, u_hat, dA, dS, dW, dV, db, dc = rnn_compute(u, nu, A, S, W, V, b, c)
    if i1 % disp_every == 0:
        print i1, ':', np.sqrt(cost)
    
    # update momentums with new grads
    vA -= eta_A * dA
    vS -= eta_S * dS
    vW -= eta_W * dW
    vV -= eta_V * dV
    vb -= eta_b * db
    vc -= eta_c * dc
    # update weights with new grads
    A -= eta_A * dA
    S -= eta_S * dS
    W -= eta_W * dW
    V -= eta_V * dV
    b -= eta_b * db
    c -= eta_c * dc
    V /= np.sqrt(np.sum(V ** 2, axis=1, keepdims=True))
    
    if i1 % save_every == 0:
        np.savez(tmp_save,
                 cost=cost, u=u, u_hat=u_hat, x=x, x_hat=x_hat, y=y,
                 A=A, S=S, W=W, V=V, b=b, c=c)


10 : 0.00986188
20 : 0.010486
30 : 0.011162
40 : 0.0096735
50 : 0.0103571
60 : 0.0105638
70 : 0.00994699
80 : 0.0099207
90 : 0.0100439
100 : 0.0108662
---------------------------------------------------------------------------
KeyboardInterrupt                         Traceback (most recent call last)
<ipython-input-28-a6eb478a3d61> in <module>()
     33     u = np.random.uniform(size=(1, n_batch)).astype(np.float32)
     34     nu = sigma * np.random.randn(dim_unit, n_batch).astype(np.float32)
---> 35     cost, x, y, x_hat, u_hat, dA, dS, dW, dV, db, dc = rnn_compute(u, nu, A, S, W, V, b, c)
     36     if i1 % disp_every == 0:
     37         print i1, ':', np.sqrt(cost)

/Users/chayut/anaconda/lib/python2.7/site-packages/theano/compile/function_module.pyc in __call__(self, *args, **kwargs)
    593         t0_fn = time.time()
    594         try:
--> 595             outputs = self.fn()
    596         except Exception:
    597             if hasattr(self.fn, 'position_of_error'):

/Users/chayut/anaconda/lib/python2.7/site-packages/theano/scan_module/scan_op.pyc in rval(p, i, o, n, allow_gc)
    635         def rval(p=p, i=node_input_storage, o=node_output_storage, n=node,
    636                  allow_gc=allow_gc):
--> 637             r = p(n, [x[0] for x in i], o)
    638             for o in node.outputs:
    639                 compute_map[o][0] = True

/Users/chayut/anaconda/lib/python2.7/site-packages/theano/scan_module/scan_op.pyc in <lambda>(node, args, outs)
    624                         args,
    625                         outs,
--> 626                         self, node)
    627         except (ImportError, theano.gof.cmodule.MissingGXX):
    628             p = self.execute

/Users/chayut/.theano/compiledir_Darwin-14.1.0-x86_64-i386-64bit-i386-2.7.9-64/scan_perform/scan_perform.so in theano.scan_module.scan_perform.perform (/Users/chayut/.theano/compiledir_Darwin-14.1.0-x86_64-i386-64bit-i386-2.7.9-64/scan_perform/mod.cpp:4578)()

/Users/chayut/anaconda/lib/python2.7/site-packages/theano/tensor/type.pyc in value_zeros(self, shape)
    581             return ()
    582 
--> 583     def value_zeros(self, shape):
    584         """
    585         Create an numpy ndarray full of 0 values.

KeyboardInterrupt: 

In [17]:
plt.imshow(V, interpolation='None')


Out[17]:
<matplotlib.image.AxesImage at 0x11d676f90>

In [ ]: