Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Add warning for fp16 inputs with MXNET_SAFE_ACCUMULATION=0 #15046

Merged
merged 1 commit into from
May 23, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 15 additions & 6 deletions src/operator/nn/layer_norm-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -113,10 +113,18 @@ void LayerNormComputeGeneral(const nnvm::NodeAttrs& attrs,
});
});
workspace = ctx.requested[0].get_space_typed<xpu, 1, char>(Shape1(workspace_size), s);

bool safe_acc = dmlc::GetEnv("MXNET_SAFE_ACCUMULATION", false);
if (!safe_acc && inputs[0].type_flag_ == mshadow::kFloat16) {
common::LogOnce("MXNET_SAFE_ACCUMULATION=1 is recommended for float16 inputs for LayerNorm. "
"See https://mxnet.incubator.apache.org/versions/master/faq/env_var.html "
"for more details.");
}

// Calculate mean
MSHADOW_REAL_TYPE_SWITCH(outputs[0].type_flag_, DType, {
BROADCAST_NDIM_SWITCH(red_dst_shape.ndim(), NDim, {
if (dmlc::GetEnv("MXNET_SAFE_ACCUMULATION", false)) {
if (safe_acc) {
broadcast::Reduce<mshadow_op::sum, NDim, DType, mshadow_op::identity, false>(
s, mean_data, req[0], workspace, in_data);
} else {
Expand All @@ -135,7 +143,7 @@ void LayerNormComputeGeneral(const nnvm::NodeAttrs& attrs,
const TBlob centered_out = outputs[0].reshape(red_src_shape);
MSHADOW_REAL_TYPE_SWITCH(outputs[0].type_flag_, DType, {
BROADCAST_NDIM_SWITCH(red_dst_shape.ndim(), NDim, {
if (dmlc::GetEnv("MXNET_SAFE_ACCUMULATION", false)) {
if (safe_acc) {
broadcast::Reduce<mshadow_op::sum, NDim, DType, mshadow_op::square, false>(
s, std_data, req[0], workspace, centered_out);
} else {
Expand Down Expand Up @@ -250,10 +258,11 @@ void LayerNormGradComputeGeneral(const nnvm::NodeAttrs& attrs,
{normalized_data, std},
{kWriteTo}, {normalized_data});
// Calculate grad_beta
bool safe_acc = dmlc::GetEnv("MXNET_SAFE_ACCUMULATION", false);
if (req[2] != kNullOp) {
MSHADOW_REAL_TYPE_SWITCH(outputs[2].type_flag_, DType, {
BROADCAST_NDIM_SWITCH(red_exclude_dst_shape.ndim(), NDim, {
if (dmlc::GetEnv("MXNET_SAFE_ACCUMULATION", false)) {
if (safe_acc) {
broadcast::Reduce<mshadow_op::sum, NDim, DType, mshadow_op::identity, false>(
s, outputs[2].reshape(red_exclude_dst_shape), req[2], workspace,
ograd.reshape(red_exclude_src_shape));
Expand All @@ -271,7 +280,7 @@ void LayerNormGradComputeGeneral(const nnvm::NodeAttrs& attrs,
if (req[1] != kNullOp) {
MSHADOW_REAL_TYPE_SWITCH(outputs[1].type_flag_, DType, {
BROADCAST_NDIM_SWITCH(red_exclude_dst_shape.ndim(), NDim, {
if (dmlc::GetEnv("MXNET_SAFE_ACCUMULATION", false)) {
if (safe_acc) {
broadcast::Reduce<mshadow_op::sum, NDim, DType, mshadow_op::identity, false>(
s, outputs[1].reshape(red_exclude_dst_shape), req[1], workspace,
ograd_mult.reshape(red_exclude_src_shape));
Expand All @@ -296,7 +305,7 @@ void LayerNormGradComputeGeneral(const nnvm::NodeAttrs& attrs,
{kWriteTo}, {ograd_mult});
MSHADOW_REAL_TYPE_SWITCH(outputs[0].type_flag_, DType, {
BROADCAST_NDIM_SWITCH(red_dst_shape.ndim(), NDim, {
if (dmlc::GetEnv("MXNET_SAFE_ACCUMULATION", false)) {
if (safe_acc) {
broadcast::Reduce<mshadow_op::sum, NDim, DType, mshadow_op::identity, false>(
s, red_out.reshape(red_dst_shape), kWriteTo, workspace,
ograd_mult.reshape(red_src_shape));
Expand All @@ -316,7 +325,7 @@ void LayerNormGradComputeGeneral(const nnvm::NodeAttrs& attrs,
{kWriteTo}, {ograd_mult});
MSHADOW_REAL_TYPE_SWITCH(outputs[0].type_flag_, DType, {
BROADCAST_NDIM_SWITCH(red_dst_shape.ndim(), NDim, {
if (dmlc::GetEnv("MXNET_SAFE_ACCUMULATION", false)) {
if (safe_acc) {
broadcast::Reduce<mshadow_op::sum, NDim, DType, mshadow_op::identity, false>(
s, red_out.reshape(red_dst_shape), kWriteTo, workspace,
ograd_mult.reshape(red_src_shape));
Expand Down
5 changes: 5 additions & 0 deletions src/operator/nn/layer_norm.cu
Original file line number Diff line number Diff line change
Expand Up @@ -340,6 +340,11 @@ void LayerNormCompute<gpu>(const nnvm::NodeAttrs& attrs,
if (axis == inputs[0].ndim() - 1) {
// Try to use the accelerated CUDA kernels
bool safe_acc = dmlc::GetEnv("MXNET_SAFE_ACCUMULATION", false);
if (!safe_acc && inputs[0].type_flag_ == mshadow::kFloat16) {
common::LogOnce("MXNET_SAFE_ACCUMULATION=1 is recommended for LayerNorm with float16 inputs. "
"See https://mxnet.incubator.apache.org/versions/master/faq/env_var.html "
"for more details.");
}
if (safe_acc) {
return LayerNormGPUContig<true>(param, ctx, inputs, req, outputs);
} else {
Expand Down
5 changes: 5 additions & 0 deletions src/operator/nn/softmax-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -411,6 +411,11 @@ void SoftmaxCompute(const nnvm::NodeAttrs& attrs,
param.temperature.value() : 1.0;
mxnet::TShape shape = AxisShapeCompact(inputs[0].shape_, &axis, true);
bool safe_acc = dmlc::GetEnv("MXNET_SAFE_ACCUMULATION", false);
if (!safe_acc && inputs[0].type_flag_ == mshadow::kFloat16) {
common::LogOnce("MXNET_SAFE_ACCUMULATION=1 is recommended for softmax with float16 inputs. "
"See https://mxnet.incubator.apache.org/versions/master/faq/env_var.html "
"for more details.");
}

MXNET_REAL_ACC_TYPE_SWITCH(inputs[0].type_flag_, DType, AType, {
MSHADOW_REAL_TYPE_SWITCH(outputs[0].type_flag_, OType, {
Expand Down
11 changes: 8 additions & 3 deletions src/operator/tensor/broadcast_reduce_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -1183,17 +1183,22 @@ void LpNormCompute(const nnvm::NodeAttrs& attrs,
} else {
small = ReduceAxesShapeImpl(inputs[0].shape_, param.axis, true, false);
}

bool safe_acc = dmlc::GetEnv("MXNET_SAFE_ACCUMULATION", false);
if (!safe_acc && inputs[0].type_flag_ == mshadow::kFloat16) {
common::LogOnce("MXNET_SAFE_ACCUMULATION=1 is recommended for LpNorm with float16 inputs. "
"See https://mxnet.incubator.apache.org/versions/master/faq/env_var.html "
"for more details.");
}
if (param.ord == 1) {
if (dmlc::GetEnv("MXNET_SAFE_ACCUMULATION", false)) {
if (safe_acc) {
ReduceAxesComputeImpl<xpu, mshadow_op::sum, true, false, mshadow_op::abs>(
ctx, inputs, req, outputs, small);
} else {
ReduceAxesComputeImpl<xpu, mshadow_op::sum, false, false, mshadow_op::abs>(
ctx, inputs, req, outputs, small);
}
} else if (param.ord == 2) {
if (dmlc::GetEnv("MXNET_SAFE_ACCUMULATION", false)) {
if (safe_acc) {
ReduceAxesComputeImpl<xpu, mshadow_op::nrm2, true, false, mshadow_op::identity>(
ctx, inputs, req, outputs, small);
} else {
Expand Down