From 4063aec1edb25102b6cf9335f7cf055aac53f23a Mon Sep 17 00:00:00 2001 From: Jan Willhaus Date: Sat, 26 Oct 2024 21:02:03 +0200 Subject: [PATCH] feat: Add continuous mode with --sleep-seconds --- README.md | 19 +++++++ cspell.config.yaml | 1 + docker-compose.yml | 17 ++++++ podcast_archiver/base.py | 14 +++-- podcast_archiver/cli.py | 20 +++++-- podcast_archiver/config.py | 14 ++++- podcast_archiver/constants.py | 2 + podcast_archiver/download.py | 97 +++++++++++++-------------------- podcast_archiver/enums.py | 6 +-- podcast_archiver/exceptions.py | 22 +++++++- podcast_archiver/logging.py | 87 +++++++++++++++++++++++------- podcast_archiver/models.py | 98 +++++++++++++++++++++------------- podcast_archiver/processor.py | 35 +++++++----- podcast_archiver/session.py | 42 ++++++++++----- podcast_archiver/utils.py | 24 +++++++-- tests/test_download.py | 17 ++++++ tests/test_fetch.py | 4 +- tests/test_models.py | 48 ++++++++++++++--- 18 files changed, 399 insertions(+), 168 deletions(-) create mode 100644 docker-compose.yml diff --git a/README.md b/README.md index adeab2e..af9bd89 100644 --- a/README.md +++ b/README.md @@ -72,6 +72,25 @@ Feeds can also be "fetched" from a local file: podcast-archiver -f file:/Users/janw/downloaded_feed.xml ``` +#### Continuous mode + +When the `--sleep-seconds` option is set to a non-zero value, Podcast Archiver operates in continuous mode. After successfully populating the archive, it will not terminate but rather sleep for the given number of seconds until it refreshes the feeds again and downloads episodes that have been published in the meantime. + +If no new episodes have been published, no download attempts will be made, and the archiver will go to sleep again. This mode of operation is ideal to be run in a containerized setup, for example using [docker compose](https://docs.docker.com/compose/install/): + +```yaml +services: + podcast-archiver: + restart: always + image: ghcr.io/janw/podcast-archiver + volumes: + - ./archive:/archive + command: + - --sleep-seconds=3600 # sleep for 1 hour between updates + - --feed=https://feeds.feedburner.com/TheAnthropoceneReviewed + - --feed=https://feeds.megaphone.fm/heavyweight-spot +``` + ### Changing the filename format Podcast Archiver has a `--filename-template` option that allows you to change the particular naming scheme of the archive. The default value for `--filename-template`. is shown in `podcast-archiver --help`, as well as all the available variables. The basic ones are: diff --git a/cspell.config.yaml b/cspell.config.yaml index d3cc02e..c78bf71 100644 --- a/cspell.config.yaml +++ b/cspell.config.yaml @@ -44,6 +44,7 @@ words: - PYTHONUNBUFFERED - pyyaml - rprint + - signum - subdirs - tini - tmpl diff --git a/docker-compose.yml b/docker-compose.yml new file mode 100644 index 0000000..6e731f2 --- /dev/null +++ b/docker-compose.yml @@ -0,0 +1,17 @@ +services: + podcast-archiver: + restart: always + image: ghcr.io/janw/podcast-archiver:v1 + build: + context: . + dockerfile: Dockerfile + cache_from: + - ghcr.io/janw/podcast-archiver:edge + - ghcr.io/janw/podcast-archiver:latest + volumes: + - ./archive:/archive + command: + - --sleep-seconds=3600 + - --ignore-database + - --feed=https://feeds.feedburner.com/TheAnthropoceneReviewed + - --feed=https://feeds.megaphone.fm/heavyweight-spot diff --git a/podcast_archiver/base.py b/podcast_archiver/base.py index f5ad83d..90cb94f 100755 --- a/podcast_archiver/base.py +++ b/podcast_archiver/base.py @@ -1,7 +1,9 @@ from __future__ import annotations +import signal +import sys import xml.etree.ElementTree as etree -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Any from podcast_archiver.logging import logger, rprint from podcast_archiver.processor import FeedProcessor @@ -31,9 +33,15 @@ def __init__(self, settings: Settings): self.add_from_opml(opml) def register_cleanup(self, ctx: click.RichContext) -> None: - @ctx.call_on_close - def _cleanup() -> None: + def _cleanup(signum: int, *args: Any) -> None: + logger.debug("Signal %s received", signum) + rprint("[error]Terminating.[/]") self.processor.shutdown() + ctx.close() + sys.exit(0) + + signal.signal(signal.SIGINT, _cleanup) + signal.signal(signal.SIGTERM, _cleanup) def add_feed(self, feed: Path | str) -> None: new_feeds = [feed] if isinstance(feed, str) else feed.read_text().strip().splitlines() diff --git a/podcast_archiver/cli.py b/podcast_archiver/cli.py index 17cb33f..b5b29bd 100644 --- a/podcast_archiver/cli.py +++ b/podcast_archiver/cli.py @@ -3,18 +3,18 @@ import os import pathlib import stat +import time from os import getenv from typing import TYPE_CHECKING, Any import rich_click as click -from rich import get_console from podcast_archiver import __version__ as version from podcast_archiver import constants from podcast_archiver.base import PodcastArchiver from podcast_archiver.config import Settings, in_ci from podcast_archiver.exceptions import InvalidSettings -from podcast_archiver.logging import configure_logging +from podcast_archiver.logging import configure_logging, rprint if TYPE_CHECKING: from click.shell_completion import CompletionItem @@ -49,6 +49,7 @@ "--update", "--max-episodes", "--ignore-database", + "--sleep", ], }, ] @@ -215,6 +216,7 @@ def generate_default_config(ctx: click.Context, param: click.Parameter, value: b "-v", "--verbose", count=True, + metavar="", show_envvar=True, help=Settings.model_fields["verbose"].description, ) @@ -281,10 +283,16 @@ def generate_default_config(ctx: click.Context, param: click.Parameter, value: b show_envvar=True, help=Settings.model_fields["ignore_database"].description, ) +@click.option( + "--sleep-seconds", + type=int, + default=0, + show_envvar=True, + help=Settings.model_fields["sleep_seconds"].description, +) @click.pass_context def main(ctx: click.RichContext, /, **kwargs: Any) -> int: - get_console().quiet = kwargs["quiet"] - configure_logging(kwargs["verbose"]) + configure_logging(kwargs["verbose"], kwargs["quiet"]) try: settings = Settings.load_from_dict(kwargs) @@ -296,6 +304,10 @@ def main(ctx: click.RichContext, /, **kwargs: Any) -> int: pa = PodcastArchiver(settings=settings) pa.register_cleanup(ctx) pa.run() + while settings.sleep_seconds > 0: + rprint(f"Sleeping for {settings.sleep_seconds} seconds.") + time.sleep(settings.sleep_seconds) + pa.run() except InvalidSettings as exc: raise click.BadParameter(f"Invalid settings: {exc}") from exc except KeyboardInterrupt as exc: # pragma: no cover diff --git a/podcast_archiver/config.py b/podcast_archiver/config.py index ff9bfaa..c053568 100644 --- a/podcast_archiver/config.py +++ b/podcast_archiver/config.py @@ -87,7 +87,11 @@ class Settings(BaseModel): verbose: int = Field( default=0, - description="Increase the level of verbosity while downloading.", + description=( + "Increase the level of verbosity while downloading. Can be passed multiple times. Increased verbosity and " + "non-interactive execution (in a cronjob, docker compose, etc.) will disable progress bars. " + "Non-interactive execution also always raises the verbosity unless --quiet is passed." + ), ) slugify_paths: bool = Field( @@ -136,6 +140,14 @@ class Settings(BaseModel): ), ) + sleep_seconds: int = Field( + default=0, + description=( + f"Run {constants.PROG_NAME} continuously. Set to a non-zero number of seconds to sleep after all available " + "episodes have been downloaded. Otherwise the application exits after all downloads have been completed." + ), + ) + config: FilePath | None = Field( default=None, exclude=True, diff --git a/podcast_archiver/constants.py b/podcast_archiver/constants.py index cf04cf1..2b22699 100644 --- a/podcast_archiver/constants.py +++ b/podcast_archiver/constants.py @@ -15,6 +15,8 @@ MAX_TITLE_LENGTH = 96 + +DEFAULT_DATETIME_FORMAT = "%Y-%m-%d" DEFAULT_ARCHIVE_DIRECTORY = pathlib.Path(".") DEFAULT_FILENAME_TEMPLATE = "{show.title}/{episode.published_time:%Y-%m-%d} - {episode.title}.{ext}" DEFAULT_CONCURRENCY = 4 diff --git a/podcast_archiver/download.py b/podcast_archiver/download.py index e32c1fe..4d1a548 100644 --- a/podcast_archiver/download.py +++ b/podcast_archiver/download.py @@ -1,15 +1,13 @@ from __future__ import annotations -from contextlib import nullcontext +from contextlib import contextmanager from threading import Event -from typing import IO, TYPE_CHECKING, NoReturn - -from tqdm import tqdm -from tqdm.contrib.logging import logging_redirect_tqdm +from typing import IO, TYPE_CHECKING, Generator from podcast_archiver import constants from podcast_archiver.enums import DownloadResult -from podcast_archiver.logging import logger +from podcast_archiver.exceptions import NotCompleted +from podcast_archiver.logging import logger, wrapped_tqdm from podcast_archiver.session import session from podcast_archiver.types import EpisodeResult from podcast_archiver.utils import atomic_write @@ -28,38 +26,31 @@ class DownloadJob: target: Path stop_event: Event - _debug_partial: bool + _max_download_bytes: int | None = None _write_info_json: bool - _no_progress: bool def __init__( self, episode: Episode, *, target: Path, - debug_partial: bool = False, + max_download_bytes: int | None = None, write_info_json: bool = False, - no_progress: bool = False, stop_event: Event | None = None, ) -> None: self.episode = episode self.target = target - self._debug_partial = debug_partial + self._max_download_bytes = max_download_bytes self._write_info_json = write_info_json - self._no_progress = no_progress self.stop_event = stop_event or Event() - def __repr__(self) -> str: - return f"EpisodeDownload({self})" - - def __str__(self) -> str: - return str(self.episode) - def __call__(self) -> EpisodeResult: try: return self.run() + except NotCompleted: + return EpisodeResult(self.episode, DownloadResult.ABORTED) except Exception as exc: - logger.error(f"Download failed: {exc}") + logger.error("Download failed: %s; %s", self.episode, exc) logger.debug("Exception while downloading", exc_info=exc) return EpisodeResult(self.episode, DownloadResult.FAILED) @@ -68,58 +59,44 @@ def run(self) -> EpisodeResult: return EpisodeResult(self.episode, DownloadResult.ALREADY_EXISTS) self.target.parent.mkdir(parents=True, exist_ok=True) - self.write_info_json() - - response = session.get( - self.episode.enclosure.href, - stream=True, - allow_redirects=True, - ) - response.raise_for_status() - total_size = int(response.headers.get("content-length", "0")) - with ( - logging_redirect_tqdm() if not self._no_progress else nullcontext(), - tqdm( - desc=f"{self.episode.title} ({self.episode.published_time:%Y-%m-%d})", - total=total_size, - unit_scale=True, - unit="B", - disable=self._no_progress, - ) as progresser, - ): - with atomic_write(self.target, mode="wb") as fp: - receive_complete = self.receive_data(fp, response, progresser=progresser) - - if not receive_complete: - self.target.unlink(missing_ok=True) - return EpisodeResult(self.episode, DownloadResult.ABORTED) + logger.info("Downloading: %s", self.episode) + response = session.get_and_raise(self.episode.enclosure.href, stream=True) + with self.write_info_json(), atomic_write(self.target, mode="wb") as fp: + self.receive_data(fp, response) - logger.info("Completed download of %s", self.target) + logger.info("Completed: %s", self.episode) return EpisodeResult(self.episode, DownloadResult.COMPLETED_SUCCESSFULLY) @property def infojsonfile(self) -> Path: return self.target.with_suffix(".info.json") - def receive_data(self, fp: IO[str], response: Response, progresser: tqdm[NoReturn]) -> bool: + def receive_data(self, fp: IO[bytes], response: Response) -> None: + total_size = int(response.headers.get("content-length", "0")) total_written = 0 - for chunk in response.iter_content(chunk_size=constants.DOWNLOAD_CHUNK_SIZE): - written = fp.write(chunk) - total_written += written - progresser.update(written) - - if self._debug_partial and total_written >= constants.DEBUG_PARTIAL_SIZE: - logger.debug("Partial download completed.") - return True - if self.stop_event.is_set(): - logger.debug("Stop event is set, bailing.") - return False + max_bytes = self._max_download_bytes + for chunk in wrapped_tqdm( + response.iter_content(chunk_size=constants.DOWNLOAD_CHUNK_SIZE), + desc=str(self.episode), + total=total_size, + ): + total_written += fp.write(chunk) - return True + if max_bytes and total_written >= max_bytes: + fp.truncate(max_bytes) + logger.debug("Partial download of first %s bytes completed.", max_bytes) + return + + if self.stop_event.is_set(): + logger.debug("Stop event is set, bailing on %s.", self.episode) + raise NotCompleted - def write_info_json(self) -> None: + @contextmanager + def write_info_json(self) -> Generator[None, None, None]: if not self._write_info_json: + yield return - logger.info("Writing episode metadata to %s", self.infojsonfile.name) with atomic_write(self.infojsonfile) as fp: fp.write(self.episode.model_dump_json(indent=2) + "\n") + yield + logger.debug("Wrote episode metadata to %s", self.infojsonfile.name) diff --git a/podcast_archiver/enums.py b/podcast_archiver/enums.py index 2da0584..1ee7527 100644 --- a/podcast_archiver/enums.py +++ b/podcast_archiver/enums.py @@ -7,9 +7,9 @@ def __str__(self) -> str: class QueueCompletionType(StrEnum): - COMPLETED = "Archived all episodes." - FOUND_EXISTING = "Archive is up to date." - MAX_EPISODES = "Maximum episode count reached." + COMPLETED = "Archived all episodes" + FOUND_EXISTING = "Archive is up to date" + MAX_EPISODES = "Maximum episode count reached" class DownloadResult(StrEnum): diff --git a/podcast_archiver/exceptions.py b/podcast_archiver/exceptions.py index e6665b4..941bbc6 100644 --- a/podcast_archiver/exceptions.py +++ b/podcast_archiver/exceptions.py @@ -1,6 +1,11 @@ -from typing import Any +from __future__ import annotations -import pydantic_core +from typing import TYPE_CHECKING, Any + +if TYPE_CHECKING: + import pydantic_core + + from podcast_archiver.models import FeedInfo class PodcastArchiverException(Exception): @@ -27,3 +32,16 @@ def __str__(self) -> str: class MissingDownloadUrl(ValueError): pass + + +class NotCompleted(RuntimeError): + pass + + +class NotModified(PodcastArchiverException): + info: FeedInfo + last_modified: str | None = None + + def __init__(self, info: FeedInfo, *args: object) -> None: + super().__init__(*args) + self.info = info diff --git a/podcast_archiver/logging.py b/podcast_archiver/logging.py index 6be95e2..d5b597e 100644 --- a/podcast_archiver/logging.py +++ b/podcast_archiver/logging.py @@ -2,45 +2,92 @@ import logging import logging.config -from typing import Any +import sys +from os import environ +from typing import Any, Generator, Iterable -from rich import get_console from rich import print as _print from rich.logging import RichHandler +from rich.text import Text +from tqdm import tqdm +from tqdm.contrib.logging import logging_redirect_tqdm logger = logging.getLogger("podcast_archiver") -def rprint(*objects: Any, sep: str = " ", end: str = "\n", **kwargs: Any) -> None: - if logger.level == logging.NOTSET or logger.level >= logging.WARNING: - _print(*objects, sep=sep, end=end, **kwargs) +_REDIRECT_VIA_TQDM: bool = False +_REDIRECT_VIA_LOGGING: bool = False + + +def rprint(msg: str, **kwargs: Any) -> None: + if not _REDIRECT_VIA_TQDM and not _REDIRECT_VIA_LOGGING: + _print(msg, **kwargs) + return + + text = Text.from_markup(msg.strip()).plain.strip() + logger.info(text) + + +def wrapped_tqdm(iterable: Iterable[bytes], desc: str, total: int) -> Generator[bytes, None, None]: + if _REDIRECT_VIA_LOGGING: + yield from iterable return - logger.info(objects[0].strip(), *objects[1:]) + with ( + logging_redirect_tqdm(), + tqdm(desc=desc, total=total, unit_scale=True, unit="B") as progress, + ): + global _REDIRECT_VIA_TQDM + _REDIRECT_VIA_TQDM = True + try: + for chunk in iterable: + progress.update(len(chunk)) + yield chunk + finally: + _REDIRECT_VIA_TQDM = False + + +def is_interactive() -> bool: + return sys.stdout.isatty() and environ.get("TERM", "").lower() not in ("dumb", "unknown") -def configure_logging(verbosity: int) -> None: - if verbosity > 1: - level = logging.DEBUG - elif verbosity == 1: - level = logging.WARNING + +def configure_level(verbosity: int, quiet: bool) -> int: + global _REDIRECT_VIA_LOGGING + interactive = is_interactive() + if not interactive or quiet or verbosity > 0: + _REDIRECT_VIA_LOGGING = True + + if verbosity > 1 and not quiet: + return logging.DEBUG + elif (verbosity == 1 or not interactive) and not quiet: + return logging.INFO else: - level = logging.ERROR + return logging.ERROR - logging.basicConfig( - level=level, - format="%(message)s", - datefmt="[%X]", - handlers=[ + +def configure_logging(verbosity: int, quiet: bool) -> None: + level = configure_level(verbosity, quiet) + handlers: list[logging.Handler] | None = None + logformat: str = "%(asctime)s %(name)-16s %(levelname)-5s %(message)s" + + if is_interactive(): + logformat = "%(message)s" + handlers = [ RichHandler( - log_time_format="[%X]", markup=True, - rich_tracebacks=get_console().is_terminal, + rich_tracebacks=verbosity > 1, tracebacks_suppress=[ "click", ], tracebacks_show_locals=True, ) - ], + ] + + logging.basicConfig( + level=logging.ERROR, + format=logformat, + datefmt="[%Y-%m-%d %H:%M:%S]", + handlers=handlers, ) logger.setLevel(level) logger.debug("Running in debug mode.") diff --git a/podcast_archiver/models.py b/podcast_archiver/models.py index 919f14b..e4e17ae 100644 --- a/podcast_archiver/models.py +++ b/podcast_archiver/models.py @@ -2,9 +2,10 @@ from datetime import datetime, timezone from functools import cached_property +from http import HTTPStatus from pathlib import Path from time import mktime, struct_time -from typing import Any, Iterator +from typing import TYPE_CHECKING, Annotated, Any, Iterator from urllib.parse import urlparse import feedparser @@ -16,13 +17,26 @@ field_validator, model_validator, ) +from pydantic.functional_validators import BeforeValidator -from podcast_archiver.constants import MAX_TITLE_LENGTH, REQUESTS_TIMEOUT -from podcast_archiver.exceptions import MissingDownloadUrl +from podcast_archiver.constants import DEFAULT_DATETIME_FORMAT, MAX_TITLE_LENGTH +from podcast_archiver.exceptions import MissingDownloadUrl, NotModified from podcast_archiver.logging import logger from podcast_archiver.session import session from podcast_archiver.utils import get_generic_extension, truncate +if TYPE_CHECKING: + from requests import Response + + +def parse_from_struct_time(value: Any) -> Any: + if isinstance(value, struct_time): + return datetime.fromtimestamp(mktime(value)).replace(tzinfo=timezone.utc) + return value + + +LenientDatetime = Annotated[datetime, BeforeValidator(parse_from_struct_time)] + class Link(BaseModel): rel: str = "" @@ -44,9 +58,9 @@ class Episode(BaseModel): title: str = Field(default="Untitled Episode", title="episode.title") subtitle: str = Field("", repr=False, title="episode.subtitle") author: str = Field("", repr=False) - links: list[Link] = Field(default_factory=list) + links: list[Link] = Field(default_factory=list, repr=False) enclosure: Link = Field(default=None, repr=False) # type: ignore[assignment] - published_time: datetime = Field(alias="published_parsed", title="episode.published_time") + published_time: LenientDatetime = Field(alias="published_parsed", title="episode.published_time") original_filename: str = Field(default="", repr=False, title="episode.original_filename") @@ -68,18 +82,8 @@ class Episode(BaseModel): guid: str = Field(default=None, alias="id") # type: ignore[assignment] - def __hash__(self) -> int: - return hash(self.guid) - - def __eq__(self, other: Episode | Any) -> bool: - return isinstance(other, Episode) and self.guid == other.guid - - @field_validator("published_time", mode="before") - @classmethod - def parse_from_struct_time(cls, value: Any) -> Any: - if isinstance(value, struct_time): - return datetime.fromtimestamp(mktime(value)).replace(tzinfo=timezone.utc) - return value + def __str__(self) -> str: + return f"{self.title} ({self.published_time.strftime(DEFAULT_DATETIME_FORMAT)})" @field_validator("title", mode="after") @classmethod @@ -141,6 +145,12 @@ class FeedInfo(BaseModel): language: str | None = Field(default=None, title="show.language") links: list[Link] = [] + updated_time: LenientDatetime | None = Field(default=None, alias="updated_parsed") + last_modified: str | None = Field(default=None) + + def __str__(self) -> str: + return self.title + @field_validator("title", mode="after") @classmethod def truncate_title(cls, value: str) -> str: @@ -162,40 +172,56 @@ class FeedPage(BaseModel): bozo_exception: Exception | None = None @classmethod - def from_url(cls, url: str) -> FeedPage: + def from_url(cls, url: str, *, known_info: FeedInfo | None = None) -> FeedPage: parsed = urlparse(url) if parsed.scheme == "file": feedobj = feedparser.parse(parsed.path) - else: - response = session.get(url, allow_redirects=True, timeout=REQUESTS_TIMEOUT) - response.raise_for_status() - feedobj = feedparser.parse(response.content) - return cls.model_validate(feedobj) + return cls.model_validate(feedobj) + + if not known_info: + return cls.from_response(session.get_and_raise(url)) + + response = session.get_and_raise(url, last_modified=known_info.last_modified) + if response.status_code == HTTPStatus.NOT_MODIFIED: + logger.debug("Server reported 'not modified' from %s, skipping fetch.", known_info.last_modified) + raise NotModified(known_info) + + instance = cls.from_response(response) + if instance.feed.updated_time == known_info.updated_time: + logger.debug("Feed's updated time %s did not change, skipping fetch.", known_info.updated_time) + raise NotModified(known_info) + + return instance + + @classmethod + def from_response(cls, response: Response) -> FeedPage: + feedobj = feedparser.parse(response.content) + instance = cls.model_validate(feedobj) + instance.feed.last_modified = response.headers.get("Last-Modified") + return instance class Feed: - info: FeedInfo url: str + info: FeedInfo - _page: FeedPage | None + _page: FeedPage | None = None - def __init__(self, page: FeedPage, url: str) -> None: - self.info = page.feed + def __init__(self, url: str, *, known_info: FeedInfo | None = None) -> None: self.url = url - self._page = page + if known_info: + self.info = known_info + + self._page = FeedPage.from_url(url, known_info=known_info) + self.info = self._page.feed - logger.info("Loaded feed for '%s' by %s", self.info.title, self.info.author) + logger.debug("Loaded feed for '%s' by %s from %s", self.info.title, self.info.author, url) def __repr__(self) -> str: return f"Feed(name='{self}', url='{self.url}')" def __str__(self) -> str: - return f"{self.info.title}" - - @classmethod - def from_url(cls, url: str) -> Feed: - logger.info("Parsing feed %s", url) - return cls(page=FeedPage.from_url(url), url=url) + return str(self.info) def episode_iter(self, maximum_episode_count: int = 0) -> Iterator[Episode]: episode_count_total = 0 @@ -208,7 +234,7 @@ def episode_iter(self, maximum_episode_count: int = 0) -> Iterator[Episode]: episode_count_page += 1 episode_count_total += 1 - logger.info("Found %s episodes on page %s", episode_count_page, page_count) + logger.debug("Found %s episodes on page %s", episode_count_page, page_count) self._get_next_page() def _get_next_page(self) -> None: diff --git a/podcast_archiver/processor.py b/podcast_archiver/processor.py index cbcf9ee..6bdb3f7 100644 --- a/podcast_archiver/processor.py +++ b/podcast_archiver/processor.py @@ -5,6 +5,7 @@ from threading import Event from typing import TYPE_CHECKING +from podcast_archiver import constants from podcast_archiver.download import DownloadJob from podcast_archiver.enums import DownloadResult, QueueCompletionType from podcast_archiver.logging import logger, rprint @@ -35,33 +36,36 @@ class FeedProcessor: pool_executor: ThreadPoolExecutor stop_event: Event + known_feeds: dict[str, FeedInfo] + def __init__(self, settings: Settings) -> None: self.settings = settings self.filename_formatter = FilenameFormatter(settings) self.database = settings.get_database() self.pool_executor = ThreadPoolExecutor(max_workers=self.settings.concurrency) self.stop_event = Event() + self.known_feeds = {} def process(self, url: str) -> ProcessingResult: result = ProcessingResult() with handle_feed_request(url): - result.feed = Feed.from_url(url) + result.feed = Feed(url=url, known_info=self.known_feeds.get(url)) if result.feed: - rprint(f"\n[bold bright_magenta]Downloading archive for: {result.feed.info.title}[/]\n") + rprint(f"\n[bold bright_magenta]Downloading archive for: {result.feed}[/]\n") episode_results, completion_msg = self._process_episodes(feed=result.feed) self._handle_results(episode_results, result=result) - rprint(f"\n[bar.finished]✔ {completion_msg}[/]") + rprint(f"\n[bar.finished]✔ {completion_msg} for: {result.feed}[/]") return result def _preflight_check(self, episode: Episode, target: Path) -> DownloadResult | None: if self.database.exists(episode): - logger.debug("Pre-flight check on episode '%s': already in database.", episode.title) + logger.debug("Pre-flight check on episode '%s': already in database.", episode) return DownloadResult.ALREADY_EXISTS if target.exists(): - logger.debug("Pre-flight check on episode '%s': already on disk.", episode.title) + logger.debug("Pre-flight check on episode '%s': already on disk.", episode) return DownloadResult.ALREADY_EXISTS return None @@ -73,7 +77,7 @@ def _process_episodes(self, feed: Feed) -> tuple[EpisodeResultsList, QueueComple return results, completion if (max_count := self.settings.maximum_episode_count) and idx == max_count: - logger.info("Reached requested maximum episode count of %s", max_count) + logger.debug("Reached requested maximum episode count of %s", max_count) return results, QueueCompletionType.MAX_EPISODES return results, QueueCompletionType.COMPLETED @@ -83,22 +87,21 @@ def _process_episode( ) -> QueueCompletionType | None: target = self.filename_formatter.format(episode=episode, feed_info=feed_info) if result := self._preflight_check(episode, target): - rprint(f"[bar.finished]✔ {result}: {episode.title} ({episode.published_time.strftime('%Y-%m-%d')})[/]") + rprint(f"[bar.finished]✔ {result}: {episode}[/]") results.append(EpisodeResult(episode, result)) if self.settings.update_archive: - logger.info("Up to date with %r", episode) + logger.debug("Up to date with %r", episode) return QueueCompletionType.FOUND_EXISTING return None - logger.info("Queueing download for %r", episode) + logger.debug("Queueing download for %r", episode) results.append( self.pool_executor.submit( DownloadJob( episode, target=target, - debug_partial=self.settings.debug_partial, + max_download_bytes=constants.DEBUG_PARTIAL_SIZE if self.settings.debug_partial else None, write_info_json=self.settings.write_info_json, - no_progress=self.settings.verbose > 2 or self.settings.quiet, stop_event=self.stop_event, ) ) @@ -106,6 +109,8 @@ def _process_episode( return None def _handle_results(self, episode_results: EpisodeResultsList, *, result: ProcessingResult) -> None: + if not result.feed: + return for episode_result in episode_results: if not isinstance(episode_result, EpisodeResult): try: @@ -116,9 +121,11 @@ def _handle_results(self, episode_results: EpisodeResultsList, *, result: Proces continue self.database.add(episode_result.episode) result.success += 1 + self.known_feeds[result.feed.url] = result.feed.info def shutdown(self) -> None: - self.stop_event.set() - self.pool_executor.shutdown(cancel_futures=True) + if not self.stop_event.is_set(): + self.stop_event.set() + self.pool_executor.shutdown(cancel_futures=True) - logger.debug("Completed processor shutdown") + logger.debug("Completed processor shutdown") diff --git a/podcast_archiver/session.py b/podcast_archiver/session.py index 7a4081a..9bc9c7d 100644 --- a/podcast_archiver/session.py +++ b/podcast_archiver/session.py @@ -1,21 +1,15 @@ -from typing import Any +from __future__ import annotations -from requests import PreparedRequest, Session +from typing import TYPE_CHECKING, Any + +from requests import Session from requests.adapters import HTTPAdapter -from requests.models import Response as Response from urllib3.util import Retry from podcast_archiver.constants import REQUESTS_TIMEOUT, USER_AGENT - -class DefaultTimeoutHTTPAdapter(HTTPAdapter): - def send( - self, - request: PreparedRequest, - timeout: None | float | tuple[float, float] | tuple[float, None] = None, - **kwargs: Any, - ) -> Response: - return super().send(request, timeout=timeout or REQUESTS_TIMEOUT, **kwargs) +if TYPE_CHECKING: + from requests.models import Response _retries = Retry( @@ -25,9 +19,29 @@ def send( status_forcelist=[500, 501, 502, 503, 504], ) -_adapter = DefaultTimeoutHTTPAdapter(max_retries=_retries) +_adapter = HTTPAdapter(max_retries=_retries) + + +class ArchiverSession(Session): + def get_and_raise( + self, + url: str, + *, + timeout: None | float | tuple[float, float] | tuple[float, None] = REQUESTS_TIMEOUT, + last_modified: str | None = None, + headers: dict[str, Any] | None = None, + **kwargs: Any, + ) -> Response: + if last_modified: + headers = headers or {} + headers["If-Modified-Since"] = last_modified + + response = self.get(url, timeout=timeout, headers=headers, **kwargs) + response.raise_for_status() + return response + -session = Session() +session = ArchiverSession() session.mount("http://", _adapter) session.mount("https://", _adapter) session.headers.update({"user-agent": USER_AGENT}) diff --git a/podcast_archiver/utils.py b/podcast_archiver/utils.py index 818b0ef..1a1eeed 100644 --- a/podcast_archiver/utils.py +++ b/podcast_archiver/utils.py @@ -4,12 +4,13 @@ import re from contextlib import contextmanager from string import Formatter -from typing import IO, TYPE_CHECKING, Any, Generator, Iterable, Iterator, TypedDict +from typing import IO, TYPE_CHECKING, Any, Generator, Iterable, Iterator, Literal, TypedDict, overload from pydantic import ValidationError from requests import HTTPError from slugify import slugify as _slugify +from podcast_archiver.exceptions import NotModified from podcast_archiver.logging import logger, rprint if TYPE_CHECKING: @@ -109,16 +110,29 @@ def format(self, episode: Episode, feed_info: FeedInfo) -> Path: # type: ignore return self._path_root / self.vformat(self._template, args=(), kwargs=kwargs) +@overload @contextmanager -def atomic_write(target: Path, mode: str = "w") -> Iterator[IO[Any]]: +def atomic_write(target: Path, mode: Literal["w"] = "w") -> Iterator[IO[str]]: ... + + +@overload +@contextmanager +def atomic_write(target: Path, mode: Literal["wb"]) -> Iterator[IO[bytes]]: ... + + +@contextmanager +def atomic_write(target: Path, mode: Literal["w", "wb"] = "w") -> Iterator[IO[bytes]] | Iterator[IO[str]]: tempfile = target.with_suffix(".part") try: with tempfile.open(mode) as fp: yield fp fp.flush() os.fsync(fp.fileno()) - logger.debug("Moving file %s => %s", tempfile, target) + logger.debug("Moving file '%s' => '%s'", tempfile, target) os.rename(tempfile, target) + except Exception: + target.unlink(missing_ok=True) + raise finally: tempfile.unlink(missing_ok=True) @@ -139,6 +153,10 @@ def handle_feed_request(url: str) -> Generator[None, Any, None]: logger.debug("Feed validation failed for %s", url, exc_info=exc) rprint(f"[error]Received invalid feed from {url}[/]") + except NotModified as exc: + logger.debug("Skipping retrieval for %s", exc.info) + rprint(f"\n[bar.finished]⏲ Feed of {exc.info} is unchanged, skipping.[/]") + except Exception as exc: logger.debug("Unexpected error for url %s", url, exc_info=exc) rprint(f"[error]Failed to retrieve feed {url}: {exc}[/]") diff --git a/tests/test_download.py b/tests/test_download.py index 5f7604e..0edc79f 100644 --- a/tests/test_download.py +++ b/tests/test_download.py @@ -39,6 +39,23 @@ def test_download_already_exists(tmp_path_cd: Path, feedobj_lautsprecher_notcons assert result == (episode, DownloadResult.ALREADY_EXISTS) +def test_download_partial( + tmp_path_cd: Path, + feedobj_lautsprecher: dict[str, Any], + caplog: pytest.LogCaptureFixture, +) -> None: + feed = FeedPage.model_validate(feedobj_lautsprecher) + episode = feed.episodes[0] + + job = download.DownloadJob(episode=episode, target=Path("file.mp3"), max_download_bytes=2) + with caplog.at_level(logging.DEBUG, "podcast_archiver"): + result = job() + + assert result == (episode, DownloadResult.COMPLETED_SUCCESSFULLY) + assert "Partial download of first 2 bytes completed." in caplog.messages + assert len(job.target.read_bytes()) + + def test_download_aborted(tmp_path_cd: Path, feedobj_lautsprecher: dict[str, Any]) -> None: feed = FeedPage.model_validate(feedobj_lautsprecher) episode = feed.episodes[0] diff --git a/tests/test_fetch.py b/tests/test_fetch.py index 7ff6540..36acae0 100644 --- a/tests/test_fetch.py +++ b/tests/test_fetch.py @@ -8,7 +8,7 @@ def test_fetch_from_http(feed_lautsprecher_onlyfeed: str) -> None: - assert Feed.from_url(feed_lautsprecher_onlyfeed) + assert Feed(feed_lautsprecher_onlyfeed) @pytest.mark.parametrize( @@ -19,4 +19,4 @@ def test_fetch_from_http(feed_lautsprecher_onlyfeed: str) -> None: ], ) def test_fetch_from_file(url: str) -> None: - assert Feed.from_url(url) + assert Feed(url) diff --git a/tests/test_models.py b/tests/test_models.py index 0cd24c7..546d0fd 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -2,14 +2,17 @@ import time from copy import deepcopy -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, Protocol import pytest from pydantic import ValidationError +from responses import RequestsMock from typing_extensions import TypedDict -from podcast_archiver.models import Episode +from podcast_archiver.exceptions import NotModified +from podcast_archiver.models import Episode, Feed, FeedInfo, FeedPage from podcast_archiver.utils import MIMETYPE_EXTENSION_MAPPING +from tests.conftest import FEED_CONTENT if TYPE_CHECKING: @@ -147,8 +150,9 @@ def test_episode_validation_shownotes_fallback() -> None: assert episode.shownotes.endswith("Grab a beverage and hit play!") +@pytest.mark.parametrize("urlpath", ["", "zenandtech/debug83"]) @pytest.mark.parametrize("mimetype,expected_ext", list(MIMETYPE_EXTENSION_MAPPING.items())) -def test_episode_missing_ext(mimetype: str, expected_ext: str) -> None: +def test_episode_missing_ext(urlpath: str, mimetype: str, expected_ext: str) -> None: episode = Episode.model_validate( { "title": "83: …", @@ -157,15 +161,14 @@ def test_episode_missing_ext(mimetype: str, expected_ext: str) -> None: { "length": "85468157", "type": mimetype, - "href": "http://traffic.libsyn.com/zenandtech/debug83", + "href": f"http://traffic.libsyn.com/{urlpath}", "rel": "enclosure", } ], } ) - assert episode.enclosure.href == "http://traffic.libsyn.com/zenandtech/debug83" - # assert episode.original_filename == "debug83" + assert episode.enclosure.href == f"http://traffic.libsyn.com/{urlpath}" assert episode.ext == expected_ext @@ -184,3 +187,36 @@ def test_invalid_link_length() -> None: ], } ) + + +class FeedConstructor(Protocol): + def __call__(self, url: str, *, known_info: FeedInfo | None = None) -> object: ... + + +@pytest.mark.parametrize("constructor", [FeedPage.from_url, Feed]) +def test_feed_with_known_info_not_modified(constructor: FeedConstructor, feed_lautsprecher_onlyfeed: str) -> None: + info = FeedPage.from_url(feed_lautsprecher_onlyfeed).feed + info.last_modified = "sometime" + + with RequestsMock() as responses, pytest.raises(NotModified): + responses.get(feed_lautsprecher_onlyfeed, status=304) + constructor(feed_lautsprecher_onlyfeed, known_info=info) + + +@pytest.mark.parametrize("constructor", [FeedPage.from_url, Feed]) +def test_feed_with_known_info_updated_time(constructor: FeedConstructor, feed_lautsprecher_onlyfeed: str) -> None: + info = FeedPage.from_url(feed_lautsprecher_onlyfeed).feed + info.last_modified = None + + with RequestsMock() as responses, pytest.raises(NotModified): + responses.get(feed_lautsprecher_onlyfeed, FEED_CONTENT) + constructor(feed_lautsprecher_onlyfeed, known_info=info) + + +@pytest.mark.parametrize("constructor", [FeedPage.from_url, Feed]) +def test_feed_with_known_info_updated_time_empty(constructor: FeedConstructor, feed_lautsprecher_onlyfeed: str) -> None: + info = FeedPage.from_url(feed_lautsprecher_onlyfeed).feed + info.updated_time = None + with RequestsMock() as responses: + responses.get(feed_lautsprecher_onlyfeed, FEED_CONTENT) + assert constructor(feed_lautsprecher_onlyfeed, known_info=info)