Reinforcement learning (RL)#
|
Configuration for GRPO algorithms. |
|
GRPO (Group Relative Policy Optimization) learner. |
|
Configuration for PPO learner. |
|
PPO (Proximal Policy Optimization) learner. |
|
Cluster config. |
|
RLCluster. |
|
RLTraining config. |
|
Role of the model. |
|
Configuration for the rollout worker. |
- class tunix.GRPOConfig(*, algo_variant: str = 'grpo', advantage_estimator: str = 'grpo', policy_loss_fn: str = 'grpo', reward_manager: str = 'sequence-level', kl_clamp_value: float | None = None, loss_agg_mode: str = 'sequence-mean-token-mean', loss_algo: str = 'grpo', num_generations: int = 2, num_iterations: int = 1, beta: float = 0.04, kl_loss_mode: str = 'kl', epsilon: float = 0.2)#
Configuration for GRPO algorithms.
- algo_variant#
The algorithm variant to use. Default:
grpo.- Type:
str
- advantage_estimator#
The advantage estimator to use. Default:
grpo.- Type:
str
- policy_loss_fn#
The policy loss function to use. Default:
grpo.- Type:
str
- loss_agg_mode#
The aggregation mode for the loss function. Supported values include
token-mean,sequence-mean-token-mean,sequence-mean-token-scale,seq-mean-token-sum, andsequence-mean-token-sum-norm. Default:sequence-mean-token-mean.- Type:
str
- reward_manager#
The reward manager to use. Default:
sequence-level.- Type:
str
- loss_algo#
The loss algorithm to use. To be deprecated.
- Type:
str
- num_generations#
The number of times the policy generates multiple responses for a given prompt within a single training step. This corresponds to ‘G’ in Algorithm 1 in the paper. A higher value means more samples are used to compute relative advantages.
- Type:
int
- num_iterations#
The number of iterations per batch (𝜇 in GRPO algo 1).
- Type:
int
- beta#
The coefficient for the KL divergence penalty (𝛽) in the GRPO loss function. This term prevents policy updates from deviating too far from the reference model. A value of 0.0 means no KL penalty is applied.
- Type:
float
- kl_loss_mode#
The divergence mode used for KL penalty estimation. Default:
kl.- Type:
str
- epsilon#
Epsilon value for clipping (𝜀 in GRPO loss in paper). Similar to PPO, it ensures stable updates.
- Type:
float
- epsilon_high#
Epsilon value for upper bound clipping.
- loss_algo#
use GRPO or GSPO for loss computation. GRPO loss is per-batch normalized instead of per-response normalized as mentioned in the paper. For GSPO, we use gspo-token loss which is more flexible.
- Type:
str
References
- advantage_estimator: str#
- algo_variant: str#
- beta: float = 0.04#
- epsilon: float = 0.2#
- kl_clamp_value: float | None#
- kl_loss_mode: str = 'kl'#
- loss_agg_mode: str = 'sequence-mean-token-mean'#
- loss_algo: str = 'grpo'#
- num_generations: int = 2#
- num_iterations: int = 1#
- policy_loss_fn: str#
- reward_manager: str#
- class tunix.GRPOLearner(rl_cluster: RLCluster, algo_config: TGrpoConfig, reward_fns: Callable[[...], List[float]] | List[Callable[[...], List[float]]], metric_fns: Sequence[Callable[[...], Dict[str, Tuple[Array | ndarray | bool | number | bool | int | float | complex | str, Callable[[Array | ndarray | bool | number | bool | int | float | complex], Array | ndarray | bool | number | bool | int | float | complex] | None]]]] | None = None, data_shuffle_seed: int | None = None)#
GRPO (Group Relative Policy Optimization) learner.
GRPO is a reinforcement learning algorithm designed to enhance the reasoning abilities of large language models, like mathematical problem-solving. It is a variant of Proximal Policy Optimization (PPO) that reduces memory usage by eliminating the need for a separate value function model. GRPO works by generating multiple responses for a given prompt, evaluating these responses using a reward model, and then calculating a relative advantage based on the group’s performance to update the policy.
- train(train_ds: Iterable[Dict[str, List[str] | Array | ndarray | bool | number | bool | int | float | complex]], eval_ds: Iterable[Dict[str, List[str] | Array | ndarray | bool | number | bool | int | float | complex]] | None = None, skip_jit: bool = False) None#
GRPO training loop.
Algorithm as below: extract from https://arxiv.org/abs/2402.03300
Input: initial policy model πθinit; reward models rφ; task prompts D; hyperparameters ε, β, μ policy model πθ ← πθinit for iteration = 1, ..., I do reference model πref ← πθ for step = 1, ..., M do Sample a batch D♭ from D Update the old policy model πθold ← πθ Sample G outputs {oi}G_i=1 ~ πθold(· | q) for each question q ∈ D♭ Compute rewards {ri}G_i=1 for each sampled output oi by running rφ Compute Âi,t for the t-th token of oi through group relative advantage estimation. for GRPO iteration = 1, ..., μ do Update the policy model πθ by maximizing the GRPO objective (Equation 21) Update rφ through continuous training using a replay mechanism. Output πθNote
The outer loop (I) is ignored for now because we never update the reference model for now.
Currently sample and train hold the same referece to the model. So we also omit the step to update the sampler model.
- Parameters:
train_ds – An iterable of training input data, where each element is a dictionary containing the key ‘prompts’.
eval_ds – An iterable of evaluation input data, where each element is a dictionary containing the key ‘prompts’.
skip_jit – Whether to skip JIT compilation of the training loop.
- tunix.RewardFn#
alias of
Callable[[…],List[float]]
- class tunix.PPOConfig(*, algo_variant: str = 'ppo', advantage_estimator: str = 'gae', policy_loss_fn: str = 'ppo', reward_manager: str = 'sequence-level', kl_clamp_value: float | None = None, value_loss_fn: str = 'ppo', num_iterations: int = 1, gamma: float = 1.0, gae_lambda: float = 0.95, beta: float = 0.04, epsilon: float = 0.2, epsilon_low: float | None = None, epsilon_high: float | None = None, epsilon_c: float | None = None, entropy_coef: float | None = None, clip_range_value: float = 0.2, kl_method: str = 'low_var_kl')#
Configuration for PPO learner.
- algo_variant#
The algorithm variant to use. Default:
ppo.- Type:
str
- advantage_estimator#
The advantage estimator to use. Default:
gae.- Type:
str
- policy_loss_fn#
The policy loss function to use. Default:
ppo.- Type:
str
- reward_manager#
The reward manager to use. Default:
sequence-level.- Type:
str
- num_iterations#
The number of optimization epochs per batch of rollouts. This corresponds to the number of times the policy updates its weights for a given batch of rollouts.
- Type:
int
- mini_batch_size#
The batch size on which the actual model updates happen. The rollout phase (
generate_and_compute_advantages) happen on a larger batch, which is then split into “mini-batches”.
- gamma#
The discount factor for future rewards in GAE.
- Type:
float
- gae_lambda#
The lambda parameter for Generalized Advantage Estimation (GAE).
- Type:
float
- beta#
The coefficient for the KL divergence penalty.
- Type:
float
- epsilon#
Epsilon value for clipping the ratio for the policy objective.
- Type:
float
- epsilon_low#
Lower bound for clipping the ratio for the policy objective. Set to
epsilonif not provided.- Type:
float | None
- epsilon_high#
Upper bound for clipping the ratio for the policy objective. Set to
epsilonif not provided.- Type:
float | None
- epsilon_c#
Lower bound for clipping for dual-clip PPO. If not provided, we don’t do dual-clip PPO. Reference: https://arxiv.org/abs/1912.09729.
- Type:
float | None
- entropy_coef#
Entropy coefficient for the policy loss. Set to
Noneor0.0to disable entropy regularization.- Type:
float | None
- clip_range_value#
The range for clipping the value function loss.
- Type:
float
- kl_method#
The method for computing KL divergence. Must be one of
["low_var_kl", "kl", "mse_kl"].- Type:
str
- beta: float#
- clip_range_value: float#
- entropy_coef: float | None#
- epsilon: float#
- epsilon_c: float | None#
- epsilon_high: float | None#
- epsilon_low: float | None#
- gae_lambda: float#
- gamma: float#
- kl_method: str#
- num_iterations: int#
- value_loss_fn: str#
- class tunix.PPOLearner(rl_cluster: RLCluster, ppo_config: PPOConfig, reward_fns: Callable[[...], List[float]] | List[Callable[[...], List[float]]] | None = None, metric_fns: Sequence[Callable[[...], Dict[str, Tuple[Array | ndarray | bool | number | bool | int | float | complex | str, Callable[[Array | ndarray | bool | number | bool | int | float | complex], Array | ndarray | bool | number | bool | int | float | complex] | None]]]] | None = None, data_shuffle_seed: int | None = None)#
PPO (Proximal Policy Optimization) learner.
PPO is a reinforcement learning algorithm that fine-tunes models using an actor-critic architecture. It optimizes a clipped surrogate objective function to ensure stable policy updates, preventing large, destructive changes. The actor (policy model) learns what actions to take, while the critic (value model) estimates the value of states to help calculate advantages. This approach balances exploration and exploitation, making it a robust choice for a wide range of RL tasks.
References: - https://arxiv.org/abs/1707.06347
- class tunix.ClusterConfig(*, role_to_mesh: dict[Role, Mesh], role_to_logical_axis_rule: dict[Role, Sequence[tuple[str, str | tuple[str, ...] | None]]] | None = None, rollout_engine: str | type[BaseRollout] = 'vanilla', offload_to_cpu: bool = False, training_config: RLTrainingConfig, rollout_config: dict[Mode, RolloutConfig] | RolloutConfig)#
Cluster config.
- role_to_mesh#
Mapping from model role to mesh. Key config for colocated vs disaggregated setup.
- Type:
- role_to_logical_axis_rule#
Mapping from model role to logical axis rule. This is used when models are sharded with logical axis and expects a logical to physical axis mapping at runtime.
- Type:
dict[tunix.rl.rl_cluster.Role, collections.abc.Sequence[tuple[str, str | tuple[str, …] | None]]] | None
- rollout_engine#
Rollout engine to use. E.g. “vanilla”, “vllm”, “sglang_jax”. Alternatively, if a subclass of
base_rollout.BaseRolloutis provided, it will be used as the rollout engine.- Type:
str | type[tunix.rl.rollout.base_rollout.BaseRollout]
- offload_to_cpu#
Whether to offload models to CPU at each step..
- Type:
bool
- training_config#
RL training config.
- rollout_config#
Rollout config. It may be different for different modes, e.g. TRAIN vs EVAL.
- Type:
dict[tunix.rl.rl_cluster.Mode, tunix.rl.rollout.base_rollout.RolloutConfig] | tunix.rl.rollout.base_rollout.RolloutConfig
- rollout_vllm_model_version#
Model version for vllm rollout engine.
- rollout_vllm_lora_config#
LoRA config for vllm rollout engine.
- rollout_vllm_hbm_utilization#
The percentage of TPU/GPU HBM allocated the vllm rollout engine.
- rollout_vllm_init_with_random_weights#
Init the vllm TPU backend model with random weights instead of loading from the given path.
- rollout_vllm_tpu_backend_type#
The TPU Jax backend type for vllm rollout engine, E.g. “jax”, “torchax” or “pytorch_xla”.
- offload_to_cpu: bool = False#
- role_to_logical_axis_rule: dict[Role, Sequence[tuple[str, str | tuple[str, ...] | None]]] | None = None#
- rollout_config: dict[Mode, RolloutConfig] | RolloutConfig#
- rollout_engine: str | type[BaseRollout] = 'vanilla'#
- training_config: RLTrainingConfig#
- class tunix.RLCluster(*, actor: Module | str, critic: Module | str | None = None, reference: Module | str | None = None, reward: Module | str | None = None, tokenizer: Any | None, cluster_config: ClusterConfig, perf_config: PerfMetricsConfig | None = None)#
RLCluster.
- property actor_trainer: Trainer#
- buffer_metrics(metrics: Dict[str, Tuple[Array | ndarray | bool | number | bool | int | float | complex | str, Callable[[Array | ndarray | bool | number | bool | int | float | complex], Array | ndarray | bool | number | bool | int | float | complex] | None]], mode: Mode = Mode.TRAIN) None#
Buffers rl metrics to be logged.
Actual logging will happen when global steps are incremented.
- Parameters:
metrics – A dictionary mapping metric names to a tuple containing the metric value and an optional aggregation function.
mode – The mode of the workload, either TRAIN or EVAL.
- buffer_metrics_async(metrics: Dict[str, Tuple[Array | ndarray | bool | number | bool | int | float | complex | str, Callable[[Array | ndarray | bool | number | bool | int | float | complex], Array | ndarray | bool | number | bool | int | float | complex] | None]], mode: Mode = Mode.TRAIN, step: int = 0) None#
Buffers rl metrics to be logged for async training.
Actual logging will happen when global steps are incremented.
- Parameters:
metrics – A dictionary mapping metric names to a tuple containing the metric value and an optional aggregation function.
mode – The mode of the workload, either TRAIN or EVAL.
step – The step number for the metrics. Only used in TRAIN mode.
- close()#
- property critic_trainer: Trainer#
- generate(prompts: list[str] | list[list[dict[str, str]]], apply_chat_template: bool = False, mode: Mode = Mode.TRAIN, micro_batch_size: int | None = None, trace_tags: Mapping[str, Any] | None = None, max_generation_steps: int | None = None) RolloutOutput#
Generates text from the given prompts.
- Parameters:
prompts – A list of prompts to generate text from. If
apply_chat_templateis True, this should be a list of conversations (each a list of dictionaries with ‘role’ and ‘content’). Otherwise, it should be a list of strings.apply_chat_template – Whether to apply chat template to the prompts.
mode – The mode of rollout, either TRAIN or EVAL.
micro_batch_size – The micro-batch size for generation. If None, no micro-batching is performed.
trace_tags – Optional tags to add to the performance tracer.
- Returns:
A
RolloutOutputobject containing the generated text and other info.
- get_actor_per_token_logps(prompt_tokens: Array, completion_tokens: Array, pad_id: int, eos_id: int, micro_batch_size: int | None = None, temperature: float | None = None) Array#
Gets per-token logps from the actor model on the trainer side.
Mirrors
get_ref_per_token_logps— must pass through the rollout temperature so the actor’s recomputed logps match the temperature scaling used at sampling time (otherwise log_softmax(logits/T_sample) vs log_softmax(logits) yields a multi-nat artifact diff vs vllm’sprocessed_logprobs).
- get_old_per_token_logps(prompt_tokens: Array, completion_tokens: Array, micro_batch_size: int | None = None) Array#
Gets the per-token logps of the current policy model.
- get_ref_per_token_logps(prompt_tokens: Array, completion_tokens: Array, pad_id: int, eos_id: int, micro_batch_size: int | None = None) Array#
Gets the per-token logps of the reference model.
- get_rollout_config(mode: Mode) RolloutConfig#
Returns the rollout config for the given mode.
- property inference_worker: InferenceWorker#
- property perf: PerfTracer | NoopTracer#
The v1 performance tracer.
- property perf_v2: PerfTracer | NoopTracer#
The v2 performance tracer.
- property rollout: BaseRollout#
- sync_weights()#
Syncs the weights of between the sampler model and trainer model.
- update_actor(train_ds, eval_ds, skip_jit=False)#
- update_critic(train_ds, eval_ds, skip_jit=False)#
- with_external_metrics_logger(external_metrics_logger: Callable[[MetricsBuffer], None])#
- class tunix.RLTrainingConfig(*, eval_every_n_steps: int, max_steps: int | None = None, gradient_accumulation_steps: int | None = None, checkpoint_root_directory: str | None = None, checkpointing_options: CheckpointManagerOptions | None = None, metrics_logging_options: MetricsLoggerOptions | None = None, profiler_options: ProfilerOptions | None = None, perf_metrics_options: PerfMetricsOptions | None = None, data_sharding_axis: Tuple[str, ...] = ('fsdp',), max_inflight_computations: int = 2, metrics_prefix: str = '', pbar_description: str | None = 'Training', max_seq_token_per_tpu: int | None = None, actor_optimizer: GradientTransformation, critic_optimizer: GradientTransformation | None = None, mini_batch_size: int | None = None, train_micro_batch_size: int | None = None, rollout_micro_batch_size: int | None = None, compute_logps_micro_batch_size: int | None = None)#
RLTraining config.
- actor_optimizer#
Optimizer for the actor model.
- critic_optimizer#
Optimizer for the critic model. If None, the critic model will be trained in the same optimizer as the actor model.
- Type:
- mini_batch_size#
The mini-batch size used for policy weight updates. One mini-batch corresponds to one optimizer update.
mini_batch_sizemust be divisible by the global batch size.- Type:
int | None
- train_micro_batch_size#
The micro-batch size used for gradient accumulation at training time.
train_micro_batch_sizemust be divisible bymini_batch_size.- Type:
int | None
- rollout_micro_batch_size#
The micro-batch size used for model rollouts.
- Type:
int | None
- compute_logps_micro_batch_size#
The micro-batch size used for computing log probabilities (e.g. for reference and old policy models).
- Type:
int | None
- actor_optimizer: GradientTransformation#
- compute_logps_micro_batch_size: int | None#
- critic_optimizer: GradientTransformation | None#
- mini_batch_size: int | None#
- rollout_micro_batch_size: int | None#
- train_micro_batch_size: int | None#
- class tunix.Role(*values)#
Role of the model.
- ACTOR = 'actor'#
- CRITIC = 'critic'#
- REFERENCE = 'reference'#
- REWARD = 'reward'#
- ROLLOUT = 'rollout'#
- class tunix.RolloutConfig(max_tokens_to_generate: int = 64, temperature: float = 0.9, top_p: float | None = 1.0, top_k: int | None = None, seed: Array | None = None, max_prompt_length: int = 64, kv_cache_size: int = 1024, data_type: dtype | None = None, eos_tokens: list[int] | None = None, rollout_mapping_config: MappingConfig | None = None, tensor_parallel_size: int = -1, data_parallel_size: int = -1, expert_parallel_size: int = 1, return_logprobs: bool = False, rollout_vllm_server_mode: bool = False, rollout_vllm_model_version: str = '', rollout_vllm_lora_config: dict[str, ~typing.Any] | None=None, rollout_vllm_hbm_utilization: float = 0.2, rollout_vllm_init_with_random_weights: bool = True, rollout_vllm_tpu_backend_type: str | None = None, rollout_vllm_async_scheduling: bool = False, rollout_vllm_logprobs_mode: str = 'processed_logprobs', rollout_vllm_hf_config_path: str | None = None, rollout_vllm_additional_config: dict[str, ~typing.Any] | None=None, rollout_vllm_enable_dp_attention: bool = False, rollout_vllm_delete_dst_buffers: bool = True, rollout_vllm_max_num_batched_tokens: int | None = None, rollout_vllm_max_num_seqs: int | None = None, rollout_vllm_reshard_chunk_size: int | None = None, rollout_vllm_kwargs: dict[str, ~typing.Any]=<factory>, rollout_vllm_sampling_kwargs: dict[str, ~typing.Any]=<factory>, rollout_sglang_jax_model_version: str = '', rollout_sglang_jax_context_length: int | None = None, rollout_sglang_jax_mem_fraction_static: float = 0.2, rollout_sglang_jax_init_with_random_weights: bool = True, rollout_sglang_jax_disable_radix_cache: bool = True, rollout_sglang_jax_enable_deterministic_sampling: bool = False, rollout_sglang_jax_use_sort_for_toppk_minp: bool = True, rollout_sglang_jax_enable_static_lora: bool = False, rollout_sglang_jax_enable_single_process: bool = True, rollout_sglang_jax_lora_target_modules: List[str] | None = None, rollout_sglang_jax_max_lora_rank: int | None = None, rollout_sglang_jax_lora_scaling: float | None = None, rollout_sglang_jax_precompile_bs_paddings: List[int] | None = None, rollout_sglang_jax_precompile_token_paddings: List[int] | None = None, rollout_sglang_jax_chunked_prefill_size: int | None = -1, rollout_sglang_jax_page_size: int = 128, rollout_sglang_jax_load_format: str = 'auto', rollout_sglang_jax_max_running_requests: int | None = None, rollout_sglang_jax_log_level: str | None = 'info', rollout_sglang_jax_kwargs: dict[str, ~typing.Any]=<factory>)#
Configuration for the rollout worker.
Fields should be mapped to a subset of vLLM sampling knobs https://docs.vllm.ai/en/v0.6.4/dev/sampling_params.html
- data_parallel_size: int = -1#
- eos_tokens: list[int] | None = None#
- expert_parallel_size: int = 1#
- kv_cache_size: int = 1024#
- max_prompt_length: int = 64#
- max_tokens_to_generate: int = 64#
- return_logprobs: bool = False#
- rollout_mapping_config: MappingConfig | None = None#
- rollout_sglang_jax_chunked_prefill_size: int | None = -1#
- rollout_sglang_jax_context_length: int | None = None#
- rollout_sglang_jax_disable_radix_cache: bool = True#
- rollout_sglang_jax_enable_deterministic_sampling: bool = False#
- rollout_sglang_jax_enable_single_process: bool = True#
- rollout_sglang_jax_enable_static_lora: bool = False#
- rollout_sglang_jax_init_with_random_weights: bool = True#
- rollout_sglang_jax_kwargs: dict[str, Any]#
- rollout_sglang_jax_load_format: str = 'auto'#
- rollout_sglang_jax_log_level: str | None = 'info'#
- rollout_sglang_jax_lora_scaling: float | None = None#
- rollout_sglang_jax_lora_target_modules: List[str] | None = None#
- rollout_sglang_jax_max_lora_rank: int | None = None#
- rollout_sglang_jax_max_running_requests: int | None = None#
- rollout_sglang_jax_mem_fraction_static: float = 0.2#
- rollout_sglang_jax_model_version: str = ''#
- rollout_sglang_jax_page_size: int = 128#
- rollout_sglang_jax_precompile_bs_paddings: List[int] | None = None#
- rollout_sglang_jax_precompile_token_paddings: List[int] | None = None#
- rollout_sglang_jax_use_sort_for_toppk_minp: bool = True#
- rollout_vllm_additional_config: dict[str, Any] | None = None#
- rollout_vllm_async_scheduling: bool = False#
- rollout_vllm_delete_dst_buffers: bool = True#
- rollout_vllm_enable_dp_attention: bool = False#
- rollout_vllm_hbm_utilization: float = 0.2#
- rollout_vllm_hf_config_path: str | None = None#
- rollout_vllm_init_with_random_weights: bool = True#
- rollout_vllm_kwargs: dict[str, Any]#
- rollout_vllm_logprobs_mode: str = 'processed_logprobs'#
- rollout_vllm_lora_config: dict[str, Any] | None = None#
- rollout_vllm_max_num_batched_tokens: int | None = None#
- rollout_vllm_max_num_seqs: int | None = None#
- rollout_vllm_model_version: str = ''#
- rollout_vllm_reshard_chunk_size: int | None = None#
- rollout_vllm_sampling_kwargs: dict[str, Any]#
- rollout_vllm_server_mode: bool = False#
- rollout_vllm_tpu_backend_type: str | None = None#
- temperature: float = 0.9#
- tensor_parallel_size: int = -1#
- top_k: int | None = None#
- top_p: float | None = 1.0#