In [141]:
# %load gnmf_solvebynewton.py
from __future__ import division
import numpy as np
import scipy as sp
from scipy import special
import numpy.matlib as M

def gnmf_solvebynewton(c, a0 = None):

    if a0 is None:
        a0 = 0.1 * np.ones(np.shape(c))

    M, N = np.shape(a0)
    if len(np.shape(c)) == 0:
        Mc , Nc = 1,1
    else:
        Mc, Nc = np.shape(c)



    a = None
    cond = 0

    if (M == Mc and N == Nc):
        a = a0
        cond = 1

    elif (Mc == 1 and Nc >1):
        cond = 2
        a = a0[0,:]
    elif (Mc > 1 and Nc == 1):
        cond = 3
        a = a0[:,0]
    elif (Mc == 1 and Nc == 1):
        cond = 4
        a = a0[0,0]

    a2 = None
    for index in range(10):
        a2 = a - (np.log(a) - special.polygamma(0,a) + 1 - c) / (1/a - special.polygamma(1,a))
        idx = np.where(a2<0)
        if len(idx[0]) > 0:
            if isinstance(a, float):
                a2 = a / 2
            else:
                a2[idx] = a[idx] / 2
        a = a2

    if(cond == 2):
        a = M.repmat(a,M,1)
    elif(cond == 3):
        a = M.repmat(a,1,N)
    elif(cond == 4):
        a = a * np.ones([M,N])

    return a

In [168]:
# %load gnmf_vb_poisson_mult_fast.py
from __future__ import division
import numpy as np
import scipy as sp
import math
from scipy import special
import numpy.matlib as M

def gnmf_vb_poisson_mult_fast(x,
                            a_tm,
                            b_tm,
                            a_ve,
                            b_ve,
                            EPOCH =1000,
                            Method = 'vb',
                            Update = np.inf,
                            tie_a_ve = 'clamp',
                            tie_b_ve = 'clamp',
                            tie_a_tm = 'clamp',
                            tie_b_tm = 'clamp',
                            print_period = 500
                            ):

    # Result initialiation
    g = dict()
    g['E_T'] = None
    g['E_logT'] = None
    g['E_V'] = None
    g['E_logV'] = None
    g['Bound'] = None
    g['a_ve'] = None
    g['b_ve'] = None
    g['a_tm'] = None
    g['b_tm'] = None

    logm = np.vectorize(math.log)
    W = x.shape[0]
    K = x.shape[1]
    I = b_tm.shape[1]

    M = ~np.isnan(x)
    X = np.zeros(x.shape)
    X[M] = x[M]

    t_init = np.random.gamma(a_tm, b_tm/a_tm)
    v_init = np.random.gamma(a_ve, b_ve/a_ve)
    L_t = t_init
    L_v = v_init
    E_t = t_init
    E_v = v_init
    Sig_t = t_init
    Sig_v = v_init

    B = np.zeros([1,EPOCH])
    gammalnX = special.gammaln(X+1)

    for e in range(1,EPOCH+1):

        LtLv = L_t.dot(L_v)
        tmp = X / (LtLv)
        #check Tranpose
        Sig_t = L_t * (tmp.dot(L_v.T))
        Sig_v = L_v * (L_t.T.dot(tmp))

        alpha_tm = a_tm + Sig_t
        beta_tm = 1/((a_tm/b_tm) + M.dot(E_v.T))
        E_t = alpha_tm * (beta_tm)

        alpha_ve = a_ve + Sig_v
        beta_ve = 1/((a_ve/b_ve) + E_t.T.dot(M))

        E_v = alpha_ve * (beta_ve)
        # Compute the bound
        if(e%10 == 1):
            print("*", end='')
        if(e%print_period == 1 or e == EPOCH):
            g['E_T'] = E_t
            g['E_logT'] = logm(L_t)
            g['E_V'] = E_v
            g['E_logV'] = logm(L_v)

            g['Bound'] = -np.sum(np.sum(M * (g['E_T'].dot(g['E_V'])) + gammalnX))\
                        + np.sum(np.sum(-X * ( ((L_t * g['E_logT']).dot(L_v) + L_t.dot(L_v * g['E_logV']))/(LtLv) - logm(LtLv) ) ))\
                        + np.sum(np.sum((-a_tm/b_tm)* g['E_T'] - special.gammaln(a_tm) + a_tm * logm(a_tm /b_tm)))\
                        + np.sum(np.sum((-a_ve/b_ve)* g['E_V'] - special.gammaln(a_ve) + a_ve * logm(a_ve /b_ve)))\
                        + np.sum(np.sum( special.gammaln(alpha_tm) + alpha_tm * logm(beta_tm) + 1))\
                        + np.sum(np.sum(special.gammaln(alpha_ve) + alpha_ve * logm(beta_ve) + 1 ))

            g['a_ve'] = a_ve
            g['b_ve'] = b_ve
            g['a_tm'] = a_tm
            g['b_tm'] = b_tm

            print()
            print( g['Bound'], a_ve.flatten()[0], b_ve.flatten()[0], a_tm.flatten()[0], b_tm.flatten()[0])
        if (e == EPOCH):
            break;
        L_t = np.exp(special.psi(alpha_tm)) * beta_tm
        L_v = np.exp(special.psi(alpha_ve)) * beta_ve

        Z = None
        if( e> Update):
            if(not tie_a_tm == 'clamp' ):
                Z = (E_t / b_tm) - (logm(L_t) - logm(b_tm))
                if(tie_a_tm == 'clamp'):
                    a_tm = gnmf_solvebynewton(Z,a0=a_tm)
                elif(tie_a_tm == 'rows'):
                    a_tm = gnmf_solvebynewton(np.sum(Z,0)/W, a0=a_tm)
                elif(tie_a_tm == 'cols'):
                    a_tm = gnmf_solvebynewton(np.sum(Z,1)/I, a0=a_tm)
                elif(tie_a_tm == 'tie_all'):
                    #print(np.sum(Z)/(W * I))
                    #print(a_tm)
                    a_tm = gnmf_solvebynewton(np.sum(Z)/(W * I), a0=a_tm)

            if(tie_b_tm == 'free'):
                b_tm = E_t
            elif(tie_b_tm == 'rows'):
                b_tm = M.repmat(np.sum(a_tm * E_t,0)/np.sum(a_tm,0),W,1)
            elif(tie_b_tm == 'cols'):
                b_tm = M.repmat(np.sum(a_tm * E_t,1)/np.sum(a_tm,1),1,I)
            elif(tie_b_tm == 'tie_all'):
                b_tm = (np.sum(a_tm*E_t)/ np.sum(a_tm)) * np.ones([W,I])

            if(not tie_a_ve == 'clamp' ):
                Z = (E_v / b_ve) - (logm(L_v) - logm(b_ve))
                if(tie_a_ve == 'clamp'):
                    a_ve = gnmf_solvebynewton(Z,a_ve)
                elif(tie_a_ve == 'rows'):
                    a_ve = gnmf_solvebynewton(np.sum(Z,0)/I, a0=a_ve)
                elif(tie_a_ve == 'cols'):
                    a_ve = gnmf_solvebynewton(np.sum(Z,1)/K, a0=a_ve)
                elif(tie_a_ve == 'tie_all'):
                    a_ve = gnmf_solvebynewton(np.sum(Z)/(I * K), a0=a_ve)

            if(tie_b_ve == 'free'):
                b_ve = E_v
            elif(tie_b_ve == 'rows'):
                b_ve = M.repmat(np.sum(a_ve * E_v,0)/np.sum(a_ve,0),I,1)
            elif(tie_b_tm == 'cols'):
                b_ve = M.repmat(np.sum(a_ve * E_v,1)/np.sum(a_ve,1),1,K)
            elif(tie_b_tm == 'tie_all'):
                b_ve = (np.sum(a_ve*E_v)/ np.sum(a_ve)) * np.ones([I,K])
    return g

In [170]:
# %load gnmf_vb_demo.py
import numpy as np
import scipy as sp
from scipy import special
import numpy.matlib as M

W = 40
K = 5
I = 3

a_tm = 10 * np.ones([W,I])
b_tm = np.ones([W,I])
a_ve = np.ones([I,K])
b_ve = 100 * np.ones([I,K])

T = np.random.gamma(a_tm,b_tm)
V = np.random.gamma(a_ve,b_ve)

x = np.random.poisson(T.dot(V))

hoho = gnmf_vb_poisson_mult_fast(x,a_tm,b_tm,a_ve,b_ve,
                                EPOCH=2000,
                                Update =10,
                                tie_a_ve='tie_all',
                                tie_b_ve='tie_all',
                                tie_a_tm='tie_all',
                                tie_b_tm='tie_all')


*
-1576713.90752 1.0 100.0 10.0 1.0
**************************************************
-1544975.67074 0.93549779041 98.9495890457 5.14009699444 14.2437747219
**************************************************
-1545138.3518 0.927928620051 96.8752240211 6.6271056808 13.9483194313
**************************************************
-1545301.05102 0.905908268804 95.5659607918 8.08614514641 13.7648845526
*************************************************
-1545433.16632 0.872370022471 94.8741521897 9.25827683333 13.6776155021