-
Notifications
You must be signed in to change notification settings - Fork 6.8k
fix add_n bug: when input mem overlap with output mem, results is wrong #14889
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice catch! Would you mind adding an unit test to test_ndarray.py? e.g.
x = mx.nd.ones((2,2))
mx.nd.ElemwiseSum(x,x,out=x)
Thanks for your advice. Testcase has been added. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
@sandeep-krishnamurthy please take a review too.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks
@mxnet-label-bot add [pr-awaiting-merge] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice catch! LGTM!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Minor comments.
src/ndarray/ndarray_function.cc
Outdated
@@ -207,7 +207,10 @@ void ElementwiseSumContainsDnsImpl(mshadow::Stream<cpu>* s, | |||
using namespace mxnet::op::mxnet_op; | |||
const TBlob& out_data = out->data(); | |||
MSHADOW_TYPE_SWITCH(out->dtype(), DType, { // data type | |||
Kernel<set_zero, cpu>::Launch(s, out_data.Size(), out_data.dptr<DType>()); | |||
// Do not set_zero if output mem inplace with input mem: elemwise_sum.cc FInplaceOption |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Make this comment easy to understand. Add comment that output can be in-placed with the first input.
@@ -8284,6 +8284,18 @@ def check_concat(shape1, shape2, axis): | |||
check_concat((8, 0, 0), (8, 0, 0), 2) | |||
|
|||
|
|||
@with_seed() | |||
def test_elemwise_sum_add_n(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
test_add_n() ?
rslt = mx.nd.zeros(shape=data_shape) | ||
for i in range(input_num): | ||
rslt += data[i] | ||
add_n_rslt = mx.nd.add_n(*data,out=data[0]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
add a space before out=data[0]
.
…ng (apache#14889) * fix add_n bug: when input mem overlap with output mem, results is wrong * add testcase for bugfix verification * add more comments for modification and change testcase name to test_add_n
…ng (apache#14889) * fix add_n bug: when input mem overlap with output mem, results is wrong * add testcase for bugfix verification * add more comments for modification and change testcase name to test_add_n
Description
Fix add_n forword bug
Due to "FInplaceOption" is active, output memory should not set_zero before actually do "add_n" which will make inputs memory set zero too when overlap.
Fix issue for #14858
Checklist
Essentials
Please feel free to remove inapplicable items for your PR.
Changes
Comments
@pengzhao-intel @TaoLv