Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
API change
Browse files Browse the repository at this point in the history
  • Loading branch information
ZhennanQin committed Jan 16, 2019
1 parent 208beaa commit 5695bad
Show file tree
Hide file tree
Showing 9 changed files with 48 additions and 40 deletions.
3 changes: 1 addition & 2 deletions example/quantization/imagenet_gen_qsym_mkldnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -302,8 +302,7 @@ def save_params(fname, arg_params, aux_params, logger=None):
calib_mode=calib_mode, calib_data=data,
num_calib_examples=num_calib_batches * batch_size,
calib_layer=calib_layer, quantized_dtype=args.quantized_dtype,
label_names=(label_name,), calib_quantize_op = True,
logger=logger)
label_names=(label_name,), logger=logger)
if calib_mode == 'entropy':
suffix = '-quantized-%dbatches-entropy' % num_calib_batches
elif calib_mode == 'naive':
Expand Down
1 change: 0 additions & 1 deletion example/ssd/quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,6 @@ def save_params(fname, arg_params, aux_params, logger=None):
num_calib_examples=num_calib_batches * batch_size,
calib_layer=calib_layer, quantized_dtype=args.quantized_dtype,
label_names=(label_name,),
calib_quantize_op = True,
logger=logger)
sym_name = '%s-symbol.json' % ('./model/cqssd_vgg16_reduced_300')
param_name = '%s-%04d.params' % ('./model/cqssd_vgg16_reduced_300', epoch)
Expand Down
3 changes: 1 addition & 2 deletions include/mxnet/c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -1556,13 +1556,12 @@ MXNET_DLL int MXSymbolInferType(SymbolHandle sym,
* \param num_offline number of parameters that are quantized offline
* \param offline_params array of c strings representing the names of params quantized offline
* \param quantized_dtype the quantized destination type for input data.
* \param calib_quantize whether calibrate quantize op with offline calibration data.
*/
MXNET_DLL int MXQuantizeSymbol(SymbolHandle sym_handle, SymbolHandle *ret_sym_handle,
const mx_uint num_excluded_symbols,
const char **excluded_symbols,
const mx_uint num_offline, const char **offline_params,
const char *quantized_dtype, const bool calib_quantize);
const char *quantized_dtype);

/*!
* \brief Set calibration table to node attributes in the sym
Expand Down
10 changes: 7 additions & 3 deletions include/mxnet/ndarray.h
Original file line number Diff line number Diff line change
Expand Up @@ -694,9 +694,13 @@ class NDArray {
/*
* Create NDArray from mkldnn memory.
* mkldnn_mem The mkldnn memory to be managed.
* static_data If true, mkldnn memory won't be freed on destruction.
*/
explicit NDArray(const mkldnn::memory *mkldnn_mem, bool static_data = true);
explicit NDArray(const std::shared_ptr<mkldnn::memory> &mkldnn_mem);
/*
* Create NDArray from mkldnn memory descriptor.
* mem_pd The mkldnn memory descriptor to be created.
*/
explicit NDArray(mkldnn::memory::primitive_desc mem_pd);
/*
* Test if the data is stored in one of special MKLDNN format.
*/
Expand Down Expand Up @@ -776,7 +780,7 @@ class NDArray {
/*!
* \ Fix mkldnn memory descriptor mismatch from NDArray.
*/
void UpdateMKLDNNMemDesc();
void UpdateMKLDNNMemDesc(mkldnn::memory::format format);
#endif

/*!
Expand Down
14 changes: 4 additions & 10 deletions python/mxnet/contrib/quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ def _quantize_params(qsym, params, th_dict):
return quantized_params

def _quantize_symbol(sym, excluded_symbols=None, offline_params=None,
quantized_dtype='int8', calib_quantize_op=False):
quantized_dtype='int8'):
"""Given a symbol object representing a neural network of data type FP32,
quantize it into a INT8 network.
Expand All @@ -98,8 +98,6 @@ def _quantize_symbol(sym, excluded_symbols=None, offline_params=None,
avoided.
quantized_dtype: str
The quantized destination type for input data.
calib_quantize_op : bool
Whether perform offline calibration for quantize op.
"""
num_excluded_symbols = 0
if excluded_symbols is not None:
Expand All @@ -122,8 +120,7 @@ def _quantize_symbol(sym, excluded_symbols=None, offline_params=None,
c_str_array(excluded_symbols),
mx_uint(num_offline),
c_array(ctypes.c_char_p, offline),
c_str(quantized_dtype),
ctypes.c_bool(calib_quantize_op)))
c_str(quantized_dtype)))
return Symbol(out)


Expand Down Expand Up @@ -424,7 +421,7 @@ def quantize_model(sym, arg_params, aux_params,
data_names=('data',), label_names=('softmax_label',),
ctx=cpu(), excluded_sym_names=None, calib_mode='entropy',
calib_data=None, num_calib_examples=None, calib_layer=None,
quantized_dtype='int8', calib_quantize_op=False, logger=logging):
quantized_dtype='int8', logger=logging):
"""User-level API for generating a quantized model from a FP32 model w/ or w/o calibration.
The backend quantized operators are only enabled for Linux systems. Please do not run
inference using the quantized models on Windows for now.
Expand Down Expand Up @@ -477,8 +474,6 @@ def quantize_model(sym, arg_params, aux_params,
quantized_dtype : str
The quantized destination type for input data. Currently support 'int8'
and 'uint8', default value is 'int8'.
calib_quantize_op: bool
Whether calibrate quantize op with its input calibration data. The quantize op's input should be in calib_layer
logger : Object
A logging object for printing information during the process of quantization.
Expand All @@ -501,8 +496,7 @@ def quantize_model(sym, arg_params, aux_params,
' expected `int8` or `uint8`' % quantized_dtype)
qsym = _quantize_symbol(sym, excluded_symbols=excluded_sym_names,
offline_params=list(arg_params.keys()),
quantized_dtype=quantized_dtype,
calib_quantize_op=calib_quantize_op)
quantized_dtype=quantized_dtype)

th_dict = {}
if calib_mode is not None and calib_mode != 'none':
Expand Down
5 changes: 2 additions & 3 deletions src/c_api/c_api_symbolic.cc
Original file line number Diff line number Diff line change
Expand Up @@ -650,8 +650,7 @@ int MXQuantizeSymbol(SymbolHandle sym_handle,
const char **excluded_op_names,
const mx_uint num_offline,
const char **offline_params,
const char *quantized_dtype,
const bool calib_quantize) {
const char *quantized_dtype) {
nnvm::Symbol *s = new nnvm::Symbol();
API_BEGIN();
nnvm::Symbol *sym = static_cast<nnvm::Symbol*>(sym_handle);
Expand All @@ -668,7 +667,7 @@ int MXQuantizeSymbol(SymbolHandle sym_handle,
g.attrs["excluded_nodes"] = std::make_shared<nnvm::any>(std::move(excluded_node_names));
g.attrs["offline_params"] = std::make_shared<nnvm::any>(std::move(offline));
g.attrs["quantized_dtype"] = std::make_shared<nnvm::any>(std::move(quantized_type));
g.attrs["calib_quantize"] = std::make_shared<nnvm::any>(calib_quantize);
g.attrs["calib_quantize"] = std::make_shared<nnvm::any>(true);
g = ApplyPass(std::move(g), "QuantizeGraph");
s->outputs = g.outputs;
*ret_sym_handle = s;
Expand Down
39 changes: 24 additions & 15 deletions src/ndarray/ndarray.cc
Original file line number Diff line number Diff line change
Expand Up @@ -168,16 +168,28 @@ nnvm::Symbol NDArray::get_autograd_symbol() const {

#if MXNET_USE_MKLDNN == 1

NDArray::NDArray(const mkldnn::memory *mkldnn_mem, bool static_data)
NDArray::NDArray(mkldnn::memory::primitive_desc mem_pd)
: storage_type_(kDefaultStorage), entry_({nullptr, 0, 0}) {
auto mem_pd = mkldnn_mem->get_primitive_desc();
auto mem_desc = mem_pd.desc();
shape_ = TShape(mem_desc.data.dims, mem_desc.data.dims + mem_desc.data.ndims);
dtype_ = get_mxnet_type(mem_desc.data.data_type);
auto data = TBlob(mkldnn_mem->get_data_handle(), shape_, cpu::kDevMask, dtype_);
ptr_ = std::make_shared<Chunk>(data, 0);
ptr_ = std::make_shared<Chunk>(shape_, Context::CPU(), true, dtype_);
ptr_->CheckAndAlloc(mem_pd.get_size());
ptr_->mkl_mem_ = std::make_shared<MKLDNNMemory>(mem_pd, ptr_->shandle.dptr);
ptr_->static_data = static_data;
}

NDArray::NDArray(const std::shared_ptr<mkldnn::memory> &mkldnn_mem)
: storage_type_(kDefaultStorage), entry_({nullptr, 0, 0}) {
auto mem_pd = mkldnn_mem->get_primitive_desc();
auto mem_desc = mem_pd.desc();
shape_ = TShape(mem_desc.data.dims, mem_desc.data.dims + mem_desc.data.ndims);
dtype_ = get_mxnet_type(mem_desc.data.data_type);
ptr_ = std::make_shared<Chunk>(shape_, Context::CPU(), true, dtype_);
ptr_->shandle.dptr = mkldnn_mem->get_data_handle();
ptr_->shandle.size = mem_pd.get_size();
ptr_->delay_alloc = false;
ptr_->mkl_mem_ = std::make_shared<MKLDNNMemory>(mkldnn_mem);
ptr_->static_data = true;
}

NDArray NDArray::MKLDNNDataReshape(const TShape &shape) const {
Expand Down Expand Up @@ -717,19 +729,16 @@ mkldnn::memory *NDArray::CreateMKLDNNData(const mkldnn::memory::primitive_desc &
return ptr_->mkl_mem_->GetRaw();
}

void NDArray::UpdateMKLDNNMemDesc() {
void NDArray::UpdateMKLDNNMemDesc(mkldnn::memory::format format) {
const mkldnn::memory *mem = GetMKLDNNData();
auto mem_desc = mem->get_primitive_desc().desc();
auto this_dtype = get_mkldnn_type(dtype());
if (this_dtype != mem_desc.data.data_type) {
mkldnn::memory::desc data_md(
mkldnn::memory::dims(mem_desc.data.dims,
mem_desc.data.dims + mem_desc.data.ndims),
this_dtype, static_cast<mkldnn::memory::format>(mem_desc.data.format));
mkldnn::memory::primitive_desc pd(data_md, CpuEngine::Get()->get_engine());
ptr_->mkl_mem_.reset(new MKLDNNMemory(pd, ptr_->shandle.dptr));
MKLDNNStream::Get()->RegisterMem(ptr_->mkl_mem_->GetMem());
}
mkldnn::memory::desc data_md(
mkldnn::memory::dims(mem_desc.data.dims, mem_desc.data.dims + mem_desc.data.ndims),
this_dtype, format);
mkldnn::memory::primitive_desc pd(data_md, CpuEngine::Get()->get_engine());
ptr_->mkl_mem_.reset(new MKLDNNMemory(pd, ptr_->shandle.dptr));
MKLDNNStream::Get()->RegisterMem(ptr_->mkl_mem_->GetMem());
}
#endif

Expand Down
12 changes: 9 additions & 3 deletions src/operator/subgraph/mkldnn/mkldnn_conv.cc
Original file line number Diff line number Diff line change
Expand Up @@ -261,8 +261,12 @@ void SgMKLDNNConvOperator::Forward(const OpContext &ctx,
}
if (!inplace_) {
auto in_mkl_mem = inputs[in_sum].GetMKLDNNData();
const_cast<NDArray &>(outputs[kOut]).CopyFrom(*in_mkl_mem);
output = NDArray(outputs[kOut].GetMKLDNNData());
auto out_mkl_mem = outputs[kOut].GetMKLDNNData();
mkldnn_mem_ptr tmp_mem(
new mkldnn::memory(in_mkl_mem->get_primitive_desc(), out_mkl_mem->get_data_handle()));
MKLDNNStream::Get()->RegisterMem(tmp_mem);
mxnet::MKLDNNCopy(*in_mkl_mem, tmp_mem.get());
output = NDArray(tmp_mem);
}
}

Expand Down Expand Up @@ -388,7 +392,9 @@ void SgMKLDNNConvOperator::Forward(const OpContext &ctx,

if (mkldnn_param.with_sum) {
auto out = const_cast<NDArray &>(outputs[kOut]);
out.UpdateMKLDNNMemDesc();
auto format = static_cast<mkldnn::memory::format>(
fwd_->fwd_pd.dst_primitive_desc().desc().data.format);
out.UpdateMKLDNNMemDesc(format);
}
}

Expand Down
1 change: 0 additions & 1 deletion tests/python/mkl/test_subgraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,6 @@ def check_quantize(sym, data_shape, check_conv=True):
calib_mode='naive',
calib_data=calib_data,
calib_layer=calib_layer,
calib_quantize_op=True,
num_calib_examples=5)
qsym = qsym.get_backend_symbol("MKLDNN_POST_QUANTIZE")
if check_conv:
Expand Down

0 comments on commit 5695bad

Please sign in to comment.