diff --git a/example/cmd/device-simple/Attribution.txt b/example/cmd/device-simple/Attribution.txt index 6f36144f..7d6d5edf 100644 --- a/example/cmd/device-simple/Attribution.txt +++ b/example/cmd/device-simple/Attribution.txt @@ -1,8 +1,5 @@ The following open source projects are referenced by Device Service SDK Go: -gorilla/mux 1.6.2 (BSD-3) /~https://github.com/gorilla/mux -/~https://github.com/gorilla/mux/blob/master/LICENSE - hashicorp/consul 1.1.0 (Mozilla Public License 2.0) /~https://github.com/hashicorp/consul /~https://github.com/hashicorp/consul/blob/master/LICENSE @@ -225,3 +222,5 @@ /~https://github.com/klauspost/compress/blob/master/LICENSE github.com/gabriel-vasile/mimetype (MIT) /~https://github.com/gabriel-vasile/mimetype /~https://github.com/gabriel-vasile/mimetype/blob/master/LICENSE +github.com/labstack/echo/v4 (MIT) /~https://github.com/labstack/echo +/~https://github.com/labstack/echo/blob/master/LICENSE diff --git a/go.mod b/go.mod index 076720fc..4d160641 100644 --- a/go.mod +++ b/go.mod @@ -5,10 +5,9 @@ go 1.20 require ( github.com/OneOfOne/xxhash v1.2.8 github.com/edgexfoundry/go-mod-bootstrap/v3 v3.1.0-dev.14 - github.com/edgexfoundry/go-mod-core-contracts/v3 v3.1.0-dev.2 + github.com/edgexfoundry/go-mod-core-contracts/v3 v3.1.0-dev.3 github.com/edgexfoundry/go-mod-messaging/v3 v3.1.0-dev.12 github.com/google/uuid v1.3.0 - github.com/gorilla/mux v1.8.0 github.com/hashicorp/go-multierror v1.1.1 github.com/labstack/echo/v4 v4.11.1 github.com/rcrowley/go-metrics v0.0.0-20201227073835-cf1acfcdf475 diff --git a/go.sum b/go.sum index 515c7aa5..d2f65145 100644 --- a/go.sum +++ b/go.sum @@ -32,8 +32,8 @@ github.com/edgexfoundry/go-mod-bootstrap/v3 v3.1.0-dev.14 h1:yDDt0qwMDjhEDbV7rEX github.com/edgexfoundry/go-mod-bootstrap/v3 v3.1.0-dev.14/go.mod h1:Qe0cE8xdcddeTKVhKt079SnA6obquNSiIFLfAfBKdDE= github.com/edgexfoundry/go-mod-configuration/v3 v3.1.0-dev.4 h1:RmZVvR9sa3CdRNIiBNtuYG1gkP3Y98jOf69kL2SabGI= github.com/edgexfoundry/go-mod-configuration/v3 v3.1.0-dev.4/go.mod h1:hVlzVoVpbgOZKppX+vwYsp7Kf0eBfRmBZP2FsyMwe9I= -github.com/edgexfoundry/go-mod-core-contracts/v3 v3.1.0-dev.2 h1:H3ls1vyxCv6pigZ/RZlRhj9lUTQq7CiT5/dnZRsgVmQ= -github.com/edgexfoundry/go-mod-core-contracts/v3 v3.1.0-dev.2/go.mod h1:dadH49hOQlIKkbOfRj5gYEez8l9kUQ3aj9cc8uA5hv4= +github.com/edgexfoundry/go-mod-core-contracts/v3 v3.1.0-dev.3 h1:NJa87GfnUZw/GvX2aeeZ4ZQoSzwkRotuLQej2ay6XUc= +github.com/edgexfoundry/go-mod-core-contracts/v3 v3.1.0-dev.3/go.mod h1:dadH49hOQlIKkbOfRj5gYEez8l9kUQ3aj9cc8uA5hv4= github.com/edgexfoundry/go-mod-messaging/v3 v3.1.0-dev.12 h1:b0ZNEkF0Xnsjlvdd4RjAvdqiq/bGLsVe72yysg92tQQ= github.com/edgexfoundry/go-mod-messaging/v3 v3.1.0-dev.12/go.mod h1:1f0XKGDQifMd/2D11aHvAOpVJ19BmaPzr1f0T4nVpdw= github.com/edgexfoundry/go-mod-registry/v3 v3.1.0-dev.4 h1:qYWqqsXzJie8HBj6103RCuFRi8Ks4VbDT45rdZjJ8Z8= @@ -89,8 +89,6 @@ github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38= github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= github.com/google/uuid v1.3.0 h1:t6JiXgmwXMjEs8VusXIJk2BXHsn+wx8BZdTaoZ5fu7I= github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= -github.com/gorilla/mux v1.8.0 h1:i40aqfkR1h2SlN9hojwV5ZA91wcXFOvkdNIeFDP5koI= -github.com/gorilla/mux v1.8.0/go.mod h1:DVbg23sWSpFRCP0SfiEN6jmj59UnW/n46BH5rLB71So= github.com/gorilla/websocket v1.5.0 h1:PPwGk2jz7EePpoHN/+ClbZu8SPxiqlu12wZP/3sWmnc= github.com/gorilla/websocket v1.5.0/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= github.com/hashicorp/consul/api v1.23.0 h1:L6e4v1AfoumqAHq/Rrsmuulev+nd7vltM3k8H329tyI= diff --git a/internal/controller/http/command.go b/internal/controller/http/command.go index f2f5ac31..5236448d 100644 --- a/internal/controller/http/command.go +++ b/internal/controller/http/command.go @@ -18,25 +18,25 @@ import ( commonDTO "github.com/edgexfoundry/go-mod-core-contracts/v3/dtos/common" "github.com/edgexfoundry/go-mod-core-contracts/v3/dtos/responses" "github.com/edgexfoundry/go-mod-core-contracts/v3/errors" - "github.com/gorilla/mux" "github.com/edgexfoundry/device-sdk-go/v3/internal/application" sdkCommon "github.com/edgexfoundry/device-sdk-go/v3/internal/common" -) -func (c *RestController) GetCommand(w http.ResponseWriter, r *http.Request) { - vars := mux.Vars(r) - deviceName := vars[common.Name] - commandName := vars[common.Command] + "github.com/labstack/echo/v4" +) +func (c *RestController) GetCommand(e echo.Context) error { + deviceName := e.Param(common.Name) + commandName := e.Param(common.Command) + r := e.Request() + w := e.Response() ctx := r.Context() correlationId := utils.FromContext(ctx, common.CorrelationHeader) // parse query parameter queryParams, reserved, err := filterQueryParams(r.URL.RawQuery) if err != nil { - c.sendEdgexError(w, r, err, common.ApiDeviceNameCommandNameRoute) - return + return c.sendEdgexError(w, r, err, common.ApiDeviceNameCommandNameRoute) } regexCmd := true @@ -46,8 +46,7 @@ func (c *RestController) GetCommand(w http.ResponseWriter, r *http.Request) { event, err := application.GetCommand(ctx, deviceName, commandName, queryParams, regexCmd, c.dic) if err != nil { - c.sendEdgexError(w, r, err, common.ApiDeviceNameCommandNameRoute) - return + return c.sendEdgexError(w, r, err, common.ApiDeviceNameCommandNameRoute) } // push event to CoreData if specified (default false) @@ -58,40 +57,38 @@ func (c *RestController) GetCommand(w http.ResponseWriter, r *http.Request) { // return event in http response if specified (default true) if returnEvent := reserved.Get(common.ReturnEvent); returnEvent == "" || returnEvent == common.ValueTrue { res := responses.NewEventResponse("", "", http.StatusOK, *event) - c.sendEventResponse(w, r, res, http.StatusOK) - return + return c.sendEventResponse(w, r, res, http.StatusOK) } w.WriteHeader(http.StatusOK) + return nil } -func (c *RestController) SetCommand(w http.ResponseWriter, r *http.Request) { +func (c *RestController) SetCommand(e echo.Context) error { + r := e.Request() + w := e.Response() if r.Body != nil { defer func() { _ = r.Body.Close() }() } ctx := r.Context() - vars := mux.Vars(r) - deviceName := vars[common.Name] - commandName := vars[common.Command] + deviceName := e.Param(common.Name) + commandName := e.Param(common.Command) // parse query parameter queryParams, _, err := filterQueryParams(r.URL.RawQuery) if err != nil { - c.sendEdgexError(w, r, err, common.ApiDeviceNameCommandNameRoute) - return + return c.sendEdgexError(w, r, err, common.ApiDeviceNameCommandNameRoute) } requestParamsMap, err := parseRequestBody(r) if err != nil { - c.sendEdgexError(w, r, err, common.ApiDeviceNameCommandNameRoute) - return + return c.sendEdgexError(w, r, err, common.ApiDeviceNameCommandNameRoute) } event, err := application.SetCommand(ctx, deviceName, commandName, queryParams, requestParamsMap, c.dic) if err != nil { - c.sendEdgexError(w, r, err, common.ApiDeviceNameCommandNameRoute) - return + return c.sendEdgexError(w, r, err, common.ApiDeviceNameCommandNameRoute) } if event != nil { @@ -100,7 +97,7 @@ func (c *RestController) SetCommand(w http.ResponseWriter, r *http.Request) { } res := commonDTO.NewBaseResponse("", "", http.StatusOK) - c.sendResponse(w, r, common.ApiDeviceNameCommandNameRoute, res, http.StatusOK) + return c.sendResponse(w, r, common.ApiDeviceNameCommandNameRoute, res, http.StatusOK) } func parseRequestBody(req *http.Request) (map[string]interface{}, errors.EdgeX) { diff --git a/internal/controller/http/command_test.go b/internal/controller/http/command_test.go index 2783d0ac..5e6908aa 100644 --- a/internal/controller/http/command_test.go +++ b/internal/controller/http/command_test.go @@ -25,7 +25,6 @@ import ( "github.com/edgexfoundry/go-mod-core-contracts/v3/dtos/responses" "github.com/edgexfoundry/go-mod-core-contracts/v3/models" messagingMocks "github.com/edgexfoundry/go-mod-messaging/v3/messaging/mocks" - "github.com/gorilla/mux" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" "github.com/stretchr/testify/require" @@ -209,8 +208,8 @@ func TestRestController_GetCommand(t *testing.T) { e := echo.New() dic := mockDic() - err := cache.InitCache(testService, testService, dic) - require.NoError(t, err) + edgexErr := cache.InitCache(testService, testService, dic) + require.NoError(t, edgexErr) controller := NewRestController(e, dic, testService) assert.NotNil(t, controller) @@ -237,15 +236,15 @@ func TestRestController_GetCommand(t *testing.T) { } for _, testCase := range tests { t.Run(testCase.name, func(t *testing.T) { - req, err := http.NewRequest(http.MethodGet, common.ApiDeviceNameCommandNameRoute, http.NoBody) - req = mux.SetURLVars(req, map[string]string{common.Name: testCase.deviceName, common.Command: testCase.commandName}) - require.NoError(t, err) + req := httptest.NewRequest(http.MethodGet, common.ApiDeviceNameCommandNameRoute, http.NoBody) // Act recorder := httptest.NewRecorder() - handler := WrapHandler(controller.GetCommand) c := e.NewContext(req, recorder) - err = handler(c) + c.SetParamNames(common.Name, common.Command) + c.SetParamValues(testCase.deviceName, testCase.commandName) + + err := controller.GetCommand(c) assert.NoError(t, err) var res responses.EventResponse @@ -283,15 +282,14 @@ func TestRestController_GetCommand_ServiceLocked(t *testing.T) { controller := NewRestController(e, dic, testService) assert.NotNil(t, controller) - req, err := http.NewRequest(http.MethodGet, common.ApiDeviceNameCommandNameRoute, http.NoBody) - req = mux.SetURLVars(req, map[string]string{common.Name: testDevice, common.Command: testResource}) - require.NoError(t, err) + req := httptest.NewRequest(http.MethodGet, common.ApiDeviceNameCommandNameRoute, http.NoBody) // Act recorder := httptest.NewRecorder() - handler := WrapHandler(controller.GetCommand) c := e.NewContext(req, recorder) - err = handler(c) + c.SetParamNames(common.Name, common.Command) + c.SetParamValues(testDevice, testResource) + err := controller.GetCommand(c) assert.NoError(t, err) var res responses.EventResponse @@ -315,18 +313,17 @@ func TestRestController_GetCommand_ReturnEvent(t *testing.T) { controller := NewRestController(e, dic, testService) assert.NotNil(t, controller) - req, err := http.NewRequest(http.MethodGet, common.ApiDeviceNameCommandNameRoute, http.NoBody) - req = mux.SetURLVars(req, map[string]string{common.Name: testDevice, common.Command: testResource}) - require.NoError(t, err) + req := httptest.NewRequest(http.MethodGet, common.ApiDeviceNameCommandNameRoute, http.NoBody) query := req.URL.Query() query.Add("ds-returnevent", common.ValueFalse) req.URL.RawQuery = query.Encode() // Act recorder := httptest.NewRecorder() - handler := WrapHandler(controller.GetCommand) c := e.NewContext(req, recorder) - err = handler(c) + c.SetParamNames(common.Name, common.Command) + c.SetParamValues(testDevice, testResource) + err := controller.GetCommand(c) assert.NoError(t, err) // Assert @@ -380,9 +377,7 @@ func TestRestController_SetCommand(t *testing.T) { require.NoError(t, err) reader := strings.NewReader(string(jsonData)) - req, err := http.NewRequest(http.MethodPut, common.ApiDeviceNameCommandNameRoute, reader) - req = mux.SetURLVars(req, map[string]string{common.Name: testCase.deviceName, common.Command: testCase.commandName}) - require.NoError(t, err) + req := httptest.NewRequest(http.MethodPut, common.ApiDeviceNameCommandNameRoute, reader) var wg sync.WaitGroup if testCase.commandName != writeOnlyCommand && testCase.commandName != writeOnlyResource { @@ -402,9 +397,10 @@ func TestRestController_SetCommand(t *testing.T) { // Act recorder := httptest.NewRecorder() - handler := WrapHandler(controller.SetCommand) c := e.NewContext(req, recorder) - err = handler(c) + c.SetParamNames(common.Name, common.Command) + c.SetParamValues(testCase.deviceName, testCase.commandName) + err = controller.SetCommand(c) assert.NoError(t, err) var res commonDTO.BaseResponse @@ -452,15 +448,14 @@ func TestRestController_SetCommand_ServiceLocked(t *testing.T) { require.NoError(t, err) reader := strings.NewReader(string(jsonData)) - req, err := http.NewRequest(http.MethodPut, common.ApiDeviceNameCommandNameRoute, reader) - req = mux.SetURLVars(req, map[string]string{common.Name: testDevice, common.Command: testResource}) - require.NoError(t, err) + req := httptest.NewRequest(http.MethodPut, common.ApiDeviceNameCommandNameRoute, reader) // Act recorder := httptest.NewRecorder() - handler := WrapHandler(controller.SetCommand) c := e.NewContext(req, recorder) - err = handler(c) + c.SetParamNames(common.Name, common.Command) + c.SetParamValues(testDevice, testResource) + err = controller.SetCommand(c) assert.NoError(t, err) var res commonDTO.BaseResponse diff --git a/internal/controller/http/correlation/middleware.go b/internal/controller/http/correlation/middleware.go deleted file mode 100644 index 21f2059e..00000000 --- a/internal/controller/http/correlation/middleware.go +++ /dev/null @@ -1,67 +0,0 @@ -// -*- Mode: Go; indent-tabs-mode: t -*- -// -// Copyright (C) 2019-2023 IOTech Ltd -// -// SPDX-License-Identifier: Apache-2.0 - -package correlation - -import ( - "context" - "github.com/gorilla/mux" - "net/http" - "net/url" - "time" - - "github.com/edgexfoundry/go-mod-core-contracts/v3/clients/logger" - "github.com/edgexfoundry/go-mod-core-contracts/v3/common" - "github.com/edgexfoundry/go-mod-core-contracts/v3/models" - "github.com/google/uuid" -) - -func ManageHeader(next http.Handler) http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - hdr := r.Header.Get(common.CorrelationHeader) - if hdr == "" { - hdr = uuid.New().String() - } - ctx := context.WithValue(r.Context(), common.CorrelationHeader, hdr) // nolint:staticcheck - r = r.WithContext(ctx) - next.ServeHTTP(w, r) - }) -} - -func LoggingMiddleware(lc logger.LoggingClient) func(http.Handler) http.Handler { - return func(next http.Handler) http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if lc.LogLevel() == models.TraceLog { - begin := time.Now() - correlationId := IdFromContext(r.Context()) - lc.Trace("Begin request", common.CorrelationHeader, correlationId, "path", r.URL.Path) - next.ServeHTTP(w, r) - lc.Trace("Response complete", common.CorrelationHeader, correlationId, "duration", time.Since(begin).String()) - } else { - next.ServeHTTP(w, r) - } - }) - } -} - -// UrlDecodeMiddleware decode the path variables -// After invoking the router.UseEncodedPath() func, the path variables needs to decode before passing to the controller -func UrlDecodeMiddleware(lc logger.LoggingClient) func(http.Handler) http.Handler { - return func(next http.Handler) http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - vars := mux.Vars(r) - for k, v := range vars { - unescape, err := url.PathUnescape(v) - if err != nil { - lc.Debugf("failed to decode the %s from the value %s", k, v) - return - } - vars[k] = unescape - } - next.ServeHTTP(w, r) - }) - } -} diff --git a/internal/controller/http/discovery.go b/internal/controller/http/discovery.go index 17fcd492..16a14062 100644 --- a/internal/controller/http/discovery.go +++ b/internal/controller/http/discovery.go @@ -15,25 +15,27 @@ import ( "github.com/edgexfoundry/device-sdk-go/v3/internal/autodiscovery" "github.com/edgexfoundry/device-sdk-go/v3/internal/container" + + "github.com/labstack/echo/v4" ) -func (c *RestController) Discovery(writer http.ResponseWriter, request *http.Request) { +func (c *RestController) Discovery(e echo.Context) error { + request := e.Request() + writer := e.Response() ds := container.DeviceServiceFrom(c.dic.Get) if ds.AdminState == models.Locked { err := errors.NewCommonEdgeX(errors.KindServiceLocked, "service locked", nil) - c.sendEdgexError(writer, request, err, common.ApiDiscoveryRoute) - return + return c.sendEdgexError(writer, request, err, common.ApiDiscoveryRoute) } configuration := container.ConfigurationFrom(c.dic.Get) if !configuration.Device.Discovery.Enabled { err := errors.NewCommonEdgeX(errors.KindServiceUnavailable, "device discovery disabled", nil) - c.sendEdgexError(writer, request, err, common.ApiDiscoveryRoute) - return + return c.sendEdgexError(writer, request, err, common.ApiDiscoveryRoute) } driver := container.ProtocolDriverFrom(c.dic.Get) go autodiscovery.DiscoveryWrapper(driver, c.lc) - c.sendResponse(writer, request, common.ApiDiscoveryRoute, nil, http.StatusAccepted) + return c.sendResponse(writer, request, common.ApiDiscoveryRoute, nil, http.StatusAccepted) } diff --git a/internal/controller/http/restrouter.go b/internal/controller/http/restrouter.go index 1ae63053..23709d16 100644 --- a/internal/controller/http/restrouter.go +++ b/internal/controller/http/restrouter.go @@ -59,10 +59,10 @@ func (c *RestController) InitRestRoutes() { authenticationHook := handlers.AutoConfigAuthenticationFunc(secretProvider, c.lc) // discovery - c.addReservedRoute(common.ApiDiscoveryRoute, WrapHandler(c.Discovery), http.MethodPost, authenticationHook) + c.addReservedRoute(common.ApiDiscoveryRoute, c.Discovery, http.MethodPost, authenticationHook) // device command - c.addReservedRoute(common.ApiDeviceNameCommandNameRoute, WrapHandler(c.GetCommand), http.MethodGet, authenticationHook) - c.addReservedRoute(common.ApiDeviceNameCommandNameRoute, WrapHandler(c.SetCommand), http.MethodPut, authenticationHook) + c.addReservedRoute(common.ApiDeviceNameCommandNameEchoRoute, c.GetCommand, http.MethodGet, authenticationHook) + c.addReservedRoute(common.ApiDeviceNameCommandNameEchoRoute, c.SetCommand, http.MethodPut, authenticationHook) } func (c *RestController) addReservedRoute(route string, handler func(e echo.Context) error, method string, @@ -71,12 +71,12 @@ func (c *RestController) addReservedRoute(route string, handler func(e echo.Cont return c.router.Add(method, route, handler, middlewareFunc...) } -func (c *RestController) AddRoute(route string, handler func(http.ResponseWriter, *http.Request), methods []string, middlewareFunc ...echo.MiddlewareFunc) errors.EdgeX { +func (c *RestController) AddRoute(route string, handler func(e echo.Context) error, methods []string, middlewareFunc ...echo.MiddlewareFunc) errors.EdgeX { if c.reservedRoutes[route] { return errors.NewCommonEdgeX(errors.KindServerError, "route is reserved", nil) } - c.router.Match(methods, route, WrapHandler(handler), middlewareFunc...) + c.router.Match(methods, route, handler, middlewareFunc...) c.lc.Debug("Route added", "route", route, "methods", fmt.Sprintf("%v", methods)) return nil @@ -88,11 +88,11 @@ func (c *RestController) Router() *echo.Echo { // sendResponse puts together the response packet for the V2 API func (c *RestController) sendResponse( - writer http.ResponseWriter, + writer *echo.Response, request *http.Request, api string, response interface{}, - statusCode int) { + statusCode int) error { correlationID := request.Header.Get(common.CorrelationHeader) @@ -104,32 +104,36 @@ func (c *RestController) sendResponse( data, err := json.Marshal(response) if err != nil { c.lc.Error(fmt.Sprintf("Unable to marshal %s response", api), "error", err.Error(), common.CorrelationHeader, correlationID) - http.Error(writer, err.Error(), http.StatusInternalServerError) - return + // set Response.Committed to true in order to rewrite the status code + writer.Committed = false + return echo.NewHTTPError(http.StatusInternalServerError, err.Error()) } _, err = writer.Write(data) if err != nil { c.lc.Error(fmt.Sprintf("Unable to write %s response", api), "error", err.Error(), common.CorrelationHeader, correlationID) - http.Error(writer, err.Error(), http.StatusInternalServerError) - return + // set Response.Committed to true in order to rewrite the status code + writer.Committed = false + return echo.NewHTTPError(http.StatusInternalServerError, err.Error()) } } + return nil } // sendEventResponse puts together the EventResponse packet for the V2 API func (c *RestController) sendEventResponse( - writer http.ResponseWriter, + writer *echo.Response, request *http.Request, response responses.EventResponse, - statusCode int) { + statusCode int) error { correlationID := request.Header.Get(common.CorrelationHeader) data, encoding, err := response.Encode() if err != nil { c.lc.Errorf("Unable to marshal EventResponse: %s; %s: %s", err.Error(), common.CorrelationHeader, correlationID) - http.Error(writer, err.Error(), http.StatusInternalServerError) - return + // set Response.Committed to true in order to rewrite the status code + writer.Committed = false + return echo.NewHTTPError(http.StatusInternalServerError, err.Error()) } writer.Header().Set(common.CorrelationHeader, correlationID) @@ -139,19 +143,21 @@ func (c *RestController) sendEventResponse( _, err = writer.Write(data) if err != nil { c.lc.Errorf("Unable to write DeviceCommand response: %s; %s: %s", err.Error(), common.CorrelationHeader, correlationID) - http.Error(writer, err.Error(), http.StatusInternalServerError) - return + // set Response.Committed to true in order to rewrite the status code + writer.Committed = false + return echo.NewHTTPError(http.StatusInternalServerError, err.Error()) } + return nil } func (c *RestController) sendEdgexError( - writer http.ResponseWriter, + writer *echo.Response, request *http.Request, err errors.EdgeX, - api string) { + api string) error { correlationID := request.Header.Get(common.CorrelationHeader) c.lc.Error(err.Error(), common.CorrelationHeader, correlationID) c.lc.Debug(err.DebugMessages(), common.CorrelationHeader, correlationID) response := commonDTO.NewBaseResponse("", err.Error(), err.Code()) - c.sendResponse(writer, request, api, response, err.Code()) + return c.sendResponse(writer, request, api, response, err.Code()) } diff --git a/internal/controller/http/restrouter_test.go b/internal/controller/http/restrouter_test.go index a30b9e59..d82d65fb 100644 --- a/internal/controller/http/restrouter_test.go +++ b/internal/controller/http/restrouter_test.go @@ -1,6 +1,6 @@ // // Copyright (c) 2019 Intel Corporation -// Copyright (C) 2020 IOTech Ltd +// Copyright (C) 2020-2023 IOTech Ltd // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -19,19 +19,28 @@ package http import ( "fmt" - "github.com/google/uuid" + "io" "net/http" + "net/http/httptest" "testing" bootstrapContainer "github.com/edgexfoundry/go-mod-bootstrap/v3/bootstrap/container" "github.com/edgexfoundry/go-mod-bootstrap/v3/di" "github.com/edgexfoundry/go-mod-core-contracts/v3/clients/logger" "github.com/edgexfoundry/go-mod-core-contracts/v3/common" - "github.com/gorilla/mux" "github.com/stretchr/testify/assert" "github.com/edgexfoundry/device-sdk-go/v3/internal/config" "github.com/edgexfoundry/device-sdk-go/v3/internal/container" + + "github.com/google/uuid" + "github.com/labstack/echo/v4" +) + +var ( + handlerFunc = func(c echo.Context) error { + return c.String(http.StatusOK, "OK") + } ) func TestAddRoute(t *testing.T) { @@ -39,10 +48,11 @@ func TestAddRoute(t *testing.T) { tests := []struct { Name string Route string + Method string ErrorExpected bool }{ - {"Success", "/api/v2/test", false}, - {"Reserved Route", common.ApiDiscoveryRoute, true}, + {"Success", "/api/v2/test", http.MethodPost, false}, + {"Reserved Route", common.ApiDiscoveryRoute, "", true}, } lc := logger.NewMockClient() @@ -56,11 +66,11 @@ func TestAddRoute(t *testing.T) { }) for _, test := range tests { - r := mux.NewRouter() + r := echo.New() controller := NewRestController(r, dic, uuid.NewString()) controller.InitRestRoutes() - err := controller.AddRoute(test.Route, func(http.ResponseWriter, *http.Request) {}, http.MethodPost) + err := controller.AddRoute(test.Route, handlerFunc, []string{test.Method}) if test.ErrorExpected { assert.Error(t, err, "Expected an error") } else { @@ -68,28 +78,28 @@ func TestAddRoute(t *testing.T) { t.Fatal() } - err := controller.router.Walk(func(route *mux.Route, router *mux.Router, ancestors []*mux.Route) error { - path, err := route.GetPathTemplate() - if err != nil { - return err - } - - // Have to skip all the reserved routes that have previously been added. - if controller.reservedRoutes[path] { - return nil - } - - routeMethods, err := route.GetMethods() - if err != nil { - return err - } + req := httptest.NewRequest(test.Method, test.Route, nil) + rec := httptest.NewRecorder() + c := r.NewContext(req, rec) + // Find the matched handler function from router with the matching method and url path + r.Router().Find(test.Method, test.Route, c) + // Apply the handler function to echo.Context + handlerErr := c.Handler()(c) + assert.NoError(t, handlerErr) + + // Have to skip all the reserved routes that have previously been added. + if controller.reservedRoutes[test.Route] { + return + } - assert.Equal(t, test.Route, path) - assert.Equal(t, http.MethodPost, routeMethods[0], "Expected POST Method") - return nil - }) + assert.Equal(t, test.Route, c.Path()) + assert.Equal(t, http.StatusOK, c.Response().Status) - assert.NoError(t, err, "Unexpected error examining route") + if body, err := io.ReadAll(rec.Body); err == nil { + assert.Equal(t, "OK", string(body), "unexpected handler function response") + } else { + assert.NoError(t, err) + } } } } @@ -104,22 +114,17 @@ func TestInitRestRoutes(t *testing.T) { return &config.ConfigurationStruct{} }, }) - r := mux.NewRouter() + r := echo.New() controller := NewRestController(r, dic, uuid.NewString()) controller.InitRestRoutes() - err := controller.router.Walk(func(route *mux.Route, router *mux.Router, ancestors []*mux.Route) error { - path, err := route.GetPathTemplate() - if err != nil { - return err - } + // Traverse all registered routes for the router + for _, route := range r.Routes() { + path := route.Path // Verify the route is reserved by attempting to add it as 'external' route. // If tests fails then the route was not added to the reserved list - err = controller.AddRoute(path, func(http.ResponseWriter, *http.Request) {}) + err := controller.AddRoute(path, func(c echo.Context) error { return nil }, nil) assert.Error(t, err, path, fmt.Sprintf("Expected error for '%s'", path)) - return nil - }) - - assert.NoError(t, err, "Unexpected error examining route") + } } diff --git a/pkg/interfaces/service.go b/pkg/interfaces/service.go index 13b9a486..9b200a3d 100644 --- a/pkg/interfaces/service.go +++ b/pkg/interfaces/service.go @@ -16,6 +16,8 @@ import ( "github.com/edgexfoundry/go-mod-core-contracts/v3/models" sdkModels "github.com/edgexfoundry/device-sdk-go/v3/pkg/models" + + "github.com/labstack/echo/v4" ) // UpdatableConfig interface allows services to have custom configuration populated from configuration stored @@ -128,7 +130,7 @@ type DeviceServiceSDK interface { AddRoute(route string, handler func(http.ResponseWriter, *http.Request), methods ...string) error // AddCustomRoute allows leveraging the existing internal web server to add routes specific to Device Service. - AddCustomRoute(route string, authentication Authentication, handler func(http.ResponseWriter, *http.Request), methods ...string) error + AddCustomRoute(route string, authentication Authentication, handler func(e echo.Context) error, methods ...string) error // LoadCustomConfig uses the Config Processor from go-mod-bootstrap to attempt to load service's // custom configuration. It uses the same command line flags to process the custom config in the same manner diff --git a/pkg/service/service.go b/pkg/service/service.go index 03a04970..f75a0433 100644 --- a/pkg/service/service.go +++ b/pkg/service/service.go @@ -33,6 +33,7 @@ import ( restController "github.com/edgexfoundry/device-sdk-go/v3/internal/controller/http" "github.com/edgexfoundry/device-sdk-go/v3/pkg/interfaces" sdkModels "github.com/edgexfoundry/device-sdk-go/v3/pkg/models" + "github.com/edgexfoundry/device-sdk-go/v3/pkg/utils" bootstrapConfig "github.com/edgexfoundry/go-mod-bootstrap/v3/bootstrap/config" bootstrapContainer "github.com/edgexfoundry/go-mod-bootstrap/v3/bootstrap/container" @@ -215,11 +216,11 @@ func (s *deviceService) DiscoveredDeviceChannel() chan []sdkModels.DiscoveredDev // AddRoute allows leveraging the existing internal web server to add routes specific to Device Service. // Deprecated: It is recommended to use AddCustomRoute() instead and enable authentication for custom routes func (s *deviceService) AddRoute(route string, handler func(http.ResponseWriter, *http.Request), methods ...string) error { - return s.AddCustomRoute(route, interfaces.Unauthenticated, handler, methods...) + return s.AddCustomRoute(route, interfaces.Unauthenticated, utils.WrapHandler(handler), methods...) } // AddCustomRoute allows leveraging the existing internal web server to add routes specific to Device Service. -func (s *deviceService) AddCustomRoute(route string, authentication interfaces.Authentication, handler func(http.ResponseWriter, *http.Request), methods ...string) error { +func (s *deviceService) AddCustomRoute(route string, authentication interfaces.Authentication, handler func(e echo.Context) error, methods ...string) error { if authentication == interfaces.Authenticated { lc := bootstrapContainer.LoggingClientFrom(s.dic.Get) secretProvider := bootstrapContainer.SecretProviderExtFrom(s.dic.Get) diff --git a/internal/controller/http/utils.go b/pkg/utils/handler.go similarity index 96% rename from internal/controller/http/utils.go rename to pkg/utils/handler.go index 5252d8d0..7cf982cc 100644 --- a/internal/controller/http/utils.go +++ b/pkg/utils/handler.go @@ -3,7 +3,7 @@ // // SPDX-License-Identifier: Apache-2.0 -package http +package utils import ( "net/http" diff --git a/pkg/utils/handler_test.go b/pkg/utils/handler_test.go new file mode 100644 index 00000000..0338c093 --- /dev/null +++ b/pkg/utils/handler_test.go @@ -0,0 +1,33 @@ +// +// Copyright (C) 2023 IOTech Ltd +// +// SPDX-License-Identifier: Apache-2.0 + +package utils + +import ( + "net/http" + "net/http/httptest" + "testing" + + "github.com/labstack/echo/v4" + "github.com/stretchr/testify/assert" +) + +func TestWrapHandler(t *testing.T) { + e := echo.New() + req := httptest.NewRequest(http.MethodGet, "/", nil) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + h := WrapHandler(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + _, err := w.Write([]byte("test")) + if err != nil { + assert.Fail(t, err.Error()) + } + }) + if assert.NoError(t, h(c)) { + assert.Equal(t, http.StatusOK, rec.Code) + assert.Equal(t, "test", rec.Body.String()) + } +}