The trainer module provides high-level APIs for training agents using reinforcement learning with PPO.
AgentTrainer
Wrapper class for training agents with custom environments using various backends.
from rllm.trainer import AgentTrainer
Constructor
def __init__(
workflow_class: type | None = None,
workflow_args: dict[str, Any] | None = None,
agent_class: type | None = None,
env_class: type | None = None,
agent_args: dict[str, Any] | None = None,
env_args: dict[str, Any] | None = None,
config: dict[str, Any] | list[str] | None = None,
train_dataset: Dataset | None = None,
val_dataset: Dataset | None = None,
backend: Literal["verl", "fireworks", "tinker"] = "verl",
agent_run_func: Callable | None = None
)
Workflow class to use for training (e.g., SimpleWorkflow, MultiTurnWorkflow).
Arguments to pass to the workflow class.
Custom agent class (not used with fireworks backend).
Custom environment class (not used with fireworks backend).
Arguments to pass to the agent class.
Arguments to pass to the environment class.
Configuration overrides. Can be:
- Dictionary with dot notation keys:
{"data.train_batch_size": 8}
- List of strings:
["data.train_batch_size=8", "trainer.total_epochs=3"]
backend
Literal['verl', 'fireworks', 'tinker']
default:"verl"
Training backend:
"verl": Standard backend supporting workflows and agent/env classes
"fireworks": Pipeline-based backend optimized for workflows
"tinker": Legacy backend
Optional custom function for agent execution (advanced usage).
Methods
train
Start the training process.
Configuration
The trainer uses Hydra for configuration management. Default config is at rllm/trainer/config/agent_ppo_trainer.yaml.
Common Config Overrides
config = {
# Data settings
"data.train_batch_size": 512,
"data.val_batch_size": 1024,
# Training settings
"trainer.total_epochs": 3,
"trainer.total_training_steps": 1000,
# PPO hyperparameters
"algorithm.gamma": 1.0,
"algorithm.lam": 0.95,
"algorithm.kl_penalty": 0.001,
# Model settings
"actor_rollout_ref.model.path": "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B",
# GRPO settings
"algorithm.adv_estimator": "grpo",
"algorithm.num_samples_per_prompt": 4,
}
Example: Training with SimpleWorkflow
import hydra
from rllm.trainer import AgentTrainer
from rllm.workflows import SimpleWorkflow
from rllm.rewards import math_reward_fn
from rllm.data import DatasetRegistry
@hydra.main(
config_path="pkg://rllm.trainer.config",
config_name="agent_ppo_trainer",
version_base=None
)
def main(config):
# Load datasets
train_dataset = DatasetRegistry.load_dataset("hendrycks_math", "train")
val_dataset = DatasetRegistry.load_dataset("math500", "test")
# Create trainer
trainer = AgentTrainer(
workflow_class=SimpleWorkflow,
workflow_args={
"reward_function": math_reward_fn,
},
config=config,
train_dataset=train_dataset,
val_dataset=val_dataset,
backend="verl"
)
# Start training
trainer.train()
if __name__ == "__main__":
main()
Example: Training with Agent/Environment Classes
import hydra
from rllm.trainer import AgentTrainer
from rllm.agents import ToolAgent
from rllm.environments import ToolEnvironment
from rllm.rewards import search_reward_fn
from rllm.data import DatasetRegistry
@hydra.main(
config_path="pkg://rllm.trainer.config",
config_name="agent_ppo_trainer",
version_base=None
)
def main(config):
train_dataset = DatasetRegistry.load_dataset("hotpotqa", "train")
val_dataset = DatasetRegistry.load_dataset("hotpotqa", "test")
tool_map = {"search": MySearchTool}
trainer = AgentTrainer(
agent_class=ToolAgent,
env_class=ToolEnvironment,
agent_args={
"tool_map": tool_map,
"system_prompt": "You are a helpful search assistant."
},
env_args={
"tool_map": tool_map,
"reward_fn": search_reward_fn
},
config=config,
train_dataset=train_dataset,
val_dataset=val_dataset,
backend="verl"
)
trainer.train()
if __name__ == "__main__":
main()
Example: Config Overrides
import hydra
from rllm.trainer import AgentTrainer
from rllm.workflows import MultiTurnWorkflow
from rllm.data import DatasetRegistry
@hydra.main(
config_path="pkg://rllm.trainer.config",
config_name="agent_ppo_trainer",
version_base=None
)
def main(config):
# Override config values
config_overrides = {
"data.train_batch_size": 256,
"data.val_batch_size": 512,
"trainer.total_epochs": 5,
"algorithm.gamma": 1.0,
"algorithm.num_samples_per_prompt": 8,
"actor_rollout_ref.model.path": "Qwen/Qwen3-4B",
}
# Apply overrides
for key, value in config_overrides.items():
keys = key.split(".")
cfg = config
for k in keys[:-1]:
cfg = getattr(cfg, k)
setattr(cfg, keys[-1], value)
trainer = AgentTrainer(
workflow_class=MultiTurnWorkflow,
workflow_args={
"agent_cls": "ToolAgent",
"env_cls": "ToolEnvironment",
"max_steps": 5
},
config=config,
train_dataset=DatasetRegistry.load_dataset("mydata", "train"),
val_dataset=DatasetRegistry.load_dataset("mydata", "val"),
backend="verl"
)
trainer.train()
if __name__ == "__main__":
main()
Example: Using agent_run_func (Advanced)
import hydra
from rllm.trainer import AgentTrainer
from rllm.rewards import math_reward_fn
from rllm.sdk import get_chat_client
from rllm.data import DatasetRegistry
@hydra.main(
config_path="pkg://rllm.trainer.config",
config_name="agent_ppo_trainer",
version_base=None
)
def main(config):
# Define custom run function
def rollout(**kwargs):
question = kwargs["question"]
ground_truth = kwargs["ground_truth"]
# Create client inside function for serialization
client = get_chat_client(
base_url="http://localhost:4000/v1",
api_key="EMPTY"
)
response = client.chat.completions.create(
model="deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B",
messages=[{"role": "user", "content": question}]
)
response_text = response.choices[0].message.content
reward = math_reward_fn(
{"response": response_text, "ground_truth": ground_truth},
response_text
).reward
return reward * 1.0
trainer = AgentTrainer(
config=config,
train_dataset=DatasetRegistry.load_dataset("hendrycks_math", "train"),
val_dataset=DatasetRegistry.load_dataset("math500", "test"),
agent_run_func=rollout
)
trainer.train()
if __name__ == "__main__":
main()
Running Training
Run training scripts with Hydra CLI overrides:
# Basic training
python train.py
# Override config from command line
python train.py data.train_batch_size=512 trainer.total_epochs=5
# Use different model
python train.py actor_rollout_ref.model.path=Qwen/Qwen3-4B
# Adjust PPO hyperparameters
python train.py algorithm.gamma=1.0 algorithm.num_samples_per_prompt=8