In [1]:
import numpy as np
import gym
from gym.wrappers import Monitor
from numpy.random import choice
import random
from phi.api import *
import tensorflow as tf
from tfinterface.reinforcement import DQN, ExpandedStateEnv
import os
from scipy.interpolate import interp1d
import numbers



def get_run():
    try:
        with open("run.txt") as f:
            run = int(f.read().split("/n")[0])
    except:
        run = -1
    
    with open("run.txt", 'w+') as f:
        run += 1
        
        f.seek(0)
        f.write(str(run))
        f.truncate()
        
    return run

In [7]:
run = get_run()
env_logs = '/tmp/cartpole-{}'.format(run)
expansion = 3

env = gym.make('CartPole-v1')
# env = Monitor(env, env_logs)
env = ExpandedStateEnv(env, expansion)
                
n_actions = env.action_space.n
n_states = env.observation_space.shape[0] * expansion
model_path = os.getcwd() + "/Q-network-full.model"
logs_path = "logs/run{}".format(run)


[2017-03-03 09:45:16,462] Making new env: CartPole-v1

In [4]:
model = DQN(
    n_actions, n_states,
    model_path = model_path,
    logs_path = logs_path,
    flush_secs = 3.0,
    y = 0.9999,
    buffer_length=500000,
    restore = True
)

print("run: {},\n s: {},\n a: {},\n r: {},\n Qs: {},\n update: {}".format(
    run, model.inputs.s, model.inputs.a, model.inputs.r, model.network.Qs, model.update
))


run: 42,
 s: Tensor("inputs/s:0", shape=(?, 12), dtype=float32, device=/device:CPU:0),
 a: Tensor("inputs/a:0", shape=(?,), dtype=int32, device=/device:CPU:0),
 r: Tensor("inputs/r:0", shape=(?,), dtype=float32, device=/device:CPU:0),
 Qs: Tensor("network/linear_layer/MatMul:0", shape=(?, 2), dtype=float32, device=/device:CPU:0),
 update: name: "network/Adam"
op: "NoOp"
input: "^network/Adam/update_network/relu_layer/kernel/ApplyAdam"
input: "^network/Adam/update_network/linear_layer/kernel/ApplyAdam"
input: "^network/Adam/Assign"
input: "^network/Adam/Assign_1"
device: "/device:CPU:0"


In [5]:
model.fit(
    env, 
    episodes=50000,
    max_episode_length = 60000,
    k = 1000.
)


[MAX] Episode: 0, Length: 68, e: 0.937207122774, learning_rate: 0.937207122774, buffer_len: 68
[NOR] Episode: 0, Length: 68, e: 0.937207122774, learning_rate: 0.937207122774, buffer_len: 68
[NOR] Episode: 10, Length: 27, e: 0.782472613459, learning_rate: 0.782472613459, buffer_len: 279
[NOR] Episode: 20, Length: 14, e: 0.641436818473, learning_rate: 0.641436818473, buffer_len: 560
[NOR] Episode: 30, Length: 45, e: 0.550357732526, learning_rate: 0.550357732526, buffer_len: 818
[NOR] Episode: 40, Length: 41, e: 0.471920717319, learning_rate: 0.471920717319, buffer_len: 1120
[NOR] Episode: 50, Length: 14, e: 0.4095004095, learning_rate: 0.4095004095, buffer_len: 1443
[MAX] Episode: 60, Length: 93, e: 0.354484225452, learning_rate: 0.354484225452, buffer_len: 1822
[NOR] Episode: 60, Length: 93, e: 0.354484225452, learning_rate: 0.354484225452, buffer_len: 1822
[MAX] Episode: 63, Length: 126, e: 0.327546675401, learning_rate: 0.327546675401, buffer_len: 2054
[NOR] Episode: 70, Length: 51, e: 0.291120815138, learning_rate: 0.291120815138, buffer_len: 2436
[NOR] Episode: 80, Length: 28, e: 0.245218244237, learning_rate: 0.245218244237, buffer_len: 3079
[NOR] Episode: 90, Length: 82, e: 0.217296827466, learning_rate: 0.217296827466, buffer_len: 3603
[MAX] Episode: 98, Length: 221, e: 0.193760899051, learning_rate: 0.193760899051, buffer_len: 4162
[NOR] Episode: 100, Length: 33, e: 0.191131498471, learning_rate: 0.191131498471, buffer_len: 4233
[NOR] Episode: 110, Length: 161, e: 0.156936597615, learning_rate: 0.156936597615, buffer_len: 5373
[NOR] Episode: 120, Length: 70, e: 0.128766417718, learning_rate: 0.128766417718, buffer_len: 6767
[MAX] Episode: 121, Length: 413, e: 0.122264335493, learning_rate: 0.122264335493, buffer_len: 7180
[NOR] Episode: 130, Length: 56, e: 0.105351875263, learning_rate: 0.105351875263, buffer_len: 8493
[MAX] Episode: 132, Length: 536, e: 0.0996015936255, learning_rate: 0.0996015936255, buffer_len: 9041
[NOR] Episode: 140, Length: 215, e: 0.0845379998309, learning_rate: 0.0845379998309, buffer_len: 10830
[NOR] Episode: 150, Length: 212, e: 0.0778331257783, learning_rate: 0.0778331257783, buffer_len: 11849
[NOR] Episode: 160, Length: 303, e: 0.0722647781471, learning_rate: 0.0722647781471, buffer_len: 12839
[MAX] Episode: 161, Length: 644, e: 0.0690512360171, learning_rate: 0.0690512360171, buffer_len: 13483
[MAX] Episode: 165, Length: 742, e: 0.0614665929068, learning_rate: 0.0614665929068, buffer_len: 15270
[NOR] Episode: 170, Length: 725, e: 0.0522247754335, learning_rate: 0.0522247754335, buffer_len: 18149
[MAX] Episode: 174, Length: 861, e: 0.046539768232, learning_rate: 0.046539768232, buffer_len: 20488
[NOR] Episode: 180, Length: 184, e: 0.0407564395174, learning_rate: 0.0407564395174, buffer_len: 23537
[MAX] Episode: 181, Length: 1618, e: 0.0382350692055, learning_rate: 0.0382350692055, buffer_len: 25155
[MAX] Episode: 184, Length: 2384, e: 0.0349247371914, learning_rate: 0.0349247371914, buffer_len: 27634
[MAX] Episode: 187, Length: 6012, e: 0.0280119891313, learning_rate: 0.0280119891313, buffer_len: 34700
[NOR] Episode: 190, Length: 356, e: 0.0264683306424, learning_rate: 0.0264683306424, buffer_len: 36782
[MAX] Episode: 191, Length: 8733, e: 0.0214989035559, learning_rate: 0.0214989035559, buffer_len: 45515
[MAX] Episode: 192, Length: 60001, e: 0.01, learning_rate: 0.00938834905882, buffer_len: 105516
---------------------------------------------------------------------------
KeyboardInterrupt                         Traceback (most recent call last)
<ipython-input-5-e0b7bc24d356> in <module>()
      3     episodes=50000,
      4     max_episode_length = 60000,
----> 5     k = 1000.
      6 )

/home/cristian/data/cristian/tfinterface/tfinterface/reinforcement/dnq.py in fit(self, env, k, learning_rate, print_step, episodes, max_episode_length, discount, batch_size)
    576                 MaxQs1 = self.sess.run(self.target_network.max_Qs, feed_dict={self.inputs.s: S1})
    577 
--> 578                 feed_dict = self.fit_feed(S, A, R, MaxQs1, Done, learning_rate)
    579                 _, summaries = self.sess.run([self.update, self.summaries], feed_dict=feed_dict)
    580                 self.writer.add_summary(summaries)

/home/cristian/data/cristian/tfinterface/tfinterface/reinforcement/dnq.py in fit_feed(self, S, A, R, Max_Qs1, Done, learning_rate)
    531 
    532     def fit_feed(self, S, A, R, Max_Qs1, Done, learning_rate):
--> 533         return {self.inputs.s: S, self.inputs.a: A, self.inputs.r: R, self.inputs.max_Qs1: Max_Qs1, self.inputs.done: Done, self.inputs.learning_rate: learning_rate}
    534 
    535     @_coconut_tco

KeyboardInterrupt: 

In [8]:
import time

model_run = DQN(
    n_actions, n_states,
    model_path = model_path + ".max",
    flush_secs = 3.0,
    restore = True
)



s = env.reset()
done = False

while not done:
    a = model_run.choose_action(s, e=0.2)
    s, r, done, info = env.step(a)
    env.render()
    time.sleep(0.01)


---------------------------------------------------------------------------
ArgumentError                             Traceback (most recent call last)
<ipython-input-8-38d4b16ce4eb> in <module>()
     16     a = model_run.choose_action(s, e=0.2)
     17     s, r, done, info = env.step(a)
---> 18     env.render()
     19     time.sleep(0.01)

/usr/local/lib/python2.7/dist-packages/gym/core.pyc in render(self, mode, close)
    172             raise error.UnsupportedMode('Unsupported rendering mode: {}. (Supported modes for {}: {})'.format(mode, self, modes))
    173 
--> 174         return self._render(mode=mode, close=close)
    175 
    176     def close(self):

/usr/local/lib/python2.7/dist-packages/gym/envs/classic_control/cartpole.pyc in _render(self, mode, close)
    140         self.poletrans.set_rotation(-x[2])
    141 
--> 142         return self.viewer.render(return_rgb_array = mode=='rgb_array')

/usr/local/lib/python2.7/dist-packages/gym/envs/classic_control/rendering.pyc in render(self, return_rgb_array)
     82         self.window.clear()
     83         self.window.switch_to()
---> 84         self.window.dispatch_events()
     85         self.transform.enable()
     86         for geom in self.geoms:

/usr/local/lib/python2.7/dist-packages/pyglet/window/xlib/__init__.pyc in dispatch_events(self)
    851         # Check for the events specific to this window
    852         while xlib.XCheckWindowEvent(_x_display, _window,
--> 853                                      0x1ffffff, byref(e)):
    854             # Key events are filtered by the xlib window event
    855             # handler so they get a shot at the prefiltered event.

ArgumentError: argument 2: <type 'exceptions.TypeError'>: wrong type

In [ ]: