The AgentWorkflowEngine manages workflow execution with built-in retry logic, episode logging, and parallel task processing.
AgentWorkflowEngine
from rllm.engine import AgentWorkflowEngine
Constructor
def __init__(
workflow_cls: type[Workflow],
workflow_args: dict,
rollout_engine: RolloutEngine,
config = None,
n_parallel_tasks: int = 128,
retry_limit: int = 3,
raise_on_error: bool = True,
episode_logger = None,
**kwargs
)
Workflow class to instantiate for each task.
Arguments to pass to workflow instances.
Engine for model inference and rollout.
Optional configuration object for training.
Number of parallel workflow instances to maintain.
Maximum number of retry attempts for failed tasks.
Whether to raise exceptions on permanent failures.
Optional logger for saving episode data to files.
Methods
initialize_pool
Initialize the workflow pool with parallel workflow instances.
await engine.initialize_pool()
set_training_step
Set current training step for episode logging.
engine.set_training_step(step=100, mode="train", epoch=0)
Current training step number.
Mode identifier: “train” or “val”.
process_task_with_retry
Process a single task rollout with retry logic based on termination reasons.
task_id, rollout_idx, episode = await engine.process_task_with_retry(
task=task_data,
task_id="task_123",
rollout_idx=0
)
Task dictionary containing the task specification.
Unique identifier for the task.
Index of this rollout attempt for the task.
execute_batch
Execute a batch of tasks with automatic retry and error handling.
episodes = await engine.execute_batch(
tasks=task_list,
num_rollouts_per_task=4
)
List of task dictionaries.
Number of rollouts to generate per task.
List of completed episodes.
Retry Logic
The engine automatically retries tasks based on termination reason:
- Retryable:
TIMEOUT, ERROR, MAX_PROMPT_LENGTH_EXCEEDED, MAX_RESPONSE_LENGTH_EXCEEDED
- Non-retryable:
ENV_DONE, MAX_TURNS_EXCEEDED, UNKNOWN
Tasks are retried up to retry_limit times before failing permanently.
Example: Simple Workflow Execution
import asyncio
from concurrent.futures import ThreadPoolExecutor
from rllm.engine import AgentWorkflowEngine
from rllm.engine.rollout import OpenAIEngine
from rllm.workflows import SimpleWorkflow
from rllm.rewards import math_reward_fn
# Create rollout engine
rollout_engine = OpenAIEngine(
base_url="http://localhost:4000/v1",
api_key="EMPTY",
model="deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B"
)
# Create workflow engine
engine = AgentWorkflowEngine(
workflow_cls=SimpleWorkflow,
workflow_args={
"reward_function": math_reward_fn,
"executor": ThreadPoolExecutor(max_workers=128)
},
rollout_engine=rollout_engine,
n_parallel_tasks=64,
retry_limit=3
)
# Define tasks
tasks = [
{"question": "What is 2+2?", "answer": "4"},
{"question": "Solve x + 5 = 10", "answer": "x = 5"}
]
async def main():
# Initialize pool
await engine.initialize_pool()
# Execute batch
episodes = await engine.execute_batch(
tasks=tasks,
num_rollouts_per_task=2
)
# Print results
for episode in episodes:
print(f"Task: {episode.task}")
print(f"Correct: {episode.is_correct}")
print(f"Trajectories: {len(episode.trajectories)}")
print(f"Metrics: {episode.metrics}")
asyncio.run(main())
Example: Multi-Turn Workflow
import asyncio
from concurrent.futures import ThreadPoolExecutor
from rllm.engine import AgentWorkflowEngine
from rllm.engine.rollout import OpenAIEngine
from rllm.workflows import MultiTurnWorkflow
from rllm.agents import ToolAgent
from rllm.environments import ToolEnvironment
from rllm.rewards import search_reward_fn
rollout_engine = OpenAIEngine(
base_url="http://localhost:8000/v1",
api_key="EMPTY",
model="Qwen/Qwen3-4B"
)
tool_map = {"search": MySearchTool}
engine = AgentWorkflowEngine(
workflow_cls=MultiTurnWorkflow,
workflow_args={
"agent_cls": ToolAgent,
"env_cls": ToolEnvironment,
"agent_args": {"tool_map": tool_map},
"env_args": {"tool_map": tool_map, "reward_fn": search_reward_fn},
"max_steps": 5,
"executor": ThreadPoolExecutor(max_workers=128)
},
rollout_engine=rollout_engine,
n_parallel_tasks=32,
retry_limit=2
)
tasks = [
{"question": "Find the population of Tokyo", "answer": "37.4 million"}
]
async def main():
await engine.initialize_pool()
episodes = await engine.execute_batch(tasks, num_rollouts_per_task=4)
for episode in episodes:
for traj in episode.trajectories:
print(f"Agent: {traj.name}")
print(f"Steps: {len(traj.steps)}")
print(f"Reward: {traj.reward}")
asyncio.run(main())
Episode Logging
The engine supports optional episode logging to save episodes during training:
from rllm.utils import EpisodeLogger
logger = EpisodeLogger(output_dir="./episodes")
engine = AgentWorkflowEngine(
workflow_cls=MyWorkflow,
workflow_args=workflow_args,
rollout_engine=rollout_engine,
episode_logger=logger
)
# Set training context
engine.set_training_step(step=100, mode="train", epoch=0)
# Episodes will be saved to ./episodes/train/epoch_0/step_100/