Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

initialize processgroupnccl with store #40181

Merged
merged 16 commits into from
Mar 7, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 18 additions & 27 deletions paddle/fluid/distributed/collective/ProcessGroupNCCL.cc
Original file line number Diff line number Diff line change
Expand Up @@ -156,36 +156,27 @@ bool ProcessGroupNCCL::NCCLTask::Wait(std::chrono::milliseconds timeout) {
// Same as Wait
void ProcessGroupNCCL::NCCLTask::Synchronize() { Wait(kWaitTimeout); }

ProcessGroupNCCL::ProcessGroupNCCL(const ProcessGroupStrategy& strategy,
ProcessGroupNCCL::ProcessGroupNCCL(const std::shared_ptr<Store>& store,
int rank, int size)
: ProcessGroup(rank, size), strategy_(strategy) {}

void ProcessGroupNCCL::BcastNCCLId(
std::vector<ncclUniqueId>& nccl_ids, // NOLINT
int root, int server_fd) {
if (strategy_.local_rank_ == root) {
std::vector<std::string> other_trainers;
for (auto& ep : strategy_.trainer_endpoints_) {
if (ep != strategy_.current_endpoint_) {
other_trainers.push_back(ep);
}
}
platform::SendBroadCastCommID(other_trainers, &nccl_ids);
} else {
platform::RecvBroadCastCommID(server_fd, strategy_.current_endpoint_,
&nccl_ids);
}
}
: ProcessGroup(rank, size), store_(store) {}

void ProcessGroupNCCL::BroadcastUniqueNCCLID(
std::vector<ncclUniqueId>& nccl_ids) { // NOLINT

int server_fd = -1;
if (rank_ != 0) {
server_fd = platform::SocketServer::GetInstance(strategy_.current_endpoint_)
.socket();
if (rank_ == 0) {
for (size_t i = 0; i < nccl_ids.size(); i++) {
auto key = "ProcessGroupNCCL/nccl_ids/" + std::to_string(i);
auto nccl_id = std::vector<uint8_t>(
reinterpret_cast<uint8_t*>(&nccl_ids[i]),
reinterpret_cast<uint8_t*>(&nccl_ids[i]) + NCCL_UNIQUE_ID_BYTES);
store_->set(key, nccl_id);
}
} else {
for (size_t i = 0; i < nccl_ids.size(); i++) {
auto key = "ProcessGroupNCCL/nccl_ids/" + std::to_string(i);
auto ret = store_->get(key);
std::memcpy(&nccl_ids[i], ret.data(), ret.size());
}
}
BcastNCCLId(nccl_ids, 0, server_fd);
}

// create NCCLManager cache for places_key
Expand Down Expand Up @@ -213,8 +204,8 @@ void ProcessGroupNCCL::CreateNCCLManagerCache(
}
BroadcastUniqueNCCLID(nccl_ids);

VLOG(3) << "init nccl rank: " << strategy_.local_rank_
<< ", nranks: " << strategy_.nranks_ << ", place: " << places_key
VLOG(3) << "init nccl rank: " << rank_ << ", nranks: " << size_
<< ", place: " << places_key
<< ", nccl uniqueid: " << SerializeNCCLUniqueId(nccl_id);

std::vector<std::unique_ptr<CUDADeviceContext>> dev_ctx;
Expand Down
5 changes: 3 additions & 2 deletions paddle/fluid/distributed/collective/ProcessGroupNCCL.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
#include "paddle/fluid/platform/cuda_device_guard.h"
#include "paddle/fluid/platform/device_context.h"

#include "paddle/fluid/distributed/store/store.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/gen_comm_id_helper.h"
#include "paddle/fluid/platform/place.h"
Expand Down Expand Up @@ -75,7 +76,7 @@ class ProcessGroupNCCL : public ProcessGroup {
private:
};

ProcessGroupNCCL(const ProcessGroupStrategy& strategy, int rank, int size);
ProcessGroupNCCL(const std::shared_ptr<Store>& store, int rank, int size);

const std::string GetBackendName() const override {
return std::string(NCCL_BACKEND_NAME);
Expand Down Expand Up @@ -118,7 +119,7 @@ class ProcessGroupNCCL : public ProcessGroup {
const std::vector<Tensor>& inputs);

protected:
ProcessGroupStrategy strategy_;
std::shared_ptr<Store> store_;
std::shared_ptr<NCCLCommManager> nccl_comm_;
std::mutex mutex_;
std::unordered_map<std::string, std::vector<std::shared_ptr<NCCLCommManager>>>
Expand Down
23 changes: 17 additions & 6 deletions paddle/fluid/distributed/store/store.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,15 +25,26 @@ namespace distributed {

class Store {
public:
Store() = delete;
Store() : _timeout(tcputils::kNoTimeout) {}
explicit Store(const std::chrono::seconds& timeout) : _timeout(timeout) {}
virtual ~Store() = default;

virtual int64_t add(const std::string& key, int64_t value) = 0;
virtual std::vector<uint8_t> get(const std::string& key) = 0;
virtual void wait(const std::string& key) = 0;
virtual void set(const std::string& key,
const std::vector<uint8_t>& value) = 0;
virtual int64_t add(const std::string& key, int64_t value) {
PADDLE_THROW(platform::errors::InvalidArgument(
"Implement the add method in the subclass."));
}
virtual std::vector<uint8_t> get(const std::string& key) {
PADDLE_THROW(platform::errors::InvalidArgument(
"Implement the add method in the subclass."));
}
virtual void wait(const std::string& key) {
PADDLE_THROW(platform::errors::InvalidArgument(
"Implement the add method in the subclass."));
}
virtual void set(const std::string& key, const std::vector<uint8_t>& value) {
PADDLE_THROW(platform::errors::InvalidArgument(
"Implement the add method in the subclass."));
}

virtual const std::chrono::seconds& timeout() const { return _timeout; }

Expand Down
36 changes: 30 additions & 6 deletions paddle/fluid/pybind/communication.cc
Original file line number Diff line number Diff line change
Expand Up @@ -30,18 +30,42 @@ namespace pybind {

using TCPStore = paddle::distributed::TCPStore;

void BindTCPStore(py::module* m) {
py::class_<TCPStore, std::shared_ptr<TCPStore>>(*m, "TCPStore")
void BindTCPStore(py::module *m) {
auto Store =
py::class_<distributed::Store, std::shared_ptr<distributed::Store>>(
*m, "Store")
.def(py::init<>())
.def("set",
[](distributed::Store &self, const std::string &key,
const std::string &value) {
std::vector<uint8_t> data(value.begin(), value.end());
self.set(key, data);
},
py::arg("key"), py::arg("value"),
py::call_guard<py::gil_scoped_release>())
.def("get",
[](distributed::Store &self,
const std::string &key) -> py::bytes {
auto data = self.get(key);
return py::bytes(reinterpret_cast<char *>(data.data()),
data.size());
},
py::arg("key"), py::call_guard<py::gil_scoped_release>())
.def("add", &distributed::Store::add,
py::call_guard<py::gil_scoped_release>())
.def("wait", &distributed::Store::wait,
py::call_guard<py::gil_scoped_release>());

py::class_<TCPStore, std::shared_ptr<TCPStore>>(*m, "TCPStore", Store)
.def(py::init([](std::string hostname, uint16_t port, bool is_master,
size_t world_size, std::chrono::seconds timeout) {
return std::make_shared<TCPStore>(hostname, port, is_master,
world_size, timeout);
}),
py::arg("hostname"), py::arg("port"), py::arg("is_master"),
py::arg("world_size"), py::arg("timeout"),
py::call_guard<py::gil_scoped_release>())
.def("add", &TCPStore::add)
.def("get", &TCPStore::get);
py::arg("world_size"),
py::arg("timeout") = distributed::tcputils::kNoTimeout,
py::call_guard<py::gil_scoped_release>());
}

} // namespace pybind
Expand Down
44 changes: 2 additions & 42 deletions paddle/fluid/pybind/distributed_py.cc
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,7 @@ void BindDistributed(py::module *m) {
py::class_<distributed::ProcessGroupNCCL,
std::shared_ptr<distributed::ProcessGroupNCCL>>(
*m, "ProcessGroupNCCL", ProcessGroup)
.def(py::init<const distributed::ProcessGroupStrategy &, int, int>(),
.def(py::init<const std::shared_ptr<distributed::Store> &, int, int>(),
py::call_guard<py::gil_scoped_release>());
#endif

Expand All @@ -210,44 +210,6 @@ void BindDistributed(py::module *m) {
.def("synchronize", &distributed::ProcessGroup::Task::Synchronize,
py::call_guard<py::gil_scoped_release>());

// define parallel strategy, it will be removed
py::class_<distributed::ProcessGroupStrategy> pg_strategy(
*m, "ProcessGroupStrategy", "");
pg_strategy.def(py::init())
.def_property("nranks",
[](const distributed::ProcessGroupStrategy &self) {
return self.nranks_;
},
[](distributed::ProcessGroupStrategy &self, int nranks) {
self.nranks_ = nranks;
})
.def_property("local_rank",
[](const distributed::ProcessGroupStrategy &self) {
return self.local_rank_;
},
[](distributed::ProcessGroupStrategy &self,
int local_rank) { self.local_rank_ = local_rank; })
.def_property(
"trainer_endpoints",
[](const distributed::ProcessGroupStrategy &self) {
return self.trainer_endpoints_;
},
[](distributed::ProcessGroupStrategy &self,
std::vector<std::string> eps) { self.trainer_endpoints_ = eps; })
.def_property("current_endpoint",
[](const distributed::ProcessGroupStrategy &self) {
return self.current_endpoint_;
},
[](distributed::ProcessGroupStrategy &self,
const std::string &ep) { self.current_endpoint_ = ep; })
.def_property("nrings",
[](const distributed::ProcessGroupStrategy &self) {
return self.nrings_;
},
[](distributed::ProcessGroupStrategy &self, int nrings) {
self.nrings_ = nrings;
});

#if defined(PADDLE_WITH_GLOO)
py::class_<GlooOptions>(*m, "GlooOptions")
.def(py::init<>())
Expand Down Expand Up @@ -279,9 +241,7 @@ void BindDistributed(py::module *m) {
return std::make_shared<ProcessGroupGloo>(store, rank, world_size,
opts);
}),
py::arg("store"), py::arg("rank"),
py::arg("world_size"), // py::arg("timeout") =
// kProcessGroupDefaultTimeout,
py::arg("store"), py::arg("rank"), py::arg("world_size"),
py::call_guard<py::gil_scoped_release>())
.def_static("create_default_device",
&ProcessGroupGloo::createDefaultDevice);
Expand Down
19 changes: 5 additions & 14 deletions python/paddle/fluid/tests/unittests/process_group_nccl.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,22 +27,13 @@
from paddle.fluid.framework import _test_eager_guard
from paddle.fluid.dygraph.parallel import ParallelEnv

ProcessGroupStrategy = core.ProcessGroupStrategy


def init_process_group(strategy=None):
# this will remove
if strategy is None:
strategy = ProcessGroupStrategy()
strategy.nranks = ParallelEnv().nranks
strategy.local_rank = ParallelEnv().local_rank
strategy.trainer_endpoints = ParallelEnv().trainer_endpoints
strategy.current_endpoint = ParallelEnv().current_endpoint
if strategy.nranks < 2:
return

pg_group = core.ProcessGroupNCCL(strategy, strategy.local_rank,
strategy.nranks)
nranks = ParallelEnv().nranks
rank = ParallelEnv().local_rank
is_master = True if rank == 0 else False
store = paddle.fluid.core.TCPStore("127.0.0.1", 6173, is_master, nranks)
pg_group = core.ProcessGroupNCCL(store, rank, nranks)

return pg_group

Expand Down