Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Fleet Executor] Support multi carrier #38535

Merged
merged 2 commits into from
Dec 30, 2021
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 42 additions & 17 deletions paddle/fluid/distributed/fleet_executor/carrier.cc
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
// limitations under the License.

#include "paddle/fluid/distributed/fleet_executor/carrier.h"
#include "paddle/fluid/distributed/fleet_executor/global_map.h"
#include "paddle/fluid/distributed/fleet_executor/interceptor.h"
#include "paddle/fluid/distributed/fleet_executor/interceptor_message_service.h"
#include "paddle/fluid/distributed/fleet_executor/message_bus.h"
Expand All @@ -27,16 +28,32 @@ namespace distributed {
USE_INTERCEPTOR(Compute);
USE_INTERCEPTOR(Amplifier);

void Carrier::Init(int64_t rank, std::shared_ptr<RuntimeGraph> runtime_graph,
LiYuRio marked this conversation as resolved.
Show resolved Hide resolved
framework::Scope* root_scope,
framework::Scope* minibatch_scope,
const std::vector<framework::Scope*>& microbatch_scopes,
const platform::Place& place) {
PADDLE_ENFORCE_EQ(is_init_, false, platform::errors::AlreadyExists(
"Carrier is already init."));
void Carrier::Init(
int64_t rank,
const std::unordered_map<int64_t, int64_t>& interceptor_id_to_rank,
const std::unordered_set<int64_t>& interceptor_ids) {
rank_ = rank;
runtime_graph_ = runtime_graph;
interceptor_id_to_rank_ = runtime_graph_->interceptor_id_to_rank();
interceptor_id_to_rank_ = interceptor_id_to_rank;
interceptor_ids_ = interceptor_ids;

// TODO(fleet_exe dev): thread pool
thread_num_ = 1;
thread_pool_.SetThreadNum(thread_num_);
thread_pool_.Start();
}

void Carrier::Init(
int64_t rank,
const std::unordered_map<int64_t, int64_t>& interceptor_id_to_rank,
const std::unordered_set<int64_t>& interceptor_ids,
const std::unordered_map<int64_t, TaskNode*>& interceptor_id_to_node,
framework::Scope* root_scope, framework::Scope* minibatch_scope,
const std::vector<framework::Scope*>& microbatch_scopes,
const platform::Place& place) {
rank_ = rank;
interceptor_id_to_rank_ = interceptor_id_to_rank;
interceptor_ids_ = interceptor_ids;
interceptor_id_to_node_ = interceptor_id_to_node;
minibatch_scope_ = minibatch_scope;
microbatch_scopes_ = microbatch_scopes;
place_ = place;
Expand Down Expand Up @@ -72,8 +89,6 @@ bool Carrier::EnqueueInterceptorMessage(
return true;
}

void Carrier::Barrier() { msg_bus_->Barrier(); }

Interceptor* Carrier::GetInterceptor(int64_t interceptor_id) {
auto iter = interceptor_idx_to_interceptor_.find(interceptor_id);
PADDLE_ENFORCE_NE(iter, interceptor_idx_to_interceptor_.end(),
Expand All @@ -100,7 +115,8 @@ void Carrier::Start() {
"Using message bus since it has not been initialized. "
"Please invoke MessageBus::Init() before using it or "
"neccessary components are not ready."));

PADDLE_ENFORCE_EQ(is_init_, true, platform::errors::PreconditionNotMet(
"Using carrier before initialized."));
for (int64_t id : source_interceptor_ids_) {
VLOG(3) << "Carrier Start is sending start to source interceptor " << id
<< ".";
Expand Down Expand Up @@ -140,7 +156,9 @@ bool Carrier::Send(const InterceptorMessage& msg) {
if (src_rank == dst_rank) {
VLOG(3) << "Send a message from interceptor " << src_id
<< " to interceptor " << dst_id << ", which are in the same ranks.";
return EnqueueInterceptorMessage(msg);
int64_t carrier_id = *GlobalMap<int64_t, int64_t>::Get(dst_id);
FeixLiu marked this conversation as resolved.
Show resolved Hide resolved
return GlobalMap<int64_t, Carrier>::Get(carrier_id)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

不同carrier没有互通的必要吧

Copy link
Contributor Author

@LiYuRio LiYuRio Dec 30, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这是因为如果src interceptor和dst interceptor对应不同的carrier,但这两个carrier在相同的rank下,src interceptor里调用的是自己carrier的send,这里不处理的话会出问题。

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

嗯,不过我觉得不同carrier没有互通的必要,你这个场景肯定是src和dst有关联,那么应该属于一个carrier。不同carrier可以对应不同program,比如训练和预测program,跑一个epoch训练再预测一下,时间和关系上是相互独立的。

->EnqueueInterceptorMessage(msg);
} else {
PADDLE_ENFORCE_NOT_NULL(
msg_bus_.get(),
Expand Down Expand Up @@ -174,6 +192,9 @@ Interceptor* Carrier::SetInterceptor(int64_t interceptor_id,
loop, platform::errors::Fatal("thread task loop must not null"));
interceptor->RegisterTaskLoop(loop);

// TODO(liyurui): Using struct InterceptorID replace int64_t
GlobalMap<int64_t, int64_t>::Create(interceptor_id, carrier_id_);

auto* ptr = interceptor.get();
interceptor_idx_to_interceptor_.insert(
std::make_pair(interceptor_id, std::move(interceptor)));
Expand All @@ -199,15 +220,19 @@ static std::shared_ptr<framework::GarbageCollector> GetGC(
}

void Carrier::CreateInterceptors() {
if (runtime_graph_->interceptor_id_to_node().empty()) return;
if (interceptor_ids_.empty()) return;

auto gc = GetGC(place_);

// create each Interceptor
// no auto init since there is no config
for (const auto& item : runtime_graph_->interceptor_id_to_node()) {
int64_t interceptor_id = item.first;
TaskNode* task_node = item.second;
for (int64_t interceptor_id : interceptor_ids_) {
const auto& task_node_iter = interceptor_id_to_node_.find(interceptor_id);
PADDLE_ENFORCE_NE(
task_node_iter, interceptor_id_to_node_.end(),
platform::errors::NotFound("Can not find task node for interceptor %ld",
interceptor_id));
TaskNode* task_node = task_node_iter->second;

PADDLE_ENFORCE_LT(
task_node->run_at_offset(), task_node->run_per_steps(),
Expand Down
32 changes: 16 additions & 16 deletions paddle/fluid/distributed/fleet_executor/carrier.h
Original file line number Diff line number Diff line change
Expand Up @@ -45,19 +45,19 @@ class MessageBus;

class Carrier final {
public:
Carrier() = default;
Carrier(int64_t rank,
const std::unordered_map<int64_t, int64_t>& interceptor_id_to_rank)
: rank_(rank), interceptor_id_to_rank_(interceptor_id_to_rank) {
thread_num_ = 1;
thread_pool_.SetThreadNum(thread_num_);
thread_pool_.Start();
}
explicit Carrier(int64_t carrier_id) : carrier_id_(carrier_id) {}
~Carrier();
void Init(int64_t rank, std::shared_ptr<RuntimeGraph> runtime_graph,
framework::Scope* root_scope, framework::Scope* minibatch_scope,
const std::vector<framework::Scope*>& microbatch_scopes,
const platform::Place& place);
void Init(int64_t rank,
const std::unordered_map<int64_t, int64_t>& interceptor_id_to_rank,
const std::unordered_set<int64_t>& interceptor_ids);
void Init(
int64_t rank,
const std::unordered_map<int64_t, int64_t>& interceptor_id_to_rank,
const std::unordered_set<int64_t>& interceptor_ids,
const std::unordered_map<int64_t, TaskNode*>& interceptor_id_to_node,
framework::Scope* root_scope, framework::Scope* minibatch_scope,
const std::vector<framework::Scope*>& microbatch_scopes,
const platform::Place& place);

void Release();
void Wait();
Expand All @@ -83,10 +83,9 @@ class Carrier final {

bool Send(const InterceptorMessage& msg);

void Barrier();

private:
DISABLE_COPY_AND_ASSIGN(Carrier);
Carrier() = delete;

// create each Interceptor
void CreateInterceptors();
Expand All @@ -108,13 +107,14 @@ class Carrier final {
framework::Scope* minibatch_scope_;
paddle::platform::Place place_;
paddle::platform::DeviceContext* dev_ctx_{nullptr};
std::shared_ptr<RuntimeGraph> runtime_graph_;
std::shared_ptr<MessageBus> msg_bus_;
int64_t rank_;
int64_t carrier_id_;
std::unordered_map<int64_t, TaskNode*> interceptor_id_to_node_;
std::unordered_map<int64_t, int64_t> interceptor_id_to_rank_;

int thread_num_;
TaskLoopThreadPool thread_pool_;
std::unordered_set<int64_t> interceptor_ids_;
};

} // namespace distributed
Expand Down
54 changes: 28 additions & 26 deletions paddle/fluid/distributed/fleet_executor/fleet_executor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
// limitations under the License.

#include "paddle/fluid/distributed/fleet_executor/fleet_executor.h"
#include "paddle/fluid/distributed/fleet_executor/global_map.h"
#include "paddle/fluid/distributed/fleet_executor/message_bus.h"
#include "paddle/fluid/distributed/fleet_executor/runtime_graph.h"
#include "paddle/fluid/distributed/fleet_executor/task_node.h"
Expand All @@ -27,8 +28,6 @@
namespace paddle {
namespace distributed {

std::unique_ptr<Carrier> FleetExecutor::carrier_;

FleetExecutor::FleetExecutor(const std::string& exe_desc_str) {
bool parse_flag = exe_desc_.ParseFromString(exe_desc_str);
PADDLE_ENFORCE(parse_flag, platform::errors::PreconditionNotMet(
Expand All @@ -37,13 +36,9 @@ FleetExecutor::FleetExecutor(const std::string& exe_desc_str) {

FleetExecutor::~FleetExecutor() {
root_scope_->DropKids();
GetCarrier()->Release();
}

Carrier* FleetExecutor::GetCarrier() {
PADDLE_ENFORCE_NOT_NULL(carrier_.get(), platform::errors::NotFound(
"Carrier has not been created."));
return carrier_.get();
for (const auto& item : runtime_graph_->carrier_id_to_interceptor_ids()) {
GlobalMap<int64_t, Carrier>::Get(item.first)->Release();
}
}

void FleetExecutor::Init(
Expand All @@ -63,13 +58,20 @@ void FleetExecutor::Init(
auto unused_vars = framework::GetUnusedVars(program_desc.Block(0), ops, {});
runtime_graph_ = std::make_shared<RuntimeGraph>();
std::unordered_map<int64_t, TaskNode*> interceptor_id_to_task;
std::unordered_map<int64_t, std::unordered_set<int64_t>>
carrier_id_to_interceptor_ids;
std::unordered_set<int64_t> interceptor_ids;
for (auto task_node : task_nodes) {
task_node->SetUnusedVars(unused_vars);
int64_t interceptor_id = task_node->task_id();
interceptor_id_to_task.emplace(interceptor_id, task_node);
interceptor_ids.insert(interceptor_id);
}
carrier_id_to_interceptor_ids.emplace(0, interceptor_ids);
FeixLiu marked this conversation as resolved.
Show resolved Hide resolved
runtime_graph_->SetCarrierIdToInterceptorIds(carrier_id_to_interceptor_ids);
runtime_graph_->SetInterceptorIdToRank(task_id_to_rank);
runtime_graph_->SetInterceptorIdToNode(interceptor_id_to_task);
runtime_graph_->SetCarrierIdToInterceptorIds(carrier_id_to_interceptor_ids);
FeixLiu marked this conversation as resolved.
Show resolved Hide resolved
for (auto& unique_op : ops) {
unique_op.release();
}
Expand All @@ -86,21 +88,26 @@ void FleetExecutor::Init(
}
VLOG(5) << runtime_graph_->DebugString();
msg_bus_ = std::make_shared<MessageBus>();
CreateCarrier();
for (const auto& item : runtime_graph_->carrier_id_to_interceptor_ids()) {
GlobalMap<int64_t, Carrier>::Create(item.first, item.first);
FeixLiu marked this conversation as resolved.
Show resolved Hide resolved
}
InitCarrier();
InitMessageBus();

// refine this? wait all carrier ready
// NOTE(wangxi): must add after Carrier::SetMsgBus, for we use
// MessageBus::IncreaseBarrierCount when receive barrier msg.
GetCarrier()->Barrier();
// Wait for all message bus connected.
msg_bus_->Barrier();
}

void FleetExecutor::InitCarrier() {
if (!GetCarrier()->IsInit()) {
GetCarrier()->SetMsgBus(msg_bus_);
GetCarrier()->Init(exe_desc_.cur_rank(), runtime_graph_, root_scope_,
minibatch_scope_, microbatch_scopes_, place_);
for (const auto& item : runtime_graph_->carrier_id_to_interceptor_ids()) {
Carrier* carrier = GlobalMap<int64_t, Carrier>::Get(item.first);
PADDLE_ENFORCE_NOT_NULL(carrier, platform::errors::InvalidArgument(
"Carrier has not been created."));
carrier->SetMsgBus(msg_bus_);
carrier->Init(exe_desc_.cur_rank(),
runtime_graph_->interceptor_id_to_rank(), item.second,
runtime_graph_->interceptor_id_to_node(), root_scope_,
minibatch_scope_, microbatch_scopes_, place_);
}
}

Expand Down Expand Up @@ -140,14 +147,9 @@ void FleetExecutor::InitMessageBus() {
}

void FleetExecutor::Run() {
// Run
PADDLE_ENFORCE_EQ(
GetCarrier()->IsInit(), true,
platform::errors::Unavailable("Carrier has not been init yet."));
PADDLE_ENFORCE_EQ(
msg_bus_->IsInit(), true,
platform::errors::Unavailable("MessageBus has not been init yet."));
GetCarrier()->Start();
for (const auto& item : runtime_graph_->carrier_id_to_interceptor_ids()) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

同时跑多个carrier吗

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cpu carrier与gpu carrier?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

是想同时启动起来的,不能运行可以等在那

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

一个时刻应该只跑一个carrier吧,同时跑应该是通过graph的拓扑依赖,两个拓扑依赖的节点可以同时跑

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里应该跑一个carrier,然后不同Program对应不同的carrier,通过外面的接口选择具体是跑哪个carrier

GlobalMap<int64_t, Carrier>::Get(item.first)->Start();
}
for (auto* micro_scop : microbatch_scopes_) {
// By default, we should delete all kid scopes after run executor because
// some operators may create local scope when running, such as while_op.
Expand Down
11 changes: 0 additions & 11 deletions paddle/fluid/distributed/fleet_executor/fleet_executor.h
Original file line number Diff line number Diff line change
Expand Up @@ -42,16 +42,6 @@ class FleetExecutor final {
const std::vector<TaskNode*>& task_nodes,
const std::unordered_map<int64_t, int64_t>& task_id_to_rank);
void Run();
// TODO(liyurui): Change to use registry table for multi-carrier.
static Carrier* GetCarrier();
template <typename... Args>
static Carrier* CreateCarrier(Args&&... args) {
PADDLE_ENFORCE_EQ(
carrier_.get(), nullptr,
platform::errors::AlreadyExists("Carrier has been created already."));
carrier_ = std::make_unique<Carrier>(std::forward<Args>(args)...);
return carrier_.get();
}

private:
DISABLE_COPY_AND_ASSIGN(FleetExecutor);
Expand All @@ -67,7 +57,6 @@ class FleetExecutor final {
// The carriers under FleetExecutor will share message bus,
// using shared_ptr to manage lifetime and condition race.
std::shared_ptr<MessageBus> msg_bus_;
static std::unique_ptr<Carrier> carrier_;
};

} // namespace distributed
Expand Down
49 changes: 49 additions & 0 deletions paddle/fluid/distributed/fleet_executor/global_map.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

namespace paddle {
namespace distributed {

template <typename KeyT, typename ValueT>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这样的话,一个type到另一个type的组合含义是固定的?
如果以后需要interceptor id到thread id,carrier id到stream id的等其他int64到int64的映射,如何复用这个类呢?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

有todo,下一个pr会改

Copy link
Contributor

@FeixLiu FeixLiu Dec 29, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

你是说用InterceptorId的structure� 来代替int 64那个todo?这样如果同时存在interceptor id到thread id,interceptor id到carrier id的映射怎么处理?还是说未来会更新这个global map类,把TODO放在这个类里?

Copy link
Contributor Author

@LiYuRio LiYuRio Dec 29, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

没有这个场景,而且要是有的话也可以加ThreadID和CarrierID

class GlobalMap final {
public:
static ValueT* Get(KeyT id) {
ValueT* item = GetPPtr(id)->get();
PADDLE_ENFORCE_NOT_NULL(
item, platform::errors::NotFound("This value is not in global map."));
return item;
}
template <typename... Args>
static ValueT* Create(KeyT id, Args&&... args) {
auto* ptr = GetPPtr(id);
PADDLE_ENFORCE_EQ(ptr->get(), nullptr,
platform::errors::AlreadyExists(
"This value has already in global map."));
ValueT* item = new ValueT(std::forward<Args>(args)...);
ptr->reset(item);
return item;
}

private:
static std::unique_ptr<ValueT>* GetPPtr(KeyT id) {
static std::mutex mutex;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

目前carrier使用场景都是先初始化固定好的,没有读写并发的情况下,可以不用加锁。当然也不排除之后carrier也是动态的

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个global map除了用在多carrier,interceptor id 到 carrier id也有一个map,这里会有并发的情况,就统一了一下,不过现在没有动态增删interceptor也可以把这个map的建立放到初始化里。

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

后面再增加一个没有锁的类吧,这样都加锁确实开销很大

static std::unordered_map<KeyT, std::unique_ptr<ValueT>> id_to_ptr;
std::unique_lock<std::mutex> lock(mutex);
return &id_to_ptr[id];
}
};
} // namespace distributed
} // namespace paddle
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
#include "paddle/fluid/distributed/fleet_executor/interceptor_message_service.h"
#include "brpc/server.h"
#include "paddle/fluid/distributed/fleet_executor/carrier.h"
#include "paddle/fluid/distributed/fleet_executor/fleet_executor.h"
#include "paddle/fluid/distributed/fleet_executor/global_map.h"

namespace paddle {
namespace distributed {
Expand All @@ -29,7 +29,15 @@ void InterceptorMessageServiceImpl::InterceptorMessageService(
VLOG(3) << "Interceptor Message Service receives a message from interceptor "
<< request->src_id() << " to interceptor " << request->dst_id()
<< ", with the message: " << request->message_type();
bool flag = FleetExecutor::GetCarrier()->EnqueueInterceptorMessage(*request);
// TODO(liyurui): Remove this hard code.
int64_t carrier_id;
if (request->ctrl_message()) {
carrier_id = 0;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

barrier那个逻辑应该可以直接放到这来了

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

把message bus的instance传下来?把message bus也放global map里好像也行

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

嗯,不过这样也有点麻烦了

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

barrier那个逻辑应该可以直接放到这来了

是的

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

用global map的话是不是carrier也不用存message bus的instance了,直接存id就行。

} else {
carrier_id = *GlobalMap<int64_t, int64_t>::Get(request->dst_id());
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里获取了carrier,前面Carrier中不用重新获取了

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个和前面carrier的那个场景不一样,这是不同rank的消息收发,前面是相同rank不同carrier的消息收发

}
bool flag = GlobalMap<int64_t, Carrier>::Get(carrier_id)
->EnqueueInterceptorMessage(*request);
response->set_rst(flag);
}

Expand Down
Loading