In [6]:
using Mocha
using PyPlot
using WaveletScattering
ws = WaveletScattering


Out[6]:
WaveletScattering

In [ ]:
Q = 8

data = zeros(Float32, 1<<13, 1)
data[1] = 1.0f0
spec = ws.Spec1D(log2_size=13, max_qualityfactor=Q, n_filters_per_octave=2*Q)
bank = ws.Bank1D(spec, max_log2_stride=0)

backend = Mocha.CPUBackend()
signal = ws.InputLayer(
    data = data,
    tops = [:signal],
    symbols = [:time, :chunk])

fourier = ws.FourierLayer(
    bottoms = [:signal],
    pathkeys = [ws.PathKey(:time)],
    tops = [:fourier])

wavelets = ws.WaveletLayer(
    bank = bank,
    bottoms = [:fourier],
    tops = [:wavelets])

invfourier = ws.InvFourierLayer(
    bottoms = [:wavelets],
    pathkeys = [ws.PathKey(:time)],
    tops = [:invfourier])

modulus = ws.PointwiseLayer(
    bottoms = [:invfourier],
    tops = [:modulus],
    ρ = ws.Modulus())

layers = Mocha.Layer[
    signal,
    fourier,
    wavelets,
    invfourier,
    modulus]

Mocha.init(backend)
@time net = Mocha.Net("network", backend, layers);

In [5]:
paths = [ ws.Path(ws.PathKey(:j,:time) => j) for j in bank.behavior.j_range]
octaves = [ net.states[end].blobs[1].nodes[path].data for path in paths]

U1 = cat(4, octaves...)
U1 = reshape(U1, size(U1, 1), size(U1, 3) * size(U1, 4))

X = -fftshift(U1, 1).'

imshow(X[:, :], aspect="auto", cmap=ColorMap("magma"))


Out[5]:
PyObject <matplotlib.image.AxesImage object at 0x31da59310>

In [118]:
norm(data)


Out[118]:
1.0f0

In [119]:
norm(X)


Out[119]:
0.9164143f0

In [ ]: