Skip to content

Commit 5509d98

Browse files
committed
support decollate for numpy scalars
fix linter Signed-off-by: Arthur Dujardin <arthurdujardin.dev@gmail.com> fix numpy decollate multi arrays Signed-off-by: Arthur Dujardin <arthurdujardin.dev@gmail.com> fix linter Signed-off-by: Arthur Dujardin <arthurdujardin.dev@gmail.com> fix numpy scalar support Signed-off-by: Arthur Dujardin <arthurdujardin.dev@gmail.com>
1 parent c3a317d commit 5509d98

File tree

1 file changed

+6
-3
lines changed

1 file changed

+6
-3
lines changed

monai/data/utils.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -621,10 +621,13 @@ def decollate_batch(batch, detach: bool = True, pad=True, fill_value=None):
621621
"""
622622
if batch is None:
623623
return batch
624-
if isinstance(batch, (float, int, str, bytes)) or (
625-
type(batch).__module__ == "numpy" and not isinstance(batch, Iterable)
626-
):
624+
if isinstance(batch, (float, int, str, bytes)):
627625
return batch
626+
if type(batch).__module__ == "numpy":
627+
if not isinstance(batch, Iterable):
628+
return batch
629+
if batch.ndim == 0:
630+
return batch.item() if detach else batch
628631
if isinstance(batch, torch.Tensor):
629632
if detach:
630633
batch = batch.detach()

0 commit comments

Comments
 (0)