Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 35 additions & 0 deletions docs/source/paper_index.md
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,41 @@ trainer = GRPOTrainer(
)
```

### FIPO: Eliciting Deep Reasoning with Future-KL Influenced Policy Optimization

**📜 Paper**: https://huggingface.co/papers/2603.19835

FIPO keeps the DAPO training scaffold but replaces the uniform token weighting with a discounted Future-KL influence
weight. In TRL, this is available as an additional GRPO loss type. To mirror the upstream FIPO recipe, use:

```python
from trl import GRPOConfig, GRPOTrainer
from trl.rewards import get_soft_overlong_punishment

training_args = GRPOConfig(
mask_truncated_completions=True,
loss_type="fipo",
epsilon_high=0.28,
epsilon=0.2,
per_device_train_batch_size=512,
num_generations=16,
max_completion_length=20480,
beta=0.0,
fipo_clip_ratio_c=10.0,
fipo_decay_rate=32.0,
fipo_chunk_size=128,
fipo_influence_clip_ratio=0.2,
fipo_influence_clip_high_only=True,
fipo_safety_threshold=10.0,
)
sop_reward = get_soft_overlong_punishment(max_completion_len=20480, soft_punish_cache=4096)
trainer = GRPOTrainer(
...,
args=training_args,
reward_funcs=[..., sop_reward],
)
```

### INTELLECT-2: A Reasoning Model Trained Through Globally Decentralized Reinforcement Learning

**📜 Paper**: https://huggingface.co/papers/2505.07291
Expand Down
2 changes: 1 addition & 1 deletion tests/test_grpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,7 +278,7 @@ def test_training(self, config_name):
new_param = trainer.model.get_parameter(n)
assert not torch.equal(param, new_param), f"Parameter {n} has not changed."

@pytest.mark.parametrize("loss_type", ["bnpo", "dr_grpo", "dapo", "cispo", "sapo", "luspo", "vespo"])
@pytest.mark.parametrize("loss_type", ["bnpo", "dr_grpo", "dapo", "cispo", "fipo", "sapo", "luspo", "vespo"])
def test_training_loss_types(self, loss_type):
dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train")

Expand Down
68 changes: 66 additions & 2 deletions trl/trainer/grpo_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,23 @@ class GRPOConfig(_BaseConfig):
lambda parameter for negative advantages, it is the exponential decay factor in the VESPO loss. Controls
how aggressively we down-weight samples with high importance weights (when the importance sampling ratio >
1).
fipo_decay_rate (`float`, *optional*, defaults to `32.0`):
Decay horizon for the discounted Future-KL accumulation used by the FIPO loss. The effective decay factor
is computed as `2 ** (-1 / fipo_decay_rate)`.
fipo_chunk_size (`int`, *optional*, defaults to `128`):
Chunk size used when accumulating the Future-KL signal in the FIPO loss. This affects memory/runtime but
not the mathematical objective.
fipo_clip_ratio_c (`float`, *optional*, defaults to `10.0`):
Dual-clip threshold used by the FIPO loss for negative-advantage tokens. This follows the upstream FIPO
implementation and should be greater than `1.0`.
fipo_influence_clip_ratio (`float`, *optional*, defaults to `0.2`):
Clipping ratio applied to FIPO's influence weights `exp(FutureKL_t)`.
fipo_influence_clip_high_only (`bool`, *optional*, defaults to `True`):
Whether FIPO clips only the upper bound of the influence weights. When `False`, both lower and upper
bounds are clipped symmetrically.
fipo_safety_threshold (`float`, *optional*, defaults to `10.0`):
Safety threshold used by FIPO to cap influence weights for negative-advantage tokens with large importance
ratios.
importance_sampling_level (`str`, *optional*, defaults to `"token"`):
Controls whether importance sampling ratios are computed at the `"token"` or `"sequence"` level. `"token"`
keeps the raw per-token log-probability ratios (one weight per token). `"sequence"` averages the
Expand Down Expand Up @@ -665,6 +682,45 @@ class GRPOConfig(_BaseConfig):
"sampling ratio > 1)."
},
)
fipo_decay_rate: float = field(
default=32.0,
metadata={
"help": "Decay horizon for the discounted Future-KL accumulation used by the FIPO loss. The effective "
"decay factor is computed as `2 ** (-1 / fipo_decay_rate)`."
},
)
fipo_chunk_size: int = field(
default=128,
metadata={
"help": "Chunk size used when accumulating the Future-KL signal in the FIPO loss. This affects "
"memory/runtime but not the mathematical objective."
},
)
fipo_clip_ratio_c: float = field(
default=10.0,
metadata={
"help": "Dual-clip threshold used by the FIPO loss for negative-advantage tokens. This follows the "
"upstream FIPO implementation and should be greater than `1.0`."
},
)
fipo_influence_clip_ratio: float = field(
default=0.2,
metadata={"help": "Clipping ratio applied to FIPO's influence weights `exp(FutureKL_t)`."},
)
fipo_influence_clip_high_only: bool = field(
default=True,
metadata={
"help": "Whether FIPO clips only the upper bound of the influence weights. When `False`, both lower and "
"upper bounds are clipped symmetrically."
},
)
fipo_safety_threshold: float = field(
default=10.0,
metadata={
"help": "Safety threshold used by FIPO to cap influence weights for negative-advantage tokens with large "
"importance ratios."
},
)
importance_sampling_level: str = field(
default="token",
metadata={
Expand Down Expand Up @@ -709,8 +765,7 @@ class GRPOConfig(_BaseConfig):
loss_type: str = field(
default="dapo",
metadata={
"help": "Specifies the loss formulation to use. Supported values are 'grpo', 'dapo', 'bnpo', and "
"'dr_grpo'. "
"help": "Specifies the loss formulation to use. "
"'grpo': Aggregates token-level losses by normalizing over sequence length. Not recommended due to length "
"bias—this approach tends to prefer shorter completions with positive advantages and longer ones with "
"negative advantages. "
Expand All @@ -728,6 +783,9 @@ class GRPOConfig(_BaseConfig):
"Individual token losses are aggregated by normalizing with the number of active tokens in "
"the global accumulated batch. This method was introduced in the "
"[MiniMax-M1 paper](https://huggingface.co/papers/2506.13585). "
"'fipo': Future-KL Influenced Policy Optimization. Reweights each token's advantage using a discounted "
"sum of future log-probability shifts before applying a DAPO-style clipped policy loss. Introduced in "
"the [FIPO paper](https://huggingface.co/papers/2603.19835). "
"'sapo': Soft Adaptive Policy Optimization loss, as introduced in the "
"[Soft Adaptive Policy Optimization paper](https://huggingface.co/papers/2511.20347). "
"Replaces hard clipping with a smooth, temperature-controlled gate that adaptively attenuates "
Expand Down Expand Up @@ -926,3 +984,9 @@ def __post_init__(self):

if self.delta is not None and self.use_liger_kernel:
raise ValueError("Liger kernel does not support two-sided GRPO loss yet.")

if self.fipo_clip_ratio_c <= 1.0:
raise ValueError("fipo_clip_ratio_c must be greater than 1.0.")

if self.fipo_chunk_size <= 0:
raise ValueError("fipo_chunk_size must be strictly positive.")
143 changes: 137 additions & 6 deletions trl/trainer/grpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -581,6 +581,12 @@ def __init__(
"set to `'token'` (the default)."
)

if args.loss_type == "fipo" and args.importance_sampling_level != "token":
logger.warning(
"FIPO computes token-level Future-KL weights internally. `importance_sampling_level` is ignored and "
"should be left at `'token'`."
)

if self.loss_type == "vespo" and self.use_vllm and self.vllm_importance_sampling_correction:
if self.vllm_importance_sampling_mode not in ["token_truncate", "token_mask"]:
raise ValueError(
Expand Down Expand Up @@ -676,6 +682,8 @@ def cast_outputs_to_original_dtype(module, args, output):

# Liger loss
if self.use_liger_kernel:
if self.loss_type == "fipo":
raise NotImplementedError("Liger kernel does not support the FIPO loss yet.")
if not is_liger_kernel_available():
raise ImportError(
"Liger is required to use `use_liger_kernel` as the GRPO loss. Run `pip install liger-kernel`."
Expand Down Expand Up @@ -2255,6 +2263,90 @@ def get_gamma_weights(

return phi_seq # (B, 1)

@staticmethod
@torch.no_grad()
def get_fipo_loss_components(
advantages: torch.Tensor,
log_ratio_per_token: torch.Tensor,
mask: torch.Tensor,
epsilon_low: float,
epsilon_high: float,
clip_ratio_c: float = 10.0,
decay_rate: float = 32.0,
chunk_size: int = 128,
influence_clip_ratio: float = 0.2,
influence_clip_high_only: bool = True,
safety_threshold: float = 10.0,
) -> dict[str, torch.Tensor]:
"""
Computes the Future-KL influence weights and token losses used by the FIPO loss.
"""
log_ratio_per_token = torch.clamp(log_ratio_per_token, min=-20.0, max=20.0)
ratio = torch.exp(log_ratio_per_token)
batch_size, response_len = log_ratio_per_token.shape
device = log_ratio_per_token.device
mask_bool = mask.to(dtype=torch.bool)
mask_float = mask.to(dtype=log_ratio_per_token.dtype)
completion_token_count = mask_float.sum().clamp(min=1.0)

def masked_mean(values: torch.Tensor) -> torch.Tensor:
return (values * mask_float).sum() / completion_token_count

future_kl = torch.zeros((batch_size, response_len), device=device, dtype=log_ratio_per_token.dtype)
pos_i = torch.arange(response_len, device=device).unsqueeze(1)
filter_threshold = math.log(clip_ratio_c)
ignore_mask = log_ratio_per_token > filter_threshold
participation_mask = ~ignore_mask
kl_response = log_ratio_per_token * mask_float
kl_response = kl_response * participation_mask.to(dtype=log_ratio_per_token.dtype)

gamma_t = torch.tensor(2 ** (-1.0 / decay_rate), dtype=log_ratio_per_token.dtype, device=device)
for j_start in range(0, response_len, chunk_size):
j_end = min(response_len, j_start + chunk_size)
j_idx = torch.arange(j_start, j_end, device=device).unsqueeze(0)
distance = j_idx - pos_i
distance_mask = distance >= 0
distance_clamped = distance.clamp(min=0)
decay_block = torch.pow(gamma_t, distance_clamped) * distance_mask.to(dtype=log_ratio_per_token.dtype)
future_kl += torch.matmul(kl_response[:, j_start:j_end], decay_block.t())

raw_influence_weights = torch.exp(future_kl)
if influence_clip_ratio != 0.0:
if influence_clip_high_only:
lower_bound = 1.0
upper_bound = 1.0 + influence_clip_ratio
else:
lower_bound = 1.0 - influence_clip_ratio
upper_bound = 1.0 + influence_clip_ratio
influence_weights = torch.clamp(raw_influence_weights, min=lower_bound, max=upper_bound)
else:
lower_bound = 0.0
upper_bound = 10.0
influence_weights = torch.clamp(raw_influence_weights, max=10.0)

mask_neg_high_is = (advantages < 0) & (ratio > safety_threshold)
influence_weights = torch.where(mask_neg_high_is, torch.clamp(influence_weights, min=0.8, max=1.0), influence_weights)

weighted_advantages = advantages * influence_weights
pg_losses1 = -weighted_advantages * ratio
pg_losses2 = -weighted_advantages * torch.clamp(ratio, 1 - epsilon_low, 1 + epsilon_high)
clip_pg_losses1 = torch.maximum(pg_losses1, pg_losses2)
pg_losses3 = -weighted_advantages * clip_ratio_c

lower_clip_mask = (advantages < 0) & (clip_pg_losses1 > pg_losses3) & mask_bool
low_clip_token_counts = lower_clip_mask.sum(dim=1)
seq_valid_mask = (low_clip_token_counts <= 1).unsqueeze(1)
loss_mask = (mask_bool & seq_valid_mask).to(dtype=log_ratio_per_token.dtype)

return {
"loss_mask": loss_mask,
"influence_weights": influence_weights,
"influence_weights_mean_raw": masked_mean(raw_influence_weights),
"influence_weight_clip_ratio_upper": masked_mean((influence_weights >= upper_bound - 1e-7).float()),
"influence_weight_clip_ratio_lower": masked_mean((influence_weights <= lower_bound + 1e-7).float()),
"sequence_drop_ratio": (~seq_valid_mask.squeeze(1)).float().mean(),
}

def _compute_loss(self, model, inputs):
# Compute the per-token log probabilities for the model
prompt_ids, prompt_mask = inputs["prompt_ids"], inputs["prompt_mask"]
Expand Down Expand Up @@ -2328,6 +2420,7 @@ def _compute_loss(self, model, inputs):
)

coef_1 = torch.exp(log_importance_weights)
loss_mask = mask

# Compute the KL divergence between the model and the reference model
if self.beta != 0.0:
Expand All @@ -2353,6 +2446,30 @@ def _compute_loss(self, model, inputs):
per_token_loss1 = coef_1 * advantages
per_token_loss2 = coef_2 * advantages
per_token_loss = -torch.min(per_token_loss1, per_token_loss2)
elif self.loss_type == "fipo":
fipo_outputs = self.get_fipo_loss_components(
advantages=advantages,
log_ratio_per_token=log_ratio,
mask=mask,
epsilon_low=self.epsilon_low,
epsilon_high=self.epsilon_high,
clip_ratio_c=self.args.fipo_clip_ratio_c,
decay_rate=self.args.fipo_decay_rate,
chunk_size=self.args.fipo_chunk_size,
influence_clip_ratio=self.args.fipo_influence_clip_ratio,
influence_clip_high_only=self.args.fipo_influence_clip_high_only,
safety_threshold=self.args.fipo_safety_threshold,
)
coef_1 = torch.exp(torch.clamp(log_ratio, min=-20.0, max=20.0))
coef_2 = torch.clamp(coef_1, 1 - self.epsilon_low, 1 + self.epsilon_high)
weighted_advantages = advantages * fipo_outputs["influence_weights"]
pg_losses1 = -weighted_advantages * coef_1
pg_losses2 = -weighted_advantages * coef_2
clip_pg_losses1 = torch.maximum(pg_losses1, pg_losses2)
pg_losses3 = -weighted_advantages * self.args.fipo_clip_ratio_c
clip_pg_losses2 = torch.minimum(pg_losses3, clip_pg_losses1)
per_token_loss = torch.where(weighted_advantages < 0, clip_pg_losses2, clip_pg_losses1)
loss_mask = fipo_outputs["loss_mask"]
elif self.loss_type == "sapo":
temperatures = torch.where(advantages > 0, self.args.sapo_temperature_pos, self.args.sapo_temperature_neg)
soft_coef_1 = torch.sigmoid(temperatures * (coef_1 - 1)) * 4 / temperatures
Expand Down Expand Up @@ -2386,23 +2503,27 @@ def _compute_loss(self, model, inputs):

mode = "train" if self.model.training else "eval"
if self.loss_type in ["grpo", "sapo"]:
loss = ((per_token_loss * mask).sum(-1) / mask.sum(-1).clamp(min=1.0)).mean()
loss = ((per_token_loss * loss_mask).sum(-1) / loss_mask.sum(-1).clamp(min=1.0)).mean()
normalizer = self.current_gradient_accumulation_steps if mode == "train" else 1.0 # no accum in eval
loss = loss / normalizer
elif self.loss_type == "bnpo":
loss = (per_token_loss * mask).sum() / mask.sum().clamp(min=1.0)
loss = (per_token_loss * loss_mask).sum() / loss_mask.sum().clamp(min=1.0)
normalizer = self.current_gradient_accumulation_steps if mode == "train" else 1.0 # no accum in eval
loss = loss / normalizer
elif self.loss_type == "dr_grpo":
loss = (per_token_loss * mask).sum() / (per_token_loss.size(0) * self.max_completion_length)
loss = (per_token_loss * loss_mask).sum() / (per_token_loss.size(0) * self.max_completion_length)
normalizer = self.current_gradient_accumulation_steps if mode == "train" else 1.0 # no accum in eval
loss = loss / normalizer
elif self.loss_type == "fipo":
loss = (per_token_loss * loss_mask).sum() / loss_mask.sum().clamp(min=1.0)
normalizer = self.current_gradient_accumulation_steps if mode == "train" else 1.0 # no accum in eval
loss = loss / normalizer
elif self.loss_type in ["cispo", "dapo", "vespo"]:
normalizer = inputs["num_items_in_batch"] / self.accelerator.num_processes
loss = (per_token_loss * mask).sum() / normalizer
loss = (per_token_loss * loss_mask).sum() / normalizer
elif self.loss_type == "luspo":
# Unless importance_sampling_level="token" (not recommended here), per_token_loss is expected to be (B, 1)
loss = (per_token_loss * mask.sum(1, keepdim=True)).mean()
loss = (per_token_loss * loss_mask.sum(1, keepdim=True)).mean()
normalizer = self.current_gradient_accumulation_steps if mode == "train" else 1.0
loss = loss / normalizer
else:
Expand All @@ -2424,7 +2545,17 @@ def masked_batch_mean(x):
mean_entropy = masked_batch_mean(entropies)
self._metrics[mode]["entropy"].append(self.accelerator.gather(mean_entropy).nanmean().item())

if self.loss_type in ["grpo", "bnpo", "dr_grpo", "dapo", "luspo"]:
if self.loss_type == "fipo":
for key in [
"influence_weights_mean_raw",
"influence_weight_clip_ratio_upper",
"influence_weight_clip_ratio_lower",
"sequence_drop_ratio",
]:
gathered_metric = self.accelerator.gather(fipo_outputs[key])
self._metrics[mode][f"fipo/{key}"].append(gathered_metric.nanmean().item())

if self.loss_type in ["grpo", "bnpo", "dr_grpo", "dapo", "fipo", "luspo"]:
# Compute the clipped probability ratios
is_low_clipped = (coef_1 < 1 - self.epsilon_low) & (advantages < 0)
is_high_clipped = (coef_1 > 1 + self.epsilon_high) & (advantages > 0)
Expand Down