Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
#13385 [Clojure] - Turn examples into integration tests
Browse files Browse the repository at this point in the history
  • Loading branch information
hellonico committed Dec 9, 2018
1 parent f2ca66f commit ba812a4
Show file tree
Hide file tree
Showing 22 changed files with 439 additions and 30 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,9 @@
;;

(ns cnn-text-classification.classifier
(:require [cnn-text-classification.data-helper :as data-helper]
(:require [clojure.java.io :as io]
[clojure.java.shell :refer [sh]]
[cnn-text-classification.data-helper :as data-helper]
[org.apache.clojure-mxnet.eval-metric :as eval-metric]
[org.apache.clojure-mxnet.io :as mx-io]
[org.apache.clojure-mxnet.module :as m]
Expand All @@ -26,12 +28,18 @@
[org.apache.clojure-mxnet.context :as context])
(:gen-class))

(def data-dir "data/")
(def mr-dataset-path "data/mr-data") ;; the MR polarity dataset path
(def glove-file-path "data/glove/glove.6B.50d.txt")
(def num-filter 100)
(def num-label 2)
(def dropout 0.5)



(when-not (.exists (io/file (str data-dir)))
(do (println "Retrieving data for cnn text classification...") (sh "./get_data.sh")))

(defn shuffle-data [test-num {:keys [data label sentence-count sentence-size embedding-size]}]
(println "Shuffling the data and splitting into training and test sets")
(println {:sentence-count sentence-count
Expand Down Expand Up @@ -103,10 +111,10 @@
;;; omit max-examples if you want to run all the examples in the movie review dataset
;; to limit mem consumption set to something like 1000 and adjust test size to 100
(println "Running with context devices of" devs)
(train-convnet {:devs [(context/cpu)] :embedding-size 50 :batch-size 10 :test-size 100 :num-epoch 10 :max-examples 1000})
(train-convnet {:devs devs :embedding-size 50 :batch-size 10 :test-size 100 :num-epoch 10 :max-examples 1000})
;; runs all the examples
#_(train-convnet {:embedding-size 50 :batch-size 100 :test-size 1000 :num-epoch 10})))

(comment
(train-convnet {:devs [(context/cpu)] :embedding-size 50 :batch-size 10 :test-size 100 :num-epoch 10 :max-examples 1000}))
(train-convnet {:devs devs :embedding-size 50 :batch-size 10 :test-size 100 :num-epoch 10 :max-examples 1000}))

Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
;;
;; Licensed to the Apache Software Foundation (ASF) under one or more
;; contributor license agreements. See the NOTICE file distributed with
;; this work for additional information regarding copyright ownership.
;; The ASF licenses this file to You under the Apache License, Version 2.0
;; (the "License"); you may not use this file except in compliance with
;; the License. You may obtain a copy of the License at
;;
;; http://www.apache.org/licenses/LICENSE-2.0
;;
;; Unless required by applicable law or agreed to in writing, software
;; distributed under the License is distributed on an "AS IS" BASIS,
;; WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
;; See the License for the specific language governing permissions and
;; limitations under the License.
;;

(ns cnn-text-classification.classifier-test
(:require
[clojure.test :refer :all]
[org.apache.clojure-mxnet.module :as module]
[org.apache.clojure-mxnet.ndarray :as ndarray]
[org.apache.clojure-mxnet.util :as util]
[org.apache.clojure-mxnet.context :as context]
[cnn-text-classification.classifier :as classifier]))

;
; The one and unique classifier test
;
(deftest classifier-test
(let [train
(classifier/train-convnet
{:devs [(context/default-context)]
:embedding-size 50
:batch-size 10
:test-size 100
:num-epoch 1
:max-examples 1000})]
(is (= ["data"] (util/scala-vector->vec (module/data-names train))))
(is (= 20 (count (ndarray/->vec (-> train module/outputs first first)))))))
;(prn (util/scala-vector->vec (data-shapes train)))
;(prn (util/scala-vector->vec (label-shapes train)))
;(prn (output-names train))
;(prn (output-shapes train))
6 changes: 4 additions & 2 deletions contrib/clojure-package/examples/gan/src/gan/gan_mnist.clj
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,9 @@

(save-img-diff i n calc-diff))))

(defn train [devs]
(defn train
([devs] (train devs num-epoch))
([devs num-epoch]
(let [mod-d (-> (m/module (discriminator) {:contexts devs :data-names ["data"] :label-names ["label"]})
(m/bind {:data-shapes (mx-io/provide-data-desc mnist-iter)
:label-shapes (mx-io/provide-label-desc mnist-iter)
Expand Down Expand Up @@ -203,7 +205,7 @@
(save-img-gout i n (ndarray/copy (ffirst out-g)))
(save-img-data i n batch)
(calc-diff i n (ffirst diff-d)))
(inc n)))))))
(inc n))))))))

(defn -main [& args]
(let [[dev dev-num] args
Expand Down
25 changes: 25 additions & 0 deletions contrib/clojure-package/examples/gan/test/gan/gan_test.clj
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
;;
;; Licensed to the Apache Software Foundation (ASF) under one or more
;; contributor license agreements. See the NOTICE file distributed with
;; this work for additional information regarding copyright ownership.
;; The ASF licenses this file to You under the Apache License, Version 2.0
;; (the "License"); you may not use this file except in compliance with
;; the License. You may obtain a copy of the License at
;;
;; http://www.apache.org/licenses/LICENSE-2.0
;;
;; Unless required by applicable law or agreed to in writing, software
;; distributed under the License is distributed on an "AS IS" BASIS,
;; WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
;; See the License for the specific language governing permissions and
;; limitations under the License.
;;

(ns gan.gan_test
(:require
[gan.gan-mnist :refer :all]
[org.apache.clojure-mxnet.context :as context]
[clojure.test :refer :all]))

(deftest check-pdf
(train [(context/cpu)] 1))
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
(def batch-size 10) ;; the batch size
(def optimizer (optimizer/sgd {:learning-rate 0.01 :momentum 0.0}))
(def eval-metric (eval-metric/accuracy))
(def num-epoch 5) ;; the number of training epochs
(def num-epoch 1) ;; the number of training epochs
(def kvstore "local") ;; the kvstore type
;;; Note to run distributed you might need to complile the engine with an option set
(def role "worker") ;; scheduler/server/worker
Expand Down Expand Up @@ -82,7 +82,9 @@
(sym/fully-connected "fc3" {:data data :num-hidden 10})
(sym/softmax-output "softmax" {:data data})))

(defn start [devs]
(defn start
([devs] (start devs num-epoch))
([devs _num-epoch]
(when scheduler-host
(println "Initing PS enviornments with " envs)
(kvstore-server/init envs))
Expand All @@ -94,14 +96,18 @@
(do
(println "Starting Training of MNIST ....")
(println "Running with context devices of" devs)
(let [mod (m/module (get-symbol) {:contexts devs})]
(m/fit mod {:train-data train-data
(let [_mod (m/module (get-symbol) {:contexts devs})]
(m/fit _mod {:train-data train-data
:eval-data test-data
:num-epoch num-epoch
:num-epoch _num-epoch
:fit-params (m/fit-params {:kvstore kvstore
:optimizer optimizer
:eval-metric eval-metric})}))
(println "Finish fit"))))
:eval-metric eval-metric})})
(println "Finish fit")
_mod
)

))))

(defn -main [& args]
(let [[dev dev-num] args
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
;;
;; Licensed to the Apache Software Foundation (ASF) under one or more
;; contributor license agreements. See the NOTICE file distributed with
;; this work for additional information regarding copyright ownership.
;; The ASF licenses this file to You under the Apache License, Version 2.0
;; (the "License"); you may not use this file except in compliance with
;; the License. You may obtain a copy of the License at
;;
;; http://www.apache.org/licenses/LICENSE-2.0
;;
;; Unless required by applicable law or agreed to in writing, software
;; distributed under the License is distributed on an "AS IS" BASIS,
;; WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
;; See the License for the specific language governing permissions and
;; limitations under the License.
;;

(ns imclassification.train-mnist-test
(:require
[clojure.test :refer :all]
[clojure.java.io :as io]
[org.apache.clojure-mxnet.context :as context]
[org.apache.clojure-mxnet.module :as module]
[imclassification.train-mnist :as mnist]))

(deftest mnist-two-epochs-test
(module/save-checkpoint (mnist/start [(context/cpu)] 2) {:prefix "target/test" :epoch 2})
(is (= (slurp "test/test-0002.params") (slurp "target/test-0002.params"))))
29 changes: 29 additions & 0 deletions contrib/clojure-package/examples/module/test/mnist_mlp_test.clj
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
;;
;; Licensed to the Apache Software Foundation (ASF) under one or more
;; contributor license agreements. See the NOTICE file distributed with
;; this work for additional information regarding copyright ownership.
;; The ASF licenses this file to You under the Apache License, Version 2.0
;; (the "License"); you may not use this file except in compliance with
;; the License. You may obtain a copy of the License at
;;
;; http://www.apache.org/licenses/LICENSE-2.0
;;
;; Unless required by applicable law or agreed to in writing, software
;; distributed under the License is distributed on an "AS IS" BASIS,
;; WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
;; See the License for the specific language governing permissions and
;; limitations under the License.
;;
(ns mnist-mlp-test
(:require
[mnist-mlp :refer :all]
[org.apache.clojure-mxnet.context :as context]
[clojure.test :refer :all]))

(deftest run-those-tests
(let [devs [(context/cpu)]]
(run-intermediate-level-api :devs devs)
(run-intermediate-level-api :devs devs :load-model-epoch (dec num-epoch))
(run-high-level-api devs)
(run-prediction-iterator-api devs)
(run-predication-and-calc-accuracy-manually devs)))
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
;;
;; Licensed to the Apache Software Foundation (ASF) under one or more
;; contributor license agreements. See the NOTICE file distributed with
;; this work for additional information regarding copyright ownership.
;; The ASF licenses this file to You under the Apache License, Version 2.0
;; (the "License"); you may not use this file except in compliance with
;; the License. You may obtain a copy of the License at
;;
;; http://www.apache.org/licenses/LICENSE-2.0
;;
;; Unless required by applicable law or agreed to in writing, software
;; distributed under the License is distributed on an "AS IS" BASIS,
;; WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
;; See the License for the specific language governing permissions and
;; limitations under the License.
;;

(ns multi_label_test
(:require
[multi-label.core :as label]
[clojure.java.io :as io]
[org.apache.clojure-mxnet.context :as context]
[clojure.test :refer :all]))

(deftest run-multi-label
(label/train [(context/cpu)]))
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@
(def content-weight 5) ;; the weight for the content image
(def blur-radius 1) ;; the blur filter radius
(def output-dir "output")
(def lr 10) ;; the learning rate
(def lr 10.0) ;; the learning rate
(def tv-weight 0.01) ;; the magnitude on the tv loss
(def num-epochs 1000)
(def num-channels 3)
Expand Down Expand Up @@ -157,9 +157,10 @@
out (ndarray/* out tv-weight)]
(sym/bind out ctx {"img" img "kernel" kernel}))))

(defn train [devs]

(let [dev (first devs)
(defn train
([devs] (train devs 20))
([devs n-epochs]
(let [dev (first devs)
content-np (preprocess-content-image content-image max-long-edge)
content-np-shape (mx-shape/->vec (ndarray/shape content-np))
style-np (preprocess-style-image style-image content-np-shape)
Expand Down Expand Up @@ -212,7 +213,7 @@
tv-grad-executor (get-tv-grad-executor img dev tv-weight)
eps 0.0
e 0]
(doseq [i (range 20)]
(doseq [i (range n-epochs)]
(ndarray/set (:data model-executor) img)
(-> (:executor model-executor)
(executor/forward)
Expand All @@ -237,8 +238,10 @@
(println "Epoch " i "relative change " eps)
(when (zero? (mod i 2))
(save-image (ndarray/copy img) (str output-dir "/out_" i ".png") blur-radius true)))

(ndarray/set old-img img))))
(ndarray/set old-img img))
; (save-image (ndarray/copy img) (str output-dir "/final.png") 0 false)
; (postprocess-image img)
)))

(defn -main [& args]
;;; Note this only works on cpu right now
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
;;
;; Licensed to the Apache Software Foundation (ASF) under one or more
;; contributor license agreements. See the NOTICE file distributed with
;; this work for additional information regarding copyright ownership.
;; The ASF licenses this file to You under the Apache License, Version 2.0
;; (the "License"); you may not use this file except in compliance with
;; the License. You may obtain a copy of the License at
;;
;; http://www.apache.org/licenses/LICENSE-2.0
;;
;; Unless required by applicable law or agreed to in writing, software
;; distributed under the License is distributed on an "AS IS" BASIS,
;; WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
;; See the License for the specific language governing permissions and
;; limitations under the License.
;;

(ns neural-style.vgg-19-test
(:require
[clojure.test :refer :all]
[mikera.image.core :as img]
[clojure.java.io :as io]
[org.apache.clojure-mxnet.ndarray :as ndarray]
[org.apache.clojure-mxnet.context :as context]
[neural-style.core :as neural]))

(defn pic-to-ndarray-vec[path]
(-> path
img/load-image
neural/image->ndarray
ndarray/->vec))

(defn last-modified-check[x]
(let [t (- (System/currentTimeMillis) (.lastModified x)) ]
(if (> 10000 t) ; 10 seconds
x
(throw (Exception. (str "Generated File Too Old: (" t " ms) [" x "]"))))))

(defn latest-pic-to-ndarray-vec[folder]
(->> folder
io/as-file
(.listFiles)
(sort-by #(.lastModified %))
last
(last-modified-check)
(.getPath)
pic-to-ndarray-vec))

;
; The one and unique classifier test
;
(deftest vgg-19-test
(neural/train [(context/cpu)] 3)
(is (=
(latest-pic-to-ndarray-vec "output")
(pic-to-ndarray-vec "test/ref_out_2.png"))))
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,9 @@
(def profiler-mode "symbolic") ;; can be symbolic, imperative, api, mem
(def output-path ".") ;; the profile file output directory
(def profiler-name "profile-matmul-20iter.json")
(def iter-num 100)
(def begin-profiling-iter 50)
(def end-profiling-iter 70)
(def iter-num 5)
(def begin-profiling-iter 0)
(def end-profiling-iter 1)
(def gpu? false)

(defn run []
Expand Down
Loading

0 comments on commit ba812a4

Please sign in to comment.