Skip to main content
In this tutorial, you’ll train a retrieval-augmented generation (RAG) agent built with LangGraph. This demonstrates that rLLM SDK works seamlessly with popular agent frameworks—your LangGraph code runs unchanged.

Overview

By the end of this tutorial, you will have:
  1. Built a LangGraph agent with retrieval tool calling
  2. Injected rLLM SDK tracing into LangChain’s ChatOpenAI
  3. Trained the agent to search effectively using RL

Concepts

We will cover:
  • Client injection: Swap ChatOpenAI’s internal client with traced SDK client
  • LangGraph workflow: StateGraph, nodes, edges, and tools_condition
  • Multi-turn tracing: All LLM calls in an agentic loop are captured

Setup

1

Install dependencies

Install LangChain and LangGraph:
pip install langchain-openai langgraph
2

Download datasets

Download HotpotQA dataset, Wikipedia corpus and pre-built FAISS indices:
cd examples/sdk/langgraph
python data/prepare_hotpotqa_data.py
python data/download_search_data.py --data_dir ./search_data
cat search_data/prebuilt_indices/part_aa search_data/prebuilt_indices/part_ab > search_data/prebuilt_indices/e5_Flat.index
mv search_data/wikipedia/wiki-18.jsonl search_data/prebuilt_indices/corpus.json
3

Start retrieval server

Install dependencies for retrieval server (recommend fresh env):
conda create -n rag-server python=3.10 pip -y
pip install faiss-gpu==1.7.2 Flask numpy==1.26.4 sentence-transformers torch
Start the retrieval server on port 9002:
bash launch_server.sh ./search_data/prebuilt_indices 9002
4

Start vLLM server

Start the vLLM server on port 4000:
vllm serve Qwen/Qwen3-4B \
    --host 0.0.0.0 \
    --port 4000 \
    --enable-auto-tool-choice \
    --tool-call-parser hermes

1. Client Injection

LangChain’s ChatOpenAI accepts custom client and async_client parameters. By injecting our traced clients, all LLM calls flow through our proxy automatically.

1.1 Normal LangChain (no tracing)

from langchain_openai import ChatOpenAI

# Standard usage - no tracing
llm = ChatOpenAI(
    model="Qwen/Qwen3-4B",
    api_key="token-abc123"
)

1.2 With rLLM SDK tracing

from langchain_openai import ChatOpenAI
from rllm.sdk import get_chat_client, get_chat_client_async

# Create traced clients
sync_client = get_chat_client(
    base_url="http://localhost:4000/v1",
    api_key="token-abc123"
)
async_client = get_chat_client_async(
    base_url="http://localhost:4000/v1",
    api_key="token-abc123"
)

# Inject into ChatOpenAI
llm = ChatOpenAI(
    model="Qwen/Qwen3-4B",
    client=sync_client,        # ← Traced!
    async_client=async_client,  # ← Traced!
)
That’s it! Your LangGraph agent now has full tracing with zero code changes to the workflow logic.

2. Build the LangGraph Agent

2.1 Import dependencies

import os
import re
from langchain_openai import ChatOpenAI
from langgraph.graph import END, START, MessagesState, StateGraph
from langgraph.prebuilt import ToolNode, tools_condition
from rllm.sdk import get_chat_client, get_chat_client_async

2.2 Configure the model with traced clients

MODEL = "Qwen/Qwen3-4B"
MAX_RESPONSE_TOKENS = 2048

# Create traced clients
async_client = get_chat_client_async(
    base_url="http://localhost:4000/v1",
    api_key="token-abc123",
)

sync_client = get_chat_client(
    base_url="http://localhost:4000/v1",
    api_key="token-abc123",
)

# Inject into ChatOpenAI
response_model = ChatOpenAI(
    model=MODEL,
    temperature=1.0,
    max_tokens=MAX_RESPONSE_TOKENS,
    async_client=async_client,
    client=sync_client,
)

2.3 Define the retrieval tool

from local_retrieval_tool import to_langchain_tool

retriever_tool = to_langchain_tool(
    server_url="http://127.0.0.1:9002",
    max_results=5,
    timeout=30.0,
)

2.4 Create the agent node

SYSTEM_PROMPT = """You are a helpful AI assistant that can search for information.

When answering questions:
1. Use the search tool to find relevant information
2. Synthesize information from multiple sources
3. Put your final answer in \\boxed{} format

Example: \\boxed{Paris}"""

async def agent_step(state: MessagesState):
    """Agent decides: call tools or provide final answer."""
    response = await response_model.bind_tools([retriever_tool]).ainvoke(
        state["messages"]
    )
    return {"messages": [response]}

2.5 Assemble the graph

workflow = StateGraph(MessagesState)

# Add nodes
workflow.add_node("agent", agent_step)
workflow.add_node("tools", ToolNode([retriever_tool]))

# Add edges
workflow.add_edge(START, "agent")
workflow.add_conditional_edges(
    "agent",
    tools_condition,  # Routes to "tools" or END based on tool calls
    {
        "tools": "tools",
        END: END,
    },
)
workflow.add_edge("tools", "agent")

# Compile
graph = workflow.compile()

2.6 Test the graph

async def test_agent():
    async for chunk in graph.astream(
        {
            "messages": [
                {"role": "system", "content": SYSTEM_PROMPT},
                {"role": "user", "content": "What is the capital of France?"}
            ]
        },
        {"recursion_limit": 10},
    ):
        for node_name, update in chunk.items():
            print(f"Node: {node_name}")
            if "messages" in update:
                print(f"  → {update['messages'][-1].content[:100]}...")

# Run test
await test_agent()

3. Create the Run Function

Wrap the graph execution with reward computation.
from rllm.rewards.search_reward import RewardConfig, RewardSearchFn, RewardInput

async def run_search_agent(question: str, ground_truth: str, max_turns: int = 5) -> dict:
    """Run agent and compute reward."""
    
    final_answer = None
    num_turns = 0
    timed_out = False

    async for chunk in graph.astream(
        {
            "messages": [
                {"role": "system", "content": SYSTEM_PROMPT},
                {"role": "user", "content": question}
            ]
        },
        {"recursion_limit": max_turns * 2 + 5},
    ):
        for node_name, update in chunk.items():
            if node_name == "agent":
                num_turns += 1
                if num_turns > max_turns:
                    timed_out = True
                    break

            # Extract answer from \boxed{}
            if "messages" in update and update["messages"]:
                content = update["messages"][-1].content
                match = re.search(r"\\boxed\{([^}]+)\}", content)
                if match:
                    final_answer = match.group(1)

        if timed_out:
            break

    # Compute reward
    reward = 0.0
    if final_answer and not timed_out:
        reward_fn = RewardSearchFn(RewardConfig())
        reward = reward_fn(RewardInput(task_info={"ground_truth": ground_truth}, action=final_answer)).reward

    return {
        "final_answer": final_answer,
        "reward": reward,
        "num_turns": num_turns,
        "timed_out": timed_out,
    }

4. Set Up Training

import hydra
from rllm.data import DatasetRegistry
from rllm.trainer.agent_trainer import AgentTrainer

async def run_agent(question, ground_truth, **kwargs):
    """Training wrapper - returns reward only."""
    try:
        result = await run_search_agent(question, ground_truth)
        return result["reward"]
    except Exception:
        return 0.0

@hydra.main(
    config_path="pkg://rllm.trainer.config", 
    config_name="agent_ppo_trainer", 
    version_base=None
)
def main(config):
    train_dataset = DatasetRegistry.load_dataset("hotpotqa", "train")
    val_dataset = DatasetRegistry.load_dataset("hotpotqa-small", "test")

    trainer = AgentTrainer(
        config=config,
        train_dataset=train_dataset,
        val_dataset=val_dataset,
        agent_run_func=run_agent,
    )
    trainer.train()

if __name__ == "__main__":
    main()

5. Run Training

cd ~/rllm
bash examples/sdk/langgraph/train_rag_agent.sh

Next Steps

Resources