In [1]:
# Import here
%matplotlib inline
import matplotlib.pylab as pylab
import matplotlib.pyplot as plt
import numpy as np
import theano
import theano.tensor as T

pylab.rcParams['figure.figsize'] = (9.0, 6.0)


Using gpu device 0: Quadro K6000

In [2]:
x = T.fmatrix(name='x')
A = T.fmatrix(name='A')
y = T.dot(A, x)
cost = T.mean(T.sum((x - y) ** 2, axis=0, keepdims=True))
dA = T.grad(cost, A)
f = theano.function(inputs=[x, A],
                    outputs=[cost, dA])

In [3]:
dim_x = 5
dim_batch = 100
val_A = np.random.randn(dim_x, dim_x).astype(np.float32) / np.sqrt(5)

In [4]:
eta = 10./dim_batch
for t in range(100):
    val_x = np.random.randn(dim_x, dim_batch).astype(np.float32)
    val_cost, val_dA = f(val_x, val_A)
    print t, val_cost
    val_A -= eta * val_dA


0 8.32910060883
1 6.92756414413
2 4.27359676361
3 2.86882805824
4 1.97431242466
5 1.00909984112
6 0.819539368153
7 0.437326788902
8 0.324858993292
9 0.193760022521
10 0.142531827092
11 0.0822527185082
12 0.0540464706719
13 0.040549043566
14 0.022033052519
15 0.0134007418528
16 0.0100008938462
17 0.00461608963087
18 0.00421692058444
19 0.00183940376155
20 0.00132594408933
21 0.000980254495516
22 0.000485067983391
23 0.000381104386179
24 0.000309171824483
25 0.000199684334802
26 0.0001245985477
27 8.65358961164e-05
28 5.68947725696e-05
29 4.20140240749e-05
30 2.24167106353e-05
31 1.61136958923e-05
32 9.84169946605e-06
33 5.63978619539e-06
34 4.88622936246e-06
35 2.50275070357e-06
36 2.23581696446e-06
37 1.41204952797e-06
38 8.32138312035e-07
39 6.13968722973e-07
40 3.36979070426e-07
41 2.40306462729e-07
42 1.17419453716e-07
43 9.01721435298e-08
44 4.64724223548e-08
45 3.1648550447e-08
46 2.7525269175e-08
47 1.14150378039e-08
48 7.45308081918e-09
49 6.08810513114e-09
50 3.03588354456e-09
51 2.14647144503e-09
52 1.18109322322e-09
53 9.78434999688e-10
54 5.48379230914e-10
55 3.48715972764e-10
56 2.2321536286e-10
57 2.02274350047e-10
58 1.01764610627e-10
59 4.98017981021e-11
60 4.0773225074e-11
61 2.54800434224e-11
62 1.66343067792e-11
63 9.29458593424e-12
64 7.08583192122e-12
65 6.15729472617e-12
66 2.76153535803e-12
67 1.98529886107e-12
68 1.24541793978e-12
69 7.84922257399e-13
70 5.20341582322e-13
71 3.53174466668e-13
72 2.09241085287e-13
73 1.09282557775e-13
74 1.05723867432e-13
75 9.73627781045e-14
76 7.98859811852e-14
77 6.22137839293e-14
78 6.32424207404e-14
79 4.3569477453e-14
80 4.73710900692e-14
81 5.03374943182e-14
82 4.25538409035e-14
83 5.04752049348e-14
84 3.67735864275e-14
85 3.64825865644e-14
86 3.70116501195e-14
87 3.63463531733e-14
88 3.86005179101e-14
89 3.78067091251e-14
90 4.03179144801e-14
91 3.75250876108e-14
92 4.11568091348e-14
93 3.79947233343e-14
94 3.83544040169e-14
95 3.37250572521e-14
96 3.90745310999e-14
97 3.70023192045e-14
98 3.58983947169e-14
99 3.73895182935e-14

In [5]:
print val_A
plt.imshow(val_A, interpolation='None')


[[  9.99999940e-01   4.75354422e-09  -2.03632888e-10   2.70868883e-09
   -4.14564605e-09]
 [ -2.72381229e-10   9.99999940e-01   4.27727542e-09  -7.33849204e-10
   -2.05646056e-09]
 [ -2.34329400e-09   3.51361029e-09   9.99999940e-01   1.91649829e-09
    9.01530139e-09]
 [  5.74596370e-10   4.31621316e-09   2.70990896e-09   9.99999940e-01
   -4.09515266e-09]
 [  7.72587105e-10  -4.96816099e-09   4.32309211e-09   1.30315003e-09
    9.99999940e-01]]
Out[5]:
<matplotlib.image.AxesImage at 0x12018e050>

In [5]: