Skip to content

Commit

Permalink
feat: Enable customizing the PyJWT object
Browse files Browse the repository at this point in the history
  • Loading branch information
taion committed Apr 23, 2019
1 parent 0b63b31 commit 3d0df10
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 41 deletions.
2 changes: 1 addition & 1 deletion flask_resty/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,6 @@
from .view import ApiView, GenericModelView, ModelView

try:
from .jwt import JwtAuthentication, JwkSetAuthentication
from .jwt import JwtAuthentication, JwkSetAuthentication, JwkSetPyJwt
except ImportError:
pass
88 changes: 48 additions & 40 deletions flask_resty/jwt.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,7 @@
from cryptography.x509 import load_der_x509_certificate
import flask
import jwt
from jwt.algorithms import get_default_algorithms
from jwt.exceptions import InvalidAlgorithmError, InvalidTokenError
from jwt import InvalidAlgorithmError, InvalidTokenError, PyJWT

from .authentication import AuthenticationBase
from .exceptions import ApiError
Expand Down Expand Up @@ -78,7 +77,11 @@ def get_token_from_request(self):
return flask.request.args.get(self.id_token_arg)

def decode_token(self, token):
return jwt.decode(token, **self.get_jwt_decode_args())
return self.pyjwt.decode(token, **self.get_jwt_decode_args())

@property
def pyjwt(self):
return jwt

def get_jwt_decode_args(self):
config = flask.current_app.config
Expand All @@ -98,62 +101,67 @@ def get_credentials(self, payload):
return payload


class JwkSetAuthentication(JwtAuthentication):
def __init__(self, jwk_set=None, **kwargs):
super(JwkSetAuthentication, self).__init__(**kwargs)
class JwkSetPyJwt(PyJWT):
def __init__(self, jwk_set, *args, **kwargs):
super(JwkSetPyJwt, self).__init__(*args, **kwargs)

self.jwk_set = jwk_set
self.algorithms = get_default_algorithms()

def get_jwk_set(self):
config = flask.current_app.config
return (
self.jwk_set if self.jwk_set
else config[self.get_config_key('jwk_set')]
)
def decode(self, jwt, **kwargs):
unverified_header = self.get_unverified_header(jwt)

def get_key_from_jwk(self, jwk, algorithm):
if 'x5c' in jwk:
return load_der_x509_certificate(
base64.b64decode(jwk['x5c'][0]),
default_backend(),
).public_key()
jwk = self.get_jwk_from_jwt(unverified_header)

# awkward
return algorithm.from_jwk(json.dumps(jwk))
# It's safe to use alg from the header here, as we verify that against
# the algorithm whitelist.
alg = jwk['alg'] if 'alg' in jwk else unverified_header['alg']

# jwt.decode will also check this, but this is more defensive.
if alg not in kwargs['algorithms']:
raise InvalidAlgorithmError(
"The specified alg value is not allowed",
)

def get_jwk_for_token(self, token):
unverified_header = jwt.get_unverified_header(token)
return super(JwkSetPyJwt, self).decode(
jwt,
key=self.get_key_from_jwk(jwk, alg),
**kwargs
)

def get_jwk_from_jwt(self, unverified_header):
try:
token_kid = unverified_header['kid']
except KeyError:
raise InvalidTokenError("Key ID header parameter is missing")

for jwk in self.get_jwk_set()['keys']:
for jwk in self.jwk_set['keys']:
if jwk['kid'] == token_kid:
return jwk

raise InvalidTokenError("no key found")

def decode_token(self, token):
args = self.get_jwt_decode_args()
def get_key_from_jwk(self, jwk, alg):
if 'x5c' in jwk:
return load_der_x509_certificate(
base64.b64decode(jwk['x5c'][0]),
default_backend(),
).public_key()

unverified_header = jwt.get_unverified_header(token)
jwk = self.get_jwk_for_token(token)
algorithm = self._algorithms[alg]

# It's safe to use alg from the header here, as we verify that against
# the algorithm whitelist.
alg = jwk['alg'] if 'alg' in jwk else unverified_header['alg']
# Awkward:
return algorithm.from_jwk(json.dumps(jwk))

# jwt.decode will also check this, but this is more defensive.
if alg not in args['algorithms']:
raise InvalidAlgorithmError(
"The specified alg value is not allowed",
)

return jwt.decode(
token,
key=self.get_key_from_jwk(jwk, self.algorithms[alg]),
**args
class JwkSetAuthentication(JwtAuthentication):
def __init__(self, jwk_set=None, **kwargs):
super(JwkSetAuthentication, self).__init__(**kwargs)

self.jwk_set = jwk_set

@property
def pyjwt(self):
return JwkSetPyJwt(
self.jwk_set or
flask.current_app.config[self.get_config_key('jwk_set')]
)

0 comments on commit 3d0df10

Please sign in to comment.