Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
add multiple types support
Browse files Browse the repository at this point in the history
  • Loading branch information
lanking520 committed Jan 16, 2019
1 parent b169f9e commit 9f2ea34
Show file tree
Hide file tree
Showing 3 changed files with 51 additions and 26 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -82,4 +82,10 @@ object MX_PRIMITIVES {

implicit def MX_DoubleToDouble(d: MX_Double) : Double = d.data

def isValidType(num : Any) : Boolean = {
num match {
case valid @ (_: Float | _: Double) => true
case _ => false
}
}
}
33 changes: 18 additions & 15 deletions scala-package/core/src/main/scala/org/apache/mxnet/NDArray.scala
Original file line number Diff line number Diff line change
Expand Up @@ -512,29 +512,34 @@ object NDArray extends NDArrayBase {

/**
* Create a new NDArray based on the structure of source Array
* @param sourceArr Array[Array...Array[Float]...]
* @param sourceArr Array[Array...Array[MX_PRIMITIVE_TYPE]...]
* @param ctx context like to pass in
* @return an NDArray with the same shape of the input
*/
def toNDArray(sourceArr: Array[_], ctx : Context = null) : NDArray = {
val shape = ArrayBuffer[Int]()
shapeGetter(sourceArr, shape, 0)
val finalArr = new Array[Float](shape.product)
arrayCombiner(sourceArr, finalArr, 0, finalArr.length - 1)
array(finalArr, Shape(shape), ctx)
val container = new Array[Any](shape.product)
arrayCombiner(sourceArr, container, 0, container.length - 1)
val finalArr = container(0) match {
case f: Float => array(container.map(_.asInstanceOf[Float]), Shape(shape), ctx)
case d: Double => array(container.map(_.asInstanceOf[Double]), Shape(shape), ctx)
case _ => throw new IllegalArgumentException(s"Unsupported type ${container(0).getClass}")
}
finalArr
}

private def shapeGetter(sourceArr : Any,
shape : ArrayBuffer[Int], shapeIdx : Int) : Unit = {
sourceArr match {
case arrFloat : Array[Float] => {
val arrLength = arrFloat.length
case arr: Array[_] if MX_PRIMITIVES.isValidType(arr(0)) => {
val arrLength = arr.length
if (shape.length == shapeIdx) {
shape += arrLength
}
require(shape(shapeIdx) == arrLength, "Each Array should have equal length")
}
case arr : Array[Any] => {
case arr: Array[_] => {
val arrLength = arr.length
if (shape.length == shapeIdx) {
shape += arrLength
Expand All @@ -547,12 +552,13 @@ object NDArray extends NDArrayBase {
}
}

private def arrayCombiner(sourceArr : Any, arr : Array[Float], start : Int, end : Int) : Unit = {
private def arrayCombiner(sourceArr : Any, arr : Array[Any],
start : Int, end : Int) : Unit = {
sourceArr match {
case arrFloat : Array[Float] => {
for (i <- arrFloat.indices) arr(start + i) = arrFloat(i)
case arrValid: Array[_] if MX_PRIMITIVES.isValidType(arrValid(0)) => {
for (i <- arrValid.indices) arr(start + i) = arrValid(i)
}
case arrAny : Array[Any] => {
case arrAny: Array[_] => {
val fragment = (end - start + 1) / arrAny.length
for (i <- arrAny.indices)
arrayCombiner(arrAny(i), arr, start + i * fragment, start + (i + 1) * fragment)
Expand Down Expand Up @@ -746,10 +752,7 @@ class NDArray private[mxnet](private[mxnet] val handle: NDArrayHandle,
private[mxnet] val dependencies = mutable.HashMap.empty[Long, WeakReference[NDArray]]

private val traceProperty = "mxnet.setNDArrayPrintLength"
private lazy val printLength = {
val value = Try(System.getProperty(traceProperty).toInt).getOrElse(1000)
value
}
private lazy val printLength = Try(System.getProperty(traceProperty).toInt).getOrElse(1000)

def serialize(): Array[Byte] = {
val buf = ArrayBuffer.empty[Byte]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -90,28 +90,44 @@ class NDArraySuite extends FunSuite with BeforeAndAfterAll with Matchers {
}

test("create NDArray based on Java Matrix") {
val arrBuf = ArrayBuffer[Array[Float]]()
for (i <- 0 until 100) arrBuf += Array(1.0f, 1.0f, 1.0f, 1.0f)
val arr = Array(
def arrayGen(num : Any) : Array[Any] = {
val arrayBuf = num match {
case f: Float =>
val arr = ArrayBuffer[Array[Float]]()
for (_ <- 0 until 100) arr += Array(1.0f, 1.0f, 1.0f, 1.0f)
arr
case d: Double =>
val arr = ArrayBuffer[Array[Double]]()
for (_ <- 0 until 100) arr += Array(1.0d, 1.0d, 1.0d, 1.0d)
arr
case _ => throw new IllegalArgumentException(s"Unsupported Type ${num.getClass}")
}
Array(
arrBuf.toArray
),
Array(
arrBuf.toArray
Array(
arrayBuf.toArray
),
Array(
arrayBuf.toArray
)
)
var nd = NDArray.toNDArray(arr)
)
}
val floatData = 1.0f
var nd = NDArray.toNDArray(arrayGen(floatData))
require(nd.shape == Shape(2, 1, 100, 4))
val arr2 = Array(1.0f, 1.0f, 1.0f, 1.0f)
nd = NDArray.toNDArray(arr2)
require(nd.shape == Shape(4))
val doubleData = 1.0d
nd = NDArray.toNDArray(arrayGen(doubleData))
require(nd.shape == Shape(2, 1, 100, 4))
require(nd.dtype == DType.Float64)
}

test("test Visualize") {
var nd = NDArray.ones(Shape(1, 2, 1000, 1))
logger.info(s"Test print large ndarray:\n$nd")
require(nd.toString.split("\n").length == 33)
nd = NDArray.ones(Shape(1, 4))
logger.info(s"Test print small ndarray:\n$nd")
require(nd.toString.split("\n").length == 4)
}

test("plus") {
Expand Down

0 comments on commit 9f2ea34

Please sign in to comment.