Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 15 additions & 5 deletions xtuner/v1/rl/base/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from xtuner.v1.float8.float8_handler import Float8Handler
from xtuner.v1.loss import CELossConfig, LogProbConfig
from xtuner.v1.loss.ce_loss import CELossContext
from xtuner.v1.loss.utils import sp_gather
from xtuner.v1.model.base import BaseModel as XtunerBaseModel
from xtuner.v1.model.base import ModelItem, TransformerConfig
from xtuner.v1.model.compose.base import BaseComposeConfig, BaseComposeModel
Expand Down Expand Up @@ -89,14 +90,19 @@ def calculate_entropy(
shifted_labels_list: Sequence[torch.Tensor],
old_logprobs_list: Sequence[torch.Tensor | None],
global_grad_tokens: torch.Tensor,
sp_mesh: DeviceMesh | None = None,
) -> torch.Tensor | None:
if len(old_logprobs_list) == 0 or old_logprobs_list[0] is None:
return None
sum_entropy: torch.Tensor | None = None
for i, shifted_labels in enumerate(shifted_labels_list):
old_logprobs = old_logprobs_list[i]
assert old_logprobs is not None
if sp_mesh is not None:
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

each sp rank should calculate its own rank loss, this change makes all ranks in a sp_group caculate the same loss, and it would be all_reduced sum later, thus the sum_entropy would times sp_size

shifted_labels = sp_gather(shifted_labels, sp_mesh=sp_mesh, dim=1)
old_logprobs = sp_gather(old_logprobs, sp_mesh=sp_mesh, dim=1)
mask = shifted_labels != -100
assert old_logprobs_list[i] is not None
entropy = -(cast(torch.Tensor, old_logprobs_list[i]) * mask).sum()
entropy = -(old_logprobs * mask).sum()
sum_entropy = entropy if sum_entropy is None else sum_entropy + entropy
sum_entropy = cast(torch.Tensor, sum_entropy)
dist.all_reduce(sum_entropy, op=dist.ReduceOp.SUM)
Expand Down Expand Up @@ -531,15 +537,19 @@ def fit(self, data_batches: list[WorkerInputItem], rollout_idx: int) -> WorkerLo
# compute entropy
rank_grad_tokens: torch.Tensor | None = None
for shifted_labels in shifted_labels_list:
if self.sp_mesh is not None:
shifted_labels = sp_gather(shifted_labels, self.sp_mesh, dim=1)
mask = shifted_labels != -100
grad_tokens = mask.sum()
rank_grad_tokens = grad_tokens if rank_grad_tokens is None else rank_grad_tokens + grad_tokens
rank_grad_tokens = cast(torch.Tensor, rank_grad_tokens)
global_grad_tokens = rank_grad_tokens
dist.all_reduce(global_grad_tokens, op=dist.ReduceOp.SUM)

avg_sum_entropy = calculate_entropy(shifted_labels_list, old_logprobs_list, global_grad_tokens)
avg_rollout_entropy = calculate_entropy(shifted_labels_list, rollout_logprobs_list, global_grad_tokens)
avg_sum_entropy = calculate_entropy(shifted_labels_list, old_logprobs_list, global_grad_tokens, self.sp_mesh)
avg_rollout_entropy = calculate_entropy(
shifted_labels_list, rollout_logprobs_list, global_grad_tokens, self.sp_mesh
)

assert avg_sum_entropy is not None
worker_log_item["train_entropy"] = avg_sum_entropy.item()
Expand Down Expand Up @@ -604,7 +614,7 @@ def fit(self, data_batches: list[WorkerInputItem], rollout_idx: int) -> WorkerLo
for i in range(0, len(loss_ctx_list), iters_per_step):
batches_loss_ctx = loss_ctx_list[i : i + iters_per_step]
batched_loss_ctx_list.extend(
LossContext.build_batches(batches_loss_ctx) # type: ignore[arg-type]
LossContext.build_batches(batches_loss_ctx, sp_mesh=self.sp_mesh) # type: ignore[arg-type]
)

# train optimizer steps
Expand Down
23 changes: 21 additions & 2 deletions xtuner/v1/rl/grpo/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@
import torch
import torch.distributed as dist
import torch.nn.functional as F
from torch.distributed.device_mesh import DeviceMesh

from xtuner.v1.loss.utils import sp_gather
from xtuner.v1.utils import get_logger

from ..base import (
Expand Down Expand Up @@ -75,15 +77,32 @@ def __init__(self, loss_cfg: GRPOLossConfig, loss_kwargs: GRPOLossKwargs):
self.policy_loss_fn = get_policy_loss_fn(self.loss_cfg.policy_loss_cfg.get("loss_type", "vanilla"))

@staticmethod
def build_batches(loss_ctx_list: list["GRPOLossContext"]) -> list["GRPOLossContext"]: # type: ignore[override]
def build_batches( # type: ignore[override]
loss_ctx_list: list["GRPOLossContext"],
*args: Any,
**kwargs: Any,
) -> list["GRPOLossContext"]:
sp_mesh = cast(DeviceMesh | None, kwargs.get("sp_mesh"))
if sp_mesh is None and len(args) >= 2:
sp_mesh = cast(DeviceMesh | None, args[1])

assert len(loss_ctx_list) > 0, "loss_ctx_list can not be empty"

loss_cfg = loss_ctx_list[0].loss_cfg

shifted_labels_list = [loss_ctx.loss_kwargs.shifted_labels for loss_ctx in loss_ctx_list]
rank_grad_tokens: torch.Tensor | None = None
for shifted_labels in shifted_labels_list:
if sp_mesh is not None:
# gather shifted_labels from different sp ranks to compute the correct loss weight
shifted_labels = sp_gather(shifted_labels, sp_mesh=sp_mesh, dim=1)
rank_grad_tokens = (
(shifted_labels != loss_cfg.ignore_idx).sum()
if rank_grad_tokens is None
else rank_grad_tokens + (shifted_labels != loss_cfg.ignore_idx).sum()
)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

All sp ranks in one sp group get the same rank_grad_tokens. Same problem above.


# Compute the denominator used in the global calibration of the loss
rank_grad_tokens = sum((labels != loss_cfg.ignore_idx).sum() for labels in shifted_labels_list)
rank_grad_tokens = cast(torch.Tensor, rank_grad_tokens)
global_grad_tokens = rank_grad_tokens
if dist.is_initialized():
Expand Down
Loading