Skip to content

Commit

Permalink
[MXNET-1263] Unit Tests for Java Predictor and Object Detector APIs (a…
Browse files Browse the repository at this point in the history
…pache#13794)

* Added unit tests for Predictor API in Java

* Added unit tests for ObjectDetectorOutput

* Added unit tests for ObjectDetector API in Java

* Addressed PR comments

* Added Maven SureFire plugin to run the Java UTs

* Pom file clean up -- moved surefire plugin to parent pom.xml

* Renamed skipTests to SkipJavaTests
  • Loading branch information
piyushghai authored and haohuw committed Jun 23, 2019
1 parent 4d4fe8b commit 2302420
Show file tree
Hide file tree
Showing 6 changed files with 287 additions and 11 deletions.
14 changes: 4 additions & 10 deletions scala-package/core/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,10 @@
<artifactId>mxnet-core</artifactId>
<name>MXNet Scala Package - Core</name>

<properties>
<skipJavaTests>false</skipJavaTests>
</properties>

<build>
<plugins>
<plugin>
Expand Down Expand Up @@ -115,16 +119,6 @@
</argLine>
</configuration>
</plugin>
<plugin>
<groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-surefire-plugin</artifactId>
<version>2.22.0</version>
<configuration>
<argLine>
-Djava.library.path=${project.parent.basedir}/native/target
</argLine>
</configuration>
</plugin>
<plugin>
<groupId>org.scalastyle</groupId>
<artifactId>scalastyle-maven-plugin</artifactId>
Expand Down
12 changes: 12 additions & 0 deletions scala-package/infer/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,10 @@
<relativePath>../pom.xml</relativePath>
</parent>

<properties>
<skipJavaTests>false</skipJavaTests>
</properties>

<artifactId>mxnet-infer</artifactId>
<name>MXNet Scala Package - Inference</name>

Expand Down Expand Up @@ -60,5 +64,13 @@
<version>1.10.19</version>
<scope>test</scope>
</dependency>

<dependency>
<groupId>junit</groupId>
<artifactId>junit</artifactId>
<version>4.11</version>
<scope>test</scope>
</dependency>

</dependencies>
</project>
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.mxnet.infer.javaapi;

import org.junit.Assert;
import org.junit.Test;

public class ObjectDetectorOutputTest {

private String predictedClassName = "lion";

private float delta = 0.00001f;

@Test
public void testConstructor() {

float[] arr = new float[]{0f, 1f, 2f, 3f, 4f};

ObjectDetectorOutput odOutput = new ObjectDetectorOutput(predictedClassName, arr);

Assert.assertEquals(odOutput.getClassName(), predictedClassName);
Assert.assertEquals("Threshold not matching", odOutput.getProbability(), 0f, delta);
Assert.assertEquals("Threshold not matching", odOutput.getXMin(), 1f, delta);
Assert.assertEquals("Threshold not matching", odOutput.getXMax(), 2f, delta);
Assert.assertEquals("Threshold not matching", odOutput.getYMin(), 3f, delta);
Assert.assertEquals("Threshold not matching", odOutput.getYMax(), 4f, delta);

}

@Test (expected = ArrayIndexOutOfBoundsException.class)
public void testIncompleteArgsConstructor() {

float[] arr = new float[]{0f, 1f};

ObjectDetectorOutput odOutput = new ObjectDetectorOutput(predictedClassName, arr);

Assert.assertEquals(odOutput.getClassName(), predictedClassName);
Assert.assertEquals("Threshold not matching", odOutput.getProbability(), 0f, delta);
Assert.assertEquals("Threshold not matching", odOutput.getXMin(), 1f, delta);

// This is where exception will be thrown
odOutput.getXMax();
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.mxnet.infer.javaapi;

import org.apache.mxnet.Layout;
import org.apache.mxnet.javaapi.DType;
import org.apache.mxnet.javaapi.DataDesc;
import org.apache.mxnet.javaapi.NDArray;
import org.apache.mxnet.javaapi.Shape;
import org.junit.Assert;
import org.junit.Before;
import org.junit.Test;
import org.mockito.Mockito;

import java.awt.image.BufferedImage;
import java.util.ArrayList;
import java.util.List;

public class ObjectDetectorTest {

private List<DataDesc> inputDesc;
private BufferedImage inputImage;

private List<List<ObjectDetectorOutput>> expectedResult;

private ObjectDetector objectDetector;

private int batchSize = 1;

private int channels = 3;

private int imageHeight = 512;

private int imageWidth = 512;

private String dataName = "data";

private int topK = 5;

private String predictedClassName = "lion"; // Random string

private Shape getTestShape() {

return new Shape(new int[] {batchSize, channels, imageHeight, imageWidth});
}

@Before
public void setUp() {

inputDesc = new ArrayList<>();
inputDesc.add(new DataDesc(dataName, getTestShape(), DType.Float32(), Layout.NCHW()));
inputImage = new BufferedImage(imageWidth, imageHeight, BufferedImage.TYPE_INT_RGB);
objectDetector = Mockito.mock(ObjectDetector.class);
expectedResult = new ArrayList<>();
expectedResult.add(new ArrayList<ObjectDetectorOutput>());
expectedResult.get(0).add(new ObjectDetectorOutput(predictedClassName, new float[]{}));
}

@Test
public void testObjectDetectorWithInputImage() {

Mockito.when(objectDetector.imageObjectDetect(inputImage, topK)).thenReturn(expectedResult);
List<List<ObjectDetectorOutput>> actualResult = objectDetector.imageObjectDetect(inputImage, topK);
Mockito.verify(objectDetector, Mockito.times(1)).imageObjectDetect(inputImage, topK);
Assert.assertEquals(expectedResult, actualResult);
}


@Test
public void testObjectDetectorWithBatchImage() {

List<BufferedImage> batchImage = new ArrayList<>();
batchImage.add(inputImage);
Mockito.when(objectDetector.imageBatchObjectDetect(batchImage, topK)).thenReturn(expectedResult);
List<List<ObjectDetectorOutput>> actualResult = objectDetector.imageBatchObjectDetect(batchImage, topK);
Mockito.verify(objectDetector, Mockito.times(1)).imageBatchObjectDetect(batchImage, topK);
Assert.assertEquals(expectedResult, actualResult);
}

@Test
public void testObjectDetectorWithNDArrayInput() {

NDArray inputArr = ObjectDetector.bufferedImageToPixels(inputImage, getTestShape());
List<NDArray> inputL = new ArrayList<>();
inputL.add(inputArr);
Mockito.when(objectDetector.objectDetectWithNDArray(inputL, 5)).thenReturn(expectedResult);
List<List<ObjectDetectorOutput>> actualResult = objectDetector.objectDetectWithNDArray(inputL, topK);
Mockito.verify(objectDetector, Mockito.times(1)).objectDetectWithNDArray(inputL, topK);
Assert.assertEquals(expectedResult, actualResult);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.mxnet.infer.javaapi;

import org.apache.mxnet.javaapi.Context;
import org.apache.mxnet.javaapi.NDArray;
import org.apache.mxnet.javaapi.Shape;
import org.junit.Assert;
import org.junit.Before;
import org.junit.Test;
import org.mockito.Mockito;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;

public class PredictorTest {

Predictor mockPredictor;

@Before
public void setUp() {
mockPredictor = Mockito.mock(Predictor.class);
}

@Test
public void testPredictWithFloatArray() {

float tmp[][] = new float[1][224];
for (int x = 0; x < 1; x++) {
for (int y = 0; y < 224; y++)
tmp[x][y] = (int) (Math.random() * 10);
}

float [][] expectedResult = new float[][] {{1f, 2f}};
Mockito.when(mockPredictor.predict(tmp)).thenReturn(expectedResult);
float[][] actualResult = mockPredictor.predict(tmp);

Mockito.verify(mockPredictor, Mockito.times(1)).predict(tmp);
Assert.assertArrayEquals(expectedResult, actualResult);
}

@Test
public void testPredictWithNDArray() {

float[] tmpArr = new float[224];
for (int y = 0; y < 224; y++)
tmpArr[y] = (int) (Math.random() * 10);

NDArray arr = new org.apache.mxnet.javaapi.NDArray(tmpArr, new Shape(new int[] {1, 1, 1, 224}), new Context("cpu", 0));

List<NDArray> inputList = new ArrayList<>();
inputList.add(arr);

NDArray expected = new NDArray(tmpArr, new Shape(new int[] {1, 1, 1, 224}), new Context("cpu", 0));
List<NDArray> expectedResult = new ArrayList<>();
expectedResult.add(expected);

Mockito.when(mockPredictor.predictWithNDArray(inputList)).thenReturn(expectedResult);

List<NDArray> actualOutput = mockPredictor.predictWithNDArray(inputList);

Mockito.verify(mockPredictor, Mockito.times(1)).predictWithNDArray(inputList);

Assert.assertEquals(expectedResult, actualOutput);
}

@Test
public void testPredictWithListOfFloatsAsInput() {
List<List<Float>> input = new ArrayList<>();

input.add(Arrays.asList(new Float[] {1f, 2f}));

List<List<Float>> expectedOutput = new ArrayList<>(input);

Mockito.when(mockPredictor.predict(input)).thenReturn(expectedOutput);

List<List<Float>> actualOutput = mockPredictor.predict(input);

Mockito.verify(mockPredictor, Mockito.times(1)).predict(input);

Assert.assertEquals(expectedOutput, actualOutput);

}
}
7 changes: 6 additions & 1 deletion scala-package/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
<cxx>g++</cxx>
<dollar>$</dollar>
<MXNET_DIR>${project.basedir}/..</MXNET_DIR>
<skipJavaTests>true</skipJavaTests>
</properties>

<packaging>pom</packaging>
Expand Down Expand Up @@ -228,8 +229,12 @@
<plugin>
<groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-surefire-plugin</artifactId>
<version>2.19</version>
<version>2.22.0</version>
<configuration>
<skipTests>${skipJavaTests}</skipTests>
<argLine>
-Djava.library.path=${project.parent.basedir}/native/target
</argLine>
<useSystemClassLoader>false</useSystemClassLoader>
</configuration>
</plugin>
Expand Down

0 comments on commit 2302420

Please sign in to comment.