Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

[MXNET-1180] Java Image API #13807

Merged
merged 11 commits into from
Jan 31, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 0 additions & 6 deletions scala-package/core/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -138,12 +138,6 @@
<version>INTERNAL</version>
<scope>provided</scope>
</dependency>
<dependency>
<groupId>junit</groupId>
<artifactId>junit</artifactId>
<version>4.11</version>
<scope>test</scope>
</dependency>
<dependency>
<groupId>commons-io</groupId>
<artifactId>commons-io</artifactId>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ object Image {
* @param flag Convert decoded image to grayscale (0) or color (1).
* @param to_rgb Whether to convert decoded image
* to mxnet's default RGB format (instead of opencv's default BGR).
* @return NDArray in HWC format
* @return NDArray in HWC format with DType [[DType.UInt8]]
*/
def imDecode(buf: Array[Byte], flag: Int,
to_rgb: Boolean,
Expand All @@ -56,7 +56,7 @@ object Image {
/**
* Same imageDecode with InputStream
* @param inputStream the inputStream of the image
* @return NDArray in HWC format
* @return NDArray in HWC format with DType [[DType.UInt8]]
*/
def imDecode(inputStream: InputStream, flag: Int = 1,
to_rgb: Boolean = true,
Expand All @@ -78,7 +78,7 @@ object Image {
* @param flag Convert decoded image to grayscale (0) or color (1).
* @param to_rgb Whether to convert decoded image to mxnet's default RGB format
* (instead of opencv's default BGR).
* @return org.apache.mxnet.NDArray in HWC format
* @return org.apache.mxnet.NDArray in HWC format with DType [[DType.UInt8]]
*/
def imRead(filename: String, flag: Option[Int] = None,
to_rgb: Option[Boolean] = None,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -97,9 +97,11 @@ object NDArray extends NDArrayBase {
case ndArr: Seq[NDArray @unchecked] =>
if (ndArr.head.isInstanceOf[NDArray]) (ndArr.toArray, ndArr.toArray.map(_.handle))
else throw new IllegalArgumentException(
"Unsupported out var type, should be NDArray or subclass of Seq[NDArray]")
s"""Unsupported out ${output.getClass} type,
| should be NDArray or subclass of Seq[NDArray]""".stripMargin)
case _ => throw new IllegalArgumentException(
"Unsupported out var type, should be NDArray or subclass of Seq[NDArray]")
s"""Unsupported out ${output.getClass} type,
| should be NDArray or subclass of Seq[NDArray]""".stripMargin)
}
} else {
(null, null)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.mxnet.javaapi
// scalastyle:off
import java.awt.image.BufferedImage
// scalastyle:on
import java.io.InputStream

object Image {
/**
* Decode image with OpenCV.
* Note: return image in RGB by default, instead of OpenCV's default BGR.
* @param buf Buffer containing binary encoded image
* @param flag Convert decoded image to grayscale (0) or color (1).
* @param toRGB Whether to convert decoded image
* to mxnet's default RGB format (instead of opencv's default BGR).
* @return NDArray in HWC format with DType [[DType.UInt8]]
*/
def imDecode(buf: Array[Byte], flag: Int, toRGB: Boolean): NDArray = {
org.apache.mxnet.Image.imDecode(buf, flag, toRGB, None)
lanking520 marked this conversation as resolved.
Show resolved Hide resolved
}

def imDecode(buf: Array[Byte]): NDArray = {
imDecode(buf, 1, true)
}

/**
* Same imageDecode with InputStream
*
* @param inputStream the inputStream of the image
* @param flag Convert decoded image to grayscale (0) or color (1).
* @param toRGB Whether to convert decoded image
* @return NDArray in HWC format with DType [[DType.UInt8]]
*/
def imDecode(inputStream: InputStream, flag: Int, toRGB: Boolean): NDArray = {
org.apache.mxnet.Image.imDecode(inputStream, flag, toRGB, None)
}

def imDecode(inputStream: InputStream): NDArray = {
imDecode(inputStream, 1, true)
}

/**
* Read and decode image with OpenCV.
* Note: return image in RGB by default, instead of OpenCV's default BGR.
* @param filename Name of the image file to be loaded.
* @param flag Convert decoded image to grayscale (0) or color (1).
* @param toRGB Whether to convert decoded image to mxnet's default RGB format
* (instead of opencv's default BGR).
* @return org.apache.mxnet.NDArray in HWC format with DType [[DType.UInt8]]
*/
def imRead(filename: String, flag: Int, toRGB: Boolean): NDArray = {
org.apache.mxnet.Image.imRead(filename, Some(flag), Some(toRGB), None)
}

def imRead(filename: String): NDArray = {
imRead(filename, 1, true)
}

/**
* Resize image with OpenCV.
* @param src source image in NDArray
* @param w Width of resized image.
* @param h Height of resized image.
* @param interp Interpolation method (default=cv2.INTER_LINEAR).
* @return org.apache.mxnet.NDArray
*/
def imResize(src: NDArray, w: Int, h: Int, interp: Integer): NDArray = {
val interpVal = if (interp == null) None else Some(interp.intValue())
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can just use Option(interp.intValue())

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This will cause a null pointer exception, since users may send null in

org.apache.mxnet.Image.imResize(src, w, h, interpVal, None)
}

def imResize(src: NDArray, w: Int, h: Int): NDArray = {
imResize(src, w, h, null)
}

/**
* Do a fixed crop on the image
* @param src Src image in NDArray
* @param x0 starting x point
* @param y0 starting y point
* @param w width of the image
* @param h height of the image
* @return cropped NDArray
*/
def fixedCrop(src: NDArray, x0: Int, y0: Int, w: Int, h: Int): NDArray = {
org.apache.mxnet.Image.fixedCrop(src, x0, y0, w, h)
}

/**
* Convert a NDArray image to a real image
* The time cost will increase if the image resolution is big
* @param src Source image file in RGB
* @return Buffered Image
*/
def toImage(src: NDArray): BufferedImage = {
org.apache.mxnet.Image.toImage(src)
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.mxnet.javaapi;

import org.apache.commons.io.FileUtils;
import org.junit.BeforeClass;
import org.junit.Test;
import java.io.File;
import java.net.URL;

import static org.junit.Assert.assertArrayEquals;

public class ImageTest {

private static String imLocation;

private static void downloadUrl(String url, String filePath, int maxRetry) throws Exception{
File tmpFile = new File(filePath);
Boolean success = false;
if (!tmpFile.exists()) {
while (maxRetry > 0 && !success) {
try {
FileUtils.copyURLToFile(new URL(url), tmpFile);
success = true;
} catch(Exception e){
maxRetry -= 1;
}
}
} else {
success = true;
}
if (!success) throw new Exception("$url Download failed!");
}

@BeforeClass
public static void downloadFile() throws Exception {
String tempDirPath = System.getProperty("java.io.tmpdir");
imLocation = tempDirPath + "/inputImages/Pug-Cookie.jpg";
downloadUrl("https://s3.amazonaws.com/model-server/inputs/Pug-Cookie.jpg",
imLocation, 3);
}

@Test
public void testImageProcess() {
NDArray nd = Image.imRead(imLocation, 1, true);
assertArrayEquals(nd.shape().toArray(), new int[]{576, 1024, 3});
NDArray nd2 = Image.imResize(nd, 224, 224, null);
assertArrayEquals(nd2.shape().toArray(), new int[]{224, 224, 3});
NDArray cropped = Image.fixedCrop(nd, 0, 0, 224, 224);
Image.toImage(cropped);
}
}
1 change: 1 addition & 0 deletions scala-package/examples/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

<properties>
<skipTests>true</skipTests>
<skipJavaTests>${skipTests}</skipJavaTests>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is this for? I don't think it is valid in maven to set one property based on another property and it is just set to false in modules which have java tests to run

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This keeps control of a Junit test is going to run or not, by default it is set to be false. In fact, this is working...

</properties>

<build>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,6 @@
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import javax.imageio.ImageIO;
import java.awt.Graphics2D;
import java.awt.image.BufferedImage;
import java.io.BufferedReader;
import java.io.File;
Expand All @@ -47,76 +45,7 @@ public class PredictorExample {
private String inputImagePath = "/images/dog.jpg";

final static Logger logger = LoggerFactory.getLogger(PredictorExample.class);

/**
* Load the image from file to buffered image
* It can be replaced by loadImageFromFile from ObjectDetector
* @param inputImagePath input image Path in String
* @return Buffered image
*/
private static BufferedImage loadIamgeFromFile(String inputImagePath) {
BufferedImage buf = null;
try {
buf = ImageIO.read(new File(inputImagePath));
} catch (IOException e) {
System.err.println(e);
}
return buf;
}

/**
* Reshape the current image using ImageIO and Graph2D
* It can be replaced by reshapeImage from ObjectDetector
* @param buf Buffered image
* @param newWidth desired width
* @param newHeight desired height
* @return a reshaped bufferedImage
*/
private static BufferedImage reshapeImage(BufferedImage buf, int newWidth, int newHeight) {
BufferedImage resizedImage = new BufferedImage(newWidth, newHeight, BufferedImage.TYPE_INT_RGB);
Graphics2D g = resizedImage.createGraphics();
g.drawImage(buf, 0, 0, newWidth, newHeight, null);
g.dispose();
return resizedImage;
}

/**
* Convert an image from a buffered image into pixels float array
* It can be replaced by bufferedImageToPixels from ObjectDetector
* @param buf buffered image
* @return Float array
*/
private static float[] imagePreprocess(BufferedImage buf) {
// Get height and width of the image
int w = buf.getWidth();
int h = buf.getHeight();

// get an array of integer pixels in the default RGB color mode
int[] pixels = buf.getRGB(0, 0, w, h, null, 0, w);

// 3 times height and width for R,G,B channels
float[] result = new float[3 * h * w];

int row = 0;
// copy pixels to array vertically
while (row < h) {
int col = 0;
// copy pixels to array horizontally
while (col < w) {
int rgb = pixels[row * w + col];
// getting red color
result[0 * h * w + row * w + col] = (rgb >> 16) & 0xFF;
// getting green color
result[1 * h * w + row * w + col] = (rgb >> 8) & 0xFF;
// getting blue color
result[2 * h * w + row * w + col] = rgb & 0xFF;
col += 1;
}
row += 1;
}
buf.flush();
return result;
}
private static NDArray$ NDArray = NDArray$.MODULE$;

/**
* Helper class to print the maximum prediction result
Expand Down Expand Up @@ -170,22 +99,21 @@ public static void main(String[] args) {
inputDesc.add(new DataDesc("data", inputShape, DType.Float32(), "NCHW"));
Predictor predictor = new Predictor(inst.modelPathPrefix, inputDesc, context,0);
// Prepare data
BufferedImage img = loadIamgeFromFile(inst.inputImagePath);

img = reshapeImage(img, 224, 224);
NDArray img = Image.imRead(inst.inputImagePath, 1, true);
img = Image.imResize(img, 224, 224, null);
// predict
float[][] result = predictor.predict(new float[][]{imagePreprocess(img)});
float[][] result = predictor.predict(new float[][]{img.toArray()});
try {
System.out.println("Predict with Float input");
System.out.println(printMaximumClass(result[0], inst.modelPathPrefix));
} catch (IOException e) {
System.err.println(e);
}
// predict with NDArray
NDArray nd = new NDArray(
imagePreprocess(img),
new Shape(new int[]{1, 3, 224, 224}),
Context.cpu());
NDArray nd = img;
nd = NDArray.transpose(nd, new Shape(new int[]{2, 0, 1}), null)[0];
nd = NDArray.expand_dims(nd, 0, null)[0];
nd = nd.asType(DType.Float32());
List<NDArray> ndList = new ArrayList<>();
ndList.add(nd);
List<NDArray> ndResult = predictor.predictWithNDArray(ndList);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,9 @@ import org.apache.commons.io.FileUtils

object Util {

def downloadUrl(url: String, filePath: String, maxRetry: Option[Int] = None) : Unit = {
def downloadUrl(url: String, filePath: String, maxRetry: Int = 3) : Unit = {
val tmpFile = new File(filePath)
var retry = maxRetry.getOrElse(3)
var retry = maxRetry
var success = false
if (!tmpFile.exists()) {
while (retry > 0 && !success) {
Expand Down
Loading