Skip to content

Commit f48a247

Browse files
committed
[Feature] Add Multi-Token Prediction (MTP) module implementation
ghstack-source-id: 1b45af1 Pull-Request: InternLM#1570
1 parent 960c475 commit f48a247

File tree

13 files changed

+773
-43
lines changed

13 files changed

+773
-43
lines changed

tests/train/test_trainer.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,7 @@ def create_pg(self, device):
112112

113113
@patch("xtuner.v1.train.trainer.is_hf_model_path", Mock(return_value=True))
114114
@patch("xtuner.v1.train.trainer.Trainer.build_engine", Mock(side_effect=lambda *args, **kwargs: FakeEngine()))
115+
@patch("xtuner.v1.train.trainer.Trainer._prepare_model_input", Mock(return_value=[]))
115116
@prepare
116117
def test_save_hf_interval(self):
117118
"""Test save_hf is called at correct intervals during training."""
@@ -184,6 +185,7 @@ def test_save_hf_interval(self):
184185

185186
@patch("xtuner.v1.train.trainer.is_hf_model_path", Mock(return_value=True))
186187
@patch("xtuner.v1.train.trainer.Trainer.build_engine", Mock(side_effect=lambda *args, **kwargs: FakeEngine()))
188+
@patch("xtuner.v1.train.trainer.Trainer._prepare_model_input", Mock(return_value=[]))
187189
@prepare
188190
def test_save_checkpoint_interval(self):
189191
self.create_pg(DEVICE)
@@ -258,6 +260,7 @@ def test_save_checkpoint_interval(self):
258260

259261
@patch("xtuner.v1.train.trainer.is_hf_model_path", Mock(return_value=True))
260262
@patch("xtuner.v1.train.trainer.Trainer.build_engine", Mock(side_effect=lambda *args, **kwargs: FakeEngine()))
263+
@patch("xtuner.v1.train.trainer.Trainer._prepare_model_input", Mock(return_value=[]))
261264
@prepare
262265
def test_resume(self):
263266
self.create_pg(DEVICE)
@@ -738,6 +741,7 @@ def __call__(self, checkpoint, step, epoch, total_step, total_epoch):
738741
assert len(loaded.get_hooks(HookStage.AFTER_SAVE_DCP)) == 1
739742

740743

744+
@patch("xtuner.v1.train.trainer.Trainer._prepare_model_input", Mock(return_value=[]))
741745
@patch("xtuner.v1.train.trainer.Trainer.build_engine", Mock(side_effect=lambda *args, **kwargs: FakeEngine()))
742746
def test_resume_and_load_checkpoint_cfg(tmp_path: Path):
743747
# 0. prepare environment

xtuner/v1/model/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
from .dense.qwen3 import Qwen3Dense0P6BConfig, Qwen3Dense4BConfig, Qwen3Dense8BConfig, Qwen3DenseConfig
2424
from .moe.deepseek_v3 import DeepSeekV3Config
2525
from .moe.gpt_oss import GptOss21BA3P6Config, GptOss117BA5P8Config, GptOssConfig
26-
from .moe.moe import BalancingLossConfig, MoE, MoEModelOutputs, ZLossConfig
26+
from .moe.moe import BalancingLossConfig, MoE, MoEConfig, MoEModelOutputs, ZLossConfig
2727
from .moe.qwen3 import Qwen3MoE30BA3Config, Qwen3MoEConfig, Qwen3MoEFoPEConfig
2828

2929

@@ -87,6 +87,7 @@ def get_model_config_from_hf(model_path: Path):
8787
"get_model_config",
8888
"get_model_config_from_hf",
8989
"MoE",
90+
"MoEConfig",
9091
"MoEModelOutputs",
9192
"BalancingLossConfig",
9293
"ZLossConfig",

xtuner/v1/model/compose/qwen3_5/qwen3_5_config.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from xtuner.v1.model.base import TransformerConfig
1+
from xtuner.v1.model.moe.moe import MoEConfig
22
from xtuner.v1.model.moe.qwen3_5_text import Qwen3_5_VLTextMoE35BA3BConfig
33
from xtuner.v1.utils import get_logger
44

@@ -19,7 +19,7 @@ class Qwen3_5_ProjectorConfig(Qwen3VLProjectorConfig):
1919
class Qwen3_5_BaseConfig(Qwen3VLBaseConfig):
2020
vision_config: Qwen3_5_VisionConfig
2121
projector_config: Qwen3_5_ProjectorConfig
22-
text_config: TransformerConfig
22+
text_config: MoEConfig
2323

2424
image_token_id: int = 248056
2525
video_token_id: int = 248057
@@ -30,4 +30,4 @@ class Qwen3_5_BaseConfig(Qwen3VLBaseConfig):
3030
class Qwen3_5_VLMoE35BA3Config(Qwen3_5_BaseConfig):
3131
vision_config: Qwen3_5_VisionConfig = Qwen3_5_VisionConfig()
3232
projector_config: Qwen3_5_ProjectorConfig = Qwen3_5_ProjectorConfig()
33-
text_config: TransformerConfig = Qwen3_5_VLTextMoE35BA3BConfig()
33+
text_config: MoEConfig = Qwen3_5_VLTextMoE35BA3BConfig()

xtuner/v1/model/dense/qwen3vl_text.py

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
import re
22

33
import torch
4+
import torch.nn.functional as F
45

56
from xtuner.v1.data_proto import SequenceContext
6-
from xtuner.v1.loss import CELossContext
7+
from xtuner.v1.loss import BaseLossContext
78
from xtuner.v1.model.base import ModelOutputs
89

910
from .qwen3 import Qwen3Dense, Qwen3Dense4BConfig, Qwen3Dense8BConfig
@@ -34,10 +35,10 @@ def _deepstack_process(
3435
hidden_states[visual_pos_masks, :] = local_this
3536
return hidden_states
3637

37-
def forward(
38+
def forward( # type: ignore[override]
3839
self,
3940
seq_ctx: SequenceContext, # todo(@yehaochen): support intra layer micro-batch
40-
loss_ctx: CELossContext,
41+
loss_ctx: dict[str, BaseLossContext | list[BaseLossContext]] | None = None,
4142
) -> ModelOutputs:
4243
input_ids = seq_ctx.input_ids
4344
position_ids = seq_ctx.position_ids
@@ -78,11 +79,18 @@ def forward(
7879

7980
hidden_states = self.norm(hidden_states)
8081

81-
loss, (logits, extra_info) = self.lm_head(hidden_states, loss_ctx)
82-
output["loss"] = loss
83-
output["logits"] = logits
84-
output["extra_info"] = extra_info
85-
return ModelOutputs(**output) # type: ignore[typeddict-item]
82+
if loss_ctx is None:
83+
# Inference mode
84+
logits = F.linear(hidden_states, self.lm_head.weight, self.lm_head.bias)
85+
output["logits"] = logits
86+
else:
87+
# Training mode
88+
loss, (logits, extra_info) = self.lm_head(hidden_states, loss_ctx["lm"]) # type: ignore[call-overload]
89+
output["loss"] = loss
90+
output["logits"] = logits
91+
output["extra_info"] = extra_info
92+
93+
return ModelOutputs(**output)
8694

8795

8896
class Qwen3VLTextDense4BConfig(Qwen3Dense4BConfig):

0 commit comments

Comments
 (0)