Skip to content

Commit

Permalink
[Cherry-pick][NNAdapter][TIM-VX] fuse sigmoid&mul into swish (#9623)
Browse files Browse the repository at this point in the history
  • Loading branch information
yingshengBD authored Nov 2, 2022
1 parent d6a6d56 commit 6b88c17
Show file tree
Hide file tree
Showing 7 changed files with 183 additions and 4 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

#include "core/types.h"

namespace nnadapter {

void FuseSigmoidMulIntoSwish(core::Model *model);

} // namespace nnadapter
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,9 @@ REGISTER_CONVERTER(LEAKY_RELU, ConvertLeakyRelu)
REGISTER_CONVERTER(MAX_POOL_2D, ConvertPool2D)
REGISTER_CONVERTER(MAT_MUL, ConvertMatMul)
REGISTER_CONVERTER(MUL, ConvertElementwise)
REGISTER_CONVERTER(REDUCE_MEAN, ConvertReduce)
REGISTER_CONVERTER(REDUCE_SUM, ConvertReduce)
REGISTER_CONVERTER(REDUCE_MAX, ConvertReduce)
REGISTER_CONVERTER(RELU, ConvertUnaryActivations)
REGISTER_CONVERTER(RELU6, ConvertUnaryActivations)
REGISTER_CONVERTER(RESHAPE, ConvertReshape)
Expand All @@ -43,6 +46,7 @@ REGISTER_CONVERTER(SUB, ConvertElementwise)
REGISTER_CONVERTER(SQUEEZE, ConvertSqueeze)
REGISTER_CONVERTER(SPLIT, ConvertSplit)
REGISTER_CONVERTER(SLICE, ConvertSlice)
REGISTER_CONVERTER(SWISH, ConvertUnaryActivations)
REGISTER_CONVERTER(TANH, ConvertUnaryActivations)
REGISTER_CONVERTER(TRANSPOSE, ConvertTranspose)
REGISTER_CONVERTER(UNSQUEEZE, ConvertUnsqueeze)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "operation/reduce.h"
#include "driver/verisilicon_timvx/converter/converter.h"
#include "utility/debug.h"
#include "utility/logging.h"

namespace nnadapter {
namespace verisilicon_timvx {

int ConvertReduce(Converter* converter, core::Operation* operation) {
REDUCE_OPERATION_EXTRACT_INPUTS_OUTPUTS

// Convert to tim-vx tensors and operators
auto input_tensor = converter->GetMappedTensor(input_operand);
if (!input_tensor) {
input_tensor = converter->ConvertOperand(input_operand);
}
auto output_tensor = converter->ConvertOperand(output_operand);
std::vector<int32_t> axis;
for (int i = 0; i < axes_size; i++) {
axis.push_back(axes_data[i] - 2);
}
switch (operation->type) {
#define CONVERT_REDUCE(type, class_name) \
case NNADAPTER_##type: { \
auto reduce_op = \
converter->graph()->CreateOperation<tim::vx::ops::class_name>( \
axis, keep_dim); \
reduce_op->BindInputs({input_tensor}); \
reduce_op->BindOutputs({output_tensor}); \
} break;
CONVERT_REDUCE(REDUCE_MEAN, ReduceMean);
CONVERT_REDUCE(REDUCE_SUM, ReduceSum);
CONVERT_REDUCE(REDUCE_MAX, ReduceMax);
#undef CONVERT_REDUCE
default:
NNADAPTER_LOG(FATAL) << "Unsupported reduce operation type "
<< OperationTypeToString(operation->type)
<< " is found.";
break;
}
return NNADAPTER_NO_ERROR;
}

} // namespace verisilicon_timvx
} // namespace nnadapter
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ int ConvertUnaryActivations(Converter* converter, core::Operation* operation) {
CONVERT_UNARY_ACTIVATION(RELU, Relu);
CONVERT_UNARY_ACTIVATION(RELU6, Relu6);
CONVERT_UNARY_ACTIVATION(SIGMOID, Sigmoid);
CONVERT_UNARY_ACTIVATION(SWISH, Swish);
CONVERT_UNARY_ACTIVATION(TANH, Tanh);
#undef CONVERT_UNARY_ACTIVATION
default:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
#include "optimizer/fuse_conv2d_batch_norm_into_conv2d.h"
#include "optimizer/fuse_matmul_add_into_fully_connected.h"
#include "optimizer/fuse_reshape_transpose_reshape_into_channel_shuffle.h"
#include "optimizer/fuse_sigmoid_mul_into_swish.h"
#include "utility/debug.h"
#include "utility/logging.h"
#include "utility/modeling.h"
Expand Down Expand Up @@ -122,6 +123,7 @@ int Program::Build(core::Model* model, core::Cache* cache) {
FuseConv2DActivationIntoConv2D(model);
FuseMatMulAddIntoFullyConnected(model);
FuseReshapeTransposeReshapeIntoChannelShuffle(model);
FuseSigmoidMulIntoSwish(model);
ConvertFillLikeIntoMulAdd(model);
ConstantFoldOperations(model);
UnpackOpFusion(model);
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "optimizer/fuse_sigmoid_mul_into_swish.h"
#include <algorithm>
#include <map>
#include <vector>
#include "optimizer/pattern_matcher.h"
#include "utility/debug.h"
#include "utility/logging.h"
#include "utility/micros.h"
#include "utility/modeling.h"
#include "utility/utility.h"

namespace nnadapter {

class SigmoidMulFuser : public PatternMatcher {
public:
SigmoidMulFuser() {}
void BuildPattern() override;
bool HandleMatchedResults(core::Model* model,
const std::map<std::string, Node*>& nodes) override;
};

void SigmoidMulFuser::BuildPattern() {
// Operation patterns
auto sigmoid_pattern =
CreatePattern("sigmoid", NNADAPTER_SIGMOID)->IsIntermediate();
auto mul_pattern = CreatePattern("mul", NNADAPTER_MUL)->IsIntermediate();
// Operand patterns
auto sigmoid_input_pattern =
CreatePattern("sigmoid_input")
->IsOperationInputOperand(NNADAPTER_SIGMOID, 0)
->IsOperationInputOperand(NNADAPTER_MUL, 0);
auto sigmoid_output_pattern =
CreatePattern("sigmoid_output")
->IsOperationOutputOperand(NNADAPTER_SIGMOID, 0)
->IsOperationInputOperand(NNADAPTER_MUL, 1)
->IsIntermediate();
auto mul_fuse_code_pattern = CreatePattern("mul_fuse_code")
->IsOperationInputOperand(NNADAPTER_MUL, 2)
->IsIntermediate();
auto mul_output_pattern =
CreatePattern("mul_output")->IsOperationOutputOperand(NNADAPTER_MUL, 0);
// Create the topological connections for the above patterns
std::vector<Pattern*> mul_input_patterns{
sigmoid_input_pattern, mul_fuse_code_pattern, sigmoid_output_pattern};
*sigmoid_input_pattern >> *sigmoid_pattern >> *sigmoid_output_pattern;
mul_input_patterns >> *mul_pattern >> *mul_output_pattern;
}

bool SigmoidMulFuser::HandleMatchedResults(
core::Model* model, const std::map<std::string, Node*>& nodes) {
// Get the operands and operations from the matched subgraph nodes.
auto sigmoid_input_operand = nodes.at("sigmoid_input")->operand;
auto mul_output_operand = nodes.at("mul_output")->operand;
// Create a new NNADAPTER_SWISH operation and replace the matched
// subgraph nodes.
auto* swish_operation = AddOperation(model);
swish_operation->type = NNADAPTER_SWISH;
swish_operation->input_operands = {sigmoid_input_operand};
swish_operation->output_operands = {mul_output_operand};
// The matched intermediate operands and operations will be deleted only when
// it returns true.
return true;
}

NNADAPTER_EXPORT void FuseSigmoidMulIntoSwish(core::Model* model) {
NNADAPTER_VLOG(5) << "Apply SigmoidMulFuser";
bool stop;
do {
SigmoidMulFuser sigmoid_mul_fuser;
stop = sigmoid_mul_fuser.Apply(model) == 0;
} while (!stop);
}

} // namespace nnadapter
10 changes: 6 additions & 4 deletions lite/kernels/nnadapter/converter/all.h
Original file line number Diff line number Diff line change
Expand Up @@ -246,7 +246,7 @@ REGISTER_CONVERTER(log,
REGISTER_CONVERTER(swish,
ConvertUnaryActivations,
"huawei_ascend_npu,huawei_kirin_npu,nvidia_tensorrt,intel_"
"openvino,qualcomm_qnn,kunlunxin_xtcl");
"openvino,qualcomm_qnn,kunlunxin_xtcl,verisilicon_timvx");
REGISTER_CONVERTER(
prelu,
ConvertPRelu,
Expand Down Expand Up @@ -311,13 +311,15 @@ REGISTER_CONVERTER(
reduce_mean,
ConvertReduce,
"huawei_ascend_npu,cambricon_mlu,huawei_kirin_npu,intel_openvino,"
"kunlunxin_xtcl,qualcomm_qnn");
REGISTER_CONVERTER(reduce_max, ConvertReduce, "intel_openvino,qualcomm_qnn");
"kunlunxin_xtcl,qualcomm_qnn,verisilicon_timvx");
REGISTER_CONVERTER(reduce_max,
ConvertReduce,
"intel_openvino,qualcomm_qnn,verisilicon_timvx");
REGISTER_CONVERTER(
reduce_sum,
ConvertReduce,
"huawei_ascend_npu,cambricon_mlu,huawei_kirin_npu,nvidia_tensorrt,"
"kunlunxin_xtcl,qualcomm_qnn");
"kunlunxin_xtcl,qualcomm_qnn,verisilicon_timvx");
REGISTER_CONVERTER(top_k,
ConvertTopK,
"huawei_ascend_npu,cambricon_mlu,kunlunxin_xtcl");
Expand Down

0 comments on commit 6b88c17

Please sign in to comment.