diff --git a/docs/api/python/ndarray/ndarray.md b/docs/api/python/ndarray/ndarray.md index dda534151a1e..849412021e1e 100644 --- a/docs/api/python/ndarray/ndarray.md +++ b/docs/api/python/ndarray/ndarray.md @@ -131,6 +131,7 @@ The `ndarray` package provides several classes: NDArray.flatten NDArray.expand_dims NDArray.split + NDArray.diag ``` ### Array expand elements @@ -364,6 +365,7 @@ The `ndarray` package provides several classes: ones_like full arange + diag load save ``` diff --git a/docs/api/python/symbol/symbol.md b/docs/api/python/symbol/symbol.md index 304b17803eda..a59a92745c73 100644 --- a/docs/api/python/symbol/symbol.md +++ b/docs/api/python/symbol/symbol.md @@ -182,6 +182,7 @@ Composite multiple symbols into a new one by an operator. Symbol.zeros_like Symbol.ones_like + Symbol.diag ``` ### Changing shape and type @@ -381,6 +382,7 @@ Composite multiple symbols into a new one by an operator. reshape_like flatten expand_dims + diag ``` ### Expanding elements diff --git a/python/mxnet/ndarray/ndarray.py b/python/mxnet/ndarray/ndarray.py index 09395e2ec828..ff9aac05c7c3 100644 --- a/python/mxnet/ndarray/ndarray.py +++ b/python/mxnet/ndarray/ndarray.py @@ -1302,6 +1302,14 @@ def flip(self, *args, **kwargs): """ return op.flip(self, *args, **kwargs) + def diag(self, k=0, **kwargs): + """Convenience fluent method for :py:func:`diag`. + + The arguments are the same as for :py:func:`diag`, with + this array as data. + """ + return op.diag(self, k, **kwargs) + def sum(self, *args, **kwargs): """Convenience fluent method for :py:func:`sum`. diff --git a/python/mxnet/symbol/symbol.py b/python/mxnet/symbol/symbol.py index b041f4ef646f..88f92cde0fe4 100644 --- a/python/mxnet/symbol/symbol.py +++ b/python/mxnet/symbol/symbol.py @@ -2038,6 +2038,14 @@ def flip(self, *args, **kwargs): """ return op.flip(self, *args, **kwargs) + def diag(self, k=0, **kwargs): + """Convenience fluent method for :py:func:`diag`. + + The arguments are the same as for :py:func:`diag`, with + this array as data. + """ + return op.diag(self, k, **kwargs) + def sum(self, *args, **kwargs): """Convenience fluent method for :py:func:`sum`. diff --git a/src/operator/tensor/diag_op-inl.h b/src/operator/tensor/diag_op-inl.h new file mode 100644 index 000000000000..3bc240f206b4 --- /dev/null +++ b/src/operator/tensor/diag_op-inl.h @@ -0,0 +1,217 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! +* Copyright (c) 2015 by Contributors +* \file diag_op-inl.h +* \brief CPU Implementation of the diag op +* \author Istvan Fehervari +*/ + +#ifndef MXNET_OPERATOR_TENSOR_DIAG_OP_INL_H_ +#define MXNET_OPERATOR_TENSOR_DIAG_OP_INL_H_ + +#include +#include +#include +#include "../mxnet_op.h" +#include "../operator_common.h" +#include "../elemwise_op_common.h" + +namespace mxnet { +namespace op { + +struct DiagParam : public dmlc::Parameter { + dmlc::optional k; + DMLC_DECLARE_PARAMETER(DiagParam) { + DMLC_DECLARE_FIELD(k) + .set_default(dmlc::optional(0)) + .describe("Diagonal in question. The default is 0. " + "Use k>0 for diagonals above the main diagonal, " + "and k<0 for diagonals below the main diagonal. " + "If input has shape (S0 S1) k must be between -S0 and S1"); + } +}; + +inline TShape DiagShapeImpl(const TShape& ishape, const nnvm::dim_t k) { + if (ishape.ndim() == 1) { + auto s = ishape[0] + std::abs(k); + return TShape({s, s}); + } + + auto h = ishape[0]; + auto w = ishape[1]; + + if (k > 0) { + w -= k; + } else if (k < 0) { + h += k; + } + + auto s = std::min(h, w); + if (s < 0) { + s = 0; + } + + return TShape({s}); +} + +inline bool DiagOpShape(const nnvm::NodeAttrs& attrs, + std::vector* in_attrs, + std::vector* out_attrs) { + CHECK_EQ(in_attrs->size(), 1U); + CHECK_EQ(out_attrs->size(), 1U); + + const TShape& ishape = (*in_attrs)[0]; + if (ishape.ndim() == 0) return false; + if (ishape.ndim() > 2) LOG(FATAL) << "Input must be 1- or 2-d."; + + const DiagParam& param = nnvm::get(attrs.parsed); + + TShape oshape = DiagShapeImpl(ishape, param.k.value()); + if (shape_is_none(oshape)) { + LOG(FATAL) << "Diagonal does not exist."; + } + SHAPE_ASSIGN_CHECK(*out_attrs, 0, oshape); + + return out_attrs->at(0).ndim() != 0U; +} + +inline bool DiagOpType(const nnvm::NodeAttrs& attrs, + std::vector *in_attrs, + std::vector *out_attrs) { + CHECK_EQ(in_attrs->size(), 1U); + CHECK_EQ(out_attrs->size(), 1U); + + TYPE_ASSIGN_CHECK(*out_attrs, 0, (*in_attrs)[0]); + TYPE_ASSIGN_CHECK(*in_attrs, 0, (*out_attrs)[0]); + return (*out_attrs)[0] != -1; +} + +template +struct diag { + template + MSHADOW_XINLINE static void Map(int i, DType* out, const DType* a, + mshadow::Shape<2> ishape, int k) { + using namespace mxnet_op; + int j = 0; + if (k > 0) { + j = ravel(mshadow::Shape2(i, i + k), ishape); + } else if (k < 0) { + j = ravel(mshadow::Shape2(i - k, i), ishape); + } else { + j = ravel(mshadow::Shape2(i, i), ishape); + } + + KERNEL_ASSIGN(out[i], req, a[j]); + } +}; + +template +struct diag_gen { + template + MSHADOW_XINLINE static void Map(int i, DType* out, const DType* a, + mshadow::Shape<2> oshape, int k) { + using namespace mxnet_op; + + auto j = unravel(i, oshape); + if (j[1] == (j[0] + k)) { + auto l = j[0] < j[1] ? j[0] : j[1]; + KERNEL_ASSIGN(out[i], req, a[l]); + } else { + KERNEL_ASSIGN(out[i], req, static_cast(0)); + } + } +}; + +template +void DiagOpForward(const nnvm::NodeAttrs& attrs, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { + using namespace mxnet_op; + using namespace mshadow; + CHECK_EQ(inputs.size(), 1U); + CHECK_EQ(outputs.size(), 1U); + CHECK_EQ(req.size(), 1U); + CHECK_EQ(req[0], kWriteTo); + Stream *s = ctx.get_stream(); + const TBlob& in_data = inputs[0]; + const TBlob& out_data = outputs[0]; + const TShape& ishape = inputs[0].shape_; + const TShape& oshape = outputs[0].shape_; + const DiagParam& param = nnvm::get(attrs.parsed); + + if (ishape.ndim() == 2) { + MSHADOW_TYPE_SWITCH(out_data.type_flag_, DType, { + MXNET_ASSIGN_REQ_SWITCH(req[0], req_type, { + Kernel, xpu>::Launch(s, out_data.Size(), out_data.dptr(), + in_data.dptr(), Shape2(ishape[0], ishape[1]), param.k.value()); + }); + }); + } else { + MSHADOW_TYPE_SWITCH(out_data.type_flag_, DType, { + MXNET_ASSIGN_REQ_SWITCH(req[0], req_type, { + Kernel, xpu>::Launch(s, out_data.Size(), out_data.dptr(), + in_data.dptr(), Shape2(oshape[0], oshape[1]), param.k.value()); + }); + }); + } +} + +template +void DiagOpBackward(const nnvm::NodeAttrs& attrs, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { + using namespace mxnet_op; + using namespace mshadow; + CHECK_EQ(inputs.size(), 1U); + CHECK_EQ(outputs.size(), 1U); + Stream *s = ctx.get_stream(); + + const TBlob& in_data = inputs[0]; + const TBlob& out_data = outputs[0]; + const TShape& ishape = inputs[0].shape_; + const TShape& oshape = outputs[0].shape_; + const DiagParam& param = nnvm::get(attrs.parsed); + + if (oshape.ndim() == 2) { + MSHADOW_TYPE_SWITCH(out_data.type_flag_, DType, { + MXNET_ASSIGN_REQ_SWITCH(req[0], req_type, { + Kernel, xpu>::Launch(s, out_data.Size(), out_data.dptr(), + in_data.dptr(), Shape2(oshape[0], oshape[1]), param.k.value()); + }); + }); + } else { + MSHADOW_TYPE_SWITCH(out_data.type_flag_, DType, { + MXNET_ASSIGN_REQ_SWITCH(req[0], req_type, { + Kernel, xpu>::Launch(s, out_data.Size(), out_data.dptr(), + in_data.dptr(), Shape2(ishape[0], ishape[1]), param.k.value()); + }); + }); + } +} + +} // namespace op +} // namespace mxnet + +#endif // MXNET_OPERATOR_TENSOR_DIAG_OP_INL_H_ diff --git a/src/operator/tensor/diag_op.cc b/src/operator/tensor/diag_op.cc new file mode 100644 index 000000000000..1ad3b8adc028 --- /dev/null +++ b/src/operator/tensor/diag_op.cc @@ -0,0 +1,93 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! +* Copyright (c) 2015 by Contributors +* \file diag_op.cc +* \brief +* \author Istvan Fehervari +*/ + +#include "./diag_op-inl.h" + +namespace mxnet { +namespace op { + +DMLC_REGISTER_PARAMETER(DiagParam); + +NNVM_REGISTER_OP(diag) +.describe(R"code(Extracts a diagonal or constructs a diagonal array. + +``diag``'s behavior depends on the input array dimensions: + +- 1-D arrays: constructs a 2-D array with the input as its diagonal, all other elements are zero +- 2-D arrays: returns elements in the diagonal as a new 1-D array +- N-D arrays: not supported yet + +Examples:: + + x = [[1, 2, 3], + [4, 5, 6]] + + diag(x) = [1, 5] + + diag(x, k=1) = [2, 6] + + diag(x, k=-1) = [4] + + x = [1, 2, 3] + + diag(x) = [[1, 0, 0], + [0, 2, 0], + [0, 0, 3]] + + diag(x, k=1) = [[0, 1, 0], + [0, 0, 2], + [0, 0, 0]] + + diag(x, k=-1) = [[0, 0, 0], + [1, 0, 0], + [0, 2, 0]] + +)code" ADD_FILELINE) +.set_attr_parser(ParamParser) +.set_num_inputs(1) +.set_num_outputs(1) +.set_attr("FListInputNames", + [](const NodeAttrs& attrs) { + return std::vector{"data"}; + }) +.set_attr("FInferShape", DiagOpShape) +.set_attr("FInferType", DiagOpType) +.set_attr("FCompute", DiagOpForward) +.set_attr("FGradient", ElemwiseGradUseNone{"_backward_diag"}) +.add_argument("data", "NDArray-or-Symbol", "Input ndarray") +.add_arguments(DiagParam::__FIELDS__()); + + +NNVM_REGISTER_OP(_backward_diag) +.set_attr_parser(ParamParser) +.set_num_inputs(1) +.set_num_outputs(1) +.set_attr("TIsBackward", true) +.set_attr("FCompute", DiagOpBackward); + + +} // namespace op +} // namespace mxnet diff --git a/src/operator/tensor/diag_op.cu b/src/operator/tensor/diag_op.cu new file mode 100644 index 000000000000..a3928f763869 --- /dev/null +++ b/src/operator/tensor/diag_op.cu @@ -0,0 +1,39 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! +* Copyright (c) 2015 by Contributors +* \file diag_op.cu +* \brief GPU Implementation of the diag op +* \author Istvan Fehervari +*/ + +#include "./diag_op-inl.h" + +namespace mxnet { +namespace op { + +NNVM_REGISTER_OP(diag) +.set_attr("FCompute", DiagOpForward); + +NNVM_REGISTER_OP(_backward_diag) +.set_attr("FCompute", DiagOpBackward); + +} // namespace op +} // namespace mxnet diff --git a/tests/python/unittest/test_operator.py b/tests/python/unittest/test_operator.py index 814266ad9aa3..a763037409ab 100644 --- a/tests/python/unittest/test_operator.py +++ b/tests/python/unittest/test_operator.py @@ -27,7 +27,7 @@ from numpy.testing import assert_allclose, assert_array_equal from mxnet.test_utils import * from mxnet.base import py_str, MXNetError, _as_list -from common import setup_module, with_seed, teardown, assert_raises_cudnn_disabled +from common import setup_module, with_seed, teardown, assert_raises_cudnn_disabled, assertRaises import unittest def check_rnn_consistency(cell1, cell2, T, N, I, H, grad_req): @@ -7033,6 +7033,79 @@ def test_roi_align_autograd(sampling_ratio=0): test_roi_align_value(2) test_roi_align_autograd() +@with_seed() +def test_diag(): + + # Test 2d input + h = np.random.randint(2,9) + w = np.random.randint(2,9) + a_np = np.random.random((h, w)).astype(np.float32) + a = mx.nd.array(a_np).astype('float32') + + # k == 0 + r = mx.nd.diag(a) + assert_almost_equal(r.asnumpy(), np.diag(a_np)) + + # k == 1 + k = 1 + r = mx.nd.diag(a, k=k) + assert_almost_equal(r.asnumpy(), np.diag(a_np, k=k)) + + # k == -1 + k = -1 + r = mx.nd.diag(a, k=k) + assert_almost_equal(r.asnumpy(), np.diag(a_np, k=k)) + + # random k + k = np.random.randint(-min(h,w) + 1, min(h,w)) + r = mx.nd.diag(a, k=k) + assert_almost_equal(r.asnumpy(), np.diag(a_np, k=k)) + + # invalid k + k = max(h,w) + 1 + assertRaises(MXNetError, mx.nd.diag, a, k=k) + + # Test 2d backward, k=0 + data = mx.sym.Variable('data') + diag_sym = mx.sym.diag(data=data) + check_numeric_gradient(diag_sym, [a_np]) + + # Test 2d backward, k=1 + data = mx.sym.Variable('data') + diag_sym = mx.sym.diag(data=data, k=1) + check_numeric_gradient(diag_sym, [a_np]) + + # Test 2d backward, k=-1 + data = mx.sym.Variable('data') + diag_sym = mx.sym.diag(data=data, k=-1) + check_numeric_gradient(diag_sym, [a_np]) + + # test 1d input + d = np.random.randint(2,9) + a_np = np.random.random((d)) + a = mx.nd.array(a_np) + + # k is random + k = np.random.randint(-d,d) + r = mx.nd.diag(a, k=k) + + assert_almost_equal(r.asnumpy(), np.diag(a_np, k=k)) + + # Test 2d backward, k=0 + data = mx.sym.Variable('data') + diag_sym = mx.sym.diag(data=data) + check_numeric_gradient(diag_sym, [a_np]) + + # Test 2d backward, k=1 + data = mx.sym.Variable('data') + diag_sym = mx.sym.diag(data=data, k=1) + check_numeric_gradient(diag_sym, [a_np]) + + # Test 2d backward, k=-1 + data = mx.sym.Variable('data') + diag_sym = mx.sym.diag(data=data, k=-1) + check_numeric_gradient(diag_sym, [a_np]) + if __name__ == '__main__': import nose