Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
fix quantize pass error when the quantization supported Op are exclud…
Browse files Browse the repository at this point in the history
…ed in the model (#13596)
  • Loading branch information
ciyongch authored and reminisce committed Dec 12, 2018
1 parent 002e0bb commit e36f888
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 39 deletions.
5 changes: 3 additions & 2 deletions src/operator/quantization/quantize_graph_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,7 @@ Graph QuantizeGraph(Graph &&src) {
// skip non-quantized input
continue;
}
if (quantized_op_map.count(e.node->op())) {
if (NeedQuantize(e.node, excluded_nodes)) {
// here we calculate the output number (exclude min/max, in order to
// calculate min/max index from mirror node) based on assumption that
// there is only 1min and 1max output from mirror node (which is
Expand Down Expand Up @@ -314,7 +314,8 @@ Graph QuantizeGraph(Graph &&src) {

std::vector<NodeEntry> outputs;
for (const auto& e : src.outputs) {
if (quantized_op_map.count(e.node->op())) {
if (NeedQuantize(e.node, excluded_nodes)) {
// Only insert dequantize for those Ops supports quantize and not excluded.
NodePtr mirror_node = mirror_map.at(e.node.get());
NodeEntry mirror_entry = NodeEntry{mirror_node, e.index, e.version};
size_t num_inputs = e.node->num_inputs();
Expand Down
87 changes: 50 additions & 37 deletions tests/python/quantization/test_quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -406,12 +406,16 @@ def get_fp32_sym():

def get_fp32_residual():
data = mx.sym.Variable('data')
conv = mx.sym.Convolution(data=data, num_filter=4, kernel=(1,1), pad=(0,0),
no_bias=True, name='conv')
bn = mx.sym.BatchNorm(data=conv, fix_gamma=False, eps=2e-5, momentum=0.9, name='bn')
act = mx.sym.Activation(data=bn + data, act_type='relu', name='relu')
pool = mx.sym.Pooling(act, kernel=(4, 4), pool_type='avg', name='pool')
fc = mx.sym.FullyConnected(pool, num_hidden=10, flatten=True, name='fc')
conv0 = mx.sym.Convolution(data=data, num_filter=4, kernel=(1,1), pad=(0,0),
no_bias=True, name='conv0')
bn = mx.sym.BatchNorm(data=conv0, fix_gamma=False, eps=2e-5, momentum=0.9, name='bn')
act0 = mx.sym.Activation(data=bn + data, act_type='relu', name='relu0')
pool0 = mx.sym.Pooling(act0, kernel=(4, 4), pool_type='avg', name='pool0')
conv1 = mx.sym.Convolution(data=pool0, num_filter=4, kernel=(1,1), pad=(0,0),
no_bias=False, name='conv1')
act1 = mx.sym.Activation(data=conv1, act_type='relu', name='relu1')
pool1 = mx.sym.Pooling(act1, kernel=(4, 4), pool_type='avg', name='pool1')
fc = mx.sym.FullyConnected(pool1, num_hidden=10, flatten=True, name='fc')
sym = mx.sym.SoftmaxOutput(fc, grad_scale=1, ignore_label=-1, multi_output=False,
out_grad=False, preserve_shape=False, use_ignore=False, name='softmax')
return sym
Expand Down Expand Up @@ -574,38 +578,47 @@ def check_qsym_forward(qsym, qarg_params, qaux_params, data_shape, label_shape):

mod.init_params()
arg_params, aux_params = mod.get_params()
excluded_sym_names = []
excluded_names = []
if mx.current_context() == mx.cpu():
excluded_sym_names += ['fc']
excluded_sym_names += ['concat']
qsym, qarg_params, qaux_params = mx.contrib.quant.quantize_model(sym=s,
arg_params=arg_params,
aux_params=aux_params,
excluded_sym_names=excluded_sym_names,
ctx=mx.current_context(),
quantized_dtype=qdtype,
calib_mode='none')
check_params(arg_params, qarg_params, qsym)
check_params(aux_params, qaux_params)
check_qsym_forward(qsym, qarg_params, qaux_params, dshape, lshape)

calib_data = mx.nd.random.uniform(shape=dshape)
calib_data = NDArrayIter(data=calib_data, batch_size=batch_size)
calib_data = DummyIter(calib_data)
qsym, qarg_params, qaux_params = mx.contrib.quant.quantize_model(sym=s,
arg_params=arg_params,
aux_params=aux_params,
excluded_sym_names=excluded_sym_names,
ctx=mx.current_context(),
quantized_dtype=qdtype,
calib_mode='naive',
calib_data=calib_data,
num_calib_examples=20)
check_params(arg_params, qarg_params, qsym)
check_params(aux_params, qaux_params)
check_qsym_calibrated(qsym)
check_qsym_qdtype(qsym, qdtype)
check_qsym_forward(qsym, qarg_params, qaux_params, dshape, lshape)
excluded_names += ['fc']
excluded_names += ['concat']

optional_names = ['pool0']
for skip_optional_names in [False, True]:
exclude_sym_names = []
if skip_optional_names:
excluded_sym_names = excluded_names
else:
excluded_sym_names = excluded_names + optional_names

qsym, qarg_params, qaux_params = mx.contrib.quant.quantize_model(sym=s,
arg_params=arg_params,
aux_params=aux_params,
excluded_sym_names=excluded_sym_names,
ctx=mx.current_context(),
quantized_dtype=qdtype,
calib_mode='none')
check_params(arg_params, qarg_params, qsym)
check_params(aux_params, qaux_params)
check_qsym_forward(qsym, qarg_params, qaux_params, dshape, lshape)

calib_data = mx.nd.random.uniform(shape=dshape)
calib_data = NDArrayIter(data=calib_data, batch_size=batch_size)
calib_data = DummyIter(calib_data)
qsym, qarg_params, qaux_params = mx.contrib.quant.quantize_model(sym=s,
arg_params=arg_params,
aux_params=aux_params,
excluded_sym_names=excluded_sym_names,
ctx=mx.current_context(),
quantized_dtype=qdtype,
calib_mode='naive',
calib_data=calib_data,
num_calib_examples=20)
check_params(arg_params, qarg_params, qsym)
check_params(aux_params, qaux_params)
check_qsym_calibrated(qsym)
check_qsym_qdtype(qsym, qdtype)
check_qsym_forward(qsym, qarg_params, qaux_params, dshape, lshape)

for qdtype in ['int8', 'uint8']:
check_quantize_model(qdtype)
Expand Down

0 comments on commit e36f888

Please sign in to comment.