Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Prefer make_awaitable over defer.succeed in tests #12505

Merged
merged 9 commits into from
Apr 27, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.d/12505.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Use `make_awaitable` instead of `defer.succeed` for return values of mocks in tests.
26 changes: 17 additions & 9 deletions synapse/logging/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -722,6 +722,11 @@ def nested_logging_context(suffix: str) -> LoggingContext:
R = TypeVar("R")


async def _unwrap_awaitable(awaitable: Awaitable[R]) -> R:
"""Unwraps an arbitrary awaitable by awaiting it."""
return await awaitable


@overload
def preserve_fn( # type: ignore[misc]
f: Callable[P, Awaitable[R]],
Expand Down Expand Up @@ -802,17 +807,20 @@ def run_in_background( # type: ignore[misc]
# by synchronous exceptions, so let's turn them into Failures.
return defer.fail()

# `res` may be a coroutine, `Deferred`, some other kind of awaitable, or a plain
# value. Convert it to a `Deferred`.
if isinstance(res, typing.Coroutine):
# Wrap the coroutine in a `Deferred`.
res = defer.ensureDeferred(res)

# At this point we should have a Deferred, if not then f was a synchronous
# function, wrap it in a Deferred for consistency.
if not isinstance(res, defer.Deferred):
# `res` is not a `Deferred` and not a `Coroutine`.
# There are no other types of `Awaitable`s we expect to encounter in Synapse.
assert not isinstance(res, Awaitable)

return defer.succeed(res)
elif isinstance(res, defer.Deferred):
pass
elif isinstance(res, Awaitable):
# `res` is probably some kind of completed awaitable, such as a `DoneAwaitable`
# or `Future` from `make_awaitable`.
res = defer.ensureDeferred(_unwrap_awaitable(res))
else:
# `res` is a plain value. Wrap it in a `Deferred`.
res = defer.succeed(res)

if res.called and not res.paused:
# The function should have maintained the logcontext, so we can
Expand Down
2 changes: 1 addition & 1 deletion tests/federation/test_federation_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ def test_get_room_state(self):
)

# mock up the response, and have the agent return it
self._mock_agent.request.return_value = defer.succeed(
self._mock_agent.request.side_effect = lambda *args, **kwargs: defer.succeed(
clokep marked this conversation as resolved.
Show resolved Hide resolved
_mock_response(
{
"pdus": [
Expand Down
2 changes: 1 addition & 1 deletion tests/federation/test_federation_sender.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,7 +226,7 @@ def test_dont_send_device_updates_for_remote_users(self):
# Send the server a device list EDU for the other user, this will cause
# it to try and resync the device lists.
self.hs.get_federation_transport_client().query_user_devices.return_value = (
defer.succeed(
make_awaitable(
{
"stream_id": "1",
"user_id": "@user2:host2",
Expand Down
7 changes: 3 additions & 4 deletions tests/handlers/test_e2e_keys.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
from parameterized import parameterized
from signedjson import key as key, sign as sign

from twisted.internet import defer
from twisted.test.proto_helpers import MemoryReactor

from synapse.api.constants import RoomEncryptionAlgorithms
Expand Down Expand Up @@ -704,7 +703,7 @@ def test_query_devices_remote_no_sync(self) -> None:
remote_self_signing_key = "QeIiFEjluPBtI7WQdG365QKZcFs9kqmHir6RBD0//nQ"

self.hs.get_federation_client().query_client_keys = mock.Mock(
return_value=defer.succeed(
return_value=make_awaitable(
{
"device_keys": {remote_user_id: {}},
"master_keys": {
Expand Down Expand Up @@ -777,14 +776,14 @@ def test_query_devices_remote_sync(self) -> None:
# Pretend we're sharing a room with the user we're querying. If not,
# `_query_devices_for_destination` will return early.
self.store.get_rooms_for_user = mock.Mock(
return_value=defer.succeed({"some_room_id"})
return_value=make_awaitable({"some_room_id"})
)

remote_master_key = "85T7JXPFBAySB/jwby4S3lBPTqY3+Zg53nYuGmu1ggY"
remote_self_signing_key = "QeIiFEjluPBtI7WQdG365QKZcFs9kqmHir6RBD0//nQ"

self.hs.get_federation_client().query_user_devices = mock.Mock(
return_value=defer.succeed(
return_value=make_awaitable(
{
"user_id": remote_user_id,
"stream_id": 1,
Expand Down
34 changes: 16 additions & 18 deletions tests/handlers/test_password_providers.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,6 @@
from typing import Any, Type, Union
from unittest.mock import Mock

from twisted.internet import defer

import synapse
from synapse.api.constants import LoginType
from synapse.api.errors import Codes
Expand Down Expand Up @@ -190,7 +188,7 @@ def password_only_auth_provider_login_test_body(self):
self.assertEqual(flows, [{"type": "m.login.password"}] + ADDITIONAL_LOGIN_FLOWS)

# check_password must return an awaitable
mock_password_provider.check_password.return_value = defer.succeed(True)
mock_password_provider.check_password.return_value = make_awaitable(True)
channel = self._send_password_login("u", "p")
self.assertEqual(channel.code, 200, channel.result)
self.assertEqual("@u:test", channel.json_body["user_id"])
Expand Down Expand Up @@ -226,13 +224,13 @@ def password_only_auth_provider_ui_auth_test_body(self):
self.get_success(module_api.register_user("u"))

# log in twice, to get two devices
mock_password_provider.check_password.return_value = defer.succeed(True)
mock_password_provider.check_password.return_value = make_awaitable(True)
tok1 = self.login("u", "p")
self.login("u", "p", device_id="dev2")
mock_password_provider.reset_mock()

# have the auth provider deny the request to start with
mock_password_provider.check_password.return_value = defer.succeed(False)
mock_password_provider.check_password.return_value = make_awaitable(False)

# make the initial request which returns a 401
session = self._start_delete_device_session(tok1, "dev2")
Expand All @@ -246,7 +244,7 @@ def password_only_auth_provider_ui_auth_test_body(self):
mock_password_provider.reset_mock()

# Finally, check the request goes through when we allow it
mock_password_provider.check_password.return_value = defer.succeed(True)
mock_password_provider.check_password.return_value = make_awaitable(True)
channel = self._authed_delete_device(tok1, "dev2", session, "u", "p")
self.assertEqual(channel.code, 200)
mock_password_provider.check_password.assert_called_once_with("@u:test", "p")
Expand All @@ -260,7 +258,7 @@ def local_user_fallback_login_test_body(self):
self.register_user("localuser", "localpass")

# check_password must return an awaitable
mock_password_provider.check_password.return_value = defer.succeed(False)
mock_password_provider.check_password.return_value = make_awaitable(False)
channel = self._send_password_login("u", "p")
self.assertEqual(channel.code, 403, channel.result)

Expand All @@ -277,7 +275,7 @@ def local_user_fallback_ui_auth_test_body(self):
self.register_user("localuser", "localpass")

# have the auth provider deny the request
mock_password_provider.check_password.return_value = defer.succeed(False)
mock_password_provider.check_password.return_value = make_awaitable(False)

# log in twice, to get two devices
tok1 = self.login("localuser", "localpass")
Expand Down Expand Up @@ -320,7 +318,7 @@ def no_local_user_fallback_login_test_body(self):
self.register_user("localuser", "localpass")

# check_password must return an awaitable
mock_password_provider.check_password.return_value = defer.succeed(False)
mock_password_provider.check_password.return_value = make_awaitable(False)
channel = self._send_password_login("localuser", "localpass")
self.assertEqual(channel.code, 403)
self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
Expand All @@ -342,7 +340,7 @@ def no_local_user_fallback_ui_auth_test_body(self):
self.register_user("localuser", "localpass")

# allow login via the auth provider
mock_password_provider.check_password.return_value = defer.succeed(True)
mock_password_provider.check_password.return_value = make_awaitable(True)

# log in twice, to get two devices
tok1 = self.login("localuser", "p")
Expand All @@ -359,7 +357,7 @@ def no_local_user_fallback_ui_auth_test_body(self):
mock_password_provider.check_password.assert_not_called()

# now try deleting with the local password
mock_password_provider.check_password.return_value = defer.succeed(False)
mock_password_provider.check_password.return_value = make_awaitable(False)
channel = self._authed_delete_device(
tok1, "dev2", session, "localuser", "localpass"
)
Expand Down Expand Up @@ -413,7 +411,7 @@ def custom_auth_provider_login_test_body(self):
self.assertEqual(channel.code, 400, channel.result)
mock_password_provider.check_auth.assert_not_called()

mock_password_provider.check_auth.return_value = defer.succeed(
mock_password_provider.check_auth.return_value = make_awaitable(
("@user:bz", None)
)
channel = self._send_login("test.login_type", "u", test_field="y")
Expand All @@ -427,7 +425,7 @@ def custom_auth_provider_login_test_body(self):
# try a weird username. Again, it's unclear what we *expect* to happen
# in these cases, but at least we can guard against the API changing
# unexpectedly
mock_password_provider.check_auth.return_value = defer.succeed(
mock_password_provider.check_auth.return_value = make_awaitable(
("@ MALFORMED! :bz", None)
)
channel = self._send_login("test.login_type", " USER🙂NAME ", test_field=" abc ")
Expand Down Expand Up @@ -477,7 +475,7 @@ def custom_auth_provider_ui_auth_test_body(self):
mock_password_provider.reset_mock()

# right params, but authing as the wrong user
mock_password_provider.check_auth.return_value = defer.succeed(
mock_password_provider.check_auth.return_value = make_awaitable(
("@user:bz", None)
)
body["auth"]["test_field"] = "foo"
Expand All @@ -490,7 +488,7 @@ def custom_auth_provider_ui_auth_test_body(self):
mock_password_provider.reset_mock()

# and finally, succeed
mock_password_provider.check_auth.return_value = defer.succeed(
mock_password_provider.check_auth.return_value = make_awaitable(
("@localuser:test", None)
)
channel = self._delete_device(tok1, "dev2", body)
Expand All @@ -508,9 +506,9 @@ def test_custom_auth_provider_callback(self):
self.custom_auth_provider_callback_test_body()

def custom_auth_provider_callback_test_body(self):
callback = Mock(return_value=defer.succeed(None))
callback = Mock(return_value=make_awaitable(None))

mock_password_provider.check_auth.return_value = defer.succeed(
mock_password_provider.check_auth.return_value = make_awaitable(
("@user:bz", callback)
)
channel = self._send_login("test.login_type", "u", test_field="y")
Expand Down Expand Up @@ -646,7 +644,7 @@ def password_custom_auth_password_disabled_ui_auth_test_body(self):
login is disabled"""
# register the user and log in twice via the test login type to get two devices,
self.register_user("localuser", "localpass")
mock_password_provider.check_auth.return_value = defer.succeed(
mock_password_provider.check_auth.return_value = make_awaitable(
("@localuser:test", None)
)
channel = self._send_login("test.login_type", "localuser", test_field="")
Expand Down
6 changes: 3 additions & 3 deletions tests/handlers/test_typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,11 +65,11 @@ def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
# we mock out the keyring so as to skip the authentication check on the
# federation API call.
mock_keyring = Mock(spec=["verify_json_for_server"])
mock_keyring.verify_json_for_server.return_value = defer.succeed(True)
mock_keyring.verify_json_for_server.return_value = make_awaitable(True)

# we mock out the federation client too
mock_federation_client = Mock(spec=["put_json"])
mock_federation_client.put_json.return_value = defer.succeed((200, "OK"))
mock_federation_client.put_json.return_value = make_awaitable((200, "OK"))

# the tests assume that we are starting at unix time 1000
reactor.pump((1000,))
Expand Down Expand Up @@ -98,7 +98,7 @@ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:

self.datastore = hs.get_datastores().main
self.datastore.get_destination_retry_timings = Mock(
return_value=defer.succeed(None)
return_value=make_awaitable(None)
)

self.datastore.get_device_updates_by_remote = Mock(
Expand Down
6 changes: 3 additions & 3 deletions tests/handlers/test_user_directory.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
from unittest.mock import Mock, patch
from urllib.parse import quote

from twisted.internet import defer
from twisted.test.proto_helpers import MemoryReactor

import synapse.rest.admin
Expand All @@ -30,6 +29,7 @@

from tests import unittest
from tests.storage.test_user_directory import GetUserDirectoryTables
from tests.test_utils import make_awaitable
from tests.test_utils.event_injection import inject_member_event
from tests.unittest import override_config

Expand Down Expand Up @@ -439,7 +439,7 @@ def test_handle_user_deactivated_support_user(self) -> None:
)
)

mock_remove_from_user_dir = Mock(return_value=defer.succeed(None))
mock_remove_from_user_dir = Mock(return_value=make_awaitable(None))
with patch.object(
self.store, "remove_from_user_dir", mock_remove_from_user_dir
):
Expand All @@ -454,7 +454,7 @@ def test_handle_user_deactivated_regular_user(self) -> None:
self.store.register_user(user_id=r_user_id, password_hash=None)
)

mock_remove_from_user_dir = Mock(return_value=defer.succeed(None))
mock_remove_from_user_dir = Mock(return_value=make_awaitable(None))
with patch.object(
self.store, "remove_from_user_dir", mock_remove_from_user_dir
):
Expand Down
4 changes: 2 additions & 2 deletions tests/rest/client/test_presence.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
from http import HTTPStatus
from unittest.mock import Mock

from twisted.internet import defer
from twisted.test.proto_helpers import MemoryReactor

from synapse.handlers.presence import PresenceHandler
Expand All @@ -24,6 +23,7 @@
from synapse.util import Clock

from tests import unittest
from tests.test_utils import make_awaitable


class PresenceTestCase(unittest.HomeserverTestCase):
Expand All @@ -37,7 +37,7 @@ class PresenceTestCase(unittest.HomeserverTestCase):
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:

presence_handler = Mock(spec=PresenceHandler)
presence_handler.set_state.return_value = defer.succeed(None)
presence_handler.set_state.return_value = make_awaitable(None)

hs = self.setup_test_homeserver(
"red",
Expand Down
7 changes: 2 additions & 5 deletions tests/rest/client/test_rooms.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
from unittest.mock import Mock, call
from urllib import parse as urlparse

from twisted.internet import defer
from twisted.test.proto_helpers import MemoryReactor

import synapse.rest.admin
Expand Down Expand Up @@ -1426,9 +1425,7 @@ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:

def test_simple(self) -> None:
"Simple test for searching rooms over federation"
self.federation_client.get_public_rooms.side_effect = lambda *a, **k: defer.succeed( # type: ignore[attr-defined]
{}
)
self.federation_client.get_public_rooms.return_value = make_awaitable({}) # type: ignore[attr-defined]

search_filter = {"generic_search_term": "foobar"}

Expand Down Expand Up @@ -1456,7 +1453,7 @@ def test_fallback(self) -> None:
# with a 404, when using search filters.
self.federation_client.get_public_rooms.side_effect = ( # type: ignore[attr-defined]
HttpResponseException(404, "Not Found", b""),
defer.succeed({}),
make_awaitable({}),
)

search_filter = {"generic_search_term": "foobar"}
Expand Down
7 changes: 4 additions & 3 deletions tests/rest/client/test_transactions.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from synapse.util import Clock

from tests import unittest
from tests.test_utils import make_awaitable
from tests.utils import MockClock


Expand All @@ -38,7 +39,7 @@ def setUp(self) -> None:

@defer.inlineCallbacks
def test_executes_given_function(self):
cb = Mock(return_value=defer.succeed(self.mock_http_response))
cb = Mock(return_value=make_awaitable(self.mock_http_response))
res = yield self.cache.fetch_or_execute(
self.mock_key, cb, "some_arg", keyword="arg"
)
Expand All @@ -47,7 +48,7 @@ def test_executes_given_function(self):

@defer.inlineCallbacks
def test_deduplicates_based_on_key(self):
cb = Mock(return_value=defer.succeed(self.mock_http_response))
cb = Mock(return_value=make_awaitable(self.mock_http_response))
for i in range(3): # invoke multiple times
res = yield self.cache.fetch_or_execute(
self.mock_key, cb, "some_arg", keyword="arg", changing_args=i
Expand Down Expand Up @@ -130,7 +131,7 @@ def cb():

@defer.inlineCallbacks
def test_cleans_up(self):
cb = Mock(return_value=defer.succeed(self.mock_http_response))
cb = Mock(return_value=make_awaitable(self.mock_http_response))
yield self.cache.fetch_or_execute(self.mock_key, cb, "an arg")
# should NOT have cleaned up yet
self.clock.advance_time_msec(CLEANUP_PERIOD_MS / 2)
Expand Down
Loading