Skip to content

Commit

Permalink
Fix JSON serialization error for UUID primary keys when excluded from…
Browse files Browse the repository at this point in the history
… list (#553)

* Add forced PK serialization

When PK is not listed among fields it needs proper serialization in case
it's not of a JSON-serializable type already (for example, UUID).

* Add string field size for MySQL tests

* Update tests/sqla/utils.py

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Update tests/sqla/utils.py

* Update tests/sqla/utils.py

---------

Co-authored-by: Aleksey Gureiev <agureiev@shakuro.com>
Co-authored-by: Jocelin Hounon <hounonj@gmail.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
4 people authored Jul 8, 2024
1 parent 13e902f commit 056d976
Show file tree
Hide file tree
Showing 4 changed files with 149 additions and 17 deletions.
15 changes: 5 additions & 10 deletions starlette_admin/contrib/sqla/view.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,7 @@
from typing import Any, ClassVar, Dict, List, Optional, Sequence, Tuple, Type, Union

import anyio.to_thread
from sqlalchemy import (
String,
and_,
cast,
func,
inspect,
or_,
select,
tuple_,
)
from sqlalchemy import String, and_, cast, func, inspect, or_, select, tuple_
from sqlalchemy.exc import DBAPIError, NoInspectionAvailable, SQLAlchemyError
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import (
Expand Down Expand Up @@ -606,6 +597,10 @@ def build_order_clauses(
async def get_pk_value(self, request: Request, obj: Any) -> Any:
return await self.pk_field.parse_obj(request, obj)

async def get_serialized_pk_value(self, request: Request, obj: Any) -> Any:
value = await self.get_pk_value(request, obj)
return await self.pk_field.serialize_value(request, value, request.state.action)

def handle_exception(self, exc: Exception) -> None:
try:
"""Automatically handle sqlalchemy_file error"""
Expand Down
34 changes: 27 additions & 7 deletions starlette_admin/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -719,8 +719,8 @@ async def serialize(
obj_serialized[field.name] = None
elif isinstance(field, HasOne):
if action == RequestAction.EDIT:
obj_serialized[field.name] = await foreign_model.get_pk_value(
request, value
obj_serialized[field.name] = (
await foreign_model.get_serialized_pk_value(request, value)
)
else:
obj_serialized[field.name] = await foreign_model.serialize(
Expand All @@ -729,7 +729,7 @@ async def serialize(
else:
if action == RequestAction.EDIT:
obj_serialized[field.name] = [
(await foreign_model.get_pk_value(request, obj))
(await foreign_model.get_serialized_pk_value(request, obj))
for obj in value
]
else:
Expand All @@ -750,11 +750,14 @@ async def serialize(
"result": await self.select2_result(obj, request),
}
obj_meta["repr"] = await self.repr(obj, request)

# Make sure the primary key is always available
pk_attr = not_none(self.pk_attr)
if pk_attr not in obj_serialized:
pk_value = await self.get_serialized_pk_value(request, obj)
obj_serialized[pk_attr] = pk_value

pk = await self.get_pk_value(request, obj)
obj_serialized[not_none(self.pk_attr)] = obj_serialized.get(
not_none(self.pk_attr),
pk, # Make sure the primary key is always available
)
route_name = request.app.state.ROUTE_NAME
obj_meta["detailUrl"] = str(
request.url_for(route_name + ":detail", identity=self.identity, pk=pk)
Expand Down Expand Up @@ -879,6 +882,23 @@ async def select2_selection(self, obj: Any, request: Request) -> str:
async def get_pk_value(self, request: Request, obj: Any) -> Any:
return getattr(obj, not_none(self.pk_attr))

async def get_serialized_pk_value(self, request: Request, obj: Any) -> Any:
"""
Return serialized value of the primary key.
!!! note
The returned value should be JSON-serializable.
Parameters:
request: The request being processed
obj: object to get primary key of
Returns:
Any: Serialized value of a PK.
"""
return await self.get_pk_value(request, obj)

def _length_menu(self) -> Any:
return [
self.page_size_options,
Expand Down
92 changes: 92 additions & 0 deletions tests/sqla/test_view_serialization.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
import uuid

import pytest
import pytest_asyncio
from httpx import AsyncClient
from sqlalchemy import Boolean, Column, ForeignKey, String
from sqlalchemy.engine import Engine
from sqlalchemy.orm import Session, declarative_base, relationship
from starlette.applications import Starlette
from starlette_admin.contrib.sqla import Admin
from starlette_admin.contrib.sqla.view import ModelView

from tests.sqla.utils import Uuid, get_test_engine

Base = declarative_base()


class User(Base):
__tablename__ = "user"

id = Column(Uuid, primary_key=True, default=uuid.uuid1, unique=True)
name = Column(String(50))
membership = relationship("Membership", back_populates="user", uselist=False)


class Membership(Base):
__tablename__ = "membership"

id = Column(Uuid, primary_key=True, default=uuid.uuid1, unique=True)
is_active = Column(Boolean, default=True)

user_id = Column(Uuid, ForeignKey("user.id"), unique=True)
user = relationship("User", back_populates="membership")


class UserView(ModelView):
fields = ["name", "membership"]


@pytest.fixture
def engine() -> Engine:
engine = get_test_engine()

Base.metadata.create_all(engine)

try:
yield engine

finally:
Base.metadata.drop_all(engine)


@pytest.fixture
def app(engine: Engine):
app = Starlette()

admin = Admin(engine)
admin.add_view(UserView(User))
admin.add_view(ModelView(Membership))
admin.mount_to(app)

return app


@pytest_asyncio.fixture
async def client(app):
async with AsyncClient(app=app, base_url="http://testserver") as c:
yield c


@pytest.mark.asyncio
async def test_ensuring_pk(client: AsyncClient, engine: Engine):
"""
Ensures PK is present in the serialized data and properly serialized as a string.
"""
user_id = uuid.uuid1()
membership_id = uuid.uuid1()

user = User(id=user_id, name="Jack")
membership = Membership(id=membership_id, is_active=True, user=user)

with Session(engine) as session:
session.add(user)
session.add(membership)
session.commit()

response = await client.get("/admin/api/user")
data = response.json()

assert [(str(user_id), str(membership_id))] == [
(x["id"], x["membership"]["id"]) for x in data["items"]
]
25 changes: 25 additions & 0 deletions tests/sqla/utils.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
import os
import uuid

import sqlalchemy.types as types
from libcloud.storage.base import Container, StorageDriver
from libcloud.storage.drivers.local import LocalStorageDriver
from libcloud.storage.drivers.minio import MinIOStorageDriver
from libcloud.storage.types import ContainerDoesNotExistError
from sqlalchemy import create_engine
from sqlalchemy.dialects import mysql
from sqlalchemy.engine import Engine
from sqlalchemy.ext.asyncio import AsyncEngine, create_async_engine

Expand Down Expand Up @@ -41,3 +44,25 @@ def get_test_container(name: str) -> Container:
dir_path = os.environ.get("LOCAL_PATH", "/tmp/storage")
os.makedirs(dir_path, 0o777, exist_ok=True)
return get_or_create_container(LocalStorageDriver(dir_path), name)


class Uuid(types.TypeDecorator):
"""
Platform-independent UUID type for testing.
"""

impl = types.CHAR

def load_dialect_impl(self, dialect):
return mysql.CHAR(32) if dialect == "mysql" else types.CHAR(32)

def process_bind_param(self, value, dialect):

return value.hex

def process_result_value(self, value, dialect):

if not isinstance(value, uuid.UUID):
value = uuid.UUID(value)

return value

0 comments on commit 056d976

Please sign in to comment.