diff --git a/paddle/common/flags.cc b/paddle/common/flags.cc index 7f1a9448bbc0a6..7f977c85464fc9 100644 --- a/paddle/common/flags.cc +++ b/paddle/common/flags.cc @@ -2275,7 +2275,7 @@ PHI_DEFINE_EXPORTED_bool(use_default_stream, * Note: Whether use Stride_Compute_Kernel. */ PHI_DEFINE_EXPORTED_bool(use_stride_compute_kernel, - false, + true, "Whether use Stride_Compute_Kernel."); /** diff --git a/paddle/phi/kernels/stride/elementwise_grad_stride_kernel.cu b/paddle/phi/kernels/stride/elementwise_grad_stride_kernel.cu index de9fe2b6ce1c06..697377e936659b 100644 --- a/paddle/phi/kernels/stride/elementwise_grad_stride_kernel.cu +++ b/paddle/phi/kernels/stride/elementwise_grad_stride_kernel.cu @@ -278,7 +278,7 @@ using bfloat16 = phi::bfloat16; using complex64 = ::phi::complex64; using complex128 = ::phi::complex128; -PD_REGISTER_KERNEL(add_grad, +PD_REGISTER_KERNEL(add_grad_stride, GPU, STRIDED, phi::AddGradStrideKernel, @@ -291,7 +291,7 @@ PD_REGISTER_KERNEL(add_grad, phi::complex64, phi::complex128) {} -PD_REGISTER_KERNEL(subtract_grad, +PD_REGISTER_KERNEL(subtract_grad_stride, GPU, STRIDED, phi::SubtractGradStrideKernel, @@ -304,7 +304,7 @@ PD_REGISTER_KERNEL(subtract_grad, phi::complex64, phi::complex128) {} -PD_REGISTER_KERNEL(multiply_grad, +PD_REGISTER_KERNEL(multiply_grad_stride, GPU, STRIDED, phi::MultiplyGradStrideKernel, diff --git a/paddle/phi/kernels/stride/reduce_grad_stride_kernel.cu b/paddle/phi/kernels/stride/reduce_grad_stride_kernel.cu index 437094d1422d35..f47153b89f63c8 100644 --- a/paddle/phi/kernels/stride/reduce_grad_stride_kernel.cu +++ b/paddle/phi/kernels/stride/reduce_grad_stride_kernel.cu @@ -171,7 +171,7 @@ using bfloat16 = phi::bfloat16; using complex64 = ::phi::complex64; using complex128 = ::phi::complex128; -PD_REGISTER_KERNEL(sum_grad, +PD_REGISTER_KERNEL(sum_grad_stride, GPU, STRIDED, phi::ReduceSumGradStrideKernel,