diff --git a/scripts/rl/create_debug_moe_ckpt.py b/scripts/rl/create_debug_moe_ckpt.py new file mode 100644 index 0000000000..6074dda7a1 --- /dev/null +++ b/scripts/rl/create_debug_moe_ckpt.py @@ -0,0 +1,113 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +"""Generate a debug HF-compatible checkpoint for the debugmodel_moe Qwen3 config. + +Uses the model's own ``init_weights()`` for properly-scaled initialization +(scaled output projections, etc.), which prevents activation explosion that +naive ``randn * 0.02`` produces across deep models in bf16. + +Output: /tmp/debug_moe_ckpt/ with config.json, model.safetensors, tokenizer files. +""" +import json +import os +import shutil + +import torch +from safetensors.torch import save_file + +from torchtitan.models.qwen3 import model_registry + +OUT = "/tmp/debug_moe_ckpt" +os.makedirs(OUT, exist_ok=True) + +torch.manual_seed(42) + +ms = model_registry("debugmodel_moe", attn_backend="varlen") +with torch.device("meta"): + model = ms.model.build() +model.to_empty(device="cpu") +with torch.no_grad(): + model.init_weights(buffer_device=None) + +sd_adapter = ms.state_dict_adapter(ms.model, OUT) +hf_state_dict = sd_adapter.to_hf(model.state_dict()) +hf_state_dict = {k: v.to(torch.bfloat16) for k, v in hf_state_dict.items()} + +save_file(hf_state_dict, os.path.join(OUT, "model.safetensors")) + +mc = ms.model +attn = mc.layers[0].attention +moe = mc.layers[0].moe +n_heads = attn.n_heads +n_kv_heads = attn.n_kv_heads or n_heads +head_dim = attn.head_dim if attn.head_dim is not None else mc.dim // n_heads + +config = { + "architectures": ["Qwen3MoeForCausalLM"], + "attention_bias": False, + "attention_dropout": 0.0, + "bos_token_id": 0, + "decoder_sparse_step": 1, + "eos_token_id": 1, + "head_dim": head_dim, + "hidden_act": "silu", + "hidden_size": mc.dim, + "initializer_range": 0.02, + "intermediate_size": moe.experts.hidden_dim, + "max_position_embeddings": 4096, + "max_window_layers": len(mc.layers), + "mlp_only_layers": [], + "model_type": "qwen3_moe", + "moe_intermediate_size": moe.experts.hidden_dim, + "norm_topk_prob": True, + "num_attention_heads": n_heads, + "num_experts": moe.experts.num_experts, + "num_experts_per_tok": moe.router.top_k, + "num_hidden_layers": len(mc.layers), + "num_key_value_heads": n_kv_heads, + "output_router_logits": False, + "rms_norm_eps": 1e-6, + "rope_scaling": None, + "rope_theta": 1000000.0, + "router_aux_loss_coef": 0.001, + "sliding_window": None, + "tie_word_embeddings": False, + "torch_dtype": "bfloat16", + "transformers_version": "4.45.0", + "use_cache": True, + "use_sliding_window": False, + "vocab_size": mc.vocab_size, +} +with open(os.path.join(OUT, "config.json"), "w") as f: + json.dump(config, f, indent=2) + +with open(os.path.join(OUT, "generation_config.json"), "w") as f: + json.dump( + { + "bos_token_id": 0, + "eos_token_id": 1, + "transformers_version": "4.45.0", + }, + f, + indent=2, + ) + +src = "/data/users/jianiw/model/Qwen3-30B-A3B" +for f in ( + "tokenizer.json", + "tokenizer_config.json", + "special_tokens_map.json", + "merges.txt", + "vocab.json", +): + sp = os.path.join(src, f) + if os.path.exists(sp): + shutil.copy(sp, OUT) + +total = sum(t.numel() for t in hf_state_dict.values()) +print(f"Created debug MoE checkpoint at {OUT}") +print(f" Total params: {total / 1e6:.1f}M") +print(f" Files: {sorted(os.listdir(OUT))}") diff --git a/torchtitan/distributed/parallel_dims.py b/torchtitan/distributed/parallel_dims.py index 149f418346..4c84786340 100644 --- a/torchtitan/distributed/parallel_dims.py +++ b/torchtitan/distributed/parallel_dims.py @@ -75,6 +75,10 @@ def _mesh_exist(self, name: str, degree: int) -> bool: # Always keep fsdp mesh with real backend so fully_shard() # can apply MixedPrecisionPolicy even at degree 1. return True + if name == "dp_replicate": + # Always keep dp_replicate mesh so replicate() / fully_shard() + # can use it for data parallel replication even at degree 1. + return True if name == "efsdp": # We always keep the efsdp if EP is larger than 1 because we need # FSDP wrapping to help the MoE layers do mixed precision training. diff --git a/torchtitan/experiments/rl/actors/trainer.py b/torchtitan/experiments/rl/actors/trainer.py index 0367109694..824ad997ce 100644 --- a/torchtitan/experiments/rl/actors/trainer.py +++ b/torchtitan/experiments/rl/actors/trainer.py @@ -185,6 +185,9 @@ def _load_initial_hf_weights(self, model, checkpoint_path: str) -> None: dcp.load(hf_state_dict, storage_reader=storage_reader) torchtitan_state_dict = self.sd_adapter.from_hf(hf_state_dict) + # strict=False: some buffers (e.g. expert_bias for MoE load + # balancing) exist in the model but not in HF checkpoints. + # They are zero-initialized by init_states(). set_model_state_dict( model=model, model_state_dict=torchtitan_state_dict, diff --git a/torchtitan/experiments/rl/actors/utils.py b/torchtitan/experiments/rl/actors/utils.py index 002b495017..eb445cc4f8 100644 --- a/torchtitan/experiments/rl/actors/utils.py +++ b/torchtitan/experiments/rl/actors/utils.py @@ -41,7 +41,9 @@ def compute_logprobs(logits: torch.Tensor, token_ids: torch.Tensor) -> torch.Ten # code (gather with plain-tensor indices, slicing per-sample) expects a # plain tensor — materialize once here. if isinstance(logits, DTensor): - logits = logits.to_local() + # logits has Shard(-1) on the TP mesh (vocab dim sharded by TP). + # full_tensor() all-gathers to get the full vocab dimension. + logits = logits.full_tensor() shift_logits = logits[:, :-1, :].float() shift_targets = token_ids[:, 1:] logprobs = F.log_softmax(shift_logits, dim=-1) diff --git a/torchtitan/experiments/rl/config_registry.py b/torchtitan/experiments/rl/config_registry.py index 4996110a71..fa2a93469a 100644 --- a/torchtitan/experiments/rl/config_registry.py +++ b/torchtitan/experiments/rl/config_registry.py @@ -207,6 +207,9 @@ def rl_grpo_qwen3_moe_debug_ep() -> RLTrainer.Config: Generator uses TP=2 for dense layers and EP=2 for MoE experts. The RL loop auto-rebuilds the model spec with AllToAllTokenDispatcher when generator EP > 1. + + Generate the debug checkpoint with: + python scripts/create_debug_moe_ckpt.py """ return RLTrainer.Config( model_spec=model_registry("debugmodel_moe", attn_backend="varlen"), @@ -224,10 +227,11 @@ def rl_grpo_qwen3_moe_debug_ep() -> RLTrainer.Config: ), training=TrainingConfig(), parallelism=ParallelismConfig( - tensor_parallel_degree=1, + tensor_parallel_degree=2, data_parallel_replicate_degree=1, + expert_parallel_degree=2, ), - compile=CompileConfig(enable=True, backend="aot_eager"), + compile=CompileConfig(enable=False), loss=GRPOLoss.Config(), ), generator=VLLMGenerator.Config( @@ -334,7 +338,7 @@ def rl_grpo_qwen3_30b_a3b() -> RLTrainer.Config: ), training=TrainingConfig(dtype="bfloat16"), parallelism=ParallelismConfig( - tensor_parallel_degree=4, + tensor_parallel_degree=8, disable_loss_parallel=True, ), compile=CompileConfig(enable=True, backend="aot_eager"), @@ -347,9 +351,9 @@ def rl_grpo_qwen3_30b_a3b() -> RLTrainer.Config: cudagraph_mode="piecewise", ), parallelism=ParallelismConfig( - tensor_parallel_degree=4, + tensor_parallel_degree=8, data_parallel_replicate_degree=1, - expert_parallel_degree=4, + expert_parallel_degree=8, ), sampling=SamplingConfig( n=8, diff --git a/torchtitan/experiments/rl/models/vllm_wrapper.py b/torchtitan/experiments/rl/models/vllm_wrapper.py index 9c0e3f3c62..86c7d51718 100644 --- a/torchtitan/experiments/rl/models/vllm_wrapper.py +++ b/torchtitan/experiments/rl/models/vllm_wrapper.py @@ -111,14 +111,20 @@ def create_torchtitan_config_from_vllm_config( world_size = dist.get_world_size() parallel_config = vllm_config.parallel_config - # When EP is enabled, all TP ranks are repurposed for expert parallelism - # (each rank holds a shard of experts): ep_size = tp_size. tp_size = parallel_config.tensor_parallel_size - ep_size = tp_size if parallel_config.enable_expert_parallel else 1 + dp_size = parallel_config.data_parallel_size + # When EP is enabled, it spans all ranks within a DP*TP group: + # each rank holds a shard of experts. + ep_size = dp_size * tp_size if parallel_config.enable_expert_parallel else 1 + + # Map vLLM DP to dp_shard for two purposes (no actual FSDP wrapping): + # 1. Mesh math: efsdp = dp_shard * tp / ep — makes EP dims work + # 2. Data routing: vLLM routes different requests to different DP groups + # FSDP is skipped for inference (skip_dp=True) — no backward, no grad sync. parallel_dims = ParallelDims( - dp_replicate=parallel_config.data_parallel_size, - dp_shard=1, + dp_replicate=1, + dp_shard=dp_size, cp=1, tp=tp_size, pp=parallel_config.pipeline_parallel_size, @@ -127,8 +133,8 @@ def create_torchtitan_config_from_vllm_config( ) parallelism = ParallelismConfig( - data_parallel_replicate_degree=parallel_config.data_parallel_size, - data_parallel_shard_degree=1, + data_parallel_replicate_degree=1, + data_parallel_shard_degree=dp_size, context_parallel_degree=1, tensor_parallel_degree=tp_size, pipeline_parallel_degree=parallel_config.pipeline_parallel_size, @@ -137,13 +143,11 @@ def create_torchtitan_config_from_vllm_config( enable_sequence_parallel=False, ) - # Build the full device mesh so all dimensions (tp, ep, efsdp, etc.) - # are available to the core parallelize function. parallel_dims.build_mesh() logger.info( f"Created TorchTitan config from vLLM: " - f"DP={parallel_dims.dp_replicate}, TP={parallel_dims.tp}, " + f"DP={dp_size}, TP={parallel_dims.tp}, " f"CP={parallel_dims.cp}, PP={parallel_dims.pp}, " f"EP={parallel_dims.ep}" ) @@ -308,7 +312,6 @@ def __init__( # Initial load model weights from HuggingFace checkpoint path. import os as _os - if _os.environ.get("TORCHTITAN_SKIP_INITIAL_HF_LOAD") != "1": self._initial_load_weights(checkpoint_path=vllm_config.model_config.model) else: @@ -518,7 +521,7 @@ def _initial_load_weights(self, checkpoint_path): torchtitan_state_dict[name] = DTensor.from_local( tensor.to(device_mesh.device_type), device_mesh=device_mesh, - placements=[Replicate()], + placements=[Replicate()] * device_mesh.ndim, ) return self.load_weights_from_state_dict(torchtitan_state_dict) diff --git a/torchtitan/models/llama4/parallelize.py b/torchtitan/models/llama4/parallelize.py index f23a21a50e..464da4838f 100644 --- a/torchtitan/models/llama4/parallelize.py +++ b/torchtitan/models/llama4/parallelize.py @@ -265,8 +265,12 @@ def _experts_shard_placement_fn( ) assert edp_mesh is not None - edp_mesh_info = FSDPMeshInfo(mesh=edp_mesh, shard_mesh_dim=0) - dp_mesh_info = FSDPMeshInfo(mesh=dp_mesh, shard_mesh_dim=0) + edp_mesh_info = FSDPMeshInfo( + mesh=edp_mesh, shard_mesh_dim=0 + ) # edp mesh is also 2D, [replicated, efsdp] + dp_mesh_info = FSDPMeshInfo( + mesh=dp_mesh, shard_mesh_dim=0 + ) # dp mesh is 2D, [replicate, shard] def _shard_placement_fn( param: nn.Parameter, @@ -279,7 +283,7 @@ def _shard_placement_fn( return ShardPlacementResult( placement=_expert_placement, mesh_info=_edp_mesh_info ) - else: + else: # moe.router / moe.shared_experts, apply dense dp shard return ShardPlacementResult( placement=Shard(0), mesh_info=_dp_mesh_info ) diff --git a/torchtitan/models/qwen3/parallelize.py b/torchtitan/models/qwen3/parallelize.py index ab89c5b358..368485bad4 100644 --- a/torchtitan/models/qwen3/parallelize.py +++ b/torchtitan/models/qwen3/parallelize.py @@ -79,10 +79,8 @@ def parallelize_qwen3( if model_compile_enabled: apply_compile(model, compile_config) - # Skip FSDP wrapper for inference. FSDP's forward hooks - # are incompatible with torch.inference_mode() used by vLLM. - # AC and compile are disabled via config (mode="none", enable=False). if skip_dp: + # Inference path: we don't need to apply FSDP / DDP for inference, return model dp_mesh_names = (