Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Add tests for generation op info
Browse files Browse the repository at this point in the history
  • Loading branch information
kedarbellare committed Apr 7, 2019
1 parent 2ea22a7 commit 37e86b3
Show file tree
Hide file tree
Showing 2 changed files with 103 additions and 1 deletion.
2 changes: 1 addition & 1 deletion contrib/clojure-package/src/dev/generator.clj
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@
(do (.nnGetOpHandle libinfo op-name ref)
(.value ref))))

(defn- gen-op-info [op-name]
(defn gen-op-info [op-name]
(let [handle (get-op-handle op-name)
name (new Base$RefString nil)
desc (new Base$RefString nil)
Expand Down
102 changes: 102 additions & 0 deletions contrib/clojure-package/test/dev/generator_test.clj
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,108 @@
(is (= transformed-params (gen/symbol-transform-param-name
(:parameter-types (symbol-reflect-info "floor")))))))

(deftest test-gen-op-info
(testing "activation"
(let [activation-info (gen/gen-op-info "Activation")]
(is (= "activation" (:fn-name activation-info)))
(is (string? (:fn-description activation-info)))
(is (= 2 (-> activation-info :args count)))
(is (= "" (:key-var-num-args activation-info)))

(is (= "data" (-> activation-info :args first :name)))
(is (= "NDArray-or-Symbol" (-> activation-info :args first :type)))
(is (false? (-> activation-info :args first :optional?)))
(is (nil? (-> activation-info :args first :default)))
(is (string? (-> activation-info :args first :description)))

(is (= "act-type" (-> activation-info :args second :name)))
(is (= "'relu', 'sigmoid', 'softrelu', 'softsign', 'tanh'" (-> activation-info :args second :type)))
(is (false? (-> activation-info :args second :optional?)))
(is (nil? (-> activation-info :args second :default)))
(is (string? (-> activation-info :args second :description)))))

(testing "concat"
(let [concat-info (gen/gen-op-info "Concat")]
(is (= "concat" (:fn-name concat-info)))
(is (= 3 (-> concat-info :args count)))
(is (= "num-args" (:key-var-num-args concat-info)))

(is (= "data" (-> concat-info :args (nth 0) :name)))
(is (= "NDArray-or-Symbol[]" (-> concat-info :args (nth 0) :type)))
(is (false? (-> concat-info :args (nth 0) :optional?)))

(is (= "num-args" (-> concat-info :args (nth 1) :name)))
(is (= "int" (-> concat-info :args (nth 1) :type)))
(is (false? (-> concat-info :args (nth 1) :optional?)))

(is (= "dim" (-> concat-info :args (nth 2) :name)))
(is (= "int" (-> concat-info :args (nth 2) :type)))
(is (= "'1'" (-> concat-info :args (nth 2) :default)))
(is (true? (-> concat-info :args (nth 2) :optional?)))))

(testing "convolution"
(let [convolution-info (gen/gen-op-info "Convolution")]

(is (= "convolution" (:fn-name convolution-info)))
(is (= 14 (-> convolution-info :args count)))
(is (= "" (:key-var-num-args convolution-info)))

(is (= "data" (-> convolution-info :args (nth 0) :name)))
(is (= "NDArray-or-Symbol" (-> convolution-info :args (nth 0) :type)))
(is (false? (-> convolution-info :args (nth 0) :optional?)))

(is (= "weight" (-> convolution-info :args (nth 1) :name)))
(is (= "NDArray-or-Symbol" (-> convolution-info :args (nth 1) :type)))
(is (false? (-> convolution-info :args (nth 1) :optional?)))

(is (= "kernel" (-> convolution-info :args (nth 3) :name)))
(is (= "Shape" (-> convolution-info :args (nth 3) :type)))
(is (= "(tuple)" (-> convolution-info :args (nth 3) :spec)))
(is (false? (-> convolution-info :args (nth 3) :optional?)))

(is (= "stride" (-> convolution-info :args (nth 4) :name)))
(is (= "Shape" (-> convolution-info :args (nth 4) :type)))
(is (= "(tuple)" (-> convolution-info :args (nth 4) :spec)))
(is (= "[]" (-> convolution-info :args (nth 4) :default)))
(is (true? (-> convolution-info :args (nth 4) :optional?)))

(is (= "num-filter" (-> convolution-info :args (nth 7) :name)))
(is (= "int" (-> convolution-info :args (nth 7) :type)))
(is (= "(non-negative)" (-> convolution-info :args (nth 7) :spec)))
(is (false? (-> convolution-info :args (nth 7) :optional?)))

(is (= "num-group" (-> convolution-info :args (nth 8) :name)))
(is (= "int" (-> convolution-info :args (nth 8) :type)))
(is (= "(non-negative)" (-> convolution-info :args (nth 8) :spec)))
(is (= "1" (-> convolution-info :args (nth 8) :default)))
(is (true? (-> convolution-info :args (nth 8) :optional?)))

(is (= "workspace" (-> convolution-info :args (nth 9) :name)))
(is (= "long" (-> convolution-info :args (nth 9) :type)))
(is (= "(non-negative)" (-> convolution-info :args (nth 9) :spec)))
(is (= "1024" (-> convolution-info :args (nth 9) :default)))
(is (true? (-> convolution-info :args (nth 9) :optional?)))

(is (= "no-bias" (-> convolution-info :args (nth 10) :name)))
(is (= "boolean" (-> convolution-info :args (nth 10) :type)))
(is (= "0" (-> convolution-info :args (nth 10) :default)))
(is (true? (-> convolution-info :args (nth 10) :optional?)))

(is (= "layout" (-> convolution-info :args (nth 13) :name)))
(is (= "None, 'NCDHW', 'NCHW', 'NCW', 'NDHWC', 'NHWC'" (-> convolution-info :args (nth 13) :type)))
(is (= "'None'" (-> convolution-info :args (nth 13) :default)))
(is (true? (-> convolution-info :args (nth 13) :optional?)))))

(testing "element wise sum"
(let [element-wise-sum-info (gen/gen-op-info "ElementWiseSum")]
(is (= "add-n" (:fn-name element-wise-sum-info)))
(is (= 1 (-> element-wise-sum-info :args count)))
(is (= "num-args" (:key-var-num-args element-wise-sum-info)))

(is (= "args" (-> element-wise-sum-info :args (nth 0) :name)))
(is (= "NDArray-or-Symbol[]" (-> element-wise-sum-info :args (nth 0) :type)))
(is (false? (-> element-wise-sum-info :args (nth 0) :optional?))))))

(deftest test-ndarray-transform-param-name
(let [params ["scala.collection.immutable.Map"
"scala.collection.Seq"]
Expand Down

0 comments on commit 37e86b3

Please sign in to comment.