In [25]:
import numpy as np
from scipy import stats
import scipy.io
from matplotlib import pyplot as plt
import time
import pstats

In [26]:
def Gibbs_jit(G, Ytrue, mask, T, 
          prior={'ap':2, 'bp':5, 'aa':2, 'ba':5, 'ab':2, 'bb':5, 'ar':2, 'br':5, 'a1':5, 'b1':2, 'a0':2, 'b0':3}):
    # Gibbs sampling for GCHMMs with inference and parameter estimation
    # G: non-symmetric social networks
    # Y: evidence, observed data
    # T: num. of iterations
    N, _, D = G.shape
    _, S, _ = Ytrue.shape
    
    ## Initialization
    X = np.zeros((N,D+1), dtype='int32')
    R = np.zeros((N,D))
    #hyperparameters
    ap = prior['ap']; bp = prior['bp']
    aa = prior['aa']; ba = prior['ba']
    ab = prior['ab']; bb = prior['bb']
    ar = prior['ar']; br = prior['br']
    a1 = prior['a1']; b1 = prior['b1']
    a0 = prior['a0']; b0 = prior['b0']

    xi = stats.beta.rvs(ap, bp, size=1)
    alpha = stats.beta.rvs(aa,ba,size=1)
    beta = stats.beta.rvs(ab,bb,size=1)
    gamma = stats.beta.rvs(ar,br,size=1)
    theta1 = stats.beta.rvs(a1,b1,size=(1,S)) 
    theta0 = stats.beta.rvs(a0,b0,size=(1,S)) 
    
    ##Iterative Sampling
    B = T/2 # Burn-in from Iteration B
    Xbi = X # Burn-in for X
    Ybi = np.zeros((Ytrue.shape))# Burn-in for Y
    parabi = np.zeros((1,2*S+4)) # Burn-in for all parameters
    NPI = np.zeros((N,D)) # Num. of previous infection

    for t in range(T):
        # sample missing Y
        Ym, Y = sampleMissingY(mask, Ytrue, X[:,1:], theta1, theta0, N, D, S)
        
        # Update hidden X of initial time stamp
        NPI[:, 0] = NumPreInf(X[:, 0], G[:, :, 0])
        X[:,0] = updateInitialX(X[:, 0], X[:, 1], NPI[:, 0], xi, gamma, alpha, beta, N)
    
        # Update intermediate X
        for i in range(1,D):
            NPI[:,i-1] = NumPreInf(X[:,i-1],G[:,:,i-1])
            NPI[:,i] = NumPreInf(X[:,i],G[:,:,i])
            X[:,i] = updateIntermediaX(Y[:,:,i-1], X[:,i-1], X[:,i+1], NPI[:,i-1], NPI[:,i], theta1, theta0, gamma, alpha, beta, N)
        
        # Updata hidden X of last time stamp
        NPI[:,D-1] = NumPreInf(X[:,D],G[:,:,D-1])
        X[:,D] = updateLastX(Y[:,:,D-1], X[:,D-1], NPI[:,D-1], theta1, theta0, gamma, alpha, beta, N)
        
        # Update auxiliary variable R: prob p has various approximations
        R = updateAuxR(X, NPI, alpha, beta, N, D)
        
        # Update parameters
        xi = stats.beta.rvs(ap + sum(X[:,0]),bp + N - sum(X[:,0]), size = 1)
        gamma = stats.beta.rvs(ar + np.sum(X[:,0:D]*(X[:,1:]==0)), br + np.sum(X[:,0:D] * X[:,1:]))
        alpha = stats.beta.rvs(aa + np.sum(R == 1), ba + np.sum((X[:,0:D] == 0) * (X[:,1:] == 0)) + np.sum(R == 2))
        beta = stats.beta.rvs(ab + np.sum(R > 1), bb + np.sum(NPI*((X[:,0:D] == 0) ^ (R > 1))))
    
        temp = np.transpose(np.repeat(np.expand_dims(X[:,1:], axis=2), S, axis = 2), axes = [0, 2, 1])
        theta1 = stats.beta.rvs(a1 + np.sum(Y * temp, axis = (0,2)), b1 + np.sum((1-Y) * temp, axis = (0,2)), size = (1,S))
        theta0 = stats.beta.rvs(a0 + np.sum(Y * (temp==0), axis = (0,2)), b0 + np.sum((1-Y) * (temp == 0), axis = (0,2)), size = (1,S))
        #print(theta1, theta0)
        
        # Burn-in
        if t>B:
            Xbi = Xbi + X
            Ybi = Ybi + Ym
            parabi = parabi + np.c_[xi,alpha,beta,gamma,theta1,theta0]
    # prediction
    Xpred = Xbi/(T-B)
    Ympred = Ybi/(T-B)
    parapred = parabi/(T-B)
    return [Xpred, Ympred, parapred]

def sampleMissingY(mask, Ytrue, X, theta1, theta0, N, D, S):
    th1 = np.repeat(np.repeat(theta1.reshape((1,S,1)), N, axis=0), D, axis=2)
    th0 = np.repeat(np.repeat(theta0.reshape((1,S,1)), N, axis=0), D, axis=2)
    Ym = mask * (stats.bernoulli.rvs(th1, size=Ytrue.shape) * (X == 1).reshape(N, 1 ,D) + 
                stats.bernoulli.rvs(th0, size=Ytrue.shape) * (X == 0).reshape(N, 1, D))
    return Ym, Ym + (1 - mask) * Ytrue

def updateInitialX(X0, X1, NPI0, xi, gamma, alpha, beta, N):
    p1 = xi * (gamma**np.array(X1==0) * (1-gamma)**np.array(X1))
    p0 = (1-xi)*(1-(1-alpha)*(1-beta)**NPI0)**X1 * ((1-alpha)*(1-beta)**NPI0)**(X1==0)
    p = p1 / (p0+p1)
    return 0+(np.random.rand(N,)<=p)

def updateIntermediaX(Y_cur, X_prev, X_next, NPI_prev, NPI_cur, theta1, theta0, gamma, alpha, beta, N):
    tmp1 = np.exp(Y_cur @ np.log(theta1.T))*np.exp((1-Y_cur) @ np.log(1-theta1.T))
    p1 = gamma**(X_next==0)*(1-gamma)**(X_prev+X_next)*(1-(1-alpha)*(1-beta)**NPI_prev)**(X_prev==0) * tmp1.reshape((N,))
    tmp0 = np.exp(Y_cur @ np.log(theta0.T))*np.exp((1-Y_cur) @ np.log(1-theta0.T))
    p0 = gamma**X_prev*(1-(1-alpha)*(1-beta)**NPI_cur)**X_next*(1-alpha)**((X_prev==0)+(X_next==0))*(1-beta)**(NPI_prev*(X_prev==0)+NPI_cur*(X_next==0))*tmp0.reshape((N,))
    p = p1 / (p0 + p1)
    return 0 + (np.random.rand(N,)<=p)

def updateLastX(Y_cur, X_prev, NPI_prev, theta1, theta0, gamma, alpha, beta, N):
    tmp1 = np.exp(Y_cur @ np.log(theta1.T))* np.exp((1-Y_cur) @ np.log(1-theta1.T))
    p1 = (1-gamma) ** X_prev * (1-(1-alpha) * (1-beta) ** NPI_prev)**(X_prev==0)*tmp1.reshape((N,))
    tmp0 = np.exp(Y_cur @ np.log(theta0.T))*np.exp((1-Y_cur) @ np.log(1-theta0.T))
    p0 = gamma ** X_prev*((1-alpha)*(1-beta)**NPI_prev)**(X_prev==0)*tmp0.reshape((N,))
    p = p1 / (p0 + p1)
    return 0 + (np.random.rand(N,) <= p)

def updateAuxR(X, NPI, alpha, beta, N, D):
    p = alpha / (alpha + beta * NPI)
    tmp = 2 - (np.random.rand(N, D) <= p)
    return (X[:,0:D]==0)*X[:,1:]*tmp

In [27]:
import numba
from numba import jit

@jit
def NumPreInf(Xt,Gt):
    m = Gt.shape[0] 
    gt = np.zeros((m, m))
    res = np.zeros((m, )) 
    for i in range(m):
        for j in range(m):
            gt[i,j] = (Gt[i,j] + Gt[j,i]) > 0
    

    for i in range(m):
        res[i] = 0
        for k in range(m):
            res[i] += gt[i,k] * Xt[k]
    return res

In [28]:
Y = scipy.io.loadmat('Y.mat')['Y']
X = scipy.io.loadmat('X.mat')['X']
G = scipy.io.loadmat('G.mat')['G']

In [29]:
missing_rate = 0
mask = stats.bernoulli.rvs(missing_rate, size = Y.shape)
Ymask = Y * mask
Ytrue = Y * (1 - mask)

In [30]:
def work_jit(G, Ytrue, mask, T = 500):
    Gibbs_jit(G, Ytrue, mask, T = 500)

In [31]:
%prun -q -D work_jit.prof work_jit(G, Ytrue, mask, T = 500)


 
*** Profile stats marshalled to file 'work_jit.prof'. 

In [32]:
p = pstats.Stats('work_jit.prof')
p.sort_stats('ncalls').print_stats(10)
pass


Mon May  1 01:58:26 2017    work_jit.prof

         790390 function calls (768602 primitive calls) in 16.619 seconds

   Ordered by: call count
   List reduced from 1331 to 10 due to restriction <10>

   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
   110004    0.111    0.000    0.111    0.000 {method 'reshape' of 'numpy.ndarray' objects}
    54500    0.302    0.000    0.302    0.000 {method 'rand' of 'mtrand.RandomState' objects}
    53000    6.175    0.000    6.518    0.000 <ipython-input-26-765b98198d20>:91(updateIntermediaX)
    47295    0.061    0.000    0.061    0.000 {built-in method numpy.core.multiarray.array}
    37758    0.012    0.000    0.015    0.000 {built-in method builtins.isinstance}
    23132    0.006    0.000    0.006    0.000 {built-in method builtins.len}
    20000    0.006    0.000    0.006    0.000 /opt/conda/lib/python3.5/site-packages/numpy/lib/stride_tricks.py:62(<genexpr>)
    18512    0.247    0.000    0.247    0.000 {method 'reduce' of 'numpy.ufunc' objects}
    17030    0.009    0.000    0.009    0.000 /opt/conda/lib/python3.5/site-packages/numpy/lib/stride_tricks.py:193(<genexpr>)
    15518    0.016    0.000    0.043    0.000 /opt/conda/lib/python3.5/site-packages/numpy/core/numeric.py:484(asanyarray)