In [2]:
import pyrtools as pt
%matplotlib inline
%load_ext autoreload
%autoreload 2
import sys
sys.path.append('..')
import sfp
import numpy as np
import torch
import seaborn as sns
import matplotlib.pyplot as plt
import pandas as pd
from scipy import interpolate
import itertools
from scipy import stats
import time
from torch.utils import data as torchdata
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
df_path = '/mnt/winawerlab/Projects/spatial_frequency_preferences/BIDS/derivatives/first_level_analysis/stim_class/posterior/sub-wlsubj045/ses-04/sub-wlsubj045_ses-04_task-sfprescaled_v1_e1-12_summary.csv'
In [3]:
stim = np.load('/mnt/winawerlab/Projects/spatial_frequency_preferences/BIDS/stimuli/task-sfp_stimuli.npy')
In [4]:
fig = pt.imshow(stim[72:80], col_wrap=4, zoom=.25, title=None);
#fig.savefig('stim.svg')
We have a function that does this (for all stimuli) and saves it, but this is to show how it's done
In [5]:
img_size = stim.shape[1:]
# this computation comes from the SteerablePyramidFreq code
max_ht = int(np.floor(np.log2(min(img_size))) - 2)
n_orientations = 6
energy = np.zeros((8, max_ht, n_orientations, *img_size),
dtype=np.float32)
filters = np.zeros((8, max_ht, n_orientations, *img_size), dtype=np.float32)
filled_filters = False
i=0
for j in range(32, 40):
pyr = pt.pyramids.SteerablePyramidFreq(stim[j], order=n_orientations-1, is_complex=True)
for k, l in itertools.product(range(max_ht), range(n_orientations)):
energy[0,k,l,:,:] += sfp.image_computable.upsample(np.abs(pyr.pyr_coeffs[(k, l)])**2, img_size)
# we only want to run this once per stimulus class
if not filled_filters:
if k > 0:
lomask = pyr._lomasks[k-1]
else:
lomask = pyr._lo0mask
filt = pyr._anglemasks[k][l] * pyr._himasks[k] * lomask
pad_num = []
for m in range(2):
pad_num.append([(img_size[m] - filt.shape[m])//2, (img_size[m] - filt.shape[m])//2])
if filt.shape[m] + 2*pad_num[m][0] != img_size[m]:
pad_num[m][0] += img_size[m] - (filt.shape[m] + 2*pad_num[m][0])
filters[int(i), k, l, :, :] = np.pad(filt, pad_num, 'constant', constant_values=0)
filled_filters = True
filter_mean = np.mean(filters, 0)
filters = filter_mean
And then we can examine some energy
In [6]:
fig = pt.imshow([energy[0,i,j,:,:] for i in range(5,8) for j in range(6)], col_wrap=6, zoom=.25, vrange='indep1');
#fig.savefig('energy.svg')
In [195]:
energy = np.load('/mnt/winawerlab/Projects/spatial_frequency_preferences/BIDS/derivatives/stimuli_energy/task-sfprescaled_n6_energy.npy')
filters = np.load('/mnt/winawerlab/Projects/spatial_frequency_preferences/BIDS/derivatives/stimuli_energy/task-sfprescaled_n6_filters.npy')
Let's show that things work: let's take the data for one subject and show that the model can produce predictions for a given voxel in both image-computable and tuning curve modes
In [70]:
ds = sfp.model.FirstLevelDataset(df_path, device, model_mode='image-computable',)
dl = torchdata.DataLoader(ds, 3)
In order to run the image-computable mode, we need to pass the model the pre-computed energy and filters above
In [198]:
model = sfp.model.LogGaussianDonut(vary_amplitude=False, sigma=1, sf_ecc_intercept=.2, sf_ecc_slope=.3)
model.prepare_image_computable(energy, filters)
We then pass the model a tensor that contains the index into the stimulus class, the voxels' pRF location (in eccentricity and angle) and its sigma, for 3 different voxels
In [7]:
inputs, targets = next(iter(dl))
In [11]:
inputs[:,:5]
Out[11]:
And we see the model produces a prediction for each of the 3 voxels for each of the 48 stimulus classes
In [8]:
model.image_computable(inputs)
Out[8]:
We can also produce a prediction for a single voxel:
In [12]:
inputs, targets = ds[2]
In [13]:
model.image_computable(inputs)
Out[13]:
And just for a sanity check, we can still have the model work in the tuning-curve type way. The scale is completely different, but, since we're normalizing the predictions when computing the loss, this shouldn't matter.
In [14]:
ds = sfp.model.FirstLevelDataset(df_path, device, model_mode='tuning_curve',)
dl = torchdata.DataLoader(ds, 3)
In [15]:
inputs, targets = next(iter(dl))
model(inputs)
Out[15]:
To step through what's happening here's, let's look at an imaginary voxel with the following properties:
In [192]:
vox_ecc = 4
vox_ang = 0
vox_sig = .5
First, based on its location in the visual field, we know its tuning, which we can show. This is the weight at each spatial frequency
In [196]:
vox_tuning = model.create_image(vox_ecc, vox_ang)
fig = pt.imshow(vox_tuning.detach(), zoom=.25, title=None);
#fig.savefig('vox_tuning.svg')
Based on this tuning and filters we already passed the model, we can get the relative weights on each scale and orientation sub-band.
In [197]:
weights = model.image_computable_weights(vox_ecc, vox_ang)
plt.scatter(np.array(range(weights.nelement())), weights.detach())
#plt.savefig('weights.svg')
With these weights, we take a weighted sum over the energy to get the feature map for this voxel.
In [138]:
reweighted_energy = (weights * model.energy).sum((2, 3))
fig = pt.imshow(reweighted_energy[0,0].detach(), zoom=.25);
#fig.savefig('reweighted_energy.svg')
And now we just create the pRF and drop it on the feature map.
In [76]:
prf = model.create_prfs(vox_ecc, vox_ang, vox_sig)
fig = pt.imshow(prf[0,0], zoom=.25);
#fig.savefig('prf.svg')
And we have a prediction!
In [26]:
pt.imshow((prf * reweighted_energy).detach()[0,5], zoom=.25);
We actually sum across the entire pRF, as shown below.
In [27]:
(prf * reweighted_energy[0,5]).sum((-1,-2))
Out[27]:
Couple simple tests to make sure we're handling normalization correctly:
Given a uniform spatial frequency preference, a bunch of uniform gratings should all have the same prediction.
If we're normalizing the pRFs correctly, voxels that are identical except for their pRF size should have the same prediction.
In [13]:
# we keep these small to make the computations faster
stim = []
stim_df = []
freqs = [.5, 1, 2, 4]
phases = np.linspace(0, 2*np.pi, 8, endpoint=False)
for i, f in enumerate(freqs):
for p in phases:
stim.append(sfp.utils.create_sin_cpd(256, 0, f, p, stim_rad_deg=3))
stim_df.append(pd.DataFrame({'size_pix': 256, 'w_x': 0, 'w_y': f, 'phi': p, 'size_deg': 3, 'class_idx': 3*i}, index=[0]))
stim.append(sfp.utils.create_sin_cpd(256, f, 0, p, stim_rad_deg=3))
stim_df.append(pd.DataFrame({'size_pix': 256, 'w_x': f, 'w_y': 0, 'phi': p, 'size_deg': 3, 'class_idx': 3*i+1}, index=[0]))
stim.append(sfp.utils.create_sin_cpd(256, np.sqrt(f)/2, np.sqrt(f)/2, p, stim_rad_deg=3))
stim_df.append(pd.DataFrame({'size_pix': 256, 'w_x': np.sqrt(f)/2, 'w_y': np.sqrt(f)/2, 'phi': p, 'size_deg': 3, 'class_idx': 3*i+2}, index=[0]))
stim = np.array(stim)
stim_df = pd.concat(stim_df).reset_index(drop=True)
In [35]:
energy, filters = sfp.image_computable.calc_energy_and_filters(stim, stim_df)
In [107]:
model = sfp.model.LogGaussianDonut(vary_amplitude=False, sigma=100, sf_ecc_intercept=1, sf_ecc_slope=0)
model.prepare_image_computable(energy, filters, 3)
In [155]:
vox_ecc = [2, 1]
vox_ang = [0, 0]
First, based on its location in the visual field, we know its tuning, which we can show. This is the weight at each spatial frequency
In [156]:
vox_tuning = []
for e, a in zip(vox_ecc, vox_ang):
vox_tuning.append(model.create_image(e, a).detach().numpy()[0])
pt.imshow(vox_tuning, zoom=2);
In [180]:
voxels = np.array([[vox_ecc[0], vox_ang[0], 1], [vox_ecc[1], vox_ang[1], 1]])
voxels = np.repeat(np.expand_dims(voxels, 1), 12, 1)
stim_idx = stim_df.class_idx.unique()
stim_idx = np.repeat(np.reshape(stim_idx, (1,12,1)), 2, 0)
inputs = np.dstack([stim_idx, voxels])
In [181]:
predictions = model.image_computable(torch.tensor(inputs, dtype=torch.float32)).detach().numpy()
In [182]:
results_df = stim_df.drop_duplicates('class_idx')
tmp = pd.DataFrame(predictions.T,)
tmp['class_idx'] = stim_df.class_idx.unique()
tmp = pd.melt(tmp, ['class_idx'], var_name='voxel', value_name='prediction')
tmp = tmp.set_index('voxel')
for i in range(len(vox_ecc)):
tmp.loc[i, 'ecc'] = vox_ecc[i]
tmp.loc[0, 'angle'] = vox_ang[i]
tmp = tmp.reset_index()
results_df = pd.merge(tmp, results_df)
In [183]:
weights = model.image_computable_weights(vox_ecc, vox_ang).squeeze()
for i in range(len(weights)):
w = weights[i]
plt.scatter(np.array(range(w.nelement())), w.detach())
In [188]:
## I THINK THIS COMES FROM THE ENERGY NOT BEING CORRECTLY NORMED
In [184]:
plt.plot(predictions.T);
In [179]:
pt.imshow(stim[::8,:,:], col_wrap=3);