Analysis Automation (With Python and Jupyter Notebook)

Levels of Python Automation:

  1. Good - Documenting all analysis steps in enough details that will enable them to be reproduced successfully.
  2. Better - Script your analysis
  3. Best - Script your analysis and write tests to validate each step.

To get started we will import the python modules that we will use in the session. These modules are developed by programmers and made available as open source packages for python. We would normally have to install each of these ourself but they are included as part of the Anaconda Python Distribution.

The %matplotlib inline statement is part of the Jupyter and IPython magic that enables plaots generated by the matplotlib package to be discplayed as output in the Jupyter Notebook instead of open in a separate window.


In [ ]:
import numpy as np
import pandas as pd
import pylab as plt
import matplotlib

%matplotlib inline

We will continue where the data exploration module left off but importing the cleaned gapminder dataset and setting it equal to a new varaible named df to denote that we have imported a pandas dataframe.

As validation that we have imported the data we will also look at the top five rows of data using the head method of pandas.


In [ ]:
df = pd.read_csv('./gapminder_cleaned.csv')
df.head()

In [ ]:
unique_years = df['year'].unique()

In [ ]:
# Define which continent / category we will use
category = 'lifeexp'
continent = 'asia'

In [ ]:
# Create a mask that selects the continent of choice
mask_continent = df['continent'] == continent
df_continent = df[mask_continent]

In [ ]:
# Loop through years and calculate the statistic of interest
years = df_continent['year'].unique()
summary = []

for year in years:
    mask_year = df_continent['year'] == year
    df_year = df_continent[mask_year]
    value = np.mean(df_year[category])
    summary.append((continent, year, value))
    
# Turn the summary into a dataframe so that we can visualize easily
summary = pd.DataFrame(summary, columns=['continent', 'year', category])

In [ ]:
summary.plot.line('year', 'lifeexp')

In [ ]:
def calculate_statistic_over_time(data, category, continent, func=None):
    if func is None:
        func = np.mean
        
    # Create a mask that selects the continent of choice
    mask_continent = data['continent'] == continent
    data_continent = data[mask_continent]

    # Loop through years and calculate the statistic of interest
    years = data_continent['year'].unique()
    summary = []
    for year in years:
        mask_year = data_continent['year'] == year
        data_year = data_continent[mask_year]
        value = func(data_year[category])
        summary.append((continent, year, value))

    # Turn the summary into a dataframe so that we can visualize easily
    summary = pd.DataFrame(summary, columns=['continent', 'year', category])
    return summary

In [ ]:
category = 'lifeexp'
continents = df['continent'].unique()

fig, ax = plt.subplots()
for continent in continents:
    output = calculate_statistic_over_time(df, category, continent)
    output.plot.line('year', category, ax=ax)

In [ ]:
category = 'lifeexp'
mean_values = df.groupby('continent').mean()[category]
mean_values = mean_values.sort_values(ascending=False)
continents = mean_values.index.values

n_continents = len(continents)
cmap = plt.cm.coolwarm_r

fig, ax = plt.subplots()
for ii, continent in enumerate(continents):
    this_color = cmap(float(ii / n_continents))
    output = calculate_statistic_over_time(df, category, continent)
    output.plot.line('year', category, ax=ax, label=continent,
                     color=this_color)
    plt.legend(loc=(1.02, 0))
    ax.set(ylabel=category, xlabel='Year',
           title='{} over time'.format(category))
    
plt.setp(ax.lines, lw=4, alpha=.4)

In [ ]:
def plot_statistic_over_time(data, category, func=None,
                             cmap=None, ax=None, legend=True,
                             sort=True):
    if ax is None:
        fig, ax = plt.subplots()
    if cmap is None:
        cmap = plt.cm.viridis
    
    if sort is True:
        # Sort the continents by the category of choice
        mean_values = df.groupby('continent').mean()[category]
        mean_values = mean_values.sort_values(ascending=False)
        continents = mean_values.index.values
    else:
        continents = np.unique(df['continent'])
    n_continents = len(continents)

    # Loop through continents, calculate its stat, and add a line
    for ii, continent in enumerate(continents):
        this_color = cmap(float(ii / n_continents))
        output = calculate_statistic_over_time(data, category, continent)
        output.plot.line('year', category, ax=ax, label=continent,
                         color=this_color)
        if legend is True:
            plt.legend(loc=(1.02, 0))
        else:
            ax.get_legend().set(visible=False)
        ax.set(ylabel=category, xlabel='Year',
               title='{} over time'.format(category))

    plt.setp(ax.lines, lw=4, alpha=.4)
    return ax

In [ ]:
plot_statistic_over_time(df, category, continent, cmap=plt.cm.coolwarm)

In [ ]:
fig, axs = plt.subplots(1, 2, figsize=(10, 5))
categories = ['pop', 'gdppercap']
for ax, i_category in zip(axs, categories):
    plot_statistic_over_time(df, i_category, continent,
                             ax=ax, sort=False)
plt.setp(axs[0].get_legend(), visible=False)
plt.tight_layout()

In [ ]:
fig, axs = plt.subplots(1, 2, figsize=(10, 5), sharey=True)
cmaps = [plt.cm.viridis, plt.cm.coolwarm]
for ax, cmap in zip(axs, cmaps):
    plot_statistic_over_time(df, category, continent,
                             cmap=cmap, ax=ax, legend=False)

In [ ]:
ax = df.groupby(['continent', 'year']).mean()['lifeexp']\
       .unstack('continent').plot(cmap=plt.cm.viridis, alpha=.4, lw=3)

In [ ]:
# Saving for publication
cmaps = [plt.cm.magma, plt.cm.rainbow]
for ii, cmap in enumerate(cmaps):
    fig, ax = plt.subplots(figsize=(10, 10), sharey=True)
    plot_statistic_over_time(df, category, continent,
                             cmap=cmap, ax=ax, legend=False)
    labels = [ax.get_xticklabels(), ax.get_yticklabels(),
              ax.yaxis.label, ax.xaxis.label, ax.title]
    _ = plt.setp(labels, fontsize=30)
#     ax.set_axis_off()
    fig.savefig('fig_{}.png'.format(ii), transparent=True, bbox_inches='tight')

In [ ]: