Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions src/transformers/utils/output_capturing.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ def reset(self, token):
_active_collector = CompileableContextVar("output_collector")


def install_output_capuring_hook(module: nn.Module, key: str, index: int) -> None:
def install_output_capturing_hook(module: nn.Module, key: str, index: int) -> None:
"""Install the forward hook needed to capture the output described by `key` and `index` in `module`."""

def output_capturing_hook(module, args, output):
Expand Down Expand Up @@ -144,12 +144,12 @@ def recursively_install_hooks(
):
if specs.layer_name is not None and specs.layer_name not in module_name:
continue
install_output_capuring_hook(parent_module, key, specs.index)
install_output_capturing_hook(parent_module, key, specs.index)


def install_all_output_capturing_hooks(model: PreTrainedModel, prefix: str | None = None) -> None:
"""
Install the output recording hooks on all the modules in `model`. Tis will take care of correctly dispatching
Install the output recording hooks on all the modules in `model`. This will take care of correctly dispatching
the `_can_record_outputs` property of each individual submodels in case of composite models.
"""
# _can_record_outputs is None by default
Expand Down
Loading