Reinforcement learning (RL)

Contents

Reinforcement learning (RL)#

GRPOConfig(*[, algo_variant, ...])

Configuration for GRPO algorithms.

GRPOLearner(rl_cluster, algo_config, reward_fns)

GRPO (Group Relative Policy Optimization) learner.

RewardFn

PPOConfig(*[, algo_variant, ...])

Configuration for PPO learner.

PPOLearner(rl_cluster, ppo_config[, ...])

PPO (Proximal Policy Optimization) learner.

ClusterConfig(*, role_to_mesh[, ...])

Cluster config.

RLCluster(*, actor[, critic, reference, ...])

RLCluster.

RLTrainingConfig(*, eval_every_n_steps[, ...])

RLTraining config.

Role(*values)

Role of the model.

RolloutConfig(max_tokens_to_generate, ...)

Configuration for the rollout worker.


class tunix.GRPOConfig(*, algo_variant: str = 'grpo', advantage_estimator: str = 'grpo', policy_loss_fn: str = 'grpo', reward_manager: str = 'sequence-level', kl_clamp_value: float | None = None, loss_agg_mode: str = 'sequence-mean-token-mean', loss_algo: str = 'grpo', num_generations: int = 2, num_iterations: int = 1, beta: float = 0.04, kl_loss_mode: str = 'kl', epsilon: float = 0.2)#

Configuration for GRPO algorithms.

algo_variant#

The algorithm variant to use. Default: grpo.

Type:

str

advantage_estimator#

The advantage estimator to use. Default: grpo.

Type:

str

policy_loss_fn#

The policy loss function to use. Default: grpo.

Type:

str

loss_agg_mode#

The aggregation mode for the loss function. Supported values include token-mean, sequence-mean-token-mean, sequence-mean-token-scale, seq-mean-token-sum, and sequence-mean-token-sum-norm. Default: sequence-mean-token-mean.

Type:

str

reward_manager#

The reward manager to use. Default: sequence-level.

Type:

str

loss_algo#

The loss algorithm to use. To be deprecated.

Type:

str

num_generations#

The number of times the policy generates multiple responses for a given prompt within a single training step. This corresponds to ‘G’ in Algorithm 1 in the paper. A higher value means more samples are used to compute relative advantages.

Type:

int

num_iterations#

The number of iterations per batch (𝜇 in GRPO algo 1).

Type:

int

beta#

The coefficient for the KL divergence penalty (𝛽) in the GRPO loss function. This term prevents policy updates from deviating too far from the reference model. A value of 0.0 means no KL penalty is applied.

Type:

float

kl_loss_mode#

The divergence mode used for KL penalty estimation. Default: kl.

Type:

str

epsilon#

Epsilon value for clipping (𝜀 in GRPO loss in paper). Similar to PPO, it ensures stable updates.

Type:

float

epsilon_high#

Epsilon value for upper bound clipping.

loss_algo#

use GRPO or GSPO for loss computation. GRPO loss is per-batch normalized instead of per-response normalized as mentioned in the paper. For GSPO, we use gspo-token loss which is more flexible.

Type:

str

References

advantage_estimator: str#
algo_variant: str#
beta: float = 0.04#
epsilon: float = 0.2#
kl_clamp_value: float | None#
kl_loss_mode: str = 'kl'#
loss_agg_mode: str = 'sequence-mean-token-mean'#
loss_algo: str = 'grpo'#
num_generations: int = 2#
num_iterations: int = 1#
policy_loss_fn: str#
reward_manager: str#

class tunix.GRPOLearner(rl_cluster: RLCluster, algo_config: TGrpoConfig, reward_fns: Callable[[...], List[float]] | List[Callable[[...], List[float]]], metric_fns: Sequence[Callable[[...], Dict[str, Tuple[Array | ndarray | bool | number | bool | int | float | complex | str, Callable[[Array | ndarray | bool | number | bool | int | float | complex], Array | ndarray | bool | number | bool | int | float | complex] | None]]]] | None = None, data_shuffle_seed: int | None = None)#

GRPO (Group Relative Policy Optimization) learner.

GRPO is a reinforcement learning algorithm designed to enhance the reasoning abilities of large language models, like mathematical problem-solving. It is a variant of Proximal Policy Optimization (PPO) that reduces memory usage by eliminating the need for a separate value function model. GRPO works by generating multiple responses for a given prompt, evaluating these responses using a reward model, and then calculating a relative advantage based on the group’s performance to update the policy.

train(train_ds: Iterable[Dict[str, List[str] | Array | ndarray | bool | number | bool | int | float | complex]], eval_ds: Iterable[Dict[str, List[str] | Array | ndarray | bool | number | bool | int | float | complex]] | None = None, skip_jit: bool = False) None#

GRPO training loop.

Algorithm as below: extract from https://arxiv.org/abs/2402.03300

Input:
    initial policy model πθinit;
    reward models rφ;
    task prompts D;
    hyperparameters ε, β, μ

policy model πθ ← πθinit
for iteration = 1, ..., I do
  reference model πref ← πθ
  for step = 1, ..., M do
    Sample a batch D♭ from D
    Update the old policy model πθold ← πθ
    Sample G outputs {oi}G_i=1 ~ πθold(· | q) for each question q ∈ D♭
    Compute rewards {ri}G_i=1 for each sampled output oi by running rφ
    Compute Âi,t for the t-th token of oi through group relative
    advantage estimation.
    for GRPO iteration = 1, ..., μ do
      Update the policy model πθ by maximizing the GRPO objective
      (Equation 21)
  Update rφ through continuous training using a replay mechanism.
Output πθ

Note

  1. The outer loop (I) is ignored for now because we never update the reference model for now.

  2. Currently sample and train hold the same referece to the model. So we also omit the step to update the sampler model.

Parameters:
  • train_ds – An iterable of training input data, where each element is a dictionary containing the key ‘prompts’.

  • eval_ds – An iterable of evaluation input data, where each element is a dictionary containing the key ‘prompts’.

  • skip_jit – Whether to skip JIT compilation of the training loop.


tunix.RewardFn#

alias of Callable[[…], List[float]]


class tunix.PPOConfig(*, algo_variant: str = 'ppo', advantage_estimator: str = 'gae', policy_loss_fn: str = 'ppo', reward_manager: str = 'sequence-level', kl_clamp_value: float | None = None, value_loss_fn: str = 'ppo', num_iterations: int = 1, gamma: float = 1.0, gae_lambda: float = 0.95, beta: float = 0.04, epsilon: float = 0.2, epsilon_low: float | None = None, epsilon_high: float | None = None, epsilon_c: float | None = None, entropy_coef: float | None = None, clip_range_value: float = 0.2, kl_method: str = 'low_var_kl')#

Configuration for PPO learner.

algo_variant#

The algorithm variant to use. Default: ppo.

Type:

str

advantage_estimator#

The advantage estimator to use. Default: gae.

Type:

str

policy_loss_fn#

The policy loss function to use. Default: ppo.

Type:

str

reward_manager#

The reward manager to use. Default: sequence-level.

Type:

str

num_iterations#

The number of optimization epochs per batch of rollouts. This corresponds to the number of times the policy updates its weights for a given batch of rollouts.

Type:

int

mini_batch_size#

The batch size on which the actual model updates happen. The rollout phase (generate_and_compute_advantages) happen on a larger batch, which is then split into “mini-batches”.

gamma#

The discount factor for future rewards in GAE.

Type:

float

gae_lambda#

The lambda parameter for Generalized Advantage Estimation (GAE).

Type:

float

beta#

The coefficient for the KL divergence penalty.

Type:

float

epsilon#

Epsilon value for clipping the ratio for the policy objective.

Type:

float

epsilon_low#

Lower bound for clipping the ratio for the policy objective. Set to epsilon if not provided.

Type:

float | None

epsilon_high#

Upper bound for clipping the ratio for the policy objective. Set to epsilon if not provided.

Type:

float | None

epsilon_c#

Lower bound for clipping for dual-clip PPO. If not provided, we don’t do dual-clip PPO. Reference: https://arxiv.org/abs/1912.09729.

Type:

float | None

entropy_coef#

Entropy coefficient for the policy loss. Set to None or 0.0 to disable entropy regularization.

Type:

float | None

clip_range_value#

The range for clipping the value function loss.

Type:

float

kl_method#

The method for computing KL divergence. Must be one of ["low_var_kl", "kl", "mse_kl"].

Type:

str

beta: float#
clip_range_value: float#
entropy_coef: float | None#
epsilon: float#
epsilon_c: float | None#
epsilon_high: float | None#
epsilon_low: float | None#
gae_lambda: float#
gamma: float#
kl_method: str#
num_iterations: int#
value_loss_fn: str#

class tunix.PPOLearner(rl_cluster: RLCluster, ppo_config: PPOConfig, reward_fns: Callable[[...], List[float]] | List[Callable[[...], List[float]]] | None = None, metric_fns: Sequence[Callable[[...], Dict[str, Tuple[Array | ndarray | bool | number | bool | int | float | complex | str, Callable[[Array | ndarray | bool | number | bool | int | float | complex], Array | ndarray | bool | number | bool | int | float | complex] | None]]]] | None = None, data_shuffle_seed: int | None = None)#

PPO (Proximal Policy Optimization) learner.

PPO is a reinforcement learning algorithm that fine-tunes models using an actor-critic architecture. It optimizes a clipped surrogate objective function to ensure stable policy updates, preventing large, destructive changes. The actor (policy model) learns what actions to take, while the critic (value model) estimates the value of states to help calculate advantages. This approach balances exploration and exploitation, making it a robust choice for a wide range of RL tasks.

References: - https://arxiv.org/abs/1707.06347

train(train_ds: Iterable[Dict[str, List[str] | Array | ndarray | bool | number | bool | int | float | complex]], eval_ds: Iterable[Dict[str, List[str] | Array | ndarray | bool | number | bool | int | float | complex]] | None = None, skip_jit: bool = False) None#

PPO training loop.


class tunix.ClusterConfig(*, role_to_mesh: dict[Role, Mesh], role_to_logical_axis_rule: dict[Role, Sequence[tuple[str, str | tuple[str, ...] | None]]] | None = None, rollout_engine: str | type[BaseRollout] = 'vanilla', offload_to_cpu: bool = False, training_config: RLTrainingConfig, rollout_config: dict[Mode, RolloutConfig] | RolloutConfig)#

Cluster config.

role_to_mesh#

Mapping from model role to mesh. Key config for colocated vs disaggregated setup.

Type:

dict[tunix.rl.rl_cluster.Role, jax._src.mesh.Mesh]

role_to_logical_axis_rule#

Mapping from model role to logical axis rule. This is used when models are sharded with logical axis and expects a logical to physical axis mapping at runtime.

Type:

dict[tunix.rl.rl_cluster.Role, collections.abc.Sequence[tuple[str, str | tuple[str, …] | None]]] | None

rollout_engine#

Rollout engine to use. E.g. “vanilla”, “vllm”, “sglang_jax”. Alternatively, if a subclass of base_rollout.BaseRollout is provided, it will be used as the rollout engine.

Type:

str | type[tunix.rl.rollout.base_rollout.BaseRollout]

offload_to_cpu#

Whether to offload models to CPU at each step..

Type:

bool

training_config#

RL training config.

Type:

tunix.rl.rl_cluster.RLTrainingConfig

rollout_config#

Rollout config. It may be different for different modes, e.g. TRAIN vs EVAL.

Type:

dict[tunix.rl.rl_cluster.Mode, tunix.rl.rollout.base_rollout.RolloutConfig] | tunix.rl.rollout.base_rollout.RolloutConfig

rollout_vllm_model_version#

Model version for vllm rollout engine.

rollout_vllm_lora_config#

LoRA config for vllm rollout engine.

rollout_vllm_hbm_utilization#

The percentage of TPU/GPU HBM allocated the vllm rollout engine.

rollout_vllm_init_with_random_weights#

Init the vllm TPU backend model with random weights instead of loading from the given path.

rollout_vllm_tpu_backend_type#

The TPU Jax backend type for vllm rollout engine, E.g. “jax”, “torchax” or “pytorch_xla”.

offload_to_cpu: bool = False#
role_to_logical_axis_rule: dict[Role, Sequence[tuple[str, str | tuple[str, ...] | None]]] | None = None#
role_to_mesh: dict[Role, Mesh]#
rollout_config: dict[Mode, RolloutConfig] | RolloutConfig#
rollout_engine: str | type[BaseRollout] = 'vanilla'#
training_config: RLTrainingConfig#

class tunix.RLCluster(*, actor: Module | str, critic: Module | str | None = None, reference: Module | str | None = None, reward: Module | str | None = None, tokenizer: Any | None, cluster_config: ClusterConfig, perf_config: PerfMetricsConfig | None = None)#

RLCluster.

property actor_trainer: Trainer#
buffer_metrics(metrics: Dict[str, Tuple[Array | ndarray | bool | number | bool | int | float | complex | str, Callable[[Array | ndarray | bool | number | bool | int | float | complex], Array | ndarray | bool | number | bool | int | float | complex] | None]], mode: Mode = Mode.TRAIN) None#

Buffers rl metrics to be logged.

Actual logging will happen when global steps are incremented.

Parameters:
  • metrics – A dictionary mapping metric names to a tuple containing the metric value and an optional aggregation function.

  • mode – The mode of the workload, either TRAIN or EVAL.

buffer_metrics_async(metrics: Dict[str, Tuple[Array | ndarray | bool | number | bool | int | float | complex | str, Callable[[Array | ndarray | bool | number | bool | int | float | complex], Array | ndarray | bool | number | bool | int | float | complex] | None]], mode: Mode = Mode.TRAIN, step: int = 0) None#

Buffers rl metrics to be logged for async training.

Actual logging will happen when global steps are incremented.

Parameters:
  • metrics – A dictionary mapping metric names to a tuple containing the metric value and an optional aggregation function.

  • mode – The mode of the workload, either TRAIN or EVAL.

  • step – The step number for the metrics. Only used in TRAIN mode.

close()#
property critic_trainer: Trainer#
generate(prompts: list[str] | list[list[dict[str, str]]], apply_chat_template: bool = False, mode: Mode = Mode.TRAIN, micro_batch_size: int | None = None, trace_tags: Mapping[str, Any] | None = None, max_generation_steps: int | None = None) RolloutOutput#

Generates text from the given prompts.

Parameters:
  • prompts – A list of prompts to generate text from. If apply_chat_template is True, this should be a list of conversations (each a list of dictionaries with ‘role’ and ‘content’). Otherwise, it should be a list of strings.

  • apply_chat_template – Whether to apply chat template to the prompts.

  • mode – The mode of rollout, either TRAIN or EVAL.

  • micro_batch_size – The micro-batch size for generation. If None, no micro-batching is performed.

  • trace_tags – Optional tags to add to the performance tracer.

Returns:

A RolloutOutput object containing the generated text and other info.

get_actor_per_token_logps(prompt_tokens: Array, completion_tokens: Array, pad_id: int, eos_id: int, micro_batch_size: int | None = None, temperature: float | None = None) Array#

Gets per-token logps from the actor model on the trainer side.

Mirrors get_ref_per_token_logps — must pass through the rollout temperature so the actor’s recomputed logps match the temperature scaling used at sampling time (otherwise log_softmax(logits/T_sample) vs log_softmax(logits) yields a multi-nat artifact diff vs vllm’s processed_logprobs).

get_old_per_token_logps(prompt_tokens: Array, completion_tokens: Array, micro_batch_size: int | None = None) Array#

Gets the per-token logps of the current policy model.

get_ref_per_token_logps(prompt_tokens: Array, completion_tokens: Array, pad_id: int, eos_id: int, micro_batch_size: int | None = None) Array#

Gets the per-token logps of the reference model.

get_rewards(prompt_tokens: Array, completion_tokens: Array, pad_id: int, eos_id: int) Array#
get_rollout_config(mode: Mode) RolloutConfig#

Returns the rollout config for the given mode.

get_values(prompt_tokens: Array, completion_tokens: Array, pad_id: int, eos_id: int) Array#
property inference_worker: InferenceWorker#
property perf: PerfTracer | NoopTracer#

The v1 performance tracer.

property perf_v2: PerfTracer | NoopTracer#

The v2 performance tracer.

property rollout: BaseRollout#
sync_weights()#

Syncs the weights of between the sampler model and trainer model.

update_actor(train_ds, eval_ds, skip_jit=False)#
update_critic(train_ds, eval_ds, skip_jit=False)#
with_external_metrics_logger(external_metrics_logger: Callable[[MetricsBuffer], None])#

class tunix.RLTrainingConfig(*, eval_every_n_steps: int, max_steps: int | None = None, gradient_accumulation_steps: int | None = None, checkpoint_root_directory: str | None = None, checkpointing_options: CheckpointManagerOptions | None = None, metrics_logging_options: MetricsLoggerOptions | None = None, profiler_options: ProfilerOptions | None = None, perf_metrics_options: PerfMetricsOptions | None = None, data_sharding_axis: Tuple[str, ...] = ('fsdp',), max_inflight_computations: int = 2, metrics_prefix: str = '', pbar_description: str | None = 'Training', max_seq_token_per_tpu: int | None = None, actor_optimizer: GradientTransformation, critic_optimizer: GradientTransformation | None = None, mini_batch_size: int | None = None, train_micro_batch_size: int | None = None, rollout_micro_batch_size: int | None = None, compute_logps_micro_batch_size: int | None = None)#

RLTraining config.

actor_optimizer#

Optimizer for the actor model.

Type:

optax._src.base.GradientTransformation

critic_optimizer#

Optimizer for the critic model. If None, the critic model will be trained in the same optimizer as the actor model.

Type:

optax._src.base.GradientTransformation | None

mini_batch_size#

The mini-batch size used for policy weight updates. One mini-batch corresponds to one optimizer update. mini_batch_size must be divisible by the global batch size.

Type:

int | None

train_micro_batch_size#

The micro-batch size used for gradient accumulation at training time. train_micro_batch_size must be divisible by mini_batch_size.

Type:

int | None

rollout_micro_batch_size#

The micro-batch size used for model rollouts.

Type:

int | None

compute_logps_micro_batch_size#

The micro-batch size used for computing log probabilities (e.g. for reference and old policy models).

Type:

int | None

actor_optimizer: GradientTransformation#
compute_logps_micro_batch_size: int | None#
critic_optimizer: GradientTransformation | None#
mini_batch_size: int | None#
rollout_micro_batch_size: int | None#
train_micro_batch_size: int | None#

class tunix.Role(*values)#

Role of the model.

ACTOR = 'actor'#
CRITIC = 'critic'#
REFERENCE = 'reference'#
REWARD = 'reward'#
ROLLOUT = 'rollout'#

class tunix.RolloutConfig(max_tokens_to_generate: int = 64, temperature: float = 0.9, top_p: float | None = 1.0, top_k: int | None = None, seed: Array | None = None, max_prompt_length: int = 64, kv_cache_size: int = 1024, data_type: dtype | None = None, eos_tokens: list[int] | None = None, rollout_mapping_config: MappingConfig | None = None, tensor_parallel_size: int = -1, data_parallel_size: int = -1, expert_parallel_size: int = 1, return_logprobs: bool = False, rollout_vllm_server_mode: bool = False, rollout_vllm_model_version: str = '', rollout_vllm_lora_config: dict[str, ~typing.Any] | None=None, rollout_vllm_hbm_utilization: float = 0.2, rollout_vllm_init_with_random_weights: bool = True, rollout_vllm_tpu_backend_type: str | None = None, rollout_vllm_async_scheduling: bool = False, rollout_vllm_logprobs_mode: str = 'processed_logprobs', rollout_vllm_hf_config_path: str | None = None, rollout_vllm_additional_config: dict[str, ~typing.Any] | None=None, rollout_vllm_enable_dp_attention: bool = False, rollout_vllm_delete_dst_buffers: bool = True, rollout_vllm_max_num_batched_tokens: int | None = None, rollout_vllm_max_num_seqs: int | None = None, rollout_vllm_reshard_chunk_size: int | None = None, rollout_vllm_kwargs: dict[str, ~typing.Any]=<factory>, rollout_vllm_sampling_kwargs: dict[str, ~typing.Any]=<factory>, rollout_sglang_jax_model_version: str = '', rollout_sglang_jax_context_length: int | None = None, rollout_sglang_jax_mem_fraction_static: float = 0.2, rollout_sglang_jax_init_with_random_weights: bool = True, rollout_sglang_jax_disable_radix_cache: bool = True, rollout_sglang_jax_enable_deterministic_sampling: bool = False, rollout_sglang_jax_use_sort_for_toppk_minp: bool = True, rollout_sglang_jax_enable_static_lora: bool = False, rollout_sglang_jax_enable_single_process: bool = True, rollout_sglang_jax_lora_target_modules: List[str] | None = None, rollout_sglang_jax_max_lora_rank: int | None = None, rollout_sglang_jax_lora_scaling: float | None = None, rollout_sglang_jax_precompile_bs_paddings: List[int] | None = None, rollout_sglang_jax_precompile_token_paddings: List[int] | None = None, rollout_sglang_jax_chunked_prefill_size: int | None = -1, rollout_sglang_jax_page_size: int = 128, rollout_sglang_jax_load_format: str = 'auto', rollout_sglang_jax_max_running_requests: int | None = None, rollout_sglang_jax_log_level: str | None = 'info', rollout_sglang_jax_kwargs: dict[str, ~typing.Any]=<factory>)#

Configuration for the rollout worker.

Fields should be mapped to a subset of vLLM sampling knobs https://docs.vllm.ai/en/v0.6.4/dev/sampling_params.html

data_parallel_size: int = -1#
data_type: dtype | None = None#
eos_tokens: list[int] | None = None#
expert_parallel_size: int = 1#
kv_cache_size: int = 1024#
max_prompt_length: int = 64#
max_tokens_to_generate: int = 64#
return_logprobs: bool = False#
rollout_mapping_config: MappingConfig | None = None#
rollout_sglang_jax_chunked_prefill_size: int | None = -1#
rollout_sglang_jax_context_length: int | None = None#
rollout_sglang_jax_disable_radix_cache: bool = True#
rollout_sglang_jax_enable_deterministic_sampling: bool = False#
rollout_sglang_jax_enable_single_process: bool = True#
rollout_sglang_jax_enable_static_lora: bool = False#
rollout_sglang_jax_init_with_random_weights: bool = True#
rollout_sglang_jax_kwargs: dict[str, Any]#
rollout_sglang_jax_load_format: str = 'auto'#
rollout_sglang_jax_log_level: str | None = 'info'#
rollout_sglang_jax_lora_scaling: float | None = None#
rollout_sglang_jax_lora_target_modules: List[str] | None = None#
rollout_sglang_jax_max_lora_rank: int | None = None#
rollout_sglang_jax_max_running_requests: int | None = None#
rollout_sglang_jax_mem_fraction_static: float = 0.2#
rollout_sglang_jax_model_version: str = ''#
rollout_sglang_jax_page_size: int = 128#
rollout_sglang_jax_precompile_bs_paddings: List[int] | None = None#
rollout_sglang_jax_precompile_token_paddings: List[int] | None = None#
rollout_sglang_jax_use_sort_for_toppk_minp: bool = True#
rollout_vllm_additional_config: dict[str, Any] | None = None#
rollout_vllm_async_scheduling: bool = False#
rollout_vllm_delete_dst_buffers: bool = True#
rollout_vllm_enable_dp_attention: bool = False#
rollout_vllm_hbm_utilization: float = 0.2#
rollout_vllm_hf_config_path: str | None = None#
rollout_vllm_init_with_random_weights: bool = True#
rollout_vllm_kwargs: dict[str, Any]#
rollout_vllm_logprobs_mode: str = 'processed_logprobs'#
rollout_vllm_lora_config: dict[str, Any] | None = None#
rollout_vllm_max_num_batched_tokens: int | None = None#
rollout_vllm_max_num_seqs: int | None = None#
rollout_vllm_model_version: str = ''#
rollout_vllm_reshard_chunk_size: int | None = None#
rollout_vllm_sampling_kwargs: dict[str, Any]#
rollout_vllm_server_mode: bool = False#
rollout_vllm_tpu_backend_type: str | None = None#
seed: Array | None = None#
temperature: float = 0.9#
tensor_parallel_size: int = -1#
top_k: int | None = None#
top_p: float | None = 1.0#