Skip to content

Commit e124bb4

Browse files
committed
fix: use fine-grained xet-core callback for smoother tqdm progress
Switch xet_get() from a 1-arg to 2-arg callback signature so xet-core provides network-level progress updates (~200KB) instead of coarse disk-write updates (~8MB). Network transfer progress is scaled to file size so the bar always tracks 0-100% correctly. Fixes #4058
1 parent 6983a4d commit e124bb4

File tree

2 files changed

+211
-5
lines changed

2 files changed

+211
-5
lines changed

src/huggingface_hub/file_download.py

Lines changed: 29 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
import warnings
1010
from dataclasses import dataclass
1111
from pathlib import Path
12-
from typing import Any, BinaryIO, Literal, NoReturn, overload
12+
from typing import TYPE_CHECKING, Any, BinaryIO, Literal, NoReturn, overload
1313
from urllib.parse import quote, urlparse
1414

1515
import httpx
@@ -56,6 +56,10 @@
5656
from .utils.tqdm import _get_progress_bar_context
5757

5858

59+
if TYPE_CHECKING:
60+
from hf_xet import PyItemProgressUpdate, PyTotalProgressUpdate
61+
62+
5963
logger = logging.get_logger(__name__)
6064

6165
# Return value when trying to load a file from cache but the file does not exist in the distant repo.
@@ -551,16 +555,36 @@ def token_refresher() -> tuple[str, int]:
551555
xet_headers.pop("authorization", None)
552556

553557
with progress_cm as progress:
554-
555-
def progress_updater(progress_bytes: float):
556-
progress.update(progress_bytes)
558+
# Use 2-arg callback signature so xet-core provides fine-grained
559+
# network-level progress (~200KB) instead of coarse disk-write progress (~8MB).
560+
# Scale network transfer fraction to expected file size so the bar
561+
# always tracks 0-100% correctly even when transfer bytes differ from
562+
# file size (e.g. due to xet deduplication).
563+
_cumulative_transfer = 0
564+
565+
def _xet_progress_updater(
566+
total_update: "PyTotalProgressUpdate",
567+
item_updates: "list[PyItemProgressUpdate]",
568+
) -> None:
569+
nonlocal _cumulative_transfer
570+
increment = total_update.total_transfer_bytes_completion_increment
571+
if increment > 0:
572+
_cumulative_transfer += increment
573+
transfer_total = total_update.total_transfer_bytes
574+
if transfer_total and transfer_total > 0 and expected_size:
575+
new_n = min(int(_cumulative_transfer / transfer_total * expected_size), expected_size)
576+
advance = new_n - progress.n
577+
if advance > 0:
578+
progress.update(advance)
579+
else:
580+
progress.update(int(increment))
557581

558582
download_files(
559583
xet_download_info,
560584
endpoint=connection_info.endpoint,
561585
token_info=(connection_info.access_token, connection_info.expiration_unix_epoch),
562586
token_refresher=token_refresher,
563-
progress_updater=[progress_updater],
587+
progress_updater=[_xet_progress_updater],
564588
request_headers=xet_headers,
565589
)
566590

tests/test_xet_download.py

Lines changed: 182 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -354,3 +354,185 @@ def token_refresher() -> tuple[str, int]:
354354
)
355355

356356
assert os.path.exists(file_path)
357+
358+
359+
@requires("hf_xet")
360+
class TestXetProgressGranularity:
361+
"""Test that xet_get uses the fine-grained 2-arg callback for tqdm progress."""
362+
363+
def _make_mock_total_update(self, transfer_increment: float, transfer_total: float):
364+
"""Create a mock PyTotalProgressUpdate with the given transfer fields."""
365+
update = Mock()
366+
update.total_transfer_bytes_completion_increment = transfer_increment
367+
update.total_transfer_bytes = transfer_total
368+
return update
369+
370+
def test_callback_uses_two_arg_signature(self, tmp_path):
371+
"""Verify xet_get passes a 2-arg callback to download_files, triggering
372+
xet-core's fine-grained network-level progress dispatch."""
373+
incomplete_path = tmp_path / "test_file.bin"
374+
incomplete_path.touch()
375+
376+
xet_file_data = XetFileData(file_hash="mock_hash", refresh_route="mock/route")
377+
connection_info = XetConnectionInfo(
378+
endpoint="mock_endpoint", access_token="mock_token", expiration_unix_epoch=9999999999
379+
)
380+
381+
with patch("huggingface_hub.file_download.refresh_xet_connection_info", return_value=connection_info):
382+
with patch("hf_xet.download_files") as mock_download:
383+
xet_get(
384+
incomplete_path=incomplete_path,
385+
xet_file_data=xet_file_data,
386+
headers={"authorization": "Bearer token"},
387+
expected_size=1000,
388+
)
389+
390+
mock_download.assert_called_once()
391+
callbacks = mock_download.call_args.kwargs["progress_updater"]
392+
assert len(callbacks) == 1
393+
callback = callbacks[0]
394+
395+
# Call with 2 args (total_update, item_updates) to confirm it accepts them.
396+
# A 1-arg callback would raise TypeError here.
397+
total_update = self._make_mock_total_update(transfer_increment=200, transfer_total=1000)
398+
callback(total_update, []) # should not raise
399+
400+
def test_progress_bar_scales_network_to_file_size(self, tmp_path):
401+
"""When transfer bytes differ from file size, the progress bar should
402+
scale to expected_size so it always reaches 100%."""
403+
incomplete_path = tmp_path / "test_file.bin"
404+
incomplete_path.touch()
405+
406+
xet_file_data = XetFileData(file_hash="mock_hash", refresh_route="mock/route")
407+
connection_info = XetConnectionInfo(
408+
endpoint="mock_endpoint", access_token="mock_token", expiration_unix_epoch=9999999999
409+
)
410+
411+
expected_size = 10_000
412+
# Simulate xet transferring fewer bytes than file size (deduplication)
413+
transfer_total = 5_000
414+
415+
captured_callback = None
416+
417+
def capture_callback(*args, **kwargs):
418+
nonlocal captured_callback
419+
captured_callback = kwargs["progress_updater"][0]
420+
421+
with patch("huggingface_hub.file_download.refresh_xet_connection_info", return_value=connection_info):
422+
with patch("hf_xet.download_files", side_effect=capture_callback):
423+
with patch("huggingface_hub.file_download._get_progress_bar_context") as mock_progress_cm:
424+
mock_bar = Mock()
425+
mock_bar.n = 0
426+
mock_progress_cm.return_value.__enter__ = Mock(return_value=mock_bar)
427+
mock_progress_cm.return_value.__exit__ = Mock(return_value=False)
428+
429+
xet_get(
430+
incomplete_path=incomplete_path,
431+
xet_file_data=xet_file_data,
432+
headers={"authorization": "Bearer token"},
433+
expected_size=expected_size,
434+
)
435+
436+
assert captured_callback is not None
437+
438+
# Simulate 5 updates of 1000 transfer bytes each (total: 5000 transfer bytes)
439+
for i in range(5):
440+
total_update = self._make_mock_total_update(
441+
transfer_increment=1000,
442+
transfer_total=transfer_total,
443+
)
444+
445+
# Simulate tqdm advancing bar.n after each update call
446+
def update_side_effect(n):
447+
mock_bar.n += n
448+
449+
mock_bar.update = Mock(side_effect=update_side_effect)
450+
captured_callback(total_update, [])
451+
452+
# After transferring 5000/5000 bytes, bar should be at expected_size (10000)
453+
assert mock_bar.n == expected_size
454+
455+
def test_progress_bar_capped_at_expected_size(self, tmp_path):
456+
"""Progress bar should never exceed expected_size."""
457+
incomplete_path = tmp_path / "test_file.bin"
458+
incomplete_path.touch()
459+
460+
xet_file_data = XetFileData(file_hash="mock_hash", refresh_route="mock/route")
461+
connection_info = XetConnectionInfo(
462+
endpoint="mock_endpoint", access_token="mock_token", expiration_unix_epoch=9999999999
463+
)
464+
465+
expected_size = 1000
466+
transfer_total = 1000
467+
468+
captured_callback = None
469+
470+
def capture_callback(*args, **kwargs):
471+
nonlocal captured_callback
472+
captured_callback = kwargs["progress_updater"][0]
473+
474+
with patch("huggingface_hub.file_download.refresh_xet_connection_info", return_value=connection_info):
475+
with patch("hf_xet.download_files", side_effect=capture_callback):
476+
with patch("huggingface_hub.file_download._get_progress_bar_context") as mock_progress_cm:
477+
mock_bar = Mock()
478+
mock_bar.n = 0
479+
mock_progress_cm.return_value.__enter__ = Mock(return_value=mock_bar)
480+
mock_progress_cm.return_value.__exit__ = Mock(return_value=False)
481+
482+
xet_get(
483+
incomplete_path=incomplete_path,
484+
xet_file_data=xet_file_data,
485+
headers={"authorization": "Bearer token"},
486+
expected_size=expected_size,
487+
)
488+
489+
assert captured_callback is not None
490+
491+
def update_side_effect(n):
492+
mock_bar.n += n
493+
494+
mock_bar.update = Mock(side_effect=update_side_effect)
495+
496+
# Send more transfer bytes than total (edge case)
497+
total_update = self._make_mock_total_update(transfer_increment=1200, transfer_total=transfer_total)
498+
captured_callback(total_update, [])
499+
500+
assert mock_bar.n <= expected_size
501+
502+
def test_zero_increment_skipped(self, tmp_path):
503+
"""Zero-increment updates should not call progress.update."""
504+
incomplete_path = tmp_path / "test_file.bin"
505+
incomplete_path.touch()
506+
507+
xet_file_data = XetFileData(file_hash="mock_hash", refresh_route="mock/route")
508+
connection_info = XetConnectionInfo(
509+
endpoint="mock_endpoint", access_token="mock_token", expiration_unix_epoch=9999999999
510+
)
511+
512+
captured_callback = None
513+
514+
def capture_callback(*args, **kwargs):
515+
nonlocal captured_callback
516+
captured_callback = kwargs["progress_updater"][0]
517+
518+
with patch("huggingface_hub.file_download.refresh_xet_connection_info", return_value=connection_info):
519+
with patch("hf_xet.download_files", side_effect=capture_callback):
520+
with patch("huggingface_hub.file_download._get_progress_bar_context") as mock_progress_cm:
521+
mock_bar = Mock()
522+
mock_bar.n = 0
523+
mock_progress_cm.return_value.__enter__ = Mock(return_value=mock_bar)
524+
mock_progress_cm.return_value.__exit__ = Mock(return_value=False)
525+
526+
xet_get(
527+
incomplete_path=incomplete_path,
528+
xet_file_data=xet_file_data,
529+
headers={"authorization": "Bearer token"},
530+
expected_size=1000,
531+
)
532+
533+
assert captured_callback is not None
534+
535+
total_update = self._make_mock_total_update(transfer_increment=0, transfer_total=1000)
536+
captured_callback(total_update, [])
537+
538+
mock_bar.update.assert_not_called()

0 commit comments

Comments
 (0)