From 5cdff468e0a953a0bf618d05b4358733d516fc8a Mon Sep 17 00:00:00 2001 From: Piyush Ghai Date: Tue, 8 Jan 2019 13:42:58 -0800 Subject: [PATCH] Addressed PR comments --- .../javaapi/ObjectDetectorOutputTest.java | 29 ++++++--- .../infer/javaapi/ObjectDetectorTest.java | 63 ++++++++++++------- .../mxnet/infer/javaapi/PredictorTest.java | 19 +----- 3 files changed, 62 insertions(+), 49 deletions(-) diff --git a/scala-package/infer/src/test/java/org/apache/mxnet/infer/javaapi/ObjectDetectorOutputTest.java b/scala-package/infer/src/test/java/org/apache/mxnet/infer/javaapi/ObjectDetectorOutputTest.java index 4087f3b1bdaa..04041fcda9bf 100644 --- a/scala-package/infer/src/test/java/org/apache/mxnet/infer/javaapi/ObjectDetectorOutputTest.java +++ b/scala-package/infer/src/test/java/org/apache/mxnet/infer/javaapi/ObjectDetectorOutputTest.java @@ -17,23 +17,29 @@ package org.apache.mxnet.infer.javaapi; +import org.junit.Assert; import org.junit.Test; public class ObjectDetectorOutputTest { + private String predictedClassName = "lion"; + + private float delta = 0.00001f; + @Test public void testConstructor() { float[] arr = new float[]{0f, 1f, 2f, 3f, 4f}; - ObjectDetectorOutput odOutput = new ObjectDetectorOutput("simba", arr); + ObjectDetectorOutput odOutput = new ObjectDetectorOutput(predictedClassName, arr); + + Assert.assertEquals(odOutput.getClassName(), predictedClassName); + Assert.assertEquals("Threshold not matching", odOutput.getProbability(), 0f, delta); + Assert.assertEquals("Threshold not matching", odOutput.getXMin(), 1f, delta); + Assert.assertEquals("Threshold not matching", odOutput.getXMax(), 2f, delta); + Assert.assertEquals("Threshold not matching", odOutput.getYMin(), 3f, delta); + Assert.assertEquals("Threshold not matching", odOutput.getYMax(), 4f, delta); - assert (odOutput.getClassName().equals("simba")); - assert (odOutput.getProbability() == 0); - assert (odOutput.getXMin() == 1); - assert (odOutput.getXMax() == 2); - assert (odOutput.getYMin() == 3); - assert (odOutput.getYMax() == 4); } @Test (expected = ArrayIndexOutOfBoundsException.class) @@ -41,8 +47,13 @@ public void testIncompleteArgsConstructor() { float[] arr = new float[]{0f, 1f}; - ObjectDetectorOutput odOutput = new ObjectDetectorOutput("simba", arr); + ObjectDetectorOutput odOutput = new ObjectDetectorOutput(predictedClassName, arr); + + Assert.assertEquals(odOutput.getClassName(), predictedClassName); + Assert.assertEquals("Threshold not matching", odOutput.getProbability(), 0f, delta); + Assert.assertEquals("Threshold not matching", odOutput.getXMin(), 1f, delta); - odOutput.getYMax(); + // This is where exception will be thrown + odOutput.getXMax(); } } diff --git a/scala-package/infer/src/test/java/org/apache/mxnet/infer/javaapi/ObjectDetectorTest.java b/scala-package/infer/src/test/java/org/apache/mxnet/infer/javaapi/ObjectDetectorTest.java index 789d47f42219..a5e64911d141 100644 --- a/scala-package/infer/src/test/java/org/apache/mxnet/infer/javaapi/ObjectDetectorTest.java +++ b/scala-package/infer/src/test/java/org/apache/mxnet/infer/javaapi/ObjectDetectorTest.java @@ -33,32 +33,51 @@ public class ObjectDetectorTest { - List inputDesc; - BufferedImage inputImage; + private List inputDesc; + private BufferedImage inputImage; - List> result; + private List> expectedResult; - ObjectDetector objectDetector; + private ObjectDetector objectDetector; + + private int batchSize = 1; + + private int channels = 3; + + private int imageHeight = 512; + + private int imageWidth = 512; + + private String dataName = "data"; + + private int topK = 5; + + private String predictedClassName = "lion"; // Random string + + private Shape getTestShape() { + + return new Shape(new int[] {batchSize, channels, imageHeight, imageWidth}); + } @Before public void setUp() { inputDesc = new ArrayList<>(); - inputDesc.add(new DataDesc("", new Shape(new int[]{1, 3, 512, 512}), DType.Float32(), Layout.NCHW())); - inputImage = new BufferedImage(512, 512, BufferedImage.TYPE_INT_RGB); + inputDesc.add(new DataDesc(dataName, getTestShape(), DType.Float32(), Layout.NCHW())); + inputImage = new BufferedImage(imageWidth, imageHeight, BufferedImage.TYPE_INT_RGB); objectDetector = Mockito.mock(ObjectDetector.class); - result = new ArrayList<>(); - result.add(new ArrayList()); - result.get(0).add(new ObjectDetectorOutput("simbaa", new float[]{})); + expectedResult = new ArrayList<>(); + expectedResult.add(new ArrayList()); + expectedResult.get(0).add(new ObjectDetectorOutput(predictedClassName, new float[]{})); } @Test public void testObjectDetectorWithInputImage() { - Mockito.when(objectDetector.imageObjectDetect(inputImage, 5)).thenReturn(result); - List> actualResult = objectDetector.imageObjectDetect(inputImage, 5); - Mockito.verify(objectDetector, Mockito.times(1)).imageObjectDetect(inputImage, 5); - Assert.assertEquals(result, actualResult); + Mockito.when(objectDetector.imageObjectDetect(inputImage, topK)).thenReturn(expectedResult); + List> actualResult = objectDetector.imageObjectDetect(inputImage, topK); + Mockito.verify(objectDetector, Mockito.times(1)).imageObjectDetect(inputImage, topK); + Assert.assertEquals(expectedResult, actualResult); } @@ -67,21 +86,21 @@ public void testObjectDetectorWithBatchImage() { List batchImage = new ArrayList<>(); batchImage.add(inputImage); - Mockito.when(objectDetector.imageBatchObjectDetect(batchImage, 5)).thenReturn(result); - List> actualResult = objectDetector.imageBatchObjectDetect(batchImage, 5); - Mockito.verify(objectDetector, Mockito.times(1)).imageBatchObjectDetect(batchImage, 5); - Assert.assertEquals(result, actualResult); + Mockito.when(objectDetector.imageBatchObjectDetect(batchImage, topK)).thenReturn(expectedResult); + List> actualResult = objectDetector.imageBatchObjectDetect(batchImage, topK); + Mockito.verify(objectDetector, Mockito.times(1)).imageBatchObjectDetect(batchImage, topK); + Assert.assertEquals(expectedResult, actualResult); } @Test public void testObjectDetectorWithNDArrayInput() { - NDArray inputArr = ObjectDetector.bufferedImageToPixels(inputImage, new Shape(new int[] {1, 3, 512, 512})); + NDArray inputArr = ObjectDetector.bufferedImageToPixels(inputImage, getTestShape()); List inputL = new ArrayList<>(); inputL.add(inputArr); - Mockito.when(objectDetector.objectDetectWithNDArray(inputL, 5)).thenReturn(result); - List> actualResult = objectDetector.objectDetectWithNDArray(inputL, 5); - Mockito.verify(objectDetector, Mockito.times(1)).objectDetectWithNDArray(inputL, 5); - Assert.assertEquals(result, actualResult); + Mockito.when(objectDetector.objectDetectWithNDArray(inputL, 5)).thenReturn(expectedResult); + List> actualResult = objectDetector.objectDetectWithNDArray(inputL, topK); + Mockito.verify(objectDetector, Mockito.times(1)).objectDetectWithNDArray(inputL, topK); + Assert.assertEquals(expectedResult, actualResult); } } diff --git a/scala-package/infer/src/test/java/org/apache/mxnet/infer/javaapi/PredictorTest.java b/scala-package/infer/src/test/java/org/apache/mxnet/infer/javaapi/PredictorTest.java index cebd7ced7196..e7a6c9652346 100644 --- a/scala-package/infer/src/test/java/org/apache/mxnet/infer/javaapi/PredictorTest.java +++ b/scala-package/infer/src/test/java/org/apache/mxnet/infer/javaapi/PredictorTest.java @@ -39,7 +39,7 @@ public void setUp() { } @Test - public void testPredictWithFloatArry() { + public void testPredictWithFloatArray() { float tmp[][] = new float[1][224]; for (int x = 0; x < 1; x++) { @@ -55,23 +55,6 @@ public void testPredictWithFloatArry() { Assert.assertArrayEquals(expectedResult, actualResult); } - @Test - public void testPredictWithDoubleArry() { - - double tmp[][] = new double[1][224]; - for (int x = 0; x < 1; x++) { - for (int y = 0; y < 224; y++) - tmp[x][y] = (int) (Math.random() * 10); - } - - double [][] expectedResult = new double[][] {{1d, 2d}}; - Mockito.when(mockPredictor.predict(tmp)).thenReturn(expectedResult); - double[][] actualResult = mockPredictor.predict(tmp); - - Mockito.verify(mockPredictor, Mockito.times(1)).predict(tmp); - Assert.assertArrayEquals(expectedResult, actualResult); - } - @Test public void testPredictWithNDArray() {