Skip to content

Commit

Permalink
code refactoring
Browse files Browse the repository at this point in the history
  • Loading branch information
vkuznet committed Mar 12, 2023
1 parent 4e70137 commit ad28e9e
Show file tree
Hide file tree
Showing 3 changed files with 47 additions and 13 deletions.
31 changes: 27 additions & 4 deletions src/Go/handlers.go
Original file line number Diff line number Diff line change
Expand Up @@ -80,8 +80,31 @@ func DataHandler(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusBadRequest)
}

// ImageTensorHandler send prediction from TF ML model
func ImageTensorHandler(w http.ResponseWriter, r *http.Request) {
// ImageHandler send prediction from TF ML model
func ImageHandler(w http.ResponseWriter, r *http.Request) {
model := r.FormValue("model")
if model == "" {
msg := fmt.Sprintf("unable to read %s model", model)
responseError(w, msg, nil, http.StatusInternalServerError)
return
}
tfModel, err := tfVersion(model)
if err != nil {
msg := fmt.Sprintf("unable to read %s model", model)
responseError(w, msg, nil, http.StatusInternalServerError)
return
}
if tfModel == "tf1" {
log.Println("use ImageTF1Handler")
ImageTF1Handler(w, r)
return
}
log.Println("use ImageTF2Handler")
ImageTF2Handler(w, r)
}

// ImageTF2Handler send prediction from TF2 ML model
func ImageTF2Handler(w http.ResponseWriter, r *http.Request) {
model := r.FormValue("model")
if model == "" {
msg := fmt.Sprintf("unable to read %s model", model)
Expand Down Expand Up @@ -136,8 +159,8 @@ func ImageTensorHandler(w http.ResponseWriter, r *http.Request) {
responseJSON(w, probs)
}

// ImageHandler send prediction from TF ML model
func ImageHandler(w http.ResponseWriter, r *http.Request) {
// ImageTF1Handler send prediction from TF ML model
func ImageTF1Handler(w http.ResponseWriter, r *http.Request) {
model := r.FormValue("model")
if model == "" {
msg := fmt.Sprintf("unable to read %s model", model)
Expand Down
1 change: 0 additions & 1 deletion src/Go/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,6 @@ func handlers() *mux.Router {
router.HandleFunc(basePath("/predict/json"), PredictHandler).Methods("POST")
router.HandleFunc(basePath("/predict/proto"), PredictProtobufHandler).Methods("POST")
router.HandleFunc(basePath("/predict/image"), ImageHandler).Methods("POST")
router.HandleFunc(basePath("/predict/imagetensor"), ImageTensorHandler).Methods("POST")
router.HandleFunc(basePath("/json"), PredictHandler).Methods("POST")
router.HandleFunc(basePath("/proto"), PredictProtobufHandler).Methods("POST")
router.HandleFunc(basePath("/image"), ImageHandler).Methods("POST")
Expand Down
28 changes: 20 additions & 8 deletions src/Go/tfaas.go
Original file line number Diff line number Diff line change
Expand Up @@ -229,25 +229,37 @@ func loadModel(fname, flabels string) (*tf.Graph, []string, error) {
return graph, labels, nil
}

// helper function to generate predictions based on given row values
// either TF 2.X models via tfgo or TF 1.X models via graph loading
func makePredictions(row *Row) ([]float32, error) {
name := _params.Name
if row.Model != "" {
name = row.Model
}
// helper function to determine which model in our repository for given model name
func tfVersion(name string) (string, error) {
// if model area has assets, variables and saved_model.pb
// we will use TF 2.X approach based on tfgo
path := fmt.Sprintf("%s/%s", _config.ModelDir, name)
files, err := ioutil.ReadDir(path)
if err != nil {
return []float32{}, err
return "", err
}
var fnames []string
for _, file := range files {
fnames = append(fnames, file.Name())
}
if InList("assets", fnames) && InList("variables", fnames) && InList("saved_model.pb", fnames) {
return "tf2", nil
}
return "tf1", nil
}

// helper function to generate predictions based on given row values
// either TF 2.X models via tfgo or TF 1.X models via graph loading
func makePredictions(row *Row) ([]float32, error) {
name := _params.Name
if row.Model != "" {
name = row.Model
}
tfModel, err := tfVersion(name)
if err != nil {
return []float32{}, err
}
if tfModel == "tf2" {
return makePredictions2(row)
}
return makePredictions1(row)
Expand Down

0 comments on commit ad28e9e

Please sign in to comment.