-
-
Notifications
You must be signed in to change notification settings - Fork 2.1k
Add type hints to synapse/storage/databases/main/account_data.py
#11546
Changes from 1 commit
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
Add missing type hints to storage classes. |
Original file line number | Diff line number | Diff line change | ||
---|---|---|---|---|
|
@@ -14,15 +14,25 @@ | |||
# limitations under the License. | ||||
|
||||
import logging | ||||
from typing import TYPE_CHECKING, Dict, List, Optional, Set, Tuple | ||||
from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Set, Tuple, cast | ||||
|
||||
from synapse.api.constants import AccountDataTypes | ||||
from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker | ||||
from synapse.replication.tcp.streams import AccountDataStream, TagAccountDataStream | ||||
from synapse.storage._base import SQLBaseStore, db_to_json | ||||
from synapse.storage.database import DatabasePool | ||||
from synapse.storage._base import db_to_json | ||||
from synapse.storage.database import ( | ||||
DatabasePool, | ||||
LoggingDatabaseConnection, | ||||
LoggingTransaction, | ||||
) | ||||
from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore | ||||
from synapse.storage.engines import PostgresEngine | ||||
from synapse.storage.util.id_generators import MultiWriterIdGenerator, StreamIdGenerator | ||||
from synapse.storage.util.id_generators import ( | ||||
AbstractStreamIdGenerator, | ||||
AbstractStreamIdTracker, | ||||
MultiWriterIdGenerator, | ||||
StreamIdGenerator, | ||||
) | ||||
from synapse.types import JsonDict | ||||
from synapse.util import json_encoder | ||||
from synapse.util.caches.descriptors import cached | ||||
|
@@ -34,14 +44,24 @@ | |||
logger = logging.getLogger(__name__) | ||||
|
||||
|
||||
class AccountDataWorkerStore(SQLBaseStore): | ||||
class AccountDataWorkerStore(CacheInvalidationWorkerStore): | ||||
"""This is an abstract base class where subclasses must implement | ||||
`get_max_account_data_stream_id` which can be called in the initializer. | ||||
""" | ||||
|
||||
def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"): | ||||
def __init__( | ||||
self, | ||||
database: DatabasePool, | ||||
db_conn: LoggingDatabaseConnection, | ||||
hs: "HomeServer", | ||||
): | ||||
self._instance_name = hs.get_instance_name() | ||||
|
||||
# `_can_write_to_account_data` indicates whether the current worker is allowed | ||||
# to write account data. A value of `True` implies that `_account_data_id_gen` | ||||
# is an `AbstractStreamIdGenerator` and not just a tracker. | ||||
self._account_data_id_gen: AbstractStreamIdTracker | ||||
|
||||
if isinstance(database.engine, PostgresEngine): | ||||
self._can_write_to_account_data = ( | ||||
self._instance_name in hs.config.worker.writers.account_data | ||||
|
@@ -61,8 +81,6 @@ def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"): | |||
writers=hs.config.worker.writers.account_data, | ||||
) | ||||
else: | ||||
self._can_write_to_account_data = True | ||||
|
||||
# We shouldn't be running in worker mode with SQLite, but its useful | ||||
# to support it for unit tests. | ||||
# | ||||
|
@@ -71,6 +89,7 @@ def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"): | |||
# updated over replication. (Multiple writers are not supported for | ||||
# SQLite). | ||||
if hs.get_instance_name() in hs.config.worker.writers.account_data: | ||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We have There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Good spot! |
||||
self._can_write_to_account_data = True | ||||
DMRobertson marked this conversation as resolved.
Show resolved
Hide resolved
|
||||
self._account_data_id_gen = StreamIdGenerator( | ||||
db_conn, | ||||
"room_account_data", | ||||
|
@@ -113,7 +132,9 @@ async def get_account_data_for_user( | |||
room_id string to per room account_data dicts. | ||||
""" | ||||
|
||||
def get_account_data_for_user_txn(txn): | ||||
def get_account_data_for_user_txn( | ||||
txn: LoggingTransaction, | ||||
) -> Tuple[Dict[str, JsonDict], Dict[str, Dict[str, JsonDict]]]: | ||||
rows = self.db_pool.simple_select_list_txn( | ||||
txn, | ||||
"account_data", | ||||
|
@@ -132,7 +153,7 @@ def get_account_data_for_user_txn(txn): | |||
["room_id", "account_data_type", "content"], | ||||
) | ||||
|
||||
by_room = {} | ||||
by_room: Dict[str, Dict[str, JsonDict]] = {} | ||||
for row in rows: | ||||
room_data = by_room.setdefault(row["room_id"], {}) | ||||
room_data[row["account_data_type"]] = db_to_json(row["content"]) | ||||
|
@@ -177,7 +198,9 @@ async def get_account_data_for_room( | |||
A dict of the room account_data | ||||
""" | ||||
|
||||
def get_account_data_for_room_txn(txn): | ||||
def get_account_data_for_room_txn( | ||||
txn: LoggingTransaction, | ||||
) -> Dict[str, JsonDict]: | ||||
rows = self.db_pool.simple_select_list_txn( | ||||
txn, | ||||
"room_account_data", | ||||
|
@@ -207,7 +230,9 @@ async def get_account_data_for_room_and_type( | |||
The room account_data for that type, or None if there isn't any set. | ||||
""" | ||||
|
||||
def get_account_data_for_room_and_type_txn(txn): | ||||
def get_account_data_for_room_and_type_txn( | ||||
txn: LoggingTransaction, | ||||
) -> Optional[JsonDict]: | ||||
content_json = self.db_pool.simple_select_one_onecol_txn( | ||||
txn, | ||||
table="room_account_data", | ||||
|
@@ -243,14 +268,16 @@ async def get_updated_global_account_data( | |||
if last_id == current_id: | ||||
return [] | ||||
|
||||
def get_updated_global_account_data_txn(txn): | ||||
def get_updated_global_account_data_txn( | ||||
txn: LoggingTransaction, | ||||
) -> List[Tuple[int, str, str]]: | ||||
sql = ( | ||||
"SELECT stream_id, user_id, account_data_type" | ||||
" FROM account_data WHERE ? < stream_id AND stream_id <= ?" | ||||
" ORDER BY stream_id ASC LIMIT ?" | ||||
) | ||||
txn.execute(sql, (last_id, current_id, limit)) | ||||
return txn.fetchall() | ||||
return cast(List[Tuple[int, str, str]], txn.fetchall()) | ||||
DMRobertson marked this conversation as resolved.
Show resolved
Hide resolved
|
||||
|
||||
return await self.db_pool.runInteraction( | ||||
"get_updated_global_account_data", get_updated_global_account_data_txn | ||||
|
@@ -273,14 +300,16 @@ async def get_updated_room_account_data( | |||
if last_id == current_id: | ||||
return [] | ||||
|
||||
def get_updated_room_account_data_txn(txn): | ||||
def get_updated_room_account_data_txn( | ||||
txn: LoggingTransaction, | ||||
) -> List[Tuple[int, str, str, str]]: | ||||
sql = ( | ||||
"SELECT stream_id, user_id, room_id, account_data_type" | ||||
" FROM room_account_data WHERE ? < stream_id AND stream_id <= ?" | ||||
" ORDER BY stream_id ASC LIMIT ?" | ||||
) | ||||
txn.execute(sql, (last_id, current_id, limit)) | ||||
return txn.fetchall() | ||||
return cast(List[Tuple[int, str, str, str]], txn.fetchall()) | ||||
|
||||
return await self.db_pool.runInteraction( | ||||
"get_updated_room_account_data", get_updated_room_account_data_txn | ||||
|
@@ -299,7 +328,9 @@ async def get_updated_account_data_for_user( | |||
mapping from room_id string to per room account_data dicts. | ||||
""" | ||||
|
||||
def get_updated_account_data_for_user_txn(txn): | ||||
def get_updated_account_data_for_user_txn( | ||||
txn: LoggingTransaction, | ||||
) -> Tuple[Dict[str, JsonDict], Dict[str, Dict[str, JsonDict]]]: | ||||
sql = ( | ||||
"SELECT account_data_type, content FROM account_data" | ||||
" WHERE user_id = ? AND stream_id > ?" | ||||
|
@@ -316,7 +347,7 @@ def get_updated_account_data_for_user_txn(txn): | |||
|
||||
txn.execute(sql, (user_id, stream_id)) | ||||
|
||||
account_data_by_room = {} | ||||
account_data_by_room: Dict[str, Dict[str, JsonDict]] = {} | ||||
for row in txn: | ||||
room_account_data = account_data_by_room.setdefault(row[0], {}) | ||||
room_account_data[row[1]] = db_to_json(row[2]) | ||||
|
@@ -353,12 +384,15 @@ async def ignored_by(self, user_id: str) -> Set[str]: | |||
) | ||||
) | ||||
|
||||
def process_replication_rows(self, stream_name, instance_name, token, rows): | ||||
def process_replication_rows( | ||||
self, | ||||
stream_name: str, | ||||
instance_name: str, | ||||
token: int, | ||||
rows: Iterable[Any], | ||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think this is an iterable over some kind of rowproxy? But it looks like this gets used throughout the inheritance hierarchy in way that makes type checking not feasible. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I couldn't figure out what the I think it's a
None of the row types share any fields in common so There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Maybe we could make |
||||
) -> None: | ||||
if stream_name == TagAccountDataStream.NAME: | ||||
self._account_data_id_gen.advance(instance_name, token) | ||||
for row in rows: | ||||
self.get_tags_for_user.invalidate((row.user_id,)) | ||||
self._account_data_stream_cache.entity_has_changed(row.user_id, token) | ||||
Comment on lines
-359
to
-361
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. At some point in the past, I left the call to I think we don't need to call
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Say we did instantiate such a store. What does leaving the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Without the I think the effect of that would be that various bits of Synapse would not get notified about data written after |
||||
elif stream_name == AccountDataStream.NAME: | ||||
self._account_data_id_gen.advance(instance_name, token) | ||||
for row in rows: | ||||
|
@@ -372,7 +406,8 @@ def process_replication_rows(self, stream_name, instance_name, token, rows): | |||
(row.user_id, row.room_id, row.data_type) | ||||
) | ||||
self._account_data_stream_cache.entity_has_changed(row.user_id, token) | ||||
return super().process_replication_rows(stream_name, instance_name, token, rows) | ||||
|
||||
super().process_replication_rows(stream_name, instance_name, token, rows) | ||||
|
||||
async def add_account_data_to_room( | ||||
self, user_id: str, room_id: str, account_data_type: str, content: JsonDict | ||||
|
@@ -389,6 +424,7 @@ async def add_account_data_to_room( | |||
The maximum stream ID. | ||||
""" | ||||
assert self._can_write_to_account_data | ||||
assert isinstance(self._account_data_id_gen, AbstractStreamIdGenerator) | ||||
|
||||
content_json = json_encoder.encode(content) | ||||
|
||||
|
@@ -431,6 +467,7 @@ async def add_account_data_for_user( | |||
The maximum stream ID. | ||||
""" | ||||
assert self._can_write_to_account_data | ||||
assert isinstance(self._account_data_id_gen, AbstractStreamIdGenerator) | ||||
|
||||
async with self._account_data_id_gen.get_next() as next_id: | ||||
await self.db_pool.runInteraction( | ||||
|
@@ -452,7 +489,7 @@ async def add_account_data_for_user( | |||
|
||||
def _add_account_data_for_user( | ||||
self, | ||||
txn, | ||||
txn: LoggingTransaction, | ||||
next_id: int, | ||||
user_id: str, | ||||
account_data_type: str, | ||||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this comment accurate? There appears to be an implementation of this function on +114 with no override that PyCharm recognises.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It looks inaccurate, I'll remove it.