In [1]:
%matplotlib inline
import sys
import numpy as np
from htmresearch.frameworks.pytorch.continual_learning_experiment import \
BaselineContinualLearningExperiment
import matplotlib.pyplot as plt
In [2]:
sys.argv = ['prog','-c', 'experiments.cfg']
suite = BaselineContinualLearningExperiment()
suite.parse_opt()
suite.parse_cfg()
In [4]:
fig, ax = plt.subplots(1, 1)
for exp in suite.get_exps():
# Each iteration represents a single task.
# Use get_history to get the accuracy for all tasks.
data = suite.get_history(exp, 0, 'accuracy')
plt.plot(data, label=exp)
plt.ylabel('Accuracy')
plt.xlabel('Tasks')
plt.xticks(range(5))
plt.legend()
Out[4]: