
Class WrappedReplayBuffer

Wrapper of OutOfGraphReplayBuffer with an in graph sampling mechanism.

Usage: To add a transition: call the add function.

To sample a batch: Construct operations that depend on any of the tensors is the
transition dictionary. Every that requires any of these tensors will
sample a new transition.



  1. __init__(
  2. *args,
  3. **kwargs
  4. )

Initializes WrappedReplayBuffer.


  • observation_shape: tuple or int. If int, the observation is assumed
    to be a 2D square.
  • stack_size: int, number of frames to use in state stack.
  • use_staging: bool, when True it would use a staging area to
    prefetch the next sampling batch.
  • replay_capacity: int, number of transitions to keep in memory.
  • batch_size: int.
  • update_horizon: int, length of update (‘n’ in n-step update).
  • gamma: int, the discount factor.
  • wrapped_memory: The ‘inner’ memory data structure. If None, it
    creates the standard DQN replay memory.
  • max_sample_attempts: int, the maximum number of attempts allowed to
    get a sample.
  • extra_storage_types: list of ReplayElements defining the type of
    the extra contents that will be stored and returned by
  • observation_dtype: np.dtype, type of the observations. Defaults to
    np.uint8 for Atari 2600.


  • ValueError: If update_horizon is not positive.
  • ValueError: If discount factor is not in [0, 1].


  1. add(
  2. observation,
  3. action,
  4. reward,
  5. terminal,
  6. *args
  7. )

Adds a transition to the replay memory.

Since the next_observation in the transition will be the observation added next
there is no need to pass it.

If the replay memory is at capacity the oldest transition will be discarded.


  • observation: np.array with shape observation_shape.
  • action: int, the action in the transition.
  • reward: float, the reward received in the transition.
  • terminal: A uint8 acting as a boolean indicating whether the
    transition was terminal (1) or not (0).
  • *args: extra contents with shapes and dtypes according to


  1. create_sampling_ops(use_staging)

Creates the ops necessary to sample from the replay buffer.

Creates the transition dictionary containing the sampling tensors.


  • use_staging: bool, when True it would use a staging area to
    prefetch the next sampling batch.


  1. load(
  2. checkpoint_dir,
  3. suffix
  4. )

Loads the replay buffer’s state from a saved file.


  • checkpoint_dir: str, the directory where to read the numpy
    checkpointed files from.
  • suffix: str, the suffix to use in numpy checkpoint files.


  1. save(
  2. checkpoint_dir,
  3. iteration_number
  4. )

Save the underlying replay buffer’s contents in a file.


  • checkpoint_dir: str, the directory where to read the numpy
    checkpointed files from.
  • iteration_number: int, the iteration_number to use as a suffix in
    naming numpy checkpoint files.


  1. unpack_transition(
  2. transition_tensors,
  3. transition_type
  4. )

Unpacks the given transition into member variables.


  • transition_tensors: tuple of tf.Tensors.
  • transition_type: tuple of ReplayElements matching