diff --git a/src/operator/nn/layer_norm-inl.h b/src/operator/nn/layer_norm-inl.h index c7de7d734521..baeb6e4e869c 100644 --- a/src/operator/nn/layer_norm-inl.h +++ b/src/operator/nn/layer_norm-inl.h @@ -111,7 +111,7 @@ void LayerNormCompute(const nnvm::NodeAttrs& attrs, // Calculate mean MSHADOW_REAL_TYPE_SWITCH(outputs[0].type_flag_, DType, { BROADCAST_NDIM_SWITCH(red_dst_shape.ndim(), NDim, { - broadcast::Reduce( + broadcast::Reduce( s, mean_data, req[0], workspace, in_data); Tensor mean_data_tensor = mean_data.FlatTo1D(s); mean_data_tensor /= scalar(channel_size); @@ -125,7 +125,7 @@ void LayerNormCompute(const nnvm::NodeAttrs& attrs, const TBlob centered_out = outputs[0].reshape(red_src_shape); MSHADOW_REAL_TYPE_SWITCH(outputs[0].type_flag_, DType, { BROADCAST_NDIM_SWITCH(red_dst_shape.ndim(), NDim, { - broadcast::Reduce( + broadcast::Reduce( s, std_data, req[0], workspace, centered_out); Tensor std_data_tensor = std_data.FlatTo1D(s); std_data_tensor = F(std_data_tensor / scalar(channel_size) @@ -133,17 +133,17 @@ void LayerNormCompute(const nnvm::NodeAttrs& attrs, }); }); // Calculate data = data / std - BinaryBroadcastCompute(attrs, ctx, - {outputs[0], outputs[layernorm::kStd]}, - {kWriteTo}, {outputs[0]}); + BinaryBroadcastCompute(attrs, ctx, + {outputs[0], outputs[layernorm::kStd]}, + {kWriteTo}, {outputs[0]}); // Calculate data = data * gamma - BinaryBroadcastCompute(attrs, ctx, - {outputs[0], gamma}, - {kWriteTo}, {outputs[0]}); + BinaryBroadcastCompute(attrs, ctx, + {outputs[0], gamma}, + {kWriteTo}, {outputs[0]}); // Calculate data = data + beta - BinaryBroadcastCompute(attrs, ctx, - {outputs[0], beta}, - {kWriteTo}, {outputs[0]}); + BinaryBroadcastCompute(attrs, ctx, + {outputs[0], beta}, + {kWriteTo}, {outputs[0]}); } /* @@ -203,14 +203,14 @@ void LayerNormGradCompute(const nnvm::NodeAttrs& attrs, BROADCAST_NDIM_SWITCH(red_dst_shape.ndim(), NDim, { reduce_workspace_size = std::max(reduce_workspace_size, - broadcast::ReduceWorkspaceSize(s, red_src_shape, - kAddTo, red_dst_shape)); + broadcast::ReduceWorkspaceSize(s, red_dst_shape, + kAddTo, red_src_shape)); }); BROADCAST_NDIM_SWITCH(red_exclude_dst_shape.ndim(), NDim, { reduce_workspace_size = std::max(reduce_workspace_size, - broadcast::ReduceWorkspaceSize(s, red_exclude_src_shape, kAddTo, - red_exclude_dst_shape)); + broadcast::ReduceWorkspaceSize(s, red_exclude_dst_shape, kAddTo, + red_exclude_src_shape)); }); }); workspace = ctx.requested[0].get_space_typed( @@ -222,17 +222,17 @@ void LayerNormGradCompute(const nnvm::NodeAttrs& attrs, const TBlob red_out = TBlob(workspace.dptr_ + reduce_workspace_size + data_size * 2, mean.shape_, mean.dev_mask(), mean.type_flag_, mean.dev_id()); // Compute normalized_data = (data - mean) / std - BinaryBroadcastCompute(attrs, ctx, - {data, mean}, - {kWriteTo}, {normalized_data}); - BinaryBroadcastCompute(attrs, ctx, - {normalized_data, std}, - {kWriteTo}, {normalized_data}); + BinaryBroadcastCompute(attrs, ctx, + {data, mean}, + {kWriteTo}, {normalized_data}); + BinaryBroadcastCompute(attrs, ctx, + {normalized_data, std}, + {kWriteTo}, {normalized_data}); // Calculate grad_beta if (req[2] != kNullOp) { MSHADOW_REAL_TYPE_SWITCH(outputs[2].type_flag_, DType, { BROADCAST_NDIM_SWITCH(red_exclude_dst_shape.ndim(), NDim, { - broadcast::Reduce( + broadcast::Reduce( s, outputs[2].reshape(red_exclude_dst_shape), req[2], workspace, ograd.reshape(red_exclude_src_shape)); }); @@ -244,7 +244,7 @@ void LayerNormGradCompute(const nnvm::NodeAttrs& attrs, if (req[1] != kNullOp) { MSHADOW_REAL_TYPE_SWITCH(outputs[1].type_flag_, DType, { BROADCAST_NDIM_SWITCH(red_exclude_dst_shape.ndim(), NDim, { - broadcast::Reduce( + broadcast::Reduce( s, outputs[1].reshape(red_exclude_dst_shape), req[1], workspace, ograd_mult.reshape(red_exclude_src_shape)); }); @@ -263,7 +263,7 @@ void LayerNormGradCompute(const nnvm::NodeAttrs& attrs, {kWriteTo}, {ograd_mult}); MSHADOW_REAL_TYPE_SWITCH(outputs[0].type_flag_, DType, { BROADCAST_NDIM_SWITCH(red_dst_shape.ndim(), NDim, { - broadcast::Reduce( + broadcast::Reduce( s, red_out.reshape(red_dst_shape), kWriteTo, workspace, ograd_mult.reshape(red_src_shape)); }); @@ -277,16 +277,16 @@ void LayerNormGradCompute(const nnvm::NodeAttrs& attrs, {kWriteTo}, {ograd_mult}); MSHADOW_REAL_TYPE_SWITCH(outputs[0].type_flag_, DType, { BROADCAST_NDIM_SWITCH(red_dst_shape.ndim(), NDim, { - broadcast::Reduce( + broadcast::Reduce( s, red_out.reshape(red_dst_shape), kWriteTo, workspace, ograd_mult.reshape(red_src_shape)); }); Tensor red_out_tensor = red_out.FlatTo1D(s); red_out_tensor /= scalar(- channel_size); }); - BinaryBroadcastCompute(attrs, ctx, - {normalized_data, red_out}, - {kAddTo}, {outputs[0]}); + BinaryBroadcastCompute(attrs, ctx, + {normalized_data, red_out}, + {kAddTo}, {outputs[0]}); } }