Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Added Float64 in Predictor class
Browse files Browse the repository at this point in the history
  • Loading branch information
piyushghai committed Dec 21, 2018
1 parent 00ce147 commit d1014b3
Showing 1 changed file with 58 additions and 7 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
package org.apache.mxnet.infer

import org.apache.mxnet.io.NDArrayIter
import org.apache.mxnet.{Context, DataDesc, NDArray, Shape}
import org.apache.mxnet._
import org.apache.mxnet.module.Module

import scala.collection.mutable.ListBuffer
Expand All @@ -36,11 +36,13 @@ private[infer] trait PredictBase {
* <p>
* This method will take input as IndexedSeq one dimensional arrays and creates the
* NDArray needed for inference. The array will be reshaped based on the input descriptors.
* @param input: An IndexedSequence of a one-dimensional array.
* @param input: An Indexed Sequence of a one-dimensional array of datatype
* Float or Double
An IndexedSequence is needed when the model has more than one input.
* @return Indexed sequence array of outputs
*/
def predict(input: IndexedSeq[Array[Float]]): IndexedSeq[Array[Float]]
def predict[@specialized (Base.MX_PRIMITIVES) T](input: IndexedSeq[Array[T]])
: IndexedSeq[Array[T]]

/**
* Predict using NDArray as input.
Expand Down Expand Up @@ -123,13 +125,13 @@ class Predictor(modelPathPrefix: String,
* Takes input as IndexedSeq one dimensional arrays and creates the NDArray needed for inference
* The array will be reshaped based on the input descriptors.
*
* @param input: An IndexedSequence of a one-dimensional array.
* @param input: An IndexedSequence of a one-dimensional array
* of data type Float or Double.
An IndexedSequence is needed when the model has more than one input.
* @return Indexed sequence array of outputs
*/
override def predict(input: IndexedSeq[Array[Float]])
: IndexedSeq[Array[Float]] = {

override def predict[@specialized (Base.MX_PRIMITIVES) T](input: IndexedSeq[Array[T]])
: IndexedSeq[Array[T]] = {
require(input.length == inputDescriptors.length,
s"number of inputs provided: ${input.length} does not match number of inputs " +
s"in inputDescriptors: ${inputDescriptors.length}")
Expand All @@ -139,6 +141,19 @@ class Predictor(modelPathPrefix: String,
s"number of elements:${i.length} in the input does not match the shape:" +
s"${d.shape.toString()}")
}

// Infer the dtype of input and call relevant method
val result = input(0)(0) match {
case d: Double => predictWithDoubleImpl(input.asInstanceOf[IndexedSeq[Array[Double]]])
case _ => predictWithFloatImpl(input.asInstanceOf[IndexedSeq[Array[Float]]])
}

result.asInstanceOf[IndexedSeq[Array[T]]]
}

private def predictWithFloatImpl(input: IndexedSeq[Array[Float]])
: IndexedSeq[Array[Float]] = {

var inputND: ListBuffer[NDArray] = ListBuffer.empty[NDArray]

for((i, d) <- input.zip(inputDescriptors)) {
Expand Down Expand Up @@ -171,6 +186,42 @@ class Predictor(modelPathPrefix: String,
result
}

private def predictWithDoubleImpl(input: IndexedSeq[Array[Double]])
: IndexedSeq[Array[Double]] = {

var inputND: ListBuffer[NDArray] = ListBuffer.empty[NDArray]

for((i, d) <- input.zip(inputDescriptors)) {
val shape = d.shape.toVector.patch(from = batchIndex, patch = Vector(1), replaced = 1)

inputND += mxNetHandler.execute(NDArray.array(i, Shape(shape)))
}

// rebind with batchsize 1
if (batchSize != 1) {
val desc = iDescriptors.map((f : DataDesc) => new DataDesc(f.name,
Shape(f.shape.toVector.patch(batchIndex, Vector(1), 1)), f.dtype, f.layout) )
mxNetHandler.execute(mod.bind(desc, forceRebind = true,
forTraining = false))
}

val resultND = mxNetHandler.execute(mod.predict(new NDArrayIter(
inputND.toIndexedSeq, dataBatchSize = 1)))

val result = resultND.map((f : NDArray) => f.toFloat64Array)

mxNetHandler.execute(inputND.foreach(_.dispose))
mxNetHandler.execute(resultND.foreach(_.dispose))

// rebind to batchSize
if (batchSize != 1) {
mxNetHandler.execute(mod.bind(inputDescriptors, forTraining = false, forceRebind = true))
}

result
}


/**
* Predict using NDArray as input
* This method is useful when the input is a batch of data
Expand Down

0 comments on commit d1014b3

Please sign in to comment.