In [2]:
import sys
import glob
import re
import fnmatch
import math
import os
from os import listdir
from os.path import join, isfile, basename
import itertools
import numpy as np
from numpy import float32, int32, uint8, dtype, genfromtxt
from scipy.stats import ttest_ind
import pandas as pd
import matplotlib
%matplotlib inline
import matplotlib.pyplot as plt
import seaborn as sns
import colorsys
In [ ]:
In [11]:
# Curry here so we can infer alg names
def mean_diff( alg_list ):
def diff(x):
return x['MEAN_'+alg_list[0]] - x['MEAN_'+alg_list[1]]
return diff
def diff_by_names( pos_name, neg_name ):
def diff(x):
return x[pos_name] - x[neg_name]
return diff
def count_avg( alg_list ):
def avg(x):
return ( x['COUNT_'+alg_list[0]] + x['COUNT_'+alg_list[1]] ) / 2.
return avg
def str_cat( colA, colB, sep=',' ):
def cat(x):
return '{}{}{}'.format( x[colA], sep, x[colB] )
return cat
def log_count_avg( alg_list ):
def avg(x):
y = ( x['COUNT_'+alg_list[0]] + x['COUNT_'+alg_list[1]] ) / 2.
if y <= 0:
return 0.
else:
return math.log(y)
return avg
def get_alg_names( df ):
return [ s.replace('COUNT_','') for s in df.columns.values if s.startswith('COUNT') ]
def get_data_f( fixed, v1, v2, line, merged=False, v_template=False ):
if v_template:
dirname='templateStatsByLabel'
else:
dirname='algStatsByLabel'
if merged:
ff = '/nrs/saalfeld/john/projects/flyChemStainAtlas/all_evals/{}/{}/mergeLabels_{}_vs_{}_line{}.csv'.format( dirname, fixed, v1, v2, line)
fr = '/nrs/saalfeld/john/projects/flyChemStainAtlas/all_evals/{}/{}/mergeLabels_{}_vs_{}_line{}.csv'.format( dirname, fixed, v2, v1, line)
if os.path.isfile( ff ):
return ff
elif os.path.isfile( fr ):
return fr
else:
return 'file does not exist'
else:
ff = '/nrs/saalfeld/john/projects/flyChemStainAtlas/all_evals/{}/{}/{}_vs_{}_line{}.csv'.format( dirname, fixed, v1, v2, line)
fr = '/nrs/saalfeld/john/projects/flyChemStainAtlas/all_evals/{}/{}/{}_vs_{}_line{}.csv'.format( dirname, fixed, v2, v1, line)
if os.path.isfile( ff ):
return ff
elif os.path.isfile( fr ):
return fr
else:
return 'file does not exist'
In [3]:
def plot( dist_samples_df, merge, log, ax ):
alg_list = get_alg_names( dist_samples_df )
print( alg_list )
md = dist_samples_df.apply( mean_diff(alg_list), axis=1)
c = dist_samples_df.apply( count_avg(alg_list), axis=1 )
c_log = dist_samples_df.apply( log_count_avg(alg_list), axis=1 )
dist_samples_df.loc[:, 'md'] = md
dist_samples_df.loc[:, 'c'] = c
dist_samples_df.loc[:, 'c_log'] = c_log
dfs = dist_samples_df.sort_values('c', ascending=False )
if log:
colors = matplotlib.cm.viridis( dfs['c_log'] / float(max(dfs['c_log']) ))
else:
colors = matplotlib.cm.viridis( dfs['c'] / float(max(dfs['c'])))
# sns.barplot( x=dist_samples_df['LABEL'], y=md )
if log:
plot = ax.scatter(dfs['c_log'], dfs['c_log'], c = dfs['c_log'], cmap = 'viridis' )
else:
plot = plt.scatter(c, c, c = c, cmap = 'viridis')
ax.clear()
ax.bar( range(dfs['md'].shape[0]), dfs['md'], color=colors )
ax.set_title( 'MEAN '+alg_list[0] + ' - ' + alg_list[1] )
ax.set_ylabel('mean distance difference (um)')
In [4]:
template='F-antsFlip_lo'
line=3
count_threshold = 1000;
alg_list = ['cmtkCOG', 'cmtkCow', 'cmtkHideo', 'antsRegOwl', 'antsRegDog', 'antsRegYang']
n_combs = 0
for alg1,alg2 in itertools.combinations( alg_list,2 ):
n_combs+=1
fig, axs = plt.subplots(n_combs,1)
print(n_combs)
i = 0
for alg1,alg2 in itertools.combinations( alg_list,2 ):
print( alg1, ' ', alg2 )
dist_samples_df = pd.read_csv( get_data_f(template,alg1,alg2,line), index_col=0 )
valid_df = dist_samples_df[ (dist_samples_df['LABEL'] > 0) & (dist_samples_df['COUNT_'+alg1] > count_threshold) & (dist_samples_df['COUNT_'+alg2] > count_threshold) ]
plot( valid_df, False, True, axs[i] )
i += 1
a = fig.set_size_inches( 16, n_combs*6 )
In [63]:
alg0='cmtkCOG'
alg1='antsRegYang'
In [6]:
dist_samples_df.head()
Out[6]:
In [51]:
def load_all_pairs( fixed, v_list, line, v_template=False, merged=False ):
if v_template:
prefix='TEMPLATE_'
else:
prefix='ALG_'
alg0_list = []
alg1_list = []
mn0_list = []
mn1_list = []
c0_list = []
c1_list = []
p_list = []
t_list = []
mn0stat_list = []
mn0min_list = []
mn0max_list = []
vr0stat_list = []
vr0min_list = []
vr0max_list = []
sd0stat_list = []
sd0min_list = []
sd0max_list = []
mn1stat_list = []
mn1min_list = []
mn1max_list = []
vr1stat_list = []
vr1min_list = []
vr1max_list = []
sd1stat_list = []
sd1min_list = []
sd1max_list = []
ks_list = []
kp_list = []
ws_list = []
wp_list = []
cohen_list = []
for a0,a1 in itertools.combinations( v_list, 2 ):
# print( a0, ' ', a1 )
f = get_data_f( fixed, a0, a1, line, merged=merged, v_template=v_template )
# print( 'f: ', f )
alg_names = [a0, a1 ]
df_wlabels = pd.read_csv(f, index_col=0)
df = df_wlabels[ df_wlabels.LABEL == -1 ]
for i,row in df.iterrows():
m0='MEAN_'+a0
m1='MEAN_'+a1
c0='COUNT_'+a0
c1='COUNT_'+a1
alg0_list += [a0]
alg1_list += [a1]
mn0_list += [row[m0]]
mn1_list += [row[m1]]
c0_list += [row[c0]]
c1_list += [row[c1]]
p_list += [row['PVAL']]
t_list += [row['TSTAT']]
mn0stat_list += [row['MNSTAT_'+a0]]
mn0min_list += [row['MNMIN_'+a0]]
mn0max_list += [row['MNMAX_'+a0]]
vr0stat_list += [row['VRSTAT_'+a0]]
vr0min_list += [row['VRMIN_'+a0]]
vr0max_list += [row['VRMAX_'+a0]]
sd0stat_list += [row['SDSTAT_'+a0]]
sd0min_list += [row['SDMIN_'+a0]]
sd0max_list += [row['SDMAX_'+a0]]
mn1stat_list += [row['MNSTAT_'+a1]]
mn1min_list += [row['MNMIN_'+a1]]
mn1max_list += [row['MNMAX_'+a1]]
vr1stat_list += [row['VRSTAT_'+a1]]
vr1min_list += [row['VRMIN_'+a1]]
vr1max_list += [row['VRMAX_'+a1]]
sd1stat_list += [row['SDSTAT_'+a1]]
sd1min_list += [row['SDMIN_'+a1]]
sd1max_list += [row['SDMAX_'+a1]]
ks_list += [row['KRUSKAL']]
kp_list += [row['KRUSKALP']]
ws_list += [row['WILCOXON']]
wp_list += [row['WILCOXONP']]
cohen_list += [row['COHEND']]
# Backwards
alg0_list += [a1]
alg1_list += [a0]
mn0_list += [row[m1]]
mn1_list += [row[m0]]
c0_list += [row[c1]]
c1_list += [row[c0]]
p_list += [row['PVAL']]
t_list += [row['TSTAT']]
mn0stat_list += [row['MNSTAT_'+a1]]
mn0min_list += [row['MNMIN_'+a1]]
mn0max_list += [row['MNMAX_'+a1]]
vr0stat_list += [row['VRSTAT_'+a1]]
vr0min_list += [row['VRMIN_'+a1]]
vr0max_list += [row['VRMAX_'+a1]]
sd0stat_list += [row['SDSTAT_'+a1]]
sd0min_list += [row['SDMIN_'+a1]]
sd0max_list += [row['SDMAX_'+a1]]
mn1stat_list += [row['MNSTAT_'+a0]]
mn1min_list += [row['MNMIN_'+a0]]
mn1max_list += [row['MNMAX_'+a0]]
vr1stat_list += [row['VRSTAT_'+a0]]
vr1min_list += [row['VRMIN_'+a0]]
vr1max_list += [row['VRMAX_'+a0]]
sd1stat_list += [row['SDSTAT_'+a0]]
sd1min_list += [row['SDMIN_'+a0]]
sd1max_list += [row['SDMAX_'+a0]]
ks_list += [row['KRUSKAL']]
kp_list += [row['KRUSKALP']]
ws_list += [row['WILCOXON']]
wp_list += [row['WILCOXONP']]
cohen_list += [-row['COHEND']]
merged_dist_df = pd.DataFrame({prefix+'0':alg0_list,
prefix+'1':alg1_list,
'MEAN_0':mn0_list,
'MEAN_1':mn1_list,
'COUNT_0':c0_list,
'COUNT_1':c1_list,
'PVAL':p_list,
'TSTAT':t_list,
'KRUSKAL':ks_list,
'KRUSKALP':kp_list,
'WILCOXON':ws_list,
'WILCOXONP':wp_list,
'COHEN':cohen_list,
'MN0':mn0stat_list,
'MN0MIN':mn0min_list,
'MN0MAX':mn0max_list,
'VR0':vr0stat_list,
'VR0MIN':vr0min_list,
'VR0MAX':vr0max_list,
'SD0':sd0stat_list,
'SD0MIN':sd0min_list,
'SD0MAX':sd0max_list,
'MN1':mn1stat_list,
'MN1MIN':mn1min_list,
'MN1MAX':mn1max_list,
'VR1':vr1stat_list,
'VR1MIN':vr1min_list,
'VR1MAX':vr1max_list,
'SD1':sd1stat_list,
'SD1MIN':sd1min_list,
'SD1MAX':sd1max_list
})
return merged_dist_df
In [52]:
line=3
template = 'F-antsFlip_lo'
merged_dist_df = load_all_pairs( template, alg_list, line, v_template=False )
In [53]:
merged_dist_df['MEANDIFF'] = merged_dist_df.apply( mean_diff(['0','1']), axis=1)
cmap = sns.diverging_palette(200, 15, sep=20, as_cmap=True)
piv_md = merged_dist_df[['ALG_0','ALG_1','MEANDIFF']].pivot( index='ALG_0', columns='ALG_1', values='MEANDIFF' )
sns.heatmap( piv_md, cmap=cmap,
square=True, linewidth=0.5)
Out[53]:
In [45]:
piv_pv = merged_dist_df[['ALG_0','ALG_1','PVAL']].pivot( index='ALG_0', columns='ALG_1', values='PVAL' )
sns.heatmap( piv_pv, cmap=cmap,
square=True, linewidth=0.5)
Out[45]:
In [46]:
merged_dist_df['MN_MIN_MAX'] = merged_dist_df.apply( diff_by_names('MN0MIN','MN1MAX'), axis=1)
merged_dist_df['MN_MAX_MIN'] = merged_dist_df.apply( diff_by_names('MN0MAX','MN1MIN'), axis=1)
piv_mmm = merged_dist_df[['ALG_0','ALG_1','MN_MIN_MAX']].pivot( index='ALG_0', columns='ALG_1', values='MN_MIN_MAX' )
sns.heatmap( piv_mmm, cmap=cmap,
square=True, linewidth=0.5)
Out[46]:
In [47]:
piv_effsz = merged_dist_df[['ALG_0','ALG_1','COHEN']].pivot( index='ALG_0', columns='ALG_1', values='COHEN' )
sns.heatmap( piv_effsz, cmap=cmap,
square=True, linewidth=0.5)
Out[47]:
In [54]:
line=3
alg = 'cmtkCOG'
template_list = ['JFRCtemplate2010', 'JFRC2013_lo', 'F-antsFlip_lo', 'TeforBrain_f']
merged_template_df = load_all_pairs( alg, template_list, line, v_template=True )
merged_template_df
Out[54]:
In [56]:
piv_template_md = merged_template_df[['TEMPLATE_0','TEMPLATE_1','COHEN']].pivot( index='TEMPLATE_0', columns='TEMPLATE_1', values='COHEN' )
sns.heatmap( piv_template_md, cmap=cmap,
square=True, linewidth=0.5)
Out[56]:
In [61]:
time_mem_df = pd.read_csv( '/nrs/saalfeld/john/projects/flyChemStainAtlas/all_evals/time_mem_data.csv',
dtype={'ALG':str, 'TEMPLATE':str, 'THREADS':np.int32,
'TIME':np.float32, 'AVGMEM':np.float32, 'MAXMEM':np.float32})
time_mem_df['EXP'] = time_mem_df[['TEMPLATE','ALG']].apply( str_cat('TEMPLATE','ALG'), axis=1 )
# time_mem_df.dtypes
In [45]:
def mult_col( colA, colB ):
def mult(x):
return x[colA]*x[colB]
return mult
time_mem_df['CPUTIME'] = time_mem_df[['TIME','THREADS']].apply( mult_col('TIME','THREADS'), axis=1 )
In [57]:
# fig= plt.figure()
f, (ax1, ax2, ax3, ax4) = plt.subplots(4, 1, figsize=(12, 21), sharex=False)
g1=sns.factorplot( x='ALG', y='TIME', kind='bar', data=time_mem_df, ax=ax1 )
g2=sns.factorplot( x='ALG', y='CPUTIME', kind='bar', data=time_mem_df, ax=ax2 )
g3=sns.factorplot( x='ALG', y='AVGMEM', kind='bar', data=time_mem_df, ax=ax3 )
g4=sns.factorplot( x='ALG', y='MAXMEM', kind='bar', data=time_mem_df, ax=ax4 )
# close the extra figures
plt.close(g1.fig)
plt.close(g2.fig)
plt.close(g3.fig)
plt.close(g4.fig)
# ax1.set_xticklabels(rotation=30) # This doesn't work with subplots :(