From 12d6f75e2de60bc9f8245a374824d6c004b28bc2 Mon Sep 17 00:00:00 2001 From: Mark Kurtz Date: Fri, 2 Aug 2024 00:10:20 -0400 Subject: [PATCH] fix failing test cases --- src/guidellm/core/serializable.py | 25 +++++++++++++------------ src/guidellm/scheduler/scheduler.py | 2 +- 2 files changed, 14 insertions(+), 13 deletions(-) diff --git a/src/guidellm/core/serializable.py b/src/guidellm/core/serializable.py index f700794..2d81ebb 100644 --- a/src/guidellm/core/serializable.py +++ b/src/guidellm/core/serializable.py @@ -105,21 +105,22 @@ def save_file( if isinstance(path, str): path = Path(path) - if path.is_dir(): + if path.suffix: + # is a file + ext = path.suffix[1:].upper() + if ext not in SerializableFileType.__members__: + raise ValueError( + f"Unsupported file extension: {ext}. " + f"Expected one of {', '.join(SerializableFileType.__members__)}) " + f"for {path}" + ) + type_ = SerializableFileType[ext] + else: + # is a directory file_name = f"{self.__class__.__name__.lower()}.{type_.value.lower()}" path = path / file_name - elif path.is_file(): - type_ = SerializableFileType(path.suffix[1:].upper()) - else: - raise ValueError(f"Invalid path: {path}") - - if type_.name not in SerializableFileType.__members__: - raise ValueError( - f"Unsupported file format: {type_} " - f"(expected 'yaml' or 'json') for {path}" - ) - path.mkdir(parents=True, exist_ok=True) + path.parent.mkdir(parents=True, exist_ok=True) with path.open("w") as file: if type_ == SerializableFileType.YAML: diff --git a/src/guidellm/scheduler/scheduler.py b/src/guidellm/scheduler/scheduler.py index 614f176..7c46963 100644 --- a/src/guidellm/scheduler/scheduler.py +++ b/src/guidellm/scheduler/scheduler.py @@ -151,7 +151,7 @@ async def _run_async(self) -> TextGenerationBenchmark: if ( self._max_requests is not None - and requests_counter >= self._max_requests + and requests_counter >= self._max_requests - 1 ): break