-
Notifications
You must be signed in to change notification settings - Fork 1.6k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[Cherry-pick][NNAdapter][TIM-VX] fuse sigmoid&mul into swish (#9623)
- Loading branch information
1 parent
d6a6d56
commit 6b88c17
Showing
7 changed files
with
183 additions
and
4 deletions.
There are no files selected for viewing
23 changes: 23 additions & 0 deletions
23
lite/backends/nnadapter/nnadapter/include/nnadapter/optimizer/fuse_sigmoid_mul_into_swish.h
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,23 @@ | ||
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. | ||
// | ||
// Licensed under the Apache License, Version 2.0 (the "License"); | ||
// you may not use this file except in compliance with the License. | ||
// You may obtain a copy of the License at | ||
// | ||
// http://www.apache.org/licenses/LICENSE-2.0 | ||
// | ||
// Unless required by applicable law or agreed to in writing, software | ||
// distributed under the License is distributed on an "AS IS" BASIS, | ||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
// See the License for the specific language governing permissions and | ||
// limitations under the License. | ||
|
||
#pragma once | ||
|
||
#include "core/types.h" | ||
|
||
namespace nnadapter { | ||
|
||
void FuseSigmoidMulIntoSwish(core::Model *model); | ||
|
||
} // namespace nnadapter |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
59 changes: 59 additions & 0 deletions
59
lite/backends/nnadapter/nnadapter/src/driver/verisilicon_timvx/converter/reduce.cc
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,59 @@ | ||
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. | ||
// | ||
// Licensed under the Apache License, Version 2.0 (the "License"); | ||
// you may not use this file except in compliance with the License. | ||
// You may obtain a copy of the License at | ||
// | ||
// http://www.apache.org/licenses/LICENSE-2.0 | ||
// | ||
// Unless required by applicable law or agreed to in writing, software | ||
// distributed under the License is distributed on an "AS IS" BASIS, | ||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
// See the License for the specific language governing permissions and | ||
// limitations under the License. | ||
|
||
#include "operation/reduce.h" | ||
#include "driver/verisilicon_timvx/converter/converter.h" | ||
#include "utility/debug.h" | ||
#include "utility/logging.h" | ||
|
||
namespace nnadapter { | ||
namespace verisilicon_timvx { | ||
|
||
int ConvertReduce(Converter* converter, core::Operation* operation) { | ||
REDUCE_OPERATION_EXTRACT_INPUTS_OUTPUTS | ||
|
||
// Convert to tim-vx tensors and operators | ||
auto input_tensor = converter->GetMappedTensor(input_operand); | ||
if (!input_tensor) { | ||
input_tensor = converter->ConvertOperand(input_operand); | ||
} | ||
auto output_tensor = converter->ConvertOperand(output_operand); | ||
std::vector<int32_t> axis; | ||
for (int i = 0; i < axes_size; i++) { | ||
axis.push_back(axes_data[i] - 2); | ||
} | ||
switch (operation->type) { | ||
#define CONVERT_REDUCE(type, class_name) \ | ||
case NNADAPTER_##type: { \ | ||
auto reduce_op = \ | ||
converter->graph()->CreateOperation<tim::vx::ops::class_name>( \ | ||
axis, keep_dim); \ | ||
reduce_op->BindInputs({input_tensor}); \ | ||
reduce_op->BindOutputs({output_tensor}); \ | ||
} break; | ||
CONVERT_REDUCE(REDUCE_MEAN, ReduceMean); | ||
CONVERT_REDUCE(REDUCE_SUM, ReduceSum); | ||
CONVERT_REDUCE(REDUCE_MAX, ReduceMax); | ||
#undef CONVERT_REDUCE | ||
default: | ||
NNADAPTER_LOG(FATAL) << "Unsupported reduce operation type " | ||
<< OperationTypeToString(operation->type) | ||
<< " is found."; | ||
break; | ||
} | ||
return NNADAPTER_NO_ERROR; | ||
} | ||
|
||
} // namespace verisilicon_timvx | ||
} // namespace nnadapter |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
88 changes: 88 additions & 0 deletions
88
lite/backends/nnadapter/nnadapter/src/optimizer/fuse_sigmoid_mul_into_swish.cc
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,88 @@ | ||
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. | ||
// | ||
// Licensed under the Apache License, Version 2.0 (the "License"); | ||
// you may not use this file except in compliance with the License. | ||
// You may obtain a copy of the License at | ||
// | ||
// http://www.apache.org/licenses/LICENSE-2.0 | ||
// | ||
// Unless required by applicable law or agreed to in writing, software | ||
// distributed under the License is distributed on an "AS IS" BASIS, | ||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
// See the License for the specific language governing permissions and | ||
// limitations under the License. | ||
|
||
#include "optimizer/fuse_sigmoid_mul_into_swish.h" | ||
#include <algorithm> | ||
#include <map> | ||
#include <vector> | ||
#include "optimizer/pattern_matcher.h" | ||
#include "utility/debug.h" | ||
#include "utility/logging.h" | ||
#include "utility/micros.h" | ||
#include "utility/modeling.h" | ||
#include "utility/utility.h" | ||
|
||
namespace nnadapter { | ||
|
||
class SigmoidMulFuser : public PatternMatcher { | ||
public: | ||
SigmoidMulFuser() {} | ||
void BuildPattern() override; | ||
bool HandleMatchedResults(core::Model* model, | ||
const std::map<std::string, Node*>& nodes) override; | ||
}; | ||
|
||
void SigmoidMulFuser::BuildPattern() { | ||
// Operation patterns | ||
auto sigmoid_pattern = | ||
CreatePattern("sigmoid", NNADAPTER_SIGMOID)->IsIntermediate(); | ||
auto mul_pattern = CreatePattern("mul", NNADAPTER_MUL)->IsIntermediate(); | ||
// Operand patterns | ||
auto sigmoid_input_pattern = | ||
CreatePattern("sigmoid_input") | ||
->IsOperationInputOperand(NNADAPTER_SIGMOID, 0) | ||
->IsOperationInputOperand(NNADAPTER_MUL, 0); | ||
auto sigmoid_output_pattern = | ||
CreatePattern("sigmoid_output") | ||
->IsOperationOutputOperand(NNADAPTER_SIGMOID, 0) | ||
->IsOperationInputOperand(NNADAPTER_MUL, 1) | ||
->IsIntermediate(); | ||
auto mul_fuse_code_pattern = CreatePattern("mul_fuse_code") | ||
->IsOperationInputOperand(NNADAPTER_MUL, 2) | ||
->IsIntermediate(); | ||
auto mul_output_pattern = | ||
CreatePattern("mul_output")->IsOperationOutputOperand(NNADAPTER_MUL, 0); | ||
// Create the topological connections for the above patterns | ||
std::vector<Pattern*> mul_input_patterns{ | ||
sigmoid_input_pattern, mul_fuse_code_pattern, sigmoid_output_pattern}; | ||
*sigmoid_input_pattern >> *sigmoid_pattern >> *sigmoid_output_pattern; | ||
mul_input_patterns >> *mul_pattern >> *mul_output_pattern; | ||
} | ||
|
||
bool SigmoidMulFuser::HandleMatchedResults( | ||
core::Model* model, const std::map<std::string, Node*>& nodes) { | ||
// Get the operands and operations from the matched subgraph nodes. | ||
auto sigmoid_input_operand = nodes.at("sigmoid_input")->operand; | ||
auto mul_output_operand = nodes.at("mul_output")->operand; | ||
// Create a new NNADAPTER_SWISH operation and replace the matched | ||
// subgraph nodes. | ||
auto* swish_operation = AddOperation(model); | ||
swish_operation->type = NNADAPTER_SWISH; | ||
swish_operation->input_operands = {sigmoid_input_operand}; | ||
swish_operation->output_operands = {mul_output_operand}; | ||
// The matched intermediate operands and operations will be deleted only when | ||
// it returns true. | ||
return true; | ||
} | ||
|
||
NNADAPTER_EXPORT void FuseSigmoidMulIntoSwish(core::Model* model) { | ||
NNADAPTER_VLOG(5) << "Apply SigmoidMulFuser"; | ||
bool stop; | ||
do { | ||
SigmoidMulFuser sigmoid_mul_fuser; | ||
stop = sigmoid_mul_fuser.Apply(model) == 0; | ||
} while (!stop); | ||
} | ||
|
||
} // namespace nnadapter |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters