From b49d73236406df21b8ee882ac471afe875d3943b Mon Sep 17 00:00:00 2001 From: shufan wu Date: Sun, 21 Apr 2019 12:03:37 +0800 Subject: [PATCH] Revert "Add reshape op supported by MKL-DNN (#12980)" This reverts commit c4a619c6808df82e65081bf55e815c7a7ec55566. Conflicts: src/operator/tensor/matrix_op.cc tests/python/mkl/test_mkldnn.py --- src/operator/tensor/matrix_op.cc | 2 + tests/python/mkl/test_mkldnn.py | 443 +++++++++++++++++++++++++++++++ 2 files changed, 445 insertions(+) create mode 100644 tests/python/mkl/test_mkldnn.py diff --git a/src/operator/tensor/matrix_op.cc b/src/operator/tensor/matrix_op.cc index 29d493ae5a54..904f7bd85dc4 100644 --- a/src/operator/tensor/matrix_op.cc +++ b/src/operator/tensor/matrix_op.cc @@ -101,6 +101,8 @@ DMLC_REGISTER_PARAMETER(TileParam); DMLC_REGISTER_PARAMETER(ReverseParam); DMLC_REGISTER_PARAMETER(StackParam); DMLC_REGISTER_PARAMETER(SqueezeParam); +DMLC_REGISTER_PARAMETER(DepthToSpaceParam); +DMLC_REGISTER_PARAMETER(SplitParam); NNVM_REGISTER_OP(Reshape) .add_alias("reshape") diff --git a/tests/python/mkl/test_mkldnn.py b/tests/python/mkl/test_mkldnn.py new file mode 100644 index 000000000000..d9d3abfc3ced --- /dev/null +++ b/tests/python/mkl/test_mkldnn.py @@ -0,0 +1,443 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +""" +MKL-DNN related test cases +""" +import sys +import os +import numpy as np +import mxnet as mx +import unittest +from mxnet.test_utils import rand_ndarray, assert_almost_equal +from mxnet import gluon +from mxnet.gluon import nn +from mxnet.test_utils import * +curr_path = os.path.dirname(os.path.abspath(os.path.expanduser(__file__))) +sys.path.append(os.path.join(curr_path, '../unittest/')) +from common import with_seed + + +def test_mkldnn_model(): + model = os.path.join(os.path.dirname(os.path.realpath(__file__)), "data", + "test_mkldnn_test_mkldnn_model_model1.json") + shape = (32, 3, 300, 300) + ctx = mx.cpu() + + sym = mx.sym.load(model) + args = sym.list_arguments() + shapes = sym.infer_shape(data=shape) + + def get_tensors(args, shapes, ctx): + return {x: mx.nd.ones(y, ctx) for x, y in zip(args, shapes)} + + inputs = get_tensors(args, shapes[0], ctx) + grads = get_tensors(args, shapes[0], ctx) + + try: + exe = sym.bind(ctx, inputs, args_grad=grads) + for _ in range(2): + exe.forward(is_train=True) + for y in exe.outputs: + y.wait_to_read() + exe.backward() + for y in exe.grad_arrays: + y.wait_to_read() + except: # pylint: disable=bare-except + assert 0, "test_mkldnn_model exception in bind and execution" + +def test_mkldnn_ndarray_slice(): + ctx = mx.cpu() + net = gluon.nn.HybridSequential() + with net.name_scope(): + net.add(gluon.nn.Conv2D(channels=32, kernel_size=3, activation=None)) + net.collect_params().initialize(ctx=ctx) + x = mx.nd.array(np.ones([32, 3, 224, 224]), ctx) + y = net(x) + + # trigger computation on ndarray slice + assert_almost_equal(y[0].asnumpy()[0, 0, 0], 0.3376348) + +def test_mkldnn_engine_threading(): + net = gluon.nn.HybridSequential() + with net.name_scope(): + net.add(gluon.nn.Conv2D(channels=32, kernel_size=3, activation=None)) + net.collect_params().initialize(ctx=mx.cpu()) + class Dummy(gluon.data.Dataset): + def __len__(self): + return 2 + def __getitem__(self, key): + return key, np.ones((3, 224, 224)), np.ones((10, )) + + loader = gluon.data.DataLoader(Dummy(), batch_size=2, num_workers=1) + + X = (32, 3, 32, 32) + # trigger mkldnn execution thread + y = net(mx.nd.array(np.ones(X))).asnumpy() + + # Use Gluon dataloader to trigger different thread. + # below line triggers different execution thread + for _ in loader: + y = net(mx.nd.array(np.ones(X))).asnumpy() + # output should be 016711406 (non-mkldnn mode output) + assert_almost_equal(y[0, 0, 0, 0], 0.016711406) + break + + +@with_seed() +def test_reshape_before_conv(): + class Net(gluon.HybridBlock): + """ + test Net + """ + def __init__(self, **kwargs): + super(Net, self).__init__(**kwargs) + with self.name_scope(): + self.conv0 = nn.Conv2D(10, (3, 3)) + self.conv1 = nn.Conv2D(5, (3, 3)) + + def hybrid_forward(self, F, x, *args, **kwargs): + x_reshape = x.reshape((0, 0, 20, 5)) + y = self.conv0(x_reshape) + y_reshape = y.reshape((0, 0, 9, 6)) + out = self.conv1(y_reshape) + return out + x = mx.nd.random.uniform(shape=(2, 4, 10, 10)) + x.attach_grad() + net = Net() + net.collect_params().initialize() + with mx.autograd.record(): + out1 = net(x) + out1.backward() + dx1 = x.grad + net.hybridize() + with mx.autograd.record(): + out2 = net(x) + out2.backward() + mx.test_utils.assert_almost_equal(dx1.asnumpy(), x.grad.asnumpy(), rtol=1e-5, atol=1e-6) + mx.test_utils.assert_almost_equal(out1.asnumpy(), out2.asnumpy(), rtol=1e-5, atol=1e-6) + + +@with_seed() +def test_slice_before_conv(): + class Net(gluon.HybridBlock): + """ + test Net + """ + def __init__(self, **kwargs): + super(Net, self).__init__(**kwargs) + with self.name_scope(): + self.conv0 = nn.Conv2D(4, (3, 3)) + self.conv1 = nn.Conv2D(4, (3, 3)) + + def hybrid_forward(self, F, x, *args, **kwargs): + x_slice = x.slice(begin=(0, 0, 0, 0), end=(2, 4, 10, 10)) + y = self.conv0(x_slice) + y_slice = y.slice(begin=(1, 0, 2, 2), end=(2, 1, 7, 7)) + out = self.conv1(y_slice) + return out + x = mx.nd.random.uniform(shape=(2, 10, 10, 10)) + x.attach_grad() + net = Net() + net.collect_params().initialize() + with mx.autograd.record(): + out1 = net(x) + out1.backward() + dx1 = x.grad + net.hybridize() + with mx.autograd.record(): + out2 = net(x) + out2.backward() + mx.test_utils.assert_almost_equal(dx1.asnumpy(), x.grad.asnumpy(), rtol=1e-5, atol=1e-6) + mx.test_utils.assert_almost_equal(out1.asnumpy(), out2.asnumpy(), rtol=1e-5, atol=1e-6) + + +@with_seed() +def test_slice_reshape_before_conv(): + class Net(gluon.HybridBlock): + """ + test Net + """ + def __init__(self, **kwargs): + super(Net, self).__init__(**kwargs) + with self.name_scope(): + self.conv0 = nn.Conv2D(4, (3, 3)) + self.conv1 = nn.Conv2D(4, (3, 3)) + + def hybrid_forward(self, F, x, *args, **kwargs): + x_slice = x.slice(begin=(0, 0, 0, 0), end=(2, 4, 8, 9)) + y = self.conv0(x_slice) + y_reshape = y.reshape((0, 0, 14, 3)) + out = self.conv1(y_reshape) + return out + x = mx.nd.random.uniform(shape=(2, 10, 10, 10)) + x.attach_grad() + net = Net() + net.collect_params().initialize() + with mx.autograd.record(): + out1 = net(x) + out1.backward() + dx1 = x.grad + net.hybridize() + with mx.autograd.record(): + out2 = net(x) + out2.backward() + mx.test_utils.assert_almost_equal(dx1.asnumpy(), x.grad.asnumpy(), rtol=1e-5, atol=1e-6) + mx.test_utils.assert_almost_equal(out1.asnumpy(), out2.asnumpy(), rtol=1e-5, atol=1e-6) + + +def test_mkldnn_sum_inplace_with_cpu_layout(): + + x_shape = (32, 3, 224, 224) + x_npy = np.ones(x_shape) + y_shape = (32, 32, 222, 222) + y_npy = np.ones(y_shape) + x = mx.sym.Variable("x") + y = mx.sym.Variable("y") + z = mx.symbol.Convolution(data=x, num_filter=32, kernel=(3, 3)) + z = mx.sym.add_n(z, y) + exe = z.simple_bind(ctx=mx.cpu(), x=x_shape, y=y_shape) + out = exe.forward(is_train=False, x=x_npy, y=y_npy)[0] + assert_almost_equal(out[0].asnumpy()[0, 0, 0], 1.0) + + +@with_seed() +def test_batchnorm(): + def check_batchnorm_training(stype): + for shape in [(2, 3), (2, 3, 2, 2)]: + data_tmp = np.random.normal(-0.1, 0.1, size=shape) + s = shape[1], + gamma = np.ones(s) + beta = np.ones(s) + gamma[1] = 3 + beta[0] = 3 + + rolling_mean = np.random.uniform(size=s) + rolling_std = np.random.uniform(size=s) + + data = mx.symbol.Variable('data', stype=stype) + in_location = [mx.nd.array(data_tmp).tostype(stype), mx.nd.array(gamma).tostype(stype), + mx.nd.array(beta).tostype(stype)] + mean_std = [mx.nd.array(rolling_mean).tostype(stype), mx.nd.array(rolling_std).tostype(stype)] + + test = mx.symbol.BatchNorm(data, fix_gamma=False) + check_numeric_gradient(test, in_location, mean_std, numeric_eps=1e-2, rtol=0.16, atol=1e-2) + + stypes = ['row_sparse', 'default'] + for stype in stypes: + check_batchnorm_training(stype) + + +@with_seed() +def test_softmax(): + def check_softmax_training(stype): + for shape in [(2, 3), (2, 3, 2, 2)]: + data_tmp = np.random.normal(-0.1, 0.1, size=shape) + + data = mx.symbol.Variable('data', stype=stype) + in_location = [mx.nd.array(data_tmp).tostype(stype)] + + test = mx.symbol.softmax(data, axis=-1) + check_numeric_gradient(test, in_location, numeric_eps=1e-2, rtol=0.16, atol=1e-4) + + stypes = ['row_sparse', 'default'] + for stype in stypes: + check_softmax_training(stype) + + +@with_seed() +def test_pooling(): + def check_pooling_training(stype): + for shape in [(3, 3, 10), (3, 3, 20, 20)]: + data_tmp = np.random.normal(-0.1, 0.1, size=shape) + data = mx.symbol.Variable('data', stype=stype) + in_location = [mx.nd.array(data_tmp).tostype(stype)] + + if np.array(shape).shape[0] == 3: + test = mx.symbol.Pooling(data=data, kernel=(3,), stride=(2), pool_type='avg') + elif np.array(shape).shape[0] == 4: + test = mx.symbol.Pooling(data=data, kernel=(3, 3), stride=(2, 2), pool_type='avg') + else: + return 0 + check_numeric_gradient(test, in_location, numeric_eps=1e-2, rtol=0.16, atol=1e-4) + + stypes = ['row_sparse', 'default'] + for stype in stypes: + check_pooling_training(stype) + + +@with_seed() +def test_activation(): + def check_activation_training(stype): + for shape in [(2, 3, 3), (2, 3, 2, 2)]: + data_tmp = np.random.normal(-0.1, 1, size=shape) + + data = mx.symbol.Variable('data', stype=stype) + in_location = [mx.nd.array(data_tmp).tostype(stype)] + + test = mx.symbol.Activation(data, act_type="relu") + check_numeric_gradient(test, in_location, numeric_eps=1e-5, rtol=0.16, atol=1e-4) + + stypes = ['row_sparse', 'default'] + for stype in stypes: + check_activation_training(stype) + + +@with_seed() +def test_convolution(): + def check_convolution_training(stype): + for shape in [(3, 3, 10), (3, 3, 10, 10)]: + data_tmp = np.random.normal(-0.1, 1, size=shape) + data = mx.symbol.Variable('data', stype=stype) + + if np.array(shape).shape[0] == 3: + test = mx.symbol.Convolution(data=data, kernel=(3,), stride=(2), num_filter=4) + weight_tmp = np.random.normal(-0.1, 0.1, size=(4, 3, 3)) + elif np.array(shape).shape[0] == 4: + test = mx.symbol.Convolution(data=data, kernel=(3, 3), stride=(2, 2), num_filter=4) + weight_tmp = np.random.normal(-0.1, 0.1, size=(4, 3, 3, 3)) + else: + return 0 + bias_tmp = np.random.normal(0.1, 0.1, size=(4,)) + in_location = [mx.nd.array(data_tmp).tostype(stype), mx.nd.array(weight_tmp).tostype(stype), + mx.nd.array(bias_tmp).tostype(stype)] + check_numeric_gradient(test, in_location, numeric_eps=1e-2, rtol=0.16, atol=1e-4) + + stypes = ['row_sparse', 'default'] + for stype in stypes: + check_convolution_training(stype) + + +@with_seed() +@unittest.skip("Flaky test /~https://github.com/apache/incubator-mxnet/issues/12579") +def test_Deconvolution(): + def check_Deconvolution_training(stype): + for shape in [(3, 3, 10), (3, 3, 10, 10)]: + data_tmp = np.random.randint(256, size=shape) + data = mx.symbol.Variable('data', stype=stype) + + if np.array(shape).shape[0] == 3: + test = mx.symbol.Deconvolution(data=data, kernel=(3,), stride=(2), num_filter=4) + weight_tmp = np.random.normal(-0.1, 0.1, size=(3, 4, 3)) + elif np.array(shape).shape[0] == 4: + test = mx.symbol.Deconvolution(data=data, kernel=(3, 3), stride=(2, 2), num_filter=4) + weight_tmp = np.random.normal(-0.1, 0.1, size=(3, 4, 3, 3)) + else: + return 0 + bias_tmp = np.random.normal(0.1, 0.1, size=(4,)) + in_location = [mx.nd.array(data_tmp).tostype(stype), mx.nd.array(weight_tmp).tostype(stype), + mx.nd.array(bias_tmp).tostype(stype)] + check_numeric_gradient(test, in_location, numeric_eps=1e-2, rtol=0.16, atol=1e-4) + + stypes = ['row_sparse', 'default'] + for stype in stypes: + check_Deconvolution_training(stype) + + +@with_seed() +def test_LRN(): + def check_LRN_training(stype): + for shape in [(3, 4, 5, 5)]: + data_tmp = np.random.normal(-0.1, 0.1, size=shape) + data = mx.symbol.Variable('data', stype=stype) + in_location = [mx.nd.array(data_tmp).tostype(stype)] + + test = mx.symbol.LRN(data, nsize=3) + check_numeric_gradient(test, in_location, numeric_eps=1e-2, rtol=0.16, atol=1e-4) + + stypes = ['row_sparse', 'default'] + for stype in stypes: + check_LRN_training(stype) + + +@with_seed() +def test_fullyconnected(): + def check_fullyconnected_training(stype): + data_shape = rand_shape_nd(2) + weight_shape = rand_shape_nd(2) + weight_shape = (weight_shape[0], data_shape[1]) + for density in [1.0, 0.5, 0.0]: + x = rand_ndarray(shape=data_shape, stype=stype, density=density) + w = rand_ndarray(shape=weight_shape, stype=stype, density=density) + x_sym = mx.sym.Variable("data") + w_sym = mx.sym.Variable("weight") + sym = mx.sym.FullyConnected(data=x_sym, weight=w_sym, num_hidden=weight_shape[0], no_bias=True) + in_location = [x, w] + check_numeric_gradient(sym, in_location, numeric_eps=1e-3, rtol=1e-3, atol=5e-3) + stypes = ['row_sparse', 'default'] + for stype in stypes: + check_fullyconnected_training(stype) + +def test_softmax_with_large_inputs(): + def softmax_forward(input_data, true_output): + data = mx.sym.Variable('data') + out1 = data.softmax(axis=1) + exec1 = out1.bind(mx.cpu(), args={'data': input_data}) + exec1.forward()[0].wait_to_read() + ndarr = exec1.outputs[0][0][0][0] + nparr = ndarr.asnumpy() + assert_almost_equal(nparr, true_output, rtol=1e-5, atol=1e-5) + + softmax_forward(mx.nd.array([[[[-1e30,-1e30]]]]), np.array([1.0,1.0])) + softmax_forward(mx.nd.array([[[[1e30,1e30]]]]), np.array([1.0,1.0])) + softmax_forward(mx.nd.array([[[[-3.4e38,-3.4e38]]]]), np.array([1.0,1.0])) + softmax_forward(mx.nd.array([[[[3.4e38,3.4e38]]]]), np.array([1.0,1.0])) + +@with_seed() +def test_non_mkldnn_fcomputeex(): + # test special case where MKLDNN formatted NDArray feeds into non-mkldnn fcomputeex operator + # conv is example where MKLDNN NDArray is created from regular NDArrays + # CustomOps is example of non-mkldnn fcomputeex operator + + @mx.operator.register("custom") + class CustomProp(mx.operator.CustomOpProp): + def __int__(self): + super(CustomProp, self).__init__(need_top_grad=False) + + def list_arguments(self): + return ['data'] + + def list_outputs(self): + return ['output'] + + def infer_shape(self, in_shape): + data_shape = in_shape[0] + output_shape = in_shape[0] + return [data_shape], [output_shape], [] + + def infer_type(self, in_type): + dtype = in_type[0] + return [dtype], [dtype], [] + + def create_operator(self, ctx, shapes, dtypes): + return Custom() + + + class Custom(mx.operator.CustomOp): + def forward(self, is_train, req, in_data, out_data, aux): + print(in_data[0]) + self.assign(out_data[0], req[0], in_data[0]) + + def backward(self, req, out_grad, in_data, out_data, in_grad, aux): + self.assign(in_grad[0], req[0], out_grad) + + data = mx.symbol.Variable('data') + conv = mx.sym.Convolution(data=data, kernel=(5, 5), pad=(1, 1), stride=(1,1), num_filter=8, name="conv", no_bias=True) + custom = mx.symbol.Custom(name='custom', data=conv, op_type='custom') + exec1 = custom.bind(mx.cpu(), args={'data': mx.nd.ones([10,3,96,96]), 'conv_weight': mx.nd.ones([8,3,5,5])}) + exec1.forward()[0].wait_to_read() +