Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Add type hints to synapse/storage/databases/main/account_data.py #11546

Merged
merged 4 commits into from
Dec 13, 2021
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.d/11546.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Add missing type hints to storage classes.
4 changes: 3 additions & 1 deletion mypy.ini
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@ exclude = (?x)
^(
|synapse/storage/databases/__init__.py
|synapse/storage/databases/main/__init__.py
|synapse/storage/databases/main/account_data.py
|synapse/storage/databases/main/cache.py
|synapse/storage/databases/main/devices.py
|synapse/storage/databases/main/e2e_room_keys.py
Expand Down Expand Up @@ -181,6 +180,9 @@ disallow_untyped_defs = True
[mypy-synapse.state.*]
disallow_untyped_defs = True

[mypy-synapse.storage.databases.main.account_data]
disallow_untyped_defs = True

[mypy-synapse.storage.databases.main.client_ips]
disallow_untyped_defs = True

Expand Down
85 changes: 61 additions & 24 deletions synapse/storage/databases/main/account_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,25 @@
# limitations under the License.

import logging
from typing import TYPE_CHECKING, Dict, List, Optional, Set, Tuple
from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Set, Tuple, cast

from synapse.api.constants import AccountDataTypes
from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
from synapse.replication.tcp.streams import AccountDataStream, TagAccountDataStream
from synapse.storage._base import SQLBaseStore, db_to_json
from synapse.storage.database import DatabasePool
from synapse.storage._base import db_to_json
from synapse.storage.database import (
DatabasePool,
LoggingDatabaseConnection,
LoggingTransaction,
)
from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore
from synapse.storage.engines import PostgresEngine
from synapse.storage.util.id_generators import MultiWriterIdGenerator, StreamIdGenerator
from synapse.storage.util.id_generators import (
AbstractStreamIdGenerator,
AbstractStreamIdTracker,
MultiWriterIdGenerator,
StreamIdGenerator,
)
from synapse.types import JsonDict
from synapse.util import json_encoder
from synapse.util.caches.descriptors import cached
Expand All @@ -34,14 +44,24 @@
logger = logging.getLogger(__name__)


class AccountDataWorkerStore(SQLBaseStore):
class AccountDataWorkerStore(CacheInvalidationWorkerStore):
"""This is an abstract base class where subclasses must implement
`get_max_account_data_stream_id` which can be called in the initializer.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this comment accurate? There appears to be an implementation of this function on +114 with no override that PyCharm recognises.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks inaccurate, I'll remove it.

"""

def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
def __init__(
self,
database: DatabasePool,
db_conn: LoggingDatabaseConnection,
hs: "HomeServer",
):
self._instance_name = hs.get_instance_name()

# `_can_write_to_account_data` indicates whether the current worker is allowed
# to write account data. A value of `True` implies that `_account_data_id_gen`
# is an `AbstractStreamIdGenerator` and not just a tracker.
self._account_data_id_gen: AbstractStreamIdTracker

if isinstance(database.engine, PostgresEngine):
self._can_write_to_account_data = (
self._instance_name in hs.config.worker.writers.account_data
Expand All @@ -61,8 +81,6 @@ def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
writers=hs.config.worker.writers.account_data,
)
else:
self._can_write_to_account_data = True

# We shouldn't be running in worker mode with SQLite, but its useful
# to support it for unit tests.
#
Expand All @@ -71,6 +89,7 @@ def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
# updated over replication. (Multiple writers are not supported for
# SQLite).
if hs.get_instance_name() in hs.config.worker.writers.account_data:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We have self._instance_name already defined on this class. But we also have that on the parent subclass. Can we use self._instance_name everywhere and remove the redundant overload?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good spot!

self._can_write_to_account_data = True
DMRobertson marked this conversation as resolved.
Show resolved Hide resolved
self._account_data_id_gen = StreamIdGenerator(
db_conn,
"room_account_data",
Expand Down Expand Up @@ -113,7 +132,9 @@ async def get_account_data_for_user(
room_id string to per room account_data dicts.
"""

def get_account_data_for_user_txn(txn):
def get_account_data_for_user_txn(
txn: LoggingTransaction,
) -> Tuple[Dict[str, JsonDict], Dict[str, Dict[str, JsonDict]]]:
rows = self.db_pool.simple_select_list_txn(
txn,
"account_data",
Expand All @@ -132,7 +153,7 @@ def get_account_data_for_user_txn(txn):
["room_id", "account_data_type", "content"],
)

by_room = {}
by_room: Dict[str, Dict[str, JsonDict]] = {}
for row in rows:
room_data = by_room.setdefault(row["room_id"], {})
room_data[row["account_data_type"]] = db_to_json(row["content"])
Expand Down Expand Up @@ -177,7 +198,9 @@ async def get_account_data_for_room(
A dict of the room account_data
"""

def get_account_data_for_room_txn(txn):
def get_account_data_for_room_txn(
txn: LoggingTransaction,
) -> Dict[str, JsonDict]:
rows = self.db_pool.simple_select_list_txn(
txn,
"room_account_data",
Expand Down Expand Up @@ -207,7 +230,9 @@ async def get_account_data_for_room_and_type(
The room account_data for that type, or None if there isn't any set.
"""

def get_account_data_for_room_and_type_txn(txn):
def get_account_data_for_room_and_type_txn(
txn: LoggingTransaction,
) -> Optional[JsonDict]:
content_json = self.db_pool.simple_select_one_onecol_txn(
txn,
table="room_account_data",
Expand Down Expand Up @@ -243,14 +268,16 @@ async def get_updated_global_account_data(
if last_id == current_id:
return []

def get_updated_global_account_data_txn(txn):
def get_updated_global_account_data_txn(
txn: LoggingTransaction,
) -> List[Tuple[int, str, str]]:
sql = (
"SELECT stream_id, user_id, account_data_type"
" FROM account_data WHERE ? < stream_id AND stream_id <= ?"
" ORDER BY stream_id ASC LIMIT ?"
)
txn.execute(sql, (last_id, current_id, limit))
return txn.fetchall()
return cast(List[Tuple[int, str, str]], txn.fetchall())
DMRobertson marked this conversation as resolved.
Show resolved Hide resolved

return await self.db_pool.runInteraction(
"get_updated_global_account_data", get_updated_global_account_data_txn
Expand All @@ -273,14 +300,16 @@ async def get_updated_room_account_data(
if last_id == current_id:
return []

def get_updated_room_account_data_txn(txn):
def get_updated_room_account_data_txn(
txn: LoggingTransaction,
) -> List[Tuple[int, str, str, str]]:
sql = (
"SELECT stream_id, user_id, room_id, account_data_type"
" FROM room_account_data WHERE ? < stream_id AND stream_id <= ?"
" ORDER BY stream_id ASC LIMIT ?"
)
txn.execute(sql, (last_id, current_id, limit))
return txn.fetchall()
return cast(List[Tuple[int, str, str, str]], txn.fetchall())

return await self.db_pool.runInteraction(
"get_updated_room_account_data", get_updated_room_account_data_txn
Expand All @@ -299,7 +328,9 @@ async def get_updated_account_data_for_user(
mapping from room_id string to per room account_data dicts.
"""

def get_updated_account_data_for_user_txn(txn):
def get_updated_account_data_for_user_txn(
txn: LoggingTransaction,
) -> Tuple[Dict[str, JsonDict], Dict[str, Dict[str, JsonDict]]]:
sql = (
"SELECT account_data_type, content FROM account_data"
" WHERE user_id = ? AND stream_id > ?"
Expand All @@ -316,7 +347,7 @@ def get_updated_account_data_for_user_txn(txn):

txn.execute(sql, (user_id, stream_id))

account_data_by_room = {}
account_data_by_room: Dict[str, Dict[str, JsonDict]] = {}
for row in txn:
room_account_data = account_data_by_room.setdefault(row[0], {})
room_account_data[row[1]] = db_to_json(row[2])
Expand Down Expand Up @@ -353,12 +384,15 @@ async def ignored_by(self, user_id: str) -> Set[str]:
)
)

def process_replication_rows(self, stream_name, instance_name, token, rows):
def process_replication_rows(
self,
stream_name: str,
instance_name: str,
token: int,
rows: Iterable[Any],
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is an iterable over some kind of rowproxy? But it looks like this gets used throughout the inheritance hierarchy in way that makes type checking not feasible.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I couldn't figure out what the Any was supposed to be. For now I've been using Iterable[Any] for all overrides of this method.

I think it's a Union[any type in Stream.ROW_TYPE]?

None of the row types share any fields in common so Any or object is the best we can do without refactoring

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe we could make Stream generic over its row type? One for another day though.

) -> None:
if stream_name == TagAccountDataStream.NAME:
self._account_data_id_gen.advance(instance_name, token)
for row in rows:
self.get_tags_for_user.invalidate((row.user_id,))
self._account_data_stream_cache.entity_has_changed(row.user_id, token)
Comment on lines -359 to -361
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

At some point in the past, TagsWorkerStore was split out from AccountDataWorkerStore (ca9b9d9). I think this was missed.

I left the call to advance in, in case we ever instantiate an AccountDataWorkerStore that is not a TagsWorkerStore.
advance ought to be idempotent so it should be safe to call it twice.

I think we don't need to call entity_has_changed here, since it only impacts one spot, which doesn't appear to have anything to do with tags:

changed = self._account_data_stream_cache.has_entity_changed(

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I left the call to advance in, in case we ever instantiate an AccountDataWorkerStore that is not a TagsWorkerStore.

Say we did instantiate such a store. What does leaving the advance call in gain us? It feels a bit odd to have something that looks like it's responding/handling to tag changes, but not really do anything with it.

Copy link
Contributor Author

@squahtx squahtx Dec 10, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Without the advance, _account_data_id_gen.get_current_token() would lag behind since it's the minimum stream position across all writers. In the worst case it'd get stuck entirely if there is a writer that only writes tag data.

I think the effect of that would be that various bits of Synapse would not get notified about data written after get_current_token(), since it can't be sure that all writes had completed. I can't say which bits for sure without doing more digging.

elif stream_name == AccountDataStream.NAME:
self._account_data_id_gen.advance(instance_name, token)
for row in rows:
Expand All @@ -372,7 +406,8 @@ def process_replication_rows(self, stream_name, instance_name, token, rows):
(row.user_id, row.room_id, row.data_type)
)
self._account_data_stream_cache.entity_has_changed(row.user_id, token)
return super().process_replication_rows(stream_name, instance_name, token, rows)

super().process_replication_rows(stream_name, instance_name, token, rows)

async def add_account_data_to_room(
self, user_id: str, room_id: str, account_data_type: str, content: JsonDict
Expand All @@ -389,6 +424,7 @@ async def add_account_data_to_room(
The maximum stream ID.
"""
assert self._can_write_to_account_data
assert isinstance(self._account_data_id_gen, AbstractStreamIdGenerator)

content_json = json_encoder.encode(content)

Expand Down Expand Up @@ -431,6 +467,7 @@ async def add_account_data_for_user(
The maximum stream ID.
"""
assert self._can_write_to_account_data
assert isinstance(self._account_data_id_gen, AbstractStreamIdGenerator)

async with self._account_data_id_gen.get_next() as next_id:
await self.db_pool.runInteraction(
Expand All @@ -452,7 +489,7 @@ async def add_account_data_for_user(

def _add_account_data_for_user(
self,
txn,
txn: LoggingTransaction,
next_id: int,
user_id: str,
account_data_type: str,
Expand Down
22 changes: 21 additions & 1 deletion synapse/storage/databases/main/tags.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,13 @@
# limitations under the License.

import logging
from typing import Dict, List, Tuple, cast
from typing import Any, Dict, Iterable, List, Tuple, cast

from synapse.replication.tcp.streams import TagAccountDataStream
from synapse.storage._base import db_to_json
from synapse.storage.database import LoggingTransaction
from synapse.storage.databases.main.account_data import AccountDataWorkerStore
from synapse.storage.util.id_generators import AbstractStreamIdGenerator
from synapse.types import JsonDict
from synapse.util import json_encoder
from synapse.util.caches.descriptors import cached
Expand Down Expand Up @@ -204,6 +206,7 @@ async def add_tag_to_room(
The next account data ID.
"""
assert self._can_write_to_account_data
assert isinstance(self._account_data_id_gen, AbstractStreamIdGenerator)

content_json = json_encoder.encode(content)

Expand All @@ -230,6 +233,7 @@ async def remove_tag_from_room(self, user_id: str, room_id: str, tag: str) -> in
The next account data ID.
"""
assert self._can_write_to_account_data
assert isinstance(self._account_data_id_gen, AbstractStreamIdGenerator)

def remove_tag_txn(txn: LoggingTransaction, next_id: int) -> None:
sql = (
Expand Down Expand Up @@ -258,6 +262,7 @@ def _update_revision_txn(
next_id: The the revision to advance to.
"""
assert self._can_write_to_account_data
assert isinstance(self._account_data_id_gen, AbstractStreamIdGenerator)

txn.call_after(
self._account_data_stream_cache.entity_has_changed, user_id, next_id
Expand Down Expand Up @@ -287,6 +292,21 @@ def _update_revision_txn(
# than the id that the client has.
pass

def process_replication_rows(
self,
stream_name: str,
instance_name: str,
token: int,
rows: Iterable[Any],
) -> None:
if stream_name == TagAccountDataStream.NAME:
self._account_data_id_gen.advance(instance_name, token)
for row in rows:
self.get_tags_for_user.invalidate((row.user_id,))
self._account_data_stream_cache.entity_has_changed(row.user_id, token)

super().process_replication_rows(stream_name, instance_name, token, rows)


class TagsStore(TagsWorkerStore):
pass