Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
softmax for fp16 with fp32 accumulator
Browse files Browse the repository at this point in the history
  • Loading branch information
szha committed Feb 8, 2019
1 parent 26ca37c commit 5b0ddcd
Show file tree
Hide file tree
Showing 3 changed files with 112 additions and 40 deletions.
42 changes: 42 additions & 0 deletions src/operator/mxnet_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -249,6 +249,48 @@ inline int get_num_threads<cpu>(const int N) {
LOG(FATAL) << "Unknown type enum " << type; \
}

#define MXNET_REAL_ACC_TYPE_SWITCH(type, DType, AType, ...)\
switch (type) { \
case mshadow::kFloat32: \
{ \
typedef float DType; \
typedef float AType; \
{__VA_ARGS__} \
} \
break; \
case mshadow::kFloat64: \
{ \
typedef double DType; \
typedef double AType; \
{__VA_ARGS__} \
} \
break; \
case mshadow::kFloat16: \
{ \
typedef mshadow::half::half_t DType; \
typedef float AType; \
{__VA_ARGS__} \
} \
break; \
case mshadow::kUint8: \
LOG(FATAL) << "This operation only support " \
"floating point types not uint8"; \
break; \
case mshadow::kInt8: \
LOG(FATAL) << "This operation only support " \
"floating point types not int8"; \
break; \
case mshadow::kInt32: \
LOG(FATAL) << "This operation only support " \
"floating point types, not int32"; \
break; \
case mshadow::kInt64: \
LOG(FATAL) << "This operation only support " \
"floating point types, not int64"; \
break; \
default: \
LOG(FATAL) << "Unknown type enum " << type; \
}

/*!
* \brief assign the val to out according
Expand Down
84 changes: 44 additions & 40 deletions src/operator/nn/softmax-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,22 +36,22 @@ namespace op {
namespace mxnet_op {

struct softmax_fwd {
template<typename DType>
MSHADOW_XINLINE static DType Map(DType a, DType b) {
template<typename DType, typename AType>
MSHADOW_XINLINE static DType Map(DType a, AType b) {
return DType(expf(a)/b);
}
};


struct log_softmax_fwd {
template<typename DType>
MSHADOW_XINLINE static DType Map(DType a, DType b) {
template<typename DType, typename AType>
MSHADOW_XINLINE static DType Map(DType a, AType b) {
return DType(a - logf(b));
}
};


template<typename OP, bool negate, typename DType, int ndim>
template<typename OP, bool negate, typename DType, typename AType, int ndim>
inline void Softmax(Stream<cpu> *s, DType *in, DType *out,
Shape<ndim> shape, int axis, const DType temperature) {
index_t M = shape[axis];
Expand All @@ -72,7 +72,7 @@ inline void Softmax(Stream<cpu> *s, DType *in, DType *out,
if (mmax < val) mmax = val;
}

DType sum = DType(0);
AType sum = AType(0);
DType in_val;
// By default temperature is 1.0, and only in reinforcement training
// users would set it to other values.
Expand Down Expand Up @@ -103,22 +103,22 @@ inline void Softmax(Stream<cpu> *s, DType *in, DType *out,


struct softmax_bwd {
template<typename DType>
MSHADOW_XINLINE static DType Map(DType ograd, DType out, DType sum) {
template<typename DType, typename AType>
MSHADOW_XINLINE static DType Map(DType ograd, DType out, AType sum) {
return DType(out * (ograd - sum));
}
};


struct log_softmax_bwd {
template<typename DType>
MSHADOW_XINLINE static DType Map(DType ograd, DType out, DType sum) {
template<typename DType, typename AType>
MSHADOW_XINLINE static DType Map(DType ograd, DType out, AType sum) {
return DType(ograd - expf(out)*sum);
}
};


template<typename OP1, typename OP2, int Req, bool negate, typename DType, int ndim>
template<typename OP1, typename OP2, int Req, bool negate, typename DType, typename AType, int ndim>
inline void SoftmaxGrad(Stream<cpu> *s, DType *out, DType *ograd,
DType *igrad, Shape<ndim> shape, int axis,
const DType temperature) {
Expand All @@ -133,7 +133,7 @@ inline void SoftmaxGrad(Stream<cpu> *s, DType *out, DType *ograd,
for (int i = 0; i < static_cast<int>(N); ++i) {
index_t base = unravel_dot(i, sshape, stride);

DType sum = DType(0);
AType sum = AType(0);
for (index_t j = 0; j < M; ++j) {
sum += OP1::Map(ograd[base + j*sa], out[base + j*sa]);
}
Expand Down Expand Up @@ -162,19 +162,19 @@ inline void SoftmaxGrad(Stream<cpu> *s, DType *out, DType *ograd,


#ifdef __CUDACC__
template<int x_bits, typename OP, bool negate, typename DType, int ndim>
template<int x_bits, typename OP, bool negate, typename DType, typename AType, int ndim>
__global__ void softmax_compute_kernel(DType *in, DType *out, index_t M, int axis,
Shape<ndim> sshape, Shape<ndim> stride,
const double temperature) {
const unsigned x_size = 1 << x_bits;
__shared__ DType smem[x_size];
__shared__ AType smem[x_size];
index_t sa = stride[axis];
index_t base = unravel_dot(blockIdx.x, sshape, stride);
index_t x = threadIdx.x;

red::maximum::SetInitValue(smem[x]);
for (index_t i = x; i < M; i += x_size) {
red::maximum::Reduce(smem[x], negate ? -in[base + i*sa] : in[base + i*sa]);
smem[x] = ::max(smem[x], negate ? -in[base + i*sa] : in[base + i*sa]);
}
__syncthreads();
cuda::Reduce1D<red::maximum, x_bits>(smem);
Expand All @@ -186,13 +186,12 @@ __global__ void softmax_compute_kernel(DType *in, DType *out, index_t M, int axi
DType val;
for (index_t i = x; i < M; i += x_size) {
val = negate ? -in[base + i*sa]:in[base + i*sa];
red::sum::Reduce(
smem[x], static_cast<DType>(expf((val - smax) / static_cast<DType>(temperature))));
smem[x] += static_cast<AType>(expf((val - smax) / static_cast<AType>(temperature)));
}
__syncthreads();
cuda::Reduce1D<red::sum, x_bits>(smem);
__syncthreads();
DType ssum = smem[0];
AType ssum = smem[0];
__syncthreads();

for (index_t i = x; i < M; i += x_size) {
Expand All @@ -201,7 +200,7 @@ __global__ void softmax_compute_kernel(DType *in, DType *out, index_t M, int axi
}
}

template<typename OP, bool negate, typename DType, int ndim>
template<typename OP, bool negate, typename DType, typename AType, int ndim>
inline void Softmax(Stream<gpu> *s, DType *in, DType *out,
Shape<ndim> shape, int axis, const double temperature) {
const int x_bits = 7;
Expand All @@ -212,31 +211,32 @@ inline void Softmax(Stream<gpu> *s, DType *in, DType *out,
Shape<ndim> sshape = shape;
sshape[axis] = 1;

softmax_compute_kernel<x_bits, OP, negate, DType, ndim>
softmax_compute_kernel<x_bits, OP, negate, DType, AType, ndim>
<<<N, x_size, 0, mshadow::Stream<gpu>::GetStream(s)>>>(
in, out, M, axis, sshape, stride, temperature);
MSHADOW_CUDA_POST_KERNEL_CHECK(softmax_compute_kernel);
}


template<int x_bits, typename OP1, typename OP2, int Req, bool negate, typename DType, int ndim>
template<int x_bits, typename OP1, typename OP2, int Req, bool negate,
typename DType, typename AType, int ndim>
__global__ void softmax_gradient_kernel(DType *out, DType *ograd, DType *igrad,
index_t M, int axis, Shape<ndim> sshape,
Shape<ndim> stride, const double temperature) {
const unsigned x_size = 1 << x_bits;
__shared__ DType smem[x_size];
__shared__ AType smem[x_size];
index_t sa = stride[axis];
index_t base = unravel_dot(blockIdx.x, sshape, stride);
index_t x = threadIdx.x;

red::sum::SetInitValue(smem[x]);
for (index_t i = x; i < M; i += x_size) {
red::sum::Reduce(smem[x], OP1::Map(ograd[base + i*sa], out[base + i*sa]));
smem[x] += OP1::Map(ograd[base + i*sa], out[base + i*sa]);
}
__syncthreads();
cuda::Reduce1D<red::sum, x_bits>(smem);
__syncthreads();
DType ssum = smem[0];
AType ssum = smem[0];
__syncthreads();

DType final_result;
Expand All @@ -250,7 +250,7 @@ __global__ void softmax_gradient_kernel(DType *out, DType *ograd, DType *igrad,
}


template<typename OP1, typename OP2, int Req, bool negate, typename DType, int ndim>
template<typename OP1, typename OP2, int Req, bool negate, typename DType, typename AType, int ndim>
inline void SoftmaxGrad(Stream<gpu> *s, DType *out, DType *ograd,
DType *igrad, Shape<ndim> shape, int axis,
const double temperature) {
Expand All @@ -262,7 +262,7 @@ inline void SoftmaxGrad(Stream<gpu> *s, DType *out, DType *ograd,
Shape<ndim> sshape = shape;
sshape[axis] = 1;

softmax_gradient_kernel<x_bits, OP1, OP2, Req, negate, DType, ndim>
softmax_gradient_kernel<x_bits, OP1, OP2, Req, negate, DType, AType, ndim>
<<<N, x_size, 0, mshadow::Stream<gpu>::GetStream(s)>>>(
out, ograd, igrad, M, axis, sshape, stride, temperature);
MSHADOW_CUDA_POST_KERNEL_CHECK(softmax_gradient_kernel);
Expand Down Expand Up @@ -297,15 +297,17 @@ void SoftmaxCompute(const nnvm::NodeAttrs& attrs,
const double temperature = param.temperature.has_value() ?
param.temperature.value() : 1.0;
TShape shape = AxisShapeCompact(inputs[0].shape_, &axis, true);
MSHADOW_REAL_TYPE_SWITCH(inputs[0].type_flag_, DType, {
MXNET_REAL_ACC_TYPE_SWITCH(inputs[0].type_flag_, DType, AType, {
if (shape.ndim() == 2) {
Softmax<OP, negate>(ctx.get_stream<xpu>(), inputs[0].dptr<DType>(),
outputs[0].dptr<DType>(), shape.get<2>(), axis,
static_cast<DType>(temperature));
Softmax<OP, negate, DType, AType>(
ctx.get_stream<xpu>(), inputs[0].dptr<DType>(),
outputs[0].dptr<DType>(), shape.get<2>(), axis,
static_cast<DType>(temperature));
} else {
Softmax<OP, negate>(ctx.get_stream<xpu>(), inputs[0].dptr<DType>(),
outputs[0].dptr<DType>(), shape.get<3>(), axis,
static_cast<DType>(temperature));
Softmax<OP, negate, DType, AType>(
ctx.get_stream<xpu>(), inputs[0].dptr<DType>(),
outputs[0].dptr<DType>(), shape.get<3>(), axis,
static_cast<DType>(temperature));
}
});
}
Expand All @@ -324,16 +326,18 @@ void SoftmaxGradCompute(const nnvm::NodeAttrs& attrs,
const double temperature = param.temperature.has_value() ?
param.temperature.value() : 1.0;
TShape shape = AxisShapeCompact(inputs[0].shape_, &axis, true);
MSHADOW_REAL_TYPE_SWITCH(inputs[0].type_flag_, DType, {
MXNET_REAL_ACC_TYPE_SWITCH(inputs[0].type_flag_, DType, AType, {
MXNET_ASSIGN_REQ_SWITCH(req[0], Req, {
if (shape.ndim() == 2) {
SoftmaxGrad<OP1, OP2, Req, negate>(ctx.get_stream<xpu>(), inputs[1].dptr<DType>(),
inputs[0].dptr<DType>(), outputs[0].dptr<DType>(),
shape.get<2>(), axis, static_cast<DType>(temperature));
SoftmaxGrad<OP1, OP2, Req, negate, DType, AType>(
ctx.get_stream<xpu>(), inputs[1].dptr<DType>(),
inputs[0].dptr<DType>(), outputs[0].dptr<DType>(),
shape.get<2>(), axis, static_cast<DType>(temperature));
} else {
SoftmaxGrad<OP1, OP2, Req, negate>(ctx.get_stream<xpu>(), inputs[1].dptr<DType>(),
inputs[0].dptr<DType>(), outputs[0].dptr<DType>(),
shape.get<3>(), axis, static_cast<DType>(temperature));
SoftmaxGrad<OP1, OP2, Req, negate, DType, AType>(
ctx.get_stream<xpu>(), inputs[1].dptr<DType>(),
inputs[0].dptr<DType>(), outputs[0].dptr<DType>(),
shape.get<3>(), axis, static_cast<DType>(temperature));
}
});
});
Expand Down
26 changes: 26 additions & 0 deletions tests/python/unittest/test_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -4515,6 +4515,32 @@ def softmax_forward(input_data, true_output):
softmax_forward(mx.nd.array([[[[-3.4e38,-3.4e38]]]]), np.array([1.0,1.0]))
softmax_forward(mx.nd.array([[[[3.4e38,3.4e38]]]]), np.array([1.0,1.0]))

@with_seed()
def test_softmax_fp16():
def check_fp16_fp32_almost_equal(input_data):
fp16_input = input_data.astype('float16')
fp32_input = input_data.astype('float32')
fp16_input.attach_grad()
fp32_input.attach_grad()
with mx.autograd.record():
fp16_softmax = fp16_input.softmax(axis=-1)
fp32_softmax = fp32_input.softmax(axis=-1)
fp16_softmax.backward()
fp32_softmax.backward()
assert_almost_equal(fp16_softmax.asnumpy(), fp32_softmax.asnumpy(), rtol=1e-5, atol=1e-5)
assert_almost_equal(fp16_input.grad.asnumpy(), fp32_input.grad.asnumpy(), rtol=1e-5, atol=1e-5)

with mx.autograd.record():
fp16_log_softmax = fp16_input.log_softmax(axis=-1)
fp32_log_softmax = fp32_input.log_softmax(axis=-1)
fp16_log_softmax.backward()
fp32_log_softmax.backward()
assert_almost_equal(fp16_log_softmax.asnumpy(), fp32_log_softmax.asnumpy(), rtol=1e-2, atol=1e-2)
assert_almost_equal(fp16_input.grad.asnumpy(), fp32_input.grad.asnumpy(), rtol=1e-2, atol=1e-2)

for _ in range(5):
check_fp16_fp32_almost_equal(mx.random.uniform(shape=(100, 500)))

@with_seed()
def test_pick():
def test_pick_helper(index_type=np.int32):
Expand Down

0 comments on commit 5b0ddcd

Please sign in to comment.