From 926871e9a5c26c7d87f62d4ce71aff6323655f2e Mon Sep 17 00:00:00 2001 From: Piyush Ghai Date: Thu, 27 Dec 2018 17:24:38 -0800 Subject: [PATCH] Reduced code duplication in classify method in Classifier.scala --- .../org/apache/mxnet/MX_PRIMITIVES.scala | 8 +++++ .../org/apache/mxnet/infer/Classifier.scala | 30 +++++-------------- 2 files changed, 16 insertions(+), 22 deletions(-) diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/MX_PRIMITIVES.scala b/scala-package/core/src/main/scala/org/apache/mxnet/MX_PRIMITIVES.scala index 84f16f6ae8a0..147b8f1260fb 100644 --- a/scala-package/core/src/main/scala/org/apache/mxnet/MX_PRIMITIVES.scala +++ b/scala-package/core/src/main/scala/org/apache/mxnet/MX_PRIMITIVES.scala @@ -38,6 +38,14 @@ object MX_PRIMITIVES { def unary_- : MX_PRIMITIVE_TYPE } + trait MXPrimitiveOrdering extends Ordering[MX_PRIMITIVE_TYPE] { + + def compare(x: MX_PRIMITIVE_TYPE, y: MX_PRIMITIVE_TYPE) = x.compare(y) + + } + + implicit object MX_PRIMITIVE_TYPE extends MXPrimitiveOrdering + /** * Mimics Float in Scala. * @param data diff --git a/scala-package/infer/src/main/scala/org/apache/mxnet/infer/Classifier.scala b/scala-package/infer/src/main/scala/org/apache/mxnet/infer/Classifier.scala index 2fddc1cd82e3..095138d2a7dd 100644 --- a/scala-package/infer/src/main/scala/org/apache/mxnet/infer/Classifier.scala +++ b/scala-package/infer/src/main/scala/org/apache/mxnet/infer/Classifier.scala @@ -20,7 +20,9 @@ package org.apache.mxnet.infer import org.apache.mxnet._ import java.io.File +import org.apache.mxnet.MX_PRIMITIVES.MX_PRIMITIVE_TYPE import org.slf4j.LoggerFactory + import scala.io import scala.collection.mutable.ListBuffer import scala.collection.parallel.mutable.ParArray @@ -88,40 +90,24 @@ class Classifier(modelPathPrefix: String, // considering only the first output val result = input(0)(0) match { case d: Double => { - classifyWithDoubleImpl(input.asInstanceOf[IndexedSeq[Array[Double]]], topK) + classifyImpl(input.asInstanceOf[IndexedSeq[Array[Double]]], topK) } case _ => { - classifyWithFloatImpl(input.asInstanceOf[IndexedSeq[Array[Float]]], topK) + classifyImpl(input.asInstanceOf[IndexedSeq[Array[Float]]], topK) } } result.asInstanceOf[IndexedSeq[(String, T)]] } - private def classifyWithFloatImpl(input: IndexedSeq[Array[Float]], topK: Option[Int] = None) - : IndexedSeq[(String, Float)] = { - - // considering only the first output - val predictResult = predictor.predict(input)(0) - - var result: IndexedSeq[(String, Float)] = IndexedSeq.empty - - if (topK.isDefined) { - val sortedIndex = predictResult.zipWithIndex.sortBy(-_._1).map(_._2).take(topK.get) - result = sortedIndex.map(i => (synset(i), predictResult(i))).toIndexedSeq - } else { - result = synset.zip(predictResult).toIndexedSeq - } - result - } - - private def classifyWithDoubleImpl(input: IndexedSeq[Array[Double]], topK: Option[Int] = None) - : IndexedSeq[(String, Double)] = { + private def classifyImpl[B, A <: MX_PRIMITIVE_TYPE] + (input: IndexedSeq[Array[B]], topK: Option[Int] = None)(implicit ev: B => A) + : IndexedSeq[(String, B)] = { // considering only the first output val predictResult = predictor.predict(input)(0) - var result: IndexedSeq[(String, Double)] = IndexedSeq.empty + var result: IndexedSeq[(String, B)] = IndexedSeq.empty if (topK.isDefined) { val sortedIndex = predictResult.zipWithIndex.sortBy(-_._1).map(_._2).take(topK.get)