Skip to content

Commit

Permalink
refactor: Refactor codebase to improve logging, error handling, and s…
Browse files Browse the repository at this point in the history
…ession management

- Added detailed logging for various scenarios in `realdebrid.py` to improve debugging.
- Enhanced error handling in `_download_item` method in `realdebrid.py`.
- Introduced session refresh for blacklisted streams in `realdebrid.py` and `item.py`.
- Added `requested_id` attribute to `MediaItem` class in `item.py`.
- Updated `copy_attributes` method in `trakt.py` to include new attributes.
- Refactored `delete_item_symlinks` method in `symlink.py` for better readability.
- Improved session management in `db_functions.py` by adding `store_state` calls.
- Enhanced `Overseerr` class in `overseerr.py` to include `requested_id` and improved logging.
- Minor code cleanups and added type hints for better code quality.

feat: Removal of Symlinks and Overseerr requests on removal of item from riven.
  • Loading branch information
iPromKnight authored and Gaisberg committed Aug 20, 2024
1 parent 072a352 commit 276ed79
Show file tree
Hide file tree
Showing 8 changed files with 98 additions and 117 deletions.
12 changes: 11 additions & 1 deletion src/controllers/items.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,16 @@

import Levenshtein
from fastapi import APIRouter, HTTPException, Request

from program.content import Overseerr
from program.db.db import db
from program.db.db_functions import get_media_items_by_ids, delete_media_item, reset_media_item
from program.media.item import MediaItem
from program.media.state import States
from sqlalchemy import func, select

from program.symlink import Symlinker
from utils.logger import logger
from sqlalchemy.orm import joinedload

router = APIRouter(
prefix="/items",
Expand Down Expand Up @@ -200,7 +203,14 @@ async def remove_item(request: Request, ids: str):
if not media_items or len(media_items) != len(ids):
raise ValueError("Invalid item ID(s) provided. Some items may not exist.")
for media_item in media_items:
logger.debug(f"Removing item {media_item.title} with ID {media_item._id}")
request.app.program.em.cancel_job(media_item)
symlink_service = request.app.program.services.get(Symlinker)
if symlink_service:
symlink_service.delete_item_symlinks(media_item)
if media_item.requested_by == "overseerr" and media_item.requested_id:
logger.debug(f"Item was originally requested by Overseerr, deleting request within Overseerr...")
Overseerr.delete_request(media_item.requested_id)
delete_media_item(media_item)
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
Expand Down
6 changes: 3 additions & 3 deletions src/program/content/overseerr.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ def run(self):
if not imdb_id or imdb_id in self.recurring_items:
continue
self.recurring_items.add(imdb_id)
media_item = MediaItem({"imdb_id": imdb_id, "requested_by": self.key, "overseerr_id": mediaId})
media_item = MediaItem({"imdb_id": imdb_id, "requested_by": self.key, "overseerr_id": mediaId, "requested_id": item.id})
if media_item:
yield media_item
else:
Expand Down Expand Up @@ -156,8 +156,8 @@ def delete_request(mediaId: int) -> bool:
settings.url + f"/api/v1/request/{mediaId}",
additional_headers=headers,
)
logger.success(f"Deleted request {mediaId} from overseerr")
return response.is_ok
logger.debug(f"Deleted request {mediaId} from overseerr")
return response.is_ok == True
except Exception as e:
logger.error(f"Failed to delete request from overseerr: {str(e)}")
return False
Expand Down
5 changes: 3 additions & 2 deletions src/program/db/db_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,7 @@ def reset_media_item(item: "MediaItem"):
def reset_streams(item: "MediaItem", active_stream_hash: str = None):
"""Reset streams associated with a MediaItem."""
with db.Session() as session:
item.store_state()
item = session.merge(item)
if active_stream_hash:
stream = session.query(Stream).filter(Stream.infohash == active_stream_hash).first()
Expand All @@ -121,7 +122,6 @@ def reset_streams(item: "MediaItem", active_stream_hash: str = None):
)
item.active_stream = None
session.commit()
session.refresh(item)

def blacklist_stream(item: "MediaItem", stream: Stream, session: Session = None) -> bool:
"""Blacklist a stream for a media item."""
Expand All @@ -131,6 +131,7 @@ def blacklist_stream(item: "MediaItem", stream: Stream, session: Session = None)
close_session = True

try:
item.store_state()
item = session.merge(item)
association_exists = session.query(
session.query(StreamRelation)
Expand All @@ -151,7 +152,6 @@ def blacklist_stream(item: "MediaItem", stream: Stream, session: Session = None)
)

session.commit()
session.refresh(item)
return True
return False
# except Exception as e:
Expand Down Expand Up @@ -249,6 +249,7 @@ def _store_item(item: "MediaItem"):
from program.media.item import Movie, Show, Season, Episode
if isinstance(item, (Movie, Show, Season, Episode)) and item._id is not None:
with db.Session() as session:
item.store_state()
session.merge(item)
session.commit()
logger.log("DATABASE", f"{item.log_string} Updated!")
Expand Down
51 changes: 34 additions & 17 deletions src/program/downloaders/realdebrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@
from pathlib import Path
from types import SimpleNamespace
from typing import Generator, List

from sqlalchemy.orm import Session

from program.db.db import db
from program.db.db_functions import get_stream_count, load_streams_in_pages
from program.media.item import Episode, MediaItem, Movie, Season, Show
Expand Down Expand Up @@ -110,7 +113,8 @@ def run(self, item: MediaItem) -> bool:
if self.is_cached(item) and not self._is_downloaded(item):
self._download_item(item)
return_value = True
self.log_item(item)
if return_value:
self.log_item(item)
return return_value

@staticmethod
Expand Down Expand Up @@ -169,7 +173,7 @@ def is_cached(self, item: MediaItem) -> bool:
try:
response = get(f"{RD_BASE_URL}/torrents/instantAvailability/{streams}/", additional_headers=self.auth_headers, proxies=self.proxy, response_type=dict, specific_rate_limiter=self.torrents_rate_limiter, overall_rate_limiter=self.overall_rate_limiter)
if response.is_ok and response.data and isinstance(response.data, dict):
if self._evaluate_stream_response(response.data, processed_stream_hashes, item, stream_hashes):
if self._evaluate_stream_response(response.data, processed_stream_hashes, item, stream_hashes, session):
return True
processed_stream_hashes.update(stream_chunk)
except Exception as e:
Expand All @@ -184,7 +188,7 @@ def _chunked(lst: List, n: int) -> Generator[List, None, None]:
for i in range(0, len(lst), n):
yield lst[i:i + n]

def _evaluate_stream_response(self, data: dict, processed_stream_hashes: set, item: MediaItem, stream_hashes: dict[str, "Stream"]) -> bool:
def _evaluate_stream_response(self, data: dict, processed_stream_hashes: set, item: "MediaItem", stream_hashes: dict[str, "Stream"], session: "Session") -> bool:
stream_items = list(data.items())

def sorting_key(stream_item):
Expand All @@ -209,6 +213,7 @@ def sorting_key(stream_item):

if not provider_list or not provider_list.get("rd"):
if item.blacklist_stream(stream):
session.refresh(stream)
logger.debug(f"Blacklisted un-cached stream for {item.log_string} with hash: {stream_hash}")
continue

Expand All @@ -217,6 +222,7 @@ def sorting_key(stream_item):
return True
else:
item.blacklist_stream(stream)
session.refresh(stream)
logger.debug(f"Blacklisted stream for {item.log_string} with hash: {stream_hash}")

return False
Expand Down Expand Up @@ -286,6 +292,7 @@ def _is_wanted_movie(self, container: dict, item: Movie) -> bool:
)

if not filenames:
logger.debug(f"No valid files found for {item.log_string} matching your filter settings: min_size: {min_size}, max_size: {max_size}, wanted_formats: {WANTED_FORMATS}")
return False

for file in filenames:
Expand Down Expand Up @@ -318,6 +325,7 @@ def _is_wanted_episode(self, container: dict, item: Episode) -> bool:
]

if not filenames:
logger.debug(f"No valid files found for {item.log_string} matching your filter settings: min_size: {min_size}, max_size: {max_size}, wanted_formats: {WANTED_FORMATS}")
return False

one_season = len(item.parent.parent.seasons) == 1
Expand Down Expand Up @@ -354,6 +362,7 @@ def _is_wanted_season(self, container: dict, item: Season) -> bool:
]

if not filenames:
logger.debug(f"No valid files found for {item.log_string} matching your filter settings: min_size: {min_size}b, max_size: {max_size}, wanted_formats: {WANTED_FORMATS}")
return False

acceptable_states = [States.Indexed, States.Scraped, States.Unknown, States.Failed, States.PartiallyCompleted]
Expand All @@ -364,6 +373,7 @@ def _is_wanted_season(self, container: dict, item: Season) -> bool:
needed_episodes.append(episode.number)

if not needed_episodes:
logger.debug(f"No needed episodes found for {item.log_string}")
return False

# Dictionary to hold the matched files for each episode
Expand Down Expand Up @@ -411,6 +421,7 @@ def _is_wanted_show(self, container: dict, item: Show) -> bool:
]

if not filenames:
logger.debug(f"No valid files found for {item.log_string} matching your filter settings: min_size: {min_size}, max_size: {max_size}, wanted_formats: {WANTED_FORMATS}")
return False

# Create a dictionary to map seasons and episodes needed
Expand All @@ -424,6 +435,7 @@ def _is_wanted_show(self, container: dict, item: Show) -> bool:
needed_episodes[season.number] = needed_episode_numbers

if not any(needed_episodes.values()):
logger.debug(f"No needed episodes found for {item.log_string}")
return False

# logger.debug(f"Checking {len(filenames)} files in container for {item.log_string}")
Expand Down Expand Up @@ -496,18 +508,23 @@ def _is_downloaded(self, item: MediaItem) -> bool:
logger.debug(f"Set active files for item: {item.log_string} with {len(item.active_stream.get('files', {}))} total files")
return True

def _download_item(self, item: MediaItem):
def _download_item(self, item: MediaItem) -> bool:
"""Download item from real-debrid.com"""
logger.debug(f"Starting download for item: {item.log_string}")
request_id = self.add_magnet(item) # uses item.active_stream.hash
logger.debug(f"Magnet added to Real-Debrid, request ID: {request_id} for {item.log_string}")
item.set("active_stream.id", request_id)
self.set_active_files(item)
logger.debug(f"Active files set for item: {item.log_string} with {len(item.active_stream.get('files', {}))} total files")
time.sleep(0.5)
self.select_files(request_id, item)
logger.debug(f"Files selected for request ID: {request_id} for {item.log_string}")
logger.debug(f"Item marked as downloaded: {item.log_string}")
try:
logger.debug(f"Starting download for item: {item.log_string}")
request_id = self.add_magnet(item) # uses item.active_stream.hash
logger.debug(f"Magnet added to Real-Debrid, request ID: {request_id} for {item.log_string}")
item.set("active_stream.id", request_id)
self.set_active_files(item)
logger.debug(f"Active files set for item: {item.log_string} with {len(item.active_stream.get('files', {}))} total files")
time.sleep(0.5)
self.select_files(request_id, item)
logger.debug(f"Files selected for request ID: {request_id} for {item.log_string}")
logger.debug(f"Item marked as downloaded: {item.log_string}")
return True
except Exception as e:
logger.error(f"Error downloading item: {item.log_string}: {e}")
return False

def set_active_files(self, item: MediaItem) -> None:
"""Set active files for item from real-debrid.com"""
Expand Down Expand Up @@ -548,7 +565,7 @@ def set_active_files(self, item: MediaItem) -> None:

### API Methods for Real-Debrid below

def add_magnet(self, item: MediaItem) -> str:
def add_magnet(self, item: MediaItem) -> str | None:
"""Add magnet link to real-debrid.com"""
if not item.active_stream.get("hash"):
logger.error(f"No active stream or hash found for {item.log_string}")
Expand Down Expand Up @@ -610,7 +627,7 @@ def select_files(self, request_id: str, item: MediaItem) -> bool:
specific_rate_limiter=self.torrents_rate_limiter,
overall_rate_limiter=self.overall_rate_limiter
)
return response.is_ok
return response.is_ok == True
except Exception as e:
logger.error(f"Error selecting files for {item.log_string}: {e}")
return False
Expand Down Expand Up @@ -667,7 +684,7 @@ def check_episode():
return True
return False

def check_season(season):
def check_season(season: "Season"):
season_number = season.number
episodes_in_season = {episode.number for episode in season.episodes}
matched_episodes = set()
Expand Down
9 changes: 5 additions & 4 deletions src/program/indexers/trakt.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,21 +22,22 @@ def __init__(self):
self.initialized = True
self.settings = settings_manager.settings.indexer

def copy_attributes(self, source, target):
@staticmethod
def copy_attributes(source, target):
"""Copy attributes from source to target."""
attributes = ["file", "folder", "update_folder", "symlinked", "is_anime", "symlink_path", "subtitles"]
attributes = ["file", "folder", "update_folder", "symlinked", "is_anime", "symlink_path", "subtitles", "requested_by", "requested_at", "overseerr_id", "active_stream", "requested_id"]
for attr in attributes:
target.set(attr, getattr(source, attr, None))

def copy_items(self, itema: MediaItem, itemb: MediaItem):
"""Copy attributes from itema to itemb recursively."""
if isinstance(itema, Show) and isinstance(itemb, Show):
if isinstance(itema, (MediaItem, Show)) and isinstance(itemb, Show):
for seasona, seasonb in zip(itema.seasons, itemb.seasons):
for episodea, episodeb in zip(seasona.episodes, seasonb.episodes):
self.copy_attributes(episodea, episodeb)
seasonb.set("is_anime", itema.is_anime)
itemb.set("is_anime", itema.is_anime)
elif isinstance(itema, Movie) and isinstance(itemb, Movie):
elif isinstance(itema, (MediaItem, Movie)) and isinstance(itemb, Movie):
self.copy_attributes(itema, itemb)
return itemb

Expand Down
14 changes: 11 additions & 3 deletions src/program/media/item.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from program.db.db import db
from program.media.state import States
from RTN import parse
from sqlalchemy.orm import Mapped, mapped_column, relationship
from sqlalchemy.orm import Mapped, mapped_column, relationship, object_session

from program.media.subtitle import Subtitle
from .stream import Stream
Expand All @@ -28,6 +28,7 @@ class MediaItem(db.Model):
type: Mapped[str] = mapped_column(sqlalchemy.String, nullable=False)
requested_at: Mapped[Optional[datetime]] = mapped_column(sqlalchemy.DateTime, default=datetime.now())
requested_by: Mapped[Optional[str]] = mapped_column(sqlalchemy.String, nullable=True)
requested_id: Mapped[Optional[int]] = mapped_column(sqlalchemy.Integer, nullable=True)
indexed_at: Mapped[Optional[datetime]] = mapped_column(sqlalchemy.DateTime, nullable=True)
scraped_at: Mapped[Optional[datetime]] = mapped_column(sqlalchemy.DateTime, nullable=True)
scraped_times: Mapped[Optional[int]] = mapped_column(sqlalchemy.Integer, default=0)
Expand Down Expand Up @@ -87,6 +88,7 @@ def __init__(self, item: dict | None) -> None:
return
self.requested_at = item.get("requested_at", datetime.now())
self.requested_by = item.get("requested_by")
self.requested_id = item.get("requested_id")

self.indexed_at = None

Expand Down Expand Up @@ -135,9 +137,12 @@ def store_state(self) -> None:
if self.last_state != self._determine_state().name:
ws_manager.send_item_update(json.dumps(self.to_dict()))
self.last_state = self._determine_state().name

def is_stream_blacklisted(self, stream: Stream):
"""Check if a stream is blacklisted for this item."""
session = object_session(self)
if session:
session.refresh(self, attribute_names=['blacklisted_streams'])
return stream in self.blacklisted_streams

def blacklist_stream(self, stream: Stream):
Expand Down Expand Up @@ -200,8 +205,11 @@ def copy_other_media_attr(self, other):
self.overseerr_id = getattr(other, "overseerr_id", None)

def is_scraped(self):
session = object_session(self)
if session:
session.refresh(self, attribute_names=['blacklisted_streams']) # Prom: Ensure these reflect the state of whats in the db.
return (len(self.streams) > 0
and any(not stream in self.blacklisted_streams for stream in self.streams))
and any(not stream in self.blacklisted_streams for stream in self.streams))

def to_dict(self):
"""Convert item to dictionary (API response)"""
Expand Down
19 changes: 3 additions & 16 deletions src/program/program.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,30 +202,17 @@ def _retry_library(self) -> None:
for page_number in range(0, (count // number_of_rows_per_page) + 1):
with db.Session() as session:
items_to_submit = session.execute(
select(
MediaItem._id,
MediaItem.type,
MediaItem.last_state,
MediaItem.requested_at,
MediaItem.imdb_id
)
select(MediaItem)
.where(MediaItem.type.in_(["movie", "show"]))
.where(MediaItem.last_state != "Completed")
.order_by(MediaItem.requested_at.desc())
.limit(number_of_rows_per_page)
.offset(page_number * number_of_rows_per_page)
).all()
).unique().scalars().all()

session.expunge_all()
session.close()

for item_data in items_to_submit:
item = MediaItem(None)
item._id = item_data[0]
item.type = item_data[1]
item.last_state = item_data[2]
item.requested_at = item_data[3]
item.imdb_id = item_data[4]
for item in items_to_submit:
self.em.add_event(Event(emitted_by="RetryLibrary", item=item))

def _schedule_functions(self) -> None:
Expand Down
Loading

0 comments on commit 276ed79

Please sign in to comment.