In [1]:
%matplotlib inline
DEFAULT_FIGSIZE = (12, 8)
import itertools
import sys
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.lines as mlines
import seaborn as sns
sns.set_style('darkgrid')
import pandas as pd
sys.path.append('..')
from antlia.plotdf import plotjoint
%load_ext autoreload
%autoreload 2
import matplotlib as mpl
mpl.rcParams['figure.figsize'] = DEFAULT_FIGSIZE
In [2]:
braking_df = pd.read_pickle('braking_ttc.p.gz')
steering_df = pd.read_pickle('steering_ttc.p.gz')
In [3]:
#%matplotlib notebook
#mpl.rcParams['figure.figsize'] = DEFAULT_FIGSIZE
def set_swarmplot_labels(ax, event_types):
handles, labels = ax.get_legend_handles_labels()
n = len(labels)
assert n % 2 == 0 # check if even
labels = ['{} {}'.format(event_type, rider_id)
for event_type, rider_id in itertools.product(
event_types, labels[:n//2])]
ax.legend(handles, labels)
def shift_swarmplot(ax, shift_value=0.05):
collections = ax.collections
n = len(collections)
# collections are ordered as:
# [c0, c1, c2, c3, x, x, x, x,
# c4, c5, c7, c6, x, x, x, x]
collections = [c for i, c in enumerate(collections)
if (i//n) % n//4 == 0]
for i, c in enumerate(collections):
if i < n//4:
sign = -1
else:
sign = 1
offsets = c.get_offsets()
shifts = np.matlib.repmat(
np.array([[sign*shift_value, 0]]),
offsets.shape[0], 1)
c.set_offsets(offsets + shifts)
def plot_combined_swarm(y, df_filter=None):
"""
y: dataframe column
df_filter: dataframe filter function, description
"""
color_map = sns.color_palette('Paired', 10)
fig, ax = plt.subplots()
def apply_filter(df):
if df_filter is None:
return df
else:
f = df_filter[0]
return df[f(df)]
sns.swarmplot(x='rider id', y=y, data=apply_filter(braking_df), ax=ax,
hue='rider id', palette=color_map[::2])
sns.swarmplot(x='rider id', y=y, data=apply_filter(steering_df), ax=ax,
hue='rider id', palette=color_map[1::2])
set_swarmplot_labels(ax, ['braking', 'steering'])
shift_swarmplot(ax)
title = 'rider vs {}'.format(y)
if df_filter is not None:
title = '{} ({})'.format(title, df_filter[1])
ax.set_title(title)
return fig, ax
# dataframe filter
valid_velocities = (lambda x: x['starting velocity'] > 1,
'starting velocity > 1')
plt.close('all')
fig, ax = plot_combined_swarm('starting velocity', valid_velocities)
plot_combined_swarm('distance-to-collision')
plot_combined_swarm('time-to-collision', valid_velocities)
plt.show()
In [4]:
#%matplotlib notebook
#mpl.rcParams['figure.figsize'] = DEFAULT_FIGSIZE
def plot_combined_joint(x, y, df_filter=None, key=None, figsize=None):
"""
x: dataframe column
y: dataframe column
df_filter: dataframe filter function, description
"""
def apply_filter(df):
if df_filter is None:
return df
else:
f = df_filter[0]
return df[f(df)]
labels = ['braking', 'steering']
if key is None:
g = plotjoint(x, y, apply_filter(braking_df))
g = plotjoint(x, y, apply_filter(steering_df), g=g)
g.ax_joint.legend(labels)
else:
markers = ['o', '+']
color_map = sns.color_palette('Paired', 10)
g = plotjoint(x, y, apply_filter(braking_df),
(key, color_map[::2]), marker=markers[0])
g = plotjoint(x, y, apply_filter(steering_df),
(key, color_map[1::2]), g=g, marker=markers[1])
handles = []
for (event_type, colors), k in itertools.product(
zip(labels, [color_map[::2], color_map[1::2]]),
steering_df[key].unique()):
if event_type == 'braking':
marker = markers[0]
else:
marker = markers[1]
h = mlines.Line2D(
[], [], color=colors[k], marker=marker,
markersize=6, linestyle=' ',
label='{} {}'.format(event_type, k))
handles.append(h)
g.ax_joint.legend(handles=handles)
if figsize is None:
figsize = mpl.rcParams['figure.figsize']
if figsize is not None:
g.fig.set_size_inches(figsize)
g.fig.subplots_adjust(top=0.9)
title = '{} vs {}'.format(x, y)
if df_filter is not None:
title = '{} ({})'.format(title, df_filter[1])
g.fig.suptitle(title)
return g
In [5]:
plt.close('all')
g = plot_combined_joint('starting velocity', 'distance-to-collision', valid_velocities)
g = plot_combined_joint('starting velocity', 'time-to-collision', valid_velocities)
plt.show()
In [6]:
plt.close('all')
g = plot_combined_joint('starting velocity', 'distance-to-collision', valid_velocities, 'rider id')
g = plot_combined_joint('starting velocity', 'time-to-collision', valid_velocities, 'rider id')
plt.show()