Distillation#

DistillationTrainer(student_model, ...)

Distillation trainer.

DistillationTrainingConfig


class tunix.DistillationTrainer(student_model: Module, teacher_model: Module, strategy: BaseStrategy, optimizer: GradientTransformation, training_config: TrainingConfig)#

Distillation trainer.

close()#

Closes the trainer and its associated resources.

This includes writing any buffered metrics, saving the last checkpoint, and closing the checkpoint manager and metrics logger.

get_eval_loss(model: Module, teacher_output: Any, inputs: dict[str, Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray]) tuple[Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray, dict[str, Any]]#
get_train_loss(model: Module, teacher_output: Any, inputs: dict[str, Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray]) tuple[Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray, dict[str, Any]]#
with_gen_model_input_fn(gen_model_input_fn: Callable[[Any], dict[str, Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray]]) DistillationTrainer#

Generates model input from training input.

NB: output of this function will be passed to the loss function, so the args should match what loss function expects.

Parameters:

gen_model_input_fn – A function that generates model input from training input.

Returns:

PeftTrainer.

with_loss_fn(loss_fn: Callable[[...], Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray | tuple[Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray, Any]], has_aux: bool = False) DistillationTrainer#

tunix.DistillationTrainingConfig#

alias of TrainingConfig