Skip to content

Commit

Permalink
MXNet Java bug fixes and experience improvement (apache#14213)
Browse files Browse the repository at this point in the history
* improve Java user experience

* add the new examples

* fixed based on the comments
  • Loading branch information
lanking520 committed Apr 26, 2019
1 parent 915a4ff commit 532f3e2
Show file tree
Hide file tree
Showing 9 changed files with 231 additions and 25 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,8 @@ class NDArray private[mxnet] (val nd: org.apache.mxnet.NDArray ) {
this(NDArray.array(arr, shape, ctx))
}

override def toString: String = nd.toString

def serialize(): Array[Byte] = nd.serialize()

/**
Expand Down
22 changes: 14 additions & 8 deletions scala-package/mxnet-demo/java-demo/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,17 @@
<!--- under the License. -->

# MXNet Java Sample Project
This is an project created to use Maven-published Scala/Java package with two Java examples.
This is a project demonstrating how to use the Maven published Scala/Java MXNet package.
The examples provided include:
* NDArray creation
* NDArray operation
* Object Detection using the Inference API
* Image Classification using the Predictor API

## Setup
You can use the `Makefile` to make the Java package. Simply do the following:
```Bash
make javademo
You are required to use Maven to build the package with the following commands under `java-demo`:
```
mvn package
```
This will load the default parameter for all the environment variable.
If you want to run with GPU on Linux, just simply add `USE_CUDA=1` when you run the make file
Expand All @@ -41,16 +47,16 @@ The `SCALA_PKG_PROFILE` should be chosen from `osx-x86_64-cpu`, `linux-x86_64-cp


## Run
### Hello World
The Scala file is being executed using Java. You can execute the helloWorld example as follows:
### NDArrayCreation
The Scala file is being executed using Java. You can execute the `NDArrayCreation` example as follows:
```Bash
bash bin/java_sample.sh
```
You can also run the following command manually:
```Bash
java -cp $CLASSPATH sample.HelloWorld
java -cp $CLASSPATH sample.NDArrayCreation
```
However, you have to define the Classpath before you run the demo code. More information can be found in the `java_sample.sh`.
However, you have to define the Classpath before you run the demo code. More information can be found in `bin/java_sample.sh`.
The `CLASSPATH` should point to the jar file you have downloaded.

It will load the library automatically and run the example
Expand Down
4 changes: 2 additions & 2 deletions scala-package/mxnet-demo/java-demo/bin/java_sample.sh
Original file line number Diff line number Diff line change
Expand Up @@ -16,5 +16,5 @@
# under the License.
#!/bin/bash
CURR_DIR=$(cd $(dirname $0)/../; pwd)
CLASSPATH=$CLASSPATH:$CURR_DIR/target/*:$CLASSPATH:$CURR_DIR/target/classes/lib/*
java -Xmx8G -cp $CLASSPATH mxnet.HelloWorld
CLASSPATH=$CLASSPATH:$CURR_DIR/target/*:$CLASSPATH:$CURR_DIR/target/dependency/*
java -Xmx8G -cp $CLASSPATH mxnet.NDArrayCreation
4 changes: 2 additions & 2 deletions scala-package/mxnet-demo/java-demo/bin/run_od.sh
Original file line number Diff line number Diff line change
Expand Up @@ -16,5 +16,5 @@
# under the License.
#!/bin/bash
CURR_DIR=$(cd $(dirname $0)/../; pwd)
CLASSPATH=$CLASSPATH:$CURR_DIR/target/*:$CLASSPATH:$CURR_DIR/target/classes/lib/*
java -Xmx8G -cp $CLASSPATH mxnet.ObjectDetection
CLASSPATH=$CLASSPATH:$CURR_DIR/target/*:$CLASSPATH:$CURR_DIR/target/dependency/*
java -Xmx8G -cp $CLASSPATH mxnet.ObjectDetection
6 changes: 6 additions & 0 deletions scala-package/mxnet-demo/java-demo/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,12 @@
<artifactId>mxnet-full_${mxnet.scalaprofile}-${mxnet.profile}</artifactId>
<version>${mxnet.version}</version>
</dependency>
<dependency>
<groupId>org.apache.mxnet</groupId>
<artifactId>mxnet-full_${mxnet.scalaprofile}-${mxnet.profile}</artifactId>
<version>${mxnet.version}</version>
<classifier>sources</classifier>
</dependency>
<dependency>
<groupId>commons-io</groupId>
<artifactId>commons-io</artifactId>
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package mxnet;

import org.apache.commons.io.FileUtils;
import org.apache.mxnet.infer.javaapi.Predictor;
import org.apache.mxnet.javaapi.*;

import java.io.BufferedReader;
import java.io.File;
import java.io.FileReader;
import java.io.IOException;
import java.net.URL;
import java.util.ArrayList;
import java.util.List;

public class ImageClassification {
private static String modelPath;
private static String imagePath;

private static void downloadUrl(String url, String filePath) {
File tmpFile = new File(filePath);
if (!tmpFile.exists()) {
try {
FileUtils.copyURLToFile(new URL(url), tmpFile);
} catch (Exception exception) {
System.err.println(exception);
}
}
}

public static void downloadModelImage() {
String tempDirPath = System.getProperty("java.io.tmpdir");
String baseUrl = "https://s3.us-east-2.amazonaws.com/scala-infer-models";
downloadUrl(baseUrl + "/resnet-18/resnet-18-symbol.json",
tempDirPath + "/resnet18/resnet-18-symbol.json");
downloadUrl(baseUrl + "/resnet-18/resnet-18-0000.params",
tempDirPath + "/resnet18/resnet-18-0000.params");
downloadUrl(baseUrl + "/resnet-18/synset.txt",
tempDirPath + "/resnet18/synset.txt");
downloadUrl("https://s3.amazonaws.com/model-server/inputs/Pug-Cookie.jpg",
tempDirPath + "/inputImages/resnet18/Pug-Cookie.jpg");
modelPath = tempDirPath + File.separator + "resnet18/resnet-18";
imagePath = tempDirPath + File.separator +
"inputImages/resnet18/Pug-Cookie.jpg";
}

/**
* Helper class to print the maximum prediction result
* @param probabilities The float array of probability
* @param modelPathPrefix model Path needs to load the synset.txt
*/
private static String printMaximumClass(float[] probabilities,
String modelPathPrefix) throws IOException {
String synsetFilePath = modelPathPrefix.substring(0,
1 + modelPathPrefix.lastIndexOf(File.separator)) + "/synset.txt";
BufferedReader reader = new BufferedReader(new FileReader(synsetFilePath));
ArrayList<String> list = new ArrayList<>();
String line = reader.readLine();

while (line != null){
list.add(line);
line = reader.readLine();
}
reader.close();

int maxIdx = 0;
for (int i = 1;i<probabilities.length;i++) {
if (probabilities[i] > probabilities[maxIdx]) {
maxIdx = i;
}
}

return "Probability : " + probabilities[maxIdx] + " Class : " + list.get(maxIdx) ;
}

public static void main(String[] args) {
// Download the model and Image
downloadModelImage();

// Prepare the model
List<Context> context = new ArrayList<Context>();
context.add(Context.cpu());
List<DataDesc> inputDesc = new ArrayList<>();
Shape inputShape = new Shape(new int[]{1, 3, 224, 224});
inputDesc.add(new DataDesc("data", inputShape, DType.Float32(), "NCHW"));
Predictor predictor = new Predictor(modelPath, inputDesc, context,0);

// Prepare data
NDArray nd = Image.imRead(imagePath, 1, true);
nd = Image.imResize(nd, 224, 224, null);
nd = NDArray.transpose(nd, new Shape(new int[]{2, 0, 1}), null)[0]; // HWC to CHW
nd = NDArray.expand_dims(nd, 0, null)[0]; // Add N -> NCHW
nd = nd.asType(DType.Float32()); // Inference with Float32

// Predict directly
float[][] result = predictor.predict(new float[][]{nd.toArray()});
try {
System.out.println("Predict with Float input");
System.out.println(printMaximumClass(result[0], modelPath));
} catch (IOException e) {
System.err.println(e);
}

// predict with NDArray
List<NDArray> ndList = new ArrayList<>();
ndList.add(nd);
List<NDArray> ndResult = predictor.predictWithNDArray(ndList);
try {
System.out.println("Predict with NDArray");
System.out.println(printMaximumClass(ndResult.get(0).toArray(), modelPath));
} catch (IOException e) {
System.err.println(e);
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package mxnet;

import org.apache.mxnet.javaapi.*;

public class NDArrayCreation {
static NDArray$ NDArray = NDArray$.MODULE$;
public static void main(String[] args) {

// Create new NDArray
NDArray nd = new NDArray(new float[]{2.0f, 3.0f}, new Shape(new int[]{1, 2}), Context.cpu());
System.out.println(nd);

// create new Double NDArray
NDArray ndDouble = new NDArray(new double[]{2.0d, 3.0d}, new Shape(new int[]{2, 1}), Context.cpu());
System.out.println(ndDouble);

// create ones
NDArray ones = NDArray.ones(Context.cpu(), new int[] {1, 2, 3});
System.out.println(ones);

// random
NDArray random = NDArray.random_uniform(
NDArray.new random_uniformParam()
.setLow(0.0f)
.setHigh(2.0f)
.setShape(new Shape(new int[]{10, 10}))
)[0];
System.out.println(random);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,31 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package mxnet;

import org.apache.mxnet.javaapi.*;
import java.util.Arrays;

public class HelloWorld {
public class NDArrayOperation {
static NDArray$ NDArray = NDArray$.MODULE$;
public static void main(String[] args) {
System.out.println("Hello World!");
NDArray nd = new NDArray(new float[]{2.0f, 3.0f}, new Shape(new int[]{1, 2}), Context.cpu());
System.out.println(nd.shape());

// Transpose
NDArray ndT = nd.T();
System.out.println(nd);
System.out.println(ndT);

// change Data Type
NDArray ndInt = nd.asType(DType.Int32());
System.out.println(ndInt);

// element add
NDArray eleAdd = NDArray.elemwise_add(nd, nd, null)[0];
System.out.println(eleAdd);

// norm (L2 Norm)
NDArray normed = NDArray.norm(NDArray.new normParam(nd))[0];
System.out.println(normed);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -68,20 +68,18 @@ public static void downloadModelImage() {

public static void main(String[] args) {
List<Context> context = new ArrayList<Context>();
if (System.getenv().containsKey("SCALA_TEST_ON_GPU") &&
Integer.valueOf(System.getenv("SCALA_TEST_ON_GPU")) == 1) {
context.add(Context.gpu());
} else {
context.add(Context.cpu());
}
context.add(Context.cpu());
downloadModelImage();

List<List<ObjectDetectorOutput>> output
= runObjectDetectionSingle(modelPath, imagePath, context);

Shape inputShape = new Shape(new int[] {1, 3, 512, 512});
Shape outputShape = new Shape(new int[] {1, 6132, 6});
int width = inputShape.get(2);
int height = inputShape.get(3);
List<List<ObjectDetectorOutput>> output
= runObjectDetectionSingle(modelPath, imagePath, context);
String outputStr = "\n";

for (List<ObjectDetectorOutput> ele : output) {
for (ObjectDetectorOutput i : ele) {
outputStr += "Class: " + i.getClassName() + "\n";
Expand All @@ -98,4 +96,4 @@ public static void main(String[] args) {
}
System.out.println(outputStr);
}
}
}

0 comments on commit 532f3e2

Please sign in to comment.