Skip to content

Commit

Permalink
[Auto parallel] Mixed Precision FP16 Pass (#40615)
Browse files Browse the repository at this point in the history
*  add FP16 Pass 

* Support the auto completion of while_op

*  acc aligned
  • Loading branch information
JZ-LIANG authored Mar 28, 2022
1 parent 5c5a366 commit b99c1d0
Show file tree
Hide file tree
Showing 9 changed files with 670 additions and 11 deletions.
2 changes: 2 additions & 0 deletions paddle/fluid/framework/distributed_strategy.proto
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,8 @@ message AMPConfig {
repeated string custom_black_varnames = 9;
optional bool use_pure_fp16 = 10 [ default = false ];
optional bool use_fp16_guard = 11 [ default = true ];
optional bool use_optimizer_fp16 = 12
[ default = false ]; // auto parallel effective only
}

message LocalSGDConfig {
Expand Down
12 changes: 9 additions & 3 deletions python/paddle/distributed/auto_parallel/parallelizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,9 +105,15 @@ def _apply_pre_optimization_passes(self, main_program, startup_program,
config["dist_context"] = self._dist_context
config["params_grads"] = params_grads
config["loss"] = loss
auto_parallel_amp_pass = new_pass("auto_parallel_amp", config)
auto_parallel_amp_pass.apply([main_program], [startup_program],
self._pass_context)
if config["use_pure_fp16"]:
config["base_opt"] = self._optimizer
auto_parallel_fp16_pass = new_pass("auto_parallel_fp16", config)
auto_parallel_fp16_pass.apply(
[main_program], [startup_program], self._pass_context)
else:
auto_parallel_amp_pass = new_pass("auto_parallel_amp", config)
auto_parallel_amp_pass.apply([main_program], [startup_program],
self._pass_context)

# apply recompute pass
if self._dist_strategy.recompute:
Expand Down
3 changes: 2 additions & 1 deletion python/paddle/distributed/auto_parallel/partitioner.py
Original file line number Diff line number Diff line change
Expand Up @@ -357,10 +357,11 @@ def _partition_var(dist_context, src_block, dst_block, src_varname,
src_var = src_block.var(src_varname)

if src_var.type in __not_shape_var_type__:
persist = getattr(src_var, 'persistable', False)
new_var = dst_block.create_var(
type=src_var.type,
name=dst_varname,
persistable=True,
persistable=persist,
stop_gradient=True)
target_shape = None
else:
Expand Down
3 changes: 1 addition & 2 deletions python/paddle/distributed/auto_parallel/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1047,8 +1047,7 @@ def set_grad_var_shape(program, dist_context):

forward_input_dist_attr = op_dist_attr.get_input_dist_attr(
forward_var_name)

assert forward_input_dist_attr is not None, f"{forward_var_name}"
assert forward_input_dist_attr is not None, f"{forward_var_name, str(op)}"
forward_var = vars[forward_var_name]
forward_var_dist_attr = dist_context.get_tensor_dist_attr_for_program(
forward_var)
Expand Down
1 change: 1 addition & 0 deletions python/paddle/distributed/passes/__init__.py
100755 → 100644
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from .auto_parallel_gradient_merge import *
from .auto_parallel_sharding import *
from .auto_parallel_amp import *
from .auto_parallel_fp16 import *
from .auto_parallel_recompute import *
from .cpp_pass import *
import os
Expand Down
43 changes: 38 additions & 5 deletions python/paddle/distributed/passes/auto_parallel_amp.py
Original file line number Diff line number Diff line change
Expand Up @@ -503,8 +503,6 @@ def _check_self(self):
return False
if self.get_attr("decr_ratio") < 0:
return False
if len(self.get_attr("params_grads")) <= 0:
return False
if self.get_attr("dist_context") is None:
return False
return True
Expand Down Expand Up @@ -576,13 +574,46 @@ def _scale_loss(self):

main_block = paddle.static.default_main_program().global_block()
main_block._sync_with_cpp()
OP_ROLE_KEY = core.op_proto_and_checker_maker.kOpRoleAttrName()

loss = self.get_attr("loss")
assert loss is not None
loss_op = loss.op
loss_op_dist_attr = self.dist_context.get_op_dist_attr_for_program(
loss_op)

if loss.dtype != core.VarDesc.VarType.FP32:
# cast loss here will change the effective loss tensor for the computation graph
# and therefore will effect all following passes whose logic is based on the loss tensor(Recompute & Gradient Merge),
# so we it is not allowed by now. fixed it in future.
raise NotImplementedError(
"Loss's generator op is not support in FP16 in Auto Parallel by now, please put that op into your black-list."
)

tmp_name = unique_name.generate(loss.name + ".cast_fp32")
cast_loss = main_block.create_var(name=tmp_name, dtype=dtype)
loss_dist_attr = self.dist_context.get_tensor_dist_attr_for_program(
loss)
ref_mesh = loss_op_dist_attr.process_mesh
self.dist_context.set_tensor_dist_attr_for_program(cast_loss,
loss_dist_attr)

loss_op_idx = find_op_index(main_block.desc, loss_op.desc)
cast_op = main_block._insert_op(
loss_op_idx + 1,
type='cast',
inputs={'X': [loss]},
outputs={'Out': [cast_loss]},
attrs={
"in_dtype": loss.dtype,
"out_dtype": core.VarDesc.VarType.FP32,
'op_role': loss_op.all_attrs()[OP_ROLE_KEY],
})

loss_op._set_attr(OP_ROLE_KEY,
core.op_proto_and_checker_maker.OpRole.Forward)
naive_set_dist_op_attr_for_program_by_mesh_and_mapping(
cast_op, ref_mesh, [-1], self.dist_context)
loss = loss.astype('float32')

if self.get_attr("use_dynamic_loss_scaling") or self.get_attr(
Expand All @@ -600,7 +631,6 @@ def _scale_loss(self):
set_var_dist_attr(self.dist_context, self._scaled_loss, [-1],
ref_mesh)

OP_ROLE_KEY = core.op_proto_and_checker_maker.kOpRoleAttrName()
elementwise_mul_op = main_block._insert_op(
loss_op_idx + 1,
type='elementwise_mul',
Expand Down Expand Up @@ -667,8 +697,11 @@ def _update_loss_scaling(self, grads, found_inf):
for e in grads:
check_variable_and_dtype(e, "x", ['float16', 'float32', 'float64'],
'update_loss_scaling')
assert self._loss_scaling.dtype == e.dtype, \
"The dtype of prev_loss_scaling should be equal to the dtype of x."
if e.dtype == core.VarDesc.VarType.FP16:
assert self._loss_scaling.dtype == core.VarDesc.VarType.FP32, \
"The dtype of prev_loss_scaling should be float32 when the dtype of x is float16."
else:
assert self._loss_scaling.dtype == e.dtype, "The dtype of prev_loss_scaling should be equal to the dtype of x."

inputs = {
'X': grads,
Expand Down
Loading

0 comments on commit b99c1d0

Please sign in to comment.