From 056d97650a8133bd7291718309b735c57dc9a03a Mon Sep 17 00:00:00 2001 From: Aleksey Gureiev Date: Mon, 8 Jul 2024 16:12:21 +0300 Subject: [PATCH] Fix JSON serialization error for UUID primary keys when excluded from list (#553) * Add forced PK serialization When PK is not listed among fields it needs proper serialization in case it's not of a JSON-serializable type already (for example, UUID). * Add string field size for MySQL tests * Update tests/sqla/utils.py * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update tests/sqla/utils.py * Update tests/sqla/utils.py --------- Co-authored-by: Aleksey Gureiev Co-authored-by: Jocelin Hounon Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- starlette_admin/contrib/sqla/view.py | 15 ++--- starlette_admin/views.py | 34 ++++++++-- tests/sqla/test_view_serialization.py | 92 +++++++++++++++++++++++++++ tests/sqla/utils.py | 25 ++++++++ 4 files changed, 149 insertions(+), 17 deletions(-) create mode 100644 tests/sqla/test_view_serialization.py diff --git a/starlette_admin/contrib/sqla/view.py b/starlette_admin/contrib/sqla/view.py index 90558591..c0236255 100644 --- a/starlette_admin/contrib/sqla/view.py +++ b/starlette_admin/contrib/sqla/view.py @@ -1,16 +1,7 @@ from typing import Any, ClassVar, Dict, List, Optional, Sequence, Tuple, Type, Union import anyio.to_thread -from sqlalchemy import ( - String, - and_, - cast, - func, - inspect, - or_, - select, - tuple_, -) +from sqlalchemy import String, and_, cast, func, inspect, or_, select, tuple_ from sqlalchemy.exc import DBAPIError, NoInspectionAvailable, SQLAlchemyError from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.orm import ( @@ -606,6 +597,10 @@ def build_order_clauses( async def get_pk_value(self, request: Request, obj: Any) -> Any: return await self.pk_field.parse_obj(request, obj) + async def get_serialized_pk_value(self, request: Request, obj: Any) -> Any: + value = await self.get_pk_value(request, obj) + return await self.pk_field.serialize_value(request, value, request.state.action) + def handle_exception(self, exc: Exception) -> None: try: """Automatically handle sqlalchemy_file error""" diff --git a/starlette_admin/views.py b/starlette_admin/views.py index b0568a46..8dd197ef 100644 --- a/starlette_admin/views.py +++ b/starlette_admin/views.py @@ -719,8 +719,8 @@ async def serialize( obj_serialized[field.name] = None elif isinstance(field, HasOne): if action == RequestAction.EDIT: - obj_serialized[field.name] = await foreign_model.get_pk_value( - request, value + obj_serialized[field.name] = ( + await foreign_model.get_serialized_pk_value(request, value) ) else: obj_serialized[field.name] = await foreign_model.serialize( @@ -729,7 +729,7 @@ async def serialize( else: if action == RequestAction.EDIT: obj_serialized[field.name] = [ - (await foreign_model.get_pk_value(request, obj)) + (await foreign_model.get_serialized_pk_value(request, obj)) for obj in value ] else: @@ -750,11 +750,14 @@ async def serialize( "result": await self.select2_result(obj, request), } obj_meta["repr"] = await self.repr(obj, request) + + # Make sure the primary key is always available + pk_attr = not_none(self.pk_attr) + if pk_attr not in obj_serialized: + pk_value = await self.get_serialized_pk_value(request, obj) + obj_serialized[pk_attr] = pk_value + pk = await self.get_pk_value(request, obj) - obj_serialized[not_none(self.pk_attr)] = obj_serialized.get( - not_none(self.pk_attr), - pk, # Make sure the primary key is always available - ) route_name = request.app.state.ROUTE_NAME obj_meta["detailUrl"] = str( request.url_for(route_name + ":detail", identity=self.identity, pk=pk) @@ -879,6 +882,23 @@ async def select2_selection(self, obj: Any, request: Request) -> str: async def get_pk_value(self, request: Request, obj: Any) -> Any: return getattr(obj, not_none(self.pk_attr)) + async def get_serialized_pk_value(self, request: Request, obj: Any) -> Any: + """ + Return serialized value of the primary key. + + !!! note + + The returned value should be JSON-serializable. + + Parameters: + request: The request being processed + obj: object to get primary key of + + Returns: + Any: Serialized value of a PK. + """ + return await self.get_pk_value(request, obj) + def _length_menu(self) -> Any: return [ self.page_size_options, diff --git a/tests/sqla/test_view_serialization.py b/tests/sqla/test_view_serialization.py new file mode 100644 index 00000000..294b7aff --- /dev/null +++ b/tests/sqla/test_view_serialization.py @@ -0,0 +1,92 @@ +import uuid + +import pytest +import pytest_asyncio +from httpx import AsyncClient +from sqlalchemy import Boolean, Column, ForeignKey, String +from sqlalchemy.engine import Engine +from sqlalchemy.orm import Session, declarative_base, relationship +from starlette.applications import Starlette +from starlette_admin.contrib.sqla import Admin +from starlette_admin.contrib.sqla.view import ModelView + +from tests.sqla.utils import Uuid, get_test_engine + +Base = declarative_base() + + +class User(Base): + __tablename__ = "user" + + id = Column(Uuid, primary_key=True, default=uuid.uuid1, unique=True) + name = Column(String(50)) + membership = relationship("Membership", back_populates="user", uselist=False) + + +class Membership(Base): + __tablename__ = "membership" + + id = Column(Uuid, primary_key=True, default=uuid.uuid1, unique=True) + is_active = Column(Boolean, default=True) + + user_id = Column(Uuid, ForeignKey("user.id"), unique=True) + user = relationship("User", back_populates="membership") + + +class UserView(ModelView): + fields = ["name", "membership"] + + +@pytest.fixture +def engine() -> Engine: + engine = get_test_engine() + + Base.metadata.create_all(engine) + + try: + yield engine + + finally: + Base.metadata.drop_all(engine) + + +@pytest.fixture +def app(engine: Engine): + app = Starlette() + + admin = Admin(engine) + admin.add_view(UserView(User)) + admin.add_view(ModelView(Membership)) + admin.mount_to(app) + + return app + + +@pytest_asyncio.fixture +async def client(app): + async with AsyncClient(app=app, base_url="http://testserver") as c: + yield c + + +@pytest.mark.asyncio +async def test_ensuring_pk(client: AsyncClient, engine: Engine): + """ + Ensures PK is present in the serialized data and properly serialized as a string. + """ + user_id = uuid.uuid1() + membership_id = uuid.uuid1() + + user = User(id=user_id, name="Jack") + membership = Membership(id=membership_id, is_active=True, user=user) + + with Session(engine) as session: + session.add(user) + session.add(membership) + session.commit() + + response = await client.get("/admin/api/user") + data = response.json() + + assert [(str(user_id), str(membership_id))] == [ + (x["id"], x["membership"]["id"]) for x in data["items"] + ] diff --git a/tests/sqla/utils.py b/tests/sqla/utils.py index 06206faa..9babc40d 100644 --- a/tests/sqla/utils.py +++ b/tests/sqla/utils.py @@ -1,10 +1,13 @@ import os +import uuid +import sqlalchemy.types as types from libcloud.storage.base import Container, StorageDriver from libcloud.storage.drivers.local import LocalStorageDriver from libcloud.storage.drivers.minio import MinIOStorageDriver from libcloud.storage.types import ContainerDoesNotExistError from sqlalchemy import create_engine +from sqlalchemy.dialects import mysql from sqlalchemy.engine import Engine from sqlalchemy.ext.asyncio import AsyncEngine, create_async_engine @@ -41,3 +44,25 @@ def get_test_container(name: str) -> Container: dir_path = os.environ.get("LOCAL_PATH", "/tmp/storage") os.makedirs(dir_path, 0o777, exist_ok=True) return get_or_create_container(LocalStorageDriver(dir_path), name) + + +class Uuid(types.TypeDecorator): + """ + Platform-independent UUID type for testing. + """ + + impl = types.CHAR + + def load_dialect_impl(self, dialect): + return mysql.CHAR(32) if dialect == "mysql" else types.CHAR(32) + + def process_bind_param(self, value, dialect): + + return value.hex + + def process_result_value(self, value, dialect): + + if not isinstance(value, uuid.UUID): + value = uuid.UUID(value) + + return value