Skip to content

Commit 13ed689

Browse files
committed
Fix tests and format code
1 parent bb4a627 commit 13ed689

File tree

1 file changed

+54
-34
lines changed

1 file changed

+54
-34
lines changed

src/huggingface_hub/hf_api.py

Lines changed: 54 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -236,61 +236,81 @@ def repo_type_and_id_from_hf_id(hf_id: str, hub_url: Optional[str] = None) -> tu
236236
"""
237237
input_hf_id = hf_id
238238

239-
hub_url = hub_url or constants.ENDPOINT
240-
hub_url_no_proto = re.sub(r"^https?://", "", hub_url).rstrip("/")
239+
# Get the hub_url (with or without protocol)
240+
full_hub_url = hub_url if hub_url is not None else constants.ENDPOINT
241+
hub_url_without_protocol = re.sub(r"https?://", "", full_hub_url)
241242

242-
hf_id_no_proto = re.sub(r"^https?://", "", hf_id)
243-
244-
is_hf_url = hf_id_no_proto.startswith(hub_url_no_proto) and "@" not in hf_id
245-
246-
if is_hf_url:
247-
hf_id = hf_id_no_proto[len(hub_url_no_proto) :].lstrip("/")
243+
# Check if hf_id is a URL containing the hub_url (check both with and without protocol)
244+
hf_id_without_protocol = re.sub(r"https?://", "", hf_id)
245+
is_hf_url = hub_url_without_protocol in hf_id_without_protocol and "@" not in hf_id
248246

249247
HFFS_PREFIX = "hf://"
250248
if hf_id.startswith(HFFS_PREFIX): # Remove "hf://" prefix if exists
251249
hf_id = hf_id[len(HFFS_PREFIX) :]
252250

253-
url_segments = [s for s in hf_id.split("/") if s]
254-
seg_len = len(url_segments)
251+
# If it's a URL, strip the endpoint prefix to get the path
252+
if is_hf_url:
253+
# Remove protocol if present
254+
hf_id_normalized = re.sub(r"https?://", "", hf_id)
255255

256-
repo_type: Optional[str] = None
257-
namespace: Optional[str] = None
258-
repo_id: str
256+
# Remove the hub_url prefix to get the relative path
257+
if hf_id_normalized.startswith(hub_url_without_protocol):
258+
# Strip the hub URL and any leading slashes
259+
hf_id = hf_id_normalized[len(hub_url_without_protocol) :].lstrip("/")
259260

261+
url_segments = hf_id.split("/")
262+
is_hf_id = len(url_segments) <= 3
263+
264+
namespace: Optional[str]
260265
if is_hf_url:
261-
if seg_len == 1:
262-
repo_id = url_segments[0]
263-
namespace = None
264-
repo_type = None
265-
elif seg_len == 2:
266-
namespace, repo_id = url_segments
267-
repo_type = None
268-
else:
269-
namespace, repo_id = url_segments[-2:]
270-
repo_type = url_segments[-3] if seg_len >= 3 else None
266+
# For URLs, we need to extract repo_type, namespace, repo_id
267+
# Expected format after stripping endpoint: [repo_type]/namespace/repo_id or namespace/repo_id
268+
269+
if len(url_segments) >= 3:
270+
# Check if first segment is a repo type
271+
if url_segments[0] in constants.REPO_TYPES_MAPPING:
272+
repo_type = constants.REPO_TYPES_MAPPING[url_segments[0]]
273+
namespace = url_segments[1]
274+
repo_id = url_segments[2]
275+
else:
276+
# First segment is namespace
277+
namespace = url_segments[0]
278+
repo_id = url_segments[1]
279+
repo_type = None
280+
elif len(url_segments) == 2:
281+
namespace = url_segments[0]
282+
repo_id = url_segments[1]
283+
284+
# Check if namespace is actually a repo type mapping
271285
if namespace in constants.REPO_TYPES_MAPPING:
272-
# canonical dataset/model
286+
# Mean canonical dataset or model
273287
repo_type = constants.REPO_TYPES_MAPPING[namespace]
274288
namespace = None
275-
276-
elif seg_len <= 3:
277-
if seg_len == 3:
289+
else:
290+
repo_type = None
291+
else:
292+
# Single segment
293+
repo_id = url_segments[0]
294+
namespace = None
295+
repo_type = None
296+
elif is_hf_id:
297+
if len(url_segments) == 3:
278298
# Passed <repo_type>/<user>/<model_id> or <repo_type>/<org>/<model_id>
279-
repo_type, namespace, repo_id = url_segments
280-
elif seg_len == 2:
299+
repo_type, namespace, repo_id = url_segments[-3:]
300+
elif len(url_segments) == 2:
281301
if url_segments[0] in constants.REPO_TYPES_MAPPING:
282302
# Passed '<model_id>' or 'datasets/<dataset_id>' for a canonical model or dataset
283303
repo_type = constants.REPO_TYPES_MAPPING[url_segments[0]]
284304
namespace = None
285-
repo_id = url_segments[1]
305+
repo_id = hf_id.split("/")[-1]
286306
else:
287307
# Passed <user>/<model_id> or <org>/<model_id>
288-
namespace, repo_id = url_segments
308+
namespace, repo_id = hf_id.split("/")[-2:]
289309
repo_type = None
290310
else:
311+
# Passed <model_id>
291312
repo_id = url_segments[0]
292-
namespace = None
293-
repo_type = None
313+
namespace, repo_type = None, None
294314
else:
295315
raise ValueError(f"Unable to retrieve user and repo ID from the passed HF ID: {hf_id}")
296316

@@ -299,7 +319,7 @@ def repo_type_and_id_from_hf_id(hf_id: str, hub_url: Optional[str] = None) -> tu
299319
repo_type = constants.REPO_TYPES_MAPPING[repo_type]
300320
if repo_type == "":
301321
repo_type = None
302-
if repo_type not in constants.REPO_TYPES and repo_type is not None:
322+
if repo_type not in constants.REPO_TYPES:
303323
raise ValueError(f"Unknown `repo_type`: '{repo_type}' ('{input_hf_id}')")
304324

305325
return repo_type, namespace, repo_id

0 commit comments

Comments
 (0)