Skip to content

Commit

Permalink
Merge pull request #70 from geoadmin/develop
Browse files Browse the repository at this point in the history
New Release v4.1.0 - #minor
  • Loading branch information
ltshb authored Oct 28, 2024
2 parents b99fd63 + f6a2f59 commit 183a2b3
Show file tree
Hide file tree
Showing 6 changed files with 60 additions and 45 deletions.
2 changes: 1 addition & 1 deletion .env.default
Original file line number Diff line number Diff line change
Expand Up @@ -3,5 +3,5 @@ AWS_SECRET_ACCESS_KEY=dummy123
AWS_ENDPOINT_URL=http://localhost:8080
AWS_DEFAULT_REGION=eu-central-1
AWS_DYNAMODB_TABLE_NAME=test-db
ALLOWED_DOMAINS=.*localhost((:[0-9]*)?|\/)?,.*admin\.ch,.*bgdi\.ch
ALLOWED_DOMAINS=localhost,.*\.geo\.admin\.ch,.*\.bgdi\.ch
STAGING=local
2 changes: 1 addition & 1 deletion .env.testing
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
ALLOWED_DOMAINS=.*\.geo\.admin\.ch,.*\.bgdi\.ch,http://localhost((:[0-9]*)?|\/)?
ALLOWED_DOMAINS=localhost,.*\.geo\.admin\.ch,.*\.bgdi\.ch
AWS_ACCESS_KEY_ID=testing
AWS_SECRET_ACCESS_KEY=testing
AWS_SECURITY_TOKEN=testing
Expand Down
6 changes: 1 addition & 5 deletions app/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@

from app.helpers.utils import get_redirect_param
from app.helpers.utils import get_registered_method
from app.helpers.utils import is_domain_allowed
from app.helpers.utils import make_error_msg
from app.settings import ALLOWED_DOMAINS_PATTERN
from app.settings import CACHE_CONTROL
from app.settings import CACHE_CONTROL_4XX

Expand All @@ -25,10 +25,6 @@
app.config.from_mapping({"TRAP_HTTP_EXCEPTIONS": True})


def is_domain_allowed(domain):
return re.fullmatch(ALLOWED_DOMAINS_PATTERN, domain) is not None


@app.before_request
# Add quick log of the routes used to all request.
# Important: this should be the first before_request method, to ensure
Expand Down
17 changes: 15 additions & 2 deletions app/helpers/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,8 +113,12 @@ def get_url():
f"The url given as parameter was too long. (limit is 2046 "
f"characters, {len(url)} given)"
)
if not re.fullmatch(ALLOWED_DOMAINS_PATTERN, urlparse(url).netloc):
logger.error('URL(%s) given as a parameter is not allowed', url)
if not is_domain_allowed(url):
logger.error(
'URL(%s) given as a parameter is not allowed, test pattern %s',
url,
ALLOWED_DOMAINS_PATTERN
)
abort(400, 'URL given as a parameter is not allowed.')

return url
Expand All @@ -132,3 +136,12 @@ def strtobool(value) -> bool:
if value in ('n', 'no', 'f', 'false', 'off', '0'):
return False
raise ValueError(f"invalid truth value \'{value}\'")


def is_domain_allowed(url):
"""Check if the url contain a domain that is allowed
"""
domain = urlparse(url).hostname
if domain:
return re.fullmatch(ALLOWED_DOMAINS_PATTERN, domain) is not None
return False
24 changes: 13 additions & 11 deletions tests/unit_tests/base.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import logging
import re
import unittest
from urllib.parse import urlparse

import boto3

Expand Down Expand Up @@ -83,18 +84,19 @@ def setUp(self):
def tearDown(self):
self.table.delete()

def assertCors(
self,
response,
expected_allowed_methods,
origin_pattern=ALLOWED_DOMAINS_PATTERN
): # pylint: disable=invalid-name
def assertCors(self, response, expected_allowed_methods, all_origin=False): # pylint: disable=invalid-name
self.assertIn('Access-Control-Allow-Origin', response.headers)
self.assertIsNotNone(
re.fullmatch(origin_pattern, response.headers['Access-Control-Allow-Origin']),
msg=f"Access-Control-Allow-Origin={response.headers['Access-Control-Allow-Origin']}"
f" doesn't match {origin_pattern}"
)
if all_origin:
self.assertEqual(response.headers['Access-Control-Allow-Origin'], '*')
else:
allow_origin_domain = urlparse(response.headers['Access-Control-Allow-Origin']).hostname
self.assertIsNotNone(
re.fullmatch(
ALLOWED_DOMAINS_PATTERN, allow_origin_domain if allow_origin_domain else ''
),
msg=f"Access-Control-Allow-Origin={response.headers['Access-Control-Allow-Origin']}"
f" doesn't match {ALLOWED_DOMAINS_PATTERN}"
)
self.assertIn('Access-Control-Allow-Methods', response.headers)
self.assertListEqual(
sorted(expected_allowed_methods),
Expand Down
54 changes: 29 additions & 25 deletions tests/unit_tests/test_routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ class TestRoutes(BaseShortlinkTestCase):

def test_checker_ok(self):
# checker
response = self.app.get(url_for('checker'), headers={"Origin": "map.geo.admin.ch"})
response = self.app.get(url_for('checker'), headers={"Origin": "https://map.geo.admin.ch"})
self.assertEqual(response.status_code, 200)
self.assertNotIn('Cache-Control', response.headers)
self.assertEqual(response.content_type, "application/json; charset=utf-8")
Expand All @@ -27,7 +27,9 @@ def test_checker_ok(self):
def test_create_shortlink_ok(self):
url = "https://map.geo.admin.ch/#/map?lang=en&center=2647850.83,1120124.2&z=1.812&bgLayer=ch.swisstopo.pixelkarte-farbe&top" # pylint: disable=line-too-long
response = self.app.post(
url_for('create_shortlink'), json={"url": url}, headers={"Origin": "map.geo.admin.ch"}
url_for('create_shortlink'),
json={"url": url},
headers={"Origin": "https://map.geo.admin.ch"}
)
self.assertEqual(response.status_code, 201)
self.assertCors(response, ['POST', 'OPTIONS'])
Expand All @@ -49,7 +51,9 @@ def test_create_shortlink_ok(self):
)
# Check that second call returns 200 and the same short url
response = self.app.post(
url_for('create_shortlink'), json={"url": url}, headers={"Origin": "map.geo.admin.ch"}
url_for('create_shortlink'),
json={"url": url},
headers={"Origin": "https://map.geo.admin.ch"}
)
self.assertEqual(response.status_code, 200)
self.assertCors(response, ['POST', 'OPTIONS'])
Expand All @@ -59,7 +63,7 @@ def test_create_shortlink_ok(self):

def test_create_shortlink_no_json(self):
response = self.app.post(
url_for('create_shortlink'), headers={"Origin": "map.geo.admin.ch"}
url_for('create_shortlink'), headers={"Origin": "https://map.geo.admin.ch"}
)
self.assertEqual(415, response.status_code)
self.assertCors(response, ['POST', 'OPTIONS'])
Expand All @@ -77,7 +81,7 @@ def test_create_shortlink_no_json(self):

def test_create_shortlink_no_url(self):
response = self.app.post(
url_for('create_shortlink'), json={}, headers={"Origin": "map.geo.admin.ch"}
url_for('create_shortlink'), json={}, headers={"Origin": "https://map.geo.admin.ch"}
)
self.assertEqual(400, response.status_code)
self.assertCors(response, ['POST', 'OPTIONS'])
Expand All @@ -97,7 +101,7 @@ def test_create_shortlink_no_hostname(self):
response = self.app.post(
url_for('create_shortlink'),
json={"url": f"{wrong_url}"},
headers={"Origin": "map.geo.admin.ch"}
headers={"Origin": "https://map.geo.admin.ch"}
)
self.assertEqual(response.status_code, 400)
self.assertCors(response, ['POST', 'OPTIONS'])
Expand All @@ -116,7 +120,7 @@ def test_create_shortlink_non_allowed_hostname(self):
response = self.app.post(
url_for('create_shortlink'),
json={"url": "https://non-allowed.hostname.ch/test"},
headers={"Origin": "map.geo.admin.ch"}
headers={"Origin": "https://map.geo.admin.ch"}
)
self.assertEqual(response.status_code, 400)
self.assertCors(response, ['POST', 'OPTIONS'])
Expand All @@ -135,7 +139,7 @@ def test_create_shortlink_non_allowed_hostname_containing_admin_address(self):
response = self.app.post(
url_for('create_shortlink'),
json={"url": "https://map.geo.admin.ch.non-allowed.hostname.ch/test"},
headers={"Origin": "map.geo.admin.ch"}
headers={"Origin": "https://map.geo.admin.ch"}
)
self.assertEqual(response.status_code, 400)
self.assertCors(response, ['POST', 'OPTIONS'])
Expand All @@ -156,7 +160,7 @@ def test_create_shortlink_url_too_long(self):
url_for('create_shortlink'),
json={"url": url},
content_type="application/json",
headers={"Origin": "map.geo.admin.ch"}
headers={"Origin": "https://map.geo.admin.ch"}
)
self.assertEqual(response.status_code, 400)
self.assertCors(response, ['POST', 'OPTIONS'])
Expand All @@ -178,7 +182,7 @@ def test_redirect_shortlink_ok(self):
for short_id, url in self.uuid_to_url_dict.items():
response = self.app.get(url_for('get_shortlink', shortlink_id=short_id))
self.assertEqual(response.status_code, 301)
self.assertCors(response, ['GET', 'HEAD', 'OPTIONS'], origin_pattern=r"^\*$")
self.assertCors(response, ['GET', 'HEAD', 'OPTIONS'], all_origin=True)
self.assertIn('Cache-Control', response.headers)
self.assertIn('max-age=', response.headers['Cache-Control'])
self.assertEqual(response.content_type, "text/html; charset=utf-8")
Expand All @@ -192,7 +196,7 @@ def test_redirect_shortlink_ok_with_query(self):
headers={"Origin": "www.example.com"}
)
self.assertEqual(response.status_code, 301)
self.assertCors(response, ['GET', 'HEAD', 'OPTIONS'], origin_pattern=r"^\*$")
self.assertCors(response, ['GET', 'HEAD', 'OPTIONS'], all_origin=True)
self.assertIn('Cache-Control', response.headers)
self.assertIn('max-age=', response.headers['Cache-Control'])
self.assertEqual(response.content_type, "text/html; charset=utf-8")
Expand All @@ -204,7 +208,7 @@ def test_shortlink_fetch_nok_invalid_redirect_parameter(self):
url_for('get_shortlink', shortlink_id=short_id),
query_string={'redirect': 'banana'},
content_type="text/html",
headers={"Origin": "map.geo.admin.ch"}
headers={"Origin": "https://map.geo.admin.ch"}
)
expected_json = {
'success': False,
Expand All @@ -226,7 +230,7 @@ def test_shortlink_fetch_nok_invalid_redirect_parameter(self):
def test_redirect_shortlink_url_not_found(self):
response = self.app.get(
url_for('get_shortlink', shortlink_id='nonexistent'),
headers={"Origin": "map.geo.admin.ch"}
headers={"Origin": "https://map.geo.admin.ch"}
)
expected_json = {
'success': False,
Expand All @@ -235,7 +239,7 @@ def test_redirect_shortlink_url_not_found(self):
}
}
self.assertEqual(response.status_code, 404)
self.assertCors(response, ['GET', 'HEAD', 'OPTIONS'], origin_pattern=r"^\*$")
self.assertCors(response, ['GET', 'HEAD', 'OPTIONS'], all_origin=True)
self.assertIn('Cache-Control', response.headers)
self.assertIn('max-age=3600', response.headers['Cache-Control'])
self.assertIn('application/json', response.content_type)
Expand All @@ -246,7 +250,7 @@ def test_fetch_full_url_from_shortlink_ok(self):
response = self.app.get(
url_for('get_shortlink', shortlink_id=short_id),
query_string={'redirect': 'false'},
headers={"Origin": "map.geo.admin.ch"}
headers={"Origin": "https://map.geo.admin.ch"}
)
self.assertEqual(response.status_code, 200)
self.assertCors(response, ['GET', 'HEAD', 'OPTIONS'])
Expand All @@ -262,7 +266,7 @@ def test_fetch_full_url_from_shortlink_ok_explicit_parameter(self):
response = self.app.get(
url_for('get_shortlink', shortlink_id=short_id),
query_string={'redirect': 'false'},
headers={"Origin": "map.geo.admin.ch"}
headers={"Origin": "https://map.geo.admin.ch"}
)
self.assertEqual(response.status_code, 200)
self.assertCors(response, ['GET', 'HEAD', 'OPTIONS'])
Expand All @@ -277,7 +281,7 @@ def test_fetch_full_url_from_shortlink_url_not_found(self):
response = self.app.get(
url_for('get_shortlink', shortlink_id='nonexistent'),
query_string={'redirect': 'false'},
headers={"Origin": "map.geo.admin.ch"}
headers={"Origin": "https://map.geo.admin.ch"}
)
self.assertEqual(response.status_code, 404)
self.assertCors(response, ['GET', 'HEAD', 'OPTIONS'])
Expand Down Expand Up @@ -325,12 +329,12 @@ def test_create_shortlink_origin_not_allowed(self, headers):
)

@params(
{'Origin': 'map.geo.admin.ch'},
{'Origin': 'https://map.geo.admin.ch'},
{
'Origin': 'map.geo.admin.ch', 'Sec-Fetch-Site': 'same-site'
'Origin': 'https://map.geo.admin.ch', 'Sec-Fetch-Site': 'same-site'
},
{
'Origin': 's.geo.admin.ch', 'Sec-Fetch-Site': 'same-origin'
'Origin': 'https://s.geo.admin.ch', 'Sec-Fetch-Site': 'same-origin'
},
{
'Origin': 'http://localhost', 'Sec-Fetch-Site': 'cross-site'
Expand Down Expand Up @@ -389,19 +393,19 @@ def test_get_shortlink_redirect_origin_allowed(self, headers):
headers=headers
)
self.assertEqual(response.status_code, 301)
self.assertCors(response, ['GET', 'HEAD', 'OPTIONS'], origin_pattern=r"^\*$")
self.assertCors(response, ['GET', 'HEAD', 'OPTIONS'], all_origin=True)

response = self.app.get(url_for('get_shortlink', shortlink_id=short_id), headers=headers)
self.assertEqual(response.status_code, 301)
self.assertCors(response, ['GET', 'HEAD', 'OPTIONS'], origin_pattern=r"^\*$")
self.assertCors(response, ['GET', 'HEAD', 'OPTIONS'], all_origin=True)

@params(
{'Origin': 'map.geo.admin.ch'},
{'Origin': 'https://map.geo.admin.ch'},
{
'Origin': 'map.geo.admin.ch', 'Sec-Fetch-Site': 'same-site'
'Origin': 'https://map.geo.admin.ch', 'Sec-Fetch-Site': 'same-site'
},
{
'Origin': 's.geo.admin.ch', 'Sec-Fetch-Site': 'same-origin'
'Origin': 'https://s.geo.admin.ch', 'Sec-Fetch-Site': 'same-origin'
},
{
'Origin': 'http://localhost', 'Sec-Fetch-Site': 'cross-site'
Expand Down

0 comments on commit 183a2b3

Please sign in to comment.