Import some libraries that we will need
In [ ]:
%matplotlib inline
import scipy.io
from scipy.stats import stats
import numpy as np
import brainiak.funcalign.srm as srm
import brainiak.funcalign.rsrm as rsrm
from mpl_toolkits.mplot3d import Axes3D
import matplotlib.pyplot as plt
Experiment setup
In [ ]:
voxels = 100
samples = 200
subjects = 10
features = 3
snr = 20 # in dB
amplitude = 8.0 # approximate amplitude used for the values in Si
p = 0.05 # probability of being non-zero in Si
Now we create some synthetic data
In [ ]:
# Create a Shared response R with K = 3
theta = np.linspace(-4 * np.pi, 4 * np.pi, samples)
z = np.linspace(-2, 2, samples)
r = z ** 2 + 1
x = r * np.sin(theta)
y = r * np.cos(theta)
curve = np.vstack((x, y, z))
print('Curve max, min values:', np.max(curve), np.min(curve))
# Create the subjects' data
data = [None] * subjects
W = [None] * subjects
noise_level = 0.0
for s in range(subjects):
R = curve
W[s], _ = np.linalg.qr(np.random.randn(voxels, 3))
data[s] = W[s].dot(R)
noise_level += np.sum(np.abs(data[s])**2)
# Compute noise_sigma from desired SNR
noise_level = noise_level / (10 ** (snr / 10))
noise_level = np.sqrt(noise_level / subjects / voxels / samples)
for s in range(subjects):
n = noise_level * np.random.randn(voxels, samples)
data[s] += n
S = [None] * subjects
for s in range(subjects):
S[s] = (np.random.rand(data[s].shape[0], samples) < p) * ((np.random.rand(data[s].shape[0], samples) * amplitude - amplitude/2) )
data[s] += S[s]
Now we fit the algorithms, SRM and RSRM, to the synthetic data
In [ ]:
algo_srm = srm.SRM(features=3, n_iter=20)
algo_srm.fit(data)
algo_rsrm = rsrm.RSRM(features=3, gamma=0.35, n_iter=20)
algo_rsrm.fit(data)
print('Done')
The following function finds the orthogonal transform to align the shared response to the original curve.
In [ ]:
def find_orthogonal_transform(shared_response, curve):
u,_,vt = np.linalg.svd(shared_response.dot(curve.T))
q = u.dot(vt)
aligned_curve = q.T.dot(shared_response)
return aligned_curve
Plot the results
In [ ]:
fig = plt.figure()
ax = fig.gca(projection='3d')
ax.plot(curve[0, :], curve[1, :], curve[2, :], '-g', label='original', lineWidth=5)
proj = find_orthogonal_transform(algo_srm.s_, curve)
ax.plot(proj[0, :], proj[1, :], proj[2, :], '-b', label='SRM', lineWidth=3)
proj = find_orthogonal_transform(algo_rsrm.r_, curve)
ax.plot(proj[0, :], proj[1, :], proj[2, :], '-r', label='RSRM', lineWidth=3)
plt.legend()