Skip to content
Open
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 2 additions & 3 deletions tests/model/test_qwen3_5.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import torch.distributed as dist
from xtuner.v1.model import Qwen3_5_VLMoE35BA3Config
from xtuner.v1.loss.ce_loss import CELossConfig
from xtuner.v1.model.moe.moe import SequenceContext
from xtuner.v1.model.moe.moe import SequenceContext, MTPConfig
from xtuner.v1.utils.test_utils import init_data_mesh
from xtuner.v1.datasets import Qwen3VLTokenizeFnConfig
from xtuner.v1.config import FSDPConfig
Expand Down Expand Up @@ -233,6 +233,7 @@ def test_save_hf_with_mtp(self, device, sp_size):

with torch.device("meta"):
model_cfg = Qwen3_5_VLMoE35BA3Config(compile_cfg=False)
model_cfg.text_config.mtp_config = MTPConfig(num_layers=1)
qwen3vl_model = model_cfg.build().to(torch.bfloat16)

fsdp_config = FSDPConfig(cpu_offload=False)
Expand Down Expand Up @@ -262,8 +263,6 @@ def test_save_hf_with_mtp(self, device, sp_size):

# Verify all original HF weights are preserved correctly
for key in origin_index["weight_map"].keys():
if "mtp" in key:
continue # TODO: remove this after MTP is implemented
origin_safetensor_name = origin_index["weight_map"][key]
saved_safetensor_name = saved_index["weight_map"][key]

Expand Down
4 changes: 4 additions & 0 deletions tests/train/test_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,7 @@ def create_pg(self, device):

@patch("xtuner.v1.train.trainer.is_hf_model_path", Mock(return_value=True))
@patch("xtuner.v1.train.trainer.Trainer.build_engine", Mock(side_effect=lambda *args, **kwargs: FakeEngine()))
@patch("xtuner.v1.train.trainer.Trainer._prepare_model_input", Mock(return_value=[]))
@prepare
def test_save_hf_interval(self):
"""Test save_hf is called at correct intervals during training."""
Expand Down Expand Up @@ -184,6 +185,7 @@ def test_save_hf_interval(self):

@patch("xtuner.v1.train.trainer.is_hf_model_path", Mock(return_value=True))
@patch("xtuner.v1.train.trainer.Trainer.build_engine", Mock(side_effect=lambda *args, **kwargs: FakeEngine()))
@patch("xtuner.v1.train.trainer.Trainer._prepare_model_input", Mock(return_value=[]))
@prepare
def test_save_checkpoint_interval(self):
self.create_pg(DEVICE)
Expand Down Expand Up @@ -258,6 +260,7 @@ def test_save_checkpoint_interval(self):

@patch("xtuner.v1.train.trainer.is_hf_model_path", Mock(return_value=True))
@patch("xtuner.v1.train.trainer.Trainer.build_engine", Mock(side_effect=lambda *args, **kwargs: FakeEngine()))
@patch("xtuner.v1.train.trainer.Trainer._prepare_model_input", Mock(return_value=[]))
@prepare
def test_resume(self):
self.create_pg(DEVICE)
Expand Down Expand Up @@ -738,6 +741,7 @@ def __call__(self, checkpoint, step, epoch, total_step, total_epoch):
assert len(loaded.get_hooks(HookStage.AFTER_SAVE_DCP)) == 1


@patch("xtuner.v1.train.trainer.Trainer._prepare_model_input", Mock(return_value=[]))
@patch("xtuner.v1.train.trainer.Trainer.build_engine", Mock(side_effect=lambda *args, **kwargs: FakeEngine()))
def test_resume_and_load_checkpoint_cfg(tmp_path: Path):
# 0. prepare environment
Expand Down
3 changes: 3 additions & 0 deletions xtuner/v1/loss/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
ZLossContext,
ZLossKwargs,
)
from .mtp_loss import MTPLossContext
from .rl_loss import LogProbConfig, LogProbContext


Expand All @@ -29,6 +30,8 @@
"BaseLossConfig",
"BaseLossContext",
"BaseLossKwargs",
"LMHeadLossContext",
"MTPLossContext",
"LogProbConfig",
"LogProbContext",
]
Expand Down
92 changes: 92 additions & 0 deletions xtuner/v1/loss/mtp_loss.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
# Copyright (c) OpenMMLab. All rights reserved.
from torch.distributed.device_mesh import DeviceMesh

from xtuner.v1.loss.ce_loss import CELossConfig, CELossKwargs, LMHeadLossContext
from xtuner.v1.module.mtp.utils import roll_packed_tensor
from xtuner.v1.utils.device import get_device


DEVICE = get_device()


class MTPLossKwargs(CELossKwargs):
"""Keyword arguments for MTP loss computation.

Inherits all fields from CELossKwargs. The ``shifted_labels`` field is
expected to be pre-rolled by ``MTPLossConfig.build()`` before this object
is constructed, so no additional fields are required.

Args:
shifted_labels (torch.Tensor): The shifted and rolled labels for MTP
loss computation.
loss_weight (torch.Tensor | None): Per-token loss weight.
"""


class MTPLossConfig(CELossConfig):
"""Loss configuration for Multi-Token Prediction (MTP).

Extends ``CELossConfig`` with a ``mtp_depth`` field that controls how many
additional positions the labels are rolled during ``build()``. This class
is intended for internal use by the model and is not exposed to users.

Args:
mtp_depth (int): 1-indexed MTP layer depth. The first MTP layer uses
``mtp_depth=1`` (shift=-1 on top of the existing label shift).
"""

mtp_depth: int

@property
def loss_ctx_cls(self) -> type["MTPLossContext"]:
return MTPLossContext

@property
def _loss_kwargs_cls(self) -> type["MTPLossKwargs"]:
return MTPLossKwargs

def build(self, data: dict, sp_mesh: DeviceMesh | None = None) -> "MTPLossContext | None":
"""Build MTPLossContext from data dict.

Rolls ``shifted_labels`` by ``-mtp_depth`` positions (per-sequence,
respecting packed-sequence boundaries) before constructing the loss
context. The roll is performed on the full sequence prior to any
sequence-parallel split so that boundary positions and ``cu_seq_lens``
are always consistent.

Args:
data (dict): Data dict containing loss-related fields.
Required keys: ``shifted_labels``, ``seq_ctx``.
sp_mesh (DeviceMesh | None): Sequence parallel mesh.

Returns:
MTPLossContext | None: Built loss context, or ``None`` if
``shifted_labels`` is not present in ``data``.
"""
if "shifted_labels" not in data:
return None

shifted_labels = data["shifted_labels"]
cu_seq_lens = data["seq_ctx"].cu_seq_lens_k

rolled = roll_packed_tensor(shifted_labels, cu_seq_lens, shifts=-self.mtp_depth, dim=-1, fill_value=-100)

loss_kwargs = MTPLossKwargs(shifted_labels=rolled).to(DEVICE)
if sp_mesh is not None and sp_mesh.size() > 1:
loss_kwargs = loss_kwargs.sp_split(sp_mesh)

return MTPLossContext(self, loss_kwargs)


class MTPLossContext(LMHeadLossContext):
"""Loss context for Multi-Token Prediction (MTP).

Inherits all computation logic from ``LMHeadLossContext``. The label
rolling is handled upstream in ``MTPLossConfig.build()``, so no override
is needed here.

Args:
loss_cfg (MTPLossConfig): The MTP loss configuration.
loss_kwargs (MTPLossKwargs): Pre-rolled keyword arguments for loss
computation.
"""
3 changes: 2 additions & 1 deletion xtuner/v1/model/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from .dense.qwen3 import Qwen3Dense0P6BConfig, Qwen3Dense4BConfig, Qwen3Dense8BConfig, Qwen3DenseConfig
from .moe.deepseek_v3 import DeepSeekV3Config
from .moe.gpt_oss import GptOss21BA3P6Config, GptOss117BA5P8Config, GptOssConfig
from .moe.moe import BalancingLossConfig, MoE, MoEModelOutputs, ZLossConfig
from .moe.moe import BalancingLossConfig, MoE, MoEConfig, MoEModelOutputs, ZLossConfig
from .moe.qwen3 import Qwen3MoE30BA3Config, Qwen3MoEConfig, Qwen3MoEFoPEConfig


Expand Down Expand Up @@ -87,6 +87,7 @@ def get_model_config_from_hf(model_path: Path):
"get_model_config",
"get_model_config_from_hf",
"MoE",
"MoEConfig",
"MoEModelOutputs",
"BalancingLossConfig",
"ZLossConfig",
Expand Down
3 changes: 3 additions & 0 deletions xtuner/v1/model/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -1521,6 +1521,9 @@ def _load_fused_hf_param(
continue
_loaded_tensor.append(weight.to(local_tensor.device))

if not _loaded_tensor:
return missing_keys

if not hf_keys:
# fp8 pad
assert self.config.float8_cfg is not None
Expand Down
6 changes: 3 additions & 3 deletions xtuner/v1/model/compose/qwen3_5/qwen3_5_config.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from xtuner.v1.model.base import TransformerConfig
from xtuner.v1.model.moe.moe import MoEConfig
from xtuner.v1.model.moe.qwen3_5_text import Qwen3_5_VLTextMoE35BA3BConfig
from xtuner.v1.utils import get_logger

Expand All @@ -19,7 +19,7 @@ class Qwen3_5_ProjectorConfig(Qwen3VLProjectorConfig):
class Qwen3_5_BaseConfig(Qwen3VLBaseConfig):
vision_config: Qwen3_5_VisionConfig
projector_config: Qwen3_5_ProjectorConfig
text_config: TransformerConfig
text_config: MoEConfig

image_token_id: int = 248056
video_token_id: int = 248057
Expand All @@ -30,4 +30,4 @@ class Qwen3_5_BaseConfig(Qwen3VLBaseConfig):
class Qwen3_5_VLMoE35BA3Config(Qwen3_5_BaseConfig):
vision_config: Qwen3_5_VisionConfig = Qwen3_5_VisionConfig()
projector_config: Qwen3_5_ProjectorConfig = Qwen3_5_ProjectorConfig()
text_config: TransformerConfig = Qwen3_5_VLTextMoE35BA3BConfig()
text_config: MoEConfig = Qwen3_5_VLTextMoE35BA3BConfig()
24 changes: 16 additions & 8 deletions xtuner/v1/model/dense/qwen3vl_text.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import re

import torch
import torch.nn.functional as F

from xtuner.v1.data_proto import SequenceContext
from xtuner.v1.loss import CELossContext
from xtuner.v1.loss import BaseLossContext
from xtuner.v1.model.base import ModelOutputs

from .qwen3 import Qwen3Dense, Qwen3Dense4BConfig, Qwen3Dense8BConfig
Expand Down Expand Up @@ -34,10 +35,10 @@ def _deepstack_process(
hidden_states[visual_pos_masks, :] = local_this
return hidden_states

def forward(
def forward( # type: ignore[override]
self,
seq_ctx: SequenceContext, # todo(@yehaochen): support intra layer micro-batch
loss_ctx: CELossContext,
loss_ctx: dict[str, BaseLossContext | list[BaseLossContext]] | None = None,
) -> ModelOutputs:
input_ids = seq_ctx.input_ids
position_ids = seq_ctx.position_ids
Expand Down Expand Up @@ -78,11 +79,18 @@ def forward(

hidden_states = self.norm(hidden_states)

loss, (logits, extra_info) = self.lm_head(hidden_states, loss_ctx)
output["loss"] = loss
output["logits"] = logits
output["extra_info"] = extra_info
return ModelOutputs(**output) # type: ignore[typeddict-item]
if loss_ctx is None:
# Inference mode
logits = F.linear(hidden_states, self.lm_head.weight, self.lm_head.bias)
output["logits"] = logits
else:
# Training mode
loss, (logits, extra_info) = self.lm_head(hidden_states, loss_ctx["lm"]) # type: ignore[call-overload]
output["loss"] = loss
output["logits"] = logits
output["extra_info"] = extra_info

return ModelOutputs(**output)


class Qwen3VLTextDense4BConfig(Qwen3Dense4BConfig):
Expand Down
Loading
Loading