Skip to content

Commit

Permalink
multithread_memory_optimize
Browse files Browse the repository at this point in the history
  • Loading branch information
JZZ-NOTE committed Dec 7, 2021
1 parent cf2c4ec commit bf33242
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 7 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,8 @@ typedef struct {
// used.
void MemoryOptimizePass::CollectLifeCycle(
Graph* graph, std::unordered_map<std::string, lifecycle_t>* lifecycles,
int sort_kind, int max_lifecycle) const {
max_lifecycle = 0;
int sort_kind) const {
int max_lifecycle = 0;
for (auto* op_node : framework::ir::TopologyVarientSort(
*graph, static_cast<framework::ir::SortKind>(sort_kind))) {
if (!op_node->IsOp()) continue;
Expand Down Expand Up @@ -304,19 +304,18 @@ void MemoryOptimizePass::RunImpl(Argument* argument) {
// 3. Perform reuse plan: Replace all var's name in the model according to the
// mapping table.
if (!argument->enable_memory_optim()) return;
// Because of pass is a singleton, graph and max_lifecycle can not be member
// Because of pass is a singleton, graph can not be member
// variables,otherwise,errors will be caused under multithreading
// conditions.
auto graph = argument->main_graph_ptr();
int max_lifecycle = -1;

int sort_kind = 0;
std::unordered_map<std::string, lifecycle_t> lifecycles;
space_table_t space_table;
std::unordered_map<std::string, std::string> node2cluster;
std::unordered_map<std::string, int> cluster_size;

CollectLifeCycle(graph, &lifecycles, sort_kind, max_lifecycle);
CollectLifeCycle(graph, &lifecycles, sort_kind);
CollectVarMemorySize(graph, &space_table);
MakeSimpleReusePlan(lifecycles, space_table, &node2cluster, &cluster_size);
UpdateOpDescsByReuse(graph, node2cluster, sort_kind);
Expand Down
4 changes: 2 additions & 2 deletions paddle/fluid/inference/analysis/passes/memory_optimize_pass.h
Original file line number Diff line number Diff line change
Expand Up @@ -58,8 +58,8 @@ class MemoryOptimizePass : public AnalysisPass {
private:
void CollectLifeCycle(
framework::ir::Graph *graph,
std::unordered_map<std::string, lifecycle_t> *lifecycles, int sort_kind,
int max_lifecycle) const;
std::unordered_map<std::string, lifecycle_t> *lifecycles,
int sort_kind) const;

void CollectVarMemorySize(framework::ir::Graph *graph,
space_table_t *space_table) const;
Expand Down

1 comment on commit bf33242

@paddle-bot-old
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Congratulation! Your pull request passed all required CI. You could ask reviewer(s) to approve and merge. 🎉

Please sign in to comment.