In [ ]:
import warnings
import itertools
import rl.callbacks
import numpy as np
from google_cloud_storage import GoogleCloudStorage


Using TensorFlow backend.

In [ ]:
class ModelSaver(rl.callbacks.TrainEpisodeLogger):
    def __init__(self, filepath, monitor='loss', verbose=1, 
                 save_best_only=True, mode='min', save_weights_only=False,
                 upload_to_gcs=True,
                 logger=None):
        if filepath is None:
            raise ValueError('Give value to filepath. (Given: %s)' % filepath)
        self.best_monitor_value = None
        self.filepath = filepath
        self.monitor = monitor
        self.verbose = verbose
        self.save_best_only = save_best_only
        self.mode = mode
        self.save_weights_only = save_weights_only
        if mode not in ('min', 'max'):
            raise ValueError("Give 'min' or 'max' to mode. (Given: %s)" % mode)
        self.mode = mode
        if upload_to_gcs:
            self.gcs = GoogleCloudStorage()
        else:
            self.gcs = None
        self._logger = logger
        
        super().__init__()

    def on_episode_end(self, episode, logs):
        self._logger.warn('========== Model Saver output ==============')
        try:
            monitor_value = float(self._formatted_metrics(episode)[self.monitor])
        except:
            monitor_value = 0.0
        self._logger.warn('%s value: %e' % (self.monitor, monitor_value))
        values = {'episode': episode, self.monitor: monitor_value}
        if not self.save_best_only:
            values['previous_monitor'] = monitor_value
            self._save_model(values)            
        elif self.best_monitor_value is None or self._is_this_episode_improved(monitor_value):
            previous_value = self.best_monitor_value
            self.best_monitor_value = monitor_value
            values['previous_monitor'] = previous_value
            self._save_model(values)
            self._logger.warn('%s %s value: %e' % (self.mode, self.monitor, self.best_monitor_value))
        #except:
        #    self._logger.warn('Not a float value given.')
        self._logger.warn('========== /Model Saver output =============')
        super().on_episode_end(episode, logs)

    def _is_this_episode_improved(self, monitor_value):
        if self.mode == 'min':
            return monitor_value < self.best_monitor_value
        else:
            return monitor_value > self.best_monitor_value
        
    def _save_model(self, kwargs):
        previous_monitor = kwargs['previous_monitor']
        filepath = self.filepath.format_map(kwargs)
        if self.verbose > 0:
            self._logger.warn("Step %08d: model improved\n  from %e\n    to %e,"
                  ' saving model to %s'
                  % (self.step, previous_monitor or 0.0,
                     self.best_monitor_value or 0.0, filepath))
        if self.save_weights_only:
            saved_file_path = filepath + '.hdf5'
            self.model.save_weights(saved_file_path, overwrite=True)
            self._logger.warn('Save weights to %s has done.' % filepath)
        else:
            saved_file_path = filepath + '.h5'
            self.model.model.save(saved_file_path, overwrite=True)
            self._logger.warn('Save model to %s has done.' % filepath)
        self._upload_model_to_gcs(saved_file_path)
        self._logger.warn('Save file %s to GCS has done.' % filepath)

    def _upload_model_to_gcs(self, model_file_path):
        if not self.gcs:
            return
        self.gcs.upload_model(model_file_path)

    def _formatted_metrics(self, episode):
        # Format all metrics.
        metrics = np.array(self.metrics[episode])
        metrics_variables = []
        with warnings.catch_warnings():
            warnings.filterwarnings('error')
            for idx, name in enumerate(self.metrics_names):
                try:
                    value = np.nanmean(metrics[:, idx])
                except Warning:
                    if name == 'loss':
                        value = float('inf')
                    else:
                        value = '--'
                metrics_variables += [name, value]
        return dict(itertools.zip_longest(*[iter(metrics_variables)] * 2, fillvalue=""))