Skip to main content
The trainer module provides high-level APIs for training agents using reinforcement learning with PPO.

AgentTrainer

Backend-agnostic wrapper for training agents with custom workflows or AgentFlows. As of the unified-trainer graduation, rllm.trainer.AgentTrainer resolves to rllm.trainer.unified_trainer.AgentTrainer.
from rllm.trainer import AgentTrainer
For the full constructor reference (including sandbox auto-wiring, agent_flow / evaluator / hooks plumbing, and the verl / tinker backends), see the unified trainer page. A condensed view:

Constructor

def __init__(
    config: DictConfig,
    workflow_class: type[Workflow] | None = None,
    train_dataset: Dataset | None = None,
    val_dataset: Dataset | None = None,
    workflow_args: dict | None = None,
    backend: Literal["verl", "tinker"] = "verl",
    agent_flow: Any = None,
    evaluator: Any = None,
    hooks: Any = None,
    sandbox_backend: str | None = None,
    sandbox_concurrency: int | None = None,
    store: Store | None = None,
)
Two ways to plug in your agent: pass either workflow_class (a Workflow subclass) or agent_flow (an AgentFlow built with the AgentSdk).
workflow_class
type | None
Workflow class to use for training (e.g., SimpleWorkflow, MultiTurnWorkflow).
workflow_args
dict | None
Arguments to pass to the workflow class.
config
dict | list[str] | None
Configuration overrides. Can be:
  • Dictionary with dot notation keys: {"data.train_batch_size": 8}
  • List of strings: ["data.train_batch_size=8", "trainer.total_epochs=3"]
train_dataset
Dataset | None
Training dataset.
val_dataset
Dataset | None
Validation dataset.
backend
Literal['verl', 'tinker']
default:"verl"
Training backend:
  • "verl": Distributed PPO via the verl framework
  • "tinker": LoRA training via tinker
agent_flow
Any | None
AgentFlow object (from the AgentSdk path). Use this or workflow_class, not both.
The verl-specific legacy trainer lives at rllm.trainer.agent_trainer.AgentTrainer. New code should prefer the unified from rllm.trainer import AgentTrainer.

Methods

train

Start the training process.
trainer.train()

Configuration

The trainer uses Hydra for configuration management. Default config is at rllm/trainer/config/agent_ppo_trainer.yaml.

Common Config Overrides

config = {
    # Data settings
    "data.train_batch_size": 512,
    "data.val_batch_size": 1024,
    
    # Training settings
    "trainer.total_epochs": 3,
    "trainer.total_training_steps": 1000,
    
    # PPO hyperparameters
    "algorithm.gamma": 1.0,
    "algorithm.lam": 0.95,
    "algorithm.kl_penalty": 0.001,
    
    # Model settings
    "actor_rollout_ref.model.path": "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B",
    
    # GRPO settings
    "algorithm.adv_estimator": "grpo",
    "algorithm.num_samples_per_prompt": 4,
}

Example: Training with SimpleWorkflow

import hydra
from rllm.trainer import AgentTrainer
from rllm.workflows import SimpleWorkflow
from rllm.rewards import math_reward_fn
from rllm.data import DatasetRegistry

@hydra.main(
    config_path="pkg://rllm.trainer.config",
    config_name="agent_ppo_trainer",
    version_base=None
)
def main(config):
    # Load datasets
    train_dataset = DatasetRegistry.load_dataset("hendrycks_math", "train")
    val_dataset = DatasetRegistry.load_dataset("math500", "test")
    
    # Create trainer
    trainer = AgentTrainer(
        workflow_class=SimpleWorkflow,
        workflow_args={
            "reward_function": math_reward_fn,
        },
        config=config,
        train_dataset=train_dataset,
        val_dataset=val_dataset,
        backend="verl"
    )
    
    # Start training
    trainer.train()

if __name__ == "__main__":
    main()

Example: Config Overrides

import hydra
from rllm.trainer import AgentTrainer
from rllm.workflows import MultiTurnWorkflow
from rllm.data import DatasetRegistry

@hydra.main(
    config_path="pkg://rllm.trainer.config",
    config_name="agent_ppo_trainer",
    version_base=None
)
def main(config):
    # Override config values
    config_overrides = {
        "data.train_batch_size": 256,
        "data.val_batch_size": 512,
        "trainer.total_epochs": 5,
        "algorithm.gamma": 1.0,
        "algorithm.num_samples_per_prompt": 8,
        "actor_rollout_ref.model.path": "Qwen/Qwen3-4B",
    }
    
    # Apply overrides
    for key, value in config_overrides.items():
        keys = key.split(".")
        cfg = config
        for k in keys[:-1]:
            cfg = getattr(cfg, k)
        setattr(cfg, keys[-1], value)
    
    trainer = AgentTrainer(
        workflow_class=MultiTurnWorkflow,
        workflow_args={
            # Substitute your own BaseAgent / BaseEnv subclasses.
            "agent_cls": MyAgent,
            "env_cls": MyEnv,
            "max_steps": 5,
        },
        config=config,
        train_dataset=DatasetRegistry.load_dataset("mydata", "train"),
        val_dataset=DatasetRegistry.load_dataset("mydata", "val"),
        backend="verl",
    )
    
    trainer.train()

if __name__ == "__main__":
    main()

Running Training

Run training scripts with Hydra CLI overrides:
# Basic training
python train.py

# Override config from command line
python train.py data.train_batch_size=512 trainer.total_epochs=5

# Use different model
python train.py actor_rollout_ref.model.path=Qwen/Qwen3-4B

# Adjust PPO hyperparameters
python train.py algorithm.gamma=1.0 algorithm.num_samples_per_prompt=8