prioritized_replay_buffer.WrappedPrioritizedReplayBuffer
Class WrappedPrioritizedReplayBuffer
Inherits From:WrappedReplayBuffer
Wrapper of OutOfGraphPrioritizedReplayBuffer with in-graph sampling.
Usage:
To add a transition: Call the add function.
To sample a batch: Query any of the tensors in the transition dictionary.
Every sess.run that requires any of these tensors will sample a new
transition.
Methods
init
__init__(
*args,
**kwargs
)
Initializes WrappedPrioritizedReplayBuffer.
Args:
observation_shape
: tuple or int. If int, the observation is assumed
to be a 2D square with sides equal to observation_shape.stack_size
: int, number of frames to use in state stack.use_staging
: bool, when True it would use a staging area to
prefetch the next sampling batch.replay_capacity
: int, number of transitions to keep in memory.batch_size
: int.update_horizon
: int, length of update (‘n’ in n-step update).gamma
: int, the discount factor.max_sample_attempts
: int, the maximum number of attempts allowed to
get a sample.extra_storage_types
: list of ReplayElements defining the type of
the extra contents that will be stored and returned by
sample_transition_batch.observation_dtype
: np.dtype, type of the observations. Defaults to
np.uint8 for Atari 2600.
Raises:
ValueError
: If update_horizon is not positive.ValueError
: If discount factor is not in [0, 1].
add
add(
observation,
action,
reward,
terminal,
*args
)
Adds a transition to the replay memory.
Since the next_observation in the transition will be the observation added next
there is no need to pass it.
If the replay memory is at capacity the oldest transition will be discarded.
Args:
observation
: np.array with shape observation_shape.action
: int, the action in the transition.reward
: float, the reward received in the transition.terminal
: A uint8 acting as a boolean indicating whether the
transition was terminal (1) or not (0).*args
: extra contents with shapes and dtypes according to
extra_storage_types.
create_sampling_ops
create_sampling_ops(use_staging)
Creates the ops necessary to sample from the replay buffer.
Creates the transition dictionary containing the sampling tensors.
Args:
use_staging
: bool, when True it would use a staging area to
prefetch the next sampling batch.
load
load(
checkpoint_dir,
suffix
)
Loads the replay buffer’s state from a saved file.
Args:
checkpoint_dir
: str, the directory where to read the numpy
checkpointed files from.suffix
: str, the suffix to use in numpy checkpoint files.
save
save(
checkpoint_dir,
iteration_number
)
Save the underlying replay buffer’s contents in a file.
Args:
checkpoint_dir
: str, the directory where to read the numpy
checkpointed files from.iteration_number
: int, the iteration_number to use as a suffix in
naming numpy checkpoint files.
tf_get_priority
tf_get_priority(indices)
Gets the priorities for the given indices.
Args:
indices
: tf.Tensor with dtype int32 and shape [n].
Returns:
priorities
: tf.Tensor with dtype float and shape [n], the
priorities at the indices.
tf_set_priority
tf_set_priority(
indices,
priorities
)
Sets the priorities for the given indices.
Args:
indices
: tf.Tensor with dtype int32 and shape [n].priorities
: tf.Tensor with dtype float and shape [n].
Returns:
A tf op setting the priorities for prioritized sampling.
unpack_transition
unpack_transition(
transition_tensors,
transition_type
)
Unpacks the given transition into member variables.
Args:
transition_tensors
: tuple of tf.Tensors.transition_type
: tuple of ReplayElements matching
transition_tensors.