Maxwell Pollack
An agent accelerates randomly within a square environment of side length $L$ with inelastic boundaries. With $x^{0,i} \sim \mathcal{U}(0,l)$, $v^{0,i} = 0$, and $\Delta t=1$,
$$ \begin{split} a^{t,i}_\text{self} & \sim \mathcal{N}(0,\mu) \\ v^{t,i} & = v^{t-1,i} + a^{t,i}_\text{self} \\ x^{t,i} & = max\big( 0, min(L, x^{t-1,i} + v^{t,i}) \big) \\ a^{t,i}_{\text{env}} & = -v^{t,i} \big( H(x^{t-1,i} + v^{t,i} - L) + H(-x^{t-1,i}-v^{t,i}) \big)\\ v^{t,i} & = v^{t,i} + a^{t,i}_\text{env} \end{split}$$are computed for each time step $t = \{0,...,T-1\}$ and dimension $i=\{1,2\}$, where $H$ is the Heaviside step function.
In [82]:
import torch
def simulate(trials=1, T=256, L=16, mu=0.1, n_dim=2):
input = torch.zeros(trials,T,2*n_dim)
x = torch.rand(trials,T,n_dim)*(L-1)
v = torch.zeros(trials,T,n_dim)
for t in range(1,T):
a_self = mu * torch.randn(trials,n_dim)
v[:,t] = v[:,t-1] + a_self
x[:,t] = torch.clamp(x[:,t-1] + v[:,t], 0, L-1)
a_env = torch.where((x[:,t-1]+v[:,t] > L-1) | (x[:,t-1]+v[:,t] < 0), -v[:,t], torch.zeros(1))
v[:,t] += a_env
input[:,t] = torch.cat((a_self, a_env), 1)
places = torch.flatten(torch.stack(torch.meshgrid(n_dim*[torch.arange(float(L))]), dim=-1), end_dim=-2)
target = torch.prod(torch.relu(1 - abs(x.unsqueeze(2) - places.expand(1,T,-1,-1))), dim=-1)
return (input, target, x, v)
The components $a^{t,i}_\text{self}$ and $a^{t,i}_\text{env}$ are input to a simple RNN, trained with Adam to minimize the mean squared error between its output and the set of target activations
$$z^{t,k}_\text{target} = max\big(0, 1-|x^{t,0}-k//L|\big) * max\big(0, 1-|x^{t,1}-k\%L|\big) \text{,}$$where $k=\{0,...,L^2-1\}$ is the output neuron index. This mimics a unit-spaced $L \times L$ grid of place cells with a single "bump" of activity tracking the agent's location.
Thereby, the RNN is tasked with inferring the agent's location in the environment from its self-imposed and environmentally-imposed acceleration.
In [ ]:
class Network(torch.nn.Module):
def __init__(self, n_input, n_hidden, n_output):
super(Network, self).__init__()
self.rnn = torch.nn.RNN(n_input, n_hidden, batch_first=True, nonlinearity='relu')
self.linear = torch.nn.Linear(n_hidden, n_output)
def forward(self, input):
hidden = self.rnn(input)[0]
return (torch.nn.functional.leaky_relu(self.linear(hidden)), hidden)
T = 256
L = 16
network = Network(n_input=4, n_hidden=L**2, n_output=L**2)
criterion = torch.nn.MSELoss()
optimizer = torch.optim.Adam(network.parameters(), lr=1e-4)
for trial in range(5*10**5):
input, target, x, v = simulate(1, T, L)
output, hidden = network(input)
loss = criterion(output, target)
loss.backward()
optimizer.step()
optimizer.zero_grad()
The activity of the network is visualized and animated. Hidden neurons are shown sorted in 2D by their input weights from $\vec a_\text{self}$ to illustrate how the network activity tracks velocity by integrating acceleration across time steps.
The network learns to update its location estimate (the spatial distribution of its output activity) in a Bayesian manner. Each collision with a boundary collapses the normal component of the output activity to the boundary coordinate, which then slowly widens again as path integration errors reaccumulate. The network also narrows its output activity by inferring where it could not be during path integration (e.g. if the agent moves 5 steps to the right and doesn't encounter a boundary, it could not have begun less than 5 steps from the right boundary). Visually, this looks like the output activity (moving with the agent) being squash-deformed against the boundary.
In [ ]:
%matplotlib inline
from matplotlib import animation, pyplot as plt
from matplotlib.patches import ConnectionPatch, Circle
input, target, x, v = simulate(1, 10**6, L)
output, hidden = network(input)
# Iteratively sort hidden cells row-wise by a_self^0 weights and column-wise by a_self^1 weights.
Win = next(network.parameters())
h_sorted = torch.argsort(Win[:,1]).view(L,L)
for _ in range(128):
h_sorted = h_sorted.gather(0, torch.argsort(Win[h_sorted,0], dim=0))
h_sorted = h_sorted.gather(1, torch.argsort(Win[h_sorted,1]))
hidden = hidden[...,h_sorted.flatten()]
fig,ax = plt.subplots(1,3, figsize=(13,7), gridspec_kw={'width_ratios':[1,L,L]})
plt.close()
plt.rcParams['font.size'] = 15
frames = [[ax[0].imshow(input[0,t].view(4,1).data.numpy(), vmin=0, vmax=0.5),
ax[0].annotate(r'$\vec a_{self}$', (-0.25,0.75), xytext=(-3,0.75),
xycoords='axes fraction', va='center', annotation_clip=False,
arrowprops=dict(arrowstyle='-[')),
ax[0].annotate(r'$\vec a_{env}$', (-0.25,0.25), xytext=(-3,0.25),
xycoords='axes fraction', va='center', annotation_clip=False,
arrowprops=dict(arrowstyle='-[')),
ax[0].add_patch(ConnectionPatch((1,0.5),(0,0.5),'axes fraction',
mutation_scale=20, axesA=ax[0],
axesB=ax[1], arrowstyle='->')),
ax[0].set_title('Input'),
ax[1].imshow(hidden[0,t].view(L,L).data.numpy().T, vmin=0, vmax=0.5),
ax[1].add_patch(ConnectionPatch((0.55,0),(0.45,0),'axes fraction',
mutation_scale=20, arrowstyle='->',
connectionstyle='arc, angleA=-90, angleB=-90,'+
'rad=15, armA=20, armB=20')),
ax[1].add_patch(ConnectionPatch((1,0.5),(0,0.5),'axes fraction', mutation_scale=20,
axesA=ax[1], axesB=ax[2], arrowstyle='->')),
ax[1].set_title('Hidden'),
ax[2].imshow(output[0,t].view(L,L).data.numpy().T, vmin=0, vmax=0.5),
ax[2].add_patch(Circle(x[0,t], radius=0.5, facecolor='r', edgecolor=None)),
ax[2].set_title('Output')]
for t in range(512)]
for axis in ax:
axis.axis('off')
plt.rcParams['animation.html'] = 'html5'
animation.ArtistAnimation(fig, frames, interval=100)
In [121]:
from scipy.stats import binned_statistic_2d
fig2,ax2 = plt.subplots(L, L, figsize=(16,16), sharex=True, sharey=True)
plt.subplots_adjust(wspace=0.02, hspace=0.02)
fig2.text(0.5, 0.1, r'$W_{in} (a_{self}^0) \longrightarrow$', ha='center')
fig2.text(0.1, 0.5, r'$W_{in} (a_{self}^1) \longrightarrow$', va='center', rotation='vertical')
for cell in range(hidden.shape[-1]):
zAvg = binned_statistic_2d(x[0,:,0].data.numpy(),
x[0,:,1].data.numpy(),
hidden[0,:,cell].data.numpy(),
bins=range(L+1))[0]
ax2[cell//L,cell%L].imshow(zAvg, vmin=0)
ax2[cell//L,cell%L].axis('off')