diff --git a/src/operator/nn/moments-inl.h b/src/operator/nn/moments-inl.h new file mode 100644 index 000000000000..b88171719c58 --- /dev/null +++ b/src/operator/nn/moments-inl.h @@ -0,0 +1,224 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * Copyright (c) 2019 by Contributors + * \file moments-inl.h + * \brief Moments operator + * \author Hao Jin +*/ + +#ifndef MXNET_OPERATOR_NN_MOMENTS_INL_H_ +#define MXNET_OPERATOR_NN_MOMENTS_INL_H_ + +#include +#include "../tensor/broadcast_reduce_op.h" + +namespace mxnet { +namespace op { + +struct MomentsParam : public dmlc::Parameter { + dmlc::optional axes; + bool keepdims; + DMLC_DECLARE_PARAMETER(MomentsParam) { + DMLC_DECLARE_FIELD(axes).set_default(dmlc::optional()) + .describe("Array of ints. Axes along which to compute mean and variance."); + DMLC_DECLARE_FIELD(keepdims).set_default(false) + .describe("produce moments with the same dimensionality as the input."); + } +}; + +inline bool MomentsShape(const nnvm::NodeAttrs& attrs, + mxnet::ShapeVector* in_attrs, + mxnet::ShapeVector* out_attrs) { + const MomentsParam& param = nnvm::get(attrs.parsed); + CHECK_EQ(in_attrs->size(), 1U); + CHECK_EQ(out_attrs->size(), 2U); + + mxnet::TShape out_shape = + ReduceAxesShapeImpl((*in_attrs)[0], param.axes, param.keepdims, false); + if (!param.axes.has_value() || param.axes.value().ndim() == 0) { + LOG(FATAL) << "Empty axes is not supported, if you would like to do global moments, " + << "please pass all axes to axes argument"; + } + SHAPE_ASSIGN_CHECK(*out_attrs, 0, out_shape); + SHAPE_ASSIGN_CHECK(*out_attrs, 1, out_shape); + return true; +} + +inline bool MomentsType(const nnvm::NodeAttrs& attrs, + std::vector* in_attrs, + std::vector* out_attrs) { + CHECK_EQ(in_attrs->size(), 1U); + CHECK_EQ(out_attrs->size(), 2U); + + TYPE_ASSIGN_CHECK(*out_attrs, 0, in_attrs->at(0)); + TYPE_ASSIGN_CHECK(*out_attrs, 1, in_attrs->at(0)); + TYPE_ASSIGN_CHECK(*in_attrs, 0, out_attrs->at(0)); + TYPE_ASSIGN_CHECK(*in_attrs, 0, out_attrs->at(1)); + return out_attrs->at(0) != -1 && out_attrs->at(1) != -1; +} + +struct VarBroadcastKernel { + template + MSHADOW_XINLINE static void Map(int i, + DType *out, + const DType *data, + const DType *mean, + mshadow::Shape<5> data_shape, + mshadow::Shape<5> mean_shape) { + size_t data_idx = i; + size_t mean_idx = i; + size_t data_stride = 1; + size_t mean_stride = 1; + for (int axis = 4; axis >= 0; --axis) { + size_t axis_idx = data_idx % data_shape[axis]; + mean_idx -= axis_idx * data_stride; + if (mean_shape[axis] != 1) { + mean_idx += axis_idx * mean_stride; + } + data_idx /= data_shape[axis]; + data_stride *= data_shape[axis]; + mean_stride *= mean_shape[axis]; + } + DType res = (data[i] - mean[mean_idx]); + out[i] = res * res; + } +}; + +template +inline void MomentsForward(const nnvm::NodeAttrs& attrs, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { + using namespace mshadow; + using namespace mshadow_op; + using namespace mxnet_op; + + CHECK_EQ(inputs.size(), 1U); + CHECK_EQ(outputs.size(), 2U); + + const MomentsParam& param = nnvm::get(attrs.parsed); + Stream *s = ctx.get_stream(); + + const TBlob& data = inputs[0]; + const TBlob& mean = outputs[0]; + const TBlob& var = outputs[1]; + + mxnet::TShape small; + if (param.keepdims) { + small = outputs[0].shape_; + } else { + small = ReduceAxesShapeImpl(inputs[0].shape_, param.axes, true, false); + } + + ReduceAxesComputeImpl(ctx, {data}, {req[0]}, {mean}, small); + MSHADOW_TYPE_SWITCH(data.type_flag_, DType, { + Shape<5> data_shape, mean_shape; + for (int i = 0; i < 5; ++i) { + data_shape[i] = (i < data.shape_.ndim()) ? data.shape_[i] : 1; + mean_shape[i] = (i < small.ndim()) ? small[i] : 1; + } + Tensor temp_data = + ctx.requested[0].get_space_typed(Shape1(data.shape_.Size()), s);; + Kernel::Launch(s, data.shape_.Size(), temp_data.dptr_, + data.dptr(), mean.dptr(), data_shape, mean_shape); + ReduceAxesComputeImpl( + ctx, {TBlob(temp_data).reshape(data.shape_)}, {kWriteTo}, {var}, small); + }); +} + +template +struct VarBackwardKernel { + template + MSHADOW_XINLINE static void Map(int i, + DType *igrad, + const DType *ograd, + const DType *data, + const DType *mean, + mshadow::Shape<5> data_shape, + mshadow::Shape<5> mean_shape, + const float N, + const float ddof = 0.0f) { + size_t data_idx = i; + size_t mean_idx = i; + size_t data_stride = 1; + size_t mean_stride = 1; + for (int axis = 4; axis >= 0; --axis) { + size_t axis_idx = data_idx % data_shape[axis]; + mean_idx -= axis_idx * data_stride; + if (mean_shape[axis] != 1) { + mean_idx += axis_idx * mean_stride; + } + data_idx /= data_shape[axis]; + data_stride *= data_shape[axis]; + mean_stride *= mean_shape[axis]; + } + KERNEL_ASSIGN(igrad[i], req, ograd[mean_idx] * (data[i] - mean[mean_idx]) * 2 / (N - ddof)); + } +}; + +template +inline void MomentsBackward(const nnvm::NodeAttrs& attrs, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { + using namespace mshadow; + using namespace mshadow::expr; + using namespace mshadow_op; + using namespace mxnet_op; + + CHECK_EQ(inputs.size(), 5U); + CHECK_EQ(outputs.size(), 1U); + + const MomentsParam& param = nnvm::get(attrs.parsed); + Stream *s = ctx.get_stream(); + + const TBlob& mean_grad = inputs[0]; + const TBlob& var_grad = inputs[1]; + const TBlob& data = inputs[2]; + const TBlob& mean = inputs[3]; + const TBlob& var = inputs[4]; + const TBlob& data_grad = outputs[0]; + + mxnet::TShape small = ReduceAxesShapeImpl(data.shape_, param.axes, true, false); + BroadcastComputeImpl(attrs, ctx, {mean_grad}, req, outputs, small); + MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, DType, { + Tensor igrad = outputs[0].FlatTo1D(s); + igrad /= scalar(outputs[0].Size()/inputs[0].Size()); + }); + + Shape<5> data_shape, var_shape; + float N = data_grad.Size() / var.Size(); + for (int i = 0; i < 5; ++i) { + data_shape[i] = (i < data.shape_.ndim()) ? data.shape_[i] : 1; + var_shape[i] = (i < small.ndim()) ? small[i] : 1; + } + MSHADOW_TYPE_SWITCH(data_grad.type_flag_, DType, { + Kernel, xpu>::Launch( + s, data_grad.shape_.Size(), data_grad.dptr(), var_grad.dptr(), + data.dptr(), mean.dptr(), data_shape, var_shape, N); + }); +} + +} // namespace op +} // namespace mxnet +#endif // MXNET_OPERATOR_NN_MOMENTS_INL_H_ diff --git a/src/operator/nn/moments.cc b/src/operator/nn/moments.cc new file mode 100644 index 000000000000..37b8cdf18750 --- /dev/null +++ b/src/operator/nn/moments.cc @@ -0,0 +1,85 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * Copyright (c) 2019 by Contributors + * \file moments.cc + * \brief Moments operator + * \author Hao Jin +*/ + +#include "./moments-inl.h" + +namespace mxnet { +namespace op { + +DMLC_REGISTER_PARAMETER(MomentsParam); + +NNVM_REGISTER_OP(moments) +.describe(R"code( +Calculate the mean and variance of `data`. + +The mean and variance are calculated by aggregating the contents of data across axes. +If x is 1-D and axes = [0] this is just the mean and variance of a vector. + +Example: + + x = [[1, 2, 3], [4, 5, 6]] + mean, var = moments(data=x, axes=[0]) + mean = [2.5, 3.5, 4.5] + var = [2.25, 2.25, 2.25] + mean, var = moments(data=x, axes=[1]) + mean = [2.0, 5.0] + var = [0.66666667, 0.66666667] + mean, var = moments(data=x, axis=[0, 1]) + mean = [3.5] + var = [2.9166667] + +)code" ADD_FILELINE) +.set_attr_parser(ParamParser) +.set_num_inputs(1) +.set_num_outputs(2) +.set_attr("FListInputNames", + [](const NodeAttrs& attrs) { + return std::vector{"data"}; + }) +.set_attr("FInferShape", MomentsShape) +.set_attr("FInferType", MomentsType) +.set_attr("FCompute", MomentsForward) +.set_attr("FResourceRequest", + [](const NodeAttrs& attrs) { + return std::vector{ResourceRequest::kTempSpace}; + }) +.set_attr("FGradient", ElemwiseGradUseInOut{"_backward_moments"}) +.set_attr("FInplaceOption", + [](const NodeAttrs& attrs) { + return std::vector >{{0, 0}}; + }) +.add_argument("data", "NDArray-or-Symbol", "Input ndarray") +.add_arguments(MomentsParam::__FIELDS__()); + +NNVM_REGISTER_OP(_backward_moments) +.set_attr_parser(ParamParser) +.set_num_inputs(5) +.set_num_outputs(1) +.set_attr("TIsBackward", true) +.set_attr("FCompute", MomentsBackward); + +} // namespace op +} // namespace mxnet diff --git a/src/operator/nn/moments.cu b/src/operator/nn/moments.cu new file mode 100644 index 000000000000..a45ae33281be --- /dev/null +++ b/src/operator/nn/moments.cu @@ -0,0 +1,39 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * Copyright (c) 2019 by Contributors + * \file moments.cu + * \brief Moments operator + * \author Hao Jin +*/ + +#include "./moments-inl.h" + +namespace mxnet { +namespace op { + +NNVM_REGISTER_OP(moments) +.set_attr("FCompute", MomentsForward); + +NNVM_REGISTER_OP(_backward_moments) +.set_attr("FCompute", MomentsBackward); + +} // namespace op +} // namespace mxnet diff --git a/tests/python/unittest/test_operator.py b/tests/python/unittest/test_operator.py index e8bfaba4736d..09a4c1b497e7 100644 --- a/tests/python/unittest/test_operator.py +++ b/tests/python/unittest/test_operator.py @@ -7981,6 +7981,36 @@ def test_split_v2(): check_symbolic_backward(sym, {"data": mx_data}, out_grad, [np.concatenate(out_grad, axis=axis)]) +@with_seed() +def test_moments(): + dim = random.randint(2, 5) + shape = rand_shape_nd(dim, dim=5) + axes = [i for i in range(dim)] + test_dims = random.sample(axes, random.randint(1, dim)) + test_axes = tuple(sorted(test_dims)) + np_a = np.random.uniform(-1.0, 1.0, shape) + a = mx.nd.array(np_a) + for keepdims in [True, False]: + print(shape, test_axes, keepdims) + eps = 1e-3 + np_a[abs(np_a) < eps] = 2 * eps + np_mean = np.mean(np_a, axis=test_axes, keepdims=keepdims) + np_var = np.var(np_a, axis=test_axes, keepdims=keepdims) + mx_mean, mx_var = mx.nd.moments(a, keepdims=keepdims, axes=test_axes) + N = np_a.size / np_mean.size + mx_sym = mx.sym.Variable("data") + mx_moments = mx.sym.moments(mx_sym, axes=test_axes, keepdims=keepdims) + mx_test_sym = mx.sym.elemwise_add(mx_moments[0], mx_moments[1]) + if len(np_mean.shape) == 0: + np_mean = np_mean.reshape(mx_mean.shape) + np_var = np_var.reshape(mx_var.shape) + print(np_mean.shape, mx_mean.shape) + assert np_mean.shape == mx_mean.shape + assert np_var.shape == mx_var.shape + check_symbolic_forward(mx_test_sym, [np_a], [np_mean + np_var], rtol=1e-3, atol=1e-5) + check_numeric_gradient(mx_test_sym, [np_a], numeric_eps=eps, rtol=1e-2, atol=2e-4) + + @with_seed() def test_invalid_kernel_size(): invalid_kernel_size = 28