Parameter-Efficient Fine-Tuning of Llama 3.1-8B with LoRA/QLoRA on NVIDIA GPUs using JAX and Tunix#
This tutorial walks you through parameter-efficient fine-tuning (PEFT) of Llama 3.1-8B using LoRA and QLoRA on NVIDIA GPUs with JAX, Tunix, and Qwix. Unlike full-parameter SFT, PEFT freezes the base model weights and trains only small adapter matrices, dramatically reducing memory requirements and training time while maintaining model quality.
What you’ll do:
Set up the environment and authenticate with Hugging Face
Load the Llama 3.1-8B base model and apply LoRA/QLoRA adapters
Prepare the UltraChat 200k dataset for instruction fine-tuning
Configure and run parameter-efficient fine-tuning
Visualize training metrics with TensorBoard
Run a quick inference sanity check
Preliminaries#
Make sure you have supported hardware#
Hardware requirements. QLoRA with 4-bit quantization can fine-tune Llama 3.1-8B on a single GPU with 16 GB+ of VRAM. For LoRA without quantization, 24 GB+ is recommended. Multiple GPUs enable larger batch sizes and faster training through data parallelism; on multi-GPU systems, the model is automatically sharded across devices using FSDP and tensor parallelism.
!nvidia-smi
Installing Libraries: RESTART RUNTIME AFTER INSTALLATION#
The easiest way to get a working environment is the NVIDIA NGC JAX container, which ships with all dependencies preinstalled. To install the dependencies manually:
pip install 'jax[cuda13]' flax optax transformers datasets qwix huggingface_hub wandb
On top of the installation (either container or manual), you will need Tunix:
pip install git+https://github.com/google/tunix
# Install necessary libraries
import importlib
if importlib.util.find_spec("tunix") is None:
print("Required packages not found. Running full installation...")
%pip install 'jax[cuda13]' flax optax transformers datasets qwix huggingface_hub wandb
%pip install git+https://github.com/google/tunix
Set your Hugging Face token#
Create a Hugging Face access token in your Hugging Face account settings, copy it, and paste it into the field below. This token is required to authenticate with the Hugging Face Hub and download the Llama 3.1 model and related assets; once saved, it will be reused by this environment for the rest of the tutorial.
import os
from ipywidgets import Password, Button, HBox, Output
from IPython.display import display
try:
from huggingface_hub import whoami
except Exception:
from huggingface_hub import HfApi
def _verify_token(token: str) -> str:
try:
return whoami(token=token).get("name", "unknown")
except TypeError:
return HfApi(token=token).whoami().get("name", "unknown")
token_box = Password(description="HF Token:", placeholder="paste your token here", layout={"width": "400px"})
save_btn = Button(description="Save", button_style="success")
out = Output()
def save_token(_):
out.clear_output()
with out:
existing = os.environ.get("HF_TOKEN")
entered = token_box.value.strip()
if existing and not entered:
user = _verify_token(existing)
print(f"Using existing HF_TOKEN. Logged in as: {user}")
return
if not entered:
print("No HF token entered.")
return
os.environ["HF_TOKEN"] = entered
user = _verify_token(entered)
print(f"Token saved. Logged in as: {user}")
save_btn.on_click(save_token)
display(HBox([token_box, save_btn]), out)
Authenticate with Hugging Face#
Verify that your Hugging Face token is set and valid. If the token is missing, an error is raised immediately rather than failing silently during model download.
# Prefer environment variable if already set
from huggingface_hub import whoami
HF_TOKEN = os.environ.get("HF_TOKEN")
if HF_TOKEN:
try:
user = whoami()["name"]
print(f"Authenticated with Hugging Face as: {user} (via HF_TOKEN env)")
except Exception as e:
print("HF_TOKEN is set but authentication failed:", e)
else:
raise RuntimeError(
"HF_TOKEN is not set. Please create a Hugging Face access token "
"and export it as an environment variable."
)
Acquire permission to use the gated model#
Llama 3.1-8B is a gated model, so you must explicitly request access before it can be downloaded. Visit the model page on Hugging Face, log in with the same account linked to your access token, and click Request access. You’ll need to agree to Meta’s license terms; approval is usually granted quickly but is not automatic. Once approved, your Hugging Face token will authorize downloads transparently. If you skip this step, model downloads will fail even with a valid token.
Set up the environment#
Import dependencies#
Import the core libraries needed for training:
JAX/Flax: High-performance ML framework with automatic differentiation and XLA compilation
Optax: Gradient processing and optimization library for JAX
Transformers: Hugging Face library for tokenizers and model configurations
Qwix: Quantization and LoRA utilities for JAX models
Tunix: Training utilities including
PeftTrainerandAutoModelfor streamlined fine-tuning
# Imports
import time
import shutil
import numpy as np
import jax
import jax.numpy as jnp
import optax
from flax import nnx
import transformers
from datasets import load_dataset
import qwix
from tunix.models.automodel import AutoModel
from tunix.sft import peft_trainer, metrics_logger
import wandb
print(f"JAX {jax.__version__} | Devices: {jax.devices()}")
Create the device mesh#
JAX uses a device mesh to define how computation and data are distributed across GPUs. The mesh assigns logical axis names to physical device dimensions, enabling FSDP (Fully Sharded Data Parallel) and TP (Tensor Parallel) strategies. The configuration adapts automatically based on available GPUs:
GPUs |
Mesh Shape |
Strategy |
|---|---|---|
8+ |
|
data + FSDP + TP |
2–7 |
|
FSDP only |
1 |
|
No sharding |
The fsdp axis shards model parameters across devices to reduce per-device memory, while tp enables tensor-parallel splitting of large weight matrices.
# Create mesh for sharding
NUM_DEVICES = jax.local_device_count()
if NUM_DEVICES >= 8:
mesh = jax.make_mesh((1, 4, 2), ("data", "fsdp", "tp"),
axis_types=(jax.sharding.AxisType.Auto,) * 3)
elif NUM_DEVICES >= 2:
# Shard model across GPUs using FSDP
mesh = jax.make_mesh((NUM_DEVICES, 1), ("fsdp", "tp"),
axis_types=(jax.sharding.AxisType.Auto,) * 2)
else:
# Single GPU - no sharding, but keep axis names for API consistency
mesh = jax.make_mesh((1, 1), ("fsdp", "tp"),
axis_types=(jax.sharding.AxisType.Auto,) * 2)
print(f"Devices: {NUM_DEVICES} | Mesh: {mesh.shape}")
Define model and training parameters#
All training hyperparameters are defined in one place for easy experimentation. The key parameters control model selection, LoRA configuration, quantization, batch size, sequence length, and training duration. Set CLEAN_START = True to remove existing checkpoints before training, or False to resume from a previous run.
# Configuration
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2" # Suppress CUDA/TF warnings
MODEL_ID = "meta-llama/Llama-3.1-8B"
TOKENIZER_ID = "meta-llama/Llama-3.1-8B-Instruct"
LORA_RANK = 16
LORA_ALPHA = 32.0
USE_QUANTIZATION = True # set True for QLoRA (4-bit), set False for regular LoRA
BATCH_SIZE = 2
MAX_SEQ_LENGTH = 512
LEARNING_RATE = 1e-4
MAX_STEPS = 100
OUTPUT_DIR = "/workspace/llama3_lora_output"
CLEAN_START = True # Set to False to resume from checkpoint
if CLEAN_START and os.path.exists(f"{OUTPUT_DIR}/checkpoints"):
shutil.rmtree(f"{OUTPUT_DIR}/checkpoints")
print("Removed old checkpoints (CLEAN_START=True)")
os.makedirs(OUTPUT_DIR, exist_ok=True)
Load the model#
Load the tokenizer#
Load the tokenizer from the Instruct model variant, which includes the chat template for formatting conversations. The pad token is set to the EOS token if not already defined, which is standard for decoder-only models like Llama.
# Load tokenizer
tokenizer = transformers.AutoTokenizer.from_pretrained(TOKENIZER_ID, token=HF_TOKEN)
tokenizer.pad_token = tokenizer.pad_token or tokenizer.eos_token
print(f"Tokenizer loaded: {TOKENIZER_ID}")
Load the base model#
AutoModel.from_pretrained() handles the complete model loading pipeline: downloading weights from the Hugging Face Hub (cached in model_download_path), converting them to JAX-compatible format, and initializing the model architecture with proper sharding across the mesh.
The model is loaded within the mesh context to ensure parameters are distributed correctly across devices from the start.
# Load model using AutoModel
print(f"Loading {MODEL_ID}...")
load_start = time.time()
with mesh:
base_model, model_path = AutoModel.from_pretrained(
MODEL_ID,
mesh,
model_download_path="/hf_cache",
)
print(f"Model loaded in {time.time() - load_start:.1f}s")
print(f"Model path: {model_path}")
Apply LoRA / QLoRA#
Low-Rank Adaptation (LoRA) freezes the base model weights and injects small trainable matrices into attention and MLP layers. This dramatically reduces the number of trainable parameters while preserving model quality.
QLoRA adds 4-bit NF4 quantization on top of LoRA to further reduce memory:
Base weights are quantized to 4-bit NormalFloat format
Only the small LoRA adapter weights remain in full precision
tile_size=32controls the quantization block size (must divide the smallest weight dimension)
Target modules specify which layers receive LoRA adapters using regex patterns matching attention projections (q_proj, k_proj, v_proj, o_proj) and MLP layers (gate_proj, up_proj, down_proj).
# Apply QLoRA / LoRA
target_modules = ".*q_proj|.*k_proj|.*v_proj|.*o_proj|.*gate_proj|.*up_proj|.*down_proj"
lora_provider = qwix.LoraProvider(
module_path=target_modules,
rank=LORA_RANK,
alpha=LORA_ALPHA,
weight_qtype="nf4" if USE_QUANTIZATION else None,
tile_size=32 if USE_QUANTIZATION else None,
)
dummy_input = {
'input_tokens': jnp.ones((1, 128), dtype=jnp.int32),
'positions': jnp.arange(128)[None, :],
'cache': None,
'attention_mask': jnp.ones((1, 128, 128), dtype=jnp.bool_),
}
print(f"Applying {'QLoRA' if USE_QUANTIZATION else 'LoRA'} (rank={LORA_RANK})...")
lora_model = qwix.apply_lora_to_model(
base_model, lora_provider,
rngs=nnx.Rngs(params=0), # For reproducible LoRA weight initialization
**dummy_input
)
with mesh:
state = nnx.state(lora_model)
sharded = jax.lax.with_sharding_constraint(state, nnx.get_partition_spec(state))
nnx.update(lora_model, sharded)
print(f"{'QLoRA' if USE_QUANTIZATION else 'LoRA'} applied!")
Prepare the training data#
Load the UltraChat 200k dataset, a large collection of multi-turn conversations commonly used for instruction fine-tuning. For this tutorial, a subset of 2,000 training and 200 evaluation examples is used.
The data processing pipeline applies the chat template to format conversations with special tokens, tokenizes with padding to MAX_SEQ_LENGTH, and creates attention masks to ignore padding tokens. Training uses an infinite generator that cycles through the data, while evaluation uses a finite iterator that yields exactly one pass through the eval set. Batch size is scaled by NUM_DEVICES for data parallelism.
# Prepare dataset
dataset = load_dataset("HuggingFaceH4/ultrachat_200k", split="train_sft").select(range(2000))
eval_dataset = load_dataset("HuggingFaceH4/ultrachat_200k", split="test_sft").select(range(200))
def tokenize(ex):
text = tokenizer.apply_chat_template(ex["messages"], tokenize=False)
tok = tokenizer(text, max_length=MAX_SEQ_LENGTH, padding="max_length", truncation=True)
return {"input_tokens": np.array(tok["input_ids"]), "input_mask": np.array(tok["attention_mask"], dtype=bool)}
train_data = [tokenize(ex) for ex in dataset]
eval_data = [tokenize(ex) for ex in eval_dataset]
# Infinite generator for training (cycles through data)
def train_batches(data, bs):
i = 0
while True:
batch = data[i:i+bs] if i+bs <= len(data) else data[:bs]
yield {k: np.stack([x[k] for x in batch]) for k in batch[0]}
i = (i + bs) % len(data)
# Reusable eval dataset - returns fresh finite iterator each time
class EvalDataset:
def __init__(self, data, bs):
self.data = data
self.bs = bs
def __iter__(self):
for i in range(0, len(self.data), self.bs):
batch = self.data[i:i+self.bs]
if len(batch) == self.bs:
yield {k: np.stack([x[k] for x in batch]) for k in batch[0]}
train_ds = train_batches(train_data, BATCH_SIZE * NUM_DEVICES)
eval_ds = EvalDataset(eval_data, BATCH_SIZE * NUM_DEVICES)
print(f"Train: {len(train_data)} examples | Eval: {len(eval_data)} examples | Batch size: {BATCH_SIZE * NUM_DEVICES}")
Provide the training configuration#
The model expects specific input formats: position indices for rotary embeddings and a 3D causal attention mask [batch, seq, seq] combining causal attention with padding. The gen_model_input function constructs these from the tokenized batch.
PeftTrainer orchestrates the training loop with an AdamW optimizer, periodic checkpointing, TensorBoard-compatible metrics logging, and evaluation every eval_every_n_steps steps.
# Input processing helpers
def build_positions(mask):
return jnp.clip(jnp.cumsum(mask, axis=-1) - 1, 0).astype(jnp.int32)
def build_causal_mask(mask):
n = mask.shape[-1]
return jnp.tril(jnp.ones((n, n), dtype=jnp.bool_))[None] & mask[:, None, :]
def gen_model_input(x):
mask = x["input_tokens"] != tokenizer.pad_token_id
return {
"input_tokens": x["input_tokens"],
"positions": build_positions(mask),
"attention_mask": build_causal_mask(mask),
"input_mask": x["input_mask"],
}
# Create trainer
trainer = peft_trainer.PeftTrainer(
lora_model,
optax.adamw(LEARNING_RATE),
peft_trainer.TrainingConfig(
max_steps=MAX_STEPS,
eval_every_n_steps=25, # Evaluate every 25 steps
checkpoint_root_directory=f"{OUTPUT_DIR}/checkpoints",
metrics_logging_options=metrics_logger.MetricsLoggerOptions(log_dir=f"{OUTPUT_DIR}/logs"),
),
).with_gen_model_input_fn(gen_model_input)
print("Trainer ready!")
Run the training#
This block launches the PEFT training loop. It runs a baseline evaluation first to measure initial loss, then trains for MAX_STEPS steps with periodic evaluation. The first step is slower due to XLA JIT compilation, which is cached for subsequent steps.
# Training with progress
NUM_EVAL_BATCHES = len(eval_data) // (BATCH_SIZE * NUM_DEVICES)
class Progress:
def __init__(self, n):
self.n = n
self.t0 = None
self.eval_count = 0
self.eval_started = False
def on_train_start(self, _):
self.t0 = time.time()
print("Training (first step includes JIT)...")
def on_train_end(self, _):
print(f"\nDone in {time.time()-self.t0:.0f}s")
def on_train_step_start(self, _):
self.eval_started = False
def on_train_step_end(self, _, step, loss, dt):
if step <= 2 or step % 10 == 0:
print(f"Step {step}/{self.n} | Loss: {float(loss):.4f} | {dt:.1f}s/step")
def on_eval_step_start(self, _):
if not self.eval_started:
self.eval_count += 1
label = "Baseline eval" if self.eval_count == 1 else f"Eval #{self.eval_count}"
print(f"{label}...", end=" ", flush=True)
self.eval_started = True
def on_eval_step_end(self, _, eval_loss):
avg_loss = float(eval_loss) / NUM_EVAL_BATCHES
print(f"loss: {avg_loss:.4f} (avg over {NUM_EVAL_BATCHES} batches)")
trainer.training_hooks = Progress(MAX_STEPS)
print("Starting (baseline eval + JIT compilation first)...")
with mesh:
trainer.train(train_ds, eval_ds)
print(f"Checkpoints: {OUTPUT_DIR}/checkpoints")
Visualize training with TensorBoard#
To monitor training loss and other metrics, launch TensorBoard in a separate terminal:
tensorboard --logdir=/workspace/llama3_lora_output/logs --host 0.0.0.0 --port 6006 --load_fast=false
Then open http://127.0.0.1:6006/ in your browser.
Test inference#
A quick sanity check to verify the fine-tuned model produces coherent output. The code below tokenizes a prompt using the Llama 3.1 chat template, then runs greedy autoregressive generation for up to 10 tokens, stopping early if the model produces an EOS token. This confirms the adapters are applied correctly and the model produces reasonable predictions.
Note: this is naive autoregressive generation without KV-caching, so each step recomputes attention over the full sequence. For production use, consider a dedicated serving framework with KV-cache support.
#Initialize wandb
wandb.init()
# Quick inference test with the fine-tuned LoRA model
prompt = "What is the capital of France?"
messages = [{"role": "user", "content": prompt}]
text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
# Tokenize
tokens = jnp.array(tokenizer(text)["input_ids"])[None, :]
# Greedy autoregressive generation
max_new_tokens = 10
generated_ids = []
eos_token_id = tokenizer.eos_token_id
for _ in range(max_new_tokens):
seq_len = tokens.shape[1]
positions = jnp.arange(seq_len)[None, :]
attention_mask = jnp.tril(jnp.ones((seq_len, seq_len), dtype=jnp.bool_))[None, :]
with mesh:
output = lora_model(tokens, positions, None, attention_mask)
logits = output[0] if isinstance(output, tuple) else output
next_token_id = int(jnp.argmax(logits[0, -1]))
generated_ids.append(next_token_id)
if next_token_id == eos_token_id:
break
tokens = jnp.concatenate([tokens, jnp.array([[next_token_id]])], axis=1)
# Decode all generated tokens
generated_text = tokenizer.decode(generated_ids, skip_special_tokens=True)
print(f"Prompt: {prompt}")
print(f"Generated ({len(generated_ids)} tokens): '{generated_text}'")