diff --git a/contrib/clojure-package/examples/cnn-text-classification/src/cnn_text_classification/classifier.clj b/contrib/clojure-package/examples/cnn-text-classification/src/cnn_text_classification/classifier.clj index 29ff36fe1ec0..94fd4f518c60 100644 --- a/contrib/clojure-package/examples/cnn-text-classification/src/cnn_text_classification/classifier.clj +++ b/contrib/clojure-package/examples/cnn-text-classification/src/cnn_text_classification/classifier.clj @@ -16,7 +16,9 @@ ;; (ns cnn-text-classification.classifier - (:require [cnn-text-classification.data-helper :as data-helper] + (:require [clojure.java.io :as io] + [clojure.java.shell :refer [sh]] + [cnn-text-classification.data-helper :as data-helper] [org.apache.clojure-mxnet.eval-metric :as eval-metric] [org.apache.clojure-mxnet.io :as mx-io] [org.apache.clojure-mxnet.module :as m] @@ -26,12 +28,18 @@ [org.apache.clojure-mxnet.context :as context]) (:gen-class)) +(def data-dir "data/") (def mr-dataset-path "data/mr-data") ;; the MR polarity dataset path (def glove-file-path "data/glove/glove.6B.50d.txt") (def num-filter 100) (def num-label 2) (def dropout 0.5) + + +(when-not (.exists (io/file (str data-dir))) + (do (println "Retrieving data for cnn text classification...") (sh "./get_data.sh"))) + (defn shuffle-data [test-num {:keys [data label sentence-count sentence-size embedding-size]}] (println "Shuffling the data and splitting into training and test sets") (println {:sentence-count sentence-count @@ -103,10 +111,10 @@ ;;; omit max-examples if you want to run all the examples in the movie review dataset ;; to limit mem consumption set to something like 1000 and adjust test size to 100 (println "Running with context devices of" devs) - (train-convnet {:devs [(context/cpu)] :embedding-size 50 :batch-size 10 :test-size 100 :num-epoch 10 :max-examples 1000}) + (train-convnet {:devs devs :embedding-size 50 :batch-size 10 :test-size 100 :num-epoch 10 :max-examples 1000}) ;; runs all the examples #_(train-convnet {:embedding-size 50 :batch-size 100 :test-size 1000 :num-epoch 10}))) (comment - (train-convnet {:devs [(context/cpu)] :embedding-size 50 :batch-size 10 :test-size 100 :num-epoch 10 :max-examples 1000})) + (train-convnet {:devs devs :embedding-size 50 :batch-size 10 :test-size 100 :num-epoch 10 :max-examples 1000})) diff --git a/contrib/clojure-package/examples/cnn-text-classification/test/cnn_text_classification/classifier_test.clj b/contrib/clojure-package/examples/cnn-text-classification/test/cnn_text_classification/classifier_test.clj new file mode 100644 index 000000000000..883ba2da8c8e --- /dev/null +++ b/contrib/clojure-package/examples/cnn-text-classification/test/cnn_text_classification/classifier_test.clj @@ -0,0 +1,27 @@ +(ns cnn-text-classification.classifier-test + (:require + [clojure.test :refer :all] + [org.apache.clojure-mxnet.module :as module] + [org.apache.clojure-mxnet.ndarray :as ndarray] + [org.apache.clojure-mxnet.util :as util] + [org.apache.clojure-mxnet.context :as context] + [cnn-text-classification.classifier :as classifier])) + +; +; The one and unique classifier test +; +(deftest classifier-test + (let [train + (classifier/train-convnet + {:devs [(context/default-context)] + :embedding-size 50 + :batch-size 10 + :test-size 100 + :num-epoch 1 + :max-examples 1000})] + (is (= ["data"] (util/scala-vector->vec (module/data-names train)))) + (is (= 20 (count (ndarray/->vec (-> train module/outputs first first))))))) + ;(prn (util/scala-vector->vec (data-shapes train))) + ;(prn (util/scala-vector->vec (label-shapes train))) + ;(prn (output-names train)) + ;(prn (output-shapes train)) \ No newline at end of file