diff --git a/paddle/fluid/operators/controlflow/while_op.cc b/paddle/fluid/operators/controlflow/while_op.cc index 30fdb90ce1069..4c7578c010473 100644 --- a/paddle/fluid/operators/controlflow/while_op.cc +++ b/paddle/fluid/operators/controlflow/while_op.cc @@ -22,6 +22,13 @@ #ifdef PADDLE_WITH_MKLDNN #include "paddle/fluid/platform/mkldnn_helper.h" #endif + +PADDLE_DEFINE_EXPORTED_bool( + cache_inference_while_scope, + false, + "Cache the scope of the while op to avoid repeated creation of the scope " + "for each iteration and improve inference performance."); + namespace paddle { namespace framework { class InferShapeContext; @@ -257,14 +264,23 @@ class WhileOp : public framework::OperatorBase { scope.FindVar(Input(kCondition))->Get()); } } else { - auto ¤t_scope = scope.NewScope(); - - BuildScopeForControlFlowOp(*core_, *block, ¤t_scope); - core_->reset_scope(¤t_scope); + framework::Scope *current_scope = nullptr; + if (!FLAGS_cache_inference_while_scope) { + current_scope = &(scope.NewScope()); + BuildScopeForControlFlowOp(*core_, *block, current_scope); + core_->reset_scope(current_scope); + } else { + if (cached_inference_scope_ == nullptr) { + cached_inference_scope_ = &(scope.NewScope()); + BuildScopeForControlFlowOp(*core_, *block, cached_inference_scope_); + core_->reset_scope(cached_inference_scope_); + } + current_scope = cached_inference_scope_; + } while (cond_data) { - for (auto &name : current_scope.LocalVarNames()) { - auto *var = current_scope.Var(name); + for (auto &name : current_scope->LocalVarNames()) { + auto *var = current_scope->Var(name); if (var->IsType()) { // Clear all lod information for all lod_tensors. auto *t = var->GetMutable(); @@ -283,7 +299,9 @@ class WhileOp : public framework::OperatorBase { scope.FindVar(Input(kCondition))->Get()); } - scope.DeleteScope(¤t_scope); + if (!FLAGS_cache_inference_while_scope) { + scope.DeleteScope(current_scope); + } } } @@ -291,6 +309,7 @@ class WhileOp : public framework::OperatorBase { mutable std::shared_ptr executor_{nullptr}; mutable std::unique_ptr ctx_{nullptr}; mutable std::shared_ptr core_{nullptr}; + mutable framework::Scope *cached_inference_scope_{nullptr}; }; class WhileOpMaker : public framework::OpProtoAndCheckerMaker {