diff --git a/python/paddle/nn/layer/norm.py b/python/paddle/nn/layer/norm.py index b0315dd8936891..f26599bfbec648 100644 --- a/python/paddle/nn/layer/norm.py +++ b/python/paddle/nn/layer/norm.py @@ -35,6 +35,7 @@ from paddle import _C_ops, in_dynamic_mode, pir_utils from paddle.device import get_all_custom_device_type +from paddle.utils.decorator_utils import param_one_alias from ...base import dygraph_utils from ...base.data_feeder import check_variable_and_dtype @@ -64,6 +65,7 @@ DataLayoutND, DTypeLike, ParamAttrLike, + PlaceLike, ShapeLike, ) @@ -602,13 +604,34 @@ class LayerNorm(Layer): which is expected to be of that specific size. epsilon(float, optional): The small value added to the variance to prevent division by zero. Default: 1e-05. + alias: ``eps``. + elementwise_affine(bool, optional): Whether to apply element-wise affine transformation + (i.e., learnable scale and bias). If set to ``False``, both the scale (:math:`g`) and + bias (:math:`b`) parameters will be disabled, regardless of the settings of `weight_attr` + and `bias_attr`. This parameter acts as a master switch. Defaults to True. + **Note: This argument must be passed as a keyword argument.** + bias(bool, optional): Whether to include a learnable bias term in the layer. This setting + only takes effect when `elementwise_affine` is ``True``. If set to ``False``, no bias + parameter will be created, even if `bias_attr` is specified. Defaults to True. + **Note: This argument must be passed as a keyword argument.** weight_attr(ParamAttr|bool|None, optional): The parameter attribute for the learnable - gain :math:`g`. If False, weight is None. If is None, a default :code:`ParamAttr` would be added as scale. The - :attr:`param_attr` is initialized as 1 if it is added. Default: None. For more information, please refer to :ref:`api_paddle_ParamAttr` . + gain :math:`g` (scale). This setting only takes effect when `elementwise_affine` is ``True``. + - If set to ``False``, no gain parameter will be created. + - If set to ``None`` or ``True``, a default :code:`ParamAttr` will be used, and the + parameter will be initialized to 1. + - If set to a custom :code:`ParamAttr` object, it will be used to configure the parameter. + Default: None. + **Note: This argument must be passed as a keyword argument.** bias_attr(ParamAttr|bool|None, optional): The parameter attribute for the learnable - bias :math:`b`. If is False, bias is None. If is None, a default :code:`ParamAttr` would be added as bias. The - :attr:`bias_attr` is initialized as 0 if it is added. Default: None. For more information, please refer to :ref:`api_paddle_ParamAttr` . + bias :math:`b`. This setting only takes effect when both `elementwise_affine` and `bias` are ``True``. + - If set to ``False``, no bias parameter will be created. + - If set to ``None`` or ``True``, a default :code:`ParamAttr` will be used, and the + parameter will be initialized to 0. + - If set to a custom :code:`ParamAttr` object, it will be used to configure the parameter. + Default: None. + **Note: This argument must be passed as a keyword argument.** name(str|None, optional): Name for the LayerNorm, default is None. For more information, please refer to :ref:`api_guide_Name` . + **Note: This argument must be passed as a keyword argument.** Shape: - x: 2-D, 3-D, 4-D or 5-D tensor. @@ -642,10 +665,16 @@ class LayerNorm(Layer): weight: Tensor | None bias: Tensor | None + @param_one_alias(["epsilon", "eps"]) def __init__( self, normalized_shape: int | Sequence[int], epsilon: float = 1e-5, + *, + elementwise_affine: bool = True, + bias: bool = True, + device: PlaceLike | None = None, + dtype: DTypeLike | None = None, weight_attr: bool | ParamAttr | None = None, bias_attr: bool | ParamAttr | None = None, name: str | None = None, @@ -656,6 +685,21 @@ def __init__( self._normalized_shape = list(normalized_shape) self._epsilon = epsilon + self._device = device + self._dtype = ( + self._helper.get_default_dtype() if dtype is None else dtype + ) + + if not elementwise_affine: + weight_attr = False + bias_attr = False + else: + weight_attr = weight_attr if weight_attr is not False else None + if not bias: + bias_attr = False + else: + bias_attr = bias_attr if bias_attr is not False else None + self._weight_attr = weight_attr self._bias_attr = bias_attr param_shape = [np.prod(self._normalized_shape)] @@ -665,15 +709,22 @@ def __init__( else: self.weight = self.create_parameter( attr=self._weight_attr, + dtype=self._dtype, shape=param_shape, default_initializer=Constant(1.0), + device=self._device, ) if bias_attr is False: self.bias = None else: self.bias = self.create_parameter( - attr=self._bias_attr, shape=param_shape, is_bias=True + attr=self._bias_attr, + dtype=self._dtype, + shape=param_shape, + default_initializer=Constant(0.0), + device=self._device, + is_bias=True, ) def forward(self, input: Tensor) -> Tensor: diff --git a/test/legacy_test/test_layer_norm_op_v2_dygraph.py b/test/legacy_test/test_layer_norm_op_v2_dygraph.py new file mode 100644 index 00000000000000..5e759e9b7a3ec0 --- /dev/null +++ b/test/legacy_test/test_layer_norm_op_v2_dygraph.py @@ -0,0 +1,296 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +from functools import reduce +from operator import mul + +import numpy as np + +import paddle +from paddle import nn + + +def _reference_layer_norm_naive(x, scale, beta, epsilon, begin_norm_axis=1): + x_shape = x.shape + N = reduce(mul, x_shape[0:begin_norm_axis], 1) + D = reduce(mul, x_shape[begin_norm_axis : len(x_shape)], 1) + x.shape = [N, D] + + mean = np.mean(x, axis=1) + var = np.var(x, axis=1) + epsilon + output = np.divide( + (x - mean.reshape([N, 1])), (np.sqrt(var)).reshape([N, 1]) + ) + if scale is not None: + output = scale.reshape([1, D]) * output + if beta is not None: + output = output + beta.reshape([1, D]) + + x.shape, output.shape = x_shape, x_shape + return output + + +class TestLayerNormOp(unittest.TestCase): + def setUp(self): + paddle.disable_static() + self.x_shape = [2, 6, 6, 3] + self.epsilon = 1e-5 + self.begin_norm_axis = 1 + + @unittest.skipIf( + not paddle.in_dynamic_mode(), "test is only for dynamic mode" + ) + def test_basic_fp32(self): + """test basic functionality with float32.""" + x_np = np.random.random(self.x_shape).astype('float32') + scale_np = np.random.random( + self.x_shape[self.begin_norm_axis :] + ).astype('float32') + bias_np = np.random.random(self.x_shape[self.begin_norm_axis :]).astype( + 'float32' + ) + scale = paddle.to_tensor(scale_np).reshape(-1) + bias = paddle.to_tensor(bias_np).reshape(-1) + + ln = nn.LayerNorm( + normalized_shape=self.x_shape[self.begin_norm_axis :], + weight_attr=nn.initializer.Assign(scale), + bias_attr=nn.initializer.Assign(bias), + epsilon=self.epsilon, + ) + + x_pd = paddle.to_tensor(x_np) + y_pd = ln(x_pd) + expect_res = _reference_layer_norm_naive( + x_np, scale_np, bias_np, self.epsilon, self.begin_norm_axis + ) + + np.testing.assert_allclose( + y_pd.numpy(), expect_res, rtol=1e-5, atol=1e-4 + ) + + @unittest.skipIf( + not paddle.in_dynamic_mode(), "test is only for dynamic mode" + ) + def test_no_scale_no_bias_fp32(self): + """test the case when both scale and bias are disabled (FP32).""" + x_np = np.random.random(self.x_shape).astype('float32') + x_pd = paddle.to_tensor(x_np) + + ln = nn.LayerNorm( + normalized_shape=self.x_shape[self.begin_norm_axis :], + elementwise_affine=False, + epsilon=self.epsilon, + ) + y_pd = ln(x_pd) + + expect_res = _reference_layer_norm_naive( + x_np, None, None, self.epsilon, self.begin_norm_axis + ) + np.testing.assert_allclose( + y_pd.numpy(), expect_res, rtol=1e-5, atol=1e-4 + ) + + @unittest.skipIf( + not paddle.in_dynamic_mode(), "test is only for dynamic mode" + ) + def test_with_scale_no_bias_fp32(self): + """test the case when only scale is enabled (FP32).""" + x_np = np.random.random(self.x_shape).astype('float32') + scale_np = np.random.random( + self.x_shape[self.begin_norm_axis :] + ).astype('float32') + scale = paddle.to_tensor(scale_np).reshape(-1) + + ln = nn.LayerNorm( + normalized_shape=self.x_shape[self.begin_norm_axis :], + elementwise_affine=True, + bias_attr=False, + epsilon=self.epsilon, + ) + with paddle.no_grad(): + ln.weight.set_value(scale) + + x_pd = paddle.to_tensor(x_np) + y_pd = ln(x_pd) + + expect_res = _reference_layer_norm_naive( + x_np, scale_np, None, self.epsilon, self.begin_norm_axis + ) + np.testing.assert_allclose( + y_pd.numpy(), expect_res, rtol=1e-5, atol=1e-4 + ) + + @unittest.skipIf( + not paddle.in_dynamic_mode(), "test is only for dynamic mode" + ) + def test_no_scale_with_bias_fp32(self): + """test the case when only bias is enabled (FP32).""" + x_np = np.random.random(self.x_shape).astype('float32') + bias_np = np.random.random(self.x_shape[self.begin_norm_axis :]).astype( + 'float32' + ) + bias = paddle.to_tensor(bias_np).reshape(-1) + + ln = nn.LayerNorm( + normalized_shape=self.x_shape[self.begin_norm_axis :], + elementwise_affine=True, + weight_attr=False, + epsilon=self.epsilon, + ) + with paddle.no_grad(): + ln.bias.set_value(bias) + + x_pd = paddle.to_tensor(x_np) + y_pd = ln(x_pd) + + expect_res = _reference_layer_norm_naive( + x_np, None, bias_np, self.epsilon, self.begin_norm_axis + ) + np.testing.assert_allclose( + y_pd.numpy(), expect_res, rtol=1e-5, atol=1e-4 + ) + + def test_bf16_forward_backward(self): + """test forward and backward pass with bfloat16 precision.""" + place = paddle.CUDAPlace(0) + + with paddle.base.dygraph.guard(place): + x_np = np.random.random(self.x_shape).astype('float32') + scale_np = np.random.random( + self.x_shape[self.begin_norm_axis :] + ).astype('float32') + bias_np = np.random.random( + self.x_shape[self.begin_norm_axis :] + ).astype('float32') + + x = paddle.to_tensor(x_np).cast(paddle.bfloat16) + x.stop_gradient = False + + scale = paddle.to_tensor(scale_np).cast(paddle.bfloat16).reshape(-1) + bias = paddle.to_tensor(bias_np).cast(paddle.bfloat16).reshape(-1) + + ln = nn.LayerNorm( + normalized_shape=self.x_shape[self.begin_norm_axis :], + weight_attr=nn.initializer.Assign(scale), + bias_attr=nn.initializer.Assign(bias), + epsilon=self.epsilon, + ) + ln.to(device='cuda') + + y = ln(x) + loss = y.sum() + loss.backward() + + self.assertIsNotNone(x.grad) + self.assertIsNotNone(ln.weight.grad) + self.assertIsNotNone(ln.bias.grad) + + +class TestLayerNormParam(unittest.TestCase): + def setUp(self): + self.normalized_shape = [6] + self.x_tensor = paddle.randn([2, 4, 4, 6]) + + def test_elementwise_affine_false(self): + """test that when elementwise_affine=False, no learnable parameters are created.""" + layer = nn.LayerNorm( + normalized_shape=self.normalized_shape, elementwise_affine=False + ) + self.assertIsNone(layer.weight) + self.assertIsNone(layer.bias) + + out = layer(self.x_tensor) + self.assertEqual(out.shape, self.x_tensor.shape) + + @unittest.skipIf( + not paddle.in_dynamic_mode(), "test is only for dynamic mode" + ) + def test_elementwise_affine_true(self): + """test that when elementwise_affine=True and attr=None, parameters are created with default initialization.""" + layer = nn.LayerNorm( + normalized_shape=self.normalized_shape, elementwise_affine=True + ) + self.assertIsNotNone(layer.weight) + self.assertIsNotNone(layer.bias) + + expected_weight = paddle.ones([6]) + expected_bias = paddle.zeros([6]) + self.assertTrue(paddle.allclose(layer.weight, expected_weight)) + self.assertTrue(paddle.allclose(layer.bias, expected_bias)) + + def test_bias_false(self): + """test that when bias=False, the bias parameter is disabled even if elementwise_affine=True.""" + layer = nn.LayerNorm( + normalized_shape=self.normalized_shape, + elementwise_affine=True, + bias=False, + ) + self.assertIsNotNone(layer.weight) + self.assertIsNone(layer.bias) + + @unittest.skipIf( + not paddle.in_dynamic_mode(), "test is only for dynamic mode" + ) + def test_attr_custom_initialization(self): + """test that weight_attr and bias_attr can be used to customize the initialization of the weight parameter.""" + weight_attr = paddle.nn.initializer.Constant(value=2.0) + bias_attr = paddle.nn.initializer.Constant(value=3.0) + layer = nn.LayerNorm( + normalized_shape=self.normalized_shape, + elementwise_affine=True, + weight_attr=weight_attr, + bias_attr=bias_attr, + ) + + expected_weight = paddle.full([6], 2.0) + expected_bias = paddle.full([6], 3.0) + self.assertTrue(paddle.allclose(layer.weight, expected_weight)) + self.assertTrue(paddle.allclose(layer.bias, expected_bias)) + + def test_alias(self): + """test parameter alias epsilon/eps""" + layer_epsilon = nn.LayerNorm( + normalized_shape=self.normalized_shape, + elementwise_affine=True, + epsilon=1e-5, + ) + layer_eps = nn.LayerNorm( + normalized_shape=self.normalized_shape, + elementwise_affine=True, + eps=1e-5, + ) + + out_epsilon = layer_epsilon(self.x_tensor) + out_eps = layer_eps(self.x_tensor) + + np.testing.assert_array_equal(out_epsilon.numpy(), out_eps.numpy()) + + def test_errors(self): + """test for errors.""" + layer_norm = nn.LayerNorm(self.normalized_shape) + x1 = np.random.random([3, *self.normalized_shape]).astype('float32') + with self.assertRaises(ValueError): + layer_norm(x1) + with self.assertRaises(TypeError): + nn.LayerNorm(self.normalized_shape, 1e-5, None, None, "name") + with self.assertRaises(TypeError): + nn.LayerNorm( + self.normalized_shape, 1e-5, False, "cpu", paddle.float32 + ) + + +if __name__ == '__main__': + unittest.main()