diff --git a/src/operator/quantization/quantize_graph_pass.cc b/src/operator/quantization/quantize_graph_pass.cc index b1e3bb67ad79..fcd0fb4218be 100644 --- a/src/operator/quantization/quantize_graph_pass.cc +++ b/src/operator/quantization/quantize_graph_pass.cc @@ -222,7 +222,7 @@ Graph QuantizeGraph(Graph &&src) { // skip non-quantized input continue; } - if (quantized_op_map.count(e.node->op())) { + if (NeedQuantize(e.node, excluded_nodes)) { // here we calculate the output number (exclude min/max, in order to // calculate min/max index from mirror node) based on assumption that // there is only 1min and 1max output from mirror node (which is @@ -314,7 +314,8 @@ Graph QuantizeGraph(Graph &&src) { std::vector outputs; for (const auto& e : src.outputs) { - if (quantized_op_map.count(e.node->op())) { + if (NeedQuantize(e.node, excluded_nodes)) { + // Only insert dequantize for those Ops supports quantize and not excluded. NodePtr mirror_node = mirror_map.at(e.node.get()); NodeEntry mirror_entry = NodeEntry{mirror_node, e.index, e.version}; size_t num_inputs = e.node->num_inputs(); diff --git a/tests/python/quantization/test_quantization.py b/tests/python/quantization/test_quantization.py index ca8070cfc224..518b69626246 100644 --- a/tests/python/quantization/test_quantization.py +++ b/tests/python/quantization/test_quantization.py @@ -406,12 +406,16 @@ def get_fp32_sym(): def get_fp32_residual(): data = mx.sym.Variable('data') - conv = mx.sym.Convolution(data=data, num_filter=4, kernel=(1,1), pad=(0,0), - no_bias=True, name='conv') - bn = mx.sym.BatchNorm(data=conv, fix_gamma=False, eps=2e-5, momentum=0.9, name='bn') - act = mx.sym.Activation(data=bn + data, act_type='relu', name='relu') - pool = mx.sym.Pooling(act, kernel=(4, 4), pool_type='avg', name='pool') - fc = mx.sym.FullyConnected(pool, num_hidden=10, flatten=True, name='fc') + conv0 = mx.sym.Convolution(data=data, num_filter=4, kernel=(1,1), pad=(0,0), + no_bias=True, name='conv0') + bn = mx.sym.BatchNorm(data=conv0, fix_gamma=False, eps=2e-5, momentum=0.9, name='bn') + act0 = mx.sym.Activation(data=bn + data, act_type='relu', name='relu0') + pool0 = mx.sym.Pooling(act0, kernel=(4, 4), pool_type='avg', name='pool0') + conv1 = mx.sym.Convolution(data=pool0, num_filter=4, kernel=(1,1), pad=(0,0), + no_bias=False, name='conv1') + act1 = mx.sym.Activation(data=conv1, act_type='relu', name='relu1') + pool1 = mx.sym.Pooling(act1, kernel=(4, 4), pool_type='avg', name='pool1') + fc = mx.sym.FullyConnected(pool1, num_hidden=10, flatten=True, name='fc') sym = mx.sym.SoftmaxOutput(fc, grad_scale=1, ignore_label=-1, multi_output=False, out_grad=False, preserve_shape=False, use_ignore=False, name='softmax') return sym @@ -574,38 +578,47 @@ def check_qsym_forward(qsym, qarg_params, qaux_params, data_shape, label_shape): mod.init_params() arg_params, aux_params = mod.get_params() - excluded_sym_names = [] + excluded_names = [] if mx.current_context() == mx.cpu(): - excluded_sym_names += ['fc'] - excluded_sym_names += ['concat'] - qsym, qarg_params, qaux_params = mx.contrib.quant.quantize_model(sym=s, - arg_params=arg_params, - aux_params=aux_params, - excluded_sym_names=excluded_sym_names, - ctx=mx.current_context(), - quantized_dtype=qdtype, - calib_mode='none') - check_params(arg_params, qarg_params, qsym) - check_params(aux_params, qaux_params) - check_qsym_forward(qsym, qarg_params, qaux_params, dshape, lshape) - - calib_data = mx.nd.random.uniform(shape=dshape) - calib_data = NDArrayIter(data=calib_data, batch_size=batch_size) - calib_data = DummyIter(calib_data) - qsym, qarg_params, qaux_params = mx.contrib.quant.quantize_model(sym=s, - arg_params=arg_params, - aux_params=aux_params, - excluded_sym_names=excluded_sym_names, - ctx=mx.current_context(), - quantized_dtype=qdtype, - calib_mode='naive', - calib_data=calib_data, - num_calib_examples=20) - check_params(arg_params, qarg_params, qsym) - check_params(aux_params, qaux_params) - check_qsym_calibrated(qsym) - check_qsym_qdtype(qsym, qdtype) - check_qsym_forward(qsym, qarg_params, qaux_params, dshape, lshape) + excluded_names += ['fc'] + excluded_names += ['concat'] + + optional_names = ['pool0'] + for skip_optional_names in [False, True]: + exclude_sym_names = [] + if skip_optional_names: + excluded_sym_names = excluded_names + else: + excluded_sym_names = excluded_names + optional_names + + qsym, qarg_params, qaux_params = mx.contrib.quant.quantize_model(sym=s, + arg_params=arg_params, + aux_params=aux_params, + excluded_sym_names=excluded_sym_names, + ctx=mx.current_context(), + quantized_dtype=qdtype, + calib_mode='none') + check_params(arg_params, qarg_params, qsym) + check_params(aux_params, qaux_params) + check_qsym_forward(qsym, qarg_params, qaux_params, dshape, lshape) + + calib_data = mx.nd.random.uniform(shape=dshape) + calib_data = NDArrayIter(data=calib_data, batch_size=batch_size) + calib_data = DummyIter(calib_data) + qsym, qarg_params, qaux_params = mx.contrib.quant.quantize_model(sym=s, + arg_params=arg_params, + aux_params=aux_params, + excluded_sym_names=excluded_sym_names, + ctx=mx.current_context(), + quantized_dtype=qdtype, + calib_mode='naive', + calib_data=calib_data, + num_calib_examples=20) + check_params(arg_params, qarg_params, qsym) + check_params(aux_params, qaux_params) + check_qsym_calibrated(qsym) + check_qsym_qdtype(qsym, qdtype) + check_qsym_forward(qsym, qarg_params, qaux_params, dshape, lshape) for qdtype in ['int8', 'uint8']: check_quantize_model(qdtype)