Supervised fine-tuning (SFT)#
|
PEFT trainer for LoRA. |
|
Configuration for the trainer. |
|
Direct Preference Optimization (DPO) and ORPO trainer. |
|
DPO/ORPO Training Config. |
|
Simple Metrics logger. |
|
Metrics Logger options. |
- class tunix.PeftTrainer(model: Module, optimizer: GradientTransformation, training_config: TrainingConfig, metrics_logger: MetricsLogger | None = None, perf_tracer: PerfTracer | NoopTracer | None = None)#
PEFT trainer for LoRA. Only LoRA parameters are updated.
- model#
The model to train.
- config#
The training config.
- optimizer#
The optimizer to use. To monitor the learning rate at each step, use
optax.schedules.inject_hyperparamsto inject learning rate as a hyperparameter. For example:optimizer = optax.schedules.inject_hyperparams(optax.sgd)(learning_rate=learning_rate_schedule)
- loss_fn#
The loss function to use.
- eval_loss_fn#
The loss function to use for evaluation.
- gen_model_input_fn#
The function to generate model input from training input.
- checkpoint_manager#
The checkpoint manager to use.
- metrics_logger#
The metrics logger to use.
- metrics_prefix#
The prefix for metric names for logging.
- is_managed_externally#
Whether the trainer is managed externally.
- training_hooks#
The training hooks to use.
- data_hooks#
The data hooks to use.
- clear_jit_cache()#
Clears the JIT cache of the train and eval step functions.
This function should be called when the trainer is being reused after overriding the training related states, for example, the loss function.
- close()#
Closes the trainer and its associated resources.
This includes writing any buffered metrics, saving the last checkpoint, and closing the checkpoint manager and metrics logger.
- create_eval_step_fn() Callable[[...], Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray]#
Creates the eval step function.
- create_train_step_fn() Callable[[...], Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray]#
Creates the train step function.
- custom_checkpoint_metadata() dict[str, Any]#
Override this function to return the custom metadata for the checkpoint manager.
- property iter_steps: int#
Returns the number of iterator steps taken.
- jit_train_and_eval_step(skip_jit: bool = False, cache_nnx_graph: bool = False)#
Creates and returns the train and eval step functions.
This function will return the cached ones if available.
- Parameters:
skip_jit – If True, the train and eval step functions will not be JITed.
cache_nnx_graph – If True, the nnx graph will be cached.
- Returns:
A tuple of train and eval step functions.
- train(train_ds: Iterable[Any], eval_ds: Iterable[Any] | None = None, skip_jit: bool = False, *, cache_nnx_graph: bool = True) None#
Training loop.
- property train_steps: int#
Returns the number of train steps taken.
- with_data_hooks(data_hooks: DataHooks)#
- with_gen_model_input_fn(gen_model_input_fn: Callable[[Any], Dict[str, Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray]])#
Generates model input from training input.
NB: output of this function will be passed to the loss function, so the args should match what loss function expects.
- Parameters:
gen_model_input_fn – A function that generates model input from training input.
- Returns:
PeftTrainer.
- with_loss_fn(loss_fn: Callable[[Concatenate[Module, P]], Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray | Tuple[Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray, Any]], has_aux: bool = False)#
- with_training_hooks(training_hooks: TrainingHooks)#
- class tunix.TrainingConfig(*, eval_every_n_steps: int, max_steps: int | None = None, gradient_accumulation_steps: int | None = None, checkpoint_root_directory: str | None = None, checkpointing_options: CheckpointManagerOptions | None = None, metrics_logging_options: MetricsLoggerOptions | None = None, profiler_options: ProfilerOptions | None = None, perf_metrics_options: PerfMetricsOptions | None = None, data_sharding_axis: Tuple[str, ...] = ('fsdp',), max_inflight_computations: int = 2, metrics_prefix: str = '', pbar_description: str | None = 'Training')#
Configuration for the trainer.
- checkpoint_root_directory: str | None#
- checkpointing_options: CheckpointManagerOptions | None#
- data_sharding_axis: Tuple[str, ...]#
- eval_every_n_steps: int#
- get_with_default(key: str, default: Any) Any#
- gradient_accumulation_steps: int | None#
- max_inflight_computations: int#
- max_steps: int | None#
- metrics_logging_options: MetricsLoggerOptions | None#
- metrics_prefix: str#
- pbar_description: str | None#
- perf_metrics_options: PerfMetricsOptions | None#
- profiler_options: ProfilerOptions | None#
- class tunix.DPOTrainer(model: Module, ref_model: Module | None, optimizer: GradientTransformation, training_config: DPOTrainingConfig, tokenizer: Any | None = None)#
Direct Preference Optimization (DPO) and ORPO trainer.
DPO is a preference tuning method for aligning large language models with human or AI preferences. It is a more efficient, performant alternative to RLHF.
DPO is simpler because it eliminates the need for text generation in the training loop. Moreover, DPO bypasses the reward modeling step entirely, i.e., we do not need to train a separate reward model. It uses a dataset of preferences (pairs of “chosen” and “rejected responses) to directly optimize the policy model by using a classification-style loss.
ORPO (Odds Ratio Preference Optimization) is a memory-efficient variant that combines supervised fine-tuning with preference alignment without requiring a separate reference model, making it approximately 50% more memory-efficient.
References: - DPO: https://arxiv.org/abs/2305.18290 - ORPO: https://arxiv.org/abs/2403.07691
- class tunix.DPOTrainingConfig(*, eval_every_n_steps: int, max_steps: int | None = None, gradient_accumulation_steps: int | None = None, checkpoint_root_directory: str | None = None, checkpointing_options: CheckpointManagerOptions | None = None, metrics_logging_options: MetricsLoggerOptions | None = None, profiler_options: ProfilerOptions | None = None, perf_metrics_options: PerfMetricsOptions | None = None, data_sharding_axis: Tuple[str, ...] = ('fsdp',), max_inflight_computations: int = 2, metrics_prefix: str = '', pbar_description: str | None = 'Training', algorithm: str = 'dpo', beta: float = 0.1, lambda_orpo: float = 0.1, label_smoothing: float = 0.0, max_prompt_length: int | None = None, max_response_length: int | None = None)#
DPO/ORPO Training Config.
- algorithm: str#
- beta: float#
- label_smoothing: float#
- lambda_orpo: float#
- max_prompt_length: int | None#
- max_response_length: int | None#
- class tunix.MetricsLogger(metrics_logger_options: MetricsLoggerOptions | None = None)#
Simple Metrics logger.
Log metrics to multiple backends. If no backends are specified, it will log to the default backends.
- close()#
Closes all registered logging backends.
- get_metric(metrics_prefix, metric_name: str, mode: Mode | str)#
Returns the mean metric value for the given metric name and mode.
- get_metric_history(metrics_prefix, metric_name: str, mode: Mode | str)#
Returns all past metric values for the given metric name and mode.
- log(metrics_prefix: str, metric_name: str, scalar_value: float | ndarray, mode: Mode | str, step: int)#
Logs the scalar metric value to local history and via jax.monitoring.
- metric_exists(metrics_prefix, metric_name: str, mode: Mode | str) bool#
Checks if the metric exists for the given metric name and mode.
- class tunix.MetricsLoggerOptions(log_dir: str, project_name: str = 'tunix', run_name: str = '', flush_every_n_steps: int = 100, backend_factories: list[Callable[[], LoggingBackend]] | None = None)#
Metrics Logger options.
- backend_factories: list[Callable[[], LoggingBackend]] | None = None#
- create_backends() list[LoggingBackend]#
Factory method to create a fresh set of live backends.
- flush_every_n_steps: int = 100#
- log_dir: str#
- project_name: str = 'tunix'#
- run_name: str = ''#