prioritized_replay_buffer.WrappedPrioritizedReplayBuffer

Class WrappedPrioritizedReplayBuffer

Inherits From:
WrappedReplayBuffer

Wrapper of OutOfGraphPrioritizedReplayBuffer with in-graph sampling.

Usage:

  • To add a transition: Call the add function.

  • To sample a batch: Query any of the tensors in the transition dictionary.
    Every sess.run that requires any of these tensors will sample a new
    transition.

Methods

init

  1. __init__(
  2. *args,
  3. **kwargs
  4. )

Initializes WrappedPrioritizedReplayBuffer.

Args:

  • observation_shape: tuple or int. If int, the observation is assumed
    to be a 2D square with sides equal to observation_shape.
  • stack_size: int, number of frames to use in state stack.
  • use_staging: bool, when True it would use a staging area to
    prefetch the next sampling batch.
  • replay_capacity: int, number of transitions to keep in memory.
  • batch_size: int.
  • update_horizon: int, length of update (‘n’ in n-step update).
  • gamma: int, the discount factor.
  • max_sample_attempts: int, the maximum number of attempts allowed to
    get a sample.
  • extra_storage_types: list of ReplayElements defining the type of
    the extra contents that will be stored and returned by
    sample_transition_batch.
  • observation_dtype: np.dtype, type of the observations. Defaults to
    np.uint8 for Atari 2600.

Raises:

  • ValueError: If update_horizon is not positive.
  • ValueError: If discount factor is not in [0, 1].

add

  1. add(
  2. observation,
  3. action,
  4. reward,
  5. terminal,
  6. *args
  7. )

Adds a transition to the replay memory.

Since the next_observation in the transition will be the observation added next
there is no need to pass it.

If the replay memory is at capacity the oldest transition will be discarded.

Args:

  • observation: np.array with shape observation_shape.
  • action: int, the action in the transition.
  • reward: float, the reward received in the transition.
  • terminal: A uint8 acting as a boolean indicating whether the
    transition was terminal (1) or not (0).
  • *args: extra contents with shapes and dtypes according to
    extra_storage_types.

create_sampling_ops

  1. create_sampling_ops(use_staging)

Creates the ops necessary to sample from the replay buffer.

Creates the transition dictionary containing the sampling tensors.

Args:

  • use_staging: bool, when True it would use a staging area to
    prefetch the next sampling batch.

load

  1. load(
  2. checkpoint_dir,
  3. suffix
  4. )

Loads the replay buffer’s state from a saved file.

Args:

  • checkpoint_dir: str, the directory where to read the numpy
    checkpointed files from.
  • suffix: str, the suffix to use in numpy checkpoint files.

save

  1. save(
  2. checkpoint_dir,
  3. iteration_number
  4. )

Save the underlying replay buffer’s contents in a file.

Args:

  • checkpoint_dir: str, the directory where to read the numpy
    checkpointed files from.
  • iteration_number: int, the iteration_number to use as a suffix in
    naming numpy checkpoint files.

tf_get_priority

  1. tf_get_priority(indices)

Gets the priorities for the given indices.

Args:

  • indices: tf.Tensor with dtype int32 and shape [n].

Returns:

  • priorities: tf.Tensor with dtype float and shape [n], the
    priorities at the indices.

tf_set_priority

  1. tf_set_priority(
  2. indices,
  3. priorities
  4. )

Sets the priorities for the given indices.

Args:

  • indices: tf.Tensor with dtype int32 and shape [n].
  • priorities: tf.Tensor with dtype float and shape [n].

Returns:

A tf op setting the priorities for prioritized sampling.

unpack_transition

  1. unpack_transition(
  2. transition_tensors,
  3. transition_type
  4. )

Unpacks the given transition into member variables.

Args:

  • transition_tensors: tuple of tf.Tensors.
  • transition_type: tuple of ReplayElements matching
    transition_tensors.