Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

[Clojure] Helper function for n-dim vector to ndarray #14305

Merged
merged 4 commits into from
Mar 11, 2019
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,22 @@
([start stop]
(arange start stop {})))

(defn ->ndarray
"Creates a new NDArray based on the given n-dimensional
float/double vector.
`nd-vec`: n-dimensional vector with floats or doubles.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we want to allow also integers here? or other numerical types that can be understood by MXNet?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe the Scala functions will not allow it but we can cast to float/double from clojure?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

currently, scala doesn't allow numerical types other than float/double.

Copy link
Contributor Author

@kedarbellare kedarbellare Mar 4, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

usually i've been using ndarray/as-type for type conversion instead of doing conversion within clojure. my hunch is as-type is faster but i've not verified this. however, one disadvantage is that it returns a copy of the array in the new type.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

From a user perspective, I think it would be nice to be able to handle ((ndarray/->ndarray [5 -4 3]) and not throw an exception. What do you think about adding a check in the util function to see if any element is an integer and if so, convert it to double. Ex:

(if (some int? s)
      (to-array (mapv double s))
      (to-array s))

`opts-map` {
`ctx`: Context of the output ndarray, will use default context if unspecified.
}
returns: `ndarray` with the given values and matching the shape of the input vector.
Ex:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Love your docstring!

(->ndarray [5.0 -4.0])
(->ndarray [[1.0 2.0 3.0] [4.0 5.0 6.0]])
(->ndarray [[[1.0] [2.0]]]"
([nd-vec {:keys [ctx] :as opts}]
(NDArray/toNDArray (util/to-array-nd nd-vec) ctx))
([nd-vec] (->ndarray nd-vec {})))

(defn slice
"Return a sliced NDArray that shares memory with current one."
([ndarray i]
Expand Down
10 changes: 10 additions & 0 deletions contrib/clojure-package/src/org/apache/clojure_mxnet/util.clj
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,16 @@
(throw (ex-info error-msg
(s/explain-data spec value)))))

(s/def ::non-empty-seq sequential?)
(defn to-array-nd
"Converts any N-D sequential structure to an array
with the same dimensions."
[s]
(validate! ::non-empty-seq s "Invalid N-D sequence")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for putting this validation in

(if (sequential? (first s))
(to-array (mapv to-array-nd s))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

recursion for the win! 🥇

(to-array s)))

(defn map->scala-tuple-seq
"* Convert a map to a scala-Seq of scala-Tubple.
* Should also work if a seq of seq of 2 things passed.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,18 @@
(is (= [0.0 0.0 0.5 0.5 1.0 1.0 1.5 1.5 2.0 2.0 2.5 2.5 3.0 3.0 3.5 3.5 4.0 4.0 4.5 4.5]
(->vec (ndarray/arange start stop {:step step :repeat repeat}))))))

(deftest test->ndarray
(let [nda1 (ndarray/->ndarray [5.0 -4.0])
nda2 (ndarray/->ndarray [[1.0 2.0 3.0]
[4.0 5.0 6.0]])
nda3 (ndarray/->ndarray [[[7.0] [8.0]]])]
(is (= [5.0 -4.0] (->vec nda1)))
(is (= [2] (mx-shape/->vec (shape nda1))))
(is (= [1.0 2.0 3.0 4.0 5.0 6.0] (->vec nda2)))
(is (= [2 3] (mx-shape/->vec (shape nda2))))
(is (= [7.0 8.0] (->vec nda3)))
(is (= [1 2 1] (mx-shape/->vec (shape nda3))))))

(deftest test-power
(let [nda (ndarray/array [3 5] [2 1])]

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,26 @@
(is (= [1 2] (-> (util/convert-tuple [1 2])
(util/tuple->vec)))))

(deftest test-to-array-nd
(let [a1 (util/to-array-nd '())
a2 (util/to-array-nd [1.0 2.0])
a3 (util/to-array-nd [[3.0] [4.0]])
a4 (util/to-array-nd [[[5 -5]]])]
(is (= 0 (alength a1)))
(is (= [] (->> a1 vec)))
(is (= 2 (alength a2)))
(is (= 2.0 (aget a2 1)))
(is (= [1.0 2.0] (->> a2 vec)))
(is (= 2 (alength a3)))
(is (= 1 (alength (aget a3 0))))
(is (= 4.0 (aget a3 1 0)))
(is (= [[3.0] [4.0]] (->> a3 vec (mapv vec))))
(is (= 1 (alength a4)))
(is (= 1 (alength (aget a4 0))))
(is (= 2 (alength (aget a4 0 0))))
(is (= 5 (aget a4 0 0 0)))
(is (= [[[5 -5]]] (->> a4 vec (mapv vec) (mapv #(mapv vec %)))))))

(deftest test-coerce-return
(is (= [] (util/coerce-return (ArrayBuffer.))))
(is (= [1 2 3] (util/coerce-return (util/vec->indexed-seq [1 2 3]))))
Expand Down