In [ ]:
# Test setup. Ignore warnings during production runs.
%run ./setup_tests.py
In [ ]:
data_dir = ""
data = "data.tif"
data_basename = "data"
dataset = "images"
import os
data_ext = os.path.splitext(data)[1].lower()
data_dir = os.path.abspath(data_dir)
subgroup_raw = "raw"
subgroup_trim = "trim"
subgroup_dn = "dn"
subgroup_reg = "reg"
subgroup_reg_images = "reg/images"
subgroup_reg_shifts = "reg/shifts"
subgroup_sub = "sub"
subgroup_norm = "norm"
subgroup_dict_init_data = "dict_init_data"
subgroup_dict_init_dict = "dict_init_dict"
subgroup_dict = "dict"
subgroup_code = "code"
subgroup_post = "post"
subgroup_post_mask = "post/mask"
subgroup_rois = "rois"
subgroup_rois_masks = "rois/masks"
subgroup_rois_masks_j = "rois/masks_j"
subgroup_rois_labels = "rois/labels"
subgroup_rois_labels_j = "rois/labels_j"
subgroup_traces = "traces"
subgroup_proj = "proj"
subgroup_proj_hmean = "proj/hmean"
subgroup_proj_max = "proj/max"
subgroup_proj_mean = "proj/mean"
subgroup_proj_std = "proj/std"
postfix_rois = "_rois"
postfix_traces = "_traces"
postfix_html = "_proj"
h5_ext = os.path.extsep + "h5"
tiff_ext = os.path.extsep + "tif"
zarr_ext = os.path.extsep + "zarr"
html_ext = os.path.extsep + "html"
In [ ]:
import os
from psutil import cpu_count
cluster_kwargs = {
"ip": ""
}
client_kwargs = {}
adaptive_kwargs = {
"minimum": 0,
"maximum": int(os.environ.get("CORES", cpu_count())) - 1
}
In [ ]:
import zarr
from nanshe_workflow.data import DistributedDirectoryStore
zarr_store = zarr.open_group(DistributedDirectoryStore(data_basename + zarr_ext), "a")
In [ ]:
from nanshe_workflow.par import startup_distributed
from nanshe_workflow.data import DistributedArrayStore
client = startup_distributed(0, cluster_kwargs, client_kwargs, adaptive_kwargs)
dask_store = DistributedArrayStore(zarr_store, client=client)
client
In [ ]:
client.cluster
In [ ]:
%matplotlib notebook
import matplotlib
import matplotlib.cm
import matplotlib.pyplot
import matplotlib as mpl
import matplotlib.pyplot as plt
from mplview.core import MatplotlibViewer as MPLViewer
In [ ]:
import ctypes
import logging
import os
import importlib
import sys
from builtins import (
map as imap,
range as irange
)
from past.builtins import basestring
try:
from contextlib import suppress
except ImportError:
from contextlib2 import suppress
import numpy
import h5py
import numpy as np
import h5py as hp
import dask
import dask.array
import dask.array.fft
import dask.distributed
import dask.array as da
try:
from dask.highlevelgraph import HighLevelGraph
except ImportError:
import dask.sharedict as HighLevelGraph
try:
from dask.array import blockwise as da_blockwise
except ImportError:
from dask.array import atop as da_blockwise
import dask_image
import dask_image.imread
import dask_image.ndfilters
import dask_image.ndfourier
import zarr
import nanshe
from nanshe.imp.segment import generate_dictionary
import nanshe_workflow
import nanshe_workflow._reg_joblib
from nanshe_workflow.data import io_remove, dask_io_remove, dask_load_hdf5, dask_store_zarr, zip_zarr, open_zarr
zarr.blosc.set_nthreads(1)
zarr.blosc.use_threads = False
client.run(zarr.blosc.set_nthreads, 1)
client.run(setattr, zarr.blosc, "use_threads", False)
logging.getLogger("nanshe").setLevel(logging.INFO)
In [ ]:
from nanshe_workflow.data import DistributedDirectoryStore
from nanshe_workflow.data import hdf5_to_zarr, zarr_to_hdf5
from nanshe_workflow.data import save_tiff
In [ ]:
try:
import pyfftw.interfaces.dask_fft as dask_fft
except ImportError:
import dask.array.fft as dask_fft
In [ ]:
from nanshe_workflow.reg import fourier_shift_wrap, compute_offset, roll_frames
In [ ]:
from nanshe_workflow.imp2 import extract_f0, wavelet_transform, renormalized_images, normalize_data
from nanshe_workflow.par import halo_block_generate_dictionary_parallel
from nanshe_workflow.imp import block_postprocess_data_parallel
In [ ]:
from nanshe_workflow.proj import compute_traces
from nanshe_workflow.proj import compute_adj_harmonic_mean_projection
from nanshe_workflow.proj import norm_layer
In [ ]:
block_chunks = (100, -1, -1)
for k in [subgroup_raw]:
with suppress(KeyError):
del dask_store[k]
dask_ingest_func = None
if data_ext == tiff_ext:
dask_ingest_func = lambda data: dask_image.imread.imread(data, nframes=block_chunks[0])
elif data_ext == h5_ext:
dask_ingest_func = lambda data: dask_load_hdf5(data, dataset=dataset, chunks=block_chunks)
if isinstance(data, basestring):
dask_store[subgroup_raw] = dask_ingest_func(data)
else:
dask_store[subgroup_raw] = da.concatenate(list(imap(dask_ingest_func, data)))
dask.distributed.progress(dask_store[subgroup_raw], notebook=False)
In [ ]:
imgs_min, imgs_max = 0, 100
da_imgs = dask_store[subgroup_raw]
da_imgs_min, da_imgs_max = da_imgs.min(), da_imgs.max()
status = client.compute([da_imgs_min, da_imgs_max])
dask.distributed.progress(status, notebook=False)
print("")
imgs_min, imgs_max = [s.result() for s in status]
mplsv = plt.figure(FigureClass=MPLViewer)
mplsv.set_images(
da_imgs,
vmin=imgs_min,
vmax=imgs_max
)
In [ ]:
front = 1
back = 1
for k in [subgroup_trim]:
with suppress(KeyError):
del dask_store[k]
# Load and prep data for computation.
da_imgs = dask_store[subgroup_raw]
# Trim frames from front and back
da_imgs_trim = da_imgs[front:len(da_imgs)-back]
# Store trimmed data
dask_store[subgroup_trim] = da_imgs_trim
# Check progress of store step
dask.distributed.progress(dask_store[subgroup_trim], notebook=False)
print("")
# View results
imgs_min, imgs_max = 0, 100
da_imgs = dask_store[subgroup_trim]
da_imgs_min, da_imgs_max = da_imgs.min(), da_imgs.max()
status = client.compute([da_imgs_min, da_imgs_max])
dask.distributed.progress(status, notebook=False)
print("")
imgs_min, imgs_max = [s.result() for s in status]
mplsv = plt.figure(FigureClass=MPLViewer)
mplsv.set_images(
da_imgs,
vmin=imgs_min,
vmax=imgs_max
)
In [ ]:
med_filt_size = 3
for k in [subgroup_dn]:
with suppress(KeyError):
del dask_store[k]
# Load and prep data for computation.
da_imgs = dask_store[subgroup_trim]
# Median filter frames
da_imgs_filt = dask_image.ndfilters.median_filter(
da_imgs, (1,) + (da_imgs.ndim - 1) * (med_filt_size,)
)
# Reset minimum to original value.
da_imgs_min = da_imgs.min()
da_imgs_filt_min = da_imgs_filt.min()
da_imgs_filt += da_imgs_min - da_imgs_filt_min
# Store denoised data
da_imgs_min, da_imgs_filt, da_imgs_filt_min = dask.persist(da_imgs_min, da_imgs_filt, da_imgs_filt_min)
dask_store[subgroup_dn] = da_imgs_filt
del da_imgs_min, da_imgs_filt, da_imgs_filt_min
# Check progress of store step
dask.distributed.progress(dask_store[subgroup_dn], notebook=False)
print("")
# View results
imgs_min, imgs_max = 0, 100
da_imgs = dask_store[subgroup_dn]
da_imgs_min, da_imgs_max = da_imgs.min(), da_imgs.max()
status = client.compute([da_imgs_min, da_imgs_max])
dask.distributed.progress(status, notebook=False)
print("")
imgs_min, imgs_max = [s.result() for s in status]
mplsv = plt.figure(FigureClass=MPLViewer)
mplsv.set_images(
da_imgs,
vmin=imgs_min,
vmax=imgs_max
)
In [ ]:
num_reps = 5
tmpl_hist_wght = 0.25
thld_rel_dist = 0.0
for k in [subgroup_reg_images, subgroup_reg_shifts]:
with suppress(KeyError):
del dask_store[k]
with suppress(KeyError):
del zarr_store[subgroup_reg]
zarr_store.require_group(subgroup_reg)
# Load and prep data for computation.
da_imgs = dask_store[subgroup_dn]
da_imgs_flt = da_imgs
if not (issubclass(da_imgs_flt.dtype.type, np.floating) and
da_imgs_flt.dtype.itemsize >= 4):
da_imgs_flt = da_imgs_flt.astype(np.float32)
# Create frame shape arrays
frame_shape = np.array(da_imgs_flt.shape[1:], dtype=int)
half_frame_shape = frame_shape // 2
frame_shape = da.asarray(frame_shape)
half_frame_shape = da.asarray(half_frame_shape)
# Find the inverse of each frame
da_imgs_flt_min = da_imgs_flt.min()
da_imgs_inv = dask.array.reciprocal(da_imgs_flt - (da_imgs_flt_min - 1))
# Compute the FFT of inverse frames and template
da_imgs_fft = dask_fft.rfftn(da_imgs_inv, axes=tuple(irange(1, da_imgs_flt.ndim)))
da_imgs_fft_tmplt = da_imgs_fft.mean(axis=0, keepdims=True)
# Initialize
i = 0
avg_rel_dist = 1.0
tmpl_hist_wght = da_imgs_flt.dtype.type(tmpl_hist_wght)
da_shifts = da.zeros(
(len(da_imgs_fft), da_imgs_fft.ndim - 1),
dtype=int,
chunks=(da_imgs_fft.chunks[0], (da_imgs_fft.ndim - 1,))
)
# Persist FFT of frames and template
da_imgs_flt_min, da_imgs_fft, da_imgs_fft_tmplt = client.persist([
da_imgs_flt_min, da_imgs_fft, da_imgs_fft_tmplt
])
del da_imgs_flt_min
while avg_rel_dist > thld_rel_dist and i < num_reps:
# Compute the shifted frames
da_shifted_frames = da_blockwise(
fourier_shift_wrap,
(0,) + tuple(irange(1, da_imgs_fft.ndim)),
da_imgs_fft,
(0,) + tuple(irange(1, da_imgs_fft.ndim)),
da_shifts,
(0, da_imgs_fft.ndim),
dtype=da_imgs_fft.dtype
)
# Compute the template FFT
da_imgs_fft_tmplt = (
tmpl_hist_wght * da_imgs_fft_tmplt +
(1 - tmpl_hist_wght) * da_shifted_frames.mean(axis=0, keepdims=True)
)
# Persist the updated FFT template
da_shifted_frames, da_imgs_fft_tmplt = client.persist([
da_shifted_frames, da_imgs_fft_tmplt
])
del da_shifted_frames
# Find the best overlap with the template.
da_overlap = dask_fft.irfftn(
da_imgs_fft * da_imgs_fft_tmplt,
s=da_imgs_flt.shape[1:],
axes=tuple(irange(1, da_imgs_flt.ndim))
)
da_overlap_max = da_overlap.max(axis=tuple(irange(1, da_imgs_flt.ndim)), keepdims=True)
da_overlap_max_match = (da_overlap == da_overlap_max)
# Clear FFT overlap intermediates
del da_overlap_max
# Compute the shift for each frame.
old_da_shifts = da_shifts
da_raw_shifts = da_blockwise(
compute_offset,
(0, da_overlap_max_match.ndim),
da_overlap_max_match.rechunk(dict(enumerate(da_overlap_max_match.shape[1:], 1))),
tuple(irange(0, da_overlap_max_match.ndim)),
dtype=int,
new_axes={da_overlap_max_match.ndim: da_overlap_max_match.ndim - 1}
)
# Free connected persisted values
del da_overlap_max_match
# Remove any collective frame drift.
da_drift = da_raw_shifts.mean(axis=0, keepdims=True).round().astype(da_shifts.dtype)
da_shifts = da_raw_shifts - da_drift
# Clear drift corrected shifts
del da_drift
# Find shift change.
diff_da_shifts = da_shifts - old_da_shifts
rel_diff_da_shifts = (
diff_da_shifts.astype(da_imgs_flt.dtype) /
frame_shape.astype(da_imgs_flt.dtype) /
np.sqrt(da_imgs_flt.dtype.type(len(frame_shape)))
)
rel_dist_da_shifts = da.sqrt(da.square(rel_diff_da_shifts).sum(axis=1))
avg_rel_dist = rel_dist_da_shifts.sum() / da_imgs_flt.dtype.type(len(da_shifts))
# Free old shifts
del old_da_shifts
# Persist statistics related to shift change
da_overlap, da_raw_shifts, diff_da_shifts, rel_diff_da_shifts, rel_dist_da_shifts, avg_rel_dist = client.persist([
da_overlap, da_raw_shifts, diff_da_shifts, rel_diff_da_shifts, rel_dist_da_shifts, avg_rel_dist
])
del da_overlap
del da_raw_shifts
del diff_da_shifts
del rel_diff_da_shifts
del rel_dist_da_shifts
# Compute change
dask.distributed.progress(avg_rel_dist, notebook=False)
print("")
avg_rel_dist = avg_rel_dist.compute()
i += 1
# Show change
print((i, avg_rel_dist))
# Drop unneeded items
del frame_shape
del half_frame_shape
del da_imgs_flt
del da_imgs_inv
del da_imgs_fft
del da_imgs_fft_tmplt
# Roll all parts to clip to one side
# Keep origin static
da_imgs_shifted = roll_frames(
da_imgs,
da.clip(da_shifts, None, 0)
)
# Truncate all frames to smallest one
da_imgs_trunc_shape = da.asarray(da_imgs.shape[1:]) - abs(da_shifts).max(axis=0)
da_imgs_trunc_shape = da_imgs_trunc_shape.compute()
da_imgs_trunc_cut = tuple(map(
lambda s: slice(None, s), da_imgs_trunc_shape
))
da_imgs_trunc = da_imgs_shifted[(slice(None),) + da_imgs_trunc_cut]
# Free raw data
del da_imgs
# Store registered data
dask_store.update({
subgroup_reg_images: da_imgs_trunc,
subgroup_reg_shifts: da_shifts,
})
dask.distributed.progress(
dask.distributed.futures_of([
dask_store[subgroup_reg_images],
dask_store[subgroup_reg_shifts]
]),
notebook=False
)
print("")
# Free truncated frames and shifts
del da_imgs_trunc
del da_shifts
# View results
imgs_min, imgs_max = 0, 100
da_imgs = dask_store[subgroup_reg_images]
da_shifts = dask_store[subgroup_reg_shifts]
da_imgs_min, da_imgs_max = da_imgs.min(), da_imgs.max()
status = client.compute([da_imgs_min, da_imgs_max])
dask.distributed.progress(status, notebook=False)
print("")
imgs_min, imgs_max = [s.result() for s in status]
fig, axs = plt.subplots(nrows=da_shifts.shape[1], sharex=True)
fig.subplots_adjust(hspace=0.0)
for i in range(da_shifts.shape[1]):
axs[i].plot(np.asarray(da_shifts[:, i]))
axs[i].set_ylabel("%s (px)" % chr(ord("X") + da_shifts.shape[1] - i - 1))
axs[i].yaxis.set_tick_params(width=1.5)
[v.set_linewidth(2) for v in axs[i].spines.values()]
axs[-1].set_xlabel("Frame (#)")
axs[-1].set_xlim((0, da_shifts.shape[0] - 1))
axs[-1].xaxis.set_tick_params(width=1.5)
mplsv = plt.figure(FigureClass=MPLViewer)
mplsv.set_images(
da_imgs,
vmin=imgs_min,
vmax=imgs_max
)
In [ ]:
for k in [subgroup_proj_hmean, subgroup_proj_max, subgroup_proj_mean, subgroup_proj_std]:
with suppress(KeyError):
del dask_store[k]
with suppress(KeyError):
del zarr_store[subgroup_proj]
zarr_store.require_group(subgroup_proj)
# Load and prep data for computation.
da_imgs = dask_store[subgroup_reg_images]
da_imgs_flt = da_imgs
if not (issubclass(da_imgs_flt.dtype.type, np.floating) and
da_imgs_flt.dtype.itemsize >= 4):
da_imgs_flt = da_imgs_flt.astype(np.float32)
da_imgs_proj_hmean = compute_adj_harmonic_mean_projection(da_imgs_flt)
da_imgs_proj_max = da_imgs_flt.max(axis=0)
da_imgs_proj_mean, da_imgs_proj_std = da_imgs_flt.mean(axis=0), da_imgs_flt.std(axis=0)
# Store projections
dask_store.update(dict(zip(
[subgroup_proj_hmean, subgroup_proj_max, subgroup_proj_mean, subgroup_proj_std],
[da_imgs_proj_hmean, da_imgs_proj_max, da_imgs_proj_mean, da_imgs_proj_std]
)))
dask.distributed.progress(
dask.distributed.futures_of([
dask_store[subgroup_proj_hmean],
dask_store[subgroup_proj_max],
dask_store[subgroup_proj_mean],
dask_store[subgroup_proj_std]
]),
notebook=False
)
print("")
In [ ]:
for k in [subgroup_sub]:
with suppress(KeyError):
del dask_store[k]
# Load and prep data for computation.
da_imgs = dask_store[subgroup_reg_images]
da_imgs_flt = da_imgs
if not (issubclass(da_imgs_flt.dtype.type, np.floating) and
da_imgs_flt.dtype.itemsize >= 4):
da_imgs_flt = da_imgs_flt.astype(np.float32)
da_imgs_sub = da_imgs_flt - compute_adj_harmonic_mean_projection(da_imgs_flt)
da_imgs_sub -= da_imgs_sub.min()
# Store background removed data
dask_store[subgroup_sub] = da_imgs_sub
dask.distributed.progress(dask_store[subgroup_sub], notebook=False)
print("")
# View results
imgs_min, imgs_max = 0, 100
da_imgs = dask_store[subgroup_sub]
da_imgs_min, da_imgs_max = da_imgs.min(), da_imgs.max()
status = client.compute([da_imgs_min, da_imgs_max])
dask.distributed.progress(status, notebook=False)
print("")
imgs_min, imgs_max = [s.result() for s in status]
mplsv = plt.figure(FigureClass=MPLViewer)
mplsv.set_images(
da_imgs,
vmin=imgs_min,
vmax=imgs_max
)
In [ ]:
for k in [subgroup_norm]:
with suppress(KeyError):
del dask_store[k]
# Load and prep data for computation.
da_imgs = dask_store[subgroup_sub]
da_imgs_flt = da_imgs
if not (issubclass(da_imgs_flt.dtype.type, np.floating) and
da_imgs_flt.dtype.itemsize >= 4):
da_imgs_flt = da_imgs_flt.astype(np.float32)
da_imgs_flt_mins = da_imgs_flt.min(
axis=tuple(irange(1, da_imgs_flt.ndim)),
keepdims=True
)
da_imgs_flt_shift = da_imgs_flt - da_imgs_flt_mins
da_result = renormalized_images(da_imgs_flt_shift)
# Store normalized data
dask_store[subgroup_norm] = da_result
dask.distributed.progress(dask_store[subgroup_norm], notebook=False)
print("")
# View results
imgs_min, imgs_max = 0, 100
da_imgs = dask_store[subgroup_norm]
da_imgs_min, da_imgs_max = da_imgs.min(), da_imgs.max()
status = client.compute([da_imgs_min, da_imgs_max])
dask.distributed.progress(status, notebook=False)
print("")
imgs_min, imgs_max = [s.result() for s in status]
mplsv = plt.figure(FigureClass=MPLViewer)
mplsv.set_images(
da_imgs,
vmin=imgs_min,
vmax=imgs_max
)
n_components (int): number of basis images in the dictionary.batchsize (int): minibatch size to use.iters (int): number of iterations to run before getting dictionary.lambda1 (float): weight for L1 sparisty enforcement on sparse code.lambda2 (float): weight for L2 sparisty enforcement on sparse code.block_frames (int): number of frames to work with in each full frame block (run in parallel).
In [ ]:
import functools
import itertools
import logging
import operator
import sys
from builtins import filter as ifilter
import toolz
import toolz.itertoolz
import sklearn
from sklearn.base import BaseEstimator
from sklearn.decomposition.dict_learning import SparseCodingMixin
try:
from contextlib import ExitStack, suppress, redirect_stdout, redirect_stderr
except ImportError:
from contextlib2 import ExitStack, suppress, redirect_stdout, redirect_stderr
try:
from StringIO import StringIO
except ImportError:
from io import StringIO
def to_tuple(*args):
return tuple(args)
def func_log_stdoe(func):
@functools.wraps(func)
def wrapped(*args, **kwargs):
with ExitStack() as stack:
out, err = StringIO(), StringIO()
stack.enter_context(redirect_stdout(out))
stack.enter_context(redirect_stdout(err))
try:
return func(*args, **kwargs)
finally:
logging.getLogger("distributed.worker.stdout").info(out.getvalue())
logging.getLogger("distributed.worker.stderr").info(err.getvalue())
return wrapped
def dict_learning_online(X, n_components=2, alpha=1, n_iter=100,
return_code=True, dict_init=None, callback=None,
batch_size=3, verbose=False, shuffle=True, n_jobs=1,
method='lars', iter_offset=0, random_state=None,
return_inner_stats=False, inner_stats_A=None,
inner_stats_B=None, return_n_iter=False,
positive_dict=False, positive_code=False):
if inner_stats_A is None or inner_stats_B is None:
inner_stats = None
else:
inner_stats = (inner_stats_A, inner_stats_B)
with sklearn.externals.joblib.parallel_backend("sequential"):
result = sklearn.decomposition.dict_learning_online(X, n_components=n_components, alpha=alpha,
n_iter=n_iter, return_code=return_code,
dict_init=np.require(dict_init, requirements="OW"),
callback=callback, batch_size=batch_size, verbose=verbose,
shuffle=shuffle, n_jobs=n_jobs, method=method,
iter_offset=iter_offset, random_state=random_state,
return_inner_stats=return_inner_stats,
inner_stats=inner_stats,
return_n_iter=return_n_iter,
positive_dict=positive_dict,
positive_code=positive_code)
components_, inner_stats_ = result
inner_stats_A_, inner_stats_B_ = inner_stats_
return components_, inner_stats_A_, inner_stats_B_
class MiniBatchDictionaryLearning(BaseEstimator, SparseCodingMixin):
def __init__(self,
n_components=None,
alpha=1,
n_iter=1000,
fit_algorithm="lars",
batch_size=3,
shuffle=True,
dict_init=None,
transform_algorithm="omp",
transform_n_nonzero_coefs=None,
transform_alpha=None,
positive_code=False,
positive_dict=False,
verbose=False):
self.n_components = n_components
self.alpha = alpha
self.n_iter = n_iter
self.fit_algorithm = fit_algorithm
self.batch_size = batch_size
self.shuffle = shuffle
self.dict_init = dict_init
self.transform_algorithm = transform_algorithm
self.transform_n_nonzero_coefs = transform_n_nonzero_coefs
self.transform_alpha = transform_alpha
self.positive_code = positive_code
self.positive_dict = positive_dict
self.verbose = verbose
self.components_ = self.dict_init
if self.components_ is not None:
self.components_ = da.asarray(self.components_)
self.components_ = self.components_.rechunk(self.components_.shape)
self.iter_offset_ = 0
self.inner_stats_ = None
def fit(self, X, y=None):
X = da.asarray(X)
n_components = self.n_components
n_features = X.shape[1]
components_ = self.components_
inner_stats_ = self.inner_stats_
X_get = toolz.curry(
operator.getitem,
X.rechunk((1, X.shape[1]))
)
if components_ is None:
idx = np.random.permutation(X.shape[0])[:n_components]
components_ = da.stack([X_get(i) for i in idx.flat]).rechunk({0: n_components})
if self.shuffle:
idx = np.random.permutation(X.shape[0])
X = da.stack([X_get(i) for i in idx.flat]).rechunk({0: X.chunks[0]})
if inner_stats_ is None:
inner_stats_ = (
da.zeros((n_components, n_components), chunks=(n_components, n_components)),
da.zeros((n_features, n_components), chunks=(n_features, n_components))
)
func = dict_learning_online
if self.verbose:
func = func_log_stdoe(dict_learning_online)
for X_chk_key in dask.core.flatten(X.__dask_keys__()):
components_key = next(dask.core.flatten(components_.__dask_keys__()))
if inner_stats_:
inner_stats_A_key = next(dask.core.flatten(inner_stats_[0].__dask_keys__()))
inner_stats_B_key = next(dask.core.flatten(inner_stats_[1].__dask_keys__()))
else:
inner_stats_ = tuple()
inner_stats_A_key = None
inner_stats_B_key = None
kwargs = dict(
n_components=n_components,
alpha=self.alpha,
n_iter=self.n_iter,
return_code=False,
dict_init=components_key,
callback=None,
batch_size=self.batch_size,
verbose=self.verbose,
shuffle=False,
n_jobs=1,
method=self.fit_algorithm,
iter_offset=self.iter_offset_,
random_state=None,
return_inner_stats=True,
inner_stats_A=inner_stats_A_key,
inner_stats_B=inner_stats_B_key,
return_n_iter=False,
positive_dict=self.positive_dict,
positive_code=self.positive_code
)
tok = dask.base.tokenize(X_chk_key, **kwargs)
components_res_key = ("components_dict_learning_online-%s" % tok, 0, 0)
inner_stat_A_res_key = ("inner_stat_A_dict_learning_online-%s" % tok, 0, 0)
inner_stat_B_res_key = ("inner_stat_B_dict_learning_online-%s" % tok, 0, 0)
dict_learn_key = "dict_learning_online-%s" % tok
dct = HighLevelGraph.merge(
{components_res_key: (
operator.getitem, dict_learn_key, 0
)},
{inner_stat_A_res_key: (
operator.getitem, dict_learn_key, 1
)},
{inner_stat_B_res_key: (
operator.getitem, dict_learn_key, 2
)},
{dict_learn_key:
(dask.compatibility.apply, func, [X_chk_key], kwargs),
},
X.__dask_graph__(),
components_.__dask_graph__(),
*(e.__dask_graph__() for e in inner_stats_)
)
dct = da.Array.__dask_optimize__(dct, [
components_res_key,
inner_stat_A_res_key,
inner_stat_B_res_key
])
components_ = da.Array(
dct, components_res_key[0], ((n_components,), (n_features,)), X.dtype
)
inner_stats_ = (
da.Array(
dct, inner_stat_A_res_key[0], ((n_components,), (n_components,)), X.dtype
),
da.Array(
dct, inner_stat_B_res_key[0], ((n_features,), (n_components,)), X.dtype
),
)
# Persist everything after an iteration
result = dask.persist(components_, *inner_stats_)
components_, inner_stats_ = result[0], result[1:]
self.iter_offset_ += self.n_iter
self.components_ = components_
self.inner_stats_ = inner_stats_
return self
@staticmethod
def _sparse_encode_wrapper(*args, **kwargs):
args = tuple(e[0] if isinstance(e, list) else e for e in args)
return sklearn.decomposition.sparse_encode(*args, **kwargs)
def transform(self, X):
sparse_encode_wrapper = MiniBatchDictionaryLearning._sparse_encode_wrapper
if self.verbose:
sparse_encode_wrapper = func_log_stdoe(sparse_encode_wrapper)
gram = da.tensordot(self.components_, self.components_, axes=[[1], [1]])
cov = da.tensordot(self.components_, X, axes=[[1], [1]])
code = da_blockwise(
sparse_encode_wrapper,
(1, 0),
X,
(1, 2),
self.components_,
(0, 2),
gram,
(0, 0),
cov,
(0, 1),
dtype=self.components_.dtype,
algorithm=self.transform_algorithm,
n_nonzero_coefs=self.transform_n_nonzero_coefs,
alpha=self.transform_alpha,
copy_cov=True,
init=None,
max_iter=self.n_iter,
n_jobs=1,
check_input=False,
verbose=self.verbose
)
gram, cov, code = dask.persist(gram, cov, code)
return code
In [ ]:
import toolz
n_components = 50
for k in [subgroup_dict_init_data, subgroup_dict_init_dict]:
with suppress(KeyError):
del dask_store[k]
da_imgs = dask_store[subgroup_norm]
# Reshape to matrix and provide frame selector
da_imgs_mtx = da_imgs.reshape(
da_imgs.shape[0],
int(np.prod(da_imgs.shape[1:]))
)
da_imgs_mtx_get = toolz.curry(
operator.getitem,
da_imgs_mtx.rechunk((1, da_imgs_mtx.shape[1]))
)
# Create shuffled data
idx = np.random.permutation(da_imgs_mtx.shape[0])
dict_init_data = da.stack([da_imgs_mtx_get(i) for i in idx.flat])
dict_init_data = dict_init_data.rechunk({0: da_imgs_mtx.chunks[0]})
# Create dictionary subsample
idx = np.random.permutation(da_imgs_mtx.shape[0])[:n_components]
dict_init_dict = da.stack([da_imgs_mtx_get(i) for i in idx.flat])
dict_init_dict = dict_init_dict.rechunk({0: n_components})
# Store shuffled data
dask_store.update(dict(zip(
[subgroup_dict_init_data, subgroup_dict_init_dict],
[dict_init_data, dict_init_dict]
)))
dask.distributed.progress(
dask.distributed.futures_of([
dask_store[subgroup_dict_init_data],
dask_store[subgroup_dict_init_dict]
]),
notebook=False
)
print("")
In [ ]:
from nanshe_workflow.par import startup_distributed, shutdown_distributed
from nanshe_workflow.data import DistributedArrayStore
from psutil import cpu_count
ncores = int(os.environ.get("CORES", cpu_count())) - 1
dict_client = startup_distributed(
1,
cluster_kwargs={
"threads_per_worker": ncores
},
client_kwargs={}
)
dict_client.run(importlib.import_module, "nanshe_workflow._reg_joblib")
dict_client.run(
lambda n: ctypes.CDLL(ctypes.util.find_library("openblas")).openblas_set_num_threads(n),
ncores
)
display(dict_client)
dict_dask_store = DistributedArrayStore(zarr_store, client=dict_client)
for k in [subgroup_dict]:
with suppress(KeyError):
del dict_dask_store[k]
learner = MiniBatchDictionaryLearning(
n_components=len(dict_dask_store[subgroup_dict_init_dict]),
alpha=0.2,
n_iter=100,
fit_algorithm="lars",
batch_size=50,
shuffle=False,
dict_init=dict_dask_store[subgroup_dict_init_dict],
transform_algorithm="omp",
transform_n_nonzero_coefs=None,
transform_alpha=0.01,
verbose=False
)
learner.fit(dict_dask_store[subgroup_dict_init_data])
dictionary = learner.components_
dictionary = dictionary.reshape((dictionary.shape[0],) + dict_dask_store[subgroup_norm].shape[1:])
dictionary = dictionary.persist()
dict_dask_store[subgroup_dict] = dictionary
del dictionary
learner.components_ = dict_dask_store[subgroup_dict].reshape(learner.components_.shape)
dask.distributed.progress(
dict_dask_store[subgroup_dict],
notebook=False
)
print("")
del dict_dask_store
shutdown_distributed(dict_client)
del dict_client
In [ ]:
for k in [subgroup_code]:
with suppress(KeyError):
del dask_store[k]
da_imgs = dask_store[subgroup_norm]
da_imgs_mtx = da_imgs.reshape(
da_imgs.shape[0],
int(np.prod(da_imgs.shape[1:]))
)
learner.components_ = dask_store[subgroup_dict].reshape(
(dask_store[subgroup_dict].shape[0], int(np.prod(da_imgs.shape[1:])))
)
code = learner.transform(da_imgs_mtx)
code = code.T
code = code.persist()
dask_store[subgroup_code] = code
del learner
del code
dask.distributed.progress(
dask_store[subgroup_code],
notebook=False
)
print("")
In [ ]:
import ipywidgets
fig = plt.figure()
fig.add_subplot(1,2,1)
im = plt.imshow(dask_store[subgroup_dict][0])
fig.add_subplot(1,2,2)
line, = plt.plot(dask_store[subgroup_code][0], lw=2)
plt.show()
@ipywidgets.interact(i=ipywidgets.IntSlider(min=0, max=len(dask_store[subgroup_dict])-1, step=1, value=0))
def show_basis_code_plts(i):
im.set_array(dask_store[subgroup_dict][i])
line.set_data(np.arange(len(dask_store[subgroup_code][i])), dask_store[subgroup_code][i])
fig.canvas.draw_idle()
significance_threshold (float): number of standard deviations below which to include in "noise" estimatewavelet_scale (int): scale of wavelet transform to apply (should be the same as the one used above)noise_threshold (float): number of units of "noise" above which something needs to be to be significantaccepted_region_shape_constraints (dict): if ROIs don't match this, reduce the wavelet_scale once.percentage_pixels_below_max (float): upper bound on ratio of ROI pixels not at max intensity vs. all ROI pixelsmin_local_max_distance (float): minimum allowable euclidean distance between two ROIs maximum intensitiesaccepted_neuron_shape_constraints (dict): shape constraints for ROI to be kept.
alignment_min_threshold (float): similarity measure of the intensity of two ROIs images used for merging.
overlap_min_threshold (float): similarity measure of the masks of two ROIs used for merging.
In [ ]:
significance_threshold = 3.0
wavelet_scale = 3
noise_threshold = 3.0
percentage_pixels_below_max = 0.8
min_local_max_distance = 16.0
alignment_min_threshold = 0.6
overlap_min_threshold = 0.6
for k in zarr_store.get(subgroup_post, {}).keys():
with suppress(KeyError):
del dask_store[subgroup_post + "/" + k]
with suppress(KeyError):
del zarr_store[subgroup_post]
zarr_store.require_group(subgroup_post)
imgs = dask_store._diskstore[subgroup_dict]
da_imgs = da.from_array(imgs, chunks=((1,) + imgs.shape[1:]))
result = block_postprocess_data_parallel(client)(da_imgs,
**{
"wavelet_denoising" : {
"estimate_noise" : {
"significance_threshold" : significance_threshold
},
"wavelet.transform" : {
"scale" : wavelet_scale
},
"significant_mask" : {
"noise_threshold" : noise_threshold
},
"accepted_region_shape_constraints" : {
"major_axis_length" : {
"min" : 0.0,
"max" : 25.0
}
},
"remove_low_intensity_local_maxima" : {
"percentage_pixels_below_max" : percentage_pixels_below_max
},
"remove_too_close_local_maxima" : {
"min_local_max_distance" : min_local_max_distance
},
"accepted_neuron_shape_constraints" : {
"area" : {
"min" : 25,
"max" : 600
},
"eccentricity" : {
"min" : 0.0,
"max" : 0.9
}
}
},
"merge_neuron_sets" : {
"alignment_min_threshold" : alignment_min_threshold,
"overlap_min_threshold" : overlap_min_threshold,
"fuse_neurons" : {
"fraction_mean_neuron_max_threshold" : 0.01
}
}
}
)
# Store projections
dask_store.update(dict(zip(
["%s/%s" % (subgroup_post, e) for e in result.dtype.names],
[result[e] for e in result.dtype.names]
)))
dask.distributed.progress(
dask.distributed.futures_of([
dask_store["%s/%s" % (subgroup_post, e)]
for e in result.dtype.names
]),
notebook=False
)
print("")
In [ ]:
dask_io_remove(data_basename + postfix_rois + h5_ext, client)
for k in [subgroup_rois_masks, subgroup_rois_masks_j, subgroup_rois_labels, subgroup_rois_labels_j, subgroup_rois]:
with suppress(KeyError):
del dask_store[k]
with suppress(KeyError):
del zarr_store[subgroup_rois]
zarr_store.require_group(subgroup_rois)
da_roi_masks = dask_store[subgroup_post_mask]
da_lbls = da.arange(
1,
len(da_roi_masks) + 1,
chunks=da_roi_masks.chunks[0],
dtype=np.uint64
)
da_lblimg = (
da_lbls[(slice(None),) + (da_roi_masks.ndim - 1) * (None,)] *
da_roi_masks.astype(np.uint64)
).max(axis=0)
dask_store.update(dict(zip(
[subgroup_rois_masks, subgroup_rois_masks_j, subgroup_rois_labels, subgroup_rois_labels_j],
[da_roi_masks, da_roi_masks.astype(numpy.uint8), da_lblimg, da_lblimg.astype(numpy.uint8)]
)))
dask.distributed.progress(
dask.distributed.futures_of([
dask_store[e] for e in
[subgroup_rois_masks, subgroup_rois_masks_j, subgroup_rois_labels, subgroup_rois_labels_j]
]),
notebook=False
)
print("")
with h5py.File(data_basename + postfix_rois + h5_ext, "w") as f2:
for k in [subgroup_rois_masks, subgroup_rois_masks_j, subgroup_rois_labels, subgroup_rois_labels_j]:
zarr.copy(dask_store._diskstore[k], f2)
dask_io_remove(data_basename + postfix_traces + h5_ext, client)
for k in [subgroup_traces]:
with suppress(KeyError):
del dask_store[k]
# Load and prep data for computation.
da_images = dask_store[subgroup_sub]
da_masks = dask_store[subgroup_rois_masks]
da_result = compute_traces(da_images, da_masks)
# Store traces
dask_store[subgroup_traces] = da_result
dask.distributed.progress(dask_store[subgroup_traces], notebook=False)
print("")
with h5py.File(data_basename + postfix_traces + h5_ext, "w") as f2:
zarr.copy(dask_store._diskstore[subgroup_traces], f2)
# View results
imgs_min, imgs_max = 0, 100
da_imgs = dask_store[subgroup_sub]
da_imgs_min, da_imgs_max = da_imgs.min(), da_imgs.max()
status = client.compute([da_imgs_min, da_imgs_max])
dask.distributed.progress(status, notebook=False)
print("")
imgs_min, imgs_max = [s.result() for s in status]
mplsv = plt.figure(FigureClass=MPLViewer)
mplsv.set_images(
da_imgs,
vmin=imgs_min,
vmax=imgs_max
)
lblimg = dask_store[subgroup_rois_labels].compute()
lblimg_msk = numpy.ma.masked_array(lblimg, mask=(lblimg==0))
mplsv.viewer.matshow(lblimg_msk, alpha=0.3, cmap=mpl.cm.jet)
mskimg = None
mskimg_j = None
lblimg = None
traces = None
traces_j = None
del mskimg
del mskimg_j
del lblimg
del traces
del traces_j
In [ ]:
import dask.distributed
from nanshe_workflow.par import shutdown_distributed
try:
del dask_store
except NameError:
pass
client = dask.distributed.client.default_client()
shutdown_distributed(client)
client
In [ ]:
import numpy
import numpy as np
import bokeh.plotting
import bokeh.plotting as bp
import bokeh.io
import bokeh.io as bio
import bokeh.embed
import bokeh.embed as be
from bokeh.models.mappers import LinearColorMapper
import webcolors
from bokeh.models import CustomJS, ColumnDataSource, HoverTool
from bokeh.models.layouts import Row
from builtins import (
map as imap,
range as irange
)
from past.builtins import basestring
import nanshe_workflow
from nanshe_workflow.data import io_remove, open_zarr
from nanshe_workflow.vis import (
get_rgb_array, get_rgba_array, get_all_greys, masks_to_contours_2d,
generate_cdn, write_html,
)
from nanshe_workflow.util import gzip_compress, hash_file, indent
In [ ]:
mskimg = zarr_store[subgroup_rois_masks][...]
traces = zarr_store[subgroup_traces][...]
imgproj_mean = zarr_store[subgroup_proj_max][...]
imgproj_max = zarr_store[subgroup_proj_mean][...]
imgproj_std = zarr_store[subgroup_proj_std][...]
proj_img (str or list of str): which projection or projections to plot (e.g. "max", "mean", "std").block_size (int): size of each point on any dimension in the image in terms of pixels.roi_alpha (float): transparency of the ROIs in a range of [0.0, 1.0].roi_border_width (int): width of the line border on each ROI.trace_plot_width (int): width of the trace plot.
In [ ]:
proj_img = "std"
block_size = 1
roi_alpha = 0.3
roi_border_width = 3
trace_plot_width = 500
bio.curdoc().clear()
grey_range = get_all_greys()
grey_cm = LinearColorMapper(grey_range)
colors_rgb = get_rgb_array(len(mskimg))
colors_rgb = colors_rgb.tolist()
colors_rgb = list(imap(webcolors.rgb_to_hex, colors_rgb))
mskctr_pts_y, mskctr_pts_x = masks_to_contours_2d(mskimg)
mskctr_pts_dtype = np.min_scalar_type(max(mskimg.shape[1:]) - 1)
mskctr_pts_y = [np.array(_, dtype=mskctr_pts_dtype) for _ in mskctr_pts_y]
mskctr_pts_x = [np.array(_, dtype=mskctr_pts_dtype) for _ in mskctr_pts_x]
mskctr_srcs = ColumnDataSource(data=dict(x=mskctr_pts_x, y=mskctr_pts_y, color=colors_rgb))
if isinstance(proj_img, basestring):
proj_img = [proj_img]
else:
proj_img = list(proj_img)
proj_plot_width = block_size*mskimg.shape[2]
proj_plot_height = block_size*mskimg.shape[1]
plot_projs = []
if "max" in proj_img:
plot_max = bp.Figure(plot_width=proj_plot_width, plot_height=proj_plot_height,
x_range=[0, mskimg.shape[2]], y_range=[mskimg.shape[1], 0],
tools=["tap", "pan", "box_zoom", "wheel_zoom", "save", "reset"],
title="Max Projection with ROIs", border_fill_color="black")
plot_max.image(image=[numpy.flipud(imgproj_max)], x=[0], y=[mskimg.shape[1]],
dw=[imgproj_max.shape[1]], dh=[imgproj_max.shape[0]], color_mapper=grey_cm)
plot_max.patches('x', 'y', source=mskctr_srcs, alpha=roi_alpha, line_width=roi_border_width, color="color")
plot_max.outline_line_color = "white"
for i in irange(len(plot_max.axis)):
plot_max.axis[i].axis_line_color = "white"
plot_projs.append(plot_max)
if "mean" in proj_img:
plot_mean = bp.Figure(plot_width=proj_plot_width, plot_height=proj_plot_height,
x_range=[0, mskimg.shape[2]], y_range=[mskimg.shape[1], 0],
tools=["tap", "pan", "box_zoom", "wheel_zoom", "save", "reset"],
title="Mean Projection with ROIs", border_fill_color="black")
plot_mean.image(image=[numpy.flipud(imgproj_mean)], x=[0], y=[mskimg.shape[1]],
dw=[mskimg.shape[2]], dh=[mskimg.shape[1]], color_mapper=grey_cm)
plot_mean.patches('x', 'y', source=mskctr_srcs, alpha=roi_alpha, line_width=roi_border_width, color="color")
plot_mean.outline_line_color = "white"
for i in irange(len(plot_mean.axis)):
plot_mean.axis[i].axis_line_color = "white"
plot_projs.append(plot_mean)
if "std" in proj_img:
plot_std = bp.Figure(plot_width=proj_plot_width, plot_height=proj_plot_height,
x_range=[0, mskimg.shape[2]], y_range=[mskimg.shape[1], 0],
tools=["tap", "pan", "box_zoom", "wheel_zoom", "save", "reset"],
title="Std Dev Projection with ROIs", border_fill_color="black")
plot_std.image(image=[numpy.flipud(imgproj_std)], x=[0], y=[mskimg.shape[1]],
dw=[mskimg.shape[2]], dh=[mskimg.shape[1]], color_mapper=grey_cm)
plot_std.patches('x', 'y', source=mskctr_srcs, alpha=roi_alpha, line_width=roi_border_width, color="color")
plot_std.outline_line_color = "white"
for i in irange(len(plot_std.axis)):
plot_std.axis[i].axis_line_color = "white"
plot_projs.append(plot_std)
all_tr_dtype_srcs = ColumnDataSource(data=dict(traces_dtype=traces.dtype.type(0)[None]))
all_tr_shape_srcs = ColumnDataSource(data=dict(traces_shape=traces.shape))
all_tr_srcs = ColumnDataSource(data=dict(
traces=numpy.frombuffer(
gzip_compress(traces.tobytes()),
dtype=np.uint8
)
))
tr_srcs = ColumnDataSource(data=dict(times_sel=[], traces_sel=[], colors_sel=[]))
plot_tr = bp.Figure(plot_width=trace_plot_width, plot_height=proj_plot_height,
x_range=(0.0, float(traces.shape[1])), y_range=(float(traces.min()), float(traces.max())),
tools=["pan", "box_zoom", "wheel_zoom", "save", "reset"], title="ROI traces",
background_fill_color="black", border_fill_color="black")
plot_tr.multi_line("times_sel", "traces_sel", source=tr_srcs, color="colors_sel")
plot_tr.outline_line_color = "white"
for i in irange(len(plot_tr.axis)):
plot_tr.axis[i].axis_line_color = "white"
plot_projs.append(plot_tr)
mskctr_srcs.selected.js_on_change("indices", CustomJS(
args=dict(
mskctr_srcs=mskctr_srcs,
all_tr_dtype_srcs=all_tr_dtype_srcs,
all_tr_shape_srcs=all_tr_shape_srcs,
all_tr_srcs=all_tr_srcs,
tr_srcs=tr_srcs
), code="""
var range = function(n){ return Array.from(Array(n).keys()); };
var traces_not_decoded = (all_tr_dtype_srcs.data['traces_dtype'] == 0);
var traces_dtype = all_tr_dtype_srcs.data['traces_dtype'].constructor;
var traces_shape = all_tr_shape_srcs.data['traces_shape'];
var trace_len = traces_shape[1];
var traces = all_tr_srcs.data['traces'];
if (traces_not_decoded) {
traces = window.pako.inflate(traces);
traces = new traces_dtype(traces.buffer);
all_tr_srcs.data['traces'] = traces;
all_tr_dtype_srcs.data['traces_dtype'] = 1;
}
var inds = cb_obj['1d'].indices;
var colors = mskctr_srcs.data['color'];
var selected = tr_srcs.data;
var times = range(trace_len);
selected['times_sel'] = [];
selected['traces_sel'] = [];
selected['colors_sel'] = [];
for (i = 0; i < inds.length; i++) {
var inds_i = inds[i];
var trace_i = traces.slice(trace_len*inds_i, trace_len*(inds_i+1));
var color_i = colors[inds_i];
selected['times_sel'].push(times);
selected['traces_sel'].push(trace_i);
selected['colors_sel'].push(color_i);
}
tr_srcs.change.emit();
"""))
plot_group = Row(*plot_projs)
# Clear out the old HTML file before writing a new one.
io_remove(data_basename + postfix_html + html_ext)
script, div = be.components(plot_group)
cdn = "\n" + generate_cdn("sha384") + "\n"
cdn += """
<script type="text/javascript" src="{url}" integrity="{integrity}" crossorigin="anonymous"></script>
""".format(
url="https://cdnjs.cloudflare.com/ajax/libs/pako/{ver}/pako_inflate.min.js",
integrity="{sha_type}-{sha_value}",
).format(
ver="1.0.6",
sha_type="sha384",
sha_value="vfctOCT+kAyhRRvZr0t63Ktb6zOZrCbLW9CIyQr9G4UMhKAabPpM3iDOI2lnXsX4"
)
cdn += "\n"
write_html(data_basename + postfix_html + html_ext, data_basename + postfix_html, div, script, cdn)
from IPython.display import display, IFrame
display(IFrame(data_basename + postfix_html + html_ext, "100%", 1.05*proj_plot_height))
In [ ]:
# Test teardown. Ignore warnings during production runs.
%run ./teardown_tests.py
In [ ]: