From 614ad53755300ab11e739b9dca3af54f49fa5850 Mon Sep 17 00:00:00 2001 From: Qing Date: Thu, 25 Apr 2019 15:53:06 -0700 Subject: [PATCH] add fix in the code --- .../org/apache/mxnet/module/BaseModule.scala | 15 ++++++++++++++- 1 file changed, 14 insertions(+), 1 deletion(-) diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/module/BaseModule.scala b/scala-package/core/src/main/scala/org/apache/mxnet/module/BaseModule.scala index 3be8e060fd6f..39da3811f883 100644 --- a/scala-package/core/src/main/scala/org/apache/mxnet/module/BaseModule.scala +++ b/scala-package/core/src/main/scala/org/apache/mxnet/module/BaseModule.scala @@ -247,6 +247,18 @@ abstract class BaseModule { /** * Run prediction and collect the outputs. + * The concatenation process will be like + * {{{ + * outputBatches = [ + * [a1, a2, a3], // batch a + * [b1, b2, b3] // batch b + * ] + * result = [ + * NDArray, // [a1, b1] + * NDArray, // [a2, b2] + * NDArray, // [a3, b3] + * ] + * }}} * @param evalData * @param numBatch Default is -1, indicating running all the batches in the data iterator. * @param reset Default is `True`, indicating whether we should reset the data iter before start @@ -264,7 +276,8 @@ abstract class BaseModule { s"in mini-batches (${out.size})." + "Maybe bucketing is used?") ) - val concatenatedOutput = outputBatches.map(out => NDArray.concatenate(out)) + val oBT = outputBatches.transpose + val concatenatedOutput = oBT.map(out => NDArray.concatenate(out)) outputBatches.foreach(_.foreach(_.dispose())) concatenatedOutput }