In [25]:
import numpy as np
from scipy import stats
import scipy.io
from matplotlib import pyplot as plt
import time
import pstats
In [26]:
def Gibbs_jit(G, Ytrue, mask, T,
prior={'ap':2, 'bp':5, 'aa':2, 'ba':5, 'ab':2, 'bb':5, 'ar':2, 'br':5, 'a1':5, 'b1':2, 'a0':2, 'b0':3}):
# Gibbs sampling for GCHMMs with inference and parameter estimation
# G: non-symmetric social networks
# Y: evidence, observed data
# T: num. of iterations
N, _, D = G.shape
_, S, _ = Ytrue.shape
## Initialization
X = np.zeros((N,D+1), dtype='int32')
R = np.zeros((N,D))
#hyperparameters
ap = prior['ap']; bp = prior['bp']
aa = prior['aa']; ba = prior['ba']
ab = prior['ab']; bb = prior['bb']
ar = prior['ar']; br = prior['br']
a1 = prior['a1']; b1 = prior['b1']
a0 = prior['a0']; b0 = prior['b0']
xi = stats.beta.rvs(ap, bp, size=1)
alpha = stats.beta.rvs(aa,ba,size=1)
beta = stats.beta.rvs(ab,bb,size=1)
gamma = stats.beta.rvs(ar,br,size=1)
theta1 = stats.beta.rvs(a1,b1,size=(1,S))
theta0 = stats.beta.rvs(a0,b0,size=(1,S))
##Iterative Sampling
B = T/2 # Burn-in from Iteration B
Xbi = X # Burn-in for X
Ybi = np.zeros((Ytrue.shape))# Burn-in for Y
parabi = np.zeros((1,2*S+4)) # Burn-in for all parameters
NPI = np.zeros((N,D)) # Num. of previous infection
for t in range(T):
# sample missing Y
Ym, Y = sampleMissingY(mask, Ytrue, X[:,1:], theta1, theta0, N, D, S)
# Update hidden X of initial time stamp
NPI[:, 0] = NumPreInf(X[:, 0], G[:, :, 0])
X[:,0] = updateInitialX(X[:, 0], X[:, 1], NPI[:, 0], xi, gamma, alpha, beta, N)
# Update intermediate X
for i in range(1,D):
NPI[:,i-1] = NumPreInf(X[:,i-1],G[:,:,i-1])
NPI[:,i] = NumPreInf(X[:,i],G[:,:,i])
X[:,i] = updateIntermediaX(Y[:,:,i-1], X[:,i-1], X[:,i+1], NPI[:,i-1], NPI[:,i], theta1, theta0, gamma, alpha, beta, N)
# Updata hidden X of last time stamp
NPI[:,D-1] = NumPreInf(X[:,D],G[:,:,D-1])
X[:,D] = updateLastX(Y[:,:,D-1], X[:,D-1], NPI[:,D-1], theta1, theta0, gamma, alpha, beta, N)
# Update auxiliary variable R: prob p has various approximations
R = updateAuxR(X, NPI, alpha, beta, N, D)
# Update parameters
xi = stats.beta.rvs(ap + sum(X[:,0]),bp + N - sum(X[:,0]), size = 1)
gamma = stats.beta.rvs(ar + np.sum(X[:,0:D]*(X[:,1:]==0)), br + np.sum(X[:,0:D] * X[:,1:]))
alpha = stats.beta.rvs(aa + np.sum(R == 1), ba + np.sum((X[:,0:D] == 0) * (X[:,1:] == 0)) + np.sum(R == 2))
beta = stats.beta.rvs(ab + np.sum(R > 1), bb + np.sum(NPI*((X[:,0:D] == 0) ^ (R > 1))))
temp = np.transpose(np.repeat(np.expand_dims(X[:,1:], axis=2), S, axis = 2), axes = [0, 2, 1])
theta1 = stats.beta.rvs(a1 + np.sum(Y * temp, axis = (0,2)), b1 + np.sum((1-Y) * temp, axis = (0,2)), size = (1,S))
theta0 = stats.beta.rvs(a0 + np.sum(Y * (temp==0), axis = (0,2)), b0 + np.sum((1-Y) * (temp == 0), axis = (0,2)), size = (1,S))
#print(theta1, theta0)
# Burn-in
if t>B:
Xbi = Xbi + X
Ybi = Ybi + Ym
parabi = parabi + np.c_[xi,alpha,beta,gamma,theta1,theta0]
# prediction
Xpred = Xbi/(T-B)
Ympred = Ybi/(T-B)
parapred = parabi/(T-B)
return [Xpred, Ympred, parapred]
def sampleMissingY(mask, Ytrue, X, theta1, theta0, N, D, S):
th1 = np.repeat(np.repeat(theta1.reshape((1,S,1)), N, axis=0), D, axis=2)
th0 = np.repeat(np.repeat(theta0.reshape((1,S,1)), N, axis=0), D, axis=2)
Ym = mask * (stats.bernoulli.rvs(th1, size=Ytrue.shape) * (X == 1).reshape(N, 1 ,D) +
stats.bernoulli.rvs(th0, size=Ytrue.shape) * (X == 0).reshape(N, 1, D))
return Ym, Ym + (1 - mask) * Ytrue
def updateInitialX(X0, X1, NPI0, xi, gamma, alpha, beta, N):
p1 = xi * (gamma**np.array(X1==0) * (1-gamma)**np.array(X1))
p0 = (1-xi)*(1-(1-alpha)*(1-beta)**NPI0)**X1 * ((1-alpha)*(1-beta)**NPI0)**(X1==0)
p = p1 / (p0+p1)
return 0+(np.random.rand(N,)<=p)
def updateIntermediaX(Y_cur, X_prev, X_next, NPI_prev, NPI_cur, theta1, theta0, gamma, alpha, beta, N):
tmp1 = np.exp(Y_cur @ np.log(theta1.T))*np.exp((1-Y_cur) @ np.log(1-theta1.T))
p1 = gamma**(X_next==0)*(1-gamma)**(X_prev+X_next)*(1-(1-alpha)*(1-beta)**NPI_prev)**(X_prev==0) * tmp1.reshape((N,))
tmp0 = np.exp(Y_cur @ np.log(theta0.T))*np.exp((1-Y_cur) @ np.log(1-theta0.T))
p0 = gamma**X_prev*(1-(1-alpha)*(1-beta)**NPI_cur)**X_next*(1-alpha)**((X_prev==0)+(X_next==0))*(1-beta)**(NPI_prev*(X_prev==0)+NPI_cur*(X_next==0))*tmp0.reshape((N,))
p = p1 / (p0 + p1)
return 0 + (np.random.rand(N,)<=p)
def updateLastX(Y_cur, X_prev, NPI_prev, theta1, theta0, gamma, alpha, beta, N):
tmp1 = np.exp(Y_cur @ np.log(theta1.T))* np.exp((1-Y_cur) @ np.log(1-theta1.T))
p1 = (1-gamma) ** X_prev * (1-(1-alpha) * (1-beta) ** NPI_prev)**(X_prev==0)*tmp1.reshape((N,))
tmp0 = np.exp(Y_cur @ np.log(theta0.T))*np.exp((1-Y_cur) @ np.log(1-theta0.T))
p0 = gamma ** X_prev*((1-alpha)*(1-beta)**NPI_prev)**(X_prev==0)*tmp0.reshape((N,))
p = p1 / (p0 + p1)
return 0 + (np.random.rand(N,) <= p)
def updateAuxR(X, NPI, alpha, beta, N, D):
p = alpha / (alpha + beta * NPI)
tmp = 2 - (np.random.rand(N, D) <= p)
return (X[:,0:D]==0)*X[:,1:]*tmp
In [27]:
import numba
from numba import jit
@jit
def NumPreInf(Xt,Gt):
m = Gt.shape[0]
gt = np.zeros((m, m))
res = np.zeros((m, ))
for i in range(m):
for j in range(m):
gt[i,j] = (Gt[i,j] + Gt[j,i]) > 0
for i in range(m):
res[i] = 0
for k in range(m):
res[i] += gt[i,k] * Xt[k]
return res
In [28]:
Y = scipy.io.loadmat('Y.mat')['Y']
X = scipy.io.loadmat('X.mat')['X']
G = scipy.io.loadmat('G.mat')['G']
In [29]:
missing_rate = 0
mask = stats.bernoulli.rvs(missing_rate, size = Y.shape)
Ymask = Y * mask
Ytrue = Y * (1 - mask)
In [30]:
def work_jit(G, Ytrue, mask, T = 500):
Gibbs_jit(G, Ytrue, mask, T = 500)
In [31]:
%prun -q -D work_jit.prof work_jit(G, Ytrue, mask, T = 500)
In [32]:
p = pstats.Stats('work_jit.prof')
p.sort_stats('ncalls').print_stats(10)
pass