diff --git a/src/transformers/configuration_utils.py b/src/transformers/configuration_utils.py index 97d6b94b57aa..8388a373a3e5 100755 --- a/src/transformers/configuration_utils.py +++ b/src/transformers/configuration_utils.py @@ -26,6 +26,7 @@ from huggingface_hub import create_repo from huggingface_hub.dataclasses import strict from packaging import version +from typing_extensions import dataclass_transform from . import __version__ from .dynamic_module_utils import custom_object_save @@ -75,6 +76,8 @@ # copied from huggingface_hub.dataclasses.strict when `accept_kwargs=True` def wrap_init_to_accept_kwargs(cls: dataclass): + + # Get the original dataclass-generated __init__ original_init = cls.__init__ @wraps(original_init) @@ -113,6 +116,7 @@ def __init__(self, *args, **kwargs: Any) -> None: return cls +@dataclass_transform(kw_only_default=True) @strict(accept_kwargs=True) @dataclass(repr=False) class PreTrainedConfig(PushToHubMixin, RotaryEmbeddingConfigMixin):