Skip to content

Commit

Permalink
[fleet_executor] Move IntraSend to Carrier. Using blocking queue (#38322
Browse files Browse the repository at this point in the history
)
  • Loading branch information
LiYuRio authored Dec 22, 2021
1 parent 142ea17 commit ddc15a1
Show file tree
Hide file tree
Showing 19 changed files with 228 additions and 262 deletions.
4 changes: 2 additions & 2 deletions paddle/fluid/distributed/fleet_executor/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,15 @@ endif()
proto_library(interceptor_message_proto SRCS interceptor_message.proto)

if(WITH_DISTRIBUTE AND WITH_PSCORE AND NOT (WITH_ASCEND OR WITH_ASCEND_CL))
set(BRPC_DEPS brpc ssl crypto protobuf gflags glog zlib leveldb snappy gflags glog)
set(BRPC_DEPS brpc ssl crypto protobuf zlib leveldb snappy gflags glog)
else()
set(BRPC_DEPS "")
endif()

cc_library(fleet_executor SRCS fleet_executor.cc carrier.cc task_node.cc runtime_graph.cc
interceptor.cc compute_interceptor.cc amplifier_interceptor.cc interceptor_message_service.cc message_bus.cc
DEPS proto_desc fleet_executor_desc_proto interceptor_message_proto collective_helper op_registry
executor_gc_helper ${BRPC_DEPS})
executor_gc_helper gflags glog ${BRPC_DEPS})

if(WITH_DISTRIBUTE)
set(DISTRIBUTE_COMPILE_FLAGS "-Wno-non-virtual-dtor -Wno-error=non-virtual-dtor -Wno-error=delete-non-virtual-dtor")
Expand Down
71 changes: 48 additions & 23 deletions paddle/fluid/distributed/fleet_executor/carrier.cc
Original file line number Diff line number Diff line change
Expand Up @@ -27,14 +27,16 @@ namespace distributed {
USE_INTERCEPTOR(Compute);
USE_INTERCEPTOR(Amplifier);

void Carrier::Init(std::shared_ptr<RuntimeGraph> runtime_graph,
void Carrier::Init(int64_t rank, std::shared_ptr<RuntimeGraph> runtime_graph,
framework::Scope* root_scope,
framework::Scope* minibatch_scope,
const std::vector<framework::Scope*>& microbatch_scopes,
const platform::Place& place) {
PADDLE_ENFORCE_EQ(is_init_, false, platform::errors::AlreadyExists(
"Carrier is already init."));
rank_ = rank;
runtime_graph_ = runtime_graph;
interceptor_id_to_rank_ = runtime_graph_->interceptor_id_to_rank();
minibatch_scope_ = minibatch_scope;
microbatch_scopes_ = microbatch_scopes;
place_ = place;
Expand All @@ -48,12 +50,6 @@ void Carrier::Release() {
// NOTE(wangxi): must join before `Derived Interceptor` destruct,
// otherwise Derived object will be destructed before thread complete.

// Sending STOP msg to the source interceptor
PADDLE_ENFORCE_EQ(msg_bus_->IsInit(), true,
platform::errors::PreconditionNotMet(
"Using message bus since it has not been initialized. "
"Please invoke MessageBus::Init() before using it or "
"neccessary components are not ready."));
for (int64_t id : source_interceptor_ids_) {
VLOG(3) << "Carrier Release is sending stop to source interceptor " << id
<< ".";
Expand All @@ -75,10 +71,10 @@ Carrier::~Carrier() { VLOG(3) << "Carrier's destructor."; }

bool Carrier::EnqueueInterceptorMessage(
const InterceptorMessage& interceptor_message) {
// enqueue message to interceptor
if (interceptor_message.ctrl_message()) {
// handle control message
return true;
VLOG(3) << "Receiving control message from rank "
<< interceptor_message.src_id() << " to rank "
<< interceptor_message.dst_id();
} else {
{
std::unique_lock<std::mutex> lock_creating(creating_flag_mutex_);
Expand All @@ -93,15 +89,9 @@ bool Carrier::EnqueueInterceptorMessage(
}
int64_t dst_id = interceptor_message.dst_id();
Interceptor* dst_interceptor = GetInterceptor(dst_id);
bool rst =
dst_interceptor->EnqueueRemoteInterceptorMessage(interceptor_message);
if (rst) {
std::condition_variable& interceptor_cond_var =
dst_interceptor->GetCondVar();
interceptor_cond_var.notify_all();
}
return rst;
dst_interceptor->EnqueueRemoteInterceptorMessage(interceptor_message);
}
return true;
}

Interceptor* Carrier::GetInterceptor(int64_t interceptor_id) {
Expand Down Expand Up @@ -144,9 +134,44 @@ std::condition_variable& Carrier::GetCondVar() { return cond_var_; }

bool Carrier::IsInit() const { return is_init_; }

// TODO(liyurui): Move SendIntra into carrier
bool Carrier::Send(const InterceptorMessage& msg) const {
return msg_bus_->Send(msg);
int64_t Carrier::GetRank(int64_t interceptor_id) const {
PADDLE_ENFORCE_NE(
interceptor_id_to_rank_.find(interceptor_id),
interceptor_id_to_rank_.end(),
platform::errors::NotFound("Cannot find rank for interceptor id %lld.",
interceptor_id));
return interceptor_id_to_rank_.at(interceptor_id);
}

bool Carrier::Send(const InterceptorMessage& msg) {
int64_t src_id = (msg.src_id() == -1) ? msg.dst_id() : msg.src_id();
int64_t dst_id = msg.dst_id();
int64_t src_rank = GetRank(src_id);
int64_t dst_rank = GetRank(dst_id);
PADDLE_ENFORCE_EQ(
src_rank, rank_,
platform::errors::Fatal("The source rank id %lld, which is not equal to "
"the carrier rank id %lld.",
src_rank, rank_));
if (src_rank == dst_rank) {
VLOG(3) << "Send a message from interceptor " << src_id
<< " to interceptor " << dst_id << ", which are in the same ranks.";
return EnqueueInterceptorMessage(msg);
} else {
PADDLE_ENFORCE_NOT_NULL(
msg_bus_.get(),
platform::errors::Unavailable("Message bus is released accidently"));
PADDLE_ENFORCE_EQ(
msg_bus_->IsInit(), true,
platform::errors::PreconditionNotMet(
"Using message bus since it has not been initialized. "
"Please invoke MessageBus::Init() before using it or "
"neccessary components are not ready."));
VLOG(3) << "Send a message from interceptor " << src_id
<< " to interceptor " << dst_id
<< ", which are in different ranks.";
return msg_bus_->Send(dst_rank, msg);
}
}

Interceptor* Carrier::SetInterceptor(int64_t interceptor_id,
Expand Down Expand Up @@ -222,13 +247,13 @@ static std::shared_ptr<framework::GarbageCollector> GetGC(
}

void Carrier::CreateInterceptors() {
if (runtime_graph_->intercepter_id_to_node().empty()) return;
if (runtime_graph_->interceptor_id_to_node().empty()) return;

auto gc = GetGC(place_);

// create each Interceptor
// no auto init since there is no config
for (const auto& item : runtime_graph_->intercepter_id_to_node()) {
for (const auto& item : runtime_graph_->interceptor_id_to_node()) {
int64_t interceptor_id = item.first;
TaskNode* task_node = item.second;

Expand Down
10 changes: 8 additions & 2 deletions paddle/fluid/distributed/fleet_executor/carrier.h
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,11 @@ class MessageBus;
class Carrier final {
public:
Carrier() = default;
Carrier(int64_t rank,
const std::unordered_map<int64_t, int64_t>& interceptor_id_to_rank)
: rank_(rank), interceptor_id_to_rank_(interceptor_id_to_rank) {}
~Carrier();
void Init(std::shared_ptr<RuntimeGraph> runtime_graph,
void Init(int64_t rank, std::shared_ptr<RuntimeGraph> runtime_graph,
framework::Scope* root_scope, framework::Scope* minibatch_scope,
const std::vector<framework::Scope*>& microbatch_scopes,
const platform::Place& place);
Expand Down Expand Up @@ -75,7 +78,7 @@ class Carrier final {

bool IsInit() const;

bool Send(const InterceptorMessage& msg) const;
bool Send(const InterceptorMessage& msg);

// NOTE: This mutex will be used in interceptor's RunOps function.
// This mutex is used for avoiding forward ops and backward ops run
Expand All @@ -90,6 +93,8 @@ class Carrier final {

void HandleTmpMessages();

int64_t GetRank(int64_t interceptor_id) const;

// interceptor logic id to actually interceptor
std::unordered_map<int64_t, std::unique_ptr<Interceptor>>
interceptor_idx_to_interceptor_;
Expand All @@ -111,6 +116,7 @@ class Carrier final {
paddle::platform::DeviceContext* dev_ctx_{nullptr};
std::shared_ptr<RuntimeGraph> runtime_graph_;
std::shared_ptr<MessageBus> msg_bus_;
int64_t rank_;
std::unordered_map<int64_t, int64_t> interceptor_id_to_rank_;
};

Expand Down
30 changes: 15 additions & 15 deletions paddle/fluid/distributed/fleet_executor/fleet_executor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
// limitations under the License.

#include "paddle/fluid/distributed/fleet_executor/fleet_executor.h"
#include "paddle/fluid/distributed/fleet_executor/carrier.h"
#include "paddle/fluid/distributed/fleet_executor/message_bus.h"
#include "paddle/fluid/distributed/fleet_executor/runtime_graph.h"
#include "paddle/fluid/distributed/fleet_executor/task_node.h"
Expand All @@ -28,6 +27,8 @@
namespace paddle {
namespace distributed {

std::unique_ptr<Carrier> FleetExecutor::carrier_;

FleetExecutor::FleetExecutor(const std::string& exe_desc_str) {
bool parse_flag = exe_desc_.ParseFromString(exe_desc_str);
PADDLE_ENFORCE(parse_flag, platform::errors::PreconditionNotMet(
Expand All @@ -36,12 +37,13 @@ FleetExecutor::FleetExecutor(const std::string& exe_desc_str) {

FleetExecutor::~FleetExecutor() {
root_scope_->DropKids();
GetCarrier().Release();
GetCarrier()->Release();
}

Carrier& FleetExecutor::GetCarrier() {
static Carrier carrier;
return carrier;
Carrier* FleetExecutor::GetCarrier() {
PADDLE_ENFORCE_NOT_NULL(carrier_.get(), platform::errors::NotFound(
"Carrier has not been created."));
return carrier_.get();
}

void FleetExecutor::Init(
Expand Down Expand Up @@ -84,16 +86,16 @@ void FleetExecutor::Init(
}
VLOG(5) << runtime_graph_->DebugString();
msg_bus_ = std::make_shared<MessageBus>();
CreateCarrier();
InitCarrier();
InitMessageBus();
}

void FleetExecutor::InitCarrier() {
Carrier& carrier = GetCarrier();
if (!carrier.IsInit()) {
carrier.SetMsgBus(msg_bus_);
carrier.Init(runtime_graph_, root_scope_, minibatch_scope_,
microbatch_scopes_, place_);
if (!GetCarrier()->IsInit()) {
GetCarrier()->SetMsgBus(msg_bus_);
GetCarrier()->Init(exe_desc_.cur_rank(), runtime_graph_, root_scope_,
minibatch_scope_, microbatch_scopes_, place_);
}
}

Expand Down Expand Up @@ -128,21 +130,19 @@ void FleetExecutor::InitMessageBus() {
<< (rank_to_addr.size() == 0 ? 1 : rank_to_addr.size()) << ".";
VLOG(5) << ss.str();
if (!msg_bus_->IsInit()) {
msg_bus_->Init(runtime_graph_->intercepter_id_to_rank(), rank_to_addr,
addr);
msg_bus_->Init(cur_rank, rank_to_addr, addr);
}
}

void FleetExecutor::Run() {
// Run
Carrier& carrier = GetCarrier();
PADDLE_ENFORCE_EQ(
carrier.IsInit(), true,
GetCarrier()->IsInit(), true,
platform::errors::Unavailable("Carrier has not been init yet."));
PADDLE_ENFORCE_EQ(
msg_bus_->IsInit(), true,
platform::errors::Unavailable("MessageBus has not been init yet."));
carrier.Start();
GetCarrier()->Start();
for (auto* micro_scop : microbatch_scopes_) {
// By default, we should delete all kid scopes after run executor because
// some operators may create local scope when running, such as while_op.
Expand Down
13 changes: 11 additions & 2 deletions paddle/fluid/distributed/fleet_executor/fleet_executor.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
#include <memory>
#include <string>

#include "paddle/fluid/distributed/fleet_executor/carrier.h"
#include "paddle/fluid/distributed/fleet_executor/fleet_executor_desc.pb.h"
#include "paddle/fluid/platform/macros.h"
#include "paddle/fluid/platform/place.h"
Expand All @@ -30,7 +31,6 @@ namespace distributed {
class RuntimeGraph;
class MessageBus;
class TaskNode;
class Carrier;

class FleetExecutor final {
public:
Expand All @@ -43,7 +43,15 @@ class FleetExecutor final {
const std::unordered_map<int64_t, int64_t>& task_id_to_rank);
void Run();
// TODO(liyurui): Change to use registry table for multi-carrier.
static Carrier& GetCarrier();
static Carrier* GetCarrier();
template <typename... Args>
static Carrier* CreateCarrier(Args&&... args) {
PADDLE_ENFORCE_EQ(
carrier_.get(), nullptr,
platform::errors::AlreadyExists("Carrier has been created already."));
carrier_ = std::make_unique<Carrier>(std::forward<Args>(args)...);
return carrier_.get();
}

private:
DISABLE_COPY_AND_ASSIGN(FleetExecutor);
Expand All @@ -59,6 +67,7 @@ class FleetExecutor final {
// The carriers under FleetExecutor will share message bus,
// using shared_ptr to manage lifetime and condition race.
std::shared_ptr<MessageBus> msg_bus_;
static std::unique_ptr<Carrier> carrier_;
};

} // namespace distributed
Expand Down
28 changes: 5 additions & 23 deletions paddle/fluid/distributed/fleet_executor/interceptor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -52,24 +52,17 @@ void Interceptor::StopCarrier() {
cond_var.notify_all();
}

std::condition_variable& Interceptor::GetCondVar() {
// get the conditional var
return cond_var_;
}

int64_t Interceptor::GetInterceptorId() const {
// return the interceptor id
return interceptor_id_;
}

bool Interceptor::EnqueueRemoteInterceptorMessage(
void Interceptor::EnqueueRemoteInterceptorMessage(
const InterceptorMessage& interceptor_message) {
// Called by Carrier, enqueue an InterceptorMessage to remote mailbox
VLOG(3) << "Enqueue message: " << interceptor_message.message_type()
<< " into " << interceptor_id_ << "'s remote mailbox.";
std::unique_lock<std::mutex> lock(remote_mailbox_mutex_);
remote_mailbox_.push(interceptor_message);
return true;
remote_mailbox_.Push(interceptor_message);
}

bool Interceptor::Send(int64_t dst_id, InterceptorMessage& msg) {
Expand All @@ -92,7 +85,7 @@ void Interceptor::PoolTheMailbox() {
"Error encountered when fetch remote mailbox."));
}
const InterceptorMessage interceptor_message = local_mailbox_.front();
local_mailbox_.pop();
local_mailbox_.pop_front();
const MessageType message_type = interceptor_message.message_type();
VLOG(3) << "Interceptor " << interceptor_id_ << " has received a message"
<< " from interceptor " << interceptor_message.src_id()
Expand All @@ -109,19 +102,8 @@ void Interceptor::PoolTheMailbox() {
}

bool Interceptor::FetchRemoteMailbox() {
// fetch all Message from remote mailbox to local mailbox
// return true if remote mailbox not empty, otherwise return false
std::unique_lock<std::mutex> lock(remote_mailbox_mutex_);
cond_var_.wait(lock, [this]() { return !remote_mailbox_.empty(); });
if (remote_mailbox_.empty()) {
// the thread has been unblocked accidentally
return false;
}
while (!remote_mailbox_.empty()) {
local_mailbox_.push(std::move(remote_mailbox_.front()));
remote_mailbox_.pop();
}
return true;
remote_mailbox_.PopAll(&local_mailbox_);
return !local_mailbox_.empty();
}

static InterceptorFactory::CreateInterceptorMap& GetInterceptorMap() {
Expand Down
Loading

0 comments on commit ddc15a1

Please sign in to comment.