diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/javaapi/Image.scala b/scala-package/core/src/main/scala/org/apache/mxnet/javaapi/Image.scala index 9276452c3ba0..9d5216cbc862 100644 --- a/scala-package/core/src/main/scala/org/apache/mxnet/javaapi/Image.scala +++ b/scala-package/core/src/main/scala/org/apache/mxnet/javaapi/Image.scala @@ -31,8 +31,8 @@ object Image { * to mxnet's default RGB format (instead of opencv's default BGR). * @return NDArray in HWC format */ - def imDecode(buf: Array[Byte], flag: Int, toRGB: Boolean, out: NDArray): NDArray = { - org.apache.mxnet.Image.imDecode(buf, flag, toRGB, Some(out)) + def imDecode(buf: Array[Byte], flag: Int, toRGB: Boolean): NDArray = { + org.apache.mxnet.Image.imDecode(buf, flag, toRGB, None) } /** @@ -40,9 +40,8 @@ object Image { * @param inputStream the inputStream of the image * @return NDArray in HWC format */ - def imDecode(inputStream: InputStream, flag: Int = 1, toRGB: Boolean = true, - out: NDArray): NDArray = { - org.apache.mxnet.Image.imDecode(inputStream, flag, toRGB, Some(out)) + def imDecode(inputStream: InputStream, flag: Int = 1, toRGB: Boolean = true): NDArray = { + org.apache.mxnet.Image.imDecode(inputStream, flag, toRGB, None) } /** @@ -54,8 +53,8 @@ object Image { * (instead of opencv's default BGR). * @return org.apache.mxnet.NDArray in HWC format */ - def imRead(filename: String, flag: Int, toRGB: Boolean = true, out: NDArray): NDArray = { - org.apache.mxnet.Image.imRead(filename, Some(flag), Some(toRGB), Some(out)) + def imRead(filename: String, flag: Int, toRGB: Boolean = true): NDArray = { + org.apache.mxnet.Image.imRead(filename, Some(flag), Some(toRGB), None) } /** @@ -66,9 +65,9 @@ object Image { * @param interp Interpolation method (default=cv2.INTER_LINEAR). * @return org.apache.mxnet.NDArray */ - def imResize(src: NDArray, w: Int, h: Int, - interp: Integer, out: NDArray): NDArray = { - org.apache.mxnet.Image.imResize(src, w, h, Some(interp), Some(out)) + def imResize(src: NDArray, w: Int, h: Int, interp: Integer): NDArray = { + val interpVal = if (interp == null) None else Some(interp.intValue()) + org.apache.mxnet.Image.imResize(src, w, h, interpVal, None) } /** diff --git a/scala-package/core/src/test/java/org/apache/mxnet/javaapi/ImageTest.java b/scala-package/core/src/test/java/org/apache/mxnet/javaapi/ImageTest.java index 7a028be30527..049b1d6156be 100644 --- a/scala-package/core/src/test/java/org/apache/mxnet/javaapi/ImageTest.java +++ b/scala-package/core/src/test/java/org/apache/mxnet/javaapi/ImageTest.java @@ -6,11 +6,13 @@ import java.io.File; import java.net.URL; +import static org.junit.Assert.assertArrayEquals; + public class ImageTest { - private String imLocation; + private static String imLocation; - private void downloadUrl(String url, String filePath, int maxRetry) throws Exception{ + private static void downloadUrl(String url, String filePath, int maxRetry) throws Exception{ File tmpFile = new File(filePath); Boolean success = false; if (!tmpFile.exists()) { @@ -29,7 +31,7 @@ private void downloadUrl(String url, String filePath, int maxRetry) throws Excep } @BeforeClass - public void downloadFile() throws Exception { + public static void downloadFile() throws Exception { String tempDirPath = System.getProperty("java.io.tmpdir"); imLocation = tempDirPath + "/inputImages/Pug-Cookie.jpg"; try { @@ -42,8 +44,10 @@ public void downloadFile() throws Exception { @Test public void testImageProcess() { - NDArray nd = Image.imRead(imLocation, 1, true, null); - NDArray nd2 = Image.imResize(nd, 224, 224, null, null); + NDArray nd = Image.imRead(imLocation, 1, true); + assertArrayEquals(nd.shape().toArray(), new int[]{576, 1024, 3}); + NDArray nd2 = Image.imResize(nd, 224, 224, null); + assertArrayEquals(nd.shape().toArray(), new int[]{224, 224, 3}); NDArray cropped = Image.fixedCrop(nd, 0, 0, 224, 224); Image.toImage(cropped); } diff --git a/scala-package/examples/src/main/java/org/apache/mxnetexamples/javaapi/infer/predictor/PredictorExample.java b/scala-package/examples/src/main/java/org/apache/mxnetexamples/javaapi/infer/predictor/PredictorExample.java index a1c3401c2378..4559315866ba 100644 --- a/scala-package/examples/src/main/java/org/apache/mxnetexamples/javaapi/infer/predictor/PredictorExample.java +++ b/scala-package/examples/src/main/java/org/apache/mxnetexamples/javaapi/infer/predictor/PredictorExample.java @@ -98,8 +98,8 @@ public static void main(String[] args) { inputDesc.add(new DataDesc("data", inputShape, DType.Float32(), "NCHW")); Predictor predictor = new Predictor(inst.modelPathPrefix, inputDesc, context,0); // Prepare data - NDArray img = Image.imRead(inst.inputImagePath, 1, true, null); - img = Image.imResize(img, 224, 224, null, null); + NDArray img = Image.imRead(inst.inputImagePath, 1, true); + img = Image.imResize(img, 224, 224, null); // predict float[][] result = predictor.predict(new float[][]{img.toArray()}); try {