Skip to main content
Module: rllm.experimental.unified_trainer
The UnifiedTrainer is the central orchestrator for backend-agnostic training in rLLM. It manages the full training loop — episode generation, data transformation, advantage computation, policy updates, validation, and logging — while delegating all backend-specific operations to a pluggable BackendProtocol implementation. This page contains the technical details of the unified trainer training loop.

Architecture overview

                         UnifiedTrainer
                +-----------+-----------+
                |                       |
        BackendProtocol         UnifiedWorkflowEngine
        (e.g. TinkerBackend)    (manages Workflow pool)
                |                       |
        Backend-specific        Workflow instances
        infra (model, optim)    (rollout logic)
The trainer itself is backend-agnostic: it knows nothing about model weights, optimizers, or inference servers. All of that is encapsulated behind the BackendProtocol interface (see backend protocol).

Entry points

There are two ways to start training: A convenience wrapper that selects the correct TrainerLauncher for the backend string ("verl" or "tinker") and handles environment setup.
from rllm.experimental.unified_trainer import AgentTrainer

trainer = AgentTrainer(
    config=config,
    workflow_class=MyWorkflow,
    train_dataset=train_dataset,
    val_dataset=val_dataset,
    backend="tinker",
)
trainer.train()

UnifiedTrainer (direct, for custom backends)

When using a custom backend class, instantiate the trainer directly:
from rllm.experimental.unified_trainer import UnifiedTrainer

trainer = UnifiedTrainer(
    backend_cls=MyCustomBackend,
    config=config,
    workflow_class=MyWorkflow,
    train_dataset=train_dataset,
    val_dataset=val_dataset,
    workflow_args={"my_param": 42},
    backend_args={"device": "cuda:0"},
)
trainer.fit()
Constructor parameters:
ParameterTypeDescription
backend_clstype[BackendProtocol]The backend class to instantiate
configDictConfigFull Hydra config (must contain config.rllm)
workflow_classtype[Workflow]The workflow class for episode generation
train_datasetDataset | NoneTraining dataset
val_datasetDataset | NoneValidation dataset
workflow_argsdict | NoneExtra kwargs passed to each workflow instance
backend_argsdict | NoneExtra kwargs passed to the backend constructor
traj_grouping_hookCallable | NoneOptional custom episode-to-trajectory-group hook
traj_group_adv_estimator_mapdict | NoneOptional per-role advantage estimator override map
**kwargsAnyForwarded to the selected launcher/backend
When traj_group_adv_estimator_map is provided, rllm.algorithm.use_rllm must be true.

TrainerState

A mutable dataclass that serves as the shared context between the trainer and the backend throughout a training step. It is reset at the start of each batch via reset_batch().
@dataclass
class TrainerState:
    rs_state: RejectionSamplingState  # rejection-sampling accumulator/state

    # Progress
    global_step: int = 0
    epoch: int = 0
    total_steps: int = 0
    is_training: bool = True

    # Timing and metrics (reset per batch)
    timing_dict: dict          # populated by simple_timer context managers
    metrics: dict              # logged to wandb/tracking after each batch
    extra_info: dict           # backend-private scratchpad (e.g. logprobs, LR)

    # Pipeline data (reset per batch)
    episodes: list[Episode] | None
    trajectory_groups: list[TrajectoryGroup] | None
    backend_batch: Any | None  # backend-specific format
Convenience properties: has_episodes, has_trajectory_groups, has_backend_batch — used by the trainer to detect early-return conditions (e.g. all episodes filtered).

Initialization sequence

When the UnifiedTrainer is constructed, the following happens in order:
StepMethodDescription
1backend_cls(config, ...)Instantiate the backend
2_validate_and_setup_configs()Build AlgorithmConfig, TransformConfig, etc.
3_setup_logging()Init tracking and optional episode logger
4backend.init_rollout_engine()Backend creates its RolloutEngine
5UnifiedWorkflowEngine(...)Create the workflow engine (workflow class + args + rollout engine)
The workflow pool initialization (initialize_pool) happens when fit_async() starts, not in the constructor.

Training loop

Calling trainer.fit() runs fit_async() via asyncio.run(...). The high-level flow:
If val_before_train=true and val_only=true, training returns after initial validation and does not enter _fit_async().

The 8-stage batch pipeline

Each call to _train_batch_async executes the following stages. Stages 1-3 are framework-managed. Stages 4-7 are delegated to the backend.
StageMethod / OwnerSync/AsyncDescription
1backend.generate_episodes()asyncRun workflows to produce Episode objects
2transform_episodes_to_trajectory_groupssyncGroup episodes into TrajectoryGroups
3apply_rejection_sampling_and_filteringsyncFilter groups (solve-all / solve-none / etc.)
4backend.transform_to_backend_batch()syncConvert to backend-native format
5backend.process_backend_batch()asyncForward/backward pass, compute logprobs, etc.
6backend.compute_advantages()asyncCompute per-step advantages
7backend.update_policy()asyncOptimizer step
8(framework) visualization + metricssyncPrint trajectories, collect workflow metrics
Early returns: The pipeline returns early (skipping stages 4-8) if no episodes are generated in stage 1, or if all trajectory groups are filtered out in stage 3. The lifecycle hooks (on_batch_end, logger.log) still execute even after an early return.

Validation loop

Validation is triggered:
  • Before training (if config.rllm.trainer.val_before_train is true)
  • Periodically during training (every config.rllm.trainer.test_freq steps)
  • After training completes (only when test_freq > 0)
The validation loop calls backend.generate_episodes(..., is_validation=True), transforms the results, and computes reward metrics (no advantage computation or policy updates). Pass@1 and pass@K metrics are computed per data source and logged. The backend can control validation via hooks:
  • on_validation_start returns a bool — return False to skip validation entirely
  • on_validation_end is called when validation actually runs (i.e. not skipped)

Configuration

The trainer reads configuration from config.rllm (the rLLM sub-config within the full Hydra config). Key config groups:
Config pathBuilt configUsed for
rllm.compact_filteringCompactFilteringConfigFilter invalid episodes
rllm.stepwise_advantageTransformConfigEpisode-to-group transform mode
rllm.rejection_sampleRejectionSamplingConfigRejection sampling settings
rllm.algorithmAlgorithmConfigAdvantage estimator, loss fn, LR schedule
rllm.workflow(direct)n_parallel_tasks, retry_limit, raise_on_error
rllm.trainer(direct)total_epochs, total_batches, test_freq, save_freq, logger
rllm.rollout(direct)n (group size), n_val (val samples per task)
For the full configuration reference, see configuration.

AlgorithmConfig fields

FieldTypeDefaultDescription
estimatorrLLMAdvantageEstimatorGRPOAdvantage estimator (GRPO, REINFORCE, REINFORCE_PLUS_PLUS_BASELINE, RLOO)
estimator_mapdict[str, rLLMAdvantageEstimator | str]{}Per-role estimator override map (set by traj_group_adv_estimator_map)
stepwise_advantage_mode"broadcast""broadcast"How advantages map to steps
norm_adv_by_std_in_grpoboolTrueNormalize advantages by std in GRPO
use_rllmboolFalseWhether to use rLLM-native advantage path (relevant for Verl backend)
use_precomputed_advantageboolFalseReuse pre-computed step.advantage from workflow instead of recomputing
loss_fnstr | NoneNoneBackend loss function (e.g. "importance_sampling")
lr_schedulestr"constant"LR schedule: "constant", "linear", "cosine"
warmup_steps_ratiofloat0.0Fraction of total steps for LR warmup

Async design

The trainer uses an async-prioritized design:
  • fit() is the sync entry point and runs fit_async() via asyncio.run(...)
  • fit_async() is available directly if you are already in an async context
  • The pipeline mixes async and sync steps:
    • async: generate_episodes, process_backend_batch, compute_advantages, update_policy
    • sync: transformation/rejection-sampling steps, dataloader access, logging, visualization

Shutdown

Always call trainer.shutdown() when done (or use a try/finally block). This:
  1. Shuts down the workflow engine
  2. Calls backend.shutdown() for backend-specific cleanup
  3. Calls logger.finish() to flush and close the tracking backend (e.g. wandb)
try:
    trainer = UnifiedTrainer(...)
    trainer.fit()
finally:
    trainer.shutdown()