Skip to content

add JEPO trainer#5411

Open
zbills wants to merge 12 commits intohuggingface:mainfrom
zbills:main
Open

add JEPO trainer#5411
zbills wants to merge 12 commits intohuggingface:mainfrom
zbills:main

Conversation

@zbills
Copy link
Copy Markdown

@zbills zbills commented Mar 31, 2026

What does this PR do?

Fixes # (issue)

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline, Pull Request section?
  • Was this discussed/approved via a GitHub issue? Please add a link to it if that's the case.
  • Did you make sure to update the documentation with your changes?
  • Did you write any new necessary tests?

AI writing disclosure

We welcome the use of AI tools to help with contributions. For transparency and to help us improve our review process, please indicate the level of AI involvement in this PR.

  • [] No AI usage: the PR was written entirely by a human.
  • AI-assisted: some parts were suggested or improved by AI, but the PR was written and reviewed by a human.
  • AI-generated: the PR was mostly or fully generated by an AI tool.

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag members/contributors who may be interested in your PR.


Note

Medium Risk
Introduces a large new trainer with non-trivial reward/advantage and loss computation paths (plus optional vLLM/FSDP/DeepSpeed integration), so subtle training/runtime issues are possible despite being mostly additive.

Overview
Adds a new JEPO training flow to TRL by introducing JEPOConfig and JEPOTrainer, and exporting them via the top-level trl and trl.trainer lazy import structures.

JEPOTrainer implements JEPO-style generation, reward aggregation, JEPO-specific advantage computation (including CoT-based fabricated completions), and a JEPO loss (optionally combined with supervised and KL terms), with support hooks for standard transformers generation and optional vLLM execution.

Includes a new examples/notebooks/jepo_math.ipynb notebook demonstrating dataset prep, reward/cot helpers, and end-to-end JEPO training on a math dataset.

Written by Cursor Bugbot for commit 6bf8c8e. This will update automatically on new commits. Configure here.

class JEPOConfig(TrainingArguments):
r"""
Configuration class for the [`JEPOTrainer`], which serves as a variation of GRPO for unverifiable RL training.
JEPO [https://arxiv.org/pdf/2503.19618]
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missing paper_index.md update for JEPO paper

Medium Severity

This PR implements the JEPO algorithm from a research paper ("Beyond Verifiable Rewards: Scaling RL for Language Models to Unverifiable Data"), but does not add a corresponding subsection to paper_index.md. The project rule in .ai/AGENTS.md states: "If a PR implements a method, algorithm, or training approach from a research paper, it must also add a corresponding subsection to paper_index.md."

Fix in Cursor Fix in Web

Triggered by project rule: ../.ai/AGENTS.md

if self.loss_type == 'unnorm_jepo':
loss = - per_token_adv.sum(dim=1).sum()/ (self.num_generations) # sum over the tokens and average over the batch
else:
loss = - (per_token_adv.sum(dim=1)/cot_mask.sum(dim=1)).sum()/ (self.num_generations) # sum over the tokens and average over the batch
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Division by zero when cot_mask is all zeros

Medium Severity

In _compute_loss, the norm_jepo loss path divides by cot_mask.sum(dim=1) and the supervised loss divides by answer_mask.sum(dim=1). When a sample doesn't have the correct format (no CoT or answer extracted), these sums can be zero, producing NaN values that propagate through the loss and corrupt gradients.

Additional Locations (1)
Fix in Cursor Fix in Web

year = 2025,
eprint = {arXiv:2503.19618},
}"""),
}
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missing paper_index.md update for JEPO paper

Low Severity

This PR implements the JEPO method from the paper "Beyond Verifiable Rewards" (2503.19618) but does not add a corresponding subsection to paper_index.md. The project rules require that any PR implementing a research paper must update paper_index.md.

Fix in Cursor Fix in Web

Triggered by project rule: ../.ai/AGENTS.md

Copy link
Copy Markdown

@cursor cursor bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cursor Bugbot has reviewed your changes and found 1 potential issue.

There are 7 total unresolved issues (including 6 from previous reviews).

Fix All in Cursor

Bugbot Autofix is OFF. To automatically fix reported issues with cloud agents, enable autofix in the Cursor dashboard.

advantage = torch.log(mean_reward) - torch.log(variance)
else:
# If only one sample is applicable, use log(mean_reward).
advantage = torch.log(mean_reward)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

torch.log on potentially zero rewards produces NaN

Medium Severity

_compute_jepo_advantages calls torch.log(mean_reward) and torch.log(variance) where these values are sums of token-level probabilities. If the model assigns near-zero probability to all answer tokens, mean_reward or variance can be zero or extremely small, producing -inf or NaN. These NaN advantages then propagate into the loss computation.

Fix in Cursor Fix in Web

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant