Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

[MXNET-1287] Feat dep #14668

Merged
merged 3 commits into from
Apr 19, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -395,17 +395,17 @@ private[mxnet] object ExecutorManager {
* @param paramNames Names of all trainable parameters.
* @param ctx List of devices for training (data parallel)
* @param slices Describes how the data parallel splits data into different devices.
* @param providedData training data shapes
* @param providedLabel training label shapes
* @param providedDataDesc training data descriptions
* @param providedLabelDesc training label descriptions
* @param sharedGroup: DataParallelExecutorGroup
* An existing executor group, if to share parameters with it.
*
*/
private class DataParallelExecutorGroup private(sym: Symbol,
argNames: IndexedSeq[String], paramNames: Set[String],
ctx: Array[Context], private val slices: Array[(Int, Int)],
providedData: Map[String, Shape],
providedLabel: Map[String, Shape],
providedDataDesc: IndexedSeq[DataDesc],
providedLabelDesc: IndexedSeq[DataDesc],
sharedGroup: DataParallelExecutorGroup) {
// make sure the architecture is valid
ExecutorManager.checkArguments(sym)
Expand All @@ -417,8 +417,8 @@ private class DataParallelExecutorGroup private(sym: Symbol,
sharedGroup.sharedDataArrays
}

private[mxnet] val dataNames = providedData.map { case (k, _) => k }.toList
private[mxnet] val labelNames = providedLabel.map { case (k, _) => k }.toList
private[mxnet] val dataNames = providedDataDesc.map(_.name).toList
private[mxnet] val labelNames = providedLabelDesc.map(_.name).toList
private[mxnet] val auxNames = sym.listAuxiliaryStates()
private[mxnet] val paramIdx = argNames.zipWithIndex
.filter { case (name, i) => paramNames.contains(name) }
Expand All @@ -428,9 +428,10 @@ private class DataParallelExecutorGroup private(sym: Symbol,
private[mxnet] val trainExecs: Array[Executor] =
ctx.zipWithIndex.map { case (ctxi, i) =>
val dataShapes =
(providedData ++ providedLabel) map { case (name, shape) =>
name -> (Shape(slices(i)._2 - slices(i)._1) ++ shape.slice(1, shape.length))
}
(providedDataDesc ++ providedLabelDesc).map( desc => {
desc.name ->
(Shape(slices(i)._2 - slices(i)._1) ++ desc.shape.slice(1, desc.shape.length))
}).toMap
val sharedExec: Executor = if (sharedGroup == null) null else sharedGroup.trainExecs(i)
ExecutorManager.bindExec(sym, ctxi, dataShapes, paramNamesComb,
needGrad = true, baseExec = sharedExec,
Expand Down Expand Up @@ -479,15 +480,15 @@ private class DataParallelExecutorGroup private(sym: Symbol,
trainData: DataIter,
sharedGroup: DataParallelExecutorGroup) {
this(sym, argNames, paramNames, ctx, slices,
trainData.provideData, trainData.provideLabel, sharedGroup)
trainData.provideDataDesc, trainData.provideLabelDesc, sharedGroup)
}

def this(sym: Symbol,
argNames: IndexedSeq[String], paramNames: Set[String],
ctx: Array[Context], slices: Array[(Int, Int)],
trainData: DataIter) {
this(sym, argNames, paramNames, ctx, slices,
trainData.provideData, trainData.provideLabel, null)
trainData.provideDataDesc, trainData.provideLabelDesc, null)
}

/**
Expand All @@ -509,15 +510,15 @@ private class DataParallelExecutorGroup private(sym: Symbol,
trainData: DataBatch,
sharedGroup: DataParallelExecutorGroup) {
this(sym, argNames, paramNames, ctx, slices,
trainData.provideData, trainData.provideLabel, sharedGroup)
trainData.provideDataDesc, trainData.provideLabelDesc, sharedGroup)
}

def this(sym: Symbol,
argNames: IndexedSeq[String], paramNames: Set[String],
ctx: Array[Context], slices: Array[(Int, Int)],
trainData: DataBatch) {
this(sym, argNames, paramNames, ctx, slices,
trainData.provideData, trainData.provideLabel, null)
trainData.provideDataDesc, trainData.provideLabelDesc, null)
}

// load data and labels into arrays
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -129,11 +129,11 @@ class FeedForward private(
// Initialize weight parameters and auxiliary states
// The NDArrays associated with the _argParms and _auxParams are not disposed instead
// they are passed a outer scope if available.
private def initParams(inputShapes: Map[String, Shape], overwrite: Boolean = false)
private def initParams(inputShapes: IndexedSeq[DataDesc], overwrite: Boolean = false)
: (IndexedSeq[String], IndexedSeq[String], IndexedSeq[String]) = {
val (argShapes, _, auxShapes) = symbol.inferShape(inputShapes)
val argNames = symbol.listArguments()
val inputNames = inputShapes.keys.toSet
val inputNames = inputShapes.map(_.name).toSet
val paramNames = argNames.filter(!inputNames.contains(_))
val auxNames = symbol.listAuxiliaryStates()

Expand Down Expand Up @@ -179,7 +179,7 @@ class FeedForward private(
}

// Initialize the predictor module for running prediction.
private def initPredictor(inputShapes: Map[String, Shape]): Unit = {
private def initPredictor(inputShapes: IndexedSeq[DataDesc]): Unit = {
var shouldInit = true
if (this.predExec != null) {
val (argShapes, _, _) = symbol.inferShape(inputShapes)
Expand All @@ -193,7 +193,7 @@ class FeedForward private(
}
if(shouldInit) {
// for now only use the first device
val predExec = symbol.simpleBind(ctx(0), gradReq = "null", shapeDict = inputShapes)
val predExec = symbol.simpleBind(ctx(0), gradReq = "null", inputShapes)
predExec.copyParamsFrom(_argParams, _auxParams)
ExecutorManager.checkArguments(symbol)
this.predExec = predExec
Expand Down Expand Up @@ -233,8 +233,8 @@ class FeedForward private(
*/
def predict(data: DataIter, numBatch: Int = -1): Array[NDArray] = {
data.reset()
val dataShapes = data.provideData
val dataNames = dataShapes.map(_._1).toArray
val dataShapes = data.provideDataDesc
val dataNames = dataShapes.map(_.name).toArray
initPredictor(dataShapes)
val batchSize = data.batchSize
val dataArrays = dataNames.map(predExec.argDict(_))
Expand Down Expand Up @@ -363,7 +363,7 @@ class FeedForward private(
this.symbol = symGen.generate(trainData.defaultBucketKey)
checkArguments()
}
initParams(trainData.provideData ++ trainData.provideLabel)
initParams(trainData.provideDataDesc ++ trainData.provideLabelDesc)
}

private def fit(trainData: DataIter, evalData: DataIter, evalMetric: EvalMetric = new Accuracy(),
Expand Down
35 changes: 27 additions & 8 deletions scala-package/core/src/main/scala/org/apache/mxnet/IO.scala
Original file line number Diff line number Diff line change
Expand Up @@ -141,28 +141,46 @@ class DataBatch(val data: IndexedSeq[NDArray],
val pad: Int,
// the key for the bucket that should be used for this batch,
// for bucketing io only
val bucketKey: AnyRef,
val bucketKey: AnyRef = null,
// use DataDesc to indicate the order of data/label loading
// (must match the order of input data/label)
private val providedDataDesc: IndexedSeq[DataDesc],
private val providedLabelDesc: IndexedSeq[DataDesc]) {
private val providedDataDesc: IndexedSeq[DataDesc] = null,
private val providedLabelDesc: IndexedSeq[DataDesc] = null) {
// TODO: change the data/label type into IndexedSeq[(NDArray, DataDesc)]
// However, since the data and label can be accessed publicly (no getter and setter)
// the change on this will break BC

@deprecated("Use provideDataDesc and provideDataLabel instead", "1.3.0")
def this(data: IndexedSeq[NDArray],
label: IndexedSeq[NDArray],
index: IndexedSeq[Long],
pad: Int,
// the key for the bucket that should be used for this batch,
// for bucketing io only
bucketKey: AnyRef,
// use ListMap to indicate the order of data/label loading
// (must match the order of input data/label)
providedData: ListMap[String, Shape]) {
this(data, label, index, pad, bucketKey,
DataDesc.ListMap2Descs(providedData))
}

@deprecated("Use provideDataDesc and provideDataLabel instead", "1.3.0")
def this(data: IndexedSeq[NDArray],
label: IndexedSeq[NDArray],
index: IndexedSeq[Long],
pad: Int,
// the key for the bucket that should be used for this batch,
// for bucketing io only
bucketKey: AnyRef = null,
bucketKey: AnyRef,
// use ListMap to indicate the order of data/label loading
// (must match the order of input data/label)
providedData: ListMap[String, Shape] = null,
providedLabel: ListMap[String, Shape] = null) {
providedData: ListMap[String, Shape],
providedLabel: ListMap[String, Shape]) {
this(data, label, index, pad, bucketKey,
DataDesc.ListMap2Descs(providedData), DataDesc.ListMap2Descs(providedLabel))
}

/**
* Dispose its data and labels
* The object shall never be used after it is disposed.
Expand All @@ -177,6 +195,7 @@ class DataBatch(val data: IndexedSeq[NDArray],
}

// The name and shape of data
@deprecated("Use provideDataDesc instead", "1.3.0")
def provideData: ListMap[String, Shape] = {
var temp = ListMap[String, Shape]()
if (providedDataDesc == null) null
Expand All @@ -187,6 +206,7 @@ class DataBatch(val data: IndexedSeq[NDArray],
}

// The name and shape of label
@deprecated("Use provideLabelDesc instead", "1.3.0")
def provideLabel: ListMap[String, Shape] = {
var temp = ListMap[String, Shape]()
if (providedLabelDesc == null) null
Expand Down Expand Up @@ -311,8 +331,7 @@ abstract class DataIter extends Iterator[DataBatch] {
*/
@throws(classOf[NoSuchElementException])
def next(): DataBatch = {
new DataBatch(getData(), getLabel(), getIndex(), getPad(),
null, null, null)
new DataBatch(getData(), getLabel(), getIndex(), getPad())
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@

package org.apache.mxnet

import scala.language.implicitConversions

object MX_PRIMITIVES {

/**
Expand Down
55 changes: 53 additions & 2 deletions scala-package/core/src/main/scala/org/apache/mxnet/Symbol.scala
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import org.apache.mxnet.Base._
import org.apache.mxnet.DType.DType
import org.slf4j.{Logger, LoggerFactory}

import scala.collection.mutable
import scala.collection.mutable.{ArrayBuffer, ListBuffer}
import scala.language.implicitConversions

Expand Down Expand Up @@ -209,6 +210,33 @@ class Symbol private(private[mxnet] val handle: SymbolHandle) extends NativeReso
}
}

/**
* Infer the shape of outputs and arguments of given known shapes of arguments.
* User can either pass in the known shapes in positional way or keyword argument way.
* Tuple of Nones is returned if there is not enough information passed in.
* An error will be raised if there is inconsistency found in the known shapes passed in.
* @param args Provide a list of DataDesc containing the shapes to resolve
* @return
* argShapes List of shapes of arguments. The order is in the same order as list_arguments()
* outShapes List of shapes of outputs. The order is in the same order as list_outputs()
* auxShapes List of shapes of outputs. The order is in the same order as list_auxiliary()
*/
def inferShape(args: IndexedSeq[DataDesc]):
(IndexedSeq[Shape], IndexedSeq[Shape], IndexedSeq[Shape]) = {
val keys = ArrayBuffer.empty[String]
val indPtr = ArrayBuffer(0)
val sdata = ArrayBuffer.empty[Int]
args.foreach { arg =>
val shape = arg.shape
if (shape != null) {
keys += arg.name
sdata ++= shape.toVector
indPtr += sdata.size
}
}
inferShape(keys.toArray, indPtr.toArray, sdata.toArray)
}

/**
* Infer the shape of outputs and arguments of given known shapes of arguments.
* User can either pass in the known shapes in positional way or keyword argument way.
Expand Down Expand Up @@ -389,6 +417,29 @@ class Symbol private(private[mxnet] val handle: SymbolHandle) extends NativeReso
checkCall(_LIB.mxSymbolCompose(handle, name, keys, args))
}

/**
* Bind current symbol to get an executor, allocate all the ndarrays needed.
* Allows specifying data types.
* This function will ask user to pass in ndarray of position
* they like to bind to, and it will automatically allocate the ndarray
* for arguments and auxiliary states that user did not specify explicitly.
*
* @param ctx The device context the generated executor to run on.
* @param gradReq {'write', 'add', 'null'}, or list of str or dict of str to str, optional
* Specifies how we should update the gradient to the args_grad.
* - 'write' means everytime gradient is write to specified args_grad NDArray.
* - 'add' means everytime gradient is add to the specified NDArray.
* - 'null' means no action is taken, the gradient may not be calculated.
* @param dataDesc List of dataDescriptors
* @return The generated Executor
*/
def simpleBind(ctx: Context, gradReq: String,
descs: IndexedSeq[DataDesc]) : Executor = {
val (shapes, types) = descs.map(desc =>
( desc.name -> desc.shape, desc.name -> desc.dtype )).unzip
simpleBind(ctx, gradReq, shapes.toMap, types.toMap)
}

/**
* Bind current symbol to get an executor, allocate all the ndarrays needed.
* Allows specifying data types.
Expand Down Expand Up @@ -1189,7 +1240,7 @@ object Symbol extends SymbolBase {

// a more friendly interface for creating symbols
// all values except symbols in kwargs will be cast to String using its toString() method
@Deprecated
@deprecated("Use Checked version", "0.1.2")
def createFromNamedSymbolsNoCheck(
operator: String, name: String = null, attr: Map[String, String] = null)(
kwargs: Map[String, Any]): Symbol = {
Expand All @@ -1208,7 +1259,7 @@ object Symbol extends SymbolBase {

// a more friendly interface for creating symbols
// all values except symbols in kwargs will be cast to String using its toString() method
@Deprecated
@deprecated("Use Checked version", "0.1.2")
def createFromListedSymbolsNoCheck(
operator: String, name: String = null, attr: Map[String, String] = null)(
symbols: Array[Symbol], kwargs: Map[String, Any] = null): Symbol = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -107,8 +107,7 @@ private[mxnet] class MXDataIter(private[mxnet] val handle: DataIterHandle,
checkCall(_LIB.mxDataIterNext(handle, next))
if (next.value > 0) {
currentBatch = new DataBatch(data = getData(), label = getLabel(),
index = getIndex(), pad = getPad(),
null, null, null)
index = getIndex(), pad = getPad())
} else {
currentBatch = null
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -161,8 +161,7 @@ class NDArrayIter(data: IndexedSeq[(DataDesc, NDArray)],
override def next(): DataBatch = {
if (hasNext) {
cursor += dataBatchSize
new DataBatch(getData(), getLabel(), getIndex(), getPad(),
null, null, null)
new DataBatch(getData(), getLabel(), getIndex(), getPad())
} else {
throw new NoSuchElementException
}
Expand Down
Loading