Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
refactor code
Browse files Browse the repository at this point in the history
  • Loading branch information
lanking520 committed Jan 10, 2019
1 parent cfca296 commit 3e96e35
Show file tree
Hide file tree
Showing 3 changed files with 60 additions and 84 deletions.
61 changes: 54 additions & 7 deletions scala-package/core/src/main/scala/org/apache/mxnet/NDArray.scala
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ import scala.collection.mutable
import scala.collection.mutable.{ArrayBuffer, ListBuffer}
import scala.language.implicitConversions
import scala.ref.WeakReference
import scala.util.Try

/**
* NDArray Object extends from NDArrayBase for abstract function signatures
Expand Down Expand Up @@ -718,7 +719,6 @@ object NDArray extends NDArrayBase {
genericNDArrayFunctionInvoke("_crop_assign", args, kwargs)
}

// TODO: imdecode
}

/**
Expand All @@ -745,6 +745,12 @@ class NDArray private[mxnet](private[mxnet] val handle: NDArrayHandle,
// we use weak reference to prevent gc blocking
private[mxnet] val dependencies = mutable.HashMap.empty[Long, WeakReference[NDArray]]

private val traceProperty = "mxnet.setNDArrayPrintLength"
private lazy val printLength = {
val value = Try(System.getProperty(traceProperty).toInt).getOrElse(1000)
value
}

def serialize(): Array[Byte] = {
val buf = ArrayBuffer.empty[Byte]
checkCall(_LIB.mxNDArraySaveRawBytes(handle, buf))
Expand Down Expand Up @@ -808,13 +814,54 @@ class NDArray private[mxnet](private[mxnet] val handle: NDArrayHandle,
checkCall(_LIB.mxNDArraySyncCopyFromCPU(handle, source, source.length))
}

private def syncCopyfrom(source: Array[Double]): Unit = {
require(source.length == size,
s"array size (${source.length}) do not match the size of NDArray ($size)")
checkCall(_LIB.mxFloat64NDArraySyncCopyFromCPU(handle, source, source.length))
/**
* Visualize the internal structure of NDArray
* @return String that show the structure
*/
override def toString: String = {
val abstractND = buildStringHelper(this, this.shape.length)
val otherInfo = s"<NDArray ${this.shape} ${this.context} ${this.dtype}>"
s"$abstractND\n$otherInfo"
}

override def toString() : String = {
s"<NDArray ${this.shape} ${this.context}>"
/**
* Helper function to create formatted NDArray output
* The NDArray will be represented in a reduced version if too large
* @param nd NDArray as the input
* @param totalSpace totalSpace of the lowest dimension
* @return String format of NDArray
*/
private def buildStringHelper(nd : NDArray, totalSpace : Int) : String = {
var result = ""
val THRESHOLD = 10 // longest NDArray[NDArray[...]] to show in full
val ARRAYTHRESHOLD = printLength // longest array to show in full
val shape = nd.shape
val space = totalSpace - shape.length
if (shape.length != 1) {
val (length, postfix) =
if (shape(0) > THRESHOLD) {
// reduced NDArray
(10, s"\n${" " * (space + 1)}... with length ${shape(0)}\n")
} else {
(shape(0), "")
}
for (num <- 0 until length) {
val output = buildStringHelper(nd.at(num), totalSpace)
result += s"$output\n"
}
result = s"${" " * space}[\n$result${" " * space}$postfix${" " * space}]"
} else {
if (shape(0) > ARRAYTHRESHOLD) {
// reduced Array
val front = nd.slice(0, 10)
val back = nd.slice(shape(0) - 10, shape(0) - 1)
result = s"""${" " * space}[${front.toArray.mkString(",")}
| ... ${back.toArray.mkString(",")}]""".stripMargin
} else {
result = s"${" " * space}[${nd.toArray.mkString(",")}]"
}
}
result
}

/**
Expand Down

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,15 @@ import java.io.File
import java.util.concurrent.atomic.AtomicInteger

import org.apache.mxnet.NDArrayConversions._
import org.apache.mxnet.util.Visualize
import org.scalatest.{BeforeAndAfterAll, FunSuite, Matchers}
import org.slf4j.LoggerFactory
import scala.collection.mutable.ArrayBuffer

class NDArraySuite extends FunSuite with BeforeAndAfterAll with Matchers {
private val sequence: AtomicInteger = new AtomicInteger(0)

private val logger = LoggerFactory.getLogger(classOf[NDArraySuite])

test("to java array") {
val ndarray = NDArray.zeros(2, 2)
assert(ndarray.toArray === Array(0f, 0f, 0f, 0f))
Expand Down Expand Up @@ -106,10 +108,10 @@ class NDArraySuite extends FunSuite with BeforeAndAfterAll with Matchers {
}

test("test Visualize") {
var nd = NDArray.ones(Shape(1, 2, 100, 1))
Visualize.toString(nd)
var nd = NDArray.ones(Shape(1, 2, 1000, 1))
logger.info(s"Test print large ndarray:\n$nd")
nd = NDArray.ones(Shape(1, 4))
Visualize.toString(nd)
logger.info(s"Test print small ndarray:\n$nd")
}

test("plus") {
Expand Down

0 comments on commit 3e96e35

Please sign in to comment.