In [1]:
%matplotlib inline

In [2]:
import numpy as np
import matplotlib.pyplot as plt
from OD_setup import *

In [6]:
import random

M = 2  #dim of signal
K = 2  #dim of dictionary 

assert M >= K


x_set = []
for k in range(K):
    x = np.zeros((M,1),dtype=float)
    x[0:k] = 1.0
    x_set.append(x)

temp = []
temp_F = []
temp_L = []

T = 10000

for lamda in [0.01*i for i in range(10)]:
    np.random.seed(10)
    temp_lamda = 0
    temp_lamda_F = 0
    temp_lamda_L = 0
    
    D = np.zeros((M,K), dtype=float)
    for k in range(K):
        D[k,k] = 1.0
    
    theta = 0.25*np.pi
    D[0:2,0] = np.array([np.cos(theta),  np.sin(theta)])
    D[0:2,1] = np.array([np.sin(theta),  -np.cos(theta)])
    
    D_0 = np.array(D)
    A = np.zeros((K,K))
    B = np.zeros((M,K))
    print(D)
    for i in range(1,T+1):
        x = random.choice(x_set)
        #a = np.random.normal(0.0001,(M,1))
        #x = x + a
        
        #alpha = alpha_ADMM(lamda,x,D,100)
        D,alpha,A,B = ODL(lamda,x,D,A,B)
        #print(i,alpha)
        
        temp_lamda_i,temp_lamda_F_i,temp_lamda_L_i = cost(lamda,x,D,alpha)
        temp_lamda += temp_lamda_i
        temp_lamda_F += temp_lamda_F_i
        temp_lamda_L += temp_lamda_L_i
    temp.append(temp_lamda/T)
    temp_F.append(temp_lamda_F/T)
    temp_L.append(temp_lamda_L/T)
    print(D)

sys.exit(0)

plt.subplot(3,1,1)
plt.plot(lamda_D,temp)
plt.xlabel("lamda")
plt.ylabel("Total Cost")
plt.subplot(3,1,2)
plt.plot(lamda_D,temp_F)
plt.xlabel("lamda")
plt.ylabel("Fitting Cost")
plt.subplot(3,1,3)
plt.plot(lamda_D,temp_L)
plt.xlabel("lamda")
plt.ylabel("L1 Cost")


[[ 0.70710678  0.70710678]
 [ 0.70710678 -0.70710678]]
[[ 0.70710678  0.70710678]
 [ 0.70710678 -0.70710678]]
[[ 0.70710678  0.70710678]
 [ 0.70710678 -0.70710678]]
[[ 0.79706921  0.79729093]
 [ 0.60388766 -0.6035952 ]]
[[ 0.70710678  0.70710678]
 [ 0.70710678 -0.70710678]]
[[ 0.88554785  0.88592133]
 [ 0.46454817 -0.46383552]]
[[ 0.70710678  0.70710678]
 [ 0.70710678 -0.70710678]]
[[ 0.97358848  0.9737026 ]
 [ 0.22831005 -0.22782284]]
[[ 0.70710678  0.70710678]
 [ 0.70710678 -0.70710678]]
[[  1.00000000e+000   1.00000000e+000]
 [  1.34138823e-320  -1.25838520e-320]]
[[ 0.70710678  0.70710678]
 [ 0.70710678 -0.70710678]]
[[  1.00000000e+000   1.00000000e+000]
 [  2.47032823e-323  -2.47032823e-323]]
[[ 0.70710678  0.70710678]
 [ 0.70710678 -0.70710678]]
[[  1.00000000e+000   1.00000000e+000]
 [  1.97626258e-323  -1.97626258e-323]]
[[ 0.70710678  0.70710678]
 [ 0.70710678 -0.70710678]]
[[  1.00000000e+000   1.00000000e+000]
 [  1.48219694e-323  -1.48219694e-323]]
[[ 0.70710678  0.70710678]
 [ 0.70710678 -0.70710678]]
[[  1.00000000e+000   1.00000000e+000]
 [  1.48219694e-323  -1.48219694e-323]]
[[ 0.70710678  0.70710678]
 [ 0.70710678 -0.70710678]]
[[  1.00000000e+000   1.00000000e+000]
 [  1.48219694e-323  -1.48219694e-323]]
---------------------------------------------------------------------------
NameError                                 Traceback (most recent call last)
<ipython-input-6-677bf5bfb0ed> in <module>()
     55     print(D)
     56 
---> 57 sys.exit(0)
     58 
     59 plt.subplot(3,1,1)

NameError: name 'sys' is not defined

In [ ]: