In [2]:
using PyPlot, Interact

addprocs(12);

push!(LOAD_PATH, "../src/")
using HDStat


Vendor:  Continuum Analytics, Inc.
Package: mkl
Message: trial mode expires in 26 days
INFO: Loading help data...

Generative model

Our model for high-dimensional static decoding of neural activities: $$ y^T = w^T \left(UX_0 + Z \right) + \epsilon^T $$

We have

  • $X_0$, the $K$-by-$P$ signal matrix that is sampled from a $K$-dimensional subspace with all non-zero singular values equal to $r\sqrt{\frac{P}{K}}$,
  • $U$, the $N$-by-$K$ embedding matrix that is drawing i.i.d. from $\mathcal{N}(0, \frac{1}{N})$,
  • $Z$, the $N$-by-$P$ activities noise with $Z_{ij} \sim \mathcal{N}(0, \frac{1}{N})$,
  • $w$, a length-$N$ decoding unit-vector in the $K$-dimensional column space of $U$,
  • $\epsilon$, $P$ scalar noises for the behavior output draw i.i.d. from $\mathcal{N}(0, s^2)$,
  • $y$, $P$ scalar behavior outputs

The generative model's parameters are $(N, K, P, r, s)$, we will use

$$X = UX_0 + Z$$

to denote neural activities.

Observation model

We model the observations of neural activities as $$ \hat{X} = S\left(UX_0 + Z \right) $$ where $S$ is a $M$-by-$N$ random sampling matrix. Additionally, we also measure the behavior output $y$.

The observation model's parameter is simply $M$.


In [3]:
@everywhere immutable GenModel
    N::Integer
    K::Integer
    P::Integer
    r::Number
    s::Number
    w::Array{Float64, 1}
    U::Array{Float64, 2}
    
    function GenModel(N::Integer, K::Integer, P::Integer, r::Number, s::Number)
        U = randn(N, K) / sqrt(N)
        Up, _ = qr(U)
        w = Up * randn(K)
        w /= norm(w)
        return new(N, K, P, r, s, w, U)
    end
end

@everywhere immutable ObsModel
    gen::GenModel
    M::Integer
    S::Array{Float64, 2}
    
    function ObsModel(gen::GenModel, M::Integer)
        S = eye(gen.N)[randperm(gen.N)[1:M], :]
        return new(gen, M, S)
    end
end

Sampling functions for the generative and observation models


In [4]:
@everywhere function rand(model::GenModel)
    let N = model.N, K = model.K, P = model.P, r = model.r, s = model.s, w = model.w, U = model.U
        X0 = qr(randn(P, K))[1]' * r * sqrt(P / K)
        Z = randn(N, P) / sqrt(N)
        ϵ = randn(P) * s
        y = vec(w' * (U * X0 + Z) + ϵ')
        
        return {:X0 => X0, :X => model.U * X0 + Z, :Z => Z, :e => ϵ, :y => y}
    end
end

@everywhere function rand(model::ObsModel)
    rst = rand(model.gen)
    rst[:Xhat] = model.S * rst[:X]
    return rst
end

Problem 1: Inferring $K$

The inferred $K$ from the data is the number of $\hat{X}$'s singular values above the output noise floor

$$ \frac{\sqrt{P} + \sqrt{M}}{\sqrt{N}} $$

The correct $K$ is inferred when the minimum singular value of sample neural activities is greater than the input noise floor, or,

$$ r\sqrt{\frac{P}{K}}\left(\sqrt{\frac{M}{N}} - \sqrt{\frac{K}{N}}\right) \geq \frac{\left(MP\right)^{1/4}}{\sqrt{N}} \Rightarrow r^2\frac{\sqrt{MP}}{K}\left(1 - \sqrt{\frac{K}{M}}\right)^2 \geq 1 $$

Function to find the number of signal modes, i.e. infer $K$


In [5]:
@everywhere function inferK(model::ObsModel; Xhat = None)
    if Xhat == None; Xhat = rand(model)[:Xhat]; end
    let gen = model.gen
        _, S, _ = svd(Xhat)
        return sum(S .> (sqrt(gen.P) + sqrt(model.M)) / sqrt(gen.N))
    end
end

Sanity Check


In [6]:
@everywhere K, N, r, s = 5, 5000, 0.25, 0.0
@everywhere Ms, Ps = 10:10:400, 10:10:400

rst = [@spawn inferK(ObsModel(GenModel(N, K, P, r, s), M)) for M in Ms, P in Ps]
rst = map(fetch, rst);

theory = sqrt(Ms * Ps') / K * r^2 .* (1 - sqrt(K ./ repmat(Ms, 1, length(Ps)))).^2 .> 1;

In [7]:
K, N, r, s = 5, 5000, 0.1, 0.0
M, P = 50, 50

o = ObsModel(GenModel(N, K, P, r, s), M)
x = rand(o)

X0 = x[:X0]
SUX0 = o.S * o.gen.U * x[:X0]
SU = o.S * o.gen.U

println(string("Signal svs:\n", svd(X0)[2]))
println(string("theory:\n", r * sqrt(P)))

tmp = sqrt(eigs(SU * SU'; nev=K)[1])
println("min/max sv of sampling/projection:")
println((minimum(tmp), maximum(tmp)))
println(((sqrt(M) - sqrt(K)) / sqrt(N), (sqrt(M) + sqrt(K)) / sqrt(N)))

tmp = sqrt(eigs(SUX0 * SUX0'; nev=K)[1])
println("min/max sv of sampled/projected signal:")
println((minimum(tmp), maximum(tmp)))
println(((sqrt(M) - sqrt(K)) / sqrt(N) * r * sqrt(P), (sqrt(M) + sqrt(K)) / sqrt(N) * r * sqrt(P)))

println("Threshold:")
println((M * P)^(1/4) / sqrt(N))


Signal svs:
[0.31622776601683805,0.316227766016838,0.316227766016838,0.316227766016838,0.31622776601683794]
theory:
0.7071067811865476
min/max sv of sampling/projection:
(0.0729941214768753,0.11997900592915732)
(0.0683772233983162,0.13162277660168378)
min/max sv of sampled/projected signal:
(0.02308276796699397,0.03794069301389838)
(0.04834999834365686,0.09307135789365265)
Threshold:
0.1

Inferring dimensionality correctly, checking against theory


In [8]:
figure(figsize=(4, 3))
imshow(rst, aspect="auto", interpolation="nearest", origin="lower", cmap="RdBu_r", extent=[minimum(Ps), maximum(Ps), minimum(Ms), maximum(Ms)]);
colorbar();
contour(repmat(Ps', length(Ms), 1), repmat(Ms, 1, length(Ps)), theory, 1, linewidths=4, colors="k")


Out[8]:
PyObject <matplotlib.contour.QuadContourSet instance at 0x7fa637f13c68>

How well can we infer $w$ and decode $y$?

Problem: given $\hat{X}$ and $y$, find $\hat{w}$ such that $|\hat{w}\hat{X} - y^T|_2$ is minimized in a validation dataset

Analysis: what is the angle between the inferred $\hat{w}$ and the sampled ground truth $Sw$?

Algorithm #0: This is a cheating algorithm with $\hat{w} = \alpha Sw$. In other words, we simply find the best scaling of the sampled ground truth decoding vector.

Algorithm #1: Simple linear regression of $y$ against $\hat{X}$.

Algorithm #2: Infer the sampled signal subspace, $\hat{U}$, from $\hat{X}$ first using low-rank perturbation theory, then regress $y$ against $\hat{U}^T\hat{X}$

Algorithm #3: Recover the best sampled signal $\tilde{X}$ from $\hat{X}$ using Gavish and Donoho and the Frobenius error metric, regress $y$ against $\tilde{X}$.


In [9]:
# algorithm #0
@everywhere function cheat_w(Xhat, y, model::ObsModel)
    Sw = model.S * model.gen.w
    alpha = sum(y .* y) / sum(y .* vec(Sw' * Xhat))
    return alpha * Sw
end

# algorithm #1
@everywhere function simple_w(Xhat, y)
    return pinv(Xhat * Xhat') * (Xhat * y)
end

# algorithm #2
@everywhere function subspace_w(Xhat, y, model::ObsModel)
    let gen = model.gen
        thresh = sqrt(gen.P / gen.N) + sqrt(model.M / gen.N)
        U, S, V = svd(Xhat)
        K = sum(S .> thresh)
        if K < 1; return zeros(size(Xhat, 1)); end;
        Xtilde = U[:, 1:K]' * Xhat;
        return U[:, 1:K] * pinv(Xtilde * Xtilde') * (Xtilde * y)
#                 return  U[:, 1:K] * ((U[:, 1:K]' * Xhat)' \ y)
#         return  U[:, 1:K] * ((U[:, 1:K]' * Xhat)' \ y)
    end
end

# algorithm #3
@everywhere function signal_w(Xhat, y, model::ObsModel)
    let gen = model.gen
        U, S, V = svd(Xhat)
        S = S * sqrt(gen.N / gen.P)
        thresh = 1 + sqrt(model.M / gen.P)
        beta = model.M / gen.P
        mask = S .> thresh
        if sum(mask) < 1; return zeros(size(Xhat, 1)); end;
        S[mask] = sqrt(S[mask].^2 - beta - 1 + sqrt((S[mask].^2 - beta - 1).^2 - 4 * beta)) / sqrt(2)
        S[~mask] = 0
        Xtilde = U * diagm(S * sqrt(gen.P / gen.N)) * V'
        return pinv(Xtilde * Xtilde') * (Xtilde * y)
    end
end

In [10]:
f = figure(figsize=(18, 6))

println("Left four panels shows fitted coefficients against Sw, in the order:")
println("#0,   #1")
println("#2,   #3")

@manipulate for N in [2000, 5000], M in [50:50:1000], K in [2:4:42], P in [50:50:1000], r in 0.0:0.1:0.5, s in 0:0.1:1
    g = GenModel(N, K, P, r, s);
    o = ObsModel(g, M);
    wS = o.S * g.w

    train = rand(o)
        
    ytest = Float64[]
    Xtest = zeros(M, 0)
    while length(ytest) < 10000
        tmp = rand(o)
        ytest = [ytest; tmp[:y]]
        Xtest = [Xtest tmp[:Xhat]]
    end
        
    what_cheat = cheat_w(train[:Xhat], train[:y], o)
    what_simple = simple_w(train[:Xhat], train[:y])
    what_subspace = subspace_w(train[:Xhat], train[:y], o)
    what_signal = signal_w(train[:Xhat], train[:y], o)

    angle_cheat = abs(sum(wS .* what_cheat)) / norm(wS) / norm(what_cheat)
    angle_simple = abs(sum(wS .* what_simple)) / norm(wS) / norm(what_simple)
    angle_subspace = abs(sum(wS .* what_subspace)) / norm(wS) / norm(what_subspace)
    angle_signal = abs(sum(wS .* what_signal)) / norm(wS) / norm(what_signal)

    err_cheat = norm(vec(what_cheat' * Xtest) - ytest)^2 / norm(ytest)^2
    err_simple = norm(vec(what_simple' * Xtest) - ytest)^2 / norm(ytest)^2
    err_subspace = norm(vec(what_subspace' * Xtest) - ytest)^2 / norm(ytest)^2
    err_signal = norm(vec(what_signal' * Xtest) - ytest)^2 / norm(ytest)^2

    withfig(f) do
        subplot(261)
        plot(wS, what_cheat, ".")
        title(string("Inferred K: ", inferK(o; Xhat=train[:Xhat])))
        subplot(262)
        plot(wS, what_simple, ".")
        subplot(2, 6, 7)
        plot(wS, what_subspace, ".")
        subplot(2, 6, 8)
        plot(wS, what_signal, ".")
        subplot(132)
        bar(1:4, [angle_cheat, angle_simple, angle_subspace, angle_signal])
        xticks(1:4, ["#0", "#1", "#2", "#3"]); title("Overlap between Sw and w_hat")
        ylim([0, 1])
        subplot(133)
        bar(1:4, [err_cheat, err_simple, err_subspace, err_signal])
        xticks(1:4, ["#0", "#1", "#2", "#3"]); title("Normalized error on held out data");
        ylim([1e-2, 1e2]); yscale("log")
    end    
end


Left four panels shows fitted coefficients against Sw, in the order:
#0,   #1
#2,   #3
Out[10]:

Now lets explore some 2-dimensional parameters space $M$ and $P$ with fixed $K$, $N$, $r$, $s$


In [11]:
@everywhere overlap(v1, v2) = abs(sum(v1 .* v2)) / (norm(v1) + eps(Float64)) / (norm(v2) + eps(Float64))

@everywhere testerror(w, X, y) = norm(vec(w' * X) - y)^2 / norm(y)^2

@everywhere function trial(K, N, M, P, r, s)
    g = GenModel(N, K, P, r, s);
    o = ObsModel(g, M);
    wS = o.S * g.w

    train = rand(o)
        
    ytest = Float64[]
    Xtest = zeros(M, 0)
    while length(ytest) < 10000
        tmp = rand(o)
        ytest = [ytest; tmp[:y]]
        Xtest = [Xtest tmp[:Xhat]]
    end
    
    what_cheat = cheat_w(train[:Xhat], train[:y], o)
    what_simple = simple_w(train[:Xhat], train[:y])
    what_subspace = subspace_w(train[:Xhat], train[:y], o)
    what_signal = signal_w(train[:Xhat], train[:y], o)
    
    return map(w -> (overlap(wS, w), testerror(w, Xtest, ytest)), {what_cheat, what_simple, what_subspace, what_signal})
end

In [12]:
@everywhere K, N, r, s = 5, 5000, 0.25, 0.0
@everywhere Ms, Ps = 10:10:400, 10:10:400

In [13]:
rst = [@spawn trial(K, N, M, P, r, s) for M = Ms, P = Ps];
rst = map(fetch, rst);

In [14]:
angle_cheat = map(x -> x[1][1], rst);
angle_simple = map(x -> x[2][1], rst);
angle_subspace = map(x -> x[3][1], rst);
angle_signal = map(x -> x[4][1], rst);

err_cheat = map(x -> x[1][2], rst);
err_simple = map(x -> x[2][2], rst);
err_subspace = map(x -> x[3][2], rst);
err_signal = map(x -> x[4][2], rst);

In [15]:
valid = r^2 * sqrt(Ms * Ps') / K .* (1 - sqrt(K ./ repmat(Ms, 1, length(Ps)))).^2 .> 1;

Numerical experiments on $\hat{w}^TSw$


In [16]:
function plot_helper(angle)
    extent = [minimum(Ps), maximum(Ps), minimum(Ms), maximum(Ms)]
    imshow(angle, aspect="auto", interpolation="nearest", origin="lower", vmin=0, vmax=1, extent=extent);
    colorbar();
    contour(repmat(Ps', length(Ms), 1), repmat(Ms, 1, length(Ps)), valid, 1, linewidths=4, colors="k")
end

figure(figsize=(8, 6))

subplot(221)
plot_helper(angle_cheat)
title("#0"); ylabel("M")
subplot(222)
plot_helper(angle_simple)
title("#1");
subplot(223)
plot_helper(angle_subspace)
title("#2"); ylabel("M"); xlabel("P")
subplot(224)
plot_helper(angle_signal)
title("#3"); xlabel("P")


Out[16]:
PyObject <matplotlib.text.Text object at 0x7f08ee25bcd0>

Numerical experiment on $R^2$


In [19]:
function plot_helper(err)
    extent = [minimum(Ps), maximum(Ps), minimum(Ms), maximum(Ms)]
    imshow(1 - err, aspect="auto", interpolation="nearest", origin="lower", vmin=0, vmax=1, extent=extent);
    colorbar();
    contour(repmat(Ps', length(Ms), 1), repmat(Ms, 1, length(Ps)), valid, 1, linewidths=4, colors="k")
end

figure(figsize=(8, 6))

subplot(221)
plot_helper(err_cheat)
title("#0"); ylabel("M")
subplot(222)
plot_helper(err_simple)
title("#1")
subplot(223)
plot_helper(err_subspace)
title("#2"); ylabel("M"); xlabel("P")
subplot(224)
plot_helper(err_signal)
title("#3"); xlabel("P")


Out[19]:
PyObject <matplotlib.text.Text object at 0x7f08ed499750>

In [19]:
K = 10
f = figure(figsize=(4, 3))

@manipulate for M = 1000:1000:5000, N = 1000:1000:5000
    Uorth = qr(randn(N, K))[1][1:M, :]
    U = randn(N, K)[1:M, :] / sqrt(N)
    
    Sorth = sort(eig(Uorth' * Uorth)[1])
    S = sort(eig(U' * U)[1])
    
    withfig(f) do
        plot(S, Sorth, ".")
        plot([minimum(S), maximum(S)], [minimum(S), maximum(S)], "k-")
    end
end


Out[19]: