In [1]:
import numpy as np
from scipy import stats
import scipy.io
from matplotlib import pyplot as plt
import pstats

import sys
sys.path.insert(0, '../Source Code')
from AEMBP_GCHMM import *

In [2]:
## Load synthesized data
import os
parpath = os.path.dirname(os.getcwd())
Y = scipy.io.loadmat(os.path.join(parpath, 'data/Y.mat'))['Y'] # Graph information of contacts each timestamp, dim=(84,84,107)
X = scipy.io.loadmat(os.path.join(parpath, 'data/X.mat'))['X'] # Hidden(infection) states, dim=(84,108), including X(0)
G = scipy.io.loadmat(os.path.join(parpath, 'data/G.mat'))['G'] # Observed(symptom) states, dim=(84,6,107)
C = scipy.io.loadmat(os.path.join(parpath, 'data/C.mat'))['C']

sC = [np.sum(c[0], 1) for c in C]

In [3]:
[XBELEM, paraAEM] = AEMBP_GCHMM(G,Y,C, sC, 'f', 10, 0.0001,X,
                                prior={'xi': 0.25, 'alpha':0.1, 'beta':0.1, 'gamma':0.5, 
                                       'theta_1':0.75, 'theta_0':0.25, 'ax':2, 'bx':5, 
                                       'aa':2, 'ba':5, 'ab':2,'bb':5, 'ar':2, 'br':5,
                                       'al':2,'bl':2,'a0':2,'b0':5})

In [4]:
plt.imshow(XBELEM[:,1,:], cmap='seismic')
plt.savefig("predicted_AEM.png", dpi = 300, bbox_inches ="tight")



In [30]:
plt.imshow(X, cmap='seismic')
plt.savefig("true.png", dpi = 300, bbox_inches ="tight")



In [10]:
def work(G,Y,C, sC, initial, MaxIter, tol,X):
    AEMBP_GCHMM(G,Y,C, sC, initial, MaxIter, tol,X)

In [11]:
%prun -q -D AEMBP_GCHMM.prof AEMBP_GCHMM(G, Y, C, sC,'f', 5, 0.0001, X)


 
*** Profile stats marshalled to file 'AEMBP_GCHMM.prof'. 

In [13]:
p = pstats.Stats('AEMBP_GCHMM.prof')
p.sort_stats('time','cumulative').print_stats(15)
pass


Mon May  1 13:39:50 2017    AEMBP_GCHMM.prof

         24109445 function calls in 52.915 seconds

   Ordered by: internal time, cumulative time
   List reduced from 66 to 15 due to restriction <15>

   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
   576640   14.303    0.000   31.654    0.000 <ipython-input-6-347b45838ca2>:1(sp)
  2658791    7.131    0.000    7.131    0.000 {method 'reduce' of 'numpy.ufunc' objects}
        1    5.619    5.619   52.914   52.914 <ipython-input-2-d8846caa74c5>:7(AEMBP_GCHMM)
  1500600    4.340    0.000   11.701    0.000 /opt/conda/lib/python3.5/site-packages/numpy/matlib.py:310(repmat)
     2140    4.190    0.002   10.294    0.005 <ipython-input-5-49c0020f5080>:1(SumProd)
  3001205    3.610    0.000    3.610    0.000 {method 'repeat' of 'numpy.ndarray' objects}
  1906336    3.290    0.000    8.949    0.000 /opt/conda/lib/python3.5/site-packages/numpy/core/fromnumeric.py:1743(sum)
  4503420    2.051    0.000    2.051    0.000 {method 'reshape' of 'numpy.ndarray' objects}
   752440    1.585    0.000    5.001    0.000 /opt/conda/lib/python3.5/site-packages/numpy/core/fromnumeric.py:2433(prod)
   375150    1.470    0.000    1.470    0.000 <ipython-input-2-d8846caa74c5>:61(<lambda>)
  1510247    1.321    0.000    1.724    0.000 /opt/conda/lib/python3.5/site-packages/numpy/core/numeric.py:484(asanyarray)
  1906336    0.918    0.000    5.040    0.000 /opt/conda/lib/python3.5/site-packages/numpy/core/_methods.py:31(_sum)
   375150    0.754    0.000    0.754    0.000 <ipython-input-2-d8846caa74c5>:60(<lambda>)
  1908582    0.619    0.000    0.619    0.000 {built-in method builtins.isinstance}
   578780    0.606    0.000    0.606    0.000 {method 'copy' of 'numpy.ndarray' objects}