Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add aio_peek aio_frist #312

Merged
merged 2 commits into from
Feb 9, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions docs/peewee_async/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -57,8 +57,12 @@ AioModelSelect

.. autoclass:: peewee_async.aio_model.AioModelSelect

.. automethod:: peewee_async.aio_model.AioModelSelect.aio_peek

.. automethod:: peewee_async.aio_model.AioModelSelect.aio_scalar

.. automethod:: peewee_async.aio_model.AioModelSelect.aio_first

.. automethod:: peewee_async.aio_model.AioModelSelect.aio_get

.. automethod:: peewee_async.aio_model.AioModelSelect.aio_count
Expand Down
88 changes: 53 additions & 35 deletions peewee_async/aio_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,50 +84,72 @@ class AioModelRaw(peewee.ModelRaw, AioQueryMixin):
pass


class AioSelectMixin(AioQueryMixin):
class AioSelectMixin(AioQueryMixin, peewee.SelectBase):


@peewee.database_required
async def aio_scalar(self, database: AioDatabase, as_tuple: bool = False) -> Any:
async def aio_peek(self, database: AioDatabase, n: int = 1) -> Any:
"""
Get single value from ``select()`` query, i.e. for aggregation.

:return: result is the same as after sync ``query.scalar()`` call

See also:
http://docs.peewee-orm.com/en/3.15.3/peewee/api.html#SelectBase.scalar
Asynchronous version of
`peewee.SelectBase.peek <https://docs.peewee-orm.com/en/latest/peewee/api.html#SelectBase.peek>`_
"""

async def fetch_results(cursor: CursorProtocol) -> Any:
return await cursor.fetchone()
return await fetch_models(cursor, self, n)

rows = await database.aio_execute(self, fetch_results=fetch_results)
if rows:
return rows[0] if n == 1 else rows

return rows[0] if rows and not as_tuple else rows
@peewee.database_required
async def aio_scalar(
self,
database: AioDatabase,
as_tuple: bool = False,
as_dict: bool = False
) -> Any:
"""
Asynchronous version of `peewee.SelectBase.scalar
<https://docs.peewee-orm.com/en/latest/peewee/api.html#SelectBase.scalar>`_
"""
if as_dict:
return await self.dicts().aio_peek(database)
row = await self.tuples().aio_peek(database)

async def aio_get(self, database: Optional[AioDatabase] = None) -> Any:
return row[0] if row and not as_tuple else row

@peewee.database_required
async def aio_first(self, database: AioDatabase, n: int = 1) -> Any:
"""
Asynchronous version of `peewee.SelectBase.first
<https://docs.peewee-orm.com/en/latest/peewee/api.html#SelectBase.first>`_
"""
Async version of **peewee.SelectBase.get**

See also:
http://docs.peewee-orm.com/en/3.15.3/peewee/api.html#SelectBase.get
if self._limit != n: # type: ignore
self._limit = n
return await self.aio_peek(database, n=n)

async def aio_get(self, database: Optional[AioDatabase] = None) -> Any:
"""
clone = self.paginate(1, 1) # type: ignore
Asynchronous version of `peewee.SelectBase.get
<https://docs.peewee-orm.com/en/latest/peewee/api.html#SelectBase.get>`_
"""
clone = self.paginate(1, 1)
try:
return (await clone.aio_execute(database))[0]
except IndexError:
sql, params = clone.sql()
raise self.model.DoesNotExist('%s instance matching query does ' # type: ignore
raise self.model.DoesNotExist('%s instance matching query does '
'not exist:\nSQL: %s\nParams: %s' %
(clone.model, sql, params))

@peewee.database_required
async def aio_count(self, database: AioDatabase, clear_limit: bool = False) -> int:
"""
Async version of **peewee.SelectBase.count**

See also:
http://docs.peewee-orm.com/en/3.15.3/peewee/api.html#SelectBase.count
Asynchronous version of `peewee.SelectBase.count
<https://docs.peewee-orm.com/en/latest/peewee/api.html#SelectBase.count>`_
"""
clone = self.order_by().alias('_wrapped') # type: ignore
clone = self.order_by().alias('_wrapped')
if clear_limit:
clone._limit = clone._offset = None
try:
Expand All @@ -145,38 +167,34 @@ async def aio_count(self, database: AioDatabase, clear_limit: bool = False) -> i
@peewee.database_required
async def aio_exists(self, database: AioDatabase) -> bool:
"""
Async version of **peewee.SelectBase.exists**

See also:
http://docs.peewee-orm.com/en/3.15.3/peewee/api.html#SelectBase.exists
Asynchronous version of `peewee.SelectBase.exists
<https://docs.peewee-orm.com/en/latest/peewee/api.html#SelectBase.exists>`_
"""
clone = self.columns(peewee.SQL('1')) # type: ignore
clone = self.columns(peewee.SQL('1'))
clone._limit = 1
clone._offset = None
return bool(await clone.aio_scalar())

def union_all(self, rhs: Any) -> "AioModelCompoundSelectQuery":
return AioModelCompoundSelectQuery(self.model, self, 'UNION ALL', rhs) # type: ignore
return AioModelCompoundSelectQuery(self.model, self, 'UNION ALL', rhs)
__add__ = union_all

def union(self, rhs: Any) -> "AioModelCompoundSelectQuery":
return AioModelCompoundSelectQuery(self.model, self, 'UNION', rhs) # type: ignore
return AioModelCompoundSelectQuery(self.model, self, 'UNION', rhs)
__or__ = union

def intersect(self, rhs: Any) -> "AioModelCompoundSelectQuery":
return AioModelCompoundSelectQuery(self.model, self, 'INTERSECT', rhs) # type: ignore
return AioModelCompoundSelectQuery(self.model, self, 'INTERSECT', rhs)
__and__ = intersect

def except_(self, rhs: Any) -> "AioModelCompoundSelectQuery":
return AioModelCompoundSelectQuery(self.model, self, 'EXCEPT', rhs) # type: ignore
return AioModelCompoundSelectQuery(self.model, self, 'EXCEPT', rhs)
__sub__ = except_

def aio_prefetch(self, *subqueries: Any, prefetch_type: PREFETCH_TYPE = PREFETCH_TYPE.WHERE) -> Any:
"""
Async version of **peewee.ModelSelect.prefetch**

See also:
http://docs.peewee-orm.com/en/3.15.3/peewee/api.html#ModelSelect.prefetch
Asynchronous version of `peewee.ModelSelect.prefetch
<https://docs.peewee-orm.com/en/latest/peewee/api.html#ModelSelect.prefetch>`_
"""
return aio_prefetch(self, *subqueries, prefetch_type=prefetch_type)

Expand All @@ -186,7 +204,7 @@ class AioSelect(AioSelectMixin, peewee.Select):


class AioModelSelect(AioSelectMixin, peewee.ModelSelect):
"""Async version of **peewee.ModelSelect** that provides async versions of ModelSelect methods
"""Asynchronous version of **peewee.ModelSelect** that provides async versions of ModelSelect methods
"""
pass

Expand Down
70 changes: 57 additions & 13 deletions peewee_async/databases.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,16 +18,21 @@ class AioDatabase(peewee.Database):
connection and **async connections pool** interface.

:param pool_params: parameters that are passed to the pool
:param min_connections: min connections pool size. Alias for pool_params.minsize
:param max_connections: max connections pool size. Alias for pool_params.maxsize

Example::

database = PooledPostgresqlExtDatabase(
'test',
'min_connections': 1,
'max_connections': 5,
'pool_params': {"timeout": 30, 'pool_recycle': 1.5}
'database': 'postgres',
'host': '127.0.0.1',
'port':5432,
'password': 'postgres',
'user': 'postgres',
'pool_params': {
"minsize": 0,
"maxsize": 5,
"timeout": 30,
'pool_recycle': 1.5
}
)

See also:
Expand Down Expand Up @@ -189,8 +194,23 @@ class PsycopgDatabase(AioDatabase, Psycopg3Database):
"""Extension for `peewee.PostgresqlDatabase` providing extra methods
for managing async connection based on psycopg3 pool backend.

Example::

database = PsycopgDatabase(
'database': 'postgres',
'host': '127.0.0.1',
'port': 5432,
'password': 'postgres',
'user': 'postgres',
'pool_params': {
"min_size": 0,
"max_size": 5,
'max_lifetime': 15
}
)

See also:
https://peewee.readthedocs.io/en/latest/peewee/api.html#PostgresqlDatabase
https://www.psycopg.org/psycopg3/docs/advanced/pool.html
"""

pool_backend_cls = PsycopgPoolBackend
Expand All @@ -205,6 +225,23 @@ class PooledPostgresqlDatabase(AioDatabase, peewee.PostgresqlDatabase):
"""Extension for `peewee.PostgresqlDatabase` providing extra methods
for managing async connection based on aiopg pool backend.


Example::

database = PooledPostgresqlExtDatabase(
'database': 'postgres',
'host': '127.0.0.1',
'port':5432,
'password': 'postgres',
'user': 'postgres',
'pool_params': {
"minsize": 0,
"maxsize": 5,
"timeout": 30,
'pool_recycle': 1.5
}
)

See also:
https://peewee.readthedocs.io/en/latest/peewee/api.html#PostgresqlDatabase
"""
Expand All @@ -230,11 +267,6 @@ class PooledPostgresqlExtDatabase(
JSON fields support is enabled by default, HStore supports is disabled by
default, but can be enabled through pool_params or with ``register_hstore=False`` argument.

Example::

database = PooledPostgresqlExtDatabase('test', register_hstore=False,
max_connections=20)

See also:
https://peewee.readthedocs.io/en/latest/peewee/playhouse.html#PostgresqlExtDatabase
"""
Expand All @@ -251,7 +283,19 @@ class PooledMySQLDatabase(AioDatabase, peewee.MySQLDatabase):

Example::

database = PooledMySQLDatabase('test', max_connections=10)
database = PooledMySQLDatabase(
'database': 'mysql',
'host': '127.0.0.1',
'port': 3306,
'user': 'root',
'password': 'mysql',
'connect_timeout': 30,
"pool_params": {
"minsize": 0,
"maxsize": 5,
"pool_recycle": 2
}
)

See also:
http://peewee.readthedocs.io/en/latest/peewee/api.html#MySQLDatabase
Expand Down
7 changes: 5 additions & 2 deletions peewee_async/result_wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,11 @@ def close(self) -> None:
pass


async def fetch_models(cursor: CursorProtocol, query: BaseQuery) -> List[Any]:
rows = await cursor.fetchall()
async def fetch_models(cursor: CursorProtocol, query: BaseQuery, count: Optional[int] = None) -> List[Any]:
if count is None:
rows = await cursor.fetchall()
else:
rows = await cursor.fetchmany(count)
sync_cursor = SyncCursorAdapter(rows, cursor.description)
_result_wrapper = query._get_cursor_wrapper(sync_cursor)
return list(_result_wrapper)
3 changes: 3 additions & 0 deletions peewee_async/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,9 @@ async def fetchone(self) -> Any:
async def fetchall(self) -> List[Any]:
...

async def fetchmany(self, size: int) -> List[Any]:
...

@property
def lastrowid(self) -> int:
...
Expand Down
64 changes: 64 additions & 0 deletions tests/aio_model/test_shortcuts.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from typing import List, Union
import uuid

import peewee
Expand Down Expand Up @@ -35,6 +36,64 @@ async def test_aio_get_or_none(db: AioDatabase) -> None:
assert result is None


@dbs_all
@pytest.mark.parametrize(
["peek_num", "expected"],
(
(1, 1),
(2, [1,2]),
(5, [1,2,3]),
)
)
async def test_aio_peek(
db: AioDatabase,
peek_num: int,
expected: Union[int, List[int]]
) -> None:
await IntegerTestModel.aio_create(num=1)
await IntegerTestModel.aio_create(num=2)
await IntegerTestModel.aio_create(num=3)

rows = await IntegerTestModel.select().order_by(
IntegerTestModel.num
).aio_peek(n=peek_num)

if isinstance(rows, list):
result = [r.num for r in rows]
else:
result = rows.num
assert result == expected


@dbs_all
@pytest.mark.parametrize(
["first_num", "expected"],
(
(1, 1),
(2, [1,2]),
(5, [1,2,3]),
)
)
async def test_aio_first(
db: AioDatabase,
first_num: int,
expected: Union[int, List[int]]
) -> None:
await IntegerTestModel.aio_create(num=1)
await IntegerTestModel.aio_create(num=2)
await IntegerTestModel.aio_create(num=3)

rows = await IntegerTestModel.select().order_by(
IntegerTestModel.num
).aio_first(n=first_num)

if isinstance(rows, list):
result = [r.num for r in rows]
else:
result = rows.num
assert result == expected


@dbs_all
async def test_aio_scalar(db: AioDatabase) -> None:
await IntegerTestModel.aio_create(num=1)
Expand All @@ -46,6 +105,11 @@ async def test_aio_scalar(db: AioDatabase) -> None:
fn.MAX(IntegerTestModel.num),fn.Min(IntegerTestModel.num)
).aio_scalar(as_tuple=True) == (2, 1)

assert await IntegerTestModel.select(
fn.MAX(IntegerTestModel.num).alias('max'),
fn.Min(IntegerTestModel.num).alias('min')
).aio_scalar(as_dict=True) == {'max': 2, 'min': 1}

assert await TestModel.select().aio_scalar() is None


Expand Down