In [1]:
# Imports and notebook statements
%load_ext autoreload
%autoreload 2
%matplotlib notebook
%load_ext line_profiler

import torch
from torch.nn.functional import conv2d, conv1d, relu_, sigmoid, hardtanh,  relu, unfold, fold, softmax
if torch.cuda.is_available():
    device = torch.device('cuda')
else:
    device = torch.device('cpu')

localConv=torch.nn.backends.thnn.backend.SpatialConvolutionLocal

import numpy as np
from scipy import stats
import matplotlib
import matplotlib.pyplot as plt
import matplotlib.animation as animation
from mpl_toolkits.mplot3d import Axes3D
from matplotlib.ticker import LinearLocator, FormatStrFormatter
from matplotlib import cm
import seaborn as sns
import pickle as pkl
from sklearn.linear_model import LinearRegression

import warnings
warnings.filterwarnings('ignore')

In [2]:
def printgradnorm(self, grad_input, grad_output):
    print('Inside class:' + self.__class__.__name__)
    print('')
    print('grad_input: ', type(grad_input))
    print('grad_input[0]: ', type(grad_input[0]))
    print('grad_output: ', type(grad_output))
    print('grad_output[0]: ', type(grad_output[0]))
    print('')
    print('grad_input size:', grad_input[0].size())
    print('grad_output size:', grad_output[0].size())
    print('grad_input norm:', grad_input[0].norm())

In [3]:
def kWinnerTakeAll(tensor, k):
    values, indices = torch.topk(tensor, tensor.shape[-1] - k, largest=False, dim=-1)
    new_tensor = tensor + 0
    new_tensor[indices] = 0
    return new_tensor

In [4]:
def buildTrajectory(length, stepSize, width=1., directionStability=0.95, wrap=False, circular=False):
    trajectory = np.zeros((int(length), 2))
    turns = np.zeros((int(length)))
    if circular:
        r = np.sqrt(np.random.rand())*width
        angle = np.random.rand()*2.*np.pi
        x = np.cos(angle)*r
        y = np.sin(angle)*r
    else:
        x = np.random.rand()*width
        y = np.random.rand()*width
    direction = np.random.rand() * 2 * np.pi
    twopi = 2*np.pi
    for i in range(int(length)):
        oldDir = direction
        recenter = 0
        patience = 0
        while True:
            # This is a random value between (-180, +180) scaled by directionStability
            dirChange = ((recenter + (np.random.rand() * twopi) - np.pi) *
                       (1.0 - directionStability + patience))
            direction = (direction + dirChange) % twopi
            rotation = np.asarray([np.cos(direction), np.sin(direction)])
            movement = stepSize*rotation
            if circular:
                position = (movement[0] + x)**2 + (movement[1] + y)**2
                print(np.sqrt(position), width)
                inBounds = np.sqrt(position) < width
            else:
                inBounds = 0 < (movement[0] + x) < width and 0 < (movement[1] + y) < width
            if inBounds or wrap:
                x += movement[0]
                y += movement[1]
                trajectory[i] = (x, y)
                turns[i] = np.abs(oldDir - direction)
                oldDir = direction
                break
            else:
                patience += .5
                recenter = oldDir
                

    return(trajectory, turns)

In [8]:
class L6L4Network(torch.nn.Module):
    def __init__(self,
                 numL6=500,
                 minicols=100,
                 cellsPerMinicolumn=10,
                 dendrites=1000,
                 numGaussians=10,
                 placeSigma=.01,
                 envSize=1.,
                 boostingAlpha=.01,
                 circular=False,
                 BCMLearningRate=.01,
                 BCMAlpha=.1,
                 SGDLearningRate=.01,
                 L6Sparsity=.1,
                 dendriteWeightSparsity=0.1,
                 ):
        
        super(L6L4Network, self).__init__()
        self.minicols = minicols
        self.numL6 = numL6
        self.cellsPerMinicolumn = cellsPerMinicolumn
        self.numDendrites = dendrites
        self.numGaussians = numGaussians
        self.placeSigma=placeSigma
        self.envSize = envSize
        self.boostingAlpha = boostingAlpha
        self.BCMLearningRate = BCMLearningRate
        self.SGDLearningRate = SGDLearningRate
        self.BCMAlpha = BCMAlpha
        self.L6K = int(numL6*L6Sparsity)
        self.L6DendriteK = int(dendriteWeightSparsity*minicols*cellsPerMinicolumn)
        self.L4DendriteK = int(numL6*dendriteWeightSparsity)
        
        self.L6 = torch.nn.RNNCell(dendrites + 2, numL6)
        if device == torch.device('cuda'):
            self.L6 = self.L6.cuda()

        self.L4DendriteWeights = torch.zeros((minicols*cellsPerMinicolumn, dendrites), device=device, dtype=torch.float, 
                                            requires_grad=True)
        self.L4Dendrites = torch.zeros((dendrites, numL6), device=device, dtype=torch.float, 
                                      requires_grad=True)
        
        #self.L6DendriteWeights = torch.zeros((L6, dendrites), device=device, dtype=torch.float)
        self.L6Dendrites = torch.zeros((dendrites, minicols*cellsPerMinicolumn), device=device, dtype=torch.float, 
                                      requires_grad=True)
        
        
        torch.nn.init.kaiming_uniform_(self.L4Dendrites)
        torch.nn.init.kaiming_uniform_(self.L6Dendrites)
        torch.nn.init.kaiming_uniform_(self.L4DendriteWeights)
        
        self.zero = torch.zeros((1,), device=device, dtype=torch.float)
        
        
        self.normalization = torch.nn.LayerNorm(numL6, elementwise_affine=False)
        
        if circular:
            angles = np.random.rand(minicols, numGaussians)*2*np.pi
            radii = np.sqrt(np.random.rand(minicols, numGaussians))*self.envSize
            xComp = np.cos(angles)
            yComp = np.sin(angles)
            places = np.stack([xComp*radii, yComp*radii], axis=-1)
            self.places = torch.tensor(places,
                                       device=device,
                                       dtype=torch.float,
                                       requires_grad=False)
        else:
            self.places = torch.tensor(np.random.rand(minicols, numGaussians, 2)*self.envSize,
                                       device=device,
                                       dtype=torch.float,
                                       requires_grad=False)
        self.circular = circular
        
    def forward(self,
                velocities,
                feedforwards,
                hidden,
                L4,
                L4DendriteHistory,
                L6DendriteHistory,
                L4History,
                BCML4History):
        
        cost = torch.zeros((1,), device=device, dtype=torch.float)
        for i in range(velocities.shape[0]):
            vel = velocities[i]
            feedforward = feedforwards[i]
            L6DendriteActivations = kWinnerTakeAll(self.L6Dendrites, self.L6DendriteK)@(L4.view(L4.numel()))
            relu_(L6DendriteActivations)
            input = torch.cat((L6DendriteActivations, vel))
            hidden = self.L6(input.view(1, input.numel()), hidden)
            hidden = self.normalization(hidden)
            hidden = kWinnerTakeAll(hidden, self.L6K)

            L4DendriteActivations = kWinnerTakeAll(self.L4Dendrites, self.L4DendriteK)@(hidden.view(hidden.numel()))
            #L4DendriteActivations = (L4DendriteActivations**2) + 1.
            L4Predictions = (self.L4DendriteWeights@L4DendriteActivations).view(L4.shape)
            L4Predictions = L4Predictions*(1 - (L4History + 0.01))
            relu_(L4Predictions)
            #plt.matshow(L4Predictions.detach().cpu().numpy()); plt.show()
            L4Predictions = softmax(L4Predictions**2 + 1., dim = -1)
            #plt.matshow(L4Predictions.detach().cpu().numpy()); plt.show()
            L4 = feedforward[:, None]*L4Predictions

            cost = cost + torch.sum(L4 ** 0.5) 
            
            with torch.no_grad():
                L4History = L4Predictions*self.boostingAlpha + L4History*(1 - self.boostingAlpha)
                BCML4History = (L4Predictions **2)*self.BCMAlpha + BCML4History*(1 - self.BCMAlpha)
                L4DendriteHistory = (L4DendriteActivations**2)*self.BCMAlpha +\
                    L4DendriteHistory*(1 - self.BCMAlpha)
                L6DendriteHistory = (L6DendriteActivations **2)*self.BCMAlpha +\
                    L6DendriteHistory*(1 - self.BCMAlpha)
                
                
                if torch.isnan(hidden).any() or \
                    torch.isnan(L4).any() or \
                    torch.isnan(L4DendriteActivations).any(): 
                    import ipdb; ipdb.set_trace()
                self.L4Dendrites = self.L4Dendrites +\
                    self.BCMLearningRate*self.BCMLearn(hidden, L4DendriteActivations, L4DendriteHistory)
                self.L4DendriteWeights = self.L4DendriteWeights +\
                    self.BCMLearningRate*self.BCMLearn(L4DendriteActivations, L4.view(L4.numel()), BCML4History)
                self.L6Dendrites = self.L6Dendrites +\
                    self.BCMLearningRate*self.BCMLearn(L4, L6DendriteActivations, L6DendriteHistory)

                relu_(self.L4DendriteWeights)
                relu_(self.L4Dendrites)
        
        return (torch.sum(cost),
                hidden.detach(),
                L4.detach(),
                L4DendriteHistory.detach(),
                L6DendriteHistory.detach(),
                L4History.detach(),
                BCML4History.detach())

    
    def BCMLearn(self, presyn, postsyn, history):
        """
        Return: update
        """
        postsyn = postsyn.view(postsyn.numel())
        presyn = presyn.view(presyn.numel())
        history = history.view(history.numel())
        delta = torch.ger(postsyn*(postsyn - history), presyn)/(history[:, None] + 0.001)
        
        return (delta * self.BCMLearningRate)    
    

    def learn(self, runningTime, seqLen, speed, stability):
        L4 = torch.zeros((self.minicols, self.cellsPerMinicolumn), device=device, dtype=torch.float,)
        L4History = torch.zeros((self.minicols, self.cellsPerMinicolumn), device=device, dtype=torch.float,)
        BCML4History = torch.zeros((self.minicols, self.cellsPerMinicolumn), device=device, dtype=torch.float)
        L6History = torch.zeros((self.minicols, self.cellsPerMinicolumn), device=device, dtype=torch.float)
        L4DendriteHistory = torch.zeros((self.numDendrites), device=device, dtype=torch.float)
        L6DendriteHistory = torch.zeros((self.numDendrites), device=device, dtype=torch.float)
        hidden = torch.zeros((1, self.numL6,), device=device, dtype=torch.float, requires_grad=True)
        
        
        torch.nn.init.uniform_(L4)
        torch.nn.init.uniform_(L4History)
        torch.nn.init.uniform_(BCML4History)
        torch.nn.init.uniform_(L6History)
        torch.nn.init.uniform_(L4DendriteHistory)
        torch.nn.init.uniform_(L6DendriteHistory)
        torch.nn.init.uniform_(hidden)
        
        trajectory, turns = buildTrajectory(runningTime,
                                            speed,
                                            width=self.envSize,
                                            wrap=False,
                                            directionStability=stability,
                                            circular=self.circular)
        
        velocities = torch.tensor(np.diff(trajectory, axis=0), device=device, dtype=torch.float,
                                 requires_grad=False)
        trajectory = torch.tensor(trajectory, device=device, dtype=torch.float,
                                 requires_grad=False)
        
        cost = torch.zeros((1), device=device, dtype=torch.float, requires_grad=False)
        for run in np.arange((runningTime - 1)/seqLen):
            feedforwards = []
            vels = []
            
            for t in range(seqLen):
                i = int(t + run*seqLen)
                pos = trajectory[i]
                vel = velocities[i]

                distances = torch.zeros((self.minicols,
                                         self.numGaussians, 2),
                                         device=device,
                                         dtype=torch.float,
                                         requires_grad=False)

                distances[:, :, 0] = torch.abs(self.places[:, :, 0] - pos[0])
                distances[:, :, 1] = torch.abs(self.places[:, :, 1] - pos[1])


                activity = torch.exp(-1.*torch.norm(distances, 2, dim=-1)/(2*(self.placeSigma)))
                activity = torch.sum(activity, dim=-1)
                
                feedforwards.append(activity)
                vels.append(vel)
                
            vels = torch.stack((vels), dim=0)
            feedforwards = torch.stack((feedforwards), dim=0)
            
            (cost,
            hidden,
            L4,
            L4DendriteHistory,
            L6DendriteHistory,
            L4History,
            BCML4History) = self.forward(vels,
                            feedforwards,
                            hidden,
                            L4,
                            L4DendriteHistory,
                            L6DendriteHistory,
                            L4History,
                            BCML4History)  

            print(run, cost.detach().cpu().numpy())
            cost.backward()
            with torch.no_grad():
                for param in self.L6.parameters():
                    if torch.isnan(self.SGDLearningRate*param.grad).any():
                        print("NANs in gradient at {}!".format(run))
                        import ipdb; ipdb.set_trace()
                    else:
                        param -= self.SGDLearningRate*param.grad
                    param.grad.zero_()
            velocities=velocities.detach()
            trajectory=trajectory.detach()
            self.places=self.places.detach()
            hidden = hidden.detach()
            L4 = L4.detach()
            L4DendriteHistory = L4DendriteHistory.detach()
            L6DendriteHistory = L6DendriteHistory.detach()
            BCML4History = BCML4History.detach()
            cost = cost.detach()
            self.L4DendriteWeights = self.L4DendriteWeights.detach()
            self.L4Dendrites = self.L4Dendrites.detach()
            self.L6Dendrites = self.L6Dendrites.detach() 
        return(L4)

In [9]:
net = L6L4Network(
                 numL6=100,
                 minicols=100,
                 cellsPerMinicolumn=10,
                 dendrites=100,
                 numGaussians=10,
                 placeSigma=.05,
                 envSize=1.,
                 boostingAlpha=.05,
                 circular=False,
                 BCMLearningRate=.001,
                 BCMAlpha=.1,
                 SGDLearningRate=.0001,
                 L6Sparsity=.5,
                 dendriteWeightSparsity=.5,
                 )

places = net.places.view(net.minicols*net.numGaussians, 2)
plt.figure()
plt.scatter(places[:, 0], places[:, 1])
plt.show()


trajectory, turns = buildTrajectory(50,
                                    .01,
                                    width=net.envSize,
                                    wrap=False,
                                    directionStability=.95,
                                    circular=net.circular)

plt.figure()
plt.plot(trajectory[:, 0], trajectory[:, 1])
plt.show()


---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
<ipython-input-9-ec142001db30> in <module>()
     13                  SGDLearningRate=.0001,
     14                  L6Sparsity=.5,
---> 15                  dendriteWeightSparsity=.5,
     16                  )
     17 

<ipython-input-8-5a4315e5cdb1> in __init__(self, numL6, minicols, cellsPerMinicolumn, dendrites, numGaussians, placeSigma, envSize, boostingAlpha, circular, BCMLearningRate, BCMAlpha, SGDLearningRate, L6Sparsity, dendriteWeightSparsity)
     35         self.L6 = torch.nn.RNNCell(dendrites + 2, numL6)
     36         if device == torch.device('cuda'):
---> 37             self.L6 = self.L6.cuda()
     38 
     39         self.L4DendriteWeights = torch.zeros((minicols*cellsPerMinicolumn, dendrites), device=device, dtype=torch.float, 

~\AppData\Local\conda\conda\envs\pytorch\lib\site-packages\torch\nn\modules\module.py in cuda(self, device)
    247             Module: self
    248         """
--> 249         return self._apply(lambda t: t.cuda(device))
    250 
    251     def cpu(self):

~\AppData\Local\conda\conda\envs\pytorch\lib\site-packages\torch\nn\modules\module.py in _apply(self, fn)
    180                 # Tensors stored in modules are graph leaves, and we don't
    181                 # want to create copy nodes, so we have to unpack the data.
--> 182                 param.data = fn(param.data)
    183                 if param._grad is not None:
    184                     param._grad.data = fn(param._grad.data)

~\AppData\Local\conda\conda\envs\pytorch\lib\site-packages\torch\nn\modules\module.py in <lambda>(t)
    247             Module: self
    248         """
--> 249         return self._apply(lambda t: t.cuda(device))
    250 
    251     def cpu(self):

RuntimeError: cuda runtime error (4) : unspecified launch failure at c:\programdata\miniconda3\conda-bld\pytorch_1524543037166\work\aten\src\thc\generic/THCTensorCopy.c:20

In [ ]:
for i in range(1):
#    print(i*5)
    result = net.learn(100, 5, .01, .95)
#    print(torch.mean(net.L4Dendrites).cpu().numpy())
#     print(torch.mean(net.L4DendriteWeights).cpu().numpy())
#     print(torch.mean(net.L6Dendrites).cpu().numpy())

In [ ]:
plt.matshow(result)

In [ ]:
plt.matshow(net.L4Dendrites)

In [ ]: