Skip to content
15 changes: 14 additions & 1 deletion keras/src/layers/preprocessing/normalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from keras.src.api_export import keras_export
from keras.src.layers.preprocessing.data_layer import DataLayer
from keras.src.utils.module_utils import tensorflow as tf
from keras.utils import PyDataset


@keras_export("keras.layers.Normalization")
Expand Down Expand Up @@ -229,6 +230,18 @@ def adapt(self, data):
# Batch dataset if it isn't batched
data = data.batch(128)
input_shape = tuple(data.element_spec.shape)
elif isinstance(data, PyDataset):
data = data[0]
if isinstance(data, tuple):
# handling (x, y) or (x, y, sample_weight)
data = data[0]
input_shape = data.shape
else:
raise TypeError(
f"Unsupported data type: {type(data)}. `adapt` supports "
f"`np.ndarray`, backend tensors, `tf.data.Dataset`, and "
f"`keras.utils.PyDataset`."
)

if not self.built:
self.build(input_shape)
Expand All @@ -248,7 +261,7 @@ def adapt(self, data):
elif backend.is_tensor(data):
total_mean = ops.mean(data, axis=self._reduce_axis)
total_var = ops.var(data, axis=self._reduce_axis)
elif isinstance(data, tf.data.Dataset):
elif isinstance(data, (tf.data.Dataset, PyDataset)):
total_mean = ops.zeros(self._mean_and_var_shape)
total_var = ops.zeros(self._mean_and_var_shape)
total_count = 0
Expand Down
32 changes: 32 additions & 0 deletions keras/src/layers/preprocessing/normalization_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,3 +169,35 @@ def test_normalization_with_scalar_mean_var(self):
input_data = np.array([[1, 2, 3]], dtype="float32")
layer = layers.Normalization(mean=3.0, variance=2.0)
layer(input_data)

@parameterized.parameters([("x",), ("x_and_y",), ("x_y_and_weights",)])
def test_adapt_pydataset_compat(self, pydataset_type):
import keras

class CustomDataset(keras.utils.PyDataset):
def __len__(self):
return 100

def __getitem__(self, idx):
x = np.random.rand(32, 32, 3)
y = np.random.randint(0, 10, size=(1,))
weights = np.random.randint(0, 10, size=(1,))
if pydataset_type == "x":
return x
elif pydataset_type == "x_and_y":
return x, y
elif pydataset_type == "x_y_and_weights":
return x, y, weights
else:
raise NotImplementedError(pydataset_type)

normalizer = keras.layers.Normalization()
normalizer.adapt(CustomDataset())
self.assertTrue(normalizer.built)
self.assertIsNotNone(normalizer.mean)
self.assertIsNotNone(normalizer.variance)
self.assertEqual(normalizer.mean.shape[-1], 3)
self.assertEqual(normalizer.variance.shape[-1], 3)
sample_input = np.random.rand(1, 32, 32, 3)
output = normalizer(sample_input)
self.assertEqual(output.shape, (1, 32, 32, 3))
Loading