Skip to content

Commit

Permalink
clone
Browse files Browse the repository at this point in the history
  • Loading branch information
zhupengyang committed Jun 14, 2023
1 parent d06a1f3 commit 05b2c5f
Show file tree
Hide file tree
Showing 4 changed files with 24 additions and 7 deletions.
2 changes: 1 addition & 1 deletion lite/api/cxx_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -245,7 +245,7 @@ class LITE_API Predictor {
std::map<TargetType, std::shared_ptr<void>> target_configs_;
std::shared_ptr<cpp::ProgramDesc> program_desc_;
std::shared_ptr<Scope> scope_;
Scope* exec_scope_;
Scope* exec_scope_{nullptr};
std::shared_ptr<RuntimeProgram> program_;
bool program_generated_{false};
std::vector<std::string> input_names_;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1906,7 +1906,7 @@ class XPUMultiEncoderFuser {
if (is_qkv_already_fusion_) {
end = i + 1;
}
scope->NewTensor(update_tag);
scope->MutableParent()->NewTensor(update_tag);
// Update weight, including tranpose\convert type\fuse qkv
// weight\findmax.
update_weight(scope,
Expand Down
20 changes: 16 additions & 4 deletions lite/core/optimizer/mir/type_target_cast_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -274,8 +274,15 @@ void TypeTargetTransformPass::AddInputIoCopyInst(
// So there will be a new Argument node and a new IoCopy Statement Node.

CHECK(in->IsArg());
auto io_copy_output_name =
string_format("%s/target_trans", in->AsArg().name.c_str());
bool in_persist = in->AsArg().is_weight || in->AsArg().is_persist;
std::string io_copy_output_name;
if (in_persist) {
io_copy_output_name =
string_format("%s/target_trans_persistable", in->AsArg().name.c_str());
} else {
io_copy_output_name =
string_format("%s/target_trans", in->AsArg().name.c_str());
}

if (copied_nodes->count(in->AsArg().name)) {
// Remove the old link
Expand All @@ -292,7 +299,13 @@ void TypeTargetTransformPass::AddInputIoCopyInst(
// TODO(MyPandaShaoxiang) should set same place with input?
auto* io_copy_output_arg = graph->NewArgumentNode(io_copy_output_name);
// Create the new var manually.
auto new_var = inst_node->AsStmt().op()->scope()->Var(io_copy_output_name);
Variable* new_var = nullptr;
if (in_persist) {
new_var = inst_node->AsStmt().op()->scope()->MutableParent()->Var(
io_copy_output_name);
} else {
new_var = inst_node->AsStmt().op()->scope()->Var(io_copy_output_name);
}
// Set the place for io_copy_output_arg node, the target should be equal to
// to.target()
// The precision and layout should be equal to from.precision(),
Expand All @@ -316,7 +329,6 @@ void TypeTargetTransformPass::AddInputIoCopyInst(
}
auto* io_copy_inst = graph->NewInstructNode();

bool in_persist = in->AsArg().is_weight || in->AsArg().is_persist;
std::string io_copy_type = in_persist ? "io_copy_once" : "io_copy";
io_copy_output_arg->AsArg().is_persist = in_persist;
// create Op and kernels.
Expand Down
7 changes: 6 additions & 1 deletion lite/core/program.cc
Original file line number Diff line number Diff line change
Expand Up @@ -718,7 +718,12 @@ void Program::PrepareWorkspace(
// Create tensors or weights from variable description.
if (!var_desc->Persistable()) {
vars_.push_back(var_name);
auto* var = exec_scope_->Var(var_name);
Variable* var = nullptr;
if (var_name.find("/target_trans_persistable") != std::string::npos) {
var = scope_->Var(var_name);
} else {
var = exec_scope_->Var(var_name);
}
if (var_type == lite::VarDescAPI::Type::LOD_TENSOR) {
const auto& var_data_type =
VarDescType2PrecisionType(var_desc->GetDataType());
Expand Down

0 comments on commit 05b2c5f

Please sign in to comment.