diff --git a/scala-package/.gitignore b/scala-package/.gitignore index 9bf7851716d6..dadc000c612e 100644 --- a/scala-package/.gitignore +++ b/scala-package/.gitignore @@ -9,3 +9,4 @@ core/src/main/scala/org/apache/mxnet/SymbolBase.scala core/src/main/scala/org/apache/mxnet/SymbolRandomAPIBase.scala examples/scripts/infer/images/ examples/scripts/infer/models/ +examples/scripts/infer/objectdetector/boundingImage.png diff --git a/scala-package/examples/src/main/java/org/apache/mxnetexamples/javaapi/infer/objectdetector/README.md b/scala-package/examples/src/main/java/org/apache/mxnetexamples/javaapi/infer/objectdetector/README.md index 8a9ed3e1736b..4c4512f152c8 100644 --- a/scala-package/examples/src/main/java/org/apache/mxnetexamples/javaapi/infer/objectdetector/README.md +++ b/scala-package/examples/src/main/java/org/apache/mxnetexamples/javaapi/infer/objectdetector/README.md @@ -84,7 +84,7 @@ After the previous steps, you should be able to run the code using the following From the `scala-package/examples/scripts/infer/objectdetector/` folder run: ```bash -./run_ssd_example.sh ../models/resnet50_ssd/resnet50_ssd_model ../images/dog.jpg ../images +./run_ssd_java_example.sh ../models/resnet50_ssd/resnet50_ssd_model ../images/dog.jpg ../images ``` **Notes**: diff --git a/scala-package/examples/src/main/java/org/apache/mxnetexamples/javaapi/infer/objectdetector/SSDClassifierExample.java b/scala-package/examples/src/main/java/org/apache/mxnetexamples/javaapi/infer/objectdetector/SSDClassifierExample.java index a9c00f7f1d81..31b8514de345 100644 --- a/scala-package/examples/src/main/java/org/apache/mxnetexamples/javaapi/infer/objectdetector/SSDClassifierExample.java +++ b/scala-package/examples/src/main/java/org/apache/mxnetexamples/javaapi/infer/objectdetector/SSDClassifierExample.java @@ -28,12 +28,11 @@ import org.apache.mxnet.infer.javaapi.ObjectDetector; // scalastyle:off +import javax.imageio.ImageIO; import java.awt.image.BufferedImage; // scalastyle:on -import java.util.ArrayList; -import java.util.Arrays; -import java.util.List; +import java.util.*; import java.io.File; @@ -128,22 +127,34 @@ public static void main(String[] args) { try { Shape inputShape = new Shape(new int[]{1, 3, 512, 512}); Shape outputShape = new Shape(new int[]{1, 6132, 6}); - - - int width = inputShape.get(2); - int height = inputShape.get(3); + StringBuilder outputStr = new StringBuilder().append("\n"); List> output = runObjectDetectionSingle(mdprefixDir, imgPath, context); - + + // Creating Bounding box material + BufferedImage buf = ImageIO.read(new File(imgPath)); + int width = buf.getWidth(); + int height = buf.getHeight(); + List> boxes = new ArrayList<>(); + List names = new ArrayList<>(); for (List ele : output) { for (ObjectDetectorOutput i : ele) { outputStr.append("Class: " + i.getClassName() + "\n"); outputStr.append("Probabilties: " + i.getProbability() + "\n"); - - List coord = Arrays.asList(i.getXMin() * width, - i.getXMax() * height, i.getYMin() * width, i.getYMax() * height); + names.add(i.getClassName()); + Map map = new HashMap<>(); + float xmin = i.getXMin() * width; + float xmax = i.getXMax() * width; + float ymin = i.getYMin() * height; + float ymax = i.getYMax() * height; + List coord = Arrays.asList(xmin, xmax, ymin, ymax); + map.put("xmin", (int) xmin); + map.put("xmax", (int) xmax); + map.put("ymin", (int) ymin); + map.put("ymax", (int) ymax); + boxes.add(map); StringBuilder sb = new StringBuilder(); for (float c : coord) { sb.append(", ").append(c); @@ -152,7 +163,12 @@ public static void main(String[] args) { } } logger.info(outputStr.toString()); - + + // Covert to image + Image.drawBoundingBox(buf, boxes, names); + File outputFile = new File("boundingImage.png"); + ImageIO.write(buf, "png", outputFile); + List>> outputList = runObjectDetectionBatch(mdprefixDir, imgDir, context); @@ -177,7 +193,6 @@ public static void main(String[] args) { } } logger.info(outputStr.toString()); - } catch (Exception e) { logger.error(e.getMessage(), e); parser.printUsage(System.err); diff --git a/scala-package/infer/src/main/scala/org/apache/mxnet/infer/ObjectDetector.scala b/scala-package/infer/src/main/scala/org/apache/mxnet/infer/ObjectDetector.scala index e29f068d5558..28a578cae79f 100644 --- a/scala-package/infer/src/main/scala/org/apache/mxnet/infer/ObjectDetector.scala +++ b/scala-package/infer/src/main/scala/org/apache/mxnet/infer/ObjectDetector.scala @@ -132,7 +132,7 @@ class ObjectDetector(modelPathPrefix: String, if (topK.isDefined) { var sortedIndices = predictResult.zipWithIndex.sortBy(-_._1(1)).map(_._2) sortedIndices = sortedIndices.take(topK.get) - // takeRight(5) would provide the output as Array[Accuracy, Xmin, Ymin, Xmax, Ymax + // takeRight(5) would provide the output as Array[Accuracy, Xmin, Ymin, Xmax, Ymax] result = sortedIndices.map(idx => (synset(predictResult(idx)(0).toInt), predictResult(idx).takeRight(5))).toIndexedSeq diff --git a/scala-package/infer/src/main/scala/org/apache/mxnet/infer/javaapi/ObjectDetectorOutput.scala b/scala-package/infer/src/main/scala/org/apache/mxnet/infer/javaapi/ObjectDetectorOutput.scala index 5a6ac7599fa9..32fd87e05f69 100644 --- a/scala-package/infer/src/main/scala/org/apache/mxnet/infer/javaapi/ObjectDetectorOutput.scala +++ b/scala-package/infer/src/main/scala/org/apache/mxnet/infer/javaapi/ObjectDetectorOutput.scala @@ -52,14 +52,14 @@ class ObjectDetectorOutput (className: String, args: Array[Float]){ * * @return Float of the max X coordinate for the object bounding box */ - def getXMax: Float = args(2) + def getXMax: Float = args(3) /** * Gets the minimum Y coordinate for the bounding box containing the predicted object. * * @return Float of the min Y coordinate for the object bounding box */ - def getYMin: Float = args(3) + def getYMin: Float = args(2) /** * Gets the maximum Y coordinate for the bounding box containing the predicted object. diff --git a/scala-package/infer/src/test/java/org/apache/mxnet/infer/javaapi/ObjectDetectorOutputTest.java b/scala-package/infer/src/test/java/org/apache/mxnet/infer/javaapi/ObjectDetectorOutputTest.java index 04041fcda9bf..6f3df86b8e74 100644 --- a/scala-package/infer/src/test/java/org/apache/mxnet/infer/javaapi/ObjectDetectorOutputTest.java +++ b/scala-package/infer/src/test/java/org/apache/mxnet/infer/javaapi/ObjectDetectorOutputTest.java @@ -36,8 +36,8 @@ public void testConstructor() { Assert.assertEquals(odOutput.getClassName(), predictedClassName); Assert.assertEquals("Threshold not matching", odOutput.getProbability(), 0f, delta); Assert.assertEquals("Threshold not matching", odOutput.getXMin(), 1f, delta); - Assert.assertEquals("Threshold not matching", odOutput.getXMax(), 2f, delta); - Assert.assertEquals("Threshold not matching", odOutput.getYMin(), 3f, delta); + Assert.assertEquals("Threshold not matching", odOutput.getXMax(), 3f, delta); + Assert.assertEquals("Threshold not matching", odOutput.getYMin(), 2f, delta); Assert.assertEquals("Threshold not matching", odOutput.getYMax(), 4f, delta); }