Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
add test and change PredictorExample
Browse files Browse the repository at this point in the history
  • Loading branch information
lanking520 committed Jan 8, 2019
1 parent 0a3bc8d commit d286beb
Show file tree
Hide file tree
Showing 2 changed files with 52 additions and 5 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
package org.apache.mxnet.javaapi;

import org.apache.commons.io.FileUtils;
import org.junit.BeforeClass;
import org.junit.Test;
import java.io.File;
import java.net.URL;

public class ImageTest {

private String imLocation;

private void downloadUrl(String url, String filePath, int maxRetry) throws Exception{
File tmpFile = new File(filePath);
Boolean success = false;
if (!tmpFile.exists()) {
while (maxRetry > 0 && !success) {
try {
FileUtils.copyURLToFile(new URL(url), tmpFile);
success = true;
} catch(Exception e){
maxRetry -= 1;
}
}
} else {
success = true;
}
if (!success) throw new Exception("$url Download failed!");
}

@BeforeClass
public void downloadFile() throws Exception {
String tempDirPath = System.getProperty("java.io.tmpdir");
imLocation = tempDirPath + "/inputImages/Pug-Cookie.jpg";
try {
downloadUrl("https://s3.amazonaws.com/model-server/inputs/Pug-Cookie.jpg",
imLocation, 3);
} catch (Exception e) {
throw e;
}
}

@Test
public void testImageProcess() {
NDArray nd = Image.imRead(imLocation, 1, true, null);
NDArray nd2 = Image.imResize(nd, 224, 224, null, null);
NDArray cropped = Image.fixedCrop(nd, 0, 0, 224, 224);
Image.toImage(cropped);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -101,18 +101,15 @@ public static void main(String[] args) {
NDArray img = Image.imRead(inst.inputImagePath, 1, true, null);
img = Image.imResize(img, 224, 224, null, null);
// predict
float[][] result = predictor.predict(new float[][]{imagePreprocess(img)});
float[][] result = predictor.predict(new float[][]{img.toArray()});
try {
System.out.println("Predict with Float input");
System.out.println(printMaximumClass(result[0], inst.modelPathPrefix));
} catch (IOException e) {
System.err.println(e);
}
// predict with NDArray
NDArray nd = new NDArray(
imagePreprocess(img),
new Shape(new int[]{1, 3, 224, 224}),
Context.cpu());
NDArray nd = img;
List<NDArray> ndList = new ArrayList<>();
ndList.add(nd);
List<NDArray> ndResult = predictor.predictWithNDArray(ndList);
Expand Down

0 comments on commit d286beb

Please sign in to comment.