In [11]:
%matplotlib inline

import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np



# W_true = np.random.randn(M,K)
# H_true = np.random.randn(K,N)

# X = W_true.dot(H_true)
# X = X+0.01*np.random.randn(M,N)
X = np.array([[ 5.,  5.,  0.,  4.,  1.,  0.,  0.,  0.,  0.,  1.,  5.,  0.],
       [ 4.,  0.,  4.,  5.,  0.,  2.,  0.,  1.,  1.,  1.,  4.,  5.],
       [ 0.,  0.,  0.,  2.,  0.,  1.,  0.,  0.,  0.,  0.,  0.,  0.],
       [ 0.,  1.,  0.,  0.,  0.,  4.,  5.,  0.,  4.,  5.,  0.,  1.],
       [ 1.,  0.,  1.,  0.,  5.,  0.,  4.,  4.,  5.,  0.,  1.,  1.]])

M = X.shape[0]
N = X.shape[1]
K = 3

In [17]:


In [21]:
Mask = (np.random.rand(M,N)<0.4)

Mask = np.copy(X)
Mask[Mask!=0] = 1
Mask

# W = np.random.randn(M,K)
# H = np.random.randn(K,N)
W = np.array([[ 0.18011451,  0.60299404,  0.3177374 ],
       [ 0.65611746,  0.2221562 ,  0.11995832],
       [ 0.80116047,  0.82407141,  0.79099989],
       [ 0.66004068,  0.55214108,  0.82858835],
       [ 0.05662817,  0.99133845,  0.28844242]])
H = np.array([[ 0.12592786,  0.75672181,  0.30747259,  0.56041327,  0.54473678,
         0.89090891,  0.12236257,  0.3108339 ,  0.46744222,  0.89925064,
         0.09402283,  0.9283946 ],
       [ 0.21140259,  0.67445654,  0.14928717,  0.84047072,  0.16396425,
         0.03306015,  0.92158905,  0.18943411,  0.40587314,  0.50836581,
         0.78039105,  0.54416406],
       [ 0.44876138,  0.67787687,  0.18120697,  0.23119679,  0.30702529,
         0.52461556,  0.62692945,  0.51883236,  0.63077754,  0.93171459,
         0.72034789,  0.63587984]])

EPOCH = 1000

eta = 0.05

for i in range(EPOCH):
    dW = -(Mask*(X-W.dot(H))).dot(H.T)
    W = W - eta*dW
    dH = -W.T.dot((Mask*(X-W.dot(H))))
    H = H - eta*dH

    if (i%100 == 0):
        print(0.5*np.sum((Mask*(X-W.dot(H)))**2))


plt.imshow(Mask, interpolation='nearest',cmap=plt.cm.gray_r)
plt.title('Mask')
plt.show()

MX = X.copy()
MX[Mask==0] = np.nan

plt.imshow(MX, interpolation='nearest')
plt.title('Observed Data')
plt.show()
plt.imshow(W.dot(H), interpolation='nearest')
plt.title('Approximation')
plt.show()
plt.imshow(X, interpolation='nearest')
plt.title('True')
plt.show()


66.6787797434
0.000376521655613
7.43500743418e-05
1.61868332186e-05
3.53655769926e-06
7.73277340903e-07
1.69107631575e-07
3.69834292453e-08
8.08825172207e-09
1.76889847362e-09

In [ ]: