Skip to content

Commit

Permalink
[Quantization] Support zero-size tensor input for quantization flow (a…
Browse files Browse the repository at this point in the history
…pache#15031)

* [Quantization] Support zero-size tensor input for quantization flow

* Comment out quantized_act and quantized_sum

* retrigger CI

* Add test cases
  • Loading branch information
ciyongch authored and haohuw committed Jun 23, 2019
1 parent c713ba9 commit f39b0cc
Show file tree
Hide file tree
Showing 7 changed files with 146 additions and 165 deletions.
7 changes: 7 additions & 0 deletions src/operator/quantization/dequantize-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -74,11 +74,18 @@ inline bool DequantizeShape(const nnvm::NodeAttrs& attrs,
CHECK_EQ(in_attrs->size(), 3U);
CHECK_EQ(out_attrs->size(), 1U);

mxnet::TShape dshape = (*in_attrs)[0];
for (size_t i = 1; i < 3; ++i) {
SHAPE_ASSIGN_CHECK(*in_attrs, i, mxnet::TShape(1, 1));
}

SHAPE_ASSIGN_CHECK(*out_attrs, 0, in_attrs->at(0));

if ((*out_attrs)[0].ndim() > 0) {
dshape[0] = ((*out_attrs)[0])[0];
SHAPE_ASSIGN_CHECK(*in_attrs, 0, dshape);
}

return shape_is_known(out_attrs->at(0));
}

Expand Down
11 changes: 9 additions & 2 deletions src/operator/quantization/quantize-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -119,13 +119,20 @@ inline bool QuantizeShape(const nnvm::NodeAttrs& attrs,
CHECK_EQ(in_attrs->size(), 3U);
CHECK_EQ(out_attrs->size(), 3U);

mxnet::TShape dshape = (*in_attrs)[0];
for (size_t i = 1; i < 3; ++i) {
SHAPE_ASSIGN_CHECK(*in_attrs, i, mxnet::TShape(1, 1));
}

SHAPE_ASSIGN_CHECK(*out_attrs, 0, in_attrs->at(0));
SHAPE_ASSIGN_CHECK(*out_attrs, 1, mxnet::TShape{1});
SHAPE_ASSIGN_CHECK(*out_attrs, 2, mxnet::TShape{1});
SHAPE_ASSIGN_CHECK(*out_attrs, 1, mxnet::TShape(1, 1));
SHAPE_ASSIGN_CHECK(*out_attrs, 2, mxnet::TShape(1, 1));

if ((*out_attrs)[0].ndim() > 0) {
dshape[0] = ((*out_attrs)[0])[0];
SHAPE_ASSIGN_CHECK(*in_attrs, 0, dshape);
}

return shape_is_known(out_attrs->at(0));
}

Expand Down
11 changes: 9 additions & 2 deletions src/operator/quantization/quantize_v2-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -109,9 +109,16 @@ static inline bool QuantizeV2Shape(const nnvm::NodeAttrs &attrs, std::vector<TSh
CHECK_EQ(in_attrs->size(), 1U);
CHECK_EQ(out_attrs->size(), 3U);

mxnet::TShape dshape = (*in_attrs)[0];
SHAPE_ASSIGN_CHECK(*out_attrs, 0, in_attrs->at(0));
SHAPE_ASSIGN_CHECK(*out_attrs, 1, TShape{1});
SHAPE_ASSIGN_CHECK(*out_attrs, 2, TShape{1});
SHAPE_ASSIGN_CHECK(*out_attrs, 1, TShape(1, 1));
SHAPE_ASSIGN_CHECK(*out_attrs, 2, TShape(1, 1));

if ((*out_attrs)[0].ndim() > 0) {
dshape[0] = ((*out_attrs)[0])[0];
SHAPE_ASSIGN_CHECK(*in_attrs, 0, dshape);
}

return !shape_is_none(out_attrs->at(0));
}

Expand Down
4 changes: 4 additions & 0 deletions src/operator/quantization/quantized_activation.cc
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,9 @@ the float32 data into int8.
.add_argument("max_data", "NDArray-or-Symbol", "Maximum value of data.")
.add_arguments(ActivationParam::__FIELDS__());

// TODO(zhiyuan): need extra condition check if there's benefited if it's switched on
// Since it's not compute-intensive.
#if 0
NNVM_REGISTER_OP(Activation)
.set_attr<FQuantizedOp>("FQuantizedOp", [](const NodeAttrs& attrs) {
ActivationParam param;
Expand All @@ -133,6 +136,7 @@ NNVM_REGISTER_OP(Activation)
}
return node;
});
#endif

} // namespace op
} // namespace mxnet
4 changes: 4 additions & 0 deletions src/operator/quantization/quantized_elemwise_add.cc
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,9 @@ and max thresholds representing the threholds for quantizing the float32 output
.add_argument("rhs_max", "NDArray-or-Symbol", "6th input");


// TODO(zhangrong): need extra condition check if there's benefited if it's switched on
// Since it's not compute-intensive.
#if 0
NNVM_REGISTER_OP(elemwise_add)
.set_attr<FQuantizedOp>("FQuantizedOp", [](const NodeAttrs& attrs) {
nnvm::NodePtr node = nnvm::Node::Create();
Expand All @@ -136,6 +139,7 @@ NNVM_REGISTER_OP(elemwise_add)
}
return node;
});
#endif

} // namespace op
} // namespace mxnet
18 changes: 12 additions & 6 deletions src/operator/quantization/quantized_fully_connected.cc
Original file line number Diff line number Diff line change
Expand Up @@ -47,36 +47,42 @@ bool QuantizedFullyConnectedShape(const nnvm::NodeAttrs& attrs,
CHECK_EQ(in_shape->size(), num_inputs * 3);
CHECK_EQ(out_shape->size(), 3U);

CHECK(shape_is_known(in_shape->at(0)))
<< "QuantizedFullyConnectedOp input data shape must be given";
const mxnet::TShape& dshape = in_shape->at(0);
mxnet::TShape dshape = (*in_shape)[0];
// require data ndim to be known
if (!mxnet::ndim_is_known(dshape)) return false;

index_t num_input;
if (!param.flatten) {
num_input = dshape[dshape.ndim() - 1];
} else {
num_input = dshape.ProdShape(1, dshape.ndim());
}

TShape wshape = Shape2(param.num_hidden, num_input);
mxnet::TShape wshape = Shape2(param.num_hidden, num_input);
SHAPE_ASSIGN_CHECK(*in_shape, 1, wshape);
if (!param.no_bias) {
mxnet::TShape bshape = Shape1(param.num_hidden);
SHAPE_ASSIGN_CHECK(*in_shape, 2, bshape);
}

for (size_t i = num_inputs; i < 3 * num_inputs; ++i) {
SHAPE_ASSIGN_CHECK(*in_shape, i, mxnet::TShape{1});
SHAPE_ASSIGN_CHECK(*in_shape, i, mxnet::TShape(1, 1));
}

if (!param.flatten) {
TShape result_shape(dshape);
mxnet::TShape result_shape(dshape);
result_shape[dshape.ndim() - 1] = param.num_hidden;
SHAPE_ASSIGN_CHECK(*out_shape, 0, result_shape);
} else {
SHAPE_ASSIGN_CHECK(*out_shape, 0, Shape2(dshape[0], param.num_hidden));
}
SHAPE_ASSIGN_CHECK(*out_shape, 1, mxnet::TShape(1, 1));
SHAPE_ASSIGN_CHECK(*out_shape, 2, mxnet::TShape(1, 1));

if ((*out_shape)[0].ndim() > 0) {
dshape[0] = ((*out_shape)[0])[0];
SHAPE_ASSIGN_CHECK(*in_shape, 0, dshape);
}
return true;
}

Expand Down
Loading

0 comments on commit f39b0cc

Please sign in to comment.