diff --git a/xtuner/v1/rl/base/worker.py b/xtuner/v1/rl/base/worker.py index 855bc589a..0f2cd9487 100644 --- a/xtuner/v1/rl/base/worker.py +++ b/xtuner/v1/rl/base/worker.py @@ -28,6 +28,7 @@ from xtuner.v1.float8.float8_handler import Float8Handler from xtuner.v1.loss import CELossConfig, LogProbConfig from xtuner.v1.loss.ce_loss import CELossContext +from xtuner.v1.loss.utils import sp_gather from xtuner.v1.model.base import BaseModel as XtunerBaseModel from xtuner.v1.model.base import ModelItem, TransformerConfig from xtuner.v1.model.compose.base import BaseComposeConfig, BaseComposeModel @@ -89,14 +90,19 @@ def calculate_entropy( shifted_labels_list: Sequence[torch.Tensor], old_logprobs_list: Sequence[torch.Tensor | None], global_grad_tokens: torch.Tensor, + sp_mesh: DeviceMesh | None = None, ) -> torch.Tensor | None: if len(old_logprobs_list) == 0 or old_logprobs_list[0] is None: return None sum_entropy: torch.Tensor | None = None for i, shifted_labels in enumerate(shifted_labels_list): + old_logprobs = old_logprobs_list[i] + assert old_logprobs is not None + if sp_mesh is not None: + shifted_labels = sp_gather(shifted_labels, sp_mesh=sp_mesh, dim=1) + old_logprobs = sp_gather(old_logprobs, sp_mesh=sp_mesh, dim=1) mask = shifted_labels != -100 - assert old_logprobs_list[i] is not None - entropy = -(cast(torch.Tensor, old_logprobs_list[i]) * mask).sum() + entropy = -(old_logprobs * mask).sum() sum_entropy = entropy if sum_entropy is None else sum_entropy + entropy sum_entropy = cast(torch.Tensor, sum_entropy) dist.all_reduce(sum_entropy, op=dist.ReduceOp.SUM) @@ -531,6 +537,8 @@ def fit(self, data_batches: list[WorkerInputItem], rollout_idx: int) -> WorkerLo # compute entropy rank_grad_tokens: torch.Tensor | None = None for shifted_labels in shifted_labels_list: + if self.sp_mesh is not None: + shifted_labels = sp_gather(shifted_labels, self.sp_mesh, dim=1) mask = shifted_labels != -100 grad_tokens = mask.sum() rank_grad_tokens = grad_tokens if rank_grad_tokens is None else rank_grad_tokens + grad_tokens @@ -538,8 +546,10 @@ def fit(self, data_batches: list[WorkerInputItem], rollout_idx: int) -> WorkerLo global_grad_tokens = rank_grad_tokens dist.all_reduce(global_grad_tokens, op=dist.ReduceOp.SUM) - avg_sum_entropy = calculate_entropy(shifted_labels_list, old_logprobs_list, global_grad_tokens) - avg_rollout_entropy = calculate_entropy(shifted_labels_list, rollout_logprobs_list, global_grad_tokens) + avg_sum_entropy = calculate_entropy(shifted_labels_list, old_logprobs_list, global_grad_tokens, self.sp_mesh) + avg_rollout_entropy = calculate_entropy( + shifted_labels_list, rollout_logprobs_list, global_grad_tokens, self.sp_mesh + ) assert avg_sum_entropy is not None worker_log_item["train_entropy"] = avg_sum_entropy.item() @@ -604,7 +614,7 @@ def fit(self, data_batches: list[WorkerInputItem], rollout_idx: int) -> WorkerLo for i in range(0, len(loss_ctx_list), iters_per_step): batches_loss_ctx = loss_ctx_list[i : i + iters_per_step] batched_loss_ctx_list.extend( - LossContext.build_batches(batches_loss_ctx) # type: ignore[arg-type] + LossContext.build_batches(batches_loss_ctx, sp_mesh=self.sp_mesh) # type: ignore[arg-type] ) # train optimizer steps diff --git a/xtuner/v1/rl/grpo/loss.py b/xtuner/v1/rl/grpo/loss.py index ec34c0b15..b137fc316 100644 --- a/xtuner/v1/rl/grpo/loss.py +++ b/xtuner/v1/rl/grpo/loss.py @@ -4,7 +4,9 @@ import torch import torch.distributed as dist import torch.nn.functional as F +from torch.distributed.device_mesh import DeviceMesh +from xtuner.v1.loss.utils import sp_gather from xtuner.v1.utils import get_logger from ..base import ( @@ -75,15 +77,32 @@ def __init__(self, loss_cfg: GRPOLossConfig, loss_kwargs: GRPOLossKwargs): self.policy_loss_fn = get_policy_loss_fn(self.loss_cfg.policy_loss_cfg.get("loss_type", "vanilla")) @staticmethod - def build_batches(loss_ctx_list: list["GRPOLossContext"]) -> list["GRPOLossContext"]: # type: ignore[override] + def build_batches( # type: ignore[override] + loss_ctx_list: list["GRPOLossContext"], + *args: Any, + **kwargs: Any, + ) -> list["GRPOLossContext"]: + sp_mesh = cast(DeviceMesh | None, kwargs.get("sp_mesh")) + if sp_mesh is None and len(args) >= 2: + sp_mesh = cast(DeviceMesh | None, args[1]) + assert len(loss_ctx_list) > 0, "loss_ctx_list can not be empty" loss_cfg = loss_ctx_list[0].loss_cfg shifted_labels_list = [loss_ctx.loss_kwargs.shifted_labels for loss_ctx in loss_ctx_list] + rank_grad_tokens: torch.Tensor | None = None + for shifted_labels in shifted_labels_list: + if sp_mesh is not None: + # gather shifted_labels from different sp ranks to compute the correct loss weight + shifted_labels = sp_gather(shifted_labels, sp_mesh=sp_mesh, dim=1) + rank_grad_tokens = ( + (shifted_labels != loss_cfg.ignore_idx).sum() + if rank_grad_tokens is None + else rank_grad_tokens + (shifted_labels != loss_cfg.ignore_idx).sum() + ) # Compute the denominator used in the global calibration of the loss - rank_grad_tokens = sum((labels != loss_cfg.ignore_idx).sum() for labels in shifted_labels_list) rank_grad_tokens = cast(torch.Tensor, rank_grad_tokens) global_grad_tokens = rank_grad_tokens if dist.is_initialized():