The trainer module provides high-level APIs for training agents using reinforcement learning with PPO.
AgentTrainer
Backend-agnostic wrapper for training agents with custom workflows or AgentFlows. As of the unified-trainer graduation, rllm.trainer.AgentTrainer resolves to rllm.trainer.unified_trainer.AgentTrainer.
from rllm.trainer import AgentTrainer
For the full constructor reference (including sandbox auto-wiring, agent_flow / evaluator / hooks plumbing, and the verl / tinker backends), see the unified trainer page. A condensed view:
Constructor
def __init__(
config: DictConfig,
workflow_class: type[Workflow] | None = None,
train_dataset: Dataset | None = None,
val_dataset: Dataset | None = None,
workflow_args: dict | None = None,
backend: Literal["verl", "tinker"] = "verl",
agent_flow: Any = None,
evaluator: Any = None,
hooks: Any = None,
sandbox_backend: str | None = None,
sandbox_concurrency: int | None = None,
store: Store | None = None,
)
Two ways to plug in your agent: pass either workflow_class (a Workflow subclass) or agent_flow (an AgentFlow built with the AgentSdk).
Workflow class to use for training (e.g., SimpleWorkflow, MultiTurnWorkflow).
Arguments to pass to the workflow class.
Configuration overrides. Can be:
- Dictionary with dot notation keys:
{"data.train_batch_size": 8}
- List of strings:
["data.train_batch_size=8", "trainer.total_epochs=3"]
backend
Literal['verl', 'tinker']
default:"verl"
Training backend:
"verl": Distributed PPO via the verl framework
"tinker": LoRA training via tinker
AgentFlow object (from the AgentSdk path). Use this or workflow_class,
not both.
The verl-specific legacy trainer lives at
rllm.trainer.agent_trainer.AgentTrainer. New code should prefer the unified
from rllm.trainer import AgentTrainer.
Methods
train
Start the training process.
Configuration
The trainer uses Hydra for configuration management. Default config is at rllm/trainer/config/agent_ppo_trainer.yaml.
Common Config Overrides
config = {
# Data settings
"data.train_batch_size": 512,
"data.val_batch_size": 1024,
# Training settings
"trainer.total_epochs": 3,
"trainer.total_training_steps": 1000,
# PPO hyperparameters
"algorithm.gamma": 1.0,
"algorithm.lam": 0.95,
"algorithm.kl_penalty": 0.001,
# Model settings
"actor_rollout_ref.model.path": "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B",
# GRPO settings
"algorithm.adv_estimator": "grpo",
"algorithm.num_samples_per_prompt": 4,
}
Example: Training with SimpleWorkflow
import hydra
from rllm.trainer import AgentTrainer
from rllm.workflows import SimpleWorkflow
from rllm.rewards import math_reward_fn
from rllm.data import DatasetRegistry
@hydra.main(
config_path="pkg://rllm.trainer.config",
config_name="agent_ppo_trainer",
version_base=None
)
def main(config):
# Load datasets
train_dataset = DatasetRegistry.load_dataset("hendrycks_math", "train")
val_dataset = DatasetRegistry.load_dataset("math500", "test")
# Create trainer
trainer = AgentTrainer(
workflow_class=SimpleWorkflow,
workflow_args={
"reward_function": math_reward_fn,
},
config=config,
train_dataset=train_dataset,
val_dataset=val_dataset,
backend="verl"
)
# Start training
trainer.train()
if __name__ == "__main__":
main()
Example: Config Overrides
import hydra
from rllm.trainer import AgentTrainer
from rllm.workflows import MultiTurnWorkflow
from rllm.data import DatasetRegistry
@hydra.main(
config_path="pkg://rllm.trainer.config",
config_name="agent_ppo_trainer",
version_base=None
)
def main(config):
# Override config values
config_overrides = {
"data.train_batch_size": 256,
"data.val_batch_size": 512,
"trainer.total_epochs": 5,
"algorithm.gamma": 1.0,
"algorithm.num_samples_per_prompt": 8,
"actor_rollout_ref.model.path": "Qwen/Qwen3-4B",
}
# Apply overrides
for key, value in config_overrides.items():
keys = key.split(".")
cfg = config
for k in keys[:-1]:
cfg = getattr(cfg, k)
setattr(cfg, keys[-1], value)
trainer = AgentTrainer(
workflow_class=MultiTurnWorkflow,
workflow_args={
# Substitute your own BaseAgent / BaseEnv subclasses.
"agent_cls": MyAgent,
"env_cls": MyEnv,
"max_steps": 5,
},
config=config,
train_dataset=DatasetRegistry.load_dataset("mydata", "train"),
val_dataset=DatasetRegistry.load_dataset("mydata", "val"),
backend="verl",
)
trainer.train()
if __name__ == "__main__":
main()
Running Training
Run training scripts with Hydra CLI overrides:
# Basic training
python train.py
# Override config from command line
python train.py data.train_batch_size=512 trainer.total_epochs=5
# Use different model
python train.py actor_rollout_ref.model.path=Qwen/Qwen3-4B
# Adjust PPO hyperparameters
python train.py algorithm.gamma=1.0 algorithm.num_samples_per_prompt=8