In [3]:
import random
import gym
import numpy as np
from collections import deque
from keras.models import Sequential
from keras.layers import Dense
from keras.optimizers import Adam
from keras import backend as K

EPISODES = 500

In [4]:
class DQNAgent:
    def __init__(self, state_size, action_size):
        self.state_size = state_size
        self.action_size = action_size
        self.memory = deque(maxlen=2000)
        self.gamma = 0.95    # discount rate
        self.epsilon = 1.0  # exploration rate
        self.epsilon_min = 0.01
        self.epsilon_decay = 0.99
        self.learning_rate = 0.001
        self.model = self._build_model()
        self.target_model = self._build_model()
        self.update_target_model()

    def _huber_loss(self, target, prediction):
        # sqrt(1+error^2)-1
        error = prediction - target
        return K.mean(K.sqrt(1+K.square(error))-1, axis=-1)

    def _build_model(self):
        # Neural Net for Deep-Q learning Model
        model = Sequential()
        model.add(Dense(24, input_dim=self.state_size, activation='relu'))
        model.add(Dense(24, activation='relu'))
        model.add(Dense(self.action_size, activation='linear'))
        model.compile(loss=self._huber_loss,
                      optimizer=Adam(lr=self.learning_rate))
        return model

    def update_target_model(self):
        # copy weights from model to target_model
        self.target_model.set_weights(self.model.get_weights())

    def remember(self, state, action, reward, next_state, done):
        self.memory.append((state, action, reward, next_state, done))

    def act(self, state):
        if np.random.rand() <= self.epsilon:
            return random.randrange(self.action_size)
        act_values = self.model.predict(state)
        return np.argmax(act_values[0])  # returns action

    def replay(self, batch_size):
        minibatch = random.sample(self.memory, batch_size)
        for state, action, reward, next_state, done in minibatch:
            target = self.model.predict(state)
            if done:
                target[0][action] = reward
            else:
                a = self.model.predict(next_state)[0]
                t = self.target_model.predict(next_state)[0]
                target[0][action] = reward + self.gamma * t[np.argmax(a)]
            self.model.fit(state, target, epochs=1, verbose=0)
        if self.epsilon > self.epsilon_min:
            self.epsilon *= self.epsilon_decay

    def load(self, name):
        self.model.load_weights(name)

    def save(self, name):
        self.model.save_weights(name)

In [6]:
if __name__ == "__main__":
    env = gym.make('CartPole-v1')
    state_size = env.observation_space.shape[0]
    action_size = env.action_space.n
    agent = DQNAgent(state_size, action_size)
    # agent.load("./save/cartpole-ddqn.h5")
    done = False
    batch_size = 32

    for e in range(EPISODES):
        state = env.reset()
        state = np.reshape(state, [1, state_size])
        for time in range(500):
            env.render()
            action = agent.act(state)
            next_state, reward, done, _ = env.step(action)
            reward = reward if not done else -10
            next_state = np.reshape(next_state, [1, state_size])
            agent.remember(state, action, reward, next_state, done)
            state = next_state
            if done:
                agent.update_target_model()
                print("episode: {}/{}, score: {}, e: {:.2}"
                      .format(e, EPISODES, time, agent.epsilon))
                break
        if len(agent.memory) > batch_size:
            agent.replay(batch_size)
        # if e % 10 == 0:
        #     agent.save("./save/cartpole-ddqn.h5")


[2017-10-05 15:29:00,743] Making new env: CartPole-v1
episode: 0/500, score: 18, e: 1.0
episode: 1/500, score: 15, e: 1.0
episode: 2/500, score: 20, e: 0.99
episode: 3/500, score: 14, e: 0.98
episode: 4/500, score: 46, e: 0.97
episode: 5/500, score: 22, e: 0.96
episode: 6/500, score: 22, e: 0.95
episode: 7/500, score: 10, e: 0.94
episode: 8/500, score: 17, e: 0.93
episode: 9/500, score: 14, e: 0.92
episode: 10/500, score: 32, e: 0.91
episode: 11/500, score: 12, e: 0.9
episode: 12/500, score: 11, e: 0.9
episode: 13/500, score: 14, e: 0.89
episode: 14/500, score: 16, e: 0.88
episode: 15/500, score: 13, e: 0.87
episode: 16/500, score: 12, e: 0.86
episode: 17/500, score: 12, e: 0.85
episode: 18/500, score: 30, e: 0.84
episode: 19/500, score: 18, e: 0.83
episode: 20/500, score: 12, e: 0.83
episode: 21/500, score: 14, e: 0.82
episode: 22/500, score: 28, e: 0.81
episode: 23/500, score: 21, e: 0.8
episode: 24/500, score: 13, e: 0.79
episode: 25/500, score: 12, e: 0.79
episode: 26/500, score: 14, e: 0.78
episode: 27/500, score: 19, e: 0.77
episode: 28/500, score: 15, e: 0.76
episode: 29/500, score: 15, e: 0.75
episode: 30/500, score: 14, e: 0.75
episode: 31/500, score: 8, e: 0.74
episode: 32/500, score: 33, e: 0.73
episode: 33/500, score: 20, e: 0.72
episode: 34/500, score: 15, e: 0.72
episode: 35/500, score: 12, e: 0.71
episode: 36/500, score: 15, e: 0.7
episode: 37/500, score: 15, e: 0.7
episode: 38/500, score: 10, e: 0.69
episode: 39/500, score: 15, e: 0.68
episode: 40/500, score: 11, e: 0.68
episode: 41/500, score: 29, e: 0.67
episode: 42/500, score: 18, e: 0.66
episode: 43/500, score: 21, e: 0.66
episode: 44/500, score: 8, e: 0.65
episode: 45/500, score: 25, e: 0.64
episode: 46/500, score: 23, e: 0.64
episode: 47/500, score: 18, e: 0.63
episode: 48/500, score: 18, e: 0.62
episode: 49/500, score: 8, e: 0.62
episode: 50/500, score: 13, e: 0.61
episode: 51/500, score: 15, e: 0.61
episode: 52/500, score: 16, e: 0.6
episode: 53/500, score: 30, e: 0.59
episode: 54/500, score: 9, e: 0.59
episode: 55/500, score: 13, e: 0.58
episode: 56/500, score: 11, e: 0.58
episode: 57/500, score: 13, e: 0.57
episode: 58/500, score: 17, e: 0.56
episode: 59/500, score: 17, e: 0.56
episode: 60/500, score: 13, e: 0.55
episode: 61/500, score: 10, e: 0.55
episode: 62/500, score: 12, e: 0.54
episode: 63/500, score: 15, e: 0.54
episode: 64/500, score: 8, e: 0.53
---------------------------------------------------------------------------
KeyboardInterrupt                         Traceback (most recent call last)
<ipython-input-6-fd16458995f2> in <module>()
     13         state = np.reshape(state, [1, state_size])
     14         for time in range(500):
---> 15             env.render()
     16             action = agent.act(state)
     17             next_state, reward, done, _ = env.step(action)

C:\Users\xuanc\Anaconda3\lib\site-packages\gym\core.py in render(self, mode, close)
    148             elif mode not in modes:
    149                 raise error.UnsupportedMode('Unsupported rendering mode: {}. (Supported modes for {}: {})'.format(mode, self, modes))
--> 150         return self._render(mode=mode, close=close)
    151 
    152     def close(self):

C:\Users\xuanc\Anaconda3\lib\site-packages\gym\core.py in _render(self, mode, close)
    284 
    285     def _render(self, mode='human', close=False):
--> 286         return self.env.render(mode, close)
    287 
    288     def _close(self):

C:\Users\xuanc\Anaconda3\lib\site-packages\gym\core.py in render(self, mode, close)
    148             elif mode not in modes:
    149                 raise error.UnsupportedMode('Unsupported rendering mode: {}. (Supported modes for {}: {})'.format(mode, self, modes))
--> 150         return self._render(mode=mode, close=close)
    151 
    152     def close(self):

C:\Users\xuanc\Anaconda3\lib\site-packages\gym\envs\classic_control\cartpole.py in _render(self, mode, close)
    144         self.poletrans.set_rotation(-x[2])
    145 
--> 146         return self.viewer.render(return_rgb_array = mode=='rgb_array')

C:\Users\xuanc\Anaconda3\lib\site-packages\gym\envs\classic_control\rendering.py in render(self, return_rgb_array)
    102             arr = arr.reshape(buffer.height, buffer.width, 4)
    103             arr = arr[::-1,:,0:3]
--> 104         self.window.flip()
    105         self.onetime_geoms = []
    106         return arr

C:\Users\xuanc\Anaconda3\lib\site-packages\pyglet\window\win32\__init__.py in flip(self)
    309     def flip(self):
    310         self.draw_mouse_cursor()
--> 311         self.context.flip()
    312 
    313     def set_location(self, x, y):

C:\Users\xuanc\Anaconda3\lib\site-packages\pyglet\gl\win32.py in flip(self)
    222 
    223     def flip(self):
--> 224         wgl.wglSwapLayerBuffers(self.canvas.hdc, wgl.WGL_SWAP_MAIN_PLANE)
    225 
    226     def get_vsync(self):

KeyboardInterrupt: 

In [ ]: