Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
fix the issue with JUnit test
Browse files Browse the repository at this point in the history
  • Loading branch information
lanking520 committed Jan 15, 2019
1 parent 2632c0d commit 4936b6a
Show file tree
Hide file tree
Showing 6 changed files with 19 additions and 12 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ object Image {
* @param flag Convert decoded image to grayscale (0) or color (1).
* @param to_rgb Whether to convert decoded image
* to mxnet's default RGB format (instead of opencv's default BGR).
* @return NDArray in HWC format
* @return NDArray in HWC format with DType uint8
*/
def imDecode(buf: Array[Byte], flag: Int,
to_rgb: Boolean,
Expand All @@ -56,7 +56,7 @@ object Image {
/**
* Same imageDecode with InputStream
* @param inputStream the inputStream of the image
* @return NDArray in HWC format
* @return NDArray in HWC format with DType uint8
*/
def imDecode(inputStream: InputStream, flag: Int = 1,
to_rgb: Boolean = true,
Expand All @@ -78,7 +78,7 @@ object Image {
* @param flag Convert decoded image to grayscale (0) or color (1).
* @param to_rgb Whether to convert decoded image to mxnet's default RGB format
* (instead of opencv's default BGR).
* @return org.apache.mxnet.NDArray in HWC format
* @return org.apache.mxnet.NDArray in HWC format with DType uint8
*/
def imRead(filename: String, flag: Option[Int] = None,
to_rgb: Option[Boolean] = None,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -96,9 +96,11 @@ object NDArray extends NDArrayBase {
case ndArr: Seq[NDArray @unchecked] =>
if (ndArr.head.isInstanceOf[NDArray]) (ndArr.toArray, ndArr.toArray.map(_.handle))
else throw new IllegalArgumentException(
"Unsupported out var type, should be NDArray or subclass of Seq[NDArray]")
s"""Unsupported out ${output.getClass} type,
| should be NDArray or subclass of Seq[NDArray]""".stripMargin)
case _ => throw new IllegalArgumentException(
"Unsupported out var type, should be NDArray or subclass of Seq[NDArray]")
s"""Unsupported out ${output.getClass} type,
| should be NDArray or subclass of Seq[NDArray]""".stripMargin)
}
} else {
(null, null)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ object Image {
* @param flag Convert decoded image to grayscale (0) or color (1).
* @param toRGB Whether to convert decoded image
* to mxnet's default RGB format (instead of opencv's default BGR).
* @return NDArray in HWC format
* @return NDArray in HWC format with DType uint8
*/
def imDecode(buf: Array[Byte], flag: Int, toRGB: Boolean): NDArray = {
org.apache.mxnet.Image.imDecode(buf, flag, toRGB, None)
Expand All @@ -38,7 +38,7 @@ object Image {
/**
* Same imageDecode with InputStream
* @param inputStream the inputStream of the image
* @return NDArray in HWC format
* @return NDArray in HWC format with DType uint8
*/
def imDecode(inputStream: InputStream, flag: Int = 1, toRGB: Boolean = true): NDArray = {
org.apache.mxnet.Image.imDecode(inputStream, flag, toRGB, None)
Expand All @@ -51,7 +51,7 @@ object Image {
* @param flag Convert decoded image to grayscale (0) or color (1).
* @param toRGB Whether to convert decoded image to mxnet's default RGB format
* (instead of opencv's default BGR).
* @return org.apache.mxnet.NDArray in HWC format
* @return org.apache.mxnet.NDArray in HWC format with DType uint8
*/
def imRead(filename: String, flag: Int, toRGB: Boolean = true): NDArray = {
org.apache.mxnet.Image.imRead(filename, Some(flag), Some(toRGB), None)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ public class PredictorExample {
private String inputImagePath = "/images/dog.jpg";

final static Logger logger = LoggerFactory.getLogger(PredictorExample.class);
private static NDArray$ NDArray = NDArray$.MODULE$;

/**
* Helper class to print the maximum prediction result
Expand Down Expand Up @@ -110,6 +111,9 @@ public static void main(String[] args) {
}
// predict with NDArray
NDArray nd = img;
nd = NDArray.transpose(nd, new Shape(new int[]{2, 0, 1}), null)[0];
nd = NDArray.expand_dims(nd, 0, null)[0];
nd = nd.asType(DType.Float32());
List<NDArray> ndList = new ArrayList<>();
ndList.add(nd);
List<NDArray> ndResult = predictor.predictWithNDArray(ndList);
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
package org.apache.mxnetexamples.javaapi.infer.predictor;

import org.junit.Assert;
import org.junit.BeforeClass;
import org.junit.Test;
import org.apache.mxnetexamples.Util;
Expand All @@ -8,9 +9,9 @@

import java.io.File;

public class PredictorExampleSuite {
public class PredictorExampleTest {

final static Logger logger = LoggerFactory.getLogger(PredictorExampleSuite.class);
final static Logger logger = LoggerFactory.getLogger(PredictorExampleTest.class);
private static String modelPathPrefix = "";
private static String inputImagePath = "";

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -96,9 +96,9 @@ private[mxnet] object JavaNDArrayMacro extends GeneratorBase {
// add default out parameter
argDef += s"out: org.apache.mxnet.javaapi.NDArray"
if (useParamObject) {
impl += "if (po.getOut() != null) map(\"out\") = po.getOut()"
impl += "if (po.getOut() != null) map(\"out\") = po.getOut().nd"
} else {
impl += "if (out != null) map(\"out\") = out"
impl += "if (out != null) map(\"out\") = out.nd"
}
val returnType = "Array[org.apache.mxnet.javaapi.NDArray]"
// scalastyle:off
Expand Down

0 comments on commit 4936b6a

Please sign in to comment.