Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
  • Loading branch information
hellonico committed Dec 7, 2018
1 parent c9894f0 commit 1fe4fc3
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 7 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
(def batch-size 10) ;; the batch size
(def optimizer (optimizer/sgd {:learning-rate 0.01 :momentum 0.0}))
(def eval-metric (eval-metric/accuracy))
(def num-epoch 5) ;; the number of training epochs
(def num-epoch 1) ;; the number of training epochs
(def kvstore "local") ;; the kvstore type
;;; Note to run distributed you might need to complile the engine with an option set
(def role "worker") ;; scheduler/server/worker
Expand Down Expand Up @@ -82,7 +82,9 @@
(sym/fully-connected "fc3" {:data data :num-hidden 10})
(sym/softmax-output "softmax" {:data data})))

(defn start [devs]
(defn start
([devs] (start devs num-epoch))
([devs _num-epoch]
(when scheduler-host
(println "Initing PS enviornments with " envs)
(kvstore-server/init envs))
Expand All @@ -94,14 +96,18 @@
(do
(println "Starting Training of MNIST ....")
(println "Running with context devices of" devs)
(let [mod (m/module (get-symbol) {:contexts devs})]
(m/fit mod {:train-data train-data
(let [_mod (m/module (get-symbol) {:contexts devs})]
(m/fit _mod {:train-data train-data
:eval-data test-data
:num-epoch num-epoch
:num-epoch _num-epoch
:fit-params (m/fit-params {:kvstore kvstore
:optimizer optimizer
:eval-metric eval-metric})}))
(println "Finish fit"))))
:eval-metric eval-metric})})
(println "Finish fit")
_mod
)

))))

(defn -main [& args]
(let [[dev dev-num] args
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
;;
;; Licensed to the Apache Software Foundation (ASF) under one or more
;; contributor license agreements. See the NOTICE file distributed with
;; this work for additional information regarding copyright ownership.
;; The ASF licenses this file to You under the Apache License, Version 2.0
;; (the "License"); you may not use this file except in compliance with
;; the License. You may obtain a copy of the License at
;;
;; http://www.apache.org/licenses/LICENSE-2.0
;;
;; Unless required by applicable law or agreed to in writing, software
;; distributed under the License is distributed on an "AS IS" BASIS,
;; WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
;; See the License for the specific language governing permissions and
;; limitations under the License.
;;

(ns imclassification.train-mnist-test
(:require
[clojure.test :refer :all]
[clojure.java.io :as io]
[org.apache.clojure-mxnet.context :as context]
[org.apache.clojure-mxnet.module :as module]
[imclassification.train-mnist :as mnist]))

(deftest mnist-two-epochs-test
(module/save-checkpoint (mnist/start [(context/cpu)] 2) {:prefix "target/test" :epoch 2})
(is (= (slurp "test/test-0002.params") (slurp "target/test-0002.params"))))

0 comments on commit 1fe4fc3

Please sign in to comment.