Skip to content

Commit

Permalink
make tcp store a global instance
Browse files Browse the repository at this point in the history
  • Loading branch information
LiYuRio committed Aug 4, 2023
1 parent dc4b48f commit 2085a82
Show file tree
Hide file tree
Showing 5 changed files with 98 additions and 48 deletions.
4 changes: 4 additions & 0 deletions paddle/fluid/pybind/communication.cc
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ limitations under the License. */
#include <memory>
#include <string>

#include "paddle/phi/core/distributed/auto_parallel/reshard_utils.h"
#include "paddle/phi/core/distributed/comm_context_manager.h"
#include "paddle/phi/core/distributed/store/tcp_store.h"

Expand Down Expand Up @@ -109,6 +110,9 @@ void BindTCPStore(py::module *m) {
py::arg("world_size"),
py::arg("timeout") = 900,
py::call_guard<py::gil_scoped_release>());

m->def("create_or_get_tcp_store",
&phi::distributed::CreateOrGetGlobalTCPStore);
}

} // namespace pybind
Expand Down
11 changes: 3 additions & 8 deletions paddle/phi/core/distributed/auto_parallel/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,8 @@ proto_library(auto_parallel_proto SRCS auto_parallel.proto)
set(DISTRIBUTED_SRCS "")

if(WITH_DISTRIBUTE)
list(
APPEND
DISTRIBUTED_SRCS
dist_tensor.cc
reshard_function.cc
reshard_split_functor.cc
reshard_utils.cc
r_to_s_reshard_function.cc)
list(APPEND DISTRIBUTED_SRCS dist_tensor.cc reshard_function.cc
reshard_split_functor.cc r_to_s_reshard_function.cc)
endif()

collect_srcs(
Expand All @@ -20,4 +14,5 @@ collect_srcs(
process_mesh.cc
dist_attr.cc
dist_mapper.cc
reshard_utils.cc
${DISTRIBUTED_SRCS})
72 changes: 63 additions & 9 deletions paddle/phi/core/distributed/auto_parallel/reshard_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,11 @@
#include <cstdlib>
#include "glog/logging.h"
#include "paddle/phi/core/distributed/auto_parallel/process_mesh.h"
#include "paddle/phi/core/distributed/store/tcp_store.h"

namespace phi {
namespace distributed {
using auto_parallel::str_split;

bool IsDimsMappingShard(const std::vector<int64_t>& dims_mapping) {
return std::any_of(dims_mapping.begin(),
Expand All @@ -33,15 +35,6 @@ bool IsDimsMappingReplicated(const std::vector<int64_t>& dims_mapping) {
[](int64_t value) { return value == -1; });
}

int64_t GetCurGlobalRank() {
const char* cur_rank = std::getenv("PADDLE_TRAINER_ID");
PADDLE_ENFORCE_NOT_NULL(
cur_rank,
phi::errors::NotFound(
"The environment variable 'PADDLE_TRAINER_ID' cannot be found."));
return std::atoi(cur_rank);
}

std::vector<int64_t> GetCurRankCoordInMesh(const ProcessMesh& process_mesh) {
const auto& process_shape = process_mesh.shape();
const auto& process_ids = process_mesh.process_ids();
Expand Down Expand Up @@ -80,5 +73,66 @@ std::map<int64_t, int64_t> GetSplitAxisWithDimsMapping(
return split_axis_to_mesh_axis;
}

int64_t GetCurGlobalRank() {
const char* cur_rank = std::getenv("PADDLE_TRAINER_ID");
PADDLE_ENFORCE_NOT_NULL(
cur_rank,
phi::errors::NotFound(
"The environment variable 'PADDLE_TRAINER_ID' cannot be found."));
return std::atoi(cur_rank);
}

int64_t GetGlobalWorldSize() {
const char* world_size = std::getenv("PADDLE_TRAINERS_NUM");
PADDLE_ENFORCE_NOT_NULL(
world_size,
phi::errors::NotFound(
"The environment variable 'PADDLE_TRAINERS_NUM' cannot be found."));
return std::atoi(world_size);
}

namespace {
std::string GetMasterEndpoint() {
const char* master_endpoint = std::getenv("PADDLE_MASTER");
if (!master_endpoint) {
const char* trainer_endpoints = std::getenv("PADDLE_TRAINER_ENDPOINTS");
PADDLE_ENFORCE_NOT_NULL(
trainer_endpoints,
phi::errors::NotFound("The environment variable "
"'PADDLE_TRAINER_ENDPOINTS' cannot be found."));
return str_split(trainer_endpoints, ",")[0];
}

PADDLE_ENFORCE_NOT_NULL(
master_endpoint,
phi::errors::NotFound(
"The environment variable 'PADDLE_MASTER' cannot be found."));
return master_endpoint;
}

} // namespace

std::string GetMasterAddr() {
std::string master_endpoint = GetMasterEndpoint();
return str_split(master_endpoint, ":")[0];
}

uint16_t GetMasterPort() {
std::string master_endpoint = GetMasterEndpoint();
return std::stoi(str_split(master_endpoint, ":")[1]);
}

std::shared_ptr<TCPStore> CreateOrGetGlobalTCPStore() {
std::string host = GetMasterAddr();
uint16_t port = GetMasterPort();
int64_t cur_rank = GetCurGlobalRank();
int64_t world_size = GetGlobalWorldSize();
bool is_master = (cur_rank == 0);

static std::shared_ptr<TCPStore> store =
std::make_shared<TCPStore>(host, port, is_master, world_size);
return store;
}

} // namespace distributed
} // namespace phi
16 changes: 14 additions & 2 deletions paddle/phi/core/distributed/auto_parallel/reshard_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,14 @@

#include <cstdint>
#include <map>
#include <memory>
#include <string>
#include <vector>

namespace phi {
namespace distributed {
class TCPStore;

namespace auto_parallel {

class ProcessMesh;
Expand All @@ -31,8 +35,6 @@ bool IsDimsMappingShard(const std::vector<int64_t>& dims_mapping);

bool IsDimsMappingReplicated(const std::vector<int64_t>& dims_mapping);

int64_t GetCurGlobalRank();

// Get the coordinate of cur rank in process mesh. For example, the process mesh
// is [[0, 1], [2, 3], [4, 5], [6, 7]], if the current rank is 4, then will
// return [2, 0]; if the current rank is 3, then will return [1, 1].
Expand All @@ -46,5 +48,15 @@ std::vector<int64_t> GetCurRankCoordInMesh(const ProcessMesh& process_mesh);
std::map<int64_t, int64_t> GetSplitAxisWithDimsMapping(
const std::vector<int64_t>& dims_mapping);

int64_t GetCurGlobalRank();

std::string GetMasterAddr();

int64_t GetGlobalWorldSize();

uint16_t GetMasterPort();

std::shared_ptr<TCPStore> CreateOrGetGlobalTCPStore();

} // namespace distributed
} // namespace phi
43 changes: 14 additions & 29 deletions python/paddle/distributed/collective.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
# limitations under the License.

import datetime
import os

import paddle

Expand Down Expand Up @@ -320,32 +319,18 @@ def is_available():


def _init_parallel_env(backend):
master_endpoint = os.getenv("PADDLE_MASTER", None)
if master_endpoint is None:
master_endpoint = os.getenv("PADDLE_TRAINER_ENDPOINTS").split(',')[0]
assert (
master_endpoint is not None
), "Please set PADDLE_MASTER enviroment variable."
if master_endpoint:
master_addr = master_endpoint.split(":")[0]
master_port = int(master_endpoint.split(":")[1])
global_env = _get_global_env()
rank = global_env.rank
world_size = global_env.world_size
dev_id = global_env.device_id
is_master = rank == 0
store = core.TCPStore(
master_addr,
master_port,
is_master,
world_size,
store = core.get_global_tcp_store()
global_env = _get_global_env()
rank = global_env.rank
world_size = global_env.world_size
dev_id = global_env.device_id

if backend == "gloo":
core.CommContextManager.create_gloo_comm_context(
store, "0", rank, world_size
)
elif backend == "nccl":
core.CommContextManager.set_cuda_device_id(dev_id)
core.CommContextManager.create_nccl_comm_context(
store, "0", rank, world_size
)
if backend == "gloo":
core.CommContextManager.create_gloo_comm_context(
store, "0", rank, world_size
)
elif backend == "nccl":
core.CommContextManager.set_cuda_device_id(dev_id)
core.CommContextManager.create_nccl_comm_context(
store, "0", rank, world_size
)

0 comments on commit 2085a82

Please sign in to comment.