Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[fleet_executor] Add retry for message bus, drop kids for mirco-scope #37809

Merged
merged 1 commit into from
Dec 3, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 49 additions & 22 deletions paddle/fluid/distributed/fleet_executor/compute_interceptor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,8 @@ void ComputeInterceptor::PrepareDeps() {
"Source ComputeInterceptor must run at least one "
"times, but now max_run_times=%ld",
node_->max_run_times()));
in_readys_.emplace(-1,
std::make_pair(std::numeric_limits<int64_t>::max(), 0));
}

// If there is no downstream or every downstream is in different rank,
Expand All @@ -55,14 +57,17 @@ void ComputeInterceptor::PrepareDeps() {
}

void ComputeInterceptor::IncreaseReady(int64_t up_id) {
// source node has no upstream, data_is_ready is send by carrier or others
if (is_source_ && up_id == -1) return;

auto it = in_readys_.find(up_id);
PADDLE_ENFORCE_NE(it, in_readys_.end(),
platform::errors::NotFound(
"Cannot find upstream=%lld in in_readys.", up_id));

// source node has no upstream, data_is_ready is send by carrier or others
if (is_source_ && up_id == -1) {
it->second.second = GetTaskNode()->max_run_times();
return;
}

auto max_ready_size = it->second.first;
auto ready_size = it->second.second;
ready_size += 1;
Expand Down Expand Up @@ -93,7 +98,11 @@ bool ComputeInterceptor::IsInputReady() {
for (auto& ins : in_readys_) {
auto ready_size = ins.second.second;
// not ready, return false
if (ready_size == 0) return false;
if (ready_size == 0) {
VLOG(3) << "Interceptor " << GetInterceptorId()
<< "'s upstreams aren't all ready.";
return false;
}
}
return true;
}
Expand All @@ -103,14 +112,23 @@ bool ComputeInterceptor::CanWriteOutput() {
auto max_buffer_size = outs.second.first;
auto used_size = outs.second.second;
// full, return false
if (used_size == max_buffer_size) return false;
if (used_size == max_buffer_size) {
VLOG(3) << "Interceptor " << GetInterceptorId()
<< "'s out buffer is full.";
return false;
}
}
return true;
}

// only source node need reset
bool ComputeInterceptor::ShouldReset() {
return is_source_ && (step_ == node_->max_run_times());
if (is_source_ && step_ == node_->max_run_times()) {
VLOG(3) << "Interceptor " << GetInterceptorId()
<< " should reset for step: " << step_ << ".";
return true;
}
return false;
}

void ComputeInterceptor::SendDataReadyToDownStream() {
Expand All @@ -130,7 +148,8 @@ void ComputeInterceptor::SendDataReadyToDownStream() {
InterceptorMessage ready_msg;
ready_msg.set_message_type(DATA_IS_READY);
VLOG(3) << "ComputeInterceptor " << interceptor_id_
<< " Send data_is_ready msg to " << down_id;
<< " Send data_is_ready msg to " << down_id
<< " for step: " << step_;
Send(down_id, ready_msg);
}
}
Expand All @@ -147,23 +166,43 @@ void ComputeInterceptor::ReplyCompletedToUpStream() {
ready_size));
ins.second.second = ready_size;

VLOG(3) << "ComputeInterceptor " << interceptor_id_
<< " Reply data_is_useless msg to " << up_id
<< " for step: " << step_;
if (up_id == -1) return;

InterceptorMessage reply_msg;
reply_msg.set_message_type(DATE_IS_USELESS);
VLOG(3) << "ComputeInterceptor " << interceptor_id_
<< " Reply data_is_useless msg to " << up_id;
Send(up_id, reply_msg);
}
}

void ComputeInterceptor::RunOps() {
VLOG(3) << "ComputeInterceptor " << interceptor_id_ << " running ops for the "
<< step_ << " time.";
<< step_ + 1 << " time.";
for (auto op : node_->ops()) {
op->Run(*microbatch_scopes_[step_ % node_->max_run_times()], place_);
}
}

void ComputeInterceptor::Run() {
// If there is no limit, source interceptor can be executed
// an unlimited number of times.
// Now source node can only run max_run_times.
if (ShouldReset()) {
for (auto& out_buff : out_buffs_) {
// buffer is using
if (out_buff.second.second != 0) {
VLOG(3) << "Interceptor " << GetInterceptorId()
<< " out buffer for downstream: " << out_buff.first
<< "'s counter is: " << out_buff.second.second
<< ". Cannot be reset.";
return;
}
}
step_ = 0; // reset
}

while (IsInputReady() && CanWriteOutput() && !ShouldReset()) {
VLOG(3) << "id=" << GetInterceptorId() << " ComputeInterceptor running";

Expand All @@ -181,18 +220,6 @@ void ComputeInterceptor::Run() {
StopCarrier();
}
}

// If there is no limit, source interceptor can be executed
// an unlimited number of times.
// Now source node can only run max_run_times.
if (ShouldReset()) {
for (auto& out_buff : out_buffs_) {
// buffer is using
if (out_buff.second.second != 0) return;
}
step_ = 0; // reset
return;
}
}

void ComputeInterceptor::ReceivedStop(int64_t up_id) {
Expand Down
9 changes: 9 additions & 0 deletions paddle/fluid/distributed/fleet_executor/fleet_executor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,15 @@ void FleetExecutor::Run() {
message_bus_instance.IsInit(), true,
platform::errors::Unavailable("MessageBus has not been init yet."));
carrier_instance.Start();
for (auto* micro_scop : microbatch_scopes_) {
// By default, we should delete all kid scopes after run executor because
// some operators may create local scope when running, such as while_op.
// But when while_op also create a local executor to run it's sub block,
// the sub scopes it created should not be dropped immediately, because
// while_grad_op will use some variables created during while_op run, so
// we need to keep the kids and wait for the outer executor to drop them.
micro_scop->DropKids();
}
}

void FleetExecutor::CopyParameters(int microbatch_id,
Expand Down
33 changes: 30 additions & 3 deletions paddle/fluid/distributed/fleet_executor/message_bus.cc
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

#include <chrono>
#include <memory>
#include <set>
#include <thread>

#include "paddle/fluid/distributed/fleet_executor/carrier.h"
Expand Down Expand Up @@ -56,11 +57,11 @@ void MessageBus::Init(
bool MessageBus::IsInit() const { return is_init_; }

MessageBus::~MessageBus() {
VLOG(3) << "Message bus releases resource.";
// NOTE: fleet_executor inits carrier before message bus,
// therefore the message bus's destructor will be called first
Carrier& carrier = Carrier::Instance();
carrier.Release();
VLOG(3) << "Message bus releases resource.";
#if defined(PADDLE_WITH_DISTRIBUTE) && defined(PADDLE_WITH_PSCORE) && \
!defined(PADDLE_WITH_ASCEND_CL)
server_.Stop(1000);
Expand Down Expand Up @@ -90,6 +91,8 @@ bool MessageBus::Send(const InterceptorMessage& interceptor_message) {
<< retry_time << " times retries.";
return true;
}
VLOG(3) << "Message bus sends failed, retry after 1 seconds.";
std::this_thread::sleep_for(std::chrono::milliseconds(1000));
}
VLOG(3) << "Message bus sends inter rank fail after 10 times retries.";
return false;
Expand Down Expand Up @@ -121,16 +124,40 @@ void MessageBus::ListenPort() {
brpc::ServerOptions options;
options.idle_timeout_sec = -1;
int retry_times = 0;
int interval = 1000;
int interval = 100;
while (server_.Start(ip_for_brpc, &options) != 0) {
++retry_times;
LOG(INFO) << "Message bus is retring for starting brpc for " << retry_times
<< " times. And will retry after " << interval / 1000
<< " seconds.";
std::this_thread::sleep_for(std::chrono::milliseconds(interval));
interval += 2000;
interval += 500;
}
LOG(INFO) << "Message bus's listen port thread starts successful.";

std::set<int64_t> visit;
InterceptorMessage tmp_msg;
tmp_msg.set_ctrl_message(true);
for (auto pair : interceptor_id_to_rank_) {
if (rank_to_addr_.at(pair.second) == addr_) {
tmp_msg.set_src_id(pair.first);
}
}
for (auto pair : interceptor_id_to_rank_) {
int64_t rank = pair.second;
if (rank_to_addr_.at(rank) == addr_) {
continue;
}
tmp_msg.set_dst_id(pair.first);
if (visit.find(rank) == visit.end()) {
VLOG(3) << "Message bus is testing connection for rank: " << rank << ".";
visit.insert(rank);
while (!Send(tmp_msg)) {
std::this_thread::sleep_for(std::chrono::milliseconds(1000));
}
VLOG(3) << "Message bus has connected to rank: " << rank << ".";
}
}
#else
LOG(WARNING)
<< "Fleet executor's ListenPort() is a fake function when Paddle is "
Expand Down