diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index d7aa78b531db..10b5dc49faf4 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -25,7 +25,7 @@ - local: serialization title: Exporting to production - local: modular_transformers - title: Contributing a new model to Transformers + title: Add a model with modular transformers - local: add_new_model title: Legacy model contribution - local: auto_docstring diff --git a/docs/source/en/modular_transformers.md b/docs/source/en/modular_transformers.md index 620e877d971d..8c2715165674 100644 --- a/docs/source/en/modular_transformers.md +++ b/docs/source/en/modular_transformers.md @@ -1,98 +1,34 @@ -# Contributing a new model to Transformers + -Modular Transformers addresses these issues by adding a *modular* file to a model folder. The modular file can import code from other models and inherit code from other classes unlike traditional modeling and processing files. +# Add a model with modular transformers -> [!TIP] -> Modular Transformers isn't meant to replace the modeling code, and if your model isn't based on an existing model, you'll need to add a `modeling.py` file manually. Likewise, if a configuration, tokenization or processing file can't easily inherit from a similar file, you can add that file directly. - -A modular file contains model, processor, and configuration class code that would otherwise be in separate files under the single model, single file policy. - -Model users still import and use the single-file interface they've grown familiar with. In doing so, we hope to enable simpler contributions while sticking to our philosophy. - -## Create a modeling.py file - -A linter "unravels" the modular file into a `modeling.py` file to preserve the single model, single file directory structure (modeling, processor, etc.). Inheritance is flattened to only a **single** level. - -Run the command below to automatically generate a `modeling.py` file from a modular file (assuming the snake lowercase name of the model you want to convert is `your_model`). - -```bash -python utils/modular_model_converter.py your_model -``` - -For example: - -- If a configuration class inherits from another class, but adds and deletes an argument, the generated file directly references it if an argument is added or completely removes it if an argument is deleted. -- If a class inherits from another, like `GemmaModel(LlamaModel)`, the dependencies are automatically inferred. All submodules are also automatically inferred from the superclass. -- If a new function is defined in the modular file and used inside classes, the linter automatically infers these as well. - -You should be able to write everything (tokenizer, image processor, model, config, etc.) in a modular and their corresponding single-files are generated. - -The example below demonstrates how a model can be added with significantly fewer lines of code with Modular Transformers. - -### BERT and RoBERTa +Modular transformers reduces the code needed to add a model by allowing imports and inheritance, in contrast to the [single model, single file](https://huggingface.co/blog/transformers-design-philosophy) policy. Instead of repeating model components across files, add a *modular* file to your model folder and inherit from existing classes. -BERT and RoBERTa, two very similar models, differ solely in how the embedding layer is implemented. +A converter generates standalone files from the modular file. Users import and use the same single-file interface they're already familiar with. -Instead of redefining the model entirely, consider the `modular_roberta.py` file shown below for the modeling and configuration classes (the tokenizer isn't shown in this example). - -```py -from torch import nn -from ..bert.configuration_bert import BertConfig -from ..bert.modeling_bert import ( - BertModel, - BertEmbeddings, - BertForMaskedLM -) - -# RoBERTa and BERT config is identical -class RobertaConfig(BertConfig): - model_type = 'roberta' - -# Redefine the embeddings to highlight the padding id difference, and redefine the position embeddings -class RobertaEmbeddings(BertEmbeddings): - def __init__(self, config): - super().__init__(config()) - - self.padding_idx = config.pad_token_id - self.position_embeddings = nn.Embedding( - config.max_position_embeddings, config.hidden_size, padding_idx=self.padding_idx - ) - -# RoBERTa and BERT model is identical except for the embedding layer, which is defined above, so no need for additional changes here -class RobertaModel(BertModel): - def __init__(self, config): - super().__init__(config) - self.embeddings = RobertaEmbeddings(config) - - -# The model heads now only need to redefine the model inside to `RobertaModel` -class RobertaForMaskedLM(BertForMaskedLM): - def __init__(self, config): - super().__init__(config) - self.model = RobertaModel(config) -``` - -If you don't use the defined dependency, you'll receive the following error. - -```text -ValueError: You defined `RobertaEmbeddings` in the modular_roberta.py, it should be used when you define `BertModel`, as it is one of it's direct dependencies. Make sure you use it in the `__init__` function. -``` +> [!NOTE] +> Modular transformers isn't meant to replace the [legacy modeling code](./add_new_model). If your model isn't based on an existing model, add a `modeling.py` file manually. The same applies to configuration, tokenization, or processing files that can't cleanly inherit from a similar file. ## Implementing a modular file -The easiest way to start is by browsing Transformers for a model similar to yours in order to inherit from it. Some good starting points are [Mistral](./model_doc/mistral), [Qwen2](./model_doc/qwen2), [Cohere](./model_doc/cohere) and [Cohere2](./model_doc/cohere2), and [Llama](./model_doc/llama). Refer to the table below for components your model might be using and where you can inherit from. +Start by finding a model in Transformers similar to yours. Good starting points are [Mistral](./model_doc/mistral), [Qwen2](./model_doc/qwen2), [Cohere](./model_doc/cohere) and [Cohere2](./model_doc/cohere2), and [Llama](./model_doc/llama). The table below maps common components to models you can inherit from. | Component | Model | |---|---| -| Mixture of expert | SwitchTransformers or Mixtral | +| Mixture of experts | Mixtral or Qwen2-MoE | | Interleaved (and/or partial) rotary embedding | GLM, Phi | | State space models | Jamba, Bamba, Zamba, Mamba2 | | Recurrent hidden states | Gemma2 | @@ -101,248 +37,185 @@ The easiest way to start is by browsing Transformers for a model similar to your | QK normalization | Olmo2, Cohere | | Fused QKV (not recommended) | Phi3 | -This section will walk you through how to implement [Olmo2](./model_doc/olmo2) from [Olmo](./model_doc/olmo) with modular Transformers (you can refer to the original [modeling.py](https://github.com/huggingface/transformers/blob/main/src/transformers/models/olmo2/modular_olmo2.py) file). +> [!TIP] +> Use the [modular-detector-v2](https://huggingface.co/spaces/Molbap/modular-detector-v2) tool to find existing implementations to inherit from. Paste a code snippet and it returns the most similar methods already in Transformers, so you can identify the best parent class before you start writing. -### Config +Create `src/transformers/models//modular_.py`, where `` matches the snake_case model directory name. This section walks you through implementing [Olmo2](./model_doc/olmo2) from [Olmo](./model_doc/olmo) with the modular approach (refer to the original [modular_olmo2.py](../../../src/transformers/models/olmo2/modular_olmo2) file). -The modular `Olmo2Config` is shown below. +### Config -```py -from ..olmo.configuration_olmo import OlmoConfig - -class Olmo2Config(OlmoConfig): - r""" - This is the configuration class to store the configuration of a [Olmo2Model](/docs/transformers/main/en/model_doc/olmo2#transformers.Olmo2Model). - """ - - def __init__( - self, - vocab_size=50304, - hidden_size=4096, - intermediate_size=11008, - num_hidden_layers=32, - num_attention_heads=32, - num_key_value_heads=None, - hidden_act="silu", - max_position_embeddings=2048, - initializer_range=0.02, - use_cache=True, - pad_token_id=1, - bos_token_id=None, - eos_token_id=50279, - tie_word_embeddings=False, - rope_theta=10000.0, - rope_scaling=None, - attention_bias=False, - attention_dropout=0.0, - rms_norm_eps=1e-5, - **kwargs, - ): - super().__init__( - vocab_size=vocab_size, - hidden_size=hidden_size, - intermediate_size=intermediate_size, - num_hidden_layers=num_hidden_layers, - num_attention_heads=num_attention_heads, - num_key_value_heads=num_key_value_heads, - hidden_act=hidden_act, - max_position_embeddings=max_position_embeddings, - initializer_range=initializer_range, - use_cache=use_cache, - pad_token_id=pad_token_id, - bos_token_id=bos_token_id, - eos_token_id=eos_token_id, - tie_word_embeddings=tie_word_embeddings, - rope_theta=rope_theta, - rope_scaling=rope_scaling, - attention_bias=attention_bias, - attention_dropout=attention_dropout, - **kwargs, - ) - - self.rms_norm_eps = rms_norm_eps - del self.clip_qkv +There are two points where [`Olmo2Config`] differs from [`OlmoConfig`]. + +1. There is a new argument, `rms_norm_eps`. +2. The `clip_qkv` argument is no longer used. + +Declare new arguments as class-level type annotations with a default value. For removed arguments, assign `AttributeError()` to suppress the inherited attribute in the generated file (see [Removing attributes](#removing-attributes)). + +```diff +-@auto_docstring(checkpoint="allenai/OLMo-7B-hf") ++@auto_docstring(checkpoint="allenai/Olmo2-7B-1124-hf") ++@strict +-class OlmoConfig(PreTrainedConfig): ++class Olmo2Config(OlmoConfig): + ... +- model_type = "olmo" ++ model_type = "olmo2" + ... ++ rms_norm_eps: float = 1e-5 +- clip_qkv: float | None = None ++ clip_qkv = AttributeError() ``` -There are three points where the `Olmo2Config` is different from the original `OlmoConfig`. +`@auto_docstring` generates standard argument docs automatically (see the [@auto_docstring](./auto_docstring) guide). `@strict` rejects unknown kwargs at instantiation time, catching typos and stale arguments early. Add both to every config class — neither is inherited from the parent, so you must declare them explicitly even if the parent config already has them. -1. The default value of most arguments have changed. -2. There is a new argument, `rms_norm_eps`. -3. The `clip_qkv` argument isn't used anymore. +To set a derived attribute or handle backward-compatibility logic, use `__post_init__` instead of `__init__`. For example, Cohere2 computes `head_dim` and derives `layer_types` at init time. -For the new default values and argument, overwrite the `__init__` function with the new default values and add `rms_norm_eps`. Assign `rms_norm_eps` to `self` in the body of `__init__`. For the `clip_qkv` argument, use `del self.clip_qkv` to remove the assignment of this attribute in the unraveled code (post-linter conversion). - -Notice how the `super().__init__(...)` is used. Typically, it calls the parent `__init__`. +```py +def __post_init__(self, **kwargs): + if self.num_key_value_heads is None: + self.num_key_value_heads = self.num_attention_heads + self.head_dim = self.hidden_size // self.num_attention_heads + super().__post_init__(**kwargs) +``` -But in modular Transformers, if there is a call like `super().my_function(...)`, the linter takes the body of `my_function` in the parent and unravels it where the call to `super().my_function(...)` occurred. The `del self.clip_qkv` statement removes the reference to `self.clip_qkv` in the unraveled body. +For models with tensor or pipeline parallelism support, define `base_model_tp_plan` and `base_model_pp_plan` as class-level dictionaries on the config. These define how to shard the model across devices. See existing configs like [Olmo2](../../../src/transformers/models/olmo2/modular_olmo2) or [Cohere2](../../../src/transformers/models/cohere2/modular_cohere2) for examples. -`del self.` and `super().my_function(..)` work together, and it should always be placed after `super().my_function(...)`. You can add whatever you want *before* calling `super()`, and it is placed before the parents body. +```py +class MyNewModelConfig(PreTrainedConfig): + model_type = "my_new_model" + + # Tensor parallelism: maps layer name patterns to sharding strategies. + # Use "colwise" / "rowwise" for standard sharding, or the "gather_output" / + # "split_input" variants when an extra op (e.g. a QK norm) prevents fusing. + base_model_tp_plan = { + "layers.*.self_attn.q_proj": "colwise", + "layers.*.self_attn.k_proj": "colwise", + "layers.*.self_attn.v_proj": "colwise", + "layers.*.self_attn.o_proj": "rowwise", + "layers.*.mlp.gate_proj": "colwise", + "layers.*.mlp.up_proj": "colwise", + "layers.*.mlp.down_proj": "rowwise", + } + + # Pipeline parallelism: maps submodule names to their (input, output) tensor names. + base_model_pp_plan = { + "embed_tokens": (["input_ids"], ["inputs_embeds"]), + "layers": (["hidden_states", "attention_mask"], ["hidden_states"]), + "norm": (["hidden_states"], ["hidden_states"]), + } +``` ### Norm +To copy a parent class without changes, inherit with `pass`. The linter copies the parent's content and renames all references to match the new model. + ```py -from ..llama.modeling_llama import LlamaRMSNorm +from ..olmo.modeling_olmo import OlmoRotaryEmbedding -class Olmo2RMSNorm(LlamaRMSNorm): +class Olmo2RotaryEmbedding(OlmoRotaryEmbedding): pass ``` -Nothing needs to be modified in `LlamaRMSNorm`. The linter unravels the exact content of `LlamaRMSNorm` into `Olmo2RMSNorm`. References to Llama in the docstrings, type hints, and comments are also changed to Olmo2. +To change specific behavior, inherit and override only what differs. [`Olmo2RMSNorm`] differs from [`LlamaRMSNorm`] on one line. The multiply happens *before* casting back to the input dtype, not after. -### Attention +```diff + from ..llama.modeling_llama import LlamaRMSNorm -The modular `Olmo2Attention` is shown below. - -```py -from ..llama.modeling_llama import eager_attention_forward -from ..olmo.modeling_olmo import OlmoAttention, apply_rotary_pos_emb - - -# Olmo2 attention is identical to OLMo attention except: -# - Norm is applied to attention queries and keys. -# - No qkv clipping. -class Olmo2Attention(OlmoAttention): - def __init__(self, config: Olmo2Config, layer_idx: Optional[int] = None): - super().__init__(config, layer_idx=layer_idx) - self.q_norm = Olmo2RMSNorm(config.num_attention_heads * self.head_dim, config.rms_norm_eps) - self.k_norm = Olmo2RMSNorm(config.num_key_value_heads * self.head_dim, config.rms_norm_eps) - - def forward( - self, - hidden_states: torch.Tensor, - position_embeddings: tuple[torch.Tensor, torch.Tensor], - attention_mask: Optional[torch.Tensor], - past_key_values: Optional[Cache] = None, - **kwargs, - ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]: - input_shape = hidden_states.shape[:-1] - hidden_shape = (*input_shape, -1, self.head_dim) - - query_states = self.q_norm(self.q_proj(hidden_states)) - key_states = self.k_norm(self.k_proj(hidden_states)) - value_states = self.v_proj(hidden_states) - - query_states = query_states.view(hidden_shape).transpose(1, 2) - key_states = key_states.view(hidden_shape).transpose(1, 2) - value_states = value_states.view(hidden_shape).transpose(1, 2) - - cos, sin = position_embeddings - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) - - if past_key_values is not None: - key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx) - - attention_interface: Callable = eager_attention_forward - if self.config._attn_implementation != "eager": - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] - - attn_output, attn_weights = attention_interface( - self, - query_states, - key_states, - value_states, - attention_mask, - dropout=0.0 if not self.training else self.attention_dropout, - scaling=self.scaling, - **kwargs, - ) - - attn_output = attn_output.reshape(*input_shape, -1).contiguous() - attn_output = self.o_proj(attn_output) - return attn_output, attn_weights + class Olmo2RMSNorm(LlamaRMSNorm): + def forward(self, hidden_states): + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) +- return self.weight * hidden_states.to(input_dtype) ++ return (self.weight * hidden_states).to(input_dtype) ``` -The `super().__init__(...)` copies the parent definition and adds 2 new layers from `Olmo2RMSNorm`. The forward pass needs to be overwritten to use these 2 new layers. A pass with the norm layers is added before projecting with `q_proj` and `k_proj`. To make it easier, the `eager_attention_forward` function is directly imported from Llama and the `apply_rotary_pos_emb` is imported from Olmo. - -The linter automatically adds these imported functions in the final `modeling_olmo2.py` file by copying their definitions from the source files. The `rotate_half` and `repeat_kv` functions are also added because they are used inside `apply_rotary_pos_emb` and `eager_attention_forward`. +### Attention -The `Attention` class had to be redefined because there weren't any existing models with an `Attention` layer that included a `RMSNorm` layer. +Olmo2's attention is identical to Olmo's except it applies [`RMSNorm`] to the queries and keys, and removes qkv clipping. `super().__init__(...)` copies the parent body and appends the two new norm lines. The `forward` is fully redefined because queries and keys now pass through norms before projection. The linter pulls in any imported functions — `apply_rotary_pos_emb`, `eager_attention_forward`, and their own dependencies — into the generated file automatically. + +```diff + class Olmo2Attention(OlmoAttention): + def __init__(self, config: Olmo2Config, layer_idx: int | None = None): + super().__init__(config, layer_idx=layer_idx) ++ self.q_norm = Olmo2RMSNorm(config.num_attention_heads * self.head_dim, config.rms_norm_eps) ++ self.k_norm = Olmo2RMSNorm(config.num_key_value_heads * self.head_dim, config.rms_norm_eps) + + def forward(self, ...): + ... +- query_states = self.q_proj(hidden_states) +- key_states = self.k_proj(hidden_states) ++ query_states = self.q_norm(self.q_proj(hidden_states)) ++ key_states = self.k_norm(self.k_proj(hidden_states)) + value_states = self.v_proj(hidden_states) + +- if self.config.clip_qkv is not None: +- query_states.clamp_(min=-self.config.clip_qkv, max=self.config.clip_qkv) +- key_states.clamp_(min=-self.config.clip_qkv, max=self.config.clip_qkv) +- value_states.clamp_(min=-self.config.clip_qkv, max=self.config.clip_qkv) +- + ... +``` ### DecoderLayer -The modular `DecoderLayer` is shown below. - -```py -from ..olmo.modeling_olmo import OlmoDecoderLayer - -# The OLMo2 layers are identical to those of the OLMo model except: -# - RMSNorm is used instead of standard layer norm. -# - Norm is applied after attention/feedforward rather than before. -class Olmo2DecoderLayer(OlmoDecoderLayer): - def __init__(self, config: Olmo2Config, layer_idx: int): - super().__init__(config, layer_idx=layer_idx) - self.post_attention_layernorm = Olmo2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) - self.post_feedforward_layernorm = Olmo2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) - self.self_attn = Olmo2Attention(config=config, layer_idx=layer_idx) - del self.input_layernorm - - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[Cache] = None, - output_attentions: Optional[bool] = False, - use_cache: Optional[bool] = False, - position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, - **kwargs, - ) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]: - residual = hidden_states - - # Self Attention - hidden_states, self_attn_weights = self.self_attn( - hidden_states=hidden_states, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_values=past_key_values, - output_attentions=output_attentions, - use_cache=use_cache, - position_embeddings=position_embeddings, - **kwargs, - ) - hidden_states = self.post_attention_layernorm(hidden_states) - hidden_states = residual + hidden_states - - # Fully Connected - residual = hidden_states - hidden_states = self.mlp(hidden_states) - hidden_states = self.post_feedforward_layernorm(hidden_states) - hidden_states = residual + hidden_states - - outputs = (hidden_states,) - if output_attentions: - outputs += (self_attn_weights,) - - return outputs +After `super().__init__(...)`, overwrite the norm attributes with `Olmo2RMSNorm` instances and reassign `self.self_attn` to the new `Olmo2Attention` class. The `del self.input_layernorm` removes the parent's `input_layernorm` assignment since Olmo2 applies the norm *after* attention rather than before. See [Removing attributes](#removing-attributes) for details on what `del` does and doesn't remove. + +The `forward` is rewritten to reflect the post-attention norm placement. Switching only the norm type without renaming the attribute wouldn't require a `forward` rewrite. + +```diff + class Olmo2DecoderLayer(OlmoDecoderLayer): + def __init__(self, config: Olmo2Config, layer_idx: int): + super().__init__(config, layer_idx=layer_idx) +- self.self_attn = OlmoAttention(config=config, layer_idx=layer_idx) +- self.input_layernorm = OlmoLayerNorm(config.hidden_size) +- self.post_attention_layernorm = OlmoLayerNorm(config.hidden_size) ++ self.self_attn = Olmo2Attention(config=config, layer_idx=layer_idx) ++ del self.input_layernorm ++ self.post_attention_layernorm = Olmo2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) ++ self.post_feedforward_layernorm = Olmo2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + def forward(self, ...): + residual = hidden_states +- hidden_states = self.input_layernorm(hidden_states) + # Self Attention + hidden_states, _ = self.self_attn(...) +- hidden_states = residual + hidden_states ++ hidden_states = self.post_attention_layernorm(hidden_states) ++ hidden_states = residual + hidden_states + + # Fully Connected + residual = hidden_states +- hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) +- hidden_states = residual + hidden_states ++ hidden_states = self.post_feedforward_layernorm(hidden_states) ++ hidden_states = residual + hidden_states + return hidden_states ``` -The norm type is switched in `__init__` by overwriting `self.post_attention_layernorm` after the call to `super().__init__(...)`. Delete the `self.input_layernorm` attributed and replace it with `self.post_feedforward_layernorm` because it is applied after in Olmo2. The forward method is overwritten to reflect this change. - -If you only switched `self.post_feedforward_layernorm` and `self.input_layernorm` from `LayerNorm` to `RMSNorm` without also changing the name and logic of `self.input_layernorm`, then you wouldn't have to rewrite the forward method. - ### Model -The modular `Olmo2Model` class is shown below. - -```py -from ..olmo.modeling_olmo import OlmoModel - -# The OLMo2 model is identical to the OLMo model, except RMSNorm is used instead of -# standard layer norm for the output norm. -class Olmo2Model(OlmoModel): - def __init__(self, config: Olmo2Config): - super().__init__(config) - self.norm = Olmo2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) - self.layers = nn.ModuleList( - [Olmo2DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] - ) +Only the type of `self.norm` changes here. The `forward` method is identical to the parent's, so the linter carries it over automatically. + +```diff + class Olmo2Model(OlmoModel): + def __init__(self, config: Olmo2Config): + super().__init__(config) +- self.layers = nn.ModuleList( +- [OlmoDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] +- ) +- self.norm = OlmoLayerNorm(config.hidden_size) ++ self.norm = Olmo2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) ++ self.layers = nn.ModuleList( ++ [Olmo2DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] ++ ) ``` -You only need to change the *type* of the `self.norm` attribute to use `RMSNorm` instead of `LayerNorm`. This change doesn't affect the logic in the forward method (layer name and usage is identical to the parent class), so you don't need to overwrite it. The linter automatically unravels it. - ### Model head -The modular causal modeling head is shown below. +The logic is identical to [`OlmoForCausalLM`], so no changes are needed. ```py from ..olmo.modeling_olmo import OlmoForCausalLM @@ -351,15 +224,13 @@ class Olmo2ForCausalLM(OlmoForCausalLM): pass ``` -The logic is identical to `OlmoForCausalLM` which means you don't need to make any changes here. - ### Other classes -The [modeling_olmo2.py](https://github.com/huggingface/transformers/blob/main/src/transformers/models/olmo2/modeling_olmo2.py) generated by the linter also contains some classes (`Olmo2MLP`, `Olmo2RotaryEmbedding`, `Olmo2PreTrainedModel`) that weren't explicitly defined in `modular_olmo2.py`. +The [modeling_olmo2.py](../../../src/transformers/models/olmo2/modeling_olmo2) generated by the linter also contains classes ([`Olmo2MLP`], [`Olmo2RotaryEmbedding`], [`Olmo2PreTrainedModel`]) that weren't explicitly defined in `modular_olmo2.py`. -Classes that are a dependency of an inherited class but aren't explicitly defined are automatically added as a part of dependency tracing. This is similar to how some functions were added to the `Attention` class without directly importing them. +Any class that an inherited class depends on is pulled in automatically if you don't explicitly redefine it. Imported functions like `apply_rotary_pos_emb` follow the same rule. -For example, `OlmoDecoderLayer` has an attribute defined as `self.mlp = OlmoMLP(config)`. This class was never explicitly redefined in `Olmo2MLP`, so the linter automatically created a `Olmo2MLP` class similar to `OlmoMLP`. It is identical to the code below if it was explicitly written in `modular_olmo2.py`. +For example, [`OlmoDecoderLayer`] has `self.mlp = OlmoMLP(config)`. Since [`Olmo2MLP`] was never defined in the modular file, the linter automatically creates it. This is equivalent to using `pass`. ```py from ..olmo.modeling_olmo import OlmoMLP @@ -368,42 +239,98 @@ class Olmo2MLP(OlmoMLP): pass ``` -However, it was necessary to rewrite `Olmo2RMSNorm` because the layer norm needed to be redefined in the `Attention` and `DecoderLayer` classes. Similarly, this is why you didn't need to create the `Olmo2PreTrainedModel` and `Olmo2RotaryEmbedding` classes. - -Classes that aren't rewritten are copied from the file where the inherited module first uses them. This means if you wanted `Olmo2MLP` to inherit from `MistralMLP` instead, you would need to be more explicit as shown below. +If you want [`Olmo2MLP`] to inherit from a different model instead, be explicit. ```py -# switch to mistral definition +# switch to Mistral definition from ..mistral.modeling_mistral import MistralMLP class Olmo2MLP(MistralMLP): pass ``` +### Finishing the file + +Every modular file must declare a `logger` and an `__all__` list at the module level. + +```py +logger = logging.get_logger(__name__) + +__all__ = [ + "Olmo2Config", + "Olmo2ForCausalLM", + "Olmo2Model", + "Olmo2PreTrainedModel", +] +``` + +`__all__` must list every public class in the file. The converter and downstream imports depend on it. A class missing from `__all__` won't be exported correctly. + +## Generate the modeling files + +The `modular_model_converter.py` script generates standalone `modeling.py`, `configuration.py`, and other files from your modular file. It does this by copying each inherited parent class body into the child, renaming all references to match the new model, and pulling in any helper functions or classes those parents depend on. + +The output files contain no cross-model imports and no inheritance from other model directories. Inheritance is flattened to only a *single* level. If [`Olmo2Attention`] inherits from [`OlmoAttention`], the generated `Olmo2Attention` is fully self-contained. But if `OlmoAttention` itself inherited from something else, that grandparent is not inlined. + +Run the command below to generate files from a modular file. + +```bash +python utils/modular_model_converter.py your_model +``` + +Never edit the generated files directly because any changes will be overwritten on the next run. + ## Removing attributes -You can `del` to remove attributes defined in the parent after using `super().__init__()`. However, this doesn't work if the attribute is also used somewhere else as shown below. It only suppresses the assignment. The `self.attribute = config.attribute` line is removed, but the `if` statement remains and references the attribute. +Removing an inherited attribute depends on whether you're working with a config class or an `nn.Module` subclass. + +For a config class, assign `AttributeError()` to the attribute at the class level. ```py -class DummyModel(nn.Module): +class MyNewConfig(ParentConfig): + removed_attr = AttributeError() +``` - def __init__(self, config: DummyConfig): - super().__init__() - self.attribute = config.attribute - if self.attribute: - # do more stuff with `self.attribute` here - ... +The linter removes the attribute declaration from the generated config file entirely. Config classes use a dataclass-style layout with no `__init__`, so this is the right pattern to use. -class MyNewDummyModel(DummyModel): +For an `nn.Module` subclass, use `del self.attribute` after `super().__init__(...)`. + +```py +class MyNewModel(ParentModel): + def __init__(self, config: MyNewConfig): + super().__init__(config) + del self.attribute +``` + +`del self.attribute` removes only the `self.attribute = ...` assignment line from the copied parent body. It does not remove any other lines that reference `self.attribute`. If the parent's `forward` or other methods also reference the attribute, override those methods too. - def __init__(self, config: MyNewDummyConfig): - super().__init__(config) - del self.attribute +```py +class DummyModel(nn.Module): + def __init__(self, config: DummyConfig): + super().__init__() + self.attribute = config.attribute + if self.attribute: + # do more stuff with `self.attribute` here + ... + +class MyNewDummyModel(DummyModel): + def __init__(self, config: MyNewDummyConfig): + super().__init__(config) + del self.attribute + # 'self.attribute = config.attribute' is removed, but the 'if self.attribute:' block remains. + # Override forward() or any other method that references self.attribute. ``` -## Calling parent methods without unravelling their definition +## Working with `super()` -If you want to inherit from a module `DummyModule` and want to call `super()` WITHOUT unravelling the parent's code (that is, you want to call `super()` on the *generated* class parent), be explicit about which class' `super()` you're calling. The example below shows how to call the `super()` of `nn.Module` (unraveled code shown on the right). In this example, as `DummyModule` is itself a `nn.Module`, it makes sense to call `nn.Module.__init__(self)` as it's what was the initial intention. It's then unravelled as `super()` in `MyNewDummyModule` to follow Python's best-practices. +`super().__init__(config)` tells the converter to copy the parent body into the child. Two patterns let you override this behavior. + +- Call a specific parent class directly when you need the generated output to call a grandparent (`nn.Module.__init__`) rather than the modular parent. +- Use `**super_kwargs` when you want to inherit a parent method's full signature while adding a custom docstring or swapping a decorator. + +### Call a grandparent class directly + +To call `super()` on the *generated* class parent rather than the modular parent, be explicit about which class you're calling. The example below calls `nn.Module.__init__(self)` directly. `DummyModule` is itself an `nn.Module`, so the converter writes this as `super().__init__()` in the generated `MyNewDummyModule`. ```py class MyNewDummyModule(DummyModule): | class MyNewDummyModule(nn.Module): @@ -414,9 +341,34 @@ class MyNewDummyModule(DummyModule): | class MyNewDummyMod ... | ... ``` +### super_kwargs + +Use `**super_kwargs` when you want to inherit a parent method's full signature while adding a custom docstring or swapping a decorator. In the overridden signature, `**super_kwargs` tells the linter to expand all the parent's arguments in the generated output. + +The most common use is adding a model-specific docstring — for example, documenting the `labels` argument — without rewriting the full signature. Gemma does exactly this: + +```py +# modular_gemma.py +class GemmaForCausalLM(LlamaForCausalLM): + def forward(**super_kwargs): + r""" + Example: + + ```python + >>> from transformers import AutoTokenizer, GemmaForCausalLM + >>> model = GemmaForCausalLM.from_pretrained("google/gemma-7b") + ... + ```""" + return super().forward(**super_kwargs) +``` + +The generated `GemmaForCausalLM.forward` has the full `LlamaForCausalLM` signature — no manual copying needed. + +`**super_kwargs` is a shortcut for a few niche cases. Don't use it to avoid writing explicit signatures when overriding a method with real logic changes. If you change behavior, write the full signature explicitly. + ## Deleting unused methods -Remove an attribute by overwriting it with a `raise AttributeError("")` statement to mimic the behavior you want when you remove a parent function in Python. The example below removes the methods in the unraveled code. +Remove a parent method by overriding it with a `raise AttributeError("")` statement. The linter removes the method from the generated file. ```py class GemmaTokenizer(LlamaTokenizer): @@ -429,24 +381,13 @@ class GemmaTokenizer(LlamaTokenizer): raise AttributeError("Not needed for Gemma") ``` -## Defining new functions - -By default, if you inherit from a class and override a method with one or more decorators in the parent method, the decorators are also added to the unraveled code *only if you don't add any yourself*. Otherwise, the redefined decorator is used. - -Two decorators appear throughout the library. One enables [capturing model intermediate outputs](./model_output_tracing), and another for [auto-generating docstrings](./auto_docstring). +## Overriding decorated methods -In the example below, a subclass inherits from a decorated parent. The parent's decorator carries over to the unraveled code. +When you override a decorated parent method, the parent's decorators carry over automatically unless you add your own, in which case your decorator replaces theirs. -```py -class DummyModel(nn.Module): - ... - - @decorator(...) - def forward(...) - # do stuff here -``` +Two decorators appear throughout the library, one for [capturing model intermediate outputs](./model_output_tracing) and one for [auto-generating docstrings](./auto_docstring). -Modular code is shown on the left, and the unraveled code is shown on the right. +In the example below, a subclass overrides a decorated parent method without adding its own decorator. The parent's decorator carries over. ```py class NewModel(DummyModel): | class NewModel(nn.Module): @@ -457,7 +398,7 @@ class NewModel(DummyModel): | class NewModel(nn.Module): | ... ``` -But if you add a new decorator, your new decorator is used instead. +If you add a new decorator, your decorator replaces the parent's. ```py class NewModel(DummyModel): | class NewModel(nn.Module): @@ -468,73 +409,12 @@ class NewModel(DummyModel): | class NewModel(nn.Module): ... | ... ``` -## super_kwargs - -In scenarios where a forward method is really long and you want to switch decorators, you don't need to redefine everything and copy/paste the function. You can use `super().forward(...)` to unravel the parent body. When there are a lot of arguments in the function signature, use the special `**super_kwargs` syntax in the overwritten signature. - -This syntax indicates to the linter to unravel all the parent signature arguments here. An example signature in a [`AutoModelForCausalLM`] model is shown below, with lots of arguments. - -```py -class LlamaForCausalLM(nn.Module): - ... - - @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING) - @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) - def forward( - self, - input_ids: torch.LongTensor = None, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[Cache] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - labels: Optional[torch.LongTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - num_logits_to_keep: int = 0, - **kwargs: Unpack[KwargsForCausalLM], - ) -> Union[Tuple, CausalLMOutputWithPast]: - ... -``` - -Instead of rewriting and copying/pasting all of those arguments, use the `super().forward(**super_kwargs)` statement (modular code shown on the left, unraveled code on the right). - -```py -class NewModelForCausalLM(LlamaForCausalLM): | class LlamaForCausalLM(nn.Module): - ... | ... - | - @my_new_decorator | @my_new_decorator - def forward(self, **super_kwargs): | def forward( - super().forward(**super_kwargs) | self, - | input_ids: torch.LongTensor = None, - | attention_mask: Optional[torch.Tensor] = None, - | position_ids: Optional[torch.LongTensor] = None, - | past_key_values: Optional[Cache] = |None, - | inputs_embeds: Optional[torch.FloatTensor] = None, - | labels: Optional[torch.LongTensor] = None, - | use_cache: Optional[bool] = None, - | output_attentions: Optional[bool] = None, - | output_hidden_states: Optional[bool] = None, - | return_dict: Optional[bool] = None, - | num_logits_to_keep: int = 0, - | **kwargs: Unpack[KwargsForCausalLM], - | ) -> Union[Tuple, CausalLMOutputWithPast]: - | ... -``` - -This makes it very easy to switch decorators and makes it explicit that the only change you want to apply is the decorator. - -`**super_kwargs` should not be used to avoid being explicit when redefining methods though. If you overwrite a method, you should explicitly write the signature as you normally would. The `**super_kwargs` syntax is a shortcut for switching decorators and a few other niche cases. - ## Docstring variables > [!TIP] -> Refer to the [Documeting a model](./auto_docstring) guide for more information about how you can use the `@auto_docstring` decorator to help automatically generate consistent docstring arguments. +> Refer to the [Documenting a model](./auto_docstring) guide for more information about how to use the `@auto_docstring` decorator to automatically generate consistent docstring arguments. For most models, `@auto_docstring` removes the need for explicit docstring variables entirely. -If an object defined in both the modular and modeling file from which it inherits, the modular definition has precedence unless for assignments containing the pattern `DOCSTRING`. These variables are typically used in `MODEL_START_DOCSTRING` and `MODEL_INPUT_DOCSTRING` in the modeling files. They are big blocks of docstrings and the linter rewrites the names everywhere. For this reason, assignments containing the `DOCSTRING` variable can use the definition found in the source file without copying the whole docstring, by simply setting the variable to `None` in the modular file. - -This is very useful if you need the variable reference somewhere but you don't want to clutter the modular file with docstrings which are always the same. The example code below allows you to automatically use the same docstrings from [Mistral](./model_doc/mistral) in [Starcoder2](./model_doc/starcoder2). +The modular definition takes precedence when an object appears in both the modular and source modeling file. The exception is if assignments matches the pattern `DOCSTRING`. These variables (`MODEL_START_DOCSTRING`, `MODEL_INPUT_DOCSTRING`) contain large blocks of text. Set a docstring variable to `None` in the modular file to use the source file's definition instead. ```py STARCODER2_INPUTS_DOCSTRING = None # will be automatically redefined @@ -547,13 +427,13 @@ class Starcoder2Model(MistralModel): ... ``` -Setting the variable to anything other than `None` will override the docstring, so that you can customize the docstrings if needed. +Setting the variable to anything other than `None` overrides the docstring with your custom value. ## Special naming -The linter automatically renames everything when inheriting from a class. For consistency, you should always use the same class name prefix when inheriting from different classes from the same file. +The linter automatically renames everything when inheriting from a class. Use the same class name prefix across all classes in the same file. -The example below is not recommended. It breaks standards in the library, `MyModelIncredibleMLP` instead of `LlamaMLP`, and because the linter doesn't know how to rename potential higher-order dependencies (`MyModelIncredible` or just `MyModel`). +Avoid mixing prefixes like in the example below. `MyModelIncredibleMLP` breaks naming conventions, and the linter won't know whether to use `MyModelIncredible` or `MyModel` when renaming higher-order dependencies. ```py class MyModelIncredibleMLP(LlamaMLP): @@ -563,15 +443,15 @@ class MyModelDecoderLayer(LlamaDecoderLayer): ... ``` -However, if there aren't any [implicit dependencies](#other-classes), then you can locally rename a single class. Make sure you still explicitly redefine every other mention of the class with the new name pattern though. For example, all mentions of `LlamaMLP` should be renamed to `MyModelIncredibleMLP` otherwise the linter may add a new and unwanted `MyModelMLP` class. +With no [implicit dependencies](#other-classes), you can rename a single class locally. Explicitly redefine every other mention of that class with the new name pattern. Otherwise, the linter adds an unwanted `MyModelMLP` class alongside `MyModelIncredibleMLP`. -The linter raises a warning if an ambiguous case is detected. It explains what is happening and which prefix is used by default for getting the dependencies. These warning and renaming pattern complications usually only come up when defining multimodal models. For example, adding `Text` to class names in a multimodal model to make it clear which modality it refers to. +The linter raises a warning when it detects an ambiguous prefix. -```py +```text We detected multiple prefix names when inheriting from transformers.models.llama.modeling_llama: ('Emu3Text', 'Emu3'). We will only use the most used 'Emu3' prefix when grabbing args and dependencies. Make sure to subclass the intermediate classes with the prefix you want (if different from 'Emu3') or use a single prefix in all the modular (best). ``` -If there are automatic dependencies with a prefix, but you want another one, explicitly rename the classes locally with a `pass` class as shown in the following. +This most commonly comes up in multimodal models where class names include a modality qualifier like `Text`. If you want a dependency to use a specific prefix, explicitly rename it with a `pass`. ```py class Emu3TextMLP(LlamaMLP): @@ -580,8 +460,8 @@ class Emu3TextMLP(LlamaMLP): ## Config docstrings -When inheriting a `Config` class or adding and deleting attributes, you may want to only redefine the new attributes in the docstring. However, the linter doesn't support this yet. You need to directly add the while docstring directly in the modular file under the class definition. +The linter doesn't support partial docstring inheritance yet. When adding or removing config attributes, add the full docstring directly in the modular file under the class definition. ## See also -- [Model structure rules](./modeling_rules) — static rules enforced on all `modeling_*.py`, `modular_*.py`, and `configuration_*.py` files. Run `make typing` to check them before opening a PR. +- [Model structure rules](./modeling_rules) are static rules enforced on all `modeling_*.py`, `modular_*.py`, and `configuration_*.py` files. Run `make typing` to check them before opening a PR.