Batch correction

What is batch correction? A "Batch" is when experiments have been performed at different times and there's some obvious difference between them. Single-cell experiments are often inherently "batchy" because you can only perform so many single cell captures at once, and you do multiple captures, over different days, with different samples. How do you correct for the technical noise without deleting the true biological signal?

Avoiding batch effects

First things first, it's best to design your experiments to minimize batch effects. For example, if you can mix your samples such that there are multiple representations of samples per single-cell capture, then this will help because you will have representations of both biological and technical variance across batches, rather than BOTH biological and technical variance.

Hicks et al, preprint

Bad: Technical variance is the same as biological variance

Here, when you try to correct for batch effects between captures, it's impossible to know whether you're removing the technical noise of the different captures, or the biological signal of the data.

Good: Technical variance is different from biological variance

The idea here is that you would ahead of time, mix the cells from your samples in equal proportions and then perform cell capture on the mixed samples, so you would get different technical batches, but they wouldn't be counfounded by the biological signals.

Here, when you correct for batch effects, the technical batches and biological signals are separate.

If it's completely impossible to do multiple biological samples in the same technical replicate...

For example, if you have to harvest your cells at parcticular timepoints, here are some ways that you can try to mitigate the batch effects:

  • Repeat the timepoint
  • Save an aliquot of cells from each timepoint and run another experiment with the mixed aliquots

Correcting batch effects

Okay so say your data are such that you couldn't have mixed your biological samples ahead of time. What do you do?

There's two main ways to approach batch correction: using groups of samples or groups of features (genes).

Sample-batchy

This is when you have groups of samples that may have some biological difference between them, but also have technical differences between them. Say, you performed single-cell capture on several different days from different mice, of somewhat overlapping ages. You know that you have the biological signal from the different mice and the different ages, but you also have the technical signal from the different batches. BUT there's no getting around that you had to sacrifice the mice and collect their cells in one batch

Feature-batchy

This is when you think particular groups of genes are contributing to the batch effects.

How to find these features:

  • Numerical feature (e.g. RIN) associated with each sample
  • Cell cycle genes (??Buetttner 2015?)
  • (RUVseq) - Use an external dataset (e.g. bulk samples) to find non-differentially expressed genes and use them to correct between groups

In [ ]:
from __future__ import print_function

# Interactive Python (IPython - now Jupyter) widgets for interactive exploration
import ipywidgets

# Numerical python library
import numpy as np

# PLotting library
import matplotlib.pyplot as plt

# Dataframes in python
import pandas as pd

# Linear model correction
import patsy 

# Even better plotting
import seaborn as sns

# Batch effect correction
# This import statement only works because there's a folder called "combat_py" here, not that there's a module installed
from combat_py.combat import combat


# Use the styles and colors that I like
sns.set(style='white', context='talk', palette='Set2')
%matplotlib inline

Feature-batchy


In [ ]:
np.random.seed(2016)

n_samples = 10
n_genes = 20

half_genes = int(n_genes/2)
half_samples = int(n_samples/2)
size = n_samples * n_genes

genes = ['Gene_{}'.format(str(i+1).zfill(2)) for i in range(n_genes)]
samples = ['Sample_{}'.format(str(i+1).zfill(2)) for i in range(n_samples)]

data = pd.DataFrame(np.random.randn(size).reshape(n_samples, n_genes), index=samples, columns=genes)

# Add biological variance
data.iloc[:half_samples, :half_genes] += 1
data.iloc[:half_samples, half_genes:] += -1
data.iloc[half_samples:, half_genes:] += 1
data.iloc[half_samples:, :half_genes] += -1

# Biological samples
mouse_groups = pd.Series(dict(zip(data.index, (['Mouse_01'] * int(n_samples/2)) + (['Mouse_02'] * int(n_samples/2)))), 
                         name="Mouse")
mouse_to_color = dict(zip(['Mouse_01', 'Mouse_02'], ['lightgrey', 'black']))
mouse_colors = [mouse_to_color[mouse_groups[x]] for x in samples]

# Gene colors
gene_colors = sns.color_palette('husl', n_colors=n_genes)

Plot original biological variance data


In [ ]:
g = sns.clustermap(data, row_colors=mouse_colors, col_cluster=False, row_cluster=False, linewidth=0.5, 
                   col_colors=gene_colors,
                   cbar_kws=dict(label='Normalized Expression'))
plt.setp(g.ax_heatmap.get_yticklabels(), rotation=0);

In [ ]:
def make_tidy(data, sample_groups):
    tidy = data.unstack()
    tidy = tidy.reset_index()
    tidy = tidy.rename(columns={'level_0': 'Gene', 'level_1': "Sample", 0: "Normalized Expression"})
    tidy = tidy.join(sample_groups, on='Sample')
    return tidy

In [ ]:
tidy = make_tidy(data, mouse_groups)

In [ ]:
fig, ax = plt.subplots()
sns.boxplot(hue='Gene', y='Normalized Expression', data=tidy, x='Mouse')
ax.legend_.set_visible(False)

Add technical noise


In [ ]:
# Choose odd-numbered samples to be in batch1 and even numbered samples to be in batch 2
batch1_samples = samples[::2]
batch2_samples = data.index.difference(batch1_samples)
batches = pd.Series(dict((x, 'Batch_01') if x in batch1_samples else (x, "Batch_02") for x in samples), name="Batch")

# Add random noise for all genes except the last two in each batch
noisy_data = data.copy()
noisy_data.ix[batch1_samples, :-2] += np.random.normal(size=n_genes-2, scale=2)
noisy_data.ix[batch2_samples, :-2] += np.random.normal(size=n_genes-2, scale=2)


# Assign colors for batches
batch_to_color = dict(zip(["Batch_01", "Batch_02"], sns.color_palette()))
batch_colors = [batch_to_color[batches[x]] for x in samples]
row_colors = [mouse_colors, batch_colors]


g = sns.clustermap(noisy_data, row_colors=row_colors, col_cluster=False, row_cluster=False, linewidth=0.5, 
                   col_colors=gene_colors, cbar_kws=dict(label='Normalized Expression'))
plt.setp(g.ax_heatmap.get_yticklabels(), rotation=0);

We can see that there's some batch effect - for batch1 (light grey), Gene_15 is in general lower, and Gene_01 is in general higher. And for batch2 (black), Gene_16 is in general higher.

But, Gene_19 and Gene_20 are unaffected.


In [ ]:
tidy_noisy = make_tidy(noisy_data, mouse_groups)
tidy_noisy = tidy_noisy.join(batches, on='Sample')
tidy_noisy.head()

Lets plot the boxplots of data the same way, with the x-axis as the mouse they came from and the y-axis ad the genes


In [ ]:
fig, ax = plt.subplots()
sns.boxplot(hue='Gene', y='Normalized Expression', data=tidy_noisy, x='Mouse')
ax.legend_.set_visible(False)

We can see that compared to before, where we had clear differences in gene expression from genes 1-10 and 11-19 in the two mice, we don't see it as much with the noisy data.

Now let's plot the data a different way, with the x-axis as the batch


In [ ]:
fig, ax = plt.subplots()
sns.boxplot(hue='Gene', y='Normalized Expression', data=tidy_noisy, x='Batch')
ax.legend_.set_visible(False)

How to quantify the batch effect?


In [ ]:
fig, ax = plt.subplots()
sns.pointplot(hue='Batch', x='Normalized Expression', data=tidy_noisy, y='Gene', orient='horizontal', 
              scale=0.5, palette=batch_colors)

In [ ]:
fig, ax = plt.subplots()
sns.pointplot(hue='Batch', x='Normalized Expression', data=tidy_noisy, y='Gene', orient='horizontal', scale=0.5)
sns.pointplot(x='Normalized Expression', data=tidy_noisy, y='Gene', orient='horizontal', scale=0.75, color='k', 
              linestyle=None)

How to get rid of the batch effect?

COMBAT

We will use "COMBAT" to get rid of the batch effect. What combat does is basically what we just did with our eyes and intuition - find genes whose gene expression varies greatly between batches, and adjust the expression of the gene so it's closer to the mean total expression across batches.

(may need to whiteboard here)

Create metadata matrix


In [ ]:
metadata = pd.concat([batches, mouse_groups], axis=1)
metadata

In [ ]:
def remove_batch_effects_with_combat(batch, keep_constant=None, cluster_on_correlations=False):
    if keep_constant is not None or keep_constant in metadata:
        # We'll use patsy (statistical models in python) to create a "Design matrix" which encodes the batch as 
        # a boolean (0 or 1) value so the computer cna understand it.
        model = patsy.dmatrix('~ {}'.format(keep_constant), metadata, return_type="dataframe")
    elif keep_constant == 'null' or keep_constant is None:
        model = None
        
    # --- Correct for batch effects --- #
    corrected_data = combat(noisy_data.T, metadata[batch], model)
    
    # Transpose so samples are the rows and the features are the columns
    corrected_data = corrected_data.T

    # --- Plot the heatmap --- #
    if cluster_on_correlations:
        g = sns.clustermap(corrected_data.T.corr(), row_colors=row_colors, col_cluster=True, row_cluster=True, linewidth=0.5, 
                           vmin=-1, vmax=1, col_colors=row_colors, cbar_kws=dict(label='Pearson R'))
        plt.setp(g.ax_heatmap.get_yticklabels(), rotation=0);
    else:
        g = sns.clustermap(corrected_data, row_colors=row_colors, col_cluster=False, row_cluster=False, linewidth=0.5, 
                       col_colors=gene_colors, cbar_kws=dict(label='Normalized Expression'))
        plt.setp(g.ax_heatmap.get_yticklabels(), rotation=0);
        
    # Uncomment the line below to save the batch corrected heatmap
    # g.savefig('combat_batch_corrected_clustermap.pdf')
    
    # --- Quantification of the batch effect correction --- #
    # Create a "tidy" version of the dataframe for plotting
    tidy_corrected = make_tidy(corrected_data, mouse_groups)
    tidy_corrected = tidy_corrected.join(batches, on='Sample')
    tidy_corrected.head()

    # Set up the figure
    # 4 columns of figure panels
    figure_columns = 4
    width = 4.5 * figure_columns
    height = 4
    fig, axes = plt.subplots(ncols=figure_columns, figsize=(width, height))

    # PLot original data vs the corrected data
    ax = axes[0]
    ax.plot(data.values.flat, corrected_data.values.flat, 'o', 
            # Everything in the next line is my personal preference so it looks nice
            alpha=0.5, markeredgecolor='k', markeredgewidth=0.5)
    ax.set(xlabel='Original (Batchy) data', ylabel='COMBAT corrected data')

    # PLot the mean gene expression within batch in colors, and the mean gene expression across both batches in black
    ax = axes[1]
    sns.pointplot(hue='Batch', x='Normalized Expression', data=tidy_corrected, y='Gene', orient='horizontal', scale=.5, ax=ax)
    sns.pointplot(x='Normalized Expression', data=tidy_corrected, y='Gene', orient='horizontal', 
                  scale=0.75, color='k', linestyle=None, ax=ax)

    # PLot the gene epxression distribution per mouse
    ax = axes[2]
    sns.boxplot(hue='Gene', y='Normalized Expression', data=tidy_corrected, x='Mouse', ax=ax, 
                # Adjusting linewidth for my personal preference
                linewidth=1)
    # Don't show legend because it's too big
    ax.legend_.set_visible(False)
    
    
    # --- Plot boxplots of average difference between gene expression in batches --- #
    # Gete mean gene expression within batch for the original noisy data
    mean_batch_expression = noisy_data.groupby(batches).mean()
    noisy_batch_diff = (mean_batch_expression.loc['Batch_01'] - mean_batch_expression.loc['Batch_02']).abs()
    noisy_batch_diff.name = 'mean(|batch1 - batch2|)'
    noisy_batch_diff = noisy_batch_diff.reset_index()
    noisy_batch_diff['Data type'] = 'Noisy'

    # Get mean gene expression within batch for the corrected data
    mean_corrected_batch_expression = corrected_data.groupby(batches).mean()
    corrected_batch_diff = (mean_corrected_batch_expression.loc['Batch_01'] - mean_corrected_batch_expression.loc['Batch_02']).abs()
    corrected_batch_diff.name = 'mean(|batch1 - batch2|)'
    corrected_batch_diff = corrected_batch_diff.reset_index()
    corrected_batch_diff['Data type'] = 'Corrected'

    # Compile the two tables into one (concatenate)
    batch_differences = pd.concat([noisy_batch_diff, corrected_batch_diff])
    batch_differences.head()

    sns.boxplot(x='Data type', y='mean(|batch1 - batch2|)', data=batch_differences, ax=axes[3])

    # Remove right and top axes lines so it looks nicer
    sns.despine()

    # Magically adjust the figure panels (axes) so they fit nicely
    fig.tight_layout()

    # Uncomment the line below to save the figure of three panels
    # fig.savefig('combat_batch_corrected_panels.pdf')


ipywidgets.interact(
    remove_batch_effects_with_combat,
    batch=ipywidgets.Dropdown(options=['Mouse', 'Batch'], value="Batch", description='Batch to correct for'), 
    keep_constant=ipywidgets.Dropdown(value=None, options=[None, 'Mouse', 'Batch', 'Mouse + Batch'], 
                                      description='Variable of interest'),
    cluster_on_correlations=ipywidgets.Checkbox(value=False, description="Cluster on (Pearson) correlations between samples"));

Try doing these and see how they compare. Do you see similar trends to the original data? Do any of these create errors? Why would that be?

  1. Batch to correct for = Batch, Variable of interest = Mouse
  2. Batch to correct for = Mouse, Variable of interest = Batch
  3. Batch to correct for = Batch, Variable of interest = Mouse + Batch
  4. ... your own combinations!

With each of these try turning "Cluster on (Pearson) correlations between samples" on and off. This is a nice way that we can visualize the improvement in reducing the batch-dependent signal.

Feature-batchy

What if there are specific genes or features that are contributing to the batches?

This is the idea behind correcting for cell-cycle genes or some other feature that you know is associated with the data, e.g. the RNA Integrity Number (RIN).

Let's add some feature-batchy noise to our original data


In [ ]:
metadata['RIN'] = np.arange( len(samples)) + 0.5
metadata

Add noise and plot it. Use first and last genes as controls that dno't have any noise


In [ ]:
# rin_noise = metadata['RIN'].apply(lambda x: pd.Series(np.random.normal(loc=x, size=n_genes), index=genes))
rin_noise = metadata['RIN'].apply(lambda x: pd.Series(np.ones(n_genes-2)*x, index=genes[1:-1]))
rin_noise = rin_noise.reindex(columns=genes)
rin_noise = rin_noise.fillna(0)

g = sns.clustermap(rin_noise, row_colors=mouse_colors, col_cluster=False, row_cluster=False, linewidth=0.5, 
                   col_colors=gene_colors, cbar_kws=dict(label='RIN Noise'))
plt.setp(g.ax_heatmap.get_yticklabels(), rotation=0);

Add the noise to the data and re-center so that each gene's mean is approximately zero.


In [ ]:
rin_batchy_data = data + rin_noise
rin_batchy_data

# Renormalize the data so genes are 0-centered

rin_batchy_data = (rin_batchy_data - rin_batchy_data.mean())/rin_batchy_data.std()

g = sns.clustermap(rin_batchy_data, row_colors=mouse_colors, col_cluster=False, row_cluster=False, linewidth=0.5, 
                   col_colors=gene_colors, cbar_kws=dict(label='Normalized Expression'))
plt.setp(g.ax_heatmap.get_yticklabels(), rotation=0);

If we plot the RIN vs the RIN-batchy gene expression, we'll see that from this one variable, we see an increase in expression! Of course, we could also have created a variable that linearly decreases expression.


In [ ]:
tidy_rin_batchy = make_tidy(rin_batchy_data, mouse_groups)
tidy_rin_batchy = tidy_rin_batchy.join(metadata['RIN'], on='Sample')


g = sns.FacetGrid(tidy_rin_batchy, hue='Gene')
g.map(plt.plot, 'RIN', 'Normalized Expression', alpha=0.5)

Use RIN to predict gene expression

We will use linear regression to use RIN as our dependent variable and predict gene expression from there. Then we'll create a new, corrected matrix, with the influence of RIN removed


In [ ]:
from __future__ import print_function
import six
from sklearn import linear_model

regressor = linear_model.LinearRegression()
regressor

# Use RIN as the "X" - the "dependent" variable, the one you expect your gene expression to vary with.

regressor.fit(metadata['RIN'].to_frame(), rin_batchy_data)

# Use RIN to predict gene expression
rin_dependent_data = pd.DataFrame(regressor.predict(metadata['RIN'].to_frame()), columns=genes, index=samples)
rin_dependent_data

from sklearn.metrics import r2_score

# explained_variance = r2_score(rin_batchy_data, rin_dependent_data)
# six.print_("Explained variance by RIN:", explained_variance)

rin_corrected_data = rin_batchy_data - rin_dependent_data
rin_corrected_data

# Somewhat contrived, but try to predict the newly corrected data with RIN

r2_score(rin_corrected_data, rin_dependent_data)

tidy_rin_corrected = make_tidy(rin_corrected_data, mouse_groups)
tidy_rin_corrected = tidy_rin_corrected.join(metadata['RIN'], on="Sample")
tidy_rin_corrected.head()

g = sns.FacetGrid(tidy_rin_corrected, hue='Gene')
g.map(plt.plot, 'RIN', 'Normalized Expression', alpha=0.5)

g = sns.clustermap(rin_corrected_data, row_colors=mouse_colors, col_cluster=False, row_cluster=False, linewidth=0.5, 
                   col_colors=gene_colors, cbar_kws=dict(label='Normalized Expression'))
plt.setp(g.ax_heatmap.get_yticklabels(), rotation=0);

sns.clustermap(rin_corrected_data.T.corr(), row_colors=mouse_colors,linewidth=0.5, 
                   col_colors=mouse_colors, cbar_kws=dict(label='Pearson R'))
plt.setp(g.ax_heatmap.get_yticklabels(), rotation=0);

Now the data dcoens't vary by RIN! But.... now we over-corrected and removed the biological signal as well.

Other options to talk about

As you have seen, dealing with batch effects in single-cell data is supremely difficult and the best thing you can do for yourself is design your experiment nicely so you don't have to.

  • SVA
    • Can specify that you want to correct for something (like RIN) but don't correct for what you're interested in. But... often in single cell data you're trying to find new populations so you don't know a prior what you want to not be corrected for
  • RUVseq
    • "RUV" = "Remove unwanted variation"
    • With the "RUVg" version can specify a set of control genes that you know aren't supposed to change between groups (maybe from a bulk experiment) but they say in their manual not to use the normalized counts for differential expression, only for exploration, because you may have corrected for something you actually DID want but didn't know
  • scLVM
    • This method claims to account for differences in cell cycle stage and help to put all cells onto the same scale, so you can then do pseudotime ordering and clustering and all that jazz.