Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Commit

Permalink
Protect module callbacks with read semantics against cancellation
Browse files Browse the repository at this point in the history
The `on_*` callbacks have been left alone, since they are presumed to
be run on code paths that aren't cancellation-friendly.

Signed-off-by: Sean Quah <seanq@element.io>
  • Loading branch information
Sean Quah committed Apr 13, 2022
1 parent b15f04f commit 66ac827
Show file tree
Hide file tree
Showing 5 changed files with 59 additions and 27 deletions.
6 changes: 3 additions & 3 deletions synapse/events/presence_router.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
from twisted.internet.defer import CancelledError

from synapse.api.presence import UserPresenceState
from synapse.util.async_helpers import maybe_awaitable
from synapse.util.async_helpers import delay_cancellation, maybe_awaitable

if TYPE_CHECKING:
from synapse.server import HomeServer
Expand Down Expand Up @@ -149,7 +149,7 @@ async def get_users_for_states(
# run all the callbacks for get_users_for_states and combine the results
for callback in self._get_users_for_states_callbacks:
try:
result = await callback(state_updates)
result = await delay_cancellation(callback(state_updates))
except CancelledError:
raise
except Exception as e:
Expand Down Expand Up @@ -203,7 +203,7 @@ async def get_interested_users(self, user_id: str) -> Union[Set[str], str]:
# run all the callbacks for get_interested_users and combine the results
for callback in self._get_interested_users_callbacks:
try:
result = await callback(user_id)
result = await delay_cancellation(callback(user_id))
except CancelledError:
raise
except Exception as e:
Expand Down
38 changes: 27 additions & 11 deletions synapse/events/spamcheck.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,13 @@
Union,
)

from twisted.internet import defer

from synapse.rest.media.v1._base import FileInfo
from synapse.rest.media.v1.media_storage import ReadableFileWrapper
from synapse.spam_checker_api import RegistrationBehaviour
from synapse.types import RoomAlias, UserProfile
from synapse.util.async_helpers import maybe_awaitable
from synapse.util.async_helpers import delay_cancellation, maybe_awaitable

if TYPE_CHECKING:
import synapse.events
Expand Down Expand Up @@ -255,7 +257,7 @@ async def check_event_for_spam(
will be used as the error message returned to the user.
"""
for callback in self._check_event_for_spam_callbacks:
res: Union[bool, str] = await callback(event)
res: Union[bool, str] = await delay_cancellation(callback(event))
if res:
return res

Expand All @@ -276,7 +278,10 @@ async def user_may_join_room(
Whether the user may join the room
"""
for callback in self._user_may_join_room_callbacks:
if await callback(user_id, room_id, is_invited) is False:
may_join_room = await delay_cancellation(
callback(user_id, room_id, is_invited)
)
if may_join_room is False:
return False

return True
Expand All @@ -297,7 +302,10 @@ async def user_may_invite(
True if the user may send an invite, otherwise False
"""
for callback in self._user_may_invite_callbacks:
if await callback(inviter_userid, invitee_userid, room_id) is False:
may_invite = await delay_cancellation(
callback(inviter_userid, invitee_userid, room_id)
)
if may_invite is False:
return False

return True
Expand All @@ -322,7 +330,10 @@ async def user_may_send_3pid_invite(
True if the user may send the invite, otherwise False
"""
for callback in self._user_may_send_3pid_invite_callbacks:
if await callback(inviter_userid, medium, address, room_id) is False:
may_send_3pid_invite = await delay_cancellation(
callback(inviter_userid, medium, address, room_id)
)
if may_send_3pid_invite is False:
return False

return True
Expand All @@ -339,7 +350,8 @@ async def user_may_create_room(self, userid: str) -> bool:
True if the user may create a room, otherwise False
"""
for callback in self._user_may_create_room_callbacks:
if await callback(userid) is False:
may_create_room = await delay_cancellation(callback(userid))
if may_create_room is False:
return False

return True
Expand All @@ -359,7 +371,10 @@ async def user_may_create_room_alias(
True if the user may create a room alias, otherwise False
"""
for callback in self._user_may_create_room_alias_callbacks:
if await callback(userid, room_alias) is False:
may_create_room_alias = await delay_cancellation(
callback(userid, room_alias)
)
if may_create_room_alias is False:
return False

return True
Expand All @@ -377,7 +392,8 @@ async def user_may_publish_room(self, userid: str, room_id: str) -> bool:
True if the user may publish the room, otherwise False
"""
for callback in self._user_may_publish_room_callbacks:
if await callback(userid, room_id) is False:
may_publish_room = await delay_cancellation(callback(userid, room_id))
if may_publish_room is False:
return False

return True
Expand All @@ -400,7 +416,7 @@ async def check_username_for_spam(self, user_profile: UserProfile) -> bool:
for callback in self._check_username_for_spam_callbacks:
# Make a copy of the user profile object to ensure the spam checker cannot
# modify it.
if await callback(user_profile.copy()):
if await delay_cancellation(callback(user_profile.copy())):
return True

return False
Expand Down Expand Up @@ -428,7 +444,7 @@ async def check_registration_for_spam(
"""

for callback in self._check_registration_for_spam_callbacks:
behaviour = await (
behaviour = await delay_cancellation(
callback(email_threepid, username, request_info, auth_provider_id)
)
assert isinstance(behaviour, RegistrationBehaviour)
Expand Down Expand Up @@ -472,7 +488,7 @@ async def check_media_file_for_spam(
"""

for callback in self._check_media_file_for_spam_callbacks:
spam = await callback(file_wrapper, file_info)
spam = await delay_cancellation(callback(file_wrapper, file_info))
if spam:
return True

Expand Down
24 changes: 18 additions & 6 deletions synapse/events/third_party_rules.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from synapse.events.snapshot import EventContext
from synapse.storage.roommember import ProfileInfo
from synapse.types import Requester, StateMap
from synapse.util.async_helpers import maybe_awaitable
from synapse.util.async_helpers import delay_cancellation, maybe_awaitable

if TYPE_CHECKING:
from synapse.server import HomeServer
Expand Down Expand Up @@ -265,7 +265,9 @@ async def check_event_allowed(

for callback in self._check_event_allowed_callbacks:
try:
res, replacement_data = await callback(event, state_events)
res, replacement_data = await delay_cancellation(
callback(event, state_events)
)
except CancelledError:
raise
except SynapseError as e:
Expand Down Expand Up @@ -337,7 +339,10 @@ async def check_threepid_can_be_invited(

for callback in self._check_threepid_can_be_invited_callbacks:
try:
if await callback(medium, address, state_events) is False:
threepid_can_be_invited = await delay_cancellation(
callback(medium, address, state_events)
)
if threepid_can_be_invited is False:
return False
except CancelledError:
raise
Expand Down Expand Up @@ -367,7 +372,10 @@ async def check_visibility_can_be_modified(

for callback in self._check_visibility_can_be_modified_callbacks:
try:
if await callback(room_id, state_events, new_visibility) is False:
visibility_can_be_modified = await delay_cancellation(
callback(room_id, state_events, new_visibility)
)
if visibility_can_be_modified is False:
return False
except CancelledError:
raise
Expand Down Expand Up @@ -408,7 +416,8 @@ async def check_can_shutdown_room(self, user_id: str, room_id: str) -> bool:
"""
for callback in self._check_can_shutdown_room_callbacks:
try:
if await callback(user_id, room_id) is False:
can_shutdown_room = await delay_cancellation(callback(user_id, room_id))
if can_shutdown_room is False:
return False
except CancelledError:
raise
Expand All @@ -432,7 +441,10 @@ async def check_can_deactivate_user(
"""
for callback in self._check_can_deactivate_user_callbacks:
try:
if await callback(user_id, by_admin) is False:
can_deactivate_user = await delay_cancellation(
callback(user_id, by_admin)
)
if can_deactivate_user is False:
return False
except CancelledError:
raise
Expand Down
3 changes: 2 additions & 1 deletion synapse/handlers/account_validity.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from synapse.metrics.background_process_metrics import wrap_as_background_process
from synapse.types import UserID
from synapse.util import stringutils
from synapse.util.async_helpers import delay_cancellation

if TYPE_CHECKING:
from synapse.server import HomeServer
Expand Down Expand Up @@ -150,7 +151,7 @@ async def is_user_expired(self, user_id: str) -> bool:
Whether the user has expired.
"""
for callback in self._is_user_expired_callbacks:
expired = await callback(user_id)
expired = await delay_cancellation(callback(user_id))
if expired is not None:
return expired

Expand Down
15 changes: 9 additions & 6 deletions synapse/handlers/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
import unpaddedbase64
from pymacaroons.exceptions import MacaroonVerificationFailedException

from twisted.internet import defer
from twisted.internet.defer import CancelledError
from twisted.web.server import Request

Expand Down Expand Up @@ -68,7 +69,7 @@
from synapse.storage.roommember import ProfileInfo
from synapse.types import JsonDict, Requester, UserID
from synapse.util import stringutils as stringutils
from synapse.util.async_helpers import maybe_awaitable
from synapse.util.async_helpers import delay_cancellation, maybe_awaitable
from synapse.util.macaroons import get_value_from_macaroon, satisfy_expiry
from synapse.util.msisdn import phone_number_to_msisdn
from synapse.util.stringutils import base62_encode
Expand Down Expand Up @@ -2203,7 +2204,9 @@ async def check_auth(
# other than None (i.e. until a callback returns a success)
for callback in self.auth_checker_callbacks[login_type]:
try:
result = await callback(username, login_type, login_dict)
result = await delay_cancellation(
callback(username, login_type, login_dict)
)
except CancelledError:
raise
except Exception as e:
Expand Down Expand Up @@ -2266,7 +2269,7 @@ async def check_3pid_auth(

for callback in self.check_3pid_auth_callbacks:
try:
result = await callback(medium, address, password)
result = await delay_cancellation(callback(medium, address, password))
except CancelledError:
raise
except Exception as e:
Expand Down Expand Up @@ -2350,7 +2353,7 @@ async def get_username_for_registration(
"""
for callback in self.get_username_for_registration_callbacks:
try:
res = await callback(uia_results, params)
res = await delay_cancellation(callback(uia_results, params))

if isinstance(res, str):
return res
Expand Down Expand Up @@ -2395,7 +2398,7 @@ async def get_displayname_for_registration(
"""
for callback in self.get_displayname_for_registration_callbacks:
try:
res = await callback(uia_results, params)
res = await delay_cancellation(callback(uia_results, params))

if isinstance(res, str):
return res
Expand Down Expand Up @@ -2438,7 +2441,7 @@ async def is_3pid_allowed(
"""
for callback in self.is_3pid_allowed_callbacks:
try:
res = await callback(medium, address, registration)
res = await delay_cancellation(callback(medium, address, registration))

if res is False:
return res
Expand Down

0 comments on commit 66ac827

Please sign in to comment.