Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(python): Support DataFrame init from raw SQLAlchemy rows #19820

Merged
merged 2 commits into from
Nov 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 13 additions & 7 deletions py-polars/polars/_utils/construction/dataframe.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from __future__ import annotations

import contextlib
from collections.abc import Generator, Mapping
from collections.abc import Generator, Mapping, Sequence
from datetime import date, datetime, time, timedelta
from functools import singledispatch
from itertools import islice, zip_longest
Expand All @@ -20,6 +20,7 @@
is_namedtuple,
is_pydantic_model,
is_simple_numpy_backed_pandas_series,
is_sqlalchemy,
nt_unpack,
try_get_type_hints,
)
Expand Down Expand Up @@ -59,7 +60,7 @@
from polars.polars import PyDataFrame

if TYPE_CHECKING:
from collections.abc import Iterable, MutableMapping, Sequence
from collections.abc import Iterable, MutableMapping

from polars import DataFrame, Expr, Series
from polars._typing import (
Expand Down Expand Up @@ -480,9 +481,9 @@ def _sequence_to_pydf_dispatcher(
infer_schema_length: int | None,
) -> PyDataFrame:
# note: ONLY python-native data should participate in singledispatch registration
# via top-level decorators. third-party libraries (such as numpy/pandas) should
# first be identified inline (here) and THEN registered for dispatch dynamically
# so as not to break lazy-loading behaviour.
# via top-level decorators, otherwise we have to import the associated module.
# third-party libraries (such as numpy/pandas) should instead be identified inline
# and THEN registered for dispatch (here) so as not to break lazy-loading behaviour.

common_params = {
"data": data,
Expand All @@ -492,7 +493,6 @@ def _sequence_to_pydf_dispatcher(
"orient": orient,
"infer_schema_length": infer_schema_length,
}

to_pydf: Callable[..., PyDataFrame]
register_with_singledispatch = True

Expand All @@ -518,6 +518,12 @@ def _sequence_to_pydf_dispatcher(

elif is_pydantic_model(first_element):
to_pydf = _sequence_of_pydantic_models_to_pydf

elif is_sqlalchemy(first_element):
to_pydf = _sequence_of_tuple_to_pydf

elif isinstance(first_element, Sequence) and not isinstance(first_element, str):
to_pydf = _sequence_of_sequence_to_pydf
else:
to_pydf = _sequence_of_elements_to_pydf

Expand Down Expand Up @@ -652,7 +658,7 @@ def _sequence_of_tuple_to_pydf(
infer_schema_length: int | None,
) -> PyDataFrame:
# infer additional meta information if namedtuple
if is_namedtuple(first_element.__class__):
if is_namedtuple(first_element.__class__) or is_sqlalchemy(first_element):
if schema is None:
schema = first_element._fields # type: ignore[attr-defined]
annotations = getattr(first_element, "__annotations__", None)
Expand Down
2 changes: 2 additions & 0 deletions py-polars/polars/_utils/construction/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
is_namedtuple,
is_pydantic_model,
is_simple_numpy_backed_pandas_series,
is_sqlalchemy,
)
from polars._utils.various import (
range_to_series,
Expand Down Expand Up @@ -105,6 +106,7 @@ def sequence_to_pyseries(
dataclasses.is_dataclass(value)
or is_pydantic_model(value)
or is_namedtuple(value.__class__)
or is_sqlalchemy(value)
) and dtype != Object:
return pl.DataFrame(values).to_struct(name)._s
elif isinstance(value, range) and dtype is None:
Expand Down
5 changes: 5 additions & 0 deletions py-polars/polars/_utils/construction/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,11 @@ def is_pydantic_model(value: Any) -> bool:
return _check_for_pydantic(value) and isinstance(value, pydantic.BaseModel)


def is_sqlalchemy(value: Any) -> bool:
"""Check whether value is an instance of a SQLAlchemy object."""
return getattr(value, "__module__", "").startswith("sqlalchemy.")


def get_first_non_none(values: Sequence[Any | None]) -> Any:
"""
Return the first value from a sequence that isn't None.
Expand Down
22 changes: 13 additions & 9 deletions py-polars/polars/io/database/_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,24 +178,24 @@ def _fetch_arrow(
yield arrow

@staticmethod
def _fetchall_rows(result: Cursor) -> Iterable[Sequence[Any]]:
def _fetchall_rows(result: Cursor, *, is_alchemy: bool) -> Iterable[Sequence[Any]]:
"""Fetch row data in a single call, returning the complete result set."""
rows = result.fetchall()
return (
[tuple(row) for row in rows]
if rows and not isinstance(rows[0], (list, tuple, dict))
else rows
rows
if rows and (is_alchemy or isinstance(rows[0], (list, tuple, dict)))
else [tuple(row) for row in rows]
)

def _fetchmany_rows(
self, result: Cursor, batch_size: int | None
self, result: Cursor, *, batch_size: int | None, is_alchemy: bool
) -> Iterable[Sequence[Any]]:
"""Fetch row data incrementally, yielding over the complete result set."""
while True:
rows = result.fetchmany(batch_size)
if not rows:
break
elif isinstance(rows[0], (list, tuple, dict)):
elif is_alchemy or isinstance(rows[0], (list, tuple, dict)):
yield rows
else:
yield [tuple(row) for row in rows]
Expand Down Expand Up @@ -267,7 +267,7 @@ def _from_rows(
self.result = _run_async(self.result)
try:
if hasattr(self.result, "fetchall"):
if self.driver_name == "sqlalchemy":
if is_alchemy := (self.driver_name == "sqlalchemy"):
if hasattr(self.result, "cursor"):
cursor_desc = [
(d[0], d[1:]) for d in self.result.cursor.description
Expand Down Expand Up @@ -297,9 +297,13 @@ def _from_rows(
orient="row",
)
for rows in (
self._fetchmany_rows(self.result, batch_size)
self._fetchmany_rows(
self.result,
batch_size=batch_size,
is_alchemy=is_alchemy,
)
if iter_batches
else [self._fetchall_rows(self.result)] # type: ignore[list-item]
else [self._fetchall_rows(self.result, is_alchemy=is_alchemy)] # type: ignore[list-item]
)
)
return frames if iter_batches else next(frames) # type: ignore[arg-type]
Expand Down
36 changes: 30 additions & 6 deletions py-polars/tests/unit/io/database/test_read.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from polars._utils.various import parse_version
from polars.exceptions import DuplicateError, UnsuitableSQLError
from polars.io.database._arrow_registry import ARROW_DRIVER_REGISTRY
from polars.testing import assert_frame_equal
from polars.testing import assert_frame_equal, assert_series_equal

if TYPE_CHECKING:
from polars._typing import (
Expand Down Expand Up @@ -390,11 +390,13 @@ def test_read_database_alchemy_textclause(tmp_sqlite_db: Path) -> None:
alchemy_conn: ConnectionOrCursor = alchemy_engine.connect()

# establish sqlalchemy "textclause" and validate usage
textclause_query = text("""
SELECT CAST(STRFTIME('%Y',"date") AS INT) as "year", name, value
FROM test_data
WHERE value < 0
""")
textclause_query = text(
"""
SELECT CAST(STRFTIME('%Y',"date") AS INT) as "year", name, value
FROM test_data
WHERE value < 0
"""
)

expected = pl.DataFrame({"year": [2021], "name": ["other"], "value": [-99.5]})

Expand Down Expand Up @@ -815,3 +817,25 @@ def test_read_kuzu_graph_database(tmp_path: Path, io_files_path: Path) -> None:
schema={"a.name": pl.Utf8, "f.since": pl.Int64, "b.name": pl.Utf8}
),
)


def test_sqlalchemy_row_init(tmp_sqlite_db: Path) -> None:
expected_frame = pl.DataFrame(
{
"id": [1, 2],
"name": ["misc", "other"],
"value": [100.0, -99.5],
"date": ["2020-01-01", "2021-12-31"],
}
)
expected_series = expected_frame.to_struct()

alchemy_engine = create_engine(f"sqlite:///{tmp_sqlite_db}")
with alchemy_engine.connect() as conn:
query_result = conn.execute(text("SELECT * FROM test_data ORDER BY name"))
df = pl.DataFrame(list(query_result))
assert_frame_equal(expected_frame, df)

query_result = conn.execute(text("SELECT * FROM test_data ORDER BY name"))
s = pl.Series(list(query_result))
assert_series_equal(expected_series, s)
Loading