diff --git a/xtuner/v1/model/moe/moe.py b/xtuner/v1/model/moe/moe.py index 240e49eb4..758d01422 100644 --- a/xtuner/v1/model/moe/moe.py +++ b/xtuner/v1/model/moe/moe.py @@ -695,7 +695,7 @@ def _forward( for idx, (mtp_hidden, mtp_ctx) in enumerate(zip(mtp_outputs, mtp_loss_ctx_list)): shifted_tensor = mtp_ctx.loss_kwargs.shifted_labels mtp_ctx.loss_kwargs.shifted_labels = roll_packed_tensor( - shifted_tensor, seq_ctx.cu_seq_lens_k, -idx - 1, dim=-1 + shifted_tensor, seq_ctx.cu_seq_lens_k, -idx - 1, dim=-1, fill_value=-100 ) mtp_hidden_states, mtp_router_results, mtp_router_weights = mtp_hidden diff --git a/xtuner/v1/module/mtp/utils.py b/xtuner/v1/module/mtp/utils.py index 4e174a16a..9d3b4d91f 100644 --- a/xtuner/v1/module/mtp/utils.py +++ b/xtuner/v1/module/mtp/utils.py @@ -10,6 +10,7 @@ def roll_packed_tensor( cu_seq_lens: torch.IntTensor, shifts: int = -1, dim: int = -1, + fill_value: float | int = 0, ) -> torch.Tensor: """Roll a packed tensor along the specified dimension. @@ -24,9 +25,12 @@ def roll_packed_tensor( Only negative shifts are supported. dim (int): Dimension along which to roll. The ``cu_seq_lens`` boundaries are applied on this dimension. Default is -1 (last dimension). + fill_value (float | int): Value used to fill boundary positions after rolling. + Defaults to 0. Use the loss ignore index (e.g., -100) when rolling label + tensors to ensure boundary positions are excluded from loss computation. Returns: - torch.Tensor: Rolled tensor with boundary positions zeroed. + torch.Tensor: Rolled tensor with boundary positions filled with ``fill_value``. Example: For packed sequences [1,2,3] and [4,5,6] with shifts=-1, dim=-1: @@ -39,7 +43,7 @@ def roll_packed_tensor( >>> tensor = torch.arange(12).reshape(1, 6, 2) >>> cu_seq_lens = torch.tensor([0, 3, 6], dtype=torch.int32) >>> rolled = roll_packed_tensor(tensor, cu_seq_lens, shifts=-1, dim=-2) - >>> rolled[0, 2] # tensor([0, 0]) (boundary zeroed) + >>> rolled[0, 2] # tensor([0, 0]) (boundary filled with fill_value=0) """ assert shifts <= 0, "Only negative shift is supported" @@ -57,13 +61,13 @@ def roll_packed_tensor( seq_slice = tensor.narrow(dim, start_idx, end_idx - start_idx) # type: ignore[arg-type] rolled_seq = torch.roll(seq_slice, shifts=shifts, dims=dim) - # Zero out the last |shifts| positions along dim to avoid information + # Fill the last |shifts| positions along dim to avoid information # leakage across sequences. For shifts=-1 the last 1 position is - # zeroed; for shifts=-2 the last 2 positions are zeroed, etc. - zero_len = -shifts - zero_start = (end_idx - start_idx) - zero_len - zero_slice = rolled_seq.narrow(dim, zero_start, zero_len) # type: ignore[arg-type] - zero_slice.zero_() + # filled; for shifts=-2 the last 2 positions are filled, etc. + fill_len = -shifts + fill_start = (end_idx - start_idx) - fill_len + fill_slice = rolled_seq.narrow(dim, fill_start, fill_len) # type: ignore[arg-type] + fill_slice.fill_(fill_value) # Write back to the rolled tensor rolled_tensor.narrow(dim, start_idx, end_idx - start_idx).copy_(rolled_seq) # type: ignore[arg-type]