From 37e86b3497537f78f66841b0b28d7099c34abcc4 Mon Sep 17 00:00:00 2001 From: Kedar Bellare Date: Sun, 7 Apr 2019 09:08:19 -0700 Subject: [PATCH] Add tests for generation op info --- contrib/clojure-package/src/dev/generator.clj | 2 +- .../test/dev/generator_test.clj | 102 ++++++++++++++++++ 2 files changed, 103 insertions(+), 1 deletion(-) diff --git a/contrib/clojure-package/src/dev/generator.clj b/contrib/clojure-package/src/dev/generator.clj index 5581587224dd..0038653d04a1 100644 --- a/contrib/clojure-package/src/dev/generator.clj +++ b/contrib/clojure-package/src/dev/generator.clj @@ -152,7 +152,7 @@ (do (.nnGetOpHandle libinfo op-name ref) (.value ref)))) -(defn- gen-op-info [op-name] +(defn gen-op-info [op-name] (let [handle (get-op-handle op-name) name (new Base$RefString nil) desc (new Base$RefString nil) diff --git a/contrib/clojure-package/test/dev/generator_test.clj b/contrib/clojure-package/test/dev/generator_test.clj index 4c6651713d7c..bafd3b905586 100644 --- a/contrib/clojure-package/test/dev/generator_test.clj +++ b/contrib/clojure-package/test/dev/generator_test.clj @@ -50,6 +50,108 @@ (is (= transformed-params (gen/symbol-transform-param-name (:parameter-types (symbol-reflect-info "floor"))))))) +(deftest test-gen-op-info + (testing "activation" + (let [activation-info (gen/gen-op-info "Activation")] + (is (= "activation" (:fn-name activation-info))) + (is (string? (:fn-description activation-info))) + (is (= 2 (-> activation-info :args count))) + (is (= "" (:key-var-num-args activation-info))) + + (is (= "data" (-> activation-info :args first :name))) + (is (= "NDArray-or-Symbol" (-> activation-info :args first :type))) + (is (false? (-> activation-info :args first :optional?))) + (is (nil? (-> activation-info :args first :default))) + (is (string? (-> activation-info :args first :description))) + + (is (= "act-type" (-> activation-info :args second :name))) + (is (= "'relu', 'sigmoid', 'softrelu', 'softsign', 'tanh'" (-> activation-info :args second :type))) + (is (false? (-> activation-info :args second :optional?))) + (is (nil? (-> activation-info :args second :default))) + (is (string? (-> activation-info :args second :description))))) + + (testing "concat" + (let [concat-info (gen/gen-op-info "Concat")] + (is (= "concat" (:fn-name concat-info))) + (is (= 3 (-> concat-info :args count))) + (is (= "num-args" (:key-var-num-args concat-info))) + + (is (= "data" (-> concat-info :args (nth 0) :name))) + (is (= "NDArray-or-Symbol[]" (-> concat-info :args (nth 0) :type))) + (is (false? (-> concat-info :args (nth 0) :optional?))) + + (is (= "num-args" (-> concat-info :args (nth 1) :name))) + (is (= "int" (-> concat-info :args (nth 1) :type))) + (is (false? (-> concat-info :args (nth 1) :optional?))) + + (is (= "dim" (-> concat-info :args (nth 2) :name))) + (is (= "int" (-> concat-info :args (nth 2) :type))) + (is (= "'1'" (-> concat-info :args (nth 2) :default))) + (is (true? (-> concat-info :args (nth 2) :optional?))))) + + (testing "convolution" + (let [convolution-info (gen/gen-op-info "Convolution")] + + (is (= "convolution" (:fn-name convolution-info))) + (is (= 14 (-> convolution-info :args count))) + (is (= "" (:key-var-num-args convolution-info))) + + (is (= "data" (-> convolution-info :args (nth 0) :name))) + (is (= "NDArray-or-Symbol" (-> convolution-info :args (nth 0) :type))) + (is (false? (-> convolution-info :args (nth 0) :optional?))) + + (is (= "weight" (-> convolution-info :args (nth 1) :name))) + (is (= "NDArray-or-Symbol" (-> convolution-info :args (nth 1) :type))) + (is (false? (-> convolution-info :args (nth 1) :optional?))) + + (is (= "kernel" (-> convolution-info :args (nth 3) :name))) + (is (= "Shape" (-> convolution-info :args (nth 3) :type))) + (is (= "(tuple)" (-> convolution-info :args (nth 3) :spec))) + (is (false? (-> convolution-info :args (nth 3) :optional?))) + + (is (= "stride" (-> convolution-info :args (nth 4) :name))) + (is (= "Shape" (-> convolution-info :args (nth 4) :type))) + (is (= "(tuple)" (-> convolution-info :args (nth 4) :spec))) + (is (= "[]" (-> convolution-info :args (nth 4) :default))) + (is (true? (-> convolution-info :args (nth 4) :optional?))) + + (is (= "num-filter" (-> convolution-info :args (nth 7) :name))) + (is (= "int" (-> convolution-info :args (nth 7) :type))) + (is (= "(non-negative)" (-> convolution-info :args (nth 7) :spec))) + (is (false? (-> convolution-info :args (nth 7) :optional?))) + + (is (= "num-group" (-> convolution-info :args (nth 8) :name))) + (is (= "int" (-> convolution-info :args (nth 8) :type))) + (is (= "(non-negative)" (-> convolution-info :args (nth 8) :spec))) + (is (= "1" (-> convolution-info :args (nth 8) :default))) + (is (true? (-> convolution-info :args (nth 8) :optional?))) + + (is (= "workspace" (-> convolution-info :args (nth 9) :name))) + (is (= "long" (-> convolution-info :args (nth 9) :type))) + (is (= "(non-negative)" (-> convolution-info :args (nth 9) :spec))) + (is (= "1024" (-> convolution-info :args (nth 9) :default))) + (is (true? (-> convolution-info :args (nth 9) :optional?))) + + (is (= "no-bias" (-> convolution-info :args (nth 10) :name))) + (is (= "boolean" (-> convolution-info :args (nth 10) :type))) + (is (= "0" (-> convolution-info :args (nth 10) :default))) + (is (true? (-> convolution-info :args (nth 10) :optional?))) + + (is (= "layout" (-> convolution-info :args (nth 13) :name))) + (is (= "None, 'NCDHW', 'NCHW', 'NCW', 'NDHWC', 'NHWC'" (-> convolution-info :args (nth 13) :type))) + (is (= "'None'" (-> convolution-info :args (nth 13) :default))) + (is (true? (-> convolution-info :args (nth 13) :optional?))))) + + (testing "element wise sum" + (let [element-wise-sum-info (gen/gen-op-info "ElementWiseSum")] + (is (= "add-n" (:fn-name element-wise-sum-info))) + (is (= 1 (-> element-wise-sum-info :args count))) + (is (= "num-args" (:key-var-num-args element-wise-sum-info))) + + (is (= "args" (-> element-wise-sum-info :args (nth 0) :name))) + (is (= "NDArray-or-Symbol[]" (-> element-wise-sum-info :args (nth 0) :type))) + (is (false? (-> element-wise-sum-info :args (nth 0) :optional?)))))) + (deftest test-ndarray-transform-param-name (let [params ["scala.collection.immutable.Map" "scala.collection.Seq"]