diff --git a/paddle/fluid/eager/backward.cc b/paddle/fluid/eager/backward.cc index a220fe18fb35d3..c68129bfb0a494 100644 --- a/paddle/fluid/eager/backward.cc +++ b/paddle/fluid/eager/backward.cc @@ -373,32 +373,31 @@ std::vector RunBackward( auto add_next_node_func = [&node_in_degree_map, &queue](GradNodeBase* next_node) { - if (node_in_degree_map[next_node] == 0) { - if (dynamic_cast(next_node)) { - queue.push_front(std::move(next_node)); - } else { - queue.push_back(std::move(next_node)); - } + if (dynamic_cast(next_node)) { + queue.push_front(std::move(next_node)); + } else { + queue.push_back(std::move(next_node)); } }; - - if (force_sequential_nodes_set.count(next_node)) { - if (force_sequential_nodes_queue.front() == next_node) { - force_sequential_nodes_queue.pop_front(); - add_next_node_func(next_node); - while (ready_force_sequential_nodes.count( - force_sequential_nodes_queue.front())) { - ready_force_sequential_nodes.erase( - force_sequential_nodes_queue.front()); - add_next_node_func(force_sequential_nodes_queue.front()); + if (node_in_degree_map[next_node] == 0) { + if (force_sequential_nodes_set.count(next_node)) { + if (force_sequential_nodes_queue.front() == next_node) { force_sequential_nodes_queue.pop_front(); + add_next_node_func(next_node); + while (ready_force_sequential_nodes.count( + force_sequential_nodes_queue.front())) { + ready_force_sequential_nodes.erase( + force_sequential_nodes_queue.front()); + add_next_node_func(force_sequential_nodes_queue.front()); + force_sequential_nodes_queue.pop_front(); + } + } else { + ready_force_sequential_nodes.insert(next_node); + continue; } } else { - ready_force_sequential_nodes.insert(next_node); - continue; + add_next_node_func(next_node); } - } else { - add_next_node_func(next_node); } } }