Skip to content

Commit afb1cb1

Browse files
authored
Fix tf dataset detection logic. (#21794)
In some contexts, matching the begining of the module name won't work, we need to match the tensorflow module in any part of the module name. This was already done in other contexts: https://github.com/keras-team/keras/blob/0512fdb3c70c6ef5499615793e40d5ad46f3b301/keras/src/trainers/data_adapters/data_adapter_utils.py#L307
1 parent 0512fdb commit afb1cb1

File tree

1 file changed

+6
-2
lines changed

1 file changed

+6
-2
lines changed

keras/src/utils/dataset_utils.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -461,7 +461,7 @@ def is_tf_dataset(dataset):
461461
return _mro_matches(
462462
dataset,
463463
class_names=("DatasetV2", "Dataset"),
464-
module_prefixes=(
464+
module_substrings=(
465465
"tensorflow.python.data", # TF classic
466466
"tensorflow.data", # newer TF paths
467467
),
@@ -480,14 +480,18 @@ def is_torch_dataset(dataset):
480480
return _mro_matches(dataset, ("Dataset",), ("torch.utils.data",))
481481

482482

483-
def _mro_matches(dataset, class_names, module_prefixes):
483+
def _mro_matches(
484+
dataset, class_names, module_prefixes=(), module_substrings=()
485+
):
484486
if not hasattr(dataset, "__class__"):
485487
return False
486488
for parent in dataset.__class__.__mro__:
487489
if parent.__name__ in class_names:
488490
mod = str(parent.__module__)
489491
if any(mod.startswith(pref) for pref in module_prefixes):
490492
return True
493+
if any(subs in mod for subs in module_substrings):
494+
return True
491495
return False
492496

493497

0 commit comments

Comments
 (0)