Copyright 2020 DeepMind Technologies Limited.

Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at

https://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License.

RL Unplugged: Offline D4PG - DM control

Guide to training an Acme D4PG agent on DM control data.

Installation


In [ ]:
!pip install dm-acme
!pip install dm-acme[reverb]
!pip install dm-acme[tf]
!pip install dm-sonnet
!git clone https://github.com/deepmind/deepmind-research.git
%cd deepmind-research

dm_control

More detailed instructions in .

Institutional MuJoCo license.


In [ ]:
#@title Edit and run
mjkey = """

REPLACE THIS LINE WITH YOUR MUJOCO LICENSE KEY

""".strip()

mujoco_dir = "$HOME/.mujoco"

# Install OpenGL deps
!apt-get update && apt-get install -y --no-install-recommends \
  libgl1-mesa-glx libosmesa6 libglew2.0

# Fetch MuJoCo binaries from Roboti
!wget -q https://www.roboti.us/download/mujoco200_linux.zip -O mujoco.zip
!unzip -o -q mujoco.zip -d "$mujoco_dir"

# Copy over MuJoCo license
!echo "$mjkey" > "$mujoco_dir/mjkey.txt"


# Configure dm_control to use the OSMesa rendering backend
%env MUJOCO_GL=osmesa

# Install dm_control, including extra dependencies needed for the locomotion
# mazes.
!pip install dm_control[locomotion_mazes]

Machine-locked MuJoCo license.


In [ ]:
#@title Add your MuJoCo License and run
mjkey = """
""".strip()

mujoco_dir = "$HOME/.mujoco"

# Install OpenGL dependencies
!apt-get update && apt-get install -y --no-install-recommends \
  libgl1-mesa-glx libosmesa6 libglew2.0

# Get MuJoCo binaries
!wget -q https://www.roboti.us/download/mujoco200_linux.zip -O mujoco.zip
!unzip -o -q mujoco.zip -d "$mujoco_dir"

# Copy over MuJoCo license
!echo "$mjkey" > "$mujoco_dir/mjkey.txt"

# Install dm_control
!pip install dm_control[locomotion_mazes]

# Configure dm_control to use the OSMesa rendering backend
%env MUJOCO_GL=osmesa

Imports


In [ ]:
import collections
import copy
from typing import Mapping, Sequence

import acme
from acme import specs
from acme.agents.tf import actors
from acme.agents.tf import d4pg
from acme.tf import networks
from acme.tf import utils as tf2_utils
from acme.utils import loggers
from acme.wrappers import single_precision
from acme.tf import utils as tf2_utils
import numpy as np
from rl_unplugged import dm_control_suite
import sonnet as snt
import tensorflow as tf

Data


In [ ]:
task_name = 'cartpole_swingup' #@param
tmp_path = '/tmp/dm_control_suite'
gs_path = 'gs://rl_unplugged/dm_control_suite'

!mkdir -p {tmp_path}/{task_name}
!gsutil cp {gs_path}/{task_name}/* {tmp_path}/{task_name}

num_shards_str, = !ls {tmp_path}/{task_name}/* | wc -l
num_shards = int(num_shards_str)

Dataset and environment


In [ ]:
batch_size = 10  #@param

task = dm_control_suite.ControlSuite(task_name)

environment = task.environment
environment_spec = specs.make_environment_spec(environment)

dataset = dm_control_suite.dataset(
    '/tmp',
    data_path=task.data_path,
    shapes=task.shapes,
    uint8_features=task.uint8_features,
    num_threads=1,
    batch_size=batch_size,
    num_shards=num_shards)

def discard_extras(sample):
  return sample._replace(data=sample.data[:5])

dataset = dataset.map(discard_extras).batch(batch_size)

D4PG learner


In [ ]:
# Create the networks to optimize.
action_spec = environment_spec.actions
action_size = np.prod(action_spec.shape, dtype=int)

policy_network = snt.Sequential([
    tf2_utils.batch_concat,
    networks.LayerNormMLP(layer_sizes=(300, 200, action_size)),
    networks.TanhToSpec(spec=environment_spec.actions)])

critic_network = snt.Sequential([
    networks.CriticMultiplexer(
        observation_network=tf2_utils.batch_concat,
        action_network=tf.identity,
        critic_network=networks.LayerNormMLP(
            layer_sizes=(400, 300),
            activate_final=True)),
    # Value-head gives a 51-atomed delta distribution over state-action values.
    networks.DiscreteValuedHead(vmin=-150., vmax=150., num_atoms=51)])

# Create the target networks
target_policy_network = copy.deepcopy(policy_network)
target_critic_network = copy.deepcopy(critic_network)

# Create variables.
tf2_utils.create_variables(network=policy_network,
                           input_spec=[environment_spec.observations])
tf2_utils.create_variables(network=critic_network,
                           input_spec=[environment_spec.observations,
                                       environment_spec.actions])
tf2_utils.create_variables(network=target_policy_network,
                           input_spec=[environment_spec.observations])
tf2_utils.create_variables(network=target_critic_network,
                           input_spec=[environment_spec.observations,
                                       environment_spec.actions])

# The learner updates the parameters (and initializes them).
learner = d4pg.D4PGLearner(
    policy_network=policy_network,
    critic_network=critic_network,
    target_policy_network=target_policy_network,
    target_critic_network=target_critic_network,
    dataset=dataset,
    discount=0.99,
    target_update_period=100)

Training loop


In [ ]:
for _ in range(100):
  learner.step()


[Learner] Critic Loss = 3.919 | Policy Loss = 0.326 | Steps = 1 | Walltime = 0

Evaluation


In [ ]:
# Create a logger.
logger = loggers.TerminalLogger(label='evaluation', time_delta=1.)

# Create an environment loop.
loop = acme.EnvironmentLoop(
    environment=environment,
    actor=actors.FeedForwardActor(policy_network),
    logger=logger)

loop.run(5)


[Evaluation] Episode Length = 1000 | Episode Return = 129.717 | Episodes = 2 | Steps = 2000 | Steps Per Second = 1480.399
[Evaluation] Episode Length = 1000 | Episode Return = 34.790 | Episodes = 4 | Steps = 4000 | Steps Per Second = 1449.009