From bd775690df2edb28d508099d952411554ba6f387 Mon Sep 17 00:00:00 2001 From: zhouwei25 Date: Tue, 15 Nov 2022 05:19:29 +0000 Subject: [PATCH] [Zero-Dim] Make auto parallel judge dim more strict --- paddle/fluid/operators/batch_norm_op.cc | 2 +- python/paddle/distributed/auto_parallel/completion.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/paddle/fluid/operators/batch_norm_op.cc b/paddle/fluid/operators/batch_norm_op.cc index 6c6591f34abcef..878ab18432cdcf 100644 --- a/paddle/fluid/operators/batch_norm_op.cc +++ b/paddle/fluid/operators/batch_norm_op.cc @@ -164,7 +164,7 @@ void BatchNormOp::InferShape(framework::InferShapeContext *ctx) const { ctx->SetOutputDim("SavedMean", {C}); ctx->SetOutputDim("SavedVariance", {C}); ctx->ShareLoD("X", "Y"); - if (ctx->HasInput("ReserveSpace")) { + if (ctx->HasOutput("ReserveSpace")) { ctx->SetOutputDim("ReserveSpace", {-1}); } } diff --git a/python/paddle/distributed/auto_parallel/completion.py b/python/paddle/distributed/auto_parallel/completion.py index c0f70f482dd17f..7f5e0fee775267 100644 --- a/python/paddle/distributed/auto_parallel/completion.py +++ b/python/paddle/distributed/auto_parallel/completion.py @@ -1239,7 +1239,7 @@ def _get_op_by_id(ops, id): input_var ).dims_mapping else: - if fwd_op_dist_attr.get_input_dims_mapping(input_name): + if input_name in forward_op.input_arg_names: ref_dims_mapping = ( fwd_op_dist_attr.get_input_dims_mapping( input_name @@ -1544,7 +1544,7 @@ def _get_op_by_id(ops, id): input_var ).dims_mapping else: - if fwd_op_dist_attr.get_input_dims_mapping(input_name): + if input_name in forward_op.input_arg_names: ref_dims_mapping = ( fwd_op_dist_attr.get_input_dims_mapping( input_name