diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/ExecutorManager.scala b/scala-package/core/src/main/scala/org/apache/mxnet/ExecutorManager.scala
index b13741bdd3b0..d94b8fb01ed6 100644
--- a/scala-package/core/src/main/scala/org/apache/mxnet/ExecutorManager.scala
+++ b/scala-package/core/src/main/scala/org/apache/mxnet/ExecutorManager.scala
@@ -395,8 +395,8 @@ private[mxnet] object ExecutorManager {
* @param paramNames Names of all trainable parameters.
* @param ctx List of devices for training (data parallel)
* @param slices Describes how the data parallel splits data into different devices.
- * @param providedData training data shapes
- * @param providedLabel training label shapes
+ * @param providedDataDesc training data descriptions
+ * @param providedLabelDesc training label descriptions
* @param sharedGroup: DataParallelExecutorGroup
* An existing executor group, if to share parameters with it.
*
@@ -404,8 +404,8 @@ private[mxnet] object ExecutorManager {
private class DataParallelExecutorGroup private(sym: Symbol,
argNames: IndexedSeq[String], paramNames: Set[String],
ctx: Array[Context], private val slices: Array[(Int, Int)],
- providedData: Map[String, Shape],
- providedLabel: Map[String, Shape],
+ providedDataDesc: IndexedSeq[DataDesc],
+ providedLabelDesc: IndexedSeq[DataDesc],
sharedGroup: DataParallelExecutorGroup) {
// make sure the architecture is valid
ExecutorManager.checkArguments(sym)
@@ -417,8 +417,8 @@ private class DataParallelExecutorGroup private(sym: Symbol,
sharedGroup.sharedDataArrays
}
- private[mxnet] val dataNames = providedData.map { case (k, _) => k }.toList
- private[mxnet] val labelNames = providedLabel.map { case (k, _) => k }.toList
+ private[mxnet] val dataNames = providedDataDesc.map(_.name).toList
+ private[mxnet] val labelNames = providedLabelDesc.map(_.name).toList
private[mxnet] val auxNames = sym.listAuxiliaryStates()
private[mxnet] val paramIdx = argNames.zipWithIndex
.filter { case (name, i) => paramNames.contains(name) }
@@ -428,9 +428,10 @@ private class DataParallelExecutorGroup private(sym: Symbol,
private[mxnet] val trainExecs: Array[Executor] =
ctx.zipWithIndex.map { case (ctxi, i) =>
val dataShapes =
- (providedData ++ providedLabel) map { case (name, shape) =>
- name -> (Shape(slices(i)._2 - slices(i)._1) ++ shape.slice(1, shape.length))
- }
+ (providedDataDesc ++ providedLabelDesc).map( desc => {
+ desc.name ->
+ (Shape(slices(i)._2 - slices(i)._1) ++ desc.shape.slice(1, desc.shape.length))
+ }).toMap
val sharedExec: Executor = if (sharedGroup == null) null else sharedGroup.trainExecs(i)
ExecutorManager.bindExec(sym, ctxi, dataShapes, paramNamesComb,
needGrad = true, baseExec = sharedExec,
@@ -479,7 +480,7 @@ private class DataParallelExecutorGroup private(sym: Symbol,
trainData: DataIter,
sharedGroup: DataParallelExecutorGroup) {
this(sym, argNames, paramNames, ctx, slices,
- trainData.provideData, trainData.provideLabel, sharedGroup)
+ trainData.provideDataDesc, trainData.provideLabelDesc, sharedGroup)
}
def this(sym: Symbol,
@@ -487,7 +488,7 @@ private class DataParallelExecutorGroup private(sym: Symbol,
ctx: Array[Context], slices: Array[(Int, Int)],
trainData: DataIter) {
this(sym, argNames, paramNames, ctx, slices,
- trainData.provideData, trainData.provideLabel, null)
+ trainData.provideDataDesc, trainData.provideLabelDesc, null)
}
/**
@@ -509,7 +510,7 @@ private class DataParallelExecutorGroup private(sym: Symbol,
trainData: DataBatch,
sharedGroup: DataParallelExecutorGroup) {
this(sym, argNames, paramNames, ctx, slices,
- trainData.provideData, trainData.provideLabel, sharedGroup)
+ trainData.provideDataDesc, trainData.provideLabelDesc, sharedGroup)
}
def this(sym: Symbol,
@@ -517,7 +518,7 @@ private class DataParallelExecutorGroup private(sym: Symbol,
ctx: Array[Context], slices: Array[(Int, Int)],
trainData: DataBatch) {
this(sym, argNames, paramNames, ctx, slices,
- trainData.provideData, trainData.provideLabel, null)
+ trainData.provideDataDesc, trainData.provideLabelDesc, null)
}
// load data and labels into arrays
diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/FeedForward.scala b/scala-package/core/src/main/scala/org/apache/mxnet/FeedForward.scala
index 2b1765531824..b8e2ba0b39c8 100644
--- a/scala-package/core/src/main/scala/org/apache/mxnet/FeedForward.scala
+++ b/scala-package/core/src/main/scala/org/apache/mxnet/FeedForward.scala
@@ -129,11 +129,11 @@ class FeedForward private(
// Initialize weight parameters and auxiliary states
// The NDArrays associated with the _argParms and _auxParams are not disposed instead
// they are passed a outer scope if available.
- private def initParams(inputShapes: Map[String, Shape], overwrite: Boolean = false)
+ private def initParams(inputShapes: IndexedSeq[DataDesc], overwrite: Boolean = false)
: (IndexedSeq[String], IndexedSeq[String], IndexedSeq[String]) = {
val (argShapes, _, auxShapes) = symbol.inferShape(inputShapes)
val argNames = symbol.listArguments()
- val inputNames = inputShapes.keys.toSet
+ val inputNames = inputShapes.map(_.name).toSet
val paramNames = argNames.filter(!inputNames.contains(_))
val auxNames = symbol.listAuxiliaryStates()
@@ -179,7 +179,7 @@ class FeedForward private(
}
// Initialize the predictor module for running prediction.
- private def initPredictor(inputShapes: Map[String, Shape]): Unit = {
+ private def initPredictor(inputShapes: IndexedSeq[DataDesc]): Unit = {
var shouldInit = true
if (this.predExec != null) {
val (argShapes, _, _) = symbol.inferShape(inputShapes)
@@ -193,7 +193,7 @@ class FeedForward private(
}
if(shouldInit) {
// for now only use the first device
- val predExec = symbol.simpleBind(ctx(0), gradReq = "null", shapeDict = inputShapes)
+ val predExec = symbol.simpleBind(ctx(0), gradReq = "null", inputShapes)
predExec.copyParamsFrom(_argParams, _auxParams)
ExecutorManager.checkArguments(symbol)
this.predExec = predExec
@@ -233,8 +233,8 @@ class FeedForward private(
*/
def predict(data: DataIter, numBatch: Int = -1): Array[NDArray] = {
data.reset()
- val dataShapes = data.provideData
- val dataNames = dataShapes.map(_._1).toArray
+ val dataShapes = data.provideDataDesc
+ val dataNames = dataShapes.map(_.name).toArray
initPredictor(dataShapes)
val batchSize = data.batchSize
val dataArrays = dataNames.map(predExec.argDict(_))
@@ -363,7 +363,7 @@ class FeedForward private(
this.symbol = symGen.generate(trainData.defaultBucketKey)
checkArguments()
}
- initParams(trainData.provideData ++ trainData.provideLabel)
+ initParams(trainData.provideDataDesc ++ trainData.provideLabelDesc)
}
private def fit(trainData: DataIter, evalData: DataIter, evalMetric: EvalMetric = new Accuracy(),
diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/IO.scala b/scala-package/core/src/main/scala/org/apache/mxnet/IO.scala
index b580ad10a04e..1db6d2a6e953 100644
--- a/scala-package/core/src/main/scala/org/apache/mxnet/IO.scala
+++ b/scala-package/core/src/main/scala/org/apache/mxnet/IO.scala
@@ -141,28 +141,46 @@ class DataBatch(val data: IndexedSeq[NDArray],
val pad: Int,
// the key for the bucket that should be used for this batch,
// for bucketing io only
- val bucketKey: AnyRef,
+ val bucketKey: AnyRef = null,
// use DataDesc to indicate the order of data/label loading
// (must match the order of input data/label)
- private val providedDataDesc: IndexedSeq[DataDesc],
- private val providedLabelDesc: IndexedSeq[DataDesc]) {
+ private val providedDataDesc: IndexedSeq[DataDesc] = null,
+ private val providedLabelDesc: IndexedSeq[DataDesc] = null) {
// TODO: change the data/label type into IndexedSeq[(NDArray, DataDesc)]
// However, since the data and label can be accessed publicly (no getter and setter)
// the change on this will break BC
+
+ @deprecated("Use provideDataDesc and provideDataLabel instead", "1.3.0")
+ def this(data: IndexedSeq[NDArray],
+ label: IndexedSeq[NDArray],
+ index: IndexedSeq[Long],
+ pad: Int,
+ // the key for the bucket that should be used for this batch,
+ // for bucketing io only
+ bucketKey: AnyRef,
+ // use ListMap to indicate the order of data/label loading
+ // (must match the order of input data/label)
+ providedData: ListMap[String, Shape]) {
+ this(data, label, index, pad, bucketKey,
+ DataDesc.ListMap2Descs(providedData))
+ }
+
+ @deprecated("Use provideDataDesc and provideDataLabel instead", "1.3.0")
def this(data: IndexedSeq[NDArray],
label: IndexedSeq[NDArray],
index: IndexedSeq[Long],
pad: Int,
// the key for the bucket that should be used for this batch,
// for bucketing io only
- bucketKey: AnyRef = null,
+ bucketKey: AnyRef,
// use ListMap to indicate the order of data/label loading
// (must match the order of input data/label)
- providedData: ListMap[String, Shape] = null,
- providedLabel: ListMap[String, Shape] = null) {
+ providedData: ListMap[String, Shape],
+ providedLabel: ListMap[String, Shape]) {
this(data, label, index, pad, bucketKey,
DataDesc.ListMap2Descs(providedData), DataDesc.ListMap2Descs(providedLabel))
}
+
/**
* Dispose its data and labels
* The object shall never be used after it is disposed.
@@ -177,6 +195,7 @@ class DataBatch(val data: IndexedSeq[NDArray],
}
// The name and shape of data
+ @deprecated("Use provideDataDesc instead", "1.3.0")
def provideData: ListMap[String, Shape] = {
var temp = ListMap[String, Shape]()
if (providedDataDesc == null) null
@@ -187,6 +206,7 @@ class DataBatch(val data: IndexedSeq[NDArray],
}
// The name and shape of label
+ @deprecated("Use provideLabelDesc instead", "1.3.0")
def provideLabel: ListMap[String, Shape] = {
var temp = ListMap[String, Shape]()
if (providedLabelDesc == null) null
@@ -311,8 +331,7 @@ abstract class DataIter extends Iterator[DataBatch] {
*/
@throws(classOf[NoSuchElementException])
def next(): DataBatch = {
- new DataBatch(getData(), getLabel(), getIndex(), getPad(),
- null, null, null)
+ new DataBatch(getData(), getLabel(), getIndex(), getPad())
}
/**
diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/MX_PRIMITIVES.scala b/scala-package/core/src/main/scala/org/apache/mxnet/MX_PRIMITIVES.scala
index 3a51222cc0b8..de7792850dc1 100644
--- a/scala-package/core/src/main/scala/org/apache/mxnet/MX_PRIMITIVES.scala
+++ b/scala-package/core/src/main/scala/org/apache/mxnet/MX_PRIMITIVES.scala
@@ -17,6 +17,8 @@
package org.apache.mxnet
+import scala.language.implicitConversions
+
object MX_PRIMITIVES {
/**
diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/Symbol.scala b/scala-package/core/src/main/scala/org/apache/mxnet/Symbol.scala
index 821e04f08df2..296734d1a4ff 100644
--- a/scala-package/core/src/main/scala/org/apache/mxnet/Symbol.scala
+++ b/scala-package/core/src/main/scala/org/apache/mxnet/Symbol.scala
@@ -21,6 +21,7 @@ import org.apache.mxnet.Base._
import org.apache.mxnet.DType.DType
import org.slf4j.{Logger, LoggerFactory}
+import scala.collection.mutable
import scala.collection.mutable.{ArrayBuffer, ListBuffer}
import scala.language.implicitConversions
@@ -209,6 +210,33 @@ class Symbol private(private[mxnet] val handle: SymbolHandle) extends NativeReso
}
}
+ /**
+ * Infer the shape of outputs and arguments of given known shapes of arguments.
+ * User can either pass in the known shapes in positional way or keyword argument way.
+ * Tuple of Nones is returned if there is not enough information passed in.
+ * An error will be raised if there is inconsistency found in the known shapes passed in.
+ * @param args Provide a list of DataDesc containing the shapes to resolve
+ * @return
+ * argShapes List of shapes of arguments. The order is in the same order as list_arguments()
+ * outShapes List of shapes of outputs. The order is in the same order as list_outputs()
+ * auxShapes List of shapes of outputs. The order is in the same order as list_auxiliary()
+ */
+ def inferShape(args: IndexedSeq[DataDesc]):
+ (IndexedSeq[Shape], IndexedSeq[Shape], IndexedSeq[Shape]) = {
+ val keys = ArrayBuffer.empty[String]
+ val indPtr = ArrayBuffer(0)
+ val sdata = ArrayBuffer.empty[Int]
+ args.foreach { arg =>
+ val shape = arg.shape
+ if (shape != null) {
+ keys += arg.name
+ sdata ++= shape.toVector
+ indPtr += sdata.size
+ }
+ }
+ inferShape(keys.toArray, indPtr.toArray, sdata.toArray)
+ }
+
/**
* Infer the shape of outputs and arguments of given known shapes of arguments.
* User can either pass in the known shapes in positional way or keyword argument way.
@@ -389,6 +417,29 @@ class Symbol private(private[mxnet] val handle: SymbolHandle) extends NativeReso
checkCall(_LIB.mxSymbolCompose(handle, name, keys, args))
}
+ /**
+ * Bind current symbol to get an executor, allocate all the ndarrays needed.
+ * Allows specifying data types.
+ * This function will ask user to pass in ndarray of position
+ * they like to bind to, and it will automatically allocate the ndarray
+ * for arguments and auxiliary states that user did not specify explicitly.
+ *
+ * @param ctx The device context the generated executor to run on.
+ * @param gradReq {'write', 'add', 'null'}, or list of str or dict of str to str, optional
+ * Specifies how we should update the gradient to the args_grad.
+ * - 'write' means everytime gradient is write to specified args_grad NDArray.
+ * - 'add' means everytime gradient is add to the specified NDArray.
+ * - 'null' means no action is taken, the gradient may not be calculated.
+ * @param dataDesc List of dataDescriptors
+ * @return The generated Executor
+ */
+ def simpleBind(ctx: Context, gradReq: String,
+ descs: IndexedSeq[DataDesc]) : Executor = {
+ val (shapes, types) = descs.map(desc =>
+ ( desc.name -> desc.shape, desc.name -> desc.dtype )).unzip
+ simpleBind(ctx, gradReq, shapes.toMap, types.toMap)
+ }
+
/**
* Bind current symbol to get an executor, allocate all the ndarrays needed.
* Allows specifying data types.
@@ -1189,7 +1240,7 @@ object Symbol extends SymbolBase {
// a more friendly interface for creating symbols
// all values except symbols in kwargs will be cast to String using its toString() method
- @Deprecated
+ @deprecated("Use Checked version", "0.1.2")
def createFromNamedSymbolsNoCheck(
operator: String, name: String = null, attr: Map[String, String] = null)(
kwargs: Map[String, Any]): Symbol = {
@@ -1208,7 +1259,7 @@ object Symbol extends SymbolBase {
// a more friendly interface for creating symbols
// all values except symbols in kwargs will be cast to String using its toString() method
- @Deprecated
+ @deprecated("Use Checked version", "0.1.2")
def createFromListedSymbolsNoCheck(
operator: String, name: String = null, attr: Map[String, String] = null)(
symbols: Array[Symbol], kwargs: Map[String, Any] = null): Symbol = {
diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/io/MXDataIter.scala b/scala-package/core/src/main/scala/org/apache/mxnet/io/MXDataIter.scala
index e30098c3088b..66b7d83cedc8 100644
--- a/scala-package/core/src/main/scala/org/apache/mxnet/io/MXDataIter.scala
+++ b/scala-package/core/src/main/scala/org/apache/mxnet/io/MXDataIter.scala
@@ -107,8 +107,7 @@ private[mxnet] class MXDataIter(private[mxnet] val handle: DataIterHandle,
checkCall(_LIB.mxDataIterNext(handle, next))
if (next.value > 0) {
currentBatch = new DataBatch(data = getData(), label = getLabel(),
- index = getIndex(), pad = getPad(),
- null, null, null)
+ index = getIndex(), pad = getPad())
} else {
currentBatch = null
}
diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/io/NDArrayIter.scala b/scala-package/core/src/main/scala/org/apache/mxnet/io/NDArrayIter.scala
index b205bbe47abb..e9513257c050 100644
--- a/scala-package/core/src/main/scala/org/apache/mxnet/io/NDArrayIter.scala
+++ b/scala-package/core/src/main/scala/org/apache/mxnet/io/NDArrayIter.scala
@@ -161,8 +161,7 @@ class NDArrayIter(data: IndexedSeq[(DataDesc, NDArray)],
override def next(): DataBatch = {
if (hasNext) {
cursor += dataBatchSize
- new DataBatch(getData(), getLabel(), getIndex(), getPad(),
- null, null, null)
+ new DataBatch(getData(), getLabel(), getIndex(), getPad())
} else {
throw new NoSuchElementException
}
diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/io/PrefetchingIter.scala b/scala-package/core/src/main/scala/org/apache/mxnet/io/PrefetchingIter.scala
index d277351b124b..9cfcd598197c 100644
--- a/scala-package/core/src/main/scala/org/apache/mxnet/io/PrefetchingIter.scala
+++ b/scala-package/core/src/main/scala/org/apache/mxnet/io/PrefetchingIter.scala
@@ -42,71 +42,51 @@ class PrefetchingIter(
require(iters.nonEmpty, "Iters length must be greater than 0")
- private val _provideData: ListMap[String, Shape] = {
+ @deprecated("Please use provideDataDesc instead", "1.3.0")
+ override def provideData: ListMap[String, Shape] = {
if (dataNames == null) {
- iters.map(_.provideData).foldLeft(ListMap[String, Shape]()) { (acc, elem) =>
- acc ++ elem
- }
+ iters.map(_.provideData).reduce(_ ++ _)
} else {
- iters.zipWithIndex.map(tu => (tu._1.provideData, tu._2))
- .map(m => m._1.map(t => (dataNames(m._2)(t._1), t._2)))
- .foldLeft(ListMap[String, Shape]()) { (acc, elem) =>
- acc ++ elem
- }
+ iters.map(_.provideData).zip(dataNames).map { case (providedData, names) =>
+ providedData.map { case (oldName, shape) => names(oldName) -> shape }
+ }.reduceLeft(_ ++ _)
}
}
- private val _provideLabel: ListMap[String, Shape] = {
+ @deprecated("Please use provideDataDesc instead", "1.3.0")
+ override def provideLabel: ListMap[String, Shape] = {
if (labelNames == null) {
- iters.map(_.provideLabel).foldLeft(ListMap[String, Shape]()) { (acc, elem) =>
- acc ++ elem
- }
+ iters.map(_.provideLabel).reduce(_ ++ _)
} else {
- iters.zipWithIndex.map(tu => (tu._1.provideLabel, tu._2))
- .map(m => m._1.map(t => (labelNames(m._2)(t._1), t._2)))
- .foldLeft(ListMap[String, Shape]()) { (acc, elem) =>
- acc ++ elem
- }
+ iters.map(_.provideLabel).zip(labelNames).map { case (providedLabel, names) =>
+ providedLabel.map { case (oldName, shape) => names(oldName) -> shape }
+ }.reduceLeft(_ ++ _)
}
}
- private val _provideDataDesc: IndexedSeq[DataDesc] = {
+ override def provideDataDesc: IndexedSeq[DataDesc] = {
if (dataNames == null) {
- iters.map(_.provideDataDesc).foldLeft(IndexedSeq[DataDesc]()) { (acc, elem) =>
- acc ++ elem
- }
+ iters.flatMap(_.provideDataDesc)
} else {
- iters.zipWithIndex.map(tu => (tu._1.provideDataDesc, tu._2))
- .map(m =>
- m._1.map(t =>
- new DataDesc(dataNames(m._2)(t.name), t.shape, t.dtype, t.layout)
- )
- )
- .foldLeft(IndexedSeq[DataDesc]()) { (acc, elem) =>
- acc ++ elem
- }
+ iters.map(_.provideDataDesc).zip(dataNames).flatMap { case (providedDataDesc, names) =>
+ providedDataDesc.map(desc =>
+ new DataDesc(names(desc.name), desc.shape, desc.dtype, desc.layout))
+ }
}
}
- private val _provideLabelDesc: IndexedSeq[DataDesc] = {
+ override def provideLabelDesc: IndexedSeq[DataDesc] = {
if (labelNames == null) {
- iters.map(_.provideLabelDesc).foldLeft(IndexedSeq[DataDesc]()) { (acc, elem) =>
- acc ++ elem
- }
+ iters.flatMap(_.provideLabelDesc)
} else {
- iters.zipWithIndex.map(tu => (tu._1.provideLabelDesc, tu._2))
- .map(m =>
- m._1.map(t =>
- new DataDesc(labelNames(m._2)(t.name), t.shape, t.dtype, t.layout)
- )
- )
- .foldLeft(IndexedSeq[DataDesc]()) { (acc, elem) =>
- acc ++ elem
- }
+ iters.map(_.provideLabelDesc).zip(labelNames).flatMap { case (providedLabelDesc, names) =>
+ providedLabelDesc.map(desc =>
+ new DataDesc(names(desc.name), desc.shape, desc.dtype, desc.layout))
+ }
}
}
- private val _batchSize: Int = this._provideData.toList(0)._2(0)
+ private val _batchSize: Int = this.provideDataDesc.head.shape(0)
private val dataReady: IndexedSeq[Semaphore] =
(0 until iters.length).map(i => new Semaphore(0))
private val dataTaken: IndexedSeq[Semaphore] =
@@ -177,20 +157,6 @@ class PrefetchingIter(
*/
override def getPad(): Int = this.currentBatch.pad
- // The name and shape of label provided by this iterator
- @deprecated("Please use provideDataDesc instead", "1.3.0")
- override def provideLabel: ListMap[String, Shape] = this._provideLabel
-
- // The name and shape of data provided by this iterator
- @deprecated("Please use provideLabelDesc instead", "1.3.0")
- override def provideData: ListMap[String, Shape] = this._provideData
-
- // Provide type:DataDesc of the data
- override def provideDataDesc: IndexedSeq[DataDesc] = _provideDataDesc
-
- // Provide type:DataDesc of the label
- override def provideLabelDesc: IndexedSeq[DataDesc] = _provideLabelDesc
-
override def hasNext: Boolean = {
for (e <- dataReady) e.acquire()
if (nextBatch(0) == null) {
@@ -209,8 +175,7 @@ class PrefetchingIter(
currentBatch = new DataBatch(datas.toIndexedSeq.flatten,
labels.toIndexedSeq.flatten,
nextBatch(0).index,
- nextBatch(0).pad,
- null, null, null)
+ nextBatch(0).pad)
for (e <- dataTaken) e.release()
true
}
diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/module/BaseModule.scala b/scala-package/core/src/main/scala/org/apache/mxnet/module/BaseModule.scala
index b73f4ad4b112..3be8e060fd6f 100644
--- a/scala-package/core/src/main/scala/org/apache/mxnet/module/BaseModule.scala
+++ b/scala-package/core/src/main/scala/org/apache/mxnet/module/BaseModule.scala
@@ -398,8 +398,7 @@ abstract class BaseModule {
fitParams: FitParams = new FitParams): Unit = {
require(fitParams != null, "Undefined fitParams")
require(numEpoch > 0, s"Invalid number of epochs $numEpoch")
- import org.apache.mxnet.DataDesc._
- bind(dataShapes = trainData.provideData, labelShapes = Option(trainData.provideLabel),
+ bind(dataShapes = trainData.provideDataDesc, labelShapes = Option(trainData.provideLabelDesc),
forTraining = true, forceRebind = fitParams.forceRebind)
fitParams.monitor.foreach(installMonitor)
initParams(fitParams.initializer, argParams, auxParams,
diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/module/BucketingModule.scala b/scala-package/core/src/main/scala/org/apache/mxnet/module/BucketingModule.scala
index 41a6f69394d2..d5c8c21ea106 100644
--- a/scala-package/core/src/main/scala/org/apache/mxnet/module/BucketingModule.scala
+++ b/scala-package/core/src/main/scala/org/apache/mxnet/module/BucketingModule.scala
@@ -296,7 +296,7 @@ class BucketingModule(symGen: AnyRef => (Symbol, IndexedSeq[String], IndexedSeq[
require(this.binded && this.paramsInitialized, "bind() and initParams() must be called first.")
val bucketKey = dataBatch.bucketKey
val originalBucketKey = this._currBucketKey
- this.switchBucket(bucketKey, dataBatch.provideData, Option(dataBatch.provideLabel))
+ this.switchBucket(bucketKey, dataBatch.provideDataDesc, Option(dataBatch.provideLabelDesc))
// switch back
this.switchBucket(originalBucketKey, null, None)
}
@@ -308,8 +308,8 @@ class BucketingModule(symGen: AnyRef => (Symbol, IndexedSeq[String], IndexedSeq[
*/
override def forward(dataBatch: DataBatch, isTrain: Option[Boolean] = None): Unit = {
require(binded && paramsInitialized, "bind() and initParams() must be called first.")
- this.switchBucket(dataBatch.bucketKey, dataBatch.provideData,
- Option(dataBatch.provideLabel))
+ this.switchBucket(dataBatch.bucketKey, dataBatch.provideDataDesc,
+ Option(dataBatch.provideLabelDesc))
this._currModule.forward(dataBatch, isTrain)
}
diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/module/Module.scala b/scala-package/core/src/main/scala/org/apache/mxnet/module/Module.scala
index 3255d9346b80..9928f66b2200 100644
--- a/scala-package/core/src/main/scala/org/apache/mxnet/module/Module.scala
+++ b/scala-package/core/src/main/scala/org/apache/mxnet/module/Module.scala
@@ -435,14 +435,14 @@ class Module(symbolVar: Symbol,
val newDataShapes = dataBatch.data.map(_.shape)
if (currDataShapes != newDataShapes) {
val newDShapes: IndexedSeq[DataDesc] =
- if (dataBatch.provideData != null) dataBatch.provideData
+ if (dataBatch.provideDataDesc != null) dataBatch.provideDataDesc
else {
this.dataShapes.zip(newDataShapes).map { case (i, shape) =>
DataDesc(i.name, shape, i.dtype, i.layout)
}
}
val newLShapes: Option[IndexedSeq[DataDesc]] =
- if (dataBatch.provideLabel != null) Some(dataBatch.provideLabel)
+ if (dataBatch.provideLabelDesc != null) Some(dataBatch.provideLabelDesc)
else if (dataBatch.label != null && dataBatch.label.length > 0
&& this.labelShapes != null) {
Some(this.labelShapes.zip(dataBatch.label).map { case (i, j) =>
diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/module/SequentialModule.scala b/scala-package/core/src/main/scala/org/apache/mxnet/module/SequentialModule.scala
index 3c3eeb97f201..d80e6bc6279b 100644
--- a/scala-package/core/src/main/scala/org/apache/mxnet/module/SequentialModule.scala
+++ b/scala-package/core/src/main/scala/org/apache/mxnet/module/SequentialModule.scala
@@ -111,6 +111,16 @@ class SequentialModule extends BaseModule {
this.labelShapesVar.orNull
}
+ /**
+ * Get output shapes.
+ * @return The output shapes of the last
+ * module is the output shape of a SequentialModule.
+ */
+ def outputDesc: IndexedSeq[DataDesc] = {
+ require(this.binded, "bind() must be called first.")
+ this.modules.reverse.head.dataShapes
+ }
+
/**
* Get output shapes.
* @return The output shapes of the last
@@ -306,12 +316,8 @@ class SequentialModule extends BaseModule {
val dataNames = module.outputShapes.map(_._1)
require(dataNames.length == data.data.length,
s"dataNames $dataNames do not match with number of arrays in batch")
- var provideData = ListMap[String, Shape]()
- for ((name, x) <- dataNames.zip(out.head)) {
- provideData += name -> x.shape
- }
data = new DataBatch(out.head, data.label, data.index,
- data.pad, data.bucketKey, provideData, data.provideLabel)
+ data.pad, data.bucketKey, outputDesc, data.provideLabelDesc)
}
}
}
diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/util/OptionConversion.scala b/scala-package/core/src/main/scala/org/apache/mxnet/util/OptionConversion.scala
index 2cf453ac3d18..c780a9605b12 100644
--- a/scala-package/core/src/main/scala/org/apache/mxnet/util/OptionConversion.scala
+++ b/scala-package/core/src/main/scala/org/apache/mxnet/util/OptionConversion.scala
@@ -17,6 +17,8 @@
package org.apache.mxnet.util
+import scala.language.implicitConversions
+
object OptionConversion {
implicit def someWrapper[A](noSome : A) : Option[A] = Option(noSome)
}
diff --git a/scala-package/core/src/test/scala/org/apache/mxnet/IOSuite.scala b/scala-package/core/src/test/scala/org/apache/mxnet/IOSuite.scala
index 698a2b53a9fa..9839f09e4063 100644
--- a/scala-package/core/src/test/scala/org/apache/mxnet/IOSuite.scala
+++ b/scala-package/core/src/test/scala/org/apache/mxnet/IOSuite.scala
@@ -19,6 +19,7 @@ package org.apache.mxnet
import org.apache.mxnet.io.{NDArrayIter, ResizeIter, PrefetchingIter}
import org.scalatest.{BeforeAndAfterAll, FunSuite}
+import scala.language.postfixOps
import scala.sys.process._
@@ -53,10 +54,10 @@ class IOSuite extends FunSuite with BeforeAndAfterAll {
// test DataIter
val mnistIter = mnistPack.iterator
// test provideData
- val provideData = mnistIter.provideData
- val provideLabel = mnistIter.provideLabel
- assert(provideData("data") === Shape(100, 784))
- assert(provideLabel("label") === Shape(100))
+ val provideData = mnistIter.provideDataDesc
+ val provideLabel = mnistIter.provideLabelDesc
+ assert(provideData.find(_.name == "data").get.shape === Shape(100, 784))
+ assert(provideLabel.find(_.name == "label").get.shape === Shape(100))
// test_loop
mnistIter.reset()
batchCount = 0
@@ -105,10 +106,10 @@ class IOSuite extends FunSuite with BeforeAndAfterAll {
val nBatch = 500
var batchCount = 0
// test provideData
- val provideData = imgRecIter.provideData
- val provideLabel = imgRecIter.provideLabel
- assert(provideData("data").toArray === Array(100, 3, 28, 28))
- assert(provideLabel("label").toArray === Array(100))
+ val provideData = imgRecIter.provideDataDesc
+ val provideLabel = imgRecIter.provideLabelDesc
+ assert(provideData.find(_.name == "data").get.shape.toArray === Array(100, 3, 28, 28))
+ assert(provideLabel.find(_.name == "label").get.shape.toArray === Array(100))
imgRecIter.reset()
while (imgRecIter.hasNext) {
@@ -208,12 +209,12 @@ class IOSuite extends FunSuite with BeforeAndAfterAll {
assert(nBatch === batchCount)
// test provideData
- val provideData = prefetchIter.provideData
- val provideLabel = prefetchIter.provideLabel
- assert(provideData("data1") === Shape(100, 784))
- assert(provideData("data2") === Shape(100, 784))
- assert(provideLabel("label1") === Shape(100))
- assert(provideLabel("label2") === Shape(100))
+ val provideData = prefetchIter.provideDataDesc
+ val provideLabel = prefetchIter.provideLabelDesc
+ assert(provideData.find(_.name == "data1").get.shape === Shape(100, 784))
+ assert(provideData.find(_.name == "data2").get.shape === Shape(100, 784))
+ assert(provideLabel.find(_.name == "label1").get.shape === Shape(100))
+ assert(provideLabel.find(_.name == "label2").get.shape === Shape(100))
// test reset
prefetchIter.reset()
diff --git a/scala-package/core/src/test/scala/org/apache/mxnet/ModuleSuite.scala b/scala-package/core/src/test/scala/org/apache/mxnet/ModuleSuite.scala
index 88e314e2a72c..3e753a18d247 100644
--- a/scala-package/core/src/test/scala/org/apache/mxnet/ModuleSuite.scala
+++ b/scala-package/core/src/test/scala/org/apache/mxnet/ModuleSuite.scala
@@ -253,8 +253,8 @@ class ModuleSuite extends FunSuite with BeforeAndAfterAll {
// create module
val mod = new Module(x, contexts = Array(Context.cpu()))
- mod.bind(dataShapes = trainData.provideData,
- Option(trainData.provideLabel))
+ mod.bind(dataShapes = trainData.provideDataDesc,
+ Option(trainData.provideLabelDesc))
mod.installMonitor(mon)
val argParams = Map(
"fc_0_weight" -> NDArray.array(Array(0.15f, 0.2f, 0.25f, 0.3f), Shape(2, 2)),
diff --git a/scala-package/core/src/test/scala/org/apache/mxnet/train/ConvSuite.scala b/scala-package/core/src/test/scala/org/apache/mxnet/train/ConvSuite.scala
index 44f57c0d1162..eb6d5b7d7175 100644
--- a/scala-package/core/src/test/scala/org/apache/mxnet/train/ConvSuite.scala
+++ b/scala-package/core/src/test/scala/org/apache/mxnet/train/ConvSuite.scala
@@ -23,6 +23,7 @@ import org.scalatest.{BeforeAndAfterAll, FunSuite}
import org.slf4j.LoggerFactory
import scala.collection.mutable.ListBuffer
+import scala.language.postfixOps
import scala.sys.process._
class ConvSuite extends FunSuite with BeforeAndAfterAll {
diff --git a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/cnntextclassification/CNNTextClassification.scala b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/cnntextclassification/CNNTextClassification.scala
index 7745043b23d8..d9902e9dcc75 100644
--- a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/cnntextclassification/CNNTextClassification.scala
+++ b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/cnntextclassification/CNNTextClassification.scala
@@ -188,7 +188,7 @@ object CNNTextClassification {
// decay learning rate
if (iter % 50 == 0 && iter > 0) {
factor *= 0.5f
- opt.setLrScale(paramBlocks.map(_._1 -> factor).toMap)
+ opt.setLrMult(paramBlocks.map(paramBlock => (Left(paramBlock._1), factor)).toMap)
logger.info(s"reset learning to ${opt.learningRate * factor}")
}
// end of training loop
diff --git a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/customop/ExampleCustomOp.scala b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/customop/ExampleCustomOp.scala
index df79f5b63769..0cfcc49aee04 100644
--- a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/customop/ExampleCustomOp.scala
+++ b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/customop/ExampleCustomOp.scala
@@ -107,7 +107,7 @@ object ExampleCustomOp {
val (trainIter, testIter) =
Data.mnistIterator(dataPath, batchSize = 100, inputShape = Shape(784))
- val datasAndLabels = trainIter.provideData ++ trainIter.provideLabel
+ val datasAndLabels = trainIter.provideDataDesc ++ trainIter.provideLabelDesc
val (argShapes, outputShapes, auxShapes) = mlp.inferShape(datasAndLabels)
val initializer = new Xavier(factorType = "in", magnitude = 2.34f)
diff --git a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/customop/ExampleCustomOpWithRtc.scala b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/customop/ExampleCustomOpWithRtc.scala
index c3ac347353df..7b0fb349373d 100644
--- a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/customop/ExampleCustomOpWithRtc.scala
+++ b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/customop/ExampleCustomOpWithRtc.scala
@@ -128,7 +128,7 @@ object ExampleCustomOpWithRtc {
val (trainIter, testIter) =
Data.mnistIterator(dataPath, batchSize = 100, inputShape = Shape(784))
- val datasAndLabels = trainIter.provideData ++ trainIter.provideLabel
+ val datasAndLabels = trainIter.provideDataDesc ++ trainIter.provideLabelDesc
val (argShapes, outputShapes, auxShapes) = mlp.inferShape(datasAndLabels)
val initializer = new Xavier(factorType = "in", magnitude = 2.34f)
diff --git a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/imclassification/datasets/SyntheticDataIter.scala b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/imclassification/datasets/SyntheticDataIter.scala
index e4d3b2ae7c3e..4d22b62bea81 100644
--- a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/imclassification/datasets/SyntheticDataIter.scala
+++ b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/imclassification/datasets/SyntheticDataIter.scala
@@ -54,7 +54,7 @@ class SyntheticDataIter(numClasses: Int, val batchSize: Int, datumShape: List[In
override def next(): DataBatch = {
if (hasNext) {
curIter += batchSize
- new DataBatch(data, label, getIndex, getPad, null, null, null)
+ new DataBatch(data, label, getIndex, getPad)
} else {
throw new NoSuchElementException
}
diff --git a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/module/MnistMlp.scala b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/module/MnistMlp.scala
index 4d450c60456b..839f6ac85902 100644
--- a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/module/MnistMlp.scala
+++ b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/module/MnistMlp.scala
@@ -49,7 +49,7 @@ object MnistMlp {
logger.info("Load checkpoint from epoch {}", loadModelEpoch)
Module.loadCheckpoint("model/mnist_mlp", loadModelEpoch, loadOptimizerStates = true)
}
- mod.bind(dataShapes = train.provideData, labelShapes = Some(train.provideLabel))
+ mod.bind(dataShapes = train.provideDataDesc, labelShapes = Some(train.provideLabelDesc))
mod.initParams()
mod.initOptimizer(optimizer = new SGD(learningRate = 0.01f, momentum = 0.9f))
diff --git a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/module/SequentialModuleEx.scala b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/module/SequentialModuleEx.scala
index ff616a57b1aa..ea2273ebd796 100644
--- a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/module/SequentialModuleEx.scala
+++ b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/module/SequentialModuleEx.scala
@@ -57,7 +57,7 @@ object SequentialModuleEx {
cmdLine: SequentialModuleEx): Unit = {
// Intermediate-level API
val modSeq = getSeqModule()
- modSeq.bind(dataShapes = train.provideData, labelShapes = Some(train.provideLabel))
+ modSeq.bind(dataShapes = train.provideDataDesc, labelShapes = Some(train.provideLabelDesc))
if (cmdLine.loadModelPath != null) {
logger.info(s"Load checkpoint from ${cmdLine.loadModelPath}")
modSeq.loadParams(cmdLine.loadModelPath)
diff --git a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/multitask/ExampleMultiTask.scala b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/multitask/ExampleMultiTask.scala
index bfde55831e26..5c17a3747ab6 100644
--- a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/multitask/ExampleMultiTask.scala
+++ b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/multitask/ExampleMultiTask.scala
@@ -31,6 +31,7 @@ import org.apache.mxnet.optimizer.RMSProp
import org.apache.mxnetexamples.Util
import scala.collection.immutable.ListMap
+import scala.language.postfixOps
import scala.sys.process.Process
/**
@@ -65,7 +66,7 @@ object ExampleMultiTask {
new DataBatch(batch.data,
IndexedSeq(label, label),
batch.index,
- batch.pad, null, null, null)
+ batch.pad)
} else {
throw new NoSuchElementException
}
@@ -100,7 +101,7 @@ object ExampleMultiTask {
override def getIndex(): IndexedSeq[Long] = this.dataIter.getIndex()
// The name and shape of label provided by this iterator
- @deprecated
+ @deprecated("Use provideLabelDesc instead", "1.3.0")
override def provideLabel: ListMap[String, Shape] = {
val provideLabel = this.dataIter.provideLabel.toArray
// Different labels should be used here for actual application
@@ -126,7 +127,7 @@ object ExampleMultiTask {
override def getPad(): Int = this.dataIter.getPad()
// The name and shape of data provided by this iterator
- @deprecated
+ @deprecated("Use provideDataDesc instead", "1.3.0")
override def provideData: ListMap[String, Shape] = this.dataIter.provideData
override def provideDataDesc: IndexedSeq[DataDesc] = this.dataIter.provideDataDesc
@@ -229,10 +230,10 @@ object ExampleMultiTask {
val trainMultiIt = new MultiMnistIterator(trainIter)
val valMultiIter = new MultiMnistIterator(valIter)
- val datasAndLabels = trainMultiIt.provideData ++ trainMultiIt.provideLabel
+ val datasAndLabels = trainMultiIt.provideDataDesc ++ trainMultiIt.provideLabelDesc
val (argShapes, outputShapes, auxShapes)
- = network.inferShape(trainMultiIt.provideData("data"))
+ = network.inferShape(trainMultiIt.provideDataDesc.filter(_.name == "data"))
val initializer = new Xavier(factorType = "in", magnitude = 2.34f)
val argNames = network.listArguments
diff --git a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/neuralstyle/NeuralStyle.scala b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/neuralstyle/NeuralStyle.scala
index 1767cabcbae4..475e179f819b 100644
--- a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/neuralstyle/NeuralStyle.scala
+++ b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/neuralstyle/NeuralStyle.scala
@@ -39,7 +39,7 @@ object NeuralStyle {
private val logger = LoggerFactory.getLogger(classOf[NeuralStyle])
def preprocessContentImage(path: String, longEdge: Int, ctx: Context): NDArray = {
- val img = Image(new File(path))
+ val img = Image.fromFile(new File(path))
logger.info(s"load the content image, size = ${(img.height, img.width)}")
val factor = longEdge.toFloat / Math.max(img.height, img.width)
val (newHeight, newWidth) = ((img.height * factor).toInt, (img.width * factor).toInt)
@@ -60,7 +60,7 @@ object NeuralStyle {
}
def preprocessStyleImage(path: String, shape: Shape, ctx: Context): NDArray = {
- val img = Image(new File(path))
+ val img = Image.fromFile(new File(path))
val resizedImg = img.scaleTo(shape(3), shape(2))
val sample = NDArray.empty(Shape(1, 3, shape(2), shape(3)), ctx)
val datas = {
diff --git a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/neuralstyle/end2end/DataProcessing.scala b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/neuralstyle/end2end/DataProcessing.scala
index 80a009ea40c2..5b01d2016467 100644
--- a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/neuralstyle/end2end/DataProcessing.scala
+++ b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/neuralstyle/end2end/DataProcessing.scala
@@ -29,7 +29,7 @@ object DataProcessing {
def preprocessContentImage(path: String,
dShape: Shape = null, ctx: Context): NDArray = {
- val img = Image(new File(path))
+ val img = Image.fromFile(new File(path))
val resizedImg = img.scaleTo(dShape(3), dShape(2))
val sample = NDArray.empty(Shape(1, 3, resizedImg.height, resizedImg.width), ctx)
val datas = {
@@ -46,7 +46,7 @@ object DataProcessing {
}
def preprocessStyleImage(path: String, shape: Shape, ctx: Context): NDArray = {
- val img = Image(new File(path))
+ val img = Image.fromFile(new File(path))
val resizedImg = img.scaleTo(shape(3), shape(2))
val sample = NDArray.empty(Shape(1, 3, shape(2), shape(3)), ctx)
val datas = {
diff --git a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/rnn/BucketIo.scala b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/rnn/BucketIo.scala
index 350e28cf8634..2648f9e3d6bb 100644
--- a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/rnn/BucketIo.scala
+++ b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/rnn/BucketIo.scala
@@ -251,11 +251,11 @@ object BucketIo {
override def getPad(): Int = 0
// The name and shape of label provided by this iterator
- @deprecated
+ @deprecated("Use provideLabelDesc instead", "1.3.0")
override def provideLabel: ListMap[String, Shape] = this._provideLabel
// The name and shape of data provided by this iterator
- @deprecated
+ @deprecated("Use provideDataDesc instead", "1.3.0")
override def provideData: ListMap[String, Shape] = this._provideData
// Provide type:DataDesc of the data
diff --git a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/rnn/TrainCharRnn.scala b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/rnn/TrainCharRnn.scala
index c90b7637b9b1..68346afe1f47 100644
--- a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/rnn/TrainCharRnn.scala
+++ b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/rnn/TrainCharRnn.scala
@@ -75,7 +75,7 @@ object TrainCharRnn {
// the network symbol
val symbol = symGen(buckets(0))
- val datasAndLabels = dataTrain.provideData ++ dataTrain.provideLabel
+ val datasAndLabels = dataTrain.provideDataDesc ++ dataTrain.provideLabelDesc
val (argShapes, outputShapes, auxShapes) = symbol.inferShape(datasAndLabels)
val initializer = new Xavier(factorType = "in", magnitude = 2.34f)
@@ -85,12 +85,13 @@ object TrainCharRnn {
val auxNames = symbol.listAuxiliaryStates()
val auxDict = auxNames.zip(auxShapes.map(NDArray.zeros(_, ctx))).toMap
+ val datasAndLabelsNames = datasAndLabels.map(_.name)
val gradDict = argNames.zip(argShapes).filter { case (name, shape) =>
- !datasAndLabels.contains(name)
+ !datasAndLabelsNames.contains(name)
}.map(x => x._1 -> NDArray.empty(x._2, ctx)).toMap
argDict.foreach { case (name, ndArray) =>
- if (!datasAndLabels.contains(name)) {
+ if (!datasAndLabelsNames.contains(name)) {
initializer.initWeight(name, ndArray)
}
}
diff --git a/scala-package/examples/src/test/scala/org/apache/mxnetexamples/benchmark/ScalaInferenceBenchmarkSuite.scala b/scala-package/examples/src/test/scala/org/apache/mxnetexamples/benchmark/ScalaInferenceBenchmarkSuite.scala
index 0b7f4693c5fa..548f2e4122e0 100644
--- a/scala-package/examples/src/test/scala/org/apache/mxnetexamples/benchmark/ScalaInferenceBenchmarkSuite.scala
+++ b/scala-package/examples/src/test/scala/org/apache/mxnetexamples/benchmark/ScalaInferenceBenchmarkSuite.scala
@@ -22,6 +22,7 @@ import org.apache.mxnetexamples.Util
import org.scalatest.{BeforeAndAfterAll, FunSuite}
import org.slf4j.LoggerFactory
+import scala.language.postfixOps
import scala.sys.process.Process
class ScalaInferenceBenchmarkSuite extends FunSuite with BeforeAndAfterAll {
diff --git a/scala-package/examples/src/test/scala/org/apache/mxnetexamples/cnntextclassification/CNNClassifierExampleSuite.scala b/scala-package/examples/src/test/scala/org/apache/mxnetexamples/cnntextclassification/CNNClassifierExampleSuite.scala
index a7e36dfc3a11..ae0ee33002d9 100644
--- a/scala-package/examples/src/test/scala/org/apache/mxnetexamples/cnntextclassification/CNNClassifierExampleSuite.scala
+++ b/scala-package/examples/src/test/scala/org/apache/mxnetexamples/cnntextclassification/CNNClassifierExampleSuite.scala
@@ -26,6 +26,7 @@ import org.apache.mxnetexamples.Util
import org.scalatest.{BeforeAndAfterAll, FunSuite}
import org.slf4j.LoggerFactory
+import scala.language.postfixOps
import scala.sys.process.Process
/**
diff --git a/scala-package/examples/src/test/scala/org/apache/mxnetexamples/customop/CustomOpExampleSuite.scala b/scala-package/examples/src/test/scala/org/apache/mxnetexamples/customop/CustomOpExampleSuite.scala
index 6385e062a260..b65f237c8621 100644
--- a/scala-package/examples/src/test/scala/org/apache/mxnetexamples/customop/CustomOpExampleSuite.scala
+++ b/scala-package/examples/src/test/scala/org/apache/mxnetexamples/customop/CustomOpExampleSuite.scala
@@ -25,6 +25,7 @@ import org.apache.mxnetexamples.Util
import org.scalatest.{BeforeAndAfterAll, FunSuite}
import org.slf4j.LoggerFactory
+import scala.language.postfixOps
import scala.sys.process.Process
class CustomOpExampleSuite extends FunSuite with BeforeAndAfterAll {
diff --git a/scala-package/examples/src/test/scala/org/apache/mxnetexamples/gan/GanExampleSuite.scala b/scala-package/examples/src/test/scala/org/apache/mxnetexamples/gan/GanExampleSuite.scala
index 59faba9a3779..709ea77632e0 100644
--- a/scala-package/examples/src/test/scala/org/apache/mxnetexamples/gan/GanExampleSuite.scala
+++ b/scala-package/examples/src/test/scala/org/apache/mxnetexamples/gan/GanExampleSuite.scala
@@ -24,6 +24,7 @@ import org.apache.mxnetexamples.Util
import org.scalatest.{BeforeAndAfterAll, FunSuite, Ignore}
import org.slf4j.LoggerFactory
+import scala.language.postfixOps
import scala.sys.process.Process
class GanExampleSuite extends FunSuite with BeforeAndAfterAll{
diff --git a/scala-package/examples/src/test/scala/org/apache/mxnetexamples/imclassification/IMClassificationExampleSuite.scala b/scala-package/examples/src/test/scala/org/apache/mxnetexamples/imclassification/IMClassificationExampleSuite.scala
index 0daba5a97d77..e6f4f6fcc908 100644
--- a/scala-package/examples/src/test/scala/org/apache/mxnetexamples/imclassification/IMClassificationExampleSuite.scala
+++ b/scala-package/examples/src/test/scala/org/apache/mxnetexamples/imclassification/IMClassificationExampleSuite.scala
@@ -24,6 +24,7 @@ import org.apache.mxnetexamples.Util
import org.scalatest.{BeforeAndAfterAll, FunSuite}
import org.slf4j.LoggerFactory
+import scala.language.postfixOps
import scala.sys.process.Process
/**
diff --git a/scala-package/examples/src/test/scala/org/apache/mxnetexamples/infer/imageclassifier/ImageClassifierExampleSuite.scala b/scala-package/examples/src/test/scala/org/apache/mxnetexamples/infer/imageclassifier/ImageClassifierExampleSuite.scala
index c5308ac37512..9c16aca420ef 100644
--- a/scala-package/examples/src/test/scala/org/apache/mxnetexamples/infer/imageclassifier/ImageClassifierExampleSuite.scala
+++ b/scala-package/examples/src/test/scala/org/apache/mxnetexamples/infer/imageclassifier/ImageClassifierExampleSuite.scala
@@ -23,6 +23,7 @@ import java.io.File
import org.apache.mxnet.Context
import org.apache.mxnetexamples.Util
+import scala.language.postfixOps
import sys.process.Process
/**
diff --git a/scala-package/examples/src/test/scala/org/apache/mxnetexamples/infer/objectdetector/ObjectDetectorExampleSuite.scala b/scala-package/examples/src/test/scala/org/apache/mxnetexamples/infer/objectdetector/ObjectDetectorExampleSuite.scala
index bd960bddebf5..918fb835f76e 100644
--- a/scala-package/examples/src/test/scala/org/apache/mxnetexamples/infer/objectdetector/ObjectDetectorExampleSuite.scala
+++ b/scala-package/examples/src/test/scala/org/apache/mxnetexamples/infer/objectdetector/ObjectDetectorExampleSuite.scala
@@ -23,6 +23,7 @@ import org.apache.mxnetexamples.Util
import org.scalatest.{BeforeAndAfterAll, FunSuite}
import org.slf4j.LoggerFactory
+import scala.language.postfixOps
import scala.sys.process.Process
class ObjectDetectorExampleSuite extends FunSuite with BeforeAndAfterAll {
diff --git a/scala-package/examples/src/test/scala/org/apache/mxnetexamples/neuralstyle/NeuralStyleSuite.scala b/scala-package/examples/src/test/scala/org/apache/mxnetexamples/neuralstyle/NeuralStyleSuite.scala
index 71c2b35ef444..c93a7d06a452 100644
--- a/scala-package/examples/src/test/scala/org/apache/mxnetexamples/neuralstyle/NeuralStyleSuite.scala
+++ b/scala-package/examples/src/test/scala/org/apache/mxnetexamples/neuralstyle/NeuralStyleSuite.scala
@@ -23,6 +23,7 @@ import org.apache.mxnetexamples.neuralstyle.end2end.{BoostInference, BoostTrain}
import org.scalatest.{BeforeAndAfterAll, FunSuite}
import org.slf4j.LoggerFactory
+import scala.language.postfixOps
import scala.sys.process.Process
/**
diff --git a/scala-package/examples/src/test/scala/org/apache/mxnetexamples/rnn/ExampleRNNSuite.scala b/scala-package/examples/src/test/scala/org/apache/mxnetexamples/rnn/ExampleRNNSuite.scala
index 2ccd38fc4f9c..ff3fbe9e05d2 100644
--- a/scala-package/examples/src/test/scala/org/apache/mxnetexamples/rnn/ExampleRNNSuite.scala
+++ b/scala-package/examples/src/test/scala/org/apache/mxnetexamples/rnn/ExampleRNNSuite.scala
@@ -23,6 +23,7 @@ import org.apache.mxnetexamples.Util
import org.scalatest.{BeforeAndAfterAll, FunSuite, Ignore}
import org.slf4j.LoggerFactory
+import scala.language.postfixOps
import scala.sys.process.Process
class ExampleRNNSuite extends FunSuite with BeforeAndAfterAll {
diff --git a/scala-package/infer/src/test/scala/org/apache/mxnet/infer/ClassifierSuite.scala b/scala-package/infer/src/test/scala/org/apache/mxnet/infer/ClassifierSuite.scala
index d9ccec468791..11d418002744 100644
--- a/scala-package/infer/src/test/scala/org/apache/mxnet/infer/ClassifierSuite.scala
+++ b/scala-package/infer/src/test/scala/org/apache/mxnet/infer/ClassifierSuite.scala
@@ -142,7 +142,7 @@ class ClassifierSuite extends FunSuite with BeforeAndAfterAll {
val result: IndexedSeq[(String, Double)] = testClassifier.
classify(IndexedSeq(inputData), topK = Some(10))
- assert((result(0)_2).getClass == 1d.getClass)
+ assert((result(0)._2).getClass == 1d.getClass)
assertResult(predictResult(0).sortBy(-_)) {
result.map(_._2).toArray
@@ -185,7 +185,7 @@ class ClassifierSuite extends FunSuite with BeforeAndAfterAll {
val result: IndexedSeq[(String, Double)] = testClassifier.
classify(IndexedSeq(inputData))
- assert((result(0)_2).getClass == 1d.getClass)
+ assert((result(0)._2).getClass == 1d.getClass)
assertResult(predictResult(0)) {
result.map(_._2).toArray
diff --git a/scala-package/macros/src/main/scala/org/apache/mxnet/GeneratorBase.scala b/scala-package/macros/src/main/scala/org/apache/mxnet/GeneratorBase.scala
index b2033f529c65..a2b8633f07d7 100644
--- a/scala-package/macros/src/main/scala/org/apache/mxnet/GeneratorBase.scala
+++ b/scala-package/macros/src/main/scala/org/apache/mxnet/GeneratorBase.scala
@@ -141,7 +141,7 @@ private[mxnet] abstract class GeneratorBase {
throw new IllegalArgumentException(s"Invalid macro input: $ex")
}
// wrap the result up in an Expr, and return it
- val result = c.Expr(Block(modDefs, Literal(Constant())))
+ val result = c.Expr(Block(modDefs, Literal(Constant(()))))
result
}
diff --git a/scala-package/pom.xml b/scala-package/pom.xml
index 44b784a56f8e..46726bc00472 100644
--- a/scala-package/pom.xml
+++ b/scala-package/pom.xml
@@ -357,6 +357,10 @@
2.1.0
+
+ -feature
+ -deprecation
+
diff --git a/scala-package/spark/src/main/scala/org/apache/mxnet/spark/io/LabeledPointIter.scala b/scala-package/spark/src/main/scala/org/apache/mxnet/spark/io/LabeledPointIter.scala
index bf1b26e4b48d..44d6f3345bdf 100644
--- a/scala-package/spark/src/main/scala/org/apache/mxnet/spark/io/LabeledPointIter.scala
+++ b/scala-package/spark/src/main/scala/org/apache/mxnet/spark/io/LabeledPointIter.scala
@@ -115,13 +115,13 @@ class LabeledPointIter private[mxnet](
}
// The name and shape of label provided by this iterator
- @deprecated
+ @deprecated("Use provideLabelDesc instead", "1.3.0")
override def provideLabel: ListMap[String, Shape] = {
ListMap(labelName -> Shape(_batchSize))
}
// The name and shape of data provided by this iterator
- @deprecated
+ @deprecated("Use provideDataDesc instead", "1.3.0")
override def provideData: ListMap[String, Shape] = {
ListMap(dataName -> dataShape)
}
diff --git a/scala-package/spark/src/main/scala/org/apache/mxnet/spark/io/LongLivingDataBatch.scala b/scala-package/spark/src/main/scala/org/apache/mxnet/spark/io/LongLivingDataBatch.scala
index e3272a4066b5..abf82f6e510c 100644
--- a/scala-package/spark/src/main/scala/org/apache/mxnet/spark/io/LongLivingDataBatch.scala
+++ b/scala-package/spark/src/main/scala/org/apache/mxnet/spark/io/LongLivingDataBatch.scala
@@ -28,8 +28,7 @@ class LongLivingDataBatch(
override val data: IndexedSeq[NDArray],
override val label: IndexedSeq[NDArray],
override val index: IndexedSeq[Long],
- override val pad: Int) extends DataBatch(data, label, index, pad,
- null, null, null) {
+ override val pad: Int) extends DataBatch(data, label, index, pad) {
override def dispose(): Unit = {}
def disposeForce(): Unit = super.dispose()
}
diff --git a/scala-package/spark/src/main/scala/org/apache/mxnet/spark/io/PointIter.scala b/scala-package/spark/src/main/scala/org/apache/mxnet/spark/io/PointIter.scala
index a955ee74e7e2..1ca23927e123 100644
--- a/scala-package/spark/src/main/scala/org/apache/mxnet/spark/io/PointIter.scala
+++ b/scala-package/spark/src/main/scala/org/apache/mxnet/spark/io/PointIter.scala
@@ -114,13 +114,13 @@ class PointIter private[mxnet](
}
// The name and shape of label provided by this iterator
- @deprecated
+ @deprecated("Use provideLabelDesc instead", "1.3.0")
override def provideLabel: ListMap[String, Shape] = {
ListMap(labelName -> Shape(_batchSize))
}
// The name and shape of data provided by this iterator
- @deprecated
+ @deprecated("Use provideDataDesc instead", "1.3.0")
override def provideData: ListMap[String, Shape] = {
ListMap(dataName -> dataShape)
}
diff --git a/scala-package/spark/src/test/scala/org/apache/mxnet/spark/MXNetGeneralSuite.scala b/scala-package/spark/src/test/scala/org/apache/mxnet/spark/MXNetGeneralSuite.scala
index 7a2417bdea1f..2382ca9fa358 100644
--- a/scala-package/spark/src/test/scala/org/apache/mxnet/spark/MXNetGeneralSuite.scala
+++ b/scala-package/spark/src/test/scala/org/apache/mxnet/spark/MXNetGeneralSuite.scala
@@ -20,6 +20,7 @@ package org.apache.mxnet.spark
import java.io.{BufferedReader, File, InputStreamReader}
import java.nio.file.Files
+import scala.language.postfixOps
import scala.sys.process.Process
import org.apache.spark.SparkContext