Skip to content

Commit

Permalink
refactor: rework request handling and rate limiting for downloaders
Browse files Browse the repository at this point in the history
- Introduced `BaseRequestHandler` and `BaseRequestParameters` for standardized request handling.
- Implemented `RealDebridRequestHandler` and `AllDebridRequestHandler` to manage API requests with improved error handling.
- Updated `RealDebridAPI` and `AllDebridAPI` to use new request handlers.
- Enhanced rate limiting configuration for `AllDebridAPI`.
- Improved logging and error management across downloader classes.
  • Loading branch information
iPromKnight authored and Gaisberg committed Nov 4, 2024
1 parent 877ffec commit 0d31e41
Show file tree
Hide file tree
Showing 3 changed files with 134 additions and 80 deletions.
55 changes: 28 additions & 27 deletions src/program/services/downloaders/alldebrid.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
from datetime import datetime
from typing import Dict, Iterator, List, Optional, Tuple

import requests
from requests import Session
from loguru import logger
from requests.exceptions import ConnectTimeout, RequestException
from requests.exceptions import ConnectTimeout
from program.utils.request import get_rate_limit_params, create_service_session, BaseRequestHandler, \
BaseRequestParameters

from program.settings.manager import settings_manager

Expand All @@ -13,39 +15,38 @@
class AllDebridError(Exception):
"""Base exception for AllDebrid related errors"""

class AllDebridBaseRequestParameters(BaseRequestParameters):
"""AllDebrid base request parameters"""
agent: Optional[str] = None

class AllDebridRequestHandler(BaseRequestHandler):
def __init__(self, session: Session, base_url: str, base_params: AllDebridBaseRequestParameters, request_logging: bool = False):
super().__init__(session, base_url, base_params, custom_exception=AllDebridError, request_logging=request_logging)

def execute(self, method: str, endpoint: str, **kwargs) -> dict:
data, status_code = super()._request(method, endpoint, **kwargs)
if not data or "data" not in data:
raise AllDebridError("Invalid response from AllDebrid")
return data["data"]

class AllDebridAPI:
"""Handles AllDebrid API communication"""
BASE_URL = "https://api.alldebrid.com/v4"
AGENT = "Riven"

def __init__(self, api_key: str, proxy_url: Optional[str] = None):
self.api_key = api_key
self.session = requests.Session()
rate_limit_params = get_rate_limit_params(per_minute=600, per_second=12)
self.session = create_service_session(rate_limit_params=rate_limit_params)
self.session.headers.update({
"Authorization": f"Bearer {api_key}"
})
if proxy_url:
self.session.proxies = {"http": proxy_url, "https": proxy_url}
base_params = AllDebridBaseRequestParameters()
base_params.agent = self.AGENT
self.request_handler = AllDebridRequestHandler(self.session, self.BASE_URL, base_params)

def _request(self, method: str, endpoint: str, **params) -> dict:
"""Generic request handler with error handling"""
try:
params["agent"] = self.AGENT
url = f"{self.BASE_URL}/{endpoint}"
response = self.session.request(method, url, params=params)
response.raise_for_status()
data = response.json() if response.content else {}

if not data or "data" not in data:
raise AllDebridError("Invalid response from AllDebrid")

return data["data"]
except requests.exceptions.JSONDecodeError as e:
logger.error(f"Failed to decode JSON response: {e}")
raise AllDebridError("Invalid JSON response") from e
except RequestException as e:
logger.error(f"Request failed: {e}")
raise

class AllDebridDownloader(DownloaderBase):
"""Main AllDebrid downloader class implementing DownloaderBase"""
Expand Down Expand Up @@ -92,7 +93,7 @@ def _validate_settings(self) -> bool:
def _validate_premium(self) -> bool:
"""Validate premium status"""
try:
user_info = self.api._request("GET", "user")
user_info = self.api.request_handler.execute("GET", "user")
user = user_info.get("user", {})

if not user.get("isPremium", False):
Expand Down Expand Up @@ -120,7 +121,7 @@ def get_instant_availability(self, infohashes: List[str]) -> Dict[str, list]:

try:
params = {f"magnets[{i}]": infohash for i, infohash in enumerate(infohashes)}
response = self.api._request("GET", "magnet/instant", **params)
response = self.api.request_handler.execute("GET", "magnet/instant", **params)
magnets = response.get("magnets", [])

availability = {}
Expand Down Expand Up @@ -173,7 +174,7 @@ def add_torrent(self, infohash: str) -> str:
raise AllDebridError("Downloader not properly initialized")

try:
response = self.api._request(
response = self.api.request_handler.execute(
"GET",
"magnet/upload",
**{"magnets[]": infohash}
Expand Down Expand Up @@ -215,7 +216,7 @@ def get_torrent_info(self, torrent_id: str) -> dict:
raise AllDebridError("Downloader not properly initialized")

try:
response = self.api._request("GET", "magnet/status", id=torrent_id)
response = self.api.request_handler.execute("GET", "magnet/status", id=torrent_id)
info = response.get("magnets", {})
if "filename" not in info:
raise AllDebridError("Invalid torrent info response")
Expand All @@ -233,7 +234,7 @@ def delete_torrent(self, torrent_id: str):
raise AllDebridError("Downloader not properly initialized")

try:
self.api._request("GET", "magnet/delete", id=torrent_id)
self.api.request_handler.execute("GET", "magnet/delete", id=torrent_id)
except Exception as e:
logger.error(f"Failed to delete torrent {torrent_id}: {e}")
raise
44 changes: 21 additions & 23 deletions src/program/services/downloaders/realdebrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,14 @@
from enum import Enum
from typing import Dict, List, Optional, Union

import requests
from requests import Session
from loguru import logger
from pydantic import BaseModel
from requests.exceptions import RequestException

from program.settings.manager import settings_manager

from .shared import VIDEO_EXTENSIONS, DownloaderBase, FileFinder, premium_days_left
from program.utils.request import get_rate_limit_params, get_cache_params, create_service_session
from program.utils.request import get_rate_limit_params, create_service_session, BaseRequestHandler


class RDTorrentStatus(str, Enum):
Expand Down Expand Up @@ -43,6 +42,18 @@ class RDTorrent(BaseModel):
class RealDebridError(Exception):
"""Base exception for Real-Debrid related errors"""

class RealDebridRequestHandler(BaseRequestHandler):
def __init__(self, session: Session, base_url: str, request_logging: bool = False):
super().__init__(session, base_url, custom_exception=RealDebridError, request_logging=request_logging)

def execute(self, method: str, endpoint: str, **kwargs) -> Union[dict, list]:
data, status_code = super()._request(method, endpoint, **kwargs)
if status_code == 204:
return {}
if not data:
raise RealDebridError("Invalid JSON response from RealDebrid")
return data

class RealDebridAPI:
"""Handles Real-Debrid API communication"""
BASE_URL = "https://api.real-debrid.com/rest/1.0"
Expand All @@ -54,20 +65,7 @@ def __init__(self, api_key: str, proxy_url: Optional[str] = None):
self.session.headers.update({"Authorization": f"Bearer {api_key}"})
if proxy_url:
self.session.proxies = {"http": proxy_url, "https": proxy_url}

def _request(self, method: str, endpoint: str, **kwargs) -> Union[dict, list]:
"""Generic request handler with error handling"""
try:
url = f"{self.BASE_URL}/{endpoint}"
response = self.session.request(method, url, **kwargs)
response.raise_for_status()
return response.json() if response.content else {}
except requests.exceptions.JSONDecodeError as e:
logger.error(f"Failed to decode JSON response: {e}")
raise RealDebridError("Invalid JSON response") from e
except RequestException as e:
logger.error(f"Request failed: {e}")
raise
self.request_handler = RealDebridRequestHandler(self.session, self.BASE_URL)

class RealDebridDownloader(DownloaderBase):
"""Main Real-Debrid downloader class implementing DownloaderBase"""
Expand Down Expand Up @@ -112,7 +110,7 @@ def _validate_settings(self) -> bool:
def _validate_premium(self) -> bool:
"""Validate premium status"""
try:
user_info = self.api._request("GET", "user")
user_info = self.api.request_handler.execute("GET", "user")
if not user_info.get("premium"):
logger.error("Premium membership required")
return False
Expand All @@ -137,7 +135,7 @@ def get_instant_availability(self, infohashes: List[str]) -> Dict[str, list]:

for attempt in range(self.MAX_RETRIES):
try:
response = self.api._request(
response = self.api.request_handler.execute(
"GET",
f"torrents/instantAvailability/{'/'.join(infohashes)}"
)
Expand Down Expand Up @@ -195,7 +193,7 @@ def add_torrent(self, infohash: str) -> str:

try:
magnet = f"magnet:?xt=urn:btih:{infohash}"
response = self.api._request(
response = self.api.request_handler.execute(
"POST",
"torrents/addMagnet",
data={"magnet": magnet.lower()}
Expand All @@ -214,7 +212,7 @@ def select_files(self, torrent_id: str, files: List[str]):
raise RealDebridError("Downloader not properly initialized")

try:
self.api._request(
self.api.request_handler.execute(
"POST",
f"torrents/selectFiles/{torrent_id}",
data={"files": ",".join(files)}
Expand All @@ -232,7 +230,7 @@ def get_torrent_info(self, torrent_id: str) -> dict:
raise RealDebridError("Downloader not properly initialized")

try:
return self.api._request("GET", f"torrents/info/{torrent_id}")
return self.api.request_handler.execute("GET", f"torrents/info/{torrent_id}")
except Exception as e:
logger.error(f"Failed to get torrent info for {torrent_id}: {e}")
raise
Expand All @@ -247,7 +245,7 @@ def delete_torrent(self, torrent_id: str):
raise RealDebridError("Downloader not properly initialized")

try:
self.api._request("DELETE", f"torrents/delete/{torrent_id}")
self.api.request_handler.execute("DELETE", f"torrents/delete/{torrent_id}")
except Exception as e:
logger.error(f"Failed to delete torrent {torrent_id}: {e}")
raise
115 changes: 85 additions & 30 deletions src/program/utils/request.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,66 @@
import json
import logging
from types import SimpleNamespace
from typing import Optional
from typing import Optional, Dict, Union, Type, Any, Tuple
from requests import Session
from lxml import etree
from requests.adapters import HTTPAdapter
from requests.exceptions import ConnectTimeout, RequestException
from requests.exceptions import ConnectTimeout, RequestException, JSONDecodeError
from requests.models import Response
from requests_cache import CacheMixin, CachedSession
from requests_ratelimiter import LimiterMixin, SQLiteBucket, LimiterSession, MemoryQueueBucket, MemoryListBucket
from pyrate_limiter import RequestRate, Duration, Limiter
from requests_ratelimiter import LimiterMixin, LimiterSession
from urllib3.util.retry import Retry
from xmltodict import parse as parse_xml

from loguru import logger
from program.utils import data_dir_path

logger = logging.getLogger(__name__)
class BaseRequestParameters:
"""Holds base parameters that may be included in every request."""

def to_dict(self) -> Dict[str, Optional[str]]:
"""Convert all non-None attributes to a dictionary for inclusion in requests."""
return {key: value for key, value in self.__dict__.items() if value is not None}


class BaseRequestHandler:
def __init__(self, session: Session, base_url: str, base_params: Optional[BaseRequestParameters] = None,
custom_exception: Optional[Type[Exception]] = None, request_logging: bool = False):
self.session = session
self.BASE_URL = base_url
self.BASE_REQUEST_PARAMS = base_params or BaseRequestParameters()
self.custom_exception = custom_exception or Exception
self.request_logging = request_logging

def _request(self, method: str, endpoint: str, **kwargs) -> tuple[None, Any] | Any:
"""Generic request handler with error handling, using kwargs for flexibility."""
try:
url = f"{self.BASE_URL}/{endpoint}"

request_params = self.BASE_REQUEST_PARAMS.to_dict()
if request_params:
kwargs.setdefault('params', {}).update(request_params)
elif 'params' in kwargs and not kwargs['params']:
del kwargs['params']

if self.request_logging:
logger.debug(f"Making request to {url} with kwargs: {kwargs}")

response = self.session.request(method, url, **kwargs)
response.raise_for_status()

if response.content:
try:
data = response.json()
if self.request_logging:
logger.debug(f"Response JSON from {endpoint}: {data}")
return data, response.status_code
except JSONDecodeError:
logger.error("Received non-JSON response")
raise self.custom_exception("Non-JSON response received from API")
else:
return None, response.status_code
except RequestException as e:
logger.error(f"Request failed: {e}")
raise self.custom_exception(f"Request failed: {e}") from e

class RateLimitExceeded(Exception):
"""Rate limit exceeded exception"""
Expand Down Expand Up @@ -147,15 +192,15 @@ def ping(session: Session, url: str, timeout: int = 10, additional_headers=None,


def get_rate_limit_params(
per_second=None,
per_minute=None,
per_hour=None,
calculated_rate=None,
per_second: Optional[int] = None,
per_minute: Optional[int] = None,
per_hour: Optional[int] = None,
calculated_rate: Optional[int] = None,
max_calls: Optional[int] = None,
period: Optional[int] = None,
db_name: Optional[str] = None,
use_memory_list: bool = False
) -> dict:
) -> Dict[str, any]:
"""
Generate rate limit parameters for a service. If `db_name` is not provided,
use an in-memory bucket for rate limiting.
Expand All @@ -170,25 +215,35 @@ def get_rate_limit_params(
:param use_memory_list: If true, use MemoryListBucket instead of MemoryQueueBucket for in-memory limiting.
:return: Dictionary with rate limit configuration.
"""
# Choose bucket type based on whether db_name is provided
if db_name:
bucket_class = SQLiteBucket
bucket_kwargs = {"path": data_dir_path / f"{db_name}.db"}
else:
bucket_class = MemoryListBucket if use_memory_list else MemoryQueueBucket
bucket_kwargs = {}

# Set up the limiter based on available rate parameters
# Choose the appropriate bucket type based on the presence of db_name
bucket_class = SQLiteBucket if db_name else (MemoryListBucket if use_memory_list else MemoryQueueBucket)
bucket_kwargs = {"path": data_dir_path / f"{db_name}.db"} if db_name else {}

# Create a list of request rates based on provided limits
rate_limits = []
if per_second:
rate_limits.append(RequestRate(per_second, Duration.SECOND))
if per_minute:
rate_limits.append(RequestRate(per_minute, Duration.MINUTE))
if per_hour:
rate_limits.append(RequestRate(per_hour, Duration.HOUR))
if calculated_rate:
rate_limits.append(RequestRate(calculated_rate, Duration.MINUTE))
if max_calls and period:
limiter = Limiter(RequestRate(max_calls, Duration.SECOND * period))
return {'limiter': limiter, 'bucket_class': bucket_class, 'bucket_kwargs': bucket_kwargs}
elif calculated_rate:
return {'per_minute': calculated_rate, 'bucket_class': bucket_class, 'bucket_kwargs': bucket_kwargs}
else:
limit_key = ('per_second' if per_second else 'per_minute' if per_minute else 'per_hour' if per_hour else None)
if not limit_key:
raise ValueError("One of max_calls and period, per_second, per_minute, or per_hour must be provided.")
return {limit_key: locals()[limit_key], 'bucket_class': bucket_class, 'bucket_kwargs': bucket_kwargs}
rate_limits.append(RequestRate(max_calls, Duration.SECOND * period))

# Raise an error if no limits are provided
if not rate_limits:
raise ValueError("At least one rate limit (per_second, per_minute, per_hour, calculated_rate, or max_calls and period) must be specified.")

# Initialize the limiter with all applicable rate limits
limiter = Limiter(*rate_limits)

return {
'limiter': limiter,
'bucket_class': bucket_class,
'bucket_kwargs': bucket_kwargs
}


def get_cache_params(cache_name: str = 'cache', expire_after: Optional[int] = 60) -> dict:
Expand Down

0 comments on commit 0d31e41

Please sign in to comment.