diff --git a/changelog.d/10707.misc b/changelog.d/10707.misc new file mode 100644 index 000000000000..39a37b90b1b3 --- /dev/null +++ b/changelog.d/10707.misc @@ -0,0 +1 @@ +Add missing type hints to REST servlets. diff --git a/synapse/rest/client/account_data.py b/synapse/rest/client/account_data.py index 7517e9304e8d..d1badbdf3bec 100644 --- a/synapse/rest/client/account_data.py +++ b/synapse/rest/client/account_data.py @@ -13,12 +13,19 @@ # limitations under the License. import logging +from typing import TYPE_CHECKING, Tuple from synapse.api.errors import AuthError, NotFoundError, SynapseError +from synapse.http.server import HttpServer from synapse.http.servlet import RestServlet, parse_json_object_from_request +from synapse.http.site import SynapseRequest +from synapse.types import JsonDict from ._base import client_patterns +if TYPE_CHECKING: + from synapse.server import HomeServer + logger = logging.getLogger(__name__) @@ -32,13 +39,15 @@ class AccountDataServlet(RestServlet): "/user/(?P[^/]*)/account_data/(?P[^/]*)" ) - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__() self.auth = hs.get_auth() self.store = hs.get_datastore() self.handler = hs.get_account_data_handler() - async def on_PUT(self, request, user_id, account_data_type): + async def on_PUT( + self, request: SynapseRequest, user_id: str, account_data_type: str + ) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request) if user_id != requester.user.to_string(): raise AuthError(403, "Cannot add account data for other users.") @@ -49,7 +58,9 @@ async def on_PUT(self, request, user_id, account_data_type): return 200, {} - async def on_GET(self, request, user_id, account_data_type): + async def on_GET( + self, request: SynapseRequest, user_id: str, account_data_type: str + ) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request) if user_id != requester.user.to_string(): raise AuthError(403, "Cannot get account data for other users.") @@ -76,13 +87,19 @@ class RoomAccountDataServlet(RestServlet): "/account_data/(?P[^/]*)" ) - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__() self.auth = hs.get_auth() self.store = hs.get_datastore() self.handler = hs.get_account_data_handler() - async def on_PUT(self, request, user_id, room_id, account_data_type): + async def on_PUT( + self, + request: SynapseRequest, + user_id: str, + room_id: str, + account_data_type: str, + ) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request) if user_id != requester.user.to_string(): raise AuthError(403, "Cannot add account data for other users.") @@ -102,7 +119,13 @@ async def on_PUT(self, request, user_id, room_id, account_data_type): return 200, {} - async def on_GET(self, request, user_id, room_id, account_data_type): + async def on_GET( + self, + request: SynapseRequest, + user_id: str, + room_id: str, + account_data_type: str, + ) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request) if user_id != requester.user.to_string(): raise AuthError(403, "Cannot get account data for other users.") @@ -117,6 +140,6 @@ async def on_GET(self, request, user_id, room_id, account_data_type): return 200, event -def register_servlets(hs, http_server): +def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None: AccountDataServlet(hs).register(http_server) RoomAccountDataServlet(hs).register(http_server) diff --git a/synapse/rest/client/groups.py b/synapse/rest/client/groups.py index c3667ff8aaad..0950f43f2f1b 100644 --- a/synapse/rest/client/groups.py +++ b/synapse/rest/client/groups.py @@ -156,7 +156,7 @@ async def on_PUT( group_id: str, category_id: Optional[str], room_id: str, - ): + ) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request) requester_user_id = requester.user.to_string() @@ -188,7 +188,7 @@ async def on_PUT( @_validate_group_id async def on_DELETE( self, request: SynapseRequest, group_id: str, category_id: str, room_id: str - ): + ) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request) requester_user_id = requester.user.to_string() @@ -451,7 +451,7 @@ async def on_PUT( @_validate_group_id async def on_DELETE( self, request: SynapseRequest, group_id: str, role_id: str, user_id: str - ): + ) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request) requester_user_id = requester.user.to_string() @@ -674,7 +674,7 @@ def __init__(self, hs: "HomeServer"): @_validate_group_id async def on_PUT( self, request: SynapseRequest, group_id: str, room_id: str, config_key: str - ): + ) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request) requester_user_id = requester.user.to_string() @@ -706,7 +706,7 @@ def __init__(self, hs: "HomeServer"): @_validate_group_id async def on_PUT( - self, request: SynapseRequest, group_id, user_id + self, request: SynapseRequest, group_id: str, user_id: str ) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request) requester_user_id = requester.user.to_string() @@ -738,7 +738,7 @@ def __init__(self, hs: "HomeServer"): @_validate_group_id async def on_PUT( - self, request: SynapseRequest, group_id, user_id + self, request: SynapseRequest, group_id: str, user_id: str ) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request) requester_user_id = requester.user.to_string() diff --git a/synapse/rest/client/receipts.py b/synapse/rest/client/receipts.py index d9ab836cd892..9770413c618d 100644 --- a/synapse/rest/client/receipts.py +++ b/synapse/rest/client/receipts.py @@ -13,13 +13,20 @@ # limitations under the License. import logging +from typing import TYPE_CHECKING, Tuple from synapse.api.constants import ReadReceiptEventFields from synapse.api.errors import Codes, SynapseError +from synapse.http.server import HttpServer from synapse.http.servlet import RestServlet, parse_json_object_from_request +from synapse.http.site import SynapseRequest +from synapse.types import JsonDict from ._base import client_patterns +if TYPE_CHECKING: + from synapse.server import HomeServer + logger = logging.getLogger(__name__) @@ -30,14 +37,16 @@ class ReceiptRestServlet(RestServlet): "/(?P[^/]*)$" ) - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__() self.hs = hs self.auth = hs.get_auth() self.receipts_handler = hs.get_receipts_handler() self.presence_handler = hs.get_presence_handler() - async def on_POST(self, request, room_id, receipt_type, event_id): + async def on_POST( + self, request: SynapseRequest, room_id: str, receipt_type: str, event_id: str + ) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request) if receipt_type != "m.read": @@ -67,5 +76,5 @@ async def on_POST(self, request, room_id, receipt_type, event_id): return 200, {} -def register_servlets(hs, http_server): +def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None: ReceiptRestServlet(hs).register(http_server) diff --git a/synapse/rest/client/register.py b/synapse/rest/client/register.py index 7b5f49d635cb..a28acd4041c2 100644 --- a/synapse/rest/client/register.py +++ b/synapse/rest/client/register.py @@ -14,7 +14,9 @@ # limitations under the License. import logging import random -from typing import List, Union +from typing import TYPE_CHECKING, List, Optional, Tuple + +from twisted.web.server import Request import synapse import synapse.api.auth @@ -29,15 +31,13 @@ ) from synapse.api.ratelimiting import Ratelimiter from synapse.config import ConfigError -from synapse.config.captcha import CaptchaConfig -from synapse.config.consent import ConsentConfig from synapse.config.emailconfig import ThreepidBehaviour +from synapse.config.homeserver import HomeServerConfig from synapse.config.ratelimiting import FederationRateLimitConfig -from synapse.config.registration import RegistrationConfig from synapse.config.server import is_threepid_reserved from synapse.handlers.auth import AuthHandler from synapse.handlers.ui_auth import UIAuthSessionDataConstants -from synapse.http.server import finish_request, respond_with_html +from synapse.http.server import HttpServer, finish_request, respond_with_html from synapse.http.servlet import ( RestServlet, assert_params_in_dict, @@ -45,6 +45,7 @@ parse_json_object_from_request, parse_string, ) +from synapse.http.site import SynapseRequest from synapse.metrics import threepid_send_requests from synapse.push.mailer import Mailer from synapse.types import JsonDict @@ -59,17 +60,16 @@ from ._base import client_patterns, interactive_auth_handler +if TYPE_CHECKING: + from synapse.server import HomeServer + logger = logging.getLogger(__name__) class EmailRegisterRequestTokenRestServlet(RestServlet): PATTERNS = client_patterns("/register/email/requestToken$") - def __init__(self, hs): - """ - Args: - hs (synapse.server.HomeServer): server - """ + def __init__(self, hs: "HomeServer"): super().__init__() self.hs = hs self.identity_handler = hs.get_identity_handler() @@ -83,7 +83,7 @@ def __init__(self, hs): template_text=self.config.email_registration_template_text, ) - async def on_POST(self, request): + async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]: if self.hs.config.threepid_behaviour_email == ThreepidBehaviour.OFF: if self.hs.config.local_threepid_handling_disabled_due_to_email_config: logger.warning( @@ -171,16 +171,12 @@ async def on_POST(self, request): class MsisdnRegisterRequestTokenRestServlet(RestServlet): PATTERNS = client_patterns("/register/msisdn/requestToken$") - def __init__(self, hs): - """ - Args: - hs (synapse.server.HomeServer): server - """ + def __init__(self, hs: "HomeServer"): super().__init__() self.hs = hs self.identity_handler = hs.get_identity_handler() - async def on_POST(self, request): + async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]: body = parse_json_object_from_request(request) assert_params_in_dict( @@ -255,11 +251,7 @@ class RegistrationSubmitTokenServlet(RestServlet): "/registration/(?P[^/]*)/submit_token$", releases=(), unstable=True ) - def __init__(self, hs): - """ - Args: - hs (synapse.server.HomeServer): server - """ + def __init__(self, hs: "HomeServer"): super().__init__() self.hs = hs self.auth = hs.get_auth() @@ -272,7 +264,7 @@ def __init__(self, hs): self.config.email_registration_template_failure_html ) - async def on_GET(self, request, medium): + async def on_GET(self, request: Request, medium: str) -> None: if medium != "email": raise SynapseError( 400, "This medium is currently not supported for registration" @@ -326,11 +318,7 @@ async def on_GET(self, request, medium): class UsernameAvailabilityRestServlet(RestServlet): PATTERNS = client_patterns("/register/available") - def __init__(self, hs): - """ - Args: - hs (synapse.server.HomeServer): server - """ + def __init__(self, hs: "HomeServer"): super().__init__() self.hs = hs self.registration_handler = hs.get_registration_handler() @@ -350,7 +338,7 @@ def __init__(self, hs): ), ) - async def on_GET(self, request): + async def on_GET(self, request: Request) -> Tuple[int, JsonDict]: if not self.hs.config.enable_registration: raise SynapseError( 403, "Registration has been disabled", errcode=Codes.FORBIDDEN @@ -419,11 +407,7 @@ async def on_GET(self, request): class RegisterRestServlet(RestServlet): PATTERNS = client_patterns("/register$") - def __init__(self, hs): - """ - Args: - hs (synapse.server.HomeServer): server - """ + def __init__(self, hs: "HomeServer"): super().__init__() self.hs = hs @@ -445,23 +429,21 @@ def __init__(self, hs): ) @interactive_auth_handler - async def on_POST(self, request): + async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]: body = parse_json_object_from_request(request) client_addr = request.getClientIP() await self.ratelimiter.ratelimit(None, client_addr, update=False) - kind = b"user" - if b"kind" in request.args: - kind = request.args[b"kind"][0] + kind = parse_string(request, "kind", default="user") - if kind == b"guest": + if kind == "guest": ret = await self._do_guest_registration(body, address=client_addr) return ret - elif kind != b"user": + elif kind != "user": raise UnrecognizedRequestError( - "Do not understand membership kind: %s" % (kind.decode("utf8"),) + f"Do not understand membership kind: {kind}", ) if self._msc2918_enabled: @@ -749,7 +731,7 @@ async def on_POST(self, request): async def _do_appservice_registration( self, username, as_token, body, should_issue_refresh_token: bool = False - ): + ) -> JsonDict: user_id = await self.registration_handler.appservice_register( username, as_token ) @@ -766,7 +748,7 @@ async def _create_registration_details( params: JsonDict, is_appservice_ghost: bool = False, should_issue_refresh_token: bool = False, - ): + ) -> JsonDict: """Complete registration of newly-registered user Allocates device_id if one was not given; also creates access_token. @@ -810,7 +792,9 @@ async def _create_registration_details( return result - async def _do_guest_registration(self, params, address=None): + async def _do_guest_registration( + self, params: JsonDict, address: Optional[str] = None + ) -> Tuple[int, JsonDict]: if not self.hs.config.allow_guest_access: raise SynapseError(403, "Guest access is disabled") user_id = await self.registration_handler.register_user( @@ -848,9 +832,7 @@ async def _do_guest_registration(self, params, address=None): def _calculate_registration_flows( - # technically `config` has to provide *all* of these interfaces, not just one - config: Union[RegistrationConfig, ConsentConfig, CaptchaConfig], - auth_handler: AuthHandler, + config: HomeServerConfig, auth_handler: AuthHandler ) -> List[List[str]]: """Get a suitable flows list for registration @@ -929,7 +911,7 @@ def _calculate_registration_flows( return flows -def register_servlets(hs, http_server): +def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None: EmailRegisterRequestTokenRestServlet(hs).register(http_server) MsisdnRegisterRequestTokenRestServlet(hs).register(http_server) UsernameAvailabilityRestServlet(hs).register(http_server) diff --git a/synapse/rest/client/relations.py b/synapse/rest/client/relations.py index 0821cd285fd3..0b0711c03c48 100644 --- a/synapse/rest/client/relations.py +++ b/synapse/rest/client/relations.py @@ -19,25 +19,32 @@ """ import logging +from typing import TYPE_CHECKING, Awaitable, Optional, Tuple from synapse.api.constants import EventTypes, RelationTypes from synapse.api.errors import ShadowBanError, SynapseError +from synapse.http.server import HttpServer from synapse.http.servlet import ( RestServlet, parse_integer, parse_json_object_from_request, parse_string, ) +from synapse.http.site import SynapseRequest from synapse.rest.client.transactions import HttpTransactionCache from synapse.storage.relations import ( AggregationPaginationToken, PaginationChunk, RelationPaginationToken, ) +from synapse.types import JsonDict from synapse.util.stringutils import random_string from ._base import client_patterns +if TYPE_CHECKING: + from synapse.server import HomeServer + logger = logging.getLogger(__name__) @@ -59,13 +66,13 @@ class RelationSendServlet(RestServlet): "/(?P[^/]*)/(?P[^/]*)/(?P[^/]*)" ) - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__() self.auth = hs.get_auth() self.event_creation_handler = hs.get_event_creation_handler() self.txns = HttpTransactionCache(hs) - def register(self, http_server): + def register(self, http_server: HttpServer) -> None: http_server.register_paths( "POST", client_patterns(self.PATTERN + "$", releases=()), @@ -79,14 +86,35 @@ def register(self, http_server): self.__class__.__name__, ) - def on_PUT(self, request, *args, **kwargs): + def on_PUT( + self, + request: SynapseRequest, + room_id: str, + parent_id: str, + relation_type: str, + event_type: str, + txn_id: Optional[str] = None, + ) -> Awaitable[Tuple[int, JsonDict]]: return self.txns.fetch_or_execute_request( - request, self.on_PUT_or_POST, request, *args, **kwargs + request, + self.on_PUT_or_POST, + request, + room_id, + parent_id, + relation_type, + event_type, + txn_id, ) async def on_PUT_or_POST( - self, request, room_id, parent_id, relation_type, event_type, txn_id=None - ): + self, + request: SynapseRequest, + room_id: str, + parent_id: str, + relation_type: str, + event_type: str, + txn_id: Optional[str] = None, + ) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request, allow_guest=True) if event_type == EventTypes.Member: @@ -136,7 +164,7 @@ class RelationPaginationServlet(RestServlet): releases=(), ) - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__() self.auth = hs.get_auth() self.store = hs.get_datastore() @@ -145,8 +173,13 @@ def __init__(self, hs): self.event_handler = hs.get_event_handler() async def on_GET( - self, request, room_id, parent_id, relation_type=None, event_type=None - ): + self, + request: SynapseRequest, + room_id: str, + parent_id: str, + relation_type: Optional[str] = None, + event_type: Optional[str] = None, + ) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request, allow_guest=True) await self.auth.check_user_in_room_or_world_readable( @@ -156,6 +189,8 @@ async def on_GET( # This gets the original event and checks that a) the event exists and # b) the user is allowed to view it. event = await self.event_handler.get_event(requester.user, room_id, parent_id) + if event is None: + raise SynapseError(404, "Unknown parent event.") limit = parse_integer(request, "limit", default=5) from_token_str = parse_string(request, "from") @@ -233,15 +268,20 @@ class RelationAggregationPaginationServlet(RestServlet): releases=(), ) - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__() self.auth = hs.get_auth() self.store = hs.get_datastore() self.event_handler = hs.get_event_handler() async def on_GET( - self, request, room_id, parent_id, relation_type=None, event_type=None - ): + self, + request: SynapseRequest, + room_id: str, + parent_id: str, + relation_type: Optional[str] = None, + event_type: Optional[str] = None, + ) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request, allow_guest=True) await self.auth.check_user_in_room_or_world_readable( @@ -253,6 +293,8 @@ async def on_GET( # This checks that a) the event exists and b) the user is allowed to # view it. event = await self.event_handler.get_event(requester.user, room_id, parent_id) + if event is None: + raise SynapseError(404, "Unknown parent event.") if relation_type not in (RelationTypes.ANNOTATION, None): raise SynapseError(400, "Relation type must be 'annotation'") @@ -315,7 +357,7 @@ class RelationAggregationGroupPaginationServlet(RestServlet): releases=(), ) - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__() self.auth = hs.get_auth() self.store = hs.get_datastore() @@ -323,7 +365,15 @@ def __init__(self, hs): self._event_serializer = hs.get_event_client_serializer() self.event_handler = hs.get_event_handler() - async def on_GET(self, request, room_id, parent_id, relation_type, event_type, key): + async def on_GET( + self, + request: SynapseRequest, + room_id: str, + parent_id: str, + relation_type: str, + event_type: str, + key: str, + ) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request, allow_guest=True) await self.auth.check_user_in_room_or_world_readable( @@ -374,7 +424,7 @@ async def on_GET(self, request, room_id, parent_id, relation_type, event_type, k return 200, return_value -def register_servlets(hs, http_server): +def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None: RelationSendServlet(hs).register(http_server) RelationPaginationServlet(hs).register(http_server) RelationAggregationPaginationServlet(hs).register(http_server) diff --git a/synapse/rest/client/room.py b/synapse/rest/client/room.py index c5c54564bed3..9b0c54650533 100644 --- a/synapse/rest/client/room.py +++ b/synapse/rest/client/room.py @@ -16,9 +16,11 @@ """ This module contains REST servlets to do with rooms: /rooms/ """ import logging import re -from typing import TYPE_CHECKING, Dict, List, Optional, Tuple +from typing import TYPE_CHECKING, Awaitable, Dict, List, Optional, Tuple from urllib import parse as urlparse +from twisted.web.server import Request + from synapse.api.constants import EventTypes, Membership from synapse.api.errors import ( AuthError, @@ -30,6 +32,7 @@ ) from synapse.api.filtering import Filter from synapse.events.utils import format_event_for_client_v2 +from synapse.http.server import HttpServer from synapse.http.servlet import ( ResolveRoomIdMixin, RestServlet, @@ -57,7 +60,7 @@ class TransactionRestServlet(RestServlet): - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__() self.txns = HttpTransactionCache(hs) @@ -65,20 +68,22 @@ def __init__(self, hs): class RoomCreateRestServlet(TransactionRestServlet): # No PATTERN; we have custom dispatch rules here - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__(hs) self._room_creation_handler = hs.get_room_creation_handler() self.auth = hs.get_auth() - def register(self, http_server): + def register(self, http_server: HttpServer) -> None: PATTERNS = "/createRoom" register_txn_path(self, PATTERNS, http_server) - def on_PUT(self, request, txn_id): + def on_PUT( + self, request: SynapseRequest, txn_id: str + ) -> Awaitable[Tuple[int, JsonDict]]: set_tag("txn_id", txn_id) return self.txns.fetch_or_execute_request(request, self.on_POST, request) - async def on_POST(self, request): + async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request) info, _ = await self._room_creation_handler.create_room( @@ -87,21 +92,21 @@ async def on_POST(self, request): return 200, info - def get_room_config(self, request): + def get_room_config(self, request: Request) -> JsonDict: user_supplied_config = parse_json_object_from_request(request) return user_supplied_config # TODO: Needs unit testing for generic events class RoomStateEventRestServlet(TransactionRestServlet): - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__(hs) self.event_creation_handler = hs.get_event_creation_handler() self.room_member_handler = hs.get_room_member_handler() self.message_handler = hs.get_message_handler() self.auth = hs.get_auth() - def register(self, http_server): + def register(self, http_server: HttpServer) -> None: # /room/$roomid/state/$eventtype no_state_key = "/rooms/(?P[^/]*)/state/(?P[^/]*)$" @@ -136,13 +141,19 @@ def register(self, http_server): self.__class__.__name__, ) - def on_GET_no_state_key(self, request, room_id, event_type): + def on_GET_no_state_key( + self, request: SynapseRequest, room_id: str, event_type: str + ) -> Awaitable[Tuple[int, JsonDict]]: return self.on_GET(request, room_id, event_type, "") - def on_PUT_no_state_key(self, request, room_id, event_type): + def on_PUT_no_state_key( + self, request: SynapseRequest, room_id: str, event_type: str + ) -> Awaitable[Tuple[int, JsonDict]]: return self.on_PUT(request, room_id, event_type, "") - async def on_GET(self, request, room_id, event_type, state_key): + async def on_GET( + self, request: SynapseRequest, room_id: str, event_type: str, state_key: str + ) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request, allow_guest=True) format = parse_string( request, "format", default="content", allowed_values=["content", "event"] @@ -165,7 +176,17 @@ async def on_GET(self, request, room_id, event_type, state_key): elif format == "content": return 200, data.get_dict()["content"] - async def on_PUT(self, request, room_id, event_type, state_key, txn_id=None): + # Format must be event or content, per the parse_string call above. + raise RuntimeError(f"Unknown format: {format:r}.") + + async def on_PUT( + self, + request: SynapseRequest, + room_id: str, + event_type: str, + state_key: str, + txn_id: Optional[str] = None, + ) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request) if txn_id: @@ -211,27 +232,35 @@ async def on_PUT(self, request, room_id, event_type, state_key, txn_id=None): # TODO: Needs unit testing for generic events + feedback class RoomSendEventRestServlet(TransactionRestServlet): - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__(hs) self.event_creation_handler = hs.get_event_creation_handler() self.auth = hs.get_auth() - def register(self, http_server): + def register(self, http_server: HttpServer) -> None: # /rooms/$roomid/send/$event_type[/$txn_id] PATTERNS = "/rooms/(?P[^/]*)/send/(?P[^/]*)" register_txn_path(self, PATTERNS, http_server, with_get=True) - async def on_POST(self, request, room_id, event_type, txn_id=None): + async def on_POST( + self, + request: SynapseRequest, + room_id: str, + event_type: str, + txn_id: Optional[str] = None, + ) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request, allow_guest=True) content = parse_json_object_from_request(request) - event_dict = { + event_dict: JsonDict = { "type": event_type, "content": content, "room_id": room_id, "sender": requester.user.to_string(), } + # Twisted will have processed the args by now. + assert request.args is not None if b"ts" in request.args and requester.app_service: event_dict["origin_server_ts"] = parse_integer(request, "ts", 0) @@ -249,10 +278,14 @@ async def on_POST(self, request, room_id, event_type, txn_id=None): set_tag("event_id", event_id) return 200, {"event_id": event_id} - def on_GET(self, request, room_id, event_type, txn_id): + def on_GET( + self, request: SynapseRequest, room_id: str, event_type: str, txn_id: str + ) -> Tuple[int, str]: return 200, "Not implemented" - def on_PUT(self, request, room_id, event_type, txn_id): + def on_PUT( + self, request: SynapseRequest, room_id: str, event_type: str, txn_id: str + ) -> Awaitable[Tuple[int, JsonDict]]: set_tag("txn_id", txn_id) return self.txns.fetch_or_execute_request( @@ -262,12 +295,12 @@ def on_PUT(self, request, room_id, event_type, txn_id): # TODO: Needs unit testing for room ID + alias joins class JoinRoomAliasServlet(ResolveRoomIdMixin, TransactionRestServlet): - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__(hs) super(ResolveRoomIdMixin, self).__init__(hs) # ensure the Mixin is set up self.auth = hs.get_auth() - def register(self, http_server): + def register(self, http_server: HttpServer) -> None: # /join/$room_identifier[/$txn_id] PATTERNS = "/join/(?P[^/]*)" register_txn_path(self, PATTERNS, http_server) @@ -277,7 +310,7 @@ async def on_POST( request: SynapseRequest, room_identifier: str, txn_id: Optional[str] = None, - ): + ) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request, allow_guest=True) try: @@ -308,7 +341,9 @@ async def on_POST( return 200, {"room_id": room_id} - def on_PUT(self, request, room_identifier, txn_id): + def on_PUT( + self, request: SynapseRequest, room_identifier: str, txn_id: str + ) -> Awaitable[Tuple[int, JsonDict]]: set_tag("txn_id", txn_id) return self.txns.fetch_or_execute_request( @@ -320,12 +355,12 @@ def on_PUT(self, request, room_identifier, txn_id): class PublicRoomListRestServlet(TransactionRestServlet): PATTERNS = client_patterns("/publicRooms$", v1=True) - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__(hs) self.hs = hs self.auth = hs.get_auth() - async def on_GET(self, request): + async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]: server = parse_string(request, "server") try: @@ -374,7 +409,7 @@ async def on_GET(self, request): return 200, data - async def on_POST(self, request): + async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]: await self.auth.get_user_by_req(request, allow_guest=True) server = parse_string(request, "server") @@ -438,13 +473,15 @@ async def on_POST(self, request): class RoomMemberListRestServlet(RestServlet): PATTERNS = client_patterns("/rooms/(?P[^/]*)/members$", v1=True) - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__() self.message_handler = hs.get_message_handler() self.auth = hs.get_auth() self.store = hs.get_datastore() - async def on_GET(self, request, room_id): + async def on_GET( + self, request: SynapseRequest, room_id: str + ) -> Tuple[int, JsonDict]: # TODO support Pagination stream API (limit/tokens) requester = await self.auth.get_user_by_req(request, allow_guest=True) handler = self.message_handler @@ -490,12 +527,14 @@ async def on_GET(self, request, room_id): class JoinedRoomMemberListRestServlet(RestServlet): PATTERNS = client_patterns("/rooms/(?P[^/]*)/joined_members$", v1=True) - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__() self.message_handler = hs.get_message_handler() self.auth = hs.get_auth() - async def on_GET(self, request, room_id): + async def on_GET( + self, request: SynapseRequest, room_id: str + ) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request) users_with_profile = await self.message_handler.get_joined_members( @@ -509,17 +548,21 @@ async def on_GET(self, request, room_id): class RoomMessageListRestServlet(RestServlet): PATTERNS = client_patterns("/rooms/(?P[^/]*)/messages$", v1=True) - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__() self.pagination_handler = hs.get_pagination_handler() self.auth = hs.get_auth() self.store = hs.get_datastore() - async def on_GET(self, request, room_id): + async def on_GET( + self, request: SynapseRequest, room_id: str + ) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request, allow_guest=True) pagination_config = await PaginationConfig.from_request( self.store, request, default_limit=10 ) + # Twisted will have processed the args by now. + assert request.args is not None as_client_event = b"raw" not in request.args filter_str = parse_string(request, "filter", encoding="utf-8") if filter_str: @@ -549,12 +592,14 @@ async def on_GET(self, request, room_id): class RoomStateRestServlet(RestServlet): PATTERNS = client_patterns("/rooms/(?P[^/]*)/state$", v1=True) - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__() self.message_handler = hs.get_message_handler() self.auth = hs.get_auth() - async def on_GET(self, request, room_id): + async def on_GET( + self, request: SynapseRequest, room_id: str + ) -> Tuple[int, List[JsonDict]]: requester = await self.auth.get_user_by_req(request, allow_guest=True) # Get all the current state for this room events = await self.message_handler.get_state_events( @@ -569,13 +614,15 @@ async def on_GET(self, request, room_id): class RoomInitialSyncRestServlet(RestServlet): PATTERNS = client_patterns("/rooms/(?P[^/]*)/initialSync$", v1=True) - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__() self.initial_sync_handler = hs.get_initial_sync_handler() self.auth = hs.get_auth() self.store = hs.get_datastore() - async def on_GET(self, request, room_id): + async def on_GET( + self, request: SynapseRequest, room_id: str + ) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request, allow_guest=True) pagination_config = await PaginationConfig.from_request(self.store, request) content = await self.initial_sync_handler.room_initial_sync( @@ -589,14 +636,16 @@ class RoomEventServlet(RestServlet): "/rooms/(?P[^/]*)/event/(?P[^/]*)$", v1=True ) - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__() self.clock = hs.get_clock() self.event_handler = hs.get_event_handler() self._event_serializer = hs.get_event_client_serializer() self.auth = hs.get_auth() - async def on_GET(self, request, room_id, event_id): + async def on_GET( + self, request: SynapseRequest, room_id: str, event_id: str + ) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request, allow_guest=True) try: event = await self.event_handler.get_event( @@ -610,10 +659,10 @@ async def on_GET(self, request, room_id, event_id): time_now = self.clock.time_msec() if event: - event = await self._event_serializer.serialize_event(event, time_now) - return 200, event + event_dict = await self._event_serializer.serialize_event(event, time_now) + return 200, event_dict - return SynapseError(404, "Event not found.", errcode=Codes.NOT_FOUND) + raise SynapseError(404, "Event not found.", errcode=Codes.NOT_FOUND) class RoomEventContextServlet(RestServlet): @@ -621,14 +670,16 @@ class RoomEventContextServlet(RestServlet): "/rooms/(?P[^/]*)/context/(?P[^/]*)$", v1=True ) - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__() self.clock = hs.get_clock() self.room_context_handler = hs.get_room_context_handler() self._event_serializer = hs.get_event_client_serializer() self.auth = hs.get_auth() - async def on_GET(self, request, room_id, event_id): + async def on_GET( + self, request: SynapseRequest, room_id: str, event_id: str + ) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request, allow_guest=True) limit = parse_integer(request, "limit", default=10) @@ -669,23 +720,27 @@ async def on_GET(self, request, room_id, event_id): class RoomForgetRestServlet(TransactionRestServlet): - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__(hs) self.room_member_handler = hs.get_room_member_handler() self.auth = hs.get_auth() - def register(self, http_server): + def register(self, http_server: HttpServer) -> None: PATTERNS = "/rooms/(?P[^/]*)/forget" register_txn_path(self, PATTERNS, http_server) - async def on_POST(self, request, room_id, txn_id=None): + async def on_POST( + self, request: SynapseRequest, room_id: str, txn_id: Optional[str] = None + ) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request, allow_guest=False) await self.room_member_handler.forget(user=requester.user, room_id=room_id) return 200, {} - def on_PUT(self, request, room_id, txn_id): + def on_PUT( + self, request: SynapseRequest, room_id: str, txn_id: str + ) -> Awaitable[Tuple[int, JsonDict]]: set_tag("txn_id", txn_id) return self.txns.fetch_or_execute_request( @@ -695,12 +750,12 @@ def on_PUT(self, request, room_id, txn_id): # TODO: Needs unit testing class RoomMembershipRestServlet(TransactionRestServlet): - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__(hs) self.room_member_handler = hs.get_room_member_handler() self.auth = hs.get_auth() - def register(self, http_server): + def register(self, http_server: HttpServer) -> None: # /rooms/$roomid/[invite|join|leave] PATTERNS = ( "/rooms/(?P[^/]*)/" @@ -708,7 +763,13 @@ def register(self, http_server): ) register_txn_path(self, PATTERNS, http_server) - async def on_POST(self, request, room_id, membership_action, txn_id=None): + async def on_POST( + self, + request: SynapseRequest, + room_id: str, + membership_action: str, + txn_id: Optional[str] = None, + ) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request, allow_guest=True) if requester.is_guest and membership_action not in { @@ -771,13 +832,15 @@ async def on_POST(self, request, room_id, membership_action, txn_id=None): return 200, return_value - def _has_3pid_invite_keys(self, content): + def _has_3pid_invite_keys(self, content: JsonDict) -> bool: for key in {"id_server", "medium", "address"}: if key not in content: return False return True - def on_PUT(self, request, room_id, membership_action, txn_id): + def on_PUT( + self, request: SynapseRequest, room_id: str, membership_action: str, txn_id: str + ) -> Awaitable[Tuple[int, JsonDict]]: set_tag("txn_id", txn_id) return self.txns.fetch_or_execute_request( @@ -786,16 +849,22 @@ def on_PUT(self, request, room_id, membership_action, txn_id): class RoomRedactEventRestServlet(TransactionRestServlet): - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__(hs) self.event_creation_handler = hs.get_event_creation_handler() self.auth = hs.get_auth() - def register(self, http_server): + def register(self, http_server: HttpServer) -> None: PATTERNS = "/rooms/(?P[^/]*)/redact/(?P[^/]*)" register_txn_path(self, PATTERNS, http_server) - async def on_POST(self, request, room_id, event_id, txn_id=None): + async def on_POST( + self, + request: SynapseRequest, + room_id: str, + event_id: str, + txn_id: Optional[str] = None, + ) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request) content = parse_json_object_from_request(request) @@ -821,7 +890,9 @@ async def on_POST(self, request, room_id, event_id, txn_id=None): set_tag("event_id", event_id) return 200, {"event_id": event_id} - def on_PUT(self, request, room_id, event_id, txn_id): + def on_PUT( + self, request: SynapseRequest, room_id: str, event_id: str, txn_id: str + ) -> Awaitable[Tuple[int, JsonDict]]: set_tag("txn_id", txn_id) return self.txns.fetch_or_execute_request( @@ -846,7 +917,9 @@ def __init__(self, hs: "HomeServer"): hs.config.worker.writers.typing == hs.get_instance_name() ) - async def on_PUT(self, request, room_id, user_id): + async def on_PUT( + self, request: SynapseRequest, room_id: str, user_id: str + ) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request) if not self._is_typing_writer: @@ -897,7 +970,9 @@ def __init__(self, hs: "HomeServer"): self.auth = hs.get_auth() self.directory_handler = hs.get_directory_handler() - async def on_GET(self, request, room_id): + async def on_GET( + self, request: SynapseRequest, room_id: str + ) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request) alias_list = await self.directory_handler.get_aliases_for_room( @@ -910,12 +985,12 @@ async def on_GET(self, request, room_id): class SearchRestServlet(RestServlet): PATTERNS = client_patterns("/search$", v1=True) - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__() self.search_handler = hs.get_search_handler() self.auth = hs.get_auth() - async def on_POST(self, request): + async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request) content = parse_json_object_from_request(request) @@ -929,19 +1004,24 @@ async def on_POST(self, request): class JoinedRoomsRestServlet(RestServlet): PATTERNS = client_patterns("/joined_rooms$", v1=True) - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__() self.store = hs.get_datastore() self.auth = hs.get_auth() - async def on_GET(self, request): + async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request, allow_guest=True) room_ids = await self.store.get_rooms_for_user(requester.user.to_string()) return 200, {"joined_rooms": list(room_ids)} -def register_txn_path(servlet, regex_string, http_server, with_get=False): +def register_txn_path( + servlet: RestServlet, + regex_string: str, + http_server: HttpServer, + with_get: bool = False, +) -> None: """Registers a transaction-based path. This registers two paths: @@ -949,28 +1029,37 @@ def register_txn_path(servlet, regex_string, http_server, with_get=False): POST regex_string Args: - regex_string (str): The regex string to register. Must NOT have a - trailing $ as this string will be appended to. - http_server : The http_server to register paths with. + regex_string: The regex string to register. Must NOT have a + trailing $ as this string will be appended to. + http_server: The http_server to register paths with. with_get: True to also register respective GET paths for the PUTs. """ + on_POST = getattr(servlet, "on_POST", None) + on_PUT = getattr(servlet, "on_PUT", None) + if on_POST is None or on_PUT is None: + raise RuntimeError("on_POST and on_PUT must exist when using register_txn_path") http_server.register_paths( "POST", client_patterns(regex_string + "$", v1=True), - servlet.on_POST, + on_POST, servlet.__class__.__name__, ) http_server.register_paths( "PUT", client_patterns(regex_string + "/(?P[^/]*)$", v1=True), - servlet.on_PUT, + on_PUT, servlet.__class__.__name__, ) + on_GET = getattr(servlet, "on_GET", None) if with_get: + if on_GET is None: + raise RuntimeError( + "register_txn_path called with with_get = True, but no on_GET method exists" + ) http_server.register_paths( "GET", client_patterns(regex_string + "/(?P[^/]*)$", v1=True), - servlet.on_GET, + on_GET, servlet.__class__.__name__, ) @@ -1120,7 +1209,9 @@ async def on_GET( ) -def register_servlets(hs: "HomeServer", http_server, is_worker=False): +def register_servlets( + hs: "HomeServer", http_server: HttpServer, is_worker: bool = False +) -> None: RoomStateEventRestServlet(hs).register(http_server) RoomMemberListRestServlet(hs).register(http_server) JoinedRoomMemberListRestServlet(hs).register(http_server) @@ -1148,5 +1239,5 @@ def register_servlets(hs: "HomeServer", http_server, is_worker=False): RoomForgetRestServlet(hs).register(http_server) -def register_deprecated_servlets(hs, http_server): +def register_deprecated_servlets(hs: "HomeServer", http_server: HttpServer) -> None: RoomInitialSyncRestServlet(hs).register(http_server)