diff --git a/changelog.d/12888.misc b/changelog.d/12888.misc new file mode 100644 index 000000000000..8ed2ea65b5a8 --- /dev/null +++ b/changelog.d/12888.misc @@ -0,0 +1 @@ +Refactor receipt linearization code. diff --git a/synapse/storage/databases/main/receipts.py b/synapse/storage/databases/main/receipts.py index cfa4d4924d54..09f78c61c0cb 100644 --- a/synapse/storage/databases/main/receipts.py +++ b/synapse/storage/databases/main/receipts.py @@ -597,7 +597,7 @@ def process_replication_rows( return super().process_replication_rows(stream_name, instance_name, token, rows) - def insert_linearized_receipt_txn( + def _insert_linearized_receipt_txn( self, txn: LoggingTransaction, room_id: str, @@ -686,6 +686,44 @@ def insert_linearized_receipt_txn( return rx_ts + def _graph_to_linear( + self, txn: LoggingTransaction, room_id: str, event_ids: List[str] + ) -> str: + """ + Generate a linearized event from a list of events (i.e. a list of forward + extremities in the room). + + This should allow for calculation of the correct read receipt even if + servers have different event ordering. + + Args: + txn: The transaction + room_id: The room ID the events are in. + event_ids: The list of event IDs to linearize. + + Returns: + The linearized event ID. + """ + # TODO: Make this better. + clause, args = make_in_list_sql_clause( + self.database_engine, "event_id", event_ids + ) + + sql = """ + SELECT event_id WHERE room_id = ? AND stream_ordering IN ( + SELECT max(stream_ordering) WHERE %s + ) + """ % ( + clause, + ) + + txn.execute(sql, [room_id] + list(args)) + rows = txn.fetchall() + if rows: + return rows[0][0] + else: + raise RuntimeError("Unrecognized event_ids: %r" % (event_ids,)) + async def insert_receipt( self, room_id: str, @@ -712,35 +750,14 @@ async def insert_receipt( linearized_event_id = event_ids[0] else: # we need to points in graph -> linearized form. - # TODO: Make this better. - def graph_to_linear(txn: LoggingTransaction) -> str: - clause, args = make_in_list_sql_clause( - self.database_engine, "event_id", event_ids - ) - - sql = """ - SELECT event_id WHERE room_id = ? AND stream_ordering IN ( - SELECT max(stream_ordering) WHERE %s - ) - """ % ( - clause, - ) - - txn.execute(sql, [room_id] + list(args)) - rows = txn.fetchall() - if rows: - return rows[0][0] - else: - raise RuntimeError("Unrecognized event_ids: %r" % (event_ids,)) - linearized_event_id = await self.db_pool.runInteraction( - "insert_receipt_conv", graph_to_linear + "insert_receipt_conv", self._graph_to_linear, room_id, event_ids ) async with self._receipts_id_gen.get_next() as stream_id: # type: ignore[attr-defined] event_ts = await self.db_pool.runInteraction( "insert_linearized_receipt", - self.insert_linearized_receipt_txn, + self._insert_linearized_receipt_txn, room_id, receipt_type, user_id, @@ -761,25 +778,9 @@ def graph_to_linear(txn: LoggingTransaction) -> str: now - event_ts, ) - await self.insert_graph_receipt(room_id, receipt_type, user_id, event_ids, data) - - max_persisted_id = self._receipts_id_gen.get_current_token() - - return stream_id, max_persisted_id - - async def insert_graph_receipt( - self, - room_id: str, - receipt_type: str, - user_id: str, - event_ids: List[str], - data: JsonDict, - ) -> None: - assert self._can_write_to_receipts - await self.db_pool.runInteraction( "insert_graph_receipt", - self.insert_graph_receipt_txn, + self._insert_graph_receipt_txn, room_id, receipt_type, user_id, @@ -787,7 +788,11 @@ async def insert_graph_receipt( data, ) - def insert_graph_receipt_txn( + max_persisted_id = self._receipts_id_gen.get_current_token() + + return stream_id, max_persisted_id + + def _insert_graph_receipt_txn( self, txn: LoggingTransaction, room_id: str,