From e87e9139c2f30b69c6d36a3d686a405e7e7ded62 Mon Sep 17 00:00:00 2001 From: Zach Kimberg Date: Tue, 9 Apr 2019 14:14:46 -0700 Subject: [PATCH] Address feature and deprecation Warnings (mostly provideData) --- .../org/apache/mxnet/ExecutorManager.scala | 27 +++--- .../scala/org/apache/mxnet/FeedForward.scala | 14 +-- .../src/main/scala/org/apache/mxnet/IO.scala | 35 ++++++-- .../main/scala/org/apache/mxnet/Symbol.scala | 51 +++++++++++ .../org/apache/mxnet/io/MXDataIter.scala | 3 +- .../org/apache/mxnet/io/NDArrayIter.scala | 3 +- .../org/apache/mxnet/io/PrefetchingIter.scala | 87 ++++++------------- .../org/apache/mxnet/module/BaseModule.scala | 3 +- .../apache/mxnet/module/BucketingModule.scala | 6 +- .../org/apache/mxnet/module/Module.scala | 4 +- .../mxnet/module/SequentialModule.scala | 16 ++-- .../test/scala/org/apache/mxnet/IOSuite.scala | 28 +++--- .../scala/org/apache/mxnet/ModuleSuite.scala | 4 +- .../customop/ExampleCustomOp.scala | 2 +- .../customop/ExampleCustomOpWithRtc.scala | 2 +- .../datasets/SyntheticDataIter.scala | 2 +- .../mxnetexamples/module/MnistMlp.scala | 2 +- .../module/SequentialModuleEx.scala | 2 +- .../multitask/ExampleMultiTask.scala | 6 +- .../mxnetexamples/rnn/TrainCharRnn.scala | 2 +- .../mxnet/spark/io/LongLivingDataBatch.scala | 3 +- 21 files changed, 170 insertions(+), 132 deletions(-) diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/ExecutorManager.scala b/scala-package/core/src/main/scala/org/apache/mxnet/ExecutorManager.scala index b13741bdd3b0..d94b8fb01ed6 100644 --- a/scala-package/core/src/main/scala/org/apache/mxnet/ExecutorManager.scala +++ b/scala-package/core/src/main/scala/org/apache/mxnet/ExecutorManager.scala @@ -395,8 +395,8 @@ private[mxnet] object ExecutorManager { * @param paramNames Names of all trainable parameters. * @param ctx List of devices for training (data parallel) * @param slices Describes how the data parallel splits data into different devices. - * @param providedData training data shapes - * @param providedLabel training label shapes + * @param providedDataDesc training data descriptions + * @param providedLabelDesc training label descriptions * @param sharedGroup: DataParallelExecutorGroup * An existing executor group, if to share parameters with it. * @@ -404,8 +404,8 @@ private[mxnet] object ExecutorManager { private class DataParallelExecutorGroup private(sym: Symbol, argNames: IndexedSeq[String], paramNames: Set[String], ctx: Array[Context], private val slices: Array[(Int, Int)], - providedData: Map[String, Shape], - providedLabel: Map[String, Shape], + providedDataDesc: IndexedSeq[DataDesc], + providedLabelDesc: IndexedSeq[DataDesc], sharedGroup: DataParallelExecutorGroup) { // make sure the architecture is valid ExecutorManager.checkArguments(sym) @@ -417,8 +417,8 @@ private class DataParallelExecutorGroup private(sym: Symbol, sharedGroup.sharedDataArrays } - private[mxnet] val dataNames = providedData.map { case (k, _) => k }.toList - private[mxnet] val labelNames = providedLabel.map { case (k, _) => k }.toList + private[mxnet] val dataNames = providedDataDesc.map(_.name).toList + private[mxnet] val labelNames = providedLabelDesc.map(_.name).toList private[mxnet] val auxNames = sym.listAuxiliaryStates() private[mxnet] val paramIdx = argNames.zipWithIndex .filter { case (name, i) => paramNames.contains(name) } @@ -428,9 +428,10 @@ private class DataParallelExecutorGroup private(sym: Symbol, private[mxnet] val trainExecs: Array[Executor] = ctx.zipWithIndex.map { case (ctxi, i) => val dataShapes = - (providedData ++ providedLabel) map { case (name, shape) => - name -> (Shape(slices(i)._2 - slices(i)._1) ++ shape.slice(1, shape.length)) - } + (providedDataDesc ++ providedLabelDesc).map( desc => { + desc.name -> + (Shape(slices(i)._2 - slices(i)._1) ++ desc.shape.slice(1, desc.shape.length)) + }).toMap val sharedExec: Executor = if (sharedGroup == null) null else sharedGroup.trainExecs(i) ExecutorManager.bindExec(sym, ctxi, dataShapes, paramNamesComb, needGrad = true, baseExec = sharedExec, @@ -479,7 +480,7 @@ private class DataParallelExecutorGroup private(sym: Symbol, trainData: DataIter, sharedGroup: DataParallelExecutorGroup) { this(sym, argNames, paramNames, ctx, slices, - trainData.provideData, trainData.provideLabel, sharedGroup) + trainData.provideDataDesc, trainData.provideLabelDesc, sharedGroup) } def this(sym: Symbol, @@ -487,7 +488,7 @@ private class DataParallelExecutorGroup private(sym: Symbol, ctx: Array[Context], slices: Array[(Int, Int)], trainData: DataIter) { this(sym, argNames, paramNames, ctx, slices, - trainData.provideData, trainData.provideLabel, null) + trainData.provideDataDesc, trainData.provideLabelDesc, null) } /** @@ -509,7 +510,7 @@ private class DataParallelExecutorGroup private(sym: Symbol, trainData: DataBatch, sharedGroup: DataParallelExecutorGroup) { this(sym, argNames, paramNames, ctx, slices, - trainData.provideData, trainData.provideLabel, sharedGroup) + trainData.provideDataDesc, trainData.provideLabelDesc, sharedGroup) } def this(sym: Symbol, @@ -517,7 +518,7 @@ private class DataParallelExecutorGroup private(sym: Symbol, ctx: Array[Context], slices: Array[(Int, Int)], trainData: DataBatch) { this(sym, argNames, paramNames, ctx, slices, - trainData.provideData, trainData.provideLabel, null) + trainData.provideDataDesc, trainData.provideLabelDesc, null) } // load data and labels into arrays diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/FeedForward.scala b/scala-package/core/src/main/scala/org/apache/mxnet/FeedForward.scala index 2b1765531824..b8e2ba0b39c8 100644 --- a/scala-package/core/src/main/scala/org/apache/mxnet/FeedForward.scala +++ b/scala-package/core/src/main/scala/org/apache/mxnet/FeedForward.scala @@ -129,11 +129,11 @@ class FeedForward private( // Initialize weight parameters and auxiliary states // The NDArrays associated with the _argParms and _auxParams are not disposed instead // they are passed a outer scope if available. - private def initParams(inputShapes: Map[String, Shape], overwrite: Boolean = false) + private def initParams(inputShapes: IndexedSeq[DataDesc], overwrite: Boolean = false) : (IndexedSeq[String], IndexedSeq[String], IndexedSeq[String]) = { val (argShapes, _, auxShapes) = symbol.inferShape(inputShapes) val argNames = symbol.listArguments() - val inputNames = inputShapes.keys.toSet + val inputNames = inputShapes.map(_.name).toSet val paramNames = argNames.filter(!inputNames.contains(_)) val auxNames = symbol.listAuxiliaryStates() @@ -179,7 +179,7 @@ class FeedForward private( } // Initialize the predictor module for running prediction. - private def initPredictor(inputShapes: Map[String, Shape]): Unit = { + private def initPredictor(inputShapes: IndexedSeq[DataDesc]): Unit = { var shouldInit = true if (this.predExec != null) { val (argShapes, _, _) = symbol.inferShape(inputShapes) @@ -193,7 +193,7 @@ class FeedForward private( } if(shouldInit) { // for now only use the first device - val predExec = symbol.simpleBind(ctx(0), gradReq = "null", shapeDict = inputShapes) + val predExec = symbol.simpleBind(ctx(0), gradReq = "null", inputShapes) predExec.copyParamsFrom(_argParams, _auxParams) ExecutorManager.checkArguments(symbol) this.predExec = predExec @@ -233,8 +233,8 @@ class FeedForward private( */ def predict(data: DataIter, numBatch: Int = -1): Array[NDArray] = { data.reset() - val dataShapes = data.provideData - val dataNames = dataShapes.map(_._1).toArray + val dataShapes = data.provideDataDesc + val dataNames = dataShapes.map(_.name).toArray initPredictor(dataShapes) val batchSize = data.batchSize val dataArrays = dataNames.map(predExec.argDict(_)) @@ -363,7 +363,7 @@ class FeedForward private( this.symbol = symGen.generate(trainData.defaultBucketKey) checkArguments() } - initParams(trainData.provideData ++ trainData.provideLabel) + initParams(trainData.provideDataDesc ++ trainData.provideLabelDesc) } private def fit(trainData: DataIter, evalData: DataIter, evalMetric: EvalMetric = new Accuracy(), diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/IO.scala b/scala-package/core/src/main/scala/org/apache/mxnet/IO.scala index b580ad10a04e..1db6d2a6e953 100644 --- a/scala-package/core/src/main/scala/org/apache/mxnet/IO.scala +++ b/scala-package/core/src/main/scala/org/apache/mxnet/IO.scala @@ -141,28 +141,46 @@ class DataBatch(val data: IndexedSeq[NDArray], val pad: Int, // the key for the bucket that should be used for this batch, // for bucketing io only - val bucketKey: AnyRef, + val bucketKey: AnyRef = null, // use DataDesc to indicate the order of data/label loading // (must match the order of input data/label) - private val providedDataDesc: IndexedSeq[DataDesc], - private val providedLabelDesc: IndexedSeq[DataDesc]) { + private val providedDataDesc: IndexedSeq[DataDesc] = null, + private val providedLabelDesc: IndexedSeq[DataDesc] = null) { // TODO: change the data/label type into IndexedSeq[(NDArray, DataDesc)] // However, since the data and label can be accessed publicly (no getter and setter) // the change on this will break BC + + @deprecated("Use provideDataDesc and provideDataLabel instead", "1.3.0") + def this(data: IndexedSeq[NDArray], + label: IndexedSeq[NDArray], + index: IndexedSeq[Long], + pad: Int, + // the key for the bucket that should be used for this batch, + // for bucketing io only + bucketKey: AnyRef, + // use ListMap to indicate the order of data/label loading + // (must match the order of input data/label) + providedData: ListMap[String, Shape]) { + this(data, label, index, pad, bucketKey, + DataDesc.ListMap2Descs(providedData)) + } + + @deprecated("Use provideDataDesc and provideDataLabel instead", "1.3.0") def this(data: IndexedSeq[NDArray], label: IndexedSeq[NDArray], index: IndexedSeq[Long], pad: Int, // the key for the bucket that should be used for this batch, // for bucketing io only - bucketKey: AnyRef = null, + bucketKey: AnyRef, // use ListMap to indicate the order of data/label loading // (must match the order of input data/label) - providedData: ListMap[String, Shape] = null, - providedLabel: ListMap[String, Shape] = null) { + providedData: ListMap[String, Shape], + providedLabel: ListMap[String, Shape]) { this(data, label, index, pad, bucketKey, DataDesc.ListMap2Descs(providedData), DataDesc.ListMap2Descs(providedLabel)) } + /** * Dispose its data and labels * The object shall never be used after it is disposed. @@ -177,6 +195,7 @@ class DataBatch(val data: IndexedSeq[NDArray], } // The name and shape of data + @deprecated("Use provideDataDesc instead", "1.3.0") def provideData: ListMap[String, Shape] = { var temp = ListMap[String, Shape]() if (providedDataDesc == null) null @@ -187,6 +206,7 @@ class DataBatch(val data: IndexedSeq[NDArray], } // The name and shape of label + @deprecated("Use provideLabelDesc instead", "1.3.0") def provideLabel: ListMap[String, Shape] = { var temp = ListMap[String, Shape]() if (providedLabelDesc == null) null @@ -311,8 +331,7 @@ abstract class DataIter extends Iterator[DataBatch] { */ @throws(classOf[NoSuchElementException]) def next(): DataBatch = { - new DataBatch(getData(), getLabel(), getIndex(), getPad(), - null, null, null) + new DataBatch(getData(), getLabel(), getIndex(), getPad()) } /** diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/Symbol.scala b/scala-package/core/src/main/scala/org/apache/mxnet/Symbol.scala index 52f715285b35..296734d1a4ff 100644 --- a/scala-package/core/src/main/scala/org/apache/mxnet/Symbol.scala +++ b/scala-package/core/src/main/scala/org/apache/mxnet/Symbol.scala @@ -21,6 +21,7 @@ import org.apache.mxnet.Base._ import org.apache.mxnet.DType.DType import org.slf4j.{Logger, LoggerFactory} +import scala.collection.mutable import scala.collection.mutable.{ArrayBuffer, ListBuffer} import scala.language.implicitConversions @@ -209,6 +210,33 @@ class Symbol private(private[mxnet] val handle: SymbolHandle) extends NativeReso } } + /** + * Infer the shape of outputs and arguments of given known shapes of arguments. + * User can either pass in the known shapes in positional way or keyword argument way. + * Tuple of Nones is returned if there is not enough information passed in. + * An error will be raised if there is inconsistency found in the known shapes passed in. + * @param args Provide a list of DataDesc containing the shapes to resolve + * @return + * argShapes List of shapes of arguments. The order is in the same order as list_arguments() + * outShapes List of shapes of outputs. The order is in the same order as list_outputs() + * auxShapes List of shapes of outputs. The order is in the same order as list_auxiliary() + */ + def inferShape(args: IndexedSeq[DataDesc]): + (IndexedSeq[Shape], IndexedSeq[Shape], IndexedSeq[Shape]) = { + val keys = ArrayBuffer.empty[String] + val indPtr = ArrayBuffer(0) + val sdata = ArrayBuffer.empty[Int] + args.foreach { arg => + val shape = arg.shape + if (shape != null) { + keys += arg.name + sdata ++= shape.toVector + indPtr += sdata.size + } + } + inferShape(keys.toArray, indPtr.toArray, sdata.toArray) + } + /** * Infer the shape of outputs and arguments of given known shapes of arguments. * User can either pass in the known shapes in positional way or keyword argument way. @@ -389,6 +417,29 @@ class Symbol private(private[mxnet] val handle: SymbolHandle) extends NativeReso checkCall(_LIB.mxSymbolCompose(handle, name, keys, args)) } + /** + * Bind current symbol to get an executor, allocate all the ndarrays needed. + * Allows specifying data types. + * This function will ask user to pass in ndarray of position + * they like to bind to, and it will automatically allocate the ndarray + * for arguments and auxiliary states that user did not specify explicitly. + * + * @param ctx The device context the generated executor to run on. + * @param gradReq {'write', 'add', 'null'}, or list of str or dict of str to str, optional + * Specifies how we should update the gradient to the args_grad. + * - 'write' means everytime gradient is write to specified args_grad NDArray. + * - 'add' means everytime gradient is add to the specified NDArray. + * - 'null' means no action is taken, the gradient may not be calculated. + * @param dataDesc List of dataDescriptors + * @return The generated Executor + */ + def simpleBind(ctx: Context, gradReq: String, + descs: IndexedSeq[DataDesc]) : Executor = { + val (shapes, types) = descs.map(desc => + ( desc.name -> desc.shape, desc.name -> desc.dtype )).unzip + simpleBind(ctx, gradReq, shapes.toMap, types.toMap) + } + /** * Bind current symbol to get an executor, allocate all the ndarrays needed. * Allows specifying data types. diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/io/MXDataIter.scala b/scala-package/core/src/main/scala/org/apache/mxnet/io/MXDataIter.scala index e30098c3088b..66b7d83cedc8 100644 --- a/scala-package/core/src/main/scala/org/apache/mxnet/io/MXDataIter.scala +++ b/scala-package/core/src/main/scala/org/apache/mxnet/io/MXDataIter.scala @@ -107,8 +107,7 @@ private[mxnet] class MXDataIter(private[mxnet] val handle: DataIterHandle, checkCall(_LIB.mxDataIterNext(handle, next)) if (next.value > 0) { currentBatch = new DataBatch(data = getData(), label = getLabel(), - index = getIndex(), pad = getPad(), - null, null, null) + index = getIndex(), pad = getPad()) } else { currentBatch = null } diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/io/NDArrayIter.scala b/scala-package/core/src/main/scala/org/apache/mxnet/io/NDArrayIter.scala index b205bbe47abb..e9513257c050 100644 --- a/scala-package/core/src/main/scala/org/apache/mxnet/io/NDArrayIter.scala +++ b/scala-package/core/src/main/scala/org/apache/mxnet/io/NDArrayIter.scala @@ -161,8 +161,7 @@ class NDArrayIter(data: IndexedSeq[(DataDesc, NDArray)], override def next(): DataBatch = { if (hasNext) { cursor += dataBatchSize - new DataBatch(getData(), getLabel(), getIndex(), getPad(), - null, null, null) + new DataBatch(getData(), getLabel(), getIndex(), getPad()) } else { throw new NoSuchElementException } diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/io/PrefetchingIter.scala b/scala-package/core/src/main/scala/org/apache/mxnet/io/PrefetchingIter.scala index d277351b124b..9cfcd598197c 100644 --- a/scala-package/core/src/main/scala/org/apache/mxnet/io/PrefetchingIter.scala +++ b/scala-package/core/src/main/scala/org/apache/mxnet/io/PrefetchingIter.scala @@ -42,71 +42,51 @@ class PrefetchingIter( require(iters.nonEmpty, "Iters length must be greater than 0") - private val _provideData: ListMap[String, Shape] = { + @deprecated("Please use provideDataDesc instead", "1.3.0") + override def provideData: ListMap[String, Shape] = { if (dataNames == null) { - iters.map(_.provideData).foldLeft(ListMap[String, Shape]()) { (acc, elem) => - acc ++ elem - } + iters.map(_.provideData).reduce(_ ++ _) } else { - iters.zipWithIndex.map(tu => (tu._1.provideData, tu._2)) - .map(m => m._1.map(t => (dataNames(m._2)(t._1), t._2))) - .foldLeft(ListMap[String, Shape]()) { (acc, elem) => - acc ++ elem - } + iters.map(_.provideData).zip(dataNames).map { case (providedData, names) => + providedData.map { case (oldName, shape) => names(oldName) -> shape } + }.reduceLeft(_ ++ _) } } - private val _provideLabel: ListMap[String, Shape] = { + @deprecated("Please use provideDataDesc instead", "1.3.0") + override def provideLabel: ListMap[String, Shape] = { if (labelNames == null) { - iters.map(_.provideLabel).foldLeft(ListMap[String, Shape]()) { (acc, elem) => - acc ++ elem - } + iters.map(_.provideLabel).reduce(_ ++ _) } else { - iters.zipWithIndex.map(tu => (tu._1.provideLabel, tu._2)) - .map(m => m._1.map(t => (labelNames(m._2)(t._1), t._2))) - .foldLeft(ListMap[String, Shape]()) { (acc, elem) => - acc ++ elem - } + iters.map(_.provideLabel).zip(labelNames).map { case (providedLabel, names) => + providedLabel.map { case (oldName, shape) => names(oldName) -> shape } + }.reduceLeft(_ ++ _) } } - private val _provideDataDesc: IndexedSeq[DataDesc] = { + override def provideDataDesc: IndexedSeq[DataDesc] = { if (dataNames == null) { - iters.map(_.provideDataDesc).foldLeft(IndexedSeq[DataDesc]()) { (acc, elem) => - acc ++ elem - } + iters.flatMap(_.provideDataDesc) } else { - iters.zipWithIndex.map(tu => (tu._1.provideDataDesc, tu._2)) - .map(m => - m._1.map(t => - new DataDesc(dataNames(m._2)(t.name), t.shape, t.dtype, t.layout) - ) - ) - .foldLeft(IndexedSeq[DataDesc]()) { (acc, elem) => - acc ++ elem - } + iters.map(_.provideDataDesc).zip(dataNames).flatMap { case (providedDataDesc, names) => + providedDataDesc.map(desc => + new DataDesc(names(desc.name), desc.shape, desc.dtype, desc.layout)) + } } } - private val _provideLabelDesc: IndexedSeq[DataDesc] = { + override def provideLabelDesc: IndexedSeq[DataDesc] = { if (labelNames == null) { - iters.map(_.provideLabelDesc).foldLeft(IndexedSeq[DataDesc]()) { (acc, elem) => - acc ++ elem - } + iters.flatMap(_.provideLabelDesc) } else { - iters.zipWithIndex.map(tu => (tu._1.provideLabelDesc, tu._2)) - .map(m => - m._1.map(t => - new DataDesc(labelNames(m._2)(t.name), t.shape, t.dtype, t.layout) - ) - ) - .foldLeft(IndexedSeq[DataDesc]()) { (acc, elem) => - acc ++ elem - } + iters.map(_.provideLabelDesc).zip(labelNames).flatMap { case (providedLabelDesc, names) => + providedLabelDesc.map(desc => + new DataDesc(names(desc.name), desc.shape, desc.dtype, desc.layout)) + } } } - private val _batchSize: Int = this._provideData.toList(0)._2(0) + private val _batchSize: Int = this.provideDataDesc.head.shape(0) private val dataReady: IndexedSeq[Semaphore] = (0 until iters.length).map(i => new Semaphore(0)) private val dataTaken: IndexedSeq[Semaphore] = @@ -177,20 +157,6 @@ class PrefetchingIter( */ override def getPad(): Int = this.currentBatch.pad - // The name and shape of label provided by this iterator - @deprecated("Please use provideDataDesc instead", "1.3.0") - override def provideLabel: ListMap[String, Shape] = this._provideLabel - - // The name and shape of data provided by this iterator - @deprecated("Please use provideLabelDesc instead", "1.3.0") - override def provideData: ListMap[String, Shape] = this._provideData - - // Provide type:DataDesc of the data - override def provideDataDesc: IndexedSeq[DataDesc] = _provideDataDesc - - // Provide type:DataDesc of the label - override def provideLabelDesc: IndexedSeq[DataDesc] = _provideLabelDesc - override def hasNext: Boolean = { for (e <- dataReady) e.acquire() if (nextBatch(0) == null) { @@ -209,8 +175,7 @@ class PrefetchingIter( currentBatch = new DataBatch(datas.toIndexedSeq.flatten, labels.toIndexedSeq.flatten, nextBatch(0).index, - nextBatch(0).pad, - null, null, null) + nextBatch(0).pad) for (e <- dataTaken) e.release() true } diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/module/BaseModule.scala b/scala-package/core/src/main/scala/org/apache/mxnet/module/BaseModule.scala index b73f4ad4b112..3be8e060fd6f 100644 --- a/scala-package/core/src/main/scala/org/apache/mxnet/module/BaseModule.scala +++ b/scala-package/core/src/main/scala/org/apache/mxnet/module/BaseModule.scala @@ -398,8 +398,7 @@ abstract class BaseModule { fitParams: FitParams = new FitParams): Unit = { require(fitParams != null, "Undefined fitParams") require(numEpoch > 0, s"Invalid number of epochs $numEpoch") - import org.apache.mxnet.DataDesc._ - bind(dataShapes = trainData.provideData, labelShapes = Option(trainData.provideLabel), + bind(dataShapes = trainData.provideDataDesc, labelShapes = Option(trainData.provideLabelDesc), forTraining = true, forceRebind = fitParams.forceRebind) fitParams.monitor.foreach(installMonitor) initParams(fitParams.initializer, argParams, auxParams, diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/module/BucketingModule.scala b/scala-package/core/src/main/scala/org/apache/mxnet/module/BucketingModule.scala index 41a6f69394d2..d5c8c21ea106 100644 --- a/scala-package/core/src/main/scala/org/apache/mxnet/module/BucketingModule.scala +++ b/scala-package/core/src/main/scala/org/apache/mxnet/module/BucketingModule.scala @@ -296,7 +296,7 @@ class BucketingModule(symGen: AnyRef => (Symbol, IndexedSeq[String], IndexedSeq[ require(this.binded && this.paramsInitialized, "bind() and initParams() must be called first.") val bucketKey = dataBatch.bucketKey val originalBucketKey = this._currBucketKey - this.switchBucket(bucketKey, dataBatch.provideData, Option(dataBatch.provideLabel)) + this.switchBucket(bucketKey, dataBatch.provideDataDesc, Option(dataBatch.provideLabelDesc)) // switch back this.switchBucket(originalBucketKey, null, None) } @@ -308,8 +308,8 @@ class BucketingModule(symGen: AnyRef => (Symbol, IndexedSeq[String], IndexedSeq[ */ override def forward(dataBatch: DataBatch, isTrain: Option[Boolean] = None): Unit = { require(binded && paramsInitialized, "bind() and initParams() must be called first.") - this.switchBucket(dataBatch.bucketKey, dataBatch.provideData, - Option(dataBatch.provideLabel)) + this.switchBucket(dataBatch.bucketKey, dataBatch.provideDataDesc, + Option(dataBatch.provideLabelDesc)) this._currModule.forward(dataBatch, isTrain) } diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/module/Module.scala b/scala-package/core/src/main/scala/org/apache/mxnet/module/Module.scala index 3255d9346b80..9928f66b2200 100644 --- a/scala-package/core/src/main/scala/org/apache/mxnet/module/Module.scala +++ b/scala-package/core/src/main/scala/org/apache/mxnet/module/Module.scala @@ -435,14 +435,14 @@ class Module(symbolVar: Symbol, val newDataShapes = dataBatch.data.map(_.shape) if (currDataShapes != newDataShapes) { val newDShapes: IndexedSeq[DataDesc] = - if (dataBatch.provideData != null) dataBatch.provideData + if (dataBatch.provideDataDesc != null) dataBatch.provideDataDesc else { this.dataShapes.zip(newDataShapes).map { case (i, shape) => DataDesc(i.name, shape, i.dtype, i.layout) } } val newLShapes: Option[IndexedSeq[DataDesc]] = - if (dataBatch.provideLabel != null) Some(dataBatch.provideLabel) + if (dataBatch.provideLabelDesc != null) Some(dataBatch.provideLabelDesc) else if (dataBatch.label != null && dataBatch.label.length > 0 && this.labelShapes != null) { Some(this.labelShapes.zip(dataBatch.label).map { case (i, j) => diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/module/SequentialModule.scala b/scala-package/core/src/main/scala/org/apache/mxnet/module/SequentialModule.scala index 3c3eeb97f201..d80e6bc6279b 100644 --- a/scala-package/core/src/main/scala/org/apache/mxnet/module/SequentialModule.scala +++ b/scala-package/core/src/main/scala/org/apache/mxnet/module/SequentialModule.scala @@ -111,6 +111,16 @@ class SequentialModule extends BaseModule { this.labelShapesVar.orNull } + /** + * Get output shapes. + * @return The output shapes of the last + * module is the output shape of a SequentialModule. + */ + def outputDesc: IndexedSeq[DataDesc] = { + require(this.binded, "bind() must be called first.") + this.modules.reverse.head.dataShapes + } + /** * Get output shapes. * @return The output shapes of the last @@ -306,12 +316,8 @@ class SequentialModule extends BaseModule { val dataNames = module.outputShapes.map(_._1) require(dataNames.length == data.data.length, s"dataNames $dataNames do not match with number of arrays in batch") - var provideData = ListMap[String, Shape]() - for ((name, x) <- dataNames.zip(out.head)) { - provideData += name -> x.shape - } data = new DataBatch(out.head, data.label, data.index, - data.pad, data.bucketKey, provideData, data.provideLabel) + data.pad, data.bucketKey, outputDesc, data.provideLabelDesc) } } } diff --git a/scala-package/core/src/test/scala/org/apache/mxnet/IOSuite.scala b/scala-package/core/src/test/scala/org/apache/mxnet/IOSuite.scala index 6471250ed786..9839f09e4063 100644 --- a/scala-package/core/src/test/scala/org/apache/mxnet/IOSuite.scala +++ b/scala-package/core/src/test/scala/org/apache/mxnet/IOSuite.scala @@ -54,10 +54,10 @@ class IOSuite extends FunSuite with BeforeAndAfterAll { // test DataIter val mnistIter = mnistPack.iterator // test provideData - val provideData = mnistIter.provideData - val provideLabel = mnistIter.provideLabel - assert(provideData("data") === Shape(100, 784)) - assert(provideLabel("label") === Shape(100)) + val provideData = mnistIter.provideDataDesc + val provideLabel = mnistIter.provideLabelDesc + assert(provideData.find(_.name == "data").get.shape === Shape(100, 784)) + assert(provideLabel.find(_.name == "label").get.shape === Shape(100)) // test_loop mnistIter.reset() batchCount = 0 @@ -106,10 +106,10 @@ class IOSuite extends FunSuite with BeforeAndAfterAll { val nBatch = 500 var batchCount = 0 // test provideData - val provideData = imgRecIter.provideData - val provideLabel = imgRecIter.provideLabel - assert(provideData("data").toArray === Array(100, 3, 28, 28)) - assert(provideLabel("label").toArray === Array(100)) + val provideData = imgRecIter.provideDataDesc + val provideLabel = imgRecIter.provideLabelDesc + assert(provideData.find(_.name == "data").get.shape.toArray === Array(100, 3, 28, 28)) + assert(provideLabel.find(_.name == "label").get.shape.toArray === Array(100)) imgRecIter.reset() while (imgRecIter.hasNext) { @@ -209,12 +209,12 @@ class IOSuite extends FunSuite with BeforeAndAfterAll { assert(nBatch === batchCount) // test provideData - val provideData = prefetchIter.provideData - val provideLabel = prefetchIter.provideLabel - assert(provideData("data1") === Shape(100, 784)) - assert(provideData("data2") === Shape(100, 784)) - assert(provideLabel("label1") === Shape(100)) - assert(provideLabel("label2") === Shape(100)) + val provideData = prefetchIter.provideDataDesc + val provideLabel = prefetchIter.provideLabelDesc + assert(provideData.find(_.name == "data1").get.shape === Shape(100, 784)) + assert(provideData.find(_.name == "data2").get.shape === Shape(100, 784)) + assert(provideLabel.find(_.name == "label1").get.shape === Shape(100)) + assert(provideLabel.find(_.name == "label2").get.shape === Shape(100)) // test reset prefetchIter.reset() diff --git a/scala-package/core/src/test/scala/org/apache/mxnet/ModuleSuite.scala b/scala-package/core/src/test/scala/org/apache/mxnet/ModuleSuite.scala index 88e314e2a72c..3e753a18d247 100644 --- a/scala-package/core/src/test/scala/org/apache/mxnet/ModuleSuite.scala +++ b/scala-package/core/src/test/scala/org/apache/mxnet/ModuleSuite.scala @@ -253,8 +253,8 @@ class ModuleSuite extends FunSuite with BeforeAndAfterAll { // create module val mod = new Module(x, contexts = Array(Context.cpu())) - mod.bind(dataShapes = trainData.provideData, - Option(trainData.provideLabel)) + mod.bind(dataShapes = trainData.provideDataDesc, + Option(trainData.provideLabelDesc)) mod.installMonitor(mon) val argParams = Map( "fc_0_weight" -> NDArray.array(Array(0.15f, 0.2f, 0.25f, 0.3f), Shape(2, 2)), diff --git a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/customop/ExampleCustomOp.scala b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/customop/ExampleCustomOp.scala index df79f5b63769..0cfcc49aee04 100644 --- a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/customop/ExampleCustomOp.scala +++ b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/customop/ExampleCustomOp.scala @@ -107,7 +107,7 @@ object ExampleCustomOp { val (trainIter, testIter) = Data.mnistIterator(dataPath, batchSize = 100, inputShape = Shape(784)) - val datasAndLabels = trainIter.provideData ++ trainIter.provideLabel + val datasAndLabels = trainIter.provideDataDesc ++ trainIter.provideLabelDesc val (argShapes, outputShapes, auxShapes) = mlp.inferShape(datasAndLabels) val initializer = new Xavier(factorType = "in", magnitude = 2.34f) diff --git a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/customop/ExampleCustomOpWithRtc.scala b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/customop/ExampleCustomOpWithRtc.scala index c3ac347353df..7b0fb349373d 100644 --- a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/customop/ExampleCustomOpWithRtc.scala +++ b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/customop/ExampleCustomOpWithRtc.scala @@ -128,7 +128,7 @@ object ExampleCustomOpWithRtc { val (trainIter, testIter) = Data.mnistIterator(dataPath, batchSize = 100, inputShape = Shape(784)) - val datasAndLabels = trainIter.provideData ++ trainIter.provideLabel + val datasAndLabels = trainIter.provideDataDesc ++ trainIter.provideLabelDesc val (argShapes, outputShapes, auxShapes) = mlp.inferShape(datasAndLabels) val initializer = new Xavier(factorType = "in", magnitude = 2.34f) diff --git a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/imclassification/datasets/SyntheticDataIter.scala b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/imclassification/datasets/SyntheticDataIter.scala index e4d3b2ae7c3e..4d22b62bea81 100644 --- a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/imclassification/datasets/SyntheticDataIter.scala +++ b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/imclassification/datasets/SyntheticDataIter.scala @@ -54,7 +54,7 @@ class SyntheticDataIter(numClasses: Int, val batchSize: Int, datumShape: List[In override def next(): DataBatch = { if (hasNext) { curIter += batchSize - new DataBatch(data, label, getIndex, getPad, null, null, null) + new DataBatch(data, label, getIndex, getPad) } else { throw new NoSuchElementException } diff --git a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/module/MnistMlp.scala b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/module/MnistMlp.scala index 4d450c60456b..839f6ac85902 100644 --- a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/module/MnistMlp.scala +++ b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/module/MnistMlp.scala @@ -49,7 +49,7 @@ object MnistMlp { logger.info("Load checkpoint from epoch {}", loadModelEpoch) Module.loadCheckpoint("model/mnist_mlp", loadModelEpoch, loadOptimizerStates = true) } - mod.bind(dataShapes = train.provideData, labelShapes = Some(train.provideLabel)) + mod.bind(dataShapes = train.provideDataDesc, labelShapes = Some(train.provideLabelDesc)) mod.initParams() mod.initOptimizer(optimizer = new SGD(learningRate = 0.01f, momentum = 0.9f)) diff --git a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/module/SequentialModuleEx.scala b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/module/SequentialModuleEx.scala index ff616a57b1aa..ea2273ebd796 100644 --- a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/module/SequentialModuleEx.scala +++ b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/module/SequentialModuleEx.scala @@ -57,7 +57,7 @@ object SequentialModuleEx { cmdLine: SequentialModuleEx): Unit = { // Intermediate-level API val modSeq = getSeqModule() - modSeq.bind(dataShapes = train.provideData, labelShapes = Some(train.provideLabel)) + modSeq.bind(dataShapes = train.provideDataDesc, labelShapes = Some(train.provideLabelDesc)) if (cmdLine.loadModelPath != null) { logger.info(s"Load checkpoint from ${cmdLine.loadModelPath}") modSeq.loadParams(cmdLine.loadModelPath) diff --git a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/multitask/ExampleMultiTask.scala b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/multitask/ExampleMultiTask.scala index 766fafda940c..5c17a3747ab6 100644 --- a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/multitask/ExampleMultiTask.scala +++ b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/multitask/ExampleMultiTask.scala @@ -66,7 +66,7 @@ object ExampleMultiTask { new DataBatch(batch.data, IndexedSeq(label, label), batch.index, - batch.pad, null, null, null) + batch.pad) } else { throw new NoSuchElementException } @@ -230,10 +230,10 @@ object ExampleMultiTask { val trainMultiIt = new MultiMnistIterator(trainIter) val valMultiIter = new MultiMnistIterator(valIter) - val datasAndLabels = trainMultiIt.provideData ++ trainMultiIt.provideLabel + val datasAndLabels = trainMultiIt.provideDataDesc ++ trainMultiIt.provideLabelDesc val (argShapes, outputShapes, auxShapes) - = network.inferShape(trainMultiIt.provideData("data")) + = network.inferShape(trainMultiIt.provideDataDesc.filter(_.name == "data")) val initializer = new Xavier(factorType = "in", magnitude = 2.34f) val argNames = network.listArguments diff --git a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/rnn/TrainCharRnn.scala b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/rnn/TrainCharRnn.scala index c90b7637b9b1..32a3cfcd45d6 100644 --- a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/rnn/TrainCharRnn.scala +++ b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/rnn/TrainCharRnn.scala @@ -75,7 +75,7 @@ object TrainCharRnn { // the network symbol val symbol = symGen(buckets(0)) - val datasAndLabels = dataTrain.provideData ++ dataTrain.provideLabel + val datasAndLabels = dataTrain.provideDataDesc ++ dataTrain.provideLabelDesc val (argShapes, outputShapes, auxShapes) = symbol.inferShape(datasAndLabels) val initializer = new Xavier(factorType = "in", magnitude = 2.34f) diff --git a/scala-package/spark/src/main/scala/org/apache/mxnet/spark/io/LongLivingDataBatch.scala b/scala-package/spark/src/main/scala/org/apache/mxnet/spark/io/LongLivingDataBatch.scala index e3272a4066b5..abf82f6e510c 100644 --- a/scala-package/spark/src/main/scala/org/apache/mxnet/spark/io/LongLivingDataBatch.scala +++ b/scala-package/spark/src/main/scala/org/apache/mxnet/spark/io/LongLivingDataBatch.scala @@ -28,8 +28,7 @@ class LongLivingDataBatch( override val data: IndexedSeq[NDArray], override val label: IndexedSeq[NDArray], override val index: IndexedSeq[Long], - override val pad: Int) extends DataBatch(data, label, index, pad, - null, null, null) { + override val pad: Int) extends DataBatch(data, label, index, pad) { override def dispose(): Unit = {} def disposeForce(): Unit = super.dispose() }