Skip to content

Commit

Permalink
Ensure exempt routes are exempt from meta limits
Browse files Browse the repository at this point in the history
  • Loading branch information
alisaifee committed Apr 21, 2024
1 parent 7967e8e commit 3d7f20a
Show file tree
Hide file tree
Showing 5 changed files with 61 additions and 48 deletions.
2 changes: 2 additions & 0 deletions flask_limiter/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,8 @@ class ExemptionScope(enum.Flag):
#: Exempt from application wide "global" limits
APPLICATION = enum.auto()
#: Exempt from default limits configured on the extension
META = enum.auto()
#: Exempts from meta limits
DEFAULT = enum.auto()
#: Exempts any nested blueprints. See :ref:`recipes:nested blueprints`
DESCENDENTS = enum.auto()
Expand Down
88 changes: 48 additions & 40 deletions flask_limiter/extension.py
Original file line number Diff line number Diff line change
Expand Up @@ -651,35 +651,40 @@ def exempt(
self,
obj: flask.Blueprint,
*,
flags: ExemptionScope = ExemptionScope.APPLICATION | ExemptionScope.DEFAULT,
) -> flask.Blueprint:
...
flags: ExemptionScope = ExemptionScope.APPLICATION
| ExemptionScope.DEFAULT
| ExemptionScope.META,
) -> flask.Blueprint: ...

@overload
def exempt(
self,
obj: Callable[..., R],
*,
flags: ExemptionScope = ExemptionScope.APPLICATION | ExemptionScope.DEFAULT,
) -> Callable[..., R]:
...
flags: ExemptionScope = ExemptionScope.APPLICATION
| ExemptionScope.DEFAULT
| ExemptionScope.META,
) -> Callable[..., R]: ...

@overload
def exempt(
self,
*,
flags: ExemptionScope = ExemptionScope.APPLICATION | ExemptionScope.DEFAULT,
flags: ExemptionScope = ExemptionScope.APPLICATION
| ExemptionScope.DEFAULT
| ExemptionScope.META,
) -> Union[
Callable[[Callable[P, R]], Callable[P, R]],
Callable[[flask.Blueprint], flask.Blueprint],
]:
...
]: ...

def exempt(
self,
obj: Optional[Union[Callable[..., R], flask.Blueprint]] = None,
*,
flags: ExemptionScope = ExemptionScope.APPLICATION | ExemptionScope.DEFAULT,
flags: ExemptionScope = ExemptionScope.APPLICATION
| ExemptionScope.DEFAULT
| ExemptionScope.META,
) -> Union[
Callable[..., R],
flask.Blueprint,
Expand All @@ -692,7 +697,7 @@ def exempt(
:param obj: view function or blueprint to mark as exempt.
:param flags: Controls the scope of the exemption. By default
application wide limits and defaults configured on the extension
application wide limits, defaults configured on the extension and meta limits
are opted out of. Additional flags can be used to control the behavior
when :paramref:`obj` is a Blueprint that is nested under another Blueprint
or has other Blueprints nested under it (See :ref:`recipes:nested blueprints`)
Expand Down Expand Up @@ -1031,29 +1036,35 @@ def __evaluate_limits(self, endpoint: str, limits: List[Limit]) -> None:
limit_for_header: Optional[RequestLimit] = None
view_limits: List[RequestLimit] = []
meta_limits = list(itertools.chain(*self._meta_limits))
for lim in meta_limits:
limit_key, scope = lim.key_func(), lim.scope_for(endpoint, None)
args = [limit_key, scope]
if not self.limiter.test(lim.limit, *args):
breached_meta_limit = RequestLimit(
self, lim.limit, args, True, lim.shared
)
self.context.view_rate_limit = breached_meta_limit
self.context.view_rate_limits = [breached_meta_limit]
meta_breach_response = None
if self._on_meta_breach:
try:
cb_response = self._on_meta_breach(breached_meta_limit)
if isinstance(cb_response, flask.wrappers.Response):
meta_breach_response = cb_response
except Exception as err: # noqa
if self._swallow_errors:
self.logger.exception(
"on_meta_breach callback failed with error %s", err
)
else:
raise err
raise RateLimitExceeded(lim, response=meta_breach_response)
if not (
ExemptionScope.META
& self.limit_manager.exemption_scope(
flask.current_app, endpoint, flask.request.blueprint
)
):
for lim in meta_limits:
limit_key, scope = lim.key_func(), lim.scope_for(endpoint, None)
args = [limit_key, scope]
if not self.limiter.test(lim.limit, *args):
breached_meta_limit = RequestLimit(
self, lim.limit, args, True, lim.shared
)
self.context.view_rate_limit = breached_meta_limit
self.context.view_rate_limits = [breached_meta_limit]
meta_breach_response = None
if self._on_meta_breach:
try:
cb_response = self._on_meta_breach(breached_meta_limit)
if isinstance(cb_response, flask.wrappers.Response):
meta_breach_response = cb_response
except Exception as err: # noqa
if self._swallow_errors:
self.logger.exception(
"on_meta_breach callback failed with error %s", err
)
else:
raise err
raise RateLimitExceeded(lim, response=meta_breach_response)

for lim in sorted(limits, key=lambda x: x.limit):
if lim.is_exempt or lim.method_exempt:
Expand Down Expand Up @@ -1253,16 +1264,13 @@ def __exit__(
exc_type: Optional[Type[BaseException]],
exc_value: Optional[BaseException],
traceback: Optional[TracebackType],
) -> None:
...
) -> None: ...

@overload
def __call__(self, obj: Callable[P, R]) -> Callable[P, R]:
...
def __call__(self, obj: Callable[P, R]) -> Callable[P, R]: ...

@overload
def __call__(self, obj: flask.Blueprint) -> None:
...
def __call__(self, obj: flask.Blueprint) -> None: ...

def __call__(
self, obj: Union[Callable[P, R], flask.Blueprint]
Expand Down
3 changes: 1 addition & 2 deletions tests/test_decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -705,8 +705,7 @@ def default_on_breach_with_response(request_limit):
f"default custom response {request_limit.limit} @ {request.path}", 429
)

def on_breach_invalid():
...
def on_breach_invalid(): ...

def on_breach_fail(request_limit):
1 / 0
Expand Down
12 changes: 9 additions & 3 deletions tests/test_flask_ext.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,4 @@
"""
"""
""" """

import logging
import time
Expand Down Expand Up @@ -912,6 +910,11 @@ def meta_breach_cb(limit):
def root():
return "root"

@app.route("/exempt")
@limiter.exempt
def exempt():
return "exempt"

with hiro.Timeline().freeze() as timeline:
with app.test_client() as cli:
for _ in range(2):
Expand All @@ -922,17 +925,20 @@ def root():

# blocked because of max 2 breaches/minute
assert cli.get("/").status_code == 429
assert cli.get("/exempt").status_code == 200
timeline.forward(59)
assert cli.get("/").status_code == 200
assert cli.get("/").status_code == 200
assert cli.get("/").status_code == 429
assert cli.get("/exempt").status_code == 200
timeline.forward(59)
# blocked because of max 3 breaches/hour
response = cli.get("/")
assert response.text == "Would you like some tea?"
assert response.status_code == 429
assert response.headers.get("X-RateLimit-Limit") == "3"
assert response.headers.get("X-RateLimit-Remaining") == "0"
assert cli.get("/exempt").status_code == 200

# forward to 1 hour since start
timeline.forward(60 * 58)
Expand Down
4 changes: 1 addition & 3 deletions tests/test_regressions.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,4 @@
"""
"""
""" """

import time

Expand Down

0 comments on commit 3d7f20a

Please sign in to comment.