In [8]:
import numpy as np
import gym
from numpy.random import choice
import random
from phi.api import *
import tensorflow as tf
from tfinterface.reinforcement import OnExperienceModel, OnExperienceTrainer
import time
import os


env = gym.make("FrozenLake-v0")

def select_columns(tensor, indexes):
    idx = tf.stack((tf.range(tf.shape(indexes)[0]), indexes), 1)
    return tf.gather_nd(tensor, idx)

t0 = time.time()


[2017-02-19 12:46:44,020] Making new env: FrozenLake-v0

In [13]:
class Model(OnExperienceModel):
    
    def define_model(self, n_actions, n_states, y=0.98, b=0.5, k=2000, e=0.01):
        
        self.b = b
        self.k = k
        self.e = e
        
        with self.graph.as_default(), tf.device("cpu:0"):
            
            self.s = tf.placeholder(tf.int32, [None], name='s')
            self.a = tf.placeholder(tf.int32, [None], name='a')
            self.r = tf.placeholder(tf.float32, [None], name='r')
            
            self.max_Qs1 = tf.placeholder(tf.float32, [None], name='maxQs1')
            self.lr = tf.placeholder(tf.float32, [], name='lr')

            ops = dict(
                trainable=True, 
                kernel_initializer=tf.random_uniform_initializer(minval=0.0, maxval=0.01),
                use_bias=False,
                bias_initializer=None
            ) #tf.random_uniform_initializer(minval=0, maxval=0.01))


            net = tf.one_hot(self.s, n_states)
            self.Qs = tf.layers.dense(net, n_actions, name='linear_layer', **ops)
  
            self.Qsa = select_columns(self.Qs, self.a)

            self.max_Qs = tf.reduce_max(self.Qs, 1)

            error = self.r + y * self.max_Qs1 - self.Qsa
            self.loss = Pipe(error, tf.nn.l2_loss, tf.reduce_sum)
            self.update = tf.train.GradientDescentOptimizer(self.lr).minimize(self.loss)


    def experience_feed(self, s, a, r, s1, max_Qs1, lr):
        return {self.s: [s], self.a: [a], self.r: [r], self.max_Qs1: [max_Qs1], self.lr: lr}
                
    def choose_action(self, state, e=None):
        actions = self.sess.run(self.Qs, feed_dict={self.s: [state]})[0]
        n = len(actions)
        
        if e is None:
            e = self.e
        
        if random.random() < e:
            return random.randint(0, n-1)
        else:
            return np.argmax(actions)
    
    @property
    def default_trainer(self):
        return Trainer
    
    def reset(self):
        pass

    def learning_rate(self, t):
        return self.b * self.k / (self.k + t)
    

class Trainer(OnExperienceTrainer):
    
    def on_experience_start(self):
        self.lr = self.model.learning_rate(self.global_step)
        
    def get_experience(self, s, a, r, s1, done, info):
        max_Qs1 = self.model.sess.run(self.model.max_Qs, feed_dict={self.model.s: [s1]})[0]
        return s, a, r, s1, max_Qs1, self.lr
    
    def train_on_experience(self, *experience):
        feed_dict = self.model.experience_feed(*experience)
        self.model.sess.run(self.model.update, feed_dict=feed_dict)
        
    def after_episode(self, *args):
        if self.episode % 500 == 0 and self.episode > 0:
            print(self.fit_reward, "of", 500, ", lr:", self.lr)
            self.fit_reward = 0
            self.model.save()
        
        
    
    
n_actions = env.action_space.n
n_states = env.observation_space.n

model_path = os.getcwd() + "/shallow.model"

model = Model(
    n_actions, n_states, 
    y=0.95, b=0.5, k=20000., e=0.05,
    model_path = model_path,
    logs_path = "/logs",
    restore = True
)

In [14]:
model.fit(env)


(206.0, 'of', 500, ', lr:', 0.29171528588098017)
(183.0, 'of', 500, ', lr:', 0.20650916900710392)
(207.0, 'of', 500, ', lr:', 0.15776603297310088)
(221.0, 'of', 500, ', lr:', 0.12708418056120374)
(242.0, 'of', 500, ', lr:', 0.1060096892856007)
(251.0, 'of', 500, ', lr:', 0.09105892422986915)
(224.0, 'of', 500, ', lr:', 0.07998400319936012)
(269.0, 'of', 500, ', lr:', 0.07036802476954472)
(265.0, 'of', 500, ', lr:', 0.06298339757639886)
---------------------------------------------------------------------------
KeyboardInterrupt                         Traceback (most recent call last)
<ipython-input-14-e916f46b95f9> in <module>()
----> 1 model.fit(env)

/home/cristian/data/cristian/tfinterface/tfinterface/reinforcement/reinforcement_base.py in fit(self, *args, **kwargs)
    468 
    469     def fit(self, *args, **kwargs):
--> 470         self.trainer.fit(*args, **kwargs)
    471 
    472     @abstractmethod

/home/cristian/data/cristian/tfinterface/tfinterface/reinforcement/reinforcement_base.py in fit(self, env, episodes, max_episodes)
    531                 experience = self.get_experience(s, a, r, s1, done, info)
    532                 self.experience_buffer.append(experience)
--> 533                 self.train_on_experience(*experience)
    534                 self.after_experience(*experience)
    535 

<ipython-input-13-d3f257c92ec1> in train_on_experience(self, *experience)
     73     def train_on_experience(self, *experience):
     74         feed_dict = self.model.experience_feed(*experience)
---> 75         self.model.sess.run(self.model.update, feed_dict=feed_dict)
     76 
     77     def after_episode(self, *args):

/usr/local/lib/python2.7/dist-packages/tensorflow/python/client/session.pyc in run(self, fetches, feed_dict, options, run_metadata)
    765     try:
    766       result = self._run(None, fetches, feed_dict, options_ptr,
--> 767                          run_metadata_ptr)
    768       if run_metadata:
    769         proto_data = tf_session.TF_GetBuffer(run_metadata_ptr)

/usr/local/lib/python2.7/dist-packages/tensorflow/python/client/session.pyc in _run(self, handle, fetches, feed_dict, options, run_metadata)
    936                 ' to a larger type (e.g. int64).')
    937 
--> 938           np_val = np.asarray(subfeed_val, dtype=subfeed_dtype)
    939 
    940           if not subfeed_t.get_shape().is_compatible_with(np_val.shape):

/usr/local/lib/python2.7/dist-packages/numpy/core/numeric.pyc in asarray(a, dtype, order)
    529 
    530     """
--> 531     return array(a, dtype, copy=False, order=order)
    532 
    533 

KeyboardInterrupt: 

In [4]:
s = env.reset()

for i in range(1000):
    a = model.choose_action(s, e=0)
    s, r, done, info = env.step(a)
    env.render()
    print("")

    if done:
        print(r)
        break


SFFF
FHFH
FFFH
HFFG
  (Left)

SFFF
FHFH
FFFH
HFFG
  (Left)

SFFF
FHFH
FFFH
HFFG
  (Left)

SFFF
FHFH
FFFH
HFFG
  (Left)

SFFF
FHFH
FFFH
HFFG
  (Left)

SFFF
FHFH
FFFH
HFFG
  (Left)

SFFF
FHFH
FFFH
HFFG
  (Left)

SFFF
FHFH
FFFH
HFFG
  (Left)

SFFF
FHFH
FFFH
HFFG
  (Left)

SFFF
FHFH
FFFH
HFFG
  (Left)

SFFF
FHFH
FFFH
HFFG
  (Left)

SFFF
FHFH
FFFH
HFFG
  (Left)

SFFF
FHFH
FFFH
HFFG
  (Left)

SFFF
FHFH
FFFH
HFFG
  (Left)

SFFF
FHFH
FFFH
HFFG
  (Left)

SFFF
FHFH
FFFH
HFFG
  (Left)

SFFF
FHFH
FFFH
HFFG
  (Left)

SFFF
FHFH
FFFH
HFFG
  (Left)

SFFF
FHFH
FFFH
HFFG
  (Up)

SFFF
FHFH
FFFH
HFFG
  (Left)

SFFF
FHFH
FFFH
HFFG
  (Left)

SFFF
FHFH
FFFH
HFFG
  (Up)

SFFF
FHFH
FFFH
HFFG
  (Up)

SFFF
FHFH
FFFH
HFFG
  (Down)

SFFF
FHFH
FFFH
HFFG
  (Right)

SFFF
FHFH
FFFH
HFFG
  (Down)

SFFF
FHFH
FFFH
HFFG
  (Left)

SFFF
FHFH
FFFH
HFFG
  (Down)

SFFF
FHFH
FFFH
HFFG
  (Up)

SFFF
FHFH
FFFH
HFFG
  (Up)

SFFF
FHFH
FFFH
HFFG
  (Left)

SFFF
FHFH
FFFH
HFFG
  (Up)

SFFF
FHFH
FFFH
HFFG
  (Left)

SFFF
FHFH
FFFH
HFFG
  (Left)

SFFF
FHFH
FFFH
HFFG
  (Left)

SFFF
FHFH
FFFH
HFFG
  (Left)

SFFF
FHFH
FFFH
HFFG
  (Left)

SFFF
FHFH
FFFH
HFFG
  (Left)

SFFF
FHFH
FFFH
HFFG
  (Left)

SFFF
FHFH
FFFH
HFFG
  (Left)

SFFF
FHFH
FFFH
HFFG
  (Left)

SFFF
FHFH
FFFH
HFFG
  (Left)

SFFF
FHFH
FFFH
HFFG
  (Left)

SFFF
FHFH
FFFH
HFFG
  (Left)

SFFF
FHFH
FFFH
HFFG
  (Left)

SFFF
FHFH
FFFH
HFFG
  (Left)

SFFF
FHFH
FFFH
HFFG
  (Left)

SFFF
FHFH
FFFH
HFFG
  (Left)

SFFF
FHFH
FFFH
HFFG
  (Left)

SFFF
FHFH
FFFH
HFFG
  (Left)

SFFF
FHFH
FFFH
HFFG
  (Left)

SFFF
FHFH
FFFH
HFFG
  (Left)

SFFF
FHFH
FFFH
HFFG
  (Left)

SFFF
FHFH
FFFH
HFFG
  (Left)

SFFF
FHFH
FFFH
HFFG
  (Left)

SFFF
FHFH
FFFH
HFFG
  (Left)

SFFF
FHFH
FFFH
HFFG
  (Up)

SFFF
FHFH
FFFH
HFFG
  (Left)

SFFF
FHFH
FFFH
HFFG
  (Left)

SFFF
FHFH
FFFH
HFFG
  (Left)

SFFF
FHFH
FFFH
HFFG
  (Left)

SFFF
FHFH
FFFH
HFFG
  (Up)

SFFF
FHFH
FFFH
HFFG
  (Left)

SFFF
FHFH
FFFH
HFFG
  (Left)

SFFF
FHFH
FFFH
HFFG
  (Left)

SFFF
FHFH
FFFH
HFFG
  (Left)

SFFF
FHFH
FFFH
HFFG
  (Left)

SFFF
FHFH
FFFH
HFFG
  (Left)

SFFF
FHFH
FFFH
HFFG
  (Left)

SFFF
FHFH
FFFH
HFFG
  (Left)

SFFF
FHFH
FFFH
HFFG
  (Left)

SFFF
FHFH
FFFH
HFFG
  (Up)

SFFF
FHFH
FFFH
HFFG
  (Up)

SFFF
FHFH
FFFH
HFFG
  (Left)

SFFF
FHFH
FFFH
HFFG
  (Left)

SFFF
FHFH
FFFH
HFFG
  (Left)

SFFF
FHFH
FFFH
HFFG
  (Up)

SFFF
FHFH
FFFH
HFFG
  (Up)

SFFF
FHFH
FFFH
HFFG
  (Left)

SFFF
FHFH
FFFH
HFFG
  (Up)

SFFF
FHFH
FFFH
HFFG
  (Down)

SFFF
FHFH
FFFH
HFFG
  (Left)

SFFF
FHFH
FFFH
HFFG
  (Left)

SFFF
FHFH
FFFH
HFFG
  (Down)

SFFF
FHFH
FFFH
HFFG
  (Up)

SFFF
FHFH
FFFH
HFFG
  (Down)

SFFF
FHFH
FFFH
HFFG
  (Up)

SFFF
FHFH
FFFH
HFFG
  (Left)

SFFF
FHFH
FFFH
HFFG
  (Left)

SFFF
FHFH
FFFH
HFFG
  (Left)

SFFF
FHFH
FFFH
HFFG
  (Left)

SFFF
FHFH
FFFH
HFFG
  (Left)

SFFF
FHFH
FFFH
HFFG
  (Left)

SFFF
FHFH
FFFH
HFFG
  (Left)

SFFF
FHFH
FFFH
HFFG
  (Left)

SFFF
FHFH
FFFH
HFFG
  (Left)

SFFF
FHFH
FFFH
HFFG
  (Left)

SFFF
FHFH
FFFH
HFFG
  (Left)

SFFF
FHFH
FFFH
HFFG
  (Left)

SFFF
FHFH
FFFH
HFFG
  (Left)

SFFF
FHFH
FFFH
HFFG
  (Left)

SFFF
FHFH
FFFH
HFFG
  (Left)

SFFF
FHFH
FFFH
HFFG
  (Left)

SFFF
FHFH
FFFH
HFFG
  (Left)

SFFF
FHFH
FFFH
HFFG
  (Left)

SFFF
FHFH
FFFH
HFFG
  (Up)

SFFF
FHFH
FFFH
HFFG
  (Left)

SFFF
FHFH
FFFH
HFFG
  (Left)

SFFF
FHFH
FFFH
HFFG
  (Left)

SFFF
FHFH
FFFH
HFFG
  (Left)

SFFF
FHFH
FFFH
HFFG
  (Left)

SFFF
FHFH
FFFH
HFFG
  (Left)

SFFF
FHFH
FFFH
HFFG
  (Left)

SFFF
FHFH
FFFH
HFFG
  (Left)

SFFF
FHFH
FFFH
HFFG
  (Left)

SFFF
FHFH
FFFH
HFFG
  (Left)

SFFF
FHFH
FFFH
HFFG
  (Left)

SFFF
FHFH
FFFH
HFFG
  (Left)

SFFF
FHFH
FFFH
HFFG
  (Up)

SFFF
FHFH
FFFH
HFFG
  (Down)

SFFF
FHFH
FFFH
HFFG
  (Up)

SFFF
FHFH
FFFH
HFFG
  (Up)

SFFF
FHFH
FFFH
HFFG
  (Left)

SFFF
FHFH
FFFH
HFFG
  (Left)

SFFF
FHFH
FFFH
HFFG
  (Left)

SFFF
FHFH
FFFH
HFFG
  (Left)

SFFF
FHFH
FFFH
HFFG
  (Left)

SFFF
FHFH
FFFH
HFFG
  (Left)

SFFF
FHFH
FFFH
HFFG
  (Left)

SFFF
FHFH
FFFH
HFFG
  (Up)

SFFF
FHFH
FFFH
HFFG
  (Down)

SFFF
FHFH
FFFH
HFFG
  (Left)

SFFF
FHFH
FFFH
HFFG
  (Down)

SFFF
FHFH
FFFH
HFFG
  (Right)

SFFF
FHFH
FFFH
HFFG
  (Down)

SFFF
FHFH
FFFH
HFFG
  (Right)

SFFF
FHFH
FFFH
HFFG
  (Right)

SFFF
FHFH
FFFH
HFFG
  (Down)

SFFF
FHFH
FFFH
HFFG
  (Down)

SFFF
FHFH
FFFH
HFFG
  (Right)

SFFF
FHFH
FFFH
HFFG
  (Right)

SFFF
FHFH
FFFH
HFFG
  (Right)

SFFF
FHFH
FFFH
HFFG
  (Right)

SFFF
FHFH
FFFH
HFFG
  (Right)

SFFF
FHFH
FFFH
HFFG
  (Down)

SFFF
FHFH
FFFH
HFFG
  (Right)

SFFF
FHFH
FFFH
HFFG
  (Down)

1.0