Skip to content

Commit 3d3c3a9

Browse files
committed
[Feature] Add Multi-Token Prediction (MTP) module implementation
ghstack-source-id: e63ad27 Pull-Request: InternLM#1570
1 parent f794ae6 commit 3d3c3a9

File tree

17 files changed

+888
-53
lines changed

17 files changed

+888
-53
lines changed

tests/model/test_qwen3_5.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
import torch.distributed as dist
1010
from xtuner.v1.model import Qwen3_5_VLMoE35BA3Config
1111
from xtuner.v1.loss.ce_loss import CELossConfig
12-
from xtuner.v1.model.moe.moe import SequenceContext
12+
from xtuner.v1.model.moe.moe import SequenceContext, MTPConfig
1313
from xtuner.v1.utils.test_utils import init_data_mesh
1414
from xtuner.v1.datasets import Qwen3VLTokenizeFnConfig
1515
from xtuner.v1.config import FSDPConfig
@@ -233,6 +233,7 @@ def test_save_hf_with_mtp(self, device, sp_size):
233233

234234
with torch.device("meta"):
235235
model_cfg = Qwen3_5_VLMoE35BA3Config(compile_cfg=False)
236+
model_cfg.text_config.mtp_config = MTPConfig(num_layers=1)
236237
qwen3vl_model = model_cfg.build().to(torch.bfloat16)
237238

238239
fsdp_config = FSDPConfig(cpu_offload=False)
@@ -262,8 +263,6 @@ def test_save_hf_with_mtp(self, device, sp_size):
262263

263264
# Verify all original HF weights are preserved correctly
264265
for key in origin_index["weight_map"].keys():
265-
if "mtp" in key:
266-
continue # TODO: remove this after MTP is implemented
267266
origin_safetensor_name = origin_index["weight_map"][key]
268267
saved_safetensor_name = saved_index["weight_map"][key]
269268

tests/train/test_trainer.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,7 @@ def create_pg(self, device):
112112

113113
@patch("xtuner.v1.train.trainer.is_hf_model_path", Mock(return_value=True))
114114
@patch("xtuner.v1.train.trainer.Trainer.build_engine", Mock(side_effect=lambda *args, **kwargs: FakeEngine()))
115+
@patch("xtuner.v1.train.trainer.Trainer._prepare_model_input", Mock(return_value=[]))
115116
@prepare
116117
def test_save_hf_interval(self):
117118
"""Test save_hf is called at correct intervals during training."""
@@ -184,6 +185,7 @@ def test_save_hf_interval(self):
184185

185186
@patch("xtuner.v1.train.trainer.is_hf_model_path", Mock(return_value=True))
186187
@patch("xtuner.v1.train.trainer.Trainer.build_engine", Mock(side_effect=lambda *args, **kwargs: FakeEngine()))
188+
@patch("xtuner.v1.train.trainer.Trainer._prepare_model_input", Mock(return_value=[]))
187189
@prepare
188190
def test_save_checkpoint_interval(self):
189191
self.create_pg(DEVICE)
@@ -258,6 +260,7 @@ def test_save_checkpoint_interval(self):
258260

259261
@patch("xtuner.v1.train.trainer.is_hf_model_path", Mock(return_value=True))
260262
@patch("xtuner.v1.train.trainer.Trainer.build_engine", Mock(side_effect=lambda *args, **kwargs: FakeEngine()))
263+
@patch("xtuner.v1.train.trainer.Trainer._prepare_model_input", Mock(return_value=[]))
261264
@prepare
262265
def test_resume(self):
263266
self.create_pg(DEVICE)
@@ -738,6 +741,7 @@ def __call__(self, checkpoint, step, epoch, total_step, total_epoch):
738741
assert len(loaded.get_hooks(HookStage.AFTER_SAVE_DCP)) == 1
739742

740743

744+
@patch("xtuner.v1.train.trainer.Trainer._prepare_model_input", Mock(return_value=[]))
741745
@patch("xtuner.v1.train.trainer.Trainer.build_engine", Mock(side_effect=lambda *args, **kwargs: FakeEngine()))
742746
def test_resume_and_load_checkpoint_cfg(tmp_path: Path):
743747
# 0. prepare environment

xtuner/v1/loss/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
ZLossContext,
1212
ZLossKwargs,
1313
)
14+
from .mtp_loss import MTPLossContext
1415
from .rl_loss import LogProbConfig, LogProbContext
1516

1617

@@ -29,6 +30,8 @@
2930
"BaseLossConfig",
3031
"BaseLossContext",
3132
"BaseLossKwargs",
33+
"LMHeadLossContext",
34+
"MTPLossContext",
3235
"LogProbConfig",
3336
"LogProbContext",
3437
]

xtuner/v1/loss/mtp_loss.py

Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,92 @@
1+
# Copyright (c) OpenMMLab. All rights reserved.
2+
from torch.distributed.device_mesh import DeviceMesh
3+
4+
from xtuner.v1.loss.ce_loss import CELossConfig, CELossKwargs, LMHeadLossContext
5+
from xtuner.v1.module.mtp.utils import roll_packed_tensor
6+
from xtuner.v1.utils.device import get_device
7+
8+
9+
DEVICE = get_device()
10+
11+
12+
class MTPLossKwargs(CELossKwargs):
13+
"""Keyword arguments for MTP loss computation.
14+
15+
Inherits all fields from CELossKwargs. The ``shifted_labels`` field is
16+
expected to be pre-rolled by ``MTPLossConfig.build()`` before this object
17+
is constructed, so no additional fields are required.
18+
19+
Args:
20+
shifted_labels (torch.Tensor): The shifted and rolled labels for MTP
21+
loss computation.
22+
loss_weight (torch.Tensor | None): Per-token loss weight.
23+
"""
24+
25+
26+
class MTPLossConfig(CELossConfig):
27+
"""Loss configuration for Multi-Token Prediction (MTP).
28+
29+
Extends ``CELossConfig`` with a ``mtp_depth`` field that controls how many
30+
additional positions the labels are rolled during ``build()``. This class
31+
is intended for internal use by the model and is not exposed to users.
32+
33+
Args:
34+
mtp_depth (int): 1-indexed MTP layer depth. The first MTP layer uses
35+
``mtp_depth=1`` (shift=-1 on top of the existing label shift).
36+
"""
37+
38+
mtp_depth: int
39+
40+
@property
41+
def loss_ctx_cls(self) -> type["MTPLossContext"]:
42+
return MTPLossContext
43+
44+
@property
45+
def _loss_kwargs_cls(self) -> type["MTPLossKwargs"]:
46+
return MTPLossKwargs
47+
48+
def build(self, data: dict, sp_mesh: DeviceMesh | None = None) -> "MTPLossContext | None":
49+
"""Build MTPLossContext from data dict.
50+
51+
Rolls ``shifted_labels`` by ``-mtp_depth`` positions (per-sequence,
52+
respecting packed-sequence boundaries) before constructing the loss
53+
context. The roll is performed on the full sequence prior to any
54+
sequence-parallel split so that boundary positions and ``cu_seq_lens``
55+
are always consistent.
56+
57+
Args:
58+
data (dict): Data dict containing loss-related fields.
59+
Required keys: ``shifted_labels``, ``seq_ctx``.
60+
sp_mesh (DeviceMesh | None): Sequence parallel mesh.
61+
62+
Returns:
63+
MTPLossContext | None: Built loss context, or ``None`` if
64+
``shifted_labels`` is not present in ``data``.
65+
"""
66+
if "shifted_labels" not in data:
67+
return None
68+
69+
shifted_labels = data["shifted_labels"]
70+
cu_seq_lens = data["seq_ctx"].cu_seq_lens_k
71+
72+
rolled = roll_packed_tensor(shifted_labels, cu_seq_lens, shifts=-self.mtp_depth, dim=-1, fill_value=-100)
73+
74+
loss_kwargs = MTPLossKwargs(shifted_labels=rolled).to(DEVICE)
75+
if sp_mesh is not None and sp_mesh.size() > 1:
76+
loss_kwargs = loss_kwargs.sp_split(sp_mesh)
77+
78+
return MTPLossContext(self, loss_kwargs)
79+
80+
81+
class MTPLossContext(LMHeadLossContext):
82+
"""Loss context for Multi-Token Prediction (MTP).
83+
84+
Inherits all computation logic from ``LMHeadLossContext``. The label
85+
rolling is handled upstream in ``MTPLossConfig.build()``, so no override
86+
is needed here.
87+
88+
Args:
89+
loss_cfg (MTPLossConfig): The MTP loss configuration.
90+
loss_kwargs (MTPLossKwargs): Pre-rolled keyword arguments for loss
91+
computation.
92+
"""

xtuner/v1/model/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
from .dense.qwen3 import Qwen3Dense0P6BConfig, Qwen3Dense4BConfig, Qwen3Dense8BConfig, Qwen3DenseConfig
2424
from .moe.deepseek_v3 import DeepSeekV3Config
2525
from .moe.gpt_oss import GptOss21BA3P6Config, GptOss117BA5P8Config, GptOssConfig
26-
from .moe.moe import BalancingLossConfig, MoE, MoEModelOutputs, ZLossConfig
26+
from .moe.moe import BalancingLossConfig, MoE, MoEConfig, MoEModelOutputs, ZLossConfig
2727
from .moe.qwen3 import Qwen3MoE30BA3Config, Qwen3MoEConfig, Qwen3MoEFoPEConfig
2828

2929

@@ -87,6 +87,7 @@ def get_model_config_from_hf(model_path: Path):
8787
"get_model_config",
8888
"get_model_config_from_hf",
8989
"MoE",
90+
"MoEConfig",
9091
"MoEModelOutputs",
9192
"BalancingLossConfig",
9293
"ZLossConfig",

xtuner/v1/model/base.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1557,6 +1557,9 @@ def _load_fused_hf_param(
15571557
continue
15581558
_loaded_tensor.append(weight.to(local_tensor.device))
15591559

1560+
if not _loaded_tensor:
1561+
return missing_keys
1562+
15601563
if not hf_keys:
15611564
# fp8 pad
15621565
assert self.config.float8_cfg is not None

xtuner/v1/model/compose/qwen3_5/qwen3_5_config.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from xtuner.v1.model.base import TransformerConfig
1+
from xtuner.v1.model.moe.moe import MoEConfig
22
from xtuner.v1.model.moe.qwen3_5_text import Qwen3_5_VLTextMoE35BA3BConfig
33
from xtuner.v1.utils import get_logger
44

@@ -19,7 +19,7 @@ class Qwen3_5_ProjectorConfig(Qwen3VLProjectorConfig):
1919
class Qwen3_5_BaseConfig(Qwen3VLBaseConfig):
2020
vision_config: Qwen3_5_VisionConfig
2121
projector_config: Qwen3_5_ProjectorConfig
22-
text_config: TransformerConfig
22+
text_config: MoEConfig
2323

2424
image_token_id: int = 248056
2525
video_token_id: int = 248057
@@ -30,4 +30,4 @@ class Qwen3_5_BaseConfig(Qwen3VLBaseConfig):
3030
class Qwen3_5_VLMoE35BA3Config(Qwen3_5_BaseConfig):
3131
vision_config: Qwen3_5_VisionConfig = Qwen3_5_VisionConfig()
3232
projector_config: Qwen3_5_ProjectorConfig = Qwen3_5_ProjectorConfig()
33-
text_config: TransformerConfig = Qwen3_5_VLTextMoE35BA3BConfig()
33+
text_config: MoEConfig = Qwen3_5_VLTextMoE35BA3BConfig()

xtuner/v1/model/dense/qwen3vl_text.py

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
import re
22

33
import torch
4+
import torch.nn.functional as F
45

56
from xtuner.v1.data_proto import SequenceContext
6-
from xtuner.v1.loss import CELossContext
7+
from xtuner.v1.loss import BaseLossContext
78
from xtuner.v1.model.base import ModelOutputs
89

910
from .qwen3 import Qwen3Dense, Qwen3Dense4BConfig, Qwen3Dense8BConfig
@@ -34,10 +35,10 @@ def _deepstack_process(
3435
hidden_states[visual_pos_masks, :] = local_this
3536
return hidden_states
3637

37-
def forward(
38+
def forward( # type: ignore[override]
3839
self,
3940
seq_ctx: SequenceContext, # todo(@yehaochen): support intra layer micro-batch
40-
loss_ctx: CELossContext,
41+
loss_ctx: dict[str, BaseLossContext | list[BaseLossContext]] | None = None,
4142
) -> ModelOutputs:
4243
input_ids = seq_ctx.input_ids
4344
position_ids = seq_ctx.position_ids
@@ -78,11 +79,18 @@ def forward(
7879

7980
hidden_states = self.norm(hidden_states)
8081

81-
loss, (logits, extra_info) = self.lm_head(hidden_states, loss_ctx)
82-
output["loss"] = loss
83-
output["logits"] = logits
84-
output["extra_info"] = extra_info
85-
return ModelOutputs(**output) # type: ignore[typeddict-item]
82+
if loss_ctx is None:
83+
# Inference mode
84+
logits = F.linear(hidden_states, self.lm_head.weight, self.lm_head.bias)
85+
output["logits"] = logits
86+
else:
87+
# Training mode
88+
loss, (logits, extra_info) = self.lm_head(hidden_states, loss_ctx["lm"]) # type: ignore[call-overload]
89+
output["loss"] = loss
90+
output["logits"] = logits
91+
output["extra_info"] = extra_info
92+
93+
return ModelOutputs(**output)
8694

8795

8896
class Qwen3VLTextDense4BConfig(Qwen3Dense4BConfig):

0 commit comments

Comments
 (0)