Copyright 2020 Google LLC.

Licensed under the Apache License, Version 2.0 (the "License");


In [0]:
#@title Default title text
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

In [0]:
import tensorflow.compat.v2 as tf
import numpy as np
import imp
import os
import time
from collections import namedtuple

In [0]:
tf.enable_v2_behavior()

In [0]:
from action_gap_rl import replay
from action_gap_rl import value as value_lib
from action_gap_rl import layers_lib
replay = imp.reload(replay)
value_lib = imp.reload(value_lib)
layers_lib = imp.reload(layers_lib)

In [0]:
class AttrDict(dict):
  def __init__(self, *args, **kwargs):
    super(AttrDict, self).__init__(*args, **kwargs)
    self.__dict__ = self

def to_dict(d):
  if isinstance(d, AttrDict):
    return {k: to_dict(v) for k, v in d.items()}
  return d

In [0]:
def down_sample(lst, irresolution, reduce_fn):
  return [reduce_fn(lst[i: i+irresolution]) for i in range(0, len(lst), irresolution)]

first = lambda s: s[0]


Metadata = namedtuple('Metadata', 'ds,y_scale,aux_data')
Dataset = namedtuple('Dataset', 'o,a,r')

def load_datasets(dataset_files, gamma=1.0, horizon=None, irresolution=1, scale=1.0):
  datasets = {}
  for df in DATASET_FILES:        
    with open(df, 'rb') as f:
      s = f.read()

    memory = replay.Memory()
    memory.unserialize(s)

    if irresolution > 1:
      memory.rewards = [down_sample(r, irresolution, sum) for r in memory.rewards]
      memory.actions = [down_sample(a, irresolution, first) for a in memory.actions]
      memory.observations = [down_sample(o, irresolution, first) for o in memory.observations]

    if horizon is None:
      qmax_Bax = np.reshape(value_lib.max_q_iteration(memory, gamma), (-1, 1))
    else:
      assert horizon == 1
      qmax_Bax = np.reshape(memory.rewards, (-1, 1))
    
    data = Dataset(
        o=memory.exited_states(), 
        a=memory.attempted_actions(),
        r=qmax_Bax * scale)
    
    y_scale = np.max(np.abs(qmax_Bax * scale))

    datasets[os.path.splitext(os.path.basename(df))[0]] = Metadata(ds=data, y_scale=y_scale, aux_data=memory.data)
  
  return datasets

In [0]:
# Load datasets.
DATASET_FILES = [
    # TRAIN
    '/tmp/action_gap_rl/datasets/v2/pendulum_train.pickle',

    # EVAL
    '/tmp/action_gap_rl/datasets/v2/pendulum_eval.pickle',
]

datasets = {
    '1IR': load_datasets(DATASET_FILES, horizon=1, gamma=1.0, irresolution=1, scale=1/16),
    '10IR': load_datasets(DATASET_FILES, horizon=1, gamma=1.0, irresolution=10, scale=1/109)
}

for category, subset in datasets.items():
  for ds, data in subset.items():
    print('{}/{} : {}, {}'.format(category, ds, data.ds.o.shape[0], data.y_scale))


1IR/pendulum_a2_t10_nnp_train : 22800, 1.0163658022882485
1IR/pendulum_a2_t10_nnp_eval : 21800, 1.0168128733032964
10IR/pendulum_a2_t10_nnp_train : 2280, 1.0021880235295113
10IR/pendulum_a2_t10_nnp_eval : 2180, 1.0006529785472604

In [0]:
datasets['1IR']['pendulum_a2_t10_nnp_eval'].aux_data.keys()


Out[0]:
dict_keys(['pi0_h=1/R0', 'pi0_h=1/R1', 'pi0_h=5/R0', 'pi0_h=5/R1', 'pi0_h=10/R0', 'pi0_h=10/R1'])

In [0]:
class CategoricalPolicy(tf.keras.Model):
  """A policy that takes an arbitrary function as the un-normalized log pdf."""

  def __init__(self, config, name=None):
    super(CategoricalPolicy, self).__init__(
        name=name or self.__class__.__name__)
    self._config = config
    self.num_actions = config.num_actions
    hidden_widths = config.hidden_widths
    if config.embed:
      transformation_layers = [layers_lib.soft_hot_layer(**config.embed)]
    else:
      transformation_layers = []
    self._body = tf.keras.Sequential(
        transformation_layers
        + [tf.keras.layers.Dense(w, activation='relu') for w in hidden_widths]
        + [tf.keras.layers.Dense(self.num_actions, activation=None)]
    )

  def call(self, states, actions):
    # Returns unnormalized log-pdf of the actions (value predictions)
    return index_rows(self._body(states), actions)

  def argmax(self, states):
    return tf.argmax(self._body(states), axis=1)

In [0]:
def l2_loss(model, states, actions, targets):
  return tf.reduce_mean(tf.square(model(states, actions) - targets))

def l1_loss(model, states, actions, targets):
  return tf.reduce_mean(tf.abs(model(states, actions) - targets))


def sample_batch(batch_size, *args):
  assert args
  idx = np.random.choice(args[0].shape[0], batch_size)
  return tuple([arg[idx] for arg in args])


def optimize(optimizer, model, loss_fn, data, eval_fn=None,
             batch_size=100, maxiter=10000, report_gap=100, eval_size=None):
  trace_loss = np.zeros(maxiter//report_gap + 1)
  trace_eval = eval_fn and np.zeros(maxiter//report_gap + 1)
  start = time.time()
  j = 0
  for i in range(maxiter+1):
    optimizer.minimize(lambda: loss_fn(model, *sample_batch(batch_size, *data)),
                       model.trainable_variables)
    if i % report_gap == 0:
      if eval_size:
        batch = sample_batch(eval_size, *data)
      else:
        batch = data
      trace_loss[j] = loss_fn(model, *batch).numpy()
      if trace_eval is not None:
        trace_eval[j] = eval_fn(model, *batch).numpy()
        print(i, time.time() - start, trace_loss[j], trace_eval[j])
      else:
        print(i, time.time() - start, trace_loss[j])
      j += 1
  return trace_loss, trace_eval, time.time() - start


def index_rows(a, idx):
  # https://stackoverflow.com/a/40723732
  idx_2 = tf.expand_dims(tf.cast(idx, tf.int32), 1)
  rng = tf.expand_dims(tf.range(tf.shape(idx)[0]), 1)
  ind = tf.concat([rng, idx_2], 1)
  return tf.expand_dims(tf.gather_nd(a, ind), 1)

Training


In [0]:
# device_string = '/device:GPU:0'
# device_string = '/device:TPU:0'
# device_string = ''  # CPU

# '' for CPU, '/device:GPU:0' for GPU,  '/device:TPU:0' for TPU
device_string = ''

BATCH_SIZE = 1000
ITERATIONS = 200  # 1000

policies = {}
train_datasets = ('1IR/pendulum_a2_t10_nnp_train', '10IR/pendulum_a2_t10_nnp_train')
for catds in train_datasets:
  cat, ds = catds.split('/')
  data = datasets[cat][ds].ds
  print('Training policy on "{}"'.format(catds))

  embed=layers_lib.obs_embedding_kwargs(
      20,
      batch=data[0],
  )
  # embed=None
  policy = CategoricalPolicy(AttrDict(num_actions=2, embed=embed, hidden_widths=[512,256]), name='policy_'+catds+'_'+str(np.random.randint(0, 1000000000)))

  with tf.device(device_string):
    print(l2_loss(policy, *sample_batch(1000, *data)).numpy())

  optimizer = tf.keras.optimizers.Adam(learning_rate=1e-4)
  optimize(optimizer, policy, l2_loss, data, eval_fn=l1_loss,
           batch_size=BATCH_SIZE,
           maxiter=ITERATIONS,
           report_gap=10)
  policies[catds] = policy

  print('\n'+'='*32+'\n')


Training policy on "1IR/pendulum_a2_t10_nnp_train"
0.1393086
0 0.6871130466461182 0.10662606358528137 0.2695086896419525
10 1.9298017024993896 0.024069122970104218 0.12149373441934586
20 3.1672990322113037 0.007872077636420727 0.0693831667304039
30 4.418404579162598 0.004413059912621975 0.04777473956346512
40 5.662010669708252 0.002177769085392356 0.03612833842635155
50 6.883200168609619 0.0012185540981590748 0.025759786367416382
60 8.164255857467651 0.0008546442841179669 0.021938396617770195
70 9.418537378311157 0.0006102664628997445 0.018048182129859924
80 10.643690347671509 0.0004671918286476284 0.015309818089008331
90 11.911546230316162 0.0003677069616969675 0.013742777518928051
100 13.181400299072266 0.0003068363294005394 0.012310230173170567
110 14.493411302566528 0.0002600151055958122 0.011396393179893494
120 15.726702690124512 0.00022517552133649588 0.01056018564850092
130 16.962790727615356 0.0001979361695703119 0.009940200485289097
140 18.18386483192444 0.0001773530530044809 0.009301239624619484
150 19.413865089416504 0.00015956826973706484 0.008901944383978844
160 20.68570876121521 0.00014456317876465619 0.008485150523483753
170 21.945855855941772 0.0001322800962952897 0.008065366186201572
180 23.239185571670532 0.00012128549860790372 0.0077376943081617355
190 24.49921178817749 0.00011158815323142335 0.007444859016686678
200 25.755825757980347 0.00010345455666538328 0.007128487341105938

================================

Training policy on "10IR/pendulum_a2_t10_nnp_train"
0.34569624
0 0.2002716064453125 0.2904421091079712 0.4757387340068817
10 1.056138038635254 0.045797199010849 0.16910883784294128
20 1.8789281845092773 0.03701580688357353 0.14168091118335724
30 2.6946518421173096 0.014611796475946903 0.08349346369504929
40 3.789968252182007 0.012578931637108326 0.07779631018638611
50 4.639315843582153 0.00848556961864233 0.05789828300476074
60 5.495704889297485 0.006276874803006649 0.04842951521277428
70 6.334512233734131 0.004708305932581425 0.041257236152887344
80 7.178510904312134 0.003511307528242469 0.035275086760520935
90 8.02871298789978 0.0026497822254896164 0.030110469087958336
100 8.873443365097046 0.0020326655358076096 0.026405006647109985
110 9.70754337310791 0.0015879484126344323 0.023577705025672913
120 10.541680812835693 0.001278302283026278 0.020893214270472527
130 11.37149453163147 0.0010631912155076861 0.018960315734148026
140 12.201504945755005 0.0008879590313881636 0.017184384167194366
150 13.037904024124146 0.0007555758347734809 0.015799270942807198
160 13.878431558609009 0.0006492193206213415 0.014735614880919456
170 14.714341163635254 0.000563266163226217 0.013822839595377445
180 15.557454586029053 0.00049587432295084 0.012855799868702888
190 16.405006408691406 0.0004416673327796161 0.012297489680349827
200 17.247861623764038 0.0003948996018152684 0.011713462881743908

================================

Evaluation on dataset


In [0]:
policies


Out[0]:
{'1IR/pendulum_a2_t10_nnp_train': <__main__.CategoricalPolicy at 0x7f093e86eba8>,
 '10IR/pendulum_a2_t10_nnp_train': <__main__.CategoricalPolicy at 0x7f093e61bc88>}

In [0]:
for catds, policy in policies.items():
  cat, ds = catds.split('/')
  assert ds.endswith('_train')
  ds = ds[:-6]+'_eval'
  print('evaluating policy for {}/{}'.format(cat, ds))
  eval_data = datasets[cat][ds].ds
  print('L1: ', l1_loss(policy, *eval_data).numpy())
  print('L2: ', l2_loss(policy, *eval_data).numpy())
  print('')


evaluating policy for 1IR/pendulum_a2_t10_nnp_eval
L1:  0.0072429325
L2:  0.0001110887

evaluating policy for 10IR/pendulum_a2_t10_nnp_eval
L1:  0.014524726
L2:  0.00071514945


In [0]:
# Sanity check: view prediction errors on eval sets 
for catds, policy in policies.items():
  cat, ds = catds.split('/')
  assert ds.endswith('_train')
  ds = ds[:-6]+'_eval'
  eval_data = datasets[cat][ds].ds
  states, actions, targets = eval_data
  print(ds)
  print(policy(states, actions) - targets)
  print('')


pendulum_a2_t10_nnp_eval
tf.Tensor(
[[-0.00222814]
 [-0.0005995 ]
 [ 0.00530088]
 ...
 [ 0.00438103]
 [ 0.00365618]
 [ 0.00426745]], shape=(19600, 1), dtype=float32)

pendulum_a2_t10_nnp_eval
tf.Tensor(
[[-0.00349784]
 [ 0.00062013]
 [ 0.01205844]
 ...
 [-0.00655425]
 [ 0.0416511 ]
 [ 0.0086602 ]], shape=(1960, 1), dtype=float32)


In [0]: