Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow staticfiles to follow symlinks outside directory #1377

Merged
merged 17 commits into from
May 28, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 24 additions & 20 deletions starlette/staticfiles.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import stat
import typing
from email.utils import parsedate
from pathlib import Path

import anyio

Expand Down Expand Up @@ -51,7 +52,7 @@ def __init__(
self.all_directories = self.get_directories(directory, packages)
self.html = html
self.config_checked = False
if check_dir and directory is not None and not os.path.isdir(directory):
if check_dir and directory is not None and not Path(directory).is_dir():
raise RuntimeError(f"Directory '{directory}' does not exist")

def get_directories(
Expand All @@ -77,11 +78,9 @@ def get_directories(
spec = importlib.util.find_spec(package)
assert spec is not None, f"Package {package!r} could not be found."
assert spec.origin is not None, f"Package {package!r} could not be found."
package_directory = os.path.normpath(
os.path.join(spec.origin, "..", statics_dir)
)
assert os.path.isdir(
package_directory
package_directory = Path(spec.origin).joinpath("..", statics_dir).resolve()
aminalaee marked this conversation as resolved.
Show resolved Hide resolved
assert (
package_directory.is_dir()
), f"Directory '{statics_dir!r}' in package {package!r} could not be found."
directories.append(package_directory)

Expand All @@ -101,14 +100,14 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
response = await self.get_response(path, scope)
await response(scope, receive, send)

def get_path(self, scope: Scope) -> str:
def get_path(self, scope: Scope) -> Path:
"""
Given the ASGI scope, return the `path` string to serve up,
with OS specific path separators, and any '..', '.' components removed.
"""
return os.path.normpath(os.path.join(*scope["path"].split("/")))
return Path(*scope["path"].split("/"))

async def get_response(self, path: str, scope: Scope) -> Response:
async def get_response(self, path: Path, scope: Scope) -> Response:
"""
Returns an HTTP response, given the incoming path, method and request headers.
"""
Expand All @@ -131,7 +130,7 @@ async def get_response(self, path: str, scope: Scope) -> Response:
elif stat_result and stat.S_ISDIR(stat_result.st_mode) and self.html:
# We're in HTML mode, and have got a directory URL.
# Check if we have 'index.html' file to serve.
index_path = os.path.join(path, "index.html")
index_path = path.joinpath("index.html")
full_path, stat_result = await anyio.to_thread.run_sync(
self.lookup_path, index_path
)
Expand All @@ -158,20 +157,25 @@ async def get_response(self, path: str, scope: Scope) -> Response:
raise HTTPException(status_code=404)

def lookup_path(
self, path: str
) -> typing.Tuple[str, typing.Optional[os.stat_result]]:
self, path: Path
) -> typing.Tuple[Path, typing.Optional[os.stat_result]]:
for directory in self.all_directories:
full_path = os.path.realpath(os.path.join(directory, path))
directory = os.path.realpath(directory)
if os.path.commonprefix([full_path, directory]) != directory:
# Don't allow misbehaving clients to break out of the static files
# directory.
continue
original_path = Path(directory).joinpath(path)
full_path = original_path.resolve()
directory = Path(directory).resolve()
try:
return full_path, os.stat(full_path)
stat_result = os.lstat(original_path)
full_path.relative_to(directory)
aminalaee marked this conversation as resolved.
Show resolved Hide resolved
return full_path, stat_result
except ValueError:
# Allow clients to break out of the static files directory
# if following symlinks.
if stat.S_ISLNK(stat_result.st_mode):
stat_result = os.lstat(full_path)
Kludex marked this conversation as resolved.
Show resolved Hide resolved
return full_path, stat_result
except (FileNotFoundError, NotADirectoryError):
continue
return "", None
return Path(), None

def file_response(
self,
Expand Down
29 changes: 27 additions & 2 deletions tests/test_staticfiles.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,8 +166,8 @@ def test_staticfiles_prevents_breaking_out_of_directory(tmpdir):
directory = os.path.join(tmpdir, "foo")
os.mkdir(directory)

path = os.path.join(tmpdir, "example.txt")
with open(path, "w") as file:
file_path = os.path.join(tmpdir, "example.txt")
euri10 marked this conversation as resolved.
Show resolved Hide resolved
with open(file_path, "w") as file:
file.write("outside root dir")

app = StaticFiles(directory=directory)
Expand Down Expand Up @@ -441,3 +441,28 @@ def mock_timeout(*args, **kwargs):
response = client.get("/example.txt")
assert response.status_code == 500
assert response.text == "Internal Server Error"


def test_staticfiles_follows_symlinks_to_break_out_of_dir(
aminalaee marked this conversation as resolved.
Show resolved Hide resolved
tmp_path: pathlib.Path, test_client_factory
):
statics_path = tmp_path.joinpath("statics")
statics_path.mkdir()

symlink_path = tmp_path.joinpath("symlink")
symlink_path.mkdir()

symlink_file_path = symlink_path.joinpath("index.html")
with open(symlink_file_path, "w") as file:
file.write("<h1>Hello</h1>")

statics_file_path = statics_path.joinpath("index.html")
statics_file_path.symlink_to(symlink_file_path)

app = StaticFiles(directory=statics_path)
client = test_client_factory(app)

response = client.get("/index.html")
assert response.url == "http://testserver/index.html"
assert response.status_code == 200
assert response.text == "<h1>Hello</h1>"