In [0]:
#@title Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

In [0]:
!pip install tf-nightly==2.3.0.dev20200604
!pip install dm-tree
!pip install dm-reverb-nightly

Frame Stacking using Reverb

This contains minimal examples of how frame stacking can be implemented using Reverb.

Setup


In [0]:
from collections import deque

import numpy as np
import reverb
import tensorflow as tf

In [0]:
FRAME_SHAPE = (16, 16)  # [width, height]
FRAME_DTYPE = np.uint8


def frame_generator(max_num_frames: int = 1000):
  for i in range(1, max_num_frames + 1):
    yield np.ones(FRAME_SHAPE, dtype=FRAME_DTYPE) * i

Stack Before Writing

The simplest approach is to simply stack the frames before writing it to Reverb. If there is no overlap between trajectories or if the overlap never "break" stacks then this approach might be the most efficient as it reduces the post processing after trajectories have been sampled.


In [0]:
def store_stacked(stack_size: int, stride: int, sequence_length: int):
  """Simple example where frames are stacked before sent to Reverb.

  If `stride` < `stack_size` then stacks will "overlap".
  If `stride` == `stack_size` then stacks will be adjecent.
  If `stride` > `stack_size` then frames between stacks will be dropped.

  Args:
    stack_size: The number of frames to stack.
    stride: The number of frames between each stack is created.
    sequence_length: The number of stacks in each sampleable item.
  """
  server = reverb.Server(
      tables=[
              reverb.Table(
                  name='stacked_frames',
                  sampler=reverb.selectors.Fifo(),
                  remover=reverb.selectors.Fifo(),
                  max_size=100,
                  rate_limiter=reverb.rate_limiters.MinSize(10),
              ),
      ],
  )
  client = reverb.Client(f'localhost:{server.port}')

  with client.writer(max_sequence_length=sequence_length) as writer:
    # Create a circular buffer of the `stack_size` most recent frames.
    buffer = deque(maxlen=stack_size)

    for i, frame in enumerate(frame_generator(5 * stride * sequence_length)):
      buffer.append(frame)

      # We can't insert anything before the first stack is full.
      if len(buffer) < stack_size or (i + 1) % stride != 0:
        continue

      # Stack the frames in buffer and insert the data into Reverb. The shape of
      # the stack is [stack_size, width, height].
      writer.append(np.stack(buffer))

      # If `sequence_length` full stacks have been written then insert an item
      # that can be sampled.
      stacks_written = (i + 1) // stride - (stack_size - 1) // stride
      if stacks_written >= sequence_length:
        writer.create_item(table='stacked_frames',
                           num_timesteps=sequence_length,
                           priority=1.0)

  # Create a dataset that samples sequences of stacked frames.
  dataset = reverb.ReplayDataset(
      server_address=client.server_address,
      table='stacked_frames',
      max_in_flight_samples_per_worker=10,
      dtypes=tf.as_dtype(FRAME_DTYPE),
      shapes=tf.TensorShape((sequence_length, stack_size) + FRAME_SHAPE),
      sequence_length=sequence_length,
      emit_timesteps=False)

  # Print the result.
  for sequence in dataset.take(2):
    print(sequence.data)

In [4]:
# Create trajectories with 4 frames stacked together, no frames shared
# between stacks and create sequences of 3 stacks. For example, the first 16
# steps will result in the following 2 samplable items:
#
#   [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16]
#
#     -> [[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12]]
#     -> [[5, 6, 7, 8], [9, 10, 11, 12], [13, 14, 15, 16]]
#

store_stacked(stack_size=4, stride=4, sequence_length=3)


tf.Tensor(
[[[[ 1  1  1 ...  1  1  1]
   [ 1  1  1 ...  1  1  1]
   [ 1  1  1 ...  1  1  1]
   ...
   [ 1  1  1 ...  1  1  1]
   [ 1  1  1 ...  1  1  1]
   [ 1  1  1 ...  1  1  1]]

  [[ 2  2  2 ...  2  2  2]
   [ 2  2  2 ...  2  2  2]
   [ 2  2  2 ...  2  2  2]
   ...
   [ 2  2  2 ...  2  2  2]
   [ 2  2  2 ...  2  2  2]
   [ 2  2  2 ...  2  2  2]]

  [[ 3  3  3 ...  3  3  3]
   [ 3  3  3 ...  3  3  3]
   [ 3  3  3 ...  3  3  3]
   ...
   [ 3  3  3 ...  3  3  3]
   [ 3  3  3 ...  3  3  3]
   [ 3  3  3 ...  3  3  3]]

  [[ 4  4  4 ...  4  4  4]
   [ 4  4  4 ...  4  4  4]
   [ 4  4  4 ...  4  4  4]
   ...
   [ 4  4  4 ...  4  4  4]
   [ 4  4  4 ...  4  4  4]
   [ 4  4  4 ...  4  4  4]]]


 [[[ 5  5  5 ...  5  5  5]
   [ 5  5  5 ...  5  5  5]
   [ 5  5  5 ...  5  5  5]
   ...
   [ 5  5  5 ...  5  5  5]
   [ 5  5  5 ...  5  5  5]
   [ 5  5  5 ...  5  5  5]]

  [[ 6  6  6 ...  6  6  6]
   [ 6  6  6 ...  6  6  6]
   [ 6  6  6 ...  6  6  6]
   ...
   [ 6  6  6 ...  6  6  6]
   [ 6  6  6 ...  6  6  6]
   [ 6  6  6 ...  6  6  6]]

  [[ 7  7  7 ...  7  7  7]
   [ 7  7  7 ...  7  7  7]
   [ 7  7  7 ...  7  7  7]
   ...
   [ 7  7  7 ...  7  7  7]
   [ 7  7  7 ...  7  7  7]
   [ 7  7  7 ...  7  7  7]]

  [[ 8  8  8 ...  8  8  8]
   [ 8  8  8 ...  8  8  8]
   [ 8  8  8 ...  8  8  8]
   ...
   [ 8  8  8 ...  8  8  8]
   [ 8  8  8 ...  8  8  8]
   [ 8  8  8 ...  8  8  8]]]


 [[[ 9  9  9 ...  9  9  9]
   [ 9  9  9 ...  9  9  9]
   [ 9  9  9 ...  9  9  9]
   ...
   [ 9  9  9 ...  9  9  9]
   [ 9  9  9 ...  9  9  9]
   [ 9  9  9 ...  9  9  9]]

  [[10 10 10 ... 10 10 10]
   [10 10 10 ... 10 10 10]
   [10 10 10 ... 10 10 10]
   ...
   [10 10 10 ... 10 10 10]
   [10 10 10 ... 10 10 10]
   [10 10 10 ... 10 10 10]]

  [[11 11 11 ... 11 11 11]
   [11 11 11 ... 11 11 11]
   [11 11 11 ... 11 11 11]
   ...
   [11 11 11 ... 11 11 11]
   [11 11 11 ... 11 11 11]
   [11 11 11 ... 11 11 11]]

  [[12 12 12 ... 12 12 12]
   [12 12 12 ... 12 12 12]
   [12 12 12 ... 12 12 12]
   ...
   [12 12 12 ... 12 12 12]
   [12 12 12 ... 12 12 12]
   [12 12 12 ... 12 12 12]]]], shape=(3, 4, 16, 16), dtype=uint8)
tf.Tensor(
[[[[ 1  1  1 ...  1  1  1]
   [ 1  1  1 ...  1  1  1]
   [ 1  1  1 ...  1  1  1]
   ...
   [ 1  1  1 ...  1  1  1]
   [ 1  1  1 ...  1  1  1]
   [ 1  1  1 ...  1  1  1]]

  [[ 2  2  2 ...  2  2  2]
   [ 2  2  2 ...  2  2  2]
   [ 2  2  2 ...  2  2  2]
   ...
   [ 2  2  2 ...  2  2  2]
   [ 2  2  2 ...  2  2  2]
   [ 2  2  2 ...  2  2  2]]

  [[ 3  3  3 ...  3  3  3]
   [ 3  3  3 ...  3  3  3]
   [ 3  3  3 ...  3  3  3]
   ...
   [ 3  3  3 ...  3  3  3]
   [ 3  3  3 ...  3  3  3]
   [ 3  3  3 ...  3  3  3]]

  [[ 4  4  4 ...  4  4  4]
   [ 4  4  4 ...  4  4  4]
   [ 4  4  4 ...  4  4  4]
   ...
   [ 4  4  4 ...  4  4  4]
   [ 4  4  4 ...  4  4  4]
   [ 4  4  4 ...  4  4  4]]]


 [[[ 5  5  5 ...  5  5  5]
   [ 5  5  5 ...  5  5  5]
   [ 5  5  5 ...  5  5  5]
   ...
   [ 5  5  5 ...  5  5  5]
   [ 5  5  5 ...  5  5  5]
   [ 5  5  5 ...  5  5  5]]

  [[ 6  6  6 ...  6  6  6]
   [ 6  6  6 ...  6  6  6]
   [ 6  6  6 ...  6  6  6]
   ...
   [ 6  6  6 ...  6  6  6]
   [ 6  6  6 ...  6  6  6]
   [ 6  6  6 ...  6  6  6]]

  [[ 7  7  7 ...  7  7  7]
   [ 7  7  7 ...  7  7  7]
   [ 7  7  7 ...  7  7  7]
   ...
   [ 7  7  7 ...  7  7  7]
   [ 7  7  7 ...  7  7  7]
   [ 7  7  7 ...  7  7  7]]

  [[ 8  8  8 ...  8  8  8]
   [ 8  8  8 ...  8  8  8]
   [ 8  8  8 ...  8  8  8]
   ...
   [ 8  8  8 ...  8  8  8]
   [ 8  8  8 ...  8  8  8]
   [ 8  8  8 ...  8  8  8]]]


 [[[ 9  9  9 ...  9  9  9]
   [ 9  9  9 ...  9  9  9]
   [ 9  9  9 ...  9  9  9]
   ...
   [ 9  9  9 ...  9  9  9]
   [ 9  9  9 ...  9  9  9]
   [ 9  9  9 ...  9  9  9]]

  [[10 10 10 ... 10 10 10]
   [10 10 10 ... 10 10 10]
   [10 10 10 ... 10 10 10]
   ...
   [10 10 10 ... 10 10 10]
   [10 10 10 ... 10 10 10]
   [10 10 10 ... 10 10 10]]

  [[11 11 11 ... 11 11 11]
   [11 11 11 ... 11 11 11]
   [11 11 11 ... 11 11 11]
   ...
   [11 11 11 ... 11 11 11]
   [11 11 11 ... 11 11 11]
   [11 11 11 ... 11 11 11]]

  [[12 12 12 ... 12 12 12]
   [12 12 12 ... 12 12 12]
   [12 12 12 ... 12 12 12]
   ...
   [12 12 12 ... 12 12 12]
   [12 12 12 ... 12 12 12]
   [12 12 12 ... 12 12 12]]]], shape=(3, 4, 16, 16), dtype=uint8)

In [5]:
# Create trajectories with 4 frames stacked together, 2 frames shared between
# stacks and create sequences of 3 stacks. Note that since we stack the frames
# BEFORE sending it to Reverb, most stacks will be stored twice resulting in
# double the storage (before compression is applied).
#
# For example, the first 12 steps will result in the following 3 samplable
# items:
#
#   [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]
#
#     -> [[1, 2, 3, 4], [3, 4, 5, 6], [5, 6, 7, 8]]
#     -> [[3, 4, 5, 6], [5, 6, 7, 8], [7, 8, 9, 10]]
#     -> [[5, 6, 7, 8], [7, 8, 9, 10], [9, 10, 11, 12]]
#

store_stacked(stack_size=4, stride=2, sequence_length=3)


tf.Tensor(
[[[[1 1 1 ... 1 1 1]
   [1 1 1 ... 1 1 1]
   [1 1 1 ... 1 1 1]
   ...
   [1 1 1 ... 1 1 1]
   [1 1 1 ... 1 1 1]
   [1 1 1 ... 1 1 1]]

  [[2 2 2 ... 2 2 2]
   [2 2 2 ... 2 2 2]
   [2 2 2 ... 2 2 2]
   ...
   [2 2 2 ... 2 2 2]
   [2 2 2 ... 2 2 2]
   [2 2 2 ... 2 2 2]]

  [[3 3 3 ... 3 3 3]
   [3 3 3 ... 3 3 3]
   [3 3 3 ... 3 3 3]
   ...
   [3 3 3 ... 3 3 3]
   [3 3 3 ... 3 3 3]
   [3 3 3 ... 3 3 3]]

  [[4 4 4 ... 4 4 4]
   [4 4 4 ... 4 4 4]
   [4 4 4 ... 4 4 4]
   ...
   [4 4 4 ... 4 4 4]
   [4 4 4 ... 4 4 4]
   [4 4 4 ... 4 4 4]]]


 [[[3 3 3 ... 3 3 3]
   [3 3 3 ... 3 3 3]
   [3 3 3 ... 3 3 3]
   ...
   [3 3 3 ... 3 3 3]
   [3 3 3 ... 3 3 3]
   [3 3 3 ... 3 3 3]]

  [[4 4 4 ... 4 4 4]
   [4 4 4 ... 4 4 4]
   [4 4 4 ... 4 4 4]
   ...
   [4 4 4 ... 4 4 4]
   [4 4 4 ... 4 4 4]
   [4 4 4 ... 4 4 4]]

  [[5 5 5 ... 5 5 5]
   [5 5 5 ... 5 5 5]
   [5 5 5 ... 5 5 5]
   ...
   [5 5 5 ... 5 5 5]
   [5 5 5 ... 5 5 5]
   [5 5 5 ... 5 5 5]]

  [[6 6 6 ... 6 6 6]
   [6 6 6 ... 6 6 6]
   [6 6 6 ... 6 6 6]
   ...
   [6 6 6 ... 6 6 6]
   [6 6 6 ... 6 6 6]
   [6 6 6 ... 6 6 6]]]


 [[[5 5 5 ... 5 5 5]
   [5 5 5 ... 5 5 5]
   [5 5 5 ... 5 5 5]
   ...
   [5 5 5 ... 5 5 5]
   [5 5 5 ... 5 5 5]
   [5 5 5 ... 5 5 5]]

  [[6 6 6 ... 6 6 6]
   [6 6 6 ... 6 6 6]
   [6 6 6 ... 6 6 6]
   ...
   [6 6 6 ... 6 6 6]
   [6 6 6 ... 6 6 6]
   [6 6 6 ... 6 6 6]]

  [[7 7 7 ... 7 7 7]
   [7 7 7 ... 7 7 7]
   [7 7 7 ... 7 7 7]
   ...
   [7 7 7 ... 7 7 7]
   [7 7 7 ... 7 7 7]
   [7 7 7 ... 7 7 7]]

  [[8 8 8 ... 8 8 8]
   [8 8 8 ... 8 8 8]
   [8 8 8 ... 8 8 8]
   ...
   [8 8 8 ... 8 8 8]
   [8 8 8 ... 8 8 8]
   [8 8 8 ... 8 8 8]]]], shape=(3, 4, 16, 16), dtype=uint8)
tf.Tensor(
[[[[1 1 1 ... 1 1 1]
   [1 1 1 ... 1 1 1]
   [1 1 1 ... 1 1 1]
   ...
   [1 1 1 ... 1 1 1]
   [1 1 1 ... 1 1 1]
   [1 1 1 ... 1 1 1]]

  [[2 2 2 ... 2 2 2]
   [2 2 2 ... 2 2 2]
   [2 2 2 ... 2 2 2]
   ...
   [2 2 2 ... 2 2 2]
   [2 2 2 ... 2 2 2]
   [2 2 2 ... 2 2 2]]

  [[3 3 3 ... 3 3 3]
   [3 3 3 ... 3 3 3]
   [3 3 3 ... 3 3 3]
   ...
   [3 3 3 ... 3 3 3]
   [3 3 3 ... 3 3 3]
   [3 3 3 ... 3 3 3]]

  [[4 4 4 ... 4 4 4]
   [4 4 4 ... 4 4 4]
   [4 4 4 ... 4 4 4]
   ...
   [4 4 4 ... 4 4 4]
   [4 4 4 ... 4 4 4]
   [4 4 4 ... 4 4 4]]]


 [[[3 3 3 ... 3 3 3]
   [3 3 3 ... 3 3 3]
   [3 3 3 ... 3 3 3]
   ...
   [3 3 3 ... 3 3 3]
   [3 3 3 ... 3 3 3]
   [3 3 3 ... 3 3 3]]

  [[4 4 4 ... 4 4 4]
   [4 4 4 ... 4 4 4]
   [4 4 4 ... 4 4 4]
   ...
   [4 4 4 ... 4 4 4]
   [4 4 4 ... 4 4 4]
   [4 4 4 ... 4 4 4]]

  [[5 5 5 ... 5 5 5]
   [5 5 5 ... 5 5 5]
   [5 5 5 ... 5 5 5]
   ...
   [5 5 5 ... 5 5 5]
   [5 5 5 ... 5 5 5]
   [5 5 5 ... 5 5 5]]

  [[6 6 6 ... 6 6 6]
   [6 6 6 ... 6 6 6]
   [6 6 6 ... 6 6 6]
   ...
   [6 6 6 ... 6 6 6]
   [6 6 6 ... 6 6 6]
   [6 6 6 ... 6 6 6]]]


 [[[5 5 5 ... 5 5 5]
   [5 5 5 ... 5 5 5]
   [5 5 5 ... 5 5 5]
   ...
   [5 5 5 ... 5 5 5]
   [5 5 5 ... 5 5 5]
   [5 5 5 ... 5 5 5]]

  [[6 6 6 ... 6 6 6]
   [6 6 6 ... 6 6 6]
   [6 6 6 ... 6 6 6]
   ...
   [6 6 6 ... 6 6 6]
   [6 6 6 ... 6 6 6]
   [6 6 6 ... 6 6 6]]

  [[7 7 7 ... 7 7 7]
   [7 7 7 ... 7 7 7]
   [7 7 7 ... 7 7 7]
   ...
   [7 7 7 ... 7 7 7]
   [7 7 7 ... 7 7 7]
   [7 7 7 ... 7 7 7]]

  [[8 8 8 ... 8 8 8]
   [8 8 8 ... 8 8 8]
   [8 8 8 ... 8 8 8]
   ...
   [8 8 8 ... 8 8 8]
   [8 8 8 ... 8 8 8]
   [8 8 8 ... 8 8 8]]]], shape=(3, 4, 16, 16), dtype=uint8)

In [6]:
# Create trajectories with 2 frames stacked together, a stride of 3 and create
# sequences of 3 stacks. Note that this means that some frames will be dropped.
#
# For example, the first 12 steps will result in the following 3 samplable
# items:
#
#   [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]
#
#     -> [[1, 2], [4, 5], [6, 7]]
#     -> [[4, 5], [6, 7], [8, 9]]
#     -> [[6, 7], [8, 9], [11, 12]]
#

store_stacked(stack_size=2, stride=3, sequence_length=3)


tf.Tensor(
[[[[2 2 2 ... 2 2 2]
   [2 2 2 ... 2 2 2]
   [2 2 2 ... 2 2 2]
   ...
   [2 2 2 ... 2 2 2]
   [2 2 2 ... 2 2 2]
   [2 2 2 ... 2 2 2]]

  [[3 3 3 ... 3 3 3]
   [3 3 3 ... 3 3 3]
   [3 3 3 ... 3 3 3]
   ...
   [3 3 3 ... 3 3 3]
   [3 3 3 ... 3 3 3]
   [3 3 3 ... 3 3 3]]]


 [[[5 5 5 ... 5 5 5]
   [5 5 5 ... 5 5 5]
   [5 5 5 ... 5 5 5]
   ...
   [5 5 5 ... 5 5 5]
   [5 5 5 ... 5 5 5]
   [5 5 5 ... 5 5 5]]

  [[6 6 6 ... 6 6 6]
   [6 6 6 ... 6 6 6]
   [6 6 6 ... 6 6 6]
   ...
   [6 6 6 ... 6 6 6]
   [6 6 6 ... 6 6 6]
   [6 6 6 ... 6 6 6]]]


 [[[8 8 8 ... 8 8 8]
   [8 8 8 ... 8 8 8]
   [8 8 8 ... 8 8 8]
   ...
   [8 8 8 ... 8 8 8]
   [8 8 8 ... 8 8 8]
   [8 8 8 ... 8 8 8]]

  [[9 9 9 ... 9 9 9]
   [9 9 9 ... 9 9 9]
   [9 9 9 ... 9 9 9]
   ...
   [9 9 9 ... 9 9 9]
   [9 9 9 ... 9 9 9]
   [9 9 9 ... 9 9 9]]]], shape=(3, 2, 16, 16), dtype=uint8)
tf.Tensor(
[[[[2 2 2 ... 2 2 2]
   [2 2 2 ... 2 2 2]
   [2 2 2 ... 2 2 2]
   ...
   [2 2 2 ... 2 2 2]
   [2 2 2 ... 2 2 2]
   [2 2 2 ... 2 2 2]]

  [[3 3 3 ... 3 3 3]
   [3 3 3 ... 3 3 3]
   [3 3 3 ... 3 3 3]
   ...
   [3 3 3 ... 3 3 3]
   [3 3 3 ... 3 3 3]
   [3 3 3 ... 3 3 3]]]


 [[[5 5 5 ... 5 5 5]
   [5 5 5 ... 5 5 5]
   [5 5 5 ... 5 5 5]
   ...
   [5 5 5 ... 5 5 5]
   [5 5 5 ... 5 5 5]
   [5 5 5 ... 5 5 5]]

  [[6 6 6 ... 6 6 6]
   [6 6 6 ... 6 6 6]
   [6 6 6 ... 6 6 6]
   ...
   [6 6 6 ... 6 6 6]
   [6 6 6 ... 6 6 6]
   [6 6 6 ... 6 6 6]]]


 [[[8 8 8 ... 8 8 8]
   [8 8 8 ... 8 8 8]
   [8 8 8 ... 8 8 8]
   ...
   [8 8 8 ... 8 8 8]
   [8 8 8 ... 8 8 8]
   [8 8 8 ... 8 8 8]]

  [[9 9 9 ... 9 9 9]
   [9 9 9 ... 9 9 9]
   [9 9 9 ... 9 9 9]
   ...
   [9 9 9 ... 9 9 9]
   [9 9 9 ... 9 9 9]
   [9 9 9 ... 9 9 9]]]], shape=(3, 2, 16, 16), dtype=uint8)

Store flat and stack when sampled

If there is overlap between trajectories then it is probably more efficient to store flat sequences of data and create the frame stacking after the data has been received. Consider for example a trajectory with the following data:

[[1, 2, 3], [2, 3, 4], [3, 4, 5], [4, 5, 6]]

If each frame has size B then the total size of the trajectory is 4 * 3 * B = 12 * B. This cost has to be paid both in terms of memory and in network trafic every time the data is transmitted.

It is easy to see that even though the size is 12 * B it only holds 6 * B distinct data. We could therefore send [1, 2, 3, 4, 5, 6] and with some processing on the receiver side achieve the same result.

For the general case, assuming maximum overlap, the length of the flat sequence $L_{flat}$ needed to construct a stacked one $L_{stacked}$ with $H$ frames in each stack is:

$L_{flat} = L_{stacked} + H - 1$

For the example this becomes 4 + 3 - 1 = 6.


In [0]:
def store_flat(stack_size: int, sequence_length: int):
  """Simple example where frames are sent to Reverb and stacked after sampled.

  Args:
    stack_size: The number of frames to stack.
    sequence_length: The number of stacks in each sampleable item.
  """
  server = reverb.Server(
      tables=[
              reverb.Table(
                  name='flat_frames',
                  sampler=reverb.selectors.Fifo(),
                  remover=reverb.selectors.Fifo(),
                  max_size=100,
                  rate_limiter=reverb.rate_limiters.MinSize(10),
              ),
      ],
  )
  client = reverb.Client(f'localhost:{server.port}')

  # Insert flat sequences that can be stacked into the desired shape after
  # sampling.
  flat_sequence_length = sequence_length + stack_size - 1
  with client.writer(max_sequence_length=flat_sequence_length) as writer:
    for i, frame in enumerate(frame_generator(flat_sequence_length * 5)):
      writer.append(frame)

      if i + 1 >= flat_sequence_length:
        writer.create_item(table='flat_frames',
                           num_timesteps=flat_sequence_length,
                           priority=1.0)

  # Create a dataset that samples sequences of flat frames.
  flat_dataset = reverb.ReplayDataset(
      server_address=client.server_address,
      table='flat_frames',
      max_in_flight_samples_per_worker=10,
      dtypes=tf.as_dtype(FRAME_DTYPE),
      shapes=tf.TensorShape((flat_sequence_length,) + FRAME_SHAPE),
      sequence_length=flat_sequence_length,
      emit_timesteps=False)

  # Create a transformation that stacks the frames.
  def _stack(sample):
    stacks = []
    for i in range(sequence_length):
      stacks.append(sample.data[i:i+stack_size])
    return reverb.ReplaySample(
        info=sample.info,
        data=tf.stack(stacks))

  stacked_dataset = flat_dataset.map(_stack)



  # Print the result.
  for sequence in stacked_dataset.take(2):
    print(sequence.data)

In [8]:
# Create trajectories of 3 stacks each with 2 frames stacked together. The data
# is stored as a flat sequence and then stacked when sampled.
#
# For example, the first 6 steps will result in the following 3 sequences:
#
#   [1, 2, 3, 4, 5, 6]
#
#     -> [1, 2, 3, 4] -> [[1, 2], [2, 3], [3, 4]]
#     -> [2, 3, 4, 5] -> [[2, 3], [3, 4], [4, 5]]
#     -> [3, 4, 5, 6] -> [[3, 4], [4, 5], [5, 6]]
#

store_flat(stack_size=2, sequence_length=3)


tf.Tensor(
[[[[1 1 1 ... 1 1 1]
   [1 1 1 ... 1 1 1]
   [1 1 1 ... 1 1 1]
   ...
   [1 1 1 ... 1 1 1]
   [1 1 1 ... 1 1 1]
   [1 1 1 ... 1 1 1]]

  [[2 2 2 ... 2 2 2]
   [2 2 2 ... 2 2 2]
   [2 2 2 ... 2 2 2]
   ...
   [2 2 2 ... 2 2 2]
   [2 2 2 ... 2 2 2]
   [2 2 2 ... 2 2 2]]]


 [[[2 2 2 ... 2 2 2]
   [2 2 2 ... 2 2 2]
   [2 2 2 ... 2 2 2]
   ...
   [2 2 2 ... 2 2 2]
   [2 2 2 ... 2 2 2]
   [2 2 2 ... 2 2 2]]

  [[3 3 3 ... 3 3 3]
   [3 3 3 ... 3 3 3]
   [3 3 3 ... 3 3 3]
   ...
   [3 3 3 ... 3 3 3]
   [3 3 3 ... 3 3 3]
   [3 3 3 ... 3 3 3]]]


 [[[3 3 3 ... 3 3 3]
   [3 3 3 ... 3 3 3]
   [3 3 3 ... 3 3 3]
   ...
   [3 3 3 ... 3 3 3]
   [3 3 3 ... 3 3 3]
   [3 3 3 ... 3 3 3]]

  [[4 4 4 ... 4 4 4]
   [4 4 4 ... 4 4 4]
   [4 4 4 ... 4 4 4]
   ...
   [4 4 4 ... 4 4 4]
   [4 4 4 ... 4 4 4]
   [4 4 4 ... 4 4 4]]]], shape=(3, 2, 16, 16), dtype=uint8)
tf.Tensor(
[[[[1 1 1 ... 1 1 1]
   [1 1 1 ... 1 1 1]
   [1 1 1 ... 1 1 1]
   ...
   [1 1 1 ... 1 1 1]
   [1 1 1 ... 1 1 1]
   [1 1 1 ... 1 1 1]]

  [[2 2 2 ... 2 2 2]
   [2 2 2 ... 2 2 2]
   [2 2 2 ... 2 2 2]
   ...
   [2 2 2 ... 2 2 2]
   [2 2 2 ... 2 2 2]
   [2 2 2 ... 2 2 2]]]


 [[[2 2 2 ... 2 2 2]
   [2 2 2 ... 2 2 2]
   [2 2 2 ... 2 2 2]
   ...
   [2 2 2 ... 2 2 2]
   [2 2 2 ... 2 2 2]
   [2 2 2 ... 2 2 2]]

  [[3 3 3 ... 3 3 3]
   [3 3 3 ... 3 3 3]
   [3 3 3 ... 3 3 3]
   ...
   [3 3 3 ... 3 3 3]
   [3 3 3 ... 3 3 3]
   [3 3 3 ... 3 3 3]]]


 [[[3 3 3 ... 3 3 3]
   [3 3 3 ... 3 3 3]
   [3 3 3 ... 3 3 3]
   ...
   [3 3 3 ... 3 3 3]
   [3 3 3 ... 3 3 3]
   [3 3 3 ... 3 3 3]]

  [[4 4 4 ... 4 4 4]
   [4 4 4 ... 4 4 4]
   [4 4 4 ... 4 4 4]
   ...
   [4 4 4 ... 4 4 4]
   [4 4 4 ... 4 4 4]
   [4 4 4 ... 4 4 4]]]], shape=(3, 2, 16, 16), dtype=uint8)