In [0]:
#@title Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
In Reinforcement Learning terminology, policies map an observation from the environment to an action or a distribution over actions. In TF-Agents, observations from the environment are contained in a named tuple TimeStep('step_type', 'discount', 'reward', 'observation')
, and policies map timesteps to actions or distributions over actions. Most policies use timestep.observation
, some policies use timestep.step_type
(e.g. to reset the state at the beginning of an episode in stateful policies), but timestep.discount
and timestep.reward
are usually ignored.
Policies are related to other components in TF-Agents in the following way. Most policies have a neural network to compute actions and/or distributions over actions from TimeSteps. Agents can contain one or more policies for different purposes, e.g. a main policy that is being trained for deployment, and a noisy policy for data collection. Policies can be saved/restored, and can be used indepedently of the agent for data collection, evaluation etc.
Some policies are easier to write in Tensorflow (e.g. those with a neural network), whereas others are easier to write in Python (e.g. following a script of actions). So in TF agents, we allow both Python and Tensorflow policies. Morever, policies written in TensorFlow might have to be used in a Python environment, or vice versa, e.g. a TensorFlow policy is used for training but later deployed in a production python environment. To make this easier, we provide wrappers for converting between python and TensorFlow policies.
Another interesting class of policies are policy wrappers, which modify a given policy in a certain way, e.g. add a particular type of noise, make a greedy or epsilon-greedy version of a stochastic policy, randomly mix multiple policies etc.
If you haven't installed tf-agents yet, run:
In [0]:
!pip install --pre tf-agents[reverb]
In [0]:
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import abc
import tensorflow as tf
import tensorflow_probability as tfp
import numpy as np
from tf_agents.specs import array_spec
from tf_agents.specs import tensor_spec
from tf_agents.networks import network
from tf_agents.policies import py_policy
from tf_agents.policies import random_py_policy
from tf_agents.policies import scripted_py_policy
from tf_agents.policies import tf_policy
from tf_agents.policies import random_tf_policy
from tf_agents.policies import actor_policy
from tf_agents.policies import q_policy
from tf_agents.policies import greedy_policy
from tf_agents.trajectories import time_step as ts
tf.compat.v1.enable_v2_behavior()
The interface for Python policies is defined in policies/py_policy.PyPolicy
. The main methods are:
In [0]:
class Base(object):
@abc.abstractmethod
def __init__(self, time_step_spec, action_spec, policy_state_spec=()):
self._time_step_spec = time_step_spec
self._action_spec = action_spec
self._policy_state_spec = policy_state_spec
@abc.abstractmethod
def reset(self, policy_state=()):
# return initial_policy_state.
pass
@abc.abstractmethod
def action(self, time_step, policy_state=()):
# return a PolicyStep(action, state, info) named tuple.
pass
@abc.abstractmethod
def distribution(self, time_step, policy_state=()):
# Not implemented in python, only for TF policies.
pass
@abc.abstractmethod
def update(self, policy):
# update self to be similar to the input `policy`.
pass
@property
def time_step_spec(self):
return self._time_step_spec
@property
def action_spec(self):
return self._action_spec
@property
def policy_state_spec(self):
return self._policy_state_spec
The most important method is action(time_step)
which maps a time_step
containing an observation from the environment to a PolicyStep named tuple containing the following attributes:
action
: The action to be applied to the environment.state
: The state of the policy (e.g. RNN state) to be fed into the next call to action.info
: Optional side information such as action log probabilities.The time_step_spec
and action_spec
are specifications for the input time step and the output action. Policies also have a reset
function which is typically used for resetting the state in stateful policies. The update(new_policy)
function updates self
towards new_policy
.
Now, let us look at a couple of examples of python policies.
A simple example of a PyPolicy
is the RandomPyPolicy
which generates random actions for the discrete/continuous given action_spec. The input time_step
is ignored.
In [0]:
action_spec = array_spec.BoundedArraySpec((2,), np.int32, -10, 10)
my_random_py_policy = random_py_policy.RandomPyPolicy(time_step_spec=None,
action_spec=action_spec)
time_step = None
action_step = my_random_py_policy.action(time_step)
print(action_step)
action_step = my_random_py_policy.action(time_step)
print(action_step)
A scripted policy plays back a script of actions represented as a list of (num_repeats, action)
tuples. Every time the action
function is called, it returns the next action from the list until the specified number of repeats is done, and then moves on to the next action in the list. The reset
method can be called to start executing from the beginning of the list.
In [0]:
action_spec = array_spec.BoundedArraySpec((2,), np.int32, -10, 10)
action_script = [(1, np.array([5, 2], dtype=np.int32)),
(0, np.array([0, 0], dtype=np.int32)), # Setting `num_repeats` to 0 will skip this action.
(2, np.array([1, 2], dtype=np.int32)),
(1, np.array([3, 4], dtype=np.int32))]
my_scripted_py_policy = scripted_py_policy.ScriptedPyPolicy(
time_step_spec=None, action_spec=action_spec, action_script=action_script)
policy_state = my_scripted_py_policy.get_initial_state()
time_step = None
print('Executing scripted policy...')
action_step = my_scripted_py_policy.action(time_step, policy_state)
print(action_step)
action_step= my_scripted_py_policy.action(time_step, action_step.state)
print(action_step)
action_step = my_scripted_py_policy.action(time_step, action_step.state)
print(action_step)
print('Resetting my_scripted_py_policy...')
policy_state = my_scripted_py_policy.get_initial_state()
action_step = my_scripted_py_policy.action(time_step, policy_state)
print(action_step)
TensorFlow policies follow the same interface as Python policies. Let us look at a few examples.
In [0]:
action_spec = tensor_spec.BoundedTensorSpec(
(2,), tf.float32, minimum=-1, maximum=3)
input_tensor_spec = tensor_spec.TensorSpec((2,), tf.float32)
time_step_spec = ts.time_step_spec(input_tensor_spec)
my_random_tf_policy = random_tf_policy.RandomTFPolicy(
action_spec=action_spec, time_step_spec=time_step_spec)
observation = tf.ones(time_step_spec.observation.shape)
time_step = ts.restart(observation)
action_step = my_random_tf_policy.action(time_step)
print('Action:')
print(action_step.action)
Let us define a network as follows:
In [0]:
class ActionNet(network.Network):
def __init__(self, input_tensor_spec, output_tensor_spec):
super(ActionNet, self).__init__(
input_tensor_spec=input_tensor_spec,
state_spec=(),
name='ActionNet')
self._output_tensor_spec = output_tensor_spec
self._sub_layers = [
tf.keras.layers.Dense(
action_spec.shape.num_elements(), activation=tf.nn.tanh),
]
def call(self, observations, step_type, network_state):
del step_type
output = tf.cast(observations, dtype=tf.float32)
for layer in self._sub_layers:
output = layer(output)
actions = tf.reshape(output, [-1] + self._output_tensor_spec.shape.as_list())
# Scale and shift actions to the correct range if necessary.
return actions, network_state
In TensorFlow most network layers are designed for batch operations, so we expect the input time_steps to be batched, and the output of the network will be batched as well. Also the network is responsible for producing actions in the correct range of the given action_spec. This is conventionally done using e.g. a tanh activation for the final layer to produce actions in [-1, 1] and then scaling and shifting this to the correct range as the input action_spec (e.g. see tf_agents/agents/ddpg/networks.actor_network()
).
Now, we can create an actor policy using the above network.
In [0]:
input_tensor_spec = tensor_spec.TensorSpec((4,), tf.float32)
time_step_spec = ts.time_step_spec(input_tensor_spec)
action_spec = tensor_spec.BoundedTensorSpec((3,),
tf.float32,
minimum=-1,
maximum=1)
action_net = ActionNet(input_tensor_spec, action_spec)
my_actor_policy = actor_policy.ActorPolicy(
time_step_spec=time_step_spec,
action_spec=action_spec,
actor_network=action_net)
We can apply it to any batch of time_steps that follow time_step_spec:
In [0]:
batch_size = 2
observations = tf.ones([2] + time_step_spec.observation.shape.as_list())
time_step = ts.restart(observations, batch_size)
action_step = my_actor_policy.action(time_step)
print('Action:')
print(action_step.action)
distribution_step = my_actor_policy.distribution(time_step)
print('Action distribution:')
print(distribution_step.action)
In the above example, we created the policy using an action network that produces an action tensor. In this case, policy.distribution(time_step)
is a deterministic (delta) distribution around the output of policy.action(time_step)
. One way to produce a stochastic policy is to wrap the actor policy in a policy wrapper that adds noise to the actions. Another way is to create the actor policy using an action distribution network instead of an action network as shown below.
In [0]:
class ActionDistributionNet(ActionNet):
def call(self, observations, step_type, network_state):
action_means, network_state = super(ActionDistributionNet, self).call(
observations, step_type, network_state)
action_std = tf.ones_like(action_means)
return tfp.distributions.Normal(action_means, action_std), network_state
action_distribution_net = ActionDistributionNet(input_tensor_spec, action_spec)
my_actor_policy = actor_policy.ActorPolicy(
time_step_spec=time_step_spec,
action_spec=action_spec,
actor_network=action_distribution_net)
action_step = my_actor_policy.action(time_step)
print('Action:')
print(action_step.action)
distribution_step = my_actor_policy.distribution(time_step)
print('Action distribution:')
print(distribution_step.action)
Note that in the above, actions are clipped to the range of the given action spec [-1, 1]. This is because a constructor argument of ActorPolicy clip=True by default. Setting this to false will return unclipped actions produced by the network.
Stochastic policies can be converted to deterministic policies using, for example, a GreedyPolicy wrapper which chooses stochastic_policy.distribution().mode()
as its action, and a deterministic/delta distribution around this greedy action as its distribution()
.
A Q policy is used in agents like DQN and is based on a Q network that predicts a Q value for each discrete action. For a given time step, the action distribution in the Q Policy is a categorical distribution created using the q values as logits.
In [0]:
input_tensor_spec = tensor_spec.TensorSpec((4,), tf.float32)
time_step_spec = ts.time_step_spec(input_tensor_spec)
action_spec = tensor_spec.BoundedTensorSpec((),
tf.int32,
minimum=0,
maximum=2)
num_actions = action_spec.maximum - action_spec.minimum + 1
class QNetwork(network.Network):
def __init__(self, input_tensor_spec, action_spec, num_actions=num_actions, name=None):
super(QNetwork, self).__init__(
input_tensor_spec=input_tensor_spec,
state_spec=(),
name=name)
self._sub_layers = [
tf.keras.layers.Dense(num_actions),
]
def call(self, inputs, step_type=None, network_state=()):
del step_type
inputs = tf.cast(inputs, tf.float32)
for layer in self._sub_layers:
inputs = layer(inputs)
return inputs, network_state
batch_size = 2
observation = tf.ones([batch_size] + time_step_spec.observation.shape.as_list())
time_steps = ts.restart(observation, batch_size=batch_size)
my_q_network = QNetwork(
input_tensor_spec=input_tensor_spec,
action_spec=action_spec)
my_q_policy = q_policy.QPolicy(
time_step_spec, action_spec, q_network=my_q_network)
action_step = my_q_policy.action(time_steps)
distribution_step = my_q_policy.distribution(time_steps)
print('Action:')
print(action_step.action)
print('Action distribution:')
print(distribution_step.action)
A policy wrapper can be used to wrap and modify a given policy, e.g. add noise. Policy wrappers are a subclass of Policy (Python/TensorFlow) and can therefore be used just like any other policy.
In [0]:
my_greedy_policy = greedy_policy.GreedyPolicy(my_q_policy)
action_step = my_greedy_policy.action(time_steps)
print('Action:')
print(action_step.action)
distribution_step = my_greedy_policy.distribution(time_steps)
print('Action distribution:')
print(distribution_step.action)