In [ ]:
import sys
import glob
import re
import fnmatch
import math
import re
import os
from os import listdir
from os.path import join, isfile, basename

import itertools

import numpy as np
from numpy import float32, int32, uint8, dtype, genfromtxt

from scipy.stats import ttest_ind

import pandas as pd

import matplotlib
%matplotlib inline
import matplotlib.pyplot as plt
from matplotlib.ticker import MultipleLocator, LogLocator, FormatStrFormatter

import seaborn as sns

import colorsys

import template_common as tc

In [ ]:
from IPython.core.display import display, HTML
display(HTML("<style>.container { width:95% !important; }</style>"))

In [ ]:
# Generate the time table from raw data
## See Vis_Pairwise_alg-temp

In [ ]:
# table_f='/groups/saalfeld/home/bogovicj/pubDrafts/grpDrosTemplate/grpDrosTemplate/table.dat'
time_table_f='/groups/saalfeld/home/bogovicj/pubDrafts/grpDrosTemplate/grpDrosTemplate/table_cputime.csv'
time_table_raw = pd.read_csv( time_table_f )


base_dir='/nrs/saalfeld/john/projects/flyChemStainAtlas/all_evals/distanceStatsWarpNorm'
dest_dir = '/groups/saalfeld/home/bogovicj/pubDrafts/grpDrosTemplate/grpDrosTemplate/tables'
fig_dir = '/groups/saalfeld/home/bogovicj/pubDrafts/grpDrosTemplate/grpDrosTemplate/figs'


alg_list = ['antsRegDog', 'antsRegOwl', 'antsRegYang', 'cmtkCOG', 'cmtkCow', 'cmtkHideo']
# template_list = [ 'JFRC2013_lo', 'JFRCtemplate2010', 'TeforBrain_f', 'F-antsFlip_lo', 'F-cmtkFlip_lof', 'FCWB']
template_list = [ 'JFRC2013_lo', 'JFRCtemplate2010', 'TeforBrain_f', 'F-antsFlip_lo', 'FCWB']


# Load distance stats
df = None
# for tmp in template_list:
for f in glob.glob( ''.join([base_dir,'/*.csv']) ):
#     f = glob.glob( ''.join([base_dir,'/',tmp,'.csv']) )
    print( f )
    this_df = pd.read_csv( f, header=[0,1], index_col=0 )
    if df is None:
        df = this_df
    else:
        df = df.append( this_df )

clean_cols( df )
df['std'] = df.apply( lambda x: math.sqrt(x['var']), axis=1)
df['gam_std'] = df.apply( lambda x: math.sqrt(x['gam_var']), axis=1)
df['ray_std'] = df.apply( lambda x: math.sqrt(x['ray_var']), axis=1)
df.reset_index( drop=True )

In [ ]:
labels = [16,64,8,32,2,4,65,66,33,67,34,17,69,70,35,71,9,18,72,36,73,74,37,75,19,76,38,77,39,78,79,20,5,40,80,10,81,82,83,84,85,86,11,22,23,24,12,3,6,49,50,25,51,13,52,26,53,27,54,55,56,28,7,14,57,58,29,59,30,60,15,61,31,62,63]
label_names_file = '/groups/saalfeld/home/bogovicj/vfb/DrosAdultBRAINdomains/refData/Original_Index.tsv'

label_names = pd.read_csv( label_names_file, delimiter='\t', header=0 )
# print label_names[ label_names['Stack id'] == 11 ]['JFRCtempate2010.mask130819' ].iloc[0]
# print label_names[ label_names['Stack id'] == 70 ]['JFRCtempate2010.mask130819' ].iloc[0]

def get_label_name( label_id ):
    return label_names[ label_names['Stack id'] == label_id ]['JFRCtempate2010.mask130819' ].iloc[0]

def clean_cols( df ):
    ## clean up the weird columns
    df.columns = [ c[0] if c[1].startswith('Unnamed') else c[1] for c in df.columns.values ]
    
def flatten_heir_cols( df ):
    ## flatten heirarchical columns
    df.columns = [ '_'.join(c) for c in df.columns.values ]
    
template_color_map = { 'JFRC2010':'firebrick',
                       'JFRC2013':'navy',
                       'FCWB':'darkgreen',
                       'Tefor':'darkorchid',
                       'JRC2018':'black',
#                        'CMTK groupwise':'gray'
                       'CMTK groupwise':'darkorange'
                     }

In [ ]:


In [ ]:
# filter templates
tmask = df.apply( lambda x: (x['TEMPLATE'] in template_list ), axis=1)
df = df.loc[tmask]

# Filter out appropriate rows and columns
mean_table = df.loc[ (df.LABEL == -1) & (df.ALG != 'ALL')][['ALG','TEMPLATE','mean']]
mean_table['TEMPLATE'] = mean_table.apply(lambda x: tc.template_name(x['TEMPLATE']), axis=1)
mean_table['ALG'] = mean_table.apply(lambda x: tc.alg_name(x['ALG']), axis=1)
mean_table['TA'] = mean_table.apply(lambda x: ''.join([x['TEMPLATE'],':',x['ALG']]), axis=1)
mean_table

In [ ]:
# Filter the time table

regexp=re.compile('\s+ANTs Wolf')
time_table = time_table_raw[ time_table_raw.apply( lambda x: ( regexp.search( x['ALG'] ) is None ), axis=1)]

time_table['TEMPLATE'] = time_table.apply(lambda x: tc.template_name(x['TEMPLATE']), axis=1)
time_table['ALG'] = time_table.apply(lambda x: tc.alg_name(x['ALG']), axis=1)

# # filter templates
time_table = time_table.loc[time_table.TEMPLATE != 'CMTK groupwise']
time_table['TA'] = time_table.apply(lambda x: ''.join([x['TEMPLATE'],':',x['ALG']]), axis=1)
time_table

In [ ]:
# combine the tables

mean_time_table = mean_table.set_index('TA').join( time_table.set_index('TA'), lsuffix='_mn')
mean_time_table = mean_time_table.reset_index()[['ALG','TEMPLATE','CPUTIME','mean']]
mean_time_table

In [ ]:
# plt.scatter( mean_time_table.CPUTIME/3600., mean_time_table['mean'] )

ax = plt.gca()
for i,row in mean_time_table.iterrows():

    dist = '%0.2f' % (row['mean'])
    time = '%0.1fh' % (row['CPUTIME'])
#     s = "   " + alg_name_map[row['ALG'].lstrip(' ')] + "-" + template_name_map[(row['TEMPLATE']).lstrip(' ')]
    s = "   " + tc.alg_name(row['ALG'].lstrip(' ').rstrip(' ')) + " : " + tc.template_name((row['TEMPLATE']).lstrip(' '))
#     s = "   " + row['ALG'].lstrip(' ') + "-" + (row['TEMPLATE']).lstrip(' ')
    c = template_color_map[row['TEMPLATE']]
    ax.annotate( s, (row['CPUTIME']/3600., row['mean'] ), color=c, size=13 )
    plt.scatter( row['CPUTIME']/3600., row['mean'], color=c )
    
ax.set_xscale('log')
plt.xlabel('CPU hours', size=18)
plt.ylabel('Mean distance (um)', size=18)
plt.xticks(fontsize=16)
plt.yticks(fontsize=16)

ax.yaxis.set_minor_locator( MultipleLocator(0.2) )
plt.grid( which='minor', linestyle=':', dashes=(3,3))

plt.xlim(10000./3600,1200000./3600)

fig = plt.gcf()


a = fig.set_size_inches( 16, 10 )
fout_prefix="/groups/saalfeld/home/bogovicj/pubDrafts/grpDrosTemplate/grpDrosTemplate/figs/speed_quality_20180531"
plt.savefig(fout_prefix+".svg")
plt.savefig(fout_prefix+".pdf")
plt.savefig(fout_prefix+".png")

In [ ]:
# A table that only shows the 'best' few results
# Where best means having mean distance less than 4.1

goodtable = mean_time_table[ mean_time_table['mean'] < 4.1 ]

# plt.scatter( goodtable.CPUTIME/3600., goodtable.MEAN_D )
plt.xlabel('CPU hours')
plt.ylabel('Mean distance (um)')

ax = plt.gca()
for i,row in goodtable.iterrows():
    dist = '%0.2f' % (row['mean'])
    time = '%0.1fh' % (row['CPUTIME'])
    s = "   " + tc.alg_name(row['ALG'].lstrip(' ').rstrip(' ')) + " : " + tc.template_name((row['TEMPLATE']).lstrip(' '))
    c = template_color_map[row['TEMPLATE']]
    ax.annotate( s, (row['CPUTIME']/3600., row['mean'] ), color=c, size=13 )
    plt.scatter( row['CPUTIME']/3600., row['mean'], color=c )

ax.set_xscale('log')
# ax.xaxis.set_ticks(np.arange( 0, 550000, 10000 ))
ax.yaxis.set_minor_locator( MultipleLocator(0.02) )
plt.xticks(fontsize=16)
plt.yticks(fontsize=16)

plt.xlim(10000/3600.,1200000/3600.)

# plt.grid()
plt.grid( which='minor', linestyle=':', dashes=(3,3))


fig = plt.gcf()
a = fig.set_size_inches( 16, 5 )
fout_prefix="/groups/saalfeld/home/bogovicj/pubDrafts/grpDrosTemplate/grpDrosTemplate/figs/speed_quality_best_20180531"
plt.savefig(fout_prefix+".svg")
plt.savefig(fout_prefix+".pdf")
plt.savefig(fout_prefix+".png")

In [ ]: