In [1]:
import numpy as np
import gym
from numpy.random import choice
import random
from tensorbuilder.api import *
import tensorflow as tf
rnn_cell = tf.nn.rnn_cell

env = gym.make("CartPole-v1")


hdf5 not supported (please install/reinstall h5py)
[2017-02-04 03:55:50,345] Making new env: CartPole-v1

In [2]:
tf.train.Saver?

In [2]:
def select_columns(tensor, indexes):
    idx = tf.stack((tf.range(tf.shape(indexes)[0]), indexes), 1)
    return tf.gather_nd(tensor, idx)

def discount(rewards, y):
    r_accum = 0.0
    gains = []
    for r in reversed(list(rewards)):
        r_accum = r + y * r_accum 
        gains.insert(0, r_accum)
        
    return gains

In [72]:
class Model(object):
    
    def __init__(self, **kwargs):
        restore = kwargs.get('restore', False)
        model_path = kwargs.get('model_path', "model")
        logs_path = kwargs.get('logs_path', "/logs/")
        
        self.graph = tf.Graph()
        self.sess = tf.Session(graph=self.graph)
        
        self.make_model(**kwargs)
        
        with self.graph.as_default():
            self.writer = tf.summary.FileWriter(logs_path, graph=self.graph, flush_secs=10.0)
            self.saver = tf.train.Saver()

            if restore:
                self.saver.restore(self.sess, model_path)
            else:
                self.sess.run(tf.global_variables_initializer())
        
    def make_model(self, **kwargs):
        y = kwargs['y']
        n_states = kwargs['n_states']
        n_actions = kwargs['n_actions']
        clip_value = kwargs.get('clip_value', 1.0)
        internal_state_size = kwargs.get('internal_state_size', 16)
        
        with self.graph.as_default():
            with tf.device("cpu:0"):
                s = tf.placeholder(tf.float32, [None, n_states], name='s')
                a = tf.placeholder(tf.int32, [None], name='a')
                r = tf.placeholder(tf.float32, [None], name='r')
                lr = tf.placeholder(tf.float32, [], name='lr')
                
                batch_size = 1
                

#                 s_rec = tf.placeholder(tf.float32, [None, n_states], name='s_rec')
                
                trainer = tf.train.GradientDescentOptimizer(lr)
                
                ops = dict(trainable=True, weights_initializer=tf.random_uniform_initializer(minval=0.0, maxval=0.01), biases_initializer=None) #tf.random_uniform_initializer(minval=0, maxval=0.01))
                
                with tf.variable_scope("actor"):
                    lstm = rnn_cell.LSTMCell(internal_state_size, state_is_tuple=True, cell_clip=clip_value)
                    current_state = lstm.zero_state(batch_size, dtype=tf.float32)
                    lstm_state_placeholders = T.rnn_placeholders_from_state(current_state)
                    
                    zero_state = current_state = self.sess.run(current_state)
                    
                    with tf.name_scope("dynamic"):
                        Ps = Pipe(
                            s, 
                            T.relu_layer(internal_state_size, scope="relu_layer", **ops)
                            .expand_dims(1)
                            .dynamic_rnn(lstm, dtype=tf.float32, time_major=True)[0]
                            .squeeze(axis=1)
                            .softmax_layer(n_actions, scope='softmax_layer', **ops)
                        )
                    
                    with tf.name_scope("manual"):
                        Ps_step, state_step = Pipe(
                            s,
                            T.relu_layer(internal_state_size, scope="relu_layer", reuse=True, **ops)
                            .Then(lstm, lstm_state_placeholders),
                            [
                                T[0].softmax_layer(n_actions, scope='softmax_layer', reuse=True, **ops)
                            ,
                                T[1]
                            ]
                            
                        )
                    
                    
                Psws = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, "actor")

                Psa = select_columns(Ps, a)
                
                base = tf.Variable(0.0)
                
                error = r - base

                loss = -tf.reduce_sum(tf.log(Psa) * error)
                gradients = trainer.compute_gradients(loss, var_list=Psws)
                gradients = [ (tf.clip_by_value(g, -clip_value, clip_value), w) for g, w in gradients if g is not None ]
                update = trainer.apply_gradients(gradients)
                
                loss_base = Pipe(error, tf.nn.l2_loss, tf.reduce_sum)
                gradients = trainer.compute_gradients(loss_base, var_list=[base])
                gradients = [ (tf.clip_by_value(g, -clip_value, clip_value), w) for g, w in gradients if g is not None ]
                update_base = trainer.apply_gradients(gradients)
                        
                
                
        self.s = s; self.a = a; self.r = r;
        self.Ps = Ps; self.Psa = Psa; self.update = update; self.update_base = update_base
        self.lr = lr
        self.lstm = lstm; self.lstm_state_placeholders = lstm_state_placeholders;
        self.Ps_step = Ps_step; self.state_step = state_step;
        self.current_state = current_state; self.zero_state = zero_state
                
    def next_action(self, state):
        
        feed_dict = T.rnn_state_feed_dict(self.lstm_state_placeholders, self.current_state)
        feed_dict[self.s] = [state]
        
        actions, self.current_state = self.sess.run([self.Ps_step, self.state_step], feed_dict=feed_dict)
        actions = actions[0]
        
#         print self.current_state[0][0, 0:3]
        
        n = len(actions)
        return choice(n, p=actions)

    def train(self, s, a, r, s1, lr):
        #train
        self.train_offline([s], [a], [r], [s1], lr)
        
    def train_offline(self, S, A, R, S1, lr):
        #train
        self.sess.run(self.update, feed_dict={
            self.s: S, self.a: A, self.r: R, 
            self.lr: lr
        })
        
        self.sess.run(self.update_base, feed_dict={
            self.s: S, self.a: A, self.r: R, 
            self.lr: lr
        })
        
    def reset(self):
        self.current_state = self.zero_state

    def save(self, model_path):
        self.saver.save(self.sess, model_path)

    def restore(self, model_path):
        self.sess.close()
        self.sess = tf.Session(graph=self.graph)
        self.saver.restore(self.sess, model_path)

    @staticmethod
    def learning_rate(t, b, k):
        return b * k / (k + t)

In [95]:
y = 0.98
b = 0.01
k = 2000.0
model_name = "recurrent-policy-gradient-cartpole.model"
model_path = "/models/" + model_name

model = Model(
    y=y, 
    restore=False,
    model_name = model_name,
    model_path = model_path,
    n_actions = env.action_space.n,
    n_states = env.observation_space.shape[0],
    clip_value = 1.5,
    internal_state_size = 16
)


r_total = 0.0
max_r = 0.0

for t in range(200000):
    lr = model.learning_rate(t, b, k)
    
    s = env.reset()
    model.reset()
    
    r_ep = 0.0
    
    S = []; A = []; R = []; S1 = []
    
    
    for j in range(10000):
        #next action
        a = model.next_action(s)

        #take step
        s1, r, done, info = env.step(a)
        
        r_total += r
        r_ep += r
        
        #append values
        S.append(s); A.append(a); R.append(r); S1.append(s1)
        
        #update state
        s = s1
        
        if done: break
        
    R = discount(R, y)
        
    #train
    model.train_offline(S, A, R, S1, lr)

    save_period = 50
    if t % save_period == 0:
        print r_total / save_period, ", lr:", lr
        r_total = 0
        model.save(model_path)
#         model.reset()


0.32 , lr: 0.01
23.14 , lr: 0.00975609756098
19.64 , lr: 0.00952380952381
21.86 , lr: 0.0093023255814
26.36 , lr: 0.00909090909091
20.3 , lr: 0.00888888888889
20.4 , lr: 0.00869565217391
23.94 , lr: 0.00851063829787
23.0 , lr: 0.00833333333333
23.32 , lr: 0.00816326530612
18.04 , lr: 0.008
19.56 , lr: 0.0078431372549
19.1 , lr: 0.00769230769231
17.54 , lr: 0.00754716981132
14.98 , lr: 0.00740740740741
11.92 , lr: 0.00727272727273
9.28 , lr: 0.00714285714286
9.38 , lr: 0.00701754385965
9.22 , lr: 0.00689655172414
9.24 , lr: 0.00677966101695
9.48 , lr: 0.00666666666667
9.44 , lr: 0.00655737704918
9.48 , lr: 0.00645161290323
9.42 , lr: 0.00634920634921
9.44 , lr: 0.00625
9.08 , lr: 0.00615384615385
9.24 , lr: 0.00606060606061
9.42 , lr: 0.00597014925373
9.5 , lr: 0.00588235294118
9.2 , lr: 0.00579710144928
9.4 , lr: 0.00571428571429
9.46 , lr: 0.0056338028169
9.42 , lr: 0.00555555555556
9.42 , lr: 0.00547945205479
9.48 , lr: 0.00540540540541
9.4 , lr: 0.00533333333333
9.42 , lr: 0.00526315789474
9.32 , lr: 0.00519480519481
9.36 , lr: 0.00512820512821
9.22 , lr: 0.00506329113924
9.16 , lr: 0.005
9.32 , lr: 0.00493827160494
9.46 , lr: 0.00487804878049
9.34 , lr: 0.00481927710843
9.12 , lr: 0.0047619047619
9.52 , lr: 0.00470588235294
9.42 , lr: 0.0046511627907

KeyboardInterruptTraceback (most recent call last)
<ipython-input-95-384b4d506e74> in <module>()
     58         print r_total / save_period, ", lr:", lr
     59         r_total = 0
---> 60         model.save(model_path)
     61 #         model.reset()
     62 

<ipython-input-72-acfca640ce16> in save(self, model_path)
    136 
    137     def save(self, model_path):
--> 138         self.saver.save(self.sess, model_path)
    139 
    140     def restore(self, model_path):

/usr/local/lib/python2.7/dist-packages/tensorflow/python/training/saver.pyc in save(self, sess, save_path, global_step, latest_filename, meta_graph_suffix, write_meta_graph, write_state)
   1333           checkpoint_file, meta_graph_suffix=meta_graph_suffix)
   1334       with sess.graph.as_default():
-> 1335         self.export_meta_graph(meta_graph_filename)
   1336 
   1337     if self._is_empty:

/usr/local/lib/python2.7/dist-packages/tensorflow/python/training/saver.pyc in export_meta_graph(self, filename, collection_list, as_text, export_scope, clear_devices)
   1366         as_text=as_text,
   1367         export_scope=export_scope,
-> 1368         clear_devices=clear_devices)
   1369 
   1370   def restore(self, sess, save_path):

/usr/local/lib/python2.7/dist-packages/tensorflow/python/training/saver.pyc in export_meta_graph(filename, meta_info_def, graph_def, saver_def, collection_list, as_text, graph, export_scope, clear_devices, **kwargs)
   1588       export_scope=export_scope,
   1589       clear_devices=clear_devices,
-> 1590       **kwargs)
   1591   return meta_graph_def
   1592 

/usr/local/lib/python2.7/dist-packages/tensorflow/python/framework/meta_graph.pyc in export_scoped_meta_graph(filename, graph_def, graph, export_scope, as_text, unbound_inputs_col_name, clear_devices, **kwargs)
    642       graph=graph,
    643       export_scope=export_scope,
--> 644       **kwargs)
    645 
    646   if filename:

/usr/local/lib/python2.7/dist-packages/tensorflow/python/framework/meta_graph.pyc in create_meta_graph_def(meta_info_def, graph_def, saver_def, collection_list, graph, export_scope)
    365     meta_graph_def.graph_def.MergeFrom(graph.as_graph_def(add_shapes=True))
    366   else:
--> 367     meta_graph_def.graph_def.MergeFrom(graph_def)
    368 
    369   # Fills in meta_info_def.stripped_op_list using the ops from graph_def.

/usr/local/lib/python2.7/dist-packages/google/protobuf/internal/python_message.pyc in MergeFrom(self, msg)
   1245           field_value = field._default_constructor(self)
   1246           fields[field] = field_value
-> 1247         field_value.MergeFrom(value)
   1248       elif field.cpp_type == CPPTYPE_MESSAGE:
   1249         if value._is_present_in_parent:

/usr/local/lib/python2.7/dist-packages/google/protobuf/internal/containers.pyc in MergeFrom(self, other)
    395     one, copying each individual message.
    396     """
--> 397     self.extend(other._values)
    398 
    399   def remove(self, elem):

/usr/local/lib/python2.7/dist-packages/google/protobuf/internal/containers.pyc in extend(self, elem_seq)
    387       new_element = message_class()
    388       new_element._SetListener(listener)
--> 389       new_element.MergeFrom(message)
    390       values.append(new_element)
    391     listener.Modified()

/usr/local/lib/python2.7/dist-packages/google/protobuf/internal/python_message.pyc in MergeFrom(self, msg)
   1245           field_value = field._default_constructor(self)
   1246           fields[field] = field_value
-> 1247         field_value.MergeFrom(value)
   1248       elif field.cpp_type == CPPTYPE_MESSAGE:
   1249         if value._is_present_in_parent:

/usr/local/lib/python2.7/dist-packages/google/protobuf/internal/containers.pyc in MergeFrom(self, other)
    609       if key in self:
    610         del self[key]
--> 611       self[key].CopyFrom(other[key])
    612     # self._message_listener.Modified() not required here, because
    613     # mutations to submessages already propagate.

/usr/local/lib/python2.7/dist-packages/google/protobuf/message.pyc in CopyFrom(self, other_msg)
    116       return
    117     self.Clear()
--> 118     self.MergeFrom(other_msg)
    119 
    120   def Clear(self):

/usr/local/lib/python2.7/dist-packages/google/protobuf/internal/python_message.pyc in MergeFrom(self, msg)
   1253             field_value = field._default_constructor(self)
   1254             fields[field] = field_value
-> 1255           field_value.MergeFrom(value)
   1256       else:
   1257         self._fields[field] = value

/usr/local/lib/python2.7/dist-packages/google/protobuf/internal/python_message.pyc in MergeFrom(self, msg)
   1251           if field_value is None:
   1252             # Construct a new object to represent this field.
-> 1253             field_value = field._default_constructor(self)
   1254             fields[field] = field_value
   1255           field_value.MergeFrom(value)

/usr/local/lib/python2.7/dist-packages/google/protobuf/internal/python_message.pyc in MakeSubMessageDefault(message)
    435     message_type = field.message_type
    436     def MakeSubMessageDefault(message):
--> 437       result = message_type._concrete_class()
    438       result._SetListener(
    439           _OneofListener(message, field)

/usr/local/lib/python2.7/dist-packages/google/protobuf/internal/python_message.pyc in init(self, **kwargs)
    493     self._listener = message_listener_mod.NullMessageListener()
    494     self._listener_for_children = _Listener(self)
--> 495     for field_name, field_value in kwargs.items():
    496       field = _GetFieldByName(message_descriptor, field_name)
    497       if field is None:

KeyboardInterrupt: 

In [ ]:
rnn_cell.LSTMCell?

In [ ]:
s = env.reset()
s = np.hstack((s,s,s))
    
for i in range(100):
    a = model.next_action(s)
    s1, r, done, info = env.step(a)
    s = np.hstack((s[n_states_env:], s1))
    env.render()
    print("")

    if done:
        print(r)
        break