Skip to content

Commit

Permalink
refactor: adopt jwt.MapClaims to replace custom MapClaims type across…
Browse files Browse the repository at this point in the history
… codebase (#342)

- Remove the custom `MapClaims` type definition
- Replace `MapClaims` with `jwt.MapClaims` in `PayloadFunc`, `GetClaimsFromJWT`, `ExtractClaims`, and `ExtractClaimsFromToken` functions
- Update test cases to use `jwt.MapClaims` instead of the custom `MapClaims`
- Update `ConvertClaims` function to accept `jwt.MapClaims`

Signed-off-by: appleboy <appleboy.tw@gmail.com>
  • Loading branch information
appleboy authored Feb 24, 2025
1 parent be787cc commit 74e77e5
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 30 deletions.
22 changes: 9 additions & 13 deletions auth_jwt.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,6 @@ import (
"github.com/youmark/pkcs8"
)

// MapClaims type that uses the map[string]interface{} for JSON decoding
// This is the default claims type if you don't supply one
type MapClaims map[string]interface{}

// GinJWTMiddleware provides a Json-Web-Token authentication implementation. On failure, a 401 HTTP response
// is returned. On success, the wrapped middleware is called, and the userID is made available as
// c.Get("userID").(string).
Expand Down Expand Up @@ -65,7 +61,7 @@ type GinJWTMiddleware struct {
// Note that the payload is not encrypted.
// The attributes mentioned on jwt.io can't be used as keys for the map.
// Optional, by default no additional data will be set.
PayloadFunc func(data interface{}) MapClaims
PayloadFunc func(data interface{}) jwt.MapClaims

// User can define own Unauthorized func.
Unauthorized func(c *gin.Context, code int, message string)
Expand Down Expand Up @@ -487,7 +483,7 @@ func (mw *GinJWTMiddleware) middlewareImpl(c *gin.Context) {
}

// GetClaimsFromJWT get claims from JWT token
func (mw *GinJWTMiddleware) GetClaimsFromJWT(c *gin.Context) (MapClaims, error) {
func (mw *GinJWTMiddleware) GetClaimsFromJWT(c *gin.Context) (jwt.MapClaims, error) {
token, err := mw.ParseToken(c)
if err != nil {
return nil, err
Expand All @@ -499,7 +495,7 @@ func (mw *GinJWTMiddleware) GetClaimsFromJWT(c *gin.Context) (MapClaims, error)
}
}

claims := MapClaims{}
claims := jwt.MapClaims{}
for key, value := range token.Claims.(jwt.MapClaims) {
claims[key] = value
}
Expand Down Expand Up @@ -801,22 +797,22 @@ func (mw *GinJWTMiddleware) unauthorized(c *gin.Context, code int, message strin
}

// ExtractClaims help to extract the JWT claims
func ExtractClaims(c *gin.Context) MapClaims {
func ExtractClaims(c *gin.Context) jwt.MapClaims {
claims, exists := c.Get("JWT_PAYLOAD")
if !exists {
return make(MapClaims)
return make(jwt.MapClaims)
}

return claims.(MapClaims)
return claims.(jwt.MapClaims)
}

// ExtractClaimsFromToken help to extract the JWT claims from token
func ExtractClaimsFromToken(token *jwt.Token) MapClaims {
func ExtractClaimsFromToken(token *jwt.Token) jwt.MapClaims {
if token == nil {
return make(MapClaims)
return make(jwt.MapClaims)
}

claims := MapClaims{}
claims := jwt.MapClaims{}
for key, value := range token.Claims.(jwt.MapClaims) {
claims[key] = value
}
Expand Down
34 changes: 17 additions & 17 deletions auth_jwt_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -235,9 +235,9 @@ func TestLoginHandler(t *testing.T) {
authMiddleware, err := New(&GinJWTMiddleware{
Realm: "test zone",
Key: key,
PayloadFunc: func(data interface{}) MapClaims {
PayloadFunc: func(data interface{}) jwt.MapClaims {
// Set custom claim, to be checked in Authorizator method
return MapClaims{"testkey": "testval", "exp": 0}
return jwt.MapClaims{"testkey": "testval", "exp": 0}
},
Authenticator: func(c *gin.Context) (interface{}, error) {
var loginVals Login
Expand Down Expand Up @@ -699,13 +699,13 @@ func TestClaimsDuringAuthorization(t *testing.T) {
Key: key,
Timeout: time.Hour,
MaxRefresh: time.Hour * 24,
PayloadFunc: func(data interface{}) MapClaims {
if v, ok := data.(MapClaims); ok {
PayloadFunc: func(data interface{}) jwt.MapClaims {
if v, ok := data.(jwt.MapClaims); ok {
return v
}

if reflect.TypeOf(data).String() != "string" {
return MapClaims{}
return jwt.MapClaims{}
}

var testkey string
Expand All @@ -718,7 +718,7 @@ func TestClaimsDuringAuthorization(t *testing.T) {
testkey = ""
}
// Set custom claim, to be checked in Authorizator method
return MapClaims{"identity": data.(string), "testkey": testkey, "exp": 0}
return jwt.MapClaims{"identity": data.(string), "testkey": testkey, "exp": 0}
},
Authenticator: func(c *gin.Context) (interface{}, error) {
var loginVals Login
Expand Down Expand Up @@ -762,7 +762,7 @@ func TestClaimsDuringAuthorization(t *testing.T) {
r := gofight.New()
handler := ginHandler(authMiddleware)

userToken, _, _ := authMiddleware.TokenGenerator(MapClaims{
userToken, _, _ := authMiddleware.TokenGenerator(jwt.MapClaims{
"identity": "administrator",
})

Expand Down Expand Up @@ -813,12 +813,12 @@ func TestClaimsDuringAuthorization(t *testing.T) {
})
}

func ConvertClaims(claims MapClaims) map[string]interface{} {
func ConvertClaims(claims jwt.MapClaims) map[string]interface{} {
return map[string]interface{}{}
}

func TestEmptyClaims(t *testing.T) {
var jwtClaims MapClaims
var jwtClaims jwt.MapClaims

// the middleware to test
authMiddleware, _ := New(&GinJWTMiddleware{
Expand Down Expand Up @@ -905,7 +905,7 @@ func TestTokenExpire(t *testing.T) {

r := gofight.New()

userToken, _, _ := authMiddleware.TokenGenerator(MapClaims{
userToken, _, _ := authMiddleware.TokenGenerator(jwt.MapClaims{
"identity": "admin",
})

Expand Down Expand Up @@ -935,7 +935,7 @@ func TestTokenFromQueryString(t *testing.T) {

r := gofight.New()

userToken, _, _ := authMiddleware.TokenGenerator(MapClaims{
userToken, _, _ := authMiddleware.TokenGenerator(jwt.MapClaims{
"identity": "admin",
})

Expand Down Expand Up @@ -973,7 +973,7 @@ func TestTokenFromParamPath(t *testing.T) {

r := gofight.New()

userToken, _, _ := authMiddleware.TokenGenerator(MapClaims{
userToken, _, _ := authMiddleware.TokenGenerator(jwt.MapClaims{
"identity": "admin",
})

Expand Down Expand Up @@ -1008,7 +1008,7 @@ func TestTokenFromCookieString(t *testing.T) {

r := gofight.New()

userToken, _, _ := authMiddleware.TokenGenerator(MapClaims{
userToken, _, _ := authMiddleware.TokenGenerator(jwt.MapClaims{
"identity": "admin",
})

Expand Down Expand Up @@ -1253,8 +1253,8 @@ func TestCheckTokenString(t *testing.T) {
Unauthorized: func(c *gin.Context, code int, message string) {
c.String(code, message)
},
PayloadFunc: func(data interface{}) MapClaims {
if v, ok := data.(MapClaims); ok {
PayloadFunc: func(data interface{}) jwt.MapClaims {
if v, ok := data.(jwt.MapClaims); ok {
return v
}

Expand All @@ -1266,7 +1266,7 @@ func TestCheckTokenString(t *testing.T) {

r := gofight.New()

userToken, _, _ := authMiddleware.TokenGenerator(MapClaims{
userToken, _, _ := authMiddleware.TokenGenerator(jwt.MapClaims{
"identity": "admin",
})

Expand Down Expand Up @@ -1295,7 +1295,7 @@ func TestCheckTokenString(t *testing.T) {

_, err = authMiddleware.ParseTokenString(userToken)
assert.Error(t, err)
assert.Equal(t, MapClaims{}, ExtractClaimsFromToken(nil))
assert.Equal(t, jwt.MapClaims{}, ExtractClaimsFromToken(nil))
}

func TestLogout(t *testing.T) {
Expand Down

0 comments on commit 74e77e5

Please sign in to comment.