Reinforcement Learning: Deep Q-Networks

If you aren't familiar with reinforcement learning, check out the previous guide on reinforcement learning for an introduction.

In the previous guide we implemented the Q function as a lookup table. That worked well enough for that scenario because it had a fairly small state space. However, consider something like DeepMind's Atari player. A state in that task is a unique configuration of pixels. All those Atari games are color, so each pixel has three values (R,G,B), and there are quite a few pixels. So there is a massive state space for all possible configurations of pixels, and we simply can't implement a lookup table encompassing all of these states - it would take up too much memory.

Instead, we can learn a Q function that approximately maps a set of pixel values and an action to some value. We could implement this Q function as a neural network and have it learn how to predict rewards for each action given an input state. This is the general idea behind deep Q-learning (i.e. deep Q networks, or DQNs).

Here we'll put together a simple DQN agent that learns how to play a simple game of catch. The agent controls a paddle at the bottom of the screen that it can move left, right, or not at all (so there are three possible action). An object falls from the top of the screen, and the agent wins if it catches it (a reward of +1). Otherwise, it loses (a reward of -1).

We'll implement the game in black-and-white so that the pixels in the game can be represented as 1 or 0.

Using DQNs are quite like using neural networks in ways you may be more familiar with. Here we'll take a vector that represents the screen, feed it through the network, and the network will output a distribution of values over possible actions. You can kind of think of it as a classification problem: given this input state, label it with the best action to take.

For example, this is the architecture of the Atari player:

The scenario we're dealing with is simple enough that we don't need convolutional neural networks, but we could easily extend it in that way if we wanted (just replace our vanilla neural network with a convolutional one).

Here's what our catch game will look like:

To start I'll present the code for the catch game itself. It's not important that you understand this code - the part we care about is the agent itself.

Note that this needs to be run in the terminal in order to visualize the game.


In [1]:
import numpy as np
from blessings import Terminal

class Game():
    def __init__(self, shape=(10,10)):
        self.shape = shape
        self.height, self.width = shape
        self.last_row = self.height - 1
        self.paddle_padding = 1
        self.n_actions = 3 # left, stay, right
        self.term = Terminal()
        self.reset()

    def reset(self):
        # reset grid
        self.grid = np.zeros(self.shape)

        # can only move left or right (or stay)
        # so position is only its horizontal position (col)
        self.pos = np.random.randint(self.paddle_padding, self.width - 1 - self.paddle_padding)
        self.set_paddle(1)

        # item to catch
        self.target = (0, np.random.randint(self.width - 1))
        self.set_position(self.target, 1)

    def move(self, action):
        # clear previous paddle position
        self.set_paddle(0)

        # action is either -1, 0, 1,
        # but comes in as 0, 1, 2, so subtract 1
        action -= 1
        self.pos = min(max(self.pos + action, self.paddle_padding), self.width - 1 - self.paddle_padding)

        # set new paddle position
        self.set_paddle(1)

    def set_paddle(self, val):
        for i in range(1 + self.paddle_padding*2):
            pos = self.pos - self.paddle_padding + i
            self.set_position((self.last_row, pos), val)

    @property
    def state(self):
        return self.grid.reshape((1,-1)).copy()

    def set_position(self, pos, val):
        r, c = pos
        self.grid[r,c] = val

    def update(self):
        r, c = self.target

        self.set_position(self.target, 0)
        self.set_paddle(1) # in case the target is on the paddle
        self.target = (r+1, c)
        self.set_position(self.target, 1)

        # off the map, it's gone
        if r + 1 == self.last_row:
            # reward of 1 if collided with paddle, else -1
            if abs(c - self.pos) <= self.paddle_padding:
                return 1
            else:
                return -1

        return 0

    def render(self):
        print(self.term.clear())
        for r, row in enumerate(self.grid):
            for c, on in enumerate(row):
                if on:
                    color = 235
                else:
                    color = 229

                print(self.term.move(r, c) + self.term.on_color(color) + ' ' + self.term.normal)

        # move cursor to end
        print(self.term.move(self.height, 0))

Ok, on to the agent itself. I'll present the code in full here, then explain parts in more detail.


In [2]:
import os

#if using Theano with GPU
#os.environ["THEANO_FLAGS"] = "mode=FAST_RUN,device=gpu,floatX=float32"

import random
from keras.models import Sequential
from keras.layers.core import Dense
from collections import deque

class Agent():
    def __init__(self, env, explore=0.1, discount=0.9, hidden_size=100, memory_limit=5000):
        self.env = env
        model = Sequential()
        model.add(Dense(hidden_size, input_shape=(env.height * env.width,), activation='relu'))
        model.add(Dense(hidden_size, activation='relu'))
        model.add(Dense(env.n_actions))
        model.compile(loss='mse', optimizer='sgd')
        self.Q = model

        # experience replay:
        # remember states to "reflect" on later
        self.memory = deque([], maxlen=memory_limit)

        self.explore = explore
        self.discount = discount

    def choose_action(self):
        if np.random.rand() <= self.explore:
            return np.random.randint(0, self.env.n_actions)
        state = self.env.state
        q = self.Q.predict(state)
        return np.argmax(q[0])

    def remember(self, state, action, next_state, reward):
        # the deque object will automatically keep a fixed length
        self.memory.append((state, action, next_state, reward))

    def _prep_batch(self, batch_size):
        if batch_size > self.memory.maxlen:
            Warning('batch size should not be larger than max memory size. Setting batch size to memory size')
            batch_size = self.memory.maxlen

        batch_size = min(batch_size, len(self.memory))

        inputs = []
        targets = []

        # prep the batch
        # inputs are states, outputs are values over actions
        batch = random.sample(list(self.memory), batch_size)
        random.shuffle(batch)
        for state, action, next_state, reward in batch:
            inputs.append(state)
            target = self.Q.predict(state)[0]

            # debug, "this should never happen"
            assert not np.array_equal(state, next_state)

            # non-zero reward indicates terminal state
            if reward:
                target[action] = reward
            else:
                # reward + gamma * max_a' Q(s', a')
                Q_sa = np.max(self.Q.predict(next_state)[0])
                target[action] = reward + self.discount * Q_sa
            targets.append(target)

        # to numpy matrices
        return np.vstack(inputs), np.vstack(targets)

    def replay(self, batch_size):
        inputs, targets = self._prep_batch(batch_size)
        loss = self.Q.train_on_batch(inputs, targets)
        return loss

    def save(self, fname):
        self.Q.save_weights(fname)

    def load(self, fname):
        self.Q.load_weights(fname)
        print(self.Q.get_weights())


Using Theano backend.
WARNING (theano.sandbox.cuda): The cuda backend is deprecated and will be removed in the next release (v0.10).  Please switch to the gpuarray backend. You can get more information about how to switch at this URL:
 https://github.com/Theano/Theano/wiki/Converting-to-the-new-gpu-back-end%28gpuarray%29

Using gpu device 0: GeForce GTX TITAN X (CNMeM is disabled, cuDNN 5103)

You'll see that this is quite similar to the previous Q-learning agent we implemented. There are explore and discount values, for example. But the Q function is now a neural network.

The biggest difference are these remember and replay methods.

A challenge with DQNs is that they can be unstable - in particular, they exhibit a problem known as catastrophic forgetting in which later experiences overwrite earlier ones. When this happens, the agent is unable to take full advantage of everything it's learned, only what it's learned most recently.

A method to deal with this is called experience replay. We just store experienced states and their resulting rewards (as "memories"), then between actions we sample a batch of these memories (this is what the _prep_batch method does) and use them to train the neural network (i.e. "replay" these remembered experiences). This will become clearer in the code below, where we actually train the agent.


In [3]:
import os
import sys
from time import sleep
game = Game()
agent = Agent(game)

print('training...')
epochs = 6500
batch_size = 256
fname = 'game_weights.h5'

# keep track of past record_len results
record_len = 100
record = deque([], record_len)

for i in range(epochs):
    game.reset()
    reward = 0
    loss = 0
    # rewards only given at end of game
    while reward == 0:
        prev_state = game.state
        action = agent.choose_action()
        game.move(action)
        reward = game.update()
        new_state = game.state
        
        # debug, "this should never happen"
        assert not np.array_equal(new_state, prev_state)

        agent.remember(prev_state, action, new_state, reward)
        loss += agent.replay(batch_size)

    # if running in a terminal, use these instead of print:
    #sys.stdout.flush()
    #sys.stdout.write('epoch: {:04d}/{} | loss: {:.3f} | win rate: {:.3f}\r'.format(i+1, epochs, loss, sum(record)/len(record) if record else 0))
    if i % 100 == 0:
        print('epoch: {:04d}/{} | loss: {:.3f} | win rate: {:.3f}\r'.format(i+1, epochs, loss, sum(record)/len(record) if record else 0))
    
    record.append(reward if reward == 1 else 0)

agent.save(fname)


training...
epoch: 0001/6500 | loss: 0.047 | win rate: 0.000
epoch: 0101/6500 | loss: 0.293 | win rate: 0.000
epoch: 0201/6500 | loss: 0.292 | win rate: 0.000
epoch: 0301/6500 | loss: 0.285 | win rate: 0.000
epoch: 0401/6500 | loss: 0.254 | win rate: 0.000
epoch: 0501/6500 | loss: 0.249 | win rate: 0.000
epoch: 0601/6500 | loss: 0.244 | win rate: 0.000
epoch: 0701/6500 | loss: 0.219 | win rate: 0.000
epoch: 0801/6500 | loss: 0.182 | win rate: 0.000
epoch: 0901/6500 | loss: 0.159 | win rate: 0.000
epoch: 1001/6500 | loss: 0.145 | win rate: 0.000
epoch: 1101/6500 | loss: 0.163 | win rate: 0.000
epoch: 1201/6500 | loss: 0.170 | win rate: 0.000
epoch: 1301/6500 | loss: 0.151 | win rate: 0.000
epoch: 1401/6500 | loss: 0.179 | win rate: 0.000
epoch: 1501/6500 | loss: 0.143 | win rate: 0.000
epoch: 1601/6500 | loss: 0.149 | win rate: 0.000
epoch: 1701/6500 | loss: 0.162 | win rate: 0.000
epoch: 1801/6500 | loss: 0.173 | win rate: 0.000
epoch: 1901/6500 | loss: 0.154 | win rate: 0.000
epoch: 2001/6500 | loss: 0.153 | win rate: 0.000
epoch: 2101/6500 | loss: 0.141 | win rate: 0.000
epoch: 2201/6500 | loss: 0.160 | win rate: 0.000
epoch: 2301/6500 | loss: 0.171 | win rate: 0.000
epoch: 2401/6500 | loss: 0.160 | win rate: 0.000
epoch: 2501/6500 | loss: 0.167 | win rate: 0.000
epoch: 2601/6500 | loss: 0.170 | win rate: 0.000
epoch: 2701/6500 | loss: 0.166 | win rate: 0.000
epoch: 2801/6500 | loss: 0.131 | win rate: 0.000
epoch: 2901/6500 | loss: 0.115 | win rate: 0.000
epoch: 3001/6500 | loss: 0.111 | win rate: 0.000
epoch: 3101/6500 | loss: 0.092 | win rate: 0.000
epoch: 3201/6500 | loss: 0.110 | win rate: 0.000
epoch: 3301/6500 | loss: 0.092 | win rate: 0.000
epoch: 3401/6500 | loss: 0.097 | win rate: 0.000
epoch: 3501/6500 | loss: 0.086 | win rate: 0.000
epoch: 3601/6500 | loss: 0.106 | win rate: 0.000
epoch: 3701/6500 | loss: 0.082 | win rate: 0.000
epoch: 3801/6500 | loss: 0.088 | win rate: 0.000
epoch: 3901/6500 | loss: 0.092 | win rate: 0.000
epoch: 4001/6500 | loss: 0.082 | win rate: 0.000
epoch: 4101/6500 | loss: 0.094 | win rate: 0.000
epoch: 4201/6500 | loss: 0.089 | win rate: 0.000
epoch: 4301/6500 | loss: 0.091 | win rate: 0.000
epoch: 4401/6500 | loss: 0.079 | win rate: 0.000
epoch: 4501/6500 | loss: 0.073 | win rate: 0.000
epoch: 4601/6500 | loss: 0.062 | win rate: 0.000
epoch: 4701/6500 | loss: 0.077 | win rate: 0.000
epoch: 4801/6500 | loss: 0.047 | win rate: 0.000
epoch: 4901/6500 | loss: 0.052 | win rate: 0.000
epoch: 5001/6500 | loss: 0.068 | win rate: 0.000
epoch: 5101/6500 | loss: 0.045 | win rate: 0.000
epoch: 5201/6500 | loss: 0.053 | win rate: 0.000
epoch: 5301/6500 | loss: 0.056 | win rate: 0.000
epoch: 5401/6500 | loss: 0.040 | win rate: 0.000
epoch: 5501/6500 | loss: 0.036 | win rate: 0.000
epoch: 5601/6500 | loss: 0.050 | win rate: 0.000
epoch: 5701/6500 | loss: 0.034 | win rate: 0.000
epoch: 5801/6500 | loss: 0.052 | win rate: 0.000
epoch: 5901/6500 | loss: 0.045 | win rate: 0.000
epoch: 6001/6500 | loss: 0.049 | win rate: 0.000
epoch: 6101/6500 | loss: 0.034 | win rate: 0.000
epoch: 6201/6500 | loss: 0.043 | win rate: 0.000
epoch: 6301/6500 | loss: 0.035 | win rate: 0.000
epoch: 6401/6500 | loss: 0.042 | win rate: 0.000

Here we train the agent for 6500 epochs (that is, 6500 games). We also keep a trailing record of its wins to see if its win rate is improving.

A game goes on until the reward is non-zero, which means the agent has either lost (reward of -1) or won (reward of +1). Note that between each action the agent "remembers" the states and reward it just saw, as well as the action it took. Then it "replays" past experiences to update its neural network.

Once the agent is trained, we can play a round and see if it wins.

Depicted below are the results from my training:


In [4]:
# play a round
game.reset()
#game.render() # rendering won't work inside a notebook, only from terminal. uncomment
reward = 0
while reward == 0:
    action = agent.choose_action()
    game.move(action)
    reward = game.update()
    #game.render()
    sleep(0.1)
print('winner!' if reward == 1 else 'loser!')


winner!

After 6500 epochs, the agent I trained won about 90% of the time. Not bad from the 30% or so it started at!