Skip to content

Commit

Permalink
feat: add retry policy and connection pool configuration to request u…
Browse files Browse the repository at this point in the history
…tils (#864)

fix: add retry policy and connection pool configuration to request utils

- Introduced retry policy and connection pool configuration in `create_service_session`.
- Added `get_retry_policy` and `get_http_adapter` functions for customizable retry and connection settings.
- Updated tests to cover new session adapter configurations and ensure correct behavior.
  • Loading branch information
iPromKnight authored Nov 5, 2024
1 parent 91d3f7d commit 1713a51
Show file tree
Hide file tree
Showing 2 changed files with 185 additions and 13 deletions.
126 changes: 114 additions & 12 deletions src/program/utils/request.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,12 @@
from typing import Dict, Type, Optional, Any
from requests import Session
from lxml import etree
from urllib3.util.retry import Retry
from requests.adapters import HTTPAdapter
from requests.exceptions import ConnectTimeout, RequestException, HTTPError
from requests.models import Response
from requests_cache import CacheMixin, CachedSession
from requests_ratelimiter import LimiterMixin, LimiterSession
from requests_ratelimiter import LimiterMixin, LimiterSession, LimiterAdapter
from xmltodict import parse as parse_xml
from loguru import logger
from program.utils import data_dir_path
Expand Down Expand Up @@ -37,7 +39,12 @@ def to_dict(self) -> Dict[str, Any]:


class ResponseObject:
"""Response object to handle different response formats."""
"""Response object to handle different response formats.
:param response: The response object to parse.
:param response_type: The response type to parse the content as.
"""

def __init__(self, response: Response, response_type: ResponseType = ResponseType.SIMPLE_NAMESPACE):
self.response = response
self.is_ok = response.ok
Expand All @@ -47,7 +54,13 @@ def __init__(self, response: Response, response_type: ResponseType = ResponseTyp


def handle_response(self, response: Response, response_type: ResponseType) -> dict | SimpleNamespace:
"""Parse the response content based on content type."""
"""Parse the response content based on content type.
:param response: The response object to parse.
:param response_type: The response type to parse the content as.
:return: Parsed response content.
"""

timeout_statuses = [408, 460, 504, 520, 524, 522, 598, 599]
rate_limit_statuses = [429]
client_error_statuses = list(range(400, 451)) # 400-450
Expand Down Expand Up @@ -84,7 +97,16 @@ def handle_response(self, response: Response, response_type: ResponseType) -> di
return {}

class BaseRequestHandler:
def __init__(self, session: Session, response_type: ResponseType = ResponseType.SIMPLE_NAMESPACE, base_url: Optional[str] = None, base_params: Optional[BaseRequestParameters] = None,
"""Base request handler for services.
:param session: The session to use for requests.
:param response_type: The response type to parse the content as.
:param base_url: Optional base URL to use for requests.
:param base_params: Optional base parameters to include in requests.
:param custom_exception: Optional custom exception to raise on request failure.
:param request_logging: Boolean indicating if request logging should be enabled.
"""
def __init__(self, session: Session | LimiterSession, response_type: ResponseType = ResponseType.SIMPLE_NAMESPACE, base_url: Optional[str] = None, base_params: Optional[BaseRequestParameters] = None,
custom_exception: Optional[Type[Exception]] = None, request_logging: bool = False):
self.session = session
self.response_type = response_type
Expand All @@ -94,11 +116,19 @@ def __init__(self, session: Session, response_type: ResponseType = ResponseType.
self.request_logging = request_logging

def _request(self, method: HttpMethod, endpoint: str, ignore_base_url: Optional[bool] = None, overriden_response_type: ResponseType = None, **kwargs) -> ResponseObject:
"""Generic request handler with error handling, using kwargs for flexibility."""
"""Generic request handler with error handling, using kwargs for flexibility.
:param method: HTTP method to use for the request.
:param endpoint: Endpoint to request.
:param ignore_base_url: Boolean indicating if the base URL should be ignored.
:param overriden_response_type: Optional response type to use for the request.
:param retry_policy: Optional retry policy to use for the request.
:param kwargs: Additional parameters to pass to the request.
:return: ResponseObject with the response data.
"""
try:
url = f"{self.BASE_URL}/{endpoint}".rstrip('/') if not ignore_base_url and self.BASE_URL else endpoint

# Add base parameters to kwargs if they exist
request_params = self.BASE_REQUEST_PARAMS.to_dict()
if request_params:
kwargs.setdefault('params', {}).update(request_params)
Expand Down Expand Up @@ -128,7 +158,7 @@ def _request(self, method: HttpMethod, endpoint: str, ignore_base_url: Optional[


class RateLimitExceeded(Exception):
"""Rate limit exceeded exception"""
"""Rate limit exceeded exception for requests."""
def __init__(self, message, response=None):
super().__init__(message)
self.response = response
Expand All @@ -141,14 +171,18 @@ def create_service_session(
rate_limit_params: Optional[dict] = None,
use_cache: bool = False,
cache_params: Optional[dict] = None,
session_adapter: Optional[HTTPAdapter | LimiterAdapter] = None,
retry_policy: Optional[Retry] = None,
log_config: Optional[bool] = False,
) -> Session:
) -> Session | CachedSession | CachedLimiterSession:
"""
Create a session for a specific service with optional caching and rate-limiting.
:param rate_limit_params: Dictionary of rate-limiting parameters.
:param use_cache: Boolean indicating if caching should be enabled.
:param cache_params: Dictionary of caching parameters if caching is enabled.
:param session_adapter: Optional custom HTTP adapter to use for the session.
:param retry_policy: Optional retry policy to use for the session.
:param log_config: Boolean indicating if the session configuration should be logged.
:return: Configured session for the service.
"""
Expand All @@ -160,14 +194,20 @@ def create_service_session(
logger.debug(f"Rate Limit Parameters: {rate_limit_params}")
logger.debug(f"Cache Parameters: {cache_params}")
session_class = CachedLimiterSession if rate_limit_params else CachedSession
return session_class(**rate_limit_params, **cache_params)
cache_session = session_class(**rate_limit_params, **cache_params)
_create_and_mount_session_adapter(cache_session, session_adapter, retry_policy, log_config)
return cache_session

if rate_limit_params:
if log_config:
logger.debug(f"Rate Limit Parameters: {rate_limit_params}")
return LimiterSession(**rate_limit_params)
limiter_session = LimiterSession(**rate_limit_params)
_create_and_mount_session_adapter(limiter_session, session_adapter, retry_policy, log_config)
return limiter_session

return Session()
standard_session = Session()
_create_and_mount_session_adapter(standard_session, session_adapter, retry_policy, log_config)
return standard_session


def get_rate_limit_params(
Expand All @@ -182,6 +222,7 @@ def get_rate_limit_params(
use_memory_list: bool = False,
limit_statuses: Optional[list[int]] = None,
max_delay: Optional[int] = 0,

) -> Dict[str, any]:
"""
Generate rate limit parameters for a service. If `db_name` is not provided,
Expand Down Expand Up @@ -231,15 +272,76 @@ def get_rate_limit_params(


def get_cache_params(cache_name: str = 'cache', expire_after: Optional[int] = 60) -> dict:
"""Generate cache parameters for a service, ensuring the cache file is in the specified directory."""
"""Generate cache parameters for a service, ensuring the cache file is in the specified directory.
:param cache_name: The name of the cache file excluding the extension.
:param expire_after: The time in seconds to expire the cache.
:return: Dictionary with cache configuration.
"""
cache_path = data_dir_path / f"{cache_name}.db"
return {'cache_name': cache_path, 'expire_after': expire_after}


def get_retry_policy(retries: int = 3, backoff_factor: float = 0.3, status_forcelist: Optional[list[int]] = None) -> Retry:
"""
Create a retry policy for requests.
:param retries: The maximum number of retry attempts.
:param backoff_factor: A backoff factor to apply between attempts.
:param status_forcelist: A list of HTTP status codes that we should force a retry on.
:return: Configured Retry object.
"""
return Retry(total=retries, backoff_factor=backoff_factor, status_forcelist=status_forcelist or [500, 502, 503, 504])


def get_http_adapter(
retry_policy: Optional[Retry] = None,
pool_connections: Optional[int] = 50,
pool_maxsize: Optional[int] = 100,
pool_block: Optional[bool] = True
) -> HTTPAdapter:
"""
Create an HTTP adapter with retry policy and optional rate limiting.
:param retry_policy: The retry policy to use for the adapter.
:param pool_connections: The number of connection pools to allow.
:param pool_maxsize: The maximum number of connections to keep in the pool.
:param pool_block: Boolean indicating if the pool should block when full.
"""
adapter_kwargs = {
'max_retries': retry_policy,
'pool_connections': pool_connections,
'pool_maxsize': pool_maxsize,
'pool_block': pool_block,
}
return HTTPAdapter(**adapter_kwargs)


def xml_to_simplenamespace(xml_string: str) -> SimpleNamespace:
"""Convert an XML string to a SimpleNamespace object."""
root = etree.fromstring(xml_string)
def element_to_simplenamespace(element):
children_as_ns = {child.tag: element_to_simplenamespace(child) for child in element}
attributes = {key: value for key, value in element.attrib.items()}
attributes.update(children_as_ns)
return SimpleNamespace(**attributes, text=element.text)
return element_to_simplenamespace(root)


def _create_and_mount_session_adapter(
session: Session,
adapter_instance: Optional[HTTPAdapter] = None,
retry_policy: Optional[Retry] = None,
log_config: Optional[bool] = False):
"""
Create and mount an HTTP adapter to a session.
:param session: The session to mount the adapter to.
:param retry_policy: The retry policy to use for the adapter.
"""
adapter = adapter_instance or get_http_adapter(retry_policy)
session.mount("https://", adapter)
session.mount("http://", adapter)

if log_config:
logger.debug(f"Mounted http adapter with params: {adapter.__dict__} to session.")
72 changes: 71 additions & 1 deletion src/tests/test_rate_limiting.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
import time
from unittest.mock import patch

import pytest
import responses
from requests.exceptions import HTTPError
from program.utils.request import create_service_session, get_rate_limit_params, HttpMethod, BaseRequestHandler, ResponseType, RateLimitExceeded
from program.utils.request import create_service_session, get_rate_limit_params, HttpMethod, BaseRequestHandler, ResponseType, RateLimitExceeded, get_http_adapter, get_retry_policy

@responses.activate
def test_rate_limiter_with_base_request_handler():
Expand Down Expand Up @@ -214,3 +217,70 @@ def test_limiter_session_with_basic_rate_limit():
f"Test failed: Expected at least {expected_min_time:.2f} seconds "
f"for {request_count} requests, got {total_elapsed_time:.2f} seconds"
)

@pytest.fixture
def retry_policy():
return get_retry_policy(retries=5, backoff_factor=0.5, status_forcelist=[500, 502, 503, 504])

@pytest.fixture
def connection_pool_params():
return {
'pool_connections': 20,
'pool_maxsize': 50,
'pool_block': True
}


def test_session_adapter_configuration(retry_policy, connection_pool_params):
with patch("program.utils.request.HTTPAdapter") as MockAdapter:
session = create_service_session(
retry_policy=retry_policy,
session_adapter=get_http_adapter(
retry_policy=retry_policy,
pool_connections=connection_pool_params["pool_connections"],
pool_maxsize=connection_pool_params["pool_maxsize"],
pool_block=connection_pool_params["pool_block"]
)
)

MockAdapter.assert_called_with(
max_retries=retry_policy,
**connection_pool_params
)

assert session.adapters["http://"] == MockAdapter.return_value
assert session.adapters["https://"] == MockAdapter.return_value


def test_session_adapter_pool_configuration_and_request(retry_policy, connection_pool_params):
# Mock an HTTP endpoint to test request functionality
url = "https://api.example.com/test"
with responses.RequestsMock() as rsps:
rsps.add(rsps.GET, url, json={"message": "success"}, status=200)

session = create_service_session(
retry_policy=retry_policy,
session_adapter=get_http_adapter(
retry_policy=retry_policy,
pool_connections=connection_pool_params["pool_connections"],
pool_maxsize=connection_pool_params["pool_maxsize"],
pool_block=connection_pool_params["pool_block"]
)
)

adapter_http = session.adapters["http://"]
adapter_https = session.adapters["https://"]

assert adapter_http == adapter_https, "HTTP and HTTPS adapters should be the same instance"
assert adapter_http._pool_connections == connection_pool_params["pool_connections"], \
f"Expected pool_connections to be {connection_pool_params['pool_connections']}, got {adapter_http._pool_connections}"
assert adapter_http._pool_maxsize == connection_pool_params["pool_maxsize"], \
f"Expected pool_maxsize to be {connection_pool_params['pool_maxsize']}, got {adapter_http._pool_maxsize}"
assert adapter_http._pool_block == connection_pool_params["pool_block"], \
f"Expected pool_block to be {connection_pool_params['pool_block']}, got {adapter_http._pool_block}"
assert adapter_http.max_retries == retry_policy, \
f"Expected max_retries to be {retry_policy}, got {adapter_http.max_retries}"

response = session.get(url)
assert response.status_code == 200
assert response.json() == {"message": "success"}

0 comments on commit 1713a51

Please sign in to comment.