diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/NDArray.scala b/scala-package/core/src/main/scala/org/apache/mxnet/NDArray.scala index 65711cce997e..3ffca90b8166 100644 --- a/scala-package/core/src/main/scala/org/apache/mxnet/NDArray.scala +++ b/scala-package/core/src/main/scala/org/apache/mxnet/NDArray.scala @@ -197,6 +197,15 @@ object NDArray extends NDArrayBase { "_onehot_encode", Seq(indices, out), Map("out" -> out))(0) } + /** + * Get the String representation of NDArray + * @param nd input NDArray + * @return String + */ + def toString(nd : NDArray) : String = { + nd.visualize + } + /** * Create an empty uninitialized new NDArray, with specified shape. * @@ -395,6 +404,57 @@ object NDArray extends NDArrayBase { arr } + /** + * Create a new NDArray based on the structure of source Array + * @param sourceArr Array[Array...Array[Float]...] + * @param ctx context like to pass in + * @return an NDArray with the same shape of the input + */ + def toNDArray(sourceArr: Array[_], ctx : Context = null) : NDArray = { + val shape = ArrayBuffer[Int]() + shapeGetter(sourceArr, shape, 0) + val finalArr = new Array[Float](shape.product) + arrayCombiner(sourceArr, finalArr, 0, finalArr.length - 1) + array(finalArr, Shape(shape), ctx) + } + + private def shapeGetter(sourceArr : Any, + shape : ArrayBuffer[Int], shapeIdx : Int) : Unit = { + sourceArr match { + case arrFloat : Array[Float] => { + val arrLength = arrFloat.length + if (shape.length == shapeIdx) { + shape += arrLength + } + require(shape(shapeIdx) == arrLength, "Each Array should have equal length") + } + case arr : Array[Any] => { + val arrLength = arr.length + if (shape.length == shapeIdx) { + shape += arrLength + } + require(shape(shapeIdx) == arrLength, + s"Each Array should have equal length, expected ${shape(shapeIdx)}, get $arrLength") + arr.foreach(ele => shapeGetter(ele, shape, shapeIdx + 1)) + } + case _ => throw new IllegalArgumentException(s"Wrong type passed: ${sourceArr.getClass}") + } + } + + private def arrayCombiner(sourceArr : Any, arr : Array[Float], start : Int, end : Int) : Unit = { + sourceArr match { + case arrFloat : Array[Float] => { + for (i <- arrFloat.indices) arr(start + i) = arrFloat(i) + } + case arrAny : Array[Any] => { + val fragment = (end - start + 1) / arrAny.length + for (i <- arrAny.indices) + arrayCombiner(arrAny(i), arr, start + i * fragment, end + (i + 1) * fragment) + } + case _ => throw new IllegalArgumentException(s"Wrong type passed: ${sourceArr.getClass}") + } + } + /** * Returns evenly spaced values within a given interval. * Values are generated within the half-open interval [`start`, `stop`). In other