DPO Demo with math (gsm8k)#
This notebook demonstrates how to fine-tune a Gemma3-1B-it model using Direct Preference Optimization (DPO). DPO is a method for training language models to align with human preferences without requiring a separate reward model.
What this example covers:#
Loading and setting up a pre-trained Gemma3-1B instruction-tuned model
Applying LoRA (Low-Rank Adaptation) for efficient fine-tuning
Processing DPO training data with prompt/chosen/rejected response pairs
Training the model using the DPO trainer from Tunix
Evaluating model performance on GSM8K mathematical reasoning tasks
The training uses the Argilla DPO dataset containing preference pairs, focusing on GSM8K training examples to improve mathematical reasoning capabilities.
This notebook has been tested on a v6e-1 TPU instance, with 32 GB HBM.
Installing Libraries: RESTART AFTER INSTALL FOR COLAB#
# Install necessary libraries
import importlib
if importlib.util.find_spec("dotenv") is None:
print("Required packages not found. Running full installation...")
%pip install dotenv
%pip install kagglehub
%pip install safetensors
%pip install tensorflow
%pip install tensorflow_datasets
%pip install tensorboardX
%pip install transformers
%pip install grain
%pip install datasets
%pip install huggingface_hub
%pip install wandb
%pip install -q git+https://github.com/jax-ml/jax
%pip uninstall flax -y
%pip install git+https://github.com/google/flax
%pip install git+https://github.com/google/tunix
%pip install git+https://github.com/google/qwix
%pip install -q 'numpy>2'
import os
import sys
try:
from google.colab import userdata
USE_COLAB = True
%pip uninstall -y wandb -y # wandb is glitchy with tunix in colab
os.environ["WANDB_API_KEY"] = userdata.get('WANDB_API_KEY')
os.environ["HF_TOKEN"] = userdata.get("HF_TOKEN")
except:
USE_COLAB = False
from dotenv import load_dotenv
load_dotenv()
print("Using env vars to login")
import nest_asyncio
nest_asyncio.apply()
print("nest_asyncio applied")
import wandb
if "WANDB_API_KEY" in os.environ and os.environ["WANDB_API_KEY"]:
wandb.login(key=os.environ["WANDB_API_KEY"])
else:
print("WANDB_API_KEY not found. Skipping wandb login.")
# Check if HF_TOKEN is set before logging in
if "HF_TOKEN" in os.environ and os.environ["HF_TOKEN"]:
hf_token = os.environ["HF_TOKEN"]
!hf auth login --token "$hf_token"
else:
print("HF_TOKEN not found. Skipping Hugging Face login.")
# Imports
import os
import sys
import json
import shutil
from datasets import concatenate_datasets
from datasets import load_dataset
from flax import nnx
import grain
from huggingface_hub import snapshot_download
import jax
import jax.numpy as jnp
import numpy as np
import optax
from orbax import checkpoint as ocp
import qwix
import tensorflow_datasets as tfds
from tqdm.auto import tqdm
from tunix.generate import sampler as sampler_lib
from tunix.generate import tokenizer_adapter as tokenizer_lib
from tunix.models.gemma3 import model as gemma3_model_lib
from tunix.models.gemma3 import params_safetensors as params_safetensors_lib
from tunix.models.gemma3 import params as gemma_params
from tunix.sft import metrics_logger
from tunix.sft.dpo.dpo_trainer import DPOTrainer
from tunix.sft.dpo.dpo_trainer import DPOTrainingConfig
from tunix.sft.utils import show_hbm_usage
# Hyperparamters/Config
model_id = "google/gemma-3-1b-it" # also supports "google/gemma-3-270m-it"
GEMMA_TOKENIZER_PATH = "gs://gemma-data/tokenizers/tokenizer_gemma3.model"
# ====== Data ======
TRAIN_DATA_DIR = "./data/train"
TEST_DATA_DIR = "./data/test"
TRAIN_FRACTION = 1.0
INTERMEDIATE_CKPT_DIR = "/content/intermediate_ckpt/"
# ====== LoRA ======
RANK = 32
ALPHA = 16.0
# ====== Sharding ======
# Adjust mesh based on your TPU memory and model size.
NUM_TPUS = len(jax.devices())
if NUM_TPUS == 8:
MESH_COUNTS = (1, 4)
elif NUM_TPUS == 1:
MESH_COUNTS = (1, 1)
else:
raise ValueError(f"Unsupported number of TPUs: {NUM_TPUS}")
MESH = [
MESH_COUNTS,
("fsdp", "tp"),
]
MAX_PROMPT_LENGTH = 192
MAX_RESPONSE_LENGTH = 192
TEMPERATURE = 0.7
TOP_P = 1.0
TOP_K = 50
BETA = 0.1
# === AdamW, warmup, cosine scheduler ===
LEARNING_RATE = 3e-5
B1 = 0.9
B2 = 0.99
WEIGHT_DECAY = 0.1
# == Cosine decay with warmup scheduler ==
# Linearly increase learning rate from 0. to 5e-6 in the first 10% training
# steps, and then gradually decrease the learning rate to 0 using cosine
# scheduler.
BATCH_SIZE = 2
NUM_BATCHES = 512
NUM_TEST_BATCHES = 100
NUM_TEST_BATCHES = 2
EVAL_EVERY_N_STEPS = 1024
NUM_EPOCHS = 2 # can potentially train for more epochs
MAX_STEPS = int(NUM_BATCHES * TRAIN_FRACTION * NUM_EPOCHS)
WARMUP_STEPS = 0.1 * MAX_STEPS
# == Grad clipping ==
# Grad clipping to prevent large gradients. Found this
# important to keep KL divergence in check.
MAX_GRAD_NORM = 0.1
# Checkpoint saving
INTERMEDIATE_CKPT_DIR = "/tmp/content/intermediate_ckpt/"
CKPT_DIR = "/tmp/content/ckpts/"
SAVE_INTERVAL_STEPS = 500
MAX_TO_KEEP = 4
# ====== Inference ======
GENERATION_CONFIGS = {
# greedy search
"greedy": {"temperature": None, "top_k": 1, "top_p": None},
# some randomness
"standard": {"temperature": 0.7, "top_k": 50, "top_p": 0.95},
# liberal
"liberal": {"temperature": 0.85, "top_k": 2000, "top_p": 1.0},
}
Load reference model and LoRA model#
Reference Model and LoRA Model#
Reference Model: This is the original pre-trained Gemma3-1B instruction-tuned model that serves as the base for fine-tuning. It’s loaded from the Hugging Face Hub.
LoRA Model: This is a Low-Rank Adaptation of the reference model. LoRA is a parameter-efficient fine-tuning technique that injects small, trainable matrices into specific layers of the pre-trained model, significantly reducing the number of parameters that need to be updated during training. This makes fine-tuning much faster and requires less memory compared to fine-tuning the entire model. The LoRA model is built on top of the reference model, inheriting its pre-trained weights and capabilities, while allowing for efficient adaptation to the DPO task.
ignore_patterns = [
"*.pth", # Ignore PyTorch .pth weight files
]
print(f"Downloading {model_id} from Hugging Face...")
local_model_path = snapshot_download(
repo_id=model_id, ignore_patterns=ignore_patterns
)
print(f"Model successfully downloaded to: {local_model_path}")
EOS_TOKENS = []
generation_config_path = os.path.join(local_model_path, "generation_config.json")
if os.path.exists(generation_config_path):
with open(generation_config_path, "r") as f:
generation_configs = json.load(f)
EOS_TOKENS = generation_configs.get("eos_token_id", [])
print(f"Using EOS token IDs: {EOS_TOKENS}")
print("\n--- HBM Usage BEFORE Model Load ---")
show_hbm_usage()
MODEL_CP_PATH = local_model_path
if "gemma-3-270m" in model_id:
model_config = gemma3_model_lib.ModelConfig.gemma3_270m()
elif "gemma-3-1b" in model_id:
model_config = gemma3_model_lib.ModelConfig.gemma3_1b_it()
else:
raise ValueError(f"Unsupported model: {model_id}")
mesh = jax.make_mesh(*MESH, axis_types=(jax.sharding.AxisType.Auto,) * len(MESH[0]))
with mesh:
gemma3 = params_safetensors_lib.create_model_from_safe_tensors(
MODEL_CP_PATH, (model_config), mesh
)
nnx.display(gemma3)
print("\n--- HBM Usage AFTER Model Load ---")
show_hbm_usage()
gemma_tokenizer = tokenizer_lib.Tokenizer(tokenizer_path=GEMMA_TOKENIZER_PATH)
if gemma_tokenizer.eos_id() not in EOS_TOKENS:
EOS_TOKENS.append(gemma_tokenizer.eos_id())
print(f"Using EOS token IDs: {EOS_TOKENS}")
sampler = sampler_lib.Sampler(
transformer=gemma3,
tokenizer=gemma_tokenizer,
cache_config=sampler_lib.CacheConfig(
cache_size=MAX_PROMPT_LENGTH + MAX_RESPONSE_LENGTH + 256,
num_layers=model_config.num_layers,
num_kv_heads=model_config.num_kv_heads,
head_dim=model_config.head_dim,
),
)
def get_lora_model(base_model, mesh):
lora_provider = qwix.LoraProvider(
module_path=(
".*q_einsum|.*kv_einsum|.*gate_proj|.*down_proj|.*up_proj|"
".*attn_vec_einsum"
),
rank=RANK,
alpha=ALPHA,
)
model_input = base_model.get_model_input()
lora_model = qwix.apply_lora_to_model(
base_model, lora_provider, **model_input
)
with mesh:
state = nnx.state(lora_model)
pspecs = nnx.get_partition_spec(state)
sharded_state = jax.lax.with_sharding_constraint(state, pspecs)
nnx.update(lora_model, sharded_state)
return lora_model
# Policy model
lora_gemma = get_lora_model(gemma3, mesh=mesh)
nnx.display(lora_gemma)
Load evaluation data and evaluate the reference model
TEMPLATE = """<start_of_turn>user
{question}<end_of_turn>
<start_of_turn>model
"""
def generate(
question, sampler, temperature=0.7, top_k=50, top_p=0.95, seed=None
):
"""Given prompt, generates text."""
if isinstance(question, str):
input_batch = [
TEMPLATE.format(
question=question,
),
]
else:
input_batch = [
TEMPLATE.format(
question=q,
)
for q in question
]
out_data = sampler(
input_strings=input_batch,
max_generation_steps=MAX_RESPONSE_LENGTH,
temperature=temperature,
top_k=top_k,
top_p=top_p,
echo=False,
seed=seed if seed is not None else None,
eos_tokens=EOS_TOKENS,
)
output = out_data.text
if isinstance(question, str):
return output[0]
return output
def evaluate(
dataset,
sampler,
temperature=0.7,
top_k=50,
top_p=0.95,
num_passes=1,
corr_lst=False,
make_lst=False,
):
"""Computes accuracy."""
response_lst = []
corr = 0
total = 0
for batch in tqdm(dataset):
answers = batch["answer"]
questions = batch["question"]
multiple_call_responses = [[] for _ in range(len(questions))]
for p in range(num_passes):
responses = generate(
questions, sampler, temperature, top_k, top_p, seed=p
)
for idx, response in enumerate(responses):
multiple_call_responses[idx].append(response)
print(f"Question:\t{questions[idx]}")
print(f"Correct Answer:\t{answers[idx]}")
print(f"Response:\t{response}")
print("-" * 50)
for question, multiple_call_response, answer in zip(
questions, multiple_call_responses, answers
):
corr_ctr_per_question = 0
for response in multiple_call_response:
# Simple Accuracy: check for answer anywhere in the full response
try:
answer_no_comma = answer.replace(",", "")
response_no_comma = response.replace(",", "")
if (
answer.strip() in response.strip()
or answer_no_comma.strip() in response_no_comma.strip()
):
corr_ctr_per_question += 1
except:
print("SKIPPED accuracy check")
if corr_ctr_per_question > 0:
break
if corr_ctr_per_question > 0:
corr += 1
if corr_lst and make_lst:
response_lst.append((question, answer, multiple_call_response))
else:
if not corr_lst and make_lst:
response_lst.append((question, answer, multiple_call_response))
total += 1
if total % 10 == 0:
print(f"===> {corr=}, {total=}, {corr / total * 100=}")
to_return = (
corr,
total,
corr / total * 100,
)
if make_lst:
return to_return, response_lst
return to_return
def extract_hash_answer(text: str) -> str | None:
if "####" not in text:
return None
return text.split("####")[1].strip()
def get_gsm8k_dataset(data_dir, split="train", source="tfds") -> grain.MapDataset:
# Download data
if not os.path.exists(data_dir):
os.makedirs(data_dir)
if source == "tfds":
data = tfds.data_source(
"gsm8k",
split=split,
data_dir=data_dir,
builder_kwargs={"file_format": tfds.core.FileFormat.ARRAY_RECORD},
download=True,
)
elif source == "huggingface":
from datasets import load_dataset as hf_load_dataset
hf_dataset = hf_load_dataset("openai/gsm8k", "main", split=split)
data = [{"question": item["question"], "answer": item["answer"]} for item in hf_dataset]
else:
raise ValueError(f"Unknown source: {source}. Choose 'tfds' or 'huggingface'")
def _as_text(v):
return v if isinstance(v, str) else v.decode("utf-8")
dataset = (
grain.MapDataset.source(data)
.shuffle(seed=42)
.map(
lambda x: {
# passed to model forward pass
"prompts": TEMPLATE.format(
question=_as_text(x["question"]),
),
# passed to reward functions
"question": _as_text(x["question"]),
# passed to reward functions
"answer": extract_hash_answer(_as_text(x["answer"])),
}
)
)
return dataset
# Choose data source: "tfds" or "huggingface"
gsm8k_source = input("Choose data source [tfds/huggingface]: ").strip().lower()
if gsm8k_source not in ("tfds", "huggingface"):
print("Invalid choice. Defaulting to 'tfds'.")
gsm8k_source = "tfds"
test_dataset = get_gsm8k_dataset(TEST_DATA_DIR, "test", source=gsm8k_source).batch(BATCH_SIZE)[
:NUM_TEST_BATCHES
]
len(test_dataset)
# Evaluate
# After evaluating the reference model on the GSM8K test dataset, we achieved an accuracy of around 65%.
(num_correct, total, accuracy), responses = evaluate(
test_dataset,
sampler,
**GENERATION_CONFIGS["standard"],
make_lst=True,
num_passes=5,
)
print(f"{num_correct=}, {total=}, {accuracy=}%")
DPO Dataset Preparation#
The DPO training dataset is loaded from the “argilla/distilabel-intel-orca-dpo-pairs” dataset on the Hugging Face Hub. This dataset contains preference pairs (chosen and rejected responses) for various prompts.
To improve the model’s performance on mathematical reasoning tasks, we prioritize samples from the GSM8K training set by filtering the dataset for records where in_gsm8k_train is True.
Since the number of GSM8K training samples might be less than the desired NUM_BATCHES for training, we add a sufficient number of random samples from the rest of the dataset to reach the target batch size. This ensures we have enough data for training while giving more weight to the GSM8K examples, and also helps improve the model’s performance on general use cases.
def get_dataset() -> grain.MapDataset:
dpo_dataset = load_dataset(
"argilla/distilabel-intel-orca-dpo-pairs", split="train"
)
gsm8k_train_dpo_dataset = dpo_dataset.filter(lambda x: x["in_gsm8k_train"])
# Get the number of samples in the filtered dataset
num_gsm8k_train_samples = len(gsm8k_train_dpo_dataset)
print(
f"Number of samples with in_gsm8k_train=True: {num_gsm8k_train_samples}"
)
# Calculate how many more samples are needed
total_samples_needed = NUM_BATCHES * BATCH_SIZE
samples_to_add = total_samples_needed - num_gsm8k_train_samples
print(f"Number of additional random samples needed: {samples_to_add}")
if samples_to_add > 0:
# Randomly select additional samples from the original dataset
# Ensure we don't sample more than the total available in the original dataset
random_samples = dpo_dataset.shuffle(seed=42).select(
range(min(samples_to_add, len(dpo_dataset)))
)
print(f"Number of random samples selected: {len(random_samples)}")
# Combine the filtered dataset and the random samples
combined_dpo_dataset = concatenate_datasets(
[gsm8k_train_dpo_dataset, random_samples]
)
else:
combined_dpo_dataset = gsm8k_train_dpo_dataset
print(f"Total samples in the combined dataset: {len(combined_dpo_dataset)}")
def _get_response(x):
for element in x:
if element["role"] == "assistant":
return element["content"]
dataset = grain.MapDataset.source(combined_dpo_dataset).map(
lambda x: {
"prompts": TEMPLATE.format(question=x["input"]),
"chosen_responses": x["chosen"],
"rejected_responses": x["rejected"],
}
)
return dataset
dataset = get_dataset().batch(BATCH_SIZE)[:NUM_BATCHES]
if TRAIN_FRACTION == 1.0:
train_dataset = dataset.repeat(NUM_EPOCHS)
val_dataset = None
else:
train_dataset = dataset[: int(len(dataset) * TRAIN_FRACTION)]
train_dataset = train_dataset.repeat(NUM_EPOCHS)
val_dataset = dataset[int(len(dataset) * TRAIN_FRACTION) :].repeat(NUM_EPOCHS)
len(train_dataset)
Define optimizer and DPO Trainer#
# Ckpt saving
checkpointing_options = ocp.CheckpointManagerOptions(
save_interval_steps=SAVE_INTERVAL_STEPS, max_to_keep=MAX_TO_KEEP
)
# Metrics logger
metrics_logging_options = metrics_logger.MetricsLoggerOptions(
log_dir="/tmp/content/tmp/tensorboard/dpo", flush_every_n_steps=20
)
# Logs
%load_ext tensorboard
%tensorboard --logdir /tmp/content/tmp/tensorboard/dpo --port=0
# Optimizer, learning rate scheduler, gradient clipping
optimizer = optax.adamw(
learning_rate=optax.schedules.warmup_cosine_decay_schedule(
init_value=0.0,
peak_value=LEARNING_RATE,
warmup_steps=WARMUP_STEPS,
decay_steps=MAX_STEPS,
end_value=0.0,
),
b1=B1,
b2=B2,
weight_decay=WEIGHT_DECAY,
)
if MAX_GRAD_NORM is not None:
optimizer = optax.chain(
optax.clip_by_global_norm(max_norm=MAX_GRAD_NORM),
optimizer,
)
# Configure DPO Training
dpo_config = DPOTrainingConfig(
beta=BETA,
eval_every_n_steps=EVAL_EVERY_N_STEPS,
max_steps=MAX_STEPS,
max_prompt_length=MAX_PROMPT_LENGTH,
max_response_length=MAX_RESPONSE_LENGTH,
metrics_logging_options=metrics_logging_options,
checkpoint_root_directory=CKPT_DIR,
checkpointing_options=checkpointing_options,
)
dpo_config
dpo_trainer = DPOTrainer(
model=lora_gemma,
ref_model=gemma3,
optimizer=optimizer,
training_config=dpo_config,
tokenizer=gemma_tokenizer,
)
Train and evaluate LoRA model#
show_hbm_usage()
with mesh:
# The first couple of training step might take up to 5 minutes to finish. Please be patient. If you experience long training steps, e.g. >10 minutes per step, please open a bug. Really appreciated!
dpo_trainer.train(train_dataset, val_dataset)
if not USE_COLAB: wandb.init()
lora_sampler = sampler_lib.Sampler(
transformer=lora_gemma,
tokenizer=gemma_tokenizer,
cache_config=sampler_lib.CacheConfig(
cache_size=MAX_PROMPT_LENGTH + MAX_RESPONSE_LENGTH + 256,
num_layers=model_config.num_layers,
num_kv_heads=model_config.num_kv_heads,
head_dim=model_config.head_dim,
),
)
# Evaluate
# After evaluating the finetuned model on the GSM8K test dataset, we achieved an accuracy of around 70%.
(num_correct, total, accuracy), responses = evaluate(
test_dataset,
lora_sampler,
**GENERATION_CONFIGS["standard"],
make_lst=True,
num_passes=5,
)
print(f"{num_correct=}, {total=}, {accuracy=}%")
Export Merged Lora Weights (Huggingface Format)#
output_dir = f"./{model_id}-lora"
if USE_COLAB:
output_dir = f"/tmp/content/{model_id}-lora"
if os.path.exists(output_dir):
shutil.rmtree(output_dir)
os.makedirs(output_dir)
print(f"Saving merged LoRA model to {output_dir}")
# Use the save_lora_merged_model_as_safetensors function
gemma_params.save_lora_merged_model_as_safetensors(
local_model_path=local_model_path,
output_dir=output_dir,
lora_model=lora_gemma,
rank=RANK,
alpha=ALPHA,
)
print("\n" + "="*60)
print("Model saved successfully!")
print(f"Output directory: {output_dir}")
print("="*60)
print("\nSaved files:")
for f in os.listdir(output_dir):
size = os.path.getsize(os.path.join(output_dir, f)) / (1024 * 1024)
print(f" {f:<30} {size:>10.2f} MB")
# For Colab: Download as zip
if USE_COLAB:
from google.colab import files
shutil.make_archive(output_dir, 'zip', output_dir)
files.download(f"{output_dir}.zip")