Skip to content

Commit

Permalink
[Auto Parallel] Adapt Partitioner & DistOp for ERNIE3.0 Inference and…
Browse files Browse the repository at this point in the history
… cache (#39895)

* adapot dist op

* add dist_fill_constant_batch_size_like

* remvoe print

* update compitable

* add unitest
  • Loading branch information
JZ-LIANG authored Mar 2, 2022
1 parent 6af2729 commit c9cd47d
Show file tree
Hide file tree
Showing 7 changed files with 168 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -27,3 +27,4 @@
from . import dist_check_finite_and_unscale
from . import dist_update_loss_scaling
from . import dist_split
from . import dist_fill_constant_batch_size_like
Empty file.
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ def forward(ctx, *args, **kwargs):
kwargs['Out'])

Ids_var = main_block.var(kwargs['Ids'][0])
Weight_var = main_block.var(kwargs['W'][0])
Weight_var = main_block._var_recursive(kwargs['W'][0])
Out_var = main_block.var(kwargs['Out'][0])

# got dist attribute info
Expand Down Expand Up @@ -277,7 +277,8 @@ def forward(ctx, *args, **kwargs):

# param initialization sync
if Weight_var.is_parameter and not op_dist_attr.is_recompute:
assert Weight_var.name not in dist_op_context.already_init_sync_vars
if Weight_var.name in dist_op_context.already_init_sync_vars:
return
dist_op_context.already_init_sync_vars.add(Weight_var.name)
param = startup_block.var(Weight_var.name)
param_dist_attr = ctx.get_tensor_dist_attr_for_program(param)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License

from .common import DistributedOperatorImplContainer
from .common import DistributedOperatorImpl
from .common import register_distributed_operator_impl_container
from .common import register_distributed_operator_impl
from ..utils import is_dim_shard
from ..utils import is_dim_replicate
from ..utils import is_valid_list_index
from ..utils import compute_compatible_dim_mapping
from ..utils import compute_compatible_dims_mapping
from ..utils import compute_compatible_and_update_dim_mapping
from ..utils import set_dist_op_desc_original_id
from paddle.fluid import core, unique_name
from paddle.fluid.framework import in_dygraph_mode
from paddle.fluid.framework import Program, Parameter, Variable, program_guard
from paddle.fluid.data_feeder import check_variable_and_dtype, check_dtype
from .dist_default import DistributedDefaultImpl0


class DistributedFillConstantBatchSizeLike(DistributedOperatorImplContainer):
def __init__(self, op_type):
super(DistributedFillConstantBatchSizeLike, self).__init__(op_type)


register_distributed_operator_impl_container(
DistributedFillConstantBatchSizeLike("fill_constant_batch_size_like"))


class DistributedFillConstantBatchSizeLikeImpl0(DistributedOperatorImpl):
def __init__(self, name):
super(DistributedFillConstantBatchSizeLikeImpl0, self).__init__(name)
self._forward_implemented = True
self._backward_implemented = True

def is_input_compatible(self, dist_op):

return True

def is_output_compatible(self, dist_op):
op_desc = dist_op.serial_op.desc
op_dist_attr = dist_op.dist_attr
out_name = op_desc.output('Out')[0]
out_dims_mapping = op_dist_attr.get_output_dims_mapping(out_name)
shape_list = op_desc.attr("shape")

if len(shape_list) != len(out_dims_mapping):
return False

return True

def is_auto_compatible(self, dist_op):
if (not self.is_input_compatible(dist_op)) or \
(not self.is_output_compatible(dist_op)):
return False

out_name = op_desc.output('Out')[0]
out_dims_mapping = op_dist_attr.get_output_dims_mapping(out_name)
in_name = op_desc.input('Input')[0]
in_dims_mapping = op_dist_attr.get_input_dims_mapping(in_name)

# the dim_mapping of batch dimension should be the same
return out_dims_mapping[0] == in_dims_mapping[0]

def update_dims_mapping(self, dist_op):
changed = False
op_desc = dist_op.serial_op.desc
op_dist_attr = dist_op.dist_attr
x_name = op_desc.input('X')[0]
out_name = op_desc.output('Out')[0]
x_dims_mapping = op_dist_attr.get_input_dims_mapping(x_name)
out_dims_mapping = op_dist_attr.get_output_dims_mapping(out_name)

# only the batch size dimemsion of input and output are relative.
dim_changed = compute_compatible_and_update_dim_mapping(
[x_dims_mapping, out_dims_mapping], [0, 0])
if dim_changed:
changed = True

return changed

@staticmethod
def forward(ctx, *args, **kwargs):
"""
kwargs: inputname_mapping & outputname_mapping
"""
DistributedDefaultImpl0.forward(ctx, *args, **kwargs)
dist_op_context = ctx.dist_op_context
src_op = dist_op_context.cur_src_op
op_dist_attr = ctx.get_op_dist_attr_for_program(src_op)
main_block = dist_op_context.work_block
op = main_block.ops[-1]
assert op.type == "fill_constant_batch_size_like"

# modify shape attr according to how output are partitioned
out_name = op.output('Out')[0]
dims_mapping = op_dist_attr.get_output_dims_mapping(out_name)
process_mesh_shape = op_dist_attr.process_mesh.topology
shape_list = op.attr("shape")
# modify target shape
for idx, axis in enumerate(dims_mapping):
if axis >= 0:
shape_list[idx] = shape_list[idx] // process_mesh_shape[axis]

op._set_attr("shape", shape_list)
main_block._sync_with_cpp()

@staticmethod
def backward(ctx, *args, **kwargs):
DistributedDefaultImpl0.backward(ctx, *args, **kwargs)


register_distributed_operator_impl(
"fill_constant_batch_size_like",
DistributedFillConstantBatchSizeLikeImpl0("fill_by_shape"))
Original file line number Diff line number Diff line change
Expand Up @@ -433,8 +433,8 @@ def _right_operand_parameter_matmul_backward(ctx, *args, **kwargs):

def _init_param_sync(Weight_var, dist_op_context, startup_block, ctx, rank_id):

assert Weight_var.name not in dist_op_context.already_init_sync_vars, "{} is in {}.".format(
Weight_var.name, dist_op_context.already_init_sync_vars)
if Weight_var.name in dist_op_context.already_init_sync_vars:
return
assert startup_block.has_var(Weight_var.name)
dist_op_context.already_init_sync_vars.add(Weight_var.name)
param = startup_block.var(Weight_var.name)
Expand Down Expand Up @@ -819,6 +819,8 @@ def forward(ctx, *args, **kwargs):
out_var_dist_attr)

intermediate_var_0 = main_block.create_var(
name=unique_name.generate_with_ignorable_key(".".join(
["c_allreduce_sum", 'tmp'])),
shape=Out_var.shape,
dtype=Out_var.dtype,
type=Out_var.type,
Expand Down Expand Up @@ -1323,6 +1325,8 @@ def forward(ctx, *args, **kwargs):
out_var_dist_attr)

intermediate_var_0 = main_block.create_var(
name=unique_name.generate_with_ignorable_key(".".join(
["c_allreduce_sum", 'tmp'])),
shape=Out_var.shape,
dtype=Out_var.dtype,
type=Out_var.type,
Expand Down
3 changes: 3 additions & 0 deletions python/paddle/distributed/auto_parallel/partitioner.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,6 +285,9 @@ def _get_dist_shape(var, dist_attr):
var_shape = var.shape
mapping = dist_attr.dims_mapping
mesh = dist_attr.process_mesh.topology
if mapping == []:
return var_shape

assert len(var_shape) == len(
mapping
), "variable shape [{}] and dim_mapping [{}] is NOT match !".format(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,7 @@ def get_program():
dtype='float32')
label = static.data(
name="label", shape=[batch_size, sequence_len, 1], dtype='float32')

data_holder = [input, label]
# dataloader
dataloader = paddle.io.DataLoader.from_generator(
Expand All @@ -194,6 +195,17 @@ def get_program():
"dims_mapping": [-1, -1, -1]
})

# fill constant bsz like
tmp = paddle.fluid.layers.fill_constant_batch_size_like(
input=input, shape=[-1, 16, 0, 48], dtype='float32', value=0)
auto.shard_tensor(
tmp,
dist_attr={
"process_mesh": _g_process_mesh,
"dims_mapping": [-1, 0, -1, -1]
})

# model
mlp_start = MLPLayer(
hidden_size=hidden_size,
intermediate_size=4 * hidden_size,
Expand Down Expand Up @@ -395,6 +407,9 @@ def completion(train_program, start_program, dist_context):
op_dist_attr.impl_idx = 0
else:
op_dist_attr.impl_idx = 1
elif op.type == "fill_constant_batch_size_like":
op_dist_attr.impl_type = "fill_constant_batch_size_like"
op_dist_attr.impl_idx = 0
else:
op_dist_attr.impl_type = "default"
op_dist_attr.impl_idx = 0
Expand Down Expand Up @@ -428,13 +443,26 @@ def test_partitioner(self):
dist_main_prog, dist_startup_prog = partition(
train_program, start_program, dist_context)
global_block_ops = dist_main_prog.blocks[0].ops

fill_op = None
for op in global_block_ops:
if op.type == "fill_constant_batch_size_like":
fill_op = op

global_block_ops = [op.type for op in global_block_ops]
sub_block_ops = dist_main_prog.blocks[1].ops
sub_block_ops = [op.type for op in sub_block_ops]

self.assertTrue("c_allreduce_sum" in global_block_ops)
self.assertTrue("c_allreduce_sum" in sub_block_ops)

# test fill_constant_batch_size_like

self.assertTrue(fill_op is not None)
ref_shape = [-1, 8, 0, 48]
shape = fill_op.attr("shape")
self.assertTrue(ref_shape == shape)


if __name__ == "__main__":
unittest.main()

0 comments on commit c9cd47d

Please sign in to comment.