Skip to content

Commit

Permalink
[Frontend][Tensorflow2] Stridedslice and concat_v2 fix (apache#8483)
Browse files Browse the repository at this point in the history
* fix for strided_slice when begin > end in case of shrinkaxis_mask

* fix for name_hint missing error for concat_v2 op

* removing a local fix

* adding more testing capability to concat_v2
  • Loading branch information
srinidhigoud authored and ylc committed Sep 29, 2021
1 parent 7ba8b6a commit b5295c8
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 3 deletions.
10 changes: 8 additions & 2 deletions python/tvm/relay/frontend/tensorflow_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -1483,7 +1483,13 @@ def _impl(inputs, attr, params, mod):
def _concatV2():
def _impl(inputs, attr, params, mod):
pop_node = inputs.pop(len(inputs) - 1)
axis = int(_get_num_param(params, pop_node))
try:
axis = int(_get_num_param(params, pop_node))
except (IndexError, KeyError, AttributeError):
try:
axis = int(_infer_value(pop_node, params, mod).numpy())
except Exception:
axis = int(pop_node)
return AttrCvt(op_name="concatenate", ignores=["T", "N", "Tidx"], extras={"axis": axis})(
[inputs], attr
)
Expand Down Expand Up @@ -2244,7 +2250,7 @@ def _transform_mask(stride_dim, ellipsis_mask):
if begin[index] < 0
else begin[index]
)
m_end[final_index] = begin[index] + 1
m_end[final_index] = m_begin[final_index] + 1
m_stride[final_index] = 1
fshape_indices.append(-2)
else:
Expand Down
2 changes: 2 additions & 0 deletions tests/python/frontend/tensorflow/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -2568,7 +2568,9 @@ def test_forward_stridedslice():

_test_stridedslice([], [0], [0], [1], "float32", new_axis_mask=1)
_test_stridedslice([2], [1], [1], [1], "float32", shrink_axis_mask=1)
_test_stridedslice([4], [-1], [0], [1], "float32", shrink_axis_mask=1)
_test_stridedslice([2, 1], [0], [1], [1], "float32", shrink_axis_mask=1)
_test_stridedslice([2, 3, 4], [-2], [0], [1], "float32", shrink_axis_mask=8)
_test_stridedslice([2, 3, 4], [0], [1], [1], "float32", shrink_axis_mask=8)
_test_stridedslice([3, 4, 3], [1, -1, 0], [4, -5, 3], [2, -1, 1], "float32")
_test_stridedslice([3, 4, 3], [1, 0], [4, 3], [2, 1], "float32", ellipsis_mask=8)
Expand Down
3 changes: 2 additions & 1 deletion tests/python/frontend/tensorflow2/test_functional_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -354,7 +354,8 @@ def get_input(self):
@tf.function(input_signature=[tf.TensorSpec(shape=(1, 30), dtype=tf.float32)])
def func(self, x):
a, b, c = tf.split(x, 3, axis=1)
return tf.raw_ops.ConcatV2(values=[a, b, c], axis=1)
axis = tf.add(tf.constant(1, dtype="int32"), tf.constant(0, dtype="int32"))
return tf.raw_ops.ConcatV2(values=[a, b, c], axis=axis)

run_all(ConcatV2)

Expand Down

0 comments on commit b5295c8

Please sign in to comment.