The central orchestrator for backend-agnostic training in rLLM, managing the full training loop from episode generation to policy updates.
Module: rllm.experimental.unified_trainer
The UnifiedTrainer is the central orchestrator for backend-agnostic training in rLLM.
It manages the full training loop — episode generation, data transformation, advantage
computation, policy updates, validation, and logging — while delegating all
backend-specific operations to a pluggable BackendProtocol implementation.This page contains the technical details of the unified trainer training loop.
The trainer itself is backend-agnostic: it knows nothing about model weights,
optimizers, or inference servers. All of that is encapsulated behind the
BackendProtocol interface (see backend protocol).
A mutable dataclass that serves as the shared context between the trainer and the
backend throughout a training step. It is reset at the start of each batch via
reset_batch().
Copy
Ask AI
@dataclassclass TrainerState: rs_state: RejectionSamplingState # rejection-sampling accumulator/state # Progress global_step: int = 0 epoch: int = 0 total_steps: int = 0 is_training: bool = True # Timing and metrics (reset per batch) timing_dict: dict # populated by simple_timer context managers metrics: dict # logged to wandb/tracking after each batch extra_info: dict # backend-private scratchpad (e.g. logprobs, LR) # Pipeline data (reset per batch) episodes: list[Episode] | None trajectory_groups: list[TrajectoryGroup] | None backend_batch: Any | None # backend-specific format
Convenience properties:has_episodes, has_trajectory_groups, has_backend_batch
— used by the trainer to detect early-return conditions (e.g. all episodes filtered).
Each call to _train_batch_async executes the following stages. Stages 1-3 are
framework-managed. Stages 4-7 are delegated to the backend.
Stage
Method / Owner
Sync/Async
Description
1
backend.generate_episodes()
async
Run workflows to produce Episode objects
2
transform_episodes_to_trajectory_groups
sync
Group episodes into TrajectoryGroups
3
apply_rejection_sampling_and_filtering
sync
Filter groups (solve-all / solve-none / etc.)
4
backend.transform_to_backend_batch()
sync
Convert to backend-native format
5
backend.process_backend_batch()
async
Forward/backward pass, compute logprobs, etc.
6
backend.compute_advantages()
async
Compute per-step advantages
7
backend.update_policy()
async
Optimizer step
8
(framework) visualization + metrics
sync
Print trajectories, collect workflow metrics
Early returns: The pipeline returns early (skipping stages 4-8) if no episodes
are generated in stage 1, or if all trajectory groups are filtered out in stage 3.
The lifecycle hooks (on_batch_end, logger.log) still execute even after an early
return.
Before training (if config.rllm.trainer.val_before_train is true)
Periodically during training (every config.rllm.trainer.test_freq steps)
After training completes (only when test_freq > 0)
The validation loop calls backend.generate_episodes(..., is_validation=True),
transforms the results, and computes reward metrics (no advantage computation or
policy updates). Pass@1 and pass@K metrics are computed per data source and logged.The backend can control validation via hooks:
on_validation_start returns a bool — return False to skip validation entirely
on_validation_end is called when validation actually runs (i.e. not skipped)