-
Notifications
You must be signed in to change notification settings - Fork 5.6k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[WIP]飞桨PaddlePaddle 分布式强化学习功能研发 (#45998)
* add rpc module in cpp side * add rpc module in python side * support win32 and mac for rpc * 代码优化 * 优化代码 * update rpc * update rpc launch * rpc remove rank and world_size api * fix logger import bug * remove support for win and mac * remove support for xpu, npu, cinn and rocm * remove support for xpu, npu, cinn and rocm * fix shutdown barrier timeout bug * update:python_rpc_handler to shared ptr * fix master shutodwn first bug * tests support for cpu * update log to vlog * update get service info api * add single process test case * remove process group * remove some useless dependencies * update rpc api comments * update rpc comments: Example to Examples * update rpc api comments * update rpc api comments * update launch api comments * update init_rpc comments * update rpc sync and async comments * fix bug: init_rpc cant be called repeatly in a process * update rpc api comment: make master endpoint unique * update rpc api:service to worker, timeout_ms to timeout * rename ServiceInfo to WorkerInfo * refactor: rename server to worker, log to vlog * add launch test * remove unused codes * refine
- Loading branch information
Showing
27 changed files
with
1,803 additions
and
3 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
set(PADDLE_RPC_SRCS python_rpc_handler.cc rpc_agent.cc) | ||
|
||
set_source_files_properties( | ||
python_rpc_handler.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) | ||
set_source_files_properties(rpc_agent.cc PROPERTIES COMPILE_FLAGS | ||
${DISTRIBUTE_COMPILE_FLAGS}) | ||
|
||
set(PADDLE_RPC_DEPS brpc protobuf glog pybind) | ||
proto_library(paddle_rpc_proto SRCS rpc.proto) | ||
cc_library( | ||
paddle_rpc | ||
SRCS ${PADDLE_RPC_SRCS} | ||
DEPS ${PADDLE_RPC_DEPS} paddle_rpc_proto) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,57 @@ | ||
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. | ||
// | ||
// Licensed under the Apache License, Version 2.0 (the "License"); | ||
// you may not use this file except in compliance with the License. | ||
// You may obtain a copy of the License at | ||
// | ||
// http://www.apache.org/licenses/LICENSE-2.0 | ||
// | ||
// Unless required by applicable law or agreed to in writing, software | ||
// distributed under the License is distributed on an "AS IS" BASIS, | ||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
// See the License for the specific language governing permissions and | ||
// limitations under the License. | ||
|
||
#pragma once | ||
|
||
#include <pybind11/pybind11.h> | ||
|
||
#include <cassert> | ||
#include <future> | ||
#include <string> | ||
|
||
#include "paddle/fluid/distributed/rpc/python_rpc_handler.h" | ||
#include "paddle/fluid/platform/enforce.h" | ||
#include "paddle/fluid/platform/macros.h" | ||
|
||
namespace py = pybind11; | ||
namespace paddle { | ||
namespace distributed { | ||
class FutureWrapper { | ||
public: | ||
FutureWrapper() {} | ||
explicit FutureWrapper(std::future<std::string> fut) : fut_(std::move(fut)) {} | ||
py::object wait() { | ||
// GIL must be released, otherwise fut_.get() blocking will cause the | ||
// service to fail to process RPC requests, leading to deadlock | ||
PADDLE_ENFORCE_EQ( | ||
PyGILState_Check(), | ||
false, | ||
platform::errors::Fatal( | ||
"GIL must be released before fut.wait(), otherwise fut_.get() " | ||
"blocking will cause the service to fail to " | ||
"process RPC requests, leading to deadlock")); | ||
auto s = fut_.get(); | ||
py::gil_scoped_acquire ag; | ||
std::shared_ptr<PythonRpcHandler> python_handler = | ||
PythonRpcHandler::GetInstance(); | ||
py::object obj = python_handler->Deserialize(py::bytes(s)); | ||
return obj; | ||
} | ||
|
||
private: | ||
DISABLE_COPY_AND_ASSIGN(FutureWrapper); | ||
std::future<std::string> fut_; | ||
}; | ||
} // namespace distributed | ||
} // namespace paddle |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,67 @@ | ||
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. | ||
// | ||
// Licensed under the Apache License, Version 2.0 (the "License"); | ||
// you may not use this file except in compliance with the License. | ||
// You may obtain a copy of the License at | ||
// | ||
// http://www.apache.org/licenses/LICENSE-2.0 | ||
// | ||
// Unless required by applicable law or agreed to in writing, software | ||
// distributed under the License is distributed on an "AS IS" BASIS, | ||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
// See the License for the specific language governing permissions and | ||
// limitations under the License. | ||
|
||
#include "paddle/fluid/distributed/rpc/python_rpc_handler.h" | ||
|
||
namespace paddle { | ||
namespace distributed { | ||
constexpr auto kInternalModule = "paddle.distributed.rpc.internal"; | ||
|
||
py::object getFunction(const py::object& module, const char* name) { | ||
py::object fn = module.attr(name); | ||
return fn; | ||
} | ||
|
||
PythonRpcHandler::PythonRpcHandler() { | ||
py::gil_scoped_acquire ag; | ||
// import python module | ||
py::object rpc_internal = py::module::import(kInternalModule); | ||
py_run_function_ = getFunction(rpc_internal, "_run_py_func"); | ||
py_serialize_ = getFunction(rpc_internal, "_serialize"); | ||
py_deserialize_ = getFunction(rpc_internal, "_deserialize"); | ||
} | ||
|
||
py::object PythonRpcHandler::RunPythonFunc(const py::object& python_func) { | ||
py::gil_scoped_acquire ag; | ||
return py_run_function_(python_func); | ||
} | ||
|
||
std::string PythonRpcHandler::Serialize(const py::object& obj) { | ||
py::gil_scoped_acquire ag; | ||
py::object res = py_serialize_(obj); | ||
return res.cast<std::string>(); | ||
} | ||
|
||
py::object PythonRpcHandler::Deserialize(const std::string& obj) { | ||
py::gil_scoped_acquire ag; | ||
return py_deserialize_(py::bytes(obj)); | ||
} | ||
|
||
std::shared_ptr<PythonRpcHandler> PythonRpcHandler::python_rpc_handler_ = | ||
nullptr; | ||
std::mutex PythonRpcHandler::lock_; | ||
|
||
std::shared_ptr<PythonRpcHandler> PythonRpcHandler::GetInstance() { | ||
if (python_rpc_handler_ == nullptr) { | ||
std::lock_guard<std::mutex> guard(lock_); | ||
if (python_rpc_handler_ == nullptr) { | ||
python_rpc_handler_ = std::make_shared<PythonRpcHandler>(); | ||
return python_rpc_handler_; | ||
} | ||
} | ||
return python_rpc_handler_; | ||
} | ||
|
||
} // namespace distributed | ||
} // namespace paddle |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,62 @@ | ||
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. | ||
// | ||
// Licensed under the Apache License, Version 2.0 (the "License"); | ||
// you may not use this file except in compliance with the License. | ||
// You may obtain a copy of the License at | ||
// | ||
// http://www.apache.org/licenses/LICENSE-2.0 | ||
// | ||
// Unless required by applicable law or agreed to in writing, software | ||
// distributed under the License is distributed on an "AS IS" BASIS, | ||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
// See the License for the specific language governing permissions and | ||
// limitations under the License. | ||
|
||
#pragma once | ||
|
||
#include <pybind11/pybind11.h> | ||
|
||
#include <memory> | ||
#include <mutex> | ||
#include <string> | ||
|
||
#include "paddle/fluid/platform/macros.h" | ||
|
||
namespace py = pybind11; | ||
|
||
namespace paddle { | ||
namespace distributed { | ||
|
||
class PYBIND11_EXPORT PythonRpcHandler { | ||
public: | ||
PythonRpcHandler(); | ||
~PythonRpcHandler() = default; | ||
static std::shared_ptr<PythonRpcHandler> GetInstance(); | ||
// Run a pickled Python function and return the result py::object | ||
py::object RunPythonFunc(const py::object& python_func); | ||
|
||
// Serialized a py::object into a string | ||
std::string Serialize(const py::object& obj); | ||
|
||
// Deserialize a string into a py::object | ||
py::object Deserialize(const std::string& obj); | ||
|
||
private: | ||
DISABLE_COPY_AND_ASSIGN(PythonRpcHandler); | ||
|
||
static std::shared_ptr<PythonRpcHandler> python_rpc_handler_; | ||
// Ref to `paddle.distributed.rpc.internal.run_py_func`. | ||
py::object py_run_function_; | ||
|
||
// Ref to `paddle.distributed.rpc.internal.serialize`. | ||
py::object py_serialize_; | ||
|
||
// Ref to `paddle.distributed.rpc.internal.deserialize`. | ||
py::object py_deserialize_; | ||
|
||
// Lock to protect initialization. | ||
static std::mutex lock_; | ||
}; | ||
|
||
} // namespace distributed | ||
} // namespace paddle |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,33 @@ | ||
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. | ||
// | ||
// Licensed under the Apache License, Version 2.0 (the "License"); | ||
// you may not use this file except in compliance with the License. | ||
// You may obtain a copy of the License at | ||
// | ||
// http://www.apache.org/licenses/LICENSE-2.0 | ||
// | ||
// Unless required by applicable law or agreed to in writing, software | ||
// distributed under the License is distributed on an "AS IS" BASIS, | ||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
// See the License for the specific language governing permissions and | ||
// limitations under the License. | ||
|
||
|
||
syntax="proto2"; | ||
package paddle.distributed; | ||
|
||
option cc_generic_services = true; | ||
option cc_enable_arenas = true; | ||
|
||
message RpcRequest { | ||
required bytes message = 1; | ||
}; | ||
|
||
message RpcResponse { | ||
required bytes message = 1; | ||
}; | ||
|
||
service RpcBaseService { | ||
rpc Send(RpcRequest) returns (RpcResponse); | ||
rpc InvokeRpc(RpcRequest) returns (RpcResponse); | ||
}; |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,145 @@ | ||
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. | ||
// | ||
// Licensed under the Apache License, Version 2.0 (the "License"); | ||
// you may not use this file except in compliance with the License. | ||
// You may obtain a copy of the License at | ||
// | ||
// http://www.apache.org/licenses/LICENSE-2.0 | ||
// | ||
// Unless required by applicable law or agreed to in writing, software | ||
// distributed under the License is distributed on an "AS IS" BASIS, | ||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
// See the License for the specific language governing permissions and | ||
// limitations under the License. | ||
|
||
#include "paddle/fluid/distributed/rpc/rpc_agent.h" | ||
|
||
#include <memory> | ||
#include <string> | ||
#include <unordered_map> | ||
#include <vector> | ||
|
||
#include "paddle/fluid/platform/enforce.h" | ||
|
||
namespace paddle { | ||
namespace distributed { | ||
|
||
const int kTimeoutMs = 500000; | ||
const int kConnectTimeoutMs = 10000; | ||
const int kMaxRetry = 5; | ||
const int kCloseWaitMs = 1000; | ||
std::shared_ptr<RpcAgent> RpcAgent::rpc_agent_instance_ = nullptr; | ||
|
||
RpcAgent::RpcAgent(std::string name, std::vector<WorkerInfo> infos) { | ||
name_ = std::move(name); | ||
for (auto info : infos) { | ||
name_to_infos_.insert({info.name_, info}); | ||
id_to_infos_.insert({info.id_, info}); | ||
} | ||
this->infos_ = std::move(infos); | ||
auto it = name_to_infos_.find(name_); | ||
this->rank_ = it->second.id_; | ||
rpc_service_ = std::make_shared<RpcService>(); | ||
PADDLE_ENFORCE_EQ( | ||
server_.AddService(rpc_service_.get(), brpc::SERVER_DOESNT_OWN_SERVICE), | ||
0, | ||
platform::errors::Fatal("Fail to add service: %s", name)); | ||
} | ||
|
||
int RpcAgent::StartWorker() { | ||
auto info = GetWorkerInfo(name_); | ||
// Start the server. | ||
int port = info.port_; | ||
brpc::ServerOptions options; | ||
PADDLE_ENFORCE_EQ(server_.Start(port, &options), | ||
0, | ||
platform::errors::Fatal("Fail to start worker: %s", name_)); | ||
VLOG(0) << "Start worker : " << name_; | ||
return 0; | ||
} | ||
|
||
int RpcAgent::StartClient() { | ||
// Initialize the channel, NULL means using default options. | ||
brpc::ChannelOptions channel_options; | ||
channel_options.protocol = "baidu_std"; | ||
channel_options.timeout_ms = kTimeoutMs; | ||
channel_options.connection_type = "pooled"; | ||
channel_options.connect_timeout_ms = kConnectTimeoutMs; | ||
channel_options.max_retry = kMaxRetry; | ||
channels_.resize(name_to_infos_.size()); | ||
// build connection from client to all servers | ||
for (std::size_t i = 0; i < channels_.size(); i++) { | ||
auto info = id_to_infos_.find(i)->second; | ||
channels_[i].reset(new brpc::Channel()); | ||
PADDLE_ENFORCE_EQ( | ||
channels_[i]->Init(info.ip_.c_str(), info.port_, &channel_options), | ||
0, | ||
platform::errors::Fatal( | ||
"Fail to initialize channel: %d, ip: %s, port: %d", | ||
i, | ||
info.ip_, | ||
info.port_)); | ||
} | ||
VLOG(0) << "Init Channels: " << name_; | ||
return 0; | ||
} | ||
|
||
int RpcAgent::Stop() { | ||
VLOG(0) << "Worker: " << name_ << " is going to stop."; | ||
server_.Stop(kCloseWaitMs); | ||
server_.Join(); | ||
rpc_agent_instance_ = nullptr; | ||
VLOG(0) << "Worker: " << name_ << " has stopped"; | ||
return 0; | ||
} | ||
void OnRpcDone::Run() { | ||
// delete this after Run | ||
std::unique_ptr<OnRpcDone> self_guard(this); | ||
PADDLE_ENFORCE_EQ( | ||
cntl_.Failed(), false, platform::errors::Fatal(cntl_.ErrorText())); | ||
promise_->set_value(response_.message()); | ||
VLOG(2) << "Received response from " << cntl_.remote_side() << " to " | ||
<< cntl_.local_side() << " (attached=" << cntl_.response_attachment() | ||
<< ")" | ||
<< " latency=" << cntl_.latency_us() << "us"; | ||
} | ||
|
||
std::future<std::string> RpcAgent::InvokeRpc(const std::string &py_func, | ||
const std::string &to, | ||
int timeout_ms = kTimeoutMs) { | ||
auto it = name_to_infos_.find(to); | ||
PADDLE_ENFORCE_NE( | ||
it, | ||
name_to_infos_.end(), | ||
platform::errors::OutOfRange("Worker %s doesn't exist!", to)); | ||
uint32_t id = it->second.id_; | ||
auto channel = channels_[id]; | ||
// `done` must be allocated on the heap because its life cycle is after | ||
// calling done.Run(). | ||
OnRpcDone *done = new OnRpcDone; | ||
done->cntl_.set_timeout_ms(timeout_ms); | ||
done->request_.set_message(py_func); | ||
std::future<std::string> fut = done->GetFuture(); | ||
RpcBaseService_Stub stub(channel.get()); | ||
stub.InvokeRpc(&done->cntl_, &done->request_, &done->response_, done); | ||
return fut; | ||
} | ||
|
||
std::shared_ptr<RpcAgent> RpcAgent::RpcAgentInstance() { | ||
PADDLE_ENFORCE_NE(rpc_agent_instance_, | ||
nullptr, | ||
platform::errors::Fatal( | ||
"RpcAgent is not set, please calling " | ||
"paddle.distributed.rpc.int_rpc() to init rpc agent.")); | ||
return rpc_agent_instance_; | ||
} | ||
void RpcAgent::SetAgentInstance(std::shared_ptr<RpcAgent> agent) { | ||
PADDLE_ENFORCE_EQ( | ||
rpc_agent_instance_, | ||
nullptr, | ||
platform::errors::Fatal( | ||
"RpcAgent has been set, please don't set rpc agent repeatly.")); | ||
rpc_agent_instance_ = agent; | ||
} | ||
} // namespace distributed | ||
} // namespace paddle |
Oops, something went wrong.