diff --git a/flask_resty/__init__.py b/flask_resty/__init__.py index c8662f9..53439b5 100644 --- a/flask_resty/__init__.py +++ b/flask_resty/__init__.py @@ -33,6 +33,6 @@ from .view import ApiView, GenericModelView, ModelView try: - from .jwt import JwtAuthentication, JwkSetAuthentication + from .jwt import JwtAuthentication, JwkSetAuthentication, JwkSetPyJwt except ImportError: pass diff --git a/flask_resty/jwt.py b/flask_resty/jwt.py index ba66246..b676eba 100644 --- a/flask_resty/jwt.py +++ b/flask_resty/jwt.py @@ -7,8 +7,7 @@ from cryptography.x509 import load_der_x509_certificate import flask import jwt -from jwt.algorithms import get_default_algorithms -from jwt.exceptions import InvalidAlgorithmError, InvalidTokenError +from jwt import InvalidAlgorithmError, InvalidTokenError, PyJWT from .authentication import AuthenticationBase from .exceptions import ApiError @@ -78,7 +77,11 @@ def get_token_from_request(self): return flask.request.args.get(self.id_token_arg) def decode_token(self, token): - return jwt.decode(token, **self.get_jwt_decode_args()) + return self.pyjwt.decode(token, **self.get_jwt_decode_args()) + + @property + def pyjwt(self): + return jwt def get_jwt_decode_args(self): config = flask.current_app.config @@ -98,62 +101,67 @@ def get_credentials(self, payload): return payload -class JwkSetAuthentication(JwtAuthentication): - def __init__(self, jwk_set=None, **kwargs): - super(JwkSetAuthentication, self).__init__(**kwargs) +class JwkSetPyJwt(PyJWT): + def __init__(self, jwk_set, *args, **kwargs): + super(JwkSetPyJwt, self).__init__(*args, **kwargs) self.jwk_set = jwk_set - self.algorithms = get_default_algorithms() - def get_jwk_set(self): - config = flask.current_app.config - return ( - self.jwk_set if self.jwk_set - else config[self.get_config_key('jwk_set')] - ) + def decode(self, jwt, **kwargs): + unverified_header = self.get_unverified_header(jwt) - def get_key_from_jwk(self, jwk, algorithm): - if 'x5c' in jwk: - return load_der_x509_certificate( - base64.b64decode(jwk['x5c'][0]), - default_backend(), - ).public_key() + jwk = self.get_jwk_from_jwt(unverified_header) - # awkward - return algorithm.from_jwk(json.dumps(jwk)) + # It's safe to use alg from the header here, as we verify that against + # the algorithm whitelist. + alg = jwk['alg'] if 'alg' in jwk else unverified_header['alg'] + + # jwt.decode will also check this, but this is more defensive. + if alg not in kwargs['algorithms']: + raise InvalidAlgorithmError( + "The specified alg value is not allowed", + ) - def get_jwk_for_token(self, token): - unverified_header = jwt.get_unverified_header(token) + return super(JwkSetPyJwt, self).decode( + jwt, + key=self.get_key_from_jwk(jwk, alg), + **kwargs + ) + def get_jwk_from_jwt(self, unverified_header): try: token_kid = unverified_header['kid'] except KeyError: raise InvalidTokenError("Key ID header parameter is missing") - for jwk in self.get_jwk_set()['keys']: + for jwk in self.jwk_set['keys']: if jwk['kid'] == token_kid: return jwk raise InvalidTokenError("no key found") - def decode_token(self, token): - args = self.get_jwt_decode_args() + def get_key_from_jwk(self, jwk, alg): + if 'x5c' in jwk: + return load_der_x509_certificate( + base64.b64decode(jwk['x5c'][0]), + default_backend(), + ).public_key() - unverified_header = jwt.get_unverified_header(token) - jwk = self.get_jwk_for_token(token) + algorithm = self._algorithms[alg] - # It's safe to use alg from the header here, as we verify that against - # the algorithm whitelist. - alg = jwk['alg'] if 'alg' in jwk else unverified_header['alg'] + # Awkward: + return algorithm.from_jwk(json.dumps(jwk)) - # jwt.decode will also check this, but this is more defensive. - if alg not in args['algorithms']: - raise InvalidAlgorithmError( - "The specified alg value is not allowed", - ) - return jwt.decode( - token, - key=self.get_key_from_jwk(jwk, self.algorithms[alg]), - **args +class JwkSetAuthentication(JwtAuthentication): + def __init__(self, jwk_set=None, **kwargs): + super(JwkSetAuthentication, self).__init__(**kwargs) + + self.jwk_set = jwk_set + + @property + def pyjwt(self): + return JwkSetPyJwt( + self.jwk_set or + flask.current_app.config[self.get_config_key('jwk_set')] )