diff --git a/Makefile b/Makefile index 5a2bd3f7301d..caffe9403928 100644 --- a/Makefile +++ b/Makefile @@ -37,6 +37,7 @@ check-repository-consistency: copies,\ modular_conversion,\ doc_toc,\ + modeling_rules_doc,\ docstrings,\ dummies,\ repo,\ @@ -63,6 +64,7 @@ check-repo: copies,\ modular_conversion,\ doc_toc,\ + modeling_rules_doc,\ docstrings,\ dummies,\ repo,\ @@ -84,6 +86,7 @@ fix-repo: init_isort,\ auto_mappings,\ doc_toc,\ + modeling_rules_doc,\ copies,\ modular_conversion,\ dummies,\ diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index 0a718ba6d512..d7aa78b531db 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -403,6 +403,8 @@ title: Contribute to Transformers - local: testing title: Transformers model tests + - local: modeling_rules + title: Model structure rules - local: pr_checks title: Pull request checks title: Contribute diff --git a/docs/source/en/add_new_model.md b/docs/source/en/add_new_model.md index f23117b0050f..60aa7f592ad1 100644 --- a/docs/source/en/add_new_model.md +++ b/docs/source/en/add_new_model.md @@ -659,3 +659,8 @@ There are four timelines for model additions depending on the model contributor - **Hub-first release**: Transformers [remote-code](./models#custom-models) feature allows Transformers-based projects to be shared directly on the Hub. This is a good option if you don't have the bandwidth to add a model directly to Transformers. If a model ends up being very popular, then it's very likely that we'll integrate it in Transformers ourselves to enable better support (documentation, maintenance, optimization, etc.) for it. A Hub-first release is the most frictionless way to add a model. + +## See also + +- [Model structure rules](./modeling_rules) — static rules enforced on all `modeling_*.py` and `configuration_*.py` files. Run `make typing` to check them before opening a PR. +- [Pull request checks](./pr_checks) — full reference for what CI checks run on your PR and how to pass them. diff --git a/docs/source/en/modeling_rules.md b/docs/source/en/modeling_rules.md new file mode 100644 index 000000000000..d93b48c4024c --- /dev/null +++ b/docs/source/en/modeling_rules.md @@ -0,0 +1,243 @@ + + +# Model structure rules + +Transformers enforces a set of static rules on every `modeling_*.py`, `modular_*.py`, and `configuration_*.py` file. The [mlinter](https://github.com/huggingface/transformers/tree/main/utils/mlinter) tool checks them as part of `make typing` and errors out if violations are found. + +These are the expected model conventions for adding or changing modeling code. They keep the codebase consistent and ensure compatibility with features like pipeline parallelism, device maps, and weight tying. + +## Running the checker + +`make typing` runs `mlinter` alongside the `ty` type checker. Run `mlinter` on its own with the following commands. + +```bash +python -m utils.mlinter # check all modeling files +python -m utils.mlinter --changed-only # check only files changed vs origin/main +python -m utils.mlinter --list-rules # list all rules and their enabled status +python -m utils.mlinter --rule TRF001 # show built-in docs for a specific rule +``` + +The `--changed-only` flag is the fastest option during development. It only checks the files you've modified relative to the main branch. + +## Fixing a violation + +When a rule violation is detected, the error looks like this: + +``` +src/transformers/models/acme/modeling_acme.py:18: TRF013: AcmeModel.__init__ does not call self.post_init(). +``` + +Use the rule ID to look up the fix in the [rules reference](#rules-reference). TRF013 is triggered when a [`PreTrainedModel`] subclass doesn't call `self.post_init()`. That method performs essential finalization steps, and omitting it causes runtime bugs. + +```diff + class AcmeModel(AcmePreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.layers = nn.ModuleList( + [AcmeDecoderLayer(config) for _ in range(config.num_hidden_layers)] + ) ++ self.post_init() +``` + +## Rules reference + +Each rule below lists what it enforces and a diff showing the fix. Run `python -m utils.mlinter --rule TRF001` to see the built-in docs for any rule. + + + +### TRF001 + +Checks naming consistency between PreTrainedModel and config_class. Mismatched config_class can break loading, auto classes, and developer expectations. + +```diff +class AcmePreTrainedModel(PreTrainedModel): +- config_class = WileConfig ++ config_class = AcmeConfig +``` + +### TRF002 + +Checks that base_model_prefix, when set, is a non-empty, whitespace-free string literal. Invalid prefixes can break weight loading key mapping and base model access patterns. + +```diff +class AcmePreTrainedModel(PreTrainedModel): +- base_model_prefix = "" ++ base_model_prefix = "model" +``` + +### TRF003 + +Detects forward methods that use the old 'if not return_dict: return (x,)' pattern. The old return_dict branching pattern is error-prone and verbose. Use the capture_output or can_return_tuple decorators instead. + +```diff +-def forward(self, x, return_dict=None): +- if not return_dict: +- return (x,) +- return AcmeModelOutput(last_hidden_state=x) ++@can_return_tuple ++def forward(self, x): ++ return AcmeModelOutput(last_hidden_state=x) +``` + +### TRF004 + +Checks that no model class defines a tie_weights method. Overriding tie_weights leads to bad consequences for loading, device_map computation, and saving. Use _tied_weights_keys class attribute to declare tied weights instead. + +```diff +-def tie_weights(self): +- self.lm_head.weight = self.emb.weight ++class AcmeForCausalLM(AcmePreTrainedModel): ++ _tied_weights_keys = ["lm_head.weight"] +``` + +### TRF005 + +Checks the shape of _no_split_modules when present. Malformed values can break device-map partitioning and sharding behavior. + +```diff +-_no_split_modules = [SomeLayerClass, ""] ++_no_split_modules = ["AcmeDecoderLayer", "AcmeAttention"] +``` + +### TRF006 + +Checks forward signatures that expose cache arguments for usage of those arguments in method body. Unused cache arguments can indicate incomplete caching support and inconsistent API behavior. + +```diff +def forward(self, x, past_key_values=None, use_cache=False): ++ if use_cache: ++ ... + return x +``` + +### TRF007 + +Checks for self attribute assignments after self.post_init() in __init__. Mutating model structure after post_init can bypass intended initialization/finalization logic. + +```diff +def __init__(self, config): + ... +- self.post_init() +- self.proj = nn.Linear(...) ++ self.proj = nn.Linear(...) ++ self.post_init() +``` + +### TRF008 + +Checks add_start_docstrings usage on model classes for non-empty docstring arguments. Empty decorator usage produces unclear docs and weakens generated API documentation quality. + +```diff +-@add_start_docstrings("") ++@add_start_docstrings("The Acme model.") + class AcmeModel(AcmePreTrainedModel): + ... +``` + +### TRF009 + +Checks modeling files for cross-model imports such as transformers.models.other_model.* or from ..other_model.* imports. Cross-model implementation imports violate the single-file policy and make model behavior harder to inspect and maintain. + +```diff +-from transformers.models.llama.modeling_llama import LlamaAttention ++# Keep implementation local to this file. ++# If reusing code, copy it with a # Copied from comment. +``` + +### TRF010 + +Checks direct PreTrainedConfig/PretrainedConfig subclasses in configuration_*.py and modular_*.py for an explicit @strict(accept_kwargs=True) decorator. Without strict, new config classes miss the repo's runtime type-validation contract and drift from the dataclass-based config standard. + +```diff ++@strict(accept_kwargs=True) + class AcmeConfig(PreTrainedConfig): + ... +``` + +### TRF011 + +In forward() methods of PreTrainedModel subclasses, checks for attribute accesses on submodules that would not exist on torch.nn.Identity. This includes attribute accesses on loop variables iterating over self.layers, and self.. chains where is not a standard nn.Module attribute. Pipeline parallelism may replace any submodule with torch.nn.Identity. Accessing custom attributes (e.g. decoder_layer.attention_type) on a replaced module raises AttributeError at runtime. Per-layer metadata should be read from self.config instead. + +```diff +def forward(self, ...): +- for decoder_layer in self.layers: ++ for i, decoder_layer in enumerate(self.layers): + hidden_states = decoder_layer( + hidden_states, +- attention_mask=causal_mask_mapping[decoder_layer.attention_type], ++ attention_mask=causal_mask_mapping[self.config.layer_types[i]], + ) +``` + +### TRF012 + +Checks that _init_weights(self, module) does not use in-place operations (e.g. .normal_(), .zero_()) directly on module weights. We rely on internal flags set on parameters to track whether they need re-initialization. In-place ops bypass this mechanism. Use the `init` primitives instead. + +```diff ++from transformers import initialization as init ++ + def _init_weights(self, module): +- module.weight.normal_(mean=0.0, std=0.02) ++ init.normal_(module.weight, mean=0.0, std=0.02) +``` + +### TRF013 + +Checks that every PreTrainedModel subclass with an __init__ method calls self.post_init(). In modular files, calling super().__init__() is also accepted since it propagates post_init from the parent. post_init performs essential finalization (weight initialization, gradient checkpointing setup, etc.). Omitting it causes subtle runtime bugs. + +```diff +class AcmeModel(AcmePreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.layers = nn.ModuleList(...) ++ self.post_init() +``` + +### TRF014 + +Checks whether `trust_remote_code` is passed or used in code (e.g. as kwarg) within native model integration files. `trust_remote_code` allows arbitrary loading, including binaries, which should only be a power feature for users, not a standard use-case. Native integrations must not depend on it, as remote code cannot be reviewed or maintained within transformers. + +```diff +class AcmeModel(AcmePreTrainedModel): + def __init__(self, config): + super().__init__(config) +- self.model = AutoModel.from_pretrained(..., trust_remote_code=True) ++ self.model = AutoModel.from_pretrained(...) +``` + + + +## Suppressing violations + +If you need to suppress a rule violation, use one of the two options below. + +### Inline suppression + +Add a `# trf-ignore: RULE_ID` comment on the violating line. Include an explanation so reviewers understand why the suppression is justified. + +```py +# trf-ignore: TRF011 — mask is derived from self.config, not the layer +hidden_states = layer(hidden_states, attention_mask=mask_from_config) +``` + +Don't use `trf-ignore` to silence violations that should be fixed in the code. + +### `allowlist_models` + +For models with legacy code that can't be fixed immediately, add the model's directory name to the relevant rule's `allowlist_models` list in `utils/mlinter/rules.toml`. + +```toml +[rules.TRF004] +allowlist_models = ["existing_model", "your_model_name"] +``` diff --git a/docs/source/en/modular_transformers.md b/docs/source/en/modular_transformers.md index 7bda0075e0d3..620e877d971d 100644 --- a/docs/source/en/modular_transformers.md +++ b/docs/source/en/modular_transformers.md @@ -581,3 +581,7 @@ class Emu3TextMLP(LlamaMLP): ## Config docstrings When inheriting a `Config` class or adding and deleting attributes, you may want to only redefine the new attributes in the docstring. However, the linter doesn't support this yet. You need to directly add the while docstring directly in the modular file under the class definition. + +## See also + +- [Model structure rules](./modeling_rules) — static rules enforced on all `modeling_*.py`, `modular_*.py`, and `configuration_*.py` files. Run `make typing` to check them before opening a PR. diff --git a/utils/check_modeling_rules_doc.py b/utils/check_modeling_rules_doc.py new file mode 100644 index 000000000000..14fc12e070ed --- /dev/null +++ b/utils/check_modeling_rules_doc.py @@ -0,0 +1,100 @@ +# Copyright 2026 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Keep `## Rules reference` section ofdocs/source/en/modeling_rules.m in sync +with utils/mlinter/rules.toml. + +Usage (from the root of the repo): + +Check everything is up to date (used in ``make check-repo``): + +```bash +python utils/check_modeling_rules_doc.py +``` + +Auto-regenerate if out of date (used in ``make fix-repo``): + +```bash +python utils/check_modeling_rules_doc.py --fix_and_overwrite +``` +""" + +import argparse +import os +import sys + + +CHECKER_CONFIG = { + "name": "modeling_rules_doc", + "label": "Modeling rules documentation", + "file_globs": ["utils/mlinter/rules.toml", "docs/source/en/modeling_rules.md"], + "check_args": [], + "fix_args": ["--fix_and_overwrite"], +} + +ROOT = os.path.dirname(os.path.dirname(__file__)) +DOC_PATH = os.path.join(ROOT, "docs", "source", "en", "modeling_rules.md") + +BEGIN_MARKER = "" +END_MARKER = "" + + +sys.path.insert(0, ROOT) +from utils.mlinter.mlinter import TRF_RULE_SPECS, format_rule_details # noqa: E402 + + +def generate_rules_reference() -> str: + sections = [] + for rule_id in sorted(TRF_RULE_SPECS): + sections.append(format_rule_details(rule_id)) + return "\n\n".join(sections) + "\n" + + +def check_modeling_rules_doc(overwrite: bool = False): + with open(DOC_PATH, encoding="utf-8") as f: + content = f.read() + + begin_idx = content.find(BEGIN_MARKER) + end_idx = content.find(END_MARKER) + if begin_idx == -1 or end_idx == -1: + raise ValueError( + f"Could not find {BEGIN_MARKER} and {END_MARKER} markers in {DOC_PATH}. " + "These markers delimit the auto-generated rules reference section." + ) + + after_begin = begin_idx + len(BEGIN_MARKER) + expected = "\n\n" + generate_rules_reference() + "\n" + current = content[after_begin:end_idx] + + if current == expected: + return + + if overwrite: + new_content = content[:after_begin] + expected + content[end_idx:] + with open(DOC_PATH, "w", encoding="utf-8") as f: + f.write(new_content) + print(f"Updated rules reference in {DOC_PATH}") + else: + raise ValueError( + "The rules reference section in docs/source/en/modeling_rules.md is out of sync " + "with utils/mlinter/rules.toml. Run `make fix-repo` to regenerate it." + ) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--fix_and_overwrite", action="store_true", help="Whether to fix inconsistencies.") + args = parser.parse_args() + + check_modeling_rules_doc(args.fix_and_overwrite) diff --git a/utils/mlinter/mlinter.py b/utils/mlinter/mlinter.py index 443379b8f2d8..4fc1a3e2740c 100644 --- a/utils/mlinter/mlinter.py +++ b/utils/mlinter/mlinter.py @@ -45,7 +45,7 @@ def _load_rule_specs() -> dict[str, dict]: if not isinstance(rules, dict): raise ValueError(f"Invalid rule spec file: missing [rules] table in {RULE_SPECS_PATH}") - required_explanation_keys = {"what_it_does", "why_bad", "bad_example", "good_example"} + required_explanation_keys = {"what_it_does", "why_bad", "diff"} specs: dict[str, dict] = {} for rule_id, spec in rules.items(): if not isinstance(spec, dict): @@ -332,29 +332,15 @@ def format_rule_summary(rule_id: str) -> str: def format_rule_details(rule_id: str) -> str: spec = TRF_RULE_SPECS[rule_id] explanation = spec["explanation"] - default_label = "yes" if spec["default_enabled"] else "no" return "\n".join( [ - rule_id, + f"### {rule_id}", "", - f"Summary: {spec['description']}", - f"Default enabled: {default_label}", + f"{explanation['what_it_does']} {explanation['why_bad']}", "", - "What it does", - "", - explanation["what_it_does"], - "", - "Why is this bad?", - "", - explanation["why_bad"], - "", - "Example", - "", - explanation["bad_example"], - "", - "Use instead:", - "", - explanation["good_example"], + "```diff", + explanation["diff"].strip(), + "```", ] ) diff --git a/utils/mlinter/rules.toml b/utils/mlinter/rules.toml index 4294f53f3e14..80c31e1654af 100644 --- a/utils/mlinter/rules.toml +++ b/utils/mlinter/rules.toml @@ -6,13 +6,10 @@ allowlist_models = ["qwen3_omni_moe"] [rules.TRF001.explanation] what_it_does = "Checks naming consistency between PreTrainedModel and config_class." why_bad = "Mismatched config_class can break loading, auto classes, and developer expectations." -bad_example = ''' -class FooPreTrainedModel(PreTrainedModel): - config_class = BarConfig -''' -good_example = ''' -class FooPreTrainedModel(PreTrainedModel): - config_class = FooConfig +diff = ''' + class AcmePreTrainedModel(PreTrainedModel): +- config_class = WileConfig ++ config_class = AcmeConfig ''' [rules.TRF002] @@ -23,13 +20,10 @@ allowlist_models = ["lighton_ocr"] [rules.TRF002.explanation] what_it_does = "Checks that base_model_prefix, when set, is a non-empty, whitespace-free string literal." why_bad = "Invalid prefixes can break weight loading key mapping and base model access patterns." -bad_example = ''' -class FooPreTrainedModel(PreTrainedModel): - base_model_prefix = "" -''' -good_example = ''' -class FooPreTrainedModel(PreTrainedModel): - base_model_prefix = "model" +diff = ''' + class AcmePreTrainedModel(PreTrainedModel): +- base_model_prefix = "" ++ base_model_prefix = "model" ''' [rules.TRF003] @@ -40,16 +34,14 @@ allowlist_models = [] [rules.TRF003.explanation] what_it_does = "Detects forward methods that use the old 'if not return_dict: return (x,)' pattern." why_bad = "The old return_dict branching pattern is error-prone and verbose. Use the capture_output or can_return_tuple decorators instead." -bad_example = ''' -def forward(self, x, return_dict=None): - if not return_dict: - return (x,) - return FooOutput(last_hidden_state=x) -''' -good_example = ''' -@can_return_tuple -def forward(self, x): - return FooOutput(last_hidden_state=x) +diff = ''' +-def forward(self, x, return_dict=None): +- if not return_dict: +- return (x,) +- return AcmeModelOutput(last_hidden_state=x) ++@can_return_tuple ++def forward(self, x): ++ return AcmeModelOutput(last_hidden_state=x) ''' [rules.TRF004] @@ -60,13 +52,11 @@ allowlist_models = ["data2vec", "hubert", "sew", "sew_d", "unispeech", "unispeec [rules.TRF004.explanation] what_it_does = "Checks that no model class defines a tie_weights method." why_bad = "Overriding tie_weights leads to bad consequences for loading, device_map computation, and saving. Use _tied_weights_keys class attribute to declare tied weights instead." -bad_example = ''' -def tie_weights(self): - self.lm_head.weight = self.emb.weight -''' -good_example = ''' -class FooForCausalLM(FooPreTrainedModel): - _tied_weights_keys = ["lm_head.weight"] +diff = ''' +-def tie_weights(self): +- self.lm_head.weight = self.emb.weight ++class AcmeForCausalLM(AcmePreTrainedModel): ++ _tied_weights_keys = ["lm_head.weight"] ''' [rules.TRF005] @@ -77,11 +67,9 @@ allowlist_models = ["d_fine", "deformable_detr", "glm46v", "lw_detr", "pp_doclay [rules.TRF005.explanation] what_it_does = "Checks the shape of _no_split_modules when present." why_bad = "Malformed values can break device-map partitioning and sharding behavior." -bad_example = ''' -_no_split_modules = [SomeLayerClass, ""] -''' -good_example = ''' -_no_split_modules = ["FooDecoderLayer", "FooAttention"] +diff = ''' +-_no_split_modules = [SomeLayerClass, ""] ++_no_split_modules = ["AcmeDecoderLayer", "AcmeAttention"] ''' [rules.TRF006] @@ -92,15 +80,11 @@ allowlist_models = ["chinese_clip", "evolla", "idefics2", "llama4"] [rules.TRF006.explanation] what_it_does = "Checks forward signatures that expose cache arguments for usage of those arguments in method body." why_bad = "Unused cache arguments can indicate incomplete caching support and inconsistent API behavior." -bad_example = ''' -def forward(self, x, past_key_values=None, use_cache=False): - return x -''' -good_example = ''' -def forward(self, x, past_key_values=None, use_cache=False): - if use_cache: - ... - return x +diff = ''' + def forward(self, x, past_key_values=None, use_cache=False): ++ if use_cache: ++ ... + return x ''' [rules.TRF007] @@ -111,17 +95,13 @@ allowlist_models = ["distilbert", "lxmert", "mt5", "pix2struct", "pop2piano", "s [rules.TRF007.explanation] what_it_does = "Checks for self attribute assignments after self.post_init() in __init__." why_bad = "Mutating model structure after post_init can bypass intended initialization/finalization logic." -bad_example = ''' -def __init__(self, config): - ... - self.post_init() - self.proj = nn.Linear(...) -''' -good_example = ''' -def __init__(self, config): - ... - self.proj = nn.Linear(...) - self.post_init() +diff = ''' + def __init__(self, config): + ... +- self.post_init() +- self.proj = nn.Linear(...) ++ self.proj = nn.Linear(...) ++ self.post_init() ''' [rules.TRF008] @@ -131,15 +111,11 @@ default_enabled = true [rules.TRF008.explanation] what_it_does = "Checks add_start_docstrings usage on model classes for non-empty docstring arguments." why_bad = "Empty decorator usage produces unclear docs and weakens generated API documentation quality." -bad_example = ''' -@add_start_docstrings("") -class FooModel(FooPreTrainedModel): - ... -''' -good_example = ''' -@add_start_docstrings("The Foo model.") -class FooModel(FooPreTrainedModel): - ... +diff = ''' +-@add_start_docstrings("") ++@add_start_docstrings("The Acme model.") + class AcmeModel(AcmePreTrainedModel): + ... ''' [rules.TRF009] @@ -150,12 +126,10 @@ allowlist_models = ["dpr", "maskformer", "sam3_video", "vision_text_dual_encoder [rules.TRF009.explanation] what_it_does = "Checks modeling files for cross-model imports such as transformers.models.other_model.* or from ..other_model.* imports." why_bad = "Cross-model implementation imports violate the single-file policy and make model behavior harder to inspect and maintain." -bad_example = ''' -from transformers.models.llama.modeling_llama import LlamaAttention -''' -good_example = ''' -# Keep implementation local in this modeling file. -# If code is reused, copy it with an appropriate # Copied from ... statement. +diff = ''' +-from transformers.models.llama.modeling_llama import LlamaAttention ++# Keep implementation local to this file. ++# If reusing code, copy it with a # Copied from comment. ''' [rules.TRF010] @@ -166,14 +140,10 @@ allowlist_models = ["nemotron_h", "vibevoice_asr"] [rules.TRF010.explanation] what_it_does = "Checks direct PreTrainedConfig/PretrainedConfig subclasses in configuration_*.py and modular_*.py for an explicit @strict(accept_kwargs=True) decorator." why_bad = "Without strict, new config classes miss the repo's runtime type-validation contract and drift from the dataclass-based config standard." -bad_example = ''' -class FooConfig(PreTrainedConfig): - ... -''' -good_example = ''' -@strict(accept_kwargs=True) -class FooConfig(PreTrainedConfig): - ... +diff = ''' ++@strict(accept_kwargs=True) + class AcmeConfig(PreTrainedConfig): + ... ''' [rules.TRF011] @@ -184,21 +154,15 @@ allowlist_models = [] [rules.TRF011.explanation] what_it_does = "In forward() methods of PreTrainedModel subclasses, checks for attribute accesses on submodules that would not exist on torch.nn.Identity. This includes attribute accesses on loop variables iterating over self.layers, and self.. chains where is not a standard nn.Module attribute." why_bad = "Pipeline parallelism may replace any submodule with torch.nn.Identity. Accessing custom attributes (e.g. decoder_layer.attention_type) on a replaced module raises AttributeError at runtime. Per-layer metadata should be read from self.config instead." -bad_example = ''' -def forward(self, ...): - for decoder_layer in self.layers: - hidden_states = decoder_layer( - hidden_states, - attention_mask=causal_mask_mapping[decoder_layer.attention_type], - ) -''' -good_example = ''' -def forward(self, ...): - for i, decoder_layer in enumerate(self.layers): - hidden_states = decoder_layer( - hidden_states, - attention_mask=causal_mask_mapping[self.config.layer_types[i]], - ) +diff = ''' + def forward(self, ...): +- for decoder_layer in self.layers: ++ for i, decoder_layer in enumerate(self.layers): + hidden_states = decoder_layer( + hidden_states, +- attention_mask=causal_mask_mapping[decoder_layer.attention_type], ++ attention_mask=causal_mask_mapping[self.config.layer_types[i]], + ) ''' [rules.TRF012] @@ -209,15 +173,12 @@ allowlist_models = [] [rules.TRF012.explanation] what_it_does = "Checks that _init_weights(self, module) does not use in-place operations (e.g. .normal_(), .zero_()) directly on module weights." why_bad = "We rely on internal flags set on parameters to track whether they need re-initialization. In-place ops bypass this mechanism. Use the `init` primitives instead." -bad_example = ''' -def _init_weights(self, module): - module.weight.normal_(mean=0.0, std=0.02) -''' -good_example = ''' -from transformers import initialization as init - -def _init_weights(self, module): - init.normal_(module.weight, mean=0.0, std=0.02) +diff = ''' ++from transformers import initialization as init ++ + def _init_weights(self, module): +- module.weight.normal_(mean=0.0, std=0.02) ++ init.normal_(module.weight, mean=0.0, std=0.02) ''' [rules.TRF013] @@ -228,18 +189,12 @@ allowlist_models = [] [rules.TRF013.explanation] what_it_does = "Checks that every PreTrainedModel subclass with an __init__ method calls self.post_init(). In modular files, calling super().__init__() is also accepted since it propagates post_init from the parent." why_bad = "post_init performs essential finalization (weight initialization, gradient checkpointing setup, etc.). Omitting it causes subtle runtime bugs." -bad_example = ''' -class FooModel(FooPreTrainedModel): - def __init__(self, config): - super().__init__(config) - self.layers = nn.ModuleList(...) -''' -good_example = ''' -class FooModel(FooPreTrainedModel): - def __init__(self, config): - super().__init__(config) - self.layers = nn.ModuleList(...) - self.post_init() +diff = ''' + class AcmeModel(AcmePreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.layers = nn.ModuleList(...) ++ self.post_init() ''' [rules.TRF014] @@ -250,15 +205,10 @@ allowlist_models = [] [rules.TRF014.explanation] what_it_does = "Checks whether `trust_remote_code` is passed or used in code (e.g. as kwarg) within native model integration files." why_bad = "`trust_remote_code` allows arbitrary loading, including binaries, which should only be a power feature for users, not a standard use-case. Native integrations must not depend on it, as remote code cannot be reviewed or maintained within transformers." -bad_example = ''' -class FooModel(FooPreTrainedModel): - def __init__(self, config): - super().__init__(config) - self.model = AutoModel.from_pretrained(..., trust_remote_code=True) -''' -good_example = ''' -class FooModel(FooPreTrainedModel): - def __init__(self, config): - super().__init__(config) - self.model = AutoModel.from_pretrained(...) +diff = ''' + class AcmeModel(AcmePreTrainedModel): + def __init__(self, config): + super().__init__(config) +- self.model = AutoModel.from_pretrained(..., trust_remote_code=True) ++ self.model = AutoModel.from_pretrained(...) '''