diff --git a/.ci/docker/requirements-transformers-backend.txt b/.ci/docker/requirements-transformers-backend.txt index 6b0cc637db..76e8886ed0 100644 --- a/.ci/docker/requirements-transformers-backend.txt +++ b/.ci/docker/requirements-transformers-backend.txt @@ -1 +1 @@ -transformers==4.55.4 +transformers==4.57.1 diff --git a/torchtitan/experiments/transformers_backend/README.md b/torchtitan/experiments/transformers_backend/README.md index a5b4059c07..3d1a2dcf0d 100644 --- a/torchtitan/experiments/transformers_backend/README.md +++ b/torchtitan/experiments/transformers_backend/README.md @@ -2,22 +2,21 @@ ## Quick start -- Requirements `transformers==4.55.4` +- Requirements `transformers==4.57.1` - Config: `torchtitan/torchtitan/experiments/transformers_backend/configs/qwen3_fsdp2_tp2_pp2.toml` ```diff ... [model] - name = "llama3" -+ name = "transformers_backend" ++ name = "Qwen/Qwen3-4B-Instruct-2507" flavor = "debugmodel" hf_assets_path = "./tests/assets/tokenizer" - -+[hf_transformers] -+model = "Qwen/Qwen3-4B-Instruct-2507" ... ``` -- Train: `LOG_RANK=7 CONFIG_FILE=/torchtitan/experiments/transformers_backend/configs/qwen3_fsdp2_tp2_pp2.toml ./run_train.sh --compile.enable` - Make sure you have created the tokenizers beforehand image diff --git a/torchtitan/experiments/transformers_backend/__init__.py b/torchtitan/experiments/transformers_backend/__init__.py index fd0cd9b689..5ebf531043 100644 --- a/torchtitan/experiments/transformers_backend/__init__.py +++ b/torchtitan/experiments/transformers_backend/__init__.py @@ -5,11 +5,22 @@ # LICENSE file in the root directory of this source tree. from dataclasses import dataclass +import torch.nn as nn + +from torchtitan.components.ft import FTManager +from torchtitan.models.moe import MoEArgs from torchtitan.components.loss import build_cross_entropy_loss from torchtitan.components.lr_scheduler import build_lr_schedulers -from torchtitan.components.optimizer import build_optimizers +from torchtitan.components.optimizer import ( + build_optimizers, + build_optimizers_with_moe_load_balancing, + OptimizersContainer, +) from torchtitan.components.tokenizer import build_hf_tokenizer +from torchtitan.config import Optimizer as OptimizerConfig +from torchtitan.distributed import ParallelDims from torchtitan.hf_datasets.text_datasets import build_text_dataloader +from torchtitan.models.moe import MoEArgs from torchtitan.protocols.train_spec import TrainSpec from .infra.parallelize import parallelize_hf_transformers @@ -24,7 +35,6 @@ "HFTransformerModel", ] - @dataclass class TitanDenseModelArgs: """Arguments for the base TorchTitan model.""" @@ -43,6 +53,30 @@ class TitanDenseModelArgs: use_flex_attn: bool = False attn_mask_type: str = "causal" +@dataclass +class TitanMoeModelArgs: + """Arguments specific to DeepSeekV3 models.""" + + moe_args: MoEArgs | None = None + n_group: int | None = None + topk_group: int | None = None + inter_dim: int | None = None + moe_inter_dim: int | None = None + n_dense_layers: int | None = None + n_expert_groups: int | None = None + n_limited_groups: int | None = None + q_lora_rank: int | None = None + kv_lora_rank: int | None = None + qk_nope_head_dim: int | None = None + qk_rope_head_dim: int | None = None + v_head_dim: int | None = None + original_seq_len: int | None = None + rope_factor: float | None = None + beta_fast: int | None = None + beta_slow: int | None = None + mscale: float | None = None + partial_rotary_factor: float | None = None + rope_interleave: bool = True flavors = { "debugmodel": HFTransformerModelArgs( @@ -53,11 +87,73 @@ class TitanDenseModelArgs: n_kv_heads=16, ), ), + "debugmodel_moe": HFTransformerModelArgs( + titan_dense_args=TitanDenseModelArgs( + dim=256, + n_layers=3, + n_heads=16, + n_kv_heads=16, + ), + titan_moe_args=TitanMoeModelArgs( + partial_rotary_factor=4.0, + inter_dim=1024, + moe_inter_dim=256, + n_dense_layers=1, + n_group=2, + topk_group=1, + kv_lora_rank=512, + q_lora_rank=0, + qk_nope_head_dim=128, + qk_rope_head_dim=64, + v_head_dim=128, + mscale=0.70, + moe_args=MoEArgs( + num_experts=8, + num_shared_experts=2, + top_k=3, + score_func="softmax", + route_norm=True, + score_before_experts=False, + load_balance_coeff=1e-3, + ), + ), + ), "full": HFTransformerModelArgs( titan_dense_args=TitanDenseModelArgs(), ), } +def build_optimizers_auto_detect_moe( + model_parts: list[nn.Module], + optimizer_config: OptimizerConfig, + parallel_dims: ParallelDims, + ft_manager: FTManager | None = None, +) -> OptimizersContainer: + + # Check if any model part has MoE enabled + has_moe = False + for model_part in model_parts: + if hasattr(model_part, "layers"): + for layer in model_part.layers: + if hasattr(layer, "moe_enabled") and layer.moe_enabled: + has_moe = True + break + if has_moe: + break + + if has_moe: + # NOTE(3outeille): Monkey-patch temporarily for compatibility. Otherwise, I will need to copy optimizer.py just to loop over layer instead of layer.values(). + for model_part in model_parts: + if hasattr(model_part, "layers") and not hasattr(model_part.layers, "values"): + model_part.layers.values = lambda self=model_part.layers: iter(self) + + return_val = (build_optimizers_with_moe_load_balancing if has_moe else build_optimizers)( + model_parts=model_parts, + optimizer_config=optimizer_config, + parallel_dims=parallel_dims, + ft_manager=ft_manager, + ) + return return_val def get_train_spec() -> TrainSpec: return TrainSpec( @@ -65,7 +161,7 @@ def get_train_spec() -> TrainSpec: model_args=flavors, parallelize_fn=parallelize_hf_transformers, pipelining_fn=pipeline_hf_transformers, - build_optimizers_fn=build_optimizers, + build_optimizers_fn=build_optimizers_auto_detect_moe, build_lr_schedulers_fn=build_lr_schedulers, build_dataloader_fn=build_text_dataloader, build_tokenizer_fn=build_hf_tokenizer, diff --git a/torchtitan/experiments/transformers_backend/configs/qwen3_fsdp2_tp2_pp2.toml b/torchtitan/experiments/transformers_backend/configs/qwen3_fsdp2_tp2_pp2.toml index 13e3f4ddf0..b0e294ccbe 100644 --- a/torchtitan/experiments/transformers_backend/configs/qwen3_fsdp2_tp2_pp2.toml +++ b/torchtitan/experiments/transformers_backend/configs/qwen3_fsdp2_tp2_pp2.toml @@ -20,15 +20,12 @@ save_tb_folder = "tb" enable_wandb = false [model] -name = "transformers_backend" +name = "Qwen/Qwen3-4B-Instruct-2507" flavor = "debugmodel" # test folder with tokenizer.json, for debug purpose only hf_assets_path = "./tests/assets/tokenizer" # converters = ["float8"] -[hf_transformers] -model = "Qwen/Qwen3-4B-Instruct-2507" - [optimizer] name = "AdamW" lr = 8e-4 diff --git a/torchtitan/experiments/transformers_backend/infra/parallelize.py b/torchtitan/experiments/transformers_backend/infra/parallelize.py index b2ae3f02a1..a4b7e66ad8 100644 --- a/torchtitan/experiments/transformers_backend/infra/parallelize.py +++ b/torchtitan/experiments/transformers_backend/infra/parallelize.py @@ -17,12 +17,12 @@ SequenceParallel, ) from torchtitan.config import TORCH_DTYPE_MAP +from torchtitan.config.job_config import JobConfig from torchtitan.distributed import NoParallel, ParallelDims from torchtitan.distributed.activation_checkpoint import apply_ac from torchtitan.distributed.tensor_parallel import maybe_enable_async_tp -from torchtitan.experiments.transformers_backend.job_config import JobConfig from torchtitan.models.llama3.infra.parallelize import apply_compile, apply_ddp from torchtitan.tools.logging import logger diff --git a/torchtitan/experiments/transformers_backend/infra/pipeline.py b/torchtitan/experiments/transformers_backend/infra/pipeline.py index 6a891bb271..b813225fe6 100644 --- a/torchtitan/experiments/transformers_backend/infra/pipeline.py +++ b/torchtitan/experiments/transformers_backend/infra/pipeline.py @@ -14,13 +14,13 @@ ) from torchtitan.components.loss import LossFunction +from torchtitan.config.job_config import JobConfig from torchtitan.distributed import ParallelDims from torchtitan.distributed.pipeline_parallel import ( build_pipeline_schedule, generate_llm_fqn_per_model_part, pipeline_module_split, ) -from torchtitan.experiments.transformers_backend.job_config import JobConfig from torchtitan.protocols.train_spec import BaseModelArgs, ParallelizeFunction from torchtitan.tools.logging import logger diff --git a/torchtitan/experiments/transformers_backend/job_config.py b/torchtitan/experiments/transformers_backend/job_config.py deleted file mode 100644 index f3b1667798..0000000000 --- a/torchtitan/experiments/transformers_backend/job_config.py +++ /dev/null @@ -1,18 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -from dataclasses import dataclass, field - - -@dataclass -class HFTransformers: - model: str = "" - """HuggingFace model ID (e.g., 'Qwen/Qwen3-4B-Instruct-2507')""" - - -@dataclass -class JobConfig: - hf_transformers: HFTransformers = field(default_factory=HFTransformers) diff --git a/torchtitan/experiments/transformers_backend/model/args.py b/torchtitan/experiments/transformers_backend/model/args.py index db1696e7a5..899a333d6d 100644 --- a/torchtitan/experiments/transformers_backend/model/args.py +++ b/torchtitan/experiments/transformers_backend/model/args.py @@ -7,8 +7,11 @@ from dataclasses import dataclass from torch import nn -from torchtitan.config.job_config import JobConfig -from torchtitan.models.utils import get_dense_model_nparams_and_flops +from torchtitan.config import JobConfig +from torchtitan.models.utils import ( + get_dense_model_nparams_and_flops, + get_moe_model_nparams_and_flops, +) from torchtitan.protocols import BaseModelArgs from transformers import AutoConfig from transformers.configuration_utils import PretrainedConfig @@ -36,7 +39,12 @@ class HFTransformerModelArgs(PretrainedConfig, BaseModelArgs): "norm_eps": "rms_norm_eps", "max_seq_len": "max_position_embeddings", "eos_id": "eos_token_id", - } + }, + "moe": { + # TorchTitan moe model specific mappings (only when titan_moe_args provided) + "inter_dim": "intermediate_size", + "n_dense_layers": "first_k_dense_replace", + }, } # Declarative list of TorchTitan-only attributes (no HF equivalent) @@ -48,9 +56,23 @@ class HFTransformerModelArgs(PretrainedConfig, BaseModelArgs): "attn_mask_type", ] + # MoE attributes that should be copied directly + _MOE_SHARED_ATTRIBUTES = [ + "rope_interleave", + "partial_rotary_factor", + "n_group", + "topk_group", + "kv_lora_rank", + "q_lora_rank", + "qk_nope_head_dim", + "qk_rope_head_dim", + "v_head_dim", + ] + def __init__( self, titan_dense_args, + titan_moe_args=None, # HuggingFace specific args attn_implementation: str = "sdpa_torchtitan", **kwargs, @@ -59,7 +81,7 @@ def __init__( assert titan_dense_args is not None, "titan_dense_args is required" # Create getter/setter dynamically for TT <-> HF attribute mappings - self._create_getter_setter_dynamically(has_moe=False) + self._create_getter_setter_dynamically(titan_moe_args is not None) self._titan_injected_model_args = {} self._titan_injected_model_args.update(kwargs) @@ -67,6 +89,9 @@ def __init__( self._initialize_dense_attributes(titan_dense_args) + if titan_moe_args is not None: + self._initialize_moe_attributes(titan_moe_args) + def _initialize_dense_attributes(self, titan_dense_args): """Initialize all dense model attributes.""" # Set mapped attributes (TorchTitan <-> HuggingFace) @@ -83,6 +108,45 @@ def _initialize_dense_attributes(self, titan_dense_args): # Update passed_args self._titan_injected_model_args.update(titan_dense_args.__dict__) + def _initialize_moe_attributes(self, titan_moe_args): + """Initialize all MoE-specific attributes.""" + if titan_moe_args.moe_args is None: + self._titan_injected_model_args.update(titan_moe_args.__dict__) + return + + moe_args = titan_moe_args.moe_args + + # Convert q_lora_rank (0 -> None for HuggingFace compatibility) + self.q_lora_rank = ( + None if titan_moe_args.q_lora_rank == 0 else titan_moe_args.q_lora_rank + ) + + # Set core MoE attributes + self.moe_args = moe_args + self.num_experts_per_tok = moe_args.top_k + self.n_routed_experts = moe_args.num_experts + self.n_shared_experts = moe_args.num_shared_experts + self.moe_intermediate_size = titan_moe_args.moe_inter_dim + + # Set remaining architecture-specific MoE attributes + for attr in self._MOE_SHARED_ATTRIBUTES: + if attr == "q_lora_rank": + continue # Already set above + if hasattr(titan_moe_args, attr): + setattr(self, attr, getattr(titan_moe_args, attr)) + + # Track all MoE arguments + self._titan_injected_model_args.update(titan_moe_args.__dict__) + self._titan_injected_model_args.update( + { + "num_experts_per_tok": moe_args.top_k, + "n_routed_experts": moe_args.num_experts, + "n_shared_experts": moe_args.num_shared_experts, + "moe_intermediate_size": titan_moe_args.moe_inter_dim, + "q_lora_rank": self.q_lora_rank, + } + ) + def _configure_hf_attention(self, attn_implementation: str): """Configure HuggingFace attention settings.""" self._titan_injected_model_args["attn_implementation"] = attn_implementation @@ -151,6 +215,9 @@ def update_from_config(self, job_config: JobConfig): if hasattr(self, key) and value is not None: setattr(self, key, value) + # MoE + if hasattr(self, "qk_nope_head_dim") and hasattr(self, "qk_rope_head_dim"): + self.qk_head_dim = self.qk_nope_head_dim + self.qk_rope_head_dim self.max_seq_len = job_config.training.seq_len # Configure HF-specific settings to match TorchTitan settings @@ -174,4 +241,9 @@ def update_from_config(self, job_config: JobConfig): return self def get_nparams_and_flops(self, model: nn.Module, seq_len: int) -> tuple[int, int]: - return get_dense_model_nparams_and_flops(self, model, head_dims=self.head_dim, seq_len=seq_len) + is_moe = hasattr(self, "n_routed_experts") + + if is_moe: + return get_moe_model_nparams_and_flops(self, model, head_dims=self.head_dim, seq_len=seq_len) + else: + return get_dense_model_nparams_and_flops(self, model, head_dims=self.head_dim, seq_len=seq_len) diff --git a/torchtitan/experiments/transformers_backend/model/model.py b/torchtitan/experiments/transformers_backend/model/model.py index 8c35ac4e94..f7d1538096 100644 --- a/torchtitan/experiments/transformers_backend/model/model.py +++ b/torchtitan/experiments/transformers_backend/model/model.py @@ -74,24 +74,55 @@ def __init__(self, model_args: HFTransformerModelArgs): model_module, f"{model_name_prefix}DecoderLayer", None ) - required_classes = { - "Attention": attention_cls, - "DecoderLayer": decoder_layer_cls, - } - - if all(required_classes.values()): - logger.info(f"Applying Llama-like patch for {model_name_prefix}") - self._patch_hf_llama_like( - decoder_layer_cls=decoder_layer_cls, - attention_cls=attention_cls, - mlp_cls=mlp_cls, # mlp_cls can be None - ) + is_moe = hasattr( + model_args, "n_routed_experts" + ) # TODO(3outeille): check if this is the most reliable to detect a moe model + if is_moe: + moe_cls = getattr(model_module, f"{model_name_prefix}MoE", None) + required_classes = { + "Attention": attention_cls, + "MLP": mlp_cls, + "DecoderLayer": decoder_layer_cls, + "MoE": moe_cls, + } + + if all(required_classes.values()): + logger.info(f"Applying MoE-like patch for {model_name_prefix}") + self._patch_hf_moe_like( + decoder_layer_cls=decoder_layer_cls, + attention_cls=attention_cls, + mlp_cls=mlp_cls, + moe_cls=moe_cls, + ) + else: + missing = [ + name for name, cls in required_classes.items() if not cls + ] + logger.warning( + f"Could not find required classes ({', '.join(missing)}) for MoE patching of {model_name_prefix}. " + "Skipping MoE-like patch." + ) else: - missing = [name for name, cls in required_classes.items() if not cls] - logger.warning( - f"Could not find required classes ({', '.join(missing)}) for {model_name_prefix}. " - "Skipping Llama-like patch." - ) + required_classes = { + "Attention": attention_cls, + "DecoderLayer": decoder_layer_cls, + } + + if all(required_classes.values()): + logger.info(f"Applying Llama-like patch for {model_name_prefix}") + self._patch_hf_llama_like( + decoder_layer_cls=decoder_layer_cls, + attention_cls=attention_cls, + mlp_cls=mlp_cls, # mlp_cls can be None + ) + else: + missing = [ + name for name, cls in required_classes.items() if not cls + ] + logger.warning( + f"Could not find required classes ({', '.join(missing)}) for {model_name_prefix}. " + "Skipping Llama-like patch." + ) except Exception as e: logger.warning( @@ -101,7 +132,6 @@ def __init__(self, model_args: HFTransformerModelArgs): self.model = model_cls(config=model_args) self.max_seq_len = model_args.max_seq_len - self.cp_mesh = None # Convert ModuleList to ModuleDict to preserve original indices # This ensures state dict keys match checkpoint keys @@ -111,7 +141,13 @@ def __init__(self, model_args: HFTransformerModelArgs): ) for layer in self.model.model.layers.values(): - layer.moe_enabled = False + if ( + hasattr(model_args, "first_k_dense_replace") + and layer.layer_idx >= model_args.first_k_dense_replace + ): + layer.moe_enabled = True + else: + layer.moe_enabled = False def set_cp_mesh(self, mesh): self.cp_mesh = mesh @@ -285,6 +321,175 @@ def _init_weights_patched(self, module): PreTrainedModel._init_weights = _init_weights_patched PreTrainedModel._initialize_weights = _initialize_weights_patched + def _patch_hf_moe_like(self, decoder_layer_cls, attention_cls, mlp_cls, moe_cls): + """ + This patch modifies a Hugging Face MoE (Mixture-of-Experts) model's weight + initialization to match the initialization scheme used in TorchTitan, + drawing from patterns in models like DeepseekV3. + + The patch targets: + - `PreTrainedModel._initialize_weights`: For correct meta device initialization. + - `PreTrainedModel._init_weights`: To implement TorchTitan's specific initialization + for attention, MLP, MoE, embedding, and layer norm layers. + - `DecoderLayer.__init__`: Adds `layer_idx` to attention, MLP, and MoE expert + modules, required for depth-dependent initialization. + """ + + _original_decoder_layer_init = decoder_layer_cls.__init__ + + def _decoder_layer_init_patched(self, config: PretrainedConfig, layer_idx: int): + _original_decoder_layer_init(self, config, layer_idx) + self.layer_idx = layer_idx + + if hasattr(self, "self_attn"): + self.self_attn.layer_idx = layer_idx + + if hasattr(self, "mlp"): + self.mlp.layer_idx = layer_idx + if hasattr(self.mlp, "experts"): + for expert in self.mlp.experts: + expert.layer_idx = layer_idx + if hasattr(self.mlp, "shared_experts"): + # Not all MoE models have shared experts + if self.mlp.shared_experts is not None: + self.mlp.shared_experts.layer_idx = layer_idx + + def _initialize_weights_patched(self, module): + if getattr(module, "_is_hf_initialized", False): + return + for param in module.parameters(recurse=True): + if param.device.type == "meta": + return + self._init_weights(module) + module._is_hf_initialized = True + + def _init_weights_patched(self, module): + """ + Patched version of _init_weights for MoE models. + """ + config = self.config + init_std = None + + if isinstance(module, (attention_cls, mlp_cls, moe_cls)): + if hasattr(module, "layer_idx"): + layer_idx = module.layer_idx + if hasattr(config, "depth_init") and config.depth_init: + init_std = 0.02 / (2 * (layer_idx + 1)) ** 0.5 + else: + # Fallback for models without depth_init + init_std = 0.02 / (2 * config.num_hidden_layers) ** 0.5 + + if isinstance(module, attention_cls): + # Handle different attention projection layer names by initializing if they exist + if hasattr(module, "q_proj"): + nn.init.trunc_normal_(module.q_proj.weight, mean=0.0, std=0.02) + if hasattr(module, "k_proj"): + nn.init.trunc_normal_(module.k_proj.weight, mean=0.0, std=0.02) + if hasattr(module, "v_proj"): + nn.init.trunc_normal_(module.v_proj.weight, mean=0.0, std=0.02) + + if hasattr(module, "q_a_proj"): + nn.init.trunc_normal_(module.q_a_proj.weight, mean=0.0, std=0.02) + if hasattr(module, "q_b_proj"): + nn.init.trunc_normal_(module.q_b_proj.weight, mean=0.0, std=0.02) + + if hasattr(module, "kv_a_proj_with_mqa"): + nn.init.trunc_normal_( + module.kv_a_proj_with_mqa.weight, mean=0.0, std=0.02 + ) + if hasattr(module, "kv_b_proj"): + nn.init.trunc_normal_(module.kv_b_proj.weight, mean=0.0, std=0.02) + + if hasattr(module, "o_proj") and init_std is not None: + nn.init.trunc_normal_(module.o_proj.weight, mean=0.0, std=init_std) + + elif isinstance(module, mlp_cls): + nn.init.trunc_normal_(module.gate_proj.weight, mean=0.0, std=0.02) + # DeepseekV3 uses std=0.02 for up_proj, unlike Llama + nn.init.trunc_normal_(module.up_proj.weight, mean=0.0, std=0.02) + if init_std is not None: + nn.init.trunc_normal_( + module.down_proj.weight, mean=0.0, std=init_std + ) + + elif isinstance(module, moe_cls): + if hasattr(module, "gate") and init_std is not None: + nn.init.trunc_normal_(module.gate.weight, mean=0.0, std=init_std) + if hasattr(module, "experts"): + for expert in module.experts: + nn.init.trunc_normal_( + expert.gate_proj.weight, mean=0.0, std=0.02 + ) + nn.init.trunc_normal_(expert.up_proj.weight, mean=0.0, std=0.02) + if init_std is not None: + nn.init.trunc_normal_( + expert.down_proj.weight, mean=0.0, std=init_std + ) + if ( + hasattr(module, "shared_experts") + and module.shared_experts is not None + ): + nn.init.trunc_normal_( + module.shared_experts.gate_proj.weight, mean=0.0, std=0.02 + ) + nn.init.trunc_normal_( + module.shared_experts.up_proj.weight, mean=0.0, std=0.02 + ) + if init_std is not None: + nn.init.trunc_normal_( + module.shared_experts.down_proj.weight, + mean=0.0, + std=init_std, + ) + + elif module is getattr(self, "lm_head", None): + final_out_std = config.hidden_size**-0.5 + cutoff_factor = 3 + nn.init.trunc_normal_( + module.weight, + mean=0.0, + std=final_out_std, + a=-cutoff_factor * final_out_std, + b=cutoff_factor * final_out_std, + ) + if module.bias is not None: + module.bias.data.zero_() + + elif isinstance(module, nn.Embedding): + # When tie_word_embeddings is True, use lm_head initialization + if ( + hasattr(config, "tie_word_embeddings") + and config.tie_word_embeddings + ): + final_out_std = config.hidden_size**-0.5 + cutoff_factor = 3 + nn.init.trunc_normal_( + module.weight, + mean=0.0, + std=final_out_std, + a=-cutoff_factor * final_out_std, + b=cutoff_factor * final_out_std, + ) + else: + std = config.initializer_range + module.weight.data.normal_(mean=0.0, std=std) + + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + + elif ( + "LayerNorm" in module.__class__.__name__ + or "RMSNorm" in module.__class__.__name__ + ): + if hasattr(module, "weight") and module.weight is not None: + module.weight.data.fill_(1.0) + if hasattr(module, "bias") and module.bias is not None: + module.bias.data.zero_() + + decoder_layer_cls.__init__ = _decoder_layer_init_patched + PreTrainedModel._init_weights = _init_weights_patched + PreTrainedModel._initialize_weights = _initialize_weights_patched + @property def tok_embeddings(self): """Returns the model's embed_tokens, handling different Hugging Face model structures.""" diff --git a/torchtitan/protocols/train_spec.py b/torchtitan/protocols/train_spec.py index 22bfa7df9b..12a9b68e71 100644 --- a/torchtitan/protocols/train_spec.py +++ b/torchtitan/protocols/train_spec.py @@ -77,6 +77,11 @@ def get_train_spec(name: str) -> TrainSpec: from torchtitan.experiments import _supported_experiments from torchtitan.models import _supported_models + # If name contains '/', treat it as a HuggingFace model ID and use transformers_backend + if "/" in name: + module = import_module("torchtitan.experiments.transformers_backend") + return module.get_train_spec() + if name in _supported_models: module = import_module(f"torchtitan.models.{name}") return module.get_train_spec()