In [ ]:
%matplotlib inline

from json import loads

import matplotlib.pyplot as plt

import numpy as np

def fetch_data(path):
    rv = []
    with open(path, 'r') as fp:
        for line in fp:
            rv.append(loads(line))
    return rv

def get_subset(data, name):
    all_data = [d for d in data if d['type'] == name]
    x_axis = np.cumsum([d['bsize'] for d in all_data])
    y_axis = [d['loss'] for d in all_data]
    return x_axis, y_axis

def split_epochs(data):
    epochs = {
        'train': [],
        'val_batch': []
    }
    current_epoch = None
    for datum in data:
        t = datum['type']
        if t not in epochs:
            continue
        if current_epoch != t:
            current_epoch = t
            epochs[current_epoch].append([])
        # Should probably handle batch size too. Oh well.
        epochs[current_epoch][-1].append(datum['loss'])
    return epochs['train'], epochs['val_batch']

def plot_data(data):
    x_val = 0
    train_x, train_y, val_x, val_y = [], [], [], []
    for datum in data:
        t = datum['type']
        if t not in ['train', 'val_batch']:
            continue
        x_val += datum['bsize']
        loss = datum['loss']
        if t == 'train':
            train_y.append(loss)
            train_x.append(x_val)
        else:
            val_y.append(loss)
            val_x.append(x_val)
    # train_x, train_y = get_subset(data, 'train')
    # val_x, val_y = get_subset(data, 'val_batch')
    plt.plot(train_x, train_y, label='Training', ls=' ', marker='+')
    plt.plot(val_x, val_y, label='Validation', ls=' ', marker='v')
    plt.xlim(xmin=0)
    plt.ylim(ymin=0)
    plt.xlabel('Samples seen')
    plt.ylabel('Loss')
    plt.title('Training and validation loss')
    plt.legend()
    plt.show()
    
def do_boxplot(epochs, title):
    plt.boxplot(epochs)
    plt.ylim(ymin=0)
    plt.title(title)
    plt.ylabel('Loss')
    plt.xlabel('Epoch')
    plt.show()
    
def plot_by_epoch(data):
    train_epochs, val_epochs = split_epochs(data)
    do_boxplot(train_epochs, 'Training loss')
    do_boxplot(val_epochs, 'Validation loss')

In [ ]:
data_path = '../cache/kcnn-flow-rgb-poselet-highdrop/logs/numlog-2016-03-21T22:04:55.577568.log'

In [ ]:
all_data = fetch_data(data_path)
plot_data(all_data)
plot_by_epoch(all_data)

In [ ]: