Skip to content

Commit 55b79e8

Browse files
authored
[API Compatiblity] Support paddle.tanh (#76122)
* sink tanh * fix * fix Ut * fix win UT
1 parent 11b63f5 commit 55b79e8

File tree

8 files changed

+77
-74
lines changed

8 files changed

+77
-74
lines changed

paddle/phi/ops/yaml/python_api_info.yaml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -197,6 +197,11 @@
197197
args_mapper :
198198
func : ArgSumMapper
199199

200+
- op : tanh
201+
name : [paddle.tanh, paddle.Tensor.tanh, paddle.nn.functional.tanh]
202+
args_alias:
203+
use_default_mapping : True
204+
200205
- op : exp
201206
name : [paddle.exp, paddle.Tensor.exp]
202207
args_alias:

python/paddle/_paddle_docs.py

Lines changed: 55 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2314,6 +2314,47 @@ def dot(
23142314
""",
23152315
)
23162316

2317+
add_doc_and_signature(
2318+
"tanh",
2319+
r"""
2320+
2321+
Tanh Activation Operator.
2322+
2323+
.. math::
2324+
out = \frac{e^{x} - e^{-x}}{e^{x} + e^{-x}}
2325+
2326+
.. note::
2327+
Alias Support:
2328+
1. The parameter name ``input`` can be used as an alias for ``x``.
2329+
2330+
Args:
2331+
x (Tensor): Input of Tanh operator, an N-D Tensor, with data type bfloat16, float32, float64,
2332+
float16, uint8, int8, int16, int32, int64. Alias: ``input``.
2333+
name (str|None, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`.
2334+
out (Tensor|None, optional): The output tensor. Default: None.
2335+
2336+
Returns:
2337+
Output of Tanh operator, a Tensor with same data type and shape as input
2338+
(integer types are autocasted into float32).
2339+
2340+
Examples:
2341+
.. code-block:: python
2342+
2343+
>>> import paddle
2344+
2345+
>>> x = paddle.to_tensor([-0.4, -0.2, 0.1, 0.3])
2346+
>>> out = paddle.tanh(x)
2347+
>>> out
2348+
Tensor(shape=[4], dtype=float32, place=Place(cpu), stop_gradient=True,
2349+
[-0.37994900, -0.19737528, 0.09966799, 0.29131261])
2350+
""",
2351+
"""
2352+
def tanh(
2353+
x: Tensor, *, out: Tensor | None = None, name: str | None = None,
2354+
) -> Tensor
2355+
""",
2356+
)
2357+
23172358
add_doc_and_signature(
23182359
"exp",
23192360
"""
@@ -2331,7 +2372,7 @@ def dot(
23312372
x (Tensor): Input of Exp operator, an N-D Tensor, with data type int32, int64, bfloat16, float16, float32, float64, complex64 or complex128.
23322373
Alias: ``input``.
23332374
name (str|None, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`.
2334-
out (Tensor|None, optional): The output tensor.
2375+
out (Tensor|None, optional): The output tensor. Default: None.
23352376
23362377
Returns:
23372378
Tensor. Output of Exp operator, a Tensor with shape same as input.
@@ -2371,7 +2412,7 @@ def exp(
23712412
x (Tensor): Input of Expm1 operator, an N-D Tensor, with data type int32, int64, bfloat16, float16, float32, float64, complex64 or complex128.
23722413
Alias: ``input``.
23732414
name (str|None, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`.
2374-
out (Tensor|None, optional): The output tensor.
2415+
out (Tensor|None, optional): The output tensor. Default: None.
23752416
23762417
Returns:
23772418
Tensor. Output of Expm1 operator, a Tensor with shape same as input.
@@ -2494,10 +2535,16 @@ def diagonal(
24942535
out.shape = [4]
24952536
out.data = [1., -1., 3., 1.]
24962537
2538+
.. note::
2539+
Alias Support:
2540+
1. The parameter name ``input`` can be used as an alias for ``x``.
2541+
24972542
Args:
24982543
x (Tensor): Input of Round operator, an N-D Tensor, with data type bfloat16, int32, int64, float32, float64, float16, complex64 or complex128.
2544+
Alias: ``input``.
24992545
decimals(int): Rounded decimal place (default: 0).
25002546
name (str|None, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`.
2547+
out (Tensor|None, optional): The output tensor. Default: None.
25012548
25022549
Returns:
25032550
Tensor. Output of Round operator, a Tensor with shape same as input.
@@ -2529,12 +2576,15 @@ def round(
25292576
25302577
out = |x|
25312578
2579+
.. note::
2580+
Alias Support:
2581+
1. The parameter name ``input`` can be used as an alias for ``x``.
2582+
25322583
Args:
25332584
x (Tensor): The input Tensor with data type int32, int64, float16, float32, float64, complex64 and complex128.
2585+
Alias: ``input``.
25342586
name (str|None, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`.
2535-
2536-
Keyword args:
2537-
out (Tensor|None, optional): The output tensor.
2587+
out (Tensor|None, optional): The output tensor. Default: None.
25382588
25392589
Returns:
25402590
Tensor.A Tensor with the same data type and shape as :math:`x`.

python/paddle/tensor/math.py

Lines changed: 1 addition & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@
4040
sign,
4141
sin,
4242
sum,
43+
tanh,
4344
)
4445
from paddle.base.libpaddle import DataType
4546
from paddle.common_ops_import import VarDesc, dygraph_utils
@@ -4459,60 +4460,6 @@ def prod(
44594460
return out
44604461

44614462

4462-
def tanh(x: Tensor, name: str | None = None) -> Tensor:
4463-
r"""
4464-
Tanh Activation Operator.
4465-
4466-
.. math::
4467-
out = \frac{e^{x} - e^{-x}}{e^{x} + e^{-x}}
4468-
4469-
Args:
4470-
x (Tensor): Input of Tanh operator, an N-D Tensor, with data type bfloat16, float32, float64,
4471-
float16, uint8, int8, int16, int32, int64.
4472-
name (str|None, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`.
4473-
4474-
Returns:
4475-
Output of Tanh operator, a Tensor with same data type and shape as input
4476-
(integer types are autocasted into float32).
4477-
4478-
Examples:
4479-
4480-
.. code-block:: python
4481-
4482-
>>> import paddle
4483-
4484-
>>> x = paddle.to_tensor([-0.4, -0.2, 0.1, 0.3])
4485-
>>> out = paddle.tanh(x)
4486-
>>> out
4487-
Tensor(shape=[4], dtype=float32, place=Place(cpu), stop_gradient=True,
4488-
[-0.37994900, -0.19737528, 0.09966799, 0.29131261])
4489-
"""
4490-
if in_dynamic_or_pir_mode():
4491-
return _C_ops.tanh(x)
4492-
else:
4493-
check_variable_and_dtype(
4494-
x,
4495-
'x',
4496-
[
4497-
'uint16',
4498-
'float16',
4499-
'float32',
4500-
'float64',
4501-
'uint8',
4502-
'int8',
4503-
'int16',
4504-
'int32',
4505-
'int64',
4506-
],
4507-
'tanh',
4508-
)
4509-
check_type(x, 'x', (Variable), 'tanh')
4510-
helper = LayerHelper('tanh', **locals())
4511-
out = helper.create_variable_for_type_inference(x.dtype)
4512-
helper.append_op(type='tanh', inputs={'X': x}, outputs={'Out': out})
4513-
return out
4514-
4515-
45164463
@inplace_apis_in_dygraph_only
45174464
def tanh_(x: Tensor, name: str | None = None) -> Tensor:
45184465
r"""

test/legacy_test/test_activation_op.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6214,6 +6214,7 @@ class TestActivationAPI_Compatibility(unittest.TestCase):
62146214
("paddle.exp", np.exp, {'min_val': -1.0, 'max_val': 1.0}),
62156215
("paddle.expm1", np.expm1, {'min_val': -1.0, 'max_val': 1.0}),
62166216
("paddle.round", np.round, {'min_val': -5.0, 'max_val': 5.0}),
6217+
("paddle.tanh", np.tanh, {'min_val': -1.0, 'max_val': 1.0}),
62176218
]
62186219

62196220
def setUp(self):

test/legacy_test/test_weight_decay.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -62,9 +62,9 @@ def bow_net(
6262
bow = paddle.static.nn.sequence_lod.sequence_pool(
6363
input=emb, pool_type='sum'
6464
)
65-
bow_tanh = paddle.tanh(bow)
66-
fc_1 = paddle.static.nn.fc(x=bow_tanh, size=hid_dim, activation="tanh")
67-
fc_2 = paddle.static.nn.fc(x=fc_1, size=hid_dim2, activation="tanh")
65+
bow_silu = paddle.nn.functional.silu(bow)
66+
fc_1 = paddle.static.nn.fc(x=bow_silu, size=hid_dim, activation="silu")
67+
fc_2 = paddle.static.nn.fc(x=fc_1, size=hid_dim2, activation="silu")
6868
prediction = paddle.static.nn.fc(
6969
x=[fc_2], size=class_dim, activation="softmax"
7070
)

test/standalone_executor/test_standalone_custom_event.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -40,12 +40,12 @@ def build_program():
4040
matmul_out = data @ weight
4141
bias = paddle.ones([1024, 2048], dtype='float32', name='bias')
4242
add_out = paddle.add(matmul_out, bias, name='add_out')
43-
# add_out -> [sub] -> sub_out -> [tanh] -> tanh_out
43+
# add_out -> [sub] -> sub_out -> [silu] -> silu_out
4444
sub_out = paddle.subtract(add_out, data, name='sub_out')
45-
tanh_out = paddle.tanh(sub_out, name='tanh_out')
45+
silu_out = paddle.nn.functional.silu(sub_out, name='silu_out')
4646
bias_1 = paddle.add(bias, sub_out, name='bias_1')
47-
out_before = paddle.tanh(bias_1, name='out_before')
48-
out_last = paddle.subtract(tanh_out, data, name='out_last')
47+
out_before = paddle.nn.functional.silu(bias_1, name='out_before')
48+
out_last = paddle.subtract(silu_out, data, name='out_last')
4949
out_last2 = out_last @ weight
5050

5151
out = paddle.add(out_before, out_last2, name='out')
@@ -64,9 +64,9 @@ class TestManualEvent(unittest.TestCase):
6464
| | | |
6565
| elementwise_sub(s1) |
6666
| | | |
67-
| tanh(s1) elementwise_add(s1)
67+
| silu(s1) elementwise_add(s1)
6868
| | |
69-
elementwise_sub(s1) tanh(s1)
69+
elementwise_sub(s1) silu(s1)
7070
| |
7171
matmul_v2(s1) |
7272
| | ---split prog----

test/standalone_executor/test_standalone_custom_stream.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,9 +35,9 @@ class TestCustomStream(unittest.TestCase):
3535
| | | |
3636
| elementwise_sub(cpu) |
3737
| | | |
38-
| tanh(cpu) elementwise_add(s2)
38+
| silu(cpu) elementwise_add(s2)
3939
| | |
40-
elementwise_sub(s1) tanh(s2)
40+
elementwise_sub(s1) silu(s2)
4141
| |
4242
elementwise_add(s2)
4343
|

test/standalone_executor/test_standalone_executor.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -44,15 +44,15 @@ def build_program():
4444
bias = paddle.ones([4, 64], dtype='float32', name='bias')
4545
add_out = paddle.add(matmul_out, bias, name='add_out')
4646

47-
# add_out -> [memcpy_d2h] -> add_out' -> [sub] -> sub_out -> [tanh] -> tanh_out
47+
# add_out -> [memcpy_d2h] -> add_out' -> [sub] -> sub_out -> [silu] -> silu_out
4848
with paddle.static.device_guard('cpu'):
4949
sub_out = paddle.subtract(add_out, data, name='sub_out')
50-
tanh_out = paddle.tanh(sub_out, name='tanh_out')
50+
silu_out = paddle.nn.functional.silu(sub_out, name='silu_out')
5151

5252
with paddle.static.device_guard('gpu'):
5353
bias_1 = paddle.add(bias, sub_out, name='bias_1')
54-
out_before = paddle.tanh(bias_1, name='out_before')
55-
out_last = paddle.subtract(tanh_out, data, name='out_last')
54+
out_before = paddle.nn.functional.silu(bias_1, name='out_before')
55+
out_last = paddle.subtract(silu_out, data, name='out_last')
5656

5757
out = paddle.add(out_before, out_last, name='out')
5858
mean = paddle.mean(out, name='mean_out')

0 commit comments

Comments
 (0)