From ed7ca26a23881f9b7474fbcb12a576c2b544bee6 Mon Sep 17 00:00:00 2001
From: Piyush Ghai
Date: Thu, 10 Jan 2019 15:12:43 -0800
Subject: [PATCH] [MXNET-1260] Float64 DType computation support in Scala/Java
(#13678)
* Added Float64 as a supported datatype in NDArray
* Added unit tests for Float64 in NDArray
* Fix for failing Clojure unit tests
* Added Float and Double as MX_PRIMITIVES for computation in Scala
* Trying out second approach --> Private Impl methods with generic signature, and public methods calling the Impls
* Fixed errors in *= method
* Added Float64 in IO.scala and DataIter.scala
* Added another testcase for IO.DataDesc creation
* Fixed failing CI
* Added Float64 in Predictor class
* Added Float64 in Classifier class
* Added Double as a possible return type to : classifyWithNDArray
* Added unit tests for Classifier and Predictor.scala classes for Float64/Double
* Approach 3 --> Using a trait to mirror Float and Double in Scala
* Added comments on MX_PRIMITIVES.scala
* Added Float64/Double support for inference in ImageClassifier APIs
* Added unary- and compareTo in MX_NUMBER_LIKE
* Renamed MX_NUMBER_LIKE to MX_PRIMITIVE_TYPE
* Fixed linting issue
* Now specifying dType from the available data in copyTo and MXDataIter.scala for creating a new DataIterator
* Add primitives support handling to the generator for proper conversion
* Reduced code duplication in classify method in Classifier.scala
* Fix infer package for new signatures and address some bugs
* Removed code duplication in getPixelsArray
* remove debugging
* Changed classifyWithNDArray method in Classifier.scala
* Removed code duplication in predictImpl
* Satisfying lint god _/\_
* Fixed failing PredictorSuite test
* Renamed MX_FLOAT to Camel case
* Revert "Renamed MX_FLOAT to Camel case"
This reverts commit 9d7c3ce6f9c4d6ed2c46041a02e23c0f1df8dfe5.
* Added an implicit conversion from int--> float to support int operations in NDArrays. (These ops were already supported in the previous versions)
* Added Float64 as a training option to ImClassification Suite. Also added integration tests for it
* Satisfy Lint God _/\_
* Added Float64 support in Java NDArray
* Added Float64 support in Java's Predictor API
* Added yours truly to the Contributors list
* Added method comments on Predictor.predict with Array[Double] as a possible input
* Added method comments explaining what MX_PRIMITIVE_TYPE is
* Fixed errors cause by rebasing with master
* Added licences to the files
---
CONTRIBUTORS.md | 1 +
.../src/org/apache/clojure_mxnet/infer.clj | 242 ++++++-----
.../org/apache/clojure_mxnet/primitives.clj | 46 ++
.../src/org/apache/clojure_mxnet/util.clj | 7 +-
.../test/good-test-ndarray.clj | 7 +-
.../infer/imageclassifier_test.clj | 12 +-
.../infer/objectdetector_test.clj | 4 +
.../org/apache/clojure_mxnet/ndarray_test.clj | 2 +-
.../apache/clojure_mxnet/primitives_test.clj | 45 ++
.../org/apache/clojure_mxnet/util_test.clj | 10 +
.../main/scala/org/apache/mxnet/Base.scala | 7 +-
.../main/scala/org/apache/mxnet/LibInfo.scala | 3 +
.../org/apache/mxnet/MX_PRIMITIVES.scala | 85 ++++
.../main/scala/org/apache/mxnet/NDArray.scala | 230 ++++++++--
.../org/apache/mxnet/io/MXDataIter.scala | 6 +-
.../org/apache/mxnet/io/NDArrayIter.scala | 7 +-
.../org/apache/mxnet/javaapi/NDArray.scala | 65 +++
.../org/apache/mxnet/javaapi/NDArrayTest.java | 15 +
.../test/scala/org/apache/mxnet/IOSuite.scala | 27 ++
.../scala/org/apache/mxnet/NDArraySuite.scala | 396 +++++++++++++++---
.../imclassification/TrainModel.scala | 24 +-
.../datasets/SyntheticDataIter.scala | 8 +-
.../imclassification/models/Lenet.scala | 5 +-
.../models/MultiLayerPerceptron.scala | 5 +-
.../imclassification/models/Resnet.scala | 16 +-
.../IMClassificationExampleSuite.scala | 10 +-
.../org/apache/mxnet/infer/Classifier.scala | 39 +-
.../apache/mxnet/infer/ImageClassifier.scala | 48 ++-
.../org/apache/mxnet/infer/Predictor.scala | 46 +-
.../mxnet/infer/javaapi/Predictor.scala | 24 ++
.../apache/mxnet/infer/ClassifierSuite.scala | 47 ++-
.../mxnet/infer/ImageClassifierSuite.scala | 7 +
.../apache/mxnet/infer/PredictorSuite.scala | 32 +-
.../native/org_apache_mxnet_native_c_api.cc | 9 +
.../native/org_apache_mxnet_native_c_api.h | 8 +
35 files changed, 1251 insertions(+), 294 deletions(-)
create mode 100644 contrib/clojure-package/src/org/apache/clojure_mxnet/primitives.clj
create mode 100644 contrib/clojure-package/test/org/apache/clojure_mxnet/primitives_test.clj
create mode 100644 scala-package/core/src/main/scala/org/apache/mxnet/MX_PRIMITIVES.scala
diff --git a/CONTRIBUTORS.md b/CONTRIBUTORS.md
index b9f84d592a70..5b5fdce712f1 100644
--- a/CONTRIBUTORS.md
+++ b/CONTRIBUTORS.md
@@ -193,6 +193,7 @@ List of Contributors
* [Yuxi Hu](/~https://github.com/yuxihu)
* [Harsh Patel](/~https://github.com/harshp8l)
* [Xiao Wang](/~https://github.com/BeyonderXX)
+* [Piyush Ghai](/~https://github.com/piyushghai)
Label Bot
---------
diff --git a/contrib/clojure-package/src/org/apache/clojure_mxnet/infer.clj b/contrib/clojure-package/src/org/apache/clojure_mxnet/infer.clj
index b2b23da6274e..224a39275dac 100644
--- a/contrib/clojure-package/src/org/apache/clojure_mxnet/infer.clj
+++ b/contrib/clojure-package/src/org/apache/clojure_mxnet/infer.clj
@@ -18,6 +18,7 @@
(ns org.apache.clojure-mxnet.infer
(:refer-clojure :exclude [type])
(:require [org.apache.clojure-mxnet.context :as context]
+ [org.apache.clojure-mxnet.dtype :as dtype]
[org.apache.clojure-mxnet.io :as mx-io]
[org.apache.clojure-mxnet.shape :as shape]
[org.apache.clojure-mxnet.util :as util]
@@ -62,10 +63,12 @@
(defprotocol AImageClassifier
(classify-image
[wrapped-image-classifier image]
- [wrapped-image-classifier image topk])
+ [wrapped-image-classifier image topk]
+ [wrapped-image-classifier image topk dtype])
(classify-image-batch
[wrapped-image-classifier images]
- [wrapped-image-classifier images topk]))
+ [wrapped-image-classifier images topk]
+ [wrapped-image-classifier images topk dtype]))
(defprotocol AObjectDetector
(detect-objects
@@ -80,7 +83,8 @@
(extend-protocol APredictor
WrappedPredictor
- (predict [wrapped-predictor inputs]
+ (predict
+ [wrapped-predictor inputs]
(util/validate! ::wrapped-predictor wrapped-predictor
"Invalid predictor")
(util/validate! ::vec-of-float-arrays inputs
@@ -101,62 +105,50 @@
(extend-protocol AClassifier
WrappedClassifier
- (classify [wrapped-classifier inputs]
- (util/validate! ::wrapped-classifier wrapped-classifier
- "Invalid classifier")
- (util/validate! ::vec-of-float-arrays inputs
- "Invalid inputs")
- (classify wrapped-classifier inputs nil))
- (classify [wrapped-classifier inputs topk]
- (util/validate! ::wrapped-classifier wrapped-classifier
- "Invalid classifier")
- (util/validate! ::vec-of-float-arrays inputs
- "Invalid inputs")
- (util/validate! ::nil-or-int topk "Invalid top-K")
- (util/coerce-return-recursive
- (.classify (:classifier wrapped-classifier)
- (util/vec->indexed-seq inputs)
- (util/->int-option topk))))
- (classify-with-ndarray [wrapped-classifier inputs]
- (util/validate! ::wrapped-classifier wrapped-classifier
- "Invalid classifier")
- (util/validate! ::vec-of-ndarrays inputs
- "Invalid inputs")
- (classify-with-ndarray wrapped-classifier inputs nil))
- (classify-with-ndarray [wrapped-classifier inputs topk]
- (util/validate! ::wrapped-classifier wrapped-classifier
- "Invalid classifier")
- (util/validate! ::vec-of-ndarrays inputs
- "Invalid inputs")
- (util/validate! ::nil-or-int topk "Invalid top-K")
- (util/coerce-return-recursive
- (.classifyWithNDArray (:classifier wrapped-classifier)
- (util/vec->indexed-seq inputs)
- (util/->int-option topk))))
+ (classify
+ ([wrapped-classifier inputs]
+ (classify wrapped-classifier inputs nil))
+ ([wrapped-classifier inputs topk]
+ (util/validate! ::wrapped-classifier wrapped-classifier
+ "Invalid classifier")
+ (util/validate! ::vec-of-float-arrays inputs
+ "Invalid inputs")
+ (util/validate! ::nil-or-int topk "Invalid top-K")
+ (util/coerce-return-recursive
+ (.classify (:classifier wrapped-classifier)
+ (util/vec->indexed-seq inputs)
+ (util/->int-option topk)))))
+ (classify-with-ndarray
+ ([wrapped-classifier inputs]
+ (classify-with-ndarray wrapped-classifier inputs nil))
+ ([wrapped-classifier inputs topk]
+ (util/validate! ::wrapped-classifier wrapped-classifier
+ "Invalid classifier")
+ (util/validate! ::vec-of-ndarrays inputs
+ "Invalid inputs")
+ (util/validate! ::nil-or-int topk "Invalid top-K")
+ (util/coerce-return-recursive
+ (.classifyWithNDArray (:classifier wrapped-classifier)
+ (util/vec->indexed-seq inputs)
+ (util/->int-option topk)))))
WrappedImageClassifier
- (classify [wrapped-image-classifier inputs]
- (util/validate! ::wrapped-image-classifier wrapped-image-classifier
- "Invalid classifier")
- (util/validate! ::vec-of-float-arrays inputs
- "Invalid inputs")
- (classify wrapped-image-classifier inputs nil))
- (classify [wrapped-image-classifier inputs topk]
- (util/validate! ::wrapped-image-classifier wrapped-image-classifier
- "Invalid classifier")
- (util/validate! ::vec-of-float-arrays inputs
- "Invalid inputs")
- (util/validate! ::nil-or-int topk "Invalid top-K")
- (util/coerce-return-recursive
- (.classify (:image-classifier wrapped-image-classifier)
- (util/vec->indexed-seq inputs)
- (util/->int-option topk))))
- (classify-with-ndarray [wrapped-image-classifier inputs]
- (util/validate! ::wrapped-image-classifier wrapped-image-classifier
- "Invalid classifier")
- (util/validate! ::vec-of-ndarrays inputs
- "Invalid inputs")
- (classify-with-ndarray wrapped-image-classifier inputs nil))
- (classify-with-ndarray [wrapped-image-classifier inputs topk]
+ (classify
+ ([wrapped-image-classifier inputs]
+ (classify wrapped-image-classifier inputs nil))
+ ([wrapped-image-classifier inputs topk]
+ (util/validate! ::wrapped-image-classifier wrapped-image-classifier
+ "Invalid classifier")
+ (util/validate! ::vec-of-float-arrays inputs
+ "Invalid inputs")
+ (util/validate! ::nil-or-int topk "Invalid top-K")
+ (util/coerce-return-recursive
+ (.classify (:image-classifier wrapped-image-classifier)
+ (util/vec->indexed-seq inputs)
+ (util/->int-option topk)))))
+ (classify-with-ndarray
+ ([wrapped-image-classifier inputs]
+ (classify-with-ndarray wrapped-image-classifier inputs nil))
+ ([wrapped-image-classifier inputs topk]
(util/validate! ::wrapped-image-classifier wrapped-image-classifier
"Invalid classifier")
(util/validate! ::vec-of-ndarrays inputs
@@ -165,83 +157,83 @@
(util/coerce-return-recursive
(.classifyWithNDArray (:image-classifier wrapped-image-classifier)
(util/vec->indexed-seq inputs)
- (util/->int-option topk)))))
+ (util/->int-option topk))))))
(s/def ::image #(instance? BufferedImage %))
+(s/def ::dtype #{dtype/UINT8 dtype/INT32 dtype/FLOAT16 dtype/FLOAT32 dtype/FLOAT64})
(extend-protocol AImageClassifier
WrappedImageClassifier
- (classify-image [wrapped-image-classifier image]
- (util/validate! ::wrapped-image-classifier wrapped-image-classifier
- "Invalid classifier")
- (util/validate! ::image image "Invalid image")
- (classify-image wrapped-image-classifier image nil))
- (classify-image [wrapped-image-classifier image topk]
- (util/validate! ::wrapped-image-classifier wrapped-image-classifier
- "Invalid classifier")
- (util/validate! ::image image "Invalid image")
- (util/validate! ::nil-or-int topk "Invalid top-K")
- (util/coerce-return-recursive
- (.classifyImage (:image-classifier wrapped-image-classifier)
- image
- (util/->int-option topk))))
- (classify-image-batch [wrapped-image-classifier images]
- (util/validate! ::wrapped-image-classifier wrapped-image-classifier
- "Invalid classifier")
- (classify-image-batch wrapped-image-classifier images nil))
- (classify-image-batch [wrapped-image-classifier images topk]
- (util/validate! ::wrapped-image-classifier wrapped-image-classifier
- "Invalid classifier")
- (util/validate! ::nil-or-int topk "Invalid top-K")
- (util/coerce-return-recursive
- (.classifyImageBatch (:image-classifier wrapped-image-classifier)
- images
- (util/->int-option topk)))))
+ (classify-image
+ ([wrapped-image-classifier image]
+ (classify-image wrapped-image-classifier image nil dtype/FLOAT32))
+ ([wrapped-image-classifier image topk]
+ (classify-image wrapped-image-classifier image topk dtype/FLOAT32))
+ ([wrapped-image-classifier image topk dtype]
+ (util/validate! ::wrapped-image-classifier wrapped-image-classifier
+ "Invalid classifier")
+ (util/validate! ::image image "Invalid image")
+ (util/validate! ::nil-or-int topk "Invalid top-K")
+ (util/validate! ::dtype dtype "Invalid dtype")
+ (util/coerce-return-recursive
+ (.classifyImage (:image-classifier wrapped-image-classifier)
+ image
+ (util/->int-option topk)
+ dtype))))
+ (classify-image-batch
+ ([wrapped-image-classifier images]
+ (classify-image-batch wrapped-image-classifier images nil dtype/FLOAT32))
+ ([wrapped-image-classifier images topk]
+ (classify-image-batch wrapped-image-classifier images topk dtype/FLOAT32))
+ ([wrapped-image-classifier images topk dtype]
+ (util/validate! ::wrapped-image-classifier wrapped-image-classifier
+ "Invalid classifier")
+ (util/validate! ::nil-or-int topk "Invalid top-K")
+ (util/validate! ::dtype dtype "Invalid dtype")
+ (util/coerce-return-recursive
+ (.classifyImageBatch (:image-classifier wrapped-image-classifier)
+ images
+ (util/->int-option topk)
+ dtype)))))
(extend-protocol AObjectDetector
WrappedObjectDetector
- (detect-objects [wrapped-detector image]
- (util/validate! ::wrapped-detector wrapped-detector
- "Invalid object detector")
- (util/validate! ::image image "Invalid image")
- (detect-objects wrapped-detector image nil))
- (detect-objects [wrapped-detector image topk]
- (util/validate! ::wrapped-detector wrapped-detector
- "Invalid object detector")
- (util/validate! ::image image "Invalid image")
- (util/validate! ::nil-or-int topk "Invalid top-K")
- (util/coerce-return-recursive
- (.imageObjectDetect (:object-detector wrapped-detector)
- image
- (util/->int-option topk))))
- (detect-objects-batch [wrapped-detector images]
- (util/validate! ::wrapped-detector wrapped-detector
- "Invalid object detector")
- (detect-objects-batch wrapped-detector images nil))
- (detect-objects-batch [wrapped-detector images topk]
- (util/validate! ::wrapped-detector wrapped-detector
- "Invalid object detector")
- (util/validate! ::nil-or-int topk "Invalid top-K")
- (util/coerce-return-recursive
- (.imageBatchObjectDetect (:object-detector wrapped-detector)
- images
- (util/->int-option topk))))
- (detect-objects-with-ndarrays [wrapped-detector input-arrays]
- (util/validate! ::wrapped-detector wrapped-detector
- "Invalid object detector")
- (util/validate! ::vec-of-ndarrays input-arrays
- "Invalid inputs")
- (detect-objects-with-ndarrays wrapped-detector input-arrays nil))
- (detect-objects-with-ndarrays [wrapped-detector input-arrays topk]
+ (detect-objects
+ ([wrapped-detector image]
+ (detect-objects wrapped-detector image nil))
+ ([wrapped-detector image topk]
(util/validate! ::wrapped-detector wrapped-detector
"Invalid object detector")
- (util/validate! ::vec-of-ndarrays input-arrays
- "Invalid inputs")
- (util/validate! ::nil-or-int topk "Invalid top-K")
- (util/coerce-return-recursive
- (.objectDetectWithNDArray (:object-detector wrapped-detector)
- (util/vec->indexed-seq input-arrays)
+ (util/validate! ::image image "Invalid image")
+ (util/validate! ::nil-or-int topk "Invalid top-K")
+ (util/coerce-return-recursive
+ (.imageObjectDetect (:object-detector wrapped-detector)
+ image
+ (util/->int-option topk)))))
+ (detect-objects-batch
+ ([wrapped-detector images]
+ (detect-objects-batch wrapped-detector images nil))
+ ([wrapped-detector images topk]
+ (util/validate! ::wrapped-detector wrapped-detector
+ "Invalid object detector")
+ (util/validate! ::nil-or-int topk "Invalid top-K")
+ (util/coerce-return-recursive
+ (.imageBatchObjectDetect (:object-detector wrapped-detector)
+ images
(util/->int-option topk)))))
+ (detect-objects-with-ndarrays
+ ([wrapped-detector input-arrays]
+ (detect-objects-with-ndarrays wrapped-detector input-arrays nil))
+ ([wrapped-detector input-arrays topk]
+ (util/validate! ::wrapped-detector wrapped-detector
+ "Invalid object detector")
+ (util/validate! ::vec-of-ndarrays input-arrays
+ "Invalid inputs")
+ (util/validate! ::nil-or-int topk "Invalid top-K")
+ (util/coerce-return-recursive
+ (.objectDetectWithNDArray (:object-detector wrapped-detector)
+ (util/vec->indexed-seq input-arrays)
+ (util/->int-option topk))))))
(defprotocol AInferenceFactory
(create-predictor [factory] [factory opts])
@@ -335,7 +327,7 @@
[image input-shape-vec]
(util/validate! ::image image "Invalid image")
(util/validate! (s/coll-of int?) input-shape-vec "Invalid shape vector")
- (ImageClassifier/bufferedImageToPixels image (shape/->shape input-shape-vec)))
+ (ImageClassifier/bufferedImageToPixels image (shape/->shape input-shape-vec) dtype/FLOAT32))
(s/def ::image-path string?)
(s/def ::image-paths (s/coll-of ::image-path))
diff --git a/contrib/clojure-package/src/org/apache/clojure_mxnet/primitives.clj b/contrib/clojure-package/src/org/apache/clojure_mxnet/primitives.clj
new file mode 100644
index 000000000000..0967df2289d8
--- /dev/null
+++ b/contrib/clojure-package/src/org/apache/clojure_mxnet/primitives.clj
@@ -0,0 +1,46 @@
+;;
+;; Licensed to the Apache Software Foundation (ASF) under one or more
+;; contributor license agreements. See the NOTICE file distributed with
+;; this work for additional information regarding copyright ownership.
+;; The ASF licenses this file to You under the Apache License, Version 2.0
+;; (the "License"); you may not use this file except in compliance with
+;; the License. You may obtain a copy of the License at
+;;
+;; http://www.apache.org/licenses/LICENSE-2.0
+;;
+;; Unless required by applicable law or agreed to in writing, software
+;; distributed under the License is distributed on an "AS IS" BASIS,
+;; WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+;; See the License for the specific language governing permissions and
+;; limitations under the License.
+;;
+
+(ns org.apache.clojure-mxnet.primitives
+ (:import (org.apache.mxnet MX_PRIMITIVES$MX_FLOAT MX_PRIMITIVES$MX_Double
+ MX_PRIMITIVES$MX_PRIMITIVE_TYPE)))
+
+
+;;; Defines customer mx primitives that can be used for mathematical computations
+;;; in NDArrays to control precision. Currently Float and Double are supported
+
+;;; For purposes of automatic conversion in ndarray functions, doubles are default
+;; to specify using floats you must use a Float
+
+(defn mx-float
+ "Creates a MXNet float primitive"
+ [num]
+ (new MX_PRIMITIVES$MX_FLOAT num))
+
+(defn mx-double
+ "Creates a MXNet double primitive"
+ [num]
+ (new MX_PRIMITIVES$MX_Double num))
+
+(defn ->num
+ "Returns the underlying number value"
+ [primitive]
+ (.data primitive))
+
+(defn primitive? [x]
+ (instance? MX_PRIMITIVES$MX_PRIMITIVE_TYPE x))
+
diff --git a/contrib/clojure-package/src/org/apache/clojure_mxnet/util.clj b/contrib/clojure-package/src/org/apache/clojure_mxnet/util.clj
index 21e31baa3a9b..43970c0abd79 100644
--- a/contrib/clojure-package/src/org/apache/clojure_mxnet/util.clj
+++ b/contrib/clojure-package/src/org/apache/clojure_mxnet/util.clj
@@ -19,6 +19,7 @@
(:require [clojure.spec.alpha :as s]
[t6.from-scala.core :refer [$ $$] :as $]
[clojure.string :as string]
+ [org.apache.clojure-mxnet.primitives :as primitives]
[org.apache.clojure-mxnet.shape :as mx-shape])
(:import (org.apache.mxnet NDArray)
(scala Product Tuple2 Tuple3)
@@ -36,7 +37,8 @@
"byte<>" "byte-array"
"java.lang.String<>" "vec-or-strings"
"org.apache.mxnet.NDArray" "ndarray"
- "org.apache.mxnet.Symbol" "sym"})
+ "org.apache.mxnet.Symbol" "sym"
+ "org.apache.mxnet.MX_PRIMITIVES$MX_PRIMITIVE_TYPE" "double-or-float"})
(def symbol-param-coerce {"java.lang.String" "sym-name"
"float" "num"
@@ -144,6 +146,8 @@
(and (get targets "int<>") (vector? param)) (int-array param)
(and (get targets "float<>") (vector? param)) (float-array param)
(and (get targets "java.lang.String<>") (vector? param)) (into-array param)
+ (and (get targets "org.apache.mxnet.MX_PRIMITIVES$MX_PRIMITIVE_TYPE") (instance? Float param)) (primitives/mx-float param)
+ (and (get targets "org.apache.mxnet.MX_PRIMITIVES$MX_PRIMITIVE_TYPE") (number? param)) (primitives/mx-double param)
:else param))
(defn nil-or-coerce-param [param targets]
@@ -177,6 +181,7 @@
(instance? Map return-val) (scala-map->map return-val)
(instance? Tuple2 return-val) (tuple->vec return-val)
(instance? Tuple3 return-val) (tuple->vec return-val)
+ (primitives/primitive? return-val) (primitives/->num return-val)
:else return-val))
(defn coerce-return-recursive [return-val]
diff --git a/contrib/clojure-package/test/good-test-ndarray.clj b/contrib/clojure-package/test/good-test-ndarray.clj
index 3b53b1906006..b048a819c642 100644
--- a/contrib/clojure-package/test/good-test-ndarray.clj
+++ b/contrib/clojure-package/test/good-test-ndarray.clj
@@ -27,11 +27,12 @@
(defn
div
- ([ndarray num-or-ndarray]
+ ([ndarray ndarray-or-double-or-float]
(util/coerce-return
(.$div
ndarray
(util/coerce-param
- num-or-ndarray
- #{"float" "org.apache.mxnet.NDArray"})))))
+ ndarray-or-double-or-float
+ #{"org.apache.mxnet.MX_PRIMITIVES$MX_PRIMITIVE_TYPE"
+ "org.apache.mxnet.NDArray"})))))
diff --git a/contrib/clojure-package/test/org/apache/clojure_mxnet/infer/imageclassifier_test.clj b/contrib/clojure-package/test/org/apache/clojure_mxnet/infer/imageclassifier_test.clj
index 9badfed933a5..b459b06132b2 100644
--- a/contrib/clojure-package/test/org/apache/clojure_mxnet/infer/imageclassifier_test.clj
+++ b/contrib/clojure-package/test/org/apache/clojure_mxnet/infer/imageclassifier_test.clj
@@ -40,7 +40,11 @@
(deftest test-single-classification
(let [classifier (create-classifier)
image (infer/load-image-from-file "test/test-images/kitten.jpg")
- [predictions] (infer/classify-image classifier image 5)]
+ [predictions-all] (infer/classify-image classifier image)
+ [predictions-with-default-dtype] (infer/classify-image classifier image 10)
+ [predictions] (infer/classify-image classifier image 5 dtype/FLOAT32)]
+ (is (= 1000 (count predictions-all)))
+ (is (= 10 (count predictions-with-default-dtype)))
(is (some? predictions))
(is (= 5 (count predictions)))
(is (every? #(= 2 (count %)) predictions))
@@ -58,8 +62,12 @@
(let [classifier (create-classifier)
image-batch (infer/load-image-paths ["test/test-images/kitten.jpg"
"test/test-images/Pug-Cookie.jpg"])
- batch-predictions (infer/classify-image-batch classifier image-batch 5)
+ batch-predictions-all (infer/classify-image-batch classifier image-batch)
+ batch-predictions-with-default-dtype (infer/classify-image-batch classifier image-batch 10)
+ batch-predictions (infer/classify-image-batch classifier image-batch 5 dtype/FLOAT32)
predictions (first batch-predictions)]
+ (is (= 1000 (count (first batch-predictions-all))))
+ (is (= 10 (count (first batch-predictions-with-default-dtype))))
(is (some? batch-predictions))
(is (= 5 (count predictions)))
(is (every? #(= 2 (count %)) predictions))
diff --git a/contrib/clojure-package/test/org/apache/clojure_mxnet/infer/objectdetector_test.clj b/contrib/clojure-package/test/org/apache/clojure_mxnet/infer/objectdetector_test.clj
index 788a59491095..3a0e3d30a1d9 100644
--- a/contrib/clojure-package/test/org/apache/clojure_mxnet/infer/objectdetector_test.clj
+++ b/contrib/clojure-package/test/org/apache/clojure_mxnet/infer/objectdetector_test.clj
@@ -40,9 +40,11 @@
(deftest test-single-detection
(let [detector (create-detector)
image (infer/load-image-from-file "test/test-images/kitten.jpg")
+ [predictions-all] (infer/detect-objects detector image)
[predictions] (infer/detect-objects detector image 5)]
(is (some? predictions))
(is (= 5 (count predictions)))
+ (is (= 13 (count predictions-all)))
(is (every? #(= 2 (count %)) predictions))
(is (every? #(string? (first %)) predictions))
(is (every? #(= 5 (count (second %))) predictions))
@@ -53,9 +55,11 @@
(let [detector (create-detector)
image-batch (infer/load-image-paths ["test/test-images/kitten.jpg"
"test/test-images/Pug-Cookie.jpg"])
+ batch-predictions-all (infer/detect-objects-batch detector image-batch)
batch-predictions (infer/detect-objects-batch detector image-batch 5)
predictions (first batch-predictions)]
(is (some? batch-predictions))
+ (is (= 13 (count (first batch-predictions-all))))
(is (= 5 (count predictions)))
(is (every? #(= 2 (count %)) predictions))
(is (every? #(string? (first %)) predictions))
diff --git a/contrib/clojure-package/test/org/apache/clojure_mxnet/ndarray_test.clj b/contrib/clojure-package/test/org/apache/clojure_mxnet/ndarray_test.clj
index 79e94412d0df..9ffd3abed2f9 100644
--- a/contrib/clojure-package/test/org/apache/clojure_mxnet/ndarray_test.clj
+++ b/contrib/clojure-package/test/org/apache/clojure_mxnet/ndarray_test.clj
@@ -97,7 +97,7 @@
(is (= [1.0 1.0] (->vec ndhalves)))))
(deftest test-full
- (let [nda (full [1 2] 3)]
+ (let [nda (full [1 2] 3.0)]
(is (= (shape nda) (mx-shape/->shape [1 2])))
(is (= [3.0 3.0] (->vec nda)))))
diff --git a/contrib/clojure-package/test/org/apache/clojure_mxnet/primitives_test.clj b/contrib/clojure-package/test/org/apache/clojure_mxnet/primitives_test.clj
new file mode 100644
index 000000000000..1a538e537b8b
--- /dev/null
+++ b/contrib/clojure-package/test/org/apache/clojure_mxnet/primitives_test.clj
@@ -0,0 +1,45 @@
+;;
+;; Licensed to the Apache Software Foundation (ASF) under one or more
+;; contributor license agreements. See the NOTICE file distributed with
+;; this work for additional information regarding copyright ownership.
+;; The ASF licenses this file to You under the Apache License, Version 2.0
+;; (the "License"); you may not use this file except in compliance with
+;; the License. You may obtain a copy of the License at
+;;
+;; http://www.apache.org/licenses/LICENSE-2.0
+;;
+;; Unless required by applicable law or agreed to in writing, software
+;; distributed under the License is distributed on an "AS IS" BASIS,
+;; WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+;; See the License for the specific language governing permissions and
+;; limitations under the License.
+;;
+
+(ns org.apache.clojure-mxnet.primitives-test
+ (:require [org.apache.clojure-mxnet.primitives :as primitives]
+ [clojure.test :refer :all])
+ (:import (org.apache.mxnet MX_PRIMITIVES$MX_PRIMITIVE_TYPE
+ MX_PRIMITIVES$MX_FLOAT
+ MX_PRIMITIVES$MX_Double)))
+
+(deftest test-primitive-types
+ (is (not (primitives/primitive? 3)))
+ (is (primitives/primitive? (primitives/mx-float 3)))
+ (is (primitives/primitive? (primitives/mx-double 3))))
+
+(deftest test-float-primitives
+ (is (instance? MX_PRIMITIVES$MX_PRIMITIVE_TYPE (primitives/mx-float 3)))
+ (is (instance? MX_PRIMITIVES$MX_FLOAT (primitives/mx-float 3)))
+ (is (instance? Float (-> (primitives/mx-float 3)
+ (primitives/->num))))
+ (is (= 3.0 (-> (primitives/mx-float 3)
+ (primitives/->num)))))
+
+(deftest test-double-primitives
+ (is (instance? MX_PRIMITIVES$MX_PRIMITIVE_TYPE (primitives/mx-double 2)))
+ (is (instance? MX_PRIMITIVES$MX_Double (primitives/mx-double 2)))
+ (is (instance? Double (-> (primitives/mx-double 2)
+ (primitives/->num))))
+ (is (= 2.0 (-> (primitives/mx-double 2)
+ (primitives/->num)))))
+
diff --git a/contrib/clojure-package/test/org/apache/clojure_mxnet/util_test.clj b/contrib/clojure-package/test/org/apache/clojure_mxnet/util_test.clj
index bd77a8a0edc6..c26f83d5aa49 100644
--- a/contrib/clojure-package/test/org/apache/clojure_mxnet/util_test.clj
+++ b/contrib/clojure-package/test/org/apache/clojure_mxnet/util_test.clj
@@ -20,6 +20,7 @@
[org.apache.clojure-mxnet.shape :as mx-shape]
[org.apache.clojure-mxnet.util :as util]
[org.apache.clojure-mxnet.ndarray :as ndarray]
+ [org.apache.clojure-mxnet.primitives :as primitives]
[org.apache.clojure-mxnet.symbol :as sym]
[org.apache.clojure-mxnet.test-util :as test-util]
[clojure.spec.alpha :as s])
@@ -133,6 +134,9 @@
(is (= "[F" (->> (util/coerce-param [1 2] #{"float<>"}) str (take 2) (apply str))))
(is (= "[L" (->> (util/coerce-param [1 2] #{"java.lang.String<>"}) str (take 2) (apply str))))
+ (is (primitives/primitive? (util/coerce-param 1.0 #{"org.apache.mxnet.MX_PRIMITIVES$MX_PRIMITIVE_TYPE"})))
+ (is (primitives/primitive? (util/coerce-param (float 1.0) #{"org.apache.mxnet.MX_PRIMITIVES$MX_PRIMITIVE_TYPE"})))
+
(is (= 1 (util/coerce-param 1 #{"unknown"}))))
(deftest test-nil-or-coerce-param
@@ -171,6 +175,12 @@
(util/convert-tuple [1 2]))))
(is (= [1 2 3] (util/coerce-return
(util/convert-tuple [1 2 3]))))
+
+ (is (instance? Double (util/coerce-return (primitives/mx-double 3))))
+ (is (= 3.0 (util/coerce-return (primitives/mx-double 3))))
+ (is (instance? Float (util/coerce-return (primitives/mx-float 2))))
+ (is (= 2.0 (util/coerce-return (primitives/mx-float 2))))
+
(is (= "foo" (util/coerce-return "foo"))))
(deftest test-translate-keyword-shape
diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/Base.scala b/scala-package/core/src/main/scala/org/apache/mxnet/Base.scala
index ed7aff602f63..001bd04d2c95 100644
--- a/scala-package/core/src/main/scala/org/apache/mxnet/Base.scala
+++ b/scala-package/core/src/main/scala/org/apache/mxnet/Base.scala
@@ -18,7 +18,9 @@
package org.apache.mxnet
import org.apache.mxnet.util.NativeLibraryLoader
-import org.slf4j.{LoggerFactory, Logger}
+import org.slf4j.{Logger, LoggerFactory}
+
+import scala.Specializable.Group
private[mxnet] object Base {
private val logger: Logger = LoggerFactory.getLogger("MXNetJVM")
@@ -57,6 +59,9 @@ private[mxnet] object Base {
val MX_REAL_TYPE = DType.Float32
+ // The primitives currently supported for NDArray operations
+ val MX_PRIMITIVES = new Group ((Double, Float))
+
try {
try {
tryLoadLibraryOS("mxnet-scala")
diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/LibInfo.scala b/scala-package/core/src/main/scala/org/apache/mxnet/LibInfo.scala
index 0a5683aa7ab3..20b6ed9fc806 100644
--- a/scala-package/core/src/main/scala/org/apache/mxnet/LibInfo.scala
+++ b/scala-package/core/src/main/scala/org/apache/mxnet/LibInfo.scala
@@ -93,6 +93,9 @@ private[mxnet] class LibInfo {
@native def mxNDArraySyncCopyFromCPU(handle: NDArrayHandle,
source: Array[MXFloat],
size: Int): Int
+ @native def mxFloat64NDArraySyncCopyFromCPU(handle: NDArrayHandle,
+ source: Array[Double],
+ size: Int): Int
@native def mxNDArrayLoad(fname: String,
outSize: MXUintRef,
handles: ArrayBuffer[NDArrayHandle],
diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/MX_PRIMITIVES.scala b/scala-package/core/src/main/scala/org/apache/mxnet/MX_PRIMITIVES.scala
new file mode 100644
index 000000000000..cb978856963c
--- /dev/null
+++ b/scala-package/core/src/main/scala/org/apache/mxnet/MX_PRIMITIVES.scala
@@ -0,0 +1,85 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mxnet
+
+object MX_PRIMITIVES {
+
+ /**
+ * This defines the basic primitives we can use in Scala for mathematical
+ * computations in NDArrays.This gives us a flexibility to expand to
+ * more supported primitives in the future. Currently Float and Double
+ * are supported. The functions which accept MX_PRIMITIVE_TYPE as input can also accept
+ * plain old Float and Double data as inputs because of the underlying
+ * implicit conversion between primitives to MX_PRIMITIVE_TYPE.
+ */
+ trait MX_PRIMITIVE_TYPE extends Ordered[MX_PRIMITIVE_TYPE]{
+
+ def toString: String
+
+ def unary_- : MX_PRIMITIVE_TYPE
+ }
+
+ trait MXPrimitiveOrdering extends Ordering[MX_PRIMITIVE_TYPE] {
+
+ def compare(x: MX_PRIMITIVE_TYPE, y: MX_PRIMITIVE_TYPE): Int = x.compare(y)
+
+ }
+
+ implicit object MX_PRIMITIVE_TYPE extends MXPrimitiveOrdering
+
+ /**
+ * Wrapper over Float in Scala.
+ * @param data
+ */
+ class MX_FLOAT(val data: Float) extends MX_PRIMITIVE_TYPE {
+
+ override def toString: String = data.toString
+
+ override def unary_- : MX_PRIMITIVE_TYPE = new MX_FLOAT(data.unary_-)
+
+ override def compare(that: MX_PRIMITIVE_TYPE): Int = {
+ this.data.compareTo(that.asInstanceOf[MX_FLOAT].data)
+ }
+ }
+
+ implicit def FloatToMX_Float(d : Float): MX_FLOAT = new MX_FLOAT(d)
+
+ implicit def MX_FloatToFloat(d: MX_FLOAT) : Float = d.data
+
+ implicit def IntToMX_Float(d: Int): MX_FLOAT = new MX_FLOAT(d.toFloat)
+
+ /**
+ * Wrapper over Double in Scala.
+ * @param data
+ */
+ class MX_Double(val data: Double) extends MX_PRIMITIVE_TYPE {
+
+ override def toString: String = data.toString
+
+ override def unary_- : MX_PRIMITIVE_TYPE = new MX_Double(data.unary_-)
+
+ override def compare(that: MX_PRIMITIVE_TYPE): Int = {
+ this.data.compareTo(that.asInstanceOf[MX_Double].data)
+ }
+ }
+
+ implicit def DoubleToMX_Double(d : Double): MX_Double = new MX_Double(d)
+
+ implicit def MX_DoubleToDouble(d: MX_Double) : Double = d.data
+
+}
diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/NDArray.scala b/scala-package/core/src/main/scala/org/apache/mxnet/NDArray.scala
index 125958150b72..163ed2682532 100644
--- a/scala-package/core/src/main/scala/org/apache/mxnet/NDArray.scala
+++ b/scala-package/core/src/main/scala/org/apache/mxnet/NDArray.scala
@@ -21,6 +21,7 @@ import java.nio.{ByteBuffer, ByteOrder}
import org.apache.mxnet.Base._
import org.apache.mxnet.DType.DType
+import org.apache.mxnet.MX_PRIMITIVES.{MX_PRIMITIVE_TYPE}
import org.slf4j.LoggerFactory
import scala.collection.mutable
@@ -262,16 +263,46 @@ object NDArray extends NDArrayBase {
arr
}
- // Perform power operator
+ def full(shape: Shape, value: Double, ctx: Context): NDArray = {
+ val arr = empty(shape, ctx, DType.Float64)
+ arr.set(value)
+ arr
+ }
+
+ /**
+ * Create a new NDArray filled with given value, with specified shape.
+ * @param shape shape of the NDArray.
+ * @param value value to be filled with
+ */
+ def full(shape: Shape, value: Double): NDArray = {
+ full(shape, value, null)
+ }
+
+
+ /**
+ * Perform power operation on NDArray. Returns result as NDArray
+ * @param lhs
+ * @param rhs
+ */
def power(lhs: NDArray, rhs: NDArray): NDArray = {
NDArray.genericNDArrayFunctionInvoke("_power", Seq(lhs, rhs))
}
- def power(lhs: NDArray, rhs: Float): NDArray = {
+ /**
+ * Perform scalar power operation on NDArray. Returns result as NDArray
+ * @param lhs NDArray on which to perform the operation on.
+ * @param rhs The scalar input. Can be of type Float/Double
+ */
+ def power(lhs: NDArray, rhs: MX_PRIMITIVE_TYPE): NDArray = {
NDArray.genericNDArrayFunctionInvoke("_power_scalar", Seq(lhs, rhs))
}
- def power(lhs: Float, rhs: NDArray): NDArray = {
+ /**
+ * Perform scalar power operation on NDArray. Returns result as NDArray
+ * @param lhs The scalar input. Can be of type Float/Double
+ * @param rhs NDArray on which to perform the operation on.
+ */
+ def power(lhs: MX_PRIMITIVE_TYPE, rhs: NDArray): NDArray = {
NDArray.genericNDArrayFunctionInvoke("_rpower_scalar", Seq(lhs, rhs))
}
@@ -280,11 +311,21 @@ object NDArray extends NDArrayBase {
NDArray.genericNDArrayFunctionInvoke("_maximum", Seq(lhs, rhs))
}
- def maximum(lhs: NDArray, rhs: Float): NDArray = {
+ /**
+ * Perform the max operation on NDArray. Returns the result as NDArray.
+ * @param lhs NDArray on which to perform the operation on.
+ * @param rhs The scalar input. Can be of type Float/Double
+ */
+ def maximum(lhs: NDArray, rhs: MX_PRIMITIVE_TYPE): NDArray = {
NDArray.genericNDArrayFunctionInvoke("_maximum_scalar", Seq(lhs, rhs))
}
- def maximum(lhs: Float, rhs: NDArray): NDArray = {
+ /**
+ * Perform the max operation on NDArray. Returns the result as NDArray.
+ * @param lhs The scalar input. Can be of type Float/Double
+ * @param rhs NDArray on which to perform the operation on.
+ */
+ def maximum(lhs: MX_PRIMITIVE_TYPE, rhs: NDArray): NDArray = {
NDArray.genericNDArrayFunctionInvoke("_maximum_scalar", Seq(lhs, rhs))
}
@@ -293,11 +334,21 @@ object NDArray extends NDArrayBase {
NDArray.genericNDArrayFunctionInvoke("_minimum", Seq(lhs, rhs))
}
- def minimum(lhs: NDArray, rhs: Float): NDArray = {
+ /**
+ * Perform the min operation on NDArray. Returns the result as NDArray.
+ * @param lhs NDArray on which to perform the operation on.
+ * @param rhs The scalar input. Can be of type Float/Double
+ */
+ def minimum(lhs: NDArray, rhs: MX_PRIMITIVE_TYPE): NDArray = {
NDArray.genericNDArrayFunctionInvoke("_minimum_scalar", Seq(lhs, rhs))
}
- def minimum(lhs: Float, rhs: NDArray): NDArray = {
+ /**
+ * Perform the min operation on NDArray. Returns the result as NDArray.
+ * @param lhs The scalar input. Can be of type Float/Double
+ * @param rhs NDArray on which to perform the operation on.
+ */
+ def minimum(lhs: MX_PRIMITIVE_TYPE, rhs: NDArray): NDArray = {
NDArray.genericNDArrayFunctionInvoke("_minimum_scalar", Seq(lhs, rhs))
}
@@ -310,7 +361,15 @@ object NDArray extends NDArrayBase {
NDArray.genericNDArrayFunctionInvoke("broadcast_equal", Seq(lhs, rhs))
}
- def equal(lhs: NDArray, rhs: Float): NDArray = {
+ /**
+ * Returns the result of element-wise **equal to** (==) comparison operation with broadcasting.
+ * For each element in input arrays, return 1(true) if corresponding elements are same,
+ * otherwise return 0(false).
+ *
+ * @param lhs NDArray
+ * @param rhs The scalar input. Can be of type Float/Double
+ */
+ def equal(lhs: NDArray, rhs: MX_PRIMITIVE_TYPE): NDArray = {
NDArray.genericNDArrayFunctionInvoke("_equal_scalar", Seq(lhs, rhs))
}
@@ -324,7 +383,15 @@ object NDArray extends NDArrayBase {
NDArray.genericNDArrayFunctionInvoke("broadcast_not_equal", Seq(lhs, rhs))
}
- def notEqual(lhs: NDArray, rhs: Float): NDArray = {
+ /**
+ * Returns the result of element-wise **not equal to** (!=) comparison operation
+ * with broadcasting.
+ * For each element in input arrays, return 1(true) if corresponding elements are different,
+ * otherwise return 0(false).
+ * @param lhs NDArray
+ * @param rhs The scalar input. Can be of type Float/Double
+ */
+ def notEqual(lhs: NDArray, rhs: MX_PRIMITIVE_TYPE): NDArray = {
NDArray.genericNDArrayFunctionInvoke("_not_equal_scalar", Seq(lhs, rhs))
}
@@ -338,7 +405,16 @@ object NDArray extends NDArrayBase {
NDArray.genericNDArrayFunctionInvoke("broadcast_greater", Seq(lhs, rhs))
}
- def greater(lhs: NDArray, rhs: Float): NDArray = {
+ /**
+ * Returns the result of element-wise **greater than** (>) comparison operation
+ * with broadcasting.
+ * For each element in input arrays, return 1(true) if lhs elements are greater than rhs,
+ * otherwise return 0(false).
+ *
+ * @param lhs NDArray
+ * @param rhs The scalar input. Can be of type Float/Double
+ */
+ def greater(lhs: NDArray, rhs: MX_PRIMITIVE_TYPE): NDArray = {
NDArray.genericNDArrayFunctionInvoke("_greater_scalar", Seq(lhs, rhs))
}
@@ -352,7 +428,16 @@ object NDArray extends NDArrayBase {
NDArray.genericNDArrayFunctionInvoke("broadcast_greater_equal", Seq(lhs, rhs))
}
- def greaterEqual(lhs: NDArray, rhs: Float): NDArray = {
+ /**
+ * Returns the result of element-wise **greater than or equal to** (>=) comparison
+ * operation with broadcasting.
+ * For each element in input arrays, return 1(true) if lhs elements are greater than equal to
+ * rhs, otherwise return 0(false).
+ *
+ * @param lhs NDArray
+ * @param rhs The scalar input. Can be of type Float/Double
+ */
+ def greaterEqual(lhs: NDArray, rhs: MX_PRIMITIVE_TYPE): NDArray = {
NDArray.genericNDArrayFunctionInvoke("_greater_equal_scalar", Seq(lhs, rhs))
}
@@ -366,7 +451,15 @@ object NDArray extends NDArrayBase {
NDArray.genericNDArrayFunctionInvoke("broadcast_lesser", Seq(lhs, rhs))
}
- def lesser(lhs: NDArray, rhs: Float): NDArray = {
+ /**
+ * Returns the result of element-wise **lesser than** (<) comparison operation
+ * with broadcasting.
+ * For each element in input arrays, return 1(true) if lhs elements are less than rhs,
+ * otherwise return 0(false).
+ * @param lhs NDArray
+ * @param rhs The scalar input. Can be of type Float/Double
+ */
+ def lesser(lhs: NDArray, rhs: MX_PRIMITIVE_TYPE): NDArray = {
NDArray.genericNDArrayFunctionInvoke("_lesser_scalar", Seq(lhs, rhs))
}
@@ -380,7 +473,16 @@ object NDArray extends NDArrayBase {
NDArray.genericNDArrayFunctionInvoke("broadcast_lesser_equal", Seq(lhs, rhs))
}
- def lesserEqual(lhs: NDArray, rhs: Float): NDArray = {
+ /**
+ * Returns the result of element-wise **lesser than or equal to** (<=) comparison
+ * operation with broadcasting.
+ * For each element in input arrays, return 1(true) if lhs elements are
+ * lesser than equal to rhs, otherwise return 0(false).
+ *
+ * @param lhs NDArray
+ * @param rhs The scalar input. Can be of type Float/Double
+ */
+ def lesserEqual(lhs: NDArray, rhs: MX_PRIMITIVE_TYPE): NDArray = {
NDArray.genericNDArrayFunctionInvoke("_lesser_equal_scalar", Seq(lhs, rhs))
}
@@ -397,6 +499,16 @@ object NDArray extends NDArrayBase {
arr
}
+ def array(sourceArr: Array[Double], shape: Shape, ctx: Context): NDArray = {
+ val arr = empty(shape, ctx, dtype = DType.Float64)
+ arr.set(sourceArr)
+ arr
+ }
+
+ def array(sourceArr: Array[Double], shape: Shape): NDArray = {
+ array(sourceArr, shape, null)
+ }
+
/**
* Returns evenly spaced values within a given interval.
* Values are generated within the half-open interval [`start`, `stop`). In other
@@ -645,6 +757,12 @@ class NDArray private[mxnet](private[mxnet] val handle: NDArrayHandle,
checkCall(_LIB.mxNDArraySyncCopyFromCPU(handle, source, source.length))
}
+ private def syncCopyfrom(source: Array[Double]): Unit = {
+ require(source.length == size,
+ s"array size (${source.length}) do not match the size of NDArray ($size)")
+ checkCall(_LIB.mxFloat64NDArraySyncCopyFromCPU(handle, source, source.length))
+ }
+
/**
* Return a sliced NDArray that shares memory with current one.
* NDArray only support continuous slicing on axis 0
@@ -759,7 +877,7 @@ class NDArray private[mxnet](private[mxnet] val handle: NDArrayHandle,
* @param value Value to set
* @return Current NDArray
*/
- def set(value: Float): NDArray = {
+ def set(value: MX_PRIMITIVE_TYPE): NDArray = {
require(writable, "trying to assign to a readonly NDArray")
NDArray.genericNDArrayFunctionInvoke("_set_value", Seq(value), Map("out" -> this))
this
@@ -776,11 +894,17 @@ class NDArray private[mxnet](private[mxnet] val handle: NDArrayHandle,
this
}
+ def set(other: Array[Double]): NDArray = {
+ require(writable, "trying to assign to a readonly NDArray")
+ syncCopyfrom(other)
+ this
+ }
+
def +(other: NDArray): NDArray = {
NDArray.genericNDArrayFunctionInvoke("_plus", Seq(this, other))
}
- def +(other: Float): NDArray = {
+ def +(other: MX_PRIMITIVE_TYPE): NDArray = {
NDArray.genericNDArrayFunctionInvoke("_plus_scalar", Seq(this, other))
}
@@ -792,7 +916,7 @@ class NDArray private[mxnet](private[mxnet] val handle: NDArrayHandle,
this
}
- def +=(other: Float): NDArray = {
+ def +=(other: MX_PRIMITIVE_TYPE): NDArray = {
if (!writable) {
throw new IllegalArgumentException("trying to add to a readonly NDArray")
}
@@ -804,7 +928,7 @@ class NDArray private[mxnet](private[mxnet] val handle: NDArrayHandle,
NDArray.genericNDArrayFunctionInvoke("_minus", Seq(this, other))
}
- def -(other: Float): NDArray = {
+ def -(other: MX_PRIMITIVE_TYPE): NDArray = {
NDArray.genericNDArrayFunctionInvoke("_minus_scalar", Seq(this, other))
}
@@ -816,7 +940,7 @@ class NDArray private[mxnet](private[mxnet] val handle: NDArrayHandle,
this
}
- def -=(other: Float): NDArray = {
+ def -=(other: MX_PRIMITIVE_TYPE): NDArray = {
if (!writable) {
throw new IllegalArgumentException("trying to subtract from a readonly NDArray")
}
@@ -828,7 +952,7 @@ class NDArray private[mxnet](private[mxnet] val handle: NDArrayHandle,
NDArray.genericNDArrayFunctionInvoke("_mul", Seq(this, other))
}
- def *(other: Float): NDArray = {
+ def *(other: MX_PRIMITIVE_TYPE): NDArray = {
NDArray.genericNDArrayFunctionInvoke("_mul_scalar", Seq(this, other))
}
@@ -844,7 +968,7 @@ class NDArray private[mxnet](private[mxnet] val handle: NDArrayHandle,
this
}
- def *=(other: Float): NDArray = {
+ def *=(other: MX_PRIMITIVE_TYPE): NDArray = {
if (!writable) {
throw new IllegalArgumentException("trying to multiply to a readonly NDArray")
}
@@ -856,7 +980,7 @@ class NDArray private[mxnet](private[mxnet] val handle: NDArrayHandle,
NDArray.genericNDArrayFunctionInvoke("_div", Seq(this, other))
}
- def /(other: Float): NDArray = {
+ def /(other: MX_PRIMITIVE_TYPE): NDArray = {
NDArray.genericNDArrayFunctionInvoke("_div_scalar", Seq(this, other))
}
@@ -868,7 +992,7 @@ class NDArray private[mxnet](private[mxnet] val handle: NDArrayHandle,
this
}
- def /=(other: Float): NDArray = {
+ def /=(other: MX_PRIMITIVE_TYPE): NDArray = {
if (!writable) {
throw new IllegalArgumentException("trying to divide from a readonly NDArray")
}
@@ -880,7 +1004,7 @@ class NDArray private[mxnet](private[mxnet] val handle: NDArrayHandle,
NDArray.power(this, other)
}
- def **(other: Float): NDArray = {
+ def **(other: MX_PRIMITIVE_TYPE): NDArray = {
NDArray.power(this, other)
}
@@ -888,7 +1012,7 @@ class NDArray private[mxnet](private[mxnet] val handle: NDArrayHandle,
NDArray.genericNDArrayFunctionInvoke("_power", Seq(this, other), Map("out" -> this))
}
- def **=(other: Float): NDArray = {
+ def **=(other: MX_PRIMITIVE_TYPE): NDArray = {
NDArray.genericNDArrayFunctionInvoke("_power_scalar", Seq(this, other), Map("out" -> this))
}
@@ -896,7 +1020,7 @@ class NDArray private[mxnet](private[mxnet] val handle: NDArrayHandle,
NDArray.greater(this, other)
}
- def >(other: Float): NDArray = {
+ def >(other: MX_PRIMITIVE_TYPE): NDArray = {
NDArray.greater(this, other)
}
@@ -904,7 +1028,7 @@ class NDArray private[mxnet](private[mxnet] val handle: NDArrayHandle,
NDArray.greaterEqual(this, other)
}
- def >=(other: Float): NDArray = {
+ def >=(other: MX_PRIMITIVE_TYPE): NDArray = {
NDArray.greaterEqual(this, other)
}
@@ -912,7 +1036,7 @@ class NDArray private[mxnet](private[mxnet] val handle: NDArrayHandle,
NDArray.lesser(this, other)
}
- def <(other: Float): NDArray = {
+ def <(other: MX_PRIMITIVE_TYPE): NDArray = {
NDArray.lesser(this, other)
}
@@ -920,7 +1044,7 @@ class NDArray private[mxnet](private[mxnet] val handle: NDArrayHandle,
NDArray.lesserEqual(this, other)
}
- def <=(other: Float): NDArray = {
+ def <=(other: MX_PRIMITIVE_TYPE): NDArray = {
NDArray.lesserEqual(this, other)
}
@@ -928,7 +1052,7 @@ class NDArray private[mxnet](private[mxnet] val handle: NDArrayHandle,
NDArray.genericNDArrayFunctionInvoke("_mod", Seq(this, other))
}
- def %(other: Float): NDArray = {
+ def %(other: MX_PRIMITIVE_TYPE): NDArray = {
NDArray.genericNDArrayFunctionInvoke("_mod_scalar", Seq(this, other))
}
@@ -940,7 +1064,7 @@ class NDArray private[mxnet](private[mxnet] val handle: NDArrayHandle,
this
}
- def %=(other: Float): NDArray = {
+ def %=(other: MX_PRIMITIVE_TYPE): NDArray = {
if (!writable) {
throw new IllegalArgumentException("trying to take modulo from a readonly NDArray")
}
@@ -956,6 +1080,14 @@ class NDArray private[mxnet](private[mxnet] val handle: NDArrayHandle,
internal.toFloatArray
}
+ /**
+ * Return a copied flat java array of current array (row-major) with datatype as Float64/Double.
+ * @return A copy of array content.
+ */
+ def toFloat64Array: Array[Double] = {
+ internal.toDoubleArray
+ }
+
def internal: NDArrayInternal = {
val myType = dtype
val arrLength = DType.numOfBytes(myType) * size
@@ -975,6 +1107,11 @@ class NDArray private[mxnet](private[mxnet] val handle: NDArrayHandle,
this.toArray(0)
}
+ def toFloat64Scalar: Double = {
+ require(shape == Shape(1), "The current array is not a scalar")
+ this.toFloat64Array(0)
+ }
+
/**
* Copy the content of current array to other.
*
@@ -997,7 +1134,7 @@ class NDArray private[mxnet](private[mxnet] val handle: NDArrayHandle,
* @return The copy target NDArray
*/
def copyTo(ctx: Context): NDArray = {
- val ret = new NDArray(NDArray.newAllocHandle(shape, ctx, delayAlloc = true))
+ val ret = new NDArray(NDArray.newAllocHandle(shape, ctx, delayAlloc = true, dtype = dtype))
copyTo(ret)
}
@@ -1047,11 +1184,11 @@ class NDArray private[mxnet](private[mxnet] val handle: NDArrayHandle,
private[mxnet] object NDArrayConversions {
implicit def int2Scalar(x: Int): NDArrayConversions = new NDArrayConversions(x.toFloat)
- implicit def double2Scalar(x: Double): NDArrayConversions = new NDArrayConversions(x.toFloat)
+ implicit def double2Scalar(x: Double): NDArrayConversions = new NDArrayConversions(x)
implicit def float2Scalar(x: Float): NDArrayConversions = new NDArrayConversions(x)
}
-private[mxnet] class NDArrayConversions(val value: Float) {
+private[mxnet] class NDArrayConversions(val value: MX_PRIMITIVE_TYPE) {
def +(other: NDArray): NDArray = {
other + value
}
@@ -1145,34 +1282,39 @@ private[mxnet] class NDArrayFuncReturn(private[mxnet] val arr: Array[NDArray]) {
def waitToRead(): Unit = head.waitToRead()
def context: Context = head.context
def set(value: Float): NDArray = head.set(value)
+ def set(value: Double): NDArray = head.set(value)
def set(other: NDArray): NDArray = head.set(other)
def set(other: Array[Float]): NDArray = head.set(other)
+ def set(other: Array[Double]): NDArray = head.set(other)
def +(other: NDArray): NDArray = head + other
- def +(other: Float): NDArray = head + other
+ def +(other: MX_PRIMITIVE_TYPE): NDArray = head + other
def +=(other: NDArray): NDArray = head += other
- def +=(other: Float): NDArray = head += other
+ def +=(other: MX_PRIMITIVE_TYPE): NDArray = head += other
def -(other: NDArray): NDArray = head - other
- def -(other: Float): NDArray = head - other
+ def -(other: MX_PRIMITIVE_TYPE): NDArray = head - other
def -=(other: NDArray): NDArray = head -= other
- def -=(other: Float): NDArray = head -= other
+ def -=(other: MX_PRIMITIVE_TYPE): NDArray = head -= other
def *(other: NDArray): NDArray = head * other
- def *(other: Float): NDArray = head * other
+ def *(other: MX_PRIMITIVE_TYPE): NDArray = head * other
def unary_-(): NDArray = -head
def *=(other: NDArray): NDArray = head *= other
- def *=(other: Float): NDArray = head *= other
+ def *=(other: MX_PRIMITIVE_TYPE): NDArray = head *= other
def /(other: NDArray): NDArray = head / other
+ def /(other: MX_PRIMITIVE_TYPE): NDArray = head / other
def **(other: NDArray): NDArray = head ** other
- def **(other: Float): NDArray = head ** other
+ def **(other: MX_PRIMITIVE_TYPE): NDArray = head ** other
def >(other: NDArray): NDArray = head > other
- def >(other: Float): NDArray = head > other
+ def >(other: MX_PRIMITIVE_TYPE): NDArray = head > other
def >=(other: NDArray): NDArray = head >= other
- def >=(other: Float): NDArray = head >= other
+ def >=(other: MX_PRIMITIVE_TYPE): NDArray = head >= other
def <(other: NDArray): NDArray = head < other
- def <(other: Float): NDArray = head < other
+ def <(other: MX_PRIMITIVE_TYPE): NDArray = head < other
def <=(other: NDArray): NDArray = head <= other
- def <=(other: Float): NDArray = head <= other
+ def <=(other: MX_PRIMITIVE_TYPE): NDArray = head <= other
def toArray: Array[Float] = head.toArray
+ def toFloat64Array: Array[Double] = head.toFloat64Array
def toScalar: Float = head.toScalar
+ def toFloat64Scalar: Double = head.toFloat64Scalar
def copyTo(other: NDArray): NDArray = head.copyTo(other)
def copyTo(ctx: Context): NDArray = head.copyTo(ctx)
def copy(): NDArray = head.copy()
diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/io/MXDataIter.scala b/scala-package/core/src/main/scala/org/apache/mxnet/io/MXDataIter.scala
index a84bd106b763..e30098c3088b 100644
--- a/scala-package/core/src/main/scala/org/apache/mxnet/io/MXDataIter.scala
+++ b/scala-package/core/src/main/scala/org/apache/mxnet/io/MXDataIter.scala
@@ -53,9 +53,9 @@ private[mxnet] class MXDataIter(private[mxnet] val handle: DataIterHandle,
val label = currentBatch.label(0)
// properties
val res = (
- // TODO: need to allow user to specify DType and Layout
- IndexedSeq(new DataDesc(dataName, data.shape, DType.Float32, Layout.UNDEFINED)),
- IndexedSeq(new DataDesc(labelName, label.shape, DType.Float32, Layout.UNDEFINED)),
+ // TODO: need to allow user to specify Layout
+ IndexedSeq(new DataDesc(dataName, data.shape, data.dtype, Layout.UNDEFINED)),
+ IndexedSeq(new DataDesc(labelName, label.shape, label.dtype, Layout.UNDEFINED)),
ListMap(dataName -> data.shape),
ListMap(labelName -> label.shape),
data.shape(0))
diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/io/NDArrayIter.scala b/scala-package/core/src/main/scala/org/apache/mxnet/io/NDArrayIter.scala
index 0032a54dd802..e690abba0d13 100644
--- a/scala-package/core/src/main/scala/org/apache/mxnet/io/NDArrayIter.scala
+++ b/scala-package/core/src/main/scala/org/apache/mxnet/io/NDArrayIter.scala
@@ -61,7 +61,8 @@ class NDArrayIter(data: IndexedSeq[(DataDesc, NDArray)],
dataBatchSize: Int = 1, shuffle: Boolean = false,
lastBatchHandle: String = "pad",
dataName: String = "data", labelName: String = "label") {
- this(IO.initDataDesc(data, allowEmpty = false, dataName, MX_REAL_TYPE, Layout.UNDEFINED),
+ this(IO.initDataDesc(data, allowEmpty = false, dataName,
+ if (data == null || data.isEmpty) MX_REAL_TYPE else data(0).dtype, Layout.UNDEFINED),
IO.initDataDesc(label, allowEmpty = true, labelName, MX_REAL_TYPE, Layout.UNDEFINED),
dataBatchSize, shuffle, lastBatchHandle)
}
@@ -272,7 +273,7 @@ object NDArrayIter {
*/
def addData(name: String, data: NDArray): Builder = {
this.data = this.data ++ IndexedSeq((new DataDesc(name,
- data.shape, DType.Float32, Layout.UNDEFINED), data))
+ data.shape, data.dtype, Layout.UNDEFINED), data))
this
}
@@ -284,7 +285,7 @@ object NDArrayIter {
*/
def addLabel(name: String, label: NDArray): Builder = {
this.label = this.label ++ IndexedSeq((new DataDesc(name,
- label.shape, DType.Float32, Layout.UNDEFINED), label))
+ label.shape, label.dtype, Layout.UNDEFINED), label))
this
}
diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/javaapi/NDArray.scala b/scala-package/core/src/main/scala/org/apache/mxnet/javaapi/NDArray.scala
index 198102d2377f..67809c158aff 100644
--- a/scala-package/core/src/main/scala/org/apache/mxnet/javaapi/NDArray.scala
+++ b/scala-package/core/src/main/scala/org/apache/mxnet/javaapi/NDArray.scala
@@ -91,17 +91,26 @@ object NDArray extends NDArrayBase {
def full(shape: Shape, value: Float, ctx: Context): NDArray
= org.apache.mxnet.NDArray.full(shape, value, ctx)
+ def full(shape: Shape, value: Double, ctx: Context): NDArray
+ = org.apache.mxnet.NDArray.full(shape, value, ctx)
+
def power(lhs: NDArray, rhs: NDArray): NDArray = org.apache.mxnet.NDArray.power(lhs, rhs)
def power(lhs: NDArray, rhs: Float): NDArray = org.apache.mxnet.NDArray.power(lhs, rhs)
def power(lhs: Float, rhs: NDArray): NDArray = org.apache.mxnet.NDArray.power(lhs, rhs)
+ def power(lhs: NDArray, rhs: Double): NDArray = org.apache.mxnet.NDArray.power(lhs, rhs)
+ def power(lhs: Double, rhs: NDArray): NDArray = org.apache.mxnet.NDArray.power(lhs, rhs)
def maximum(lhs: NDArray, rhs: NDArray): NDArray = org.apache.mxnet.NDArray.maximum(lhs, rhs)
def maximum(lhs: NDArray, rhs: Float): NDArray = org.apache.mxnet.NDArray.maximum(lhs, rhs)
def maximum(lhs: Float, rhs: NDArray): NDArray = org.apache.mxnet.NDArray.maximum(lhs, rhs)
+ def maximum(lhs: NDArray, rhs: Double): NDArray = org.apache.mxnet.NDArray.maximum(lhs, rhs)
+ def maximum(lhs: Double, rhs: NDArray): NDArray = org.apache.mxnet.NDArray.maximum(lhs, rhs)
def minimum(lhs: NDArray, rhs: NDArray): NDArray = org.apache.mxnet.NDArray.minimum(lhs, rhs)
def minimum(lhs: NDArray, rhs: Float): NDArray = org.apache.mxnet.NDArray.minimum(lhs, rhs)
def minimum(lhs: Float, rhs: NDArray): NDArray = org.apache.mxnet.NDArray.minimum(lhs, rhs)
+ def minimum(lhs: NDArray, rhs: Double): NDArray = org.apache.mxnet.NDArray.minimum(lhs, rhs)
+ def minimum(lhs: Double, rhs: NDArray): NDArray = org.apache.mxnet.NDArray.minimum(lhs, rhs)
/**
@@ -111,6 +120,7 @@ object NDArray extends NDArrayBase {
*/
def equal(lhs: NDArray, rhs: NDArray): NDArray = org.apache.mxnet.NDArray.equal(lhs, rhs)
def equal(lhs: NDArray, rhs: Float): NDArray = org.apache.mxnet.NDArray.equal(lhs, rhs)
+ def equal(lhs: NDArray, rhs: Double): NDArray = org.apache.mxnet.NDArray.equal(lhs, rhs)
/**
* Returns the result of element-wise **not equal to** (!=) comparison operation
@@ -120,6 +130,7 @@ object NDArray extends NDArrayBase {
*/
def notEqual(lhs: NDArray, rhs: NDArray): NDArray = org.apache.mxnet.NDArray.notEqual(lhs, rhs)
def notEqual(lhs: NDArray, rhs: Float): NDArray = org.apache.mxnet.NDArray.notEqual(lhs, rhs)
+ def notEqual(lhs: NDArray, rhs: Double): NDArray = org.apache.mxnet.NDArray.notEqual(lhs, rhs)
/**
* Returns the result of element-wise **greater than** (>) comparison operation
@@ -129,6 +140,7 @@ object NDArray extends NDArrayBase {
*/
def greater(lhs: NDArray, rhs: NDArray): NDArray = org.apache.mxnet.NDArray.greater(lhs, rhs)
def greater(lhs: NDArray, rhs: Float): NDArray = org.apache.mxnet.NDArray.greater(lhs, rhs)
+ def greater(lhs: NDArray, rhs: Double): NDArray = org.apache.mxnet.NDArray.greater(lhs, rhs)
/**
* Returns the result of element-wise **greater than or equal to** (>=) comparison
@@ -140,6 +152,8 @@ object NDArray extends NDArrayBase {
= org.apache.mxnet.NDArray.greaterEqual(lhs, rhs)
def greaterEqual(lhs: NDArray, rhs: Float): NDArray
= org.apache.mxnet.NDArray.greaterEqual(lhs, rhs)
+ def greaterEqual(lhs: NDArray, rhs: Double): NDArray
+ = org.apache.mxnet.NDArray.greaterEqual(lhs, rhs)
/**
* Returns the result of element-wise **lesser than** (<) comparison operation
@@ -149,6 +163,7 @@ object NDArray extends NDArrayBase {
*/
def lesser(lhs: NDArray, rhs: NDArray): NDArray = org.apache.mxnet.NDArray.lesser(lhs, rhs)
def lesser(lhs: NDArray, rhs: Float): NDArray = org.apache.mxnet.NDArray.lesser(lhs, rhs)
+ def lesser(lhs: NDArray, rhs: Double): NDArray = org.apache.mxnet.NDArray.lesser(lhs, rhs)
/**
* Returns the result of element-wise **lesser than or equal to** (<=) comparison
@@ -160,6 +175,8 @@ object NDArray extends NDArrayBase {
= org.apache.mxnet.NDArray.lesserEqual(lhs, rhs)
def lesserEqual(lhs: NDArray, rhs: Float): NDArray
= org.apache.mxnet.NDArray.lesserEqual(lhs, rhs)
+ def lesserEqual(lhs: NDArray, rhs: Double): NDArray
+ = org.apache.mxnet.NDArray.lesserEqual(lhs, rhs)
/**
* Create a new NDArray that copies content from source_array.
@@ -172,6 +189,18 @@ object NDArray extends NDArrayBase {
= org.apache.mxnet.NDArray.array(
sourceArr.asScala.map(ele => Float.unbox(ele)).toArray, shape, ctx)
+ /**
+ * Create a new NDArray that copies content from source_array.
+ * @param sourceArr Source data (list of Doubles) to create NDArray from.
+ * @param shape shape of the NDArray
+ * @param ctx The context of the NDArray, default to current default context.
+ * @return The created NDArray.
+ */
+ def arrayWithDouble(sourceArr: java.util.List[java.lang.Double], shape: Shape,
+ ctx: Context = null): NDArray
+ = org.apache.mxnet.NDArray.array(
+ sourceArr.asScala.map(ele => Double.unbox(ele)).toArray, shape)
+
/**
* Returns evenly spaced values within a given interval.
* Values are generated within the half-open interval [`start`, `stop`). In other
@@ -205,6 +234,10 @@ class NDArray private[mxnet] (val nd: org.apache.mxnet.NDArray ) {
this(org.apache.mxnet.NDArray.array(arr, shape, ctx))
}
+ def this(arr: Array[Double], shape: Shape, ctx: Context) = {
+ this(org.apache.mxnet.NDArray.array(arr, shape, ctx))
+ }
+
def this(arr: java.util.List[java.lang.Float], shape: Shape, ctx: Context) = {
this(NDArray.array(arr, shape, ctx))
}
@@ -304,41 +337,59 @@ class NDArray private[mxnet] (val nd: org.apache.mxnet.NDArray ) {
* @return Current NDArray
*/
def set(value: Float): NDArray = nd.set(value)
+ def set(value: Double): NDArray = nd.set(value)
def set(other: NDArray): NDArray = nd.set(other)
def set(other: Array[Float]): NDArray = nd.set(other)
+ def set(other: Array[Double]): NDArray = nd.set(other)
def add(other: NDArray): NDArray = this.nd + other.nd
def add(other: Float): NDArray = this.nd + other
+ def add(other: Double): NDArray = this.nd + other
def addInplace(other: NDArray): NDArray = this.nd += other
def addInplace(other: Float): NDArray = this.nd += other
+ def addInplace(other: Double): NDArray = this.nd += other
def subtract(other: NDArray): NDArray = this.nd - other
def subtract(other: Float): NDArray = this.nd - other
+ def subtract(other: Double): NDArray = this.nd - other
def subtractInplace(other: NDArray): NDArray = this.nd -= other
def subtractInplace(other: Float): NDArray = this.nd -= other
+ def subtractInplace(other: Double): NDArray = this.nd -= other
def multiply(other: NDArray): NDArray = this.nd * other
def multiply(other: Float): NDArray = this.nd * other
+ def multiply(other: Double): NDArray = this.nd * other
def multiplyInplace(other: NDArray): NDArray = this.nd *= other
def multiplyInplace(other: Float): NDArray = this.nd *= other
+ def multiplyInplace(other: Double): NDArray = this.nd *= other
def div(other: NDArray): NDArray = this.nd / other
def div(other: Float): NDArray = this.nd / other
+ def div(other: Double): NDArray = this.nd / other
def divInplace(other: NDArray): NDArray = this.nd /= other
def divInplace(other: Float): NDArray = this.nd /= other
+ def divInplace(other: Double): NDArray = this.nd /= other
def pow(other: NDArray): NDArray = this.nd ** other
def pow(other: Float): NDArray = this.nd ** other
+ def pow(other: Double): NDArray = this.nd ** other
def powInplace(other: NDArray): NDArray = this.nd **= other
def powInplace(other: Float): NDArray = this.nd **= other
+ def powInplace(other: Double): NDArray = this.nd **= other
def mod(other: NDArray): NDArray = this.nd % other
def mod(other: Float): NDArray = this.nd % other
+ def mod(other: Double): NDArray = this.nd % other
def modInplace(other: NDArray): NDArray = this.nd %= other
def modInplace(other: Float): NDArray = this.nd %= other
+ def modInplace(other: Double): NDArray = this.nd %= other
def greater(other: NDArray): NDArray = this.nd > other
def greater(other: Float): NDArray = this.nd > other
+ def greater(other: Double): NDArray = this.nd > other
def greaterEqual(other: NDArray): NDArray = this.nd >= other
def greaterEqual(other: Float): NDArray = this.nd >= other
+ def greaterEqual(other: Double): NDArray = this.nd >= other
def lesser(other: NDArray): NDArray = this.nd < other
def lesser(other: Float): NDArray = this.nd < other
+ def lesser(other: Double): NDArray = this.nd < other
def lesserEqual(other: NDArray): NDArray = this.nd <= other
def lesserEqual(other: Float): NDArray = this.nd <= other
+ def lesserEqual(other: Double): NDArray = this.nd <= other
/**
* Return a copied flat java array of current array (row-major).
@@ -346,6 +397,12 @@ class NDArray private[mxnet] (val nd: org.apache.mxnet.NDArray ) {
*/
def toArray: Array[Float] = nd.toArray
+ /**
+ * Return a copied flat java array of current array (row-major).
+ * @return A copy of array content.
+ */
+ def toFloat64Array: Array[Double] = nd.toFloat64Array
+
/**
* Return a CPU scalar(float) of current ndarray.
* This ndarray must have shape (1,)
@@ -354,6 +411,14 @@ class NDArray private[mxnet] (val nd: org.apache.mxnet.NDArray ) {
*/
def toScalar: Float = nd.toScalar
+ /**
+ * Return a CPU scalar(float) of current ndarray.
+ * This ndarray must have shape (1,)
+ *
+ * @return The scalar representation of the ndarray.
+ */
+ def toFloat64Scalar: Double = nd.toFloat64Scalar
+
/**
* Copy the content of current array to other.
*
diff --git a/scala-package/core/src/test/java/org/apache/mxnet/javaapi/NDArrayTest.java b/scala-package/core/src/test/java/org/apache/mxnet/javaapi/NDArrayTest.java
index 2659b7848bc6..86c7eb29d2ef 100644
--- a/scala-package/core/src/test/java/org/apache/mxnet/javaapi/NDArrayTest.java
+++ b/scala-package/core/src/test/java/org/apache/mxnet/javaapi/NDArrayTest.java
@@ -40,6 +40,15 @@ public void testCreateNDArray() {
new Shape(new int[]{1, 3}),
new Context("cpu", 0));
assertTrue(Arrays.equals(nd.shape().toArray(), arr));
+
+ List list2 = Arrays.asList(1d, 1d, 1d);
+ nd = NDArray.arrayWithDouble(list2,
+ new Shape(new int[]{1, 3}),
+ new Context("cpu", 0));
+
+ // Float64 assertion
+ assertTrue(nd.dtype() == DType.Float64());
+
}
@Test
@@ -64,6 +73,12 @@ public void testComparison(){
nd = nd.subtract(nd2);
float[] lesser = new float[]{0, 0, 0};
assertTrue(Arrays.equals(nd.greater(nd2).toArray(), lesser));
+
+ NDArray nd3 = new NDArray(new double[]{1.0, 2.0, 3.0}, new Shape(new int[]{3}), new Context("cpu", 0));
+ nd3 = nd3.add(1.0);
+ double[] smaller = new double[] {2, 3, 4};
+ assertTrue(Arrays.equals(smaller, nd3.toFloat64Array()));
+
}
@Test
diff --git a/scala-package/core/src/test/scala/org/apache/mxnet/IOSuite.scala b/scala-package/core/src/test/scala/org/apache/mxnet/IOSuite.scala
index 2ec6f668dbcc..d3969b0ce77d 100644
--- a/scala-package/core/src/test/scala/org/apache/mxnet/IOSuite.scala
+++ b/scala-package/core/src/test/scala/org/apache/mxnet/IOSuite.scala
@@ -303,5 +303,32 @@ class IOSuite extends FunSuite with BeforeAndAfterAll {
assert(dataDesc(0).layout == Layout.NTC)
assert(labelDesc(0).dtype == DType.Int32)
assert(labelDesc(0).layout == Layout.NT)
+
+
+ // Test with passing Float64 hardcoded as Dtype of data
+ val dataIter4 = new NDArrayIter(
+ IO.initDataDesc(data, false, "data", DType.Float64, Layout.NTC),
+ IO.initDataDesc(label, false, "label", DType.Int32, Layout.NT),
+ 128, false, "pad")
+ val dataDesc4 = dataIter4.provideDataDesc
+ val labelDesc4 = dataIter4.provideLabelDesc
+ assert(dataDesc4(0).dtype == DType.Float64)
+ assert(dataDesc4(0).layout == Layout.NTC)
+ assert(labelDesc4(0).dtype == DType.Int32)
+ assert(labelDesc4(0).layout == Layout.NT)
+
+ // Test with Float64 coming from the data itself
+ val dataF64 = IndexedSeq(NDArray.ones(shape0, dtype = DType.Float64),
+ NDArray.zeros(shape0, dtype = DType.Float64))
+
+ val dataIter5 = new NDArrayIter(
+ IO.initDataDesc(dataF64, false, "data", DType.Float64, Layout.NTC),
+ IO.initDataDesc(label, false, "label", DType.Int32, Layout.NT),
+ 128, false, "pad")
+ val dataDesc5 = dataIter5.provideDataDesc
+ assert(dataDesc5(0).dtype == DType.Float64)
+ assert(dataDesc5(0).dtype != DType.Float32)
+ assert(dataDesc5(0).layout == Layout.NTC)
+
}
}
diff --git a/scala-package/core/src/test/scala/org/apache/mxnet/NDArraySuite.scala b/scala-package/core/src/test/scala/org/apache/mxnet/NDArraySuite.scala
index 2f3b1676d272..bc7a0a026bc3 100644
--- a/scala-package/core/src/test/scala/org/apache/mxnet/NDArraySuite.scala
+++ b/scala-package/core/src/test/scala/org/apache/mxnet/NDArraySuite.scala
@@ -21,7 +21,7 @@ import java.io.File
import java.util.concurrent.atomic.AtomicInteger
import org.apache.mxnet.NDArrayConversions._
-import org.scalatest.{Matchers, BeforeAndAfterAll, FunSuite}
+import org.scalatest.{BeforeAndAfterAll, FunSuite, Matchers}
class NDArraySuite extends FunSuite with BeforeAndAfterAll with Matchers {
private val sequence: AtomicInteger = new AtomicInteger(0)
@@ -29,6 +29,9 @@ class NDArraySuite extends FunSuite with BeforeAndAfterAll with Matchers {
test("to java array") {
val ndarray = NDArray.zeros(2, 2)
assert(ndarray.toArray === Array(0f, 0f, 0f, 0f))
+
+ val float64Array = NDArray.zeros(Shape(2, 2), dtype = DType.Float64)
+ assert(float64Array.toFloat64Array === Array(0d, 0d, 0d, 0d))
}
test("to scalar") {
@@ -38,8 +41,17 @@ class NDArraySuite extends FunSuite with BeforeAndAfterAll with Matchers {
assert(ndones.toScalar === 1f)
}
+ test("to float 64 scalar") {
+ val ndzeros = NDArray.zeros(Shape(1), dtype = DType.Float64)
+ assert(ndzeros.toFloat64Scalar === 0d)
+ val ndones = NDArray.ones(Shape(1), dtype = DType.Float64)
+ assert(ndones.toFloat64Scalar === 1d)
+ }
+
test ("call toScalar on an ndarray which is not a scalar") {
intercept[Exception] { NDArray.zeros(1, 1).toScalar }
+ intercept[Exception] { NDArray.zeros(shape = Shape (1, 1),
+ dtype = DType.Float64).toFloat64Scalar }
}
test("size and shape") {
@@ -51,12 +63,20 @@ class NDArraySuite extends FunSuite with BeforeAndAfterAll with Matchers {
test("dtype") {
val arr = NDArray.zeros(3, 2)
assert(arr.dtype === DType.Float32)
+
+ val float64Array = NDArray.zeros(shape = Shape(3, 2), dtype = DType.Float64)
+ assert(float64Array.dtype === DType.Float64)
}
test("set scalar value") {
val ndarray = NDArray.empty(2, 1)
ndarray.set(10f)
assert(ndarray.toArray === Array(10f, 10f))
+
+ val float64array = NDArray.empty(shape = Shape(2, 1), dtype = DType.Float64)
+ float64array.set(10d)
+ assert(float64array.toFloat64Array === Array(10d, 10d))
+
}
test("copy from java array") {
@@ -66,19 +86,29 @@ class NDArraySuite extends FunSuite with BeforeAndAfterAll with Matchers {
}
test("plus") {
- val ndzeros = NDArray.zeros(2, 1)
- val ndones = ndzeros + 1f
+ var ndzeros = NDArray.zeros(2, 1)
+ var ndones = ndzeros + 1f
assert(ndones.toArray === Array(1f, 1f))
assert((ndones + ndzeros).toArray === Array(1f, 1f))
assert((1 + ndones).toArray === Array(2f, 2f))
// in-place
ndones += ndones
assert(ndones.toArray === Array(2f, 2f))
+
+ // Float64 method test
+ ndzeros = NDArray.zeros(shape = Shape(2, 1), dtype = DType.Float64)
+ ndones = ndzeros + 1d
+ assert(ndones.toFloat64Array === Array(1d, 1d))
+ assert((ndones + ndzeros).toFloat64Array === Array(1d, 1d))
+ assert((1d + ndones).toArray === Array(2d, 2d))
+ // in-place
+ ndones += ndones
+ assert(ndones.toFloat64Array === Array(2d, 2d))
}
test("minus") {
- val ndones = NDArray.ones(2, 1)
- val ndzeros = ndones - 1f
+ var ndones = NDArray.ones(2, 1)
+ var ndzeros = ndones - 1f
assert(ndzeros.toArray === Array(0f, 0f))
assert((ndones - ndzeros).toArray === Array(1f, 1f))
assert((ndzeros - ndones).toArray === Array(-1f, -1f))
@@ -86,23 +116,46 @@ class NDArraySuite extends FunSuite with BeforeAndAfterAll with Matchers {
// in-place
ndones -= ndones
assert(ndones.toArray === Array(0f, 0f))
+
+ // Float64 methods test
+ ndones = NDArray.ones(shape = Shape(2, 1))
+ ndzeros = ndones - 1d
+ assert(ndzeros.toFloat64Array === Array(0d, 0d))
+ assert((ndones - ndzeros).toFloat64Array === Array(1d , 1d))
+ assert((ndzeros - ndones).toFloat64Array === Array(-1d , -1d))
+ assert((ndones - 1).toFloat64Array === Array(0d, 0d))
+ // in-place
+ ndones -= ndones
+ assert(ndones.toArray === Array(0d, 0d))
+
}
test("multiplication") {
- val ndones = NDArray.ones(2, 1)
- val ndtwos = ndones * 2
+ var ndones = NDArray.ones(2, 1)
+ var ndtwos = ndones * 2
assert(ndtwos.toArray === Array(2f, 2f))
assert((ndones * ndones).toArray === Array(1f, 1f))
assert((ndtwos * ndtwos).toArray === Array(4f, 4f))
ndtwos *= ndtwos
// in-place
assert(ndtwos.toArray === Array(4f, 4f))
+
+ // Float64 methods test
+ ndones = NDArray.ones(shape = Shape(2, 1), dtype = DType.Float64)
+ ndtwos = ndones * 2d
+ assert(ndtwos.toFloat64Array === Array(2d, 2d))
+ assert((ndones * ndones).toFloat64Array === Array(1d, 1d))
+ assert((ndtwos * ndtwos).toFloat64Array === Array(4d, 4d))
+ ndtwos *= ndtwos
+ // in-place
+ assert(ndtwos.toFloat64Array === Array(4d, 4d))
+
}
test("division") {
- val ndones = NDArray.ones(2, 1)
- val ndzeros = ndones - 1f
- val ndhalves = ndones / 2
+ var ndones = NDArray.ones(2, 1)
+ var ndzeros = ndones - 1f
+ var ndhalves = ndones / 2
assert(ndhalves.toArray === Array(0.5f, 0.5f))
assert((ndhalves / ndhalves).toArray === Array(1f, 1f))
assert((ndones / ndones).toArray === Array(1f, 1f))
@@ -110,37 +163,75 @@ class NDArraySuite extends FunSuite with BeforeAndAfterAll with Matchers {
ndhalves /= ndhalves
// in-place
assert(ndhalves.toArray === Array(1f, 1f))
+
+ // Float64 methods test
+ ndones = NDArray.ones(shape = Shape (2, 1), dtype = DType.Float64)
+ ndzeros = ndones - 1d
+ ndhalves = ndones / 2d
+ assert(ndhalves.toFloat64Array === Array(0.5d, 0.5d))
+ assert((ndhalves / ndhalves).toFloat64Array === Array(1d, 1d))
+ assert((ndones / ndones).toFloat64Array === Array(1d, 1d))
+ assert((ndzeros / ndones).toFloat64Array === Array(0d, 0d))
+ ndhalves /= ndhalves
+ // in-place
+ assert(ndhalves.toFloat64Array === Array(1d, 1d))
}
test("full") {
- val arr = NDArray.full(Shape(1, 2), 3f)
+ var arr = NDArray.full(Shape(1, 2), 3f)
assert(arr.shape === Shape(1, 2))
assert(arr.toArray === Array(3f, 3f))
+
+ // Float64 methods test
+ arr = NDArray.full(Shape(1, 2), value = 5d, Context.cpu())
+ assert(arr.toFloat64Array === Array (5d, 5d))
}
test("clip") {
- val ndarray = NDArray.empty(3, 2)
+ var ndarray = NDArray.empty(3, 2)
ndarray.set(Array(1f, 2f, 3f, 4f, 5f, 6f))
assert(NDArray.clip(ndarray, 2f, 5f).toArray === Array(2f, 2f, 3f, 4f, 5f, 5f))
+
+ // Float64 methods test
+ ndarray = NDArray.empty(shape = Shape(3, 2), dtype = DType.Float64)
+ ndarray.set(Array(1d, 2d, 3d, 4d, 5d, 6d))
+ assert(NDArray.clip(ndarray, 2d, 5d).toFloat64Array === Array(2d, 2d, 3d, 4d, 5d, 5d))
}
test("sqrt") {
- val ndarray = NDArray.empty(4, 1)
+ var ndarray = NDArray.empty(4, 1)
ndarray.set(Array(0f, 1f, 4f, 9f))
assert(NDArray.sqrt(ndarray).toArray === Array(0f, 1f, 2f, 3f))
+
+ // Float64 methods test
+ ndarray = NDArray.empty(shape = Shape(4, 1), dtype = DType.Float64)
+ ndarray.set(Array(0d, 1d, 4d, 9d))
+ assert(NDArray.sqrt(ndarray).toFloat64Array === Array(0d, 1d, 2d, 3d))
}
test("rsqrt") {
- val ndarray = NDArray.array(Array(1f, 4f), shape = Shape(2, 1))
+ var ndarray = NDArray.array(Array(1f, 4f), shape = Shape(2, 1))
assert(NDArray.rsqrt(ndarray).toArray === Array(1f, 0.5f))
+
+ // Float64 methods test
+ ndarray = NDArray.array(Array(1d, 4d, 25d), shape = Shape(3, 1), Context.cpu())
+ assert(NDArray.rsqrt(ndarray).toFloat64Array === Array(1d, 0.5d, 0.2d))
}
test("norm") {
- val ndarray = NDArray.empty(3, 1)
+ var ndarray = NDArray.empty(3, 1)
ndarray.set(Array(1f, 2f, 3f))
- val normed = NDArray.norm(ndarray)
+ var normed = NDArray.norm(ndarray)
assert(normed.shape === Shape(1))
assert(normed.toScalar === math.sqrt(14.0).toFloat +- 1e-3f)
+
+ // Float64 methods test
+ ndarray = NDArray.empty(shape = Shape(3, 1), dtype = DType.Float64)
+ ndarray.set(Array(1d, 2d, 3d))
+ normed = NDArray.norm(ndarray)
+ assert(normed.get.dtype === DType.Float64)
+ assert(normed.shape === Shape(1))
+ assert(normed.toFloat64Scalar === math.sqrt(14.0) +- 1e-3d)
}
test("one hot encode") {
@@ -176,25 +267,26 @@ class NDArraySuite extends FunSuite with BeforeAndAfterAll with Matchers {
}
test("power") {
- val arr = NDArray.array(Array(3f, 5f), shape = Shape(2, 1))
+ var arr = NDArray.array(Array(3f, 5f), shape = Shape(2, 1))
- val arrPower1 = NDArray.power(2f, arr)
+ var arrPower1 = NDArray.power(2f, arr)
assert(arrPower1.shape === Shape(2, 1))
assert(arrPower1.toArray === Array(8f, 32f))
- val arrPower2 = NDArray.power(arr, 2f)
+ var arrPower2 = NDArray.power(arr, 2f)
assert(arrPower2.shape === Shape(2, 1))
assert(arrPower2.toArray === Array(9f, 25f))
- val arrPower3 = NDArray.power(arr, arr)
+ var arrPower3 = NDArray.power(arr, arr)
assert(arrPower3.shape === Shape(2, 1))
assert(arrPower3.toArray === Array(27f, 3125f))
- val arrPower4 = arr ** 2f
+ var arrPower4 = arr ** 2f
+
assert(arrPower4.shape === Shape(2, 1))
assert(arrPower4.toArray === Array(9f, 25f))
- val arrPower5 = arr ** arr
+ var arrPower5 = arr ** arr
assert(arrPower5.shape === Shape(2, 1))
assert(arrPower5.toArray === Array(27f, 3125f))
@@ -206,84 +298,211 @@ class NDArraySuite extends FunSuite with BeforeAndAfterAll with Matchers {
arr **= arr
assert(arr.shape === Shape(2, 1))
assert(arr.toArray === Array(27f, 3125f))
+
+ // Float64 tests
+ arr = NDArray.array(Array(3d, 5d), shape = Shape(2, 1))
+
+ arrPower1 = NDArray.power(2d, arr)
+ assert(arrPower1.shape === Shape(2, 1))
+ assert(arrPower1.dtype === DType.Float64)
+ assert(arrPower1.toFloat64Array === Array(8d, 32d))
+
+ arrPower2 = NDArray.power(arr, 2d)
+ assert(arrPower2.shape === Shape(2, 1))
+ assert(arrPower2.dtype === DType.Float64)
+ assert(arrPower2.toFloat64Array === Array(9d, 25d))
+
+ arrPower3 = NDArray.power(arr, arr)
+ assert(arrPower3.shape === Shape(2, 1))
+ assert(arrPower3.dtype === DType.Float64)
+ assert(arrPower3.toFloat64Array === Array(27d, 3125d))
+
+ arrPower4 = arr ** 2f
+ assert(arrPower4.shape === Shape(2, 1))
+ assert(arrPower4.dtype === DType.Float64)
+ assert(arrPower4.toFloat64Array === Array(9d, 25d))
+
+ arrPower5 = arr ** arr
+ assert(arrPower5.shape === Shape(2, 1))
+ assert(arrPower5.dtype === DType.Float64)
+ assert(arrPower5.toFloat64Array === Array(27d, 3125d))
+
+ arr **= 2d
+ assert(arr.shape === Shape(2, 1))
+ assert(arr.dtype === DType.Float64)
+ assert(arr.toFloat64Array === Array(9d, 25d))
+
+ arr.set(Array(3d, 5d))
+ arr **= arr
+ assert(arr.shape === Shape(2, 1))
+ assert(arr.dtype === DType.Float64)
+ assert(arr.toFloat64Array === Array(27d, 3125d))
}
test("equal") {
- val arr1 = NDArray.array(Array(1f, 2f, 3f, 5f), shape = Shape(2, 2))
- val arr2 = NDArray.array(Array(1f, 4f, 3f, 6f), shape = Shape(2, 2))
+ var arr1 = NDArray.array(Array(1f, 2f, 3f, 5f), shape = Shape(2, 2))
+ var arr2 = NDArray.array(Array(1f, 4f, 3f, 6f), shape = Shape(2, 2))
- val arrEqual1 = NDArray.equal(arr1, arr2)
+ var arrEqual1 = NDArray.equal(arr1, arr2)
assert(arrEqual1.shape === Shape(2, 2))
assert(arrEqual1.toArray === Array(1f, 0f, 1f, 0f))
- val arrEqual2 = NDArray.equal(arr1, 3f)
+ var arrEqual2 = NDArray.equal(arr1, 3f)
assert(arrEqual2.shape === Shape(2, 2))
assert(arrEqual2.toArray === Array(0f, 0f, 1f, 0f))
+
+
+ // Float64 methods test
+ arr1 = NDArray.array(Array(1d, 2d, 3d, 5d), shape = Shape(2, 2))
+ arr2 = NDArray.array(Array(1d, 4d, 3d, 6d), shape = Shape(2, 2))
+
+ arrEqual1 = NDArray.equal(arr1, arr2)
+ assert(arrEqual1.shape === Shape(2, 2))
+ assert(arrEqual1.dtype === DType.Float64)
+ assert(arrEqual1.toFloat64Array === Array(1d, 0d, 1d, 0d))
+
+ arrEqual2 = NDArray.equal(arr1, 3d)
+ assert(arrEqual2.shape === Shape(2, 2))
+ assert(arrEqual2.dtype === DType.Float64)
+ assert(arrEqual2.toFloat64Array === Array(0d, 0d, 1d, 0d))
}
test("not_equal") {
- val arr1 = NDArray.array(Array(1f, 2f, 3f, 5f), shape = Shape(2, 2))
- val arr2 = NDArray.array(Array(1f, 4f, 3f, 6f), shape = Shape(2, 2))
+ var arr1 = NDArray.array(Array(1f, 2f, 3f, 5f), shape = Shape(2, 2))
+ var arr2 = NDArray.array(Array(1f, 4f, 3f, 6f), shape = Shape(2, 2))
- val arrEqual1 = NDArray.notEqual(arr1, arr2)
+ var arrEqual1 = NDArray.notEqual(arr1, arr2)
assert(arrEqual1.shape === Shape(2, 2))
assert(arrEqual1.toArray === Array(0f, 1f, 0f, 1f))
- val arrEqual2 = NDArray.notEqual(arr1, 3f)
+ var arrEqual2 = NDArray.notEqual(arr1, 3f)
assert(arrEqual2.shape === Shape(2, 2))
assert(arrEqual2.toArray === Array(1f, 1f, 0f, 1f))
+
+ // Float64 methods test
+
+ arr1 = NDArray.array(Array(1d, 2d, 3d, 5d), shape = Shape(2, 2))
+ arr2 = NDArray.array(Array(1d, 4d, 3d, 6d), shape = Shape(2, 2))
+
+ arrEqual1 = NDArray.notEqual(arr1, arr2)
+ assert(arrEqual1.shape === Shape(2, 2))
+ assert(arrEqual1.dtype === DType.Float64)
+ assert(arrEqual1.toFloat64Array === Array(0d, 1d, 0d, 1d))
+
+ arrEqual2 = NDArray.notEqual(arr1, 3d)
+ assert(arrEqual2.shape === Shape(2, 2))
+ assert(arrEqual2.dtype === DType.Float64)
+ assert(arrEqual2.toFloat64Array === Array(1d, 1d, 0d, 1d))
+
}
test("greater") {
- val arr1 = NDArray.array(Array(1f, 2f, 4f, 5f), shape = Shape(2, 2))
- val arr2 = NDArray.array(Array(1f, 4f, 3f, 6f), shape = Shape(2, 2))
+ var arr1 = NDArray.array(Array(1f, 2f, 4f, 5f), shape = Shape(2, 2))
+ var arr2 = NDArray.array(Array(1f, 4f, 3f, 6f), shape = Shape(2, 2))
- val arrEqual1 = arr1 > arr2
+ var arrEqual1 = arr1 > arr2
assert(arrEqual1.shape === Shape(2, 2))
assert(arrEqual1.toArray === Array(0f, 0f, 1f, 0f))
- val arrEqual2 = arr1 > 2f
+ var arrEqual2 = arr1 > 2f
assert(arrEqual2.shape === Shape(2, 2))
assert(arrEqual2.toArray === Array(0f, 0f, 1f, 1f))
+
+ // Float64 methods test
+ arr1 = NDArray.array(Array(1d, 2d, 4d, 5d), shape = Shape(2, 2))
+ arr2 = NDArray.array(Array(1d, 4d, 3d, 6d), shape = Shape(2, 2))
+
+ arrEqual1 = arr1 > arr2
+ assert(arrEqual1.shape === Shape(2, 2))
+ assert(arrEqual1.dtype === DType.Float64)
+ assert(arrEqual1.toFloat64Array === Array(0d, 0d, 1d, 0d))
+
+ arrEqual2 = arr1 > 2d
+ assert(arrEqual2.shape === Shape(2, 2))
+ assert(arrEqual2.dtype === DType.Float64)
+ assert(arrEqual2.toFloat64Array === Array(0d, 0d, 1d, 1d))
}
test("greater_equal") {
- val arr1 = NDArray.array(Array(1f, 2f, 4f, 5f), shape = Shape(2, 2))
- val arr2 = NDArray.array(Array(1f, 4f, 3f, 6f), shape = Shape(2, 2))
+ var arr1 = NDArray.array(Array(1f, 2f, 4f, 5f), shape = Shape(2, 2))
+ var arr2 = NDArray.array(Array(1f, 4f, 3f, 6f), shape = Shape(2, 2))
- val arrEqual1 = arr1 >= arr2
+ var arrEqual1 = arr1 >= arr2
assert(arrEqual1.shape === Shape(2, 2))
assert(arrEqual1.toArray === Array(1f, 0f, 1f, 0f))
- val arrEqual2 = arr1 >= 2f
+ var arrEqual2 = arr1 >= 2f
assert(arrEqual2.shape === Shape(2, 2))
assert(arrEqual2.toArray === Array(0f, 1f, 1f, 1f))
+
+ // Float64 methods test
+ arr1 = NDArray.array(Array(1d, 2d, 4d, 5d), shape = Shape(2, 2))
+ arr2 = NDArray.array(Array(1d, 4d, 3d, 6d), shape = Shape(2, 2))
+
+ arrEqual1 = arr1 >= arr2
+ assert(arrEqual1.shape === Shape(2, 2))
+ assert(arrEqual1.dtype === DType.Float64)
+ assert(arrEqual1.toFloat64Array === Array(1d, 0d, 1d, 0d))
+
+ arrEqual2 = arr1 >= 2d
+ assert(arrEqual2.shape === Shape(2, 2))
+ assert(arrEqual2.dtype === DType.Float64)
+ assert(arrEqual2.toFloat64Array === Array(0d, 1d, 1d, 1d))
}
test("lesser") {
- val arr1 = NDArray.array(Array(1f, 2f, 4f, 5f), shape = Shape(2, 2))
- val arr2 = NDArray.array(Array(1f, 4f, 3f, 6f), shape = Shape(2, 2))
+ var arr1 = NDArray.array(Array(1f, 2f, 4f, 5f), shape = Shape(2, 2))
+ var arr2 = NDArray.array(Array(1f, 4f, 3f, 6f), shape = Shape(2, 2))
- val arrEqual1 = arr1 < arr2
+ var arrEqual1 = arr1 < arr2
assert(arrEqual1.shape === Shape(2, 2))
assert(arrEqual1.toArray === Array(0f, 1f, 0f, 1f))
- val arrEqual2 = arr1 < 2f
+ var arrEqual2 = arr1 < 2f
assert(arrEqual2.shape === Shape(2, 2))
assert(arrEqual2.toArray === Array(1f, 0f, 0f, 0f))
+
+ // Float64 methods test
+ arr1 = NDArray.array(Array(1d, 2d, 4d, 5d), shape = Shape(2, 2))
+ arr2 = NDArray.array(Array(1d, 4d, 3d, 6d), shape = Shape(2, 2))
+
+ arrEqual1 = arr1 < arr2
+ assert(arrEqual1.shape === Shape(2, 2))
+ assert(arrEqual1.dtype === DType.Float64)
+ assert(arrEqual1.toFloat64Array === Array(0d, 1d, 0d, 1d))
+
+ arrEqual2 = arr1 < 2d
+ assert(arrEqual2.shape === Shape(2, 2))
+ assert(arrEqual2.dtype === DType.Float64)
+ assert(arrEqual2.toFloat64Array === Array(1d, 0d, 0d, 0d))
+
}
test("lesser_equal") {
- val arr1 = NDArray.array(Array(1f, 2f, 4f, 5f), shape = Shape(2, 2))
- val arr2 = NDArray.array(Array(1f, 4f, 3f, 6f), shape = Shape(2, 2))
+ var arr1 = NDArray.array(Array(1f, 2f, 4f, 5f), shape = Shape(2, 2))
+ var arr2 = NDArray.array(Array(1f, 4f, 3f, 6f), shape = Shape(2, 2))
- val arrEqual1 = arr1 <= arr2
+ var arrEqual1 = arr1 <= arr2
assert(arrEqual1.shape === Shape(2, 2))
assert(arrEqual1.toArray === Array(1f, 1f, 0f, 1f))
- val arrEqual2 = arr1 <= 2f
+ var arrEqual2 = arr1 <= 2f
assert(arrEqual2.shape === Shape(2, 2))
assert(arrEqual2.toArray === Array(1f, 1f, 0f, 0f))
+
+ // Float64 methods test
+ arr1 = NDArray.array(Array(1d, 2d, 4d, 5d), shape = Shape(2, 2))
+ arr2 = NDArray.array(Array(1d, 4d, 3d, 6d), shape = Shape(2, 2))
+
+ arrEqual1 = arr1 <= arr2
+ assert(arrEqual1.shape === Shape(2, 2))
+ assert(arrEqual1.dtype === DType.Float64)
+ assert(arrEqual1.toFloat64Array === Array(1d, 1d, 0d, 1d))
+
+ arrEqual2 = arr1 <= 2d
+ assert(arrEqual2.shape === Shape(2, 2))
+ assert(arrEqual2.dtype === DType.Float64)
+ assert(arrEqual2.toFloat64Array === Array(1d, 1d, 0d, 0d))
}
test("choose_element_0index") {
@@ -294,11 +513,18 @@ class NDArraySuite extends FunSuite with BeforeAndAfterAll with Matchers {
}
test("copy to") {
- val source = NDArray.array(Array(1f, 2f, 3f), shape = Shape(1, 3))
- val dest = NDArray.empty(1, 3)
+ var source = NDArray.array(Array(1f, 2f, 3f), shape = Shape(1, 3))
+ var dest = NDArray.empty(1, 3)
source.copyTo(dest)
assert(dest.shape === Shape(1, 3))
assert(dest.toArray === Array(1f, 2f, 3f))
+
+ // Float64 methods test
+ source = NDArray.array(Array(1d, 2d, 3d), shape = Shape(1, 3))
+ dest = NDArray.empty(shape = Shape(1, 3), dtype = DType.Float64)
+ source.copyTo(dest)
+ assert(dest.dtype === DType.Float64)
+ assert(dest.toFloat64Array === Array(1d, 2d, 3d))
}
test("abs") {
@@ -365,6 +591,12 @@ class NDArraySuite extends FunSuite with BeforeAndAfterAll with Matchers {
val arr = NDArray.maximum(arr1, arr2)
assert(arr.shape === Shape(3, 1))
assert(arr.toArray === Array(4f, 2.1f, 3.7f))
+
+ // Float64 methods test
+ val arr3 = NDArray.array(Array(1d, 2d, 3d), shape = Shape(3, 1))
+ val maxArr = NDArray.maximum(arr3, 10d)
+ assert(maxArr.shape === Shape(3, 1))
+ assert(maxArr.toArray === Array(10d, 10d, 10d))
}
test("min") {
@@ -378,11 +610,18 @@ class NDArraySuite extends FunSuite with BeforeAndAfterAll with Matchers {
val arr = NDArray.minimum(arr1, arr2)
assert(arr.shape === Shape(3, 1))
assert(arr.toArray === Array(1.5f, 1f, 3.5f))
+
+ // Float64 methods test
+ val arr3 = NDArray.array(Array(4d, 5d, 6d), shape = Shape(3, 1))
+ val minArr = NDArray.minimum(arr3, 5d)
+ assert(minArr.shape === Shape(3, 1))
+ assert(minArr.toFloat64Array === Array(4d, 5d, 5d))
}
test("sum") {
- val arr = NDArray.array(Array(1f, 2f, 3f, 4f), shape = Shape(2, 2))
+ var arr = NDArray.array(Array(1f, 2f, 3f, 4f), shape = Shape(2, 2))
assert(NDArray.sum(arr).toScalar === 10f +- 1e-3f)
+
}
test("argmaxChannel") {
@@ -398,6 +637,12 @@ class NDArraySuite extends FunSuite with BeforeAndAfterAll with Matchers {
val arr = NDArray.concatenate(arr1, arr2)
assert(arr.shape === Shape(3, 3))
assert(arr.toArray === Array(1f, 2f, 4f, 3f, 3f, 3f, 8f, 7f, 6f))
+
+ // Try concatenating float32 arr with float64 arr. Should get exception
+ intercept[Exception] {
+ val arr3 = NDArray.array(Array (5d, 6d, 7d), shape = Shape(1, 3))
+ NDArray.concatenate(Array(arr1, arr3))
+ }
}
test("concatenate axis-1") {
@@ -406,6 +651,12 @@ class NDArraySuite extends FunSuite with BeforeAndAfterAll with Matchers {
val arr = NDArray.concatenate(Array(arr1, arr2), axis = 1)
assert(arr.shape === Shape(2, 3))
assert(arr.toArray === Array(1f, 2f, 5f, 3f, 4f, 6f))
+
+ // Try concatenating float32 arr with float64 arr. Should get exception
+ intercept[Exception] {
+ val arr3 = NDArray.array(Array (5d, 6d), shape = Shape(2, 1))
+ NDArray.concatenate(Array(arr1, arr3), axis = 1)
+ }
}
test("transpose") {
@@ -428,6 +679,24 @@ class NDArraySuite extends FunSuite with BeforeAndAfterAll with Matchers {
val loadedArray = arrays(0)
assert(loadedArray.shape === Shape(3, 1))
assert(loadedArray.toArray === Array(1f, 2f, 3f))
+ assert(loadedArray.dtype === DType.Float32)
+ } finally {
+ val file = new File(filename)
+ file.delete()
+ }
+
+ // Try the same for Float64 array
+ try {
+ val ndarray = NDArray.array(Array(1d, 2d, 3d), shape = Shape(3, 1), ctx = Context.cpu())
+ NDArray.save(filename, Map("local" -> ndarray))
+ val (keys, arrays) = NDArray.load(filename)
+ assert(keys.length === 1)
+ assert(keys(0) === "local")
+ assert(arrays.length === 1)
+ val loadedArray = arrays(0)
+ assert(loadedArray.shape === Shape(3, 1))
+ assert(loadedArray.toArray === Array(1d, 2d, 3d))
+ assert(loadedArray.dtype === DType.Float64)
} finally {
val file = new File(filename)
file.delete()
@@ -446,6 +715,24 @@ class NDArraySuite extends FunSuite with BeforeAndAfterAll with Matchers {
val loadedArray = arrays(0)
assert(loadedArray.shape === Shape(3, 1))
assert(loadedArray.toArray === Array(1f, 2f, 3f))
+ assert(loadedArray.dtype === DType.Float32)
+ } finally {
+ val file = new File(filename)
+ file.delete()
+ }
+
+ // Try the same thing for Float64 array :
+
+ try {
+ val ndarray = NDArray.array(Array(1d, 2d, 3d), shape = Shape(3, 1), ctx = Context.cpu())
+ NDArray.save(filename, Array(ndarray))
+ val (keys, arrays) = NDArray.load(filename)
+ assert(keys.length === 0)
+ assert(arrays.length === 1)
+ val loadedArray = arrays(0)
+ assert(loadedArray.shape === Shape(3, 1))
+ assert(loadedArray.toArray === Array(1d, 2d, 3d))
+ assert(loadedArray.dtype === DType.Float64)
} finally {
val file = new File(filename)
file.delete()
@@ -464,9 +751,11 @@ class NDArraySuite extends FunSuite with BeforeAndAfterAll with Matchers {
val ndarray2 = NDArray.array(Array(1f, 2f, 3f), shape = Shape(3, 1))
val ndarray3 = NDArray.array(Array(1f, 2f, 3f), shape = Shape(1, 3))
val ndarray4 = NDArray.array(Array(3f, 2f, 3f), shape = Shape(3, 1))
+ val ndarray5 = NDArray.array(Array(3d, 2d, 3d), shape = Shape(3, 1), ctx = Context.cpu())
ndarray1 shouldEqual ndarray2
ndarray1 shouldNot equal(ndarray3)
ndarray1 shouldNot equal(ndarray4)
+ ndarray5 shouldNot equal(ndarray3)
}
test("slice") {
@@ -545,6 +834,7 @@ class NDArraySuite extends FunSuite with BeforeAndAfterAll with Matchers {
val bytes = arr.serialize()
val arrCopy = NDArray.deserialize(bytes)
assert(arr === arrCopy)
+ assert(arrCopy.dtype === DType.Float32)
}
test("dtype int32") {
@@ -580,18 +870,22 @@ class NDArraySuite extends FunSuite with BeforeAndAfterAll with Matchers {
test("NDArray random module is generated properly") {
val lam = NDArray.ones(1, 2)
val rnd = NDArray.random.poisson(lam = Some(lam), shape = Some(Shape(3, 4)))
- val rnd2 = NDArray.random.poisson(lam = Some(1f), shape = Some(Shape(3, 4)))
+ val rnd2 = NDArray.random.poisson(lam = Some(1f), shape = Some(Shape(3, 4)),
+ dtype = Some("float64"))
assert(rnd.shape === Shape(1, 2, 3, 4))
assert(rnd2.shape === Shape(3, 4))
+ assert(rnd2.head.dtype === DType.Float64)
}
test("NDArray random module is generated properly - special case of 'normal'") {
val mu = NDArray.ones(1, 2)
val sigma = NDArray.ones(1, 2) * 2
val rnd = NDArray.random.normal(mu = Some(mu), sigma = Some(sigma), shape = Some(Shape(3, 4)))
- val rnd2 = NDArray.random.normal(mu = Some(1f), sigma = Some(2f), shape = Some(Shape(3, 4)))
+ val rnd2 = NDArray.random.normal(mu = Some(1f), sigma = Some(2f), shape = Some(Shape(3, 4)),
+ dtype = Some("float64"))
assert(rnd.shape === Shape(1, 2, 3, 4))
assert(rnd2.shape === Shape(3, 4))
+ assert(rnd2.head.dtype === DType.Float64)
}
test("Generated api") {
diff --git a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/imclassification/TrainModel.scala b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/imclassification/TrainModel.scala
index f6c283c3dfb2..9f0430eaada6 100644
--- a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/imclassification/TrainModel.scala
+++ b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/imclassification/TrainModel.scala
@@ -19,6 +19,7 @@ package org.apache.mxnetexamples.imclassification
import java.util.concurrent._
+import org.apache.mxnet.DType.DType
import org.apache.mxnetexamples.imclassification.models._
import org.apache.mxnetexamples.imclassification.util.Trainer
import org.apache.mxnet._
@@ -42,12 +43,13 @@ object TrainModel {
* @return The final validation accuracy
*/
def test(model: String, dataPath: String, numExamples: Int = 60000,
- numEpochs: Int = 10, benchmark: Boolean = false): Float = {
+ numEpochs: Int = 10, benchmark: Boolean = false,
+ dtype: DType = DType.Float32): Float = {
ResourceScope.using() {
val devs = Array(Context.cpu(0))
val envs: mutable.Map[String, String] = mutable.HashMap.empty[String, String]
val (dataLoader, net) = dataLoaderAndModel("mnist", model, dataPath,
- numExamples = numExamples, benchmark = benchmark)
+ numExamples = numExamples, benchmark = benchmark, dtype = dtype)
val Acc = Trainer.fit(batchSize = 128, numExamples, devs = devs,
network = net, dataLoader = dataLoader,
kvStore = "local", numEpochs = numEpochs)
@@ -69,7 +71,7 @@ object TrainModel {
*/
def dataLoaderAndModel(dataset: String, model: String, dataDir: String = "",
numLayers: Int = 50, numExamples: Int = 60000,
- benchmark: Boolean = false
+ benchmark: Boolean = false, dtype: DType = DType.Float32
): ((Int, KVStore) => (DataIter, DataIter), Symbol) = {
val (imageShape, numClasses) = dataset match {
case "mnist" => (List(1, 28, 28), 10)
@@ -80,16 +82,17 @@ object TrainModel {
val List(channels, height, width) = imageShape
val dataSize: Int = channels * height * width
val (datumShape, net) = model match {
- case "mlp" => (List(dataSize), MultiLayerPerceptron.getSymbol(numClasses))
- case "lenet" => (List(channels, height, width), Lenet.getSymbol(numClasses))
+ case "mlp" => (List(dataSize), MultiLayerPerceptron.getSymbol(numClasses, dtype = dtype))
+ case "lenet" => (List(channels, height, width), Lenet.getSymbol(numClasses, dtype = dtype))
case "resnet" => (List(channels, height, width), Resnet.getSymbol(numClasses,
- numLayers, imageShape))
+ numLayers, imageShape, dtype = dtype))
case _ => throw new Exception("Invalid model name")
}
val dataLoader: (Int, KVStore) => (DataIter, DataIter) = if (benchmark) {
(batchSize: Int, kv: KVStore) => {
- val iter = new SyntheticDataIter(numClasses, batchSize, datumShape, List(), numExamples)
+ val iter = new SyntheticDataIter(numClasses, batchSize, datumShape, List(), numExamples,
+ dtype)
(iter, iter)
}
} else {
@@ -116,8 +119,10 @@ object TrainModel {
val dataPath = if (inst.dataDir == null) System.getenv("MXNET_HOME")
else inst.dataDir
+ val dtype = DType.withName(inst.dType)
+
val (dataLoader, net) = dataLoaderAndModel(inst.dataset, inst.network, dataPath,
- inst.numLayers, inst.numExamples, inst.benchmark)
+ inst.numLayers, inst.numExamples, inst.benchmark, dtype)
val devs =
if (inst.gpus != null) inst.gpus.split(',').map(id => Context.gpu(id.trim.toInt))
@@ -210,5 +215,8 @@ class TrainModel {
private val numWorker: Int = 1
@Option(name = "--num-server", usage = "# of servers")
private val numServer: Int = 1
+ @Option(name = "--dtype", usage = "data type of the model to train. " +
+ "Can be float32/float64. Works only with synthetic data currently")
+ private val dType: String = "float32"
}
diff --git a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/imclassification/datasets/SyntheticDataIter.scala b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/imclassification/datasets/SyntheticDataIter.scala
index 9421f1021619..e4d3b2ae7c3e 100644
--- a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/imclassification/datasets/SyntheticDataIter.scala
+++ b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/imclassification/datasets/SyntheticDataIter.scala
@@ -24,7 +24,7 @@ import scala.collection.immutable.ListMap
import scala.util.Random
class SyntheticDataIter(numClasses: Int, val batchSize: Int, datumShape: List[Int],
- labelShape: List[Int], maxIter: Int, dtype: DType = DType.Float32
+ labelShape: List[Int], maxIter: Int, dType: DType = DType.Float32
) extends DataIter {
var curIter = 0
val random = new Random()
@@ -35,12 +35,12 @@ class SyntheticDataIter(numClasses: Int, val batchSize: Int, datumShape: List[In
var label: IndexedSeq[NDArray] = IndexedSeq(
NDArray.api.random_uniform(Some(0f), Some(maxLabel), shape = Some(batchLabelShape)))
var data: IndexedSeq[NDArray] = IndexedSeq(
- NDArray.api.random_uniform(shape = Some(shape)))
+ NDArray.api.random_uniform(shape = Some(shape), dtype = Some(dType.toString)))
val provideDataDesc: IndexedSeq[DataDesc] = IndexedSeq(
- new DataDesc("data", shape, dtype, Layout.UNDEFINED))
+ new DataDesc("data", shape, data(0).dtype, Layout.UNDEFINED))
val provideLabelDesc: IndexedSeq[DataDesc] = IndexedSeq(
- new DataDesc("softmax_label", batchLabelShape, dtype, Layout.UNDEFINED))
+ new DataDesc("softmax_label", batchLabelShape, label(0).dtype, Layout.UNDEFINED))
val getPad: Int = 0
override def getData(): IndexedSeq[NDArray] = data
diff --git a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/imclassification/models/Lenet.scala b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/imclassification/models/Lenet.scala
index 76fb7bb66022..6f8b138d5ccb 100644
--- a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/imclassification/models/Lenet.scala
+++ b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/imclassification/models/Lenet.scala
@@ -17,6 +17,7 @@
package org.apache.mxnetexamples.imclassification.models
+import org.apache.mxnet.DType.DType
import org.apache.mxnet._
object Lenet {
@@ -26,8 +27,8 @@ object Lenet {
* @param numClasses Number of classes to classify into
* @return model symbol
*/
- def getSymbol(numClasses: Int): Symbol = {
- val data = Symbol.Variable("data")
+ def getSymbol(numClasses: Int, dtype: DType = DType.Float32): Symbol = {
+ val data = Symbol.Variable("data", dType = dtype)
// first conv
val conv1 = Symbol.api.Convolution(data = Some(data), kernel = Shape(5, 5), num_filter = 20)
val tanh1 = Symbol.api.tanh(data = Some(conv1))
diff --git a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/imclassification/models/MultiLayerPerceptron.scala b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/imclassification/models/MultiLayerPerceptron.scala
index 5d880bbe0619..089b65f24a65 100644
--- a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/imclassification/models/MultiLayerPerceptron.scala
+++ b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/imclassification/models/MultiLayerPerceptron.scala
@@ -17,6 +17,7 @@
package org.apache.mxnetexamples.imclassification.models
+import org.apache.mxnet.DType.DType
import org.apache.mxnet._
object MultiLayerPerceptron {
@@ -26,8 +27,8 @@ object MultiLayerPerceptron {
* @param numClasses Number of classes to classify into
* @return model symbol
*/
- def getSymbol(numClasses: Int): Symbol = {
- val data = Symbol.Variable("data")
+ def getSymbol(numClasses: Int, dtype: DType = DType.Float32): Symbol = {
+ val data = Symbol.Variable("data", dType = dtype)
val fc1 = Symbol.api.FullyConnected(data = Some(data), num_hidden = 128, name = "fc1")
val act1 = Symbol.api.Activation(data = Some(fc1), "relu", name = "relu")
diff --git a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/imclassification/models/Resnet.scala b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/imclassification/models/Resnet.scala
index c3f43d97e898..e5f597680f99 100644
--- a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/imclassification/models/Resnet.scala
+++ b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/imclassification/models/Resnet.scala
@@ -17,6 +17,7 @@
package org.apache.mxnetexamples.imclassification.models
+import org.apache.mxnet.DType.DType
import org.apache.mxnet._
object Resnet {
@@ -77,13 +78,14 @@ object Resnet {
*/
def resnet(units: List[Int], numStages: Int, filterList: List[Int], numClasses: Int,
imageShape: List[Int], bottleNeck: Boolean = true, bnMom: Float = 0.9f,
- workspace: Int = 256, dtype: String = "float32", memonger: Boolean = false): Symbol = {
+ workspace: Int = 256, dtype: DType = DType.Float32,
+ memonger: Boolean = false): Symbol = {
assert(units.size == numStages)
var data = Symbol.Variable("data", shape = Shape(List(4) ::: imageShape), dType = DType.Float32)
- if (dtype == "float32") {
+ if (dtype == DType.Float32) {
data = Symbol.api.identity(Some(data), "id")
- } else if (dtype == "float16") {
- data = Symbol.api.cast(Some(data), "float16")
+ } else if (dtype == DType.Float16) {
+ data = Symbol.api.cast(Some(data), DType.Float16.toString)
}
data = Symbol.api.BatchNorm(Some(data), fix_gamma = Some(true), eps = Some(2e-5),
momentum = Some(bnMom), name = "bn_data")
@@ -118,8 +120,8 @@ object Resnet {
kernel = Some(Shape(7, 7)), pool_type = Some("avg"), name = "pool1")
val flat = Symbol.api.Flatten(Some(pool1))
var fc1 = Symbol.api.FullyConnected(Some(flat), num_hidden = numClasses, name = "fc1")
- if (dtype == "float16") {
- fc1 = Symbol.api.cast(Some(fc1), "float32")
+ if (dtype == DType.Float16) {
+ fc1 = Symbol.api.cast(Some(fc1), DType.Float32.toString)
}
Symbol.api.SoftmaxOutput(Some(fc1), name = "softmax")
}
@@ -134,7 +136,7 @@ object Resnet {
* @return Model symbol
*/
def getSymbol(numClasses: Int, numLayers: Int, imageShape: List[Int], convWorkspace: Int = 256,
- dtype: String = "float32"): Symbol = {
+ dtype: DType = DType.Float32): Symbol = {
val List(channels, height, width) = imageShape
val (numStages, units, filterList, bottleNeck): (Int, List[Int], List[Int], Boolean) =
if (height <= 28) {
diff --git a/scala-package/examples/src/test/scala/org/apache/mxnetexamples/imclassification/IMClassificationExampleSuite.scala b/scala-package/examples/src/test/scala/org/apache/mxnetexamples/imclassification/IMClassificationExampleSuite.scala
index 6e9667abe9c0..0daba5a97d77 100644
--- a/scala-package/examples/src/test/scala/org/apache/mxnetexamples/imclassification/IMClassificationExampleSuite.scala
+++ b/scala-package/examples/src/test/scala/org/apache/mxnetexamples/imclassification/IMClassificationExampleSuite.scala
@@ -19,7 +19,7 @@ package org.apache.mxnetexamples.imclassification
import java.io.File
-import org.apache.mxnet.Context
+import org.apache.mxnet.{Context, DType}
import org.apache.mxnetexamples.Util
import org.scalatest.{BeforeAndAfterAll, FunSuite}
import org.slf4j.LoggerFactory
@@ -55,9 +55,15 @@ class IMClassificationExampleSuite extends FunSuite with BeforeAndAfterAll {
for(model <- List("mlp", "lenet", "resnet")) {
test(s"Example CI: Test Image Classification Model ${model}") {
- var context = Context.cpu()
val valAccuracy = TrainModel.test(model, "", 10, 1, benchmark = true)
}
}
+ for(model <- List("mlp", "lenet", "resnet")) {
+ test(s"Example CI: Test Image Classification Model ${model} with Float64 input") {
+ val valAccuracy = TrainModel.test(model, "", 10, 1, benchmark = true,
+ dtype = DType.Float64)
+ }
+ }
+
}
diff --git a/scala-package/infer/src/main/scala/org/apache/mxnet/infer/Classifier.scala b/scala-package/infer/src/main/scala/org/apache/mxnet/infer/Classifier.scala
index 5208923275f6..bf6581588114 100644
--- a/scala-package/infer/src/main/scala/org/apache/mxnet/infer/Classifier.scala
+++ b/scala-package/infer/src/main/scala/org/apache/mxnet/infer/Classifier.scala
@@ -17,9 +17,10 @@
package org.apache.mxnet.infer
-import org.apache.mxnet.{Context, DataDesc, NDArray}
+import org.apache.mxnet._
import java.io.File
+import org.apache.mxnet.MX_PRIMITIVES.MX_PRIMITIVE_TYPE
import org.slf4j.LoggerFactory
import scala.io
@@ -30,13 +31,13 @@ trait ClassifierBase {
/**
* Takes an array of floats and returns corresponding (Label, Score) tuples
- * @param input Indexed sequence one-dimensional array of floats
+ * @param input Indexed sequence one-dimensional array of floats/doubles
* @param topK (Optional) How many result (sorting based on the last axis)
* elements to return. Default returns unsorted output.
* @return Indexed sequence of (Label, Score) tuples
*/
- def classify(input: IndexedSeq[Array[Float]],
- topK: Option[Int] = None): IndexedSeq[(String, Float)]
+ def classify[@specialized (Base.MX_PRIMITIVES) T](input: IndexedSeq[Array[T]],
+ topK: Option[Int] = None): IndexedSeq[(String, T)]
/**
* Takes a sequence of NDArrays and returns (Label, Score) tuples
@@ -78,17 +79,35 @@ class Classifier(modelPathPrefix: String,
/**
* Takes flat arrays as input and returns (Label, Score) tuples.
- * @param input Indexed sequence one-dimensional array of floats
+ * @param input Indexed sequence one-dimensional array of floats/doubles
* @param topK (Optional) How many result (sorting based on the last axis)
* elements to return. Default returns unsorted output.
* @return Indexed sequence of (Label, Score) tuples
*/
- override def classify(input: IndexedSeq[Array[Float]],
- topK: Option[Int] = None): IndexedSeq[(String, Float)] = {
+ override def classify[@specialized (Base.MX_PRIMITIVES) T](input: IndexedSeq[Array[T]],
+ topK: Option[Int] = None): IndexedSeq[(String, T)] = {
+
+ // considering only the first output
+ val result = input(0)(0) match {
+ case d: Double => {
+ classifyImpl(input.asInstanceOf[IndexedSeq[Array[Double]]], topK)
+ }
+ case _ => {
+ classifyImpl(input.asInstanceOf[IndexedSeq[Array[Float]]], topK)
+ }
+ }
+
+ result.asInstanceOf[IndexedSeq[(String, T)]]
+ }
+
+ private def classifyImpl[B, A <: MX_PRIMITIVE_TYPE]
+ (input: IndexedSeq[Array[B]], topK: Option[Int] = None)(implicit ev: B => A)
+ : IndexedSeq[(String, B)] = {
// considering only the first output
val predictResult = predictor.predict(input)(0)
- var result: IndexedSeq[(String, Float)] = IndexedSeq.empty
+
+ var result: IndexedSeq[(String, B)] = IndexedSeq.empty
if (topK.isDefined) {
val sortedIndex = predictResult.zipWithIndex.sortBy(-_._1).map(_._2).take(topK.get)
@@ -105,7 +124,7 @@ class Classifier(modelPathPrefix: String,
* @param input Indexed sequence of NDArrays
* @param topK (Optional) How many result (sorting based on the last axis)
* elements to return. Default returns unsorted output.
- * @return Traversable sequence of (Label, Score) tuples
+ * @return Traversable sequence of (Label, Score) tuples.
*/
override def classifyWithNDArray(input: IndexedSeq[NDArray], topK: Option[Int] = None)
: IndexedSeq[IndexedSeq[(String, Float)]] = {
@@ -113,7 +132,7 @@ class Classifier(modelPathPrefix: String,
// considering only the first output
// Copy NDArray to CPU to avoid frequent GPU to CPU copying
val predictResultND: NDArray =
- predictor.predictWithNDArray(input)(0).asInContext(Context.cpu())
+ predictor.predictWithNDArray(input)(0).asInContext(Context.cpu())
// Parallel Execution with ParArray for better performance
val predictResultPar: ParArray[Array[Float]] =
new ParArray[Array[Float]](predictResultND.shape(0))
diff --git a/scala-package/infer/src/main/scala/org/apache/mxnet/infer/ImageClassifier.scala b/scala-package/infer/src/main/scala/org/apache/mxnet/infer/ImageClassifier.scala
index 96be12179d42..3c80f9226399 100644
--- a/scala-package/infer/src/main/scala/org/apache/mxnet/infer/ImageClassifier.scala
+++ b/scala-package/infer/src/main/scala/org/apache/mxnet/infer/ImageClassifier.scala
@@ -17,7 +17,8 @@
package org.apache.mxnet.infer
-import org.apache.mxnet.{Context, DataDesc, NDArray, Shape}
+import org.apache.mxnet.DType.DType
+import org.apache.mxnet._
import scala.collection.mutable.ListBuffer
@@ -70,14 +71,18 @@ class ImageClassifier(modelPathPrefix: String,
*
* @param inputImage Path prefix of the input image
* @param topK Number of result elements to return, sorted by probability
+ * @param dType The precision at which to run the inference.
+ * specify the DType as DType.Float64 for Double precision.
+ * Defaults to DType.Float32
* @return List of list of tuples of (Label, Probability)
*/
- def classifyImage(inputImage: BufferedImage,
- topK: Option[Int] = None): IndexedSeq[IndexedSeq[(String, Float)]] = {
+ def classifyImage
+ (inputImage: BufferedImage, topK: Option[Int] = None, dType: DType = DType.Float32):
+ IndexedSeq[IndexedSeq[(String, Float)]] = {
val scaledImage = ImageClassifier.reshapeImage(inputImage, width, height)
val imageShape = inputShape.drop(1)
- val pixelsNDArray = ImageClassifier.bufferedImageToPixels(scaledImage, imageShape)
+ val pixelsNDArray = ImageClassifier.bufferedImageToPixels(scaledImage, imageShape, dType)
val imgWithBatchNum = NDArray.api.expand_dims(pixelsNDArray, 0)
inputImage.flush()
scaledImage.flush()
@@ -95,16 +100,19 @@ class ImageClassifier(modelPathPrefix: String,
*
* @param inputBatch Input array of buffered images
* @param topK Number of result elements to return, sorted by probability
+ * @param dType The precision at which to run the inference.
+ * specify the DType as DType.Float64 for Double precision.
+ * Defaults to DType.Float32
* @return List of list of tuples of (Label, Probability)
*/
- def classifyImageBatch(inputBatch: Traversable[BufferedImage], topK: Option[Int] = None):
- IndexedSeq[IndexedSeq[(String, Float)]] = {
+ def classifyImageBatch(inputBatch: Traversable[BufferedImage], topK: Option[Int] = None,
+ dType: DType = DType.Float32): IndexedSeq[IndexedSeq[(String, Float)]] = {
val inputBatchSeq = inputBatch.toIndexedSeq
val imageBatch = inputBatchSeq.indices.par.map(idx => {
val scaledImage = ImageClassifier.reshapeImage(inputBatchSeq(idx), width, height)
val imageShape = inputShape.drop(1)
- val imgND = ImageClassifier.bufferedImageToPixels(scaledImage, imageShape)
+ val imgND = ImageClassifier.bufferedImageToPixels(scaledImage, imageShape, dType)
val imgWithBatch = NDArray.api.expand_dims(imgND, 0).get
handler.execute(imgND.dispose())
imgWithBatch
@@ -152,11 +160,29 @@ object ImageClassifier {
* returned by this method after the use.
*
* @param resizedImage BufferedImage to get pixels from
+ *
* @param inputImageShape Input shape; for example for resnet it is (3,224,224).
Should be same as inputDescriptor shape.
+ * @param dType The DataType of the NDArray created from the image
+ * that should be returned.
+ * Currently it defaults to Dtype.Float32
* @return NDArray pixels array with shape (3, 224, 224) in CHW format
*/
- def bufferedImageToPixels(resizedImage: BufferedImage, inputImageShape: Shape): NDArray = {
+ def bufferedImageToPixels(resizedImage: BufferedImage, inputImageShape: Shape,
+ dType : DType = DType.Float32): NDArray = {
+
+ if (dType == DType.Float64) {
+ val result = getFloatPixelsArray(resizedImage)
+ NDArray.array(result.map(_.toDouble), shape = inputImageShape)
+ }
+ else {
+ val result = getFloatPixelsArray(resizedImage)
+ NDArray.array(result, shape = inputImageShape)
+ }
+ }
+
+ private def getFloatPixelsArray(resizedImage: BufferedImage): Array[Float] = {
+
// Get height and width of the image
val w = resizedImage.getWidth()
val h = resizedImage.getHeight()
@@ -166,7 +192,6 @@ object ImageClassifier {
// 3 times height and width for R,G,B channels
val result = new Array[Float](3 * h * w)
-
var row = 0
// copy pixels to array vertically
while (row < h) {
@@ -184,11 +209,10 @@ object ImageClassifier {
}
row += 1
}
+
resizedImage.flush()
- // creating NDArray according to the input shape
- val pixelsArray = NDArray.array(result, shape = inputImageShape)
- pixelsArray
+ result
}
/**
diff --git a/scala-package/infer/src/main/scala/org/apache/mxnet/infer/Predictor.scala b/scala-package/infer/src/main/scala/org/apache/mxnet/infer/Predictor.scala
index d4bce9f0d71e..67692a316cc4 100644
--- a/scala-package/infer/src/main/scala/org/apache/mxnet/infer/Predictor.scala
+++ b/scala-package/infer/src/main/scala/org/apache/mxnet/infer/Predictor.scala
@@ -17,8 +17,9 @@
package org.apache.mxnet.infer
+import org.apache.mxnet.MX_PRIMITIVES.MX_PRIMITIVE_TYPE
import org.apache.mxnet.io.NDArrayIter
-import org.apache.mxnet.{Context, DataDesc, NDArray, Shape}
+import org.apache.mxnet._
import org.apache.mxnet.module.Module
import scala.collection.mutable.ListBuffer
@@ -36,11 +37,13 @@ private[infer] trait PredictBase {
*
* This method will take input as IndexedSeq one dimensional arrays and creates the
* NDArray needed for inference. The array will be reshaped based on the input descriptors.
- * @param input: An IndexedSequence of a one-dimensional array.
+ * @param input: An Indexed Sequence of a one-dimensional array of datatype
+ * Float or Double
An IndexedSequence is needed when the model has more than one input.
* @return Indexed sequence array of outputs
*/
- def predict(input: IndexedSeq[Array[Float]]): IndexedSeq[Array[Float]]
+ def predict[@specialized (Base.MX_PRIMITIVES) T](input: IndexedSeq[Array[T]])
+ : IndexedSeq[Array[T]]
/**
* Predict using NDArray as input.
@@ -123,13 +126,13 @@ class Predictor(modelPathPrefix: String,
* Takes input as IndexedSeq one dimensional arrays and creates the NDArray needed for inference
* The array will be reshaped based on the input descriptors.
*
- * @param input: An IndexedSequence of a one-dimensional array.
+ * @param input: An IndexedSequence of a one-dimensional array
+ * of data type Float or Double.
An IndexedSequence is needed when the model has more than one input.
* @return Indexed sequence array of outputs
*/
- override def predict(input: IndexedSeq[Array[Float]])
- : IndexedSeq[Array[Float]] = {
-
+ override def predict[@specialized (Base.MX_PRIMITIVES) T](input: IndexedSeq[Array[T]])
+ : IndexedSeq[Array[T]] = {
require(input.length == inputDescriptors.length,
s"number of inputs provided: ${input.length} does not match number of inputs " +
s"in inputDescriptors: ${inputDescriptors.length}")
@@ -139,12 +142,30 @@ class Predictor(modelPathPrefix: String,
s"number of elements:${i.length} in the input does not match the shape:" +
s"${d.shape.toString()}")
}
+
+ // Infer the dtype of input and call relevant method
+ val result = input(0)(0) match {
+ case d: Double => predictImpl(input.asInstanceOf[IndexedSeq[Array[Double]]])
+ case _ => predictImpl(input.asInstanceOf[IndexedSeq[Array[Float]]])
+ }
+
+ result.asInstanceOf[IndexedSeq[Array[T]]]
+ }
+
+ private def predictImpl[B, A <: MX_PRIMITIVE_TYPE]
+ (input: IndexedSeq[Array[B]])(implicit ev: B => A)
+ : IndexedSeq[Array[B]] = {
+
var inputND: ListBuffer[NDArray] = ListBuffer.empty[NDArray]
for((i, d) <- input.zip(inputDescriptors)) {
val shape = d.shape.toVector.patch(from = batchIndex, patch = Vector(1), replaced = 1)
-
- inputND += mxNetHandler.execute(NDArray.array(i, Shape(shape)))
+ if (d.dtype == DType.Float64) {
+ inputND += mxNetHandler.execute(NDArray.array(i.asInstanceOf[Array[Double]], Shape(shape)))
+ }
+ else {
+ inputND += mxNetHandler.execute(NDArray.array(i.asInstanceOf[Array[Float]], Shape(shape)))
+ }
}
// rebind with batchsize 1
@@ -158,7 +179,8 @@ class Predictor(modelPathPrefix: String,
val resultND = mxNetHandler.execute(mod.predict(new NDArrayIter(
inputND.toIndexedSeq, dataBatchSize = 1)))
- val result = resultND.map((f : NDArray) => f.toArray)
+ val result =
+ resultND.map((f : NDArray) => if (f.dtype == DType.Float64) f.toFloat64Array else f.toArray)
mxNetHandler.execute(inputND.foreach(_.dispose))
mxNetHandler.execute(resultND.foreach(_.dispose))
@@ -168,9 +190,11 @@ class Predictor(modelPathPrefix: String,
mxNetHandler.execute(mod.bind(inputDescriptors, forTraining = false, forceRebind = true))
}
- result
+ result.asInstanceOf[IndexedSeq[Array[B]]]
}
+
+
/**
* Predict using NDArray as input
* This method is useful when the input is a batch of data
diff --git a/scala-package/infer/src/main/scala/org/apache/mxnet/infer/javaapi/Predictor.scala b/scala-package/infer/src/main/scala/org/apache/mxnet/infer/javaapi/Predictor.scala
index 0466693be9bc..146fe93105e4 100644
--- a/scala-package/infer/src/main/scala/org/apache/mxnet/infer/javaapi/Predictor.scala
+++ b/scala-package/infer/src/main/scala/org/apache/mxnet/infer/javaapi/Predictor.scala
@@ -72,6 +72,30 @@ class Predictor private[mxnet] (val predictor: org.apache.mxnet.infer.Predictor)
predictor.predict(input).toArray
}
+ /**
+ * Takes input as Array of one dimensional arrays and creates the NDArray needed for inference
+ * The array will be reshaped based on the input descriptors. Example of calling in Java:
+ *
+ *
+ * {@code
+ * double tmp[][] = new double[1][224];
+ * for (int x = 0; x < 1; x++)
+ * for (int y = 0; y < 224; y++)
+ * tmp[x][y] = (int)(Math.random()*10);
+ * predictor.predict(tmp);
+ * }
+ *
+ *
+ * @param input: An Array of a one-dimensional array.
+ An extra Array is needed for when the model has more than one input.
+ * @return Indexed sequence array of outputs
+ */
+
+ def predict(input: Array[Array[Double]]):
+ Array[Array[Double]] = {
+ predictor.predict(input).toArray
+ }
+
/**
* Takes input as List of one dimensional arrays and creates the NDArray needed for inference
* The array will be reshaped based on the input descriptors.
diff --git a/scala-package/infer/src/test/scala/org/apache/mxnet/infer/ClassifierSuite.scala b/scala-package/infer/src/test/scala/org/apache/mxnet/infer/ClassifierSuite.scala
index b28aeba1deed..d9ccec468791 100644
--- a/scala-package/infer/src/test/scala/org/apache/mxnet/infer/ClassifierSuite.scala
+++ b/scala-package/infer/src/test/scala/org/apache/mxnet/infer/ClassifierSuite.scala
@@ -22,7 +22,7 @@ import java.nio.file.{Files, Paths}
import java.util
import org.apache.mxnet.module.Module
-import org.apache.mxnet.{Context, DataDesc, NDArray, Shape}
+import org.apache.mxnet.{Context, DType, DataDesc, NDArray, Shape}
import org.mockito.Matchers._
import org.mockito.Mockito
import org.scalatest.{BeforeAndAfterAll, FunSuite}
@@ -127,6 +127,29 @@ class ClassifierSuite extends FunSuite with BeforeAndAfterAll {
}
+ test("ClassifierSuite-flatFloat64Array-topK") {
+ val inputDescriptor = IndexedSeq[DataDesc](new DataDesc("data", Shape(2, 3, 2, 2)))
+ val inputData = Array.fill[Double](12)(1d)
+
+ val predictResult : IndexedSeq[Array[Double]] =
+ IndexedSeq[Array[Double]](Array(.98d, 0.97d, 0.96d, 0.99d))
+
+ val testClassifier = new MyClassifier(modelPath, inputDescriptor)
+
+ Mockito.doReturn(predictResult).when(testClassifier.predictor)
+ .predict(any(classOf[IndexedSeq[Array[Double]]]))
+
+ val result: IndexedSeq[(String, Double)] = testClassifier.
+ classify(IndexedSeq(inputData), topK = Some(10))
+
+ assert((result(0)_2).getClass == 1d.getClass)
+
+ assertResult(predictResult(0).sortBy(-_)) {
+ result.map(_._2).toArray
+ }
+
+ }
+
test("ClassifierSuite-flatArrayInput") {
val inputDescriptor = IndexedSeq[DataDesc](new DataDesc("data", Shape(2, 3, 2, 2)))
val inputData = Array.fill[Float](12)(1)
@@ -147,6 +170,28 @@ class ClassifierSuite extends FunSuite with BeforeAndAfterAll {
}
}
+ test("ClassifierSuite-flatArrayFloat64Input") {
+ val inputDescriptor = IndexedSeq[DataDesc](new DataDesc("data", Shape(2, 3, 2, 2)))
+ val inputData = Array.fill[Double](12)(1d)
+
+ val predictResult : IndexedSeq[Array[Double]] =
+ IndexedSeq[Array[Double]](Array(.98d, 0.97d, 0.96d, 0.99d))
+
+ val testClassifier = new MyClassifier(modelPath, inputDescriptor)
+
+ Mockito.doReturn(predictResult).when(testClassifier.predictor)
+ .predict(any(classOf[IndexedSeq[Array[Double]]]))
+
+ val result: IndexedSeq[(String, Double)] = testClassifier.
+ classify(IndexedSeq(inputData))
+
+ assert((result(0)_2).getClass == 1d.getClass)
+
+ assertResult(predictResult(0)) {
+ result.map(_._2).toArray
+ }
+ }
+
test("ClassifierSuite-NDArray1InputWithoutTopK") {
val inputDescriptor = IndexedSeq[DataDesc](new DataDesc("data", Shape(2, 3, 2, 2)))
val inputDataShape = Shape(1, 3, 2, 2)
diff --git a/scala-package/infer/src/test/scala/org/apache/mxnet/infer/ImageClassifierSuite.scala b/scala-package/infer/src/test/scala/org/apache/mxnet/infer/ImageClassifierSuite.scala
index 1c291e1e7b3c..5198c4a1f309 100644
--- a/scala-package/infer/src/test/scala/org/apache/mxnet/infer/ImageClassifierSuite.scala
+++ b/scala-package/infer/src/test/scala/org/apache/mxnet/infer/ImageClassifierSuite.scala
@@ -68,6 +68,10 @@ class ImageClassifierSuite extends ClassifierSuite with BeforeAndAfterAll {
val result = ImageClassifier.bufferedImageToPixels(image2, Shape(3, 2, 2))
assert(result.shape == inputDescriptor(0).shape.drop(1))
+ assert(result.dtype == DType.Float32)
+
+ val resultFloat64 = ImageClassifier.bufferedImageToPixels(image2, Shape(3, 2, 2), DType.Float64)
+ assert(resultFloat64.dtype == DType.Float64)
}
test("ImageClassifierSuite-testWithInputImage") {
@@ -106,8 +110,10 @@ class ImageClassifierSuite extends ClassifierSuite with BeforeAndAfterAll {
predictResult(i).map(_._2).toArray
}
}
+
}
+
test("ImageClassifierSuite-testWithInputBatchImage") {
val dType = DType.Float32
val inputDescriptor = IndexedSeq[DataDesc](new DataDesc(modelPath, Shape(1, 3, 512, 512),
@@ -152,4 +158,5 @@ class ImageClassifierSuite extends ClassifierSuite with BeforeAndAfterAll {
}
}
}
+
}
diff --git a/scala-package/infer/src/test/scala/org/apache/mxnet/infer/PredictorSuite.scala b/scala-package/infer/src/test/scala/org/apache/mxnet/infer/PredictorSuite.scala
index 509ffb35db8d..9afbc9b3d4a8 100644
--- a/scala-package/infer/src/test/scala/org/apache/mxnet/infer/PredictorSuite.scala
+++ b/scala-package/infer/src/test/scala/org/apache/mxnet/infer/PredictorSuite.scala
@@ -19,7 +19,7 @@ package org.apache.mxnet.infer
import org.apache.mxnet.io.NDArrayIter
import org.apache.mxnet.module.{BaseModule, Module}
-import org.apache.mxnet.{DataDesc, Layout, NDArray, Shape}
+import org.apache.mxnet._
import org.mockito.Matchers._
import org.mockito.Mockito
import org.scalatest.{BeforeAndAfterAll, FunSuite}
@@ -91,6 +91,36 @@ class PredictorSuite extends FunSuite with BeforeAndAfterAll {
, any[Option[BaseModule]], any[String])
}
+ test("PredictorSuite-testWithFlatFloat64Arrays") {
+
+ val inputDescriptor = IndexedSeq[DataDesc](new DataDesc("data", Shape(2, 3, 2, 2),
+ layout = Layout.NCHW, dtype = DType.Float64))
+ val inputData = Array.fill[Double](12)(1d)
+
+ // this will disposed at the end of the predict call on Predictor.
+ val predictResult = IndexedSeq(NDArray.ones(Shape(1, 3, 2, 2), dtype = DType.Float64))
+
+ val testPredictor = new MyPredictor("xyz", inputDescriptor)
+
+ Mockito.doReturn(predictResult).when(testPredictor.mockModule)
+ .predict(any(classOf[NDArrayIter]), any[Int], any[Boolean])
+
+ val testFun = testPredictor.predict(IndexedSeq(inputData))
+
+ assert(testFun.size == 1, "output size should be 1 ")
+
+ assert(testFun(0)(0).getClass == 1d.getClass)
+
+ assert(Array.fill[Double](12)(1d).mkString == testFun(0).mkString)
+
+ // Verify that the module was bound with batch size 1 and rebound back to the original
+ // input descriptor. the number of times is twice here because loadModule overrides the
+ // initial bind.
+ Mockito.verify(testPredictor.mockModule, Mockito.times(2)).bind(any[IndexedSeq[DataDesc]],
+ any[Option[IndexedSeq[DataDesc]]], any[Boolean], any[Boolean], any[Boolean]
+ , any[Option[BaseModule]], any[String])
+ }
+
test("PredictorSuite-testWithNDArray") {
val inputDescriptor = IndexedSeq[DataDesc](new DataDesc("data", Shape(2, 3, 2, 2),
layout = Layout.NCHW))
diff --git a/scala-package/native/src/main/native/org_apache_mxnet_native_c_api.cc b/scala-package/native/src/main/native/org_apache_mxnet_native_c_api.cc
index d684c6d13564..ea6e9c8f5ba4 100644
--- a/scala-package/native/src/main/native/org_apache_mxnet_native_c_api.cc
+++ b/scala-package/native/src/main/native/org_apache_mxnet_native_c_api.cc
@@ -424,6 +424,15 @@ JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxNDArraySyncCopyFromCPU
return ret;
}
+JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxFloat64NDArraySyncCopyFromCPU
+ (JNIEnv *env, jobject obj, jlong arrayPtr, jdoubleArray sourceArr, jint arrSize) {
+ jdouble *sourcePtr = env->GetDoubleArrayElements(sourceArr, NULL);
+ int ret = MXNDArraySyncCopyFromCPU(reinterpret_cast(arrayPtr),
+ static_cast(sourcePtr), arrSize);
+ env->ReleaseDoubleArrayElements(sourceArr, sourcePtr, 0);
+ return ret;
+}
+
JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxNDArrayGetContext
(JNIEnv *env, jobject obj, jlong arrayPtr, jobject devTypeId, jobject devId) {
int outDevType;
diff --git a/scala-package/native/src/main/native/org_apache_mxnet_native_c_api.h b/scala-package/native/src/main/native/org_apache_mxnet_native_c_api.h
index 40230ac6daae..7e8e03de9124 100644
--- a/scala-package/native/src/main/native/org_apache_mxnet_native_c_api.h
+++ b/scala-package/native/src/main/native/org_apache_mxnet_native_c_api.h
@@ -175,6 +175,14 @@ JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxNDArrayReshape
JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxNDArraySyncCopyFromCPU
(JNIEnv *, jobject, jlong, jfloatArray, jint);
+/*
+ * Class: org_apache_mxnet_LibInfo
+ * Method: mxFloat64NDArraySyncCopyFromCPU
+ * Signature: (J[DI)I
+ */
+JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxFloat64NDArraySyncCopyFromCPU
+ (JNIEnv *, jobject, jlong, jdoubleArray, jint);
+
/*
* Class: org_apache_mxnet_LibInfo
* Method: mxNDArrayLoad