From ce32044cb572ea9617b6389db1ee0bd4255f4e2e Mon Sep 17 00:00:00 2001 From: enkilee Date: Tue, 4 Apr 2023 03:37:12 +0000 Subject: [PATCH 01/27] test --- paddle/phi/api/yaml/backward.yaml | 10 +++++ paddle/phi/api/yaml/op_compat.yaml | 7 ++++ paddle/phi/api/yaml/ops.yaml | 9 ++++ paddle/phi/infermeta/binary.cc | 20 +++++++++ paddle/phi/infermeta/binary.h | 4 ++ .../phi/kernels/cpu/nextafter_grad_kernel.cc | 31 ++++++++++++++ paddle/phi/kernels/cpu/nextafter_kernel.cc | 42 +++++++++++++++++++ paddle/phi/kernels/nextafter_grad_kernel.h | 28 +++++++++++++ paddle/phi/kernels/nextafter_kernel.h | 28 +++++++++++++ 9 files changed, 179 insertions(+) create mode 100644 paddle/phi/kernels/cpu/nextafter_grad_kernel.cc create mode 100644 paddle/phi/kernels/cpu/nextafter_kernel.cc create mode 100644 paddle/phi/kernels/nextafter_grad_kernel.h create mode 100644 paddle/phi/kernels/nextafter_kernel.h diff --git a/paddle/phi/api/yaml/backward.yaml b/paddle/phi/api/yaml/backward.yaml index bc85a1d0ca7bf..22c810bedbcc9 100644 --- a/paddle/phi/api/yaml/backward.yaml +++ b/paddle/phi/api/yaml/backward.yaml @@ -1092,6 +1092,16 @@ data_transform : skip_transform : out_size, size_tensor, scale_tensor +- backward_op : nextafter_grad + forward : nextafter (Tensor x, Tensor y) -> Tensor(out) + args : (Tensor x, Tensor y, Tensor out_grad) + output : Tensor(x_grad), Tensor(y_grad) + infer_meta : + func : nextafter_grad + param: [x, y] + kernel : + func : nextafter_grad + - backward_op : nll_loss_grad forward : nll_loss (Tensor input, Tensor label, Tensor weight, int64_t ignore_index = -100, str reduction = "mean") -> Tensor(out), Tensor(total_weight) args : (Tensor input, Tensor label, Tensor weight, Tensor total_weight, Tensor out_grad, int64_t ignore_index, str reduction) diff --git a/paddle/phi/api/yaml/op_compat.yaml b/paddle/phi/api/yaml/op_compat.yaml index 303edec5af7be..95d70b6c041df 100644 --- a/paddle/phi/api/yaml/op_compat.yaml +++ b/paddle/phi/api/yaml/op_compat.yaml @@ -1381,6 +1381,13 @@ extra : attrs : [bool use_mkldnn = false] +- op : nextafter + backward : nextafter_grad + inputs : + {x: X, y : Y} + outputs : + out: Out + - op : nll_loss backward : nll_loss_grad inputs : diff --git a/paddle/phi/api/yaml/ops.yaml b/paddle/phi/api/yaml/ops.yaml index 473b18dbcc476..96cd066468ed7 100644 --- a/paddle/phi/api/yaml/ops.yaml +++ b/paddle/phi/api/yaml/ops.yaml @@ -1117,6 +1117,15 @@ data_transform : skip_transform : out_size, size_tensor, scale_tensor +- op : nextafter + args : (Tensor x, Tensor y) + output : Tensor(out) + infer_meta : + func : ElementwiseInferMeta + kernel : + func : nextafter + backward : nextafter_grad + - op : nll_loss args : (Tensor input, Tensor label, Tensor weight, int64_t ignore_index = -100, str reduction = "mean") output : Tensor(out), Tensor(total_weight) diff --git a/paddle/phi/infermeta/binary.cc b/paddle/phi/infermeta/binary.cc index 0a3e31054b0f5..21c3dcb2d2674 100644 --- a/paddle/phi/infermeta/binary.cc +++ b/paddle/phi/infermeta/binary.cc @@ -2282,6 +2282,26 @@ void MvInferMeta(const MetaTensor& x, const MetaTensor& vec, MetaTensor* out) { out->share_lod(x); } +void NextafterInferMeta(const MetaTensor& x, + const MetaTensor& y, + MetaTensor* out) { + auto x_dims = x.dims(); + auto y_dims = y.dims(); + + if (x_dims.size() > 0) + PADDLE_ENFORCE_GE(x_dims[0], + y_dims[0], + phi::errors::InvalidArgument( + "The count (%d) of elements of X shall " + "greater than count (%d) of elements of Y.", + x_dims[0], + y_dims[0])); + + out->set_dims(x_dims); + out->set_dtype(x.dtype()); + out->share_lod(x); +} + void PReluInferMeta(const MetaTensor& x, const MetaTensor& alpha, const std::string& data_format, diff --git a/paddle/phi/infermeta/binary.h b/paddle/phi/infermeta/binary.h index ed4da703ce520..e643c33b33c39 100644 --- a/paddle/phi/infermeta/binary.h +++ b/paddle/phi/infermeta/binary.h @@ -366,6 +366,10 @@ void MatrixRankTolInferMeta(const MetaTensor& x, void MvInferMeta(const MetaTensor& x, const MetaTensor& vec, MetaTensor* out); +void NextafterInferMeta(const MetaTensor& x, + const MetaTensor& y, + MetaTensor* out); + void PReluInferMeta(const MetaTensor& x, const MetaTensor& alpha, const std::string& data_format, diff --git a/paddle/phi/kernels/cpu/nextafter_grad_kernel.cc b/paddle/phi/kernels/cpu/nextafter_grad_kernel.cc new file mode 100644 index 0000000000000..b8f6b54a84d37 --- /dev/null +++ b/paddle/phi/kernels/cpu/nextafter_grad_kernel.cc @@ -0,0 +1,31 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/nextafter_grad_kernel.h" +#include "paddle/phi/backends/cpu/cpu_context.h" +#include "paddle/phi/core/kernel_registry.h" + +namespace phi { + +template +void NextafterGradKernel(const Context& ctx, + const DenseTensor& x, + const DenseTensor& y, + const DenseTensor& out_grad, + DenseTensor* x_grad, + DenseTensor* y_grad) {} +} // namespace phi + +PD_REGISTER_KERNEL( + nextafter_grad, CPU, ALL_LAYOUT, phi::NextafterGradKernel, float, double) {} diff --git a/paddle/phi/kernels/cpu/nextafter_kernel.cc b/paddle/phi/kernels/cpu/nextafter_kernel.cc new file mode 100644 index 0000000000000..abd0bf483198d --- /dev/null +++ b/paddle/phi/kernels/cpu/nextafter_kernel.cc @@ -0,0 +1,42 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/nextafter_kernel.h" +#include "paddle/phi/backends/cpu/cpu_context.h" +#include "paddle/phi/core/kernel_registry.h" + +namespace phi { + +template +void NextafterKernel(const Context& ctx, + const DenseTensor& x, + const DenseTensor& y, + DenseTensor* out) { + if (x.numel() == 0 || y.numel() == 0) { + return; + } + auto* out_data = dev_ctx.template Alloc(out); + auto* x_data = x.data(); + auto* y_data = y.data(); + int x_numel = x.numel(); + + for (int i = 0; i < x_numel; ++i) { + out_data[i] = std::nextafter(x_data[i], y_data[i]); + } +} + +} // namespace phi + +PD_REGISTER_KERNEL( + nextafter, CPU, ALL_LAYOUT, phi::NextafterKernel, float, double) {} diff --git a/paddle/phi/kernels/nextafter_grad_kernel.h b/paddle/phi/kernels/nextafter_grad_kernel.h new file mode 100644 index 0000000000000..52ddf889cbf11 --- /dev/null +++ b/paddle/phi/kernels/nextafter_grad_kernel.h @@ -0,0 +1,28 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/phi/core/dense_tensor.h" + +namespace phi { + +template +void NextafterGradKernel(const Context& ctx, + const DenseTensor& x, + const DenseTensor& y, + const DenseTensor& out_grad, + DenseTensor* x_grad, + DenseTensor* y_grad); +} // namespace phi diff --git a/paddle/phi/kernels/nextafter_kernel.h b/paddle/phi/kernels/nextafter_kernel.h new file mode 100644 index 0000000000000..047917dbd6753 --- /dev/null +++ b/paddle/phi/kernels/nextafter_kernel.h @@ -0,0 +1,28 @@ + +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/phi/core/dense_tensor.h" + +namespace phi { + +template +void NextafterRawKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& y, + DenseTensor* out); + +} // namespace phi From f260eb5a0f62d91a867033f31b95ca1e69fb6ba7 Mon Sep 17 00:00:00 2001 From: enkilee Date: Tue, 4 Apr 2023 08:04:57 +0000 Subject: [PATCH 02/27] fix --- paddle/phi/api/yaml/backward.yaml | 10 ------ paddle/phi/api/yaml/op_compat.yaml | 1 - paddle/phi/api/yaml/ops.yaml | 1 - .../phi/kernels/cpu/nextafter_grad_kernel.cc | 31 ------------------ paddle/phi/kernels/cpu/nextafter_kernel.cc | 8 +++-- python/paddle/__init__.py | 2 ++ python/paddle/tensor/math.py | 32 +++++++++++++++++++ 7 files changed, 39 insertions(+), 46 deletions(-) delete mode 100644 paddle/phi/kernels/cpu/nextafter_grad_kernel.cc diff --git a/paddle/phi/api/yaml/backward.yaml b/paddle/phi/api/yaml/backward.yaml index 22c810bedbcc9..bc85a1d0ca7bf 100644 --- a/paddle/phi/api/yaml/backward.yaml +++ b/paddle/phi/api/yaml/backward.yaml @@ -1092,16 +1092,6 @@ data_transform : skip_transform : out_size, size_tensor, scale_tensor -- backward_op : nextafter_grad - forward : nextafter (Tensor x, Tensor y) -> Tensor(out) - args : (Tensor x, Tensor y, Tensor out_grad) - output : Tensor(x_grad), Tensor(y_grad) - infer_meta : - func : nextafter_grad - param: [x, y] - kernel : - func : nextafter_grad - - backward_op : nll_loss_grad forward : nll_loss (Tensor input, Tensor label, Tensor weight, int64_t ignore_index = -100, str reduction = "mean") -> Tensor(out), Tensor(total_weight) args : (Tensor input, Tensor label, Tensor weight, Tensor total_weight, Tensor out_grad, int64_t ignore_index, str reduction) diff --git a/paddle/phi/api/yaml/op_compat.yaml b/paddle/phi/api/yaml/op_compat.yaml index 95d70b6c041df..12f7ca7526e50 100644 --- a/paddle/phi/api/yaml/op_compat.yaml +++ b/paddle/phi/api/yaml/op_compat.yaml @@ -1382,7 +1382,6 @@ attrs : [bool use_mkldnn = false] - op : nextafter - backward : nextafter_grad inputs : {x: X, y : Y} outputs : diff --git a/paddle/phi/api/yaml/ops.yaml b/paddle/phi/api/yaml/ops.yaml index 96cd066468ed7..d96a7f905fec6 100644 --- a/paddle/phi/api/yaml/ops.yaml +++ b/paddle/phi/api/yaml/ops.yaml @@ -1124,7 +1124,6 @@ func : ElementwiseInferMeta kernel : func : nextafter - backward : nextafter_grad - op : nll_loss args : (Tensor input, Tensor label, Tensor weight, int64_t ignore_index = -100, str reduction = "mean") diff --git a/paddle/phi/kernels/cpu/nextafter_grad_kernel.cc b/paddle/phi/kernels/cpu/nextafter_grad_kernel.cc deleted file mode 100644 index b8f6b54a84d37..0000000000000 --- a/paddle/phi/kernels/cpu/nextafter_grad_kernel.cc +++ /dev/null @@ -1,31 +0,0 @@ -// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "paddle/phi/kernels/nextafter_grad_kernel.h" -#include "paddle/phi/backends/cpu/cpu_context.h" -#include "paddle/phi/core/kernel_registry.h" - -namespace phi { - -template -void NextafterGradKernel(const Context& ctx, - const DenseTensor& x, - const DenseTensor& y, - const DenseTensor& out_grad, - DenseTensor* x_grad, - DenseTensor* y_grad) {} -} // namespace phi - -PD_REGISTER_KERNEL( - nextafter_grad, CPU, ALL_LAYOUT, phi::NextafterGradKernel, float, double) {} diff --git a/paddle/phi/kernels/cpu/nextafter_kernel.cc b/paddle/phi/kernels/cpu/nextafter_kernel.cc index abd0bf483198d..6cc743caba5a5 100644 --- a/paddle/phi/kernels/cpu/nextafter_kernel.cc +++ b/paddle/phi/kernels/cpu/nextafter_kernel.cc @@ -13,8 +13,10 @@ // limitations under the License. #include "paddle/phi/kernels/nextafter_kernel.h" +#include #include "paddle/phi/backends/cpu/cpu_context.h" #include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/funcs/math.h" namespace phi { @@ -26,9 +28,9 @@ void NextafterKernel(const Context& ctx, if (x.numel() == 0 || y.numel() == 0) { return; } - auto* out_data = dev_ctx.template Alloc(out); - auto* x_data = x.data(); - auto* y_data = y.data(); + auto out_data = dev_ctx.template Alloc(out); + auto x_data = x.data(); + auto y_data = y.data(); int x_numel = x.numel(); for (int i = 0; i < x_numel; ++i) { diff --git a/python/paddle/__init__.py b/python/paddle/__init__.py index f6244c51fea83..6cab2c374e1be 100644 --- a/python/paddle/__init__.py +++ b/python/paddle/__init__.py @@ -295,6 +295,7 @@ from .tensor.math import trapezoid # noqa: F401 from .tensor.math import cumulative_trapezoid # noqa: F401 from .tensor.math import vander # noqa: F401 +from .tensor.math import nextafter # noqa: F401 from .tensor.random import bernoulli # noqa: F401 from .tensor.random import poisson # noqa: F401 @@ -687,4 +688,5 @@ 'cumulative_trapezoid', 'polar', 'vander', + 'nextafter', ] diff --git a/python/paddle/tensor/math.py b/python/paddle/tensor/math.py index cc662d83457fa..db834f379ba95 100644 --- a/python/paddle/tensor/math.py +++ b/python/paddle/tensor/math.py @@ -5413,3 +5413,35 @@ def vander(x, n=None, increasing=False, name=None): res[:, 1:] = paddle.cumprod(res[:, 1:], dim=-1) res = res[:, ::-1] if not increasing else res return res + + +def nextafter(x, y, name=None): + r""" + Return the next floating-point value after input towards other, elementwise. + The shapes of input and other must be broadcastable. + + Args: + x (Tensor): An N-D Tensor, the data type is float32, float64. + y (Tensor): An N-D Tensor, the data type is float32, float64. + name(str, optional):Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`. + + Returns: + out (Tensor): An N-D Tensor, the shape and data type is the same with input. + + Examples: + .. code-block:: python + + import paddle + + """ + if in_dygraph_mode(): + return _C_ops.nextafter(x, y) + else: + check_variable_and_dtype(x, 'x', ['float32', 'float64'], 'nextafter') + check_variable_and_dtype(y, 'y', ['float32', 'float64'], 'nextafter') + helper = LayerHelper('nextafter', **locals()) + out = helper.create_variable_for_type_inference(dtype=paddle.float32) + helper.append_op( + type='trace', inputs={'X': x, 'Y': y}, outputs={'Out': out} + ) + return out From 96f562f6bc91d6d5cd7b945d5bd4f31b7f998e4e Mon Sep 17 00:00:00 2001 From: enkilee Date: Tue, 4 Apr 2023 16:06:47 +0000 Subject: [PATCH 03/27] add unittest --- paddle/phi/api/yaml/ops.yaml | 2 +- paddle/phi/infermeta/binary.cc | 11 ---- paddle/phi/kernels/cpu/nextafter_kernel.cc | 2 +- .../tests/unittests/test_nextafter_op.py | 65 +++++++++++++++++++ python/paddle/tensor/math.py | 7 +- 5 files changed, 72 insertions(+), 15 deletions(-) create mode 100644 python/paddle/fluid/tests/unittests/test_nextafter_op.py diff --git a/paddle/phi/api/yaml/ops.yaml b/paddle/phi/api/yaml/ops.yaml index d96a7f905fec6..d79093d1ea023 100644 --- a/paddle/phi/api/yaml/ops.yaml +++ b/paddle/phi/api/yaml/ops.yaml @@ -1121,7 +1121,7 @@ args : (Tensor x, Tensor y) output : Tensor(out) infer_meta : - func : ElementwiseInferMeta + func : NextafterInferMeta kernel : func : nextafter diff --git a/paddle/phi/infermeta/binary.cc b/paddle/phi/infermeta/binary.cc index 21c3dcb2d2674..8c349fa877924 100644 --- a/paddle/phi/infermeta/binary.cc +++ b/paddle/phi/infermeta/binary.cc @@ -2286,17 +2286,6 @@ void NextafterInferMeta(const MetaTensor& x, const MetaTensor& y, MetaTensor* out) { auto x_dims = x.dims(); - auto y_dims = y.dims(); - - if (x_dims.size() > 0) - PADDLE_ENFORCE_GE(x_dims[0], - y_dims[0], - phi::errors::InvalidArgument( - "The count (%d) of elements of X shall " - "greater than count (%d) of elements of Y.", - x_dims[0], - y_dims[0])); - out->set_dims(x_dims); out->set_dtype(x.dtype()); out->share_lod(x); diff --git a/paddle/phi/kernels/cpu/nextafter_kernel.cc b/paddle/phi/kernels/cpu/nextafter_kernel.cc index 6cc743caba5a5..e4ae3af65057b 100644 --- a/paddle/phi/kernels/cpu/nextafter_kernel.cc +++ b/paddle/phi/kernels/cpu/nextafter_kernel.cc @@ -28,7 +28,7 @@ void NextafterKernel(const Context& ctx, if (x.numel() == 0 || y.numel() == 0) { return; } - auto out_data = dev_ctx.template Alloc(out); + auto out_data = ctx.template Alloc(out); auto x_data = x.data(); auto y_data = y.data(); int x_numel = x.numel(); diff --git a/python/paddle/fluid/tests/unittests/test_nextafter_op.py b/python/paddle/fluid/tests/unittests/test_nextafter_op.py new file mode 100644 index 0000000000000..2de978afeb7c0 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_nextafter_op.py @@ -0,0 +1,65 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np + +import paddle + + +def ref_nextafter(x, y): + out = np.nextafter(x, y) + return out + + +class TestNextafterAPI(unittest.TestCase): + def setUp(self): + self.x_np = np.random.rand(1, 2).astype('float32') + self.y_np = np.random.rand(1, 2).astype('float32') + self.place = ( + paddle.CUDAPlace(0) + if paddle.is_compiled_with_cuda() + else paddle.CPUPlace() + ) + + def test_static_api(self): + paddle.enable_static() + with paddle.static.program_guard(paddle.static.Program()): + x = paddle.static.data( + name="X", shape=self.x_np.shape, dtype='float32' + ) + y = paddle.static.data( + name="Y", shape=self.y_np.shape, dtype='float32' + ) + out = paddle.nextafter(x, y) + exe = paddle.static.Executor(self.place) + res = exe.run( + feed={'X': self.x_np, 'Y': self.y_np}, fetch_list=[out] + ) + out_ref = ref_nextafter(self.x_np, self.y_np) + np.testing.assert_allclose(out_ref, res[0], rtol=1e-05) + + def test_dygraph_api(self): + paddle.disable_static(self.place) + x = paddle.to_tensor(self.x_np) + y = paddle.to_tensor(self.y_np) + out = paddle.nextafter(x, y) + out_ref = ref_nextafter(self.x_np, self.y_np) + np.testing.assert_allclose(out_ref, out.numpy(), rtol=1e-05) + paddle.enable_static() + + +if __name__ == "__main__": + unittest.main() diff --git a/python/paddle/tensor/math.py b/python/paddle/tensor/math.py index db834f379ba95..abfe44f86a2ed 100644 --- a/python/paddle/tensor/math.py +++ b/python/paddle/tensor/math.py @@ -5432,7 +5432,10 @@ def nextafter(x, y, name=None): .. code-block:: python import paddle - + out = paddle.nextafter(paddle.to_tensor([1.0,2.0]),paddle.to_tensor([2.0,1.0])) + print(out) + #Tensor(shape=[2], dtype=float32, place=Place(cpu), stop_gradient=True, + # [1.00000012, 1.99999988]) """ if in_dygraph_mode(): return _C_ops.nextafter(x, y) @@ -5442,6 +5445,6 @@ def nextafter(x, y, name=None): helper = LayerHelper('nextafter', **locals()) out = helper.create_variable_for_type_inference(dtype=paddle.float32) helper.append_op( - type='trace', inputs={'X': x, 'Y': y}, outputs={'Out': out} + type='nextafter', inputs={'X': x, 'Y': y}, outputs={'Out': out} ) return out From 798cc119d183d29d8538a0f91b49dd5b60a690f6 Mon Sep 17 00:00:00 2001 From: enkilee Date: Wed, 5 Apr 2023 00:51:44 +0000 Subject: [PATCH 04/27] add gpu file --- paddle/phi/kernels/gpu/nextafter_kernel.cu | 44 ++++++++++++++++++++++ 1 file changed, 44 insertions(+) create mode 100644 paddle/phi/kernels/gpu/nextafter_kernel.cu diff --git a/paddle/phi/kernels/gpu/nextafter_kernel.cu b/paddle/phi/kernels/gpu/nextafter_kernel.cu new file mode 100644 index 0000000000000..35d3e2ca9d742 --- /dev/null +++ b/paddle/phi/kernels/gpu/nextafter_kernel.cu @@ -0,0 +1,44 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/nextafter_kernel.h" +#include +#include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/funcs/math.h" + +namespace phi { + +template +void NextafterKernel(const Context& ctx, + const DenseTensor& x, + const DenseTensor& y, + DenseTensor* out) { + if (x.numel() == 0 || y.numel() == 0) { + return; + } + auto out_data = ctx.template Alloc(out); + auto x_data = x.data(); + auto y_data = y.data(); + int x_numel = x.numel(); + + for (int i = 0; i < x_numel; ++i) { + out_data[i] = std::nextafter(x_data[i], y_data[i]); + } +} + +} // namespace phi + +PD_REGISTER_KERNEL( + nextafter, GPU, ALL_LAYOUT, phi::NextafterKernel, float, double) {} From 4e771429c61ec4fdf2cf0f8d086e778698e52b53 Mon Sep 17 00:00:00 2001 From: enkilee Date: Wed, 5 Apr 2023 13:35:12 +0000 Subject: [PATCH 05/27] fix --- paddle/phi/kernels/nextafter_grad_kernel.h | 28 ---------------------- paddle/phi/kernels/nextafter_kernel.h | 8 +++---- 2 files changed, 4 insertions(+), 32 deletions(-) delete mode 100644 paddle/phi/kernels/nextafter_grad_kernel.h diff --git a/paddle/phi/kernels/nextafter_grad_kernel.h b/paddle/phi/kernels/nextafter_grad_kernel.h deleted file mode 100644 index 52ddf889cbf11..0000000000000 --- a/paddle/phi/kernels/nextafter_grad_kernel.h +++ /dev/null @@ -1,28 +0,0 @@ -// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#pragma once - -#include "paddle/phi/core/dense_tensor.h" - -namespace phi { - -template -void NextafterGradKernel(const Context& ctx, - const DenseTensor& x, - const DenseTensor& y, - const DenseTensor& out_grad, - DenseTensor* x_grad, - DenseTensor* y_grad); -} // namespace phi diff --git a/paddle/phi/kernels/nextafter_kernel.h b/paddle/phi/kernels/nextafter_kernel.h index 047917dbd6753..3a185e39bd940 100644 --- a/paddle/phi/kernels/nextafter_kernel.h +++ b/paddle/phi/kernels/nextafter_kernel.h @@ -20,9 +20,9 @@ namespace phi { template -void NextafterRawKernel(const Context& dev_ctx, - const DenseTensor& x, - const DenseTensor& y, - DenseTensor* out); +void NextafterKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& y, + DenseTensor* out); } // namespace phi From 28781c98305b7cde53d51a77499f1aa758a4c32c Mon Sep 17 00:00:00 2001 From: enkilee Date: Wed, 5 Apr 2023 14:24:08 +0000 Subject: [PATCH 06/27] fix --- paddle/phi/kernels/cpu/nextafter_kernel.cc | 26 +----------- paddle/phi/kernels/gpu/nextafter_kernel.cu | 26 +----------- .../phi/kernels/impl/nextafter_kernel_impl.h | 40 +++++++++++++++++++ 3 files changed, 44 insertions(+), 48 deletions(-) create mode 100644 paddle/phi/kernels/impl/nextafter_kernel_impl.h diff --git a/paddle/phi/kernels/cpu/nextafter_kernel.cc b/paddle/phi/kernels/cpu/nextafter_kernel.cc index e4ae3af65057b..ac4ab00a4d3fe 100644 --- a/paddle/phi/kernels/cpu/nextafter_kernel.cc +++ b/paddle/phi/kernels/cpu/nextafter_kernel.cc @@ -13,32 +13,10 @@ // limitations under the License. #include "paddle/phi/kernels/nextafter_kernel.h" -#include + #include "paddle/phi/backends/cpu/cpu_context.h" #include "paddle/phi/core/kernel_registry.h" -#include "paddle/phi/kernels/funcs/math.h" - -namespace phi { - -template -void NextafterKernel(const Context& ctx, - const DenseTensor& x, - const DenseTensor& y, - DenseTensor* out) { - if (x.numel() == 0 || y.numel() == 0) { - return; - } - auto out_data = ctx.template Alloc(out); - auto x_data = x.data(); - auto y_data = y.data(); - int x_numel = x.numel(); - - for (int i = 0; i < x_numel; ++i) { - out_data[i] = std::nextafter(x_data[i], y_data[i]); - } -} - -} // namespace phi +#include "paddle/phi/kernels/impl/nextafter_kernel_impl.h" PD_REGISTER_KERNEL( nextafter, CPU, ALL_LAYOUT, phi::NextafterKernel, float, double) {} diff --git a/paddle/phi/kernels/gpu/nextafter_kernel.cu b/paddle/phi/kernels/gpu/nextafter_kernel.cu index 35d3e2ca9d742..e0ac8212853c9 100644 --- a/paddle/phi/kernels/gpu/nextafter_kernel.cu +++ b/paddle/phi/kernels/gpu/nextafter_kernel.cu @@ -13,32 +13,10 @@ // limitations under the License. #include "paddle/phi/kernels/nextafter_kernel.h" -#include + #include "paddle/phi/backends/gpu/gpu_context.h" #include "paddle/phi/core/kernel_registry.h" -#include "paddle/phi/kernels/funcs/math.h" - -namespace phi { - -template -void NextafterKernel(const Context& ctx, - const DenseTensor& x, - const DenseTensor& y, - DenseTensor* out) { - if (x.numel() == 0 || y.numel() == 0) { - return; - } - auto out_data = ctx.template Alloc(out); - auto x_data = x.data(); - auto y_data = y.data(); - int x_numel = x.numel(); - - for (int i = 0; i < x_numel; ++i) { - out_data[i] = std::nextafter(x_data[i], y_data[i]); - } -} - -} // namespace phi +#include "paddle/phi/kernels/impl/nextafter_kernel_impl.h" PD_REGISTER_KERNEL( nextafter, GPU, ALL_LAYOUT, phi::NextafterKernel, float, double) {} diff --git a/paddle/phi/kernels/impl/nextafter_kernel_impl.h b/paddle/phi/kernels/impl/nextafter_kernel_impl.h new file mode 100644 index 0000000000000..2f88c849162e8 --- /dev/null +++ b/paddle/phi/kernels/impl/nextafter_kernel_impl.h @@ -0,0 +1,40 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once +#include +#include "paddle/phi/kernels/funcs/math.h" +#include "paddle/phi/kernels/nextafter_kernel.h" + +namespace phi { + +template +void NextafterKernel(const Context& ctx, + const DenseTensor& x, + const DenseTensor& y, + DenseTensor* out) { + if (x.numel() == 0 || y.numel() == 0) { + return; + } + auto out_data = ctx.template Alloc(out); + auto x_data = x.data(); + auto y_data = y.data(); + int x_numel = x.numel(); + + for (int i = 0; i < x_numel; ++i) { + out_data[i] = std::nextafter(x_data[i], y_data[i]); + } +} + +} // namespace phi From 6b88bfc543bebac529fe31e1258b8a7c47de1da9 Mon Sep 17 00:00:00 2001 From: enkilee Date: Thu, 6 Apr 2023 14:38:56 +0000 Subject: [PATCH 07/27] fix gpu --- .../phi/kernels/impl/nextafter_kernel_impl.h | 62 ++++++++++++++++--- 1 file changed, 53 insertions(+), 9 deletions(-) diff --git a/paddle/phi/kernels/impl/nextafter_kernel_impl.h b/paddle/phi/kernels/impl/nextafter_kernel_impl.h index 2f88c849162e8..d90f9b9640468 100644 --- a/paddle/phi/kernels/impl/nextafter_kernel_impl.h +++ b/paddle/phi/kernels/impl/nextafter_kernel_impl.h @@ -14,27 +14,71 @@ limitations under the License. */ #pragma once #include +#include "paddle/phi/core/dense_tensor.h" +#include "paddle/phi/core/device_context.h" +#include "paddle/phi/kernels/funcs/for_range.h" #include "paddle/phi/kernels/funcs/math.h" #include "paddle/phi/kernels/nextafter_kernel.h" - namespace phi { +template +struct NextafterOut { + using type = T; +}; + +template <> +struct NextafterOut { + using type = double; +}; + +template <> +struct NextafterOut { + using type = double; +}; +template +struct NextafterFunctor { + NextafterFunctor(const T* x, + const T* y, + typename NextafterOut::type* out, + int64_t numel) + : x_(x), y_(y), out_(out), numel_(numel) {} + + HOSTDEVICE void operator()(int64_t idx) const { + out_[idx] = static_cast::type>( + ::nextafter(static_cast(x_[idx]), static_cast(y_[idx]))); + } + const T* x_; + const T* y_; + typename NextafterOut::type* out_; + int64_t numel_; +}; +template <> +struct NextafterFunctor { + NextafterFunctor(const double* x, const double* y, double* out, int64_t numel) + : x_(x), y_(y), out_(out), numel_(numel) {} + + HOSTDEVICE void operator()(int64_t idx) const { + out_[idx] = ::nextafter(x_[idx], y_[idx]); + } + + const double* x_; + const double* y_; + double* out_; + int64_t numel_; +}; template void NextafterKernel(const Context& ctx, const DenseTensor& x, const DenseTensor& y, DenseTensor* out) { - if (x.numel() == 0 || y.numel() == 0) { - return; - } - auto out_data = ctx.template Alloc(out); + auto* out_data = ctx.template Alloc(out); auto x_data = x.data(); auto y_data = y.data(); - int x_numel = x.numel(); + auto x_numel = x.numel(); - for (int i = 0; i < x_numel; ++i) { - out_data[i] = std::nextafter(x_data[i], y_data[i]); - } + phi::funcs::ForRange for_range(ctx, x_numel); + phi::NextafterFunctor functor(x_data, y_data, out_data, x_numel); + for_range(functor); } } // namespace phi From 9d4b9236b8f7a927b91617f89818c41367718734 Mon Sep 17 00:00:00 2001 From: enkilee Date: Tue, 18 Apr 2023 01:02:20 +0000 Subject: [PATCH 08/27] add Optest --- .../tests/unittests/test_nextafter_op.py | 20 +++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/python/paddle/fluid/tests/unittests/test_nextafter_op.py b/python/paddle/fluid/tests/unittests/test_nextafter_op.py index 2de978afeb7c0..e441862da8ca5 100644 --- a/python/paddle/fluid/tests/unittests/test_nextafter_op.py +++ b/python/paddle/fluid/tests/unittests/test_nextafter_op.py @@ -15,6 +15,7 @@ import unittest import numpy as np +from eager_op_test import OpTest import paddle @@ -61,5 +62,24 @@ def test_dygraph_api(self): paddle.enable_static() +class TextNextafterOP(OpTest): + def setUp(self): + self.op_type = "nextafter" + self.python_api = paddle.nextafter + self.init_dtype() + + x = np.array([1, 2]).astype(self.dtype) + y = np.array([2, 1]).astype(self.dtype) + out = np.nextafter(x, y) + self.inputs = {'X': x, 'Y': y} + self.outputs = {'Out': out} + + def test_check_output(self): + self.check_output() + + def init_dtype(self): + self.dtype = np.float64 + + if __name__ == "__main__": unittest.main() From 6ad9198f13caa0ef319688d1e7ff04d9d31ead63 Mon Sep 17 00:00:00 2001 From: enkilee Date: Tue, 18 Apr 2023 06:46:49 +0000 Subject: [PATCH 09/27] add fp32 optest --- python/paddle/fluid/tests/unittests/test_nextafter_op.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/python/paddle/fluid/tests/unittests/test_nextafter_op.py b/python/paddle/fluid/tests/unittests/test_nextafter_op.py index e441862da8ca5..17224929fdcd5 100644 --- a/python/paddle/fluid/tests/unittests/test_nextafter_op.py +++ b/python/paddle/fluid/tests/unittests/test_nextafter_op.py @@ -1,4 +1,4 @@ -# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -62,7 +62,7 @@ def test_dygraph_api(self): paddle.enable_static() -class TextNextafterOP(OpTest): +class TestNextafterOP(OpTest): def setUp(self): self.op_type = "nextafter" self.python_api = paddle.nextafter @@ -81,5 +81,10 @@ def init_dtype(self): self.dtype = np.float64 +class TestNextafterOPFP32(TestNextafterOP): + def init_dtype(self): + self.dtype = np.float32 + + if __name__ == "__main__": unittest.main() From 233fa1d611f2deca7eaaa45341fd75746113f1cb Mon Sep 17 00:00:00 2001 From: enkilee Date: Fri, 21 Apr 2023 05:05:18 +0000 Subject: [PATCH 10/27] fix --- paddle/phi/api/yaml/ops.yaml | 2 +- paddle/phi/infermeta/binary.cc | 9 --------- paddle/phi/kernels/impl/nextafter_kernel_impl.h | 6 +++--- 3 files changed, 4 insertions(+), 13 deletions(-) diff --git a/paddle/phi/api/yaml/ops.yaml b/paddle/phi/api/yaml/ops.yaml index 737be63bef89d..9614c0893602d 100644 --- a/paddle/phi/api/yaml/ops.yaml +++ b/paddle/phi/api/yaml/ops.yaml @@ -1132,7 +1132,7 @@ args : (Tensor x, Tensor y) output : Tensor(out) infer_meta : - func : NextafterInferMeta + func : ElementwiseInferMeta kernel : func : nextafter diff --git a/paddle/phi/infermeta/binary.cc b/paddle/phi/infermeta/binary.cc index 8c349fa877924..0a3e31054b0f5 100644 --- a/paddle/phi/infermeta/binary.cc +++ b/paddle/phi/infermeta/binary.cc @@ -2282,15 +2282,6 @@ void MvInferMeta(const MetaTensor& x, const MetaTensor& vec, MetaTensor* out) { out->share_lod(x); } -void NextafterInferMeta(const MetaTensor& x, - const MetaTensor& y, - MetaTensor* out) { - auto x_dims = x.dims(); - out->set_dims(x_dims); - out->set_dtype(x.dtype()); - out->share_lod(x); -} - void PReluInferMeta(const MetaTensor& x, const MetaTensor& alpha, const std::string& data_format, diff --git a/paddle/phi/kernels/impl/nextafter_kernel_impl.h b/paddle/phi/kernels/impl/nextafter_kernel_impl.h index d90f9b9640468..6d54009282528 100644 --- a/paddle/phi/kernels/impl/nextafter_kernel_impl.h +++ b/paddle/phi/kernels/impl/nextafter_kernel_impl.h @@ -43,8 +43,8 @@ struct NextafterFunctor { : x_(x), y_(y), out_(out), numel_(numel) {} HOSTDEVICE void operator()(int64_t idx) const { - out_[idx] = static_cast::type>( - ::nextafter(static_cast(x_[idx]), static_cast(y_[idx]))); + out_[idx] = static_cast::type>(std::nextafter( + static_cast(x_[idx]), static_cast(y_[idx]))); } const T* x_; const T* y_; @@ -57,7 +57,7 @@ struct NextafterFunctor { : x_(x), y_(y), out_(out), numel_(numel) {} HOSTDEVICE void operator()(int64_t idx) const { - out_[idx] = ::nextafter(x_[idx], y_[idx]); + out_[idx] = std::nextafter(x_[idx], y_[idx]); } const double* x_; From 7cbffb0b1146f493ca8dc774b8d7d559c2bded96 Mon Sep 17 00:00:00 2001 From: enkilee Date: Fri, 21 Apr 2023 05:06:55 +0000 Subject: [PATCH 11/27] fix --- paddle/phi/api/yaml/op_compat.yaml | 6 ------ 1 file changed, 6 deletions(-) diff --git a/paddle/phi/api/yaml/op_compat.yaml b/paddle/phi/api/yaml/op_compat.yaml index d21b42ecece0d..e97af1e9c3b14 100644 --- a/paddle/phi/api/yaml/op_compat.yaml +++ b/paddle/phi/api/yaml/op_compat.yaml @@ -1393,12 +1393,6 @@ extra : attrs : [bool use_mkldnn = false] -- op : nextafter - inputs : - {x: X, y : Y} - outputs : - out: Out - - op : nll_loss backward : nll_loss_grad inputs : From 4dbb5fcf337e151692fff9171f96c6ad99c2d722 Mon Sep 17 00:00:00 2001 From: enkilee Date: Fri, 21 Apr 2023 06:41:47 +0000 Subject: [PATCH 12/27] fix --- python/paddle/fluid/tests/unittests/test_nextafter_op.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/paddle/fluid/tests/unittests/test_nextafter_op.py b/python/paddle/fluid/tests/unittests/test_nextafter_op.py index 17224929fdcd5..ba7866ab172ca 100644 --- a/python/paddle/fluid/tests/unittests/test_nextafter_op.py +++ b/python/paddle/fluid/tests/unittests/test_nextafter_op.py @@ -39,10 +39,10 @@ def test_static_api(self): paddle.enable_static() with paddle.static.program_guard(paddle.static.Program()): x = paddle.static.data( - name="X", shape=self.x_np.shape, dtype='float32' + name='X', shape=self.x_np.shape, dtype='float32' ) y = paddle.static.data( - name="Y", shape=self.y_np.shape, dtype='float32' + name='Y', shape=self.y_np.shape, dtype='float32' ) out = paddle.nextafter(x, y) exe = paddle.static.Executor(self.place) From 5e736caa6df902d63969008a055907e7bdf48683 Mon Sep 17 00:00:00 2001 From: enkilee Date: Fri, 21 Apr 2023 08:13:18 +0000 Subject: [PATCH 13/27] fix --- paddle/phi/api/yaml/ops.yaml | 1 + 1 file changed, 1 insertion(+) diff --git a/paddle/phi/api/yaml/ops.yaml b/paddle/phi/api/yaml/ops.yaml index 9614c0893602d..5423ba22a4fef 100644 --- a/paddle/phi/api/yaml/ops.yaml +++ b/paddle/phi/api/yaml/ops.yaml @@ -1132,6 +1132,7 @@ args : (Tensor x, Tensor y) output : Tensor(out) infer_meta : + param: [x, y] func : ElementwiseInferMeta kernel : func : nextafter From a0ef1da1ae13421ea7dc37a1c38c23efd2b32944 Mon Sep 17 00:00:00 2001 From: enkilee Date: Fri, 21 Apr 2023 14:55:33 +0000 Subject: [PATCH 14/27] fix --- paddle/phi/api/yaml/ops.yaml | 1 + python/paddle/__init__.py | 3 +-- python/paddle/tensor/__init__.py | 2 ++ python/paddle/tensor/math.py | 9 +-------- 4 files changed, 5 insertions(+), 10 deletions(-) diff --git a/paddle/phi/api/yaml/ops.yaml b/paddle/phi/api/yaml/ops.yaml index 5423ba22a4fef..8d4f3331343f8 100644 --- a/paddle/phi/api/yaml/ops.yaml +++ b/paddle/phi/api/yaml/ops.yaml @@ -1136,6 +1136,7 @@ func : ElementwiseInferMeta kernel : func : nextafter + backend : x - op : nll_loss args : (Tensor input, Tensor label, Tensor weight, int64_t ignore_index = -100, str reduction = "mean") diff --git a/python/paddle/__init__.py b/python/paddle/__init__.py index 6cab2c374e1be..417f83a8452b2 100644 --- a/python/paddle/__init__.py +++ b/python/paddle/__init__.py @@ -295,7 +295,7 @@ from .tensor.math import trapezoid # noqa: F401 from .tensor.math import cumulative_trapezoid # noqa: F401 from .tensor.math import vander # noqa: F401 -from .tensor.math import nextafter # noqa: F401 + from .tensor.random import bernoulli # noqa: F401 from .tensor.random import poisson # noqa: F401 @@ -688,5 +688,4 @@ 'cumulative_trapezoid', 'polar', 'vander', - 'nextafter', ] diff --git a/python/paddle/tensor/__init__.py b/python/paddle/tensor/__init__.py index b78ac0e57c22e..bea2fd7323d9f 100644 --- a/python/paddle/tensor/__init__.py +++ b/python/paddle/tensor/__init__.py @@ -251,6 +251,7 @@ from .math import sigmoid # noqa: F401 from .math import sigmoid_ # noqa: F401 from .math import vander # noqa: F401 +from .math import nextafter # noqa: F401 from .random import multinomial # noqa: F401 from .random import standard_normal # noqa: F401 @@ -540,6 +541,7 @@ 'sigmoid', 'sigmoid_', 'vander', + 'nextafter', ] # this list used in math_op_patch.py for magic_method bind diff --git a/python/paddle/tensor/math.py b/python/paddle/tensor/math.py index 31177ebad9e31..004b01270bf8c 100644 --- a/python/paddle/tensor/math.py +++ b/python/paddle/tensor/math.py @@ -5442,11 +5442,4 @@ def nextafter(x, y, name=None): if in_dygraph_mode(): return _C_ops.nextafter(x, y) else: - check_variable_and_dtype(x, 'x', ['float32', 'float64'], 'nextafter') - check_variable_and_dtype(y, 'y', ['float32', 'float64'], 'nextafter') - helper = LayerHelper('nextafter', **locals()) - out = helper.create_variable_for_type_inference(dtype=paddle.float32) - helper.append_op( - type='nextafter', inputs={'X': x, 'Y': y}, outputs={'Out': out} - ) - return out + return _elementwise_op(LayerHelper('nextafter', **locals())) From 88ec2979c09898ec101205e05107ba083424d2de Mon Sep 17 00:00:00 2001 From: enkilee Date: Fri, 21 Apr 2023 15:35:34 +0000 Subject: [PATCH 15/27] fix --- python/paddle/__init__.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/python/paddle/__init__.py b/python/paddle/__init__.py index 417f83a8452b2..6cab2c374e1be 100644 --- a/python/paddle/__init__.py +++ b/python/paddle/__init__.py @@ -295,7 +295,7 @@ from .tensor.math import trapezoid # noqa: F401 from .tensor.math import cumulative_trapezoid # noqa: F401 from .tensor.math import vander # noqa: F401 - +from .tensor.math import nextafter # noqa: F401 from .tensor.random import bernoulli # noqa: F401 from .tensor.random import poisson # noqa: F401 @@ -688,4 +688,5 @@ 'cumulative_trapezoid', 'polar', 'vander', + 'nextafter', ] From 38cf25c5187aef4a7d18ec41b78b0a3dc3a73bfa Mon Sep 17 00:00:00 2001 From: enkilee Date: Fri, 21 Apr 2023 16:13:56 +0000 Subject: [PATCH 16/27] re --- paddle/phi/api/yaml/ops.yaml | 4 +--- paddle/phi/infermeta/binary.cc | 9 +++++++++ python/paddle/tensor/math.py | 9 +++++++++ 3 files changed, 19 insertions(+), 3 deletions(-) diff --git a/paddle/phi/api/yaml/ops.yaml b/paddle/phi/api/yaml/ops.yaml index 8d4f3331343f8..737be63bef89d 100644 --- a/paddle/phi/api/yaml/ops.yaml +++ b/paddle/phi/api/yaml/ops.yaml @@ -1132,11 +1132,9 @@ args : (Tensor x, Tensor y) output : Tensor(out) infer_meta : - param: [x, y] - func : ElementwiseInferMeta + func : NextafterInferMeta kernel : func : nextafter - backend : x - op : nll_loss args : (Tensor input, Tensor label, Tensor weight, int64_t ignore_index = -100, str reduction = "mean") diff --git a/paddle/phi/infermeta/binary.cc b/paddle/phi/infermeta/binary.cc index 0a3e31054b0f5..8c349fa877924 100644 --- a/paddle/phi/infermeta/binary.cc +++ b/paddle/phi/infermeta/binary.cc @@ -2282,6 +2282,15 @@ void MvInferMeta(const MetaTensor& x, const MetaTensor& vec, MetaTensor* out) { out->share_lod(x); } +void NextafterInferMeta(const MetaTensor& x, + const MetaTensor& y, + MetaTensor* out) { + auto x_dims = x.dims(); + out->set_dims(x_dims); + out->set_dtype(x.dtype()); + out->share_lod(x); +} + void PReluInferMeta(const MetaTensor& x, const MetaTensor& alpha, const std::string& data_format, diff --git a/python/paddle/tensor/math.py b/python/paddle/tensor/math.py index 004b01270bf8c..6539314201bf2 100644 --- a/python/paddle/tensor/math.py +++ b/python/paddle/tensor/math.py @@ -5443,3 +5443,12 @@ def nextafter(x, y, name=None): return _C_ops.nextafter(x, y) else: return _elementwise_op(LayerHelper('nextafter', **locals())) + check_variable_and_dtype(x, 'x', ['float32', 'float64'], 'nextafter') + check_variable_and_dtype(y, 'y', ['float32', 'float64'], 'nextafter') + op_type = "nextafter" + helper = LayerHelper(op_type, **locals()) + inputs = {"X": x, "Y": y} + out = helper.create_variable_for_type_inference(dtype=paddle.float32) + outputs = {"Out": out} + helper.append_op(type=op_type, inputs=inputs, outputs=outputs) + return out From 20b587f1d9333e13ef6c7083b3b11076ae7c6fe5 Mon Sep 17 00:00:00 2001 From: enkilee Date: Sat, 22 Apr 2023 05:46:01 +0000 Subject: [PATCH 17/27] fix --- paddle/phi/api/yaml/ops.yaml | 4 +++- paddle/phi/infermeta/binary.cc | 9 --------- python/paddle/tensor/math.py | 3 +-- 3 files changed, 4 insertions(+), 12 deletions(-) diff --git a/paddle/phi/api/yaml/ops.yaml b/paddle/phi/api/yaml/ops.yaml index 737be63bef89d..e98b2c66d19f7 100644 --- a/paddle/phi/api/yaml/ops.yaml +++ b/paddle/phi/api/yaml/ops.yaml @@ -1132,9 +1132,11 @@ args : (Tensor x, Tensor y) output : Tensor(out) infer_meta : - func : NextafterInferMeta + func : ElementwiseInferMeta + param: [x, y] kernel : func : nextafter + data_type : x - op : nll_loss args : (Tensor input, Tensor label, Tensor weight, int64_t ignore_index = -100, str reduction = "mean") diff --git a/paddle/phi/infermeta/binary.cc b/paddle/phi/infermeta/binary.cc index 8c349fa877924..0a3e31054b0f5 100644 --- a/paddle/phi/infermeta/binary.cc +++ b/paddle/phi/infermeta/binary.cc @@ -2282,15 +2282,6 @@ void MvInferMeta(const MetaTensor& x, const MetaTensor& vec, MetaTensor* out) { out->share_lod(x); } -void NextafterInferMeta(const MetaTensor& x, - const MetaTensor& y, - MetaTensor* out) { - auto x_dims = x.dims(); - out->set_dims(x_dims); - out->set_dtype(x.dtype()); - out->share_lod(x); -} - void PReluInferMeta(const MetaTensor& x, const MetaTensor& alpha, const std::string& data_format, diff --git a/python/paddle/tensor/math.py b/python/paddle/tensor/math.py index 6539314201bf2..eb5ddb5c3922f 100644 --- a/python/paddle/tensor/math.py +++ b/python/paddle/tensor/math.py @@ -5442,7 +5442,6 @@ def nextafter(x, y, name=None): if in_dygraph_mode(): return _C_ops.nextafter(x, y) else: - return _elementwise_op(LayerHelper('nextafter', **locals())) check_variable_and_dtype(x, 'x', ['float32', 'float64'], 'nextafter') check_variable_and_dtype(y, 'y', ['float32', 'float64'], 'nextafter') op_type = "nextafter" @@ -5451,4 +5450,4 @@ def nextafter(x, y, name=None): out = helper.create_variable_for_type_inference(dtype=paddle.float32) outputs = {"Out": out} helper.append_op(type=op_type, inputs=inputs, outputs=outputs) - return out + return out From de13f4b7570d4aa6268d6be64a8981850f2fe077 Mon Sep 17 00:00:00 2001 From: enkilee Date: Sat, 22 Apr 2023 06:26:31 +0000 Subject: [PATCH 18/27] fix --- .../paddle/fluid/tests/unittests/test_nextafter_op.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/python/paddle/fluid/tests/unittests/test_nextafter_op.py b/python/paddle/fluid/tests/unittests/test_nextafter_op.py index ba7866ab172ca..91d8085a88e0e 100644 --- a/python/paddle/fluid/tests/unittests/test_nextafter_op.py +++ b/python/paddle/fluid/tests/unittests/test_nextafter_op.py @@ -39,15 +39,15 @@ def test_static_api(self): paddle.enable_static() with paddle.static.program_guard(paddle.static.Program()): x = paddle.static.data( - name='X', shape=self.x_np.shape, dtype='float32' + name='x', shape=self.x_np.shape, dtype='float32' ) y = paddle.static.data( - name='Y', shape=self.y_np.shape, dtype='float32' + name='y', shape=self.y_np.shape, dtype='float32' ) out = paddle.nextafter(x, y) exe = paddle.static.Executor(self.place) res = exe.run( - feed={'X': self.x_np, 'Y': self.y_np}, fetch_list=[out] + feed={'x': self.x_np, 'y': self.y_np}, fetch_list=[out] ) out_ref = ref_nextafter(self.x_np, self.y_np) np.testing.assert_allclose(out_ref, res[0], rtol=1e-05) @@ -71,8 +71,8 @@ def setUp(self): x = np.array([1, 2]).astype(self.dtype) y = np.array([2, 1]).astype(self.dtype) out = np.nextafter(x, y) - self.inputs = {'X': x, 'Y': y} - self.outputs = {'Out': out} + self.inputs = {'x': x, 'y': y} + self.outputs = {'out': out} def test_check_output(self): self.check_output() From d0068ef052292179c921d5929e8a8083fb3166be Mon Sep 17 00:00:00 2001 From: enkilee Date: Sat, 22 Apr 2023 07:02:49 +0000 Subject: [PATCH 19/27] fix --- .../fluid/tests/unittests/test_nextafter_op.py | 14 ++++++-------- python/paddle/tensor/math.py | 4 ++-- 2 files changed, 8 insertions(+), 10 deletions(-) diff --git a/python/paddle/fluid/tests/unittests/test_nextafter_op.py b/python/paddle/fluid/tests/unittests/test_nextafter_op.py index 91d8085a88e0e..b6608c0466eac 100644 --- a/python/paddle/fluid/tests/unittests/test_nextafter_op.py +++ b/python/paddle/fluid/tests/unittests/test_nextafter_op.py @@ -27,8 +27,8 @@ def ref_nextafter(x, y): class TestNextafterAPI(unittest.TestCase): def setUp(self): - self.x_np = np.random.rand(1, 2).astype('float32') - self.y_np = np.random.rand(1, 2).astype('float32') + self.x = np.random.rand(1, 2).astype('float32') + self.y = np.random.rand(1, 2).astype('float32') self.place = ( paddle.CUDAPlace(0) if paddle.is_compiled_with_cuda() @@ -39,17 +39,15 @@ def test_static_api(self): paddle.enable_static() with paddle.static.program_guard(paddle.static.Program()): x = paddle.static.data( - name='x', shape=self.x_np.shape, dtype='float32' + name='x', shape=self.x.shape, dtype='float32' ) y = paddle.static.data( - name='y', shape=self.y_np.shape, dtype='float32' + name='y', shape=self.y.shape, dtype='float32' ) out = paddle.nextafter(x, y) exe = paddle.static.Executor(self.place) - res = exe.run( - feed={'x': self.x_np, 'y': self.y_np}, fetch_list=[out] - ) - out_ref = ref_nextafter(self.x_np, self.y_np) + res = exe.run(feed={'x': self.x, 'y': self.y}, fetch_list=[out]) + out_ref = ref_nextafter(self.x, self.y) np.testing.assert_allclose(out_ref, res[0], rtol=1e-05) def test_dygraph_api(self): diff --git a/python/paddle/tensor/math.py b/python/paddle/tensor/math.py index eb5ddb5c3922f..8996b1ec61097 100644 --- a/python/paddle/tensor/math.py +++ b/python/paddle/tensor/math.py @@ -5446,8 +5446,8 @@ def nextafter(x, y, name=None): check_variable_and_dtype(y, 'y', ['float32', 'float64'], 'nextafter') op_type = "nextafter" helper = LayerHelper(op_type, **locals()) - inputs = {"X": x, "Y": y} + inputs = {"x": x, "y": y} out = helper.create_variable_for_type_inference(dtype=paddle.float32) - outputs = {"Out": out} + outputs = {"out": out} helper.append_op(type=op_type, inputs=inputs, outputs=outputs) return out From e8db15f11ee57443d904a52265effa5f021fa43f Mon Sep 17 00:00:00 2001 From: enkilee Date: Sat, 22 Apr 2023 07:04:49 +0000 Subject: [PATCH 20/27] fix --- python/paddle/fluid/tests/unittests/test_nextafter_op.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/python/paddle/fluid/tests/unittests/test_nextafter_op.py b/python/paddle/fluid/tests/unittests/test_nextafter_op.py index b6608c0466eac..34c4cb344b600 100644 --- a/python/paddle/fluid/tests/unittests/test_nextafter_op.py +++ b/python/paddle/fluid/tests/unittests/test_nextafter_op.py @@ -52,10 +52,10 @@ def test_static_api(self): def test_dygraph_api(self): paddle.disable_static(self.place) - x = paddle.to_tensor(self.x_np) - y = paddle.to_tensor(self.y_np) + x = paddle.to_tensor(self.x) + y = paddle.to_tensor(self.y) out = paddle.nextafter(x, y) - out_ref = ref_nextafter(self.x_np, self.y_np) + out_ref = ref_nextafter(self.x, self.y) np.testing.assert_allclose(out_ref, out.numpy(), rtol=1e-05) paddle.enable_static() From ef9aac7a0ac8053cb4180860e4e8b065e35556d2 Mon Sep 17 00:00:00 2001 From: enkilee Date: Tue, 25 Apr 2023 02:33:13 +0000 Subject: [PATCH 21/27] fix --- paddle/phi/infermeta/binary.h | 4 --- .../tests/unittests/test_nextafter_op.py | 30 +++++++++++++++++++ 2 files changed, 30 insertions(+), 4 deletions(-) diff --git a/paddle/phi/infermeta/binary.h b/paddle/phi/infermeta/binary.h index e643c33b33c39..ed4da703ce520 100644 --- a/paddle/phi/infermeta/binary.h +++ b/paddle/phi/infermeta/binary.h @@ -366,10 +366,6 @@ void MatrixRankTolInferMeta(const MetaTensor& x, void MvInferMeta(const MetaTensor& x, const MetaTensor& vec, MetaTensor* out); -void NextafterInferMeta(const MetaTensor& x, - const MetaTensor& y, - MetaTensor* out); - void PReluInferMeta(const MetaTensor& x, const MetaTensor& alpha, const std::string& data_format, diff --git a/python/paddle/fluid/tests/unittests/test_nextafter_op.py b/python/paddle/fluid/tests/unittests/test_nextafter_op.py index 34c4cb344b600..8079ef4661ebf 100644 --- a/python/paddle/fluid/tests/unittests/test_nextafter_op.py +++ b/python/paddle/fluid/tests/unittests/test_nextafter_op.py @@ -29,6 +29,10 @@ class TestNextafterAPI(unittest.TestCase): def setUp(self): self.x = np.random.rand(1, 2).astype('float32') self.y = np.random.rand(1, 2).astype('float32') + self.x1 = np.array([0, 0, 10]).astype("float32") + self.y1 = np.array([np.inf, -np.inf, 10]).astype("float32") + self.x2 = np.array([np.nan, 0]).astype("float32") + self.y2 = np.array([0, np.nan]).astype("float32") self.place = ( paddle.CUDAPlace(0) if paddle.is_compiled_with_cuda() @@ -50,6 +54,32 @@ def test_static_api(self): out_ref = ref_nextafter(self.x, self.y) np.testing.assert_allclose(out_ref, res[0], rtol=1e-05) + with paddle.static.program_guard(paddle.static.Program()): + x1 = paddle.static.data( + name='x', shape=self.x1.shape, dtype='float32' + ) + y1 = paddle.static.data( + name='y', shape=self.y1.shape, dtype='float32' + ) + out = paddle.nextafter(x1, y1) + exe = paddle.static.Executor(self.place) + res = exe.run(feed={'x': self.x1, 'y': self.y1}, fetch_list=[out]) + out_ref = ref_nextafter(self.x1, self.y1) + np.testing.assert_allclose(out_ref, res[0], rtol=1e-05) + + with paddle.static.program_guard(paddle.static.Program()): + x2 = paddle.static.data( + name='x', shape=self.x2.shape, dtype='float32' + ) + y2 = paddle.static.data( + name='y', shape=self.y2.shape, dtype='float32' + ) + out = paddle.nextafter(x2, y2) + exe = paddle.static.Executor(self.place) + res = exe.run(feed={'x': self.x2, 'y': self.y2}, fetch_list=[out]) + out_ref = ref_nextafter(self.x2, self.y2) + self.assertTrue((out.numpy() == out_ref).all(), True) + def test_dygraph_api(self): paddle.disable_static(self.place) x = paddle.to_tensor(self.x) From 458071ac6cd1ec5db04eb69af218ed13c72a09c3 Mon Sep 17 00:00:00 2001 From: enkilee Date: Tue, 25 Apr 2023 03:47:08 +0000 Subject: [PATCH 22/27] fix --- python/paddle/fluid/tests/unittests/test_nextafter_op.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/python/paddle/fluid/tests/unittests/test_nextafter_op.py b/python/paddle/fluid/tests/unittests/test_nextafter_op.py index 8079ef4661ebf..5913a278610e5 100644 --- a/python/paddle/fluid/tests/unittests/test_nextafter_op.py +++ b/python/paddle/fluid/tests/unittests/test_nextafter_op.py @@ -31,8 +31,8 @@ def setUp(self): self.y = np.random.rand(1, 2).astype('float32') self.x1 = np.array([0, 0, 10]).astype("float32") self.y1 = np.array([np.inf, -np.inf, 10]).astype("float32") - self.x2 = np.array([np.nan, 0]).astype("float32") - self.y2 = np.array([0, np.nan]).astype("float32") + self.x2 = np.random.rand(100).astype("float32") + self.y2 = np.random.rand(100).astype("float32") self.place = ( paddle.CUDAPlace(0) if paddle.is_compiled_with_cuda() @@ -78,7 +78,7 @@ def test_static_api(self): exe = paddle.static.Executor(self.place) res = exe.run(feed={'x': self.x2, 'y': self.y2}, fetch_list=[out]) out_ref = ref_nextafter(self.x2, self.y2) - self.assertTrue((out.numpy() == out_ref).all(), True) + np.testing.assert_allclose(out_ref, res[0], rtol=1e-05) def test_dygraph_api(self): paddle.disable_static(self.place) From 4e433d260515d65dc6c052b76d3a72f572dcc583 Mon Sep 17 00:00:00 2001 From: enkilee Date: Tue, 25 Apr 2023 04:33:26 +0000 Subject: [PATCH 23/27] add test --- .../paddle/fluid/tests/unittests/test_nextafter_op.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/python/paddle/fluid/tests/unittests/test_nextafter_op.py b/python/paddle/fluid/tests/unittests/test_nextafter_op.py index 5913a278610e5..c7e0319b1d457 100644 --- a/python/paddle/fluid/tests/unittests/test_nextafter_op.py +++ b/python/paddle/fluid/tests/unittests/test_nextafter_op.py @@ -80,6 +80,17 @@ def test_static_api(self): out_ref = ref_nextafter(self.x2, self.y2) np.testing.assert_allclose(out_ref, res[0], rtol=1e-05) + def test_nan_case(self): + paddle.disable_static() + x_data = np.array([0, np.nan]).astype("float32") + y_data = np.array([np.nan, 0]).astype("float32") + x = paddle.to_tensor(x_data) + y = paddle.to_tensor(y_data) + out = paddle.nextafter(x, y) + expected_out = np.nextafter(x_data, y_data) + self.assertTrue((out.numpy() == expected_out).all(), True) + paddle.enable_static() + def test_dygraph_api(self): paddle.disable_static(self.place) x = paddle.to_tensor(self.x) From fe308eabf78a9effdfa2689a4fa117b1719ead33 Mon Sep 17 00:00:00 2001 From: enkilee Date: Tue, 25 Apr 2023 06:53:11 +0000 Subject: [PATCH 24/27] fix --- python/paddle/fluid/tests/unittests/test_nextafter_op.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/paddle/fluid/tests/unittests/test_nextafter_op.py b/python/paddle/fluid/tests/unittests/test_nextafter_op.py index c7e0319b1d457..2a56c9932c123 100644 --- a/python/paddle/fluid/tests/unittests/test_nextafter_op.py +++ b/python/paddle/fluid/tests/unittests/test_nextafter_op.py @@ -87,8 +87,8 @@ def test_nan_case(self): x = paddle.to_tensor(x_data) y = paddle.to_tensor(y_data) out = paddle.nextafter(x, y) - expected_out = np.nextafter(x_data, y_data) - self.assertTrue((out.numpy() == expected_out).all(), True) + expected_out = ref_nextafter(x_data, y_data) + self.assertEqual(out.numpy(), expected_out) paddle.enable_static() def test_dygraph_api(self): From 9cd6e14dc222577d766843f9a2825868acc886c3 Mon Sep 17 00:00:00 2001 From: enkilee Date: Tue, 25 Apr 2023 07:07:53 +0000 Subject: [PATCH 25/27] fix --- python/paddle/fluid/tests/unittests/test_nextafter_op.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/paddle/fluid/tests/unittests/test_nextafter_op.py b/python/paddle/fluid/tests/unittests/test_nextafter_op.py index 2a56c9932c123..d8356622e5a04 100644 --- a/python/paddle/fluid/tests/unittests/test_nextafter_op.py +++ b/python/paddle/fluid/tests/unittests/test_nextafter_op.py @@ -88,7 +88,7 @@ def test_nan_case(self): y = paddle.to_tensor(y_data) out = paddle.nextafter(x, y) expected_out = ref_nextafter(x_data, y_data) - self.assertEqual(out.numpy(), expected_out) + self.assertEqual((out.numpy() == expected_out).all(), True) paddle.enable_static() def test_dygraph_api(self): From 6cdfa21ba0217774edab9b6d8902a2b6efd56784 Mon Sep 17 00:00:00 2001 From: enkilee Date: Tue, 25 Apr 2023 09:12:32 +0000 Subject: [PATCH 26/27] fix --- .../paddle/fluid/tests/unittests/test_nextafter_op.py | 11 ----------- 1 file changed, 11 deletions(-) diff --git a/python/paddle/fluid/tests/unittests/test_nextafter_op.py b/python/paddle/fluid/tests/unittests/test_nextafter_op.py index d8356622e5a04..5913a278610e5 100644 --- a/python/paddle/fluid/tests/unittests/test_nextafter_op.py +++ b/python/paddle/fluid/tests/unittests/test_nextafter_op.py @@ -80,17 +80,6 @@ def test_static_api(self): out_ref = ref_nextafter(self.x2, self.y2) np.testing.assert_allclose(out_ref, res[0], rtol=1e-05) - def test_nan_case(self): - paddle.disable_static() - x_data = np.array([0, np.nan]).astype("float32") - y_data = np.array([np.nan, 0]).astype("float32") - x = paddle.to_tensor(x_data) - y = paddle.to_tensor(y_data) - out = paddle.nextafter(x, y) - expected_out = ref_nextafter(x_data, y_data) - self.assertEqual((out.numpy() == expected_out).all(), True) - paddle.enable_static() - def test_dygraph_api(self): paddle.disable_static(self.place) x = paddle.to_tensor(self.x) From 03478c305805b3f25313fedb76557b74c727ed3f Mon Sep 17 00:00:00 2001 From: enkilee Date: Tue, 25 Apr 2023 15:25:23 +0000 Subject: [PATCH 27/27] fix --- python/paddle/fluid/tests/unittests/test_nextafter_op.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/paddle/fluid/tests/unittests/test_nextafter_op.py b/python/paddle/fluid/tests/unittests/test_nextafter_op.py index 5913a278610e5..5048778e1b7b7 100644 --- a/python/paddle/fluid/tests/unittests/test_nextafter_op.py +++ b/python/paddle/fluid/tests/unittests/test_nextafter_op.py @@ -27,8 +27,8 @@ def ref_nextafter(x, y): class TestNextafterAPI(unittest.TestCase): def setUp(self): - self.x = np.random.rand(1, 2).astype('float32') - self.y = np.random.rand(1, 2).astype('float32') + self.x = np.random.rand(2, 3, 4, 5).astype('float32') + self.y = np.random.rand(2, 3, 4, 5).astype('float32') self.x1 = np.array([0, 0, 10]).astype("float32") self.y1 = np.array([np.inf, -np.inf, 10]).astype("float32") self.x2 = np.random.rand(100).astype("float32")