Image registration

A common problem when working with collections of images is registering or aligning them, relative to a reference. The thunder-registration package implements a set of registration algorithms all exposed through a common API. These algorithms support parallelization through Spark, but can also be run locally on numpy arrays. Here, we generate example data for performing registration, apply a registration algorithm, and validate the results.

Setup imports


In [1]:
%matplotlib inline

In [3]:
import seaborn as sns
import matplotlib.pyplot as plt
from showit import image, tile
sns.set_style('darkgrid')
sns.set_context('notebook')

In [4]:
import thunder as td

Generating data

We will use a toy example dataset test registration algorithms. These data do not actually have any motion, so to test the algorithms, we will induce fake motion. First we'll load and inspect the data.


In [5]:
data = td.images.fromexample('mouse')
data


Out[5]:
Images
mode: local
dtype: int16
shape: (20, 64, 64)

There are 500 images (corresponding to 500 time points), and the data are two-dimensional, so we'll want to generate 500 random shifts in x and y. We'll use smoothing functions from scipy to make sure the drift varies slowly over time, which will be easier to look at.


In [38]:
from numpy import random
from scipy.ndimage.filters import gaussian_filter
t = 20
dx = gaussian_filter(random.randn(t), 1.5) * 10
dy = gaussian_filter(random.randn(t), 1.5) * 10

In [39]:
plt.plot(dx);
plt.plot(dy);


Now let's use these drifts to shift the data. We'll use the apply method on our data, which applies an arbitrary function to each record; in this case, the function is to shift by an amount given by the corresponding entry in our list of shifts.


In [40]:
from scipy.ndimage import shift
shifted = data.map(lambda (k, v): shift(v, (dx[k], dy[k]), mode='nearest', order=0), with_keys=True)

Look at the first entry of both the original images and the shifted images, and their difference


In [41]:
im1 = data[0].toarray()
im2 = shifted[0].toarray()
tile([im1, im2, im1-im2], clim=[(0,300), (0,300), (-300,300)], grid=(1,3), size=14);


It's also useful to look at the mean of the raw images and the shifted images, the mean of the shifted images should be much more blurry!


In [42]:
tile([data.mean(), shifted.mean()], size=14);


Registration

To run registration, first we create a registration method by importing the algorithm CrossCorr


In [43]:
from registration import CrossCorr
algorithm = CrossCorr()

This method computes a cross-correlation between every image and a reference. First, we'll compute a reference using the mean of the images.


In [44]:
reference = shifted.mean().toarray()

In [45]:
image(reference);


Now we use the registration method reg and fit it to the shifted data, returning a fitted RegistrationModel


In [46]:
model = algorithm.fit(shifted, reference=reference)

Inspect the model


In [47]:
model


Out[47]:
RegistrationModel
length: 20
algorithm: CrossCorr

The model is a dictionary mapping tuple indices to transformations. You can inspect them:


In [48]:
model.transformations[(0,)]


Out[48]:
Displacement(delta=[-4, -2])

You can also convert the full collection of transformations into an array, which is useful for plotting. Here we'll plot the estimated transformations relative to the ground truth (as dashed lines), they should be fairly similar.


In [49]:
clrs = sns.color_palette('deep')
plt.plot(model.toarray()[:,0], color=clrs[0])
plt.plot(dx, '--', color=clrs[0])
plt.plot(model.toarray()[:,1], color=clrs[1])
plt.plot(dy, '--', color=clrs[1]);


Note that, while following a similar pattern as the ground truth, the estimates are not correct. That's because we didn't use the true reference to estimate the displacements, but rather the mean of the displaced data, which biases the estimated displacements. To see that we get the exact displacements back, let's compute a reference from the original, unshifted data.


In [50]:
reference = data.mean().toarray()
model = algorithm.fit(shifted, reference=reference)

Now the estimates should be exact (up to rounding error)! But note that this is sort of cheating, because in general we don't know the reference exactly.


In [51]:
plt.plot(model.toarray()[:,0], color=clrs[0])
plt.plot(dx, '--', color=clrs[0])
plt.plot(model.toarray()[:,1], color=clrs[1])
plt.plot(dy, '--', color=clrs[1]);


We can now use our model to transform a set of images, which applies the estimated transformations. The API design makes it easy to apply the transformations to the dataset we used to estimate the transformations, or a different one. We'll use the model we just estimates, which used the true reference, because it will be easy to see that it did the right thing.


In [53]:
corrected = model.transform(shifted)

Let's again look at the first image from the orignal and corrected, and their difference. Whereas before they were different, now they should be the same, except for minor near the boundaries (where the image has been replaced with its nearest neighbors).


In [54]:
im1 = data[0].toarray()
im2 = corrected[0].toarray()
tile([im1, im2, im1-im2], clim=[(0,300), (0,300), (-300,300)], grid=(1,3), size=14);


As a final check on the registation, we can compare the mean of the shifted data, and the mean of the regsitered data. The latter should be much sharper.


In [55]:
tile([shifted.mean(), corrected.mean()], size=14);