Skip to content

Commit

Permalink
Optimize NMS part 2 (apache#14352)
Browse files Browse the repository at this point in the history
* Optimize NMS part 2

* Guarding ldg intrinsics
  • Loading branch information
ptrendx authored and haohuw committed Jun 23, 2019
1 parent b7fdc65 commit 183d99d
Show file tree
Hide file tree
Showing 3 changed files with 68 additions and 13 deletions.
10 changes: 10 additions & 0 deletions src/operator/contrib/bounding_box-common.h
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,16 @@ struct nms_impl {
}
};

namespace mshadow_op {
struct less_than : public mxnet_op::tunable {
// a is x, b is sigma
template<typename DType>
MSHADOW_XINLINE static DType Map(DType a, DType b) {
return static_cast<DType>(a < b);
}
}; // struct equal_to
} // namespace mshadow_op

} // namespace op
} // namespace mxnet

Expand Down
44 changes: 44 additions & 0 deletions src/operator/contrib/bounding_box-inl.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -280,6 +280,50 @@ void NMSApply(mshadow::Stream<gpu> *s,
}
}

__launch_bounds__(512)
__global__ void nms_calculate_batch_start_kernel(int32_t * batch_start,
int32_t * valid_batch_id,
size_t N,
int num_batch) {
size_t tid = blockIdx.x * blockDim.x + threadIdx.x;
if (tid < N) {
#if __CUDA_ARCH__ >= 350
const int32_t previous = tid > 0 ? __ldg(valid_batch_id + tid - 1) : -1;
const int32_t my = __ldg(valid_batch_id + tid);
#else
const int32_t previous = tid > 0 ? valid_batch_id[tid - 1] : -1;
const int32_t my = valid_batch_id[tid];
#endif
if (my > previous) {
for (int32_t current = previous + 1; current <= my; ++current) {
batch_start[current] = tid;
}
}
if (tid == N - 1) {
for (int32_t current = my + 1; current <= num_batch; ++current) {
batch_start[current] = tid + 1;
}
}
}
}

inline void NMSCalculateBatchStart(mshadow::Stream<gpu> *s,
mshadow::Tensor<gpu, 1, int32_t>* batch_start,
mshadow::Tensor<gpu, 1, int32_t>* valid_batch_id,
int num_batch) {
using namespace mshadow;
using namespace mshadow::expr;
using namespace mxnet_op;
auto stream = mshadow::Stream<gpu>::GetStream(s);
constexpr int block_size = 512;
const int num_elements = valid_batch_id->size(0);
const int blocks = (num_elements + block_size - 1) / block_size;
nms_calculate_batch_start_kernel<<<blocks, block_size, 0, stream>>>(batch_start->dptr_,
valid_batch_id->dptr_,
num_elements,
num_batch);
}

} // namespace op
} // namespace mxnet

Expand Down
27 changes: 14 additions & 13 deletions src/operator/contrib/bounding_box-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -162,15 +162,6 @@ int FilterScores(mshadow::Tensor<cpu, 1, DType> out_scores,
return j;
}

namespace mshadow_op {
struct less_than : public mxnet_op::tunable {
// a is x, b is sigma
template<typename DType>
MSHADOW_XINLINE static DType Map(DType a, DType b) {
return static_cast<DType>(a < b);
}
}; // struct equal_to
} // namespace mshadow_op

struct corner_to_center {
template<typename DType>
Expand Down Expand Up @@ -277,6 +268,19 @@ void NMSApply(mshadow::Stream<cpu> *s,
}
}

inline void NMSCalculateBatchStart(mshadow::Stream<cpu> *s,
mshadow::Tensor<cpu, 1, int32_t>* batch_start,
mshadow::Tensor<cpu, 1, int32_t>* valid_batch_id,
int num_batch) {
using namespace mshadow;
using namespace mshadow::expr;
using namespace mxnet_op;
for (int b = 0; b < num_batch + 1; b++) {
slice<0>(*batch_start, b, b + 1) = reduce_keepdim<red::sum, false>(
F<mshadow_op::less_than>(*valid_batch_id, ScalarExp<int32_t>(b)), 0);
}
}

/*!
* \brief Assign output of nms by indexing input
*
Expand Down Expand Up @@ -435,10 +439,7 @@ void BoxNMSForward(const nnvm::NodeAttrs& attrs,

// calculate batch_start: accumulated sum to denote 1st sorted_index for a given batch_index
valid_batch_id = (valid_sorted_index / ScalarExp<int32_t>(num_elem));
for (int b = 0; b < num_batch + 1; b++) {
slice<0>(batch_start, b, b + 1) = reduce_keepdim<red::sum, false>(
F<mshadow_op::less_than>(valid_batch_id, ScalarExp<int32_t>(b)), 0);
}
mxnet::op::NMSCalculateBatchStart(s, &batch_start, &valid_batch_id, num_batch);

// pre-compute areas of candidates
areas = 0;
Expand Down

0 comments on commit 183d99d

Please sign in to comment.