In [141]:
# %load gnmf_solvebynewton.py
from __future__ import division
import numpy as np
import scipy as sp
from scipy import special
import numpy.matlib as M
def gnmf_solvebynewton(c, a0 = None):
if a0 is None:
a0 = 0.1 * np.ones(np.shape(c))
M, N = np.shape(a0)
if len(np.shape(c)) == 0:
Mc , Nc = 1,1
else:
Mc, Nc = np.shape(c)
a = None
cond = 0
if (M == Mc and N == Nc):
a = a0
cond = 1
elif (Mc == 1 and Nc >1):
cond = 2
a = a0[0,:]
elif (Mc > 1 and Nc == 1):
cond = 3
a = a0[:,0]
elif (Mc == 1 and Nc == 1):
cond = 4
a = a0[0,0]
a2 = None
for index in range(10):
a2 = a - (np.log(a) - special.polygamma(0,a) + 1 - c) / (1/a - special.polygamma(1,a))
idx = np.where(a2<0)
if len(idx[0]) > 0:
if isinstance(a, float):
a2 = a / 2
else:
a2[idx] = a[idx] / 2
a = a2
if(cond == 2):
a = M.repmat(a,M,1)
elif(cond == 3):
a = M.repmat(a,1,N)
elif(cond == 4):
a = a * np.ones([M,N])
return a
In [168]:
# %load gnmf_vb_poisson_mult_fast.py
from __future__ import division
import numpy as np
import scipy as sp
import math
from scipy import special
import numpy.matlib as M
def gnmf_vb_poisson_mult_fast(x,
a_tm,
b_tm,
a_ve,
b_ve,
EPOCH =1000,
Method = 'vb',
Update = np.inf,
tie_a_ve = 'clamp',
tie_b_ve = 'clamp',
tie_a_tm = 'clamp',
tie_b_tm = 'clamp',
print_period = 500
):
# Result initialiation
g = dict()
g['E_T'] = None
g['E_logT'] = None
g['E_V'] = None
g['E_logV'] = None
g['Bound'] = None
g['a_ve'] = None
g['b_ve'] = None
g['a_tm'] = None
g['b_tm'] = None
logm = np.vectorize(math.log)
W = x.shape[0]
K = x.shape[1]
I = b_tm.shape[1]
M = ~np.isnan(x)
X = np.zeros(x.shape)
X[M] = x[M]
t_init = np.random.gamma(a_tm, b_tm/a_tm)
v_init = np.random.gamma(a_ve, b_ve/a_ve)
L_t = t_init
L_v = v_init
E_t = t_init
E_v = v_init
Sig_t = t_init
Sig_v = v_init
B = np.zeros([1,EPOCH])
gammalnX = special.gammaln(X+1)
for e in range(1,EPOCH+1):
LtLv = L_t.dot(L_v)
tmp = X / (LtLv)
#check Tranpose
Sig_t = L_t * (tmp.dot(L_v.T))
Sig_v = L_v * (L_t.T.dot(tmp))
alpha_tm = a_tm + Sig_t
beta_tm = 1/((a_tm/b_tm) + M.dot(E_v.T))
E_t = alpha_tm * (beta_tm)
alpha_ve = a_ve + Sig_v
beta_ve = 1/((a_ve/b_ve) + E_t.T.dot(M))
E_v = alpha_ve * (beta_ve)
# Compute the bound
if(e%10 == 1):
print("*", end='')
if(e%print_period == 1 or e == EPOCH):
g['E_T'] = E_t
g['E_logT'] = logm(L_t)
g['E_V'] = E_v
g['E_logV'] = logm(L_v)
g['Bound'] = -np.sum(np.sum(M * (g['E_T'].dot(g['E_V'])) + gammalnX))\
+ np.sum(np.sum(-X * ( ((L_t * g['E_logT']).dot(L_v) + L_t.dot(L_v * g['E_logV']))/(LtLv) - logm(LtLv) ) ))\
+ np.sum(np.sum((-a_tm/b_tm)* g['E_T'] - special.gammaln(a_tm) + a_tm * logm(a_tm /b_tm)))\
+ np.sum(np.sum((-a_ve/b_ve)* g['E_V'] - special.gammaln(a_ve) + a_ve * logm(a_ve /b_ve)))\
+ np.sum(np.sum( special.gammaln(alpha_tm) + alpha_tm * logm(beta_tm) + 1))\
+ np.sum(np.sum(special.gammaln(alpha_ve) + alpha_ve * logm(beta_ve) + 1 ))
g['a_ve'] = a_ve
g['b_ve'] = b_ve
g['a_tm'] = a_tm
g['b_tm'] = b_tm
print()
print( g['Bound'], a_ve.flatten()[0], b_ve.flatten()[0], a_tm.flatten()[0], b_tm.flatten()[0])
if (e == EPOCH):
break;
L_t = np.exp(special.psi(alpha_tm)) * beta_tm
L_v = np.exp(special.psi(alpha_ve)) * beta_ve
Z = None
if( e> Update):
if(not tie_a_tm == 'clamp' ):
Z = (E_t / b_tm) - (logm(L_t) - logm(b_tm))
if(tie_a_tm == 'clamp'):
a_tm = gnmf_solvebynewton(Z,a0=a_tm)
elif(tie_a_tm == 'rows'):
a_tm = gnmf_solvebynewton(np.sum(Z,0)/W, a0=a_tm)
elif(tie_a_tm == 'cols'):
a_tm = gnmf_solvebynewton(np.sum(Z,1)/I, a0=a_tm)
elif(tie_a_tm == 'tie_all'):
#print(np.sum(Z)/(W * I))
#print(a_tm)
a_tm = gnmf_solvebynewton(np.sum(Z)/(W * I), a0=a_tm)
if(tie_b_tm == 'free'):
b_tm = E_t
elif(tie_b_tm == 'rows'):
b_tm = M.repmat(np.sum(a_tm * E_t,0)/np.sum(a_tm,0),W,1)
elif(tie_b_tm == 'cols'):
b_tm = M.repmat(np.sum(a_tm * E_t,1)/np.sum(a_tm,1),1,I)
elif(tie_b_tm == 'tie_all'):
b_tm = (np.sum(a_tm*E_t)/ np.sum(a_tm)) * np.ones([W,I])
if(not tie_a_ve == 'clamp' ):
Z = (E_v / b_ve) - (logm(L_v) - logm(b_ve))
if(tie_a_ve == 'clamp'):
a_ve = gnmf_solvebynewton(Z,a_ve)
elif(tie_a_ve == 'rows'):
a_ve = gnmf_solvebynewton(np.sum(Z,0)/I, a0=a_ve)
elif(tie_a_ve == 'cols'):
a_ve = gnmf_solvebynewton(np.sum(Z,1)/K, a0=a_ve)
elif(tie_a_ve == 'tie_all'):
a_ve = gnmf_solvebynewton(np.sum(Z)/(I * K), a0=a_ve)
if(tie_b_ve == 'free'):
b_ve = E_v
elif(tie_b_ve == 'rows'):
b_ve = M.repmat(np.sum(a_ve * E_v,0)/np.sum(a_ve,0),I,1)
elif(tie_b_tm == 'cols'):
b_ve = M.repmat(np.sum(a_ve * E_v,1)/np.sum(a_ve,1),1,K)
elif(tie_b_tm == 'tie_all'):
b_ve = (np.sum(a_ve*E_v)/ np.sum(a_ve)) * np.ones([I,K])
return g
In [170]:
# %load gnmf_vb_demo.py
import numpy as np
import scipy as sp
from scipy import special
import numpy.matlib as M
W = 40
K = 5
I = 3
a_tm = 10 * np.ones([W,I])
b_tm = np.ones([W,I])
a_ve = np.ones([I,K])
b_ve = 100 * np.ones([I,K])
T = np.random.gamma(a_tm,b_tm)
V = np.random.gamma(a_ve,b_ve)
x = np.random.poisson(T.dot(V))
hoho = gnmf_vb_poisson_mult_fast(x,a_tm,b_tm,a_ve,b_ve,
EPOCH=2000,
Update =10,
tie_a_ve='tie_all',
tie_b_ve='tie_all',
tie_a_tm='tie_all',
tie_b_tm='tie_all')