diff --git a/src/huggingface_hub/_buckets.py b/src/huggingface_hub/_buckets.py index f291890d06..d1cdd8cd37 100644 --- a/src/huggingface_hub/_buckets.py +++ b/src/huggingface_hub/_buckets.py @@ -31,6 +31,7 @@ from . import constants, logging from .errors import BucketNotFoundError from .utils import XetFileData, disable_progress_bars, enable_progress_bars, parse_datetime +from .utils._hf_uri import parse_hf_url from .utils._terminal import StatusLine @@ -56,12 +57,8 @@ def _split_bucket_id_and_prefix(path: str) -> tuple[str, str]: Returns (bucket_id, prefix) where prefix may be empty string. Raises ValueError if path doesn't contain at least namespace/name. """ - parts = path.split("/", 2) - if len(parts) < 2 or not parts[0] or not parts[1]: - raise ValueError(f"Invalid bucket path: '{path}'. Expected format: namespace/bucket_name") - bucket_id = f"{parts[0]}/{parts[1]}" - prefix = parts[2] if len(parts) > 2 else "" - return bucket_id, prefix + parsed = parse_hf_url(f"buckets/{path}") + return parsed.bucket_id, parsed.path @dataclass @@ -242,9 +239,8 @@ def _parse_bucket_path(path: str) -> tuple[str, str]: Returns: tuple: (bucket_id, prefix) where bucket_id is "namespace/bucket_name" and prefix may be empty string. """ - if not path.startswith(BUCKET_PREFIX): - raise ValueError(f"Invalid bucket path: {path}. Must start with {BUCKET_PREFIX}") - return _split_bucket_id_and_prefix(path.removeprefix(BUCKET_PREFIX)) + parsed = parse_hf_url(path) + return parsed.bucket_id, parsed.path def _is_bucket_path(path: str) -> bool: diff --git a/src/huggingface_hub/cli/jobs.py b/src/huggingface_hub/cli/jobs.py index 4ece8c5f8a..d67b8bc0d1 100644 --- a/src/huggingface_hub/cli/jobs.py +++ b/src/huggingface_hub/cli/jobs.py @@ -78,6 +78,7 @@ from huggingface_hub._jobs_api import Volume from huggingface_hub.errors import CLIError, HfHubHTTPError from huggingface_hub.utils import logging +from huggingface_hub.utils._hf_uri import ParsedBucketUrl, parse_hf_url from huggingface_hub.utils._cache_manager import _format_size from ._cli_utils import ( @@ -1099,12 +1100,6 @@ def _parse_volumes(volumes: Optional[list[str]]) -> Optional[list[Volume]]: return None HF_PREFIX = "hf://" - HF_TYPES_MAPPING = { - "models": constants.REPO_TYPE_MODEL, - "datasets": constants.REPO_TYPE_DATASET, - "spaces": constants.REPO_TYPE_SPACE, - "buckets": "bucket", - } result: list[Volume] = [] for raw_spec in volumes: @@ -1124,45 +1119,29 @@ def _parse_volumes(volumes: Optional[list[str]]) -> Optional[list[Volume]]: f"Invalid volume format: '{raw_spec}'. Source must start with 'hf://'. " f"Expected hf://[TYPE/]SOURCE:/MOUNT_PATH[:ro]. E.g. hf://gpt2:/data" ) - spec = spec[len(HF_PREFIX) :] # Find the mount path: look for :/ pattern - colon_slash_idx = spec.find(":/") + # We search in the part after "hf://" to avoid matching the "://" in the prefix. + after_prefix = spec[len(HF_PREFIX) :] + colon_slash_idx = after_prefix.find(":/") if colon_slash_idx == -1: raise CLIError( f"Invalid volume format: '{raw_spec}'. Expected hf://[TYPE/]SOURCE:/MOUNT_PATH[:ro]. E.g. hf://gpt2:/data" ) - source_part = spec[:colon_slash_idx] - mount_path = spec[colon_slash_idx + 1 :] - - # Parse type from source_part (first segment before /) - # Then split remaining into source (namespace/name or name) and optional path. - slash_idx = source_part.find("/") - if slash_idx == -1: - # No slash: bare source like "gpt2" -> model type - vol_type_str = constants.REPO_TYPE_MODEL - source = source_part - path = None + source_uri = HF_PREFIX + after_prefix[:colon_slash_idx] + mount_path = after_prefix[colon_slash_idx + 1 :] + + # Parse the source URI using the central parser. + parsed = parse_hf_url(source_uri) + + if isinstance(parsed, ParsedBucketUrl): + vol_type_str = "bucket" + source = parsed.bucket_id + path = parsed.path or None else: - first_segment = source_part[:slash_idx] - if first_segment in HF_TYPES_MAPPING: - vol_type_str = HF_TYPES_MAPPING[first_segment] - remaining = source_part[slash_idx + 1 :] - else: - # First segment isn't a known type -> model type - vol_type_str = constants.REPO_TYPE_MODEL - remaining = source_part - - # Split remaining into source (namespace/name) and optional path. - # Repo/bucket IDs are "namespace/name" (2 segments) or "name" (1 segment). - # Any extra segments are the path inside the repo/bucket. - parts = remaining.split("/", 2) - if len(parts) >= 3: - source = parts[0] + "/" + parts[1] - path = parts[2] - else: - source = remaining - path = None + vol_type_str = parsed.repo_type + source = parsed.repo_id + path = parsed.path_in_repo or None result.append( Volume( diff --git a/src/huggingface_hub/hf_api.py b/src/huggingface_hub/hf_api.py index 1397baee05..850d27ac5f 100644 --- a/src/huggingface_hub/hf_api.py +++ b/src/huggingface_hub/hf_api.py @@ -301,6 +301,9 @@ def repo_type_and_id_from_hf_id(hf_id: str, hub_url: Optional[str] = None) -> tu [`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError) If `repo_type` is unknown. """ + + from .utils._hf_uri import ParsedBucketUrl, parse_hf_url + input_hf_id = hf_id # Get the hub_url (with or without protocol) @@ -315,86 +318,45 @@ def repo_type_and_id_from_hf_id(hf_id: str, hub_url: Optional[str] = None) -> tu if hf_id.startswith(HFFS_PREFIX): # Remove "hf://" prefix if exists hf_id = hf_id[len(HFFS_PREFIX) :] - # If it's a URL, strip the endpoint prefix to get the path + # If it's a URL, strip the endpoint prefix to get the relative path if is_hf_url: - # Remove protocol if present hf_id_normalized = _REGEX_HTTP_PROTOCOL.sub("", hf_id) - - # Remove the hub_url prefix to get the relative path if hf_id_normalized.startswith(hub_url_without_protocol): - # Strip the hub URL and any leading slashes hf_id = hf_id_normalized[len(hub_url_without_protocol) :].lstrip("/") + # At this point hf_id is a relative path like "datasets/user/repo", "user/repo", or "repo". url_segments = hf_id.split("/") - is_hf_id = len(url_segments) <= 3 - namespace: Optional[str] - if is_hf_url: - # For URLs, we need to extract repo_type, namespace, repo_id - # Expected format after stripping endpoint: [repo_type]/namespace/repo_id or namespace/repo_id - - if len(url_segments) >= 3: - # Check if first segment is a repo type - if url_segments[0] in constants.REPO_TYPES_MAPPING: - repo_type = constants.REPO_TYPES_MAPPING[url_segments[0]] - namespace = url_segments[1] - repo_id = url_segments[2] - elif url_segments[0] == "buckets": - # Special case for buckets - repo_type = "bucket" - namespace = url_segments[1] - repo_id = url_segments[2] - else: - # First segment is namespace - namespace = url_segments[0] - repo_id = url_segments[1] - repo_type = None - elif len(url_segments) == 2: - namespace = url_segments[0] - repo_id = url_segments[1] - - # Check if namespace is actually a repo type mapping - if namespace in constants.REPO_TYPES_MAPPING: - # Mean canonical dataset or model - repo_type = constants.REPO_TYPES_MAPPING[namespace] - namespace = None - elif namespace == "buckets": - # Special case for buckets - repo_type = "bucket" - namespace = None - else: - repo_type = None - else: - # Single segment - repo_id = url_segments[0] - namespace = None - repo_type = None - elif is_hf_id: - if len(url_segments) == 3: - # Passed // or // - repo_type, namespace, repo_id = url_segments[-3:] - elif len(url_segments) == 2: - if url_segments[0] in constants.REPO_TYPES_MAPPING: - # Passed '' or 'datasets/' for a canonical model or dataset - repo_type = constants.REPO_TYPES_MAPPING[url_segments[0]] - namespace = None - repo_id = hf_id.split("/")[-1] - elif url_segments[0] == "buckets": - # Special case for buckets - repo_type = "bucket" - namespace = None - repo_id = hf_id.split("/")[-1] - else: - # Passed / or / - namespace, repo_id = hf_id.split("/")[-2:] - repo_type = None - else: - # Passed - repo_id = url_segments[0] - namespace, repo_type = None, None - else: + # For non-URL inputs, reject paths with more than 3 segments. + # URL inputs can have extra segments (e.g. /blob/main/file.txt) which are ignored. + if not is_hf_url and len(url_segments) > 3: raise ValueError(f"Unable to retrieve user and repo ID from the passed HF ID: {hf_id}") + if not is_hf_url and len(url_segments) == 3: + # Passed // — accept singular type names + # (e.g. "dataset/user/id") which parse_hf_url doesn't handle. + repo_type, namespace, repo_id = url_segments + elif not is_hf_url and len(url_segments) == 2 and url_segments[0] == "buckets": + # Special case: "buckets/name" (no namespace) — parse_hf_url expects + # namespace/name for buckets, but this function accepts bare bucket names. + repo_type = "bucket" + namespace = None + repo_id = url_segments[1] + else: + # Delegate to the central parser for type detection, bucket handling, etc. + parsed = parse_hf_url(hf_id) + + if isinstance(parsed, ParsedBucketUrl): + repo_type: Optional[str] = "bucket" + namespace = parsed.namespace + repo_id = parsed.bucket_name + else: + # When no type prefix is present, parse_hf_url defaults to "model". + # This function returns None instead. + repo_type = parsed.repo_type if parsed.has_explicit_type else None + namespace = parsed.namespace + repo_id = parsed.repo_name + # Check if repo type is known (mapping "spaces" => "space" + empty value => `None`) if repo_type in constants.REPO_TYPES_MAPPING: repo_type = constants.REPO_TYPES_MAPPING[repo_type] diff --git a/src/huggingface_hub/hf_file_system.py b/src/huggingface_hub/hf_file_system.py index 45d149b5c7..826c668b27 100644 --- a/src/huggingface_hub/hf_file_system.py +++ b/src/huggingface_hub/hf_file_system.py @@ -1,5 +1,4 @@ import os -import re import tempfile import threading from collections import deque @@ -29,20 +28,10 @@ from .file_download import hf_hub_url, http_get from .hf_api import BucketFile, BucketFolder, HfApi, LastCommitInfo, RepoFile, RepoFolder from .utils import HFValidationError, hf_raise_for_status, http_backoff, http_stream_backoff +from .utils._hf_uri import SPECIAL_REFS_REVISION_REGEX, ParsedBucketUrl, parse_hf_url from .utils.insecure_hashlib import md5 -# Regex used to match special revisions with "/" in them (see #1710) -SPECIAL_REFS_REVISION_REGEX = re.compile( - r""" - (^refs\/convert\/\w+) # `refs/convert/parquet` revisions - | - (^refs\/pr\/\d+) # PR revisions - """, - re.VERBOSE, -) - - @dataclass class HfFileSystemResolvedPath: """Top level Data structure containing information about a resolved Hugging Face file system path.""" @@ -319,73 +308,89 @@ def _align_revision_in_path_with_revision( path = self._strip_protocol(path) if not path: - # can't list repositories at root raise NotImplementedError("Access to buckets and repositories lists is not implemented.") - elif path.split("/")[0] == "buckets": - bucket_id = "/".join(path.split("/")[1:3]) - path = "/".join(path.split("/")[3:]) - bucket_exists, err = self._bucket_exists(bucket_id) + + # --- Use the central parser for type detection, bucket handling, + # revision extraction, and repo_id / path splitting. ---------- + parsed = parse_hf_url(path) + + # --- Bucket paths ------------------------------------------------ + if isinstance(parsed, ParsedBucketUrl): + bucket_exists, err = self._bucket_exists(parsed.bucket_id) if not bucket_exists: _raise_file_not_found(path, err) - return HfFileSystemResolvedBucketPath(bucket_id=bucket_id, path=path) - elif path.split("/")[0] + "/" in constants.REPO_TYPES_URL_PREFIXES.values(): - if "/" not in path: - # can't list repositories at the repository type level - raise NotImplementedError("Access to repositories lists is not implemented.") - repo_type, path = path.split("/", 1) - repo_type = constants.REPO_TYPES_MAPPING[repo_type] - else: - repo_type = constants.REPO_TYPE_MODEL - if path.count("/") > 0: - if "@" in "/".join(path.split("/")[:2]): - repo_id, revision_in_path = path.split("@", 1) - if "/" in revision_in_path: - match = SPECIAL_REFS_REVISION_REGEX.search(revision_in_path) - if match is not None and revision in (None, match.group()): - # Handle `refs/convert/parquet` and PR revisions separately - path_in_repo = SPECIAL_REFS_REVISION_REGEX.sub("", revision_in_path).lstrip("/") - revision_in_path = match.group() - else: - revision_in_path, path_in_repo = revision_in_path.split("/", 1) + return HfFileSystemResolvedBucketPath(bucket_id=parsed.bucket_id, path=parsed.path) + + # --- Repository paths -------------------------------------------- + repo_type = parsed.repo_type + repo_id = parsed.repo_id + path_in_repo = parsed.path_in_repo + revision_in_path: Optional[str] = parsed.revision + + if not repo_id: + raise NotImplementedError("Access to repositories lists is not implemented.") + + # Handle special refs: if the caller supplies a revision argument that + # conflicts with a special-ref parsed from the path, re-parse the + # revision portion with a simple "/" split (matching legacy behaviour). + if revision_in_path is not None and revision is not None and revision_in_path != revision: + # The caller's explicit revision wins; re-split assuming the first + # "/" after "@" separates revision from path. + raw_after_at = path.split("@", 1)[1] if "@" in path else "" + if "/" in raw_after_at: + revision_in_path_candidate, path_in_repo = raw_after_at.split("/", 1) + revision_in_path = unquote(revision_in_path_candidate) + # Fall through to conflict detection below. + + # Reconcile revision from path vs explicit argument. + revision = _align_revision_in_path_with_revision(revision_in_path, revision) + + # Keep the raw (possibly URL-encoded) revision string for unresolve(). + raw_revision = None + if "@" in path: + raw_after_at = path.split("@", 1)[1] + if "/" in raw_after_at: + match = SPECIAL_REFS_REVISION_REGEX.search(raw_after_at) + if match is not None and revision in (None, match.group()): + raw_revision = match.group() else: - path_in_repo = "" - revision = _align_revision_in_path_with_revision(unquote(revision_in_path), revision) - repo_and_revision_exist, err = self._repo_and_revision_exist(repo_type, repo_id, revision) - if not repo_and_revision_exist: - _raise_file_not_found(path, err) + raw_revision = raw_after_at.split("/", 1)[0] else: - revision_in_path = None - repo_id_with_namespace = "/".join(path.split("/")[:2]) - path_in_repo_with_namespace = "/".join(path.split("/")[2:]) - repo_id_without_namespace = path.split("/")[0] - path_in_repo_without_namespace = "/".join(path.split("/")[1:]) - repo_id = repo_id_with_namespace - path_in_repo = path_in_repo_with_namespace - repo_and_revision_exist, err = self._repo_and_revision_exist(repo_type, repo_id, revision) - if not repo_and_revision_exist: - if isinstance(err, (RepositoryNotFoundError, HFValidationError)): - repo_id = repo_id_without_namespace - path_in_repo = path_in_repo_without_namespace - repo_and_revision_exist, _ = self._repo_and_revision_exist(repo_type, repo_id, revision) - if not repo_and_revision_exist: - _raise_file_not_found(path, err) - else: + raw_revision = raw_after_at + + # --- I/O validation (namespace-fallback for ambiguous repo ids) --- + if "/" in repo_id and parsed.revision is None: + # Ambiguous: "a/b/c" could be namespace/repo with path "c", + # or bare repo "a" with path "b/c". Try namespace/repo first, + # then fall back to single-segment repo_id on 404. + repo_and_revision_exist, err = self._repo_and_revision_exist(repo_type, repo_id, revision) + if not repo_and_revision_exist: + if isinstance(err, (RepositoryNotFoundError, HFValidationError)): + parts = repo_id.split("/", 1) + fallback_repo_id = parts[0] + fallback_path = parts[1] + ("/" + path_in_repo if path_in_repo else "") + repo_and_revision_exist, _ = self._repo_and_revision_exist( + repo_type, fallback_repo_id, revision + ) + if not repo_and_revision_exist: _raise_file_not_found(path, err) - else: - repo_id = path - path_in_repo = "" - if "@" in path: - repo_id, revision_in_path = path.split("@", 1) - revision = _align_revision_in_path_with_revision(unquote(revision_in_path), revision) - else: - revision_in_path = None - repo_and_revision_exist, _ = self._repo_and_revision_exist(repo_type, repo_id, revision) + repo_id = fallback_repo_id + path_in_repo = fallback_path + else: + _raise_file_not_found(path, err) + elif repo_id: + # Unambiguous: either has explicit revision, or no "/" in repo_id. + repo_and_revision_exist, err = self._repo_and_revision_exist(repo_type, repo_id, revision) if not repo_and_revision_exist: - raise NotImplementedError("Access to repositories lists is not implemented.") + if not path_in_repo and parsed.revision is None: + raise NotImplementedError("Access to repositories lists is not implemented.") + _raise_file_not_found(path, err) + else: + raise NotImplementedError("Access to repositories lists is not implemented.") revision = revision if revision is not None else constants.DEFAULT_REVISION return HfFileSystemResolvedRepositoryPath( - repo_type, repo_id, revision, path_in_repo, _raw_revision=revision_in_path + repo_type, repo_id, revision, path_in_repo, _raw_revision=raw_revision ) def invalidate_cache(self, path: Optional[str] = None) -> None: diff --git a/src/huggingface_hub/utils/_hf_uri.py b/src/huggingface_hub/utils/_hf_uri.py new file mode 100644 index 0000000000..c7a595b2b9 --- /dev/null +++ b/src/huggingface_hub/utils/_hf_uri.py @@ -0,0 +1,213 @@ +"""Central parsing for ``hf://`` URIs and plain HF identifiers.""" + +import re +from dataclasses import dataclass +from typing import Optional, Union +from urllib.parse import unquote + +from .. import constants + + +# --------------------------------------------------------------------------- +# Regex for special revision references that contain "/" characters +# (e.g. refs/convert/parquet, refs/pr/123). Moved here from +# hf_file_system.py so every consumer can share it. +# --------------------------------------------------------------------------- +SPECIAL_REFS_REVISION_REGEX = re.compile( + r""" + (^refs\/convert\/\w+) # `refs/convert/parquet` revisions + | + (^refs\/pr\/\d+) # PR revisions + """, + re.VERBOSE, +) + +# Mapping that includes "buckets" alongside the standard repo-type plurals. +_HF_TYPES_MAPPING: dict[str, str] = { + **constants.REPO_TYPES_MAPPING, # datasets, spaces, models + "buckets": "bucket", +} + +HF_PREFIX = "hf://" + + +# --------------------------------------------------------------------------- +# Result dataclasses +# --------------------------------------------------------------------------- +@dataclass(frozen=True) +class ParsedHfUrl: + """Result of parsing an ``hf://`` URI or plain identifier for a repository.""" + + repo_type: str + """Canonical repo type: ``"model"``, ``"dataset"``, or ``"space"``.""" + + repo_id: str + """Repository identifier, e.g. ``"namespace/name"`` or ``"name"``.""" + + revision: Optional[str] + """Revision parsed from the ``@revision`` syntax, or *None* if absent.""" + + path_in_repo: str + """Sub-path inside the repository (empty string when absent).""" + + has_explicit_type: bool = False + """Whether a type prefix (``datasets/``, ``spaces/``, ``models/``) was present in the input.""" + + @property + def namespace(self) -> Optional[str]: + """The namespace (user/org) portion of the repo_id, or *None* for canonical repos.""" + if "/" in self.repo_id: + return self.repo_id.split("/", 1)[0] + return None + + @property + def repo_name(self) -> str: + """The repo name portion of the repo_id (without namespace).""" + if "/" in self.repo_id: + return self.repo_id.split("/", 1)[1] + return self.repo_id + + +@dataclass(frozen=True) +class ParsedBucketUrl: + """Result of parsing an ``hf://buckets/`` URI or plain bucket identifier.""" + + bucket_id: str + """Bucket identifier: ``"namespace/name"``.""" + + path: str + """Prefix / sub-path inside the bucket (empty string when absent).""" + + @property + def namespace(self) -> Optional[str]: + """The namespace (user/org) portion of the bucket_id, or *None*.""" + if "/" in self.bucket_id: + return self.bucket_id.split("/", 1)[0] + return None + + @property + def bucket_name(self) -> str: + """The bucket name portion of the bucket_id (without namespace).""" + if "/" in self.bucket_id: + return self.bucket_id.split("/", 1)[1] + return self.bucket_id + + +# --------------------------------------------------------------------------- +# Core parser +# --------------------------------------------------------------------------- +def parse_hf_url( + url: str, + *, + default_type: str = constants.REPO_TYPE_MODEL, +) -> Union[ParsedHfUrl, ParsedBucketUrl]: + """Parse an ``hf://`` URI or plain HF identifier into structured components. + + Pure parsing only - **no HTTP calls, no existence validation**. + + Accepted formats:: + + hf://[TYPE/]REPO_ID[@REVISION][/PATH] + [TYPE/]REPO_ID[@REVISION][/PATH] + hf://buckets/NAMESPACE/NAME[/PREFIX] + buckets/NAMESPACE/NAME[/PREFIX] + + Where *TYPE* is one of ``models``, ``datasets``, ``spaces`` (mapped to + the canonical singular form). + + Args: + url: The URI or identifier to parse. + default_type: Repo type to assume when none is present in *url*. + Defaults to ``"model"``. + + Returns: + :class:`ParsedHfUrl` for repository paths, :class:`ParsedBucketUrl` + for bucket paths. + + Raises: + ValueError: If the URL cannot be structurally parsed (e.g. a bucket + path missing ``namespace/name``). + """ + # Strip the hf:// prefix if present. + path = url[len(HF_PREFIX) :] if url.startswith(HF_PREFIX) else url + + if not path: + raise ValueError(f"Empty path in URL: '{url}'") + + # ------------------------------------------------------------------ + # Detect type from first segment + # ------------------------------------------------------------------ + first_segment = path.split("/", 1)[0] + + has_explicit_type = first_segment in _HF_TYPES_MAPPING + if has_explicit_type: + mapped_type = _HF_TYPES_MAPPING[first_segment] + path = path.split("/", 1)[1] if "/" in path else "" + + if mapped_type == "bucket": + return _parse_bucket_path(path, original_url=url) + repo_type = mapped_type + else: + repo_type = default_type + + # ------------------------------------------------------------------ + # Parse repo_id, optional @revision, and path_in_repo + # ------------------------------------------------------------------ + return _parse_repo_path(path, repo_type=repo_type, has_explicit_type=has_explicit_type) + + +# --------------------------------------------------------------------------- +# Internal helpers +# --------------------------------------------------------------------------- +def _parse_bucket_path(path: str, *, original_url: str = "") -> ParsedBucketUrl: + """Parse the portion after ``buckets/`` into a :class:`ParsedBucketUrl`.""" + parts = path.split("/", 2) + if len(parts) < 2 or not parts[0] or not parts[1]: + raise ValueError( + f"Invalid bucket path: '{original_url or path}'. " + "Expected format: hf://buckets/namespace/bucket_name[/prefix]" + ) + bucket_id = f"{parts[0]}/{parts[1]}" + prefix = parts[2] if len(parts) > 2 else "" + return ParsedBucketUrl(bucket_id=bucket_id, path=prefix) + + +def _parse_repo_path(path: str, *, repo_type: str, has_explicit_type: bool = False) -> ParsedHfUrl: + """Parse ``REPO_ID[@REVISION][/PATH]`` into a :class:`ParsedHfUrl`. + + Assumes the type prefix (if any) has already been stripped. + """ + if not path: + return ParsedHfUrl(repo_type=repo_type, repo_id="", revision=None, path_in_repo="", has_explicit_type=has_explicit_type) + + # -- Handle @revision syntax ------------------------------------------ + # The "@" can appear in the first two path segments (repo_id part). + # We only split on "@" if it appears in those first two segments. + if "@" in "/".join(path.split("/")[:2]): + repo_id, revision_and_rest = path.split("@", 1) + + if "/" in revision_and_rest: + # Check for special refs (refs/pr/N, refs/convert/…) that contain "/" + match = SPECIAL_REFS_REVISION_REGEX.search(revision_and_rest) + if match is not None: + revision = match.group() + path_in_repo = SPECIAL_REFS_REVISION_REGEX.sub("", revision_and_rest).lstrip("/") + else: + revision, path_in_repo = revision_and_rest.split("/", 1) + else: + revision = revision_and_rest + path_in_repo = "" + + revision = unquote(revision) + return ParsedHfUrl(repo_type=repo_type, repo_id=repo_id, revision=revision, path_in_repo=path_in_repo, has_explicit_type=has_explicit_type) + + # -- No @revision: split into repo_id (max 2 segments) + path -------- + parts = path.split("/") + if len(parts) >= 2: + repo_id = f"{parts[0]}/{parts[1]}" + path_in_repo = "/".join(parts[2:]) + else: + repo_id = parts[0] + path_in_repo = "" + + return ParsedHfUrl(repo_type=repo_type, repo_id=repo_id, revision=None, path_in_repo=path_in_repo, has_explicit_type=has_explicit_type) diff --git a/tests/test_hf_uri.py b/tests/test_hf_uri.py new file mode 100644 index 0000000000..64113ef609 --- /dev/null +++ b/tests/test_hf_uri.py @@ -0,0 +1,77 @@ +"""Tests for huggingface_hub.utils._hf_uri – centralised HF path parsing.""" + +import pytest + +from huggingface_hub.utils._hf_uri import ( + ParsedBucketUrl, + ParsedHfUrl, + parse_hf_url, +) + + +@pytest.mark.parametrize( + "url, expected", + [ + # bare / namespaced repo ids (no explicit type → has_explicit_type=False) + ("gpt2", ParsedHfUrl("model", "gpt2", None, "")), + ("username/my-model", ParsedHfUrl("model", "username/my-model", None, "")), + # hf:// prefix + ("hf://gpt2", ParsedHfUrl("model", "gpt2", None, "")), + ("hf://username/my-model", ParsedHfUrl("model", "username/my-model", None, "")), + # explicit type prefixes (has_explicit_type=True) + ("hf://datasets/squad", ParsedHfUrl("dataset", "squad", None, "", True)), + ("hf://datasets/username/ds", ParsedHfUrl("dataset", "username/ds", None, "", True)), + ("hf://spaces/username/app", ParsedHfUrl("space", "username/app", None, "", True)), + ("hf://models/username/m", ParsedHfUrl("model", "username/m", None, "", True)), + ("datasets/username/ds", ParsedHfUrl("dataset", "username/ds", None, "", True)), + # path_in_repo + ("hf://datasets/user/ds/train/data.csv", ParsedHfUrl("dataset", "user/ds", None, "train/data.csv", True)), + ("hf://user/model/config.json", ParsedHfUrl("model", "user/model", None, "config.json")), + # @revision + ("hf://datasets/user/ds@dev/train.csv", ParsedHfUrl("dataset", "user/ds", "dev", "train.csv", True)), + ("hf://user/model@v1.0", ParsedHfUrl("model", "user/model", "v1.0", "")), + ("gpt2@main", ParsedHfUrl("model", "gpt2", "main", "")), + # special refs revisions + ("hf://user/m@refs/pr/10/file.txt", ParsedHfUrl("model", "user/m", "refs/pr/10", "file.txt")), + ("hf://user/m@refs/pr/10", ParsedHfUrl("model", "user/m", "refs/pr/10", "")), + ("hf://datasets/u/ds@refs/convert/parquet", ParsedHfUrl("dataset", "u/ds", "refs/convert/parquet", "", True)), + ("hf://datasets/u/ds@refs/convert/parquet/default/train/0000.parquet", ParsedHfUrl("dataset", "u/ds", "refs/convert/parquet", "default/train/0000.parquet", True)), + # URL-encoded revision + ("hf://user/m@refs%2Fpr%2F10", ParsedHfUrl("model", "user/m", "refs/pr/10", "")), + # buckets + ("hf://buckets/ns/bucket", ParsedBucketUrl("ns/bucket", "")), + ("hf://buckets/ns/bucket/some/prefix", ParsedBucketUrl("ns/bucket", "some/prefix")), + ("buckets/ns/bucket", ParsedBucketUrl("ns/bucket", "")), + ("buckets/ns/bucket/data/logs", ParsedBucketUrl("ns/bucket", "data/logs")), + ], +) +def test_parse_hf_url(url, expected): + assert parse_hf_url(url) == expected + + +def test_default_type_override(): + assert parse_hf_url("user/thing", default_type="dataset").repo_type == "dataset" + # explicit type prefix wins over default_type + assert parse_hf_url("spaces/user/app", default_type="dataset").repo_type == "space" + + +@pytest.mark.parametrize("url", ["", "hf://", "hf://buckets/only-ns", "hf://buckets/"]) +def test_parse_hf_url_errors(url): + with pytest.raises(ValueError): + parse_hf_url(url) + + +def test_properties(): + repo = parse_hf_url("hf://datasets/username/ds") + assert repo.namespace == "username" + assert repo.repo_name == "ds" + + bare = parse_hf_url("gpt2") + assert bare.namespace is None + assert bare.repo_name == "gpt2" + + bucket = parse_hf_url("hf://buckets/org/my-bucket") + assert bucket.namespace == "org" + assert bucket.bucket_name == "my-bucket" + +