Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 57 additions & 6 deletions src/huggingface_hub/file_download.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,10 @@
import time
import uuid
import warnings
from collections.abc import Callable
from dataclasses import dataclass
from pathlib import Path
from typing import Any, BinaryIO, Literal, NoReturn, overload
from typing import TYPE_CHECKING, Any, BinaryIO, Literal, NoReturn, Protocol, overload
from urllib.parse import quote, urlparse

import httpx
Expand Down Expand Up @@ -56,8 +57,62 @@
from .utils.tqdm import _get_progress_bar_context


if TYPE_CHECKING:
from hf_xet import PyItemProgressUpdate, PyTotalProgressUpdate


logger = logging.get_logger(__name__)


class ProgressBar(Protocol):
"""Any object with a tqdm-compatible ``update(n)`` method."""

def update(self, n: int) -> None: ...


def make_xet_progress_callback(
progress_bar: ProgressBar, file_size: int | None
) -> Callable[["PyTotalProgressUpdate", "list[PyItemProgressUpdate]"], None]:
"""Create a xet-core progress callback that scales network transfer to *file_size*.

Each callback tracks its own cumulative bytes so multiple callbacks can
safely share a single progress bar (e.g. multi-file bucket downloads).
"""
_cumulative_transfer = 0
_last_contributed = 0

def _callback(
total_update: "PyTotalProgressUpdate",
item_updates: "list[PyItemProgressUpdate]",
) -> None:
nonlocal _cumulative_transfer, _last_contributed
increment = total_update.total_transfer_bytes_completion_increment
if increment <= 0:
return

_cumulative_transfer += increment
transfer_total = total_update.total_transfer_bytes

# Indeterminate bar (unknown file size): pass through raw bytes
if file_size is None:
progress_bar.update(int(increment))
return

# File size known but transfer total not yet reported: skip to avoid
# injecting unscaled bytes into a determinate bar
if not transfer_total or transfer_total <= 0:
return

# Scale network transfer fraction to file size
contributed = min(round(_cumulative_transfer / transfer_total * file_size), file_size)
advance = contributed - _last_contributed
if advance > 0:
_last_contributed = contributed
progress_bar.update(advance)

return _callback


# Return value when trying to load a file from cache but the file does not exist in the distant repo.
_CACHED_NO_EXIST = object()
_CACHED_NO_EXIST_T = Any
Expand Down Expand Up @@ -551,16 +606,12 @@ def token_refresher() -> tuple[str, int]:
xet_headers.pop("authorization", None)

with progress_cm as progress:

def progress_updater(progress_bytes: float):
progress.update(progress_bytes)

download_files(
xet_download_info,
endpoint=connection_info.endpoint,
token_info=(connection_info.access_token, connection_info.expiration_unix_epoch),
token_refresher=token_refresher,
progress_updater=[progress_updater],
progress_updater=[make_xet_progress_callback(progress, expected_size)],
request_headers=xet_headers,
)

Expand Down
10 changes: 4 additions & 6 deletions src/huggingface_hub/hf_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@
XetAuthorizationError,
XetRefreshTokenError,
)
from .file_download import DryRunFileInfo, HfFileMetadata, get_hf_file_metadata, hf_hub_url
from .file_download import DryRunFileInfo, HfFileMetadata, get_hf_file_metadata, hf_hub_url, make_xet_progress_callback
from .repocard_data import DatasetCardData, ModelCardData, SpaceCardData
from .utils import (
DEFAULT_IGNORE_PATTERNS,
Expand Down Expand Up @@ -12823,16 +12823,14 @@ def token_refresher() -> tuple[str, int]:
)

with progress_cm as progress:

def progress_updater(progress_bytes: float):
progress.update(progress_bytes)

download_files(
non_zero_download_infos,
endpoint=connection_info.endpoint,
token_info=(connection_info.access_token, connection_info.expiration_unix_epoch),
token_refresher=token_refresher,
progress_updater=[progress_updater] * len(non_zero_download_infos),
progress_updater=[
make_xet_progress_callback(progress, info.file_size) for info in non_zero_download_infos
],
)

@validate_hf_hub_args
Expand Down
265 changes: 265 additions & 0 deletions tests/test_xet_download.py
Original file line number Diff line number Diff line change
Expand Up @@ -354,3 +354,268 @@ def token_refresher() -> tuple[str, int]:
)

assert os.path.exists(file_path)


def _make_mock_total_update(transfer_increment: float, transfer_total: float):
"""Create a mock PyTotalProgressUpdate with the given transfer fields."""
update = Mock()
update.total_transfer_bytes_completion_increment = transfer_increment
update.total_transfer_bytes = transfer_total
return update


@requires("hf_xet")
class TestXetProgressGranularity:
"""Test that xet_get uses the fine-grained 2-arg callback for tqdm progress."""

_XET_FILE_DATA = XetFileData(file_hash="mock_hash", refresh_route="mock/route")

def _call_xet_get_and_capture(self, tmp_path, mock_download, expected_size=1000, mock_progress_cm=None):
"""Call xet_get and return the captured progress callback."""
incomplete_path = tmp_path / "test_file.bin"
incomplete_path.touch()

captured = {}

def capture(*args, **kwargs):
captured["callback"] = kwargs["progress_updater"][0]

mock_download.side_effect = capture

if mock_progress_cm is not None:
mock_bar = Mock()
mock_bar.n = 0
mock_progress_cm.return_value.__enter__ = Mock(return_value=mock_bar)
mock_progress_cm.return_value.__exit__ = Mock(return_value=False)
else:
mock_bar = None

xet_get(
incomplete_path=incomplete_path,
xet_file_data=self._XET_FILE_DATA,
headers={"authorization": "Bearer token"},
expected_size=expected_size,
)

return captured["callback"], mock_bar

@patch("huggingface_hub.file_download._get_progress_bar_context")
@patch("hf_xet.download_files")
@patch(
"huggingface_hub.file_download.refresh_xet_connection_info",
return_value=XetConnectionInfo(
endpoint="mock_endpoint", access_token="mock_token", expiration_unix_epoch=9999999999
),
)
def test_callback_uses_two_arg_signature(self, _mock_conn, mock_download, mock_progress_cm, tmp_path):
"""Verify xet_get passes a 2-arg callback to download_files, triggering
xet-core's fine-grained network-level progress dispatch."""
callback, _ = self._call_xet_get_and_capture(tmp_path, mock_download, mock_progress_cm=mock_progress_cm)

# Call with 2 args (total_update, item_updates) to confirm it accepts them.
# A 1-arg callback would raise TypeError here.
total_update = _make_mock_total_update(transfer_increment=200, transfer_total=1000)
callback(total_update, []) # should not raise

@patch("huggingface_hub.file_download._get_progress_bar_context")
@patch("hf_xet.download_files")
@patch(
"huggingface_hub.file_download.refresh_xet_connection_info",
return_value=XetConnectionInfo(
endpoint="mock_endpoint", access_token="mock_token", expiration_unix_epoch=9999999999
),
)
def test_progress_bar_scales_network_to_file_size(self, _mock_conn, mock_download, mock_progress_cm, tmp_path):
"""When transfer bytes differ from file size, the progress bar should
scale to expected_size so it always reaches 100%."""
expected_size = 10_000
transfer_total = 5_000 # fewer bytes due to deduplication

callback, mock_bar = self._call_xet_get_and_capture(
tmp_path, mock_download, expected_size=expected_size, mock_progress_cm=mock_progress_cm
)

def update_side_effect(n):
mock_bar.n += n

mock_bar.update = Mock(side_effect=update_side_effect)

# Simulate 5 updates of 1000 transfer bytes each (total: 5000 transfer bytes)
for _ in range(5):
total_update = _make_mock_total_update(transfer_increment=1000, transfer_total=transfer_total)
callback(total_update, [])

# After transferring 5000/5000 bytes, bar should be at expected_size (10000)
assert mock_bar.n == expected_size

@patch("huggingface_hub.file_download._get_progress_bar_context")
@patch("hf_xet.download_files")
@patch(
"huggingface_hub.file_download.refresh_xet_connection_info",
return_value=XetConnectionInfo(
endpoint="mock_endpoint", access_token="mock_token", expiration_unix_epoch=9999999999
),
)
def test_progress_bar_capped_at_expected_size(self, _mock_conn, mock_download, mock_progress_cm, tmp_path):
"""Progress bar should never exceed expected_size."""
expected_size = 1000

callback, mock_bar = self._call_xet_get_and_capture(
tmp_path, mock_download, expected_size=expected_size, mock_progress_cm=mock_progress_cm
)

def update_side_effect(n):
mock_bar.n += n

mock_bar.update = Mock(side_effect=update_side_effect)

# Send more transfer bytes than total (edge case)
total_update = _make_mock_total_update(transfer_increment=1200, transfer_total=1000)
callback(total_update, [])

assert mock_bar.n <= expected_size

@patch("huggingface_hub.file_download._get_progress_bar_context")
@patch("hf_xet.download_files")
@patch(
"huggingface_hub.file_download.refresh_xet_connection_info",
return_value=XetConnectionInfo(
endpoint="mock_endpoint", access_token="mock_token", expiration_unix_epoch=9999999999
),
)
def test_zero_increment_skipped(self, _mock_conn, mock_download, mock_progress_cm, tmp_path):
"""Zero-increment updates should not call progress.update."""
callback, mock_bar = self._call_xet_get_and_capture(tmp_path, mock_download, mock_progress_cm=mock_progress_cm)

total_update = _make_mock_total_update(transfer_increment=0, transfer_total=1000)
callback(total_update, [])

mock_bar.update.assert_not_called()

@patch("huggingface_hub.file_download._get_progress_bar_context")
@patch("hf_xet.download_files")
@patch(
"huggingface_hub.file_download.refresh_xet_connection_info",
return_value=XetConnectionInfo(
endpoint="mock_endpoint", access_token="mock_token", expiration_unix_epoch=9999999999
),
)
def test_expected_size_none_passes_raw_bytes(self, _mock_conn, mock_download, mock_progress_cm, tmp_path):
"""When expected_size is None, raw transfer bytes are passed through."""
callback, mock_bar = self._call_xet_get_and_capture(
tmp_path, mock_download, expected_size=None, mock_progress_cm=mock_progress_cm
)

def update_side_effect(n):
mock_bar.n += n

mock_bar.update = Mock(side_effect=update_side_effect)

total_update = _make_mock_total_update(transfer_increment=500, transfer_total=0)
callback(total_update, [])

assert mock_bar.n == 500

@patch("huggingface_hub.file_download._get_progress_bar_context")
@patch("hf_xet.download_files")
@patch(
"huggingface_hub.file_download.refresh_xet_connection_info",
return_value=XetConnectionInfo(
endpoint="mock_endpoint", access_token="mock_token", expiration_unix_epoch=9999999999
),
)
def test_expected_size_none_with_known_transfer_total(self, _mock_conn, mock_download, mock_progress_cm, tmp_path):
"""When expected_size is None, raw bytes pass through even if transfer_total is known."""
callback, mock_bar = self._call_xet_get_and_capture(
tmp_path, mock_download, expected_size=None, mock_progress_cm=mock_progress_cm
)

def update_side_effect(n):
mock_bar.n += n

mock_bar.update = Mock(side_effect=update_side_effect)

total_update = _make_mock_total_update(transfer_increment=500, transfer_total=2000)
callback(total_update, [])

assert mock_bar.n == 500

@patch("huggingface_hub.file_download._get_progress_bar_context")
@patch("hf_xet.download_files")
@patch(
"huggingface_hub.file_download.refresh_xet_connection_info",
return_value=XetConnectionInfo(
endpoint="mock_endpoint", access_token="mock_token", expiration_unix_epoch=9999999999
),
)
def test_transfer_total_zero_skips_when_expected_size_set(
self, _mock_conn, mock_download, mock_progress_cm, tmp_path
):
"""When expected_size is set but transfer_total is 0 (not yet known),
updates are skipped to avoid injecting unscaled bytes."""
callback, mock_bar = self._call_xet_get_and_capture(
tmp_path, mock_download, expected_size=1000, mock_progress_cm=mock_progress_cm
)

total_update = _make_mock_total_update(transfer_increment=500, transfer_total=0)
callback(total_update, [])

mock_bar.update.assert_not_called()


@requires("hf_xet")
class TestMakeXetProgressCallback:
"""Direct tests for make_xet_progress_callback shared helper."""

def test_multi_file_shared_bar(self):
"""Multiple callbacks sharing one bar should each contribute independently."""
from huggingface_hub.file_download import make_xet_progress_callback

mock_bar = Mock()
mock_bar.n = 0

def update_side_effect(n):
mock_bar.n += n

mock_bar.update = Mock(side_effect=update_side_effect)

# Two files: 600 bytes and 400 bytes, sharing a bar with total=1000
cb_a = make_xet_progress_callback(mock_bar, file_size=600)
cb_b = make_xet_progress_callback(mock_bar, file_size=400)

# File A: 50% done (transfers 500/1000 network bytes -> contributes 300 of 600 file bytes)
cb_a(_make_mock_total_update(transfer_increment=500, transfer_total=1000), [])
assert mock_bar.n == 300

# File B: 100% done (transfers 800/800 -> contributes 400 of 400 file bytes)
cb_b(_make_mock_total_update(transfer_increment=800, transfer_total=800), [])
assert mock_bar.n == 700 # 300 + 400

# File A: 100% done (transfers remaining 500/1000 -> contributes remaining 300)
cb_a(_make_mock_total_update(transfer_increment=500, transfer_total=1000), [])
assert mock_bar.n == 1000 # 600 + 400

def test_no_regression_on_duplicate_progress(self):
"""When cumulative doesn't advance (e.g. duplicate update), bar should not update."""
from huggingface_hub.file_download import make_xet_progress_callback

mock_bar = Mock()
mock_bar.n = 0

def update_side_effect(n):
mock_bar.n += n

mock_bar.update = Mock(side_effect=update_side_effect)

cb = make_xet_progress_callback(mock_bar, file_size=1000)

# First update: 500/1000 transfer -> 500 file bytes
cb(_make_mock_total_update(transfer_increment=500, transfer_total=1000), [])
assert mock_bar.n == 500
assert mock_bar.update.call_count == 1

# Tiny increment that doesn't move int() forward (1 byte of 1000 transfer = 0.001 * 1000 = 1)
# contributed = int(501/1000 * 1000) = 501, advance = 501 - 500 = 1
cb(_make_mock_total_update(transfer_increment=1, transfer_total=1000), [])
assert mock_bar.n == 501
assert mock_bar.update.call_count == 2