Skip to content

Commit

Permalink
Add forced PK serialization
Browse files Browse the repository at this point in the history
When PK is not listed among fields it needs proper serialization in case
it's not of a JSON-serializable type already (for example, UUID).
  • Loading branch information
Aleksey Gureiev committed Jul 1, 2024
1 parent 41a59cc commit d58e84d
Show file tree
Hide file tree
Showing 3 changed files with 104 additions and 17 deletions.
15 changes: 5 additions & 10 deletions starlette_admin/contrib/sqla/view.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,7 @@
from typing import Any, ClassVar, Dict, List, Optional, Sequence, Tuple, Type, Union

import anyio.to_thread
from sqlalchemy import (
String,
and_,
cast,
func,
inspect,
or_,
select,
tuple_,
)
from sqlalchemy import String, and_, cast, func, inspect, or_, select, tuple_
from sqlalchemy.exc import DBAPIError, NoInspectionAvailable, SQLAlchemyError
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import (
Expand Down Expand Up @@ -606,6 +597,10 @@ def build_order_clauses(
async def get_pk_value(self, request: Request, obj: Any) -> Any:
return await self.pk_field.parse_obj(request, obj)

async def get_serialized_pk_value(self, request: Request, obj: Any) -> Any:
value = await self.get_pk_value(request, obj)
return await self.pk_field.serialize_value(request, value, request.state.action)

def handle_exception(self, exc: Exception) -> None:
try:
"""Automatically handle sqlalchemy_file error"""
Expand Down
34 changes: 27 additions & 7 deletions starlette_admin/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -719,8 +719,8 @@ async def serialize(
obj_serialized[field.name] = None
elif isinstance(field, HasOne):
if action == RequestAction.EDIT:
obj_serialized[field.name] = await foreign_model.get_pk_value(
request, value
obj_serialized[field.name] = (
await foreign_model.get_serialized_pk_value(request, value)
)
else:
obj_serialized[field.name] = await foreign_model.serialize(
Expand All @@ -729,7 +729,7 @@ async def serialize(
else:
if action == RequestAction.EDIT:
obj_serialized[field.name] = [
(await foreign_model.get_pk_value(request, obj))
(await foreign_model.get_serialized_pk_value(request, obj))
for obj in value
]
else:
Expand All @@ -750,11 +750,14 @@ async def serialize(
"result": await self.select2_result(obj, request),
}
obj_meta["repr"] = await self.repr(obj, request)

# Make sure the primary key is always available
pk_attr = not_none(self.pk_attr)
if pk_attr not in obj_serialized:
pk_value = await self.get_serialized_pk_value(request, obj)
obj_serialized[pk_attr] = pk_value

pk = await self.get_pk_value(request, obj)
obj_serialized[not_none(self.pk_attr)] = obj_serialized.get(
not_none(self.pk_attr),
pk, # Make sure the primary key is always available
)
route_name = request.app.state.ROUTE_NAME
obj_meta["detailUrl"] = str(
request.url_for(route_name + ":detail", identity=self.identity, pk=pk)
Expand Down Expand Up @@ -879,6 +882,23 @@ async def select2_selection(self, obj: Any, request: Request) -> str:
async def get_pk_value(self, request: Request, obj: Any) -> Any:
return getattr(obj, not_none(self.pk_attr))

async def get_serialized_pk_value(self, request: Request, obj: Any) -> Any:
"""
Return serialized value of the primary key.
!!! note
The returned value should be JSON-serializable.
Parameters:
request: The request being processed
obj: object to get primary key of
Returns:
Any: Serialized value of a PK.
"""
return await self.get_pk_value(request, obj)

def _length_menu(self) -> Any:
return [
self.page_size_options,
Expand Down
72 changes: 72 additions & 0 deletions tests/sqla/test_view_serialization.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
import uuid

import pytest
from sqlalchemy import Column, String
from sqlalchemy.dialects.postgresql import UUID
from sqlalchemy.orm import declarative_base
from starlette.applications import Starlette
from starlette.requests import Request
from starlette_admin import RequestAction, StringField
from starlette_admin.base import BaseAdmin
from starlette_admin.contrib.sqla.view import ModelView

Base = declarative_base()


class User(Base):
__tablename__ = "user"

id: UUID = Column(
UUID(as_uuid=True), primary_key=True, default=uuid.uuid4, unique=True
)
name: str = Column(String())


class UserView(ModelView):
fields = [
StringField("name"),
]


@pytest.fixture
def user_view():
return UserView(User)


@pytest.fixture
def app(user_view: UserView):
app = Starlette()
app.state.ROUTE_NAME = "admin"

admin = BaseAdmin()
admin.add_view(user_view)
admin.mount_to(app)

return app


@pytest.fixture
def req(app: Starlette):
return Request(
{
"app": app,
"router": app.router,
"type": "http",
"headers": [],
}
)


@pytest.mark.asyncio
async def test_ensuring_pk(req: Request, user_view: UserView):
"""
Ensures PK is present in the serialized data and properly serialized as a string.
"""

user_id = uuid.uuid1()
user = User(id=user_id, name="Jack")

req.state.action = RequestAction.LIST
data = await user_view.serialize(user, req, RequestAction.LIST)

assert data["id"] == str(user_id)

0 comments on commit d58e84d

Please sign in to comment.