From b7d83241191232fe7b4ee960daa1bded18acd725 Mon Sep 17 00:00:00 2001 From: Ciyong Chen Date: Sat, 2 Mar 2019 16:56:17 +0800 Subject: [PATCH] change to use mxnet::Tuple and update tests --- src/operator/subgraph/mkldnn/mkldnn_fc.cc | 14 +++++++------- tests/python/mkl/test_subgraph.py | 4 ++-- 2 files changed, 9 insertions(+), 9 deletions(-) diff --git a/src/operator/subgraph/mkldnn/mkldnn_fc.cc b/src/operator/subgraph/mkldnn/mkldnn_fc.cc index cfb5df32793c..885781e0beb2 100644 --- a/src/operator/subgraph/mkldnn/mkldnn_fc.cc +++ b/src/operator/subgraph/mkldnn/mkldnn_fc.cc @@ -109,7 +109,7 @@ void SgMKLDNNFCOp::Forward(const OpContext &ctx, NDArray data = in_data[fullc::kData]; NDArray weight = in_data[fullc::kWeight]; NDArray output = out_data[fullc::kOut]; - const TShape &ishape = data.shape(); + const mxnet::TShape &ishape = data.shape(); if (mkldnn_param.quantized && ishape.ndim() != 2) { CHECK(default_param.flatten) << "QuantizedFullyConnected only supports flatten=true when ishape.ndim() != 2 for now."; @@ -265,12 +265,12 @@ static inline void FillBaseInputOutputInfo(const FullyConnectedParam ¶m, } static bool SgMKLDNNFCInferShape(const nnvm::NodeAttrs &attrs, - std::vector *in_shapes, - std::vector *out_shapes) { + mxnet::ShapeVector *in_shapes, + mxnet::ShapeVector *out_shapes) { auto const &full_param = nnvm::get(attrs.parsed); if (full_param.mkldnn_param.quantized) { - std::vector base_in_shapes; - std::vector base_out_shapes; + mxnet::ShapeVector base_in_shapes; + mxnet::ShapeVector base_out_shapes; FillBaseInputOutputInfo(full_param.default_param, &base_in_shapes, &base_out_shapes, in_shapes, out_shapes); bool ret = DefaultSubgraphOpShape(attrs, &base_in_shapes, &base_out_shapes); @@ -368,7 +368,7 @@ static bool SgMKLDNNFCStorageType(const nnvm::NodeAttrs &attrs, static OpStatePtr CreateSgMKLDNNFCState(const nnvm::NodeAttrs &attrs, Context ctx, - const std::vector &in_shapes, + const mxnet::ShapeVector &in_shapes, const std::vector &in_types) { return OpStatePtr::Create(attrs); } @@ -414,7 +414,7 @@ NNVM_REGISTER_OP(_sg_mkldnn_fully_connected) .set_attr_parser(SgMKLDNNFCParamParser) .set_attr("FListInputNames", SgMKLDNNFCListInputNames) .set_attr("FListOutputNames", SgMKLDNNFCListOutputNames) -.set_attr("FInferShape", SgMKLDNNFCInferShape) +.set_attr("FInferShape", SgMKLDNNFCInferShape) .set_attr("FInferType", SgMKLDNNFCInferType) .set_attr("FInferStorageType", SgMKLDNNFCStorageType) .set_attr("FCreateOpState", CreateSgMKLDNNFCState) diff --git a/tests/python/mkl/test_subgraph.py b/tests/python/mkl/test_subgraph.py index d0579de1df39..469bd3895384 100644 --- a/tests/python/mkl/test_subgraph.py +++ b/tests/python/mkl/test_subgraph.py @@ -62,7 +62,7 @@ def check_qsym_calibrated(qsym, out_type, name='conv'): if k.find('_quantize') != -1: assert v['out_type'] == out_type if k.find(quantized_op_name) != -1: - if name == 'fc' and 'fuse_dequantize' in v: + if name == 'fc' and 'enable_float_output' in v: continue assert 'min_calib_range' in v assert 'max_calib_range' in v @@ -155,7 +155,7 @@ def check_fusion(sym, data_shape, attrs_op, name='conv', check_quantization=True for k, v in sym_sg.attr_dict().items(): if k.find(op_name) != -1: for attr_op in attrs_op: - assert v[attr_op] == 'true' + assert v[attr_op] in ['true', 'True'] arg_shapes, _, aux_shapes = sym.infer_shape() arg_array = [mx.nd.random.uniform(-1, 1, shape=shape) for shape in arg_shapes]