Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

use safe accumulation for norm #14240

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 64 additions & 4 deletions src/operator/mshadow_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -898,13 +898,13 @@ struct nanprod {
/*! \brief compute l2 norm */
struct nrm2 {
/*! \brief do reduction into dst */
template<typename DType>
MSHADOW_XINLINE static void Reduce(volatile DType& sum_of_squares, volatile DType src) { // NOLINT(*)
template<typename AType, typename DType>
MSHADOW_XINLINE static void Reduce(volatile AType& sum_of_squares, volatile DType src) { // NOLINT(*)
sum_of_squares += src * src;
}
/*! \brief do stable reduction into dst */
template<typename DType>
MSHADOW_XINLINE static void Reduce(volatile DType& sum_of_squares, volatile DType src, volatile DType& scale) { // NOLINT(*)
template<typename AType, typename DType>
MSHADOW_XINLINE static void Reduce(volatile AType& sum_of_squares, volatile DType src, volatile DType& scale) { // NOLINT(*)
if (src != 0) {
DType abs = mshadow_op::abs::Map(src);
if (scale < abs) {
Expand Down Expand Up @@ -965,6 +965,66 @@ struct nrm2 {
}
};

/*! \brief sum reducer */
struct sum {
/*! \brief do reduction into dst */
template<typename AType, typename DType>
MSHADOW_XINLINE static void Reduce(volatile AType& dst, volatile DType src) { // NOLINT(*)
dst += src;
}
/*! \brief do stable reduction into dst */
template<typename AType, typename DType>
MSHADOW_XINLINE static void Reduce(volatile AType& dst, volatile DType src, volatile DType& residual) { // NOLINT(*)
DType y = src - residual;
DType t = dst + y;
residual = (t - dst) - y;
dst = t;
}
/*! \brief combine the results of two reducers */
template<typename DType>
MSHADOW_XINLINE static void Merge(volatile DType& dst_val, volatile DType& src_val) { // NOLINT(*)
Reduce(dst_val, src_val);
}
/*! \brief combine the results of two reducers */
template<typename DType>
MSHADOW_XINLINE static void Merge(volatile DType& dst_val, volatile DType& dst_residual, volatile DType& src_val, volatile DType& src_residual) { // NOLINT(*)
DType t1 = dst_val + src_val;
DType e = t1 - dst_val;
DType t2 = ((src_val - e) + (dst_val - (t1 - e))) + dst_residual + src_residual;
dst_val = t1 + t2;
dst_residual = t2 - (dst_val - t1);
}
/*! \brief finalize reduction */
template<typename DType>
MSHADOW_XINLINE static void Finalize(volatile DType& dst) {} // NOLINT(*)
/*! \brief finalize reduction */
template<typename DType>
MSHADOW_XINLINE static void Finalize(volatile DType& dst, volatile DType& residual) {} // NOLINT(*)
/*!
*\brief calculate gradient of redres with respect to redsrc,
* redres: reduced result, redsrc: one of reduction element
*/
template<typename DType>
MSHADOW_XINLINE static DType PartialGrad(DType redres, DType redsrc) {
return 1;
}
/*!
*\brief set the initial value during reduction
*/
template<typename DType>
MSHADOW_XINLINE static void SetInitValue(DType &initv) { // NOLINT(*)
initv = 0;
}
/*!
*\brief set the initial value during reduction
*/
template<typename DType>
MSHADOW_XINLINE static void SetInitValue(DType &initv, DType &residual) { // NOLINT(*)
SetInitValue(initv);
residual = 0;
}
};

struct nanprod_grad : public mxnet_op::tunable {
template<typename DType>
MSHADOW_XINLINE static DType Map(DType a, DType b) {
Expand Down
33 changes: 25 additions & 8 deletions src/operator/mxnet_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -273,25 +273,42 @@ inline int get_num_threads<cpu>(const int N) {
} \
break; \
case mshadow::kUint8: \
LOG(FATAL) << "This operation only support " \
"floating point types not uint8"; \
{ \
typedef uint8_t DType; \
typedef uint8_t AType; \
LOG(FATAL) << "This operation only support " \
"floating point types not uint8"; \
} \
break; \
case mshadow::kInt8: \
LOG(FATAL) << "This operation only support " \
"floating point types not int8"; \
{ \
typedef int8_t DType; \
typedef int8_t AType; \
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we support acc in int types, too?

LOG(FATAL) << "This operation only support " \
"floating point types not int8"; \
} \
break; \
case mshadow::kInt32: \
LOG(FATAL) << "This operation only support " \
"floating point types, not int32"; \
{ \
typedef int32_t DType; \
typedef int32_t AType; \
LOG(FATAL) << "This operation only support " \
"floating point types, not int32"; \
} \
break; \
case mshadow::kInt64: \
LOG(FATAL) << "This operation only support " \
"floating point types, not int64"; \
{ \
typedef int64_t DType; \
typedef int64_t AType; \
LOG(FATAL) << "This operation only support " \
"floating point types, not int64"; \
} \
break; \
default: \
LOG(FATAL) << "Unknown type enum " << type; \
}


/*!
* \brief assign the val to out according
* to request in Kernel::Launch
Expand Down
41 changes: 24 additions & 17 deletions src/operator/tensor/broadcast_reduce-inl.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -72,15 +72,15 @@ void BinaryBroadcastComputeImpl(Stream<gpu> *s, const OpReqType req,
}

const int nthread_reduce = kMaxThreadsPerBlock;
template<typename Reducer, int ndim, typename DType, typename OP, int unroll>
template<typename Reducer, int ndim, typename AType, typename DType, typename OP, int unroll>
__launch_bounds__(nthread_reduce)
__global__ void reduce_kernel(const int N, const int M, const bool addto,
const DType* __restrict big, DType *small,
const Shape<ndim> big_shape0, const Shape<ndim> small_shape,
const Shape<ndim> big_shape, const Shape<ndim> big_stride,
const int Mnext, const bool do_transpose) {
extern __shared__ char shTileChar[];
DType* shTile = (DType*)(shTileChar);
AType* shTile = (AType*)(shTileChar);
const int tid = threadIdx.x + threadIdx.y*blockDim.x;
const int bx = (do_transpose) ? blockDim.y : blockDim.x;
const int by = (do_transpose) ? blockDim.x : blockDim.y;
Expand All @@ -95,7 +95,7 @@ __global__ void reduce_kernel(const int N, const int M, const bool addto,
Shape<ndim> coord = unravel(idx, small_shape);
int idx_big0 = ravel(coord, big_shape0);

DType val, residual;
AType val, residual;
Reducer::SetInitValue(val, residual);
if (idx < N) {
for (int k = tidy + Mstart; k < Mend; k += by*unroll) {
Expand All @@ -113,7 +113,7 @@ __global__ void reduce_kernel(const int N, const int M, const bool addto,
}
#pragma unroll
for (int u=0;u < unroll;u++) {
if (k + u*by < Mend) Reducer::Reduce(val, tmp[u], residual);
if (k + u*by < Mend) Reducer::Reduce(val, AType(tmp[u]), residual);
}
}
}
Expand All @@ -127,7 +127,7 @@ __global__ void reduce_kernel(const int N, const int M, const bool addto,
shTile[it0 * 2 + 1] = residual;
__syncthreads();
for (int t=1;t < by;t <<= 1) {
DType tmp, tmp_residual;
AType tmp, tmp_residual;
Reducer::SetInitValue(tmp, tmp_residual);
if (tidy + t < by) {
tmp = shTile[(it0 + t*fbx) * 2];
Expand All @@ -139,12 +139,12 @@ __global__ void reduce_kernel(const int N, const int M, const bool addto,
}
if (idx < N && tidy == 0) {
Reducer::Finalize(shTile[tidx * 2], shTile[tidx * 2 + 1]);
assign(&small[idx + m0*N], addto, shTile[tidx * 2]);
assign(&small[idx + m0*N], addto, DType(shTile[tidx * 2]));
}
} else {
if (idx < N) {
Reducer::Finalize(val, residual);
assign(&small[idx + m0*N], addto, val);
assign(&small[idx + m0*N], addto, DType(val));
}
}
}
Expand Down Expand Up @@ -261,18 +261,18 @@ __global__ void reduce_lines_kernel(const int N, const int M, const bool addto,
}
}

template<typename Reducer, int ndim, typename DType, typename OP>
template<typename Reducer, int ndim, typename AType, typename DType, typename OP>
__global__ void reduce_kernel_M1(const int N, const bool addto,
const DType* __restrict big, DType *small, const Shape<ndim> bshape,
const Shape<ndim> sshape) {
for (int idx = threadIdx.x + blockIdx.x*blockDim.x; idx < N; idx += blockDim.x*gridDim.x) {
Shape<ndim> coord = unravel(idx, sshape);
int j = ravel(coord, bshape);
DType val, residual;
AType val, residual;
Reducer::SetInitValue(val, residual);
Reducer::Reduce(val, OP::Map(big[j]), residual);
Reducer::Reduce(val, AType(OP::Map(big[j])), residual);
Reducer::Finalize(val, residual);
assign(&small[idx], addto, val);
assign(&small[idx], addto, DType(val));
}
}

Expand Down Expand Up @@ -516,18 +516,17 @@ ReduceImplConfig<ndim> ConfigureReduceImpl(const TShape& small, const TShape& bi
{__VA_ARGS__} \
}

template<typename Reducer, int ndim, typename DType, typename OP>
template<typename Reducer, int ndim, typename AType, typename DType, typename OP>
void ReduceImpl(cudaStream_t stream, const TBlob& small, const OpReqType req,
const TBlob& big, const Tensor<gpu, 1, char>& workspace,
const ReduceImplConfig<ndim>& config) {
if (config.M == 1) {
reduce_kernel_M1<Reducer, ndim, DType, OP>
reduce_kernel_M1<Reducer, ndim, AType, DType, OP>
<<< config.kernel_1.gridDim, config.kernel_1.blockDim, 0, stream >>>(
config.N, req == kAddTo, big.dptr<DType>(), small.dptr<DType>(), big.shape_.get<ndim>(),
small.shape_.get<ndim>());
MSHADOW_CUDA_POST_KERNEL_CHECK(reduce_kernel_M1);
} else {

DType* small_dptr = small.dptr<DType>();
bool addto = (req == kAddTo);
if (config.Mnext > 1) {
Expand All @@ -544,7 +543,7 @@ void ReduceImpl(cudaStream_t stream, const TBlob& small, const OpReqType req,
config.kernel_1.blockDim.x : config.kernel_1.blockDim.y;
const bool do_unroll = ( config.M / (by*config.Mnext) >= config.unroll_reduce );
KERNEL_UNROLL_SWITCH(do_unroll, ReduceImplConfig<ndim>::unroll_reduce, UNROLL, {
reduce_kernel<Reducer, ndim, DType, OP, UNROLL>
reduce_kernel<Reducer, ndim, AType, DType, OP, UNROLL>
<<< config.kernel_1.gridDim, config.kernel_1.blockDim, config.kernel_1.shMemSize, stream>>>(
config.N, config.M, addto, big.dptr<DType>(), small_dptr, big.shape_.get<ndim>(),
small.shape_.get<ndim>(), config.rshape, config.rstride, config.Mnext,
Expand Down Expand Up @@ -610,14 +609,22 @@ void ReduceImpl(cudaStream_t stream, const TBlob& small, const TBlob& lhs, const

#undef KERNEL_UNROLL_SWITCH

template<typename Reducer, int ndim, typename DType, typename OP>
template<typename Reducer, int ndim, typename DType, typename OP, bool safe_acc = false>
void Reduce(Stream<gpu> *s, const TBlob& small, const OpReqType req,
const Tensor<gpu, 1, char>& workspace, const TBlob& big) {
if (req == kNullOp) return;
cudaStream_t stream = Stream<gpu>::GetStream(s);
ReduceImplConfig<ndim> config =
ConfigureReduceImpl<ndim, DType>(small.shape_, big.shape_, NULL, NULL);
ReduceImpl<Reducer, ndim, DType, OP>(stream, small, req, big, workspace, config);
if (safe_acc) {
MXNET_REAL_ACC_TYPE_SWITCH(mshadow::DataType<DType>::kFlag, DataType, AType, {
typedef typename std::conditional<safe_acc, AType, DataType>::type AccType;
ReduceImpl<Reducer, ndim, AccType, DataType, OP>(
stream, small, req, big, workspace, config);
});
} else {
ReduceImpl<Reducer, ndim, DType, DType, OP>(stream, small, req, big, workspace, config);
}
}

template <typename Reducer, int ndim, typename DType, typename OP>
Expand Down
31 changes: 20 additions & 11 deletions src/operator/tensor/broadcast_reduce-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -153,21 +153,21 @@ MSHADOW_XINLINE void binary_broadcast_assign(const index_t idx, const bool addto
assign(&out[idx], addto, OP::Map(lhs[j], rhs[k]));
}

template<typename Reducer, int ndim, typename DType, typename OP>
template<typename Reducer, int ndim, typename AType, typename DType, typename OP>
MSHADOW_XINLINE void seq_reduce_assign(const index_t idx, const size_t M, const bool addto,
const DType* __restrict big, DType *small,
const Shape<ndim>& bshape, const Shape<ndim>& sshape,
const Shape<ndim>& rshape, const Shape<ndim>& rstride) {
Shape<ndim> coord = unravel(idx, sshape);
index_t j = ravel(coord, bshape);
DType val, residual;
AType val, residual;
Reducer::SetInitValue(val, residual);
for (size_t k = 0; k < M; ++k) {
coord = unravel(k, rshape);
Reducer::Reduce(val, OP::Map(big[j + dot(coord, rstride)]), residual);
Reducer::Reduce(val, AType(OP::Map(big[j + dot(coord, rstride)])), residual);
}
Reducer::Finalize(val, residual);
assign(&small[idx], addto, val);
assign(&small[idx], addto, DType(val));
}

#ifdef __CUDACC__
Expand All @@ -194,15 +194,15 @@ void BinaryBroadcastComputeImpl(Stream<cpu> *s, const OpReqType req,
out.shape_.get<ndim>());
}

template<typename Reducer, int ndim, typename DType, typename OP>
template<typename Reducer, int ndim, typename AType, typename DType, typename OP>
void seq_reduce_compute(const size_t N, const size_t M, const bool addto,
const DType *big, DType *small, const Shape<ndim> bshape,
const Shape<ndim> sshape, const Shape<ndim> rshape,
const Shape<ndim> rstride) {
#pragma omp parallel for num_threads(engine::OpenMP::Get()->GetRecommendedOMPThreadCount())
for (index_t idx = 0; idx < static_cast<index_t>(N); ++idx) {
seq_reduce_assign<Reducer, ndim, DType, OP>(idx, M, addto, big, small, bshape, sshape, rshape,
rstride);
seq_reduce_assign<Reducer, ndim, AType, DType, OP>(idx, M, addto, big, small,
bshape, sshape, rshape, rstride);
}
}

Expand All @@ -227,16 +227,25 @@ void seq_reduce_compute_extra_mem(const size_t N, const size_t M, const bool add
}
}

template <typename Reducer, int ndim, typename DType, typename OP>
template <typename Reducer, int ndim, typename DType, typename OP, bool safe_acc = false>
void Reduce(Stream<cpu>* s, const TBlob& small, const OpReqType req,
const Tensor<cpu, 1, char>& workspace, const TBlob& big) {
if (req == kNullOp) return;
Shape<ndim> rshape, rstride;
diff(small.shape_.get<ndim>(), big.shape_.get<ndim>(), &rshape, &rstride);
size_t N = small.shape_.Size(), M = rshape.Size();
seq_reduce_compute<Reducer, ndim, DType, OP>(
N, M, req == kAddTo, big.dptr<DType>(), small.dptr<DType>(),
big.shape_.get<ndim>(), small.shape_.get<ndim>(), rshape, rstride);
if (!safe_acc) {
seq_reduce_compute<Reducer, ndim, DType, DType, OP>(
N, M, req == kAddTo, big.dptr<DType>(), small.dptr<DType>(),
big.shape_.get<ndim>(), small.shape_.get<ndim>(), rshape, rstride);
} else {
MXNET_REAL_ACC_TYPE_SWITCH(mshadow::DataType<DType>::kFlag, DataType, AType, {
typedef typename std::conditional<safe_acc, AType, DataType>::type AccType;
seq_reduce_compute<Reducer, ndim, AccType, DataType, OP>(
N, M, req == kAddTo, big.dptr<DataType>(), small.dptr<DataType>(),
big.shape_.get<ndim>(), small.shape_.get<ndim>(), rshape, rstride);
});
}
}

template <typename Reducer, int ndim, typename DType, typename OP>
Expand Down
12 changes: 6 additions & 6 deletions src/operator/tensor/broadcast_reduce_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -523,7 +523,7 @@ void SearchAxisCompute(const nnvm::NodeAttrs& attrs,
});
}

template<typename xpu, typename reducer, bool normalize = false,
template<typename xpu, typename reducer, bool safe_acc, bool normalize = false,
typename OP = op::mshadow_op::identity>
void ReduceAxesComputeImpl(const OpContext& ctx,
const std::vector<TBlob>& inputs,
Expand All @@ -544,7 +544,7 @@ void ReduceAxesComputeImpl(const OpContext& ctx,
s, out_data.shape_, req[0], in_data.shape_);
Tensor<xpu, 1, char> workspace =
ctx.requested[0].get_space_typed<xpu, 1, char>(Shape1(workspace_size), s);
broadcast::Reduce<reducer, NDim, DType, OP>(
broadcast::Reduce<reducer, NDim, DType, OP, safe_acc>(
s, out_data, req[0], workspace, in_data);
if (normalize) {
auto out = out_data.FlatTo2D<xpu, DType>(s);
Expand All @@ -569,7 +569,7 @@ void ReduceAxesCompute(const nnvm::NodeAttrs& attrs,
small = ReduceAxesShapeImpl(inputs[0].shape_, param.axis, true, param.exclude);
}

ReduceAxesComputeImpl<xpu, reducer, normalize, OP>(ctx, inputs, req, outputs, small);
ReduceAxesComputeImpl<xpu, reducer, false, normalize, OP>(ctx, inputs, req, outputs, small);
}

template <typename red_op, int req, int axis>
Expand Down Expand Up @@ -1088,10 +1088,10 @@ void LpNormCompute(const nnvm::NodeAttrs& attrs,
small = ReduceAxesShapeImpl(inputs[0].shape_, param.axis, true, false);
}
if (param.ord == 1) {
ReduceAxesComputeImpl<xpu, mshadow::red::sum, false, mshadow_op::abs>(
ctx, inputs, req, outputs, small);
ReduceAxesComputeImpl<xpu, mshadow_op::sum, true, false, mshadow_op::abs>(
ctx, inputs, req, outputs, small);
} else if (param.ord == 2) {
ReduceAxesComputeImpl<xpu, mshadow_op::nrm2, false, mshadow_op::identity>(
ReduceAxesComputeImpl<xpu, mshadow_op::nrm2, true, false, mshadow_op::identity>(
ctx, inputs, req, outputs, small);
}
}
Expand Down
Loading