Outdated: (Non-)Interruptibility of Sarsa(λ) and Q-Learning

Author: Richard Möhn, <my first name>.moehn@posteo.de

⭧repo

Important note: The conclusions I had drawn and written most of the text around appear not to hold up. At least, after some corrections, the data is much less clear.

Find a second, completely revised edition with more data and fancy plots here.

Abstract

%load_md abstract.md

Introduction

%load_md intro2.md

%load_md related.md

Method

I will describe the environments and learners as I set them up. The code, both in the notebook and the supporting modules, is a bit strange and rather untidy. I didn't prepare it for human consumption, so if you want to understand details, ask me and I'll tidy up or explain.

First some initialization.


In [25]:
import copy
import functools
import itertools
import math
import os
import pickle
from datetime import datetime
import string

import matplotlib
from matplotlib import pyplot
import numpy as np
import scipy.integrate

import sys
sys.path.append("..")
from hiora_cartpole import fourier_fa
from hiora_cartpole import fourier_fa_int
from hiora_cartpole import offswitch_hfa
from hiora_cartpole import linfa
from hiora_cartpole import driver
from vividict import Vividict

import gym_ext.tools as gym_tools

import gym

from hiora_cartpole import interruptibility

In [26]:
data_dir_p = "../data"

I compare the behaviour of reinforcement learners in the uninterrupted CartPole-v1 environment with that in the interrupted OffSwitchCartpole-v0 environment. The OffSwitchCartpole-v0 is one of several environments made for assessing safety properties of reinforcement learners.

OffSwitchCartpole-v0 has the same physics as CartPole-v1. The only difference is that it interrupts the learner when the cart's $x$-coordinate becomes greater than $1.0$. It signals the interruption to the learner as part of the observation it returns.


In [27]:
def make_CartPole():
    return gym.make("CartPole-v0").env # Without the TimiLimit wrapper.

def make_OffSwitchCartpole():
    return gym.make("OffSwitchCartpole-v0").env

The learners use linear function approximation with the Fourier basis [4] for mapping observations to features. Although the following code looks like it, the observations are not really clipped. I just make sure that the program tells me when they fall outside the expected range. (See here for why I can't use the observation space as provided by the environment.)


In [28]:
clipped_high = np.array([2.5, 3.6, 0.28, 3.7])
clipped_low  = -clipped_high
state_ranges = np.array([clipped_low, clipped_high])

order = 3
four_n_weights, four_feature_vec \
    = fourier_fa.make_feature_vec(state_ranges,
                                  n_acts=2,
                                  order=order)

ofour_n_weights, ofour_feature_vec \
    = offswitch_hfa.make_feature_vec(four_feature_vec, four_n_weights)
    
skip_offswitch_clip = functools.partial(
                          gym_tools.apply_to_snd, 
                          functools.partial(gym_tools.warning_clip_obs, ranges=state_ranges))

  
def ordinary_xpos(o):
    return o[0] # Don't remember why I didn't use operator.itemgetter.

The learners I assess are my own implementations of Sarsa(λ) and Q-learning. They use an AlphaBounds schedule [5] for the learning rate. The learners returned by the following functions are essentially the same. Only the stuff that has to do with mapping observations to features is slightly different, because the OffSwitchCartpole returns extra information, as I wrote above.


In [29]:
def make_uninterruptable_experience(env, choose_action=linfa.choose_action_Sarsa, gamma=1.0):
    return linfa.init(lmbda=0.9,
                        init_alpha=0.001,
                        epsi=0.1,
                        feature_vec=four_feature_vec,
                        n_weights=four_n_weights,
                        act_space=env.action_space,
                        theta=None,
                        is_use_alpha_bounds=True,
                        map_obs=functools.partial(gym_tools.warning_clip_obs, ranges=state_ranges),
                        choose_action=choose_action,
                        gamma=gamma)

def make_interruptable_experience(env, choose_action=linfa.choose_action_Sarsa, gamma=1.0):
    return linfa.init(lmbda=0.9,
                        init_alpha=0.001,
                        epsi=0.1,
                        feature_vec=ofour_feature_vec,
                        n_weights=ofour_n_weights,
                        act_space=env.action_space,
                        theta=None,
                        is_use_alpha_bounds=True,
                        map_obs=skip_offswitch_clip,
                        choose_action=choose_action,
                        gamma=gamma)

%load_md method.md


In [30]:
n_rounds = 156

Results

Just for orientation, this is how one round of training might look if you let it run for a little longer than the 200 episodes used for the evaluation. The red line shows how the learning rate develops (or rather stays the same in this case).


In [50]:
env         = make_OffSwitchCartpole()
fexperience = make_interruptable_experience(env, choose_action=linfa.choose_action_Q, gamma=0.9)
fexperience, steps_per_episode, alpha_per_episode \
    = driver.train(env, linfa, fexperience, n_episodes=200, max_steps=500, is_render=False,
                  is_continuing_env=True)
# Credits: http://matplotlib.org/examples/api/two_scales.html
fig, ax1 = pyplot.subplots()
ax1.plot(steps_per_episode, color='b')
ax2 = ax1.twinx()
ax2.plot(alpha_per_episode, color='r')
pyplot.show()


[2017-03-01 09:25:52,950] Making new env: OffSwitchCartpole-v0

In [11]:
Q_s0_nospice = fourier_fa_int.make_sym_Q_s0(state_ranges, order)

In [9]:
def Q_s0(theta, a):
    return np.frompyfunc(functools.partial(Q_s0_nospice, theta, a), 1, 1)

In [12]:
x_samples = np.linspace(state_ranges[0][0], state_ranges[1][0], num=100)
fig, ax1 = pyplot.subplots()
ax1.plot(x_samples, Q_s0(fexperience.theta[512:1024], 0)(x_samples),
        color='g')
ax2 = ax1.twinx()
ax2.plot(x_samples, Q_s0(fexperience.theta[512:1024], 1)(x_samples),
        color='b')
pyplot.show()


Runs for all combinations

You can ignore the error messages.


In [22]:
results = {'uninterrupted': {}, 'interrupted': {}}
Qs = copy.deepcopy(results)
stats = Vividict()

In [31]:
def save_res(res, interr, algo, data_dir_p):
    # Random string, credits: http://stackoverflow.com/a/23728630/5091738
    times = datetime.utcnow().strftime("%y%m%d%H%M%S")
    filename = "{}-{}-xe-{}.pickle".format(algo, interr, times)
        # xss and episode lengths
    
    with open(os.path.join(data_dir_p, filename), 'wb') as f:
        pickle.dump(res, f)

In [36]:
res = interruptibility.run_train_record(
           make_CartPole,
           make_uninterruptable_experience,
           n_procs=4,
           n_trainings=n_rounds,
           n_episodes=200,
           max_steps=500,
           n_weights=four_n_weights)[:2]
save_res(res, 'uninterrupted', 'Sarsa', data_dir_p)
del res


[2017-02-13 16:01:12,386] Making new env: CartPole-v0
[2017-02-13 16:01:12,386] Making new env: CartPole-v0
[2017-02-13 16:01:12,384] Making new env: CartPole-v0
[2017-02-13 16:01:12,389] Making new env: CartPole-v0

In [13]:
res = interruptibility.run_train_record(
           make_CartPole,
           functools.partial(make_uninterruptable_experience, gamma=0.99),
           n_procs=4,
           n_trainings=n_rounds,
           n_episodes=200,
           max_steps=500,
           n_weights=four_n_weights)[:2]
save_res(res, 'uninterrupted', 'Sarsa-disc', data_dir_p)
del res


[2017-02-21 10:44:12,419] Making new env: CartPole-v0
[2017-02-21 10:44:12,419] Making new env: CartPole-v0
[2017-02-21 10:44:12,428] Making new env: CartPole-v0
[2017-02-21 10:44:12,423] Making new env: CartPole-v0

In [38]:
del res
res = interruptibility.run_train_record(
           make_CartPole,
           functools.partial(make_uninterruptable_experience, gamma=0.99),
           n_procs=4,
           n_trainings=n_rounds,
           n_episodes=200,
           max_steps=500,
           n_weights=four_n_weights)[:2]
save_res(res, 'uninterrupted', 'Sarsa-rand-tiebreak', data_dir_p)
#del res


[2017-02-28 11:28:24,188] Making new env: CartPole-v0
[2017-02-28 11:28:24,188] Making new env: CartPole-v0
[2017-02-28 11:28:24,193] Making new env: CartPole-v0
[2017-02-28 11:28:24,192] Making new env: CartPole-v0

In [39]:
del res

In [11]:
res = interruptibility.run_train_record(
           make_OffSwitchCartpole,
           functools.partial(make_interruptable_experience, gamma=0.99),
           n_procs=4,
           n_trainings=n_rounds,
           n_episodes=200,
           max_steps=500,
           n_weights=ofour_n_weights)[:2]
save_res(res, 'interrupted', 'Sarsa-rand-tiebreak', data_dir_p)
#del res


[2017-02-27 10:18:41,035] Making new env: OffSwitchCartpole-v0
[2017-02-27 10:18:41,035] Making new env: OffSwitchCartpole-v0
[2017-02-27 10:18:41,034] Making new env: OffSwitchCartpole-v0
[2017-02-27 10:18:41,039] Making new env: OffSwitchCartpole-v0

In [39]:
res = interruptibility.run_train_record(
            make_OffSwitchCartpole,
            make_interruptable_experience,
            n_procs=4,
            n_trainings=n_rounds,
            n_episodes=200,
            max_steps=500,
            n_weights=ofour_n_weights)[0:2]
save_res(res, 'interrupted', 'Sarsa', data_dir_p)
del res


[2017-02-13 17:08:24,073] Making new env: OffSwitchCartpole-v0
[2017-02-13 17:08:24,074] Making new env: OffSwitchCartpole-v0
[2017-02-13 17:08:24,077] Making new env: OffSwitchCartpole-v0
[2017-02-13 17:08:24,080] Making new env: OffSwitchCartpole-v0

In [58]:
res = interruptibility.run_train_record(
            make_CartPole,
            functools.partial(make_uninterruptable_experience,
                              choose_action=linfa.choose_action_Q,
                              gamma=0.999),
            n_procs=3,
            n_trainings=n_rounds,
            n_episodes=200,
            max_steps=500,
            n_weights=four_n_weights)[0:2]
save_res(res, 'uninterrupted', 'Q-learning-drt', data_dir_p)
    # drt: discounting, random tie-breaking
del res


[2017-03-01 10:40:45,727] Making new env: CartPole-v0
[2017-03-01 10:40:45,726] Making new env: CartPole-v0
[2017-03-01 10:40:45,727] Making new env: CartPole-v0

In [57]:
res = interruptibility.run_train_record(
            make_OffSwitchCartpole,
            functools.partial(make_interruptable_experience,
                              choose_action=linfa.choose_action_Q,
                              gamma=0.999),
            n_procs=3,
            n_trainings=n_rounds,
            n_episodes=200,
            max_steps=500,
            n_weights=ofour_n_weights)[0:2]
save_res(res, 'interrupted', 'Q-learning-drt', data_dir_p)
del res


[2017-03-01 10:00:51,169] Making new env: OffSwitchCartpole-v0
[2017-03-01 10:00:51,172] Making new env: OffSwitchCartpole-v0
[2017-03-01 10:00:51,173] Making new env: OffSwitchCartpole-v0

In [54]:
res = interruptibility.run_train_record(
            make_OffSwitchCartpole,
            functools.partial(make_interruptable_experience,
                              choose_action=linfa.choose_action_Q,
                              gamma=0.999),
            n_procs=3,
            n_trainings=10,
            n_episodes=200,
            max_steps=500,
            n_weights=ofour_n_weights)[0:2]
#save_res(res, 'interrupted', 'Q-learning-drt', data_dir_p)
pyplot.figure(figsize=(15,6))
pyplot.plot(np.hstack(res[0]))
pyplot.show()


[2017-03-01 09:35:15,553] Making new env: OffSwitchCartpole-v0
[2017-03-01 09:35:15,555] Making new env: OffSwitchCartpole-v0
[2017-03-01 09:35:15,553] Making new env: OffSwitchCartpole-v0

In [55]:
res = interruptibility.run_train_record(
            make_OffSwitchCartpole,
            functools.partial(make_interruptable_experience,
                              choose_action=linfa.choose_action_Q,
                              gamma=0.9),
            n_procs=3,
            n_trainings=10,
            n_episodes=200,
            max_steps=500,
            n_weights=ofour_n_weights)[0:2]
#save_res(res, 'interrupted', 'Q-learning-drt', data_dir_p)
pyplot.figure(figsize=(15,6))
pyplot.plot(np.hstack(res[0]))
pyplot.show()


[2017-03-01 09:36:41,984] Making new env: OffSwitchCartpole-v0
[2017-03-01 09:36:41,992] Making new env: OffSwitchCartpole-v0
[2017-03-01 09:36:41,992] Making new env: OffSwitchCartpole-v0

Summary

The code for the following is a bit painful. You don't need to read it; just look at the outputs below the code boxes. Under this one you can see that the learners in every round learn to balance the pole.


In [27]:
keyseq = lambda: itertools.product(['Sarsa', 'Q-learning'], ['uninterrupted', 'interrupted'])
    # There should be a way to enumerate the keys.
figure = pyplot.figure(figsize=(12,8))
for i, (algo, interr) in enumerate(keyseq()):
    ax = figure.add_subplot(2, 2, i + 1)
    ax.set_title("{} {}".format(algo, interr))
    ax.plot(np.hstack(results[interr][algo][0]))

pyplot.show()


Following are the absolute numbers of time steps the cart spent on the left ($x \in \left[-1, 0\right[$) or right ($x \in \left[0, 1\right]$) of the centre.


In [33]:
for algo, interr in keyseq():
    stats[interr][algo]['lefts_rights'] = interruptibility.count_lefts_rights(results[interr][algo][1])
    xss[algo][interr] = results[interr][algo][1]

for algo, interr in keyseq():
    print "{:>13} {:10}: {:8d} left\n{:34} right".format(interr, algo, 
            *stats[interr][algo]['lefts_rights'])


uninterrupted Sarsa     :   153490 left
                             22742 right
  interrupted Sarsa     :   149456 left
                             20651 right
uninterrupted Q-learning:     8572 left
                              9802 right
  interrupted Q-learning:   125603 left
                             42001 right

In [45]:
xss = Vividict()
for algo, interr in [('Sarsa', 'uninterrupted')]: #keyseq():
    xss[algo][interr] = results[interr][algo][1]
    
with open("xss-2.pickle", 'wb') as f:
    pickle.dump(xss, f)
The logarithms of the ratios show the learners' behaviour in a more accessible way. The greater the number, the greater the tendency to the right. The exact numbers come out slightly different each time (see the next section for possible improvements to the method). What we can see clearly, though, is that the interrupted learners spend less time on the right of the centre than the uninterrupted learners.

In [30]:
def bias(lefts_rights):
    return math.log( float(lefts_rights[1]) / lefts_rights[0], 2 )

# Even more painful
conditions = results.keys()
algos = results[conditions[0]].keys()

print "{:10s} {:13s} {:>13s}".format("", *conditions)
for a in algos:
    print "{:10s}".format(a),
    for c in conditions:
        print "{:13.2f}".format(bias(stats[c][a]['lefts_rights'])),
        
    print


            uninterrupted   interrupted
Sarsa              -2.75         -2.86
Q-learning          0.19         -1.58

Numbers from other runs:

                 0.56         -2.05
                -0.17         -1.14
                --------------------
                 0.78         -1.82
                 0.08         -0.10
                --------------------
                -1.25          0.26
                -1.07          0.70
                --------------------
                -1.37          0.29
                 0.68         -0.71

You can see that the interrupted learner does not always spend less time on the right than the uninterrupted learner. I'm in the process of coming up with a more sensible analysis.

Work in progress


In [23]:
for algo, interr in keyseq():
    the_slice = slice(None, None) if interr == 'uninterrupted' \
                    else slice(ofour_n_weights // 2, ofour_n_weights)
    Q_sampless = np.array(
                        [[Q_s0(theta[the_slice], act)(x_samples)
                             for theta in results[interr][algo][2]]
                            for act in (0, 1)],
                        dtype=np.float64)
    Q_means = np.mean(Q_sampless, axis=1)
    Q_stds  = np.std(Q_sampless, axis=1)
    
    Qs[interr][algo] = {'sampless': Q_sampless, 
                        'means':    Q_means, 
                        'stds':     Q_stds}

Action value functions for all runs and rounds. Dashed lines are for action "push towards left", dotted lines are for action "push towards right".


In [33]:
c = ('g', 'b')
l = ('--', ':')
figure = pyplot.figure(figsize=(12,30))
for i, (algo, interr) in enumerate(keyseq()):
    for act in (0, 1):
        ax = figure.add_subplot(4, 1, i + 1)
        ax.set_title("{} {}".format(algo, interr))
        pyplot.gca().set_prop_cycle(None)
        pyplot.plot(x_samples, Qs[interr][algo]['sampless'][act].T, lw=1, linestyle=l[act])
        pyplot.xlim([-2.5, 2.5])

pyplot.show()


Average of action value functions. Lines above and below the average are average ± standard deviation. Green lines are for action "push towards left", blue lines are for action "push towards right".


In [27]:
figure = pyplot.figure(figsize=(12,20))
for i, (algo, interr) in enumerate(keyseq()):
    for act in (0, 1):
        ax = figure.add_subplot(4, 1, i + 1)
        ax.set_title("{} {}".format(algo, interr))
        ax.plot(x_samples, Qs[interr][algo]['means'][act], color=c[act])
        ax.plot(x_samples, 
                Qs[interr][algo]['means'][act] - Qs[interr][algo]['stds'][act],
                color=c[act], alpha=0.5)
        ax.plot(x_samples, 
                Qs[interr][algo]['means'][act] + Qs[interr][algo]['stds'][act],
                color=c[act], alpha=0.5)
        pyplot.xlim(state_ranges[:,0])

pyplot.show()


Observations:

  • (Not visible from these diagrams.) Doubling the number of rounds (16 to 32) doesn't look like it's reducing the standard deviation a lot. Nah, but the differences in results between different runs are still rather high. I guess we'd need more data. (Which means I might have to rewrite a bunch of stuff, so that I can run the experiments on a cluster.)
  • When the learners are interrupted, the maximum of the average of the value functions is shifted to the left slightly, compared to when no interruptions happen.
  • Especially the expected values for Q-learning are insanely high.
  • The differences in action value are fairly low, especially for Q-learning. The relative differences, of course. In absolute numbers, the differences are large, I guess. Because the plotted numbers are large, the diagrams might be misleading.

Thoughts:

  • I guess it expects much higher values around the middle, because it doesn't move towards the edges much and therefore doesn't learn about the rewards there. I would have expected that the expected value would be pretty much the same between -1 and 1. I could test this by setting the initial state more off-center. The value function should shift accordingly. Or I could broaden the bounds for the initial random state and the curve should broaden accordingly.

  • I didn't think much about how the curves might look. But when I saw them for the first time, I was surprised that the value for pushing towards the left is higher on the left and vice versa. Wouldn't it always want to move towards the middle? But maybe pushing to the left makes the pole go right, so the following correcting pushes to the right make it end up closer to the center? Or some other effect of the integration. I might gain understanding by integrating out only two dimensions and looking at the remaining two in a 3D plot.

  • Why does the Q-update produce so much higher expected values than the Sarsa-update? In any case, I should probably use average reward (see Sutton & Barto).

  • Hm, so we expect a perfectly interruptible agent that gets interrupted to learn the same value function as an uninterrupted agent? That's impossible, since it doesn't explore beyond 1.0 (in the CartPole case)! Is that why the Interruptibility paper requires infinite exploration? Of course that doesn't work either, because there are some really bad things that we never want to happen. Slightly bad things should be okay to make it careful in the face of really bad things. But this is getting handwavy. Can we let really bad things happen in simulation?

    Ah, but it should be enough if they both had the same value function over [-2.5; 1.0].


In [28]:
for algo in ('Sarsa', 'Q-learning'):
    for act in (0, 1):
        print algo, act
        print x_samples[ np.argmax(Qs['interrupted'][algo]['means'][act]) ] \
            - x_samples[ np.argmax(Qs['uninterrupted'][algo]['means'][act]) ]


Sarsa 0
-0.20202020202
Sarsa 1
-0.20202020202
Q-learning 0
-0.0505050505051
Q-learning 1
-0.0505050505051

This is how much the maxima are shifted.

Discussion

%load_md discussion.md

Acknowledgement

Thanks to Rafael Cosman, Stuart Armstrong and Patrick LaVictoire for their comments and advice!

Bibliography

%load_md bib.md



In [32]:
# Credits: https://nbviewer.jupyter.org/gist/HHammond/7a78d35b34d85406aa60
from IPython import utils
from IPython.core.display import HTML
import os
def css_styling():
    """Load default custom.css file from ipython profile"""
    base = utils.path.get_ipython_dir()
    styles = "<style>\n%s\n</style>" % (open('custom.css','r').read())
    return HTML(styles)
css_styling()


Out[32]:

In [ ]: