In [6]:
using Mocha
using PyPlot
using WaveletScattering
ws = WaveletScattering
Out[6]:
In [ ]:
Q = 8
data = zeros(Float32, 1<<13, 1)
data[1] = 1.0f0
spec = ws.Spec1D(log2_size=13, max_qualityfactor=Q, n_filters_per_octave=2*Q)
bank = ws.Bank1D(spec, max_log2_stride=0)
backend = Mocha.CPUBackend()
signal = ws.InputLayer(
data = data,
tops = [:signal],
symbols = [:time, :chunk])
fourier = ws.FourierLayer(
bottoms = [:signal],
pathkeys = [ws.PathKey(:time)],
tops = [:fourier])
wavelets = ws.WaveletLayer(
bank = bank,
bottoms = [:fourier],
tops = [:wavelets])
invfourier = ws.InvFourierLayer(
bottoms = [:wavelets],
pathkeys = [ws.PathKey(:time)],
tops = [:invfourier])
modulus = ws.PointwiseLayer(
bottoms = [:invfourier],
tops = [:modulus],
ρ = ws.Modulus())
layers = Mocha.Layer[
signal,
fourier,
wavelets,
invfourier,
modulus]
Mocha.init(backend)
@time net = Mocha.Net("network", backend, layers);
In [5]:
paths = [ ws.Path(ws.PathKey(:j,:time) => j) for j in bank.behavior.j_range]
octaves = [ net.states[end].blobs[1].nodes[path].data for path in paths]
U1 = cat(4, octaves...)
U1 = reshape(U1, size(U1, 1), size(U1, 3) * size(U1, 4))
X = -fftshift(U1, 1).'
imshow(X[:, :], aspect="auto", cmap=ColorMap("magma"))
Out[5]:
In [118]:
norm(data)
Out[118]:
In [119]:
norm(X)
Out[119]:
In [ ]: