Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
[MXNET-1231] Allow not using Some in the Scala operators (#13619)
Browse files Browse the repository at this point in the history
* add initial commit

* update image classifier as well

* create Util class make Some conversion

* add test changes

* adress Comments

* fix the spacing problem

* fix generator base

* change name to Option
  • Loading branch information
lanking520 authored Jan 3, 2019
1 parent e30d973 commit fe46cd9
Show file tree
Hide file tree
Showing 6 changed files with 62 additions and 15 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.mxnet.util

object OptionConversion {
implicit def someWrapper[A](noSome : A) : Option[A] = Option(noSome)
}
Original file line number Diff line number Diff line change
Expand Up @@ -593,4 +593,17 @@ class NDArraySuite extends FunSuite with BeforeAndAfterAll with Matchers {
assert(rnd.shape === Shape(1, 2, 3, 4))
assert(rnd2.shape === Shape(3, 4))
}

test("Generated api") {
// Without SomeConversion
val arr3 = NDArray.ones(Shape(1, 2), dtype = DType.Float64)
val arr4 = NDArray.ones(Shape(1), dtype = DType.Float64)
val arr5 = NDArray.api.norm(arr3, ord = Some(1), out = Some(arr4))
// With SomeConversion
import org.apache.mxnet.util.OptionConversion._
val arr = NDArray.ones(Shape(1, 2), dtype = DType.Float64)
val arr2 = NDArray.ones(Shape(1), dtype = DType.Float64)
NDArray.api.norm(arr, ord = 1, out = arr2)
val result = NDArray.api.dot(arr2, arr2)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,6 @@ class Classifier(modelPathPrefix: String,
})

val predictResult = predictResultPar.toArray

var result: ListBuffer[IndexedSeq[(String, Float)]] =
ListBuffer.empty[IndexedSeq[(String, Float)]]

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ private[mxnet] abstract class GeneratorBase {
else if (isSymbol) "org.apache.mxnet.Symbol"
else "org.apache.mxnet.NDArray"
val typeAndOption =
CToScalaUtils.argumentCleaner(argName, argType, family)
CToScalaUtils.argumentCleaner(argName, argType, family, isJava)
Arg(argName, typeAndOption._1, argDesc, typeAndOption._2)
}
val returnType =
Expand Down Expand Up @@ -191,7 +191,7 @@ private[mxnet] trait RandomHelpers {
// unify call targets (random_xyz and sample_xyz) and unify their argument types
private def unifyRandom(func: Func, isSymbol: Boolean): Func = {
var typeConv = Set("org.apache.mxnet.NDArray", "org.apache.mxnet.Symbol",
"java.lang.Float", "java.lang.Integer")
"Float", "Int")

func.copy(
name = func.name.replaceAll("(random|sample)_", ""),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,23 +18,35 @@ package org.apache.mxnet.utils

private[mxnet] object CToScalaUtils {


private val javaType = Map(
"float" -> "java.lang.Float",
"int" -> "java.lang.Integer",
"long" -> "java.lang.Long",
"double" -> "java.lang.Double",
"bool" -> "java.lang.Boolean")
private val scalaType = Map(
"float" -> "Float",
"int" -> "Int",
"long" -> "Long",
"double" -> "Double",
"bool" -> "Boolean")

// Convert C++ Types to Scala Types
def typeConversion(in : String, argType : String = "", argName : String,
returnType : String) : String = {
returnType : String, isJava : Boolean) : String = {
val header = returnType.split("\\.").dropRight(1)
val types = if (isJava) javaType else scalaType
in match {
case "Shape(tuple)" | "ShapeorNone" => s"${header.mkString(".")}.Shape"
case "Symbol" | "NDArray" | "NDArray-or-Symbol" => returnType
case "Symbol[]" | "NDArray[]" | "NDArray-or-Symbol[]" | "SymbolorSymbol[]"
=> s"Array[$returnType]"
case "float" | "real_t" | "floatorNone" => "java.lang.Float"
case "int" | "intorNone" | "int(non-negative)" => "java.lang.Integer"
case "long" | "long(non-negative)" => "java.lang.Long"
case "double" | "doubleorNone" => "java.lang.Double"
case "float" | "real_t" | "floatorNone" => types("float")
case "int" | "intorNone" | "int(non-negative)" => types("int")
case "long" | "long(non-negative)" => types("long")
case "double" | "doubleorNone" => types("double")
case "string" => "String"
case "boolean" | "booleanorNone" => "java.lang.Boolean"
case "boolean" | "booleanorNone" => types("bool")
case "tupleof<float>" | "tupleof<double>" | "tupleof<>" | "ptr" | "" => "Any"
case default => throw new IllegalArgumentException(
s"Invalid type for args: $default\nString argType: $argType\nargName: $argName")
Expand All @@ -54,7 +66,7 @@ private[mxnet] object CToScalaUtils {
* @return (Scala_Type, isOptional)
*/
def argumentCleaner(argName: String, argType : String,
returnType : String) : (String, Boolean) = {
returnType : String, isJava : Boolean) : (String, Boolean) = {
val spaceRemoved = argType.replaceAll("\\s+", "")
var commaRemoved : Array[String] = new Array[String](0)
// Deal with the case e.g: stype : {'csr', 'default', 'row_sparse'}
Expand All @@ -72,9 +84,9 @@ private[mxnet] object CToScalaUtils {
s"""expected "optional" got ${commaRemoved(1)}""")
require(commaRemoved(2).startsWith("default="),
s"""expected "default=..." got ${commaRemoved(2)}""")
(typeConversion(commaRemoved(0), argType, argName, returnType), true)
(typeConversion(commaRemoved(0), argType, argName, returnType, isJava), true)
} else if (commaRemoved.length == 2 || commaRemoved.length == 1) {
val tempType = typeConversion(commaRemoved(0), argType, argName, returnType)
val tempType = typeConversion(commaRemoved(0), argType, argName, returnType, isJava)
val tempOptional = tempType.equals("org.apache.mxnet.Symbol")
(tempType, tempOptional)
} else {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,14 +36,15 @@ class MacrosSuite extends FunSuite with BeforeAndAfterAll {
)
val output = List(
("org.apache.mxnet.Symbol", true),
("java.lang.Integer", false),
("Int", false),
("org.apache.mxnet.Shape", true),
("String", true),
("Any", false)
)

for (idx <- input.indices) {
val result = CToScalaUtils.argumentCleaner("Sample", input(idx), "org.apache.mxnet.Symbol")
val result = CToScalaUtils.argumentCleaner("Sample", input(idx),
"org.apache.mxnet.Symbol", false)
assert(result._1 === output(idx)._1 && result._2 === output(idx)._2)
}
}
Expand Down

0 comments on commit fe46cd9

Please sign in to comment.