diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/ExecutorManager.scala b/scala-package/core/src/main/scala/org/apache/mxnet/ExecutorManager.scala index b13741bdd3b0..d94b8fb01ed6 100644 --- a/scala-package/core/src/main/scala/org/apache/mxnet/ExecutorManager.scala +++ b/scala-package/core/src/main/scala/org/apache/mxnet/ExecutorManager.scala @@ -395,8 +395,8 @@ private[mxnet] object ExecutorManager { * @param paramNames Names of all trainable parameters. * @param ctx List of devices for training (data parallel) * @param slices Describes how the data parallel splits data into different devices. - * @param providedData training data shapes - * @param providedLabel training label shapes + * @param providedDataDesc training data descriptions + * @param providedLabelDesc training label descriptions * @param sharedGroup: DataParallelExecutorGroup * An existing executor group, if to share parameters with it. * @@ -404,8 +404,8 @@ private[mxnet] object ExecutorManager { private class DataParallelExecutorGroup private(sym: Symbol, argNames: IndexedSeq[String], paramNames: Set[String], ctx: Array[Context], private val slices: Array[(Int, Int)], - providedData: Map[String, Shape], - providedLabel: Map[String, Shape], + providedDataDesc: IndexedSeq[DataDesc], + providedLabelDesc: IndexedSeq[DataDesc], sharedGroup: DataParallelExecutorGroup) { // make sure the architecture is valid ExecutorManager.checkArguments(sym) @@ -417,8 +417,8 @@ private class DataParallelExecutorGroup private(sym: Symbol, sharedGroup.sharedDataArrays } - private[mxnet] val dataNames = providedData.map { case (k, _) => k }.toList - private[mxnet] val labelNames = providedLabel.map { case (k, _) => k }.toList + private[mxnet] val dataNames = providedDataDesc.map(_.name).toList + private[mxnet] val labelNames = providedLabelDesc.map(_.name).toList private[mxnet] val auxNames = sym.listAuxiliaryStates() private[mxnet] val paramIdx = argNames.zipWithIndex .filter { case (name, i) => paramNames.contains(name) } @@ -428,9 +428,10 @@ private class DataParallelExecutorGroup private(sym: Symbol, private[mxnet] val trainExecs: Array[Executor] = ctx.zipWithIndex.map { case (ctxi, i) => val dataShapes = - (providedData ++ providedLabel) map { case (name, shape) => - name -> (Shape(slices(i)._2 - slices(i)._1) ++ shape.slice(1, shape.length)) - } + (providedDataDesc ++ providedLabelDesc).map( desc => { + desc.name -> + (Shape(slices(i)._2 - slices(i)._1) ++ desc.shape.slice(1, desc.shape.length)) + }).toMap val sharedExec: Executor = if (sharedGroup == null) null else sharedGroup.trainExecs(i) ExecutorManager.bindExec(sym, ctxi, dataShapes, paramNamesComb, needGrad = true, baseExec = sharedExec, @@ -479,7 +480,7 @@ private class DataParallelExecutorGroup private(sym: Symbol, trainData: DataIter, sharedGroup: DataParallelExecutorGroup) { this(sym, argNames, paramNames, ctx, slices, - trainData.provideData, trainData.provideLabel, sharedGroup) + trainData.provideDataDesc, trainData.provideLabelDesc, sharedGroup) } def this(sym: Symbol, @@ -487,7 +488,7 @@ private class DataParallelExecutorGroup private(sym: Symbol, ctx: Array[Context], slices: Array[(Int, Int)], trainData: DataIter) { this(sym, argNames, paramNames, ctx, slices, - trainData.provideData, trainData.provideLabel, null) + trainData.provideDataDesc, trainData.provideLabelDesc, null) } /** @@ -509,7 +510,7 @@ private class DataParallelExecutorGroup private(sym: Symbol, trainData: DataBatch, sharedGroup: DataParallelExecutorGroup) { this(sym, argNames, paramNames, ctx, slices, - trainData.provideData, trainData.provideLabel, sharedGroup) + trainData.provideDataDesc, trainData.provideLabelDesc, sharedGroup) } def this(sym: Symbol, @@ -517,7 +518,7 @@ private class DataParallelExecutorGroup private(sym: Symbol, ctx: Array[Context], slices: Array[(Int, Int)], trainData: DataBatch) { this(sym, argNames, paramNames, ctx, slices, - trainData.provideData, trainData.provideLabel, null) + trainData.provideDataDesc, trainData.provideLabelDesc, null) } // load data and labels into arrays diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/FeedForward.scala b/scala-package/core/src/main/scala/org/apache/mxnet/FeedForward.scala index 2b1765531824..b8e2ba0b39c8 100644 --- a/scala-package/core/src/main/scala/org/apache/mxnet/FeedForward.scala +++ b/scala-package/core/src/main/scala/org/apache/mxnet/FeedForward.scala @@ -129,11 +129,11 @@ class FeedForward private( // Initialize weight parameters and auxiliary states // The NDArrays associated with the _argParms and _auxParams are not disposed instead // they are passed a outer scope if available. - private def initParams(inputShapes: Map[String, Shape], overwrite: Boolean = false) + private def initParams(inputShapes: IndexedSeq[DataDesc], overwrite: Boolean = false) : (IndexedSeq[String], IndexedSeq[String], IndexedSeq[String]) = { val (argShapes, _, auxShapes) = symbol.inferShape(inputShapes) val argNames = symbol.listArguments() - val inputNames = inputShapes.keys.toSet + val inputNames = inputShapes.map(_.name).toSet val paramNames = argNames.filter(!inputNames.contains(_)) val auxNames = symbol.listAuxiliaryStates() @@ -179,7 +179,7 @@ class FeedForward private( } // Initialize the predictor module for running prediction. - private def initPredictor(inputShapes: Map[String, Shape]): Unit = { + private def initPredictor(inputShapes: IndexedSeq[DataDesc]): Unit = { var shouldInit = true if (this.predExec != null) { val (argShapes, _, _) = symbol.inferShape(inputShapes) @@ -193,7 +193,7 @@ class FeedForward private( } if(shouldInit) { // for now only use the first device - val predExec = symbol.simpleBind(ctx(0), gradReq = "null", shapeDict = inputShapes) + val predExec = symbol.simpleBind(ctx(0), gradReq = "null", inputShapes) predExec.copyParamsFrom(_argParams, _auxParams) ExecutorManager.checkArguments(symbol) this.predExec = predExec @@ -233,8 +233,8 @@ class FeedForward private( */ def predict(data: DataIter, numBatch: Int = -1): Array[NDArray] = { data.reset() - val dataShapes = data.provideData - val dataNames = dataShapes.map(_._1).toArray + val dataShapes = data.provideDataDesc + val dataNames = dataShapes.map(_.name).toArray initPredictor(dataShapes) val batchSize = data.batchSize val dataArrays = dataNames.map(predExec.argDict(_)) @@ -363,7 +363,7 @@ class FeedForward private( this.symbol = symGen.generate(trainData.defaultBucketKey) checkArguments() } - initParams(trainData.provideData ++ trainData.provideLabel) + initParams(trainData.provideDataDesc ++ trainData.provideLabelDesc) } private def fit(trainData: DataIter, evalData: DataIter, evalMetric: EvalMetric = new Accuracy(), diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/IO.scala b/scala-package/core/src/main/scala/org/apache/mxnet/IO.scala index b580ad10a04e..1db6d2a6e953 100644 --- a/scala-package/core/src/main/scala/org/apache/mxnet/IO.scala +++ b/scala-package/core/src/main/scala/org/apache/mxnet/IO.scala @@ -141,28 +141,46 @@ class DataBatch(val data: IndexedSeq[NDArray], val pad: Int, // the key for the bucket that should be used for this batch, // for bucketing io only - val bucketKey: AnyRef, + val bucketKey: AnyRef = null, // use DataDesc to indicate the order of data/label loading // (must match the order of input data/label) - private val providedDataDesc: IndexedSeq[DataDesc], - private val providedLabelDesc: IndexedSeq[DataDesc]) { + private val providedDataDesc: IndexedSeq[DataDesc] = null, + private val providedLabelDesc: IndexedSeq[DataDesc] = null) { // TODO: change the data/label type into IndexedSeq[(NDArray, DataDesc)] // However, since the data and label can be accessed publicly (no getter and setter) // the change on this will break BC + + @deprecated("Use provideDataDesc and provideDataLabel instead", "1.3.0") + def this(data: IndexedSeq[NDArray], + label: IndexedSeq[NDArray], + index: IndexedSeq[Long], + pad: Int, + // the key for the bucket that should be used for this batch, + // for bucketing io only + bucketKey: AnyRef, + // use ListMap to indicate the order of data/label loading + // (must match the order of input data/label) + providedData: ListMap[String, Shape]) { + this(data, label, index, pad, bucketKey, + DataDesc.ListMap2Descs(providedData)) + } + + @deprecated("Use provideDataDesc and provideDataLabel instead", "1.3.0") def this(data: IndexedSeq[NDArray], label: IndexedSeq[NDArray], index: IndexedSeq[Long], pad: Int, // the key for the bucket that should be used for this batch, // for bucketing io only - bucketKey: AnyRef = null, + bucketKey: AnyRef, // use ListMap to indicate the order of data/label loading // (must match the order of input data/label) - providedData: ListMap[String, Shape] = null, - providedLabel: ListMap[String, Shape] = null) { + providedData: ListMap[String, Shape], + providedLabel: ListMap[String, Shape]) { this(data, label, index, pad, bucketKey, DataDesc.ListMap2Descs(providedData), DataDesc.ListMap2Descs(providedLabel)) } + /** * Dispose its data and labels * The object shall never be used after it is disposed. @@ -177,6 +195,7 @@ class DataBatch(val data: IndexedSeq[NDArray], } // The name and shape of data + @deprecated("Use provideDataDesc instead", "1.3.0") def provideData: ListMap[String, Shape] = { var temp = ListMap[String, Shape]() if (providedDataDesc == null) null @@ -187,6 +206,7 @@ class DataBatch(val data: IndexedSeq[NDArray], } // The name and shape of label + @deprecated("Use provideLabelDesc instead", "1.3.0") def provideLabel: ListMap[String, Shape] = { var temp = ListMap[String, Shape]() if (providedLabelDesc == null) null @@ -311,8 +331,7 @@ abstract class DataIter extends Iterator[DataBatch] { */ @throws(classOf[NoSuchElementException]) def next(): DataBatch = { - new DataBatch(getData(), getLabel(), getIndex(), getPad(), - null, null, null) + new DataBatch(getData(), getLabel(), getIndex(), getPad()) } /** diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/MX_PRIMITIVES.scala b/scala-package/core/src/main/scala/org/apache/mxnet/MX_PRIMITIVES.scala index 3a51222cc0b8..de7792850dc1 100644 --- a/scala-package/core/src/main/scala/org/apache/mxnet/MX_PRIMITIVES.scala +++ b/scala-package/core/src/main/scala/org/apache/mxnet/MX_PRIMITIVES.scala @@ -17,6 +17,8 @@ package org.apache.mxnet +import scala.language.implicitConversions + object MX_PRIMITIVES { /** diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/Symbol.scala b/scala-package/core/src/main/scala/org/apache/mxnet/Symbol.scala index 821e04f08df2..296734d1a4ff 100644 --- a/scala-package/core/src/main/scala/org/apache/mxnet/Symbol.scala +++ b/scala-package/core/src/main/scala/org/apache/mxnet/Symbol.scala @@ -21,6 +21,7 @@ import org.apache.mxnet.Base._ import org.apache.mxnet.DType.DType import org.slf4j.{Logger, LoggerFactory} +import scala.collection.mutable import scala.collection.mutable.{ArrayBuffer, ListBuffer} import scala.language.implicitConversions @@ -209,6 +210,33 @@ class Symbol private(private[mxnet] val handle: SymbolHandle) extends NativeReso } } + /** + * Infer the shape of outputs and arguments of given known shapes of arguments. + * User can either pass in the known shapes in positional way or keyword argument way. + * Tuple of Nones is returned if there is not enough information passed in. + * An error will be raised if there is inconsistency found in the known shapes passed in. + * @param args Provide a list of DataDesc containing the shapes to resolve + * @return + * argShapes List of shapes of arguments. The order is in the same order as list_arguments() + * outShapes List of shapes of outputs. The order is in the same order as list_outputs() + * auxShapes List of shapes of outputs. The order is in the same order as list_auxiliary() + */ + def inferShape(args: IndexedSeq[DataDesc]): + (IndexedSeq[Shape], IndexedSeq[Shape], IndexedSeq[Shape]) = { + val keys = ArrayBuffer.empty[String] + val indPtr = ArrayBuffer(0) + val sdata = ArrayBuffer.empty[Int] + args.foreach { arg => + val shape = arg.shape + if (shape != null) { + keys += arg.name + sdata ++= shape.toVector + indPtr += sdata.size + } + } + inferShape(keys.toArray, indPtr.toArray, sdata.toArray) + } + /** * Infer the shape of outputs and arguments of given known shapes of arguments. * User can either pass in the known shapes in positional way or keyword argument way. @@ -389,6 +417,29 @@ class Symbol private(private[mxnet] val handle: SymbolHandle) extends NativeReso checkCall(_LIB.mxSymbolCompose(handle, name, keys, args)) } + /** + * Bind current symbol to get an executor, allocate all the ndarrays needed. + * Allows specifying data types. + * This function will ask user to pass in ndarray of position + * they like to bind to, and it will automatically allocate the ndarray + * for arguments and auxiliary states that user did not specify explicitly. + * + * @param ctx The device context the generated executor to run on. + * @param gradReq {'write', 'add', 'null'}, or list of str or dict of str to str, optional + * Specifies how we should update the gradient to the args_grad. + * - 'write' means everytime gradient is write to specified args_grad NDArray. + * - 'add' means everytime gradient is add to the specified NDArray. + * - 'null' means no action is taken, the gradient may not be calculated. + * @param dataDesc List of dataDescriptors + * @return The generated Executor + */ + def simpleBind(ctx: Context, gradReq: String, + descs: IndexedSeq[DataDesc]) : Executor = { + val (shapes, types) = descs.map(desc => + ( desc.name -> desc.shape, desc.name -> desc.dtype )).unzip + simpleBind(ctx, gradReq, shapes.toMap, types.toMap) + } + /** * Bind current symbol to get an executor, allocate all the ndarrays needed. * Allows specifying data types. @@ -1189,7 +1240,7 @@ object Symbol extends SymbolBase { // a more friendly interface for creating symbols // all values except symbols in kwargs will be cast to String using its toString() method - @Deprecated + @deprecated("Use Checked version", "0.1.2") def createFromNamedSymbolsNoCheck( operator: String, name: String = null, attr: Map[String, String] = null)( kwargs: Map[String, Any]): Symbol = { @@ -1208,7 +1259,7 @@ object Symbol extends SymbolBase { // a more friendly interface for creating symbols // all values except symbols in kwargs will be cast to String using its toString() method - @Deprecated + @deprecated("Use Checked version", "0.1.2") def createFromListedSymbolsNoCheck( operator: String, name: String = null, attr: Map[String, String] = null)( symbols: Array[Symbol], kwargs: Map[String, Any] = null): Symbol = { diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/io/MXDataIter.scala b/scala-package/core/src/main/scala/org/apache/mxnet/io/MXDataIter.scala index e30098c3088b..66b7d83cedc8 100644 --- a/scala-package/core/src/main/scala/org/apache/mxnet/io/MXDataIter.scala +++ b/scala-package/core/src/main/scala/org/apache/mxnet/io/MXDataIter.scala @@ -107,8 +107,7 @@ private[mxnet] class MXDataIter(private[mxnet] val handle: DataIterHandle, checkCall(_LIB.mxDataIterNext(handle, next)) if (next.value > 0) { currentBatch = new DataBatch(data = getData(), label = getLabel(), - index = getIndex(), pad = getPad(), - null, null, null) + index = getIndex(), pad = getPad()) } else { currentBatch = null } diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/io/NDArrayIter.scala b/scala-package/core/src/main/scala/org/apache/mxnet/io/NDArrayIter.scala index b205bbe47abb..e9513257c050 100644 --- a/scala-package/core/src/main/scala/org/apache/mxnet/io/NDArrayIter.scala +++ b/scala-package/core/src/main/scala/org/apache/mxnet/io/NDArrayIter.scala @@ -161,8 +161,7 @@ class NDArrayIter(data: IndexedSeq[(DataDesc, NDArray)], override def next(): DataBatch = { if (hasNext) { cursor += dataBatchSize - new DataBatch(getData(), getLabel(), getIndex(), getPad(), - null, null, null) + new DataBatch(getData(), getLabel(), getIndex(), getPad()) } else { throw new NoSuchElementException } diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/io/PrefetchingIter.scala b/scala-package/core/src/main/scala/org/apache/mxnet/io/PrefetchingIter.scala index d277351b124b..9cfcd598197c 100644 --- a/scala-package/core/src/main/scala/org/apache/mxnet/io/PrefetchingIter.scala +++ b/scala-package/core/src/main/scala/org/apache/mxnet/io/PrefetchingIter.scala @@ -42,71 +42,51 @@ class PrefetchingIter( require(iters.nonEmpty, "Iters length must be greater than 0") - private val _provideData: ListMap[String, Shape] = { + @deprecated("Please use provideDataDesc instead", "1.3.0") + override def provideData: ListMap[String, Shape] = { if (dataNames == null) { - iters.map(_.provideData).foldLeft(ListMap[String, Shape]()) { (acc, elem) => - acc ++ elem - } + iters.map(_.provideData).reduce(_ ++ _) } else { - iters.zipWithIndex.map(tu => (tu._1.provideData, tu._2)) - .map(m => m._1.map(t => (dataNames(m._2)(t._1), t._2))) - .foldLeft(ListMap[String, Shape]()) { (acc, elem) => - acc ++ elem - } + iters.map(_.provideData).zip(dataNames).map { case (providedData, names) => + providedData.map { case (oldName, shape) => names(oldName) -> shape } + }.reduceLeft(_ ++ _) } } - private val _provideLabel: ListMap[String, Shape] = { + @deprecated("Please use provideDataDesc instead", "1.3.0") + override def provideLabel: ListMap[String, Shape] = { if (labelNames == null) { - iters.map(_.provideLabel).foldLeft(ListMap[String, Shape]()) { (acc, elem) => - acc ++ elem - } + iters.map(_.provideLabel).reduce(_ ++ _) } else { - iters.zipWithIndex.map(tu => (tu._1.provideLabel, tu._2)) - .map(m => m._1.map(t => (labelNames(m._2)(t._1), t._2))) - .foldLeft(ListMap[String, Shape]()) { (acc, elem) => - acc ++ elem - } + iters.map(_.provideLabel).zip(labelNames).map { case (providedLabel, names) => + providedLabel.map { case (oldName, shape) => names(oldName) -> shape } + }.reduceLeft(_ ++ _) } } - private val _provideDataDesc: IndexedSeq[DataDesc] = { + override def provideDataDesc: IndexedSeq[DataDesc] = { if (dataNames == null) { - iters.map(_.provideDataDesc).foldLeft(IndexedSeq[DataDesc]()) { (acc, elem) => - acc ++ elem - } + iters.flatMap(_.provideDataDesc) } else { - iters.zipWithIndex.map(tu => (tu._1.provideDataDesc, tu._2)) - .map(m => - m._1.map(t => - new DataDesc(dataNames(m._2)(t.name), t.shape, t.dtype, t.layout) - ) - ) - .foldLeft(IndexedSeq[DataDesc]()) { (acc, elem) => - acc ++ elem - } + iters.map(_.provideDataDesc).zip(dataNames).flatMap { case (providedDataDesc, names) => + providedDataDesc.map(desc => + new DataDesc(names(desc.name), desc.shape, desc.dtype, desc.layout)) + } } } - private val _provideLabelDesc: IndexedSeq[DataDesc] = { + override def provideLabelDesc: IndexedSeq[DataDesc] = { if (labelNames == null) { - iters.map(_.provideLabelDesc).foldLeft(IndexedSeq[DataDesc]()) { (acc, elem) => - acc ++ elem - } + iters.flatMap(_.provideLabelDesc) } else { - iters.zipWithIndex.map(tu => (tu._1.provideLabelDesc, tu._2)) - .map(m => - m._1.map(t => - new DataDesc(labelNames(m._2)(t.name), t.shape, t.dtype, t.layout) - ) - ) - .foldLeft(IndexedSeq[DataDesc]()) { (acc, elem) => - acc ++ elem - } + iters.map(_.provideLabelDesc).zip(labelNames).flatMap { case (providedLabelDesc, names) => + providedLabelDesc.map(desc => + new DataDesc(names(desc.name), desc.shape, desc.dtype, desc.layout)) + } } } - private val _batchSize: Int = this._provideData.toList(0)._2(0) + private val _batchSize: Int = this.provideDataDesc.head.shape(0) private val dataReady: IndexedSeq[Semaphore] = (0 until iters.length).map(i => new Semaphore(0)) private val dataTaken: IndexedSeq[Semaphore] = @@ -177,20 +157,6 @@ class PrefetchingIter( */ override def getPad(): Int = this.currentBatch.pad - // The name and shape of label provided by this iterator - @deprecated("Please use provideDataDesc instead", "1.3.0") - override def provideLabel: ListMap[String, Shape] = this._provideLabel - - // The name and shape of data provided by this iterator - @deprecated("Please use provideLabelDesc instead", "1.3.0") - override def provideData: ListMap[String, Shape] = this._provideData - - // Provide type:DataDesc of the data - override def provideDataDesc: IndexedSeq[DataDesc] = _provideDataDesc - - // Provide type:DataDesc of the label - override def provideLabelDesc: IndexedSeq[DataDesc] = _provideLabelDesc - override def hasNext: Boolean = { for (e <- dataReady) e.acquire() if (nextBatch(0) == null) { @@ -209,8 +175,7 @@ class PrefetchingIter( currentBatch = new DataBatch(datas.toIndexedSeq.flatten, labels.toIndexedSeq.flatten, nextBatch(0).index, - nextBatch(0).pad, - null, null, null) + nextBatch(0).pad) for (e <- dataTaken) e.release() true } diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/module/BaseModule.scala b/scala-package/core/src/main/scala/org/apache/mxnet/module/BaseModule.scala index b73f4ad4b112..3be8e060fd6f 100644 --- a/scala-package/core/src/main/scala/org/apache/mxnet/module/BaseModule.scala +++ b/scala-package/core/src/main/scala/org/apache/mxnet/module/BaseModule.scala @@ -398,8 +398,7 @@ abstract class BaseModule { fitParams: FitParams = new FitParams): Unit = { require(fitParams != null, "Undefined fitParams") require(numEpoch > 0, s"Invalid number of epochs $numEpoch") - import org.apache.mxnet.DataDesc._ - bind(dataShapes = trainData.provideData, labelShapes = Option(trainData.provideLabel), + bind(dataShapes = trainData.provideDataDesc, labelShapes = Option(trainData.provideLabelDesc), forTraining = true, forceRebind = fitParams.forceRebind) fitParams.monitor.foreach(installMonitor) initParams(fitParams.initializer, argParams, auxParams, diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/module/BucketingModule.scala b/scala-package/core/src/main/scala/org/apache/mxnet/module/BucketingModule.scala index 41a6f69394d2..d5c8c21ea106 100644 --- a/scala-package/core/src/main/scala/org/apache/mxnet/module/BucketingModule.scala +++ b/scala-package/core/src/main/scala/org/apache/mxnet/module/BucketingModule.scala @@ -296,7 +296,7 @@ class BucketingModule(symGen: AnyRef => (Symbol, IndexedSeq[String], IndexedSeq[ require(this.binded && this.paramsInitialized, "bind() and initParams() must be called first.") val bucketKey = dataBatch.bucketKey val originalBucketKey = this._currBucketKey - this.switchBucket(bucketKey, dataBatch.provideData, Option(dataBatch.provideLabel)) + this.switchBucket(bucketKey, dataBatch.provideDataDesc, Option(dataBatch.provideLabelDesc)) // switch back this.switchBucket(originalBucketKey, null, None) } @@ -308,8 +308,8 @@ class BucketingModule(symGen: AnyRef => (Symbol, IndexedSeq[String], IndexedSeq[ */ override def forward(dataBatch: DataBatch, isTrain: Option[Boolean] = None): Unit = { require(binded && paramsInitialized, "bind() and initParams() must be called first.") - this.switchBucket(dataBatch.bucketKey, dataBatch.provideData, - Option(dataBatch.provideLabel)) + this.switchBucket(dataBatch.bucketKey, dataBatch.provideDataDesc, + Option(dataBatch.provideLabelDesc)) this._currModule.forward(dataBatch, isTrain) } diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/module/Module.scala b/scala-package/core/src/main/scala/org/apache/mxnet/module/Module.scala index 3255d9346b80..9928f66b2200 100644 --- a/scala-package/core/src/main/scala/org/apache/mxnet/module/Module.scala +++ b/scala-package/core/src/main/scala/org/apache/mxnet/module/Module.scala @@ -435,14 +435,14 @@ class Module(symbolVar: Symbol, val newDataShapes = dataBatch.data.map(_.shape) if (currDataShapes != newDataShapes) { val newDShapes: IndexedSeq[DataDesc] = - if (dataBatch.provideData != null) dataBatch.provideData + if (dataBatch.provideDataDesc != null) dataBatch.provideDataDesc else { this.dataShapes.zip(newDataShapes).map { case (i, shape) => DataDesc(i.name, shape, i.dtype, i.layout) } } val newLShapes: Option[IndexedSeq[DataDesc]] = - if (dataBatch.provideLabel != null) Some(dataBatch.provideLabel) + if (dataBatch.provideLabelDesc != null) Some(dataBatch.provideLabelDesc) else if (dataBatch.label != null && dataBatch.label.length > 0 && this.labelShapes != null) { Some(this.labelShapes.zip(dataBatch.label).map { case (i, j) => diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/module/SequentialModule.scala b/scala-package/core/src/main/scala/org/apache/mxnet/module/SequentialModule.scala index 3c3eeb97f201..d80e6bc6279b 100644 --- a/scala-package/core/src/main/scala/org/apache/mxnet/module/SequentialModule.scala +++ b/scala-package/core/src/main/scala/org/apache/mxnet/module/SequentialModule.scala @@ -111,6 +111,16 @@ class SequentialModule extends BaseModule { this.labelShapesVar.orNull } + /** + * Get output shapes. + * @return The output shapes of the last + * module is the output shape of a SequentialModule. + */ + def outputDesc: IndexedSeq[DataDesc] = { + require(this.binded, "bind() must be called first.") + this.modules.reverse.head.dataShapes + } + /** * Get output shapes. * @return The output shapes of the last @@ -306,12 +316,8 @@ class SequentialModule extends BaseModule { val dataNames = module.outputShapes.map(_._1) require(dataNames.length == data.data.length, s"dataNames $dataNames do not match with number of arrays in batch") - var provideData = ListMap[String, Shape]() - for ((name, x) <- dataNames.zip(out.head)) { - provideData += name -> x.shape - } data = new DataBatch(out.head, data.label, data.index, - data.pad, data.bucketKey, provideData, data.provideLabel) + data.pad, data.bucketKey, outputDesc, data.provideLabelDesc) } } } diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/util/OptionConversion.scala b/scala-package/core/src/main/scala/org/apache/mxnet/util/OptionConversion.scala index 2cf453ac3d18..c780a9605b12 100644 --- a/scala-package/core/src/main/scala/org/apache/mxnet/util/OptionConversion.scala +++ b/scala-package/core/src/main/scala/org/apache/mxnet/util/OptionConversion.scala @@ -17,6 +17,8 @@ package org.apache.mxnet.util +import scala.language.implicitConversions + object OptionConversion { implicit def someWrapper[A](noSome : A) : Option[A] = Option(noSome) } diff --git a/scala-package/core/src/test/scala/org/apache/mxnet/IOSuite.scala b/scala-package/core/src/test/scala/org/apache/mxnet/IOSuite.scala index 698a2b53a9fa..9839f09e4063 100644 --- a/scala-package/core/src/test/scala/org/apache/mxnet/IOSuite.scala +++ b/scala-package/core/src/test/scala/org/apache/mxnet/IOSuite.scala @@ -19,6 +19,7 @@ package org.apache.mxnet import org.apache.mxnet.io.{NDArrayIter, ResizeIter, PrefetchingIter} import org.scalatest.{BeforeAndAfterAll, FunSuite} +import scala.language.postfixOps import scala.sys.process._ @@ -53,10 +54,10 @@ class IOSuite extends FunSuite with BeforeAndAfterAll { // test DataIter val mnistIter = mnistPack.iterator // test provideData - val provideData = mnistIter.provideData - val provideLabel = mnistIter.provideLabel - assert(provideData("data") === Shape(100, 784)) - assert(provideLabel("label") === Shape(100)) + val provideData = mnistIter.provideDataDesc + val provideLabel = mnistIter.provideLabelDesc + assert(provideData.find(_.name == "data").get.shape === Shape(100, 784)) + assert(provideLabel.find(_.name == "label").get.shape === Shape(100)) // test_loop mnistIter.reset() batchCount = 0 @@ -105,10 +106,10 @@ class IOSuite extends FunSuite with BeforeAndAfterAll { val nBatch = 500 var batchCount = 0 // test provideData - val provideData = imgRecIter.provideData - val provideLabel = imgRecIter.provideLabel - assert(provideData("data").toArray === Array(100, 3, 28, 28)) - assert(provideLabel("label").toArray === Array(100)) + val provideData = imgRecIter.provideDataDesc + val provideLabel = imgRecIter.provideLabelDesc + assert(provideData.find(_.name == "data").get.shape.toArray === Array(100, 3, 28, 28)) + assert(provideLabel.find(_.name == "label").get.shape.toArray === Array(100)) imgRecIter.reset() while (imgRecIter.hasNext) { @@ -208,12 +209,12 @@ class IOSuite extends FunSuite with BeforeAndAfterAll { assert(nBatch === batchCount) // test provideData - val provideData = prefetchIter.provideData - val provideLabel = prefetchIter.provideLabel - assert(provideData("data1") === Shape(100, 784)) - assert(provideData("data2") === Shape(100, 784)) - assert(provideLabel("label1") === Shape(100)) - assert(provideLabel("label2") === Shape(100)) + val provideData = prefetchIter.provideDataDesc + val provideLabel = prefetchIter.provideLabelDesc + assert(provideData.find(_.name == "data1").get.shape === Shape(100, 784)) + assert(provideData.find(_.name == "data2").get.shape === Shape(100, 784)) + assert(provideLabel.find(_.name == "label1").get.shape === Shape(100)) + assert(provideLabel.find(_.name == "label2").get.shape === Shape(100)) // test reset prefetchIter.reset() diff --git a/scala-package/core/src/test/scala/org/apache/mxnet/ModuleSuite.scala b/scala-package/core/src/test/scala/org/apache/mxnet/ModuleSuite.scala index 88e314e2a72c..3e753a18d247 100644 --- a/scala-package/core/src/test/scala/org/apache/mxnet/ModuleSuite.scala +++ b/scala-package/core/src/test/scala/org/apache/mxnet/ModuleSuite.scala @@ -253,8 +253,8 @@ class ModuleSuite extends FunSuite with BeforeAndAfterAll { // create module val mod = new Module(x, contexts = Array(Context.cpu())) - mod.bind(dataShapes = trainData.provideData, - Option(trainData.provideLabel)) + mod.bind(dataShapes = trainData.provideDataDesc, + Option(trainData.provideLabelDesc)) mod.installMonitor(mon) val argParams = Map( "fc_0_weight" -> NDArray.array(Array(0.15f, 0.2f, 0.25f, 0.3f), Shape(2, 2)), diff --git a/scala-package/core/src/test/scala/org/apache/mxnet/train/ConvSuite.scala b/scala-package/core/src/test/scala/org/apache/mxnet/train/ConvSuite.scala index 44f57c0d1162..eb6d5b7d7175 100644 --- a/scala-package/core/src/test/scala/org/apache/mxnet/train/ConvSuite.scala +++ b/scala-package/core/src/test/scala/org/apache/mxnet/train/ConvSuite.scala @@ -23,6 +23,7 @@ import org.scalatest.{BeforeAndAfterAll, FunSuite} import org.slf4j.LoggerFactory import scala.collection.mutable.ListBuffer +import scala.language.postfixOps import scala.sys.process._ class ConvSuite extends FunSuite with BeforeAndAfterAll { diff --git a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/cnntextclassification/CNNTextClassification.scala b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/cnntextclassification/CNNTextClassification.scala index 7745043b23d8..d9902e9dcc75 100644 --- a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/cnntextclassification/CNNTextClassification.scala +++ b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/cnntextclassification/CNNTextClassification.scala @@ -188,7 +188,7 @@ object CNNTextClassification { // decay learning rate if (iter % 50 == 0 && iter > 0) { factor *= 0.5f - opt.setLrScale(paramBlocks.map(_._1 -> factor).toMap) + opt.setLrMult(paramBlocks.map(paramBlock => (Left(paramBlock._1), factor)).toMap) logger.info(s"reset learning to ${opt.learningRate * factor}") } // end of training loop diff --git a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/customop/ExampleCustomOp.scala b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/customop/ExampleCustomOp.scala index df79f5b63769..0cfcc49aee04 100644 --- a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/customop/ExampleCustomOp.scala +++ b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/customop/ExampleCustomOp.scala @@ -107,7 +107,7 @@ object ExampleCustomOp { val (trainIter, testIter) = Data.mnistIterator(dataPath, batchSize = 100, inputShape = Shape(784)) - val datasAndLabels = trainIter.provideData ++ trainIter.provideLabel + val datasAndLabels = trainIter.provideDataDesc ++ trainIter.provideLabelDesc val (argShapes, outputShapes, auxShapes) = mlp.inferShape(datasAndLabels) val initializer = new Xavier(factorType = "in", magnitude = 2.34f) diff --git a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/customop/ExampleCustomOpWithRtc.scala b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/customop/ExampleCustomOpWithRtc.scala index c3ac347353df..7b0fb349373d 100644 --- a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/customop/ExampleCustomOpWithRtc.scala +++ b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/customop/ExampleCustomOpWithRtc.scala @@ -128,7 +128,7 @@ object ExampleCustomOpWithRtc { val (trainIter, testIter) = Data.mnistIterator(dataPath, batchSize = 100, inputShape = Shape(784)) - val datasAndLabels = trainIter.provideData ++ trainIter.provideLabel + val datasAndLabels = trainIter.provideDataDesc ++ trainIter.provideLabelDesc val (argShapes, outputShapes, auxShapes) = mlp.inferShape(datasAndLabels) val initializer = new Xavier(factorType = "in", magnitude = 2.34f) diff --git a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/imclassification/datasets/SyntheticDataIter.scala b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/imclassification/datasets/SyntheticDataIter.scala index e4d3b2ae7c3e..4d22b62bea81 100644 --- a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/imclassification/datasets/SyntheticDataIter.scala +++ b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/imclassification/datasets/SyntheticDataIter.scala @@ -54,7 +54,7 @@ class SyntheticDataIter(numClasses: Int, val batchSize: Int, datumShape: List[In override def next(): DataBatch = { if (hasNext) { curIter += batchSize - new DataBatch(data, label, getIndex, getPad, null, null, null) + new DataBatch(data, label, getIndex, getPad) } else { throw new NoSuchElementException } diff --git a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/module/MnistMlp.scala b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/module/MnistMlp.scala index 4d450c60456b..839f6ac85902 100644 --- a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/module/MnistMlp.scala +++ b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/module/MnistMlp.scala @@ -49,7 +49,7 @@ object MnistMlp { logger.info("Load checkpoint from epoch {}", loadModelEpoch) Module.loadCheckpoint("model/mnist_mlp", loadModelEpoch, loadOptimizerStates = true) } - mod.bind(dataShapes = train.provideData, labelShapes = Some(train.provideLabel)) + mod.bind(dataShapes = train.provideDataDesc, labelShapes = Some(train.provideLabelDesc)) mod.initParams() mod.initOptimizer(optimizer = new SGD(learningRate = 0.01f, momentum = 0.9f)) diff --git a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/module/SequentialModuleEx.scala b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/module/SequentialModuleEx.scala index ff616a57b1aa..ea2273ebd796 100644 --- a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/module/SequentialModuleEx.scala +++ b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/module/SequentialModuleEx.scala @@ -57,7 +57,7 @@ object SequentialModuleEx { cmdLine: SequentialModuleEx): Unit = { // Intermediate-level API val modSeq = getSeqModule() - modSeq.bind(dataShapes = train.provideData, labelShapes = Some(train.provideLabel)) + modSeq.bind(dataShapes = train.provideDataDesc, labelShapes = Some(train.provideLabelDesc)) if (cmdLine.loadModelPath != null) { logger.info(s"Load checkpoint from ${cmdLine.loadModelPath}") modSeq.loadParams(cmdLine.loadModelPath) diff --git a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/multitask/ExampleMultiTask.scala b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/multitask/ExampleMultiTask.scala index bfde55831e26..5c17a3747ab6 100644 --- a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/multitask/ExampleMultiTask.scala +++ b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/multitask/ExampleMultiTask.scala @@ -31,6 +31,7 @@ import org.apache.mxnet.optimizer.RMSProp import org.apache.mxnetexamples.Util import scala.collection.immutable.ListMap +import scala.language.postfixOps import scala.sys.process.Process /** @@ -65,7 +66,7 @@ object ExampleMultiTask { new DataBatch(batch.data, IndexedSeq(label, label), batch.index, - batch.pad, null, null, null) + batch.pad) } else { throw new NoSuchElementException } @@ -100,7 +101,7 @@ object ExampleMultiTask { override def getIndex(): IndexedSeq[Long] = this.dataIter.getIndex() // The name and shape of label provided by this iterator - @deprecated + @deprecated("Use provideLabelDesc instead", "1.3.0") override def provideLabel: ListMap[String, Shape] = { val provideLabel = this.dataIter.provideLabel.toArray // Different labels should be used here for actual application @@ -126,7 +127,7 @@ object ExampleMultiTask { override def getPad(): Int = this.dataIter.getPad() // The name and shape of data provided by this iterator - @deprecated + @deprecated("Use provideDataDesc instead", "1.3.0") override def provideData: ListMap[String, Shape] = this.dataIter.provideData override def provideDataDesc: IndexedSeq[DataDesc] = this.dataIter.provideDataDesc @@ -229,10 +230,10 @@ object ExampleMultiTask { val trainMultiIt = new MultiMnistIterator(trainIter) val valMultiIter = new MultiMnistIterator(valIter) - val datasAndLabels = trainMultiIt.provideData ++ trainMultiIt.provideLabel + val datasAndLabels = trainMultiIt.provideDataDesc ++ trainMultiIt.provideLabelDesc val (argShapes, outputShapes, auxShapes) - = network.inferShape(trainMultiIt.provideData("data")) + = network.inferShape(trainMultiIt.provideDataDesc.filter(_.name == "data")) val initializer = new Xavier(factorType = "in", magnitude = 2.34f) val argNames = network.listArguments diff --git a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/neuralstyle/NeuralStyle.scala b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/neuralstyle/NeuralStyle.scala index 1767cabcbae4..475e179f819b 100644 --- a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/neuralstyle/NeuralStyle.scala +++ b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/neuralstyle/NeuralStyle.scala @@ -39,7 +39,7 @@ object NeuralStyle { private val logger = LoggerFactory.getLogger(classOf[NeuralStyle]) def preprocessContentImage(path: String, longEdge: Int, ctx: Context): NDArray = { - val img = Image(new File(path)) + val img = Image.fromFile(new File(path)) logger.info(s"load the content image, size = ${(img.height, img.width)}") val factor = longEdge.toFloat / Math.max(img.height, img.width) val (newHeight, newWidth) = ((img.height * factor).toInt, (img.width * factor).toInt) @@ -60,7 +60,7 @@ object NeuralStyle { } def preprocessStyleImage(path: String, shape: Shape, ctx: Context): NDArray = { - val img = Image(new File(path)) + val img = Image.fromFile(new File(path)) val resizedImg = img.scaleTo(shape(3), shape(2)) val sample = NDArray.empty(Shape(1, 3, shape(2), shape(3)), ctx) val datas = { diff --git a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/neuralstyle/end2end/DataProcessing.scala b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/neuralstyle/end2end/DataProcessing.scala index 80a009ea40c2..5b01d2016467 100644 --- a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/neuralstyle/end2end/DataProcessing.scala +++ b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/neuralstyle/end2end/DataProcessing.scala @@ -29,7 +29,7 @@ object DataProcessing { def preprocessContentImage(path: String, dShape: Shape = null, ctx: Context): NDArray = { - val img = Image(new File(path)) + val img = Image.fromFile(new File(path)) val resizedImg = img.scaleTo(dShape(3), dShape(2)) val sample = NDArray.empty(Shape(1, 3, resizedImg.height, resizedImg.width), ctx) val datas = { @@ -46,7 +46,7 @@ object DataProcessing { } def preprocessStyleImage(path: String, shape: Shape, ctx: Context): NDArray = { - val img = Image(new File(path)) + val img = Image.fromFile(new File(path)) val resizedImg = img.scaleTo(shape(3), shape(2)) val sample = NDArray.empty(Shape(1, 3, shape(2), shape(3)), ctx) val datas = { diff --git a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/rnn/BucketIo.scala b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/rnn/BucketIo.scala index 350e28cf8634..2648f9e3d6bb 100644 --- a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/rnn/BucketIo.scala +++ b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/rnn/BucketIo.scala @@ -251,11 +251,11 @@ object BucketIo { override def getPad(): Int = 0 // The name and shape of label provided by this iterator - @deprecated + @deprecated("Use provideLabelDesc instead", "1.3.0") override def provideLabel: ListMap[String, Shape] = this._provideLabel // The name and shape of data provided by this iterator - @deprecated + @deprecated("Use provideDataDesc instead", "1.3.0") override def provideData: ListMap[String, Shape] = this._provideData // Provide type:DataDesc of the data diff --git a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/rnn/TrainCharRnn.scala b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/rnn/TrainCharRnn.scala index c90b7637b9b1..68346afe1f47 100644 --- a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/rnn/TrainCharRnn.scala +++ b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/rnn/TrainCharRnn.scala @@ -75,7 +75,7 @@ object TrainCharRnn { // the network symbol val symbol = symGen(buckets(0)) - val datasAndLabels = dataTrain.provideData ++ dataTrain.provideLabel + val datasAndLabels = dataTrain.provideDataDesc ++ dataTrain.provideLabelDesc val (argShapes, outputShapes, auxShapes) = symbol.inferShape(datasAndLabels) val initializer = new Xavier(factorType = "in", magnitude = 2.34f) @@ -85,12 +85,13 @@ object TrainCharRnn { val auxNames = symbol.listAuxiliaryStates() val auxDict = auxNames.zip(auxShapes.map(NDArray.zeros(_, ctx))).toMap + val datasAndLabelsNames = datasAndLabels.map(_.name) val gradDict = argNames.zip(argShapes).filter { case (name, shape) => - !datasAndLabels.contains(name) + !datasAndLabelsNames.contains(name) }.map(x => x._1 -> NDArray.empty(x._2, ctx)).toMap argDict.foreach { case (name, ndArray) => - if (!datasAndLabels.contains(name)) { + if (!datasAndLabelsNames.contains(name)) { initializer.initWeight(name, ndArray) } } diff --git a/scala-package/examples/src/test/scala/org/apache/mxnetexamples/benchmark/ScalaInferenceBenchmarkSuite.scala b/scala-package/examples/src/test/scala/org/apache/mxnetexamples/benchmark/ScalaInferenceBenchmarkSuite.scala index 0b7f4693c5fa..548f2e4122e0 100644 --- a/scala-package/examples/src/test/scala/org/apache/mxnetexamples/benchmark/ScalaInferenceBenchmarkSuite.scala +++ b/scala-package/examples/src/test/scala/org/apache/mxnetexamples/benchmark/ScalaInferenceBenchmarkSuite.scala @@ -22,6 +22,7 @@ import org.apache.mxnetexamples.Util import org.scalatest.{BeforeAndAfterAll, FunSuite} import org.slf4j.LoggerFactory +import scala.language.postfixOps import scala.sys.process.Process class ScalaInferenceBenchmarkSuite extends FunSuite with BeforeAndAfterAll { diff --git a/scala-package/examples/src/test/scala/org/apache/mxnetexamples/cnntextclassification/CNNClassifierExampleSuite.scala b/scala-package/examples/src/test/scala/org/apache/mxnetexamples/cnntextclassification/CNNClassifierExampleSuite.scala index a7e36dfc3a11..ae0ee33002d9 100644 --- a/scala-package/examples/src/test/scala/org/apache/mxnetexamples/cnntextclassification/CNNClassifierExampleSuite.scala +++ b/scala-package/examples/src/test/scala/org/apache/mxnetexamples/cnntextclassification/CNNClassifierExampleSuite.scala @@ -26,6 +26,7 @@ import org.apache.mxnetexamples.Util import org.scalatest.{BeforeAndAfterAll, FunSuite} import org.slf4j.LoggerFactory +import scala.language.postfixOps import scala.sys.process.Process /** diff --git a/scala-package/examples/src/test/scala/org/apache/mxnetexamples/customop/CustomOpExampleSuite.scala b/scala-package/examples/src/test/scala/org/apache/mxnetexamples/customop/CustomOpExampleSuite.scala index 6385e062a260..b65f237c8621 100644 --- a/scala-package/examples/src/test/scala/org/apache/mxnetexamples/customop/CustomOpExampleSuite.scala +++ b/scala-package/examples/src/test/scala/org/apache/mxnetexamples/customop/CustomOpExampleSuite.scala @@ -25,6 +25,7 @@ import org.apache.mxnetexamples.Util import org.scalatest.{BeforeAndAfterAll, FunSuite} import org.slf4j.LoggerFactory +import scala.language.postfixOps import scala.sys.process.Process class CustomOpExampleSuite extends FunSuite with BeforeAndAfterAll { diff --git a/scala-package/examples/src/test/scala/org/apache/mxnetexamples/gan/GanExampleSuite.scala b/scala-package/examples/src/test/scala/org/apache/mxnetexamples/gan/GanExampleSuite.scala index 59faba9a3779..709ea77632e0 100644 --- a/scala-package/examples/src/test/scala/org/apache/mxnetexamples/gan/GanExampleSuite.scala +++ b/scala-package/examples/src/test/scala/org/apache/mxnetexamples/gan/GanExampleSuite.scala @@ -24,6 +24,7 @@ import org.apache.mxnetexamples.Util import org.scalatest.{BeforeAndAfterAll, FunSuite, Ignore} import org.slf4j.LoggerFactory +import scala.language.postfixOps import scala.sys.process.Process class GanExampleSuite extends FunSuite with BeforeAndAfterAll{ diff --git a/scala-package/examples/src/test/scala/org/apache/mxnetexamples/imclassification/IMClassificationExampleSuite.scala b/scala-package/examples/src/test/scala/org/apache/mxnetexamples/imclassification/IMClassificationExampleSuite.scala index 0daba5a97d77..e6f4f6fcc908 100644 --- a/scala-package/examples/src/test/scala/org/apache/mxnetexamples/imclassification/IMClassificationExampleSuite.scala +++ b/scala-package/examples/src/test/scala/org/apache/mxnetexamples/imclassification/IMClassificationExampleSuite.scala @@ -24,6 +24,7 @@ import org.apache.mxnetexamples.Util import org.scalatest.{BeforeAndAfterAll, FunSuite} import org.slf4j.LoggerFactory +import scala.language.postfixOps import scala.sys.process.Process /** diff --git a/scala-package/examples/src/test/scala/org/apache/mxnetexamples/infer/imageclassifier/ImageClassifierExampleSuite.scala b/scala-package/examples/src/test/scala/org/apache/mxnetexamples/infer/imageclassifier/ImageClassifierExampleSuite.scala index c5308ac37512..9c16aca420ef 100644 --- a/scala-package/examples/src/test/scala/org/apache/mxnetexamples/infer/imageclassifier/ImageClassifierExampleSuite.scala +++ b/scala-package/examples/src/test/scala/org/apache/mxnetexamples/infer/imageclassifier/ImageClassifierExampleSuite.scala @@ -23,6 +23,7 @@ import java.io.File import org.apache.mxnet.Context import org.apache.mxnetexamples.Util +import scala.language.postfixOps import sys.process.Process /** diff --git a/scala-package/examples/src/test/scala/org/apache/mxnetexamples/infer/objectdetector/ObjectDetectorExampleSuite.scala b/scala-package/examples/src/test/scala/org/apache/mxnetexamples/infer/objectdetector/ObjectDetectorExampleSuite.scala index bd960bddebf5..918fb835f76e 100644 --- a/scala-package/examples/src/test/scala/org/apache/mxnetexamples/infer/objectdetector/ObjectDetectorExampleSuite.scala +++ b/scala-package/examples/src/test/scala/org/apache/mxnetexamples/infer/objectdetector/ObjectDetectorExampleSuite.scala @@ -23,6 +23,7 @@ import org.apache.mxnetexamples.Util import org.scalatest.{BeforeAndAfterAll, FunSuite} import org.slf4j.LoggerFactory +import scala.language.postfixOps import scala.sys.process.Process class ObjectDetectorExampleSuite extends FunSuite with BeforeAndAfterAll { diff --git a/scala-package/examples/src/test/scala/org/apache/mxnetexamples/neuralstyle/NeuralStyleSuite.scala b/scala-package/examples/src/test/scala/org/apache/mxnetexamples/neuralstyle/NeuralStyleSuite.scala index 71c2b35ef444..c93a7d06a452 100644 --- a/scala-package/examples/src/test/scala/org/apache/mxnetexamples/neuralstyle/NeuralStyleSuite.scala +++ b/scala-package/examples/src/test/scala/org/apache/mxnetexamples/neuralstyle/NeuralStyleSuite.scala @@ -23,6 +23,7 @@ import org.apache.mxnetexamples.neuralstyle.end2end.{BoostInference, BoostTrain} import org.scalatest.{BeforeAndAfterAll, FunSuite} import org.slf4j.LoggerFactory +import scala.language.postfixOps import scala.sys.process.Process /** diff --git a/scala-package/examples/src/test/scala/org/apache/mxnetexamples/rnn/ExampleRNNSuite.scala b/scala-package/examples/src/test/scala/org/apache/mxnetexamples/rnn/ExampleRNNSuite.scala index 2ccd38fc4f9c..ff3fbe9e05d2 100644 --- a/scala-package/examples/src/test/scala/org/apache/mxnetexamples/rnn/ExampleRNNSuite.scala +++ b/scala-package/examples/src/test/scala/org/apache/mxnetexamples/rnn/ExampleRNNSuite.scala @@ -23,6 +23,7 @@ import org.apache.mxnetexamples.Util import org.scalatest.{BeforeAndAfterAll, FunSuite, Ignore} import org.slf4j.LoggerFactory +import scala.language.postfixOps import scala.sys.process.Process class ExampleRNNSuite extends FunSuite with BeforeAndAfterAll { diff --git a/scala-package/infer/src/test/scala/org/apache/mxnet/infer/ClassifierSuite.scala b/scala-package/infer/src/test/scala/org/apache/mxnet/infer/ClassifierSuite.scala index d9ccec468791..11d418002744 100644 --- a/scala-package/infer/src/test/scala/org/apache/mxnet/infer/ClassifierSuite.scala +++ b/scala-package/infer/src/test/scala/org/apache/mxnet/infer/ClassifierSuite.scala @@ -142,7 +142,7 @@ class ClassifierSuite extends FunSuite with BeforeAndAfterAll { val result: IndexedSeq[(String, Double)] = testClassifier. classify(IndexedSeq(inputData), topK = Some(10)) - assert((result(0)_2).getClass == 1d.getClass) + assert((result(0)._2).getClass == 1d.getClass) assertResult(predictResult(0).sortBy(-_)) { result.map(_._2).toArray @@ -185,7 +185,7 @@ class ClassifierSuite extends FunSuite with BeforeAndAfterAll { val result: IndexedSeq[(String, Double)] = testClassifier. classify(IndexedSeq(inputData)) - assert((result(0)_2).getClass == 1d.getClass) + assert((result(0)._2).getClass == 1d.getClass) assertResult(predictResult(0)) { result.map(_._2).toArray diff --git a/scala-package/macros/src/main/scala/org/apache/mxnet/GeneratorBase.scala b/scala-package/macros/src/main/scala/org/apache/mxnet/GeneratorBase.scala index b2033f529c65..a2b8633f07d7 100644 --- a/scala-package/macros/src/main/scala/org/apache/mxnet/GeneratorBase.scala +++ b/scala-package/macros/src/main/scala/org/apache/mxnet/GeneratorBase.scala @@ -141,7 +141,7 @@ private[mxnet] abstract class GeneratorBase { throw new IllegalArgumentException(s"Invalid macro input: $ex") } // wrap the result up in an Expr, and return it - val result = c.Expr(Block(modDefs, Literal(Constant()))) + val result = c.Expr(Block(modDefs, Literal(Constant(())))) result } diff --git a/scala-package/pom.xml b/scala-package/pom.xml index 44b784a56f8e..46726bc00472 100644 --- a/scala-package/pom.xml +++ b/scala-package/pom.xml @@ -357,6 +357,10 @@ 2.1.0 + + -feature + -deprecation + diff --git a/scala-package/spark/src/main/scala/org/apache/mxnet/spark/io/LabeledPointIter.scala b/scala-package/spark/src/main/scala/org/apache/mxnet/spark/io/LabeledPointIter.scala index bf1b26e4b48d..44d6f3345bdf 100644 --- a/scala-package/spark/src/main/scala/org/apache/mxnet/spark/io/LabeledPointIter.scala +++ b/scala-package/spark/src/main/scala/org/apache/mxnet/spark/io/LabeledPointIter.scala @@ -115,13 +115,13 @@ class LabeledPointIter private[mxnet]( } // The name and shape of label provided by this iterator - @deprecated + @deprecated("Use provideLabelDesc instead", "1.3.0") override def provideLabel: ListMap[String, Shape] = { ListMap(labelName -> Shape(_batchSize)) } // The name and shape of data provided by this iterator - @deprecated + @deprecated("Use provideDataDesc instead", "1.3.0") override def provideData: ListMap[String, Shape] = { ListMap(dataName -> dataShape) } diff --git a/scala-package/spark/src/main/scala/org/apache/mxnet/spark/io/LongLivingDataBatch.scala b/scala-package/spark/src/main/scala/org/apache/mxnet/spark/io/LongLivingDataBatch.scala index e3272a4066b5..abf82f6e510c 100644 --- a/scala-package/spark/src/main/scala/org/apache/mxnet/spark/io/LongLivingDataBatch.scala +++ b/scala-package/spark/src/main/scala/org/apache/mxnet/spark/io/LongLivingDataBatch.scala @@ -28,8 +28,7 @@ class LongLivingDataBatch( override val data: IndexedSeq[NDArray], override val label: IndexedSeq[NDArray], override val index: IndexedSeq[Long], - override val pad: Int) extends DataBatch(data, label, index, pad, - null, null, null) { + override val pad: Int) extends DataBatch(data, label, index, pad) { override def dispose(): Unit = {} def disposeForce(): Unit = super.dispose() } diff --git a/scala-package/spark/src/main/scala/org/apache/mxnet/spark/io/PointIter.scala b/scala-package/spark/src/main/scala/org/apache/mxnet/spark/io/PointIter.scala index a955ee74e7e2..1ca23927e123 100644 --- a/scala-package/spark/src/main/scala/org/apache/mxnet/spark/io/PointIter.scala +++ b/scala-package/spark/src/main/scala/org/apache/mxnet/spark/io/PointIter.scala @@ -114,13 +114,13 @@ class PointIter private[mxnet]( } // The name and shape of label provided by this iterator - @deprecated + @deprecated("Use provideLabelDesc instead", "1.3.0") override def provideLabel: ListMap[String, Shape] = { ListMap(labelName -> Shape(_batchSize)) } // The name and shape of data provided by this iterator - @deprecated + @deprecated("Use provideDataDesc instead", "1.3.0") override def provideData: ListMap[String, Shape] = { ListMap(dataName -> dataShape) } diff --git a/scala-package/spark/src/test/scala/org/apache/mxnet/spark/MXNetGeneralSuite.scala b/scala-package/spark/src/test/scala/org/apache/mxnet/spark/MXNetGeneralSuite.scala index 7a2417bdea1f..2382ca9fa358 100644 --- a/scala-package/spark/src/test/scala/org/apache/mxnet/spark/MXNetGeneralSuite.scala +++ b/scala-package/spark/src/test/scala/org/apache/mxnet/spark/MXNetGeneralSuite.scala @@ -20,6 +20,7 @@ package org.apache.mxnet.spark import java.io.{BufferedReader, File, InputStreamReader} import java.nio.file.Files +import scala.language.postfixOps import scala.sys.process.Process import org.apache.spark.SparkContext