Workflows orchestrate the interaction between agents, environments, and models to execute complex multi-step tasks.
Workflow
Abstract base class for all workflows.
from rllm.workflows import Workflow
Constructor
def __init__(
rollout_engine: RolloutEngine,
executor: ThreadPoolExecutor,
timeout: float = 1e6,
gamma: float = 0.0,
reward_bonus_coeff: float = 0.0,
**kwargs
)
The rollout engine for model inference.
Thread pool executor for async operations.
Timeout for workflow execution in seconds.
Discount factor for reward computation. When > 0, computes Monte Carlo returns.
Coefficient for reward shaping based on reward deltas.
Methods
run
Execute the workflow on a single task. Must be implemented by subclasses.
async def run(task: dict, uid: str, **kwargs) -> Episode | None
Unique identifier for the task.
run_with_termination_handling
Wrapper around run() that handles termination events, errors, and timeouts.
episode = await workflow.run_with_termination_handling(task, uid)
commit
Commit a trajectory for training.
workflow.commit(
name="solver",
agent=my_agent,
reset=True
)
Agent whose trajectory to commit.
Trajectory to commit directly (alternative to agent).
Whether to reset the agent after committing.
collect_trajectories
Collect all trajectories from committed and agent instances.
episode = workflow.collect_trajectories()
Episode containing all trajectories.
reset
Reset the workflow for a new task.
workflow.reset(task=new_task, uid="task_123:0")
Unique identifier for the task.
postprocess_episode
Post-process episode after completion (compute rewards, metrics, etc.).
episode = workflow.postprocess_episode(
episode,
termination_reason=TerminationReason.ENV_DONE
)
SimpleWorkflow
Simplified workflow for single-agent, single-turn tasks.
from rllm.workflows import SimpleWorkflow
from rllm.rewards import math_reward_fn
workflow = SimpleWorkflow(
rollout_engine=engine,
reward_function=math_reward_fn,
executor=executor
)
Constructor
def __init__(
rollout_engine: RolloutEngine,
reward_function: RewardFunction,
**kwargs
)
Engine for model inference.
Function to compute rewards from task and action.
Methods
run
Execute the workflow:
episode = await workflow.run(
task={"question": "What is 2+2?", "answer": "4"},
uid="task_0:0"
)
The workflow automatically:
- Extracts messages from task (supports
question, prompt, problem, or messages keys)
- Gets model response
- Computes reward
- Creates trajectory with step
- Returns episode
MultiTurnWorkflow
Workflow for multi-step agent-environment interactions.
from rllm.workflows import MultiTurnWorkflow
from rllm.agents import ToolAgent
from rllm.environments import ToolEnvironment
workflow = MultiTurnWorkflow(
agent_cls=ToolAgent,
env_cls=ToolEnvironment,
agent_args={"tool_map": tools},
env_args={"tool_map": tools, "reward_fn": reward_fn},
max_steps=5,
rollout_engine=engine,
executor=executor
)
Constructor
def __init__(
agent_cls: type | str,
env_cls: type | str,
agent_args: dict | None = None,
env_args: dict | None = None,
max_steps: int = 5,
**kwargs
)
Agent class or string identifier (e.g., “ToolAgent”).
Environment class or string identifier.
Arguments to pass to agent constructor.
Arguments to pass to environment constructor.
Maximum number of steps before termination.
Methods
run
Execute the multi-step workflow:
episode = await workflow.run(task=task_data, uid="task_0:0")
The workflow:
- Resets environment with task
- Updates agent with initial observation
- For each step:
- Gets model response
- Updates agent with response
- Steps environment with action
- Updates agent with new observation and reward
- Terminates on
done=True or max steps reached
TerminationReason
Enum for workflow termination reasons.
from rllm.workflows import TerminationReason
class TerminationReason(Enum):
MAX_PROMPT_LENGTH_EXCEEDED = "max_prompt_length_exceeded"
MAX_RESPONSE_LENGTH_EXCEEDED = "max_response_length_exceeded"
ENV_DONE = "env_done"
MAX_TURNS_EXCEEDED = "max_turns_exceeded"
TIMEOUT = "timeout"
UNKNOWN = "unknown"
ERROR = "error"
Example: Custom Workflow
from rllm.workflows import Workflow
from rllm.agents import Episode, Trajectory, Step, Action
from rllm.workflows import TerminationEvent, TerminationReason
class SolverJudgeWorkflow(Workflow):
def __init__(self, rollout_engine, **kwargs):
super().__init__(rollout_engine, **kwargs)
self.solver = MyAgent()
self.judge = MyAgent()
async def run(self, task, uid, **kwargs):
# Reset for new task
self.reset(task, uid)
# Solver generates solution
solver_messages = [{"role": "user", "content": task["question"]}]
solver_output = await self.rollout_engine.get_model_response(
solver_messages,
application_id=uid
)
# Create solver step
solver_step = Step.from_model_output(
solver_output,
messages=solver_messages,
action=Action(solver_output.content)
)
self.solver.trajectory.steps.append(solver_step)
# Judge evaluates solution
judge_messages = [
{"role": "user", "content": f"Evaluate: {solver_output.content}"}
]
judge_output = await self.rollout_engine.get_model_response(
judge_messages,
application_id=uid
)
# Compute rewards
solver_reward = 1.0 if "correct" in judge_output.content else 0.0
solver_step.reward = solver_reward
# Commit trajectories
self.commit(name="solver", agent=self.solver)
# Return episode
raise TerminationEvent(TerminationReason.ENV_DONE)