In [0]:
#@title Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
|
In [0]:
!pip install tf-nightly==2.3.0.dev20200604
!pip install dm-tree
!pip install dm-reverb-nightly
In [0]:
import reverb
import tensorflow as tf
The code below defines a dummy RL environment for use in the examples below.
In [0]:
observations_shape = tf.TensorShape([10, 10])
actions_shape = tf.TensorShape([2])
def agent_step(unused_timestep) -> tf.Tensor:
return tf.cast(tf.random.uniform(actions_shape) > .5, tf.float32)
def environment_step(unused_action) -> tf.Tensor:
return tf.cast(tf.random.uniform(observations_shape, maxval=256), tf.uint8)
In [0]:
# Initialize the reverb server.
simple_server = reverb.Server(
tables=[
reverb.Table(
name='my_table',
sampler=reverb.selectors.Prioritized(priority_exponent=0.8),
remover=reverb.selectors.Fifo(),
max_size=int(1e6),
# Sets Rate Limiter to a low number for the examples.
# Read the Rate Limiters section for usage info.
rate_limiter=reverb.rate_limiters.MinSize(2)),
],
# Sets the port to None to make the server pick one automatically.
port=None)
# Initializes the reverb client on the same port as the server.
client = reverb.Client(f'localhost:{simple_server.port}')
For details on customizing the sampler, remover, and rate limiter, see below.
In [0]:
# Dynamically adds trajectories of length 3 to 'my_table' using a client writer.
with client.writer(max_sequence_length=3) as writer:
timestep = environment_step(None)
for step in range(4):
action = agent_step(timestep)
writer.append((timestep, action))
timestep = environment_step(action)
if step >= 2:
# In this example, the item consists of the 3 most recent timesteps that
# were added to the writer and has a priority of 1.5.
writer.create_item(
table='my_table', num_timesteps=3, priority=1.5)
The animation illustrates the state of the server at each step in the above code block. Although each item is being set to have the same priority value of 1.5, items do not need to have the same priority values. In real world scenarios, items would have differing and dynamically-calculated priority values.
In [0]:
# Sets the sequence length to match the length of the prioritized items
# inserted into the table. To match the example above, we use 3.
sequence_length = 3
# Dataset samples sequences of length 3 and streams the timesteps one by one.
# This allows streaming large sequences that do not necessarily fit in memory.
dataset = reverb.ReplayDataset(
server_address=f'localhost:{simple_server.port}',
table='my_table',
max_in_flight_samples_per_worker=10,
dtypes=(tf.uint8, tf.float32),
shapes=(observations_shape, actions_shape))
# Batches the data according to the correct sequence length.
# Shape of items is now [3, 10, 10].
dataset = dataset.batch(sequence_length)
In [0]:
# Batches 2 sequences together.
# Shapes of items is now [2, 3, 10, 10].
dataset = dataset.batch(2)
for sample in dataset.take(1):
# Results in the following format.
print(sample.info.key) # ([2, 3], uint64)
print(sample.info.probability) # ([2, 3], float64)
observation, action = sample.data
print(observation) # ([2, 3, 10, 10], uint8)
print(action) # ([2, 3, 2], float32)
Create a new server for this example to keep the elements of the priority table consistent.
In [0]:
complete_episode_server = reverb.Server(
tables=[
reverb.Table(
name='my_table',
sampler=reverb.selectors.Prioritized(priority_exponent=0.8),
remover=reverb.selectors.Fifo(),
max_size=int(1e6),
# Sets Rate Limiter to a low number for the examples.
# Read the Rate Limiters section for usage info.
rate_limiter=reverb.rate_limiters.MinSize(2)),
],
# Sets the port to None to make the server pick one automatically.
port=None)
# Initializes the reverb client on the same port.
client = reverb.Client(f'localhost:{complete_episode_server.port}')
In [0]:
# Adds episodes as a single entry to 'my_table' using the insert function.
episode_length = 100
num_episodes = 200
def unroll_full_episode():
observations, actions = [environment_step(None)], []
for _ in range(1, episode_length):
actions.append(agent_step(observations[-1]))
observations.append(environment_step(actions[-1]))
return tf.stack(observations), tf.stack(actions)
for _ in range(num_episodes):
# Uses client.insert since this is a full trajectories and not individual
# timesteps.
client.insert(unroll_full_episode(), {'my_table': 1.5})
In [0]:
# Each sample is an entire episode.
# Adjusts the expected shapes to account for the whole episode length.
dataset = reverb.ReplayDataset(
server_address=f'localhost:{complete_episode_server.port}',
table='my_table',
max_in_flight_samples_per_worker=10,
dtypes=(tf.uint8, tf.float32),
shapes=([episode_length] + observations_shape,
[episode_length - 1] + actions_shape))
# Batches 128 episodes together.
# Each item is an episode of the format (observations, actions) as above.
# Shape of items is now ([128, 100, 10, 10], [128, 100, 2]).
dataset = dataset.batch(128)
# Sample has type reverb.ReplaySample.
for sample in dataset.take(1):
# Results in the following format.
print(sample.info.key) # ([128], uint64)
print(sample.info.probability) # ([128], float64)
observation, action = sample.data
print(observation) # ([128, 100, 10, 10], uint8)
print(action) # ([128, 100, 2], float32)
Create a server that maintains multiple priority tables.
In [0]:
multitable_server = reverb.Server(
tables=[
reverb.Table(
name='my_table_a',
sampler=reverb.selectors.Prioritized(priority_exponent=0.8),
remover=reverb.selectors.Fifo(),
max_size=int(1e6),
# Sets Rate Limiter to a low number for the examples.
# Read the Rate Limiters section for usage info.
rate_limiter=reverb.rate_limiters.MinSize(1)),
reverb.Table(
name='my_table_b',
sampler=reverb.selectors.Prioritized(priority_exponent=0.8),
remover=reverb.selectors.Fifo(),
max_size=int(1e6),
# Sets Rate Limiter to a low number for the examples.
# Read the Rate Limiters section for usage info.
rate_limiter=reverb.rate_limiters.MinSize(1)),
],
port=None)
client = reverb.Client('localhost:{}'.format(multitable_server.port))
In [0]:
with client.writer(max_sequence_length=3) as writer:
timestep = environment_step(None)
for step in range(4):
writer.append(timestep)
action = agent_step(timestep)
timestep = environment_step(action)
if step >= 1:
writer.create_item(
table='my_table_b', num_timesteps=2, priority=4-step)
if step >= 2:
writer.create_item(
table='my_table_a', num_timesteps=3, priority=4-step)
The above diagram shows the state of the server after executing the overlapping trajectories code.
To insert full trajectories into multiple tables use client.insert
as illustrated below:
client.insert(episode, {'my_table_one': 1.5, 'my_table_two': 2.5})
In [0]:
reverb.Server(
tables=[
reverb.Table(
name='my_table',
sampler=reverb.selectors.Prioritized(priority_exponent=0.8),
remover=reverb.selectors.Fifo(),
max_size=int(1e6),
rate_limiter=reverb.rate_limiters.MinSize(100)),
],
port=None)
Setting max_times_sampled=1
causes each item to be removed after it is
sampled once. The end result is a priority table that essentially functions
as a max priority queue.
In [0]:
max_size = 1000
reverb.Server(
tables=[
reverb.Table(
name='my_priority_queue',
sampler=reverb.selectors.MaxHeap(),
remover=reverb.selectors.MinHeap(),
max_size=max_size,
rate_limiter=reverb.rate_limiters.MinSize(int(0.95 * max_size)),
max_times_sampled=1,
)
],
port=None)
Behavior of canonical data structures such as
circular buffer or a max
priority queue can
be implemented in Reverb by modifying the sampler
and remover
or by using the PriorityTable
queue initializer.
In [0]:
reverb.Server(
tables=[
reverb.Table.queue(name='my_queue', max_size=10000),
reverb.Table(
name='my_circular_buffer',
sampler=reverb.selectors.Fifo(),
remover=reverb.selectors.Fifo(),
max_size=10000,
max_times_sampled=1,
rate_limiter=reverb.rate_limiters.MinSize(1)),
],
port=None)
In [0]:
reverb.Server(
tables=[
reverb.Table(
name='my_table',
sampler=reverb.selectors.Prioritized(priority_exponent=0.8),
remover=reverb.selectors.Fifo(),
max_size=int(1e6),
rate_limiter=reverb.rate_limiters.SampleToInsertRatio(
samples_per_insert=3.0, min_size_to_sample=3,
error_buffer=3.0)),
],
port=None)
This example is intended to be used in a distributed or multi-threaded enviroment where insertion blocking will be unblocked by sample calls from an independent thread. If the system is single threaded, the blocked insertion call will cause a deadlock.