In [ ]:
import math
import json

import numpy as np

import htmresearchviz0.IPython_support
from htmresearchviz0.IPython_support import (printSpikeRatesSnapshot,
                                             printSpikeRatesTimeline,
                                             printInputWeights,
                                             printOutputWeights)
htmresearchviz0.IPython_support.init_notebook_mode()

In [ ]:
def w_0(x):
    """
    @param x (numpy array)
    A point
    """
    a = 1.0
    lambda_net = 13.0
    beta = 3.0 / lambda_net**2
    gamma = 1.05 * beta
    
    x_length_squared = x[0]**2 + x[1]**2
    
    return a*math.exp(-gamma*x_length_squared) - math.exp(-beta*x_length_squared)


class ContinuousAttractorModule(object):
    """
    Implementation of the Burak/Fiete 2009 attractor model. (With wrap-around topology)
    """
    def __init__(self, dimensions, dt=0.0005):
        
        self.dt = dt
        
        self.preferredDirections = {
            "n": np.array([-1.0, 0.0]),
            "e": np.array([0.0, 1.0]),
            "s": np.array([1.0, 0.0]),
            "w": np.array([0.0, -1.0])
        }
        
        self.dimensions = np.array(dimensions, dtype="int")
        
        self.numCellGroups = self.dimensions[0] * self.dimensions[1]

        self.firingRates = dict((k, np.zeros(self.numCellGroups, dtype="float"))
                                for k in self.preferredDirections.iterkeys())
        self.recurrentWeights = dict((k, np.zeros((self.numCellGroups, self.numCellGroups), dtype="float"))
                                     for k in self.preferredDirections.iterkeys())
        
        for k, preferredDirection in self.preferredDirections.iteritems():
            
            # Calculate it once
            jCoord0 = np.unravel_index(0, self.dimensions)

            assert self.dimensions[0] == self.dimensions[1]
            jTargetCoord = np.mod(jCoord0 + preferredDirection, self.dimensions[0])
            
            weights = np.zeros(self.dimensions, dtype="float")

            for i in xrange(self.numCellGroups):
                iCoord = np.unravel_index(i, self.dimensions)

                distanceComponents1 = np.abs(iCoord - jTargetCoord)

                # The two points might actually be closer by wrapping around one/two of the edges.
                # For each dimension, consider what the alternate distance would have been,
                # and choose the lowest.
                distanceComponents2 = float(self.dimensions[0]) - distanceComponents1
                distanceComponents = np.where(distanceComponents1 < distanceComponents2,
                                              distanceComponents1, distanceComponents2)
                
                weights[iCoord] = w_0(distanceComponents)
                
                
            for j in xrange(self.numCellGroups):

                jCoord = np.unravel_index(j, self.dimensions)
                
                self.recurrentWeights[k][:,j] = np.roll(np.roll(weights, jCoord[0], axis=0),
                        jCoord[1], axis=1).flatten()
                #self.recurrentWeights[k][:,j] = np.roll(weights, tuple(jCoord), (0, 1,)).flatten()

                    
    def step(self, v):
        alpha = 0.10315
        tau = 0.010
        
        # Recurrent input is the same for all preferred directions.
        # So we only need to calculate it once.
        recurrentInput = np.zeros(self.numCellGroups, dtype="float")

        for k in self.preferredDirections.iterkeys():
            recurrentInput += np.dot(self.recurrentWeights[k],
                                     self.firingRates[k])

        for k, preferredDirection in self.preferredDirections.iteritems():
            feedforwardInput = 1.0 + alpha*np.dot(preferredDirection, v)
            
            totalInput = recurrentInput + feedforwardInput
            totalInput[totalInput < 0.0] = 0.0
            
            dsdt = (totalInput - self.firingRates[k]) / tau
            
            ds = dsdt * self.dt
            
            self.firingRates[k] += ds

Recurrent connections


In [ ]:
dimensions = (32, 32)
can = ContinuousAttractorModule(dimensions)

Here are the outputs of each cell. Hover over a cell to see its inhibitory output. Red means strong inhibition.

(It might take a second to load)


In [ ]:
printOutputWeights(json.dumps({
    "dimensions": dimensions,
    "inputMatrices": dict((k, weights.tolist())
                          for k, weights in can.recurrentWeights.iteritems())
}))

Here are the inputs to each cell. Hover over a cell to see its inhibitory input.


In [ ]:
printInputWeights(json.dumps({
    "dimensions": dimensions,
    "inputMatrices": dict((k, weights.tolist())
                          for k, weights in can.recurrentWeights.iteritems())
}))

Lattice orientation


In [ ]:
def orientationExperiment():
    dimensions = (32,32)
    recording = {
        "dimensions": dimensions,
        "timesteps": [],
    }

    can = ContinuousAttractorModule(dimensions)

    for k, rates in can.firingRates.iteritems():
        rates[:] = np.random.rand(dimensions[0]*dimensions[1])

    for t in xrange(500):
        if t % 10 == 0:
            recording["timesteps"].append(
                dict((k, rates.tolist())
                     for k, rates in can.firingRates.iteritems()))

        can.step(np.array([0.0, 0.0]))

    recording["timesteps"].append(
        dict((k, rates.tolist())
             for k, rates in can.firingRates.iteritems()))
    
    return recording

Sometimes the lattice is aligned with the x axis, and sometimes with the y axis. (It's random.)


In [ ]:
recording1 = orientationExperiment()

In [ ]:
printSpikeRatesTimeline(json.dumps(recording1))

In [ ]:
def orientationExperiment2():
    dimensions = (16,16)
    recording = {
        "dimensions": dimensions,
        "timesteps": [],
    }

    can = ContinuousAttractorModule(dimensions)

    for k, rates in can.firingRates.iteritems():
        rates[:] = np.random.rand(dimensions[0]*dimensions[1])
        
    for t in xrange(500):

        can.step(np.array([0.0, 0.0]))


    for t in xrange(1500):
        if t % 100 == 0:
            recording["timesteps"].append(
                dict((k, rates.tolist())
                     for k, rates in can.firingRates.iteritems()))

        can.step(np.array([1.5, 0.0]))

    recording["timesteps"].append(
        dict((k, rates.tolist())
             for k, rates in can.firingRates.iteritems()))
    
    return recording

In [ ]:
recording2 = orientationExperiment2()

In [ ]:
printSpikeRatesTimeline(json.dumps(recording2))