Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Remove more usages of cursor_to_dict #16551

Merged
merged 12 commits into from
Oct 26, 2023
18 changes: 9 additions & 9 deletions synapse/handlers/identity.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
import urllib.parse
from typing import TYPE_CHECKING, Awaitable, Callable, Dict, List, Optional, Tuple

import attr

from synapse.api.errors import (
CodeMessageException,
Codes,
Expand Down Expand Up @@ -357,9 +359,9 @@ async def send_threepid_validation(

# Check to see if a session already exists and that it is not yet
# marked as validated
if session and session.get("validated_at") is None:
session_id = session["session_id"]
last_send_attempt = session["last_send_attempt"]
if session and session.validated_at is None:
session_id = session.session_id
last_send_attempt = session.last_send_attempt

# Check that the send_attempt is higher than previous attempts
if send_attempt <= last_send_attempt:
Expand Down Expand Up @@ -480,27 +482,25 @@ async def validate_threepid_session(

# We don't actually know which medium this 3PID is. Thus we first assume it's email,
# and if validation fails we try msisdn
validation_session = None

# Try to validate as email
if self.hs.config.email.can_verify_email:
# Get a validated session matching these details
validation_session = await self.store.get_threepid_validation_session(
"email", client_secret, sid=sid, validated=True
)

if validation_session:
return validation_session
if validation_session:
return attr.asdict(validation_session)

# Try to validate as msisdn
if self.hs.config.registration.account_threepid_delegate_msisdn:
# Ask our delegated msisdn identity server
validation_session = await self.threepid_from_creds(
return await self.threepid_from_creds(
self.hs.config.registration.account_threepid_delegate_msisdn,
threepid_creds,
)

return validation_session
return None

async def proxy_msisdn_submit_token(
self, id_server: str, client_secret: str, sid: str, token: str
Expand Down
6 changes: 3 additions & 3 deletions synapse/handlers/ui_auth/checkers.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,9 +187,9 @@ async def _check_threepid(self, medium: str, authdict: dict) -> dict:

if row:
threepid = {
"medium": row["medium"],
"address": row["address"],
"validated_at": row["validated_at"],
"medium": row.medium,
"address": row.address,
"validated_at": row.validated_at,
}

# Valid threepid returned, delete from the db
Expand Down
5 changes: 1 addition & 4 deletions synapse/media/media_repository.py
Original file line number Diff line number Diff line change
Expand Up @@ -949,10 +949,7 @@ async def delete_old_remote_media(self, before_ts: int) -> Dict[str, int]:

deleted = 0

for media in old_media:
origin = media["media_origin"]
media_id = media["media_id"]
file_id = media["filesystem_id"]
for origin, media_id, file_id in old_media:
key = (origin, media_id)

logger.info("Deleting: %r", key)
Expand Down
12 changes: 11 additions & 1 deletion synapse/rest/admin/rooms.py
Original file line number Diff line number Diff line change
Expand Up @@ -724,7 +724,17 @@ async def on_GET(
room_id, _ = await self.resolve_room_id(room_identifier)

extremities = await self.store.get_forward_extremities_for_room(room_id)
return HTTPStatus.OK, {"count": len(extremities), "results": extremities}
result = [
{
"event_id": ex[0],
"state_group": ex[1],
"depth": ex[2],
"received_ts": ex[3],
}
for ex in extremities
]

return HTTPStatus.OK, {"count": len(extremities), "results": result}


class RoomEventContextServlet(RestServlet):
Expand Down
13 changes: 12 additions & 1 deletion synapse/rest/admin/statistics.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,18 @@ async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
users_media, total = await self.store.get_users_media_usage_paginate(
start, limit, from_ts, until_ts, order_by, direction, search_term
)
ret = {"users": users_media, "total": total}
ret = {
"users": [
{
"user_id": r[0],
"displayname": r[1],
"media_count": r[2],
"media_length": r[3],
}
for r in users_media
],
"total": total,
}
if (start + limit) < total:
ret["next_token"] = start + len(users_media)

Expand Down
15 changes: 10 additions & 5 deletions synapse/storage/databases/main/events_forward_extremities.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.

import logging
from typing import Any, Dict, List
from typing import List, Tuple, cast

from synapse.api.errors import SynapseError
from synapse.storage.database import LoggingTransaction
Expand Down Expand Up @@ -91,12 +91,17 @@ def delete_forward_extremities_for_room_txn(txn: LoggingTransaction) -> int:

async def get_forward_extremities_for_room(
self, room_id: str
) -> List[Dict[str, Any]]:
"""Get list of forward extremities for a room."""
) -> List[Tuple[str, int, int, int]]:
"""
Get list of forward extremities for a room.

Returns:
A list of tuples of event_id, state_group, depth, and received_ts.
"""

def get_forward_extremities_for_room_txn(
txn: LoggingTransaction,
) -> List[Dict[str, Any]]:
) -> List[Tuple[str, int, int, int]]:
clokep marked this conversation as resolved.
Show resolved Hide resolved
sql = """
SELECT event_id, state_group, depth, received_ts
FROM event_forward_extremities
Expand All @@ -106,7 +111,7 @@ def get_forward_extremities_for_room_txn(
"""

txn.execute(sql, (room_id,))
return self.db_pool.cursor_to_dict(txn)
return cast(List[Tuple[str, int, int, int]], txn.fetchall())

return await self.db_pool.runInteraction(
"get_forward_extremities_for_room",
Expand Down
19 changes: 11 additions & 8 deletions synapse/storage/databases/main/media_repository.py
Original file line number Diff line number Diff line change
Expand Up @@ -652,7 +652,7 @@ async def store_remote_media_thumbnail(

async def get_remote_media_ids(
self, before_ts: int, include_quarantined_media: bool
) -> List[Dict[str, str]]:
) -> List[Tuple[str, str, str]]:
"""
Retrieve a list of server name, media ID tuples from the remote media cache.

Expand All @@ -666,21 +666,24 @@ async def get_remote_media_ids(
A list of tuples containing:
* The server name of homeserver where the media originates from,
* The ID of the media.
* The filesystem ID.
"""

sql = """
SELECT media_origin, media_id, filesystem_id
FROM remote_media_cache
DMRobertson marked this conversation as resolved.
Show resolved Hide resolved
WHERE last_access_ts < ?
"""
sql = (
"SELECT media_origin, media_id, filesystem_id"
" FROM remote_media_cache"
" WHERE last_access_ts < ?"
)

if include_quarantined_media is False:
# Only include media that has not been quarantined
sql += """
AND quarantined_by IS NULL
"""

return await self.db_pool.execute(
"get_remote_media_ids", self.db_pool.cursor_to_dict, sql, before_ts
return cast(
List[Tuple[str, str, str]],
await self.db_pool.execute("get_remote_media_ids", None, sql, before_ts),
)

async def delete_remote_media(self, media_origin: str, media_id: str) -> None:
Expand Down
43 changes: 29 additions & 14 deletions synapse/storage/databases/main/registration.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,22 @@ class ThreepidResult:
added_at: int


@attr.s(frozen=True, slots=True, auto_attribs=True)
class ThreepidValidationSession:
address: str
"""address of the 3pid"""
medium: str
"""medium of the 3pid"""
client_secret: str
"""a secret provided by the client for this validation session"""
session_id: str
"""ID of the validation session"""
last_send_attempt: int
"""a number serving to dedupe send attempts for this session"""
validated_at: Optional[int]
"""timestamp of when this session was validated if so"""


class RegistrationWorkerStore(CacheInvalidationWorkerStore):
def __init__(
self,
Expand Down Expand Up @@ -1156,7 +1172,7 @@ async def get_threepid_validation_session(
address: Optional[str] = None,
sid: Optional[str] = None,
validated: Optional[bool] = True,
) -> Optional[Dict[str, Any]]:
) -> Optional[ThreepidValidationSession]:
"""Gets a session_id and last_send_attempt (if available) for a
combination of validation metadata

Expand All @@ -1171,15 +1187,7 @@ async def get_threepid_validation_session(
perform no filtering

Returns:
A dict containing the following:
* address - address of the 3pid
* medium - medium of the 3pid
* client_secret - a secret provided by the client for this validation session
* session_id - ID of the validation session
* send_attempt - a number serving to dedupe send attempts for this session
* validated_at - timestamp of when this session was validated if so

Otherwise None if a validation session is not found
A ThreepidValidationSession or None if a validation session is not found
"""
if not client_secret:
raise SynapseError(
Expand All @@ -1198,7 +1206,7 @@ async def get_threepid_validation_session(

def get_threepid_validation_session_txn(
txn: LoggingTransaction,
) -> Optional[Dict[str, Any]]:
) -> Optional[ThreepidValidationSession]:
sql = """
SELECT address, session_id, medium, client_secret,
last_send_attempt, validated_at
Expand All @@ -1213,11 +1221,18 @@ def get_threepid_validation_session_txn(
sql += " LIMIT 1"

txn.execute(sql, list(keyvalues.values()))
rows = self.db_pool.cursor_to_dict(txn)
if not rows:
row = txn.fetchone()
if not row:
return None

return rows[0]
return ThreepidValidationSession(
address=row[0],
session_id=row[1],
medium=row[2],
client_secret=row[3],
last_send_attempt=row[4],
validated_at=row[5],
)

return await self.db_pool.runInteraction(
"get_threepid_validation_session", get_threepid_validation_session_txn
Expand Down
15 changes: 10 additions & 5 deletions synapse/storage/databases/main/stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -679,7 +679,7 @@ async def get_users_media_usage_paginate(
order_by: Optional[str] = UserSortOrder.USER_ID.value,
direction: Direction = Direction.FORWARDS,
search_term: Optional[str] = None,
) -> Tuple[List[JsonDict], int]:
) -> Tuple[List[Tuple[str, Optional[str], int, int]], int]:
"""Function to retrieve a paginated list of users and their uploaded local media
(size and number). This will return a json list of users and the
total number of users matching the filter criteria.
Expand All @@ -692,14 +692,19 @@ async def get_users_media_usage_paginate(
order_by: the sort order of the returned list
direction: sort ascending or descending
search_term: a string to filter user names by

Returns:
A list of user dicts and an integer representing the total number of
users that exist given this query
A tuple of:
A list of tuples of user information (the user ID, displayname,
total number of media, total length of media) and

An integer representing the total number of users that exist
given this query
"""

def get_users_media_usage_paginate_txn(
txn: LoggingTransaction,
) -> Tuple[List[JsonDict], int]:
) -> Tuple[List[Tuple[str, Optional[str], int, int]], int]:
filters = []
args: list = []

Expand Down Expand Up @@ -773,7 +778,7 @@ def get_users_media_usage_paginate_txn(

args += [limit, start]
txn.execute(sql, args)
users = self.db_pool.cursor_to_dict(txn)
users = cast(List[Tuple[str, Optional[str], int, int]], txn.fetchall())
DMRobertson marked this conversation as resolved.
Show resolved Hide resolved

return users, count

Expand Down