circular_replay_buffer.OutOfGraphReplayBuffer

Class OutOfGraphReplayBuffer

A simple out-of-graph Replay Buffer.

Stores transitions, state, action, reward, next_state, terminal (and any extra
contents specified) in a circular buffer and provides a uniform transition
sampling function.

When the states consist of stacks of observations storing the states is
inefficient. This class writes observations and constructs the stacked states at
sample time.

Attributes:

  • add_count: int, counter of how many transitions have been added
    (including the blank ones at the beginning of an episode).

Methods

init

  1. __init__(
  2. observation_shape,
  3. stack_size,
  4. replay_capacity,
  5. batch_size,
  6. update_horizon=1,
  7. gamma=0.99,
  8. max_sample_attempts=MAX_SAMPLE_ATTEMPTS,
  9. extra_storage_types=None,
  10. observation_dtype=np.uint8
  11. )

Initializes OutOfGraphReplayBuffer.

Args:

  • observation_shape: tuple or int. If int, the observation is assumed
    to be a 2D square.
  • stack_size: int, number of frames to use in state stack.
  • replay_capacity: int, number of transitions to keep in memory.
  • batch_size: int.
  • update_horizon: int, length of update (‘n’ in n-step update).
  • gamma: int, the discount factor.
  • max_sample_attempts: int, the maximum number of attempts allowed to
    get a sample.
  • extra_storage_types: list of ReplayElements defining the type of
    the extra contents that will be stored and returned by
    sample_transition_batch.
  • observation_dtype: np.dtype, type of the observations. Defaults to
    np.uint8 for Atari 2600.

Raises:

  • ValueError: If replay_capacity is too small to hold at least one
    transition.

add

  1. add(
  2. observation,
  3. action,
  4. reward,
  5. terminal,
  6. *args
  7. )

Adds a transition to the replay memory.

This function checks the types and handles the padding at the beginning of an
episode. Then it calls the _add function.

Since the next_observation in the transition will be the observation added next
there is no need to pass it.

If the replay memory is at capacity the oldest transition will be discarded.

Args:

  • observation: np.array with shape observation_shape.
  • action: int, the action in the transition.
  • reward: float, the reward received in the transition.
  • terminal: A uint8 acting as a boolean indicating whether the
    transition was terminal (1) or not (0).
  • *args: extra contents with shapes and dtypes according to
    extra_storage_types.

cursor

  1. cursor()

Index to the location where the next transition will be written.

get_add_args_signature

  1. get_add_args_signature()

The signature of the add function.

Note - Derived classes may return a different signature.

Returns:

list of ReplayElements defining the type of the argument signature needed by the
add function.

get_observation_stack

  1. get_observation_stack(index)

get_range

  1. get_range(
  2. array,
  3. start_index,
  4. end_index
  5. )

Returns the range of array at the index handling wraparound if necessary.

Args:

  • array: np.array, the array to get the stack from.
  • start_index: int, index to the start of the range to be returned.
    Range will wraparound if start_index is smaller than 0.
  • end_index: int, exclusive end index. Range will wraparound if
    end_index exceeds replay_capacity.

Returns:

np.array, with shape [end_index - start_index, array.shape[1:]].

get_storage_signature

  1. get_storage_signature()

Returns a default list of elements to be stored in this replay memory.

Note - Derived classes may return a different signature.

Returns:

list of ReplayElements defining the type of the contents stored.

get_terminal_stack

  1. get_terminal_stack(index)

get_transition_elements

  1. get_transition_elements(batch_size=None)

Returns a ‘type signature’ for sample_transition_batch.

Args:

  • batch_size: int, number of transitions returned. If None, the
    default batch_size will be used.

Returns:

  • signature: A namedtuple describing the method’s return type
    signature.

is_empty

  1. is_empty()

Is the Replay Buffer empty?

is_full

  1. is_full()

Is the Replay Buffer full?

is_valid_transition

  1. is_valid_transition(index)

Checks if the index contains a valid transition.

Checks for collisions with the end of episodes and the current position of the
cursor.

Args:

  • index: int, the index to the state in the transition.

Returns:

Is the index valid: Boolean.

load

  1. load(
  2. checkpoint_dir,
  3. suffix
  4. )

Restores the object from bundle_dictionary and numpy checkpoints.

Args:

  • checkpoint_dir: str, the directory where to read the numpy
    checkpointed files from.
  • suffix: str, the suffix to use in numpy checkpoint files.

Raises:

  • NotFoundError: If not all expected files are found in directory.

sample_index_batch

  1. sample_index_batch(batch_size)

Returns a batch of valid indices sampled uniformly.

Args:

  • batch_size: int, number of indices returned.

Returns:

list of ints, a batch of valid indices sampled uniformly.

Raises:

  • RuntimeError: If the batch was not constructed after maximum number
    of tries.

sample_transition_batch

  1. sample_transition_batch(
  2. batch_size=None,
  3. indices=None
  4. )

Returns a batch of transitions (including any extra contents).

If get_transition_elements has been overridden and defines elements not stored
in self._store, an empty array will be returned and it will be left to the child
class to fill it. For example, for the child class
OutOfGraphPrioritizedReplayBuffer, the contents of the sampling_probabilities
are stored separately in a sum tree.

When the transition is terminal next_state_batch has undefined contents.

NOTE: This transition contains the indices of the sampled elements. These are
only valid during the call to sample_transition_batch, i.e. they may be used by
subclasses of this replay buffer but may point to different data as soon as
sampling is done.

Args:

  • batch_size: int, number of transitions returned. If None, the
    default batch_size will be used.
  • indices: None or list of ints, the indices of every transition in
    the batch. If None, sample the indices uniformly.

Returns:

  • transition_batch: tuple of np.arrays with the shape and type as in
    get_transition_elements().

Raises:

  • ValueError: If an element to be sampled is missing from the replay
    buffer.

save

  1. save(
  2. checkpoint_dir,
  3. iteration_number
  4. )

Save the OutOfGraphReplayBuffer attributes into a file.

This method will save all the replay buffer’s state in a single file.

Args:

  • checkpoint_dir: str, the directory where numpy checkpoint files
    should be saved.
  • iteration_number: int, iteration_number to use as a suffix in
    naming numpy checkpoint files.