Copyright 2020 Google LLC.
Licensed under the Apache License, Version 2.0 (the "License");
In [0]:
#@title Default title text
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
In [0]:
import tensorflow.compat.v2 as tf
import numpy as np
import imp
import os
import time
from collections import namedtuple
In [0]:
tf.enable_v2_behavior()
In [0]:
from action_gap_rl import replay
from action_gap_rl import value as value_lib
from action_gap_rl import layers_lib
replay = imp.reload(replay)
value_lib = imp.reload(value_lib)
layers_lib = imp.reload(layers_lib)
In [0]:
class AttrDict(dict):
def __init__(self, *args, **kwargs):
super(AttrDict, self).__init__(*args, **kwargs)
self.__dict__ = self
def to_dict(d):
if isinstance(d, AttrDict):
return {k: to_dict(v) for k, v in d.items()}
return d
In [0]:
def down_sample(lst, irresolution, reduce_fn):
return [reduce_fn(lst[i: i+irresolution]) for i in range(0, len(lst), irresolution)]
first = lambda s: s[0]
Metadata = namedtuple('Metadata', 'ds,y_scale,aux_data')
Dataset = namedtuple('Dataset', 'o,a,r')
def load_datasets(dataset_files, gamma=1.0, horizon=None, irresolution=1, scale=1.0):
datasets = {}
for df in DATASET_FILES:
with open(df, 'rb') as f:
s = f.read()
memory = replay.Memory()
memory.unserialize(s)
if irresolution > 1:
memory.rewards = [down_sample(r, irresolution, sum) for r in memory.rewards]
memory.actions = [down_sample(a, irresolution, first) for a in memory.actions]
memory.observations = [down_sample(o, irresolution, first) for o in memory.observations]
if horizon is None:
qmax_Bax = np.reshape(value_lib.max_q_iteration(memory, gamma), (-1, 1))
else:
assert horizon == 1
qmax_Bax = np.reshape(memory.rewards, (-1, 1))
data = Dataset(
o=memory.exited_states(),
a=memory.attempted_actions(),
r=qmax_Bax * scale)
y_scale = np.max(np.abs(qmax_Bax * scale))
datasets[os.path.splitext(os.path.basename(df))[0]] = Metadata(ds=data, y_scale=y_scale, aux_data=memory.data)
return datasets
In [0]:
# Load datasets.
DATASET_FILES = [
# TRAIN
'/tmp/action_gap_rl/datasets/v2/pendulum_train.pickle',
# EVAL
'/tmp/action_gap_rl/datasets/v2/pendulum_eval.pickle',
]
datasets = {
'1IR': load_datasets(DATASET_FILES, horizon=1, gamma=1.0, irresolution=1, scale=1/16),
'10IR': load_datasets(DATASET_FILES, horizon=1, gamma=1.0, irresolution=10, scale=1/109)
}
for category, subset in datasets.items():
for ds, data in subset.items():
print('{}/{} : {}, {}'.format(category, ds, data.ds.o.shape[0], data.y_scale))
In [0]:
datasets['1IR']['pendulum_a2_t10_nnp_eval'].aux_data.keys()
Out[0]:
In [0]:
class CategoricalPolicy(tf.keras.Model):
"""A policy that takes an arbitrary function as the un-normalized log pdf."""
def __init__(self, config, name=None):
super(CategoricalPolicy, self).__init__(
name=name or self.__class__.__name__)
self._config = config
self.num_actions = config.num_actions
hidden_widths = config.hidden_widths
if config.embed:
transformation_layers = [layers_lib.soft_hot_layer(**config.embed)]
else:
transformation_layers = []
self._body = tf.keras.Sequential(
transformation_layers
+ [tf.keras.layers.Dense(w, activation='relu') for w in hidden_widths]
+ [tf.keras.layers.Dense(self.num_actions, activation=None)]
)
def call(self, states, actions):
# Returns unnormalized log-pdf of the actions (value predictions)
return index_rows(self._body(states), actions)
def argmax(self, states):
return tf.argmax(self._body(states), axis=1)
In [0]:
def l2_loss(model, states, actions, targets):
return tf.reduce_mean(tf.square(model(states, actions) - targets))
def l1_loss(model, states, actions, targets):
return tf.reduce_mean(tf.abs(model(states, actions) - targets))
def sample_batch(batch_size, *args):
assert args
idx = np.random.choice(args[0].shape[0], batch_size)
return tuple([arg[idx] for arg in args])
def optimize(optimizer, model, loss_fn, data, eval_fn=None,
batch_size=100, maxiter=10000, report_gap=100, eval_size=None):
trace_loss = np.zeros(maxiter//report_gap + 1)
trace_eval = eval_fn and np.zeros(maxiter//report_gap + 1)
start = time.time()
j = 0
for i in range(maxiter+1):
optimizer.minimize(lambda: loss_fn(model, *sample_batch(batch_size, *data)),
model.trainable_variables)
if i % report_gap == 0:
if eval_size:
batch = sample_batch(eval_size, *data)
else:
batch = data
trace_loss[j] = loss_fn(model, *batch).numpy()
if trace_eval is not None:
trace_eval[j] = eval_fn(model, *batch).numpy()
print(i, time.time() - start, trace_loss[j], trace_eval[j])
else:
print(i, time.time() - start, trace_loss[j])
j += 1
return trace_loss, trace_eval, time.time() - start
def index_rows(a, idx):
# https://stackoverflow.com/a/40723732
idx_2 = tf.expand_dims(tf.cast(idx, tf.int32), 1)
rng = tf.expand_dims(tf.range(tf.shape(idx)[0]), 1)
ind = tf.concat([rng, idx_2], 1)
return tf.expand_dims(tf.gather_nd(a, ind), 1)
In [0]:
# device_string = '/device:GPU:0'
# device_string = '/device:TPU:0'
# device_string = '' # CPU
# '' for CPU, '/device:GPU:0' for GPU, '/device:TPU:0' for TPU
device_string = ''
BATCH_SIZE = 1000
ITERATIONS = 200 # 1000
policies = {}
train_datasets = ('1IR/pendulum_a2_t10_nnp_train', '10IR/pendulum_a2_t10_nnp_train')
for catds in train_datasets:
cat, ds = catds.split('/')
data = datasets[cat][ds].ds
print('Training policy on "{}"'.format(catds))
embed=layers_lib.obs_embedding_kwargs(
20,
batch=data[0],
)
# embed=None
policy = CategoricalPolicy(AttrDict(num_actions=2, embed=embed, hidden_widths=[512,256]), name='policy_'+catds+'_'+str(np.random.randint(0, 1000000000)))
with tf.device(device_string):
print(l2_loss(policy, *sample_batch(1000, *data)).numpy())
optimizer = tf.keras.optimizers.Adam(learning_rate=1e-4)
optimize(optimizer, policy, l2_loss, data, eval_fn=l1_loss,
batch_size=BATCH_SIZE,
maxiter=ITERATIONS,
report_gap=10)
policies[catds] = policy
print('\n'+'='*32+'\n')
In [0]:
policies
Out[0]:
In [0]:
for catds, policy in policies.items():
cat, ds = catds.split('/')
assert ds.endswith('_train')
ds = ds[:-6]+'_eval'
print('evaluating policy for {}/{}'.format(cat, ds))
eval_data = datasets[cat][ds].ds
print('L1: ', l1_loss(policy, *eval_data).numpy())
print('L2: ', l2_loss(policy, *eval_data).numpy())
print('')
In [0]:
# Sanity check: view prediction errors on eval sets
for catds, policy in policies.items():
cat, ds = catds.split('/')
assert ds.endswith('_train')
ds = ds[:-6]+'_eval'
eval_data = datasets[cat][ds].ds
states, actions, targets = eval_data
print(ds)
print(policy(states, actions) - targets)
print('')
In [0]: