Skip to content
113 changes: 113 additions & 0 deletions scripts/rl/create_debug_moe_ckpt.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
"""Generate a debug HF-compatible checkpoint for the debugmodel_moe Qwen3 config.

Uses the model's own ``init_weights()`` for properly-scaled initialization
(scaled output projections, etc.), which prevents activation explosion that
naive ``randn * 0.02`` produces across deep models in bf16.

Output: /tmp/debug_moe_ckpt/ with config.json, model.safetensors, tokenizer files.
"""
import json
import os
import shutil

import torch
from safetensors.torch import save_file

from torchtitan.models.qwen3 import model_registry

OUT = "/tmp/debug_moe_ckpt"
os.makedirs(OUT, exist_ok=True)

torch.manual_seed(42)

ms = model_registry("debugmodel_moe", attn_backend="varlen")
with torch.device("meta"):
model = ms.model.build()
model.to_empty(device="cpu")
with torch.no_grad():
model.init_weights(buffer_device=None)

sd_adapter = ms.state_dict_adapter(ms.model, OUT)
hf_state_dict = sd_adapter.to_hf(model.state_dict())
hf_state_dict = {k: v.to(torch.bfloat16) for k, v in hf_state_dict.items()}

save_file(hf_state_dict, os.path.join(OUT, "model.safetensors"))

mc = ms.model
attn = mc.layers[0].attention
moe = mc.layers[0].moe
n_heads = attn.n_heads
n_kv_heads = attn.n_kv_heads or n_heads
head_dim = attn.head_dim if attn.head_dim is not None else mc.dim // n_heads

config = {
"architectures": ["Qwen3MoeForCausalLM"],
"attention_bias": False,
"attention_dropout": 0.0,
"bos_token_id": 0,
"decoder_sparse_step": 1,
"eos_token_id": 1,
"head_dim": head_dim,
"hidden_act": "silu",
"hidden_size": mc.dim,
"initializer_range": 0.02,
"intermediate_size": moe.experts.hidden_dim,
"max_position_embeddings": 4096,
"max_window_layers": len(mc.layers),
"mlp_only_layers": [],
"model_type": "qwen3_moe",
"moe_intermediate_size": moe.experts.hidden_dim,
"norm_topk_prob": True,
"num_attention_heads": n_heads,
"num_experts": moe.experts.num_experts,
"num_experts_per_tok": moe.router.top_k,
"num_hidden_layers": len(mc.layers),
"num_key_value_heads": n_kv_heads,
"output_router_logits": False,
"rms_norm_eps": 1e-6,
"rope_scaling": None,
"rope_theta": 1000000.0,
"router_aux_loss_coef": 0.001,
"sliding_window": None,
"tie_word_embeddings": False,
"torch_dtype": "bfloat16",
"transformers_version": "4.45.0",
"use_cache": True,
"use_sliding_window": False,
"vocab_size": mc.vocab_size,
}
with open(os.path.join(OUT, "config.json"), "w") as f:
json.dump(config, f, indent=2)

with open(os.path.join(OUT, "generation_config.json"), "w") as f:
json.dump(
{
"bos_token_id": 0,
"eos_token_id": 1,
"transformers_version": "4.45.0",
},
f,
indent=2,
)

src = "/data/users/jianiw/model/Qwen3-30B-A3B"
for f in (
"tokenizer.json",
"tokenizer_config.json",
"special_tokens_map.json",
"merges.txt",
"vocab.json",
):
sp = os.path.join(src, f)
if os.path.exists(sp):
shutil.copy(sp, OUT)

total = sum(t.numel() for t in hf_state_dict.values())
print(f"Created debug MoE checkpoint at {OUT}")
print(f" Total params: {total / 1e6:.1f}M")
print(f" Files: {sorted(os.listdir(OUT))}")
4 changes: 4 additions & 0 deletions torchtitan/distributed/parallel_dims.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,10 @@ def _mesh_exist(self, name: str, degree: int) -> bool:
# Always keep fsdp mesh with real backend so fully_shard()
# can apply MixedPrecisionPolicy even at degree 1.
return True
if name == "dp_replicate":
# Always keep dp_replicate mesh so replicate() / fully_shard()
# can use it for data parallel replication even at degree 1.
return True
if name == "efsdp":
# We always keep the efsdp if EP is larger than 1 because we need
# FSDP wrapping to help the MoE layers do mixed precision training.
Expand Down
3 changes: 3 additions & 0 deletions torchtitan/experiments/rl/actors/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,9 @@ def _load_initial_hf_weights(self, model, checkpoint_path: str) -> None:
dcp.load(hf_state_dict, storage_reader=storage_reader)
torchtitan_state_dict = self.sd_adapter.from_hf(hf_state_dict)

# strict=False: some buffers (e.g. expert_bias for MoE load
# balancing) exist in the model but not in HF checkpoints.
# They are zero-initialized by init_states().
set_model_state_dict(
model=model,
model_state_dict=torchtitan_state_dict,
Expand Down
4 changes: 3 additions & 1 deletion torchtitan/experiments/rl/actors/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,9 @@ def compute_logprobs(logits: torch.Tensor, token_ids: torch.Tensor) -> torch.Ten
# code (gather with plain-tensor indices, slicing per-sample) expects a
# plain tensor — materialize once here.
if isinstance(logits, DTensor):
logits = logits.to_local()
# logits has Shard(-1) on the TP mesh (vocab dim sharded by TP).
# full_tensor() all-gathers to get the full vocab dimension.
logits = logits.full_tensor()
shift_logits = logits[:, :-1, :].float()
shift_targets = token_ids[:, 1:]
logprobs = F.log_softmax(shift_logits, dim=-1)
Expand Down
14 changes: 9 additions & 5 deletions torchtitan/experiments/rl/config_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,9 @@ def rl_grpo_qwen3_moe_debug_ep() -> RLTrainer.Config:
Generator uses TP=2 for dense layers and EP=2 for MoE experts.
The RL loop auto-rebuilds the model spec with AllToAllTokenDispatcher
when generator EP > 1.

Generate the debug checkpoint with:
python scripts/create_debug_moe_ckpt.py
"""
return RLTrainer.Config(
model_spec=model_registry("debugmodel_moe", attn_backend="varlen"),
Expand All @@ -224,10 +227,11 @@ def rl_grpo_qwen3_moe_debug_ep() -> RLTrainer.Config:
),
training=TrainingConfig(),
parallelism=ParallelismConfig(
tensor_parallel_degree=1,
tensor_parallel_degree=2,
data_parallel_replicate_degree=1,
expert_parallel_degree=2,
),
compile=CompileConfig(enable=True, backend="aot_eager"),
compile=CompileConfig(enable=False),
loss=GRPOLoss.Config(),
),
generator=VLLMGenerator.Config(
Expand Down Expand Up @@ -334,7 +338,7 @@ def rl_grpo_qwen3_30b_a3b() -> RLTrainer.Config:
),
training=TrainingConfig(dtype="bfloat16"),
parallelism=ParallelismConfig(
tensor_parallel_degree=4,
tensor_parallel_degree=8,
disable_loss_parallel=True,
),
compile=CompileConfig(enable=True, backend="aot_eager"),
Expand All @@ -347,9 +351,9 @@ def rl_grpo_qwen3_30b_a3b() -> RLTrainer.Config:
cudagraph_mode="piecewise",
),
parallelism=ParallelismConfig(
tensor_parallel_degree=4,
tensor_parallel_degree=8,
data_parallel_replicate_degree=1,
expert_parallel_degree=4,
expert_parallel_degree=8,
),
sampling=SamplingConfig(
n=8,
Expand Down
27 changes: 15 additions & 12 deletions torchtitan/experiments/rl/models/vllm_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,14 +111,20 @@ def create_torchtitan_config_from_vllm_config(
world_size = dist.get_world_size()
parallel_config = vllm_config.parallel_config

# When EP is enabled, all TP ranks are repurposed for expert parallelism
# (each rank holds a shard of experts): ep_size = tp_size.
tp_size = parallel_config.tensor_parallel_size
ep_size = tp_size if parallel_config.enable_expert_parallel else 1
dp_size = parallel_config.data_parallel_size

# When EP is enabled, it spans all ranks within a DP*TP group:
# each rank holds a shard of experts.
ep_size = dp_size * tp_size if parallel_config.enable_expert_parallel else 1

# Map vLLM DP to dp_shard for two purposes (no actual FSDP wrapping):
# 1. Mesh math: efsdp = dp_shard * tp / ep — makes EP dims work
# 2. Data routing: vLLM routes different requests to different DP groups
# FSDP is skipped for inference (skip_dp=True) — no backward, no grad sync.
parallel_dims = ParallelDims(
dp_replicate=parallel_config.data_parallel_size,
dp_shard=1,
dp_replicate=1,
dp_shard=dp_size,
cp=1,
tp=tp_size,
pp=parallel_config.pipeline_parallel_size,
Expand All @@ -127,8 +133,8 @@ def create_torchtitan_config_from_vllm_config(
)

parallelism = ParallelismConfig(
data_parallel_replicate_degree=parallel_config.data_parallel_size,
data_parallel_shard_degree=1,
data_parallel_replicate_degree=1,
data_parallel_shard_degree=dp_size,
context_parallel_degree=1,
tensor_parallel_degree=tp_size,
pipeline_parallel_degree=parallel_config.pipeline_parallel_size,
Expand All @@ -137,13 +143,11 @@ def create_torchtitan_config_from_vllm_config(
enable_sequence_parallel=False,
)

# Build the full device mesh so all dimensions (tp, ep, efsdp, etc.)
# are available to the core parallelize function.
parallel_dims.build_mesh()

logger.info(
f"Created TorchTitan config from vLLM: "
f"DP={parallel_dims.dp_replicate}, TP={parallel_dims.tp}, "
f"DP={dp_size}, TP={parallel_dims.tp}, "
f"CP={parallel_dims.cp}, PP={parallel_dims.pp}, "
f"EP={parallel_dims.ep}"
)
Expand Down Expand Up @@ -308,7 +312,6 @@ def __init__(

# Initial load model weights from HuggingFace checkpoint path.
import os as _os

if _os.environ.get("TORCHTITAN_SKIP_INITIAL_HF_LOAD") != "1":
self._initial_load_weights(checkpoint_path=vllm_config.model_config.model)
else:
Expand Down Expand Up @@ -518,7 +521,7 @@ def _initial_load_weights(self, checkpoint_path):
torchtitan_state_dict[name] = DTensor.from_local(
tensor.to(device_mesh.device_type),
device_mesh=device_mesh,
placements=[Replicate()],
placements=[Replicate()] * device_mesh.ndim,
)

return self.load_weights_from_state_dict(torchtitan_state_dict)
Expand Down
10 changes: 7 additions & 3 deletions torchtitan/models/llama4/parallelize.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,8 +265,12 @@ def _experts_shard_placement_fn(
)

assert edp_mesh is not None
edp_mesh_info = FSDPMeshInfo(mesh=edp_mesh, shard_mesh_dim=0)
dp_mesh_info = FSDPMeshInfo(mesh=dp_mesh, shard_mesh_dim=0)
edp_mesh_info = FSDPMeshInfo(
mesh=edp_mesh, shard_mesh_dim=0
) # edp mesh is also 2D, [replicated, efsdp]
dp_mesh_info = FSDPMeshInfo(
mesh=dp_mesh, shard_mesh_dim=0
) # dp mesh is 2D, [replicate, shard]

def _shard_placement_fn(
param: nn.Parameter,
Expand All @@ -279,7 +283,7 @@ def _shard_placement_fn(
return ShardPlacementResult(
placement=_expert_placement, mesh_info=_edp_mesh_info
)
else:
else: # moe.router / moe.shared_experts, apply dense dp shard
return ShardPlacementResult(
placement=Shard(0), mesh_info=_dp_mesh_info
)
Expand Down
4 changes: 1 addition & 3 deletions torchtitan/models/qwen3/parallelize.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,10 +79,8 @@ def parallelize_qwen3(
if model_compile_enabled:
apply_compile(model, compile_config)

# Skip FSDP wrapper for inference. FSDP's forward hooks
# are incompatible with torch.inference_mode() used by vLLM.
# AC and compile are disabled via config (mode="none", enable=False).
if skip_dp:
# Inference path: we don't need to apply FSDP / DDP for inference,
return model

dp_mesh_names = (
Expand Down
Loading