Algorithms#

Tunix supports a wide array of SOTA algorithms for RL and SFT. Its modular design also allows users to easily extend Tunix with custom algorithms, as described further below.

Supported Algorithms#

  • Supervised Fine-Tuning (SFT) & Preference

    • PEFT (Parameter-Efficient Fine-Tuning)

    • DPO (Direct Preference Optimization)

      • ORPO (Odds ratio Preference Optimization)

  • Reinforcement Learning (RL)

    • PPO (Proximal Policy Optimization)

    • GRPO (Group Relative Policy Optimization)

      • GSPO-Token (Token-level Group Sequence Policy Optimization)

      • DAPO (Direct Alignment via Preference Optimization)

      • Dr.GRPO (Distributionally Robust GRPO)

Add a New RL Algorithm#

Tunix is designed to be highly extensible. You can introduce new algorithms by subclassing AlgorithmConfig (or its descendants) and implementing a corresponding Learner.

The system uses a parallel inheritance pattern: extending a Configuration often requires extending a Learner to consume it.

Class Hierarchy & Interaction Diagram#

       CONFIGURATION                            LEARNER (The Engine)
     (Defines Params)                       (Orchestrates Execution)
   +-------------------+                   +-----------------------+
   |  AlgorithmConfig  | <---(binds)--     |       RLLearner       |
   +-------------------+                   +-----------------------+
             ^                                         ^      |
             |                                         |      +---(Uses)---> [Function Registry]
     (Inheritance)                               (Inheritance)                (Loss, Advantage, Reward)
             |                                         |
   +-------------------+                   +-----------------------+
   |    GRPOConfig     | <---(binds)--     |      GRPOLearner      |
   +-------------------+                   +-----------------------+
             ^                                         ^
             |                                         |
   +-------------------+                   +-----------------------+
   |    DAPOConfig     | <---(binds)--     |      DAPOLearner    |
   +-------------------+                   +-----------------------+

1. Defining the Configuration & Learner#

To add a new algorithm, you typically define a config (to hold your params) and a learner (to use them).

Step 1: The Configuration Inherit from AlgorithmConfig (or a specific child like GRPOConfig if your algorithm is a variant of it). Use __post_init__ to validate your new settings.

@dataclasses.dataclass(slots=True, kw_only=True)
class MyNewAlgoConfig(AlgorithmConfig):
    # 1. Identity
    algo_variant: str = "my_new_algo"

    # 2. Components (References strings in FunctionRegistry)
    advantage_estimator: str = "gae"
    policy_loss_fn: str = "my_custom_loss"
    reward_manager: str = "sequence-level"

    # 3. Custom Hyperparameters
    my_hyperparam: float = 0.5

    def __post_init__(self):
        ...

Step 2: The Learner Inherit from RLLearner (or GrpoLearner etc). This is where you inject specific execution logic, such as modifying the training loop or injecting custom reward functions.

class MyNewAlgorithmLearner(RLLearner):
    def __init__(self, rl_cluster, algo_config: MyNewAlgoConfig, reward_fns, ...):
        # Custom initialization (e.g., adding specific reward shaping)
        if algo_config.my_hyperparam > 0.1:
            reward_fns.append(my_custom_reward_fn)

        super().__init__(
            rl_cluster=rl_cluster,
            algo_config=algo_config,
            reward_fns=reward_fns,
            ...
        )


2. Custom Loss & Advantage (The Registry)#

Tunix uses a Function Registry to manage mathematical components. This allows you to hot-swap loss functions or advantage estimators in your config without changing the Learner code.

How to Register a New Loss Function: Define your loss function and decorate it with @register_policy_loss_fn.

from tunix.registry import register_policy_loss_fn

@register_policy_loss_fn("my_custom_loss")
def compute_my_custom_loss(log_probs, advantages, **kwargs):
    """
    Args:
        log_probs: Tensor of log probabilities.
        advantages: Tensor of calculated advantages.
    Returns:
        Scalar loss tensor.
    """
    return -torch.mean(log_probs * advantages)

Usage: Once registered, simply reference it in your config: policy_loss_fn: str = "my_custom_loss"


3. Custom Reward Management#

Rewards are handled by a Manager pattern. The Learner delegates to the Manager to compute rewards from model output and log the results.

  • Reward Function: A simple callable that calculates scores based on completion texts (e.g., regex matching, length constraints, keyword presence).

  • Reward Manager: The orchestrator that calls reward functions, formats the output, and handles logging.

When to use what?

  • Simple Case: Just add a new function to the reward_fns list.

  • Complex Case: Subclass AbstractRewardManager if you need custom aggregation (e.g., weighted sums) or specialized logging strategies.

Example: Custom Manager

Below is an example of a manager that performs custom aggregation and injects specific intermediate logs.

class MyCustomRewardManager(AbstractRewardManager):
    def __call__(self, prompts: List[str], completions: List[str], **kwargs) -> Dict[str, Any]:
        """
        Orchestrates reward calculation.
        """
        # 1. Run all reward functions
        raw_scores = [fn(prompts, completions) for fn in self.reward_fns]

        # 2. Custom Aggregation (e.g., Multi-objective weighted sum)
        final_rewards = self.aggregate_logic(raw_scores)

        # 3. Calculate log metrics
        log_metrics = self._prepare_log_metrics(prompts, completions, raw_scores, final_rewards)

        # 4. Return format required by Learner
        return {
            "rewards": final_rewards,
            "log_metrics": log_metrics
        }

    def _prepare_log_metrics(
        self,
        prompts: List[str],
        completions: List[str],
        rewards: np.ndarray,
        sum_rewards: np.ndarray,
    ) -> Dict[str, Any]:
        """
        Logs individual and summed rewards.
        """
        # 1. Standard Logs (prompts, completions, sum, min, max)
        metrics_to_log = super()._prepare_log_metrics(prompts, completions, rewards, sum_rewards)

        # 2. Custom Intermediate Logging
        # User may freely add intermediate reward logging results here.
        # Example: Logging specific components of the reward signal separately
        if hasattr(self, "reward_fns"):
             for i, fn in enumerate(self.reward_fns):
                # Log the specific contribution of each function (e.g. rewards/grammar_score)
                name = getattr(fn, "__name__", f"fn_{i}")
                metrics_to_log[f"rewards/{name}"] = (rewards[:, i], np.mean)

        return metrics_to_log