Knowledge Distillation with Tunix: Gemma 7B to Gemma 2B#

Open in Colab Open in Kaggle View on GitHub

Install necessary libraries#

This notebook demonstrates how to use the Tunix library to perform knowledge distillation. Specifically, we will use logit-based distillation to transfer knowledge from a larger, more capable teacher model (Gemma 7B) to a smaller, more efficient student model (Gemma 2B).

What is Knowledge Distillation?#

Knowledge distillation is a model compression technique where a smaller “student” model is trained to mimic the behavior of a larger, pre-trained “teacher” model. Instead of training the student solely on the ground-truth labels, we also train it to replicate the teacher’s outputs.

Logit-Based Distillation#

In this specific strategy, the student model learns to match the teacher’s logits (the raw, unnormalized outputs before the final softmax layer). By doing so, the student learns the nuanced probability distribution that the teacher model has learned, which is often more informative than the hard labels alone.

The core components we’ll use are:

  • Teacher Model: Gemma 7B

  • Student Model: Gemma 2B

  • Distillation Strategy: tunix.distillation.strategies.LogitStrategy

  • Trainer: tunix.distillation.DistillationTrainer

In this tutorial, we use a v5e-8 TPU. Let’s get started!

!pip install -q kagglehub

!pip install -q tensorflow
!pip install -q tensorboardX
!pip install -q grain-nightly
!pip install -q datasets
!pip install -q git+https://github.com/google/tunix
!pip install -q git+https://github.com/google/qwix

!pip uninstall -q -y flax
!pip install -q git+https://github.com/google/flax.git
import gc
import os

from flax import nnx
import jax
import jax.numpy as jnp
import kagglehub
import optax
from orbax import checkpoint as ocp
from tunix.distillation import distillation_trainer
from tunix.distillation import strategies
from tunix.generate import sampler as sampler_lib
from tunix.generate import tokenizer_adapter as tokenizer_lib
from tunix.models.gemma import model as gemma_lib
from tunix.models.gemma import params as params_lib
from tunix.examples.data import translation_dataset as data_lib

Utility Function to check HBM#

import functools
import humanize
from tunix.sft import utils

show_hbm_usage = utils.show_hbm_usage
show_hbm_usage()
# --- Data ---
BATCH_SIZE = 4
MAX_TARGET_LENGTH = 128
NUM_TRAIN_EPOCHS = 1

# --- Model ---
MESH = [(1, 8), ("fsdp", "tp")]

# --- Training ---
MAX_STEPS = 200
EVAL_EVERY_N_STEPS = 50
LEARNING_RATE = 1e-4

# --- Distillation ---
TEMPERATURE = 2.0  # Softens the teacher's probabilities
ALPHA = 0.7  # Balances distillation loss and student's own task loss

# --- Checkpointing ---
TEACHER_CKPT_DIR = "/content/intermediate_ckpt/teacher/"
STUDENT_CKPT_DIR = "/content/intermediate_ckpt/student/"

First, we need to load our teacher and student models. We’ll use Gemma 7B as the teacher and Gemma 2B as the student.

Important: You must have a Kaggle account and agree to the Gemma license to download the models. The first time you run this, you will be prompted to log in to Kaggle.

# Log in to Kaggle
if "KAGGLE_USERNAME" not in os.environ or "KAGGLE_KEY" not in os.environ:
  kagglehub.login()


def load_and_save_model(model_handle, version, ckpt_dir):
  """Loads a model from Kaggle, saves it locally, and cleans up memory."""
  print(f"Loading {model_handle}...")
  kaggle_ckpt_path = kagglehub.model_download(model_handle)
  ckpt_version = "2b-it"
  if "7b" in version:
    ckpt_version = "7b-it"
  # Temporarily set the default device to CPU for loading the full model
  with jax.default_device(jax.devices("cpu")[0]):
    params = params_lib.load_and_format_params(
        os.path.join(kaggle_ckpt_path, ckpt_version)
    )
    gemma = gemma_lib.Gemma.from_params(params, version=ckpt_version)

  print(f"Saving checkpoint to {ckpt_dir}...")
  checkpointer = ocp.StandardCheckpointer()
  _, state = nnx.split(gemma)
  checkpointer.save(os.path.join(ckpt_dir, "state"), state)
  checkpointer.wait_until_finished()
  # Clean up to save memory
  del params
  del gemma
  del state
  gc.collect()
  print(f"Finished processing {model_handle}.")


# Load Teacher Model (Gemma 7B)
load_and_save_model(
    "google/gemma/flax/1.1-7b-it", "1.1-7b-it", TEACHER_CKPT_DIR
)

# Load Student Model (Gemma 2B)
load_and_save_model(
    "google/gemma/flax/1.1-2b-it", "1.1-2b-it", STUDENT_CKPT_DIR
)

Now that we have the checkpoints saved locally, we can load them into sharded models. Sharding is essential for training large models efficiently on TPUs by distributing the model’s weights and the computation across multiple devices.

def get_sharded_model(ckpt_path, model_config, mesh):
  """Loads a checkpoint into a sharded model."""
  abs_gemma: nnx.Module = nnx.eval_shape(
      lambda: gemma_lib.Gemma(model_config, rngs=nnx.Rngs(params=0))
  )
  abs_state = nnx.state(abs_gemma)
  abs_state = jax.tree.map(
      lambda a, s: jax.ShapeDtypeStruct(a.shape, jnp.bfloat16, sharding=s),
      abs_state,
      nnx.get_named_sharding(abs_state, mesh),
  )
  checkpointer = ocp.StandardCheckpointer()
  restored_params = checkpointer.restore(ckpt_path, target=abs_state)

  graph_def, _ = nnx.split(abs_gemma)
  gemma = nnx.merge(graph_def, restored_params)
  return gemma


mesh = jax.make_mesh(*MESH, axis_types=(jax.sharding.AxisType.Auto,) * len(MESH[0]))

# Create Teacher Model
print("Creating sharded teacher model (Gemma 7B)...")
teacher_config = gemma_lib.ModelConfig.gemma_7b()
teacher_model = get_sharded_model(
    os.path.join(TEACHER_CKPT_DIR, "state"), teacher_config, mesh
)
print("Teacher model created.")
# nnx.display(teacher_model) # Optional: view model structure

# Create Student Model
print("\nCreating sharded student model (Gemma 2B)...")
student_config = gemma_lib.ModelConfig.gemma_2b()
student_model = get_sharded_model(
    os.path.join(STUDENT_CKPT_DIR, "state"), student_config, mesh
)
print("Student model created.")
# nnx.display(student_model) # Optional: view model structure

show_hbm_usage()
print("Loading tokenizer...")
gemma_tokenizer_path = os.path.join(
    kagglehub.model_download("google/gemma/flax/1.1-2b-it"), "tokenizer.model"
)
gemma_tokenizer = tokenizer_lib.Tokenizer(
    tokenizer_type='sentencepiece', 
    tokenizer_path=gemma_tokenizer_path
)
print("Tokenizer loaded.")

print("\nCreating datasets...")
train_ds, validation_ds = data_lib.create_datasets(
    dataset_name="mtnt/en-fr",
    global_batch_size=BATCH_SIZE,
    max_target_length=MAX_TARGET_LENGTH,
    num_train_epochs=NUM_TRAIN_EPOCHS,
    tokenizer=gemma_tokenizer,
    instruct_tuned=True,
)
print("Datasets created.")

The LogitStrategy requires three key functions:

  1. model_forward_fn: A function that performs a forward pass for a given model and returns its logits. Since both our models are from the Gemma family, we can use the same function for both.

  2. labels_fn: A function that creates the ground-truth labels from the input data for the standard cross-entropy loss.

  3. gen_model_input_fn: A helper function to format each batch from the data loader into the dictionary format expected by the model.

VOCAB_SIZE = student_config.num_embed


def model_forward_fn(
    model: nnx.Module,
    input_tokens: jax.Array,
    input_mask: jax.Array,
    positions: jax.Array,
    attention_mask: jax.Array,
):
  """Performs a forward pass and returns the logits."""
  logits, _ = model(
      input_tokens,
      positions,
      None,
      attention_mask,
  )
  # Exclude the last step as it does not appear in the targets.
  return logits[:, :-1, :]


def labels_fn(
    input_tokens: jax.Array,
    input_mask: jax.Array,
    **kwargs,
):
  """Creates one-hot encoded labels for the next-token prediction task."""
  target_tokens = input_tokens[:, 1:]
  target_mask = input_mask[:, 1:]
  labels = jax.nn.one_hot(target_tokens, VOCAB_SIZE)
  # Mask out the padding tokens from the loss calculation.
  return labels * target_mask.astype(labels.dtype)[..., None]


def gen_model_input_fn(x: distillation_trainer.TrainingInput):
  """Formats a batch from the data loader into the model's expected input format."""
  pad_mask = x.input_tokens != gemma_tokenizer.pad_id()
  positions = utils.build_positions_from_mask(pad_mask)
  attention_mask = utils.make_causal_attn_mask(pad_mask)
  return {
      'input_tokens': x.input_tokens,
      'input_mask': x.input_mask,
      'positions': positions,
      'attention_mask': attention_mask,
  }

Now we can assemble all the components. We’ll instantiate the LogitStrategy, configure the DistillationTrainer, and start the training process. The trainer will handle the distributed training loop across the available TPU cores.

# 1. Setup the distillation strategy
strategy = strategies.LogitStrategy(
    student_forward_fn=model_forward_fn,
    teacher_forward_fn=model_forward_fn,
    labels_fn=labels_fn,
    temperature=TEMPERATURE,
    alpha=ALPHA,
)

# 2. Setup the training configuration
config = distillation_trainer.TrainingConfig(
    eval_every_n_steps=EVAL_EVERY_N_STEPS,
    max_steps=MAX_STEPS,
)

# 3. Setup the optimizer
optimizer = optax.adamw(LEARNING_RATE)


# Set teacher model in eval mode
teacher_model.eval()
# Set student model in train mode
student_model.train()
# 4. Instantiate the trainer
trainer = distillation_trainer.DistillationTrainer(
    student_model=student_model,
    teacher_model=teacher_model,
    strategy=strategy,
    optimizer=optimizer,
    training_config=config,
).with_gen_model_input_fn(gen_model_input_fn)

# 5. Run training within the mesh context, the first couple of training step might take up to 5 minutes to finish. Please be patient. If you experience long training steps, e.g. >10 minutes per, please open a bug. Really appreciated!
print("Starting distillation training...")
with mesh:
  trainer.train(train_ds, validation_ds)
print("Training complete.")

After training, the student model should have improved its ability to perform the translation task by learning from the teacher. Let’s test it with a few sample prompts.

print("Setting up sampler for evaluation...")
sampler = sampler_lib.Sampler(
    transformer=student_model,
    tokenizer=gemma_tokenizer,
    cache_config=sampler_lib.CacheConfig(
        cache_size=MAX_TARGET_LENGTH + 64,
        num_layers=student_config.num_layers,
        num_kv_heads=student_config.num_kv_heads,
        head_dim=student_config.head_dim,
    ),
)
input_batch = [
    "Translate this into French:\nHello, my name is Morgane.\n",
    "Translate this into French:\nThis dish is delicious!\n",
    "Translate this into French:\nI am a student.\n",
]

print("Generating translations with the distilled student model...")
with mesh:
  out_data = sampler(
      input_strings=input_batch,
      max_generation_steps=20,
  )

print("\n--- Evaluation Results ---")
for input_string, out_string in zip(input_batch, out_data.text):
  print(f"----------------------")
  print(f"Prompt:\n{input_string}")
  print(f"Distilled Student's Output:\n{out_string}")