In [ ]:
# Test setup. Ignore warnings during production runs.

%run ./setup_tests.py

Specify input data

  • data_dir (str): Where the data is located. (change if data is not in the current directory, normally is)
  • data (str): HDF5 file to use as input data.
  • data_basename (str): Basename to use for intermediate and final result files.
  • dataset (str): HDF5 dataset to use as input data.

In [ ]:
data_dir = ""
data = "data.tif"
data_basename = "data"
dataset = "images"


import os

data_ext = os.path.splitext(data)[1].lower()
data_dir = os.path.abspath(data_dir)

subgroup_raw = "raw"
subgroup_trim = "trim"
subgroup_dn = "dn"
subgroup_reg = "reg"
subgroup_reg_images = "reg/images"
subgroup_reg_shifts = "reg/shifts"
subgroup_sub = "sub"
subgroup_norm = "norm"
subgroup_dict_init_data = "dict_init_data"
subgroup_dict_init_dict = "dict_init_dict"
subgroup_dict = "dict"
subgroup_code = "code"
subgroup_post = "post"
subgroup_post_mask = "post/mask"
subgroup_rois = "rois"
subgroup_rois_masks = "rois/masks"
subgroup_rois_masks_j = "rois/masks_j"
subgroup_rois_labels = "rois/labels"
subgroup_rois_labels_j = "rois/labels_j"
subgroup_traces = "traces"
subgroup_proj = "proj"
subgroup_proj_hmean = "proj/hmean"
subgroup_proj_max = "proj/max"
subgroup_proj_mean = "proj/mean"
subgroup_proj_std = "proj/std"

postfix_rois = "_rois"
postfix_traces = "_traces"
postfix_html = "_proj"

h5_ext = os.path.extsep + "h5"
tiff_ext = os.path.extsep + "tif"
zarr_ext = os.path.extsep + "zarr"
html_ext = os.path.extsep + "html"

In [ ]:
import os
from psutil import cpu_count

cluster_kwargs = {
    "ip": ""
}
client_kwargs = {}
adaptive_kwargs = {
    "minimum": 0,
    "maximum": int(os.environ.get("CORES", cpu_count())) - 1
}

In [ ]:
import zarr

from nanshe_workflow.data import DistributedDirectoryStore

zarr_store = zarr.open_group(DistributedDirectoryStore(data_basename + zarr_ext), "a")

Configure and startup Cluster


In [ ]:
from nanshe_workflow.par import startup_distributed
from nanshe_workflow.data import DistributedArrayStore

client = startup_distributed(0, cluster_kwargs, client_kwargs, adaptive_kwargs)

dask_store = DistributedArrayStore(zarr_store, client=client)

client

In [ ]:
client.cluster

Define functions for computation.


In [ ]:
%matplotlib notebook

import matplotlib
import matplotlib.cm
import matplotlib.pyplot

import matplotlib as mpl
import matplotlib.pyplot as plt

from mplview.core import MatplotlibViewer as MPLViewer

In [ ]:
import ctypes
import logging
import os
import importlib
import sys

from builtins import (
    map as imap,
    range as irange
)
from past.builtins import basestring

try:
    from contextlib import suppress
except ImportError:
    from contextlib2 import suppress

import numpy
import h5py

import numpy as np
import h5py as hp

import dask
import dask.array
import dask.array.fft
import dask.distributed

import dask.array as da

try:
    from dask.highlevelgraph import HighLevelGraph
except ImportError:
    import dask.sharedict as HighLevelGraph

try:
    from dask.array import blockwise as da_blockwise
except ImportError:
    from dask.array import atop as da_blockwise

import dask_image
import dask_image.imread
import dask_image.ndfilters
import dask_image.ndfourier

import zarr

import nanshe
from nanshe.imp.segment import generate_dictionary

import nanshe_workflow
import nanshe_workflow._reg_joblib
from nanshe_workflow.data import io_remove, dask_io_remove, dask_load_hdf5, dask_store_zarr, zip_zarr, open_zarr

zarr.blosc.set_nthreads(1)
zarr.blosc.use_threads = False
client.run(zarr.blosc.set_nthreads, 1)
client.run(setattr, zarr.blosc, "use_threads", False)

logging.getLogger("nanshe").setLevel(logging.INFO)

In [ ]:
from nanshe_workflow.data import DistributedDirectoryStore
from nanshe_workflow.data import hdf5_to_zarr, zarr_to_hdf5
from nanshe_workflow.data import save_tiff

In [ ]:
try:
    import pyfftw.interfaces.dask_fft as dask_fft
except ImportError:
    import dask.array.fft as dask_fft

In [ ]:
from nanshe_workflow.reg import fourier_shift_wrap, compute_offset, roll_frames

In [ ]:
from nanshe_workflow.imp2 import extract_f0, wavelet_transform, renormalized_images, normalize_data

from nanshe_workflow.par import halo_block_generate_dictionary_parallel
from nanshe_workflow.imp import block_postprocess_data_parallel

In [ ]:
from nanshe_workflow.proj import compute_traces

from nanshe_workflow.proj import compute_adj_harmonic_mean_projection

from nanshe_workflow.proj import norm_layer

Begin workflow. Set parameters and run each cell.

Convert TIFF/HDF5 to Zarr

  • block_chunks (tuple of ints): chunk size for each block loaded into memory.

In [ ]:
block_chunks = (100, -1, -1)

for k in [subgroup_raw]:
    with suppress(KeyError):
        del dask_store[k]

dask_ingest_func = None
if data_ext == tiff_ext:
    dask_ingest_func = lambda data: dask_image.imread.imread(data, nframes=block_chunks[0])
elif data_ext == h5_ext:
    dask_ingest_func = lambda data: dask_load_hdf5(data, dataset=dataset, chunks=block_chunks)


if isinstance(data, basestring):
    dask_store[subgroup_raw] = dask_ingest_func(data)
else:
    dask_store[subgroup_raw] = da.concatenate(list(imap(dask_ingest_func, data)))

dask.distributed.progress(dask_store[subgroup_raw], notebook=False)

View Input Data


In [ ]:
imgs_min, imgs_max = 0, 100

da_imgs = dask_store[subgroup_raw]

da_imgs_min, da_imgs_max = da_imgs.min(), da_imgs.max()

status = client.compute([da_imgs_min, da_imgs_max])
dask.distributed.progress(status, notebook=False)
print("")

imgs_min, imgs_max = [s.result() for s in status]

mplsv = plt.figure(FigureClass=MPLViewer)
mplsv.set_images(
    da_imgs,
    vmin=imgs_min,
    vmax=imgs_max
)

Trimming

  • front (int): amount to trim off the front
  • back (int): amount to trim off the back

In [ ]:
front = 1
back = 1


for k in [subgroup_trim]:
    with suppress(KeyError):
        del dask_store[k]


# Load and prep data for computation.
da_imgs = dask_store[subgroup_raw]

# Trim frames from front and back
da_imgs_trim = da_imgs[front:len(da_imgs)-back]

# Store trimmed data
dask_store[subgroup_trim] = da_imgs_trim

# Check progress of store step
dask.distributed.progress(dask_store[subgroup_trim], notebook=False)
print("")


# View results
imgs_min, imgs_max = 0, 100

da_imgs = dask_store[subgroup_trim]

da_imgs_min, da_imgs_max = da_imgs.min(), da_imgs.max()

status = client.compute([da_imgs_min, da_imgs_max])
dask.distributed.progress(status, notebook=False)
print("")

imgs_min, imgs_max = [s.result() for s in status]

mplsv = plt.figure(FigureClass=MPLViewer)
mplsv.set_images(
    da_imgs,
    vmin=imgs_min,
    vmax=imgs_max
)

Denoising

  • med_filt_size (int): footprint size for median filter

In [ ]:
med_filt_size = 3


for k in [subgroup_dn]:
    with suppress(KeyError):
        del dask_store[k]


# Load and prep data for computation.
da_imgs = dask_store[subgroup_trim]

# Median filter frames
da_imgs_filt = dask_image.ndfilters.median_filter(
    da_imgs, (1,) + (da_imgs.ndim - 1) * (med_filt_size,)
)

# Reset minimum to original value.
da_imgs_min = da_imgs.min()
da_imgs_filt_min = da_imgs_filt.min()
da_imgs_filt += da_imgs_min - da_imgs_filt_min

# Store denoised data
da_imgs_min, da_imgs_filt, da_imgs_filt_min = dask.persist(da_imgs_min, da_imgs_filt, da_imgs_filt_min)
dask_store[subgroup_dn] = da_imgs_filt
del da_imgs_min, da_imgs_filt, da_imgs_filt_min

# Check progress of store step
dask.distributed.progress(dask_store[subgroup_dn], notebook=False)
print("")


# View results
imgs_min, imgs_max = 0, 100

da_imgs = dask_store[subgroup_dn]

da_imgs_min, da_imgs_max = da_imgs.min(), da_imgs.max()

status = client.compute([da_imgs_min, da_imgs_max])
dask.distributed.progress(status, notebook=False)
print("")

imgs_min, imgs_max = [s.result() for s in status]

mplsv = plt.figure(FigureClass=MPLViewer)
mplsv.set_images(
    da_imgs,
    vmin=imgs_min,
    vmax=imgs_max
)

Registration


In [ ]:
num_reps = 5
tmpl_hist_wght = 0.25
thld_rel_dist = 0.0


for k in [subgroup_reg_images, subgroup_reg_shifts]:
    with suppress(KeyError):
        del dask_store[k]
with suppress(KeyError):
    del zarr_store[subgroup_reg]
zarr_store.require_group(subgroup_reg)


# Load and prep data for computation.
da_imgs = dask_store[subgroup_dn]

da_imgs_flt = da_imgs
if not (issubclass(da_imgs_flt.dtype.type, np.floating) and 
        da_imgs_flt.dtype.itemsize >= 4):
    da_imgs_flt = da_imgs_flt.astype(np.float32)

# Create frame shape arrays
frame_shape = np.array(da_imgs_flt.shape[1:], dtype=int)
half_frame_shape = frame_shape // 2
frame_shape = da.asarray(frame_shape)
half_frame_shape = da.asarray(half_frame_shape)

# Find the inverse of each frame
da_imgs_flt_min = da_imgs_flt.min()
da_imgs_inv = dask.array.reciprocal(da_imgs_flt - (da_imgs_flt_min - 1))

# Compute the FFT of inverse frames and template
da_imgs_fft = dask_fft.rfftn(da_imgs_inv, axes=tuple(irange(1, da_imgs_flt.ndim)))
da_imgs_fft_tmplt = da_imgs_fft.mean(axis=0, keepdims=True)

# Initialize
i = 0
avg_rel_dist = 1.0
tmpl_hist_wght = da_imgs_flt.dtype.type(tmpl_hist_wght)
da_shifts = da.zeros(
    (len(da_imgs_fft), da_imgs_fft.ndim - 1),
    dtype=int,
    chunks=(da_imgs_fft.chunks[0], (da_imgs_fft.ndim - 1,))
)

# Persist FFT of frames and template
da_imgs_flt_min, da_imgs_fft, da_imgs_fft_tmplt = client.persist([
    da_imgs_flt_min, da_imgs_fft, da_imgs_fft_tmplt
])
del da_imgs_flt_min

while avg_rel_dist > thld_rel_dist and i < num_reps:
    # Compute the shifted frames
    da_shifted_frames = da_blockwise(
        fourier_shift_wrap,
        (0,) + tuple(irange(1, da_imgs_fft.ndim)),
        da_imgs_fft,
        (0,) + tuple(irange(1, da_imgs_fft.ndim)),
        da_shifts,
        (0, da_imgs_fft.ndim),
        dtype=da_imgs_fft.dtype
    )

    # Compute the template FFT
    da_imgs_fft_tmplt = (
        tmpl_hist_wght * da_imgs_fft_tmplt +
        (1 - tmpl_hist_wght) * da_shifted_frames.mean(axis=0, keepdims=True)
    )

    # Persist the updated FFT template
    da_shifted_frames, da_imgs_fft_tmplt = client.persist([
        da_shifted_frames, da_imgs_fft_tmplt
    ])
    del da_shifted_frames

    # Find the best overlap with the template.
    da_overlap = dask_fft.irfftn(
        da_imgs_fft * da_imgs_fft_tmplt,
        s=da_imgs_flt.shape[1:],
        axes=tuple(irange(1, da_imgs_flt.ndim))
    )
    da_overlap_max = da_overlap.max(axis=tuple(irange(1, da_imgs_flt.ndim)), keepdims=True)
    da_overlap_max_match = (da_overlap == da_overlap_max)

    # Clear FFT overlap intermediates
    del da_overlap_max

    # Compute the shift for each frame.
    old_da_shifts = da_shifts
    da_raw_shifts = da_blockwise(
        compute_offset,
        (0, da_overlap_max_match.ndim),
        da_overlap_max_match.rechunk(dict(enumerate(da_overlap_max_match.shape[1:], 1))),
        tuple(irange(0, da_overlap_max_match.ndim)),
        dtype=int,
        new_axes={da_overlap_max_match.ndim: da_overlap_max_match.ndim - 1}
    )

    # Free connected persisted values
    del da_overlap_max_match

    # Remove any collective frame drift.
    da_drift = da_raw_shifts.mean(axis=0, keepdims=True).round().astype(da_shifts.dtype)
    da_shifts = da_raw_shifts - da_drift

    # Clear drift corrected shifts
    del da_drift

    # Find shift change.
    diff_da_shifts = da_shifts - old_da_shifts
    rel_diff_da_shifts = (
        diff_da_shifts.astype(da_imgs_flt.dtype) / 
        frame_shape.astype(da_imgs_flt.dtype) /
        np.sqrt(da_imgs_flt.dtype.type(len(frame_shape)))
    )
    rel_dist_da_shifts = da.sqrt(da.square(rel_diff_da_shifts).sum(axis=1))
    avg_rel_dist = rel_dist_da_shifts.sum() / da_imgs_flt.dtype.type(len(da_shifts))

    # Free old shifts
    del old_da_shifts

    # Persist statistics related to shift change
    da_overlap, da_raw_shifts, diff_da_shifts, rel_diff_da_shifts, rel_dist_da_shifts, avg_rel_dist = client.persist([
        da_overlap, da_raw_shifts, diff_da_shifts, rel_diff_da_shifts, rel_dist_da_shifts, avg_rel_dist
    ])
    del da_overlap
    del da_raw_shifts
    del diff_da_shifts
    del rel_diff_da_shifts
    del rel_dist_da_shifts

    # Compute change
    dask.distributed.progress(avg_rel_dist, notebook=False)
    print("")
    avg_rel_dist = avg_rel_dist.compute()
    i += 1

    # Show change
    print((i, avg_rel_dist))

# Drop unneeded items
del frame_shape
del half_frame_shape
del da_imgs_flt
del da_imgs_inv
del da_imgs_fft
del da_imgs_fft_tmplt

# Roll all parts to clip to one side
# Keep origin static
da_imgs_shifted = roll_frames(
    da_imgs,
    da.clip(da_shifts, None, 0)
)

# Truncate all frames to smallest one
da_imgs_trunc_shape = da.asarray(da_imgs.shape[1:]) - abs(da_shifts).max(axis=0)
da_imgs_trunc_shape = da_imgs_trunc_shape.compute()

da_imgs_trunc_cut = tuple(map(
    lambda s: slice(None, s), da_imgs_trunc_shape
))

da_imgs_trunc = da_imgs_shifted[(slice(None),) + da_imgs_trunc_cut]

# Free raw data
del da_imgs

# Store registered data
dask_store.update({
    subgroup_reg_images: da_imgs_trunc,
    subgroup_reg_shifts: da_shifts,
})

dask.distributed.progress(
    dask.distributed.futures_of([
        dask_store[subgroup_reg_images],
        dask_store[subgroup_reg_shifts]
    ]),
    notebook=False
)
print("")

# Free truncated frames and shifts
del da_imgs_trunc
del da_shifts


# View results
imgs_min, imgs_max = 0, 100

da_imgs = dask_store[subgroup_reg_images]
da_shifts = dask_store[subgroup_reg_shifts]

da_imgs_min, da_imgs_max = da_imgs.min(), da_imgs.max()

status = client.compute([da_imgs_min, da_imgs_max])
dask.distributed.progress(status, notebook=False)
print("")

imgs_min, imgs_max = [s.result() for s in status]

fig, axs = plt.subplots(nrows=da_shifts.shape[1], sharex=True)
fig.subplots_adjust(hspace=0.0)
for i in range(da_shifts.shape[1]):
    axs[i].plot(np.asarray(da_shifts[:, i]))
    axs[i].set_ylabel("%s (px)" % chr(ord("X") + da_shifts.shape[1] - i - 1))
    axs[i].yaxis.set_tick_params(width=1.5)
    [v.set_linewidth(2) for v in axs[i].spines.values()]
axs[-1].set_xlabel("Frame (#)")
axs[-1].set_xlim((0, da_shifts.shape[0] - 1))
axs[-1].xaxis.set_tick_params(width=1.5)

mplsv = plt.figure(FigureClass=MPLViewer)
mplsv.set_images(
    da_imgs,
    vmin=imgs_min,
    vmax=imgs_max
)

Projections


In [ ]:
for k in [subgroup_proj_hmean, subgroup_proj_max, subgroup_proj_mean, subgroup_proj_std]:
    with suppress(KeyError):
        del dask_store[k]
with suppress(KeyError):
    del zarr_store[subgroup_proj]
zarr_store.require_group(subgroup_proj)


# Load and prep data for computation.
da_imgs = dask_store[subgroup_reg_images]

da_imgs_flt = da_imgs
if not (issubclass(da_imgs_flt.dtype.type, np.floating) and 
        da_imgs_flt.dtype.itemsize >= 4):
    da_imgs_flt = da_imgs_flt.astype(np.float32)

da_imgs_proj_hmean = compute_adj_harmonic_mean_projection(da_imgs_flt)

da_imgs_proj_max = da_imgs_flt.max(axis=0)

da_imgs_proj_mean, da_imgs_proj_std = da_imgs_flt.mean(axis=0), da_imgs_flt.std(axis=0)

# Store projections
dask_store.update(dict(zip(
    [subgroup_proj_hmean, subgroup_proj_max, subgroup_proj_mean, subgroup_proj_std],
    [da_imgs_proj_hmean, da_imgs_proj_max, da_imgs_proj_mean, da_imgs_proj_std]
)))

dask.distributed.progress(
    dask.distributed.futures_of([
        dask_store[subgroup_proj_hmean],
        dask_store[subgroup_proj_max],
        dask_store[subgroup_proj_mean],
        dask_store[subgroup_proj_std]
    ]),
    notebook=False
)
print("")

Subtract Projection


In [ ]:
for k in [subgroup_sub]:
    with suppress(KeyError):
        del dask_store[k]


# Load and prep data for computation.
da_imgs = dask_store[subgroup_reg_images]

da_imgs_flt = da_imgs
if not (issubclass(da_imgs_flt.dtype.type, np.floating) and 
        da_imgs_flt.dtype.itemsize >= 4):
    da_imgs_flt = da_imgs_flt.astype(np.float32)

da_imgs_sub = da_imgs_flt - compute_adj_harmonic_mean_projection(da_imgs_flt)
da_imgs_sub -= da_imgs_sub.min()

# Store background removed data
dask_store[subgroup_sub] = da_imgs_sub

dask.distributed.progress(dask_store[subgroup_sub], notebook=False)
print("")


# View results
imgs_min, imgs_max = 0, 100

da_imgs = dask_store[subgroup_sub]

da_imgs_min, da_imgs_max = da_imgs.min(), da_imgs.max()

status = client.compute([da_imgs_min, da_imgs_max])
dask.distributed.progress(status, notebook=False)
print("")

imgs_min, imgs_max = [s.result() for s in status]

mplsv = plt.figure(FigureClass=MPLViewer)
mplsv.set_images(
    da_imgs,
    vmin=imgs_min,
    vmax=imgs_max
)

Normalize Data


In [ ]:
for k in [subgroup_norm]:
    with suppress(KeyError):
        del dask_store[k]


# Load and prep data for computation.
da_imgs = dask_store[subgroup_sub]

da_imgs_flt = da_imgs
if not (issubclass(da_imgs_flt.dtype.type, np.floating) and 
        da_imgs_flt.dtype.itemsize >= 4):
    da_imgs_flt = da_imgs_flt.astype(np.float32)

da_imgs_flt_mins = da_imgs_flt.min(
    axis=tuple(irange(1, da_imgs_flt.ndim)),
    keepdims=True
)

da_imgs_flt_shift = da_imgs_flt - da_imgs_flt_mins

da_result = renormalized_images(da_imgs_flt_shift)

# Store normalized data
dask_store[subgroup_norm] = da_result

dask.distributed.progress(dask_store[subgroup_norm], notebook=False)
print("")


# View results
imgs_min, imgs_max = 0, 100

da_imgs = dask_store[subgroup_norm]

da_imgs_min, da_imgs_max = da_imgs.min(), da_imgs.max()

status = client.compute([da_imgs_min, da_imgs_max])
dask.distributed.progress(status, notebook=False)
print("")

imgs_min, imgs_max = [s.result() for s in status]

mplsv = plt.figure(FigureClass=MPLViewer)
mplsv.set_images(
    da_imgs,
    vmin=imgs_min,
    vmax=imgs_max
)

Dictionary Learning

  • n_components (int): number of basis images in the dictionary.
  • batchsize (int): minibatch size to use.
  • iters (int): number of iterations to run before getting dictionary.
  • lambda1 (float): weight for L1 sparisty enforcement on sparse code.
  • lambda2 (float): weight for L2 sparisty enforcement on sparse code.


  • block_frames (int): number of frames to work with in each full frame block (run in parallel).

In [ ]:
import functools
import itertools
import logging
import operator
import sys

from builtins import filter as ifilter

import toolz
import toolz.itertoolz

import sklearn
from sklearn.base import BaseEstimator
from sklearn.decomposition.dict_learning import SparseCodingMixin

try:
    from contextlib import ExitStack, suppress, redirect_stdout, redirect_stderr
except ImportError:
    from contextlib2 import ExitStack, suppress, redirect_stdout, redirect_stderr

try:
    from StringIO import StringIO
except ImportError:
    from io import StringIO


def to_tuple(*args):
    return tuple(args)


def func_log_stdoe(func):
    @functools.wraps(func)
    def wrapped(*args, **kwargs):
        with ExitStack() as stack:
            out, err = StringIO(), StringIO()

            stack.enter_context(redirect_stdout(out))
            stack.enter_context(redirect_stdout(err))

            try:
                return func(*args, **kwargs)
            finally:
                logging.getLogger("distributed.worker.stdout").info(out.getvalue())
                logging.getLogger("distributed.worker.stderr").info(err.getvalue())

    return wrapped


def dict_learning_online(X, n_components=2, alpha=1, n_iter=100,
                         return_code=True, dict_init=None, callback=None,
                         batch_size=3, verbose=False, shuffle=True, n_jobs=1,
                         method='lars', iter_offset=0, random_state=None,
                         return_inner_stats=False, inner_stats_A=None,
                         inner_stats_B=None, return_n_iter=False,
                         positive_dict=False, positive_code=False):
    if inner_stats_A is None or inner_stats_B is None:
        inner_stats = None
    else:
        inner_stats = (inner_stats_A, inner_stats_B)

    with sklearn.externals.joblib.parallel_backend("sequential"):
        result = sklearn.decomposition.dict_learning_online(X, n_components=n_components, alpha=alpha,
                                                            n_iter=n_iter, return_code=return_code,
                                                            dict_init=np.require(dict_init, requirements="OW"),
                                                            callback=callback, batch_size=batch_size, verbose=verbose,
                                                            shuffle=shuffle, n_jobs=n_jobs, method=method,
                                                            iter_offset=iter_offset, random_state=random_state,
                                                            return_inner_stats=return_inner_stats,
                                                            inner_stats=inner_stats,
                                                            return_n_iter=return_n_iter,
                                                            positive_dict=positive_dict,
                                                            positive_code=positive_code)

    components_, inner_stats_ = result
    inner_stats_A_, inner_stats_B_ = inner_stats_
    return components_, inner_stats_A_, inner_stats_B_




class MiniBatchDictionaryLearning(BaseEstimator, SparseCodingMixin):
    def __init__(self,
                 n_components=None,
                 alpha=1,
                 n_iter=1000,
                 fit_algorithm="lars",
                 batch_size=3,
                 shuffle=True,
                 dict_init=None,
                 transform_algorithm="omp",
                 transform_n_nonzero_coefs=None,
                 transform_alpha=None,
                 positive_code=False,
                 positive_dict=False,
                 verbose=False):
        self.n_components = n_components
        self.alpha = alpha
        self.n_iter = n_iter
        self.fit_algorithm = fit_algorithm
        self.batch_size = batch_size
        self.shuffle = shuffle
        self.dict_init = dict_init
        self.transform_algorithm = transform_algorithm
        self.transform_n_nonzero_coefs = transform_n_nonzero_coefs
        self.transform_alpha = transform_alpha
        self.positive_code = positive_code
        self.positive_dict = positive_dict
        self.verbose = verbose

        self.components_ = self.dict_init
        if self.components_ is not None:
            self.components_ = da.asarray(self.components_)
            self.components_ = self.components_.rechunk(self.components_.shape)

        self.iter_offset_ = 0
        self.inner_stats_ = None


    def fit(self, X, y=None):
        X = da.asarray(X)

        n_components = self.n_components
        n_features = X.shape[1]

        components_ = self.components_
        inner_stats_ = self.inner_stats_

        X_get = toolz.curry(
            operator.getitem,
            X.rechunk((1, X.shape[1]))
        )
        if components_ is None:
            idx = np.random.permutation(X.shape[0])[:n_components]
            components_ = da.stack([X_get(i) for i in idx.flat]).rechunk({0: n_components})
        if self.shuffle:
            idx = np.random.permutation(X.shape[0])
            X = da.stack([X_get(i) for i in idx.flat]).rechunk({0: X.chunks[0]})
        if inner_stats_ is None:
            inner_stats_ = (
                da.zeros((n_components, n_components), chunks=(n_components, n_components)),
                da.zeros((n_features, n_components), chunks=(n_features, n_components))
            )

        func = dict_learning_online
        if self.verbose:
            func = func_log_stdoe(dict_learning_online)

        for X_chk_key in dask.core.flatten(X.__dask_keys__()):
            components_key = next(dask.core.flatten(components_.__dask_keys__()))
            if inner_stats_:
                inner_stats_A_key = next(dask.core.flatten(inner_stats_[0].__dask_keys__()))
                inner_stats_B_key = next(dask.core.flatten(inner_stats_[1].__dask_keys__()))
            else:
                inner_stats_ = tuple()
                inner_stats_A_key = None
                inner_stats_B_key = None

            kwargs = dict(
                n_components=n_components,
                alpha=self.alpha,
                n_iter=self.n_iter,
                return_code=False,
                dict_init=components_key,
                callback=None,
                batch_size=self.batch_size,
                verbose=self.verbose,
                shuffle=False,
                n_jobs=1,
                method=self.fit_algorithm,
                iter_offset=self.iter_offset_,
                random_state=None,
                return_inner_stats=True,
                inner_stats_A=inner_stats_A_key,
                inner_stats_B=inner_stats_B_key,
                return_n_iter=False,
                positive_dict=self.positive_dict,
                positive_code=self.positive_code
            )

            tok = dask.base.tokenize(X_chk_key, **kwargs)
            components_res_key = ("components_dict_learning_online-%s" % tok, 0, 0)
            inner_stat_A_res_key = ("inner_stat_A_dict_learning_online-%s" % tok, 0, 0)
            inner_stat_B_res_key = ("inner_stat_B_dict_learning_online-%s" % tok, 0, 0)
            dict_learn_key = "dict_learning_online-%s" % tok

            dct = HighLevelGraph.merge(
                {components_res_key: (
                    operator.getitem, dict_learn_key, 0
                )},
                {inner_stat_A_res_key: (
                    operator.getitem, dict_learn_key, 1
                )},
                {inner_stat_B_res_key: (
                    operator.getitem, dict_learn_key, 2
                )},
                {dict_learn_key:
                    (dask.compatibility.apply, func, [X_chk_key], kwargs),
                },
                X.__dask_graph__(),
                components_.__dask_graph__(),
                *(e.__dask_graph__() for e in inner_stats_)
            )
            dct = da.Array.__dask_optimize__(dct, [
                components_res_key,
                inner_stat_A_res_key,
                inner_stat_B_res_key
            ])

            components_ = da.Array(
                dct, components_res_key[0], ((n_components,), (n_features,)), X.dtype
            )
            inner_stats_ = (
                da.Array(
                    dct, inner_stat_A_res_key[0], ((n_components,), (n_components,)), X.dtype
                ),
                da.Array(
                    dct, inner_stat_B_res_key[0], ((n_features,), (n_components,)), X.dtype
                ),
            )

            # Persist everything after an iteration
            result = dask.persist(components_, *inner_stats_)
            components_, inner_stats_ = result[0], result[1:]

            self.iter_offset_ += self.n_iter

        self.components_ = components_
        self.inner_stats_ = inner_stats_

        return self


    @staticmethod
    def _sparse_encode_wrapper(*args, **kwargs):
        args = tuple(e[0] if isinstance(e, list) else e for e in args)
        return sklearn.decomposition.sparse_encode(*args, **kwargs)


    def transform(self, X):
        sparse_encode_wrapper = MiniBatchDictionaryLearning._sparse_encode_wrapper
        if self.verbose:
            sparse_encode_wrapper = func_log_stdoe(sparse_encode_wrapper)

        gram = da.tensordot(self.components_, self.components_, axes=[[1], [1]])
        cov = da.tensordot(self.components_, X, axes=[[1], [1]])

        code = da_blockwise(
            sparse_encode_wrapper,
            (1, 0),
            X,
            (1, 2),
            self.components_,
            (0, 2),
            gram,
            (0, 0),
            cov,
            (0, 1),
            dtype=self.components_.dtype,
            algorithm=self.transform_algorithm,
            n_nonzero_coefs=self.transform_n_nonzero_coefs,
            alpha=self.transform_alpha,
            copy_cov=True,
            init=None,
            max_iter=self.n_iter,
            n_jobs=1,
            check_input=False,
            verbose=self.verbose
        )
        gram, cov, code = dask.persist(gram, cov, code)

        return code

In [ ]:
import toolz


n_components = 50


for k in [subgroup_dict_init_data, subgroup_dict_init_dict]:
    with suppress(KeyError):
        del dask_store[k]

da_imgs = dask_store[subgroup_norm]

# Reshape to matrix and provide frame selector
da_imgs_mtx = da_imgs.reshape(
    da_imgs.shape[0],
    int(np.prod(da_imgs.shape[1:]))
)
da_imgs_mtx_get = toolz.curry(
    operator.getitem,
    da_imgs_mtx.rechunk((1, da_imgs_mtx.shape[1]))
)

# Create shuffled data
idx = np.random.permutation(da_imgs_mtx.shape[0])
dict_init_data = da.stack([da_imgs_mtx_get(i) for i in idx.flat])
dict_init_data = dict_init_data.rechunk({0: da_imgs_mtx.chunks[0]})

# Create dictionary subsample
idx = np.random.permutation(da_imgs_mtx.shape[0])[:n_components]
dict_init_dict = da.stack([da_imgs_mtx_get(i) for i in idx.flat])
dict_init_dict = dict_init_dict.rechunk({0: n_components})


# Store shuffled data
dask_store.update(dict(zip(
    [subgroup_dict_init_data, subgroup_dict_init_dict],
    [dict_init_data, dict_init_dict]
)))

dask.distributed.progress(
    dask.distributed.futures_of([
        dask_store[subgroup_dict_init_data],
        dask_store[subgroup_dict_init_dict]
    ]),
    notebook=False
)
print("")

In [ ]:
from nanshe_workflow.par import startup_distributed, shutdown_distributed
from nanshe_workflow.data import DistributedArrayStore
from psutil import cpu_count


ncores = int(os.environ.get("CORES", cpu_count())) - 1


dict_client = startup_distributed(
    1,
    cluster_kwargs={
        "threads_per_worker": ncores
    },
    client_kwargs={}
)
dict_client.run(importlib.import_module, "nanshe_workflow._reg_joblib")
dict_client.run(
    lambda n: ctypes.CDLL(ctypes.util.find_library("openblas")).openblas_set_num_threads(n),
    ncores
)
display(dict_client)


dict_dask_store = DistributedArrayStore(zarr_store, client=dict_client)

for k in [subgroup_dict]:
    with suppress(KeyError):
        del dict_dask_store[k]

learner = MiniBatchDictionaryLearning(
    n_components=len(dict_dask_store[subgroup_dict_init_dict]),
    alpha=0.2,
    n_iter=100,
    fit_algorithm="lars",
    batch_size=50,
    shuffle=False,
    dict_init=dict_dask_store[subgroup_dict_init_dict],
    transform_algorithm="omp",
    transform_n_nonzero_coefs=None,
    transform_alpha=0.01,
    verbose=False
)

learner.fit(dict_dask_store[subgroup_dict_init_data])

dictionary = learner.components_
dictionary = dictionary.reshape((dictionary.shape[0],) + dict_dask_store[subgroup_norm].shape[1:])
dictionary = dictionary.persist()

dict_dask_store[subgroup_dict] = dictionary
del dictionary

learner.components_ = dict_dask_store[subgroup_dict].reshape(learner.components_.shape)

dask.distributed.progress(
    dict_dask_store[subgroup_dict],
    notebook=False
)
print("")

del dict_dask_store
shutdown_distributed(dict_client)
del dict_client

In [ ]:
for k in [subgroup_code]:
    with suppress(KeyError):
        del dask_store[k]

da_imgs = dask_store[subgroup_norm]
da_imgs_mtx = da_imgs.reshape(
    da_imgs.shape[0],
    int(np.prod(da_imgs.shape[1:]))
)

learner.components_ = dask_store[subgroup_dict].reshape(
    (dask_store[subgroup_dict].shape[0], int(np.prod(da_imgs.shape[1:])))
)

code = learner.transform(da_imgs_mtx)
code = code.T
code = code.persist()

dask_store[subgroup_code] = code

del learner
del code

dask.distributed.progress(
    dask_store[subgroup_code],
    notebook=False
)
print("")

In [ ]:
import ipywidgets

fig = plt.figure()
fig.add_subplot(1,2,1)
im = plt.imshow(dask_store[subgroup_dict][0])
fig.add_subplot(1,2,2)
line, = plt.plot(dask_store[subgroup_code][0], lw=2)
plt.show()

@ipywidgets.interact(i=ipywidgets.IntSlider(min=0, max=len(dask_store[subgroup_dict])-1, step=1, value=0))
def show_basis_code_plts(i):
    im.set_array(dask_store[subgroup_dict][i])
    line.set_data(np.arange(len(dask_store[subgroup_code][i])), dask_store[subgroup_code][i])
    fig.canvas.draw_idle()

Postprocessing

  • significance_threshold (float): number of standard deviations below which to include in "noise" estimate
  • wavelet_scale (int): scale of wavelet transform to apply (should be the same as the one used above)
  • noise_threshold (float): number of units of "noise" above which something needs to be to be significant
  • accepted_region_shape_constraints (dict): if ROIs don't match this, reduce the wavelet_scale once.
  • percentage_pixels_below_max (float): upper bound on ratio of ROI pixels not at max intensity vs. all ROI pixels
  • min_local_max_distance (float): minimum allowable euclidean distance between two ROIs maximum intensities
  • accepted_neuron_shape_constraints (dict): shape constraints for ROI to be kept.

  • alignment_min_threshold (float): similarity measure of the intensity of two ROIs images used for merging.

  • overlap_min_threshold (float): similarity measure of the masks of two ROIs used for merging.

In [ ]:
significance_threshold = 3.0
wavelet_scale = 3
noise_threshold = 3.0
percentage_pixels_below_max = 0.8
min_local_max_distance = 16.0

alignment_min_threshold = 0.6
overlap_min_threshold = 0.6


for k in zarr_store.get(subgroup_post, {}).keys():
    with suppress(KeyError):
        del dask_store[subgroup_post + "/" + k]
with suppress(KeyError):
    del zarr_store[subgroup_post]
zarr_store.require_group(subgroup_post)


imgs = dask_store._diskstore[subgroup_dict]
da_imgs = da.from_array(imgs, chunks=((1,) + imgs.shape[1:]))

result = block_postprocess_data_parallel(client)(da_imgs,
                              **{
                                    "wavelet_denoising" : {
                                        "estimate_noise" : {
                                            "significance_threshold" : significance_threshold
                                        },
                                        "wavelet.transform" : {
                                            "scale" : wavelet_scale
                                        },
                                        "significant_mask" : {
                                            "noise_threshold" : noise_threshold
                                        },
                                        "accepted_region_shape_constraints" : {
                                            "major_axis_length" : {
                                                "min" : 0.0,
                                                "max" : 25.0
                                            }
                                        },
                                        "remove_low_intensity_local_maxima" : {
                                            "percentage_pixels_below_max" : percentage_pixels_below_max
                                        },
                                        "remove_too_close_local_maxima" : {
                                            "min_local_max_distance" : min_local_max_distance
                                        },
                                        "accepted_neuron_shape_constraints" : {
                                            "area" : {
                                                "min" : 25,
                                                "max" : 600
                                            },
                                            "eccentricity" : {
                                                "min" : 0.0,
                                                "max" : 0.9
                                            }
                                        }
                                    },
                                    "merge_neuron_sets" : {
                                        "alignment_min_threshold" : alignment_min_threshold,
                                        "overlap_min_threshold" : overlap_min_threshold,
                                        "fuse_neurons" : {
                                            "fraction_mean_neuron_max_threshold" : 0.01
                                        }
                                    }
                              }
)

# Store projections
dask_store.update(dict(zip(
    ["%s/%s" % (subgroup_post, e) for e in result.dtype.names],
    [result[e] for e in result.dtype.names]
)))

dask.distributed.progress(
    dask.distributed.futures_of([
        dask_store["%s/%s" % (subgroup_post, e)]
        for e in result.dtype.names
    ]),
    notebook=False
)
print("")

ROI and trace extraction


In [ ]:
dask_io_remove(data_basename + postfix_rois + h5_ext, client)
for k in [subgroup_rois_masks, subgroup_rois_masks_j, subgroup_rois_labels, subgroup_rois_labels_j, subgroup_rois]:
    with suppress(KeyError):
        del dask_store[k]
with suppress(KeyError):
    del zarr_store[subgroup_rois]
zarr_store.require_group(subgroup_rois)


da_roi_masks = dask_store[subgroup_post_mask]

da_lbls = da.arange(
    1,
    len(da_roi_masks) + 1,
    chunks=da_roi_masks.chunks[0],
    dtype=np.uint64
)
da_lblimg = (
    da_lbls[(slice(None),) + (da_roi_masks.ndim - 1) * (None,)] * 
    da_roi_masks.astype(np.uint64)
).max(axis=0)

dask_store.update(dict(zip(
    [subgroup_rois_masks, subgroup_rois_masks_j, subgroup_rois_labels, subgroup_rois_labels_j],
    [da_roi_masks, da_roi_masks.astype(numpy.uint8), da_lblimg, da_lblimg.astype(numpy.uint8)]
)))

dask.distributed.progress(
    dask.distributed.futures_of([
        dask_store[e] for e in
    [subgroup_rois_masks, subgroup_rois_masks_j, subgroup_rois_labels, subgroup_rois_labels_j]
    ]),
    notebook=False
)
print("")


with h5py.File(data_basename + postfix_rois + h5_ext, "w") as f2:
    for k in [subgroup_rois_masks, subgroup_rois_masks_j, subgroup_rois_labels, subgroup_rois_labels_j]:
        zarr.copy(dask_store._diskstore[k], f2)


dask_io_remove(data_basename + postfix_traces + h5_ext, client)
for k in [subgroup_traces]:
    with suppress(KeyError):
        del dask_store[k]


# Load and prep data for computation.
da_images = dask_store[subgroup_sub]
da_masks = dask_store[subgroup_rois_masks]

da_result = compute_traces(da_images, da_masks)

# Store traces
dask_store[subgroup_traces] = da_result

dask.distributed.progress(dask_store[subgroup_traces], notebook=False)
print("")


with h5py.File(data_basename + postfix_traces + h5_ext, "w") as f2:
    zarr.copy(dask_store._diskstore[subgroup_traces], f2)


# View results
imgs_min, imgs_max = 0, 100

da_imgs = dask_store[subgroup_sub]

da_imgs_min, da_imgs_max = da_imgs.min(), da_imgs.max()

status = client.compute([da_imgs_min, da_imgs_max])
dask.distributed.progress(status, notebook=False)
print("")

imgs_min, imgs_max = [s.result() for s in status]

mplsv = plt.figure(FigureClass=MPLViewer)
mplsv.set_images(
    da_imgs,
    vmin=imgs_min,
    vmax=imgs_max
)

lblimg = dask_store[subgroup_rois_labels].compute()
lblimg_msk = numpy.ma.masked_array(lblimg, mask=(lblimg==0))

mplsv.viewer.matshow(lblimg_msk, alpha=0.3, cmap=mpl.cm.jet)


mskimg = None
mskimg_j = None
lblimg = None
traces = None
traces_j = None

del mskimg
del mskimg_j
del lblimg
del traces
del traces_j

End of workflow. Shutdown cluster.


In [ ]:
import dask.distributed
from nanshe_workflow.par import shutdown_distributed

try:
    del dask_store
except NameError:
    pass

client = dask.distributed.client.default_client()

shutdown_distributed(client)

client

Prepare interactive projection graph


In [ ]:
import numpy
import numpy as np

import bokeh.plotting
import bokeh.plotting as bp

import bokeh.io
import bokeh.io as bio

import bokeh.embed
import bokeh.embed as be

from bokeh.models.mappers import LinearColorMapper

import webcolors

from bokeh.models import CustomJS, ColumnDataSource, HoverTool
from bokeh.models.layouts import Row

from builtins import (
    map as imap,
    range as irange
)

from past.builtins import basestring

import nanshe_workflow
from nanshe_workflow.data import io_remove, open_zarr
from nanshe_workflow.vis import (
    get_rgb_array, get_rgba_array, get_all_greys, masks_to_contours_2d,
    generate_cdn, write_html,
)
from nanshe_workflow.util import gzip_compress, hash_file, indent

In [ ]:
mskimg = zarr_store[subgroup_rois_masks][...]

traces = zarr_store[subgroup_traces][...]

imgproj_mean = zarr_store[subgroup_proj_max][...]
imgproj_max = zarr_store[subgroup_proj_mean][...]
imgproj_std = zarr_store[subgroup_proj_std][...]

Result visualization

  • proj_img (str or list of str): which projection or projections to plot (e.g. "max", "mean", "std").
  • block_size (int): size of each point on any dimension in the image in terms of pixels.
  • roi_alpha (float): transparency of the ROIs in a range of [0.0, 1.0].
  • roi_border_width (int): width of the line border on each ROI.


  • trace_plot_width (int): width of the trace plot.

In [ ]:
proj_img = "std"
block_size = 1
roi_alpha = 0.3
roi_border_width = 3
trace_plot_width = 500


bio.curdoc().clear()

grey_range = get_all_greys()
grey_cm = LinearColorMapper(grey_range)

colors_rgb = get_rgb_array(len(mskimg))
colors_rgb = colors_rgb.tolist()
colors_rgb = list(imap(webcolors.rgb_to_hex, colors_rgb))

mskctr_pts_y, mskctr_pts_x = masks_to_contours_2d(mskimg)

mskctr_pts_dtype = np.min_scalar_type(max(mskimg.shape[1:]) - 1)
mskctr_pts_y = [np.array(_, dtype=mskctr_pts_dtype) for _ in mskctr_pts_y]
mskctr_pts_x = [np.array(_, dtype=mskctr_pts_dtype) for _ in mskctr_pts_x]

mskctr_srcs = ColumnDataSource(data=dict(x=mskctr_pts_x, y=mskctr_pts_y, color=colors_rgb))


if isinstance(proj_img, basestring):
    proj_img = [proj_img]
else:
    proj_img = list(proj_img)


proj_plot_width = block_size*mskimg.shape[2]
proj_plot_height = block_size*mskimg.shape[1]
plot_projs = []

if "max" in proj_img:
    plot_max = bp.Figure(plot_width=proj_plot_width, plot_height=proj_plot_height,
                         x_range=[0, mskimg.shape[2]], y_range=[mskimg.shape[1], 0],
                         tools=["tap", "pan", "box_zoom", "wheel_zoom", "save", "reset"],
                         title="Max Projection with ROIs", border_fill_color="black")
    plot_max.image(image=[numpy.flipud(imgproj_max)], x=[0], y=[mskimg.shape[1]],
                   dw=[imgproj_max.shape[1]], dh=[imgproj_max.shape[0]], color_mapper=grey_cm)
    plot_max.patches('x', 'y', source=mskctr_srcs, alpha=roi_alpha, line_width=roi_border_width, color="color")

    plot_max.outline_line_color = "white"
    for i in irange(len(plot_max.axis)):
        plot_max.axis[i].axis_line_color = "white"

    plot_projs.append(plot_max)


if "mean" in proj_img:
    plot_mean = bp.Figure(plot_width=proj_plot_width, plot_height=proj_plot_height,
                         x_range=[0, mskimg.shape[2]], y_range=[mskimg.shape[1], 0],
                         tools=["tap", "pan", "box_zoom", "wheel_zoom", "save", "reset"],
                         title="Mean Projection with ROIs", border_fill_color="black")
    plot_mean.image(image=[numpy.flipud(imgproj_mean)], x=[0], y=[mskimg.shape[1]],
                   dw=[mskimg.shape[2]], dh=[mskimg.shape[1]], color_mapper=grey_cm)
    plot_mean.patches('x', 'y', source=mskctr_srcs, alpha=roi_alpha, line_width=roi_border_width, color="color")

    plot_mean.outline_line_color = "white"
    for i in irange(len(plot_mean.axis)):
        plot_mean.axis[i].axis_line_color = "white"

    plot_projs.append(plot_mean)


if "std" in proj_img:
    plot_std = bp.Figure(plot_width=proj_plot_width, plot_height=proj_plot_height,
                         x_range=[0, mskimg.shape[2]], y_range=[mskimg.shape[1], 0],
                         tools=["tap", "pan", "box_zoom", "wheel_zoom", "save", "reset"],
                         title="Std Dev Projection with ROIs", border_fill_color="black")
    plot_std.image(image=[numpy.flipud(imgproj_std)], x=[0], y=[mskimg.shape[1]],
                   dw=[mskimg.shape[2]], dh=[mskimg.shape[1]], color_mapper=grey_cm)
    plot_std.patches('x', 'y', source=mskctr_srcs, alpha=roi_alpha, line_width=roi_border_width, color="color")

    plot_std.outline_line_color = "white"
    for i in irange(len(plot_std.axis)):
        plot_std.axis[i].axis_line_color = "white"

    plot_projs.append(plot_std)


all_tr_dtype_srcs = ColumnDataSource(data=dict(traces_dtype=traces.dtype.type(0)[None]))
all_tr_shape_srcs = ColumnDataSource(data=dict(traces_shape=traces.shape))
all_tr_srcs = ColumnDataSource(data=dict(
    traces=numpy.frombuffer(
        gzip_compress(traces.tobytes()),
        dtype=np.uint8
    )
))
tr_srcs = ColumnDataSource(data=dict(times_sel=[], traces_sel=[], colors_sel=[]))
plot_tr = bp.Figure(plot_width=trace_plot_width, plot_height=proj_plot_height,
                    x_range=(0.0, float(traces.shape[1])), y_range=(float(traces.min()), float(traces.max())),
                    tools=["pan", "box_zoom", "wheel_zoom", "save", "reset"], title="ROI traces",
                    background_fill_color="black", border_fill_color="black")
plot_tr.multi_line("times_sel", "traces_sel", source=tr_srcs, color="colors_sel")

plot_tr.outline_line_color = "white"
for i in irange(len(plot_tr.axis)):
    plot_tr.axis[i].axis_line_color = "white"

plot_projs.append(plot_tr)


mskctr_srcs.selected.js_on_change("indices", CustomJS(
    args=dict(
        mskctr_srcs=mskctr_srcs,
        all_tr_dtype_srcs=all_tr_dtype_srcs,
        all_tr_shape_srcs=all_tr_shape_srcs,
        all_tr_srcs=all_tr_srcs,
        tr_srcs=tr_srcs
    ), code="""
    var range = function(n){ return Array.from(Array(n).keys()); };

    var traces_not_decoded = (all_tr_dtype_srcs.data['traces_dtype'] == 0);
    var traces_dtype = all_tr_dtype_srcs.data['traces_dtype'].constructor;
    var traces_shape = all_tr_shape_srcs.data['traces_shape'];
    var trace_len = traces_shape[1];
    var traces = all_tr_srcs.data['traces'];
    if (traces_not_decoded) {
        traces = window.pako.inflate(traces);
        traces = new traces_dtype(traces.buffer);
        all_tr_srcs.data['traces'] = traces;
        all_tr_dtype_srcs.data['traces_dtype'] = 1;
    }

    var inds = cb_obj['1d'].indices;
    var colors = mskctr_srcs.data['color'];
    var selected = tr_srcs.data;

    var times = range(trace_len);

    selected['times_sel'] = [];
    selected['traces_sel'] = [];
    selected['colors_sel'] = [];

    for (i = 0; i < inds.length; i++) {
        var inds_i = inds[i];
        var trace_i = traces.slice(trace_len*inds_i, trace_len*(inds_i+1));
        var color_i = colors[inds_i];

        selected['times_sel'].push(times);
        selected['traces_sel'].push(trace_i);
        selected['colors_sel'].push(color_i);
    }

    tr_srcs.change.emit();
"""))


plot_group = Row(*plot_projs)


# Clear out the old HTML file before writing a new one.
io_remove(data_basename + postfix_html + html_ext)


script, div = be.components(plot_group)
cdn = "\n" + generate_cdn("sha384") + "\n"
cdn += """
<script type="text/javascript" src="{url}" integrity="{integrity}" crossorigin="anonymous"></script>
""".format(
    url="https://cdnjs.cloudflare.com/ajax/libs/pako/{ver}/pako_inflate.min.js",
    integrity="{sha_type}-{sha_value}",
).format(
    ver="1.0.6",
    sha_type="sha384",
    sha_value="vfctOCT+kAyhRRvZr0t63Ktb6zOZrCbLW9CIyQr9G4UMhKAabPpM3iDOI2lnXsX4"
)
cdn += "\n"

write_html(data_basename + postfix_html + html_ext, data_basename + postfix_html, div, script, cdn)


from IPython.display import display, IFrame
display(IFrame(data_basename + postfix_html + html_ext, "100%", 1.05*proj_plot_height))

In [ ]:
# Test teardown. Ignore warnings during production runs.

%run ./teardown_tests.py

In [ ]: