Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Auto Parallel] Update reshard #40865

Merged
merged 2 commits into from
Mar 28, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion python/paddle/distributed/auto_parallel/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from .interface import shard_tensor # noqa: F401
from .interface import shard_op # noqa: F401
from .process_mesh import ProcessMesh
from .reshard import reshard # noqa: F401
from .reshard import Resharder # noqa: F401
from .cost_model import estimate_cost

__all__ = []
20 changes: 10 additions & 10 deletions python/paddle/distributed/auto_parallel/converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,19 +235,19 @@ def merge_and_slice(tensor_list, pre_dist_attr, cur_dist_attr):
@staticmethod
def merge_with_dist_attr(tensor_list, dist_attr):
""" Merge tensor with distributed attribute """
from .reshard import _compute_complete_shape, _compute_partition_index
from .reshard import Resharder

dims_mapping = dist_attr["dims_mapping"]
process_shape = dist_attr["process_shape"]
process_group = dist_attr["process_group"]
# get the complete shape of the tensor
complete_shape = _compute_complete_shape(tensor_list[0].shape,
process_shape, dims_mapping)
complete_shape = Resharder.compute_complete_shape(
tensor_list[0].shape, process_shape, dims_mapping)
# merge the tensor with dist_attr
partition_tensor_list = []
merged_partiton = []
for process in process_group:
partition_index = _compute_partition_index(
partition_index = Resharder.compute_partition_index(
process, complete_shape, dims_mapping, process_shape,
process_group)
index = process_group.index(process)
Expand Down Expand Up @@ -302,7 +302,7 @@ def merge(partition_tensor_list, tensor, partition_index, complete_shape):
_merge_tensor(partition_tensor_list, tensor, partition_index)
# partition_tensor_list: [(np.array([[[1.11, 1.12, 1.13, 1.14]]]), [[0,1],[0,1],[0,4]])]
"""
from .reshard import _compute_concat_info
from .reshard import Resharder

if len(partition_tensor_list) == 1:
is_complete_data = True
Expand All @@ -318,7 +318,7 @@ def merge(partition_tensor_list, tensor, partition_index, complete_shape):
else:
i = 0
while i < len(partition_tensor_list):
concat_axis, first_order, new_partition = _compute_concat_info(
concat_axis, first_order, new_partition = Resharder.compute_concat_info(
partition_tensor_list[i][1], partition_index)
if concat_axis != -1:
if first_order == 0:
Expand Down Expand Up @@ -391,11 +391,11 @@ def _get_split_indices(complete_shape, dims_mapping, process_shape,
index = _get_split_indices(complete_shape, dims_mapping, process_shape, process_group)
# index: [[], [], [2, 4]]
"""
from .reshard import _compute_partition_index
from .reshard import Resharder

split_indices_list = []
for process in process_group:
partition_index = _compute_partition_index(
partition_index = Resharder.compute_partition_index(
process, complete_shape, dims_mapping, process_shape,
process_group)
if split_indices_list:
Expand Down Expand Up @@ -437,9 +437,9 @@ def _get_sliced_index(rank_id, complete_shape, dims_mapping, process_shape,
process_shape, process_group)
# index: 2
"""
from .reshard import _compute_partition_index
from .reshard import Resharder

partition_index = _compute_partition_index(
partition_index = Resharder.compute_partition_index(
rank_id, complete_shape, dims_mapping, process_shape, process_group)
sliced_index = 0
for i, shape in enumerate(complete_shape):
Expand Down
12 changes: 7 additions & 5 deletions python/paddle/distributed/auto_parallel/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@

from .mapper import mapping
from .cluster import Cluster
from .reshard import reshard
from .reshard import Resharder
from .planner import Planner
from .completion import Completer
from .partitioner import Partitioner
Expand Down Expand Up @@ -187,8 +187,9 @@ def _parallel_program(self, mode, rank):
# Do reshard process
set_grad_var_shape(dist_main_prog, dist_context)
make_data_unshard(dist_main_prog, dist_startup_prog, dist_context)
reshard(dist_main_prog, dist_startup_prog, rank, dist_context,
dist_params_grads)
resharder = Resharder(dist_main_prog, dist_startup_prog, rank,
dist_context, dist_params_grads)
resharder.reshard()
# Apply post optimization passes
self._apply_post_optimization(dist_main_prog, dist_startup_prog,
rank, dist_params_grads)
Expand All @@ -199,8 +200,9 @@ def _parallel_program(self, mode, rank):
serial_main_program, serial_startup_program, [])
# Do reshard process
make_data_unshard(dist_main_prog, dist_startup_prog, dist_context)
reshard(dist_main_prog, dist_startup_prog, rank, dist_context, [],
1)
resharder = Resharder(dist_main_prog, dist_startup_prog, rank,
dist_context, [], 1)
resharder.reshard()

# clone program for test
if mode != 'train':
Expand Down
10 changes: 4 additions & 6 deletions python/paddle/distributed/auto_parallel/parallelizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@
from .utils import set_grad_var_shape
from .utils import print_program_with_dist_attr
from .utils import SerialProgramInfo
from .reshard import reshard, HAS_SENT, HAS_RECV, HAS_ALLGATHER
from .reshard import Resharder
from .cluster import Cluster
from .mapper import mapping
from .dist_op import DistributedOperator
Expand Down Expand Up @@ -213,17 +213,15 @@ def _get_dist_program(self, rank, dist_context=None, relaunch_phase=False):

make_data_unshard(dist_main_prog, dist_startup_prog, self._dist_context)

reshard(dist_main_prog, dist_startup_prog, rank, self._dist_context,
dist_params_grads)
resharder = Resharder(dist_main_prog, dist_startup_prog, rank,
self._dist_context, dist_params_grads)
resharder.reshard()

self._apply_post_optimization_passes(dist_main_prog, dist_startup_prog,
rank, dist_params_grads)
g_process_group_map = None
if not relaunch_phase:
g_process_group_map = copy.deepcopy(_g_process_group_map)
HAS_SENT.clear()
HAS_RECV.clear()
HAS_ALLGATHER.clear()
_g_process_group_map.clear()
_g_process_group_map[0] = ProcessGroup(0, [])
for process_mesh in dist_context._process_meshes:
Expand Down
Loading