In [1]:
%matplotlib inline
import matplotlib.pyplot as plt
import numpy as np
import nengo
In [16]:
N = 10
J = 2000
span = np.linspace(0, 1, N)
image = (np.sin(2*np.pi*span)+1)/2
class AdaptiveWeights(object):
def __init__(self, w, learning_rate):
self.w = w
self.learning_rate = learning_rate
self.pre_value = np.zeros(self.w.shape[1])
self.post_value = np.zeros(self.w.shape[0])
def make_forward_node(self):
return nengo.Node(self.update_forward, size_in=w.shape[1], size_out=w.shape[0])
def make_reverse_node(self):
return nengo.Node(self.update_reverse, size_in=w.shape[0], size_out=w.shape[1])
def update_forward(self, t, x):
self.pre_value[:] = x
return np.dot(self.w, x)
def update_reverse(self, t, x):
self.post_value[:] = x
self.w += self.learning_rate * np.outer(self.post_value, self.pre_value)
return np.dot(-self.w.T, x)
model = nengo.Network()
with model:
stim = nengo.Node(image)
residual = nengo.Ensemble(n_neurons=N, dimensions=1,
neuron_type=nengo.RectifiedLinear(),
gain=nengo.dists.Choice([1]),
bias=nengo.dists.Choice([0]))
v1 = nengo.Ensemble(n_neurons=J, dimensions=1,
neuron_type=nengo.RectifiedLinear(),
gain=nengo.dists.Choice([1]),
bias=nengo.dists.Choice([0]))
nengo.Connection(v1.neurons, v1.neurons, synapse=0)
w = np.random.uniform(-0.0002, 0.0002, (J, N))
adapt = AdaptiveWeights(w, learning_rate=1e-4)
fwd_node = adapt.make_forward_node()
rev_node = adapt.make_reverse_node()
tau = 0.1
#nengo.Connection(residual.neurons, v1.neurons, transform=w, synapse=tau)
nengo.Connection(residual.neurons, fwd_node, synapse=tau)
nengo.Connection(fwd_node, v1.neurons, synapse=None)
nengo.Connection(stim, residual.neurons, synapse=0)
#nengo.Connection(v1.neurons, residual.neurons, transform=-w.T, synapse=tau)
nengo.Connection(v1.neurons, rev_node, synapse=tau)
nengo.Connection(rev_node, residual.neurons, synapse=None)
p_v1 = nengo.Probe(v1.neurons)
p_res = nengo.Probe(residual.neurons)
sim = nengo.Simulator(model)
sim.run(3.0)
In [17]:
plt.plot(sim.trange(), sim.data[p_v1])
plt.show()
In [18]:
recon = np.dot(sim.data[p_v1], w)
plt.imshow(recon, aspect='auto')
plt.colorbar()
plt.figure()
plt.plot(image)
plt.plot(recon[-1])
Out[18]:
In [19]:
plt.imshow(sim.data[p_res][:,:], aspect='auto')
Out[19]:
In [ ]: