Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Auto Parallel] Completion Dist Attribute for Backward & Update stage #36744

Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
230 changes: 146 additions & 84 deletions python/paddle/distributed/auto_parallel/completion.py
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -623,24 +623,35 @@ def _get_op_by_id(ops, id):
if dist_context is None:
dist_context = get_default_distributed_context()

grad_start_idx = -1
first_backward_op_idx = -1
for idx, op in enumerate(auto_parallel_main_prog.global_block().ops):
if int(op.attr('op_role')) == int(
int(core.op_proto_and_checker_maker.OpRole.Backward) | int(
core.op_proto_and_checker_maker.OpRole.Loss)):
assert op.type == "fill_constant"
grad_start_idx = idx
first_backward_op_idx = idx
break

assert grad_start_idx >= 0, "No backward procedure found in this program."
assert first_backward_op_idx >= 0, "No backward procedure found in this program."

ops = list(auto_parallel_main_prog.global_block().ops)
vars = auto_parallel_main_prog.global_block().vars
dist_op_helper = dist_context.get_dist_op_helper()

for idx in range(grad_start_idx, len(ops)):
for idx in range(first_backward_op_idx, len(ops)):

# complete the initial grad loss op
if idx == grad_start_idx:
if idx == first_backward_op_idx:
assert ops[idx].type == "fill_constant"
assert len(
ops[idx].input_arg_names
) == 0, "first backward op should has only ONE output, but got [{}]".format(
len(ops[idx].input_arg_names))
assert len(
ops[idx].output_arg_names
) == 1, "first backward op should has only ONE output, but got [{}]".format(
len(ops[idx].output_arg_names))

grad_var = vars[ops[idx].output_arg_names[0]]
forward_var_name = _get_forward_varname_from_grad_varname(
grad_var.name)
Expand All @@ -659,48 +670,17 @@ def _get_op_by_id(ops, id):

op_attr = OperatorDistributedAttribute(ops[idx], dist_context)
op_attr.set_process_mesh(process_mesh)
dist_context.set_op_distributed_attr_for_program(ops[idx], op_attr)
continue

# TODO remove this when dist op handle its own grad scale
# in the data parallel mode, the loss op followed by scale op.
if ops[idx].type == "scale" and idx == grad_start_idx + 1:
assert grad_var.name in ops[
idx].input_arg_names and grad_var.name in ops[
idx].output_arg_names
grad_var = vars[ops[idx].output_arg_names[0]]
forward_var_name = _get_forward_varname_from_grad_varname(
grad_var.name)
forward_var = vars[forward_var_name]
process_mesh = dist_context.get_tensor_distributed_attr_for_program(
forward_var).get_process_mesh()
op_attr = OperatorDistributedAttribute(ops[idx], dist_context)
op_attr.set_process_mesh(process_mesh)
dist_context.set_op_distributed_attr_for_program(ops[idx], op_attr)
continue

# TODO remove this when dist op handle its own communication
# TODO should distinguish the dp allreduce and mp allreduce
# complete the c_allreduce_sum op for gradient in the data parallel mode.
if ops[idx].type == "c_allreduce_sum" and ops[
idx].input_arg_names == ops[idx].output_arg_names:
grad_var = vars[ops[idx].output_arg_names[0]]
op_attr = OperatorDistributedAttribute(ops[idx], dist_context)
process_mesh = dist_context.get_tensor_distributed_attr_for_program(
grad_var).get_process_mesh()
op_attr.set_process_mesh(process_mesh)
op_attr.set_output_dims_mapping(grad_var.name, dims_mapping)
dist_context.set_op_distributed_attr_for_program(ops[idx], op_attr)
continue

# complete the annotation of grad op (xxx_grad op or sum op)
grad_op = ops[idx]

# xxx_grad op will have a corresponding forward op in gradopidx2opidx
dist_op_helper = dist_context.get_dist_op_helper()
grad_op = ops[idx]
if grad_op.desc.id() in dist_op_helper.gradopidx2opidx:
# TODO support the case where one forward op corresponding to multiple xxx_grad op
forward_op = _get_op_by_id(
ops[:grad_start_idx],
ops[:first_backward_op_idx],
dist_op_helper.gradopidx2opidx[grad_op.desc.id()])
assert forward_op is not None

Expand All @@ -710,39 +690,60 @@ def _get_op_by_id(ops, id):
grad_op_attr = OperatorDistributedAttribute(grad_op, dist_context)
grad_op_attr.set_process_mesh(forward_op_attr.get_process_mesh())

for var_name in grad_op.input_arg_names:
if "@GRAD" in var_name:
dims_mapping = dist_context.get_tensor_distributed_attr_for_program(
vars[var_name]).get_dims_mapping()
grad_op_attr.set_input_dims_mapping(var_name, dims_mapping)
# var
for output_name in grad_op.desc.output_names():
assert len(grad_op.desc.output(output_name)) in [0, 1]
# if grad_op.type == "cast":
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

没用可以干掉

# input_name = "X"
# else:
if _is_grad_var_name(output_name):
input_name = _get_forward_varname_from_grad_varname(
output_name)
else:
dims_mapping = forward_op_attr.get_input_dims_mapping(
var_name)
# TODO fixed here
if dims_mapping == None:
dims_mapping = forward_op_attr.get_output_dims_mapping(
var_name)
assert dims_mapping is not None, "[{}]'s dims_mapping is None".format(
var_name)
grad_op_attr.set_input_dims_mapping(var_name, dims_mapping)
assert grad_op.type in [
"cast", "c_identity", "c_allreduce_sum"
]
input_name = "X"
assert input_name in forward_op.desc.input_names(
), "var [{}] in op [{}]'s output but coulf not find [{}] in its forward op".format(
output_name, grad_op.type, input_name)
if len(grad_op.desc.output(output_name)) == 1:
assert len(forward_op.desc.input(input_name)) == 1
input_var = vars[forward_op.desc.input(input_name)[0]]
input_var_dist_attr = dist_context.get_tensor_distributed_attr_for_program(
input_var)
assert input_var_dist_attr is not None, "[{}] has not dist attribute".format(
input_var.name)
ref_dims_mapping = input_var_dist_attr.get_dims_mapping()
ref_process_mesh = input_var_dist_attr.get_process_mesh()

# tensor dist attr
output_var = vars[grad_op.desc.output(output_name)[0]]
output_var_attr = TensorDistributedAttribute(output_var,
dist_context)
output_var_attr.set_dims_mapping(ref_dims_mapping)
output_var_attr.set_process_mesh(ref_process_mesh)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the grad var process mesh should be the same with grad op instead of forward var?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, here is a bug and had been fixed

dist_context.set_tensor_distributed_attr_for_program(
output_var, output_var_attr)

# op dist attr
grad_op_attr.set_output_dims_mapping(output_var.name,
ref_dims_mapping)

for input_name in grad_op.input_arg_names:
input_var = vars[input_name]
input_var_dist_attr = dist_context.get_tensor_distributed_attr_for_program(
input_var)
assert input_var_dist_attr is not None, "[{}] has not dist attribute".format(
input_var.name)
ref_dims_mapping = input_var_dist_attr.get_dims_mapping()
assert ref_dims_mapping is not None, "[{}] 's dims mapping is NONE".format(
input_var.name)
grad_op_attr.set_input_dims_mapping(input_name,
ref_dims_mapping)

dist_context.set_op_distributed_attr_for_program(grad_op,
grad_op_attr)
# var dist attr
for var_name in grad_op.output_arg_names:
if _is_grad_var_name(var_name):

forward_var_name = _get_forward_varname_from_grad_varname(
var_name)
forward_var = vars[forward_var_name]
tensor_attr = TensorDistributedAttribute(vars[var_name],
dist_context)
process_mesh = grad_op_attr.get_process_mesh()
dims_mapping = grad_op_attr.get_input_dims_mapping(
forward_var_name)
tensor_attr.set_process_mesh(process_mesh)
tensor_attr.set_dims_mapping(dims_mapping)
dist_context.set_tensor_distributed_attr_for_program(
vars[var_name], tensor_attr)

# only sum op for merge mutiple version grad has no a corresponding mapping in gradopidx2opidx
else:
Expand Down Expand Up @@ -775,6 +776,9 @@ def _get_op_by_id(ops, id):
var_name) == ref_forward_var_name
grad_op_attr.set_input_dims_mapping(
var_name, ref_forward_var_dims_mapping)

grad_op_attr.set_output_dims_mapping(grad_op.output_arg_names[0],
ref_forward_var_dims_mapping)
dist_context.set_op_distributed_attr_for_program(grad_op,
grad_op_attr)

Expand All @@ -787,28 +791,86 @@ def complete_update_annotation(auto_parallel_main_prog, dist_context):

ops = list(auto_parallel_main_prog.global_block().ops)
vars = auto_parallel_main_prog.global_block().vars
learning_rate_completed = False

for idx in range(len(ops)):

# complete the annotation of the optimizer op.
# TODO to add attribute for moment var
if int(ops[idx].attr('op_role')) == int(OpRole.Optimize):
if "Grad" in ops[idx].input_names and "Param" in ops[
idx].input_names:
assert len(ops[idx].input(
op = ops[idx]
if int(op.attr('op_role')) == int(OpRole.Optimize):

if "Grad" in op.input_names and "Param" in ops[idx].input_names:
assert len(op.input(
"Param")) == 1, "Only support one-to-one now."
assert len(ops[idx].input(
assert len(op.input(
"Grad")) == 1, "Only support one-to-one now."
param = vars[ops[idx].input("Param")[0]]
grad_var = vars[ops[idx].input("Grad")[0]]
process_mesh = dist_context.get_tensor_distributed_attr_for_program(
param = vars[op.input("Param")[0]]
grad_var = vars[op.input("Grad")[0]]

param_dist_attr = dist_context.get_tensor_distributed_attr_for_program(
param)
grad_dist_attr = dist_context.get_tensor_distributed_attr_for_program(
grad_var)

assert param_dist_attr is not None
assert grad_dist_attr is not None
assert param_dist_attr.get_dims_mapping(
) == grad_dist_attr.get_dims_mapping()

ref_process_mesh = dist_context.get_tensor_distributed_attr_for_program(
param).get_process_mesh()
dims_mapping = dist_context.get_tensor_distributed_attr_for_program(
assert ref_process_mesh is not None
ref_dims_mapping = dist_context.get_tensor_distributed_attr_for_program(
param).get_dims_mapping()
op_attr = OperatorDistributedAttribute(ops[idx], dist_context)
op_attr.set_process_mesh(process_mesh)
op_attr.set_input_dims_mapping(grad_var.name, dims_mapping)
op_attr.set_input_dims_mapping(param.name, dims_mapping)
dist_context.set_op_distributed_attr_for_program(ops[idx],
op_attr)
assert ref_dims_mapping is not None
op_attr = OperatorDistributedAttribute(op, dist_context)
op_attr.set_process_mesh(ref_process_mesh)
op_attr.set_input_dims_mapping(grad_var.name, ref_dims_mapping)
op_attr.set_input_dims_mapping(param.name, ref_dims_mapping)
op_attr.set_output_dims_mapping(param.name, ref_dims_mapping)
learning_var = vars[op.input("LearningRate")[0]]
op_attr.set_input_dims_mapping(learning_var.name, [-1])
op_attr.set_output_dims_mapping(learning_var.name, [-1])

if not learning_rate_completed:
learning_rate_completed = True
var_dist_attr = TensorDistributedAttribute(learning_var,
dist_context)
var_dist_attr.set_process_mesh(ref_process_mesh)
var_dist_attr.set_dims_mapping([-1])
dist_context.set_tensor_distributed_attr_for_program(
learning_var, var_dist_attr)

for input_name in op.desc.input_names():

if input_name in [
'Param', 'Grad', 'LearningRate', "SkipUpdate",
"Beta1Tensor", "Beta2Tensor", "EpsilonTensor",
"MasterParam"
]:
continue

assert len(op.desc.input(input_name)) == 1
input_var = vars[op.desc.input(input_name)[0]]
input_var_attr = TensorDistributedAttribute(input_var,
dist_context)

if "Beta1Pow" in input_name or "Beta2Pow" in input_name:
input_var_attr.set_dims_mapping([-1])
op_attr.set_input_dims_mapping(input_var.name, [-1])
op_attr.set_output_dims_mapping(input_var.name, [-1])
else:
assert "Moment" in input_name
input_var_attr.set_dims_mapping(ref_dims_mapping)
op_attr.set_input_dims_mapping(input_var.name,
ref_dims_mapping)
op_attr.set_output_dims_mapping(input_var.name,
ref_dims_mapping)

input_var_attr.set_process_mesh(ref_process_mesh)
dist_context.set_tensor_distributed_attr_for_program(
input_var, input_var_attr)

dist_context.set_op_distributed_attr_for_program(op, op_attr)
continue
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,35 @@ def check_tensor_split(prog1, varnames1, prog2, varnames2, axis, nsplit):
return True


def is_valid_completed_program(dist_context, program):

# TODO (ZJ-LIANG) should check all block
ops = program.global_block().ops
vars_ = program.list_vars()
for op in ops:
op_dist_attrs = dist_context.get_op_distributed_attr_for_program(op)
if op_dist_attrs == None:
return False

if op_dist_attrs.get_process_mesh == None:
return False

if None in op_dist_attrs._dims_mapping.values():
return False

for var in vars_:
var_dist_attrs = dist_context.get_tensor_distributed_attr_for_program(
var)
if var_dist_attrs == None:
return False
elif var_dist_attrs.get_process_mesh == None:
return False
elif var_dist_attrs.get_dims_mapping == None:
return False

return True


class MultiHeadAttention(nn.Layer):
"""
Attention mapps queries and a set of key-value pairs to outputs, and
Expand Down Expand Up @@ -874,6 +903,9 @@ def test_gpt_dp_mp(self):
self.assertTrue(all_params == data_parallel_allreduce_vars)
self.assertTrue(allreduce_grads == tensor_parallel_allreduce_vars)

self.assertTrue(
is_valid_completed_program(dist_context, auto_parallel_main_prog))


if __name__ == "__main__":
unittest.main()