The purpose of this script is to provide modular code for generating errorbarjitter plots using python and pandas. The framework for this code was written in collaboration with Jason Wittenbach (@jwittenbach) of Janelia Research Campus. The idea is based on the matlab errorbarjitter code writtten by David Stern, also of Janelia Research Campus. This version uses free and open source software to achieve the same plot.
As in the original errorbarjitter, this function plots the mean ± SD of one or more samples alongside the raw data. The raw data are "jittered" and an alpha value adds transparency to aid in separation of individual data points. This form of data presentation invites active analysis of the raw data by the reader.
The following arguments can be passed to errorbarjitter:
In [1]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
%matplotlib inline
In [2]:
sns.set_style('darkgrid')
sns.set_context('talk')
In [3]:
def errorbarjitter(df, groupByCol, statsCol, fig=None, xlab='group', ylab='units', rotate = 0):
grouped = df.groupby([groupByCol])
stats = grouped.aggregate({statsCol:[np.std, np.mean]})
groups = df[groupByCol].unique()
means, devs = stats[statsCol]['mean'], stats[statsCol]['std']
plt.figure(figsize=(15,7))
if fig is None:
fig = plt.figure()
for (i, (m, s)) in enumerate(zip(means, devs)):
pts = np.array(df[df[groupByCol]==groups[i]][statsCol])
x = i*np.ones(len(pts)) + 0.2*np.random.rand(len(pts))-0.1
plt.scatter(x, pts, c='k', alpha=0.5)
delta = 0.22
plt.scatter(i+delta, m, edgecolor='k', facecolor='none', linewidth=3, s=25)
plt.plot([i+delta, i+delta], [m-s, m+s], '-', c=[0, 0, 0], lw=2.0)
plt.xticks(range(len(groups)), groups, rotation=rotate);
plt.xlabel(xlab)
plt.ylabel(ylab)
In [4]:
path = "ex-data.csv";
exdata = pd.read_csv(path)
exdata.head()
Out[4]:
In [5]:
grouped = exdata.groupby(['runner'])
stats = grouped.aggregate({'time':[np.std, np.mean]})
stats
Out[5]:
In [6]:
errorbarjitter(exdata, 'runner', 'time', xlab = 'runner', ylab = 'time (s)')
plt.title('5 runners average course time')
Out[6]:
In [7]:
path = "iris.csv";
data = pd.read_csv(path)
data.head()
Out[7]:
In [8]:
fig = plt.figure()
errorbarjitter(data, 'Species', 'Sepal.Length', xlab = 'species', ylab = 'sepal length')
plt.title('Sepal length by species')
Out[8]:
In [9]:
fig = plt.figure()
errorbarjitter(data, 'Species', 'Sepal.Width', xlab = 'species', ylab = 'sepal width')
plt.title('Sepal width by species')
Out[9]:
In [10]:
fig = plt.figure()
errorbarjitter(data, 'Species', 'Petal.Length', xlab = 'species', ylab = 'petal length')
plt.title('Petal length by species')
Out[10]:
In [11]:
fig = plt.figure()
errorbarjitter(data, 'Species', 'Petal.Width', xlab = 'species', ylab = 'petal width')
plt.title('Petal width by species')
Out[11]:
References:
[1] errorbarjitter. David Stern. http://www.mathworks.com/matlabcentral/fileexchange/33658-errorbarjitter
[2] The data were collected by Anderson, Edgar (1935). The irises of the Gaspe Peninsula, Bulletin of the American Iris Society, 59, 2–5.
[3] Fisher, R. A. (1936) The use of multiple measurements in taxonomic problems. Annals of Eugenics, 7, Part II, 179–188.
[4] Becker, R. A., Chambers, J. M. and Wilks, A. R. (1988) The New S Language. Wadsworth & Brooks/Cole.
In [12]:
sns.set_style('darkgrid')
fig = plt.figure()
errorbarjitter(data, 'Species', 'Petal.Width', xlab = 'species', ylab = 'petal width')
plt.title('Petal width by species')
Out[12]:
In [13]:
sns.set_style('whitegrid')
fig = plt.figure()
errorbarjitter(data, 'Species', 'Petal.Width', xlab = 'species', ylab = 'petal width')
plt.title('Petal width by species')
Out[13]:
In [14]:
sns.set_style('dark')
fig = plt.figure()
errorbarjitter(data, 'Species', 'Petal.Width', xlab = 'species', ylab = 'petal width')
plt.title('Petal width by species')
Out[14]:
In [15]:
sns.set_style('white')
fig = plt.figure()
errorbarjitter(data, 'Species', 'Petal.Width', xlab = 'species', ylab = 'petal width')
plt.title('Petal width by species')
Out[15]:
In [16]:
sns.set_style('ticks')
fig = plt.figure()
errorbarjitter(data, 'Species', 'Petal.Width', xlab = 'species', ylab = 'petal width')
plt.title('Petal width by species')
Out[16]:
In [17]:
path = "learning-data.csv";
learning = pd.read_csv(path)
sns.set_style('darkgrid')
fig = plt.figure()
errorbarjitter(learning, 'animal', 'pi', xlab = 'animal', ylab = 'performance index')
plt.title('learning assay')
plt.ylim(-1,1)
Out[17]:
In [ ]:
In [ ]: