diff --git a/src/Go/handlers.go b/src/Go/handlers.go index a36cdd4..5caf4d3 100644 --- a/src/Go/handlers.go +++ b/src/Go/handlers.go @@ -80,8 +80,31 @@ func DataHandler(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusBadRequest) } -// ImageTensorHandler send prediction from TF ML model -func ImageTensorHandler(w http.ResponseWriter, r *http.Request) { +// ImageHandler send prediction from TF ML model +func ImageHandler(w http.ResponseWriter, r *http.Request) { + model := r.FormValue("model") + if model == "" { + msg := fmt.Sprintf("unable to read %s model", model) + responseError(w, msg, nil, http.StatusInternalServerError) + return + } + tfModel, err := tfVersion(model) + if err != nil { + msg := fmt.Sprintf("unable to read %s model", model) + responseError(w, msg, nil, http.StatusInternalServerError) + return + } + if tfModel == "tf1" { + log.Println("use ImageTF1Handler") + ImageTF1Handler(w, r) + return + } + log.Println("use ImageTF2Handler") + ImageTF2Handler(w, r) +} + +// ImageTF2Handler send prediction from TF2 ML model +func ImageTF2Handler(w http.ResponseWriter, r *http.Request) { model := r.FormValue("model") if model == "" { msg := fmt.Sprintf("unable to read %s model", model) @@ -136,8 +159,8 @@ func ImageTensorHandler(w http.ResponseWriter, r *http.Request) { responseJSON(w, probs) } -// ImageHandler send prediction from TF ML model -func ImageHandler(w http.ResponseWriter, r *http.Request) { +// ImageTF1Handler send prediction from TF ML model +func ImageTF1Handler(w http.ResponseWriter, r *http.Request) { model := r.FormValue("model") if model == "" { msg := fmt.Sprintf("unable to read %s model", model) diff --git a/src/Go/server.go b/src/Go/server.go index 7023889..70f070b 100644 --- a/src/Go/server.go +++ b/src/Go/server.go @@ -63,7 +63,6 @@ func handlers() *mux.Router { router.HandleFunc(basePath("/predict/json"), PredictHandler).Methods("POST") router.HandleFunc(basePath("/predict/proto"), PredictProtobufHandler).Methods("POST") router.HandleFunc(basePath("/predict/image"), ImageHandler).Methods("POST") - router.HandleFunc(basePath("/predict/imagetensor"), ImageTensorHandler).Methods("POST") router.HandleFunc(basePath("/json"), PredictHandler).Methods("POST") router.HandleFunc(basePath("/proto"), PredictProtobufHandler).Methods("POST") router.HandleFunc(basePath("/image"), ImageHandler).Methods("POST") diff --git a/src/Go/tfaas.go b/src/Go/tfaas.go index 42e92f7..c95fd6e 100644 --- a/src/Go/tfaas.go +++ b/src/Go/tfaas.go @@ -229,25 +229,37 @@ func loadModel(fname, flabels string) (*tf.Graph, []string, error) { return graph, labels, nil } -// helper function to generate predictions based on given row values -// either TF 2.X models via tfgo or TF 1.X models via graph loading -func makePredictions(row *Row) ([]float32, error) { - name := _params.Name - if row.Model != "" { - name = row.Model - } +// helper function to determine which model in our repository for given model name +func tfVersion(name string) (string, error) { // if model area has assets, variables and saved_model.pb // we will use TF 2.X approach based on tfgo path := fmt.Sprintf("%s/%s", _config.ModelDir, name) files, err := ioutil.ReadDir(path) if err != nil { - return []float32{}, err + return "", err } var fnames []string for _, file := range files { fnames = append(fnames, file.Name()) } if InList("assets", fnames) && InList("variables", fnames) && InList("saved_model.pb", fnames) { + return "tf2", nil + } + return "tf1", nil +} + +// helper function to generate predictions based on given row values +// either TF 2.X models via tfgo or TF 1.X models via graph loading +func makePredictions(row *Row) ([]float32, error) { + name := _params.Name + if row.Model != "" { + name = row.Model + } + tfModel, err := tfVersion(name) + if err != nil { + return []float32{}, err + } + if tfModel == "tf2" { return makePredictions2(row) } return makePredictions1(row)