diff --git a/docker-compose.yml b/docker-compose.yml index 26e724f9..933c9d17 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -44,7 +44,7 @@ services: condition: service_healthy riven_postgres: - image: postgres:16.3-alpine3.20 + image: postgres:17.0-alpine3.20 container_name: riven-db environment: PGDATA: /var/lib/postgresql/data/pgdata diff --git a/src/controllers/items.py b/src/controllers/items.py index e213f562..ded2335a 100644 --- a/src/controllers/items.py +++ b/src/controllers/items.py @@ -201,44 +201,41 @@ async def get_items_by_imdb_ids(request: Request, imdb_ids: str): return {"success": True, "items": [item.to_extended_dict() for item in items]} @router.post( - "/reset", - summary="Reset Media Items", - description="Reset media items with bases on item IDs", + "/reset", + summary="Reset Media Items", + description="Reset media items with bases on item IDs", ) -async def reset_items( - request: Request, ids: str -): +async def reset_items(request: Request, ids: str): ids = handle_ids(ids) try: - media_items = get_media_items_by_ids(ids) - if not media_items or len(media_items) != len(ids): - raise ValueError("Invalid item ID(s) provided. Some items may not exist.") - for media_item in media_items: + media_items_generator = get_media_items_by_ids(ids) + for media_item in media_items_generator: try: request.app.program.em.cancel_job(media_item) clear_streams(media_item) reset_media_item(media_item) - except Exception as e: + except ValueError as e: logger.error(f"Failed to reset item with id {media_item._id}: {str(e)}") continue + except Exception as e: + logger.error(f"Unexpected error while resetting item with id {media_item._id}: {str(e)}") + continue except ValueError as e: raise HTTPException(status_code=400, detail=str(e)) return {"success": True, "message": f"Reset items with id {ids}"} @router.post( - "/retry", - summary="Retry Media Items", - description="Retry media items with bases on item IDs", + "/retry", + summary="Retry Media Items", + description="Retry media items with bases on item IDs", ) async def retry_items(request: Request, ids: str): ids = handle_ids(ids) try: - media_items = get_media_items_by_ids(ids) - if not media_items or len(media_items) != len(ids): - raise ValueError("Invalid item ID(s) provided. Some items may not exist.") - for media_item in media_items: + media_items_generator = get_media_items_by_ids(ids) + for media_item in media_items_generator: request.app.program.em.cancel_job(media_item) - await asyncio.sleep(0.1) # Ensure cancellation is processed + await asyncio.sleep(0.1) # Ensure cancellation is processed request.app.program.em.add_item(media_item) except ValueError as e: raise HTTPException(status_code=400, detail=str(e)) diff --git a/src/program/db/db.py b/src/program/db/db.py index 67c706e6..8485b555 100644 --- a/src/program/db/db.py +++ b/src/program/db/db.py @@ -25,11 +25,11 @@ # cursor.execute("SET statement_timeout = 300000") # cursor.close() -db = SQLAlchemy(settings_manager.settings.database.host, engine_options=engine_options) +db_host = settings_manager.settings.database.host +db = SQLAlchemy(db_host, engine_options=engine_options) script_location = data_dir_path / "alembic/" - if not os.path.exists(script_location): os.makedirs(script_location) @@ -37,6 +37,19 @@ alembic.init(script_location) +def create_database_if_not_exists(): + """Create the database if it doesn't exist.""" + db_name = db_host.split('/')[-1] + db_base_host = '/'.join(db_host.split('/')[:-1]) + try: + temp_db = SQLAlchemy(db_base_host, engine_options=engine_options) + with temp_db.engine.connect() as connection: + connection.execution_options(isolation_level="AUTOCOMMIT").execute(text(f"CREATE DATABASE {db_name}")) + return True + except Exception as e: + logger.error(f"Failed to create database {db_name}: {e}") + return False + # https://stackoverflow.com/questions/61374525/how-do-i-check-if-alembic-migrations-need-to-be-generated def need_upgrade_check() -> bool: """Check if there are any pending migrations.""" diff --git a/src/program/db/db_functions.py b/src/program/db/db_functions.py index b49c7e50..0ad9f5b9 100644 --- a/src/program/db/db_functions.py +++ b/src/program/db/db_functions.py @@ -4,7 +4,7 @@ import alembic from sqlalchemy import delete, func, insert, select, text, union_all -from sqlalchemy.orm import Session, aliased, joinedload +from sqlalchemy.orm import Session, aliased, selectinload from program.libraries.symlink import fix_broken_symlinks from program.media.stream import Stream, StreamBlacklistRelation, StreamRelation @@ -21,41 +21,42 @@ def get_media_items_by_ids(media_item_ids: list[int]): """Retrieve multiple MediaItems by a list of MediaItem _ids using the _get_item_from_db method.""" from program.media.item import Episode, MediaItem, Movie, Season, Show - items = [] + + def get_item(session, media_item_id, item_type): + match item_type: + case "movie": + return session.execute( + select(Movie) + .where(MediaItem._id == media_item_id) + ).unique().scalar_one() + case "show": + return session.execute( + select(Show) + .where(MediaItem._id == media_item_id) + .options(selectinload(Show.seasons).selectinload(Season.episodes)) + ).unique().scalar_one() + case "season": + return session.execute( + select(Season) + .where(Season._id == media_item_id) + .options(selectinload(Season.episodes)) + ).unique().scalar_one() + case "episode": + return session.execute( + select(Episode) + .where(Episode._id == media_item_id) + ).unique().scalar_one() + case _: + return None with db.Session() as session: for media_item_id in media_item_ids: - item_type = session.execute(select(MediaItem.type).where(MediaItem._id==media_item_id)).scalar_one() + item_type = session.execute(select(MediaItem.type).where(MediaItem._id == media_item_id)).scalar_one() if not item_type: continue - item = None - match item_type: - case "movie": - item = session.execute( - select(Movie) - .where(MediaItem._id == media_item_id) - ).unique().scalar_one() - case "show": - item = session.execute( - select(Show) - .where(MediaItem._id == media_item_id) - .options(joinedload(Show.seasons).joinedload(Season.episodes)) - ).unique().scalar_one() - case "season": - item = session.execute( - select(Season) - .where(Season._id == media_item_id) - .options(joinedload(Season.episodes)) - ).unique().scalar_one() - case "episode": - item = session.execute( - select(Episode) - .where(Episode._id == media_item_id) - ).unique().scalar_one() + item = get_item(session, media_item_id, item_type) if item: - items.append(item) - - return items + yield item def get_parent_items_by_ids(media_item_ids: list[int]): """Retrieve multiple MediaItems of type 'movie' or 'show' by a list of MediaItem _ids.""" @@ -312,36 +313,29 @@ def _get_item_from_db(session, item: "MediaItem"): if not _ensure_item_exists_in_db(item): return None session.expire_on_commit = False - type = _get_item_type_from_db(item) - match type: + match item.type: case "movie": - r = session.execute( + return session.execute( select(Movie) .where(MediaItem.imdb_id == item.imdb_id) - ).unique().scalar_one() - return r + ).unique().scalar_one_or_none() case "show": - r = session.execute( + return session.execute( select(Show) .where(MediaItem.imdb_id == item.imdb_id) - .options(joinedload(Show.seasons).joinedload(Season.episodes)) - ).unique().scalar_one() - return r + ).unique().scalar_one_or_none() case "season": - r = session.execute( + return session.execute( select(Season) .where(Season._id == item._id) - .options(joinedload(Season.episodes)) - ).unique().scalar_one() - return r + ).unique().scalar_one_or_none() case "episode": - r = session.execute( + return session.execute( select(Episode) .where(Episode._id == item._id) - ).unique().scalar_one() - return r + ).unique().scalar_one_or_none() case _: - logger.error(f"_get_item_from_db Failed to create item from type: {type}") + logger.error(f"_get_item_from_db Failed to create item from type: {item.type}") return None def _check_for_and_run_insertion_required(session, item: "MediaItem") -> bool: diff --git a/src/program/media/item.py b/src/program/media/item.py index 82dd42ea..7bb386b7 100644 --- a/src/program/media/item.py +++ b/src/program/media/item.py @@ -33,8 +33,8 @@ class MediaItem(db.Model): scraped_at: Mapped[Optional[datetime]] = mapped_column(sqlalchemy.DateTime, nullable=True) scraped_times: Mapped[Optional[int]] = mapped_column(sqlalchemy.Integer, default=0) active_stream: Mapped[Optional[dict]] = mapped_column(sqlalchemy.JSON, nullable=True) - streams: Mapped[list[Stream]] = relationship(secondary="StreamRelation", back_populates="parents", lazy="select", cascade="all") - blacklisted_streams: Mapped[list[Stream]] = relationship(secondary="StreamBlacklistRelation", back_populates="blacklisted_parents", lazy="select", cascade="all") + streams: Mapped[list[Stream]] = relationship(secondary="StreamRelation", back_populates="parents", lazy="selectin", cascade="all") + blacklisted_streams: Mapped[list[Stream]] = relationship(secondary="StreamBlacklistRelation", back_populates="blacklisted_parents", lazy="selectin", cascade="all") symlinked: Mapped[Optional[bool]] = mapped_column(sqlalchemy.Boolean, default=False) symlinked_at: Mapped[Optional[datetime]] = mapped_column(sqlalchemy.DateTime, nullable=True) symlinked_times: Mapped[Optional[int]] = mapped_column(sqlalchemy.Integer, default=0) @@ -59,7 +59,7 @@ class MediaItem(db.Model): update_folder: Mapped[Optional[str]] = mapped_column(sqlalchemy.String, nullable=True) overseerr_id: Mapped[Optional[int]] = mapped_column(sqlalchemy.Integer, nullable=True) last_state: Mapped[Optional[States]] = mapped_column(sqlalchemy.Enum(States), default=States.Unknown) - subtitles: Mapped[list[Subtitle]] = relationship(Subtitle, back_populates="parent", lazy="joined", cascade="all, delete-orphan") + subtitles: Mapped[list[Subtitle]] = relationship(Subtitle, back_populates="parent", lazy="selectin", cascade="all, delete-orphan") __mapper_args__ = { "polymorphic_identity": "mediaitem", diff --git a/src/program/media/stream.py b/src/program/media/stream.py index 8a4d5aad..4182ce7e 100644 --- a/src/program/media/stream.py +++ b/src/program/media/stream.py @@ -47,8 +47,8 @@ class Stream(db.Model): rank: Mapped[int] = mapped_column(sqlalchemy.Integer, nullable=False) lev_ratio: Mapped[float] = mapped_column(sqlalchemy.Float, nullable=False) - parents: Mapped[list["MediaItem"]] = relationship(secondary="StreamRelation", back_populates="streams") - blacklisted_parents: Mapped[list["MediaItem"]] = relationship(secondary="StreamBlacklistRelation", back_populates="blacklisted_streams") + parents: Mapped[list["MediaItem"]] = relationship(secondary="StreamRelation", back_populates="streams", lazy="selectin") + blacklisted_parents: Mapped[list["MediaItem"]] = relationship(secondary="StreamBlacklistRelation", back_populates="blacklisted_streams", lazy="selectin") __table_args__ = ( Index('ix_stream_infohash', 'infohash'), diff --git a/src/program/program.py b/src/program/program.py index af869329..f50e20e5 100644 --- a/src/program/program.py +++ b/src/program/program.py @@ -36,7 +36,7 @@ from sqlalchemy import func, select, text import program.db.db_functions as DB -from program.db.db import db, run_migrations, vacuum_and_analyze_index_maintenance +from program.db.db import create_database_if_not_exists, db, run_migrations, vacuum_and_analyze_index_maintenance class Program(threading.Thread): @@ -136,7 +136,11 @@ def start(self): if not self.validate_database(): # We should really make this configurable via frontend... - return + logger.log("PROGRAM", "Database not found, trying to create database") + if not create_database_if_not_exists(): + logger.error("Failed to create database, exiting") + return + logger.success("Database created successfully") run_migrations() self._init_db_from_symlinks() diff --git a/src/utils/event_manager.py b/src/utils/event_manager.py index 17e175b2..80a549e2 100644 --- a/src/utils/event_manager.py +++ b/src/utils/event_manager.py @@ -7,7 +7,7 @@ from loguru import logger from sqlalchemy.orm.exc import StaleDataError -from subliminal import Episode, Movie +from concurrent.futures import CancelledError import utils.websockets.manager as ws_manager from program.db.db import db @@ -69,12 +69,11 @@ def _process_future(self, future, service): if item: self.remove_item_from_running(item) self.add_event(Event(emitted_by=service, item=item, run_at=timestamp)) - except concurrent.futures.CancelledError: - # This is expected behavior when cancelling tasks - return - except StaleDataError: - # This is expected behavior when cancelling tasks + except (StaleDataError, CancelledError): + # Expected behavior when cancelling tasks or when the item was removed return + except ValueError as e: + logger.error(f"Error in future for {future}: {e}") except Exception as e: logger.error(f"Error in future for {future}: {e}") logger.exception(traceback.format_exc())