Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
add java example
Browse files Browse the repository at this point in the history
  • Loading branch information
lanking520 committed Jan 8, 2019
1 parent 7b7fd15 commit 0a3bc8d
Show file tree
Hide file tree
Showing 2 changed files with 98 additions and 75 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.mxnet.javaapi
// scalastyle:off
import java.awt.image.BufferedImage
// scalastyle:on
import java.io.InputStream

object Image {
/**
* Decode image with OpenCV.
* Note: return image in RGB by default, instead of OpenCV's default BGR.
* @param buf Buffer containing binary encoded image
* @param flag Convert decoded image to grayscale (0) or color (1).
* @param toRGB Whether to convert decoded image
* to mxnet's default RGB format (instead of opencv's default BGR).
* @return NDArray in HWC format
*/
def imDecode(buf: Array[Byte], flag: Int, toRGB: Boolean, out: NDArray): NDArray = {
org.apache.mxnet.Image.imDecode(buf, flag, toRGB, Some(out))
}

/**
* Same imageDecode with InputStream
* @param inputStream the inputStream of the image
* @return NDArray in HWC format
*/
def imDecode(inputStream: InputStream, flag: Int = 1, toRGB: Boolean = true,
out: NDArray): NDArray = {
org.apache.mxnet.Image.imDecode(inputStream, flag, toRGB, Some(out))
}

/**
* Read and decode image with OpenCV.
* Note: return image in RGB by default, instead of OpenCV's default BGR.
* @param filename Name of the image file to be loaded.
* @param flag Convert decoded image to grayscale (0) or color (1).
* @param toRGB Whether to convert decoded image to mxnet's default RGB format
* (instead of opencv's default BGR).
* @return org.apache.mxnet.NDArray in HWC format
*/
def imRead(filename: String, flag: Int, toRGB: Boolean = true, out: NDArray): NDArray = {
org.apache.mxnet.Image.imRead(filename, Some(flag), Some(toRGB), Some(out))
}

/**
* Resize image with OpenCV.
* @param src source image in NDArray
* @param w Width of resized image.
* @param h Height of resized image.
* @param interp Interpolation method (default=cv2.INTER_LINEAR).
* @return org.apache.mxnet.NDArray
*/
def imResize(src: NDArray, w: Int, h: Int,
interp: Integer, out: NDArray): NDArray = {
org.apache.mxnet.Image.imResize(src, w, h, Some(interp), Some(out))
}

/**
* Do a fixed crop on the image
* @param src Src image in NDArray
* @param x0 starting x point
* @param y0 starting y point
* @param w width of the image
* @param h height of the image
* @return cropped NDArray
*/
def fixedCrop(src: NDArray, x0: Int, y0: Int, w: Int, h: Int): NDArray = {
org.apache.mxnet.Image.fixedCrop(src, x0, y0, w, h)
}

/**
* Convert a NDArray image to a real image
* The time cost will increase if the image resolution is big
* @param src Source image file in RGB
* @return Buffered Image
*/
def toImage(src: NDArray): BufferedImage = {
org.apache.mxnet.Image.toImage(src)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,6 @@
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import javax.imageio.ImageIO;
import java.awt.Graphics2D;
import java.awt.image.BufferedImage;
import java.io.BufferedReader;
import java.io.File;
Expand All @@ -48,76 +46,6 @@ public class PredictorExample {

final static Logger logger = LoggerFactory.getLogger(PredictorExample.class);

/**
* Load the image from file to buffered image
* It can be replaced by loadImageFromFile from ObjectDetector
* @param inputImagePath input image Path in String
* @return Buffered image
*/
private static BufferedImage loadIamgeFromFile(String inputImagePath) {
BufferedImage buf = null;
try {
buf = ImageIO.read(new File(inputImagePath));
} catch (IOException e) {
System.err.println(e);
}
return buf;
}

/**
* Reshape the current image using ImageIO and Graph2D
* It can be replaced by reshapeImage from ObjectDetector
* @param buf Buffered image
* @param newWidth desired width
* @param newHeight desired height
* @return a reshaped bufferedImage
*/
private static BufferedImage reshapeImage(BufferedImage buf, int newWidth, int newHeight) {
BufferedImage resizedImage = new BufferedImage(newWidth, newHeight, BufferedImage.TYPE_INT_RGB);
Graphics2D g = resizedImage.createGraphics();
g.drawImage(buf, 0, 0, newWidth, newHeight, null);
g.dispose();
return resizedImage;
}

/**
* Convert an image from a buffered image into pixels float array
* It can be replaced by bufferedImageToPixels from ObjectDetector
* @param buf buffered image
* @return Float array
*/
private static float[] imagePreprocess(BufferedImage buf) {
// Get height and width of the image
int w = buf.getWidth();
int h = buf.getHeight();

// get an array of integer pixels in the default RGB color mode
int[] pixels = buf.getRGB(0, 0, w, h, null, 0, w);

// 3 times height and width for R,G,B channels
float[] result = new float[3 * h * w];

int row = 0;
// copy pixels to array vertically
while (row < h) {
int col = 0;
// copy pixels to array horizontally
while (col < w) {
int rgb = pixels[row * w + col];
// getting red color
result[0 * h * w + row * w + col] = (rgb >> 16) & 0xFF;
// getting green color
result[1 * h * w + row * w + col] = (rgb >> 8) & 0xFF;
// getting blue color
result[2 * h * w + row * w + col] = rgb & 0xFF;
col += 1;
}
row += 1;
}
buf.flush();
return result;
}

/**
* Helper class to print the maximum prediction result
* @param probabilities The float array of probability
Expand Down Expand Up @@ -170,9 +98,8 @@ public static void main(String[] args) {
inputDesc.add(new DataDesc("data", inputShape, DType.Float32(), "NCHW"));
Predictor predictor = new Predictor(inst.modelPathPrefix, inputDesc, context,0);
// Prepare data
BufferedImage img = loadIamgeFromFile(inst.inputImagePath);

img = reshapeImage(img, 224, 224);
NDArray img = Image.imRead(inst.inputImagePath, 1, true, null);
img = Image.imResize(img, 224, 224, null, null);
// predict
float[][] result = predictor.predict(new float[][]{imagePreprocess(img)});
try {
Expand Down

0 comments on commit 0a3bc8d

Please sign in to comment.