Skip to content

Commit

Permalink
[XPU]. Fixed the bug in xpu kernel pick cross subgraphs.
Browse files Browse the repository at this point in the history
  • Loading branch information
wbn03 committed Sep 7, 2022
1 parent 51d3dbd commit 42131c2
Showing 1 changed file with 10 additions and 5 deletions.
15 changes: 10 additions & 5 deletions lite/core/optimizer/mir/__xpu__static_kernel_pick_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -517,7 +517,12 @@ void XPUStaticKernelPickPass::InplaceOpScore(lite::mir::Node* node,
CHECK(instruct.op_info()->GetInputArgname(var_name, &tmp));
VLOG(6) << "current kernel input data variable name:" << var_name
<< "Parameter name:" << tmp;
if (in_node->inlinks.empty()) {
if (in_node->inlinks.empty() && xpu_output_type_.count(var_name) == 0) {
continue;
}

// only to match input X
if (tmp != "X") {
continue;
}

Expand Down Expand Up @@ -549,7 +554,7 @@ void XPUStaticKernelPickPass::InplaceOpScore(lite::mir::Node* node,
const auto& var_name = var.name;
std::string tmp;
CHECK(instruct.op_info()->GetOutputArgname(var_name, &tmp));
if (out_node->outlinks.empty()) {
if (out_node->outlinks.empty() && xpu_input_type_.count(var_name) == 0) {
continue;
}

Expand Down Expand Up @@ -584,7 +589,7 @@ void XPUStaticKernelPickPass::SpecialOpScore(lite::mir::Node* node,
const auto& var_name = var.name;
std::string tmp;
CHECK(instruct.op_info()->GetInputArgname(var_name, &tmp));
if (in_node->inlinks.empty()) {
if (in_node->inlinks.empty() && xpu_output_type_.count(var_name) == 0) {
if (kernel.GetInputDeclType(tmp)->precision() == PrecisionType::kFP16) {
*score = 0;
VLOG(6) << "not pick fp16 kernel ,because input weight "
Expand All @@ -601,7 +606,7 @@ void XPUStaticKernelPickPass::SpecialOpScore(lite::mir::Node* node,
const auto& var_name = var.name;
std::string tmp;
CHECK(instruct.op_info()->GetInputArgname(var_name, &tmp));
if (in_node->inlinks.empty()) {
if (in_node->inlinks.empty() && xpu_output_type_.count(var_name) == 0) {
continue;
}

Expand Down Expand Up @@ -644,7 +649,7 @@ void XPUStaticKernelPickPass::SpecialOpScore(lite::mir::Node* node,
std::string tmp;
CHECK(instruct.op_info()->GetOutputArgname(var_name, &tmp));
int output_match_num = xpu_input_type_.count(var_name);
if (out_node->outlinks.empty()) {
if (out_node->outlinks.empty() && output_match_num == 0) {
continue;
}

Expand Down

0 comments on commit 42131c2

Please sign in to comment.