Skip to content

Commit

Permalink
feat(core): new base middleware (#3996)
Browse files Browse the repository at this point in the history
  • Loading branch information
provinzkraut authored Feb 25, 2025
1 parent 27a5b1d commit 6054667
Show file tree
Hide file tree
Showing 11 changed files with 489 additions and 13 deletions.
17 changes: 17 additions & 0 deletions docs/examples/middleware/abstract_middleware_migration_new.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
import anyio

from litestar import Litestar
from litestar.middleware import ASGIMiddleware
from litestar.types import ASGIApp, Receive, Scope, Send


class TimeoutMiddleware(ASGIMiddleware):
def __init__(self, timeout: float):
self.timeout = timeout

async def handle(self, scope: Scope, receive: Receive, send: Send, next_app: ASGIApp) -> None:
with anyio.move_on_after(self.timeout):
await next_app(scope, receive, send)


app = Litestar(middleware=[TimeoutMiddleware(timeout=5)])
32 changes: 32 additions & 0 deletions docs/examples/middleware/abstract_middleware_migration_old.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
import anyio

from litestar import Litestar
from litestar.middleware import AbstractMiddleware, DefineMiddleware
from litestar.types import ASGIApp, Receive, Scope, Scopes, Send


class TimeoutMiddleware(AbstractMiddleware):
def __init__(
self,
app: ASGIApp,
timeout: float,
exclude: str | list[str] | None = None,
exclude_opt_key: str | None = None,
scopes: Scopes | None = None,
):
self.timeout = timeout
super().__init__(app=app, exclude=exclude, exclude_opt_key=exclude_opt_key, scopes=scopes)

async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
with anyio.move_on_after(self.timeout):
await self.app(scope, receive, send)


app = Litestar(
middleware=[
DefineMiddleware(
TimeoutMiddleware,
timeout=5,
)
]
)
8 changes: 8 additions & 0 deletions docs/examples/middleware/middleware_protocol_migration_new.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
from litestar.middleware import ASGIMiddleware
from litestar.types import ASGIApp, Receive, Scope, Send


class MyMiddleware(ASGIMiddleware):
async def handle(self, scope: Scope, receive: Receive, send: Send, next_app: ASGIApp) -> None:
# do stuff
await next_app(scope, receive, send)
11 changes: 11 additions & 0 deletions docs/examples/middleware/middleware_protocol_migration_old.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
from litestar.middleware import MiddlewareProtocol
from litestar.types import ASGIApp, Receive, Scope, Send


class MyMiddleware(MiddlewareProtocol):
def __init__(self, app: ASGIApp) -> None:
self.app = app

async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
# do stuff
await self.app(scope, receive, send)
22 changes: 22 additions & 0 deletions docs/examples/middleware/request_timing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
import time

from litestar.datastructures import MutableScopeHeaders
from litestar.enums import ScopeType
from litestar.middleware import ASGIMiddleware
from litestar.types import ASGIApp, Message, Receive, Scope, Send


class ProcessTimeHeader(ASGIMiddleware):
scopes = (ScopeType.HTTP, ScopeType.ASGI)

async def handle(self, scope: Scope, receive: Receive, send: Send, next_app: ASGIApp) -> None:
start_time = time.monotonic()

async def send_wrapper(message: Message) -> None:
if message["type"] == "http.response.start":
process_time = time.monotonic() - start_time
headers = MutableScopeHeaders.from_message(message=message)
headers["X-Process-Time"] = str(process_time)
await send(message)

await next_app(scope, receive, send_wrapper)
62 changes: 62 additions & 0 deletions docs/examples/middleware/using_asgi_middleware.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
import anyio

from litestar import Litestar, get
from litestar.enums import ScopeType
from litestar.exceptions import ClientException
from litestar.middleware import ASGIMiddleware
from litestar.types import ASGIApp, Receive, Scope, Send


class TimeoutMiddleware(ASGIMiddleware):
# we can configure some things on the class level here, related to when our
# middleware should be applied.

# if the requests' 'scope["type"]' is not "http", the middleware will be skipped
scopes = (ScopeType.HTTP,)

# if the handler for a request has set an opt of 'no_timeout=True', the middleware
# will be skipped
exclude_opt_key = "no_timeout"

# the base class does not define an '__init__' method, so we're free to overwrite
# this, which we're making use of to add some configuration
def __init__(
self,
timeout: float,
exclude_path_pattern: str | tuple[str, ...] | None = None,
) -> None:
self.timeout = timeout

# we can also dynamically configure the options provided by the base class on
# the instance level
self.exclude_path_pattern = exclude_path_pattern

async def handle(self, scope: Scope, receive: Receive, send: Send, next_app: ASGIApp) -> None:
try:
with anyio.fail_after(self.timeout):
# call the next app in the chain
await next_app(scope, receive, send)
except TimeoutError:
# if the request has timed out, raise an exception. since the whole
# application is wrapped in an exception handling middleware, it will
# transform this exception into a response for us
raise ClientException(status_code=408) from None


@get("/", no_timeout=True)
async def handler_with_opt_skip() -> None:
pass


@get("/not-this-path")
async def handler_with_path_skip() -> None:
pass


app = Litestar(
route_handlers=[
handler_with_opt_skip,
handler_with_path_skip,
],
middleware=[TimeoutMiddleware(timeout=5)],
)
76 changes: 70 additions & 6 deletions docs/usage/middleware/creating-middleware.rst
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,74 @@ The example previously given was using a factory function, i.e.:
return my_middleware
While using functions is a perfectly viable approach, you can also use classes to do the same. See the next sections on
two base classes you can use for this purpose - the :class:`~litestar.middleware.base.MiddlewareProtocol` ,
which gives a bare-bones type, or the :class:`~litestar.middleware.base.AbstractMiddleware` that offers a
base class with some built in functionality.
Extending ``ASGIMiddleware``
----------------------------

While using functions is a perfectly viable approach, the recommended way to handle this
is by using the :class:`~litestar.middleware.ASGIMiddleware` abstract base class, which
also includes functionality to dynamically skip the middleware based on ASGI
``scope["type"]``, handler ``opt`` keys or path patterns and a simple way to pass
configuration to middlewares; It does not implement an ``__init__`` method, so
subclasses are free to use it to customize the middleware's configuration.


Modifying Requests and Responses
++++++++++++++++++++++++++++++++

Middlewares can not only be used to execute *around* other ASGI callable, they can also
intercept and modify both incoming and outgoing data in a request / response cycle by
"wrapping" the respective ``receive`` and ``send`` ASGI callables.

The following demonstrates how to add a request timing header with a timestamp to all
outgoing responses:

.. literalinclude:: /examples/middleware/request_timing.py
:language: python



Migrating to ``ASGIMiddleware`` from ``MiddlewareProtocol`` / ``AbstractMiddleware``
++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++

:class:`~litestar.middleware.ASGIMiddleware` was introduced in Litestar 2.15. If you've
been using ``MiddlewareProtocol`` / ``AbstractMiddleware`` to implement your middlewares
before, there's a simple migration path to using ``ASGIMiddleware``.

**Migrating from ``MiddlewareProtocol``**

.. tab-set::

.. tab-item:: ``MiddlewareProtocol``

.. literalinclude:: /examples/middleware/middleware_protocol_migration_old.py
:language: python

.. tab-item:: ``ASGIMiddleware``

.. literalinclude:: /examples/middleware/middleware_protocol_migration_new.py
:language: python



**Migrating from ``AbstractMiddleware``**

.. tab-set::

.. tab-item:: ``MiddlewareProtocol``

.. literalinclude:: /examples/middleware/abstract_middleware_migration_old.py
:language: python

.. tab-item:: ``ASGIMiddleware``

.. literalinclude:: /examples/middleware/abstract_middleware_migration_new.py
:language: python






Using MiddlewareProtocol
------------------------
Expand Down Expand Up @@ -85,7 +149,7 @@ specifies:


Responding using the MiddlewareProtocol
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
+++++++++++++++++++++++++++++++++++++++

Once a middleware finishes doing whatever its doing, it should pass ``scope``, ``receive``, and ``send`` to an ASGI app
and await it. This is what's happening in the above example with: ``await self.app(scope, receive, send)``. Let's
Expand Down Expand Up @@ -115,7 +179,7 @@ As you can see in the above, given some condition (``request.session`` being ``N
:class:`~litestar.response.redirect.ASGIRedirectResponse` and then await it. Otherwise, we await ``self.app``

Modifying ASGI Requests and Responses using the MiddlewareProtocol
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++

.. important::

Expand Down
2 changes: 2 additions & 0 deletions litestar/middleware/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,13 @@
)
from litestar.middleware.base import (
AbstractMiddleware,
ASGIMiddleware,
DefineMiddleware,
MiddlewareProtocol,
)

__all__ = (
"ASGIMiddleware",
"AbstractAuthenticationMiddleware",
"AbstractMiddleware",
"AuthenticationResult",
Expand Down
6 changes: 3 additions & 3 deletions litestar/middleware/_utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from __future__ import annotations

import re
from typing import TYPE_CHECKING, Pattern, Sequence
from typing import TYPE_CHECKING, Iterable, Pattern, Sequence

from litestar.exceptions import ImproperlyConfiguredException

Expand All @@ -15,7 +15,7 @@

def build_exclude_path_pattern(
*,
exclude: str | list[str] | None = None,
exclude: str | Iterable[str] | None = None,
middleware_cls: type | None = None,
) -> Pattern | None:
"""Build single path pattern from list of patterns to opt-out from middleware processing.
Expand All @@ -32,7 +32,7 @@ def build_exclude_path_pattern(
return None

try:
pattern = re.compile("|".join(exclude)) if isinstance(exclude, list) else re.compile(exclude)
pattern = re.compile("|".join(exclude)) if not isinstance(exclude, str) else re.compile(exclude)
if pattern.match("/") and pattern.match("/982c7064-6ac7-44b7-9be5-07a2ff6d8a92"):
# match a UUID to ensure that it matches paths greedily and not just a literal /
warn_middleware_excluded_on_all_routes(pattern, middleware_cls=middleware_cls)
Expand Down
Loading

0 comments on commit 6054667

Please sign in to comment.