In [ ]:
import os.path as op
import numpy as np
from scipy import stats
import mne
import matplotlib.pyplot as plt
import seaborn as sns
from mne.time_frequency import psd_welch, psd_multitaper
from jumeg.jumeg_utils import rescale_arr
from jumeg import get_jumeg_path
from jumeg.connectivity import plot_grouped_connectivity_circle
from mne_features.feature_extraction import extract_features
from mne_features.univariate import _freq_bands_helper
from rasmalai.network_utils import (make_count_based_stats,
make_combined_average,
get_con_metrics, get_coords_given_labels)
from glob import glob
import yaml
import pickle
import time
t_start = time.time()
sns.set_style('white')
In [2]:
epoch_fname = '/Users/psripad/kelsa/epochs_with_artifacts/207184_rest_EC_0.6-200bp_bcc,nr,ar,1-epo.fif'
epochs = mne.read_epochs(epoch_fname, preload=True)
data = epochs.get_data()
In [3]:
n_colors = 7
color_list = plt.cm.Pastel1(np.linspace(0, 1, n_colors))
bands = ['delta', 'theta', 'alpha', 'beta1', 'beta2', 'gamma1', 'gamma2']
freqs = np.array([1., 4., 8., 13., 18., 30., 80., 200.])
selected_funcs = {'pow_freq_bands'}
In [4]:
def get_psd_fingerprint(epochs, freq_bands=None, normalize=True):
'''Function to compute PSDs on epochs and return
normalized power across the given frequency bands
PSDs are computed using Welch's method and the sum
of PSDs across frequencies (normalised) with the
total power is returned.
Uses the mne-features toolbox.
Inputs
------
epochs: mne.Epochs
Instance of epochs.
Output
------
X_avg: ndarray
Mean (normalized) sum of PSDs across given frequencies.
'''
from mne_features.feature_extraction import extract_features
if not freq_bands:
freq_bands = np.array([1., 4., 8., 13., 18., 30., 80., 200.])
selected_funcs = {'pow_freq_bands'}
X_new = extract_features(data, epochs.info['sfreq'], selected_funcs,
funcs_params={'pow_freq_bands__freq_bands': freqs,
'pow_freq_bands__normalize': normalize,
'pow_freq_bands__ratios': None,
'pow_freq_bands__psd_method': 'welch',
'pow_freq_bands__psd_params': {'welch_n_fft': 512}})
X_new = X_new.reshape((len(epochs), -1, len(freqs) - 1))
X_avg = X_new.mean(axis=(0, 1))
return X_avg
In [20]:
def get_normalised_psds_welch(epochs, norm_freq=[80., 200.],
freq_bands=None, grand_average=False):
'''Function to estimate PSDs and normalize
the PSD values by given frequency band.
Input
-----
epochs: mne.Epochs
Epochs instance.
norm_freq: list
List giving frequency range to use for normalization.
freq_bands: ndarray | list
List of frequencies used to make average across bands.
If this option is not None, then average normalized PSDs
across the given bands are returned.
(default: [1., 4., 8., 13., 18., 30., 80., 200.])
grand_average: bool
If True, then average across trials and channels are returned.
Output
------
norm_psds_welch: ndarray
Normalized psds of shape (n_trials, n_channels, n_freqs)
freqs_welch: ndarray
Frequency points for which PSDs are estimated.
'''
from mne.time_frequency import psd_welch
from mne import pick_types
from mne_features.univariate import _freq_bands_helper
# estimate the PSDs
picks = pick_types(epochs.info, meg=True, exclude='bads')
psds_welch, freqs_welch = psd_welch(epochs, fmin=1., fmax=200., n_fft=512,
picks=picks, n_jobs=4, average='mean')
# get the mask for the freq range to be used for normalization
norm_mask = np.logical_and(freqs_welch >= norm_freq[0], freqs_welch <= norm_freq[1])
# compute mean PSDs across the freq range
norm_mean = psds_welch[:, :, norm_mask].mean(axis=2)
# get the normalized values
norm_psds_welch = psds_welch / norm_mean[:, :, np.newaxis]
if freq_bands is not None:
psds_band_ = np.zeros((norm_psds_welch.shape[0],
norm_psds_welch.shape[1],
len(freq_bands) - 1))
for i, (fmin, fmax) in enumerate(_freq_bands_helper(epochs.info['sfreq'], freq_bands)):
mask = np.logical_and(freqs_welch >= fmin, freqs_welch <= fmax)
psds_band_[:, :, i] = norm_psds_welch[:, :, mask].mean(axis=(2))
if grand_average:
# return the average across trials and channels and bands
return psds_band_.mean(axis=(0, 1))
else:
# return the psds averaged across frequency bands only
return psds_band_
else:
if grand_average:
return norm_psds_welch.mean(axis=(0, 1)), freqs_welch
else:
# return the normalized psds and corresponding freqs
return norm_psds_welch, freqs_welch
In [6]:
mypsds1, myfreqs = get_normalised_psds_welch(epochs, norm_freq=[80., 200.], freq_bands=None)
print(mypsds1.shape)
In [21]:
mypsds2 = get_normalised_psds_welch(epochs, norm_freq=[80., 200.],
freq_bands=np.array([1., 4., 8., 13., 18., 30., 80., 200.]))
print(mypsds2.shape)
In [22]:
mypsds3 = get_normalised_psds_welch(epochs, norm_freq=[80., 200.],
freq_bands=np.array([1., 4., 8., 13., 18., 30., 80., 200.]),
grand_average=True)
print(mypsds3.shape)
In [ ]:
picks = mne.pick_types(epochs.info, meg=True, exclude='bads')
psds_welch, freqs_welch = psd_welch(epochs, fmin=1., fmax=200., n_fft=512,
picks=picks, n_jobs=4, average='mean')
In [ ]:
myavg = get_psd_fingerprint(epochs, freq_bands=None, normalize=False)
print(myavg)
In [ ]:
data.shape
So there are 430 epochs, 247 channels and 679 time points.
PSDs averaged across epochs and channels result in 7 data points, one for each frequency band.
In [ ]:
# compute PSDs and plot using Welch's algorithm
picks = mne.pick_types(epochs.info, meg=True, exclude='bads')
psds_welch, freqs_welch = psd_welch(epochs, fmin=1., fmax=200., n_fft=256,
picks=picks, n_jobs=4, average='mean')
psds_welch_avg = psds_welch.mean(axis=(0, 1))
In [ ]:
_freq_bands_helper?
In [ ]:
X_new = extract_features(data, epochs.info['sfreq'], selected_funcs,
funcs_params={'pow_freq_bands__freq_bands': freqs,
'pow_freq_bands__normalize': False,
'pow_freq_bands__ratios': None,
'pow_freq_bands__psd_method': 'welch',
'pow_freq_bands__psd_params': {'welch_n_fft': 256}})
X_new = X_new.reshape((430, -1, len(freqs) - 1))
X_avg = X_new.mean(axis=(0, 1))
print('Welch:', X_avg)
plt.figure(figsize=(10, 8))
plt.plot(freqs_welch, psds_welch_avg);
for i, (fmin, fmax) in enumerate(_freq_bands_helper(epochs.info['sfreq'], freqs)):
mask = np.logical_and(freqs_welch >= fmin, freqs_welch <= fmax)
plt.axhline(X_avg[i] / len(freqs_welch[mask]), 0, 1, label=bands[i], color=color_list[i]);
plt.legend();
plt.title('Welch');
plt.xlabel('freqs');
plt.ylabel('PSDs');
In [ ]:
# compute PSDs and plot using multitaper algorithm
picks = mne.pick_types(epochs.info, meg=True, exclude='bads')
psds_multitaper, freqs_multi = psd_multitaper(epochs, fmin=1., fmax=200.,
picks=picks, n_jobs=4,
normalization='length')
In [ ]:
X_new = extract_features(data, epochs.info['sfreq'], selected_funcs,
funcs_params={'pow_freq_bands__freq_bands': freqs,
'pow_freq_bands__normalize': False,
'pow_freq_bands__ratios': None,
'pow_freq_bands__psd_method': 'multitaper'})
X_new = X_new.reshape((430, -1, len(freqs) - 1))
X_avg = X_new.mean(axis=(0, 1))
plt.figure(figsize=(10, 8))
plt.plot(freqs_multi, psds_multitaper.mean(axis=(0, 1)));
for i, (fmin, fmax) in enumerate(_freq_bands_helper(epochs.info['sfreq'], freqs)):
mask = np.logical_and(freqs_multi >= fmin, freqs_multi <= fmax)
plt.axhline(X_avg[i] / len(freqs_multi[mask]), 0, 1,
label=bands[i], color=color_list[i]);
plt.legend();
plt.title('multitaper');
plt.xlabel('freqs');
plt.ylabel('PSDs');
In [ ]:
plt.figure(figsize=(10, 8))
plt.plot(freqs_multi, psds_multitaper.mean(axis=(0, 1)), color='b', label='multitaper');
plt.plot(freqs_welch, psds_welch.mean(axis=(0, 1)), color='g', label='Welch');
plt.legend();
plt.title('Welch vs Multitaper');
plt.xlabel('freqs');
plt.ylabel('PSDs');
The normalization and averaging method used creates a large difference between PSD values between Welch and Multitaper.
End.