Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

#13385 [Clojure] - Turn examples into integration tests #13554

Merged
merged 1 commit into from
Dec 11, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,9 @@
;;

(ns cnn-text-classification.classifier
(:require [cnn-text-classification.data-helper :as data-helper]
(:require [clojure.java.io :as io]
[clojure.java.shell :refer [sh]]
[cnn-text-classification.data-helper :as data-helper]
[org.apache.clojure-mxnet.eval-metric :as eval-metric]
[org.apache.clojure-mxnet.io :as mx-io]
[org.apache.clojure-mxnet.module :as m]
Expand All @@ -26,12 +28,18 @@
[org.apache.clojure-mxnet.context :as context])
(:gen-class))

(def data-dir "data/")
(def mr-dataset-path "data/mr-data") ;; the MR polarity dataset path
(def glove-file-path "data/glove/glove.6B.50d.txt")
(def num-filter 100)
(def num-label 2)
(def dropout 0.5)



(when-not (.exists (io/file (str data-dir)))
(do (println "Retrieving data for cnn text classification...") (sh "./get_data.sh")))

(defn shuffle-data [test-num {:keys [data label sentence-count sentence-size embedding-size]}]
(println "Shuffling the data and splitting into training and test sets")
(println {:sentence-count sentence-count
Expand Down Expand Up @@ -103,10 +111,10 @@
;;; omit max-examples if you want to run all the examples in the movie review dataset
;; to limit mem consumption set to something like 1000 and adjust test size to 100
(println "Running with context devices of" devs)
(train-convnet {:devs [(context/cpu)] :embedding-size 50 :batch-size 10 :test-size 100 :num-epoch 10 :max-examples 1000})
(train-convnet {:devs devs :embedding-size 50 :batch-size 10 :test-size 100 :num-epoch 10 :max-examples 1000})
;; runs all the examples
#_(train-convnet {:embedding-size 50 :batch-size 100 :test-size 1000 :num-epoch 10})))

(comment
(train-convnet {:devs [(context/cpu)] :embedding-size 50 :batch-size 10 :test-size 100 :num-epoch 10 :max-examples 1000}))
(train-convnet {:devs devs :embedding-size 50 :batch-size 10 :test-size 100 :num-epoch 10 :max-examples 1000}))

Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
;;
;; Licensed to the Apache Software Foundation (ASF) under one or more
;; contributor license agreements. See the NOTICE file distributed with
;; this work for additional information regarding copyright ownership.
;; The ASF licenses this file to You under the Apache License, Version 2.0
;; (the "License"); you may not use this file except in compliance with
;; the License. You may obtain a copy of the License at
;;
;; http://www.apache.org/licenses/LICENSE-2.0
;;
;; Unless required by applicable law or agreed to in writing, software
;; distributed under the License is distributed on an "AS IS" BASIS,
;; WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
;; See the License for the specific language governing permissions and
;; limitations under the License.
;;

(ns cnn-text-classification.classifier-test
(:require
[clojure.test :refer :all]
[org.apache.clojure-mxnet.module :as module]
[org.apache.clojure-mxnet.ndarray :as ndarray]
[org.apache.clojure-mxnet.util :as util]
[org.apache.clojure-mxnet.context :as context]
[cnn-text-classification.classifier :as classifier]))

;
; The one and unique classifier test
;
(deftest classifier-test
(let [train
(classifier/train-convnet
{:devs [(context/default-context)]
:embedding-size 50
:batch-size 10
:test-size 100
:num-epoch 1
:max-examples 1000})]
(is (= ["data"] (util/scala-vector->vec (module/data-names train))))
(is (= 20 (count (ndarray/->vec (-> train module/outputs first first)))))))
;(prn (util/scala-vector->vec (data-shapes train)))
;(prn (util/scala-vector->vec (label-shapes train)))
;(prn (output-names train))
;(prn (output-shapes train))
3 changes: 2 additions & 1 deletion contrib/clojure-package/examples/gan/project.clj
Original file line number Diff line number Diff line change
Expand Up @@ -20,5 +20,6 @@
:plugins [[lein-cljfmt "0.5.7"]]
:dependencies [[org.clojure/clojure "1.9.0"]
[org.apache.mxnet.contrib.clojure/clojure-mxnet "1.5.0-SNAPSHOT"]
[nu.pattern/opencv "2.4.9-7"]]
[org.openpnp/opencv "3.4.2-1"]
]
:main gan.gan-mnist)
6 changes: 4 additions & 2 deletions contrib/clojure-package/examples/gan/src/gan/gan_mnist.clj
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,9 @@

(save-img-diff i n calc-diff))))

(defn train [devs]
(defn train
([devs] (train devs num-epoch))
([devs num-epoch]
(let [mod-d (-> (m/module (discriminator) {:contexts devs :data-names ["data"] :label-names ["label"]})
(m/bind {:data-shapes (mx-io/provide-data-desc mnist-iter)
:label-shapes (mx-io/provide-label-desc mnist-iter)
Expand Down Expand Up @@ -203,7 +205,7 @@
(save-img-gout i n (ndarray/copy (ffirst out-g)))
(save-img-data i n batch)
(calc-diff i n (ffirst diff-d)))
(inc n)))))))
(inc n))))))))

(defn -main [& args]
(let [[dev dev-num] args
Expand Down
4 changes: 2 additions & 2 deletions contrib/clojure-package/examples/gan/src/gan/viz.clj
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
(:import (nu.pattern OpenCV)
(org.opencv.core Core CvType Mat Size)
(org.opencv.imgproc Imgproc)
(org.opencv.highgui Highgui)))
(org.opencv.imgcodecs Imgcodecs)))

;;; Viz stuff
(OpenCV/loadShared)
Expand Down Expand Up @@ -83,5 +83,5 @@
_ (Core/vconcat (java.util.ArrayList. line-mats) result)]
(do
(Imgproc/resize result resized-img (new Size (* (.width result) 1.5) (* (.height result) 1.5)))
(Highgui/imwrite (str output-path title ".jpg") resized-img)
(Imgcodecs/imwrite (str output-path title ".jpg") resized-img)
(Thread/sleep 1000))))
25 changes: 25 additions & 0 deletions contrib/clojure-package/examples/gan/test/gan/gan_test.clj
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
;;
;; Licensed to the Apache Software Foundation (ASF) under one or more
;; contributor license agreements. See the NOTICE file distributed with
;; this work for additional information regarding copyright ownership.
;; The ASF licenses this file to You under the Apache License, Version 2.0
;; (the "License"); you may not use this file except in compliance with
;; the License. You may obtain a copy of the License at
;;
;; http://www.apache.org/licenses/LICENSE-2.0
;;
;; Unless required by applicable law or agreed to in writing, software
;; distributed under the License is distributed on an "AS IS" BASIS,
;; WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
;; See the License for the specific language governing permissions and
;; limitations under the License.
;;

(ns gan.gan_test
(:require
[gan.gan-mnist :refer :all]
[org.apache.clojure-mxnet.context :as context]
[clojure.test :refer :all]))

(deftest check-pdf
(train [(context/cpu)] 1))
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
(def batch-size 10) ;; the batch size
(def optimizer (optimizer/sgd {:learning-rate 0.01 :momentum 0.0}))
(def eval-metric (eval-metric/accuracy))
(def num-epoch 5) ;; the number of training epochs
(def num-epoch 1) ;; the number of training epochs
(def kvstore "local") ;; the kvstore type
;;; Note to run distributed you might need to complile the engine with an option set
(def role "worker") ;; scheduler/server/worker
Expand Down Expand Up @@ -82,7 +82,9 @@
(sym/fully-connected "fc3" {:data data :num-hidden 10})
(sym/softmax-output "softmax" {:data data})))

(defn start [devs]
(defn start
([devs] (start devs num-epoch))
([devs _num-epoch]
(when scheduler-host
(println "Initing PS enviornments with " envs)
(kvstore-server/init envs))
Expand All @@ -94,14 +96,18 @@
(do
(println "Starting Training of MNIST ....")
(println "Running with context devices of" devs)
(let [mod (m/module (get-symbol) {:contexts devs})]
(m/fit mod {:train-data train-data
(let [_mod (m/module (get-symbol) {:contexts devs})]
(m/fit _mod {:train-data train-data
:eval-data test-data
:num-epoch num-epoch
:num-epoch _num-epoch
:fit-params (m/fit-params {:kvstore kvstore
:optimizer optimizer
:eval-metric eval-metric})}))
(println "Finish fit"))))
:eval-metric eval-metric})})
(println "Finish fit")
_mod
)

))))

(defn -main [& args]
(let [[dev dev-num] args
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
;;
;; Licensed to the Apache Software Foundation (ASF) under one or more
;; contributor license agreements. See the NOTICE file distributed with
;; this work for additional information regarding copyright ownership.
;; The ASF licenses this file to You under the Apache License, Version 2.0
;; (the "License"); you may not use this file except in compliance with
;; the License. You may obtain a copy of the License at
;;
;; http://www.apache.org/licenses/LICENSE-2.0
;;
;; Unless required by applicable law or agreed to in writing, software
;; distributed under the License is distributed on an "AS IS" BASIS,
;; WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
;; See the License for the specific language governing permissions and
;; limitations under the License.
;;

(ns imclassification.train-mnist-test
(:require
[clojure.test :refer :all]
[clojure.java.io :as io]
[clojure.string :as s]
[org.apache.clojure-mxnet.context :as context]
[org.apache.clojure-mxnet.module :as module]
[imclassification.train-mnist :as mnist]))

(defn- file-to-filtered-seq [file]
(->>
file
(io/file)
(io/reader)
(line-seq)
(filter #(not (s/includes? % "mxnet_version")))))

(deftest mnist-two-epochs-test
(module/save-checkpoint (mnist/start [(context/cpu)] 2) {:prefix "target/test" :epoch 2})
(is (=
(file-to-filtered-seq "test/test-symbol.json.ref")
(file-to-filtered-seq "target/test-symbol.json"))))
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
{
"nodes": [
{
"op": "null",
"name": "data",
"inputs": []
},
{
"op": "null",
"name": "fc1_weight",
"attrs": {"num_hidden": "128"},
"inputs": []
},
{
"op": "null",
"name": "fc1_bias",
"attrs": {"num_hidden": "128"},
"inputs": []
},
{
"op": "FullyConnected",
"name": "fc1",
"attrs": {"num_hidden": "128"},
"inputs": [[0, 0, 0], [1, 0, 0], [2, 0, 0]]
},
{
"op": "Activation",
"name": "relu1",
"attrs": {"act_type": "relu"},
"inputs": [[3, 0, 0]]
},
{
"op": "null",
"name": "fc2_weight",
"attrs": {"num_hidden": "64"},
"inputs": []
},
{
"op": "null",
"name": "fc2_bias",
"attrs": {"num_hidden": "64"},
"inputs": []
},
{
"op": "FullyConnected",
"name": "fc2",
"attrs": {"num_hidden": "64"},
"inputs": [[4, 0, 0], [5, 0, 0], [6, 0, 0]]
},
{
"op": "Activation",
"name": "relu2",
"attrs": {"act_type": "relu"},
"inputs": [[7, 0, 0]]
},
{
"op": "null",
"name": "fc3_weight",
"attrs": {"num_hidden": "10"},
"inputs": []
},
{
"op": "null",
"name": "fc3_bias",
"attrs": {"num_hidden": "10"},
"inputs": []
},
{
"op": "FullyConnected",
"name": "fc3",
"attrs": {"num_hidden": "10"},
"inputs": [[8, 0, 0], [9, 0, 0], [10, 0, 0]]
},
{
"op": "null",
"name": "softmax_label",
"inputs": []
},
{
"op": "SoftmaxOutput",
"name": "softmax",
"inputs": [[11, 0, 0], [12, 0, 0]]
}
],
"arg_nodes": [0, 1, 2, 5, 6, 9, 10, 12],
"node_row_ptr": [
0,
1,
2,
3,
4,
5,
6,
7,
8,
9,
10,
11,
12,
13,
14
],
"heads": [[13, 0, 0]],
"attrs": {"mxnet_version": ["int", 10400]}
}
29 changes: 29 additions & 0 deletions contrib/clojure-package/examples/module/test/mnist_mlp_test.clj
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
;;
;; Licensed to the Apache Software Foundation (ASF) under one or more
;; contributor license agreements. See the NOTICE file distributed with
;; this work for additional information regarding copyright ownership.
;; The ASF licenses this file to You under the Apache License, Version 2.0
;; (the "License"); you may not use this file except in compliance with
;; the License. You may obtain a copy of the License at
;;
;; http://www.apache.org/licenses/LICENSE-2.0
;;
;; Unless required by applicable law or agreed to in writing, software
;; distributed under the License is distributed on an "AS IS" BASIS,
;; WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
;; See the License for the specific language governing permissions and
;; limitations under the License.
;;
(ns mnist-mlp-test
(:require
[mnist-mlp :refer :all]
[org.apache.clojure-mxnet.context :as context]
[clojure.test :refer :all]))

(deftest run-those-tests
(let [devs [(context/cpu)]]
(run-intermediate-level-api :devs devs)
(run-intermediate-level-api :devs devs :load-model-epoch (dec num-epoch))
(run-high-level-api devs)
(run-prediction-iterator-api devs)
(run-predication-and-calc-accuracy-manually devs)))
Loading