Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
safe norm with DataType, AccuType and OutType
Browse files Browse the repository at this point in the history
  • Loading branch information
haojin2 committed Apr 9, 2019
1 parent 98f1d47 commit 15f00ef
Show file tree
Hide file tree
Showing 4 changed files with 80 additions and 42 deletions.
40 changes: 23 additions & 17 deletions src/operator/tensor/broadcast_reduce-inl.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -72,10 +72,10 @@ void BinaryBroadcastComputeImpl(Stream<gpu> *s, const OpReqType req,
}

const int nthread_reduce = kMaxThreadsPerBlock;
template<typename Reducer, int ndim, typename AType, typename DType, typename OP, int unroll>
template<typename Reducer, int ndim, typename AType, typename DType, typename OType, typename OP, int unroll>
__launch_bounds__(nthread_reduce)
__global__ void reduce_kernel(const int N, const int M, const bool addto,
const DType* __restrict big, DType *small,
const DType* __restrict big, OType *small,
const Shape<ndim> big_shape0, const Shape<ndim> small_shape,
const Shape<ndim> big_shape, const Shape<ndim> big_stride,
const int Mnext, const bool do_transpose) {
Expand Down Expand Up @@ -139,12 +139,12 @@ __global__ void reduce_kernel(const int N, const int M, const bool addto,
}
if (idx < N && tidy == 0) {
Reducer::Finalize(shTile[tidx * 2], shTile[tidx * 2 + 1]);
assign(&small[idx + m0*N], addto, DType(shTile[tidx * 2]));
assign(&small[idx + m0*N], addto, OType(shTile[tidx * 2]));
}
} else {
if (idx < N) {
Reducer::Finalize(val, residual);
assign(&small[idx + m0*N], addto, DType(val));
assign(&small[idx + m0*N], addto, OType(val));
}
}
}
Expand Down Expand Up @@ -261,9 +261,9 @@ __global__ void reduce_lines_kernel(const int N, const int M, const bool addto,
}
}

template<typename Reducer, int ndim, typename AType, typename DType, typename OP>
template<typename Reducer, int ndim, typename AType, typename DType, typename OType, typename OP>
__global__ void reduce_kernel_M1(const int N, const bool addto,
const DType* __restrict big, DType *small, const Shape<ndim> bshape,
const DType* __restrict big, OType *small, const Shape<ndim> bshape,
const Shape<ndim> sshape) {
for (int idx = threadIdx.x + blockIdx.x*blockDim.x; idx < N; idx += blockDim.x*gridDim.x) {
Shape<ndim> coord = unravel(idx, sshape);
Expand All @@ -272,7 +272,7 @@ __global__ void reduce_kernel_M1(const int N, const bool addto,
Reducer::SetInitValue(val, residual);
Reducer::Reduce(val, AType(OP::Map(big[j])), residual);
Reducer::Finalize(val, residual);
assign(&small[idx], addto, DType(val));
assign(&small[idx], addto, OType(val));
}
}

Expand Down Expand Up @@ -516,22 +516,23 @@ ReduceImplConfig<ndim> ConfigureReduceImpl(const mxnet::TShape& small, const mxn
{__VA_ARGS__} \
}

template<typename Reducer, int ndim, typename AType, typename DType, typename OP>
template<typename Reducer, int ndim, typename AType, typename DType, typename OType, typename OP>
void ReduceImpl(cudaStream_t stream, const TBlob& small, const OpReqType req,
const TBlob& big, const Tensor<gpu, 1, char>& workspace,
const ReduceImplConfig<ndim>& config) {
if (config.M == 1) {
reduce_kernel_M1<Reducer, ndim, AType, DType, OP>
std::cout << "here1" << std::endl;
reduce_kernel_M1<Reducer, ndim, AType, DType, OType, OP>
<<< config.kernel_1.gridDim, config.kernel_1.blockDim, 0, stream >>>(
config.N, req == kAddTo, big.dptr<DType>(), small.dptr<DType>(), big.shape_.get<ndim>(),
config.N, req == kAddTo, big.dptr<DType>(), small.dptr<OType>(), big.shape_.get<ndim>(),
small.shape_.get<ndim>());
MSHADOW_CUDA_POST_KERNEL_CHECK(reduce_kernel_M1);
} else {
DType* small_dptr = small.dptr<DType>();
OType* small_dptr = small.dptr<OType>();
bool addto = (req == kAddTo);
if (config.Mnext > 1) {
// small_dptr[] is N*Mnext*sizeof(DType) bytes
small_dptr = reinterpret_cast<DType*>(workspace.dptr_);
small_dptr = reinterpret_cast<OType*>(workspace.dptr_);
addto = false;
// Check that the workspace is contigiuous
CHECK_EQ(workspace.CheckContiguous(), true);
Expand All @@ -543,7 +544,7 @@ void ReduceImpl(cudaStream_t stream, const TBlob& small, const OpReqType req,
config.kernel_1.blockDim.x : config.kernel_1.blockDim.y;
const bool do_unroll = ( config.M / (by*config.Mnext) >= config.unroll_reduce );
KERNEL_UNROLL_SWITCH(do_unroll, ReduceImplConfig<ndim>::unroll_reduce, UNROLL, {
reduce_kernel<Reducer, ndim, AType, DType, OP, UNROLL>
reduce_kernel<Reducer, ndim, AType, DType, OType, OP, UNROLL>
<<< config.kernel_1.gridDim, config.kernel_1.blockDim, config.kernel_1.shMemSize, stream>>>(
config.N, config.M, addto, big.dptr<DType>(), small_dptr, big.shape_.get<ndim>(),
small.shape_.get<ndim>(), config.rshape, config.rstride, config.Mnext,
Expand All @@ -552,9 +553,10 @@ void ReduceImpl(cudaStream_t stream, const TBlob& small, const OpReqType req,
MSHADOW_CUDA_POST_KERNEL_CHECK(reduce_kernel);

if (config.Mnext > 1) {
reduce_lines_kernel<Reducer, DType>
std::cout << "here3" << std::endl;
reduce_lines_kernel<Reducer, OType>
<<< config.kernel_2.gridSize, config.kernel_2.blockSize, 0, stream >>>
(config.N, config.Mnext, req == kAddTo, config.N, small_dptr, small.dptr<DType>());
(config.N, config.Mnext, req == kAddTo, config.N, small_dptr, small.dptr<OType>());
MSHADOW_CUDA_POST_KERNEL_CHECK(reduce_lines_kernel);
}
}
Expand Down Expand Up @@ -619,11 +621,15 @@ void Reduce(Stream<gpu> *s, const TBlob& small, const OpReqType req,
if (safe_acc) {
MXNET_REAL_ACC_TYPE_SWITCH(mshadow::DataType<DType>::kFlag, DataType, AType, {
typedef typename std::conditional<safe_acc, AType, DataType>::type AccType;
ReduceImpl<Reducer, ndim, AccType, DataType, OP>(
config = ConfigureReduceImpl<ndim, AccType>(small.shape_, big.shape_, NULL, NULL);
MSHADOW_TYPE_SWITCH(small.type_flag_, OType, {
typedef typename std::conditional<safe_acc, OType, DataType>::type OutType;
ReduceImpl<Reducer, ndim, AccType, DataType, OutType, OP>(
stream, small, req, big, workspace, config);
});
});
} else {
ReduceImpl<Reducer, ndim, DType, DType, OP>(stream, small, req, big, workspace, config);
ReduceImpl<Reducer, ndim, DType, DType, DType, OP>(stream, small, req, big, workspace, config);
}
}

Expand Down
23 changes: 13 additions & 10 deletions src/operator/tensor/broadcast_reduce-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -153,9 +153,9 @@ MSHADOW_XINLINE void binary_broadcast_assign(const index_t idx, const bool addto
assign(&out[idx], addto, OP::Map(lhs[j], rhs[k]));
}

template<typename Reducer, int ndim, typename AType, typename DType, typename OP>
template<typename Reducer, int ndim, typename AType, typename DType, typename OType, typename OP>
MSHADOW_XINLINE void seq_reduce_assign(const index_t idx, const size_t M, const bool addto,
const DType* __restrict big, DType *small,
const DType* __restrict big, OType *small,
const Shape<ndim>& bshape, const Shape<ndim>& sshape,
const Shape<ndim>& rshape, const Shape<ndim>& rstride) {
Shape<ndim> coord = unravel(idx, sshape);
Expand All @@ -167,7 +167,7 @@ MSHADOW_XINLINE void seq_reduce_assign(const index_t idx, const size_t M, const
Reducer::Reduce(val, AType(OP::Map(big[j + dot(coord, rstride)])), residual);
}
Reducer::Finalize(val, residual);
assign(&small[idx], addto, DType(val));
assign(&small[idx], addto, OType(val));
}

#ifdef __CUDACC__
Expand All @@ -194,14 +194,14 @@ void BinaryBroadcastComputeImpl(Stream<cpu> *s, const OpReqType req,
out.shape_.get<ndim>());
}

template<typename Reducer, int ndim, typename AType, typename DType, typename OP>
template<typename Reducer, int ndim, typename AType, typename DType, typename OType, typename OP>
void seq_reduce_compute(const size_t N, const size_t M, const bool addto,
const DType *big, DType *small, const Shape<ndim> bshape,
const DType *big, OType *small, const Shape<ndim> bshape,
const Shape<ndim> sshape, const Shape<ndim> rshape,
const Shape<ndim> rstride) {
#pragma omp parallel for num_threads(engine::OpenMP::Get()->GetRecommendedOMPThreadCount())
for (index_t idx = 0; idx < static_cast<index_t>(N); ++idx) {
seq_reduce_assign<Reducer, ndim, AType, DType, OP>(idx, M, addto, big, small,
seq_reduce_assign<Reducer, ndim, AType, DType, OType, OP>(idx, M, addto, big, small,
bshape, sshape, rshape, rstride);
}
}
Expand Down Expand Up @@ -235,15 +235,18 @@ void Reduce(Stream<cpu>* s, const TBlob& small, const OpReqType req,
diff(small.shape_.get<ndim>(), big.shape_.get<ndim>(), &rshape, &rstride);
size_t N = small.shape_.Size(), M = rshape.Size();
if (!safe_acc) {
seq_reduce_compute<Reducer, ndim, DType, DType, OP>(
seq_reduce_compute<Reducer, ndim, DType, DType, DType, OP>(
N, M, req == kAddTo, big.dptr<DType>(), small.dptr<DType>(),
big.shape_.get<ndim>(), small.shape_.get<ndim>(), rshape, rstride);
} else {
MXNET_REAL_ACC_TYPE_SWITCH(mshadow::DataType<DType>::kFlag, DataType, AType, {
typedef typename std::conditional<safe_acc, AType, DataType>::type AccType;
seq_reduce_compute<Reducer, ndim, AccType, DataType, OP>(
N, M, req == kAddTo, big.dptr<DataType>(), small.dptr<DataType>(),
big.shape_.get<ndim>(), small.shape_.get<ndim>(), rshape, rstride);
MSHADOW_TYPE_SWITCH(small.type_flag_, OType, {
typedef typename std::conditional<safe_acc, OType, DataType>::type OutType;
seq_reduce_compute<Reducer, ndim, AccType, DataType, OutType, OP>(
N, M, req == kAddTo, big.dptr<DataType>(), small.dptr<OutType>(),
big.shape_.get<ndim>(), small.shape_.get<ndim>(), rshape, rstride);
});
});
}
}
Expand Down
57 changes: 43 additions & 14 deletions src/operator/tensor/broadcast_reduce_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ struct ReduceAxesParam : public dmlc::Parameter<ReduceAxesParam> {
struct NormParam : public dmlc::Parameter<NormParam> {
int ord;
dmlc::optional<mxnet::TShape> axis;
dmlc::optional<int> out_dtype;
bool keepdims;
DMLC_DECLARE_PARAMETER(NormParam) {
DMLC_DECLARE_FIELD(ord).set_default(2)
Expand All @@ -78,6 +79,15 @@ struct NormParam : public dmlc::Parameter<NormParam> {
If `axis` is int, a reduction is performed on a particular axis.
If `axis` is a 2-tuple, it specifies the axes that hold 2-D matrices,
and the matrix norms of these matrices are computed.)code");
DMLC_DECLARE_FIELD(out_dtype)
.add_enum("float16", mshadow::kFloat16)
.add_enum("float32", mshadow::kFloat32)
.add_enum("float64", mshadow::kFloat64)
.add_enum("int64", mshadow::kInt64)
.add_enum("int32", mshadow::kInt32)
.add_enum("int8", mshadow::kInt8)
.set_default(dmlc::optional<int>())
.describe(R"code(The data type of the output.)code");
DMLC_DECLARE_FIELD(keepdims).set_default(false)
.describe("If this is set to `True`, the reduced axis is left "
"in the result as dimension with size one.");
Expand Down Expand Up @@ -302,6 +312,23 @@ inline bool ReduceAxesShape(const nnvm::NodeAttrs& attrs,
return true;
}

inline bool NormType(const nnvm::NodeAttrs& attrs,
std::vector<int> *in_attrs,
std::vector<int> *out_attrs) {
CHECK_EQ(in_attrs->size(), 1U);
CHECK_EQ(out_attrs->size(), 1U);
const NormParam& param = nnvm::get<NormParam>(attrs.parsed);
if (param.out_dtype.has_value()) {
CHECK_NE(in_attrs->at(0), -1)
<< "input data type should be specified when out_dtype is not null";
TYPE_ASSIGN_CHECK(*out_attrs, 0, param.out_dtype.value());
} else {
TYPE_ASSIGN_CHECK(*out_attrs, 0, (*in_attrs)[0]);
TYPE_ASSIGN_CHECK(*in_attrs, 0, (*out_attrs)[0]);
}
return (*out_attrs)[0] != -1;
}

inline bool NormShape(const nnvm::NodeAttrs& attrs,
mxnet::ShapeVector *in_attrs,
mxnet::ShapeVector *out_attrs) {
Expand Down Expand Up @@ -538,20 +565,22 @@ void ReduceAxesComputeImpl(const OpContext& ctx,
mxnet::TShape src_shape, dst_shape;
BroadcastReduceShapeCompact(inputs[0].shape_, small, &src_shape, &dst_shape);
Stream<xpu> *s = ctx.get_stream<xpu>();
MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, DType, {
const TBlob in_data = inputs[0].reshape(src_shape);
const TBlob out_data = outputs[0].reshape(dst_shape);
BROADCAST_NDIM_SWITCH(dst_shape.ndim(), NDim, {
size_t workspace_size = broadcast::ReduceWorkspaceSize<NDim, DType>(
s, out_data.shape_, req[0], in_data.shape_);
Tensor<xpu, 1, char> workspace =
ctx.requested[0].get_space_typed<xpu, 1, char>(Shape1(workspace_size), s);
broadcast::Reduce<reducer, NDim, DType, OP, safe_acc>(
s, out_data, req[0], workspace, in_data);
if (normalize) {
auto out = out_data.FlatTo2D<xpu, DType>(s);
out /= scalar<DType>(src_shape.Size()/dst_shape.Size());
}
MSHADOW_TYPE_SWITCH(inputs[0].type_flag_, DType, {
MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, OType, {
const TBlob in_data = inputs[0].reshape(src_shape);
const TBlob out_data = outputs[0].reshape(dst_shape);
BROADCAST_NDIM_SWITCH(dst_shape.ndim(), NDim, {
size_t workspace_size = broadcast::ReduceWorkspaceSize<NDim, DType>(
s, out_data.shape_, req[0], in_data.shape_);
Tensor<xpu, 1, char> workspace =
ctx.requested[0].get_space_typed<xpu, 1, char>(Shape1(workspace_size), s);
broadcast::Reduce<reducer, NDim, DType, OP, safe_acc>(
s, out_data, req[0], workspace, in_data);
if (normalize) {
auto out = out_data.FlatTo2D<xpu, OType>(s);
out /= scalar<OType>(src_shape.Size()/dst_shape.Size());
}
});
});
});
}
Expand Down
2 changes: 1 addition & 1 deletion src/operator/tensor/broadcast_reduce_op_value.cc
Original file line number Diff line number Diff line change
Expand Up @@ -352,7 +352,7 @@ Examples::
.set_num_outputs(1)
.set_attr_parser(ParamParser<NormParam>)
.set_attr<mxnet::FInferShape>("FInferShape", NormShape)
.set_attr<nnvm::FInferType>("FInferType", ElemwiseType<1, 1>)
.set_attr<nnvm::FInferType>("FInferType", NormType)
.set_attr<FInferStorageType>("FInferStorageType", LpNormStorageType)
.set_attr<nnvm::FGradient>("FGradient", ReduceGrad{ "_backward_norm" })
.set_attr<FResourceRequest>("FResourceRequest",
Expand Down

0 comments on commit 15f00ef

Please sign in to comment.