In [3]:
import numpy as np
#from counting_grid import CountingGrid

In [41]:
class CountingGrid():
    def __init__(self, size, window_size, nr_of_features):
        '''
        size-- a two-dimensional numpy array, indicating the size of the
        Counting Grid
        window_size-- a two-dimensional numpy array, indicating the size
        of the window
        '''
        np.random.seed(7)
        self.size = size
        self.window_size = window_size
        self.nr_of_features = nr_of_features # No. of features
        #Initialize arrays here for pi, h
        #pi is a 3D array, storing the probability distributions
        #that collectively characterize the data. The first dimensionality
        #corresponds to features, e.g. it is Z. The second and third
        #dimension correspond to the size of the grid in x and y directions.
        rand_init_pi = 1 + np.random.rand(self.nr_of_features,self.size[0],self.size[1])        
        self.pi = rand_init_pi/sum(rand_init_pi,0)
        #Test pi    
        #self.pi = np.array([[[0.494, 0.524],[0.479,0.418]],[[0.506, 0.476],[0.521, 0.582]]])       
        self.h = np.zeros((self.nr_of_features,self.size[0],self.size[1]))    
        #self.compute_histograms()
        
      
    def normalize_data(self,X):
        X=X.transpose()
        normalized_X =  np.prod(self.window_size)*np.divide(X.astype('float'),np.sum(X,0)) 
        normalized_X=normalized_X.transpose()       
        return normalized_X  
        
    def compute_sum_in_a_window(self,grid,k,l):
        cumsum1=np.sum(np.sum(grid[:,:k+self.window_size[0],:l+self.window_size[1]],axis=1),axis=1)
        
        cumsum2=np.sum(np.sum(grid[:,:k,:l+self.window_size[1]],axis=1),axis=1)
        
        cumsum3=np.sum(np.sum(grid[:,:k+self.window_size[0],:l],axis=1),axis=1)
        cumsum4=np.sum(np.sum(grid[:,:k,:l],axis=1),axis=1)
        cumsum=cumsum1-cumsum2-cumsum3+cumsum4
        '''
        cumsum1 = grid[self.window_size[0]-1:,self.window_size[1]-1:]
        cumsum2  = grid[:grid.shape[0]-self.window_size[0]+1,self.window_size[1]-1:]
        cumsum3 = grid[self.window_size[0]-1:,:grid.shape[1]-self.window_size[1]+1]    
        cumsum4 = grid[:grid.shape[0]-self.window_size[0]+1,:grid.shape[1]-self.window_size[1]+1]    
        #print cumsum1.shape, cumsum2.shape,cumsum3.shape    
        cumsum =  cumsum1+cumsum2+cumsum3+cumsum4
        cumsum = cumsum[:cumsum.shape[1]-1,:cumsum.shape[1]-1]
        '''
        return cumsum
    
    def compute_histograms(self):
        '''
        Histograms at each point in the grid are computed
        by averaging the distributions pi in a pre-defined
        window. The left upmost corner of the window is placed
        on the grid position and the distributions are averaged.
        '''
        for k in range(0,self.size[0]):
            for l in range(0,self.size[1]):
                self.h[:,k,l]=self.compute_sum_in_a_window(self.pi,k,l)
        self.h=self.h/np.prod(self.window_size)
        print(self.h)
        
    def e_step(self,X):
        '''
        q is a 3D array with shape q.shape=(z_dimension=
        nr_of_samples,x and y=grid_size). 
        It stores the probabilities of a sample mapping to a 
        window in location k=[i1,i2]

        h is a 3D array with shape h.shape(z_dimension=
        nr_of_features, x and y=grid_size). 
        h describes the histograms (spanning along the first axis) from 
        which samples are drawn, in each location on the grid k=[i1,i2]
        '''
        nr_of_samples = X.shape[0]
        #Determine a minimal considered probability, for numerical purposes
        min_numeric_probability = float(1)/(10*self.size[0]*self.size[1])
        #Initialize q
        q_size=(nr_of_samples,self.size[0],self.size[1])
        self.q = np.zeros(q_size)
        self.q = np.exp(np.tensordot(X,np.log(self.h),axes=(1,0)))    
        self.q[self.q<min_numeric_probability]=min_numeric_probability   
        for t in range(0,nr_of_samples):
            normalizer=np.sum(self.q[t,:,:])              
            self.q[t,:,:]= self.q[t,:,:]/normalizer 
        print(self.q)
        
        
    def update_pi(self,X):  
        '''
        Updating the distributions pi on the grid involves
        calculations on data, distributions of mappings of 
        data on the grid log_q and the histograms on each
        grid point.
        '''
        #padded_q=np.lib.pad(self.q, ((0,0),(0,self.window_size[0]),(0,self.window_size[1])),'wrap')
        #padded_h=np.lib.pad(self.h, ((0,0),(0,self.window_size[0]),(0,self.window_size[1])),'wrap')   
        new_pi=np.zeros([self.nr_of_features,self.size[0],self.size[1]])     
    
        for i1 in range(0, self.size[0]):
            for i2 in range(0,self.size[1]):
                for z in range(0,X.shape[1]):
                    t_storage=[]
                    for t in range(0,X.shape[0]):
                        window_sum=self.compute_sum_in_a_window(np.divide(self.q[t,:,:],self.h[z,:,:].reshape(1,self.size[0],self.size[1])),i1,i2)
                        #print(np.divide(self.q[t,:,:],self.h).shape)
                        interm= X[t,z]*window_sum
                        t_storage.append(interm)
                    new_pi[z,i1,i2]=self.pi[z,i1,i2]*sum(t_storage)
        self.pi=new_pi
        normalizer=np.sum(self.pi,0)
        self.pi=np.divide(self.pi,normalizer)
    
    def m_step(self,X):
        self.update_pi(X)
        self.update_h()
    
    def fit(self,X,max_iteration,y=None):
        '''
        This is a function for fitting the counting
        grid using variational Expectation Maximization.

        The data dimensionality is nr_of_samples on first axis,
        and nr_of_features on second axis.

        X= [nr_of_samples, nr_of_features]    
        '''
        X=self.normalize_data(X)
        for i in range(0,max_iteration):
            self.e_step(X)
            self.m_step(X)

        return self.pi, self.q
    
        
    def cg_plot(self,labels):
        '''Currently supports 5 different symbols,
        the labels have to be numbers between 0-4
        for the code to work.
        '''
        lab = np.unique(labels)
        L = len(lab)
        for i in range(0,L):
            ids = np.where(labels==lab[i])[0]
            if i==0:
                marker='o'
            if i==1:
                marker='v'
            if i==2:
                marker='^'
            if i==3:
                marker='*'
            if i==4:
                marker='+'
            for t in range(0,len(ids)):
                temp = self.q[ids[t],:,:]
                x,y = np.unravel_index(temp.argmax(), temp.shape)
                noise = 0.2*np.random.rand(1)
                plt.scatter(x+noise,y+noise, marker=marker,s=60,color=cm.rainbow(i*100))
        plt.show()   

    
                
        
    
class CountingGridTest():
    def test_h_computation(self):
        self.pi
        return 0
    def test_e_step(self):
        return 0
    
    def test_m_step(self):
        return 0
    
    def test_arr_sum(self):
        return 0

In [ ]:
from scipy import io
def filter_by_variance(X,nr_of_features_to_keep):
    
    #Function for thresholding data by variance,
    #keeping only 'nr_of_features_to_keep' features
    #with the highest variance.
    #X=[nr_of_samples, nr_of_features]
      
    ordering = np.argsort(np.var(X,axis=0))[::-1]
    threshold = ordering[0:nr_of_features_to_keep]
    X=X[:,threshold]
    return X

data = io.loadmat('/home/maria/Documents/CountingGrids/lung_bhattacherjee.mat')
X= data['data']
Y_labels = data['sample_names'][0]
X = filter_by_variance(X,10)
print(X)
X=np.exp(X)
cg_obj=CountingGrid(np.array([15,15]),np.array([3,3]),10)
cg_obj.fit(X,10)


[[-3.0910e+01 -1.4320e+01 -2.5490e+01 ... -4.1300e+00 -1.6500e+01
  -2.0160e+01]
 [-2.7440e+01 -9.6500e+00 -1.8340e+01 ...  9.4000e+00 -4.7300e+00
   2.8800e+00]
 [-2.4670e+01 -1.6660e+01 -2.2630e+01 ...  6.1000e-01 -7.6700e+00
  -1.0000e-01]
 ...
 [ 4.1940e+01  5.0500e+01  4.5200e+00 ... -1.7650e+01 -1.7230e+01
  -8.2800e+00]
 [-9.9520e+01 -4.4700e+01 -8.9060e+01 ... -6.4280e+01 -1.0336e+02
  -6.3820e+01]
 [ 2.3700e+00  8.8910e+01  2.4540e+01 ...  8.9300e+01  4.0380e+01
   4.9180e+01]]
[[[0.00444444 0.00444444 0.00444444 ... 0.00444444 0.00444444 0.00444444]
  [0.00444444 0.00444444 0.00444444 ... 0.00444444 0.00444444 0.00444444]
  [0.00444444 0.00444444 0.00444444 ... 0.00444444 0.00444444 0.00444444]
  ...
  [0.00444444 0.00444444 0.00444444 ... 0.00444444 0.00444444 0.00444444]
  [0.00444444 0.00444444 0.00444444 ... 0.00444444 0.00444444 0.00444444]
  [0.00444444 0.00444444 0.00444444 ... 0.00444444 0.00444444 0.00444444]]

 [[0.00444444 0.00444444 0.00444444 ... 0.00444444 0.00444444 0.00444444]
  [0.00444444 0.00444444 0.00444444 ... 0.00444444 0.00444444 0.00444444]
  [0.00444444 0.00444444 0.00444444 ... 0.00444444 0.00444444 0.00444444]
  ...
  [0.00444444 0.00444444 0.00444444 ... 0.00444444 0.00444444 0.00444444]
  [0.00444444 0.00444444 0.00444444 ... 0.00444444 0.00444444 0.00444444]
  [0.00444444 0.00444444 0.00444444 ... 0.00444444 0.00444444 0.00444444]]

 [[0.00444444 0.00444444 0.00444444 ... 0.00444444 0.00444444 0.00444444]
  [0.00444444 0.00444444 0.00444444 ... 0.00444444 0.00444444 0.00444444]
  [0.00444444 0.00444444 0.00444444 ... 0.00444444 0.00444444 0.00444444]
  ...
  [0.00444444 0.00444444 0.00444444 ... 0.00444444 0.00444444 0.00444444]
  [0.00444444 0.00444444 0.00444444 ... 0.00444444 0.00444444 0.00444444]
  [0.00444444 0.00444444 0.00444444 ... 0.00444444 0.00444444 0.00444444]]

 ...

 [[0.00444444 0.00444444 0.00444444 ... 0.00444444 0.00444444 0.00444444]
  [0.00444444 0.00444444 0.00444444 ... 0.00444444 0.00444444 0.00444444]
  [0.00444444 0.00444444 0.00444444 ... 0.00444444 0.00444444 0.00444444]
  ...
  [0.00444444 0.00444444 0.00444444 ... 0.00444444 0.00444444 0.00444444]
  [0.00444444 0.00444444 0.00444444 ... 0.00444444 0.00444444 0.00444444]
  [0.00444444 0.00444444 0.00444444 ... 0.00444444 0.00444444 0.00444444]]

 [[0.00444444 0.00444444 0.00444444 ... 0.00444444 0.00444444 0.00444444]
  [0.00444444 0.00444444 0.00444444 ... 0.00444444 0.00444444 0.00444444]
  [0.00444444 0.00444444 0.00444444 ... 0.00444444 0.00444444 0.00444444]
  ...
  [0.00444444 0.00444444 0.00444444 ... 0.00444444 0.00444444 0.00444444]
  [0.00444444 0.00444444 0.00444444 ... 0.00444444 0.00444444 0.00444444]
  [0.00444444 0.00444444 0.00444444 ... 0.00444444 0.00444444 0.00444444]]

 [[0.00444444 0.00444444 0.00444444 ... 0.00444444 0.00444444 0.00444444]
  [0.00444444 0.00444444 0.00444444 ... 0.00444444 0.00444444 0.00444444]
  [0.00444444 0.00444444 0.00444444 ... 0.00444444 0.00444444 0.00444444]
  ...
  [0.00444444 0.00444444 0.00444444 ... 0.00444444 0.00444444 0.00444444]
  [0.00444444 0.00444444 0.00444444 ... 0.00444444 0.00444444 0.00444444]
  [0.00444444 0.00444444 0.00444444 ... 0.00444444 0.00444444 0.00444444]]]
/home/maria/anaconda3/lib/python3.6/site-packages/ipykernel_launcher.py:28: RuntimeWarning: invalid value encountered in true_divide
/home/maria/anaconda3/lib/python3.6/site-packages/ipykernel_launcher.py:82: RuntimeWarning: divide by zero encountered in log
/home/maria/anaconda3/lib/python3.6/site-packages/ipykernel_launcher.py:83: RuntimeWarning: invalid value encountered in less
/home/maria/anaconda3/lib/python3.6/site-packages/ipykernel_launcher.py:106: RuntimeWarning: divide by zero encountered in true_divide
/home/maria/anaconda3/lib/python3.6/site-packages/ipykernel_launcher.py:39: RuntimeWarning: invalid value encountered in subtract

In [34]:
cg_obj=CountingGrid(np.array([5,5]),np.array([2,2]),3)
cg_obj.compute_histograms()
X=np.array([[1,2,3],[3,4,5]])
cg_obj.e_step(X)
cg_obj.update_pi(X)
print(cg_obj.pi)


[[[0.36276502 0.32504596 0.30099668 0.34584886 0.18393083]
  [0.35711285 0.33432482 0.29840199 0.3014378  0.15966794]
  [0.33329783 0.31668244 0.32518974 0.30349302 0.14223587]
  [0.35678682 0.31643476 0.32710895 0.33937656 0.15327864]
  [0.18600527 0.17766831 0.16153802 0.18086445 0.08851516]]

 [[0.33021305 0.34510168 0.34470857 0.32172688 0.15034454]
  [0.33978469 0.34563195 0.356336   0.33769106 0.15159748]
  [0.32174262 0.3284022  0.33574103 0.33506598 0.1689451 ]
  [0.3155256  0.34243307 0.32610964 0.30984831 0.17292576]
  [0.1618791  0.18581211 0.17761193 0.15112565 0.08072554]]

 [[0.30702193 0.32985236 0.35429475 0.33242426 0.16572464]
  [0.30310247 0.32004323 0.345262   0.36087114 0.18873458]
  [0.34495956 0.35491536 0.33906923 0.361441   0.18881903]
  [0.32768758 0.34113218 0.34678141 0.35077513 0.1737956 ]
  [0.15211563 0.13651958 0.16085005 0.1680099  0.08075931]]]
[[[0.04 0.04 0.04 0.04 0.04]
  [0.04 0.04 0.04 0.04 0.04]
  [0.04 0.04 0.04 0.04 0.04]
  [0.04 0.04 0.04 0.04 0.04]
  [0.04 0.04 0.04 0.04 0.04]]

 [[0.04 0.04 0.04 0.04 0.04]
  [0.04 0.04 0.04 0.04 0.04]
  [0.04 0.04 0.04 0.04 0.04]
  [0.04 0.04 0.04 0.04 0.04]
  [0.04 0.04 0.04 0.04 0.04]]]
[[[0.18220092 0.27378534 0.21290093 0.23499062 0.27426922]
  [0.28561189 0.26646327 0.18143128 0.21890626 0.24458476]
  [0.20382182 0.24068875 0.2614233  0.1972293  0.23263443]
  [0.27006485 0.16522504 0.18191032 0.24420723 0.1639039 ]
  [0.17861119 0.299241   0.17736454 0.22986222 0.22222222]]

 [[0.29115337 0.31720943 0.35527144 0.36819844 0.34567014]
  [0.33663233 0.37004335 0.31343625 0.35469611 0.31454709]
  [0.33798311 0.34150125 0.36959814 0.39496844 0.29842583]
  [0.32057589 0.26045354 0.32581154 0.27690647 0.35686416]
  [0.30230207 0.30302386 0.43081961 0.2968691  0.33333333]]

 [[0.52664571 0.40900523 0.43182763 0.39681094 0.38006064]
  [0.37775577 0.36349338 0.50513246 0.42639763 0.44086815]
  [0.45819507 0.41781    0.36897857 0.40780226 0.46893973]
  [0.40935927 0.57432143 0.49227814 0.4788863  0.47923194]
  [0.51908674 0.39773514 0.39181585 0.47326868 0.44444444]]]

In [9]:
cg_obj=CountingGrid(np.array([5,5]),np.array([2,2]),3)
#X has dimensions different from the counts in 
#MATLAB implementation-- one row is one data vector,
#that's how it is in sci-kit learn
X=np.array([[1,2,3],[3,4,5]])
print(cg_obj.compute_sum_in_a_window(cg_obj.pi,4,2))
print(cg_obj.pi)
#np.sum(cg_obj.pi[0,:2,:2])
print(np.sum(cg_obj.pi[2,1:3,1:3]))


[0.64615207 0.71044772 0.64340021]
[[[0.29635248 0.37632129 0.29822722 0.35062672 0.40694111]
  [0.41083835 0.36754795 0.25808737 0.29704539 0.3287822 ]
  [0.30684227 0.34322282 0.36844113 0.27003406 0.30988956]
  [0.41534963 0.26777659 0.28728922 0.37499452 0.25905392]
  [0.31010278 0.4339183  0.27675492 0.36939715 0.35406063]]

 [[0.31200873 0.32208919 0.36291471 0.33954886 0.30196772]
  [0.32174166 0.36501264 0.33039017 0.34598054 0.29941042]
  [0.335839   0.33654546 0.35057953 0.39839378 0.30697948]
  [0.31600325 0.29858275 0.32790107 0.26608976 0.3688009 ]
  [0.33311521 0.31440118 0.42884726 0.28160045 0.32290215]]

 [[0.39163879 0.30158952 0.33885807 0.30982442 0.29109117]
  [0.26742    0.26743942 0.41152245 0.35697407 0.37180737]
  [0.35731873 0.32023171 0.28097934 0.33157215 0.38313096]
  [0.26864712 0.43364066 0.3848097  0.35891572 0.37214518]
  [0.35678201 0.25168052 0.29439782 0.3490024  0.32303722]]]
1.2801729225906024