Skip to content

Commit

Permalink
fix: future cancellation resulted in reset, retry endpoints fialing (#…
Browse files Browse the repository at this point in the history
…817)

* fix: future cancellation resulted in reset, retry endpoints fialing

* fix: update reset func to check if indexed

---------

Co-authored-by: Gaisberg <None>
Co-authored-by: Spoked <dreu.lavelle@gmail.com>
  • Loading branch information
Gaisberg and dreulavelle authored Oct 26, 2024
1 parent 2676fe8 commit 19cedc8
Show file tree
Hide file tree
Showing 5 changed files with 53 additions and 64 deletions.
36 changes: 12 additions & 24 deletions src/program/db/db_functions.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import os
import shutil
from threading import Event
from typing import TYPE_CHECKING

import alembic
Expand Down Expand Up @@ -171,15 +172,9 @@ def reset_media_item(item: "MediaItem"):
item.reset()
session.commit()

def reset_streams(item: "MediaItem", active_stream_hash: str = None):
def reset_streams(item: "MediaItem"):
"""Reset streams associated with a MediaItem."""
with db.Session() as session:
item.store_state()
item = session.merge(item)
if active_stream_hash:
stream = session.query(Stream).filter(Stream.infohash == active_stream_hash).first()
if stream:
blacklist_stream(item, stream, session)

session.execute(
delete(StreamRelation).where(StreamRelation.parent_id == item._id)
Expand All @@ -188,20 +183,11 @@ def reset_streams(item: "MediaItem", active_stream_hash: str = None):
session.execute(
delete(StreamBlacklistRelation).where(StreamBlacklistRelation.media_item_id == item._id)
)
item.active_stream = {}
session.commit()

def clear_streams(item: "MediaItem"):
"""Clear all streams for a media item."""
with db.Session() as session:
item = session.merge(item)
session.execute(
delete(StreamRelation).where(StreamRelation.parent_id == item._id)
)
session.execute(
delete(StreamBlacklistRelation).where(StreamBlacklistRelation.media_item_id == item._id)
)
session.commit()
reset_streams(item)

def clear_streams_by_id(media_item_id: int):
"""Clear all streams for a media item by the MediaItem _id."""
Expand Down Expand Up @@ -358,7 +344,7 @@ def store_item(item: "MediaItem"):
finally:
session.close()

def run_thread_with_db_item(fn, service, program, input_id: int = None):
def run_thread_with_db_item(fn, service, program, input_id, cancellation_event: Event):
from program.media.item import MediaItem
if input_id:
with db.Session() as session:
Expand All @@ -378,11 +364,12 @@ def run_thread_with_db_item(fn, service, program, input_id: int = None):
logger.log("PROGRAM", f"Service {service.__name__} emitted {item} from input item {input_item} of type {type(item).__name__}, backing off.")
program.em.remove_id_from_queues(input_item._id)

input_item.store_state()
session.commit()
if not cancellation_event.is_set():
input_item.store_state()
session.commit()

session.expunge_all()
yield res
return res
else:
# Indexing returns a copy of the item, was too lazy to create a copy attr func so this will do for now
indexed_item = next(fn(input_item), None)
Expand All @@ -393,9 +380,10 @@ def run_thread_with_db_item(fn, service, program, input_id: int = None):
indexed_item.store_state()
session.delete(input_item)
indexed_item = session.merge(indexed_item)
session.commit()
logger.debug(f"{input_item._id} is now {indexed_item._id} after indexing...")
yield indexed_item._id
if not cancellation_event.is_set():
session.commit()
logger.debug(f"{input_item._id} is now {indexed_item._id} after indexing...")
return indexed_item._id
return
else:
# Content services
Expand Down
51 changes: 25 additions & 26 deletions src/program/media/item.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,8 +132,8 @@ def __init__(self, item: dict | None) -> None:
#Post processing
self.subtitles = item.get("subtitles", [])

def store_state(self) -> None:
new_state = self._determine_state()
def store_state(self, given_state=None) -> None:
new_state = given_state if given_state else self._determine_state()
if self.last_state and self.last_state != new_state:
sse_manager.publish_event("item_update", {"last_state": self.last_state, "new_state": new_state, "item_id": self._id})
self.last_state = new_state
Expand All @@ -145,6 +145,10 @@ def is_stream_blacklisted(self, stream: Stream):
session.refresh(self, attribute_names=['blacklisted_streams'])
return stream in self.blacklisted_streams

def blacklist_active_stream(self):
stream = next(stream for stream in self.streams if stream.infohash == self.active_stream["infohash"])
self.blacklist_stream(stream)

def blacklist_stream(self, stream: Stream):
value = blacklist_stream(self, stream)
if value:
Expand Down Expand Up @@ -321,20 +325,23 @@ def get_aliases(self) -> dict:
def __hash__(self):
return hash(self._id)

def reset(self, soft_reset: bool = False):
def reset(self):
"""Reset item attributes."""
if self.type == "show":
for season in self.seasons:
for episode in season.episodes:
episode._reset(soft_reset)
season._reset(soft_reset)
episode._reset()
season._reset()
elif self.type == "season":
for episode in self.episodes:
episode._reset(soft_reset)
self._reset(soft_reset)
self.store_state()
episode._reset()
self._reset()
if self.title:
self.store_state(States.Indexed)
else:
self.store_state(States.Requested)

def _reset(self, soft_reset):
def _reset(self):
"""Reset item attributes for rescraping."""
if self.symlink_path:
if Path(self.symlink_path).exists():
Expand All @@ -351,16 +358,8 @@ def _reset(self, soft_reset):
self.set("folder", None)
self.set("alternative_folder", None)

if not self.active_stream:
self.active_stream = {}
if not soft_reset:
if self.active_stream.get("infohash", False):
reset_streams(self, self.active_stream["infohash"])
else:
if self.active_stream.get("infohash", False):
stream = next((stream for stream in self.streams if stream.infohash == self.active_stream["infohash"]), None)
if stream:
self.blacklist_stream(stream)
reset_streams(self)
self.active_stream = {}

self.set("active_stream", {})
self.set("symlinked", False)
Expand All @@ -371,7 +370,7 @@ def _reset(self, soft_reset):
self.set("symlinked_times", 0)
self.set("scraped_times", 0)

logger.debug(f"Item {self.log_string} reset for rescraping")
logger.debug(f"Item {self.log_string} has been reset")

@property
def log_string(self):
Expand Down Expand Up @@ -456,10 +455,10 @@ def _determine_state(self):
return States.Requested
return States.Unknown

def store_state(self) -> None:
def store_state(self, given_state: States =None) -> None:
for season in self.seasons:
season.store_state()
super().store_state()
season.store_state(given_state)
super().store_state(given_state)

def __repr__(self):
return f"Show:{self.log_string}:{self.state.name}"
Expand Down Expand Up @@ -527,10 +526,10 @@ class Season(MediaItem):
"polymorphic_load": "inline",
}

def store_state(self) -> None:
def store_state(self, given_state: States = None) -> None:
for episode in self.episodes:
episode.store_state()
super().store_state()
episode.store_state(given_state)
super().store_state(given_state)

def __init__(self, item):
self.type = "season"
Expand Down
3 changes: 2 additions & 1 deletion src/program/symlink.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,8 @@ def run(self, item: Union[Movie, Show, Season, Episode]):
if not self._should_submit(items):
if item.symlinked_times == 5:
logger.debug(f"Soft resetting {item.log_string} because required files were not found")
item.reset(True)
item.blacklist_active_stream()
item.reset()
yield item
next_attempt = self._calculate_next_attempt(item)
logger.debug(f"Waiting for {item.log_string} to become available, next attempt in {round((next_attempt - datetime.now()).total_seconds())} seconds")
Expand Down
2 changes: 1 addition & 1 deletion src/routers/secure/items.py
Original file line number Diff line number Diff line change
Expand Up @@ -533,7 +533,7 @@ def set_torrent_rd(request: Request, id: int, torrent_id: str) -> SetTorrentRDRe
# downloader = request.app.program.services.get(Downloader).service
# with db.Session() as session:
# item = session.execute(select(MediaItem).where(MediaItem._id == id)).unique().scalar_one()
# item.reset(True)
# item.reset()
# downloader.download_cached(item, hash)
# request.app.program.add_to_queue(item)
# return {"success": True, "message": f"Downloading {item.title} with hash {hash}"}
25 changes: 13 additions & 12 deletions src/utils/event_manager.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import os
import threading
import traceback

from datetime import datetime
Expand All @@ -8,8 +9,7 @@

from loguru import logger
from pydantic import BaseModel
from sqlalchemy.orm.exc import StaleDataError
from concurrent.futures import CancelledError, Future, ThreadPoolExecutor
from concurrent.futures import Future, ThreadPoolExecutor

from utils.sse_manager import sse_manager
from program.db.db import db
Expand Down Expand Up @@ -37,6 +37,7 @@ def __init__(self):
self._futures: list[Future] = []
self._queued_events: list[Event] = []
self._running_events: list[Event] = []
self._canceled_futures: list[Future] = []
self.mutex = Lock()

def _find_or_create_executor(self, service_cls) -> ThreadPoolExecutor:
Expand Down Expand Up @@ -71,7 +72,7 @@ def _process_future(self, future, service):
service (type): The service class associated with the future.
"""
try:
result = next(future.result(), None)
result = future.result()
if future in self._futures:
self._futures.remove(future)
sse_manager.publish_event("event_update", self.get_event_updates())
Expand All @@ -81,10 +82,10 @@ def _process_future(self, future, service):
item_id, timestamp = result, datetime.now()
if item_id:
self.remove_event_from_running(item_id)
if future.cancellation_event.is_set():
logger.debug(f"Future with Item ID: {item_id} was cancelled discarding results...")
return
self.add_event(Event(emitted_by=service, item_id=item_id, run_at=timestamp))
except (StaleDataError, CancelledError):
# Expected behavior when cancelling tasks or when the item was removed
return
except Exception as e:
logger.error(f"Error in future for {future}: {e}")
logger.exception(traceback.format_exc())
Expand Down Expand Up @@ -166,8 +167,10 @@ def submit_job(self, service, program, event=None):
log_message += f" with Item ID: {item_id}"
logger.debug(log_message)

cancellation_event = threading.Event()
executor = self._find_or_create_executor(service)
future = executor.submit(run_thread_with_db_item, program.all_services[service].run, service, program, item_id)
future = executor.submit(run_thread_with_db_item, program.all_services[service].run, service, program, item_id, cancellation_event)
future.cancellation_event = cancellation_event
if event:
future.event = event
self._futures.append(future)
Expand All @@ -186,27 +189,25 @@ def cancel_job(self, item_id: int, suppress_logs=False):
item_id, related_ids = get_item_ids(session, item_id)
ids_to_cancel = set([item_id] + related_ids)

futures_to_remove = []
for future in self._futures:
future_item_id = None
future_related_ids = []

if hasattr(future, 'event') and hasattr(future.event, 'item'):
if hasattr(future, 'event') and hasattr(future.event, 'item_id'):
future_item = future.event.item_id
future_item_id, future_related_ids = get_item_ids(session, future_item)

if future_item_id in ids_to_cancel or any(rid in ids_to_cancel for rid in future_related_ids):
self.remove_id_from_queues(future_item)
futures_to_remove.append(future)
if not future.done() and not future.cancelled():
try:
future.cancellation_event.set()
future.cancel()
self._canceled_futures.append(future)
except Exception as e:
if not suppress_logs:
logger.error(f"Error cancelling future for {future_item.log_string}: {str(e)}")

for future in futures_to_remove:
self._futures.remove(future)

logger.debug(f"Canceled jobs for Item ID {item_id} and its children.")

Expand Down

0 comments on commit 19cedc8

Please sign in to comment.