From 6a8f8a3314ee6800a19f49116b38ef40d6762c0e Mon Sep 17 00:00:00 2001 From: Tobias Perelstein <5562156+tobocop2@users.noreply.github.com> Date: Mon, 6 Apr 2026 12:58:35 -0400 Subject: [PATCH] fix: use fine-grained xet-core callback for smoother tqdm progress Switch xet_get() from a 1-arg to 2-arg callback so xet-core reports progress frequently instead of barely at all. Fixes #4058 --- src/huggingface_hub/file_download.py | 63 ++++++- src/huggingface_hub/hf_api.py | 10 +- tests/test_xet_download.py | 265 +++++++++++++++++++++++++++ 3 files changed, 326 insertions(+), 12 deletions(-) diff --git a/src/huggingface_hub/file_download.py b/src/huggingface_hub/file_download.py index ccfb19be1d..bfd167d540 100644 --- a/src/huggingface_hub/file_download.py +++ b/src/huggingface_hub/file_download.py @@ -7,9 +7,10 @@ import time import uuid import warnings +from collections.abc import Callable from dataclasses import dataclass from pathlib import Path -from typing import Any, BinaryIO, Literal, NoReturn, overload +from typing import TYPE_CHECKING, Any, BinaryIO, Literal, NoReturn, Protocol, overload from urllib.parse import quote, urlparse import httpx @@ -56,8 +57,62 @@ from .utils.tqdm import _get_progress_bar_context +if TYPE_CHECKING: + from hf_xet import PyItemProgressUpdate, PyTotalProgressUpdate + + logger = logging.get_logger(__name__) + +class ProgressBar(Protocol): + """Any object with a tqdm-compatible ``update(n)`` method.""" + + def update(self, n: int) -> None: ... + + +def make_xet_progress_callback( + progress_bar: ProgressBar, file_size: int | None +) -> Callable[["PyTotalProgressUpdate", "list[PyItemProgressUpdate]"], None]: + """Create a xet-core progress callback that scales network transfer to *file_size*. + + Each callback tracks its own cumulative bytes so multiple callbacks can + safely share a single progress bar (e.g. multi-file bucket downloads). + """ + _cumulative_transfer = 0 + _last_contributed = 0 + + def _callback( + total_update: "PyTotalProgressUpdate", + item_updates: "list[PyItemProgressUpdate]", + ) -> None: + nonlocal _cumulative_transfer, _last_contributed + increment = total_update.total_transfer_bytes_completion_increment + if increment <= 0: + return + + _cumulative_transfer += increment + transfer_total = total_update.total_transfer_bytes + + # Indeterminate bar (unknown file size): pass through raw bytes + if file_size is None: + progress_bar.update(int(increment)) + return + + # File size known but transfer total not yet reported: skip to avoid + # injecting unscaled bytes into a determinate bar + if not transfer_total or transfer_total <= 0: + return + + # Scale network transfer fraction to file size + contributed = min(round(_cumulative_transfer / transfer_total * file_size), file_size) + advance = contributed - _last_contributed + if advance > 0: + _last_contributed = contributed + progress_bar.update(advance) + + return _callback + + # Return value when trying to load a file from cache but the file does not exist in the distant repo. _CACHED_NO_EXIST = object() _CACHED_NO_EXIST_T = Any @@ -551,16 +606,12 @@ def token_refresher() -> tuple[str, int]: xet_headers.pop("authorization", None) with progress_cm as progress: - - def progress_updater(progress_bytes: float): - progress.update(progress_bytes) - download_files( xet_download_info, endpoint=connection_info.endpoint, token_info=(connection_info.access_token, connection_info.expiration_unix_epoch), token_refresher=token_refresher, - progress_updater=[progress_updater], + progress_updater=[make_xet_progress_callback(progress, expected_size)], request_headers=xet_headers, ) diff --git a/src/huggingface_hub/hf_api.py b/src/huggingface_hub/hf_api.py index e5013d0746..ad99ffc4d5 100644 --- a/src/huggingface_hub/hf_api.py +++ b/src/huggingface_hub/hf_api.py @@ -100,7 +100,7 @@ XetAuthorizationError, XetRefreshTokenError, ) -from .file_download import DryRunFileInfo, HfFileMetadata, get_hf_file_metadata, hf_hub_url +from .file_download import DryRunFileInfo, HfFileMetadata, get_hf_file_metadata, hf_hub_url, make_xet_progress_callback from .repocard_data import DatasetCardData, ModelCardData, SpaceCardData from .utils import ( DEFAULT_IGNORE_PATTERNS, @@ -12823,16 +12823,14 @@ def token_refresher() -> tuple[str, int]: ) with progress_cm as progress: - - def progress_updater(progress_bytes: float): - progress.update(progress_bytes) - download_files( non_zero_download_infos, endpoint=connection_info.endpoint, token_info=(connection_info.access_token, connection_info.expiration_unix_epoch), token_refresher=token_refresher, - progress_updater=[progress_updater] * len(non_zero_download_infos), + progress_updater=[ + make_xet_progress_callback(progress, info.file_size) for info in non_zero_download_infos + ], ) @validate_hf_hub_args diff --git a/tests/test_xet_download.py b/tests/test_xet_download.py index 01c90b156d..55f556b2e3 100644 --- a/tests/test_xet_download.py +++ b/tests/test_xet_download.py @@ -354,3 +354,268 @@ def token_refresher() -> tuple[str, int]: ) assert os.path.exists(file_path) + + +def _make_mock_total_update(transfer_increment: float, transfer_total: float): + """Create a mock PyTotalProgressUpdate with the given transfer fields.""" + update = Mock() + update.total_transfer_bytes_completion_increment = transfer_increment + update.total_transfer_bytes = transfer_total + return update + + +@requires("hf_xet") +class TestXetProgressGranularity: + """Test that xet_get uses the fine-grained 2-arg callback for tqdm progress.""" + + _XET_FILE_DATA = XetFileData(file_hash="mock_hash", refresh_route="mock/route") + + def _call_xet_get_and_capture(self, tmp_path, mock_download, expected_size=1000, mock_progress_cm=None): + """Call xet_get and return the captured progress callback.""" + incomplete_path = tmp_path / "test_file.bin" + incomplete_path.touch() + + captured = {} + + def capture(*args, **kwargs): + captured["callback"] = kwargs["progress_updater"][0] + + mock_download.side_effect = capture + + if mock_progress_cm is not None: + mock_bar = Mock() + mock_bar.n = 0 + mock_progress_cm.return_value.__enter__ = Mock(return_value=mock_bar) + mock_progress_cm.return_value.__exit__ = Mock(return_value=False) + else: + mock_bar = None + + xet_get( + incomplete_path=incomplete_path, + xet_file_data=self._XET_FILE_DATA, + headers={"authorization": "Bearer token"}, + expected_size=expected_size, + ) + + return captured["callback"], mock_bar + + @patch("huggingface_hub.file_download._get_progress_bar_context") + @patch("hf_xet.download_files") + @patch( + "huggingface_hub.file_download.refresh_xet_connection_info", + return_value=XetConnectionInfo( + endpoint="mock_endpoint", access_token="mock_token", expiration_unix_epoch=9999999999 + ), + ) + def test_callback_uses_two_arg_signature(self, _mock_conn, mock_download, mock_progress_cm, tmp_path): + """Verify xet_get passes a 2-arg callback to download_files, triggering + xet-core's fine-grained network-level progress dispatch.""" + callback, _ = self._call_xet_get_and_capture(tmp_path, mock_download, mock_progress_cm=mock_progress_cm) + + # Call with 2 args (total_update, item_updates) to confirm it accepts them. + # A 1-arg callback would raise TypeError here. + total_update = _make_mock_total_update(transfer_increment=200, transfer_total=1000) + callback(total_update, []) # should not raise + + @patch("huggingface_hub.file_download._get_progress_bar_context") + @patch("hf_xet.download_files") + @patch( + "huggingface_hub.file_download.refresh_xet_connection_info", + return_value=XetConnectionInfo( + endpoint="mock_endpoint", access_token="mock_token", expiration_unix_epoch=9999999999 + ), + ) + def test_progress_bar_scales_network_to_file_size(self, _mock_conn, mock_download, mock_progress_cm, tmp_path): + """When transfer bytes differ from file size, the progress bar should + scale to expected_size so it always reaches 100%.""" + expected_size = 10_000 + transfer_total = 5_000 # fewer bytes due to deduplication + + callback, mock_bar = self._call_xet_get_and_capture( + tmp_path, mock_download, expected_size=expected_size, mock_progress_cm=mock_progress_cm + ) + + def update_side_effect(n): + mock_bar.n += n + + mock_bar.update = Mock(side_effect=update_side_effect) + + # Simulate 5 updates of 1000 transfer bytes each (total: 5000 transfer bytes) + for _ in range(5): + total_update = _make_mock_total_update(transfer_increment=1000, transfer_total=transfer_total) + callback(total_update, []) + + # After transferring 5000/5000 bytes, bar should be at expected_size (10000) + assert mock_bar.n == expected_size + + @patch("huggingface_hub.file_download._get_progress_bar_context") + @patch("hf_xet.download_files") + @patch( + "huggingface_hub.file_download.refresh_xet_connection_info", + return_value=XetConnectionInfo( + endpoint="mock_endpoint", access_token="mock_token", expiration_unix_epoch=9999999999 + ), + ) + def test_progress_bar_capped_at_expected_size(self, _mock_conn, mock_download, mock_progress_cm, tmp_path): + """Progress bar should never exceed expected_size.""" + expected_size = 1000 + + callback, mock_bar = self._call_xet_get_and_capture( + tmp_path, mock_download, expected_size=expected_size, mock_progress_cm=mock_progress_cm + ) + + def update_side_effect(n): + mock_bar.n += n + + mock_bar.update = Mock(side_effect=update_side_effect) + + # Send more transfer bytes than total (edge case) + total_update = _make_mock_total_update(transfer_increment=1200, transfer_total=1000) + callback(total_update, []) + + assert mock_bar.n <= expected_size + + @patch("huggingface_hub.file_download._get_progress_bar_context") + @patch("hf_xet.download_files") + @patch( + "huggingface_hub.file_download.refresh_xet_connection_info", + return_value=XetConnectionInfo( + endpoint="mock_endpoint", access_token="mock_token", expiration_unix_epoch=9999999999 + ), + ) + def test_zero_increment_skipped(self, _mock_conn, mock_download, mock_progress_cm, tmp_path): + """Zero-increment updates should not call progress.update.""" + callback, mock_bar = self._call_xet_get_and_capture(tmp_path, mock_download, mock_progress_cm=mock_progress_cm) + + total_update = _make_mock_total_update(transfer_increment=0, transfer_total=1000) + callback(total_update, []) + + mock_bar.update.assert_not_called() + + @patch("huggingface_hub.file_download._get_progress_bar_context") + @patch("hf_xet.download_files") + @patch( + "huggingface_hub.file_download.refresh_xet_connection_info", + return_value=XetConnectionInfo( + endpoint="mock_endpoint", access_token="mock_token", expiration_unix_epoch=9999999999 + ), + ) + def test_expected_size_none_passes_raw_bytes(self, _mock_conn, mock_download, mock_progress_cm, tmp_path): + """When expected_size is None, raw transfer bytes are passed through.""" + callback, mock_bar = self._call_xet_get_and_capture( + tmp_path, mock_download, expected_size=None, mock_progress_cm=mock_progress_cm + ) + + def update_side_effect(n): + mock_bar.n += n + + mock_bar.update = Mock(side_effect=update_side_effect) + + total_update = _make_mock_total_update(transfer_increment=500, transfer_total=0) + callback(total_update, []) + + assert mock_bar.n == 500 + + @patch("huggingface_hub.file_download._get_progress_bar_context") + @patch("hf_xet.download_files") + @patch( + "huggingface_hub.file_download.refresh_xet_connection_info", + return_value=XetConnectionInfo( + endpoint="mock_endpoint", access_token="mock_token", expiration_unix_epoch=9999999999 + ), + ) + def test_expected_size_none_with_known_transfer_total(self, _mock_conn, mock_download, mock_progress_cm, tmp_path): + """When expected_size is None, raw bytes pass through even if transfer_total is known.""" + callback, mock_bar = self._call_xet_get_and_capture( + tmp_path, mock_download, expected_size=None, mock_progress_cm=mock_progress_cm + ) + + def update_side_effect(n): + mock_bar.n += n + + mock_bar.update = Mock(side_effect=update_side_effect) + + total_update = _make_mock_total_update(transfer_increment=500, transfer_total=2000) + callback(total_update, []) + + assert mock_bar.n == 500 + + @patch("huggingface_hub.file_download._get_progress_bar_context") + @patch("hf_xet.download_files") + @patch( + "huggingface_hub.file_download.refresh_xet_connection_info", + return_value=XetConnectionInfo( + endpoint="mock_endpoint", access_token="mock_token", expiration_unix_epoch=9999999999 + ), + ) + def test_transfer_total_zero_skips_when_expected_size_set( + self, _mock_conn, mock_download, mock_progress_cm, tmp_path + ): + """When expected_size is set but transfer_total is 0 (not yet known), + updates are skipped to avoid injecting unscaled bytes.""" + callback, mock_bar = self._call_xet_get_and_capture( + tmp_path, mock_download, expected_size=1000, mock_progress_cm=mock_progress_cm + ) + + total_update = _make_mock_total_update(transfer_increment=500, transfer_total=0) + callback(total_update, []) + + mock_bar.update.assert_not_called() + + +@requires("hf_xet") +class TestMakeXetProgressCallback: + """Direct tests for make_xet_progress_callback shared helper.""" + + def test_multi_file_shared_bar(self): + """Multiple callbacks sharing one bar should each contribute independently.""" + from huggingface_hub.file_download import make_xet_progress_callback + + mock_bar = Mock() + mock_bar.n = 0 + + def update_side_effect(n): + mock_bar.n += n + + mock_bar.update = Mock(side_effect=update_side_effect) + + # Two files: 600 bytes and 400 bytes, sharing a bar with total=1000 + cb_a = make_xet_progress_callback(mock_bar, file_size=600) + cb_b = make_xet_progress_callback(mock_bar, file_size=400) + + # File A: 50% done (transfers 500/1000 network bytes -> contributes 300 of 600 file bytes) + cb_a(_make_mock_total_update(transfer_increment=500, transfer_total=1000), []) + assert mock_bar.n == 300 + + # File B: 100% done (transfers 800/800 -> contributes 400 of 400 file bytes) + cb_b(_make_mock_total_update(transfer_increment=800, transfer_total=800), []) + assert mock_bar.n == 700 # 300 + 400 + + # File A: 100% done (transfers remaining 500/1000 -> contributes remaining 300) + cb_a(_make_mock_total_update(transfer_increment=500, transfer_total=1000), []) + assert mock_bar.n == 1000 # 600 + 400 + + def test_no_regression_on_duplicate_progress(self): + """When cumulative doesn't advance (e.g. duplicate update), bar should not update.""" + from huggingface_hub.file_download import make_xet_progress_callback + + mock_bar = Mock() + mock_bar.n = 0 + + def update_side_effect(n): + mock_bar.n += n + + mock_bar.update = Mock(side_effect=update_side_effect) + + cb = make_xet_progress_callback(mock_bar, file_size=1000) + + # First update: 500/1000 transfer -> 500 file bytes + cb(_make_mock_total_update(transfer_increment=500, transfer_total=1000), []) + assert mock_bar.n == 500 + assert mock_bar.update.call_count == 1 + + # Tiny increment that doesn't move int() forward (1 byte of 1000 transfer = 0.001 * 1000 = 1) + # contributed = int(501/1000 * 1000) = 501, advance = 501 - 500 = 1 + cb(_make_mock_total_update(transfer_increment=1, transfer_total=1000), []) + assert mock_bar.n == 501 + assert mock_bar.update.call_count == 2