diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/module/BaseModule.scala b/scala-package/core/src/main/scala/org/apache/mxnet/module/BaseModule.scala index 39da3811f883..7fbdae5b3e21 100644 --- a/scala-package/core/src/main/scala/org/apache/mxnet/module/BaseModule.scala +++ b/scala-package/core/src/main/scala/org/apache/mxnet/module/BaseModule.scala @@ -247,23 +247,23 @@ abstract class BaseModule { /** * Run prediction and collect the outputs. - * The concatenation process will be like - * {{{ - * outputBatches = [ - * [a1, a2, a3], // batch a - * [b1, b2, b3] // batch b - * ] - * result = [ - * NDArray, // [a1, b1] - * NDArray, // [a2, b2] - * NDArray, // [a3, b3] - * ] - * }}} - * @param evalData + * @param evalData dataIter to do the Inference * @param numBatch Default is -1, indicating running all the batches in the data iterator. * @param reset Default is `True`, indicating whether we should reset the data iter before start * doing prediction. * @return The return value will be a list `[out1, out2, out3]`. + * The concatenation process will be like + * {{{ + * outputBatches = [ + * [a1, a2, a3], // batch a + * [b1, b2, b3] // batch b + * ] + * result = [ + * NDArray, // [a1, b1] + * NDArray, // [a2, b2] + * NDArray, // [a3, b3] + * ] + * }}} * Where each element is concatenation of the outputs for all the mini-batches. */ def predict(evalData: DataIter, numBatch: Int = -1, reset: Boolean = true)