Copyright 2020 DeepMind Technologies Limited.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at
https://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License.
In [ ]:
!pip install dm-acme
!pip install dm-acme[reverb]
!pip install dm-acme[tf]
!pip install dm-sonnet
!pip install dopamine-rl==3.0.1
!pip install atari-py
!git clone https://github.com/deepmind/deepmind-research.git
%cd deepmind-research
In [ ]:
import copy
import acme
from acme.agents.tf import actors
from acme.agents.tf.dqn import learning as dqn
from acme.tf import utils as acme_utils
from acme.utils import loggers
from rl_unplugged import atari
import sonnet as snt
import tensorflow as tf
In [ ]:
game = 'Pong' #@param
run = 1 #@param
tmp_path = '/tmp/atari'
gs_path = 'gs://rl_unplugged/atari'
!mkdir -p {tmp_path}/{game}
!gsutil cp {gs_path}/{game}/run_{run}-00000-of-00001 {tmp_path}/{game}
In [ ]:
batch_size = 10 #@param
def discard_extras(sample):
return sample._replace(data=sample.data[:5])
dataset = atari.dataset(path=tmp_path, game='Pong', run=1, num_shards=1)
# Small batch size, experiments in the paper were run with batch size 256.
dataset = dataset.map(discard_extras).batch(batch_size)
In [ ]:
environment = atari.environment(game='Pong')
In [ ]:
# Get total number of actions.
num_actions = environment.action_spec().num_values
# Create the Q network.
network = snt.Sequential([
lambda x: tf.image.convert_image_dtype(x, tf.float32),
snt.Conv2D(32, [8, 8], [4, 4]),
tf.nn.relu,
snt.Conv2D(64, [4, 4], [2, 2]),
tf.nn.relu,
snt.Conv2D(64, [3, 3], [1, 1]),
tf.nn.relu,
snt.Flatten(),
snt.nets.MLP([512, num_actions])
])
acme_utils.create_variables(network, [environment.observation_spec()])
Out[ ]:
In [ ]:
# Create a logger.
logger = loggers.TerminalLogger(label='learner', time_delta=1.)
# Create the DQN learner.
learner = dqn.DQNLearner(
network=network,
target_network=copy.deepcopy(network),
discount=0.99,
learning_rate=3e-4,
importance_sampling_exponent=0.2,
target_update_period=2500,
dataset=dataset,
logger=logger)
In [ ]:
for _ in range(100):
learner.step()
In [ ]:
# Create a logger.
logger = loggers.TerminalLogger(label='evaluation', time_delta=1.)
# Create an environment loop.
policy_network = snt.Sequential([
network,
lambda q: tf.argmax(q, axis=-1),
])
loop = acme.EnvironmentLoop(
environment=environment,
actor=actors.FeedForwardActor(policy_network=policy_network),
logger=logger)
loop.run(5)