In [1]:
# Import here
%matplotlib inline
import matplotlib.pylab as pylab
import matplotlib.pyplot as plt
import numpy as np
import theano
import theano.tensor as T

pylab.rcParams['figure.figsize'] = (9.0, 6.0)

In [2]:
def cpc(x, N):
    n = np.arange(N, dtype=np.float32).reshape((N, 1))
    z_matrix = np.float32(x * (N - 1)) - n
    r = np.exp(-z_matrix ** 2 / 2)
    r /= np.sqrt(np.sum(r ** 2, axis=0, keepdims=True))
    return r

In [23]:
n_steps = 100

def t_rectify(t_x):
    return T.maximum(t_x, 0.0)

def t_one_step(t_z, t_S, t_b_eff):
    return t_rectify(T.dot(t_S, t_z) + t_b_eff)

t_x, t_nu = T.fmatrices('x', 'nu')
t_S, t_W, t_V = T.fmatrices('S', 'W', 'V')
t_b, t_c = T.fcols('b', 'c')
t_y = t_rectify(T.dot(t_V, t_x) + t_c)
t_b_eff = T.dot(t_W, t_y + t_nu) + t_b
t_z_seq, updates = theano.scan(t_one_step,
                               outputs_info=T.zeros_like(t_b_eff),
                               non_sequences=[t_S, t_b_eff],
                               n_steps=n_steps)
t_x_hat = t_z_seq[-1]
cost = T.mean(T.sum((t_x_hat - t_x) ** 2, axis=0))
t_dS = T.grad(cost, t_S)
t_dW = T.grad(cost, t_W)
t_dV = T.grad(cost, t_V)
t_db = T.grad(cost, t_b)
t_dc = T.grad(cost, t_c)
f = theano.function(inputs=[t_x, t_nu, t_S, t_W, t_V, t_b, t_c],
                    outputs=[cost, t_x_hat, t_dS, t_dW, t_dV, t_db, t_dc],
                    updates=updates)

In [28]:
dim_unit = 64

# initialization
S = np.eye(dim_unit, dtype=np.float32)
V = np.random.randn(dim_unit, dim_unit).astype(np.float32)
V /= np.sqrt(np.sum(V ** 2, axis=1, keepdims=True))
W = V.T / n_steps
b = np.zeros((dim_unit, 1), dtype=np.float32)
c = np.zeros((dim_unit, 1), dtype=np.float32)

In [34]:
sigma = np.float32(1. / np.sqrt(dim_unit))
eta_V = np.float32(0.05)
eta_c = eta_V / dim_unit
eta_W = eta_V / n_steps
eta_S = eta_W / n_steps
eta_b = eta_W / dim_unit
dim_batch = 1024

for t in range(100000):
    x = np.random.uniform(size=(1, dim_batch)).astype(np.float32)
    r = cpc(x, dim_unit)
    nu = sigma * np.random.randn(dim_unit, dim_batch).astype(np.float32)
    cost, x_hat, dS, dW, dV, db, dc = f(r, nu, S, W, V, b, c)
    if (t + 1) % 10 == 0:
        print (t + 1), ':', np.sqrt(cost)
    S -= eta_S * dS
    W -= eta_W * dW
    V -= eta_V * dV
    b -= eta_b * db
    c -= eta_c * dc
    V /= np.sqrt(np.sum(V ** 2, axis=1, keepdims=True))


10 : 0.286571
20 : 0.278597
30 : 0.27697
40 : 0.280778
50 : 0.283951
60 : 0.275291
70 : 0.273751
80 : 0.278399
90 : 0.276142
100 : 0.272287
110 : 0.275641
120 : 0.271068
130 : 0.2755
140 : 0.273139
150 : 0.276744
160 : 0.275934
170 : 0.273977
180 : 0.27566
190 : 0.273721
200 : 0.266562
210 : 0.265552
220 : 0.269163
230 : 0.275493
240 : 0.270877
250 : 0.263074
260 : 0.268905
270 : 0.273318
280 : 0.274667
290 : 0.269911
300 : 0.267647
310 : 0.265164
320 : 0.271678
330 : 0.262505
340 : 0.269487
350 : 0.277756
360 : 0.270381
370 : 0.260855
380 : 0.270355
390 : 0.264649
400 : 0.269567
410 : 0.268137
420 : 0.263199
430 : 0.272687
440 : 0.26254
450 : 0.260367
460 : 0.266136
470 : 0.265777
480 : 0.268628
490 : 0.262194
500 : 0.263993
510 : 0.262497
520 : 0.258843
530 : 0.261001
540 : 0.262615
550 : 0.257322
560 : 0.258361
570 : 0.259855
580 : 0.259019
590 : 0.258621
600 : 0.25713
610 : 0.26026
620 : 0.254791
630 : 0.257685
640 : 0.257958
650 : 0.257385
660 : 0.255858
670 : 0.25336
680 : 0.255913
690 : 0.261236
700 : 0.254736
710 : 0.259723
720 : 0.25873
730 : 0.254876
740 : 0.256836
750 : 0.261133
760 : 0.252746
770 : 0.261524
780 : 0.252038
790 : 0.256676
800 : 0.251389
810 : 0.253071
820 : 0.258036
830 : 0.253677
840 : 0.256124
850 : 0.252319
860 : 0.249619
870 : 0.250136
880 : 0.254626
890 : 0.258207
900 : 0.252874
910 : 0.24648
920 : 0.25315
930 : 0.251707
940 : 0.254001
950 : 0.251566
960 : 0.247781
970 : 0.253993
980 : 0.250107
990 : 0.246598
1000 : 0.256752
1010 : 0.244312
1020 : 0.253776
1030 : 0.250083
1040 : 0.248485
1050 : 0.25219
1060 : 0.24769
1070 : 0.248035
1080 : 0.248566
1090 : 0.248257
1100 : 0.246937
1110 : 0.248074
1120 : 0.242897
1130 : 0.244621
1140 : 0.242703
1150 : 0.248194
1160 : 0.247071
1170 : 0.246309
1180 : 0.246202
1190 : 0.248741
1200 : 0.25229
1210 : 0.247475
1220 : 0.242532
1230 : 0.247505
1240 : 0.245457
1250 : 0.244231
1260 : 0.245212
1270 : 0.243171
1280 : 0.245756
1290 : 0.244453
1300 : 0.249497
1310 : 0.247503
1320 : 0.248681
1330 : 0.247488
1340 : 0.249092
1350 : 0.24955
1360 : 0.24315
1370 : 0.245771
1380 : 0.241264
1390 : 0.244
1400 : 0.247498
1410 : 0.243185
1420 : 0.238869
1430 : 0.236283
1440 : 0.243725
1450 : 0.243881
1460 : 0.241309
1470 : 0.239782
1480 : 0.245256
1490 : 0.242054
1500 : 0.24551
1510 : 0.240373
1520 : 0.241747
1530 : 0.235552
1540 : 0.241142
1550 : 0.237343
1560 : 0.240692
1570 : 0.239237
1580 : 0.238784
1590 : 0.238152
1600 : 0.232126
1610 : 0.240884
1620 : 0.238867
1630 : 0.242488
1640 : 0.236804
1650 : 0.240624
1660 : 0.241052
1670 : 0.241521
1680 : 0.240434
1690 : 0.237673
1700 : 0.241769
1710 : 0.239123
1720 : 0.240373
1730 : 0.238181
1740 : 0.244262
1750 : 0.238815
1760 : 0.234933
1770 : 0.240994
1780 : 0.235766
1790 : 0.236165
1800 : 0.235166
1810 : 0.235647
1820 : 0.242799
1830 : 0.237132
1840 : 0.238386
1850 : 0.237855
1860 : 0.234478
1870 : 0.238299
1880 : 0.238831
---------------------------------------------------------------------------
KeyboardInterrupt                         Traceback (most recent call last)
<ipython-input-34-6739620eb1fe> in <module>()
     11     r = cpc(x, dim_unit)
     12     nu = sigma * np.random.randn(dim_unit, dim_batch).astype(np.float32)
---> 13     cost, x_hat, dS, dW, dV, db, dc = f(r, nu, S, W, V, b, c)
     14     if (t + 1) % 10 == 0:
     15         print (t + 1), ':', np.sqrt(cost)

/Users/chayut/anaconda/lib/python2.7/site-packages/theano/compile/function_module.pyc in __call__(self, *args, **kwargs)
    593         t0_fn = time.time()
    594         try:
--> 595             outputs = self.fn()
    596         except Exception:
    597             if hasattr(self.fn, 'position_of_error'):

/Users/chayut/anaconda/lib/python2.7/site-packages/theano/scan_module/scan_op.pyc in rval(p, i, o, n, allow_gc)
    633         allow_gc = config.allow_gc and not self.allow_gc
    634 
--> 635         def rval(p=p, i=node_input_storage, o=node_output_storage, n=node,
    636                  allow_gc=allow_gc):
    637             r = p(n, [x[0] for x in i], o)

KeyboardInterrupt: 

In [7]:
with np.load('/Users/chayut/Dropbox/data/grid_cells/opt_coding_1d_tmp_save.npz') as result:
    result_S = result['S']
    result_W = result['W']
    result_V = result['V']
    result_b = result['b']
    result_c = result['c']
    result_x_hat = result['x_hat']

In [73]:
plt.imshow(result_V, interpolation='None')
plt.colorbar()


Out[73]:
<matplotlib.colorbar.Colorbar instance at 0x12981d368>

In [74]:
plt.plot(result_V[13, :].T)


Out[74]:
[<matplotlib.lines.Line2D at 0x12aaea110>]

In [75]:
plt.plot(np.abs(np.fft.rfft(result_V[13,:])).T)
plt.grid()



In [20]:
plt.imshow(r, interpolation='None')


Out[20]:
<matplotlib.image.AxesImage at 0x12031f410>

In [21]:
plt.imshow(x_hat, interpolation='None')


Out[21]:
<matplotlib.image.AxesImage at 0x12088ef10>

In [2]:
def create_decoder():
    A, y = T.fmatrices('A', 'y')
    x = T.nnet.sigmoid(T.dot(A, y))
    return theano.function(inputs=[A, y],
                           outputs=x)

In [6]:
decoder = create_decoder()
dec_mat = np.load('/Users/chayut/Dropbox/data/grid_cells/cpc_decoding_vector.npy')

In [11]:
decoder(dec_mat, result_x_hat)


Out[11]:
array([[ 0.27246067,  0.17158265,  0.59323031, ...,  0.48518836,
         0.68330532,  0.4247331 ]], dtype=float32)

In [ ]: