PyTorch DQN Implemenation



In [1]:
%matplotlib inline

import torch
import torch.nn as nn
import gym
import random
import numpy as np
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
from torch.autograd import Variable
from collections import deque, namedtuple

In [2]:
env = gym.envs.make("CartPole-v0")


[2017-03-09 21:31:48,174] Making new env: CartPole-v0

In [3]:
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.fc1 = nn.Linear(4, 128)
        self.tanh = nn.Tanh()
        self.fc2 = nn.Linear(128, 2)
        self.init_weights()
    
    def init_weights(self):
        self.fc1.weight.data.uniform_(-0.1, 0.1)
        self.fc2.weight.data.uniform_(-0.1, 0.1)
    
    def forward(self, x):
        out = self.fc1(x)
        out = self.tanh(out)
        out = self.fc2(out)
        return out

In [4]:
def make_epsilon_greedy_policy(network, epsilon, nA):
    def policy(state):
        sample = random.random()
        if sample < (1-epsilon) + (epsilon/nA):
            q_values = network(state.view(1, -1))
            action = q_values.data.max(1)[1][0, 0]
        else:
            action = random.randrange(nA)
        return action
    return policy

In [5]:
class ReplayMemory(object):
    
    def __init__(self, capacity):
        self.memory = deque()
        self.capacity = capacity
        
    def push(self, transition):
        if len(self.memory) > self.capacity:
            self.memory.popleft()
        self.memory.append(transition)
    
    def sample(self, batch_size):
        return random.sample(self.memory, batch_size)
     
    def __len__(self):
        return len(self.memory)

In [6]:
def to_tensor(ndarray, volatile=False):
    return Variable(torch.from_numpy(ndarray), volatile=volatile).float()

In [7]:
def deep_q_learning(num_episodes=10, batch_size=100, 
                    discount_factor=0.95, epsilon=0.1, epsilon_decay=0.95):

    # Q-Network and memory 
    net = Net()
    memory = ReplayMemory(10000)
    
    # Loss and Optimizer
    criterion = nn.MSELoss()
    optimizer = torch.optim.Adam(net.parameters(), lr=0.001)
    
    for i_episode in range(num_episodes):
        
        # Set policy (TODO: decaying epsilon)
        #if (i_episode+1) % 100 == 0:
        #    epsilon *= 0.9
            
        policy = make_epsilon_greedy_policy(
            net, epsilon, env.action_space.n)
        
        # Start an episode
        state = env.reset()
        
        for t in range(10000):
            
            # Sample action from epsilon greed policy
            action = policy(to_tensor(state)) 
            next_state, reward, done, _ = env.step(action)
            
            
            # Restore transition in memory
            memory.push([state, action, reward, next_state])
            
            
            if len(memory) >= batch_size:
                # Sample mini-batch transitions from memory
                batch = memory.sample(batch_size)
                state_batch = np.vstack([trans[0] for trans in batch])
                action_batch =np.vstack([trans[1] for trans in batch]) 
                reward_batch = np.vstack([trans[2] for trans in batch])
                next_state_batch = np.vstack([trans[3] for trans in batch])
                
                # Forward + Backward + Opimize
                net.zero_grad()
                q_values = net(to_tensor(state_batch))
                next_q_values = net(to_tensor(next_state_batch, volatile=True))
                next_q_values.volatile = False
                
                td_target = to_tensor(reward_batch) + discount_factor * (next_q_values).max(1)[0]
                loss = criterion(q_values.gather(1, 
                            to_tensor(action_batch).long().view(-1, 1)), td_target)
                loss.backward()
                optimizer.step()
            
            if done:
                break
        
            state = next_state
            
        if len(memory) >= batch_size and (i_episode+1) % 10 == 0:
            print ('episode: %d, time: %d, loss: %.4f' %(i_episode, t, loss.data[0]))

In [8]:
deep_q_learning(1000)


episode: 9, time: 9, loss: 0.9945
episode: 19, time: 9, loss: 1.8221
episode: 29, time: 9, loss: 4.3124
episode: 39, time: 8, loss: 6.9764
episode: 49, time: 9, loss: 6.8300
episode: 59, time: 8, loss: 5.5186
episode: 69, time: 9, loss: 4.1160
episode: 79, time: 9, loss: 2.4802
episode: 89, time: 13, loss: 0.7890
episode: 99, time: 10, loss: 0.2805
episode: 109, time: 12, loss: 0.1323
episode: 119, time: 13, loss: 0.0519
episode: 129, time: 18, loss: 0.0176
episode: 139, time: 22, loss: 0.0067
episode: 149, time: 17, loss: 0.0114
episode: 159, time: 26, loss: 0.0017
episode: 169, time: 23, loss: 0.0018
episode: 179, time: 21, loss: 0.0023
episode: 189, time: 11, loss: 0.0024
episode: 199, time: 7, loss: 0.0040
episode: 209, time: 8, loss: 0.0030
episode: 219, time: 7, loss: 0.0070
episode: 229, time: 9, loss: 0.0031
episode: 239, time: 9, loss: 0.0029
episode: 249, time: 8, loss: 0.0046
episode: 259, time: 8, loss: 0.0009
episode: 269, time: 10, loss: 0.0020
episode: 279, time: 9, loss: 0.0025
episode: 289, time: 8, loss: 0.0015
episode: 299, time: 10, loss: 0.0009
episode: 309, time: 8, loss: 0.0012
episode: 319, time: 8, loss: 0.0034
episode: 329, time: 8, loss: 0.0008
episode: 339, time: 9, loss: 0.0021
episode: 349, time: 8, loss: 0.0018
episode: 359, time: 9, loss: 0.0017
episode: 369, time: 9, loss: 0.0006
episode: 379, time: 9, loss: 0.0023
episode: 389, time: 10, loss: 0.0017
episode: 399, time: 8, loss: 0.0018
episode: 409, time: 8, loss: 0.0023
episode: 419, time: 9, loss: 0.0020
episode: 429, time: 9, loss: 0.0006
episode: 439, time: 10, loss: 0.0006
episode: 449, time: 10, loss: 0.0025
episode: 459, time: 9, loss: 0.0013
episode: 469, time: 8, loss: 0.0011
episode: 479, time: 8, loss: 0.0005
episode: 489, time: 8, loss: 0.0004
episode: 499, time: 7, loss: 0.0017
episode: 509, time: 7, loss: 0.0004
episode: 519, time: 10, loss: 0.0008
episode: 529, time: 11, loss: 0.0006
episode: 539, time: 9, loss: 0.0010
episode: 549, time: 8, loss: 0.0006
episode: 559, time: 8, loss: 0.0012
episode: 569, time: 9, loss: 0.0011
episode: 579, time: 8, loss: 0.0010
episode: 589, time: 8, loss: 0.0008
episode: 599, time: 10, loss: 0.0010
episode: 609, time: 8, loss: 0.0005
episode: 619, time: 9, loss: 0.0004
episode: 629, time: 8, loss: 0.0007
episode: 639, time: 10, loss: 0.0014
episode: 649, time: 10, loss: 0.0004
episode: 659, time: 9, loss: 0.0008
episode: 669, time: 8, loss: 0.0005
episode: 679, time: 8, loss: 0.0003
episode: 689, time: 9, loss: 0.0009
episode: 699, time: 8, loss: 0.0004
episode: 709, time: 8, loss: 0.0013
episode: 719, time: 8, loss: 0.0006
episode: 729, time: 7, loss: 0.0021
episode: 739, time: 9, loss: 0.0023
episode: 749, time: 9, loss: 0.0039
episode: 759, time: 8, loss: 0.0030
episode: 769, time: 9, loss: 0.0016
episode: 779, time: 7, loss: 0.0041
episode: 789, time: 8, loss: 0.0050
episode: 799, time: 8, loss: 0.0041
episode: 809, time: 11, loss: 0.0053
episode: 819, time: 7, loss: 0.0018
episode: 829, time: 9, loss: 0.0019
episode: 839, time: 11, loss: 0.0017
episode: 849, time: 8, loss: 0.0029
episode: 859, time: 9, loss: 0.0012
episode: 869, time: 9, loss: 0.0036
episode: 879, time: 7, loss: 0.0017
episode: 889, time: 9, loss: 0.0016
episode: 899, time: 10, loss: 0.0023
episode: 909, time: 8, loss: 0.0032
episode: 919, time: 8, loss: 0.0015
episode: 929, time: 9, loss: 0.0021
episode: 939, time: 9, loss: 0.0015
episode: 949, time: 9, loss: 0.0016
episode: 959, time: 9, loss: 0.0013
episode: 969, time: 12, loss: 0.0029
episode: 979, time: 7, loss: 0.0016
episode: 989, time: 7, loss: 0.0012
episode: 999, time: 9, loss: 0.0013

In [ ]: