Skip to main content
The data module provides dataset loading, processing, and registry management for training and evaluation.

Dataset

PyTorch-compatible dataset class for rLLM.
from rllm.data import Dataset

dataset = Dataset(
    data=[{"question": "...", "answer": "..."}],
    name="my_dataset",
    split="train"
)

Constructor

data
list[dict[str, Any]]
List of dictionaries containing dataset examples.
name
str | None
Optional name for the dataset.
split
str | None
Optional split name (e.g., “train”, “test”, “val”).

Methods

len

Get the number of examples.
num_examples = len(dataset)

getitem

Get an item by index.
example = dataset[0]

get_data

Get the raw dataset list.
data = dataset.get_data()

repeat

Repeat the dataset n times.
repeated = dataset.repeat(n=4)
n
int
Number of times to repeat the dataset.

shuffle

Shuffle the dataset.
shuffled = dataset.shuffle(seed=42)
seed
int | None
Random seed for reproducibility.

select

Select a subset of the dataset.
subset = dataset.select(range(100))
indices
list[int] | range
Indices to select.

load_data

Load dataset from a file.
dataset = Dataset.load_data("path/to/data.parquet")
path
str
Path to dataset file (.json, .jsonl, .parquet).

DatasetRegistry

Global registry for managing datasets.
from rllm.data import DatasetRegistry

Methods

register_dataset

Register a dataset.
DatasetRegistry.register_dataset(
    name="hendrycks_math",
    split="train",
    path="/path/to/hendrycks_math_train.parquet"
)
name
str
Dataset name.
split
str
Split name (“train”, “test”, “val”).
path
str
Absolute path to dataset file.

load_dataset

Load a registered dataset.
dataset = DatasetRegistry.load_dataset("hendrycks_math", "train")
name
str
Dataset name.
split
str
Split name.
dataset
Dataset | None
Loaded dataset, or None if not found.

list_datasets

List all registered datasets.
datasets = DatasetRegistry.list_datasets()
datasets
dict
Dictionary mapping dataset names to splits and paths.

Example: Creating a Dataset

from rllm.data import Dataset

# Create from list
data = [
    {"question": "What is 2+2?", "answer": "4"},
    {"question": "What is 3*3?", "answer": "9"},
    {"question": "What is 10-5?", "answer": "5"}
]

dataset = Dataset(data=data, name="simple_math", split="train")

print(f"Size: {len(dataset)}")
print(f"First item: {dataset[0]}")

Example: Loading from File

from rllm.data import Dataset

# Load from parquet
dataset = Dataset.load_data("data/hendrycks_math_train.parquet")

print(f"Loaded {len(dataset)} examples")
print(f"First example: {dataset[0]}")

Example: Dataset Operations

from rllm.data import Dataset

dataset = Dataset.load_data("data/my_data.parquet")

# Shuffle
shuffled = dataset.shuffle(seed=42)

# Select subset
subset = shuffled.select(range(100))

# Repeat for oversampling
repeated = subset.repeat(n=4)

print(f"Original: {len(dataset)}")
print(f"Subset: {len(subset)}")
print(f"Repeated: {len(repeated)}")

Example: Registering Datasets

from rllm.data import Dataset, DatasetRegistry
import os

# Register datasets
base_path = "/data/math"

DatasetRegistry.register_dataset(
    name="hendrycks_math",
    split="train",
    path=os.path.join(base_path, "train.parquet")
)

DatasetRegistry.register_dataset(
    name="hendrycks_math",
    split="test",
    path=os.path.join(base_path, "test.parquet")
)

DatasetRegistry.register_dataset(
    name="math500",
    split="test",
    path=os.path.join(base_path, "math500.parquet")
)

# List registered datasets
print(DatasetRegistry.list_datasets())

Example: Loading Registered Datasets

from rllm.data import DatasetRegistry

# Load datasets
train_dataset = DatasetRegistry.load_dataset("hendrycks_math", "train")
test_dataset = DatasetRegistry.load_dataset("hendrycks_math", "test")

if train_dataset:
    print(f"Train size: {len(train_dataset)}")
    print(f"Train example: {train_dataset[0]}")

if test_dataset:
    print(f"Test size: {len(test_dataset)}")

Example: Preparing Data for Training

from rllm.data import Dataset, DatasetRegistry
import polars as pl
import os

def prepare_math_data(train_size=1000, test_size=100):
    """Prepare and register math datasets."""
    
    # Load raw data
    raw_data = pl.read_json("raw_data/hendrycks_math.jsonl")
    
    # Process data
    train_data = raw_data.head(train_size).to_dicts()
    test_data = raw_data.tail(test_size).to_dicts()
    
    # Create datasets
    train_dataset = Dataset(data=train_data, name="hendrycks_math", split="train")
    test_dataset = Dataset(data=test_data, name="hendrycks_math", split="test")
    
    # Save to parquet
    os.makedirs("data", exist_ok=True)
    train_path = "data/hendrycks_math_train.parquet"
    test_path = "data/hendrycks_math_test.parquet"
    
    pl.DataFrame(train_data).write_parquet(train_path)
    pl.DataFrame(test_data).write_parquet(test_path)
    
    # Register
    DatasetRegistry.register_dataset("hendrycks_math", "train", os.path.abspath(train_path))
    DatasetRegistry.register_dataset("hendrycks_math", "test", os.path.abspath(test_path))
    
    return train_dataset, test_dataset

# Prepare data
train_ds, test_ds = prepare_math_data()

Example: Using with Trainer

import hydra
from rllm.trainer import AgentTrainer
from rllm.data import DatasetRegistry
from rllm.workflows import SimpleWorkflow
from rllm.rewards import math_reward_fn

@hydra.main(
    config_path="pkg://rllm.trainer.config",
    config_name="agent_ppo_trainer",
    version_base=None
)
def main(config):
    # Load registered datasets
    train_dataset = DatasetRegistry.load_dataset("hendrycks_math", "train")
    val_dataset = DatasetRegistry.load_dataset("math500", "test")
    
    assert train_dataset, "Train dataset not found"
    assert val_dataset, "Validation dataset not found"
    
    # Create trainer
    trainer = AgentTrainer(
        workflow_class=SimpleWorkflow,
        workflow_args={"reward_function": math_reward_fn},
        config=config,
        train_dataset=train_dataset,
        val_dataset=val_dataset
    )
    
    trainer.train()

if __name__ == "__main__":
    main()

Dataset File Formats

Supported formats:
import polars as pl

# Write
pl.DataFrame(data).write_parquet("data.parquet")

# Read
dataset = Dataset.load_data("data.parquet")

JSON

import json

# Write
with open("data.json", "w") as f:
    json.dump(data, f)

# Read
dataset = Dataset.load_data("data.json")

JSONL

import json

# Write
with open("data.jsonl", "w") as f:
    for item in data:
        f.write(json.dumps(item) + "\n")

# Read
dataset = Dataset.load_data("data.jsonl")