diff --git a/src/program/utils/request.py b/src/program/utils/request.py index cb9c49ec..18a9dbaa 100644 --- a/src/program/utils/request.py +++ b/src/program/utils/request.py @@ -4,10 +4,12 @@ from typing import Dict, Type, Optional, Any from requests import Session from lxml import etree +from urllib3.util.retry import Retry +from requests.adapters import HTTPAdapter from requests.exceptions import ConnectTimeout, RequestException, HTTPError from requests.models import Response from requests_cache import CacheMixin, CachedSession -from requests_ratelimiter import LimiterMixin, LimiterSession +from requests_ratelimiter import LimiterMixin, LimiterSession, LimiterAdapter from xmltodict import parse as parse_xml from loguru import logger from program.utils import data_dir_path @@ -37,7 +39,12 @@ def to_dict(self) -> Dict[str, Any]: class ResponseObject: - """Response object to handle different response formats.""" + """Response object to handle different response formats. + + :param response: The response object to parse. + :param response_type: The response type to parse the content as. + """ + def __init__(self, response: Response, response_type: ResponseType = ResponseType.SIMPLE_NAMESPACE): self.response = response self.is_ok = response.ok @@ -47,7 +54,13 @@ def __init__(self, response: Response, response_type: ResponseType = ResponseTyp def handle_response(self, response: Response, response_type: ResponseType) -> dict | SimpleNamespace: - """Parse the response content based on content type.""" + """Parse the response content based on content type. + + :param response: The response object to parse. + :param response_type: The response type to parse the content as. + :return: Parsed response content. + """ + timeout_statuses = [408, 460, 504, 520, 524, 522, 598, 599] rate_limit_statuses = [429] client_error_statuses = list(range(400, 451)) # 400-450 @@ -84,7 +97,16 @@ def handle_response(self, response: Response, response_type: ResponseType) -> di return {} class BaseRequestHandler: - def __init__(self, session: Session, response_type: ResponseType = ResponseType.SIMPLE_NAMESPACE, base_url: Optional[str] = None, base_params: Optional[BaseRequestParameters] = None, + """Base request handler for services. + + :param session: The session to use for requests. + :param response_type: The response type to parse the content as. + :param base_url: Optional base URL to use for requests. + :param base_params: Optional base parameters to include in requests. + :param custom_exception: Optional custom exception to raise on request failure. + :param request_logging: Boolean indicating if request logging should be enabled. + """ + def __init__(self, session: Session | LimiterSession, response_type: ResponseType = ResponseType.SIMPLE_NAMESPACE, base_url: Optional[str] = None, base_params: Optional[BaseRequestParameters] = None, custom_exception: Optional[Type[Exception]] = None, request_logging: bool = False): self.session = session self.response_type = response_type @@ -94,11 +116,19 @@ def __init__(self, session: Session, response_type: ResponseType = ResponseType. self.request_logging = request_logging def _request(self, method: HttpMethod, endpoint: str, ignore_base_url: Optional[bool] = None, overriden_response_type: ResponseType = None, **kwargs) -> ResponseObject: - """Generic request handler with error handling, using kwargs for flexibility.""" + """Generic request handler with error handling, using kwargs for flexibility. + + :param method: HTTP method to use for the request. + :param endpoint: Endpoint to request. + :param ignore_base_url: Boolean indicating if the base URL should be ignored. + :param overriden_response_type: Optional response type to use for the request. + :param retry_policy: Optional retry policy to use for the request. + :param kwargs: Additional parameters to pass to the request. + :return: ResponseObject with the response data. + """ try: url = f"{self.BASE_URL}/{endpoint}".rstrip('/') if not ignore_base_url and self.BASE_URL else endpoint - # Add base parameters to kwargs if they exist request_params = self.BASE_REQUEST_PARAMS.to_dict() if request_params: kwargs.setdefault('params', {}).update(request_params) @@ -128,7 +158,7 @@ def _request(self, method: HttpMethod, endpoint: str, ignore_base_url: Optional[ class RateLimitExceeded(Exception): - """Rate limit exceeded exception""" + """Rate limit exceeded exception for requests.""" def __init__(self, message, response=None): super().__init__(message) self.response = response @@ -141,14 +171,18 @@ def create_service_session( rate_limit_params: Optional[dict] = None, use_cache: bool = False, cache_params: Optional[dict] = None, + session_adapter: Optional[HTTPAdapter | LimiterAdapter] = None, + retry_policy: Optional[Retry] = None, log_config: Optional[bool] = False, -) -> Session: +) -> Session | CachedSession | CachedLimiterSession: """ Create a session for a specific service with optional caching and rate-limiting. :param rate_limit_params: Dictionary of rate-limiting parameters. :param use_cache: Boolean indicating if caching should be enabled. :param cache_params: Dictionary of caching parameters if caching is enabled. + :param session_adapter: Optional custom HTTP adapter to use for the session. + :param retry_policy: Optional retry policy to use for the session. :param log_config: Boolean indicating if the session configuration should be logged. :return: Configured session for the service. """ @@ -160,14 +194,20 @@ def create_service_session( logger.debug(f"Rate Limit Parameters: {rate_limit_params}") logger.debug(f"Cache Parameters: {cache_params}") session_class = CachedLimiterSession if rate_limit_params else CachedSession - return session_class(**rate_limit_params, **cache_params) + cache_session = session_class(**rate_limit_params, **cache_params) + _create_and_mount_session_adapter(cache_session, session_adapter, retry_policy, log_config) + return cache_session if rate_limit_params: if log_config: logger.debug(f"Rate Limit Parameters: {rate_limit_params}") - return LimiterSession(**rate_limit_params) + limiter_session = LimiterSession(**rate_limit_params) + _create_and_mount_session_adapter(limiter_session, session_adapter, retry_policy, log_config) + return limiter_session - return Session() + standard_session = Session() + _create_and_mount_session_adapter(standard_session, session_adapter, retry_policy, log_config) + return standard_session def get_rate_limit_params( @@ -182,6 +222,7 @@ def get_rate_limit_params( use_memory_list: bool = False, limit_statuses: Optional[list[int]] = None, max_delay: Optional[int] = 0, + ) -> Dict[str, any]: """ Generate rate limit parameters for a service. If `db_name` is not provided, @@ -231,11 +272,53 @@ def get_rate_limit_params( def get_cache_params(cache_name: str = 'cache', expire_after: Optional[int] = 60) -> dict: - """Generate cache parameters for a service, ensuring the cache file is in the specified directory.""" + """Generate cache parameters for a service, ensuring the cache file is in the specified directory. + + :param cache_name: The name of the cache file excluding the extension. + :param expire_after: The time in seconds to expire the cache. + :return: Dictionary with cache configuration. + """ cache_path = data_dir_path / f"{cache_name}.db" return {'cache_name': cache_path, 'expire_after': expire_after} + +def get_retry_policy(retries: int = 3, backoff_factor: float = 0.3, status_forcelist: Optional[list[int]] = None) -> Retry: + """ + Create a retry policy for requests. + + :param retries: The maximum number of retry attempts. + :param backoff_factor: A backoff factor to apply between attempts. + :param status_forcelist: A list of HTTP status codes that we should force a retry on. + :return: Configured Retry object. + """ + return Retry(total=retries, backoff_factor=backoff_factor, status_forcelist=status_forcelist or [500, 502, 503, 504]) + + +def get_http_adapter( + retry_policy: Optional[Retry] = None, + pool_connections: Optional[int] = 50, + pool_maxsize: Optional[int] = 100, + pool_block: Optional[bool] = True +) -> HTTPAdapter: + """ + Create an HTTP adapter with retry policy and optional rate limiting. + + :param retry_policy: The retry policy to use for the adapter. + :param pool_connections: The number of connection pools to allow. + :param pool_maxsize: The maximum number of connections to keep in the pool. + :param pool_block: Boolean indicating if the pool should block when full. + """ + adapter_kwargs = { + 'max_retries': retry_policy, + 'pool_connections': pool_connections, + 'pool_maxsize': pool_maxsize, + 'pool_block': pool_block, + } + return HTTPAdapter(**adapter_kwargs) + + def xml_to_simplenamespace(xml_string: str) -> SimpleNamespace: + """Convert an XML string to a SimpleNamespace object.""" root = etree.fromstring(xml_string) def element_to_simplenamespace(element): children_as_ns = {child.tag: element_to_simplenamespace(child) for child in element} @@ -243,3 +326,22 @@ def element_to_simplenamespace(element): attributes.update(children_as_ns) return SimpleNamespace(**attributes, text=element.text) return element_to_simplenamespace(root) + + +def _create_and_mount_session_adapter( + session: Session, + adapter_instance: Optional[HTTPAdapter] = None, + retry_policy: Optional[Retry] = None, + log_config: Optional[bool] = False): + """ + Create and mount an HTTP adapter to a session. + + :param session: The session to mount the adapter to. + :param retry_policy: The retry policy to use for the adapter. + """ + adapter = adapter_instance or get_http_adapter(retry_policy) + session.mount("https://", adapter) + session.mount("http://", adapter) + + if log_config: + logger.debug(f"Mounted http adapter with params: {adapter.__dict__} to session.") diff --git a/src/tests/test_rate_limiting.py b/src/tests/test_rate_limiting.py index 8e1b0bf3..84342936 100644 --- a/src/tests/test_rate_limiting.py +++ b/src/tests/test_rate_limiting.py @@ -1,7 +1,10 @@ import time +from unittest.mock import patch + +import pytest import responses from requests.exceptions import HTTPError -from program.utils.request import create_service_session, get_rate_limit_params, HttpMethod, BaseRequestHandler, ResponseType, RateLimitExceeded +from program.utils.request import create_service_session, get_rate_limit_params, HttpMethod, BaseRequestHandler, ResponseType, RateLimitExceeded, get_http_adapter, get_retry_policy @responses.activate def test_rate_limiter_with_base_request_handler(): @@ -214,3 +217,70 @@ def test_limiter_session_with_basic_rate_limit(): f"Test failed: Expected at least {expected_min_time:.2f} seconds " f"for {request_count} requests, got {total_elapsed_time:.2f} seconds" ) + +@pytest.fixture +def retry_policy(): + return get_retry_policy(retries=5, backoff_factor=0.5, status_forcelist=[500, 502, 503, 504]) + +@pytest.fixture +def connection_pool_params(): + return { + 'pool_connections': 20, + 'pool_maxsize': 50, + 'pool_block': True + } + + +def test_session_adapter_configuration(retry_policy, connection_pool_params): + with patch("program.utils.request.HTTPAdapter") as MockAdapter: + session = create_service_session( + retry_policy=retry_policy, + session_adapter=get_http_adapter( + retry_policy=retry_policy, + pool_connections=connection_pool_params["pool_connections"], + pool_maxsize=connection_pool_params["pool_maxsize"], + pool_block=connection_pool_params["pool_block"] + ) + ) + + MockAdapter.assert_called_with( + max_retries=retry_policy, + **connection_pool_params + ) + + assert session.adapters["http://"] == MockAdapter.return_value + assert session.adapters["https://"] == MockAdapter.return_value + + +def test_session_adapter_pool_configuration_and_request(retry_policy, connection_pool_params): + # Mock an HTTP endpoint to test request functionality + url = "https://api.example.com/test" + with responses.RequestsMock() as rsps: + rsps.add(rsps.GET, url, json={"message": "success"}, status=200) + + session = create_service_session( + retry_policy=retry_policy, + session_adapter=get_http_adapter( + retry_policy=retry_policy, + pool_connections=connection_pool_params["pool_connections"], + pool_maxsize=connection_pool_params["pool_maxsize"], + pool_block=connection_pool_params["pool_block"] + ) + ) + + adapter_http = session.adapters["http://"] + adapter_https = session.adapters["https://"] + + assert adapter_http == adapter_https, "HTTP and HTTPS adapters should be the same instance" + assert adapter_http._pool_connections == connection_pool_params["pool_connections"], \ + f"Expected pool_connections to be {connection_pool_params['pool_connections']}, got {adapter_http._pool_connections}" + assert adapter_http._pool_maxsize == connection_pool_params["pool_maxsize"], \ + f"Expected pool_maxsize to be {connection_pool_params['pool_maxsize']}, got {adapter_http._pool_maxsize}" + assert adapter_http._pool_block == connection_pool_params["pool_block"], \ + f"Expected pool_block to be {connection_pool_params['pool_block']}, got {adapter_http._pool_block}" + assert adapter_http.max_retries == retry_policy, \ + f"Expected max_retries to be {retry_policy}, got {adapter_http.max_retries}" + + response = session.get(url) + assert response.status_code == 200 + assert response.json() == {"message": "success"} \ No newline at end of file