Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Initial implementation of MSC3981: recursive relations API #15315

Merged
merged 6 commits into from
May 2, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.d/15315.feature
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Experimental support to recursively provide relations per [MSC3981](/~https://github.com/matrix-org/matrix-spec-proposals/pull/3981).
5 changes: 5 additions & 0 deletions synapse/config/experimental.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,5 +192,10 @@ def read_config(self, config: JsonDict, **kwargs: Any) -> None:
# MSC2659: Application service ping endpoint
self.msc2659_enabled = experimental.get("msc2659_enabled", False)

# MSC3981: Recurse relations
self.msc3981_recurse_relations = experimental.get(
"msc3981_recurse_relations", False
)

# MSC3970: Scope transaction IDs to devices
self.msc3970_enabled = experimental.get("msc3970_enabled", False)
3 changes: 3 additions & 0 deletions synapse/handlers/relations.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ async def get_relations(
event_id: str,
room_id: str,
pagin_config: PaginationConfig,
recurse: bool,
include_original_event: bool,
relation_type: Optional[str] = None,
event_type: Optional[str] = None,
Expand All @@ -98,6 +99,7 @@ async def get_relations(
event_id: Fetch events that relate to this event ID.
room_id: The room the event belongs to.
pagin_config: The pagination config rules to apply, if any.
recurse: Whether to recursively find relations.
include_original_event: Whether to include the parent event.
relation_type: Only fetch events with this relation type, if given.
event_type: Only fetch events with this event type, if given.
Expand Down Expand Up @@ -132,6 +134,7 @@ async def get_relations(
direction=pagin_config.direction,
from_token=pagin_config.from_token,
to_token=pagin_config.to_token,
recurse=recurse,
)

events = await self._main_store.get_events_as_list(
Expand Down
10 changes: 9 additions & 1 deletion synapse/rest/client/relations.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from synapse.api.constants import Direction
from synapse.handlers.relations import ThreadsListInclude
from synapse.http.server import HttpServer
from synapse.http.servlet import RestServlet, parse_integer, parse_string
from synapse.http.servlet import RestServlet, parse_boolean, parse_integer, parse_string
from synapse.http.site import SynapseRequest
from synapse.rest.client._base import client_patterns
from synapse.storage.databases.main.relations import ThreadsNextBatch
Expand Down Expand Up @@ -49,6 +49,7 @@ def __init__(self, hs: "HomeServer"):
self.auth = hs.get_auth()
self._store = hs.get_datastores().main
self._relations_handler = hs.get_relations_handler()
self._support_recurse = hs.config.experimental.msc3981_recurse_relations

async def on_GET(
self,
Expand All @@ -63,6 +64,12 @@ async def on_GET(
pagination_config = await PaginationConfig.from_request(
self._store, request, default_limit=5, default_dir=Direction.BACKWARDS
)
if self._support_recurse:
recurse = parse_boolean(
request, "org.matrix.msc3981.recurse", default=False
)
else:
recurse = False

# The unstable version of this API returns an extra field for client
# compatibility, see /~https://github.com/matrix-org/synapse/issues/12930.
Expand All @@ -75,6 +82,7 @@ async def on_GET(
event_id=parent_id,
room_id=room_id,
pagin_config=pagination_config,
recurse=recurse,
include_original_event=include_original_event,
relation_type=relation_type,
event_type=event_type,
Expand Down
65 changes: 48 additions & 17 deletions synapse/storage/databases/main/relations.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,7 @@ async def get_relations_for_event(
direction: Direction = Direction.BACKWARDS,
from_token: Optional[StreamToken] = None,
to_token: Optional[StreamToken] = None,
recurse: bool = False,
) -> Tuple[Sequence[_RelatedEvent], Optional[StreamToken]]:
"""Get a list of relations for an event, ordered by topological ordering.

Expand All @@ -186,6 +187,7 @@ async def get_relations_for_event(
oldest first (forwards).
from_token: Fetch rows from the given token, or from the start if None.
to_token: Fetch rows up to the given token, or up to the end if None.
recurse: Whether to recursively find relations.

Returns:
A tuple of:
Expand All @@ -200,8 +202,8 @@ async def get_relations_for_event(
# Ensure bad limits aren't being passed in.
assert limit >= 0

where_clause = ["relates_to_id = ?", "room_id = ?"]
where_args: List[Union[str, int]] = [event.event_id, room_id]
where_clause = ["room_id = ?"]
where_args: List[Union[str, int]] = [room_id]
is_redacted = event.internal_metadata.is_redacted()

if relation_type is not None:
Expand Down Expand Up @@ -229,23 +231,52 @@ async def get_relations_for_event(
if pagination_clause:
where_clause.append(pagination_clause)

sql = """
SELECT event_id, relation_type, sender, topological_ordering, stream_ordering
FROM event_relations
INNER JOIN events USING (event_id)
WHERE %s
ORDER BY topological_ordering %s, stream_ordering %s
LIMIT ?
""" % (
" AND ".join(where_clause),
order,
order,
)
# If a recursive query is requested then the filters are applied after
# recursively following relationships from the requested event to children
# up to 3-relations deep.
#
# If no recursion is needed then the event_relations table is queried
# for direct children of the requested event.
if recurse:
sql = """
WITH RECURSIVE related_events AS (
SELECT event_id, relation_type, relates_to_id, 0 AS depth
FROM event_relations
WHERE relates_to_id = ?
UNION SELECT e.event_id, e.relation_type, e.relates_to_id, depth + 1
FROM event_relations e
INNER JOIN related_events r ON r.event_id = e.relates_to_id
WHERE depth <= 3
)
SELECT event_id, relation_type, sender, topological_ordering, stream_ordering
FROM related_events
INNER JOIN events USING (event_id)
WHERE %s
ORDER BY topological_ordering %s, stream_ordering %s
LIMIT ?;
Comment on lines +251 to +256
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The SELECT portion of these queries are the same -- would be better / clearer to structure this as a preamble & table name to query and only have one copy of the shared bit?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is probably fine for now

""" % (
" AND ".join(where_clause),
order,
order,
)
else:
sql = """
SELECT event_id, relation_type, sender, topological_ordering, stream_ordering
FROM event_relations
INNER JOIN events USING (event_id)
WHERE relates_to_id = ? AND %s
ORDER BY topological_ordering %s, stream_ordering %s
LIMIT ?
""" % (
" AND ".join(where_clause),
order,
order,
)

def _get_recent_references_for_event_txn(
txn: LoggingTransaction,
) -> Tuple[List[_RelatedEvent], Optional[StreamToken]]:
txn.execute(sql, where_args + [limit + 1])
txn.execute(sql, [event.event_id] + where_args + [limit + 1])

events = []
topo_orderings: List[int] = []
Expand Down Expand Up @@ -965,7 +996,7 @@ async def get_thread_id(self, event_id: str) -> str:
# relation.
sql = """
WITH RECURSIVE related_events AS (
SELECT event_id, relates_to_id, relation_type, 0 depth
SELECT event_id, relates_to_id, relation_type, 0 AS depth
FROM event_relations
WHERE event_id = ?
UNION SELECT e.event_id, e.relates_to_id, e.relation_type, depth + 1
Expand Down Expand Up @@ -1025,7 +1056,7 @@ async def get_thread_id_for_receipts(self, event_id: str) -> str:
sql = """
SELECT relates_to_id FROM event_relations WHERE relates_to_id = COALESCE((
WITH RECURSIVE related_events AS (
SELECT event_id, relates_to_id, relation_type, 0 depth
SELECT event_id, relates_to_id, relation_type, 0 AS depth
FROM event_relations
WHERE event_id = ?
UNION SELECT e.event_id, e.relates_to_id, e.relation_type, depth + 1
Expand Down
120 changes: 120 additions & 0 deletions tests/rest/client/test_relations.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from tests.server import FakeChannel
from tests.test_utils import make_awaitable
from tests.test_utils.event_injection import inject_event
from tests.unittest import override_config


class BaseRelationsTestCase(unittest.HomeserverTestCase):
Expand Down Expand Up @@ -949,6 +950,125 @@ def test_pagination_from_sync_and_messages(self) -> None:
)


class RecursiveRelationTestCase(BaseRelationsTestCase):
@override_config({"experimental_features": {"msc3981_recurse_relations": True}})
def test_recursive_relations(self) -> None:
"""Generate a complex, multi-level relationship tree and query it."""
# Create a thread with a few messages in it.
channel = self._send_relation(RelationTypes.THREAD, "m.room.test")
thread_1 = channel.json_body["event_id"]

channel = self._send_relation(RelationTypes.THREAD, "m.room.test")
thread_2 = channel.json_body["event_id"]

# Add annotations.
channel = self._send_relation(
RelationTypes.ANNOTATION, "m.reaction", "a", parent_id=thread_2
)
annotation_1 = channel.json_body["event_id"]

channel = self._send_relation(
RelationTypes.ANNOTATION, "m.reaction", "b", parent_id=thread_1
)
annotation_2 = channel.json_body["event_id"]

# Add a reference to part of the thread, then edit the reference and annotate it.
channel = self._send_relation(
RelationTypes.REFERENCE, "m.room.test", parent_id=thread_2
)
reference_1 = channel.json_body["event_id"]

channel = self._send_relation(
RelationTypes.ANNOTATION, "m.reaction", "c", parent_id=reference_1
)
annotation_3 = channel.json_body["event_id"]

channel = self._send_relation(
RelationTypes.REPLACE,
"m.room.test",
parent_id=reference_1,
)
edit = channel.json_body["event_id"]

# Also more events off the root.
channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "d")
annotation_4 = channel.json_body["event_id"]

channel = self.make_request(
"GET",
f"/_matrix/client/v1/rooms/{self.room}/relations/{self.parent_id}"
"?dir=f&limit=20&org.matrix.msc3981.recurse=true",
access_token=self.user_token,
)
self.assertEqual(200, channel.code, channel.json_body)

# The above events should be returned in creation order.
event_ids = [ev["event_id"] for ev in channel.json_body["chunk"]]
self.assertEqual(
event_ids,
[
thread_1,
thread_2,
annotation_1,
annotation_2,
reference_1,
annotation_3,
edit,
annotation_4,
],
)

@override_config({"experimental_features": {"msc3981_recurse_relations": True}})
def test_recursive_relations_with_filter(self) -> None:
"""The event_type and rel_type still apply."""
# Create a thread with a few messages in it.
channel = self._send_relation(RelationTypes.THREAD, "m.room.test")
thread_1 = channel.json_body["event_id"]

# Add annotations.
channel = self._send_relation(
RelationTypes.ANNOTATION, "m.reaction", "b", parent_id=thread_1
)
annotation_1 = channel.json_body["event_id"]

# Add a reference to part of the thread, then edit the reference and annotate it.
channel = self._send_relation(
RelationTypes.REFERENCE, "m.room.test", parent_id=thread_1
)
reference_1 = channel.json_body["event_id"]

channel = self._send_relation(
RelationTypes.ANNOTATION, "org.matrix.reaction", "c", parent_id=reference_1
)
annotation_2 = channel.json_body["event_id"]

# Fetch only annotations, but recursively.
channel = self.make_request(
"GET",
f"/_matrix/client/v1/rooms/{self.room}/relations/{self.parent_id}/{RelationTypes.ANNOTATION}"
"?dir=f&limit=20&org.matrix.msc3981.recurse=true",
access_token=self.user_token,
)
self.assertEqual(200, channel.code, channel.json_body)

# The above events should be returned in creation order.
event_ids = [ev["event_id"] for ev in channel.json_body["chunk"]]
self.assertEqual(event_ids, [annotation_1, annotation_2])

# Fetch only m.reactions, but recursively.
channel = self.make_request(
"GET",
f"/_matrix/client/v1/rooms/{self.room}/relations/{self.parent_id}/{RelationTypes.ANNOTATION}/m.reaction"
"?dir=f&limit=20&org.matrix.msc3981.recurse=true",
access_token=self.user_token,
)
self.assertEqual(200, channel.code, channel.json_body)

# The above events should be returned in creation order.
event_ids = [ev["event_id"] for ev in channel.json_body["chunk"]]
self.assertEqual(event_ids, [annotation_1])


class BundledAggregationsTestCase(BaseRelationsTestCase):
"""
See RelationsTestCase.test_edit for a similar test for edits.
Expand Down