diff --git a/docs/peewee_async/api.rst b/docs/peewee_async/api.rst index 3e2799f..4b7843a 100644 --- a/docs/peewee_async/api.rst +++ b/docs/peewee_async/api.rst @@ -57,8 +57,12 @@ AioModelSelect .. autoclass:: peewee_async.aio_model.AioModelSelect +.. automethod:: peewee_async.aio_model.AioModelSelect.aio_peek + .. automethod:: peewee_async.aio_model.AioModelSelect.aio_scalar +.. automethod:: peewee_async.aio_model.AioModelSelect.aio_first + .. automethod:: peewee_async.aio_model.AioModelSelect.aio_get .. automethod:: peewee_async.aio_model.AioModelSelect.aio_count diff --git a/peewee_async/aio_model.py b/peewee_async/aio_model.py index 37d822a..d19e5f7 100644 --- a/peewee_async/aio_model.py +++ b/peewee_async/aio_model.py @@ -84,50 +84,72 @@ class AioModelRaw(peewee.ModelRaw, AioQueryMixin): pass -class AioSelectMixin(AioQueryMixin): +class AioSelectMixin(AioQueryMixin, peewee.SelectBase): + @peewee.database_required - async def aio_scalar(self, database: AioDatabase, as_tuple: bool = False) -> Any: + async def aio_peek(self, database: AioDatabase, n: int = 1) -> Any: """ - Get single value from ``select()`` query, i.e. for aggregation. - - :return: result is the same as after sync ``query.scalar()`` call - - See also: - http://docs.peewee-orm.com/en/3.15.3/peewee/api.html#SelectBase.scalar + Asynchronous version of + `peewee.SelectBase.peek `_ """ + async def fetch_results(cursor: CursorProtocol) -> Any: - return await cursor.fetchone() + return await fetch_models(cursor, self, n) rows = await database.aio_execute(self, fetch_results=fetch_results) + if rows: + return rows[0] if n == 1 else rows - return rows[0] if rows and not as_tuple else rows + @peewee.database_required + async def aio_scalar( + self, + database: AioDatabase, + as_tuple: bool = False, + as_dict: bool = False + ) -> Any: + """ + Asynchronous version of `peewee.SelectBase.scalar + `_ + """ + if as_dict: + return await self.dicts().aio_peek(database) + row = await self.tuples().aio_peek(database) - async def aio_get(self, database: Optional[AioDatabase] = None) -> Any: + return row[0] if row and not as_tuple else row + + @peewee.database_required + async def aio_first(self, database: AioDatabase, n: int = 1) -> Any: + """ + Asynchronous version of `peewee.SelectBase.first + `_ """ - Async version of **peewee.SelectBase.get** - See also: - http://docs.peewee-orm.com/en/3.15.3/peewee/api.html#SelectBase.get + if self._limit != n: # type: ignore + self._limit = n + return await self.aio_peek(database, n=n) + + async def aio_get(self, database: Optional[AioDatabase] = None) -> Any: """ - clone = self.paginate(1, 1) # type: ignore + Asynchronous version of `peewee.SelectBase.get + `_ + """ + clone = self.paginate(1, 1) try: return (await clone.aio_execute(database))[0] except IndexError: sql, params = clone.sql() - raise self.model.DoesNotExist('%s instance matching query does ' # type: ignore + raise self.model.DoesNotExist('%s instance matching query does ' 'not exist:\nSQL: %s\nParams: %s' % (clone.model, sql, params)) @peewee.database_required async def aio_count(self, database: AioDatabase, clear_limit: bool = False) -> int: """ - Async version of **peewee.SelectBase.count** - - See also: - http://docs.peewee-orm.com/en/3.15.3/peewee/api.html#SelectBase.count + Asynchronous version of `peewee.SelectBase.count + `_ """ - clone = self.order_by().alias('_wrapped') # type: ignore + clone = self.order_by().alias('_wrapped') if clear_limit: clone._limit = clone._offset = None try: @@ -145,38 +167,34 @@ async def aio_count(self, database: AioDatabase, clear_limit: bool = False) -> i @peewee.database_required async def aio_exists(self, database: AioDatabase) -> bool: """ - Async version of **peewee.SelectBase.exists** - - See also: - http://docs.peewee-orm.com/en/3.15.3/peewee/api.html#SelectBase.exists + Asynchronous version of `peewee.SelectBase.exists + `_ """ - clone = self.columns(peewee.SQL('1')) # type: ignore + clone = self.columns(peewee.SQL('1')) clone._limit = 1 clone._offset = None return bool(await clone.aio_scalar()) def union_all(self, rhs: Any) -> "AioModelCompoundSelectQuery": - return AioModelCompoundSelectQuery(self.model, self, 'UNION ALL', rhs) # type: ignore + return AioModelCompoundSelectQuery(self.model, self, 'UNION ALL', rhs) __add__ = union_all def union(self, rhs: Any) -> "AioModelCompoundSelectQuery": - return AioModelCompoundSelectQuery(self.model, self, 'UNION', rhs) # type: ignore + return AioModelCompoundSelectQuery(self.model, self, 'UNION', rhs) __or__ = union def intersect(self, rhs: Any) -> "AioModelCompoundSelectQuery": - return AioModelCompoundSelectQuery(self.model, self, 'INTERSECT', rhs) # type: ignore + return AioModelCompoundSelectQuery(self.model, self, 'INTERSECT', rhs) __and__ = intersect def except_(self, rhs: Any) -> "AioModelCompoundSelectQuery": - return AioModelCompoundSelectQuery(self.model, self, 'EXCEPT', rhs) # type: ignore + return AioModelCompoundSelectQuery(self.model, self, 'EXCEPT', rhs) __sub__ = except_ def aio_prefetch(self, *subqueries: Any, prefetch_type: PREFETCH_TYPE = PREFETCH_TYPE.WHERE) -> Any: """ - Async version of **peewee.ModelSelect.prefetch** - - See also: - http://docs.peewee-orm.com/en/3.15.3/peewee/api.html#ModelSelect.prefetch + Asynchronous version of `peewee.ModelSelect.prefetch + `_ """ return aio_prefetch(self, *subqueries, prefetch_type=prefetch_type) @@ -186,7 +204,7 @@ class AioSelect(AioSelectMixin, peewee.Select): class AioModelSelect(AioSelectMixin, peewee.ModelSelect): - """Async version of **peewee.ModelSelect** that provides async versions of ModelSelect methods + """Asynchronous version of **peewee.ModelSelect** that provides async versions of ModelSelect methods """ pass diff --git a/peewee_async/databases.py b/peewee_async/databases.py index dfa7e62..e981f62 100644 --- a/peewee_async/databases.py +++ b/peewee_async/databases.py @@ -18,16 +18,21 @@ class AioDatabase(peewee.Database): connection and **async connections pool** interface. :param pool_params: parameters that are passed to the pool - :param min_connections: min connections pool size. Alias for pool_params.minsize - :param max_connections: max connections pool size. Alias for pool_params.maxsize Example:: database = PooledPostgresqlExtDatabase( - 'test', - 'min_connections': 1, - 'max_connections': 5, - 'pool_params': {"timeout": 30, 'pool_recycle': 1.5} + 'database': 'postgres', + 'host': '127.0.0.1', + 'port':5432, + 'password': 'postgres', + 'user': 'postgres', + 'pool_params': { + "minsize": 0, + "maxsize": 5, + "timeout": 30, + 'pool_recycle': 1.5 + } ) See also: @@ -189,8 +194,23 @@ class PsycopgDatabase(AioDatabase, Psycopg3Database): """Extension for `peewee.PostgresqlDatabase` providing extra methods for managing async connection based on psycopg3 pool backend. + Example:: + + database = PsycopgDatabase( + 'database': 'postgres', + 'host': '127.0.0.1', + 'port': 5432, + 'password': 'postgres', + 'user': 'postgres', + 'pool_params': { + "min_size": 0, + "max_size": 5, + 'max_lifetime': 15 + } + ) + See also: - https://peewee.readthedocs.io/en/latest/peewee/api.html#PostgresqlDatabase + https://www.psycopg.org/psycopg3/docs/advanced/pool.html """ pool_backend_cls = PsycopgPoolBackend @@ -205,6 +225,23 @@ class PooledPostgresqlDatabase(AioDatabase, peewee.PostgresqlDatabase): """Extension for `peewee.PostgresqlDatabase` providing extra methods for managing async connection based on aiopg pool backend. + + Example:: + + database = PooledPostgresqlExtDatabase( + 'database': 'postgres', + 'host': '127.0.0.1', + 'port':5432, + 'password': 'postgres', + 'user': 'postgres', + 'pool_params': { + "minsize": 0, + "maxsize": 5, + "timeout": 30, + 'pool_recycle': 1.5 + } + ) + See also: https://peewee.readthedocs.io/en/latest/peewee/api.html#PostgresqlDatabase """ @@ -230,11 +267,6 @@ class PooledPostgresqlExtDatabase( JSON fields support is enabled by default, HStore supports is disabled by default, but can be enabled through pool_params or with ``register_hstore=False`` argument. - Example:: - - database = PooledPostgresqlExtDatabase('test', register_hstore=False, - max_connections=20) - See also: https://peewee.readthedocs.io/en/latest/peewee/playhouse.html#PostgresqlExtDatabase """ @@ -251,7 +283,19 @@ class PooledMySQLDatabase(AioDatabase, peewee.MySQLDatabase): Example:: - database = PooledMySQLDatabase('test', max_connections=10) + database = PooledMySQLDatabase( + 'database': 'mysql', + 'host': '127.0.0.1', + 'port': 3306, + 'user': 'root', + 'password': 'mysql', + 'connect_timeout': 30, + "pool_params": { + "minsize": 0, + "maxsize": 5, + "pool_recycle": 2 + } + ) See also: http://peewee.readthedocs.io/en/latest/peewee/api.html#MySQLDatabase diff --git a/peewee_async/result_wrappers.py b/peewee_async/result_wrappers.py index 777fc74..40eaf78 100644 --- a/peewee_async/result_wrappers.py +++ b/peewee_async/result_wrappers.py @@ -23,8 +23,11 @@ def close(self) -> None: pass -async def fetch_models(cursor: CursorProtocol, query: BaseQuery) -> List[Any]: - rows = await cursor.fetchall() +async def fetch_models(cursor: CursorProtocol, query: BaseQuery, count: Optional[int] = None) -> List[Any]: + if count is None: + rows = await cursor.fetchall() + else: + rows = await cursor.fetchmany(count) sync_cursor = SyncCursorAdapter(rows, cursor.description) _result_wrapper = query._get_cursor_wrapper(sync_cursor) return list(_result_wrapper) diff --git a/peewee_async/utils.py b/peewee_async/utils.py index 1d17da1..9dfd794 100644 --- a/peewee_async/utils.py +++ b/peewee_async/utils.py @@ -33,6 +33,9 @@ async def fetchone(self) -> Any: async def fetchall(self) -> List[Any]: ... + async def fetchmany(self, size: int) -> List[Any]: + ... + @property def lastrowid(self) -> int: ... diff --git a/tests/aio_model/test_shortcuts.py b/tests/aio_model/test_shortcuts.py index d388de5..e4fd6e9 100644 --- a/tests/aio_model/test_shortcuts.py +++ b/tests/aio_model/test_shortcuts.py @@ -1,3 +1,4 @@ +from typing import List, Union import uuid import peewee @@ -35,6 +36,64 @@ async def test_aio_get_or_none(db: AioDatabase) -> None: assert result is None +@dbs_all +@pytest.mark.parametrize( + ["peek_num", "expected"], + ( + (1, 1), + (2, [1,2]), + (5, [1,2,3]), + ) +) +async def test_aio_peek( + db: AioDatabase, + peek_num: int, + expected: Union[int, List[int]] +) -> None: + await IntegerTestModel.aio_create(num=1) + await IntegerTestModel.aio_create(num=2) + await IntegerTestModel.aio_create(num=3) + + rows = await IntegerTestModel.select().order_by( + IntegerTestModel.num + ).aio_peek(n=peek_num) + + if isinstance(rows, list): + result = [r.num for r in rows] + else: + result = rows.num + assert result == expected + + +@dbs_all +@pytest.mark.parametrize( + ["first_num", "expected"], + ( + (1, 1), + (2, [1,2]), + (5, [1,2,3]), + ) +) +async def test_aio_first( + db: AioDatabase, + first_num: int, + expected: Union[int, List[int]] +) -> None: + await IntegerTestModel.aio_create(num=1) + await IntegerTestModel.aio_create(num=2) + await IntegerTestModel.aio_create(num=3) + + rows = await IntegerTestModel.select().order_by( + IntegerTestModel.num + ).aio_first(n=first_num) + + if isinstance(rows, list): + result = [r.num for r in rows] + else: + result = rows.num + assert result == expected + + @dbs_all async def test_aio_scalar(db: AioDatabase) -> None: await IntegerTestModel.aio_create(num=1) @@ -46,6 +105,11 @@ async def test_aio_scalar(db: AioDatabase) -> None: fn.MAX(IntegerTestModel.num),fn.Min(IntegerTestModel.num) ).aio_scalar(as_tuple=True) == (2, 1) + assert await IntegerTestModel.select( + fn.MAX(IntegerTestModel.num).alias('max'), + fn.Min(IntegerTestModel.num).alias('min') + ).aio_scalar(as_dict=True) == {'max': 2, 'min': 1} + assert await TestModel.select().aio_scalar() is None