Skip to content

Commit

Permalink
feat: Types for the FastAPI API and API refactor (#748)
Browse files Browse the repository at this point in the history
* feat: add response models to items.py

* feat: ignore ruff linter raising exceptions in except block

* refactor: ruff format all the api files

* feat: add response models to scrape.py

* feat: add response models to settings.py

* feat: add response models to tmdb.py

* feat: add missing type annotations to tmdb.py

* fix: add default values for some pydantic models

* feat: add types to default.py

* fix: bad pydantic types causing serialization error

* fix: add some model validation where needed

* feat: add mypy to dev dependencies for static type checking

* fix: wrong type in realdebrid

* feat: add some options for easier querying of items

* fix: pass with_streams argument in to_extended_dict to chidren

* feat: remove the old json response format from services and stats endpoints

* feat: migrate the settings api to the new response types

* feat: add type annotation to get_all_settings

* feat: migrate the rest of the APIs to the new response schema

* fix: remove old imports
  • Loading branch information
Filip Trplan authored Oct 10, 2024
1 parent 24904fc commit 9eec02d
Show file tree
Hide file tree
Showing 14 changed files with 803 additions and 256 deletions.
61 changes: 59 additions & 2 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 3 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ codecov = "^2.1.13"
httpx = "^0.27.0"
memray = "^1.13.4"
testcontainers = "^4.8.0"
mypy = "^1.11.2"

[tool.poetry.group.test]
optional = true
Expand Down Expand Up @@ -91,7 +92,8 @@ ignore = [
"S101", # ruff: Ignore assert warnings on tests
"RET505", #
"RET503", # ruff: Ignore required explicit returns (is this desired?)
"SLF001" # private member accessing from pickle
"SLF001", # private member accessing from pickle
"B904" # ruff: ignore raising exceptions from except for the API
]
extend-select = [
"I", # isort
Expand Down
140 changes: 98 additions & 42 deletions src/controllers/default.py
Original file line number Diff line number Diff line change
@@ -1,57 +1,75 @@
from typing import Literal

import requests
from controllers.models.shared import MessageResponse
from fastapi import APIRouter, HTTPException, Request
from loguru import logger
from sqlalchemy import func, select

from program.content.trakt import TraktContent
from program.db.db import db
from program.media.item import Episode, MediaItem, Movie, Season, Show
from program.media.state import States
from program.settings.manager import settings_manager
from pydantic import BaseModel, Field
from sqlalchemy import func, select
from utils.event_manager import EventUpdate

router = APIRouter(
responses={404: {"description": "Not found"}},
)


class RootResponse(MessageResponse):
version: str


@router.get("/", operation_id="root")
async def root():
async def root() -> RootResponse:
return {
"success": True,
"message": "Riven is running!",
"version": settings_manager.settings.version,
}


@router.get("/health", operation_id="health")
async def health(request: Request):
async def health(request: Request) -> MessageResponse:
return {
"success": True,
"message": request.app.program.initialized,
}


class RDUser(BaseModel):
id: int
username: str
email: str
points: int = Field(description="User's RD points")
locale: str
avatar: str = Field(description="URL to the user's avatar")
type: Literal["free", "premium"]
premium: int = Field(description="Premium subscription left in seconds")


@router.get("/rd", operation_id="rd")
async def get_rd_user():
async def get_rd_user() -> RDUser:
api_key = settings_manager.settings.downloaders.real_debrid.api_key
headers = {"Authorization": f"Bearer {api_key}"}

proxy = settings_manager.settings.downloaders.real_debrid.proxy_url if settings_manager.settings.downloaders.real_debrid.proxy_enabled else None
proxy = (
settings_manager.settings.downloaders.real_debrid.proxy_url
if settings_manager.settings.downloaders.real_debrid.proxy_enabled
else None
)

response = requests.get(
"https://api.real-debrid.com/rest/1.0/user",
headers=headers,
proxies=proxy if proxy else None,
timeout=10
timeout=10,
)

if response.status_code != 200:
return {"success": False, "message": response.json()}

return {
"success": True,
"data": response.json(),
}
return response.json()


@router.get("/torbox", operation_id="torbox")
Expand All @@ -65,7 +83,7 @@ async def get_torbox_user():


@router.get("/services", operation_id="services")
async def get_services(request: Request):
async def get_services(request: Request) -> dict[str, bool]:
data = {}
if hasattr(request.app.program, "services"):
for service in request.app.program.all_services.values():
Expand All @@ -74,11 +92,15 @@ async def get_services(request: Request):
continue
for sub_service in service.services.values():
data[sub_service.key] = sub_service.initialized
return {"success": True, "data": data}
return data


class TraktOAuthInitiateResponse(BaseModel):
auth_url: str


@router.get("/trakt/oauth/initiate", operation_id="trakt_oauth_initiate")
async def initiate_trakt_oauth(request: Request):
async def initiate_trakt_oauth(request: Request) -> TraktOAuthInitiateResponse:
trakt = request.app.program.services.get(TraktContent)
if trakt is None:
raise HTTPException(status_code=404, detail="Trakt service not found")
Expand All @@ -87,24 +109,41 @@ async def initiate_trakt_oauth(request: Request):


@router.get("/trakt/oauth/callback", operation_id="trakt_oauth_callback")
async def trakt_oauth_callback(code: str, request: Request):
async def trakt_oauth_callback(code: str, request: Request) -> MessageResponse:
trakt = request.app.program.services.get(TraktContent)
if trakt is None:
raise HTTPException(status_code=404, detail="Trakt service not found")
success = trakt.handle_oauth_callback(code)
if success:
return {"success": True, "message": "OAuth token obtained successfully"}
return {"message": "OAuth token obtained successfully"}
else:
raise HTTPException(status_code=400, detail="Failed to obtain OAuth token")


class StatsResponse(BaseModel):
total_items: int
total_movies: int
total_shows: int
total_seasons: int
total_episodes: int
total_symlinks: int
incomplete_items: int
incomplete_retries: dict[str, int] = Field(
description="Media item log string: number of retries"
)
states: dict[States, int]


@router.get("/stats", operation_id="stats")
async def get_stats(_: Request):
async def get_stats(_: Request) -> StatsResponse:
payload = {}
with db.Session() as session:

movies_symlinks = session.execute(select(func.count(Movie._id)).where(Movie.symlinked == True)).scalar_one()
episodes_symlinks = session.execute(select(func.count(Episode._id)).where(Episode.symlinked == True)).scalar_one()
movies_symlinks = session.execute(
select(func.count(Movie._id)).where(Movie.symlinked == True)
).scalar_one()
episodes_symlinks = session.execute(
select(func.count(Episode._id)).where(Episode.symlinked == True)
).scalar_one()
total_symlinks = movies_symlinks + episodes_symlinks

total_movies = session.execute(select(func.count(Movie._id))).scalar_one()
Expand All @@ -113,21 +152,30 @@ async def get_stats(_: Request):
total_episodes = session.execute(select(func.count(Episode._id))).scalar_one()
total_items = session.execute(select(func.count(MediaItem._id))).scalar_one()

# Select only the IDs of incomplete items
_incomplete_items = session.execute(
select(MediaItem._id)
.where(MediaItem.last_state != States.Completed)
).scalars().all()

# Select only the IDs of incomplete items
_incomplete_items = (
session.execute(
select(MediaItem._id).where(MediaItem.last_state != States.Completed)
)
.scalars()
.all()
)

incomplete_retries = {}
if _incomplete_items:
media_items = session.query(MediaItem).filter(MediaItem._id.in_(_incomplete_items)).all()
media_items = (
session.query(MediaItem)
.filter(MediaItem._id.in_(_incomplete_items))
.all()
)
for media_item in media_items:
incomplete_retries[media_item.log_string] = media_item.scraped_times

states = {}
for state in States:
states[state] = session.execute(select(func.count(MediaItem._id)).where(MediaItem.last_state == state)).scalar_one()
states[state] = session.execute(
select(func.count(MediaItem._id)).where(MediaItem.last_state == state)
).scalar_one()

payload["total_items"] = total_items
payload["total_movies"] = total_movies
Expand All @@ -138,11 +186,15 @@ async def get_stats(_: Request):
payload["incomplete_items"] = len(_incomplete_items)
payload["incomplete_retries"] = incomplete_retries
payload["states"] = states
return payload


class LogsResponse(BaseModel):
logs: str

return {"success": True, "data": payload}

@router.get("/logs", operation_id="logs")
async def get_logs():
async def get_logs() -> str:
log_file_path = None
for handler in logger._core.handlers.values():
if ".log" in handler._name:
Expand All @@ -153,24 +205,29 @@ async def get_logs():
return {"success": False, "message": "Log file handler not found"}

try:
with open(log_file_path, 'r') as log_file:
with open(log_file_path, "r") as log_file:
log_contents = log_file.read()
return {"success": True, "logs": log_contents}
return {"logs": log_contents}
except Exception as e:
logger.error(f"Failed to read log file: {e}")
return {"success": False, "message": "Failed to read log file"}

raise HTTPException(status_code=500, detail="Failed to read log file")


@router.get("/events", operation_id="events")
async def get_events(request: Request):
return {"success": True, "data": request.app.program.em.get_event_updates()}
async def get_events(
request: Request,
) -> dict[str, list[EventUpdate]]:
return request.app.program.em.get_event_updates()


@router.get("/mount", operation_id="mount")
async def get_rclone_files():
async def get_rclone_files() -> dict[str, str]:
"""Get all files in the rclone mount."""
import os

rclone_dir = settings_manager.settings.symlink.rclone_path
file_map = {}

def scan_dir(path):
with os.scandir(path) as entries:
for entry in entries:
Expand All @@ -179,6 +236,5 @@ def scan_dir(path):
elif entry.is_dir():
scan_dir(entry.path)

scan_dir(rclone_dir) # dict of `filename: filepath``
return {"success": True, "data": file_map}

scan_dir(rclone_dir) # dict of `filename: filepath``
return file_map
Loading

0 comments on commit 9eec02d

Please sign in to comment.