Skip to content

Commit

Permalink
Image normalize operator - GPU support, 3D/4D inputs (apache#13802)
Browse files Browse the repository at this point in the history
* CPU version of normalize operator is working and unit test added

* Add GPU implementation and tests

* Working GPU normalize transforms

* Add default values, fix imports, fix documentation

* Add backward implmentation for image normalize

* Add tests for backward pass

* Move back operators to its original files

* Add review comments

* Add 4D example

* Make infer type generic

* Fix inline function build error

* make functions as inline to avoid multiple definition conflict across cc and cu

* Fix build errors

* Fix failing GPU tests
  • Loading branch information
sandeep-krishnamurthy authored and haohuw committed Jun 23, 2019
1 parent 9ed5759 commit 58f27de
Show file tree
Hide file tree
Showing 7 changed files with 497 additions and 57 deletions.
25 changes: 22 additions & 3 deletions python/mxnet/gluon/data/vision/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ def hybrid_forward(self, F, x):


class Normalize(HybridBlock):
"""Normalize an tensor of shape (C x H x W) with mean and
"""Normalize an tensor of shape (C x H x W) or (N x C x H x W) with mean and
standard deviation.
Given mean `(m1, ..., mn)` and std `(s1, ..., sn)` for `n` channels,
Expand All @@ -154,12 +154,31 @@ class Normalize(HybridBlock):
Inputs:
- **data**: input tensor with (C x H x W) shape.
- **data**: input tensor with (C x H x W) or (N x C x H x W) shape.
Outputs:
- **out**: output tensor with the shape as `data`.
Examples
--------
>>> transformer = transforms.Normalize(mean=(0, 1, 2), std=(3, 2, 1))
>>> image = mx.nd.random.uniform(0, 1, (3, 4, 2))
>>> transformer(image)
[[[ 0.18293785 0.19761486]
[ 0.23839645 0.28142193]
[ 0.20092112 0.28598186]
[ 0.18162774 0.28241724]]
[[-0.2881726 -0.18821815]
[-0.17705294 -0.30780914]
[-0.2812064 -0.3512327 ]
[-0.05411351 -0.4716435 ]]
[[-1.0363373 -1.7273437 ]
[-1.6165586 -1.5223348 ]
[-1.208275 -1.1878313 ]
[-1.4711051 -1.5200229 ]]]
<NDArray 3x4x2 @cpu(0)>
"""
def __init__(self, mean, std):
def __init__(self, mean=0.0, std=1.0):
super(Normalize, self).__init__()
self._mean = mean
self._std = std
Expand Down
Loading

0 comments on commit 58f27de

Please sign in to comment.