Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 7 additions & 5 deletions src/huggingface_hub/_snapshot_download.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,9 @@
)
from .file_download import REGEX_COMMIT_HASH, DryRunFileInfo, hf_hub_download, repo_folder_name
from .hf_api import DatasetInfo, HfApi, ModelInfo, RepoFile, SpaceInfo
from .utils import OfflineModeIsEnabled, filter_repo_objects, is_tqdm_disabled, logging, validate_hf_hub_args
from .utils import tqdm as hf_tqdm
from .utils import OfflineModeIsEnabled, filter_repo_objects, logging, validate_hf_hub_args
from .utils.tqdm import _create_progress_bar
from .utils.tqdm import tqdm as hf_tqdm


logger = logging.get_logger(__name__)
Expand Down Expand Up @@ -385,14 +386,15 @@ def snapshot_download(
# Create a progress bar for the bytes downloaded
# This progress bar is shared across threads/files and gets updated each time we fetch
# metadata for a file.
bytes_progress = tqdm_class(
bytes_progress = _create_progress_bar(
cls=tqdm_class,
log_level=logger.getEffectiveLevel(),
name="huggingface_hub.snapshot_download",
desc="Downloading (incomplete total...)",
disable=is_tqdm_disabled(log_level=logger.getEffectiveLevel()),
total=0,
initial=0,
unit="B",
unit_scale=True,
name="huggingface_hub.snapshot_download",
)

class _AggregatedTqdm:
Expand Down
33 changes: 30 additions & 3 deletions src/huggingface_hub/utils/tqdm.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,6 +279,32 @@ def _inner_read(size: int | None = -1) -> bytes:
pbar.close()


def _create_progress_bar(*, cls: type[old_tqdm], log_level: int, name: str | None = None, **kwargs) -> old_tqdm:
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added this private helper to centralize how we initialize the tqdm progress bar. is_tqdm_disabled is no longer needed internally but kept it since it's exposed in huggingface_hub.utils module

"""Create a progress bar.

For our `tqdm` subclass (or subclasses of it): respects all disable signals
(`HF_HUB_DISABLE_PROGRESS_BARS`, `disable_progress_bars()`, log level) and uses
`disable=None` for TTY auto-detection (see https://github.com/huggingface/huggingface_hub/pull/2000),
unless `TQDM_POSITION=-1` forces bars on (https://github.com/huggingface/huggingface_hub/pull/2698).

For other classes: does not inject `disable` or `name`. the custom class is fully
responsible for its own behavior. Vanilla tqdm defaults to `disable=False` (bar shows).
Omits `name` which vanilla tqdm rejects with `TqdmKeyError`. See https://github.com/huggingface/huggingface_hub/issues/4050.
"""
# issubclass() crashes on non-class callables (e.g. functools.partial), guard with isinstance.
if not (isinstance(cls, type) and issubclass(cls, tqdm)):
return cls(**kwargs) # type: ignore[return-value]

# HF subclass: apply all disable signals + TTY auto-detection.
if are_progress_bars_disabled(name) or log_level == logging.NOTSET:
disable: bool | None = True
elif os.getenv("TQDM_POSITION") == "-1":
disable = False
else:
disable = None
return cls(disable=disable, name=name, **kwargs) # type: ignore[return-value]
Copy link
Copy Markdown

@tobocop2 tobocop2 Apr 6, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

_create_progress_bar should wrap instantiation in try/except (OSError, ValueError), log a warning, and fall back to a no-op. A progress bar failing to init shouldn't crash the download.

See #4056 (comment)



def _get_progress_bar_context(
*,
desc: str,
Expand All @@ -297,12 +323,13 @@ def _get_progress_bar_context(
# Makes it easier to use the same code path for both cases but in the later
# case, the progress bar is not closed when exiting the context manager.

return (tqdm_class or tqdm)( # type: ignore
return _create_progress_bar( # type: ignore
cls=tqdm_class or tqdm,
log_level=log_level,
name=name,
unit=unit,
unit_scale=unit_scale,
total=total,
initial=initial,
desc=desc,
disable=is_tqdm_disabled(log_level=log_level),
name=name,
)
51 changes: 51 additions & 0 deletions tests/test_utils_tqdm.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,14 @@
import io
import logging
import sys
import time
import unittest
from pathlib import Path
from unittest.mock import patch

import pytest
from pytest import CaptureFixture
from tqdm.auto import tqdm as vanilla_tqdm

from huggingface_hub.utils import (
SoftTemporaryDirectory,
Expand All @@ -14,6 +18,7 @@
tqdm,
tqdm_stream_file,
)
from huggingface_hub.utils.tqdm import _get_progress_bar_context


class CapsysBaseTest(unittest.TestCase):
Expand Down Expand Up @@ -235,3 +240,49 @@ def test_progress_bar_respects_group(self) -> None:
captured = self.capsys.readouterr()
assert captured.out == ""
assert "10/10" in captured.err


class TestCreateProgressBarCustomClass:
"""Regression tests for https://github.com/huggingface/huggingface_hub/issues/4050."""

def test_custom_tqdm_class_not_disabled_in_non_tty(self):
"""Custom tqdm_class should not be silently disabled in non-TTY."""
fake_stderr = io.StringIO()
with patch.object(sys, "stderr", fake_stderr):
bar = _get_progress_bar_context(
desc="test",
log_level=logging.INFO,
total=100,
tqdm_class=vanilla_tqdm,
name="huggingface_hub.test",
)
with bar as pbar:
assert not pbar.disable
pbar.update(50)
pbar.update(50)
assert pbar.n == 100

@patch("huggingface_hub.utils._tqdm.HF_HUB_DISABLE_PROGRESS_BARS", True)
def test_custom_tqdm_class_ignores_hf_disable_signal(self):
"""Custom tqdm_class is not affected by HF_HUB_DISABLE_PROGRESS_BARS."""
bar = _get_progress_bar_context(
desc="test",
log_level=logging.INFO,
total=10,
tqdm_class=vanilla_tqdm,
name="huggingface_hub.test",
)
with bar as pbar:
assert not pbar.disable

def test_custom_tqdm_class_no_name_kwarg(self):
"""Custom tqdm_class should not receive HF-specific 'name' kwarg."""
bar = _get_progress_bar_context(
desc="test",
log_level=logging.INFO,
total=10,
tqdm_class=vanilla_tqdm,
name="huggingface_hub.test",
)
with bar as pbar:
pbar.update(10)
Loading