Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Add GPU implementation and tests
Browse files Browse the repository at this point in the history
  • Loading branch information
sandeep-krishnamurthy committed Jan 8, 2019
1 parent ca4d8ee commit 853caa3
Show file tree
Hide file tree
Showing 4 changed files with 111 additions and 28 deletions.
23 changes: 21 additions & 2 deletions python/mxnet/gluon/data/vision/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ def hybrid_forward(self, F, x):


class Normalize(HybridBlock):
"""Normalize an tensor of shape (C x H x W) with mean and
"""Normalize an tensor of shape (C x H x W) or (N x C x H x W) with mean and
standard deviation.
Given mean `(m1, ..., mn)` and std `(s1, ..., sn)` for `n` channels,
Expand All @@ -154,10 +154,29 @@ class Normalize(HybridBlock):
Inputs:
- **data**: input tensor with (C x H x W) shape.
- **data**: input tensor with (C x H x W) or (N x C x H x W) shape.
Outputs:
- **out**: output tensor with the shape as `data`.
Examples
--------
>>> transformer = transforms.Normalize(mean=(0, 1, 2), std=(3, 2, 1))
>>> image = mx.nd.random.uniform(0, 1, (3, 4, 2))
>>> transformer(image)
[[[ 0.18293785 0.19761486]
[ 0.23839645 0.28142193]
[ 0.20092112 0.28598186]
[ 0.18162774 0.28241724]]
[[-0.2881726 -0.18821815]
[-0.17705294 -0.30780914]
[-0.2812064 -0.3512327 ]
[-0.05411351 -0.4716435 ]]
[[-1.0363373 -1.7273437 ]
[-1.6165586 -1.5223348 ]
[-1.208275 -1.1878313 ]
[-1.4711051 -1.5200229 ]]]
<NDArray 3x4x2 @cpu(0)>
"""
def __init__(self, mean, std):
super(Normalize, self).__init__()
Expand Down
26 changes: 13 additions & 13 deletions src/operator/image/normalize_op-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@
* \file normalize_op-inl.h
* \brief Image normalization operator
*/
#ifndef MXNET_OPERATOR_IMAGE_NORMALIZE_INL_H_
#define MXNET_OPERATOR_IMAGE_NORMALIZE_INL_H_
#ifndef MXNET_OPERATOR_IMAGE_NORMALIZE_OP_INL_H_
#define MXNET_OPERATOR_IMAGE_NORMALIZE_OP_INL_H_


#include <mxnet/base.h>
Expand Down Expand Up @@ -60,15 +60,15 @@ inline bool NormalizeOpShape(const nnvm::NodeAttrs& attrs,
std::vector<TShape> *in_attrs,
std::vector<TShape> *out_attrs) {
const NormalizeParam &param = nnvm::get<NormalizeParam>(attrs.parsed);

const auto& dshape = (*in_attrs)[0];
if (!dshape.ndim()) return false;

CHECK((dshape.ndim() == 3) || (dshape.ndim() == 4))
<< "Input tensor must have shape (channels, height, width), or "
<< "(N, channels, height, width), but got " << dshape;
uint32_t nchannels;

int32_t nchannels;
if (dshape.ndim() == 3) {
nchannels = dshape[0];
CHECK(nchannels == 3 || nchannels == 1)
Expand Down Expand Up @@ -103,7 +103,7 @@ inline bool NormalizeOpType(const nnvm::NodeAttrs& attrs,

// Normalized Tensor will be a float
TYPE_ASSIGN_CHECK(*out_attrs, 0, mshadow::kFloat32);
return out_attrs->at(0) != -1;
return out_attrs->at(0) != -1;
}

template<int req>
Expand All @@ -112,7 +112,7 @@ struct normalize_forward {
MSHADOW_XINLINE static void Map(int j, DType* out_data, const DType* in_data,
const int i, const int length, const int step,
const DType mean, const DType std_dev) {
KERNEL_ASSIGN(out_data[step + i*length + j], req,
KERNEL_ASSIGN(out_data[step + i*length + j], req,
(in_data[step + i*length + j] - mean) / std_dev);
}
};
Expand All @@ -139,7 +139,7 @@ void NormalizeImpl(const OpContext &ctx,
mxnet_op::Kernel<normalize_forward<req_type>, xpu>::Launch(
s, length, output, input,
i, length, step, mean, std_dev);
}
}
});
});
}
Expand All @@ -166,7 +166,7 @@ void NormalizeOpForward(const nnvm::NodeAttrs &attrs,
const int batch_size = inputs[0].shape_[0];
const int length = inputs[0].shape_[2] * inputs[0].shape_[3];
const int channel = inputs[0].shape_[1];
const int step = channel*length;
const int step = channel * length;

#pragma omp parallel for
for (auto n = 0; n < batch_size; ++n) {
Expand All @@ -175,7 +175,7 @@ void NormalizeOpForward(const nnvm::NodeAttrs &attrs,
}
}

} // namespace image
} // namespace op
} // namespace mxnet
#endif //MXNET_OPERATOR_IMAGE_NORMALIZE_INL_H_
} // namespace image
} // namespace op
} // namespace mxnet
#endif // MXNET_OPERATOR_IMAGE_NORMALIZE_OP_INL_H_
24 changes: 11 additions & 13 deletions src/operator/image/normalize_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -21,18 +21,16 @@
* \file normalize_op.cu
* \brief GPU Implementation of Normalize op
*/
#include "./normalize_op-inl.h"
#include "./normalize_op-inl.h"

namespace mxnet {
namespace op {
namespace image {
namespace mxnet {
namespace op {
namespace image {

NNVM_REGISTER_OP(_image_normalize)
.set_attr<FComputeEx>("FComputeEx<gpu>", NormalizeOpForward<gpu>)
.set_attr<FCompute>("FCompute<gpu>", NormalizeOpForward<gpu>);
.set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseNone{ "_copy" })

} // namespace image
} // namespace op
} // namespace mxnet

NNVM_REGISTER_OP(_image_normalize)
.set_attr<FCompute>("FCompute", NormalizeOpForward<gpu>)
.set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseNone{ "_copy" });

} // namespace image
} // namespace op
} // namespace mxnet
66 changes: 66 additions & 0 deletions tests/python/gpu/test_gluon_transforms.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from __future__ import print_function
import os
import sys
import mxnet as mx
import mxnet.ndarray as nd
import numpy as np
from mxnet import gluon
from mxnet.base import MXNetError
from mxnet.gluon.data.vision import transforms
from mxnet.test_utils import assert_almost_equal, set_default_context
from mxnet.test_utils import almost_equal
curr_path = os.path.dirname(os.path.abspath(os.path.expanduser(__file__)))
sys.path.insert(0, os.path.join(curr_path, '../unittest'))
from common import assertRaises, setup_module, with_seed, teardown


set_default_context(mx.gpu(0))

@with_seed()
def test_normalize():
# 3D Input
data_in_3d = nd.random.uniform(0, 1, (3, 300, 300))
out_nd_3d = transforms.Normalize(mean=(0, 1, 2), std=(3, 2, 1))(data_in_3d)
data_expected_3d = data_in_3d.asnumpy()
data_expected_3d[:][:][0] = data_expected_3d[:][:][0] / 3.0
data_expected_3d[:][:][1] = (data_expected_3d[:][:][1] - 1.0) / 2.0
data_expected_3d[:][:][2] = data_expected_3d[:][:][2] - 2.0
assert_almost_equal(data_expected_3d, out_nd_3d.asnumpy())

# 4D Input
data_in_4d = nd.random.uniform(0, 1, (2, 3, 300, 300))
out_nd_4d = transforms.Normalize(mean=(0, 1, 2), std=(3, 2, 1))(data_in_4d)
data_expected_4d = data_in_4d.asnumpy()
data_expected_4d[0][:][:][0] = data_expected_4d[0][:][:][0] / 3.0
data_expected_4d[0][:][:][1] = (data_expected_4d[0][:][:][1] - 1.0) / 2.0
data_expected_4d[0][:][:][2] = data_expected_4d[0][:][:][2] - 2.0
data_expected_4d[1][:][:][0] = data_expected_4d[1][:][:][0] / 3.0
data_expected_4d[1][:][:][1] = (data_expected_4d[1][:][:][1] - 1.0) / 2.0
data_expected_4d[1][:][:][2] = data_expected_4d[1][:][:][2] - 2.0
assert_almost_equal(data_expected_4d, out_nd_4d.asnumpy())

# Invalid Input - Neither 3D or 4D input
invalid_data_in = nd.random.uniform(0, 1, (5, 5, 3, 300, 300))
normalize_transformer = transforms.Normalize(mean=(0, 1, 2), std=(3, 2, 1))
assertRaises(MXNetError, normalize_transformer, invalid_data_in)

# Invalid Input - Channel neither 1 or 3
invalid_data_in = nd.random.uniform(0, 1, (5, 4, 300, 300))
normalize_transformer = transforms.Normalize(mean=(0, 1, 2), std=(3, 2, 1))
assertRaises(MXNetError, normalize_transformer, invalid_data_in)

0 comments on commit 853caa3

Please sign in to comment.