Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
create Util class make Some conversion
Browse files Browse the repository at this point in the history
  • Loading branch information
lanking520 committed Dec 11, 2018
1 parent 501ee37 commit ef65ec4
Show file tree
Hide file tree
Showing 5 changed files with 41 additions and 14 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.mxnet.util

object SomeConversion {
implicit def someWrapper[A](noSome : A) : Option[A] = Option(noSome)
}
Original file line number Diff line number Diff line change
Expand Up @@ -578,8 +578,9 @@ class NDArraySuite extends FunSuite with BeforeAndAfterAll with Matchers {
}

test("Generated api") {
import org.apache.mxnet.util.SomeConversion._
val arr = NDArray.ones(Shape(1, 2), dtype = DType.Float64)
NDArray.api.norm(arr, Some(0), out = arr)
NDArray.api.norm(arr, ord = 0, out = arr)
val result = NDArray.api.dot(arr, arr)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ abstract class GeneratorBase {
else if (isSymbol) "org.apache.mxnet.Symbol"
else "org.apache.mxnet.NDArray"
val typeAndOption =
CToScalaUtils.argumentCleaner(argName, argType, family)
CToScalaUtils.argumentCleaner(argName, argType, family, isJava)
Arg(argName, typeAndOption._1, argDesc, typeAndOption._2)
}
val returnType =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,23 +18,26 @@ package org.apache.mxnet.utils

private[mxnet] object CToScalaUtils {


private val javaType = Array("java.lang.Float", "java.lang.Integer",
"java.lang.Long", "java.lang.Double", "java.lang.Boolean")
private val scalaType = Array("Float", "Int", "Long", "Double", "Boolean")

// Convert C++ Types to Scala Types
def typeConversion(in : String, argType : String = "", argName : String,
returnType : String) : String = {
returnType : String, isJava : Boolean) : String = {
val header = returnType.split("\\.").dropRight(1)
val types = if (isJava) javaType else scalaType
in match {
case "Shape(tuple)" | "ShapeorNone" => s"${header.mkString(".")}.Shape"
case "Symbol" | "NDArray" | "NDArray-or-Symbol" => returnType
case "Symbol[]" | "NDArray[]" | "NDArray-or-Symbol[]" | "SymbolorSymbol[]"
=> s"Array[$returnType]"
case "float" | "real_t" | "floatorNone" => "java.lang.Float"
case "int" | "intorNone" | "int(non-negative)" => "java.lang.Integer"
case "long" | "long(non-negative)" => "java.lang.Long"
case "double" | "doubleorNone" => "java.lang.Double"
case "float" | "real_t" | "floatorNone" => types(0)
case "int" | "intorNone" | "int(non-negative)" => types(1)
case "long" | "long(non-negative)" => types(2)
case "double" | "doubleorNone" => types(3)
case "string" => "String"
case "boolean" | "booleanorNone" => "java.lang.Boolean"
case "boolean" | "booleanorNone" => types(4)
case "tupleof<float>" | "tupleof<double>" | "tupleof<>" | "ptr" | "" => "Any"
case default => throw new IllegalArgumentException(
s"Invalid type for args: $default\nString argType: $argType\nargName: $argName")
Expand All @@ -54,7 +57,7 @@ private[mxnet] object CToScalaUtils {
* @return (Scala_Type, isOptional)
*/
def argumentCleaner(argName: String, argType : String,
returnType : String) : (String, Boolean) = {
returnType : String, isJava : Boolean) : (String, Boolean) = {
val spaceRemoved = argType.replaceAll("\\s+", "")
var commaRemoved : Array[String] = new Array[String](0)
// Deal with the case e.g: stype : {'csr', 'default', 'row_sparse'}
Expand All @@ -72,9 +75,9 @@ private[mxnet] object CToScalaUtils {
s"""expected "optional" got ${commaRemoved(1)}""")
require(commaRemoved(2).startsWith("default="),
s"""expected "default=..." got ${commaRemoved(2)}""")
(typeConversion(commaRemoved(0), argType, argName, returnType), true)
(typeConversion(commaRemoved(0), argType, argName, returnType, isJava), true)
} else if (commaRemoved.length == 2 || commaRemoved.length == 1) {
val tempType = typeConversion(commaRemoved(0), argType, argName, returnType)
val tempType = typeConversion(commaRemoved(0), argType, argName, returnType, isJava)
val tempOptional = tempType.equals("org.apache.mxnet.Symbol")
(tempType, tempOptional)
} else {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,14 +36,15 @@ class MacrosSuite extends FunSuite with BeforeAndAfterAll {
)
val output = List(
("org.apache.mxnet.Symbol", true),
("java.lang.Integer", false),
("Int", false),
("org.apache.mxnet.Shape", true),
("String", true),
("Any", false)
)

for (idx <- input.indices) {
val result = CToScalaUtils.argumentCleaner("Sample", input(idx), "org.apache.mxnet.Symbol")
val result = CToScalaUtils.argumentCleaner("Sample", input(idx),
"org.apache.mxnet.Symbol", false)
assert(result._1 === output(idx)._1 && result._2 === output(idx)._2)
}
}
Expand Down

0 comments on commit ef65ec4

Please sign in to comment.