Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Make to_tensor and normalize to accept 3D or 4D tensor inputs #13614

Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 27 additions & 5 deletions python/mxnet/gluon/data/vision/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,17 +96,20 @@ def hybrid_forward(self, F, x):


class ToTensor(HybridBlock):
"""Converts an image NDArray to a tensor NDArray.
"""Converts an image NDArray or batch of image NDArray to a tensor NDArray.

Converts an image NDArray of shape (H x W x C) in the range
[0, 255] to a float32 tensor NDArray of shape (C x H x W) in
the range [0, 1).

If batch input, converts a batch image NDArray of shape (N x H x W x C) in the
range [0, 255] to a float32 tensor NDArray of shape (N x C x H x W).

Inputs:
- **data**: input tensor with (H x W x C) shape and uint8 type.
- **data**: input tensor with (H x W x C) or (N x H x W x C) shape and uint8 type.

Outputs:
- **out**: output tensor with (C x H x W) shape and float32 type.
- **out**: output tensor with (C x H x W) or (N x H x W x C) shape and float32 type.

Examples
--------
Expand Down Expand Up @@ -135,7 +138,7 @@ def hybrid_forward(self, F, x):


class Normalize(HybridBlock):
"""Normalize an tensor of shape (C x H x W) with mean and
"""Normalize an tensor of shape (C x H x W) or (N x C x H x W) with mean and
standard deviation.

Given mean `(m1, ..., mn)` and std `(s1, ..., sn)` for `n` channels,
Expand All @@ -154,10 +157,29 @@ class Normalize(HybridBlock):


Inputs:
- **data**: input tensor with (C x H x W) shape.
- **data**: input tensor with (C x H x W) or (N x C x H x W) shape.

Outputs:
- **out**: output tensor with the shape as `data`.

Examples
--------
>>> transformer = transforms.Normalize(mean=(0, 1, 2), std=(3, 2, 1))
>>> image = mx.nd.random.uniform(0, 1, (3, 4, 2))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you give an example of 4D (N x C x H x W) here?

>>> transformer(image)
[[[ 0.18293785 0.19761486]
[ 0.23839645 0.28142193]
[ 0.20092112 0.28598186]
[ 0.18162774 0.28241724]]
[[-0.2881726 -0.18821815]
[-0.17705294 -0.30780914]
[-0.2812064 -0.3512327 ]
[-0.05411351 -0.4716435 ]]
[[-1.0363373 -1.7273437 ]
[-1.6165586 -1.5223348 ]
[-1.208275 -1.1878313 ]
[-1.4711051 -1.5200229 ]]]
<NDArray 3x4x2 @cpu(0)>
"""
def __init__(self, mean, std):
super(Normalize, self).__init__()
Expand Down
128 changes: 93 additions & 35 deletions src/operator/image/image_random-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -47,9 +47,14 @@ inline bool ToTensorShape(const nnvm::NodeAttrs& attrs,
CHECK_EQ(out_attrs->size(), 1U);
TShape &shp = (*in_attrs)[0];
if (!shp.ndim()) return false;
CHECK_EQ(shp.ndim(), 3)
<< "Input image must have shape (height, width, channels), but got " << shp;
SHAPE_ASSIGN_CHECK(*out_attrs, 0, TShape({shp[2], shp[0], shp[1]}));
CHECK((shp.ndim() == 3) || (shp.ndim() == 4))
<< "Input image must have shape (height, width, channels), or "
<< "(N, height, width, channels) but got " << shp;
if (shp.ndim() == 3) {
SHAPE_ASSIGN_CHECK(*out_attrs, 0, TShape({shp[2], shp[0], shp[1]}));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be clearer to define enum constant N, W, C, H instead of using 0, 1, 2, 3

} else if (shp.ndim() == 4) {
SHAPE_ASSIGN_CHECK(*out_attrs, 0, TShape({shp[0], shp[3], shp[1], shp[2]}));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can shp[0] be zero?

}
return true;
}

Expand All @@ -62,6 +67,23 @@ inline bool ToTensorType(const nnvm::NodeAttrs& attrs,
return (*in_attrs)[0] != -1;
}

void ToTensorImpl(const std::vector<TBlob> &inputs,
const std::vector<TBlob> &outputs,
const int length,
const int channel,
const int step = 0) {
MSHADOW_TYPE_SWITCH(inputs[0].type_flag_, DType, {
float* output = outputs[0].dptr<float>();
DType* input = inputs[0].dptr<DType>();

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
#pragma omp parallel for collapse(2)

would this make better performance?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cannot be used inside Macro.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about checking the DType by ourself without using the Macro

for (int l = 0; l < length; ++l) {
for (int c = 0; c < channel; ++c) {
output[step + c*length + l] = static_cast<float>(input[step + l*channel + c]) / 255.0f;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

255.0f is already a float, so this cast may not be needed here.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Another question is why 255.0f? Maybe using a constant variable with clear name is more readable.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed. Making the change.

}
}
});
}

void ToTensor(const nnvm::NodeAttrs &attrs,
const OpContext &ctx,
const std::vector<TBlob> &inputs,
Expand All @@ -70,19 +92,23 @@ void ToTensor(const nnvm::NodeAttrs &attrs,
CHECK_EQ(req[0], kWriteTo)
<< "`to_tensor` does not support inplace";

int length = inputs[0].shape_[0] * inputs[0].shape_[1];
int channel = inputs[0].shape_[2];

MSHADOW_TYPE_SWITCH(inputs[0].type_flag_, DType, {
float* output = outputs[0].dptr<float>();
DType* input = inputs[0].dptr<DType>();

for (int l = 0; l < length; ++l) {
for (int c = 0; c < channel; ++c) {
output[c*length + l] = static_cast<float>(input[l*channel + c]) / 255.0f;
}
// 3D Input - 1 image
if (inputs[0].ndim() == 3) {
const int length = inputs[0].shape_[0] * inputs[0].shape_[1];
const int channel = inputs[0].shape_[2];
ToTensorImpl(inputs, outputs, length, channel);
} else if (inputs[0].ndim() == 4) {
// 4D input batch of images
const int batch_size = inputs[0].shape_[0];
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How large can this value be in practice?

const int length = inputs[0].shape_[1] * inputs[0].shape_[2];
const int channel = inputs[0].shape_[3];
const int step = channel * length;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How large can this value be in practice?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point, IMO make step unsigned long long int would be enough?


#pragma omp parallel for
for (auto n = 0; n < batch_size; ++n) {
ToTensorImpl(inputs, outputs, length, channel, n*step);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good change!

}
});
}
}

struct NormalizeParam : public dmlc::Parameter<NormalizeParam> {
Expand All @@ -103,14 +129,24 @@ inline bool NormalizeShape(const nnvm::NodeAttrs& attrs,
const auto& dshape = (*in_attrs)[0];
if (!dshape.ndim()) return false;

CHECK_EQ(dshape.ndim(), 3)
<< "Input tensor must have shape (channels, height, width), but got "
<< dshape;
auto nchannels = dshape[0];
CHECK(nchannels == 3 || nchannels == 1)
CHECK((dshape.ndim() == 3) || (dshape.ndim() == 4))
<< "Input tensor must have shape (channels, height, width), or "
<< "(N, channels, height, width), but got " << dshape;

uint32_t nchannels;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks. Good learning!

if (dshape.ndim() == 3) {
nchannels = dshape[0];
CHECK(nchannels == 3 || nchannels == 1)
<< "The first dimension of input tensor must be the channel dimension with "
<< "either 1 or 3 elements, but got input with shape " << dshape;
CHECK(param.mean.ndim() == 1 || param.mean.ndim() == nchannels)
} else if (dshape.ndim() == 4) {
nchannels = dshape[1];
CHECK(nchannels == 3 || nchannels == 1)
<< "The second dimension of input tensor must be the channel dimension with "
<< "either 1 or 3 elements, but got input with shape " << dshape;
}

CHECK((param.mean.ndim() == 1) || (param.mean.ndim() == nchannels))
<< "Invalid mean for input with shape " << dshape
<< ". mean must have either 1 or " << nchannels
<< " elements, but got " << param.mean;
Expand All @@ -123,28 +159,50 @@ inline bool NormalizeShape(const nnvm::NodeAttrs& attrs,
return true;
}

void NormalizeImpl(const std::vector<TBlob> &inputs,
const std::vector<TBlob> &outputs,
const NormalizeParam &param,
const int length,
const int channel,
const int step = 0) {
MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, DType, {
DType* input = inputs[0].dptr<DType>();
DType* output = outputs[0].dptr<DType>();

for (int i = 0; i < channel; ++i) {
DType mean = param.mean[param.mean.ndim() > 1 ? i : 0];
DType std_dev = param.std[param.std.ndim() > 1 ? i : 0];
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

mean and std_dev should be float type defined by line 115, 116

for (int j = 0; j < length; ++j) {
output[step + i*length + j] = (input[step + i*length + j] - mean) / std_dev;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if input is int, should it be int or float after nomarlization? I prefer float here

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point, it should be float. Making the change.

}
}
});
}

void Normalize(const nnvm::NodeAttrs &attrs,
const OpContext &ctx,
const std::vector<TBlob> &inputs,
const std::vector<OpReqType> &req,
const std::vector<TBlob> &outputs) {
const NormalizeParam &param = nnvm::get<NormalizeParam>(attrs.parsed);

int nchannels = inputs[0].shape_[0];
int length = inputs[0].shape_[1] * inputs[0].shape_[2];

MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, DType, {
DType* input = inputs[0].dptr<DType>();
DType* output = outputs[0].dptr<DType>();

for (int i = 0; i < nchannels; ++i) {
DType mean = param.mean[param.mean.ndim() > 1 ? i : 0];
DType std = param.std[param.std.ndim() > 1 ? i : 0];
for (int j = 0; j < length; ++j) {
output[i*length + j] = (input[i*length + j] - mean) / std;
}
// 3D input (c, h, w)
if (inputs[0].ndim() == 3) {
const int length = inputs[0].shape_[1] * inputs[0].shape_[2];
const int channel = inputs[0].shape_[0];
NormalizeImpl(inputs, outputs, param, length, channel);
} else if (inputs[0].ndim() == 4) {
// 4D input (n, c, h, w)
const int batch_size = inputs[0].shape_[0];
const int length = inputs[0].shape_[2] * inputs[0].shape_[3];
const int channel = inputs[0].shape_[1];
const int step = channel*length;
sandeep-krishnamurthy marked this conversation as resolved.
Show resolved Hide resolved

#pragma omp parallel for
for (auto n = 0; n < batch_size; ++n) {
NormalizeImpl(inputs, outputs, param, length, channel, n*step);
}
});
}
}

template<typename DType>
Expand Down
58 changes: 47 additions & 11 deletions tests/python/unittest/test_gluon_data_vision.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,30 +19,66 @@
import mxnet.ndarray as nd
import numpy as np
from mxnet import gluon
from mxnet.base import MXNetError
from mxnet.gluon.data.vision import transforms
from mxnet.test_utils import assert_almost_equal
from mxnet.test_utils import almost_equal
from common import setup_module, with_seed, teardown

from common import assertRaises, setup_module, with_seed, teardown

@with_seed()
def test_to_tensor():
# 3D Input
data_in = np.random.uniform(0, 255, (300, 300, 3)).astype(dtype=np.uint8)
sandeep-krishnamurthy marked this conversation as resolved.
Show resolved Hide resolved
out_nd = transforms.ToTensor()(nd.array(data_in, dtype='uint8'))
out_nd = transforms.ToTensor()(nd.array(data_in))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why remove the dtype in original 3D input test?

assert_almost_equal(out_nd.asnumpy(), np.transpose(
data_in.astype(dtype=np.float32) / 255.0, (2, 0, 1)))

# 4D Input
data_in = np.random.uniform(0, 255, (5, 300, 300, 3)).astype(dtype=np.uint8)
out_nd = transforms.ToTensor()(nd.array(data_in, dtype='uint8'))
assert_almost_equal(out_nd.asnumpy(), np.transpose(
data_in.astype(dtype=np.float32) / 255.0, (0, 3, 1, 2)))

# Invalid Input
invalid_data_in = nd.random.uniform(0, 255, (5, 5, 300, 300, 3)).astype(dtype=np.uint8)
transformer = transforms.ToTensor()
assertRaises(MXNetError, transformer, invalid_data_in)


@with_seed()
def test_normalize():
data_in = np.random.uniform(0, 255, (300, 300, 3)).astype(dtype=np.uint8)
data_in = transforms.ToTensor()(nd.array(data_in, dtype='uint8'))
out_nd = transforms.Normalize(mean=(0, 1, 2), std=(3, 2, 1))(data_in)
data_expected = data_in.asnumpy()
data_expected[:][:][0] = data_expected[:][:][0] / 3.0
data_expected[:][:][1] = (data_expected[:][:][1] - 1.0) / 2.0
data_expected[:][:][2] = data_expected[:][:][2] - 2.0
assert_almost_equal(data_expected, out_nd.asnumpy())
# 3D Input
data_in_3d = np.random.uniform(0, 255, (300, 300, 3)).astype(dtype=np.uint8)
data_in_3d = transforms.ToTensor()(nd.array(data_in_3d, dtype='uint8'))
out_nd_3d = transforms.Normalize(mean=(0, 1, 2), std=(3, 2, 1))(data_in_3d)
data_expected_3d = data_in_3d.asnumpy()
data_expected_3d[:][:][0] = data_expected_3d[:][:][0] / 3.0
data_expected_3d[:][:][1] = (data_expected_3d[:][:][1] - 1.0) / 2.0
data_expected_3d[:][:][2] = data_expected_3d[:][:][2] - 2.0
assert_almost_equal(data_expected_3d, out_nd_3d.asnumpy())

# 4D Input
data_in_4d = np.random.uniform(0, 255, (2, 300, 300, 3)).astype(dtype=np.uint8)
data_in_4d = transforms.ToTensor()(nd.array(data_in_4d, dtype='uint8'))
out_nd_4d = transforms.Normalize(mean=(0, 1, 2), std=(3, 2, 1))(data_in_4d)
data_expected_4d = data_in_4d.asnumpy()
data_expected_4d[0][:][:][0] = data_expected_4d[0][:][:][0] / 3.0
data_expected_4d[0][:][:][1] = (data_expected_4d[0][:][:][1] - 1.0) / 2.0
data_expected_4d[0][:][:][2] = data_expected_4d[0][:][:][2] - 2.0
data_expected_4d[1][:][:][0] = data_expected_4d[1][:][:][0] / 3.0
data_expected_4d[1][:][:][1] = (data_expected_4d[1][:][:][1] - 1.0) / 2.0
data_expected_4d[1][:][:][2] = data_expected_4d[1][:][:][2] - 2.0
assert_almost_equal(data_expected_4d, out_nd_4d.asnumpy())

# Invalid Input - Neither 3D or 4D input
invalid_data_in = nd.random.uniform(0, 1, (5, 5, 3, 300, 300)).astype(dtype=np.float32)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍

normalize_transformer = transforms.Normalize(mean=(0, 1, 2), std=(3, 2, 1))
assertRaises(MXNetError, normalize_transformer, invalid_data_in)

# Invalid Input - Channel neither 1 or 3
invalid_data_in = nd.random.uniform(0, 1, (5, 4, 300, 300)).astype(dtype=np.float32)
normalize_transformer = transforms.Normalize(mean=(0, 1, 2), std=(3, 2, 1))
assertRaises(MXNetError, normalize_transformer, invalid_data_in)


@with_seed()
Expand Down