In [0]:
#@title Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

Introduction

This colab is a demonstration of how to use Reverb through examples.

Setup

Installs the nightly build of Reverb (dm-reverb-nightly) and TensorFlow nightly (tf-nightly) to match.


In [0]:
!pip install tf-nightly==2.3.0.dev20200604
!pip install dm-tree
!pip install dm-reverb-nightly

In [0]:
import reverb
import tensorflow as tf

The code below defines a dummy RL environment for use in the examples below.


In [0]:
observations_shape = tf.TensorShape([10, 10])
actions_shape = tf.TensorShape([2])

def agent_step(unused_timestep) -> tf.Tensor:
  return tf.cast(tf.random.uniform(actions_shape) > .5, tf.float32)

def environment_step(unused_action) -> tf.Tensor:
  return tf.cast(tf.random.uniform(observations_shape, maxval=256), tf.uint8)

Creating a Server and Client


In [0]:
# Initialize the reverb server.
simple_server = reverb.Server(
    tables=[
        reverb.Table(
            name='my_table',
            sampler=reverb.selectors.Prioritized(priority_exponent=0.8),
            remover=reverb.selectors.Fifo(),
            max_size=int(1e6),
            # Sets Rate Limiter to a low number for the examples.
            # Read the Rate Limiters section for usage info.
            rate_limiter=reverb.rate_limiters.MinSize(2)),
    ],
    # Sets the port to None to make the server pick one automatically.
    port=None)

# Initializes the reverb client on the same port as the server.
client = reverb.Client(f'localhost:{simple_server.port}')

For details on customizing the sampler, remover, and rate limiter, see below.

Example 1: Overlapping Trajectories

Inserting Overlapping Trajectories


In [0]:
# Dynamically adds trajectories of length 3 to 'my_table' using a client writer.

with client.writer(max_sequence_length=3) as writer:
  timestep = environment_step(None)
  for step in range(4):
    action = agent_step(timestep)
    writer.append((timestep, action))
    timestep = environment_step(action)
    if step >= 2:
      # In this example, the item consists of the 3 most recent timesteps that
      # were added to the writer and has a priority of 1.5.
      writer.create_item(
          table='my_table', num_timesteps=3, priority=1.5)

The animation illustrates the state of the server at each step in the above code block. Although each item is being set to have the same priority value of 1.5, items do not need to have the same priority values. In real world scenarios, items would have differing and dynamically-calculated priority values.

Sampling Overlapping Trajectories in TensorFlow


In [0]:
# Sets the sequence length to match the length of the prioritized items
# inserted into the table. To match the example above, we use 3.
sequence_length = 3

# Dataset samples sequences of length 3 and streams the timesteps one by one.
# This allows streaming large sequences that do not necessarily fit in memory.
dataset = reverb.ReplayDataset(
  server_address=f'localhost:{simple_server.port}',
  table='my_table',
  max_in_flight_samples_per_worker=10,
  dtypes=(tf.uint8, tf.float32),
  shapes=(observations_shape, actions_shape))

# Batches the data according to the correct sequence length.
# Shape of items is now [3, 10, 10].
dataset = dataset.batch(sequence_length)

In [0]:
# Batches 2 sequences together.
# Shapes of items is now [2, 3, 10, 10].
dataset = dataset.batch(2)

for sample in dataset.take(1):
  # Results in the following format.
  print(sample.info.key)          # ([2, 3], uint64)
  print(sample.info.probability)  # ([2, 3], float64)
  
  observation, action = sample.data
  print(observation)              # ([2, 3, 10, 10], uint8)
  print(action)                   # ([2, 3, 2], float32)

Example 2: Complete Episodes

Create a new server for this example to keep the elements of the priority table consistent.


In [0]:
complete_episode_server = reverb.Server(
    tables=[
        reverb.Table(
            name='my_table',
            sampler=reverb.selectors.Prioritized(priority_exponent=0.8),
            remover=reverb.selectors.Fifo(),
            max_size=int(1e6),
            # Sets Rate Limiter to a low number for the examples.
            # Read the Rate Limiters section for usage info.
            rate_limiter=reverb.rate_limiters.MinSize(2)),
    ],
    # Sets the port to None to make the server pick one automatically.
    port=None)

# Initializes the reverb client on the same port.
client = reverb.Client(f'localhost:{complete_episode_server.port}')

Inserting Complete Episodes


In [0]:
# Adds episodes as a single entry to 'my_table' using the insert function.
episode_length = 100

num_episodes = 200

def unroll_full_episode():
  observations, actions = [environment_step(None)], []
  for _ in range(1, episode_length):
    actions.append(agent_step(observations[-1]))
    observations.append(environment_step(actions[-1]))
  return tf.stack(observations), tf.stack(actions)

for _ in range(num_episodes):
  # Uses client.insert since this is a full trajectories and not individual
  # timesteps.
  client.insert(unroll_full_episode(), {'my_table': 1.5})

Sampling Complete Episodes in TensorFlow


In [0]:
# Each sample is an entire episode.
# Adjusts the expected shapes to account for the whole episode length.
dataset = reverb.ReplayDataset(
  server_address=f'localhost:{complete_episode_server.port}',
  table='my_table',
  max_in_flight_samples_per_worker=10,
  dtypes=(tf.uint8, tf.float32),
  shapes=([episode_length] + observations_shape, 
          [episode_length - 1] + actions_shape))

# Batches 128 episodes together.
# Each item is an episode of the format (observations, actions) as above.
# Shape of items is now ([128, 100, 10, 10], [128, 100, 2]).
dataset = dataset.batch(128)

# Sample has type reverb.ReplaySample.
for sample in dataset.take(1):
  # Results in the following format.
  print(sample.info.key)          # ([128], uint64)
  print(sample.info.probability)  # ([128], float64)
  
  observation, action = sample.data
  print(observation)              # ([128, 100, 10, 10], uint8)
  print(action)                   # ([128, 100, 2], float32)

Example 3: Multiple Priority Tables

Create a server that maintains multiple priority tables.


In [0]:
multitable_server = reverb.Server(
    tables=[
        reverb.Table(
            name='my_table_a',
            sampler=reverb.selectors.Prioritized(priority_exponent=0.8),
            remover=reverb.selectors.Fifo(),
            max_size=int(1e6),
            # Sets Rate Limiter to a low number for the examples.
            # Read the Rate Limiters section for usage info.
            rate_limiter=reverb.rate_limiters.MinSize(1)),
        reverb.Table(
            name='my_table_b',
            sampler=reverb.selectors.Prioritized(priority_exponent=0.8),
            remover=reverb.selectors.Fifo(),
            max_size=int(1e6),
            # Sets Rate Limiter to a low number for the examples.
            # Read the Rate Limiters section for usage info.
            rate_limiter=reverb.rate_limiters.MinSize(1)),
    ],
    port=None)

client = reverb.Client('localhost:{}'.format(multitable_server.port))

Inserting Sequences of Varying Length into Multiple Priority Tables


In [0]:
with client.writer(max_sequence_length=3) as writer:
  timestep = environment_step(None)
  for step in range(4):
    writer.append(timestep)
    action = agent_step(timestep)
    timestep = environment_step(action)

    if step >= 1:
      writer.create_item(
          table='my_table_b', num_timesteps=2, priority=4-step)
    if step >= 2:
      writer.create_item(
          table='my_table_a', num_timesteps=3, priority=4-step)

The above diagram shows the state of the server after executing the overlapping trajectories code.

To insert full trajectories into multiple tables use client.insert as illustrated below:

client.insert(episode, {'my_table_one': 1.5, 'my_table_two': 2.5})

Example 4: Samplers and Removers

Creating a Server with a Prioritized Sampler and a FIFO Remover


In [0]:
reverb.Server(
    tables=[
        reverb.Table(
            name='my_table',
            sampler=reverb.selectors.Prioritized(priority_exponent=0.8),
            remover=reverb.selectors.Fifo(),
            max_size=int(1e6),
            rate_limiter=reverb.rate_limiters.MinSize(100)),
    ],
    port=None)

Creating a Server with a MaxHeap Sampler and a MinHeap Remover

Setting max_times_sampled=1 causes each item to be removed after it is sampled once. The end result is a priority table that essentially functions as a max priority queue.


In [0]:
max_size = 1000
reverb.Server(
    tables=[
        reverb.Table(
            name='my_priority_queue',
            sampler=reverb.selectors.MaxHeap(),
            remover=reverb.selectors.MinHeap(),
            max_size=max_size,
            rate_limiter=reverb.rate_limiters.MinSize(int(0.95 * max_size)),
            max_times_sampled=1,
        )
    ],
    port=None)

Creating a Server with One Queue and One Circular Buffer

Behavior of canonical data structures such as circular buffer or a max priority queue can be implemented in Reverb by modifying the sampler and remover or by using the PriorityTable queue initializer.


In [0]:
reverb.Server(
    tables=[
        reverb.Table.queue(name='my_queue', max_size=10000),
        reverb.Table(
            name='my_circular_buffer',
            sampler=reverb.selectors.Fifo(),
            remover=reverb.selectors.Fifo(),
            max_size=10000,
            max_times_sampled=1,
            rate_limiter=reverb.rate_limiters.MinSize(1)),
    ],
    port=None)

Example 5: Rate Limiters

Creating a Server with a SampleToInsertRatio Rate Limiter


In [0]:
reverb.Server(
    tables=[
        reverb.Table(
            name='my_table',
            sampler=reverb.selectors.Prioritized(priority_exponent=0.8),
            remover=reverb.selectors.Fifo(),
            max_size=int(1e6),
            rate_limiter=reverb.rate_limiters.SampleToInsertRatio(
                samples_per_insert=3.0, min_size_to_sample=3,
                error_buffer=3.0)),
    ],
    port=None)

This example is intended to be used in a distributed or multi-threaded enviroment where insertion blocking will be unblocked by sample calls from an independent thread. If the system is single threaded, the blocked insertion call will cause a deadlock.