Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Fixes #14181, validate model output shape for ObjectDetector. #14215

Merged
merged 1 commit into from
Mar 7, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,8 @@ class ImageClassifier(modelPathPrefix: String,
protected[infer] val height = inputShape(inputLayout.indexOf('H'))
protected[infer] val width = inputShape(inputLayout.indexOf('W'))

def outputShapes: IndexedSeq[(String, Shape)] = predictor.outputShapes

/**
* To classify the image according to the provided model
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,13 @@ package org.apache.mxnet.infer
// scalastyle:off
import java.awt.image.BufferedImage

import org.apache.mxnet.Shape

import scala.collection.parallel.mutable.ParArray
// scalastyle:on
import org.apache.mxnet.NDArray
import org.apache.mxnet.DataDesc
import org.apache.mxnet.Context
import scala.collection.mutable.ListBuffer

/**
* The ObjectDetector class helps to run ObjectDetection tasks where the goal
Expand Down Expand Up @@ -174,7 +175,25 @@ class ObjectDetector(modelPathPrefix: String,
contexts: Array[Context] = Context.cpu(),
epoch: Option[Int] = Some(0)):
ImageClassifier = {
new ImageClassifier(modelPathPrefix, inputDescriptors, contexts, epoch)
}
val imageClassifier: ImageClassifier =
new ImageClassifier(modelPathPrefix, inputDescriptors, contexts, epoch)

val shapes: IndexedSeq[(String, Shape)] = imageClassifier.outputShapes
if (shapes.length != inputDescriptors.length) {
throw new IllegalStateException(s"Invalid output shapes, expected:" +
s" $inputDescriptors.length, actual: $shapes.length.")
}
shapes.map(_._2).foreach(shape => {
if (shape.length < 3) {
throw new IllegalArgumentException("Invalid output shapes, the model doesn't"
+ " support object detection.")
}
if (shape.get(2) < 6) {
throw new IllegalArgumentException("Invalid output shapes, the model doesn't"
+ " support object detection with bounding box.")
}
})

imageClassifier
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,11 @@ private[infer] trait PredictBase {
*/
def predictWithNDArray(input: IndexedSeq[NDArray]): IndexedSeq[NDArray]

/**
* Get model output shapes.
* @return model output shapes.
*/
def outputShapes: IndexedSeq[(String, Shape)]
}

/**
Expand Down Expand Up @@ -122,6 +127,8 @@ class Predictor(modelPathPrefix: String,

protected[infer] val mod = loadModule()

override def outputShapes: IndexedSeq[(String, Shape)] = mod.outputShapes

/**
* Takes input as IndexedSeq one dimensional arrays and creates the NDArray needed for inference
* The array will be reshaped based on the input descriptors.
Expand Down