Skip to content

Commit

Permalink
[HOST][X86] fix equal_int64; cast support bool, int32; slice support …
Browse files Browse the repository at this point in the history
…int32; test=develop
  • Loading branch information
zhupengyang committed Feb 22, 2021
1 parent fd8bac9 commit 509cf5d
Show file tree
Hide file tree
Showing 5 changed files with 113 additions and 21 deletions.
18 changes: 18 additions & 0 deletions lite/kernels/host/compare_compute.cc
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,24 @@ REGISTER_LITE_KERNEL(equal, kHost, kInt64, kAny, equal_int64, def)
.BindPaddleOpVersion("equal", 1)
.Finalize();

// float kernel has higher score when picking kernel.
// TODO(zhupengyang): merge equal_int64 later
using equal_int64_f = paddle::lite::kernels::host::CompareCompute<
PRECISION(kFloat),
paddle::lite::kernels::host::_EqualFunctor<int64_t>>;
REGISTER_LITE_KERNEL(equal, kHost, kFloat, kAny, equal_int64_f, int64)
.BindInput("X",
{LiteType::GetTensorTy(
TARGET(kHost), PRECISION(kInt64), DATALAYOUT(kAny), -1)})
.BindInput("Y",
{LiteType::GetTensorTy(
TARGET(kHost), PRECISION(kInt64), DATALAYOUT(kAny), -1)})
.BindOutput("Out",
{LiteType::GetTensorTy(
TARGET(kHost), PRECISION(kBool), DATALAYOUT(kAny), -1)})
.BindPaddleOpVersion("equal", 1)
.Finalize();

using equal_int32 = paddle::lite::kernels::host::CompareCompute<
PRECISION(kInt32),
paddle::lite::kernels::host::_EqualFunctor<int32_t>>;
Expand Down
20 changes: 20 additions & 0 deletions lite/kernels/x86/cast_compute.cc
Original file line number Diff line number Diff line change
Expand Up @@ -34,3 +34,23 @@ REGISTER_LITE_KERNEL(
.BindInput("X", {LiteType::GetTensorTy(TARGET(kX86), PRECISION(kFP16))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kX86))})
.Finalize();

REGISTER_LITE_KERNEL(cast,
kX86,
kFloat,
kNCHW,
paddle::lite::kernels::x86::CastCompute<bool>,
bool_to_any)
.BindInput("X", {LiteType::GetTensorTy(TARGET(kX86), PRECISION(kBool))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kX86), PRECISION(kAny))})
.Finalize();

REGISTER_LITE_KERNEL(cast,
kX86,
kFloat,
kNCHW,
paddle::lite::kernels::x86::CastCompute<int>,
int32_to_any)
.BindInput("X", {LiteType::GetTensorTy(TARGET(kX86), PRECISION(kInt32))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kX86), PRECISION(kAny))})
.Finalize();
15 changes: 15 additions & 0 deletions lite/kernels/x86/slice_compute.cc
Original file line number Diff line number Diff line change
Expand Up @@ -27,3 +27,18 @@ REGISTER_LITE_KERNEL(slice,
.BindInput("EndsTensorList", {LiteType::GetTensorTy(TARGET(kX86))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kX86))})
.Finalize();

REGISTER_LITE_KERNEL(slice,
kX86,
kFloat,
kNCHW,
paddle::lite::kernels::x86::SliceCompute<int>,
int32)
.BindInput("Input",
{LiteType::GetTensorTy(TARGET(kX86), PRECISION(kInt32))})
.BindInput("StartsTensor", {LiteType::GetTensorTy(TARGET(kX86))})
.BindInput("EndsTensor", {LiteType::GetTensorTy(TARGET(kX86))})
.BindInput("StartsTensorList", {LiteType::GetTensorTy(TARGET(kX86))})
.BindInput("EndsTensorList", {LiteType::GetTensorTy(TARGET(kX86))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kX86), PRECISION(kInt32))})
.Finalize();
80 changes: 59 additions & 21 deletions lite/kernels/xpu/reshape_compute.cc
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "lite/kernels/xpu/reshape_compute.h"
#include <algorithm>
#include "lite/backends/xpu/xpu_header_sitter.h"
Expand All @@ -21,9 +22,10 @@ namespace lite {
namespace kernels {
namespace xpu {

void ReshapeCompute::Run() {
auto& param = this->Param<param_t>();
auto& ctx = this->ctx_->As<XPUContext>();
template <class T>
void ReshapeCompute<T>::Run() {
auto& param = this->template Param<param_t>();
auto& ctx = this->ctx_->template As<XPUContext>();
auto x = param.x;
auto output = param.output;
auto output_dims = output->dims();
Expand All @@ -32,10 +34,10 @@ void ReshapeCompute::Run() {
output->ShareDataWith(*x);
output->Resize(output_dims);
} else {
int r = xdnn::copy<float>(ctx.GetRawContext(),
param.x->data<float>(),
param.output->mutable_data<float>(TARGET(kXPU)),
param.x->numel());
int r = xdnn::copy<T>(ctx.GetRawContext(),
x->template data<T>(),
output->template mutable_data<T>(TARGET(kXPU)),
x->numel());

CHECK_EQ(r, 0);
}
Expand All @@ -50,46 +52,82 @@ REGISTER_LITE_KERNEL(reshape2,
kXPU,
kFloat,
kNCHW,
paddle::lite::kernels::xpu::ReshapeCompute,
def)
paddle::lite::kernels::xpu::ReshapeCompute<float>,
float32)
.BindInput("X", {LiteType::GetTensorTy(TARGET(kXPU))})
.BindInput("ShapeTensor", {LiteType::GetTensorTy(TARGET(kHost))})
.BindInput("Shape", {LiteType::GetTensorTy(TARGET(kHost))})
.BindInput("ShapeTensor",
{LiteType::GetTensorTy(TARGET(kHost), PRECISION(kInt32))})
.BindInput("Shape",
{LiteType::GetTensorTy(TARGET(kHost), PRECISION(kInt32))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kXPU))})
.BindOutput("XShape", {LiteType::GetTensorTy(TARGET(kHost))})
.Finalize();

REGISTER_LITE_KERNEL(reshape2,
kXPU,
kFloat,
kNCHW,
paddle::lite::kernels::xpu::ReshapeCompute<int>,
int32)
.BindInput("X", {LiteType::GetTensorTy(TARGET(kXPU), PRECISION(kInt32))})
.BindInput("ShapeTensor",
{LiteType::GetTensorTy(TARGET(kHost), PRECISION(kInt32))})
.BindInput("Shape",
{LiteType::GetTensorTy(TARGET(kHost), PRECISION(kInt32))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kXPU), PRECISION(kInt32))})
.BindOutput("XShape", {LiteType::GetTensorTy(TARGET(kHost))})
.Finalize();

REGISTER_LITE_KERNEL(reshape2,
kXPU,
kFloat,
kNCHW,
paddle::lite::kernels::xpu::ReshapeCompute<int64_t>,
int64)
.BindInput("X", {LiteType::GetTensorTy(TARGET(kXPU), PRECISION(kInt64))})
.BindInput("ShapeTensor",
{LiteType::GetTensorTy(TARGET(kHost), PRECISION(kInt32))})
.BindInput("Shape",
{LiteType::GetTensorTy(TARGET(kHost), PRECISION(kInt32))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kXPU), PRECISION(kInt64))})
.BindOutput("XShape", {LiteType::GetTensorTy(TARGET(kHost))})
.Finalize();

REGISTER_LITE_KERNEL(reshape,
kXPU,
kFloat,
kNCHW,
paddle::lite::kernels::xpu::ReshapeCompute,
def)
paddle::lite::kernels::xpu::ReshapeCompute<float>,
float32)
.BindInput("X", {LiteType::GetTensorTy(TARGET(kXPU))})
.BindInput("ShapeTensor", {LiteType::GetTensorTy(TARGET(kHost))})
.BindInput("Shape", {LiteType::GetTensorTy(TARGET(kHost))})
.BindInput("ShapeTensor",
{LiteType::GetTensorTy(TARGET(kHost), PRECISION(kInt32))})
.BindInput("Shape",
{LiteType::GetTensorTy(TARGET(kHost), PRECISION(kInt32))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kXPU))})
.Finalize();

REGISTER_LITE_KERNEL(flatten,
kXPU,
kFloat,
kNCHW,
paddle::lite::kernels::xpu::ReshapeCompute,
def)
paddle::lite::kernels::xpu::ReshapeCompute<float>,
float32)
.BindInput("X", {LiteType::GetTensorTy(TARGET(kXPU))})
.BindInput("Shape", {LiteType::GetTensorTy(TARGET(kHost))})
.BindInput("Shape",
{LiteType::GetTensorTy(TARGET(kHost), PRECISION(kInt32))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kXPU))})
.Finalize();

REGISTER_LITE_KERNEL(flatten2,
kXPU,
kFloat,
kNCHW,
paddle::lite::kernels::xpu::ReshapeCompute,
def)
paddle::lite::kernels::xpu::ReshapeCompute<float>,
float32)
.BindInput("X", {LiteType::GetTensorTy(TARGET(kXPU))})
.BindInput("Shape", {LiteType::GetTensorTy(TARGET(kHost))})
.BindInput("Shape",
{LiteType::GetTensorTy(TARGET(kHost), PRECISION(kInt32))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kXPU))})
.BindOutput("XShape", {LiteType::GetTensorTy(TARGET(kHost))})
.Finalize();
1 change: 1 addition & 0 deletions lite/kernels/xpu/reshape_compute.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ namespace lite {
namespace kernels {
namespace xpu {

template <class T>
class ReshapeCompute : public KernelLite<TARGET(kXPU), PRECISION(kFloat)> {
public:
using param_t = operators::ReshapeParam;
Expand Down

0 comments on commit 509cf5d

Please sign in to comment.