Skip to content

Commit

Permalink
[Clojure] Add methods based on NDArrayAPI/SymbolAPI (apache#14195)
Browse files Browse the repository at this point in the history
* [Clojure] Add methods based on NDArrayAPI/SymbolAPI

* Add symbol API methods and ndarray API unit tests

* Some more ndarray API unit tests

* Explore direct use of JNI

* Use library info directly instead of reflection

* Add tests for generation op info

* Fix ordering of keys using array-map

* Ignore generated test files

* Minor style changes

* Refactor code for better readability

* Address comments

* Small tweaks to symbol api coercion
  • Loading branch information
kedarbellare authored and haohuw committed Jun 23, 2019
1 parent f677827 commit 834b48f
Show file tree
Hide file tree
Showing 11 changed files with 1,257 additions and 121 deletions.
2 changes: 2 additions & 0 deletions contrib/clojure-package/.gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,8 @@ examples/visualization/test-vis.pdf
src/.DS_Store
src/org/.DS_Store
test/test-ndarray.clj
test/test-ndarray-api.clj
test/test-symbol.clj
test/test-symbol-api.clj
src/org/apache/clojure_mxnet/gen/*

460 changes: 353 additions & 107 deletions contrib/clojure-package/src/dev/generator.clj

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
;; Licensed to the Apache Software Foundation (ASF) under one or more
;; contributor license agreements. See the NOTICE file distributed with
;; this work for additional information regarding copyright ownership.
;; The ASF licenses this file to You under the Apache License, Version 2.0
;; (the "License"); you may not use this file except in compliance with
;; the License. You may obtain a copy of the License at
;;
;; http://www.apache.org/licenses/LICENSE-2.0
;;
;; Unless required by applicable law or agreed to in writing, software
;; distributed under the License is distributed on an "AS IS" BASIS,
;; WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
;; See the License for the specific language governing permissions and
;; limitations under the License.
;;

(ns org.apache.clojure-mxnet.ndarray-api
"Experimental NDArray API"
(:refer-clojure :exclude [* - + > >= < <= / cast concat flatten identity load max
min repeat reverse set sort take to-array empty shuffle
ref])

(:require [org.apache.clojure-mxnet.base :as base]
[org.apache.clojure-mxnet.context :as mx-context]
[org.apache.clojure-mxnet.shape :as mx-shape]
[org.apache.clojure-mxnet.util :as util]
[clojure.reflect :as r]
[t6.from-scala.core :refer [$] :as $])
(:import (org.apache.mxnet NDArrayAPI)))

;; loads the generated functions into the namespace
(do (clojure.core/load "gen/ndarray_api"))
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
;; Licensed to the Apache Software Foundation (ASF) under one or more
;; contributor license agreements. See the NOTICE file distributed with
;; this work for additional information regarding copyright ownership.
;; The ASF licenses this file to You under the Apache License, Version 2.0
;; (the "License"); you may not use this file except in compliance with
;; the License. You may obtain a copy of the License at
;;
;; http://www.apache.org/licenses/LICENSE-2.0
;;
;; Unless required by applicable law or agreed to in writing, software
;; distributed under the License is distributed on an "AS IS" BASIS,
;; WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
;; See the License for the specific language governing permissions and
;; limitations under the License.
;;

(ns org.apache.clojure-mxnet.symbol-api
"Experimental Symbol API"
(:refer-clojure :exclude [* - + > >= < <= / cast concat identity flatten load max
min repeat reverse set sort take to-array empty sin
get apply shuffle ref])
(:require [org.apache.clojure-mxnet.base :as base]
[org.apache.clojure-mxnet.context :as mx-context]
[org.apache.clojure-mxnet.executor :as ex]
[org.apache.clojure-mxnet.shape :as mx-shape]
[org.apache.clojure-mxnet.util :as util]
[t6.from-scala.core :refer [$] :as $]
[org.apache.clojure-mxnet.ndarray :as ndarray])
(:import (org.apache.mxnet SymbolAPI)))

;; loads the generated functions into the namespace
(do (clojure.core/load "gen/symbol_api"))
6 changes: 4 additions & 2 deletions contrib/clojure-package/src/org/apache/clojure_mxnet/util.clj
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@
"int<>" "vec-of-ints"
"float<>" "vec-of-floats"
"byte<>" "byte-array"
"java.lang.String<>" "vec-or-strings"
"org.apache.mxnet.NDArray" "ndarray"
"org.apache.mxnet.Symbol" "sym"
"org.apache.mxnet.MX_PRIMITIVES$MX_PRIMITIVE_TYPE" "double-or-float"})
Expand All @@ -49,7 +48,7 @@
"int<>" "vec-of-ints"
"float<>" "vec-of-floats"
"byte<>" "byte-array"
"java.lang.String<>" "vec-or-strings"
"java.lang.String<>" "vec-of-strings"
"org.apache.mxnet.Symbol" "sym"
"java.lang.Object" "object"})

Expand Down Expand Up @@ -152,9 +151,12 @@
(and (get targets "scala.collection.Seq") (instance? org.apache.mxnet.Symbol param)) ($/immutable-list param)
(and (get targets "scala.collection.Seq") (and (or (vector? param) (seq? param)) (empty? param))) (empty-list)
(and (get targets "scala.collection.Seq") (or (vector? param) (seq? param))) (apply $/immutable-list param)
(and (get targets "org.apache.mxnet.Shape") (or (vector? param) (seq? param) (empty? param))) (mx-shape/->shape param)
(and (get targets "int<>") (vector? param)) (int-array param)
(and (get targets "float<>") (vector? param)) (float-array param)
(and (get targets "java.lang.String<>") (vector? param)) (into-array param)
(and (get targets "org.apache.mxnet.NDArray<>") (vector? param)) (into-array param)
(and (get targets "org.apache.mxnet.Symbol<>") (vector? param)) (into-array param)
(and (get targets "org.apache.mxnet.MX_PRIMITIVES$MX_PRIMITIVE_TYPE") (instance? Float param)) (primitives/mx-float param)
(and (get targets "org.apache.mxnet.MX_PRIMITIVES$MX_PRIMITIVE_TYPE") (number? param)) (primitives/mx-double param)
:else param))
Expand Down
148 changes: 146 additions & 2 deletions contrib/clojure-package/test/dev/generator_test.clj
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,127 @@
(is (= transformed-params (gen/symbol-transform-param-name
(:parameter-types (symbol-reflect-info "floor")))))))

(deftest test-gen-op-info
(testing "activation"
(let [activation-info (gen/gen-op-info "Activation")]
(is (= "activation" (:fn-name activation-info)))
(is (string? (:fn-description activation-info)))
(is (= 2 (-> activation-info :args count)))
(is (= "" (:key-var-num-args activation-info)))

(is (= "data" (-> activation-info :args first :name)))
(is (= "NDArray-or-Symbol" (-> activation-info :args first :type)))
(is (false? (-> activation-info :args first :optional?)))
(is (nil? (-> activation-info :args first :default)))
(is (string? (-> activation-info :args first :description)))

(is (= "act-type" (-> activation-info :args second :name)))
(is (= "'relu', 'sigmoid', 'softrelu', 'softsign', 'tanh'" (-> activation-info :args second :type)))
(is (false? (-> activation-info :args second :optional?)))
(is (nil? (-> activation-info :args second :default)))
(is (string? (-> activation-info :args second :description)))))

(testing "argmin"
(let [argmin-info (gen/gen-op-info "argmin")]
(is (= "argmin" (:fn-name argmin-info)))
(is (= 3 (-> argmin-info :args count)))

(is (= "data" (-> argmin-info :args (nth 0) :name)))
(is (= "NDArray-or-Symbol" (-> argmin-info :args (nth 0) :type)))
(is (false? (-> argmin-info :args (nth 0) :optional?)))

(is (= "axis" (-> argmin-info :args (nth 1) :name)))
(is (= "int or None" (-> argmin-info :args (nth 1) :type)))
(is (= "'None'" (-> argmin-info :args (nth 1) :default)))
(is (true? (-> argmin-info :args (nth 1) :optional?)))

(is (= "keepdims" (-> argmin-info :args (nth 2) :name)))
(is (= "boolean" (-> argmin-info :args (nth 2) :type)))
(is (= "0" (-> argmin-info :args (nth 2) :default)))
(is (true? (-> argmin-info :args (nth 2) :optional?)))))

(testing "concat"
(let [concat-info (gen/gen-op-info "Concat")]
(is (= "concat" (:fn-name concat-info)))
(is (= 3 (-> concat-info :args count)))
(is (= "num-args" (:key-var-num-args concat-info)))

(is (= "data" (-> concat-info :args (nth 0) :name)))
(is (= "NDArray-or-Symbol[]" (-> concat-info :args (nth 0) :type)))
(is (false? (-> concat-info :args (nth 0) :optional?)))

(is (= "num-args" (-> concat-info :args (nth 1) :name)))
(is (= "int" (-> concat-info :args (nth 1) :type)))
(is (false? (-> concat-info :args (nth 1) :optional?)))

(is (= "dim" (-> concat-info :args (nth 2) :name)))
(is (= "int" (-> concat-info :args (nth 2) :type)))
(is (= "'1'" (-> concat-info :args (nth 2) :default)))
(is (true? (-> concat-info :args (nth 2) :optional?)))))

(testing "convolution"
(let [convolution-info (gen/gen-op-info "Convolution")]

(is (= "convolution" (:fn-name convolution-info)))
(is (= 14 (-> convolution-info :args count)))
(is (= "" (:key-var-num-args convolution-info)))

(is (= "data" (-> convolution-info :args (nth 0) :name)))
(is (= "NDArray-or-Symbol" (-> convolution-info :args (nth 0) :type)))
(is (false? (-> convolution-info :args (nth 0) :optional?)))

(is (= "weight" (-> convolution-info :args (nth 1) :name)))
(is (= "NDArray-or-Symbol" (-> convolution-info :args (nth 1) :type)))
(is (false? (-> convolution-info :args (nth 1) :optional?)))

(is (= "kernel" (-> convolution-info :args (nth 3) :name)))
(is (= "Shape" (-> convolution-info :args (nth 3) :type)))
(is (= "(tuple)" (-> convolution-info :args (nth 3) :spec)))
(is (false? (-> convolution-info :args (nth 3) :optional?)))

(is (= "stride" (-> convolution-info :args (nth 4) :name)))
(is (= "Shape" (-> convolution-info :args (nth 4) :type)))
(is (= "(tuple)" (-> convolution-info :args (nth 4) :spec)))
(is (= "[]" (-> convolution-info :args (nth 4) :default)))
(is (true? (-> convolution-info :args (nth 4) :optional?)))

(is (= "num-filter" (-> convolution-info :args (nth 7) :name)))
(is (= "int" (-> convolution-info :args (nth 7) :type)))
(is (= "(non-negative)" (-> convolution-info :args (nth 7) :spec)))
(is (false? (-> convolution-info :args (nth 7) :optional?)))

(is (= "num-group" (-> convolution-info :args (nth 8) :name)))
(is (= "int" (-> convolution-info :args (nth 8) :type)))
(is (= "(non-negative)" (-> convolution-info :args (nth 8) :spec)))
(is (= "1" (-> convolution-info :args (nth 8) :default)))
(is (true? (-> convolution-info :args (nth 8) :optional?)))

(is (= "workspace" (-> convolution-info :args (nth 9) :name)))
(is (= "long" (-> convolution-info :args (nth 9) :type)))
(is (= "(non-negative)" (-> convolution-info :args (nth 9) :spec)))
(is (= "1024" (-> convolution-info :args (nth 9) :default)))
(is (true? (-> convolution-info :args (nth 9) :optional?)))

(is (= "no-bias" (-> convolution-info :args (nth 10) :name)))
(is (= "boolean" (-> convolution-info :args (nth 10) :type)))
(is (= "0" (-> convolution-info :args (nth 10) :default)))
(is (true? (-> convolution-info :args (nth 10) :optional?)))

(is (= "layout" (-> convolution-info :args (nth 13) :name)))
(is (= "None, 'NCDHW', 'NCHW', 'NCW', 'NDHWC', 'NHWC'" (-> convolution-info :args (nth 13) :type)))
(is (= "'None'" (-> convolution-info :args (nth 13) :default)))
(is (true? (-> convolution-info :args (nth 13) :optional?)))))

(testing "element wise sum"
(let [element-wise-sum-info (gen/gen-op-info "ElementWiseSum")]
(is (= "add-n" (:fn-name element-wise-sum-info)))
(is (= 1 (-> element-wise-sum-info :args count)))
(is (= "num-args" (:key-var-num-args element-wise-sum-info)))

(is (= "args" (-> element-wise-sum-info :args (nth 0) :name)))
(is (= "NDArray-or-Symbol[]" (-> element-wise-sum-info :args (nth 0) :type)))
(is (false? (-> element-wise-sum-info :args (nth 0) :optional?))))))

(deftest test-ndarray-transform-param-name
(let [params ["scala.collection.immutable.Map"
"scala.collection.Seq"]
Expand All @@ -68,7 +189,10 @@

(deftest test-rename-duplicate-params
(is (= ["foo" "bar" "baz"] (gen/rename-duplicate-params ["foo" "bar" "baz"])))
(is (= ["foo" "bar" "bar-1"] (gen/rename-duplicate-params ["foo" "bar" "bar"]))))
(is (= ["foo" "bar" "bar-1"] (gen/rename-duplicate-params ["foo" "bar" "bar"])))
(is (= ["foo" "bar" "bar-1" "foo-1"] (gen/rename-duplicate-params ["foo" "bar" "bar" "foo"])))
(is (= ["foo" "bar" "bar-1" "bar-2"] (gen/rename-duplicate-params ["foo" "bar" "bar" "bar"])))
(is (= ["foo" "bar" "bar-1" "bar-2" "foo-1" "baz"] (gen/rename-duplicate-params ["foo" "bar" "bar" "bar" "foo" "baz"]))))

(deftest test-is-symbol-hand-gen?
(is (not (false? (gen/is-symbol-hand-gen? (symbol-reflect-info "max")))))
Expand Down Expand Up @@ -191,7 +315,17 @@
(gen/gen-ndarray-function-arity op-name op-values)))))

(deftest test-write-to-file
(testing "symbol"
(testing "symbol-api"
(let [fname "test/test-symbol-api.clj"
_ (gen/write-to-file [(first gen/all-symbol-api-functions)
(second gen/all-symbol-api-functions)]
gen/symbol-api-gen-ns
fname)
good-contents (slurp "test/good-test-symbol-api.clj")
contents (slurp fname)]
(is (= good-contents contents))))

(testing "symbol"
(let [fname "test/test-symbol.clj"
_ (gen/write-to-file [(first gen/all-symbol-functions)]
gen/symbol-gen-ns
Expand All @@ -200,6 +334,16 @@
contents (slurp fname)]
(is (= good-contents contents))))

(testing "ndarray-api"
(let [fname "test/test-ndarray-api.clj"
_ (gen/write-to-file [(first gen/all-ndarray-api-functions)
(second gen/all-ndarray-api-functions)]
gen/ndarray-api-gen-ns
fname)
good-contents (slurp "test/good-test-ndarray-api.clj")
contents (slurp fname)]
(is (= good-contents contents))))

(testing "ndarray"
(let [fname "test/test-ndarray.clj"
_ (gen/write-to-file [(first gen/all-ndarray-functions)]
Expand Down
89 changes: 89 additions & 0 deletions contrib/clojure-package/test/good-test-ndarray-api.clj
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
(ns
^{:doc "Experimental"}
org.apache.clojure-mxnet.ndarray-api
(:refer-clojure :exclude [* - + > >= < <= / cast concat flatten identity load max
min repeat reverse set sort take to-array empty shuffle
ref])
(:require [org.apache.clojure-mxnet.shape :as mx-shape]
[org.apache.clojure-mxnet.util :as util])
(:import (org.apache.mxnet NDArrayAPI)))

;; Do not edit - this is auto-generated

;; Licensed to the Apache Software Foundation (ASF) under one or more
;; contributor license agreements. See the NOTICE file distributed with
;; this work for additional information regarding copyright ownership.
;; The ASF licenses this file to You under the Apache License, Version 2.0
;; (the "License"); you may not use this file except in compliance with
;; the License. You may obtain a copy of the License at
;;
;; http://www.apache.org/licenses/LICENSE-2.0
;;
;; Unless required by applicable law or agreed to in writing, software
;; distributed under the License is distributed on an "AS IS" BASIS,
;; WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
;; See the License for the specific language governing permissions and
;; limitations under the License.
;;




(defn
activation
"Applies an activation function element-wise to the input.\n\nThe following activation functions are supported:\n\n- `relu`: Rectified Linear Unit, :math:`y = max(x, 0)`\n- `sigmoid`: :math:`y = \\frac{1}{1 + exp(-x)}`\n- `tanh`: Hyperbolic tangent, :math:`y = \\frac{exp(x) - exp(-x)}{exp(x) + exp(-x)}`\n- `softrelu`: Soft ReLU, or SoftPlus, :math:`y = log(1 + exp(x))`\n- `softsign`: :math:`y = \\frac{x}{1 + abs(x)}`\n\n\n\nDefined in src/operator/nn/activation.cc:L167\n\n`data`: The input array.\n`act-type`: Activation function to be applied.\n`out`: Output array. (optional)\n"
([data act-type] (activation {:data data, :act-type act-type}))
([{:keys [data act-type out], :or {out nil}, :as opts}]
(util/coerce-return
(NDArrayAPI/Activation data act-type (util/->option out)))))

(defn
batch-norm
"Batch normalization.\n\nNormalizes a data batch by mean and variance, and applies a scale ``gamma`` as\nwell as offset ``beta``.\n\nAssume the input has more than one dimension and we normalize along axis 1.\nWe first compute the mean and variance along this axis:\n\n.. math::\n\n data\\_mean[i] = mean(data[:,i,:,...]) \\\\\n data\\_var[i] = var(data[:,i,:,...])\n\nThen compute the normalized output, which has the same shape as input, as following:\n\n.. math::\n\n out[:,i,:,...] = \\frac{data[:,i,:,...] - data\\_mean[i]}{\\sqrt{data\\_var[i]+\\epsilon}} * gamma[i] + beta[i]\n\nBoth *mean* and *var* returns a scalar by treating the input as a vector.\n\nAssume the input has size *k* on axis 1, then both ``gamma`` and ``beta``\nhave shape *(k,)*. If ``output_mean_var`` is set to be true, then outputs both ``data_mean`` and\nthe inverse of ``data_var``, which are needed for the backward pass. Note that gradient of these\ntwo outputs are blocked.\n\nBesides the inputs and the outputs, this operator accepts two auxiliary\nstates, ``moving_mean`` and ``moving_var``, which are *k*-length\nvectors. They are global statistics for the whole dataset, which are updated\nby::\n\n moving_mean = moving_mean * momentum + data_mean * (1 - momentum)\n moving_var = moving_var * momentum + data_var * (1 - momentum)\n\nIf ``use_global_stats`` is set to be true, then ``moving_mean`` and\n``moving_var`` are used instead of ``data_mean`` and ``data_var`` to compute\nthe output. It is often used during inference.\n\nThe parameter ``axis`` specifies which axis of the input shape denotes\nthe 'channel' (separately normalized groups). The default is 1. Specifying -1 sets the channel\naxis to be the last item in the input shape.\n\nBoth ``gamma`` and ``beta`` are learnable parameters. But if ``fix_gamma`` is true,\nthen set ``gamma`` to 1 and its gradient to 0.\n\n.. Note::\n When ``fix_gamma`` is set to True, no sparse support is provided. If ``fix_gamma is`` set to False,\n the sparse tensors will fallback.\n\n\n\nDefined in src/operator/nn/batch_norm.cc:L574\n\n`data`: Input data to batch normalization\n`gamma`: gamma array\n`beta`: beta array\n`moving-mean`: running mean of input\n`moving-var`: running variance of input\n`eps`: Epsilon to prevent div 0. Must be no less than CUDNN_BN_MIN_EPSILON defined in cudnn.h when using cudnn (usually 1e-5) (optional)\n`momentum`: Momentum for moving average (optional)\n`fix-gamma`: Fix gamma while training (optional)\n`use-global-stats`: Whether use global moving statistics instead of local batch-norm. This will force change batch-norm into a scale shift operator. (optional)\n`output-mean-var`: Output the mean and inverse std (optional)\n`axis`: Specify which shape axis the channel is specified (optional)\n`cudnn-off`: Do not select CUDNN operator, if available (optional)\n`out`: Output array. (optional)\n"
([data gamma beta moving-mean moving-var]
(batch-norm
{:data data,
:gamma gamma,
:beta beta,
:moving-mean moving-mean,
:moving-var moving-var}))
([{:keys
[data
gamma
beta
moving-mean
moving-var
eps
momentum
fix-gamma
use-global-stats
output-mean-var
axis
cudnn-off
out],
:or
{eps nil,
momentum nil,
fix-gamma nil,
use-global-stats nil,
output-mean-var nil,
axis nil,
cudnn-off nil,
out nil},
:as opts}]
(util/coerce-return
(NDArrayAPI/BatchNorm
data
gamma
beta
moving-mean
moving-var
(util/->option eps)
(util/->option momentum)
(util/->option fix-gamma)
(util/->option use-global-stats)
(util/->option output-mean-var)
(util/->option axis)
(util/->option cudnn-off)
(util/->option out)))))

Loading

0 comments on commit 834b48f

Please sign in to comment.