The data module provides dataset loading, processing, and registry management for training and evaluation.
Dataset
PyTorch-compatible dataset class for rLLM.
from rllm.data import Dataset
dataset = Dataset(
data=[{"question": "...", "answer": "..."}],
name="my_dataset",
split="train"
)
Constructor
List of dictionaries containing dataset examples.
Optional name for the dataset.
Optional split name (e.g., “train”, “test”, “val”).
Methods
len
Get the number of examples.
num_examples = len(dataset)
getitem
Get an item by index.
get_data
Get the raw dataset list.
data = dataset.get_data()
repeat
Repeat the dataset n times.
repeated = dataset.repeat(n=4)
Number of times to repeat the dataset.
shuffle
Shuffle the dataset.
shuffled = dataset.shuffle(seed=42)
Random seed for reproducibility.
select
Select a subset of the dataset.
subset = dataset.select(range(100))
load_data
Load dataset from a file.
dataset = Dataset.load_data("path/to/data.parquet")
Path to dataset file (.json, .jsonl, .parquet).
DatasetRegistry
Global registry for managing datasets.
from rllm.data import DatasetRegistry
Methods
register_dataset
Register a dataset.
DatasetRegistry.register_dataset(
name="hendrycks_math",
split="train",
path="/path/to/hendrycks_math_train.parquet"
)
Split name (“train”, “test”, “val”).
Absolute path to dataset file.
load_dataset
Load a registered dataset.
dataset = DatasetRegistry.load_dataset("hendrycks_math", "train")
Loaded dataset, or None if not found.
list_datasets
List all registered datasets.
datasets = DatasetRegistry.list_datasets()
Dictionary mapping dataset names to splits and paths.
Example: Creating a Dataset
from rllm.data import Dataset
# Create from list
data = [
{"question": "What is 2+2?", "answer": "4"},
{"question": "What is 3*3?", "answer": "9"},
{"question": "What is 10-5?", "answer": "5"}
]
dataset = Dataset(data=data, name="simple_math", split="train")
print(f"Size: {len(dataset)}")
print(f"First item: {dataset[0]}")
Example: Loading from File
from rllm.data import Dataset
# Load from parquet
dataset = Dataset.load_data("data/hendrycks_math_train.parquet")
print(f"Loaded {len(dataset)} examples")
print(f"First example: {dataset[0]}")
Example: Dataset Operations
from rllm.data import Dataset
dataset = Dataset.load_data("data/my_data.parquet")
# Shuffle
shuffled = dataset.shuffle(seed=42)
# Select subset
subset = shuffled.select(range(100))
# Repeat for oversampling
repeated = subset.repeat(n=4)
print(f"Original: {len(dataset)}")
print(f"Subset: {len(subset)}")
print(f"Repeated: {len(repeated)}")
Example: Registering Datasets
from rllm.data import Dataset, DatasetRegistry
import os
# Register datasets
base_path = "/data/math"
DatasetRegistry.register_dataset(
name="hendrycks_math",
split="train",
path=os.path.join(base_path, "train.parquet")
)
DatasetRegistry.register_dataset(
name="hendrycks_math",
split="test",
path=os.path.join(base_path, "test.parquet")
)
DatasetRegistry.register_dataset(
name="math500",
split="test",
path=os.path.join(base_path, "math500.parquet")
)
# List registered datasets
print(DatasetRegistry.list_datasets())
Example: Loading Registered Datasets
from rllm.data import DatasetRegistry
# Load datasets
train_dataset = DatasetRegistry.load_dataset("hendrycks_math", "train")
test_dataset = DatasetRegistry.load_dataset("hendrycks_math", "test")
if train_dataset:
print(f"Train size: {len(train_dataset)}")
print(f"Train example: {train_dataset[0]}")
if test_dataset:
print(f"Test size: {len(test_dataset)}")
Example: Preparing Data for Training
from rllm.data import Dataset, DatasetRegistry
import polars as pl
import os
def prepare_math_data(train_size=1000, test_size=100):
"""Prepare and register math datasets."""
# Load raw data
raw_data = pl.read_json("raw_data/hendrycks_math.jsonl")
# Process data
train_data = raw_data.head(train_size).to_dicts()
test_data = raw_data.tail(test_size).to_dicts()
# Create datasets
train_dataset = Dataset(data=train_data, name="hendrycks_math", split="train")
test_dataset = Dataset(data=test_data, name="hendrycks_math", split="test")
# Save to parquet
os.makedirs("data", exist_ok=True)
train_path = "data/hendrycks_math_train.parquet"
test_path = "data/hendrycks_math_test.parquet"
pl.DataFrame(train_data).write_parquet(train_path)
pl.DataFrame(test_data).write_parquet(test_path)
# Register
DatasetRegistry.register_dataset("hendrycks_math", "train", os.path.abspath(train_path))
DatasetRegistry.register_dataset("hendrycks_math", "test", os.path.abspath(test_path))
return train_dataset, test_dataset
# Prepare data
train_ds, test_ds = prepare_math_data()
Example: Using with Trainer
import hydra
from rllm.trainer import AgentTrainer
from rllm.data import DatasetRegistry
from rllm.workflows import SimpleWorkflow
from rllm.rewards import math_reward_fn
@hydra.main(
config_path="pkg://rllm.trainer.config",
config_name="agent_ppo_trainer",
version_base=None
)
def main(config):
# Load registered datasets
train_dataset = DatasetRegistry.load_dataset("hendrycks_math", "train")
val_dataset = DatasetRegistry.load_dataset("math500", "test")
assert train_dataset, "Train dataset not found"
assert val_dataset, "Validation dataset not found"
# Create trainer
trainer = AgentTrainer(
workflow_class=SimpleWorkflow,
workflow_args={"reward_function": math_reward_fn},
config=config,
train_dataset=train_dataset,
val_dataset=val_dataset
)
trainer.train()
if __name__ == "__main__":
main()
Supported formats:
Parquet (Recommended)
import polars as pl
# Write
pl.DataFrame(data).write_parquet("data.parquet")
# Read
dataset = Dataset.load_data("data.parquet")
JSON
import json
# Write
with open("data.json", "w") as f:
json.dump(data, f)
# Read
dataset = Dataset.load_data("data.json")
JSONL
import json
# Write
with open("data.jsonl", "w") as f:
for item in data:
f.write(json.dumps(item) + "\n")
# Read
dataset = Dataset.load_data("data.jsonl")