What is batch correction? A "Batch" is when experiments have been performed at different times and there's some obvious difference between them. Single-cell experiments are often inherently "batchy" because you can only perform so many single cell captures at once, and you do multiple captures, over different days, with different samples. How do you correct for the technical noise without deleting the true biological signal?
First things first, it's best to design your experiments to minimize batch effects. For example, if you can mix your samples such that there are multiple representations of samples per single-cell capture, then this will help because you will have representations of both biological and technical variance across batches, rather than BOTH biological and technical variance.
Here, when you try to correct for batch effects between captures, it's impossible to know whether you're removing the technical noise of the different captures, or the biological signal of the data.
The idea here is that you would ahead of time, mix the cells from your samples in equal proportions and then perform cell capture on the mixed samples, so you would get different technical batches, but they wouldn't be counfounded by the biological signals.
Here, when you correct for batch effects, the technical batches and biological signals are separate.
For example, if you have to harvest your cells at parcticular timepoints, here are some ways that you can try to mitigate the batch effects:
Okay so say your data are such that you couldn't have mixed your biological samples ahead of time. What do you do?
There's two main ways to approach batch correction: using groups of samples or groups of features (genes).
This is when you have groups of samples that may have some biological difference between them, but also have technical differences between them. Say, you performed single-cell capture on several different days from different mice, of somewhat overlapping ages. You know that you have the biological signal from the different mice and the different ages, but you also have the technical signal from the different batches. BUT there's no getting around that you had to sacrifice the mice and collect their cells in one batch
This is when you think particular groups of genes are contributing to the batch effects.
How to find these features:
In [ ]:
from __future__ import print_function
# Interactive Python (IPython - now Jupyter) widgets for interactive exploration
import ipywidgets
# Numerical python library
import numpy as np
# PLotting library
import matplotlib.pyplot as plt
# Dataframes in python
import pandas as pd
# Linear model correction
import patsy
# Even better plotting
import seaborn as sns
# Batch effect correction
# This import statement only works because there's a folder called "combat_py" here, not that there's a module installed
from combat_py.combat import combat
# Use the styles and colors that I like
sns.set(style='white', context='talk', palette='Set2')
%matplotlib inline
In [ ]:
np.random.seed(2016)
n_samples = 10
n_genes = 20
half_genes = int(n_genes/2)
half_samples = int(n_samples/2)
size = n_samples * n_genes
genes = ['Gene_{}'.format(str(i+1).zfill(2)) for i in range(n_genes)]
samples = ['Sample_{}'.format(str(i+1).zfill(2)) for i in range(n_samples)]
data = pd.DataFrame(np.random.randn(size).reshape(n_samples, n_genes), index=samples, columns=genes)
# Add biological variance
data.iloc[:half_samples, :half_genes] += 1
data.iloc[:half_samples, half_genes:] += -1
data.iloc[half_samples:, half_genes:] += 1
data.iloc[half_samples:, :half_genes] += -1
# Biological samples
mouse_groups = pd.Series(dict(zip(data.index, (['Mouse_01'] * int(n_samples/2)) + (['Mouse_02'] * int(n_samples/2)))),
name="Mouse")
mouse_to_color = dict(zip(['Mouse_01', 'Mouse_02'], ['lightgrey', 'black']))
mouse_colors = [mouse_to_color[mouse_groups[x]] for x in samples]
# Gene colors
gene_colors = sns.color_palette('husl', n_colors=n_genes)
In [ ]:
g = sns.clustermap(data, row_colors=mouse_colors, col_cluster=False, row_cluster=False, linewidth=0.5,
col_colors=gene_colors,
cbar_kws=dict(label='Normalized Expression'))
plt.setp(g.ax_heatmap.get_yticklabels(), rotation=0);
In [ ]:
def make_tidy(data, sample_groups):
tidy = data.unstack()
tidy = tidy.reset_index()
tidy = tidy.rename(columns={'level_0': 'Gene', 'level_1': "Sample", 0: "Normalized Expression"})
tidy = tidy.join(sample_groups, on='Sample')
return tidy
In [ ]:
tidy = make_tidy(data, mouse_groups)
In [ ]:
fig, ax = plt.subplots()
sns.boxplot(hue='Gene', y='Normalized Expression', data=tidy, x='Mouse')
ax.legend_.set_visible(False)
In [ ]:
# Choose odd-numbered samples to be in batch1 and even numbered samples to be in batch 2
batch1_samples = samples[::2]
batch2_samples = data.index.difference(batch1_samples)
batches = pd.Series(dict((x, 'Batch_01') if x in batch1_samples else (x, "Batch_02") for x in samples), name="Batch")
# Add random noise for all genes except the last two in each batch
noisy_data = data.copy()
noisy_data.ix[batch1_samples, :-2] += np.random.normal(size=n_genes-2, scale=2)
noisy_data.ix[batch2_samples, :-2] += np.random.normal(size=n_genes-2, scale=2)
# Assign colors for batches
batch_to_color = dict(zip(["Batch_01", "Batch_02"], sns.color_palette()))
batch_colors = [batch_to_color[batches[x]] for x in samples]
row_colors = [mouse_colors, batch_colors]
g = sns.clustermap(noisy_data, row_colors=row_colors, col_cluster=False, row_cluster=False, linewidth=0.5,
col_colors=gene_colors, cbar_kws=dict(label='Normalized Expression'))
plt.setp(g.ax_heatmap.get_yticklabels(), rotation=0);
We can see that there's some batch effect - for batch1 (light grey), Gene_15
is in general lower, and Gene_01
is in general higher. And for batch2 (black), Gene_16
is in general higher.
But, Gene_19 and Gene_20 are unaffected.
In [ ]:
tidy_noisy = make_tidy(noisy_data, mouse_groups)
tidy_noisy = tidy_noisy.join(batches, on='Sample')
tidy_noisy.head()
Lets plot the boxplots of data the same way, with the x-axis as the mouse they came from and the y-axis ad the genes
In [ ]:
fig, ax = plt.subplots()
sns.boxplot(hue='Gene', y='Normalized Expression', data=tidy_noisy, x='Mouse')
ax.legend_.set_visible(False)
We can see that compared to before, where we had clear differences in gene expression from genes 1-10 and 11-19 in the two mice, we don't see it as much with the noisy data.
Now let's plot the data a different way, with the x-axis as the batch
In [ ]:
fig, ax = plt.subplots()
sns.boxplot(hue='Gene', y='Normalized Expression', data=tidy_noisy, x='Batch')
ax.legend_.set_visible(False)
In [ ]:
fig, ax = plt.subplots()
sns.pointplot(hue='Batch', x='Normalized Expression', data=tidy_noisy, y='Gene', orient='horizontal',
scale=0.5, palette=batch_colors)
In [ ]:
fig, ax = plt.subplots()
sns.pointplot(hue='Batch', x='Normalized Expression', data=tidy_noisy, y='Gene', orient='horizontal', scale=0.5)
sns.pointplot(x='Normalized Expression', data=tidy_noisy, y='Gene', orient='horizontal', scale=0.75, color='k',
linestyle=None)
We will use "COMBAT" to get rid of the batch effect. What combat does is basically what we just did with our eyes and intuition - find genes whose gene expression varies greatly between batches, and adjust the expression of the gene so it's closer to the mean total expression across batches.
(may need to whiteboard here)
Create metadata matrix
In [ ]:
metadata = pd.concat([batches, mouse_groups], axis=1)
metadata
In [ ]:
def remove_batch_effects_with_combat(batch, keep_constant=None, cluster_on_correlations=False):
if keep_constant is not None or keep_constant in metadata:
# We'll use patsy (statistical models in python) to create a "Design matrix" which encodes the batch as
# a boolean (0 or 1) value so the computer cna understand it.
model = patsy.dmatrix('~ {}'.format(keep_constant), metadata, return_type="dataframe")
elif keep_constant == 'null' or keep_constant is None:
model = None
# --- Correct for batch effects --- #
corrected_data = combat(noisy_data.T, metadata[batch], model)
# Transpose so samples are the rows and the features are the columns
corrected_data = corrected_data.T
# --- Plot the heatmap --- #
if cluster_on_correlations:
g = sns.clustermap(corrected_data.T.corr(), row_colors=row_colors, col_cluster=True, row_cluster=True, linewidth=0.5,
vmin=-1, vmax=1, col_colors=row_colors, cbar_kws=dict(label='Pearson R'))
plt.setp(g.ax_heatmap.get_yticklabels(), rotation=0);
else:
g = sns.clustermap(corrected_data, row_colors=row_colors, col_cluster=False, row_cluster=False, linewidth=0.5,
col_colors=gene_colors, cbar_kws=dict(label='Normalized Expression'))
plt.setp(g.ax_heatmap.get_yticklabels(), rotation=0);
# Uncomment the line below to save the batch corrected heatmap
# g.savefig('combat_batch_corrected_clustermap.pdf')
# --- Quantification of the batch effect correction --- #
# Create a "tidy" version of the dataframe for plotting
tidy_corrected = make_tidy(corrected_data, mouse_groups)
tidy_corrected = tidy_corrected.join(batches, on='Sample')
tidy_corrected.head()
# Set up the figure
# 4 columns of figure panels
figure_columns = 4
width = 4.5 * figure_columns
height = 4
fig, axes = plt.subplots(ncols=figure_columns, figsize=(width, height))
# PLot original data vs the corrected data
ax = axes[0]
ax.plot(data.values.flat, corrected_data.values.flat, 'o',
# Everything in the next line is my personal preference so it looks nice
alpha=0.5, markeredgecolor='k', markeredgewidth=0.5)
ax.set(xlabel='Original (Batchy) data', ylabel='COMBAT corrected data')
# PLot the mean gene expression within batch in colors, and the mean gene expression across both batches in black
ax = axes[1]
sns.pointplot(hue='Batch', x='Normalized Expression', data=tidy_corrected, y='Gene', orient='horizontal', scale=.5, ax=ax)
sns.pointplot(x='Normalized Expression', data=tidy_corrected, y='Gene', orient='horizontal',
scale=0.75, color='k', linestyle=None, ax=ax)
# PLot the gene epxression distribution per mouse
ax = axes[2]
sns.boxplot(hue='Gene', y='Normalized Expression', data=tidy_corrected, x='Mouse', ax=ax,
# Adjusting linewidth for my personal preference
linewidth=1)
# Don't show legend because it's too big
ax.legend_.set_visible(False)
# --- Plot boxplots of average difference between gene expression in batches --- #
# Gete mean gene expression within batch for the original noisy data
mean_batch_expression = noisy_data.groupby(batches).mean()
noisy_batch_diff = (mean_batch_expression.loc['Batch_01'] - mean_batch_expression.loc['Batch_02']).abs()
noisy_batch_diff.name = 'mean(|batch1 - batch2|)'
noisy_batch_diff = noisy_batch_diff.reset_index()
noisy_batch_diff['Data type'] = 'Noisy'
# Get mean gene expression within batch for the corrected data
mean_corrected_batch_expression = corrected_data.groupby(batches).mean()
corrected_batch_diff = (mean_corrected_batch_expression.loc['Batch_01'] - mean_corrected_batch_expression.loc['Batch_02']).abs()
corrected_batch_diff.name = 'mean(|batch1 - batch2|)'
corrected_batch_diff = corrected_batch_diff.reset_index()
corrected_batch_diff['Data type'] = 'Corrected'
# Compile the two tables into one (concatenate)
batch_differences = pd.concat([noisy_batch_diff, corrected_batch_diff])
batch_differences.head()
sns.boxplot(x='Data type', y='mean(|batch1 - batch2|)', data=batch_differences, ax=axes[3])
# Remove right and top axes lines so it looks nicer
sns.despine()
# Magically adjust the figure panels (axes) so they fit nicely
fig.tight_layout()
# Uncomment the line below to save the figure of three panels
# fig.savefig('combat_batch_corrected_panels.pdf')
ipywidgets.interact(
remove_batch_effects_with_combat,
batch=ipywidgets.Dropdown(options=['Mouse', 'Batch'], value="Batch", description='Batch to correct for'),
keep_constant=ipywidgets.Dropdown(value=None, options=[None, 'Mouse', 'Batch', 'Mouse + Batch'],
description='Variable of interest'),
cluster_on_correlations=ipywidgets.Checkbox(value=False, description="Cluster on (Pearson) correlations between samples"));
Try doing these and see how they compare. Do you see similar trends to the original data? Do any of these create errors? Why would that be?
With each of these try turning "Cluster on (Pearson) correlations between samples" on and off. This is a nice way that we can visualize the improvement in reducing the batch-dependent signal.
What if there are specific genes or features that are contributing to the batches?
This is the idea behind correcting for cell-cycle genes or some other feature that you know is associated with the data, e.g. the RNA Integrity Number (RIN).
Let's add some feature-batchy noise to our original data
In [ ]:
metadata['RIN'] = np.arange( len(samples)) + 0.5
metadata
Add noise and plot it. Use first and last genes as controls that dno't have any noise
In [ ]:
# rin_noise = metadata['RIN'].apply(lambda x: pd.Series(np.random.normal(loc=x, size=n_genes), index=genes))
rin_noise = metadata['RIN'].apply(lambda x: pd.Series(np.ones(n_genes-2)*x, index=genes[1:-1]))
rin_noise = rin_noise.reindex(columns=genes)
rin_noise = rin_noise.fillna(0)
g = sns.clustermap(rin_noise, row_colors=mouse_colors, col_cluster=False, row_cluster=False, linewidth=0.5,
col_colors=gene_colors, cbar_kws=dict(label='RIN Noise'))
plt.setp(g.ax_heatmap.get_yticklabels(), rotation=0);
Add the noise to the data and re-center so that each gene's mean is approximately zero.
In [ ]:
rin_batchy_data = data + rin_noise
rin_batchy_data
# Renormalize the data so genes are 0-centered
rin_batchy_data = (rin_batchy_data - rin_batchy_data.mean())/rin_batchy_data.std()
g = sns.clustermap(rin_batchy_data, row_colors=mouse_colors, col_cluster=False, row_cluster=False, linewidth=0.5,
col_colors=gene_colors, cbar_kws=dict(label='Normalized Expression'))
plt.setp(g.ax_heatmap.get_yticklabels(), rotation=0);
If we plot the RIN vs the RIN-batchy gene expression, we'll see that from this one variable, we see an increase in expression! Of course, we could also have created a variable that linearly decreases expression.
In [ ]:
tidy_rin_batchy = make_tidy(rin_batchy_data, mouse_groups)
tidy_rin_batchy = tidy_rin_batchy.join(metadata['RIN'], on='Sample')
g = sns.FacetGrid(tidy_rin_batchy, hue='Gene')
g.map(plt.plot, 'RIN', 'Normalized Expression', alpha=0.5)
In [ ]:
from __future__ import print_function
import six
from sklearn import linear_model
regressor = linear_model.LinearRegression()
regressor
# Use RIN as the "X" - the "dependent" variable, the one you expect your gene expression to vary with.
regressor.fit(metadata['RIN'].to_frame(), rin_batchy_data)
# Use RIN to predict gene expression
rin_dependent_data = pd.DataFrame(regressor.predict(metadata['RIN'].to_frame()), columns=genes, index=samples)
rin_dependent_data
from sklearn.metrics import r2_score
# explained_variance = r2_score(rin_batchy_data, rin_dependent_data)
# six.print_("Explained variance by RIN:", explained_variance)
rin_corrected_data = rin_batchy_data - rin_dependent_data
rin_corrected_data
# Somewhat contrived, but try to predict the newly corrected data with RIN
r2_score(rin_corrected_data, rin_dependent_data)
tidy_rin_corrected = make_tidy(rin_corrected_data, mouse_groups)
tidy_rin_corrected = tidy_rin_corrected.join(metadata['RIN'], on="Sample")
tidy_rin_corrected.head()
g = sns.FacetGrid(tidy_rin_corrected, hue='Gene')
g.map(plt.plot, 'RIN', 'Normalized Expression', alpha=0.5)
g = sns.clustermap(rin_corrected_data, row_colors=mouse_colors, col_cluster=False, row_cluster=False, linewidth=0.5,
col_colors=gene_colors, cbar_kws=dict(label='Normalized Expression'))
plt.setp(g.ax_heatmap.get_yticklabels(), rotation=0);
sns.clustermap(rin_corrected_data.T.corr(), row_colors=mouse_colors,linewidth=0.5,
col_colors=mouse_colors, cbar_kws=dict(label='Pearson R'))
plt.setp(g.ax_heatmap.get_yticklabels(), rotation=0);
Now the data dcoens't vary by RIN! But.... now we over-corrected and removed the biological signal as well.
As you have seen, dealing with batch effects in single-cell data is supremely difficult and the best thing you can do for yourself is design your experiment nicely so you don't have to.