Distillation#
|
Distillation trainer. |
- class tunix.DistillationTrainer(student_model: Module, teacher_model: Module, strategy: BaseStrategy, optimizer: GradientTransformation, training_config: TrainingConfig)#
Distillation trainer.
- close()#
Closes the trainer and its associated resources.
This includes writing any buffered metrics, saving the last checkpoint, and closing the checkpoint manager and metrics logger.
- get_eval_loss(model: Module, teacher_output: Any, inputs: dict[str, Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray]) tuple[Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray, dict[str, Any]]#
- get_train_loss(model: Module, teacher_output: Any, inputs: dict[str, Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray]) tuple[Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray, dict[str, Any]]#
- with_gen_model_input_fn(gen_model_input_fn: Callable[[Any], dict[str, Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray]]) DistillationTrainer#
Generates model input from training input.
NB: output of this function will be passed to the loss function, so the args should match what loss function expects.
- Parameters:
gen_model_input_fn – A function that generates model input from training input.
- Returns:
PeftTrainer.
- tunix.DistillationTrainingConfig#
alias of
TrainingConfig