Skip to content

Commit

Permalink
[WIP]飞桨PaddlePaddle 分布式强化学习功能研发 (#45998)
Browse files Browse the repository at this point in the history
* add rpc module in cpp side

* add rpc module in python side

* support win32 and mac for rpc

* 代码优化

* 优化代码

* update rpc

* update rpc launch

* rpc remove rank and world_size api

* fix logger import bug

* remove support for win and mac

* remove support for xpu, npu, cinn and rocm

* remove support for xpu, npu, cinn and rocm

* fix shutdown barrier timeout bug

* update:python_rpc_handler to shared ptr

* fix master shutodwn first bug

* tests support for cpu

* update log to vlog

* update get service info api

* add single process test case

* remove process group

* remove some useless dependencies

* update rpc api comments

* update rpc comments: Example to Examples

* update rpc api comments

* update rpc api comments

* update launch api comments

* update init_rpc comments

* update rpc sync and async comments

* fix bug: init_rpc cant be called repeatly in a process

* update rpc api comment: make master endpoint unique

* update rpc api:service to worker, timeout_ms to timeout

* rename ServiceInfo to WorkerInfo

* refactor: rename server to worker, log to vlog

* add launch test

* remove unused codes

* refine
  • Loading branch information
Ningsir authored Oct 13, 2022
1 parent 8474392 commit f0afcab
Show file tree
Hide file tree
Showing 27 changed files with 1,803 additions and 3 deletions.
4 changes: 3 additions & 1 deletion paddle/fluid/distributed/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,9 @@ set(DISTRIBUTE_COMPILE_FLAGS
if(CMAKE_CXX_COMPILER_VERSION VERSION_GREATER 7.0)
set(DISTRIBUTE_COMPILE_FLAGS "${DISTRIBUTE_COMPILE_FLAGS} -faligned-new")
endif()

if(LINUX)
add_subdirectory(rpc)
endif()
add_subdirectory(common)
add_subdirectory(ps)
add_subdirectory(test)
Expand Down
13 changes: 13 additions & 0 deletions paddle/fluid/distributed/rpc/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
set(PADDLE_RPC_SRCS python_rpc_handler.cc rpc_agent.cc)

set_source_files_properties(
python_rpc_handler.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
set_source_files_properties(rpc_agent.cc PROPERTIES COMPILE_FLAGS
${DISTRIBUTE_COMPILE_FLAGS})

set(PADDLE_RPC_DEPS brpc protobuf glog pybind)
proto_library(paddle_rpc_proto SRCS rpc.proto)
cc_library(
paddle_rpc
SRCS ${PADDLE_RPC_SRCS}
DEPS ${PADDLE_RPC_DEPS} paddle_rpc_proto)
57 changes: 57 additions & 0 deletions paddle/fluid/distributed/rpc/future_wrapper.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

#include <pybind11/pybind11.h>

#include <cassert>
#include <future>
#include <string>

#include "paddle/fluid/distributed/rpc/python_rpc_handler.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/macros.h"

namespace py = pybind11;
namespace paddle {
namespace distributed {
class FutureWrapper {
public:
FutureWrapper() {}
explicit FutureWrapper(std::future<std::string> fut) : fut_(std::move(fut)) {}
py::object wait() {
// GIL must be released, otherwise fut_.get() blocking will cause the
// service to fail to process RPC requests, leading to deadlock
PADDLE_ENFORCE_EQ(
PyGILState_Check(),
false,
platform::errors::Fatal(
"GIL must be released before fut.wait(), otherwise fut_.get() "
"blocking will cause the service to fail to "
"process RPC requests, leading to deadlock"));
auto s = fut_.get();
py::gil_scoped_acquire ag;
std::shared_ptr<PythonRpcHandler> python_handler =
PythonRpcHandler::GetInstance();
py::object obj = python_handler->Deserialize(py::bytes(s));
return obj;
}

private:
DISABLE_COPY_AND_ASSIGN(FutureWrapper);
std::future<std::string> fut_;
};
} // namespace distributed
} // namespace paddle
67 changes: 67 additions & 0 deletions paddle/fluid/distributed/rpc/python_rpc_handler.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/fluid/distributed/rpc/python_rpc_handler.h"

namespace paddle {
namespace distributed {
constexpr auto kInternalModule = "paddle.distributed.rpc.internal";

py::object getFunction(const py::object& module, const char* name) {
py::object fn = module.attr(name);
return fn;
}

PythonRpcHandler::PythonRpcHandler() {
py::gil_scoped_acquire ag;
// import python module
py::object rpc_internal = py::module::import(kInternalModule);
py_run_function_ = getFunction(rpc_internal, "_run_py_func");
py_serialize_ = getFunction(rpc_internal, "_serialize");
py_deserialize_ = getFunction(rpc_internal, "_deserialize");
}

py::object PythonRpcHandler::RunPythonFunc(const py::object& python_func) {
py::gil_scoped_acquire ag;
return py_run_function_(python_func);
}

std::string PythonRpcHandler::Serialize(const py::object& obj) {
py::gil_scoped_acquire ag;
py::object res = py_serialize_(obj);
return res.cast<std::string>();
}

py::object PythonRpcHandler::Deserialize(const std::string& obj) {
py::gil_scoped_acquire ag;
return py_deserialize_(py::bytes(obj));
}

std::shared_ptr<PythonRpcHandler> PythonRpcHandler::python_rpc_handler_ =
nullptr;
std::mutex PythonRpcHandler::lock_;

std::shared_ptr<PythonRpcHandler> PythonRpcHandler::GetInstance() {
if (python_rpc_handler_ == nullptr) {
std::lock_guard<std::mutex> guard(lock_);
if (python_rpc_handler_ == nullptr) {
python_rpc_handler_ = std::make_shared<PythonRpcHandler>();
return python_rpc_handler_;
}
}
return python_rpc_handler_;
}

} // namespace distributed
} // namespace paddle
62 changes: 62 additions & 0 deletions paddle/fluid/distributed/rpc/python_rpc_handler.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

#include <pybind11/pybind11.h>

#include <memory>
#include <mutex>
#include <string>

#include "paddle/fluid/platform/macros.h"

namespace py = pybind11;

namespace paddle {
namespace distributed {

class PYBIND11_EXPORT PythonRpcHandler {
public:
PythonRpcHandler();
~PythonRpcHandler() = default;
static std::shared_ptr<PythonRpcHandler> GetInstance();
// Run a pickled Python function and return the result py::object
py::object RunPythonFunc(const py::object& python_func);

// Serialized a py::object into a string
std::string Serialize(const py::object& obj);

// Deserialize a string into a py::object
py::object Deserialize(const std::string& obj);

private:
DISABLE_COPY_AND_ASSIGN(PythonRpcHandler);

static std::shared_ptr<PythonRpcHandler> python_rpc_handler_;
// Ref to `paddle.distributed.rpc.internal.run_py_func`.
py::object py_run_function_;

// Ref to `paddle.distributed.rpc.internal.serialize`.
py::object py_serialize_;

// Ref to `paddle.distributed.rpc.internal.deserialize`.
py::object py_deserialize_;

// Lock to protect initialization.
static std::mutex lock_;
};

} // namespace distributed
} // namespace paddle
33 changes: 33 additions & 0 deletions paddle/fluid/distributed/rpc/rpc.proto
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.


syntax="proto2";
package paddle.distributed;

option cc_generic_services = true;
option cc_enable_arenas = true;

message RpcRequest {
required bytes message = 1;
};

message RpcResponse {
required bytes message = 1;
};

service RpcBaseService {
rpc Send(RpcRequest) returns (RpcResponse);
rpc InvokeRpc(RpcRequest) returns (RpcResponse);
};
145 changes: 145 additions & 0 deletions paddle/fluid/distributed/rpc/rpc_agent.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/fluid/distributed/rpc/rpc_agent.h"

#include <memory>
#include <string>
#include <unordered_map>
#include <vector>

#include "paddle/fluid/platform/enforce.h"

namespace paddle {
namespace distributed {

const int kTimeoutMs = 500000;
const int kConnectTimeoutMs = 10000;
const int kMaxRetry = 5;
const int kCloseWaitMs = 1000;
std::shared_ptr<RpcAgent> RpcAgent::rpc_agent_instance_ = nullptr;

RpcAgent::RpcAgent(std::string name, std::vector<WorkerInfo> infos) {
name_ = std::move(name);
for (auto info : infos) {
name_to_infos_.insert({info.name_, info});
id_to_infos_.insert({info.id_, info});
}
this->infos_ = std::move(infos);
auto it = name_to_infos_.find(name_);
this->rank_ = it->second.id_;
rpc_service_ = std::make_shared<RpcService>();
PADDLE_ENFORCE_EQ(
server_.AddService(rpc_service_.get(), brpc::SERVER_DOESNT_OWN_SERVICE),
0,
platform::errors::Fatal("Fail to add service: %s", name));
}

int RpcAgent::StartWorker() {
auto info = GetWorkerInfo(name_);
// Start the server.
int port = info.port_;
brpc::ServerOptions options;
PADDLE_ENFORCE_EQ(server_.Start(port, &options),
0,
platform::errors::Fatal("Fail to start worker: %s", name_));
VLOG(0) << "Start worker : " << name_;
return 0;
}

int RpcAgent::StartClient() {
// Initialize the channel, NULL means using default options.
brpc::ChannelOptions channel_options;
channel_options.protocol = "baidu_std";
channel_options.timeout_ms = kTimeoutMs;
channel_options.connection_type = "pooled";
channel_options.connect_timeout_ms = kConnectTimeoutMs;
channel_options.max_retry = kMaxRetry;
channels_.resize(name_to_infos_.size());
// build connection from client to all servers
for (std::size_t i = 0; i < channels_.size(); i++) {
auto info = id_to_infos_.find(i)->second;
channels_[i].reset(new brpc::Channel());
PADDLE_ENFORCE_EQ(
channels_[i]->Init(info.ip_.c_str(), info.port_, &channel_options),
0,
platform::errors::Fatal(
"Fail to initialize channel: %d, ip: %s, port: %d",
i,
info.ip_,
info.port_));
}
VLOG(0) << "Init Channels: " << name_;
return 0;
}

int RpcAgent::Stop() {
VLOG(0) << "Worker: " << name_ << " is going to stop.";
server_.Stop(kCloseWaitMs);
server_.Join();
rpc_agent_instance_ = nullptr;
VLOG(0) << "Worker: " << name_ << " has stopped";
return 0;
}
void OnRpcDone::Run() {
// delete this after Run
std::unique_ptr<OnRpcDone> self_guard(this);
PADDLE_ENFORCE_EQ(
cntl_.Failed(), false, platform::errors::Fatal(cntl_.ErrorText()));
promise_->set_value(response_.message());
VLOG(2) << "Received response from " << cntl_.remote_side() << " to "
<< cntl_.local_side() << " (attached=" << cntl_.response_attachment()
<< ")"
<< " latency=" << cntl_.latency_us() << "us";
}

std::future<std::string> RpcAgent::InvokeRpc(const std::string &py_func,
const std::string &to,
int timeout_ms = kTimeoutMs) {
auto it = name_to_infos_.find(to);
PADDLE_ENFORCE_NE(
it,
name_to_infos_.end(),
platform::errors::OutOfRange("Worker %s doesn't exist!", to));
uint32_t id = it->second.id_;
auto channel = channels_[id];
// `done` must be allocated on the heap because its life cycle is after
// calling done.Run().
OnRpcDone *done = new OnRpcDone;
done->cntl_.set_timeout_ms(timeout_ms);
done->request_.set_message(py_func);
std::future<std::string> fut = done->GetFuture();
RpcBaseService_Stub stub(channel.get());
stub.InvokeRpc(&done->cntl_, &done->request_, &done->response_, done);
return fut;
}

std::shared_ptr<RpcAgent> RpcAgent::RpcAgentInstance() {
PADDLE_ENFORCE_NE(rpc_agent_instance_,
nullptr,
platform::errors::Fatal(
"RpcAgent is not set, please calling "
"paddle.distributed.rpc.int_rpc() to init rpc agent."));
return rpc_agent_instance_;
}
void RpcAgent::SetAgentInstance(std::shared_ptr<RpcAgent> agent) {
PADDLE_ENFORCE_EQ(
rpc_agent_instance_,
nullptr,
platform::errors::Fatal(
"RpcAgent has been set, please don't set rpc agent repeatly."));
rpc_agent_instance_ = agent;
}
} // namespace distributed
} // namespace paddle
Loading

0 comments on commit f0afcab

Please sign in to comment.