diff --git a/middleware/context.go b/middleware/context.go
index d890ed3..5e9ecf5 100644
--- a/middleware/context.go
+++ b/middleware/context.go
@@ -18,6 +18,8 @@ import (
stdContext "context"
"fmt"
"net/http"
+ "net/url"
+ "path"
"strings"
"sync"
@@ -584,45 +586,92 @@ func (c *Context) Respond(rw http.ResponseWriter, r *http.Request, produces []st
c.api.ServeErrorFor(route.Operation.ID)(rw, r, errors.New(http.StatusInternalServerError, "can't produce response"))
}
-func (c *Context) APIHandlerSwaggerUI(builder Builder) http.Handler {
+// APIHandlerSwaggerUI returns a handler to serve the API.
+//
+// This handler includes a swagger spec, router and the contract defined in the swagger spec.
+//
+// A spec UI (SwaggerUI) is served at {API base path}/docs and the spec document at /swagger.json
+// (these can be modified with uiOptions).
+func (c *Context) APIHandlerSwaggerUI(builder Builder, opts ...UIOption) http.Handler {
b := builder
if b == nil {
b = PassthroughBuilder
}
- var title string
- sp := c.spec.Spec()
- if sp != nil && sp.Info != nil && sp.Info.Title != "" {
- title = sp.Info.Title
- }
+ specPath, uiOpts, specOpts := c.uiOptionsForHandler(opts)
+ var swaggerUIOpts SwaggerUIOpts
+ fromCommonToAnyOptions(uiOpts, &swaggerUIOpts)
+
+ return Spec(specPath, c.spec.Raw(), SwaggerUI(swaggerUIOpts, c.RoutesHandler(b)), specOpts...)
+}
- swaggerUIOpts := SwaggerUIOpts{
- BasePath: c.BasePath(),
- Title: title,
+// APIHandlerRapiDoc returns a handler to serve the API.
+//
+// This handler includes a swagger spec, router and the contract defined in the swagger spec.
+//
+// A spec UI (RapiDoc) is served at {API base path}/docs and the spec document at /swagger.json
+// (these can be modified with uiOptions).
+func (c *Context) APIHandlerRapiDoc(builder Builder, opts ...UIOption) http.Handler {
+ b := builder
+ if b == nil {
+ b = PassthroughBuilder
}
- return Spec("", c.spec.Raw(), SwaggerUI(swaggerUIOpts, c.RoutesHandler(b)))
+ specPath, uiOpts, specOpts := c.uiOptionsForHandler(opts)
+ var rapidocUIOpts RapiDocOpts
+ fromCommonToAnyOptions(uiOpts, &rapidocUIOpts)
+
+ return Spec(specPath, c.spec.Raw(), RapiDoc(rapidocUIOpts, c.RoutesHandler(b)), specOpts...)
}
-// APIHandler returns a handler to serve the API, this includes a swagger spec, router and the contract defined in the swagger spec
-func (c *Context) APIHandler(builder Builder) http.Handler {
+// APIHandler returns a handler to serve the API.
+//
+// This handler includes a swagger spec, router and the contract defined in the swagger spec.
+//
+// A spec UI (Redoc) is served at {API base path}/docs and the spec document at /swagger.json
+// (these can be modified with uiOptions).
+func (c *Context) APIHandler(builder Builder, opts ...UIOption) http.Handler {
b := builder
if b == nil {
b = PassthroughBuilder
}
+ specPath, uiOpts, specOpts := c.uiOptionsForHandler(opts)
+ var redocOpts RedocOpts
+ fromCommonToAnyOptions(uiOpts, &redocOpts)
+
+ return Spec(specPath, c.spec.Raw(), Redoc(redocOpts, c.RoutesHandler(b)), specOpts...)
+}
+
+func (c Context) uiOptionsForHandler(opts []UIOption) (string, uiOptions, []SpecOption) {
var title string
sp := c.spec.Spec()
if sp != nil && sp.Info != nil && sp.Info.Title != "" {
title = sp.Info.Title
}
- redocOpts := RedocOpts{
- BasePath: c.BasePath(),
- Title: title,
+ // default options (may be overridden)
+ optsForContext := []UIOption{
+ WithUIBasePath(c.BasePath()),
+ WithUITitle(title),
+ }
+ optsForContext = append(optsForContext, opts...)
+ uiOpts := uiOptionsWithDefaults(optsForContext)
+
+ // If spec URL is provided, there is a non-default path to serve the spec.
+ // This makes sure that the UI middleware is aligned with the Spec middleware.
+ u, _ := url.Parse(uiOpts.SpecURL)
+ var specPath string
+ if u != nil {
+ specPath = u.Path
+ }
+
+ pth, doc := path.Split(specPath)
+ if pth == "." {
+ pth = ""
}
- return Spec("", c.spec.Raw(), Redoc(redocOpts, c.RoutesHandler(b)))
+ return pth, uiOpts, []SpecOption{WithSpecDocument(doc)}
}
// RoutesHandler returns a handler to serve the API, just the routes and the contract defined in the swagger spec
diff --git a/middleware/context_test.go b/middleware/context_test.go
index 4c0a62d..e9ed568 100644
--- a/middleware/context_test.go
+++ b/middleware/context_test.go
@@ -17,6 +17,7 @@ package middleware
import (
stdcontext "context"
"errors"
+ "fmt"
"net/http"
"net/http/httptest"
"strings"
@@ -32,8 +33,6 @@ import (
"github.com/stretchr/testify/require"
)
-const applicationJSON = "application/json"
-
type stubBindRequester struct {
}
@@ -131,28 +130,226 @@ func TestContentType_Issue174(t *testing.T) {
assert.Equal(t, http.StatusOK, recorder.Code)
}
+const (
+ testHost = "https://localhost:8080"
+
+ // how to get the spec document?
+ defaultSpecPath = "/swagger.json"
+ defaultSpecURL = testHost + defaultSpecPath
+ // how to get the UI asset?
+ defaultUIURL = testHost + "/api/docs"
+)
+
func TestServe(t *testing.T) {
spec, api := petstore.NewAPI(t)
handler := Serve(spec, api)
- // serve spec document
- request, err := http.NewRequestWithContext(stdcontext.Background(), http.MethodGet, "http://localhost:8080/swagger.json", nil)
- require.NoError(t, err)
+ t.Run("serve spec document", func(t *testing.T) {
+ request, err := http.NewRequestWithContext(stdcontext.Background(), http.MethodGet, defaultSpecURL, nil)
+ require.NoError(t, err)
- request.Header.Add("Content-Type", runtime.JSONMime)
- request.Header.Add("Accept", runtime.JSONMime)
- recorder := httptest.NewRecorder()
+ request.Header.Add("Content-Type", runtime.JSONMime)
+ request.Header.Add("Accept", runtime.JSONMime)
+ recorder := httptest.NewRecorder()
- handler.ServeHTTP(recorder, request)
- assert.Equal(t, http.StatusOK, recorder.Code)
+ handler.ServeHTTP(recorder, request)
+ assert.Equal(t, http.StatusOK, recorder.Code)
+ })
- request, err = http.NewRequestWithContext(stdcontext.Background(), http.MethodGet, "http://localhost:8080/swagger-ui", nil)
- require.NoError(t, err)
+ t.Run("should not find UI there", func(t *testing.T) {
+ request, err := http.NewRequestWithContext(stdcontext.Background(), http.MethodGet, testHost+"/swagger-ui", nil)
+ require.NoError(t, err)
+ recorder := httptest.NewRecorder()
- recorder = httptest.NewRecorder()
+ handler.ServeHTTP(recorder, request)
+ assert.Equal(t, http.StatusNotFound, recorder.Code)
+ })
- handler.ServeHTTP(recorder, request)
- assert.Equal(t, 404, recorder.Code)
+ t.Run("should find UI here", func(t *testing.T) {
+ request, err := http.NewRequestWithContext(stdcontext.Background(), http.MethodGet, defaultUIURL, nil)
+ require.NoError(t, err)
+ recorder := httptest.NewRecorder()
+
+ handler.ServeHTTP(recorder, request)
+ assert.Equal(t, http.StatusOK, recorder.Code)
+
+ htmlResponse := recorder.Body.String()
+ assert.Containsf(t, htmlResponse, "
Swagger Petstore", "should default to the API's title")
+ assert.Containsf(t, htmlResponse, "", "should default to /swagger.json spec document")
+ })
+}
+
+func TestServeWithUIs(t *testing.T) {
+ spec, api := petstore.NewAPI(t)
+ ctx := NewContext(spec, api, nil)
+
+ const (
+ alternateSpecURL = testHost + "/specs/petstore.json"
+ alternateSpecPath = "/specs/petstore.json"
+ alternateUIURL = testHost + "/ui/docs"
+ )
+
+ uiOpts := []UIOption{
+ WithUIBasePath("ui"), // override the base path from the spec, implies /ui
+ WithUIPath("docs"),
+ WithUISpecURL("/specs/petstore.json"),
+ }
+
+ t.Run("with APIHandler", func(t *testing.T) {
+ t.Run("with defaults", func(t *testing.T) {
+ handler := ctx.APIHandler(nil)
+
+ t.Run("should find UI", func(t *testing.T) {
+ request, err := http.NewRequestWithContext(stdcontext.Background(), http.MethodGet, defaultUIURL, nil)
+ require.NoError(t, err)
+ recorder := httptest.NewRecorder()
+
+ handler.ServeHTTP(recorder, request)
+ assert.Equal(t, http.StatusOK, recorder.Code)
+
+ htmlResponse := recorder.Body.String()
+ assert.Containsf(t, htmlResponse, "", alternateSpecPath))
+ })
+
+ t.Run("should find spec", func(t *testing.T) {
+ request, err := http.NewRequestWithContext(stdcontext.Background(), http.MethodGet, alternateSpecURL, nil)
+ require.NoError(t, err)
+ recorder := httptest.NewRecorder()
+
+ handler.ServeHTTP(recorder, request)
+ assert.Equal(t, http.StatusOK, recorder.Code)
+ })
+ })
+ })
+
+ t.Run("with APIHandlerSwaggerUI", func(t *testing.T) {
+ t.Run("with defaults", func(t *testing.T) {
+ handler := ctx.APIHandlerSwaggerUI(nil)
+
+ t.Run("should find UI", func(t *testing.T) {
+ request, err := http.NewRequestWithContext(stdcontext.Background(), http.MethodGet, defaultUIURL, nil)
+ require.NoError(t, err)
+ recorder := httptest.NewRecorder()
+
+ handler.ServeHTTP(recorder, request)
+ assert.Equal(t, http.StatusOK, recorder.Code)
+
+ htmlResponse := recorder.Body.String()
+ assert.Contains(t, htmlResponse, fmt.Sprintf(`url: '%s',`, strings.ReplaceAll(defaultSpecPath, `/`, `\/`)))
+ })
+
+ t.Run("should find spec", func(t *testing.T) {
+ request, err := http.NewRequestWithContext(stdcontext.Background(), http.MethodGet, defaultSpecURL, nil)
+ require.NoError(t, err)
+ recorder := httptest.NewRecorder()
+
+ handler.ServeHTTP(recorder, request)
+ assert.Equal(t, http.StatusOK, recorder.Code)
+ })
+ })
+
+ t.Run("with options", func(t *testing.T) {
+ handler := ctx.APIHandlerSwaggerUI(nil, uiOpts...)
+
+ t.Run("should find UI", func(t *testing.T) {
+ request, err := http.NewRequestWithContext(stdcontext.Background(), http.MethodGet, alternateUIURL, nil)
+ require.NoError(t, err)
+ recorder := httptest.NewRecorder()
+
+ handler.ServeHTTP(recorder, request)
+ assert.Equal(t, http.StatusOK, recorder.Code)
+
+ htmlResponse := recorder.Body.String()
+ assert.Contains(t, htmlResponse, fmt.Sprintf(`url: '%s',`, strings.ReplaceAll(alternateSpecPath, `/`, `\/`)))
+ })
+
+ t.Run("should find spec", func(t *testing.T) {
+ request, err := http.NewRequestWithContext(stdcontext.Background(), http.MethodGet, alternateSpecURL, nil)
+ require.NoError(t, err)
+ recorder := httptest.NewRecorder()
+
+ handler.ServeHTTP(recorder, request)
+ assert.Equal(t, http.StatusOK, recorder.Code)
+ })
+ })
+ })
+
+ t.Run("with APIHandlerRapiDoc", func(t *testing.T) {
+ t.Run("with defaults", func(t *testing.T) {
+ handler := ctx.APIHandlerRapiDoc(nil)
+
+ t.Run("should find UI", func(t *testing.T) {
+ request, err := http.NewRequestWithContext(stdcontext.Background(), http.MethodGet, defaultUIURL, nil)
+ require.NoError(t, err)
+ recorder := httptest.NewRecorder()
+
+ handler.ServeHTTP(recorder, request)
+ assert.Equal(t, http.StatusOK, recorder.Code)
+
+ htmlResponse := recorder.Body.String()
+ assert.Contains(t, htmlResponse, fmt.Sprintf("", defaultSpecPath))
+ })
+
+ t.Run("should find spec", func(t *testing.T) {
+ request, err := http.NewRequestWithContext(stdcontext.Background(), http.MethodGet, defaultSpecURL, nil)
+ require.NoError(t, err)
+ recorder := httptest.NewRecorder()
+
+ handler.ServeHTTP(recorder, request)
+ assert.Equal(t, http.StatusOK, recorder.Code)
+ })
+ })
+
+ t.Run("with options", func(t *testing.T) {
+ handler := ctx.APIHandlerRapiDoc(nil, uiOpts...)
+
+ t.Run("should find UI", func(t *testing.T) {
+ request, err := http.NewRequestWithContext(stdcontext.Background(), http.MethodGet, alternateUIURL, nil)
+ require.NoError(t, err)
+ recorder := httptest.NewRecorder()
+
+ handler.ServeHTTP(recorder, request)
+ assert.Equal(t, http.StatusOK, recorder.Code)
+
+ htmlResponse := recorder.Body.String()
+ assert.Contains(t, htmlResponse, fmt.Sprintf("", alternateSpecPath))
+ })
+ t.Run("should find spec", func(t *testing.T) {
+ request, err := http.NewRequestWithContext(stdcontext.Background(), http.MethodGet, alternateSpecURL, nil)
+ require.NoError(t, err)
+ recorder := httptest.NewRecorder()
+
+ handler.ServeHTTP(recorder, request)
+ assert.Equal(t, http.StatusOK, recorder.Code)
+ })
+ })
+ })
}
func TestContextAuthorize(t *testing.T) {
diff --git a/middleware/rapidoc.go b/middleware/rapidoc.go
index 5cb5314..ef75e74 100644
--- a/middleware/rapidoc.go
+++ b/middleware/rapidoc.go
@@ -1,4 +1,3 @@
-//nolint:dupl
package middleware
import (
@@ -11,66 +10,57 @@ import (
// RapiDocOpts configures the RapiDoc middlewares
type RapiDocOpts struct {
- // BasePath for the UI path, defaults to: /
+ // BasePath for the UI, defaults to: /
BasePath string
- // Path combines with BasePath for the full UI path, defaults to: docs
+
+ // Path combines with BasePath to construct the path to the UI, defaults to: "docs".
Path string
- // SpecURL the url to find the spec for
+
+ // SpecURL is the URL of the spec document.
+ //
+ // Defaults to: /swagger.json
SpecURL string
- // RapiDocURL for the js that generates the rapidoc site, defaults to: https://cdn.jsdelivr.net/npm/rapidoc/bundles/rapidoc.standalone.js
- RapiDocURL string
+
// Title for the documentation site, default to: API documentation
Title string
+
+ // Template specifies a custom template to serve the UI
+ Template string
+
+ // RapiDocURL points to the js asset that generates the rapidoc site.
+ //
+ // Defaults to https://unpkg.com/rapidoc/dist/rapidoc-min.js
+ RapiDocURL string
}
-// EnsureDefaults in case some options are missing
func (r *RapiDocOpts) EnsureDefaults() {
- if r.BasePath == "" {
- r.BasePath = "/"
- }
- if r.Path == "" {
- r.Path = defaultDocsPath
- }
- if r.SpecURL == "" {
- r.SpecURL = defaultDocsURL
- }
+ common := toCommonUIOptions(r)
+ common.EnsureDefaults()
+ fromCommonToAnyOptions(common, r)
+
+ // rapidoc-specifics
if r.RapiDocURL == "" {
r.RapiDocURL = rapidocLatest
}
- if r.Title == "" {
- r.Title = defaultDocsTitle
+ if r.Template == "" {
+ r.Template = rapidocTemplate
}
}
// RapiDoc creates a middleware to serve a documentation site for a swagger spec.
+//
// This allows for altering the spec before starting the http listener.
func RapiDoc(opts RapiDocOpts, next http.Handler) http.Handler {
opts.EnsureDefaults()
pth := path.Join(opts.BasePath, opts.Path)
- tmpl := template.Must(template.New("rapidoc").Parse(rapidocTemplate))
-
- buf := bytes.NewBuffer(nil)
- _ = tmpl.Execute(buf, opts)
- b := buf.Bytes()
-
- return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
- if r.URL.Path == pth {
- rw.Header().Set("Content-Type", "text/html; charset=utf-8")
- rw.WriteHeader(http.StatusOK)
-
- _, _ = rw.Write(b)
- return
- }
+ tmpl := template.Must(template.New("rapidoc").Parse(opts.Template))
+ assets := bytes.NewBuffer(nil)
+ if err := tmpl.Execute(assets, opts); err != nil {
+ panic(fmt.Errorf("cannot execute template: %w", err))
+ }
- if next == nil {
- rw.Header().Set("Content-Type", "text/plain")
- rw.WriteHeader(http.StatusNotFound)
- _, _ = rw.Write([]byte(fmt.Sprintf("%q not found", pth)))
- return
- }
- next.ServeHTTP(rw, r)
- })
+ return serveUI(pth, assets.Bytes(), next)
}
const (
diff --git a/middleware/rapidoc_test.go b/middleware/rapidoc_test.go
index 8c9f0ce..50682ee 100644
--- a/middleware/rapidoc_test.go
+++ b/middleware/rapidoc_test.go
@@ -12,17 +12,33 @@ import (
)
func TestRapiDocMiddleware(t *testing.T) {
- rapidoc := RapiDoc(RapiDocOpts{}, nil)
+ t.Run("with defaults", func(t *testing.T) {
+ rapidoc := RapiDoc(RapiDocOpts{}, nil)
- req, err := http.NewRequestWithContext(context.Background(), http.MethodGet, "/docs", nil)
- require.NoError(t, err)
- recorder := httptest.NewRecorder()
- rapidoc.ServeHTTP(recorder, req)
- assert.Equal(t, http.StatusOK, recorder.Code)
- assert.Equal(t, "text/html; charset=utf-8", recorder.Header().Get("Content-Type"))
- var o RapiDocOpts
- o.EnsureDefaults()
- assert.Contains(t, recorder.Body.String(), fmt.Sprintf("%s", o.Title))
- assert.Contains(t, recorder.Body.String(), fmt.Sprintf("", o.SpecURL))
- assert.Contains(t, recorder.Body.String(), rapidocLatest)
+ req, err := http.NewRequestWithContext(context.Background(), http.MethodGet, "/docs", nil)
+ require.NoError(t, err)
+ recorder := httptest.NewRecorder()
+ rapidoc.ServeHTTP(recorder, req)
+ assert.Equal(t, http.StatusOK, recorder.Code)
+ assert.Equal(t, "text/html; charset=utf-8", recorder.Header().Get(contentTypeHeader))
+ var o RapiDocOpts
+ o.EnsureDefaults()
+ assert.Contains(t, recorder.Body.String(), fmt.Sprintf("%s", o.Title))
+ assert.Contains(t, recorder.Body.String(), fmt.Sprintf("", o.SpecURL))
+ assert.Contains(t, recorder.Body.String(), rapidocLatest)
+ })
+
+ t.Run("edge cases", func(t *testing.T) {
+ t.Run("with custom template that fails to execute", func(t *testing.T) {
+ assert.Panics(t, func() {
+ RapiDoc(RapiDocOpts{
+ Template: `
+
+ spec-url='{{ .Unknown }}'
+
+`,
+ }, nil)
+ })
+ })
+ })
}
diff --git a/middleware/redoc.go b/middleware/redoc.go
index ca1d4ed..b96b01e 100644
--- a/middleware/redoc.go
+++ b/middleware/redoc.go
@@ -1,4 +1,3 @@
-//nolint:dupl
package middleware
import (
@@ -11,66 +10,58 @@ import (
// RedocOpts configures the Redoc middlewares
type RedocOpts struct {
- // BasePath for the UI path, defaults to: /
+ // BasePath for the UI, defaults to: /
BasePath string
- // Path combines with BasePath for the full UI path, defaults to: docs
+
+ // Path combines with BasePath to construct the path to the UI, defaults to: "docs".
Path string
- // SpecURL the url to find the spec for
+
+ // SpecURL is the URL of the spec document.
+ //
+ // Defaults to: /swagger.json
SpecURL string
- // RedocURL for the js that generates the redoc site, defaults to: https://cdn.jsdelivr.net/npm/redoc/bundles/redoc.standalone.js
- RedocURL string
+
// Title for the documentation site, default to: API documentation
Title string
+
+ // Template specifies a custom template to serve the UI
+ Template string
+
+ // RedocURL points to the js that generates the redoc site.
+ //
+ // Defaults to: https://cdn.jsdelivr.net/npm/redoc/bundles/redoc.standalone.js
+ RedocURL string
}
// EnsureDefaults in case some options are missing
func (r *RedocOpts) EnsureDefaults() {
- if r.BasePath == "" {
- r.BasePath = "/"
- }
- if r.Path == "" {
- r.Path = defaultDocsPath
- }
- if r.SpecURL == "" {
- r.SpecURL = defaultDocsURL
- }
+ common := toCommonUIOptions(r)
+ common.EnsureDefaults()
+ fromCommonToAnyOptions(common, r)
+
+ // redoc-specifics
if r.RedocURL == "" {
r.RedocURL = redocLatest
}
- if r.Title == "" {
- r.Title = defaultDocsTitle
+ if r.Template == "" {
+ r.Template = redocTemplate
}
}
// Redoc creates a middleware to serve a documentation site for a swagger spec.
+//
// This allows for altering the spec before starting the http listener.
func Redoc(opts RedocOpts, next http.Handler) http.Handler {
opts.EnsureDefaults()
pth := path.Join(opts.BasePath, opts.Path)
- tmpl := template.Must(template.New("redoc").Parse(redocTemplate))
-
- buf := bytes.NewBuffer(nil)
- _ = tmpl.Execute(buf, opts)
- b := buf.Bytes()
-
- return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
- if r.URL.Path == pth {
- rw.Header().Set("Content-Type", "text/html; charset=utf-8")
- rw.WriteHeader(http.StatusOK)
-
- _, _ = rw.Write(b)
- return
- }
+ tmpl := template.Must(template.New("redoc").Parse(opts.Template))
+ assets := bytes.NewBuffer(nil)
+ if err := tmpl.Execute(assets, opts); err != nil {
+ panic(fmt.Errorf("cannot execute template: %w", err))
+ }
- if next == nil {
- rw.Header().Set("Content-Type", "text/plain")
- rw.WriteHeader(http.StatusNotFound)
- _, _ = rw.Write([]byte(fmt.Sprintf("%q not found", pth)))
- return
- }
- next.ServeHTTP(rw, r)
- })
+ return serveUI(pth, assets.Bytes(), next)
}
const (
diff --git a/middleware/redoc_test.go b/middleware/redoc_test.go
index 71a7c1b..f117898 100644
--- a/middleware/redoc_test.go
+++ b/middleware/redoc_test.go
@@ -12,17 +12,105 @@ import (
)
func TestRedocMiddleware(t *testing.T) {
- redoc := Redoc(RedocOpts{}, nil)
-
- req, err := http.NewRequestWithContext(context.Background(), http.MethodGet, "/docs", nil)
- require.NoError(t, err)
- recorder := httptest.NewRecorder()
- redoc.ServeHTTP(recorder, req)
- assert.Equal(t, http.StatusOK, recorder.Code)
- assert.Equal(t, "text/html; charset=utf-8", recorder.Header().Get("Content-Type"))
- var o RedocOpts
- o.EnsureDefaults()
- assert.Contains(t, recorder.Body.String(), fmt.Sprintf("%s", o.Title))
- assert.Contains(t, recorder.Body.String(), fmt.Sprintf("", o.SpecURL))
- assert.Contains(t, recorder.Body.String(), redocLatest)
+ t.Run("with defaults", func(t *testing.T) {
+ redoc := Redoc(RedocOpts{}, nil)
+
+ req, err := http.NewRequestWithContext(context.Background(), http.MethodGet, "/docs", nil)
+ require.NoError(t, err)
+ recorder := httptest.NewRecorder()
+ redoc.ServeHTTP(recorder, req)
+ assert.Equal(t, http.StatusOK, recorder.Code)
+ assert.Equal(t, "text/html; charset=utf-8", recorder.Header().Get(contentTypeHeader))
+ var o RedocOpts
+ o.EnsureDefaults()
+ assert.Contains(t, recorder.Body.String(), fmt.Sprintf("%s", o.Title))
+ assert.Contains(t, recorder.Body.String(), fmt.Sprintf("", o.SpecURL))
+ assert.Contains(t, recorder.Body.String(), redocLatest)
+ })
+
+ t.Run("with alternate path and spec URL", func(t *testing.T) {
+ redoc := Redoc(RedocOpts{
+ BasePath: "/base",
+ Path: "ui",
+ SpecURL: "/ui/swagger.json",
+ }, nil)
+
+ req, err := http.NewRequestWithContext(context.Background(), http.MethodGet, "/base/ui", nil)
+ require.NoError(t, err)
+ recorder := httptest.NewRecorder()
+ redoc.ServeHTTP(recorder, req)
+ assert.Equal(t, http.StatusOK, recorder.Code)
+ assert.Contains(t, recorder.Body.String(), "")
+ })
+
+ t.Run("with custom template", func(t *testing.T) {
+ redoc := Redoc(RedocOpts{
+ Template: `
+
+
+ {{ .Title }}
+
+
+
+
+
+
+
+
+
+
+
+
+`,
+ }, nil)
+
+ req, err := http.NewRequestWithContext(context.Background(), http.MethodGet, "/docs", nil)
+ require.NoError(t, err)
+ recorder := httptest.NewRecorder()
+ redoc.ServeHTTP(recorder, req)
+ assert.Equal(t, http.StatusOK, recorder.Code)
+ assert.Contains(t, recorder.Body.String(), "required-props-first=true")
+ })
+
+ t.Run("edge cases", func(t *testing.T) {
+ t.Run("with invalid custom template", func(t *testing.T) {
+ assert.Panics(t, func() {
+ Redoc(RedocOpts{
+ Template: `
+
+
+ spec-url='{{ .Spec
+
+`,
+ }, nil)
+ })
+ })
+
+ t.Run("with custom template that fails to execute", func(t *testing.T) {
+ assert.Panics(t, func() {
+ Redoc(RedocOpts{
+ Template: `
+
+ spec-url='{{ .Unknown }}'
+
+`,
+ }, nil)
+ })
+ })
+ })
}
diff --git a/middleware/spec.go b/middleware/spec.go
index c288a2b..87e17e3 100644
--- a/middleware/spec.go
+++ b/middleware/spec.go
@@ -19,29 +19,84 @@ import (
"path"
)
-// Spec creates a middleware to serve a swagger spec.
+const (
+ contentTypeHeader = "Content-Type"
+ applicationJSON = "application/json"
+)
+
+// SpecOption can be applied to the Spec serving middleware
+type SpecOption func(*specOptions)
+
+var defaultSpecOptions = specOptions{
+ Path: "",
+ Document: "swagger.json",
+}
+
+type specOptions struct {
+ Path string
+ Document string
+}
+
+func specOptionsWithDefaults(opts []SpecOption) specOptions {
+ o := defaultSpecOptions
+ for _, apply := range opts {
+ apply(&o)
+ }
+
+ return o
+}
+
+// Spec creates a middleware to serve a swagger spec as a JSON document.
+//
// This allows for altering the spec before starting the http listener.
-// This can be useful if you want to serve the swagger spec from another path than /swagger.json
-func Spec(basePath string, b []byte, next http.Handler) http.Handler {
+//
+// The basePath argument indicates the path of the spec document (defaults to "/").
+// Additional SpecOption can be used to change the name of the document (defaults to "swagger.json").
+func Spec(basePath string, b []byte, next http.Handler, opts ...SpecOption) http.Handler {
if basePath == "" {
basePath = "/"
}
- pth := path.Join(basePath, "swagger.json")
+ o := specOptionsWithDefaults(opts)
+ pth := path.Join(basePath, o.Path, o.Document)
return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
- if r.URL.Path == pth {
- rw.Header().Set("Content-Type", "application/json")
+ if path.Clean(r.URL.Path) == pth {
+ rw.Header().Set(contentTypeHeader, applicationJSON)
rw.WriteHeader(http.StatusOK)
- //#nosec
_, _ = rw.Write(b)
+
return
}
- if next == nil {
- rw.Header().Set("Content-Type", "application/json")
- rw.WriteHeader(http.StatusNotFound)
+ if next != nil {
+ next.ServeHTTP(rw, r)
+
return
}
- next.ServeHTTP(rw, r)
+
+ rw.Header().Set(contentTypeHeader, applicationJSON)
+ rw.WriteHeader(http.StatusNotFound)
})
}
+
+// WithSpecPath sets the path to be joined to the base path of the Spec middleware.
+//
+// This is empty by default.
+func WithSpecPath(pth string) SpecOption {
+ return func(o *specOptions) {
+ o.Path = pth
+ }
+}
+
+// WithSpecDocument sets the name of the JSON document served as a spec.
+//
+// By default, this is "swagger.json"
+func WithSpecDocument(doc string) SpecOption {
+ return func(o *specOptions) {
+ if doc == "" {
+ return
+ }
+
+ o.Document = doc
+ }
+}
diff --git a/middleware/spec_test.go b/middleware/spec_test.go
index 2e02077..efdd9f5 100644
--- a/middleware/spec_test.go
+++ b/middleware/spec_test.go
@@ -30,38 +30,76 @@ import (
func TestServeSpecMiddleware(t *testing.T) {
spec, api := petstore.NewAPI(t)
ctx := NewContext(spec, api, nil)
- handler := Spec("", ctx.spec.Raw(), nil)
- t.Run("serves spec", func(t *testing.T) {
- request, err := http.NewRequestWithContext(context.Background(), http.MethodGet, "/swagger.json", nil)
- require.NoError(t, err)
- request.Header.Add(runtime.HeaderContentType, runtime.JSONMime)
- recorder := httptest.NewRecorder()
+ t.Run("Spec handler", func(t *testing.T) {
+ handler := Spec("", ctx.spec.Raw(), nil)
- handler.ServeHTTP(recorder, request)
- assert.Equal(t, http.StatusOK, recorder.Code)
- })
+ t.Run("serves spec", func(t *testing.T) {
+ request, err := http.NewRequestWithContext(context.Background(), http.MethodGet, "/swagger.json", nil)
+ require.NoError(t, err)
+ request.Header.Add(runtime.HeaderContentType, runtime.JSONMime)
+ recorder := httptest.NewRecorder()
+
+ handler.ServeHTTP(recorder, request)
+ assert.Equal(t, http.StatusOK, recorder.Code)
+
+ responseHeaders := recorder.Result().Header //nolint:bodyclose // false positive from linter
+ responseContentType := responseHeaders.Get("Content-Type")
+ assert.Equal(t, applicationJSON, responseContentType)
+
+ responseBody := recorder.Body
+ require.NotNil(t, responseBody)
+ require.JSONEq(t, string(spec.Raw()), responseBody.String())
+ })
+
+ t.Run("returns 404 when no next handler", func(t *testing.T) {
+ request, err := http.NewRequestWithContext(context.Background(), http.MethodGet, "/api/pets", nil)
+ require.NoError(t, err)
+ request.Header.Add(runtime.HeaderContentType, runtime.JSONMime)
+ recorder := httptest.NewRecorder()
- t.Run("returns 404 when no next handler", func(t *testing.T) {
- request, err := http.NewRequestWithContext(context.Background(), http.MethodGet, "/api/pets", nil)
- require.NoError(t, err)
- request.Header.Add(runtime.HeaderContentType, runtime.JSONMime)
- recorder := httptest.NewRecorder()
+ handler.ServeHTTP(recorder, request)
+ assert.Equal(t, http.StatusNotFound, recorder.Code)
+ })
- handler.ServeHTTP(recorder, request)
- assert.Equal(t, http.StatusNotFound, recorder.Code)
+ t.Run("forwards to next handler for other url", func(t *testing.T) {
+ handler = Spec("", ctx.spec.Raw(), http.HandlerFunc(func(rw http.ResponseWriter, _ *http.Request) {
+ rw.WriteHeader(http.StatusOK)
+ }))
+ request, err := http.NewRequestWithContext(context.Background(), http.MethodGet, "/api/pets", nil)
+ require.NoError(t, err)
+ request.Header.Add(runtime.HeaderContentType, runtime.JSONMime)
+ recorder := httptest.NewRecorder()
+
+ handler.ServeHTTP(recorder, request)
+ assert.Equal(t, http.StatusOK, recorder.Code)
+ })
})
- t.Run("forwards to next handler for other url", func(t *testing.T) {
- handler = Spec("", ctx.spec.Raw(), http.HandlerFunc(func(rw http.ResponseWriter, _ *http.Request) {
- rw.WriteHeader(http.StatusOK)
- }))
- request, err := http.NewRequestWithContext(context.Background(), http.MethodGet, "/api/pets", nil)
- require.NoError(t, err)
- request.Header.Add(runtime.HeaderContentType, runtime.JSONMime)
- recorder := httptest.NewRecorder()
-
- handler.ServeHTTP(recorder, request)
- assert.Equal(t, http.StatusOK, recorder.Code)
+ t.Run("Spec handler with options", func(t *testing.T) {
+ handler := Spec("/swagger", ctx.spec.Raw(), nil,
+ WithSpecPath("spec"),
+ WithSpecDocument("myapi-swagger.json"),
+ )
+
+ t.Run("serves spec", func(t *testing.T) {
+ request, err := http.NewRequestWithContext(context.Background(), http.MethodGet, "/swagger/spec/myapi-swagger.json", nil)
+ require.NoError(t, err)
+ request.Header.Add(runtime.HeaderContentType, runtime.JSONMime)
+ recorder := httptest.NewRecorder()
+
+ handler.ServeHTTP(recorder, request)
+ assert.Equal(t, http.StatusOK, recorder.Code)
+ })
+
+ t.Run("should not find spec there", func(t *testing.T) {
+ request, err := http.NewRequestWithContext(context.Background(), http.MethodGet, "/swagger.json", nil)
+ require.NoError(t, err)
+ request.Header.Add(runtime.HeaderContentType, runtime.JSONMime)
+ recorder := httptest.NewRecorder()
+
+ handler.ServeHTTP(recorder, request)
+ assert.Equal(t, http.StatusNotFound, recorder.Code)
+ })
})
}
diff --git a/middleware/swaggerui.go b/middleware/swaggerui.go
index 846e3cf..ec3c10c 100644
--- a/middleware/swaggerui.go
+++ b/middleware/swaggerui.go
@@ -8,40 +8,65 @@ import (
"path"
)
-// SwaggerUIOpts configures the Swaggerui middlewares
+// SwaggerUIOpts configures the SwaggerUI middleware
type SwaggerUIOpts struct {
- // BasePath for the UI path, defaults to: /
+ // BasePath for the API, defaults to: /
BasePath string
- // Path combines with BasePath for the full UI path, defaults to: docs
+
+ // Path combines with BasePath to construct the path to the UI, defaults to: "docs".
Path string
- // SpecURL the url to find the spec for
+
+ // SpecURL is the URL of the spec document.
+ //
+ // Defaults to: /swagger.json
SpecURL string
+
+ // Title for the documentation site, default to: API documentation
+ Title string
+
+ // Template specifies a custom template to serve the UI
+ Template string
+
// OAuthCallbackURL the url called after OAuth2 login
OAuthCallbackURL string
// The three components needed to embed swagger-ui
- SwaggerURL string
+
+ // SwaggerURL points to the js that generates the SwaggerUI site.
+ //
+ // Defaults to: https://unpkg.com/swagger-ui-dist/swagger-ui-bundle.js
+ SwaggerURL string
+
SwaggerPresetURL string
SwaggerStylesURL string
Favicon32 string
Favicon16 string
-
- // Title for the documentation site, default to: API documentation
- Title string
}
// EnsureDefaults in case some options are missing
func (r *SwaggerUIOpts) EnsureDefaults() {
- if r.BasePath == "" {
- r.BasePath = "/"
- }
- if r.Path == "" {
- r.Path = defaultDocsPath
+ r.ensureDefaults()
+
+ if r.Template == "" {
+ r.Template = swaggeruiTemplate
}
- if r.SpecURL == "" {
- r.SpecURL = defaultDocsURL
+}
+
+func (r *SwaggerUIOpts) EnsureDefaultsOauth2() {
+ r.ensureDefaults()
+
+ if r.Template == "" {
+ r.Template = swaggerOAuthTemplate
}
+}
+
+func (r *SwaggerUIOpts) ensureDefaults() {
+ common := toCommonUIOptions(r)
+ common.EnsureDefaults()
+ fromCommonToAnyOptions(common, r)
+
+ // swaggerui-specifics
if r.OAuthCallbackURL == "" {
r.OAuthCallbackURL = path.Join(r.BasePath, r.Path, "oauth2-callback")
}
@@ -60,40 +85,22 @@ func (r *SwaggerUIOpts) EnsureDefaults() {
if r.Favicon32 == "" {
r.Favicon32 = swaggerFavicon32Latest
}
- if r.Title == "" {
- r.Title = defaultDocsTitle
- }
}
// SwaggerUI creates a middleware to serve a documentation site for a swagger spec.
+//
// This allows for altering the spec before starting the http listener.
func SwaggerUI(opts SwaggerUIOpts, next http.Handler) http.Handler {
opts.EnsureDefaults()
pth := path.Join(opts.BasePath, opts.Path)
- tmpl := template.Must(template.New("swaggerui").Parse(swaggeruiTemplate))
-
- buf := bytes.NewBuffer(nil)
- _ = tmpl.Execute(buf, &opts)
- b := buf.Bytes()
-
- return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
- if path.Join(r.URL.Path) == pth {
- rw.Header().Set("Content-Type", "text/html; charset=utf-8")
- rw.WriteHeader(http.StatusOK)
-
- _, _ = rw.Write(b)
- return
- }
-
- if next == nil {
- rw.Header().Set("Content-Type", "text/plain")
- rw.WriteHeader(http.StatusNotFound)
- _, _ = rw.Write([]byte(fmt.Sprintf("%q not found", pth)))
- return
- }
- next.ServeHTTP(rw, r)
- })
+ tmpl := template.Must(template.New("swaggerui").Parse(opts.Template))
+ assets := bytes.NewBuffer(nil)
+ if err := tmpl.Execute(assets, opts); err != nil {
+ panic(fmt.Errorf("cannot execute template: %w", err))
+ }
+
+ return serveUI(pth, assets.Bytes(), next)
}
const (
diff --git a/middleware/swaggerui_oauth2.go b/middleware/swaggerui_oauth2.go
index 576f600..e81212f 100644
--- a/middleware/swaggerui_oauth2.go
+++ b/middleware/swaggerui_oauth2.go
@@ -4,37 +4,20 @@ import (
"bytes"
"fmt"
"net/http"
- "path"
"text/template"
)
func SwaggerUIOAuth2Callback(opts SwaggerUIOpts, next http.Handler) http.Handler {
- opts.EnsureDefaults()
+ opts.EnsureDefaultsOauth2()
pth := opts.OAuthCallbackURL
- tmpl := template.Must(template.New("swaggeroauth").Parse(swaggerOAuthTemplate))
+ tmpl := template.Must(template.New("swaggeroauth").Parse(opts.Template))
+ assets := bytes.NewBuffer(nil)
+ if err := tmpl.Execute(assets, opts); err != nil {
+ panic(fmt.Errorf("cannot execute template: %w", err))
+ }
- buf := bytes.NewBuffer(nil)
- _ = tmpl.Execute(buf, &opts)
- b := buf.Bytes()
-
- return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
- if path.Join(r.URL.Path) == pth {
- rw.Header().Set("Content-Type", "text/html; charset=utf-8")
- rw.WriteHeader(http.StatusOK)
-
- _, _ = rw.Write(b)
- return
- }
-
- if next == nil {
- rw.Header().Set("Content-Type", "text/plain")
- rw.WriteHeader(http.StatusNotFound)
- _, _ = rw.Write([]byte(fmt.Sprintf("%q not found", pth)))
- return
- }
- next.ServeHTTP(rw, r)
- })
+ return serveUI(pth, assets.Bytes(), next)
}
const (
diff --git a/middleware/swaggerui_oauth2_test.go b/middleware/swaggerui_oauth2_test.go
index c9b28c9..a19c430 100644
--- a/middleware/swaggerui_oauth2_test.go
+++ b/middleware/swaggerui_oauth2_test.go
@@ -12,15 +12,35 @@ import (
)
func TestSwaggerUIOAuth2CallbackMiddleware(t *testing.T) {
- redoc := SwaggerUIOAuth2Callback(SwaggerUIOpts{}, nil)
+ t.Run("with defaults", func(t *testing.T) {
+ doc := SwaggerUIOAuth2Callback(SwaggerUIOpts{}, nil)
- req, err := http.NewRequestWithContext(context.Background(), http.MethodGet, "/docs/oauth2-callback", nil)
- require.NoError(t, err)
- recorder := httptest.NewRecorder()
- redoc.ServeHTTP(recorder, req)
- assert.Equal(t, http.StatusOK, recorder.Code)
- assert.Equal(t, "text/html; charset=utf-8", recorder.Header().Get("Content-Type"))
- var o SwaggerUIOpts
- o.EnsureDefaults()
- assert.Contains(t, recorder.Body.String(), fmt.Sprintf("%s", o.Title))
+ req, err := http.NewRequestWithContext(context.Background(), http.MethodGet, "/docs/oauth2-callback", nil)
+ require.NoError(t, err)
+ recorder := httptest.NewRecorder()
+
+ doc.ServeHTTP(recorder, req)
+ require.Equal(t, http.StatusOK, recorder.Code)
+ assert.Equal(t, "text/html; charset=utf-8", recorder.Header().Get(contentTypeHeader))
+
+ var o SwaggerUIOpts
+ o.EnsureDefaultsOauth2()
+ htmlResponse := recorder.Body.String()
+ assert.Contains(t, htmlResponse, fmt.Sprintf("%s", o.Title))
+ assert.Contains(t, htmlResponse, `oauth2.auth.schema.get("flow") === "accessCode"`)
+ })
+
+ t.Run("edge cases", func(t *testing.T) {
+ t.Run("with custom template that fails to execute", func(t *testing.T) {
+ assert.Panics(t, func() {
+ SwaggerUIOAuth2Callback(SwaggerUIOpts{
+ Template: `
+
+ spec-url='{{ .Unknown }}'
+
+`,
+ }, nil)
+ })
+ })
+ })
}
diff --git a/middleware/swaggerui_test.go b/middleware/swaggerui_test.go
index 4047e66..b299c09 100644
--- a/middleware/swaggerui_test.go
+++ b/middleware/swaggerui_test.go
@@ -17,17 +17,53 @@ func TestSwaggerUIMiddleware(t *testing.T) {
o.EnsureDefaults()
swui := SwaggerUI(o, nil)
- req, err := http.NewRequestWithContext(context.Background(), http.MethodGet, "/docs", nil)
- require.NoError(t, err)
- recorder := httptest.NewRecorder()
- swui.ServeHTTP(recorder, req)
- assert.Equal(t, http.StatusOK, recorder.Code)
- assert.Equal(t, "text/html; charset=utf-8", recorder.Header().Get("Content-Type"))
- assert.Contains(t, recorder.Body.String(), fmt.Sprintf("%s", o.Title))
- assert.Contains(t, recorder.Body.String(), fmt.Sprintf(`url: '%s',`, strings.ReplaceAll(o.SpecURL, `/`, `\/`)))
- assert.Contains(t, recorder.Body.String(), swaggerLatest)
- assert.Contains(t, recorder.Body.String(), swaggerPresetLatest)
- assert.Contains(t, recorder.Body.String(), swaggerStylesLatest)
- assert.Contains(t, recorder.Body.String(), swaggerFavicon16Latest)
- assert.Contains(t, recorder.Body.String(), swaggerFavicon32Latest)
+ t.Run("with defaults ", func(t *testing.T) {
+ req, err := http.NewRequestWithContext(context.Background(), http.MethodGet, "/docs", nil)
+ require.NoError(t, err)
+ recorder := httptest.NewRecorder()
+
+ swui.ServeHTTP(recorder, req)
+ assert.Equal(t, http.StatusOK, recorder.Code)
+
+ assert.Equal(t, "text/html; charset=utf-8", recorder.Header().Get(contentTypeHeader))
+ assert.Contains(t, recorder.Body.String(), fmt.Sprintf("%s", o.Title))
+ assert.Contains(t, recorder.Body.String(), fmt.Sprintf(`url: '%s',`, strings.ReplaceAll(o.SpecURL, `/`, `\/`)))
+ assert.Contains(t, recorder.Body.String(), swaggerLatest)
+ assert.Contains(t, recorder.Body.String(), swaggerPresetLatest)
+ assert.Contains(t, recorder.Body.String(), swaggerStylesLatest)
+ assert.Contains(t, recorder.Body.String(), swaggerFavicon16Latest)
+ assert.Contains(t, recorder.Body.String(), swaggerFavicon32Latest)
+ })
+
+ t.Run("with path with a trailing / (issue #238)", func(t *testing.T) {
+ req, err := http.NewRequestWithContext(context.Background(), http.MethodGet, "/docs/", nil)
+ require.NoError(t, err)
+ recorder := httptest.NewRecorder()
+
+ swui.ServeHTTP(recorder, req)
+ assert.Equal(t, http.StatusOK, recorder.Code)
+ })
+
+ t.Run("should yield not found", func(t *testing.T) {
+ req, err := http.NewRequestWithContext(context.Background(), http.MethodGet, "/nowhere", nil)
+ require.NoError(t, err)
+ recorder := httptest.NewRecorder()
+
+ swui.ServeHTTP(recorder, req)
+ assert.Equal(t, http.StatusNotFound, recorder.Code)
+ })
+
+ t.Run("edge cases", func(t *testing.T) {
+ t.Run("with custom template that fails to execute", func(t *testing.T) {
+ assert.Panics(t, func() {
+ SwaggerUI(SwaggerUIOpts{
+ Template: `
+
+ spec-url='{{ .Unknown }}'
+
+`,
+ }, nil)
+ })
+ })
+ })
}
diff --git a/middleware/ui_defaults.go b/middleware/ui_defaults.go
deleted file mode 100644
index 25817d2..0000000
--- a/middleware/ui_defaults.go
+++ /dev/null
@@ -1,8 +0,0 @@
-package middleware
-
-const (
- // constants that are common to all UI-serving middlewares
- defaultDocsPath = "docs"
- defaultDocsURL = "/swagger.json"
- defaultDocsTitle = "API Documentation"
-)
diff --git a/middleware/ui_options.go b/middleware/ui_options.go
new file mode 100644
index 0000000..b86efa0
--- /dev/null
+++ b/middleware/ui_options.go
@@ -0,0 +1,173 @@
+package middleware
+
+import (
+ "bytes"
+ "encoding/gob"
+ "fmt"
+ "net/http"
+ "path"
+ "strings"
+)
+
+const (
+ // constants that are common to all UI-serving middlewares
+ defaultDocsPath = "docs"
+ defaultDocsURL = "/swagger.json"
+ defaultDocsTitle = "API Documentation"
+)
+
+// uiOptions defines common options for UI serving middlewares.
+type uiOptions struct {
+ // BasePath for the UI, defaults to: /
+ BasePath string
+
+ // Path combines with BasePath to construct the path to the UI, defaults to: "docs".
+ Path string
+
+ // SpecURL is the URL of the spec document.
+ //
+ // Defaults to: /swagger.json
+ SpecURL string
+
+ // Title for the documentation site, default to: API documentation
+ Title string
+
+ // Template specifies a custom template to serve the UI
+ Template string
+}
+
+// toCommonUIOptions converts any UI option type to retain the common options.
+//
+// This uses gob encoding/decoding to convert common fields from one struct to another.
+func toCommonUIOptions(opts interface{}) uiOptions {
+ var buf bytes.Buffer
+ enc := gob.NewEncoder(&buf)
+ dec := gob.NewDecoder(&buf)
+ var o uiOptions
+ err := enc.Encode(opts)
+ if err != nil {
+ panic(err)
+ }
+
+ err = dec.Decode(&o)
+ if err != nil {
+ panic(err)
+ }
+
+ return o
+}
+
+func fromCommonToAnyOptions[T any](source uiOptions, target *T) {
+ var buf bytes.Buffer
+ enc := gob.NewEncoder(&buf)
+ dec := gob.NewDecoder(&buf)
+ err := enc.Encode(source)
+ if err != nil {
+ panic(err)
+ }
+
+ err = dec.Decode(target)
+ if err != nil {
+ panic(err)
+ }
+}
+
+// UIOption can be applied to UI serving middleware, such as Context.APIHandler or
+// Context.APIHandlerSwaggerUI to alter the defaut behavior.
+type UIOption func(*uiOptions)
+
+func uiOptionsWithDefaults(opts []UIOption) uiOptions {
+ var o uiOptions
+ for _, apply := range opts {
+ apply(&o)
+ }
+
+ return o
+}
+
+// WithUIBasePath sets the base path from where to serve the UI assets.
+//
+// By default, Context middleware sets this value to the API base path.
+func WithUIBasePath(base string) UIOption {
+ return func(o *uiOptions) {
+ if !strings.HasPrefix(base, "/") {
+ base = "/" + base
+ }
+ o.BasePath = base
+ }
+}
+
+// WithUIPath sets the path from where to serve the UI assets (i.e. /{basepath}/{path}.
+func WithUIPath(pth string) UIOption {
+ return func(o *uiOptions) {
+ o.Path = pth
+ }
+}
+
+// WithUISpecURL sets the path from where to serve swagger spec document.
+//
+// This may be specified as a full URL or a path.
+//
+// By default, this is "/swagger.json"
+func WithUISpecURL(specURL string) UIOption {
+ return func(o *uiOptions) {
+ o.SpecURL = specURL
+ }
+}
+
+// WithUITitle sets the title of the UI.
+//
+// By default, Context middleware sets this value to the title found in the API spec.
+func WithUITitle(title string) UIOption {
+ return func(o *uiOptions) {
+ o.Title = title
+ }
+}
+
+// WithTemplate allows to set a custom template for the UI.
+//
+// UI middleware will panic if the template does not parse or execute properly.
+func WithTemplate(tpl string) UIOption {
+ return func(o *uiOptions) {
+ o.Template = tpl
+ }
+}
+
+// EnsureDefaults in case some options are missing
+func (r *uiOptions) EnsureDefaults() {
+ if r.BasePath == "" {
+ r.BasePath = "/"
+ }
+ if r.Path == "" {
+ r.Path = defaultDocsPath
+ }
+ if r.SpecURL == "" {
+ r.SpecURL = defaultDocsURL
+ }
+ if r.Title == "" {
+ r.Title = defaultDocsTitle
+ }
+}
+
+// serveUI creates a middleware that serves a templated asset as text/html.
+func serveUI(pth string, assets []byte, next http.Handler) http.Handler {
+ return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
+ if path.Clean(r.URL.Path) == pth {
+ rw.Header().Set(contentTypeHeader, "text/html; charset=utf-8")
+ rw.WriteHeader(http.StatusOK)
+ _, _ = rw.Write(assets)
+
+ return
+ }
+
+ if next != nil {
+ next.ServeHTTP(rw, r)
+
+ return
+ }
+
+ rw.Header().Set(contentTypeHeader, "text/plain")
+ rw.WriteHeader(http.StatusNotFound)
+ _, _ = rw.Write([]byte(fmt.Sprintf("%q not found", pth)))
+ })
+}
diff --git a/middleware/ui_options_test.go b/middleware/ui_options_test.go
new file mode 100644
index 0000000..7183369
--- /dev/null
+++ b/middleware/ui_options_test.go
@@ -0,0 +1,105 @@
+package middleware
+
+import (
+ "testing"
+
+ "github.com/stretchr/testify/require"
+)
+
+func TestConvertOptions(t *testing.T) {
+ t.Run("from any UI options to uiOptions", func(t *testing.T) {
+ t.Run("from RedocOpts", func(t *testing.T) {
+ in := RedocOpts{
+ BasePath: "a",
+ Path: "b",
+ SpecURL: "c",
+ Template: "d",
+ Title: "e",
+ RedocURL: "f",
+ }
+ out := toCommonUIOptions(in)
+
+ require.Equal(t, "a", out.BasePath)
+ require.Equal(t, "b", out.Path)
+ require.Equal(t, "c", out.SpecURL)
+ require.Equal(t, "d", out.Template)
+ require.Equal(t, "e", out.Title)
+ })
+
+ t.Run("from RapiDocOpts", func(t *testing.T) {
+ in := RapiDocOpts{
+ BasePath: "a",
+ Path: "b",
+ SpecURL: "c",
+ Template: "d",
+ Title: "e",
+ RapiDocURL: "f",
+ }
+ out := toCommonUIOptions(in)
+
+ require.Equal(t, "a", out.BasePath)
+ require.Equal(t, "b", out.Path)
+ require.Equal(t, "c", out.SpecURL)
+ require.Equal(t, "d", out.Template)
+ require.Equal(t, "e", out.Title)
+ })
+
+ t.Run("from SwaggerUIOpts", func(t *testing.T) {
+ in := SwaggerUIOpts{
+ BasePath: "a",
+ Path: "b",
+ SpecURL: "c",
+ Template: "d",
+ Title: "e",
+ SwaggerURL: "f",
+ }
+ out := toCommonUIOptions(in)
+
+ require.Equal(t, "a", out.BasePath)
+ require.Equal(t, "b", out.Path)
+ require.Equal(t, "c", out.SpecURL)
+ require.Equal(t, "d", out.Template)
+ require.Equal(t, "e", out.Title)
+ })
+ })
+
+ t.Run("from uiOptions to any UI options", func(t *testing.T) {
+ in := uiOptions{
+ BasePath: "a",
+ Path: "b",
+ SpecURL: "c",
+ Template: "d",
+ Title: "e",
+ }
+
+ t.Run("to RedocOpts", func(t *testing.T) {
+ var out RedocOpts
+ fromCommonToAnyOptions(in, &out)
+ require.Equal(t, "a", out.BasePath)
+ require.Equal(t, "b", out.Path)
+ require.Equal(t, "c", out.SpecURL)
+ require.Equal(t, "d", out.Template)
+ require.Equal(t, "e", out.Title)
+ })
+
+ t.Run("to RapiDocOpts", func(t *testing.T) {
+ var out RapiDocOpts
+ fromCommonToAnyOptions(in, &out)
+ require.Equal(t, "a", out.BasePath)
+ require.Equal(t, "b", out.Path)
+ require.Equal(t, "c", out.SpecURL)
+ require.Equal(t, "d", out.Template)
+ require.Equal(t, "e", out.Title)
+ })
+
+ t.Run("to SwaggerUIOpts", func(t *testing.T) {
+ var out SwaggerUIOpts
+ fromCommonToAnyOptions(in, &out)
+ require.Equal(t, "a", out.BasePath)
+ require.Equal(t, "b", out.Path)
+ require.Equal(t, "c", out.SpecURL)
+ require.Equal(t, "d", out.Template)
+ require.Equal(t, "e", out.Title)
+ })
+ })
+}