Skip to content

Commit

Permalink
fix gather of xpu
Browse files Browse the repository at this point in the history
  • Loading branch information
shentanyue committed May 17, 2022
1 parent 1df7e0e commit 4c57c7d
Show file tree
Hide file tree
Showing 2 changed files with 92 additions and 9 deletions.
99 changes: 90 additions & 9 deletions lite/kernels/xpu/gather_compute.cc
Original file line number Diff line number Diff line change
Expand Up @@ -46,16 +46,88 @@ void GatherCompute<DataType, IndexType>::Run() {
axis += x_dims.size();
}

int r = xdnn::gather<DataType, IndexType>(
ctx.GetRawContext(),
x->template data<DataType>(),
index->template data<IndexType>(),
out->template mutable_data<DataType>(TARGET(kXPU)),
x_dims,
index->numel(),
axis);
if (param.X->precision() == PrecisionType::kInt64 &&
param.Index->precision() == PrecisionType::kInt64) {
auto* index_int64 = param.Index->template data<int64_t>();
int size = param.Index->dims().production();
XPUScratchPadGuard index_xpu_guard_ =
TargetWrapperXPU::MallocScratchPad(size * sizeof(int));
int* index_int32_device = reinterpret_cast<int*>(index_xpu_guard_->addr_);

CHECK_EQ(r, 0);
int r0 = xdnn::cast_v2<int64_t, int32_t>(
ctx.GetRawContext(), index_int64, index_int32_device, index->numel());
CHECK_EQ(r0, 0);

int r1 = xdnn::gather<int64_t, int32_t>(
ctx.GetRawContext(),
x->template data<int64_t>(),
index_int32_device,
out->template mutable_data<int64_t>(TARGET(kXPU)),
x_dims,
index->numel(),
axis);
CHECK_EQ(r1, 0);
} else if (param.X->precision() == PrecisionType::kInt64 &&
param.Index->precision() == PrecisionType::kInt32) {
int r = xdnn::gather<int64_t, int32_t>(
ctx.GetRawContext(),
x->template data<int64_t>(),
index->template data<int32_t>(),
out->template mutable_data<int64_t>(TARGET(kXPU)),
x_dims,
index->numel(),
axis);
CHECK_EQ(r, 0);
} else if (param.X->precision() == PrecisionType::kInt32 &&
param.Index->precision() == PrecisionType::kInt32) {
int r = xdnn::gather<int32_t, int32_t>(
ctx.GetRawContext(),
x->template data<int32_t>(),
index->template data<int32_t>(),
out->template mutable_data<int32_t>(TARGET(kXPU)),
x_dims,
index->numel(),
axis);
CHECK_EQ(r, 0);
} else if (param.X->precision() == PrecisionType::kInt32 &&
param.Index->precision() == PrecisionType::kInt64) {
int r = xdnn::gather<int32_t, int64_t>(
ctx.GetRawContext(),
x->template data<int32_t>(),
index->template data<int64_t>(),
out->template mutable_data<int32_t>(TARGET(kXPU)),
x_dims,
index->numel(),
axis);
CHECK_EQ(r, 0);
} else if (param.X->precision() == PrecisionType::kFloat &&
param.Index->precision() == PrecisionType::kInt32) {
int r = xdnn::gather<float, int32_t>(
ctx.GetRawContext(),
x->template data<float>(),
index->template data<int32_t>(),
out->template mutable_data<float>(TARGET(kXPU)),
x_dims,
index->numel(),
axis);
CHECK_EQ(r, 0);
} else if (param.X->precision() == PrecisionType::kFloat &&
param.Index->precision() == PrecisionType::kInt64) {
int r = xdnn::gather<float, int64_t>(
ctx.GetRawContext(),
x->template data<float>(),
index->template data<int64_t>(),
out->template mutable_data<float>(TARGET(kXPU)),
x_dims,
index->numel(),
axis);
CHECK_EQ(r, 0);
} else {
LOG(FATAL) << "Unsupported gather op with x dtype: "
<< lite_api::PrecisionToStr(param.X->precision())
<< " and index dtype: "
<< lite_api::PrecisionToStr(param.Index->precision());
}
}

} // namespace xpu
Expand Down Expand Up @@ -107,3 +179,12 @@ REGISTER_LITE_KERNEL(
{LiteType::GetTensorTy(TARGET(kHost), PRECISION(kInt32))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kXPU), PRECISION(kInt64))})
.Finalize();
REGISTER_LITE_KERNEL(
gather, kXPU, kFloat, kNCHW, GatherXPUInt64Int64, gather_i64_i64)
.BindInput("X", {LiteType::GetTensorTy(TARGET(kXPU), PRECISION(kInt64))})
.BindInput("Index",
{LiteType::GetTensorTy(TARGET(kXPU), PRECISION(kInt64))})
.BindInput("Axis",
{LiteType::GetTensorTy(TARGET(kHost), PRECISION(kInt32))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kXPU), PRECISION(kInt64))})
.Finalize();
2 changes: 2 additions & 0 deletions lite/kernels/xpu/gather_compute.h
Original file line number Diff line number Diff line change
Expand Up @@ -46,3 +46,5 @@ typedef paddle::lite::kernels::xpu::GatherCompute<float, int64_t>
GatherXPUFloatInt64;
typedef paddle::lite::kernels::xpu::GatherCompute<int64_t, int32_t>
GatherXPUInt64Int32;
typedef paddle::lite::kernels::xpu::GatherCompute<int64_t, int64_t>
GatherXPUInt64Int64;

0 comments on commit 4c57c7d

Please sign in to comment.