Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
adding Any type input to form NDArray
Browse files Browse the repository at this point in the history
  • Loading branch information
lanking520 committed Sep 27, 2018
1 parent c5b1c48 commit 34b14f1
Showing 1 changed file with 60 additions and 0 deletions.
60 changes: 60 additions & 0 deletions scala-package/core/src/main/scala/org/apache/mxnet/NDArray.scala
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,15 @@ object NDArray extends NDArrayBase {
"_onehot_encode", Seq(indices, out), Map("out" -> out))(0)
}

/**
* Get the String representation of NDArray
* @param nd input NDArray
* @return String
*/
def toString(nd : NDArray) : String = {
nd.visualize
}

/**
* Create an empty uninitialized new NDArray, with specified shape.
*
Expand Down Expand Up @@ -395,6 +404,57 @@ object NDArray extends NDArrayBase {
arr
}

/**
* Create a new NDArray based on the structure of source Array
* @param sourceArr Array[Array...Array[Float]...]
* @param ctx context like to pass in
* @return an NDArray with the same shape of the input
*/
def toNDArray(sourceArr: Array[_], ctx : Context = null) : NDArray = {
val shape = ArrayBuffer[Int]()
shapeGetter(sourceArr, shape, 0)
val finalArr = new Array[Float](shape.product)
arrayCombiner(sourceArr, finalArr, 0, finalArr.length - 1)
array(finalArr, Shape(shape), ctx)
}

private def shapeGetter(sourceArr : Any,
shape : ArrayBuffer[Int], shapeIdx : Int) : Unit = {
sourceArr match {
case arrFloat : Array[Float] => {
val arrLength = arrFloat.length
if (shape.length == shapeIdx) {
shape += arrLength
}
require(shape(shapeIdx) == arrLength, "Each Array should have equal length")
}
case arr : Array[Any] => {
val arrLength = arr.length
if (shape.length == shapeIdx) {
shape += arrLength
}
require(shape(shapeIdx) == arrLength,
s"Each Array should have equal length, expected ${shape(shapeIdx)}, get $arrLength")
arr.foreach(ele => shapeGetter(ele, shape, shapeIdx + 1))
}
case _ => throw new IllegalArgumentException(s"Wrong type passed: ${sourceArr.getClass}")
}
}

private def arrayCombiner(sourceArr : Any, arr : Array[Float], start : Int, end : Int) : Unit = {
sourceArr match {
case arrFloat : Array[Float] => {
for (i <- arrFloat.indices) arr(start + i) = arrFloat(i)
}
case arrAny : Array[Any] => {
val fragment = (end - start + 1) / arrAny.length
for (i <- arrAny.indices)
arrayCombiner(arrAny(i), arr, start + i * fragment, end + (i + 1) * fragment)
}
case _ => throw new IllegalArgumentException(s"Wrong type passed: ${sourceArr.getClass}")
}
}

/**
* Returns evenly spaced values within a given interval.
* Values are generated within the half-open interval [`start`, `stop`). In other
Expand Down

0 comments on commit 34b14f1

Please sign in to comment.