storage
1.reply_buffer.py
Class: ReplayBuffer
Purpose
- Stores experience data (here: states and next states) for reinforcement learning (RL).
- Implements a circular buffer (ring buffer): once full, new data overwrites the oldest data.
- Provides a sampling function
feed_forward_generatorto draw mini-batches for training.
Constructor: init
def __init__(self, obs_dim, buffer_size, device):
self.states = torch.zeros(buffer_size, obs_dim).to(device)
self.next_states = torch.zeros(buffer_size, obs_dim).to(device)
self.buffer_size = buffer_size
self.device = device
self.step = 0
self.num_samples = 0
Arguments
- obs_dim: Dimension of each state vector.
- buffer_size: Maximum number of entries the buffer can hold.
- device: Device to store the tensors on (CPU or GPU).
Internal variables
self.states: Stores current states ([buffer_size, obs_dim]).self.next_states: Stores next states.self.step: The current write index in the buffer.self.num_samples: Number of valid samples currently in the buffer (≤ buffer_size).
Method: insert
def insert(self, states, next_states):
num_states = states.shape[0]
start_idx = self.step
end_idx = self.step + num_states
...
Purpose
Insert a batch of states and their corresponding next states into the buffer.
Logic
- Determine the write range: from
self.steptoend_idx. - Check if it exceeds buffer size:
- Not exceeded → write directly.
- Exceeded → split writing into two parts:
- Fill from
self.step : buffer_size. - Wrap around and fill from the start
[0 : (end_idx - buffer_size)].
- Fill from
- Update tracking:
self.num_samples: updated to the number of valid samples (max capped atbuffer_size).self.step: advanced to the new write position (wrapped with modulo% buffer_size).
Method: feed_forward_generator
def feed_forward_generator(self, num_mini_batch, mini_batch_size):
for _ in range(num_mini_batch):
sample_idxs = np.random.choice(self.num_samples, size=mini_batch_size)
yield (self.states[sample_idxs].to(self.device),
self.next_states[sample_idxs].to(self.device))
Purpose
Randomly sample mini-batches of data for training.
Arguments
- num_mini_batch: Number of mini-batches to generate.
- mini_batch_size: Number of samples per mini-batch.
Logic
- Randomly pick
mini_batch_sizeindices from the valid samples (self.num_samples). - Collect the corresponding states and next states.
- Yield them one mini-batch at a time.
2.rollout_storage.py
1. RolloutStorage
This is the base class for storing rollouts (trajectories) in reinforcement learning. It keeps all the data you need for PPO (or other policy gradient methods):
- Core storage: observations, critic observations (privileged info), actions, rewards, dones.
- PPO-specific: action log-probs, value predictions, returns, advantages, policy distribution parameters (μ, σ).
- Optional: hidden states (for RNN policies).
Key features:
add_transitions()→ add one step of data.compute_returns()→ calculate GAE (generalized advantage estimation).mini_batch_generator()→ shuffle and sample minibatches.reccurent_mini_batch_generator()→ specialized batching for RNNs.get_statistics()→ average episode length + average reward.
2. QueueRolloutStorage (extends RolloutStorage)
Adds support for a rolling buffer (like a queue), useful when the rollout length isn’t fixed.
- Can expand the buffer size dynamically.
- Can loop the buffer (new data overwrites old).
untie_buffer_loop()→ reorders buffer so the latest data is continuous.- Designed for training with buffered rollouts instead of strict episode cuts.
3. ActionLabelRollout (extends QueueRolloutStorage)
A variant that also stores action labels (e.g., for imitation learning).
- Adds an extra
action_labelstensor. - MiniBatch now includes
action_labels. - Everything else works the same as
QueueRolloutStorage.
4. SarsaRolloutStorage (extends RolloutStorage)
Specialized for algorithms like SARSA, where you need both the current state and the next state.
- Stores
next_observationsandnext_critic_observations. - Uses an extended buffer (
all_observations) so you can easily shift data by 1 timestep. - Ensures that each transition has
(s, a, r, s')aligned.
3.rollout_files
1.base.py
Class: RolloutFileBase
This is an abstract base class for datasets that load and serve rollouts (trajectories / sequences) from files. It inherits from torch.utils.data.IterableDataset, so you can iterate over it like a PyTorch dataset.
It’s designed as a template — real implementations (subclasses) must implement the abstract methods (reset_all, refresh_handlers, get_buffer, fill_transition).
Key Attributes
data_dir→ where the rollout data is stored (file directory).num_envs→ how many environments (parallel envs / agents) to manage.device→ usually"cuda"or"cpu".__initialized→ ensures lazy initialization (reset happens on first use).all_env_ids→ tensor with IDs[0, 1, ..., num_envs-1]representing environments.
Main Methods
reset(env_ids=None)
- Resets rollout handlers.
- If no env_ids given → reset all environments.
- If env_ids provided → only refresh handlers for those envs.
- (Useful when some envs terminate early but others keep running.)
get_batch(num_transitions_per_env=None)
- Fills a buffer with rollout data.
- If
num_transitions_per_env=None→ returns a single transition per env. - Else → returns a sequence of transitions of length
num_transitions_per_env. - Calls
fill_transition()internally to populate the buffer. - First time it’s called, it will automatically call
reset().
get_transition_batch()
- Convenience method to simulate environment stepping.
- Returns
(s, a, r, d, info)transitions like a gym env. - If
"timeout"field exists in buffer, wraps it in{"time_outs": buffer.timeout}for compatibility.
Dataset Interface (iter, next)
- Allows iteration in PyTorch’s
DataLoader. iter()→ resets the dataset.next()→ returns the next batch (viaget_batch()).
Abstract Methods (must be implemented in subclasses)
reset_all()- Rebuild all handlers (e.g., file readers, trajectory pointers).
- Reset envs to initial states.
- Example: start reading from the first trajectory in each env.
refresh_handlers(env_ids)- Reset only specific envs (e.g., when they hit end of trajectory).
- Useful for multi-env training where envs finish episodes at different times.
get_buffer(num_transitions_per_env=None)- Allocate an empty buffer (PyTorch tensor/dict) for transitions.
- Shape depends on whether
num_transitions_per_envis set.
fill_transition(buffer, env_ids=None)- Actually load transitions from file into the buffer.
- Advance the trajectory cursor (like stepping forward in a video).
- Must include both current and next observation.
- Data format per step should be
(s, a, r, d, ...).
High-level role
- Provides a unified interface for trajectory loading.
- Can be used with:
- offline RL (load dataset of rollouts from disk).
- imitation learning (playback expert demonstrations).
- hybrid methods (mix real env + replay buffer + offline data).
2.rollout_dataset.py
Class Overview: RolloutDataset
- Inherits from
RolloutFileBase(abstract base class for trajectory loaders). - Purpose: Load, manage, and feed rollout (trajectory) data from files into training (e.g., imitation learning, RL).
- Handles multiple environments (
num_envs), dataset looping, shuffling, and on-demand loading. - Maintains transitions in a named tuple (
Transition) containing:observation,privileged_observationaction,rewarddone,timeoutnext_observation,next_privileged_observation
Constructor
def __init__(self, data_dir, num_envs, dataset_loops=1, random_shuffle_traj_order=False, keep_latest_n_trajs=0, starting_frame_range=[0, 1], device="cuda"):
- Args:
data_dir: directory containing trajectories.num_envs: number of parallel environments.dataset_loops: how many times to loop dataset before stopping.random_shuffle_traj_order: whether to randomize trajectory order.keep_latest_n_trajs: only keep the most recent N trajectories.starting_frame_range: where to start inside a trajectory (random within range).device:"cuda"or"cpu".
- Initializes counters (
num_dataset_looped) and configs.
Data Reading & Preparation
get_frame_range(filename)
- Extracts frame index range
(start, end)from filename (e.g.,"traj_100_200.pkl"→(100, 200)).
read_dataset_directory()
- Scans
data_dirfor trajectories (trajectory_*folders). - Loads and sorts trajectories by modification time.
- Loads metadata (
metadata.json) if present. - Keeps track of unused trajectories and supports random shuffling.
- Returns
Trueif enough data exists, otherwise waits.
assemble_obs_components(traj_data)
- Reconstructs observations from compressed components using metadata.
- Concatenates different observation parts into a full observation tensor.
Handler Management
reset_all()
- Clears all handlers.
- Ensures dataset directory is valid and trajectories exist.
- Initializes tracking structures for each environment: identifiers, file names, lengths, cursors, etc.
- Calls
refresh_handlers()to assign initial trajectories.
_refresh_traj_data(env_idx)
- Loads a specific trajectory file for a given environment.
- Converts numpy arrays → PyTorch tensors (on
device). - Optionally reconstructs observations from compressed components.
_refresh_traj_handler(env_idx)
- Assigns a trajectory to an environment.
- Randomizes starting frame (within
starting_frame_range). - Ensures the cursor is within a valid file.
- Marks the first frame as
done=True.
refresh_handlers(env_ids)
- Refreshes trajectory handlers for selected envs.
- Assigns unused trajectory IDs to them.
_maintain_handler(env_idx)
- Maintains trajectory progress.
- If one trajectory finishes → loads next one.
- Handles looping if dataset ends and
dataset_loops > 1.
Buffer & Transition Filling
get_buffer(num_transitions_per_env=None)
- Builds an output transition buffer (
Transitiontuple) with required shape. - Pre-allocates tensors for efficiency (observations, actions, rewards, dones, etc.).
- Supports both single-step and multi-step (time-major) format.
_fill_transition_per_env(buffer, env_idx)
- Writes a single environment’s transition into buffer.
- Handles:
- Copying observation, privileged observation, action, reward, done, timeout.
- Advancing trajectory cursor.
- Loading next trajectory when current is exhausted.
- Ensures next_observation is also filled.
fill_transition(buffer, env_ids=None)
- Iterates over environments and fills each env’s transition into the buffer.
- If
env_idsisNone, processes all environments.
How It Works in Training
- On reset: scans directories, loads available trajectories, sets handlers.
- On
get_batch: requests a batch of transitions. - On
fill_transition: loads actual(s, a, r, d, next_s)from trajectories. - Iteratively feeds these batches into RL training.