From fc1b5f55dfc6eccdeed1a078736de266167bd433 Mon Sep 17 00:00:00 2001 From: Przemyslaw Tredak Date: Mon, 4 Mar 2019 16:22:08 -0800 Subject: [PATCH] Optimize NMS (#14290) * Optimize NMS * Fix lint --- src/operator/contrib/bounding_box-common.h | 118 +++++++++++ src/operator/contrib/bounding_box-inl.cuh | 223 +++++++++++++++++++++ src/operator/contrib/bounding_box-inl.h | 115 +++-------- 3 files changed, 368 insertions(+), 88 deletions(-) create mode 100644 src/operator/contrib/bounding_box-common.h diff --git a/src/operator/contrib/bounding_box-common.h b/src/operator/contrib/bounding_box-common.h new file mode 100644 index 000000000000..70215ab25d64 --- /dev/null +++ b/src/operator/contrib/bounding_box-common.h @@ -0,0 +1,118 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file bounding_box-common.h + * \brief bounding box util functions and operators commonly used by CPU and GPU implementations + * \author Joshua Zhang +*/ +#ifndef MXNET_OPERATOR_CONTRIB_BOUNDING_BOX_COMMON_H_ +#define MXNET_OPERATOR_CONTRIB_BOUNDING_BOX_COMMON_H_ +#include "../mshadow_op.h" +#include "../mxnet_op.h" +#include "../operator_common.h" + +namespace mxnet { +namespace op { +namespace box_common_enum { +enum BoxType {kCorner, kCenter}; +} + +// compute line intersect along either height or width +template +MSHADOW_XINLINE DType Intersect(const DType *a, const DType *b, int encode) { + DType a1 = a[0]; + DType a2 = a[2]; + DType b1 = b[0]; + DType b2 = b[2]; + DType w; + if (box_common_enum::kCorner == encode) { + DType left = a1 > b1 ? a1 : b1; + DType right = a2 < b2 ? a2 : b2; + w = right - left; + } else { + DType aw = a2 / 2; + DType bw = b2 / 2; + DType al = a1 - aw; + DType ar = a1 + aw; + DType bl = b1 - bw; + DType br = b1 + bw; + DType left = bl > al ? bl : al; + DType right = br < ar ? br : ar; + w = right - left; + } + return w > 0 ? w : DType(0); +} + +/*! + * \brief Implementation of the non-maximum suppression operation + * + * \param i the launched thread index + * \param index sorted index in descending order + * \param batch_start map (b, k) to compact index by indices[batch_start[b] + k] + * \param input the input of nms op + * \param areas pre-computed box areas + * \param k nms topk number + * \param ref compare reference position + * \param num number of input boxes in each batch + * \param stride input stride, usually 6 (id-score-x1-y1-x2-y2) + * \param offset_box box offset, usually 2 + * \param thresh nms threshold + * \param force force suppress regardless of class id + * \param offset_id class id offset, used when force == false, usually 0 + * \param encode box encoding type, corner(0) or center(1) + * \param DType the data type + */ +struct nms_impl { + template + MSHADOW_XINLINE static void Map(int i, int32_t *index, const int32_t *batch_start, + const DType *input, const DType *areas, + int k, int ref, int num, + int stride, int offset_box, int offset_id, + float thresh, bool force, int encode) { + int b = i / k; // batch + int pos = i % k + ref + 1; // position + ref = static_cast(batch_start[b]) + ref; + pos = static_cast(batch_start[b]) + pos; + if (ref >= static_cast(batch_start[b + 1])) return; + if (pos >= static_cast(batch_start[b + 1])) return; + if (index[ref] < 0) return; // reference has been suppressed + if (index[pos] < 0) return; // self been suppressed + int ref_offset = static_cast(index[ref]) * stride + offset_box; + int pos_offset = static_cast(index[pos]) * stride + offset_box; + if (!force && offset_id >=0) { + int ref_id = static_cast(input[ref_offset - offset_box + offset_id]); + int pos_id = static_cast(input[pos_offset - offset_box + offset_id]); + if (ref_id != pos_id) return; // different class + } + DType intersect = Intersect(input + ref_offset, input + pos_offset, encode); + intersect *= Intersect(input + ref_offset + 1, input + pos_offset + 1, encode); + int ref_area_offset = static_cast(index[ref]); + int pos_area_offset = static_cast(index[pos]); + DType iou = intersect / (areas[ref_area_offset] + areas[pos_area_offset] - intersect); + if (iou > thresh) { + index[pos] = -1; + } + } +}; + +} // namespace op +} // namespace mxnet + +#endif // MXNET_OPERATOR_CONTRIB_BOUNDING_BOX_COMMON_H_ diff --git a/src/operator/contrib/bounding_box-inl.cuh b/src/operator/contrib/bounding_box-inl.cuh index fd5e30b25b2d..4b7cf3476448 100644 --- a/src/operator/contrib/bounding_box-inl.cuh +++ b/src/operator/contrib/bounding_box-inl.cuh @@ -24,12 +24,15 @@ */ #ifndef MXNET_OPERATOR_CONTRIB_BOUNDING_BOX_INL_CUH_ #define MXNET_OPERATOR_CONTRIB_BOUNDING_BOX_INL_CUH_ +#include +#include #include #include #include #include "../mshadow_op.h" #include "../mxnet_op.h" #include "../operator_common.h" +#include "./bounding_box-common.h" namespace mxnet { namespace op { @@ -57,6 +60,226 @@ int FilterScores(mshadow::Tensor out_scores, return end_scores - out_scores.dptr_; } +// compute line intersect along either height or width +template +MSHADOW_XINLINE DType Intersect2(const DType *a, const DType b1, const DType b2, int encode) { + const DType a1 = a[0]; + const DType a2 = a[2]; + DType left, right; + if (box_common_enum::kCorner == encode) { + left = a1 > b1 ? a1 : b1; + right = a2 < b2 ? a2 : b2; + } else { + const DType aw = a2 / 2; + const DType bw = b2 / 2; + const DType al = a1 - aw; + const DType ar = a1 + aw; + const DType bl = b1 - bw; + const DType br = b1 + bw; + left = bl > al ? bl : al; + right = br < ar ? br : ar; + } + const DType w = right - left; + return w > 0 ? w : DType(0); +} + +template +__launch_bounds__(512) +__global__ void nms_apply_kernel(const int topk, int32_t *index, + const int32_t *batch_start, + const DType *input, + const DType *areas, + const int num, const int stride, + const int offset_box, const int offset_id, + const float thresh, const bool force, + const int encode, const int start_offset) { + constexpr int block_size = 512; + const int start = static_cast(batch_start[blockIdx.x]) + start_offset; + const int size_of_batch = static_cast(batch_start[blockIdx.x + 1]) - start; + const int end = min(min(size_of_batch, topk - start_offset), N * block_size); + __shared__ int s_index[N * block_size]; + + for (int i = threadIdx.x; i < end; i += block_size) { + s_index[i] = static_cast(index[start + i]); + } + + __syncthreads(); + for (int ref = 0; ref < end; ++ref) { + const int ref_area_offset = static_cast(s_index[ref]); + if (ref_area_offset >= 0) { + const int ref_offset = ref_area_offset * stride + offset_box; + int ref_id = 0; + if (check_class) { + ref_id = static_cast(input[ref_offset - offset_box + offset_id]); + } + for (int i = 0; i < N; ++i) { + const int my_pos = threadIdx.x + i * block_size; + if (my_pos > ref && my_pos < end && s_index[my_pos] >= 0) { + const int pos_area_offset = static_cast(s_index[my_pos]); + const int pos_offset = pos_area_offset * stride + offset_box; + if (check_class) { + const int pos_id = static_cast(input[pos_offset - offset_box + offset_id]); + if (ref_id != pos_id) continue; // different class + } + DType intersect = Intersect(input + ref_offset, input + pos_offset, encode); + intersect *= Intersect(input + ref_offset + 1, input + pos_offset + 1, encode); + const DType iou = intersect / + (areas[ref_area_offset] + areas[pos_area_offset] - intersect); + if (iou > thresh) { + s_index[my_pos] = -1; + } + } + } + __syncthreads(); + } + } + + for (int i = threadIdx.x; i < end; i += block_size) { + index[start + i] = s_index[i]; + } +} + +template +__launch_bounds__(512) +__global__ void nms_apply_kernel_rest(const int topk, int32_t *index, + const int32_t *batch_start, + const DType *input, + const DType *areas, + const int num, const int stride, + const int offset_box, const int offset_id, + const float thresh, const bool force, + const int encode, const int start_offset, + const int blocks_per_batch) { + constexpr int block_size = 512; + const int batch = blockIdx.x / blocks_per_batch; + const int start_ref = static_cast(batch_start[batch]) + start_offset; + const int block_offset = (N + blockIdx.x % blocks_per_batch) * block_size; + const int start = start_ref + block_offset; + + const int size_of_batch = static_cast(batch_start[batch + 1]) - start; + const int end = min(size_of_batch, topk - start_offset - block_offset); + const int my_pos = start + threadIdx.x; + if (threadIdx.x < end && index[my_pos] >= 0) { + const int pos_area_offset = static_cast(index[my_pos]); + const int pos_offset = pos_area_offset * stride + offset_box; + DType my_box[4]; +#pragma unroll + for (int i = 0; i < 4; ++i) { + my_box[i] = input[pos_offset + i]; + } + const DType my_area = areas[pos_area_offset]; + int pos_id = 0; + if (check_class) { + pos_id = static_cast(input[pos_offset - offset_box + offset_id]); + } + + for (int ref = start_ref; ref < start_ref + N * block_size; ++ref) { + const int ref_area_offset = static_cast(index[ref]); + if (ref_area_offset >= 0) { + const int ref_offset = ref_area_offset * stride + offset_box; + int ref_id = 0; + if (check_class) { + ref_id = static_cast(input[ref_offset - offset_box + offset_id]); + if (ref_id != pos_id) continue; // different class + } + DType intersect = Intersect2(input + ref_offset, my_box[0], my_box[2], encode); + intersect *= Intersect2(input + ref_offset + 1, my_box[1], my_box[3], encode); + const DType iou = intersect / + (areas[ref_area_offset] + my_area - intersect); + if (iou > thresh) { + index[my_pos] = -1; + break; + } + } + } + } +} + +template +void NMSApply(mshadow::Stream *s, + int num_batch, int topk, + mshadow::Tensor* sorted_index, + mshadow::Tensor* batch_start, + mshadow::Tensor* buffer, + mshadow::Tensor* areas, + int num_elem, int width_elem, + int coord_start, int id_index, + float threshold, bool force_suppress, + int in_format) { + using namespace mxnet_op; + constexpr int THRESHOLD = 1024; + for (int ref = 0; ref < topk; ref += THRESHOLD) { + constexpr int block_size = 512; + constexpr int N = THRESHOLD / block_size; + auto stream = mshadow::Stream::GetStream(s); + if (!force_suppress && id_index >= 0) { + nms_apply_kernel<<>>(topk, + sorted_index->dptr_, + batch_start->dptr_, + buffer->dptr_, + areas->dptr_, + num_elem, + width_elem, + coord_start, + id_index, + threshold, + force_suppress, + in_format, + ref); + int blocks_per_batch = (topk - ref - THRESHOLD + block_size - 1)/block_size; + int blocks = blocks_per_batch * num_batch; + if (blocks > 0) { + nms_apply_kernel_rest<<>>(topk, + sorted_index->dptr_, + batch_start->dptr_, + buffer->dptr_, + areas->dptr_, + num_elem, + width_elem, + coord_start, + id_index, + threshold, + force_suppress, + in_format, + ref, + blocks_per_batch); + } + } else { + nms_apply_kernel<<>>(topk, + sorted_index->dptr_, + batch_start->dptr_, + buffer->dptr_, + areas->dptr_, + num_elem, + width_elem, + coord_start, + id_index, + threshold, + force_suppress, + in_format, + ref); + int blocks_per_batch = (topk - ref - THRESHOLD + block_size - 1)/block_size; + int blocks = blocks_per_batch * num_batch; + if (blocks > 0) { + nms_apply_kernel_rest<<>>(topk, + sorted_index->dptr_, + batch_start->dptr_, + buffer->dptr_, + areas->dptr_, + num_elem, + width_elem, + coord_start, + id_index, + threshold, + force_suppress, + in_format, + ref, + blocks_per_batch); + } + } + } +} + } // namespace op } // namespace mxnet diff --git a/src/operator/contrib/bounding_box-inl.h b/src/operator/contrib/bounding_box-inl.h index 650e58d0e0cd..35ab19d01a19 100644 --- a/src/operator/contrib/bounding_box-inl.h +++ b/src/operator/contrib/bounding_box-inl.h @@ -34,12 +34,10 @@ #include "../mxnet_op.h" #include "../operator_common.h" #include "../tensor/sort_op.h" +#include "./bounding_box-common.h" namespace mxnet { namespace op { -namespace box_common_enum { -enum BoxType {kCorner, kCenter}; -} namespace box_nms_enum { enum BoxNMSOpInputs {kData}; enum BoxNMSOpOutputs {kOut, kTemp}; @@ -254,84 +252,31 @@ struct compute_area { } }; -// compute line intersect along either height or width template -MSHADOW_XINLINE DType Intersect(const DType *a, const DType *b, int encode) { - DType a1 = a[0]; - DType a2 = a[2]; - DType b1 = b[0]; - DType b2 = b[2]; - DType w; - if (box_common_enum::kCorner == encode) { - DType left = a1 > b1 ? a1 : b1; - DType right = a2 < b2 ? a2 : b2; - w = right - left; - } else { - DType aw = a2 / 2; - DType bw = b2 / 2; - DType al = a1 - aw; - DType ar = a1 + aw; - DType bl = b1 - bw; - DType br = b1 + bw; - DType left = bl > al ? bl : al; - DType right = br < ar ? br : ar; - w = right - left; +void NMSApply(mshadow::Stream *s, + int num_batch, int topk, + mshadow::Tensor* sorted_index, + mshadow::Tensor* batch_start, + mshadow::Tensor* buffer, + mshadow::Tensor* areas, + int num_elem, int width_elem, + int coord_start, int id_index, + float threshold, bool force_suppress, + int in_format) { + using namespace mxnet_op; + // go through each box as reference, suppress if overlap > threshold + // sorted_index with -1 is marked as suppressed + for (int ref = 0; ref < topk; ++ref) { + int num_worker = topk - ref - 1; + if (num_worker < 1) continue; + Kernel::Launch(s, num_batch * num_worker, + sorted_index->dptr_, batch_start->dptr_, buffer->dptr_, areas->dptr_, + num_worker, ref, num_elem, + width_elem, coord_start, id_index, + threshold, force_suppress, in_format); } - return w > 0 ? w : DType(0); } -/*! - * \brief Implementation of the non-maximum suppression operation - * - * \param i the launched thread index - * \param index sorted index in descending order - * \param batch_start map (b, k) to compact index by indices[batch_start[b] + k] - * \param input the input of nms op - * \param areas pre-computed box areas - * \param k nms topk number - * \param ref compare reference position - * \param num number of input boxes in each batch - * \param stride input stride, usually 6 (id-score-x1-y1-x2-y2) - * \param offset_box box offset, usually 2 - * \param thresh nms threshold - * \param force force suppress regardless of class id - * \param offset_id class id offset, used when force == false, usually 0 - * \param encode box encoding type, corner(0) or center(1) - * \param DType the data type - */ -struct nms_impl { - template - MSHADOW_XINLINE static void Map(int i, int32_t *index, const int32_t *batch_start, - const DType *input, const DType *areas, - int k, int ref, int num, - int stride, int offset_box, int offset_id, - float thresh, bool force, int encode) { - int b = i / k; // batch - int pos = i % k + ref + 1; // position - ref = static_cast(batch_start[b]) + ref; - pos = static_cast(batch_start[b]) + pos; - if (ref >= static_cast(batch_start[b + 1])) return; - if (pos >= static_cast(batch_start[b + 1])) return; - if (index[ref] < 0) return; // reference has been suppressed - if (index[pos] < 0) return; // self been suppressed - int ref_offset = static_cast(index[ref]) * stride + offset_box; - int pos_offset = static_cast(index[pos]) * stride + offset_box; - if (!force && offset_id >=0) { - int ref_id = static_cast(input[ref_offset - offset_box + offset_id]); - int pos_id = static_cast(input[pos_offset - offset_box + offset_id]); - if (ref_id != pos_id) return; // different class - } - DType intersect = Intersect(input + ref_offset, input + pos_offset, encode); - intersect *= Intersect(input + ref_offset + 1, input + pos_offset + 1, encode); - int ref_area_offset = static_cast(index[ref]); - int pos_area_offset = static_cast(index[pos]); - DType iou = intersect / (areas[ref_area_offset] + areas[pos_area_offset] - intersect); - if (iou > thresh) { - index[pos] = -1; - } - } -}; - /*! * \brief Assign output of nms by indexing input * @@ -502,17 +447,11 @@ void BoxNMSForward(const nnvm::NodeAttrs& attrs, topk, num_elem, width_elem, param.in_format); // apply nms - // go through each box as reference, suppress if overlap > threshold - // sorted_index with -1 is marked as suppressed - for (int ref = 0; ref < topk; ++ref) { - int num_worker = topk - ref - 1; - if (num_worker < 1) continue; - Kernel::Launch(s, num_batch * num_worker, - sorted_index.dptr_, batch_start.dptr_, buffer.dptr_, areas.dptr_, - num_worker, ref, num_elem, - width_elem, coord_start, id_index, - param.overlap_thresh, param.force_suppress, param.in_format); - } + mxnet::op::NMSApply(s, num_batch, topk, &sorted_index, + &batch_start, &buffer, &areas, + num_elem, width_elem, coord_start, + id_index, param.overlap_thresh, + param.force_suppress, param.in_format); // store the results to output, keep a record for backward record = -1;