In [1]:
import numpy as np
import pandas as pd
import os
import seaborn as sns
import matplotlib.pyplot as plt
from sklearn.metrics import fbeta_score
from sklearn.metrics import roc_curve, auc
from sklearn.metrics import precision_recall_curve
from sklearn.metrics import average_precision_score
%matplotlib inline
In [2]:
LABELS = {'primary': 0,
'clear': 1,
'agriculture': 2,
'road': 3,
'water': 4,
'partly_cloudy': 5,
'cultivation': 6,
'habitation': 7,
'haze': 8,
'cloudy': 9,
'bare_ground': 10,
'selective_logging': 11,
'artisinal_mine': 12,
'blooming': 13,
'slash_burn': 14,
'blow_down': 15,
'conventional_mine': 16}
EVALUATION_PATH = os.path.join(r'../reports/planet_validation')
In [3]:
hashtable = dict([[v,k] for k,v in LABELS.items()])
label_names = list(hashtable.values())
In [4]:
total_size = 40483
train_size = int(np.floor(40483 * 0.8)) + 1
validation_size = total_size - train_size
In [5]:
labels = np.loadtxt(os.path.join(EVALUATION_PATH, 'out_labels_train.txt'))[0 : train_size, :]
preds = np.loadtxt(os.path.join(EVALUATION_PATH, 'out_predictions_train.txt'))[0 : train_size, :]
In [6]:
plt.figure(figsize=(10,10))
s = pd.Series(labels.sum(axis=0), index=label_names)
s.plot.pie(fontsize=15)
s
Out[6]:
In [7]:
df = pd.DataFrame(preds.reshape(-1, 1), columns=['pred'])
df['label'] = labels.reshape(-1, 1)
df['tag'] = label_names * train_size
In [8]:
g = sns.FacetGrid(df, col='tag', col_wrap=4)
g = g.map(plt.hist, 'pred', bins=np.arange(0,1.05,0.05), color='r', alpha=0.5)
g = g.map(plt.hist, 'label', bins=np.arange(0,1.05,0.05), color='c', alpha=0.5)
Conclusion: Due to the unbalanced data, the model tends to be conservative, maybe try to oversample the minority labels or add penalty to false negative.
In [9]:
thresh_list = np.arange(0, 1.01, 0.01)
fbeta_mat = np.empty((len(thresh_list), len(label_names)))
for idx, thresh in enumerate(thresh_list):
score = fbeta_score(labels, preds > thresh, beta=2, average=None)
fbeta_mat[idx, :] = score
In [10]:
plt.figure(figsize=(20,10))
cm = plt.get_cmap('rainbow')
cm = [cm(i) for i in np.linspace(0,1,17)]
for i in range(len(label_names)):
plt.plot(thresh_list, fbeta_mat[:, i], color=cm[i])
plt.legend(label_names, fontsize=18)
plt.tick_params(axis='both', which='major', labelsize=15)
plt.xlabel('threshold', fontsize=18)
plt.ylabel('fbeta score', fontsize=18)
Out[10]:
conclusion: The fbeta_score for minority labels looks very poor. Also, the optimal thresholds for each labels are not consistent at all.
In [11]:
weather_labels = ['clear', 'partly_cloudy', 'haze', 'cloudy']
common_labels = ['primary', 'agriculture', 'road', 'water', 'habitation', 'cultivation', 'bare_ground']
rare_labels = ['selective_logging', 'artisinal_mine', 'blooming', 'slash_burn', 'blow_down', 'conventional_mine']
In [12]:
fpr = dict()
tpr = dict()
roc_auc = dict()
for idx, label in enumerate(weather_labels):
fpr[label], tpr[label], _ = roc_curve(labels[:, LABELS[label]], preds[:, LABELS[label]])
roc_auc[label] = auc(fpr[label], tpr[label])
plt.figure(figsize=(10,5))
plt.plot([0, 1], [0, 1], color='navy', lw=2, linestyle='--')
for label in weather_labels:
plt.plot(fpr[label], tpr[label], lw=2, label='{} (area = {:.3f})'.format(label, roc_auc[label]))
plt.xlim([0.0, 1.0])
plt.ylim([0.0, 1.05])
plt.xlabel('False Positive Rate', fontsize=16)
plt.ylabel('True Positive Rate', fontsize=16)
plt.tick_params(axis='both', which='major', labelsize=15)
plt.legend(loc='lower right', fontsize=16)
Out[12]:
In [13]:
fpr = dict()
tpr = dict()
roc_auc = dict()
for idx, label in enumerate(common_labels):
fpr[label], tpr[label], _ = roc_curve(labels[:, LABELS[label]], preds[:, LABELS[label]])
roc_auc[label] = auc(fpr[label], tpr[label])
plt.figure(figsize=(10,5))
plt.plot([0, 1], [0, 1], color='navy', lw=2, linestyle='--')
for label in common_labels:
plt.plot(fpr[label], tpr[label], lw=2, label='{} (area = {:.3f})'.format(label, roc_auc[label]))
plt.xlim([0.0, 1.0])
plt.ylim([0.0, 1.05])
plt.xlabel('False Positive Rate', fontsize=16)
plt.ylabel('True Positive Rate', fontsize=16)
plt.tick_params(axis='both', which='major', labelsize=15)
plt.legend(loc='lower right', fontsize=16)
Out[13]:
In [14]:
fpr = dict()
tpr = dict()
roc_auc = dict()
for idx, label in enumerate(rare_labels):
fpr[label], tpr[label], _ = roc_curve(labels[:, LABELS[label]], preds[:, LABELS[label]])
roc_auc[label] = auc(fpr[label], tpr[label])
plt.figure(figsize=(10,5))
plt.plot([0, 1], [0, 1], color='navy', lw=2, linestyle='--')
for label in rare_labels:
plt.plot(fpr[label], tpr[label], lw=2, label='{} (area = {:.3f})'.format(label, roc_auc[label]))
plt.xlim([0.0, 1.0])
plt.ylim([0.0, 1.05])
plt.xlabel('False Positive Rate', fontsize=16)
plt.ylabel('True Positive Rate', fontsize=16)
plt.tick_params(axis='both', which='major', labelsize=15)
plt.legend(loc='lower right', fontsize=16)
Out[14]:
Conclusion: The ROC curve is unexpectedly good (all of them are higher than 0.85), which is inconsistent with the fscore plot in section 1.2. I think it might be not a good indicator of the network performance.
In [15]:
precision = dict()
recall = dict()
average_precision = dict()
for idx, label in enumerate(weather_labels):
precision[label], recall[label], _ = precision_recall_curve(labels[:, LABELS[label]], preds[:, LABELS[label]])
average_precision[label] = average_precision_score(labels[:, LABELS[label]], preds[:, LABELS[label]])
plt.figure(figsize=(10,5))
for label in weather_labels:
plt.plot(recall[label], precision[label], lw=2, label='{} (area = {:.3f})'.format(label, average_precision[label]))
plt.xlim([0.0, 1.0])
plt.ylim([0.0, 1.05])
plt.xlabel('Recall', fontsize=16)
plt.ylabel('Precision', fontsize=16)
plt.tick_params(axis='both', which='major', labelsize=15)
plt.legend(loc='lower left', fontsize=16)
Out[15]:
In [16]:
precision = dict()
recall = dict()
average_precision = dict()
for idx, label in enumerate(common_labels):
precision[label], recall[label], _ = precision_recall_curve(labels[:, LABELS[label]], preds[:, LABELS[label]])
average_precision[label] = average_precision_score(labels[:, LABELS[label]], preds[:, LABELS[label]])
plt.figure(figsize=(10,5))
for label in common_labels:
plt.plot(recall[label], precision[label], lw=2, label='{} (area = {:.3f})'.format(label, average_precision[label]))
plt.xlim([0.0, 1.0])
plt.ylim([0.0, 1.05])
plt.xlabel('Recall', fontsize=16)
plt.ylabel('Precision', fontsize=16)
plt.tick_params(axis='both', which='major', labelsize=15)
plt.legend(loc='lower left', fontsize=16)
Out[16]:
In [17]:
precision = dict()
recall = dict()
average_precision = dict()
for idx, label in enumerate(rare_labels):
precision[label], recall[label], _ = precision_recall_curve(labels[:, LABELS[label]], preds[:, LABELS[label]])
average_precision[label] = average_precision_score(labels[:, LABELS[label]], preds[:, LABELS[label]])
plt.figure(figsize=(10,5))
for label in rare_labels:
plt.plot(recall[label], precision[label], lw=2, label='{} (area = {:.3f})'.format(label, average_precision[label]))
plt.xlim([0.0, 1.0])
plt.ylim([0.0, 1.05])
plt.xlabel('Recall', fontsize=16)
plt.ylabel('Precision', fontsize=16)
plt.tick_params(axis='both', which='major', labelsize=15)
plt.legend(loc='lower left', fontsize=16)
Out[17]:
Conclusion: The Precision-Recall curve is consistent with the analysis in Section 1.1 and Section 1.2. The rare labels classes perform very poorly. I think this curve can be a good indicator of the network performance.
In [ ]: