In [1]:
import numpy as np
from scipy import stats
import scipy.io
from matplotlib import pyplot as plt
import time
import pstats

import sys
sys.path.insert(0, '../Source Code')
from Gibbs_GCHMM import *

In [2]:
import os
parpath = os.path.dirname(os.getcwd())
Y = scipy.io.loadmat(os.path.join(parpath, 'data/Y.mat'))['Y']
X = scipy.io.loadmat(os.path.join(parpath, 'data/X.mat'))['X']
G = scipy.io.loadmat(os.path.join(parpath, 'data/G.mat'))['G']

In [36]:
# compute missing rate for 
mis_err = []
for missing_rate in np.arange(0.1, 0.9, 0.05):
    #mask = stats.bernoulli.rvs(missing_rate, size = Y.shape)
    mask = np.repeat(np.expand_dims(stats.bernoulli.rvs(missing_rate, size = (Y.shape[0], Y.shape[2])), axis = 1), Y.shape[1], axis = 1)
    Ymask = Y * mask
    Ytrue = Y * (1 - mask)
    ans = Gibbs_GCHMM(G, Ytrue, mask, T = 500, 
                      prior={'ap':2, 'bp':5, 'aa':2, 'ba':5, 'ab':2, 'bb':5, 'ar':2, 'br':5, 
                             'a1':5, 'b1':2, 'a0':2, 'b0':5})
    Ypred = ans[1]>0.5
    mis_err.append(np.sum((Ypred == 0) * (Ymask == 1))/np.sum(Ymask == 1))
print(mis_err)


[0.94752402069475239, 0.94585529468136076, 0.94311045437753971, 0.948483956432146, 0.94500122219506233, 0.94434021263289558, 0.95583710407239819, 0.96078745198463511, 0.94673959982726352, 0.96470276008492573, 0.96450786681302603, 0.97010114482605314, 0.96832343440658541, 0.97490516486723078, 0.96754472714195805, 0.9926730672056594]

In [37]:
plt.plot(np.arange(0.1, 0.9, 0.05), mis_err)
#np.sum((Ypred == 1) * (Ymask == 0))/np.sum(Ymask == 0)
plt.savefig("FNR.png", dpi = 300, bbox_inches ="tight")



In [3]:
missing_rate = 0.2
mask = stats.bernoulli.rvs(missing_rate, size = Y.shape)
Ymask = Y * mask
Ytrue = Y * (1 - mask)
ans = Gibbs_GCHMM(G, Ytrue, mask, T = 500)
plt.imshow(ans[0], cmap='seismic')
plt.savefig("missing02.png", dpi = 300, bbox_inches ="tight")



In [6]:
missing_rate = 0.4
mask = stats.bernoulli.rvs(missing_rate, size = Y.shape)
Ymask = Y * mask
Ytrue = Y * (1 - mask)
ans = Gibbs_GCHMM(G, Ytrue, mask, T = 500)
plt.imshow(ans[0], cmap='seismic')
plt.savefig("missing04.png", dpi = 300, bbox_inches ="tight")



In [21]:
missing_rate = 0.6
mask = stats.bernoulli.rvs(missing_rate, size = Y.shape)
Ymask = Y * mask
Ytrue = Y * (1 - mask)
ans = Gibbs_GCHMM(G, Ytrue, mask, T = 500)
plt.imshow(ans[0], cmap='seismic')
plt.savefig("missing06.png", dpi = 300, bbox_inches ="tight")



In [22]:
missing_rate = 0.8
mask = stats.bernoulli.rvs(missing_rate, size = Y.shape)
Ymask = Y * mask
Ytrue = Y * (1 - mask)
ans = Gibbs_GCHMM(G, Ytrue, mask, T = 500)
plt.imshow(ans[0], cmap='seismic')
plt.savefig("missing08.png", dpi = 300, bbox_inches ="tight")



In [10]:
missing_rate = 0
mask = stats.bernoulli.rvs(missing_rate, size = Y.shape)
Ymask = Y * mask
Ytrue = Y * (1 - mask)

In [12]:
ans = Gibbs_GCHMM(G, Ytrue, mask, T = 500)

In [13]:
plt.imshow(ans[0], cmap='seismic')
plt.savefig("predicted.png", dpi = 300, bbox_inches ="tight")



In [14]:
plt.imshow(X, cmap='seismic')
plt.savefig("true.png", dpi = 300, bbox_inches ="tight")



In [17]:
def work_whole(G, Y, T = 500):
    Gibbs_GCHMM(G, Y, T = 500)

In [18]:
%prun -q -D Gibbs_GCHMM.prof Gibbs_GCHMM(G, Y, T = 500)


 
*** Profile stats marshalled to file 'Gibbs_GCHMM.prof'. 

In [19]:
p = pstats.Stats('Gibbs_GCHMM.prof')
p.sort_stats('time', 'cumulative').print_stats(10)
pass


Mon May  1 13:31:29 2017    Gibbs_GCHMM.prof

         524722 function calls in 10.767 seconds

   Ordered by: internal time, cumulative time
   List reduced from 63 to 10 due to restriction <10>

   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
        1    6.268    6.268   10.767   10.767 <ipython-input-14-54eb1aa86164>:1(Gibbs_GCHMM)
   107000    3.104    0.000    3.104    0.000 <ipython-input-14-54eb1aa86164>:83(NumPreInf)
    13512    0.515    0.000    0.515    0.000 {method 'reduce' of 'numpy.ufunc' objects}
    54500    0.259    0.000    0.259    0.000 {method 'rand' of 'mtrand.RandomState' objects}
      500    0.116    0.000    0.116    0.000 {method 'repeat' of 'numpy.ndarray' objects}
   110004    0.095    0.000    0.095    0.000 {method 'reshape' of 'numpy.ndarray' objects}
     3006    0.045    0.000    0.363    0.000 /opt/conda/lib/python3.5/site-packages/scipy/stats/_distn_infrastructure.py:909(rvs)
    35295    0.035    0.000    0.035    0.000 {built-in method numpy.core.multiarray.array}
     4000    0.033    0.000    0.044    0.000 /opt/conda/lib/python3.5/site-packages/numpy/lib/stride_tricks.py:57(_broadcast_to)
     3006    0.030    0.000    0.181    0.000 /opt/conda/lib/python3.5/site-packages/scipy/stats/_distn_infrastructure.py:789(_argcheck_rvs)