In [1]:
# Import here
%matplotlib inline
import matplotlib.pylab as pylab
import matplotlib.pyplot as plt
import numpy as np
import theano
import theano.tensor as T

pylab.rcParams['figure.figsize'] = (9.0, 6.0)

In [2]:
n_steps = 10
x = T.fmatrix(name='x')
A = T.fmatrix(name='A')
def one_step(x, A):
    return T.dot(A, x)
y_seq, updates = theano.scan(fn=one_step,
                             outputs_info=x,
                             non_sequences=[A],
                             n_steps=n_steps)
y=y_seq[-1]
cost = T.mean(T.sum((x - y) ** 2, axis=0, keepdims=True))
dA = T.grad(cost, A)
f = theano.function(inputs=[x, A],
                    outputs=[cost, dA, y],
                    updates=updates)


/Users/chayut/anaconda/lib/python2.7/site-packages/theano/scan_module/scan_perform_ext.py:133: RuntimeWarning: numpy.ndarray size changed, may indicate binary incompatibility
  from scan_perform.scan_perform import *

In [3]:
dim_x = 5
dim_batch = 256
val_A = np.random.randn(dim_x, dim_x).astype(np.float32) / np.sqrt(dim_x)

In [4]:
eta = 0.5 / dim_batch
mu = 0.5
val_dA = np.zeros_like(val_A)
val_vA = np.zeros_like(val_A)

for t in range(10000):
    val_vA *= mu
    val_A += val_vA
    val_x = np.random.randn(dim_x, dim_batch).astype(np.float32)
    val_cost, val_dA, val_y = f(val_x, val_A)
    if (t + 1) % 1000 == 0:
        print (t + 1), val_cost
    val_A -= eta * val_dA
    val_vA -= eta * val_dA


1000 1.99653935432
2000 1.01200270653
3000 0.974596500397
4000 1.0012421608
5000 1.09729588032
6000 1.07824110985
7000 0.96113473177
8000 0.999555230141
9000 1.17887532711
10000 0.950882732868

In [5]:
val_A_10 = np.linalg.matrix_power(val_A, n_steps)
print val_A_10
plt.imshow(val_A_10, interpolation='None')


[[ 0.95057124  0.0940406   0.11677547 -0.11102918  0.10848867]
 [ 0.09496891  0.81898624 -0.22467604  0.21356429 -0.20869146]
 [ 0.11795858 -0.2247728   0.7209534   0.26528206 -0.25920445]
 [-0.11179806  0.21296276  0.26444793  0.74857152  0.24562483]
 [ 0.10887601 -0.20742393 -0.25748801  0.24477866  0.76072252]]
Out[5]:
<matplotlib.image.AxesImage at 0x1111d8dd0>

In [5]: