Skip to content

Commit

Permalink
remove mod from arity 2 version of load-checkpoint in clojure-package (
Browse files Browse the repository at this point in the history
…apache#11808)

* remove mod from arity 2 version of load-checkpoint

* load-checkpoint arity 2 test
  • Loading branch information
jimdunn authored and nswamy committed Aug 2, 2018
1 parent 52abe1f commit 1373b28
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 7 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -309,7 +309,6 @@

(defn load-checkpoint
"Create a model from previously saved checkpoint.
- mod module
- opts map of
- prefix Path prefix of saved model files. You should have prefix-symbol.json,
prefix-xxxx.params, and optionally prefix-xxxx.states,
Expand Down Expand Up @@ -341,7 +340,7 @@
(util/->option (when workload-list (util/vec->indexed-seq workload-list)))
(util/->option (when fixed-param-names (util/vec->set fixed-param-names)))))
([prefix epoch]
(load-checkpoint mod {:prefix prefix :epoch epoch})))
(load-checkpoint {:prefix prefix :epoch epoch})))

(defn load-optimizer-states [mod fname]
(.mod load fname))
Expand Down Expand Up @@ -670,4 +669,3 @@

(fit-params {:allow-missing true})
(fit-params {}))

Original file line number Diff line number Diff line change
Expand Up @@ -101,13 +101,20 @@
(m/init-optimizer {:optimizer (optimizer/sgd {:learning-rate 0.1 :momentum 0.9})})
(m/update)
(m/save-checkpoint {:prefix "test" :epoch 0 :save-opt-states true}))

(let [mod2 (m/load-checkpoint {:prefix "test" :epoch 0 :load-optimizer-states true})]
(-> mod2
(m/bind {:data-shapes [{:name "data" :shape [10 10] :layout "NT"}]})
(m/init-optimizer {:optimizer (optimizer/sgd {:learning-rate 0.1 :momentum 0.9})}))
(is (= (-> mod m/symbol sym/to-json) (-> mod2 m/symbol sym/to-json)))
(is (= (-> mod m/params first) (-> mod2 m/params first))))))
(is (= (-> mod m/symbol sym/to-json) (-> mod2 m/symbol sym/to-json)))
(is (= (-> mod m/params first) (-> mod2 m/params first))))
;; arity 2 version of above. `load-optimizer-states` is `false` here by default,
;; but optimizers states aren't checked here so it's not relevant to the test outcome.
(let [mod3 (m/load-checkpoint "test" 0)]
(-> mod3
(m/bind {:data-shapes [{:name "data" :shape [10 10] :layout "NT"}]})
(m/init-optimizer {:optimizer (optimizer/sgd {:learning-rate 0.1 :momentum 0.9})}))
(is (= (-> mod m/symbol sym/to-json) (-> mod3 m/symbol sym/to-json)))
(is (= (-> mod m/params first) (-> mod3 m/params first))))))

(deftest test-module-save-load-multi-device
(let [s (sym/variable "data")
Expand Down Expand Up @@ -321,4 +328,3 @@
(comment

(m/data-shapes x))

0 comments on commit 1373b28

Please sign in to comment.