Skip to content

Commit

Permalink
Support stream priority for standalone executor (#49939)
Browse files Browse the repository at this point in the history
* Support stream priority for standalone executor

* Fix compile error

* Fix compile error

* Fix compile error

* Fix compile error

* Fix compile error
  • Loading branch information
From00 authored Jan 30, 2023
1 parent f12f2a9 commit 172d1de
Show file tree
Hide file tree
Showing 20 changed files with 206 additions and 95 deletions.
9 changes: 9 additions & 0 deletions paddle/fluid/distributed/auto_parallel/dist_attr.cc
Original file line number Diff line number Diff line change
Expand Up @@ -293,6 +293,7 @@ std::vector<std::string> OperatorDistAttr::fields_{"process_mesh",
"impl_idx",
"is_recompute",
"execution_stream",
"stream_priority",
"scheduling_priority"};

OperatorDistAttr::OperatorDistAttr(const OpDesc& op) {
Expand All @@ -318,6 +319,8 @@ OperatorDistAttr& OperatorDistAttr::operator=(
std::swap(this->impl_idx_, tmp.impl_idx_);
std::swap(this->is_recompute_, tmp.is_recompute_);
std::swap(this->execution_stream_, tmp.execution_stream_);
std::swap(this->stream_priority_, tmp.stream_priority_);
std::swap(this->scheduling_priority_, tmp.scheduling_priority_);
std::swap(this->annotated_, tmp.annotated_);
// Note: Make sure all tensor dist attr has the same process_mesh
set_process_mesh(this->process_mesh_);
Expand Down Expand Up @@ -349,6 +352,7 @@ void OperatorDistAttr::initialize(const OpDesc* op) {
impl_idx_ = 0;
is_recompute_ = false;
execution_stream_ = kDefault;
stream_priority_ = 0;
scheduling_priority_ = 0;
}

Expand All @@ -361,6 +365,7 @@ void OperatorDistAttr::copy_from(const OperatorDistAttr& dist_attr) {
set_impl_idx(dist_attr.impl_idx());
set_is_recompute(dist_attr.is_recompute());
set_execution_stream(dist_attr.execution_stream());
set_stream_priority(dist_attr.stream_priority());
set_scheduling_priority(dist_attr.scheduling_priority());
set_annotated(dist_attr.annotated());
}
Expand Down Expand Up @@ -599,6 +604,7 @@ std::string OperatorDistAttr::to_string() const {
str += "{impl_type: " + impl_type_ + ", ";
str += "impl_idx: " + std::to_string(impl_idx_) + ", ";
str += "execution_stream: " + execution_stream_ + ", ";
str += "stream_priority: " + std::to_string(stream_priority_) + ", ";
str += "scheduling_priority: " + std::to_string(scheduling_priority_) + ", ";
str += "annotated: [" + str_join(annotated_) + "], ";
str += "\nprocess_mesh: " + process_mesh_.to_string() + ", ";
Expand Down Expand Up @@ -684,6 +690,9 @@ bool operator==(const OperatorDistAttr& lhs, const OperatorDistAttr& rhs) {
if (lhs.execution_stream() != rhs.execution_stream()) {
return false;
}
if (lhs.stream_priority() != rhs.stream_priority()) {
return false;
}
if (lhs.scheduling_priority() != rhs.scheduling_priority()) {
return false;
}
Expand Down
7 changes: 7 additions & 0 deletions paddle/fluid/distributed/auto_parallel/dist_attr.h
Original file line number Diff line number Diff line change
Expand Up @@ -222,6 +222,12 @@ class OperatorDistAttr {
execution_stream_ = execution_stream;
}

int stream_priority() const { return stream_priority_; }

void set_stream_priority(int stream_priority) {
stream_priority_ = stream_priority;
}

int64_t scheduling_priority() const { return scheduling_priority_; }

void set_scheduling_priority(int64_t scheduling_priority) {
Expand Down Expand Up @@ -289,6 +295,7 @@ class OperatorDistAttr {
int64_t impl_idx_ = 0;
bool is_recompute_ = false;
std::string execution_stream_ = kDefault;
int stream_priority_ = 0; // lower value, higher priority
int64_t scheduling_priority_ = 0; // lower value, higher priority
std::map<std::string, bool> annotated_;
};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,8 @@ FastThreadedSSAGraphExecutor::FastThreadedSSAGraphExecutor(
platform::EmplaceDeviceContexts(
&fetch_ctxs_,
places,
/*disable_setting_default_stream_for_allocator=*/true);
/*disable_setting_default_stream_for_allocator=*/true,
/*stream_priority=*/0);
if (ir::IsTopologySortOperationsUnique(*graph_)) {
VLOG(10)
<< "Change thread number to 1 because the toposort order is unique";
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,8 @@ ThreadedSSAGraphExecutor::ThreadedSSAGraphExecutor(
platform::EmplaceDeviceContexts(
&fetch_ctxs_,
places,
/*disable_setting_default_stream_for_allocator=*/true);
/*disable_setting_default_stream_for_allocator=*/true,
/*stream_priority=*/0);

if (strategy_.num_iteration_per_run_ > 1) {
int read_op_num = 0;
Expand Down
33 changes: 20 additions & 13 deletions paddle/fluid/framework/new_executor/interpreter/interpreter_util.cc
Original file line number Diff line number Diff line change
Expand Up @@ -666,21 +666,28 @@ bool BuildOpFuncList(const platform::Place& place,
op_func_node.output_index = outs_name2id;

const OperatorDistAttr* dist_attr = block.Op(i)->DistAttr();
if (dist_attr &&
dist_attr->execution_stream() != distributed::auto_parallel::kDefault) {
op_func_node.execution_stream_ = dist_attr->execution_stream();
}

if (dist_attr) {
op_func_node.priority_ = dist_attr->scheduling_priority();
} else if (interpreter::IsCommunicationOp(op_type)) {
// NOTE(Ruibiao): Dispatching computation before communication improves
// multi-stream overlap when the time cost of communication less than that
// of the calculation (e.g., ResNet50_bs128_pure_fp16 N4C32 training).
op_func_node.priority_ = 1;
if (dist_attr->execution_stream() !=
distributed::auto_parallel::kDefault) {
op_func_node.execution_stream_ = dist_attr->execution_stream();
}
op_func_node.stream_priority_ = dist_attr->stream_priority();
op_func_node.scheduling_priority_ = dist_attr->scheduling_priority();
} else {
if (interpreter::IsCommunicationOp(op_type)) {
// NOTE(Ruibiao): Dispatching computation before communication improves
// multi-stream overlap when the time cost of communication less than
// that of the calculation (e.g., ResNet50_bs128_pure_fp16 N4C32
// training).
op_func_node.scheduling_priority_ = 1;
}
}
VLOG(6) << "scheduling priority of " << op_type << " : "
<< op_func_node.priority_;

VLOG(6) << op_type
<< " : [execution_stream, stream_priority, scheduling_priority] = ["
<< op_func_node.execution_stream_ << ", "
<< op_func_node.stream_priority_ << ", "
<< op_func_node.scheduling_priority_ << "]";

SingleStreamGuard single_stream_guard(ops[i]);

Expand Down
20 changes: 15 additions & 5 deletions paddle/fluid/framework/new_executor/interpreter/stream_analyzer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,9 @@ class ContextManager {
}

std::shared_future<std::unique_ptr<DeviceContext>> Get(
const std::string& type, const platform::Place& place) {
const std::string& type,
const platform::Place& place,
int stream_priority) {
std::lock_guard<std::mutex> lk(ctx_mtx_);
VLOG(6) << "Get dev_ctx for " << type << " - " << place;

Expand All @@ -48,7 +50,8 @@ class ContextManager {
platform::EmplaceDeviceContexts(
&ctxs,
{place},
/*disable_setting_default_stream_for_allocator=*/true);
/*disable_setting_default_stream_for_allocator=*/true,
stream_priority);
}
return ctxs[place];
}
Expand Down Expand Up @@ -142,6 +145,7 @@ DeviceContext* StreamAnalyzer::ParseDeviceContext(
auto& op = op_func_node.operator_base_;
auto& op_type = op->Type();
const std::string& execution_stream = op_func_node.execution_stream_;
const int stream_priority = op_func_node.stream_priority_;
ContextManager& ctx_manager = ContextManager::Instance();

// only gpu/npu need update. xpu not need, because xpu memcpy op kernel is
Expand All @@ -152,15 +156,21 @@ DeviceContext* StreamAnalyzer::ParseDeviceContext(
<< ", execution stream = " << execution_stream;
if (execution_stream != kDefaultStream) {
return ctx_manager
.Get(std::string(kCustomStream) + "-" + execution_stream, place_)
.Get(std::string(kCustomStream) + "-" + execution_stream,
place_,
stream_priority)
.get()
.get();
}

if (op_type == interpreter::kMemcpyD2H) {
return ctx_manager.Get(std::string(kD2HStream), place_).get().get();
return ctx_manager.Get(std::string(kD2HStream), place_, stream_priority)
.get()
.get();
} else if (op_type == interpreter::kMemcpyH2D) {
return ctx_manager.Get(std::string(kH2DStream), place_).get().get();
return ctx_manager.Get(std::string(kH2DStream), place_, stream_priority)
.get()
.get();
}

#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
Expand Down
16 changes: 9 additions & 7 deletions paddle/fluid/framework/new_executor/interpretercore.cc
Original file line number Diff line number Diff line change
Expand Up @@ -139,13 +139,15 @@ InterpreterCore::InterpreterCore(const platform::Place& place,
}
var_scope_.SetLocalScope(local_scope_);

instruction_prority_less = [this](size_t lhs, size_t rhs) {
Priority lhs_prority = vec_instruction_[lhs].GetPriority();
Priority rhs_prority = vec_instruction_[rhs].GetPriority();
if (lhs_prority == rhs_prority) {
instruction_scheduling_priority_less = [this](size_t lhs, size_t rhs) {
SchedulingPriority lhs_scheduling_priority =
vec_instruction_[lhs].GetSchedulingPriority();
SchedulingPriority rhs_scheduling_priority =
vec_instruction_[rhs].GetSchedulingPriority();
if (lhs_scheduling_priority == rhs_scheduling_priority) {
return lhs < rhs;
}
return lhs_prority > rhs_prority;
return lhs_scheduling_priority > rhs_scheduling_priority;
};

PrepareForCUDAGraphCapture();
Expand Down Expand Up @@ -1089,7 +1091,7 @@ void InterpreterCore::RunInstructionAsync(size_t instr_id) {
// scheduling, the priority order involved cross-thread scheduling is not
// guaranteed. Only Ops scheduled by the same AddTask call have the guarantee
// of priority order.
SchedulingQueue ready_ops(instruction_prority_less);
SchedulingQueue ready_ops(instruction_scheduling_priority_less);
ready_ops.push(instr_id);
while (!ready_ops.empty()) {
instr_id = ready_ops.top();
Expand Down Expand Up @@ -1427,7 +1429,7 @@ void InterpreterCore::AnalyseExecuteOrderForTrace() {
};

std::vector<size_t> trace_order;
SchedulingQueue ready_ops(instruction_prority_less);
SchedulingQueue ready_ops(instruction_scheduling_priority_less);

for (size_t instr_id = 0; instr_id < dependecy_count_.size(); ++instr_id) {
if (dependecy_count_[instr_id] == 0) {
Expand Down
8 changes: 5 additions & 3 deletions paddle/fluid/framework/new_executor/interpretercore.h
Original file line number Diff line number Diff line change
Expand Up @@ -79,9 +79,11 @@ class InterpreterCore {
const platform::Place& GetPlace() const { return place_; }

private:
using InstructionPriorityLess = std::function<bool(size_t, size_t)>;
using InstructionSchedulingPriorityLess = std::function<bool(size_t, size_t)>;
using SchedulingQueue =
std::priority_queue<size_t, std::vector<size_t>, InstructionPriorityLess>;
std::priority_queue<size_t,
std::vector<size_t>,
InstructionSchedulingPriorityLess>;

// build graph
void Convert(std::vector<paddle::framework::OpFuncNode>* op_func_nodes);
Expand Down Expand Up @@ -181,7 +183,7 @@ class InterpreterCore {
int64_t sync_op_num_{-1};
std::vector<size_t> trace_execute_order_;

InstructionPriorityLess instruction_prority_less;
InstructionSchedulingPriorityLess instruction_scheduling_priority_less;
};

std::shared_ptr<InterpreterCore> CreateInterpreterCore(
Expand Down
11 changes: 7 additions & 4 deletions paddle/fluid/framework/new_executor/new_executor_defs.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,12 +32,12 @@ namespace framework {

using OpKernelComputeFunc = std::function<void(const ExecutionContext&)>;

using Priority = int64_t;
using SchedulingPriority = int64_t;

constexpr const char* kCoalesceTensor = "coalesce_tensor";

// stream types
constexpr const char* kCustomStream = "CustromStream";
constexpr const char* kCustomStream = "CustomStream";
constexpr const char* kDefaultStream = "DefaultStream";
constexpr const char* kD2HStream = "D2HStream";
constexpr const char* kH2DStream = "H2DStream";
Expand Down Expand Up @@ -263,6 +263,7 @@ enum class OpFuncType {
class RuntimeInferShapeContext;

struct OpFuncNode {
int stream_priority_{0}; // lower value, higher priority
// fit for phi kernel
phi::Kernel* phi_kernel_{nullptr}; // not owned
platform::DeviceContext* dev_ctx_; // not owned
Expand All @@ -279,7 +280,7 @@ struct OpFuncNode {
OpFuncType type_;
OpKernelComputeFunc kernel_func_;

Priority priority_{0}; // lower value, higher priority
SchedulingPriority scheduling_priority_{0}; // lower value, higher priority
};

class Instruction {
Expand Down Expand Up @@ -369,7 +370,9 @@ class Instruction {

void ClearInplace();

Priority GetPriority() const { return op_func_node_.priority_; }
SchedulingPriority GetSchedulingPriority() const {
return op_func_node_.scheduling_priority_;
}

private:
bool is_artificial_; // Instruction is artificial means that it is only used
Expand Down
Loading

0 comments on commit 172d1de

Please sign in to comment.