diff --git a/contrib/clojure-package/examples/captcha/.gitignore b/contrib/clojure-package/examples/captcha/.gitignore new file mode 100644 index 000000000000..e1569bd89020 --- /dev/null +++ b/contrib/clojure-package/examples/captcha/.gitignore @@ -0,0 +1,3 @@ +/.lein-* +/.nrepl-port +images/* diff --git a/contrib/clojure-package/examples/captcha/README.md b/contrib/clojure-package/examples/captcha/README.md new file mode 100644 index 000000000000..6b593b2f1c65 --- /dev/null +++ b/contrib/clojure-package/examples/captcha/README.md @@ -0,0 +1,61 @@ +# Captcha + +This is the clojure version of [captcha recognition](/~https://github.com/xlvector/learning-dl/tree/master/mxnet/ocr) +example by xlvector and mirrors the R captcha example. It can be used as an +example of multi-label training. For the following captcha example, we consider it as an +image with 4 labels and train a CNN over the data set. + +![captcha example](captcha_example.png) + +## Installation + +Before you run this example, make sure that you have the clojure package +installed. In the main clojure package directory, do `lein install`. +Then you can run `lein install` in this directory. + +## Usage + +### Training + +First the OCR model needs to be trained based on [labeled data](https://apache-mxnet.s3-accelerate.dualstack.amazonaws.com/R/data/captcha_example.zip). +The training can be started using the following: +``` +$ lein train [:cpu|:gpu] [num-devices] +``` +This downloads the training/evaluation data using the `get_data.sh` script +before starting training. + +It is possible that you will encounter some out-of-memory issues while training using :gpu on Ubuntu +linux (18.04). However, the command `lein train` (training on one CPU) may resolve the issue. + +The training runs for 10 iterations by default and saves the model with the +prefix `ocr-`. The model achieved an exact match accuracy of ~0.954 and +~0.628 on training and validation data respectively. + +### Inference + +Once the model has been saved, it can be used for prediction. This can be done +by running: +``` +$ lein infer +INFO MXNetJVM: Try loading mxnet-scala from native path. +INFO MXNetJVM: Try loading mxnet-scala-linux-x86_64-gpu from native path. +INFO MXNetJVM: Try loading mxnet-scala-linux-x86_64-cpu from native path. +WARN MXNetJVM: MXNet Scala native library not found in path. Copying native library from the archive. Consider installing the library somewhere in the path (for Windows: PATH, for Linux: LD_LIBRARY_PATH), or specifying by Java cmd option -Djava.library.path=[lib path]. +WARN org.apache.mxnet.DataDesc: Found Undefined Layout, will use default index 0 for batch axis +INFO org.apache.mxnet.infer.Predictor: Latency increased due to batchSize mismatch 8 vs 1 +WARN org.apache.mxnet.DataDesc: Found Undefined Layout, will use default index 0 for batch axis +WARN org.apache.mxnet.DataDesc: Found Undefined Layout, will use default index 0 for batch axis +CAPTCHA output: 6643 +INFO org.apache.mxnet.util.NativeLibraryLoader: Deleting /tmp/mxnet6045308279291774865/libmxnet.so +INFO org.apache.mxnet.util.NativeLibraryLoader: Deleting /tmp/mxnet6045308279291774865/mxnet-scala +INFO org.apache.mxnet.util.NativeLibraryLoader: Deleting /tmp/mxnet6045308279291774865 +``` +The model runs on `captcha_example.png` by default. + +It can be run on other generated captcha images as well. The script +`gen_captcha.py` generates random captcha images for length 4. +Before running the python script, you will need to install the [captcha](https://pypi.org/project/captcha/) +library using `pip3 install --user captcha`. The captcha images are generated +in the `images/` folder and we can run the prediction using +`lein infer images/7534.png`. diff --git a/contrib/clojure-package/examples/captcha/captcha_example.png b/contrib/clojure-package/examples/captcha/captcha_example.png new file mode 100644 index 000000000000..09b84f7190fa Binary files /dev/null and b/contrib/clojure-package/examples/captcha/captcha_example.png differ diff --git a/contrib/clojure-package/examples/captcha/gen_captcha.py b/contrib/clojure-package/examples/captcha/gen_captcha.py new file mode 100755 index 000000000000..43e0d26fb961 --- /dev/null +++ b/contrib/clojure-package/examples/captcha/gen_captcha.py @@ -0,0 +1,40 @@ +#!/usr/bin/env python3 + +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from captcha.image import ImageCaptcha +import os +import random + +length = 4 +width = 160 +height = 60 +IMAGE_DIR = "images" + + +def random_text(): + return ''.join(str(random.randint(0, 9)) + for _ in range(length)) + + +if __name__ == '__main__': + image = ImageCaptcha(width=width, height=height) + captcha_text = random_text() + if not os.path.exists(IMAGE_DIR): + os.makedirs(IMAGE_DIR) + image.write(captcha_text, os.path.join(IMAGE_DIR, captcha_text + ".png")) diff --git a/contrib/clojure-package/examples/captcha/get_data.sh b/contrib/clojure-package/examples/captcha/get_data.sh new file mode 100755 index 000000000000..baa7f9eb818f --- /dev/null +++ b/contrib/clojure-package/examples/captcha/get_data.sh @@ -0,0 +1,32 @@ +#!/bin/bash + +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +set -evx + +EXAMPLE_ROOT=$(cd "$(dirname $0)"; pwd) + +data_path=$EXAMPLE_ROOT + +if [ ! -f "$data_path/captcha_example.zip" ]; then + wget https://apache-mxnet.s3-accelerate.dualstack.amazonaws.com/R/data/captcha_example.zip -P $data_path +fi + +if [ ! -f "$data_path/captcha_example/captcha_train.rec" ]; then + unzip $data_path/captcha_example.zip -d $data_path +fi diff --git a/contrib/clojure-package/examples/captcha/project.clj b/contrib/clojure-package/examples/captcha/project.clj new file mode 100644 index 000000000000..fa37fecbe035 --- /dev/null +++ b/contrib/clojure-package/examples/captcha/project.clj @@ -0,0 +1,28 @@ +;; +;; Licensed to the Apache Software Foundation (ASF) under one or more +;; contributor license agreements. See the NOTICE file distributed with +;; this work for additional information regarding copyright ownership. +;; The ASF licenses this file to You under the Apache License, Version 2.0 +;; (the "License"); you may not use this file except in compliance with +;; the License. You may obtain a copy of the License at +;; +;; http://www.apache.org/licenses/LICENSE-2.0 +;; +;; Unless required by applicable law or agreed to in writing, software +;; distributed under the License is distributed on an "AS IS" BASIS, +;; WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +;; See the License for the specific language governing permissions and +;; limitations under the License. +;; + +(defproject captcha "0.1.0-SNAPSHOT" + :description "Captcha recognition via multi-label classification" + :plugins [[lein-cljfmt "0.5.7"]] + :dependencies [[org.clojure/clojure "1.9.0"] + [org.apache.mxnet.contrib.clojure/clojure-mxnet "1.5.0-SNAPSHOT"]] + :main ^:skip-aot captcha.train-ocr + :profiles {:train {:main captcha.train-ocr} + :infer {:main captcha.infer-ocr} + :uberjar {:aot :all}} + :aliases {"train" ["with-profile" "train" "run"] + "infer" ["with-profile" "infer" "run"]}) diff --git a/contrib/clojure-package/examples/captcha/src/captcha/consts.clj b/contrib/clojure-package/examples/captcha/src/captcha/consts.clj new file mode 100644 index 000000000000..318e0d806873 --- /dev/null +++ b/contrib/clojure-package/examples/captcha/src/captcha/consts.clj @@ -0,0 +1,27 @@ +;; +;; Licensed to the Apache Software Foundation (ASF) under one or more +;; contributor license agreements. See the NOTICE file distributed with +;; this work for additional information regarding copyright ownership. +;; The ASF licenses this file to You under the Apache License, Version 2.0 +;; (the "License"); you may not use this file except in compliance with +;; the License. You may obtain a copy of the License at +;; +;; http://www.apache.org/licenses/LICENSE-2.0 +;; +;; Unless required by applicable law or agreed to in writing, software +;; distributed under the License is distributed on an "AS IS" BASIS, +;; WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +;; See the License for the specific language governing permissions and +;; limitations under the License. +;; + +(ns captcha.consts) + +(def batch-size 8) +(def channels 3) +(def height 30) +(def width 80) +(def data-shape [channels height width]) +(def num-labels 10) +(def label-width 4) +(def model-prefix "ocr") diff --git a/contrib/clojure-package/examples/captcha/src/captcha/infer_ocr.clj b/contrib/clojure-package/examples/captcha/src/captcha/infer_ocr.clj new file mode 100644 index 000000000000..f6a648e9867b --- /dev/null +++ b/contrib/clojure-package/examples/captcha/src/captcha/infer_ocr.clj @@ -0,0 +1,56 @@ +;; +;; Licensed to the Apache Software Foundation (ASF) under one or more +;; contributor license agreements. See the NOTICE file distributed with +;; this work for additional information regarding copyright ownership. +;; The ASF licenses this file to You under the Apache License, Version 2.0 +;; (the "License"); you may not use this file except in compliance with +;; the License. You may obtain a copy of the License at +;; +;; http://www.apache.org/licenses/LICENSE-2.0 +;; +;; Unless required by applicable law or agreed to in writing, software +;; distributed under the License is distributed on an "AS IS" BASIS, +;; WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +;; See the License for the specific language governing permissions and +;; limitations under the License. +;; + +(ns captcha.infer-ocr + (:require [captcha.consts :refer :all] + [org.apache.clojure-mxnet.dtype :as dtype] + [org.apache.clojure-mxnet.infer :as infer] + [org.apache.clojure-mxnet.layout :as layout] + [org.apache.clojure-mxnet.ndarray :as ndarray])) + +(defn create-predictor + [] + (let [data-desc {:name "data" + :shape [batch-size channels height width] + :layout layout/NCHW + :dtype dtype/FLOAT32} + label-desc {:name "label" + :shape [batch-size label-width] + :layout layout/NT + :dtype dtype/FLOAT32} + factory (infer/model-factory model-prefix + [data-desc label-desc])] + (infer/create-predictor factory))) + +(defn -main + [& args] + (let [[filename] args + image-fname (or filename "captcha_example.png") + image-ndarray (-> image-fname + infer/load-image-from-file + (infer/reshape-image width height) + (infer/buffered-image-to-pixels [channels height width]) + (ndarray/expand-dims 0)) + label-ndarray (ndarray/zeros [1 label-width]) + predictor (create-predictor) + predictions (-> (infer/predict-with-ndarray + predictor + [image-ndarray label-ndarray]) + first + (ndarray/argmax 1) + ndarray/->vec)] + (println "CAPTCHA output:" (apply str (mapv int predictions))))) diff --git a/contrib/clojure-package/examples/captcha/src/captcha/train_ocr.clj b/contrib/clojure-package/examples/captcha/src/captcha/train_ocr.clj new file mode 100644 index 000000000000..91ec2fff3af7 --- /dev/null +++ b/contrib/clojure-package/examples/captcha/src/captcha/train_ocr.clj @@ -0,0 +1,156 @@ +;; +;; Licensed to the Apache Software Foundation (ASF) under one or more +;; contributor license agreements. See the NOTICE file distributed with +;; this work for additional information regarding copyright ownership. +;; The ASF licenses this file to You under the Apache License, Version 2.0 +;; (the "License"); you may not use this file except in compliance with +;; the License. You may obtain a copy of the License at +;; +;; http://www.apache.org/licenses/LICENSE-2.0 +;; +;; Unless required by applicable law or agreed to in writing, software +;; distributed under the License is distributed on an "AS IS" BASIS, +;; WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +;; See the License for the specific language governing permissions and +;; limitations under the License. +;; + +(ns captcha.train-ocr + (:require [captcha.consts :refer :all] + [clojure.java.io :as io] + [clojure.java.shell :refer [sh]] + [org.apache.clojure-mxnet.callback :as callback] + [org.apache.clojure-mxnet.context :as context] + [org.apache.clojure-mxnet.eval-metric :as eval-metric] + [org.apache.clojure-mxnet.initializer :as initializer] + [org.apache.clojure-mxnet.io :as mx-io] + [org.apache.clojure-mxnet.module :as m] + [org.apache.clojure-mxnet.ndarray :as ndarray] + [org.apache.clojure-mxnet.optimizer :as optimizer] + [org.apache.clojure-mxnet.symbol :as sym]) + (:gen-class)) + +(when-not (.exists (io/file "captcha_example/captcha_train.lst")) + (sh "./get_data.sh")) + +(defonce train-data + (mx-io/image-record-iter {:path-imgrec "captcha_example/captcha_train.rec" + :path-imglist "captcha_example/captcha_train.lst" + :batch-size batch-size + :label-width label-width + :data-shape data-shape + :shuffle true + :seed 42})) + +(defonce eval-data + (mx-io/image-record-iter {:path-imgrec "captcha_example/captcha_test.rec" + :path-imglist "captcha_example/captcha_test.lst" + :batch-size batch-size + :label-width label-width + :data-shape data-shape})) + +(defn accuracy + [label pred & {:keys [by-character] + :or {by-character false} :as opts}] + (let [[nr nc] (ndarray/shape-vec label) + pred-context (ndarray/context pred) + label-t (-> label + ndarray/transpose + (ndarray/reshape [-1]) + (ndarray/as-in-context pred-context)) + pred-label (ndarray/argmax pred 1) + matches (ndarray/equal label-t pred-label) + [digit-matches] (-> matches + ndarray/sum + ndarray/->vec) + [complete-matches] (-> matches + (ndarray/reshape [nc nr]) + (ndarray/sum 0) + (ndarray/equal label-width) + ndarray/sum + ndarray/->vec)] + (if by-character + (float (/ digit-matches nr nc)) + (float (/ complete-matches nr))))) + +(defn get-data-symbol + [] + (let [data (sym/variable "data") + ;; normalize the input pixels + scaled (sym/div (sym/- data 127) 128) + + conv1 (sym/convolution {:data scaled :kernel [5 5] :num-filter 32}) + pool1 (sym/pooling {:data conv1 :pool-type "max" :kernel [2 2] :stride [1 1]}) + relu1 (sym/activation {:data pool1 :act-type "relu"}) + + conv2 (sym/convolution {:data relu1 :kernel [5 5] :num-filter 32}) + pool2 (sym/pooling {:data conv2 :pool-type "avg" :kernel [2 2] :stride [1 1]}) + relu2 (sym/activation {:data pool2 :act-type "relu"}) + + conv3 (sym/convolution {:data relu2 :kernel [3 3] :num-filter 32}) + pool3 (sym/pooling {:data conv3 :pool-type "avg" :kernel [2 2] :stride [1 1]}) + relu3 (sym/activation {:data pool3 :act-type "relu"}) + + conv4 (sym/convolution {:data relu3 :kernel [3 3] :num-filter 32}) + pool4 (sym/pooling {:data conv4 :pool-type "avg" :kernel [2 2] :stride [1 1]}) + relu4 (sym/activation {:data pool4 :act-type "relu"}) + + flattened (sym/flatten {:data relu4}) + fc1 (sym/fully-connected {:data flattened :num-hidden 256}) + fc21 (sym/fully-connected {:data fc1 :num-hidden num-labels}) + fc22 (sym/fully-connected {:data fc1 :num-hidden num-labels}) + fc23 (sym/fully-connected {:data fc1 :num-hidden num-labels}) + fc24 (sym/fully-connected {:data fc1 :num-hidden num-labels})] + (sym/concat "concat" nil [fc21 fc22 fc23 fc24] {:dim 0}))) + +(defn get-label-symbol + [] + (as-> (sym/variable "label") label + (sym/transpose {:data label}) + (sym/reshape {:data label :shape [-1]}))) + +(defn create-captcha-net + [] + (let [scores (get-data-symbol) + labels (get-label-symbol)] + (sym/softmax-output {:data scores :label labels}))) + +(def optimizer + (optimizer/adam + {:learning-rate 0.0002 + :wd 0.00001 + :clip-gradient 10})) + +(defn train-ocr + [devs] + (println "Starting the captcha training ...") + (let [model (m/module + (create-captcha-net) + {:data-names ["data"] :label-names ["label"] + :contexts devs})] + (m/fit model {:train-data train-data + :eval-data eval-data + :num-epoch 10 + :fit-params (m/fit-params + {:kvstore "local" + :batch-end-callback + (callback/speedometer batch-size 100) + :initializer + (initializer/xavier {:factor-type "in" + :magnitude 2.34}) + :optimizer optimizer + :eval-metric (eval-metric/custom-metric + #(accuracy %1 %2) + "accuracy")})}) + (println "Finished the fit") + model)) + +(defn -main + [& args] + (let [[dev dev-num] args + num-devices (Integer/parseInt (or dev-num "1")) + devs (if (= dev ":gpu") + (mapv #(context/gpu %) (range num-devices)) + (mapv #(context/cpu %) (range num-devices))) + model (train-ocr devs)] + (m/save-checkpoint model {:prefix model-prefix :epoch 0}))) diff --git a/contrib/clojure-package/examples/captcha/test/captcha/train_ocr_test.clj b/contrib/clojure-package/examples/captcha/test/captcha/train_ocr_test.clj new file mode 100644 index 000000000000..ab785f7fedf2 --- /dev/null +++ b/contrib/clojure-package/examples/captcha/test/captcha/train_ocr_test.clj @@ -0,0 +1,119 @@ +;; +;; Licensed to the Apache Software Foundation (ASF) under one or more +;; contributor license agreements. See the NOTICE file distributed with +;; this work for additional information regarding copyright ownership. +;; The ASF licenses this file to You under the Apache License, Version 2.0 +;; (the "License"); you may not use this file except in compliance with +;; the License. You may obtain a copy of the License at +;; +;; http://www.apache.org/licenses/LICENSE-2.0 +;; +;; Unless required by applicable law or agreed to in writing, software +;; distributed under the License is distributed on an "AS IS" BASIS, +;; WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +;; See the License for the specific language governing permissions and +;; limitations under the License. +;; + +(ns captcha.train-ocr-test + (:require [clojure.test :refer :all] + [captcha.consts :refer :all] + [captcha.train-ocr :refer :all] + [org.apache.clojure-mxnet.io :as mx-io] + [org.apache.clojure-mxnet.module :as m] + [org.apache.clojure-mxnet.ndarray :as ndarray] + [org.apache.clojure-mxnet.shape :as shape] + [org.apache.clojure-mxnet.util :as util])) + +(deftest test-consts + (is (= 8 batch-size)) + (is (= [3 30 80] data-shape)) + (is (= 4 label-width)) + (is (= 10 num-labels))) + +(deftest test-labeled-data + (let [train-batch (mx-io/next train-data) + eval-batch (mx-io/next eval-data) + allowed-labels (into #{} (map float (range 10)))] + (is (= 8 (-> train-batch mx-io/batch-index count))) + (is (= 8 (-> eval-batch mx-io/batch-index count))) + (is (= [8 3 30 80] (-> train-batch + mx-io/batch-data + first + ndarray/shape-vec))) + (is (= [8 3 30 80] (-> eval-batch + mx-io/batch-data + first + ndarray/shape-vec))) + (is (every? #(<= 0 % 255) (-> train-batch + mx-io/batch-data + first + ndarray/->vec))) + (is (every? #(<= 0 % 255) (-> eval-batch + mx-io/batch-data + first + ndarray/->vec))) + (is (= [8 4] (-> train-batch + mx-io/batch-label + first + ndarray/shape-vec))) + (is (= [8 4] (-> eval-batch + mx-io/batch-label + first + ndarray/shape-vec))) + (is (every? allowed-labels (-> train-batch + mx-io/batch-label + first + ndarray/->vec))) + (is (every? allowed-labels (-> eval-batch + mx-io/batch-label + first + ndarray/->vec))))) + +(deftest test-model + (let [batch (mx-io/next train-data) + model (m/module (create-captcha-net) + {:data-names ["data"] :label-names ["label"]}) + _ (m/bind model + {:data-shapes (mx-io/provide-data-desc train-data) + :label-shapes (mx-io/provide-label-desc train-data)}) + _ (m/init-params model) + _ (m/forward-backward model batch) + output-shapes (-> model + m/output-shapes + util/coerce-return-recursive) + outputs (-> model + m/outputs-merged + first) + grads (->> model m/grad-arrays (map first))] + (is (= [["softmaxoutput0_output" (shape/->shape [8 10])]] + output-shapes)) + (is (= [32 10] (-> outputs ndarray/shape-vec))) + (is (every? #(<= 0.0 % 1.0) (-> outputs ndarray/->vec))) + (is (= [[32 3 5 5] [32] ; convolution1 weights+bias + [32 32 5 5] [32] ; convolution2 weights+bias + [32 32 3 3] [32] ; convolution3 weights+bias + [32 32 3 3] [32] ; convolution4 weights+bias + [256 28672] [256] ; fully-connected1 weights+bias + [10 256] [10] ; 1st label scores + [10 256] [10] ; 2nd label scores + [10 256] [10] ; 3rd label scores + [10 256] [10]] ; 4th label scores + (map ndarray/shape-vec grads))))) + +(deftest test-accuracy + (let [labels (ndarray/array [1 2 3 4, + 5 6 7 8] + [2 4]) + pred-labels (ndarray/array [1 0, + 2 6, + 3 0, + 4 8] + [8]) + preds (ndarray/one-hot pred-labels 10)] + (is (float? (accuracy labels preds))) + (is (float? (accuracy labels preds :by-character false))) + (is (float? (accuracy labels preds :by-character true))) + (is (= 0.5 (accuracy labels preds))) + (is (= 0.5 (accuracy labels preds :by-character false))) + (is (= 0.75 (accuracy labels preds :by-character true)))))