Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
new use of ParamObject
Browse files Browse the repository at this point in the history
  • Loading branch information
lanking520 committed Apr 25, 2019
1 parent acf53fd commit 1790d72
Show file tree
Hide file tree
Showing 7 changed files with 26 additions and 20 deletions.
2 changes: 1 addition & 1 deletion 3rdparty/googletest
Submodule googletest updated 355 files
2 changes: 1 addition & 1 deletion 3rdparty/mshadow
Original file line number Diff line number Diff line change
Expand Up @@ -86,15 +86,15 @@ public void testGenerated(){
NDArray$ NDArray = NDArray$.MODULE$;
float[] arr = new float[]{1.0f, 2.0f, 3.0f};
NDArray nd = new NDArray(arr, new Shape(new int[]{3}), new Context("cpu", 0));
float result = NDArray.norm(NDArray.new normParam(nd))[0].toArray()[0];
float result = NDArray.norm(new normParam(nd))[0].toArray()[0];
float cal = 0.0f;
for (float ele : arr) {
cal += ele * ele;
}
cal = (float) Math.sqrt(cal);
assertTrue(Math.abs(result - cal) < 1e-5);
NDArray dotResult = new NDArray(new float[]{0}, new Shape(new int[]{1}), new Context("cpu", 0));
NDArray.dot(NDArray.new dotParam(nd, nd).setOut(dotResult));
NDArray.dot(new dotParam(nd, nd).setOut(dotResult));
assertTrue(Arrays.equals(dotResult.toArray(), new float[]{14.0f}));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -68,15 +68,15 @@ private static int argmax(float[] prob) {
*/
static List<String> postProcessing(NDArray result, List<String> tokens) {
NDArray[] output = NDArray.split(
NDArray.new splitParam(result, 2).setAxis(2));
new splitParam(result, 2).setAxis(2));
// Get the formatted logits result
NDArray startLogits = output[0].reshape(new int[]{0, -3});
NDArray endLogits = output[1].reshape(new int[]{0, -3});
// Get Probability distribution
float[] startProb = NDArray.softmax(
NDArray.new softmaxParam(startLogits))[0].toArray();
new softmaxParam(startLogits))[0].toArray();
float[] endProb = NDArray.softmax(
NDArray.new softmaxParam(endLogits))[0].toArray();
new softmaxParam(endLogits))[0].toArray();
int startIdx = argmax(startProb);
int endIdx = argmax(endProb);
return tokens.subList(startIdx, endIdx + 1);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,8 @@ private[mxnet] object APIDocGenerator extends GeneratorBase with RandomHelpers {
def javaClassGen(FILE_PATH : String) : String = {
val notGenerated = Set("Custom")
val absClassFunctions = functionsToGenerate(false, false, true)
val absFuncs = absClassFunctions.filterNot(ele => notGenerated.contains(ele.name))
val (absFuncs, paramClassUncleaned) =
absClassFunctions.filterNot(ele => notGenerated.contains(ele.name))
.groupBy(_.name.toLowerCase).map(ele => {
/* Pattern matching for not generating deprecated method
* Group all method name in lowercase
Expand All @@ -166,15 +167,16 @@ private[mxnet] object APIDocGenerator extends GeneratorBase with RandomHelpers {
}
}).map(absClassFunction => {
generateJavaAPISignature(absClassFunction)
}).toSeq
}).toSeq.unzip
val paramClass = paramClassUncleaned.filter(!_.isEmpty)
val packageName = "NDArrayBase"
val packageDef = "package org.apache.mxnet.javaapi"
writeFile(
FILE_PATH + "javaapi/",
packageDef,
packageName,
"import org.apache.mxnet.annotation.Experimental",
absFuncs)
absFuncs, Some(paramClass))
}

/**
Expand Down Expand Up @@ -248,7 +250,7 @@ private[mxnet] object APIDocGenerator extends GeneratorBase with RandomHelpers {
* @param func The function case class
* @return A formatted string for the function
*/
def generateJavaAPISignature(func : Func) : String = {
def generateJavaAPISignature(func : Func) : (String, String) = {
val useParamObject = func.listOfArgs.count(arg => arg.isOptional) >= 2
var argDef = ListBuffer[String]()
var classDef = ListBuffer[String]()
Expand Down Expand Up @@ -287,22 +289,23 @@ private[mxnet] object APIDocGenerator extends GeneratorBase with RandomHelpers {
| }
| def getOut() = this.out
| """.stripMargin
s"""$scalaDocNoParam
(s"""$scalaDocNoParam
| $experimentalTag
| def ${func.name}(po: ${func.name}Param) : $returnType
| /**
| """.stripMargin,
s"""/**
| * This Param Object is specifically used for ${func.name}
| ${requiredParam.mkString("\n")}
| */
| class ${func.name}Param(${argDef.mkString(",")}) {
| ${classDef.mkString("\n ")}
| }""".stripMargin
| }""".stripMargin)
} else {
argDef += "out : NDArray"
s"""$scalaDoc
(s"""$scalaDoc
|$experimentalTag
| def ${func.name}(${argDef.mkString(", ")}) : $returnType
| """.stripMargin
| """.stripMargin, "")
}
}

Expand All @@ -316,7 +319,8 @@ private[mxnet] object APIDocGenerator extends GeneratorBase with RandomHelpers {
* @return A MD5 string
*/
def writeFile(FILE_PATH: String, packageDef: String, className: String,
imports: String, absFuncs: Seq[String]): String = {
imports: String, absFuncs: Seq[String],
paramClass: Option[Seq[String]] = None): String = {

val finalStr =
s"""/*
Expand All @@ -343,7 +347,9 @@ private[mxnet] object APIDocGenerator extends GeneratorBase with RandomHelpers {
|// scalastyle:off
|abstract class $className {
|${absFuncs.mkString("\n")}
|}""".stripMargin
|}
|${paramClass.getOrElse(Seq()).mkString("\n")}
|""".stripMargin


val pw = new PrintWriter(new File(FILE_PATH + s"$className.scala"))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ public static void main(String[] args) {

// random
NDArray random = NDArray.random_uniform(
NDArray.new random_uniformParam()
new random_uniformParam()
.setLow(0.0f)
.setHigh(2.0f)
.setShape(new Shape(new int[]{10, 10}))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ public static void main(String[] args) {
System.out.println(eleAdd);

// norm (L2 Norm)
NDArray normed = NDArray.norm(NDArray.new normParam(nd))[0];
NDArray normed = NDArray.norm(new normParam(nd))[0];
System.out.println(normed);
}
}

0 comments on commit 1790d72

Please sign in to comment.