diff --git a/poetry.lock b/poetry.lock index c3d74356..335d741e 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,14 +1,14 @@ -# This file is automatically @generated by Poetry 1.8.3 and should not be changed by hand. +# This file is automatically @generated by Poetry 1.8.4 and should not be changed by hand. [[package]] name = "alembic" -version = "1.13.3" +version = "1.14.0" description = "A database migration tool for SQLAlchemy." optional = false python-versions = ">=3.8" files = [ - {file = "alembic-1.13.3-py3-none-any.whl", hash = "sha256:908e905976d15235fae59c9ac42c4c5b75cfcefe3d27c0fbf7ae15a37715d80e"}, - {file = "alembic-1.13.3.tar.gz", hash = "sha256:203503117415561e203aa14541740643a611f641517f0209fcae63e9fa09f1a2"}, + {file = "alembic-1.14.0-py3-none-any.whl", hash = "sha256:99bd884ca390466db5e27ffccff1d179ec5c05c965cfefc0607e69f9e411cb25"}, + {file = "alembic-1.14.0.tar.gz", hash = "sha256:b00892b53b3642d0b8dbedba234dbf1924b69be83a9a769d5a624b01094e304b"}, ] [package.dependencies] @@ -2394,6 +2394,25 @@ requests = ">=2.20" [package.extras] docs = ["furo (>=2023.3,<2024.0)", "myst-parser (>=1.0)", "sphinx (>=5.2,<6.0)", "sphinx-autodoc-typehints (>=1.22,<2.0)", "sphinx-copybutton (>=0.5)"] +[[package]] +name = "responses" +version = "0.25.3" +description = "A utility library for mocking out the `requests` Python library." +optional = false +python-versions = ">=3.8" +files = [ + {file = "responses-0.25.3-py3-none-any.whl", hash = "sha256:521efcbc82081ab8daa588e08f7e8a64ce79b91c39f6e62199b19159bea7dbcb"}, + {file = "responses-0.25.3.tar.gz", hash = "sha256:617b9247abd9ae28313d57a75880422d55ec63c29d33d629697590a034358dba"}, +] + +[package.dependencies] +pyyaml = "*" +requests = ">=2.30.0,<3.0" +urllib3 = ">=1.25.10,<3.0" + +[package.extras] +tests = ["coverage (>=6.0.0)", "flake8", "mypy", "pytest (>=7.0.0)", "pytest-asyncio", "pytest-cov", "pytest-httpserver", "tomli", "tomli-w", "types-PyYAML", "types-requests"] + [[package]] name = "rich" version = "13.9.4" @@ -3315,4 +3334,4 @@ type = ["pytest-mypy"] [metadata] lock-version = "2.0" python-versions = "^3.11" -content-hash = "f3f7b9e1063da57c644d70d106ddd1b4d617335a2b062f8da37f150ba9432f7f" +content-hash = "2757aff75c37be8d41e01a73644d010cd2a08b75cd1438758d5bf052b3e205b8" diff --git a/pyproject.toml b/pyproject.toml index 1357a972..b933b70e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -46,6 +46,7 @@ pyright = "^1.1.352" pyperf = "^2.2.0" pytest = "^8.3.2" pytest-mock = "^3.14.0" +responses = "0.25.3" pyfakefs = "^5.4.1" ruff = "^0.7.2" isort = "^5.10.1" diff --git a/src/program/utils/request.py b/src/program/utils/request.py index df772d57..cb9c49ec 100644 --- a/src/program/utils/request.py +++ b/src/program/utils/request.py @@ -4,7 +4,7 @@ from typing import Dict, Type, Optional, Any from requests import Session from lxml import etree -from requests.exceptions import ConnectTimeout, RequestException +from requests.exceptions import ConnectTimeout, RequestException, HTTPError from requests.models import Response from requests_cache import CacheMixin, CachedSession from requests_ratelimiter import LimiterMixin, LimiterSession @@ -96,7 +96,7 @@ def __init__(self, session: Session, response_type: ResponseType = ResponseType. def _request(self, method: HttpMethod, endpoint: str, ignore_base_url: Optional[bool] = None, overriden_response_type: ResponseType = None, **kwargs) -> ResponseObject: """Generic request handler with error handling, using kwargs for flexibility.""" try: - url = f"{self.BASE_URL}/{endpoint}" if not ignore_base_url and self.BASE_URL else endpoint + url = f"{self.BASE_URL}/{endpoint}".rstrip('/') if not ignore_base_url and self.BASE_URL else endpoint # Add base parameters to kwargs if they exist request_params = self.BASE_REQUEST_PARAMS.to_dict() @@ -118,9 +118,13 @@ def _request(self, method: HttpMethod, endpoint: str, ignore_base_url: Optional[ logger.debug(f"ResponseObject: status_code={response_obj.status_code}, data={response_obj.data}") return response_obj - except RequestException as e: - logger.error(f"Request failed: {e}") - raise self.custom_exception(f"Request failed: {e}") from e + except HTTPError as e: + if e.response is not None and e.response.status_code == 429: + logger.warning(f"Rate limit hit: {e}") + raise RateLimitExceeded(f"Rate limit exceeded for {url}", response=e.response) from e + else: + logger.error(f"Request failed: {e}") + raise self.custom_exception(f"Request failed: {e}") from e class RateLimitExceeded(Exception): @@ -136,7 +140,8 @@ class CachedLimiterSession(CacheMixin, LimiterMixin, Session): def create_service_session( rate_limit_params: Optional[dict] = None, use_cache: bool = False, - cache_params: Optional[dict] = None + cache_params: Optional[dict] = None, + log_config: Optional[bool] = False, ) -> Session: """ Create a session for a specific service with optional caching and rate-limiting. @@ -144,23 +149,29 @@ def create_service_session( :param rate_limit_params: Dictionary of rate-limiting parameters. :param use_cache: Boolean indicating if caching should be enabled. :param cache_params: Dictionary of caching parameters if caching is enabled. + :param log_config: Boolean indicating if the session configuration should be logged. :return: Configured session for the service. """ if use_cache and not cache_params: raise ValueError("Cache parameters must be provided if use_cache is True.") if use_cache and cache_params: - if rate_limit_params: - return CachedLimiterSession(**rate_limit_params, **cache_params) - else: - return CachedSession(**cache_params) + if log_config: + logger.debug(f"Rate Limit Parameters: {rate_limit_params}") + logger.debug(f"Cache Parameters: {cache_params}") + session_class = CachedLimiterSession if rate_limit_params else CachedSession + return session_class(**rate_limit_params, **cache_params) if rate_limit_params: + if log_config: + logger.debug(f"Rate Limit Parameters: {rate_limit_params}") return LimiterSession(**rate_limit_params) return Session() + def get_rate_limit_params( + custom_limiter: Optional[Limiter] = None, per_second: Optional[int] = None, per_minute: Optional[int] = None, per_hour: Optional[int] = None, @@ -168,12 +179,15 @@ def get_rate_limit_params( max_calls: Optional[int] = None, period: Optional[int] = None, db_name: Optional[str] = None, - use_memory_list: bool = False + use_memory_list: bool = False, + limit_statuses: Optional[list[int]] = None, + max_delay: Optional[int] = 0, ) -> Dict[str, any]: """ Generate rate limit parameters for a service. If `db_name` is not provided, use an in-memory bucket for rate limiting. + :param custom_limiter: Optional custom limiter to use for rate limiting. :param per_second: Requests per second limit. :param per_minute: Requests per minute limit. :param per_hour: Requests per hour limit. @@ -182,13 +196,14 @@ def get_rate_limit_params( :param period: Time period in seconds for max_calls. :param db_name: Optional name for the SQLite database file for persistent rate limiting. :param use_memory_list: If true, use MemoryListBucket instead of MemoryQueueBucket for in-memory limiting. + :param limit_statuses: Optional list of status codes to track for rate limiting. + :param max_delay: Optional maximum delay for rate limiting. :return: Dictionary with rate limit configuration. """ - # Choose the appropriate bucket type based on the presence of db_name + bucket_class = SQLiteBucket if db_name else (MemoryListBucket if use_memory_list else MemoryQueueBucket) bucket_kwargs = {"path": data_dir_path / f"{db_name}.db"} if db_name else {} - # Create a list of request rates based on provided limits rate_limits = [] if per_second: rate_limits.append(RequestRate(per_second, Duration.SECOND)) @@ -201,17 +216,17 @@ def get_rate_limit_params( if max_calls and period: rate_limits.append(RequestRate(max_calls, Duration.SECOND * period)) - # Raise an error if no limits are provided if not rate_limits: raise ValueError("At least one rate limit (per_second, per_minute, per_hour, calculated_rate, or max_calls and period) must be specified.") - # Initialize the limiter with all applicable rate limits - limiter = Limiter(*rate_limits) + limiter = custom_limiter or Limiter(*rate_limits, bucket_class=bucket_class, bucket_kwargs=bucket_kwargs) return { 'limiter': limiter, 'bucket_class': bucket_class, - 'bucket_kwargs': bucket_kwargs + 'bucket_kwargs': bucket_kwargs, + 'limit_statuses': limit_statuses or [429], + 'max_delay': max_delay, } diff --git a/src/tests/test_rate_limiting.py b/src/tests/test_rate_limiting.py new file mode 100644 index 00000000..8e1b0bf3 --- /dev/null +++ b/src/tests/test_rate_limiting.py @@ -0,0 +1,216 @@ +import time +import responses +from requests.exceptions import HTTPError +from program.utils.request import create_service_session, get_rate_limit_params, HttpMethod, BaseRequestHandler, ResponseType, RateLimitExceeded + +@responses.activate +def test_rate_limiter_with_base_request_handler(): + # Setup: Define the URL and rate-limiting parameters + url = "https://api.example.com/endpoint" + rate_limit_params = get_rate_limit_params(per_second=1) # 1 request per second as an example + session = create_service_session(rate_limit_params=rate_limit_params) + + # Initialize the BaseRequestHandler with the rate-limited session + request_handler = BaseRequestHandler(session=session, response_type=ResponseType.DICT, base_url=url) + + for _ in range(3): + responses.add(responses.GET, url, json={"message": "OK"}, status=200) + + for _ in range(5): + responses.add(responses.GET, url, json={"error": "Rate limit exceeded"}, status=429) + + success_count = 0 + rate_limited_count = 0 + + for i in range(8): + try: + # Use BaseRequestHandler's _request method + response_obj = request_handler._request(HttpMethod.GET, "") + print(f"Request {i + 1}: Status {response_obj.status_code} - Success") + success_count += 1 + except RateLimitExceeded as e: + print(f"Request {i + 1}: Rate limit hit - {e}") + rate_limited_count += 1 + except HTTPError as e: + print(f"Request {i + 1}: Failed with error - {e}") + time.sleep(0.1) # Interval shorter than rate limit threshold + + # Assertions + assert success_count == 3, "Expected 3 successful requests before rate limiting" + assert rate_limited_count == 5, "Expected 5 rate-limited requests after threshold exceeded" + + +@responses.activate +def test_successful_requests_within_limit(): + """Test that requests succeed if within the rate limit.""" + url = "https://api.example.com/endpoint" + rate_limit_params = get_rate_limit_params(per_second=2) # 2 requests per second + session = create_service_session(rate_limit_params=rate_limit_params) + request_handler = BaseRequestHandler(session=session, response_type=ResponseType.DICT, base_url=url) + + # Mock responses for the first 2 requests + responses.add(responses.GET, url, json={"message": "OK"}, status=200) + responses.add(responses.GET, url, json={"message": "OK"}, status=200) + + success_count = 0 + + for i in range(2): + response_obj = request_handler._request(HttpMethod.GET, "") + print(f"Request {i + 1}: Status {response_obj.status_code} - Success") + success_count += 1 + + assert success_count == 2, "Expected both requests to succeed within the rate limit" + + +@responses.activate +def test_rate_limit_exceeded(): + """Test that requests are blocked after rate limit is reached.""" + url = "https://api.example.com/endpoint" + rate_limit_params = get_rate_limit_params(per_second=1) # 1 request per second + session = create_service_session(rate_limit_params=rate_limit_params) + request_handler = BaseRequestHandler(session=session, response_type=ResponseType.DICT, base_url=url) + + # First request is mocked as 200 OK, subsequent as 429 + responses.add(responses.GET, url, json={"message": "OK"}, status=200) + responses.add(responses.GET, url, json={"error": "Rate limit exceeded"}, status=429) + + # First request should succeed + success_count = 0 + rate_limited_count = 0 + + try: + response_obj = request_handler._request(HttpMethod.GET, "") + print(f"Request 1: Status {response_obj.status_code} - Success") + success_count += 1 + except RateLimitExceeded: + rate_limited_count += 1 + + # Second request should be rate-limited + try: + request_handler._request(HttpMethod.GET, "") + except RateLimitExceeded as e: + print("Request 2: Rate limit hit -", e) + rate_limited_count += 1 + + assert success_count == 1, "Expected the first request to succeed" + assert rate_limited_count == 1, "Expected the second request to be rate-limited" + + +@responses.activate +def test_rate_limit_reset(): + """Test that requests succeed after waiting for the rate limit to reset.""" + url = "https://api.example.com/endpoint" + rate_limit_params = get_rate_limit_params(per_second=1) # 1 request per second + session = create_service_session(rate_limit_params=rate_limit_params) + request_handler = BaseRequestHandler(session=session, response_type=ResponseType.DICT, base_url=url) + + # Mock the first request with 200 OK + responses.add(responses.GET, url, json={"message": "OK"}, status=200) + + # Mock the second request with 429 to simulate rate limit + responses.add(responses.GET, url, json={"error": "Rate limit exceeded"}, status=429) + + # Mock the third request after rate limit reset with 200 OK + responses.add(responses.GET, url, json={"message": "OK"}, status=200) + + success_count = 0 + rate_limited_count = 0 + + # First request should succeed + try: + response_obj = request_handler._request(HttpMethod.GET, "") + print(f"Request 1: Status {response_obj.status_code} - Success") + success_count += 1 + except RateLimitExceeded: + rate_limited_count += 1 + + # Second request immediately should be rate-limited + try: + request_handler._request(HttpMethod.GET, "") + except RateLimitExceeded as e: + print("Request 2: Rate limit hit -", e) + rate_limited_count += 1 + + # Wait for the rate limit to reset, then try again + time.sleep(1.1) + try: + response_obj = request_handler._request(HttpMethod.GET, "") + print(f"Request 3: Status {response_obj.status_code} - Success after reset") + success_count += 1 + except RateLimitExceeded: + rate_limited_count += 1 + + assert success_count == 2, "Expected two successful requests (first and after reset)" + assert rate_limited_count == 1, "Expected one rate-limited request (second request)" + + +def test_direct_rate_limiter(): + """Test the Limiter directly to confirm it enforces rate limiting.""" + from pyrate_limiter import Limiter, RequestRate, Duration + + rate_limits = [] + rate_limits.append(RequestRate(1, Duration.SECOND)) + rate_limits.append(RequestRate(60, Duration.MINUTE)) + limiter = Limiter(*rate_limits) # 1 request per second and 60 requests per minute + + success_count = 0 + rate_limited_count = 0 + + # First request should succeed + try: + limiter.try_acquire("test_key") + print("Request 1: Success") + success_count += 1 + except Exception as e: + print("Request 1: Rate limit hit") + rate_limited_count += 1 + + # Additional requests should be rate-limited + for i in range(4): + try: + limiter.try_acquire("test_key") + print(f"Request {i + 2}: Success") + success_count += 1 + except Exception as e: + print(f"Request {i + 2}: Rate limit hit") + rate_limited_count += 1 + time.sleep(0.2) # Short interval to exceed rate limit + + # Assertions + assert success_count == 1, "Expected only one successful request within the rate limit" + assert rate_limited_count >= 1, "Expected at least one rate-limited request after hitting the limit" + + +def test_limiter_session_with_basic_rate_limit(): + """Test a basic LimiterSession that enforces a rate limit of 5 requests per second.""" + rate_limit_params = get_rate_limit_params(per_second=1) + session = create_service_session(rate_limit_params=rate_limit_params) + start = time.time() + request_count = 20 + interval_limit = 5 + buffer_time = 0.8 + + # Store timestamps to analyze intervals + request_timestamps = [] + + # Send 20 requests, observing the time intervals to confirm rate limiting + for i in range(request_count): + response = session.get('https://httpbin.org/get') + current_time = time.time() + request_timestamps.append(current_time) + print(f'[t+{current_time - start:.2f}] Sent request {i + 1} - Status code: {response.status_code}') + + # Check time intervals every 5 requests to confirm rate limiting is applied + if (i + 1) % interval_limit == 0: + elapsed_time = request_timestamps[-1] - request_timestamps[-interval_limit] + assert elapsed_time >= 1 - buffer_time, ( + f"Rate limit exceeded: {interval_limit} requests in {elapsed_time:.2f} seconds" + ) + + # Final assertion to ensure all requests respected the rate limit + total_elapsed_time = request_timestamps[-1] - request_timestamps[0] + expected_min_time = (request_count / interval_limit) - buffer_time + assert total_elapsed_time >= expected_min_time, ( + f"Test failed: Expected at least {expected_min_time:.2f} seconds " + f"for {request_count} requests, got {total_elapsed_time:.2f} seconds" + )