Refer this notebook for the full problem description (reproduced partially below).

Generative model

Our model for low-dimensional linear encoding of stimulus, $$ R^F = UX^0 + Z $$

where

  • $R^F$, the $N$-by-$P$, firing-rate matrix of the entire population of neurons,
  • $X^0$, the $K$-by-$P$ signal, or stimulus, matrix that is sampled from a $K$-dimensional subspace with i.i.d. elements from $\mathcal{N}(0, \frac{N}{K}\sigma_s^2)$,
  • $U$, the $N$-by-$K$ orthogonal embedding matrix of the stimulus subspace,
  • $Z$, the $N$-by-$P$ activities noise with $Z_{ij} \sim \mathcal{N}(0, \sigma_n^2)$.

The generative model's parameters are $(N, K, P, \sigma_s, \sigma_n)$, we will use

$$X = UX^0$$

to denote the embedded signal/stimulus. Note that for individual neurons and the population as a whole, we have the signal-to-noise ratio

$$ \text{SNR} = \frac{\sigma_s^2}{\sigma_n^2} $$

and for the signal subspace spanned by columns of $U$, we have the subspace signal-to-noise ratio

$$ \text{SNR}_s = \frac{N}{K}\frac{\sigma_s^2}{\sigma_n^2} = \frac{N}{K}SNR $$

We will use $k$, $m$, and $p$ to denote the normalized quantities $K / N$, $M / N$, and $P / N$.

Observation model

We model the observations of neural activities as $$ R^S = SR^F = S(X + Z) $$ where $S$ is a $M$-by-$N$ random sampling matrix. Additionally, we also measure the behavior output $y$.

The observation model's parameter is simply $M$.

Problem

We would like to find $\hat{W}$, s.t. $ \left|\hat{W}R^S - X^0\right|_2 $ is minimized.

Theoretical Treatment at the infinite data limit

In this notebook, we theoretically compute the performance of such a linear decoder in the infinite data limit, or $P \to \infty$. Use the short hand $A = SU$, it's easy to derive that

$$ \hat{W} = A^T(AA^T + \frac{1}{\text{SNR}_s}I)^{-1} $$

with performance measured by $R$-squared as,

$$ \begin{align} R^2 &= 1 - \frac{\text{Tr}\left[ (\hat{W}R^S - X^0)(\hat{W}R^S - X^0)^T \right]}{\text{Tr}\left[X^0{X^0}^T\right]} \\ &= \frac{1}{K} \text{Tr}\left[\frac{AA^T}{AA^T + \text{SNR}_s^{-1}I } \right] \\ &= \frac{1}{K}\sum_{i = 1}^\infty (-1)^{i + 1} {\text{SNR}_s}^{i} \text{Tr}\left[ (AA^T)^i \right] \tag{Taylor series} \end{align} $$

To compute the infinite series above, we note that $AA^T$'s eigenvalue spectrum $\mu^{AA^T}(\lambda)$ has the Stieljes transform,

$$ \begin{align} G^{AA^T}(z) &= \int \frac{1}{z - \lambda} d\mu^{AA^T}(\lambda) \\ &= \lim_{N \to \infty} \sum_{i = 0}^\infty \frac{\text{Tr}\left[ \frac{1}{M}(AA^T)^i \right]}{z^{i + 1}} \end{align} $$

This allows us to re-express the $R$-squared measure in terms of $\mu^{AA^T}(\lambda)$'s Stiejles transform (see this notebook) to compute a closed form of decoding performance,

$$ \begin{align} R^2 &= \frac{1}{\text{SNR}_s}\frac{M}{K} \left( G^{AA^T}\left(-\text{SNR}_s^{-1}\right) + \text{SNR}_s \right) \\ &= \frac{k + (k + m)\text{SNR} - \sqrt{\left( \lambda_-\text{SNR} + k \right) \left( \lambda_+\text{SNR} + k \right)}}{2k(k+\text{SNR})}\\ &= \frac{1 + (k + m)\text{SNR}_s - \sqrt{\left( \lambda_-\text{SNR}_s + 1 \right) \left( \lambda_+\text{SNR}_s + 1 \right)}}{2k\left(1 + \text{SNR}_s\right)} \end{align} $$

where $\lambda_\pm$ denote the upper and lower bounds of $\mu^{AA^T}(\lambda)$,

$$ \lambda_\pm = \left(\sqrt{k(1 - m)} \pm \sqrt{m(1 - k)} \right)^2 $$

Limits and bounds of the $R$-squared performance measure

In the large and small signal-to-noise ratio limits:

$$ \begin{align} \lim_{\text{SNR} \to \infty} R^2 &\approx 1 - \frac{k - k^2}{m - k}\frac{1}{\text{SNR}} \\ &\approx 1 - \frac{k}{m}\frac{1}{\text{SNR}} \tag{for $k \ll m$} \\ \lim_{\text{SNR} \to 0} R^2 &\approx \frac{m}{k}\text{SNR} \end{align} $$

Furthermore, we have bounds on the performance measure,

$$ k\frac{\text{SNR}}{1 + \text{SNR}} \leq R^2 \leq \frac{\text{SNR}_s}{1 + \text{SNR}_s} $$
  • The lower bound (a guess) on the left side corresponds to the case where decoding is done by the averaging of single-neuron decoders. (which limit do we take to get here???)
  • The upper bound on the right side corresponds to the case where the signal is decoded in the correct signal subspace (unknown to experimenters). The performance approaches this bound as $m \to 1$, or when we approach full observation.

Question?

What is $m(\text{SNR}, k)$ beyond which we have perfect recovery???

Verification


In [1]:
using PyPlot, Interact

println("Red dot: simulation using matrix trace")
println("Black line: exact theory")
println("Blue/red dash: low and high SNR limiting behavior")

f = figure(figsize=(8, 3))
N = 1000
@manipulate for k = 0.01:0.01:1, m = 0.01:0.01:1
    if k <= m
        K, M = int(k * N), int(m * N)
        A = qr(randn(N, K))[1][randperm(N)[1:M], :]
        es = eig(A * A')[1]
        
        snr = logspace(-2, 2, 21)
        snrs = snr / k
        R2(s) = sum(es ./ (es + 1 / s * k) / K)
        R2s = map(R2, snr)
        
        snrp = logspace(-2, 2, 100)
        snrps = snrp / k
        lp, lm = (sqrt(k * (1 - m)) + sqrt(m * (1 - k)))^2, (sqrt(k * (1 - m)) - sqrt(m * (1 - k)))^2
#         R2p(s) = (k + (k + m) * s - sqrt((lm * s + k) * (lp * s + k))) / k / (k + s)
        R2p(s) = (k + (k + m) * s - sqrt((lm * s + k) * (lp * s + k))) / 2 / k / (k + s)
        R2ps = map(R2p, snrp)
        
        withfig(f) do
            subplot(121)
            plot(snr, R2s, "ro")
            plot(snrp, R2ps, "k", linewidth=2)
            plot(snrp, m / k * snrp, "b--")
            plot(snrp, 1 - (k - k^2) / (m - k) ./ snrp, "r--")
            plot(snrp, snrp ./ (k + snrp), "k--")
            xlabel("SNR"); ylim([0, 1]); ylabel("Performance"); xscale("log")
            
            subplot(122)
            plot(snrs, R2s, "ro")
            plot(snrps, R2ps, "k", linewidth=2)
            plot(snrps, m / k * snrp, "b--")
            plot(snrps, 1 - (k - k^2) / (m - k) ./ snrp, "r--")
            xlabel("SNR_s"); ylim([0, 1]); xscale("log");
        end
    end
end


INFO: Loading help data...
Red dot: simulation using matrix trace
Black line: exact theory
Blue/red dash: low and high SNR limiting behavior

In [ ]: