Skip to main content
The AgentWorkflowEngine manages workflow execution with built-in retry logic, episode logging, and parallel task processing.

AgentWorkflowEngine

from rllm.engine import AgentWorkflowEngine

Constructor

def __init__(
    workflow_cls: type[Workflow],
    workflow_args: dict,
    rollout_engine: RolloutEngine,
    config = None,
    n_parallel_tasks: int = 128,
    retry_limit: int = 3,
    raise_on_error: bool = True,
    episode_logger = None,
    **kwargs
)
workflow_cls
type[Workflow]
Workflow class to instantiate for each task.
workflow_args
dict
Arguments to pass to workflow instances.
rollout_engine
RolloutEngine
Engine for model inference and rollout.
config
dict | None
Optional configuration object for training.
n_parallel_tasks
int
default:"128"
Number of parallel workflow instances to maintain.
retry_limit
int
default:"3"
Maximum number of retry attempts for failed tasks.
raise_on_error
bool
default:"True"
Whether to raise exceptions on permanent failures.
episode_logger
EpisodeLogger | None
Optional logger for saving episode data to files.

Methods

initialize_pool

Initialize the workflow pool with parallel workflow instances.
await engine.initialize_pool()

set_training_step

Set current training step for episode logging.
engine.set_training_step(step=100, mode="train", epoch=0)
step
int
Current training step number.
mode
str
default:"train"
Mode identifier: “train” or “val”.
epoch
int
default:"0"
Current epoch number.

process_task_with_retry

Process a single task rollout with retry logic based on termination reasons.
task_id, rollout_idx, episode = await engine.process_task_with_retry(
    task=task_data,
    task_id="task_123",
    rollout_idx=0
)
task
dict
Task dictionary containing the task specification.
task_id
str
Unique identifier for the task.
rollout_idx
int
Index of this rollout attempt for the task.
task_id
str
The task ID.
rollout_idx
int
The rollout index.
episode
Episode
Completed episode.

execute_batch

Execute a batch of tasks with automatic retry and error handling.
episodes = await engine.execute_batch(
    tasks=task_list,
    num_rollouts_per_task=4
)
tasks
list[dict]
List of task dictionaries.
num_rollouts_per_task
int
default:"1"
Number of rollouts to generate per task.
episodes
list[Episode]
List of completed episodes.

Retry Logic

The engine automatically retries tasks based on termination reason:
  • Retryable: TIMEOUT, ERROR, MAX_PROMPT_LENGTH_EXCEEDED, MAX_RESPONSE_LENGTH_EXCEEDED
  • Non-retryable: ENV_DONE, MAX_TURNS_EXCEEDED, UNKNOWN
Tasks are retried up to retry_limit times before failing permanently.

Example: Simple Workflow Execution

import asyncio
from concurrent.futures import ThreadPoolExecutor
from rllm.engine import AgentWorkflowEngine
from rllm.engine.rollout import OpenAIEngine
from rllm.workflows import SimpleWorkflow
from rllm.rewards import math_reward_fn

# Create rollout engine
rollout_engine = OpenAIEngine(
    base_url="http://localhost:4000/v1",
    api_key="EMPTY",
    model="deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B"
)

# Create workflow engine
engine = AgentWorkflowEngine(
    workflow_cls=SimpleWorkflow,
    workflow_args={
        "reward_function": math_reward_fn,
        "executor": ThreadPoolExecutor(max_workers=128)
    },
    rollout_engine=rollout_engine,
    n_parallel_tasks=64,
    retry_limit=3
)

# Define tasks
tasks = [
    {"question": "What is 2+2?", "answer": "4"},
    {"question": "Solve x + 5 = 10", "answer": "x = 5"}
]

async def main():
    # Initialize pool
    await engine.initialize_pool()
    
    # Execute batch
    episodes = await engine.execute_batch(
        tasks=tasks,
        num_rollouts_per_task=2
    )
    
    # Print results
    for episode in episodes:
        print(f"Task: {episode.task}")
        print(f"Correct: {episode.is_correct}")
        print(f"Trajectories: {len(episode.trajectories)}")
        print(f"Metrics: {episode.metrics}")

asyncio.run(main())

Example: Multi-Turn Workflow

import asyncio
from concurrent.futures import ThreadPoolExecutor
from rllm.engine import AgentWorkflowEngine
from rllm.engine.rollout import OpenAIEngine
from rllm.workflows import MultiTurnWorkflow
from rllm.agents import ToolAgent
from rllm.environments import ToolEnvironment
from rllm.rewards import search_reward_fn

rollout_engine = OpenAIEngine(
    base_url="http://localhost:8000/v1",
    api_key="EMPTY",
    model="Qwen/Qwen3-4B"
)

tool_map = {"search": MySearchTool}

engine = AgentWorkflowEngine(
    workflow_cls=MultiTurnWorkflow,
    workflow_args={
        "agent_cls": ToolAgent,
        "env_cls": ToolEnvironment,
        "agent_args": {"tool_map": tool_map},
        "env_args": {"tool_map": tool_map, "reward_fn": search_reward_fn},
        "max_steps": 5,
        "executor": ThreadPoolExecutor(max_workers=128)
    },
    rollout_engine=rollout_engine,
    n_parallel_tasks=32,
    retry_limit=2
)

tasks = [
    {"question": "Find the population of Tokyo", "answer": "37.4 million"}
]

async def main():
    await engine.initialize_pool()
    episodes = await engine.execute_batch(tasks, num_rollouts_per_task=4)
    
    for episode in episodes:
        for traj in episode.trajectories:
            print(f"Agent: {traj.name}")
            print(f"Steps: {len(traj.steps)}")
            print(f"Reward: {traj.reward}")

asyncio.run(main())

Episode Logging

The engine supports optional episode logging to save episodes during training:
from rllm.utils import EpisodeLogger

logger = EpisodeLogger(output_dir="./episodes")

engine = AgentWorkflowEngine(
    workflow_cls=MyWorkflow,
    workflow_args=workflow_args,
    rollout_engine=rollout_engine,
    episode_logger=logger
)

# Set training context
engine.set_training_step(step=100, mode="train", epoch=0)

# Episodes will be saved to ./episodes/train/epoch_0/step_100/