Skip to content

Commit

Permalink
[dist attr 迁移到 phi]Dist attr (#53848)
Browse files Browse the repository at this point in the history
* merge code from forsish

* polish

* paddle/fluid/pybind/auto_parallel_py.cc

* polish

* polish

* polish

* polish

* polish

* polish

* polish

* polish

* polish

* polish

* polish
  • Loading branch information
liuzhenhai93 authored May 23, 2023
1 parent 4af0f14 commit be1152a
Show file tree
Hide file tree
Showing 26 changed files with 607 additions and 483 deletions.
24 changes: 4 additions & 20 deletions paddle/fluid/distributed/auto_parallel/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,23 +1,7 @@
proto_library(auto_parallel_proto SRCS auto_parallel.proto)

cc_library(
device_mesh
SRCS device_mesh.cc
DEPS auto_parallel_proto phi_enforce)

cc_library(
process_mesh
SRCS process_mesh.cc
DEPS auto_parallel_proto phi_enforce)

cc_library(
dist_attr
op_dist_attr
SRCS dist_attr.cc
DEPS process_mesh auto_parallel_proto proto_desc phi_enforce)

cc_library(
dist_mapper
SRCS dist_mapper.cc
DEPS device_mesh auto_parallel_proto phi_enforce)
DEPS dist_attr process_mesh dist_mapper auto_parallel_proto proto_desc
phi_enforce)

cc_library(auto_parallel DEPS device_mesh process_mesh dist_attr dist_mapper)
add_subdirectory(test)
260 changes: 8 additions & 252 deletions paddle/fluid/distributed/auto_parallel/dist_attr.cc
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,9 @@ namespace paddle {
namespace distributed {
namespace auto_parallel {

std::vector<std::string> TensorDistAttr::fields_{
"process_mesh", "dims_mapping", "batch_dim", "dynamic_dims"};
using phi::distributed::auto_parallel::str_join;

static inline std::vector<int64_t> get_tensor_shape(const VarDesc* tensor) {
std::vector<int64_t> get_tensor_shape(const VarDesc* tensor) {
if (tensor == nullptr) return std::vector<int64_t>();
switch (tensor->GetType()) {
case framework::proto::VarType::READER:
Expand All @@ -43,251 +42,6 @@ static inline std::vector<int64_t> get_tensor_shape(const VarDesc* tensor) {
}
}

TensorDistAttr::TensorDistAttr(const VarDesc& tensor) {
VLOG(4) << "[TensorDistAttr constructor] tensor name: " << tensor.Name();
std::vector<int64_t> tensor_shape = get_tensor_shape(&tensor);
set_default_dims_mapping(tensor_shape);
set_default_dynamic_dims(tensor_shape);
}

TensorDistAttr::TensorDistAttr(const TensorDistAttr& dist_attr) {
copy_from(dist_attr);
}

TensorDistAttr& TensorDistAttr::operator=(const TensorDistAttr& dist_attr) {
if (this == &dist_attr) return *this;
TensorDistAttr tmp(dist_attr);
std::swap(this->process_mesh_, tmp.process_mesh_);
std::swap(this->dims_mapping_, tmp.dims_mapping_);
std::swap(this->batch_dim_, tmp.batch_dim_);
std::swap(this->dynamic_dims_, tmp.dynamic_dims_);
std::swap(this->annotated_, tmp.annotated_);
return *this;
}

void TensorDistAttr::copy_from(const TensorDistAttr& dist_attr) {
set_process_mesh(dist_attr.process_mesh());
set_dims_mapping(dist_attr.dims_mapping());
set_batch_dim(dist_attr.batch_dim());
set_dynamic_dims(dist_attr.dynamic_dims());
set_annotated(dist_attr.annotated());
}

void TensorDistAttr::set_process_mesh(const ProcessMesh& process_mesh) {
process_mesh_ = process_mesh;
}

void TensorDistAttr::set_dims_mapping(
const std::vector<int64_t>& dims_mapping) {
dims_mapping_ = dims_mapping;
}

void TensorDistAttr::set_batch_dim(int64_t batch_dim) {
batch_dim_ = batch_dim;
}

void TensorDistAttr::set_dynamic_dims(const std::vector<bool>& dynamic_dims) {
dynamic_dims_ = dynamic_dims;
}

void TensorDistAttr::set_annotated(
const std::map<std::string, bool>& annotated) {
annotated_ = annotated;
}

void TensorDistAttr::set_default_dims_mapping(
const std::vector<int64_t>& tensor_shape) {
if (tensor_shape.size() != 0) {
dims_mapping_ = std::vector<int64_t>(tensor_shape.size(), -1);
}
}

void TensorDistAttr::set_default_dynamic_dims(
const std::vector<int64_t>& tensor_shape) {
if (tensor_shape.size() != 0) {
dynamic_dims_ = std::vector<bool>(tensor_shape.size(), false);
}
}

void TensorDistAttr::mark_annotated(const std::string& name) {
auto result = std::find(std::begin(fields_), std::end(fields_), name);
if (result != std::end(fields_)) {
annotated_[name] = true;
}
}

bool TensorDistAttr::verify_process_mesh(
const ProcessMesh& process_mesh) const {
VLOG(4) << "[TensorDistAttr verify_process_mesh] "
<< process_mesh.to_string();
if (!process_mesh_.empty()) {
for (int64_t dim_mapping : dims_mapping_) {
if (dim_mapping >= process_mesh_.ndim()) {
return false;
}
}
}
return true;
}

bool TensorDistAttr::verify_dims_mapping(
const std::vector<int64_t>& dims_mapping,
const std::vector<int64_t>& tensor_shape) const {
VLOG(4) << "[TensorDistAttr verify_dims_mapping] " << str_join(dims_mapping);
if (dims_mapping.size() != tensor_shape.size()) {
return false;
}
std::unordered_map<int64_t, int64_t> map;
if (!process_mesh_.empty()) {
for (int64_t i : dims_mapping) {
if (i < -1 || i >= process_mesh_.ndim()) {
return false;
}
++map[i];
if (i != -1 && map[i] > 1) {
return false;
}
}
} else {
for (int64_t i : dims_mapping) {
++map[i];
if (i != -1 && map[i] > 1) {
return false;
}
}
}
return true;
}

bool TensorDistAttr::verify_batch_dim(
int64_t dim, const std::vector<int64_t>& tensor_shape) const {
VLOG(4) << "[TensorDistAttr verify_batch_dim] " << dim;
int64_t ndim = tensor_shape.size();
if (ndim > 0) {
if (dim < 0) {
dim = dim + ndim;
}
if (dim < 0 || dim >= ndim) {
return false;
}
}
return true;
}

bool TensorDistAttr::verify_dynamic_dims(
const std::vector<bool>& dynamic_dims,
const std::vector<int64_t>& tensor_shape) const {
VLOG(4) << "[TensorDistAttr verify_dynamic_dims] " << str_join(dynamic_dims);
if (dynamic_dims.size() > 0 && dynamic_dims.size() != tensor_shape.size()) {
return false;
}
return true;
}

bool TensorDistAttr::verify_annotated(
const std::map<std::string, bool>& annotated) const {
VLOG(4) << "[TensorDistAttr verify_annotated] " << str_join(annotated);
for (const auto& item : annotated) {
auto result = std::find(std::begin(fields_), std::end(fields_), item.first);
if (result == std::end(fields_)) {
return false;
}
}
return true;
}

bool TensorDistAttr::verify(const VarDesc* tensor) const {
auto tensor_shape = get_tensor_shape(tensor);
if (!verify_process_mesh(process_mesh_)) {
return false;
}
if (!verify_dims_mapping(dims_mapping_, tensor_shape)) {
return false;
}
if (!verify_batch_dim(batch_dim_, tensor_shape)) {
return false;
}
if (!verify_dynamic_dims(dynamic_dims_, tensor_shape)) {
return false;
}
if (!verify_annotated(annotated_)) {
return false;
}
return true;
}

std::string TensorDistAttr::to_string() const {
std::string dist_str;
dist_str += "{process_mesh: " + process_mesh_.to_string() + ", ";
dist_str += "dims_mappings: [" + str_join(dims_mapping_) + "], ";
dist_str += "batch_dim: " + std::to_string(batch_dim_) + ", ";
dist_str += "dynamic_dims: [" + str_join(dynamic_dims_) + "], ";
dist_str += "annotated: [" + str_join(annotated_) + "]}";
return dist_str;
}

void TensorDistAttr::from_proto(const TensorDistAttrProto& proto) {
process_mesh_ = ProcessMesh::from_proto(proto.process_mesh());
dims_mapping_.resize(proto.dims_mapping_size());
for (int64_t i = 0; i < proto.dims_mapping_size(); ++i) {
dims_mapping_[i] = proto.dims_mapping(i);
}
batch_dim_ = proto.batch_dim();
dynamic_dims_.resize(proto.dynamic_dims_size());
for (int64_t i = 0; i < proto.dynamic_dims_size(); ++i) {
dynamic_dims_[i] = proto.dynamic_dims(i);
}
}

TensorDistAttrProto TensorDistAttr::to_proto() const {
TensorDistAttrProto proto;
proto.mutable_process_mesh()->CopyFrom(process_mesh_.to_proto());
for (const auto& i : dims_mapping_) {
proto.add_dims_mapping(i);
}
proto.set_batch_dim(batch_dim_);
for (const auto& i : dynamic_dims_) {
proto.add_dynamic_dims(i);
}
return proto;
}

std::string TensorDistAttr::serialize_to_string() {
std::string data;
auto proto = to_proto();
proto.SerializeToString(&data);
PADDLE_ENFORCE_EQ(to_proto().SerializeToString(&data),
true,
platform::errors::InvalidArgument(
"Failed to serialize tensor dist attr to string."));
return data;
}

void TensorDistAttr::parse_from_string(const std::string& data) {
TensorDistAttrProto proto;
PADDLE_ENFORCE_EQ(proto.ParseFromString(data),
true,
platform::errors::InvalidArgument(
"Failed to parse tensor dist attr from string."));
from_proto(proto);
}

bool operator==(const TensorDistAttr& lhs, const TensorDistAttr& rhs) {
if (lhs.process_mesh() != rhs.process_mesh()) {
return false;
}
if (lhs.dims_mapping() != rhs.dims_mapping()) {
return false;
}
if (lhs.batch_dim() != rhs.batch_dim()) {
return false;
}
if (lhs.dynamic_dims() != rhs.dynamic_dims()) {
return false;
}
return true;
}

std::vector<std::string> OperatorDistAttr::fields_{"process_mesh",
"impl_type",
"impl_idx",
Expand Down Expand Up @@ -335,7 +89,7 @@ void OperatorDistAttr::initialize(const OpDesc* op) {
if (input == nullptr || op->Type() == "create_py_reader") {
input_dist_attrs_[name] = TensorDistAttr();
} else {
input_dist_attrs_[name] = TensorDistAttr(*input);
input_dist_attrs_[name] = TensorDistAttr(get_tensor_shape(input));
}
}
for (std::string name : op->OutputArgumentNames()) {
Expand All @@ -344,7 +98,7 @@ void OperatorDistAttr::initialize(const OpDesc* op) {
if (output == nullptr) {
output_dist_attrs_[name] = TensorDistAttr();
} else {
output_dist_attrs_[name] = TensorDistAttr(*output);
output_dist_attrs_[name] = TensorDistAttr(get_tensor_shape(output));
}
}
op_type_ = op->Type();
Expand Down Expand Up @@ -465,7 +219,8 @@ bool OperatorDistAttr::verify_input_dist_attr(const std::string& name,
const VarDesc* tensor) const {
VLOG(4) << "[OperatorDistAttr verify_input_dist_attr] " << name << " "
<< dist_attr.to_string();
if (!dist_attr.verify(tensor)) {
auto tensor_shape = get_tensor_shape(tensor);
if (!dist_attr.verify(tensor_shape)) {
return false;
}
if (tensor != nullptr) {
Expand All @@ -484,7 +239,8 @@ bool OperatorDistAttr::verify_output_dist_attr(const std::string& name,
const VarDesc* tensor) const {
VLOG(4) << "[OperatorDistAttr verify_output_dist_attr] " << name << " "
<< dist_attr.to_string();
if (!dist_attr.verify(tensor)) {
auto tensor_shape = get_tensor_shape(tensor);
if (!dist_attr.verify(tensor_shape)) {
return false;
}
if (tensor != nullptr) {
Expand Down
Loading

0 comments on commit be1152a

Please sign in to comment.