Documentation Index
Fetch the complete documentation index at: https://docs.rllm-project.com/llms.txt
Use this file to discover all available pages before exploring further.
The trainer module provides high-level APIs for training agents using reinforcement learning with PPO.
AgentTrainer
Wrapper class for training agents with custom environments using various backends.
from rllm.trainer import AgentTrainer
Constructor
def __init__(
workflow_class: type | None = None,
workflow_args: dict[str, Any] | None = None,
config: dict[str, Any] | list[str] | None = None,
train_dataset: Dataset | None = None,
val_dataset: Dataset | None = None,
backend: Literal["verl", "fireworks", "tinker"] = "verl",
agent_run_func: Callable | None = None
)
Two ways to plug in your agent: pass either workflow_class (a Workflow subclass) or agent_run_func (a plain rollout function for the AgentSdk path).
Workflow class to use for training (e.g., SimpleWorkflow, MultiTurnWorkflow).
Arguments to pass to the workflow class.
Configuration overrides. Can be:
- Dictionary with dot notation keys:
{"data.train_batch_size": 8}
- List of strings:
["data.train_batch_size=8", "trainer.total_epochs=3"]
backend
Literal['verl', 'fireworks', 'tinker']
default:"verl"
Training backend:
"verl": Standard distributed PPO via the verl framework
"fireworks": Pipeline-based variant (workflow-only) for the Fireworks workflow API
"tinker": Single-machine LoRA training via tinker (workflow-only)
Plain rollout function — drives the AgentSdk path. Use this or
workflow_class, not both.
The legacy agent_class + env_class parameters that drove the
AgentExecutionEngine rollout have been removed. Port your agent to
either a Workflow or an
AgentFlow — see the
cookbooks/
directory for examples.
Methods
train
Start the training process.
Configuration
The trainer uses Hydra for configuration management. Default config is at rllm/trainer/config/agent_ppo_trainer.yaml.
Common Config Overrides
config = {
# Data settings
"data.train_batch_size": 512,
"data.val_batch_size": 1024,
# Training settings
"trainer.total_epochs": 3,
"trainer.total_training_steps": 1000,
# PPO hyperparameters
"algorithm.gamma": 1.0,
"algorithm.lam": 0.95,
"algorithm.kl_penalty": 0.001,
# Model settings
"actor_rollout_ref.model.path": "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B",
# GRPO settings
"algorithm.adv_estimator": "grpo",
"algorithm.num_samples_per_prompt": 4,
}
Example: Training with SimpleWorkflow
import hydra
from rllm.trainer import AgentTrainer
from rllm.workflows import SimpleWorkflow
from rllm.rewards import math_reward_fn
from rllm.data import DatasetRegistry
@hydra.main(
config_path="pkg://rllm.trainer.config",
config_name="agent_ppo_trainer",
version_base=None
)
def main(config):
# Load datasets
train_dataset = DatasetRegistry.load_dataset("hendrycks_math", "train")
val_dataset = DatasetRegistry.load_dataset("math500", "test")
# Create trainer
trainer = AgentTrainer(
workflow_class=SimpleWorkflow,
workflow_args={
"reward_function": math_reward_fn,
},
config=config,
train_dataset=train_dataset,
val_dataset=val_dataset,
backend="verl"
)
# Start training
trainer.train()
if __name__ == "__main__":
main()
Example: Config Overrides
import hydra
from rllm.trainer import AgentTrainer
from rllm.workflows import MultiTurnWorkflow
from rllm.data import DatasetRegistry
@hydra.main(
config_path="pkg://rllm.trainer.config",
config_name="agent_ppo_trainer",
version_base=None
)
def main(config):
# Override config values
config_overrides = {
"data.train_batch_size": 256,
"data.val_batch_size": 512,
"trainer.total_epochs": 5,
"algorithm.gamma": 1.0,
"algorithm.num_samples_per_prompt": 8,
"actor_rollout_ref.model.path": "Qwen/Qwen3-4B",
}
# Apply overrides
for key, value in config_overrides.items():
keys = key.split(".")
cfg = config
for k in keys[:-1]:
cfg = getattr(cfg, k)
setattr(cfg, keys[-1], value)
trainer = AgentTrainer(
workflow_class=MultiTurnWorkflow,
workflow_args={
# Substitute your own BaseAgent / BaseEnv subclasses.
"agent_cls": MyAgent,
"env_cls": MyEnv,
"max_steps": 5,
},
config=config,
train_dataset=DatasetRegistry.load_dataset("mydata", "train"),
val_dataset=DatasetRegistry.load_dataset("mydata", "val"),
backend="verl",
)
trainer.train()
if __name__ == "__main__":
main()
Example: Using agent_run_func (Advanced)
import hydra
from rllm.trainer import AgentTrainer
from rllm.rewards import math_reward_fn
from rllm.sdk import get_chat_client
from rllm.data import DatasetRegistry
@hydra.main(
config_path="pkg://rllm.trainer.config",
config_name="agent_ppo_trainer",
version_base=None
)
def main(config):
# Define custom run function
def rollout(**kwargs):
question = kwargs["question"]
ground_truth = kwargs["ground_truth"]
# Create client inside function for serialization
client = get_chat_client(
base_url="http://localhost:4000/v1",
api_key="EMPTY"
)
response = client.chat.completions.create(
model="deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B",
messages=[{"role": "user", "content": question}]
)
response_text = response.choices[0].message.content
reward = math_reward_fn(
{"response": response_text, "ground_truth": ground_truth},
response_text
).reward
return reward * 1.0
trainer = AgentTrainer(
config=config,
train_dataset=DatasetRegistry.load_dataset("hendrycks_math", "train"),
val_dataset=DatasetRegistry.load_dataset("math500", "test"),
agent_run_func=rollout
)
trainer.train()
if __name__ == "__main__":
main()
Running Training
Run training scripts with Hydra CLI overrides:
# Basic training
python train.py
# Override config from command line
python train.py data.train_batch_size=512 trainer.total_epochs=5
# Use different model
python train.py actor_rollout_ref.model.path=Qwen/Qwen3-4B
# Adjust PPO hyperparameters
python train.py algorithm.gamma=1.0 algorithm.num_samples_per_prompt=8