diff --git a/src/operator/quantization/dequantize-inl.h b/src/operator/quantization/dequantize-inl.h index 92b74b787141..b5f9e385c48e 100644 --- a/src/operator/quantization/dequantize-inl.h +++ b/src/operator/quantization/dequantize-inl.h @@ -74,11 +74,18 @@ inline bool DequantizeShape(const nnvm::NodeAttrs& attrs, CHECK_EQ(in_attrs->size(), 3U); CHECK_EQ(out_attrs->size(), 1U); + mxnet::TShape dshape = (*in_attrs)[0]; for (size_t i = 1; i < 3; ++i) { SHAPE_ASSIGN_CHECK(*in_attrs, i, mxnet::TShape(1, 1)); } SHAPE_ASSIGN_CHECK(*out_attrs, 0, in_attrs->at(0)); + + if ((*out_attrs)[0].ndim() > 0) { + dshape[0] = ((*out_attrs)[0])[0]; + SHAPE_ASSIGN_CHECK(*in_attrs, 0, dshape); + } + return shape_is_known(out_attrs->at(0)); } diff --git a/src/operator/quantization/quantize-inl.h b/src/operator/quantization/quantize-inl.h index 7b856579a7b5..5108b130e1ab 100644 --- a/src/operator/quantization/quantize-inl.h +++ b/src/operator/quantization/quantize-inl.h @@ -119,13 +119,20 @@ inline bool QuantizeShape(const nnvm::NodeAttrs& attrs, CHECK_EQ(in_attrs->size(), 3U); CHECK_EQ(out_attrs->size(), 3U); + mxnet::TShape dshape = (*in_attrs)[0]; for (size_t i = 1; i < 3; ++i) { SHAPE_ASSIGN_CHECK(*in_attrs, i, mxnet::TShape(1, 1)); } SHAPE_ASSIGN_CHECK(*out_attrs, 0, in_attrs->at(0)); - SHAPE_ASSIGN_CHECK(*out_attrs, 1, mxnet::TShape{1}); - SHAPE_ASSIGN_CHECK(*out_attrs, 2, mxnet::TShape{1}); + SHAPE_ASSIGN_CHECK(*out_attrs, 1, mxnet::TShape(1, 1)); + SHAPE_ASSIGN_CHECK(*out_attrs, 2, mxnet::TShape(1, 1)); + + if ((*out_attrs)[0].ndim() > 0) { + dshape[0] = ((*out_attrs)[0])[0]; + SHAPE_ASSIGN_CHECK(*in_attrs, 0, dshape); + } + return shape_is_known(out_attrs->at(0)); } diff --git a/src/operator/quantization/quantize_v2-inl.h b/src/operator/quantization/quantize_v2-inl.h index a8cbc0b6fdf5..d8814cc6cb20 100644 --- a/src/operator/quantization/quantize_v2-inl.h +++ b/src/operator/quantization/quantize_v2-inl.h @@ -109,9 +109,16 @@ static inline bool QuantizeV2Shape(const nnvm::NodeAttrs &attrs, std::vectorsize(), 1U); CHECK_EQ(out_attrs->size(), 3U); + mxnet::TShape dshape = (*in_attrs)[0]; SHAPE_ASSIGN_CHECK(*out_attrs, 0, in_attrs->at(0)); - SHAPE_ASSIGN_CHECK(*out_attrs, 1, TShape{1}); - SHAPE_ASSIGN_CHECK(*out_attrs, 2, TShape{1}); + SHAPE_ASSIGN_CHECK(*out_attrs, 1, TShape(1, 1)); + SHAPE_ASSIGN_CHECK(*out_attrs, 2, TShape(1, 1)); + + if ((*out_attrs)[0].ndim() > 0) { + dshape[0] = ((*out_attrs)[0])[0]; + SHAPE_ASSIGN_CHECK(*in_attrs, 0, dshape); + } + return !shape_is_none(out_attrs->at(0)); } diff --git a/src/operator/quantization/quantized_activation.cc b/src/operator/quantization/quantized_activation.cc index 4ab74d0b1c3f..95c17ed30c9a 100644 --- a/src/operator/quantization/quantized_activation.cc +++ b/src/operator/quantization/quantized_activation.cc @@ -115,6 +115,9 @@ the float32 data into int8. .add_argument("max_data", "NDArray-or-Symbol", "Maximum value of data.") .add_arguments(ActivationParam::__FIELDS__()); +// TODO(zhiyuan): need extra condition check if there's benefited if it's switched on +// Since it's not compute-intensive. +#if 0 NNVM_REGISTER_OP(Activation) .set_attr("FQuantizedOp", [](const NodeAttrs& attrs) { ActivationParam param; @@ -133,6 +136,7 @@ NNVM_REGISTER_OP(Activation) } return node; }); +#endif } // namespace op } // namespace mxnet diff --git a/src/operator/quantization/quantized_elemwise_add.cc b/src/operator/quantization/quantized_elemwise_add.cc index f821e6598192..0e7034e88b8c 100644 --- a/src/operator/quantization/quantized_elemwise_add.cc +++ b/src/operator/quantization/quantized_elemwise_add.cc @@ -125,6 +125,9 @@ and max thresholds representing the threholds for quantizing the float32 output .add_argument("rhs_max", "NDArray-or-Symbol", "6th input"); +// TODO(zhangrong): need extra condition check if there's benefited if it's switched on +// Since it's not compute-intensive. +#if 0 NNVM_REGISTER_OP(elemwise_add) .set_attr("FQuantizedOp", [](const NodeAttrs& attrs) { nnvm::NodePtr node = nnvm::Node::Create(); @@ -136,6 +139,7 @@ NNVM_REGISTER_OP(elemwise_add) } return node; }); +#endif } // namespace op } // namespace mxnet diff --git a/src/operator/quantization/quantized_fully_connected.cc b/src/operator/quantization/quantized_fully_connected.cc index ceac0b6ec9a0..23790ca78b3d 100644 --- a/src/operator/quantization/quantized_fully_connected.cc +++ b/src/operator/quantization/quantized_fully_connected.cc @@ -47,9 +47,10 @@ bool QuantizedFullyConnectedShape(const nnvm::NodeAttrs& attrs, CHECK_EQ(in_shape->size(), num_inputs * 3); CHECK_EQ(out_shape->size(), 3U); - CHECK(shape_is_known(in_shape->at(0))) - << "QuantizedFullyConnectedOp input data shape must be given"; - const mxnet::TShape& dshape = in_shape->at(0); + mxnet::TShape dshape = (*in_shape)[0]; + // require data ndim to be known + if (!mxnet::ndim_is_known(dshape)) return false; + index_t num_input; if (!param.flatten) { num_input = dshape[dshape.ndim() - 1]; @@ -57,7 +58,7 @@ bool QuantizedFullyConnectedShape(const nnvm::NodeAttrs& attrs, num_input = dshape.ProdShape(1, dshape.ndim()); } - TShape wshape = Shape2(param.num_hidden, num_input); + mxnet::TShape wshape = Shape2(param.num_hidden, num_input); SHAPE_ASSIGN_CHECK(*in_shape, 1, wshape); if (!param.no_bias) { mxnet::TShape bshape = Shape1(param.num_hidden); @@ -65,11 +66,11 @@ bool QuantizedFullyConnectedShape(const nnvm::NodeAttrs& attrs, } for (size_t i = num_inputs; i < 3 * num_inputs; ++i) { - SHAPE_ASSIGN_CHECK(*in_shape, i, mxnet::TShape{1}); + SHAPE_ASSIGN_CHECK(*in_shape, i, mxnet::TShape(1, 1)); } if (!param.flatten) { - TShape result_shape(dshape); + mxnet::TShape result_shape(dshape); result_shape[dshape.ndim() - 1] = param.num_hidden; SHAPE_ASSIGN_CHECK(*out_shape, 0, result_shape); } else { @@ -77,6 +78,11 @@ bool QuantizedFullyConnectedShape(const nnvm::NodeAttrs& attrs, } SHAPE_ASSIGN_CHECK(*out_shape, 1, mxnet::TShape(1, 1)); SHAPE_ASSIGN_CHECK(*out_shape, 2, mxnet::TShape(1, 1)); + + if ((*out_shape)[0].ndim() > 0) { + dshape[0] = ((*out_shape)[0])[0]; + SHAPE_ASSIGN_CHECK(*in_shape, 0, dshape); + } return true; } diff --git a/tests/python/quantization/test_quantization.py b/tests/python/quantization/test_quantization.py index ce93f9821b9d..294e10763220 100644 --- a/tests/python/quantization/test_quantization.py +++ b/tests/python/quantization/test_quantization.py @@ -748,9 +748,6 @@ def check_quantize_model(qdtype): if is_test_for_native_cpu(): print('skipped testing test_quantize_model_with_forward for native cpu since it is not supported yet') return - elif qdtype == 'int8' and is_test_for_mkldnn(): - print('skipped testing test_quantize_model_with_forward for mkldnn cpu int8 since it is not supported yet') - return elif qdtype == 'uint8' and is_test_for_gpu(): print('skipped testing test_quantize_model_with_forward for gpu uint8 since it is not supported yet') return @@ -782,11 +779,16 @@ def check_qsym_qdtype(qsym, qdtype): assert 'out_type' in v assert v['out_type'] == qdtype - def check_qsym_forward(qsym, qarg_params, qaux_params, data_shape, label_shape): - mod = mx.mod.Module(symbol=qsym, context=mx.current_context()) - mod.bind(for_training=False, - data_shapes=[('data', data_shape)], - label_shapes=[('softmax_label', label_shape)]) + def check_qsym_forward(qsym, qarg_params, qaux_params, data_shape, label_shape=None): + if label_shape is None: + mod = mx.mod.Module(symbol=qsym, label_names=None, context=mx.current_context()) + mod.bind(for_training=False, + data_shapes=[('data', data_shape)]) + else: + mod = mx.mod.Module(symbol=qsym, context=mx.current_context()) + mod.bind(for_training=False, + data_shapes=[('data', data_shape)], + label_shapes=[('softmax_label', label_shape)]) mod.set_params(qarg_params, qaux_params) data = [mx.random.uniform(-1.0, 1.0, shape=shape) for _, shape in mod.data_shapes] batch = mx.io.DataBatch(data, []) @@ -794,165 +796,109 @@ def check_qsym_forward(qsym, qarg_params, qaux_params, data_shape, label_shape): for output in mod.get_outputs(): output.wait_to_read() - sym = get_fp32_residual() batch_size = 4 - data_shape = (batch_size, 4, 10, 10) - label_shape = (batch_size, 10) - length = batch_size # specify num of outputs from split op - msym = get_fp32_sym_with_multiple_outputs(length) - msym_label_shape = (length, 10) - msym_data_shape = (length, 4, 4, 10, 10) + sym_list = [] + name_list = [] + dshape_list = [] + lshape_list = [] + + # sym 1 + sym_list.append(get_fp32_residual()) + name_list.append('sym1') + dshape_list.append((batch_size, 4, 10, 10)) + lshape_list.append((batch_size, 10)) + + # sym 2 + sym_list.append(get_fp32_sym_with_multiple_outputs(length)) + name_list.append('sym2') + dshape_list.append((length, 4, 4, 10, 10)) + lshape_list.append((length, 10)) - for s, dshape, lshape in zip((sym, msym), (data_shape, msym_data_shape), - (label_shape, msym_label_shape)): - mod = Module(symbol=s) - mod.bind(data_shapes=[('data', dshape)], label_shapes=[('softmax_label', lshape)]) + data = mx.sym.Variable('data') + # sym 3 + sym_list.append(mx.sym.Convolution(data, kernel=(1, 1), num_filter=16, name='conv0')) + name_list.append('sym3') + dshape_list.append((batch_size, 4, 10, 10)) + lshape_list.append(None) + + # sym 4 + cell = mx.rnn.LSTMCell(num_hidden=64) + outputs, _ = cell.unroll(length, data) + sym_list.append(mx.sym.Group(outputs)) + name_list.append('sym4') + dshape_list.append((batch_size, length, 32)) + lshape_list.append(None) + + for s, dshape, lshape, name in zip(sym_list, dshape_list, lshape_list, name_list): + if qdtype == 'int8' and is_test_for_mkldnn() and name in ['sym1', 'sym2', 'sym3']: + print('skipped testing test_quantize_model_with_forward for mkldnn cpu int8 since it is not supported yet') + continue + + if lshape is None: + mod = Module(symbol=s, label_names=None) + mod.bind(for_training=False, + data_shapes=[('data', dshape)]) + else: + mod = Module(symbol=s) + mod.bind(for_training=False, + data_shapes=[('data', dshape)], + label_shapes=[('softmax_label', lshape)]) mod.init_params() arg_params, aux_params = mod.get_params() - excluded_names = [] - if mx.current_context() == mx.cpu(): - excluded_names += ['fc', 'conv1'] - if mx.current_context() == mx.gpu(): - excluded_names += ['sum0', 'relu0', 'relu1'] - excluded_names += ['concat'] - - optional_names = ['pool0'] - for skip_optional_names in [False, True]: - exclude_sym_names = [] - if skip_optional_names: - excluded_sym_names = excluded_names - else: - excluded_sym_names = excluded_names + optional_names - - qsym, qarg_params, qaux_params = mx.contrib.quant.quantize_model(sym=s, - arg_params=arg_params, - aux_params=aux_params, - excluded_sym_names=excluded_sym_names, - ctx=mx.current_context(), - quantized_dtype=qdtype, - calib_mode='none') - check_params(arg_params, qarg_params, qsym) - check_params(aux_params, qaux_params) - check_qsym_forward(qsym, qarg_params, qaux_params, dshape, lshape) - - calib_data = mx.nd.random.uniform(shape=dshape) - calib_data = NDArrayIter(data=calib_data, batch_size=batch_size) - calib_data = DummyIter(calib_data) - qsym, qarg_params, qaux_params = mx.contrib.quant.quantize_model(sym=s, - arg_params=arg_params, - aux_params=aux_params, - excluded_sym_names=excluded_sym_names, - ctx=mx.current_context(), - quantized_dtype=qdtype, - calib_mode='naive', - calib_data=calib_data, - num_calib_examples=20) - check_params(arg_params, qarg_params, qsym) - check_params(aux_params, qaux_params) - check_qsym_calibrated(qsym) - check_qsym_qdtype(qsym, qdtype) - check_qsym_forward(qsym, qarg_params, qaux_params, dshape, lshape) - - for qdtype in ['int8', 'uint8']: - check_quantize_model(qdtype) - -@with_seed() -def test_quantize_conv_with_forward(): - def check_quantize_model(qdtype): - if is_test_for_native_cpu(): - print('skipped testing test_quantize_model_with_forward for native cpu since it is not supported yet') - return - elif qdtype == 'int8' and is_test_for_mkldnn(): - print('skipped testing test_quantize_model_with_forward for mkldnn cpu int8 since it is not supported yet') - return - elif qdtype == 'uint8' and is_test_for_gpu(): - print('skipped testing test_quantize_model_with_forward for gpu uint8 since it is not supported yet') - return - def check_params(params, qparams, qsym=None): - if qsym is None: - assert len(params) == len(qparams) - for k, v in params.items(): - assert k in qparams - assert same(v.asnumpy(), qparams[k].asnumpy()) - else: - qparams_ground_truth = mx.contrib.quant._quantize_params(qsym, params, th_dict = {}) - assert len(qparams) == len(qparams_ground_truth) - for k, v in qparams_ground_truth.items(): - assert k in qparams - assert same(v.asnumpy(), qparams[k].asnumpy()) + excluded_sym_names = [] + # sym3/sym4 doesn't have such layers + if name not in ['sym3', 'sym4']: + excluded_names = [] + if mx.current_context() == mx.cpu(): + excluded_names += ['fc', 'conv1'] + if mx.current_context() == mx.gpu(): + excluded_names += ['sum0', 'relu0', 'relu1'] + excluded_names += ['concat'] + + optional_names = ['pool0'] + for skip_optional_names in [False, True]: + exclude_sym_names = [] + if skip_optional_names: + excluded_sym_names = excluded_names + else: + excluded_sym_names = excluded_names + optional_names - def check_qsym_calibrated(qsym): - attrs = qsym.attr_dict() - for k, v in attrs.items(): - if k.find('requantize_') != -1: - assert 'min_calib_range' in v - assert 'max_calib_range' in v - - def check_qsym_qdtype(qsym, qdtype): - attrs = qsym.attr_dict() - for k, v in attrs.items(): - if k.find('_quantize') != -1: - assert 'out_type' in v - assert v['out_type'] == qdtype + qsym, qarg_params, qaux_params = mx.contrib.quant.quantize_model(sym=s, + arg_params=arg_params, + aux_params=aux_params, + excluded_sym_names=excluded_sym_names, + ctx=mx.current_context(), + quantized_dtype=qdtype, + calib_mode='none') + check_params(arg_params, qarg_params, qsym) + check_params(aux_params, qaux_params) + check_qsym_forward(qsym, qarg_params, qaux_params, dshape, lshape) - def check_qsym_forward(qsym, qarg_params, qaux_params, data_shape): - mod = mx.mod.Module(symbol=qsym, label_names=None, context=mx.current_context()) - mod.bind(for_training=False, - data_shapes=[('data', data_shape)]) - mod.set_params(qarg_params, qaux_params) - data = [mx.random.uniform(-1.0, 1.0, shape=shape) for _, shape in mod.data_shapes] - batch = mx.io.DataBatch(data, []) - mod.forward(batch, is_train=False) - for output in mod.get_outputs(): - output.wait_to_read() + calib_data = mx.nd.random.uniform(shape=dshape) + calib_data = NDArrayIter(data=calib_data, batch_size=batch_size) + calib_data = DummyIter(calib_data) + qsym, qarg_params, qaux_params = mx.contrib.quant.quantize_model(sym=s, + arg_params=arg_params, + aux_params=aux_params, + excluded_sym_names=excluded_sym_names, + ctx=mx.current_context(), + quantized_dtype=qdtype, + calib_mode='naive', + calib_data=calib_data, + num_calib_examples=20) + check_params(arg_params, qarg_params, qsym) + check_params(aux_params, qaux_params) + check_qsym_calibrated(qsym) + check_qsym_qdtype(qsym, qdtype) + check_qsym_forward(qsym, qarg_params, qaux_params, dshape, lshape) - batch_size = 4 - dshape = (batch_size, 4, 10, 10) - data = mx.sym.Variable('data') - sym = mx.sym.Convolution(data, kernel=(1, 1), num_filter=16, name='conv0') - - mod = Module(symbol=sym, label_names=None) - mod.bind(data_shapes=[('data', dshape)]) - - mod.init_params() - arg_params, aux_params = mod.get_params() - excluded_sym_names = [] - - qsym, qarg_params, qaux_params = mx.contrib.quant.quantize_model(sym=sym, - arg_params=arg_params, - aux_params=aux_params, - excluded_sym_names=excluded_sym_names, - ctx=mx.current_context(), - quantized_dtype=qdtype, - calib_mode='none') - check_params(arg_params, qarg_params, qsym) - check_params(aux_params, qaux_params) - check_qsym_forward(qsym, qarg_params, qaux_params, dshape) - - calib_data = mx.nd.random.uniform(shape=dshape) - calib_data = NDArrayIter(data=calib_data, batch_size=batch_size) - calib_data = DummyIter(calib_data) - qsym, qarg_params, qaux_params = mx.contrib.quant.quantize_model(sym=sym, - arg_params=arg_params, - aux_params=aux_params, - excluded_sym_names=excluded_sym_names, - ctx=mx.current_context(), - quantized_dtype=qdtype, - calib_mode='naive', - calib_data=calib_data, - num_calib_examples=20) - check_params(arg_params, qarg_params, qsym) - check_params(aux_params, qaux_params) - check_qsym_calibrated(qsym) - check_qsym_qdtype(qsym, qdtype) - check_qsym_forward(qsym, qarg_params, qaux_params, dshape) - - for qdtype in ['uint8', 'int8']: + for qdtype in ['int8', 'uint8']: check_quantize_model(qdtype) + @with_seed() def test_quantize_sym_with_calib(): sym = get_fp32_sym()