Algorithms#
Tunix supports a wide array of SOTA algorithms for RL and SFT. Its modular design also allows users to easily extend Tunix with custom algorithms, as described further below.
Supported Algorithms#
Supervised Fine-Tuning (SFT) & Preference
Reinforcement Learning (RL)
PPO (Proximal Policy Optimization)
GRPO (Group Relative Policy Optimization)
GSPO-Token (Token-level Group Sequence Policy Optimization)
DAPO (Direct Alignment via Preference Optimization)
Dr.GRPO (Distributionally Robust GRPO)
Add a New RL Algorithm#
Tunix is designed to be highly extensible. You can introduce new algorithms by
subclassing AlgorithmConfig (or its descendants) and implementing a
corresponding Learner.
The system uses a parallel inheritance pattern: extending a Configuration often requires extending a Learner to consume it.
Class Hierarchy & Interaction Diagram#
CONFIGURATION LEARNER (The Engine)
(Defines Params) (Orchestrates Execution)
+-------------------+ +-----------------------+
| AlgorithmConfig | <---(binds)-- | RLLearner |
+-------------------+ +-----------------------+
^ ^ |
| | +---(Uses)---> [Function Registry]
(Inheritance) (Inheritance) (Loss, Advantage, Reward)
| |
+-------------------+ +-----------------------+
| GRPOConfig | <---(binds)-- | GRPOLearner |
+-------------------+ +-----------------------+
^ ^
| |
+-------------------+ +-----------------------+
| DAPOConfig | <---(binds)-- | DAPOLearner |
+-------------------+ +-----------------------+
1. Defining the Configuration & Learner#
To add a new algorithm, you typically define a config (to hold your params) and a learner (to use them).
Step 1: The Configuration Inherit from AlgorithmConfig (or a specific
child like GRPOConfig if your algorithm is a variant of it). Use
__post_init__ to validate your new settings.
@dataclasses.dataclass(slots=True, kw_only=True)
class MyNewAlgoConfig(AlgorithmConfig):
# 1. Identity
algo_variant: str = "my_new_algo"
# 2. Components (References strings in FunctionRegistry)
advantage_estimator: str = "gae"
policy_loss_fn: str = "my_custom_loss"
reward_manager: str = "sequence-level"
# 3. Custom Hyperparameters
my_hyperparam: float = 0.5
def __post_init__(self):
...
Step 2: The Learner Inherit from RLLearner (or GrpoLearner etc). This is
where you inject specific execution logic, such as modifying the training loop
or injecting custom reward functions.
class MyNewAlgorithmLearner(RLLearner):
def __init__(self, rl_cluster, algo_config: MyNewAlgoConfig, reward_fns, ...):
# Custom initialization (e.g., adding specific reward shaping)
if algo_config.my_hyperparam > 0.1:
reward_fns.append(my_custom_reward_fn)
super().__init__(
rl_cluster=rl_cluster,
algo_config=algo_config,
reward_fns=reward_fns,
...
)
2. Custom Loss & Advantage (The Registry)#
Tunix uses a Function Registry to manage mathematical components. This allows you to hot-swap loss functions or advantage estimators in your config without changing the Learner code.
How to Register a New Loss Function: Define your loss function and decorate
it with @register_policy_loss_fn.
from tunix.registry import register_policy_loss_fn
@register_policy_loss_fn("my_custom_loss")
def compute_my_custom_loss(log_probs, advantages, **kwargs):
"""
Args:
log_probs: Tensor of log probabilities.
advantages: Tensor of calculated advantages.
Returns:
Scalar loss tensor.
"""
return -torch.mean(log_probs * advantages)
Usage: Once registered, simply reference it in your config: policy_loss_fn: str = "my_custom_loss"
3. Custom Reward Management#
Rewards are handled by a Manager pattern. The Learner delegates to the Manager to compute rewards from model output and log the results.
Reward Function: A simple callable that calculates scores based on completion texts (e.g., regex matching, length constraints, keyword presence).
Reward Manager: The orchestrator that calls reward functions, formats the output, and handles logging.
When to use what?
Simple Case: Just add a new function to the
reward_fnslist.Complex Case: Subclass
AbstractRewardManagerif you need custom aggregation (e.g., weighted sums) or specialized logging strategies.
Example: Custom Manager
Below is an example of a manager that performs custom aggregation and injects specific intermediate logs.
class MyCustomRewardManager(AbstractRewardManager):
def __call__(self, prompts: List[str], completions: List[str], **kwargs) -> Dict[str, Any]:
"""
Orchestrates reward calculation.
"""
# 1. Run all reward functions
raw_scores = [fn(prompts, completions) for fn in self.reward_fns]
# 2. Custom Aggregation (e.g., Multi-objective weighted sum)
final_rewards = self.aggregate_logic(raw_scores)
# 3. Calculate log metrics
log_metrics = self._prepare_log_metrics(prompts, completions, raw_scores, final_rewards)
# 4. Return format required by Learner
return {
"rewards": final_rewards,
"log_metrics": log_metrics
}
def _prepare_log_metrics(
self,
prompts: List[str],
completions: List[str],
rewards: np.ndarray,
sum_rewards: np.ndarray,
) -> Dict[str, Any]:
"""
Logs individual and summed rewards.
"""
# 1. Standard Logs (prompts, completions, sum, min, max)
metrics_to_log = super()._prepare_log_metrics(prompts, completions, rewards, sum_rewards)
# 2. Custom Intermediate Logging
# User may freely add intermediate reward logging results here.
# Example: Logging specific components of the reward signal separately
if hasattr(self, "reward_fns"):
for i, fn in enumerate(self.reward_fns):
# Log the specific contribution of each function (e.g. rewards/grammar_score)
name = getattr(fn, "__name__", f"fn_{i}")
metrics_to_log[f"rewards/{name}"] = (rewards[:, i], np.mean)
return metrics_to_log