diff --git a/.ci/docker/requirements-transformers-modeling-backend.txt b/.ci/docker/requirements-transformers-modeling-backend.txt index 76e8886ed0..4a4a39c9ba 100644 --- a/.ci/docker/requirements-transformers-modeling-backend.txt +++ b/.ci/docker/requirements-transformers-modeling-backend.txt @@ -1 +1 @@ -transformers==4.57.1 +transformers==5.0.0rc0 diff --git a/torchtitan/experiments/transformers_modeling_backend/README.md b/torchtitan/experiments/transformers_modeling_backend/README.md index fb70d03a1f..d47290947f 100644 --- a/torchtitan/experiments/transformers_modeling_backend/README.md +++ b/torchtitan/experiments/transformers_modeling_backend/README.md @@ -4,7 +4,7 @@ This enables HF transformers models to be trained with `4D parallelism + torch.c ## Quick start -- Requirements `transformers==4.57.1` +- Requirements `transformers==5.0.0rc0` - Config: `torchtitan/torchtitan/experiments/transformers_modeling_backend/configs/qwen3.toml` ```diff diff --git a/torchtitan/experiments/transformers_modeling_backend/model/model.py b/torchtitan/experiments/transformers_modeling_backend/model/model.py index b88fffc54b..5e937c598a 100644 --- a/torchtitan/experiments/transformers_modeling_backend/model/model.py +++ b/torchtitan/experiments/transformers_modeling_backend/model/model.py @@ -9,6 +9,7 @@ import torch from torch import nn +from torch.distributed.tensor import DTensor from torch.nn import init from torchtitan.tools.logging import logger from transformers.configuration_utils import PretrainedConfig @@ -270,7 +271,10 @@ def _init_weights_patched(self, module): module.weight.data.normal_(mean=0.0, std=std) if module.padding_idx is not None: - module.weight.data[module.padding_idx].zero_() + if isinstance(module.weight.data, DTensor): + module.weight.data._local_tensor[module.padding_idx].zero_() + else: + module.weight.data[module.padding_idx].zero_() elif ( isinstance( @@ -429,10 +433,6 @@ def init_weights(self, *args, **kwargs): # This method replicates the behavior of the original PreTrainedModel.init_weights, # but with a custom weight initialization function that skips nn.Identity modules (when PP is enabled) - if self.model.config.pruned_heads: - logger.info("Pruning heads as per model configuration.") - self.model.prune_heads(self.model.config.pruned_heads) - original_init_weights_fn = self.model._init_weights def selective_init(module):