diff --git a/contrib/clojure-package/src/org/apache/clojure_mxnet/ndarray.clj b/contrib/clojure-package/src/org/apache/clojure_mxnet/ndarray.clj index 651bdcb3f315..151e18bcb482 100644 --- a/contrib/clojure-package/src/org/apache/clojure_mxnet/ndarray.clj +++ b/contrib/clojure-package/src/org/apache/clojure_mxnet/ndarray.clj @@ -16,15 +16,18 @@ ;; (ns org.apache.clojure-mxnet.ndarray + "NDArray API for Clojure package." (:refer-clojure :exclude [* - + > >= < <= / cast concat flatten identity load max min repeat reverse set sort take to-array empty shuffle ref]) - (:require [org.apache.clojure-mxnet.base :as base] - [org.apache.clojure-mxnet.context :as mx-context] - [org.apache.clojure-mxnet.shape :as mx-shape] - [org.apache.clojure-mxnet.util :as util] - [clojure.reflect :as r] - [t6.from-scala.core :refer [$] :as $]) + (:require + [clojure.spec.alpha :as s] + + [org.apache.clojure-mxnet.base :as base] + [org.apache.clojure-mxnet.context :as mx-context] + [org.apache.clojure-mxnet.shape :as mx-shape] + [org.apache.clojure-mxnet.util :as util] + [t6.from-scala.core :refer [$] :as $]) (:import (org.apache.mxnet NDArray))) ;; loads the generated functions into the namespace @@ -167,3 +170,46 @@ (defn shape-vec [ndarray] (mx-shape/->vec (shape ndarray))) + +(s/def ::ndarray #(instance? NDArray %)) +(s/def ::vector vector?) +(s/def ::sequential sequential?) +(s/def ::shape-vec-match-vec + (fn [[v vec-shape]] (= (count v) (reduce clojure.core/* 1 vec-shape)))) + +(s/fdef vec->nd-vec + :args (s/cat :v ::sequential :shape-vec ::sequential) + :ret ::vector) + +(defn- vec->nd-vec + "Convert a vector `v` into a n-dimensional vector given the `shape-vec` + Ex: + (vec->nd-vec [1 2 3] [1 1 3]) ;[[[1 2 3]]] + (vec->nd-vec [1 2 3 4 5 6] [2 3 1]) ;[[[1] [2] [3]] [[4] [5] [6]]] + (vec->nd-vec [1 2 3 4 5 6] [1 2 3]) ;[[[1 2 3]] [4 5 6]]] + (vec->nd-vec [1 2 3 4 5 6] [3 1 2]) ;[[[1 2]] [[3 4]] [[5 6]]] + (vec->nd-vec [1 2 3 4 5 6] [3 2]) ;[[1 2] [3 4] [5 6]]" + [v [s1 & ss :as shape-vec]] + (util/validate! ::sequential v "Invalid input vector `v`") + (util/validate! ::sequential shape-vec "Invalid input vector `shape-vec`") + (util/validate! ::shape-vec-match-vec + [v shape-vec] + "Mismatch between vector `v` and vector `shape-vec`") + (if-not (seq ss) + (vec v) + (->> v + (partition (clojure.core// (count v) s1)) + vec + (mapv #(vec->nd-vec % ss))))) + +(s/fdef ->nd-vec :args (s/cat :ndarray ::ndarray) :ret ::vector) + +(defn ->nd-vec + "Convert an ndarray `ndarray` into a n-dimensional Clojure vector. + Ex: + (->nd-vec (array [1] [1 1 1])) ;[[[1.0]]] + (->nd-vec (array [1 2 3] [3 1 1])) ;[[[1.0]] [[2.0]] [[3.0]]] + (->nd-vec (array [1 2 3 4 5 6]) [3 1 2]) ;[[[1.0 2.0]] [[3.0 4.0]] [[5.0 6.0]]]" + [ndarray] + (util/validate! ::ndarray ndarray "Invalid input array") + (vec->nd-vec (->vec ndarray) (shape-vec ndarray))) diff --git a/contrib/clojure-package/test/org/apache/clojure_mxnet/ndarray_test.clj b/contrib/clojure-package/test/org/apache/clojure_mxnet/ndarray_test.clj index 9ffd3abed2f9..a9ae2966db89 100644 --- a/contrib/clojure-package/test/org/apache/clojure_mxnet/ndarray_test.clj +++ b/contrib/clojure-package/test/org/apache/clojure_mxnet/ndarray_test.clj @@ -473,3 +473,15 @@ (is (= [2 2] (ndarray/->int-vec nda))) (is (= [2.0 2.0] (ndarray/->double-vec nda))) (is (= [(byte 2) (byte 2)] (ndarray/->byte-vec nda))))) + +(deftest test->nd-vec + (is (= [[[1.0]]] + (ndarray/->nd-vec (ndarray/array [1] [1 1 1])))) + (is (= [[[1.0]] [[2.0]] [[3.0]]] + (ndarray/->nd-vec (ndarray/array [1 2 3] [3 1 1])))) + (is (= [[[1.0 2.0]] [[3.0 4.0]] [[5.0 6.0]]] + (ndarray/->nd-vec (ndarray/array [1 2 3 4 5 6] [3 1 2])))) + (is (= [[[1.0] [2.0]] [[3.0] [4.0]] [[5.0] [6.0]]] + (ndarray/->nd-vec (ndarray/array [1 2 3 4 5 6] [3 2 1])))) + (is (thrown-with-msg? Exception #"Invalid input array" + (ndarray/->nd-vec [1 2 3 4 5]))))