From 21c5464b5b4c3d5be71db9ac631258e4d9723616 Mon Sep 17 00:00:00 2001 From: Courtesy-Xs Date: Thu, 13 Apr 2023 08:25:21 +0000 Subject: [PATCH 01/24] add index_put api --- paddle/phi/api/yaml/backward.yaml | 10 + paddle/phi/api/yaml/ops.yaml | 10 + paddle/phi/infermeta/multiary.cc | 15 + paddle/phi/infermeta/multiary.h | 6 + .../phi/kernels/cpu/index_put_grad_kernel.cc | 393 +++++++++++++++++ paddle/phi/kernels/cpu/index_put_kernel.cc | 274 ++++++++++++ .../phi/kernels/gpu/index_put_grad_kernel.cu | 414 ++++++++++++++++++ paddle/phi/kernels/gpu/index_put_kernel.cu | 284 ++++++++++++ paddle/phi/kernels/index_put_grad_kernel.h | 185 ++++++++ paddle/phi/kernels/index_put_kernel.h | 182 ++++++++ python/paddle/__init__.py | 4 + .../tests/unittests/test_index_put_op.py | 341 +++++++++++++++ python/paddle/tensor/__init__.py | 5 +- python/paddle/tensor/manipulation.py | 110 +++++ 14 files changed, 2232 insertions(+), 1 deletion(-) create mode 100644 paddle/phi/kernels/cpu/index_put_grad_kernel.cc create mode 100644 paddle/phi/kernels/cpu/index_put_kernel.cc create mode 100644 paddle/phi/kernels/gpu/index_put_grad_kernel.cu create mode 100644 paddle/phi/kernels/gpu/index_put_kernel.cu create mode 100644 paddle/phi/kernels/index_put_grad_kernel.h create mode 100644 paddle/phi/kernels/index_put_kernel.h create mode 100644 python/paddle/fluid/tests/unittests/test_index_put_op.py diff --git a/paddle/phi/api/yaml/backward.yaml b/paddle/phi/api/yaml/backward.yaml index 6c0184d7d0ced4..d6e416a80bed6f 100644 --- a/paddle/phi/api/yaml/backward.yaml +++ b/paddle/phi/api/yaml/backward.yaml @@ -796,6 +796,16 @@ data_type : out_grad inplace : (out_grad -> x_grad) +- backward_op : index_put_grad + forward : index_put (Tensor x, Tensor[] indices, Tensor value, bool accumulate=false) -> Tensor(out) + args : (Tensor x, Tensor[] indices, Tensor value, Tensor out_grad, bool accumulate=false) + output : Tensor(x_grad), Tensor(value_grad) + infer_meta : + func : GeneralBinaryGradInferMeta + param : [x, value] + kernel : + func : index_put_grad + - backward_op : index_sample_grad forward : index_sample (Tensor x, Tensor index) -> Tensor(out) args : (Tensor x, Tensor index, Tensor out_grad) diff --git a/paddle/phi/api/yaml/ops.yaml b/paddle/phi/api/yaml/ops.yaml index f1cc5d1b5395f7..bdbcd12f56d15b 100644 --- a/paddle/phi/api/yaml/ops.yaml +++ b/paddle/phi/api/yaml/ops.yaml @@ -5,6 +5,16 @@ # are consistent and correspond one-to-one. It's forbidden that the # operator configured in this yaml file does not have Python API. + - op : index_put + args : (Tensor x, Tensor[] indices, Tensor value, bool accumulate=false) + output : Tensor(out) + infer_meta : + func : IndexPutInferMeta + kernel : + func : index_put + inplace : (x -> out) + backward : index_put_grad + - op : accuracy args : (Tensor x, Tensor indices, Tensor label) output : Tensor(accuracy), Tensor(correct), Tensor(total) diff --git a/paddle/phi/infermeta/multiary.cc b/paddle/phi/infermeta/multiary.cc index bd38a2ec521d9b..0d681d9749f8c8 100644 --- a/paddle/phi/infermeta/multiary.cc +++ b/paddle/phi/infermeta/multiary.cc @@ -3241,5 +3241,20 @@ void MoeInferMeta(const MetaTensor& x, out->set_layout(x.layout()); } +void IndexPutInferMeta(const MetaTensor& x, + const std::vector& indices, + const MetaTensor& value, + bool accumulate, + MetaTensor* out) { + auto in_dims = x.dims(); + PADDLE_ENFORCE_LT( + in_dims.size(), + 7, + phi::errors::InvalidArgument( + "The rank of input should be less than 7, but received %d.", + in_dims.size())); + out->share_meta(x); +} + } // namespace phi PD_REGISTER_INFER_META_FN(batch_norm_infer, phi::BatchNormInferInferMeta); diff --git a/paddle/phi/infermeta/multiary.h b/paddle/phi/infermeta/multiary.h index 307e6115cfd566..d924942fc5ef1b 100644 --- a/paddle/phi/infermeta/multiary.h +++ b/paddle/phi/infermeta/multiary.h @@ -605,4 +605,10 @@ void MoeInferMeta(const MetaTensor& x, const std::string& act_type, MetaTensor* out); +void IndexPutInferMeta(const MetaTensor& x, + const std::vector& indices, + const MetaTensor& value, + bool accumulate, + MetaTensor* out); + } // namespace phi diff --git a/paddle/phi/kernels/cpu/index_put_grad_kernel.cc b/paddle/phi/kernels/cpu/index_put_grad_kernel.cc new file mode 100644 index 00000000000000..32a6c8b30c3d62 --- /dev/null +++ b/paddle/phi/kernels/cpu/index_put_grad_kernel.cc @@ -0,0 +1,393 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/index_put_grad_kernel.h" +#include +#include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/backends/gpu/gpu_launch_config.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/core/utils/array.h" +#include "paddle/phi/kernels/cast_kernel.h" +#include "paddle/phi/kernels/expand_kernel.h" +#include "paddle/phi/kernels/reduce_sum_kernel.h" +#include "paddle/phi/kernels/reshape_kernel.h" +namespace phi { + +template +void range_kernel(int64_t N, T* out) { + for (int64_t idx = 0; idx < N; ++idx) { + out[idx] = idx; + } +} + +template +phi::DenseTensor GetRangeTensor(const Context& dev_ctx, + int64_t N, + phi::DataType dtype) { + phi::DenseTensor res(dtype); + res.Resize(phi::make_ddim({N})); + DenseTensor* p_res = &res; + T* out = dev_ctx.template Alloc(p_res); + range_kernel(N, out); + return res; +} + +template +void set_zero_kernel(const int64_t N, + const int64_t** indices, + phi::Array stride, + phi::Array shape, + T* out) { +#ifdef PADDLE_WITH_MKLML +#pragma omp parallel for +#endif + for (int64_t idx = 0; idx < N; ++idx) { + int64_t cur_ix = 0; + int64_t offset = 0; + + for (size_t i = 0; i < Rank; ++i) { + cur_ix = (int64_t(*(indices[i] + idx))); + if (cur_ix < 0) { + cur_ix += shape[i]; + } + offset += stride[i] * cur_ix; + } + *(out + idx) = 0; + } +} + +template +void index_put_grad_kernel(const int64_t N, + const T* out_grad, + const int64_t** indices, + phi::Array stride, + phi::Array shape, + T* value_grad) { +#ifdef PADDLE_WITH_MKLML +#pragma omp parallel for +#endif + for (int64_t idx = 0; idx < N; ++idx) { + int64_t cur_ix = 0; + int64_t offset = 0; + + for (size_t i = 0; i < Rank; ++i) { + cur_ix = (int64_t(*(indices[i] + idx))); + if (cur_ix < 0) { + cur_ix += shape[i]; + } + offset += stride[i] * cur_ix; + } + *(value_grad + idx) = *(out_grad + offset); + } +} + +template +void LaunchIndexPutGradKernel(const Context& dev_ctx, + const std::vector& indices_v, + const DenseTensor& out_grad, + bool accumulate, + DenseTensor* value_grad, + DenseTensor* x_grad) { + if (x_grad) { + phi::Copy(dev_ctx, out_grad, dev_ctx.GetPlace(), false, x_grad); + if (accumulate) { + T* x_grad_data = x_grad->data(); + + auto x_grad_dims = x_grad->dims(); + const int64_t numel = indices_v[0]->numel(); + auto x_grad_stride = phi::stride(x_grad_dims); + + phi::Array stride_a; + phi::Array shape_a; + + for (size_t idx = 0; idx < Rank; ++idx) { + stride_a[idx] = x_grad_stride[idx]; + shape_a[idx] = x_grad_dims[idx]; + } + + const int64_t* pd_indices[Rank]; + for (size_t i = 0; i < Rank; ++i) { + pd_indices[i] = indices_v[i]->data(); + } + set_zero_kernel( + numel, pd_indices, stride_a, shape_a, x_grad_data); + } + } + if (value_grad) { + if (value_grad->numel() == 1) { + DenseTensor tmp_value_grad(value_grad->dtype()); + tmp_value_grad.Resize(indices_v[0]->dims()); + + T* tmp_value_grad_data = dev_ctx.template Alloc(&tmp_value_grad); + auto out_grad_data = out_grad.data(); + + auto out_grad_dims = out_grad.dims(); + const int64_t numel = indices_v[0]->numel(); + auto out_grad_stride = phi::stride(out_grad_dims); + + phi::Array stride_a; + phi::Array shape_a; + + for (size_t idx = 0; idx < Rank; ++idx) { + stride_a[idx] = out_grad_stride[idx]; + shape_a[idx] = out_grad_dims[idx]; + } + + const int64_t* pd_indices[Rank]; + for (size_t i = 0; i < Rank; ++i) { + pd_indices[i] = indices_v[i]->data(); + } + index_put_grad_kernel(numel, + out_grad_data, + pd_indices, + stride_a, + shape_a, + tmp_value_grad_data); + + std::vector v_dims(tmp_value_grad.dims().size()); + std::iota(v_dims.begin(), v_dims.end(), 0); + IntArray v_axis(v_dims); + SumKernel(dev_ctx, + tmp_value_grad, + v_axis, + value_grad->dtype(), + false, + value_grad); + } else if (value_grad->numel() == indices_v[0]->numel()) { + T* value_grad_data = dev_ctx.template Alloc(value_grad); + auto out_grad_data = out_grad.data(); + + auto out_grad_dims = out_grad.dims(); + const int64_t numel = indices_v[0]->numel(); + auto out_grad_stride = phi::stride(out_grad_dims); + + phi::Array stride_a; + phi::Array shape_a; + + for (size_t idx = 0; idx < Rank; ++idx) { + stride_a[idx] = out_grad_stride[idx]; + shape_a[idx] = out_grad_dims[idx]; + } + + const int64_t* pd_indices[Rank]; + for (size_t i = 0; i < Rank; ++i) { + pd_indices[i] = indices_v[i]->data(); + } + index_put_grad_kernel( + numel, out_grad_data, pd_indices, stride_a, shape_a, value_grad_data); + } else { + DenseTensor tmp_value_grad(value_grad->dtype()); + tmp_value_grad.Resize(indices_v[0]->dims()); + + T* tmp_value_grad_data = dev_ctx.template Alloc(&tmp_value_grad); + auto out_grad_data = out_grad.data(); + + auto out_grad_dims = out_grad.dims(); + const int64_t numel = indices_v[0]->numel(); + auto out_grad_stride = phi::stride(out_grad_dims); + + phi::Array stride_a; + phi::Array shape_a; + + for (size_t idx = 0; idx < Rank; ++idx) { + stride_a[idx] = out_grad_stride[idx]; + shape_a[idx] = out_grad_dims[idx]; + } + + const int64_t* pd_indices[Rank]; + for (size_t i = 0; i < Rank; ++i) { + pd_indices[i] = indices_v[i]->data(); + } + index_put_grad_kernel(numel, + out_grad_data, + pd_indices, + stride_a, + shape_a, + tmp_value_grad_data); + + std::vector after_dims = phi::vectorize(tmp_value_grad.dims()); + std::vector before_dims = phi::vectorize(value_grad->dims()); + std::vector compress_dims; + std::vector dims_without_1; + size_t i = after_dims.size(); + size_t j = before_dims.size(); + if (i < j) { + PADDLE_THROW(phi::errors::InvalidArgument( + "shape of value can't not be broadcast to shape of x[indices]")); + } + while ((i--) && (j--)) { + if (after_dims[i] == before_dims[j]) { + dims_without_1.push_back(before_dims[j]); + continue; + } else if (before_dims[j] == 1) { + compress_dims.push_back(i); + } else { + PADDLE_THROW(phi::errors::InvalidArgument( + "shape of value can't not be broadcast to shape of x[indices]")); + } + } + while (i--) { + compress_dims.push_back(i); + } + + phi::DenseTensor value_grad_dims_without1(value_grad->dtype()); + value_grad_dims_without1.Resize(phi::make_ddim(dims_without_1)); + IntArray v_axis(compress_dims); + SumKernel(dev_ctx, + tmp_value_grad, + v_axis, + value_grad->dtype(), + false, + &value_grad_dims_without1); + phi::ReshapeInferKernel( + dev_ctx, + value_grad_dims_without1, + phi::IntArray(phi::vectorize(value_grad->dims())), + value_grad); + } + } +} + +template +void IndexPutGradKernel(const Context& dev_ctx, + const DenseTensor& x, + const std::vector& indices_v, + const DenseTensor& value, + const DenseTensor& out_grad, + bool accumulate, + DenseTensor* x_grad, + DenseTensor* value_grad) { + PADDLE_ENFORCE_EQ( + x.dtype(), + value.dtype(), + phi::errors::InvalidArgument( + "The data type of tensor in indices must be same to the data type " + "of tensor x.")); + std::vector tmp_args; + std::vector int_indices_v = + DealWithBoolIndices(dev_ctx, indices_v, &tmp_args); + const size_t total_dims = x.dims().size(); + auto bd_dim = BroadCastTensorsDims(int_indices_v); + + std::vector res_dim_v(phi::vectorize(bd_dim)); + std::vector res_indices_v(x.dims().size(), nullptr); + std::vector tmp_res_indices_v; + + if (int_indices_v.size() < total_dims) { + std::vector tmp_x_dims = phi::vectorize(x.dims()); + int len_bd_dim = bd_dim.size(); + res_dim_v.insert( + res_dim_v.end(), tmp_x_dims.begin() + len_bd_dim, tmp_x_dims.end()); + + std::vector reshaped_indices_v; + for (size_t i = 0; i < int_indices_v.size(); ++i) { + if (int_indices_v[i]->dtype() == phi::DataType::INT32) { + reshaped_indices_v.emplace_back(phi::Cast( + dev_ctx, *int_indices_v[i], phi::DataType::INT64)); + } else { + reshaped_indices_v.emplace_back(*int_indices_v[i]); + } + } + for (size_t i = int_indices_v.size(); i < total_dims; ++i) { + reshaped_indices_v.emplace_back(GetRangeTensor( + dev_ctx, res_dim_v[i], phi::DataType::INT64)); + } + phi::DDim res_dim = phi::make_ddim(res_dim_v); + + for (size_t i = 0; i < reshaped_indices_v.size(); ++i) { + tmp_res_indices_v.emplace_back( + GetReshapeAndExpandTensor( + dev_ctx, + reshaped_indices_v[i], + res_dim, + ((i < int_indices_v.size()) ? 0 : i))); + } + for (size_t i = 0; i < res_indices_v.size(); ++i) { + res_indices_v[i] = &tmp_res_indices_v[i]; + } + + } else { + std::vector int_indices_v_tmp; + + for (size_t i = 0; i < int_indices_v.size(); ++i) { + if (int_indices_v[i]->dtype() == phi::DataType::INT32) { + int_indices_v_tmp.emplace_back(phi::Cast( + dev_ctx, *int_indices_v[i], phi::DataType::INT64)); + } else { + int_indices_v_tmp.emplace_back(*int_indices_v[i]); + } + } + + for (size_t i = 0; i < int_indices_v.size(); ++i) { + if (bd_dim != int_indices_v[i]->dims()) { + tmp_res_indices_v.emplace_back( + DenseTensor(phi::DataType::INT64).Resize(bd_dim)); + ExpandKernel( + dev_ctx, + int_indices_v_tmp[i], + IntArray(phi::vectorize(bd_dim)), + &tmp_res_indices_v[i]); + } else { + tmp_res_indices_v.emplace_back(int_indices_v_tmp[i]); + } + } + + for (size_t i = 0; i < res_indices_v.size(); ++i) { + res_indices_v[i] = &tmp_res_indices_v[i]; + } + } + + switch (total_dims) { + case 1: + LaunchIndexPutGradKernel( + dev_ctx, res_indices_v, out_grad, accumulate, value_grad, x_grad); + break; + case 2: + LaunchIndexPutGradKernel( + dev_ctx, res_indices_v, out_grad, accumulate, value_grad, x_grad); + break; + case 3: + LaunchIndexPutGradKernel( + dev_ctx, res_indices_v, out_grad, accumulate, value_grad, x_grad); + break; + case 4: + LaunchIndexPutGradKernel( + dev_ctx, res_indices_v, out_grad, accumulate, value_grad, x_grad); + break; + case 5: + LaunchIndexPutGradKernel( + dev_ctx, res_indices_v, out_grad, accumulate, value_grad, x_grad); + break; + case 6: + LaunchIndexPutGradKernel( + dev_ctx, res_indices_v, out_grad, accumulate, value_grad, x_grad); + break; + default: + PADDLE_THROW(phi::errors::InvalidArgument( + "dims of input tensor should be less than 7, But received" + "%d", + x.dims().size())); + } +} +} // namespace phi + +PD_REGISTER_KERNEL(index_put_grad, + CPU, + ALL_LAYOUT, + phi::IndexPutGradKernel, + float, + double, + int, + int64_t, + bool) {} diff --git a/paddle/phi/kernels/cpu/index_put_kernel.cc b/paddle/phi/kernels/cpu/index_put_kernel.cc new file mode 100644 index 00000000000000..363f874277d653 --- /dev/null +++ b/paddle/phi/kernels/cpu/index_put_kernel.cc @@ -0,0 +1,274 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/index_put_kernel.h" +#include "paddle/phi/backends/cpu/cpu_context.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/core/utils/array.h" +#include "paddle/phi/kernels/cast_kernel.h" +#include "paddle/phi/kernels/expand_kernel.h" + +namespace phi { + +template +void range_kernel(int64_t N, T* out) { + for (int64_t idx = 0; idx < N; ++idx) { + out[idx] = idx; + } +} + +template +phi::DenseTensor GetRangeTensor(const Context& dev_ctx, + int64_t N, + phi::DataType dtype) { + phi::DenseTensor res(dtype); + res.Resize(phi::make_ddim({N})); + DenseTensor* p_res = &res; + T* out = dev_ctx.template Alloc(p_res); + range_kernel(N, out); + return res; +} + +template +void index_put_kernel(const int64_t N, + const T* x, + const T* vals, + const int64_t** indices, + phi::Array stride, + phi::Array shape, + int64_t isSingleValTensor, + bool accumulate, + T* out) { +#ifdef PADDLE_WITH_MKLML +#pragma omp parallel for +#endif + for (int64_t idx = 0; idx < N; ++idx) { + int64_t cur_ix = 0; + int64_t offset = 0; + + for (size_t i = 0; i < Rank; ++i) { + cur_ix = (int64_t(*(indices[i] + idx))); + if (cur_ix < 0) { + cur_ix += shape[i]; + } + offset += stride[i] * cur_ix; + } + + if (accumulate) { + *(out + offset) += *(vals + (idx & isSingleValTensor)); + } else { + *(out + offset) = *(vals + (idx & isSingleValTensor)); + } + } +} + +template +void LaunchIndexPutKernel(const Context& dev_ctx, + const DenseTensor& x, + const std::vector& indices_v, + const DenseTensor& value, + bool accumulate, + DenseTensor* out) { + auto* x_data = x.data(); + auto* val_data = value.data(); + bool isInitialized = out->initialized(); + T* out_data = dev_ctx.template Alloc(out); + + if (!isInitialized) { + phi::Copy(dev_ctx, x, dev_ctx.GetPlace(), false, out); + } + + auto x_dims = x.dims(); + const int64_t numel = indices_v[0]->numel(); + auto x_stride = phi::stride(x_dims); + + phi::Array stride_a; + phi::Array shape_a; + + for (size_t idx = 0; idx < Rank; ++idx) { + stride_a[idx] = x_stride[idx]; + shape_a[idx] = x_dims[idx]; + } + + int64_t isSingleValTensor = (value.numel() == 1) ? 0 : INT64_MAX; + + const int64_t* pd_indices[Rank]; + for (size_t i = 0; i < Rank; ++i) { + pd_indices[i] = indices_v[i]->data(); + } + + index_put_kernel(numel, + x_data, + val_data, + pd_indices, + stride_a, + shape_a, + isSingleValTensor, + accumulate, + out_data); +} + +template +void IndexPutKernel(const Context& dev_ctx, + const DenseTensor& x, + const std::vector& indices_v, + const DenseTensor& value, + bool accumulate, + DenseTensor* out) { + PADDLE_ENFORCE_EQ( + x.dtype(), + value.dtype(), + phi::errors::InvalidArgument( + "The data type of tensor in indices must be same to the data type " + "of tensor x.")); + std::vector tmp_args; + std::vector int_indices_v = + DealWithBoolIndices(dev_ctx, indices_v, &tmp_args); + const size_t total_dims = x.dims().size(); + auto bd_dim = BroadCastTensorsDims(int_indices_v); + + std::vector res_dim_v(phi::vectorize(bd_dim)); + std::vector res_indices_v(x.dims().size(), nullptr); + std::vector tmp_res_indices_v; + std::vector tmp_value_v; + const DenseTensor* ptr_value = nullptr; + + if (int_indices_v.size() < total_dims) { + std::vector tmp_x_dims = phi::vectorize(x.dims()); + int len_bd_dim = bd_dim.size(); + res_dim_v.insert( + res_dim_v.end(), tmp_x_dims.begin() + len_bd_dim, tmp_x_dims.end()); + + std::vector reshaped_indices_v; + for (size_t i = 0; i < int_indices_v.size(); ++i) { + if (int_indices_v[i]->dtype() == phi::DataType::INT32) { + reshaped_indices_v.emplace_back(phi::Cast( + dev_ctx, *int_indices_v[i], phi::DataType::INT64)); + } else { + reshaped_indices_v.emplace_back(*int_indices_v[i]); + } + } + for (size_t i = int_indices_v.size(); i < total_dims; ++i) { + reshaped_indices_v.emplace_back(GetRangeTensor( + dev_ctx, res_dim_v[i], phi::DataType::INT64)); + } + phi::DDim res_dim = phi::make_ddim(res_dim_v); + + for (size_t i = 0; i < reshaped_indices_v.size(); ++i) { + tmp_res_indices_v.emplace_back( + GetReshapeAndExpandTensor( + dev_ctx, + reshaped_indices_v[i], + res_dim, + ((i < int_indices_v.size()) ? 0 : i))); + } + for (size_t i = 0; i < res_indices_v.size(); ++i) { + res_indices_v[i] = &tmp_res_indices_v[i]; + } + // value至少需要满足与已有的indices为可broadcast_to关系 + if (value.numel() != 1) { + tmp_value_v.emplace_back(DenseTensor(value.dtype()).Resize(res_dim)); + ExpandKernel(dev_ctx, + value, + IntArray(phi::vectorize(res_dim)), + &tmp_value_v[0]); + ptr_value = &tmp_value_v[0]; + } else { + ptr_value = &value; + } + } else { + std::vector int_indices_v_tmp; + + for (size_t i = 0; i < int_indices_v.size(); ++i) { + if (int_indices_v[i]->dtype() == phi::DataType::INT32) { + int_indices_v_tmp.emplace_back(phi::Cast( + dev_ctx, *int_indices_v[i], phi::DataType::INT64)); + } else { + int_indices_v_tmp.emplace_back(*int_indices_v[i]); + } + } + + for (size_t i = 0; i < int_indices_v.size(); ++i) { + if (bd_dim != int_indices_v[i]->dims()) { + tmp_res_indices_v.emplace_back( + DenseTensor(phi::DataType::INT64).Resize(bd_dim)); + ExpandKernel( + dev_ctx, + int_indices_v_tmp[i], + IntArray(phi::vectorize(bd_dim)), + &tmp_res_indices_v[i]); + } else { + tmp_res_indices_v.emplace_back(int_indices_v_tmp[i]); + } + } + + for (size_t i = 0; i < res_indices_v.size(); ++i) { + res_indices_v[i] = &tmp_res_indices_v[i]; + } + + if (value.numel() != 1) { + tmp_value_v.emplace_back(DenseTensor(value.dtype()).Resize(bd_dim)); + ExpandKernel(dev_ctx, + value, + IntArray(phi::vectorize(bd_dim)), + &tmp_value_v[0]); + ptr_value = &tmp_value_v[0]; + } else { + ptr_value = &value; + } + } + + switch (total_dims) { + case 1: + LaunchIndexPutKernel( + dev_ctx, x, res_indices_v, *ptr_value, accumulate, out); + break; + case 2: + LaunchIndexPutKernel( + dev_ctx, x, res_indices_v, *ptr_value, accumulate, out); + break; + case 3: + LaunchIndexPutKernel( + dev_ctx, x, res_indices_v, *ptr_value, accumulate, out); + break; + case 4: + LaunchIndexPutKernel( + dev_ctx, x, res_indices_v, *ptr_value, accumulate, out); + break; + case 5: + LaunchIndexPutKernel( + dev_ctx, x, res_indices_v, *ptr_value, accumulate, out); + break; + case 6: + LaunchIndexPutKernel( + dev_ctx, x, res_indices_v, *ptr_value, accumulate, out); + break; + default: + PADDLE_THROW(phi::errors::InvalidArgument( + "dims of input tensor should be less than 7, But received" + "%d", + x.dims().size())); + } +} +} // namespace phi + +PD_REGISTER_KERNEL(index_put, + CPU, + ALL_LAYOUT, + phi::IndexPutKernel, + float, + double, + int, + int64_t, + bool) {} diff --git a/paddle/phi/kernels/gpu/index_put_grad_kernel.cu b/paddle/phi/kernels/gpu/index_put_grad_kernel.cu new file mode 100644 index 00000000000000..608ea982007d8c --- /dev/null +++ b/paddle/phi/kernels/gpu/index_put_grad_kernel.cu @@ -0,0 +1,414 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/index_put_grad_kernel.h" +#include +#include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/backends/gpu/gpu_launch_config.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/core/utils/array.h" +#include "paddle/phi/kernels/cast_kernel.h" +#include "paddle/phi/kernels/expand_kernel.h" +#include "paddle/phi/kernels/reduce_sum_kernel.h" +#include "paddle/phi/kernels/reshape_kernel.h" + +namespace phi { + +template +__global__ void range_cuda_kernel(int64_t N, T* out) { + int64_t idx = threadIdx.x + blockDim.x * blockIdx.x; + + if (idx >= N) { + return; + } + out[idx] = idx; +} + +template +phi::DenseTensor GetRangeCudaTensor(const Context& dev_ctx, + int64_t N, + phi::DataType dtype) { + phi::DenseTensor res(dtype); + res.Resize(phi::make_ddim({N})); + DenseTensor* p_res = &res; + T* out = dev_ctx.template Alloc(p_res); + auto config = phi::backends::gpu::GetGpuLaunchConfig1D(dev_ctx, N); + range_cuda_kernel + <<>>( + N, out); + return res; +} + +template +__global__ void set_zero_cuda_kernel(const int64_t N, + int64_t** indices, + phi::Array stride, + phi::Array shape, + T* out) { + int64_t idx = threadIdx.x + blockDim.x * blockIdx.x; + int64_t cur_ix = 0; + + if (idx >= N) { + return; + } + int64_t offset = 0; + for (int i = 0; i < Rank; ++i) { + cur_ix = (int64_t(*(indices[i] + idx))); + if (cur_ix < 0) { + cur_ix += shape[i]; + } + offset += stride[i] * cur_ix; + } + + *(out + offset) = 0; +} + +template +__global__ void index_put_grad_cuda_kernel(const int64_t N, + const T* out_grad, + int64_t** indices, + phi::Array stride, + phi::Array shape, + T* value_grad) { + int64_t idx = threadIdx.x + blockDim.x * blockIdx.x; + int64_t cur_ix = 0; + + if (idx >= N) { + return; + } + int64_t offset = 0; + for (int i = 0; i < Rank; ++i) { + cur_ix = (int64_t(*(indices[i] + idx))); + if (cur_ix < 0) { + cur_ix += shape[i]; + } + offset += stride[i] * cur_ix; + } + + *(value_grad + idx) = *(out_grad + offset); +} + +template +void LaunchIndexPutGradCudaKernel( + const Context& dev_ctx, + const std::vector& indices_v, + const DenseTensor& out_grad, + bool accumulate, + DenseTensor* value_grad, + DenseTensor* x_grad) { + if (x_grad) { + phi::Copy(dev_ctx, out_grad, dev_ctx.GetPlace(), false, x_grad); + if (accumulate) { + T* x_grad_data = x_grad->data(); + + auto x_grad_dims = x_grad->dims(); + const int64_t numel = indices_v[0]->numel(); + auto config = phi::backends::gpu::GetGpuLaunchConfig1D(dev_ctx, numel); + auto x_grad_stride = phi::stride(x_grad_dims); + + phi::Array stride_a; + phi::Array shape_a; + + for (size_t idx = 0; idx < Rank; ++idx) { + stride_a[idx] = x_grad_stride[idx]; + shape_a[idx] = x_grad_dims[idx]; + } + + auto pd_indices = + GetDevicePointerArray(dev_ctx, indices_v); + set_zero_cuda_kernel<<>>( + numel, pd_indices, stride_a, shape_a, x_grad_data); + } + } + + if (value_grad) { + if (value_grad->numel() == 1) { + DenseTensor tmp_value_grad(value_grad->dtype()); + tmp_value_grad.Resize(indices_v[0]->dims()); + + T* tmp_value_grad_data = dev_ctx.template Alloc(&tmp_value_grad); + auto out_grad_data = out_grad.data(); + + auto out_grad_dims = out_grad.dims(); + const int64_t numel = indices_v[0]->numel(); + auto config = phi::backends::gpu::GetGpuLaunchConfig1D(dev_ctx, numel); + auto out_grad_stride = phi::stride(out_grad_dims); + + phi::Array stride_a; + phi::Array shape_a; + + for (size_t idx = 0; idx < Rank; ++idx) { + stride_a[idx] = out_grad_stride[idx]; + shape_a[idx] = out_grad_dims[idx]; + } + + auto pd_indices = + GetDevicePointerArray(dev_ctx, indices_v); + index_put_grad_cuda_kernel + <<>>(numel, + out_grad_data, + pd_indices, + stride_a, + shape_a, + tmp_value_grad_data); + + std::vector v_dims(tmp_value_grad.dims().size()); + std::iota(v_dims.begin(), v_dims.end(), 0); + IntArray v_axis(v_dims); + SumKernel(dev_ctx, + tmp_value_grad, + v_axis, + value_grad->dtype(), + false, + value_grad); + } else if (value_grad->numel() == indices_v[0]->numel()) { + T* value_grad_data = dev_ctx.template Alloc(value_grad); + auto out_grad_data = out_grad.data(); + + auto out_grad_dims = out_grad.dims(); + const int64_t numel = indices_v[0]->numel(); + auto config = phi::backends::gpu::GetGpuLaunchConfig1D(dev_ctx, numel); + auto out_grad_stride = phi::stride(out_grad_dims); + + phi::Array stride_a; + phi::Array shape_a; + + for (size_t idx = 0; idx < Rank; ++idx) { + stride_a[idx] = out_grad_stride[idx]; + shape_a[idx] = out_grad_dims[idx]; + } + + auto pd_indices = + GetDevicePointerArray(dev_ctx, indices_v); + index_put_grad_cuda_kernel<<>>( + numel, out_grad_data, pd_indices, stride_a, shape_a, value_grad_data); + } else { + DenseTensor tmp_value_grad(value_grad->dtype()); + tmp_value_grad.Resize(indices_v[0]->dims()); + + T* tmp_value_grad_data = dev_ctx.template Alloc(&tmp_value_grad); + auto out_grad_data = out_grad.data(); + + auto out_grad_dims = out_grad.dims(); + const int64_t numel = indices_v[0]->numel(); + auto config = phi::backends::gpu::GetGpuLaunchConfig1D(dev_ctx, numel); + auto out_grad_stride = phi::stride(out_grad_dims); + + phi::Array stride_a; + phi::Array shape_a; + + for (size_t idx = 0; idx < Rank; ++idx) { + stride_a[idx] = out_grad_stride[idx]; + shape_a[idx] = out_grad_dims[idx]; + } + + auto pd_indices = + GetDevicePointerArray(dev_ctx, indices_v); + index_put_grad_cuda_kernel + <<>>(numel, + out_grad_data, + pd_indices, + stride_a, + shape_a, + tmp_value_grad_data); + + std::vector after_dims = phi::vectorize(tmp_value_grad.dims()); + std::vector before_dims = phi::vectorize(value_grad->dims()); + std::vector compress_dims; + std::vector dims_without_1; + size_t i = after_dims.size(); + size_t j = before_dims.size(); + if (i < j) { + PADDLE_THROW(phi::errors::InvalidArgument( + "shape of value can't not be broadcast to shape of x[indices]")); + } + while ((i--) && (j--)) { + if (after_dims[i] == before_dims[j]) { + dims_without_1.push_back(before_dims[j]); + continue; + } else if (before_dims[j] == 1) { + compress_dims.push_back(i); + } else { + PADDLE_THROW(phi::errors::InvalidArgument( + "shape of value can't not be broadcast to shape of x[indices]")); + } + } + while (i--) { + compress_dims.push_back(i); + } + + phi::DenseTensor value_grad_dims_without1(value_grad->dtype()); + value_grad_dims_without1.Resize(phi::make_ddim(dims_without_1)); + IntArray v_axis(compress_dims); + SumKernel(dev_ctx, + tmp_value_grad, + v_axis, + value_grad->dtype(), + false, + &value_grad_dims_without1); + phi::ReshapeInferKernel( + dev_ctx, + value_grad_dims_without1, + phi::IntArray(phi::vectorize(value_grad->dims())), + value_grad); + } + } +} + +template +void IndexPutGradKernel(const Context& dev_ctx, + const DenseTensor& x, + const std::vector& indices_v, + const DenseTensor& value, + const DenseTensor& out_grad, + bool accumulate, + DenseTensor* x_grad, + DenseTensor* value_grad) { + PADDLE_ENFORCE_EQ( + x.dtype(), + value.dtype(), + phi::errors::InvalidArgument( + "The data type of tensor in indices must be same to the data type " + "of tensor x.")); + std::vector tmp_args; + std::vector int_indices_v = + DealWithBoolIndices(dev_ctx, indices_v, &tmp_args); + const size_t total_dims = x.dims().size(); + auto bd_dim = BroadCastTensorsDims(int_indices_v); + + std::vector res_dim_v(phi::vectorize(bd_dim)); + std::vector res_indices_v(x.dims().size(), nullptr); + std::vector tmp_res_indices_v; + std::vector tmp_value_v; + const DenseTensor* ptr_value = nullptr; + + if (int_indices_v.size() < total_dims) { + std::vector tmp_x_dims = phi::vectorize(x.dims()); + int len_bd_dim = bd_dim.size(); + res_dim_v.insert( + res_dim_v.end(), tmp_x_dims.begin() + len_bd_dim, tmp_x_dims.end()); + + std::vector reshaped_indices_v; + for (size_t i = 0; i < int_indices_v.size(); ++i) { + if (int_indices_v[i]->dtype() == phi::DataType::INT32) { + reshaped_indices_v.emplace_back(phi::Cast( + dev_ctx, *int_indices_v[i], phi::DataType::INT64)); + } else { + reshaped_indices_v.emplace_back(*int_indices_v[i]); + } + } + for (size_t i = int_indices_v.size(); i < total_dims; ++i) { + reshaped_indices_v.emplace_back(GetRangeCudaTensor( + dev_ctx, res_dim_v[i], phi::DataType::INT64)); + } + phi::DDim res_dim = phi::make_ddim(res_dim_v); + + for (size_t i = 0; i < reshaped_indices_v.size(); ++i) { + tmp_res_indices_v.emplace_back( + GetReshapeAndExpandTensor( + dev_ctx, + reshaped_indices_v[i], + res_dim, + ((i < int_indices_v.size()) ? 0 : i))); + } + for (size_t i = 0; i < res_indices_v.size(); ++i) { + res_indices_v[i] = &tmp_res_indices_v[i]; + } + } else { + std::vector int_indices_v_tmp; + + for (size_t i = 0; i < int_indices_v.size(); ++i) { + if (int_indices_v[i]->dtype() == phi::DataType::INT32) { + int_indices_v_tmp.emplace_back(phi::Cast( + dev_ctx, *int_indices_v[i], phi::DataType::INT64)); + } else { + int_indices_v_tmp.emplace_back(*int_indices_v[i]); + } + } + + for (size_t i = 0; i < int_indices_v.size(); ++i) { + if (bd_dim != int_indices_v[i]->dims()) { + tmp_res_indices_v.emplace_back( + DenseTensor(phi::DataType::INT64).Resize(bd_dim)); + ExpandKernel( + dev_ctx, + int_indices_v_tmp[i], + IntArray(phi::vectorize(bd_dim)), + &tmp_res_indices_v[i]); + } else { + tmp_res_indices_v.emplace_back(int_indices_v_tmp[i]); + } + } + + for (size_t i = 0; i < res_indices_v.size(); ++i) { + res_indices_v[i] = &tmp_res_indices_v[i]; + } + } + + switch (total_dims) { + case 1: + LaunchIndexPutGradCudaKernel( + dev_ctx, res_indices_v, out_grad, accumulate, value_grad, x_grad); + break; + case 2: + LaunchIndexPutGradCudaKernel( + dev_ctx, res_indices_v, out_grad, accumulate, value_grad, x_grad); + break; + case 3: + LaunchIndexPutGradCudaKernel( + dev_ctx, res_indices_v, out_grad, accumulate, value_grad, x_grad); + break; + case 4: + LaunchIndexPutGradCudaKernel( + dev_ctx, res_indices_v, out_grad, accumulate, value_grad, x_grad); + break; + case 5: + LaunchIndexPutGradCudaKernel( + dev_ctx, res_indices_v, out_grad, accumulate, value_grad, x_grad); + break; + case 6: + LaunchIndexPutGradCudaKernel( + dev_ctx, res_indices_v, out_grad, accumulate, value_grad, x_grad); + break; + default: + PADDLE_THROW(phi::errors::InvalidArgument( + "dims of input tensor should be less than 7, But received" + "%d", + x.dims().size())); + } +} +} // namespace phi + +PD_REGISTER_KERNEL(index_put_grad, + GPU, + ALL_LAYOUT, + phi::IndexPutGradKernel, + float, + double, + int, + int64_t, + bool, + phi::dtype::float16) {} diff --git a/paddle/phi/kernels/gpu/index_put_kernel.cu b/paddle/phi/kernels/gpu/index_put_kernel.cu new file mode 100644 index 00000000000000..b81e49813ff85d --- /dev/null +++ b/paddle/phi/kernels/gpu/index_put_kernel.cu @@ -0,0 +1,284 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/index_put_kernel.h" +#include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/backends/gpu/gpu_launch_config.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/core/utils/array.h" +#include "paddle/phi/kernels/cast_kernel.h" +#include "paddle/phi/kernels/expand_kernel.h" +#include "paddle/phi/kernels/nonzero_kernel.h" +#include "paddle/phi/kernels/split_kernel.h" + +namespace phi { +template +__global__ void range_cuda_kernel(int64_t N, T* out) { + int64_t idx = threadIdx.x + blockDim.x * blockIdx.x; + + if (idx >= N) { + return; + } + out[idx] = idx; +} + +template +phi::DenseTensor GetRangeCudaTensor(const Context& dev_ctx, + int64_t N, + phi::DataType dtype) { + phi::DenseTensor res(dtype); + res.Resize(phi::make_ddim({N})); + DenseTensor* p_res = &res; + T* out = dev_ctx.template Alloc(p_res); + auto config = phi::backends::gpu::GetGpuLaunchConfig1D(dev_ctx, N); + range_cuda_kernel + <<>>( + N, out); + return res; +} + +template +__global__ void index_put_cuda_kernel(const int64_t N, + const T* x, + const T* vals, + int64_t** indices, + phi::Array stride, + phi::Array shape, + int64_t isSingleValTensor, + bool accumulate, + T* out) { + int64_t idx = threadIdx.x + blockDim.x * blockIdx.x; + int64_t cur_ix = 0; + + if (idx >= N) { + return; + } + int64_t offset = 0; + for (int i = 0; i < Rank; ++i) { + cur_ix = (int64_t(*(indices[i] + idx))); + if (cur_ix < 0) { + cur_ix += shape[i]; + } + offset += stride[i] * cur_ix; + } + // 能不能加到模板里面去 + if (accumulate) { + *(out + offset) += *(vals + (idx & isSingleValTensor)); + } else { + *(out + offset) = *(vals + (idx & isSingleValTensor)); + } +} + +template +void LaunchIndexPutCudaKernel(const Context& dev_ctx, + const DenseTensor& x, + const std::vector& indices_v, + const DenseTensor& value, + bool accumulate, + DenseTensor* out) { + auto* x_data = x.data(); + auto* val_data = value.data(); + bool isInitialized = out->initialized(); + T* out_data = dev_ctx.template Alloc(out); + + if (!isInitialized) { + phi::Copy(dev_ctx, x, dev_ctx.GetPlace(), false, out); + } + + auto x_dims = x.dims(); + const int64_t numel = indices_v[0]->numel(); + auto config = phi::backends::gpu::GetGpuLaunchConfig1D(dev_ctx, numel); + auto x_stride = phi::stride(x_dims); + + phi::Array stride_a; + phi::Array shape_a; + + for (size_t idx = 0; idx < Rank; ++idx) { + stride_a[idx] = x_stride[idx]; + shape_a[idx] = x_dims[idx]; + } + + int64_t isSingleValTensor = (value.numel() == 1) ? 0 : INT64_MAX; + + auto pd_indices = GetDevicePointerArray(dev_ctx, indices_v); + index_put_cuda_kernel + <<>>( + numel, + x_data, + val_data, + pd_indices, + stride_a, + shape_a, + isSingleValTensor, + accumulate, + out_data); +} + +template +void IndexPutKernel(const Context& dev_ctx, + const DenseTensor& x, + const std::vector& indices_v, + const DenseTensor& value, + bool accumulate, + DenseTensor* out) { + PADDLE_ENFORCE_EQ( + x.dtype(), + value.dtype(), + phi::errors::InvalidArgument( + "The data type of tensor in indices must be same to the data type " + "of tensor x.")); + std::vector tmp_args; + std::vector int_indices_v = + DealWithBoolIndices(dev_ctx, indices_v, &tmp_args); + std::cout << "line 143" << std::endl; + const size_t total_dims = x.dims().size(); + auto bd_dim = BroadCastTensorsDims(int_indices_v); + + std::cout << "line 147" << std::endl; + std::vector res_dim_v(phi::vectorize(bd_dim)); + std::vector res_indices_v(x.dims().size(), nullptr); + std::vector tmp_res_indices_v; + std::vector tmp_value_v; + const DenseTensor* ptr_value = nullptr; + + if (int_indices_v.size() < total_dims) { + std::vector tmp_x_dims = phi::vectorize(x.dims()); + int len_bd_dim = bd_dim.size(); + res_dim_v.insert( + res_dim_v.end(), tmp_x_dims.begin() + len_bd_dim, tmp_x_dims.end()); + + std::vector reshaped_indices_v; + for (size_t i = 0; i < int_indices_v.size(); ++i) { + if (int_indices_v[i]->dtype() == phi::DataType::INT32) { + reshaped_indices_v.emplace_back(phi::Cast( + dev_ctx, *int_indices_v[i], phi::DataType::INT64)); + } else { + reshaped_indices_v.emplace_back(*int_indices_v[i]); + } + } + for (size_t i = int_indices_v.size(); i < total_dims; ++i) { + reshaped_indices_v.emplace_back(GetRangeCudaTensor( + dev_ctx, res_dim_v[i], phi::DataType::INT64)); + } + phi::DDim res_dim = phi::make_ddim(res_dim_v); + + for (size_t i = 0; i < reshaped_indices_v.size(); ++i) { + tmp_res_indices_v.emplace_back( + GetReshapeAndExpandTensor( + dev_ctx, + reshaped_indices_v[i], + res_dim, + ((i < int_indices_v.size()) ? 0 : i))); + } + for (size_t i = 0; i < res_indices_v.size(); ++i) { + res_indices_v[i] = &tmp_res_indices_v[i]; + } + + if (value.numel() != 1) { + tmp_value_v.emplace_back(DenseTensor(value.dtype()).Resize(res_dim)); + ExpandKernel(dev_ctx, + value, + IntArray(phi::vectorize(res_dim)), + &tmp_value_v[0]); + ptr_value = &tmp_value_v[0]; + } else { + ptr_value = &value; + } + } else { + std::vector int_indices_v_tmp; + + for (size_t i = 0; i < int_indices_v.size(); ++i) { + if (int_indices_v[i]->dtype() == phi::DataType::INT32) { + int_indices_v_tmp.emplace_back(phi::Cast( + dev_ctx, *int_indices_v[i], phi::DataType::INT64)); + } else { + int_indices_v_tmp.emplace_back(*int_indices_v[i]); + } + } + for (size_t i = 0; i < int_indices_v.size(); ++i) { + if (bd_dim != int_indices_v[i]->dims()) { + tmp_res_indices_v.emplace_back( + DenseTensor(phi::DataType::INT64).Resize(bd_dim)); + ExpandKernel( + dev_ctx, + int_indices_v_tmp[i], + IntArray(phi::vectorize(bd_dim)), + &tmp_res_indices_v[i]); + } else { + tmp_res_indices_v.emplace_back(int_indices_v_tmp[i]); + } + } + + for (size_t i = 0; i < res_indices_v.size(); ++i) { + res_indices_v[i] = &tmp_res_indices_v[i]; + } + + if (value.numel() != 1) { + tmp_value_v.emplace_back(DenseTensor(value.dtype()).Resize(bd_dim)); + ExpandKernel(dev_ctx, + value, + IntArray(phi::vectorize(bd_dim)), + &tmp_value_v[0]); + ptr_value = &tmp_value_v[0]; + } else { + ptr_value = &value; + } + } + std::cout << "line 249" << std::endl; + + switch (total_dims) { + case 1: + LaunchIndexPutCudaKernel( + dev_ctx, x, res_indices_v, *ptr_value, accumulate, out); + break; + case 2: + LaunchIndexPutCudaKernel( + dev_ctx, x, res_indices_v, *ptr_value, accumulate, out); + break; + case 3: + LaunchIndexPutCudaKernel( + dev_ctx, x, res_indices_v, *ptr_value, accumulate, out); + break; + case 4: + LaunchIndexPutCudaKernel( + dev_ctx, x, res_indices_v, *ptr_value, accumulate, out); + break; + case 5: + LaunchIndexPutCudaKernel( + dev_ctx, x, res_indices_v, *ptr_value, accumulate, out); + break; + case 6: + LaunchIndexPutCudaKernel( + dev_ctx, x, res_indices_v, *ptr_value, accumulate, out); + break; + default: + PADDLE_THROW(phi::errors::InvalidArgument( + "dims of input tensor should be less than 7, But received" + "%d", + x.dims().size())); + } + std::cout << "line 276" << std::endl; +} +} // namespace phi + +PD_REGISTER_KERNEL(index_put, + GPU, + ALL_LAYOUT, + phi::IndexPutKernel, + float, + double, + int, + int64_t, + bool, + phi::dtype::float16) {} diff --git a/paddle/phi/kernels/index_put_grad_kernel.h b/paddle/phi/kernels/index_put_grad_kernel.h new file mode 100644 index 00000000000000..a1879a929435a4 --- /dev/null +++ b/paddle/phi/kernels/index_put_grad_kernel.h @@ -0,0 +1,185 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include "paddle/fluid/memory/malloc.h" +#include "paddle/fluid/memory/memcpy.h" +#include "paddle/phi/common/int_array.h" +#include "paddle/phi/common/place.h" +#include "paddle/phi/core/dense_tensor.h" +#include "paddle/phi/core/utils/array.h" +#include "paddle/phi/kernels/expand_kernel.h" +#include "paddle/phi/kernels/nonzero_kernel.h" +#include "paddle/phi/kernels/reshape_kernel.h" +#include "paddle/phi/kernels/split_kernel.h" + +namespace phi { + +template +static phi::DenseTensor GetReshapeAndExpandTensor( + const Context& dev_ctx, + const phi::DenseTensor& tensor, + const phi::DDim& res_dim, + int index) { + std::vector before_dims = phi::vectorize(tensor.dims()); + std::vector mid_dims(res_dim.size(), 1); + + for (size_t i = 0; i < before_dims.size(); ++i) { + mid_dims[i + index] = before_dims[i]; + } + phi::DenseTensor mid_tensor(tensor.dtype()); + mid_tensor.Resize(phi::make_ddim(mid_dims)); + ReshapeInferKernel(dev_ctx, tensor, IntArray(mid_dims), &mid_tensor); + + phi::DenseTensor res_tensor(tensor.dtype()); + res_tensor.Resize(res_dim); + ExpandKernel( + dev_ctx, mid_tensor, IntArray(phi::vectorize(res_dim)), &res_tensor); + return res_tensor; +} + +template +static std::vector DealWithBoolIndices( + const Context& dev_ctx, + const std::vector& indices_v, + std::vector* tmp_indices_v) { + std::vector res(indices_v.begin(), indices_v.end()); + bool contains_bool_tensor = false; + for (size_t i = 0; i < indices_v.size(); ++i) { + if (indices_v[i]->dtype() == phi::DataType::BOOL) { + contains_bool_tensor = true; + } else if ((indices_v[i]->dtype() == phi::DataType::INT64) || + (indices_v[i]->dtype() == phi::DataType::INT32)) { + if (contains_bool_tensor) { + PADDLE_THROW(phi::errors::InvalidArgument( + "indices contains bool tensor and int32/int64 tensor at the same " + "time")); + } + } else { + PADDLE_THROW(phi::errors::InvalidArgument( + "data type of tensor in indices must be int32, int64 or bool")); + } + } + + if (contains_bool_tensor) { + if (indices_v.size() != 1) { + PADDLE_THROW(phi::errors::InvalidArgument( + "the size of indices must be 1 when it containts bool tensor")); + } + int rank = indices_v[0]->dims().size(); + PADDLE_ENFORCE_GE( + rank, + 1UL, + phi::errors::InvalidArgument("the only bool tensor in indices should " + "have number of dimension at least 1")); + phi::DenseTensor nonzero_indices(phi::DataType::INT64); + nonzero_indices.Resize(phi::make_ddim({-1, rank})); + NonZeroKernel(dev_ctx, *indices_v[0], &nonzero_indices); + + std::vector integer_indices(rank, nullptr); + for (int i = 0; i < rank; ++i) { + // tmp_indices_v.emplace_back(DenseTensor(phi::DataType::INT64).Resize(phi::make_ddim({nonzero_indices.dims()[0],1}))); + // 理论上这里应该要加个1的 + tmp_indices_v->emplace_back( + DenseTensor(phi::DataType::INT64) + .Resize(phi::make_ddim({nonzero_indices.dims()[0]}))); + } + for (int i = 0; i < rank; ++i) { + integer_indices[i] = &((*tmp_indices_v)[i]); + } + SplitWithNumKernel( + dev_ctx, nonzero_indices, rank, 1, integer_indices); + + std::vector res_tmp(integer_indices.size(), + nullptr); + for (int i = 0; i < rank; ++i) { + res_tmp[i] = &((*tmp_indices_v)[i]); + } + res.swap(res_tmp); + } + return res; +} + +static phi::DDim BroadCastTensorsDims( + const std::vector& tensors) { + int target_rank = 0; + for (const auto& tensor : tensors) { + target_rank = std::max(target_rank, tensor->dims().size()); + } + + PADDLE_ENFORCE_GT(target_rank, + 0, + errors::InvalidArgument("BroadCastTensorsDims requires at " + "least one input tensor to have " + "rank greater than zero")); + + std::vector target_dims(target_rank, 0); + for (int index = 0; index < target_rank; index++) { + int target_dim_size = 1; + for (const auto& tensor : tensors) { + auto input_ddim = tensor->dims(); + int axis = static_cast(input_ddim.size()) - index - 1; + int dim_size = 1; + if (axis >= 0) { + dim_size = input_ddim[axis]; + } + + if (target_dim_size != 1 && dim_size != 1 && + target_dim_size != dim_size) { + PADDLE_THROW(errors::InvalidArgument( + "BroadCastTensorsDims inputs does not satisfy bcast semantics, " + "please check axis = %d in reverse order", + index)); + } + + target_dim_size = dim_size == 1 ? target_dim_size : dim_size; + } + target_dims[target_rank - index - 1] = target_dim_size; + } + return phi::make_ddim(target_dims); +} + +template +T** GetDevicePointerArray(const Context& ctx, + const std::vector& indices_v) { + std::vector h_indices_v(indices_v.size()); + for (int i = 0; i < indices_v.size(); ++i) { + h_indices_v[i] = indices_v[i]->data(); + } + auto d_indices_data = paddle::memory::Alloc( + ctx.GetPlace(), + h_indices_v.size() * sizeof(T*), + phi::Stream(reinterpret_cast(ctx.stream()))); + paddle::memory::Copy(ctx.GetPlace(), + d_indices_data->ptr(), + phi::CPUPlace(), + reinterpret_cast(h_indices_v.data()), + h_indices_v.size() * sizeof(T*), + ctx.stream()); + return reinterpret_cast(d_indices_data->ptr()); +} + +template +void IndexPutGradKernel(const Context& dev_ctx, + const DenseTensor& x, + const std::vector& indices_v, + const DenseTensor& value, + const DenseTensor& out_grad, + bool accumulate, + DenseTensor* x_grad, + DenseTensor* value_grad); + +} // namespace phi diff --git a/paddle/phi/kernels/index_put_kernel.h b/paddle/phi/kernels/index_put_kernel.h new file mode 100644 index 00000000000000..ef8d2bd987e75c --- /dev/null +++ b/paddle/phi/kernels/index_put_kernel.h @@ -0,0 +1,182 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include "paddle/fluid/memory/malloc.h" +#include "paddle/fluid/memory/memcpy.h" +#include "paddle/phi/common/int_array.h" +#include "paddle/phi/common/place.h" +#include "paddle/phi/core/dense_tensor.h" +#include "paddle/phi/core/utils/array.h" +#include "paddle/phi/kernels/expand_kernel.h" +#include "paddle/phi/kernels/nonzero_kernel.h" +#include "paddle/phi/kernels/reshape_kernel.h" +#include "paddle/phi/kernels/split_kernel.h" + +namespace phi { + +template +static phi::DenseTensor GetReshapeAndExpandTensor( + const Context& dev_ctx, + const phi::DenseTensor& tensor, + const phi::DDim& res_dim, + int index) { + std::vector before_dims = phi::vectorize(tensor.dims()); + std::vector mid_dims(res_dim.size(), 1); + + for (size_t i = 0; i < before_dims.size(); ++i) { + mid_dims[i + index] = before_dims[i]; + } + phi::DenseTensor mid_tensor(tensor.dtype()); + mid_tensor.Resize(phi::make_ddim(mid_dims)); + ReshapeInferKernel(dev_ctx, tensor, IntArray(mid_dims), &mid_tensor); + + phi::DenseTensor res_tensor(tensor.dtype()); + res_tensor.Resize(res_dim); + ExpandKernel( + dev_ctx, mid_tensor, IntArray(phi::vectorize(res_dim)), &res_tensor); + return res_tensor; +} + +template +static std::vector DealWithBoolIndices( + const Context& dev_ctx, + const std::vector& indices_v, + std::vector* tmp_indices_v) { + std::vector res(indices_v.begin(), indices_v.end()); + bool contains_bool_tensor = false; + for (size_t i = 0; i < indices_v.size(); ++i) { + if (indices_v[i]->dtype() == phi::DataType::BOOL) { + contains_bool_tensor = true; + } else if ((indices_v[i]->dtype() == phi::DataType::INT64) || + (indices_v[i]->dtype() == phi::DataType::INT32)) { + if (contains_bool_tensor) { + PADDLE_THROW(phi::errors::InvalidArgument( + "indices contains bool tensor and int32/int64 tensor at the same " + "time")); + } + } else { + PADDLE_THROW(phi::errors::InvalidArgument( + "data type of tensor in indices must be int32, int64 or bool")); + } + } + + if (contains_bool_tensor) { + if (indices_v.size() != 1) { + PADDLE_THROW(phi::errors::InvalidArgument( + "the size of indices must be 1 when it containts bool tensor")); + } + int rank = indices_v[0]->dims().size(); + PADDLE_ENFORCE_GE( + rank, + 1UL, + phi::errors::InvalidArgument("the only bool tensor in indices should " + "have number of dimension at least 1")); + phi::DenseTensor nonzero_indices(phi::DataType::INT64); + nonzero_indices.Resize(phi::make_ddim({-1, rank})); + NonZeroKernel(dev_ctx, *indices_v[0], &nonzero_indices); + + std::vector integer_indices(rank, nullptr); + for (int i = 0; i < rank; ++i) { + // tmp_indices_v.emplace_back(DenseTensor(phi::DataType::INT64).Resize(phi::make_ddim({nonzero_indices.dims()[0],1}))); + tmp_indices_v->emplace_back( + DenseTensor(phi::DataType::INT64) + .Resize(phi::make_ddim({nonzero_indices.dims()[0]}))); + } + for (int i = 0; i < rank; ++i) { + integer_indices[i] = &((*tmp_indices_v)[i]); + } + SplitWithNumKernel( + dev_ctx, nonzero_indices, rank, 1, integer_indices); + + std::vector res_tmp(integer_indices.size(), + nullptr); + for (int i = 0; i < rank; ++i) { + res_tmp[i] = &((*tmp_indices_v)[i]); + } + res.swap(res_tmp); + } + return res; +} + +static phi::DDim BroadCastTensorsDims( + const std::vector& tensors) { + int target_rank = 0; + for (const auto& tensor : tensors) { + target_rank = std::max(target_rank, tensor->dims().size()); + } + + PADDLE_ENFORCE_GT(target_rank, + 0, + errors::InvalidArgument("BroadCastTensorsDims requires at " + "least one input tensor to have " + "rank greater than zero")); + + std::vector target_dims(target_rank, 0); + for (int index = 0; index < target_rank; index++) { + int target_dim_size = 1; + for (const auto& tensor : tensors) { + auto input_ddim = tensor->dims(); + int axis = static_cast(input_ddim.size()) - index - 1; + int dim_size = 1; + if (axis >= 0) { + dim_size = input_ddim[axis]; + } + + if (target_dim_size != 1 && dim_size != 1 && + target_dim_size != dim_size) { + PADDLE_THROW(errors::InvalidArgument( + "BroadCastTensorsDims inputs does not satisfy bcast semantics, " + "please check axis = %d in reverse order", + index)); + } + + target_dim_size = dim_size == 1 ? target_dim_size : dim_size; + } + target_dims[target_rank - index - 1] = target_dim_size; + } + return phi::make_ddim(target_dims); +} + +template +T** GetDevicePointerArray(const Context& ctx, + const std::vector& indices_v) { + std::vector h_indices_v(indices_v.size()); + for (int i = 0; i < indices_v.size(); ++i) { + h_indices_v[i] = indices_v[i]->data(); + } + auto d_indices_data = paddle::memory::Alloc( + ctx.GetPlace(), + h_indices_v.size() * sizeof(T*), + phi::Stream(reinterpret_cast(ctx.stream()))); + paddle::memory::Copy(ctx.GetPlace(), + d_indices_data->ptr(), + phi::CPUPlace(), + reinterpret_cast(h_indices_v.data()), + h_indices_v.size() * sizeof(T*), + ctx.stream()); + return reinterpret_cast(d_indices_data->ptr()); +} + +template +void IndexPutKernel(const Context& dev_ctx, + const DenseTensor& x, + const std::vector& indices_v, + const DenseTensor& value, + bool accumulate, + DenseTensor* out); + +} // namespace phi diff --git a/python/paddle/__init__.py b/python/paddle/__init__.py index ca237df8e53fe1..ac79e6f5b2d3c2 100644 --- a/python/paddle/__init__.py +++ b/python/paddle/__init__.py @@ -194,6 +194,8 @@ from .tensor.manipulation import repeat_interleave # noqa: F401 from .tensor.manipulation import index_add # noqa: F401 from .tensor.manipulation import index_add_ # noqa: F401 +from .tensor.manipulation import index_put # noqa: F401 +from .tensor.manipulation import index_put_ # noqa: F401 from .tensor.math import abs # noqa: F401 from .tensor.math import acos # noqa: F401 from .tensor.math import asin # noqa: F401 @@ -676,6 +678,8 @@ 'tril_indices', 'index_add', "index_add_", + "index_put", + "index_put_", 'sgn', 'triu_indices', 'take', diff --git a/python/paddle/fluid/tests/unittests/test_index_put_op.py b/python/paddle/fluid/tests/unittests/test_index_put_op.py new file mode 100644 index 00000000000000..249860d0c55d27 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_index_put_op.py @@ -0,0 +1,341 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import copy +import unittest + +import numpy as np + +import paddle +from paddle import _C_ops + + +def compute_index_put_ref(x_np, indices_np, value_np, accumulate=False): + if accumulate: + x_np[indices_np] += value_np + return x_np + else: + x_np[indices_np] = value_np + return x_np + + +def raw_index_put(x, indices, value): + return _C_ops.index_put(x, indices, value) + + +def has_duplicate_index(indices, shapes): + bd_shape = np.broadcast_shapes(*shapes) + bd_indices = [ + list(np.broadcast_to(indice, bd_shape).flatten()) for indice in indices + ] + + zip_res = list(zip(*bd_indices)) + if len(zip_res) == len(set(zip_res)): + return False + else: + return True + + +def gen_indices_np(x_shape, indices_shapes, index_type): + indices = [] + if index_type == np.bool_: + indice = np.zeros(indices_shapes[0], dtype=np.bool_) + indice.flatten() + for i in range(len(indice)): + indice[i] = (i & 1) == 0 + indice = indice.reshape(indices_shapes[0]) + indices.append(indice) + else: + while True: + indices = [] + for i in range(len(indices_shapes)): + np.random.seed() + index_np = np.random.randint( + low=0, + high=x_shape[i], + size=indices_shapes[i], + dtype=index_type, + ) + indices.append(index_np) + if not has_duplicate_index( + copy.deepcopy(indices), copy.deepcopy(indices_shapes) + ): + break + return tuple(indices) + + +class TestIndexPutOp(unittest.TestCase): + def setUp(self): + self.init_dtype_type() + self.x_np = np.random.random(self.x_shape).astype(self.dtype_np) + self.value_np = np.random.random(self.value_shape).astype(self.dtype_np) + self.indices_np = gen_indices_np( + self.x_shape, self.indices_shapes, self.index_type_np + ) + + self.x_pd = paddle.to_tensor(self.x_np, dtype=self.dtype_pd) + self.value_pd = paddle.to_tensor(self.value_np, dtype=self.dtype_pd) + self.indices_pd = [ + paddle.to_tensor(indice, dtype=self.index_type_pd) + for indice in self.indices_np + ] + + def init_dtype_type(self): + self.dtype_np = np.float64 + self.index_type_np = np.int64 + self.x_shape = (100, 110) + self.indices_shapes = [(21,), (21,)] + self.value_shape = (21,) + self.dtype_pd = paddle.float64 + self.index_type_pd = paddle.int64 + + def test_forward(self): + ref_res = compute_index_put_ref( + self.x_np, self.indices_np, self.value_np + ) + pd_res = raw_index_put(self.x_pd, self.indices_pd, self.value_pd) + np.testing.assert_allclose(ref_res, pd_res.numpy(), atol=1e-7) + + def test_backward(self): + value = paddle.ones(shape=[4], dtype=self.dtype_pd) + x = paddle.ones(shape=[16, 21], dtype=self.dtype_pd) + ix1 = paddle.to_tensor([0, 1, 2, 3], dtype=self.index_type_pd) + ix2 = paddle.to_tensor([0, 1, 2, 3], dtype=self.index_type_pd) + value.stop_gradient = False + x[ix1, ix2] = value + + dvalue = paddle.grad( + outputs=[x], inputs=[value], create_graph=False, retain_graph=True + )[0] + + np.testing.assert_allclose( + np.array([1.0, 1.0, 1.0, 1.0], dtype=self.dtype_np), + dvalue.numpy(), + atol=1e-7, + ) + + def test_backwardScalarVal(self): + value = paddle.ones(shape=[1], dtype=self.dtype_pd) + x = paddle.ones(shape=[16, 21], dtype=self.dtype_pd) + ix1 = paddle.to_tensor([0, 1, 2, 3], dtype=self.index_type_pd) + ix2 = paddle.to_tensor([0, 1, 2, 3], dtype=self.index_type_pd) + value.stop_gradient = False + x[ix1, ix2] = value + + dvalue = paddle.grad( + outputs=[x], inputs=[value], create_graph=False, retain_graph=True + )[0] + + np.testing.assert_allclose( + np.array([4.0], dtype=self.dtype_np), dvalue.numpy(), atol=1e-7 + ) + + +class TestIndexPutOpFloat32(TestIndexPutOp): + def init_dtype_type(self): + self.dtype_np = np.float32 + self.index_type_np = np.int64 + self.x_shape = (100, 110) + self.indices_shapes = [(21,), (21,)] + self.value_shape = (21,) + self.dtype_pd = paddle.float32 + self.index_type_pd = paddle.int64 + self.dtype_pd = paddle.float32 + + +class TestIndexPutOpFloat16(TestIndexPutOp): + def init_dtype_type(self): + self.dtype_np = np.float16 + self.index_type_np = np.int64 + self.x_shape = (100, 110) + self.indices_shapes = [(21,), (21,)] + self.value_shape = (21,) + self.dtype_pd = paddle.float16 + self.index_type_pd = paddle.int64 + self.dtype_pd = paddle.float16 + + +class TestIndexPutOpInt32(TestIndexPutOp): + def init_dtype_type(self): + self.dtype_np = np.int32 + self.index_type_np = np.int64 + self.x_shape = (100, 110) + self.indices_shapes = [(21,), (21,)] + self.value_shape = (21,) + self.dtype_pd = paddle.int32 + self.index_type_pd = paddle.int64 + self.dtype_pd = paddle.int32 + + +class TestIndexPutOpInt64(TestIndexPutOp): + def init_dtype_type(self): + self.dtype_np = np.int64 + self.index_type_np = np.int64 + self.x_shape = (100, 110) + self.indices_shapes = [(21,), (21,)] + self.value_shape = (21,) + self.dtype_pd = paddle.int64 + self.index_type_pd = paddle.int64 + self.dtype_pd = paddle.int64 + + +class TestIndexPutOpBool(TestIndexPutOp): + def init_dtype_type(self): + self.dtype_np = np.bool_ + self.index_type_np = np.int64 + self.x_shape = (100, 110) + self.indices_shapes = [(21,), (21,)] + self.value_shape = (21,) + self.dtype_pd = paddle.bool + self.index_type_pd = paddle.int64 + self.dtype_pd = paddle.bool + + +class TestIndexPutAPIBase(unittest.TestCase): + def setUp(self): + self.init_dtype_type() + self.x_np = np.random.random(self.x_shape).astype(self.dtype_np) + self.value_np = np.random.random(self.value_shape).astype(self.dtype_np) + self.indices_np = gen_indices_np( + self.x_shape, self.indices_shapes, self.index_type_np + ) + + self.x_pd = paddle.to_tensor(self.x_np, dtype=self.dtype_pd) + self.value_pd = paddle.to_tensor(self.value_np, dtype=self.dtype_pd) + self.indices_pd = [ + paddle.to_tensor(indice, dtype=self.index_type_pd) + for indice in self.indices_np + ] + self.indices_pd = tuple(self.indices_pd) + + def init_dtype_type(self): + self.dtype_np = np.float64 + self.index_type_np = np.int64 + self.x_shape = (100, 110) + self.indices_shapes = [(21,), (21,)] + self.value_shape = (21,) + self.dtype_pd = paddle.float64 + self.index_type_pd = paddle.int64 + self.accumulate = False + + +class TestIndexPutAPI0(TestIndexPutAPIBase): + def test_forward(self): + ref_res = compute_index_put_ref( + self.x_np, self.indices_np, self.value_np + ) + pd_res = paddle.index_put( + self.x_pd, self.indices_pd, self.value_pd, self.accumulate + ) + np.testing.assert_allclose(ref_res, pd_res.numpy(), atol=1e-7) + + +class TestIndexPutAPI1(TestIndexPutAPIBase): + def init_dtype_type(self): + self.dtype_np = np.float64 + self.index_type_np = np.int64 + self.x_shape = (110, 42, 56, 56) + self.indices_shapes = ((16, 16), (16, 16), (1, 16), (1, 16)) + self.value_shape = (16, 16) + self.dtype_pd = paddle.float64 + self.index_type_pd = paddle.int64 + self.accumulate = False + + def test_forward(self): + ref_res = compute_index_put_ref( + self.x_np, self.indices_np, self.value_np + ) + pd_res = paddle.index_put( + self.x_pd, self.indices_pd, self.value_pd, self.accumulate + ) + np.testing.assert_allclose(ref_res, pd_res.numpy(), atol=1e-7) + + +class TestIndexPutAPI2(TestIndexPutAPIBase): + def init_dtype_type(self): + self.dtype_np = np.float64 + self.index_type_np = np.bool_ + self.x_shape = (110, 94) + self.indices_shapes = [(110, 94)] + self.value_shape = 5170 + self.dtype_pd = paddle.float64 + self.index_type_pd = paddle.bool + self.accumulate = False + + def test_forward(self): + ref_res = compute_index_put_ref( + self.x_np, self.indices_np, self.value_np + ) + pd_res = paddle.index_put( + self.x_pd, self.indices_pd, self.value_pd, self.accumulate + ) + np.testing.assert_allclose(ref_res, pd_res.numpy(), atol=1e-7) + + +class TestIndexPutAPIBackward0(TestIndexPutAPIBase): + def test_backward(self): + value = paddle.ones(shape=[4], dtype=self.dtype_pd) + x = paddle.ones(shape=[16, 21], dtype=self.dtype_pd) + ix1 = paddle.to_tensor([0, 1, 2, 3], dtype=self.index_type_pd) + ix2 = paddle.to_tensor([0, 1, 2, 3], dtype=self.index_type_pd) + value.stop_gradient = False + x.stop_gradient = False + out = paddle.index_put(x, (ix1, ix2), value, self.accumulate) + + dx, dvalue = paddle.grad( + outputs=[out], + inputs=[x, value], + create_graph=False, + retain_graph=True, + ) + ref_dx = np.ones(shape=[16, 21], dtype=self.dtype_np) + ref_dx[ix1, ix2] = 0 + + np.testing.assert_allclose(ref_dx, dx.numpy(), atol=1e-7) + + np.testing.assert_allclose(ref_dx, dx.numpy(), atol=1e-7) + np.testing.assert_allclose( + np.array([1.0, 1.0, 1.0, 1.0], dtype=self.dtype_np), + dvalue.numpy(), + atol=1e-7, + ) + + +class TestIndexPutAPIBackward1(TestIndexPutAPIBase): + def test_backwardScalarVal(self): + value = paddle.ones(shape=[1], dtype=self.dtype_pd) + x = paddle.ones(shape=[16, 21], dtype=self.dtype_pd) + ix1 = paddle.to_tensor([0, 1, 2, 3], dtype=self.index_type_pd) + ix2 = paddle.to_tensor([0, 1, 2, 3], dtype=self.index_type_pd) + value.stop_gradient = False + x.stop_gradient = False + out = paddle.index_put(x, (ix1, ix2), value, self.accumulate) + + dx, dvalue = paddle.grad( + outputs=[out], + inputs=[x, value], + create_graph=False, + retain_graph=True, + ) + ref_dx = np.ones(shape=[16, 21], dtype=self.dtype_np) + ref_dx[ix1, ix2] = 0 + + np.testing.assert_allclose(ref_dx, dx.numpy(), atol=1e-7) + np.testing.assert_allclose( + np.array([4.0], dtype=self.dtype_np), dvalue.numpy(), atol=1e-7 + ) + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/tensor/__init__.py b/python/paddle/tensor/__init__.py index b78ac0e57c22e8..cda233c17fe542 100644 --- a/python/paddle/tensor/__init__.py +++ b/python/paddle/tensor/__init__.py @@ -135,6 +135,8 @@ from .manipulation import repeat_interleave # noqa: F401 from .manipulation import index_add # noqa: F401 from .manipulation import index_add_ # noqa: F401 +from .manipulation import index_put # noqa: F401 +from .manipulation import index_put_ # noqa: F401 from .math import abs # noqa: F401 from .math import acos # noqa: F401 from .math import asin # noqa: F401 @@ -530,7 +532,8 @@ 'heaviside', 'index_add', "index_add_", - 'take', + "index_put", + "index_put_" 'take', 'bucketize', 'sgn', 'frexp', diff --git a/python/paddle/tensor/manipulation.py b/python/paddle/tensor/manipulation.py index 09aaff08c3ca5e..330c15c8bd77b5 100644 --- a/python/paddle/tensor/manipulation.py +++ b/python/paddle/tensor/manipulation.py @@ -4750,6 +4750,116 @@ def index_add_(x, index, axis, value, name=None): return _C_ops.index_add_(x, index, value, axis) +@inplace_apis_in_dygraph_only +def index_put_(x, indices, value, accumulate=False, name=None): + """ + Puts values from the tensor values into the tensor x using the indices specified in indices (which is a tuple of Tensors). + The expression paddle.index_put_(x, indices, values) is equivalent to tensor[indices] = values. Returns x. + If accumulate is True, the elements in values are added to x. If accumulate is False, the behavior is undefined if indices contain duplicate elements. + + Args: + x (Tensor) : The Source Tensor. Supported data types are int32, int64, float16, float32, float64, bool, complex64, complex128. + indices (Tensor): The tuple of Tensor containing the indices to index. + The data type of ``tensor in indices`` must be int32, int64 or bool + value (Tensor): The tensor used to be assigned to x. + accummulate (Bool): Whether the elements in values are added to x + name(str, optional): For details, please refer to :ref:`api_guide_Name`. Generally, no setting is required. Default: None. + + Returns: + Tensor, same dimention and dtype with x. + + Examples: + .. code-block:: python + + import paddle + + x = paddle.zeros([3, 3]) + value = paddle.ones([3]) + ix1 = paddle.to_tensor([0,1,2]) + ix2 = paddle.to_tensor([1,2,1]) + indices=(ix1,ix2) + + out = paddle.index_put_(x,indices,value) + print(x) + # Tensor(shape=[3, 3], dtype=float32, place=Place(gpu:0), stop_gradient=True, + # [[0., 1., 0.], + # [0., 0., 1.], + # [0., 1., 0.]]) + print(out) + # Tensor(shape=[3, 3], dtype=float32, place=Place(gpu:0), stop_gradient=True, + # [[0., 1., 0.], + # [0., 0., 1.], + # [0., 1., 0.]]) + + + + + """ + assert len(indices) != 0, "indices can't be empty" + return _C_ops.index_put_(x, indices, value, accumulate) + + +def index_put(x, indices, value, accumulate=False, name=None): + """ + Outplace version of ``index_put_`` API, the output Tensor will be inplaced with input ``x``. + Please refer to :ref:`api_paddle_index_put`. + + Examples: + .. code-block:: python + + import paddle + + x = paddle.zeros([3, 3]) + value = paddle.ones([3]) + ix1 = paddle.to_tensor([0,1,2]) + ix2 = paddle.to_tensor([1,2,1]) + indices=(ix1,ix2) + + out = paddle.index_put(x,indices,value) + print(x) + # Tensor(shape=[3, 3], dtype=float32, place=Place(gpu:0), stop_gradient=True, + # [[0., 0., 0.], + # [0., 0., 0.], + # [0., 0., 0.]]) + print(out) + # Tensor(shape=[3, 3], dtype=float32, place=Place(gpu:0), stop_gradient=True, + # [[0., 1., 0.], + # [0., 0., 1.], + # [0., 1., 0.]]) + """ + + if in_dygraph_mode(): + return _C_ops.index_put(x, indices, value, accumulate) + + helper = LayerHelper("index_put", **locals()) + check_variable_and_dtype( + x, + 'x', + ['float16', 'float32', 'float64', 'int32', 'int64', 'bool'], + 'paddle.tensor.manipulation.index_put', + ) + check_variable_and_dtype( + value, + 'add_value', + ['float16', 'float32', 'float64', 'int32', 'int64', 'bool'], + 'paddle.tensor.manipulation.index_put', + ) + + out = helper.create_variable_for_type_inference(x.dtype) + + helper.append_op( + type='index_put', + inputs={ + 'x': x, + 'indices': indices, + 'value': value, + }, + outputs={'out': out}, + attrs={'accumulate': accumulate}, + ) + return out + + # TODO(dev): We need avoid implementing it by this way. __METHODS = { 'fill_': fill_, From 9da71b6de89bad7deff154ea5d13e8d71cc91877 Mon Sep 17 00:00:00 2001 From: Courtesy-Xs Date: Fri, 14 Apr 2023 09:14:50 +0000 Subject: [PATCH 02/24] fix some bugs --- paddle/phi/api/yaml/ops.yaml | 20 +- .../phi/kernels/cpu/index_put_grad_kernel.cc | 14 +- paddle/phi/kernels/cpu/index_put_kernel.cc | 12 +- .../phi/kernels/gpu/index_put_grad_kernel.cu | 14 +- paddle/phi/kernels/gpu/index_put_kernel.cu | 21 +- paddle/phi/kernels/index_put_grad_kernel.h | 10 +- paddle/phi/kernels/index_put_kernel.h | 14 +- .../tests/unittests/test_index_put_op.py | 406 ++++++++++++------ 8 files changed, 354 insertions(+), 157 deletions(-) diff --git a/paddle/phi/api/yaml/ops.yaml b/paddle/phi/api/yaml/ops.yaml index bdbcd12f56d15b..2873212f57e4e2 100644 --- a/paddle/phi/api/yaml/ops.yaml +++ b/paddle/phi/api/yaml/ops.yaml @@ -5,16 +5,6 @@ # are consistent and correspond one-to-one. It's forbidden that the # operator configured in this yaml file does not have Python API. - - op : index_put - args : (Tensor x, Tensor[] indices, Tensor value, bool accumulate=false) - output : Tensor(out) - infer_meta : - func : IndexPutInferMeta - kernel : - func : index_put - inplace : (x -> out) - backward : index_put_grad - - op : accuracy args : (Tensor x, Tensor indices, Tensor label) output : Tensor(accuracy), Tensor(correct), Tensor(total) @@ -878,6 +868,16 @@ inplace : (x -> out) backward : index_add_grad +- op : index_put + args : (Tensor x, Tensor[] indices, Tensor value, bool accumulate=false) + output : Tensor(out) + infer_meta : + func : IndexPutInferMeta + kernel : + func : index_put + inplace : (x -> out) + backward : index_put_grad + - op : index_sample args : (Tensor x, Tensor index) output : Tensor diff --git a/paddle/phi/kernels/cpu/index_put_grad_kernel.cc b/paddle/phi/kernels/cpu/index_put_grad_kernel.cc index 32a6c8b30c3d62..0e71979a596a47 100644 --- a/paddle/phi/kernels/cpu/index_put_grad_kernel.cc +++ b/paddle/phi/kernels/cpu/index_put_grad_kernel.cc @@ -101,7 +101,7 @@ void LaunchIndexPutGradKernel(const Context& dev_ctx, DenseTensor* x_grad) { if (x_grad) { phi::Copy(dev_ctx, out_grad, dev_ctx.GetPlace(), false, x_grad); - if (accumulate) { + if (!accumulate) { T* x_grad_data = x_grad->data(); auto x_grad_dims = x_grad->dims(); @@ -287,8 +287,9 @@ void IndexPutGradKernel(const Context& dev_ctx, if (int_indices_v.size() < total_dims) { std::vector tmp_x_dims = phi::vectorize(x.dims()); int len_bd_dim = bd_dim.size(); - res_dim_v.insert( - res_dim_v.end(), tmp_x_dims.begin() + len_bd_dim, tmp_x_dims.end()); + res_dim_v.insert(res_dim_v.end(), + tmp_x_dims.begin() + int_indices_v.size(), + tmp_x_dims.end()); std::vector reshaped_indices_v; for (size_t i = 0; i < int_indices_v.size(); ++i) { @@ -299,7 +300,7 @@ void IndexPutGradKernel(const Context& dev_ctx, reshaped_indices_v.emplace_back(*int_indices_v[i]); } } - for (size_t i = int_indices_v.size(); i < total_dims; ++i) { + for (size_t i = len_bd_dim; i < res_dim_v.size(); ++i) { reshaped_indices_v.emplace_back(GetRangeTensor( dev_ctx, res_dim_v[i], phi::DataType::INT64)); } @@ -311,7 +312,10 @@ void IndexPutGradKernel(const Context& dev_ctx, dev_ctx, reshaped_indices_v[i], res_dim, - ((i < int_indices_v.size()) ? 0 : i))); + bd_dim, + ((i < int_indices_v.size()) + ? 0 + : i - int_indices_v.size() + len_bd_dim))); } for (size_t i = 0; i < res_indices_v.size(); ++i) { res_indices_v[i] = &tmp_res_indices_v[i]; diff --git a/paddle/phi/kernels/cpu/index_put_kernel.cc b/paddle/phi/kernels/cpu/index_put_kernel.cc index 363f874277d653..8a1ec36e01b55f 100644 --- a/paddle/phi/kernels/cpu/index_put_kernel.cc +++ b/paddle/phi/kernels/cpu/index_put_kernel.cc @@ -147,8 +147,9 @@ void IndexPutKernel(const Context& dev_ctx, if (int_indices_v.size() < total_dims) { std::vector tmp_x_dims = phi::vectorize(x.dims()); int len_bd_dim = bd_dim.size(); - res_dim_v.insert( - res_dim_v.end(), tmp_x_dims.begin() + len_bd_dim, tmp_x_dims.end()); + res_dim_v.insert(res_dim_v.end(), + tmp_x_dims.begin() + int_indices_v.size(), + tmp_x_dims.end()); std::vector reshaped_indices_v; for (size_t i = 0; i < int_indices_v.size(); ++i) { @@ -159,7 +160,7 @@ void IndexPutKernel(const Context& dev_ctx, reshaped_indices_v.emplace_back(*int_indices_v[i]); } } - for (size_t i = int_indices_v.size(); i < total_dims; ++i) { + for (size_t i = len_bd_dim; i < res_dim_v.size(); ++i) { reshaped_indices_v.emplace_back(GetRangeTensor( dev_ctx, res_dim_v[i], phi::DataType::INT64)); } @@ -171,7 +172,10 @@ void IndexPutKernel(const Context& dev_ctx, dev_ctx, reshaped_indices_v[i], res_dim, - ((i < int_indices_v.size()) ? 0 : i))); + bd_dim, + ((i < int_indices_v.size()) + ? 0 + : i - int_indices_v.size() + len_bd_dim))); } for (size_t i = 0; i < res_indices_v.size(); ++i) { res_indices_v[i] = &tmp_res_indices_v[i]; diff --git a/paddle/phi/kernels/gpu/index_put_grad_kernel.cu b/paddle/phi/kernels/gpu/index_put_grad_kernel.cu index 608ea982007d8c..f2802127539a1a 100644 --- a/paddle/phi/kernels/gpu/index_put_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/index_put_grad_kernel.cu @@ -109,7 +109,7 @@ void LaunchIndexPutGradCudaKernel( DenseTensor* x_grad) { if (x_grad) { phi::Copy(dev_ctx, out_grad, dev_ctx.GetPlace(), false, x_grad); - if (accumulate) { + if (!accumulate) { T* x_grad_data = x_grad->data(); auto x_grad_dims = x_grad->dims(); @@ -308,8 +308,9 @@ void IndexPutGradKernel(const Context& dev_ctx, if (int_indices_v.size() < total_dims) { std::vector tmp_x_dims = phi::vectorize(x.dims()); int len_bd_dim = bd_dim.size(); - res_dim_v.insert( - res_dim_v.end(), tmp_x_dims.begin() + len_bd_dim, tmp_x_dims.end()); + res_dim_v.insert(res_dim_v.end(), + tmp_x_dims.begin() + int_indices_v.size(), + tmp_x_dims.end()); std::vector reshaped_indices_v; for (size_t i = 0; i < int_indices_v.size(); ++i) { @@ -320,7 +321,7 @@ void IndexPutGradKernel(const Context& dev_ctx, reshaped_indices_v.emplace_back(*int_indices_v[i]); } } - for (size_t i = int_indices_v.size(); i < total_dims; ++i) { + for (size_t i = len_bd_dim; i < res_dim_v.size(); ++i) { reshaped_indices_v.emplace_back(GetRangeCudaTensor( dev_ctx, res_dim_v[i], phi::DataType::INT64)); } @@ -332,7 +333,10 @@ void IndexPutGradKernel(const Context& dev_ctx, dev_ctx, reshaped_indices_v[i], res_dim, - ((i < int_indices_v.size()) ? 0 : i))); + bd_dim, + ((i < int_indices_v.size()) + ? 0 + : i - int_indices_v.size() + len_bd_dim))); } for (size_t i = 0; i < res_indices_v.size(); ++i) { res_indices_v[i] = &tmp_res_indices_v[i]; diff --git a/paddle/phi/kernels/gpu/index_put_kernel.cu b/paddle/phi/kernels/gpu/index_put_kernel.cu index b81e49813ff85d..5314beb5aa3fd8 100644 --- a/paddle/phi/kernels/gpu/index_put_kernel.cu +++ b/paddle/phi/kernels/gpu/index_put_kernel.cu @@ -141,11 +141,9 @@ void IndexPutKernel(const Context& dev_ctx, std::vector tmp_args; std::vector int_indices_v = DealWithBoolIndices(dev_ctx, indices_v, &tmp_args); - std::cout << "line 143" << std::endl; const size_t total_dims = x.dims().size(); auto bd_dim = BroadCastTensorsDims(int_indices_v); - std::cout << "line 147" << std::endl; std::vector res_dim_v(phi::vectorize(bd_dim)); std::vector res_indices_v(x.dims().size(), nullptr); std::vector tmp_res_indices_v; @@ -155,8 +153,9 @@ void IndexPutKernel(const Context& dev_ctx, if (int_indices_v.size() < total_dims) { std::vector tmp_x_dims = phi::vectorize(x.dims()); int len_bd_dim = bd_dim.size(); - res_dim_v.insert( - res_dim_v.end(), tmp_x_dims.begin() + len_bd_dim, tmp_x_dims.end()); + res_dim_v.insert(res_dim_v.end(), + tmp_x_dims.begin() + int_indices_v.size(), + tmp_x_dims.end()); std::vector reshaped_indices_v; for (size_t i = 0; i < int_indices_v.size(); ++i) { @@ -167,7 +166,12 @@ void IndexPutKernel(const Context& dev_ctx, reshaped_indices_v.emplace_back(*int_indices_v[i]); } } - for (size_t i = int_indices_v.size(); i < total_dims; ++i) { + + for (auto dim : res_dim_v) { + std::cout << dim << std::endl; + } + + for (size_t i = len_bd_dim; i < res_dim_v.size(); ++i) { reshaped_indices_v.emplace_back(GetRangeCudaTensor( dev_ctx, res_dim_v[i], phi::DataType::INT64)); } @@ -179,7 +183,10 @@ void IndexPutKernel(const Context& dev_ctx, dev_ctx, reshaped_indices_v[i], res_dim, - ((i < int_indices_v.size()) ? 0 : i))); + bd_dim, + ((i < int_indices_v.size()) + ? 0 + : i - int_indices_v.size() + len_bd_dim))); } for (size_t i = 0; i < res_indices_v.size(); ++i) { res_indices_v[i] = &tmp_res_indices_v[i]; @@ -235,7 +242,6 @@ void IndexPutKernel(const Context& dev_ctx, ptr_value = &value; } } - std::cout << "line 249" << std::endl; switch (total_dims) { case 1: @@ -268,7 +274,6 @@ void IndexPutKernel(const Context& dev_ctx, "%d", x.dims().size())); } - std::cout << "line 276" << std::endl; } } // namespace phi diff --git a/paddle/phi/kernels/index_put_grad_kernel.h b/paddle/phi/kernels/index_put_grad_kernel.h index a1879a929435a4..f28a4c61dfdbc3 100644 --- a/paddle/phi/kernels/index_put_grad_kernel.h +++ b/paddle/phi/kernels/index_put_grad_kernel.h @@ -33,13 +33,19 @@ static phi::DenseTensor GetReshapeAndExpandTensor( const Context& dev_ctx, const phi::DenseTensor& tensor, const phi::DDim& res_dim, + const phi::DDim& bd_dim, int index) { std::vector before_dims = phi::vectorize(tensor.dims()); std::vector mid_dims(res_dim.size(), 1); - for (size_t i = 0; i < before_dims.size(); ++i) { - mid_dims[i + index] = before_dims[i]; + if (index == 0) { + for (size_t i = 0; i < before_dims.size(); ++i) { + mid_dims[bd_dim.size() - i - 1] = before_dims[before_dims.size() - i - 1]; + } + } else { + mid_dims[index] = before_dims[0]; } + phi::DenseTensor mid_tensor(tensor.dtype()); mid_tensor.Resize(phi::make_ddim(mid_dims)); ReshapeInferKernel(dev_ctx, tensor, IntArray(mid_dims), &mid_tensor); diff --git a/paddle/phi/kernels/index_put_kernel.h b/paddle/phi/kernels/index_put_kernel.h index ef8d2bd987e75c..e9ab52b64f3cd8 100644 --- a/paddle/phi/kernels/index_put_kernel.h +++ b/paddle/phi/kernels/index_put_kernel.h @@ -33,13 +33,23 @@ static phi::DenseTensor GetReshapeAndExpandTensor( const Context& dev_ctx, const phi::DenseTensor& tensor, const phi::DDim& res_dim, + const phi::DDim& bd_dim, int index) { std::vector before_dims = phi::vectorize(tensor.dims()); std::vector mid_dims(res_dim.size(), 1); - for (size_t i = 0; i < before_dims.size(); ++i) { - mid_dims[i + index] = before_dims[i]; + if (index == 0) { + for (size_t i = 0; i < before_dims.size(); ++i) { + mid_dims[bd_dim.size() - i - 1] = before_dims[before_dims.size() - i - 1]; + } + } else { + mid_dims[index] = before_dims[0]; + } + std::cout << "this is mid_dim" << std::endl; + for (auto dim : mid_dims) { + std::cout << dim << std::endl; } + phi::DenseTensor mid_tensor(tensor.dtype()); mid_tensor.Resize(phi::make_ddim(mid_dims)); ReshapeInferKernel(dev_ctx, tensor, IntArray(mid_dims), &mid_tensor); diff --git a/python/paddle/fluid/tests/unittests/test_index_put_op.py b/python/paddle/fluid/tests/unittests/test_index_put_op.py index 249860d0c55d27..ec2a8c4ff6c2c0 100644 --- a/python/paddle/fluid/tests/unittests/test_index_put_op.py +++ b/python/paddle/fluid/tests/unittests/test_index_put_op.py @@ -75,7 +75,7 @@ def gen_indices_np(x_shape, indices_shapes, index_type): return tuple(indices) -class TestIndexPutOp(unittest.TestCase): +class TestIndexPutAPIBase(unittest.TestCase): def setUp(self): self.init_dtype_type() self.x_np = np.random.random(self.x_shape).astype(self.dtype_np) @@ -90,6 +90,7 @@ def setUp(self): paddle.to_tensor(indice, dtype=self.index_type_pd) for indice in self.indices_np ] + self.indices_pd = tuple(self.indices_pd) def init_dtype_type(self): self.dtype_np = np.float64 @@ -99,183 +100,280 @@ def init_dtype_type(self): self.value_shape = (21,) self.dtype_pd = paddle.float64 self.index_type_pd = paddle.int64 + self.accumulate = False + +class TestIndexPutAPI0(TestIndexPutAPIBase): def test_forward(self): + self.accumulate = False ref_res = compute_index_put_ref( - self.x_np, self.indices_np, self.value_np + self.x_np, self.indices_np, self.value_np, self.accumulate + ) + pd_res = paddle.index_put( + self.x_pd, self.indices_pd, self.value_pd, self.accumulate ) - pd_res = raw_index_put(self.x_pd, self.indices_pd, self.value_pd) np.testing.assert_allclose(ref_res, pd_res.numpy(), atol=1e-7) - def test_backward(self): - value = paddle.ones(shape=[4], dtype=self.dtype_pd) - x = paddle.ones(shape=[16, 21], dtype=self.dtype_pd) - ix1 = paddle.to_tensor([0, 1, 2, 3], dtype=self.index_type_pd) - ix2 = paddle.to_tensor([0, 1, 2, 3], dtype=self.index_type_pd) - value.stop_gradient = False - x[ix1, ix2] = value + def test_forward1(self): + self.accumulate = True + ref_res = compute_index_put_ref( + self.x_np, self.indices_np, self.value_np, self.accumulate + ) + pd_res = paddle.index_put( + self.x_pd, self.indices_pd, self.value_pd, self.accumulate + ) + np.testing.assert_allclose(ref_res, pd_res.numpy(), atol=1e-7) - dvalue = paddle.grad( - outputs=[x], inputs=[value], create_graph=False, retain_graph=True - )[0] - np.testing.assert_allclose( - np.array([1.0, 1.0, 1.0, 1.0], dtype=self.dtype_np), - dvalue.numpy(), - atol=1e-7, +class TestIndexPutAPI1(TestIndexPutAPIBase): + def init_dtype_type(self): + self.dtype_np = np.float64 + self.index_type_np = np.int64 + self.x_shape = (110, 42, 56, 56) + self.indices_shapes = ((16, 16), (16, 16), (1, 16), (1, 16)) + self.value_shape = (16, 16) + self.dtype_pd = paddle.float64 + self.index_type_pd = paddle.int64 + self.accumulate = False + + def test_forward(self): + self.accumulate = False + ref_res = compute_index_put_ref( + self.x_np, self.indices_np, self.value_np, self.accumulate + ) + pd_res = paddle.index_put( + self.x_pd, self.indices_pd, self.value_pd, self.accumulate ) + np.testing.assert_allclose(ref_res, pd_res.numpy(), atol=1e-7) - def test_backwardScalarVal(self): - value = paddle.ones(shape=[1], dtype=self.dtype_pd) - x = paddle.ones(shape=[16, 21], dtype=self.dtype_pd) - ix1 = paddle.to_tensor([0, 1, 2, 3], dtype=self.index_type_pd) - ix2 = paddle.to_tensor([0, 1, 2, 3], dtype=self.index_type_pd) - value.stop_gradient = False - x[ix1, ix2] = value + def test_forward1(self): + self.accumulate = True + ref_res = compute_index_put_ref( + self.x_np, self.indices_np, self.value_np, self.accumulate + ) + pd_res = paddle.index_put( + self.x_pd, self.indices_pd, self.value_pd, self.accumulate + ) + np.testing.assert_allclose(ref_res, pd_res.numpy(), atol=1e-7) - dvalue = paddle.grad( - outputs=[x], inputs=[value], create_graph=False, retain_graph=True - )[0] - np.testing.assert_allclose( - np.array([4.0], dtype=self.dtype_np), dvalue.numpy(), atol=1e-7 +class TestIndexPutAPI2(TestIndexPutAPIBase): + def init_dtype_type(self): + self.dtype_np = np.float64 + self.index_type_np = np.bool_ + self.x_shape = (110, 94) + self.indices_shapes = [(110, 94)] + self.value_shape = 5170 + self.dtype_pd = paddle.float64 + self.index_type_pd = paddle.bool + self.accumulate = False + + def test_forward(self): + self.accumulate = False + ref_res = compute_index_put_ref( + self.x_np, self.indices_np, self.value_np, self.accumulate + ) + pd_res = paddle.index_put( + self.x_pd, self.indices_pd, self.value_pd, self.accumulate ) + np.testing.assert_allclose(ref_res, pd_res.numpy(), atol=1e-7) + + def test_forward1(self): + self.accumulate = True + ref_res = compute_index_put_ref( + self.x_np, self.indices_np, self.value_np, self.accumulate + ) + pd_res = paddle.index_put( + self.x_pd, self.indices_pd, self.value_pd, self.accumulate + ) + np.testing.assert_allclose(ref_res, pd_res.numpy(), atol=1e-7) -class TestIndexPutOpFloat32(TestIndexPutOp): +class TestIndexPutAPI3(TestIndexPutAPIBase): def init_dtype_type(self): - self.dtype_np = np.float32 + self.dtype_np = np.float64 self.index_type_np = np.int64 - self.x_shape = (100, 110) - self.indices_shapes = [(21,), (21,)] - self.value_shape = (21,) - self.dtype_pd = paddle.float32 + self.x_shape = (110, 42, 56, 56) + self.indices_shapes = ((16, 16), (16, 16), (1, 16)) + self.value_shape = (16, 16, 56) + self.dtype_pd = paddle.float64 self.index_type_pd = paddle.int64 - self.dtype_pd = paddle.float32 + self.accumulate = False + def test_forward(self): + self.accumulate = False + ref_res = compute_index_put_ref( + self.x_np, self.indices_np, self.value_np, self.accumulate + ) + pd_res = paddle.index_put( + self.x_pd, self.indices_pd, self.value_pd, self.accumulate + ) + np.testing.assert_allclose(ref_res, pd_res.numpy(), atol=1e-7) -class TestIndexPutOpFloat16(TestIndexPutOp): - def init_dtype_type(self): - self.dtype_np = np.float16 - self.index_type_np = np.int64 - self.x_shape = (100, 110) - self.indices_shapes = [(21,), (21,)] - self.value_shape = (21,) - self.dtype_pd = paddle.float16 - self.index_type_pd = paddle.int64 - self.dtype_pd = paddle.float16 + def test_forward1(self): + self.accumulate = True + ref_res = compute_index_put_ref( + self.x_np, self.indices_np, self.value_np, self.accumulate + ) + pd_res = paddle.index_put( + self.x_pd, self.indices_pd, self.value_pd, self.accumulate + ) + np.testing.assert_allclose(ref_res, pd_res.numpy(), atol=1e-7) -class TestIndexPutOpInt32(TestIndexPutOp): +class TestIndexPutAPI4(TestIndexPutAPIBase): def init_dtype_type(self): - self.dtype_np = np.int32 - self.index_type_np = np.int64 - self.x_shape = (100, 110) - self.indices_shapes = [(21,), (21,)] - self.value_shape = (21,) - self.dtype_pd = paddle.int32 - self.index_type_pd = paddle.int64 - self.dtype_pd = paddle.int32 + self.dtype_np = np.float64 + self.index_type_np = np.bool_ + self.x_shape = (110, 94) + self.indices_shapes = [(110)] + self.value_shape = (55, 94) + self.dtype_pd = paddle.float64 + self.index_type_pd = paddle.bool + self.accumulate = False + def test_forward(self): + self.accumulate = False + ref_res = compute_index_put_ref( + self.x_np, self.indices_np, self.value_np, self.accumulate + ) + pd_res = paddle.index_put( + self.x_pd, self.indices_pd, self.value_pd, self.accumulate + ) + np.testing.assert_allclose(ref_res, pd_res.numpy(), atol=1e-7) -class TestIndexPutOpInt64(TestIndexPutOp): - def init_dtype_type(self): - self.dtype_np = np.int64 - self.index_type_np = np.int64 - self.x_shape = (100, 110) - self.indices_shapes = [(21,), (21,)] - self.value_shape = (21,) - self.dtype_pd = paddle.int64 - self.index_type_pd = paddle.int64 - self.dtype_pd = paddle.int64 + def test_forward1(self): + self.accumulate = True + ref_res = compute_index_put_ref( + self.x_np, self.indices_np, self.value_np, self.accumulate + ) + pd_res = paddle.index_put( + self.x_pd, self.indices_pd, self.value_pd, self.accumulate + ) + np.testing.assert_allclose(ref_res, pd_res.numpy(), atol=1e-7) -class TestIndexPutOpBool(TestIndexPutOp): +class TestIndexPutAPI5(TestIndexPutAPIBase): def init_dtype_type(self): - self.dtype_np = np.bool_ + self.dtype_np = np.float64 self.index_type_np = np.int64 - self.x_shape = (100, 110) - self.indices_shapes = [(21,), (21,)] - self.value_shape = (21,) - self.dtype_pd = paddle.bool + self.x_shape = (24, 100, 110, 98) + self.indices_shapes = ((21, 21), (1, 21), (1, 21)) + self.value_shape = 98 + self.dtype_pd = paddle.float64 self.index_type_pd = paddle.int64 - self.dtype_pd = paddle.bool + self.accumulate = False + def test_forward(self): + self.accumulate = False + ref_res = compute_index_put_ref( + self.x_np, self.indices_np, self.value_np, self.accumulate + ) + pd_res = paddle.index_put( + self.x_pd, self.indices_pd, self.value_pd, self.accumulate + ) + np.testing.assert_allclose(ref_res, pd_res.numpy(), atol=1e-7) -class TestIndexPutAPIBase(unittest.TestCase): - def setUp(self): - self.init_dtype_type() - self.x_np = np.random.random(self.x_shape).astype(self.dtype_np) - self.value_np = np.random.random(self.value_shape).astype(self.dtype_np) - self.indices_np = gen_indices_np( - self.x_shape, self.indices_shapes, self.index_type_np + def test_forward1(self): + self.accumulate = True + ref_res = compute_index_put_ref( + self.x_np, self.indices_np, self.value_np, self.accumulate ) + pd_res = paddle.index_put( + self.x_pd, self.indices_pd, self.value_pd, self.accumulate + ) + np.testing.assert_allclose(ref_res, pd_res.numpy(), atol=1e-7) - self.x_pd = paddle.to_tensor(self.x_np, dtype=self.dtype_pd) - self.value_pd = paddle.to_tensor(self.value_np, dtype=self.dtype_pd) - self.indices_pd = [ - paddle.to_tensor(indice, dtype=self.index_type_pd) - for indice in self.indices_np - ] - self.indices_pd = tuple(self.indices_pd) +class TestIndexPutAPI6(TestIndexPutAPIBase): def init_dtype_type(self): self.dtype_np = np.float64 self.index_type_np = np.int64 - self.x_shape = (100, 110) - self.indices_shapes = [(21,), (21,)] - self.value_shape = (21,) + self.x_shape = (24, 100, 110, 98) + self.indices_shapes = ((21, 21), (1, 21), (1, 21)) + self.value_shape = 1 self.dtype_pd = paddle.float64 self.index_type_pd = paddle.int64 self.accumulate = False - -class TestIndexPutAPI0(TestIndexPutAPIBase): def test_forward(self): + self.accumulate = False ref_res = compute_index_put_ref( - self.x_np, self.indices_np, self.value_np + self.x_np, self.indices_np, self.value_np, self.accumulate ) pd_res = paddle.index_put( self.x_pd, self.indices_pd, self.value_pd, self.accumulate ) np.testing.assert_allclose(ref_res, pd_res.numpy(), atol=1e-7) + def test_forward1(self): + self.accumulate = True + ref_res = compute_index_put_ref( + self.x_np, self.indices_np, self.value_np, self.accumulate + ) + pd_res = paddle.index_put( + self.x_pd, self.indices_pd, self.value_pd, self.accumulate + ) + np.testing.assert_allclose(ref_res, pd_res.numpy(), atol=1e-7) -class TestIndexPutAPI1(TestIndexPutAPIBase): + +class TestIndexPutAPI7(TestIndexPutAPIBase): def init_dtype_type(self): self.dtype_np = np.float64 - self.index_type_np = np.int64 - self.x_shape = (110, 42, 56, 56) - self.indices_shapes = ((16, 16), (16, 16), (1, 16), (1, 16)) - self.value_shape = (16, 16) + self.index_type_np = np.bool_ + self.x_shape = (44, 94) + self.indices_shapes = [(44)] + self.value_shape = 94 self.dtype_pd = paddle.float64 - self.index_type_pd = paddle.int64 + self.index_type_pd = paddle.bool self.accumulate = False def test_forward(self): + self.accumulate = False ref_res = compute_index_put_ref( - self.x_np, self.indices_np, self.value_np + self.x_np, self.indices_np, self.value_np, self.accumulate ) pd_res = paddle.index_put( self.x_pd, self.indices_pd, self.value_pd, self.accumulate ) np.testing.assert_allclose(ref_res, pd_res.numpy(), atol=1e-7) + def test_forward1(self): + self.accumulate = True + ref_res = compute_index_put_ref( + self.x_np, self.indices_np, self.value_np, self.accumulate + ) + pd_res = paddle.index_put( + self.x_pd, self.indices_pd, self.value_pd, self.accumulate + ) + np.testing.assert_allclose(ref_res, pd_res.numpy(), atol=1e-7) -class TestIndexPutAPI2(TestIndexPutAPIBase): + +class TestIndexPutAPI8(TestIndexPutAPIBase): def init_dtype_type(self): self.dtype_np = np.float64 self.index_type_np = np.bool_ - self.x_shape = (110, 94) - self.indices_shapes = [(110, 94)] - self.value_shape = 5170 + self.x_shape = (44, 94) + self.indices_shapes = [(44)] + self.value_shape = 1 self.dtype_pd = paddle.float64 self.index_type_pd = paddle.bool self.accumulate = False def test_forward(self): + self.accumulate = False + ref_res = compute_index_put_ref( + self.x_np, self.indices_np, self.value_np, self.accumulate + ) + pd_res = paddle.index_put( + self.x_pd, self.indices_pd, self.value_pd, self.accumulate + ) + np.testing.assert_allclose(ref_res, pd_res.numpy(), atol=1e-7) + + def test_forward1(self): + self.accumulate = True ref_res = compute_index_put_ref( - self.x_np, self.indices_np, self.value_np + self.x_np, self.indices_np, self.value_np, self.accumulate ) pd_res = paddle.index_put( self.x_pd, self.indices_pd, self.value_pd, self.accumulate @@ -283,15 +381,15 @@ def test_forward(self): np.testing.assert_allclose(ref_res, pd_res.numpy(), atol=1e-7) -class TestIndexPutAPIBackward0(TestIndexPutAPIBase): +class TestIndexPutAPIBackward(unittest.TestCase): def test_backward(self): - value = paddle.ones(shape=[4], dtype=self.dtype_pd) - x = paddle.ones(shape=[16, 21], dtype=self.dtype_pd) - ix1 = paddle.to_tensor([0, 1, 2, 3], dtype=self.index_type_pd) - ix2 = paddle.to_tensor([0, 1, 2, 3], dtype=self.index_type_pd) + value = paddle.ones(shape=[4], dtype=paddle.float64) + x = paddle.ones(shape=[16, 21], dtype=paddle.float64) + ix1 = paddle.to_tensor([0, 1, 2, 3], dtype=paddle.int64) + ix2 = paddle.to_tensor([0, 1, 2, 3], dtype=paddle.int64) value.stop_gradient = False x.stop_gradient = False - out = paddle.index_put(x, (ix1, ix2), value, self.accumulate) + out = paddle.index_put(x, (ix1, ix2), value, False) dx, dvalue = paddle.grad( outputs=[out], @@ -299,28 +397,79 @@ def test_backward(self): create_graph=False, retain_graph=True, ) - ref_dx = np.ones(shape=[16, 21], dtype=self.dtype_np) + ref_dx = np.ones(shape=[16, 21], dtype=np.float64) ref_dx[ix1, ix2] = 0 np.testing.assert_allclose(ref_dx, dx.numpy(), atol=1e-7) + np.testing.assert_allclose( + np.array([1.0, 1.0, 1.0, 1.0], dtype=np.float64), + dvalue.numpy(), + atol=1e-7, + ) + + out = paddle.index_put(x, (ix1, ix2), value, True) + + dx, dvalue = paddle.grad( + outputs=[out], + inputs=[x, value], + create_graph=False, + retain_graph=True, + ) + ref_dx = np.ones(shape=[16, 21], dtype=np.float64) np.testing.assert_allclose(ref_dx, dx.numpy(), atol=1e-7) np.testing.assert_allclose( - np.array([1.0, 1.0, 1.0, 1.0], dtype=self.dtype_np), + np.array([1.0, 1.0, 1.0, 1.0], dtype=np.float64), dvalue.numpy(), atol=1e-7, ) - -class TestIndexPutAPIBackward1(TestIndexPutAPIBase): def test_backwardScalarVal(self): - value = paddle.ones(shape=[1], dtype=self.dtype_pd) - x = paddle.ones(shape=[16, 21], dtype=self.dtype_pd) - ix1 = paddle.to_tensor([0, 1, 2, 3], dtype=self.index_type_pd) - ix2 = paddle.to_tensor([0, 1, 2, 3], dtype=self.index_type_pd) + value = paddle.ones(shape=[1], dtype=paddle.float64) + x = paddle.ones(shape=[16, 21], dtype=paddle.float64) + ix1 = paddle.to_tensor([0, 1, 2, 3], dtype=paddle.int64) + ix2 = paddle.to_tensor([0, 1, 2, 3], dtype=paddle.int64) + value.stop_gradient = False + x.stop_gradient = False + out = paddle.index_put(x, (ix1, ix2), value, False) + + dx, dvalue = paddle.grad( + outputs=[out], + inputs=[x, value], + create_graph=False, + retain_graph=True, + ) + ref_dx = np.ones(shape=[16, 21], dtype=np.float64) + ref_dx[ix1, ix2] = 0 + + np.testing.assert_allclose(ref_dx, dx.numpy(), atol=1e-7) + np.testing.assert_allclose( + np.array([4.0], dtype=np.float64), dvalue.numpy(), atol=1e-7 + ) + + out = paddle.index_put(x, (ix1, ix2), value, True) + + dx, dvalue = paddle.grad( + outputs=[out], + inputs=[x, value], + create_graph=False, + retain_graph=True, + ) + ref_dx = np.ones(shape=[16, 21], dtype=np.float64) + + np.testing.assert_allclose(ref_dx, dx.numpy(), atol=1e-7) + np.testing.assert_allclose( + np.array([4.0], dtype=np.float64), dvalue.numpy(), atol=1e-7 + ) + + def test_backwardBroadCastValue(self): + value = paddle.ones(shape=[2], dtype=paddle.float64) + x = paddle.ones(shape=[16, 21], dtype=paddle.float64) + ix1 = paddle.to_tensor([[0, 1], [2, 3]], dtype=paddle.int64) + ix2 = paddle.to_tensor([[0, 1], [2, 3]], dtype=paddle.int64) value.stop_gradient = False x.stop_gradient = False - out = paddle.index_put(x, (ix1, ix2), value, self.accumulate) + out = paddle.index_put(x, (ix1, ix2), value, False) dx, dvalue = paddle.grad( outputs=[out], @@ -328,12 +477,27 @@ def test_backwardScalarVal(self): create_graph=False, retain_graph=True, ) - ref_dx = np.ones(shape=[16, 21], dtype=self.dtype_np) + ref_dx = np.ones(shape=[16, 21], dtype=np.float64) ref_dx[ix1, ix2] = 0 np.testing.assert_allclose(ref_dx, dx.numpy(), atol=1e-7) np.testing.assert_allclose( - np.array([4.0], dtype=self.dtype_np), dvalue.numpy(), atol=1e-7 + np.array([2.0, 2.0], dtype=np.float64), dvalue.numpy(), atol=1e-7 + ) + + out = paddle.index_put(x, (ix1, ix2), value, False) + + dx, dvalue = paddle.grad( + outputs=[out], + inputs=[x, value], + create_graph=False, + retain_graph=True, + ) + ref_dx = np.ones(shape=[16, 21], dtype=np.float64) + + np.testing.assert_allclose(ref_dx, dx.numpy(), atol=1e-7) + np.testing.assert_allclose( + np.array([2.0, 2.0], dtype=np.float64), dvalue.numpy(), atol=1e-7 ) From 4538c1a1fda0cd334d285d60d257c9a607e05591 Mon Sep 17 00:00:00 2001 From: Courtesy-Xs Date: Sun, 16 Apr 2023 03:42:24 +0000 Subject: [PATCH 03/24] fix value broadcast in backward and add test case in static --- .../phi/kernels/cpu/index_put_grad_kernel.cc | 14 +- .../phi/kernels/gpu/index_put_grad_kernel.cu | 14 +- paddle/phi/kernels/gpu/index_put_kernel.cu | 4 - paddle/phi/kernels/index_put_kernel.h | 4 - .../tests/unittests/test_index_put_op.py | 349 +++++++++--------- 5 files changed, 187 insertions(+), 198 deletions(-) diff --git a/paddle/phi/kernels/cpu/index_put_grad_kernel.cc b/paddle/phi/kernels/cpu/index_put_grad_kernel.cc index 0e71979a596a47..f24df760d35517 100644 --- a/paddle/phi/kernels/cpu/index_put_grad_kernel.cc +++ b/paddle/phi/kernels/cpu/index_put_grad_kernel.cc @@ -220,25 +220,31 @@ void LaunchIndexPutGradKernel(const Context& dev_ctx, std::vector before_dims = phi::vectorize(value_grad->dims()); std::vector compress_dims; std::vector dims_without_1; - size_t i = after_dims.size(); - size_t j = before_dims.size(); + int i = static_cast(after_dims.size()) - 1; + int j = static_cast(before_dims.size()) - 1; if (i < j) { PADDLE_THROW(phi::errors::InvalidArgument( "shape of value can't not be broadcast to shape of x[indices]")); } - while ((i--) && (j--)) { + + while ((i >= 0) && (j >= 0)) { if (after_dims[i] == before_dims[j]) { dims_without_1.push_back(before_dims[j]); + i--; + j--; continue; } else if (before_dims[j] == 1) { compress_dims.push_back(i); + i--; + j--; } else { PADDLE_THROW(phi::errors::InvalidArgument( "shape of value can't not be broadcast to shape of x[indices]")); } } - while (i--) { + while (i >= 0) { compress_dims.push_back(i); + i--; } phi::DenseTensor value_grad_dims_without1(value_grad->dtype()); diff --git a/paddle/phi/kernels/gpu/index_put_grad_kernel.cu b/paddle/phi/kernels/gpu/index_put_grad_kernel.cu index f2802127539a1a..b2ad4b15c4ab4d 100644 --- a/paddle/phi/kernels/gpu/index_put_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/index_put_grad_kernel.cu @@ -239,25 +239,31 @@ void LaunchIndexPutGradCudaKernel( std::vector before_dims = phi::vectorize(value_grad->dims()); std::vector compress_dims; std::vector dims_without_1; - size_t i = after_dims.size(); - size_t j = before_dims.size(); + int i = static_cast(after_dims.size()) - 1; + int j = static_cast(before_dims.size()) - 1; if (i < j) { PADDLE_THROW(phi::errors::InvalidArgument( "shape of value can't not be broadcast to shape of x[indices]")); } - while ((i--) && (j--)) { + + while ((i >= 0) && (j >= 0)) { if (after_dims[i] == before_dims[j]) { dims_without_1.push_back(before_dims[j]); + i--; + j--; continue; } else if (before_dims[j] == 1) { compress_dims.push_back(i); + i--; + j--; } else { PADDLE_THROW(phi::errors::InvalidArgument( "shape of value can't not be broadcast to shape of x[indices]")); } } - while (i--) { + while (i >= 0) { compress_dims.push_back(i); + i--; } phi::DenseTensor value_grad_dims_without1(value_grad->dtype()); diff --git a/paddle/phi/kernels/gpu/index_put_kernel.cu b/paddle/phi/kernels/gpu/index_put_kernel.cu index 5314beb5aa3fd8..de64e4705a7987 100644 --- a/paddle/phi/kernels/gpu/index_put_kernel.cu +++ b/paddle/phi/kernels/gpu/index_put_kernel.cu @@ -167,10 +167,6 @@ void IndexPutKernel(const Context& dev_ctx, } } - for (auto dim : res_dim_v) { - std::cout << dim << std::endl; - } - for (size_t i = len_bd_dim; i < res_dim_v.size(); ++i) { reshaped_indices_v.emplace_back(GetRangeCudaTensor( dev_ctx, res_dim_v[i], phi::DataType::INT64)); diff --git a/paddle/phi/kernels/index_put_kernel.h b/paddle/phi/kernels/index_put_kernel.h index e9ab52b64f3cd8..5e8ce922a2dfc6 100644 --- a/paddle/phi/kernels/index_put_kernel.h +++ b/paddle/phi/kernels/index_put_kernel.h @@ -45,10 +45,6 @@ static phi::DenseTensor GetReshapeAndExpandTensor( } else { mid_dims[index] = before_dims[0]; } - std::cout << "this is mid_dim" << std::endl; - for (auto dim : mid_dims) { - std::cout << dim << std::endl; - } phi::DenseTensor mid_tensor(tensor.dtype()); mid_tensor.Resize(phi::make_ddim(mid_dims)); diff --git a/python/paddle/fluid/tests/unittests/test_index_put_op.py b/python/paddle/fluid/tests/unittests/test_index_put_op.py index ec2a8c4ff6c2c0..ee85b523c49308 100644 --- a/python/paddle/fluid/tests/unittests/test_index_put_op.py +++ b/python/paddle/fluid/tests/unittests/test_index_put_op.py @@ -19,6 +19,7 @@ import paddle from paddle import _C_ops +from paddle.fluid import Program def compute_index_put_ref(x_np, indices_np, value_np, accumulate=False): @@ -78,20 +79,13 @@ def gen_indices_np(x_shape, indices_shapes, index_type): class TestIndexPutAPIBase(unittest.TestCase): def setUp(self): self.init_dtype_type() + self.setPlace() self.x_np = np.random.random(self.x_shape).astype(self.dtype_np) self.value_np = np.random.random(self.value_shape).astype(self.dtype_np) self.indices_np = gen_indices_np( self.x_shape, self.indices_shapes, self.index_type_np ) - self.x_pd = paddle.to_tensor(self.x_np, dtype=self.dtype_pd) - self.value_pd = paddle.to_tensor(self.value_np, dtype=self.dtype_pd) - self.indices_pd = [ - paddle.to_tensor(indice, dtype=self.index_type_pd) - for indice in self.indices_np - ] - self.indices_pd = tuple(self.indices_pd) - def init_dtype_type(self): self.dtype_np = np.float64 self.index_type_np = np.int64 @@ -102,27 +96,79 @@ def init_dtype_type(self): self.index_type_pd = paddle.int64 self.accumulate = False + def setPlace(self): + self.place = ['cpu'] + if paddle.is_compiled_with_cuda(): + self.place.append('gpu') + + def test_dygraph_forward(self): + paddle.disable_static() + for place in self.place: + paddle.device.set_device(place) + self.x_pd = paddle.to_tensor(self.x_np, dtype=self.dtype_pd) + self.value_pd = paddle.to_tensor(self.value_np, dtype=self.dtype_pd) + self.indices_pd = [ + paddle.to_tensor(indice, dtype=self.index_type_pd) + for indice in self.indices_np + ] + self.indices_pd = tuple(self.indices_pd) + ref_res = compute_index_put_ref( + self.x_np, self.indices_np, self.value_np, self.accumulate + ) + pd_res = paddle.index_put( + self.x_pd, self.indices_pd, self.value_pd, self.accumulate + ) + np.testing.assert_allclose(ref_res, pd_res.numpy(), atol=1e-7) + + def test_static_forward(self): + paddle.enable_static() + for place in self.place: + with paddle.static.program_guard(Program()): + x = paddle.static.data( + name="x", shape=self.x_shape, dtype=self.dtype_pd + ) + indices = tuple( + [ + paddle.static.data( + name="indice" + str(i), + shape=self.indices_shapes[i], + dtype=self.index_type_pd, + ) + for i in range(len(self.indices_shapes)) + ] + ) + value = paddle.static.data( + name="value", shape=self.value_shape, dtype=self.dtype_pd + ) -class TestIndexPutAPI0(TestIndexPutAPIBase): - def test_forward(self): - self.accumulate = False - ref_res = compute_index_put_ref( - self.x_np, self.indices_np, self.value_np, self.accumulate - ) - pd_res = paddle.index_put( - self.x_pd, self.indices_pd, self.value_pd, self.accumulate - ) - np.testing.assert_allclose(ref_res, pd_res.numpy(), atol=1e-7) + out = paddle.index_put(x, indices, value, self.accumulate) + exe = paddle.static.Executor(place=place) + feed_list = {} + feed_list.update({"x": self.x_np}) + for i in range(len(indices)): + feed_list.update({"indice" + str(i): self.indices_np[i]}) + feed_list.update({"value": self.value_np}) + pd_res = exe.run( + paddle.static.default_main_program(), + feed=feed_list, + fetch_list=[out], + )[0] + ref_res = compute_index_put_ref( + self.x_np, self.indices_np, self.value_np, self.accumulate + ) + np.testing.assert_allclose(ref_res, pd_res, atol=1e-7) - def test_forward1(self): + +class TestIndexPutAPI0(TestIndexPutAPIBase): + def init_dtype_type(self): + self.dtype_np = np.float64 + self.index_type_np = np.int64 + self.x_shape = (100, 110) + self.indices_shapes = [(21,), (21,)] + self.value_shape = (21,) + self.dtype_pd = paddle.float64 + self.index_type_pd = paddle.int64 self.accumulate = True - ref_res = compute_index_put_ref( - self.x_np, self.indices_np, self.value_np, self.accumulate - ) - pd_res = paddle.index_put( - self.x_pd, self.indices_pd, self.value_pd, self.accumulate - ) - np.testing.assert_allclose(ref_res, pd_res.numpy(), atol=1e-7) class TestIndexPutAPI1(TestIndexPutAPIBase): @@ -136,60 +182,44 @@ def init_dtype_type(self): self.index_type_pd = paddle.int64 self.accumulate = False - def test_forward(self): - self.accumulate = False - ref_res = compute_index_put_ref( - self.x_np, self.indices_np, self.value_np, self.accumulate - ) - pd_res = paddle.index_put( - self.x_pd, self.indices_pd, self.value_pd, self.accumulate - ) - np.testing.assert_allclose(ref_res, pd_res.numpy(), atol=1e-7) - def test_forward1(self): +class TestIndexPutAPI2(TestIndexPutAPIBase): + def init_dtype_type(self): + self.dtype_np = np.float64 + self.index_type_np = np.int64 + self.x_shape = (110, 42, 56, 56) + self.indices_shapes = ((16, 16), (16, 16), (1, 16), (1, 16)) + self.value_shape = (16, 16) + self.dtype_pd = paddle.float64 + self.index_type_pd = paddle.int64 self.accumulate = True - ref_res = compute_index_put_ref( - self.x_np, self.indices_np, self.value_np, self.accumulate - ) - pd_res = paddle.index_put( - self.x_pd, self.indices_pd, self.value_pd, self.accumulate - ) - np.testing.assert_allclose(ref_res, pd_res.numpy(), atol=1e-7) -class TestIndexPutAPI2(TestIndexPutAPIBase): +class TestIndexPutAPI3(TestIndexPutAPIBase): def init_dtype_type(self): self.dtype_np = np.float64 self.index_type_np = np.bool_ self.x_shape = (110, 94) self.indices_shapes = [(110, 94)] - self.value_shape = 5170 + self.value_shape = (5170,) self.dtype_pd = paddle.float64 self.index_type_pd = paddle.bool self.accumulate = False - def test_forward(self): - self.accumulate = False - ref_res = compute_index_put_ref( - self.x_np, self.indices_np, self.value_np, self.accumulate - ) - pd_res = paddle.index_put( - self.x_pd, self.indices_pd, self.value_pd, self.accumulate - ) - np.testing.assert_allclose(ref_res, pd_res.numpy(), atol=1e-7) - def test_forward1(self): +class TestIndexPutAPI4(TestIndexPutAPIBase): + def init_dtype_type(self): + self.dtype_np = np.float64 + self.index_type_np = np.bool_ + self.x_shape = (110, 94) + self.indices_shapes = [(110, 94)] + self.value_shape = (5170,) + self.dtype_pd = paddle.float64 + self.index_type_pd = paddle.bool self.accumulate = True - ref_res = compute_index_put_ref( - self.x_np, self.indices_np, self.value_np, self.accumulate - ) - pd_res = paddle.index_put( - self.x_pd, self.indices_pd, self.value_pd, self.accumulate - ) - np.testing.assert_allclose(ref_res, pd_res.numpy(), atol=1e-7) -class TestIndexPutAPI3(TestIndexPutAPIBase): +class TestIndexPutAPI5(TestIndexPutAPIBase): def init_dtype_type(self): self.dtype_np = np.float64 self.index_type_np = np.int64 @@ -200,189 +230,142 @@ def init_dtype_type(self): self.index_type_pd = paddle.int64 self.accumulate = False - def test_forward(self): - self.accumulate = False - ref_res = compute_index_put_ref( - self.x_np, self.indices_np, self.value_np, self.accumulate - ) - pd_res = paddle.index_put( - self.x_pd, self.indices_pd, self.value_pd, self.accumulate - ) - np.testing.assert_allclose(ref_res, pd_res.numpy(), atol=1e-7) - def test_forward1(self): +class TestIndexPutAPI6(TestIndexPutAPIBase): + def init_dtype_type(self): + self.dtype_np = np.float64 + self.index_type_np = np.int64 + self.x_shape = (110, 42, 56, 56) + self.indices_shapes = ((16, 16), (16, 16), (1, 16)) + self.value_shape = (16, 16, 56) + self.dtype_pd = paddle.float64 + self.index_type_pd = paddle.int64 self.accumulate = True - ref_res = compute_index_put_ref( - self.x_np, self.indices_np, self.value_np, self.accumulate - ) - pd_res = paddle.index_put( - self.x_pd, self.indices_pd, self.value_pd, self.accumulate - ) - np.testing.assert_allclose(ref_res, pd_res.numpy(), atol=1e-7) -class TestIndexPutAPI4(TestIndexPutAPIBase): +class TestIndexPutAPI7(TestIndexPutAPIBase): def init_dtype_type(self): self.dtype_np = np.float64 self.index_type_np = np.bool_ self.x_shape = (110, 94) - self.indices_shapes = [(110)] + self.indices_shapes = [(110,)] self.value_shape = (55, 94) self.dtype_pd = paddle.float64 self.index_type_pd = paddle.bool self.accumulate = False - def test_forward(self): - self.accumulate = False - ref_res = compute_index_put_ref( - self.x_np, self.indices_np, self.value_np, self.accumulate - ) - pd_res = paddle.index_put( - self.x_pd, self.indices_pd, self.value_pd, self.accumulate - ) - np.testing.assert_allclose(ref_res, pd_res.numpy(), atol=1e-7) - def test_forward1(self): +class TestIndexPutAPI8(TestIndexPutAPIBase): + def init_dtype_type(self): + self.dtype_np = np.float64 + self.index_type_np = np.bool_ + self.x_shape = (110, 94) + self.indices_shapes = [(110,)] + self.value_shape = (55, 94) + self.dtype_pd = paddle.float64 + self.index_type_pd = paddle.bool self.accumulate = True - ref_res = compute_index_put_ref( - self.x_np, self.indices_np, self.value_np, self.accumulate - ) - pd_res = paddle.index_put( - self.x_pd, self.indices_pd, self.value_pd, self.accumulate - ) - np.testing.assert_allclose(ref_res, pd_res.numpy(), atol=1e-7) -class TestIndexPutAPI5(TestIndexPutAPIBase): +class TestIndexPutAPI9(TestIndexPutAPIBase): def init_dtype_type(self): self.dtype_np = np.float64 self.index_type_np = np.int64 - self.x_shape = (24, 100, 110, 98) - self.indices_shapes = ((21, 21), (1, 21), (1, 21)) - self.value_shape = 98 + self.x_shape = (110, 42, 56, 56) + self.indices_shapes = ((16, 16), (16, 16), (1, 16)) + self.value_shape = (56,) self.dtype_pd = paddle.float64 self.index_type_pd = paddle.int64 self.accumulate = False - def test_forward(self): - self.accumulate = False - ref_res = compute_index_put_ref( - self.x_np, self.indices_np, self.value_np, self.accumulate - ) - pd_res = paddle.index_put( - self.x_pd, self.indices_pd, self.value_pd, self.accumulate - ) - np.testing.assert_allclose(ref_res, pd_res.numpy(), atol=1e-7) - def test_forward1(self): +class TestIndexPutAPI10(TestIndexPutAPIBase): + def init_dtype_type(self): + self.dtype_np = np.float64 + self.index_type_np = np.int64 + self.x_shape = (110, 42, 56, 56) + self.indices_shapes = ((16, 16), (16, 16), (1, 16)) + self.value_shape = (56,) + self.dtype_pd = paddle.float64 + self.index_type_pd = paddle.int64 self.accumulate = True - ref_res = compute_index_put_ref( - self.x_np, self.indices_np, self.value_np, self.accumulate - ) - pd_res = paddle.index_put( - self.x_pd, self.indices_pd, self.value_pd, self.accumulate - ) - np.testing.assert_allclose(ref_res, pd_res.numpy(), atol=1e-7) -class TestIndexPutAPI6(TestIndexPutAPIBase): +class TestIndexPutAPI11(TestIndexPutAPIBase): def init_dtype_type(self): self.dtype_np = np.float64 self.index_type_np = np.int64 - self.x_shape = (24, 100, 110, 98) - self.indices_shapes = ((21, 21), (1, 21), (1, 21)) - self.value_shape = 1 + self.x_shape = (110, 42, 56, 56) + self.indices_shapes = ((16, 16), (16, 16), (1, 16)) + self.value_shape = (1,) self.dtype_pd = paddle.float64 self.index_type_pd = paddle.int64 self.accumulate = False - def test_forward(self): - self.accumulate = False - ref_res = compute_index_put_ref( - self.x_np, self.indices_np, self.value_np, self.accumulate - ) - pd_res = paddle.index_put( - self.x_pd, self.indices_pd, self.value_pd, self.accumulate - ) - np.testing.assert_allclose(ref_res, pd_res.numpy(), atol=1e-7) - def test_forward1(self): +class TestIndexPutAPI12(TestIndexPutAPIBase): + def init_dtype_type(self): + self.dtype_np = np.float64 + self.index_type_np = np.int64 + self.x_shape = (110, 42, 56, 56) + self.indices_shapes = ((16, 16), (16, 16), (1, 16)) + self.value_shape = (1,) + self.dtype_pd = paddle.float64 + self.index_type_pd = paddle.int64 self.accumulate = True - ref_res = compute_index_put_ref( - self.x_np, self.indices_np, self.value_np, self.accumulate - ) - pd_res = paddle.index_put( - self.x_pd, self.indices_pd, self.value_pd, self.accumulate - ) - np.testing.assert_allclose(ref_res, pd_res.numpy(), atol=1e-7) -class TestIndexPutAPI7(TestIndexPutAPIBase): +class TestIndexPutAPI13(TestIndexPutAPIBase): def init_dtype_type(self): self.dtype_np = np.float64 self.index_type_np = np.bool_ self.x_shape = (44, 94) - self.indices_shapes = [(44)] - self.value_shape = 94 + self.indices_shapes = [(44,)] + self.value_shape = (94,) self.dtype_pd = paddle.float64 self.index_type_pd = paddle.bool self.accumulate = False - def test_forward(self): - self.accumulate = False - ref_res = compute_index_put_ref( - self.x_np, self.indices_np, self.value_np, self.accumulate - ) - pd_res = paddle.index_put( - self.x_pd, self.indices_pd, self.value_pd, self.accumulate - ) - np.testing.assert_allclose(ref_res, pd_res.numpy(), atol=1e-7) - def test_forward1(self): +class TestIndexPutAPI14(TestIndexPutAPIBase): + def init_dtype_type(self): + self.dtype_np = np.float64 + self.index_type_np = np.bool_ + self.x_shape = (44, 94) + self.indices_shapes = [(44,)] + self.value_shape = (94,) + self.dtype_pd = paddle.float64 + self.index_type_pd = paddle.bool self.accumulate = True - ref_res = compute_index_put_ref( - self.x_np, self.indices_np, self.value_np, self.accumulate - ) - pd_res = paddle.index_put( - self.x_pd, self.indices_pd, self.value_pd, self.accumulate - ) - np.testing.assert_allclose(ref_res, pd_res.numpy(), atol=1e-7) -class TestIndexPutAPI8(TestIndexPutAPIBase): +class TestIndexPutAPI15(TestIndexPutAPIBase): def init_dtype_type(self): self.dtype_np = np.float64 self.index_type_np = np.bool_ self.x_shape = (44, 94) - self.indices_shapes = [(44)] - self.value_shape = 1 + self.indices_shapes = [(44,)] + self.value_shape = (1,) self.dtype_pd = paddle.float64 self.index_type_pd = paddle.bool self.accumulate = False - def test_forward(self): - self.accumulate = False - ref_res = compute_index_put_ref( - self.x_np, self.indices_np, self.value_np, self.accumulate - ) - pd_res = paddle.index_put( - self.x_pd, self.indices_pd, self.value_pd, self.accumulate - ) - np.testing.assert_allclose(ref_res, pd_res.numpy(), atol=1e-7) - def test_forward1(self): +class TestIndexPutAPI16(TestIndexPutAPIBase): + def init_dtype_type(self): + self.dtype_np = np.float64 + self.index_type_np = np.bool_ + self.x_shape = (44, 94) + self.indices_shapes = [(44,)] + self.value_shape = (1,) + self.dtype_pd = paddle.float64 + self.index_type_pd = paddle.bool self.accumulate = True - ref_res = compute_index_put_ref( - self.x_np, self.indices_np, self.value_np, self.accumulate - ) - pd_res = paddle.index_put( - self.x_pd, self.indices_pd, self.value_pd, self.accumulate - ) - np.testing.assert_allclose(ref_res, pd_res.numpy(), atol=1e-7) class TestIndexPutAPIBackward(unittest.TestCase): def test_backward(self): + paddle.disable_static() value = paddle.ones(shape=[4], dtype=paddle.float64) x = paddle.ones(shape=[16, 21], dtype=paddle.float64) ix1 = paddle.to_tensor([0, 1, 2, 3], dtype=paddle.int64) @@ -425,6 +408,7 @@ def test_backward(self): ) def test_backwardScalarVal(self): + paddle.disable_static() value = paddle.ones(shape=[1], dtype=paddle.float64) x = paddle.ones(shape=[16, 21], dtype=paddle.float64) ix1 = paddle.to_tensor([0, 1, 2, 3], dtype=paddle.int64) @@ -463,6 +447,7 @@ def test_backwardScalarVal(self): ) def test_backwardBroadCastValue(self): + paddle.disable_static() value = paddle.ones(shape=[2], dtype=paddle.float64) x = paddle.ones(shape=[16, 21], dtype=paddle.float64) ix1 = paddle.to_tensor([[0, 1], [2, 3]], dtype=paddle.int64) @@ -485,7 +470,7 @@ def test_backwardBroadCastValue(self): np.array([2.0, 2.0], dtype=np.float64), dvalue.numpy(), atol=1e-7 ) - out = paddle.index_put(x, (ix1, ix2), value, False) + out = paddle.index_put(x, (ix1, ix2), value, True) dx, dvalue = paddle.grad( outputs=[out], From 244d02d9ef0fcc01929b889ceedf415a0f6275ee Mon Sep 17 00:00:00 2001 From: Courtesy-Xs Date: Mon, 17 Apr 2023 03:10:00 +0000 Subject: [PATCH 04/24] fix cpu backward bug --- paddle/phi/kernels/cpu/index_put_grad_kernel.cc | 2 +- paddle/phi/kernels/gpu/index_put_grad_kernel.cu | 2 -- 2 files changed, 1 insertion(+), 3 deletions(-) diff --git a/paddle/phi/kernels/cpu/index_put_grad_kernel.cc b/paddle/phi/kernels/cpu/index_put_grad_kernel.cc index f24df760d35517..c55a82b6b828ac 100644 --- a/paddle/phi/kernels/cpu/index_put_grad_kernel.cc +++ b/paddle/phi/kernels/cpu/index_put_grad_kernel.cc @@ -63,7 +63,7 @@ void set_zero_kernel(const int64_t N, } offset += stride[i] * cur_ix; } - *(out + idx) = 0; + *(out + offset) = 0; } } diff --git a/paddle/phi/kernels/gpu/index_put_grad_kernel.cu b/paddle/phi/kernels/gpu/index_put_grad_kernel.cu index b2ad4b15c4ab4d..cbcf0f1cd18a4a 100644 --- a/paddle/phi/kernels/gpu/index_put_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/index_put_grad_kernel.cu @@ -308,8 +308,6 @@ void IndexPutGradKernel(const Context& dev_ctx, std::vector res_dim_v(phi::vectorize(bd_dim)); std::vector res_indices_v(x.dims().size(), nullptr); std::vector tmp_res_indices_v; - std::vector tmp_value_v; - const DenseTensor* ptr_value = nullptr; if (int_indices_v.size() < total_dims) { std::vector tmp_x_dims = phi::vectorize(x.dims()); From 01672f806b3248f7c19cc6da6df984c87dc79fdc Mon Sep 17 00:00:00 2001 From: Courtesy-Xs Date: Mon, 17 Apr 2023 05:01:04 +0000 Subject: [PATCH 05/24] add timeout=120s for index_put --- python/paddle/fluid/tests/unittests/CMakeLists.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/python/paddle/fluid/tests/unittests/CMakeLists.txt b/python/paddle/fluid/tests/unittests/CMakeLists.txt index e0d89932a29213..ef41e2f6c2e56b 100755 --- a/python/paddle/fluid/tests/unittests/CMakeLists.txt +++ b/python/paddle/fluid/tests/unittests/CMakeLists.txt @@ -927,6 +927,7 @@ set_tests_properties(test_imperative_selected_rows_to_lod_tensor PROPERTIES TIMEOUT 200) set_tests_properties(test_index_select_op PROPERTIES TIMEOUT 120) set_tests_properties(test_index_add_op PROPERTIES TIMEOUT 120) +set_tests_properties(test_index_put_op PROPERTIES TIMEOUT 120) set_tests_properties(test_tensordot PROPERTIES TIMEOUT 200) set_tests_properties(test_partial_eager_deletion_transformer PROPERTIES TIMEOUT 120) From 5a361ea9edb521f4484fa3f556dfb43206abc275 Mon Sep 17 00:00:00 2001 From: Courtesy-Xs Date: Mon, 17 Apr 2023 06:52:43 +0000 Subject: [PATCH 06/24] add op_compat for index_put --- paddle/phi/api/yaml/op_compat.yaml | 8 ++++++++ python/paddle/tensor/manipulation.py | 3 ++- 2 files changed, 10 insertions(+), 1 deletion(-) diff --git a/paddle/phi/api/yaml/op_compat.yaml b/paddle/phi/api/yaml/op_compat.yaml index 19069eeac9a58e..c41423101dd6dd 100644 --- a/paddle/phi/api/yaml/op_compat.yaml +++ b/paddle/phi/api/yaml/op_compat.yaml @@ -1061,6 +1061,14 @@ outputs : out : Out +- op : index_put + backward : index_put_grad + inputs : + {x : x, indices : indices, value : value} + outputs : + out : out + attrs : [accumulate = false] + - op : index_sample inputs : {x : X, index : Index} diff --git a/python/paddle/tensor/manipulation.py b/python/paddle/tensor/manipulation.py index 330c15c8bd77b5..478b6bbe9fbbb0 100644 --- a/python/paddle/tensor/manipulation.py +++ b/python/paddle/tensor/manipulation.py @@ -4828,6 +4828,7 @@ def index_put(x, indices, value, accumulate=False, name=None): # [0., 1., 0.]]) """ + assert len(indices) != 0, "indices can't be empty" if in_dygraph_mode(): return _C_ops.index_put(x, indices, value, accumulate) @@ -4840,7 +4841,7 @@ def index_put(x, indices, value, accumulate=False, name=None): ) check_variable_and_dtype( value, - 'add_value', + 'value', ['float16', 'float32', 'float64', 'int32', 'int64', 'bool'], 'paddle.tensor.manipulation.index_put', ) From a7f2d4230cbd87ace11b96be0da6a19f2bbc7235 Mon Sep 17 00:00:00 2001 From: Courtesy-Xs Date: Mon, 17 Apr 2023 07:04:57 +0000 Subject: [PATCH 07/24] delete input_put in op_compat.yaml --- paddle/phi/api/yaml/op_compat.yaml | 8 -------- 1 file changed, 8 deletions(-) diff --git a/paddle/phi/api/yaml/op_compat.yaml b/paddle/phi/api/yaml/op_compat.yaml index c41423101dd6dd..19069eeac9a58e 100644 --- a/paddle/phi/api/yaml/op_compat.yaml +++ b/paddle/phi/api/yaml/op_compat.yaml @@ -1061,14 +1061,6 @@ outputs : out : Out -- op : index_put - backward : index_put_grad - inputs : - {x : x, indices : indices, value : value} - outputs : - out : out - attrs : [accumulate = false] - - op : index_sample inputs : {x : X, index : Index} From d996d36f866e78c8fc7f221977c72b9ca9554979 Mon Sep 17 00:00:00 2001 From: Courtesy-Xs Date: Mon, 17 Apr 2023 12:39:11 +0000 Subject: [PATCH 08/24] add inplace index_put test --- .../tests/unittests/test_index_put_op.py | 59 +++++++++++++++++++ 1 file changed, 59 insertions(+) diff --git a/python/paddle/fluid/tests/unittests/test_index_put_op.py b/python/paddle/fluid/tests/unittests/test_index_put_op.py index ee85b523c49308..47a07e25e98e11 100644 --- a/python/paddle/fluid/tests/unittests/test_index_put_op.py +++ b/python/paddle/fluid/tests/unittests/test_index_put_op.py @@ -486,5 +486,64 @@ def test_backwardBroadCastValue(self): ) +class TestIndexPutInplaceAPI(unittest.TestCase): + def setUp(self): + self.init_dtype_type() + self.setPlace() + self.x_np = np.random.random(self.x_shape).astype(self.dtype_np) + self.value_np = np.random.random(self.value_shape).astype(self.dtype_np) + self.indices_np = gen_indices_np( + self.x_shape, self.indices_shapes, self.index_type_np + ) + + def init_dtype_type(self): + self.dtype_np = np.float64 + self.index_type_np = np.int64 + self.x_shape = (100, 110) + self.indices_shapes = [(21,), (21,)] + self.value_shape = (21,) + self.dtype_pd = paddle.float64 + self.index_type_pd = paddle.int64 + self.accumulate = False + + def setPlace(self): + self.place = ['cpu'] + if paddle.is_compiled_with_cuda(): + self.place.append('gpu') + + def test_dygraph_forward(self): + paddle.disable_static() + for place in self.place: + paddle.device.set_device(place) + self.x_pd = paddle.to_tensor(self.x_np, dtype=self.dtype_pd) + self.value_pd = paddle.to_tensor(self.value_np, dtype=self.dtype_pd) + self.indices_pd = [ + paddle.to_tensor(indice, dtype=self.index_type_pd) + for indice in self.indices_np + ] + self.indices_pd = tuple(self.indices_pd) + ref_res = compute_index_put_ref( + self.x_np, self.indices_np, self.value_np, self.accumulate + ) + x_pd_bk = self.x_pd.clone() + pd_res = paddle.index_put_( + x_pd_bk, self.indices_pd, self.value_pd, self.accumulate + ) + np.testing.assert_allclose(ref_res, pd_res.numpy(), atol=1e-7) + np.testing.assert_allclose(ref_res, x_pd_bk.numpy(), atol=1e-7) + + +class TestIndexPutInplaceAPI1(TestIndexPutInplaceAPI): + def init_dtype_type(self): + self.dtype_np = np.float64 + self.index_type_np = np.int64 + self.x_shape = (100, 110) + self.indices_shapes = [(21,), (21,)] + self.value_shape = (21,) + self.dtype_pd = paddle.float64 + self.index_type_pd = paddle.int64 + self.accumulate = True + + if __name__ == '__main__': unittest.main() From 8a3fef496d784ed6a0de2cab798b1cbd9bc9c4a3 Mon Sep 17 00:00:00 2001 From: Courtesy-Xs Date: Tue, 18 Apr 2023 03:41:14 +0000 Subject: [PATCH 09/24] refactor code --- .../phi/kernels/cpu/index_put_grad_kernel.cc | 1 + paddle/phi/kernels/cpu/index_put_kernel.cc | 1 + .../phi/kernels/gpu/index_put_grad_kernel.cu | 1 + paddle/phi/kernels/gpu/index_put_kernel.cu | 1 + paddle/phi/kernels/index_put_grad_kernel.h | 151 ---------- paddle/phi/kernels/index_put_kernel.h | 150 ---------- paddle/phi/kernels/index_put_utils.h | 184 ++++++++++++ .../tests/unittests/test_index_put_op.py | 269 ++++++++++++++---- 8 files changed, 395 insertions(+), 363 deletions(-) create mode 100644 paddle/phi/kernels/index_put_utils.h diff --git a/paddle/phi/kernels/cpu/index_put_grad_kernel.cc b/paddle/phi/kernels/cpu/index_put_grad_kernel.cc index c55a82b6b828ac..5c42db4a6d60a5 100644 --- a/paddle/phi/kernels/cpu/index_put_grad_kernel.cc +++ b/paddle/phi/kernels/cpu/index_put_grad_kernel.cc @@ -20,6 +20,7 @@ #include "paddle/phi/core/utils/array.h" #include "paddle/phi/kernels/cast_kernel.h" #include "paddle/phi/kernels/expand_kernel.h" +#include "paddle/phi/kernels/index_put_utils.h" #include "paddle/phi/kernels/reduce_sum_kernel.h" #include "paddle/phi/kernels/reshape_kernel.h" namespace phi { diff --git a/paddle/phi/kernels/cpu/index_put_kernel.cc b/paddle/phi/kernels/cpu/index_put_kernel.cc index 8a1ec36e01b55f..3b530ce057f98a 100644 --- a/paddle/phi/kernels/cpu/index_put_kernel.cc +++ b/paddle/phi/kernels/cpu/index_put_kernel.cc @@ -18,6 +18,7 @@ #include "paddle/phi/core/utils/array.h" #include "paddle/phi/kernels/cast_kernel.h" #include "paddle/phi/kernels/expand_kernel.h" +#include "paddle/phi/kernels/index_put_utils.h" namespace phi { diff --git a/paddle/phi/kernels/gpu/index_put_grad_kernel.cu b/paddle/phi/kernels/gpu/index_put_grad_kernel.cu index cbcf0f1cd18a4a..84bb847adb5abb 100644 --- a/paddle/phi/kernels/gpu/index_put_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/index_put_grad_kernel.cu @@ -20,6 +20,7 @@ #include "paddle/phi/core/utils/array.h" #include "paddle/phi/kernels/cast_kernel.h" #include "paddle/phi/kernels/expand_kernel.h" +#include "paddle/phi/kernels/index_put_utils.h" #include "paddle/phi/kernels/reduce_sum_kernel.h" #include "paddle/phi/kernels/reshape_kernel.h" diff --git a/paddle/phi/kernels/gpu/index_put_kernel.cu b/paddle/phi/kernels/gpu/index_put_kernel.cu index de64e4705a7987..a8df483b86a249 100644 --- a/paddle/phi/kernels/gpu/index_put_kernel.cu +++ b/paddle/phi/kernels/gpu/index_put_kernel.cu @@ -19,6 +19,7 @@ #include "paddle/phi/core/utils/array.h" #include "paddle/phi/kernels/cast_kernel.h" #include "paddle/phi/kernels/expand_kernel.h" +#include "paddle/phi/kernels/index_put_utils.h" #include "paddle/phi/kernels/nonzero_kernel.h" #include "paddle/phi/kernels/split_kernel.h" diff --git a/paddle/phi/kernels/index_put_grad_kernel.h b/paddle/phi/kernels/index_put_grad_kernel.h index f28a4c61dfdbc3..d5313ac10dd5ae 100644 --- a/paddle/phi/kernels/index_put_grad_kernel.h +++ b/paddle/phi/kernels/index_put_grad_kernel.h @@ -27,157 +27,6 @@ #include "paddle/phi/kernels/split_kernel.h" namespace phi { - -template -static phi::DenseTensor GetReshapeAndExpandTensor( - const Context& dev_ctx, - const phi::DenseTensor& tensor, - const phi::DDim& res_dim, - const phi::DDim& bd_dim, - int index) { - std::vector before_dims = phi::vectorize(tensor.dims()); - std::vector mid_dims(res_dim.size(), 1); - - if (index == 0) { - for (size_t i = 0; i < before_dims.size(); ++i) { - mid_dims[bd_dim.size() - i - 1] = before_dims[before_dims.size() - i - 1]; - } - } else { - mid_dims[index] = before_dims[0]; - } - - phi::DenseTensor mid_tensor(tensor.dtype()); - mid_tensor.Resize(phi::make_ddim(mid_dims)); - ReshapeInferKernel(dev_ctx, tensor, IntArray(mid_dims), &mid_tensor); - - phi::DenseTensor res_tensor(tensor.dtype()); - res_tensor.Resize(res_dim); - ExpandKernel( - dev_ctx, mid_tensor, IntArray(phi::vectorize(res_dim)), &res_tensor); - return res_tensor; -} - -template -static std::vector DealWithBoolIndices( - const Context& dev_ctx, - const std::vector& indices_v, - std::vector* tmp_indices_v) { - std::vector res(indices_v.begin(), indices_v.end()); - bool contains_bool_tensor = false; - for (size_t i = 0; i < indices_v.size(); ++i) { - if (indices_v[i]->dtype() == phi::DataType::BOOL) { - contains_bool_tensor = true; - } else if ((indices_v[i]->dtype() == phi::DataType::INT64) || - (indices_v[i]->dtype() == phi::DataType::INT32)) { - if (contains_bool_tensor) { - PADDLE_THROW(phi::errors::InvalidArgument( - "indices contains bool tensor and int32/int64 tensor at the same " - "time")); - } - } else { - PADDLE_THROW(phi::errors::InvalidArgument( - "data type of tensor in indices must be int32, int64 or bool")); - } - } - - if (contains_bool_tensor) { - if (indices_v.size() != 1) { - PADDLE_THROW(phi::errors::InvalidArgument( - "the size of indices must be 1 when it containts bool tensor")); - } - int rank = indices_v[0]->dims().size(); - PADDLE_ENFORCE_GE( - rank, - 1UL, - phi::errors::InvalidArgument("the only bool tensor in indices should " - "have number of dimension at least 1")); - phi::DenseTensor nonzero_indices(phi::DataType::INT64); - nonzero_indices.Resize(phi::make_ddim({-1, rank})); - NonZeroKernel(dev_ctx, *indices_v[0], &nonzero_indices); - - std::vector integer_indices(rank, nullptr); - for (int i = 0; i < rank; ++i) { - // tmp_indices_v.emplace_back(DenseTensor(phi::DataType::INT64).Resize(phi::make_ddim({nonzero_indices.dims()[0],1}))); - // 理论上这里应该要加个1的 - tmp_indices_v->emplace_back( - DenseTensor(phi::DataType::INT64) - .Resize(phi::make_ddim({nonzero_indices.dims()[0]}))); - } - for (int i = 0; i < rank; ++i) { - integer_indices[i] = &((*tmp_indices_v)[i]); - } - SplitWithNumKernel( - dev_ctx, nonzero_indices, rank, 1, integer_indices); - - std::vector res_tmp(integer_indices.size(), - nullptr); - for (int i = 0; i < rank; ++i) { - res_tmp[i] = &((*tmp_indices_v)[i]); - } - res.swap(res_tmp); - } - return res; -} - -static phi::DDim BroadCastTensorsDims( - const std::vector& tensors) { - int target_rank = 0; - for (const auto& tensor : tensors) { - target_rank = std::max(target_rank, tensor->dims().size()); - } - - PADDLE_ENFORCE_GT(target_rank, - 0, - errors::InvalidArgument("BroadCastTensorsDims requires at " - "least one input tensor to have " - "rank greater than zero")); - - std::vector target_dims(target_rank, 0); - for (int index = 0; index < target_rank; index++) { - int target_dim_size = 1; - for (const auto& tensor : tensors) { - auto input_ddim = tensor->dims(); - int axis = static_cast(input_ddim.size()) - index - 1; - int dim_size = 1; - if (axis >= 0) { - dim_size = input_ddim[axis]; - } - - if (target_dim_size != 1 && dim_size != 1 && - target_dim_size != dim_size) { - PADDLE_THROW(errors::InvalidArgument( - "BroadCastTensorsDims inputs does not satisfy bcast semantics, " - "please check axis = %d in reverse order", - index)); - } - - target_dim_size = dim_size == 1 ? target_dim_size : dim_size; - } - target_dims[target_rank - index - 1] = target_dim_size; - } - return phi::make_ddim(target_dims); -} - -template -T** GetDevicePointerArray(const Context& ctx, - const std::vector& indices_v) { - std::vector h_indices_v(indices_v.size()); - for (int i = 0; i < indices_v.size(); ++i) { - h_indices_v[i] = indices_v[i]->data(); - } - auto d_indices_data = paddle::memory::Alloc( - ctx.GetPlace(), - h_indices_v.size() * sizeof(T*), - phi::Stream(reinterpret_cast(ctx.stream()))); - paddle::memory::Copy(ctx.GetPlace(), - d_indices_data->ptr(), - phi::CPUPlace(), - reinterpret_cast(h_indices_v.data()), - h_indices_v.size() * sizeof(T*), - ctx.stream()); - return reinterpret_cast(d_indices_data->ptr()); -} - template void IndexPutGradKernel(const Context& dev_ctx, const DenseTensor& x, diff --git a/paddle/phi/kernels/index_put_kernel.h b/paddle/phi/kernels/index_put_kernel.h index 5e8ce922a2dfc6..1f5d17fe451e79 100644 --- a/paddle/phi/kernels/index_put_kernel.h +++ b/paddle/phi/kernels/index_put_kernel.h @@ -27,156 +27,6 @@ #include "paddle/phi/kernels/split_kernel.h" namespace phi { - -template -static phi::DenseTensor GetReshapeAndExpandTensor( - const Context& dev_ctx, - const phi::DenseTensor& tensor, - const phi::DDim& res_dim, - const phi::DDim& bd_dim, - int index) { - std::vector before_dims = phi::vectorize(tensor.dims()); - std::vector mid_dims(res_dim.size(), 1); - - if (index == 0) { - for (size_t i = 0; i < before_dims.size(); ++i) { - mid_dims[bd_dim.size() - i - 1] = before_dims[before_dims.size() - i - 1]; - } - } else { - mid_dims[index] = before_dims[0]; - } - - phi::DenseTensor mid_tensor(tensor.dtype()); - mid_tensor.Resize(phi::make_ddim(mid_dims)); - ReshapeInferKernel(dev_ctx, tensor, IntArray(mid_dims), &mid_tensor); - - phi::DenseTensor res_tensor(tensor.dtype()); - res_tensor.Resize(res_dim); - ExpandKernel( - dev_ctx, mid_tensor, IntArray(phi::vectorize(res_dim)), &res_tensor); - return res_tensor; -} - -template -static std::vector DealWithBoolIndices( - const Context& dev_ctx, - const std::vector& indices_v, - std::vector* tmp_indices_v) { - std::vector res(indices_v.begin(), indices_v.end()); - bool contains_bool_tensor = false; - for (size_t i = 0; i < indices_v.size(); ++i) { - if (indices_v[i]->dtype() == phi::DataType::BOOL) { - contains_bool_tensor = true; - } else if ((indices_v[i]->dtype() == phi::DataType::INT64) || - (indices_v[i]->dtype() == phi::DataType::INT32)) { - if (contains_bool_tensor) { - PADDLE_THROW(phi::errors::InvalidArgument( - "indices contains bool tensor and int32/int64 tensor at the same " - "time")); - } - } else { - PADDLE_THROW(phi::errors::InvalidArgument( - "data type of tensor in indices must be int32, int64 or bool")); - } - } - - if (contains_bool_tensor) { - if (indices_v.size() != 1) { - PADDLE_THROW(phi::errors::InvalidArgument( - "the size of indices must be 1 when it containts bool tensor")); - } - int rank = indices_v[0]->dims().size(); - PADDLE_ENFORCE_GE( - rank, - 1UL, - phi::errors::InvalidArgument("the only bool tensor in indices should " - "have number of dimension at least 1")); - phi::DenseTensor nonzero_indices(phi::DataType::INT64); - nonzero_indices.Resize(phi::make_ddim({-1, rank})); - NonZeroKernel(dev_ctx, *indices_v[0], &nonzero_indices); - - std::vector integer_indices(rank, nullptr); - for (int i = 0; i < rank; ++i) { - // tmp_indices_v.emplace_back(DenseTensor(phi::DataType::INT64).Resize(phi::make_ddim({nonzero_indices.dims()[0],1}))); - tmp_indices_v->emplace_back( - DenseTensor(phi::DataType::INT64) - .Resize(phi::make_ddim({nonzero_indices.dims()[0]}))); - } - for (int i = 0; i < rank; ++i) { - integer_indices[i] = &((*tmp_indices_v)[i]); - } - SplitWithNumKernel( - dev_ctx, nonzero_indices, rank, 1, integer_indices); - - std::vector res_tmp(integer_indices.size(), - nullptr); - for (int i = 0; i < rank; ++i) { - res_tmp[i] = &((*tmp_indices_v)[i]); - } - res.swap(res_tmp); - } - return res; -} - -static phi::DDim BroadCastTensorsDims( - const std::vector& tensors) { - int target_rank = 0; - for (const auto& tensor : tensors) { - target_rank = std::max(target_rank, tensor->dims().size()); - } - - PADDLE_ENFORCE_GT(target_rank, - 0, - errors::InvalidArgument("BroadCastTensorsDims requires at " - "least one input tensor to have " - "rank greater than zero")); - - std::vector target_dims(target_rank, 0); - for (int index = 0; index < target_rank; index++) { - int target_dim_size = 1; - for (const auto& tensor : tensors) { - auto input_ddim = tensor->dims(); - int axis = static_cast(input_ddim.size()) - index - 1; - int dim_size = 1; - if (axis >= 0) { - dim_size = input_ddim[axis]; - } - - if (target_dim_size != 1 && dim_size != 1 && - target_dim_size != dim_size) { - PADDLE_THROW(errors::InvalidArgument( - "BroadCastTensorsDims inputs does not satisfy bcast semantics, " - "please check axis = %d in reverse order", - index)); - } - - target_dim_size = dim_size == 1 ? target_dim_size : dim_size; - } - target_dims[target_rank - index - 1] = target_dim_size; - } - return phi::make_ddim(target_dims); -} - -template -T** GetDevicePointerArray(const Context& ctx, - const std::vector& indices_v) { - std::vector h_indices_v(indices_v.size()); - for (int i = 0; i < indices_v.size(); ++i) { - h_indices_v[i] = indices_v[i]->data(); - } - auto d_indices_data = paddle::memory::Alloc( - ctx.GetPlace(), - h_indices_v.size() * sizeof(T*), - phi::Stream(reinterpret_cast(ctx.stream()))); - paddle::memory::Copy(ctx.GetPlace(), - d_indices_data->ptr(), - phi::CPUPlace(), - reinterpret_cast(h_indices_v.data()), - h_indices_v.size() * sizeof(T*), - ctx.stream()); - return reinterpret_cast(d_indices_data->ptr()); -} - template void IndexPutKernel(const Context& dev_ctx, const DenseTensor& x, diff --git a/paddle/phi/kernels/index_put_utils.h b/paddle/phi/kernels/index_put_utils.h new file mode 100644 index 00000000000000..aef3e4a60afab7 --- /dev/null +++ b/paddle/phi/kernels/index_put_utils.h @@ -0,0 +1,184 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include "paddle/fluid/memory/malloc.h" +#include "paddle/fluid/memory/memcpy.h" +#include "paddle/phi/common/int_array.h" +#include "paddle/phi/common/place.h" +#include "paddle/phi/core/dense_tensor.h" +#include "paddle/phi/core/utils/array.h" +#include "paddle/phi/kernels/expand_kernel.h" +#include "paddle/phi/kernels/nonzero_kernel.h" +#include "paddle/phi/kernels/reshape_kernel.h" +#include "paddle/phi/kernels/split_kernel.h" + +namespace phi { + +template +static phi::DenseTensor GetReshapeAndExpandTensor( + const Context& dev_ctx, + const phi::DenseTensor& tensor, + const phi::DDim& res_dim, + const phi::DDim& bd_dim, + int index) { + std::vector before_dims = phi::vectorize(tensor.dims()); + std::vector mid_dims(res_dim.size(), 1); + + if (index == 0) { + for (size_t i = 0; i < before_dims.size(); ++i) { + mid_dims[bd_dim.size() - i - 1] = before_dims[before_dims.size() - i - 1]; + } + } else { + mid_dims[index] = before_dims[0]; + } + + phi::DenseTensor mid_tensor(tensor.dtype()); + mid_tensor.Resize(phi::make_ddim(mid_dims)); + ReshapeInferKernel(dev_ctx, tensor, IntArray(mid_dims), &mid_tensor); + + phi::DenseTensor res_tensor(tensor.dtype()); + res_tensor.Resize(res_dim); + ExpandKernel( + dev_ctx, mid_tensor, IntArray(phi::vectorize(res_dim)), &res_tensor); + return res_tensor; +} + +template +static std::vector DealWithBoolIndices( + const Context& dev_ctx, + const std::vector& indices_v, + std::vector* tmp_indices_v) { + std::vector res(indices_v.begin(), indices_v.end()); + bool contains_bool_tensor = false; + for (size_t i = 0; i < indices_v.size(); ++i) { + PADDLE_ENFORCE( + (indices_v[i]->dtype() == phi::DataType::INT64) || + (indices_v[i]->dtype() == phi::DataType::INT32) || + (indices_v[i]->dtype() == phi::DataType::BOOL), + phi::errors::InvalidArgument( + "indices contains bool tensor and int32/int64 tensor at the same " + "time")); + if (indices_v[i]->dtype() == phi::DataType::BOOL) { + contains_bool_tensor = true; + } else { + PADDLE_ENFORCE_EQ( + contains_bool_tensor, + false, + phi::errors::InvalidArgument( + "indices contains bool tensor and int32/int64 tensor at the same " + "time")); + } + } + + if (contains_bool_tensor) { + if (indices_v.size() != 1) { + PADDLE_THROW(phi::errors::InvalidArgument( + "the size of indices must be 1 when it containts bool tensor")); + } + int rank = indices_v[0]->dims().size(); + PADDLE_ENFORCE_GE( + rank, + 1UL, + phi::errors::InvalidArgument("the only bool tensor in indices should " + "have number of dimension at least 1")); + phi::DenseTensor nonzero_indices(phi::DataType::INT64); + nonzero_indices.Resize(phi::make_ddim({-1, rank})); + NonZeroKernel(dev_ctx, *indices_v[0], &nonzero_indices); + + std::vector integer_indices(rank, nullptr); + for (int i = 0; i < rank; ++i) { + // here should be + // tmp_indices_v.emplace_back(DenseTensor(phi::DataType::INT64).Resize(phi::make_ddim({nonzero_indices.dims()[0],1}))); + tmp_indices_v->emplace_back( + DenseTensor(phi::DataType::INT64) + .Resize(phi::make_ddim({nonzero_indices.dims()[0]}))); + } + for (int i = 0; i < rank; ++i) { + integer_indices[i] = &((*tmp_indices_v)[i]); + } + SplitWithNumKernel( + dev_ctx, nonzero_indices, rank, 1, integer_indices); + + std::vector res_tmp(integer_indices.size(), + nullptr); + for (int i = 0; i < rank; ++i) { + res_tmp[i] = &((*tmp_indices_v)[i]); + } + res.swap(res_tmp); + } + return res; +} + +static phi::DDim BroadCastTensorsDims( + const std::vector& tensors) { + int target_rank = 0; + for (const auto& tensor : tensors) { + target_rank = std::max(target_rank, tensor->dims().size()); + } + + PADDLE_ENFORCE_GT(target_rank, + 0, + errors::InvalidArgument("BroadCastTensorsDims requires at " + "least one input tensor to have " + "rank greater than zero")); + + std::vector target_dims(target_rank, 0); + for (int index = 0; index < target_rank; index++) { + int target_dim_size = 1; + for (const auto& tensor : tensors) { + auto input_ddim = tensor->dims(); + int axis = static_cast(input_ddim.size()) - index - 1; + int dim_size = 1; + if (axis >= 0) { + dim_size = input_ddim[axis]; + } + + if (target_dim_size != 1 && dim_size != 1 && + target_dim_size != dim_size) { + PADDLE_THROW(errors::InvalidArgument( + "BroadCastTensorsDims inputs does not satisfy bcast semantics, " + "please check axis = %d in reverse order", + index)); + } + + target_dim_size = dim_size == 1 ? target_dim_size : dim_size; + } + target_dims[target_rank - index - 1] = target_dim_size; + } + return phi::make_ddim(target_dims); +} + +template +T** GetDevicePointerArray(const Context& ctx, + const std::vector& indices_v) { + std::vector h_indices_v(indices_v.size()); + for (int i = 0; i < indices_v.size(); ++i) { + h_indices_v[i] = indices_v[i]->data(); + } + auto d_indices_data = paddle::memory::Alloc( + ctx.GetPlace(), + h_indices_v.size() * sizeof(T*), + phi::Stream(reinterpret_cast(ctx.stream()))); + paddle::memory::Copy(ctx.GetPlace(), + d_indices_data->ptr(), + phi::CPUPlace(), + reinterpret_cast(h_indices_v.data()), + h_indices_v.size() * sizeof(T*), + ctx.stream()); + return reinterpret_cast(d_indices_data->ptr()); +} +} // namespace phi diff --git a/python/paddle/fluid/tests/unittests/test_index_put_op.py b/python/paddle/fluid/tests/unittests/test_index_put_op.py index 47a07e25e98e11..8ec40c40198125 100644 --- a/python/paddle/fluid/tests/unittests/test_index_put_op.py +++ b/python/paddle/fluid/tests/unittests/test_index_put_op.py @@ -18,7 +18,6 @@ import numpy as np import paddle -from paddle import _C_ops from paddle.fluid import Program @@ -31,8 +30,8 @@ def compute_index_put_ref(x_np, indices_np, value_np, accumulate=False): return x_np -def raw_index_put(x, indices, value): - return _C_ops.index_put(x, indices, value) +def raw_index_put(x, indices, value, accummulate): + return paddle.index_put(x, indices, value, accummulate) def has_duplicate_index(indices, shapes): @@ -98,6 +97,8 @@ def init_dtype_type(self): def setPlace(self): self.place = ['cpu'] + if self.dtype_np is np.float16: + self.place = [] if paddle.is_compiled_with_cuda(): self.place.append('gpu') @@ -363,6 +364,209 @@ def init_dtype_type(self): self.accumulate = True +class TestIndexPutAPI17(TestIndexPutAPIBase): + def init_dtype_type(self): + self.dtype_np = np.float64 + self.index_type_np = np.int32 + self.x_shape = (100, 110) + self.indices_shapes = [(21,), (21,)] + self.value_shape = (21,) + self.dtype_pd = paddle.float64 + self.index_type_pd = paddle.int32 + self.accumulate = False + + +class TestIndexPutAPI18(TestIndexPutAPIBase): + def init_dtype_type(self): + self.dtype_np = np.float64 + self.index_type_np = np.int32 + self.x_shape = (100, 110) + self.indices_shapes = [(21,), (21,)] + self.value_shape = (21,) + self.dtype_pd = paddle.float64 + self.index_type_pd = paddle.int32 + self.accumulate = True + + +class TestIndexPutAPI19(TestIndexPutAPIBase): + def init_dtype_type(self): + self.dtype_np = np.float32 + self.index_type_np = np.int32 + self.x_shape = (100, 110) + self.indices_shapes = [(21,), (21,)] + self.value_shape = (21,) + self.dtype_pd = paddle.float32 + self.index_type_pd = paddle.int32 + self.accumulate = False + + +class TestIndexPutAPI20(TestIndexPutAPIBase): + def init_dtype_type(self): + self.dtype_np = np.float32 + self.index_type_np = np.int32 + self.x_shape = (100, 110) + self.indices_shapes = [(21,), (21,)] + self.value_shape = (21,) + self.dtype_pd = paddle.float32 + self.index_type_pd = paddle.int32 + self.accumulate = True + + +class TestIndexPutAPI21(TestIndexPutAPIBase): + def init_dtype_type(self): + self.dtype_np = np.float16 + self.index_type_np = np.int32 + self.x_shape = (100, 110) + self.indices_shapes = [(21,), (21,)] + self.value_shape = (21,) + self.dtype_pd = paddle.float16 + self.index_type_pd = paddle.int32 + self.accumulate = False + + +class TestIndexPutAPI22(TestIndexPutAPIBase): + def init_dtype_type(self): + self.dtype_np = np.float16 + self.index_type_np = np.int32 + self.x_shape = (100, 110) + self.indices_shapes = [(21,), (21,)] + self.value_shape = (21,) + self.dtype_pd = paddle.float16 + self.index_type_pd = paddle.int32 + self.accumulate = True + + +class TestIndexPutAPI23(TestIndexPutAPIBase): + def init_dtype_type(self): + self.dtype_np = np.int32 + self.index_type_np = np.int32 + self.x_shape = (100, 110) + self.indices_shapes = [(21,), (21,)] + self.value_shape = (21,) + self.dtype_pd = paddle.int32 + self.index_type_pd = paddle.int32 + self.accumulate = False + + +class TestIndexPutAPI24(TestIndexPutAPIBase): + def init_dtype_type(self): + self.dtype_np = np.int32 + self.index_type_np = np.int32 + self.x_shape = (100, 110) + self.indices_shapes = [(21,), (21,)] + self.value_shape = (21,) + self.dtype_pd = paddle.int32 + self.index_type_pd = paddle.int32 + self.accumulate = True + + +class TestIndexPutAPI25(TestIndexPutAPIBase): + def init_dtype_type(self): + self.dtype_np = np.int64 + self.index_type_np = np.int32 + self.x_shape = (100, 110) + self.indices_shapes = [(21,), (21,)] + self.value_shape = (21,) + self.dtype_pd = paddle.int64 + self.index_type_pd = paddle.int32 + self.accumulate = False + + +class TestIndexPutAPI26(TestIndexPutAPIBase): + def init_dtype_type(self): + self.dtype_np = np.int64 + self.index_type_np = np.int32 + self.x_shape = (100, 110) + self.indices_shapes = [(21,), (21,)] + self.value_shape = (21,) + self.dtype_pd = paddle.int64 + self.index_type_pd = paddle.int32 + self.accumulate = True + + +class TestIndexPutAPI27(TestIndexPutAPIBase): + def init_dtype_type(self): + self.dtype_np = np.bool_ + self.index_type_np = np.int32 + self.x_shape = (100, 110) + self.indices_shapes = [(21,), (21,)] + self.value_shape = (21,) + self.dtype_pd = paddle.bool + self.index_type_pd = paddle.int32 + self.accumulate = False + + +class TestIndexPutAPI28(TestIndexPutAPIBase): + def init_dtype_type(self): + self.dtype_np = np.bool_ + self.index_type_np = np.int32 + self.x_shape = (100, 110) + self.indices_shapes = [(21,), (21,)] + self.value_shape = (21,) + self.dtype_pd = paddle.bool + self.index_type_pd = paddle.int32 + self.accumulate = True + + +class TestIndexPutInplaceAPI(unittest.TestCase): + def setUp(self): + self.init_dtype_type() + self.setPlace() + self.x_np = np.random.random(self.x_shape).astype(self.dtype_np) + self.value_np = np.random.random(self.value_shape).astype(self.dtype_np) + self.indices_np = gen_indices_np( + self.x_shape, self.indices_shapes, self.index_type_np + ) + + def init_dtype_type(self): + self.dtype_np = np.float64 + self.index_type_np = np.int64 + self.x_shape = (100, 110) + self.indices_shapes = [(21,), (21,)] + self.value_shape = (21,) + self.dtype_pd = paddle.float64 + self.index_type_pd = paddle.int64 + self.accumulate = False + + def setPlace(self): + self.place = ['cpu'] + if paddle.is_compiled_with_cuda(): + self.place.append('gpu') + + def test_dygraph_forward(self): + paddle.disable_static() + for place in self.place: + paddle.device.set_device(place) + self.x_pd = paddle.to_tensor(self.x_np, dtype=self.dtype_pd) + self.value_pd = paddle.to_tensor(self.value_np, dtype=self.dtype_pd) + self.indices_pd = [ + paddle.to_tensor(indice, dtype=self.index_type_pd) + for indice in self.indices_np + ] + self.indices_pd = tuple(self.indices_pd) + ref_res = compute_index_put_ref( + self.x_np, self.indices_np, self.value_np, self.accumulate + ) + x_pd_bk = self.x_pd.clone() + pd_res = paddle.index_put_( + x_pd_bk, self.indices_pd, self.value_pd, self.accumulate + ) + np.testing.assert_allclose(ref_res, pd_res.numpy(), atol=1e-7) + np.testing.assert_allclose(ref_res, x_pd_bk.numpy(), atol=1e-7) + + +class TestIndexPutInplaceAPI1(TestIndexPutInplaceAPI): + def init_dtype_type(self): + self.dtype_np = np.float64 + self.index_type_np = np.int64 + self.x_shape = (100, 110) + self.indices_shapes = [(21,), (21,)] + self.value_shape = (21,) + self.dtype_pd = paddle.float64 + self.index_type_pd = paddle.int64 + self.accumulate = True + + class TestIndexPutAPIBackward(unittest.TestCase): def test_backward(self): paddle.disable_static() @@ -486,64 +690,5 @@ def test_backwardBroadCastValue(self): ) -class TestIndexPutInplaceAPI(unittest.TestCase): - def setUp(self): - self.init_dtype_type() - self.setPlace() - self.x_np = np.random.random(self.x_shape).astype(self.dtype_np) - self.value_np = np.random.random(self.value_shape).astype(self.dtype_np) - self.indices_np = gen_indices_np( - self.x_shape, self.indices_shapes, self.index_type_np - ) - - def init_dtype_type(self): - self.dtype_np = np.float64 - self.index_type_np = np.int64 - self.x_shape = (100, 110) - self.indices_shapes = [(21,), (21,)] - self.value_shape = (21,) - self.dtype_pd = paddle.float64 - self.index_type_pd = paddle.int64 - self.accumulate = False - - def setPlace(self): - self.place = ['cpu'] - if paddle.is_compiled_with_cuda(): - self.place.append('gpu') - - def test_dygraph_forward(self): - paddle.disable_static() - for place in self.place: - paddle.device.set_device(place) - self.x_pd = paddle.to_tensor(self.x_np, dtype=self.dtype_pd) - self.value_pd = paddle.to_tensor(self.value_np, dtype=self.dtype_pd) - self.indices_pd = [ - paddle.to_tensor(indice, dtype=self.index_type_pd) - for indice in self.indices_np - ] - self.indices_pd = tuple(self.indices_pd) - ref_res = compute_index_put_ref( - self.x_np, self.indices_np, self.value_np, self.accumulate - ) - x_pd_bk = self.x_pd.clone() - pd_res = paddle.index_put_( - x_pd_bk, self.indices_pd, self.value_pd, self.accumulate - ) - np.testing.assert_allclose(ref_res, pd_res.numpy(), atol=1e-7) - np.testing.assert_allclose(ref_res, x_pd_bk.numpy(), atol=1e-7) - - -class TestIndexPutInplaceAPI1(TestIndexPutInplaceAPI): - def init_dtype_type(self): - self.dtype_np = np.float64 - self.index_type_np = np.int64 - self.x_shape = (100, 110) - self.indices_shapes = [(21,), (21,)] - self.value_shape = (21,) - self.dtype_pd = paddle.float64 - self.index_type_pd = paddle.int64 - self.accumulate = True - - if __name__ == '__main__': unittest.main() From 5f77bb58498ab49e5f37861288d6611cf3bdc17a Mon Sep 17 00:00:00 2001 From: Courtesy-Xs Date: Tue, 18 Apr 2023 05:02:21 +0000 Subject: [PATCH 10/24] add test case when index tensor in indices is int32 when indices.size less than x.dims --- paddle/phi/kernels/index_put_utils.h | 13 ++++------ .../tests/unittests/test_index_put_op.py | 24 +++++++++++++++++++ 2 files changed, 29 insertions(+), 8 deletions(-) diff --git a/paddle/phi/kernels/index_put_utils.h b/paddle/phi/kernels/index_put_utils.h index aef3e4a60afab7..89ac5ab4e0d485 100644 --- a/paddle/phi/kernels/index_put_utils.h +++ b/paddle/phi/kernels/index_put_utils.h @@ -65,22 +65,19 @@ static std::vector DealWithBoolIndices( std::vector res(indices_v.begin(), indices_v.end()); bool contains_bool_tensor = false; for (size_t i = 0; i < indices_v.size(); ++i) { - PADDLE_ENFORCE( - (indices_v[i]->dtype() == phi::DataType::INT64) || - (indices_v[i]->dtype() == phi::DataType::INT32) || - (indices_v[i]->dtype() == phi::DataType::BOOL), - phi::errors::InvalidArgument( - "indices contains bool tensor and int32/int64 tensor at the same " - "time")); if (indices_v[i]->dtype() == phi::DataType::BOOL) { contains_bool_tensor = true; - } else { + } else if ((indices_v[i]->dtype() == phi::DataType::INT64) || + (indices_v[i]->dtype() == phi::DataType::INT32)) { PADDLE_ENFORCE_EQ( contains_bool_tensor, false, phi::errors::InvalidArgument( "indices contains bool tensor and int32/int64 tensor at the same " "time")); + } else { + PADDLE_THROW(phi::errors::InvalidArgument( + "data type of tensor in indices must be int32, int64 or bool")); } } diff --git a/python/paddle/fluid/tests/unittests/test_index_put_op.py b/python/paddle/fluid/tests/unittests/test_index_put_op.py index 8ec40c40198125..c311e51206e57c 100644 --- a/python/paddle/fluid/tests/unittests/test_index_put_op.py +++ b/python/paddle/fluid/tests/unittests/test_index_put_op.py @@ -508,6 +508,30 @@ def init_dtype_type(self): self.accumulate = True +class TestIndexPutAPI29(TestIndexPutAPIBase): + def init_dtype_type(self): + self.dtype_np = np.float64 + self.index_type_np = np.int32 + self.x_shape = (110, 42, 56, 56) + self.indices_shapes = ((16, 16), (16, 16), (1, 16)) + self.value_shape = (16, 16, 56) + self.dtype_pd = paddle.float64 + self.index_type_pd = paddle.int32 + self.accumulate = False + + +class TestIndexPutAPI30(TestIndexPutAPIBase): + def init_dtype_type(self): + self.dtype_np = np.float64 + self.index_type_np = np.int32 + self.x_shape = (110, 42, 56, 56) + self.indices_shapes = ((16, 16), (16, 16), (1, 16)) + self.value_shape = (16, 16, 56) + self.dtype_pd = paddle.float64 + self.index_type_pd = paddle.int32 + self.accumulate = True + + class TestIndexPutInplaceAPI(unittest.TestCase): def setUp(self): self.init_dtype_type() From 6267d328f3e94c54259b9a274604647166d6303c Mon Sep 17 00:00:00 2001 From: Courtesy-Xs Date: Tue, 18 Apr 2023 05:12:40 +0000 Subject: [PATCH 11/24] add index_put api backward in cpu place --- .../tests/unittests/test_index_put_op.py | 222 ++++++++++-------- 1 file changed, 120 insertions(+), 102 deletions(-) diff --git a/python/paddle/fluid/tests/unittests/test_index_put_op.py b/python/paddle/fluid/tests/unittests/test_index_put_op.py index c311e51206e57c..2c2e6892ce39fe 100644 --- a/python/paddle/fluid/tests/unittests/test_index_put_op.py +++ b/python/paddle/fluid/tests/unittests/test_index_put_op.py @@ -592,126 +592,144 @@ def init_dtype_type(self): class TestIndexPutAPIBackward(unittest.TestCase): + def setUp(self): + self.setPlace() + + def setPlace(self): + self.place = ['cpu'] + if paddle.is_compiled_with_cuda(): + self.place.append('gpu') + def test_backward(self): paddle.disable_static() - value = paddle.ones(shape=[4], dtype=paddle.float64) - x = paddle.ones(shape=[16, 21], dtype=paddle.float64) - ix1 = paddle.to_tensor([0, 1, 2, 3], dtype=paddle.int64) - ix2 = paddle.to_tensor([0, 1, 2, 3], dtype=paddle.int64) - value.stop_gradient = False - x.stop_gradient = False - out = paddle.index_put(x, (ix1, ix2), value, False) - - dx, dvalue = paddle.grad( - outputs=[out], - inputs=[x, value], - create_graph=False, - retain_graph=True, - ) - ref_dx = np.ones(shape=[16, 21], dtype=np.float64) - ref_dx[ix1, ix2] = 0 - - np.testing.assert_allclose(ref_dx, dx.numpy(), atol=1e-7) - np.testing.assert_allclose( - np.array([1.0, 1.0, 1.0, 1.0], dtype=np.float64), - dvalue.numpy(), - atol=1e-7, - ) + for place in self.place: + paddle.device.set_device(place) + value = paddle.ones(shape=[4], dtype=paddle.float64) + x = paddle.ones(shape=[16, 21], dtype=paddle.float64) + ix1 = paddle.to_tensor([0, 1, 2, 3], dtype=paddle.int64) + ix2 = paddle.to_tensor([0, 1, 2, 3], dtype=paddle.int64) + value.stop_gradient = False + x.stop_gradient = False + out = paddle.index_put(x, (ix1, ix2), value, False) + + dx, dvalue = paddle.grad( + outputs=[out], + inputs=[x, value], + create_graph=False, + retain_graph=True, + ) + ref_dx = np.ones(shape=[16, 21], dtype=np.float64) + ref_dx[ix1, ix2] = 0 + + np.testing.assert_allclose(ref_dx, dx.numpy(), atol=1e-7) + np.testing.assert_allclose( + np.array([1.0, 1.0, 1.0, 1.0], dtype=np.float64), + dvalue.numpy(), + atol=1e-7, + ) - out = paddle.index_put(x, (ix1, ix2), value, True) + out = paddle.index_put(x, (ix1, ix2), value, True) - dx, dvalue = paddle.grad( - outputs=[out], - inputs=[x, value], - create_graph=False, - retain_graph=True, - ) - ref_dx = np.ones(shape=[16, 21], dtype=np.float64) + dx, dvalue = paddle.grad( + outputs=[out], + inputs=[x, value], + create_graph=False, + retain_graph=True, + ) + ref_dx = np.ones(shape=[16, 21], dtype=np.float64) - np.testing.assert_allclose(ref_dx, dx.numpy(), atol=1e-7) - np.testing.assert_allclose( - np.array([1.0, 1.0, 1.0, 1.0], dtype=np.float64), - dvalue.numpy(), - atol=1e-7, - ) + np.testing.assert_allclose(ref_dx, dx.numpy(), atol=1e-7) + np.testing.assert_allclose( + np.array([1.0, 1.0, 1.0, 1.0], dtype=np.float64), + dvalue.numpy(), + atol=1e-7, + ) def test_backwardScalarVal(self): paddle.disable_static() - value = paddle.ones(shape=[1], dtype=paddle.float64) - x = paddle.ones(shape=[16, 21], dtype=paddle.float64) - ix1 = paddle.to_tensor([0, 1, 2, 3], dtype=paddle.int64) - ix2 = paddle.to_tensor([0, 1, 2, 3], dtype=paddle.int64) - value.stop_gradient = False - x.stop_gradient = False - out = paddle.index_put(x, (ix1, ix2), value, False) - - dx, dvalue = paddle.grad( - outputs=[out], - inputs=[x, value], - create_graph=False, - retain_graph=True, - ) - ref_dx = np.ones(shape=[16, 21], dtype=np.float64) - ref_dx[ix1, ix2] = 0 + for place in self.place: + paddle.device.set_device(place) + value = paddle.ones(shape=[1], dtype=paddle.float64) + x = paddle.ones(shape=[16, 21], dtype=paddle.float64) + ix1 = paddle.to_tensor([0, 1, 2, 3], dtype=paddle.int64) + ix2 = paddle.to_tensor([0, 1, 2, 3], dtype=paddle.int64) + value.stop_gradient = False + x.stop_gradient = False + out = paddle.index_put(x, (ix1, ix2), value, False) + + dx, dvalue = paddle.grad( + outputs=[out], + inputs=[x, value], + create_graph=False, + retain_graph=True, + ) + ref_dx = np.ones(shape=[16, 21], dtype=np.float64) + ref_dx[ix1, ix2] = 0 - np.testing.assert_allclose(ref_dx, dx.numpy(), atol=1e-7) - np.testing.assert_allclose( - np.array([4.0], dtype=np.float64), dvalue.numpy(), atol=1e-7 - ) + np.testing.assert_allclose(ref_dx, dx.numpy(), atol=1e-7) + np.testing.assert_allclose( + np.array([4.0], dtype=np.float64), dvalue.numpy(), atol=1e-7 + ) - out = paddle.index_put(x, (ix1, ix2), value, True) + out = paddle.index_put(x, (ix1, ix2), value, True) - dx, dvalue = paddle.grad( - outputs=[out], - inputs=[x, value], - create_graph=False, - retain_graph=True, - ) - ref_dx = np.ones(shape=[16, 21], dtype=np.float64) + dx, dvalue = paddle.grad( + outputs=[out], + inputs=[x, value], + create_graph=False, + retain_graph=True, + ) + ref_dx = np.ones(shape=[16, 21], dtype=np.float64) - np.testing.assert_allclose(ref_dx, dx.numpy(), atol=1e-7) - np.testing.assert_allclose( - np.array([4.0], dtype=np.float64), dvalue.numpy(), atol=1e-7 - ) + np.testing.assert_allclose(ref_dx, dx.numpy(), atol=1e-7) + np.testing.assert_allclose( + np.array([4.0], dtype=np.float64), dvalue.numpy(), atol=1e-7 + ) def test_backwardBroadCastValue(self): paddle.disable_static() - value = paddle.ones(shape=[2], dtype=paddle.float64) - x = paddle.ones(shape=[16, 21], dtype=paddle.float64) - ix1 = paddle.to_tensor([[0, 1], [2, 3]], dtype=paddle.int64) - ix2 = paddle.to_tensor([[0, 1], [2, 3]], dtype=paddle.int64) - value.stop_gradient = False - x.stop_gradient = False - out = paddle.index_put(x, (ix1, ix2), value, False) - - dx, dvalue = paddle.grad( - outputs=[out], - inputs=[x, value], - create_graph=False, - retain_graph=True, - ) - ref_dx = np.ones(shape=[16, 21], dtype=np.float64) - ref_dx[ix1, ix2] = 0 - - np.testing.assert_allclose(ref_dx, dx.numpy(), atol=1e-7) - np.testing.assert_allclose( - np.array([2.0, 2.0], dtype=np.float64), dvalue.numpy(), atol=1e-7 - ) + for place in self.place: + paddle.device.set_device(place) + value = paddle.ones(shape=[2], dtype=paddle.float64) + x = paddle.ones(shape=[16, 21], dtype=paddle.float64) + ix1 = paddle.to_tensor([[0, 1], [2, 3]], dtype=paddle.int64) + ix2 = paddle.to_tensor([[0, 1], [2, 3]], dtype=paddle.int64) + value.stop_gradient = False + x.stop_gradient = False + out = paddle.index_put(x, (ix1, ix2), value, False) + + dx, dvalue = paddle.grad( + outputs=[out], + inputs=[x, value], + create_graph=False, + retain_graph=True, + ) + ref_dx = np.ones(shape=[16, 21], dtype=np.float64) + ref_dx[ix1, ix2] = 0 + + np.testing.assert_allclose(ref_dx, dx.numpy(), atol=1e-7) + np.testing.assert_allclose( + np.array([2.0, 2.0], dtype=np.float64), + dvalue.numpy(), + atol=1e-7, + ) - out = paddle.index_put(x, (ix1, ix2), value, True) + out = paddle.index_put(x, (ix1, ix2), value, True) - dx, dvalue = paddle.grad( - outputs=[out], - inputs=[x, value], - create_graph=False, - retain_graph=True, - ) - ref_dx = np.ones(shape=[16, 21], dtype=np.float64) + dx, dvalue = paddle.grad( + outputs=[out], + inputs=[x, value], + create_graph=False, + retain_graph=True, + ) + ref_dx = np.ones(shape=[16, 21], dtype=np.float64) - np.testing.assert_allclose(ref_dx, dx.numpy(), atol=1e-7) - np.testing.assert_allclose( - np.array([2.0, 2.0], dtype=np.float64), dvalue.numpy(), atol=1e-7 - ) + np.testing.assert_allclose(ref_dx, dx.numpy(), atol=1e-7) + np.testing.assert_allclose( + np.array([2.0, 2.0], dtype=np.float64), + dvalue.numpy(), + atol=1e-7, + ) if __name__ == '__main__': From fdd04363b291b7ab87dd68ab60b811bcf825315c Mon Sep 17 00:00:00 2001 From: Courtesy-Xs Date: Tue, 18 Apr 2023 09:12:09 +0000 Subject: [PATCH 12/24] add backward test case --- .../tests/unittests/test_index_put_op.py | 90 +++++++++++++++++++ 1 file changed, 90 insertions(+) diff --git a/python/paddle/fluid/tests/unittests/test_index_put_op.py b/python/paddle/fluid/tests/unittests/test_index_put_op.py index 2c2e6892ce39fe..39d342517db9d5 100644 --- a/python/paddle/fluid/tests/unittests/test_index_put_op.py +++ b/python/paddle/fluid/tests/unittests/test_index_put_op.py @@ -731,6 +731,96 @@ def test_backwardBroadCastValue(self): atol=1e-7, ) + def test_backwardBroadCastValue1(self): + paddle.disable_static() + for place in self.place: + paddle.device.set_device(place) + value = paddle.ones(shape=[1, 2], dtype=paddle.float64) + x = paddle.ones(shape=[16, 21], dtype=paddle.float64) + ix1 = paddle.to_tensor([[0, 1], [2, 3]], dtype=paddle.int64) + ix2 = paddle.to_tensor([[0, 1], [2, 3]], dtype=paddle.int64) + value.stop_gradient = False + x.stop_gradient = False + out = paddle.index_put(x, (ix1, ix2), value, False) + + dx, dvalue = paddle.grad( + outputs=[out], + inputs=[x, value], + create_graph=False, + retain_graph=True, + ) + ref_dx = np.ones(shape=[16, 21], dtype=np.float64) + ref_dx[ix1, ix2] = 0 + + np.testing.assert_allclose(ref_dx, dx.numpy(), atol=1e-7) + np.testing.assert_allclose( + np.array([[2.0, 2.0]], dtype=np.float64), + dvalue.numpy(), + atol=1e-7, + ) + + out = paddle.index_put(x, (ix1, ix2), value, True) + + dx, dvalue = paddle.grad( + outputs=[out], + inputs=[x, value], + create_graph=False, + retain_graph=True, + ) + ref_dx = np.ones(shape=[16, 21], dtype=np.float64) + + np.testing.assert_allclose(ref_dx, dx.numpy(), atol=1e-7) + np.testing.assert_allclose( + np.array([[2.0, 2.0]], dtype=np.float64), + dvalue.numpy(), + atol=1e-7, + ) + + def test_backwardBroadCastValue2(self): + paddle.disable_static() + for place in self.place: + paddle.device.set_device(place) + value = paddle.ones(shape=[2, 1], dtype=paddle.float64) + x = paddle.ones(shape=[16, 21], dtype=paddle.float64) + ix1 = paddle.to_tensor([[0, 1], [2, 3]], dtype=paddle.int64) + ix2 = paddle.to_tensor([[0, 1], [2, 3]], dtype=paddle.int64) + value.stop_gradient = False + x.stop_gradient = False + out = paddle.index_put(x, (ix1, ix2), value, False) + + dx, dvalue = paddle.grad( + outputs=[out], + inputs=[x, value], + create_graph=False, + retain_graph=True, + ) + ref_dx = np.ones(shape=[16, 21], dtype=np.float64) + ref_dx[ix1, ix2] = 0 + + np.testing.assert_allclose(ref_dx, dx.numpy(), atol=1e-7) + np.testing.assert_allclose( + np.array([[2.0], [2.0]], dtype=np.float64), + dvalue.numpy(), + atol=1e-7, + ) + + out = paddle.index_put(x, (ix1, ix2), value, True) + + dx, dvalue = paddle.grad( + outputs=[out], + inputs=[x, value], + create_graph=False, + retain_graph=True, + ) + ref_dx = np.ones(shape=[16, 21], dtype=np.float64) + + np.testing.assert_allclose(ref_dx, dx.numpy(), atol=1e-7) + np.testing.assert_allclose( + np.array([[2.0], [2.0]], dtype=np.float64), + dvalue.numpy(), + atol=1e-7, + ) + if __name__ == '__main__': unittest.main() From 7b71a3a6dcd9463a48b77f1b590c3f63a7c8f7c5 Mon Sep 17 00:00:00 2001 From: Courtesy-Xs Date: Fri, 28 Apr 2023 10:01:27 +0000 Subject: [PATCH 13/24] fix take in init.py bug --- python/paddle/tensor/__init__.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/python/paddle/tensor/__init__.py b/python/paddle/tensor/__init__.py index 5529f113d46501..db4d16b82c9864 100644 --- a/python/paddle/tensor/__init__.py +++ b/python/paddle/tensor/__init__.py @@ -533,8 +533,9 @@ 'heaviside', 'index_add', "index_add_", - "index_put", - "index_put_" 'take', + 'index_put', + 'index_put_', + 'take', 'bucketize', 'sgn', 'frexp', From 48a03c661cb531babec4864bdd8024fdb059ccb3 Mon Sep 17 00:00:00 2001 From: Courtesy-Xs Date: Sat, 6 May 2023 07:52:24 +0000 Subject: [PATCH 14/24] refactor code according to review result --- .../phi/kernels/cpu/index_put_grad_kernel.cc | 148 ++++++------------ paddle/phi/kernels/cpu/index_put_kernel.cc | 93 ++++------- .../phi/kernels/{ => funcs}/index_put_utils.h | 18 +-- .../phi/kernels/gpu/index_put_grad_kernel.cu | 9 +- paddle/phi/kernels/gpu/index_put_kernel.cu | 13 +- paddle/phi/kernels/index_put_grad_kernel.h | 9 -- paddle/phi/kernels/index_put_kernel.h | 8 - .../tests/unittests/test_index_put_op.py | 2 +- python/paddle/tensor/manipulation.py | 3 - 9 files changed, 95 insertions(+), 208 deletions(-) rename paddle/phi/kernels/{ => funcs}/index_put_utils.h (93%) diff --git a/paddle/phi/kernels/cpu/index_put_grad_kernel.cc b/paddle/phi/kernels/cpu/index_put_grad_kernel.cc index 5c42db4a6d60a5..41d6ac21c3dc89 100644 --- a/paddle/phi/kernels/cpu/index_put_grad_kernel.cc +++ b/paddle/phi/kernels/cpu/index_put_grad_kernel.cc @@ -14,15 +14,11 @@ #include "paddle/phi/kernels/index_put_grad_kernel.h" #include -#include "paddle/phi/backends/gpu/gpu_context.h" -#include "paddle/phi/backends/gpu/gpu_launch_config.h" #include "paddle/phi/core/kernel_registry.h" -#include "paddle/phi/core/utils/array.h" #include "paddle/phi/kernels/cast_kernel.h" -#include "paddle/phi/kernels/expand_kernel.h" -#include "paddle/phi/kernels/index_put_utils.h" +#include "paddle/phi/kernels/funcs/index_put_utils.h" #include "paddle/phi/kernels/reduce_sum_kernel.h" -#include "paddle/phi/kernels/reshape_kernel.h" + namespace phi { template @@ -44,11 +40,11 @@ phi::DenseTensor GetRangeTensor(const Context& dev_ctx, return res; } -template +template void set_zero_kernel(const int64_t N, const int64_t** indices, - phi::Array stride, - phi::Array shape, + const phi::DDim& stride, + const phi::DDim& shape, T* out) { #ifdef PADDLE_WITH_MKLML #pragma omp parallel for @@ -57,8 +53,8 @@ void set_zero_kernel(const int64_t N, int64_t cur_ix = 0; int64_t offset = 0; - for (size_t i = 0; i < Rank; ++i) { - cur_ix = (int64_t(*(indices[i] + idx))); + for (int i = 0; i < shape.size(); ++i) { + cur_ix = (static_cast(*(indices[i] + idx))); if (cur_ix < 0) { cur_ix += shape[i]; } @@ -68,12 +64,12 @@ void set_zero_kernel(const int64_t N, } } -template +template void index_put_grad_kernel(const int64_t N, const T* out_grad, const int64_t** indices, - phi::Array stride, - phi::Array shape, + const phi::DDim& stride, + const phi::DDim& shape, T* value_grad) { #ifdef PADDLE_WITH_MKLML #pragma omp parallel for @@ -82,8 +78,8 @@ void index_put_grad_kernel(const int64_t N, int64_t cur_ix = 0; int64_t offset = 0; - for (size_t i = 0; i < Rank; ++i) { - cur_ix = (int64_t(*(indices[i] + idx))); + for (int i = 0; i < shape.size(); ++i) { + cur_ix = (static_cast(*(indices[i] + idx))); if (cur_ix < 0) { cur_ix += shape[i]; } @@ -93,7 +89,7 @@ void index_put_grad_kernel(const int64_t N, } } -template +template void LaunchIndexPutGradKernel(const Context& dev_ctx, const std::vector& indices_v, const DenseTensor& out_grad, @@ -109,20 +105,12 @@ void LaunchIndexPutGradKernel(const Context& dev_ctx, const int64_t numel = indices_v[0]->numel(); auto x_grad_stride = phi::stride(x_grad_dims); - phi::Array stride_a; - phi::Array shape_a; - - for (size_t idx = 0; idx < Rank; ++idx) { - stride_a[idx] = x_grad_stride[idx]; - shape_a[idx] = x_grad_dims[idx]; - } - - const int64_t* pd_indices[Rank]; - for (size_t i = 0; i < Rank; ++i) { + const int64_t* pd_indices[7]; + for (size_t i = 0; i < indices_v.size(); ++i) { pd_indices[i] = indices_v[i]->data(); } - set_zero_kernel( - numel, pd_indices, stride_a, shape_a, x_grad_data); + set_zero_kernel( + numel, pd_indices, x_grad_stride, x_grad_dims, x_grad_data); } } if (value_grad) { @@ -137,24 +125,16 @@ void LaunchIndexPutGradKernel(const Context& dev_ctx, const int64_t numel = indices_v[0]->numel(); auto out_grad_stride = phi::stride(out_grad_dims); - phi::Array stride_a; - phi::Array shape_a; - - for (size_t idx = 0; idx < Rank; ++idx) { - stride_a[idx] = out_grad_stride[idx]; - shape_a[idx] = out_grad_dims[idx]; - } - - const int64_t* pd_indices[Rank]; - for (size_t i = 0; i < Rank; ++i) { + const int64_t* pd_indices[7]; + for (size_t i = 0; i < indices_v.size(); ++i) { pd_indices[i] = indices_v[i]->data(); } - index_put_grad_kernel(numel, - out_grad_data, - pd_indices, - stride_a, - shape_a, - tmp_value_grad_data); + index_put_grad_kernel(numel, + out_grad_data, + pd_indices, + out_grad_stride, + out_grad_dims, + tmp_value_grad_data); std::vector v_dims(tmp_value_grad.dims().size()); std::iota(v_dims.begin(), v_dims.end(), 0); @@ -173,20 +153,16 @@ void LaunchIndexPutGradKernel(const Context& dev_ctx, const int64_t numel = indices_v[0]->numel(); auto out_grad_stride = phi::stride(out_grad_dims); - phi::Array stride_a; - phi::Array shape_a; - - for (size_t idx = 0; idx < Rank; ++idx) { - stride_a[idx] = out_grad_stride[idx]; - shape_a[idx] = out_grad_dims[idx]; - } - - const int64_t* pd_indices[Rank]; - for (size_t i = 0; i < Rank; ++i) { + const int64_t* pd_indices[7]; + for (size_t i = 0; i < indices_v.size(); ++i) { pd_indices[i] = indices_v[i]->data(); } - index_put_grad_kernel( - numel, out_grad_data, pd_indices, stride_a, shape_a, value_grad_data); + index_put_grad_kernel(numel, + out_grad_data, + pd_indices, + out_grad_stride, + out_grad_dims, + value_grad_data); } else { DenseTensor tmp_value_grad(value_grad->dtype()); tmp_value_grad.Resize(indices_v[0]->dims()); @@ -198,24 +174,16 @@ void LaunchIndexPutGradKernel(const Context& dev_ctx, const int64_t numel = indices_v[0]->numel(); auto out_grad_stride = phi::stride(out_grad_dims); - phi::Array stride_a; - phi::Array shape_a; - - for (size_t idx = 0; idx < Rank; ++idx) { - stride_a[idx] = out_grad_stride[idx]; - shape_a[idx] = out_grad_dims[idx]; - } - - const int64_t* pd_indices[Rank]; - for (size_t i = 0; i < Rank; ++i) { + const int64_t* pd_indices[7]; + for (size_t i = 0; i < indices_v.size(); ++i) { pd_indices[i] = indices_v[i]->data(); } - index_put_grad_kernel(numel, - out_grad_data, - pd_indices, - stride_a, - shape_a, - tmp_value_grad_data); + index_put_grad_kernel(numel, + out_grad_data, + pd_indices, + out_grad_stride, + out_grad_dims, + tmp_value_grad_data); std::vector after_dims = phi::vectorize(tmp_value_grad.dims()); std::vector before_dims = phi::vectorize(value_grad->dims()); @@ -290,6 +258,7 @@ void IndexPutGradKernel(const Context& dev_ctx, std::vector res_dim_v(phi::vectorize(bd_dim)); std::vector res_indices_v(x.dims().size(), nullptr); std::vector tmp_res_indices_v; + std::vector range_tensor_v; if (int_indices_v.size() < total_dims) { std::vector tmp_x_dims = phi::vectorize(x.dims()); @@ -359,37 +328,8 @@ void IndexPutGradKernel(const Context& dev_ctx, } } - switch (total_dims) { - case 1: - LaunchIndexPutGradKernel( - dev_ctx, res_indices_v, out_grad, accumulate, value_grad, x_grad); - break; - case 2: - LaunchIndexPutGradKernel( - dev_ctx, res_indices_v, out_grad, accumulate, value_grad, x_grad); - break; - case 3: - LaunchIndexPutGradKernel( - dev_ctx, res_indices_v, out_grad, accumulate, value_grad, x_grad); - break; - case 4: - LaunchIndexPutGradKernel( - dev_ctx, res_indices_v, out_grad, accumulate, value_grad, x_grad); - break; - case 5: - LaunchIndexPutGradKernel( - dev_ctx, res_indices_v, out_grad, accumulate, value_grad, x_grad); - break; - case 6: - LaunchIndexPutGradKernel( - dev_ctx, res_indices_v, out_grad, accumulate, value_grad, x_grad); - break; - default: - PADDLE_THROW(phi::errors::InvalidArgument( - "dims of input tensor should be less than 7, But received" - "%d", - x.dims().size())); - } + LaunchIndexPutGradKernel( + dev_ctx, res_indices_v, out_grad, accumulate, value_grad, x_grad); } } // namespace phi diff --git a/paddle/phi/kernels/cpu/index_put_kernel.cc b/paddle/phi/kernels/cpu/index_put_kernel.cc index 3b530ce057f98a..3374805ac28a04 100644 --- a/paddle/phi/kernels/cpu/index_put_kernel.cc +++ b/paddle/phi/kernels/cpu/index_put_kernel.cc @@ -15,10 +15,8 @@ #include "paddle/phi/kernels/index_put_kernel.h" #include "paddle/phi/backends/cpu/cpu_context.h" #include "paddle/phi/core/kernel_registry.h" -#include "paddle/phi/core/utils/array.h" #include "paddle/phi/kernels/cast_kernel.h" -#include "paddle/phi/kernels/expand_kernel.h" -#include "paddle/phi/kernels/index_put_utils.h" +#include "paddle/phi/kernels/funcs/index_put_utils.h" namespace phi { @@ -41,13 +39,13 @@ phi::DenseTensor GetRangeTensor(const Context& dev_ctx, return res; } -template +template void index_put_kernel(const int64_t N, const T* x, const T* vals, const int64_t** indices, - phi::Array stride, - phi::Array shape, + const phi::DDim& stride, + const phi::DDim& shape, int64_t isSingleValTensor, bool accumulate, T* out) { @@ -58,8 +56,8 @@ void index_put_kernel(const int64_t N, int64_t cur_ix = 0; int64_t offset = 0; - for (size_t i = 0; i < Rank; ++i) { - cur_ix = (int64_t(*(indices[i] + idx))); + for (int i = 0; i < shape.size(); ++i) { + cur_ix = (static_cast(*(indices[i] + idx))); if (cur_ix < 0) { cur_ix += shape[i]; } @@ -74,7 +72,7 @@ void index_put_kernel(const int64_t N, } } -template +template void LaunchIndexPutKernel(const Context& dev_ctx, const DenseTensor& x, const std::vector& indices_v, @@ -94,30 +92,22 @@ void LaunchIndexPutKernel(const Context& dev_ctx, const int64_t numel = indices_v[0]->numel(); auto x_stride = phi::stride(x_dims); - phi::Array stride_a; - phi::Array shape_a; - - for (size_t idx = 0; idx < Rank; ++idx) { - stride_a[idx] = x_stride[idx]; - shape_a[idx] = x_dims[idx]; - } - int64_t isSingleValTensor = (value.numel() == 1) ? 0 : INT64_MAX; - const int64_t* pd_indices[Rank]; - for (size_t i = 0; i < Rank; ++i) { + const int64_t* pd_indices[7]; + for (size_t i = 0; i < indices_v.size(); ++i) { pd_indices[i] = indices_v[i]->data(); } - index_put_kernel(numel, - x_data, - val_data, - pd_indices, - stride_a, - shape_a, - isSingleValTensor, - accumulate, - out_data); + index_put_kernel(numel, + x_data, + val_data, + pd_indices, + x_stride, + x_dims, + isSingleValTensor, + accumulate, + out_data); } template @@ -133,10 +123,20 @@ void IndexPutKernel(const Context& dev_ctx, phi::errors::InvalidArgument( "The data type of tensor in indices must be same to the data type " "of tensor x.")); + PADDLE_ENFORCE_EQ(indices_v.empty(), + false, + phi::errors::InvalidArgument("Indices cannot be empty.")); + + const size_t total_dims = x.dims().size(); + PADDLE_ENFORCE_LE(total_dims, + 6, + phi::errors::InvalidArgument( + "Dims of input tensor should be less than 7.")); + std::vector tmp_args; std::vector int_indices_v = DealWithBoolIndices(dev_ctx, indices_v, &tmp_args); - const size_t total_dims = x.dims().size(); + auto bd_dim = BroadCastTensorsDims(int_indices_v); std::vector res_dim_v(phi::vectorize(bd_dim)); @@ -181,7 +181,7 @@ void IndexPutKernel(const Context& dev_ctx, for (size_t i = 0; i < res_indices_v.size(); ++i) { res_indices_v[i] = &tmp_res_indices_v[i]; } - // value至少需要满足与已有的indices为可broadcast_to关系 + if (value.numel() != 1) { tmp_value_v.emplace_back(DenseTensor(value.dtype()).Resize(res_dim)); ExpandKernel(dev_ctx, @@ -234,37 +234,8 @@ void IndexPutKernel(const Context& dev_ctx, } } - switch (total_dims) { - case 1: - LaunchIndexPutKernel( - dev_ctx, x, res_indices_v, *ptr_value, accumulate, out); - break; - case 2: - LaunchIndexPutKernel( - dev_ctx, x, res_indices_v, *ptr_value, accumulate, out); - break; - case 3: - LaunchIndexPutKernel( - dev_ctx, x, res_indices_v, *ptr_value, accumulate, out); - break; - case 4: - LaunchIndexPutKernel( - dev_ctx, x, res_indices_v, *ptr_value, accumulate, out); - break; - case 5: - LaunchIndexPutKernel( - dev_ctx, x, res_indices_v, *ptr_value, accumulate, out); - break; - case 6: - LaunchIndexPutKernel( - dev_ctx, x, res_indices_v, *ptr_value, accumulate, out); - break; - default: - PADDLE_THROW(phi::errors::InvalidArgument( - "dims of input tensor should be less than 7, But received" - "%d", - x.dims().size())); - } + LaunchIndexPutKernel( + dev_ctx, x, res_indices_v, *ptr_value, accumulate, out); } } // namespace phi diff --git a/paddle/phi/kernels/index_put_utils.h b/paddle/phi/kernels/funcs/index_put_utils.h similarity index 93% rename from paddle/phi/kernels/index_put_utils.h rename to paddle/phi/kernels/funcs/index_put_utils.h index 89ac5ab4e0d485..79ac9cdc6be615 100644 --- a/paddle/phi/kernels/index_put_utils.h +++ b/paddle/phi/kernels/funcs/index_put_utils.h @@ -15,9 +15,8 @@ #pragma once #include -#include "paddle/fluid/memory/malloc.h" -#include "paddle/fluid/memory/memcpy.h" #include "paddle/phi/common/int_array.h" +#include "paddle/phi/common/memory_utils.h" #include "paddle/phi/common/place.h" #include "paddle/phi/core/dense_tensor.h" #include "paddle/phi/core/utils/array.h" @@ -166,16 +165,17 @@ T** GetDevicePointerArray(const Context& ctx, for (int i = 0; i < indices_v.size(); ++i) { h_indices_v[i] = indices_v[i]->data(); } - auto d_indices_data = paddle::memory::Alloc( + auto d_indices_data = phi::memory_utils::Alloc( ctx.GetPlace(), h_indices_v.size() * sizeof(T*), phi::Stream(reinterpret_cast(ctx.stream()))); - paddle::memory::Copy(ctx.GetPlace(), - d_indices_data->ptr(), - phi::CPUPlace(), - reinterpret_cast(h_indices_v.data()), - h_indices_v.size() * sizeof(T*), - ctx.stream()); + phi::memory_utils::Copy(ctx.GetPlace(), + d_indices_data->ptr(), + phi::CPUPlace(), + reinterpret_cast(h_indices_v.data()), + h_indices_v.size() * sizeof(T*), + ctx.stream()); return reinterpret_cast(d_indices_data->ptr()); } + } // namespace phi diff --git a/paddle/phi/kernels/gpu/index_put_grad_kernel.cu b/paddle/phi/kernels/gpu/index_put_grad_kernel.cu index 84bb847adb5abb..8d81279655e2b0 100644 --- a/paddle/phi/kernels/gpu/index_put_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/index_put_grad_kernel.cu @@ -17,12 +17,9 @@ #include "paddle/phi/backends/gpu/gpu_context.h" #include "paddle/phi/backends/gpu/gpu_launch_config.h" #include "paddle/phi/core/kernel_registry.h" -#include "paddle/phi/core/utils/array.h" #include "paddle/phi/kernels/cast_kernel.h" -#include "paddle/phi/kernels/expand_kernel.h" -#include "paddle/phi/kernels/index_put_utils.h" +#include "paddle/phi/kernels/funcs/index_put_utils.h" #include "paddle/phi/kernels/reduce_sum_kernel.h" -#include "paddle/phi/kernels/reshape_kernel.h" namespace phi { @@ -65,7 +62,7 @@ __global__ void set_zero_cuda_kernel(const int64_t N, } int64_t offset = 0; for (int i = 0; i < Rank; ++i) { - cur_ix = (int64_t(*(indices[i] + idx))); + cur_ix = (static_cast(*(indices[i] + idx))); if (cur_ix < 0) { cur_ix += shape[i]; } @@ -90,7 +87,7 @@ __global__ void index_put_grad_cuda_kernel(const int64_t N, } int64_t offset = 0; for (int i = 0; i < Rank; ++i) { - cur_ix = (int64_t(*(indices[i] + idx))); + cur_ix = (static_cast(*(indices[i] + idx))); if (cur_ix < 0) { cur_ix += shape[i]; } diff --git a/paddle/phi/kernels/gpu/index_put_kernel.cu b/paddle/phi/kernels/gpu/index_put_kernel.cu index a8df483b86a249..5fd8e063fd43cd 100644 --- a/paddle/phi/kernels/gpu/index_put_kernel.cu +++ b/paddle/phi/kernels/gpu/index_put_kernel.cu @@ -16,12 +16,8 @@ #include "paddle/phi/backends/gpu/gpu_context.h" #include "paddle/phi/backends/gpu/gpu_launch_config.h" #include "paddle/phi/core/kernel_registry.h" -#include "paddle/phi/core/utils/array.h" #include "paddle/phi/kernels/cast_kernel.h" -#include "paddle/phi/kernels/expand_kernel.h" -#include "paddle/phi/kernels/index_put_utils.h" -#include "paddle/phi/kernels/nonzero_kernel.h" -#include "paddle/phi/kernels/split_kernel.h" +#include "paddle/phi/kernels/funcs/index_put_utils.h" namespace phi { template @@ -67,13 +63,13 @@ __global__ void index_put_cuda_kernel(const int64_t N, } int64_t offset = 0; for (int i = 0; i < Rank; ++i) { - cur_ix = (int64_t(*(indices[i] + idx))); + cur_ix = (static_cast(*(indices[i] + idx))); if (cur_ix < 0) { cur_ix += shape[i]; } offset += stride[i] * cur_ix; } - // 能不能加到模板里面去 + if (accumulate) { *(out + offset) += *(vals + (idx & isSingleValTensor)); } else { @@ -139,6 +135,9 @@ void IndexPutKernel(const Context& dev_ctx, phi::errors::InvalidArgument( "The data type of tensor in indices must be same to the data type " "of tensor x.")); + PADDLE_ENFORCE_EQ(indices_v.empty(), + false, + phi::errors::InvalidArgument("Indices cannot be empty.")); std::vector tmp_args; std::vector int_indices_v = DealWithBoolIndices(dev_ctx, indices_v, &tmp_args); diff --git a/paddle/phi/kernels/index_put_grad_kernel.h b/paddle/phi/kernels/index_put_grad_kernel.h index d5313ac10dd5ae..e0d4471a6e9b36 100644 --- a/paddle/phi/kernels/index_put_grad_kernel.h +++ b/paddle/phi/kernels/index_put_grad_kernel.h @@ -15,16 +15,8 @@ #pragma once #include -#include "paddle/fluid/memory/malloc.h" -#include "paddle/fluid/memory/memcpy.h" -#include "paddle/phi/common/int_array.h" #include "paddle/phi/common/place.h" #include "paddle/phi/core/dense_tensor.h" -#include "paddle/phi/core/utils/array.h" -#include "paddle/phi/kernels/expand_kernel.h" -#include "paddle/phi/kernels/nonzero_kernel.h" -#include "paddle/phi/kernels/reshape_kernel.h" -#include "paddle/phi/kernels/split_kernel.h" namespace phi { template @@ -36,5 +28,4 @@ void IndexPutGradKernel(const Context& dev_ctx, bool accumulate, DenseTensor* x_grad, DenseTensor* value_grad); - } // namespace phi diff --git a/paddle/phi/kernels/index_put_kernel.h b/paddle/phi/kernels/index_put_kernel.h index 1f5d17fe451e79..04577291faa626 100644 --- a/paddle/phi/kernels/index_put_kernel.h +++ b/paddle/phi/kernels/index_put_kernel.h @@ -15,16 +15,8 @@ #pragma once #include -#include "paddle/fluid/memory/malloc.h" -#include "paddle/fluid/memory/memcpy.h" -#include "paddle/phi/common/int_array.h" #include "paddle/phi/common/place.h" #include "paddle/phi/core/dense_tensor.h" -#include "paddle/phi/core/utils/array.h" -#include "paddle/phi/kernels/expand_kernel.h" -#include "paddle/phi/kernels/nonzero_kernel.h" -#include "paddle/phi/kernels/reshape_kernel.h" -#include "paddle/phi/kernels/split_kernel.h" namespace phi { template diff --git a/python/paddle/fluid/tests/unittests/test_index_put_op.py b/python/paddle/fluid/tests/unittests/test_index_put_op.py index 39d342517db9d5..5f0257f25535c7 100644 --- a/python/paddle/fluid/tests/unittests/test_index_put_op.py +++ b/python/paddle/fluid/tests/unittests/test_index_put_op.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/python/paddle/tensor/manipulation.py b/python/paddle/tensor/manipulation.py index 8e0ce3b1199442..46c40e61194ae1 100644 --- a/python/paddle/tensor/manipulation.py +++ b/python/paddle/tensor/manipulation.py @@ -4840,7 +4840,6 @@ def index_put_(x, indices, value, accumulate=False, name=None): """ - assert len(indices) != 0, "indices can't be empty" return _C_ops.index_put_(x, indices, value, accumulate) @@ -4872,8 +4871,6 @@ def index_put(x, indices, value, accumulate=False, name=None): # [0., 0., 1.], # [0., 1., 0.]]) """ - - assert len(indices) != 0, "indices can't be empty" if in_dygraph_mode(): return _C_ops.index_put(x, indices, value, accumulate) From 9b2d4550a4cabd6454bfcc5392907d163816467e Mon Sep 17 00:00:00 2001 From: Courtesy-Xs Date: Sat, 6 May 2023 07:57:20 +0000 Subject: [PATCH 15/24] alter 2022 to 2023 in copyright declaration --- paddle/phi/kernels/cpu/index_put_kernel.cc | 2 +- paddle/phi/kernels/funcs/index_put_utils.h | 2 +- paddle/phi/kernels/index_put_grad_kernel.h | 2 +- paddle/phi/kernels/index_put_kernel.h | 2 +- 4 files changed, 4 insertions(+), 4 deletions(-) diff --git a/paddle/phi/kernels/cpu/index_put_kernel.cc b/paddle/phi/kernels/cpu/index_put_kernel.cc index 3374805ac28a04..b8b9dad972e9ce 100644 --- a/paddle/phi/kernels/cpu/index_put_kernel.cc +++ b/paddle/phi/kernels/cpu/index_put_kernel.cc @@ -1,4 +1,4 @@ -// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. diff --git a/paddle/phi/kernels/funcs/index_put_utils.h b/paddle/phi/kernels/funcs/index_put_utils.h index 79ac9cdc6be615..00e79ff2095de7 100644 --- a/paddle/phi/kernels/funcs/index_put_utils.h +++ b/paddle/phi/kernels/funcs/index_put_utils.h @@ -1,4 +1,4 @@ -// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. diff --git a/paddle/phi/kernels/index_put_grad_kernel.h b/paddle/phi/kernels/index_put_grad_kernel.h index e0d4471a6e9b36..f439573230d99b 100644 --- a/paddle/phi/kernels/index_put_grad_kernel.h +++ b/paddle/phi/kernels/index_put_grad_kernel.h @@ -1,4 +1,4 @@ -// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. diff --git a/paddle/phi/kernels/index_put_kernel.h b/paddle/phi/kernels/index_put_kernel.h index 04577291faa626..a7d313c252c52d 100644 --- a/paddle/phi/kernels/index_put_kernel.h +++ b/paddle/phi/kernels/index_put_kernel.h @@ -1,4 +1,4 @@ -// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. From 0c6545a32507e32aa6ae610a292740d906fc34f7 Mon Sep 17 00:00:00 2001 From: Courtesy-Xs Date: Sat, 6 May 2023 09:38:45 +0000 Subject: [PATCH 16/24] refactor code to delete some duplicated code --- .../phi/kernels/cpu/index_put_grad_kernel.cc | 139 +++------------ paddle/phi/kernels/cpu/index_put_kernel.cc | 105 +++-------- paddle/phi/kernels/funcs/index_put_utils.h | 115 +++++++++++- .../phi/kernels/gpu/index_put_grad_kernel.cu | 165 ++++-------------- paddle/phi/kernels/gpu/index_put_kernel.cu | 106 +++-------- 5 files changed, 207 insertions(+), 423 deletions(-) diff --git a/paddle/phi/kernels/cpu/index_put_grad_kernel.cc b/paddle/phi/kernels/cpu/index_put_grad_kernel.cc index 41d6ac21c3dc89..b5a1da17ebd451 100644 --- a/paddle/phi/kernels/cpu/index_put_grad_kernel.cc +++ b/paddle/phi/kernels/cpu/index_put_grad_kernel.cc @@ -113,6 +113,15 @@ void LaunchIndexPutGradKernel(const Context& dev_ctx, numel, pd_indices, x_grad_stride, x_grad_dims, x_grad_data); } } + + auto out_grad_dims = out_grad.dims(); + const int64_t numel = indices_v[0]->numel(); + auto out_grad_stride = phi::stride(out_grad_dims); + + const int64_t* pd_indices[7]; + for (size_t i = 0; i < indices_v.size(); ++i) { + pd_indices[i] = indices_v[i]->data(); + } if (value_grad) { if (value_grad->numel() == 1) { DenseTensor tmp_value_grad(value_grad->dtype()); @@ -121,14 +130,6 @@ void LaunchIndexPutGradKernel(const Context& dev_ctx, T* tmp_value_grad_data = dev_ctx.template Alloc(&tmp_value_grad); auto out_grad_data = out_grad.data(); - auto out_grad_dims = out_grad.dims(); - const int64_t numel = indices_v[0]->numel(); - auto out_grad_stride = phi::stride(out_grad_dims); - - const int64_t* pd_indices[7]; - for (size_t i = 0; i < indices_v.size(); ++i) { - pd_indices[i] = indices_v[i]->data(); - } index_put_grad_kernel(numel, out_grad_data, pd_indices, @@ -149,14 +150,6 @@ void LaunchIndexPutGradKernel(const Context& dev_ctx, T* value_grad_data = dev_ctx.template Alloc(value_grad); auto out_grad_data = out_grad.data(); - auto out_grad_dims = out_grad.dims(); - const int64_t numel = indices_v[0]->numel(); - auto out_grad_stride = phi::stride(out_grad_dims); - - const int64_t* pd_indices[7]; - for (size_t i = 0; i < indices_v.size(); ++i) { - pd_indices[i] = indices_v[i]->data(); - } index_put_grad_kernel(numel, out_grad_data, pd_indices, @@ -170,14 +163,6 @@ void LaunchIndexPutGradKernel(const Context& dev_ctx, T* tmp_value_grad_data = dev_ctx.template Alloc(&tmp_value_grad); auto out_grad_data = out_grad.data(); - auto out_grad_dims = out_grad.dims(); - const int64_t numel = indices_v[0]->numel(); - auto out_grad_stride = phi::stride(out_grad_dims); - - const int64_t* pd_indices[7]; - for (size_t i = 0; i < indices_v.size(); ++i) { - pd_indices[i] = indices_v[i]->data(); - } index_put_grad_kernel(numel, out_grad_data, pd_indices, @@ -189,32 +174,9 @@ void LaunchIndexPutGradKernel(const Context& dev_ctx, std::vector before_dims = phi::vectorize(value_grad->dims()); std::vector compress_dims; std::vector dims_without_1; - int i = static_cast(after_dims.size()) - 1; - int j = static_cast(before_dims.size()) - 1; - if (i < j) { - PADDLE_THROW(phi::errors::InvalidArgument( - "shape of value can't not be broadcast to shape of x[indices]")); - } - while ((i >= 0) && (j >= 0)) { - if (after_dims[i] == before_dims[j]) { - dims_without_1.push_back(before_dims[j]); - i--; - j--; - continue; - } else if (before_dims[j] == 1) { - compress_dims.push_back(i); - i--; - j--; - } else { - PADDLE_THROW(phi::errors::InvalidArgument( - "shape of value can't not be broadcast to shape of x[indices]")); - } - } - while (i >= 0) { - compress_dims.push_back(i); - i--; - } + CalCompressedDimsWith1AndWithout1( + &after_dims, &before_dims, &compress_dims, &dims_without_1); phi::DenseTensor value_grad_dims_without1(value_grad->dtype()); value_grad_dims_without1.Resize(phi::make_ddim(dims_without_1)); @@ -252,7 +214,6 @@ void IndexPutGradKernel(const Context& dev_ctx, std::vector tmp_args; std::vector int_indices_v = DealWithBoolIndices(dev_ctx, indices_v, &tmp_args); - const size_t total_dims = x.dims().size(); auto bd_dim = BroadCastTensorsDims(int_indices_v); std::vector res_dim_v(phi::vectorize(bd_dim)); @@ -260,74 +221,20 @@ void IndexPutGradKernel(const Context& dev_ctx, std::vector tmp_res_indices_v; std::vector range_tensor_v; - if (int_indices_v.size() < total_dims) { - std::vector tmp_x_dims = phi::vectorize(x.dims()); - int len_bd_dim = bd_dim.size(); - res_dim_v.insert(res_dim_v.end(), - tmp_x_dims.begin() + int_indices_v.size(), - tmp_x_dims.end()); - - std::vector reshaped_indices_v; - for (size_t i = 0; i < int_indices_v.size(); ++i) { - if (int_indices_v[i]->dtype() == phi::DataType::INT32) { - reshaped_indices_v.emplace_back(phi::Cast( - dev_ctx, *int_indices_v[i], phi::DataType::INT64)); - } else { - reshaped_indices_v.emplace_back(*int_indices_v[i]); - } - } - for (size_t i = len_bd_dim; i < res_dim_v.size(); ++i) { - reshaped_indices_v.emplace_back(GetRangeTensor( - dev_ctx, res_dim_v[i], phi::DataType::INT64)); - } - phi::DDim res_dim = phi::make_ddim(res_dim_v); - - for (size_t i = 0; i < reshaped_indices_v.size(); ++i) { - tmp_res_indices_v.emplace_back( - GetReshapeAndExpandTensor( - dev_ctx, - reshaped_indices_v[i], - res_dim, - bd_dim, - ((i < int_indices_v.size()) - ? 0 - : i - int_indices_v.size() + len_bd_dim))); - } - for (size_t i = 0; i < res_indices_v.size(); ++i) { - res_indices_v[i] = &tmp_res_indices_v[i]; - } - - } else { - std::vector int_indices_v_tmp; - - for (size_t i = 0; i < int_indices_v.size(); ++i) { - if (int_indices_v[i]->dtype() == phi::DataType::INT32) { - int_indices_v_tmp.emplace_back(phi::Cast( - dev_ctx, *int_indices_v[i], phi::DataType::INT64)); - } else { - int_indices_v_tmp.emplace_back(*int_indices_v[i]); - } - } - - for (size_t i = 0; i < int_indices_v.size(); ++i) { - if (bd_dim != int_indices_v[i]->dims()) { - tmp_res_indices_v.emplace_back( - DenseTensor(phi::DataType::INT64).Resize(bd_dim)); - ExpandKernel( - dev_ctx, - int_indices_v_tmp[i], - IntArray(phi::vectorize(bd_dim)), - &tmp_res_indices_v[i]); - } else { - tmp_res_indices_v.emplace_back(int_indices_v_tmp[i]); - } - } - - for (size_t i = 0; i < res_indices_v.size(); ++i) { - res_indices_v[i] = &tmp_res_indices_v[i]; - } + for (int i = indices_v.size(); i < x.dims().size(); ++i) { + range_tensor_v.emplace_back(GetRangeTensor( + dev_ctx, x.dims()[i], phi::DataType::INT64)); } + DealWithIndices(dev_ctx, + x, + int_indices_v, + &res_indices_v, + &tmp_res_indices_v, + range_tensor_v, + bd_dim, + &res_dim_v); + LaunchIndexPutGradKernel( dev_ctx, res_indices_v, out_grad, accumulate, value_grad, x_grad); } diff --git a/paddle/phi/kernels/cpu/index_put_kernel.cc b/paddle/phi/kernels/cpu/index_put_kernel.cc index b8b9dad972e9ce..73cf62b6d6ef21 100644 --- a/paddle/phi/kernels/cpu/index_put_kernel.cc +++ b/paddle/phi/kernels/cpu/index_put_kernel.cc @@ -143,95 +143,30 @@ void IndexPutKernel(const Context& dev_ctx, std::vector res_indices_v(x.dims().size(), nullptr); std::vector tmp_res_indices_v; std::vector tmp_value_v; + std::vector range_tensor_v; const DenseTensor* ptr_value = nullptr; - if (int_indices_v.size() < total_dims) { - std::vector tmp_x_dims = phi::vectorize(x.dims()); - int len_bd_dim = bd_dim.size(); - res_dim_v.insert(res_dim_v.end(), - tmp_x_dims.begin() + int_indices_v.size(), - tmp_x_dims.end()); - - std::vector reshaped_indices_v; - for (size_t i = 0; i < int_indices_v.size(); ++i) { - if (int_indices_v[i]->dtype() == phi::DataType::INT32) { - reshaped_indices_v.emplace_back(phi::Cast( - dev_ctx, *int_indices_v[i], phi::DataType::INT64)); - } else { - reshaped_indices_v.emplace_back(*int_indices_v[i]); - } - } - for (size_t i = len_bd_dim; i < res_dim_v.size(); ++i) { - reshaped_indices_v.emplace_back(GetRangeTensor( - dev_ctx, res_dim_v[i], phi::DataType::INT64)); - } - phi::DDim res_dim = phi::make_ddim(res_dim_v); - - for (size_t i = 0; i < reshaped_indices_v.size(); ++i) { - tmp_res_indices_v.emplace_back( - GetReshapeAndExpandTensor( - dev_ctx, - reshaped_indices_v[i], - res_dim, - bd_dim, - ((i < int_indices_v.size()) - ? 0 - : i - int_indices_v.size() + len_bd_dim))); - } - for (size_t i = 0; i < res_indices_v.size(); ++i) { - res_indices_v[i] = &tmp_res_indices_v[i]; - } + for (int i = indices_v.size(); i < x.dims().size(); ++i) { + range_tensor_v.emplace_back(GetRangeTensor( + dev_ctx, x.dims()[i], phi::DataType::INT64)); + } - if (value.numel() != 1) { - tmp_value_v.emplace_back(DenseTensor(value.dtype()).Resize(res_dim)); - ExpandKernel(dev_ctx, - value, - IntArray(phi::vectorize(res_dim)), - &tmp_value_v[0]); - ptr_value = &tmp_value_v[0]; - } else { - ptr_value = &value; - } + DealWithIndices(dev_ctx, + x, + int_indices_v, + &res_indices_v, + &tmp_res_indices_v, + range_tensor_v, + bd_dim, + &res_dim_v); + if (value.numel() != 1) { + tmp_value_v.emplace_back( + DenseTensor(value.dtype()).Resize(phi::make_ddim(res_dim_v))); + ExpandKernel( + dev_ctx, value, IntArray(res_dim_v), &tmp_value_v[0]); + ptr_value = &tmp_value_v[0]; } else { - std::vector int_indices_v_tmp; - - for (size_t i = 0; i < int_indices_v.size(); ++i) { - if (int_indices_v[i]->dtype() == phi::DataType::INT32) { - int_indices_v_tmp.emplace_back(phi::Cast( - dev_ctx, *int_indices_v[i], phi::DataType::INT64)); - } else { - int_indices_v_tmp.emplace_back(*int_indices_v[i]); - } - } - - for (size_t i = 0; i < int_indices_v.size(); ++i) { - if (bd_dim != int_indices_v[i]->dims()) { - tmp_res_indices_v.emplace_back( - DenseTensor(phi::DataType::INT64).Resize(bd_dim)); - ExpandKernel( - dev_ctx, - int_indices_v_tmp[i], - IntArray(phi::vectorize(bd_dim)), - &tmp_res_indices_v[i]); - } else { - tmp_res_indices_v.emplace_back(int_indices_v_tmp[i]); - } - } - - for (size_t i = 0; i < res_indices_v.size(); ++i) { - res_indices_v[i] = &tmp_res_indices_v[i]; - } - - if (value.numel() != 1) { - tmp_value_v.emplace_back(DenseTensor(value.dtype()).Resize(bd_dim)); - ExpandKernel(dev_ctx, - value, - IntArray(phi::vectorize(bd_dim)), - &tmp_value_v[0]); - ptr_value = &tmp_value_v[0]; - } else { - ptr_value = &value; - } + ptr_value = &value; } LaunchIndexPutKernel( diff --git a/paddle/phi/kernels/funcs/index_put_utils.h b/paddle/phi/kernels/funcs/index_put_utils.h index 00e79ff2095de7..5c832a3b0fb79a 100644 --- a/paddle/phi/kernels/funcs/index_put_utils.h +++ b/paddle/phi/kernels/funcs/index_put_utils.h @@ -20,6 +20,7 @@ #include "paddle/phi/common/place.h" #include "paddle/phi/core/dense_tensor.h" #include "paddle/phi/core/utils/array.h" +#include "paddle/phi/kernels/cast_kernel.h" #include "paddle/phi/kernels/expand_kernel.h" #include "paddle/phi/kernels/nonzero_kernel.h" #include "paddle/phi/kernels/reshape_kernel.h" @@ -97,8 +98,6 @@ static std::vector DealWithBoolIndices( std::vector integer_indices(rank, nullptr); for (int i = 0; i < rank; ++i) { - // here should be - // tmp_indices_v.emplace_back(DenseTensor(phi::DataType::INT64).Resize(phi::make_ddim({nonzero_indices.dims()[0],1}))); tmp_indices_v->emplace_back( DenseTensor(phi::DataType::INT64) .Resize(phi::make_ddim({nonzero_indices.dims()[0]}))); @@ -178,4 +177,116 @@ T** GetDevicePointerArray(const Context& ctx, return reinterpret_cast(d_indices_data->ptr()); } +template +static void DealWithIndices( + const Context& dev_ctx, + const DenseTensor& x, + const std::vector& int_indices_v, + std::vector* res_indices_v, + std::vector* tmp_res_indices_v, + const std::vector& range_tensor_v, + const phi::DDim& bd_dim, + std::vector* res_dim_v) { + size_t total_dims = x.dims().size(); + if (int_indices_v.size() < total_dims) { + std::vector tmp_x_dims = phi::vectorize(x.dims()); + int len_bd_dim = bd_dim.size(); + res_dim_v->insert(res_dim_v->end(), + tmp_x_dims.begin() + int_indices_v.size(), + tmp_x_dims.end()); + + std::vector reshaped_indices_v; + for (size_t i = 0; i < int_indices_v.size(); ++i) { + if (int_indices_v[i]->dtype() == phi::DataType::INT32) { + reshaped_indices_v.emplace_back(phi::Cast( + dev_ctx, *int_indices_v[i], phi::DataType::INT64)); + } else { + reshaped_indices_v.emplace_back(*int_indices_v[i]); + } + } + reshaped_indices_v.insert( + reshaped_indices_v.end(), range_tensor_v.begin(), range_tensor_v.end()); + + phi::DDim res_dim = phi::make_ddim(*res_dim_v); + + for (size_t i = 0; i < reshaped_indices_v.size(); ++i) { + tmp_res_indices_v->emplace_back( + GetReshapeAndExpandTensor( + dev_ctx, + reshaped_indices_v[i], + res_dim, + bd_dim, + ((i < int_indices_v.size()) + ? 0 + : i - int_indices_v.size() + len_bd_dim))); + } + for (size_t i = 0; i < res_indices_v->size(); ++i) { + (*res_indices_v)[i] = &(*tmp_res_indices_v)[i]; + } + + } else { + std::vector int_indices_v_tmp; + + for (size_t i = 0; i < int_indices_v.size(); ++i) { + if (int_indices_v[i]->dtype() == phi::DataType::INT32) { + int_indices_v_tmp.emplace_back(phi::Cast( + dev_ctx, *int_indices_v[i], phi::DataType::INT64)); + } else { + int_indices_v_tmp.emplace_back(*int_indices_v[i]); + } + } + + for (size_t i = 0; i < int_indices_v.size(); ++i) { + if (bd_dim != int_indices_v[i]->dims()) { + tmp_res_indices_v->emplace_back( + DenseTensor(phi::DataType::INT64).Resize(bd_dim)); + ExpandKernel( + dev_ctx, + int_indices_v_tmp[i], + IntArray(phi::vectorize(bd_dim)), + &(*tmp_res_indices_v)[i]); + } else { + tmp_res_indices_v->emplace_back(int_indices_v_tmp[i]); + } + } + + for (size_t i = 0; i < res_indices_v->size(); ++i) { + (*res_indices_v)[i] = &(*tmp_res_indices_v)[i]; + } + } +} + +static void CalCompressedDimsWith1AndWithout1( + std::vector* after_dims, + std::vector* before_dims, + std::vector* compress_dims, + std::vector* dims_without_1) { + int i = static_cast(after_dims->size()) - 1; + int j = static_cast(before_dims->size()) - 1; + if (i < j) { + PADDLE_THROW(phi::errors::InvalidArgument( + "shape of value can't not be broadcast to shape of x[indices]")); + } + + while ((i >= 0) && (j >= 0)) { + if ((*after_dims)[i] == (*before_dims)[j]) { + dims_without_1->push_back((*before_dims)[j]); + i--; + j--; + continue; + } else if ((*before_dims)[j] == 1) { + compress_dims->push_back(i); + i--; + j--; + } else { + PADDLE_THROW(phi::errors::InvalidArgument( + "shape of value can't not be broadcast to shape of x[indices]")); + } + } + while (i >= 0) { + compress_dims->push_back(i); + i--; + } +} + } // namespace phi diff --git a/paddle/phi/kernels/gpu/index_put_grad_kernel.cu b/paddle/phi/kernels/gpu/index_put_grad_kernel.cu index 8d81279655e2b0..6adc18f425e041 100644 --- a/paddle/phi/kernels/gpu/index_put_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/index_put_grad_kernel.cu @@ -133,6 +133,21 @@ void LaunchIndexPutGradCudaKernel( } } + auto out_grad_dims = out_grad.dims(); + const int64_t numel = indices_v[0]->numel(); + auto config = phi::backends::gpu::GetGpuLaunchConfig1D(dev_ctx, numel); + auto out_grad_stride = phi::stride(out_grad_dims); + + phi::Array stride_a; + phi::Array shape_a; + + for (size_t idx = 0; idx < Rank; ++idx) { + stride_a[idx] = out_grad_stride[idx]; + shape_a[idx] = out_grad_dims[idx]; + } + + auto pd_indices = GetDevicePointerArray(dev_ctx, indices_v); + if (value_grad) { if (value_grad->numel() == 1) { DenseTensor tmp_value_grad(value_grad->dtype()); @@ -141,21 +156,6 @@ void LaunchIndexPutGradCudaKernel( T* tmp_value_grad_data = dev_ctx.template Alloc(&tmp_value_grad); auto out_grad_data = out_grad.data(); - auto out_grad_dims = out_grad.dims(); - const int64_t numel = indices_v[0]->numel(); - auto config = phi::backends::gpu::GetGpuLaunchConfig1D(dev_ctx, numel); - auto out_grad_stride = phi::stride(out_grad_dims); - - phi::Array stride_a; - phi::Array shape_a; - - for (size_t idx = 0; idx < Rank; ++idx) { - stride_a[idx] = out_grad_stride[idx]; - shape_a[idx] = out_grad_dims[idx]; - } - - auto pd_indices = - GetDevicePointerArray(dev_ctx, indices_v); index_put_grad_cuda_kernel <<(value_grad); auto out_grad_data = out_grad.data(); - auto out_grad_dims = out_grad.dims(); - const int64_t numel = indices_v[0]->numel(); - auto config = phi::backends::gpu::GetGpuLaunchConfig1D(dev_ctx, numel); - auto out_grad_stride = phi::stride(out_grad_dims); - - phi::Array stride_a; - phi::Array shape_a; - - for (size_t idx = 0; idx < Rank; ++idx) { - stride_a[idx] = out_grad_stride[idx]; - shape_a[idx] = out_grad_dims[idx]; - } - - auto pd_indices = - GetDevicePointerArray(dev_ctx, indices_v); index_put_grad_cuda_kernel<<(&tmp_value_grad); auto out_grad_data = out_grad.data(); - auto out_grad_dims = out_grad.dims(); - const int64_t numel = indices_v[0]->numel(); - auto config = phi::backends::gpu::GetGpuLaunchConfig1D(dev_ctx, numel); - auto out_grad_stride = phi::stride(out_grad_dims); - - phi::Array stride_a; - phi::Array shape_a; - - for (size_t idx = 0; idx < Rank; ++idx) { - stride_a[idx] = out_grad_stride[idx]; - shape_a[idx] = out_grad_dims[idx]; - } - - auto pd_indices = - GetDevicePointerArray(dev_ctx, indices_v); index_put_grad_cuda_kernel << before_dims = phi::vectorize(value_grad->dims()); std::vector compress_dims; std::vector dims_without_1; - int i = static_cast(after_dims.size()) - 1; - int j = static_cast(before_dims.size()) - 1; - if (i < j) { - PADDLE_THROW(phi::errors::InvalidArgument( - "shape of value can't not be broadcast to shape of x[indices]")); - } - while ((i >= 0) && (j >= 0)) { - if (after_dims[i] == before_dims[j]) { - dims_without_1.push_back(before_dims[j]); - i--; - j--; - continue; - } else if (before_dims[j] == 1) { - compress_dims.push_back(i); - i--; - j--; - } else { - PADDLE_THROW(phi::errors::InvalidArgument( - "shape of value can't not be broadcast to shape of x[indices]")); - } - } - while (i >= 0) { - compress_dims.push_back(i); - i--; - } + CalCompressedDimsWith1AndWithout1( + &after_dims, &before_dims, &compress_dims, &dims_without_1); phi::DenseTensor value_grad_dims_without1(value_grad->dtype()); value_grad_dims_without1.Resize(phi::make_ddim(dims_without_1)); @@ -306,74 +253,22 @@ void IndexPutGradKernel(const Context& dev_ctx, std::vector res_dim_v(phi::vectorize(bd_dim)); std::vector res_indices_v(x.dims().size(), nullptr); std::vector tmp_res_indices_v; + std::vector range_tensor_v; - if (int_indices_v.size() < total_dims) { - std::vector tmp_x_dims = phi::vectorize(x.dims()); - int len_bd_dim = bd_dim.size(); - res_dim_v.insert(res_dim_v.end(), - tmp_x_dims.begin() + int_indices_v.size(), - tmp_x_dims.end()); - - std::vector reshaped_indices_v; - for (size_t i = 0; i < int_indices_v.size(); ++i) { - if (int_indices_v[i]->dtype() == phi::DataType::INT32) { - reshaped_indices_v.emplace_back(phi::Cast( - dev_ctx, *int_indices_v[i], phi::DataType::INT64)); - } else { - reshaped_indices_v.emplace_back(*int_indices_v[i]); - } - } - for (size_t i = len_bd_dim; i < res_dim_v.size(); ++i) { - reshaped_indices_v.emplace_back(GetRangeCudaTensor( - dev_ctx, res_dim_v[i], phi::DataType::INT64)); - } - phi::DDim res_dim = phi::make_ddim(res_dim_v); - - for (size_t i = 0; i < reshaped_indices_v.size(); ++i) { - tmp_res_indices_v.emplace_back( - GetReshapeAndExpandTensor( - dev_ctx, - reshaped_indices_v[i], - res_dim, - bd_dim, - ((i < int_indices_v.size()) - ? 0 - : i - int_indices_v.size() + len_bd_dim))); - } - for (size_t i = 0; i < res_indices_v.size(); ++i) { - res_indices_v[i] = &tmp_res_indices_v[i]; - } - } else { - std::vector int_indices_v_tmp; - - for (size_t i = 0; i < int_indices_v.size(); ++i) { - if (int_indices_v[i]->dtype() == phi::DataType::INT32) { - int_indices_v_tmp.emplace_back(phi::Cast( - dev_ctx, *int_indices_v[i], phi::DataType::INT64)); - } else { - int_indices_v_tmp.emplace_back(*int_indices_v[i]); - } - } - - for (size_t i = 0; i < int_indices_v.size(); ++i) { - if (bd_dim != int_indices_v[i]->dims()) { - tmp_res_indices_v.emplace_back( - DenseTensor(phi::DataType::INT64).Resize(bd_dim)); - ExpandKernel( - dev_ctx, - int_indices_v_tmp[i], - IntArray(phi::vectorize(bd_dim)), - &tmp_res_indices_v[i]); - } else { - tmp_res_indices_v.emplace_back(int_indices_v_tmp[i]); - } - } - - for (size_t i = 0; i < res_indices_v.size(); ++i) { - res_indices_v[i] = &tmp_res_indices_v[i]; - } + for (int i = indices_v.size(); i < x.dims().size(); ++i) { + range_tensor_v.emplace_back(GetRangeCudaTensor( + dev_ctx, x.dims()[i], phi::DataType::INT64)); } + DealWithIndices(dev_ctx, + x, + int_indices_v, + &res_indices_v, + &tmp_res_indices_v, + range_tensor_v, + bd_dim, + &res_dim_v); + switch (total_dims) { case 1: LaunchIndexPutGradCudaKernel( diff --git a/paddle/phi/kernels/gpu/index_put_kernel.cu b/paddle/phi/kernels/gpu/index_put_kernel.cu index 5fd8e063fd43cd..4efe3e5a43d8cb 100644 --- a/paddle/phi/kernels/gpu/index_put_kernel.cu +++ b/paddle/phi/kernels/gpu/index_put_kernel.cu @@ -148,95 +148,31 @@ void IndexPutKernel(const Context& dev_ctx, std::vector res_indices_v(x.dims().size(), nullptr); std::vector tmp_res_indices_v; std::vector tmp_value_v; + std::vector range_tensor_v; const DenseTensor* ptr_value = nullptr; - if (int_indices_v.size() < total_dims) { - std::vector tmp_x_dims = phi::vectorize(x.dims()); - int len_bd_dim = bd_dim.size(); - res_dim_v.insert(res_dim_v.end(), - tmp_x_dims.begin() + int_indices_v.size(), - tmp_x_dims.end()); - - std::vector reshaped_indices_v; - for (size_t i = 0; i < int_indices_v.size(); ++i) { - if (int_indices_v[i]->dtype() == phi::DataType::INT32) { - reshaped_indices_v.emplace_back(phi::Cast( - dev_ctx, *int_indices_v[i], phi::DataType::INT64)); - } else { - reshaped_indices_v.emplace_back(*int_indices_v[i]); - } - } - - for (size_t i = len_bd_dim; i < res_dim_v.size(); ++i) { - reshaped_indices_v.emplace_back(GetRangeCudaTensor( - dev_ctx, res_dim_v[i], phi::DataType::INT64)); - } - phi::DDim res_dim = phi::make_ddim(res_dim_v); - - for (size_t i = 0; i < reshaped_indices_v.size(); ++i) { - tmp_res_indices_v.emplace_back( - GetReshapeAndExpandTensor( - dev_ctx, - reshaped_indices_v[i], - res_dim, - bd_dim, - ((i < int_indices_v.size()) - ? 0 - : i - int_indices_v.size() + len_bd_dim))); - } - for (size_t i = 0; i < res_indices_v.size(); ++i) { - res_indices_v[i] = &tmp_res_indices_v[i]; - } + for (int i = indices_v.size(); i < x.dims().size(); ++i) { + range_tensor_v.emplace_back(GetRangeCudaTensor( + dev_ctx, x.dims()[i], phi::DataType::INT64)); + } - if (value.numel() != 1) { - tmp_value_v.emplace_back(DenseTensor(value.dtype()).Resize(res_dim)); - ExpandKernel(dev_ctx, - value, - IntArray(phi::vectorize(res_dim)), - &tmp_value_v[0]); - ptr_value = &tmp_value_v[0]; - } else { - ptr_value = &value; - } + DealWithIndices(dev_ctx, + x, + int_indices_v, + &res_indices_v, + &tmp_res_indices_v, + range_tensor_v, + bd_dim, + &res_dim_v); + + if (value.numel() != 1) { + tmp_value_v.emplace_back( + DenseTensor(value.dtype()).Resize(phi::make_ddim(res_dim_v))); + ExpandKernel( + dev_ctx, value, IntArray(res_dim_v), &tmp_value_v[0]); + ptr_value = &tmp_value_v[0]; } else { - std::vector int_indices_v_tmp; - - for (size_t i = 0; i < int_indices_v.size(); ++i) { - if (int_indices_v[i]->dtype() == phi::DataType::INT32) { - int_indices_v_tmp.emplace_back(phi::Cast( - dev_ctx, *int_indices_v[i], phi::DataType::INT64)); - } else { - int_indices_v_tmp.emplace_back(*int_indices_v[i]); - } - } - for (size_t i = 0; i < int_indices_v.size(); ++i) { - if (bd_dim != int_indices_v[i]->dims()) { - tmp_res_indices_v.emplace_back( - DenseTensor(phi::DataType::INT64).Resize(bd_dim)); - ExpandKernel( - dev_ctx, - int_indices_v_tmp[i], - IntArray(phi::vectorize(bd_dim)), - &tmp_res_indices_v[i]); - } else { - tmp_res_indices_v.emplace_back(int_indices_v_tmp[i]); - } - } - - for (size_t i = 0; i < res_indices_v.size(); ++i) { - res_indices_v[i] = &tmp_res_indices_v[i]; - } - - if (value.numel() != 1) { - tmp_value_v.emplace_back(DenseTensor(value.dtype()).Resize(bd_dim)); - ExpandKernel(dev_ctx, - value, - IntArray(phi::vectorize(bd_dim)), - &tmp_value_v[0]); - ptr_value = &tmp_value_v[0]; - } else { - ptr_value = &value; - } + ptr_value = &value; } switch (total_dims) { From 894adb134c557f7251e29cf1113d5e7ed984aa56 Mon Sep 17 00:00:00 2001 From: Courtesy-Xs Date: Mon, 8 May 2023 06:57:11 +0000 Subject: [PATCH 17/24] replaace reshape with resize for decrease extra memcpy --- paddle/phi/api/yaml/ops.yaml | 1 + paddle/phi/infermeta/multiary.cc | 30 +++---- paddle/phi/infermeta/multiary.h | 12 +-- .../phi/kernels/cpu/index_put_grad_kernel.cc | 82 +++++++------------ paddle/phi/kernels/cpu/index_put_kernel.cc | 67 ++++++--------- paddle/phi/kernels/funcs/index_put_utils.h | 49 +++++++++++ .../phi/kernels/gpu/index_put_grad_kernel.cu | 80 +++++++----------- paddle/phi/kernels/gpu/index_put_kernel.cu | 71 ++++++---------- paddle/phi/kernels/index_put_grad_kernel.h | 1 - paddle/phi/kernels/index_put_kernel.h | 1 - 10 files changed, 181 insertions(+), 213 deletions(-) diff --git a/paddle/phi/api/yaml/ops.yaml b/paddle/phi/api/yaml/ops.yaml index 143be86998b549..4f18403d44dca4 100644 --- a/paddle/phi/api/yaml/ops.yaml +++ b/paddle/phi/api/yaml/ops.yaml @@ -938,6 +938,7 @@ func : IndexPutInferMeta kernel : func : index_put + data_type : x inplace : (x -> out) backward : index_put_grad diff --git a/paddle/phi/infermeta/multiary.cc b/paddle/phi/infermeta/multiary.cc index 36e0dfe2391969..8ea2dc65d9a540 100644 --- a/paddle/phi/infermeta/multiary.cc +++ b/paddle/phi/infermeta/multiary.cc @@ -1962,6 +1962,21 @@ void InterpolateInferMeta( } } +void IndexPutInferMeta(const MetaTensor& x, + const std::vector& indices, + const MetaTensor& value, + bool accumulate, + MetaTensor* out) { + auto in_dims = x.dims(); + PADDLE_ENFORCE_LT( + in_dims.size(), + 7, + phi::errors::InvalidArgument( + "The rank of input should be less than 7, but received %d.", + in_dims.size())); + out->share_meta(x); +} + void LambInferMeta(const MetaTensor& param, const MetaTensor& grad, const MetaTensor& learning_rate, @@ -3249,21 +3264,6 @@ void MoeInferMeta(const MetaTensor& x, out->set_layout(x.layout()); } -void IndexPutInferMeta(const MetaTensor& x, - const std::vector& indices, - const MetaTensor& value, - bool accumulate, - MetaTensor* out) { - auto in_dims = x.dims(); - PADDLE_ENFORCE_LT( - in_dims.size(), - 7, - phi::errors::InvalidArgument( - "The rank of input should be less than 7, but received %d.", - in_dims.size())); - out->share_meta(x); -} - void WeightedSampleNeighborsInferMeta(const MetaTensor& row, const MetaTensor& col_ptr, const MetaTensor& edge_weight, diff --git a/paddle/phi/infermeta/multiary.h b/paddle/phi/infermeta/multiary.h index 02bdaf2bd0d5ff..2a924ecbb30a08 100644 --- a/paddle/phi/infermeta/multiary.h +++ b/paddle/phi/infermeta/multiary.h @@ -332,6 +332,12 @@ void InterpolateInferMeta( MetaTensor* output, MetaConfig config = MetaConfig()); +void IndexPutInferMeta(const MetaTensor& x, + const std::vector& indices, + const MetaTensor& value, + bool accumulate, + MetaTensor* out); + void LambInferMeta(const MetaTensor& param, const MetaTensor& grad, const MetaTensor& learning_rate, @@ -615,10 +621,4 @@ void MoeInferMeta(const MetaTensor& x, const std::string& act_type, MetaTensor* out); -void IndexPutInferMeta(const MetaTensor& x, - const std::vector& indices, - const MetaTensor& value, - bool accumulate, - MetaTensor* out); - } // namespace phi diff --git a/paddle/phi/kernels/cpu/index_put_grad_kernel.cc b/paddle/phi/kernels/cpu/index_put_grad_kernel.cc index b5a1da17ebd451..ce2771c8c304bb 100644 --- a/paddle/phi/kernels/cpu/index_put_grad_kernel.cc +++ b/paddle/phi/kernels/cpu/index_put_grad_kernel.cc @@ -21,24 +21,9 @@ namespace phi { -template -void range_kernel(int64_t N, T* out) { - for (int64_t idx = 0; idx < N; ++idx) { - out[idx] = idx; - } -} +UNROLL_RANGE_KERNEL_DEFINITION -template -phi::DenseTensor GetRangeTensor(const Context& dev_ctx, - int64_t N, - phi::DataType dtype) { - phi::DenseTensor res(dtype); - res.Resize(phi::make_ddim({N})); - DenseTensor* p_res = &res; - T* out = dev_ctx.template Alloc(p_res); - range_kernel(N, out); - return res; -} +UNROLL_GET_RANGE_TENSOR_DEFINITION template void set_zero_kernel(const int64_t N, @@ -91,41 +76,38 @@ void index_put_grad_kernel(const int64_t N, template void LaunchIndexPutGradKernel(const Context& dev_ctx, - const std::vector& indices_v, + const std::vector& indices, const DenseTensor& out_grad, bool accumulate, DenseTensor* value_grad, DenseTensor* x_grad) { + const int64_t* pd_indices[7]; + for (size_t i = 0; i < indices.size(); ++i) { + pd_indices[i] = indices[i]->data(); + } + if (x_grad) { phi::Copy(dev_ctx, out_grad, dev_ctx.GetPlace(), false, x_grad); if (!accumulate) { T* x_grad_data = x_grad->data(); auto x_grad_dims = x_grad->dims(); - const int64_t numel = indices_v[0]->numel(); + const int64_t numel = indices[0]->numel(); auto x_grad_stride = phi::stride(x_grad_dims); - const int64_t* pd_indices[7]; - for (size_t i = 0; i < indices_v.size(); ++i) { - pd_indices[i] = indices_v[i]->data(); - } set_zero_kernel( numel, pd_indices, x_grad_stride, x_grad_dims, x_grad_data); } } auto out_grad_dims = out_grad.dims(); - const int64_t numel = indices_v[0]->numel(); + const int64_t numel = indices[0]->numel(); auto out_grad_stride = phi::stride(out_grad_dims); - const int64_t* pd_indices[7]; - for (size_t i = 0; i < indices_v.size(); ++i) { - pd_indices[i] = indices_v[i]->data(); - } if (value_grad) { if (value_grad->numel() == 1) { DenseTensor tmp_value_grad(value_grad->dtype()); - tmp_value_grad.Resize(indices_v[0]->dims()); + tmp_value_grad.Resize(indices[0]->dims()); T* tmp_value_grad_data = dev_ctx.template Alloc(&tmp_value_grad); auto out_grad_data = out_grad.data(); @@ -146,7 +128,7 @@ void LaunchIndexPutGradKernel(const Context& dev_ctx, value_grad->dtype(), false, value_grad); - } else if (value_grad->numel() == indices_v[0]->numel()) { + } else if (value_grad->numel() == indices[0]->numel()) { T* value_grad_data = dev_ctx.template Alloc(value_grad); auto out_grad_data = out_grad.data(); @@ -158,7 +140,7 @@ void LaunchIndexPutGradKernel(const Context& dev_ctx, value_grad_data); } else { DenseTensor tmp_value_grad(value_grad->dtype()); - tmp_value_grad.Resize(indices_v[0]->dims()); + tmp_value_grad.Resize(indices[0]->dims()); T* tmp_value_grad_data = dev_ctx.template Alloc(&tmp_value_grad); auto out_grad_data = out_grad.data(); @@ -175,23 +157,19 @@ void LaunchIndexPutGradKernel(const Context& dev_ctx, std::vector compress_dims; std::vector dims_without_1; - CalCompressedDimsWith1AndWithout1( + funcs::CalCompressedDimsWith1AndWithout1( &after_dims, &before_dims, &compress_dims, &dims_without_1); - phi::DenseTensor value_grad_dims_without1(value_grad->dtype()); - value_grad_dims_without1.Resize(phi::make_ddim(dims_without_1)); + auto pre_dims = value_grad->dims(); + value_grad->Resize(phi::make_ddim(dims_without_1)); IntArray v_axis(compress_dims); SumKernel(dev_ctx, tmp_value_grad, v_axis, value_grad->dtype(), false, - &value_grad_dims_without1); - phi::ReshapeInferKernel( - dev_ctx, - value_grad_dims_without1, - phi::IntArray(phi::vectorize(value_grad->dims())), - value_grad); + value_grad); + value_grad->Resize(pre_dims); } } } @@ -199,7 +177,7 @@ void LaunchIndexPutGradKernel(const Context& dev_ctx, template void IndexPutGradKernel(const Context& dev_ctx, const DenseTensor& x, - const std::vector& indices_v, + const std::vector& indices, const DenseTensor& value, const DenseTensor& out_grad, bool accumulate, @@ -213,27 +191,27 @@ void IndexPutGradKernel(const Context& dev_ctx, "of tensor x.")); std::vector tmp_args; std::vector int_indices_v = - DealWithBoolIndices(dev_ctx, indices_v, &tmp_args); - auto bd_dim = BroadCastTensorsDims(int_indices_v); + funcs::DealWithBoolIndices(dev_ctx, indices, &tmp_args); + auto bd_dim = funcs::BroadCastTensorsDims(int_indices_v); std::vector res_dim_v(phi::vectorize(bd_dim)); std::vector res_indices_v(x.dims().size(), nullptr); std::vector tmp_res_indices_v; std::vector range_tensor_v; - for (int i = indices_v.size(); i < x.dims().size(); ++i) { + for (int i = indices.size(); i < x.dims().size(); ++i) { range_tensor_v.emplace_back(GetRangeTensor( dev_ctx, x.dims()[i], phi::DataType::INT64)); } - DealWithIndices(dev_ctx, - x, - int_indices_v, - &res_indices_v, - &tmp_res_indices_v, - range_tensor_v, - bd_dim, - &res_dim_v); + funcs::DealWithIndices(dev_ctx, + x, + int_indices_v, + &res_indices_v, + &tmp_res_indices_v, + range_tensor_v, + bd_dim, + &res_dim_v); LaunchIndexPutGradKernel( dev_ctx, res_indices_v, out_grad, accumulate, value_grad, x_grad); diff --git a/paddle/phi/kernels/cpu/index_put_kernel.cc b/paddle/phi/kernels/cpu/index_put_kernel.cc index 73cf62b6d6ef21..24bb29af6a3914 100644 --- a/paddle/phi/kernels/cpu/index_put_kernel.cc +++ b/paddle/phi/kernels/cpu/index_put_kernel.cc @@ -20,24 +20,9 @@ namespace phi { -template -void range_kernel(int64_t N, T* out) { - for (int64_t idx = 0; idx < N; ++idx) { - out[idx] = idx; - } -} +UNROLL_RANGE_KERNEL_DEFINITION -template -phi::DenseTensor GetRangeTensor(const Context& dev_ctx, - int64_t N, - phi::DataType dtype) { - phi::DenseTensor res(dtype); - res.Resize(phi::make_ddim({N})); - DenseTensor* p_res = &res; - T* out = dev_ctx.template Alloc(p_res); - range_kernel(N, out); - return res; -} +UNROLL_GET_RANGE_TENSOR_DEFINITION template void index_put_kernel(const int64_t N, @@ -46,7 +31,7 @@ void index_put_kernel(const int64_t N, const int64_t** indices, const phi::DDim& stride, const phi::DDim& shape, - int64_t isSingleValTensor, + int64_t is_single_val_tensor, bool accumulate, T* out) { #ifdef PADDLE_WITH_MKLML @@ -65,9 +50,9 @@ void index_put_kernel(const int64_t N, } if (accumulate) { - *(out + offset) += *(vals + (idx & isSingleValTensor)); + *(out + offset) += *(vals + (idx & is_single_val_tensor)); } else { - *(out + offset) = *(vals + (idx & isSingleValTensor)); + *(out + offset) = *(vals + (idx & is_single_val_tensor)); } } } @@ -75,28 +60,28 @@ void index_put_kernel(const int64_t N, template void LaunchIndexPutKernel(const Context& dev_ctx, const DenseTensor& x, - const std::vector& indices_v, + const std::vector& indices, const DenseTensor& value, bool accumulate, DenseTensor* out) { auto* x_data = x.data(); auto* val_data = value.data(); - bool isInitialized = out->initialized(); + bool is_initialized = out->initialized(); T* out_data = dev_ctx.template Alloc(out); - if (!isInitialized) { + if (!is_initialized) { phi::Copy(dev_ctx, x, dev_ctx.GetPlace(), false, out); } auto x_dims = x.dims(); - const int64_t numel = indices_v[0]->numel(); + const int64_t numel = indices[0]->numel(); auto x_stride = phi::stride(x_dims); - int64_t isSingleValTensor = (value.numel() == 1) ? 0 : INT64_MAX; + int64_t is_single_val_tensor = (value.numel() == 1) ? 0 : INT64_MAX; const int64_t* pd_indices[7]; - for (size_t i = 0; i < indices_v.size(); ++i) { - pd_indices[i] = indices_v[i]->data(); + for (size_t i = 0; i < indices.size(); ++i) { + pd_indices[i] = indices[i]->data(); } index_put_kernel(numel, @@ -105,7 +90,7 @@ void LaunchIndexPutKernel(const Context& dev_ctx, pd_indices, x_stride, x_dims, - isSingleValTensor, + is_single_val_tensor, accumulate, out_data); } @@ -113,7 +98,7 @@ void LaunchIndexPutKernel(const Context& dev_ctx, template void IndexPutKernel(const Context& dev_ctx, const DenseTensor& x, - const std::vector& indices_v, + const std::vector& indices, const DenseTensor& value, bool accumulate, DenseTensor* out) { @@ -123,7 +108,7 @@ void IndexPutKernel(const Context& dev_ctx, phi::errors::InvalidArgument( "The data type of tensor in indices must be same to the data type " "of tensor x.")); - PADDLE_ENFORCE_EQ(indices_v.empty(), + PADDLE_ENFORCE_EQ(indices.empty(), false, phi::errors::InvalidArgument("Indices cannot be empty.")); @@ -135,9 +120,9 @@ void IndexPutKernel(const Context& dev_ctx, std::vector tmp_args; std::vector int_indices_v = - DealWithBoolIndices(dev_ctx, indices_v, &tmp_args); + funcs::DealWithBoolIndices(dev_ctx, indices, &tmp_args); - auto bd_dim = BroadCastTensorsDims(int_indices_v); + auto bd_dim = funcs::BroadCastTensorsDims(int_indices_v); std::vector res_dim_v(phi::vectorize(bd_dim)); std::vector res_indices_v(x.dims().size(), nullptr); @@ -146,19 +131,19 @@ void IndexPutKernel(const Context& dev_ctx, std::vector range_tensor_v; const DenseTensor* ptr_value = nullptr; - for (int i = indices_v.size(); i < x.dims().size(); ++i) { + for (int i = indices.size(); i < x.dims().size(); ++i) { range_tensor_v.emplace_back(GetRangeTensor( dev_ctx, x.dims()[i], phi::DataType::INT64)); } - DealWithIndices(dev_ctx, - x, - int_indices_v, - &res_indices_v, - &tmp_res_indices_v, - range_tensor_v, - bd_dim, - &res_dim_v); + funcs::DealWithIndices(dev_ctx, + x, + int_indices_v, + &res_indices_v, + &tmp_res_indices_v, + range_tensor_v, + bd_dim, + &res_dim_v); if (value.numel() != 1) { tmp_value_v.emplace_back( DenseTensor(value.dtype()).Resize(phi::make_ddim(res_dim_v))); diff --git a/paddle/phi/kernels/funcs/index_put_utils.h b/paddle/phi/kernels/funcs/index_put_utils.h index 5c832a3b0fb79a..0c30c1622fed3c 100644 --- a/paddle/phi/kernels/funcs/index_put_utils.h +++ b/paddle/phi/kernels/funcs/index_put_utils.h @@ -28,6 +28,8 @@ namespace phi { +namespace funcs { + template static phi::DenseTensor GetReshapeAndExpandTensor( const Context& dev_ctx, @@ -289,4 +291,51 @@ static void CalCompressedDimsWith1AndWithout1( } } +} // namespace funcs } // namespace phi + +#define UNROLL_RANGE_CUDA_KERNEL_DEFINITION \ + template \ + __global__ void range_cuda_kernel(int64_t N, T* out) { \ + int64_t idx = threadIdx.x + blockDim.x * blockIdx.x; \ + if (idx >= N) { \ + return; \ + } \ + out[idx] = idx; \ + } + +#define UNROLL_GET_RANGE_CUDA_TENSOR_DEFINITION \ + template \ + phi::DenseTensor GetRangeCudaTensor( \ + const Context& dev_ctx, int64_t N, phi::DataType dtype) { \ + phi::DenseTensor res(dtype); \ + res.Resize(phi::make_ddim({N})); \ + DenseTensor* p_res = &res; \ + T* out = dev_ctx.template Alloc(p_res); \ + auto config = phi::backends::gpu::GetGpuLaunchConfig1D(dev_ctx, N); \ + range_cuda_kernel<<>>(N, out); \ + return res; \ + } + +#define UNROLL_RANGE_KERNEL_DEFINITION \ + template \ + void range_kernel(int64_t N, T* out) { \ + for (int64_t idx = 0; idx < N; ++idx) { \ + out[idx] = idx; \ + } \ + } + +#define UNROLL_GET_RANGE_TENSOR_DEFINITION \ + template \ + phi::DenseTensor GetRangeTensor( \ + const Context& dev_ctx, int64_t N, phi::DataType dtype) { \ + phi::DenseTensor res(dtype); \ + res.Resize(phi::make_ddim({N})); \ + DenseTensor* p_res = &res; \ + T* out = dev_ctx.template Alloc(p_res); \ + range_kernel(N, out); \ + return res; \ + } diff --git a/paddle/phi/kernels/gpu/index_put_grad_kernel.cu b/paddle/phi/kernels/gpu/index_put_grad_kernel.cu index 6adc18f425e041..a5b27c1072dc8e 100644 --- a/paddle/phi/kernels/gpu/index_put_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/index_put_grad_kernel.cu @@ -23,30 +23,9 @@ namespace phi { -template -__global__ void range_cuda_kernel(int64_t N, T* out) { - int64_t idx = threadIdx.x + blockDim.x * blockIdx.x; +UNROLL_RANGE_CUDA_KERNEL_DEFINITION - if (idx >= N) { - return; - } - out[idx] = idx; -} - -template -phi::DenseTensor GetRangeCudaTensor(const Context& dev_ctx, - int64_t N, - phi::DataType dtype) { - phi::DenseTensor res(dtype); - res.Resize(phi::make_ddim({N})); - DenseTensor* p_res = &res; - T* out = dev_ctx.template Alloc(p_res); - auto config = phi::backends::gpu::GetGpuLaunchConfig1D(dev_ctx, N); - range_cuda_kernel - <<>>( - N, out); - return res; -} +UNROLL_GET_RANGE_CUDA_TENSOR_DEFINITION template __global__ void set_zero_cuda_kernel(const int64_t N, @@ -100,7 +79,7 @@ __global__ void index_put_grad_cuda_kernel(const int64_t N, template void LaunchIndexPutGradCudaKernel( const Context& dev_ctx, - const std::vector& indices_v, + const std::vector& indices, const DenseTensor& out_grad, bool accumulate, DenseTensor* value_grad, @@ -111,7 +90,7 @@ void LaunchIndexPutGradCudaKernel( T* x_grad_data = x_grad->data(); auto x_grad_dims = x_grad->dims(); - const int64_t numel = indices_v[0]->numel(); + const int64_t numel = indices[0]->numel(); auto config = phi::backends::gpu::GetGpuLaunchConfig1D(dev_ctx, numel); auto x_grad_stride = phi::stride(x_grad_dims); @@ -124,7 +103,7 @@ void LaunchIndexPutGradCudaKernel( } auto pd_indices = - GetDevicePointerArray(dev_ctx, indices_v); + funcs::GetDevicePointerArray(dev_ctx, indices); set_zero_cuda_kernel<<numel(); + const int64_t numel = indices[0]->numel(); auto config = phi::backends::gpu::GetGpuLaunchConfig1D(dev_ctx, numel); auto out_grad_stride = phi::stride(out_grad_dims); @@ -146,12 +125,13 @@ void LaunchIndexPutGradCudaKernel( shape_a[idx] = out_grad_dims[idx]; } - auto pd_indices = GetDevicePointerArray(dev_ctx, indices_v); + auto pd_indices = + funcs::GetDevicePointerArray(dev_ctx, indices); if (value_grad) { if (value_grad->numel() == 1) { DenseTensor tmp_value_grad(value_grad->dtype()); - tmp_value_grad.Resize(indices_v[0]->dims()); + tmp_value_grad.Resize(indices[0]->dims()); T* tmp_value_grad_data = dev_ctx.template Alloc(&tmp_value_grad); auto out_grad_data = out_grad.data(); @@ -176,7 +156,7 @@ void LaunchIndexPutGradCudaKernel( value_grad->dtype(), false, value_grad); - } else if (value_grad->numel() == indices_v[0]->numel()) { + } else if (value_grad->numel() == indices[0]->numel()) { T* value_grad_data = dev_ctx.template Alloc(value_grad); auto out_grad_data = out_grad.data(); @@ -187,7 +167,7 @@ void LaunchIndexPutGradCudaKernel( numel, out_grad_data, pd_indices, stride_a, shape_a, value_grad_data); } else { DenseTensor tmp_value_grad(value_grad->dtype()); - tmp_value_grad.Resize(indices_v[0]->dims()); + tmp_value_grad.Resize(indices[0]->dims()); T* tmp_value_grad_data = dev_ctx.template Alloc(&tmp_value_grad); auto out_grad_data = out_grad.data(); @@ -208,23 +188,19 @@ void LaunchIndexPutGradCudaKernel( std::vector compress_dims; std::vector dims_without_1; - CalCompressedDimsWith1AndWithout1( + funcs::CalCompressedDimsWith1AndWithout1( &after_dims, &before_dims, &compress_dims, &dims_without_1); - phi::DenseTensor value_grad_dims_without1(value_grad->dtype()); - value_grad_dims_without1.Resize(phi::make_ddim(dims_without_1)); + auto pre_dims = value_grad->dims(); + value_grad->Resize(phi::make_ddim(dims_without_1)); IntArray v_axis(compress_dims); SumKernel(dev_ctx, tmp_value_grad, v_axis, value_grad->dtype(), false, - &value_grad_dims_without1); - phi::ReshapeInferKernel( - dev_ctx, - value_grad_dims_without1, - phi::IntArray(phi::vectorize(value_grad->dims())), - value_grad); + value_grad); + value_grad->Resize(pre_dims); } } } @@ -232,7 +208,7 @@ void LaunchIndexPutGradCudaKernel( template void IndexPutGradKernel(const Context& dev_ctx, const DenseTensor& x, - const std::vector& indices_v, + const std::vector& indices, const DenseTensor& value, const DenseTensor& out_grad, bool accumulate, @@ -246,28 +222,28 @@ void IndexPutGradKernel(const Context& dev_ctx, "of tensor x.")); std::vector tmp_args; std::vector int_indices_v = - DealWithBoolIndices(dev_ctx, indices_v, &tmp_args); + funcs::DealWithBoolIndices(dev_ctx, indices, &tmp_args); const size_t total_dims = x.dims().size(); - auto bd_dim = BroadCastTensorsDims(int_indices_v); + auto bd_dim = funcs::BroadCastTensorsDims(int_indices_v); std::vector res_dim_v(phi::vectorize(bd_dim)); std::vector res_indices_v(x.dims().size(), nullptr); std::vector tmp_res_indices_v; std::vector range_tensor_v; - for (int i = indices_v.size(); i < x.dims().size(); ++i) { + for (int i = indices.size(); i < x.dims().size(); ++i) { range_tensor_v.emplace_back(GetRangeCudaTensor( dev_ctx, x.dims()[i], phi::DataType::INT64)); } - DealWithIndices(dev_ctx, - x, - int_indices_v, - &res_indices_v, - &tmp_res_indices_v, - range_tensor_v, - bd_dim, - &res_dim_v); + funcs::DealWithIndices(dev_ctx, + x, + int_indices_v, + &res_indices_v, + &tmp_res_indices_v, + range_tensor_v, + bd_dim, + &res_dim_v); switch (total_dims) { case 1: diff --git a/paddle/phi/kernels/gpu/index_put_kernel.cu b/paddle/phi/kernels/gpu/index_put_kernel.cu index 4efe3e5a43d8cb..0c1cbc695c7e51 100644 --- a/paddle/phi/kernels/gpu/index_put_kernel.cu +++ b/paddle/phi/kernels/gpu/index_put_kernel.cu @@ -20,30 +20,10 @@ #include "paddle/phi/kernels/funcs/index_put_utils.h" namespace phi { -template -__global__ void range_cuda_kernel(int64_t N, T* out) { - int64_t idx = threadIdx.x + blockDim.x * blockIdx.x; - if (idx >= N) { - return; - } - out[idx] = idx; -} +UNROLL_RANGE_CUDA_KERNEL_DEFINITION -template -phi::DenseTensor GetRangeCudaTensor(const Context& dev_ctx, - int64_t N, - phi::DataType dtype) { - phi::DenseTensor res(dtype); - res.Resize(phi::make_ddim({N})); - DenseTensor* p_res = &res; - T* out = dev_ctx.template Alloc(p_res); - auto config = phi::backends::gpu::GetGpuLaunchConfig1D(dev_ctx, N); - range_cuda_kernel - <<>>( - N, out); - return res; -} +UNROLL_GET_RANGE_CUDA_TENSOR_DEFINITION template __global__ void index_put_cuda_kernel(const int64_t N, @@ -52,7 +32,7 @@ __global__ void index_put_cuda_kernel(const int64_t N, int64_t** indices, phi::Array stride, phi::Array shape, - int64_t isSingleValTensor, + int64_t is_single_val_tensor, bool accumulate, T* out) { int64_t idx = threadIdx.x + blockDim.x * blockIdx.x; @@ -71,30 +51,30 @@ __global__ void index_put_cuda_kernel(const int64_t N, } if (accumulate) { - *(out + offset) += *(vals + (idx & isSingleValTensor)); + *(out + offset) += *(vals + (idx & is_single_val_tensor)); } else { - *(out + offset) = *(vals + (idx & isSingleValTensor)); + *(out + offset) = *(vals + (idx & is_single_val_tensor)); } } template void LaunchIndexPutCudaKernel(const Context& dev_ctx, const DenseTensor& x, - const std::vector& indices_v, + const std::vector& indices, const DenseTensor& value, bool accumulate, DenseTensor* out) { auto* x_data = x.data(); auto* val_data = value.data(); - bool isInitialized = out->initialized(); + bool is_initialized = out->initialized(); T* out_data = dev_ctx.template Alloc(out); - if (!isInitialized) { + if (!is_initialized) { phi::Copy(dev_ctx, x, dev_ctx.GetPlace(), false, out); } auto x_dims = x.dims(); - const int64_t numel = indices_v[0]->numel(); + const int64_t numel = indices[0]->numel(); auto config = phi::backends::gpu::GetGpuLaunchConfig1D(dev_ctx, numel); auto x_stride = phi::stride(x_dims); @@ -106,9 +86,10 @@ void LaunchIndexPutCudaKernel(const Context& dev_ctx, shape_a[idx] = x_dims[idx]; } - int64_t isSingleValTensor = (value.numel() == 1) ? 0 : INT64_MAX; + int64_t is_single_val_tensor = (value.numel() == 1) ? 0 : INT64_MAX; - auto pd_indices = GetDevicePointerArray(dev_ctx, indices_v); + auto pd_indices = + funcs::GetDevicePointerArray(dev_ctx, indices); index_put_cuda_kernel <<>>( numel, @@ -117,7 +98,7 @@ void LaunchIndexPutCudaKernel(const Context& dev_ctx, pd_indices, stride_a, shape_a, - isSingleValTensor, + is_single_val_tensor, accumulate, out_data); } @@ -125,7 +106,7 @@ void LaunchIndexPutCudaKernel(const Context& dev_ctx, template void IndexPutKernel(const Context& dev_ctx, const DenseTensor& x, - const std::vector& indices_v, + const std::vector& indices, const DenseTensor& value, bool accumulate, DenseTensor* out) { @@ -135,14 +116,14 @@ void IndexPutKernel(const Context& dev_ctx, phi::errors::InvalidArgument( "The data type of tensor in indices must be same to the data type " "of tensor x.")); - PADDLE_ENFORCE_EQ(indices_v.empty(), + PADDLE_ENFORCE_EQ(indices.empty(), false, phi::errors::InvalidArgument("Indices cannot be empty.")); std::vector tmp_args; std::vector int_indices_v = - DealWithBoolIndices(dev_ctx, indices_v, &tmp_args); + funcs::DealWithBoolIndices(dev_ctx, indices, &tmp_args); const size_t total_dims = x.dims().size(); - auto bd_dim = BroadCastTensorsDims(int_indices_v); + auto bd_dim = funcs::BroadCastTensorsDims(int_indices_v); std::vector res_dim_v(phi::vectorize(bd_dim)); std::vector res_indices_v(x.dims().size(), nullptr); @@ -151,19 +132,19 @@ void IndexPutKernel(const Context& dev_ctx, std::vector range_tensor_v; const DenseTensor* ptr_value = nullptr; - for (int i = indices_v.size(); i < x.dims().size(); ++i) { + for (int i = indices.size(); i < x.dims().size(); ++i) { range_tensor_v.emplace_back(GetRangeCudaTensor( dev_ctx, x.dims()[i], phi::DataType::INT64)); } - DealWithIndices(dev_ctx, - x, - int_indices_v, - &res_indices_v, - &tmp_res_indices_v, - range_tensor_v, - bd_dim, - &res_dim_v); + funcs::DealWithIndices(dev_ctx, + x, + int_indices_v, + &res_indices_v, + &tmp_res_indices_v, + range_tensor_v, + bd_dim, + &res_dim_v); if (value.numel() != 1) { tmp_value_v.emplace_back( diff --git a/paddle/phi/kernels/index_put_grad_kernel.h b/paddle/phi/kernels/index_put_grad_kernel.h index f439573230d99b..575b7df5f27397 100644 --- a/paddle/phi/kernels/index_put_grad_kernel.h +++ b/paddle/phi/kernels/index_put_grad_kernel.h @@ -15,7 +15,6 @@ #pragma once #include -#include "paddle/phi/common/place.h" #include "paddle/phi/core/dense_tensor.h" namespace phi { diff --git a/paddle/phi/kernels/index_put_kernel.h b/paddle/phi/kernels/index_put_kernel.h index a7d313c252c52d..4410a508244571 100644 --- a/paddle/phi/kernels/index_put_kernel.h +++ b/paddle/phi/kernels/index_put_kernel.h @@ -15,7 +15,6 @@ #pragma once #include -#include "paddle/phi/common/place.h" #include "paddle/phi/core/dense_tensor.h" namespace phi { From ed7a1416c42f9cb777202af6ec9ce52e8d7672ee Mon Sep 17 00:00:00 2001 From: Courtesy-Xs Date: Mon, 8 May 2023 09:33:45 +0000 Subject: [PATCH 18/24] add datatype flag in backward yaml --- paddle/phi/api/yaml/backward.yaml | 1 + 1 file changed, 1 insertion(+) diff --git a/paddle/phi/api/yaml/backward.yaml b/paddle/phi/api/yaml/backward.yaml index 81d053962f4354..707c772e6f220d 100644 --- a/paddle/phi/api/yaml/backward.yaml +++ b/paddle/phi/api/yaml/backward.yaml @@ -805,6 +805,7 @@ param : [x, value] kernel : func : index_put_grad + data_type : out_grad - backward_op : index_sample_grad forward : index_sample (Tensor x, Tensor index) -> Tensor(out) From c92f75e5a8b845e10e838237a4b81eea8e74d7a7 Mon Sep 17 00:00:00 2001 From: Courtesy-Xs Date: Mon, 8 May 2023 11:49:41 +0000 Subject: [PATCH 19/24] replace macro with template with conditional complilation --- .../phi/kernels/cpu/index_put_grad_kernel.cc | 6 +- paddle/phi/kernels/cpu/index_put_kernel.cc | 6 +- paddle/phi/kernels/funcs/index_put_utils.h | 119 +++++++++--------- .../phi/kernels/gpu/index_put_grad_kernel.cu | 6 +- paddle/phi/kernels/gpu/index_put_kernel.cu | 6 +- 5 files changed, 65 insertions(+), 78 deletions(-) diff --git a/paddle/phi/kernels/cpu/index_put_grad_kernel.cc b/paddle/phi/kernels/cpu/index_put_grad_kernel.cc index ce2771c8c304bb..7374bcd403d12f 100644 --- a/paddle/phi/kernels/cpu/index_put_grad_kernel.cc +++ b/paddle/phi/kernels/cpu/index_put_grad_kernel.cc @@ -21,10 +21,6 @@ namespace phi { -UNROLL_RANGE_KERNEL_DEFINITION - -UNROLL_GET_RANGE_TENSOR_DEFINITION - template void set_zero_kernel(const int64_t N, const int64_t** indices, @@ -200,7 +196,7 @@ void IndexPutGradKernel(const Context& dev_ctx, std::vector range_tensor_v; for (int i = indices.size(); i < x.dims().size(); ++i) { - range_tensor_v.emplace_back(GetRangeTensor( + range_tensor_v.emplace_back(funcs::GetRangeTensor( dev_ctx, x.dims()[i], phi::DataType::INT64)); } diff --git a/paddle/phi/kernels/cpu/index_put_kernel.cc b/paddle/phi/kernels/cpu/index_put_kernel.cc index 24bb29af6a3914..da3e37ac242364 100644 --- a/paddle/phi/kernels/cpu/index_put_kernel.cc +++ b/paddle/phi/kernels/cpu/index_put_kernel.cc @@ -20,10 +20,6 @@ namespace phi { -UNROLL_RANGE_KERNEL_DEFINITION - -UNROLL_GET_RANGE_TENSOR_DEFINITION - template void index_put_kernel(const int64_t N, const T* x, @@ -132,7 +128,7 @@ void IndexPutKernel(const Context& dev_ctx, const DenseTensor* ptr_value = nullptr; for (int i = indices.size(); i < x.dims().size(); ++i) { - range_tensor_v.emplace_back(GetRangeTensor( + range_tensor_v.emplace_back(funcs::GetRangeTensor( dev_ctx, x.dims()[i], phi::DataType::INT64)); } diff --git a/paddle/phi/kernels/funcs/index_put_utils.h b/paddle/phi/kernels/funcs/index_put_utils.h index 0c30c1622fed3c..4d2aa1c2c0e7d0 100644 --- a/paddle/phi/kernels/funcs/index_put_utils.h +++ b/paddle/phi/kernels/funcs/index_put_utils.h @@ -26,17 +26,21 @@ #include "paddle/phi/kernels/reshape_kernel.h" #include "paddle/phi/kernels/split_kernel.h" +#if defined(__NVCC__) +#include +#include +#endif + namespace phi { namespace funcs { template -static phi::DenseTensor GetReshapeAndExpandTensor( - const Context& dev_ctx, - const phi::DenseTensor& tensor, - const phi::DDim& res_dim, - const phi::DDim& bd_dim, - int index) { +phi::DenseTensor GetReshapeAndExpandTensor(const Context& dev_ctx, + const phi::DenseTensor& tensor, + const phi::DDim& res_dim, + const phi::DDim& bd_dim, + int index) { std::vector before_dims = phi::vectorize(tensor.dims()); std::vector mid_dims(res_dim.size(), 1); @@ -60,7 +64,7 @@ static phi::DenseTensor GetReshapeAndExpandTensor( } template -static std::vector DealWithBoolIndices( +std::vector DealWithBoolIndices( const Context& dev_ctx, const std::vector& indices_v, std::vector* tmp_indices_v) { @@ -180,15 +184,14 @@ T** GetDevicePointerArray(const Context& ctx, } template -static void DealWithIndices( - const Context& dev_ctx, - const DenseTensor& x, - const std::vector& int_indices_v, - std::vector* res_indices_v, - std::vector* tmp_res_indices_v, - const std::vector& range_tensor_v, - const phi::DDim& bd_dim, - std::vector* res_dim_v) { +void DealWithIndices(const Context& dev_ctx, + const DenseTensor& x, + const std::vector& int_indices_v, + std::vector* res_indices_v, + std::vector* tmp_res_indices_v, + const std::vector& range_tensor_v, + const phi::DDim& bd_dim, + std::vector* res_dim_v) { size_t total_dims = x.dims().size(); if (int_indices_v.size() < total_dims) { std::vector tmp_x_dims = phi::vectorize(x.dims()); @@ -291,51 +294,51 @@ static void CalCompressedDimsWith1AndWithout1( } } -} // namespace funcs -} // namespace phi +#if defined(__NVCC__) +template +__global__ void range_cuda_kernel(int64_t N, T* out) { + int64_t idx = threadIdx.x + blockDim.x * blockIdx.x; -#define UNROLL_RANGE_CUDA_KERNEL_DEFINITION \ - template \ - __global__ void range_cuda_kernel(int64_t N, T* out) { \ - int64_t idx = threadIdx.x + blockDim.x * blockIdx.x; \ - if (idx >= N) { \ - return; \ - } \ - out[idx] = idx; \ + if (idx >= N) { + return; } + out[idx] = idx; +} -#define UNROLL_GET_RANGE_CUDA_TENSOR_DEFINITION \ - template \ - phi::DenseTensor GetRangeCudaTensor( \ - const Context& dev_ctx, int64_t N, phi::DataType dtype) { \ - phi::DenseTensor res(dtype); \ - res.Resize(phi::make_ddim({N})); \ - DenseTensor* p_res = &res; \ - T* out = dev_ctx.template Alloc(p_res); \ - auto config = phi::backends::gpu::GetGpuLaunchConfig1D(dev_ctx, N); \ - range_cuda_kernel<<>>(N, out); \ - return res; \ - } +template +phi::DenseTensor GetRangeCudaTensor(const Context& dev_ctx, + int64_t N, + phi::DataType dtype) { + phi::DenseTensor res(dtype); + res.Resize(phi::make_ddim({N})); + DenseTensor* p_res = &res; + T* out = dev_ctx.template Alloc(p_res); + auto config = phi::backends::gpu::GetGpuLaunchConfig1D(dev_ctx, N); + range_cuda_kernel + <<>>( + N, out); + return res; +} +#endif -#define UNROLL_RANGE_KERNEL_DEFINITION \ - template \ - void range_kernel(int64_t N, T* out) { \ - for (int64_t idx = 0; idx < N; ++idx) { \ - out[idx] = idx; \ - } \ +template +void range_kernel(int64_t N, T* out) { + for (int64_t idx = 0; idx < N; ++idx) { + out[idx] = idx; } +} -#define UNROLL_GET_RANGE_TENSOR_DEFINITION \ - template \ - phi::DenseTensor GetRangeTensor( \ - const Context& dev_ctx, int64_t N, phi::DataType dtype) { \ - phi::DenseTensor res(dtype); \ - res.Resize(phi::make_ddim({N})); \ - DenseTensor* p_res = &res; \ - T* out = dev_ctx.template Alloc(p_res); \ - range_kernel(N, out); \ - return res; \ - } +template +phi::DenseTensor GetRangeTensor(const Context& dev_ctx, + int64_t N, + phi::DataType dtype) { + phi::DenseTensor res(dtype); + res.Resize(phi::make_ddim({N})); + DenseTensor* p_res = &res; + T* out = dev_ctx.template Alloc(p_res); + range_kernel(N, out); + return res; +} + +} // namespace funcs +} // namespace phi diff --git a/paddle/phi/kernels/gpu/index_put_grad_kernel.cu b/paddle/phi/kernels/gpu/index_put_grad_kernel.cu index a5b27c1072dc8e..7ae1e42c067cbd 100644 --- a/paddle/phi/kernels/gpu/index_put_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/index_put_grad_kernel.cu @@ -23,10 +23,6 @@ namespace phi { -UNROLL_RANGE_CUDA_KERNEL_DEFINITION - -UNROLL_GET_RANGE_CUDA_TENSOR_DEFINITION - template __global__ void set_zero_cuda_kernel(const int64_t N, int64_t** indices, @@ -232,7 +228,7 @@ void IndexPutGradKernel(const Context& dev_ctx, std::vector range_tensor_v; for (int i = indices.size(); i < x.dims().size(); ++i) { - range_tensor_v.emplace_back(GetRangeCudaTensor( + range_tensor_v.emplace_back(funcs::GetRangeCudaTensor( dev_ctx, x.dims()[i], phi::DataType::INT64)); } diff --git a/paddle/phi/kernels/gpu/index_put_kernel.cu b/paddle/phi/kernels/gpu/index_put_kernel.cu index 0c1cbc695c7e51..ad27993c35244c 100644 --- a/paddle/phi/kernels/gpu/index_put_kernel.cu +++ b/paddle/phi/kernels/gpu/index_put_kernel.cu @@ -21,10 +21,6 @@ namespace phi { -UNROLL_RANGE_CUDA_KERNEL_DEFINITION - -UNROLL_GET_RANGE_CUDA_TENSOR_DEFINITION - template __global__ void index_put_cuda_kernel(const int64_t N, const T* x, @@ -133,7 +129,7 @@ void IndexPutKernel(const Context& dev_ctx, const DenseTensor* ptr_value = nullptr; for (int i = indices.size(); i < x.dims().size(); ++i) { - range_tensor_v.emplace_back(GetRangeCudaTensor( + range_tensor_v.emplace_back(funcs::GetRangeCudaTensor( dev_ctx, x.dims()[i], phi::DataType::INT64)); } From 4de9b48b012da2b4225a05e59d92c22d158c1975 Mon Sep 17 00:00:00 2001 From: Courtesy-Xs Date: Tue, 9 May 2023 02:20:43 +0000 Subject: [PATCH 20/24] fix rocmn bug --- paddle/phi/kernels/funcs/index_put_utils.h | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/paddle/phi/kernels/funcs/index_put_utils.h b/paddle/phi/kernels/funcs/index_put_utils.h index 4d2aa1c2c0e7d0..88e86aa8534b7a 100644 --- a/paddle/phi/kernels/funcs/index_put_utils.h +++ b/paddle/phi/kernels/funcs/index_put_utils.h @@ -26,7 +26,7 @@ #include "paddle/phi/kernels/reshape_kernel.h" #include "paddle/phi/kernels/split_kernel.h" -#if defined(__NVCC__) +#if defined(__NVCC__) || defined(__HIPCC__) #include #include #endif @@ -294,7 +294,7 @@ static void CalCompressedDimsWith1AndWithout1( } } -#if defined(__NVCC__) +#if defined(__NVCC__) || defined(__HIPCC__) template __global__ void range_cuda_kernel(int64_t N, T* out) { int64_t idx = threadIdx.x + blockDim.x * blockIdx.x; From ed00d81777327fd4e6730dce88e33ab63fa82467 Mon Sep 17 00:00:00 2001 From: Courtesy-Xs Date: Tue, 9 May 2023 04:59:26 +0000 Subject: [PATCH 21/24] fix note and rocmn bug --- paddle/phi/kernels/funcs/index_put_utils.h | 4 ++++ python/paddle/tensor/manipulation.py | 2 +- 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/paddle/phi/kernels/funcs/index_put_utils.h b/paddle/phi/kernels/funcs/index_put_utils.h index 88e86aa8534b7a..51e918c852347c 100644 --- a/paddle/phi/kernels/funcs/index_put_utils.h +++ b/paddle/phi/kernels/funcs/index_put_utils.h @@ -27,8 +27,12 @@ #include "paddle/phi/kernels/split_kernel.h" #if defined(__NVCC__) || defined(__HIPCC__) +#ifdef __NVCC__ #include #include +#elif defined(__HIPCC__) +#include +#endif #endif namespace phi { diff --git a/python/paddle/tensor/manipulation.py b/python/paddle/tensor/manipulation.py index 46c40e61194ae1..0a494113293dd5 100644 --- a/python/paddle/tensor/manipulation.py +++ b/python/paddle/tensor/manipulation.py @@ -4803,7 +4803,7 @@ def index_put_(x, indices, value, accumulate=False, name=None): If accumulate is True, the elements in values are added to x. If accumulate is False, the behavior is undefined if indices contain duplicate elements. Args: - x (Tensor) : The Source Tensor. Supported data types are int32, int64, float16, float32, float64, bool, complex64, complex128. + x (Tensor) : The Source Tensor. Supported data types are int32, int64, float16, float32, float64, bool. indices (Tensor): The tuple of Tensor containing the indices to index. The data type of ``tensor in indices`` must be int32, int64 or bool value (Tensor): The tensor used to be assigned to x. From 43167abd145be960ef6e031c122199397684961d Mon Sep 17 00:00:00 2001 From: Courtesy-Xs Date: Tue, 9 May 2023 10:34:58 +0000 Subject: [PATCH 22/24] fix conflict between flatten and index_put --- python/paddle/tensor/manipulation.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/python/paddle/tensor/manipulation.py b/python/paddle/tensor/manipulation.py index c003fb3adaa6ba..217ea26d4c5e0c 100644 --- a/python/paddle/tensor/manipulation.py +++ b/python/paddle/tensor/manipulation.py @@ -4911,6 +4911,12 @@ def unflatten(x, axis, shape, name=None): Returns: Tensor, return the unflatten tensor of :attr:`x`. + + Examples: + .. code-block:: python + + import paddle + x = paddle.randn(shape=[4, 6, 8]) shape = [2, 3] axis = 1 @@ -4932,9 +4938,9 @@ def unflatten(x, axis, shape, name=None): print(res.shape) # [2, 2, 6, 8] """ + # determine whether the input axis is valid. axis = non_negative_axis(x, axis) - if isinstance(shape, (list, tuple)): new_shape = ( list(x.shape[:axis]) + list(shape) + list(x.shape[axis + 1 :]) From b09221f8d970abd9f7a325ab253fef13ac7a505a Mon Sep 17 00:00:00 2001 From: Courtesy-Xs Date: Tue, 9 May 2023 12:37:07 +0000 Subject: [PATCH 23/24] fix bug in documentation --- python/paddle/tensor/manipulation.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/python/paddle/tensor/manipulation.py b/python/paddle/tensor/manipulation.py index 217ea26d4c5e0c..1b9f6a6c241b7f 100644 --- a/python/paddle/tensor/manipulation.py +++ b/python/paddle/tensor/manipulation.py @@ -4807,11 +4807,12 @@ def index_put_(x, indices, value, accumulate=False, name=None): indices (Tuple of Tensor): The tuple of Tensor containing the indices to index. The data type of ``tensor in indices`` must be int32, int64 or bool value (Tensor): The tensor used to be assigned to x. - accummulate (Bool): Whether the elements in values are added to x + accummulate (Bool): Whether the elements in values are added to x. Default: False. name(str, optional): For details, please refer to :ref:`api_guide_Name`. Generally, no setting is required. Default: None. Returns: Tensor, same dimention and dtype with x. + Examples: .. code-block:: python import paddle From db0209febe719f8a24c7ad8e3c2afe9322aee4d4 Mon Sep 17 00:00:00 2001 From: Ligoml <39876205+Ligoml@users.noreply.github.com> Date: Tue, 9 May 2023 20:43:47 +0800 Subject: [PATCH 24/24] Update python/paddle/tensor/manipulation.py --- python/paddle/tensor/manipulation.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/paddle/tensor/manipulation.py b/python/paddle/tensor/manipulation.py index 1b9f6a6c241b7f..7fe717383cd3b6 100644 --- a/python/paddle/tensor/manipulation.py +++ b/python/paddle/tensor/manipulation.py @@ -4807,7 +4807,7 @@ def index_put_(x, indices, value, accumulate=False, name=None): indices (Tuple of Tensor): The tuple of Tensor containing the indices to index. The data type of ``tensor in indices`` must be int32, int64 or bool value (Tensor): The tensor used to be assigned to x. - accummulate (Bool): Whether the elements in values are added to x. Default: False. + accummulate (Bool, optional): Whether the elements in values are added to x. Default: False. name(str, optional): For details, please refer to :ref:`api_guide_Name`. Generally, no setting is required. Default: None. Returns: