In [1]:
# Imports and notebook statements
%load_ext autoreload
%autoreload 2
%matplotlib notebook
%load_ext line_profiler
import torch
from torch.nn.functional import conv2d, relu_
if torch.cuda.is_available():
device = torch.device('cuda')
else:
device = torch.device('cpu')
import numpy as np
from scipy import stats
import matplotlib
import matplotlib.pyplot as plt
import matplotlib.animation as animation
import seaborn as sns
from sklearn.linear_model import LinearRegression
from utils import *
import warnings
warnings.filterwarnings('ignore')
In [26]:
# STDP kernel time constant in seconds. Used for the default kernel.
STDP_TIME_CONSTANT = 0.012
def w_0(x):
"""
@param x (numpy array)
A distance
"""
a = 1.00
lambda_net = 13.0
beta = 3.0 / lambda_net ** 2
gamma = 1.05 * beta
return a * np.exp(-gamma * x) - np.exp(-beta * x)
def w_1(x):
"""
@param x (numpy array)
A distance
"""
lambda_net = 13.0
beta = 3.15 / lambda_net ** 2
return - np.exp(-beta * x)
# Random walk builder
def buildTrajectory(length, stepSize, width=1., directionStability=0.95, wrap=False):
trajectory = np.zeros((int(length), 2))
x = np.random.rand()
y = np.random.rand()
direction = np.random.rand() * 2 * np.pi
twopi = 2*np.pi
for i in range(int(length)):
while True:
# This is a random value between (-180, +180) scaled by directionStability
dirChange = (((np.random.rand() * twopi) - np.pi) *
(1.0 - directionStability))
direction = (direction + dirChange) % twopi
rotation = np.asarray([np.cos(direction), np.sin(direction)])
movement = stepSize*rotation
if 0 < (movement[0] + x) < 1 and 0 < (movement[1] + y) < 1 or wrap:
x += movement[0]
y += movement[1]
trajectory[i] = (x, y)
break
return(trajectory)
In [212]:
class GCN2D(object):
def __init__(self,
numX,
numY,
inhibitionWindow,
inhibitionRadius,
inhibitionStrength,
boostEffect=10,
boostDecay=3.,
dt=0.001,
numPlaces=200,
globalTonic=20,
decayConstant=0.03,
envelopeWidth=0.25,
envelopeFactor=10,
stdpWindow=10,
sigmaLoc=0.05,
learningRate=0.015,
negativeLearnFactor=.9,
initialWeightFactor=.2,
weightDecay=60,
boostGradientX=1,
wideningFactor=0,
):
self.activity = torch.zeros([1., 1., numX, numY], device=device, dtype=torch.float)
self.filter = torch.zeros([1, 1, 1+2*inhibitionWindow, 1+2*inhibitionWindow], dtype=torch.float,
device=device)
self.numX = numX
self.numY = numY
self.numPlaces=numPlaces
for i in range(1+2*inhibitionWindow):
for j in range(1+2*inhibitionWindow):
xComp = np.abs(i - (inhibitionWindow))
yComp = np.abs(j - (inhibitionWindow))
dist = np.asarray((xComp, yComp))
dist = dist[0] ** 2 + dist[1] ** 2
dist = max(dist - wideningFactor, 0)
if dist <= 0:
weight = 0.
else:
weight = w_1(dist/inhibitionRadius)*inhibitionStrength
self.filter[0, 0, i, j] = weight
self.activationHistory = torch.zeros([1, 1, numX, numY], device=device, dtype=torch.float)
self.instantaneous = torch.zeros([1, 1, numX, numY], device=device, dtype=torch.float)
self.boostEffect = torch.tensor(np.repeat(
np.linspace(1, boostGradientX, self.numX)[:, np.newaxis], self.numY, axis=-1)*boostEffect,
device=device, dtype=torch.float)
self.boostDecay = boostDecay
self.dt = dt
self.globalTonic = torch.tensor([globalTonic], device=device, dtype=torch.float)
self.decay = decayConstant
self.inhibitionWindow = inhibitionWindow
self.envelopeWidth = envelopeWidth
self.envelopeFactor = envelopeFactor
self.sigmaLoc = 0.01
self.learningRate = learningRate
self.negativeLearnFactor = negativeLearnFactor
self.weightDecay = weightDecay
self.zero = torch.zeros([1], device=device, dtype=torch.float)
self.places = torch.tensor(np.random.rand(numPlaces, 2), device=device, dtype=torch.float)
self.placeWeights = torch.tensor(np.random.rand(numX, numY, numPlaces)*initialWeightFactor,
device=device, dtype=torch.float)
self.placeActivity = torch.zeros([numPlaces,], device=device, dtype=torch.float)
self.envelope = torch.tensor(self.computeEnvelope(), device=device, dtype=torch.float)
self.stdpWindow = stdpWindow
def computeEnvelope(self):
"""
Compute an envelope for use in suppressing border cells.
:return: A numpy array that can be elementwise-multiplied with activations
for the given cell population to apply the envelope.
"""
# envelope = np.zeros((self.numX, self.numY))
# for i, ip in enumerate(np.linspace(-1, 1, self.numX)):
# for j, jp in enumerate( np.linspace(-1, 1, self.numY)):
# dist = np.sqrt(ip**2 + jp**2)
# if dist < 1 - self.envelopeWidth:
# envelope[i, j] = 1.
# else:
# envelope[i, j] = np.exp(-1.*self.envelopeFactor *
# ((dist - 1 + self.envelopeWidth)/self.envelopeWidth)**2)
envelopeX = [1 if self.numX/2. - np.abs(p) > self.envelopeWidth else
np.exp(-1.*self.envelopeFactor *
((-self.numX/2. + np.abs(p) + self.envelopeWidth)/self.envelopeWidth)**2)
for p in np.arange(self.numX) - self.numX/2.]
envelopeY = [1 if self.numY/2. - np.abs(p) > self.envelopeWidth else
np.exp(-1.*self.envelopeFactor *
((-self.numY/2. + np.abs(p) + self.envelopeWidth)/self.envelopeWidth)**2)
for p in np.arange(self.numY) - self.numY/2.]
return np.outer(envelopeX, envelopeY)
def randomLesions(self, numLesions, lesionRadius, lesionInnerCutoff):
lesions = []
for i in range(numLesions):
x = int(np.random.rand()*self.numX)
y = int(np.random.rand()*self.numY)
lesions.append((x, y))
radii = [lesionRadius] * numLesions
cutoffs = [lesionInnerCutoff] * numLesions
self.addLesions(lesions, radii, cutoffs)
def addLesions(self, lesionCenters, lesionRadii, lesionInnerCutoffs):
for center, radius, cutoff in zip(lesionCenters, lesionRadii, lesionInnerCutoffs):
for x in range(self.numX):
for y in range(self.numY):
distance = np.sqrt((x - center[0])**2 + (y - center[1])**2)
if distance < cutoff:
self.envelope[x, y] = 0.
elif distance < radius:
value = (distance - cutoff)/(radius - cutoff)
self.envelope[x, y] = min(value, self.envelope[x, y])
def step(self, speed=1, place=True):
if place:
self.instantaneous = torch.matmul(self.placeWeights, self.placeActivity).view(1, 1, self.numX, self.numY)
else:
self.instantaneous.fill_(0.)
self.instantaneous += conv2d(self.activity, self.filter, padding=self.inhibitionWindow)
self.instantaneous *= self.envelope
self.instantaneous += self.activationHistory * self.boostEffect
self.instantaneous *= min(speed, 1)
relu_(self.instantaneous)
self.activity += (self.instantaneous - self.activity/self.decay)*self.dt
#torch.min(self.activity, self.zero + 1., out=self.activity)
self.activationHistory += (self.globalTonic - self.activity)*self.dt*self.envelope #torch.sum(self.activity)/torch.sum(self.envelope) -
self.activationHistory -= self.dt*self.activationHistory/self.boostDecay
def simulate(self, time, logFreq = 10, startFrom = 0):
self.activity = torch.tensor(np.random.rand(1, 1, self.numX, self.numY)*0.1, device=device,
dtype=torch.float)
self.activationHistory.fill_(self.globalTonic[0])
numSteps = int(time/self.dt)
numLogs = int(((time - startFrom)/self.dt)/logFreq)
output = torch.zeros([numLogs, self.numX, self.numY], device=device, dtype=torch.float)
s = 0
for t in range(numSteps):
self.step(place=False)
if t % logFreq == 0 and t*self.dt >= startFrom:
print("At {}".format(t*self.dt))
output[s].copy_(self.activity.view(self.numX, self.numY))
s += 1
return output.cpu().numpy()
def decayWeights(self):
"""
Only decay place weights
"""
self.placeWeights -= self.dt*self.placeWeights/self.weightDecay
def learn(self, time, plotting=True, plotInterval=100, runLength=10, oneD=False):
if plotting:
fig, (ax1, ax2, ax3) = plt.subplots(3,1)
ax1.scatter(self.places[:, 0].cpu().numpy(),
self.places[:, 1].cpu().numpy(),
c = self.placeActivity.cpu().numpy(),
cmap = plt.get_cmap("coolwarm"))
ax2.matshow(self.activity.view((self.numX, self.numY)).cpu().numpy())
im = ax3.scatter(self.places[:, 0].cpu().numpy(),
self.places[:, 1].cpu().numpy(),
c = self.placeWeights[self.numX//2, self.numY//2, :].cpu().numpy(),
cmap = plt.get_cmap("coolwarm"))
fig.colorbar(im, ax=ax3)
plt.show()
self.activityBuffer = torch.zeros([self.stdpWindow, self.numX*self.numY],
device=device, dtype=torch.float)
self.placeBuffer = torch.zeros([self.stdpWindow, self.numPlaces],
device=device, dtype=torch.float)
self.stdpValues = torch.tensor(np.exp(-self.dt*np.arange(0, self.stdpWindow)/STDP_TIME_CONSTANT),
device=device, dtype=torch.float)*self.learningRate*self.dt
self.bufferIndex = 0
times = np.arange(0, time, self.dt)
self.activity = torch.tensor(np.random.rand(1, 1, self.numX, self.numY)*0.1, device=device,
dtype=torch.float)
self.activationHistory.fill_(self.globalTonic[0])
self.activationHistory *= self.envelope
#trajectory = np.zeros((len(times), 2))
# trajectories = []
# times = []
# oldPosition=np.asarray([0.5, 0.5])
# for t in np.arange(0, time, runLength):
# currentTimes = np.arange(0, runLength, self.dt)
# newDest = np.random.sample((2,))
# movement = np.sin(currentTimes*np.pi/(runLength*2))
# trajectory = np.outer(movement, newDest) + np.outer(1 - movement, oldPosition)
# trajectories.append(trajectory)
# times.append(currentTimes + t)
# oldPosition = newDest
# trajectory = np.concatenate(trajectories, axis=0)
# times = np.concatenate(times)
#trajectory[:, 0] = (np.sin((times * np.pi / (10*1.34754)) + offsets[0]) + 1)/2
#trajectory[:, 1] = (np.sin((times * np.pi / (10*1.6383478)) + offsets[1]) + 1)/2
trajectory = buildTrajectory(len(times), 1*self.dt, wrap=True, directionStability=0.95)
if oneD:
trajectory[:, 1] = 0.
velocity = np.diff(trajectory, axis=0)/self.dt
trajectory = np.mod(trajectory, 1)
trajectory = torch.tensor(trajectory, device=device, dtype=torch.float)
velocity = torch.tensor(velocity, device=device, dtype=torch.float)
speed = torch.norm(velocity, 2, dim=-1)
distances = torch.zeros((self.numPlaces, 2), device=device, dtype=torch.float)
for i, t in enumerate(times[:-1]):
pos = trajectory[i, :]
s = min(speed[i]/torch.mean(speed), 1.)
distances[:,0] = torch.min(torch.abs(self.places[:,0] - pos[0]), 1 - torch.abs(self.places[:,0] - pos[0]))
if not oneD:
distances[:,1] = torch.min(torch.abs(self.places[:,1] - pos[1]), 1 - torch.abs(self.places[:,1] - pos[1]))
else:
distances[:,1] = 0.
torch.exp(-1.*torch.norm(distances, 2, dim=-1)/(2*(self.sigmaLoc)), out=self.placeActivity)
self.placeActivity *= s
self.step(speed=s)
self.stdpUpdate(i)
torch.max(self.placeWeights, self.zero, out=self.placeWeights)
torch.min(self.placeWeights, self.zero + 2., out=self.placeWeights)
#self.placeWeights *= .6667/torch.mean(self.placeWeights, -1, keepdim=True)
#self.decayWeights()
if i % plotInterval == 0:
if plotting:
ax1.scatter(self.places[:, 0].cpu().numpy(),
self.places[:, 1].cpu().numpy(),
c = self.placeActivity.cpu().numpy(),
cmap = plt.get_cmap("coolwarm"))
ax2.matshow(self.activity.view((self.numX, self.numY)).cpu().numpy())
im = ax3.scatter(self.places[:, 0].cpu().numpy(),
self.places[:, 1].cpu().numpy(),
c = self.placeWeights[self.numX//2, self.numY//2, :].cpu().numpy(),
cmap = plt.get_cmap("coolwarm"))
ax1.set_title(str(t))
fig.canvas.draw()
def stdpUpdate(self, time, clearBuffer=False):
if time < self.activityBuffer.shape[0]:
self.activityBuffer[self.bufferIndex].copy_(self.activity.view(self.numX*self.numY,))
self.placeBuffer[self.bufferIndex].copy_(self.placeActivity)
self.bufferIndex += 1
else:
for t in range(self.stdpWindow):
i = (self.bufferIndex - t) % self.stdpWindow
self.placeWeights += torch.ger(self.activity.view(self.numX*self.numY), self.placeBuffer[i]* \
self.stdpValues[t]).view(self.numX, self.numY, self.numPlaces)
self.placeWeights -= (torch.ger(self.activityBuffer[i], self.placeActivity) *\
self.stdpValues[t]).view(self.numX, self.numY, self.numPlaces) *\
self.negativeLearnFactor
self.bufferIndex = (self.bufferIndex + 1) % self.stdpWindow
self.activityBuffer[self.bufferIndex].copy_(self.activity.view(self.numX*self.numY,))
self.placeBuffer[self.bufferIndex].copy_(self.placeActivity)
if clearBuffer:
pass
In [218]:
plt.rcParams['figure.figsize'] = [5, 5]
GCN = GCN2D(32,
32,
9,
.3,
25.,
globalTonic=.25,
stdpWindow=1,
dt=0.01,
boostEffect=50,
boostDecay=100.,
numPlaces=1000,
learningRate=1.,
initialWeightFactor=.1,
boostGradientX=1,
weightDecay=500,
wideningFactor=2,
negativeLearnFactor=1.,
envelopeWidth=12,
envelopeFactor=1.2,
sigmaLoc=.0005)
#GCN.randomLesions(10, 10, 5)
plt.matshow(GCN.filter[0,0])
plt.show()
plt.matshow(GCN.envelope)
plt.show()
# plt.matshow(GCN.boostEffect)
# plt.show()
# plt.figure()
# plt.scatter(GCN.places[:, 0].cpu().numpy(),
# GCN.places[:, 1].cpu().numpy())
# plt.show()
In [214]:
plt.rcParams['figure.figsize'] = [5, 5]
results = GCN.simulate(100, logFreq=10, startFrom = 0)
plt.matshow(results[-2])
plt.show()
In [219]:
plt.rcParams['figure.figsize'] = [5, 15]
for i in range(1000):
print(i, GCN.learningRate)
GCN.learningRate /= 1.01
GCN.learn(25, plotting=False, plotInterval=1000, oneD=True)
weights = GCN.placeWeights.cpu().numpy()
if i % 50 == 0:
with open("PlaceWeights{}.npz".format(i), "wb") as f:
np.savez(f, weights)
In [ ]:
print(torch.max(GCN.activity)[0])
print(torch.max(GCN.placeWeights, dim=-1)[0])
print(torch.min(GCN.placeWeights, dim=-1)[0])
print(torch.mean(GCN.placeWeights, dim=-1))
In [222]:
plt.rcParams['figure.figsize'] = [5, 15]
GCN.learn(5000, plotting=True, plotInterval=10, oneD=True)
In [220]:
plt.rcParams['figure.figsize'] = [15, 15]
plt.figure()
start = 12
end = 20
for i in range(start, end):
for j in range(start, end):
index = (i - start)*(end - start) + (j - start) + 1
plt.subplot((end - start), (end - start), index)
plt.scatter(GCN.places[:, 0].cpu().numpy(),
GCN.places[:, 1].cpu().numpy(),
c = GCN.placeWeights[i, j, :].cpu().numpy(),
cmap = plt.get_cmap("coolwarm"))
plt.draw()