diff --git a/docs/authentication.md b/docs/authentication.md index d5970aa4f..8c2a895a6 100644 --- a/docs/authentication.md +++ b/docs/authentication.md @@ -131,6 +131,29 @@ async def dashboard(request): ... ``` +When redirecting users, the page you redirect them to will include URL they originally requested at the `next` query param: + +```python +from starlette.authentication import requires +from starlette.responses import RedirectResponse + + +@requires('authenticated', redirect='login') +async def admin(request): + ... + + +async def login(request): + if request.method == "POST": + # Now that the user is authenticated, + # we can send them to their original request destination + if request.user.is_authenticated: + next_url = request.query_params.get("next") + if next_url: + return RedirectResponse(next_url) + return RedirectResponse("/") +``` + For class-based endpoints, you should wrap the decorator around a method on the class. diff --git a/starlette/authentication.py b/starlette/authentication.py index 1fc6acaeb..92064b23b 100644 --- a/starlette/authentication.py +++ b/starlette/authentication.py @@ -2,6 +2,7 @@ import functools import inspect import typing +from urllib.parse import urlencode from starlette.exceptions import HTTPException from starlette.requests import HTTPConnection, Request @@ -62,7 +63,12 @@ async def async_wrapper( if not has_required_scope(request, scopes_list): if redirect is not None: - return RedirectResponse(url=request.url_for(redirect)) + orig_request_qparam = urlencode({"next": str(request.url)}) + next_url = "{redirect_path}?{orig_request}".format( + redirect_path=request.url_for(redirect), + orig_request=orig_request_qparam, + ) + return RedirectResponse(url=next_url) raise HTTPException(status_code=status_code) return await func(*args, **kwargs) @@ -77,7 +83,12 @@ def sync_wrapper(*args: typing.Any, **kwargs: typing.Any) -> Response: if not has_required_scope(request, scopes_list): if redirect is not None: - return RedirectResponse(url=request.url_for(redirect)) + orig_request_qparam = urlencode({"next": str(request.url)}) + next_url = "{redirect_path}?{orig_request}".format( + redirect_path=request.url_for(redirect), + orig_request=orig_request_qparam, + ) + return RedirectResponse(url=next_url) raise HTTPException(status_code=status_code) return func(*args, **kwargs) diff --git a/tests/test_authentication.py b/tests/test_authentication.py index 372ea81d8..becb8a54b 100644 --- a/tests/test_authentication.py +++ b/tests/test_authentication.py @@ -1,5 +1,6 @@ import base64 import binascii +from urllib.parse import urlencode import pytest @@ -183,7 +184,10 @@ def test_authentication_redirect(): with TestClient(app) as client: response = client.get("/admin") assert response.status_code == 200 - assert response.url == "http://testserver/" + url = "{}?{}".format( + "http://testserver/", urlencode({"next": "http://testserver/admin"}) + ) + assert response.url == url response = client.get("/admin", auth=("tomchristie", "example")) assert response.status_code == 200 @@ -191,7 +195,10 @@ def test_authentication_redirect(): response = client.get("/admin/sync") assert response.status_code == 200 - assert response.url == "http://testserver/" + url = "{}?{}".format( + "http://testserver/", urlencode({"next": "http://testserver/admin/sync"}) + ) + assert response.url == url response = client.get("/admin/sync", auth=("tomchristie", "example")) assert response.status_code == 200