Skip to main content
Workflows orchestrate the interaction between agents, environments, and models to execute complex multi-step tasks.

Workflow

Abstract base class for all workflows.
from rllm.workflows import Workflow

Constructor

def __init__(
    rollout_engine: RolloutEngine,
    executor: ThreadPoolExecutor,
    timeout: float = 1e6,
    gamma: float = 0.0,
    reward_bonus_coeff: float = 0.0,
    **kwargs
)
rollout_engine
RolloutEngine
The rollout engine for model inference.
executor
ThreadPoolExecutor
Thread pool executor for async operations.
timeout
float
default:"1e6"
Timeout for workflow execution in seconds.
gamma
float
default:"0.0"
Discount factor for reward computation. When > 0, computes Monte Carlo returns.
reward_bonus_coeff
float
default:"0.0"
Coefficient for reward shaping based on reward deltas.

Methods

run

Execute the workflow on a single task. Must be implemented by subclasses.
async def run(task: dict, uid: str, **kwargs) -> Episode | None
task
dict
The task to execute.
uid
str
Unique identifier for the task.
episode
Episode | None
The generated episode.

run_with_termination_handling

Wrapper around run() that handles termination events, errors, and timeouts.
episode = await workflow.run_with_termination_handling(task, uid)

commit

Commit a trajectory for training.
workflow.commit(
    name="solver",
    agent=my_agent,
    reset=True
)
name
str | None
Name for the trajectory.
agent
BaseAgent | None
Agent whose trajectory to commit.
trajectory
Trajectory | None
Trajectory to commit directly (alternative to agent).
reset
bool
default:"False"
Whether to reset the agent after committing.

collect_trajectories

Collect all trajectories from committed and agent instances.
episode = workflow.collect_trajectories()
episode
Episode
Episode containing all trajectories.

reset

Reset the workflow for a new task.
workflow.reset(task=new_task, uid="task_123:0")
task
dict | None
The task to reset to.
uid
str | None
Unique identifier for the task.

postprocess_episode

Post-process episode after completion (compute rewards, metrics, etc.).
episode = workflow.postprocess_episode(
    episode,
    termination_reason=TerminationReason.ENV_DONE
)

SimpleWorkflow

Simplified workflow for single-agent, single-turn tasks.
from rllm.workflows import SimpleWorkflow
from rllm.rewards import math_reward_fn

workflow = SimpleWorkflow(
    rollout_engine=engine,
    reward_function=math_reward_fn,
    executor=executor
)

Constructor

def __init__(
    rollout_engine: RolloutEngine,
    reward_function: RewardFunction,
    **kwargs
)
rollout_engine
RolloutEngine
Engine for model inference.
reward_function
RewardFunction
Function to compute rewards from task and action.

Methods

run

Execute the workflow:
episode = await workflow.run(
    task={"question": "What is 2+2?", "answer": "4"},
    uid="task_0:0"
)
The workflow automatically:
  1. Extracts messages from task (supports question, prompt, problem, or messages keys)
  2. Gets model response
  3. Computes reward
  4. Creates trajectory with step
  5. Returns episode

MultiTurnWorkflow

Workflow for multi-step agent-environment interactions.
from rllm.workflows import MultiTurnWorkflow
from rllm.agents import ToolAgent
from rllm.environments import ToolEnvironment

workflow = MultiTurnWorkflow(
    agent_cls=ToolAgent,
    env_cls=ToolEnvironment,
    agent_args={"tool_map": tools},
    env_args={"tool_map": tools, "reward_fn": reward_fn},
    max_steps=5,
    rollout_engine=engine,
    executor=executor
)

Constructor

def __init__(
    agent_cls: type | str,
    env_cls: type | str,
    agent_args: dict | None = None,
    env_args: dict | None = None,
    max_steps: int = 5,
    **kwargs
)
agent_cls
type | str
Agent class or string identifier (e.g., “ToolAgent”).
env_cls
type | str
Environment class or string identifier.
agent_args
dict | None
Arguments to pass to agent constructor.
env_args
dict | None
Arguments to pass to environment constructor.
max_steps
int
default:"5"
Maximum number of steps before termination.

Methods

run

Execute the multi-step workflow:
episode = await workflow.run(task=task_data, uid="task_0:0")
The workflow:
  1. Resets environment with task
  2. Updates agent with initial observation
  3. For each step:
    • Gets model response
    • Updates agent with response
    • Steps environment with action
    • Updates agent with new observation and reward
  4. Terminates on done=True or max steps reached

TerminationReason

Enum for workflow termination reasons.
from rllm.workflows import TerminationReason

class TerminationReason(Enum):
    MAX_PROMPT_LENGTH_EXCEEDED = "max_prompt_length_exceeded"
    MAX_RESPONSE_LENGTH_EXCEEDED = "max_response_length_exceeded"
    ENV_DONE = "env_done"
    MAX_TURNS_EXCEEDED = "max_turns_exceeded"
    TIMEOUT = "timeout"
    UNKNOWN = "unknown"
    ERROR = "error"

Example: Custom Workflow

from rllm.workflows import Workflow
from rllm.agents import Episode, Trajectory, Step, Action
from rllm.workflows import TerminationEvent, TerminationReason

class SolverJudgeWorkflow(Workflow):
    def __init__(self, rollout_engine, **kwargs):
        super().__init__(rollout_engine, **kwargs)
        self.solver = MyAgent()
        self.judge = MyAgent()
    
    async def run(self, task, uid, **kwargs):
        # Reset for new task
        self.reset(task, uid)
        
        # Solver generates solution
        solver_messages = [{"role": "user", "content": task["question"]}]
        solver_output = await self.rollout_engine.get_model_response(
            solver_messages,
            application_id=uid
        )
        
        # Create solver step
        solver_step = Step.from_model_output(
            solver_output,
            messages=solver_messages,
            action=Action(solver_output.content)
        )
        self.solver.trajectory.steps.append(solver_step)
        
        # Judge evaluates solution
        judge_messages = [
            {"role": "user", "content": f"Evaluate: {solver_output.content}"}
        ]
        judge_output = await self.rollout_engine.get_model_response(
            judge_messages,
            application_id=uid
        )
        
        # Compute rewards
        solver_reward = 1.0 if "correct" in judge_output.content else 0.0
        solver_step.reward = solver_reward
        
        # Commit trajectories
        self.commit(name="solver", agent=self.solver)
        
        # Return episode
        raise TerminationEvent(TerminationReason.ENV_DONE)