Temporal Vegetation Modelling using LSTM Neural Networks

Evaluation of Experiments

In this notebook the experiments are evaluated for the CVPRWS paper "Temporal Vegetation Modelling using Long Short-Term Memory Networks for Crop Identification from Medium-Resolution Multi-Spectral Satellite Images"

Set up envornmental variables:


In [4]:
import tensorflow as tf
import os
import re

savedir = "models"

networks = ['lstm', 'rnn', 'cnn']

r = [2,4,6,8]
l = [2,3,4]
f = [0,1,2,3,4,5,6,7,9]

image_filepath = os.path.join("doc","2017_CVPRWS","tmp","images")
table_filepath = os.path.join("doc","2017_CVPRWS","tmp","tables")

if not os.path.exists(image_filepath):
    os.makedirs(image_filepath)
    
if not os.path.exists(table_filepath):
    os.makedirs(table_filepath)

dataset = "test"
rolling_mean_window = 10

def params2name(l,r,f):
    return "{}l{}r50d{}f".format(l,r,f)

def name2param(name):
    l,r,d,f = re.findall('\d+', name)
    return l,r,d,f

runs = [] # expected runs
for r_ in r:
    for l_ in l:
        for f_ in f:
            runs.append(params2name(l_,r_,f_))

def fix_typos(classes):
    classes[3] = classes[3].replace("_"," ") # cloud_shadow -> cloud shadow
    classes[22] = "sugar beet" # org. sugar beets
    return classes

extract data from Tensorflow event files:


In [5]:
def extract_from_eventfile(eventfile_path, tag='train/cross_entropy'):
    steps = []
    values = []

    for e in tf.train.summary_iterator(eventfile_path):
        for v in e.summary.value:
            if v.tag == tag:
                steps.append(e.step)
                values.append(v.simple_value)

    return steps, values

def extract_from_all_eventfiles(path, tag='train/cross_entropy'):
    """ 
    appends values from all event files in one folder
    if path does not exist: returns empty list
    """
    steps = []
    values = []
    
    if os.path.exists(path):
        eventfiles = os.listdir(path)
    else:
        return steps, values # empty

    for eventfile in eventfiles:
        steps_,values_ = extract_from_eventfile(os.path.join(path,eventfile), tag=tag)
        steps.extend(steps_)
        values.extend(values_)
        
    return steps, values

def gather_data_from_multiple_runs(folder_path, runs, dataset="test", tag='train/cross_entropy'):
    """
    look through all save folders defined by runs, extract eventfiles from runs
    and append everything to pandas dataframe
    """

    series = []
    for run in runs:
        path = os.path.join(folder_path,run,dataset)
        steps,values = extract_from_all_eventfiles(path, tag)
        print("run: {} extracted {} values".format(run,len(values)))   

        if len(values) > 0:

            s = pd.Series(data=values, index=steps,name=run).sort_index()
            # drop duplicates
            s = s[~s.index.duplicated(keep='last')]

            #.drop_duplicates(keep='last')
            series.append(s)
            
    return pd.concat(series,axis=1,join="outer")

walk through all event files and extract cross entropy:


In [6]:
import pandas as pd
import numpy as np

def smooth_and_interpolate(data, rolling_mean_window = 10):
    data.interpolate(axis=1,inplace=True)
    return data.rolling(window=rolling_mean_window,axis=0).mean()

def get_best_run(data,max_is_better=False):
    scores = []
    for col in data.columns: 
        s = data[col]
        scores.append(s.loc[s.last_valid_index()])
    
    if max_is_better:
        return data.columns[np.array(scores).argmax()]
    else:
        return data.columns[np.array(scores).argmin()]

datasets = []
best_runs = []

# cross entropy
max_is_better = False
tag='train/cross_entropy'
# for lstm/rnn/cnn
for network in networks: 
    print
    print network
    d = gather_data_from_multiple_runs(os.path.join(savedir,network),runs,dataset=dataset,tag=tag)
    best_runs.append(get_best_run(d, max_is_better))
    d = smooth_and_interpolate(d,rolling_mean_window)
    datasets.append(d)


lstm
run: 2l2r50d0f extracted 742 values
run: 2l2r50d1f extracted 720 values
run: 2l2r50d2f extracted 738 values
run: 2l2r50d3f extracted 363 values
run: 2l2r50d4f extracted 342 values
run: 2l2r50d5f extracted 352 values
run: 2l2r50d6f extracted 370 values
run: 2l2r50d7f extracted 348 values
run: 2l2r50d9f extracted 352 values
run: 3l2r50d0f extracted 718 values
run: 3l2r50d1f extracted 720 values
run: 3l2r50d2f extracted 738 values
run: 3l2r50d3f extracted 363 values
run: 3l2r50d4f extracted 342 values
run: 3l2r50d5f extracted 352 values
run: 3l2r50d6f extracted 370 values
run: 3l2r50d7f extracted 348 values
run: 3l2r50d9f extracted 352 values
run: 4l2r50d0f extracted 718 values
run: 4l2r50d1f extracted 720 values
run: 4l2r50d2f extracted 738 values
run: 4l2r50d3f extracted 363 values
run: 4l2r50d4f extracted 342 values
run: 4l2r50d5f extracted 352 values
run: 4l2r50d6f extracted 370 values
run: 4l2r50d7f extracted 348 values
run: 4l2r50d9f extracted 352 values
run: 2l4r50d0f extracted 718 values
run: 2l4r50d1f extracted 720 values
run: 2l4r50d2f extracted 762 values
run: 2l4r50d3f extracted 363 values
run: 2l4r50d4f extracted 342 values
run: 2l4r50d5f extracted 352 values
run: 2l4r50d6f extracted 370 values
run: 2l4r50d7f extracted 348 values
run: 2l4r50d9f extracted 352 values
run: 3l4r50d0f extracted 718 values
run: 3l4r50d1f extracted 720 values
run: 3l4r50d2f extracted 738 values
run: 3l4r50d3f extracted 363 values
run: 3l4r50d4f extracted 342 values
run: 3l4r50d5f extracted 352 values
run: 3l4r50d6f extracted 370 values
run: 3l4r50d7f extracted 348 values
run: 3l4r50d9f extracted 352 values
run: 4l4r50d0f extracted 718 values
run: 4l4r50d1f extracted 720 values
run: 4l4r50d2f extracted 738 values
run: 4l4r50d3f extracted 363 values
run: 4l4r50d4f extracted 342 values
run: 4l4r50d5f extracted 352 values
run: 4l4r50d6f extracted 370 values
run: 4l4r50d7f extracted 348 values
run: 4l4r50d9f extracted 352 values
run: 2l6r50d0f extracted 718 values
run: 2l6r50d1f extracted 720 values
run: 2l6r50d2f extracted 738 values
run: 2l6r50d3f extracted 363 values
run: 2l6r50d4f extracted 342 values
run: 2l6r50d5f extracted 352 values
run: 2l6r50d6f extracted 370 values
run: 2l6r50d7f extracted 348 values
run: 2l6r50d9f extracted 352 values
run: 3l6r50d0f extracted 718 values
run: 3l6r50d1f extracted 720 values
run: 3l6r50d2f extracted 738 values
run: 3l6r50d3f extracted 363 values
run: 3l6r50d4f extracted 342 values
run: 3l6r50d5f extracted 352 values
run: 3l6r50d6f extracted 370 values
run: 3l6r50d7f extracted 348 values
run: 3l6r50d9f extracted 352 values
run: 4l6r50d0f extracted 718 values
run: 4l6r50d1f extracted 720 values
run: 4l6r50d2f extracted 738 values
run: 4l6r50d3f extracted 363 values
run: 4l6r50d4f extracted 342 values
run: 4l6r50d5f extracted 352 values
run: 4l6r50d6f extracted 370 values
run: 4l6r50d7f extracted 348 values
run: 4l6r50d9f extracted 352 values
run: 2l8r50d0f extracted 718 values
run: 2l8r50d1f extracted 720 values
run: 2l8r50d2f extracted 738 values
run: 2l8r50d3f extracted 363 values
run: 2l8r50d4f extracted 342 values
run: 2l8r50d5f extracted 352 values
run: 2l8r50d6f extracted 370 values
run: 2l8r50d7f extracted 348 values
run: 2l8r50d9f extracted 352 values
run: 3l8r50d0f extracted 718 values
run: 3l8r50d1f extracted 720 values
run: 3l8r50d2f extracted 738 values
run: 3l8r50d3f extracted 363 values
run: 3l8r50d4f extracted 342 values
run: 3l8r50d5f extracted 352 values
run: 3l8r50d6f extracted 370 values
run: 3l8r50d7f extracted 348 values
run: 3l8r50d9f extracted 352 values
run: 4l8r50d0f extracted 718 values
run: 4l8r50d1f extracted 720 values
run: 4l8r50d2f extracted 738 values
run: 4l8r50d3f extracted 363 values
run: 4l8r50d4f extracted 342 values
run: 4l8r50d5f extracted 352 values
run: 4l8r50d6f extracted 370 values
run: 4l8r50d7f extracted 348 values
run: 4l8r50d9f extracted 352 values

rnn
run: 2l2r50d0f extracted 718 values
run: 2l2r50d1f extracted 754 values
run: 2l2r50d2f extracted 738 values
run: 2l2r50d3f extracted 363 values
run: 2l2r50d4f extracted 342 values
run: 2l2r50d5f extracted 352 values
run: 2l2r50d6f extracted 370 values
run: 2l2r50d7f extracted 348 values
run: 2l2r50d9f extracted 352 values
run: 3l2r50d0f extracted 718 values
run: 3l2r50d1f extracted 743 values
run: 3l2r50d2f extracted 768 values
run: 3l2r50d3f extracted 363 values
run: 3l2r50d4f extracted 342 values
run: 3l2r50d5f extracted 352 values
run: 3l2r50d6f extracted 370 values
run: 3l2r50d7f extracted 348 values
run: 3l2r50d9f extracted 352 values
run: 4l2r50d0f extracted 718 values
run: 4l2r50d1f extracted 720 values
run: 4l2r50d2f extracted 738 values
run: 4l2r50d3f extracted 363 values
run: 4l2r50d4f extracted 342 values
run: 4l2r50d5f extracted 352 values
run: 4l2r50d6f extracted 370 values
run: 4l2r50d7f extracted 348 values
run: 4l2r50d9f extracted 352 values
run: 2l4r50d0f extracted 718 values
run: 2l4r50d1f extracted 720 values
run: 2l4r50d2f extracted 738 values
run: 2l4r50d3f extracted 363 values
run: 2l4r50d4f extracted 342 values
run: 2l4r50d5f extracted 352 values
run: 2l4r50d6f extracted 370 values
run: 2l4r50d7f extracted 348 values
run: 2l4r50d9f extracted 352 values
run: 3l4r50d0f extracted 718 values
run: 3l4r50d1f extracted 720 values
run: 3l4r50d2f extracted 753 values
run: 3l4r50d3f extracted 363 values
run: 3l4r50d4f extracted 342 values
run: 3l4r50d5f extracted 352 values
run: 3l4r50d6f extracted 370 values
run: 3l4r50d7f extracted 348 values
run: 3l4r50d9f extracted 352 values
run: 4l4r50d0f extracted 718 values
run: 4l4r50d1f extracted 720 values
run: 4l4r50d2f extracted 738 values
run: 4l4r50d3f extracted 363 values
run: 4l4r50d4f extracted 342 values
run: 4l4r50d5f extracted 352 values
run: 4l4r50d6f extracted 370 values
run: 4l4r50d7f extracted 348 values
run: 4l4r50d9f extracted 352 values
run: 2l6r50d0f extracted 737 values
run: 2l6r50d1f extracted 720 values
run: 2l6r50d2f extracted 756 values
run: 2l6r50d3f extracted 363 values
run: 2l6r50d4f extracted 342 values
run: 2l6r50d5f extracted 352 values
run: 2l6r50d6f extracted 370 values
run: 2l6r50d7f extracted 348 values
run: 2l6r50d9f extracted 352 values
run: 3l6r50d0f extracted 718 values
run: 3l6r50d1f extracted 720 values
run: 3l6r50d2f extracted 738 values
run: 3l6r50d3f extracted 363 values
run: 3l6r50d4f extracted 342 values
run: 3l6r50d5f extracted 352 values
run: 3l6r50d6f extracted 370 values
run: 3l6r50d7f extracted 348 values
run: 3l6r50d9f extracted 352 values
run: 4l6r50d0f extracted 718 values
run: 4l6r50d1f extracted 720 values
run: 4l6r50d2f extracted 738 values
run: 4l6r50d3f extracted 363 values
run: 4l6r50d4f extracted 342 values
run: 4l6r50d5f extracted 352 values
run: 4l6r50d6f extracted 370 values
run: 4l6r50d7f extracted 348 values
run: 4l6r50d9f extracted 352 values
run: 2l8r50d0f extracted 718 values
run: 2l8r50d1f extracted 720 values
run: 2l8r50d2f extracted 738 values
run: 2l8r50d3f extracted 363 values
run: 2l8r50d4f extracted 342 values
run: 2l8r50d5f extracted 352 values
run: 2l8r50d6f extracted 370 values
run: 2l8r50d7f extracted 348 values
run: 2l8r50d9f extracted 352 values
run: 3l8r50d0f extracted 718 values
run: 3l8r50d1f extracted 720 values
run: 3l8r50d2f extracted 738 values
run: 3l8r50d3f extracted 363 values
run: 3l8r50d4f extracted 342 values
run: 3l8r50d5f extracted 352 values
run: 3l8r50d6f extracted 370 values
run: 3l8r50d7f extracted 348 values
run: 3l8r50d9f extracted 352 values
run: 4l8r50d0f extracted 718 values
run: 4l8r50d1f extracted 720 values
run: 4l8r50d2f extracted 738 values
run: 4l8r50d3f extracted 363 values
run: 4l8r50d4f extracted 342 values
run: 4l8r50d5f extracted 352 values
run: 4l8r50d6f extracted 370 values
run: 4l8r50d7f extracted 348 values
run: 4l8r50d9f extracted 352 values

cnn
run: 2l2r50d0f extracted 718 values
run: 2l2r50d1f extracted 720 values
run: 2l2r50d2f extracted 738 values
run: 2l2r50d3f extracted 363 values
run: 2l2r50d4f extracted 342 values
run: 2l2r50d5f extracted 352 values
run: 2l2r50d6f extracted 370 values
run: 2l2r50d7f extracted 348 values
run: 2l2r50d9f extracted 352 values
run: 3l2r50d0f extracted 742 values
run: 3l2r50d1f extracted 737 values
run: 3l2r50d2f extracted 759 values
run: 3l2r50d3f extracted 363 values
run: 3l2r50d4f extracted 342 values
run: 3l2r50d5f extracted 352 values
run: 3l2r50d6f extracted 370 values
run: 3l2r50d7f extracted 348 values
run: 3l2r50d9f extracted 352 values
run: 4l2r50d0f extracted 718 values
run: 4l2r50d1f extracted 720 values
run: 4l2r50d2f extracted 738 values
run: 4l2r50d3f extracted 363 values
run: 4l2r50d4f extracted 342 values
run: 4l2r50d5f extracted 352 values
run: 4l2r50d6f extracted 370 values
run: 4l2r50d7f extracted 348 values
run: 4l2r50d9f extracted 352 values
run: 2l4r50d0f extracted 718 values
run: 2l4r50d1f extracted 720 values
run: 2l4r50d2f extracted 738 values
run: 2l4r50d3f extracted 363 values
run: 2l4r50d4f extracted 342 values
run: 2l4r50d5f extracted 352 values
run: 2l4r50d6f extracted 370 values
run: 2l4r50d7f extracted 348 values
run: 2l4r50d9f extracted 352 values
run: 3l4r50d0f extracted 718 values
run: 3l4r50d1f extracted 720 values
run: 3l4r50d2f extracted 738 values
run: 3l4r50d3f extracted 363 values
run: 3l4r50d4f extracted 342 values
run: 3l4r50d5f extracted 352 values
run: 3l4r50d6f extracted 370 values
run: 3l4r50d7f extracted 348 values
run: 3l4r50d9f extracted 352 values
run: 4l4r50d0f extracted 718 values
run: 4l4r50d1f extracted 720 values
run: 4l4r50d2f extracted 738 values
run: 4l4r50d3f extracted 363 values
run: 4l4r50d4f extracted 342 values
run: 4l4r50d5f extracted 352 values
run: 4l4r50d6f extracted 370 values
run: 4l4r50d7f extracted 348 values
run: 4l4r50d9f extracted 352 values
run: 2l6r50d0f extracted 737 values
run: 2l6r50d1f extracted 720 values
run: 2l6r50d2f extracted 738 values
run: 2l6r50d3f extracted 363 values
run: 2l6r50d4f extracted 342 values
run: 2l6r50d5f extracted 352 values
run: 2l6r50d6f extracted 370 values
run: 2l6r50d7f extracted 348 values
run: 2l6r50d9f extracted 352 values
run: 3l6r50d0f extracted 718 values
run: 3l6r50d1f extracted 720 values
run: 3l6r50d2f extracted 738 values
run: 3l6r50d3f extracted 363 values
run: 3l6r50d4f extracted 342 values
run: 3l6r50d5f extracted 352 values
run: 3l6r50d6f extracted 370 values
run: 3l6r50d7f extracted 348 values
run: 3l6r50d9f extracted 352 values
run: 4l6r50d0f extracted 718 values
run: 4l6r50d1f extracted 720 values
run: 4l6r50d2f extracted 738 values
run: 4l6r50d3f extracted 363 values
run: 4l6r50d4f extracted 342 values
run: 4l6r50d5f extracted 352 values
run: 4l6r50d6f extracted 370 values
run: 4l6r50d7f extracted 348 values
run: 4l6r50d9f extracted 352 values
run: 2l8r50d0f extracted 740 values
run: 2l8r50d1f extracted 720 values
run: 2l8r50d2f extracted 738 values
run: 2l8r50d3f extracted 363 values
run: 2l8r50d4f extracted 342 values
run: 2l8r50d5f extracted 352 values
run: 2l8r50d6f extracted 370 values
run: 2l8r50d7f extracted 348 values
run: 2l8r50d9f extracted 352 values
run: 3l8r50d0f extracted 718 values
run: 3l8r50d1f extracted 720 values
run: 3l8r50d2f extracted 738 values
run: 3l8r50d3f extracted 363 values
run: 3l8r50d4f extracted 342 values
run: 3l8r50d5f extracted 352 values
run: 3l8r50d6f extracted 370 values
run: 3l8r50d7f extracted 348 values
run: 3l8r50d9f extracted 352 values
run: 4l8r50d0f extracted 718 values
run: 4l8r50d1f extracted 720 values
run: 4l8r50d2f extracted 738 values
run: 4l8r50d3f extracted 363 values
run: 4l8r50d4f extracted 342 values
run: 4l8r50d5f extracted 352 values
run: 4l8r50d6f extracted 370 values
run: 4l8r50d7f extracted 348 values
run: 4l8r50d9f extracted 352 values

best runs


In [7]:
for best_run, network in zip(best_runs, networks):
    print "Network {}: best run {}".format(network, best_run)


Network lstm: best run 4l4r50d1f
Network rnn: best run 4l8r50d6f
Network cnn: best run 3l8r50d0f

Plot cross entropy over time for all networks


In [10]:
import matplotlib.colors as colors
import matplotlib.pyplot as plt
import matplotlib
import matplotlib.patches as mpatches
from matplotlib.patches import FancyBboxPatch
from matplotlib.colors import rgb2hex
%matplotlib inline
import seaborn as sns
import numpy as np

sns.set(context='notebook', style='whitegrid', palette='deep', font='Times', font_scale=1, color_codes=False, rc=None)

tumivory = rgb2hex((218./255, 215./255, 203./255))

tumblue = rgb2hex((0., 101./255, 189./255))
tumgreen = rgb2hex((162./255, 173./255, 0))
tumorange = rgb2hex((227./255, 114./255, 34./255))

tumbluelight=rgb2hex((152./255, 198./255, 234./255))
tumbluemedium=rgb2hex((100./255, 160./255, 200./255))
tumbluedark=rgb2hex((0, 82./255, 147./255))

figsize=(11,9)
xlim = (0,7.5e6)

pdf_filepath = os.path.join(image_filepath,"shadedplot.pdf")
tikz_filepath = os.path.join(image_filepath,"shadedplot.tikz")


def plot_network_runs(data, ax=None, best_run = None, col="#cccccc", std_alpha=0.5, label="LSTM"):
    if ax is None:
        f, ax = plt.subplots()
    
    std = data.std(axis=1)
    
    patch = mpatches.Patch(color=col, label=label)
    
    mean = data.mean(axis=1)
    p_mean, = ax.plot(mean, color=col, linestyle="dashed", label="mean")
    
    #median = data.median(axis=1)
    #ax.plot(median,label="median", color=col, linestyle="dotted")
    
    p_best, = ax.plot(data[best_run], color=col, label="best")

    p_std = ax.fill_between(mean.index, mean-std, mean+std, where=mean+std >= mean-std, 
                            interpolate=True, color=col, alpha=std_alpha, label="std")
    
    
    
    # xlabels
    ax.get_xaxis().set_major_formatter(
    matplotlib.ticker.FuncFormatter(lambda x, p: format(int(x), ',')))
    
    ax.set_xlim(*xlim)
    
    return ax, [patch,p_best,p_mean,p_std], [mean, std, data[best_run]]
    
f, ax = plt.subplots(figsize=figsize)
ax, handles1, dat_cnn = plot_network_runs(datasets[1], 
                       ax=ax, 
                       best_run=best_runs[1], 
                       col=tumgreen, 
                       label="CNN")

ax, handles2, dat_rnn = plot_network_runs(datasets[2], 
                       ax=ax, 
                       best_run=best_runs[2], 
                       col=tumorange, 
                       label="RNN")

ax, handles0, dat_lstm = plot_network_runs(datasets[0], 
                       ax=ax, 
                       best_run=best_runs[0], 
                       col=tumblue, 
                       label="LSTM")


ax.legend(handles=handles0+handles1+handles2,ncol=len(networks))
ax.set_xlabel("features trained")
ax.set_ylabel("cross entropy")

# uncomment to save
if False:
    f.savefig(pdf_filepath,transparent=True)

dat_filepath_lstm = os.path.join(image_filepath,"shadedplot_lstm.dat")
dat_filepath_cnn = os.path.join(image_filepath,"shadedplot_cnn.dat")
dat_filepath_rnn = os.path.join(image_filepath,"shadedplot_rnn.dat")
lstm = pd.DataFrame(dat_lstm, index=["mean", "std", "best"]).transpose()
rnn = pd.DataFrame(dat_rnn, index=["mean", "std", "best"]).transpose()
cnn = pd.DataFrame(dat_cnn, index=["mean", "std", "best"]).transpose()

lstm.dropna().to_csv(dat_filepath_lstm, sep=' ', header=True)
rnn.dropna().to_csv(dat_filepath_cnn, sep=' ', header=True)
cnn.dropna().to_csv(dat_filepath_rnn, sep=' ', header=True)


/home/russwurm/anaconda2/lib/python2.7/site-packages/matplotlib/font_manager.py:1288: UserWarning: findfont: Font family [u'Times'] not found. Falling back to Bitstream Vera Sans
  (prop.get_family(), self.defaultFamily[fontext]))
/home/russwurm/anaconda2/lib/python2.7/site-packages/matplotlib/font_manager.py:1298: UserWarning: findfont: Could not match :family=Bitstream Vera Sans:style=normal:variant=normal:weight=400:stretch=normal:size=10.0. Returning /usr/share/matplotlib/mpl-data/fonts/ttf/cmb10.ttf
  UserWarning)
/home/russwurm/anaconda2/lib/python2.7/site-packages/matplotlib/font_manager.py:1298: UserWarning: findfont: Could not match :family=Bitstream Vera Sans:style=normal:variant=normal:weight=400:stretch=normal:size=11.0. Returning /usr/share/matplotlib/mpl-data/fonts/ttf/cmb10.ttf
  UserWarning)

all networks

plot data from all networks


In [11]:
def plot_all_runs(data, ax=None, col="#cccccc", std_alpha=0.5, label_std=r"1$\sigma$ std", label="best run", title="title"):
    if ax is None:
        f, ax = plt.subplots()
    
    std = data.std(axis=1)
    mean = data.mean(axis=1)

    runs = data.columns
    for run in runs:
        ax.plot(data[run],label=run)

    ax.fill_between(mean.index, mean-std, mean+std, where=mean+std >= mean-std, interpolate=True, facecolor=col, alpha=std_alpha, label=label_std)
    
    # xlabels
    ax.get_xaxis().set_major_formatter(
    matplotlib.ticker.FuncFormatter(lambda x, p: format(int(x), ',')))
    
    ax.set_xlim(0,1e7)
    ax.set_title(title)
    
    return ax

figsize=(15,10)
for data, network in zip(datasets, networks):
    f, ax = plt.subplots(figsize=figsize)
    plot_all_runs(data, ax=ax, col="#cccccc", std_alpha=0.5, label_std=r"1$\sigma$ std", label="best run", title=network)
    ax.legend()


/home/russwurm/anaconda2/lib/python2.7/site-packages/matplotlib/font_manager.py:1298: UserWarning: findfont: Could not match :family=Bitstream Vera Sans:style=normal:variant=normal:weight=400:stretch=normal:size=12.0. Returning /usr/share/matplotlib/mpl-data/fonts/ttf/cmb10.ttf
  UserWarning)

Check completeness of training

Extract indeces of first event and last event for all networks to test if the networks have all trained completely


In [12]:
out = []
for run in runs:
    for network, data in zip(networks, datasets):
        if run in data.columns:
            col = data[run]
            first_valid = col.first_valid_index()
            last_valid = col.last_valid_index()
        else:
            first_valid = None
            last_valid = None
            
        out.append({"network":network, "run":run, "first_valid":first_valid, "last_valid":last_valid})


d = pd.DataFrame(out).set_index("run")
first_valid_by_network = d.pivot(columns='network', values='first_valid')
last_valid_by_network = d.pivot(columns='network', values='last_valid')

indices of first events for each network


In [13]:
first_valid_by_network


Out[13]:
network cnn lstm rnn
run
2l2r50d0f 100000 100000 100000
2l2r50d1f 100000 100000 100000
2l2r50d2f 100000 100000 100000
2l2r50d3f 100000 100000 100000
2l2r50d4f 100000 100000 100000
2l2r50d5f 100000 100000 100000
2l2r50d6f 100000 100000 100000
2l2r50d7f 100000 100000 100000
2l2r50d9f 100000 100000 100000
2l4r50d0f 100000 100000 100000
2l4r50d1f 100000 100000 100000
2l4r50d2f 100000 100000 100000
2l4r50d3f 100000 100000 100000
2l4r50d4f 100000 100000 100000
2l4r50d5f 100000 100000 100000
2l4r50d6f 100000 100000 100000
2l4r50d7f 100000 100000 100000
2l4r50d9f 100000 100000 100000
2l6r50d0f 100000 100000 100000
2l6r50d1f 100000 100000 100000
2l6r50d2f 100000 100000 100000
2l6r50d3f 100000 100000 100000
2l6r50d4f 100000 100000 100000
2l6r50d5f 100000 100000 100000
2l6r50d6f 100000 100000 100000
2l6r50d7f 100000 100000 100000
2l6r50d9f 100000 100000 100000
2l8r50d0f 100000 100000 100000
2l8r50d1f 100000 100000 100000
2l8r50d2f 100000 100000 100000
... ... ... ...
4l2r50d6f 100000 100000 100000
4l2r50d7f 100000 100000 100000
4l2r50d9f 100000 100000 100000
4l4r50d0f 100000 100000 100000
4l4r50d1f 100000 100000 100000
4l4r50d2f 100000 100000 100000
4l4r50d3f 100000 100000 100000
4l4r50d4f 100000 100000 100000
4l4r50d5f 100000 100000 100000
4l4r50d6f 100000 100000 100000
4l4r50d7f 100000 100000 100000
4l4r50d9f 100000 100000 100000
4l6r50d0f 100000 100000 100000
4l6r50d1f 100000 100000 100000
4l6r50d2f 100000 100000 100000
4l6r50d3f 100000 100000 100000
4l6r50d4f 100000 100000 100000
4l6r50d5f 100000 100000 100000
4l6r50d6f 100000 100000 100000
4l6r50d7f 100000 100000 100000
4l6r50d9f 100000 100000 100000
4l8r50d0f 100000 100000 100000
4l8r50d1f 100000 100000 100000
4l8r50d2f 100000 100000 100000
4l8r50d3f 100000 100000 100000
4l8r50d4f 100000 100000 100000
4l8r50d5f 100000 100000 100000
4l8r50d6f 100000 100000 100000
4l8r50d7f 100000 100000 100000
4l8r50d9f 100000 100000 100000

108 rows × 3 columns

indices of last event


In [14]:
last_valid_by_network


Out[14]:
network cnn lstm rnn
run
2l2r50d0f 7180000 7390000 7180000
2l2r50d1f 7200000 7390000 7490000
2l2r50d2f 7380000 7390000 7490000
2l2r50d3f 7380000 7390000 7490000
2l2r50d4f 7380000 7390000 7490000
2l2r50d5f 7380000 7390000 7490000
2l2r50d6f 7380000 7400000 7490000
2l2r50d7f 7380000 7400000 7490000
2l2r50d9f 7380000 7400000 7490000
2l4r50d0f 7580000 7400000 7640000
2l4r50d1f 7580000 7400000 7640000
2l4r50d2f 7580000 7580000 7640000
2l4r50d3f 7580000 7580000 7640000
2l4r50d4f 7580000 7580000 7640000
2l4r50d5f 7580000 7580000 7640000
2l4r50d6f 7580000 7580000 7640000
2l4r50d7f 7580000 7580000 7640000
2l4r50d9f 7580000 7580000 7640000
2l6r50d0f 7580000 7580000 7640000
2l6r50d1f 7580000 7580000 7640000
2l6r50d2f 7580000 7580000 7640000
2l6r50d3f 7580000 7580000 7640000
2l6r50d4f 7580000 7580000 7640000
2l6r50d5f 7580000 7580000 7640000
2l6r50d6f 7580000 7580000 7640000
2l6r50d7f 7580000 7580000 7640000
2l6r50d9f 7580000 7580000 7640000
2l8r50d0f 7580000 7580000 7640000
2l8r50d1f 7580000 7580000 7640000
2l8r50d2f 7580000 7580000 7640000
... ... ... ...
4l2r50d6f 7580000 7400000 7640000
4l2r50d7f 7580000 7400000 7640000
4l2r50d9f 7580000 7400000 7640000
4l4r50d0f 7580000 7580000 7640000
4l4r50d1f 7580000 7580000 7640000
4l4r50d2f 7580000 7580000 7640000
4l4r50d3f 7580000 7580000 7640000
4l4r50d4f 7580000 7580000 7640000
4l4r50d5f 7580000 7580000 7640000
4l4r50d6f 7580000 7580000 7640000
4l4r50d7f 7580000 7580000 7640000
4l4r50d9f 7580000 7580000 7640000
4l6r50d0f 7580000 7580000 7640000
4l6r50d1f 7580000 7580000 7640000
4l6r50d2f 7580000 7580000 7640000
4l6r50d3f 7580000 7580000 7640000
4l6r50d4f 7580000 7580000 7640000
4l6r50d5f 7580000 7580000 7640000
4l6r50d6f 7580000 7580000 7640000
4l6r50d7f 7580000 7580000 7640000
4l6r50d9f 7580000 7580000 7640000
4l8r50d0f 7580000 7580000 7640000
4l8r50d1f 7580000 7580000 7640000
4l8r50d2f 7580000 7580000 7640000
4l8r50d3f 7580000 7580000 7640000
4l8r50d4f 7580000 7580000 7640000
4l8r50d5f 7580000 7580000 7640000
4l8r50d6f 7580000 7580000 7640000
4l8r50d7f 7580000 7580000 7640000
4l8r50d9f 7580000 7580000 7640000

108 rows × 3 columns

Plots of LSTM layers

by number of layers l and number if cells per layer r


In [15]:
data = datasets[0]

colormapping = {  2:"blue",
                  3:"green",
                  4:"red"}


def plot_by_l(data, colormapping, ax=None, cols=["blue","green","red"]):
    if ax is None:
        f, ax = plt.subplots()
    
    runs = data.columns
    for run in runs:
        l,r,d,f = name2param(run.split("/")[-1])
        col = colormapping[int(l)]
        ax.plot(data[run], color=col)
    
    # xlabels
    ax.get_xaxis().set_major_formatter(
    matplotlib.ticker.FuncFormatter(lambda x, p: format(int(x), ',')))
    
    ax.set_xlim(0,1e7)
    
    return ax

f, ax = plt.subplots()
plot_by_l(data, colormapping,ax=ax)
ax.legend(["2 layer", "3 layer", "4 layer"])


Out[15]:
<matplotlib.legend.Legend at 0x7f1b38497f10>

In [17]:
data = datasets[0]


cols = sns.color_palette("RdBu_r", n_colors=len(r))

colormapping = dict(zip(r,cols))


def plot_by_r(data, colormapping, ax=None, cols=["blue","green","red"]):
    """
        Plot multilpe networks sorted by number of cells r
    """
    if ax is None:
        f, ax = plt.subplots()
    
    labels = ["{} cells per layer".format(r_*55) for r_ in colormapping.keys()]
    labels_map = dict(zip(colormapping.keys(),labels))
    
    
    runs = data.columns
    for run in runs:
        l,r,d,f = name2param(run.split("/")[-1])
        #print r
        col = colormapping[int(r)]
        ax.plot(data[run], color=col, label=labels_map[int(r)], alpha=0.5)
        labels_map[int(r)] = ""
    
    # xlabels
    ax.get_xaxis().set_major_formatter(
    matplotlib.ticker.FuncFormatter(lambda x, p: format(int(x), ',')))
    
    ax.set_ylabel("cross entropy")
    
    ax.set_xlim(0,10e6)
    
    return ax

f, ax = plt.subplots(figsize=(6,3))
plot_by_r(data, colormapping,ax=ax)
ax.legend()


Out[17]:
<matplotlib.legend.Legend at 0x7f1b383e6e90>

Accuracy Metric Evaluation

this part utilizes evaluation files

  • eval_scores.npy.
  • eval_probabilites.npy,
  • eval_observations.npy in the run directors if no files are available create

e.g.

python evaluate.py save/folds/lstm/4l4r50d1f --model lstm
  • accuracy
  • precision
  • recall
  • kappa
  • fscore
  • AOI

In [18]:
from sklearn.metrics import accuracy_score, cohen_kappa_score, classification_report, precision_score, recall_score, f1_score, precision_recall_fscore_support

def gather_accuracy_values_per_class(classes,targets,scores):
    """
    Gather per class a variety of accuracy metrics from targets and scores
    """
    y_pred = np.argmax(scores,axis=1)
    y_true = np.argmax(targets,axis=1)
    
    precision_, recall_, fscore_, support_ = precision_recall_fscore_support(y_true, y_pred, beta=0.5, average=None)

    fscore = pd.Series(index=classes, data=fscore_, name="f-score")
    precision = pd.Series(index=classes, data=precision_, name="precision")
    recall = pd.Series(index=classes, data=recall_, name="recall")
    support = pd.Series(index=classes, data=support_, name="support")
    
    s = [fscore,precision,recall, support]
    names = [el.name for el in s]
    return pd.DataFrame(zip(*s), columns=names, index=recall.index).T

definition of Sklearn accuracy:

In this implementation of accuracy weights are not 1, but inverse frequence of classes. controlled via sample_weights parameter


In [19]:
import numpy as np
from sklearn.metrics import roc_curve, roc_auc_score, accuracy_score, cohen_kappa_score, classification_report, precision_score, recall_score, f1_score, precision_recall_fscore_support
 
def gather_mean_accuracies(classes, scores, targets, average='weighted', label="label", b=4):
    """
    calculate a series for mean accuracy values for all, covered (class id < b) and fields (class id > b)
    """
    metrics = []
    
    y_pred = np.argmax(scores,axis=1)
    y_true = np.argmax(targets,axis=1)
    
    all_mask = np.ones(y_true.shape)
    covered_mask = y_true<b
    field_mask = y_true>=b
    
    # class weighted average accuracy
    w_all = np.ones(y_true.shape[0])
    for idx, i in enumerate(np.bincount(y_true)):
        w_all[y_true == idx] *= (i/float(y_true.shape[0]))
        
    w_cov = np.ones(y_true[covered_mask].shape[0])
    for idx, i in enumerate(np.bincount(y_true[covered_mask])):
        w_cov[y_true[covered_mask] == idx] *= (i/float(y_true[covered_mask].shape[0]))
        
    w_field = np.ones(y_true[field_mask].shape[0])
    for idx, i in enumerate(np.bincount(y_true[field_mask])):
        w_field[y_true[field_mask] == idx] *= (i/float(y_true[field_mask].shape[0]))
        
    w_acc = accuracy_score(y_true, y_pred, sample_weight=w_all)
    w_acc_cov = accuracy_score(y_true[covered_mask], y_pred[covered_mask], sample_weight=w_cov)
    w_acc_field = accuracy_score(y_true[field_mask], y_pred[field_mask], sample_weight=w_field)
    
    metrics.append(pd.Series(data=[w_acc, w_acc_cov, w_acc_field], dtype=float, name="accuracy"))
    
    # AUC
    try:
        # if AUC not possible skip
        auc = roc_auc_score(targets, scores, average=average)
        auc_cov = roc_auc_score(targets[covered_mask,:b], scores[covered_mask,:b], average=average)
        auc_field = roc_auc_score(targets[field_mask,b:], scores[field_mask,b:], average=average)

        metrics.append(pd.Series(data=[auc, auc_cov, auc_field], dtype=float, name="AUC"))
    except:
        print "no AUC calculated"
        pass
    
    # Kappa
    kappa = cohen_kappa_score(y_true, y_pred)
    kappa_cov = cohen_kappa_score(y_true[covered_mask], y_pred[covered_mask])
    kappa_field = cohen_kappa_score(y_true[field_mask], y_pred[field_mask])
    
    metrics.append(pd.Series(data=[kappa, kappa_cov, kappa_field], dtype=float, name="kappa"))
    
    # Precision, Recall, F1, support
    prec, rec, f1, support = precision_recall_fscore_support(y_true, y_pred, beta=1, average=average)
    prec_cov, rec_cov, f1_cov, support_cov = precision_recall_fscore_support(y_true[covered_mask], y_pred[covered_mask], beta=1, average=average)
    prec_field, rec_field, f1_field, support_field = precision_recall_fscore_support(y_true[field_mask], y_pred[field_mask], beta=1, average=average)
    
    metrics.append(pd.Series(data=[prec, prec_cov, prec_field], dtype=float, name="precision"))
    metrics.append(pd.Series(data=[rec, rec_cov, rec_field], dtype=float, name="recall"))
    metrics.append(pd.Series(data=[f1, f1_cov, f1_field], dtype=float, name="fscore"))
    #sup_ = pd.Series(data=[support, support_cov, support_field], dtype=int, name="support")
        
    df_ = pd.DataFrame(metrics).T
    if label is not None:
        df_.index = [[label,label,label],["all","cov","fields"]]
    else:
        df_.index = ["all","cov","fields"]
    
    return df_

In [20]:
from sklearn.metrics import confusion_matrix

# border in the classes between field classes and coverage
b = 4

obs_file = "eval_observations.npy"
probs_file = "eval_probabilities.npy"
targets_file = "eval_targets.npy"
conf_mat_file = "eval_confusion_matrix.npy"
class_file = "classes.npy"

# drop fc for now:
#networks = [networks[0], networks[2]]
#best_runs = [best_runs[0], best_runs[2]]

networklabels = ["LSTM","RNN","CNN"]

#over_accuracy_label = "ov. accuracy2"

# ignore <obs_limit> first observations
obs_limit = 0

acc=[]
mean = []
for best_run, network, label_ in zip(best_runs,networks,networklabels):
    print network
   
    path = os.path.join(savedir,network,best_run)
    #obs = np.load(os.path.join(path,obs_file))
    scores = np.load(os.path.join(path,probs_file))
    targets = np.load(os.path.join(path,targets_file))
    #if os.path.exists(os.path.join(path,conf_mat_file)):
    #    cm = np.load(os.path.join(path,conf_mat_file))
    #else:
    y_pred = np.argmax(scores,axis=1)
    y_true = np.argmax(targets,axis=1)
    cm = confusion_matrix(y_true,y_pred)


    classes = fix_typos(
                    list(np.load(os.path.join(path,class_file)))
            )
    
    #df_, a_ = acc_mean_accuracies(cm, classes, label_, b, scores,targets)
    df_ = gather_mean_accuracies(classes, scores, targets, b=b, label=label_)
    
    mean.append(df_)
    
mean_df = pd.concat(mean)


lstm
/home/russwurm/anaconda2/lib/python2.7/site-packages/sklearn/metrics/classification.py:1076: UndefinedMetricWarning: Recall and F-score are ill-defined and being set to 0.0 in labels with no true samples.
  'recall', 'true', average, warn_for)
rnn
cnn
/home/russwurm/anaconda2/lib/python2.7/site-packages/sklearn/metrics/classification.py:1074: UndefinedMetricWarning: Precision and F-score are ill-defined and being set to 0.0 in labels with no predicted samples.
  'precision', 'predicted', average, warn_for)

In [21]:
# CNN
gather_accuracy_values_per_class(classes,targets,scores)


Out[21]:
cloud water snow cloud shadow other sugar beet corn meadow asparagus rape ... fallow winter wheat winter barley winter rye beans winter triticale summer barley peas potatoe sugar beet
f-score 0.953189 0.279844 0.580252 0.859940 0.667834 0.567854 0.464987 0.0 0.564285 0.008255 ... 0.470068 0.405574 0.0 0.213499 0.0 0.094097 0.005236 0.209685 0.012937 0.551577
precision 0.957144 0.334073 0.619585 0.853733 0.670908 0.563712 0.437367 0.0 0.623093 0.210526 ... 0.433381 0.424064 0.0 0.489474 0.0 0.315615 0.333333 0.488778 0.363636 0.628046
recall 0.937692 0.169673 0.462745 0.885700 0.655813 0.585045 0.622144 0.0 0.409639 0.001704 ... 0.710721 0.345342 0.0 0.065585 0.0 0.024714 0.001060 0.063851 0.002663 0.370925
support 693246.000000 1774.000000 10200.000000 35092.000000 256811.000000 103821.000000 72951.000000 946.0 14254.000000 9387.000000 ... 71972.000000 25065.000000 2853.0 1418.000000 4971.0 7688.000000 943.000000 9209.000000 1502.000000 3405.000000

4 rows × 23 columns


In [22]:
best_runs, networks


Out[22]:
(['4l4r50d1f', '4l8r50d6f', '3l8r50d0f'], ['lstm', 'rnn', 'cnn'])

SVM

evaluation of SVM baseline. SVM baseline is calculated at SVM.py. Results are saved as .npy


In [23]:
from sklearn.metrics import roc_curve
# load from SVM Evaluation.ipynb

cm = cm_SVM = np.load("svm/confusion_matrix.npy")
scores = np.load("svm/scores.npy")
targets = np.load("svm/targets.npy")
pred = np.load("svm/predicted.npy")

df = gather_mean_accuracies(classes, scores, targets, b=b, label="SVM")

In [24]:
# drop previous SVM if exists
table_df = pd.concat([mean_df,df])
table = table_df.round(3) * 100

In [25]:
# put overall accuracy first
#overall_acc = table[over_accuracy_label]
#new = table.drop([over_accuracy_label],axis=1)
#new.insert(0,over_accuracy_label,overall_acc)
#table = new.transpose()
table = table.transpose()
table


Out[25]:
LSTM RNN CNN SVM
all cov fields all cov fields all cov fields all cov fields
accuracy 90.6 93.6 74.3 89.8 92.9 72.9 89.2 93.7 64.3 40.9 87.4 31.1
AUC 98.1 97.5 94.9 97.7 97.0 94.1 95.1 97.0 84.7 87.1 97.6 81.6
kappa 77.6 55.6 67.4 76.1 53.0 65.6 66.2 56.3 44.0 38.2 83.4 27.3
precision 85.6 98.4 78.4 84.8 98.2 77.3 76.7 98.2 59.2 40.2 91.2 31.4
recall 84.4 92.5 74.5 83.4 91.8 73.0 76.8 92.7 57.2 40.9 87.4 31.1
fscore 84.6 95.3 75.3 83.6 94.9 74.0 76.1 95.3 56.7 40.3 88.9 31.1

write to Tex


In [26]:
# not nessecary anymore write data and input in formatted table
if False:
    filename = os.path.join(table_filepath, "jupyter_generated_meanAccuracyPerNetwork.tex")
    

    # put oa first
    oa_row = table[-1:]
    table = table.drop(over_accuracy_label)
    table = pd.concat([oa_row,table])

    with open(filename, 'w+') as f:  
        f.write(table.to_latex())

write table out data


In [27]:
if False:
    filename = os.path.join(table_filepath, "meanAccuracyPerNetwork.data")
    table.to_csv(filename, sep='&', line_terminator= r'\\', header=False)

In [28]:
from mpl_toolkits.axes_grid1 import make_axes_locatable

def plot_confusion_matrix(confusion_matrix, classes, normalize_axis=None, figsize=(7, 7), colormap=None):
    """
    Plots a confusion matrix using seaborn heatmap functionality

    @param confusion_matrix: np array [n_classes, n_classes] with rows reference and cols predicted
    @param classes: list of class labels
    @param normalize_axis: 0 sum of rows, 1: sum of cols, None no normalization

    @return matplotlib figure
    """
    # Set up the matplotlib figure
    plt.figure()
    f, ax = plt.subplots(figsize=figsize)

    # normalize
    normalized_str = ""  # add on at the title
    if normalize_axis is not None:

        with np.errstate(divide='ignore'):  # ignore divide by zero and replace with 0
            confusion_matrix = np.nan_to_num(
                confusion_matrix.astype(float) / np.sum(confusion_matrix, axis=normalize_axis))

        
    # Draw the heatmap with the mask and correct aspect ratio
    g = sns.heatmap(confusion_matrix,
                    square=True,
                    linewidths=1, 
                    cbar=False,
                    ax=ax,
                    cmap=colormap)
    
    divider = make_axes_locatable(g)
    cax = divider.append_axes("right", size="5%", pad=0.05)
    
    cbar = g.figure.colorbar(ax.collections[0],cax=cax)
    if normalize_axis == 0:
        cbar.set_label("precision")
    if normalize_axis == 1:
        cbar.set_label("recall")

    n_classes = len(classes)
    # if n_classes < threshold plot values in plot

    cols = np.arange(0, n_classes)
    rows = np.arange(n_classes - 1, -1, -1)

    #g.set_title("Confusion Matrix")
    g.set_xticklabels([])
    g.set_yticklabels(classes[::-1], rotation=0)
    g.set_xlabel("predicted")
    g.set_ylabel("reference")

    return f, g

Grid of parameters


In [29]:
from sklearn.metrics import confusion_matrix

if False:

    # border in the classes between field classes and coverage
    b = 4

    obs_file = "eval_observations.npy"
    probs_file = "eval_probabilities.npy"
    targets_file = "eval_targets.npy"
    conf_mat_file = "eval_confusion_matrix.npy"
    class_file = "classes.npy"

    # drop fc for now:
    #networks = [networks[0], networks[2]]
    #best_runs = [best_runs[0], best_runs[2]]

    networklabels = ["LSTM","RNN","CNN"]

    over_accuracy_label = "ov. accuracy"

    # ignore <obs_limit> first observations
    obs_limit = 0


    n_data=[]
    for network in networks:
        mean = []
        acc=[]
        for run in runs:
            print network

            path = os.path.join(savedir, network, best_run)
            #obs = np.load(os.path.join(path,obs_file))

            try:
                scores = np.load(os.path.join(path,probs_file))
                targets = np.load(os.path.join(path,targets_file))
            except:
                mean.append(None)
                continue
            #if os.path.exists(os.path.join(path,conf_mat_file)):
            #    cm = np.load(os.path.join(path,conf_mat_file))
            #else:
            y_pred = np.argmax(scores,axis=1)
            y_true = np.argmax(targets,axis=1)
            cm = confusion_matrix(y_true,y_pred)

            # add Overall accuracy to table
            cm_cov = cm[0:b, 0:b]
            cm_field = cm[b:, b:]

            df_, a_ = acc_mean_accuracies(cm, classes, None, b, scores,targets)

            oa,_ = calculate_accuracy_metrics(cm)
            oa_cov, _ = calculate_accuracy_metrics(cm_cov)
            oa_field, _ = calculate_accuracy_metrics(cm_field)
            df_[over_accuracy_label] = [oa, oa_cov,oa_field]

            #mean.append(df_)

            n_data.append({"avg. acc":df_.loc["all"]["avg. accuracy"],"network": network, "run": run})

In [30]:
if False:

    df_.loc["all"]["avg. accuracy"]

    n_data = []
    for i,dataset in zip(range(len(networks)), datasets):
        data = []
        x_entropy = dataset.dropna().iloc[-1].values
        run_names = dataset.dropna().iloc[-1].index

        for run_name, loss in zip(run_names, x_entropy):
            l,r,d,f = name2param(run_name)
            data.append({"l":int(l),"r":int(r)*55,"xentropy":loss})

        df = pd.DataFrame(data).pivot(index='l', columns='r', values='xentropy')

        n_data.append(pd.DataFrame(data))

    if False:
        for d,network in zip(n_data,networklabels):
            filename = os.path.join(image_filepath, "gridplot_{}.dat".format(network))
            d.to_csv(filename, sep=' ', header=False, index=False)
    #n_data[0]

In [31]:
def unroll_cms(cms):
    rows,cols = cms[0].shape
    line = []
    dat = []
    for row in range(rows):
        for col in range(cols):
            line = [col+1, row+1]
            for cm in cms:
                line.append(cm[row, col])
            dat.append(line)
    return dat

Confusion Matrix


In [32]:
# Foerster confusion matrix

cm_foe = np.array([[76.6,1.6,0.4,0,0,6.7,1.6,0,0.4,4.4,3.2,5.1],
[0,0,0,0,0,0,0,0,0,0,0,0],
[10.1,0,71.1,8.8,6.1,1.6,0,0,0,0.4,1.9,0],
[1.6,0,2.9,91.7,0,1.5,0.7,0,0,0,0,1.6],
[5.3,0,10.8,0,76.9,3.9,0,0,0,2,1.1,0],
[1.2,0,0,1.5,0,97.3,0,0,0,0,0,0],
[0,0,0,52.4,0,47.6,0,0,0,0,0,0],
[0,0,0,0,0,0,0,0,0,100,0,0],
[0.4,0,6.2,1.4,0,1.3,0.6,0,14.5,1.8,68.5,5.2],
[2.1,0,0,0,0,1,0,0,3.1,92.1,1.7,0],
[0,0,0,0,0,0,0,0,0,0,0,0],
[0,100,0,0,0,0,0,0,0,0,0,0]],dtype=float)/100

classes_foe = ["fallow", "perennial field grass", "winter rye", "winter wheat", 
               "winter barley", "oilseed rape", "summer grain", "sugar beets", 
               "maize", "oilseed crops and legumes", "potatoes", "first year field grass"]

dat = unroll_cms([cm_foe])
pd.DataFrame(unroll_cms([cm_foe])).to_csv(os.path.join(image_filepath,"confmatfoerster.dat"), sep=' ', header=False, index=False)


#plt.imshow(cm_foe)
#f,ax = plot_confusion_matrix(cm_foe, classes_foe, figsize=figsize, normalize_axis=1, colormap = "Blues")
for cl in classes_foe:
    print "\classname{%s}" % cl


\classname{fallow}
\classname{perennial field grass}
\classname{winter rye}
\classname{winter wheat}
\classname{winter barley}
\classname{oilseed rape}
\classname{summer grain}
\classname{sugar beets}
\classname{maize}
\classname{oilseed crops and legumes}
\classname{potatoes}
\classname{first year field grass}

In [33]:
from sklearn.metrics import confusion_matrix
%matplotlib inline

def calc_confusion_matrix(path):
    probs_file = "eval_probabilities.npy"
    targets_file = "eval_targets.npy"

    scores = np.load(os.path.join(path,probs_file))
    targets = np.load(os.path.join(path,targets_file))

    y_pred = np.argmax(scores,axis=1)
    y_true = np.argmax(targets,axis=1)
    cm = confusion_matrix(y_true,y_pred)

    return cm

cms_prec = []
cms_rec = []
for best_run, network in zip(best_runs, networks):
    path = os.path.join(savedir, network, best_run)
    cm = calc_confusion_matrix(path)
    cms_prec.append(cm.astype(float)/np.sum(cm,axis=0))
    cms_rec.append(cm.astype(float)/np.sum(cm,axis=1))
    
# add SVM
cms_prec.append(cm_SVM.astype(float)/np.sum(cm_SVM,axis=0))
cms_rec.append(cm_SVM.astype(float)/np.sum(cm_SVM,axis=1))
    
pd.DataFrame(unroll_cms(cms_prec)).to_csv(os.path.join(image_filepath,"confmatprecision.dat"), sep=' ', header=False, index=False)
pd.DataFrame(unroll_cms(cms_rec)).to_csv(os.path.join(image_filepath,"confmatrecall.dat"), sep=' ', header=False, index=False)

cm = cms_prec[0]
    
# Generate a custom diverging colormap
cmap = sns.color_palette("Blues")
figsize=(6,6)
f,ax = plot_confusion_matrix(cm, classes, figsize=figsize, normalize_axis=1, colormap = "Blues")

dat_filepath = os.path.join(image_filepath,"confusion_matrix.dat")
pdf_filepath = os.path.join(image_filepath,"confusion_matrix.pdf")
tikz_filepath = os.path.join(image_filepath,"confusion_matrix.tikz")

# double checked at http://stackoverflow.com/questions/20927368/python-how-to-normalize-a-confusion-matrix
precision = cm/np.sum(cm,axis=0)
recall = cm/np.sum(cm,axis=1)

# uncomment to save
if False:
    #f.savefig(os.path.join(image_filepath,"confusion_matrix.pdf"),transparent=True)
    #tikz_save(tikz_filepath)
    pd.DataFrame(cm).to_csv(os.path.join(image_filepath,"confusion_matrix.dat"), sep=' ', header=False, index=False)
    pd.DataFrame(precision).to_csv(os.path.join(image_filepath,"confusion_matrix_prec.dat"), sep=' ', header=False, index=False)
    pd.DataFrame(recall).to_csv(os.path.join(image_filepath,"confusion_matrix_recall.dat"), sep=' ', header=False, index=False)
    
    pd.DataFrame(classes).to_csv(os.path.join(image_filepath,"confusion_matrix.labels"), sep=' ', header=False, index=False)


<matplotlib.figure.Figure at 0x7f1b3806ef50>

Barplots per class

Influence of element in sequence


In [34]:
lstm_network, _, _ = networks
lstm_best, _, _ = best_runs
path = os.path.join(savedir, lstm_network, lstm_best)

obs = np.load(os.path.join(path,obs_file))
scores = np.load(os.path.join(path,probs_file))
targets = np.load(os.path.join(path,targets_file))

In [35]:
#from sklearn.metrics import confusion_matrix

b = 4 #filter all non cloud classes
def get_obs_subset(targets,scores,obs,obs_idx, b, classes):
    """
    This function calls the gather_mean_accuracies, which is used for the calculation of accuracy tables
    on a subset of targets and scores filtered by obs_idx
    """
    # select by observation
    sc = scores[obs==obs_idx]
    ta = targets[obs==obs_idx]
    
    return gather_mean_accuracies(classes, sc, ta, average='weighted', label="label", b=4)

In [36]:
#a = get_obs_subset(targets,scores,3, b, classes)
sc = scores[obs==3]
ta = targets[obs==3]
gather_mean_accuracies(classes, sc, ta, average='weighted', label="label", b=4)


no AUC calculated
Out[36]:
accuracy kappa precision recall fscore
label all 0.949167 0.571642 0.882885 0.892407 0.885235
cov 0.949167 0.571682 0.884211 0.893442 0.886500
fields 0.406237 0.050855 0.353666 0.303030 0.282303

In [37]:
# gather accuracy values for fields
#from util.db import conn
import sklearn

#t = pd.read_sql("select distinct doa, doy from products order by doa",conn)["doy"]
#t.to_pickle(os.path.join("loc","t.pkl"))

t = pd.read_pickle(os.path.join("loc","t.pkl"))

def collect_data_per_obs(targets, scores, obs, classes, metric="accuracy", classcategory="fields"):
    """
    this function calculates `metric` based on scores and targets for each available observations `t` 0..25
    
    This function takes a 
    - target matrix resembling ground thruth, 
    - scores as calculated probablities for each observation
    - obs as indices of observation
    """
    
    #oa = []
    outlist=[]
    for i in range(len(t)):
        try:
            per_class_ = get_obs_subset(targets,scores,obs,i, b, classes)
        #per_class.append(per_class_.mean(axis=0))
        # append the average <classcategory> <metric> at each time i
            outlist.append(per_class_.loc["label"].loc[classcategory][metric])
        except:
            print "t{} could not calculate accuracy metrics".format(i)
            outlist.append(None)
            pass
        #oa.append(oa_)
        
        print "Collecting doy {} ({}/{})".format(t[i],i+1,len(t))

    #oa_s = pd.Series(data=oa, name=over_accuracy_label, index=t)
    return pd.DataFrame(outlist, index=t)


def collect_data_for_each_network(networks, best_runs, metric="kappa", classcategory="fields"):
    """
    This function calls collect_data_per_obs for each network.
    First targets, scores and obs are loaded from file at the respective network's best run model
    Then collect_data_per_obs is called.
    """
    
    acc_dfs = []
    for network, best in zip(networks, best_runs):
        path = os.path.join(savedir, network, best)

        obs = np.load(os.path.join(path,obs_file))
        scores = np.load(os.path.join(path,probs_file))
        targets = np.load(os.path.join(path,targets_file))
        print 
        print network
        # for every network append a dataframe of observations
        observations_df_ = collect_data_per_obs(targets, scores, obs, classes, metric=metric, classcategory=classcategory)
        acc_dfs.append(observations_df_.values.reshape(-1))

    # create final DataFrame with proper column and indexes of all three networks
    return pd.DataFrame(acc_dfs, index=networks,columns=t).T

In [38]:
acc_df = collect_data_for_each_network(networks, best_runs, metric="accuracy", classcategory="fields")
rec_df = collect_data_for_each_network(networks, best_runs, metric="recall", classcategory="fields")
kappa_df = collect_data_for_each_network(networks, best_runs, metric="kappa", classcategory="fields")
prec_df = collect_data_for_each_network(networks, best_runs, metric="precision", classcategory="fields")


lstm
t0 could not calculate accuracy metrics
Collecting doy 365 (1/25)
t1 could not calculate accuracy metrics
Collecting doy 3 (2/25)
no AUC calculated
Collecting doy 10 (3/25)
no AUC calculated
Collecting doy 13 (4/25)
no AUC calculated
Collecting doy 20 (5/25)
Collecting doy 43 (6/25)
no AUC calculated
Collecting doy 70 (7/25)
Collecting doy 80 (8/25)
Collecting doy 103 (9/25)
Collecting doy 110 (10/25)
no AUC calculated
Collecting doy 113 (11/25)
no AUC calculated
Collecting doy 120 (12/25)
no AUC calculated
Collecting doy 123 (13/25)
Collecting doy 130 (14/25)
Collecting doy 143 (15/25)
Collecting doy 180 (16/25)
Collecting doy 183 (17/25)
Collecting doy 200 (18/25)
Collecting doy 203 (19/25)
Collecting doy 220 (20/25)
Collecting doy 230 (21/25)
Collecting doy 240 (22/25)
Collecting doy 243 (23/25)
Collecting doy 250 (24/25)
no AUC calculated
Collecting doy 273 (25/25)

rnn
t0 could not calculate accuracy metrics
Collecting doy 365 (1/25)
t1 could not calculate accuracy metrics
Collecting doy 3 (2/25)
no AUC calculated
Collecting doy 10 (3/25)
no AUC calculated
Collecting doy 13 (4/25)
no AUC calculated
Collecting doy 20 (5/25)
Collecting doy 43 (6/25)
no AUC calculated
Collecting doy 70 (7/25)
Collecting doy 80 (8/25)
Collecting doy 103 (9/25)
Collecting doy 110 (10/25)
no AUC calculated
Collecting doy 113 (11/25)
no AUC calculated
Collecting doy 120 (12/25)
no AUC calculated
Collecting doy 123 (13/25)
Collecting doy 130 (14/25)
Collecting doy 143 (15/25)
Collecting doy 180 (16/25)
Collecting doy 183 (17/25)
Collecting doy 200 (18/25)
Collecting doy 203 (19/25)
Collecting doy 220 (20/25)
Collecting doy 230 (21/25)
Collecting doy 240 (22/25)
Collecting doy 243 (23/25)
Collecting doy 250 (24/25)
no AUC calculated
Collecting doy 273 (25/25)

cnn
t0 could not calculate accuracy metrics
Collecting doy 365 (1/25)
t1 could not calculate accuracy metrics
Collecting doy 3 (2/25)
t2 could not calculate accuracy metrics
Collecting doy 10 (3/25)
t3 could not calculate accuracy metrics
Collecting doy 13 (4/25)
t4 could not calculate accuracy metrics
Collecting doy 20 (5/25)
t5 could not calculate accuracy metrics
Collecting doy 43 (6/25)
t6 could not calculate accuracy metrics
Collecting doy 70 (7/25)
t7 could not calculate accuracy metrics
Collecting doy 80 (8/25)
t8 could not calculate accuracy metrics
Collecting doy 103 (9/25)
t9 could not calculate accuracy metrics
Collecting doy 110 (10/25)
t10 could not calculate accuracy metrics
Collecting doy 113 (11/25)
t11 could not calculate accuracy metrics
Collecting doy 120 (12/25)
t12 could not calculate accuracy metrics
Collecting doy 123 (13/25)
t13 could not calculate accuracy metrics
Collecting doy 130 (14/25)
t14 could not calculate accuracy metrics
Collecting doy 143 (15/25)
t15 could not calculate accuracy metrics
Collecting doy 180 (16/25)
t16 could not calculate accuracy metrics
Collecting doy 183 (17/25)
t17 could not calculate accuracy metrics
Collecting doy 200 (18/25)
t18 could not calculate accuracy metrics
Collecting doy 203 (19/25)
t19 could not calculate accuracy metrics
Collecting doy 220 (20/25)
t20 could not calculate accuracy metrics
Collecting doy 230 (21/25)
t21 could not calculate accuracy metrics
Collecting doy 240 (22/25)
t22 could not calculate accuracy metrics
Collecting doy 243 (23/25)
t23 could not calculate accuracy metrics
Collecting doy 250 (24/25)
t24 could not calculate accuracy metrics
Collecting doy 273 (25/25)

lstm
t0 could not calculate accuracy metrics
Collecting doy 365 (1/25)
t1 could not calculate accuracy metrics
Collecting doy 3 (2/25)
no AUC calculated
Collecting doy 10 (3/25)
no AUC calculated
Collecting doy 13 (4/25)
no AUC calculated
Collecting doy 20 (5/25)
Collecting doy 43 (6/25)
no AUC calculated
Collecting doy 70 (7/25)
Collecting doy 80 (8/25)
Collecting doy 103 (9/25)
Collecting doy 110 (10/25)
no AUC calculated
Collecting doy 113 (11/25)
no AUC calculated
Collecting doy 120 (12/25)
no AUC calculated
Collecting doy 123 (13/25)
Collecting doy 130 (14/25)
Collecting doy 143 (15/25)
Collecting doy 180 (16/25)
Collecting doy 183 (17/25)
Collecting doy 200 (18/25)
Collecting doy 203 (19/25)
Collecting doy 220 (20/25)
Collecting doy 230 (21/25)
Collecting doy 240 (22/25)
Collecting doy 243 (23/25)
Collecting doy 250 (24/25)
no AUC calculated
Collecting doy 273 (25/25)

rnn
t0 could not calculate accuracy metrics
Collecting doy 365 (1/25)
t1 could not calculate accuracy metrics
Collecting doy 3 (2/25)
no AUC calculated
Collecting doy 10 (3/25)
no AUC calculated
Collecting doy 13 (4/25)
no AUC calculated
Collecting doy 20 (5/25)
Collecting doy 43 (6/25)
no AUC calculated
Collecting doy 70 (7/25)
Collecting doy 80 (8/25)
Collecting doy 103 (9/25)
Collecting doy 110 (10/25)
no AUC calculated
Collecting doy 113 (11/25)
no AUC calculated
Collecting doy 120 (12/25)
no AUC calculated
Collecting doy 123 (13/25)
Collecting doy 130 (14/25)
Collecting doy 143 (15/25)
Collecting doy 180 (16/25)
Collecting doy 183 (17/25)
Collecting doy 200 (18/25)
Collecting doy 203 (19/25)
Collecting doy 220 (20/25)
Collecting doy 230 (21/25)
Collecting doy 240 (22/25)
Collecting doy 243 (23/25)
Collecting doy 250 (24/25)
no AUC calculated
Collecting doy 273 (25/25)

cnn
t0 could not calculate accuracy metrics
Collecting doy 365 (1/25)
t1 could not calculate accuracy metrics
Collecting doy 3 (2/25)
t2 could not calculate accuracy metrics
Collecting doy 10 (3/25)
t3 could not calculate accuracy metrics
Collecting doy 13 (4/25)
t4 could not calculate accuracy metrics
Collecting doy 20 (5/25)
t5 could not calculate accuracy metrics
Collecting doy 43 (6/25)
t6 could not calculate accuracy metrics
Collecting doy 70 (7/25)
t7 could not calculate accuracy metrics
Collecting doy 80 (8/25)
t8 could not calculate accuracy metrics
Collecting doy 103 (9/25)
t9 could not calculate accuracy metrics
Collecting doy 110 (10/25)
t10 could not calculate accuracy metrics
Collecting doy 113 (11/25)
t11 could not calculate accuracy metrics
Collecting doy 120 (12/25)
t12 could not calculate accuracy metrics
Collecting doy 123 (13/25)
t13 could not calculate accuracy metrics
Collecting doy 130 (14/25)
t14 could not calculate accuracy metrics
Collecting doy 143 (15/25)
t15 could not calculate accuracy metrics
Collecting doy 180 (16/25)
t16 could not calculate accuracy metrics
Collecting doy 183 (17/25)
t17 could not calculate accuracy metrics
Collecting doy 200 (18/25)
t18 could not calculate accuracy metrics
Collecting doy 203 (19/25)
t19 could not calculate accuracy metrics
Collecting doy 220 (20/25)
t20 could not calculate accuracy metrics
Collecting doy 230 (21/25)
t21 could not calculate accuracy metrics
Collecting doy 240 (22/25)
t22 could not calculate accuracy metrics
Collecting doy 243 (23/25)
t23 could not calculate accuracy metrics
Collecting doy 250 (24/25)
t24 could not calculate accuracy metrics
Collecting doy 273 (25/25)

lstm
t0 could not calculate accuracy metrics
Collecting doy 365 (1/25)
t1 could not calculate accuracy metrics
Collecting doy 3 (2/25)
no AUC calculated
Collecting doy 10 (3/25)
no AUC calculated
Collecting doy 13 (4/25)
no AUC calculated
Collecting doy 20 (5/25)
Collecting doy 43 (6/25)
no AUC calculated
Collecting doy 70 (7/25)
Collecting doy 80 (8/25)
Collecting doy 103 (9/25)
Collecting doy 110 (10/25)
no AUC calculated
Collecting doy 113 (11/25)
no AUC calculated
Collecting doy 120 (12/25)
no AUC calculated
Collecting doy 123 (13/25)
Collecting doy 130 (14/25)
Collecting doy 143 (15/25)
Collecting doy 180 (16/25)
Collecting doy 183 (17/25)
Collecting doy 200 (18/25)
Collecting doy 203 (19/25)
Collecting doy 220 (20/25)
Collecting doy 230 (21/25)
Collecting doy 240 (22/25)
Collecting doy 243 (23/25)
Collecting doy 250 (24/25)
no AUC calculated
Collecting doy 273 (25/25)

rnn
t0 could not calculate accuracy metrics
Collecting doy 365 (1/25)
t1 could not calculate accuracy metrics
Collecting doy 3 (2/25)
no AUC calculated
Collecting doy 10 (3/25)
no AUC calculated
Collecting doy 13 (4/25)
no AUC calculated
Collecting doy 20 (5/25)
Collecting doy 43 (6/25)
no AUC calculated
Collecting doy 70 (7/25)
Collecting doy 80 (8/25)
Collecting doy 103 (9/25)
Collecting doy 110 (10/25)
no AUC calculated
Collecting doy 113 (11/25)
no AUC calculated
Collecting doy 120 (12/25)
no AUC calculated
Collecting doy 123 (13/25)
Collecting doy 130 (14/25)
Collecting doy 143 (15/25)
Collecting doy 180 (16/25)
Collecting doy 183 (17/25)
Collecting doy 200 (18/25)
Collecting doy 203 (19/25)
Collecting doy 220 (20/25)
Collecting doy 230 (21/25)
Collecting doy 240 (22/25)
Collecting doy 243 (23/25)
Collecting doy 250 (24/25)
no AUC calculated
Collecting doy 273 (25/25)

cnn
t0 could not calculate accuracy metrics
Collecting doy 365 (1/25)
t1 could not calculate accuracy metrics
Collecting doy 3 (2/25)
t2 could not calculate accuracy metrics
Collecting doy 10 (3/25)
t3 could not calculate accuracy metrics
Collecting doy 13 (4/25)
t4 could not calculate accuracy metrics
Collecting doy 20 (5/25)
t5 could not calculate accuracy metrics
Collecting doy 43 (6/25)
t6 could not calculate accuracy metrics
Collecting doy 70 (7/25)
t7 could not calculate accuracy metrics
Collecting doy 80 (8/25)
t8 could not calculate accuracy metrics
Collecting doy 103 (9/25)
t9 could not calculate accuracy metrics
Collecting doy 110 (10/25)
t10 could not calculate accuracy metrics
Collecting doy 113 (11/25)
t11 could not calculate accuracy metrics
Collecting doy 120 (12/25)
t12 could not calculate accuracy metrics
Collecting doy 123 (13/25)
t13 could not calculate accuracy metrics
Collecting doy 130 (14/25)
t14 could not calculate accuracy metrics
Collecting doy 143 (15/25)
t15 could not calculate accuracy metrics
Collecting doy 180 (16/25)
t16 could not calculate accuracy metrics
Collecting doy 183 (17/25)
t17 could not calculate accuracy metrics
Collecting doy 200 (18/25)
t18 could not calculate accuracy metrics
Collecting doy 203 (19/25)
t19 could not calculate accuracy metrics
Collecting doy 220 (20/25)
t20 could not calculate accuracy metrics
Collecting doy 230 (21/25)
t21 could not calculate accuracy metrics
Collecting doy 240 (22/25)
t22 could not calculate accuracy metrics
Collecting doy 243 (23/25)
t23 could not calculate accuracy metrics
Collecting doy 250 (24/25)
t24 could not calculate accuracy metrics
Collecting doy 273 (25/25)

lstm
t0 could not calculate accuracy metrics
Collecting doy 365 (1/25)
t1 could not calculate accuracy metrics
Collecting doy 3 (2/25)
no AUC calculated
Collecting doy 10 (3/25)
no AUC calculated
Collecting doy 13 (4/25)
no AUC calculated
Collecting doy 20 (5/25)
Collecting doy 43 (6/25)
no AUC calculated
Collecting doy 70 (7/25)
Collecting doy 80 (8/25)
Collecting doy 103 (9/25)
Collecting doy 110 (10/25)
no AUC calculated
Collecting doy 113 (11/25)
no AUC calculated
Collecting doy 120 (12/25)
no AUC calculated
Collecting doy 123 (13/25)
Collecting doy 130 (14/25)
Collecting doy 143 (15/25)
Collecting doy 180 (16/25)
Collecting doy 183 (17/25)
Collecting doy 200 (18/25)
Collecting doy 203 (19/25)
Collecting doy 220 (20/25)
Collecting doy 230 (21/25)
Collecting doy 240 (22/25)
Collecting doy 243 (23/25)
Collecting doy 250 (24/25)
no AUC calculated
Collecting doy 273 (25/25)

rnn
t0 could not calculate accuracy metrics
Collecting doy 365 (1/25)
t1 could not calculate accuracy metrics
Collecting doy 3 (2/25)
no AUC calculated
Collecting doy 10 (3/25)
no AUC calculated
Collecting doy 13 (4/25)
no AUC calculated
Collecting doy 20 (5/25)
Collecting doy 43 (6/25)
no AUC calculated
Collecting doy 70 (7/25)
Collecting doy 80 (8/25)
Collecting doy 103 (9/25)
Collecting doy 110 (10/25)
no AUC calculated
Collecting doy 113 (11/25)
no AUC calculated
Collecting doy 120 (12/25)
no AUC calculated
Collecting doy 123 (13/25)
Collecting doy 130 (14/25)
Collecting doy 143 (15/25)
Collecting doy 180 (16/25)
Collecting doy 183 (17/25)
Collecting doy 200 (18/25)
Collecting doy 203 (19/25)
Collecting doy 220 (20/25)
Collecting doy 230 (21/25)
Collecting doy 240 (22/25)
Collecting doy 243 (23/25)
Collecting doy 250 (24/25)
no AUC calculated
Collecting doy 273 (25/25)

cnn
t0 could not calculate accuracy metrics
Collecting doy 365 (1/25)
t1 could not calculate accuracy metrics
Collecting doy 3 (2/25)
t2 could not calculate accuracy metrics
Collecting doy 10 (3/25)
t3 could not calculate accuracy metrics
Collecting doy 13 (4/25)
t4 could not calculate accuracy metrics
Collecting doy 20 (5/25)
t5 could not calculate accuracy metrics
Collecting doy 43 (6/25)
t6 could not calculate accuracy metrics
Collecting doy 70 (7/25)
t7 could not calculate accuracy metrics
Collecting doy 80 (8/25)
t8 could not calculate accuracy metrics
Collecting doy 103 (9/25)
t9 could not calculate accuracy metrics
Collecting doy 110 (10/25)
t10 could not calculate accuracy metrics
Collecting doy 113 (11/25)
t11 could not calculate accuracy metrics
Collecting doy 120 (12/25)
t12 could not calculate accuracy metrics
Collecting doy 123 (13/25)
t13 could not calculate accuracy metrics
Collecting doy 130 (14/25)
t14 could not calculate accuracy metrics
Collecting doy 143 (15/25)
t15 could not calculate accuracy metrics
Collecting doy 180 (16/25)
t16 could not calculate accuracy metrics
Collecting doy 183 (17/25)
t17 could not calculate accuracy metrics
Collecting doy 200 (18/25)
t18 could not calculate accuracy metrics
Collecting doy 203 (19/25)
t19 could not calculate accuracy metrics
Collecting doy 220 (20/25)
t20 could not calculate accuracy metrics
Collecting doy 230 (21/25)
t21 could not calculate accuracy metrics
Collecting doy 240 (22/25)
t22 could not calculate accuracy metrics
Collecting doy 243 (23/25)
t23 could not calculate accuracy metrics
Collecting doy 250 (24/25)
t24 could not calculate accuracy metrics
Collecting doy 273 (25/25)

In [39]:
prec_df.T


Out[39]:
doy 365 3 10 13 20 43 70 80 103 110 ... 180 183 200 203 220 230 240 243 250 273
lstm NaN NaN 0.570888 0.353666 0.476102 0.653447 0.667034 0.694907 0.672883 0.707204 ... 0.799102 0.803849 0.814197 0.820373 0.818037 0.821337 0.823234 0.826440 0.818529 0.820762
rnn NaN NaN 0.532238 0.379987 0.433551 0.631502 0.638370 0.663115 0.655155 0.682512 ... 0.788752 0.797388 0.807323 0.815302 0.811250 0.814474 0.815652 0.817766 0.809460 0.814352
cnn NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN ... NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN

3 rows × 25 columns


In [40]:
x = range(len(t))

def plot_acctime(x,acc_df,metric="measure"):
    f,ax = plt.subplots()
    #ax.plot(x,oa_s.values, label="overall accuracy")
    for col in acc_df.columns:
        ax.plot(x,acc_df[col].values, label=col)
    plt.xticks(x,t, rotation='vertical')
    ax.set_xlabel("day of year")
    ax.set_ylabel(metric)
    plt.legend()
    
# 0 lstm, 1 rnn, 2 cnn
#plot_acctime(x,prec_df)

In [41]:
plot_acctime(x,acc_df,metric="accuracy")



In [42]:
plot_acctime(x,prec_df, metric="precision")



In [43]:
plot_acctime(x,rec_df,metric="recall")



In [44]:
plot_acctime(x,kappa_df,metric="kappa")



In [45]:
# Write oen metric to file
prec_df.to_csv(os.path.join(image_filepath,"seqtimeacc.dat"), sep=' ', header=False, index=True)

Plot of evauated scores by observation index

A score is the calculated probability.

x axis: observation y axis: feature


In [46]:
plt.figure(figsize=(9,11))

obs_idx = 23
max_n = 30

sc = scores[obs==obs_idx][:max_n,:]

sc_ = np.sum(sc,axis=0)

sc_ = sc_/np.sum(sc_)

f, ax = plt.subplots(1,2, figsize=(13,10))
ax[0].imshow(np.log(sc+1e-10))
#ax[1].imshow(np.log(targets[obs==obs_idx][:max_n,:]+1e-10))
#ax[1].colorbar()


Out[46]:
<matplotlib.image.AxesImage at 0x7f1b386d2bd0>
<matplotlib.figure.Figure at 0x7f1b386d2d10>

Ground Truth vs evaluated scores

targets are ground truth. Can be interpreted as one hot vectors with 100% probability at one class

scores are evaluated probabilities.


In [47]:
plt.imshow(targets[obs==obs_idx][:max_n,:])


Out[47]:
<matplotlib.image.AxesImage at 0x7f1b397dee50>

In [48]:
plt.imshow(scores[obs==obs_idx][:max_n,:])


Out[48]:
<matplotlib.image.AxesImage at 0x7f1b3b219990>

Class Distribution Histograms

class distributions of training+validation (traintest) and evaluation datasets

ignores covered classes


In [49]:
#import psycopg2
import cPickle as pickle
import StringIO
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns

reload(matplotlib.pyplot)
data = None

def query(conn, column="is_train", boolkey="True", b=4):

    sql = """
    select d.y_data 
    from raster_label_fields d, point_grid p 
    where d.field_id=p.id AND d.{0}={1}""".format(column,boolkey)

    y_data = pd.read_sql(sql,conn).values[:,0]

    y = []

    # extract predominant field class in measurement
    for y_data_ in y_data:
        y_data_ = StringIO.StringIO(y_data_)
        all_classes = pickle.load(y_data_)
        field_classes = all_classes[b:] # exclude cloud-other
        fields = np.sum(field_classes,axis=1) # sum up all observations
        y_ = np.argmax(fields) # take index of maximum class
        y.append(y_+b)
        
    n,b = np.histogram(y,bins=23, range=(0,23))
        
    return pd.Series(data=n)

In [51]:
#conn = psycopg2.connect('postgres://russwurm:dbfieldpassword@localhost/dbBayField')

#classes_df = pd.read_sql("select * from label_lookup",conn)["name"]
#classes_df.to_pickle(os.path.join("loc","classes.pkl"))

classes_df = pd.read_pickle(os.path.join("loc","classes.pkl"))

classes = classes_df.values

# very long query
if False:
    y_traintest = query(conn, column="is_evaluate", boolkey="False", b=4)
    y_evaluate = query(conn, column="is_evaluate", boolkey="True", b=4)
    
    y_traintest.to_pickle(os.path.join("loc","y_traintest.pkl"))
    y_evaluate.to_pickle(os.path.join("loc","y_evaluate.pkl"))
else:
    y_traintest = pd.read_pickle(os.path.join("loc","y_traintest.pkl"))
    y_evaluate = pd.read_pickle(os.path.join("loc","y_evaluate.pkl"))

In [52]:
data = pd.DataFrame([y_traintest/y_traintest.sum(), y_evaluate/y_evaluate.sum()])

data.T.to_csv(os.path.join(image_filepath,"classHistograms.dat"), index=True)

In [53]:
data


Out[53]:
0 1 2 3 4 5 6 7 8 9 ... 13 14 15 16 17 18 19 20 21 22
0 0.0 0.0 0.0 0.0 0.364238 0.170128 0.140738 0.003524 0.025900 0.013106 ... 0.139887 0.050676 0.005954 0.00311 0.011474 0.015648 0.002370 0.020461 0.002985 0.007512
1 0.0 0.0 0.0 0.0 0.341125 0.167322 0.142725 0.002306 0.024841 0.018343 ... 0.149800 0.054504 0.006638 0.00297 0.011373 0.018535 0.002271 0.021557 0.003389 0.007459

2 rows × 23 columns


In [54]:
y_traintest


Out[54]:
0          0
1          0
2          0
3          0
4     108217
5      50546
6      41814
7       1047
8       7695
9       3894
10      1775
11      1670
12      3177
13     41561
14     15056
15      1769
16       924
17      3409
18      4649
19       704
20      6079
21       887
22      2232
dtype: int64

In [55]:
y_evaluate


Out[55]:
0         0
1         0
2         0
3         0
4     19527
5      9578
6      8170
7       132
8      1422
9      1050
10      339
11      426
12      657
13     8575
14     3120
15      380
16      170
17      651
18     1061
19      130
20     1234
21      194
22      427
dtype: int64

In [ ]: