From 1eea662780a6325af0a61ceb447b4c91a2d3ac98 Mon Sep 17 00:00:00 2001 From: Andrew Morgan <1342360+anoadragon453@users.noreply.github.com> Date: Thu, 2 Mar 2023 18:27:00 +0000 Subject: [PATCH] Add a `get_next_txn` method to `StreamIdGenerator` to match `MultiWriterIdGenerator` (#15191 --- changelog.d/15191.misc | 1 + .../storage/databases/main/account_data.py | 11 +---- synapse/storage/util/id_generators.py | 45 ++++++++++++++++++- synapse/storage/util/sequence.py | 2 +- 4 files changed, 48 insertions(+), 11 deletions(-) create mode 100644 changelog.d/15191.misc diff --git a/changelog.d/15191.misc b/changelog.d/15191.misc new file mode 100644 index 000000000000..579f76d451ff --- /dev/null +++ b/changelog.d/15191.misc @@ -0,0 +1 @@ +Add a `get_next_txn` method to `StreamIdGenerator` to match `MultiWriterIdGenerator`. \ No newline at end of file diff --git a/synapse/storage/databases/main/account_data.py b/synapse/storage/databases/main/account_data.py index 308d19440f63..2d2ba74347e1 100644 --- a/synapse/storage/databases/main/account_data.py +++ b/synapse/storage/databases/main/account_data.py @@ -40,7 +40,6 @@ from synapse.storage.engines import PostgresEngine from synapse.storage.util.id_generators import ( AbstractStreamIdGenerator, - AbstractStreamIdTracker, MultiWriterIdGenerator, StreamIdGenerator, ) @@ -64,14 +63,12 @@ def __init__( ): super().__init__(database, db_conn, hs) - # `_can_write_to_account_data` indicates whether the current worker is allowed - # to write account data. A value of `True` implies that `_account_data_id_gen` - # is an `AbstractStreamIdGenerator` and not just a tracker. - self._account_data_id_gen: AbstractStreamIdTracker self._can_write_to_account_data = ( self._instance_name in hs.config.worker.writers.account_data ) + self._account_data_id_gen: AbstractStreamIdGenerator + if isinstance(database.engine, PostgresEngine): self._account_data_id_gen = MultiWriterIdGenerator( db_conn=db_conn, @@ -558,7 +555,6 @@ async def add_account_data_to_room( The maximum stream ID. """ assert self._can_write_to_account_data - assert isinstance(self._account_data_id_gen, AbstractStreamIdGenerator) content_json = json_encoder.encode(content) @@ -598,7 +594,6 @@ async def remove_account_data_for_room( data to delete. """ assert self._can_write_to_account_data - assert isinstance(self._account_data_id_gen, AbstractStreamIdGenerator) def _remove_account_data_for_room_txn( txn: LoggingTransaction, next_id: int @@ -663,7 +658,6 @@ async def add_account_data_for_user( The maximum stream ID. """ assert self._can_write_to_account_data - assert isinstance(self._account_data_id_gen, AbstractStreamIdGenerator) async with self._account_data_id_gen.get_next() as next_id: await self.db_pool.runInteraction( @@ -770,7 +764,6 @@ async def remove_account_data_for_user( to delete. """ assert self._can_write_to_account_data - assert isinstance(self._account_data_id_gen, AbstractStreamIdGenerator) def _remove_account_data_for_user_txn( txn: LoggingTransaction, next_id: int diff --git a/synapse/storage/util/id_generators.py b/synapse/storage/util/id_generators.py index 9adff3f4f523..334d3d718b4b 100644 --- a/synapse/storage/util/id_generators.py +++ b/synapse/storage/util/id_generators.py @@ -158,6 +158,15 @@ def get_next_mult(self, n: int) -> AsyncContextManager[Sequence[int]]: """ raise NotImplementedError() + @abc.abstractmethod + def get_next_txn(self, txn: LoggingTransaction) -> int: + """ + Usage: + stream_id_gen.get_next_txn(txn) + # ... persist events ... + """ + raise NotImplementedError() + class StreamIdGenerator(AbstractStreamIdGenerator): """Generates and tracks stream IDs for a stream with a single writer. @@ -263,6 +272,40 @@ def manager() -> Generator[Sequence[int], None, None]: return _AsyncCtxManagerWrapper(manager()) + def get_next_txn(self, txn: LoggingTransaction) -> int: + """ + Retrieve the next stream ID from within a database transaction. + + Clean-up functions will be called when the transaction finishes. + + Args: + txn: The database transaction object. + + Returns: + The next stream ID. + """ + if not self._is_writer: + raise Exception("Tried to allocate stream ID on non-writer") + + # Get the next stream ID. + with self._lock: + self._current += self._step + next_id = self._current + + self._unfinished_ids[next_id] = next_id + + def clear_unfinished_id(id_to_clear: int) -> None: + """A function to mark processing this ID as finished""" + with self._lock: + self._unfinished_ids.pop(id_to_clear) + + # Mark this ID as finished once the database transaction itself finishes. + txn.call_after(clear_unfinished_id, next_id) + txn.call_on_exception(clear_unfinished_id, next_id) + + # Return the new ID. + return next_id + def get_current_token(self) -> int: if not self._is_writer: return self._current @@ -568,7 +611,7 @@ def get_next_txn(self, txn: LoggingTransaction) -> int: """ Usage: - stream_id = stream_id_gen.get_next(txn) + stream_id = stream_id_gen.get_next_txn(txn) # ... persist event ... """ diff --git a/synapse/storage/util/sequence.py b/synapse/storage/util/sequence.py index 75268cbe1595..80915216de94 100644 --- a/synapse/storage/util/sequence.py +++ b/synapse/storage/util/sequence.py @@ -205,7 +205,7 @@ def __init__(self, get_first_callback: GetFirstCallbackType): """ Args: get_first_callback: a callback which is called on the first call to - get_next_id_txn; should return the curreent maximum id + get_next_id_txn; should return the current maximum id """ # the callback. this is cleared after it is called, so that it can be GCed. self._callback: Optional[GetFirstCallbackType] = get_first_callback