From 7d9dac9a8e0c572b15d1edff8aa502f812bfc975 Mon Sep 17 00:00:00 2001 From: lijialin03 Date: Tue, 4 Nov 2025 12:45:52 +0000 Subject: [PATCH 1/6] [API compatibility] update paddle LayerNorm api --- python/paddle/nn/layer/norm.py | 61 +++- .../test_layer_norm_op_v2_dygraph.py | 296 ++++++++++++++++++ 2 files changed, 352 insertions(+), 5 deletions(-) create mode 100644 test/legacy_test/test_layer_norm_op_v2_dygraph.py diff --git a/python/paddle/nn/layer/norm.py b/python/paddle/nn/layer/norm.py index 27d0c672d2490a..7d89f34ff98c14 100644 --- a/python/paddle/nn/layer/norm.py +++ b/python/paddle/nn/layer/norm.py @@ -22,6 +22,7 @@ from paddle import _C_ops, in_dynamic_mode, pir_utils from paddle.device import get_all_custom_device_type +from paddle.utils.decorator_utils import param_one_alias from ...base import dygraph_utils from ...base.data_feeder import check_variable_and_dtype @@ -51,6 +52,7 @@ DataLayoutND, DTypeLike, ParamAttrLike, + PlaceLike, ShapeLike, ) @@ -589,13 +591,34 @@ class LayerNorm(Layer): which is expected to be of that specific size. epsilon(float, optional): The small value added to the variance to prevent division by zero. Default: 1e-05. + alias: ``eps``. + elementwise_affine(bool, optional): Whether to apply element-wise affine transformation + (i.e., learnable scale and bias). If set to ``False``, both the scale (:math:`g`) and + bias (:math:`b`) parameters will be disabled, regardless of the settings of `weight_attr` + and `bias_attr`. This parameter acts as a master switch. Defaults to True. + **Note: This argument must be passed as a keyword argument.** + bias(bool, optional): Whether to include a learnable bias term in the layer. This setting + only takes effect when `elementwise_affine` is ``True``. If set to ``False``, no bias + parameter will be created, even if `bias_attr` is specified. Defaults to True. + **Note: This argument must be passed as a keyword argument.** weight_attr(ParamAttr|bool|None, optional): The parameter attribute for the learnable - gain :math:`g`. If False, weight is None. If is None, a default :code:`ParamAttr` would be added as scale. The - :attr:`param_attr` is initialized as 1 if it is added. Default: None. For more information, please refer to :ref:`api_paddle_ParamAttr` . + gain :math:`g` (scale). This setting only takes effect when `elementwise_affine` is ``True``. + - If set to ``False``, no gain parameter will be created. + - If set to ``None`` or ``True``, a default :code:`ParamAttr` will be used, and the + parameter will be initialized to 1. + - If set to a custom :code:`ParamAttr` object, it will be used to configure the parameter. + Default: None. + **Note: This argument must be passed as a keyword argument.** bias_attr(ParamAttr|bool|None, optional): The parameter attribute for the learnable - bias :math:`b`. If is False, bias is None. If is None, a default :code:`ParamAttr` would be added as bias. The - :attr:`bias_attr` is initialized as 0 if it is added. Default: None. For more information, please refer to :ref:`api_paddle_ParamAttr` . + bias :math:`b`. This setting only takes effect when both `elementwise_affine` and `bias` are ``True``. + - If set to ``False``, no bias parameter will be created. + - If set to ``None`` or ``True``, a default :code:`ParamAttr` will be used, and the + parameter will be initialized to 0. + - If set to a custom :code:`ParamAttr` object, it will be used to configure the parameter. + Default: None. + **Note: This argument must be passed as a keyword argument.** name(str|None, optional): Name for the LayerNorm, default is None. For more information, please refer to :ref:`api_guide_Name` . + **Note: This argument must be passed as a keyword argument.** Shape: - x: 2-D, 3-D, 4-D or 5-D tensor. @@ -629,10 +652,16 @@ class LayerNorm(Layer): weight: Tensor | None bias: Tensor | None + @param_one_alias(["epsilon", "eps"]) def __init__( self, normalized_shape: int | Sequence[int], epsilon: float = 1e-5, + *, + elementwise_affine: bool = True, + bias: bool = True, + device: PlaceLike | None = None, + dtype: DTypeLike | None = None, weight_attr: bool | ParamAttr | None = None, bias_attr: bool | ParamAttr | None = None, name: str | None = None, @@ -643,6 +672,21 @@ def __init__( self._normalized_shape = list(normalized_shape) self._epsilon = epsilon + self._device = device + self._dtype = ( + self._helper.get_default_dtype() if dtype is None else dtype + ) + + if not elementwise_affine: + weight_attr = False + bias_attr = False + else: + weight_attr = weight_attr if weight_attr is not False else None + if not bias: + bias_attr = False + else: + bias_attr = bias_attr if bias_attr is not False else None + self._weight_attr = weight_attr self._bias_attr = bias_attr param_shape = [np.prod(self._normalized_shape)] @@ -652,15 +696,22 @@ def __init__( else: self.weight = self.create_parameter( attr=self._weight_attr, + dtype=self._dtype, shape=param_shape, default_initializer=Constant(1.0), + device=self._device, ) if bias_attr is False: self.bias = None else: self.bias = self.create_parameter( - attr=self._bias_attr, shape=param_shape, is_bias=True + attr=self._bias_attr, + dtype=self._dtype, + shape=param_shape, + default_initializer=Constant(0.0), + device=self._device, + is_bias=True, ) def forward(self, input: Tensor) -> Tensor: diff --git a/test/legacy_test/test_layer_norm_op_v2_dygraph.py b/test/legacy_test/test_layer_norm_op_v2_dygraph.py new file mode 100644 index 00000000000000..5e759e9b7a3ec0 --- /dev/null +++ b/test/legacy_test/test_layer_norm_op_v2_dygraph.py @@ -0,0 +1,296 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +from functools import reduce +from operator import mul + +import numpy as np + +import paddle +from paddle import nn + + +def _reference_layer_norm_naive(x, scale, beta, epsilon, begin_norm_axis=1): + x_shape = x.shape + N = reduce(mul, x_shape[0:begin_norm_axis], 1) + D = reduce(mul, x_shape[begin_norm_axis : len(x_shape)], 1) + x.shape = [N, D] + + mean = np.mean(x, axis=1) + var = np.var(x, axis=1) + epsilon + output = np.divide( + (x - mean.reshape([N, 1])), (np.sqrt(var)).reshape([N, 1]) + ) + if scale is not None: + output = scale.reshape([1, D]) * output + if beta is not None: + output = output + beta.reshape([1, D]) + + x.shape, output.shape = x_shape, x_shape + return output + + +class TestLayerNormOp(unittest.TestCase): + def setUp(self): + paddle.disable_static() + self.x_shape = [2, 6, 6, 3] + self.epsilon = 1e-5 + self.begin_norm_axis = 1 + + @unittest.skipIf( + not paddle.in_dynamic_mode(), "test is only for dynamic mode" + ) + def test_basic_fp32(self): + """test basic functionality with float32.""" + x_np = np.random.random(self.x_shape).astype('float32') + scale_np = np.random.random( + self.x_shape[self.begin_norm_axis :] + ).astype('float32') + bias_np = np.random.random(self.x_shape[self.begin_norm_axis :]).astype( + 'float32' + ) + scale = paddle.to_tensor(scale_np).reshape(-1) + bias = paddle.to_tensor(bias_np).reshape(-1) + + ln = nn.LayerNorm( + normalized_shape=self.x_shape[self.begin_norm_axis :], + weight_attr=nn.initializer.Assign(scale), + bias_attr=nn.initializer.Assign(bias), + epsilon=self.epsilon, + ) + + x_pd = paddle.to_tensor(x_np) + y_pd = ln(x_pd) + expect_res = _reference_layer_norm_naive( + x_np, scale_np, bias_np, self.epsilon, self.begin_norm_axis + ) + + np.testing.assert_allclose( + y_pd.numpy(), expect_res, rtol=1e-5, atol=1e-4 + ) + + @unittest.skipIf( + not paddle.in_dynamic_mode(), "test is only for dynamic mode" + ) + def test_no_scale_no_bias_fp32(self): + """test the case when both scale and bias are disabled (FP32).""" + x_np = np.random.random(self.x_shape).astype('float32') + x_pd = paddle.to_tensor(x_np) + + ln = nn.LayerNorm( + normalized_shape=self.x_shape[self.begin_norm_axis :], + elementwise_affine=False, + epsilon=self.epsilon, + ) + y_pd = ln(x_pd) + + expect_res = _reference_layer_norm_naive( + x_np, None, None, self.epsilon, self.begin_norm_axis + ) + np.testing.assert_allclose( + y_pd.numpy(), expect_res, rtol=1e-5, atol=1e-4 + ) + + @unittest.skipIf( + not paddle.in_dynamic_mode(), "test is only for dynamic mode" + ) + def test_with_scale_no_bias_fp32(self): + """test the case when only scale is enabled (FP32).""" + x_np = np.random.random(self.x_shape).astype('float32') + scale_np = np.random.random( + self.x_shape[self.begin_norm_axis :] + ).astype('float32') + scale = paddle.to_tensor(scale_np).reshape(-1) + + ln = nn.LayerNorm( + normalized_shape=self.x_shape[self.begin_norm_axis :], + elementwise_affine=True, + bias_attr=False, + epsilon=self.epsilon, + ) + with paddle.no_grad(): + ln.weight.set_value(scale) + + x_pd = paddle.to_tensor(x_np) + y_pd = ln(x_pd) + + expect_res = _reference_layer_norm_naive( + x_np, scale_np, None, self.epsilon, self.begin_norm_axis + ) + np.testing.assert_allclose( + y_pd.numpy(), expect_res, rtol=1e-5, atol=1e-4 + ) + + @unittest.skipIf( + not paddle.in_dynamic_mode(), "test is only for dynamic mode" + ) + def test_no_scale_with_bias_fp32(self): + """test the case when only bias is enabled (FP32).""" + x_np = np.random.random(self.x_shape).astype('float32') + bias_np = np.random.random(self.x_shape[self.begin_norm_axis :]).astype( + 'float32' + ) + bias = paddle.to_tensor(bias_np).reshape(-1) + + ln = nn.LayerNorm( + normalized_shape=self.x_shape[self.begin_norm_axis :], + elementwise_affine=True, + weight_attr=False, + epsilon=self.epsilon, + ) + with paddle.no_grad(): + ln.bias.set_value(bias) + + x_pd = paddle.to_tensor(x_np) + y_pd = ln(x_pd) + + expect_res = _reference_layer_norm_naive( + x_np, None, bias_np, self.epsilon, self.begin_norm_axis + ) + np.testing.assert_allclose( + y_pd.numpy(), expect_res, rtol=1e-5, atol=1e-4 + ) + + def test_bf16_forward_backward(self): + """test forward and backward pass with bfloat16 precision.""" + place = paddle.CUDAPlace(0) + + with paddle.base.dygraph.guard(place): + x_np = np.random.random(self.x_shape).astype('float32') + scale_np = np.random.random( + self.x_shape[self.begin_norm_axis :] + ).astype('float32') + bias_np = np.random.random( + self.x_shape[self.begin_norm_axis :] + ).astype('float32') + + x = paddle.to_tensor(x_np).cast(paddle.bfloat16) + x.stop_gradient = False + + scale = paddle.to_tensor(scale_np).cast(paddle.bfloat16).reshape(-1) + bias = paddle.to_tensor(bias_np).cast(paddle.bfloat16).reshape(-1) + + ln = nn.LayerNorm( + normalized_shape=self.x_shape[self.begin_norm_axis :], + weight_attr=nn.initializer.Assign(scale), + bias_attr=nn.initializer.Assign(bias), + epsilon=self.epsilon, + ) + ln.to(device='cuda') + + y = ln(x) + loss = y.sum() + loss.backward() + + self.assertIsNotNone(x.grad) + self.assertIsNotNone(ln.weight.grad) + self.assertIsNotNone(ln.bias.grad) + + +class TestLayerNormParam(unittest.TestCase): + def setUp(self): + self.normalized_shape = [6] + self.x_tensor = paddle.randn([2, 4, 4, 6]) + + def test_elementwise_affine_false(self): + """test that when elementwise_affine=False, no learnable parameters are created.""" + layer = nn.LayerNorm( + normalized_shape=self.normalized_shape, elementwise_affine=False + ) + self.assertIsNone(layer.weight) + self.assertIsNone(layer.bias) + + out = layer(self.x_tensor) + self.assertEqual(out.shape, self.x_tensor.shape) + + @unittest.skipIf( + not paddle.in_dynamic_mode(), "test is only for dynamic mode" + ) + def test_elementwise_affine_true(self): + """test that when elementwise_affine=True and attr=None, parameters are created with default initialization.""" + layer = nn.LayerNorm( + normalized_shape=self.normalized_shape, elementwise_affine=True + ) + self.assertIsNotNone(layer.weight) + self.assertIsNotNone(layer.bias) + + expected_weight = paddle.ones([6]) + expected_bias = paddle.zeros([6]) + self.assertTrue(paddle.allclose(layer.weight, expected_weight)) + self.assertTrue(paddle.allclose(layer.bias, expected_bias)) + + def test_bias_false(self): + """test that when bias=False, the bias parameter is disabled even if elementwise_affine=True.""" + layer = nn.LayerNorm( + normalized_shape=self.normalized_shape, + elementwise_affine=True, + bias=False, + ) + self.assertIsNotNone(layer.weight) + self.assertIsNone(layer.bias) + + @unittest.skipIf( + not paddle.in_dynamic_mode(), "test is only for dynamic mode" + ) + def test_attr_custom_initialization(self): + """test that weight_attr and bias_attr can be used to customize the initialization of the weight parameter.""" + weight_attr = paddle.nn.initializer.Constant(value=2.0) + bias_attr = paddle.nn.initializer.Constant(value=3.0) + layer = nn.LayerNorm( + normalized_shape=self.normalized_shape, + elementwise_affine=True, + weight_attr=weight_attr, + bias_attr=bias_attr, + ) + + expected_weight = paddle.full([6], 2.0) + expected_bias = paddle.full([6], 3.0) + self.assertTrue(paddle.allclose(layer.weight, expected_weight)) + self.assertTrue(paddle.allclose(layer.bias, expected_bias)) + + def test_alias(self): + """test parameter alias epsilon/eps""" + layer_epsilon = nn.LayerNorm( + normalized_shape=self.normalized_shape, + elementwise_affine=True, + epsilon=1e-5, + ) + layer_eps = nn.LayerNorm( + normalized_shape=self.normalized_shape, + elementwise_affine=True, + eps=1e-5, + ) + + out_epsilon = layer_epsilon(self.x_tensor) + out_eps = layer_eps(self.x_tensor) + + np.testing.assert_array_equal(out_epsilon.numpy(), out_eps.numpy()) + + def test_errors(self): + """test for errors.""" + layer_norm = nn.LayerNorm(self.normalized_shape) + x1 = np.random.random([3, *self.normalized_shape]).astype('float32') + with self.assertRaises(ValueError): + layer_norm(x1) + with self.assertRaises(TypeError): + nn.LayerNorm(self.normalized_shape, 1e-5, None, None, "name") + with self.assertRaises(TypeError): + nn.LayerNorm( + self.normalized_shape, 1e-5, False, "cpu", paddle.float32 + ) + + +if __name__ == '__main__': + unittest.main() From ba68146374135923f6a7385df331337b61d7bf26 Mon Sep 17 00:00:00 2001 From: lijialin03 Date: Mon, 10 Nov 2025 07:19:59 +0000 Subject: [PATCH 2/6] update:modify tests to fit static mode --- .../test_layer_norm_op_v2_dygraph.py | 360 +++++++++--------- 1 file changed, 182 insertions(+), 178 deletions(-) diff --git a/test/legacy_test/test_layer_norm_op_v2_dygraph.py b/test/legacy_test/test_layer_norm_op_v2_dygraph.py index 5e759e9b7a3ec0..a6d8b9dd00f357 100644 --- a/test/legacy_test/test_layer_norm_op_v2_dygraph.py +++ b/test/legacy_test/test_layer_norm_op_v2_dygraph.py @@ -17,6 +17,7 @@ from operator import mul import numpy as np +from op_test import get_places import paddle from paddle import nn @@ -48,161 +49,161 @@ def setUp(self): self.x_shape = [2, 6, 6, 3] self.epsilon = 1e-5 self.begin_norm_axis = 1 + self.places = get_places() - @unittest.skipIf( - not paddle.in_dynamic_mode(), "test is only for dynamic mode" - ) def test_basic_fp32(self): """test basic functionality with float32.""" - x_np = np.random.random(self.x_shape).astype('float32') - scale_np = np.random.random( - self.x_shape[self.begin_norm_axis :] - ).astype('float32') - bias_np = np.random.random(self.x_shape[self.begin_norm_axis :]).astype( - 'float32' - ) - scale = paddle.to_tensor(scale_np).reshape(-1) - bias = paddle.to_tensor(bias_np).reshape(-1) - - ln = nn.LayerNorm( - normalized_shape=self.x_shape[self.begin_norm_axis :], - weight_attr=nn.initializer.Assign(scale), - bias_attr=nn.initializer.Assign(bias), - epsilon=self.epsilon, - ) - - x_pd = paddle.to_tensor(x_np) - y_pd = ln(x_pd) - expect_res = _reference_layer_norm_naive( - x_np, scale_np, bias_np, self.epsilon, self.begin_norm_axis - ) - - np.testing.assert_allclose( - y_pd.numpy(), expect_res, rtol=1e-5, atol=1e-4 - ) + for place in self.places: + with paddle.base.dygraph.guard(place): + x_np = np.random.random(self.x_shape).astype('float32') + scale_np = np.random.random( + self.x_shape[self.begin_norm_axis :] + ).astype('float32') + bias_np = np.random.random( + self.x_shape[self.begin_norm_axis :] + ).astype('float32') + scale = paddle.to_tensor(scale_np).reshape(-1) + bias = paddle.to_tensor(bias_np).reshape(-1) + + ln = nn.LayerNorm( + normalized_shape=self.x_shape[self.begin_norm_axis :], + weight_attr=nn.initializer.Assign(scale), + bias_attr=nn.initializer.Assign(bias), + epsilon=self.epsilon, + ) + + x_pd = paddle.to_tensor(x_np) + y_pd = ln(x_pd) + expect_res = _reference_layer_norm_naive( + x_np, scale_np, bias_np, self.epsilon, self.begin_norm_axis + ) + + np.testing.assert_allclose( + y_pd.numpy(), expect_res, rtol=1e-5, atol=1e-4 + ) - @unittest.skipIf( - not paddle.in_dynamic_mode(), "test is only for dynamic mode" - ) def test_no_scale_no_bias_fp32(self): """test the case when both scale and bias are disabled (FP32).""" - x_np = np.random.random(self.x_shape).astype('float32') - x_pd = paddle.to_tensor(x_np) - - ln = nn.LayerNorm( - normalized_shape=self.x_shape[self.begin_norm_axis :], - elementwise_affine=False, - epsilon=self.epsilon, - ) - y_pd = ln(x_pd) - - expect_res = _reference_layer_norm_naive( - x_np, None, None, self.epsilon, self.begin_norm_axis - ) - np.testing.assert_allclose( - y_pd.numpy(), expect_res, rtol=1e-5, atol=1e-4 - ) + for place in self.places: + with paddle.base.dygraph.guard(place): + x_np = np.random.random(self.x_shape).astype('float32') + x_pd = paddle.to_tensor(x_np) + + ln = nn.LayerNorm( + normalized_shape=self.x_shape[self.begin_norm_axis :], + elementwise_affine=False, + epsilon=self.epsilon, + ) + y_pd = ln(x_pd) + + expect_res = _reference_layer_norm_naive( + x_np, None, None, self.epsilon, self.begin_norm_axis + ) + np.testing.assert_allclose( + y_pd.numpy(), expect_res, rtol=1e-5, atol=1e-4 + ) - @unittest.skipIf( - not paddle.in_dynamic_mode(), "test is only for dynamic mode" - ) def test_with_scale_no_bias_fp32(self): """test the case when only scale is enabled (FP32).""" - x_np = np.random.random(self.x_shape).astype('float32') - scale_np = np.random.random( - self.x_shape[self.begin_norm_axis :] - ).astype('float32') - scale = paddle.to_tensor(scale_np).reshape(-1) - - ln = nn.LayerNorm( - normalized_shape=self.x_shape[self.begin_norm_axis :], - elementwise_affine=True, - bias_attr=False, - epsilon=self.epsilon, - ) - with paddle.no_grad(): - ln.weight.set_value(scale) - - x_pd = paddle.to_tensor(x_np) - y_pd = ln(x_pd) - - expect_res = _reference_layer_norm_naive( - x_np, scale_np, None, self.epsilon, self.begin_norm_axis - ) - np.testing.assert_allclose( - y_pd.numpy(), expect_res, rtol=1e-5, atol=1e-4 - ) + for place in self.places: + with paddle.base.dygraph.guard(place): + x_np = np.random.random(self.x_shape).astype('float32') + scale_np = np.random.random( + self.x_shape[self.begin_norm_axis :] + ).astype('float32') + scale = paddle.to_tensor(scale_np).reshape(-1) + + ln = nn.LayerNorm( + normalized_shape=self.x_shape[self.begin_norm_axis :], + elementwise_affine=True, + bias_attr=False, + epsilon=self.epsilon, + ) + with paddle.no_grad(): + ln.weight.set_value(scale) + + x_pd = paddle.to_tensor(x_np) + y_pd = ln(x_pd) + + expect_res = _reference_layer_norm_naive( + x_np, scale_np, None, self.epsilon, self.begin_norm_axis + ) + np.testing.assert_allclose( + y_pd.numpy(), expect_res, rtol=1e-5, atol=1e-4 + ) - @unittest.skipIf( - not paddle.in_dynamic_mode(), "test is only for dynamic mode" - ) def test_no_scale_with_bias_fp32(self): """test the case when only bias is enabled (FP32).""" - x_np = np.random.random(self.x_shape).astype('float32') - bias_np = np.random.random(self.x_shape[self.begin_norm_axis :]).astype( - 'float32' - ) - bias = paddle.to_tensor(bias_np).reshape(-1) - - ln = nn.LayerNorm( - normalized_shape=self.x_shape[self.begin_norm_axis :], - elementwise_affine=True, - weight_attr=False, - epsilon=self.epsilon, - ) - with paddle.no_grad(): - ln.bias.set_value(bias) - - x_pd = paddle.to_tensor(x_np) - y_pd = ln(x_pd) - - expect_res = _reference_layer_norm_naive( - x_np, None, bias_np, self.epsilon, self.begin_norm_axis - ) - np.testing.assert_allclose( - y_pd.numpy(), expect_res, rtol=1e-5, atol=1e-4 - ) + for place in self.places: + with paddle.base.dygraph.guard(place): + x_np = np.random.random(self.x_shape).astype('float32') + bias_np = np.random.random( + self.x_shape[self.begin_norm_axis :] + ).astype('float32') + bias = paddle.to_tensor(bias_np).reshape(-1) + + ln = nn.LayerNorm( + normalized_shape=self.x_shape[self.begin_norm_axis :], + elementwise_affine=True, + weight_attr=False, + epsilon=self.epsilon, + ) + with paddle.no_grad(): + ln.bias.set_value(bias) + + x_pd = paddle.to_tensor(x_np) + y_pd = ln(x_pd) + + expect_res = _reference_layer_norm_naive( + x_np, None, bias_np, self.epsilon, self.begin_norm_axis + ) + np.testing.assert_allclose( + y_pd.numpy(), expect_res, rtol=1e-5, atol=1e-4 + ) def test_bf16_forward_backward(self): """test forward and backward pass with bfloat16 precision.""" - place = paddle.CUDAPlace(0) - - with paddle.base.dygraph.guard(place): - x_np = np.random.random(self.x_shape).astype('float32') - scale_np = np.random.random( - self.x_shape[self.begin_norm_axis :] - ).astype('float32') - bias_np = np.random.random( - self.x_shape[self.begin_norm_axis :] - ).astype('float32') - - x = paddle.to_tensor(x_np).cast(paddle.bfloat16) - x.stop_gradient = False - - scale = paddle.to_tensor(scale_np).cast(paddle.bfloat16).reshape(-1) - bias = paddle.to_tensor(bias_np).cast(paddle.bfloat16).reshape(-1) - - ln = nn.LayerNorm( - normalized_shape=self.x_shape[self.begin_norm_axis :], - weight_attr=nn.initializer.Assign(scale), - bias_attr=nn.initializer.Assign(bias), - epsilon=self.epsilon, - ) - ln.to(device='cuda') - - y = ln(x) - loss = y.sum() - loss.backward() - - self.assertIsNotNone(x.grad) - self.assertIsNotNone(ln.weight.grad) - self.assertIsNotNone(ln.bias.grad) + for place in self.places: + with paddle.base.dygraph.guard(place): + x_np = np.random.random(self.x_shape).astype('float32') + scale_np = np.random.random( + self.x_shape[self.begin_norm_axis :] + ).astype('float32') + bias_np = np.random.random( + self.x_shape[self.begin_norm_axis :] + ).astype('float32') + + x = paddle.to_tensor(x_np).cast(paddle.bfloat16) + x.stop_gradient = False + + scale = ( + paddle.to_tensor(scale_np).cast(paddle.bfloat16).reshape(-1) + ) + bias = ( + paddle.to_tensor(bias_np).cast(paddle.bfloat16).reshape(-1) + ) + + ln = nn.LayerNorm( + normalized_shape=self.x_shape[self.begin_norm_axis :], + weight_attr=nn.initializer.Assign(scale), + bias_attr=nn.initializer.Assign(bias), + epsilon=self.epsilon, + ) + + y = ln(x) + loss = y.sum() + loss.backward() + + self.assertIsNotNone(x.grad) + self.assertIsNotNone(ln.weight.grad) + self.assertIsNotNone(ln.bias.grad) class TestLayerNormParam(unittest.TestCase): def setUp(self): self.normalized_shape = [6] self.x_tensor = paddle.randn([2, 4, 4, 6]) + self.places = get_places() def test_elementwise_affine_false(self): """test that when elementwise_affine=False, no learnable parameters are created.""" @@ -215,21 +216,21 @@ def test_elementwise_affine_false(self): out = layer(self.x_tensor) self.assertEqual(out.shape, self.x_tensor.shape) - @unittest.skipIf( - not paddle.in_dynamic_mode(), "test is only for dynamic mode" - ) def test_elementwise_affine_true(self): """test that when elementwise_affine=True and attr=None, parameters are created with default initialization.""" - layer = nn.LayerNorm( - normalized_shape=self.normalized_shape, elementwise_affine=True - ) - self.assertIsNotNone(layer.weight) - self.assertIsNotNone(layer.bias) - - expected_weight = paddle.ones([6]) - expected_bias = paddle.zeros([6]) - self.assertTrue(paddle.allclose(layer.weight, expected_weight)) - self.assertTrue(paddle.allclose(layer.bias, expected_bias)) + for place in self.places: + with paddle.base.dygraph.guard(place): + layer = nn.LayerNorm( + normalized_shape=self.normalized_shape, + elementwise_affine=True, + ) + self.assertIsNotNone(layer.weight) + self.assertIsNotNone(layer.bias) + + expected_weight = paddle.ones([6]) + expected_bias = paddle.zeros([6]) + self.assertTrue(paddle.allclose(layer.weight, expected_weight)) + self.assertTrue(paddle.allclose(layer.bias, expected_bias)) def test_bias_false(self): """test that when bias=False, the bias parameter is disabled even if elementwise_affine=True.""" @@ -241,42 +242,45 @@ def test_bias_false(self): self.assertIsNotNone(layer.weight) self.assertIsNone(layer.bias) - @unittest.skipIf( - not paddle.in_dynamic_mode(), "test is only for dynamic mode" - ) def test_attr_custom_initialization(self): """test that weight_attr and bias_attr can be used to customize the initialization of the weight parameter.""" - weight_attr = paddle.nn.initializer.Constant(value=2.0) - bias_attr = paddle.nn.initializer.Constant(value=3.0) - layer = nn.LayerNorm( - normalized_shape=self.normalized_shape, - elementwise_affine=True, - weight_attr=weight_attr, - bias_attr=bias_attr, - ) - - expected_weight = paddle.full([6], 2.0) - expected_bias = paddle.full([6], 3.0) - self.assertTrue(paddle.allclose(layer.weight, expected_weight)) - self.assertTrue(paddle.allclose(layer.bias, expected_bias)) + for place in self.places: + with paddle.base.dygraph.guard(place): + weight_attr = paddle.nn.initializer.Constant(value=2.0) + bias_attr = paddle.nn.initializer.Constant(value=3.0) + layer = nn.LayerNorm( + normalized_shape=self.normalized_shape, + elementwise_affine=True, + weight_attr=weight_attr, + bias_attr=bias_attr, + ) + + expected_weight = paddle.full([6], 2.0) + expected_bias = paddle.full([6], 3.0) + self.assertTrue(paddle.allclose(layer.weight, expected_weight)) + self.assertTrue(paddle.allclose(layer.bias, expected_bias)) def test_alias(self): """test parameter alias epsilon/eps""" - layer_epsilon = nn.LayerNorm( - normalized_shape=self.normalized_shape, - elementwise_affine=True, - epsilon=1e-5, - ) - layer_eps = nn.LayerNorm( - normalized_shape=self.normalized_shape, - elementwise_affine=True, - eps=1e-5, - ) - - out_epsilon = layer_epsilon(self.x_tensor) - out_eps = layer_eps(self.x_tensor) - - np.testing.assert_array_equal(out_epsilon.numpy(), out_eps.numpy()) + for place in self.places: + with paddle.base.dygraph.guard(place): + layer_epsilon = nn.LayerNorm( + normalized_shape=self.normalized_shape, + elementwise_affine=True, + epsilon=1e-5, + ) + layer_eps = nn.LayerNorm( + normalized_shape=self.normalized_shape, + elementwise_affine=True, + eps=1e-5, + ) + + out_epsilon = layer_epsilon(self.x_tensor) + out_eps = layer_eps(self.x_tensor) + + np.testing.assert_array_equal( + out_epsilon.numpy(), out_eps.numpy() + ) def test_errors(self): """test for errors.""" From 6606a7a03b4bd026c8113cb295ca47044227219b Mon Sep 17 00:00:00 2001 From: lijialin03 Date: Mon, 10 Nov 2025 09:46:21 +0000 Subject: [PATCH 3/6] update:optimize tests --- test/legacy_test/test_layer_norm_op_v2.py | 160 ++++++++++ .../test_layer_norm_op_v2_dygraph.py | 300 ------------------ 2 files changed, 160 insertions(+), 300 deletions(-) delete mode 100644 test/legacy_test/test_layer_norm_op_v2_dygraph.py diff --git a/test/legacy_test/test_layer_norm_op_v2.py b/test/legacy_test/test_layer_norm_op_v2.py index a8bfec46252114..652d9d8194aa24 100644 --- a/test/legacy_test/test_layer_norm_op_v2.py +++ b/test/legacy_test/test_layer_norm_op_v2.py @@ -159,6 +159,166 @@ def compute_v4(x): ) +class TestLayerNormParam(unittest.TestCase): + def setUp(self): + self.normalized_shape = [6] + self.x_shape = [2, 4, 4, 6] + self.epsilon = 1e-5 + self.places = get_places() + + def test_elementwise_affine_false(self): + """test that when elementwise_affine=False, weight and bias parameters are not created.""" + for p in self.places: + with base.dygraph.guard(p): + layer = paddle.nn.LayerNorm( + normalized_shape=self.normalized_shape, + elementwise_affine=False, + ) + self.assertIsNone( + layer.weight, + "Weight should be None when elementwise_affine=False", + ) + self.assertIsNone( + layer.bias, + "Bias should be None when elementwise_affine=False", + ) + + x_tensor = paddle.randn(self.x_shape) + out = layer(x_tensor) + self.assertEqual(out.shape, x_tensor.shape) + + def test_elementwise_affine_true(self): + """test that when elementwise_affine=True and attr=None, parameters are created with default initialization.""" + for place in self.places: + with paddle.base.dygraph.guard(place): + layer = paddle.nn.LayerNorm( + normalized_shape=self.normalized_shape, + elementwise_affine=True, + ) + self.assertIsNotNone( + layer.weight, + "Weight should not be None when elementwise_affine=True", + ) + self.assertIsNotNone( + layer.bias, + "Weight should not be None when elementwise_affine=True", + ) + + expected_weight = paddle.ones(self.normalized_shape) + expected_bias = paddle.zeros(self.normalized_shape) + + self.assertTrue(paddle.allclose(layer.weight, expected_weight)) + self.assertTrue(paddle.allclose(layer.bias, expected_bias)) + + def test_bias_false(self): + """test that when bias=False, the bias parameter is disabled even if elementwise_affine=True.""" + for p in self.places: + with base.dygraph.guard(p): + layer = paddle.nn.LayerNorm( + normalized_shape=self.normalized_shape, + elementwise_affine=True, + bias=False, + ) + self.assertIsNotNone( + layer.weight, + "Weight should exist when elementwise_affine=True", + ) + self.assertIsNone( + layer.bias, "Bias should be None when bias_attr=False" + ) + + def test_weight_and_bias_false(self): + """test that when weight_attr=False and bias_attr=False, both parameters are disabled.""" + for p in self.places: + with base.dygraph.guard(p): + layer = paddle.nn.LayerNorm( + normalized_shape=self.normalized_shape, + elementwise_affine=True, + weight_attr=False, + bias_attr=False, + ) + self.assertIsNotNone( + layer.weight, + "Weight should not be None when elementwise_affine=True although weight_attr=False", + ) + self.assertIsNotNone( + layer.bias, + "Bias should not be None when elementwise_affine=True although bias_attr=False", + ) + + def test_custom_initialization(self): + """test custom initialization using weight_attr and bias_attr.""" + for p in self.places: + with base.dygraph.guard(p): + weight_val = 2.5 + bias_val = -1.0 + weight_initializer = paddle.nn.initializer.Constant( + value=weight_val + ) + bias_initializer = paddle.nn.initializer.Constant( + value=bias_val + ) + + layer = paddle.nn.LayerNorm( + normalized_shape=self.normalized_shape, + elementwise_affine=True, + weight_attr=weight_initializer, + bias_attr=bias_initializer, + ) + + expected_weight = paddle.full( + self.normalized_shape, weight_val, dtype=layer.weight.dtype + ) + expected_bias = paddle.full( + self.normalized_shape, bias_val, dtype=layer.bias.dtype + ) + + self.assertTrue( + paddle.allclose(layer.weight, expected_weight), + f"Weight initialization failed. Got {layer.weight.numpy()}, expected {expected_weight.numpy()}", + ) + self.assertTrue( + paddle.allclose(layer.bias, expected_bias), + f"Bias initialization failed. Got {layer.bias.numpy()}, expected {expected_bias.numpy()}", + ) + + def test_alias(self): + """test parameter alias epsilon/eps""" + for place in self.places: + with paddle.base.dygraph.guard(place): + layer_epsilon = paddle.nn.LayerNorm( + normalized_shape=self.normalized_shape, + elementwise_affine=True, + epsilon=1e-5, + ) + layer_eps = paddle.nn.LayerNorm( + normalized_shape=self.normalized_shape, + elementwise_affine=True, + eps=1e-5, + ) + + x_tensor = paddle.randn(self.x_shape) + out_epsilon = layer_epsilon(x_tensor) + out_eps = layer_eps(x_tensor) + + np.testing.assert_array_equal( + out_epsilon.numpy(), out_eps.numpy() + ) + + def test_errors(self): + """test for errors.""" + layer_norm = paddle.nn.LayerNorm(self.normalized_shape) + x1 = np.random.random([3, *self.normalized_shape]).astype('float32') + with self.assertRaises(TypeError): + layer_norm(x1) + with self.assertRaises(TypeError): + paddle.nn.LayerNorm(self.normalized_shape, 1e-5, None, None, "name") + with self.assertRaises(TypeError): + paddle.nn.LayerNorm( + self.normalized_shape, 1e-5, False, "cpu", paddle.float32 + ) + + if __name__ == '__main__': paddle.enable_static() unittest.main() diff --git a/test/legacy_test/test_layer_norm_op_v2_dygraph.py b/test/legacy_test/test_layer_norm_op_v2_dygraph.py deleted file mode 100644 index a6d8b9dd00f357..00000000000000 --- a/test/legacy_test/test_layer_norm_op_v2_dygraph.py +++ /dev/null @@ -1,300 +0,0 @@ -# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import unittest -from functools import reduce -from operator import mul - -import numpy as np -from op_test import get_places - -import paddle -from paddle import nn - - -def _reference_layer_norm_naive(x, scale, beta, epsilon, begin_norm_axis=1): - x_shape = x.shape - N = reduce(mul, x_shape[0:begin_norm_axis], 1) - D = reduce(mul, x_shape[begin_norm_axis : len(x_shape)], 1) - x.shape = [N, D] - - mean = np.mean(x, axis=1) - var = np.var(x, axis=1) + epsilon - output = np.divide( - (x - mean.reshape([N, 1])), (np.sqrt(var)).reshape([N, 1]) - ) - if scale is not None: - output = scale.reshape([1, D]) * output - if beta is not None: - output = output + beta.reshape([1, D]) - - x.shape, output.shape = x_shape, x_shape - return output - - -class TestLayerNormOp(unittest.TestCase): - def setUp(self): - paddle.disable_static() - self.x_shape = [2, 6, 6, 3] - self.epsilon = 1e-5 - self.begin_norm_axis = 1 - self.places = get_places() - - def test_basic_fp32(self): - """test basic functionality with float32.""" - for place in self.places: - with paddle.base.dygraph.guard(place): - x_np = np.random.random(self.x_shape).astype('float32') - scale_np = np.random.random( - self.x_shape[self.begin_norm_axis :] - ).astype('float32') - bias_np = np.random.random( - self.x_shape[self.begin_norm_axis :] - ).astype('float32') - scale = paddle.to_tensor(scale_np).reshape(-1) - bias = paddle.to_tensor(bias_np).reshape(-1) - - ln = nn.LayerNorm( - normalized_shape=self.x_shape[self.begin_norm_axis :], - weight_attr=nn.initializer.Assign(scale), - bias_attr=nn.initializer.Assign(bias), - epsilon=self.epsilon, - ) - - x_pd = paddle.to_tensor(x_np) - y_pd = ln(x_pd) - expect_res = _reference_layer_norm_naive( - x_np, scale_np, bias_np, self.epsilon, self.begin_norm_axis - ) - - np.testing.assert_allclose( - y_pd.numpy(), expect_res, rtol=1e-5, atol=1e-4 - ) - - def test_no_scale_no_bias_fp32(self): - """test the case when both scale and bias are disabled (FP32).""" - for place in self.places: - with paddle.base.dygraph.guard(place): - x_np = np.random.random(self.x_shape).astype('float32') - x_pd = paddle.to_tensor(x_np) - - ln = nn.LayerNorm( - normalized_shape=self.x_shape[self.begin_norm_axis :], - elementwise_affine=False, - epsilon=self.epsilon, - ) - y_pd = ln(x_pd) - - expect_res = _reference_layer_norm_naive( - x_np, None, None, self.epsilon, self.begin_norm_axis - ) - np.testing.assert_allclose( - y_pd.numpy(), expect_res, rtol=1e-5, atol=1e-4 - ) - - def test_with_scale_no_bias_fp32(self): - """test the case when only scale is enabled (FP32).""" - for place in self.places: - with paddle.base.dygraph.guard(place): - x_np = np.random.random(self.x_shape).astype('float32') - scale_np = np.random.random( - self.x_shape[self.begin_norm_axis :] - ).astype('float32') - scale = paddle.to_tensor(scale_np).reshape(-1) - - ln = nn.LayerNorm( - normalized_shape=self.x_shape[self.begin_norm_axis :], - elementwise_affine=True, - bias_attr=False, - epsilon=self.epsilon, - ) - with paddle.no_grad(): - ln.weight.set_value(scale) - - x_pd = paddle.to_tensor(x_np) - y_pd = ln(x_pd) - - expect_res = _reference_layer_norm_naive( - x_np, scale_np, None, self.epsilon, self.begin_norm_axis - ) - np.testing.assert_allclose( - y_pd.numpy(), expect_res, rtol=1e-5, atol=1e-4 - ) - - def test_no_scale_with_bias_fp32(self): - """test the case when only bias is enabled (FP32).""" - for place in self.places: - with paddle.base.dygraph.guard(place): - x_np = np.random.random(self.x_shape).astype('float32') - bias_np = np.random.random( - self.x_shape[self.begin_norm_axis :] - ).astype('float32') - bias = paddle.to_tensor(bias_np).reshape(-1) - - ln = nn.LayerNorm( - normalized_shape=self.x_shape[self.begin_norm_axis :], - elementwise_affine=True, - weight_attr=False, - epsilon=self.epsilon, - ) - with paddle.no_grad(): - ln.bias.set_value(bias) - - x_pd = paddle.to_tensor(x_np) - y_pd = ln(x_pd) - - expect_res = _reference_layer_norm_naive( - x_np, None, bias_np, self.epsilon, self.begin_norm_axis - ) - np.testing.assert_allclose( - y_pd.numpy(), expect_res, rtol=1e-5, atol=1e-4 - ) - - def test_bf16_forward_backward(self): - """test forward and backward pass with bfloat16 precision.""" - for place in self.places: - with paddle.base.dygraph.guard(place): - x_np = np.random.random(self.x_shape).astype('float32') - scale_np = np.random.random( - self.x_shape[self.begin_norm_axis :] - ).astype('float32') - bias_np = np.random.random( - self.x_shape[self.begin_norm_axis :] - ).astype('float32') - - x = paddle.to_tensor(x_np).cast(paddle.bfloat16) - x.stop_gradient = False - - scale = ( - paddle.to_tensor(scale_np).cast(paddle.bfloat16).reshape(-1) - ) - bias = ( - paddle.to_tensor(bias_np).cast(paddle.bfloat16).reshape(-1) - ) - - ln = nn.LayerNorm( - normalized_shape=self.x_shape[self.begin_norm_axis :], - weight_attr=nn.initializer.Assign(scale), - bias_attr=nn.initializer.Assign(bias), - epsilon=self.epsilon, - ) - - y = ln(x) - loss = y.sum() - loss.backward() - - self.assertIsNotNone(x.grad) - self.assertIsNotNone(ln.weight.grad) - self.assertIsNotNone(ln.bias.grad) - - -class TestLayerNormParam(unittest.TestCase): - def setUp(self): - self.normalized_shape = [6] - self.x_tensor = paddle.randn([2, 4, 4, 6]) - self.places = get_places() - - def test_elementwise_affine_false(self): - """test that when elementwise_affine=False, no learnable parameters are created.""" - layer = nn.LayerNorm( - normalized_shape=self.normalized_shape, elementwise_affine=False - ) - self.assertIsNone(layer.weight) - self.assertIsNone(layer.bias) - - out = layer(self.x_tensor) - self.assertEqual(out.shape, self.x_tensor.shape) - - def test_elementwise_affine_true(self): - """test that when elementwise_affine=True and attr=None, parameters are created with default initialization.""" - for place in self.places: - with paddle.base.dygraph.guard(place): - layer = nn.LayerNorm( - normalized_shape=self.normalized_shape, - elementwise_affine=True, - ) - self.assertIsNotNone(layer.weight) - self.assertIsNotNone(layer.bias) - - expected_weight = paddle.ones([6]) - expected_bias = paddle.zeros([6]) - self.assertTrue(paddle.allclose(layer.weight, expected_weight)) - self.assertTrue(paddle.allclose(layer.bias, expected_bias)) - - def test_bias_false(self): - """test that when bias=False, the bias parameter is disabled even if elementwise_affine=True.""" - layer = nn.LayerNorm( - normalized_shape=self.normalized_shape, - elementwise_affine=True, - bias=False, - ) - self.assertIsNotNone(layer.weight) - self.assertIsNone(layer.bias) - - def test_attr_custom_initialization(self): - """test that weight_attr and bias_attr can be used to customize the initialization of the weight parameter.""" - for place in self.places: - with paddle.base.dygraph.guard(place): - weight_attr = paddle.nn.initializer.Constant(value=2.0) - bias_attr = paddle.nn.initializer.Constant(value=3.0) - layer = nn.LayerNorm( - normalized_shape=self.normalized_shape, - elementwise_affine=True, - weight_attr=weight_attr, - bias_attr=bias_attr, - ) - - expected_weight = paddle.full([6], 2.0) - expected_bias = paddle.full([6], 3.0) - self.assertTrue(paddle.allclose(layer.weight, expected_weight)) - self.assertTrue(paddle.allclose(layer.bias, expected_bias)) - - def test_alias(self): - """test parameter alias epsilon/eps""" - for place in self.places: - with paddle.base.dygraph.guard(place): - layer_epsilon = nn.LayerNorm( - normalized_shape=self.normalized_shape, - elementwise_affine=True, - epsilon=1e-5, - ) - layer_eps = nn.LayerNorm( - normalized_shape=self.normalized_shape, - elementwise_affine=True, - eps=1e-5, - ) - - out_epsilon = layer_epsilon(self.x_tensor) - out_eps = layer_eps(self.x_tensor) - - np.testing.assert_array_equal( - out_epsilon.numpy(), out_eps.numpy() - ) - - def test_errors(self): - """test for errors.""" - layer_norm = nn.LayerNorm(self.normalized_shape) - x1 = np.random.random([3, *self.normalized_shape]).astype('float32') - with self.assertRaises(ValueError): - layer_norm(x1) - with self.assertRaises(TypeError): - nn.LayerNorm(self.normalized_shape, 1e-5, None, None, "name") - with self.assertRaises(TypeError): - nn.LayerNorm( - self.normalized_shape, 1e-5, False, "cpu", paddle.float32 - ) - - -if __name__ == '__main__': - unittest.main() From 24d7488619be73ddfaf0b00c1458bc0162531c18 Mon Sep 17 00:00:00 2001 From: lijialin03 Date: Thu, 13 Nov 2025 06:44:52 +0000 Subject: [PATCH 4/6] update:add all static tests and fix api --- python/paddle/nn/layer/norm.py | 8 +- test/legacy_test/test_layer_norm_op_v2.py | 483 ++++++++++++++++------ 2 files changed, 364 insertions(+), 127 deletions(-) diff --git a/python/paddle/nn/layer/norm.py b/python/paddle/nn/layer/norm.py index 7d89f34ff98c14..efc2c601e51697 100644 --- a/python/paddle/nn/layer/norm.py +++ b/python/paddle/nn/layer/norm.py @@ -680,12 +680,8 @@ def __init__( if not elementwise_affine: weight_attr = False bias_attr = False - else: - weight_attr = weight_attr if weight_attr is not False else None - if not bias: - bias_attr = False - else: - bias_attr = bias_attr if bias_attr is not False else None + elif not bias: + bias_attr = False self._weight_attr = weight_attr self._bias_attr = bias_attr diff --git a/test/legacy_test/test_layer_norm_op_v2.py b/test/legacy_test/test_layer_norm_op_v2.py index 652d9d8194aa24..89152389f5db4a 100644 --- a/test/legacy_test/test_layer_norm_op_v2.py +++ b/test/legacy_test/test_layer_norm_op_v2.py @@ -16,6 +16,7 @@ import numpy as np from op_test import get_places +from utils import static_guard import paddle from paddle import base @@ -159,164 +160,404 @@ def compute_v4(x): ) -class TestLayerNormParam(unittest.TestCase): +class TestLayerNormParamDygraph(unittest.TestCase): def setUp(self): + paddle.disable_static() self.normalized_shape = [6] self.x_shape = [2, 4, 4, 6] - self.epsilon = 1e-5 self.places = get_places() - def test_elementwise_affine_false(self): - """test that when elementwise_affine=False, weight and bias parameters are not created.""" + def _run_test_on_places(self, test_func): + """Helper to run the test function on all places.""" for p in self.places: with base.dygraph.guard(p): - layer = paddle.nn.LayerNorm( - normalized_shape=self.normalized_shape, - elementwise_affine=False, - ) - self.assertIsNone( - layer.weight, - "Weight should be None when elementwise_affine=False", - ) - self.assertIsNone( - layer.bias, - "Bias should be None when elementwise_affine=False", - ) + test_func(p) + + def test_elementwise_affine_false(self): + """test that when elementwise_affine=False, weight and bias parameters are not created.""" + + def run_test(p): + layer = paddle.nn.LayerNorm( + normalized_shape=self.normalized_shape, elementwise_affine=False + ) + assert layer.weight is None + assert layer.bias is None + + x_tensor = paddle.randn(self.x_shape) + out = layer(x_tensor) + assert out.shape == self.x_shape - x_tensor = paddle.randn(self.x_shape) - out = layer(x_tensor) - self.assertEqual(out.shape, x_tensor.shape) + self._run_test_on_places(run_test) def test_elementwise_affine_true(self): """test that when elementwise_affine=True and attr=None, parameters are created with default initialization.""" - for place in self.places: - with paddle.base.dygraph.guard(place): - layer = paddle.nn.LayerNorm( - normalized_shape=self.normalized_shape, - elementwise_affine=True, - ) - self.assertIsNotNone( - layer.weight, - "Weight should not be None when elementwise_affine=True", - ) - self.assertIsNotNone( - layer.bias, - "Weight should not be None when elementwise_affine=True", - ) - expected_weight = paddle.ones(self.normalized_shape) - expected_bias = paddle.zeros(self.normalized_shape) + def run_test(p): + layer = paddle.nn.LayerNorm( + normalized_shape=self.normalized_shape, + elementwise_affine=True, + ) + assert layer.weight is not None + assert layer.bias is not None + + expected_weight = paddle.ones(self.normalized_shape) + expected_bias = paddle.zeros(self.normalized_shape) - self.assertTrue(paddle.allclose(layer.weight, expected_weight)) - self.assertTrue(paddle.allclose(layer.bias, expected_bias)) + np.testing.assert_allclose( + layer.weight.numpy(), expected_weight.numpy() + ) + np.testing.assert_allclose( + layer.bias.numpy(), expected_bias.numpy() + ) + + self._run_test_on_places(run_test) def test_bias_false(self): """test that when bias=False, the bias parameter is disabled even if elementwise_affine=True.""" - for p in self.places: - with base.dygraph.guard(p): - layer = paddle.nn.LayerNorm( - normalized_shape=self.normalized_shape, - elementwise_affine=True, - bias=False, - ) - self.assertIsNotNone( - layer.weight, - "Weight should exist when elementwise_affine=True", - ) - self.assertIsNone( - layer.bias, "Bias should be None when bias_attr=False" - ) + + def run_test(p): + layer = paddle.nn.LayerNorm( + normalized_shape=self.normalized_shape, + elementwise_affine=True, + bias=False, + ) + assert layer.weight is not None + assert layer.bias is None + + self._run_test_on_places(run_test) def test_weight_and_bias_false(self): """test that when weight_attr=False and bias_attr=False, both parameters are disabled.""" - for p in self.places: - with base.dygraph.guard(p): - layer = paddle.nn.LayerNorm( - normalized_shape=self.normalized_shape, - elementwise_affine=True, - weight_attr=False, - bias_attr=False, - ) - self.assertIsNotNone( - layer.weight, - "Weight should not be None when elementwise_affine=True although weight_attr=False", - ) - self.assertIsNotNone( - layer.bias, - "Bias should not be None when elementwise_affine=True although bias_attr=False", - ) + + def run_test(p): + layer = paddle.nn.LayerNorm( + normalized_shape=self.normalized_shape, + elementwise_affine=True, + weight_attr=False, + bias_attr=False, + ) + assert layer.weight is None + assert layer.bias is None + + self._run_test_on_places(run_test) def test_custom_initialization(self): """test custom initialization using weight_attr and bias_attr.""" - for p in self.places: - with base.dygraph.guard(p): - weight_val = 2.5 - bias_val = -1.0 - weight_initializer = paddle.nn.initializer.Constant( - value=weight_val - ) - bias_initializer = paddle.nn.initializer.Constant( - value=bias_val - ) - layer = paddle.nn.LayerNorm( - normalized_shape=self.normalized_shape, - elementwise_affine=True, - weight_attr=weight_initializer, - bias_attr=bias_initializer, + def run_test(p): + weight_val = 2.5 + bias_val = -1.0 + weight_initializer = paddle.nn.initializer.Constant( + value=weight_val + ) + bias_initializer = paddle.nn.initializer.Constant(value=bias_val) + + layer = paddle.nn.LayerNorm( + normalized_shape=self.normalized_shape, + elementwise_affine=True, + weight_attr=weight_initializer, + bias_attr=bias_initializer, + ) + + expected_weight = paddle.full( + self.normalized_shape, weight_val, dtype=layer.weight.dtype + ) + expected_bias = paddle.full( + self.normalized_shape, bias_val, dtype=layer.bias.dtype + ) + + np.testing.assert_allclose( + layer.weight.numpy(), expected_weight.numpy() + ) + np.testing.assert_allclose( + layer.bias.numpy(), expected_bias.numpy() + ) + + self._run_test_on_places(run_test) + + def test_alias(self): + """test parameter alias epsilon/eps""" + + def run_test(p): + layer_epsilon = paddle.nn.LayerNorm( + normalized_shape=self.normalized_shape, + elementwise_affine=True, + epsilon=1e-5, + ) + layer_eps = paddle.nn.LayerNorm( + normalized_shape=self.normalized_shape, + elementwise_affine=True, + eps=1e-5, + ) + + x_tensor = paddle.randn(self.x_shape) + out_epsilon = layer_epsilon(x_tensor) + out_eps = layer_eps(x_tensor) + + np.testing.assert_array_equal(out_epsilon.numpy(), out_eps.numpy()) + + self._run_test_on_places(run_test) + + def test_errors(self): + """test for errors.""" + + def run_test(p): + try: + layer_norm = paddle.nn.LayerNorm(self.normalized_shape) + x1 = np.random.random([3, *self.normalized_shape]).astype( + 'float32' + ) + layer_norm(x1) + self.fail( + "Expected ValueError for wrong input type in dygraph mode" ) + except ValueError: + pass - expected_weight = paddle.full( - self.normalized_shape, weight_val, dtype=layer.weight.dtype + try: + paddle.nn.LayerNorm( + self.normalized_shape, 1e-5, None, None, "name" ) - expected_bias = paddle.full( - self.normalized_shape, bias_val, dtype=layer.bias.dtype + self.fail( + "Expected TypeError for positional args mismatch in dygraph mode" ) + except TypeError: + pass - self.assertTrue( - paddle.allclose(layer.weight, expected_weight), - f"Weight initialization failed. Got {layer.weight.numpy()}, expected {expected_weight.numpy()}", + try: + paddle.nn.LayerNorm( + self.normalized_shape, 1e-5, False, "cpu", paddle.float32 ) - self.assertTrue( - paddle.allclose(layer.bias, expected_bias), - f"Bias initialization failed. Got {layer.bias.numpy()}, expected {expected_bias.numpy()}", + self.fail( + "Expected TypeError for positional args mismatch in dygraph mode" ) + except TypeError: + pass - def test_alias(self): - """test parameter alias epsilon/eps""" - for place in self.places: - with paddle.base.dygraph.guard(place): - layer_epsilon = paddle.nn.LayerNorm( - normalized_shape=self.normalized_shape, - elementwise_affine=True, - epsilon=1e-5, + self._run_test_on_places(run_test) + + +class TestLayerNormParamStatic(unittest.TestCase): + def setUp(self): + paddle.enable_static() + self.normalized_shape = [6] + self.x_shape = [2, 4, 4, 6] + self.places = get_places() + + def test_static_elementwise_affine_false(self): + """test elementwise_affine=False in static graph mode.""" + for p in self.places: + with static_guard(): + main = base.Program() + start = base.Program() + with ( + base.unique_name.guard(), + base.program_guard(main, start), + ): + layer = paddle.nn.LayerNorm( + normalized_shape=self.normalized_shape, + elementwise_affine=False, + ) + x = paddle.static.data( + name='x', shape=self.x_shape, dtype='float32' + ) + out = layer(x) + + exe = base.Executor(p) + exe.run(start) + paddle.device.synchronize(p) + input_np = np.random.randn(*self.x_shape).astype('float32') + result = exe.run(main, feed={'x': input_np}, fetch_list=[out])[ + 0 + ] + + assert result.shape == tuple(self.x_shape) + + def test_static_elementwise_affine_true(self): + """test elementwise_affine=True with default init in static graph mode.""" + for p in self.places: + with static_guard(): + main = base.Program() + start = base.Program() + with ( + base.unique_name.guard(), + base.program_guard(main, start), + ): + layer = paddle.nn.LayerNorm( + normalized_shape=self.normalized_shape, + elementwise_affine=True, + ) + + exe = base.Executor(p) + exe.run(start) + paddle.device.synchronize(p) + weight_np, bias_np = exe.run( + main, fetch_list=[layer.weight, layer.bias] ) - layer_eps = paddle.nn.LayerNorm( - normalized_shape=self.normalized_shape, - elementwise_affine=True, - eps=1e-5, + + assert weight_np is not None + assert bias_np is not None + + expected_weight = np.ones(self.normalized_shape) + expected_bias = np.zeros(self.normalized_shape) + + np.testing.assert_allclose(weight_np, expected_weight) + np.testing.assert_allclose(bias_np, expected_bias) + + def test_static_bias_false(self): + """test bias=False in static graph mode.""" + for p in self.places: + with static_guard(): + main = base.Program() + start = base.Program() + with ( + base.unique_name.guard(), + base.program_guard(main, start), + ): + layer = paddle.nn.LayerNorm( + normalized_shape=self.normalized_shape, + elementwise_affine=True, + bias=False, + ) + assert layer.bias is None + + exe = base.Executor(p) + exe.run(start) + paddle.device.synchronize(p) + weight_np = exe.run(main, fetch_list=[layer.weight])[0] + assert weight_np is not None + assert weight_np.shape == tuple(self.normalized_shape) + + def test_static_weight_and_bias_false(self): + """test weight_attr=False and bias_attr=False in static graph mode.""" + for p in self.places: + with static_guard(): + main = base.Program() + start = base.Program() + with ( + base.unique_name.guard(), + base.program_guard(main, start), + ): + layer = paddle.nn.LayerNorm( + normalized_shape=self.normalized_shape, + elementwise_affine=True, + weight_attr=False, + bias_attr=False, + ) + assert layer.weight is None + assert layer.bias is None + + def test_static_custom_initialization(self): + """test custom initialization in static graph mode.""" + for p in self.places: + with static_guard(): + main = base.Program() + start = base.Program() + with ( + base.unique_name.guard(), + base.program_guard(main, start), + ): + weight_val = 2.5 + bias_val = -1.0 + weight_initializer = paddle.nn.initializer.Constant( + value=weight_val + ) + bias_initializer = paddle.nn.initializer.Constant( + value=bias_val + ) + + layer = paddle.nn.LayerNorm( + normalized_shape=self.normalized_shape, + elementwise_affine=True, + weight_attr=weight_initializer, + bias_attr=bias_initializer, + ) + + exe = base.Executor(p) + exe.run(start) + paddle.device.synchronize(p) + weight_np, bias_np = exe.run( + main, fetch_list=[layer.weight, layer.bias] ) - x_tensor = paddle.randn(self.x_shape) - out_epsilon = layer_epsilon(x_tensor) - out_eps = layer_eps(x_tensor) + expected_weight = np.full(self.normalized_shape, weight_val) + expected_bias = np.full(self.normalized_shape, bias_val) + + np.testing.assert_allclose(weight_np, expected_weight) + np.testing.assert_allclose(bias_np, expected_bias) + + def test_static_alias(self): + """test parameter alias epsilon/eps in static graph mode.""" + for p in self.places: + with static_guard(): + main = base.Program() + start = base.Program() + with ( + base.unique_name.guard(), + base.program_guard(main, start), + ): + layer_epsilon = paddle.nn.LayerNorm( + normalized_shape=self.normalized_shape, + elementwise_affine=True, + epsilon=1e-5, + ) + layer_eps = paddle.nn.LayerNorm( + normalized_shape=self.normalized_shape, + elementwise_affine=True, + eps=1e-5, + ) - np.testing.assert_array_equal( - out_epsilon.numpy(), out_eps.numpy() + x = paddle.static.data( + name='x', shape=self.x_shape, dtype='float32' + ) + out_epsilon = layer_epsilon(x) + out_eps = layer_eps(x) + + exe = base.Executor(p) + exe.run(start) + paddle.device.synchronize(p) + input_np = np.random.randn(*self.x_shape).astype('float32') + out_eps_val, out_epsilon_val = exe.run( + main, + feed={'x': input_np}, + fetch_list=[out_eps, out_epsilon], ) - def test_errors(self): - """test for errors.""" - layer_norm = paddle.nn.LayerNorm(self.normalized_shape) - x1 = np.random.random([3, *self.normalized_shape]).astype('float32') - with self.assertRaises(TypeError): - layer_norm(x1) - with self.assertRaises(TypeError): - paddle.nn.LayerNorm(self.normalized_shape, 1e-5, None, None, "name") - with self.assertRaises(TypeError): - paddle.nn.LayerNorm( - self.normalized_shape, 1e-5, False, "cpu", paddle.float32 - ) + np.testing.assert_array_equal(out_epsilon_val, out_eps_val) + + def test_static_errors(self): + """test errors in static graph mode.""" + for p in self.places: + with static_guard(): + main = base.Program() + start = base.Program() + with ( + base.unique_name.guard(), + base.program_guard(main, start), + ): + try: + paddle.nn.LayerNorm( + self.normalized_shape, 1e-5, None, None, "name" + ) + self.fail( + "Expected TypeError for positional args mismatch in static mode" + ) + except TypeError: + pass + + try: + paddle.nn.LayerNorm( + self.normalized_shape, + 1e-5, + False, + "cpu", + paddle.float32, + ) + self.fail( + "Expected TypeError for positional args mismatch in static mode" + ) + except TypeError: + pass if __name__ == '__main__': From 01fa13599323eb85fd08d07009b4a2e92bc87327 Mon Sep 17 00:00:00 2001 From: lijialin03 Date: Thu, 13 Nov 2025 11:43:52 +0000 Subject: [PATCH 5/6] update:update and fix code --- test/legacy_test/test_layer_norm_op_v2.py | 45 +++++------------------ 1 file changed, 10 insertions(+), 35 deletions(-) diff --git a/test/legacy_test/test_layer_norm_op_v2.py b/test/legacy_test/test_layer_norm_op_v2.py index 89152389f5db4a..0fe377bc57f629 100644 --- a/test/legacy_test/test_layer_norm_op_v2.py +++ b/test/legacy_test/test_layer_norm_op_v2.py @@ -302,37 +302,22 @@ def test_errors(self): """test for errors.""" def run_test(p): - try: + with self.assertRaises(ValueError): layer_norm = paddle.nn.LayerNorm(self.normalized_shape) x1 = np.random.random([3, *self.normalized_shape]).astype( 'float32' ) layer_norm(x1) - self.fail( - "Expected ValueError for wrong input type in dygraph mode" - ) - except ValueError: - pass - try: + with self.assertRaises(TypeError): paddle.nn.LayerNorm( self.normalized_shape, 1e-5, None, None, "name" ) - self.fail( - "Expected TypeError for positional args mismatch in dygraph mode" - ) - except TypeError: - pass - try: + with self.assertRaises(TypeError): paddle.nn.LayerNorm( self.normalized_shape, 1e-5, False, "cpu", paddle.float32 ) - self.fail( - "Expected TypeError for positional args mismatch in dygraph mode" - ) - except TypeError: - pass self._run_test_on_places(run_test) @@ -365,7 +350,7 @@ def test_static_elementwise_affine_false(self): exe = base.Executor(p) exe.run(start) - paddle.device.synchronize(p) + paddle.device.synchronize() input_np = np.random.randn(*self.x_shape).astype('float32') result = exe.run(main, feed={'x': input_np}, fetch_list=[out])[ 0 @@ -390,7 +375,7 @@ def test_static_elementwise_affine_true(self): exe = base.Executor(p) exe.run(start) - paddle.device.synchronize(p) + paddle.device.synchronize() weight_np, bias_np = exe.run( main, fetch_list=[layer.weight, layer.bias] ) @@ -423,7 +408,7 @@ def test_static_bias_false(self): exe = base.Executor(p) exe.run(start) - paddle.device.synchronize(p) + paddle.device.synchronize() weight_np = exe.run(main, fetch_list=[layer.weight])[0] assert weight_np is not None assert weight_np.shape == tuple(self.normalized_shape) @@ -475,7 +460,7 @@ def test_static_custom_initialization(self): exe = base.Executor(p) exe.run(start) - paddle.device.synchronize(p) + paddle.device.synchronize() weight_np, bias_np = exe.run( main, fetch_list=[layer.weight, layer.bias] ) @@ -515,7 +500,7 @@ def test_static_alias(self): exe = base.Executor(p) exe.run(start) - paddle.device.synchronize(p) + paddle.device.synchronize() input_np = np.random.randn(*self.x_shape).astype('float32') out_eps_val, out_epsilon_val = exe.run( main, @@ -535,17 +520,12 @@ def test_static_errors(self): base.unique_name.guard(), base.program_guard(main, start), ): - try: + with self.assertRaises(TypeError): paddle.nn.LayerNorm( self.normalized_shape, 1e-5, None, None, "name" ) - self.fail( - "Expected TypeError for positional args mismatch in static mode" - ) - except TypeError: - pass - try: + with self.assertRaises(TypeError): paddle.nn.LayerNorm( self.normalized_shape, 1e-5, @@ -553,11 +533,6 @@ def test_static_errors(self): "cpu", paddle.float32, ) - self.fail( - "Expected TypeError for positional args mismatch in static mode" - ) - except TypeError: - pass if __name__ == '__main__': From 4460687a00ce00e2de58aaac33481e55c40c455a Mon Sep 17 00:00:00 2001 From: lijialin03 Date: Fri, 14 Nov 2025 10:02:38 +0000 Subject: [PATCH 6/6] update:remove paddle.device.synchronize() --- test/legacy_test/test_layer_norm_op_v2.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/test/legacy_test/test_layer_norm_op_v2.py b/test/legacy_test/test_layer_norm_op_v2.py index 0fe377bc57f629..97e681733df801 100644 --- a/test/legacy_test/test_layer_norm_op_v2.py +++ b/test/legacy_test/test_layer_norm_op_v2.py @@ -350,7 +350,6 @@ def test_static_elementwise_affine_false(self): exe = base.Executor(p) exe.run(start) - paddle.device.synchronize() input_np = np.random.randn(*self.x_shape).astype('float32') result = exe.run(main, feed={'x': input_np}, fetch_list=[out])[ 0 @@ -375,7 +374,6 @@ def test_static_elementwise_affine_true(self): exe = base.Executor(p) exe.run(start) - paddle.device.synchronize() weight_np, bias_np = exe.run( main, fetch_list=[layer.weight, layer.bias] ) @@ -408,7 +406,6 @@ def test_static_bias_false(self): exe = base.Executor(p) exe.run(start) - paddle.device.synchronize() weight_np = exe.run(main, fetch_list=[layer.weight])[0] assert weight_np is not None assert weight_np.shape == tuple(self.normalized_shape) @@ -460,7 +457,6 @@ def test_static_custom_initialization(self): exe = base.Executor(p) exe.run(start) - paddle.device.synchronize() weight_np, bias_np = exe.run( main, fetch_list=[layer.weight, layer.bias] ) @@ -500,7 +496,6 @@ def test_static_alias(self): exe = base.Executor(p) exe.run(start) - paddle.device.synchronize() input_np = np.random.randn(*self.x_shape).astype('float32') out_eps_val, out_epsilon_val = exe.run( main,