circular_replay_buffer.WrappedReplayBuffer

Class WrappedReplayBuffer

Wrapper of OutOfGraphReplayBuffer with an in graph sampling mechanism.

Usage: To add a transition: call the add function.

To sample a batch: Construct operations that depend on any of the tensors is the
transition dictionary. Every sess.run that requires any of these tensors will
sample a new transition.

Methods

init

  1. __init__(
  2. *args,
  3. **kwargs
  4. )

Initializes WrappedReplayBuffer.

Args:

  • observation_shape: tuple or int. If int, the observation is assumed
    to be a 2D square.
  • stack_size: int, number of frames to use in state stack.
  • use_staging: bool, when True it would use a staging area to
    prefetch the next sampling batch.
  • replay_capacity: int, number of transitions to keep in memory.
  • batch_size: int.
  • update_horizon: int, length of update (‘n’ in n-step update).
  • gamma: int, the discount factor.
  • wrapped_memory: The ‘inner’ memory data structure. If None, it
    creates the standard DQN replay memory.
  • max_sample_attempts: int, the maximum number of attempts allowed to
    get a sample.
  • extra_storage_types: list of ReplayElements defining the type of
    the extra contents that will be stored and returned by
    sample_transition_batch.
  • observation_dtype: np.dtype, type of the observations. Defaults to
    np.uint8 for Atari 2600.

Raises:

  • ValueError: If update_horizon is not positive.
  • ValueError: If discount factor is not in [0, 1].

add

  1. add(
  2. observation,
  3. action,
  4. reward,
  5. terminal,
  6. *args
  7. )

Adds a transition to the replay memory.

Since the next_observation in the transition will be the observation added next
there is no need to pass it.

If the replay memory is at capacity the oldest transition will be discarded.

Args:

  • observation: np.array with shape observation_shape.
  • action: int, the action in the transition.
  • reward: float, the reward received in the transition.
  • terminal: A uint8 acting as a boolean indicating whether the
    transition was terminal (1) or not (0).
  • *args: extra contents with shapes and dtypes according to
    extra_storage_types.

create_sampling_ops

  1. create_sampling_ops(use_staging)

Creates the ops necessary to sample from the replay buffer.

Creates the transition dictionary containing the sampling tensors.

Args:

  • use_staging: bool, when True it would use a staging area to
    prefetch the next sampling batch.

load

  1. load(
  2. checkpoint_dir,
  3. suffix
  4. )

Loads the replay buffer’s state from a saved file.

Args:

  • checkpoint_dir: str, the directory where to read the numpy
    checkpointed files from.
  • suffix: str, the suffix to use in numpy checkpoint files.

save

  1. save(
  2. checkpoint_dir,
  3. iteration_number
  4. )

Save the underlying replay buffer’s contents in a file.

Args:

  • checkpoint_dir: str, the directory where to read the numpy
    checkpointed files from.
  • iteration_number: int, the iteration_number to use as a suffix in
    naming numpy checkpoint files.

unpack_transition

  1. unpack_transition(
  2. transition_tensors,
  3. transition_type
  4. )

Unpacks the given transition into member variables.

Args:

  • transition_tensors: tuple of tf.Tensors.
  • transition_type: tuple of ReplayElements matching
    transition_tensors.