From ace62b5a203e71a3779db4f9922ed357e4ef2f32 Mon Sep 17 00:00:00 2001 From: zhupengyang Date: Fri, 7 Apr 2023 05:31:34 +0000 Subject: [PATCH 1/2] while reset_scope --- .../fluid/operators/controlflow/while_op.cc | 24 +++++++++---------- 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/paddle/fluid/operators/controlflow/while_op.cc b/paddle/fluid/operators/controlflow/while_op.cc index 3017a1e0fc4b79..7dc6cbf9d885b9 100644 --- a/paddle/fluid/operators/controlflow/while_op.cc +++ b/paddle/fluid/operators/controlflow/while_op.cc @@ -271,18 +271,19 @@ class WhileOp : public framework::OperatorBase { scope.FindVar(Input(kCondition))->Get()); } } else { - auto ¤t_scope = scope.NewScope(); - - if (FLAGS_control_flow_use_new_executor) { - BuildScopeForControlFlowOp(*core_, *block, ¤t_scope); - core_->reset_scope(¤t_scope); - } else { - executor_->CreateVariables(*program, ¤t_scope, block->ID()); + if (inference_scope_ == nullptr) { + inference_scope_ = &(scope.NewScope()); + if (FLAGS_control_flow_use_new_executor) { + BuildScopeForControlFlowOp(*core_, *block, inference_scope_); + core_->reset_scope(inference_scope_); + } else { + executor_->CreateVariables(*program, inference_scope_, block->ID()); + } } while (cond_data) { - for (auto &name : current_scope.LocalVarNames()) { - auto *var = current_scope.Var(name); + for (auto &name : inference_scope_->LocalVarNames()) { + auto *var = inference_scope_->Var(name); if (var->IsType()) { // Clear all lod information for all lod_tensors. auto *t = var->GetMutable(); @@ -299,14 +300,12 @@ class WhileOp : public framework::OperatorBase { core_->Run({}, false); } else { executor_->RunPreparedContext( - ctx_.get(), ¤t_scope, false, false, false); + ctx_.get(), inference_scope_, false, false, false); } cond_data = GetCondData( scope.FindVar(Input(kCondition))->Get()); } - - scope.DeleteScope(¤t_scope); } } @@ -314,6 +313,7 @@ class WhileOp : public framework::OperatorBase { mutable std::shared_ptr executor_{nullptr}; mutable std::unique_ptr ctx_{nullptr}; mutable std::shared_ptr core_{nullptr}; + mutable framework::Scope *inference_scope_{nullptr}; }; class WhileOpMaker : public framework::OpProtoAndCheckerMaker { From 9ff85b605f397490210b7ced2d7950383731c91d Mon Sep 17 00:00:00 2001 From: zhupengyang Date: Tue, 11 Apr 2023 11:31:05 +0000 Subject: [PATCH 2/2] test --- .../fluid/operators/controlflow/while_op.cc | 33 +++++++++++++++---- 1 file changed, 26 insertions(+), 7 deletions(-) diff --git a/paddle/fluid/operators/controlflow/while_op.cc b/paddle/fluid/operators/controlflow/while_op.cc index 3a6f39c261abc9..4c7578c0104739 100644 --- a/paddle/fluid/operators/controlflow/while_op.cc +++ b/paddle/fluid/operators/controlflow/while_op.cc @@ -22,6 +22,13 @@ #ifdef PADDLE_WITH_MKLDNN #include "paddle/fluid/platform/mkldnn_helper.h" #endif + +PADDLE_DEFINE_EXPORTED_bool( + cache_inference_while_scope, + false, + "Cache the scope of the while op to avoid repeated creation of the scope " + "for each iteration and improve inference performance."); + namespace paddle { namespace framework { class InferShapeContext; @@ -257,15 +264,23 @@ class WhileOp : public framework::OperatorBase { scope.FindVar(Input(kCondition))->Get()); } } else { - if (inference_scope_ == nullptr) { - inference_scope_ = &(scope.NewScope()); - BuildScopeForControlFlowOp(*core_, *block, inference_scope_); - core_->reset_scope(inference_scope_); + framework::Scope *current_scope = nullptr; + if (!FLAGS_cache_inference_while_scope) { + current_scope = &(scope.NewScope()); + BuildScopeForControlFlowOp(*core_, *block, current_scope); + core_->reset_scope(current_scope); + } else { + if (cached_inference_scope_ == nullptr) { + cached_inference_scope_ = &(scope.NewScope()); + BuildScopeForControlFlowOp(*core_, *block, cached_inference_scope_); + core_->reset_scope(cached_inference_scope_); + } + current_scope = cached_inference_scope_; } while (cond_data) { - for (auto &name : inference_scope_->LocalVarNames()) { - auto *var = inference_scope_->Var(name); + for (auto &name : current_scope->LocalVarNames()) { + auto *var = current_scope->Var(name); if (var->IsType()) { // Clear all lod information for all lod_tensors. auto *t = var->GetMutable(); @@ -283,6 +298,10 @@ class WhileOp : public framework::OperatorBase { cond_data = GetCondData( scope.FindVar(Input(kCondition))->Get()); } + + if (!FLAGS_cache_inference_while_scope) { + scope.DeleteScope(current_scope); + } } } @@ -290,7 +309,7 @@ class WhileOp : public framework::OperatorBase { mutable std::shared_ptr executor_{nullptr}; mutable std::unique_ptr ctx_{nullptr}; mutable std::shared_ptr core_{nullptr}; - mutable framework::Scope *inference_scope_{nullptr}; + mutable framework::Scope *cached_inference_scope_{nullptr}; }; class WhileOpMaker : public framework::OpProtoAndCheckerMaker {