Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
address comments
Browse files Browse the repository at this point in the history
  • Loading branch information
roywei committed May 20, 2019
1 parent ba11198 commit c7cd156
Show file tree
Hide file tree
Showing 3 changed files with 86 additions and 91 deletions.
11 changes: 3 additions & 8 deletions src/operator/tensor/dot-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -1207,14 +1207,9 @@ inline bool DotShape(const nnvm::NodeAttrs& attrs,
CHECK_EQ(out_attrs->size(), 1U);
mxnet::TShape& lshape = (*in_attrs)[0];
mxnet::TShape& rshape = (*in_attrs)[1];
// check if lhs ndim is larger than 1 and last dim is known
if (lshape.ndim() < 1 || !dim_size_is_known(lshape, lshape.ndim() - 1)) {
return false;
}
// check if rhs ndim is larger than 1 and first dim is known
if (rshape.ndim() < 1 || !dim_size_is_known(rshape, 0)) {
return false;
}
if (!ndim_is_known(lshape) || !ndim_is_known(rshape)) return false;
CHECK_GT(lshape.ndim(), 0) << "scalar tensor is not supported by this operator.";
CHECK_GT(rshape.ndim(), 0) << "scalar tensor is not supported by this operator.";
if (lshape.ndim() == 1 && rshape.ndim() == 1) {
CHECK(!param.transpose_a && !param.transpose_b) << "Cannot transpose vectors";
CHECK_EQ(lshape[0], rshape[0]) << "dot shape error: " << lshape << " X " << rshape;
Expand Down
83 changes: 83 additions & 0 deletions tests/python/unittest/test_infer_shape.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,89 @@ def test_shape_completely_unknown():
assert out_shapes[0] is None


def test_dot_partial_shape():
x = mx.sym.Variable("x")
y = mx.sym.Variable("y")
z = mx.sym.dot(x, y)
# first dim of rhs unknwon
_, result_shape, _ = z.infer_shape_partial(x=(0, 3, 4), y=(0, 4, 5))
assert result_shape == [()]
# batch size(first dim) of lhs unknown
_, result_shape, _ = z.infer_shape_partial(x=(0, 3, 4), y=(4, 5))
assert result_shape == [(0, 3, 5)]
with mx.np_compat(True):
_, result_shape, _ = z.infer_shape_partial(x=(-1, 3, 4), y=(4, 5))
assert result_shape == [(-1, 3, 5)]
_, result_shape, _ = z.infer_shape_partial(x=(-1, 3, 4), y=(-1, 4, 5))
assert result_shape == [None]


def test_batch_dot_partial_shape():
x = mx.sym.Variable("x")
y = mx.sym.Variable("y")
z = mx.sym.batch_dot(x, y)
# lhs and rhs batch size unknown
_, result_shape, _ = z.infer_shape_partial(x=(0, 3, 4), y=(0, 4, 5))
assert result_shape == [(0, 3, 5)]
# rhs second dim unknown
_, result_shape, _ = z.infer_shape_partial(x=(0, 3, 4), y=(0, 0, 5))
assert result_shape == [()]
with mx.np_compat(True):
_, result_shape, _ = z.infer_shape_partial(x=(-1, 3, 4), y=(-1, 4, 5))
assert result_shape == [(-1, 3, 5)]
_, result_shape, _ = z.infer_shape_partial(x=(-1, 3, 4), y=(-1, -1, 5))
assert result_shape == [None]


def test_embedding_partial_shape():
# testing embedding with batch size unknown
x = mx.sym.Variable("x")
w = mx.sym.Variable("w")
y = mx.sym.Embedding(data=x, weight=w, input_dim=100, output_dim=10)
_, result_shape, _ = y.infer_shape_partial(x=(0, 5), w=(100, 10))
assert result_shape == [(0, 5, 10)]
with mx.np_compat(True):
_, result_shape, _ = y.infer_shape_partial(x=(-1, 5), w=(100, 10))
assert result_shape == [(-1, 5, 10)]


def test_transpose_partial_shape():
# test converting tensor shape
# from channels first to channels last
# with batch size unknown
axes = [0, 3, 2, 1]
x = mx.sym.Variable("x")
y = mx.sym.transpose(x, axes=axes)
_, result, _ = y.infer_shape_partial(x=(0, 3, 224, 224))
assert result == [(0, 224, 224, 3)]

with mx.np_compat(True):
_, result, _ = y.infer_shape_partial(x=(-1, 3, 224, 224))
assert result == [(-1, 224, 224, 3)]


def test_pick_partial_shape():
x = mx.sym.Variable("x")
index = mx.sym.Variable("index")
y = mx.sym.pick(x, index, axis=1)
# batch size unknown
_, result, _ = y.infer_shape_partial(x=(0, 3, 3), index=(0, 3,))
assert result == [(0, 3)]
with mx.np_compat(True):
_, result, _ = y.infer_shape_partial(x=(-1, 3, 3), index=(-1, 3,))
assert result == [(-1, 3)]


def test_where_partial_shape():
x = mx.sym.Variable("x")
y = mx.sym.Variable("y")
cond = mx.sym.Variable("cond")
where_op = mx.sym.where(cond, x, y)
where_op.infer_shape_partial(cond=(0, 2), x=(0, 2), y =(0, 2))
with mx.np_compat(True):
where_op.infer_shape_partial(cond=(-1, 2), x=(-1, 2), y =(-1, 2))


if __name__ == "__main__":
test_mlp2_infer_shape()
test_mlp2_infer_error()
Expand Down
83 changes: 0 additions & 83 deletions tests/python/unittest/test_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -8436,89 +8436,6 @@ def test_add_n():
assert_almost_equal(rslt.asnumpy(), add_n_rslt.asnumpy(), atol=1e-5)


def test_dot_partial_shape():
x = mx.sym.Variable("x")
y = mx.sym.Variable("y")
z = mx.sym.dot(x, y)
# first dim of rhs unknwon
_, result_shape, _ = z.infer_shape_partial(x=(0, 3, 4), y=(0, 4, 5))
assert result_shape == [()]
# batch size(first dim) of lhs unknown
_, result_shape, _ = z.infer_shape_partial(x=(0, 3, 4), y=(4, 5))
assert result_shape == [(0, 3, 5)]
with mx.np_compat(True):
_, result_shape, _ = z.infer_shape_partial(x=(-1, 3, 4), y=(4, 5))
assert result_shape == [(-1, 3, 5)]
_, result_shape, _ = z.infer_shape_partial(x=(-1, 3, 4), y=(-1, 4, 5))
assert result_shape == [None]


def test_batch_dot_partial_shape():
x = mx.sym.Variable("x")
y = mx.sym.Variable("y")
z = mx.sym.batch_dot(x, y)
# lhs and rhs batch size unknown
_, result_shape, _ = z.infer_shape_partial(x=(0, 3, 4), y=(0, 4, 5))
assert result_shape == [(0, 3, 5)]
# rhs second dim unknown
_, result_shape, _ = z.infer_shape_partial(x=(0, 3, 4), y=(0, 0, 5))
assert result_shape == [()]
with mx.np_compat(True):
_, result_shape, _ = z.infer_shape_partial(x=(-1, 3, 4), y=(-1, 4, 5))
assert result_shape == [(-1, 3, 5)]
_, result_shape, _ = z.infer_shape_partial(x=(-1, 3, 4), y=(-1, -1, 5))
assert result_shape == [None]


def test_embedding_partial_shape():
# testing embedding with batch size unknown
x = mx.sym.Variable("x")
w = mx.sym.Variable("w")
y = mx.sym.Embedding(data=x, weight=w, input_dim=100, output_dim=10)
_, result_shape, _ = y.infer_shape_partial(x=(0, 5), w=(100, 10))
assert result_shape == [(0, 5, 10)]
with mx.np_compat(True):
_, result_shape, _ = y.infer_shape_partial(x=(-1, 5), w=(100, 10))
assert result_shape == [(-1, 5, 10)]


def test_transpose_partial_shape():
# test converting tensor shape
# from channels first to channels last
# with batch size unknown
axes = [0, 3, 2, 1]
x = mx.sym.Variable("x")
y = mx.sym.transpose(x, axes=axes)
_, result, _ = y.infer_shape_partial(x=(0, 3, 224, 224))
assert result == [(0, 224, 224, 3)]

with mx.np_compat(True):
_, result, _ = y.infer_shape_partial(x=(-1, 3, 224, 224))
assert result == [(-1, 224, 224, 3)]


def test_pick_partial_shape():
x = mx.sym.Variable("x")
index = mx.sym.Variable("index")
y = mx.sym.pick(x, index, axis=1)
# batch size unknown
_, result, _ = y.infer_shape_partial(x=(0, 3, 3), index=(0, 3,))
assert result == [(0, 3)]
with mx.np_compat(True):
_, result, _ = y.infer_shape_partial(x=(-1, 3, 3), index=(-1, 3,))
assert result == [(-1, 3)]


def test_where_partial_shape():
x = mx.sym.Variable("x")
y = mx.sym.Variable("y")
cond = mx.sym.Variable("cond")
where_op = mx.sym.where(cond, x, y)
where_op.infer_shape_partial(cond=(0, 2), x=(0, 2), y =(0, 2))
with mx.np_compat(True):
where_op.infer_shape_partial(cond=(-1, 2), x=(-1, 2), y =(-1, 2))


if __name__ == '__main__':
import nose
nose.runmodule()

0 comments on commit c7cd156

Please sign in to comment.