Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
[MXNET-1285] Draw bounding box with Scala/Java Image API (#14474)
Browse files Browse the repository at this point in the history
* new feature to draw bounding box

* add Java support

* add point wise verification

* cancel the check on top-left corner

* add this example to Java world and fixing bugs
  • Loading branch information
lanking520 authored and nswamy committed Apr 5, 2019
1 parent 0ab1da2 commit f1354b4
Show file tree
Hide file tree
Showing 10 changed files with 161 additions and 34 deletions.
1 change: 1 addition & 0 deletions scala-package/.gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -9,3 +9,4 @@ core/src/main/scala/org/apache/mxnet/SymbolBase.scala
core/src/main/scala/org/apache/mxnet/SymbolRandomAPIBase.scala
examples/scripts/infer/images/
examples/scripts/infer/models/
examples/scripts/infer/objectdetector/boundingImage.png
54 changes: 54 additions & 0 deletions scala-package/core/src/main/scala/org/apache/mxnet/Image.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

package org.apache.mxnet
// scalastyle:off
import java.awt.{BasicStroke, Color, Graphics2D}
import java.awt.image.BufferedImage
// scalastyle:on
import java.io.InputStream
Expand Down Expand Up @@ -182,4 +183,57 @@ object Image {
img
}

/**
* Helper function to generate ramdom colors
* @param transparency The transparency level
* @return Color
*/
private def randomColor(transparency: Option[Float] = Some(1.0f)) : Color = {
new Color(
Math.random().toFloat, Math.random().toFloat, Math.random().toFloat,
transparency.get
)
}

/**
* Method to draw bounding boxes for an image
* @param src Source of the buffered image
* @param coordinate Contains Map of xmin, xmax, ymin, ymax
* corresponding to top-left and down-right points
* @param names The name set of the bounding box
* @param stroke Thickness of the bounding box
* @param fontSizeMult Font size multiplier
* @param transparency Transparency of the bounding box
*/
def drawBoundingBox(src: BufferedImage, coordinate: Array[Map[String, Int]],
names: Option[Array[String]] = None,
stroke : Option[Int] = Some(3),
fontSizeMult : Option[Float] = Some(1.0f),
transparency: Option[Float] = Some(1.0f)): Unit = {
val g2d : Graphics2D = src.createGraphics()
g2d.setStroke(new BasicStroke(stroke.get))
// Increase the size of font
val currentFont = g2d.getFont
val newFont = currentFont.deriveFont(currentFont.getSize * fontSizeMult.get)
g2d.setFont(newFont)
// Get font metrics to draw the font box
val fm = g2d.getFontMetrics(newFont)
for (idx <- coordinate.indices) {
val map = coordinate(idx)
g2d.setColor(randomColor(transparency).darker())
g2d.drawRect(map("xmin"), map("ymin"), map("xmax") - map("xmin"), map("ymax") - map("ymin"))
// Write the name of the bounding box
if (names.isDefined) {
val x = map("xmin") - stroke.get
val y = map("ymin")
val h = fm.getHeight
val w = fm.charsWidth(names.get(idx).toCharArray, 0, names.get(idx).length())
g2d.fillRect(x, y - h, w, h)
g2d.setColor(Color.WHITE)
g2d.drawString(names.get(idx), x, y)
}
}
g2d.dispose()
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -20,15 +20,16 @@ package org.apache.mxnet.javaapi
import java.awt.image.BufferedImage
// scalastyle:on
import java.io.InputStream
import scala.collection.JavaConverters._

object Image {
/**
* Decode image with OpenCV.
* Note: return image in RGB by default, instead of OpenCV's default BGR.
* @param buf Buffer containing binary encoded image
* @param flag Convert decoded image to grayscale (0) or color (1).
* @param buf Buffer containing binary encoded image
* @param flag Convert decoded image to grayscale (0) or color (1).
* @param toRGB Whether to convert decoded image
* to mxnet's default RGB format (instead of opencv's default BGR).
* to mxnet's default RGB format (instead of opencv's default BGR).
* @return NDArray in HWC format with DType [[DType.UInt8]]
*/
def imDecode(buf: Array[Byte], flag: Int, toRGB: Boolean): NDArray = {
Expand All @@ -43,8 +44,8 @@ object Image {
* Same imageDecode with InputStream
*
* @param inputStream the inputStream of the image
* @param flag Convert decoded image to grayscale (0) or color (1).
* @param toRGB Whether to convert decoded image
* @param flag Convert decoded image to grayscale (0) or color (1).
* @param toRGB Whether to convert decoded image
* @return NDArray in HWC format with DType [[DType.UInt8]]
*/
def imDecode(inputStream: InputStream, flag: Int, toRGB: Boolean): NDArray = {
Expand All @@ -60,7 +61,7 @@ object Image {
* Note: return image in RGB by default, instead of OpenCV's default BGR.
* @param filename Name of the image file to be loaded.
* @param flag Convert decoded image to grayscale (0) or color (1).
* @param toRGB Whether to convert decoded image to mxnet's default RGB format
* @param toRGB Whether to convert decoded image to mxnet's default RGB format
* (instead of opencv's default BGR).
* @return org.apache.mxnet.NDArray in HWC format with DType [[DType.UInt8]]
*/
Expand All @@ -74,10 +75,10 @@ object Image {

/**
* Resize image with OpenCV.
* @param src source image in NDArray
* @param w Width of resized image.
* @param h Height of resized image.
* @param interp Interpolation method (default=cv2.INTER_LINEAR).
* @param src source image in NDArray
* @param w Width of resized image.
* @param h Height of resized image.
* @param interp Interpolation method (default=cv2.INTER_LINEAR).
* @return org.apache.mxnet.NDArray
*/
def imResize(src: NDArray, w: Int, h: Int, interp: Integer): NDArray = {
Expand All @@ -92,10 +93,10 @@ object Image {
/**
* Do a fixed crop on the image
* @param src Src image in NDArray
* @param x0 starting x point
* @param y0 starting y point
* @param w width of the image
* @param h height of the image
* @param x0 starting x point
* @param y0 starting y point
* @param w width of the image
* @param h height of the image
* @return cropped NDArray
*/
def fixedCrop(src: NDArray, x0: Int, y0: Int, w: Int, h: Int): NDArray = {
Expand All @@ -111,4 +112,21 @@ object Image {
def toImage(src: NDArray): BufferedImage = {
org.apache.mxnet.Image.toImage(src)
}

/**
* Draw bounding boxes on the image
* @param src buffered image to draw on
* @param coordinate Contains Map of xmin, xmax, ymin, ymax
* corresponding to top-left and down-right points
* @param names The name set of the bounding box
*/
def drawBoundingBox(src: BufferedImage,
coordinate: java.util.List[
java.util.Map[java.lang.String, java.lang.Integer]],
names: java.util.List[java.lang.String]): Unit = {
val coord = coordinate.asScala.map(
_.asScala.map{case (name, value) => (name, Integer2int(value))}.toMap).toArray
org.apache.mxnet.Image.drawBoundingBox(src, coord, Option(names.asScala.toArray))
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,15 @@
import org.apache.commons.io.FileUtils;
import org.junit.BeforeClass;
import org.junit.Test;

import javax.imageio.ImageIO;
import java.awt.image.BufferedImage;
import java.io.File;
import java.net.URL;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;

import static org.junit.Assert.assertArrayEquals;

Expand Down Expand Up @@ -56,12 +63,23 @@ public static void downloadFile() throws Exception {
}

@Test
public void testImageProcess() {
public void testImageProcess() throws Exception {
NDArray nd = Image.imRead(imLocation, 1, true);
assertArrayEquals(nd.shape().toArray(), new int[]{576, 1024, 3});
NDArray nd2 = Image.imResize(nd, 224, 224, null);
assertArrayEquals(nd2.shape().toArray(), new int[]{224, 224, 3});
NDArray cropped = Image.fixedCrop(nd, 0, 0, 224, 224);
Image.toImage(cropped);
BufferedImage buf = ImageIO.read(new File(imLocation));
Map<String, Integer> map = new HashMap<>();
map.put("xmin", 190);
map.put("xmax", 850);
map.put("ymin", 50);
map.put("ymax", 450);
List<Map<String, Integer>> box = new ArrayList<>();
box.add(map);
List<String> names = new ArrayList<>();
names.add("pug");
Image.drawBoundingBox(buf, box, names);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -97,4 +97,25 @@ class ImageSuite extends FunSuite with BeforeAndAfterAll {
logger.info(s"converted image stored in ${tempDirPath + "/inputImages/out.png"}")
}

test("Test draw Bounding box") {
val buf = ImageIO.read(new File(imLocation))
val box = Array(
Map("xmin" -> 190, "xmax" -> 850, "ymin" -> 50, "ymax" -> 450),
Map("xmin" -> 200, "xmax" -> 350, "ymin" -> 440, "ymax" -> 530)
)
val names = Array("pug", "cookie")
Image.drawBoundingBox(buf, box, Some(names), fontSizeMult = Some(1.4f))
val tempDirPath = System.getProperty("java.io.tmpdir")
ImageIO.write(buf, "png", new File(tempDirPath + "/inputImages/out2.png"))
logger.info(s"converted image stored in ${tempDirPath + "/inputImages/out2.png"}")
for (coord <- box) {
val topLeft = buf.getRGB(coord("xmin"), coord("ymin"))
val downLeft = buf.getRGB(coord("xmin"), coord("ymax"))
val topRight = buf.getRGB(coord("xmax"), coord("ymin"))
val downRight = buf.getRGB(coord("xmax"), coord("ymax"))
require(downLeft == downRight)
require(topRight == downRight)
}
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ After the previous steps, you should be able to run the code using the following
From the `scala-package/examples/scripts/infer/objectdetector/` folder run:

```bash
./run_ssd_example.sh ../models/resnet50_ssd/resnet50_ssd_model ../images/dog.jpg ../images
./run_ssd_java_example.sh ../models/resnet50_ssd/resnet50_ssd_model ../images/dog.jpg ../images
```

**Notes**:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,11 @@
import org.apache.mxnet.infer.javaapi.ObjectDetector;

// scalastyle:off
import javax.imageio.ImageIO;
import java.awt.image.BufferedImage;
// scalastyle:on

import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.*;

import java.io.File;

Expand Down Expand Up @@ -128,22 +127,34 @@ public static void main(String[] args) {
try {
Shape inputShape = new Shape(new int[]{1, 3, 512, 512});
Shape outputShape = new Shape(new int[]{1, 6132, 6});


int width = inputShape.get(2);
int height = inputShape.get(3);

StringBuilder outputStr = new StringBuilder().append("\n");

List<List<ObjectDetectorOutput>> output
= runObjectDetectionSingle(mdprefixDir, imgPath, context);


// Creating Bounding box material
BufferedImage buf = ImageIO.read(new File(imgPath));
int width = buf.getWidth();
int height = buf.getHeight();
List<Map<String, Integer>> boxes = new ArrayList<>();
List<String> names = new ArrayList<>();
for (List<ObjectDetectorOutput> ele : output) {
for (ObjectDetectorOutput i : ele) {
outputStr.append("Class: " + i.getClassName() + "\n");
outputStr.append("Probabilties: " + i.getProbability() + "\n");

List<Float> coord = Arrays.asList(i.getXMin() * width,
i.getXMax() * height, i.getYMin() * width, i.getYMax() * height);
names.add(i.getClassName());
Map<String, Integer> map = new HashMap<>();
float xmin = i.getXMin() * width;
float xmax = i.getXMax() * width;
float ymin = i.getYMin() * height;
float ymax = i.getYMax() * height;
List<Float> coord = Arrays.asList(xmin, xmax, ymin, ymax);
map.put("xmin", (int) xmin);
map.put("xmax", (int) xmax);
map.put("ymin", (int) ymin);
map.put("ymax", (int) ymax);
boxes.add(map);
StringBuilder sb = new StringBuilder();
for (float c : coord) {
sb.append(", ").append(c);
Expand All @@ -152,7 +163,12 @@ public static void main(String[] args) {
}
}
logger.info(outputStr.toString());


// Covert to image
Image.drawBoundingBox(buf, boxes, names);
File outputFile = new File("boundingImage.png");
ImageIO.write(buf, "png", outputFile);

List<List<List<ObjectDetectorOutput>>> outputList =
runObjectDetectionBatch(mdprefixDir, imgDir, context);

Expand All @@ -177,7 +193,6 @@ public static void main(String[] args) {
}
}
logger.info(outputStr.toString());

} catch (Exception e) {
logger.error(e.getMessage(), e);
parser.printUsage(System.err);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ class ObjectDetector(modelPathPrefix: String,
if (topK.isDefined) {
var sortedIndices = predictResult.zipWithIndex.sortBy(-_._1(1)).map(_._2)
sortedIndices = sortedIndices.take(topK.get)
// takeRight(5) would provide the output as Array[Accuracy, Xmin, Ymin, Xmax, Ymax
// takeRight(5) would provide the output as Array[Accuracy, Xmin, Ymin, Xmax, Ymax]
result = sortedIndices.map(idx
=> (synset(predictResult(idx)(0).toInt),
predictResult(idx).takeRight(5))).toIndexedSeq
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,14 +52,14 @@ class ObjectDetectorOutput (className: String, args: Array[Float]){
*
* @return Float of the max X coordinate for the object bounding box
*/
def getXMax: Float = args(2)
def getXMax: Float = args(3)

/**
* Gets the minimum Y coordinate for the bounding box containing the predicted object.
*
* @return Float of the min Y coordinate for the object bounding box
*/
def getYMin: Float = args(3)
def getYMin: Float = args(2)

/**
* Gets the maximum Y coordinate for the bounding box containing the predicted object.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,8 @@ public void testConstructor() {
Assert.assertEquals(odOutput.getClassName(), predictedClassName);
Assert.assertEquals("Threshold not matching", odOutput.getProbability(), 0f, delta);
Assert.assertEquals("Threshold not matching", odOutput.getXMin(), 1f, delta);
Assert.assertEquals("Threshold not matching", odOutput.getXMax(), 2f, delta);
Assert.assertEquals("Threshold not matching", odOutput.getYMin(), 3f, delta);
Assert.assertEquals("Threshold not matching", odOutput.getXMax(), 3f, delta);
Assert.assertEquals("Threshold not matching", odOutput.getYMin(), 2f, delta);
Assert.assertEquals("Threshold not matching", odOutput.getYMax(), 4f, delta);

}
Expand Down

0 comments on commit f1354b4

Please sign in to comment.