Skip to content

Commit

Permalink
[MXNET-918] Random module (apache#13039)
Browse files Browse the repository at this point in the history
* introduce random API

* revert useless changes

* shorter types in APIDoc gen code

* fix after merge from master

* Trigger CI

* temp code / diag on CI

* cleanup type-class code

* cleanup type-class code

* fix scalastyle
  • Loading branch information
mdespriee authored and Ubuntu committed Dec 18, 2018
1 parent d59dae6 commit 948ea01
Show file tree
Hide file tree
Showing 11 changed files with 435 additions and 85 deletions.
18 changes: 18 additions & 0 deletions scala-package/core/src/main/scala/org/apache/mxnet/Base.scala
Original file line number Diff line number Diff line change
Expand Up @@ -153,3 +153,21 @@ private[mxnet] object Base {
}

class MXNetError(val err: String) extends Exception(err)

// Some type-classes to ease the work in Symbol.random and NDArray.random modules

class SymbolOrScalar[T](val isScalar: Boolean)
object SymbolOrScalar {
def apply[T](implicit ev: SymbolOrScalar[T]): SymbolOrScalar[T] = ev
implicit object FloatWitness extends SymbolOrScalar[Float](true)
implicit object IntWitness extends SymbolOrScalar[Int](true)
implicit object SymbolWitness extends SymbolOrScalar[Symbol](false)
}

class NDArrayOrScalar[T](val isScalar: Boolean)
object NDArrayOrScalar {
def apply[T](implicit ev: NDArrayOrScalar[T]): NDArrayOrScalar[T] = ev
implicit object FloatWitness extends NDArrayOrScalar[Float](true)
implicit object IntWitness extends NDArrayOrScalar[Int](true)
implicit object NDArrayWitness extends NDArrayOrScalar[NDArray](false)
}
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ object NDArray extends NDArrayBase {
private val functions: Map[String, NDArrayFunction] = initNDArrayModule()

val api = NDArrayAPI
val random = NDArrayRandomAPI

private def addDependency(froms: Array[NDArray], tos: Array[NDArray]): Unit = {
froms.foreach { from =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,22 @@
* limitations under the License.
*/
package org.apache.mxnet
@AddNDArrayAPIs(false)

/**
* typesafe NDArray API: NDArray.api._
* Main code will be generated during compile time through Macros
*/
@AddNDArrayAPIs(false)
object NDArrayAPI extends NDArrayAPIBase {
// TODO: Implement CustomOp for NDArray
}

/**
* typesafe NDArray random module: NDArray.random._
* Main code will be generated during compile time through Macros
*/
@AddNDArrayRandomAPIs(false)
object NDArrayRandomAPI extends NDArrayRandomAPIBase {

}

Original file line number Diff line number Diff line change
Expand Up @@ -842,6 +842,7 @@ object Symbol extends SymbolBase {
private val bindReqMap = Map("null" -> 0, "write" -> 1, "add" -> 3)

val api = SymbolAPI
val random = SymbolRandomAPI

def pow(sym1: Symbol, sym2: Symbol): Symbol = {
Symbol.createFromListedSymbols("_Power")(Array(sym1, sym2))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,11 @@ package org.apache.mxnet
import scala.collection.mutable


@AddSymbolAPIs(false)
/**
* typesafe Symbol API: Symbol.api._
* Main code will be generated during compile time through Macros
*/
@AddSymbolAPIs(false)
object SymbolAPI extends SymbolAPIBase {
def Custom (op_type : String, kwargs : mutable.Map[String, Any],
name : String = null, attr : Map[String, String] = null) : Symbol = {
Expand All @@ -32,3 +32,13 @@ object SymbolAPI extends SymbolAPIBase {
Symbol.createSymbolGeneral("Custom", name, attr, Seq(), map.toMap)
}
}

/**
* typesafe Symbol random module: Symbol.random._
* Main code will be generated during compile time through Macros
*/
@AddSymbolRandomAPIs(false)
object SymbolRandomAPI extends SymbolRandomAPIBase {

}

Original file line number Diff line number Diff line change
Expand Up @@ -576,4 +576,21 @@ class NDArraySuite extends FunSuite with BeforeAndAfterAll with Matchers {
assert(arr.internal.toDoubleArray === Array(2d, 2d))
assert(arr.internal.toByteArray === Array(2.toByte, 2.toByte))
}

test("NDArray random module is generated properly") {
val lam = NDArray.ones(1, 2)
val rnd = NDArray.random.poisson(lam = Some(lam), shape = Some(Shape(3, 4)))
val rnd2 = NDArray.random.poisson(lam = Some(1f), shape = Some(Shape(3, 4)))
assert(rnd.shape === Shape(1, 2, 3, 4))
assert(rnd2.shape === Shape(3, 4))
}

test("NDArray random module is generated properly - special case of 'normal'") {
val mu = NDArray.ones(1, 2)
val sigma = NDArray.ones(1, 2) * 2
val rnd = NDArray.random.normal(mu = Some(mu), sigma = Some(sigma), shape = Some(Shape(3, 4)))
val rnd2 = NDArray.random.normal(mu = Some(1f), sigma = Some(2f), shape = Some(Shape(3, 4)))
assert(rnd.shape === Shape(1, 2, 3, 4))
assert(rnd2.shape === Shape(3, 4))
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ package org.apache.mxnet
import org.scalatest.{BeforeAndAfterAll, FunSuite}

class SymbolSuite extends FunSuite with BeforeAndAfterAll {

test("symbol compose") {
val data = Symbol.Variable("data")

Expand Down Expand Up @@ -71,4 +72,25 @@ class SymbolSuite extends FunSuite with BeforeAndAfterAll {
val data2 = data.clone()
assert(data.toJson === data2.toJson)
}

test("Symbol random module is generated properly") {
val lam = Symbol.Variable("lam")
val rnd = Symbol.random.poisson(lam = Some(lam), shape = Some(Shape(2, 2)))
val rnd2 = Symbol.random.poisson(lam = Some(1f), shape = Some(Shape(2, 2)))
// scalastyle:off println
println(s"Symbol.random.poisson debug info: ${rnd.debugStr}")
println(s"Symbol.random.poisson debug info: ${rnd2.debugStr}")
// scalastyle:on println
}

test("Symbol random module is generated properly - special case of 'normal'") {
val loc = Symbol.Variable("loc")
val scale = Symbol.Variable("scale")
val rnd = Symbol.random.normal(mu = Some(loc), sigma = Some(scale), shape = Some(Shape(2, 2)))
val rnd2 = Symbol.random.normal(mu = Some(1f), sigma = Some(2f), shape = Some(Shape(2, 2)))
// scalastyle:off println
println(s"Symbol.random.sample_normal debug info: ${rnd.debugStr}")
println(s"Symbol.random.random_normal debug info: ${rnd2.debugStr}")
// scalastyle:on println
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -27,13 +27,15 @@ import scala.collection.mutable.ListBuffer
* Two file namely: SymbolAPIBase.scala and NDArrayAPIBase.scala
* The code will be executed during Macros stage and file live in Core stage
*/
private[mxnet] object APIDocGenerator extends GeneratorBase {
private[mxnet] object APIDocGenerator extends GeneratorBase with RandomHelpers {

def main(args: Array[String]): Unit = {
val FILE_PATH = args(0)
val hashCollector = ListBuffer[String]()
hashCollector += typeSafeClassGen(FILE_PATH, true)
hashCollector += typeSafeClassGen(FILE_PATH, false)
hashCollector += typeSafeRandomClassGen(FILE_PATH, true)
hashCollector += typeSafeRandomClassGen(FILE_PATH, false)
hashCollector += nonTypeSafeClassGen(FILE_PATH, true)
hashCollector += nonTypeSafeClassGen(FILE_PATH, false)
hashCollector += javaClassGen(FILE_PATH)
Expand All @@ -57,8 +59,27 @@ private[mxnet] object APIDocGenerator extends GeneratorBase {

writeFile(
FILE_PATH,
"package org.apache.mxnet",
if (isSymbol) "SymbolAPIBase" else "NDArrayAPIBase",
"import org.apache.mxnet.annotation.Experimental",
generated)
}

def typeSafeRandomClassGen(FILE_PATH: String, isSymbol: Boolean): String = {
val generated = typeSafeRandomFunctionsToGenerate(isSymbol)
.map { func =>
val scalaDoc = generateAPIDocFromBackend(func)
val typeParameter = randomGenericTypeSpec(isSymbol, false)
val decl = generateAPISignature(func, isSymbol, typeParameter)
s"$scalaDoc\n$decl"
}

writeFile(
FILE_PATH,
"package org.apache.mxnet",
if (isSymbol) "SymbolRandomAPIBase" else "NDArrayRandomAPIBase",
"""import org.apache.mxnet.annotation.Experimental
|import scala.reflect.ClassTag""".stripMargin,
generated)
}

Expand All @@ -85,8 +106,9 @@ private[mxnet] object APIDocGenerator extends GeneratorBase {

writeFile(
FILE_PATH,
if (isSymbol) "SymbolBase" else "NDArrayBase",
"package org.apache.mxnet",
if (isSymbol) "SymbolBase" else "NDArrayBase",
"import org.apache.mxnet.annotation.Experimental",
absFuncs)
}

Expand All @@ -110,7 +132,12 @@ private[mxnet] object APIDocGenerator extends GeneratorBase {
}).toSeq
val packageName = "NDArrayBase"
val packageDef = "package org.apache.mxnet.javaapi"
writeFile(filePath + "javaapi/", packageName, packageDef, absFuncs)
writeFile(
filePath + "javaapi/",
packageDef,
packageName,
"import org.apache.mxnet.annotation.Experimental",
absFuncs)
}

def generateAPIDocFromBackend(func: Func, withParam: Boolean = true): String = {
Expand Down Expand Up @@ -146,7 +173,7 @@ private[mxnet] object APIDocGenerator extends GeneratorBase {
}
}

def generateAPISignature(func: Func, isSymbol: Boolean): String = {
def generateAPISignature(func: Func, isSymbol: Boolean, typeParameter: String = ""): String = {
val argDef = ListBuffer[String]()

argDef ++= typedFunctionCommonArgDef(func)
Expand All @@ -162,7 +189,7 @@ private[mxnet] object APIDocGenerator extends GeneratorBase {
val returnType = func.returnType

s"""@Experimental
|def ${func.name} (${argDef.mkString(", ")}): $returnType""".stripMargin
|def ${func.name}$typeParameter (${argDef.mkString(", ")}): $returnType""".stripMargin
}

def generateJavaAPISignature(func : Func) : String = {
Expand Down Expand Up @@ -223,8 +250,8 @@ private[mxnet] object APIDocGenerator extends GeneratorBase {
}
}

def writeFile(FILE_PATH: String, className: String, packageDef: String,
absFuncs: Seq[String]): String = {
def writeFile(FILE_PATH: String, packageDef: String, className: String,
imports: String, absFuncs: Seq[String]): String = {

val finalStr =
s"""/*
Expand All @@ -246,7 +273,7 @@ private[mxnet] object APIDocGenerator extends GeneratorBase {
|
|$packageDef
|
|import org.apache.mxnet.annotation.Experimental
|$imports
|
|// scalastyle:off
|abstract class $className {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ import org.apache.mxnet.utils.{CToScalaUtils, OperatorBuildUtils}
import scala.collection.mutable.ListBuffer
import scala.reflect.macros.blackbox

abstract class GeneratorBase {
private[mxnet] abstract class GeneratorBase {
type Handle = Long

case class Arg(argName: String, argType: String, argDesc: String, isOptional: Boolean) {
Expand All @@ -46,7 +46,8 @@ abstract class GeneratorBase {
}
}

def typeSafeFunctionsToGenerate(isSymbol: Boolean, isContrib: Boolean): List[Func] = {
// filter the operators to generate in the type-safe Symbol.api and NDArray.api
protected def typeSafeFunctionsToGenerate(isSymbol: Boolean, isContrib: Boolean): List[Func] = {
// Operators that should not be generated
val notGenerated = Set("Custom")

Expand Down Expand Up @@ -144,8 +145,8 @@ abstract class GeneratorBase {
result
}

// build function argument definition, with optionality, and safe names
protected def typedFunctionCommonArgDef(func: Func): List[String] = {
// build function argument definition, with optionality, and safe names
func.listOfArgs.map(arg =>
if (arg.isOptional) {
// let's avoid a stupid Option[Array[...]]
Expand All @@ -161,3 +162,71 @@ abstract class GeneratorBase {
)
}
}

// a mixin to ease generating the Random module
private[mxnet] trait RandomHelpers {
self: GeneratorBase =>

// a generic type spec used in Symbol.random and NDArray.random modules
protected def randomGenericTypeSpec(isSymbol: Boolean, fullPackageSpec: Boolean): String = {
val classTag = if (fullPackageSpec) "scala.reflect.ClassTag" else "ClassTag"
if (isSymbol) s"[T: SymbolOrScalar : $classTag]"
else s"[T: NDArrayOrScalar : $classTag]"
}

// filter the operators to generate in the type-safe Symbol.random and NDArray.random
protected def typeSafeRandomFunctionsToGenerate(isSymbol: Boolean): List[Func] = {
getBackEndFunctions(isSymbol)
.filter(f => f.name.startsWith("_sample_") || f.name.startsWith("_random_"))
.map(f => f.copy(name = f.name.stripPrefix("_")))
// unify _random and _sample
.map(f => unifyRandom(f, isSymbol))
// deduplicate
.groupBy(_.name)
.mapValues(_.head)
.values
.toList
}

// unify call targets (random_xyz and sample_xyz) and unify their argument types
private def unifyRandom(func: Func, isSymbol: Boolean): Func = {
var typeConv = Set("org.apache.mxnet.NDArray", "org.apache.mxnet.Symbol",
"java.lang.Float", "java.lang.Integer")

func.copy(
name = func.name.replaceAll("(random|sample)_", ""),
listOfArgs = func.listOfArgs
.map(hackNormalFunc)
.map(arg =>
if (typeConv(arg.argType)) arg.copy(argType = "T")
else arg
)
// TODO: some functions are non consistent in random_ vs sample_ regarding optionality
// we may try to unify that as well here.
)
}

// hacks to manage the fact that random_normal and sample_normal have
// non-consistent parameter naming in the back-end
// this first one, merge loc/scale and mu/sigma
protected def hackNormalFunc(arg: Arg): Arg = {
if (arg.argName == "loc") arg.copy(argName = "mu")
else if (arg.argName == "scale") arg.copy(argName = "sigma")
else arg
}

// this second one reverts this merge prior to back-end call
protected def unhackNormalFunc(func: Func): String = {
if (func.name.equals("normal")) {
s"""if(target.equals("random_normal")) {
| if(map.contains("mu")) { map("loc") = map("mu"); map.remove("mu") }
| if(map.contains("sigma")) { map("scale") = map("sigma"); map.remove("sigma") }
|}
""".stripMargin
} else {
""
}

}

}
Loading

0 comments on commit 948ea01

Please sign in to comment.