In [0]:
environment_library = 'gym' # @param ['dm_control', 'gym']
In [0]:
mjkey = """
""".strip()
if not mjkey and environment_library == 'dm_control':
raise ValueError(
'A Mujoco license is required for `dm_control`, if you do not have on '
'consider selecting `gym` from the dropdown menu in the cell above.')
In [0]:
!pip install dm-acme
!pip install dm-acme[reverb]
!pip install dm-acme[tf]
Without a valid license you won't be able to use the dm_control environments but can still follow this colab using the gym environments.
If you have a personal Mujoco license (not an institutional one), you may
need to follow the instructions at https://research.google.com/colaboratory/local-runtimes.html to run a Jupyter kernel on your local machine.
This will allow you to install dm_control by following instructions in
https://github.com/deepmind/dm_control and using a personal MuJoCo license.
In [0]:
#@test {"skip": true}
if environment_library == 'dm_control':
mujoco_dir = "$HOME/.mujoco"
# Install OpenGL dependencies
!apt-get update && apt-get install -y --no-install-recommends \
libgl1-mesa-glx libosmesa6 libglew2.0
# Get MuJoCo binaries
!wget -q https://www.roboti.us/download/mujoco200_linux.zip -O mujoco.zip
!unzip -o -q mujoco.zip -d "$mujoco_dir"
# Copy over MuJoCo license
!echo "$mjkey" > "$mujoco_dir/mjkey.txt"
# Install dm_control
!pip install dm_control
# Configure dm_control to use the OSMesa rendering backend
%env MUJOCO_GL=osmesa
# Check that the installation succeeded
try:
from dm_control import suite
env = suite.load('cartpole', 'swingup')
pixels = env.physics.render()
except Exception as e:
raise e from RuntimeError(
'Something went wrong during installation. Check the shell output above '
'for more information. If you do not have a valid Mujoco license, '
'consider selecting `gym` in the dropdown menu at the top of this Colab.')
else:
del suite, env, pixels
elif environment_library == 'gym':
!pip install gym
In [0]:
!sudo apt-get install -y xvfb ffmpeg
!pip install imageio
!pip install PILLOW
!pip install pyvirtualdisplay
In [0]:
import IPython
from acme import environment_loop
from acme import specs
from acme import wrappers
from acme.agents.tf import d4pg
from acme.tf import networks
from acme.tf import utils as tf2_utils
from acme.utils import loggers
import numpy as np
import sonnet as snt
# Import the selected environment lib
if environment_library == 'dm_control':
from dm_control import suite
elif environment_library == 'gym':
import gym
# Imports required for visualization
import pyvirtualdisplay
import imageio
import base64
# Set up a virtual display for rendering.
display = pyvirtualdisplay.Display(visible=0, size=(1400, 900)).start()
In [0]:
if environment_library == 'dm_control':
environment = suite.load('cartpole', 'balance')
elif environment_library == 'gym':
environment = gym.make('MountainCarContinuous-v0')
environment = wrappers.GymWrapper(environment) # To dm_env interface.
else:
raise ValueError(
"Unknown environment library: {};".format(environment_name) +
"choose among ['dm_control', 'gym'].")
# Make sure the environment outputs single-precision floats.
environment = wrappers.SinglePrecisionWrapper(environment)
# Grab the spec of the environment.
environment_spec = specs.make_environment_spec(environment)
In [0]:
#@title Build agent networks
# Get total number of action dimensions from action spec.
num_dimensions = np.prod(environment_spec.actions.shape, dtype=int)
# Create the shared observation network; here simply a state-less operation.
observation_network = tf2_utils.batch_concat
# Create the deterministic policy network.
policy_network = snt.Sequential([
networks.LayerNormMLP((256, 256, 256), activate_final=True),
networks.NearZeroInitializedLinear(num_dimensions),
networks.TanhToSpec(environment_spec.actions),
])
# Create the distributional critic network.
critic_network = snt.Sequential([
# The multiplexer concatenates the observations/actions.
networks.CriticMultiplexer(),
networks.LayerNormMLP((512, 512, 256), activate_final=True),
networks.DiscreteValuedHead(vmin=-150., vmax=150., num_atoms=51),
])
In [0]:
# Create a logger for the agent and environment loop.
agent_logger = loggers.TerminalLogger(label='agent', time_delta=10.)
env_loop_logger = loggers.TerminalLogger(label='env_loop', time_delta=10.)
# Create the D4PG agent.
agent = d4pg.D4PG(
environment_spec=environment_spec,
policy_network=policy_network,
critic_network=critic_network,
observation_network=observation_network,
sigma=1.0,
logger=agent_logger,
checkpoint=False
)
# Create an loop connecting this agent to the environment created above.
env_loop = environment_loop.EnvironmentLoop(
environment, agent, logger=env_loop_logger)
In [0]:
# Run a `num_episodes` training episodes.
# Rerun this cell until the agent has learned the given task.
env_loop.run(num_episodes=100)
In [0]:
# Create a simple helper function to render a frame from the current state of
# the environment.
if environment_library == 'dm_control':
def render(env):
return env.physics.render(camera_id=0)
elif environment_library == 'gym':
def render(env):
return env.environment.render(mode='rgb_array')
else:
raise ValueError(
"Unknown environment library: {};".format(environment_name) +
"choose among ['dm_control', 'gym'].")
def display_video(frames, filename='temp.mp4'):
"""Save and display video."""
# Write video
with imageio.get_writer(filename, fps=60) as video:
for frame in frames:
video.append_data(frame)
# Read video and display the video
video = open(filename, 'rb').read()
b64_video = base64.b64encode(video)
video_tag = ('<video width="320" height="240" controls alt="test" '
'src="data:video/mp4;base64,{0}">').format(b64_video.decode())
return IPython.display.HTML(video_tag)
In [0]:
timestep = environment.reset()
frames = [render(environment)]
while not timestep.last():
# Simple environment loop.
action = agent.select_action(timestep.observation)
timestep = environment.step(action)
# Render the scene and add it to the frame stack.
frames.append(render(environment))
# Save and display a video of the behaviour.
display_video(np.array(frames))