diff --git a/src/program/db/db_functions.py b/src/program/db/db_functions.py index 72803fc5..551a87a5 100644 --- a/src/program/db/db_functions.py +++ b/src/program/db/db_functions.py @@ -1,5 +1,6 @@ import os import shutil +from threading import Event from typing import TYPE_CHECKING import alembic @@ -171,15 +172,9 @@ def reset_media_item(item: "MediaItem"): item.reset() session.commit() -def reset_streams(item: "MediaItem", active_stream_hash: str = None): +def reset_streams(item: "MediaItem"): """Reset streams associated with a MediaItem.""" with db.Session() as session: - item.store_state() - item = session.merge(item) - if active_stream_hash: - stream = session.query(Stream).filter(Stream.infohash == active_stream_hash).first() - if stream: - blacklist_stream(item, stream, session) session.execute( delete(StreamRelation).where(StreamRelation.parent_id == item._id) @@ -188,20 +183,11 @@ def reset_streams(item: "MediaItem", active_stream_hash: str = None): session.execute( delete(StreamBlacklistRelation).where(StreamBlacklistRelation.media_item_id == item._id) ) - item.active_stream = {} session.commit() def clear_streams(item: "MediaItem"): """Clear all streams for a media item.""" - with db.Session() as session: - item = session.merge(item) - session.execute( - delete(StreamRelation).where(StreamRelation.parent_id == item._id) - ) - session.execute( - delete(StreamBlacklistRelation).where(StreamBlacklistRelation.media_item_id == item._id) - ) - session.commit() + reset_streams(item) def clear_streams_by_id(media_item_id: int): """Clear all streams for a media item by the MediaItem _id.""" @@ -358,7 +344,7 @@ def store_item(item: "MediaItem"): finally: session.close() -def run_thread_with_db_item(fn, service, program, input_id: int = None): +def run_thread_with_db_item(fn, service, program, input_id, cancellation_event: Event): from program.media.item import MediaItem if input_id: with db.Session() as session: @@ -378,11 +364,12 @@ def run_thread_with_db_item(fn, service, program, input_id: int = None): logger.log("PROGRAM", f"Service {service.__name__} emitted {item} from input item {input_item} of type {type(item).__name__}, backing off.") program.em.remove_id_from_queues(input_item._id) - input_item.store_state() - session.commit() + if not cancellation_event.is_set(): + input_item.store_state() + session.commit() session.expunge_all() - yield res + return res else: # Indexing returns a copy of the item, was too lazy to create a copy attr func so this will do for now indexed_item = next(fn(input_item), None) @@ -393,9 +380,10 @@ def run_thread_with_db_item(fn, service, program, input_id: int = None): indexed_item.store_state() session.delete(input_item) indexed_item = session.merge(indexed_item) - session.commit() - logger.debug(f"{input_item._id} is now {indexed_item._id} after indexing...") - yield indexed_item._id + if not cancellation_event.is_set(): + session.commit() + logger.debug(f"{input_item._id} is now {indexed_item._id} after indexing...") + return indexed_item._id return else: # Content services diff --git a/src/program/media/item.py b/src/program/media/item.py index e00ffd3a..5f17b13d 100644 --- a/src/program/media/item.py +++ b/src/program/media/item.py @@ -132,8 +132,8 @@ def __init__(self, item: dict | None) -> None: #Post processing self.subtitles = item.get("subtitles", []) - def store_state(self) -> None: - new_state = self._determine_state() + def store_state(self, given_state=None) -> None: + new_state = given_state if given_state else self._determine_state() if self.last_state and self.last_state != new_state: sse_manager.publish_event("item_update", {"last_state": self.last_state, "new_state": new_state, "item_id": self._id}) self.last_state = new_state @@ -145,6 +145,10 @@ def is_stream_blacklisted(self, stream: Stream): session.refresh(self, attribute_names=['blacklisted_streams']) return stream in self.blacklisted_streams + def blacklist_active_stream(self): + stream = next(stream for stream in self.streams if stream.infohash == self.active_stream["infohash"]) + self.blacklist_stream(stream) + def blacklist_stream(self, stream: Stream): value = blacklist_stream(self, stream) if value: @@ -321,20 +325,23 @@ def get_aliases(self) -> dict: def __hash__(self): return hash(self._id) - def reset(self, soft_reset: bool = False): + def reset(self): """Reset item attributes.""" if self.type == "show": for season in self.seasons: for episode in season.episodes: - episode._reset(soft_reset) - season._reset(soft_reset) + episode._reset() + season._reset() elif self.type == "season": for episode in self.episodes: - episode._reset(soft_reset) - self._reset(soft_reset) - self.store_state() + episode._reset() + self._reset() + if self.title: + self.store_state(States.Indexed) + else: + self.store_state(States.Requested) - def _reset(self, soft_reset): + def _reset(self): """Reset item attributes for rescraping.""" if self.symlink_path: if Path(self.symlink_path).exists(): @@ -351,16 +358,8 @@ def _reset(self, soft_reset): self.set("folder", None) self.set("alternative_folder", None) - if not self.active_stream: - self.active_stream = {} - if not soft_reset: - if self.active_stream.get("infohash", False): - reset_streams(self, self.active_stream["infohash"]) - else: - if self.active_stream.get("infohash", False): - stream = next((stream for stream in self.streams if stream.infohash == self.active_stream["infohash"]), None) - if stream: - self.blacklist_stream(stream) + reset_streams(self) + self.active_stream = {} self.set("active_stream", {}) self.set("symlinked", False) @@ -371,7 +370,7 @@ def _reset(self, soft_reset): self.set("symlinked_times", 0) self.set("scraped_times", 0) - logger.debug(f"Item {self.log_string} reset for rescraping") + logger.debug(f"Item {self.log_string} has been reset") @property def log_string(self): @@ -456,10 +455,10 @@ def _determine_state(self): return States.Requested return States.Unknown - def store_state(self) -> None: + def store_state(self, given_state: States =None) -> None: for season in self.seasons: - season.store_state() - super().store_state() + season.store_state(given_state) + super().store_state(given_state) def __repr__(self): return f"Show:{self.log_string}:{self.state.name}" @@ -527,10 +526,10 @@ class Season(MediaItem): "polymorphic_load": "inline", } - def store_state(self) -> None: + def store_state(self, given_state: States = None) -> None: for episode in self.episodes: - episode.store_state() - super().store_state() + episode.store_state(given_state) + super().store_state(given_state) def __init__(self, item): self.type = "season" diff --git a/src/program/symlink.py b/src/program/symlink.py index 0123ab0f..8f3cf7bb 100644 --- a/src/program/symlink.py +++ b/src/program/symlink.py @@ -94,7 +94,8 @@ def run(self, item: Union[Movie, Show, Season, Episode]): if not self._should_submit(items): if item.symlinked_times == 5: logger.debug(f"Soft resetting {item.log_string} because required files were not found") - item.reset(True) + item.blacklist_active_stream() + item.reset() yield item next_attempt = self._calculate_next_attempt(item) logger.debug(f"Waiting for {item.log_string} to become available, next attempt in {round((next_attempt - datetime.now()).total_seconds())} seconds") diff --git a/src/routers/secure/items.py b/src/routers/secure/items.py index 91347400..4b8ee20e 100644 --- a/src/routers/secure/items.py +++ b/src/routers/secure/items.py @@ -533,7 +533,7 @@ def set_torrent_rd(request: Request, id: int, torrent_id: str) -> SetTorrentRDRe # downloader = request.app.program.services.get(Downloader).service # with db.Session() as session: # item = session.execute(select(MediaItem).where(MediaItem._id == id)).unique().scalar_one() -# item.reset(True) +# item.reset() # downloader.download_cached(item, hash) # request.app.program.add_to_queue(item) # return {"success": True, "message": f"Downloading {item.title} with hash {hash}"} diff --git a/src/utils/event_manager.py b/src/utils/event_manager.py index 32c1c442..ab3e19b4 100644 --- a/src/utils/event_manager.py +++ b/src/utils/event_manager.py @@ -1,4 +1,5 @@ import os +import threading import traceback from datetime import datetime @@ -8,8 +9,7 @@ from loguru import logger from pydantic import BaseModel -from sqlalchemy.orm.exc import StaleDataError -from concurrent.futures import CancelledError, Future, ThreadPoolExecutor +from concurrent.futures import Future, ThreadPoolExecutor from utils.sse_manager import sse_manager from program.db.db import db @@ -37,6 +37,7 @@ def __init__(self): self._futures: list[Future] = [] self._queued_events: list[Event] = [] self._running_events: list[Event] = [] + self._canceled_futures: list[Future] = [] self.mutex = Lock() def _find_or_create_executor(self, service_cls) -> ThreadPoolExecutor: @@ -71,7 +72,7 @@ def _process_future(self, future, service): service (type): The service class associated with the future. """ try: - result = next(future.result(), None) + result = future.result() if future in self._futures: self._futures.remove(future) sse_manager.publish_event("event_update", self.get_event_updates()) @@ -81,10 +82,10 @@ def _process_future(self, future, service): item_id, timestamp = result, datetime.now() if item_id: self.remove_event_from_running(item_id) + if future.cancellation_event.is_set(): + logger.debug(f"Future with Item ID: {item_id} was cancelled discarding results...") + return self.add_event(Event(emitted_by=service, item_id=item_id, run_at=timestamp)) - except (StaleDataError, CancelledError): - # Expected behavior when cancelling tasks or when the item was removed - return except Exception as e: logger.error(f"Error in future for {future}: {e}") logger.exception(traceback.format_exc()) @@ -166,8 +167,10 @@ def submit_job(self, service, program, event=None): log_message += f" with Item ID: {item_id}" logger.debug(log_message) + cancellation_event = threading.Event() executor = self._find_or_create_executor(service) - future = executor.submit(run_thread_with_db_item, program.all_services[service].run, service, program, item_id) + future = executor.submit(run_thread_with_db_item, program.all_services[service].run, service, program, item_id, cancellation_event) + future.cancellation_event = cancellation_event if event: future.event = event self._futures.append(future) @@ -186,27 +189,25 @@ def cancel_job(self, item_id: int, suppress_logs=False): item_id, related_ids = get_item_ids(session, item_id) ids_to_cancel = set([item_id] + related_ids) - futures_to_remove = [] for future in self._futures: future_item_id = None future_related_ids = [] - if hasattr(future, 'event') and hasattr(future.event, 'item'): + if hasattr(future, 'event') and hasattr(future.event, 'item_id'): future_item = future.event.item_id future_item_id, future_related_ids = get_item_ids(session, future_item) if future_item_id in ids_to_cancel or any(rid in ids_to_cancel for rid in future_related_ids): self.remove_id_from_queues(future_item) - futures_to_remove.append(future) if not future.done() and not future.cancelled(): try: + future.cancellation_event.set() future.cancel() + self._canceled_futures.append(future) except Exception as e: if not suppress_logs: logger.error(f"Error cancelling future for {future_item.log_string}: {str(e)}") - for future in futures_to_remove: - self._futures.remove(future) logger.debug(f"Canceled jobs for Item ID {item_id} and its children.")