Skip to content

Commit

Permalink
cache scope in while (#52628)
Browse files Browse the repository at this point in the history
  • Loading branch information
zhupengyang authored Apr 12, 2023
1 parent cea62c0 commit 8e7c378
Showing 1 changed file with 26 additions and 7 deletions.
33 changes: 26 additions & 7 deletions paddle/fluid/operators/controlflow/while_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,13 @@
#ifdef PADDLE_WITH_MKLDNN
#include "paddle/fluid/platform/mkldnn_helper.h"
#endif

PADDLE_DEFINE_EXPORTED_bool(
cache_inference_while_scope,
false,
"Cache the scope of the while op to avoid repeated creation of the scope "
"for each iteration and improve inference performance.");

namespace paddle {
namespace framework {
class InferShapeContext;
Expand Down Expand Up @@ -257,14 +264,23 @@ class WhileOp : public framework::OperatorBase {
scope.FindVar(Input(kCondition))->Get<phi::DenseTensor>());
}
} else {
auto &current_scope = scope.NewScope();

BuildScopeForControlFlowOp(*core_, *block, &current_scope);
core_->reset_scope(&current_scope);
framework::Scope *current_scope = nullptr;
if (!FLAGS_cache_inference_while_scope) {
current_scope = &(scope.NewScope());
BuildScopeForControlFlowOp(*core_, *block, current_scope);
core_->reset_scope(current_scope);
} else {
if (cached_inference_scope_ == nullptr) {
cached_inference_scope_ = &(scope.NewScope());
BuildScopeForControlFlowOp(*core_, *block, cached_inference_scope_);
core_->reset_scope(cached_inference_scope_);
}
current_scope = cached_inference_scope_;
}

while (cond_data) {
for (auto &name : current_scope.LocalVarNames()) {
auto *var = current_scope.Var(name);
for (auto &name : current_scope->LocalVarNames()) {
auto *var = current_scope->Var(name);
if (var->IsType<phi::DenseTensor>()) {
// Clear all lod information for all lod_tensors.
auto *t = var->GetMutable<phi::DenseTensor>();
Expand All @@ -283,14 +299,17 @@ class WhileOp : public framework::OperatorBase {
scope.FindVar(Input(kCondition))->Get<phi::DenseTensor>());
}

scope.DeleteScope(&current_scope);
if (!FLAGS_cache_inference_while_scope) {
scope.DeleteScope(current_scope);
}
}
}

private:
mutable std::shared_ptr<framework::Executor> executor_{nullptr};
mutable std::unique_ptr<framework::ExecutorPrepareContext> ctx_{nullptr};
mutable std::shared_ptr<framework::InterpreterCore> core_{nullptr};
mutable framework::Scope *cached_inference_scope_{nullptr};
};

class WhileOpMaker : public framework::OpProtoAndCheckerMaker {
Expand Down

0 comments on commit 8e7c378

Please sign in to comment.