Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Add primitives support handling to the generator for proper conversion
Browse files Browse the repository at this point in the history
  • Loading branch information
gigasquid authored and piyushghai committed Dec 28, 2018
1 parent 1d9b8bb commit d5595a0
Show file tree
Hide file tree
Showing 5 changed files with 77 additions and 4 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
(ns org.apache.clojure-mxnet.primitives
(:import (org.apache.mxnet MX_PRIMITIVES$MX_FLOAT MX_PRIMITIVES$MX_Double
MX_PRIMITIVES$MX_PRIMITIVE_TYPE)))


;;; Defines customer mx primitives that can be used for mathematical computations
;;; in NDArrays to control precision. Currently Float and Double are supported

;;; For purposes of automatic conversion in ndarray functions, doubles are default
;; to specify using floats you must use a Float

(defn mx-float
"Creates a MXNet float primitive"
[num]
(new MX_PRIMITIVES$MX_FLOAT num))

(defn mx-double
"Creates a MXNet double primitive"
[num]
(new MX_PRIMITIVES$MX_Double num))

(defn ->num
"Returns the underlying number value"
[primitive]
(.data primitive))

(defn primitive? [x]
(instance? MX_PRIMITIVES$MX_PRIMITIVE_TYPE x))

Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
(:require [clojure.spec.alpha :as s]
[t6.from-scala.core :refer [$ $$] :as $]
[clojure.string :as string]
[org.apache.clojure-mxnet.primitives :as primitives]
[org.apache.clojure-mxnet.shape :as mx-shape])
(:import (org.apache.mxnet NDArray)
(scala Product Tuple2 Tuple3)
Expand All @@ -36,7 +37,8 @@
"byte<>" "byte-array"
"java.lang.String<>" "vec-or-strings"
"org.apache.mxnet.NDArray" "ndarray"
"org.apache.mxnet.Symbol" "sym"})
"org.apache.mxnet.Symbol" "sym"
"org.apache.mxnet.MX_PRIMITIVES$MX_PRIMITIVE_TYPE" "double-or-float"})

(def symbol-param-coerce {"java.lang.String" "sym-name"
"float" "num"
Expand Down Expand Up @@ -144,6 +146,8 @@
(and (get targets "int<>") (vector? param)) (int-array param)
(and (get targets "float<>") (vector? param)) (float-array param)
(and (get targets "java.lang.String<>") (vector? param)) (into-array param)
(and (get targets "org.apache.mxnet.MX_PRIMITIVES$MX_PRIMITIVE_TYPE") (instance? Float param)) (primitives/mx-float param)
(and (get targets "org.apache.mxnet.MX_PRIMITIVES$MX_PRIMITIVE_TYPE") (number? param)) (primitives/mx-double param)
:else param))

(defn nil-or-coerce-param [param targets]
Expand Down Expand Up @@ -177,6 +181,7 @@
(instance? Map return-val) (scala-map->map return-val)
(instance? Tuple2 return-val) (tuple->vec return-val)
(instance? Tuple3 return-val) (tuple->vec return-val)
(primitives/primitive? return-val) (primitives/->num return-val)
:else return-val))

(defn coerce-return-recursive [return-val]
Expand Down
7 changes: 4 additions & 3 deletions contrib/clojure-package/test/good-test-ndarray.clj
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,12 @@

(defn
div
([ndarray num-or-double-or-ndarray]
([ndarray ndarray-or-double-or-float]
(util/coerce-return
(.$div
ndarray
(util/coerce-param
num-or-double-or-ndarray
#{"float" "double" "org.apache.mxnet.NDArray"})))))
ndarray-or-double-or-float
#{"org.apache.mxnet.MX_PRIMITIVES$MX_PRIMITIVE_TYPE"
"org.apache.mxnet.NDArray"})))))

Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
(ns org.apache.clojure-mxnet.primitives-test
(:require [org.apache.clojure-mxnet.primitives :as primitives]
[clojure.test :refer :all])
(:import (org.apache.mxnet MX_PRIMITIVES$MX_PRIMITIVE_TYPE
MX_PRIMITIVES$MX_FLOAT
MX_PRIMITIVES$MX_Double)))

(deftest test-primitive-types
(is (not (primitives/primitive? 3)))
(is (primitives/primitive? (primitives/mx-float 3)))
(is (primitives/primitive? (primitives/mx-double 3))))

(deftest test-float-primitives
(is (instance? MX_PRIMITIVES$MX_PRIMITIVE_TYPE (primitives/mx-float 3)))
(is (instance? MX_PRIMITIVES$MX_FLOAT (primitives/mx-float 3)))
(is (instance? Float (-> (primitives/mx-float 3)
(primitives/->num))))
(is (= 3.0 (-> (primitives/mx-float 3)
(primitives/->num)))))

(deftest test-double-primitives
(is (instance? MX_PRIMITIVES$MX_PRIMITIVE_TYPE (primitives/mx-double 2)))
(is (instance? MX_PRIMITIVES$MX_Double (primitives/mx-double 2)))
(is (instance? Double (-> (primitives/mx-double 2)
(primitives/->num))))
(is (= 2.0 (-> (primitives/mx-double 2)
(primitives/->num)))))

Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
[org.apache.clojure-mxnet.shape :as mx-shape]
[org.apache.clojure-mxnet.util :as util]
[org.apache.clojure-mxnet.ndarray :as ndarray]
[org.apache.clojure-mxnet.primitives :as primitives]
[org.apache.clojure-mxnet.symbol :as sym]
[org.apache.clojure-mxnet.test-util :as test-util]
[clojure.spec.alpha :as s])
Expand Down Expand Up @@ -133,6 +134,9 @@
(is (= "[F" (->> (util/coerce-param [1 2] #{"float<>"}) str (take 2) (apply str))))
(is (= "[L" (->> (util/coerce-param [1 2] #{"java.lang.String<>"}) str (take 2) (apply str))))

(is (primitives/primitive? (util/coerce-param 1.0 #{"org.apache.mxnet.MX_PRIMITIVES$MX_PRIMITIVE_TYPE"})))
(is (primitives/primitive? (util/coerce-param (float 1.0) #{"org.apache.mxnet.MX_PRIMITIVES$MX_PRIMITIVE_TYPE"})))

(is (= 1 (util/coerce-param 1 #{"unknown"}))))

(deftest test-nil-or-coerce-param
Expand Down Expand Up @@ -171,6 +175,12 @@
(util/convert-tuple [1 2]))))
(is (= [1 2 3] (util/coerce-return
(util/convert-tuple [1 2 3]))))

(is (instance? Double (util/coerce-return (primitives/mx-double 3))))
(is (= 3.0 (util/coerce-return (primitives/mx-double 3))))
(is (instance? Float (util/coerce-return (primitives/mx-float 2))))
(is (= 2.0 (util/coerce-return (primitives/mx-float 2))))

(is (= "foo" (util/coerce-return "foo"))))

(deftest test-translate-keyword-shape
Expand Down

0 comments on commit d5595a0

Please sign in to comment.