Copyright 2020 Google LLC.
Licensed under the Apache License, Version 2.0 (the "License");
In [0]:
#@title Default title text
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
In [0]:
import gym
import tensorflow.compat.v2 as tf
import numpy as np
import pickle
import imp
import getpass
import os
import random
import string
from action_gap_rl import replay
from action_gap_rl import value as value_lib
from action_gap_rl.policies import layers_lib
replay = imp.reload(replay)
value_lib = imp.reload(value_lib)
layers_lib = imp.reload(layers_lib)
In [0]:
tf.enable_v2_behavior()
In [0]:
class AttrDict(dict):
def __init__(self, *args, **kwargs):
super(AttrDict, self).__init__(*args, **kwargs)
self.__dict__ = self
def to_dict(d):
if isinstance(d, AttrDict):
return {k: to_dict(v) for k, v in d.items()}
return d
def filter_bool(lst, mask):
return [lst[i] for i in range(len(lst)) if mask[i]]
def rand_str(N):
return ''.join(random.choice(string.ascii_uppercase + string.digits) for _ in range(N))
In [0]:
class BehaviorPolicy(tf.keras.Model):
"""A policy that takes an arbitrary function as the un-normalized log pdf."""
def __init__(self, config, name=None):
super(BehaviorPolicy, self).__init__(
name=name or self.__class__.__name__)
self._config = config
self.num_actions = config.num_actions
if 'initializer' in config:
init = config.initializer
else:
init = tf.keras.initializers.glorot_uniform()
hidden_widths = config.hidden_widths
if config.embed:
transformation_layers = [layers_lib.soft_hot_layer(**config.embed)]
else:
transformation_layers = []
self._body = tf.keras.Sequential(
transformation_layers
+ [tf.keras.layers.Dense(w, activation='relu', kernel_initializer=init) for w in hidden_widths]
+ [tf.keras.layers.Dense(self.num_actions, activation=None, kernel_initializer=init)]
)
def call(self, states):
return tf.argmax(self._body(tf.expand_dims(states, axis=0)), axis=-1).numpy()[0]
In [0]:
class ActionAdaptor(object):
def __init__(self, env, actions={0:-2., 1:2.}, t_res=1):
self.env = env
self.actions = actions
self.t_res = t_res
assert t_res >= 1
def step(self, a):
for _ in range(self.t_res):
result = self.env.step([self.actions[a]])
return result
def reset(self):
return self.env.reset()
@property
def unwrapped(self):
return self.env.unwrapped
@property
def action_space(self):
return gym.spaces.Discrete(2)
In [0]:
import copy
def policy_returns_with_horizon(env, state, policy, horizon, irresolution=1, forced_actions=()):
env = copy.deepcopy(env)
R = 0.
for t in range(horizon):
if t < len(forced_actions):
a = forced_actions[t]
else:
a = policy(state)
for _ in range(irresolution):
state, reward, term, _ = env.step(a)
R += reward
if term: break
return R
In [0]:
# TODO: Horizon H returns under optimal and behavior policies.
# TODO: more efficient episode sampling using an ensemble of behavior policies
WRITE_OUT = True #@param
FILTER = True #@param
compute_behavior_policy_returns = True #@param
## compute_optimal_policy_returns = False #@param
num_episodes = 30 #@param
num_datasets = 2 #@param
episode_length = 200 #@param
temporal_resolution = 10 #@param
horizons = [1, 5, 10] #@param
# file_name = "v3/pendulum_a2_t10_nnp_eval" #@param
file_name = "v3/pendulum_test" #@param
RENDER = False #@param
env = ActionAdaptor(gym.make('Pendulum-v0'))
embed=layers_lib.obs_embedding_kwargs(
20,
bounds=((-1,1),(-1,1),(0,2*np.pi)),
variance=[1.]*3,
spillover=0.05,
)
# embed=None
data_keys = []
if compute_behavior_policy_returns:
for h in horizons:
data_keys.extend(['pi0_h={}/R0'.format(h), 'pi0_h={}/R1'.format(h)])
memory = replay.Memory(data_keys)
for dataset_index in range(num_datasets):
print('dataset index =', dataset_index)
for _ in range(num_episodes):
behavior_policy = BehaviorPolicy(AttrDict(
num_actions=2,
initializer=tf.keras.initializers.glorot_normal(),
embed=embed,
hidden_widths=[64]),
name='policy_'+rand_str(10))
# collect a trajectory
obs = env.reset()
memory.log_init(obs)
for _ in range(episode_length // temporal_resolution):
if RENDER: env.render()
act = behavior_policy(obs)
for rep in range(temporal_resolution):
if compute_behavior_policy_returns:
data = {}
for h in horizons:
if rep % h == 0:
r0, r1 = [
policy_returns_with_horizon(
env,
obs,
behavior_policy,
horizon=horizon,
irresolution=temporal_resolution,
forced_actions=(a,))
for a in (0, 1)]
data.update({'pi0_h={}/R0'.format(h): r0, 'pi0_h={}/R1'.format(h): r1})
else:
data.update({'pi0_h={}/R0'.format(h): 0., 'pi0_h={}/R1'.format(h): 0.})
else:
data = {}
next_obs, reward, term, _ = env.step(act)
memory.log_experience(obs, act, reward, next_obs, data=data)
obs = next_obs
if term:
break
if RENDER: env.render()
print('done simulating')
if FILTER:
ma = np.mean(memory.actions, axis=1)
mask = np.logical_and(ma>=0.33, ma<=.66)
print('Num episodes retained:', np.count_nonzero(mask))
print('Returns:', np.sum(memory.rewards, axis=1)[mask].tolist())
memory.observations = filter_bool(memory.observations, mask)
memory.actions = filter_bool(memory.actions, mask)
memory.rewards = filter_bool(memory.rewards, mask)
print('done filtering')
if WRITE_OUT:
s = memory.serialize()
# Make directory.
user = getpass.getuser()
path = '/tmp/action_gap_rl/datasets'.format(user)
os.makedirs(path)
# Save pickle file
with open(
os.path.join(path, '{}.{}.pickle'.format(file_name, dataset_index)),
'wb') as f:
f.write(s)
# Sanity check serialization.
m2 = replay.Memory()
m2.unserialize(s)
print(np.array_equal(m2.entered_states(), memory.entered_states()))
print(np.array_equal(m2.exited_states(), memory.exited_states()))
print(np.array_equal(m2.attempted_actions(), memory.attempted_actions()))
print(np.array_equal(m2.observed_rewards(), memory.observed_rewards()))
print('\n\n')
In [0]: