Skip to content

Commit

Permalink
[XPU]. Add fp16 to op split.
Browse files Browse the repository at this point in the history
  • Loading branch information
wbn03 committed Sep 7, 2022
1 parent 288120f commit 51d3dbd
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 12 deletions.
35 changes: 24 additions & 11 deletions lite/kernels/xpu/split_compute.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,8 @@ namespace lite {
namespace kernels {
namespace xpu {

void SplitCompute::Run() {
template <typename T, PrecisionType PType>
void SplitCompute<T, PType>::Run() {
auto& param = this->template Param<param_t>();
auto& ctx = this->ctx_->template As<XPUContext>();
auto& dout = param.output;
Expand All @@ -33,19 +34,19 @@ void SplitCompute::Run() {
height = height * in_dim[i];
}
int width = param.x->numel() / height;
std::vector<float*> out_ptrs;
std::vector<T*> out_ptrs;
std::vector<int> width_out;
for (auto out : dout) {
out->set_lod(param.x->lod());
out_ptrs.push_back(out->mutable_data<float>(TARGET(kXPU)));
out_ptrs.push_back(out->template mutable_data<T>(TARGET(kXPU)));
width_out.push_back(out->numel() / height);
}
int r = xdnn::split<float>(ctx.GetRawContext(),
param.x->data<float>(),
out_ptrs,
{height, width},
width_out,
1);
int r = xdnn::split<T>(ctx.GetRawContext(),
param.x->template data<T>(),
out_ptrs,
{height, width},
width_out,
1);

CHECK_EQ(r, 0);
}
Expand All @@ -55,12 +56,24 @@ void SplitCompute::Run() {
} // namespace lite
} // namespace paddle

REGISTER_LITE_KERNEL(
split, kXPU, kFloat, kNCHW, paddle::lite::kernels::xpu::SplitCompute, def)
using split_float =
paddle::lite::kernels::xpu::SplitCompute<float, PRECISION(kFloat)>;
REGISTER_LITE_KERNEL(split, kXPU, kFloat, kNCHW, split_float, def)
.BindInput("X", {LiteType::GetTensorTy(TARGET(kXPU))})
.BindInput("AxisTensor",
{LiteType::GetTensorTy(TARGET(kHost), PRECISION(kInt32))})
.BindOutput("SectionsTensorList",
{LiteType::GetTensorTy(TARGET(kHost), PRECISION(kInt32))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kXPU))})
.Finalize();

using split_fp16 =
paddle::lite::kernels::xpu::SplitCompute<float16, PRECISION(kFP16)>;
REGISTER_LITE_KERNEL(split, kXPU, kFP16, kNCHW, split_fp16, fp16)
.BindInput("X", {LiteType::GetTensorTy(TARGET(kXPU), PRECISION(kFP16))})
.BindInput("AxisTensor",
{LiteType::GetTensorTy(TARGET(kHost), PRECISION(kInt32))})
.BindOutput("SectionsTensorList",
{LiteType::GetTensorTy(TARGET(kHost), PRECISION(kInt32))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kXPU), PRECISION(kFP16))})
.Finalize();
3 changes: 2 additions & 1 deletion lite/kernels/xpu/split_compute.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@ namespace lite {
namespace kernels {
namespace xpu {

class SplitCompute : public KernelLite<TARGET(kXPU), PRECISION(kFloat)> {
template <typename T, PrecisionType PType>
class SplitCompute : public KernelLite<TARGET(kXPU), PType> {
public:
using param_t = operators::SplitParam;

Expand Down

0 comments on commit 51d3dbd

Please sign in to comment.