In [2]:
import os
import numpy as np
import scipy.io
from yass.evaluate.visualization import ChristmasPlot
from yass.evaluate.util import main_channels
In the constructor, give a title, number of total datasets that you want to plot side by side, a list of methods for which you are plotting results. logit_y will logit transforms the y axis for emphasis on low and high end part of the metric. eval_type simply is for naming purposes, and will appear in y-axis titles.
In the following block we just create fake SNR and metrics just fo demonstration purposes.
If you want to compute SNR of a templates (np.ndarray of shape (# time samples, # channels, # units)) just call main_channels(templates).
In [11]:
plot = ChristmasPlot('Fake', n_dataset=3, methods=['yass', 'kilosort', 'spyking circus'], logit_y=True, eval_type="Accuracy")
for method in plot.methods:
for i in range(plot.n_dataset):
x = (np.random.rand(30) - 0.5) * 10
y = 1 / (1 + np.exp(-x + np.random.rand()))
plot.add_metric(x, y, dataset_number=i, method_name=method)
In [12]:
plot.generate_snr_metric_plot(save_to=None)
In [14]:
plot.generate_curve_plots(save_to=None)
In [ ]: