Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ARM] Add viterbi_decode op #10066

Merged
merged 3 commits into from
Mar 13, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
403 changes: 403 additions & 0 deletions lite/backends/arm/math/viterbi_decode.cc

Large diffs are not rendered by default.

35 changes: 35 additions & 0 deletions lite/backends/arm/math/viterbi_decode.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

#include "lite/operators/op_params.h"
#include "lite/utils/log/cp_logging.h"

namespace paddle {
namespace lite {
namespace arm {
namespace math {

void viterbi_decode(const lite::Tensor &input,
const lite::Tensor &transition,
const lite::Tensor &length,
bool include_bos_eos_tag,
Tensor *scores,
Tensor *path);

} // namespace math
} // namespace arm
} // namespace lite
} // namespace paddle
31 changes: 15 additions & 16 deletions lite/core/optimizer/mir/memory_optimize_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -64,22 +64,21 @@ void MemoryOptimizePass::CollectLifeCycleByDevice(
};

// The all of input and output variables of the Ops will not be reused.
std::set<std::string> invalid_op_nodes = {
"while",
"conditional_block",
"conditional_block_infer",
"merge_lod_tensor_infer",
"merge_lod_tensor",
"equal",
"lod_reset",
"yolo_box",
"subgraph",
"feed",
"fetch",
"cast",
"expand",
"share_data",
};
std::set<std::string> invalid_op_nodes = {"while",
"conditional_block",
"conditional_block_infer",
"merge_lod_tensor_infer",
"merge_lod_tensor",
"equal",
"lod_reset",
"yolo_box",
"subgraph",
"feed",
"fetch",
"cast",
"expand",
"share_data",
"viterbi_decode"};

auto insert_invalid_op_nodes_for_specific_target = [&](
std::set<std::string> op_node_set, TargetType specific_target) {
Expand Down
1 change: 1 addition & 0 deletions lite/kernels/arm/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ add_kernel(matmul_v2_compute ARM extra SRCS matmul_v2_compute.cc)
add_kernel(sum_compute ARM extra SRCS sum_compute.cc)
add_kernel(dequantize_log_compute ARM extra SRCS dequantize_log_compute.cc)
add_kernel(fused_attention_compute_arm ARM extra SRCS fused_attention_compute.cc)
add_kernel(viterbi_decode_compute ARM extra SRCS viterbi_decode_compute.cc)

# for OCR specific
add_kernel(gru_unit_compute_arm ARM extra SRCS gru_unit_compute.cc)
Expand Down
58 changes: 58 additions & 0 deletions lite/kernels/arm/viterbi_decode_compute.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "lite/kernels/arm/viterbi_decode_compute.h"
#include "lite/backends/arm/math/viterbi_decode.h"
#include "lite/core/op_registry.h"
#include "lite/core/tensor.h"
#include "lite/core/type_system.h"

namespace paddle {
namespace lite {
namespace kernels {
namespace arm {

void ViterbiDecodeCompute::Run() {
auto& param = Param<operators::ViterbiDecodeParam>();
auto input = param.input;
auto transition = param.transition;
auto length = param.length;
auto include_bos_eos_tag = param.include_bos_eos_tag;
auto scores = param.scores;
auto path = param.path;
lite::arm::math::viterbi_decode(
*param.input, *transition, *length, include_bos_eos_tag, scores, path);
return;
}

} // namespace arm
} // namespace kernels
} // namespace lite
} // namespace paddle

REGISTER_LITE_KERNEL(viterbi_decode,
kARM,
kFloat,
kNCHW,
paddle::lite::kernels::arm::ViterbiDecodeCompute,
def)
.BindInput("Input", {LiteType::GetTensorTy(TARGET(kARM))})
.BindInput("Length",
{LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt64))})
.BindInput("Transition", {LiteType::GetTensorTy(TARGET(kARM))})
.BindOutput("Path",
{LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt64))})
.BindOutput("Scores", {LiteType::GetTensorTy(TARGET(kARM))})
.BindPaddleOpVersion("viterbi_decode", 1)
.Finalize();
37 changes: 37 additions & 0 deletions lite/kernels/arm/viterbi_decode_compute.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once
#include "lite/core/kernel.h"
#include "lite/operators/viterbi_decode_op.h"

namespace paddle {
namespace lite {
namespace kernels {
namespace arm {

class ViterbiDecodeCompute
: public KernelLite<TARGET(kARM), PRECISION(kFloat)> {
public:
using param_t = operators::ViterbiDecodeParam;

void Run() override;

virtual ~ViterbiDecodeCompute() = default;
};

} // namespace arm
} // namespace kernels
} // namespace lite
} // namespace paddle
1 change: 1 addition & 0 deletions lite/operators/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,7 @@ add_operator(strided_slice_op extra SRCS strided_slice_op.cc)
add_operator(where_op extra SRCS where_op.cc)
add_operator(unique_with_counts_op extra SRCS unique_with_counts_op.cc)
add_operator(unique_op extra SRCS unique_op.cc)
add_operator(viterbi_decode extra SRCS viterbi_decode_op.cc)

# for content-dnn specific
add_operator(search_aligned_mat_mul_op extra SRCS search_aligned_mat_mul_op.cc)
Expand Down
9 changes: 9 additions & 0 deletions lite/operators/op_params.h
Original file line number Diff line number Diff line change
Expand Up @@ -2602,6 +2602,15 @@ struct TemporalShiftParam : ParamBase {
std::string data_format{"NCHW"};
};

struct ViterbiDecodeParam : ParamBase {
const lite::Tensor* input{};
const lite::Tensor* length{};
const lite::Tensor* transition{};
lite::Tensor* path{};
lite::Tensor* scores{};
bool include_bos_eos_tag{};
};

} // namespace operators
} // namespace lite
} // namespace paddle
58 changes: 58 additions & 0 deletions lite/operators/viterbi_decode_op.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "lite/operators/viterbi_decode_op.h"
#include "lite/core/op_registry.h"

namespace paddle {
namespace lite {
namespace operators {

bool ViterbiDecodeOpLite::CheckShape() const {
CHECK_OR_FALSE(param_.input);
CHECK_OR_FALSE(param_.length);
CHECK_OR_FALSE(param_.transition);
CHECK_OR_FALSE(param_.path);
CHECK_OR_FALSE(param_.scores);
return true;
}

bool ViterbiDecodeOpLite::InferShapeImpl() const { return true; }

bool ViterbiDecodeOpLite::AttachImpl(const cpp::OpDesc &op_desc,
lite::Scope *scope) {
CHECK(!op_desc.Input("Input").empty());
CHECK(!op_desc.Input("Length").empty());
CHECK(!op_desc.Input("Transition").empty());
CHECK(!op_desc.Output("Path").empty());
CHECK(!op_desc.Output("Scores").empty());
auto Input = op_desc.Input("Input").front();
auto Length = op_desc.Input("Length").front();
auto Transition = op_desc.Input("Transition").front();
auto Path = op_desc.Output("Path").front();
auto Scores = op_desc.Output("Scores").front();
param_.input = GetVar<lite::Tensor>(scope, Input);
param_.length = GetVar<lite::Tensor>(scope, Length);
param_.transition = GetVar<lite::Tensor>(scope, Transition);
param_.scores = GetMutableVar<lite::Tensor>(scope, Scores);
param_.path = GetMutableVar<lite::Tensor>(scope, Path);
param_.include_bos_eos_tag = op_desc.GetAttr<bool>("include_bos_eos_tag");
return true;
}

} // namespace operators
} // namespace lite
} // namespace paddle

REGISTER_LITE_OP(viterbi_decode, paddle::lite::operators::ViterbiDecodeOpLite);
52 changes: 52 additions & 0 deletions lite/operators/viterbi_decode_op.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once
#include <string>
#include <vector>
#include "lite/core/kernel.h"
#include "lite/core/op_lite.h"
#include "lite/core/scope.h"
#include "lite/operators/op_params.h"
#include "lite/utils/all.h"

namespace paddle {
namespace lite {
namespace operators {

class ViterbiDecodeOpLite : public OpLite {
public:
ViterbiDecodeOpLite() {}

explicit ViterbiDecodeOpLite(const std::string &type) : OpLite(type) {}

bool CheckShape() const override;

bool InferShapeImpl() const override;

bool InferShapeWithCache() const override { return true; }

void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); }

bool AttachImpl(const cpp::OpDesc &op_desc, lite::Scope *scope) override;

std::string DebugString() const override { return "viterbi_decode"; }

private:
mutable ViterbiDecodeParam param_;
};

} // namespace operators
} // namespace lite
} // namespace paddle