Skip to content

Commit

Permalink
Add forced PK serialization
Browse files Browse the repository at this point in the history
When PK is not listed among fields it needs proper serialization in case
it's not of a JSON-serializable type already (for example, UUID).
  • Loading branch information
Aleksey Gureiev committed Jul 3, 2024
1 parent 41a59cc commit 92562a5
Show file tree
Hide file tree
Showing 4 changed files with 156 additions and 17 deletions.
15 changes: 5 additions & 10 deletions starlette_admin/contrib/sqla/view.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,7 @@
from typing import Any, ClassVar, Dict, List, Optional, Sequence, Tuple, Type, Union

import anyio.to_thread
from sqlalchemy import (
String,
and_,
cast,
func,
inspect,
or_,
select,
tuple_,
)
from sqlalchemy import String, and_, cast, func, inspect, or_, select, tuple_
from sqlalchemy.exc import DBAPIError, NoInspectionAvailable, SQLAlchemyError
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import (
Expand Down Expand Up @@ -606,6 +597,10 @@ def build_order_clauses(
async def get_pk_value(self, request: Request, obj: Any) -> Any:
return await self.pk_field.parse_obj(request, obj)

async def get_serialized_pk_value(self, request: Request, obj: Any) -> Any:
value = await self.get_pk_value(request, obj)
return await self.pk_field.serialize_value(request, value, request.state.action)

def handle_exception(self, exc: Exception) -> None:
try:
"""Automatically handle sqlalchemy_file error"""
Expand Down
34 changes: 27 additions & 7 deletions starlette_admin/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -719,8 +719,8 @@ async def serialize(
obj_serialized[field.name] = None
elif isinstance(field, HasOne):
if action == RequestAction.EDIT:
obj_serialized[field.name] = await foreign_model.get_pk_value(
request, value
obj_serialized[field.name] = (
await foreign_model.get_serialized_pk_value(request, value)
)
else:
obj_serialized[field.name] = await foreign_model.serialize(
Expand All @@ -729,7 +729,7 @@ async def serialize(
else:
if action == RequestAction.EDIT:
obj_serialized[field.name] = [
(await foreign_model.get_pk_value(request, obj))
(await foreign_model.get_serialized_pk_value(request, obj))
for obj in value
]
else:
Expand All @@ -750,11 +750,14 @@ async def serialize(
"result": await self.select2_result(obj, request),
}
obj_meta["repr"] = await self.repr(obj, request)

# Make sure the primary key is always available
pk_attr = not_none(self.pk_attr)
if pk_attr not in obj_serialized:
pk_value = await self.get_serialized_pk_value(request, obj)
obj_serialized[pk_attr] = pk_value

pk = await self.get_pk_value(request, obj)
obj_serialized[not_none(self.pk_attr)] = obj_serialized.get(
not_none(self.pk_attr),
pk, # Make sure the primary key is always available
)
route_name = request.app.state.ROUTE_NAME
obj_meta["detailUrl"] = str(
request.url_for(route_name + ":detail", identity=self.identity, pk=pk)
Expand Down Expand Up @@ -879,6 +882,23 @@ async def select2_selection(self, obj: Any, request: Request) -> str:
async def get_pk_value(self, request: Request, obj: Any) -> Any:
return getattr(obj, not_none(self.pk_attr))

async def get_serialized_pk_value(self, request: Request, obj: Any) -> Any:
"""
Return serialized value of the primary key.
!!! note
The returned value should be JSON-serializable.
Parameters:
request: The request being processed
obj: object to get primary key of
Returns:
Any: Serialized value of a PK.
"""
return await self.get_pk_value(request, obj)

def _length_menu(self) -> Any:
return [
self.page_size_options,
Expand Down
92 changes: 92 additions & 0 deletions tests/sqla/test_view_serialization.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
import uuid

import pytest
import pytest_asyncio
from httpx import AsyncClient
from sqlalchemy import Boolean, Column, ForeignKey, String
from sqlalchemy.engine import Engine
from sqlalchemy.orm import Session, declarative_base, relationship
from starlette.applications import Starlette
from starlette_admin.contrib.sqla import Admin
from starlette_admin.contrib.sqla.view import ModelView

from tests.sqla.utils import Uuid, get_test_engine

Base = declarative_base()


class User(Base):
__tablename__ = "user"

id = Column(Uuid, primary_key=True, default=uuid.uuid1, unique=True)
name = Column(String())
membership = relationship("Membership", back_populates="user", uselist=False)


class Membership(Base):
__tablename__ = "membership"

id = Column(Uuid, primary_key=True, default=uuid.uuid1, unique=True)
is_active = Column(Boolean, default=True)

user_id = Column(Uuid, ForeignKey("user.id"), unique=True)
user = relationship("User", back_populates="membership")


class UserView(ModelView):
fields = ["name", "membership"]


@pytest.fixture
def engine() -> Engine:
engine = get_test_engine()

Base.metadata.create_all(engine)

try:
yield engine

finally:
Base.metadata.drop_all(engine)


@pytest.fixture
def app(engine: Engine):
app = Starlette()

admin = Admin(engine)
admin.add_view(UserView(User))
admin.add_view(ModelView(Membership))
admin.mount_to(app)

return app


@pytest_asyncio.fixture
async def client(app):
async with AsyncClient(app=app, base_url="http://testserver") as c:
yield c


@pytest.mark.asyncio
async def test_ensuring_pk(client: AsyncClient, engine: Engine):
"""
Ensures PK is present in the serialized data and properly serialized as a string.
"""
user_id = uuid.uuid1()
membership_id = uuid.uuid1()

user = User(id=user_id, name="Jack")
membership = Membership(id=membership_id, is_active=True, user=user)

with Session(engine) as session:
session.add(user)
session.add(membership)
session.commit()

response = await client.get("/admin/api/user")
data = response.json()

assert [(str(user_id), str(membership_id))] == [
(x["id"], x["membership"]["id"]) for x in data["items"]
]
32 changes: 32 additions & 0 deletions tests/sqla/utils.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
import os
import uuid

import sqlalchemy.types as types
from libcloud.storage.base import Container, StorageDriver
from libcloud.storage.drivers.local import LocalStorageDriver
from libcloud.storage.drivers.minio import MinIOStorageDriver
from libcloud.storage.types import ContainerDoesNotExistError
from sqlalchemy import create_engine
from sqlalchemy.dialects import mysql
from sqlalchemy.engine import Engine
from sqlalchemy.ext.asyncio import AsyncEngine, create_async_engine

Expand Down Expand Up @@ -41,3 +44,32 @@ def get_test_container(name: str) -> Container:
dir_path = os.environ.get("LOCAL_PATH", "/tmp/storage")
os.makedirs(dir_path, 0o777, exist_ok=True)
return get_or_create_container(LocalStorageDriver(dir_path), name)


class Uuid(types.TypeDecorator):
"""
Platform-independent UUID type for testing.
"""

impl = types.CHAR

def load_dialect_impl(self, dialect):
return mysql.CHAR(32) if dialect == "mysql" else types.CHAR(32)

def process_bind_param(self, value, dialect):
if value is None:
return value

if not isinstance(value, uuid.UUID):
value = uuid.UUID(value)

return value.hex

def process_result_value(self, value, dialect):
if value is None:
return value

if not isinstance(value, uuid.UUID):
value = uuid.UUID(value)

return value

0 comments on commit 92562a5

Please sign in to comment.