This example replicates the synthetic experiment in the paper

"Capturing Shared and Individual Information in fMRI Data", J. Turek, C. Ellis, L. Skalaban, N. Turk-Browne, T. Willke

Import some libraries that we will need


In [ ]:
%matplotlib inline
import scipy.io
from scipy.stats import stats
import numpy as np
import brainiak.funcalign.srm as srm
import brainiak.funcalign.rsrm as rsrm
from mpl_toolkits.mplot3d import Axes3D
import matplotlib.pyplot as plt

Experiment setup


In [ ]:
voxels = 100
samples = 200
subjects = 10
features = 3
snr = 20 # in dB
amplitude = 8.0 # approximate amplitude used for the values in Si
p = 0.05 # probability of being non-zero in Si

Now we create some synthetic data


In [ ]:
# Create a Shared response R with K = 3
theta = np.linspace(-4 * np.pi, 4 * np.pi, samples)
z = np.linspace(-2, 2, samples)
r = z ** 2 + 1
x = r * np.sin(theta)
y = r * np.cos(theta)
curve = np.vstack((x, y, z))
print('Curve max, min values:', np.max(curve), np.min(curve))

# Create the subjects' data
data = [None] * subjects
W = [None] * subjects
noise_level = 0.0
for s in range(subjects):
    R = curve
    W[s], _ = np.linalg.qr(np.random.randn(voxels, 3))
    data[s] = W[s].dot(R)
    noise_level += np.sum(np.abs(data[s])**2)

# Compute noise_sigma from desired SNR
noise_level = noise_level / (10 ** (snr / 10))
noise_level = np.sqrt(noise_level  / subjects / voxels / samples)

for s in range(subjects):
    n = noise_level * np.random.randn(voxels, samples)
    data[s] += n
    
S = [None] * subjects
for s in range(subjects):
    S[s] = (np.random.rand(data[s].shape[0], samples) < p) * ((np.random.rand(data[s].shape[0], samples) * amplitude - amplitude/2) )
    data[s] += S[s]

Now we fit the algorithms, SRM and RSRM, to the synthetic data


In [ ]:
algo_srm = srm.SRM(features=3, n_iter=20)
algo_srm.fit(data)


algo_rsrm = rsrm.RSRM(features=3, gamma=0.35, n_iter=20)
algo_rsrm.fit(data)

print('Done')

The following function finds the orthogonal transform to align the shared response to the original curve.


In [ ]:
def find_orthogonal_transform(shared_response, curve):
    u,_,vt = np.linalg.svd(shared_response.dot(curve.T))
    q = u.dot(vt)
    aligned_curve = q.T.dot(shared_response)
    return aligned_curve

Plot the results


In [ ]:
fig = plt.figure()
ax = fig.gca(projection='3d')
ax.plot(curve[0, :], curve[1, :], curve[2, :], '-g', label='original', lineWidth=5)

proj = find_orthogonal_transform(algo_srm.s_, curve)
ax.plot(proj[0, :], proj[1, :], proj[2, :], '-b', label='SRM', lineWidth=3)
proj = find_orthogonal_transform(algo_rsrm.r_, curve)
ax.plot(proj[0, :], proj[1, :], proj[2, :], '-r', label='RSRM', lineWidth=3)
plt.legend()