From 69515c2f9b1ac6fd4b661d5411a97de968cf4e2e Mon Sep 17 00:00:00 2001 From: Lanking Date: Mon, 29 Apr 2019 16:03:36 -0700 Subject: [PATCH] [v1.4.1] Java bug-fix cherry pick (#14834) * clean up submodule (#14645) * Scala/Java Predict API fix #14756 (#14804) * add fix in the code * add unit test * update comments * add fixes to code gen --- .../org/apache/mxnet/module/BaseModule.scala | 17 +- .../org/apache/mxnet/javaapi/NDArrayTest.java | 4 +- .../scala/org/apache/mxnet/ModuleSuite.scala | 28 +++ .../org/apache/mxnet/APIDocGenerator.scala | 184 ++++++++++++------ 4 files changed, 173 insertions(+), 60 deletions(-) diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/module/BaseModule.scala b/scala-package/core/src/main/scala/org/apache/mxnet/module/BaseModule.scala index b73f4ad4b112..73ccef2c355c 100644 --- a/scala-package/core/src/main/scala/org/apache/mxnet/module/BaseModule.scala +++ b/scala-package/core/src/main/scala/org/apache/mxnet/module/BaseModule.scala @@ -247,11 +247,23 @@ abstract class BaseModule { /** * Run prediction and collect the outputs. - * @param evalData + * @param evalData dataIter to do the Inference * @param numBatch Default is -1, indicating running all the batches in the data iterator. * @param reset Default is `True`, indicating whether we should reset the data iter before start * doing prediction. * @return The return value will be a list `[out1, out2, out3]`. + * The concatenation process will be like + * {{{ + * outputBatches = [ + * [a1, a2, a3], // batch a + * [b1, b2, b3] // batch b + * ] + * result = [ + * NDArray, // [a1, b1] + * NDArray, // [a2, b2] + * NDArray, // [a3, b3] + * ] + * }}} * Where each element is concatenation of the outputs for all the mini-batches. */ def predict(evalData: DataIter, numBatch: Int = -1, reset: Boolean = true) @@ -264,7 +276,8 @@ abstract class BaseModule { s"in mini-batches (${out.size})." + "Maybe bucketing is used?") ) - val concatenatedOutput = outputBatches.map(out => NDArray.concatenate(out)) + val oBT = outputBatches.transpose + val concatenatedOutput = oBT.map(out => NDArray.concatenate(out)) outputBatches.foreach(_.foreach(_.dispose())) concatenatedOutput } diff --git a/scala-package/core/src/test/java/org/apache/mxnet/javaapi/NDArrayTest.java b/scala-package/core/src/test/java/org/apache/mxnet/javaapi/NDArrayTest.java index 2659b7848bc6..5bbe8bbd97ea 100644 --- a/scala-package/core/src/test/java/org/apache/mxnet/javaapi/NDArrayTest.java +++ b/scala-package/core/src/test/java/org/apache/mxnet/javaapi/NDArrayTest.java @@ -71,7 +71,7 @@ public void testGenerated(){ NDArray$ NDArray = NDArray$.MODULE$; float[] arr = new float[]{1.0f, 2.0f, 3.0f}; NDArray nd = new NDArray(arr, new Shape(new int[]{3}), new Context("cpu", 0)); - float result = NDArray.norm(NDArray.new normParam(nd))[0].toArray()[0]; + float result = NDArray.norm(new normParam(nd))[0].toArray()[0]; float cal = 0.0f; for (float ele : arr) { cal += ele * ele; @@ -79,7 +79,7 @@ public void testGenerated(){ cal = (float) Math.sqrt(cal); assertTrue(Math.abs(result - cal) < 1e-5); NDArray dotResult = new NDArray(new float[]{0}, new Shape(new int[]{1}), new Context("cpu", 0)); - NDArray.dot(NDArray.new dotParam(nd, nd).setOut(dotResult)); + NDArray.dot(new dotParam(nd, nd).setOut(dotResult)); assertTrue(Arrays.equals(dotResult.toArray(), new float[]{14.0f})); } } diff --git a/scala-package/core/src/test/scala/org/apache/mxnet/ModuleSuite.scala b/scala-package/core/src/test/scala/org/apache/mxnet/ModuleSuite.scala index 88e314e2a72c..e6ebfd3c30d3 100644 --- a/scala-package/core/src/test/scala/org/apache/mxnet/ModuleSuite.scala +++ b/scala-package/core/src/test/scala/org/apache/mxnet/ModuleSuite.scala @@ -23,6 +23,34 @@ import org.apache.mxnet.optimizer._ import org.apache.mxnet.io._ class ModuleSuite extends FunSuite with BeforeAndAfterAll { + + class myModule(symbol : Symbol) extends Module (symbol) { + override def predictEveryBatch(evalData: DataIter, + numBatch: Int = 1, reset: Boolean = true): + IndexedSeq[IndexedSeq[NDArray]] = { + val data = IndexedSeq( + NDArray.ones(Shape(1, 10, 1)), + NDArray.ones(Shape(1, 10, 1)), + NDArray.ones(Shape(1, 10, 4)) + ) + List.fill(numBatch)(data).toIndexedSeq + } + } + + test("predict") { + val sym = Symbol.Variable("data") + val mod = new myModule(sym) + val dummyIter = new NDArrayIter(IndexedSeq(NDArray.ones(1))) + var output = mod.predict(dummyIter, 1) + require(output(0).shape == Shape(1, 10, 1)) + require(output(1).shape == Shape(1, 10, 1)) + require(output(2).shape == Shape(1, 10, 4)) + output = mod.predict(dummyIter, 2) + require(output(0).shape == Shape(2, 10, 1)) + require(output(1).shape == Shape(2, 10, 1)) + require(output(2).shape == Shape(2, 10, 4)) + } + test ("model dtype") { val dType = DType.Float32 val dShape = Shape(3, 8, 7) diff --git a/scala-package/macros/src/main/scala/org/apache/mxnet/APIDocGenerator.scala b/scala-package/macros/src/main/scala/org/apache/mxnet/APIDocGenerator.scala index ce12dc7cd5a0..77a2704a071b 100644 --- a/scala-package/macros/src/main/scala/org/apache/mxnet/APIDocGenerator.scala +++ b/scala-package/macros/src/main/scala/org/apache/mxnet/APIDocGenerator.scala @@ -23,12 +23,16 @@ import java.security.MessageDigest import scala.collection.mutable.ListBuffer /** - * This object will generate the Scala documentation of the new Scala API - * Two file namely: SymbolAPIBase.scala and NDArrayAPIBase.scala + * This object will generate the Scala documentation of the Scala/Java APIs * The code will be executed during Macros stage and file live in Core stage */ private[mxnet] object APIDocGenerator extends GeneratorBase { + /** + * Main method used to generate code and write to files + * A hash check placed at the end to verify changes + * @param args Input args + */ def main(args: Array[String]): Unit = { val FILE_PATH = args(0) val hashCollector = ListBuffer[String]() @@ -40,6 +44,12 @@ private[mxnet] object APIDocGenerator extends GeneratorBase { val finalHash = hashCollector.mkString("\n") } + /** + * Generate MD5 result from an input string + * Encoded in UTF-8 + * @param input The input string + * @return A MD5 value from the string + */ def MD5Generator(input: String): String = { val md = MessageDigest.getInstance("MD5") md.update(input.getBytes("UTF-8")) @@ -47,6 +57,12 @@ private[mxnet] object APIDocGenerator extends GeneratorBase { org.apache.commons.codec.binary.Base64.encodeBase64URLSafeString(digest) } + /** + * Type-safe class body generation for NDArray/Symbol + * @param FILE_PATH File path write the file to + * @param isSymbol Check if write the Symbol API, NDArray otherwise + * @return MD5 String + */ def typeSafeClassGen(FILE_PATH: String, isSymbol: Boolean): String = { val generated = typeSafeFunctionsToGenerate(isSymbol, isContrib = false) .map { func => @@ -57,11 +73,22 @@ private[mxnet] object APIDocGenerator extends GeneratorBase { writeFile( FILE_PATH, - if (isSymbol) "SymbolAPIBase" else "NDArrayAPIBase", "package org.apache.mxnet", + if (isSymbol) "SymbolAPIBase" else "NDArrayAPIBase", + "import org.apache.mxnet.annotation.Experimental", generated) } + /** + * Non Type-safe interface of Scala Symbol/NDArray + * It includes class definition : e.g class SymbolBase + * and function definitions : e.g def softmax(...)(...)(...) : NDArray + * Users can directly use the api by calling NDArray. + * It support both positional input or Map input + * @param FILE_PATH File path write the file to + * @param isSymbol Check if write the Symbol API, NDArray otherwise + * @return MD5 String + */ def nonTypeSafeClassGen(FILE_PATH: String, isSymbol: Boolean): String = { val absFuncs = functionsToGenerate(isSymbol, isContrib = false) .map { func => @@ -85,34 +112,53 @@ private[mxnet] object APIDocGenerator extends GeneratorBase { writeFile( FILE_PATH, - if (isSymbol) "SymbolBase" else "NDArrayBase", "package org.apache.mxnet", + if (isSymbol) "SymbolBase" else "NDArrayBase", + "import org.apache.mxnet.annotation.Experimental", absFuncs) } - def javaClassGen(filePath : String) : String = { + /** + * Type-safe interface of Java NDArray + * @param FILE_PATH File path write the file to + * @return MD5 String + */ + def javaClassGen(FILE_PATH : String) : String = { val notGenerated = Set("Custom") val absClassFunctions = functionsToGenerate(false, false, true) - val absFuncs = absClassFunctions.filterNot(ele => notGenerated.contains(ele.name)) - .groupBy(_.name.toLowerCase).map(ele => { - /* Pattern matching for not generating deprecated method - * Group all method name in lowercase - * Kill the capital lettered method such as Cast vs cast - * As it defined by default it deprecated - */ - if (ele._2.length == 1) ele._2.head - else { - if (ele._2.head.name.head.isLower) ele._2.head - else ele._2.last - } - }).map(absClassFunction => { + val (absFuncs, paramClassUncleaned) = + absClassFunctions.filterNot(ele => notGenerated.contains(ele.name)) + .groupBy(_.name.toLowerCase).map(ele => { + /* Pattern matching for not generating deprecated method + * Group all method name in lowercase + * Kill the capital lettered method such as Cast vs cast + * As it defined by default it deprecated + */ + if (ele._2.length == 1) ele._2.head + else { + if (ele._2.head.name.head.isLower) ele._2.head + else ele._2.last + } + }).map(absClassFunction => { generateJavaAPISignature(absClassFunction) - }).toSeq + }).toSeq.unzip + val paramClass = paramClassUncleaned.filterNot(_.isEmpty) val packageName = "NDArrayBase" val packageDef = "package org.apache.mxnet.javaapi" - writeFile(filePath + "javaapi/", packageName, packageDef, absFuncs) + writeFile( + FILE_PATH + "javaapi/", + packageDef, + packageName, + "import org.apache.mxnet.annotation.Experimental", + absFuncs, Some(paramClass)) } + /** + * Generate Scala docs from the function description + * @param func The function case class + * @param withParam Whether to generate param field + * @return A formatted string for the function description + */ def generateAPIDocFromBackend(func: Func, withParam: Boolean = true): String = { def fixDesc(desc: String): String = { var curDesc = desc @@ -146,7 +192,15 @@ private[mxnet] object APIDocGenerator extends GeneratorBase { } } - def generateAPISignature(func: Func, isSymbol: Boolean): String = { + /** + * Generate the function interface + * e.g: def softmax(data: NDArray, name ...): NDArrayFunctionReturn + * @param func The function case class + * @param isSymbol Check if generate Symbol function, NDArray otherwise + * @param typeParameter Type param specifically used in Random Module + * @return Formatted string for the function + */ + def generateAPISignature(func: Func, isSymbol: Boolean, typeParameter: String = ""): String = { val argDef = ListBuffer[String]() argDef ++= typedFunctionCommonArgDef(func) @@ -162,10 +216,15 @@ private[mxnet] object APIDocGenerator extends GeneratorBase { val returnType = func.returnType s"""@Experimental - |def ${func.name} (${argDef.mkString(", ")}): $returnType""".stripMargin + |def ${func.name}$typeParameter (${argDef.mkString(", ")}): $returnType""".stripMargin } - def generateJavaAPISignature(func : Func) : String = { + /** + * Generate Java function interface + * @param func The function case class + * @return A formatted string for the function + */ + def generateJavaAPISignature(func : Func) : (String, String) = { val useParamObject = func.listOfArgs.count(arg => arg.isOptional) >= 2 var argDef = ListBuffer[String]() var classDef = ListBuffer[String]() @@ -204,54 +263,67 @@ private[mxnet] object APIDocGenerator extends GeneratorBase { | } | def getOut() = this.out | """.stripMargin - s"""$scalaDocNoParam - | $experimentalTag - | def ${func.name}(po: ${func.name}Param) : $returnType - | /** - | * This Param Object is specifically used for ${func.name} - | ${requiredParam.mkString("\n")} - | */ - | class ${func.name}Param(${argDef.mkString(",")}) { - | ${classDef.mkString("\n ")} - | }""".stripMargin + (s"""$scalaDocNoParam + | $experimentalTag + | def ${func.name}(po: ${func.name}Param) : $returnType + | """.stripMargin, + s"""/** + | * This Param Object is specifically used for ${func.name} + | ${requiredParam.mkString("\n")} + | */ + | class ${func.name}Param(${argDef.mkString(",")}) { + | ${classDef.mkString("\n ")} + | }""".stripMargin) } else { argDef += "out : NDArray" - s"""$scalaDoc - |$experimentalTag - | def ${func.name}(${argDef.mkString(", ")}) : $returnType - | """.stripMargin + (s"""$scalaDoc + |$experimentalTag + | def ${func.name}(${argDef.mkString(", ")}) : $returnType + | """.stripMargin, "") } } - def writeFile(FILE_PATH: String, className: String, packageDef: String, - absFuncs: Seq[String]): String = { + /** + * Write the formatted string to file + * @param FILE_PATH Location of the file writes to + * @param packageDef Package definition + * @param className Class name + * @param imports Packages need to import + * @param absFuncs All formatted functions + * @return A MD5 string + */ + def writeFile(FILE_PATH: String, packageDef: String, className: String, + imports: String, absFuncs: Seq[String], + paramClass: Option[Seq[String]] = None): String = { val finalStr = s"""/* - |* Licensed to the Apache Software Foundation (ASF) under one or more - |* contributor license agreements. See the NOTICE file distributed with - |* this work for additional information regarding copyright ownership. - |* The ASF licenses this file to You under the Apache License, Version 2.0 - |* (the "License"); you may not use this file except in compliance with - |* the License. You may obtain a copy of the License at - |* - |* http://www.apache.org/licenses/LICENSE-2.0 - |* - |* Unless required by applicable law or agreed to in writing, software - |* distributed under the License is distributed on an "AS IS" BASIS, - |* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - |* See the License for the specific language governing permissions and - |* limitations under the License. - |*/ + | * Licensed to the Apache Software Foundation (ASF) under one or more + | * contributor license agreements. See the NOTICE file distributed with + | * this work for additional information regarding copyright ownership. + | * The ASF licenses this file to You under the Apache License, Version 2.0 + | * (the "License"); you may not use this file except in compliance with + | * the License. You may obtain a copy of the License at + | * + | * http://www.apache.org/licenses/LICENSE-2.0 + | * + | * Unless required by applicable law or agreed to in writing, software + | * distributed under the License is distributed on an "AS IS" BASIS, + | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + | * See the License for the specific language governing permissions and + | * limitations under the License. + | */ | |$packageDef | - |import org.apache.mxnet.annotation.Experimental + |$imports | |// scalastyle:off |abstract class $className { |${absFuncs.mkString("\n")} - |}""".stripMargin + |} + |${paramClass.getOrElse(Seq()).mkString("\n")} + |""".stripMargin val pw = new PrintWriter(new File(FILE_PATH + s"$className.scala"))