In [ ]:
%matplotlib inline

import os.path
import json

import pandas as pd
import matplotlib
matplotlib.style.use('ggplot')
import matplotlib.pyplot as plt

from atntools import settings, searchprocess

In [ ]:
sequence_nums = [5, 1, 2]

In [ ]:
def get_final_iteration_num(sequence_dir):
    state_filename = os.path.join(sequence_dir, 'sequence-state.json')
    with open(state_filename, 'r') as f:
        sequence_state = json.load(f)
    final_iteration_num = len(sequence_state['sets']) - 2
    return final_iteration_num

def get_iteration_data(sequence_num):
    sequence_dir = searchprocess.get_sequence_dir(sequence_num)
    final_iteration_num = get_final_iteration_num(sequence_dir)
    return pd.read_csv(
        os.path.join(sequence_dir, 'iteration-{}.csv'.format(final_iteration_num)),
        index_col=0)

def get_extinction_data(sequence_num):
    sequence_dir = searchprocess.get_sequence_dir(sequence_num)
    final_iteration_num = get_final_iteration_num(sequence_dir)
    df = pd.read_csv(
        os.path.join(sequence_dir, 'extinctions-iteration-{}.csv'.format(final_iteration_num)),
        index_col=0)
    df.fillna(0, inplace=True)
    return df

In [ ]:
def plot_extinction_distributions(dataframes, iterations):
    plt.figure(figsize=(10, 20))
    rows = iterations
    cols = len(dataframes)
    for col, extinctions in enumerate(dataframes, start=1):
        for i in range(len(extinctions)):
            plot_num = i * cols + col
            ax = plt.subplot(rows, cols, plot_num)
            extinctions.loc[i].plot.bar(ax=ax, width=1, ylim=(0, 1))
            if plot_num == 1:
                plt.ylabel("relative freq.")
                plt.xlabel("extinctions")
            if plot_num <= cols:
                species_count = len(extinctions.columns) - 1
                plt.title("{} species".format(species_count))
            
extinction_dfs = [get_extinction_data(s) for s in sequence_nums]
plot_extinction_distributions(extinction_dfs, 10)
plt.tight_layout()

In [ ]:
def plot_mean_extinctions(dataframes, iterations):
    plt.figure()
    for df in dataframes:
        mean = (df * list(map(int, df.columns))).sum(axis=1)
        species_count = len(df.columns) - 1
        mean.plot(ylim=(0, 7), label="{} species".format(species_count))
    plt.xlabel('iteration')
    plt.ylabel('mean exinction count')
    plt.legend(prop={'size': 11})
        
plot_mean_extinctions(extinction_dfs, 10)

In [ ]:
def plot_f1_scores(dataframes, species_counts, iterations):
    plt.figure()
    for species_count, df in zip(species_counts, dataframes):
        mean_score = df[['f1_test_0', 'f1_test_1']].mean(axis=1)
        mean_score.plot(ylim=(0, 1), label="{} species".format(species_count))
    plt.xlabel('iteration')
    plt.ylabel('F1 score class average')
    plt.legend(loc='lower right', prop={'size': 11})
        
iteration_dfs = [get_iteration_data(s) for s in sequence_nums]
species_counts = [len(df.columns) - 1 for df in extinction_dfs]
plot_f1_scores(iteration_dfs, species_counts, 10)

In [ ]: