diff --git a/scala-package/infer/src/main/scala/org/apache/mxnet/infer/Predictor.scala b/scala-package/infer/src/main/scala/org/apache/mxnet/infer/Predictor.scala index d4bce9f0d71e..54a08a81f7ff 100644 --- a/scala-package/infer/src/main/scala/org/apache/mxnet/infer/Predictor.scala +++ b/scala-package/infer/src/main/scala/org/apache/mxnet/infer/Predictor.scala @@ -18,7 +18,7 @@ package org.apache.mxnet.infer import org.apache.mxnet.io.NDArrayIter -import org.apache.mxnet.{Context, DataDesc, NDArray, Shape} +import org.apache.mxnet._ import org.apache.mxnet.module.Module import scala.collection.mutable.ListBuffer @@ -36,11 +36,13 @@ private[infer] trait PredictBase { *
* This method will take input as IndexedSeq one dimensional arrays and creates the * NDArray needed for inference. The array will be reshaped based on the input descriptors. - * @param input: An IndexedSequence of a one-dimensional array. + * @param input: An Indexed Sequence of a one-dimensional array of datatype + * Float or Double An IndexedSequence is needed when the model has more than one input. * @return Indexed sequence array of outputs */ - def predict(input: IndexedSeq[Array[Float]]): IndexedSeq[Array[Float]] + def predict[@specialized (Base.MX_PRIMITIVES) T](input: IndexedSeq[Array[T]]) + : IndexedSeq[Array[T]] /** * Predict using NDArray as input. @@ -123,13 +125,13 @@ class Predictor(modelPathPrefix: String, * Takes input as IndexedSeq one dimensional arrays and creates the NDArray needed for inference * The array will be reshaped based on the input descriptors. * - * @param input: An IndexedSequence of a one-dimensional array. + * @param input: An IndexedSequence of a one-dimensional array + * of data type Float or Double. An IndexedSequence is needed when the model has more than one input. * @return Indexed sequence array of outputs */ - override def predict(input: IndexedSeq[Array[Float]]) - : IndexedSeq[Array[Float]] = { - + override def predict[@specialized (Base.MX_PRIMITIVES) T](input: IndexedSeq[Array[T]]) + : IndexedSeq[Array[T]] = { require(input.length == inputDescriptors.length, s"number of inputs provided: ${input.length} does not match number of inputs " + s"in inputDescriptors: ${inputDescriptors.length}") @@ -139,6 +141,19 @@ class Predictor(modelPathPrefix: String, s"number of elements:${i.length} in the input does not match the shape:" + s"${d.shape.toString()}") } + + // Infer the dtype of input and call relevant method + val result = input(0)(0) match { + case d: Double => predictWithDoubleImpl(input.asInstanceOf[IndexedSeq[Array[Double]]]) + case _ => predictWithFloatImpl(input.asInstanceOf[IndexedSeq[Array[Float]]]) + } + + result.asInstanceOf[IndexedSeq[Array[T]]] + } + + private def predictWithFloatImpl(input: IndexedSeq[Array[Float]]) + : IndexedSeq[Array[Float]] = { + var inputND: ListBuffer[NDArray] = ListBuffer.empty[NDArray] for((i, d) <- input.zip(inputDescriptors)) { @@ -171,6 +186,42 @@ class Predictor(modelPathPrefix: String, result } + private def predictWithDoubleImpl(input: IndexedSeq[Array[Double]]) + : IndexedSeq[Array[Double]] = { + + var inputND: ListBuffer[NDArray] = ListBuffer.empty[NDArray] + + for((i, d) <- input.zip(inputDescriptors)) { + val shape = d.shape.toVector.patch(from = batchIndex, patch = Vector(1), replaced = 1) + + inputND += mxNetHandler.execute(NDArray.array(i, Shape(shape))) + } + + // rebind with batchsize 1 + if (batchSize != 1) { + val desc = iDescriptors.map((f : DataDesc) => new DataDesc(f.name, + Shape(f.shape.toVector.patch(batchIndex, Vector(1), 1)), f.dtype, f.layout) ) + mxNetHandler.execute(mod.bind(desc, forceRebind = true, + forTraining = false)) + } + + val resultND = mxNetHandler.execute(mod.predict(new NDArrayIter( + inputND.toIndexedSeq, dataBatchSize = 1))) + + val result = resultND.map((f : NDArray) => f.toFloat64Array) + + mxNetHandler.execute(inputND.foreach(_.dispose)) + mxNetHandler.execute(resultND.foreach(_.dispose)) + + // rebind to batchSize + if (batchSize != 1) { + mxNetHandler.execute(mod.bind(inputDescriptors, forTraining = false, forceRebind = true)) + } + + result + } + + /** * Predict using NDArray as input * This method is useful when the input is a batch of data