In [ ]:
include("../../spikenet/spikenet.jl")
In [2]:
using SpikeNet
using Parameters
using ProgressMeter
using Images
using Interact
Load the natural images dataset from Olshausen & Field:
In [3]:
include("../natimages.jl")
const images = nat_images()
const P = 16 # patch size
Out[3]:
Build SAILnet:
In [4]:
include("sailnet.jl")
Out[4]:
In [56]:
Nx = 16
Ny = 16
N = Nx*Ny
input = InputBuffer(zeros(Float32, P*P))
lifs = LIF(N; γ=1e-3)
W = Synapses(zeros(N,N), Foldiak(α=1e-2))
Q = Synapses(randn(Float32, (N,P*P)), HebbOja(β=1e-4))
ff_path = DensePathway(input, Q, lifs)
rc_path = DensePathway(lifs, W, lifs)
net = SAILnet(input, lifs, ff_path, rc_path);
In [57]:
function train(duration, substeps, net::SAILnet, spike_rec, state_recs...)
reset!.(spike_rec)
reset!.(state_recs)
t = 0
@showprogress 1 "Training..." for stim in 1:duration
# Select a random image patch:
i = rand(1:size(images,3))
x = rand(1:size(images,1)-P)
y = rand(1:size(images,2)-P)
patch = @view images[x:x+P-1, y:y+P-1, i]
# Normalise to zero mean and unit standard deviation:
net.input.X .= reshape(patch, (length(input.X),))
net.input.X .-= mean(net.input.X)
net.input.X ./= std(net.input.X)
# Show it to the network:
t = train_one(t, net, substeps, spike_rec)
# Record the current state:
record!.(state_recs, stim)
end
end
Out[57]:
In [58]:
n_stims = 2_000_000
substeps = 50
spike_rec = RecordedSpikes(net.lifs, 1:n_stims*substeps)
ff_rec = RecordedState(net.ff_path.syns, 1:20000:n_stims, :w);
@time train(n_stims, substeps, net, spike_rec, ff_rec)
println("Recorded $(length(spike_rec.ts)) spikes")
Let's look at the weights of the neurons over time:
In [60]:
@manipulate for step in slider(ff_rec.steps, value=ff_rec.steps[end], label="step")
i = searchsortedfirst(ff_rec.steps, step)
weights = @view ff_rec.arrays[:w][:,:,i]
rfs = normed(reshape(weights, (N,P,P)), true)
imgrid(rfs, Ny; pad=10, padval=0.5)
end
Out[60]:
Now let's look at the spiking activity:
In [71]:
img = raster_spikes(spike_rec, 100_000, 1)
display(colorview(Gray, 1-sqrt.(img/maximum(img))))
And here is a detail view of the spikes emitted for the last few stimuli:
In [62]:
using PyPlot
PyPlot.svg(true);
In [65]:
ts = spike_rec.ts[end-300:end]
ids = spike_rec.id[end-300:end]
fig = figure(figsize=(6,3))
s1, s2 = extrema(ts).÷substeps
for x in s1:2:s2
axvspan(x, x+1, color="gray", alpha=0.1)
end
scatter(ts/substeps, ids, s=1.0, c="k")
autoscale(tight=true)
display(fig)
close(fig)
In [ ]: