In [ ]:
%matplotlib inline
import os.path
import json
import pandas as pd
import matplotlib
matplotlib.style.use('ggplot')
import matplotlib.pyplot as plt
from atntools import settings, searchprocess
In [ ]:
sequence_nums = [5, 1, 2]
In [ ]:
def get_final_iteration_num(sequence_dir):
state_filename = os.path.join(sequence_dir, 'sequence-state.json')
with open(state_filename, 'r') as f:
sequence_state = json.load(f)
final_iteration_num = len(sequence_state['sets']) - 2
return final_iteration_num
def get_iteration_data(sequence_num):
sequence_dir = searchprocess.get_sequence_dir(sequence_num)
final_iteration_num = get_final_iteration_num(sequence_dir)
return pd.read_csv(
os.path.join(sequence_dir, 'iteration-{}.csv'.format(final_iteration_num)),
index_col=0)
def get_extinction_data(sequence_num):
sequence_dir = searchprocess.get_sequence_dir(sequence_num)
final_iteration_num = get_final_iteration_num(sequence_dir)
df = pd.read_csv(
os.path.join(sequence_dir, 'extinctions-iteration-{}.csv'.format(final_iteration_num)),
index_col=0)
df.fillna(0, inplace=True)
return df
In [ ]:
def plot_extinction_distributions(dataframes, iterations):
plt.figure(figsize=(10, 20))
rows = iterations
cols = len(dataframes)
for col, extinctions in enumerate(dataframes, start=1):
for i in range(len(extinctions)):
plot_num = i * cols + col
ax = plt.subplot(rows, cols, plot_num)
extinctions.loc[i].plot.bar(ax=ax, width=1, ylim=(0, 1))
if plot_num == 1:
plt.ylabel("relative freq.")
plt.xlabel("extinctions")
if plot_num <= cols:
species_count = len(extinctions.columns) - 1
plt.title("{} species".format(species_count))
extinction_dfs = [get_extinction_data(s) for s in sequence_nums]
plot_extinction_distributions(extinction_dfs, 10)
plt.tight_layout()
In [ ]:
def plot_mean_extinctions(dataframes, iterations):
plt.figure()
for df in dataframes:
mean = (df * list(map(int, df.columns))).sum(axis=1)
species_count = len(df.columns) - 1
mean.plot(ylim=(0, 7), label="{} species".format(species_count))
plt.xlabel('iteration')
plt.ylabel('mean exinction count')
plt.legend(prop={'size': 11})
plot_mean_extinctions(extinction_dfs, 10)
In [ ]:
def plot_f1_scores(dataframes, species_counts, iterations):
plt.figure()
for species_count, df in zip(species_counts, dataframes):
mean_score = df[['f1_test_0', 'f1_test_1']].mean(axis=1)
mean_score.plot(ylim=(0, 1), label="{} species".format(species_count))
plt.xlabel('iteration')
plt.ylabel('F1 score class average')
plt.legend(loc='lower right', prop={'size': 11})
iteration_dfs = [get_iteration_data(s) for s in sequence_nums]
species_counts = [len(df.columns) - 1 for df in extinction_dfs]
plot_f1_scores(iteration_dfs, species_counts, 10)
In [ ]: