Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[AutoParallel] adapt for clip #49249

Merged
merged 7 commits into from
Dec 28, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 21 additions & 0 deletions python/paddle/distributed/auto_parallel/operators/dist_default.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
)

__op_not_need_param_init__ = ["while", "cond"]
__op_has_shape_attr__ = ["fill_constant_batch_size_like", "fill_constant"]


def prim_operator_data_parallel_functor(ctx, src_op):
Expand Down Expand Up @@ -476,6 +477,26 @@ def forward(ctx, *args, **kwargs):
for output_name in src_op.desc.output_names():
dist_op_desc.set_output(output_name, kwargs[output_name])

if (
src_op.has_attr('shape')
and src_op.attr('shape')
and src_op.type in __op_has_shape_attr__
):
shape_list = src_op.attr('shape')
Out_var = main_block._var_recursive(kwargs['Out'][0])
op_dist_attr = ctx.get_op_dist_attr_for_program(src_op)
dim_mapping = op_dist_attr.get_output_dims_mapping(Out_var.name)
process_mesh_shape = op_dist_attr.process_mesh.shape
assert len(shape_list) == len(dim_mapping)
# modify target shape
for idx, axis in enumerate(dim_mapping):
if axis >= 0:
if len(shape_list) > idx:
shape_list[idx] = (
shape_list[idx] // process_mesh_shape[axis]
)
dist_op_desc._set_attr('shape', shape_list)

# data parallel synchronization for primtive operators
from paddle.incubate.autograd import prim_enabled

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -129,24 +129,6 @@ def forward(ctx, *args, **kwargs):
kwargs: inputname_mapping & outputname_mapping
"""
DistributedDefaultImpl0.forward(ctx, *args, **kwargs)
dist_op_context = ctx.dist_op_context
src_op = dist_op_context.cur_src_op
op_dist_attr = ctx.get_op_dist_attr_for_program(src_op)
main_block = dist_op_context.work_block
op = main_block.ops[-1]
assert op.type == "fill_constant_batch_size_like"

# modify shape attr according to how output are partitioned
out_name = op.output('Out')[0]
dims_mapping = op_dist_attr.get_output_dims_mapping(out_name)
process_mesh_shape = op_dist_attr.process_mesh.shape
shape_list = op.attr("shape")
# modify target shape
for idx, axis in enumerate(dims_mapping):
if axis >= 0:
shape_list[idx] = shape_list[idx] // process_mesh_shape[axis]

op._set_attr("shape", shape_list)

@staticmethod
def backward(ctx, *args, **kwargs):
Expand Down
51 changes: 40 additions & 11 deletions python/paddle/distributed/auto_parallel/operators/dist_pnorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,23 @@ def __init__(self, op_type):
register_distributed_operator_impl_container(DistributedPNorm("p_norm"))


# Row Parallel
class DistributedPNormImpl(DistributedOperatorImpl):
# Data Parallel
class DistributedPNormImpl0(DistributedOperatorImpl):
"""
TODO: p_norm scene

1. axis == None, isinstance(p, (int, float)), asvector = True
1.1 x_dims_mapping == [0, -1, -1]
allgather input if it is splited by dp group
1.2 x_dims_mapping == [-1, 0, -1]
allgather, split and concat input if it is splited by mp group
2. isinstance(axis, int), asvector = False
1.1 axis == 0 and x_dims_mapping == [0, -1, -1]
allgather input if it's input[0] is splited by dp group.
1.2 axis == 1 and x_dims_mapping == [-1, 0, -1]
allgather, split and concat input if it's input[1] is splited by mp group
"""

def __init__(self, name):
super().__init__(name)
self._forward_implemented = True
Expand All @@ -57,6 +72,8 @@ def __init__(self, name):
def is_input_compatible(self, dist_op):
op_desc = dist_op.serial_op.desc
op_dist_attr = dist_op.dist_attr
axis = op_desc.attr('axis')
asvector = op_desc.attr('asvector')
x_name = op_desc.input('X')[0]
x_dims_mapping = op_dist_attr.get_input_dims_mapping(x_name)
if is_dim_replicate(x_dims_mapping[0]):
Expand All @@ -65,6 +82,8 @@ def is_input_compatible(self, dist_op):
for mapping in x_dims_mapping[1:]:
if is_dim_shard(mapping):
return False
if not (axis == -1 and asvector) and not (axis == 0 and not asvector):
return False
return True

def is_output_compatible(self, dist_op):
Expand All @@ -90,6 +109,8 @@ def update_dims_mapping(self, dist_op):
changed = False
op_desc = dist_op.serial_op.desc
op_dist_attr = dist_op.dist_attr
axis = op_desc.attr('axis')
keepdim = op_desc.attr('keepdim')

batch_dim_mappings = []
for arg_name in op_desc.input_arg_names():
Expand All @@ -115,14 +136,22 @@ def update_dims_mapping(self, dist_op):
):
dims_mapping[0] = compatible_dim_mapping
changed = True
for arg_name in op_desc.output_arg_names():
dims_mapping = op_dist_attr.get_output_dims_mapping(arg_name)
if (
len(dims_mapping) >= 1
and compatible_dim_mapping != dims_mapping[0]
):
dims_mapping[0] = compatible_dim_mapping
changed = True

if axis == 0 and not keepdim:
for arg_name in op_desc.output_arg_names():
dims_mapping = op_dist_attr.get_output_dims_mapping(arg_name)
if len(dims_mapping) >= 1 and dims_mapping[0] != -1:
dims_mapping[0] = -1
changed = True
else:
for arg_name in op_desc.output_arg_names():
dims_mapping = op_dist_attr.get_output_dims_mapping(arg_name)
if (
len(dims_mapping) >= 1
and compatible_dim_mapping != dims_mapping[0]
):
dims_mapping[0] = compatible_dim_mapping
changed = True

return changed

Expand Down Expand Up @@ -350,5 +379,5 @@ def backward(ctx, *args, **kwargs):


register_distributed_operator_impl(
"p_norm", DistributedPNormImpl("row_parallel")
"p_norm", DistributedPNormImpl0("data_parallel")
)
4 changes: 4 additions & 0 deletions python/paddle/distributed/auto_parallel/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1261,6 +1261,8 @@ def set_grad_var_shape(program, dist_context):
"fused_softmax_mask_upper_triangle_grad",
"flatten_contiguous_range_grad",
"relu_grad",
"exp_grad",
"sigmoid_grad",
]
forward_list = [
"reshape2",
Expand All @@ -1279,6 +1281,8 @@ def set_grad_var_shape(program, dist_context):
"fused_softmax_mask_upper_triangle",
"flatten_contiguous_range",
"relu",
"exp",
"sigmoid",
]
if op.type in need_set_shape_list:
for forward_op in block.ops:
Expand Down
134 changes: 92 additions & 42 deletions python/paddle/fluid/tests/unittests/auto_parallel/test_dist_pnorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
paddle.enable_static()


def make_program_dp2():
def make_program_dp2_axis_None():
main_program = paddle.fluid.Program()
start_program = paddle.fluid.Program()
with paddle.static.program_guard(main_program, start_program):
Expand All @@ -35,6 +35,32 @@ def make_program_dp2():
return main_program, start_program, tmp_0


def make_program_dp2_axis_0():
main_program = paddle.fluid.Program()
start_program = paddle.fluid.Program()
with paddle.static.program_guard(main_program, start_program):
x = paddle.static.data(name='x', shape=[4, 5, 6], dtype='float32')
x.stop_gradient = False
auto.shard_tensor(
x, auto.ProcessMesh([0, 1], dim_names=["x"]), ["x", None, None]
)
tmp_0 = paddle.norm(x, p=2, axis=0)
return main_program, start_program, tmp_0


def make_program_dp2_axis_1():
main_program = paddle.fluid.Program()
start_program = paddle.fluid.Program()
with paddle.static.program_guard(main_program, start_program):
x = paddle.static.data(name='x', shape=[4, 5, 6], dtype='float32')
x.stop_gradient = False
auto.shard_tensor(
x, auto.ProcessMesh([0, 1], dim_names=["x"]), ["x", None, None]
)
tmp_0 = paddle.norm(x, p=2, axis=1)
return main_program, start_program, tmp_0


def make_program_serial():
main_program = paddle.fluid.Program()
start_program = paddle.fluid.Program()
Expand Down Expand Up @@ -76,47 +102,71 @@ def parallelizer(program_func, rank):


class TestDistPNorm(unittest.TestCase):
def test_dist_pnorm_dp2(self):

for rank in range(2):
dist_main_prog, dist_context = parallelizer(make_program_dp2, rank)
ops = dist_main_prog.global_block().ops
op_types = []
for op in ops:
op_types.append(op.type)
op_dist_attr = dist_context.get_op_dist_attr_for_program(op)
if op.type == "p_norm":
assert op_dist_attr.impl_type == "p_norm"
if op.type in ["p_norm", "p_norm_grad"]:
for input_attr in op_dist_attr.inputs_dist_attrs.values():
assert set(input_attr.dims_mapping) == set([-1])
for output_attr in op_dist_attr.outputs_dist_attrs.values():
assert set(output_attr.dims_mapping) == set([-1])
if op.type == 'c_allgather':
for input_attr in op_dist_attr.inputs_dist_attrs.values():
assert input_attr.dims_mapping[0] == 0
assert set(input_attr.dims_mapping[1:]) == set([-1])
for output_attr in op_dist_attr.outputs_dist_attrs.values():
assert set(output_attr.dims_mapping) == set([-1])
if op.type == 'slice':
for input_attr in op_dist_attr.inputs_dist_attrs.values():
assert set(input_attr.dims_mapping) == set([-1])
for output_attr in op_dist_attr.outputs_dist_attrs.values():
assert output_attr.dims_mapping[0] == 0
assert set(output_attr.dims_mapping[1:]) == set([-1])
assert op_types == [
"c_allgather",
"p_norm",
"fill_constant",
"p_norm_grad",
"slice",
]

def test_dist_pnorm_serial(self):
dist_main_prog, dist_context = parallelizer(make_program_serial, 0)
ops = dist_main_prog.global_block().ops
for op in ops:
op_dist_attr = dist_context.get_op_dist_attr_for_program(op)
def prepare(self, func):
self.dist_main_prog, self.dist_context = parallelizer(func, 0)
self.ops = self.dist_main_prog.global_block().ops

def test_dist_pnorm(self):
pass


class TestDistPNormDP(TestDistPNorm):
def test_dist_pnorm(self):
self.prepare(make_program_dp2_axis_None)
self.check_program()

def check_program(self):
op_types = []
for op in self.ops:
op_types.append(op.type)
op_dist_attr = self.dist_context.get_op_dist_attr_for_program(op)
if op.type == "p_norm":
assert op_dist_attr.impl_type == "p_norm"
if op.type in ["p_norm", "p_norm_grad"]:
for input_attr in op_dist_attr.inputs_dist_attrs.values():
assert set(input_attr.dims_mapping) == set([-1])
for output_attr in op_dist_attr.outputs_dist_attrs.values():
assert set(output_attr.dims_mapping) == set([-1])
if op.type == 'c_allgather':
for input_attr in op_dist_attr.inputs_dist_attrs.values():
assert input_attr.dims_mapping[0] == 0
assert set(input_attr.dims_mapping[1:]) == set([-1])
for output_attr in op_dist_attr.outputs_dist_attrs.values():
assert set(output_attr.dims_mapping) == set([-1])
if op.type == 'slice':
for input_attr in op_dist_attr.inputs_dist_attrs.values():
assert set(input_attr.dims_mapping) == set([-1])
for output_attr in op_dist_attr.outputs_dist_attrs.values():
assert output_attr.dims_mapping[0] == 0
assert set(output_attr.dims_mapping[1:]) == set([-1])
assert op_types == [
"c_allgather",
"p_norm",
"fill_constant",
"p_norm_grad",
"slice",
]


class TestDistPNormDP1(TestDistPNormDP):
def test_dist_pnorm(self):
self.prepare(make_program_dp2_axis_0)
self.check_program()


class TestDistPNormSerial(TestDistPNorm):
def test_dist_pnorm(self):
self.prepare(make_program_serial)
for op in self.ops:
op_dist_attr = self.dist_context.get_op_dist_attr_for_program(op)
assert op_dist_attr.impl_type == "default"


class TestDistPNormDPAxis1(TestDistPNorm):
def test_dist_pnorm(self):
self.prepare(make_program_dp2_axis_1)
for op in self.ops:
op_dist_attr = self.dist_context.get_op_dist_attr_for_program(op)
assert op_dist_attr.impl_type == "default"


Expand Down