In [1]:
%matplotlib inline
import numpy as np
import nengo
import matplotlib.pyplot as plt

In [4]:
freq = 1.0

def func_x(t):
    return (np.sin(t*2*np.pi*freq), np.cos(t*2*np.pi*freq))

def func_y(t):
    return (np.cos(t*2*np.pi*freq), np.sin(t*2*np.pi*freq), 1.0)

class Memory(nengo.Node):
    def __init__(self, D, n_steps):
        self.mem = np.zeros((n_steps, D))
        super().__init__(self.update, size_in=D, size_out=D*n_steps)
    def update(self, t, x):
        self.mem = np.roll(self.mem, self.mem.shape[1])
        self.mem[0] = x
        return self.mem.flat

M = 100

seed = 1

model = nengo.Network()
with model:
    node_x = nengo.Node(func_x)
    node_y = nengo.Node(func_y)    
    D = node_y.size_out
    
    state = nengo.Ensemble(n_neurons=500, dimensions=node_x.size_out, neuron_type=nengo.LIFRate(), seed=seed)
    
    predict_y = nengo.Node(None, size_in=D*M)
    def ideal_predict(x):
        theta = np.arctan2(x[0], x[1])/(2*np.pi)
        dt = 0.001
        t = np.arange(M)*dt + theta
        return np.hstack([func_y(tt) for tt in t]).flatten()
    c = nengo.Connection(state, predict_y, function=ideal_predict, synapse=None)
sim = nengo.Simulator(model)
decoder = sim.data[c].weights


0%
 

In [131]:
class TemporalPES(nengo.Network):
    def __init__(self, decoder, n_steps, learning_rate=1e-8):
        self.w = np.array(decoder)
        self.error_value = np.zeros(decoder.shape[0])
        self.learning_rate = learning_rate

        super().__init__()
        with self:
            self.input = nengo.Node(None, size_in=decoder.shape[1])
            self.output = nengo.Node(None, size_in=decoder.shape[0])
            self.error = nengo.Node(self.update_error, size_in=decoder.shape[0], size_out=decoder.shape[0])
            self.activity_mem = Memory(decoder.shape[1], n_steps)
            nengo.Connection(self.input, self.activity_mem, synapse=None)
            self.decoder = nengo.Node(self.do_decoder, size_in=decoder.shape[1], size_out=decoder.shape[0])
            nengo.Connection(self.input, self.decoder, synapse=None)
            nengo.Connection(self.decoder, self.output, synapse=None)
            
    def update_error(self, t, x):
        self.error_value = x
        return x
    def do_decoder(self, t, activity):
        for i in range(D):
            for j in range(M):
                actv = self.activity_mem.mem[j]
                err = self.error_value[j*D+i]
                self.w[j*D+i] -= actv * err * self.learning_rate / (j+1)
        
        output = np.dot(self.w, activity)
        return output

model = nengo.Network()
with model:
    model.add(state)
    node_x = nengo.Node(func_x)
    node_y = nengo.Node(func_y)
    
    D = node_y.size_out
    
    nengo.Connection(node_x, state, synapse=None)
    
    predict_y = nengo.Node(None, size_in=D*M)
    
    temporal_pes = TemporalPES(decoder*0, M)
    nengo.Connection(state.neurons, temporal_pes.input, synapse=None)
    nengo.Connection(temporal_pes.output, predict_y, synapse=None)
    
    p_predict = nengo.Probe(predict_y)
    
    mem_predict = Memory(D*M, M)
    nengo.Connection(predict_y, mem_predict, synapse=None)
    
    p_predict_mem = nengo.Probe(mem_predict)
    
    error = nengo.Node(None, size_in=M*D)
    T = np.zeros((M*D, D))
    for i in range(D):
        T[i::D, i] = 1
    nengo.Connection(node_y, error, transform=-T, synapse=None)
    T2 = np.zeros((M*D, D*M*M))
    for i in range(D):
        for j in range(M):
            T2[j*D+i, (j)*M*D+(j)*D+i] = 1
    nengo.Connection(mem_predict, error, transform=T2, synapse=None)
    
    nengo.Connection(error, temporal_pes.error, synapse=None)
    
    p_error = nengo.Probe(error)
    
    
    mem_y = Memory(node_y.size_out, n_steps=M)
    nengo.Connection(node_y, mem_y, synapse=None)
    
    p_x = nengo.Probe(node_x)
    p_y = nengo.Probe(node_y)
    
    
    
    p_mem = nengo.Probe(mem_y)
    
sim = nengo.Simulator(model)
sim.run(10.0)


0%
 
0%
 

In [132]:
plt.imshow(temporal_pes.activity_mem.mem, aspect='auto')


Out[132]:
<matplotlib.image.AxesImage at 0x7f64f3b56908>

In [133]:
plt.subplot(2, 1, 1)
plt.plot(sim.trange(), sim.data[p_x])
plt.subplot(2, 1, 2)
plt.plot(sim.trange(), sim.data[p_y])
plt.show()



In [134]:
data = sim.data[p_predict]

plt.figure(figsize=(12,8))
D = node_y.size_out
for i in range(D):
    plt.subplot(D, 1, i+1)
    plt.plot(sim.trange(), data[:,i::D])
plt.show()



In [135]:
data = sim.data[p_predict_mem]

data.shape = data.shape[0], M, M, node_y.size_out

In [136]:
plt.plot([data[500, i, i, 1] for i in range(M)])
#plt.ylim(-1,1)


Out[136]:
[<matplotlib.lines.Line2D at 0x7f6537c8cfd0>]

In [139]:
data = sim.data[p_error]
plt.plot(data[:,1::3])
plt.show()



In [ ]:


In [ ]: