Skip to main content
The trainer module provides high-level APIs for training agents using reinforcement learning with PPO.

AgentTrainer

Wrapper class for training agents with custom environments using various backends.
from rllm.trainer import AgentTrainer

Constructor

def __init__(
    workflow_class: type | None = None,
    workflow_args: dict[str, Any] | None = None,
    agent_class: type | None = None,
    env_class: type | None = None,
    agent_args: dict[str, Any] | None = None,
    env_args: dict[str, Any] | None = None,
    config: dict[str, Any] | list[str] | None = None,
    train_dataset: Dataset | None = None,
    val_dataset: Dataset | None = None,
    backend: Literal["verl", "fireworks", "tinker"] = "verl",
    agent_run_func: Callable | None = None
)
workflow_class
type | None
Workflow class to use for training (e.g., SimpleWorkflow, MultiTurnWorkflow).
workflow_args
dict | None
Arguments to pass to the workflow class.
agent_class
type | None
Custom agent class (not used with fireworks backend).
env_class
type | None
Custom environment class (not used with fireworks backend).
agent_args
dict | None
Arguments to pass to the agent class.
env_args
dict | None
Arguments to pass to the environment class.
config
dict | list[str] | None
Configuration overrides. Can be:
  • Dictionary with dot notation keys: {"data.train_batch_size": 8}
  • List of strings: ["data.train_batch_size=8", "trainer.total_epochs=3"]
train_dataset
Dataset | None
Training dataset.
val_dataset
Dataset | None
Validation dataset.
backend
Literal['verl', 'fireworks', 'tinker']
default:"verl"
Training backend:
  • "verl": Standard backend supporting workflows and agent/env classes
  • "fireworks": Pipeline-based backend optimized for workflows
  • "tinker": Legacy backend
agent_run_func
Callable | None
Optional custom function for agent execution (advanced usage).

Methods

train

Start the training process.
trainer.train()

Configuration

The trainer uses Hydra for configuration management. Default config is at rllm/trainer/config/agent_ppo_trainer.yaml.

Common Config Overrides

config = {
    # Data settings
    "data.train_batch_size": 512,
    "data.val_batch_size": 1024,
    
    # Training settings
    "trainer.total_epochs": 3,
    "trainer.total_training_steps": 1000,
    
    # PPO hyperparameters
    "algorithm.gamma": 1.0,
    "algorithm.lam": 0.95,
    "algorithm.kl_penalty": 0.001,
    
    # Model settings
    "actor_rollout_ref.model.path": "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B",
    
    # GRPO settings
    "algorithm.adv_estimator": "grpo",
    "algorithm.num_samples_per_prompt": 4,
}

Example: Training with SimpleWorkflow

import hydra
from rllm.trainer import AgentTrainer
from rllm.workflows import SimpleWorkflow
from rllm.rewards import math_reward_fn
from rllm.data import DatasetRegistry

@hydra.main(
    config_path="pkg://rllm.trainer.config",
    config_name="agent_ppo_trainer",
    version_base=None
)
def main(config):
    # Load datasets
    train_dataset = DatasetRegistry.load_dataset("hendrycks_math", "train")
    val_dataset = DatasetRegistry.load_dataset("math500", "test")
    
    # Create trainer
    trainer = AgentTrainer(
        workflow_class=SimpleWorkflow,
        workflow_args={
            "reward_function": math_reward_fn,
        },
        config=config,
        train_dataset=train_dataset,
        val_dataset=val_dataset,
        backend="verl"
    )
    
    # Start training
    trainer.train()

if __name__ == "__main__":
    main()

Example: Training with Agent/Environment Classes

import hydra
from rllm.trainer import AgentTrainer
from rllm.agents import ToolAgent
from rllm.environments import ToolEnvironment
from rllm.rewards import search_reward_fn
from rllm.data import DatasetRegistry

@hydra.main(
    config_path="pkg://rllm.trainer.config",
    config_name="agent_ppo_trainer",
    version_base=None
)
def main(config):
    train_dataset = DatasetRegistry.load_dataset("hotpotqa", "train")
    val_dataset = DatasetRegistry.load_dataset("hotpotqa", "test")
    
    tool_map = {"search": MySearchTool}
    
    trainer = AgentTrainer(
        agent_class=ToolAgent,
        env_class=ToolEnvironment,
        agent_args={
            "tool_map": tool_map,
            "system_prompt": "You are a helpful search assistant."
        },
        env_args={
            "tool_map": tool_map,
            "reward_fn": search_reward_fn
        },
        config=config,
        train_dataset=train_dataset,
        val_dataset=val_dataset,
        backend="verl"
    )
    
    trainer.train()

if __name__ == "__main__":
    main()

Example: Config Overrides

import hydra
from rllm.trainer import AgentTrainer
from rllm.workflows import MultiTurnWorkflow
from rllm.data import DatasetRegistry

@hydra.main(
    config_path="pkg://rllm.trainer.config",
    config_name="agent_ppo_trainer",
    version_base=None
)
def main(config):
    # Override config values
    config_overrides = {
        "data.train_batch_size": 256,
        "data.val_batch_size": 512,
        "trainer.total_epochs": 5,
        "algorithm.gamma": 1.0,
        "algorithm.num_samples_per_prompt": 8,
        "actor_rollout_ref.model.path": "Qwen/Qwen3-4B",
    }
    
    # Apply overrides
    for key, value in config_overrides.items():
        keys = key.split(".")
        cfg = config
        for k in keys[:-1]:
            cfg = getattr(cfg, k)
        setattr(cfg, keys[-1], value)
    
    trainer = AgentTrainer(
        workflow_class=MultiTurnWorkflow,
        workflow_args={
            "agent_cls": "ToolAgent",
            "env_cls": "ToolEnvironment",
            "max_steps": 5
        },
        config=config,
        train_dataset=DatasetRegistry.load_dataset("mydata", "train"),
        val_dataset=DatasetRegistry.load_dataset("mydata", "val"),
        backend="verl"
    )
    
    trainer.train()

if __name__ == "__main__":
    main()

Example: Using agent_run_func (Advanced)

import hydra
from rllm.trainer import AgentTrainer
from rllm.rewards import math_reward_fn
from rllm.sdk import get_chat_client
from rllm.data import DatasetRegistry

@hydra.main(
    config_path="pkg://rllm.trainer.config",
    config_name="agent_ppo_trainer",
    version_base=None
)
def main(config):
    # Define custom run function
    def rollout(**kwargs):
        question = kwargs["question"]
        ground_truth = kwargs["ground_truth"]
        
        # Create client inside function for serialization
        client = get_chat_client(
            base_url="http://localhost:4000/v1",
            api_key="EMPTY"
        )
        
        response = client.chat.completions.create(
            model="deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B",
            messages=[{"role": "user", "content": question}]
        )
        
        response_text = response.choices[0].message.content
        reward = math_reward_fn(
            {"response": response_text, "ground_truth": ground_truth},
            response_text
        ).reward
        
        return reward * 1.0
    
    trainer = AgentTrainer(
        config=config,
        train_dataset=DatasetRegistry.load_dataset("hendrycks_math", "train"),
        val_dataset=DatasetRegistry.load_dataset("math500", "test"),
        agent_run_func=rollout
    )
    
    trainer.train()

if __name__ == "__main__":
    main()

Running Training

Run training scripts with Hydra CLI overrides:
# Basic training
python train.py

# Override config from command line
python train.py data.train_batch_size=512 trainer.total_epochs=5

# Use different model
python train.py actor_rollout_ref.model.path=Qwen/Qwen3-4B

# Adjust PPO hyperparameters
python train.py algorithm.gamma=1.0 algorithm.num_samples_per_prompt=8