Skip to content

[Fix] fix SP for InternS1 VL RL#1656

Open
tina-wen wants to merge 1 commit intoInternLM:mainfrom
tina-wen:rl_sp
Open

[Fix] fix SP for InternS1 VL RL#1656
tina-wen wants to merge 1 commit intoInternLM:mainfrom
tina-wen:rl_sp

Conversation

@tina-wen
Copy link
Copy Markdown
Contributor

@tina-wen tina-wen commented Apr 6, 2026

Root Cause

Under sequence parallelism, entropy statistics and GRPO batch loss calibration were computed from sharded tensors, which could produce incorrect token counts and inconsistent metrics across SP ranks. In addition, GRPO loss batching needed to accept the forwarded SP context from the worker.

Fix

Gather shifted labels and logprobs before entropy aggregation under SP, pass sp_mesh into RL loss batching from the training worker, and make GRPO batch construction use the SP-aware token counting path.

@tina-wen tina-wen changed the title [Fix] fix SP for VL-241B RL [Fix] fix SP for InternS1 VL RL Apr 7, 2026
for i, shifted_labels in enumerate(shifted_labels_list):
old_logprobs = old_logprobs_list[i]
assert old_logprobs is not None
if sp_mesh is not None:
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

each sp rank should calculate its own rank loss, this change makes all ranks in a sp_group caculate the same loss, and it would be all_reduced sum later, thus the sum_entropy would times sp_size

(shifted_labels != loss_cfg.ignore_idx).sum()
if rank_grad_tokens is None
else rank_grad_tokens + (shifted_labels != loss_cfg.ignore_idx).sum()
)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

All sp ranks in one sp group get the same rank_grad_tokens. Same problem above.

@CyCle1024
Copy link
Copy Markdown
Collaborator

@tina-wen reference code block:

if sp_mesh is not None:
# gather shifted_labels from different sp ranks to compute the correct loss weight
shifted_labels = sp_gather(shifted_labels, sp_mesh=sp_mesh, dim=1)
mask = (shifted_labels != loss_cfg.ignore_idx).int()
num_grad_tokens = torch.zeros_like(boundaries, dtype=torch.int32)
prev_idx = 0
for j, boundary in enumerate(boundaries):
num_grad_tokens[j] = mask[0, prev_idx:boundary].sum()
prev_idx = boundary
if loss_cfg.loss_reduction == "sample":
loss_weight = 1.0 / num_grad_tokens
elif loss_cfg.loss_reduction == "square":
loss_weight = 1.0 / torch.sqrt(num_grad_tokens.float())
else:
raise NotImplementedError(loss_cfg.loss_reduction)
loss_weight = loss_weight.repeat_interleave(num_tokens).unsqueeze(0)
if sp_mesh is not None:
loss_weight = sp_split(loss_weight, sp_mesh=sp_mesh, split_dim=1, padding_value=0.0)
shifted_labels = sp_split(shifted_labels, sp_mesh=sp_mesh, split_dim=1, padding_value=-100)
not the same compute logic, but the sp processing.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants