Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
add this example to Java world and fixing bugs
Browse files Browse the repository at this point in the history
  • Loading branch information
lanking520 committed Mar 22, 2019
1 parent ff60626 commit 3ea2c7d
Show file tree
Hide file tree
Showing 6 changed files with 35 additions and 19 deletions.
1 change: 1 addition & 0 deletions scala-package/.gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -9,3 +9,4 @@ core/src/main/scala/org/apache/mxnet/SymbolBase.scala
core/src/main/scala/org/apache/mxnet/SymbolRandomAPIBase.scala
examples/scripts/infer/images/
examples/scripts/infer/models/
examples/scripts/infer/objectdetector/boundingImage.png
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ After the previous steps, you should be able to run the code using the following
From the `scala-package/examples/scripts/infer/objectdetector/` folder run:

```bash
./run_ssd_example.sh ../models/resnet50_ssd/resnet50_ssd_model ../images/dog.jpg ../images
./run_ssd_java_example.sh ../models/resnet50_ssd/resnet50_ssd_model ../images/dog.jpg ../images
```

**Notes**:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,11 @@
import org.apache.mxnet.infer.javaapi.ObjectDetector;

// scalastyle:off
import javax.imageio.ImageIO;
import java.awt.image.BufferedImage;
// scalastyle:on

import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.*;

import java.io.File;

Expand Down Expand Up @@ -128,22 +127,34 @@ public static void main(String[] args) {
try {
Shape inputShape = new Shape(new int[]{1, 3, 512, 512});
Shape outputShape = new Shape(new int[]{1, 6132, 6});


int width = inputShape.get(2);
int height = inputShape.get(3);

StringBuilder outputStr = new StringBuilder().append("\n");

List<List<ObjectDetectorOutput>> output
= runObjectDetectionSingle(mdprefixDir, imgPath, context);


// Creating Bounding box material
BufferedImage buf = ImageIO.read(new File(imgPath));
int width = buf.getWidth();
int height = buf.getHeight();
List<Map<String, Integer>> boxes = new ArrayList<>();
List<String> names = new ArrayList<>();
for (List<ObjectDetectorOutput> ele : output) {
for (ObjectDetectorOutput i : ele) {
outputStr.append("Class: " + i.getClassName() + "\n");
outputStr.append("Probabilties: " + i.getProbability() + "\n");

List<Float> coord = Arrays.asList(i.getXMin() * width,
i.getXMax() * height, i.getYMin() * width, i.getYMax() * height);
names.add(i.getClassName());
Map<String, Integer> map = new HashMap<>();
float xmin = i.getXMin() * width;
float xmax = i.getXMax() * width;
float ymin = i.getYMin() * height;
float ymax = i.getYMax() * height;
List<Float> coord = Arrays.asList(xmin, xmax, ymin, ymax);
map.put("xmin", (int) xmin);
map.put("xmax", (int) xmax);
map.put("ymin", (int) ymin);
map.put("ymax", (int) ymax);
boxes.add(map);
StringBuilder sb = new StringBuilder();
for (float c : coord) {
sb.append(", ").append(c);
Expand All @@ -152,7 +163,12 @@ public static void main(String[] args) {
}
}
logger.info(outputStr.toString());


// Covert to image
Image.drawBoundingBox(buf, boxes, names);
File outputFile = new File("boundingImage.png");
ImageIO.write(buf, "png", outputFile);

List<List<List<ObjectDetectorOutput>>> outputList =
runObjectDetectionBatch(mdprefixDir, imgDir, context);

Expand All @@ -177,7 +193,6 @@ public static void main(String[] args) {
}
}
logger.info(outputStr.toString());

} catch (Exception e) {
logger.error(e.getMessage(), e);
parser.printUsage(System.err);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ class ObjectDetector(modelPathPrefix: String,
if (topK.isDefined) {
var sortedIndices = predictResult.zipWithIndex.sortBy(-_._1(1)).map(_._2)
sortedIndices = sortedIndices.take(topK.get)
// takeRight(5) would provide the output as Array[Accuracy, Xmin, Ymin, Xmax, Ymax
// takeRight(5) would provide the output as Array[Accuracy, Xmin, Ymin, Xmax, Ymax]
result = sortedIndices.map(idx
=> (synset(predictResult(idx)(0).toInt),
predictResult(idx).takeRight(5))).toIndexedSeq
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,14 +52,14 @@ class ObjectDetectorOutput (className: String, args: Array[Float]){
*
* @return Float of the max X coordinate for the object bounding box
*/
def getXMax: Float = args(2)
def getXMax: Float = args(3)

/**
* Gets the minimum Y coordinate for the bounding box containing the predicted object.
*
* @return Float of the min Y coordinate for the object bounding box
*/
def getYMin: Float = args(3)
def getYMin: Float = args(2)

/**
* Gets the maximum Y coordinate for the bounding box containing the predicted object.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,8 @@ public void testConstructor() {
Assert.assertEquals(odOutput.getClassName(), predictedClassName);
Assert.assertEquals("Threshold not matching", odOutput.getProbability(), 0f, delta);
Assert.assertEquals("Threshold not matching", odOutput.getXMin(), 1f, delta);
Assert.assertEquals("Threshold not matching", odOutput.getXMax(), 2f, delta);
Assert.assertEquals("Threshold not matching", odOutput.getYMin(), 3f, delta);
Assert.assertEquals("Threshold not matching", odOutput.getXMax(), 3f, delta);
Assert.assertEquals("Threshold not matching", odOutput.getYMin(), 2f, delta);
Assert.assertEquals("Threshold not matching", odOutput.getYMax(), 4f, delta);

}
Expand Down

0 comments on commit 3ea2c7d

Please sign in to comment.