-
Notifications
You must be signed in to change notification settings - Fork 6.8k
Conversation
@mxnet-label-bot add [Operator, pr-awaiting-review] |
int num_batch) { | ||
size_t tid = blockIdx.x * blockDim.x + threadIdx.x; | ||
if (tid < N) { | ||
const int32_t previous = tid > 0 ? __ldg(valid_batch_id + tid - 1) : -1; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Using __ldg intrinsic will fail to compile on some early cuda architectures.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It will fail on sm 3.0 and earlier (so Fermi and the first Kepler). I can put ifdef there, but do we care about those?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In Makefile, sm 30 is in KNOWN_CUDA_ARCHS.
/~https://github.com/apache/incubator-mxnet/blob/master/Makefile#L385
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Then we do ;-). I will introduce the guard, thanks!
* Optimize NMS part 2 * Guarding ldg intrinsics
* Optimize NMS part 2 * Guarding ldg intrinsics
* Optimize NMS part 2 * Guarding ldg intrinsics
Description
This PR changes the
batch_start
calculation in the BoxNMSForward op to the custom kernel, much faster than the mshadow generated one. In MaskRCNN model it changes the runtime of that part from 20 ms to 2 us, speeding up the single GPU training by 20% in fp16 mode.Checklist
Essentials
Please feel free to remove inapplicable items for your PR.
Comments