Fine-tuning a Vision Language Model (VLM) using DPO#
This notebook demonstrates how to fine-tune a Vision Language Model (VLM), specifically the Gemma 3 4B IT model, using Direct Preference Optimization (DPO). DPO is a method for training language models to align with human preferences without requiring a separate reward model.
This example is split into the following sections:
Installing necessary libraries and dependencies;
Loading and setting up a Gemma3-4B instruction-tuned model;
Applying LoRA (Low-Rank Adaptation) for efficient fine-tuning;
Processing DPO training data with prompt/chosen/rejected response pairs and images;
Training the model using the DPO trainer from Tunix, and
Log metrics and visualize them.
Setup#
Installing libraries (restart Colab after running the cell below)#
!pip install -q kagglehub
!pip install -q tensorflow
!pip install -q tensorboardX
!pip install -q grain
!pip install -q git+https://github.com/google/tunix
!pip install -q git+https://github.com/google/qwix
!pip uninstall -q -y flax
!pip install -q git+https://github.com/google/flax.git
!pip install -q huggingface_hub
!pip install -q datasets
Imports#
import dataclasses
import json
import os
import types
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import numpy as np
import grain
import optax
import qwix
import datasets
from flax import nnx
from huggingface_hub import snapshot_download
from PIL import Image
from tunix.generate import tokenizer_adapter as tokenizer_lib
from tunix.models.gemma3 import model as gemma3_model_lib
from tunix.models.gemma3 import params as params_lib
from tunix.processors import image_processor as image_processor_lib
from tunix.sft import metrics_logger
from tunix.sft.dpo.dpo_trainer import DPOTrainer, DPOTrainingConfig, TrainingInput
Configuration and hyperparameters#
# ====== Model paths =====
GEMMA_TOKENIZER_PATH = "gs://gemma-data/tokenizers/tokenizer_gemma3.model"
MODEL_CKPT_PATH = "gs://gemma-data/checkpoints/gemma3-4b-it"
# ====== LoRA ======
RANK = 32
ALPHA = 16.0
# ====== Sharding ======
# Adjust mesh based on your TPU memory and model size.
NUM_TPUS = len(jax.devices())
if NUM_TPUS == 8:
MESH_COUNTS = (1, 4)
elif NUM_TPUS == 1:
MESH_COUNTS = (1, 1)
else:
raise ValueError(f"Unsupported number of TPUs: {NUM_TPUS}")
MESH = [
MESH_COUNTS,
("fsdp", "tp"),
]
MAX_PROMPT_LENGTH = 512
MAX_RESPONSE_LENGTH = 512
TEMPERATURE = 0.7
TOP_P = 1.0
TOP_K = 50
BETA = 0.1
# === AdamW, warmup, cosine scheduler ===
LEARNING_RATE = 5e-6
B1 = 0.9
B2 = 0.99
WEIGHT_DECAY = 0.1
# == Cosine decay with warmup scheduler ==
# Linearly increase learning rate from 0. to 5e-6 in the first 10% training
# steps, and then gradually decrease the learning rate to 0 using cosine
# scheduler.
BATCH_SIZE = 2
NUM_SAMPLES = 5000
NUM_BATCHES = NUM_SAMPLES // BATCH_SIZE
EVAL_EVERY_N_STEPS = 1000
NUM_EPOCHS = 10
MAX_STEPS = int(NUM_BATCHES * NUM_EPOCHS)
WARMUP_STEPS = 0.1 * MAX_STEPS
# == Grad clipping ==
# Grad clipping to prevent large gradients. Found this
# important to keep KL divergence in check.
MAX_GRAD_NORM = 0.1
# Checkpoint saving
CKPT_DIR = "/tmp/ckpts/"
SAVE_INTERVAL_STEPS = 1000
MAX_TO_KEEP = 1
# ====== Inference ======
GENERATION_CONFIGS = {
# greedy search
"greedy": {"temperature": 1e-4, "top_k": 1, "top_p": 1.0},
# some randomness
"standard": {"temperature": 0.7, "top_k": 50, "top_p": 0.95},
# liberal
"liberal": {"temperature": 0.85, "top_k": 2000, "top_p": 1.0},
}
Load reference model and LoRA model#
This section is similar to how we load the model in the original DPO example,
except that we pass text_only = False when creating the model config.
model_config = gemma3_model_lib.ModelConfig.gemma3_4b_it(text_only=False)
mesh = jax.make_mesh(*MESH)
with mesh:
gemma3 = params_lib.create_model_from_checkpoint(
MODEL_CKPT_PATH, model_config, mesh, dtype=jnp.bfloat16
)
nnx.display(gemma3)
gemma_tokenizer = tokenizer_lib.Tokenizer(tokenizer_path=GEMMA_TOKENIZER_PATH)
def get_lora_model(base_model, mesh):
lora_provider = qwix.LoraProvider(
module_path=(
r".*q_einsum|.*kv_einsum|.*attn_vec_einsum|.*gate_proj|.*down_proj|"
r".*up_proj|.*query_proj|.*key_proj|.*value_proj|.*out_proj|.*fc1|.*fc2"
),
rank=RANK,
alpha=ALPHA,
)
lora_model = qwix.apply_lora_to_model(
base_model, lora_provider, **base_model.get_model_input()
)
with mesh:
state = nnx.state(lora_model)
pspecs = nnx.get_partition_spec(state)
nnx.update(lora_model, jax.lax.with_sharding_constraint(state, pspecs))
return lora_model
# Policy model
lora_gemma = get_lora_model(gemma3, mesh=mesh)
nnx.display(lora_gemma)
Load dataset#
We use the RL from AI Feedback for Vision (RLAIF-V) dataset. The dataset contains ~80K high-quality preference pairs, each pair having an associated image.
TEMPLATE = (
"<start_of_turn>user\n<start_of_image>{question}<end_of_turn>\n"
"<start_of_turn>model\n"
)
def load_dataset(split, image_processor):
ds = datasets.load_dataset("openbmb/RLAIF-V-Dataset", split=split)
cols = ["image", "question", "chosen", "rejected"]
ds = ds.remove_columns([c for c in ds.column_names if c not in cols])
ds = (
grain.MapDataset.source(ds)
.shuffle(seed=42)
.map(
lambda x: {
"prompts": (
TEMPLATE
.format(question=x["question"])
.replace(
"<start_of_image>",
"\n\n<start_of_image>" + "<img>" * 256 + "<end_of_image>\n\n"
)
),
"chosen_responses": x["chosen"],
"rejected_responses": x["rejected"],
"images": np.array(
image_processor(x["image"].convert('RGB'))[0]
),
}
)
)
return ds
train_ds = load_dataset(
split=f"train[:{NUM_SAMPLES}]",
image_processor=image_processor_lib.ImageProcessor(
config=model_config.vision_config
)
)
train_ds = train_ds.batch(BATCH_SIZE)
Define optimizer and initialize trainer#
We define checkpointing, metrics configs, load the TensorBoard metrics visualizer, define the optimizer and use Tunix’s DPOTrainer.
# Ckpt saving
checkpointing_options = ocp.CheckpointManagerOptions(
save_interval_steps=SAVE_INTERVAL_STEPS, max_to_keep=MAX_TO_KEEP
)
# Metrics logger
metrics_logging_options = metrics_logger.MetricsLoggerOptions(
log_dir="/tmp/tensorboard/dpo", flush_every_n_steps=20
)
# Logs
%load_ext tensorboard
%tensorboard --logdir /tmp/tensorboard/dpo --port=0
# Optimizer, learning rate scheduler, gradient clipping
optimizer = optax.adamw(
learning_rate=optax.schedules.warmup_cosine_decay_schedule(
init_value=0.0,
peak_value=LEARNING_RATE,
warmup_steps=WARMUP_STEPS,
decay_steps=MAX_STEPS,
end_value=0.0,
),
b1=B1,
b2=B2,
weight_decay=WEIGHT_DECAY,
)
if MAX_GRAD_NORM is not None:
optimizer = optax.chain(
optax.clip_by_global_norm(max_norm=MAX_GRAD_NORM),
optimizer,
)
# `DPOTrainer`
dpo_config = DPOTrainingConfig(
beta=BETA,
eval_every_n_steps=EVAL_EVERY_N_STEPS,
max_steps=MAX_STEPS,
max_prompt_length=MAX_PROMPT_LENGTH,
max_response_length=MAX_RESPONSE_LENGTH,
metrics_logging_options=metrics_logging_options,
checkpoint_root_directory=CKPT_DIR,
checkpointing_options=checkpointing_options,
)
dpo_trainer = DPOTrainer(
model=lora_gemma,
ref_model=gemma3,
optimizer=optimizer,
training_config=dpo_config,
tokenizer=gemma_tokenizer,
image_processor=image_processor_lib.ImageProcessor(
config=model_config.vision_config
)
)
Train!#
with mesh:
trainer.train(train_ds)