Copyright 2020 DeepMind Technologies Limited.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at
https://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License.
In [ ]:
!pip install dm-acme
!pip install dm-acme[reverb]
!pip install dm-acme[tf]
!pip install dm-sonnet
In [ ]:
#@title Edit and run
mjkey = """
REPLACE THIS LINE WITH YOUR MUJOCO LICENSE KEY
""".strip()
mujoco_dir = "$HOME/.mujoco"
# Install OpenGL deps
!apt-get update && apt-get install -y --no-install-recommends \
libgl1-mesa-glx libosmesa6 libglew2.0
# Fetch MuJoCo binaries from Roboti
!wget -q https://www.roboti.us/download/mujoco200_linux.zip -O mujoco.zip
!unzip -o -q mujoco.zip -d "$mujoco_dir"
# Copy over MuJoCo license
!echo "$mjkey" > "$mujoco_dir/mjkey.txt"
# Configure dm_control to use the OSMesa rendering backend
%env MUJOCO_GL=osmesa
# Install dm_control, including extra dependencies needed for the locomotion
# mazes.
!pip install dm_control[locomotion_mazes]
In [ ]:
#@title Add your MuJoCo License and run
mjkey = """
""".strip()
mujoco_dir = "$HOME/.mujoco"
# Install OpenGL dependencies
!apt-get update && apt-get install -y --no-install-recommends \
libgl1-mesa-glx libosmesa6 libglew2.0
# Get MuJoCo binaries
!wget -q https://www.roboti.us/download/mujoco200_linux.zip -O mujoco.zip
!unzip -o -q mujoco.zip -d "$mujoco_dir"
# Copy over MuJoCo license
!echo "$mjkey" > "$mujoco_dir/mjkey.txt"
# Configure dm_control to use the OSMesa rendering backend
%env MUJOCO_GL=osmesa
# Install dm_control, including extra dependencies needed for the locomotion
# mazes.
!pip install dm_control[locomotion_mazes]
In [ ]:
!git clone https://github.com/google-research/realworldrl_suite.git
!pip install realworldrl_suite/
In [ ]:
!git clone https://github.com/deepmind/deepmind-research.git
%cd deepmind-research
In [ ]:
import collections
import copy
from typing import Mapping, Sequence
import acme
from acme import specs
from acme.agents.tf import actors
from acme.agents.tf import d4pg
from acme.tf import networks
from acme.tf import utils as tf2_utils
from acme.utils import loggers
from acme.wrappers import single_precision
from acme.tf import utils as tf2_utils
import numpy as np
import realworldrl_suite.environments as rwrl_envs
from reverb import replay_sample
import six
from rl_unplugged import rwrl
import sonnet as snt
import tensorflow as tf
In [ ]:
domain_name = 'cartpole' #@param
task_name = 'swingup' #@param
difficulty = 'easy' #@param
combined_challenge = 'easy' #@param
combined_challenge_str = str(combined_challenge).lower()
tmp_path = '/tmp/rwrl'
gs_path = f'gs://rl_unplugged/rwrl'
data_path = (f'combined_challenge_{combined_challenge_str}/{domain_name}/'
f'{task_name}/offline_rl_challenge_{difficulty}')
!mkdir -p {tmp_path}/{data_path}
!gsutil cp -r {gs_path}/{data_path}/* {tmp_path}/{data_path}
num_shards_str, = !ls {tmp_path}/{data_path}/* | wc -l
num_shards = int(num_shards_str)
In [ ]:
#@title Auxiliary functions
def flatten_observation(observation):
"""Flattens multiple observation arrays into a single tensor.
Args:
observation: A mutable mapping from observation names to tensors.
Returns:
A flattened and concatenated observation array.
Raises:
ValueError: If `observation` is not a `collections.MutableMapping`.
"""
if not isinstance(observation, collections.MutableMapping):
raise ValueError('Can only flatten dict-like observations.')
if isinstance(observation, collections.OrderedDict):
keys = six.iterkeys(observation)
else:
# Keep a consistent ordering for other mappings.
keys = sorted(six.iterkeys(observation))
observation_arrays = [tf.reshape(observation[key], [-1]) for key in keys]
return tf.concat(observation_arrays, 0)
def preprocess_fn(sample):
o_tm1, a_tm1, r_t, d_t, o_t = sample.data[:5]
o_tm1 = flatten_observation(o_tm1)
o_t = flatten_observation(o_t)
return replay_sample.ReplaySample(
info=sample.info, data=(o_tm1, a_tm1, r_t, d_t, o_t))
In [ ]:
batch_size = 10 #@param
environment = rwrl_envs.load(
domain_name=domain_name,
task_name=f'realworld_{task_name}',
environment_kwargs=dict(log_safety_vars=False, flat_observation=True),
combined_challenge=combined_challenge)
environment = single_precision.SinglePrecisionWrapper(environment)
environment_spec = specs.make_environment_spec(environment)
act_spec = environment_spec.actions
obs_spec = environment_spec.observations
dataset = rwrl.dataset(
tmp_path,
combined_challenge=combined_challenge_str,
domain=domain_name,
task=task_name,
difficulty=difficulty,
num_shards=num_shards,
shuffle_buffer_size=10)
dataset = dataset.map(preprocess_fn).batch(batch_size)
In [ ]:
#@title Auxiliary functions
def make_networks(
action_spec: specs.BoundedArray,
hidden_size: int = 1024,
num_blocks: int = 4,
num_mixtures: int = 5,
vmin: float = -150.,
vmax: float = 150.,
num_atoms: int = 51,
):
"""Creates networks used by the agent."""
num_dimensions = np.prod(action_spec.shape, dtype=int)
policy_network = snt.Sequential([
networks.LayerNormAndResidualMLP(
hidden_size=hidden_size, num_blocks=num_blocks),
# Converts the policy output into the same shape as the action spec.
snt.Linear(num_dimensions),
# Note that TanhToSpec applies tanh to the input.
networks.TanhToSpec(action_spec)
])
# The multiplexer concatenates the (maybe transformed) observations/actions.
critic_network = snt.Sequential([
networks.CriticMultiplexer(
critic_network=networks.LayerNormAndResidualMLP(
hidden_size=hidden_size, num_blocks=num_blocks),
observation_network=tf2_utils.batch_concat),
networks.DiscreteValuedHead(vmin, vmax, num_atoms)
])
return {
'policy': policy_network,
'critic': critic_network,
}
In [ ]:
# Create the networks to optimize.
online_networks = make_networks(act_spec)
target_networks = copy.deepcopy(online_networks)
# Create variables.
tf2_utils.create_variables(online_networks['policy'], [obs_spec])
tf2_utils.create_variables(online_networks['critic'], [obs_spec, act_spec])
tf2_utils.create_variables(target_networks['policy'], [obs_spec])
tf2_utils.create_variables(target_networks['critic'], [obs_spec, act_spec])
# The learner updates the parameters (and initializes them).
learner = d4pg.D4PGLearner(
policy_network=online_networks['policy'],
critic_network=online_networks['critic'],
target_policy_network=target_networks['policy'],
target_critic_network=target_networks['critic'],
dataset=dataset,
discount=0.99,
target_update_period=100)
In [ ]:
for _ in range(100):
learner.step()
In [ ]:
# Create a logger.
logger = loggers.TerminalLogger(label='evaluation', time_delta=1.)
# Create an environment loop.
loop = acme.EnvironmentLoop(
environment=environment,
actor=actors.FeedForwardActor(online_networks['policy']),
logger=logger)
loop.run(5)