diff --git a/src/operator/nn/fully_connected.cc b/src/operator/nn/fully_connected.cc index 2e02de300e8f..2bc321832af6 100644 --- a/src/operator/nn/fully_connected.cc +++ b/src/operator/nn/fully_connected.cc @@ -32,6 +32,12 @@ namespace mxnet { namespace op { +bool SupportMKLDNNFC(const NDArray& input) { + int ndim = input.shape().ndim(); + return input.dtype() == mshadow::kFloat32 && (ndim >= 1 && ndim <= 4) && + input.storage_type() == kDefaultStorage; +} + static bool FullyConnectedShape(const nnvm::NodeAttrs& attrs, mxnet::ShapeVector *in_shape, mxnet::ShapeVector *out_shape) { @@ -94,7 +100,7 @@ void FullyConnectedComputeExCPU(const nnvm::NodeAttrs& attrs, #if MXNET_USE_MKLDNN == 1 if (common::ContainsOnlyStorage(inputs, kDefaultStorage) && common::ContainsOnlyStorage(outputs, kDefaultStorage)) { - if (SupportMKLDNN(inputs[0])) { + if (SupportMKLDNNFC(inputs[0])) { MKLDNN_OPCHECK_INIT(false, outputs.size(), inputs, outputs); MKLDNNFCForward(attrs, ctx, inputs, req, outputs); MKLDNN_OPCHECK_RUN(FullyConnectedCompute, attrs, ctx, inputs, req, @@ -141,7 +147,7 @@ void FullyConnectedGradComputeExCPU(const nnvm::NodeAttrs& attrs, const std::vector &inputs, const std::vector &req, const std::vector &outputs) { - if (SupportMKLDNN(inputs[0])) { + if (SupportMKLDNNFC(inputs[0])) { MKLDNN_OPCHECK_INIT(true, outputs.size(), inputs, outputs); MKLDNNFCBackward(attrs, ctx, inputs, req, outputs); MKLDNN_OPCHECK_RUN(FullyConnectedGradCompute, attrs, ctx, inputs, req, diff --git a/src/operator/quantization/mkldnn/mkldnn_quantized_fully_connected.cc b/src/operator/quantization/mkldnn/mkldnn_quantized_fully_connected.cc index 36def0002073..39f8116379c2 100644 --- a/src/operator/quantization/mkldnn/mkldnn_quantized_fully_connected.cc +++ b/src/operator/quantization/mkldnn/mkldnn_quantized_fully_connected.cc @@ -50,17 +50,6 @@ void MKLDNNQuantizedFullyConnectedForward(const nnvm::NodeAttrs &attrs, NDArray data = in_data[fullc::kData]; NDArray weight = in_data[fullc::kWeight]; - const TShape &ishape = data.shape(); - - CHECK(data.dtype() == mshadow::kUint8) - << "MKLDNNQuantizedFullyConnected Op only supports uint8 for now, but got " - << mxnet::op::type_string(data.dtype()); - - if (ishape.ndim() != 2) { - CHECK(param.flatten) - << "QuantizedFullyConnected Op only supports flatten=true when ishape.ndim()!=2 for now."; - data = data.MKLDNNDataReshape(Shape2(ishape[0], ishape.ProdShape(1, ishape.ndim()))); - } const float min_data = in_data[num_inputs + quantized_fc_enum::kDataMin].data().dptr()[0]; diff --git a/src/operator/quantization/quantized_fully_connected.cc b/src/operator/quantization/quantized_fully_connected.cc index 742825c7a477..4718b3b673eb 100644 --- a/src/operator/quantization/quantized_fully_connected.cc +++ b/src/operator/quantization/quantized_fully_connected.cc @@ -50,11 +50,14 @@ bool QuantizedFullyConnectedShape(const nnvm::NodeAttrs& attrs, CHECK(!shape_is_none(in_shape->at(0))) << "QuantizedFullyConnectedOp input data shape must be given"; const mxnet::TShape& dshape = in_shape->at(0); - mxnet::TShape wshape = Shape2(param.num_hidden, dshape.ProdShape(1, dshape.ndim())); - if (dshape.ndim() != 2) { - CHECK(param.flatten) - << "QuantizedFullyConnectedOp only supports flatten=true when ishape.ndim()!=2 for now. "; + index_t num_input; + if (!param.flatten) { + num_input = dshape[dshape.ndim() - 1]; + } else { + num_input = dshape.ProdShape(1, dshape.ndim()); } + + TShape wshape = Shape2(param.num_hidden, num_input); SHAPE_ASSIGN_CHECK(*in_shape, 1, wshape); if (!param.no_bias) { mxnet::TShape bshape = Shape1(param.num_hidden); @@ -65,7 +68,13 @@ bool QuantizedFullyConnectedShape(const nnvm::NodeAttrs& attrs, SHAPE_ASSIGN_CHECK(*in_shape, i, mxnet::TShape{1}); } - SHAPE_ASSIGN_CHECK(*out_shape, 0, mxnet::TShape({dshape[0], wshape[0]})); + if (!param.flatten) { + TShape result_shape(dshape); + result_shape[dshape.ndim() - 1] = param.num_hidden; + SHAPE_ASSIGN_CHECK(*out_shape, 0, result_shape); + } else { + SHAPE_ASSIGN_CHECK(*out_shape, 0, Shape2(dshape[0], param.num_hidden)); + } SHAPE_ASSIGN_CHECK(*out_shape, 1, mxnet::TShape({1})); SHAPE_ASSIGN_CHECK(*out_shape, 2, mxnet::TShape({1})); return true; @@ -80,9 +89,9 @@ bool QuantizedFullyConnectedType(const nnvm::NodeAttrs& attrs, CHECK_EQ(out_type->size(), 3U); #if MXNET_USE_MKLDNN == 1 - // TODO(ciyong): currently, only uint8 fully_connected is upported, - // int8 fully_connected will be supported after mkldnn v0.18 - TYPE_ASSIGN_CHECK(*in_type, 0, mshadow::kUint8); + CHECK(in_type->at(0) == mshadow::kInt8 || in_type->at(0) == mshadow::kUint8) + << "QuantizedFullyConnected only supports int8/uint8 input, while " + << in_type->at(0) << " is given."; #else TYPE_ASSIGN_CHECK(*in_type, 0, mshadow::kInt8); #endif @@ -182,7 +191,8 @@ void QuantizedFullyConnectedForwardCPU(const nnvm::NodeAttrs& attrs, if (dshape.ndim() != 2) CHECK(param.flatten) - << "QuantizedFullyConnectedOp only supports flatten=true when input_shape!=2 for now. "; + << "QuantizedFullyConnectedForwardCPU only supports flatten=true " + << "when dshape.ndim() != 2 for now."; Tensor weight = in_data[fullc::kWeight].get(s); Tensor data = in_data[fullc::kData].get_with_shape( @@ -276,11 +286,6 @@ void QuantizedFullyConnectedForwardExCPU(const nnvm::NodeAttrs &attrs, const std::vector &in_data, const std::vector &req, const std::vector &out_data) { - if (in_data[fullc::kData].dtype() == mshadow::kInt8) { - FallBackCompute(QuantizedFullyConnectedForwardCPU, attrs, ctx, in_data, req, out_data); - return; - } - MKLDNNQuantizedFullyConnectedForward(attrs, ctx, in_data, req, out_data); } #endif diff --git a/src/operator/quantization/quantized_fully_connected.cu b/src/operator/quantization/quantized_fully_connected.cu index e8580e2e2c9d..d1cbdc98d535 100644 --- a/src/operator/quantization/quantized_fully_connected.cu +++ b/src/operator/quantization/quantized_fully_connected.cu @@ -75,6 +75,11 @@ void QuantizedFullyConnectedForwardGPU(const nnvm::NodeAttrs& attrs, mxnet::TShape oshape = out.shape_; // (m, n) * (k, n).T = (m, k) // A * B.T = C + if (dshape.ndim() != 2) { + CHECK(param.flatten) + << "Currently, QuantizedFullyConnected Op only supports flatten=true " + << "when ishape.ndim()!=2 for GPU."; + } // row_C = col_C(T) = cublas(col_B * col_A(T)) = cublas(row_B(T), row_A) // row_C = col_C(T) = cublas(col_B(T) * col_A(T)) = cublas(row_B, row_A) diff --git a/src/operator/subgraph/mkldnn/mkldnn_fc.cc b/src/operator/subgraph/mkldnn/mkldnn_fc.cc index 94e2bda1e16c..8829404b9576 100644 --- a/src/operator/subgraph/mkldnn/mkldnn_fc.cc +++ b/src/operator/subgraph/mkldnn/mkldnn_fc.cc @@ -116,11 +116,6 @@ void SgMKLDNNFCOp::Forward(const OpContext &ctx, NDArray data = in_data[fullc::kData]; NDArray weight = in_data[fullc::kWeight]; NDArray output = out_data[fullc::kOut]; - const mxnet::TShape &ishape = data.shape(); - if (mkldnn_param.quantized && ishape.ndim() != 2) { - CHECK(default_param.flatten) - << "QuantizedFullyConnected only supports flatten=true when ishape.ndim() != 2 for now."; - } mkldnn::memory::desc out_md = GetMemDesc(output); MKLDNNFCFlattenData(default_param, out_data[fullc::kOut], &data, &out_md); @@ -307,9 +302,10 @@ static bool SgMKLDNNFCInferType(const nnvm::NodeAttrs &attrs, if (full_param.mkldnn_param.quantized) { size_t base_num_inputs = full_param.default_param.no_bias ? 2 : 3; - // TODO(ciyong): currently, only uint8 fully_connected is upported, - // int8 fully_connected will be supported after mkldnn v0.18 - TYPE_ASSIGN_CHECK(*in_types, 0, mshadow::kUint8); + CHECK(in_types->at(0) == mshadow::kInt8 || + in_types->at(0) == mshadow::kUint8) + << "QuantizedFullyConnected only supports int8/uint8 input, while " + << in_types->at(0) << " is given."; for (size_t i = 1; i < in_types->size(); ++i) { if (i < base_num_inputs) { TYPE_ASSIGN_CHECK(*in_types, i, mshadow::kInt8); diff --git a/tests/python/mkl/test_subgraph.py b/tests/python/mkl/test_subgraph.py index e6fe0011af19..871c1e3d566a 100644 --- a/tests/python/mkl/test_subgraph.py +++ b/tests/python/mkl/test_subgraph.py @@ -124,7 +124,6 @@ def check_quantize(sym, data_shape, out_type, name='conv', mod.bind(for_training=False, data_shapes=[('data', data_shape)], label_shapes=[('softmax_label', label_shape)]) - mod.init_params(mx.init.Normal(0.5)) arg_params, aux_params = mod.get_params() @@ -136,12 +135,10 @@ def check_quantize(sym, data_shape, out_type, name='conv', output.wait_to_read() ref_out = mod.get_outputs() - # TODO(ciyong), exclude the second fc due to int8 fully_connected is not - # supported before mkldnn 0.18 excluded_sym_names = [] - if mx.current_context() == mx.cpu(): + if mx.current_context() == mx.cpu() and gluon_forward == True: + excluded_sym_names += ['sg_mkldnn_fully_connected_0'] excluded_sym_names += ['fc_softmax'] - excluded_sym_names += ['sg_mkldnn_fully_connected_1'] calib_data = mx.nd.random.uniform(shape=data_shape) calib_data = NDArrayIter(data=calib_data) @@ -193,11 +190,7 @@ def check_fusion(sym, data_shape, attrs_op, name='conv', check_quantization=True assert_almost_equal(exe.outputs[i].asnumpy(), exe_sg.outputs[i].asnumpy(), rtol=1e-3, atol=1e-3) # fp32 to int8 - # TODO(ciyong), int8 fully_connected will be supported after mkldnn 0.18 - if name == 'fc': - out_type_list = ['uint8', 'auto'] - else: - out_type_list = ['uint8', 'int8', 'auto'] + out_type_list = ['uint8', 'int8', 'auto'] if check_quantization: for out_type in out_type_list: diff --git a/tests/python/quantization/test_quantization.py b/tests/python/quantization/test_quantization.py index e2457c7a4d50..eedc867ce8d3 100644 --- a/tests/python/quantization/test_quantization.py +++ b/tests/python/quantization/test_quantization.py @@ -288,9 +288,6 @@ def check_quantized_fc(data_shape, num_hidden, no_bias, qdtype, flatten=True): if hasMKL == False: print('skipped testing quantized_fc on cpu since s8u8s32 is only supported by MKL BLAS library') return - elif qdtype == 'int8' and is_test_for_mkldnn(): - print('skipped testing test_quantized_fc for mkldnn cpu int8 since it is not supported yet') - return elif qdtype == 'uint8' and is_test_for_gpu(): print('skipped testing quantized_fc for gpu uint8 since it is not supported yet') return @@ -377,6 +374,11 @@ def maxabs(a, b): assert cond == 0 for qdtype in ['int8', 'uint8']: + if is_test_for_mkldnn(): + check_quantized_fc((32, 512, 2), 100, True, qdtype, flatten=False) + check_quantized_fc((32, 512, 2), 100, False, qdtype, flatten=False) + check_quantized_fc((32, 512, 2, 2), 100, True, qdtype, flatten=False) + check_quantized_fc((32, 512, 2, 2), 100, False, qdtype, flatten=False) check_quantized_fc((32, 512, 2, 2), 100, True, qdtype) check_quantized_fc((32, 111, 2, 2), 100, True, qdtype) check_quantized_fc((32, 512, 2, 2), 100, False, qdtype)