prioritized_replay_buffer.OutOfGraphPrioritizedReplayBuffer

Class OutOfGraphPrioritizedReplayBuffer

Inherits From:
OutOfGraphReplayBuffer

An out-of-graph Replay Buffer for Prioritized Experience Replay.

See circular_replay_buffer.py for details.

Methods

init

  1. __init__(
  2. observation_shape,
  3. stack_size,
  4. replay_capacity,
  5. batch_size,
  6. update_horizon=1,
  7. gamma=0.99,
  8. max_sample_attempts=circular_replay_buffer.MAX_SAMPLE_ATTEMPTS,
  9. extra_storage_types=None,
  10. observation_dtype=np.uint8
  11. )

Initializes OutOfGraphPrioritizedReplayBuffer.

Args:

  • observation_shape: tuple or int. If int, the observation is assumed
    to be a 2D square with sides equal to observation_shape.
  • stack_size: int, number of frames to use in state stack.
  • replay_capacity: int, number of transitions to keep in memory.
  • batch_size: int.
  • update_horizon: int, length of update (‘n’ in n-step update).
  • gamma: int, the discount factor.
  • max_sample_attempts: int, the maximum number of attempts allowed to
    get a sample.
  • extra_storage_types: list of ReplayElements defining the type of
    the extra contents that will be stored and returned by
    sample_transition_batch.
  • observation_dtype: np.dtype, type of the observations. Defaults to
    np.uint8 for Atari 2600.

add

  1. add(
  2. observation,
  3. action,
  4. reward,
  5. terminal,
  6. *args
  7. )

Adds a transition to the replay memory.

This function checks the types and handles the padding at the beginning of an
episode. Then it calls the _add function.

Since the next_observation in the transition will be the observation added next
there is no need to pass it.

If the replay memory is at capacity the oldest transition will be discarded.

Args:

  • observation: np.array with shape observation_shape.
  • action: int, the action in the transition.
  • reward: float, the reward received in the transition.
  • terminal: A uint8 acting as a boolean indicating whether the
    transition was terminal (1) or not (0).
  • *args: extra contents with shapes and dtypes according to
    extra_storage_types.

cursor

  1. cursor()

Index to the location where the next transition will be written.

get_add_args_signature

  1. get_add_args_signature()

The signature of the add function.

The signature is the same as the one for OutOfGraphReplayBuffer, with an added
priority.

Returns:

list of ReplayElements defining the type of the argument signature needed by the
add function.

get_observation_stack

  1. get_observation_stack(index)

get_priority

  1. get_priority(indices)

Fetches the priorities correspond to a batch of memory indices.

For any memory location not yet used, the corresponding priority is 0.

Args:

  • indices: np.array with dtype int32, of indices in range [0,
    replay_capacity).

Returns:

  • priorities: float, the corresponding priorities.

get_range

  1. get_range(
  2. array,
  3. start_index,
  4. end_index
  5. )

Returns the range of array at the index handling wraparound if necessary.

Args:

  • array: np.array, the array to get the stack from.
  • start_index: int, index to the start of the range to be returned.
    Range will wraparound if start_index is smaller than 0.
  • end_index: int, exclusive end index. Range will wraparound if
    end_index exceeds replay_capacity.

Returns:

np.array, with shape [end_index - start_index, array.shape[1:]].

get_storage_signature

  1. get_storage_signature()

Returns a default list of elements to be stored in this replay memory.

Note - Derived classes may return a different signature.

Returns:

list of ReplayElements defining the type of the contents stored.

get_terminal_stack

  1. get_terminal_stack(index)

get_transition_elements

  1. get_transition_elements(batch_size=None)

Returns a ‘type signature’ for sample_transition_batch.

Args:

  • batch_size: int, number of transitions returned. If None, the
    default batch_size will be used.

Returns:

  • signature: A namedtuple describing the method’s return type
    signature.

is_empty

  1. is_empty()

Is the Replay Buffer empty?

is_full

  1. is_full()

Is the Replay Buffer full?

is_valid_transition

  1. is_valid_transition(index)

Checks if the index contains a valid transition.

Checks for collisions with the end of episodes and the current position of the
cursor.

Args:

  • index: int, the index to the state in the transition.

Returns:

Is the index valid: Boolean.

load

  1. load(
  2. checkpoint_dir,
  3. suffix
  4. )

Restores the object from bundle_dictionary and numpy checkpoints.

Args:

  • checkpoint_dir: str, the directory where to read the numpy
    checkpointed files from.
  • suffix: str, the suffix to use in numpy checkpoint files.

Raises:

  • NotFoundError: If not all expected files are found in directory.

sample_index_batch

  1. sample_index_batch(batch_size)

Returns a batch of valid indices sampled as in Schaul et al. (2015).

Args:

  • batch_size: int, number of indices returned.

Returns:

list of ints, a batch of valid indices sampled uniformly.

Raises:

  • Exception: If the batch was not constructed after maximum number of
    tries.

sample_transition_batch

  1. sample_transition_batch(
  2. batch_size=None,
  3. indices=None
  4. )

Returns a batch of transitions with extra storage and the priorities.

The extra storage are defined through the extra_storage_types constructor
argument.

When the transition is terminal next_state_batch has undefined contents.

Args:

  • batch_size: int, number of transitions returned. If None, the
    default batch_size will be used.
  • indices: None or list of ints, the indices of every transition in
    the batch. If None, sample the indices uniformly.

Returns:

  • transition_batch: tuple of np.arrays with the shape and type as in
    get_transition_elements().

save

  1. save(
  2. checkpoint_dir,
  3. iteration_number
  4. )

Save the OutOfGraphReplayBuffer attributes into a file.

This method will save all the replay buffer’s state in a single file.

Args:

  • checkpoint_dir: str, the directory where numpy checkpoint files
    should be saved.
  • iteration_number: int, iteration_number to use as a suffix in
    naming numpy checkpoint files.

set_priority

  1. set_priority(
  2. indices,
  3. priorities
  4. )

Sets the priority of the given elements according to Schaul et al.

Args:

  • indices: np.array with dtype int32, of indices in range [0,
    replay_capacity).
  • priorities: float, the corresponding priorities.