LoRA & QLoRA Demo#
In this tutorial, we use v6e-1 TPU to demonstrate how to fine-tune the Gemma
model for translation using Low Rank Adaptation(LoRA), a
parameter-efficient way of finetuning LLMs.
LoRA works by freezing the original weights of the pre-trained model and injecting trainable low-rank matrices into each layer of the Transformer architecture. During fine-tuning, only these newly introduced low-rank matrices are updated, greatly decreasing the computational and memory resources required compared to traditional full fine-tuning. This approach is based on the observation that the changes in model weights needed for adaptation often have a low rank. The benefits of using LoRA include reduced HBM memory usage, faster training times, and the advantage that, after training, the LoRA adapters can be merged with the original model weights, resulting in no additional inference latency.
Install necessary libraries: RESTART AFTER INSTALL FOR COLAB#
import importlib
if importlib.util.find_spec("dotenv") is None:
print("Required packages not found. Running full installation...")
%pip install -q dotenv
%pip install -q kagglehub
%pip install -q safetensors
%pip install -q tensorflow
%pip install -q tensorflow_datasets
%pip install -q tensorboardX
%pip install -q transformers
%pip install -q grain
%pip install -q datasets
%pip install -q wandb
%pip install -q git+https://github.com/jax-ml/jax
%pip install -q git+https://github.com/google/tunix
%pip install -q git+https://github.com/google/qwix
%pip uninstall -q flax -y
%pip install -q git+https://github.com/google/flax
%pip install -q 'numpy>2'
import os
import sys
import kagglehub
try:
from google.colab import userdata
USE_COLAB = True
%pip uninstall -q wandb -y # wandb is glitchy with tunix in colab
os.environ["HF_TOKEN"] = userdata.get("HF_TOKEN")
os.environ["KAGGLE_USERNAME"] = userdata.get("KAGGLE_USERNAME")
os.environ["KAGGLE_KEY"] = userdata.get("KAGGLE_KEY")
except:
USE_COLAB = False
from dotenv import load_dotenv
load_dotenv()
print("Using env vars to login")
import nest_asyncio
nest_asyncio.apply()
print("nest_asyncio applied")
import wandb
if "WANDB_API_KEY" in os.environ and os.environ["WANDB_API_KEY"]:
wandb.login(key=os.environ["WANDB_API_KEY"])
else:
print("WANDB_API_KEY not found. Skipping wandb login.")
if "KAGGLE_USERNAME" not in os.environ or "KAGGLE_KEY" not in os.environ:
kagglehub.login()
if "HF_TOKEN" in os.environ and os.environ["HF_TOKEN"]:
hf_token = os.environ["HF_TOKEN"]
!hf auth login --token "$hf_token"
else:
print("HF_TOKEN not found. Skipping Hugging Face login.")
If wandb is installed, you’ll see a message like the one below when you start the experiment:
Tracking run with wandb version 0.21.0
Run data is saved locally in /content/wandb/run-20250717_224322-kmvoi0ho
Syncing run 2025-07-17_22-43-22 to Weights & Biases (docs)
View project at https://wandb.ai/<wandb_username>/tunix?apiKey=<api_key>
View run at https://wandb.ai/<wandb_username>/tunix/runs/kmvoi0ho?apiKey=<api_key>
Do NOT share these links with anyone. They can be used to claim your runs.
After clicking the link, you will be directed to the following Weights & Biases metrics page which contain train metrics, eval metrics, system metrics, and various custom metric you wish to report:
Imports#
import gc
import json
import logging
import shutil
import dotenv
from flax import nnx
from huggingface_hub import snapshot_download
import jax
import jax.numpy as jnp
import numpy as np
import optax
from orbax import checkpoint as ocp
import qwix
from tunix.examples.data import translation_dataset as data_lib
from tunix.generate import sampler as sampler_lib
from tunix.generate import tokenizer_adapter as tokenizer_lib
from tunix.models.gemma3 import model as gemma3_model_lib
from tunix.models.gemma3 import params_safetensors as params_safetensors_lib
from tunix.models.gemma3 import params as gemma_params
from tunix.sft import metrics_logger
from tunix.sft import peft_trainer
from tunix.sft import utils
from tunix.sft.utils import show_hbm_usage
logger = logging.getLogger()
logger.setLevel(logging.INFO)
Hyperparameters#
model_id = "google/gemma-3-270m-it"
GEMMA_TOKENIZER_PATH = "gs://gemma-data/tokenizers/tokenizer_gemma3.model"
# Data
BATCH_SIZE = 2 if USE_COLAB else 64 # Adjust based on TPU memory & model size.
MAX_TARGET_LENGTH = 256 # Adjusted based on your TPU memory and model size.
# Model Setup
# Adjust mesh based on your TPU memory and model size.
NUM_TPUS = len(jax.devices())
if NUM_TPUS == 8:
MESH_COUNTS = (1, 4)
elif NUM_TPUS == 1:
MESH_COUNTS = (1, 1)
else:
raise ValueError(f"Unsupported number of TPUs: {NUM_TPUS}")
MESH = [
MESH_COUNTS,
("fsdp", "tp"),
]
# LoRA/QLoRA Configuration
USE_QUANTIZATION = True # Set to True for QLoRA, False for LoRA
RANK = 16
ALPHA = float(2 * RANK)
# Train
MAX_STEPS = 100
EVAL_EVERY_N_STEPS = 20
NUM_EPOCHS = 3
# Checkpoint saving
FULL_CKPT_DIR = "/tmp/content/full_ckpts/"
LORA_CKPT_DIR = "/tmp/content/lora_ckpts/"
PROFILING_DIR = "/tmp/content/profiling/"
def create_dir(path):
try:
os.makedirs(path, exist_ok=True)
logging.info(f"Created dir: {path}")
except OSError as e:
logging.error(f"Error creating directory '{path}': {e}")
create_dir(FULL_CKPT_DIR)
create_dir(LORA_CKPT_DIR)
create_dir(PROFILING_DIR)
Load model from HF#
To load the model, you need to be on Kaggle and need to have agreed to the Gemma license here.
ignore_patterns = [
"*.pth", # Ignore PyTorch .pth weight files
]
print(f"Downloading {model_id} from Hugging Face...")
local_model_path = snapshot_download(
repo_id=model_id, ignore_patterns=ignore_patterns
)
print(f"Model successfully downloaded to: {local_model_path}")
EOS_TOKENS = []
generation_config_path = os.path.join(local_model_path, "generation_config.json")
if os.path.exists(generation_config_path):
with open(generation_config_path, "r") as f:
generation_configs = json.load(f)
EOS_TOKENS = generation_configs.get("eos_token_id", [])
print(f"Using EOS token IDs: {EOS_TOKENS}")
print("\n--- HBM Usage BEFORE Model Load ---")
show_hbm_usage()
MODEL_CP_PATH = local_model_path
if "gemma-3-270m" in model_id:
model_config = gemma3_model_lib.ModelConfig.gemma3_270m()
elif "gemma-3-1b" in model_id:
model_config = gemma3_model_lib.ModelConfig.gemma3_1b_it()
else:
raise ValueError(f"Unsupported model: {model_id}")
mesh = jax.make_mesh(*MESH, axis_types=(jax.sharding.AxisType.Auto,) * len(MESH[0]))
with mesh:
base_model = params_safetensors_lib.create_model_from_safe_tensors(
MODEL_CP_PATH, (model_config), mesh
)
nnx.display(base_model)
Initialize Tokenizer#
tokenizer = tokenizer_lib.Tokenizer(tokenizer_path=GEMMA_TOKENIZER_PATH)
if tokenizer.eos_id() not in EOS_TOKENS:
EOS_TOKENS.append(tokenizer.eos_id())
print(f"Using EOS token IDs: {EOS_TOKENS}")
Prompt the model#
Let’s see how the original model performs on the English-French translation task.
sampler = sampler_lib.Sampler(
transformer=base_model,
tokenizer=tokenizer if "gemma" in model_id else tokenizer.tokenizer,
cache_config=sampler_lib.CacheConfig(
cache_size=256,
num_layers=model_config.num_layers,
num_kv_heads=model_config.num_kv_heads,
head_dim=model_config.head_dim,
),
)
input_batch = [
"Translate this into French:\nHello, my name is Morgane.\n",
"Translate this into French:\nThis dish is delicious!\n",
"Translate this into French:\nI am a student.\n",
"Translate this into French:\nHow's the weather today?\n",
]
out_data = sampler(
input_strings=input_batch,
max_generation_steps=10, # The number of steps performed when generating a response.
eos_tokens=EOS_TOKENS,
)
for input_string, out_string in zip(input_batch, out_data.text):
print(f"----------------------")
print(f"Prompt:\n{input_string}")
print(f"Output:\n{out_string}")
Apply LoRA/QLoRA to the base model#
The choice between LoRA and QLoRA is controlled by the USE_QUANTIZATION hyperparameter:
USE_QUANTIZATION = True: Use QLoRA (quantized LoRA with NF4)USE_QUANTIZATION = False: Use regular LoRA
def get_lora_model(base_model, mesh, quantize=False):
if quantize:
lora_provider = qwix.LoraProvider(
module_path=".*q_einsum|.*kv_einsum|.*gate_proj|.*down_proj|.*up_proj",
rank=RANK,
alpha=ALPHA,
weight_qtype="nf4",
tile_size=128,
)
else:
lora_provider = qwix.LoraProvider(
module_path=".*q_einsum|.*kv_einsum|.*gate_proj|.*down_proj|.*up_proj",
rank=RANK,
alpha=ALPHA,
)
model_input = base_model.get_model_input()
lora_model = qwix.apply_lora_to_model(
base_model, lora_provider, **model_input
)
with mesh:
state = nnx.state(lora_model)
pspecs = nnx.get_partition_spec(state)
sharded_state = jax.lax.with_sharding_constraint(state, pspecs)
nnx.update(lora_model, sharded_state)
return lora_model
# Create LoRA or QLoRA model based on USE_QUANTIZATION hyperparameter
lora_model = get_lora_model(base_model, mesh=mesh, quantize=USE_QUANTIZATION)
nnx.display(lora_model)
print(f"Using {'QLoRA' if USE_QUANTIZATION else 'LoRA'} model")
Load Datasets for SFT Training#
# Loads the training and validation datasets
train_ds, validation_ds = data_lib.create_datasets(
dataset_name='mtnt/en-fr',
# Uncomment the line below to use a Hugging Face dataset.
# Note that this requires upgrading the 'datasets' package and restarting
# the Colab runtime if you are using it.
# dataset_name='Helsinki-NLP/opus-100',
global_batch_size=BATCH_SIZE,
max_target_length=MAX_TARGET_LENGTH,
num_train_epochs=NUM_EPOCHS,
tokenizer=tokenizer,
)
def gen_model_input_fn(x: peft_trainer.TrainingInput):
pad_mask = x.input_tokens != tokenizer.pad_id()
positions = utils.build_positions_from_mask(pad_mask)
attention_mask = utils.make_causal_attn_mask(pad_mask)
return {
'input_tokens': x.input_tokens,
'input_mask': x.input_mask,
'positions': positions,
'attention_mask': attention_mask,
}
SFT Training#
Training with full weights (skipped in Colab because it runs out of RAM)#
%load_ext tensorboard
%tensorboard --logdir /tmp/tensorboard/full
if USE_COLAB:
print("Not enough RAM to run full training on Colab")
else:
full_logging_options = metrics_logger.MetricsLoggerOptions(
log_dir="/tmp/tensorboard/full", flush_every_n_steps=20
)
training_config = peft_trainer.TrainingConfig(
eval_every_n_steps=EVAL_EVERY_N_STEPS,
max_steps=MAX_STEPS,
metrics_logging_options=full_logging_options,
checkpoint_root_directory=FULL_CKPT_DIR,
)
trainer = peft_trainer.PeftTrainer(
base_model, optax.adamw(1e-5), training_config
).with_gen_model_input_fn(gen_model_input_fn)
# The first couple of training step might take up to 5 minutes to finish. Please be patient. If you experience long training steps, e.g. >10 minutes per step, please open a bug. Really appreciated!
with mesh:
trainer.train(train_ds, validation_ds)
wandb.init()
Training with LoRA/QLoRA#
The model will be trained using the method specified by USE_QUANTIZATION hyperparameter.
%load_ext tensorboard
%tensorboard --logdir /tmp/tensorboard/lora
lora_logging_options = metrics_logger.MetricsLoggerOptions(
log_dir="/tmp/tensorboard/lora", flush_every_n_steps=20
)
training_config = peft_trainer.TrainingConfig(
eval_every_n_steps=EVAL_EVERY_N_STEPS,
max_steps=MAX_STEPS,
metrics_logging_options=lora_logging_options,
checkpoint_root_directory=LORA_CKPT_DIR,
)
trainer = peft_trainer.PeftTrainer(
lora_model, optax.adamw(1e-3), training_config
).with_gen_model_input_fn(gen_model_input_fn)
# The first couple of training step might take up to 5 minutes to finish. Please be patient. If you experience long training steps, e.g. >10 minutes per step, please open a bug. Really appreciated!
method_name = "QLoRA" if USE_QUANTIZATION else "LoRA"
with mesh:
trainer.train(train_ds, validation_ds)
if not USE_COLAB: wandb.init()
Compare profile results of different training#
Setup |
Train Step Time |
Peak Memory Usage |
|---|---|---|
Full weights |
~1.22 s |
43.26 GiB |
QLoRA |
~1.19 s |
28.14 GiB |
Generate with the LoRA/QLoRA model#
The model may still not perform English-to-French translation perfectly since we only trained for 100 steps. If you train it for longer, you will see better results.
sampler = sampler_lib.Sampler(
transformer=lora_model,
tokenizer=tokenizer if "gemma" in model_id else tokenizer.tokenizer,
cache_config=sampler_lib.CacheConfig(
cache_size=256,
num_layers=model_config.num_layers,
num_kv_heads=model_config.num_kv_heads,
head_dim=model_config.head_dim,
),
)
input_batch = [
"Translate this into French:\nHello, my name is Morgane.\n",
"Translate this into French:\nThis dish is delicious!\n",
"Translate this into French:\nI am a student.\n",
"Translate this into French:\nHow's the weather today?\n",
]
out_data = sampler(
input_strings=input_batch,
max_generation_steps=10, # The number of steps performed when generating a response.
eos_tokens=EOS_TOKENS,
)
for input_string, out_string in zip(input_batch, out_data.text):
print(f"----------------------")
print(f"Prompt:\n{input_string}")
print(f"Output:\n{out_string}")
Export Merged Lora Weights (Huggingface Format)#
output_dir = f"./{model_id}-lora"
if USE_COLAB:
output_dir = f"/tmp/content/{model_id}-lora"
if os.path.exists(output_dir):
shutil.rmtree(output_dir)
os.makedirs(output_dir)
print(f"Saving merged LoRA model to {output_dir}")
# Use the save_lora_merged_model_as_safetensors function
gemma_params.save_lora_merged_model_as_safetensors(
local_model_path=local_model_path,
output_dir=output_dir,
lora_model=lora_model,
rank=RANK,
alpha=ALPHA,
)
print("\n" + "="*60)
print("Model saved successfully!")
print(f"Output directory: {output_dir}")
print("="*60)
print("\nSaved files:")
for f in os.listdir(output_dir):
size = os.path.getsize(os.path.join(output_dir, f)) / (1024 * 1024)
print(f" {f:<30} {size:>10.2f} MB")
# For Colab: Download as zip
if USE_COLAB:
from google.colab import files
shutil.make_archive(output_dir, 'zip', output_dir)
files.download(f"{output_dir}.zip")