Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
81 changes: 81 additions & 0 deletions tests/model/test_qwen3_5.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,11 @@
from xtuner.v1.config import FSDPConfig
from xtuner.v1.model.compose.qwen3_vl.modeling_vision import init_world_mesh

import tempfile
from pathlib import Path
import json
from safetensors import safe_open


VIDEO_ROOT = os.environ["VIDEO_ROOT"]

Expand Down Expand Up @@ -216,6 +221,82 @@ def test_qwen3_5_vl_run(self, device, sp_size, tol):
self.assertTrue(torch.allclose(loss_xtuner_image_fsdp, loss_xtuner_image, atol=tol, rtol=tol))
self.assertTrue(torch.allclose(loss_xtuner_video_fsdp, loss_xtuner_video, atol=tol, rtol=tol))

@parametrize.parametrize(
"device,sp_size",
[
("cuda", 1),
],
)
def test_save_hf_with_mtp(self, device, sp_size):
self.create_pg(device)
QWEN3_VL_MOE_PATH = os.environ["QWEN3_5_MOE_PATH"]

with torch.device("meta"):
model_cfg = Qwen3_5_VLMoE35BA3Config(compile_cfg=False)
qwen3vl_model = model_cfg.build().to(torch.bfloat16)

fsdp_config = FSDPConfig(cpu_offload=False)
fsdp_mesh = init_world_mesh()
qwen3vl_model.vision_tower.fsdp_mesh = fsdp_mesh
qwen3vl_model.vision_tower.fsdp_config = fsdp_config
qwen3vl_model.fully_shard(fsdp_config=fsdp_config)

with tempfile.TemporaryDirectory() as tmpdir:
syncdir = [tmpdir]
dist.broadcast_object_list(syncdir, src=0)
tmpdir = Path(syncdir[0])
qwen3vl_model.from_hf(QWEN3_VL_MOE_PATH)
qwen3vl_model.save_hf(tmpdir)

origin_hf_path = Path(QWEN3_VL_MOE_PATH)
origin_index_path = origin_hf_path / "model.safetensors.index.json"
saved_index_path = tmpdir / "model.safetensors.index.json"

if dist.get_rank() == 0:
with open(origin_index_path, "r") as f:
origin_index = json.load(f)
with open(saved_index_path, "r") as f:
saved_index = json.load(f)

cache_save_fh: dict = {}

# Verify all original HF weights are preserved correctly
for key in origin_index["weight_map"].keys():
if "mtp" in key:
continue # TODO: remove this after MTP is implemented
origin_safetensor_name = origin_index["weight_map"][key]
saved_safetensor_name = saved_index["weight_map"][key]

origin_sf_fh_name = str(origin_hf_path / origin_safetensor_name)
saved_sf_fh_name = str(tmpdir / saved_safetensor_name)

if origin_sf_fh_name not in cache_save_fh:
cache_save_fh[origin_sf_fh_name] = safe_open(origin_sf_fh_name, framework="pt")
if saved_sf_fh_name not in cache_save_fh:
cache_save_fh[saved_sf_fh_name] = safe_open(saved_sf_fh_name, framework="pt")

origin_tensor = cache_save_fh[origin_sf_fh_name].get_tensor(key)
saved_tensor = cache_save_fh[saved_sf_fh_name].get_tensor(key)

self.assertTrue(torch.equal(origin_tensor, saved_tensor), f"Tensor mismatch for key: {key}")

# Verify MTP weights are present in the saved output
mtp_keys = [key for key in saved_index["weight_map"].keys() if key.startswith("mtp.")]
# TODO: remove skip after MTP is implemented
_ = mtp_keys

# Verify the tensor count in safetensors matches the saved index
safetensor_keys: list[str] = []
for safetensor_path in tmpdir.glob("*.safetensors"):
fh = safe_open(str(safetensor_path), framework="pt")
safetensor_keys.extend(fh.keys())
safetensor_keys.sort()
model_index_keys = list(saved_index["weight_map"].keys())
model_index_keys.sort()
self.assertListEqual(safetensor_keys, model_index_keys)

dist.barrier()

@property
def world_size(self) -> int:
return int(os.getenv("XTUNER_TEST_WORLD_SIZE", "4"))
146 changes: 137 additions & 9 deletions xtuner/v1/model/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
MixedPrecisionPolicy,
fully_shard,
)
from torch.distributed.tensor import DTensor, Placement, Shard
from torch.distributed.tensor import DTensor, Placement, Replicate, Shard, distribute_tensor
from torch.distributed.tensor._utils import compute_local_shape_and_global_offset
from typing_extensions import NotRequired, Self, TypedDict, overload

Expand Down Expand Up @@ -82,6 +82,12 @@ class HFSaveCfg(PydanticBaseModel):
worker_per_rank: Annotated[int, Parameter(group="model")] = 16
max_save_rank: Annotated[int, Parameter(group="model")] = 16
bucket_size: Annotated[int, Parameter(group="model")] = 1024**3 * 4
# TODO: `XTunerBaseModel` should also be able to specify which parameters to be trained in fp32,
# currently it could only be specified in HFSaveCfg
# Each entry is a **regex** pattern (passed to `re.search`) matched against the HF parameter name.
# Remember to escape literal dots, e.g. use r"model\.layers\.\d+\.weight" instead of
# r"model.layers.\d+.weight" to avoid unintended wildcard matches.
fp32_keys_pattern: Annotated[list[str] | None, Parameter(group="model")] = None


class XTunerBaseModelConfig(PydanticBaseModel):
Expand Down Expand Up @@ -313,6 +319,7 @@ def fully_shard(
"""Fully shard the model parameters."""
self.fsdp_config = fsdp_config
self.fsdp_mesh = self._init_world_mesh()
self._world_mesh = self.fsdp_mesh

if self.fsdp_config.requires_grad:
for name, module in self.named_modules():
Expand All @@ -337,15 +344,79 @@ def fully_shard(
mp_policy = MixedPrecisionPolicy(
param_dtype=self.fsdp_config.param_dtype, reduce_dtype=fsdp_config.reduce_dtype
)
fully_shard(
self,
self._fully_shard(
mesh=self.fsdp_mesh,
mp_policy=mp_policy,
reshard_after_forward=fsdp_config.reshard_after_forward,
offload_policy=CPUOffloadPolicy() if self.fsdp_config.cpu_offload else None,
)
return self

def _fully_shard(
self,
mesh: DeviceMesh,
mp_policy: MixedPrecisionPolicy,
reshard_after_forward: bool,
offload_policy: CPUOffloadPolicy | None,
module: nn.Module | None = None,
) -> None:
def traverse(module):
for name, param in module.named_parameters(recurse=False):
full_name = full_param_name_mapping[id(param)]
full_name = self._clean_param_name(full_name)
hf_name_list = self.to_hf_key_list(full_name)

for hf_name in hf_name_list:
if any(re.search(p, hf_name) for p in patterns): # type: ignore
if not isinstance(param, DTensor):
dist_param = nn.Parameter(
distribute_tensor(
param, self.world_mesh, [Replicate() for _ in range(self.world_mesh.ndim)]
),
requires_grad=param.requires_grad,
)
module.register_parameter(name, dist_param)
ignored_params.add(dist_param)
else:
# param is already a DTensor (e.g. distributed by
# MoE._replicate_other_params on ep_mesh before _fully_shard
# is called). We skip re-distributing on world_mesh and just
# add it to ignored_params so FSDP leaves it alone.
# ASSUMPTION: fp32 distribution always happens AFTER any
# prior EP distribution, so the existing placement is correct.
ignored_params.add(param)
break

for child in module.children():
traverse(child)

# Collect the parameters of `target` that match any fp32 pattern so they can be
# excluded from FSDP sharding (passed as `ignored_params`).
#
# We intentionally iterate over `self.named_parameters()` rather than
# `target.named_parameters()` so that `name` is always relative to the root model
# (`self`). This matters when `target` is a sub-module (e.g. `self.embed_tokens`):
# `target.named_parameters()` would yield bare names like `"weight"`, which
# `to_hf_key_list` cannot resolve correctly. By iterating from `self` we get the
# full path (e.g. `"embed_tokens.weight"`) and filter to `target`'s parameters
# using identity comparison.
full_param_name_mapping = {id(param): name for name, param in self.named_parameters()}
ignored_params: set[nn.Parameter] = set()
patterns = self.config.hf_save_cfg.fp32_keys_pattern

target = module or self
if patterns:
traverse(target)

fully_shard(
target,
mesh=mesh,
mp_policy=mp_policy,
reshard_after_forward=reshard_after_forward,
offload_policy=offload_policy,
ignored_params=ignored_params if ignored_params else None,
)

def save_hf(self, hf_dir: Path | str, save_dtype: torch.dtype = torch.bfloat16, safetensors_prefix: str = "model"):
with profile_time_and_memory(f"[Saving HF to [{safetensors_prefix}]{hf_dir} cost]"):
self._save_hf(hf_dir=hf_dir, save_dtype=save_dtype, safetensors_prefix=safetensors_prefix)
Expand Down Expand Up @@ -396,6 +467,12 @@ def device(self) -> torch.device:
return torch.device("cpu")
return torch.device(DEVICE)

@property
def world_mesh(self) -> DeviceMesh | None:
if not hasattr(self, "_world_mesh"):
self._world_mesh = self._init_world_mesh()
return self._world_mesh

@property
def default_compile_cfg(self) -> dict[str, TorchCompileOption]:
return {}
Expand Down Expand Up @@ -670,6 +747,12 @@ def post_micro_batch_forward(self, batch_outputs: Sequence[ModelOutputs]) -> Bat
)
return ret

def _get_save_dtype(self, name: str, dtype: torch.dtype) -> torch.dtype:
patterns = self.config.hf_save_cfg.fp32_keys_pattern
if patterns and any(re.search(p, name) for p in patterns):
return torch.float32
return dtype

def _get_shard_hf_param(
self,
params: list[tuple[torch.Tensor, LoadSpec]],
Expand All @@ -679,6 +762,16 @@ def _get_shard_hf_param(
) -> Generator[tuple[list[str], list[torch.Tensor]], None, None]:
if not params:
return

ignored_params, params = self._split_ignored_params(params)
if ignored_params:
name_list: list[str] = [load_spec.hf_keys[0] for _, load_spec in ignored_params]
hf_params = [param._local_tensor if isinstance(param, DTensor) else param for param, _ in ignored_params]
yield name_list, hf_params

if not params:
return

if dtype != torch.bfloat16:
raise NotImplementedError

Expand All @@ -696,7 +789,7 @@ def _get_hf_params(fsdp_tensor_list: list[tuple[torch.Tensor, LoadSpec]]) -> lis
# Get unsharded params
_unsharded_tensor_list = foreach_all_gather(fsdp_unsharded_tensor_list, load_spec0.group)
unsharded_tensor_list = [
torch.cat([i.to(dtype) for i in tensors], dim=load_spec0.dim) for tensors in _unsharded_tensor_list
torch.cat(list(tensors), dim=load_spec0.dim) for tensors in _unsharded_tensor_list
]
name_list = [spec.hf_keys[0] for _, spec in fsdp_tensor_list]
unsharded_tensor_list = [
Expand All @@ -711,11 +804,11 @@ def _get_hf_params(fsdp_tensor_list: list[tuple[torch.Tensor, LoadSpec]]) -> lis

safetensor_size = 0
tensor_list: list[tuple[torch.Tensor, LoadSpec]] = []
name_list: list[str] = []
name_list = []

for param, load_spec in params:
local_tensor = param._local_tensor if isinstance(param, DTensor) else param
local_tensor = local_tensor.to(dtype=dtype)
local_tensor = local_tensor.to(dtype=self._get_save_dtype(load_spec.hf_keys[0], torch.bfloat16))
tensor_size = self._get_tensor_size(param, dtype)
if safetensor_size + tensor_size > bucket_size and tensor_list:
hf_params = _get_hf_params(tensor_list)
Expand Down Expand Up @@ -744,6 +837,12 @@ def _get_fused_hf_param(
if not params:
return

ignored_params, params = self._split_ignored_params(params)
if ignored_params:
fp32_name_list: list[str] = [load_spec.hf_keys[0] for _, load_spec in ignored_params]
fp32_params = [param._local_tensor if isinstance(param, DTensor) else param for param, _ in ignored_params]
yield fp32_name_list, fp32_params

def _get_hf_params(
fsdp_tensor_list: list[tuple[torch.Tensor, LoadSpec]],
name_list: list[str],
Expand Down Expand Up @@ -867,7 +966,7 @@ def _get_hf_params(

for param, load_spec in params:
local_tensor = param._local_tensor if isinstance(param, DTensor) else param
local_tensor = local_tensor.bfloat16()
local_tensor = local_tensor.to(dtype=self._get_save_dtype(load_spec.hf_keys[0], torch.bfloat16))
tensor_size = self._get_tensor_size(param, dtype)
if safetensor_size + tensor_size > bucket_size and tensor_list:
hf_params, name_list = _get_hf_params(tensor_list, name_list)
Expand All @@ -893,6 +992,15 @@ def _get_same_hf_param(
) -> Generator[tuple[list[str], list[torch.Tensor]], None, None]:
if not params:
return

ignored_params, params = self._split_ignored_params(params)
if ignored_params:
fp32_name_list: list[str] = [load_spec.hf_keys[0] for _, load_spec in ignored_params]
fp32_tensor_list: list[torch.Tensor] = [
param._local_tensor if isinstance(param, DTensor) else param for param, _ in ignored_params
]
yield fp32_name_list, fp32_tensor_list

if bucket_size is None:
bucket_size = self.config.hf_save_cfg.bucket_size
safetensor_size = 0
Expand All @@ -909,7 +1017,7 @@ def _get_same_hf_param(
buffer_name_list.append(load_spec.hf_keys[0])
continue
local_tensor = param._local_tensor if isinstance(param, DTensor) else param
local_tensor = local_tensor.bfloat16()
local_tensor = local_tensor.to(dtype=self._get_save_dtype(load_spec.hf_keys[0], torch.bfloat16))
tensor_size = self._get_tensor_size(param, dtype)
if safetensor_size + tensor_size > bucket_size and tensor_list:
if self.fsdp_mesh is not None:
Expand Down Expand Up @@ -953,6 +1061,21 @@ def _get_same_hf_param(
if buffer_tensor_list:
yield buffer_name_list, buffer_tensor_list

def _is_ignored_params(self, key: str):
patterns = self.config.hf_save_cfg.fp32_keys_pattern
if patterns is None:
return False
return any(re.search(p, key) for p in patterns)

def _split_ignored_params(
self, params: list[tuple[torch.Tensor, LoadSpec]]
) -> tuple[list[tuple[torch.Tensor, LoadSpec]], list[tuple[torch.Tensor, LoadSpec]]]:
if not self.config.hf_save_cfg.fp32_keys_pattern:
return [], params
ignored_params = [(p, l) for p, l in params if self._is_ignored_params(l.hf_keys[0])]
remaining = [(p, l) for p, l in params if not self._is_ignored_params(l.hf_keys[0])]
return ignored_params, remaining

# TODO: Using `xtuenr.v1.utils.misc.clean_param_name`
def _clean_param_name(self, name: str) -> str:
if "_checkpoint_wrapped_module." in name:
Expand Down Expand Up @@ -1230,7 +1353,12 @@ def _load_same_hf_param(

loaded_tensor = loaded_tensor.to(local_tensor.device)

if self.fsdp_mesh is not None and isinstance(param, nn.Parameter):
if (
self.fsdp_mesh is not None
and isinstance(param, nn.Parameter)
and isinstance(param, DTensor)
and any(isinstance(p, Shard) for p in param.placements)
):
shape_before_fsdp = load_spec.shape
_, _offset = compute_local_shape_and_global_offset(
shape_before_fsdp, self.fsdp_mesh, [Shard(self.FSDP_SHARD_DIM)]
Expand Down
4 changes: 1 addition & 3 deletions xtuner/v1/model/compose/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
CPUOffloadPolicy,
FSDPModule,
MixedPrecisionPolicy,
fully_shard,
)
from typing_extensions import override

Expand Down Expand Up @@ -108,8 +107,7 @@ def fully_shard(
# Note: 非常关键,不能删除这个 assert
assert self.fsdp_mesh is not None

fully_shard(
self,
self._fully_shard(
mesh=self.fsdp_mesh,
mp_policy=mp_policy,
reshard_after_forward=fsdp_config.reshard_after_forward,
Expand Down
3 changes: 1 addition & 2 deletions xtuner/v1/model/compose/intern_s1/modeling_intern_s1.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,8 +83,7 @@ def fully_shard(
# Note: 非常关键,不能删除这个 assert
assert self.fsdp_mesh is not None

fully_shard(
self,
self._fully_shard(
mesh=self.fsdp_mesh,
mp_policy=mp_policy,
reshard_after_forward=fsdp_config.reshard_after_forward,
Expand Down
3 changes: 1 addition & 2 deletions xtuner/v1/model/compose/intern_s1/modeling_projector.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,8 +70,7 @@ def fully_shard(
for param in self.parameters():
param.requires_grad = False

fully_shard(
self,
self._fully_shard(
mesh=self.fsdp_mesh,
mp_policy=mp_policy,
reshard_after_forward=True,
Expand Down
Loading
Loading