Skip to content

Commit 6a8f8a3

Browse files
committed
fix: use fine-grained xet-core callback for smoother tqdm progress
Switch xet_get() from a 1-arg to 2-arg callback so xet-core reports progress frequently instead of barely at all. Fixes #4058
1 parent 6983a4d commit 6a8f8a3

3 files changed

Lines changed: 326 additions & 12 deletions

File tree

src/huggingface_hub/file_download.py

Lines changed: 57 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,10 @@
77
import time
88
import uuid
99
import warnings
10+
from collections.abc import Callable
1011
from dataclasses import dataclass
1112
from pathlib import Path
12-
from typing import Any, BinaryIO, Literal, NoReturn, overload
13+
from typing import TYPE_CHECKING, Any, BinaryIO, Literal, NoReturn, Protocol, overload
1314
from urllib.parse import quote, urlparse
1415

1516
import httpx
@@ -56,8 +57,62 @@
5657
from .utils.tqdm import _get_progress_bar_context
5758

5859

60+
if TYPE_CHECKING:
61+
from hf_xet import PyItemProgressUpdate, PyTotalProgressUpdate
62+
63+
5964
logger = logging.get_logger(__name__)
6065

66+
67+
class ProgressBar(Protocol):
68+
"""Any object with a tqdm-compatible ``update(n)`` method."""
69+
70+
def update(self, n: int) -> None: ...
71+
72+
73+
def make_xet_progress_callback(
74+
progress_bar: ProgressBar, file_size: int | None
75+
) -> Callable[["PyTotalProgressUpdate", "list[PyItemProgressUpdate]"], None]:
76+
"""Create a xet-core progress callback that scales network transfer to *file_size*.
77+
78+
Each callback tracks its own cumulative bytes so multiple callbacks can
79+
safely share a single progress bar (e.g. multi-file bucket downloads).
80+
"""
81+
_cumulative_transfer = 0
82+
_last_contributed = 0
83+
84+
def _callback(
85+
total_update: "PyTotalProgressUpdate",
86+
item_updates: "list[PyItemProgressUpdate]",
87+
) -> None:
88+
nonlocal _cumulative_transfer, _last_contributed
89+
increment = total_update.total_transfer_bytes_completion_increment
90+
if increment <= 0:
91+
return
92+
93+
_cumulative_transfer += increment
94+
transfer_total = total_update.total_transfer_bytes
95+
96+
# Indeterminate bar (unknown file size): pass through raw bytes
97+
if file_size is None:
98+
progress_bar.update(int(increment))
99+
return
100+
101+
# File size known but transfer total not yet reported: skip to avoid
102+
# injecting unscaled bytes into a determinate bar
103+
if not transfer_total or transfer_total <= 0:
104+
return
105+
106+
# Scale network transfer fraction to file size
107+
contributed = min(round(_cumulative_transfer / transfer_total * file_size), file_size)
108+
advance = contributed - _last_contributed
109+
if advance > 0:
110+
_last_contributed = contributed
111+
progress_bar.update(advance)
112+
113+
return _callback
114+
115+
61116
# Return value when trying to load a file from cache but the file does not exist in the distant repo.
62117
_CACHED_NO_EXIST = object()
63118
_CACHED_NO_EXIST_T = Any
@@ -551,16 +606,12 @@ def token_refresher() -> tuple[str, int]:
551606
xet_headers.pop("authorization", None)
552607

553608
with progress_cm as progress:
554-
555-
def progress_updater(progress_bytes: float):
556-
progress.update(progress_bytes)
557-
558609
download_files(
559610
xet_download_info,
560611
endpoint=connection_info.endpoint,
561612
token_info=(connection_info.access_token, connection_info.expiration_unix_epoch),
562613
token_refresher=token_refresher,
563-
progress_updater=[progress_updater],
614+
progress_updater=[make_xet_progress_callback(progress, expected_size)],
564615
request_headers=xet_headers,
565616
)
566617

src/huggingface_hub/hf_api.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,7 @@
100100
XetAuthorizationError,
101101
XetRefreshTokenError,
102102
)
103-
from .file_download import DryRunFileInfo, HfFileMetadata, get_hf_file_metadata, hf_hub_url
103+
from .file_download import DryRunFileInfo, HfFileMetadata, get_hf_file_metadata, hf_hub_url, make_xet_progress_callback
104104
from .repocard_data import DatasetCardData, ModelCardData, SpaceCardData
105105
from .utils import (
106106
DEFAULT_IGNORE_PATTERNS,
@@ -12823,16 +12823,14 @@ def token_refresher() -> tuple[str, int]:
1282312823
)
1282412824

1282512825
with progress_cm as progress:
12826-
12827-
def progress_updater(progress_bytes: float):
12828-
progress.update(progress_bytes)
12829-
1283012826
download_files(
1283112827
non_zero_download_infos,
1283212828
endpoint=connection_info.endpoint,
1283312829
token_info=(connection_info.access_token, connection_info.expiration_unix_epoch),
1283412830
token_refresher=token_refresher,
12835-
progress_updater=[progress_updater] * len(non_zero_download_infos),
12831+
progress_updater=[
12832+
make_xet_progress_callback(progress, info.file_size) for info in non_zero_download_infos
12833+
],
1283612834
)
1283712835

1283812836
@validate_hf_hub_args

tests/test_xet_download.py

Lines changed: 265 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -354,3 +354,268 @@ def token_refresher() -> tuple[str, int]:
354354
)
355355

356356
assert os.path.exists(file_path)
357+
358+
359+
def _make_mock_total_update(transfer_increment: float, transfer_total: float):
360+
"""Create a mock PyTotalProgressUpdate with the given transfer fields."""
361+
update = Mock()
362+
update.total_transfer_bytes_completion_increment = transfer_increment
363+
update.total_transfer_bytes = transfer_total
364+
return update
365+
366+
367+
@requires("hf_xet")
368+
class TestXetProgressGranularity:
369+
"""Test that xet_get uses the fine-grained 2-arg callback for tqdm progress."""
370+
371+
_XET_FILE_DATA = XetFileData(file_hash="mock_hash", refresh_route="mock/route")
372+
373+
def _call_xet_get_and_capture(self, tmp_path, mock_download, expected_size=1000, mock_progress_cm=None):
374+
"""Call xet_get and return the captured progress callback."""
375+
incomplete_path = tmp_path / "test_file.bin"
376+
incomplete_path.touch()
377+
378+
captured = {}
379+
380+
def capture(*args, **kwargs):
381+
captured["callback"] = kwargs["progress_updater"][0]
382+
383+
mock_download.side_effect = capture
384+
385+
if mock_progress_cm is not None:
386+
mock_bar = Mock()
387+
mock_bar.n = 0
388+
mock_progress_cm.return_value.__enter__ = Mock(return_value=mock_bar)
389+
mock_progress_cm.return_value.__exit__ = Mock(return_value=False)
390+
else:
391+
mock_bar = None
392+
393+
xet_get(
394+
incomplete_path=incomplete_path,
395+
xet_file_data=self._XET_FILE_DATA,
396+
headers={"authorization": "Bearer token"},
397+
expected_size=expected_size,
398+
)
399+
400+
return captured["callback"], mock_bar
401+
402+
@patch("huggingface_hub.file_download._get_progress_bar_context")
403+
@patch("hf_xet.download_files")
404+
@patch(
405+
"huggingface_hub.file_download.refresh_xet_connection_info",
406+
return_value=XetConnectionInfo(
407+
endpoint="mock_endpoint", access_token="mock_token", expiration_unix_epoch=9999999999
408+
),
409+
)
410+
def test_callback_uses_two_arg_signature(self, _mock_conn, mock_download, mock_progress_cm, tmp_path):
411+
"""Verify xet_get passes a 2-arg callback to download_files, triggering
412+
xet-core's fine-grained network-level progress dispatch."""
413+
callback, _ = self._call_xet_get_and_capture(tmp_path, mock_download, mock_progress_cm=mock_progress_cm)
414+
415+
# Call with 2 args (total_update, item_updates) to confirm it accepts them.
416+
# A 1-arg callback would raise TypeError here.
417+
total_update = _make_mock_total_update(transfer_increment=200, transfer_total=1000)
418+
callback(total_update, []) # should not raise
419+
420+
@patch("huggingface_hub.file_download._get_progress_bar_context")
421+
@patch("hf_xet.download_files")
422+
@patch(
423+
"huggingface_hub.file_download.refresh_xet_connection_info",
424+
return_value=XetConnectionInfo(
425+
endpoint="mock_endpoint", access_token="mock_token", expiration_unix_epoch=9999999999
426+
),
427+
)
428+
def test_progress_bar_scales_network_to_file_size(self, _mock_conn, mock_download, mock_progress_cm, tmp_path):
429+
"""When transfer bytes differ from file size, the progress bar should
430+
scale to expected_size so it always reaches 100%."""
431+
expected_size = 10_000
432+
transfer_total = 5_000 # fewer bytes due to deduplication
433+
434+
callback, mock_bar = self._call_xet_get_and_capture(
435+
tmp_path, mock_download, expected_size=expected_size, mock_progress_cm=mock_progress_cm
436+
)
437+
438+
def update_side_effect(n):
439+
mock_bar.n += n
440+
441+
mock_bar.update = Mock(side_effect=update_side_effect)
442+
443+
# Simulate 5 updates of 1000 transfer bytes each (total: 5000 transfer bytes)
444+
for _ in range(5):
445+
total_update = _make_mock_total_update(transfer_increment=1000, transfer_total=transfer_total)
446+
callback(total_update, [])
447+
448+
# After transferring 5000/5000 bytes, bar should be at expected_size (10000)
449+
assert mock_bar.n == expected_size
450+
451+
@patch("huggingface_hub.file_download._get_progress_bar_context")
452+
@patch("hf_xet.download_files")
453+
@patch(
454+
"huggingface_hub.file_download.refresh_xet_connection_info",
455+
return_value=XetConnectionInfo(
456+
endpoint="mock_endpoint", access_token="mock_token", expiration_unix_epoch=9999999999
457+
),
458+
)
459+
def test_progress_bar_capped_at_expected_size(self, _mock_conn, mock_download, mock_progress_cm, tmp_path):
460+
"""Progress bar should never exceed expected_size."""
461+
expected_size = 1000
462+
463+
callback, mock_bar = self._call_xet_get_and_capture(
464+
tmp_path, mock_download, expected_size=expected_size, mock_progress_cm=mock_progress_cm
465+
)
466+
467+
def update_side_effect(n):
468+
mock_bar.n += n
469+
470+
mock_bar.update = Mock(side_effect=update_side_effect)
471+
472+
# Send more transfer bytes than total (edge case)
473+
total_update = _make_mock_total_update(transfer_increment=1200, transfer_total=1000)
474+
callback(total_update, [])
475+
476+
assert mock_bar.n <= expected_size
477+
478+
@patch("huggingface_hub.file_download._get_progress_bar_context")
479+
@patch("hf_xet.download_files")
480+
@patch(
481+
"huggingface_hub.file_download.refresh_xet_connection_info",
482+
return_value=XetConnectionInfo(
483+
endpoint="mock_endpoint", access_token="mock_token", expiration_unix_epoch=9999999999
484+
),
485+
)
486+
def test_zero_increment_skipped(self, _mock_conn, mock_download, mock_progress_cm, tmp_path):
487+
"""Zero-increment updates should not call progress.update."""
488+
callback, mock_bar = self._call_xet_get_and_capture(tmp_path, mock_download, mock_progress_cm=mock_progress_cm)
489+
490+
total_update = _make_mock_total_update(transfer_increment=0, transfer_total=1000)
491+
callback(total_update, [])
492+
493+
mock_bar.update.assert_not_called()
494+
495+
@patch("huggingface_hub.file_download._get_progress_bar_context")
496+
@patch("hf_xet.download_files")
497+
@patch(
498+
"huggingface_hub.file_download.refresh_xet_connection_info",
499+
return_value=XetConnectionInfo(
500+
endpoint="mock_endpoint", access_token="mock_token", expiration_unix_epoch=9999999999
501+
),
502+
)
503+
def test_expected_size_none_passes_raw_bytes(self, _mock_conn, mock_download, mock_progress_cm, tmp_path):
504+
"""When expected_size is None, raw transfer bytes are passed through."""
505+
callback, mock_bar = self._call_xet_get_and_capture(
506+
tmp_path, mock_download, expected_size=None, mock_progress_cm=mock_progress_cm
507+
)
508+
509+
def update_side_effect(n):
510+
mock_bar.n += n
511+
512+
mock_bar.update = Mock(side_effect=update_side_effect)
513+
514+
total_update = _make_mock_total_update(transfer_increment=500, transfer_total=0)
515+
callback(total_update, [])
516+
517+
assert mock_bar.n == 500
518+
519+
@patch("huggingface_hub.file_download._get_progress_bar_context")
520+
@patch("hf_xet.download_files")
521+
@patch(
522+
"huggingface_hub.file_download.refresh_xet_connection_info",
523+
return_value=XetConnectionInfo(
524+
endpoint="mock_endpoint", access_token="mock_token", expiration_unix_epoch=9999999999
525+
),
526+
)
527+
def test_expected_size_none_with_known_transfer_total(self, _mock_conn, mock_download, mock_progress_cm, tmp_path):
528+
"""When expected_size is None, raw bytes pass through even if transfer_total is known."""
529+
callback, mock_bar = self._call_xet_get_and_capture(
530+
tmp_path, mock_download, expected_size=None, mock_progress_cm=mock_progress_cm
531+
)
532+
533+
def update_side_effect(n):
534+
mock_bar.n += n
535+
536+
mock_bar.update = Mock(side_effect=update_side_effect)
537+
538+
total_update = _make_mock_total_update(transfer_increment=500, transfer_total=2000)
539+
callback(total_update, [])
540+
541+
assert mock_bar.n == 500
542+
543+
@patch("huggingface_hub.file_download._get_progress_bar_context")
544+
@patch("hf_xet.download_files")
545+
@patch(
546+
"huggingface_hub.file_download.refresh_xet_connection_info",
547+
return_value=XetConnectionInfo(
548+
endpoint="mock_endpoint", access_token="mock_token", expiration_unix_epoch=9999999999
549+
),
550+
)
551+
def test_transfer_total_zero_skips_when_expected_size_set(
552+
self, _mock_conn, mock_download, mock_progress_cm, tmp_path
553+
):
554+
"""When expected_size is set but transfer_total is 0 (not yet known),
555+
updates are skipped to avoid injecting unscaled bytes."""
556+
callback, mock_bar = self._call_xet_get_and_capture(
557+
tmp_path, mock_download, expected_size=1000, mock_progress_cm=mock_progress_cm
558+
)
559+
560+
total_update = _make_mock_total_update(transfer_increment=500, transfer_total=0)
561+
callback(total_update, [])
562+
563+
mock_bar.update.assert_not_called()
564+
565+
566+
@requires("hf_xet")
567+
class TestMakeXetProgressCallback:
568+
"""Direct tests for make_xet_progress_callback shared helper."""
569+
570+
def test_multi_file_shared_bar(self):
571+
"""Multiple callbacks sharing one bar should each contribute independently."""
572+
from huggingface_hub.file_download import make_xet_progress_callback
573+
574+
mock_bar = Mock()
575+
mock_bar.n = 0
576+
577+
def update_side_effect(n):
578+
mock_bar.n += n
579+
580+
mock_bar.update = Mock(side_effect=update_side_effect)
581+
582+
# Two files: 600 bytes and 400 bytes, sharing a bar with total=1000
583+
cb_a = make_xet_progress_callback(mock_bar, file_size=600)
584+
cb_b = make_xet_progress_callback(mock_bar, file_size=400)
585+
586+
# File A: 50% done (transfers 500/1000 network bytes -> contributes 300 of 600 file bytes)
587+
cb_a(_make_mock_total_update(transfer_increment=500, transfer_total=1000), [])
588+
assert mock_bar.n == 300
589+
590+
# File B: 100% done (transfers 800/800 -> contributes 400 of 400 file bytes)
591+
cb_b(_make_mock_total_update(transfer_increment=800, transfer_total=800), [])
592+
assert mock_bar.n == 700 # 300 + 400
593+
594+
# File A: 100% done (transfers remaining 500/1000 -> contributes remaining 300)
595+
cb_a(_make_mock_total_update(transfer_increment=500, transfer_total=1000), [])
596+
assert mock_bar.n == 1000 # 600 + 400
597+
598+
def test_no_regression_on_duplicate_progress(self):
599+
"""When cumulative doesn't advance (e.g. duplicate update), bar should not update."""
600+
from huggingface_hub.file_download import make_xet_progress_callback
601+
602+
mock_bar = Mock()
603+
mock_bar.n = 0
604+
605+
def update_side_effect(n):
606+
mock_bar.n += n
607+
608+
mock_bar.update = Mock(side_effect=update_side_effect)
609+
610+
cb = make_xet_progress_callback(mock_bar, file_size=1000)
611+
612+
# First update: 500/1000 transfer -> 500 file bytes
613+
cb(_make_mock_total_update(transfer_increment=500, transfer_total=1000), [])
614+
assert mock_bar.n == 500
615+
assert mock_bar.update.call_count == 1
616+
617+
# Tiny increment that doesn't move int() forward (1 byte of 1000 transfer = 0.001 * 1000 = 1)
618+
# contributed = int(501/1000 * 1000) = 501, advance = 501 - 500 = 1
619+
cb(_make_mock_total_update(transfer_increment=1, transfer_total=1000), [])
620+
assert mock_bar.n == 501
621+
assert mock_bar.update.call_count == 2

0 commit comments

Comments
 (0)