In [1]:
# Import here
%matplotlib inline
import matplotlib.pylab as pylab
import matplotlib.pyplot as plt
import numpy as np
import theano
import theano.tensor as T

pylab.rcParams['figure.figsize'] = (9.0, 6.0)

In [17]:
n_steps = 100
x = T.fmatrix(name='x')
S = T.fmatrix(name='S')
W = T.fmatrix(name='W')
b_eff = T.dot(W, x)
def one_step(y, S, b_eff):
    return T.dot(S, y) + b_eff
y_seq, updates = theano.scan(fn=one_step,
                             outputs_info=T.zeros_like(b_eff),
                             non_sequences=[S, b_eff],
                             n_steps=n_steps)
y=y_seq[-1]
cost = T.mean(T.sum((x - y) ** 2, axis=0, keepdims=True))
dS = T.grad(cost, S)
dW = T.grad(cost, W)
f = theano.function(inputs=[x, S, W],
                    outputs=[cost, dS, dW, y],
                    updates=updates)

In [22]:
dim_x = 64
dim_batch = 256
val_S = np.eye(dim_x, dtype=np.float32)
val_W = np.random.randn(dim_x, dim_x).astype(np.float32) / np.sqrt(dim_x) / n_steps

In [24]:
eta_W = 0.1 / dim_batch
eta_S = 0.1 / dim_batch / n_steps

for t in range(1000):
    val_x = np.random.randn(dim_x, dim_batch).astype(np.float32)
    val_cost, val_dS, val_dW, val_y = f(val_x, val_S, val_W)
    if (t + 1) % 10 == 0:
        print (t + 1), val_cost
    val_S -= eta_S * val_dS
    val_W -= eta_W * val_dW


10 70.2283554077
20 58.7585639954
30 51.4939842224
40 43.9743728638
50 39.5721511841
60 34.2540283203
70 28.9766349792
80 26.5210094452
90 23.2932872772
100 21.7147140503
110 19.7919521332
120 18.0458621979
130 15.7576503754
140 14.2348108292
150 12.7208366394
160 11.2104034424
170 11.1228265762
180 10.1094770432
190 8.88823795319
200 8.09349822998
210 7.41362094879
220 6.91442012787
230 6.06851100922
240 5.5574092865
250 4.96092796326
260 4.79374551773
270 4.8021364212
280 4.51977014542
290 4.24153661728
300 3.95704865456
310 3.62325525284
320 3.59461092949
330 3.16009163857
340 2.90127444267
350 2.45963668823
360 2.49890470505
370 2.35923743248
380 2.15628027916
390 2.08569288254
400 2.00370645523
410 1.69234800339
420 1.95856952667
430 1.63814485073
440 1.58379006386
450 1.54001557827
460 1.41342389584
470 1.24092912674
480 1.19672417641
490 1.10880458355
500 1.07253813744
510 1.03193199635
520 1.08108055592
530 0.955024003983
540 0.866100728512
550 0.861440658569
560 0.696216404438
570 0.733282744884
580 0.767969310284
590 0.59978967905
600 0.645955085754
610 0.628410041332
620 0.528648495674
630 0.514638841152
640 0.477892398834
650 0.445438295603
660 0.516854822636
670 0.46590834856
680 0.402379214764
690 0.391501277685
700 0.379371553659
710 0.411156445742
720 0.332494288683
730 0.268220216036
740 0.288221120834
750 0.266344666481
760 0.277551949024
770 0.27564522624
780 0.227135062218
790 0.257936567068
800 0.224724963307
810 0.210002377629
820 0.200717478991
830 0.178926870227
840 0.180963680148
850 0.176275461912
860 0.152346223593
870 0.147061794996
880 0.170192927122
890 0.131817892194
900 0.151087611914
910 0.133023500443
920 0.137734174728
930 0.101545706391
940 0.106590211391
950 0.124645128846
960 0.0900602787733
970 0.0960389003158
980 0.0834612026811
990 0.0895466953516
1000 0.0854084566236

In [25]:
print val_S
plt.imshow(val_S, interpolation='None')


[[  7.91551828e-01   1.09368796e-02   8.90446454e-03 ...,   9.46536602e-04
   -1.09681813e-02  -2.88536400e-02]
 [ -4.43507422e-04   7.99930394e-01  -9.05676279e-03 ...,  -4.71686348e-02
    2.56220158e-02  -1.41058501e-03]
 [  2.22858787e-03  -2.90666968e-02   8.28683913e-01 ...,  -1.49826519e-03
   -1.80445565e-03  -1.26732010e-02]
 ..., 
 [  3.53699573e-03  -3.98051888e-02   1.74171582e-03 ...,   8.09756994e-01
    9.91128106e-03  -9.01053566e-03]
 [ -1.42429871e-02   3.22850645e-02  -6.38222974e-03 ...,   6.68446906e-03
    8.07649314e-01  -2.28243377e-02]
 [ -2.03520227e-02   1.53767061e-03  -3.90229234e-03 ...,  -9.68658552e-03
   -2.98266788e-03   8.06548715e-01]]
Out[25]:
<matplotlib.image.AxesImage at 0x114877f50>

In [26]:
print val_W
plt.imshow(val_W, interpolation='None')


[[  1.94576979e-01  -9.00759269e-03  -8.84541078e-04 ...,   1.66932924e-03
    1.23289805e-02   2.62517035e-02]
 [  4.65040193e-06   1.98400542e-01   7.54676294e-03 ...,   4.57942858e-02
   -2.46644337e-02   1.70087279e-03]
 [ -4.66752192e-03   2.87849307e-02   1.72156021e-01 ...,   1.45561306e-03
    2.42715818e-03   1.21286856e-02]
 ..., 
 [ -5.88082150e-03   3.92963067e-02  -1.02771446e-03 ...,   1.89997867e-01
   -9.23397858e-03   8.56697652e-03]
 [  8.21293332e-03  -3.07865757e-02   1.08245527e-02 ...,  -4.96891467e-03
    1.92571193e-01   2.13590879e-02]
 [  8.59491248e-03  -2.48973083e-04   1.06977308e-02 ...,   1.16584627e-02
    4.31708433e-03   1.91000864e-01]]
Out[26]:
<matplotlib.image.AxesImage at 0x11bc978d0>

In [ ]: