In [ ]:
%load_ext autoreload
%autoreload 2

In [ ]:


In [ ]:
import os
import sys
module_path = os.path.abspath(os.path.join('..'))
if module_path not in sys.path:
    sys.path.append(module_path)

In [ ]:
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import GPy
from diGP.preprocessing_pipelines import (preprocess_SPARC, preprocess_HCP)
from diGP.dataManipulations import (DataHandler,
                                    log_q_squared,
                                    generateCoordinates)
from diGP.generateSyntheticData import combineCoordinatesAndqVecs
from diGP.model import Model

%matplotlib inline
sns.set_style('dark')

In [ ]:
dataPath = {'HCP_1007': 'C:\\Users\\sesjojen\\Documents\\Data\\HumanConnectomeProject\\mgh_1007\\diff\\preproc',
            'SPARC_20': 'C:\\Users\\sesjojen\\Documents\\Data\\SPARC\\nifti\\gradient_20_nifti',
            'SPARC_30': 'C:\\Users\\sesjojen\\Documents\\Data\\SPARC\\nifti\\gradient_30_nifti',
            'SPARC_60': 'C:\\Users\\sesjojen\\Documents\\Data\\SPARC\\nifti\\gradient_60_nifti',
            'SPARC_GS': 'C:\\Users\\sesjojen\\Documents\\Data\\SPARC\\nifti\\goldstandard_nifti'}

def preprocess(path_dict, data_source):
    if data_source[0:3] == 'HCP':
        return preprocess_HCP(path_dict[data_source])
    elif data_source[0:5] == 'SPARC':
        return preprocess_SPARC(path_dict[data_source])
    else:
        raise Exception('Unknown data source.')

In [ ]:
source = 'SPARC_20'
gtab, data, voxelSize = preprocess(dataPath, source)
print(gtab.info)

Compare with spatial interpolation in striped pattern.


In [ ]:
mid_z = np.round(data.shape[2]/2).astype(int)
handler = DataHandler(gtab, data[::2, :, mid_z, :], voxelSize=(2*voxelSize[0], voxelSize[1]),
                      image_origin=voxelSize[0:2]*np.array([0, 0]), qMagnitudeTransform=log_q_squared)

handlerPred = DataHandler(gtab, data[1::2, :, mid_z, :], voxelSize=(2*voxelSize[0], voxelSize[1]),
                          image_origin=voxelSize[0:2]*np.array([1, 0]), qMagnitudeTransform=log_q_squared)

In [ ]:
spatialLengthScale = 5
bValLengthScale = 3

kernel = (GPy.kern.RBF(input_dim=1, active_dims=[0],
                       variance=1,
                       lengthscale=spatialLengthScale) *
          GPy.kern.RBF(input_dim=1, active_dims=[1],
                       variance=1,
                       lengthscale=spatialLengthScale) *
          GPy.kern.Matern52(input_dim=1, active_dims=[2],
                            variance=1,
                            lengthscale=bValLengthScale) *
          GPy.kern.LegendrePolynomial(
             input_dim=3,
             coefficients=np.array((2, 0.5, 0.05)),
             orders=(0, 2, 4),
             active_dims=(3, 4, 5)))

kernel.parts[0].variance.fix(value=1)
kernel.parts[1].variance.fix(value=1)
kernel.parts[2].variance.fix(value=1)

In [ ]:
grid_dims = [[0], [1], [2, 3, 4, 5]]

model = Model(handler, kernel, data_handler_pred=handlerPred, grid_dims=grid_dims, verbose=False)

In [ ]:
model.train(restarts=True)

In [ ]:
print(model.GP_model)
print("\nLegendre coefficients: \n{}".format(model.GP_model.mul.LegendrePolynomial.coefficients))

In [ ]:
#mu = model.predict(compute_var=False)
mu, var = model.predict(compute_var=True)

In [ ]:
plt.hist(handlerPred.y-mu, bins=500);

In [ ]:
y_slice = model.data_handler.data[:, :, 1]
mu_slice = model.data_handler_pred.untransform(mu)[:, :, 1]
var_slice = model.data_handler_pred.untransform(var)[:, :, 1]

#plt.imshow(y_slice, vmin=0, vmax=1)
plt.imshow(mu_slice, vmin=0, vmax=1)
#plt.imshow(var_slice)

In [ ]:
sz=data.shape
combined = np.zeros((sz[0], sz[1], sz[3]))

y_idx = (handler.X_coordinates/ voxelSize[0:2]).astype(int)
mu_idx = (handlerPred.X_coordinates / voxelSize[0:2]).astype(int)

combined[y_idx[:, 0], y_idx[:, 1], :] = handler.y.reshape(np.prod(handler.originalShape[0:-1]), -1)
combined[mu_idx[:, 0], mu_idx[:, 1], :] = mu.reshape(np.prod(handlerPred.originalShape[0:-1]), -1)

In [ ]:
fig, ax = plt.subplots()
cax = ax.imshow(combined[:, :, 30], vmin=0, vmax=1)
fig.colorbar(cax)