Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion xtuner/v1/model/moe/moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -695,7 +695,7 @@ def _forward(
for idx, (mtp_hidden, mtp_ctx) in enumerate(zip(mtp_outputs, mtp_loss_ctx_list)):
shifted_tensor = mtp_ctx.loss_kwargs.shifted_labels
mtp_ctx.loss_kwargs.shifted_labels = roll_packed_tensor(
shifted_tensor, seq_ctx.cu_seq_lens_k, -idx - 1, dim=-1
shifted_tensor, seq_ctx.cu_seq_lens_k, -idx - 1, dim=-1, fill_value=-100
)

mtp_hidden_states, mtp_router_results, mtp_router_weights = mtp_hidden
Expand Down
20 changes: 12 additions & 8 deletions xtuner/v1/module/mtp/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ def roll_packed_tensor(
cu_seq_lens: torch.IntTensor,
shifts: int = -1,
dim: int = -1,
fill_value: float | int = 0,
) -> torch.Tensor:
"""Roll a packed tensor along the specified dimension.

Expand All @@ -24,9 +25,12 @@ def roll_packed_tensor(
Only negative shifts are supported.
dim (int): Dimension along which to roll. The ``cu_seq_lens`` boundaries
are applied on this dimension. Default is -1 (last dimension).
fill_value (float | int): Value used to fill boundary positions after rolling.
Defaults to 0. Use the loss ignore index (e.g., -100) when rolling label
tensors to ensure boundary positions are excluded from loss computation.

Returns:
torch.Tensor: Rolled tensor with boundary positions zeroed.
torch.Tensor: Rolled tensor with boundary positions filled with ``fill_value``.

Example:
For packed sequences [1,2,3] and [4,5,6] with shifts=-1, dim=-1:
Expand All @@ -39,7 +43,7 @@ def roll_packed_tensor(
>>> tensor = torch.arange(12).reshape(1, 6, 2)
>>> cu_seq_lens = torch.tensor([0, 3, 6], dtype=torch.int32)
>>> rolled = roll_packed_tensor(tensor, cu_seq_lens, shifts=-1, dim=-2)
>>> rolled[0, 2] # tensor([0, 0]) (boundary zeroed)
>>> rolled[0, 2] # tensor([0, 0]) (boundary filled with fill_value=0)
"""
assert shifts <= 0, "Only negative shift is supported"

Expand All @@ -57,13 +61,13 @@ def roll_packed_tensor(
seq_slice = tensor.narrow(dim, start_idx, end_idx - start_idx) # type: ignore[arg-type]
rolled_seq = torch.roll(seq_slice, shifts=shifts, dims=dim)

# Zero out the last |shifts| positions along dim to avoid information
# Fill the last |shifts| positions along dim to avoid information
# leakage across sequences. For shifts=-1 the last 1 position is
# zeroed; for shifts=-2 the last 2 positions are zeroed, etc.
zero_len = -shifts
zero_start = (end_idx - start_idx) - zero_len
zero_slice = rolled_seq.narrow(dim, zero_start, zero_len) # type: ignore[arg-type]
zero_slice.zero_()
# filled; for shifts=-2 the last 2 positions are filled, etc.
fill_len = -shifts
fill_start = (end_idx - start_idx) - fill_len
fill_slice = rolled_seq.narrow(dim, fill_start, fill_len) # type: ignore[arg-type]
fill_slice.fill_(fill_value)

# Write back to the rolled tensor
rolled_tensor.narrow(dim, start_idx, end_idx - start_idx).copy_(rolled_seq) # type: ignore[arg-type]
Expand Down
Loading