Skip to content

Commit dcc4347

Browse files
committed
fix: use fine-grained xet-core callback for smoother tqdm progress
Switch xet_get() from a 1-arg to 2-arg callback signature so xet-core provides network-level progress updates (~200KB) instead of coarse disk-write updates (~8MB). Network transfer progress is scaled to file size so the bar always tracks 0-100% correctly. Fixes #4058
1 parent 6983a4d commit dcc4347

File tree

2 files changed

+168
-4
lines changed

2 files changed

+168
-4
lines changed

src/huggingface_hub/file_download.py

Lines changed: 28 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
import warnings
1010
from dataclasses import dataclass
1111
from pathlib import Path
12-
from typing import Any, BinaryIO, Literal, NoReturn, overload
12+
from typing import TYPE_CHECKING, Any, BinaryIO, Literal, NoReturn, overload
1313
from urllib.parse import quote, urlparse
1414

1515
import httpx
@@ -56,6 +56,10 @@
5656
from .utils.tqdm import _get_progress_bar_context
5757

5858

59+
if TYPE_CHECKING:
60+
from hf_xet import PyItemProgressUpdate, PyTotalProgressUpdate
61+
62+
5963
logger = logging.get_logger(__name__)
6064

6165
# Return value when trying to load a file from cache but the file does not exist in the distant repo.
@@ -551,9 +555,29 @@ def token_refresher() -> tuple[str, int]:
551555
xet_headers.pop("authorization", None)
552556

553557
with progress_cm as progress:
554-
555-
def progress_updater(progress_bytes: float):
556-
progress.update(progress_bytes)
558+
# Use 2-arg callback signature so xet-core provides fine-grained
559+
# network-level progress (~200KB) instead of coarse disk-write progress (~8MB).
560+
# Scale network transfer fraction to expected file size so the bar
561+
# always tracks 0-100% correctly even when transfer bytes differ from
562+
# file size (e.g. due to xet deduplication).
563+
_cumulative_transfer = 0
564+
565+
def progress_updater(
566+
total_update: "PyTotalProgressUpdate",
567+
item_updates: "list[PyItemProgressUpdate]",
568+
) -> None:
569+
nonlocal _cumulative_transfer
570+
increment = total_update.total_transfer_bytes_completion_increment
571+
if increment > 0:
572+
_cumulative_transfer += increment
573+
transfer_total = total_update.total_transfer_bytes
574+
if transfer_total and transfer_total > 0 and expected_size:
575+
new_n = min(int(_cumulative_transfer / transfer_total * expected_size), expected_size)
576+
advance = new_n - progress.n
577+
if advance > 0:
578+
progress.update(advance)
579+
else:
580+
progress.update(int(increment))
557581

558582
download_files(
559583
xet_download_info,

tests/test_xet_download.py

Lines changed: 140 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -354,3 +354,143 @@ def token_refresher() -> tuple[str, int]:
354354
)
355355

356356
assert os.path.exists(file_path)
357+
358+
359+
@requires("hf_xet")
360+
class TestXetProgressGranularity:
361+
"""Test that xet_get uses the fine-grained 2-arg callback for tqdm progress."""
362+
363+
_XET_FILE_DATA = XetFileData(file_hash="mock_hash", refresh_route="mock/route")
364+
_CONNECTION_INFO = XetConnectionInfo(
365+
endpoint="mock_endpoint", access_token="mock_token", expiration_unix_epoch=9999999999
366+
)
367+
368+
@staticmethod
369+
def _make_mock_total_update(transfer_increment: float, transfer_total: float):
370+
"""Create a mock PyTotalProgressUpdate with the given transfer fields."""
371+
update = Mock()
372+
update.total_transfer_bytes_completion_increment = transfer_increment
373+
update.total_transfer_bytes = transfer_total
374+
return update
375+
376+
def _call_xet_get_and_capture(self, tmp_path, mock_download, expected_size=1000, mock_progress_cm=None):
377+
"""Call xet_get and return the captured progress callback."""
378+
incomplete_path = tmp_path / "test_file.bin"
379+
incomplete_path.touch()
380+
381+
captured = {}
382+
383+
def capture(*args, **kwargs):
384+
captured["callback"] = kwargs["progress_updater"][0]
385+
386+
mock_download.side_effect = capture
387+
388+
if mock_progress_cm is not None:
389+
mock_bar = Mock()
390+
mock_bar.n = 0
391+
mock_progress_cm.return_value.__enter__ = Mock(return_value=mock_bar)
392+
mock_progress_cm.return_value.__exit__ = Mock(return_value=False)
393+
else:
394+
mock_bar = None
395+
396+
xet_get(
397+
incomplete_path=incomplete_path,
398+
xet_file_data=self._XET_FILE_DATA,
399+
headers={"authorization": "Bearer token"},
400+
expected_size=expected_size,
401+
)
402+
403+
return captured["callback"], mock_bar
404+
405+
@patch("huggingface_hub.file_download._get_progress_bar_context")
406+
@patch("hf_xet.download_files")
407+
@patch(
408+
"huggingface_hub.file_download.refresh_xet_connection_info",
409+
return_value=XetConnectionInfo(
410+
endpoint="mock_endpoint", access_token="mock_token", expiration_unix_epoch=9999999999
411+
),
412+
)
413+
def test_callback_uses_two_arg_signature(self, _mock_conn, mock_download, mock_progress_cm, tmp_path):
414+
"""Verify xet_get passes a 2-arg callback to download_files, triggering
415+
xet-core's fine-grained network-level progress dispatch."""
416+
callback, _ = self._call_xet_get_and_capture(tmp_path, mock_download, mock_progress_cm=mock_progress_cm)
417+
418+
# Call with 2 args (total_update, item_updates) to confirm it accepts them.
419+
# A 1-arg callback would raise TypeError here.
420+
total_update = self._make_mock_total_update(transfer_increment=200, transfer_total=1000)
421+
callback(total_update, []) # should not raise
422+
423+
@patch("huggingface_hub.file_download._get_progress_bar_context")
424+
@patch("hf_xet.download_files")
425+
@patch(
426+
"huggingface_hub.file_download.refresh_xet_connection_info",
427+
return_value=XetConnectionInfo(
428+
endpoint="mock_endpoint", access_token="mock_token", expiration_unix_epoch=9999999999
429+
),
430+
)
431+
def test_progress_bar_scales_network_to_file_size(self, _mock_conn, mock_download, mock_progress_cm, tmp_path):
432+
"""When transfer bytes differ from file size, the progress bar should
433+
scale to expected_size so it always reaches 100%."""
434+
expected_size = 10_000
435+
transfer_total = 5_000 # fewer bytes due to deduplication
436+
437+
callback, mock_bar = self._call_xet_get_and_capture(
438+
tmp_path, mock_download, expected_size=expected_size, mock_progress_cm=mock_progress_cm
439+
)
440+
441+
def update_side_effect(n):
442+
mock_bar.n += n
443+
444+
mock_bar.update = Mock(side_effect=update_side_effect)
445+
446+
# Simulate 5 updates of 1000 transfer bytes each (total: 5000 transfer bytes)
447+
for _ in range(5):
448+
total_update = self._make_mock_total_update(transfer_increment=1000, transfer_total=transfer_total)
449+
callback(total_update, [])
450+
451+
# After transferring 5000/5000 bytes, bar should be at expected_size (10000)
452+
assert mock_bar.n == expected_size
453+
454+
@patch("huggingface_hub.file_download._get_progress_bar_context")
455+
@patch("hf_xet.download_files")
456+
@patch(
457+
"huggingface_hub.file_download.refresh_xet_connection_info",
458+
return_value=XetConnectionInfo(
459+
endpoint="mock_endpoint", access_token="mock_token", expiration_unix_epoch=9999999999
460+
),
461+
)
462+
def test_progress_bar_capped_at_expected_size(self, _mock_conn, mock_download, mock_progress_cm, tmp_path):
463+
"""Progress bar should never exceed expected_size."""
464+
expected_size = 1000
465+
466+
callback, mock_bar = self._call_xet_get_and_capture(
467+
tmp_path, mock_download, expected_size=expected_size, mock_progress_cm=mock_progress_cm
468+
)
469+
470+
def update_side_effect(n):
471+
mock_bar.n += n
472+
473+
mock_bar.update = Mock(side_effect=update_side_effect)
474+
475+
# Send more transfer bytes than total (edge case)
476+
total_update = self._make_mock_total_update(transfer_increment=1200, transfer_total=1000)
477+
callback(total_update, [])
478+
479+
assert mock_bar.n <= expected_size
480+
481+
@patch("huggingface_hub.file_download._get_progress_bar_context")
482+
@patch("hf_xet.download_files")
483+
@patch(
484+
"huggingface_hub.file_download.refresh_xet_connection_info",
485+
return_value=XetConnectionInfo(
486+
endpoint="mock_endpoint", access_token="mock_token", expiration_unix_epoch=9999999999
487+
),
488+
)
489+
def test_zero_increment_skipped(self, _mock_conn, mock_download, mock_progress_cm, tmp_path):
490+
"""Zero-increment updates should not call progress.update."""
491+
callback, mock_bar = self._call_xet_get_and_capture(tmp_path, mock_download, mock_progress_cm=mock_progress_cm)
492+
493+
total_update = self._make_mock_total_update(transfer_increment=0, transfer_total=1000)
494+
callback(total_update, [])
495+
496+
mock_bar.update.assert_not_called()

0 commit comments

Comments
 (0)