In [10]:
import numpy as np
from scipy import stats
import scipy.io
from matplotlib import pyplot as plt
import pstats
import time

In [11]:
def Gibbs_numpy(G, Ytrue, mask, T, 
          prior={'ap':2, 'bp':5, 'aa':2, 'ba':5, 'ab':2, 'bb':5, 'ar':2, 'br':5, 'a1':5, 'b1':2, 'a0':2, 'b0':3}):
    # Gibbs sampling for GCHMMs with inference and parameter estimation
    # G: non-symmetric social networks
    # Y: evidence, observed data
    # T: num. of iterations
    N, _, D = G.shape
    _, S, _ = Ytrue.shape
    
    ## Initialization
    X = np.zeros((N,D+1), dtype='int32')
    R = np.zeros((N,D))
    #hyperparameters
    ap = prior['ap']; bp = prior['bp']
    aa = prior['aa']; ba = prior['ba']
    ab = prior['ab']; bb = prior['bb']
    ar = prior['ar']; br = prior['br']
    a1 = prior['a1']; b1 = prior['b1']
    a0 = prior['a0']; b0 = prior['b0']

    xi = stats.beta.rvs(ap, bp, size=1)
    alpha = stats.beta.rvs(aa,ba,size=1)
    beta = stats.beta.rvs(ab,bb,size=1)
    gamma = stats.beta.rvs(ar,br,size=1)
    theta1 = stats.beta.rvs(a1,b1,size=(1,S)) 
    theta0 = stats.beta.rvs(a0,b0,size=(1,S)) 
    
    ##Iterative Sampling
    B = T/2 # Burn-in from Iteration B
    Xbi = X # Burn-in for X
    Ybi = np.zeros((Ytrue.shape))# Burn-in for Y
    parabi = np.zeros((1,2*S+4)) # Burn-in for all parameters
    NPI = np.zeros((N,D)) # Num. of previous infection

    for t in range(T):
        # sample missing Y
        Ym, Y = sampleMissingY(mask, Ytrue, X[:,1:], theta1, theta0, N, D, S)
        
        # Update hidden X of initial time stamp
        NPI[:, 0] = NumPreInf(X[:, 0], G[:, :, 0])
        X[:,0] = updateInitialX(X[:, 0], X[:, 1], NPI[:, 0], xi, gamma, alpha, beta, N)
    
        # Update intermediate X
        for i in range(1,D):
            NPI[:,i-1] = NumPreInf(X[:,i-1],G[:,:,i-1])
            NPI[:,i] = NumPreInf(X[:,i],G[:,:,i])
            X[:,i] = updateIntermediaX(Y[:,:,i-1], X[:,i-1], X[:,i+1], NPI[:,i-1], NPI[:,i], theta1, theta0, gamma, alpha, beta, N)
        
        # Updata hidden X of last time stamp
        NPI[:,D-1] = NumPreInf(X[:,D],G[:,:,D-1])
        X[:,D] = updateLastX(Y[:,:,D-1], X[:,D-1], NPI[:,D-1], theta1, theta0, gamma, alpha, beta, N)
        
        # Update auxiliary variable R: prob p has various approximations
        R = updateAuxR(X, NPI, alpha, beta, N, D)
        
        # Update parameters
        xi = stats.beta.rvs(ap + sum(X[:,0]),bp + N - sum(X[:,0]), size = 1)
        gamma = stats.beta.rvs(ar + np.sum(X[:,0:D]*(X[:,1:]==0)), br + np.sum(X[:,0:D] * X[:,1:]))
        alpha = stats.beta.rvs(aa + np.sum(R == 1), ba + np.sum((X[:,0:D] == 0) * (X[:,1:] == 0)) + np.sum(R == 2))
        beta = stats.beta.rvs(ab + np.sum(R > 1), bb + np.sum(NPI*((X[:,0:D] == 0) ^ (R > 1))))
    
        temp = np.transpose(np.repeat(np.expand_dims(X[:,1:], axis=2), S, axis = 2), axes = [0, 2, 1])
        theta1 = stats.beta.rvs(a1 + np.sum(Y * temp, axis = (0,2)), b1 + np.sum((1-Y) * temp, axis = (0,2)), size = (1,S))
        theta0 = stats.beta.rvs(a0 + np.sum(Y * (temp==0), axis = (0,2)), b0 + np.sum((1-Y) * (temp == 0), axis = (0,2)), size = (1,S))
        #print(theta1, theta0)
        
        # Burn-in
        if t>B:
            Xbi = Xbi + X
            Ybi = Ybi + Ym
            parabi = parabi + np.c_[xi,alpha,beta,gamma,theta1,theta0]
    # prediction
    Xpred = Xbi/(T-B)
    Ympred = Ybi/(T-B)
    parapred = parabi/(T-B)
    return [Xpred, Ympred, parapred]

def sampleMissingY(mask, Ytrue, X, theta1, theta0, N, D, S):
    th1 = np.repeat(np.repeat(theta1.reshape((1,S,1)), N, axis=0), D, axis=2)
    th0 = np.repeat(np.repeat(theta0.reshape((1,S,1)), N, axis=0), D, axis=2)
    Ym = mask * (stats.bernoulli.rvs(th1, size=Ytrue.shape) * (X == 1).reshape(N, 1 ,D) + 
                stats.bernoulli.rvs(th0, size=Ytrue.shape) * (X == 0).reshape(N, 1, D))
    return Ym, Ym + (1 - mask) * Ytrue

def updateInitialX(X0, X1, NPI0, xi, gamma, alpha, beta, N):
    p1 = xi * (gamma**np.array(X1==0) * (1-gamma)**np.array(X1))
    p0 = (1-xi)*(1-(1-alpha)*(1-beta)**NPI0)**X1 * ((1-alpha)*(1-beta)**NPI0)**(X1==0)
    p = p1 / (p0+p1)
    return 0+(np.random.rand(N,)<=p)

def updateIntermediaX(Y_cur, X_prev, X_next, NPI_prev, NPI_cur, theta1, theta0, gamma, alpha, beta, N):
    tmp1 = np.exp(Y_cur @ np.log(theta1.T))*np.exp((1-Y_cur) @ np.log(1-theta1.T))
    p1 = gamma**(X_next==0)*(1-gamma)**(X_prev+X_next)*(1-(1-alpha)*(1-beta)**NPI_prev)**(X_prev==0) * tmp1.reshape((N,))
    tmp0 = np.exp(Y_cur @ np.log(theta0.T))*np.exp((1-Y_cur) @ np.log(1-theta0.T))
    p0 = gamma**X_prev*(1-(1-alpha)*(1-beta)**NPI_cur)**X_next*(1-alpha)**((X_prev==0)+(X_next==0))*(1-beta)**(NPI_prev*(X_prev==0)+NPI_cur*(X_next==0))*tmp0.reshape((N,))
    p = p1 / (p0 + p1)
    return 0 + (np.random.rand(N,)<=p)

def updateLastX(Y_cur, X_prev, NPI_prev, theta1, theta0, gamma, alpha, beta, N):
    tmp1 = np.exp(Y_cur @ np.log(theta1.T))* np.exp((1-Y_cur) @ np.log(1-theta1.T))
    p1 = (1-gamma) ** X_prev * (1-(1-alpha) * (1-beta) ** NPI_prev)**(X_prev==0)*tmp1.reshape((N,))
    tmp0 = np.exp(Y_cur @ np.log(theta0.T))*np.exp((1-Y_cur) @ np.log(1-theta0.T))
    p0 = gamma ** X_prev*((1-alpha)*(1-beta)**NPI_prev)**(X_prev==0)*tmp0.reshape((N,))
    p = p1 / (p0 + p1)
    return 0 + (np.random.rand(N,) <= p)

def updateAuxR(X, NPI, alpha, beta, N, D):
    p = alpha / (alpha + beta * NPI)
    tmp = 2 - (np.random.rand(N, D) <= p)
    return (X[:,0:D]==0)*X[:,1:]*tmp

In [12]:
def NumPreInf(Xt, Gt):
    return ((Gt + Gt.T) > 0) @ Xt

In [13]:
Y = scipy.io.loadmat('Y.mat')['Y']
X = scipy.io.loadmat('X.mat')['X']
G = scipy.io.loadmat('G.mat')['G']

In [14]:
missing_rate = 0
mask = stats.bernoulli.rvs(missing_rate, size = Y.shape)
Ymask = Y * mask
Ytrue = Y * (1 - mask)

In [15]:
def work_numpy(G, Ytrue, mask, T = 500):
    Gibbs_numpy(G, Ytrue, mask, T = 500)

In [16]:
%prun -q -D work_numpy.prof work_numpy(G, Ytrue, mask, T = 500)


 
*** Profile stats marshalled to file 'work_numpy.prof'. 

In [17]:
p = pstats.Stats('work_numpy.prof')
p.sort_stats('ncalls').print_stats(10)
pass


Mon May  1 01:57:58 2017    work_numpy.prof

         678718 function calls in 15.517 seconds

   Ordered by: call count
   List reduced from 74 to 10 due to restriction <10>

   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
   110004    0.134    0.000    0.134    0.000 {method 'reshape' of 'numpy.ndarray' objects}
   107000    3.369    0.000    3.369    0.000 <ipython-input-12-228c30ad9db4>:1(NumPreInf)
    54500    0.302    0.000    0.302    0.000 {method 'rand' of 'mtrand.RandomState' objects}
    53000    6.512    0.000    6.875    0.000 <ipython-input-11-f495c0f15b83>:91(updateIntermediaX)
    47295    0.064    0.000    0.064    0.000 {built-in method numpy.core.multiarray.array}
    20743    0.005    0.000    0.005    0.000 {built-in method builtins.len}
    20000    0.005    0.000    0.005    0.000 /opt/conda/lib/python3.5/site-packages/numpy/lib/stride_tricks.py:62(<genexpr>)
    18512    0.253    0.000    0.253    0.000 {method 'reduce' of 'numpy.ufunc' objects}
    17030    0.009    0.000    0.009    0.000 /opt/conda/lib/python3.5/site-packages/numpy/lib/stride_tricks.py:193(<genexpr>)
    15518    0.017    0.000    0.046    0.000 /opt/conda/lib/python3.5/site-packages/numpy/core/numeric.py:484(asanyarray)