In [1]:
# Import here
%matplotlib inline
import matplotlib.pylab as pylab
import matplotlib.pyplot as plt
import numpy as np
import theano
import theano.tensor as T

pylab.rcParams['figure.figsize'] = (9.0, 6.0)


Using gpu device 0: Quadro K6000

In [2]:
def cpc(x, N):
    n = np.arange(N, dtype=np.float32).reshape((N, 1))
    z_matrix = np.float32(x * (N - 1)) - n
    r = np.exp(-z_matrix ** 2 / 2)
    r /= np.sqrt(np.sum(r ** 2, axis=0, keepdims=True))
    return r

In [14]:
t_x, t_y, t_nu = T.fmatrices('x', 'y', 'nu')
t_A = T.fmatrix(name='A')
t_xhat = T.dot(t_A, t_y + t_nu) + 0.5
cost = T.mean((t_xhat - t_x) ** 2)
t_dA = T.grad(cost, t_A)
f = theano.function(inputs=[t_x, t_y, t_nu, t_A],
                    outputs=[cost, t_xhat, t_dA])

In [17]:
N = 64
A = (2 * np.arange(N, dtype=np.float32).reshape((1, N)) - (N - 1)) / N

In [21]:
sigma = 0.000
dim_batch = 2 ** 14
eta = np.float32(0.1)
for t in range(1000):
    x = np.random.uniform(size=(1, dim_batch)).astype(np.float32)
    y = cpc(x, N)
    nu = sigma * np.random.randn(N, dim_batch).astype(np.float32)
    cost, xhat, dA = f(x, y, nu, A)
    if (t + 1) % 100 == 0:
        print (t + 1), ':', np.sqrt(cost)
    A -= eta * dA


100 : 0.0010572
200 : 0.00104923
300 : 0.00102008
400 : 0.00104077
500 : 0.00101465
600 : 0.00102516
700 : 0.000995324
800 : 0.00103729
900 : 0.00101906
1000 : 0.000968446

In [22]:
plt.plot(A.T)


Out[22]:
[<matplotlib.lines.Line2D at 0x12107ee10>]

In [23]:
np.save('/Users/chayut/Dropbox/data/grid_cells/cpc_decoding_vector_linear.npy', A)

In [24]:
np.load('/Users/chayut/Dropbox/data/grid_cells/cpc_decoding_vector_linear.npy')


Out[24]:
array([[-0.41830656, -0.19770402, -0.25956628, -0.24929924, -0.22035065,
        -0.22938499, -0.21490771, -0.20411801, -0.19991897, -0.18943271,
        -0.18073294, -0.17321098, -0.16439231, -0.15584508, -0.14754376,
        -0.13920587, -0.13057926, -0.12226155, -0.11386252, -0.10530018,
        -0.09703528, -0.08846128, -0.08010851, -0.07167848, -0.06320387,
        -0.05479869, -0.04640026, -0.03789746, -0.02952373, -0.02109368,
        -0.01261051, -0.00424446,  0.00422625,  0.01264855,  0.0210708 ,
         0.02951011,  0.03793552,  0.04637125,  0.05479787,  0.06322136,
         0.07168344,  0.08004296,  0.08859339,  0.09687582,  0.10542083,
         0.11383771,  0.1221714 ,  0.13075542,  0.13901037,  0.14767955,
         0.15583062,  0.16427414,  0.17340688,  0.18056643,  0.18947446,
         0.19999462,  0.2040641 ,  0.21475542,  0.22977965,  0.21995682,
         0.24926494,  0.26027727,  0.19648869,  0.41929153]], dtype=float32)

In [ ]: