Skip to content

Commit

Permalink
begin=end not a valid input (apache#14403)
Browse files Browse the repository at this point in the history
refactoring logic for indexing
  • Loading branch information
mseth10 authored and haohuw committed Jun 23, 2019
1 parent cc1593a commit 433bd8b
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 36 deletions.
65 changes: 29 additions & 36 deletions src/operator/tensor/matrix_op-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -653,50 +653,43 @@ inline void GetIndexRange(const mxnet::TShape& dshape,
}

for (index_t i = 0; i < param_begin.ndim(); ++i) {
index_t b = 0, e = dshape[i], s = 1;
const index_t len = dshape[i];
if (param_step.ndim() != 0U) {
const auto& opt_step_val = param_step[i];
if (opt_step_val.has_value()) {
s = opt_step_val.value();
CHECK_NE(s, 0) << "slice op step[" << i << "] cannot be 0";
}
}
index_t s = param_step.ndim() != 0U && param_step[i].has_value() ? param_step[i].value() : 1;
CHECK_NE(s, 0) << "slice op step[" << i << "] cannot be 0";

if (len) {
if (param_begin[i].has_value()) {
b = param_begin[i].value();
if (b < 0) {
b += len;
CHECK_GE(b, 0) << "slicing with begin[" << i << "]="
<< b - len << " exceeds limit of " << len;
}
} else if (s < 0) {
b = len - 1;
index_t b = 0, e = 0;
const index_t len = dshape[i];
if (len > 0) {
b = param_begin[i].has_value() ? param_begin[i].value() : (s < 0 ? len - 1 : 0);
e = param_end[i].has_value() ? param_end[i].value() : (s < 0 ? -1 : len);

// checking upper and lower bounds for begin
if (b < 0) {
b += len;
CHECK_GE(b, 0) << "slicing with begin[" << i << "]=" << b - len
<< " exceeds limit of input dimension[" << i << "]=" << len;
}
CHECK_LT(b, len) << "slicing with begin[" << i << "]="
<< b << " exceends limit of " << len;

if (param_end[i].has_value()) {
e = param_end[i].value();
if (e < 0) {
e += len;
CHECK_GE(e, 0) << "slicing with end[" << i << "]="
<< e - len << " exceeds limit of " << len;
}
} else if (s < 0) {
e = -1;
CHECK_LT(b, len) << "slicing with begin[" << i << "]=" << b
<< " exceeds limit of input dimension[" << i << "]=" << len;

// checking upper and lower bounds for end
if (e < 0 && param_end[i].has_value()) {
e += len;
CHECK_GE(e, 0) << "slicing with end[" << i << "]=" << e - len
<< " exceeds limit of input dimension[" << i << "]=" << len;
}
CHECK_LE(e, len) << "slicing with end[" << i << "]="
<< e << " exceeds limit of " << len;
} else {
b = 0;
e = 0;
CHECK_LE(e, len) << "slicing with end[" << i << "]=" << e
<< " exceeds limit of input dimension[" << i << "]=" << len;

// checking begin==end case which is not supported
CHECK_NE(b, e) << "slicing with begin[" << i << "]=end[" << i << "]="
<< e << " results in an empty tensor and is not supported";
}

(*begin)[i] = b;
(*end)[i] = e;
(*step)[i] = s;
}

for (index_t i = param_begin.ndim(); i < dshape.ndim(); ++i) {
(*begin)[i] = 0;
(*end)[i] = dshape[i];
Expand Down
9 changes: 9 additions & 0 deletions tests/python/unittest/test_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -6606,6 +6606,15 @@ def test_slice_forward_backward(a, index):
for index in index_list:
test_slice_forward_backward(arr, index)

def test_begin_equals_end(shape, begin, end, step):
in_arr = mx.nd.arange(np.prod(shape)).reshape(shape=shape)
out_arr = mx.nd.slice(in_arr, begin=begin, end=end, step=step)

assertRaises(MXNetError, test_begin_equals_end, (4,), (2,), (2,), (1,))
assertRaises(MXNetError, test_begin_equals_end, (1, 5), (None, 3), (None, 3), (-1, 1))
assertRaises(MXNetError, test_begin_equals_end, (3, 4, 5), (1, 3, 1), (3, 3, 1), (1, -3, 2))
assertRaises(MXNetError, test_begin_equals_end, (2, 4), (None, 2), (None, 2), (1, -1))

# check numeric gradient
in_data = np.arange(36).reshape(2, 2, 3, 3)
data = mx.sym.Variable('data')
Expand Down

0 comments on commit 433bd8b

Please sign in to comment.