diff --git a/src/huggingface_hub/hf_api.py b/src/huggingface_hub/hf_api.py index 4d602f34b6..08254a8dd5 100644 --- a/src/huggingface_hub/hf_api.py +++ b/src/huggingface_hub/hf_api.py @@ -195,6 +195,7 @@ USERNAME_PLACEHOLDER = "hf_user" _REGEX_DISCUSSION_URL = re.compile(r".*/discussions/(\d+)$") +_REGEX_HTTP_PROTOCOL = re.compile(r"https?://") _CREATE_COMMIT_NO_REPO_ERROR_MESSAGE = ( "\nNote: Creating a commit assumes that the repo already exists on the" @@ -239,28 +240,62 @@ def repo_type_and_id_from_hf_id(hf_id: str, hub_url: Optional[str] = None) -> tu """ input_hf_id = hf_id - hub_url = re.sub(r"https?://", "", hub_url if hub_url is not None else constants.ENDPOINT) - is_hf_url = hub_url in hf_id and "@" not in hf_id + # Get the hub_url (with or without protocol) + full_hub_url = hub_url if hub_url is not None else constants.ENDPOINT + hub_url_without_protocol = _REGEX_HTTP_PROTOCOL.sub("", full_hub_url) + + # Check if hf_id is a URL containing the hub_url (check both with and without protocol) + hf_id_without_protocol = _REGEX_HTTP_PROTOCOL.sub("", hf_id) + is_hf_url = hub_url_without_protocol in hf_id_without_protocol and "@" not in hf_id HFFS_PREFIX = "hf://" if hf_id.startswith(HFFS_PREFIX): # Remove "hf://" prefix if exists hf_id = hf_id[len(HFFS_PREFIX) :] + # If it's a URL, strip the endpoint prefix to get the path + if is_hf_url: + # Remove protocol if present + hf_id_normalized = _REGEX_HTTP_PROTOCOL.sub("", hf_id) + + # Remove the hub_url prefix to get the relative path + if hf_id_normalized.startswith(hub_url_without_protocol): + # Strip the hub URL and any leading slashes + hf_id = hf_id_normalized[len(hub_url_without_protocol) :].lstrip("/") + url_segments = hf_id.split("/") is_hf_id = len(url_segments) <= 3 namespace: Optional[str] if is_hf_url: - namespace, repo_id = url_segments[-2:] - if namespace == hub_url: - namespace = None - if len(url_segments) > 2 and hub_url not in url_segments[-3]: - repo_type = url_segments[-3] - elif namespace in constants.REPO_TYPES_MAPPING: - # Mean canonical dataset or model - repo_type = constants.REPO_TYPES_MAPPING[namespace] - namespace = None + # For URLs, we need to extract repo_type, namespace, repo_id + # Expected format after stripping endpoint: [repo_type]/namespace/repo_id or namespace/repo_id + + if len(url_segments) >= 3: + # Check if first segment is a repo type + if url_segments[0] in constants.REPO_TYPES_MAPPING: + repo_type = constants.REPO_TYPES_MAPPING[url_segments[0]] + namespace = url_segments[1] + repo_id = url_segments[2] + else: + # First segment is namespace + namespace = url_segments[0] + repo_id = url_segments[1] + repo_type = None + elif len(url_segments) == 2: + namespace = url_segments[0] + repo_id = url_segments[1] + + # Check if namespace is actually a repo type mapping + if namespace in constants.REPO_TYPES_MAPPING: + # Mean canonical dataset or model + repo_type = constants.REPO_TYPES_MAPPING[namespace] + namespace = None + else: + repo_type = None else: + # Single segment + repo_id = url_segments[0] + namespace = None repo_type = None elif is_hf_id: if len(url_segments) == 3: diff --git a/tests/test_hf_api.py b/tests/test_hf_api.py index bf9839e613..1dfb355d85 100644 --- a/tests/test_hf_api.py +++ b/tests/test_hf_api.py @@ -2800,25 +2800,31 @@ def test_git_push_end_to_end(self): class ParseHFUrlTest(unittest.TestCase): def test_repo_type_and_id_from_hf_id_on_correct_values(self): possible_values = { - "https://huggingface.co/id": [None, None, "id"], - "https://huggingface.co/user/id": [None, "user", "id"], - "https://huggingface.co/datasets/user/id": ["dataset", "user", "id"], - "https://huggingface.co/spaces/user/id": ["space", "user", "id"], - "user/id": [None, "user", "id"], - "dataset/user/id": ["dataset", "user", "id"], - "space/user/id": ["space", "user", "id"], - "id": [None, None, "id"], - "hf://id": [None, None, "id"], - "hf://user/id": [None, "user", "id"], - "hf://model/user/name": ["model", "user", "name"], # 's' is optional - "hf://models/user/name": ["model", "user", "name"], + "hub": { + "https://huggingface.co/id": [None, None, "id"], + "https://huggingface.co/user/id": [None, "user", "id"], + "https://huggingface.co/datasets/user/id": ["dataset", "user", "id"], + "https://huggingface.co/spaces/user/id": ["space", "user", "id"], + "user/id": [None, "user", "id"], + "dataset/user/id": ["dataset", "user", "id"], + "space/user/id": ["space", "user", "id"], + "id": [None, None, "id"], + "hf://id": [None, None, "id"], + "hf://user/id": [None, "user", "id"], + "hf://model/user/name": ["model", "user", "name"], # 's' is optional + "hf://models/user/name": ["model", "user", "name"], + }, + "self-hosted": { + "http://localhost:8080/hf/user/id": [None, "user", "id"], + "http://localhost:8080/hf/datasets/user/id": ["dataset", "user", "id"], + "http://localhost:8080/hf/models/user/id": ["model", "user", "id"], + }, } for key, value in possible_values.items(): - self.assertEqual( - repo_type_and_id_from_hf_id(key, hub_url=ENDPOINT_PRODUCTION), - tuple(value), - ) + hub_url = ENDPOINT_PRODUCTION if key == "hub" else "http://localhost:8080/hf" + for key, value in value.items(): + assert repo_type_and_id_from_hf_id(key, hub_url=hub_url) == tuple(value) def test_repo_type_and_id_from_hf_id_on_wrong_values(self): for hub_id in [