In [2]:
from __future__ import division
import time
from scipy.special import logit
from sklearn.utils.extmath import logsumexp
from sklearn.utils import check_random_state
from sklearn.base import BaseEstimator
import numpy as np
import math
import random, collections
import scipy.sparse

EPS = np.finfo(float).eps

def log_product_of_bernoullis_mixture_likelihood(X, logit_odds, log_inv_mean_sums):
    """
    X:
    Here the X is the original patche data (with padded window)
    logit_odds:
    The logit parts with background probabilities as well.
    log(mean / 1 - mean)
    log_inv_mean_sums:
    (1 / (log mean)).sum()
    """
    
    print("inside calculating log")
    n, d = X.shape
    m = log_inv_mean_sums.size

    memory_limit = 200

    b = max(math.floor(memory_limit * 1024 * 1024 / (64 * d) - m), 1)
    
    num_b = math.ceil(n / b)

    loglike = np.empty((n,m))
    for i in range(num_b):
        loglike[i * b : (i + 1) * b] = np.dot(X[i * b:(i + 1) * b], logit_odds.transpose()) + log_inv_mean_sums
    print("outside calculatinglog")
    return loglike


def latentShiftEM(data, num_mixture_component, parts_shape, region_shape, shifting_shape, max_num_iteration, loglike_tolerance, mu_truncation = (1, 1), additional_mu = None, permutation = None, numpy_rng=None, verbose = False):
    n,d = data.shape
    partDimension = parts_shape[0] * parts_shape[1] * parts_shape[2]
    numShifting = shifting_shape[0] * shifting_shape[1]
    bkgRegion = region_shape[0] * region_shape[1] * region_shape[2] - partDimension
    if(isinstance(mu_truncation, float)):
        use_epsilon = True
        epsilon = mu_truncation
    else:
        use_epsilon = False
        beta_prior = mu_truncation

    purity_level = 2
    log_w = np.empty((numShifting,num_mixture_component))
    log_w.fill(-np.log(numShifting * num_mixture_component))
    mu = numpy_rng.uniform(size = (num_mixture_component, partDimension)) ** (1 / (purity_level + 1))
    print(mu.shape)
    centerXStart = int((shifting_shape[0] - 1)/2)
    centerYStart = int((shifting_shape[1] - 1)/2)
    is_flip = np.logical_not((data.reshape((n,)+region_shape)[:,centerXStart:centerXStart + parts_shape[0], centerYStart:centerYStart + parts_shape[1],:]).reshape(n,-1)[numpy_rng.choice(n,num_mixture_component,replace=False)])
    print(is_flip.shape)
    mu[is_flip] = 1 - mu[is_flip]
    bkg_probability = 0.2

    if use_epsilon:
        mu[mu < epsilon] = epsilon
        mu[mu > 1 - epsilon] = 1 - epsilon
    else:
        mu *= (n / num_mixture_component) / ((n / num_mixture_component) + np.sum(beta_prior))
        mu += beta_prior[0] / ((n / num_mixture_component) + np.sum(beta_prior))

    log_odd = np.empty((num_mixture_component,)+region_shape)
    sum_log_one_mu = np.empty(num_mixture_component)
    log_q = np.empty((numShifting, n, num_mixture_component))

    # DO EM.
    loglike = []
    t = 0
    while t < max_num_iteration:
        if verbose:
            clock_start = time.clock()

        # E -step : Compoute q
        sum_log_one_mu = np.log(1 - mu).sum(axis = 1) + np.log(1 - bkg_probability) * bkgRegion
        for i in range(shifting_shape[0]):
            for j in range(shifting_shape[1]):

                log_odd[:num_mixture_component] = np.ones((num_mixture_component,)+ region_shape) * bkg_probability
                log_odd[:num_mixture_component,i:i+parts_shape[0],j:j+parts_shape[1],:] = mu.reshape((num_mixture_component,) + parts_shape)
                log_odd = logit(log_odd)
                log_q[i * shifting_shape[1] + j] = log_product_of_bernoullis_mixture_likelihood(data,log_odd.reshape((num_mixture_component, -1)),sum_log_one_mu)

        norm_log_q = logsumexp(logsumexp(log_q, axis = 2),axis = 0)
        log_q -= norm_log_q.reshape((1,n,1))

        # M - Step: Computer weights and model.
        log_w = logsumexp(log_q, axis = 1)
        log_q_sum_r_n = logsumexp(log_w,axis = 0)
        log_w -= logsumexp(log_w)

        q = np.exp(log_q)
        mu = np.zeros((num_mixture_component,partDimension))
        p_bkg = 0
        for i in range(shifting_shape[0]):
            for j in range(shifting_shape[1]):
                dotResult = np.dot(q[i * shifting_shape[1] + j].transpose(),data).reshape((num_mixture_component,)+region_shape) 
                mu += dotResult[:,i:i+parts_shape[0],j:j+parts_shape[1],:].reshape((num_mixture_component,-1))
                dotResult[:,i:i+parts_shape[0],j:j+parts_shape[1],:] = 0
                p_bkg += np.sum(dotResult)
        bkg_probability = p_bkg / (n * bkgRegion)
        
        
        if use_epsilon:
            eps = np.finfo(np.float_).eps
            mu[mu<eps] = eps
            mu = np.exp(np.log(mu) - log_q_sum_r_n.reshape((num_mixture_component, 1)))
            mu[mu < epsilon] = epsilon
            mu[mu > 1- epsilon] = 1 - epsilon
        else:
            mu += beta_prior[0]
            mu /= (np.exp(log_q_sum_r_n) + np.sum(beta_prior)).reshape((num_mixture_component,1))

        loglike.append(norm_log_q.sum())
        if verbose:
            print('Iter {}:{:.3f} seconds. Log-likelihood : {:.1f}'.format(t + 1, time.clock - clock_start, loglike[-1]))
        if t >= 1 and loglike[-1] - loglike[-2] < loglike_tolerance *  - loglike[-2]:
            break
        t+=1
    loglike = np.asarray(loglike,dtype = np.float64)
    log_weight = log_w.sum(axis = 0)
    ordering = np.argsort(log_weight)[::-1]
    log_weight = log_weight[ordering]
    mu = mu[ordering]

    data_m = logsumexp(log_q,axis = 0).argmax(axis = 1)
    idx = np.ravel_multi_index((np.arange(n),data_m),(n,num_mixture_component))
    data_p = log_q.reshape((numShifting,n * num_mixture_component))[:,idx].argmax(axis = 0)
    inverse_ordering = np.argsort(ordering)
    data_m = inverse_ordering
    data_label = np.hstack((data_m.reshape((-1,1)),data_p.reshape((-1,1))))
    if verbose:
        print('latentShiftingEM finished')
    return log_weight, mu, loglike,data_label

In [4]:
import amitgroup as ag

In [5]:
import pnet

In [6]:
edgeLayer = pnet.EdgeLayer(k=5,radius = 1,spread = 'orthogonal',minimum_contrast=0.05)
partsNet = [edgeLayer]

In [7]:
net = pnet.PartsNet(partsNet)

In [8]:
digits = range(10)
ims = ag.io.load_mnist('training',selection = slice(10000),return_labels = False)

In [9]:
result = edgeLayer.extract(ims)

In [10]:
result.shape


Out[10]:
(10000, 28, 28, 8)

In [11]:
partsLayer = pnet.PartsLayer(100, (10, 10), settings=dict(outer_frame=0,threshold=40,samples_per_image=40,max_samples=1000000,min_prob=0.005))

In [12]:
allPatch= partsLayer._get_patches(result,result[:,:,:,0])

In [13]:
allPatch[0].shape


Out[13]:
(400000, 10, 10, 8)

In [14]:
allPatch = allPatch[0]
allPatch = allPatch.reshape(allPatch.shape[0],-1)

In [15]:
allPatch.shape


Out[15]:
(400000, 800)

In [ ]:
rng = np.random.RandomState()
em = latentShiftEM(allPatch,50,(6,6,8),(10,10,8), (5,5),25,1e-3 , mu_truncation = (1, 1), numpy_rng=rng, verbose = True)


(50, 288)
(50, 288)
inside calculating log
outside calculatinglog
inside calculating log
outside calculatinglog
inside calculating log
outside calculatinglog
inside calculating log
outside calculatinglog
inside calculating log
outside calculatinglog
inside calculating log
outside calculatinglog
inside calculating log
outside calculatinglog
inside calculating log
outside calculatinglog
inside calculating log
outside calculatinglog
inside calculating log
outside calculatinglog
inside calculating log
outside calculatinglog
inside calculating log
outside calculatinglog
inside calculating log
outside calculatinglog
inside calculating log
outside calculatinglog
inside calculating log
outside calculatinglog
inside calculating log
outside calculatinglog
inside calculating log
outside calculatinglog
inside calculating log
outside calculatinglog
inside calculating log
outside calculatinglog
inside calculating log
outside calculatinglog
inside calculating log
outside calculatinglog
inside calculating log
outside calculatinglog
inside calculating log
outside calculatinglog
inside calculating log
outside calculatinglog
inside calculating log
outside calculatinglog

In [39]:
is_flip = allPatch[rng.choice(400000,800,replace=False)]
is_flip.shape


Out[39]:
(800, 800)

In [40]:
is_flip


Out[40]:
array([[1, 0, 0, ..., 0, 0, 0],
       [0, 0, 0, ..., 0, 0, 0],
       [0, 0, 0, ..., 0, 0, 0],
       ..., 
       [0, 0, 0, ..., 0, 0, 0],
       [0, 0, 0, ..., 0, 0, 0],
       [0, 0, 0, ..., 0, 0, 0]], dtype=uint8)

In [36]:
a = rng.choice(400000,800,replace=False)

In [37]:
a.shape


Out[37]:
(800,)

In [38]:
a


Out[38]:
array([277650, 336464, 130993, 302097, 358685,  42860, 213120,  99881,
        79077,  72797,  89541,  97412, 121763, 124844, 286665,  63963,
       167056, 291097, 246268, 170426,  19331, 151911, 189497, 190836,
       189543, 200435,  96226, 388889,  70116,  70752, 245557, 230087,
       348475, 327921, 264280,  63605,  69687,  43990, 127604, 246140,
       243665,    327, 291118, 104809, 257797,  87995,  58852, 236555,
       196464, 251864, 369993,  94222, 390273,  55992,  36474, 352695,
       355321, 300925, 249732,  12406, 253392, 353603, 234290, 338170,
       359460,  29976, 242494,  21964, 224492,  81911,  27741, 268636,
       321762, 348494, 137348, 281624, 398961, 231002, 234042, 126883,
        98748, 219305,  49333, 333713,  90564, 257275,   4144,  42158,
       253955,  57145, 228283, 397077, 121484, 209934, 286106, 116531,
       360478, 238119, 391352, 383769,  45119, 189110, 336767, 157458,
       352802, 236181, 202200, 305454, 212475, 368507, 360518, 220526,
        17468, 300060, 381222,    423, 185919, 166692, 206467, 169206,
       279348, 205783, 132920, 342168, 307513, 166000,  44910, 354726,
        21475,  66285, 369427, 370335, 139278, 245094, 204253, 147582,
       294860, 379670, 364926, 163143, 202377,   4722,  86327, 398992,
       177554, 269122, 270593, 126010, 259520, 209169, 354773, 313787,
       153156,  68366, 299685, 253206, 169190, 248032, 247873, 208054,
       339265, 142762,  59989, 345369, 207616,  94388, 178371,  88027,
       161864, 384157, 377955,  68070, 263036, 238455, 331903, 376909,
        27235, 288109,   8070, 106590, 237326, 124467, 360366, 178470,
       332125,   9259, 236178, 155783, 184059, 356510, 191943, 149896,
       176858,  32225, 158746, 115818, 343124,  28426,  67228, 177835,
       387914, 122869, 312900, 309593,  48443, 315256, 367305, 189147,
        64266,  42597, 224321, 101571, 351545,  15345,  73330, 121808,
       287971, 128783, 234776, 283735,  77215, 216755, 362565, 107091,
       197410, 275051,  60842, 265651, 342419,  73208, 354513, 137833,
       288450, 337259, 138312, 151811, 366860, 227952, 379668, 221643,
       112224,  51492, 308168, 223367, 277830, 337245,  79247, 237837,
       384053, 265095,  14477, 382260, 330569,  70392, 197057, 182655,
       140060, 274468, 380507, 358977, 330846, 136735,  18244, 163077,
        85587, 199793, 121513, 288315, 398055,  38078, 196775,   3787,
        50940, 216314,  44752, 388623, 304726, 194108,  55563,  59656,
       187569, 197921, 206656, 123999, 275404, 242801, 366184, 216003,
       101028, 146113, 119636,   1575, 102930, 198973, 260421, 253462,
       263075, 382226, 273392, 114684, 133361,  87778, 360781,   2800,
       276290,  32485, 167560, 243018, 379026, 260857, 333371, 326937,
       263041, 349658,   4356,  69492, 226261,   6175, 214656, 241800,
       243670, 131570,  97860, 332835,  52813, 320709, 150884, 187970,
       327874,  99802, 315589, 391360, 221407, 372650, 229910, 231178,
       166797, 349298, 390431, 222067,  98265, 359774, 147554, 119928,
       327997,  15420,   9689,  71991, 246888, 206127,  92765, 333910,
       121523, 366364, 346750, 117418,  21316, 115924, 351689,  25722,
       282063, 201476,  31380, 399009, 347571, 280289, 160820, 325908,
       266539, 130788, 289188, 340362, 117233, 372173, 282242, 301087,
        21228, 210835, 385384, 263760,  74966, 346391,   4243, 221178,
       118470,  32517,  79507, 188709,   9575, 158635,  28019, 207415,
       343841, 122616, 179032,  25941, 351683, 110604, 238511, 213148,
       392770, 375949,  43011, 165327, 210384,  92982, 283708, 210018,
       353597,  52176, 303968, 137007, 377020, 285318,   2593,  78370,
        96786, 196774, 252542, 346613, 141154, 340616, 139950,  41730,
        73114,  86340, 162869,  85585, 232522, 163630,  95224, 103285,
       159577, 180346, 279340, 291915, 100238, 302974, 321106, 280260,
       312268, 162895, 142699, 164308, 234445, 176069, 269785, 207342,
       166151, 319041, 285722, 260016,  77061,   8780, 167333, 345062,
       165226, 299845,  79211, 292237,  45581, 345870, 203388, 273135,
       106307, 163309, 246629, 303229, 155058,   1680, 363146,  66815,
       100013, 324210, 265927,  53087, 167780, 308378, 360299, 333884,
       373630, 394943,  38067, 271044, 369632, 330957,  95652,  20104,
        46387,  17174, 347250, 339052, 256286, 101060, 177316, 285891,
       144526, 150242, 340448,  15429, 295352,  41821, 329099, 371191,
       196126, 232007, 394701, 119393, 269342, 328272, 263302,  81929,
       112633, 159533,  77077, 373247, 340298, 315608, 317452,  26503,
        98801, 229880, 305017, 380516, 381278, 179573, 274356, 258674,
       322431, 345461, 363979, 322172, 226354, 291991,  29876, 161525,
       106068, 103855, 307069,  62328, 276894, 395932, 206949, 156696,
        50851, 282614, 169836,  11072, 234903, 193260, 381907, 364253,
       176661, 224942, 112626, 328261, 218184, 379362,  97665,  11490,
       239189, 172736, 320430, 127433,  30955, 378487, 368320, 205428,
        34844, 120513, 303330, 363944, 131431, 399734, 198523,  57818,
       218370, 133887, 148901, 335812,  21380, 206342, 373171,  80592,
       388128, 256710, 113876, 275656, 295370, 114397,  76656,  97062,
       236756, 270984,  81762,  50678, 240288, 314767, 302751, 235310,
        47183, 355176, 317042, 184781, 200107, 143754,  18064, 268091,
       252980, 276621, 214748, 308389, 345714, 305520,  12008, 260410,
       105438, 297458, 233630, 246376, 324967,  10299, 217130, 255163,
       221113, 308563,  32037, 253826, 135588, 349635,  98945,  94617,
       145767,  26083, 339416, 344129, 173731, 291002, 328495, 218665,
       190700, 290985, 292620, 101198, 123724,  83699, 355619, 328518,
       374914,  94789, 181536, 265540,  60023,  82648, 292772, 107170,
        70505, 139408, 212899, 200573, 202817, 165880,  79185, 173155,
       176691,  77473, 216837,  33407, 382804, 365489,  59891, 195668,
       253187,  98269, 389869,  72277,  97555, 283532, 392993, 315254,
       354092,  15074,  34786, 127132,    768, 253664, 159783, 321763,
       302983,  81010,  63275, 260182, 112351,  15661, 284440, 334999,
       345430, 234578, 229547,  90332, 128451, 346651, 304820, 336634,
       321941,  30714, 179609, 261402, 302759, 242878, 278835,  29790,
       202946, 352110,  22010, 246417,  53067,  84749, 121162,  46636,
       209372, 372134, 137687, 126586, 353074,   7507, 295355, 393796,
       168723,  53966, 257678, 348730, 120079, 328824, 215687,  97712,
       142747,  32586,  33996, 266175, 205035, 333746, 373262,  31626,
       139884, 303690, 182961,  92641, 145487, 324585,  60939, 188905,
       346605,  47030, 225901, 238797,  84458, 375329, 198169, 341424,
       146760, 151159, 344832, 287014, 390797,  55524, 221467, 296063,
       238019,  50104, 243288,  46981, 122643, 170930, 269258,  85375,
       277608, 362354, 232077,  62313, 104665, 317370,  29991,  84289,
       198865, 125403, 190760, 256386, 328779, 299595, 263977, 281917,
       149525, 198643, 135381,  22409, 235277,  11557, 221858, 260803])

In [ ]: