Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
[v1.4.1] Java bug-fix cherry pick (#14834)
Browse files Browse the repository at this point in the history
* clean up submodule (#14645)

* Scala/Java Predict API fix #14756 (#14804)

* add fix in the code

* add unit test

* update comments

* add fixes to code gen
  • Loading branch information
lanking520 authored Apr 29, 2019
1 parent 19d78e5 commit 69515c2
Show file tree
Hide file tree
Showing 4 changed files with 173 additions and 60 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -247,11 +247,23 @@ abstract class BaseModule {

/**
* Run prediction and collect the outputs.
* @param evalData
* @param evalData dataIter to do the Inference
* @param numBatch Default is -1, indicating running all the batches in the data iterator.
* @param reset Default is `True`, indicating whether we should reset the data iter before start
* doing prediction.
* @return The return value will be a list `[out1, out2, out3]`.
* The concatenation process will be like
* {{{
* outputBatches = [
* [a1, a2, a3], // batch a
* [b1, b2, b3] // batch b
* ]
* result = [
* NDArray, // [a1, b1]
* NDArray, // [a2, b2]
* NDArray, // [a3, b3]
* ]
* }}}
* Where each element is concatenation of the outputs for all the mini-batches.
*/
def predict(evalData: DataIter, numBatch: Int = -1, reset: Boolean = true)
Expand All @@ -264,7 +276,8 @@ abstract class BaseModule {
s"in mini-batches (${out.size})." +
"Maybe bucketing is used?")
)
val concatenatedOutput = outputBatches.map(out => NDArray.concatenate(out))
val oBT = outputBatches.transpose
val concatenatedOutput = oBT.map(out => NDArray.concatenate(out))
outputBatches.foreach(_.foreach(_.dispose()))
concatenatedOutput
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,15 +71,15 @@ public void testGenerated(){
NDArray$ NDArray = NDArray$.MODULE$;
float[] arr = new float[]{1.0f, 2.0f, 3.0f};
NDArray nd = new NDArray(arr, new Shape(new int[]{3}), new Context("cpu", 0));
float result = NDArray.norm(NDArray.new normParam(nd))[0].toArray()[0];
float result = NDArray.norm(new normParam(nd))[0].toArray()[0];
float cal = 0.0f;
for (float ele : arr) {
cal += ele * ele;
}
cal = (float) Math.sqrt(cal);
assertTrue(Math.abs(result - cal) < 1e-5);
NDArray dotResult = new NDArray(new float[]{0}, new Shape(new int[]{1}), new Context("cpu", 0));
NDArray.dot(NDArray.new dotParam(nd, nd).setOut(dotResult));
NDArray.dot(new dotParam(nd, nd).setOut(dotResult));
assertTrue(Arrays.equals(dotResult.toArray(), new float[]{14.0f}));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,34 @@ import org.apache.mxnet.optimizer._
import org.apache.mxnet.io._

class ModuleSuite extends FunSuite with BeforeAndAfterAll {

class myModule(symbol : Symbol) extends Module (symbol) {
override def predictEveryBatch(evalData: DataIter,
numBatch: Int = 1, reset: Boolean = true):
IndexedSeq[IndexedSeq[NDArray]] = {
val data = IndexedSeq(
NDArray.ones(Shape(1, 10, 1)),
NDArray.ones(Shape(1, 10, 1)),
NDArray.ones(Shape(1, 10, 4))
)
List.fill(numBatch)(data).toIndexedSeq
}
}

test("predict") {
val sym = Symbol.Variable("data")
val mod = new myModule(sym)
val dummyIter = new NDArrayIter(IndexedSeq(NDArray.ones(1)))
var output = mod.predict(dummyIter, 1)
require(output(0).shape == Shape(1, 10, 1))
require(output(1).shape == Shape(1, 10, 1))
require(output(2).shape == Shape(1, 10, 4))
output = mod.predict(dummyIter, 2)
require(output(0).shape == Shape(2, 10, 1))
require(output(1).shape == Shape(2, 10, 1))
require(output(2).shape == Shape(2, 10, 4))
}

test ("model dtype") {
val dType = DType.Float32
val dShape = Shape(3, 8, 7)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,16 @@ import java.security.MessageDigest
import scala.collection.mutable.ListBuffer

/**
* This object will generate the Scala documentation of the new Scala API
* Two file namely: SymbolAPIBase.scala and NDArrayAPIBase.scala
* This object will generate the Scala documentation of the Scala/Java APIs
* The code will be executed during Macros stage and file live in Core stage
*/
private[mxnet] object APIDocGenerator extends GeneratorBase {

/**
* Main method used to generate code and write to files
* A hash check placed at the end to verify changes
* @param args Input args
*/
def main(args: Array[String]): Unit = {
val FILE_PATH = args(0)
val hashCollector = ListBuffer[String]()
Expand All @@ -40,13 +44,25 @@ private[mxnet] object APIDocGenerator extends GeneratorBase {
val finalHash = hashCollector.mkString("\n")
}

/**
* Generate MD5 result from an input string
* Encoded in UTF-8
* @param input The input string
* @return A MD5 value from the string
*/
def MD5Generator(input: String): String = {
val md = MessageDigest.getInstance("MD5")
md.update(input.getBytes("UTF-8"))
val digest = md.digest()
org.apache.commons.codec.binary.Base64.encodeBase64URLSafeString(digest)
}

/**
* Type-safe class body generation for NDArray/Symbol
* @param FILE_PATH File path write the file to
* @param isSymbol Check if write the Symbol API, NDArray otherwise
* @return MD5 String
*/
def typeSafeClassGen(FILE_PATH: String, isSymbol: Boolean): String = {
val generated = typeSafeFunctionsToGenerate(isSymbol, isContrib = false)
.map { func =>
Expand All @@ -57,11 +73,22 @@ private[mxnet] object APIDocGenerator extends GeneratorBase {

writeFile(
FILE_PATH,
if (isSymbol) "SymbolAPIBase" else "NDArrayAPIBase",
"package org.apache.mxnet",
if (isSymbol) "SymbolAPIBase" else "NDArrayAPIBase",
"import org.apache.mxnet.annotation.Experimental",
generated)
}

/**
* Non Type-safe interface of Scala Symbol/NDArray
* It includes class definition : e.g class SymbolBase
* and function definitions : e.g def softmax(...)(...)(...) : NDArray
* Users can directly use the api by calling NDArray.<function_name>
* It support both positional input or Map input
* @param FILE_PATH File path write the file to
* @param isSymbol Check if write the Symbol API, NDArray otherwise
* @return MD5 String
*/
def nonTypeSafeClassGen(FILE_PATH: String, isSymbol: Boolean): String = {
val absFuncs = functionsToGenerate(isSymbol, isContrib = false)
.map { func =>
Expand All @@ -85,34 +112,53 @@ private[mxnet] object APIDocGenerator extends GeneratorBase {

writeFile(
FILE_PATH,
if (isSymbol) "SymbolBase" else "NDArrayBase",
"package org.apache.mxnet",
if (isSymbol) "SymbolBase" else "NDArrayBase",
"import org.apache.mxnet.annotation.Experimental",
absFuncs)
}

def javaClassGen(filePath : String) : String = {
/**
* Type-safe interface of Java NDArray
* @param FILE_PATH File path write the file to
* @return MD5 String
*/
def javaClassGen(FILE_PATH : String) : String = {
val notGenerated = Set("Custom")
val absClassFunctions = functionsToGenerate(false, false, true)
val absFuncs = absClassFunctions.filterNot(ele => notGenerated.contains(ele.name))
.groupBy(_.name.toLowerCase).map(ele => {
/* Pattern matching for not generating deprecated method
* Group all method name in lowercase
* Kill the capital lettered method such as Cast vs cast
* As it defined by default it deprecated
*/
if (ele._2.length == 1) ele._2.head
else {
if (ele._2.head.name.head.isLower) ele._2.head
else ele._2.last
}
}).map(absClassFunction => {
val (absFuncs, paramClassUncleaned) =
absClassFunctions.filterNot(ele => notGenerated.contains(ele.name))
.groupBy(_.name.toLowerCase).map(ele => {
/* Pattern matching for not generating deprecated method
* Group all method name in lowercase
* Kill the capital lettered method such as Cast vs cast
* As it defined by default it deprecated
*/
if (ele._2.length == 1) ele._2.head
else {
if (ele._2.head.name.head.isLower) ele._2.head
else ele._2.last
}
}).map(absClassFunction => {
generateJavaAPISignature(absClassFunction)
}).toSeq
}).toSeq.unzip
val paramClass = paramClassUncleaned.filterNot(_.isEmpty)
val packageName = "NDArrayBase"
val packageDef = "package org.apache.mxnet.javaapi"
writeFile(filePath + "javaapi/", packageName, packageDef, absFuncs)
writeFile(
FILE_PATH + "javaapi/",
packageDef,
packageName,
"import org.apache.mxnet.annotation.Experimental",
absFuncs, Some(paramClass))
}

/**
* Generate Scala docs from the function description
* @param func The function case class
* @param withParam Whether to generate param field
* @return A formatted string for the function description
*/
def generateAPIDocFromBackend(func: Func, withParam: Boolean = true): String = {
def fixDesc(desc: String): String = {
var curDesc = desc
Expand Down Expand Up @@ -146,7 +192,15 @@ private[mxnet] object APIDocGenerator extends GeneratorBase {
}
}

def generateAPISignature(func: Func, isSymbol: Boolean): String = {
/**
* Generate the function interface
* e.g: def softmax(data: NDArray, name ...): NDArrayFunctionReturn
* @param func The function case class
* @param isSymbol Check if generate Symbol function, NDArray otherwise
* @param typeParameter Type param specifically used in Random Module
* @return Formatted string for the function
*/
def generateAPISignature(func: Func, isSymbol: Boolean, typeParameter: String = ""): String = {
val argDef = ListBuffer[String]()

argDef ++= typedFunctionCommonArgDef(func)
Expand All @@ -162,10 +216,15 @@ private[mxnet] object APIDocGenerator extends GeneratorBase {
val returnType = func.returnType

s"""@Experimental
|def ${func.name} (${argDef.mkString(", ")}): $returnType""".stripMargin
|def ${func.name}$typeParameter (${argDef.mkString(", ")}): $returnType""".stripMargin
}

def generateJavaAPISignature(func : Func) : String = {
/**
* Generate Java function interface
* @param func The function case class
* @return A formatted string for the function
*/
def generateJavaAPISignature(func : Func) : (String, String) = {
val useParamObject = func.listOfArgs.count(arg => arg.isOptional) >= 2
var argDef = ListBuffer[String]()
var classDef = ListBuffer[String]()
Expand Down Expand Up @@ -204,54 +263,67 @@ private[mxnet] object APIDocGenerator extends GeneratorBase {
| }
| def getOut() = this.out
| """.stripMargin
s"""$scalaDocNoParam
| $experimentalTag
| def ${func.name}(po: ${func.name}Param) : $returnType
| /**
| * This Param Object is specifically used for ${func.name}
| ${requiredParam.mkString("\n")}
| */
| class ${func.name}Param(${argDef.mkString(",")}) {
| ${classDef.mkString("\n ")}
| }""".stripMargin
(s"""$scalaDocNoParam
| $experimentalTag
| def ${func.name}(po: ${func.name}Param) : $returnType
| """.stripMargin,
s"""/**
| * This Param Object is specifically used for ${func.name}
| ${requiredParam.mkString("\n")}
| */
| class ${func.name}Param(${argDef.mkString(",")}) {
| ${classDef.mkString("\n ")}
| }""".stripMargin)
} else {
argDef += "out : NDArray"
s"""$scalaDoc
|$experimentalTag
| def ${func.name}(${argDef.mkString(", ")}) : $returnType
| """.stripMargin
(s"""$scalaDoc
|$experimentalTag
| def ${func.name}(${argDef.mkString(", ")}) : $returnType
| """.stripMargin, "")
}
}

def writeFile(FILE_PATH: String, className: String, packageDef: String,
absFuncs: Seq[String]): String = {
/**
* Write the formatted string to file
* @param FILE_PATH Location of the file writes to
* @param packageDef Package definition
* @param className Class name
* @param imports Packages need to import
* @param absFuncs All formatted functions
* @return A MD5 string
*/
def writeFile(FILE_PATH: String, packageDef: String, className: String,
imports: String, absFuncs: Seq[String],
paramClass: Option[Seq[String]] = None): String = {

val finalStr =
s"""/*
|* Licensed to the Apache Software Foundation (ASF) under one or more
|* contributor license agreements. See the NOTICE file distributed with
|* this work for additional information regarding copyright ownership.
|* The ASF licenses this file to You under the Apache License, Version 2.0
|* (the "License"); you may not use this file except in compliance with
|* the License. You may obtain a copy of the License at
|*
|* http://www.apache.org/licenses/LICENSE-2.0
|*
|* Unless required by applicable law or agreed to in writing, software
|* distributed under the License is distributed on an "AS IS" BASIS,
|* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|* See the License for the specific language governing permissions and
|* limitations under the License.
|*/
| * Licensed to the Apache Software Foundation (ASF) under one or more
| * contributor license agreements. See the NOTICE file distributed with
| * this work for additional information regarding copyright ownership.
| * The ASF licenses this file to You under the Apache License, Version 2.0
| * (the "License"); you may not use this file except in compliance with
| * the License. You may obtain a copy of the License at
| *
| * http://www.apache.org/licenses/LICENSE-2.0
| *
| * Unless required by applicable law or agreed to in writing, software
| * distributed under the License is distributed on an "AS IS" BASIS,
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
| * See the License for the specific language governing permissions and
| * limitations under the License.
| */
|
|$packageDef
|
|import org.apache.mxnet.annotation.Experimental
|$imports
|
|// scalastyle:off
|abstract class $className {
|${absFuncs.mkString("\n")}
|}""".stripMargin
|}
|${paramClass.getOrElse(Seq()).mkString("\n")}
|""".stripMargin


val pw = new PrintWriter(new File(FILE_PATH + s"$className.scala"))
Expand Down

0 comments on commit 69515c2

Please sign in to comment.