In [ ]:
from rl.callbacks import TrainEpisodeLogger
import matplotlib.pyplot as plt
%matplotlib notebook
import warnings
import timeit
import json

import numpy as np

In [ ]:
class TestOutputLogger(TrainEpisodeLogger):
    def __init__(self, hist_data):
        self.logs = []
        self.hist_data = hist_data
        super(TestOutputLogger, self).__init__()

    def on_train_begin(self, logs):
        pass

    def on_train_end(self, logs):
        pass

    def on_episode_begin(self, episode, logs):
        pass

    def on_step_end(self, step, logs):
        self.logs.append(logs)
        #print(logs)

    def on_episode_end(self, episode, logs):
        x = self.hist_data.data().index
        y_reward = [step['reward'] for step in self.logs]
        y_price =  self.hist_data.data().loc[:, ['Close']].values
        y_reward.insert(0, np.nan)
        
        fig, ax1 = plt.subplots()
        ax1.plot(y_reward)
        ax2 = ax1.twinx()
        ax2.plot(y_price)
        plt.show()