From 32e7c9e7f20b57dd081023ac42d6931a8da9b3a3 Mon Sep 17 00:00:00 2001 From: Amber Brown Date: Thu, 20 Jun 2019 19:32:02 +1000 Subject: [PATCH] Run Black. (#5482) --- .buildkite/pipeline.yml | 4 +- changelog.d/5482.misc | 1 + contrib/cmdclient/console.py | 309 +++--- contrib/cmdclient/http.py | 38 +- contrib/experiments/cursesio.py | 36 +- contrib/experiments/test_messaging.py | 81 +- contrib/graph/graph.py | 33 +- contrib/graph/graph2.py | 51 +- contrib/graph/graph3.py | 46 +- contrib/jitsimeetbridge/jitsimeetbridge.py | 257 ++--- contrib/scripts/kick_users.py | 39 +- demo/webserver.py | 25 +- docker/start.py | 58 +- docs/sphinx/conf.py | 158 ++-- pyproject.toml | 19 + scripts-dev/check_auth.py | 4 +- scripts-dev/check_event_hash.py | 2 +- scripts-dev/check_signature.py | 5 +- scripts-dev/convert_server_keys.py | 2 +- scripts-dev/definitions.py | 62 +- scripts-dev/federation_client.py | 32 +- scripts-dev/hash_history.py | 2 +- scripts-dev/list_url_patterns.py | 4 +- scripts-dev/tail-synapse.py | 2 +- scripts/generate_signing_key.py | 8 +- scripts/move_remote_media_to_new_store.py | 4 +- setup.cfg | 7 +- setup.py | 31 +- synapse/__init__.py | 1 + synapse/_scripts/register_new_matrix_user.py | 24 +- synapse/api/auth.py | 159 ++-- synapse/api/constants.py | 46 +- synapse/api/errors.py | 112 ++- synapse/api/filtering.py | 189 ++-- synapse/api/ratelimiting.py | 20 +- synapse/api/room_versions.py | 25 +- synapse/api/urls.py | 17 +- synapse/app/__init__.py | 4 +- synapse/app/_base.py | 77 +- synapse/app/appservice.py | 29 +- synapse/app/client_reader.py | 29 +- synapse/app/event_creator.py | 39 +- synapse/app/federation_reader.py | 38 +- synapse/app/federation_sender.py | 46 +- synapse/app/frontend_proxy.py | 69 +- synapse/app/homeserver.py | 163 ++-- synapse/app/media_repository.py | 37 +- synapse/app/pusher.py | 60 +- synapse/app/synchrotron.py | 90 +- synapse/app/user_dir.py | 45 +- synapse/appservice/__init__.py | 39 +- synapse/appservice/api.py | 66 +- synapse/appservice/scheduler.py | 50 +- synapse/config/_base.py | 12 +- synapse/config/api.py | 22 +- synapse/config/appservice.py | 54 +- synapse/config/captcha.py | 1 - synapse/config/consent_config.py | 27 +- synapse/config/database.py | 31 +- synapse/config/emailconfig.py | 71 +- synapse/config/jwt_config.py | 5 +- synapse/config/key.py | 5 +- synapse/config/logger.py | 70 +- synapse/config/metrics.py | 8 +- synapse/config/password_auth_providers.py | 16 +- synapse/config/registration.py | 41 +- synapse/config/repository.py | 74 +- synapse/config/room_directory.py | 19 +- synapse/config/saml2_config.py | 18 +- synapse/config/server.py | 226 +++-- synapse/config/server_notices_config.py | 17 +- synapse/config/tls.py | 52 +- synapse/config/user_directory.py | 8 +- synapse/config/voip.py | 3 +- synapse/config/workers.py | 16 +- synapse/crypto/event_signing.py | 17 +- synapse/crypto/keyring.py | 33 +- synapse/event_auth.py | 177 ++-- synapse/events/__init__.py | 30 +- synapse/events/builder.py | 37 +- synapse/events/snapshot.py | 49 +- synapse/events/spamcheck.py | 4 +- synapse/events/third_party_rules.py | 5 +- synapse/events/utils.py | 58 +- synapse/events/validator.py | 21 +- synapse/federation/federation_base.py | 78 +- synapse/federation/federation_client.py | 217 +++-- synapse/federation/federation_server.py | 234 ++--- synapse/federation/persistence.py | 15 +- synapse/federation/send_queue.py | 145 ++- synapse/federation/sender/__init__.py | 67 +- .../sender/per_destination_queue.py | 7 +- .../federation/sender/transaction_manager.py | 38 +- synapse/federation/transport/client.py | 294 +++--- synapse/federation/transport/server.py | 237 +++-- synapse/federation/units.py | 33 +- synapse/groups/attestations.py | 35 +- synapse/groups/groups_server.py | 317 +++---- synapse/handlers/_base.py | 10 +- synapse/handlers/account_data.py | 23 +- synapse/handlers/account_validity.py | 61 +- synapse/handlers/acme.py | 14 +- synapse/handlers/admin.py | 23 +- synapse/handlers/appservice.py | 74 +- synapse/handlers/auth.py | 266 +++--- synapse/handlers/deactivate_account.py | 12 +- synapse/handlers/device.py | 112 +-- synapse/handlers/devicemessage.py | 7 +- synapse/handlers/directory.py | 136 +-- synapse/handlers/e2e_keys.py | 119 ++- synapse/handlers/e2e_room_keys.py | 38 +- synapse/handlers/events.py | 49 +- synapse/handlers/federation.py | 884 +++++++----------- synapse/handlers/groups_local.py | 95 +- synapse/handlers/identity.py | 132 +-- synapse/handlers/initial_sync.py | 170 ++-- synapse/handlers/message.py | 258 ++--- synapse/handlers/pagination.py | 68 +- synapse/handlers/presence.py | 311 +++--- synapse/handlers/profile.py | 65 +- synapse/handlers/read_marker.py | 9 +- synapse/handlers/receipts.py | 26 +- synapse/handlers/register.py | 224 ++--- synapse/handlers/room.py | 295 +++--- synapse/handlers/room_list.py | 254 +++-- synapse/handlers/room_member.py | 304 +++--- synapse/handlers/room_member_worker.py | 8 +- synapse/handlers/search.py | 129 ++- synapse/handlers/set_password.py | 5 +- synapse/handlers/state_deltas.py | 1 - synapse/handlers/stats.py | 2 +- synapse/handlers/sync.py | 694 +++++++------- synapse/handlers/typing.py | 74 +- synapse/http/__init__.py | 9 +- synapse/http/additional_resource.py | 1 + synapse/http/client.py | 28 +- synapse/http/endpoint.py | 20 +- .../federation/matrix_federation_agent.py | 98 +- synapse/http/federation/srv_resolver.py | 37 +- synapse/http/matrixfederationclient.py | 242 ++--- synapse/http/server.py | 80 +- synapse/http/servlet.py | 45 +- synapse/http/site.py | 48 +- synapse/metrics/__init__.py | 7 +- synapse/metrics/background_process_metrics.py | 33 +- synapse/module_api/__init__.py | 6 +- synapse/notifier.py | 70 +- synapse/push/action_generator.py | 4 +- synapse/push/baserules.py | 404 ++++---- synapse/push/bulk_push_rule_evaluator.py | 63 +- synapse/push/clientformat.py | 38 +- synapse/push/emailpusher.py | 63 +- synapse/push/httppusher.py | 226 +++-- synapse/push/mailer.py | 347 ++++--- synapse/push/presentable_names.py | 45 +- synapse/push/push_rule_evaluator.py | 57 +- synapse/push/push_tools.py | 8 +- synapse/push/pusher.py | 10 +- synapse/push/pusherpool.py | 97 +- synapse/push/rulekinds.py | 10 +- synapse/python_dependencies.py | 24 +- synapse/replication/http/_base.py | 22 +- synapse/replication/http/federation.py | 51 +- synapse/replication/http/login.py | 7 +- synapse/replication/http/membership.py | 34 +- synapse/replication/http/register.py | 17 +- synapse/replication/http/send_event.py | 13 +- synapse/replication/slave/storage/_base.py | 2 +- .../replication/slave/storage/account_data.py | 17 +- .../replication/slave/storage/appservice.py | 5 +- .../replication/slave/storage/client_ips.py | 4 +- .../replication/slave/storage/deviceinbox.py | 6 +- synapse/replication/slave/storage/devices.py | 14 +- synapse/replication/slave/storage/events.py | 77 +- synapse/replication/slave/storage/groups.py | 9 +- synapse/replication/slave/storage/presence.py | 8 +- .../replication/slave/storage/push_rule.py | 6 +- synapse/replication/slave/storage/pushers.py | 4 +- synapse/replication/slave/storage/receipts.py | 1 - synapse/replication/slave/storage/room.py | 4 +- synapse/replication/tcp/client.py | 6 +- synapse/replication/tcp/commands.py | 71 +- synapse/replication/tcp/protocol.py | 76 +- synapse/replication/tcp/resource.py | 54 +- synapse/replication/tcp/streams/_base.py | 160 ++-- synapse/replication/tcp/streams/events.py | 32 +- synapse/replication/tcp/streams/federation.py | 12 +- synapse/rest/__init__.py | 1 + synapse/rest/admin/__init__.py | 153 ++- synapse/rest/admin/server_notice_servlet.py | 13 +- synapse/rest/client/transactions.py | 3 +- synapse/rest/client/v1/directory.py | 31 +- synapse/rest/client/v1/events.py | 7 +- synapse/rest/client/v1/login.py | 121 ++- synapse/rest/client/v1/logout.py | 3 +- synapse/rest/client/v1/presence.py | 2 +- synapse/rest/client/v1/profile.py | 6 +- synapse/rest/client/v1/push_rule.py | 119 +-- synapse/rest/client/v1/pusher.py | 70 +- synapse/rest/client/v1/room.py | 185 ++-- synapse/rest/client/v1/voip.py | 24 +- synapse/rest/client/v2_alpha/_base.py | 12 +- synapse/rest/client/v2_alpha/account.py | 190 ++-- synapse/rest/client/v2_alpha/account_data.py | 16 +- .../rest/client/v2_alpha/account_validity.py | 16 +- synapse/rest/client/v2_alpha/auth.py | 62 +- synapse/rest/client/v2_alpha/devices.py | 19 +- synapse/rest/client/v2_alpha/filter.py | 11 +- synapse/rest/client/v2_alpha/groups.py | 157 ++-- synapse/rest/client/v2_alpha/keys.py | 27 +- synapse/rest/client/v2_alpha/notifications.py | 28 +- synapse/rest/client/v2_alpha/openid.py | 22 +- synapse/rest/client/v2_alpha/read_marker.py | 4 +- synapse/rest/client/v2_alpha/receipts.py | 5 +- synapse/rest/client/v2_alpha/register.py | 132 ++- synapse/rest/client/v2_alpha/relations.py | 5 +- synapse/rest/client/v2_alpha/report_event.py | 4 +- synapse/rest/client/v2_alpha/room_keys.py | 54 +- .../v2_alpha/room_upgrade_rest_servlet.py | 9 +- synapse/rest/client/v2_alpha/sendtodevice.py | 2 +- synapse/rest/client/v2_alpha/sync.py | 129 +-- synapse/rest/client/v2_alpha/tags.py | 14 +- synapse/rest/client/v2_alpha/thirdparty.py | 2 +- synapse/rest/client/v2_alpha/tokenrefresh.py | 1 + .../rest/client/v2_alpha/user_directory.py | 7 +- synapse/rest/client/versions.py | 43 +- synapse/rest/consent/consent_resource.py | 28 +- synapse/rest/key/v2/local_key_resource.py | 28 +- synapse/rest/key/v2/remote_key_resource.py | 56 +- synapse/rest/media/v0/content_repository.py | 15 +- synapse/rest/media/v1/_base.py | 57 +- synapse/rest/media/v1/config_resource.py | 4 +- synapse/rest/media/v1/download_resource.py | 8 +- synapse/rest/media/v1/filepath.py | 131 ++- synapse/rest/media/v1/media_repository.py | 215 +++-- synapse/rest/media/v1/media_storage.py | 17 +- synapse/rest/media/v1/preview_url_resource.py | 215 ++--- synapse/rest/media/v1/storage_provider.py | 6 +- synapse/rest/media/v1/thumbnail_resource.py | 120 ++- synapse/rest/media/v1/thumbnailer.py | 15 +- synapse/rest/media/v1/upload_resource.py | 30 +- synapse/rest/saml2/metadata_resource.py | 2 +- synapse/rest/saml2/response_resource.py | 13 +- synapse/rest/well_known.py | 11 +- synapse/secrets.py | 3 +- synapse/server.py | 156 ++-- synapse/server.pyi | 51 +- .../server_notices/consent_server_notices.py | 24 +- .../resource_limits_server_notices.py | 33 +- .../server_notices/server_notices_manager.py | 30 +- .../server_notices/server_notices_sender.py | 11 +- .../worker_server_notices_sender.py | 1 + synapse/state/__init__.py | 107 +-- synapse/state/v1.py | 56 +- synapse/state/v2.py | 92 +- synapse/storage/__init__.py | 10 +- synapse/storage/_base.py | 6 +- synapse/storage/background_updates.py | 2 +- synapse/storage/devices.py | 12 +- synapse/storage/e2e_room_keys.py | 28 +- synapse/storage/engines/sqlite.py | 2 +- synapse/storage/event_federation.py | 3 +- synapse/storage/event_push_actions.py | 4 +- synapse/storage/events.py | 2 +- synapse/storage/events_bg_updates.py | 16 +- synapse/storage/events_worker.py | 10 +- synapse/storage/group_server.py | 8 +- synapse/storage/keys.py | 2 +- synapse/storage/media_repository.py | 22 +- synapse/storage/monthly_active_users.py | 6 +- synapse/storage/prepare_database.py | 13 +- synapse/storage/profile.py | 2 +- synapse/storage/push_rule.py | 30 +- synapse/storage/pusher.py | 36 +- synapse/storage/receipts.py | 2 +- synapse/storage/registration.py | 107 +-- synapse/storage/relations.py | 6 +- synapse/storage/roommember.py | 8 +- synapse/storage/schema/delta/20/pushers.py | 22 +- synapse/storage/schema/delta/30/as_users.py | 17 +- synapse/storage/schema/delta/31/pushers.py | 24 +- .../schema/delta/33/remote_media_ts.py | 2 +- .../schema/delta/47/state_group_seq.py | 5 +- .../schema/delta/48/group_unique_indexes.py | 14 +- .../delta/50/make_event_content_nullable.py | 12 +- synapse/storage/search.py | 6 +- synapse/storage/stats.py | 34 +- synapse/storage/stream.py | 32 +- synapse/streams/config.py | 27 +- synapse/streams/events.py | 43 +- synapse/types.py | 108 ++- synapse/util/__init__.py | 23 +- synapse/util/async_helpers.py | 41 +- synapse/util/caches/__init__.py | 6 +- synapse/util/caches/descriptors.py | 100 +- synapse/util/caches/dictionary_cache.py | 15 +- synapse/util/caches/expiringcache.py | 18 +- synapse/util/caches/lrucache.py | 14 +- synapse/util/caches/response_cache.py | 22 +- synapse/util/caches/stream_change_cache.py | 13 +- synapse/util/caches/treecache.py | 1 + synapse/util/caches/ttlcache.py | 1 + synapse/util/distributor.py | 29 +- synapse/util/frozenutils.py | 9 +- synapse/util/httpresourcetree.py | 8 +- synapse/util/jsonobject.py | 6 +- synapse/util/logcontext.py | 83 +- synapse/util/logformatter.py | 3 +- synapse/util/logutils.py | 51 +- synapse/util/manhole.py | 18 +- synapse/util/metrics.py | 32 +- synapse/util/module_loader.py | 6 +- synapse/util/msisdn.py | 6 +- synapse/util/ratelimitutils.py | 35 +- synapse/util/stringutils.py | 16 +- synapse/util/threepids.py | 10 +- synapse/util/versionstring.py | 61 +- synapse/util/wheel_timer.py | 4 +- synapse/visibility.py | 69 +- tests/api/test_auth.py | 12 +- tests/config/test_server.py | 8 +- tests/config/test_tls.py | 2 +- tests/crypto/test_event_signing.py | 46 +- tests/events/test_utils.py | 94 +- tests/federation/test_complexity.py | 2 +- tests/federation/test_federation_sender.py | 48 +- tests/handlers/test_auth.py | 10 +- tests/handlers/test_directory.py | 8 +- tests/handlers/test_e2e_room_keys.py | 18 +- tests/handlers/test_register.py | 31 +- tests/handlers/test_stats.py | 5 +- tests/handlers/test_typing.py | 16 +- tests/handlers/test_user_directory.py | 8 +- .../test_matrix_federation_agent.py | 216 ++--- tests/http/federation/test_srv_resolver.py | 2 +- tests/http/test_endpoint.py | 12 +- tests/http/test_fedclient.py | 19 +- tests/push/test_email.py | 6 +- tests/rest/admin/test_admin.py | 100 +- tests/rest/client/test_consent.py | 10 +- tests/rest/client/test_identity.py | 2 +- tests/rest/client/v1/test_profile.py | 7 +- tests/rest/client/v1/test_rooms.py | 46 +- tests/rest/client/v1/utils.py | 8 +- tests/rest/client/v2_alpha/test_account.py | 31 +- .../rest/client/v2_alpha/test_capabilities.py | 14 +- tests/rest/client/v2_alpha/test_register.py | 35 +- tests/rest/client/v2_alpha/test_relations.py | 10 +- tests/rest/media/v1/test_base.py | 14 +- tests/rest/media/v1/test_media_storage.py | 6 +- tests/rest/media/v1/test_url_preview.py | 56 +- tests/server.py | 16 +- .../test_resource_limits_server_notices.py | 4 +- tests/state/test_v2.py | 20 +- tests/storage/test_appservice.py | 4 +- tests/storage/test_client_ips.py | 20 +- tests/storage/test_devices.py | 16 +- tests/storage/test_end_to_end_keys.py | 8 +- tests/storage/test_event_federation.py | 14 +- tests/storage/test_event_metrics.py | 36 +- tests/storage/test_monthly_active_users.py | 22 +- tests/storage/test_redaction.py | 4 +- tests/storage/test_registration.py | 2 +- tests/storage/test_room.py | 4 +- tests/storage/test_state.py | 16 +- tests/test_preview.py | 136 +-- tests/test_server.py | 14 +- tests/test_state.py | 6 +- tests/test_types.py | 4 +- tests/test_utils/logging_setup.py | 2 +- tests/test_visibility.py | 2 +- tests/unittest.py | 18 +- tests/util/caches/test_descriptors.py | 60 +- tests/util/caches/test_ttlcache.py | 46 +- tests/utils.py | 18 +- tox.ini | 9 +- 376 files changed, 9153 insertions(+), 10399 deletions(-) create mode 100644 changelog.d/5482.misc diff --git a/.buildkite/pipeline.yml b/.buildkite/pipeline.yml index 20c7aab5a740..513eb3bde9cc 100644 --- a/.buildkite/pipeline.yml +++ b/.buildkite/pipeline.yml @@ -5,8 +5,8 @@ steps: - command: - "python -m pip install tox" - - "tox -e pep8" - label: "\U0001F9F9 PEP-8" + - "tox -e check_codestyle" + label: "\U0001F9F9 Check Style" plugins: - docker#v3.0.1: image: "python:3.6" diff --git a/changelog.d/5482.misc b/changelog.d/5482.misc new file mode 100644 index 000000000000..0332d1133bdc --- /dev/null +++ b/changelog.d/5482.misc @@ -0,0 +1 @@ +Synapse's codebase is now formatted by `black`. diff --git a/contrib/cmdclient/console.py b/contrib/cmdclient/console.py index 462f1461131a..af8f39c8c279 100755 --- a/contrib/cmdclient/console.py +++ b/contrib/cmdclient/console.py @@ -37,9 +37,8 @@ CONFIG_JSON = "cmdclient_config.json" -TRUSTED_ID_SERVERS = [ - 'localhost:8001' -] +TRUSTED_ID_SERVERS = ["localhost:8001"] + class SynapseCmd(cmd.Cmd): @@ -59,7 +58,7 @@ def __init__(self, http_client, server_url, identity_server_url, username, token "token": token, "verbose": "on", "complete_usernames": "on", - "send_delivery_receipts": "on" + "send_delivery_receipts": "on", } self.path_prefix = "/_matrix/client/api/v1" self.event_stream_token = "END" @@ -120,12 +119,11 @@ def do_config(self, line): config_rules = [ # key, valid_values ("verbose", ["on", "off"]), ("complete_usernames", ["on", "off"]), - ("send_delivery_receipts", ["on", "off"]) + ("send_delivery_receipts", ["on", "off"]), ] for key, valid_vals in config_rules: if key == args["key"] and args["val"] not in valid_vals: - print("%s value must be one of %s" % (args["key"], - valid_vals)) + print("%s value must be one of %s" % (args["key"], valid_vals)) return # toggle the http client verbosity @@ -159,16 +157,13 @@ def do_register(self, line): else: password = pwd - body = { - "type": "m.login.password" - } + body = {"type": "m.login.password"} if "userid" in args: body["user"] = args["userid"] if password: body["password"] = password - reactor.callFromThread(self._do_register, body, - "noupdate" not in args) + reactor.callFromThread(self._do_register, body, "noupdate" not in args) @defer.inlineCallbacks def _do_register(self, data, update_config): @@ -179,7 +174,9 @@ def _do_register(self, data, update_config): passwordFlow = None for flow in json_res["flows"]: - if flow["type"] == "m.login.recaptcha" or ("stages" in flow and "m.login.recaptcha" in flow["stages"]): + if flow["type"] == "m.login.recaptcha" or ( + "stages" in flow and "m.login.recaptcha" in flow["stages"] + ): print("Unable to register: Home server requires captcha.") return if flow["type"] == "m.login.password" and "stages" not in flow: @@ -202,9 +199,7 @@ def do_login(self, line): """ try: args = self._parse(line, ["user_id"], force_keys=True) - can_login = threads.blockingCallFromThread( - reactor, - self._check_can_login) + can_login = threads.blockingCallFromThread(reactor, self._check_can_login) if can_login: p = getpass.getpass("Enter your password: ") user = args["user_id"] @@ -212,20 +207,16 @@ def do_login(self, line): domain = self._domain() if domain: user = "@" + user + ":" + domain - + reactor.callFromThread(self._do_login, user, p) - #print " got %s " % p + # print " got %s " % p except Exception as e: print(e) @defer.inlineCallbacks def _do_login(self, user, password): path = "/login" - data = { - "user": user, - "password": password, - "type": "m.login.password" - } + data = {"user": user, "password": password, "type": "m.login.password"} url = self._url() + path json_res = yield self.http_client.do_request("POST", url, data=data) print(json_res) @@ -249,12 +240,13 @@ def _check_can_login(self): print("Failed to find any login flows.") defer.returnValue(False) - flow = json_res["flows"][0] # assume first is the one we want. - if ("type" not in flow or "m.login.password" != flow["type"] or - "stages" in flow): + flow = json_res["flows"][0] # assume first is the one we want. + if "type" not in flow or "m.login.password" != flow["type"] or "stages" in flow: fallback_url = self._url() + "/login/fallback" - print ("Unable to login via the command line client. Please visit " - "%s to login." % fallback_url) + print( + "Unable to login via the command line client. Please visit " + "%s to login." % fallback_url + ) defer.returnValue(False) defer.returnValue(True) @@ -264,21 +256,33 @@ def do_emailrequest(self, line): A string of characters generated when requesting an email that you'll supply in subsequent calls to identify yourself The number of times the user has requested an email. Leave this the same between requests to retry the request at the transport level. Increment it to request that the email be sent again. """ - args = self._parse(line, ['address', 'clientSecret', 'sendAttempt']) + args = self._parse(line, ["address", "clientSecret", "sendAttempt"]) - postArgs = {'email': args['address'], 'clientSecret': args['clientSecret'], 'sendAttempt': args['sendAttempt']} + postArgs = { + "email": args["address"], + "clientSecret": args["clientSecret"], + "sendAttempt": args["sendAttempt"], + } reactor.callFromThread(self._do_emailrequest, postArgs) @defer.inlineCallbacks def _do_emailrequest(self, args): - url = self._identityServerUrl()+"/_matrix/identity/api/v1/validate/email/requestToken" - - json_res = yield self.http_client.do_request("POST", url, data=urllib.urlencode(args), jsonreq=False, - headers={'Content-Type': ['application/x-www-form-urlencoded']}) + url = ( + self._identityServerUrl() + + "/_matrix/identity/api/v1/validate/email/requestToken" + ) + + json_res = yield self.http_client.do_request( + "POST", + url, + data=urllib.urlencode(args), + jsonreq=False, + headers={"Content-Type": ["application/x-www-form-urlencoded"]}, + ) print(json_res) - if 'sid' in json_res: - print("Token sent. Your session ID is %s" % (json_res['sid'])) + if "sid" in json_res: + print("Token sent. Your session ID is %s" % (json_res["sid"])) def do_emailvalidate(self, line): """Validate and associate a third party ID @@ -286,18 +290,30 @@ def do_emailvalidate(self, line): The token sent to your third party identifier address The same clientSecret you supplied in requestToken """ - args = self._parse(line, ['sid', 'token', 'clientSecret']) + args = self._parse(line, ["sid", "token", "clientSecret"]) - postArgs = { 'sid' : args['sid'], 'token' : args['token'], 'clientSecret': args['clientSecret'] } + postArgs = { + "sid": args["sid"], + "token": args["token"], + "clientSecret": args["clientSecret"], + } reactor.callFromThread(self._do_emailvalidate, postArgs) @defer.inlineCallbacks def _do_emailvalidate(self, args): - url = self._identityServerUrl()+"/_matrix/identity/api/v1/validate/email/submitToken" - - json_res = yield self.http_client.do_request("POST", url, data=urllib.urlencode(args), jsonreq=False, - headers={'Content-Type': ['application/x-www-form-urlencoded']}) + url = ( + self._identityServerUrl() + + "/_matrix/identity/api/v1/validate/email/submitToken" + ) + + json_res = yield self.http_client.do_request( + "POST", + url, + data=urllib.urlencode(args), + jsonreq=False, + headers={"Content-Type": ["application/x-www-form-urlencoded"]}, + ) print(json_res) def do_3pidbind(self, line): @@ -305,19 +321,24 @@ def do_3pidbind(self, line): The session ID (sid) given to you in the response to requestToken The same clientSecret you supplied in requestToken """ - args = self._parse(line, ['sid', 'clientSecret']) + args = self._parse(line, ["sid", "clientSecret"]) - postArgs = { 'sid' : args['sid'], 'clientSecret': args['clientSecret'] } - postArgs['mxid'] = self.config["user"] + postArgs = {"sid": args["sid"], "clientSecret": args["clientSecret"]} + postArgs["mxid"] = self.config["user"] reactor.callFromThread(self._do_3pidbind, postArgs) @defer.inlineCallbacks def _do_3pidbind(self, args): - url = self._identityServerUrl()+"/_matrix/identity/api/v1/3pid/bind" - - json_res = yield self.http_client.do_request("POST", url, data=urllib.urlencode(args), jsonreq=False, - headers={'Content-Type': ['application/x-www-form-urlencoded']}) + url = self._identityServerUrl() + "/_matrix/identity/api/v1/3pid/bind" + + json_res = yield self.http_client.do_request( + "POST", + url, + data=urllib.urlencode(args), + jsonreq=False, + headers={"Content-Type": ["application/x-www-form-urlencoded"]}, + ) print(json_res) def do_join(self, line): @@ -356,9 +377,7 @@ def do_topic(self, line): if "topic" not in args: print("Must specify a new topic.") return - body = { - "topic": args["topic"] - } + body = {"topic": args["topic"]} reactor.callFromThread(self._run_and_pprint, "PUT", path, body) elif args["action"].lower() == "get": reactor.callFromThread(self._run_and_pprint, "GET", path) @@ -378,45 +397,60 @@ def do_invite(self, line): @defer.inlineCallbacks def _do_invite(self, roomid, userstring): - if (not userstring.startswith('@') and - self._is_on("complete_usernames")): - url = self._identityServerUrl()+"/_matrix/identity/api/v1/lookup" + if not userstring.startswith("@") and self._is_on("complete_usernames"): + url = self._identityServerUrl() + "/_matrix/identity/api/v1/lookup" - json_res = yield self.http_client.do_request("GET", url, qparams={'medium':'email','address':userstring}) + json_res = yield self.http_client.do_request( + "GET", url, qparams={"medium": "email", "address": userstring} + ) mxid = None - if 'mxid' in json_res and 'signatures' in json_res: - url = self._identityServerUrl()+"/_matrix/identity/api/v1/pubkey/ed25519" + if "mxid" in json_res and "signatures" in json_res: + url = ( + self._identityServerUrl() + + "/_matrix/identity/api/v1/pubkey/ed25519" + ) pubKey = None pubKeyObj = yield self.http_client.do_request("GET", url) - if 'public_key' in pubKeyObj: - pubKey = nacl.signing.VerifyKey(pubKeyObj['public_key'], encoder=nacl.encoding.HexEncoder) + if "public_key" in pubKeyObj: + pubKey = nacl.signing.VerifyKey( + pubKeyObj["public_key"], encoder=nacl.encoding.HexEncoder + ) else: print("No public key found in pubkey response!") sigValid = False if pubKey: - for signame in json_res['signatures']: + for signame in json_res["signatures"]: if signame not in TRUSTED_ID_SERVERS: - print("Ignoring signature from untrusted server %s" % (signame)) + print( + "Ignoring signature from untrusted server %s" + % (signame) + ) else: try: verify_signed_json(json_res, signame, pubKey) sigValid = True - print("Mapping %s -> %s correctly signed by %s" % (userstring, json_res['mxid'], signame)) + print( + "Mapping %s -> %s correctly signed by %s" + % (userstring, json_res["mxid"], signame) + ) break except SignatureVerifyException as e: print("Invalid signature from %s" % (signame)) print(e) if sigValid: - print("Resolved 3pid %s to %s" % (userstring, json_res['mxid'])) - mxid = json_res['mxid'] + print("Resolved 3pid %s to %s" % (userstring, json_res["mxid"])) + mxid = json_res["mxid"] else: - print("Got association for %s but couldn't verify signature" % (userstring)) + print( + "Got association for %s but couldn't verify signature" + % (userstring) + ) if not mxid: mxid = "@" + userstring + ":" + self._domain() @@ -435,12 +469,11 @@ def do_send(self, line): """Sends a message. "send " """ args = self._parse(line, ["roomid", "body"]) txn_id = "txn%s" % int(time.time()) - path = "/rooms/%s/send/m.room.message/%s" % (urllib.quote(args["roomid"]), - txn_id) - body_json = { - "msgtype": "m.text", - "body": args["body"] - } + path = "/rooms/%s/send/m.room.message/%s" % ( + urllib.quote(args["roomid"]), + txn_id, + ) + body_json = {"msgtype": "m.text", "body": args["body"]} reactor.callFromThread(self._run_and_pprint, "PUT", path, body_json) def do_list(self, line): @@ -472,8 +505,7 @@ def do_list(self, line): print("Bad query param: %s" % key_value) return - reactor.callFromThread(self._run_and_pprint, "GET", path, - query_params=qp) + reactor.callFromThread(self._run_and_pprint, "GET", path, query_params=qp) def do_create(self, line): """Creates a room. @@ -513,8 +545,16 @@ def do_raw(self, line): return args["method"] = args["method"].upper() - valid_methods = ["PUT", "GET", "POST", "DELETE", - "XPUT", "XGET", "XPOST", "XDELETE"] + valid_methods = [ + "PUT", + "GET", + "POST", + "DELETE", + "XPUT", + "XGET", + "XPOST", + "XDELETE", + ] if args["method"] not in valid_methods: print("Unsupported method: %s" % args["method"]) return @@ -541,10 +581,13 @@ def do_raw(self, line): except: pass - reactor.callFromThread(self._run_and_pprint, args["method"], - args["path"], - args["data"], - query_params=qp) + reactor.callFromThread( + self._run_and_pprint, + args["method"], + args["path"], + args["data"], + query_params=qp, + ) def do_stream(self, line): """Stream data from the server: "stream " """ @@ -561,19 +604,22 @@ def do_stream(self, line): @defer.inlineCallbacks def _do_event_stream(self, timeout): res = yield self.http_client.get_json( - self._url() + "/events", - { - "access_token": self._tok(), - "timeout": str(timeout), - "from": self.event_stream_token - }) + self._url() + "/events", + { + "access_token": self._tok(), + "timeout": str(timeout), + "from": self.event_stream_token, + }, + ) print(json.dumps(res, indent=4)) if "chunk" in res: for event in res["chunk"]: - if (event["type"] == "m.room.message" and - self._is_on("send_delivery_receipts") and - event["user_id"] != self._usr()): # not sent by us + if ( + event["type"] == "m.room.message" + and self._is_on("send_delivery_receipts") + and event["user_id"] != self._usr() + ): # not sent by us self._send_receipt(event, "d") # update the position in the stram @@ -581,18 +627,28 @@ def _do_event_stream(self, timeout): self.event_stream_token = res["end"] def _send_receipt(self, event, feedback_type): - path = ("/rooms/%s/messages/%s/%s/feedback/%s/%s" % - (urllib.quote(event["room_id"]), event["user_id"], event["msg_id"], - self._usr(), feedback_type)) + path = "/rooms/%s/messages/%s/%s/feedback/%s/%s" % ( + urllib.quote(event["room_id"]), + event["user_id"], + event["msg_id"], + self._usr(), + feedback_type, + ) data = {} - reactor.callFromThread(self._run_and_pprint, "PUT", path, data=data, - alt_text="Sent receipt for %s" % event["msg_id"]) + reactor.callFromThread( + self._run_and_pprint, + "PUT", + path, + data=data, + alt_text="Sent receipt for %s" % event["msg_id"], + ) def _do_membership_change(self, roomid, membership, userid): - path = "/rooms/%s/state/m.room.member/%s" % (urllib.quote(roomid), urllib.quote(userid)) - data = { - "membership": membership - } + path = "/rooms/%s/state/m.room.member/%s" % ( + urllib.quote(roomid), + urllib.quote(userid), + ) + data = {"membership": membership} reactor.callFromThread(self._run_and_pprint, "PUT", path, data=data) def do_displayname(self, line): @@ -645,15 +701,20 @@ def _parse(self, line, keys, force_keys=False): for i, arg in enumerate(line_args): for config_key in self.config: if ("$" + config_key) in arg: - arg = arg.replace("$" + config_key, - self.config[config_key]) + arg = arg.replace("$" + config_key, self.config[config_key]) line_args[i] = arg return dict(zip(keys, line_args)) @defer.inlineCallbacks - def _run_and_pprint(self, method, path, data=None, - query_params={"access_token": None}, alt_text=None): + def _run_and_pprint( + self, + method, + path, + data=None, + query_params={"access_token": None}, + alt_text=None, + ): """ Runs an HTTP request and pretty prints the output. Args: @@ -666,9 +727,9 @@ def _run_and_pprint(self, method, path, data=None, if "access_token" in query_params: query_params["access_token"] = self._tok() - json_res = yield self.http_client.do_request(method, url, - data=data, - qparams=query_params) + json_res = yield self.http_client.do_request( + method, url, data=data, qparams=query_params + ) if alt_text: print(alt_text) else: @@ -676,7 +737,7 @@ def _run_and_pprint(self, method, path, data=None, def save_config(config): - with open(CONFIG_JSON, 'w') as out: + with open(CONFIG_JSON, "w") as out: json.dump(config, out) @@ -700,7 +761,7 @@ def main(server_url, identity_server_url, username, token, config_path): global CONFIG_JSON CONFIG_JSON = config_path # bit cheeky, but just overwrite the global try: - with open(config_path, 'r') as config: + with open(config_path, "r") as config: syn_cmd.config = json.load(config) try: http_client.verbose = "on" == syn_cmd.config["verbose"] @@ -717,23 +778,33 @@ def main(server_url, identity_server_url, username, token, config_path): reactor.run() -if __name__ == '__main__': +if __name__ == "__main__": parser = argparse.ArgumentParser("Starts a synapse client.") parser.add_argument( - "-s", "--server", dest="server", default="http://localhost:8008", - help="The URL of the home server to talk to.") - parser.add_argument( - "-i", "--identity-server", dest="identityserver", default="http://localhost:8090", - help="The URL of the identity server to talk to.") + "-s", + "--server", + dest="server", + default="http://localhost:8008", + help="The URL of the home server to talk to.", + ) parser.add_argument( - "-u", "--username", dest="username", - help="Your username on the server.") + "-i", + "--identity-server", + dest="identityserver", + default="http://localhost:8090", + help="The URL of the identity server to talk to.", + ) parser.add_argument( - "-t", "--token", dest="token", - help="Your access token.") + "-u", "--username", dest="username", help="Your username on the server." + ) + parser.add_argument("-t", "--token", dest="token", help="Your access token.") parser.add_argument( - "-c", "--config", dest="config", default=CONFIG_JSON, - help="The location of the config.json file to read from.") + "-c", + "--config", + dest="config", + default=CONFIG_JSON, + help="The location of the config.json file to read from.", + ) args = parser.parse_args() if not args.server: diff --git a/contrib/cmdclient/http.py b/contrib/cmdclient/http.py index 1bd600e148be..0e101d2be56b 100644 --- a/contrib/cmdclient/http.py +++ b/contrib/cmdclient/http.py @@ -73,9 +73,7 @@ def __init__(self): @defer.inlineCallbacks def put_json(self, url, data): response = yield self._create_put_request( - url, - data, - headers_dict={"Content-Type": ["application/json"]} + url, data, headers_dict={"Content-Type": ["application/json"]} ) body = yield readBody(response) defer.returnValue((response.code, body)) @@ -95,40 +93,34 @@ def _create_put_request(self, url, json_data, headers_dict={}): """ if "Content-Type" not in headers_dict: - raise defer.error( - RuntimeError("Must include Content-Type header for PUTs")) + raise defer.error(RuntimeError("Must include Content-Type header for PUTs")) return self._create_request( - "PUT", - url, - producer=_JsonProducer(json_data), - headers_dict=headers_dict + "PUT", url, producer=_JsonProducer(json_data), headers_dict=headers_dict ) def _create_get_request(self, url, headers_dict={}): """ Wrapper of _create_request to issue a GET request """ - return self._create_request( - "GET", - url, - headers_dict=headers_dict - ) + return self._create_request("GET", url, headers_dict=headers_dict) @defer.inlineCallbacks - def do_request(self, method, url, data=None, qparams=None, jsonreq=True, headers={}): + def do_request( + self, method, url, data=None, qparams=None, jsonreq=True, headers={} + ): if qparams: url = "%s?%s" % (url, urllib.urlencode(qparams, True)) if jsonreq: prod = _JsonProducer(data) - headers['Content-Type'] = ["application/json"]; + headers["Content-Type"] = ["application/json"] else: prod = _RawProducer(data) if method in ["POST", "PUT"]: - response = yield self._create_request(method, url, - producer=prod, - headers_dict=headers) + response = yield self._create_request( + method, url, producer=prod, headers_dict=headers + ) else: response = yield self._create_request(method, url) @@ -155,10 +147,7 @@ def _create_request(self, method, url, producer=None, headers_dict={}): while True: try: response = yield self.agent.request( - method, - url.encode("UTF8"), - Headers(headers_dict), - producer + method, url.encode("UTF8"), Headers(headers_dict), producer ) break except Exception as e: @@ -179,6 +168,7 @@ def sleep(self, seconds): reactor.callLater(seconds, d.callback, seconds) return d + class _RawProducer(object): def __init__(self, data): self.data = data @@ -195,9 +185,11 @@ def pauseProducing(self): def stopProducing(self): pass + class _JsonProducer(object): """ Used by the twisted http client to create the HTTP body from json """ + def __init__(self, jsn): self.data = jsn self.body = json.dumps(jsn).encode("utf8") diff --git a/contrib/experiments/cursesio.py b/contrib/experiments/cursesio.py index 44afe81008a5..ffefe3bb3928 100644 --- a/contrib/experiments/cursesio.py +++ b/contrib/experiments/cursesio.py @@ -19,13 +19,13 @@ from twisted.internet import reactor -class CursesStdIO(): +class CursesStdIO: def __init__(self, stdscr, callback=None): self.statusText = "Synapse test app -" - self.searchText = '' + self.searchText = "" self.stdscr = stdscr - self.logLine = '' + self.logLine = "" self.callback = callback @@ -71,8 +71,7 @@ def redraw(self): i = 0 index = len(self.lines) - 1 while i < (self.rows - 3) and index >= 0: - self.stdscr.addstr(self.rows - 3 - i, 0, self.lines[index], - curses.A_NORMAL) + self.stdscr.addstr(self.rows - 3 - i, 0, self.lines[index], curses.A_NORMAL) i = i + 1 index = index - 1 @@ -85,15 +84,13 @@ def paintStatus(self, text): raise RuntimeError("TextTooLongError") self.stdscr.addstr( - self.rows - 2, 0, - text + ' ' * (self.cols - len(text)), - curses.A_STANDOUT) + self.rows - 2, 0, text + " " * (self.cols - len(text)), curses.A_STANDOUT + ) def printLogLine(self, text): self.stdscr.addstr( - 0, 0, - text + ' ' * (self.cols - len(text)), - curses.A_STANDOUT) + 0, 0, text + " " * (self.cols - len(text)), curses.A_STANDOUT + ) def doRead(self): """ Input is ready! """ @@ -105,7 +102,7 @@ def doRead(self): elif c == curses.KEY_ENTER or c == 10: text = self.searchText - self.searchText = '' + self.searchText = "" self.print_line(">> %s" % text) @@ -122,11 +119,13 @@ def doRead(self): return self.searchText = self.searchText + chr(c) - self.stdscr.addstr(self.rows - 1, 0, - self.searchText + (' ' * ( - self.cols - len(self.searchText) - 2))) + self.stdscr.addstr( + self.rows - 1, + 0, + self.searchText + (" " * (self.cols - len(self.searchText) - 2)), + ) - self.paintStatus(self.statusText + ' %d' % len(self.searchText)) + self.paintStatus(self.statusText + " %d" % len(self.searchText)) self.stdscr.move(self.rows - 1, len(self.searchText)) self.stdscr.refresh() @@ -143,7 +142,6 @@ def close(self): class Callback(object): - def __init__(self, stdio): self.stdio = stdio @@ -152,7 +150,7 @@ def on_line(self, text): def main(stdscr): - screen = CursesStdIO(stdscr) # create Screen object + screen = CursesStdIO(stdscr) # create Screen object callback = Callback(screen) @@ -164,5 +162,5 @@ def main(stdscr): screen.close() -if __name__ == '__main__': +if __name__ == "__main__": curses.wrapper(main) diff --git a/contrib/experiments/test_messaging.py b/contrib/experiments/test_messaging.py index 85c9c1198415..c7e55d8aa7fd 100644 --- a/contrib/experiments/test_messaging.py +++ b/contrib/experiments/test_messaging.py @@ -28,9 +28,7 @@ """ -from synapse.federation import ( - ReplicationHandler -) +from synapse.federation import ReplicationHandler from synapse.federation.units import Pdu @@ -38,7 +36,7 @@ from synapse.app.homeserver import SynapseHomeServer -#from synapse.util.logutils import log_function +# from synapse.util.logutils import log_function from twisted.internet import reactor, defer from twisted.python import log @@ -83,7 +81,7 @@ def on_line(self, line): room_name, = m.groups() self.print_line("%s joining %s" % (self.user, room_name)) self.server.join_room(room_name, self.user, self.user) - #self.print_line("OK.") + # self.print_line("OK.") return m = re.match("^invite (\S+) (\S+)$", line) @@ -92,7 +90,7 @@ def on_line(self, line): room_name, invitee = m.groups() self.print_line("%s invited to %s" % (invitee, room_name)) self.server.invite_to_room(room_name, self.user, invitee) - #self.print_line("OK.") + # self.print_line("OK.") return m = re.match("^send (\S+) (.*)$", line) @@ -101,7 +99,7 @@ def on_line(self, line): room_name, body = m.groups() self.print_line("%s send to %s" % (self.user, room_name)) self.server.send_message(room_name, self.user, body) - #self.print_line("OK.") + # self.print_line("OK.") return m = re.match("^backfill (\S+)$", line) @@ -125,7 +123,6 @@ def print_log(self, text): class IOLoggerHandler(logging.Handler): - def __init__(self, io): logging.Handler.__init__(self) self.io = io @@ -142,6 +139,7 @@ class Room(object): """ Used to store (in memory) the current membership state of a room, and which home servers we should send PDUs associated with the room to. """ + def __init__(self, room_name): self.room_name = room_name self.invited = set() @@ -175,6 +173,7 @@ class HomeServer(ReplicationHandler): """ A very basic home server implentation that allows people to join a room and then invite other people. """ + def __init__(self, server_name, replication_layer, output): self.server_name = server_name self.replication_layer = replication_layer @@ -197,26 +196,27 @@ def on_receive_pdu(self, pdu): elif pdu.content["membership"] == "invite": self._on_invite(pdu.origin, pdu.context, pdu.state_key) else: - self.output.print_line("#%s (unrec) %s = %s" % - (pdu.context, pdu.pdu_type, json.dumps(pdu.content)) + self.output.print_line( + "#%s (unrec) %s = %s" + % (pdu.context, pdu.pdu_type, json.dumps(pdu.content)) ) - #def on_state_change(self, pdu): - ##self.output.print_line("#%s (state) %s *** %s" % - ##(pdu.context, pdu.state_key, pdu.pdu_type) - ##) + # def on_state_change(self, pdu): + ##self.output.print_line("#%s (state) %s *** %s" % + ##(pdu.context, pdu.state_key, pdu.pdu_type) + ##) - #if "joinee" in pdu.content: - #self._on_join(pdu.context, pdu.content["joinee"]) - #elif "invitee" in pdu.content: - #self._on_invite(pdu.origin, pdu.context, pdu.content["invitee"]) + # if "joinee" in pdu.content: + # self._on_join(pdu.context, pdu.content["joinee"]) + # elif "invitee" in pdu.content: + # self._on_invite(pdu.origin, pdu.context, pdu.content["invitee"]) def _on_message(self, pdu): """ We received a message """ - self.output.print_line("#%s %s %s" % - (pdu.context, pdu.content["sender"], pdu.content["body"]) - ) + self.output.print_line( + "#%s %s %s" % (pdu.context, pdu.content["sender"], pdu.content["body"]) + ) def _on_join(self, context, joinee): """ Someone has joined a room, either a remote user or a local user @@ -224,9 +224,7 @@ def _on_join(self, context, joinee): room = self._get_or_create_room(context) room.add_participant(joinee) - self.output.print_line("#%s %s %s" % - (context, joinee, "*** JOINED") - ) + self.output.print_line("#%s %s %s" % (context, joinee, "*** JOINED")) def _on_invite(self, origin, context, invitee): """ Someone has been invited @@ -234,9 +232,7 @@ def _on_invite(self, origin, context, invitee): room = self._get_or_create_room(context) room.add_invited(invitee) - self.output.print_line("#%s %s %s" % - (context, invitee, "*** INVITED") - ) + self.output.print_line("#%s %s %s" % (context, invitee, "*** INVITED")) if not room.have_got_metadata and origin is not self.server_name: logger.debug("Get room state") @@ -272,14 +268,14 @@ def join_room(self, room_name, sender, joinee): try: pdu = Pdu.create_new( - context=room_name, - pdu_type="sy.room.member", - is_state=True, - state_key=joinee, - content={"membership": "join"}, - origin=self.server_name, - destinations=destinations, - ) + context=room_name, + pdu_type="sy.room.member", + is_state=True, + state_key=joinee, + content={"membership": "join"}, + origin=self.server_name, + destinations=destinations, + ) yield self.replication_layer.send_pdu(pdu) except Exception as e: logger.exception(e) @@ -318,21 +314,21 @@ def backfill(self, room_name, limit=5): return self.replication_layer.backfill(dest, room_name, limit) def _get_room_remote_servers(self, room_name): - return [i for i in self.joined_rooms.setdefault(room_name,).servers] + return [i for i in self.joined_rooms.setdefault(room_name).servers] def _get_or_create_room(self, room_name): return self.joined_rooms.setdefault(room_name, Room(room_name)) def get_servers_for_context(self, context): return defer.succeed( - self.joined_rooms.setdefault(context, Room(context)).servers - ) + self.joined_rooms.setdefault(context, Room(context)).servers + ) def main(stdscr): parser = argparse.ArgumentParser() - parser.add_argument('user', type=str) - parser.add_argument('-v', '--verbose', action='count') + parser.add_argument("user", type=str) + parser.add_argument("-v", "--verbose", action="count") args = parser.parse_args() user = args.user @@ -342,8 +338,9 @@ def main(stdscr): root_logger = logging.getLogger() - formatter = logging.Formatter('%(asctime)s - %(name)s - %(lineno)d - ' - '%(levelname)s - %(message)s') + formatter = logging.Formatter( + "%(asctime)s - %(name)s - %(lineno)d - " "%(levelname)s - %(message)s" + ) if not os.path.exists("logs"): os.makedirs("logs") fh = logging.FileHandler("logs/%s" % user) diff --git a/contrib/graph/graph.py b/contrib/graph/graph.py index e174ff5026dc..92736480ebab 100644 --- a/contrib/graph/graph.py +++ b/contrib/graph/graph.py @@ -1,4 +1,5 @@ from __future__ import print_function + # Copyright 2014-2016 OpenMarket Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -58,9 +59,9 @@ def make_graph(pdus, room, filename_prefix): name = make_name(pdu.get("pdu_id"), pdu.get("origin")) pdu_map[name] = pdu - t = datetime.datetime.fromtimestamp( - float(pdu["ts"]) / 1000 - ).strftime('%Y-%m-%d %H:%M:%S,%f') + t = datetime.datetime.fromtimestamp(float(pdu["ts"]) / 1000).strftime( + "%Y-%m-%d %H:%M:%S,%f" + ) label = ( "<" @@ -80,11 +81,7 @@ def make_graph(pdus, room, filename_prefix): "depth": pdu.get("depth"), } - node = pydot.Node( - name=name, - label=label, - color=color_map[pdu.get("origin")] - ) + node = pydot.Node(name=name, label=label, color=color_map[pdu.get("origin")]) node_map[name] = node graph.add_node(node) @@ -108,14 +105,13 @@ def make_graph(pdus, room, filename_prefix): if prev_state_name in node_map: state_edge = pydot.Edge( - node_map[start_name], node_map[prev_state_name], - style='dotted' + node_map[start_name], node_map[prev_state_name], style="dotted" ) graph.add_edge(state_edge) - graph.write('%s.dot' % filename_prefix, format='raw', prog='dot') -# graph.write_png("%s.png" % filename_prefix, prog='dot') - graph.write_svg("%s.svg" % filename_prefix, prog='dot') + graph.write("%s.dot" % filename_prefix, format="raw", prog="dot") + # graph.write_png("%s.png" % filename_prefix, prog='dot') + graph.write_svg("%s.svg" % filename_prefix, prog="dot") def get_pdus(host, room): @@ -131,15 +127,14 @@ def get_pdus(host, room): if __name__ == "__main__": parser = argparse.ArgumentParser( description="Generate a PDU graph for a given room by talking " - "to the given homeserver to get the list of PDUs. \n" - "Requires pydot." + "to the given homeserver to get the list of PDUs. \n" + "Requires pydot." ) parser.add_argument( - "-p", "--prefix", dest="prefix", - help="String to prefix output files with" + "-p", "--prefix", dest="prefix", help="String to prefix output files with" ) - parser.add_argument('host') - parser.add_argument('room') + parser.add_argument("host") + parser.add_argument("room") args = parser.parse_args() diff --git a/contrib/graph/graph2.py b/contrib/graph/graph2.py index 1ccad6572890..9db8725eee01 100644 --- a/contrib/graph/graph2.py +++ b/contrib/graph/graph2.py @@ -36,10 +36,7 @@ def make_graph(db_name, room_id, file_prefix, limit): args = [room_id] if limit: - sql += ( - " ORDER BY topological_ordering DESC, stream_ordering DESC " - "LIMIT ?" - ) + sql += " ORDER BY topological_ordering DESC, stream_ordering DESC " "LIMIT ?" args.append(limit) @@ -56,9 +53,8 @@ def make_graph(db_name, room_id, file_prefix, limit): for event in events: c = conn.execute( - "SELECT state_group FROM event_to_state_groups " - "WHERE event_id = ?", - (event.event_id,) + "SELECT state_group FROM event_to_state_groups " "WHERE event_id = ?", + (event.event_id,), ) res = c.fetchone() @@ -69,7 +65,7 @@ def make_graph(db_name, room_id, file_prefix, limit): t = datetime.datetime.fromtimestamp( float(event.origin_server_ts) / 1000 - ).strftime('%Y-%m-%d %H:%M:%S,%f') + ).strftime("%Y-%m-%d %H:%M:%S,%f") content = json.dumps(unfreeze(event.get_dict()["content"])) @@ -93,10 +89,7 @@ def make_graph(db_name, room_id, file_prefix, limit): "state_group": state_group, } - node = pydot.Node( - name=event.event_id, - label=label, - ) + node = pydot.Node(name=event.event_id, label=label) node_map[event.event_id] = node graph.add_node(node) @@ -106,10 +99,7 @@ def make_graph(db_name, room_id, file_prefix, limit): try: end_node = node_map[prev_id] except: - end_node = pydot.Node( - name=prev_id, - label="<%s>" % (prev_id,), - ) + end_node = pydot.Node(name=prev_id, label="<%s>" % (prev_id,)) node_map[prev_id] = end_node graph.add_node(end_node) @@ -121,36 +111,33 @@ def make_graph(db_name, room_id, file_prefix, limit): if len(event_ids) <= 1: continue - cluster = pydot.Cluster( - str(group), - label="" % (str(group),) - ) + cluster = pydot.Cluster(str(group), label="" % (str(group),)) for event_id in event_ids: cluster.add_node(node_map[event_id]) graph.add_subgraph(cluster) - graph.write('%s.dot' % file_prefix, format='raw', prog='dot') - graph.write_svg("%s.svg" % file_prefix, prog='dot') + graph.write("%s.dot" % file_prefix, format="raw", prog="dot") + graph.write_svg("%s.svg" % file_prefix, prog="dot") + if __name__ == "__main__": parser = argparse.ArgumentParser( description="Generate a PDU graph for a given room by talking " - "to the given homeserver to get the list of PDUs. \n" - "Requires pydot." + "to the given homeserver to get the list of PDUs. \n" + "Requires pydot." ) parser.add_argument( - "-p", "--prefix", dest="prefix", + "-p", + "--prefix", + dest="prefix", help="String to prefix output files with", - default="graph_output" - ) - parser.add_argument( - "-l", "--limit", - help="Only retrieve the last N events.", + default="graph_output", ) - parser.add_argument('db') - parser.add_argument('room') + parser.add_argument("-l", "--limit", help="Only retrieve the last N events.") + parser.add_argument("db") + parser.add_argument("room") args = parser.parse_args() diff --git a/contrib/graph/graph3.py b/contrib/graph/graph3.py index fe1dc81e9063..7f9e5374a61c 100644 --- a/contrib/graph/graph3.py +++ b/contrib/graph/graph3.py @@ -1,4 +1,5 @@ from __future__ import print_function + # Copyright 2016 OpenMarket Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -42,7 +43,7 @@ def make_graph(file_name, room_id, file_prefix, limit): print("Sorted events") if limit: - events = events[-int(limit):] + events = events[-int(limit) :] node_map = {} @@ -51,7 +52,7 @@ def make_graph(file_name, room_id, file_prefix, limit): for event in events: t = datetime.datetime.fromtimestamp( float(event.origin_server_ts) / 1000 - ).strftime('%Y-%m-%d %H:%M:%S,%f') + ).strftime("%Y-%m-%d %H:%M:%S,%f") content = json.dumps(unfreeze(event.get_dict()["content"]), indent=4) content = content.replace("\n", "
\n") @@ -67,9 +68,10 @@ def make_graph(file_name, room_id, file_prefix, limit): value = json.dumps(value) content.append( - "%s: %s," % ( - cgi.escape(key, quote=True).encode("ascii", 'xmlcharrefreplace'), - cgi.escape(value, quote=True).encode("ascii", 'xmlcharrefreplace'), + "%s: %s," + % ( + cgi.escape(key, quote=True).encode("ascii", "xmlcharrefreplace"), + cgi.escape(value, quote=True).encode("ascii", "xmlcharrefreplace"), ) ) @@ -95,10 +97,7 @@ def make_graph(file_name, room_id, file_prefix, limit): "depth": event.depth, } - node = pydot.Node( - name=event.event_id, - label=label, - ) + node = pydot.Node(name=event.event_id, label=label) node_map[event.event_id] = node graph.add_node(node) @@ -110,10 +109,7 @@ def make_graph(file_name, room_id, file_prefix, limit): try: end_node = node_map[prev_id] except: - end_node = pydot.Node( - name=prev_id, - label="<%s>" % (prev_id,), - ) + end_node = pydot.Node(name=prev_id, label="<%s>" % (prev_id,)) node_map[prev_id] = end_node graph.add_node(end_node) @@ -123,31 +119,31 @@ def make_graph(file_name, room_id, file_prefix, limit): print("Created edges") - graph.write('%s.dot' % file_prefix, format='raw', prog='dot') + graph.write("%s.dot" % file_prefix, format="raw", prog="dot") print("Created Dot") - graph.write_svg("%s.svg" % file_prefix, prog='dot') + graph.write_svg("%s.svg" % file_prefix, prog="dot") print("Created svg") + if __name__ == "__main__": parser = argparse.ArgumentParser( description="Generate a PDU graph for a given room by reading " - "from a file with line deliminated events. \n" - "Requires pydot." + "from a file with line deliminated events. \n" + "Requires pydot." ) parser.add_argument( - "-p", "--prefix", dest="prefix", + "-p", + "--prefix", + dest="prefix", help="String to prefix output files with", - default="graph_output" - ) - parser.add_argument( - "-l", "--limit", - help="Only retrieve the last N events.", + default="graph_output", ) - parser.add_argument('event_file') - parser.add_argument('room') + parser.add_argument("-l", "--limit", help="Only retrieve the last N events.") + parser.add_argument("event_file") + parser.add_argument("room") args = parser.parse_args() diff --git a/contrib/jitsimeetbridge/jitsimeetbridge.py b/contrib/jitsimeetbridge/jitsimeetbridge.py index e82d1be5d28a..67fb2cd1a7a5 100644 --- a/contrib/jitsimeetbridge/jitsimeetbridge.py +++ b/contrib/jitsimeetbridge/jitsimeetbridge.py @@ -20,24 +20,25 @@ import subprocess import time -#ACCESS_TOKEN="" # +# ACCESS_TOKEN="" # -MATRIXBASE = 'https://matrix.org/_matrix/client/api/v1/' -MYUSERNAME = '@davetest:matrix.org' +MATRIXBASE = "https://matrix.org/_matrix/client/api/v1/" +MYUSERNAME = "@davetest:matrix.org" -HTTPBIND = 'https://meet.jit.si/http-bind' -#HTTPBIND = 'https://jitsi.vuc.me/http-bind' -#ROOMNAME = "matrix" +HTTPBIND = "https://meet.jit.si/http-bind" +# HTTPBIND = 'https://jitsi.vuc.me/http-bind' +# ROOMNAME = "matrix" ROOMNAME = "pibble" -HOST="guest.jit.si" -#HOST="jitsi.vuc.me" +HOST = "guest.jit.si" +# HOST="jitsi.vuc.me" -TURNSERVER="turn.guest.jit.si" -#TURNSERVER="turn.jitsi.vuc.me" +TURNSERVER = "turn.guest.jit.si" +# TURNSERVER="turn.jitsi.vuc.me" + +ROOMDOMAIN = "meet.jit.si" +# ROOMDOMAIN="conference.jitsi.vuc.me" -ROOMDOMAIN="meet.jit.si" -#ROOMDOMAIN="conference.jitsi.vuc.me" class TrivialMatrixClient: def __init__(self, access_token): @@ -46,38 +47,50 @@ def __init__(self, access_token): def getEvent(self): while True: - url = MATRIXBASE+'events?access_token='+self.access_token+"&timeout=60000" + url = ( + MATRIXBASE + + "events?access_token=" + + self.access_token + + "&timeout=60000" + ) if self.token: - url += "&from="+self.token + url += "&from=" + self.token req = grequests.get(url) resps = grequests.map([req]) obj = json.loads(resps[0].content) - print("incoming from matrix",obj) - if 'end' not in obj: + print("incoming from matrix", obj) + if "end" not in obj: continue - self.token = obj['end'] - if len(obj['chunk']): - return obj['chunk'][0] + self.token = obj["end"] + if len(obj["chunk"]): + return obj["chunk"][0] def joinRoom(self, roomId): - url = MATRIXBASE+'rooms/'+roomId+'/join?access_token='+self.access_token + url = MATRIXBASE + "rooms/" + roomId + "/join?access_token=" + self.access_token print(url) - headers={ 'Content-Type': 'application/json' } - req = grequests.post(url, headers=headers, data='{}') + headers = {"Content-Type": "application/json"} + req = grequests.post(url, headers=headers, data="{}") resps = grequests.map([req]) obj = json.loads(resps[0].content) - print("response: ",obj) + print("response: ", obj) def sendEvent(self, roomId, evType, event): - url = MATRIXBASE+'rooms/'+roomId+'/send/'+evType+'?access_token='+self.access_token + url = ( + MATRIXBASE + + "rooms/" + + roomId + + "/send/" + + evType + + "?access_token=" + + self.access_token + ) print(url) print(json.dumps(event)) - headers={ 'Content-Type': 'application/json' } + headers = {"Content-Type": "application/json"} req = grequests.post(url, headers=headers, data=json.dumps(event)) resps = grequests.map([req]) obj = json.loads(resps[0].content) - print("response: ",obj) - + print("response: ", obj) xmppClients = {} @@ -87,38 +100,39 @@ def matrixLoop(): while True: ev = matrixCli.getEvent() print(ev) - if ev['type'] == 'm.room.member': - print('membership event') - if ev['membership'] == 'invite' and ev['state_key'] == MYUSERNAME: - roomId = ev['room_id'] + if ev["type"] == "m.room.member": + print("membership event") + if ev["membership"] == "invite" and ev["state_key"] == MYUSERNAME: + roomId = ev["room_id"] print("joining room %s" % (roomId)) matrixCli.joinRoom(roomId) - elif ev['type'] == 'm.room.message': - if ev['room_id'] in xmppClients: + elif ev["type"] == "m.room.message": + if ev["room_id"] in xmppClients: print("already have a bridge for that user, ignoring") continue print("got message, connecting") - xmppClients[ev['room_id']] = TrivialXmppClient(ev['room_id'], ev['user_id']) - gevent.spawn(xmppClients[ev['room_id']].xmppLoop) - elif ev['type'] == 'm.call.invite': + xmppClients[ev["room_id"]] = TrivialXmppClient(ev["room_id"], ev["user_id"]) + gevent.spawn(xmppClients[ev["room_id"]].xmppLoop) + elif ev["type"] == "m.call.invite": print("Incoming call") - #sdp = ev['content']['offer']['sdp'] - #print "sdp: %s" % (sdp) - #xmppClients[ev['room_id']] = TrivialXmppClient(ev['room_id'], ev['user_id']) - #gevent.spawn(xmppClients[ev['room_id']].xmppLoop) - elif ev['type'] == 'm.call.answer': + # sdp = ev['content']['offer']['sdp'] + # print "sdp: %s" % (sdp) + # xmppClients[ev['room_id']] = TrivialXmppClient(ev['room_id'], ev['user_id']) + # gevent.spawn(xmppClients[ev['room_id']].xmppLoop) + elif ev["type"] == "m.call.answer": print("Call answered") - sdp = ev['content']['answer']['sdp'] - if ev['room_id'] not in xmppClients: + sdp = ev["content"]["answer"]["sdp"] + if ev["room_id"] not in xmppClients: print("We didn't have a call for that room") continue # should probably check call ID too - xmppCli = xmppClients[ev['room_id']] + xmppCli = xmppClients[ev["room_id"]] xmppCli.sendAnswer(sdp) - elif ev['type'] == 'm.call.hangup': - if ev['room_id'] in xmppClients: - xmppClients[ev['room_id']].stop() - del xmppClients[ev['room_id']] + elif ev["type"] == "m.call.hangup": + if ev["room_id"] in xmppClients: + xmppClients[ev["room_id"]].stop() + del xmppClients[ev["room_id"]] + class TrivialXmppClient: def __init__(self, matrixRoom, userId): @@ -132,130 +146,155 @@ def stop(self): def nextRid(self): self.rid += 1 - return '%d' % (self.rid) + return "%d" % (self.rid) def sendIq(self, xml): - fullXml = "%s" % (self.nextRid(), self.sid, xml) - #print "\t>>>%s" % (fullXml) + fullXml = ( + "%s" + % (self.nextRid(), self.sid, xml) + ) + # print "\t>>>%s" % (fullXml) return self.xmppPoke(fullXml) def xmppPoke(self, xml): - headers = {'Content-Type': 'application/xml'} + headers = {"Content-Type": "application/xml"} req = grequests.post(HTTPBIND, verify=False, headers=headers, data=xml) resps = grequests.map([req]) obj = BeautifulSoup(resps[0].content) return obj def sendAnswer(self, answer): - print("sdp from matrix client",answer) - p = subprocess.Popen(['node', 'unjingle/unjingle.js', '--sdp'], stdin=subprocess.PIPE, stdout=subprocess.PIPE) + print("sdp from matrix client", answer) + p = subprocess.Popen( + ["node", "unjingle/unjingle.js", "--sdp"], + stdin=subprocess.PIPE, + stdout=subprocess.PIPE, + ) jingle, out_err = p.communicate(answer) jingle = jingle % { - 'tojid': self.callfrom, - 'action': 'session-accept', - 'initiator': self.callfrom, - 'responder': self.jid, - 'sid': self.callsid + "tojid": self.callfrom, + "action": "session-accept", + "initiator": self.callfrom, + "responder": self.jid, + "sid": self.callsid, } - print("answer jingle from sdp",jingle) + print("answer jingle from sdp", jingle) res = self.sendIq(jingle) - print("reply from answer: ",res) + print("reply from answer: ", res) self.ssrcs = {} jingleSoup = BeautifulSoup(jingle) - for cont in jingleSoup.iq.jingle.findAll('content'): + for cont in jingleSoup.iq.jingle.findAll("content"): if cont.description: - self.ssrcs[cont['name']] = cont.description['ssrc'] - print("my ssrcs:",self.ssrcs) + self.ssrcs[cont["name"]] = cont.description["ssrc"] + print("my ssrcs:", self.ssrcs) - gevent.joinall([ - gevent.spawn(self.advertiseSsrcs) - ]) + gevent.joinall([gevent.spawn(self.advertiseSsrcs)]) def advertiseSsrcs(self): time.sleep(7) print("SSRC spammer started") while self.running: - ssrcMsg = "%(nick)s" % { 'tojid': "%s@%s/%s" % (ROOMNAME, ROOMDOMAIN, self.shortJid), 'nick': self.userId, 'assrc': self.ssrcs['audio'], 'vssrc': self.ssrcs['video'] } + ssrcMsg = ( + "%(nick)s" + % { + "tojid": "%s@%s/%s" % (ROOMNAME, ROOMDOMAIN, self.shortJid), + "nick": self.userId, + "assrc": self.ssrcs["audio"], + "vssrc": self.ssrcs["video"], + } + ) res = self.sendIq(ssrcMsg) - print("reply from ssrc announce: ",res) + print("reply from ssrc announce: ", res) time.sleep(10) - - def xmppLoop(self): self.matrixCallId = time.time() - res = self.xmppPoke("" % (self.nextRid(), HOST)) + res = self.xmppPoke( + "" + % (self.nextRid(), HOST) + ) print(res) - self.sid = res.body['sid'] + self.sid = res.body["sid"] print("sid %s" % (self.sid)) - res = self.sendIq("") + res = self.sendIq( + "" + ) - res = self.xmppPoke("" % (self.nextRid(), self.sid, HOST)) + res = self.xmppPoke( + "" + % (self.nextRid(), self.sid, HOST) + ) - res = self.sendIq("") + res = self.sendIq( + "" + ) print(res) self.jid = res.body.iq.bind.jid.string print("jid: %s" % (self.jid)) - self.shortJid = self.jid.split('-')[0] + self.shortJid = self.jid.split("-")[0] - res = self.sendIq("") + res = self.sendIq( + "" + ) - #randomthing = res.body.iq['to'] - #whatsitpart = randomthing.split('-')[0] + # randomthing = res.body.iq['to'] + # whatsitpart = randomthing.split('-')[0] - #print "other random bind thing: %s" % (randomthing) + # print "other random bind thing: %s" % (randomthing) # advertise preence to the jitsi room, with our nick - res = self.sendIq("%s" % (HOST, TURNSERVER, ROOMNAME, ROOMDOMAIN, self.userId)) - self.muc = {'users': []} - for p in res.body.findAll('presence'): + res = self.sendIq( + "%s" + % (HOST, TURNSERVER, ROOMNAME, ROOMDOMAIN, self.userId) + ) + self.muc = {"users": []} + for p in res.body.findAll("presence"): u = {} - u['shortJid'] = p['from'].split('/')[1] + u["shortJid"] = p["from"].split("/")[1] if p.c and p.c.nick: - u['nick'] = p.c.nick.string - self.muc['users'].append(u) - print("muc: ",self.muc) + u["nick"] = p.c.nick.string + self.muc["users"].append(u) + print("muc: ", self.muc) # wait for stuff while True: print("waiting...") res = self.sendIq("") - print("got from stream: ",res) + print("got from stream: ", res) if res.body.iq: - jingles = res.body.iq.findAll('jingle') + jingles = res.body.iq.findAll("jingle") if len(jingles): - self.callfrom = res.body.iq['from'] + self.callfrom = res.body.iq["from"] self.handleInvite(jingles[0]) - elif 'type' in res.body and res.body['type'] == 'terminate': + elif "type" in res.body and res.body["type"] == "terminate": self.running = False del xmppClients[self.matrixRoom] return def handleInvite(self, jingle): - self.initiator = jingle['initiator'] - self.callsid = jingle['sid'] - p = subprocess.Popen(['node', 'unjingle/unjingle.js', '--jingle'], stdin=subprocess.PIPE, stdout=subprocess.PIPE) - print("raw jingle invite",str(jingle)) + self.initiator = jingle["initiator"] + self.callsid = jingle["sid"] + p = subprocess.Popen( + ["node", "unjingle/unjingle.js", "--jingle"], + stdin=subprocess.PIPE, + stdout=subprocess.PIPE, + ) + print("raw jingle invite", str(jingle)) sdp, out_err = p.communicate(str(jingle)) - print("transformed remote offer sdp",sdp) + print("transformed remote offer sdp", sdp) inviteEvent = { - 'offer': { - 'type': 'offer', - 'sdp': sdp - }, - 'call_id': self.matrixCallId, - 'version': 0, - 'lifetime': 30000 + "offer": {"type": "offer", "sdp": sdp}, + "call_id": self.matrixCallId, + "version": 0, + "lifetime": 30000, } - matrixCli.sendEvent(self.matrixRoom, 'm.call.invite', inviteEvent) + matrixCli.sendEvent(self.matrixRoom, "m.call.invite", inviteEvent) -matrixCli = TrivialMatrixClient(ACCESS_TOKEN) # Undefined name -gevent.joinall([ - gevent.spawn(matrixLoop) -]) +matrixCli = TrivialMatrixClient(ACCESS_TOKEN) # Undefined name +gevent.joinall([gevent.spawn(matrixLoop)]) diff --git a/contrib/scripts/kick_users.py b/contrib/scripts/kick_users.py index b4a14385d05f..f57e6e7d2599 100755 --- a/contrib/scripts/kick_users.py +++ b/contrib/scripts/kick_users.py @@ -11,22 +11,22 @@ except NameError: # Python 3 raw_input = input + def _mkurl(template, kws): for key in kws: template = template.replace(key, kws[key]) return template + def main(hs, room_id, access_token, user_id_prefix, why): if not why: why = "Automated kick." - print("Kicking members on %s in room %s matching %s" % (hs, room_id, user_id_prefix)) + print( + "Kicking members on %s in room %s matching %s" % (hs, room_id, user_id_prefix) + ) room_state_url = _mkurl( "$HS/_matrix/client/api/v1/rooms/$ROOM/state?access_token=$TOKEN", - { - "$HS": hs, - "$ROOM": room_id, - "$TOKEN": access_token - } + {"$HS": hs, "$ROOM": room_id, "$TOKEN": access_token}, ) print("Getting room state => %s" % room_state_url) res = requests.get(room_state_url) @@ -57,24 +57,16 @@ def main(hs, room_id, access_token, user_id_prefix, why): for uid in kick_list: print(uid) doit = raw_input("Continue? [Y]es\n") - if len(doit) > 0 and doit.lower() == 'y': + if len(doit) > 0 and doit.lower() == "y": print("Kicking members...") # encode them all kick_list = [urllib.quote(uid) for uid in kick_list] for uid in kick_list: kick_url = _mkurl( "$HS/_matrix/client/api/v1/rooms/$ROOM/state/m.room.member/$UID?access_token=$TOKEN", - { - "$HS": hs, - "$UID": uid, - "$ROOM": room_id, - "$TOKEN": access_token - } + {"$HS": hs, "$UID": uid, "$ROOM": room_id, "$TOKEN": access_token}, ) - kick_body = { - "membership": "leave", - "reason": why - } + kick_body = {"membership": "leave", "reason": why} print("Kicking %s" % uid) res = requests.put(kick_url, data=json.dumps(kick_body)) if res.status_code != 200: @@ -83,14 +75,15 @@ def main(hs, room_id, access_token, user_id_prefix, why): print("ERROR: JSON %s" % res.json()) - if __name__ == "__main__": parser = ArgumentParser("Kick members in a room matching a certain user ID prefix.") - parser.add_argument("-u","--user-id",help="The user ID prefix e.g. '@irc_'") - parser.add_argument("-t","--token",help="Your access_token") - parser.add_argument("-r","--room",help="The room ID to kick members in") - parser.add_argument("-s","--homeserver",help="The base HS url e.g. http://matrix.org") - parser.add_argument("-w","--why",help="Reason for the kick. Optional.") + parser.add_argument("-u", "--user-id", help="The user ID prefix e.g. '@irc_'") + parser.add_argument("-t", "--token", help="Your access_token") + parser.add_argument("-r", "--room", help="The room ID to kick members in") + parser.add_argument( + "-s", "--homeserver", help="The base HS url e.g. http://matrix.org" + ) + parser.add_argument("-w", "--why", help="Reason for the kick. Optional.") args = parser.parse_args() if not args.room or not args.token or not args.user_id or not args.homeserver: parser.print_help() diff --git a/demo/webserver.py b/demo/webserver.py index 875095c87789..ba176d3bd2af 100644 --- a/demo/webserver.py +++ b/demo/webserver.py @@ -6,23 +6,25 @@ from daemonize import Daemonize + class SimpleHTTPRequestHandlerWithPOST(SimpleHTTPServer.SimpleHTTPRequestHandler): UPLOAD_PATH = "upload" """ Accept all post request as file upload """ + def do_POST(self): path = os.path.join(self.UPLOAD_PATH, os.path.basename(self.path)) - length = self.headers['content-length'] + length = self.headers["content-length"] data = self.rfile.read(int(length)) - with open(path, 'wb') as fh: + with open(path, "wb") as fh: fh.write(data) self.send_response(200) - self.send_header('Content-Type', 'application/json') + self.send_header("Content-Type", "application/json") self.end_headers() # Return the absolute path of the uploaded file @@ -33,30 +35,25 @@ def setup(): parser = argparse.ArgumentParser() parser.add_argument("directory") parser.add_argument("-p", "--port", dest="port", type=int, default=8080) - parser.add_argument('-P', "--pid-file", dest="pid", default="web.pid") + parser.add_argument("-P", "--pid-file", dest="pid", default="web.pid") args = parser.parse_args() # Get absolute path to directory to serve, as daemonize changes to '/' os.chdir(args.directory) dr = os.getcwd() - httpd = BaseHTTPServer.HTTPServer( - ('', args.port), - SimpleHTTPRequestHandlerWithPOST - ) + httpd = BaseHTTPServer.HTTPServer(("", args.port), SimpleHTTPRequestHandlerWithPOST) def run(): os.chdir(dr) httpd.serve_forever() daemon = Daemonize( - app="synapse-webclient", - pid=args.pid, - action=run, - auto_close_fds=False, - ) + app="synapse-webclient", pid=args.pid, action=run, auto_close_fds=False + ) daemon.start() -if __name__ == '__main__': + +if __name__ == "__main__": setup() diff --git a/docker/start.py b/docker/start.py index 2da555272a45..a7a54dacf7c4 100755 --- a/docker/start.py +++ b/docker/start.py @@ -8,7 +8,10 @@ import codecs # Utility functions -convert = lambda src, dst, environ: open(dst, "w").write(jinja2.Template(open(src).read()).render(**environ)) +convert = lambda src, dst, environ: open(dst, "w").write( + jinja2.Template(open(src).read()).render(**environ) +) + def check_arguments(environ, args): for argument in args: @@ -16,18 +19,22 @@ def check_arguments(environ, args): print("Environment variable %s is mandatory, exiting." % argument) sys.exit(2) + def generate_secrets(environ, secrets): for name, secret in secrets.items(): if secret not in environ: filename = "/data/%s.%s.key" % (environ["SYNAPSE_SERVER_NAME"], name) if os.path.exists(filename): - with open(filename) as handle: value = handle.read() + with open(filename) as handle: + value = handle.read() else: print("Generating a random secret for {}".format(name)) value = codecs.encode(os.urandom(32), "hex").decode() - with open(filename, "w") as handle: handle.write(value) + with open(filename, "w") as handle: + handle.write(value) environ[secret] = value + # Prepare the configuration mode = sys.argv[1] if len(sys.argv) > 1 else None environ = os.environ.copy() @@ -36,12 +43,17 @@ def generate_secrets(environ, secrets): # In generate mode, generate a configuration, missing keys, then exit if mode == "generate": - check_arguments(environ, ("SYNAPSE_SERVER_NAME", "SYNAPSE_REPORT_STATS", "SYNAPSE_CONFIG_PATH")) + check_arguments( + environ, ("SYNAPSE_SERVER_NAME", "SYNAPSE_REPORT_STATS", "SYNAPSE_CONFIG_PATH") + ) args += [ - "--server-name", environ["SYNAPSE_SERVER_NAME"], - "--report-stats", environ["SYNAPSE_REPORT_STATS"], - "--config-path", environ["SYNAPSE_CONFIG_PATH"], - "--generate-config" + "--server-name", + environ["SYNAPSE_SERVER_NAME"], + "--report-stats", + environ["SYNAPSE_REPORT_STATS"], + "--config-path", + environ["SYNAPSE_CONFIG_PATH"], + "--generate-config", ] os.execv("/usr/local/bin/python", args) @@ -51,15 +63,19 @@ def generate_secrets(environ, secrets): config_path = environ["SYNAPSE_CONFIG_PATH"] else: check_arguments(environ, ("SYNAPSE_SERVER_NAME", "SYNAPSE_REPORT_STATS")) - generate_secrets(environ, { - "registration": "SYNAPSE_REGISTRATION_SHARED_SECRET", - "macaroon": "SYNAPSE_MACAROON_SECRET_KEY" - }) + generate_secrets( + environ, + { + "registration": "SYNAPSE_REGISTRATION_SHARED_SECRET", + "macaroon": "SYNAPSE_MACAROON_SECRET_KEY", + }, + ) environ["SYNAPSE_APPSERVICES"] = glob.glob("/data/appservices/*.yaml") - if not os.path.exists("/compiled"): os.mkdir("/compiled") + if not os.path.exists("/compiled"): + os.mkdir("/compiled") config_path = "/compiled/homeserver.yaml" - + # Convert SYNAPSE_NO_TLS to boolean if exists if "SYNAPSE_NO_TLS" in environ: tlsanswerstring = str.lower(environ["SYNAPSE_NO_TLS"]) @@ -69,19 +85,23 @@ def generate_secrets(environ, secrets): if tlsanswerstring in ("false", "off", "0", "no"): environ["SYNAPSE_NO_TLS"] = False else: - print("Environment variable \"SYNAPSE_NO_TLS\" found but value \"" + tlsanswerstring + "\" unrecognized; exiting.") + print( + 'Environment variable "SYNAPSE_NO_TLS" found but value "' + + tlsanswerstring + + '" unrecognized; exiting.' + ) sys.exit(2) convert("/conf/homeserver.yaml", config_path, environ) convert("/conf/log.config", "/compiled/log.config", environ) subprocess.check_output(["chown", "-R", ownership, "/data"]) - args += [ - "--config-path", config_path, - + "--config-path", + config_path, # tell synapse to put any generated keys in /data rather than /compiled - "--keys-directory", "/data", + "--keys-directory", + "/data", ] # Generate missing keys and start synapse diff --git a/docs/sphinx/conf.py b/docs/sphinx/conf.py index 0b15bd89123a..ca4b879526df 100644 --- a/docs/sphinx/conf.py +++ b/docs/sphinx/conf.py @@ -18,226 +18,220 @@ # If extensions (or modules to document with autodoc) are in another directory, # add these directories to sys.path here. If the directory is relative to the # documentation root, use os.path.abspath to make it absolute, like shown here. -sys.path.insert(0, os.path.abspath('..')) +sys.path.insert(0, os.path.abspath("..")) # -- General configuration ------------------------------------------------ # If your documentation needs a minimal Sphinx version, state it here. -#needs_sphinx = '1.0' +# needs_sphinx = '1.0' # Add any Sphinx extension module names here, as strings. They can be # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom # ones. extensions = [ - 'sphinx.ext.autodoc', - 'sphinx.ext.intersphinx', - 'sphinx.ext.coverage', - 'sphinx.ext.ifconfig', - 'sphinxcontrib.napoleon', + "sphinx.ext.autodoc", + "sphinx.ext.intersphinx", + "sphinx.ext.coverage", + "sphinx.ext.ifconfig", + "sphinxcontrib.napoleon", ] # Add any paths that contain templates here, relative to this directory. -templates_path = ['_templates'] +templates_path = ["_templates"] # The suffix of source filenames. -source_suffix = '.rst' +source_suffix = ".rst" # The encoding of source files. -#source_encoding = 'utf-8-sig' +# source_encoding = 'utf-8-sig' # The master toctree document. -master_doc = 'index' +master_doc = "index" # General information about the project. -project = u'Synapse' -copyright = u'Copyright 2014-2017 OpenMarket Ltd, 2017 Vector Creations Ltd, 2017 New Vector Ltd' +project = "Synapse" +copyright = ( + "Copyright 2014-2017 OpenMarket Ltd, 2017 Vector Creations Ltd, 2017 New Vector Ltd" +) # The version info for the project you're documenting, acts as replacement for # |version| and |release|, also used in various other places throughout the # built documents. # # The short X.Y version. -version = '1.0' +version = "1.0" # The full version, including alpha/beta/rc tags. -release = '1.0' +release = "1.0" # The language for content autogenerated by Sphinx. Refer to documentation # for a list of supported languages. -#language = None +# language = None # There are two options for replacing |today|: either, you set today to some # non-false value, then it is used: -#today = '' +# today = '' # Else, today_fmt is used as the format for a strftime call. -#today_fmt = '%B %d, %Y' +# today_fmt = '%B %d, %Y' # List of patterns, relative to source directory, that match files and # directories to ignore when looking for source files. -exclude_patterns = ['_build'] +exclude_patterns = ["_build"] # The reST default role (used for this markup: `text`) to use for all # documents. -#default_role = None +# default_role = None # If true, '()' will be appended to :func: etc. cross-reference text. -#add_function_parentheses = True +# add_function_parentheses = True # If true, the current module name will be prepended to all description # unit titles (such as .. function::). -#add_module_names = True +# add_module_names = True # If true, sectionauthor and moduleauthor directives will be shown in the # output. They are ignored by default. -#show_authors = False +# show_authors = False # The name of the Pygments (syntax highlighting) style to use. -pygments_style = 'sphinx' +pygments_style = "sphinx" # A list of ignored prefixes for module index sorting. -#modindex_common_prefix = [] +# modindex_common_prefix = [] # If true, keep warnings as "system message" paragraphs in the built documents. -#keep_warnings = False +# keep_warnings = False # -- Options for HTML output ---------------------------------------------- # The theme to use for HTML and HTML Help pages. See the documentation for # a list of builtin themes. -html_theme = 'default' +html_theme = "default" # Theme options are theme-specific and customize the look and feel of a theme # further. For a list of options available for each theme, see the # documentation. -#html_theme_options = {} +# html_theme_options = {} # Add any paths that contain custom themes here, relative to this directory. -#html_theme_path = [] +# html_theme_path = [] # The name for this set of Sphinx documents. If None, it defaults to # " v documentation". -#html_title = None +# html_title = None # A shorter title for the navigation bar. Default is the same as html_title. -#html_short_title = None +# html_short_title = None # The name of an image file (relative to this directory) to place at the top # of the sidebar. -#html_logo = None +# html_logo = None # The name of an image file (within the static path) to use as favicon of the # docs. This file should be a Windows icon file (.ico) being 16x16 or 32x32 # pixels large. -#html_favicon = None +# html_favicon = None # Add any paths that contain custom static files (such as style sheets) here, # relative to this directory. They are copied after the builtin static files, # so a file named "default.css" will overwrite the builtin "default.css". -html_static_path = ['_static'] +html_static_path = ["_static"] # Add any extra paths that contain custom files (such as robots.txt or # .htaccess) here, relative to this directory. These files are copied # directly to the root of the documentation. -#html_extra_path = [] +# html_extra_path = [] # If not '', a 'Last updated on:' timestamp is inserted at every page bottom, # using the given strftime format. -#html_last_updated_fmt = '%b %d, %Y' +# html_last_updated_fmt = '%b %d, %Y' # If true, SmartyPants will be used to convert quotes and dashes to # typographically correct entities. -#html_use_smartypants = True +# html_use_smartypants = True # Custom sidebar templates, maps document names to template names. -#html_sidebars = {} +# html_sidebars = {} # Additional templates that should be rendered to pages, maps page names to # template names. -#html_additional_pages = {} +# html_additional_pages = {} # If false, no module index is generated. -#html_domain_indices = True +# html_domain_indices = True # If false, no index is generated. -#html_use_index = True +# html_use_index = True # If true, the index is split into individual pages for each letter. -#html_split_index = False +# html_split_index = False # If true, links to the reST sources are added to the pages. -#html_show_sourcelink = True +# html_show_sourcelink = True # If true, "Created using Sphinx" is shown in the HTML footer. Default is True. -#html_show_sphinx = True +# html_show_sphinx = True # If true, "(C) Copyright ..." is shown in the HTML footer. Default is True. -#html_show_copyright = True +# html_show_copyright = True # If true, an OpenSearch description file will be output, and all pages will # contain a tag referring to it. The value of this option must be the # base URL from which the finished HTML is served. -#html_use_opensearch = '' +# html_use_opensearch = '' # This is the file name suffix for HTML files (e.g. ".xhtml"). -#html_file_suffix = None +# html_file_suffix = None # Output file base name for HTML help builder. -htmlhelp_basename = 'Synapsedoc' +htmlhelp_basename = "Synapsedoc" # -- Options for LaTeX output --------------------------------------------- latex_elements = { -# The paper size ('letterpaper' or 'a4paper'). -#'papersize': 'letterpaper', - -# The font size ('10pt', '11pt' or '12pt'). -#'pointsize': '10pt', - -# Additional stuff for the LaTeX preamble. -#'preamble': '', + # The paper size ('letterpaper' or 'a4paper'). + #'papersize': 'letterpaper', + # The font size ('10pt', '11pt' or '12pt'). + #'pointsize': '10pt', + # Additional stuff for the LaTeX preamble. + #'preamble': '', } # Grouping the document tree into LaTeX files. List of tuples # (source start file, target name, title, # author, documentclass [howto, manual, or own class]). -latex_documents = [ - ('index', 'Synapse.tex', u'Synapse Documentation', - u'TNG', 'manual'), -] +latex_documents = [("index", "Synapse.tex", "Synapse Documentation", "TNG", "manual")] # The name of an image file (relative to this directory) to place at the top of # the title page. -#latex_logo = None +# latex_logo = None # For "manual" documents, if this is true, then toplevel headings are parts, # not chapters. -#latex_use_parts = False +# latex_use_parts = False # If true, show page references after internal links. -#latex_show_pagerefs = False +# latex_show_pagerefs = False # If true, show URL addresses after external links. -#latex_show_urls = False +# latex_show_urls = False # Documents to append as an appendix to all manuals. -#latex_appendices = [] +# latex_appendices = [] # If false, no module index is generated. -#latex_domain_indices = True +# latex_domain_indices = True # -- Options for manual page output --------------------------------------- # One entry per manual page. List of tuples # (source start file, name, description, authors, manual section). -man_pages = [ - ('index', 'synapse', u'Synapse Documentation', - [u'TNG'], 1) -] +man_pages = [("index", "synapse", "Synapse Documentation", ["TNG"], 1)] # If true, show URL addresses after external links. -#man_show_urls = False +# man_show_urls = False # -- Options for Texinfo output ------------------------------------------- @@ -246,26 +240,32 @@ # (source start file, target name, title, author, # dir menu entry, description, category) texinfo_documents = [ - ('index', 'Synapse', u'Synapse Documentation', - u'TNG', 'Synapse', 'One line description of project.', - 'Miscellaneous'), + ( + "index", + "Synapse", + "Synapse Documentation", + "TNG", + "Synapse", + "One line description of project.", + "Miscellaneous", + ) ] # Documents to append as an appendix to all manuals. -#texinfo_appendices = [] +# texinfo_appendices = [] # If false, no module index is generated. -#texinfo_domain_indices = True +# texinfo_domain_indices = True # How to display URL addresses: 'footnote', 'no', or 'inline'. -#texinfo_show_urls = 'footnote' +# texinfo_show_urls = 'footnote' # If true, do not generate a @detailmenu in the "Top" node's menu. -#texinfo_no_detailmenu = False +# texinfo_no_detailmenu = False # Example configuration for intersphinx: refer to the Python standard library. -intersphinx_mapping = {'http://docs.python.org/': None} +intersphinx_mapping = {"http://docs.python.org/": None} napoleon_include_special_with_doc = True napoleon_use_ivar = True diff --git a/pyproject.toml b/pyproject.toml index dd099dc9c81e..ec23258da8db 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -28,3 +28,22 @@ directory = "misc" name = "Internal Changes" showcontent = true + +[tool.black] +target-version = ['py34'] +exclude = ''' + +( + /( + \.eggs # exclude a few common directories in the + | \.git # root of the project + | \.tox + | \.venv + | _build + | _trial_temp.* + | build + | dist + | debian + )/ +) +''' diff --git a/scripts-dev/check_auth.py b/scripts-dev/check_auth.py index b3d11f49ec00..2a1c5f39d433 100644 --- a/scripts-dev/check_auth.py +++ b/scripts-dev/check_auth.py @@ -39,11 +39,11 @@ def check_auth(auth, auth_chain, events): print("Success:", e.event_id, e.type, e.state_key) -if __name__ == '__main__': +if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument( - 'json', nargs='?', type=argparse.FileType('r'), default=sys.stdin + "json", nargs="?", type=argparse.FileType("r"), default=sys.stdin ) args = parser.parse_args() diff --git a/scripts-dev/check_event_hash.py b/scripts-dev/check_event_hash.py index 8535f99697cf..cd5599e9a1ff 100644 --- a/scripts-dev/check_event_hash.py +++ b/scripts-dev/check_event_hash.py @@ -30,7 +30,7 @@ def get_pdu_json(self): def main(): parser = argparse.ArgumentParser() parser.add_argument( - "input_json", nargs="?", type=argparse.FileType('r'), default=sys.stdin + "input_json", nargs="?", type=argparse.FileType("r"), default=sys.stdin ) args = parser.parse_args() logging.basicConfig() diff --git a/scripts-dev/check_signature.py b/scripts-dev/check_signature.py index 612f17ca7f9c..ecda103cf7c4 100644 --- a/scripts-dev/check_signature.py +++ b/scripts-dev/check_signature.py @@ -1,4 +1,3 @@ - import argparse import json import logging @@ -40,7 +39,7 @@ def main(): parser = argparse.ArgumentParser() parser.add_argument("signature_name") parser.add_argument( - "input_json", nargs="?", type=argparse.FileType('r'), default=sys.stdin + "input_json", nargs="?", type=argparse.FileType("r"), default=sys.stdin ) args = parser.parse_args() @@ -69,5 +68,5 @@ def main(): print("FAIL %s" % (key_id,)) -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/scripts-dev/convert_server_keys.py b/scripts-dev/convert_server_keys.py index ac152b5c4249..179be61c30d9 100644 --- a/scripts-dev/convert_server_keys.py +++ b/scripts-dev/convert_server_keys.py @@ -116,5 +116,5 @@ def main(): connection.commit() -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/scripts-dev/definitions.py b/scripts-dev/definitions.py index 1deb0fe2b7fe..9eddb6d515d7 100755 --- a/scripts-dev/definitions.py +++ b/scripts-dev/definitions.py @@ -19,10 +19,10 @@ def __init__(self): self.names = {} self.attrs = set() self.definitions = { - 'def': self.functions, - 'class': self.classes, - 'names': self.names, - 'attrs': self.attrs, + "def": self.functions, + "class": self.classes, + "names": self.names, + "attrs": self.attrs, } def visit_Name(self, node): @@ -47,23 +47,23 @@ def visit_FunctionDef(self, node): def non_empty(defs): - functions = {name: non_empty(f) for name, f in defs['def'].items()} - classes = {name: non_empty(f) for name, f in defs['class'].items()} + functions = {name: non_empty(f) for name, f in defs["def"].items()} + classes = {name: non_empty(f) for name, f in defs["class"].items()} result = {} if functions: - result['def'] = functions + result["def"] = functions if classes: - result['class'] = classes - names = defs['names'] + result["class"] = classes + names = defs["names"] uses = [] - for name in names.get('Load', ()): - if name not in names.get('Param', ()) and name not in names.get('Store', ()): + for name in names.get("Load", ()): + if name not in names.get("Param", ()) and name not in names.get("Store", ()): uses.append(name) - uses.extend(defs['attrs']) + uses.extend(defs["attrs"]) if uses: - result['uses'] = uses - result['names'] = names - result['attrs'] = defs['attrs'] + result["uses"] = uses + result["names"] = names + result["attrs"] = defs["attrs"] return result @@ -81,33 +81,33 @@ def definitions_in_file(filepath): def defined_names(prefix, defs, names): - for name, funcs in defs.get('def', {}).items(): - names.setdefault(name, {'defined': []})['defined'].append(prefix + name) + for name, funcs in defs.get("def", {}).items(): + names.setdefault(name, {"defined": []})["defined"].append(prefix + name) defined_names(prefix + name + ".", funcs, names) - for name, funcs in defs.get('class', {}).items(): - names.setdefault(name, {'defined': []})['defined'].append(prefix + name) + for name, funcs in defs.get("class", {}).items(): + names.setdefault(name, {"defined": []})["defined"].append(prefix + name) defined_names(prefix + name + ".", funcs, names) def used_names(prefix, item, defs, names): - for name, funcs in defs.get('def', {}).items(): + for name, funcs in defs.get("def", {}).items(): used_names(prefix + name + ".", name, funcs, names) - for name, funcs in defs.get('class', {}).items(): + for name, funcs in defs.get("class", {}).items(): used_names(prefix + name + ".", name, funcs, names) - path = prefix.rstrip('.') - for used in defs.get('uses', ()): + path = prefix.rstrip(".") + for used in defs.get("uses", ()): if used in names: if item: - names[item].setdefault('uses', []).append(used) - names[used].setdefault('used', {}).setdefault(item, []).append(path) + names[item].setdefault("uses", []).append(used) + names[used].setdefault("used", {}).setdefault(item, []).append(path) -if __name__ == '__main__': +if __name__ == "__main__": - parser = argparse.ArgumentParser(description='Find definitions.') + parser = argparse.ArgumentParser(description="Find definitions.") parser.add_argument( "--unused", action="store_true", help="Only list unused definitions" ) @@ -119,7 +119,7 @@ def used_names(prefix, item, defs, names): ) parser.add_argument( "directories", - nargs='+', + nargs="+", metavar="DIR", help="Directories to search for definitions", ) @@ -164,7 +164,7 @@ def used_names(prefix, item, defs, names): continue if ignore and any(pattern.match(name) for pattern in ignore): continue - if args.unused and definition.get('used'): + if args.unused and definition.get("used"): continue result[name] = definition @@ -196,9 +196,9 @@ def used_names(prefix, item, defs, names): continue result[name] = definition - if args.format == 'yaml': + if args.format == "yaml": yaml.dump(result, sys.stdout, default_flow_style=False) - elif args.format == 'dot': + elif args.format == "dot": print("digraph {") for name, entry in result.items(): print(name) diff --git a/scripts-dev/federation_client.py b/scripts-dev/federation_client.py index 41e7b244187a..7c19e405d451 100755 --- a/scripts-dev/federation_client.py +++ b/scripts-dev/federation_client.py @@ -63,7 +63,7 @@ def encode_canonical_json(value): # Encode code-points outside of ASCII as UTF-8 rather than \u escapes ensure_ascii=False, # Remove unecessary white space. - separators=(',', ':'), + separators=(",", ":"), # Sort the keys of dictionaries. sort_keys=True, # Encode the resulting unicode as UTF-8 bytes. @@ -145,7 +145,7 @@ def request_json(method, origin_name, origin_key, destination, path, content): authorization_headers = [] for key, sig in signed_json["signatures"][origin_name].items(): - header = "X-Matrix origin=%s,key=\"%s\",sig=\"%s\"" % (origin_name, key, sig) + header = 'X-Matrix origin=%s,key="%s",sig="%s"' % (origin_name, key, sig) authorization_headers.append(header.encode("ascii")) print("Authorization: %s" % header, file=sys.stderr) @@ -161,11 +161,7 @@ def request_json(method, origin_name, origin_key, destination, path, content): headers["Content-Type"] = "application/json" result = s.request( - method=method, - url=dest, - headers=headers, - verify=False, - data=content, + method=method, url=dest, headers=headers, verify=False, data=content ) sys.stderr.write("Status Code: %d\n" % (result.status_code,)) return result.json() @@ -241,18 +237,18 @@ def main(): def read_args_from_config(args): - with open(args.config, 'r') as fh: + with open(args.config, "r") as fh: config = yaml.safe_load(fh) if not args.server_name: - args.server_name = config['server_name'] + args.server_name = config["server_name"] if not args.signing_key_path: - args.signing_key_path = config['signing_key_path'] + args.signing_key_path = config["signing_key_path"] class MatrixConnectionAdapter(HTTPAdapter): @staticmethod def lookup(s, skip_well_known=False): - if s[-1] == ']': + if s[-1] == "]": # ipv6 literal (with no port) return s, 8448 @@ -268,9 +264,7 @@ def lookup(s, skip_well_known=False): if not skip_well_known: well_known = MatrixConnectionAdapter.get_well_known(s) if well_known: - return MatrixConnectionAdapter.lookup( - well_known, skip_well_known=True - ) + return MatrixConnectionAdapter.lookup(well_known, skip_well_known=True) try: srv = srvlookup.lookup("matrix", "tcp", s)[0] @@ -280,8 +274,8 @@ def lookup(s, skip_well_known=False): @staticmethod def get_well_known(server_name): - uri = "https://%s/.well-known/matrix/server" % (server_name, ) - print("fetching %s" % (uri, ), file=sys.stderr) + uri = "https://%s/.well-known/matrix/server" % (server_name,) + print("fetching %s" % (uri,), file=sys.stderr) try: resp = requests.get(uri) @@ -294,12 +288,12 @@ def get_well_known(server_name): raise Exception("not a dict") if "m.server" not in parsed_well_known: raise Exception("Missing key 'm.server'") - new_name = parsed_well_known['m.server'] - print("well-known lookup gave %s" % (new_name, ), file=sys.stderr) + new_name = parsed_well_known["m.server"] + print("well-known lookup gave %s" % (new_name,), file=sys.stderr) return new_name except Exception as e: - print("Invalid response from %s: %s" % (uri, e, ), file=sys.stderr) + print("Invalid response from %s: %s" % (uri, e), file=sys.stderr) return None def get_connection(self, url, proxies=None): diff --git a/scripts-dev/hash_history.py b/scripts-dev/hash_history.py index 514d80fa606d..d20f6db17656 100644 --- a/scripts-dev/hash_history.py +++ b/scripts-dev/hash_history.py @@ -79,5 +79,5 @@ def main(): conn.commit() -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/scripts-dev/list_url_patterns.py b/scripts-dev/list_url_patterns.py index 62e5a07472b1..26ad7c67f483 100755 --- a/scripts-dev/list_url_patterns.py +++ b/scripts-dev/list_url_patterns.py @@ -35,11 +35,11 @@ def find_patterns_in_file(filepath): find_patterns_in_code(f.read()) -parser = argparse.ArgumentParser(description='Find url patterns.') +parser = argparse.ArgumentParser(description="Find url patterns.") parser.add_argument( "directories", - nargs='+', + nargs="+", metavar="DIR", help="Directories to search for definitions", ) diff --git a/scripts-dev/tail-synapse.py b/scripts-dev/tail-synapse.py index 7c9985d9f07f..44e3a6dbf16e 100644 --- a/scripts-dev/tail-synapse.py +++ b/scripts-dev/tail-synapse.py @@ -63,5 +63,5 @@ def main(): streams[update.name] = update.position -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/scripts/generate_signing_key.py b/scripts/generate_signing_key.py index 36e9140b5017..16d7c4f38284 100755 --- a/scripts/generate_signing_key.py +++ b/scripts/generate_signing_key.py @@ -24,14 +24,14 @@ parser = argparse.ArgumentParser() parser.add_argument( - "-o", "--output_file", - - type=argparse.FileType('w'), + "-o", + "--output_file", + type=argparse.FileType("w"), default=sys.stdout, help="Where to write the output to", ) args = parser.parse_args() key_id = "a_" + random_string(4) - key = generate_signing_key(key_id), + key = (generate_signing_key(key_id),) write_signing_keys(args.output_file, key) diff --git a/scripts/move_remote_media_to_new_store.py b/scripts/move_remote_media_to_new_store.py index e630936f7825..12747c6024a0 100755 --- a/scripts/move_remote_media_to_new_store.py +++ b/scripts/move_remote_media_to_new_store.py @@ -50,7 +50,7 @@ def main(src_repo, dest_repo): dest_paths = MediaFilePaths(dest_repo) for line in sys.stdin: line = line.strip() - parts = line.split('|') + parts = line.split("|") if len(parts) != 2: print("Unable to parse input line %s" % line, file=sys.stderr) exit(1) @@ -107,7 +107,7 @@ def mkdir_and_move(original_file, dest_file): parser = argparse.ArgumentParser( description=__doc__, formatter_class=argparse.RawDescriptionHelpFormatter ) - parser.add_argument("-v", action='store_true', help='enable debug logging') + parser.add_argument("-v", action="store_true", help="enable debug logging") parser.add_argument("src_repo", help="Path to source content repo") parser.add_argument("dest_repo", help="Path to source content repo") args = parser.parse_args() diff --git a/setup.cfg b/setup.cfg index b6b4aa740d5b..12a7849081a0 100644 --- a/setup.cfg +++ b/setup.cfg @@ -15,18 +15,17 @@ ignore = tox.ini [flake8] -max-line-length = 90 - # see https://pycodestyle.readthedocs.io/en/latest/intro.html#error-codes # for error codes. The ones we ignore are: # W503: line break before binary operator # W504: line break after binary operator # E203: whitespace before ':' (which is contrary to pep8?) # E731: do not assign a lambda expression, use a def -ignore=W503,W504,E203,E731 +# E501: Line too long (black enforces this for us) +ignore=W503,W504,E203,E731,E501 [isort] -line_length = 89 +line_length = 88 not_skip = __init__.py sections=FUTURE,STDLIB,COMPAT,THIRDPARTY,TWISTED,FIRSTPARTY,TESTS,LOCALFOLDER default_section=THIRDPARTY diff --git a/setup.py b/setup.py index 3492cdc5a0e2..5ce06c898792 100755 --- a/setup.py +++ b/setup.py @@ -60,9 +60,12 @@ def finalize_options(self): pass def run(self): - print ("""Synapse's tests cannot be run via setup.py. To run them, try: + print( + """Synapse's tests cannot be run via setup.py. To run them, try: PYTHONPATH="." trial tests -""") +""" + ) + def read_file(path_segments): """Read a file from the package. Takes a list of strings to join to @@ -84,9 +87,9 @@ def exec_file(path_segments): dependencies = exec_file(("synapse", "python_dependencies.py")) long_description = read_file(("README.rst",)) -REQUIREMENTS = dependencies['REQUIREMENTS'] -CONDITIONAL_REQUIREMENTS = dependencies['CONDITIONAL_REQUIREMENTS'] -ALL_OPTIONAL_REQUIREMENTS = dependencies['ALL_OPTIONAL_REQUIREMENTS'] +REQUIREMENTS = dependencies["REQUIREMENTS"] +CONDITIONAL_REQUIREMENTS = dependencies["CONDITIONAL_REQUIREMENTS"] +ALL_OPTIONAL_REQUIREMENTS = dependencies["ALL_OPTIONAL_REQUIREMENTS"] # Make `pip install matrix-synapse[all]` install all the optional dependencies. CONDITIONAL_REQUIREMENTS["all"] = list(ALL_OPTIONAL_REQUIREMENTS) @@ -102,16 +105,16 @@ def exec_file(path_segments): include_package_data=True, zip_safe=False, long_description=long_description, - python_requires='~=3.5', + python_requires="~=3.5", classifiers=[ - 'Development Status :: 5 - Production/Stable', - 'Topic :: Communications :: Chat', - 'License :: OSI Approved :: Apache Software License', - 'Programming Language :: Python :: 3 :: Only', - 'Programming Language :: Python :: 3.5', - 'Programming Language :: Python :: 3.6', - 'Programming Language :: Python :: 3.7', + "Development Status :: 5 - Production/Stable", + "Topic :: Communications :: Chat", + "License :: OSI Approved :: Apache Software License", + "Programming Language :: Python :: 3 :: Only", + "Programming Language :: Python :: 3.5", + "Programming Language :: Python :: 3.6", + "Programming Language :: Python :: 3.7", ], scripts=["synctl"] + glob.glob("scripts/*"), - cmdclass={'test': TestCommand}, + cmdclass={"test": TestCommand}, ) diff --git a/synapse/__init__.py b/synapse/__init__.py index 0c0154678935..119359be68f4 100644 --- a/synapse/__init__.py +++ b/synapse/__init__.py @@ -28,6 +28,7 @@ from twisted.internet import protocol from twisted.internet.protocol import Factory from twisted.names.dns import DNSDatagramProtocol + protocol.Factory.noisy = False Factory.noisy = False DNSDatagramProtocol.noisy = False diff --git a/synapse/_scripts/register_new_matrix_user.py b/synapse/_scripts/register_new_matrix_user.py index 6e93f5a0c6fe..bdcd915bbe8b 100644 --- a/synapse/_scripts/register_new_matrix_user.py +++ b/synapse/_scripts/register_new_matrix_user.py @@ -57,18 +57,18 @@ def request_registration( nonce = r.json()["nonce"] - mac = hmac.new(key=shared_secret.encode('utf8'), digestmod=hashlib.sha1) + mac = hmac.new(key=shared_secret.encode("utf8"), digestmod=hashlib.sha1) - mac.update(nonce.encode('utf8')) + mac.update(nonce.encode("utf8")) mac.update(b"\x00") - mac.update(user.encode('utf8')) + mac.update(user.encode("utf8")) mac.update(b"\x00") - mac.update(password.encode('utf8')) + mac.update(password.encode("utf8")) mac.update(b"\x00") mac.update(b"admin" if admin else b"notadmin") if user_type: mac.update(b"\x00") - mac.update(user_type.encode('utf8')) + mac.update(user_type.encode("utf8")) mac = mac.hexdigest() @@ -134,8 +134,9 @@ def register_new_user(user, password, server_location, shared_secret, admin, use else: admin = False - request_registration(user, password, server_location, shared_secret, - bool(admin), user_type) + request_registration( + user, password, server_location, shared_secret, bool(admin), user_type + ) def main(): @@ -189,7 +190,7 @@ def main(): group.add_argument( "-c", "--config", - type=argparse.FileType('r'), + type=argparse.FileType("r"), help="Path to server config file. Used to read in shared secret.", ) @@ -200,7 +201,7 @@ def main(): parser.add_argument( "server_url", default="https://localhost:8448", - nargs='?', + nargs="?", help="URL to use to talk to the home server. Defaults to " " 'https://localhost:8448'.", ) @@ -220,8 +221,9 @@ def main(): if args.admin or args.no_admin: admin = args.admin - register_new_user(args.user, args.password, args.server_url, secret, - admin, args.user_type) + register_new_user( + args.user, args.password, args.server_url, secret, admin, args.user_type + ) if __name__ == "__main__": diff --git a/synapse/api/auth.py b/synapse/api/auth.py index 79e2808dc5a7..86f145649cfd 100644 --- a/synapse/api/auth.py +++ b/synapse/api/auth.py @@ -36,8 +36,11 @@ AuthEventTypes = ( - EventTypes.Create, EventTypes.Member, EventTypes.PowerLevels, - EventTypes.JoinRules, EventTypes.RoomHistoryVisibility, + EventTypes.Create, + EventTypes.Member, + EventTypes.PowerLevels, + EventTypes.JoinRules, + EventTypes.RoomHistoryVisibility, EventTypes.ThirdPartyInvite, ) @@ -54,6 +57,7 @@ class Auth(object): FIXME: This class contains a mix of functions for authenticating users of our client-server API and authenticating events added to room graphs. """ + def __init__(self, hs): self.hs = hs self.clock = hs.get_clock() @@ -70,15 +74,12 @@ def __init__(self, hs): def check_from_context(self, room_version, event, context, do_sig_check=True): prev_state_ids = yield context.get_prev_state_ids(self.store) auth_events_ids = yield self.compute_auth_events( - event, prev_state_ids, for_verification=True, + event, prev_state_ids, for_verification=True ) auth_events = yield self.store.get_events(auth_events_ids) - auth_events = { - (e.type, e.state_key): e for e in itervalues(auth_events) - } + auth_events = {(e.type, e.state_key): e for e in itervalues(auth_events)} self.check( - room_version, event, - auth_events=auth_events, do_sig_check=do_sig_check, + room_version, event, auth_events=auth_events, do_sig_check=do_sig_check ) def check(self, room_version, event, auth_events, do_sig_check=True): @@ -115,15 +116,10 @@ def check_joined_room(self, room_id, user_id, current_state=None): the room. """ if current_state: - member = current_state.get( - (EventTypes.Member, user_id), - None - ) + member = current_state.get((EventTypes.Member, user_id), None) else: member = yield self.state.get_current_state( - room_id=room_id, - event_type=EventTypes.Member, - state_key=user_id + room_id=room_id, event_type=EventTypes.Member, state_key=user_id ) self._check_joined_room(member, user_id, room_id) @@ -143,23 +139,17 @@ def check_user_was_in_room(self, room_id, user_id): the room. This will be the leave event if they have left the room. """ member = yield self.state.get_current_state( - room_id=room_id, - event_type=EventTypes.Member, - state_key=user_id + room_id=room_id, event_type=EventTypes.Member, state_key=user_id ) membership = member.membership if member else None if membership not in (Membership.JOIN, Membership.LEAVE): - raise AuthError(403, "User %s not in room %s" % ( - user_id, room_id - )) + raise AuthError(403, "User %s not in room %s" % (user_id, room_id)) if membership == Membership.LEAVE: forgot = yield self.store.did_forget(user_id, room_id) if forgot: - raise AuthError(403, "User %s not in room %s" % ( - user_id, room_id - )) + raise AuthError(403, "User %s not in room %s" % (user_id, room_id)) defer.returnValue(member) @@ -171,9 +161,9 @@ def check_host_in_room(self, room_id, host): def _check_joined_room(self, member, user_id, room_id): if not member or member.membership != Membership.JOIN: - raise AuthError(403, "User %s not in room %s (%s)" % ( - user_id, room_id, repr(member) - )) + raise AuthError( + 403, "User %s not in room %s (%s)" % (user_id, room_id, repr(member)) + ) def can_federate(self, event, auth_events): creation_event = auth_events.get((EventTypes.Create, "")) @@ -185,11 +175,7 @@ def get_public_keys(self, invite_event): @defer.inlineCallbacks def get_user_by_req( - self, - request, - allow_guest=False, - rights="access", - allow_expired=False, + self, request, allow_guest=False, rights="access", allow_expired=False ): """ Get a registered user's ID. @@ -209,9 +195,8 @@ def get_user_by_req( try: ip_addr = self.hs.get_ip_from_request(request) user_agent = request.requestHeaders.getRawHeaders( - b"User-Agent", - default=[b""] - )[0].decode('ascii', 'surrogateescape') + b"User-Agent", default=[b""] + )[0].decode("ascii", "surrogateescape") access_token = self.get_access_token_from_request( request, self.TOKEN_NOT_FOUND_HTTP_STATUS @@ -243,11 +228,12 @@ def get_user_by_req( if self._account_validity.enabled and not allow_expired: user_id = user.to_string() expiration_ts = yield self.store.get_expiration_ts_for_user(user_id) - if expiration_ts is not None and self.clock.time_msec() >= expiration_ts: + if ( + expiration_ts is not None + and self.clock.time_msec() >= expiration_ts + ): raise AuthError( - 403, - "User account has expired", - errcode=Codes.EXPIRED_ACCOUNT, + 403, "User account has expired", errcode=Codes.EXPIRED_ACCOUNT ) # device_id may not be present if get_user_by_access_token has been @@ -265,18 +251,23 @@ def get_user_by_req( if is_guest and not allow_guest: raise AuthError( - 403, "Guest access not allowed", errcode=Codes.GUEST_ACCESS_FORBIDDEN + 403, + "Guest access not allowed", + errcode=Codes.GUEST_ACCESS_FORBIDDEN, ) request.authenticated_entity = user.to_string() - defer.returnValue(synapse.types.create_requester( - user, token_id, is_guest, device_id, app_service=app_service) + defer.returnValue( + synapse.types.create_requester( + user, token_id, is_guest, device_id, app_service=app_service + ) ) except KeyError: raise AuthError( - self.TOKEN_NOT_FOUND_HTTP_STATUS, "Missing access token.", - errcode=Codes.MISSING_TOKEN + self.TOKEN_NOT_FOUND_HTTP_STATUS, + "Missing access token.", + errcode=Codes.MISSING_TOKEN, ) @defer.inlineCallbacks @@ -297,20 +288,14 @@ def _get_appservice_user_id(self, request): if b"user_id" not in request.args: defer.returnValue((app_service.sender, app_service)) - user_id = request.args[b"user_id"][0].decode('utf8') + user_id = request.args[b"user_id"][0].decode("utf8") if app_service.sender == user_id: defer.returnValue((app_service.sender, app_service)) if not app_service.is_interested_in_user(user_id): - raise AuthError( - 403, - "Application service cannot masquerade as this user." - ) + raise AuthError(403, "Application service cannot masquerade as this user.") if not (yield self.store.get_user_by_id(user_id)): - raise AuthError( - 403, - "Application service has not registered this user" - ) + raise AuthError(403, "Application service has not registered this user") defer.returnValue((user_id, app_service)) @defer.inlineCallbacks @@ -368,13 +353,13 @@ def get_user_by_access_token(self, token, rights="access"): raise AuthError( self.TOKEN_NOT_FOUND_HTTP_STATUS, "Unknown user_id %s" % user_id, - errcode=Codes.UNKNOWN_TOKEN + errcode=Codes.UNKNOWN_TOKEN, ) if not stored_user["is_guest"]: raise AuthError( self.TOKEN_NOT_FOUND_HTTP_STATUS, "Guest access token used for regular user", - errcode=Codes.UNKNOWN_TOKEN + errcode=Codes.UNKNOWN_TOKEN, ) ret = { "user": user, @@ -402,8 +387,9 @@ def get_user_by_access_token(self, token, rights="access"): ) as e: logger.warning("Invalid macaroon in auth: %s %s", type(e), e) raise AuthError( - self.TOKEN_NOT_FOUND_HTTP_STATUS, "Invalid macaroon passed.", - errcode=Codes.UNKNOWN_TOKEN + self.TOKEN_NOT_FOUND_HTTP_STATUS, + "Invalid macaroon passed.", + errcode=Codes.UNKNOWN_TOKEN, ) def _parse_and_validate_macaroon(self, token, rights="access"): @@ -441,13 +427,13 @@ def _parse_and_validate_macaroon(self, token, rights="access"): guest = True self.validate_macaroon( - macaroon, rights, self.hs.config.expire_access_token, - user_id=user_id, + macaroon, rights, self.hs.config.expire_access_token, user_id=user_id ) except (pymacaroons.exceptions.MacaroonException, TypeError, ValueError): raise AuthError( - self.TOKEN_NOT_FOUND_HTTP_STATUS, "Invalid macaroon passed.", - errcode=Codes.UNKNOWN_TOKEN + self.TOKEN_NOT_FOUND_HTTP_STATUS, + "Invalid macaroon passed.", + errcode=Codes.UNKNOWN_TOKEN, ) if not has_expiry and rights == "access": @@ -472,10 +458,11 @@ def get_user_id_from_macaroon(self, macaroon): user_prefix = "user_id = " for caveat in macaroon.caveats: if caveat.caveat_id.startswith(user_prefix): - return caveat.caveat_id[len(user_prefix):] + return caveat.caveat_id[len(user_prefix) :] raise AuthError( - self.TOKEN_NOT_FOUND_HTTP_STATUS, "No user caveat in macaroon", - errcode=Codes.UNKNOWN_TOKEN + self.TOKEN_NOT_FOUND_HTTP_STATUS, + "No user caveat in macaroon", + errcode=Codes.UNKNOWN_TOKEN, ) def validate_macaroon(self, macaroon, type_string, verify_expiry, user_id): @@ -522,7 +509,7 @@ def _verify_expiry(self, caveat): prefix = "time < " if not caveat.startswith(prefix): return False - expiry = int(caveat[len(prefix):]) + expiry = int(caveat[len(prefix) :]) now = self.hs.get_clock().time_msec() return now < expiry @@ -554,14 +541,12 @@ def get_appservice_by_req(self, request): raise AuthError( self.TOKEN_NOT_FOUND_HTTP_STATUS, "Unrecognised access token.", - errcode=Codes.UNKNOWN_TOKEN + errcode=Codes.UNKNOWN_TOKEN, ) request.authenticated_entity = service.sender return defer.succeed(service) except KeyError: - raise AuthError( - self.TOKEN_NOT_FOUND_HTTP_STATUS, "Missing access token." - ) + raise AuthError(self.TOKEN_NOT_FOUND_HTTP_STATUS, "Missing access token.") def is_server_admin(self, user): """ Check if the given user is a local server admin. @@ -581,19 +566,19 @@ def compute_auth_events(self, event, current_state_ids, for_verification=False): auth_ids = [] - key = (EventTypes.PowerLevels, "", ) + key = (EventTypes.PowerLevels, "") power_level_event_id = current_state_ids.get(key) if power_level_event_id: auth_ids.append(power_level_event_id) - key = (EventTypes.JoinRules, "", ) + key = (EventTypes.JoinRules, "") join_rule_event_id = current_state_ids.get(key) - key = (EventTypes.Member, event.sender, ) + key = (EventTypes.Member, event.sender) member_event_id = current_state_ids.get(key) - key = (EventTypes.Create, "", ) + key = (EventTypes.Create, "") create_event_id = current_state_ids.get(key) if create_event_id: auth_ids.append(create_event_id) @@ -619,7 +604,7 @@ def compute_auth_events(self, event, current_state_ids, for_verification=False): auth_ids.append(member_event_id) if for_verification: - key = (EventTypes.Member, event.state_key, ) + key = (EventTypes.Member, event.state_key) existing_event_id = current_state_ids.get(key) if existing_event_id: auth_ids.append(existing_event_id) @@ -628,7 +613,7 @@ def compute_auth_events(self, event, current_state_ids, for_verification=False): if "third_party_invite" in event.content: key = ( EventTypes.ThirdPartyInvite, - event.content["third_party_invite"]["signed"]["token"] + event.content["third_party_invite"]["signed"]["token"], ) third_party_invite_id = current_state_ids.get(key) if third_party_invite_id: @@ -684,7 +669,7 @@ def check_can_change_room_list(self, room_id, user): auth_events[(EventTypes.PowerLevels, "")] = power_level_event send_level = event_auth.get_send_level( - EventTypes.Aliases, "", power_level_event, + EventTypes.Aliases, "", power_level_event ) user_level = event_auth.get_user_power_level(user_id, auth_events) @@ -692,7 +677,7 @@ def check_can_change_room_list(self, room_id, user): raise AuthError( 403, "This server requires you to be a moderator in the room to" - " edit its room list entry" + " edit its room list entry", ) @staticmethod @@ -742,7 +727,7 @@ def get_access_token_from_request(request, token_not_found_http_status=401): ) parts = auth_headers[0].split(b" ") if parts[0] == b"Bearer" and len(parts) == 2: - return parts[1].decode('ascii') + return parts[1].decode("ascii") else: raise AuthError( token_not_found_http_status, @@ -755,10 +740,10 @@ def get_access_token_from_request(request, token_not_found_http_status=401): raise AuthError( token_not_found_http_status, "Missing access token.", - errcode=Codes.MISSING_TOKEN + errcode=Codes.MISSING_TOKEN, ) - return query_params[0].decode('ascii') + return query_params[0].decode("ascii") @defer.inlineCallbacks def check_in_room_or_world_readable(self, room_id, user_id): @@ -785,8 +770,8 @@ def check_in_room_or_world_readable(self, room_id, user_id): room_id, EventTypes.RoomHistoryVisibility, "" ) if ( - visibility and - visibility.content["history_visibility"] == "world_readable" + visibility + and visibility.content["history_visibility"] == "world_readable" ): defer.returnValue((Membership.JOIN, None)) return @@ -820,10 +805,11 @@ def check_auth_blocking(self, user_id=None, threepid=None): if self.hs.config.hs_disabled: raise ResourceLimitError( - 403, self.hs.config.hs_disabled_message, + 403, + self.hs.config.hs_disabled_message, errcode=Codes.RESOURCE_LIMIT_EXCEEDED, admin_contact=self.hs.config.admin_contact, - limit_type=self.hs.config.hs_disabled_limit_type + limit_type=self.hs.config.hs_disabled_limit_type, ) if self.hs.config.limit_usage_by_mau is True: assert not (user_id and threepid) @@ -848,8 +834,9 @@ def check_auth_blocking(self, user_id=None, threepid=None): current_mau = yield self.store.get_monthly_active_count() if current_mau >= self.hs.config.max_mau_value: raise ResourceLimitError( - 403, "Monthly Active User Limit Exceeded", + 403, + "Monthly Active User Limit Exceeded", admin_contact=self.hs.config.admin_contact, errcode=Codes.RESOURCE_LIMIT_EXCEEDED, - limit_type="monthly_active_user" + limit_type="monthly_active_user", ) diff --git a/synapse/api/constants.py b/synapse/api/constants.py index ee129c868991..3ffde0d7fc83 100644 --- a/synapse/api/constants.py +++ b/synapse/api/constants.py @@ -18,7 +18,7 @@ """Contains constants from the specification.""" # the "depth" field on events is limited to 2**63 - 1 -MAX_DEPTH = 2**63 - 1 +MAX_DEPTH = 2 ** 63 - 1 # the maximum length for a room alias is 255 characters MAX_ALIAS_LENGTH = 255 @@ -30,39 +30,41 @@ class Membership(object): """Represents the membership states of a user in a room.""" - INVITE = u"invite" - JOIN = u"join" - KNOCK = u"knock" - LEAVE = u"leave" - BAN = u"ban" + + INVITE = "invite" + JOIN = "join" + KNOCK = "knock" + LEAVE = "leave" + BAN = "ban" LIST = (INVITE, JOIN, KNOCK, LEAVE, BAN) class PresenceState(object): """Represents the presence state of a user.""" - OFFLINE = u"offline" - UNAVAILABLE = u"unavailable" - ONLINE = u"online" + + OFFLINE = "offline" + UNAVAILABLE = "unavailable" + ONLINE = "online" class JoinRules(object): - PUBLIC = u"public" - KNOCK = u"knock" - INVITE = u"invite" - PRIVATE = u"private" + PUBLIC = "public" + KNOCK = "knock" + INVITE = "invite" + PRIVATE = "private" class LoginType(object): - PASSWORD = u"m.login.password" - EMAIL_IDENTITY = u"m.login.email.identity" - MSISDN = u"m.login.msisdn" - RECAPTCHA = u"m.login.recaptcha" - TERMS = u"m.login.terms" - DUMMY = u"m.login.dummy" + PASSWORD = "m.login.password" + EMAIL_IDENTITY = "m.login.email.identity" + MSISDN = "m.login.msisdn" + RECAPTCHA = "m.login.recaptcha" + TERMS = "m.login.terms" + DUMMY = "m.login.dummy" # Only for C/S API v1 - APPLICATION_SERVICE = u"m.login.application_service" - SHARED_SECRET = u"org.matrix.login.shared_secret" + APPLICATION_SERVICE = "m.login.application_service" + SHARED_SECRET = "org.matrix.login.shared_secret" class EventTypes(object): @@ -118,6 +120,7 @@ class UserTypes(object): """Allows for user type specific behaviour. With the benefit of hindsight 'admin' and 'guest' users should also be UserTypes. Normal users are type None """ + SUPPORT = "support" ALL_USER_TYPES = (SUPPORT,) @@ -125,6 +128,7 @@ class UserTypes(object): class RelationTypes(object): """The types of relations known to this server. """ + ANNOTATION = "m.annotation" REPLACE = "m.replace" REFERENCE = "m.reference" diff --git a/synapse/api/errors.py b/synapse/api/errors.py index 66201d6efe41..28b5c2af9b39 100644 --- a/synapse/api/errors.py +++ b/synapse/api/errors.py @@ -70,6 +70,7 @@ class CodeMessageException(RuntimeError): code (int): HTTP error code msg (str): string describing the error """ + def __init__(self, code, msg): super(CodeMessageException, self).__init__("%d: %s" % (code, msg)) self.code = code @@ -83,6 +84,7 @@ class SynapseError(CodeMessageException): Attributes: errcode (str): Matrix error code e.g 'M_FORBIDDEN' """ + def __init__(self, code, msg, errcode=Codes.UNKNOWN): """Constructs a synapse error. @@ -95,10 +97,7 @@ def __init__(self, code, msg, errcode=Codes.UNKNOWN): self.errcode = errcode def error_dict(self): - return cs_error( - self.msg, - self.errcode, - ) + return cs_error(self.msg, self.errcode) class ProxiedRequestError(SynapseError): @@ -107,27 +106,23 @@ class ProxiedRequestError(SynapseError): Attributes: errcode (str): Matrix error code e.g 'M_FORBIDDEN' """ + def __init__(self, code, msg, errcode=Codes.UNKNOWN, additional_fields=None): - super(ProxiedRequestError, self).__init__( - code, msg, errcode - ) + super(ProxiedRequestError, self).__init__(code, msg, errcode) if additional_fields is None: self._additional_fields = {} else: self._additional_fields = dict(additional_fields) def error_dict(self): - return cs_error( - self.msg, - self.errcode, - **self._additional_fields - ) + return cs_error(self.msg, self.errcode, **self._additional_fields) class ConsentNotGivenError(SynapseError): """The error returned to the client when the user has not consented to the privacy policy. """ + def __init__(self, msg, consent_uri): """Constructs a ConsentNotGivenError @@ -136,22 +131,17 @@ def __init__(self, msg, consent_uri): consent_url (str): The URL where the user can give their consent """ super(ConsentNotGivenError, self).__init__( - code=http_client.FORBIDDEN, - msg=msg, - errcode=Codes.CONSENT_NOT_GIVEN + code=http_client.FORBIDDEN, msg=msg, errcode=Codes.CONSENT_NOT_GIVEN ) self._consent_uri = consent_uri def error_dict(self): - return cs_error( - self.msg, - self.errcode, - consent_uri=self._consent_uri - ) + return cs_error(self.msg, self.errcode, consent_uri=self._consent_uri) class RegistrationError(SynapseError): """An error raised when a registration event fails.""" + pass @@ -190,15 +180,17 @@ class InteractiveAuthIncompleteError(Exception): result (dict): the server response to the request, which should be passed back to the client """ + def __init__(self, result): super(InteractiveAuthIncompleteError, self).__init__( - "Interactive auth not yet complete", + "Interactive auth not yet complete" ) self.result = result class UnrecognizedRequestError(SynapseError): """An error indicating we don't understand the request you're trying to make""" + def __init__(self, *args, **kwargs): if "errcode" not in kwargs: kwargs["errcode"] = Codes.UNRECOGNIZED @@ -207,21 +199,14 @@ def __init__(self, *args, **kwargs): message = "Unrecognized request" else: message = args[0] - super(UnrecognizedRequestError, self).__init__( - 400, - message, - **kwargs - ) + super(UnrecognizedRequestError, self).__init__(400, message, **kwargs) class NotFoundError(SynapseError): """An error indicating we can't find the thing you asked for""" + def __init__(self, msg="Not found", errcode=Codes.NOT_FOUND): - super(NotFoundError, self).__init__( - 404, - msg, - errcode=errcode - ) + super(NotFoundError, self).__init__(404, msg, errcode=errcode) class AuthError(SynapseError): @@ -238,8 +223,11 @@ class ResourceLimitError(SynapseError): Any error raised when there is a problem with resource usage. For instance, the monthly active user limit for the server has been exceeded """ + def __init__( - self, code, msg, + self, + code, + msg, errcode=Codes.RESOURCE_LIMIT_EXCEEDED, admin_contact=None, limit_type=None, @@ -253,7 +241,7 @@ def error_dict(self): self.msg, self.errcode, admin_contact=self.admin_contact, - limit_type=self.limit_type + limit_type=self.limit_type, ) @@ -268,6 +256,7 @@ def __init__(self, *args, **kwargs): class EventStreamError(SynapseError): """An error raised when there a problem with the event stream.""" + def __init__(self, *args, **kwargs): if "errcode" not in kwargs: kwargs["errcode"] = Codes.BAD_PAGINATION @@ -276,47 +265,53 @@ def __init__(self, *args, **kwargs): class LoginError(SynapseError): """An error raised when there was a problem logging in.""" + pass class StoreError(SynapseError): """An error raised when there was a problem storing some data.""" + pass class InvalidCaptchaError(SynapseError): - def __init__(self, code=400, msg="Invalid captcha.", error_url=None, - errcode=Codes.CAPTCHA_INVALID): + def __init__( + self, + code=400, + msg="Invalid captcha.", + error_url=None, + errcode=Codes.CAPTCHA_INVALID, + ): super(InvalidCaptchaError, self).__init__(code, msg, errcode) self.error_url = error_url def error_dict(self): - return cs_error( - self.msg, - self.errcode, - error_url=self.error_url, - ) + return cs_error(self.msg, self.errcode, error_url=self.error_url) class LimitExceededError(SynapseError): """A client has sent too many requests and is being throttled. """ - def __init__(self, code=429, msg="Too Many Requests", retry_after_ms=None, - errcode=Codes.LIMIT_EXCEEDED): + + def __init__( + self, + code=429, + msg="Too Many Requests", + retry_after_ms=None, + errcode=Codes.LIMIT_EXCEEDED, + ): super(LimitExceededError, self).__init__(code, msg, errcode) self.retry_after_ms = retry_after_ms def error_dict(self): - return cs_error( - self.msg, - self.errcode, - retry_after_ms=self.retry_after_ms, - ) + return cs_error(self.msg, self.errcode, retry_after_ms=self.retry_after_ms) class RoomKeysVersionError(SynapseError): """A client has tried to upload to a non-current version of the room_keys store """ + def __init__(self, current_version): """ Args: @@ -331,6 +326,7 @@ def __init__(self, current_version): class UnsupportedRoomVersionError(SynapseError): """The client's request to create a room used a room version that the server does not support.""" + def __init__(self): super(UnsupportedRoomVersionError, self).__init__( code=400, @@ -354,22 +350,19 @@ class IncompatibleRoomVersionError(SynapseError): Unlike UnsupportedRoomVersionError, it is specific to the case of the make_join failing. """ + def __init__(self, room_version): super(IncompatibleRoomVersionError, self).__init__( code=400, msg="Your homeserver does not support the features required to " - "join this room", + "join this room", errcode=Codes.INCOMPATIBLE_ROOM_VERSION, ) self._room_version = room_version def error_dict(self): - return cs_error( - self.msg, - self.errcode, - room_version=self._room_version, - ) + return cs_error(self.msg, self.errcode, room_version=self._room_version) class RequestSendFailed(RuntimeError): @@ -380,11 +373,11 @@ class RequestSendFailed(RuntimeError): networking (e.g. DNS failures, connection timeouts etc), versus unexpected errors (like programming errors). """ + def __init__(self, inner_exception, can_retry): super(RequestSendFailed, self).__init__( - "Failed to send request: %s: %s" % ( - type(inner_exception).__name__, inner_exception, - ) + "Failed to send request: %s: %s" + % (type(inner_exception).__name__, inner_exception) ) self.inner_exception = inner_exception self.can_retry = can_retry @@ -428,7 +421,7 @@ def __init__(self, level, code, reason, affected, source=None): self.affected = affected self.source = source - msg = "%s %s: %s" % (level, code, reason,) + msg = "%s %s: %s" % (level, code, reason) super(FederationError, self).__init__(msg) def get_dict(self): @@ -448,6 +441,7 @@ class HttpResponseException(CodeMessageException): Attributes: response (bytes): body of response """ + def __init__(self, code, msg, response): """ @@ -486,7 +480,7 @@ def to_synapse_error(self): if not isinstance(j, dict): j = {} - errcode = j.pop('errcode', Codes.UNKNOWN) - errmsg = j.pop('error', self.msg) + errcode = j.pop("errcode", Codes.UNKNOWN) + errmsg = j.pop("error", self.msg) return ProxiedRequestError(self.code, errmsg, errcode, j) diff --git a/synapse/api/filtering.py b/synapse/api/filtering.py index 3906475403ac..9b3daca29bbf 100644 --- a/synapse/api/filtering.py +++ b/synapse/api/filtering.py @@ -28,117 +28,55 @@ "additionalProperties": False, "type": "object", "properties": { - "limit": { - "type": "number" - }, - "senders": { - "$ref": "#/definitions/user_id_array" - }, - "not_senders": { - "$ref": "#/definitions/user_id_array" - }, + "limit": {"type": "number"}, + "senders": {"$ref": "#/definitions/user_id_array"}, + "not_senders": {"$ref": "#/definitions/user_id_array"}, # TODO: We don't limit event type values but we probably should... # check types are valid event types - "types": { - "type": "array", - "items": { - "type": "string" - } - }, - "not_types": { - "type": "array", - "items": { - "type": "string" - } - } - } + "types": {"type": "array", "items": {"type": "string"}}, + "not_types": {"type": "array", "items": {"type": "string"}}, + }, } ROOM_FILTER_SCHEMA = { "additionalProperties": False, "type": "object", "properties": { - "not_rooms": { - "$ref": "#/definitions/room_id_array" - }, - "rooms": { - "$ref": "#/definitions/room_id_array" - }, - "ephemeral": { - "$ref": "#/definitions/room_event_filter" - }, - "include_leave": { - "type": "boolean" - }, - "state": { - "$ref": "#/definitions/room_event_filter" - }, - "timeline": { - "$ref": "#/definitions/room_event_filter" - }, - "account_data": { - "$ref": "#/definitions/room_event_filter" - }, - } + "not_rooms": {"$ref": "#/definitions/room_id_array"}, + "rooms": {"$ref": "#/definitions/room_id_array"}, + "ephemeral": {"$ref": "#/definitions/room_event_filter"}, + "include_leave": {"type": "boolean"}, + "state": {"$ref": "#/definitions/room_event_filter"}, + "timeline": {"$ref": "#/definitions/room_event_filter"}, + "account_data": {"$ref": "#/definitions/room_event_filter"}, + }, } ROOM_EVENT_FILTER_SCHEMA = { "additionalProperties": False, "type": "object", "properties": { - "limit": { - "type": "number" - }, - "senders": { - "$ref": "#/definitions/user_id_array" - }, - "not_senders": { - "$ref": "#/definitions/user_id_array" - }, - "types": { - "type": "array", - "items": { - "type": "string" - } - }, - "not_types": { - "type": "array", - "items": { - "type": "string" - } - }, - "rooms": { - "$ref": "#/definitions/room_id_array" - }, - "not_rooms": { - "$ref": "#/definitions/room_id_array" - }, - "contains_url": { - "type": "boolean" - }, - "lazy_load_members": { - "type": "boolean" - }, - "include_redundant_members": { - "type": "boolean" - }, - } + "limit": {"type": "number"}, + "senders": {"$ref": "#/definitions/user_id_array"}, + "not_senders": {"$ref": "#/definitions/user_id_array"}, + "types": {"type": "array", "items": {"type": "string"}}, + "not_types": {"type": "array", "items": {"type": "string"}}, + "rooms": {"$ref": "#/definitions/room_id_array"}, + "not_rooms": {"$ref": "#/definitions/room_id_array"}, + "contains_url": {"type": "boolean"}, + "lazy_load_members": {"type": "boolean"}, + "include_redundant_members": {"type": "boolean"}, + }, } USER_ID_ARRAY_SCHEMA = { "type": "array", - "items": { - "type": "string", - "format": "matrix_user_id" - } + "items": {"type": "string", "format": "matrix_user_id"}, } ROOM_ID_ARRAY_SCHEMA = { "type": "array", - "items": { - "type": "string", - "format": "matrix_room_id" - } + "items": {"type": "string", "format": "matrix_room_id"}, } USER_FILTER_SCHEMA = { @@ -150,22 +88,13 @@ "user_id_array": USER_ID_ARRAY_SCHEMA, "filter": FILTER_SCHEMA, "room_filter": ROOM_FILTER_SCHEMA, - "room_event_filter": ROOM_EVENT_FILTER_SCHEMA + "room_event_filter": ROOM_EVENT_FILTER_SCHEMA, }, "properties": { - "presence": { - "$ref": "#/definitions/filter" - }, - "account_data": { - "$ref": "#/definitions/filter" - }, - "room": { - "$ref": "#/definitions/room_filter" - }, - "event_format": { - "type": "string", - "enum": ["client", "federation"] - }, + "presence": {"$ref": "#/definitions/filter"}, + "account_data": {"$ref": "#/definitions/filter"}, + "room": {"$ref": "#/definitions/room_filter"}, + "event_format": {"type": "string", "enum": ["client", "federation"]}, "event_fields": { "type": "array", "items": { @@ -177,26 +106,25 @@ # # Note that because this is a regular expression, we have to escape # each backslash in the pattern. - "pattern": r"^((?!\\\\).)*$" - } - } + "pattern": r"^((?!\\\\).)*$", + }, + }, }, - "additionalProperties": False + "additionalProperties": False, } -@FormatChecker.cls_checks('matrix_room_id') +@FormatChecker.cls_checks("matrix_room_id") def matrix_room_id_validator(room_id_str): return RoomID.from_string(room_id_str) -@FormatChecker.cls_checks('matrix_user_id') +@FormatChecker.cls_checks("matrix_user_id") def matrix_user_id_validator(user_id_str): return UserID.from_string(user_id_str) class Filtering(object): - def __init__(self, hs): super(Filtering, self).__init__() self.store = hs.get_datastore() @@ -228,8 +156,9 @@ def check_valid_filter(self, user_filter_json): # individual top-level key e.g. public_user_data. Filters are made of # many definitions. try: - jsonschema.validate(user_filter_json, USER_FILTER_SCHEMA, - format_checker=FormatChecker()) + jsonschema.validate( + user_filter_json, USER_FILTER_SCHEMA, format_checker=FormatChecker() + ) except jsonschema.ValidationError as e: raise SynapseError(400, str(e)) @@ -240,10 +169,9 @@ def __init__(self, filter_json): room_filter_json = self._filter_json.get("room", {}) - self._room_filter = Filter({ - k: v for k, v in room_filter_json.items() - if k in ("rooms", "not_rooms") - }) + self._room_filter = Filter( + {k: v for k, v in room_filter_json.items() if k in ("rooms", "not_rooms")} + ) self._room_timeline_filter = Filter(room_filter_json.get("timeline", {})) self._room_state_filter = Filter(room_filter_json.get("state", {})) @@ -252,9 +180,7 @@ def __init__(self, filter_json): self._presence_filter = Filter(filter_json.get("presence", {})) self._account_data = Filter(filter_json.get("account_data", {})) - self.include_leave = filter_json.get("room", {}).get( - "include_leave", False - ) + self.include_leave = filter_json.get("room", {}).get("include_leave", False) self.event_fields = filter_json.get("event_fields", []) self.event_format = filter_json.get("event_format", "client") @@ -299,22 +225,22 @@ def filter_room_account_data(self, events): def blocks_all_presence(self): return ( - self._presence_filter.filters_all_types() or - self._presence_filter.filters_all_senders() + self._presence_filter.filters_all_types() + or self._presence_filter.filters_all_senders() ) def blocks_all_room_ephemeral(self): return ( - self._room_ephemeral_filter.filters_all_types() or - self._room_ephemeral_filter.filters_all_senders() or - self._room_ephemeral_filter.filters_all_rooms() + self._room_ephemeral_filter.filters_all_types() + or self._room_ephemeral_filter.filters_all_senders() + or self._room_ephemeral_filter.filters_all_rooms() ) def blocks_all_room_timeline(self): return ( - self._room_timeline_filter.filters_all_types() or - self._room_timeline_filter.filters_all_senders() or - self._room_timeline_filter.filters_all_rooms() + self._room_timeline_filter.filters_all_types() + or self._room_timeline_filter.filters_all_senders() + or self._room_timeline_filter.filters_all_rooms() ) @@ -375,12 +301,7 @@ def check(self, event): # check if there is a string url field in the content for filtering purposes contains_url = isinstance(content.get("url"), text_type) - return self.check_fields( - room_id, - sender, - ev_type, - contains_url, - ) + return self.check_fields(room_id, sender, ev_type, contains_url) def check_fields(self, room_id, sender, event_type, contains_url): """Checks whether the filter matches the given event fields. @@ -391,7 +312,7 @@ def check_fields(self, room_id, sender, event_type, contains_url): literal_keys = { "rooms": lambda v: room_id == v, "senders": lambda v: sender == v, - "types": lambda v: _matches_wildcard(event_type, v) + "types": lambda v: _matches_wildcard(event_type, v), } for name, match_func in literal_keys.items(): diff --git a/synapse/api/ratelimiting.py b/synapse/api/ratelimiting.py index 296c4a1c1749..172841f59530 100644 --- a/synapse/api/ratelimiting.py +++ b/synapse/api/ratelimiting.py @@ -44,29 +44,25 @@ def can_do_action(self, key, time_now_s, rate_hz, burst_count, update=True): """ self.prune_message_counts(time_now_s) message_count, time_start, _ignored = self.message_counts.get( - key, (0., time_now_s, None), + key, (0.0, time_now_s, None) ) time_delta = time_now_s - time_start sent_count = message_count - time_delta * rate_hz if sent_count < 0: allowed = True time_start = time_now_s - message_count = 1. - elif sent_count > burst_count - 1.: + message_count = 1.0 + elif sent_count > burst_count - 1.0: allowed = False else: allowed = True message_count += 1 if update: - self.message_counts[key] = ( - message_count, time_start, rate_hz - ) + self.message_counts[key] = (message_count, time_start, rate_hz) if rate_hz > 0: - time_allowed = ( - time_start + (message_count - burst_count + 1) / rate_hz - ) + time_allowed = time_start + (message_count - burst_count + 1) / rate_hz if time_allowed < time_now_s: time_allowed = time_now_s else: @@ -76,9 +72,7 @@ def can_do_action(self, key, time_now_s, rate_hz, burst_count, update=True): def prune_message_counts(self, time_now_s): for key in list(self.message_counts.keys()): - message_count, time_start, rate_hz = ( - self.message_counts[key] - ) + message_count, time_start, rate_hz = self.message_counts[key] time_delta = time_now_s - time_start if message_count - time_delta * rate_hz > 0: break @@ -92,5 +86,5 @@ def ratelimit(self, key, time_now_s, rate_hz, burst_count, update=True): if not allowed: raise LimitExceededError( - retry_after_ms=int(1000 * (time_allowed - time_now_s)), + retry_after_ms=int(1000 * (time_allowed - time_now_s)) ) diff --git a/synapse/api/room_versions.py b/synapse/api/room_versions.py index d644803d3843..95292b7dec0d 100644 --- a/synapse/api/room_versions.py +++ b/synapse/api/room_versions.py @@ -19,9 +19,10 @@ class EventFormatVersions(object): """This is an internal enum for tracking the version of the event format, independently from the room version. """ - V1 = 1 # $id:server event id format - V2 = 2 # MSC1659-style $hash event id format: introduced for room v3 - V3 = 3 # MSC1884-style $hash format: introduced for room v4 + + V1 = 1 # $id:server event id format + V2 = 2 # MSC1659-style $hash event id format: introduced for room v3 + V3 = 3 # MSC1884-style $hash format: introduced for room v4 KNOWN_EVENT_FORMAT_VERSIONS = { @@ -33,8 +34,9 @@ class EventFormatVersions(object): class StateResolutionVersions(object): """Enum to identify the state resolution algorithms""" - V1 = 1 # room v1 state res - V2 = 2 # MSC1442 state res: room v2 and later + + V1 = 1 # room v1 state res + V2 = 2 # MSC1442 state res: room v2 and later class RoomDisposition(object): @@ -46,10 +48,10 @@ class RoomDisposition(object): class RoomVersion(object): """An object which describes the unique attributes of a room version.""" - identifier = attr.ib() # str; the identifier for this version - disposition = attr.ib() # str; one of the RoomDispositions - event_format = attr.ib() # int; one of the EventFormatVersions - state_res = attr.ib() # int; one of the StateResolutionVersions + identifier = attr.ib() # str; the identifier for this version + disposition = attr.ib() # str; one of the RoomDispositions + event_format = attr.ib() # int; one of the EventFormatVersions + state_res = attr.ib() # int; one of the StateResolutionVersions enforce_key_validity = attr.ib() # bool @@ -92,11 +94,12 @@ class RoomVersions(object): KNOWN_ROOM_VERSIONS = { - v.identifier: v for v in ( + v.identifier: v + for v in ( RoomVersions.V1, RoomVersions.V2, RoomVersions.V3, RoomVersions.V4, RoomVersions.V5, ) -} # type: dict[str, RoomVersion] +} # type: dict[str, RoomVersion] diff --git a/synapse/api/urls.py b/synapse/api/urls.py index e16c386a14d2..ff1f39e86ccb 100644 --- a/synapse/api/urls.py +++ b/synapse/api/urls.py @@ -42,13 +42,9 @@ def __init__(self, hs_config): hs_config (synapse.config.homeserver.HomeServerConfig): """ if hs_config.form_secret is None: - raise ConfigError( - "form_secret not set in config", - ) + raise ConfigError("form_secret not set in config") if hs_config.public_baseurl is None: - raise ConfigError( - "public_baseurl not set in config", - ) + raise ConfigError("public_baseurl not set in config") self._hmac_secret = hs_config.form_secret.encode("utf-8") self._public_baseurl = hs_config.public_baseurl @@ -64,15 +60,10 @@ def build_user_consent_uri(self, user_id): (str) the URI where the user can do consent """ mac = hmac.new( - key=self._hmac_secret, - msg=user_id.encode('ascii'), - digestmod=sha256, + key=self._hmac_secret, msg=user_id.encode("ascii"), digestmod=sha256 ).hexdigest() consent_uri = "%s_matrix/consent?%s" % ( self._public_baseurl, - urlencode({ - "u": user_id, - "h": mac - }), + urlencode({"u": user_id, "h": mac}), ) return consent_uri diff --git a/synapse/app/__init__.py b/synapse/app/__init__.py index f56f5fcc13e1..d877c77834de 100644 --- a/synapse/app/__init__.py +++ b/synapse/app/__init__.py @@ -43,7 +43,7 @@ def check_bind_error(e, address, bind_addresses): address (str): Address on which binding was attempted. bind_addresses (list): Addresses on which the service listens. """ - if address == '0.0.0.0' and '::' in bind_addresses: - logger.warn('Failed to listen on 0.0.0.0, continuing because listening on [::]') + if address == "0.0.0.0" and "::" in bind_addresses: + logger.warn("Failed to listen on 0.0.0.0, continuing because listening on [::]") else: raise e diff --git a/synapse/app/_base.py b/synapse/app/_base.py index 8cc990399f8f..df4c2d4c971f 100644 --- a/synapse/app/_base.py +++ b/synapse/app/_base.py @@ -75,14 +75,14 @@ def start_worker_reactor(appname, config): def start_reactor( - appname, - soft_file_limit, - gc_thresholds, - pid_file, - daemonize, - cpu_affinity, - print_pidfile, - logger, + appname, + soft_file_limit, + gc_thresholds, + pid_file, + daemonize, + cpu_affinity, + print_pidfile, + logger, ): """ Run the reactor in the main process @@ -149,10 +149,10 @@ def run(): def quit_with_error(error_string): message_lines = error_string.split("\n") line_length = max([len(l) for l in message_lines if len(l) < 80]) + 2 - sys.stderr.write("*" * line_length + '\n') + sys.stderr.write("*" * line_length + "\n") for line in message_lines: sys.stderr.write(" %s\n" % (line.rstrip(),)) - sys.stderr.write("*" * line_length + '\n') + sys.stderr.write("*" * line_length + "\n") sys.exit(1) @@ -178,14 +178,7 @@ def listen_tcp(bind_addresses, port, factory, reactor=reactor, backlog=50): r = [] for address in bind_addresses: try: - r.append( - reactor.listenTCP( - port, - factory, - backlog, - address - ) - ) + r.append(reactor.listenTCP(port, factory, backlog, address)) except error.CannotListenError as e: check_bind_error(e, address, bind_addresses) @@ -205,13 +198,7 @@ def listen_ssl( for address in bind_addresses: try: r.append( - reactor.listenSSL( - port, - factory, - context_factory, - backlog, - address - ) + reactor.listenSSL(port, factory, context_factory, backlog, address) ) except error.CannotListenError as e: check_bind_error(e, address, bind_addresses) @@ -243,15 +230,13 @@ def refresh_certificate(hs): if isinstance(i.factory, TLSMemoryBIOFactory): addr = i.getHost() logger.info( - "Replacing TLS context factory on [%s]:%i", addr.host, addr.port, + "Replacing TLS context factory on [%s]:%i", addr.host, addr.port ) # We want to replace TLS factories with a new one, with the new # TLS configuration. We do this by reaching in and pulling out # the wrappedFactory, and then re-wrapping it. i.factory = TLSMemoryBIOFactory( - hs.tls_server_context_factory, - False, - i.factory.wrappedFactory + hs.tls_server_context_factory, False, i.factory.wrappedFactory ) logger.info("Context factories updated.") @@ -267,6 +252,7 @@ def start(hs, listeners=None): try: # Set up the SIGHUP machinery. if hasattr(signal, "SIGHUP"): + def handle_sighup(*args, **kwargs): for i in _sighup_callbacks: i(hs) @@ -302,10 +288,8 @@ def setup_sentry(hs): return import sentry_sdk - sentry_sdk.init( - dsn=hs.config.sentry_dsn, - release=get_version_string(synapse), - ) + + sentry_sdk.init(dsn=hs.config.sentry_dsn, release=get_version_string(synapse)) # We set some default tags that give some context to this instance with sentry_sdk.configure_scope() as scope: @@ -326,7 +310,7 @@ def install_dns_limiter(reactor, max_dns_requests_in_flight=100): many DNS queries at once """ new_resolver = _LimitedHostnameResolver( - reactor.nameResolver, max_dns_requests_in_flight, + reactor.nameResolver, max_dns_requests_in_flight ) reactor.installNameResolver(new_resolver) @@ -339,11 +323,17 @@ class _LimitedHostnameResolver(object): def __init__(self, resolver, max_dns_requests_in_flight): self._resolver = resolver self._limiter = Linearizer( - name="dns_client_limiter", max_count=max_dns_requests_in_flight, + name="dns_client_limiter", max_count=max_dns_requests_in_flight ) - def resolveHostName(self, resolutionReceiver, hostName, portNumber=0, - addressTypes=None, transportSemantics='TCP'): + def resolveHostName( + self, + resolutionReceiver, + hostName, + portNumber=0, + addressTypes=None, + transportSemantics="TCP", + ): # We need this function to return `resolutionReceiver` so we do all the # actual logic involving deferreds in a separate function. @@ -363,8 +353,14 @@ def resolveHostName(self, resolutionReceiver, hostName, portNumber=0, return resolutionReceiver @defer.inlineCallbacks - def _resolve(self, resolutionReceiver, hostName, portNumber=0, - addressTypes=None, transportSemantics='TCP'): + def _resolve( + self, + resolutionReceiver, + hostName, + portNumber=0, + addressTypes=None, + transportSemantics="TCP", + ): with (yield self._limiter.queue(())): # resolveHostName doesn't return a Deferred, so we need to hook into @@ -374,8 +370,7 @@ def _resolve(self, resolutionReceiver, hostName, portNumber=0, receiver = _DeferredResolutionReceiver(resolutionReceiver, deferred) self._resolver.resolveHostName( - receiver, hostName, portNumber, - addressTypes, transportSemantics, + receiver, hostName, portNumber, addressTypes, transportSemantics ) yield deferred diff --git a/synapse/app/appservice.py b/synapse/app/appservice.py index 33107f56d137..9120bdb1438c 100644 --- a/synapse/app/appservice.py +++ b/synapse/app/appservice.py @@ -44,7 +44,9 @@ class AppserviceSlaveStore( - DirectoryStore, SlavedEventStore, SlavedApplicationServiceStore, + DirectoryStore, + SlavedEventStore, + SlavedApplicationServiceStore, SlavedRegistrationStore, ): pass @@ -74,7 +76,7 @@ def _listen_http(self, listener_config): listener_config, root_resource, self.version_string, - ) + ), ) logger.info("Synapse appservice now listening on port %d", port) @@ -88,18 +90,19 @@ def start_listening(self, listeners): listener["bind_addresses"], listener["port"], manhole( - username="matrix", - password="rabbithole", - globals={"hs": self}, - ) + username="matrix", password="rabbithole", globals={"hs": self} + ), ) elif listener["type"] == "metrics": if not self.get_config().enable_metrics: - logger.warn(("Metrics listener configured, but " - "enable_metrics is not True!")) + logger.warn( + ( + "Metrics listener configured, but " + "enable_metrics is not True!" + ) + ) else: - _base.listen_metrics(listener["bind_addresses"], - listener["port"]) + _base.listen_metrics(listener["bind_addresses"], listener["port"]) else: logger.warn("Unrecognized listener type: %s", listener["type"]) @@ -132,9 +135,7 @@ def _notify_app_services(self, room_stream_id): def start(config_options): try: - config = HomeServerConfig.load_config( - "Synapse appservice", config_options - ) + config = HomeServerConfig.load_config("Synapse appservice", config_options) except ConfigError as e: sys.stderr.write("\n" + str(e) + "\n") sys.exit(1) @@ -173,6 +174,6 @@ def start(config_options): _base.start_worker_reactor("synapse-appservice", config) -if __name__ == '__main__': +if __name__ == "__main__": with LoggingContext("main"): start(sys.argv[1:]) diff --git a/synapse/app/client_reader.py b/synapse/app/client_reader.py index a16e037f3227..310cdab2e4bc 100644 --- a/synapse/app/client_reader.py +++ b/synapse/app/client_reader.py @@ -118,9 +118,7 @@ def _listen_http(self, listener_config): PushRuleRestServlet(self).register(resource) VersionsRestServlet().register(resource) - resources.update({ - "/_matrix/client": resource, - }) + resources.update({"/_matrix/client": resource}) root_resource = create_resource_tree(resources, NoResource()) @@ -133,7 +131,7 @@ def _listen_http(self, listener_config): listener_config, root_resource, self.version_string, - ) + ), ) logger.info("Synapse client reader now listening on port %d", port) @@ -147,18 +145,19 @@ def start_listening(self, listeners): listener["bind_addresses"], listener["port"], manhole( - username="matrix", - password="rabbithole", - globals={"hs": self}, - ) + username="matrix", password="rabbithole", globals={"hs": self} + ), ) elif listener["type"] == "metrics": if not self.get_config().enable_metrics: - logger.warn(("Metrics listener configured, but " - "enable_metrics is not True!")) + logger.warn( + ( + "Metrics listener configured, but " + "enable_metrics is not True!" + ) + ) else: - _base.listen_metrics(listener["bind_addresses"], - listener["port"]) + _base.listen_metrics(listener["bind_addresses"], listener["port"]) else: logger.warn("Unrecognized listener type: %s", listener["type"]) @@ -170,9 +169,7 @@ def build_tcp_replication(self): def start(config_options): try: - config = HomeServerConfig.load_config( - "Synapse client reader", config_options - ) + config = HomeServerConfig.load_config("Synapse client reader", config_options) except ConfigError as e: sys.stderr.write("\n" + str(e) + "\n") sys.exit(1) @@ -199,6 +196,6 @@ def start(config_options): _base.start_worker_reactor("synapse-client-reader", config) -if __name__ == '__main__': +if __name__ == "__main__": with LoggingContext("main"): start(sys.argv[1:]) diff --git a/synapse/app/event_creator.py b/synapse/app/event_creator.py index b8e51961528b..ff522e4499f3 100644 --- a/synapse/app/event_creator.py +++ b/synapse/app/event_creator.py @@ -109,12 +109,14 @@ def _listen_http(self, listener_config): ProfileAvatarURLRestServlet(self).register(resource) ProfileDisplaynameRestServlet(self).register(resource) ProfileRestServlet(self).register(resource) - resources.update({ - "/_matrix/client/r0": resource, - "/_matrix/client/unstable": resource, - "/_matrix/client/v2_alpha": resource, - "/_matrix/client/api/v1": resource, - }) + resources.update( + { + "/_matrix/client/r0": resource, + "/_matrix/client/unstable": resource, + "/_matrix/client/v2_alpha": resource, + "/_matrix/client/api/v1": resource, + } + ) root_resource = create_resource_tree(resources, NoResource()) @@ -127,7 +129,7 @@ def _listen_http(self, listener_config): listener_config, root_resource, self.version_string, - ) + ), ) logger.info("Synapse event creator now listening on port %d", port) @@ -141,18 +143,19 @@ def start_listening(self, listeners): listener["bind_addresses"], listener["port"], manhole( - username="matrix", - password="rabbithole", - globals={"hs": self}, - ) + username="matrix", password="rabbithole", globals={"hs": self} + ), ) elif listener["type"] == "metrics": if not self.get_config().enable_metrics: - logger.warn(("Metrics listener configured, but " - "enable_metrics is not True!")) + logger.warn( + ( + "Metrics listener configured, but " + "enable_metrics is not True!" + ) + ) else: - _base.listen_metrics(listener["bind_addresses"], - listener["port"]) + _base.listen_metrics(listener["bind_addresses"], listener["port"]) else: logger.warn("Unrecognized listener type: %s", listener["type"]) @@ -164,9 +167,7 @@ def build_tcp_replication(self): def start(config_options): try: - config = HomeServerConfig.load_config( - "Synapse event creator", config_options - ) + config = HomeServerConfig.load_config("Synapse event creator", config_options) except ConfigError as e: sys.stderr.write("\n" + str(e) + "\n") sys.exit(1) @@ -198,6 +199,6 @@ def start(config_options): _base.start_worker_reactor("synapse-event-creator", config) -if __name__ == '__main__': +if __name__ == "__main__": with LoggingContext("main"): start(sys.argv[1:]) diff --git a/synapse/app/federation_reader.py b/synapse/app/federation_reader.py index 7da79dc82768..94214209301a 100644 --- a/synapse/app/federation_reader.py +++ b/synapse/app/federation_reader.py @@ -86,19 +86,18 @@ def _listen_http(self, listener_config): if name == "metrics": resources[METRICS_PREFIX] = MetricsResource(RegistryProxy) elif name == "federation": - resources.update({ - FEDERATION_PREFIX: TransportLayerServer(self), - }) + resources.update({FEDERATION_PREFIX: TransportLayerServer(self)}) if name == "openid" and "federation" not in res["names"]: # Only load the openid resource separately if federation resource # is not specified since federation resource includes openid # resource. - resources.update({ - FEDERATION_PREFIX: TransportLayerServer( - self, - servlet_groups=["openid"], - ), - }) + resources.update( + { + FEDERATION_PREFIX: TransportLayerServer( + self, servlet_groups=["openid"] + ) + } + ) if name in ["keys", "federation"]: resources[SERVER_KEY_V2_PREFIX] = KeyApiV2Resource(self) @@ -115,7 +114,7 @@ def _listen_http(self, listener_config): root_resource, self.version_string, ), - reactor=self.get_reactor() + reactor=self.get_reactor(), ) logger.info("Synapse federation reader now listening on port %d", port) @@ -129,18 +128,19 @@ def start_listening(self, listeners): listener["bind_addresses"], listener["port"], manhole( - username="matrix", - password="rabbithole", - globals={"hs": self}, - ) + username="matrix", password="rabbithole", globals={"hs": self} + ), ) elif listener["type"] == "metrics": if not self.get_config().enable_metrics: - logger.warn(("Metrics listener configured, but " - "enable_metrics is not True!")) + logger.warn( + ( + "Metrics listener configured, but " + "enable_metrics is not True!" + ) + ) else: - _base.listen_metrics(listener["bind_addresses"], - listener["port"]) + _base.listen_metrics(listener["bind_addresses"], listener["port"]) else: logger.warn("Unrecognized listener type: %s", listener["type"]) @@ -181,6 +181,6 @@ def start(config_options): _base.start_worker_reactor("synapse-federation-reader", config) -if __name__ == '__main__': +if __name__ == "__main__": with LoggingContext("main"): start(sys.argv[1:]) diff --git a/synapse/app/federation_sender.py b/synapse/app/federation_sender.py index 1d43f2b0755b..969be58d0b04 100644 --- a/synapse/app/federation_sender.py +++ b/synapse/app/federation_sender.py @@ -52,8 +52,13 @@ class FederationSenderSlaveStore( - SlavedDeviceInboxStore, SlavedTransactionStore, SlavedReceiptsStore, SlavedEventStore, - SlavedRegistrationStore, SlavedDeviceStore, SlavedPresenceStore, + SlavedDeviceInboxStore, + SlavedTransactionStore, + SlavedReceiptsStore, + SlavedEventStore, + SlavedRegistrationStore, + SlavedDeviceStore, + SlavedPresenceStore, ): def __init__(self, db_conn, hs): super(FederationSenderSlaveStore, self).__init__(db_conn, hs) @@ -65,10 +70,7 @@ def __init__(self, db_conn, hs): self.federation_out_pos_startup = self._get_federation_out_pos(db_conn) def _get_federation_out_pos(self, db_conn): - sql = ( - "SELECT stream_id FROM federation_stream_position" - " WHERE type = ?" - ) + sql = "SELECT stream_id FROM federation_stream_position" " WHERE type = ?" sql = self.database_engine.convert_param_style(sql) txn = db_conn.cursor() @@ -103,7 +105,7 @@ def _listen_http(self, listener_config): listener_config, root_resource, self.version_string, - ) + ), ) logger.info("Synapse federation_sender now listening on port %d", port) @@ -117,18 +119,19 @@ def start_listening(self, listeners): listener["bind_addresses"], listener["port"], manhole( - username="matrix", - password="rabbithole", - globals={"hs": self}, - ) + username="matrix", password="rabbithole", globals={"hs": self} + ), ) elif listener["type"] == "metrics": if not self.get_config().enable_metrics: - logger.warn(("Metrics listener configured, but " - "enable_metrics is not True!")) + logger.warn( + ( + "Metrics listener configured, but " + "enable_metrics is not True!" + ) + ) else: - _base.listen_metrics(listener["bind_addresses"], - listener["port"]) + _base.listen_metrics(listener["bind_addresses"], listener["port"]) else: logger.warn("Unrecognized listener type: %s", listener["type"]) @@ -151,7 +154,9 @@ def on_rdata(self, stream_name, token, rows): self.send_handler.process_replication_rows(stream_name, token, rows) def get_streams_to_replicate(self): - args = super(FederationSenderReplicationHandler, self).get_streams_to_replicate() + args = super( + FederationSenderReplicationHandler, self + ).get_streams_to_replicate() args.update(self.send_handler.stream_positions()) return args @@ -203,6 +208,7 @@ class FederationSenderHandler(object): """Processes the replication stream and forwards the appropriate entries to the federation sender. """ + def __init__(self, hs, replication_client): self.store = hs.get_datastore() self._is_mine_id = hs.is_mine_id @@ -241,7 +247,7 @@ def process_replication_rows(self, stream_name, token, rows): # ... and when new receipts happen elif stream_name == ReceiptsStream.NAME: run_as_background_process( - "process_receipts_for_federation", self._on_new_receipts, rows, + "process_receipts_for_federation", self._on_new_receipts, rows ) @defer.inlineCallbacks @@ -278,12 +284,14 @@ def update_token(self, token): # We ACK this token over replication so that the master can drop # its in memory queues - self.replication_client.send_federation_ack(self.federation_position) + self.replication_client.send_federation_ack( + self.federation_position + ) self._last_ack = self.federation_position except Exception: logger.exception("Error updating federation stream position") -if __name__ == '__main__': +if __name__ == "__main__": with LoggingContext("main"): start(sys.argv[1:]) diff --git a/synapse/app/frontend_proxy.py b/synapse/app/frontend_proxy.py index 6504da527825..2fd7d57ebf1c 100644 --- a/synapse/app/frontend_proxy.py +++ b/synapse/app/frontend_proxy.py @@ -62,14 +62,11 @@ def on_GET(self, request, user_id): # Pass through the auth headers, if any, in case the access token # is there. auth_headers = request.requestHeaders.getRawHeaders("Authorization", []) - headers = { - "Authorization": auth_headers, - } + headers = {"Authorization": auth_headers} try: result = yield self.http_client.get_json( - self.main_uri + request.uri.decode('ascii'), - headers=headers, + self.main_uri + request.uri.decode("ascii"), headers=headers ) except HttpResponseException as e: raise e.to_synapse_error() @@ -105,18 +102,19 @@ def on_POST(self, request, device_id): if device_id is not None: # passing the device_id here is deprecated; however, we allow it # for now for compatibility with older clients. - if (requester.device_id is not None and - device_id != requester.device_id): - logger.warning("Client uploading keys for a different device " - "(logged in as %s, uploading for %s)", - requester.device_id, device_id) + if requester.device_id is not None and device_id != requester.device_id: + logger.warning( + "Client uploading keys for a different device " + "(logged in as %s, uploading for %s)", + requester.device_id, + device_id, + ) else: device_id = requester.device_id if device_id is None: raise SynapseError( - 400, - "To upload keys, you must pass device_id when authenticating" + 400, "To upload keys, you must pass device_id when authenticating" ) if body: @@ -124,13 +122,9 @@ def on_POST(self, request, device_id): # Pass through the auth headers, if any, in case the access token # is there. auth_headers = request.requestHeaders.getRawHeaders(b"Authorization", []) - headers = { - "Authorization": auth_headers, - } + headers = {"Authorization": auth_headers} result = yield self.http_client.post_json_get_json( - self.main_uri + request.uri.decode('ascii'), - body, - headers=headers, + self.main_uri + request.uri.decode("ascii"), body, headers=headers ) defer.returnValue((200, result)) @@ -171,12 +165,14 @@ def _listen_http(self, listener_config): if not self.config.use_presence: PresenceStatusStubServlet(self).register(resource) - resources.update({ - "/_matrix/client/r0": resource, - "/_matrix/client/unstable": resource, - "/_matrix/client/v2_alpha": resource, - "/_matrix/client/api/v1": resource, - }) + resources.update( + { + "/_matrix/client/r0": resource, + "/_matrix/client/unstable": resource, + "/_matrix/client/v2_alpha": resource, + "/_matrix/client/api/v1": resource, + } + ) root_resource = create_resource_tree(resources, NoResource()) @@ -190,7 +186,7 @@ def _listen_http(self, listener_config): root_resource, self.version_string, ), - reactor=self.get_reactor() + reactor=self.get_reactor(), ) logger.info("Synapse client reader now listening on port %d", port) @@ -204,18 +200,19 @@ def start_listening(self, listeners): listener["bind_addresses"], listener["port"], manhole( - username="matrix", - password="rabbithole", - globals={"hs": self}, - ) + username="matrix", password="rabbithole", globals={"hs": self} + ), ) elif listener["type"] == "metrics": if not self.get_config().enable_metrics: - logger.warn(("Metrics listener configured, but " - "enable_metrics is not True!")) + logger.warn( + ( + "Metrics listener configured, but " + "enable_metrics is not True!" + ) + ) else: - _base.listen_metrics(listener["bind_addresses"], - listener["port"]) + _base.listen_metrics(listener["bind_addresses"], listener["port"]) else: logger.warn("Unrecognized listener type: %s", listener["type"]) @@ -227,9 +224,7 @@ def build_tcp_replication(self): def start(config_options): try: - config = HomeServerConfig.load_config( - "Synapse frontend proxy", config_options - ) + config = HomeServerConfig.load_config("Synapse frontend proxy", config_options) except ConfigError as e: sys.stderr.write("\n" + str(e) + "\n") sys.exit(1) @@ -258,6 +253,6 @@ def start(config_options): _base.start_worker_reactor("synapse-frontend-proxy", config) -if __name__ == '__main__': +if __name__ == "__main__": with LoggingContext("main"): start(sys.argv[1:]) diff --git a/synapse/app/homeserver.py b/synapse/app/homeserver.py index b27b12e73d98..d19c7c7d7108 100755 --- a/synapse/app/homeserver.py +++ b/synapse/app/homeserver.py @@ -101,13 +101,12 @@ def _listener_http(self, config, listener_config): # Skip loading openid resource if federation is defined # since federation resource will include openid continue - resources.update(self._configure_named_resource( - name, res.get("compress", False), - )) + resources.update( + self._configure_named_resource(name, res.get("compress", False)) + ) additional_resources = listener_config.get("additional_resources", {}) - logger.debug("Configuring additional resources: %r", - additional_resources) + logger.debug("Configuring additional resources: %r", additional_resources) module_api = ModuleApi(self, self.get_auth_handler()) for path, resmodule in additional_resources.items(): handler_cls, config = load_module(resmodule) @@ -174,59 +173,67 @@ def _configure_named_resource(self, name, compress=False): if compress: client_resource = gz_wrap(client_resource) - resources.update({ - "/_matrix/client/api/v1": client_resource, - "/_matrix/client/r0": client_resource, - "/_matrix/client/unstable": client_resource, - "/_matrix/client/v2_alpha": client_resource, - "/_matrix/client/versions": client_resource, - "/.well-known/matrix/client": WellKnownResource(self), - "/_synapse/admin": AdminRestResource(self), - }) + resources.update( + { + "/_matrix/client/api/v1": client_resource, + "/_matrix/client/r0": client_resource, + "/_matrix/client/unstable": client_resource, + "/_matrix/client/v2_alpha": client_resource, + "/_matrix/client/versions": client_resource, + "/.well-known/matrix/client": WellKnownResource(self), + "/_synapse/admin": AdminRestResource(self), + } + ) if self.get_config().saml2_enabled: from synapse.rest.saml2 import SAML2Resource + resources["/_matrix/saml2"] = SAML2Resource(self) if name == "consent": from synapse.rest.consent.consent_resource import ConsentResource + consent_resource = ConsentResource(self) if compress: consent_resource = gz_wrap(consent_resource) - resources.update({ - "/_matrix/consent": consent_resource, - }) + resources.update({"/_matrix/consent": consent_resource}) if name == "federation": - resources.update({ - FEDERATION_PREFIX: TransportLayerServer(self), - }) + resources.update({FEDERATION_PREFIX: TransportLayerServer(self)}) if name == "openid": - resources.update({ - FEDERATION_PREFIX: TransportLayerServer(self, servlet_groups=["openid"]), - }) + resources.update( + { + FEDERATION_PREFIX: TransportLayerServer( + self, servlet_groups=["openid"] + ) + } + ) if name in ["static", "client"]: - resources.update({ - STATIC_PREFIX: File( - os.path.join(os.path.dirname(synapse.__file__), "static") - ), - }) + resources.update( + { + STATIC_PREFIX: File( + os.path.join(os.path.dirname(synapse.__file__), "static") + ) + } + ) if name in ["media", "federation", "client"]: if self.get_config().enable_media_repo: media_repo = self.get_media_repository_resource() - resources.update({ - MEDIA_PREFIX: media_repo, - LEGACY_MEDIA_PREFIX: media_repo, - CONTENT_REPO_PREFIX: ContentRepoResource( - self, self.config.uploads_path - ), - }) + resources.update( + { + MEDIA_PREFIX: media_repo, + LEGACY_MEDIA_PREFIX: media_repo, + CONTENT_REPO_PREFIX: ContentRepoResource( + self, self.config.uploads_path + ), + } + ) elif name == "media": raise ConfigError( - "'media' resource conflicts with enable_media_repo=False", + "'media' resource conflicts with enable_media_repo=False" ) if name in ["keys", "federation"]: @@ -257,18 +264,14 @@ def start_listening(self, listeners): for listener in listeners: if listener["type"] == "http": - self._listening_services.extend( - self._listener_http(config, listener) - ) + self._listening_services.extend(self._listener_http(config, listener)) elif listener["type"] == "manhole": listen_tcp( listener["bind_addresses"], listener["port"], manhole( - username="matrix", - password="rabbithole", - globals={"hs": self}, - ) + username="matrix", password="rabbithole", globals={"hs": self} + ), ) elif listener["type"] == "replication": services = listen_tcp( @@ -277,16 +280,17 @@ def start_listening(self, listeners): ReplicationStreamProtocolFactory(self), ) for s in services: - reactor.addSystemEventTrigger( - "before", "shutdown", s.stopListening, - ) + reactor.addSystemEventTrigger("before", "shutdown", s.stopListening) elif listener["type"] == "metrics": if not self.get_config().enable_metrics: - logger.warn(("Metrics listener configured, but " - "enable_metrics is not True!")) + logger.warn( + ( + "Metrics listener configured, but " + "enable_metrics is not True!" + ) + ) else: - _base.listen_metrics(listener["bind_addresses"], - listener["port"]) + _base.listen_metrics(listener["bind_addresses"], listener["port"]) else: logger.warn("Unrecognized listener type: %s", listener["type"]) @@ -312,7 +316,7 @@ def run_startup_checks(self, db_conn, database_engine): max_mau_gauge = Gauge("synapse_admin_mau:max", "MAU Limit") registered_reserved_users_mau_gauge = Gauge( "synapse_admin_mau:registered_reserved_users", - "Registered users with reserved threepids" + "Registered users with reserved threepids", ) @@ -327,8 +331,7 @@ def setup(config_options): """ try: config = HomeServerConfig.load_or_generate_config( - "Synapse Homeserver", - config_options, + "Synapse Homeserver", config_options ) except ConfigError as e: sys.stderr.write("\n" + str(e) + "\n") @@ -339,10 +342,7 @@ def setup(config_options): # generating config files and shouldn't try to continue. sys.exit(0) - synapse.config.logger.setup_logging( - config, - use_worker_options=False - ) + synapse.config.logger.setup_logging(config, use_worker_options=False) events.USE_FROZEN_DICTS = config.use_frozen_dicts @@ -357,7 +357,7 @@ def setup(config_options): database_engine=database_engine, ) - logger.info("Preparing database: %s...", config.database_config['name']) + logger.info("Preparing database: %s...", config.database_config["name"]) try: with hs.get_db_conn(run_new_connection=False) as db_conn: @@ -375,7 +375,7 @@ def setup(config_options): ) sys.exit(1) - logger.info("Database prepared in %s.", config.database_config['name']) + logger.info("Database prepared in %s.", config.database_config["name"]) hs.setup() hs.setup_master() @@ -391,9 +391,7 @@ def do_acme(): acme = hs.get_acme_handler() # Check how long the certificate is active for. - cert_days_remaining = hs.config.is_disk_cert_valid( - allow_self_signed=False - ) + cert_days_remaining = hs.config.is_disk_cert_valid(allow_self_signed=False) # We want to reprovision if cert_days_remaining is None (meaning no # certificate exists), or the days remaining number it returns @@ -401,8 +399,8 @@ def do_acme(): provision = False if ( - cert_days_remaining is None or - cert_days_remaining < hs.config.acme_reprovision_threshold + cert_days_remaining is None + or cert_days_remaining < hs.config.acme_reprovision_threshold ): provision = True @@ -433,10 +431,7 @@ def start(): yield do_acme() # Check if it needs to be reprovisioned every day. - hs.get_clock().looping_call( - reprovision_acme, - 24 * 60 * 60 * 1000 - ) + hs.get_clock().looping_call(reprovision_acme, 24 * 60 * 60 * 1000) _base.start(hs, config.listeners) @@ -463,6 +458,7 @@ class SynapseService(service.Service): A twisted Service class that will start synapse. Used to run synapse via twistd and a .tac. """ + def __init__(self, config): self.config = config @@ -479,6 +475,7 @@ def stopService(self): def run(hs): PROFILE_SYNAPSE = False if PROFILE_SYNAPSE: + def profile(func): from cProfile import Profile from threading import current_thread @@ -489,13 +486,14 @@ def profiled(*args, **kargs): func(*args, **kargs) profile.disable() ident = current_thread().ident - profile.dump_stats("/tmp/%s.%s.%i.pstat" % ( - hs.hostname, func.__name__, ident - )) + profile.dump_stats( + "/tmp/%s.%s.%i.pstat" % (hs.hostname, func.__name__, ident) + ) return profiled from twisted.python.threadpool import ThreadPool + ThreadPool._worker = profile(ThreadPool._worker) reactor.run = profile(reactor.run) @@ -541,7 +539,9 @@ def phone_stats_home(): stats["daily_active_users"] = yield hs.get_datastore().count_daily_users() stats["monthly_active_users"] = yield hs.get_datastore().count_monthly_users() - stats["daily_active_rooms"] = yield hs.get_datastore().count_daily_active_rooms() + stats[ + "daily_active_rooms" + ] = yield hs.get_datastore().count_daily_active_rooms() stats["daily_messages"] = yield hs.get_datastore().count_daily_messages() r30_results = yield hs.get_datastore().count_r30_users() @@ -565,8 +565,7 @@ def phone_stats_home(): logger.info("Reporting stats to matrix.org: %s" % (stats,)) try: yield hs.get_simple_http_client().put_json( - "https://matrix.org/report-usage-stats/push", - stats + "https://matrix.org/report-usage-stats/push", stats ) except Exception as e: logger.warn("Error reporting stats: %s", e) @@ -581,14 +580,11 @@ def performance_stats_init(): logger.info("report_stats can use psutil") stats_process.append(process) except (AttributeError): - logger.warning( - "Unable to read memory/cpu stats. Disabling reporting." - ) + logger.warning("Unable to read memory/cpu stats. Disabling reporting.") def generate_user_daily_visit_stats(): return run_as_background_process( - "generate_user_daily_visits", - hs.get_datastore().generate_user_daily_visits, + "generate_user_daily_visits", hs.get_datastore().generate_user_daily_visits ) # Rather than update on per session basis, batch up the requests. @@ -599,9 +595,9 @@ def generate_user_daily_visit_stats(): # monthly active user limiting functionality def reap_monthly_active_users(): return run_as_background_process( - "reap_monthly_active_users", - hs.get_datastore().reap_monthly_active_users, + "reap_monthly_active_users", hs.get_datastore().reap_monthly_active_users ) + clock.looping_call(reap_monthly_active_users, 1000 * 60 * 60) reap_monthly_active_users() @@ -619,8 +615,7 @@ def generate_monthly_active_users(): def start_generate_monthly_active_users(): return run_as_background_process( - "generate_monthly_active_users", - generate_monthly_active_users, + "generate_monthly_active_users", generate_monthly_active_users ) start_generate_monthly_active_users() @@ -660,5 +655,5 @@ def main(): run(hs) -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/synapse/app/media_repository.py b/synapse/app/media_repository.py index d4cc4e9443c7..cf0e2036c3fd 100644 --- a/synapse/app/media_repository.py +++ b/synapse/app/media_repository.py @@ -72,13 +72,15 @@ def _listen_http(self, listener_config): resources[METRICS_PREFIX] = MetricsResource(RegistryProxy) elif name == "media": media_repo = self.get_media_repository_resource() - resources.update({ - MEDIA_PREFIX: media_repo, - LEGACY_MEDIA_PREFIX: media_repo, - CONTENT_REPO_PREFIX: ContentRepoResource( - self, self.config.uploads_path - ), - }) + resources.update( + { + MEDIA_PREFIX: media_repo, + LEGACY_MEDIA_PREFIX: media_repo, + CONTENT_REPO_PREFIX: ContentRepoResource( + self, self.config.uploads_path + ), + } + ) root_resource = create_resource_tree(resources, NoResource()) @@ -91,7 +93,7 @@ def _listen_http(self, listener_config): listener_config, root_resource, self.version_string, - ) + ), ) logger.info("Synapse media repository now listening on port %d", port) @@ -105,18 +107,19 @@ def start_listening(self, listeners): listener["bind_addresses"], listener["port"], manhole( - username="matrix", - password="rabbithole", - globals={"hs": self}, - ) + username="matrix", password="rabbithole", globals={"hs": self} + ), ) elif listener["type"] == "metrics": if not self.get_config().enable_metrics: - logger.warn(("Metrics listener configured, but " - "enable_metrics is not True!")) + logger.warn( + ( + "Metrics listener configured, but " + "enable_metrics is not True!" + ) + ) else: - _base.listen_metrics(listener["bind_addresses"], - listener["port"]) + _base.listen_metrics(listener["bind_addresses"], listener["port"]) else: logger.warn("Unrecognized listener type: %s", listener["type"]) @@ -164,6 +167,6 @@ def start(config_options): _base.start_worker_reactor("synapse-media-repository", config) -if __name__ == '__main__': +if __name__ == "__main__": with LoggingContext("main"): start(sys.argv[1:]) diff --git a/synapse/app/pusher.py b/synapse/app/pusher.py index cbf0d67f51fa..df29ea5ecbea 100644 --- a/synapse/app/pusher.py +++ b/synapse/app/pusher.py @@ -46,36 +46,27 @@ class PusherSlaveStore( - SlavedEventStore, SlavedPusherStore, SlavedReceiptsStore, - SlavedAccountDataStore + SlavedEventStore, SlavedPusherStore, SlavedReceiptsStore, SlavedAccountDataStore ): - update_pusher_last_stream_ordering_and_success = ( - __func__(DataStore.update_pusher_last_stream_ordering_and_success) + update_pusher_last_stream_ordering_and_success = __func__( + DataStore.update_pusher_last_stream_ordering_and_success ) - update_pusher_failing_since = ( - __func__(DataStore.update_pusher_failing_since) - ) + update_pusher_failing_since = __func__(DataStore.update_pusher_failing_since) - update_pusher_last_stream_ordering = ( - __func__(DataStore.update_pusher_last_stream_ordering) + update_pusher_last_stream_ordering = __func__( + DataStore.update_pusher_last_stream_ordering ) - get_throttle_params_by_room = ( - __func__(DataStore.get_throttle_params_by_room) - ) + get_throttle_params_by_room = __func__(DataStore.get_throttle_params_by_room) - set_throttle_params = ( - __func__(DataStore.set_throttle_params) - ) + set_throttle_params = __func__(DataStore.set_throttle_params) - get_time_of_last_push_action_before = ( - __func__(DataStore.get_time_of_last_push_action_before) + get_time_of_last_push_action_before = __func__( + DataStore.get_time_of_last_push_action_before ) - get_profile_displayname = ( - __func__(DataStore.get_profile_displayname) - ) + get_profile_displayname = __func__(DataStore.get_profile_displayname) class PusherServer(HomeServer): @@ -105,7 +96,7 @@ def _listen_http(self, listener_config): listener_config, root_resource, self.version_string, - ) + ), ) logger.info("Synapse pusher now listening on port %d", port) @@ -119,18 +110,19 @@ def start_listening(self, listeners): listener["bind_addresses"], listener["port"], manhole( - username="matrix", - password="rabbithole", - globals={"hs": self}, - ) + username="matrix", password="rabbithole", globals={"hs": self} + ), ) elif listener["type"] == "metrics": if not self.get_config().enable_metrics: - logger.warn(("Metrics listener configured, but " - "enable_metrics is not True!")) + logger.warn( + ( + "Metrics listener configured, but " + "enable_metrics is not True!" + ) + ) else: - _base.listen_metrics(listener["bind_addresses"], - listener["port"]) + _base.listen_metrics(listener["bind_addresses"], listener["port"]) else: logger.warn("Unrecognized listener type: %s", listener["type"]) @@ -161,9 +153,7 @@ def poke_pushers(self, stream_name, token, rows): else: yield self.start_pusher(row.user_id, row.app_id, row.pushkey) elif stream_name == "events": - yield self.pusher_pool.on_new_notifications( - token, token, - ) + yield self.pusher_pool.on_new_notifications(token, token) elif stream_name == "receipts": yield self.pusher_pool.on_new_receipts( token, token, set(row.room_id for row in rows) @@ -188,9 +178,7 @@ def start_pusher(self, user_id, app_id, pushkey): def start(config_options): try: - config = HomeServerConfig.load_config( - "Synapse pusher", config_options - ) + config = HomeServerConfig.load_config("Synapse pusher", config_options) except ConfigError as e: sys.stderr.write("\n" + str(e) + "\n") sys.exit(1) @@ -234,6 +222,6 @@ def start(): _base.start_worker_reactor("synapse-pusher", config) -if __name__ == '__main__': +if __name__ == "__main__": with LoggingContext("main"): ps = start(sys.argv[1:]) diff --git a/synapse/app/synchrotron.py b/synapse/app/synchrotron.py index 5388def28ac6..858949910d0d 100644 --- a/synapse/app/synchrotron.py +++ b/synapse/app/synchrotron.py @@ -98,10 +98,7 @@ def __init__(self, hs): self.notifier = hs.get_notifier() active_presence = self.store.take_presence_startup_info() - self.user_to_current_state = { - state.user_id: state - for state in active_presence - } + self.user_to_current_state = {state.user_id: state for state in active_presence} # user_id -> last_sync_ms. Lists the users that have stopped syncing # but we haven't notified the master of that yet @@ -196,17 +193,26 @@ def notify_from_replication(self, states, stream_id): room_ids_to_states, users_to_states = parties self.notifier.on_new_event( - "presence_key", stream_id, rooms=room_ids_to_states.keys(), - users=users_to_states.keys() + "presence_key", + stream_id, + rooms=room_ids_to_states.keys(), + users=users_to_states.keys(), ) @defer.inlineCallbacks def process_replication_rows(self, token, rows): - states = [UserPresenceState( - row.user_id, row.state, row.last_active_ts, - row.last_federation_update_ts, row.last_user_sync_ts, row.status_msg, - row.currently_active - ) for row in rows] + states = [ + UserPresenceState( + row.user_id, + row.state, + row.last_active_ts, + row.last_federation_update_ts, + row.last_user_sync_ts, + row.status_msg, + row.currently_active, + ) + for row in rows + ] for state in states: self.user_to_current_state[state.user_id] = state @@ -217,7 +223,8 @@ def process_replication_rows(self, token, rows): def get_currently_syncing_users(self): if self.hs.config.use_presence: return [ - user_id for user_id, count in iteritems(self.user_to_num_current_syncs) + user_id + for user_id, count in iteritems(self.user_to_num_current_syncs) if count > 0 ] else: @@ -281,12 +288,14 @@ def _listen_http(self, listener_config): events.register_servlets(self, resource) InitialSyncRestServlet(self).register(resource) RoomInitialSyncRestServlet(self).register(resource) - resources.update({ - "/_matrix/client/r0": resource, - "/_matrix/client/unstable": resource, - "/_matrix/client/v2_alpha": resource, - "/_matrix/client/api/v1": resource, - }) + resources.update( + { + "/_matrix/client/r0": resource, + "/_matrix/client/unstable": resource, + "/_matrix/client/v2_alpha": resource, + "/_matrix/client/api/v1": resource, + } + ) root_resource = create_resource_tree(resources, NoResource()) @@ -299,7 +308,7 @@ def _listen_http(self, listener_config): listener_config, root_resource, self.version_string, - ) + ), ) logger.info("Synapse synchrotron now listening on port %d", port) @@ -313,18 +322,19 @@ def start_listening(self, listeners): listener["bind_addresses"], listener["port"], manhole( - username="matrix", - password="rabbithole", - globals={"hs": self}, - ) + username="matrix", password="rabbithole", globals={"hs": self} + ), ) elif listener["type"] == "metrics": if not self.get_config().enable_metrics: - logger.warn(("Metrics listener configured, but " - "enable_metrics is not True!")) + logger.warn( + ( + "Metrics listener configured, but " + "enable_metrics is not True!" + ) + ) else: - _base.listen_metrics(listener["bind_addresses"], - listener["port"]) + _base.listen_metrics(listener["bind_addresses"], listener["port"]) else: logger.warn("Unrecognized listener type: %s", listener["type"]) @@ -382,40 +392,36 @@ def process_and_notify(self, stream_name, token, rows): ) elif stream_name == "push_rules": self.notifier.on_new_event( - "push_rules_key", token, users=[row.user_id for row in rows], + "push_rules_key", token, users=[row.user_id for row in rows] ) - elif stream_name in ("account_data", "tag_account_data",): + elif stream_name in ("account_data", "tag_account_data"): self.notifier.on_new_event( - "account_data_key", token, users=[row.user_id for row in rows], + "account_data_key", token, users=[row.user_id for row in rows] ) elif stream_name == "receipts": self.notifier.on_new_event( - "receipt_key", token, rooms=[row.room_id for row in rows], + "receipt_key", token, rooms=[row.room_id for row in rows] ) elif stream_name == "typing": self.typing_handler.process_replication_rows(token, rows) self.notifier.on_new_event( - "typing_key", token, rooms=[row.room_id for row in rows], + "typing_key", token, rooms=[row.room_id for row in rows] ) elif stream_name == "to_device": entities = [row.entity for row in rows if row.entity.startswith("@")] if entities: - self.notifier.on_new_event( - "to_device_key", token, users=entities, - ) + self.notifier.on_new_event("to_device_key", token, users=entities) elif stream_name == "device_lists": all_room_ids = set() for row in rows: room_ids = yield self.store.get_rooms_for_user(row.user_id) all_room_ids.update(room_ids) - self.notifier.on_new_event( - "device_list_key", token, rooms=all_room_ids, - ) + self.notifier.on_new_event("device_list_key", token, rooms=all_room_ids) elif stream_name == "presence": yield self.presence_handler.process_replication_rows(token, rows) elif stream_name == "receipts": self.notifier.on_new_event( - "groups_key", token, users=[row.user_id for row in rows], + "groups_key", token, users=[row.user_id for row in rows] ) except Exception: logger.exception("Error processing replication") @@ -423,9 +429,7 @@ def process_and_notify(self, stream_name, token, rows): def start(config_options): try: - config = HomeServerConfig.load_config( - "Synapse synchrotron", config_options - ) + config = HomeServerConfig.load_config("Synapse synchrotron", config_options) except ConfigError as e: sys.stderr.write("\n" + str(e) + "\n") sys.exit(1) @@ -453,6 +457,6 @@ def start(config_options): _base.start_worker_reactor("synapse-synchrotron", config) -if __name__ == '__main__': +if __name__ == "__main__": with LoggingContext("main"): start(sys.argv[1:]) diff --git a/synapse/app/user_dir.py b/synapse/app/user_dir.py index 355f5aa71d36..2d9d2e1bbc02 100644 --- a/synapse/app/user_dir.py +++ b/synapse/app/user_dir.py @@ -66,14 +66,16 @@ def __init__(self, db_conn, hs): events_max = self._stream_id_gen.get_current_token() curr_state_delta_prefill, min_curr_state_delta_id = self._get_cache_dict( - db_conn, "current_state_delta_stream", + db_conn, + "current_state_delta_stream", entity_column="room_id", stream_column="stream_id", max_value=events_max, # As we share the stream id with events token limit=1000, ) self._curr_state_delta_stream_cache = StreamChangeCache( - "_curr_state_delta_stream_cache", min_curr_state_delta_id, + "_curr_state_delta_stream_cache", + min_curr_state_delta_id, prefilled_cache=curr_state_delta_prefill, ) @@ -110,12 +112,14 @@ def _listen_http(self, listener_config): elif name == "client": resource = JsonResource(self, canonical_json=False) user_directory.register_servlets(self, resource) - resources.update({ - "/_matrix/client/r0": resource, - "/_matrix/client/unstable": resource, - "/_matrix/client/v2_alpha": resource, - "/_matrix/client/api/v1": resource, - }) + resources.update( + { + "/_matrix/client/r0": resource, + "/_matrix/client/unstable": resource, + "/_matrix/client/v2_alpha": resource, + "/_matrix/client/api/v1": resource, + } + ) root_resource = create_resource_tree(resources, NoResource()) @@ -128,7 +132,7 @@ def _listen_http(self, listener_config): listener_config, root_resource, self.version_string, - ) + ), ) logger.info("Synapse user_dir now listening on port %d", port) @@ -142,18 +146,19 @@ def start_listening(self, listeners): listener["bind_addresses"], listener["port"], manhole( - username="matrix", - password="rabbithole", - globals={"hs": self}, - ) + username="matrix", password="rabbithole", globals={"hs": self} + ), ) elif listener["type"] == "metrics": if not self.get_config().enable_metrics: - logger.warn(("Metrics listener configured, but " - "enable_metrics is not True!")) + logger.warn( + ( + "Metrics listener configured, but " + "enable_metrics is not True!" + ) + ) else: - _base.listen_metrics(listener["bind_addresses"], - listener["port"]) + _base.listen_metrics(listener["bind_addresses"], listener["port"]) else: logger.warn("Unrecognized listener type: %s", listener["type"]) @@ -186,9 +191,7 @@ def _notify_directory(self): def start(config_options): try: - config = HomeServerConfig.load_config( - "Synapse user directory", config_options - ) + config = HomeServerConfig.load_config("Synapse user directory", config_options) except ConfigError as e: sys.stderr.write("\n" + str(e) + "\n") sys.exit(1) @@ -227,6 +230,6 @@ def start(config_options): _base.start_worker_reactor("synapse-user-dir", config) -if __name__ == '__main__': +if __name__ == "__main__": with LoggingContext("main"): start(sys.argv[1:]) diff --git a/synapse/appservice/__init__.py b/synapse/appservice/__init__.py index 57ed8a3ca201..b26a31dd54d1 100644 --- a/synapse/appservice/__init__.py +++ b/synapse/appservice/__init__.py @@ -48,9 +48,7 @@ def send(self, as_api): A Deferred which resolves to True if the transaction was sent. """ return as_api.push_bulk( - service=self.service, - events=self.events, - txn_id=self.id + service=self.service, events=self.events, txn_id=self.id ) def complete(self, store): @@ -64,10 +62,7 @@ def complete(self, store): Returns: A Deferred which resolves to True if the transaction was completed. """ - return store.complete_appservice_txn( - service=self.service, - txn_id=self.id - ) + return store.complete_appservice_txn(service=self.service, txn_id=self.id) class ApplicationService(object): @@ -76,6 +71,7 @@ class ApplicationService(object): Provides methods to check if this service is "interested" in events. """ + NS_USERS = "users" NS_ALIASES = "aliases" NS_ROOMS = "rooms" @@ -84,9 +80,19 @@ class ApplicationService(object): # values. NS_LIST = [NS_USERS, NS_ALIASES, NS_ROOMS] - def __init__(self, token, hostname, url=None, namespaces=None, hs_token=None, - sender=None, id=None, protocols=None, rate_limited=True, - ip_range_whitelist=None): + def __init__( + self, + token, + hostname, + url=None, + namespaces=None, + hs_token=None, + sender=None, + id=None, + protocols=None, + rate_limited=True, + ip_range_whitelist=None, + ): self.token = token self.url = url self.hs_token = hs_token @@ -128,9 +134,7 @@ def _check_namespaces(self, namespaces): if not isinstance(regex_obj, dict): raise ValueError("Expected dict regex for ns '%s'" % ns) if not isinstance(regex_obj.get("exclusive"), bool): - raise ValueError( - "Expected bool for 'exclusive' in ns '%s'" % ns - ) + raise ValueError("Expected bool for 'exclusive' in ns '%s'" % ns) group_id = regex_obj.get("group_id") if group_id: if not isinstance(group_id, str): @@ -153,9 +157,7 @@ def _check_namespaces(self, namespaces): if isinstance(regex, string_types): regex_obj["regex"] = re.compile(regex) # Pre-compile regex else: - raise ValueError( - "Expected string for 'regex' in ns '%s'" % ns - ) + raise ValueError("Expected string for 'regex' in ns '%s'" % ns) return namespaces def _matches_regex(self, test_string, namespace_key): @@ -178,8 +180,9 @@ def _matches_user(self, event, store): if self.is_interested_in_user(event.sender): defer.returnValue(True) # also check m.room.member state key - if (event.type == EventTypes.Member and - self.is_interested_in_user(event.state_key)): + if event.type == EventTypes.Member and self.is_interested_in_user( + event.state_key + ): defer.returnValue(True) if not store: diff --git a/synapse/appservice/api.py b/synapse/appservice/api.py index 9ccc5a80fc59..571881775bf2 100644 --- a/synapse/appservice/api.py +++ b/synapse/appservice/api.py @@ -32,19 +32,17 @@ sent_transactions_counter = Counter( "synapse_appservice_api_sent_transactions", "Number of /transactions/ requests sent", - ["service"] + ["service"], ) failed_transactions_counter = Counter( "synapse_appservice_api_failed_transactions", "Number of /transactions/ requests that failed to send", - ["service"] + ["service"], ) sent_events_counter = Counter( - "synapse_appservice_api_sent_events", - "Number of events sent to the AS", - ["service"] + "synapse_appservice_api_sent_events", "Number of events sent to the AS", ["service"] ) HOUR_IN_MS = 60 * 60 * 1000 @@ -92,8 +90,9 @@ def __init__(self, hs): super(ApplicationServiceApi, self).__init__(hs) self.clock = hs.get_clock() - self.protocol_meta_cache = ResponseCache(hs, "as_protocol_meta", - timeout_ms=HOUR_IN_MS) + self.protocol_meta_cache = ResponseCache( + hs, "as_protocol_meta", timeout_ms=HOUR_IN_MS + ) @defer.inlineCallbacks def query_user(self, service, user_id): @@ -102,9 +101,7 @@ def query_user(self, service, user_id): uri = service.url + ("/users/%s" % urllib.parse.quote(user_id)) response = None try: - response = yield self.get_json(uri, { - "access_token": service.hs_token - }) + response = yield self.get_json(uri, {"access_token": service.hs_token}) if response is not None: # just an empty json object defer.returnValue(True) except CodeMessageException as e: @@ -123,9 +120,7 @@ def query_alias(self, service, alias): uri = service.url + ("/rooms/%s" % urllib.parse.quote(alias)) response = None try: - response = yield self.get_json(uri, { - "access_token": service.hs_token - }) + response = yield self.get_json(uri, {"access_token": service.hs_token}) if response is not None: # just an empty json object defer.returnValue(True) except CodeMessageException as e: @@ -144,9 +139,7 @@ def query_3pe(self, service, kind, protocol, fields): elif kind == ThirdPartyEntityKind.LOCATION: required_field = "alias" else: - raise ValueError( - "Unrecognised 'kind' argument %r to query_3pe()", kind - ) + raise ValueError("Unrecognised 'kind' argument %r to query_3pe()", kind) if service.url is None: defer.returnValue([]) @@ -154,14 +147,13 @@ def query_3pe(self, service, kind, protocol, fields): service.url, APP_SERVICE_PREFIX, kind, - urllib.parse.quote(protocol) + urllib.parse.quote(protocol), ) try: response = yield self.get_json(uri, fields) if not isinstance(response, list): logger.warning( - "query_3pe to %s returned an invalid response %r", - uri, response + "query_3pe to %s returned an invalid response %r", uri, response ) defer.returnValue([]) @@ -171,8 +163,7 @@ def query_3pe(self, service, kind, protocol, fields): ret.append(r) else: logger.warning( - "query_3pe to %s returned an invalid result %r", - uri, r + "query_3pe to %s returned an invalid result %r", uri, r ) defer.returnValue(ret) @@ -189,27 +180,27 @@ def _get(): uri = "%s%s/thirdparty/protocol/%s" % ( service.url, APP_SERVICE_PREFIX, - urllib.parse.quote(protocol) + urllib.parse.quote(protocol), ) try: info = yield self.get_json(uri, {}) if not _is_valid_3pe_metadata(info): - logger.warning("query_3pe_protocol to %s did not return a" - " valid result", uri) + logger.warning( + "query_3pe_protocol to %s did not return a" " valid result", uri + ) defer.returnValue(None) for instance in info.get("instances", []): network_id = instance.get("network_id", None) if network_id is not None: instance["instance_id"] = ThirdPartyInstanceID( - service.id, network_id, + service.id, network_id ).to_string() defer.returnValue(info) except Exception as ex: - logger.warning("query_3pe_protocol to %s threw exception %s", - uri, ex) + logger.warning("query_3pe_protocol to %s threw exception %s", uri, ex) defer.returnValue(None) key = (service.id, protocol) @@ -223,22 +214,19 @@ def push_bulk(self, service, events, txn_id=None): events = self._serialize(events) if txn_id is None: - logger.warning("push_bulk: Missing txn ID sending events to %s", - service.url) + logger.warning( + "push_bulk: Missing txn ID sending events to %s", service.url + ) txn_id = str(0) txn_id = str(txn_id) - uri = service.url + ("/transactions/%s" % - urllib.parse.quote(txn_id)) + uri = service.url + ("/transactions/%s" % urllib.parse.quote(txn_id)) try: yield self.put_json( uri=uri, - json_body={ - "events": events - }, - args={ - "access_token": service.hs_token - }) + json_body={"events": events}, + args={"access_token": service.hs_token}, + ) sent_transactions_counter.labels(service.id).inc() sent_events_counter.labels(service.id).inc(len(events)) defer.returnValue(True) @@ -252,6 +240,4 @@ def push_bulk(self, service, events, txn_id=None): def _serialize(self, events): time_now = self.clock.time_msec() - return [ - serialize_event(e, time_now, as_client_event=True) for e in events - ] + return [serialize_event(e, time_now, as_client_event=True) for e in events] diff --git a/synapse/appservice/scheduler.py b/synapse/appservice/scheduler.py index 685f15c06104..b54bf5411f46 100644 --- a/synapse/appservice/scheduler.py +++ b/synapse/appservice/scheduler.py @@ -112,15 +112,14 @@ def enqueue(self, service, event): return run_as_background_process( - "as-sender-%s" % (service.id, ), - self._send_request, service, + "as-sender-%s" % (service.id,), self._send_request, service ) @defer.inlineCallbacks def _send_request(self, service): # sanity-check: we shouldn't get here if this service already has a sender # running. - assert(service.id not in self.requests_in_flight) + assert service.id not in self.requests_in_flight self.requests_in_flight.add(service.id) try: @@ -137,7 +136,6 @@ def _send_request(self, service): class _TransactionController(object): - def __init__(self, clock, store, as_api, recoverer_fn): self.clock = clock self.store = store @@ -149,10 +147,7 @@ def __init__(self, clock, store, as_api, recoverer_fn): @defer.inlineCallbacks def send(self, service, events): try: - txn = yield self.store.create_appservice_txn( - service=service, - events=events - ) + txn = yield self.store.create_appservice_txn(service=service, events=events) service_is_up = yield self._is_service_up(service) if service_is_up: sent = yield txn.send(self.as_api) @@ -167,12 +162,12 @@ def send(self, service, events): @defer.inlineCallbacks def on_recovered(self, recoverer): self.recoverers.remove(recoverer) - logger.info("Successfully recovered application service AS ID %s", - recoverer.service.id) + logger.info( + "Successfully recovered application service AS ID %s", recoverer.service.id + ) logger.info("Remaining active recoverers: %s", len(self.recoverers)) yield self.store.set_appservice_state( - recoverer.service, - ApplicationServiceState.UP + recoverer.service, ApplicationServiceState.UP ) def add_recoverers(self, recoverers): @@ -184,13 +179,10 @@ def add_recoverers(self, recoverers): @defer.inlineCallbacks def _start_recoverer(self, service): try: - yield self.store.set_appservice_state( - service, - ApplicationServiceState.DOWN - ) + yield self.store.set_appservice_state(service, ApplicationServiceState.DOWN) logger.info( "Application service falling behind. Starting recoverer. AS ID %s", - service.id + service.id, ) recoverer = self.recoverer_fn(service, self.on_recovered) self.add_recoverers([recoverer]) @@ -205,19 +197,16 @@ def _is_service_up(self, service): class _Recoverer(object): - @staticmethod @defer.inlineCallbacks def start(clock, store, as_api, callback): - services = yield store.get_appservices_by_state( - ApplicationServiceState.DOWN - ) - recoverers = [ - _Recoverer(clock, store, as_api, s, callback) for s in services - ] + services = yield store.get_appservices_by_state(ApplicationServiceState.DOWN) + recoverers = [_Recoverer(clock, store, as_api, s, callback) for s in services] for r in recoverers: - logger.info("Starting recoverer for AS ID %s which was marked as " - "DOWN", r.service.id) + logger.info( + "Starting recoverer for AS ID %s which was marked as " "DOWN", + r.service.id, + ) r.recover() defer.returnValue(recoverers) @@ -232,9 +221,9 @@ def __init__(self, clock, store, as_api, service, callback): def recover(self): def _retry(): run_as_background_process( - "as-recoverer-%s" % (self.service.id,), - self.retry, + "as-recoverer-%s" % (self.service.id,), self.retry ) + self.clock.call_later((2 ** self.backoff_counter), _retry) def _backoff(self): @@ -248,8 +237,9 @@ def retry(self): try: txn = yield self.store.get_oldest_unsent_txn(self.service) if txn: - logger.info("Retrying transaction %s for AS ID %s", - txn.id, txn.service.id) + logger.info( + "Retrying transaction %s for AS ID %s", txn.id, txn.service.id + ) sent = yield txn.send(self.as_api) if sent: yield txn.complete(self.store) diff --git a/synapse/config/_base.py b/synapse/config/_base.py index f7d7f153bb49..8284aa4c6d36 100644 --- a/synapse/config/_base.py +++ b/synapse/config/_base.py @@ -284,8 +284,8 @@ def load_or_generate_config(cls, description, argv): if not config_files: config_parser.error( "Must supply a config file.\nA config file can be automatically" - " generated using \"--generate-config -H SERVER_NAME" - " -c CONFIG-FILE\"" + ' generated using "--generate-config -H SERVER_NAME' + ' -c CONFIG-FILE"' ) (config_path,) = config_files if not cls.path_exists(config_path): @@ -313,9 +313,7 @@ def load_or_generate_config(cls, description, argv): if not cls.path_exists(config_dir_path): os.makedirs(config_dir_path) with open(config_path, "w") as config_file: - config_file.write( - "# vim:ft=yaml\n\n" - ) + config_file.write("# vim:ft=yaml\n\n") config_file.write(config_str) config = yaml.safe_load(config_str) @@ -352,8 +350,8 @@ def load_or_generate_config(cls, description, argv): if not config_files: config_parser.error( "Must supply a config file.\nA config file can be automatically" - " generated using \"--generate-config -H SERVER_NAME" - " -c CONFIG-FILE\"" + ' generated using "--generate-config -H SERVER_NAME' + ' -c CONFIG-FILE"' ) obj.read_config_files( diff --git a/synapse/config/api.py b/synapse/config/api.py index 5eb4f86fa290..23b0ea696261 100644 --- a/synapse/config/api.py +++ b/synapse/config/api.py @@ -18,15 +18,17 @@ class ApiConfig(Config): - def read_config(self, config): - self.room_invite_state_types = config.get("room_invite_state_types", [ - EventTypes.JoinRules, - EventTypes.CanonicalAlias, - EventTypes.RoomAvatar, - EventTypes.RoomEncryption, - EventTypes.Name, - ]) + self.room_invite_state_types = config.get( + "room_invite_state_types", + [ + EventTypes.JoinRules, + EventTypes.CanonicalAlias, + EventTypes.RoomAvatar, + EventTypes.RoomEncryption, + EventTypes.Name, + ], + ) def default_config(cls, **kwargs): return """\ @@ -40,4 +42,6 @@ def default_config(cls, **kwargs): # - "{RoomAvatar}" # - "{RoomEncryption}" # - "{Name}" - """.format(**vars(EventTypes)) + """.format( + **vars(EventTypes) + ) diff --git a/synapse/config/appservice.py b/synapse/config/appservice.py index 7e89d345d80c..679ee62480f5 100644 --- a/synapse/config/appservice.py +++ b/synapse/config/appservice.py @@ -29,7 +29,6 @@ class AppServiceConfig(Config): - def read_config(self, config): self.app_service_config_files = config.get("app_service_config_files", []) self.notify_appservices = config.get("notify_appservices", True) @@ -53,9 +52,7 @@ def default_config(cls, **kwargs): def load_appservices(hostname, config_files): """Returns a list of Application Services from the config files.""" if not isinstance(config_files, list): - logger.warning( - "Expected %s to be a list of AS config files.", config_files - ) + logger.warning("Expected %s to be a list of AS config files.", config_files) return [] # Dicts of value -> filename @@ -66,22 +63,20 @@ def load_appservices(hostname, config_files): for config_file in config_files: try: - with open(config_file, 'r') as f: - appservice = _load_appservice( - hostname, yaml.safe_load(f), config_file - ) + with open(config_file, "r") as f: + appservice = _load_appservice(hostname, yaml.safe_load(f), config_file) if appservice.id in seen_ids: raise ConfigError( "Cannot reuse ID across application services: " - "%s (files: %s, %s)" % ( - appservice.id, config_file, seen_ids[appservice.id], - ) + "%s (files: %s, %s)" + % (appservice.id, config_file, seen_ids[appservice.id]) ) seen_ids[appservice.id] = config_file if appservice.token in seen_as_tokens: raise ConfigError( "Cannot reuse as_token across application services: " - "%s (files: %s, %s)" % ( + "%s (files: %s, %s)" + % ( appservice.token, config_file, seen_as_tokens[appservice.token], @@ -98,28 +93,26 @@ def load_appservices(hostname, config_files): def _load_appservice(hostname, as_info, config_filename): - required_string_fields = [ - "id", "as_token", "hs_token", "sender_localpart" - ] + required_string_fields = ["id", "as_token", "hs_token", "sender_localpart"] for field in required_string_fields: if not isinstance(as_info.get(field), string_types): - raise KeyError("Required string field: '%s' (%s)" % ( - field, config_filename, - )) + raise KeyError( + "Required string field: '%s' (%s)" % (field, config_filename) + ) # 'url' must either be a string or explicitly null, not missing # to avoid accidentally turning off push for ASes. - if (not isinstance(as_info.get("url"), string_types) and - as_info.get("url", "") is not None): + if ( + not isinstance(as_info.get("url"), string_types) + and as_info.get("url", "") is not None + ): raise KeyError( "Required string field or explicit null: 'url' (%s)" % (config_filename,) ) localpart = as_info["sender_localpart"] if urlparse.quote(localpart) != localpart: - raise ValueError( - "sender_localpart needs characters which are not URL encoded." - ) + raise ValueError("sender_localpart needs characters which are not URL encoded.") user = UserID(localpart, hostname) user_id = user.to_string() @@ -138,13 +131,12 @@ def _load_appservice(hostname, as_info, config_filename): for regex_obj in as_info["namespaces"][ns]: if not isinstance(regex_obj, dict): raise ValueError( - "Expected namespace entry in %s to be an object," - " but got %s", ns, regex_obj + "Expected namespace entry in %s to be an object," " but got %s", + ns, + regex_obj, ) if not isinstance(regex_obj.get("regex"), string_types): - raise ValueError( - "Missing/bad type 'regex' key in %s", regex_obj - ) + raise ValueError("Missing/bad type 'regex' key in %s", regex_obj) if not isinstance(regex_obj.get("exclusive"), bool): raise ValueError( "Missing/bad type 'exclusive' key in %s", regex_obj @@ -167,10 +159,8 @@ def _load_appservice(hostname, as_info, config_filename): ) ip_range_whitelist = None - if as_info.get('ip_range_whitelist'): - ip_range_whitelist = IPSet( - as_info.get('ip_range_whitelist') - ) + if as_info.get("ip_range_whitelist"): + ip_range_whitelist = IPSet(as_info.get("ip_range_whitelist")) return ApplicationService( token=as_info["as_token"], diff --git a/synapse/config/captcha.py b/synapse/config/captcha.py index f7eebf26d238..e2eb473a9232 100644 --- a/synapse/config/captcha.py +++ b/synapse/config/captcha.py @@ -16,7 +16,6 @@ class CaptchaConfig(Config): - def read_config(self, config): self.recaptcha_private_key = config.get("recaptcha_private_key") self.recaptcha_public_key = config.get("recaptcha_public_key") diff --git a/synapse/config/consent_config.py b/synapse/config/consent_config.py index abeb0180d307..5b0bf919c7e9 100644 --- a/synapse/config/consent_config.py +++ b/synapse/config/consent_config.py @@ -89,29 +89,26 @@ def read_config(self, config): if consent_config is None: return self.user_consent_version = str(consent_config["version"]) - self.user_consent_template_dir = self.abspath( - consent_config["template_dir"] - ) + self.user_consent_template_dir = self.abspath(consent_config["template_dir"]) if not path.isdir(self.user_consent_template_dir): raise ConfigError( - "Could not find template directory '%s'" % ( - self.user_consent_template_dir, - ), + "Could not find template directory '%s'" + % (self.user_consent_template_dir,) ) self.user_consent_server_notice_content = consent_config.get( - "server_notice_content", + "server_notice_content" ) self.block_events_without_consent_error = consent_config.get( - "block_events_error", + "block_events_error" + ) + self.user_consent_server_notice_to_guests = bool( + consent_config.get("send_server_notice_to_guests", False) + ) + self.user_consent_at_registration = bool( + consent_config.get("require_at_registration", False) ) - self.user_consent_server_notice_to_guests = bool(consent_config.get( - "send_server_notice_to_guests", False, - )) - self.user_consent_at_registration = bool(consent_config.get( - "require_at_registration", False, - )) self.user_consent_policy_name = consent_config.get( - "policy_name", "Privacy Policy", + "policy_name", "Privacy Policy" ) def default_config(self, **kwargs): diff --git a/synapse/config/database.py b/synapse/config/database.py index 3c27ed6b4a49..adc0a47ddf70 100644 --- a/synapse/config/database.py +++ b/synapse/config/database.py @@ -18,29 +18,21 @@ class DatabaseConfig(Config): - def read_config(self, config): - self.event_cache_size = self.parse_size( - config.get("event_cache_size", "10K") - ) + self.event_cache_size = self.parse_size(config.get("event_cache_size", "10K")) self.database_config = config.get("database") if self.database_config is None: - self.database_config = { - "name": "sqlite3", - "args": {}, - } + self.database_config = {"name": "sqlite3", "args": {}} name = self.database_config.get("name", None) if name == "psycopg2": pass elif name == "sqlite3": - self.database_config.setdefault("args", {}).update({ - "cp_min": 1, - "cp_max": 1, - "check_same_thread": False, - }) + self.database_config.setdefault("args", {}).update( + {"cp_min": 1, "cp_max": 1, "check_same_thread": False} + ) else: raise RuntimeError("Unsupported database type '%s'" % (name,)) @@ -48,7 +40,8 @@ def read_config(self, config): def default_config(self, data_dir_path, **kwargs): database_path = os.path.join(data_dir_path, "homeserver.db") - return """\ + return ( + """\ ## Database ## database: @@ -62,7 +55,9 @@ def default_config(self, data_dir_path, **kwargs): # Number of events to cache in memory. # #event_cache_size: 10K - """ % locals() + """ + % locals() + ) def read_arguments(self, args): self.set_databasepath(args.database_path) @@ -77,6 +72,8 @@ def set_databasepath(self, database_path): def add_arguments(self, parser): db_group = parser.add_argument_group("database") db_group.add_argument( - "-d", "--database-path", metavar="SQLITE_DATABASE_PATH", - help="The path to a sqlite database to use." + "-d", + "--database-path", + metavar="SQLITE_DATABASE_PATH", + help="The path to a sqlite database to use.", ) diff --git a/synapse/config/emailconfig.py b/synapse/config/emailconfig.py index 86018dfcce27..3a6cb07206dd 100644 --- a/synapse/config/emailconfig.py +++ b/synapse/config/emailconfig.py @@ -56,7 +56,7 @@ def read_config(self, config): if self.email_notif_from is not None: # make sure it's valid parsed = email.utils.parseaddr(self.email_notif_from) - if parsed[1] == '': + if parsed[1] == "": raise RuntimeError("Invalid notif_from address") template_dir = email_config.get("template_dir") @@ -65,19 +65,17 @@ def read_config(self, config): # (Note that loading as package_resources with jinja.PackageLoader doesn't # work for the same reason.) if not template_dir: - template_dir = pkg_resources.resource_filename( - 'synapse', 'res/templates' - ) + template_dir = pkg_resources.resource_filename("synapse", "res/templates") self.email_template_dir = os.path.abspath(template_dir) self.email_enable_notifs = email_config.get("enable_notifs", False) - account_validity_renewal_enabled = config.get( - "account_validity", {}, - ).get("renew_at") + account_validity_renewal_enabled = config.get("account_validity", {}).get( + "renew_at" + ) email_trust_identity_server_for_password_resets = email_config.get( - "trust_identity_server_for_password_resets", False, + "trust_identity_server_for_password_resets", False ) self.email_password_reset_behaviour = ( "remote" if email_trust_identity_server_for_password_resets else "local" @@ -103,62 +101,59 @@ def read_config(self, config): # make sure we can import the required deps import jinja2 import bleach + # prevent unused warnings jinja2 bleach if self.email_password_reset_behaviour == "local": - required = [ - "smtp_host", - "smtp_port", - "notif_from", - ] + required = ["smtp_host", "smtp_port", "notif_from"] missing = [] for k in required: if k not in email_config: missing.append(k) - if (len(missing) > 0): + if len(missing) > 0: raise RuntimeError( "email.password_reset_behaviour is set to 'local' " - "but required keys are missing: %s" % - (", ".join(["email." + k for k in missing]),) + "but required keys are missing: %s" + % (", ".join(["email." + k for k in missing]),) ) # Templates for password reset emails self.email_password_reset_template_html = email_config.get( - "password_reset_template_html", "password_reset.html", + "password_reset_template_html", "password_reset.html" ) self.email_password_reset_template_text = email_config.get( - "password_reset_template_text", "password_reset.txt", + "password_reset_template_text", "password_reset.txt" ) self.email_password_reset_failure_template = email_config.get( - "password_reset_failure_template", "password_reset_failure.html", + "password_reset_failure_template", "password_reset_failure.html" ) # This template does not support any replaceable variables, so we will # read it from the disk once during setup email_password_reset_success_template = email_config.get( - "password_reset_success_template", "password_reset_success.html", + "password_reset_success_template", "password_reset_success.html" ) # Check templates exist - for f in [self.email_password_reset_template_html, - self.email_password_reset_template_text, - self.email_password_reset_failure_template, - email_password_reset_success_template]: + for f in [ + self.email_password_reset_template_html, + self.email_password_reset_template_text, + self.email_password_reset_failure_template, + email_password_reset_success_template, + ]: p = os.path.join(self.email_template_dir, f) if not os.path.isfile(p): - raise ConfigError("Unable to find template file %s" % (p, )) + raise ConfigError("Unable to find template file %s" % (p,)) # Retrieve content of web templates filepath = os.path.join( - self.email_template_dir, - email_password_reset_success_template, + self.email_template_dir, email_password_reset_success_template ) self.email_password_reset_success_html_content = self.read_file( - filepath, - "email.password_reset_template_success_html", + filepath, "email.password_reset_template_success_html" ) if config.get("public_baseurl") is None: @@ -182,10 +177,10 @@ def read_config(self, config): if k not in email_config: missing.append(k) - if (len(missing) > 0): + if len(missing) > 0: raise RuntimeError( - "email.enable_notifs is True but required keys are missing: %s" % - (", ".join(["email." + k for k in missing]),) + "email.enable_notifs is True but required keys are missing: %s" + % (", ".join(["email." + k for k in missing]),) ) if config.get("public_baseurl") is None: @@ -199,27 +194,25 @@ def read_config(self, config): for f in self.email_notif_template_text, self.email_notif_template_html: p = os.path.join(self.email_template_dir, f) if not os.path.isfile(p): - raise ConfigError("Unable to find email template file %s" % (p, )) + raise ConfigError("Unable to find email template file %s" % (p,)) self.email_notif_for_new_users = email_config.get( "notif_for_new_users", True ) - self.email_riot_base_url = email_config.get( - "riot_base_url", None - ) + self.email_riot_base_url = email_config.get("riot_base_url", None) if account_validity_renewal_enabled: self.email_expiry_template_html = email_config.get( - "expiry_template_html", "notice_expiry.html", + "expiry_template_html", "notice_expiry.html" ) self.email_expiry_template_text = email_config.get( - "expiry_template_text", "notice_expiry.txt", + "expiry_template_text", "notice_expiry.txt" ) for f in self.email_expiry_template_text, self.email_expiry_template_html: p = os.path.join(self.email_template_dir, f) if not os.path.isfile(p): - raise ConfigError("Unable to find email template file %s" % (p, )) + raise ConfigError("Unable to find email template file %s" % (p,)) def default_config(self, config_dir_path, server_name, **kwargs): return """ diff --git a/synapse/config/jwt_config.py b/synapse/config/jwt_config.py index ecb4124096e8..b190dcbe38ca 100644 --- a/synapse/config/jwt_config.py +++ b/synapse/config/jwt_config.py @@ -15,13 +15,11 @@ from ._base import Config, ConfigError -MISSING_JWT = ( - """Missing jwt library. This is required for jwt login. +MISSING_JWT = """Missing jwt library. This is required for jwt login. Install by running: pip install pyjwt """ -) class JWTConfig(Config): @@ -34,6 +32,7 @@ def read_config(self, config): try: import jwt + jwt # To stop unused lint. except ImportError: raise ConfigError(MISSING_JWT) diff --git a/synapse/config/key.py b/synapse/config/key.py index 424875feae3d..94a0f47ea482 100644 --- a/synapse/config/key.py +++ b/synapse/config/key.py @@ -348,9 +348,8 @@ def _parse_key_servers(key_servers, federation_verify_certificates): result.verify_keys[key_id] = verify_key - if ( - not federation_verify_certificates and - not server.get("accept_keys_insecurely") + if not federation_verify_certificates and not server.get( + "accept_keys_insecurely" ): _assert_keyserver_has_verify_keys(result) diff --git a/synapse/config/logger.py b/synapse/config/logger.py index c1febbe9d3ba..a22655b1257b 100644 --- a/synapse/config/logger.py +++ b/synapse/config/logger.py @@ -29,7 +29,8 @@ from ._base import Config -DEFAULT_LOG_CONFIG = Template(""" +DEFAULT_LOG_CONFIG = Template( + """ version: 1 formatters: @@ -68,11 +69,11 @@ root: level: INFO handlers: [file, console] -""") +""" +) class LoggingConfig(Config): - def read_config(self, config): self.verbosity = config.get("verbose", 0) self.no_redirect_stdio = config.get("no_redirect_stdio", False) @@ -81,13 +82,16 @@ def read_config(self, config): def default_config(self, config_dir_path, server_name, **kwargs): log_config = os.path.join(config_dir_path, server_name + ".log.config") - return """\ + return ( + """\ ## Logging ## # A yaml python logging config file # log_config: "%(log_config)s" - """ % locals() + """ + % locals() + ) def read_arguments(self, args): if args.verbose is not None: @@ -102,22 +106,31 @@ def read_arguments(self, args): def add_arguments(cls, parser): logging_group = parser.add_argument_group("logging") logging_group.add_argument( - '-v', '--verbose', dest="verbose", action='count', + "-v", + "--verbose", + dest="verbose", + action="count", help="The verbosity level. Specify multiple times to increase " - "verbosity. (Ignored if --log-config is specified.)" + "verbosity. (Ignored if --log-config is specified.)", ) logging_group.add_argument( - '-f', '--log-file', dest="log_file", - help="File to log to. (Ignored if --log-config is specified.)" + "-f", + "--log-file", + dest="log_file", + help="File to log to. (Ignored if --log-config is specified.)", ) logging_group.add_argument( - '--log-config', dest="log_config", default=None, - help="Python logging config file" + "--log-config", + dest="log_config", + default=None, + help="Python logging config file", ) logging_group.add_argument( - '-n', '--no-redirect-stdio', - action='store_true', default=None, - help="Do not redirect stdout/stderr to the log" + "-n", + "--no-redirect-stdio", + action="store_true", + default=None, + help="Do not redirect stdout/stderr to the log", ) def generate_files(self, config): @@ -125,9 +138,7 @@ def generate_files(self, config): if log_config and not os.path.exists(log_config): log_file = self.abspath("homeserver.log") with open(log_config, "w") as log_config_file: - log_config_file.write( - DEFAULT_LOG_CONFIG.substitute(log_file=log_file) - ) + log_config_file.write(DEFAULT_LOG_CONFIG.substitute(log_file=log_file)) def setup_logging(config, use_worker_options=False): @@ -143,10 +154,8 @@ def setup_logging(config, use_worker_options=False): register_sighup (func | None): Function to call to register a sighup handler. """ - log_config = (config.worker_log_config if use_worker_options - else config.log_config) - log_file = (config.worker_log_file if use_worker_options - else config.log_file) + log_config = config.worker_log_config if use_worker_options else config.log_config + log_file = config.worker_log_file if use_worker_options else config.log_file log_format = ( "%(asctime)s - %(name)s - %(lineno)d - %(levelname)s - %(request)s" @@ -164,23 +173,23 @@ def setup_logging(config, use_worker_options=False): if config.verbosity > 1: level_for_storage = logging.DEBUG - logger = logging.getLogger('') + logger = logging.getLogger("") logger.setLevel(level) - logging.getLogger('synapse.storage.SQL').setLevel(level_for_storage) + logging.getLogger("synapse.storage.SQL").setLevel(level_for_storage) formatter = logging.Formatter(log_format) if log_file: # TODO: Customisable file size / backup count handler = logging.handlers.RotatingFileHandler( - log_file, maxBytes=(1000 * 1000 * 100), backupCount=3, - encoding='utf8' + log_file, maxBytes=(1000 * 1000 * 100), backupCount=3, encoding="utf8" ) def sighup(signum, stack): logger.info("Closing log file due to SIGHUP") handler.doRollover() logger.info("Opened new log file due to SIGHUP") + else: handler = logging.StreamHandler() @@ -193,8 +202,9 @@ def sighup(*args): logger.addHandler(handler) else: + def load_log_config(): - with open(log_config, 'r') as f: + with open(log_config, "r") as f: logging.config.dictConfig(yaml.safe_load(f)) def sighup(*args): @@ -209,10 +219,7 @@ def sighup(*args): # make sure that the first thing we log is a thing we can grep backwards # for logging.warn("***** STARTING SERVER *****") - logging.warn( - "Server %s version %s", - sys.argv[0], get_version_string(synapse), - ) + logging.warn("Server %s version %s", sys.argv[0], get_version_string(synapse)) logging.info("Server hostname: %s", config.server_name) # It's critical to point twisted's internal logging somewhere, otherwise it @@ -242,8 +249,7 @@ def _log(event): return observer(event) globalLogBeginner.beginLoggingTo( - [_log], - redirectStandardIO=not config.no_redirect_stdio, + [_log], redirectStandardIO=not config.no_redirect_stdio ) if not config.no_redirect_stdio: print("Redirected stdout/stderr to logs") diff --git a/synapse/config/metrics.py b/synapse/config/metrics.py index 2de51979d84a..c85e234d2244 100644 --- a/synapse/config/metrics.py +++ b/synapse/config/metrics.py @@ -15,11 +15,9 @@ from ._base import Config, ConfigError -MISSING_SENTRY = ( - """Missing sentry-sdk library. This is required to enable sentry +MISSING_SENTRY = """Missing sentry-sdk library. This is required to enable sentry integration. """ -) class MetricsConfig(Config): @@ -39,7 +37,7 @@ def read_config(self, config): self.sentry_dsn = config["sentry"].get("dsn") if not self.sentry_dsn: raise ConfigError( - "sentry.dsn field is required when sentry integration is enabled", + "sentry.dsn field is required when sentry integration is enabled" ) def default_config(self, report_stats=None, **kwargs): @@ -66,6 +64,6 @@ def default_config(self, report_stats=None, **kwargs): if report_stats is None: res += "# report_stats: true|false\n" else: - res += "report_stats: %s\n" % ('true' if report_stats else 'false') + res += "report_stats: %s\n" % ("true" if report_stats else "false") return res diff --git a/synapse/config/password_auth_providers.py b/synapse/config/password_auth_providers.py index f0a6be0679ab..fcf279e8e105 100644 --- a/synapse/config/password_auth_providers.py +++ b/synapse/config/password_auth_providers.py @@ -17,7 +17,7 @@ from ._base import Config -LDAP_PROVIDER = 'ldap_auth_provider.LdapAuthProvider' +LDAP_PROVIDER = "ldap_auth_provider.LdapAuthProvider" class PasswordAuthProviderConfig(Config): @@ -29,24 +29,20 @@ def read_config(self, config): # param. ldap_config = config.get("ldap_config", {}) if ldap_config.get("enabled", False): - providers.append({ - 'module': LDAP_PROVIDER, - 'config': ldap_config, - }) + providers.append({"module": LDAP_PROVIDER, "config": ldap_config}) providers.extend(config.get("password_providers", [])) for provider in providers: - mod_name = provider['module'] + mod_name = provider["module"] # This is for backwards compat when the ldap auth provider resided # in this package. if mod_name == "synapse.util.ldap_auth_provider.LdapAuthProvider": mod_name = LDAP_PROVIDER - (provider_class, provider_config) = load_module({ - "module": mod_name, - "config": provider['config'], - }) + (provider_class, provider_config) = load_module( + {"module": mod_name, "config": provider["config"]} + ) self.password_providers.append((provider_class, provider_config)) diff --git a/synapse/config/registration.py b/synapse/config/registration.py index aad3400819ca..a1e27ba66c7f 100644 --- a/synapse/config/registration.py +++ b/synapse/config/registration.py @@ -23,7 +23,7 @@ class AccountValidityConfig(Config): def __init__(self, config, synapse_config): self.enabled = config.get("enabled", False) - self.renew_by_email_enabled = ("renew_at" in config) + self.renew_by_email_enabled = "renew_at" in config if self.enabled: if "period" in config: @@ -39,14 +39,13 @@ def __init__(self, config, synapse_config): else: self.renew_email_subject = "Renew your %(app)s account" - self.startup_job_max_delta = self.period * 10. / 100. + self.startup_job_max_delta = self.period * 10.0 / 100.0 if self.renew_by_email_enabled and "public_baseurl" not in synapse_config: raise ConfigError("Can't send renewal emails without 'public_baseurl'") class RegistrationConfig(Config): - def read_config(self, config): self.enable_registration = bool( strtobool(str(config.get("enable_registration", False))) @@ -57,7 +56,7 @@ def read_config(self, config): ) self.account_validity = AccountValidityConfig( - config.get("account_validity", {}), config, + config.get("account_validity", {}), config ) self.registrations_require_3pid = config.get("registrations_require_3pid", []) @@ -67,24 +66,23 @@ def read_config(self, config): self.bcrypt_rounds = config.get("bcrypt_rounds", 12) self.trusted_third_party_id_servers = config.get( - "trusted_third_party_id_servers", - ["matrix.org", "vector.im"], + "trusted_third_party_id_servers", ["matrix.org", "vector.im"] ) self.default_identity_server = config.get("default_identity_server") self.allow_guest_access = config.get("allow_guest_access", False) - self.invite_3pid_guest = ( - self.allow_guest_access and config.get("invite_3pid_guest", False) + self.invite_3pid_guest = self.allow_guest_access and config.get( + "invite_3pid_guest", False ) self.auto_join_rooms = config.get("auto_join_rooms", []) for room_alias in self.auto_join_rooms: if not RoomAlias.is_valid(room_alias): - raise ConfigError('Invalid auto_join_rooms entry %s' % (room_alias,)) + raise ConfigError("Invalid auto_join_rooms entry %s" % (room_alias,)) self.autocreate_auto_join_rooms = config.get("autocreate_auto_join_rooms", True) - self.disable_msisdn_registration = ( - config.get("disable_msisdn_registration", False) + self.disable_msisdn_registration = config.get( + "disable_msisdn_registration", False ) def default_config(self, generate_secrets=False, **kwargs): @@ -93,9 +91,12 @@ def default_config(self, generate_secrets=False, **kwargs): random_string_with_symbols(50), ) else: - registration_shared_secret = '# registration_shared_secret: ' + registration_shared_secret = ( + "# registration_shared_secret: " + ) - return """\ + return ( + """\ ## Registration ## # # Registration can be rate-limited using the parameters in the "Ratelimiting" @@ -217,17 +218,19 @@ def default_config(self, generate_secrets=False, **kwargs): # users cannot be auto-joined since they do not exist. # #autocreate_auto_join_rooms: true - """ % locals() + """ + % locals() + ) def add_arguments(self, parser): reg_group = parser.add_argument_group("registration") reg_group.add_argument( - "--enable-registration", action="store_true", default=None, - help="Enable registration for new users." + "--enable-registration", + action="store_true", + default=None, + help="Enable registration for new users.", ) def read_arguments(self, args): if args.enable_registration is not None: - self.enable_registration = bool( - strtobool(str(args.enable_registration)) - ) + self.enable_registration = bool(strtobool(str(args.enable_registration))) diff --git a/synapse/config/repository.py b/synapse/config/repository.py index fbfcecc240d1..9f9669ebb14e 100644 --- a/synapse/config/repository.py +++ b/synapse/config/repository.py @@ -20,27 +20,11 @@ from ._base import Config, ConfigError DEFAULT_THUMBNAIL_SIZES = [ - { - "width": 32, - "height": 32, - "method": "crop", - }, { - "width": 96, - "height": 96, - "method": "crop", - }, { - "width": 320, - "height": 240, - "method": "scale", - }, { - "width": 640, - "height": 480, - "method": "scale", - }, { - "width": 800, - "height": 600, - "method": "scale" - }, + {"width": 32, "height": 32, "method": "crop"}, + {"width": 96, "height": 96, "method": "crop"}, + {"width": 320, "height": 240, "method": "scale"}, + {"width": 640, "height": 480, "method": "scale"}, + {"width": 800, "height": 600, "method": "scale"}, ] THUMBNAIL_SIZE_YAML = """\ @@ -49,19 +33,15 @@ # method: %(method)s """ -MISSING_NETADDR = ( - "Missing netaddr library. This is required for URL preview API." -) +MISSING_NETADDR = "Missing netaddr library. This is required for URL preview API." -MISSING_LXML = ( - """Missing lxml library. This is required for URL preview API. +MISSING_LXML = """Missing lxml library. This is required for URL preview API. Install by running: pip install lxml Requires libxslt1-dev system package. """ -) ThumbnailRequirement = namedtuple( @@ -69,7 +49,8 @@ ) MediaStorageProviderConfig = namedtuple( - "MediaStorageProviderConfig", ( + "MediaStorageProviderConfig", + ( "store_local", # Whether to store newly uploaded local files "store_remote", # Whether to store newly downloaded remote files "store_synchronous", # Whether to wait for successful storage for local uploads @@ -100,8 +81,7 @@ def parse_thumbnail_requirements(thumbnail_sizes): requirements.setdefault("image/gif", []).append(png_thumbnail) requirements.setdefault("image/png", []).append(png_thumbnail) return { - media_type: tuple(thumbnails) - for media_type, thumbnails in requirements.items() + media_type: tuple(thumbnails) for media_type, thumbnails in requirements.items() } @@ -127,15 +107,15 @@ def read_config(self, config): "Cannot use both 'backup_media_store_path' and 'storage_providers'" ) - storage_providers = [{ - "module": "file_system", - "store_local": True, - "store_synchronous": synchronous_backup_media_store, - "store_remote": True, - "config": { - "directory": backup_media_store_path, + storage_providers = [ + { + "module": "file_system", + "store_local": True, + "store_synchronous": synchronous_backup_media_store, + "store_remote": True, + "config": {"directory": backup_media_store_path}, } - }] + ] # This is a list of config that can be used to create the storage # providers. The entries are tuples of (Class, class_config, @@ -165,18 +145,19 @@ def read_config(self, config): ) self.media_storage_providers.append( - (provider_class, parsed_config, wrapper_config,) + (provider_class, parsed_config, wrapper_config) ) self.uploads_path = self.ensure_directory(config["uploads_path"]) self.dynamic_thumbnails = config.get("dynamic_thumbnails", False) self.thumbnail_requirements = parse_thumbnail_requirements( - config.get("thumbnail_sizes", DEFAULT_THUMBNAIL_SIZES), + config.get("thumbnail_sizes", DEFAULT_THUMBNAIL_SIZES) ) self.url_preview_enabled = config.get("url_preview_enabled", False) if self.url_preview_enabled: try: import lxml + lxml # To stop unused lint. except ImportError: raise ConfigError(MISSING_LXML) @@ -199,15 +180,13 @@ def read_config(self, config): # we always blacklist '0.0.0.0' and '::', which are supposed to be # unroutable addresses. - self.url_preview_ip_range_blacklist.update(['0.0.0.0', '::']) + self.url_preview_ip_range_blacklist.update(["0.0.0.0", "::"]) self.url_preview_ip_range_whitelist = IPSet( config.get("url_preview_ip_range_whitelist", ()) ) - self.url_preview_url_blacklist = config.get( - "url_preview_url_blacklist", () - ) + self.url_preview_url_blacklist = config.get("url_preview_url_blacklist", ()) def default_config(self, data_dir_path, **kwargs): media_store = os.path.join(data_dir_path, "media_store") @@ -219,7 +198,8 @@ def default_config(self, data_dir_path, **kwargs): # strip final NL formatted_thumbnail_sizes = formatted_thumbnail_sizes[:-1] - return r""" + return ( + r""" # Directory where uploaded images and attachments are stored. # media_store_path: "%(media_store)s" @@ -342,4 +322,6 @@ def default_config(self, data_dir_path, **kwargs): # The largest allowed URL preview spidering size in bytes # #max_spider_size: 10M - """ % locals() + """ + % locals() + ) diff --git a/synapse/config/room_directory.py b/synapse/config/room_directory.py index 8a9fded4c55c..c1da0e20e0ae 100644 --- a/synapse/config/room_directory.py +++ b/synapse/config/room_directory.py @@ -20,9 +20,7 @@ class RoomDirectoryConfig(Config): def read_config(self, config): - self.enable_room_list_search = config.get( - "enable_room_list_search", True, - ) + self.enable_room_list_search = config.get("enable_room_list_search", True) alias_creation_rules = config.get("alias_creation_rules") @@ -33,11 +31,7 @@ def read_config(self, config): ] else: self._alias_creation_rules = [ - _RoomDirectoryRule( - "alias_creation_rules", { - "action": "allow", - } - ) + _RoomDirectoryRule("alias_creation_rules", {"action": "allow"}) ] room_list_publication_rules = config.get("room_list_publication_rules") @@ -49,11 +43,7 @@ def read_config(self, config): ] else: self._room_list_publication_rules = [ - _RoomDirectoryRule( - "room_list_publication_rules", { - "action": "allow", - } - ) + _RoomDirectoryRule("room_list_publication_rules", {"action": "allow"}) ] def default_config(self, config_dir_path, server_name, **kwargs): @@ -178,8 +168,7 @@ def __init__(self, option_name, rule): self.action = action else: raise ConfigError( - "%s rules can only have action of 'allow'" - " or 'deny'" % (option_name,) + "%s rules can only have action of 'allow'" " or 'deny'" % (option_name,) ) self._alias_matches_all = alias == "*" diff --git a/synapse/config/saml2_config.py b/synapse/config/saml2_config.py index aa6eac271ff7..2ec38e48e9a4 100644 --- a/synapse/config/saml2_config.py +++ b/synapse/config/saml2_config.py @@ -28,6 +28,7 @@ def read_config(self, config): self.saml2_enabled = True import saml2.config + self.saml2_sp_config = saml2.config.SPConfig() self.saml2_sp_config.load(self._default_saml_config_dict()) self.saml2_sp_config.load(saml2_config.get("sp_config", {})) @@ -41,26 +42,23 @@ def _default_saml_config_dict(self): public_baseurl = self.public_baseurl if public_baseurl is None: - raise ConfigError( - "saml2_config requires a public_baseurl to be set" - ) + raise ConfigError("saml2_config requires a public_baseurl to be set") metadata_url = public_baseurl + "_matrix/saml2/metadata.xml" response_url = public_baseurl + "_matrix/saml2/authn_response" return { "entityid": metadata_url, - "service": { "sp": { "endpoints": { "assertion_consumer_service": [ - (response_url, saml2.BINDING_HTTP_POST), - ], + (response_url, saml2.BINDING_HTTP_POST) + ] }, "required_attributes": ["uid"], "optional_attributes": ["mail", "surname", "givenname"], - }, - } + } + }, } def default_config(self, config_dir_path, server_name, **kwargs): @@ -106,4 +104,6 @@ def default_config(self, config_dir_path, server_name, **kwargs): # # separate pysaml2 configuration file: # # # config_path: "%(config_dir_path)s/sp_conf.py" - """ % {"config_dir_path": config_dir_path} + """ % { + "config_dir_path": config_dir_path + } diff --git a/synapse/config/server.py b/synapse/config/server.py index 6e5b46e6c318..6d3f1da96c8c 100644 --- a/synapse/config/server.py +++ b/synapse/config/server.py @@ -34,13 +34,12 @@ # # We later check for errors when binding to 0.0.0.0 and ignore them if :: is also in # in the list. -DEFAULT_BIND_ADDRESSES = ['::', '0.0.0.0'] +DEFAULT_BIND_ADDRESSES = ["::", "0.0.0.0"] DEFAULT_ROOM_VERSION = "4" class ServerConfig(Config): - def read_config(self, config): self.server_name = config["server_name"] self.server_context = config.get("server_context", None) @@ -81,27 +80,25 @@ def read_config(self, config): # Whether to require authentication to retrieve profile data (avatars, # display names) of other users through the client API. self.require_auth_for_profile_requests = config.get( - "require_auth_for_profile_requests", False, + "require_auth_for_profile_requests", False ) # If set to 'True', requires authentication to access the server's # public rooms directory through the client API, and forbids any other # homeserver to fetch it via federation. self.restrict_public_rooms_to_local_users = config.get( - "restrict_public_rooms_to_local_users", False, + "restrict_public_rooms_to_local_users", False ) - default_room_version = config.get( - "default_room_version", DEFAULT_ROOM_VERSION, - ) + default_room_version = config.get("default_room_version", DEFAULT_ROOM_VERSION) # Ensure room version is a str default_room_version = str(default_room_version) if default_room_version not in KNOWN_ROOM_VERSIONS: raise ConfigError( - "Unknown default_room_version: %s, known room versions: %s" % - (default_room_version, list(KNOWN_ROOM_VERSIONS.keys())) + "Unknown default_room_version: %s, known room versions: %s" + % (default_room_version, list(KNOWN_ROOM_VERSIONS.keys())) ) # Get the actual room version object rather than just the identifier @@ -116,31 +113,25 @@ def read_config(self, config): # Whether we should block invites sent to users on this server # (other than those sent by local server admins) - self.block_non_admin_invites = config.get( - "block_non_admin_invites", False, - ) + self.block_non_admin_invites = config.get("block_non_admin_invites", False) # Whether to enable experimental MSC1849 (aka relations) support self.experimental_msc1849_support_enabled = config.get( - "experimental_msc1849_support_enabled", False, + "experimental_msc1849_support_enabled", False ) # Options to control access by tracking MAU self.limit_usage_by_mau = config.get("limit_usage_by_mau", False) self.max_mau_value = 0 if self.limit_usage_by_mau: - self.max_mau_value = config.get( - "max_mau_value", 0, - ) + self.max_mau_value = config.get("max_mau_value", 0) self.mau_stats_only = config.get("mau_stats_only", False) self.mau_limits_reserved_threepids = config.get( "mau_limit_reserved_threepids", [] ) - self.mau_trial_days = config.get( - "mau_trial_days", 0, - ) + self.mau_trial_days = config.get("mau_trial_days", 0) # Options to disable HS self.hs_disabled = config.get("hs_disabled", False) @@ -153,9 +144,7 @@ def read_config(self, config): # FIXME: federation_domain_whitelist needs sytests self.federation_domain_whitelist = None - federation_domain_whitelist = config.get( - "federation_domain_whitelist", None, - ) + federation_domain_whitelist = config.get("federation_domain_whitelist", None) if federation_domain_whitelist is not None: # turn the whitelist into a hash for speed of lookup @@ -165,7 +154,7 @@ def read_config(self, config): self.federation_domain_whitelist[domain] = True self.federation_ip_range_blacklist = config.get( - "federation_ip_range_blacklist", [], + "federation_ip_range_blacklist", [] ) # Attempt to create an IPSet from the given ranges @@ -178,13 +167,12 @@ def read_config(self, config): self.federation_ip_range_blacklist.update(["0.0.0.0", "::"]) except Exception as e: raise ConfigError( - "Invalid range(s) provided in " - "federation_ip_range_blacklist: %s" % e + "Invalid range(s) provided in " "federation_ip_range_blacklist: %s" % e ) if self.public_baseurl is not None: - if self.public_baseurl[-1] != '/': - self.public_baseurl += '/' + if self.public_baseurl[-1] != "/": + self.public_baseurl += "/" self.start_pushers = config.get("start_pushers", True) # (undocumented) option for torturing the worker-mode replication a bit, @@ -195,7 +183,7 @@ def read_config(self, config): # Whether to require a user to be in the room to add an alias to it. # Defaults to True. self.require_membership_for_aliases = config.get( - "require_membership_for_aliases", True, + "require_membership_for_aliases", True ) # Whether to allow per-room membership profiles through the send of membership @@ -227,9 +215,9 @@ def read_config(self, config): # if we still have an empty list of addresses, use the default list if not bind_addresses: - if listener['type'] == 'metrics': + if listener["type"] == "metrics": # the metrics listener doesn't support IPv6 - bind_addresses.append('0.0.0.0') + bind_addresses.append("0.0.0.0") else: bind_addresses.extend(DEFAULT_BIND_ADDRESSES) @@ -249,78 +237,72 @@ def read_config(self, config): bind_host = config.get("bind_host", "") gzip_responses = config.get("gzip_responses", True) - self.listeners.append({ - "port": bind_port, - "bind_addresses": [bind_host], - "tls": True, - "type": "http", - "resources": [ - { - "names": ["client"], - "compress": gzip_responses, - }, - { - "names": ["federation"], - "compress": False, - } - ] - }) - - unsecure_port = config.get("unsecure_port", bind_port - 400) - if unsecure_port: - self.listeners.append({ - "port": unsecure_port, + self.listeners.append( + { + "port": bind_port, "bind_addresses": [bind_host], - "tls": False, + "tls": True, "type": "http", "resources": [ - { - "names": ["client"], - "compress": gzip_responses, - }, - { - "names": ["federation"], - "compress": False, - } - ] - }) + {"names": ["client"], "compress": gzip_responses}, + {"names": ["federation"], "compress": False}, + ], + } + ) + + unsecure_port = config.get("unsecure_port", bind_port - 400) + if unsecure_port: + self.listeners.append( + { + "port": unsecure_port, + "bind_addresses": [bind_host], + "tls": False, + "type": "http", + "resources": [ + {"names": ["client"], "compress": gzip_responses}, + {"names": ["federation"], "compress": False}, + ], + } + ) manhole = config.get("manhole") if manhole: - self.listeners.append({ - "port": manhole, - "bind_addresses": ["127.0.0.1"], - "type": "manhole", - "tls": False, - }) + self.listeners.append( + { + "port": manhole, + "bind_addresses": ["127.0.0.1"], + "type": "manhole", + "tls": False, + } + ) metrics_port = config.get("metrics_port") if metrics_port: logger.warn( - ("The metrics_port configuration option is deprecated in Synapse 0.31 " - "in favour of a listener. Please see " - "http://github.com/matrix-org/synapse/blob/master/docs/metrics-howto.rst" - " on how to configure the new listener.")) - - self.listeners.append({ - "port": metrics_port, - "bind_addresses": [config.get("metrics_bind_host", "127.0.0.1")], - "tls": False, - "type": "http", - "resources": [ - { - "names": ["metrics"], - "compress": False, - }, - ] - }) + ( + "The metrics_port configuration option is deprecated in Synapse 0.31 " + "in favour of a listener. Please see " + "http://github.com/matrix-org/synapse/blob/master/docs/metrics-howto.rst" + " on how to configure the new listener." + ) + ) + + self.listeners.append( + { + "port": metrics_port, + "bind_addresses": [config.get("metrics_bind_host", "127.0.0.1")], + "tls": False, + "type": "http", + "resources": [{"names": ["metrics"], "compress": False}], + } + ) _check_resource_config(self.listeners) # An experimental option to try and periodically clean up extremities # by sending dummy events. self.cleanup_extremities_with_dummy_events = config.get( - "cleanup_extremities_with_dummy_events", False, + "cleanup_extremities_with_dummy_events", False ) def has_tls_listener(self): @@ -339,7 +321,8 @@ def default_config(self, server_name, data_dir_path, **kwargs): # Bring DEFAULT_ROOM_VERSION into the local-scope for use in the # default config string default_room_version = DEFAULT_ROOM_VERSION - return """\ + return ( + """\ ## Server ## # The domain name of the server, with optional explicit port. @@ -637,7 +620,9 @@ def default_config(self, server_name, data_dir_path, **kwargs): # Defaults to 'true'. # #allow_per_room_profiles: false - """ % locals() + """ + % locals() + ) def read_arguments(self, args): if args.manhole is not None: @@ -649,17 +634,26 @@ def read_arguments(self, args): def add_arguments(self, parser): server_group = parser.add_argument_group("server") - server_group.add_argument("-D", "--daemonize", action='store_true', - default=None, - help="Daemonize the home server") - server_group.add_argument("--print-pidfile", action='store_true', - default=None, - help="Print the path to the pidfile just" - " before daemonizing") - server_group.add_argument("--manhole", metavar="PORT", dest="manhole", - type=int, - help="Turn on the twisted telnet manhole" - " service on the given port.") + server_group.add_argument( + "-D", + "--daemonize", + action="store_true", + default=None, + help="Daemonize the home server", + ) + server_group.add_argument( + "--print-pidfile", + action="store_true", + default=None, + help="Print the path to the pidfile just" " before daemonizing", + ) + server_group.add_argument( + "--manhole", + metavar="PORT", + dest="manhole", + type=int, + help="Turn on the twisted telnet manhole" " service on the given port.", + ) def is_threepid_reserved(reserved_threepids, threepid): @@ -673,7 +667,7 @@ def is_threepid_reserved(reserved_threepids, threepid): """ for tp in reserved_threepids: - if (threepid['medium'] == tp['medium'] and threepid['address'] == tp['address']): + if threepid["medium"] == tp["medium"] and threepid["address"] == tp["address"]: return True return False @@ -686,9 +680,7 @@ def read_gc_thresholds(thresholds): return None try: assert len(thresholds) == 3 - return ( - int(thresholds[0]), int(thresholds[1]), int(thresholds[2]), - ) + return (int(thresholds[0]), int(thresholds[1]), int(thresholds[2])) except Exception: raise ConfigError( "Value of `gc_threshold` must be a list of three integers if set" @@ -706,22 +698,22 @@ def _warn_if_webclient_configured(listeners): for listener in listeners: for res in listener.get("resources", []): for name in res.get("names", []): - if name == 'webclient': + if name == "webclient": logger.warning(NO_MORE_WEB_CLIENT_WARNING) return KNOWN_RESOURCES = ( - 'client', - 'consent', - 'federation', - 'keys', - 'media', - 'metrics', - 'openid', - 'replication', - 'static', - 'webclient', + "client", + "consent", + "federation", + "keys", + "media", + "metrics", + "openid", + "replication", + "static", + "webclient", ) @@ -735,11 +727,9 @@ def _check_resource_config(listeners): for resource in resource_names: if resource not in KNOWN_RESOURCES: - raise ConfigError( - "Unknown listener resource '%s'" % (resource, ) - ) + raise ConfigError("Unknown listener resource '%s'" % (resource,)) if resource == "consent": try: - check_requirements('resources.consent') + check_requirements("resources.consent") except DependencyException as e: raise ConfigError(e.message) diff --git a/synapse/config/server_notices_config.py b/synapse/config/server_notices_config.py index 529dc0a61790..d930eb33b57f 100644 --- a/synapse/config/server_notices_config.py +++ b/synapse/config/server_notices_config.py @@ -58,6 +58,7 @@ class ServerNoticesConfig(Config): The name to use for the server notices room. None if server notices are not enabled. """ + def __init__(self): super(ServerNoticesConfig, self).__init__() self.server_notices_mxid = None @@ -70,18 +71,12 @@ def read_config(self, config): if c is None: return - mxid_localpart = c['system_mxid_localpart'] - self.server_notices_mxid = UserID( - mxid_localpart, self.server_name, - ).to_string() - self.server_notices_mxid_display_name = c.get( - 'system_mxid_display_name', None, - ) - self.server_notices_mxid_avatar_url = c.get( - 'system_mxid_avatar_url', None, - ) + mxid_localpart = c["system_mxid_localpart"] + self.server_notices_mxid = UserID(mxid_localpart, self.server_name).to_string() + self.server_notices_mxid_display_name = c.get("system_mxid_display_name", None) + self.server_notices_mxid_avatar_url = c.get("system_mxid_avatar_url", None) # todo: i18n - self.server_notices_room_name = c.get('room_name', "Server Notices") + self.server_notices_room_name = c.get("room_name", "Server Notices") def default_config(self, **kwargs): return DEFAULT_CONFIG diff --git a/synapse/config/tls.py b/synapse/config/tls.py index 658f9dd3618d..7951bf21faf5 100644 --- a/synapse/config/tls.py +++ b/synapse/config/tls.py @@ -42,11 +42,11 @@ def read_config(self, config): self.acme_enabled = acme_config.get("enabled", False) # hyperlink complains on py2 if this is not a Unicode - self.acme_url = six.text_type(acme_config.get( - "url", u"https://acme-v01.api.letsencrypt.org/directory" - )) + self.acme_url = six.text_type( + acme_config.get("url", "https://acme-v01.api.letsencrypt.org/directory") + ) self.acme_port = acme_config.get("port", 80) - self.acme_bind_addresses = acme_config.get("bind_addresses", ['::', '0.0.0.0']) + self.acme_bind_addresses = acme_config.get("bind_addresses", ["::", "0.0.0.0"]) self.acme_reprovision_threshold = acme_config.get("reprovision_threshold", 30) self.acme_domain = acme_config.get("domain", config.get("server_name")) @@ -74,12 +74,12 @@ def read_config(self, config): # Whether to verify certificates on outbound federation traffic self.federation_verify_certificates = config.get( - "federation_verify_certificates", True, + "federation_verify_certificates", True ) # Whitelist of domains to not verify certificates for fed_whitelist_entries = config.get( - "federation_certificate_verification_whitelist", [], + "federation_certificate_verification_whitelist", [] ) # Support globs (*) in whitelist values @@ -90,9 +90,7 @@ def read_config(self, config): self.federation_certificate_verification_whitelist.append(entry_regex) # List of custom certificate authorities for federation traffic validation - custom_ca_list = config.get( - "federation_custom_ca_list", None, - ) + custom_ca_list = config.get("federation_custom_ca_list", None) # Read in and parse custom CA certificates self.federation_ca_trust_root = None @@ -101,8 +99,10 @@ def read_config(self, config): # A trustroot cannot be generated without any CA certificates. # Raise an error if this option has been specified without any # corresponding certificates. - raise ConfigError("federation_custom_ca_list specified without " - "any certificate files") + raise ConfigError( + "federation_custom_ca_list specified without " + "any certificate files" + ) certs = [] for ca_file in custom_ca_list: @@ -114,8 +114,9 @@ def read_config(self, config): cert_base = Certificate.loadPEM(content) certs.append(cert_base) except Exception as e: - raise ConfigError("Error parsing custom CA certificate file %s: %s" - % (ca_file, e)) + raise ConfigError( + "Error parsing custom CA certificate file %s: %s" % (ca_file, e) + ) self.federation_ca_trust_root = trustRootFromCertificates(certs) @@ -146,17 +147,21 @@ def is_disk_cert_valid(self, allow_self_signed=True): return None try: - with open(self.tls_certificate_file, 'rb') as f: + with open(self.tls_certificate_file, "rb") as f: cert_pem = f.read() except Exception as e: - raise ConfigError("Failed to read existing certificate file %s: %s" - % (self.tls_certificate_file, e)) + raise ConfigError( + "Failed to read existing certificate file %s: %s" + % (self.tls_certificate_file, e) + ) try: tls_certificate = crypto.load_certificate(crypto.FILETYPE_PEM, cert_pem) except Exception as e: - raise ConfigError("Failed to parse existing certificate file %s: %s" - % (self.tls_certificate_file, e)) + raise ConfigError( + "Failed to parse existing certificate file %s: %s" + % (self.tls_certificate_file, e) + ) if not allow_self_signed: if tls_certificate.get_subject() == tls_certificate.get_issuer(): @@ -166,7 +171,7 @@ def is_disk_cert_valid(self, allow_self_signed=True): # YYYYMMDDhhmmssZ -- in UTC expires_on = datetime.strptime( - tls_certificate.get_notAfter().decode('ascii'), "%Y%m%d%H%M%SZ" + tls_certificate.get_notAfter().decode("ascii"), "%Y%m%d%H%M%SZ" ) now = datetime.utcnow() days_remaining = (expires_on - now).days @@ -191,7 +196,8 @@ def read_certificate_from_disk(self, require_cert_and_key): except Exception as e: logger.info( "Unable to read TLS certificate (%s). Ignoring as no " - "tls listeners enabled.", e, + "tls listeners enabled.", + e, ) self.tls_fingerprints = list(self._original_tls_fingerprints) @@ -205,7 +211,7 @@ def read_certificate_from_disk(self, require_cert_and_key): sha256_fingerprint = encode_base64(sha256(x509_certificate_bytes).digest()) sha256_fingerprints = set(f["sha256"] for f in self.tls_fingerprints) if sha256_fingerprint not in sha256_fingerprints: - self.tls_fingerprints.append({u"sha256": sha256_fingerprint}) + self.tls_fingerprints.append({"sha256": sha256_fingerprint}) def default_config(self, config_dir_path, server_name, **kwargs): base_key_name = os.path.join(config_dir_path, server_name) @@ -215,8 +221,8 @@ def default_config(self, config_dir_path, server_name, **kwargs): # this is to avoid the max line length. Sorrynotsorry proxypassline = ( - 'ProxyPass /.well-known/acme-challenge ' - 'http://localhost:8009/.well-known/acme-challenge' + "ProxyPass /.well-known/acme-challenge " + "http://localhost:8009/.well-known/acme-challenge" ) return ( diff --git a/synapse/config/user_directory.py b/synapse/config/user_directory.py index 023997ccdeb2..e031b115993f 100644 --- a/synapse/config/user_directory.py +++ b/synapse/config/user_directory.py @@ -26,11 +26,11 @@ def read_config(self, config): self.user_directory_search_all_users = False user_directory_config = config.get("user_directory", None) if user_directory_config: - self.user_directory_search_enabled = ( - user_directory_config.get("enabled", True) + self.user_directory_search_enabled = user_directory_config.get( + "enabled", True ) - self.user_directory_search_all_users = ( - user_directory_config.get("search_all_users", False) + self.user_directory_search_all_users = user_directory_config.get( + "search_all_users", False ) def default_config(self, config_dir_path, server_name, **kwargs): diff --git a/synapse/config/voip.py b/synapse/config/voip.py index 2a1f005a37d3..82cf8c53a8ce 100644 --- a/synapse/config/voip.py +++ b/synapse/config/voip.py @@ -16,14 +16,13 @@ class VoipConfig(Config): - def read_config(self, config): self.turn_uris = config.get("turn_uris", []) self.turn_shared_secret = config.get("turn_shared_secret") self.turn_username = config.get("turn_username") self.turn_password = config.get("turn_password") self.turn_user_lifetime = self.parse_duration( - config.get("turn_user_lifetime", "1h"), + config.get("turn_user_lifetime", "1h") ) self.turn_allow_guests = config.get("turn_allow_guests", True) diff --git a/synapse/config/workers.py b/synapse/config/workers.py index bfbd8b6c9142..75993abf3563 100644 --- a/synapse/config/workers.py +++ b/synapse/config/workers.py @@ -52,12 +52,14 @@ def read_config(self, config): # argument. manhole = config.get("worker_manhole") if manhole: - self.worker_listeners.append({ - "port": manhole, - "bind_addresses": ["127.0.0.1"], - "type": "manhole", - "tls": False, - }) + self.worker_listeners.append( + { + "port": manhole, + "bind_addresses": ["127.0.0.1"], + "type": "manhole", + "tls": False, + } + ) if self.worker_listeners: for listener in self.worker_listeners: @@ -67,7 +69,7 @@ def read_config(self, config): if bind_address: bind_addresses.append(bind_address) elif not bind_addresses: - bind_addresses.append('') + bind_addresses.append("") def read_arguments(self, args): # We support a bunch of command line arguments that override options in diff --git a/synapse/crypto/event_signing.py b/synapse/crypto/event_signing.py index 99a586655b3a..41eabbe7171d 100644 --- a/synapse/crypto/event_signing.py +++ b/synapse/crypto/event_signing.py @@ -46,9 +46,7 @@ def check_event_content_hash(event, hash_algorithm=hashlib.sha256): if name not in hashes: raise SynapseError( 400, - "Algorithm %s not in hashes %s" % ( - name, list(hashes), - ), + "Algorithm %s not in hashes %s" % (name, list(hashes)), Codes.UNAUTHORIZED, ) message_hash_base64 = hashes[name] @@ -56,9 +54,7 @@ def check_event_content_hash(event, hash_algorithm=hashlib.sha256): message_hash_bytes = decode_base64(message_hash_base64) except Exception: raise SynapseError( - 400, - "Invalid base64: %s" % (message_hash_base64,), - Codes.UNAUTHORIZED, + 400, "Invalid base64: %s" % (message_hash_base64,), Codes.UNAUTHORIZED ) return message_hash_bytes == expected_hash @@ -135,8 +131,9 @@ def compute_event_signature(event_dict, signature_name, signing_key): return redact_json["signatures"] -def add_hashes_and_signatures(event_dict, signature_name, signing_key, - hash_algorithm=hashlib.sha256): +def add_hashes_and_signatures( + event_dict, signature_name, signing_key, hash_algorithm=hashlib.sha256 +): """Add content hash and sign the event Args: @@ -153,7 +150,5 @@ def add_hashes_and_signatures(event_dict, signature_name, signing_key, event_dict.setdefault("hashes", {})[name] = encode_base64(digest) event_dict["signatures"] = compute_event_signature( - event_dict, - signature_name=signature_name, - signing_key=signing_key, + event_dict, signature_name=signature_name, signing_key=signing_key ) diff --git a/synapse/crypto/keyring.py b/synapse/crypto/keyring.py index 6f603f19615e..10c2eb7f0fa7 100644 --- a/synapse/crypto/keyring.py +++ b/synapse/crypto/keyring.py @@ -505,7 +505,7 @@ def process_v2_response(self, from_server, response_json, time_added_ms): Returns: Deferred[dict[str, FetchKeyResult]]: map from key_id to result object """ - ts_valid_until_ms = response_json[u"valid_until_ts"] + ts_valid_until_ms = response_json["valid_until_ts"] # start by extracting the keys from the response, since they may be required # to validate the signature on the response. @@ -614,10 +614,7 @@ def get_key(key_server): results = yield logcontext.make_deferred_yieldable( defer.gatherResults( - [ - run_in_background(get_key, server) - for server in self.key_servers - ], + [run_in_background(get_key, server) for server in self.key_servers], consumeErrors=True, ).addErrback(unwrapFirstError) ) @@ -630,9 +627,7 @@ def get_key(key_server): defer.returnValue(union_of_keys) @defer.inlineCallbacks - def get_server_verify_key_v2_indirect( - self, keys_to_fetch, key_server - ): + def get_server_verify_key_v2_indirect(self, keys_to_fetch, key_server): """ Args: keys_to_fetch (dict[str, dict[str, int]]): @@ -661,9 +656,9 @@ def get_server_verify_key_v2_indirect( destination=perspective_name, path="/_matrix/key/v2/query", data={ - u"server_keys": { + "server_keys": { server_name: { - key_id: {u"minimum_valid_until_ts": min_valid_ts} + key_id: {"minimum_valid_until_ts": min_valid_ts} for key_id, min_valid_ts in server_keys.items() } for server_name, server_keys in keys_to_fetch.items() @@ -690,10 +685,7 @@ def get_server_verify_key_v2_indirect( ) try: - self._validate_perspectives_response( - key_server, - response, - ) + self._validate_perspectives_response(key_server, response) processed_response = yield self.process_v2_response( perspective_name, response, time_added_ms=time_now_ms @@ -720,9 +712,7 @@ def get_server_verify_key_v2_indirect( defer.returnValue(keys) - def _validate_perspectives_response( - self, key_server, response, - ): + def _validate_perspectives_response(self, key_server, response): """Optionally check the signature on the result of a /key/query request Args: @@ -739,13 +729,13 @@ def _validate_perspectives_response( return if ( - u"signatures" not in response - or perspective_name not in response[u"signatures"] + "signatures" not in response + or perspective_name not in response["signatures"] ): raise KeyLookupError("Response not signed by the notary server") verified = False - for key_id in response[u"signatures"][perspective_name]: + for key_id in response["signatures"][perspective_name]: if key_id in perspective_keys: verify_signed_json(response, perspective_name, perspective_keys[key_id]) verified = True @@ -754,7 +744,7 @@ def _validate_perspectives_response( raise KeyLookupError( "Response not signed with a known key: signed with: %r, known keys: %r" % ( - list(response[u"signatures"][perspective_name].keys()), + list(response["signatures"][perspective_name].keys()), list(perspective_keys.keys()), ) ) @@ -826,7 +816,6 @@ def get_server_verify_key_v2_direct(self, server_name, key_ids): path="/_matrix/key/v2/server/" + urllib.parse.quote(requested_key_id), ignore_backoff=True, - # we only give the remote server 10s to respond. It should be an # easy request to handle, so if it doesn't reply within 10s, it's # probably not going to. diff --git a/synapse/event_auth.py b/synapse/event_auth.py index 203490fc36d8..cd52e3f867fa 100644 --- a/synapse/event_auth.py +++ b/synapse/event_auth.py @@ -85,17 +85,14 @@ def check(room_version, event, auth_events, do_sig_check=True, do_size_check=Tru room_id_domain = get_domain_from_id(event.room_id) if room_id_domain != sender_domain: raise AuthError( - 403, - "Creation event's room_id domain does not match sender's" + 403, "Creation event's room_id domain does not match sender's" ) room_version = event.content.get("room_version", "1") if room_version not in KNOWN_ROOM_VERSIONS: raise AuthError( - 403, - "room appears to have unsupported version %s" % ( - room_version, - )) + 403, "room appears to have unsupported version %s" % (room_version,) + ) # FIXME logger.debug("Allowing! %s", event) return @@ -103,46 +100,30 @@ def check(room_version, event, auth_events, do_sig_check=True, do_size_check=Tru creation_event = auth_events.get((EventTypes.Create, ""), None) if not creation_event: - raise AuthError( - 403, - "No create event in auth events", - ) + raise AuthError(403, "No create event in auth events") creating_domain = get_domain_from_id(event.room_id) originating_domain = get_domain_from_id(event.sender) if creating_domain != originating_domain: if not _can_federate(event, auth_events): - raise AuthError( - 403, - "This room has been marked as unfederatable." - ) + raise AuthError(403, "This room has been marked as unfederatable.") # FIXME: Temp hack if event.type == EventTypes.Aliases: if not event.is_state(): - raise AuthError( - 403, - "Alias event must be a state event", - ) + raise AuthError(403, "Alias event must be a state event") if not event.state_key: - raise AuthError( - 403, - "Alias event must have non-empty state_key" - ) + raise AuthError(403, "Alias event must have non-empty state_key") sender_domain = get_domain_from_id(event.sender) if event.state_key != sender_domain: raise AuthError( - 403, - "Alias event's state_key does not match sender's domain" + 403, "Alias event's state_key does not match sender's domain" ) logger.debug("Allowing! %s", event) return if logger.isEnabledFor(logging.DEBUG): - logger.debug( - "Auth events: %s", - [a.event_id for a in auth_events.values()] - ) + logger.debug("Auth events: %s", [a.event_id for a in auth_events.values()]) if event.type == EventTypes.Member: _is_membership_change_allowed(event, auth_events) @@ -159,9 +140,7 @@ def check(room_version, event, auth_events, do_sig_check=True, do_size_check=Tru invite_level = _get_named_level(auth_events, "invite", 0) if user_level < invite_level: - raise AuthError( - 403, "You don't have permission to invite users", - ) + raise AuthError(403, "You don't have permission to invite users") else: logger.debug("Allowing! %s", event) return @@ -207,7 +186,7 @@ def _is_membership_change_allowed(event, auth_events): # Check if this is the room creator joining: if len(event.prev_event_ids()) == 1 and Membership.JOIN == membership: # Get room creation event: - key = (EventTypes.Create, "", ) + key = (EventTypes.Create, "") create = auth_events.get(key) if create and event.prev_event_ids()[0] == create.event_id: if create.content["creator"] == event.state_key: @@ -219,38 +198,31 @@ def _is_membership_change_allowed(event, auth_events): target_domain = get_domain_from_id(target_user_id) if creating_domain != target_domain: if not _can_federate(event, auth_events): - raise AuthError( - 403, - "This room has been marked as unfederatable." - ) + raise AuthError(403, "This room has been marked as unfederatable.") # get info about the caller - key = (EventTypes.Member, event.user_id, ) + key = (EventTypes.Member, event.user_id) caller = auth_events.get(key) caller_in_room = caller and caller.membership == Membership.JOIN caller_invited = caller and caller.membership == Membership.INVITE # get info about the target - key = (EventTypes.Member, target_user_id, ) + key = (EventTypes.Member, target_user_id) target = auth_events.get(key) target_in_room = target and target.membership == Membership.JOIN target_banned = target and target.membership == Membership.BAN - key = (EventTypes.JoinRules, "", ) + key = (EventTypes.JoinRules, "") join_rule_event = auth_events.get(key) if join_rule_event: - join_rule = join_rule_event.content.get( - "join_rule", JoinRules.INVITE - ) + join_rule = join_rule_event.content.get("join_rule", JoinRules.INVITE) else: join_rule = JoinRules.INVITE user_level = get_user_power_level(event.user_id, auth_events) - target_level = get_user_power_level( - target_user_id, auth_events - ) + target_level = get_user_power_level(target_user_id, auth_events) # FIXME (erikj): What should we do here as the default? ban_level = _get_named_level(auth_events, "ban", 50) @@ -266,29 +238,26 @@ def _is_membership_change_allowed(event, auth_events): "join_rule": join_rule, "target_user_id": target_user_id, "event.user_id": event.user_id, - } + }, ) if Membership.INVITE == membership and "third_party_invite" in event.content: if not _verify_third_party_invite(event, auth_events): raise AuthError(403, "You are not invited to this room.") if target_banned: - raise AuthError( - 403, "%s is banned from the room" % (target_user_id,) - ) + raise AuthError(403, "%s is banned from the room" % (target_user_id,)) return if Membership.JOIN != membership: - if (caller_invited - and Membership.LEAVE == membership - and target_user_id == event.user_id): + if ( + caller_invited + and Membership.LEAVE == membership + and target_user_id == event.user_id + ): return if not caller_in_room: # caller isn't joined - raise AuthError( - 403, - "%s not in room %s." % (event.user_id, event.room_id,) - ) + raise AuthError(403, "%s not in room %s." % (event.user_id, event.room_id)) if Membership.INVITE == membership: # TODO (erikj): We should probably handle this more intelligently @@ -296,19 +265,14 @@ def _is_membership_change_allowed(event, auth_events): # Invites are valid iff caller is in the room and target isn't. if target_banned: - raise AuthError( - 403, "%s is banned from the room" % (target_user_id,) - ) + raise AuthError(403, "%s is banned from the room" % (target_user_id,)) elif target_in_room: # the target is already in the room. - raise AuthError(403, "%s is already in the room." % - target_user_id) + raise AuthError(403, "%s is already in the room." % target_user_id) else: invite_level = _get_named_level(auth_events, "invite", 0) if user_level < invite_level: - raise AuthError( - 403, "You don't have permission to invite users", - ) + raise AuthError(403, "You don't have permission to invite users") elif Membership.JOIN == membership: # Joins are valid iff caller == target and they were: # invited: They are accepting the invitation @@ -329,16 +293,12 @@ def _is_membership_change_allowed(event, auth_events): elif Membership.LEAVE == membership: # TODO (erikj): Implement kicks. if target_banned and user_level < ban_level: - raise AuthError( - 403, "You cannot unban user %s." % (target_user_id,) - ) + raise AuthError(403, "You cannot unban user %s." % (target_user_id,)) elif target_user_id != event.user_id: kick_level = _get_named_level(auth_events, "kick", 50) if user_level < kick_level or user_level <= target_level: - raise AuthError( - 403, "You cannot kick user %s." % target_user_id - ) + raise AuthError(403, "You cannot kick user %s." % target_user_id) elif Membership.BAN == membership: if user_level < ban_level or user_level <= target_level: raise AuthError(403, "You don't have permission to ban") @@ -347,21 +307,17 @@ def _is_membership_change_allowed(event, auth_events): def _check_event_sender_in_room(event, auth_events): - key = (EventTypes.Member, event.user_id, ) + key = (EventTypes.Member, event.user_id) member_event = auth_events.get(key) - return _check_joined_room( - member_event, - event.user_id, - event.room_id - ) + return _check_joined_room(member_event, event.user_id, event.room_id) def _check_joined_room(member, user_id, room_id): if not member or member.membership != Membership.JOIN: - raise AuthError(403, "User %s not in room %s (%s)" % ( - user_id, room_id, repr(member) - )) + raise AuthError( + 403, "User %s not in room %s (%s)" % (user_id, room_id, repr(member)) + ) def get_send_level(etype, state_key, power_levels_event): @@ -402,26 +358,21 @@ def get_send_level(etype, state_key, power_levels_event): def _can_send_event(event, auth_events): power_levels_event = _get_power_level_event(auth_events) - send_level = get_send_level( - event.type, event.get("state_key"), power_levels_event, - ) + send_level = get_send_level(event.type, event.get("state_key"), power_levels_event) user_level = get_user_power_level(event.user_id, auth_events) if user_level < send_level: raise AuthError( 403, - "You don't have permission to post that to the room. " + - "user_level (%d) < send_level (%d)" % (user_level, send_level) + "You don't have permission to post that to the room. " + + "user_level (%d) < send_level (%d)" % (user_level, send_level), ) # Check state_key if hasattr(event, "state_key"): if event.state_key.startswith("@"): if event.state_key != event.user_id: - raise AuthError( - 403, - "You are not allowed to set others state" - ) + raise AuthError(403, "You are not allowed to set others state") return True @@ -459,10 +410,7 @@ def check_redaction(room_version, event, auth_events): event.internal_metadata.recheck_redaction = True return True - raise AuthError( - 403, - "You don't have permission to redact events" - ) + raise AuthError(403, "You don't have permission to redact events") def _check_power_levels(event, auth_events): @@ -479,7 +427,7 @@ def _check_power_levels(event, auth_events): except Exception: raise SynapseError(400, "Not a valid power level: %s" % (v,)) - key = (event.type, event.state_key, ) + key = (event.type, event.state_key) current_state = auth_events.get(key) if not current_state: @@ -500,16 +448,12 @@ def _check_power_levels(event, auth_events): old_list = current_state.content.get("users", {}) for user in set(list(old_list) + list(user_list)): - levels_to_check.append( - (user, "users") - ) + levels_to_check.append((user, "users")) old_list = current_state.content.get("events", {}) new_list = event.content.get("events", {}) for ev_id in set(list(old_list) + list(new_list)): - levels_to_check.append( - (ev_id, "events") - ) + levels_to_check.append((ev_id, "events")) old_state = current_state.content new_state = event.content @@ -540,7 +484,7 @@ def _check_power_levels(event, auth_events): raise AuthError( 403, "You don't have permission to remove ops level equal " - "to your own" + "to your own", ) # Check if the old and new levels are greater than the user level @@ -550,8 +494,7 @@ def _check_power_levels(event, auth_events): if old_level_too_big or new_level_too_big: raise AuthError( 403, - "You don't have permission to add ops level greater " - "than your own" + "You don't have permission to add ops level greater " "than your own", ) @@ -587,10 +530,9 @@ def get_user_power_level(user_id, auth_events): # some things which call this don't pass the create event: hack around # that. - key = (EventTypes.Create, "", ) + key = (EventTypes.Create, "") create_event = auth_events.get(key) - if (create_event is not None and - create_event.content["creator"] == user_id): + if create_event is not None and create_event.content["creator"] == user_id: return 100 else: return 0 @@ -636,9 +578,7 @@ def _verify_third_party_invite(event, auth_events): token = signed["token"] - invite_event = auth_events.get( - (EventTypes.ThirdPartyInvite, token,) - ) + invite_event = auth_events.get((EventTypes.ThirdPartyInvite, token)) if not invite_event: return False @@ -661,8 +601,7 @@ def _verify_third_party_invite(event, auth_events): if not key_name.startswith("ed25519:"): continue verify_key = decode_verify_key_bytes( - key_name, - decode_base64(public_key) + key_name, decode_base64(public_key) ) verify_signed_json(signed, server, verify_key) @@ -671,7 +610,7 @@ def _verify_third_party_invite(event, auth_events): # The caller is responsible for checking that the signing # server has not revoked that public key. return True - except (KeyError, SignatureVerifyException,): + except (KeyError, SignatureVerifyException): continue return False @@ -679,9 +618,7 @@ def _verify_third_party_invite(event, auth_events): def get_public_keys(invite_event): public_keys = [] if "public_key" in invite_event.content: - o = { - "public_key": invite_event.content["public_key"], - } + o = {"public_key": invite_event.content["public_key"]} if "key_validity_url" in invite_event.content: o["key_validity_url"] = invite_event.content["key_validity_url"] public_keys.append(o) @@ -702,22 +639,22 @@ def auth_types_for_event(event): auth_types = [] - auth_types.append((EventTypes.PowerLevels, "", )) - auth_types.append((EventTypes.Member, event.sender, )) - auth_types.append((EventTypes.Create, "", )) + auth_types.append((EventTypes.PowerLevels, "")) + auth_types.append((EventTypes.Member, event.sender)) + auth_types.append((EventTypes.Create, "")) if event.type == EventTypes.Member: membership = event.content["membership"] if membership in [Membership.JOIN, Membership.INVITE]: - auth_types.append((EventTypes.JoinRules, "", )) + auth_types.append((EventTypes.JoinRules, "")) - auth_types.append((EventTypes.Member, event.state_key, )) + auth_types.append((EventTypes.Member, event.state_key)) if membership == Membership.INVITE: if "third_party_invite" in event.content: key = ( EventTypes.ThirdPartyInvite, - event.content["third_party_invite"]["signed"]["token"] + event.content["third_party_invite"]["signed"]["token"], ) auth_types.append(key) diff --git a/synapse/events/__init__.py b/synapse/events/__init__.py index 7154bcbea612..d3de70e671c9 100644 --- a/synapse/events/__init__.py +++ b/synapse/events/__init__.py @@ -127,25 +127,25 @@ def delete(self): except KeyError: raise AttributeError(key) - return property( - getter, - setter, - delete, - ) + return property(getter, setter, delete) class EventBase(object): - def __init__(self, event_dict, signatures={}, unsigned={}, - internal_metadata_dict={}, rejected_reason=None): + def __init__( + self, + event_dict, + signatures={}, + unsigned={}, + internal_metadata_dict={}, + rejected_reason=None, + ): self.signatures = signatures self.unsigned = unsigned self.rejected_reason = rejected_reason self._event_dict = event_dict - self.internal_metadata = _EventInternalMetadata( - internal_metadata_dict - ) + self.internal_metadata = _EventInternalMetadata(internal_metadata_dict) auth_events = _event_dict_property("auth_events") depth = _event_dict_property("depth") @@ -168,10 +168,7 @@ def is_state(self): def get_dict(self): d = dict(self._event_dict) - d.update({ - "signatures": self.signatures, - "unsigned": dict(self.unsigned), - }) + d.update({"signatures": self.signatures, "unsigned": dict(self.unsigned)}) return d @@ -358,6 +355,7 @@ def __repr__(self): class FrozenEventV3(FrozenEventV2): """FrozenEventV3, which differs from FrozenEventV2 only in the event_id format""" + format_version = EventFormatVersions.V3 # All events of this type are V3 @property @@ -414,6 +412,4 @@ def event_type_from_format_version(format_version): elif format_version == EventFormatVersions.V3: return FrozenEventV3 else: - raise Exception( - "No event format %r" % (format_version,) - ) + raise Exception("No event format %r" % (format_version,)) diff --git a/synapse/events/builder.py b/synapse/events/builder.py index 546b6f498275..db011e04078f 100644 --- a/synapse/events/builder.py +++ b/synapse/events/builder.py @@ -78,7 +78,9 @@ class EventBuilder(object): _redacts = attr.ib(default=None) _origin_server_ts = attr.ib(default=None) - internal_metadata = attr.ib(default=attr.Factory(lambda: _EventInternalMetadata({}))) + internal_metadata = attr.ib( + default=attr.Factory(lambda: _EventInternalMetadata({})) + ) @property def state_key(self): @@ -102,11 +104,9 @@ def build(self, prev_event_ids): """ state_ids = yield self._state.get_current_state_ids( - self.room_id, prev_event_ids, - ) - auth_ids = yield self._auth.compute_auth_events( - self, state_ids, + self.room_id, prev_event_ids ) + auth_ids = yield self._auth.compute_auth_events(self, state_ids) if self.format_version == EventFormatVersions.V1: auth_events = yield self._store.add_event_hashes(auth_ids) @@ -115,9 +115,7 @@ def build(self, prev_event_ids): auth_events = auth_ids prev_events = prev_event_ids - old_depth = yield self._store.get_max_depth_of( - prev_event_ids, - ) + old_depth = yield self._store.get_max_depth_of(prev_event_ids) depth = old_depth + 1 # we cap depth of generated events, to ensure that they are not @@ -217,9 +215,14 @@ def for_room_version(self, room_version, key_values): ) -def create_local_event_from_event_dict(clock, hostname, signing_key, - format_version, event_dict, - internal_metadata_dict=None): +def create_local_event_from_event_dict( + clock, + hostname, + signing_key, + format_version, + event_dict, + internal_metadata_dict=None, +): """Takes a fully formed event dict, ensuring that fields like `origin` and `origin_server_ts` have correct values for a locally produced event, then signs and hashes it. @@ -237,9 +240,7 @@ def create_local_event_from_event_dict(clock, hostname, signing_key, """ if format_version not in KNOWN_EVENT_FORMAT_VERSIONS: - raise Exception( - "No event format defined for version %r" % (format_version,) - ) + raise Exception("No event format defined for version %r" % (format_version,)) if internal_metadata_dict is None: internal_metadata_dict = {} @@ -258,13 +259,9 @@ def create_local_event_from_event_dict(clock, hostname, signing_key, event_dict.setdefault("signatures", {}) - add_hashes_and_signatures( - event_dict, - hostname, - signing_key, - ) + add_hashes_and_signatures(event_dict, hostname, signing_key) return event_type_from_format_version(format_version)( - event_dict, internal_metadata_dict=internal_metadata_dict, + event_dict, internal_metadata_dict=internal_metadata_dict ) diff --git a/synapse/events/snapshot.py b/synapse/events/snapshot.py index fa09c132a0a1..a96cdada3d35 100644 --- a/synapse/events/snapshot.py +++ b/synapse/events/snapshot.py @@ -88,8 +88,9 @@ def __init__(self): self.app_service = None @staticmethod - def with_state(state_group, current_state_ids, prev_state_ids, - prev_group=None, delta_ids=None): + def with_state( + state_group, current_state_ids, prev_state_ids, prev_group=None, delta_ids=None + ): context = EventContext() # The current state including the current event @@ -132,17 +133,19 @@ def serialize(self, event, store): else: prev_state_id = None - defer.returnValue({ - "prev_state_id": prev_state_id, - "event_type": event.type, - "event_state_key": event.state_key if event.is_state() else None, - "state_group": self.state_group, - "rejected": self.rejected, - "prev_group": self.prev_group, - "delta_ids": _encode_state_dict(self.delta_ids), - "prev_state_events": self.prev_state_events, - "app_service_id": self.app_service.id if self.app_service else None - }) + defer.returnValue( + { + "prev_state_id": prev_state_id, + "event_type": event.type, + "event_state_key": event.state_key if event.is_state() else None, + "state_group": self.state_group, + "rejected": self.rejected, + "prev_group": self.prev_group, + "delta_ids": _encode_state_dict(self.delta_ids), + "prev_state_events": self.prev_state_events, + "app_service_id": self.app_service.id if self.app_service else None, + } + ) @staticmethod def deserialize(store, input): @@ -194,7 +197,7 @@ def get_current_state_ids(self, store): if not self._fetching_state_deferred: self._fetching_state_deferred = run_in_background( - self._fill_out_state, store, + self._fill_out_state, store ) yield make_deferred_yieldable(self._fetching_state_deferred) @@ -214,7 +217,7 @@ def get_prev_state_ids(self, store): if not self._fetching_state_deferred: self._fetching_state_deferred = run_in_background( - self._fill_out_state, store, + self._fill_out_state, store ) yield make_deferred_yieldable(self._fetching_state_deferred) @@ -240,9 +243,7 @@ def _fill_out_state(self, store): if self.state_group is None: return - self._current_state_ids = yield store.get_state_ids_for_group( - self.state_group, - ) + self._current_state_ids = yield store.get_state_ids_for_group(self.state_group) if self._prev_state_id and self._event_state_key is not None: self._prev_state_ids = dict(self._current_state_ids) @@ -252,8 +253,9 @@ def _fill_out_state(self, store): self._prev_state_ids = self._current_state_ids @defer.inlineCallbacks - def update_state(self, state_group, prev_state_ids, current_state_ids, - prev_group, delta_ids): + def update_state( + self, state_group, prev_state_ids, current_state_ids, prev_group, delta_ids + ): """Replace the state in the context """ @@ -279,10 +281,7 @@ def _encode_state_dict(state_dict): if state_dict is None: return None - return [ - (etype, state_key, v) - for (etype, state_key), v in iteritems(state_dict) - ] + return [(etype, state_key, v) for (etype, state_key), v in iteritems(state_dict)] def _decode_state_dict(input): @@ -291,4 +290,4 @@ def _decode_state_dict(input): if input is None: return None - return frozendict({(etype, state_key,): v for etype, state_key, v in input}) + return frozendict({(etype, state_key): v for etype, state_key, v in input}) diff --git a/synapse/events/spamcheck.py b/synapse/events/spamcheck.py index 6058077f752d..129771f183ba 100644 --- a/synapse/events/spamcheck.py +++ b/synapse/events/spamcheck.py @@ -60,7 +60,9 @@ def user_may_invite(self, inviter_userid, invitee_userid, room_id): if self.spam_checker is None: return True - return self.spam_checker.user_may_invite(inviter_userid, invitee_userid, room_id) + return self.spam_checker.user_may_invite( + inviter_userid, invitee_userid, room_id + ) def user_may_create_room(self, userid): """Checks if a given user may create a room diff --git a/synapse/events/third_party_rules.py b/synapse/events/third_party_rules.py index 50ceeb1e8ea7..8f5d95696b79 100644 --- a/synapse/events/third_party_rules.py +++ b/synapse/events/third_party_rules.py @@ -36,8 +36,7 @@ def __init__(self, hs): if module is not None: self.third_party_rules = module( - config=config, - http_client=hs.get_simple_http_client(), + config=config, http_client=hs.get_simple_http_client() ) @defer.inlineCallbacks @@ -109,6 +108,6 @@ def check_threepid_can_be_invited(self, medium, address, room_id): state_events[key] = room_state_events[event_id] ret = yield self.third_party_rules.check_threepid_can_be_invited( - medium, address, state_events, + medium, address, state_events ) defer.returnValue(ret) diff --git a/synapse/events/utils.py b/synapse/events/utils.py index e2d4384de199..f24f0c16f001 100644 --- a/synapse/events/utils.py +++ b/synapse/events/utils.py @@ -31,7 +31,7 @@ # by a match for 'stuff'. # TODO: This is fast, but fails to handle "foo\\.bar" which should be treated as # the literal fields "foo\" and "bar" but will instead be treated as "foo\\.bar" -SPLIT_FIELD_REGEX = re.compile(r'(? MAX_ALIAS_LENGTH: raise SynapseError( 400, - ("Can't create aliases longer than" - " %d characters" % (MAX_ALIAS_LENGTH,)), + ( + "Can't create aliases longer than" + " %d characters" % (MAX_ALIAS_LENGTH,) + ), Codes.INVALID_PARAM, ) @@ -76,11 +76,7 @@ def validate_builder(self, event): event (EventBuilder|FrozenEvent) """ - strings = [ - "room_id", - "sender", - "type", - ] + strings = ["room_id", "sender", "type"] if hasattr(event, "state_key"): strings.append("state_key") @@ -93,10 +89,7 @@ def validate_builder(self, event): UserID.from_string(event.sender) if event.type == EventTypes.Message: - strings = [ - "body", - "msgtype", - ] + strings = ["body", "msgtype"] self._ensure_strings(event.content, strings) diff --git a/synapse/federation/federation_base.py b/synapse/federation/federation_base.py index fc5cfb7d83f0..58b929363fbf 100644 --- a/synapse/federation/federation_base.py +++ b/synapse/federation/federation_base.py @@ -44,8 +44,9 @@ def __init__(self, hs): self._clock = hs.get_clock() @defer.inlineCallbacks - def _check_sigs_and_hash_and_fetch(self, origin, pdus, room_version, - outlier=False, include_none=False): + def _check_sigs_and_hash_and_fetch( + self, origin, pdus, room_version, outlier=False, include_none=False + ): """Takes a list of PDUs and checks the signatures and hashs of each one. If a PDU fails its signature check then we check if we have it in the database and if not then request if from the originating server of @@ -79,9 +80,7 @@ def handle_check_result(pdu, deferred): if not res: # Check local db. res = yield self.store.get_event( - pdu.event_id, - allow_rejected=True, - allow_none=True, + pdu.event_id, allow_rejected=True, allow_none=True ) if not res and pdu.origin != origin: @@ -98,23 +97,16 @@ def handle_check_result(pdu, deferred): if not res: logger.warn( - "Failed to find copy of %s with valid signature", - pdu.event_id, + "Failed to find copy of %s with valid signature", pdu.event_id ) defer.returnValue(res) handle = logcontext.preserve_fn(handle_check_result) - deferreds2 = [ - handle(pdu, deferred) - for pdu, deferred in zip(pdus, deferreds) - ] + deferreds2 = [handle(pdu, deferred) for pdu, deferred in zip(pdus, deferreds)] valid_pdus = yield logcontext.make_deferred_yieldable( - defer.gatherResults( - deferreds2, - consumeErrors=True, - ) + defer.gatherResults(deferreds2, consumeErrors=True) ).addErrback(unwrapFirstError) if include_none: @@ -124,7 +116,7 @@ def handle_check_result(pdu, deferred): def _check_sigs_and_hash(self, room_version, pdu): return logcontext.make_deferred_yieldable( - self._check_sigs_and_hashes(room_version, [pdu])[0], + self._check_sigs_and_hashes(room_version, [pdu])[0] ) def _check_sigs_and_hashes(self, room_version, pdus): @@ -159,11 +151,9 @@ def callback(_, pdu): # received event was probably a redacted copy (but we then use our # *actual* redacted copy to be on the safe side.) redacted_event = prune_event(pdu) - if ( - set(redacted_event.keys()) == set(pdu.keys()) and - set(six.iterkeys(redacted_event.content)) - == set(six.iterkeys(pdu.content)) - ): + if set(redacted_event.keys()) == set(pdu.keys()) and set( + six.iterkeys(redacted_event.content) + ) == set(six.iterkeys(pdu.content)): logger.info( "Event %s seems to have been redacted; using our redacted " "copy", @@ -172,14 +162,16 @@ def callback(_, pdu): else: logger.warning( "Event %s content has been tampered, redacting", - pdu.event_id, pdu.get_pdu_json(), + pdu.event_id, + pdu.get_pdu_json(), ) return redacted_event if self.spam_checker.check_event_for_spam(pdu): logger.warn( "Event contains spam, redacting %s: %s", - pdu.event_id, pdu.get_pdu_json() + pdu.event_id, + pdu.get_pdu_json(), ) return prune_event(pdu) @@ -190,23 +182,24 @@ def errback(failure, pdu): with logcontext.PreserveLoggingContext(ctx): logger.warn( "Signature check failed for %s: %s", - pdu.event_id, failure.getErrorMessage(), + pdu.event_id, + failure.getErrorMessage(), ) return failure for deferred, pdu in zip(deferreds, pdus): deferred.addCallbacks( - callback, errback, - callbackArgs=[pdu], - errbackArgs=[pdu], + callback, errback, callbackArgs=[pdu], errbackArgs=[pdu] ) return deferreds -class PduToCheckSig(namedtuple("PduToCheckSig", [ - "pdu", "redacted_pdu_json", "sender_domain", "deferreds", -])): +class PduToCheckSig( + namedtuple( + "PduToCheckSig", ["pdu", "redacted_pdu_json", "sender_domain", "deferreds"] + ) +): pass @@ -260,10 +253,7 @@ def _check_sigs_on_pdus(keyring, room_version, pdus): # First we check that the sender event is signed by the sender's domain # (except if its a 3pid invite, in which case it may be sent by any server) - pdus_to_check_sender = [ - p for p in pdus_to_check - if not _is_invite_via_3pid(p.pdu) - ] + pdus_to_check_sender = [p for p in pdus_to_check if not _is_invite_via_3pid(p.pdu)] more_deferreds = keyring.verify_json_objects_for_server( [ @@ -297,7 +287,8 @@ def sender_err(e, pdu_to_check): # (ie, the room version uses old-style non-hash event IDs). if v.event_format == EventFormatVersions.V1: pdus_to_check_event_id = [ - p for p in pdus_to_check + p + for p in pdus_to_check if p.sender_domain != get_domain_from_id(p.pdu.event_id) ] @@ -315,10 +306,8 @@ def sender_err(e, pdu_to_check): def event_err(e, pdu_to_check): errmsg = ( - "event id %s: unable to verify signature for event id domain: %s" % ( - pdu_to_check.pdu.event_id, - e.getErrorMessage(), - ) + "event id %s: unable to verify signature for event id domain: %s" + % (pdu_to_check.pdu.event_id, e.getErrorMessage()) ) # XX as above: not really sure if these are the right codes raise SynapseError(400, errmsg, Codes.UNAUTHORIZED) @@ -368,21 +357,18 @@ def event_from_pdu_json(pdu_json, event_format_version, outlier=False): """ # we could probably enforce a bunch of other fields here (room_id, sender, # origin, etc etc) - assert_params_in_dict(pdu_json, ('type', 'depth')) + assert_params_in_dict(pdu_json, ("type", "depth")) - depth = pdu_json['depth'] + depth = pdu_json["depth"] if not isinstance(depth, six.integer_types): - raise SynapseError(400, "Depth %r not an intger" % (depth, ), - Codes.BAD_JSON) + raise SynapseError(400, "Depth %r not an intger" % (depth,), Codes.BAD_JSON) if depth < 0: raise SynapseError(400, "Depth too small", Codes.BAD_JSON) elif depth > MAX_DEPTH: raise SynapseError(400, "Depth too large", Codes.BAD_JSON) - event = event_type_from_format_version(event_format_version)( - pdu_json, - ) + event = event_type_from_format_version(event_format_version)(pdu_json) event.internal_metadata.outlier = outlier diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py index 70573746d6c5..3883eb525eed 100644 --- a/synapse/federation/federation_client.py +++ b/synapse/federation/federation_client.py @@ -57,6 +57,7 @@ class InvalidResponseError(RuntimeError): """Helper for _try_destination_list: indicates that the server returned a response we couldn't parse """ + pass @@ -65,9 +66,7 @@ def __init__(self, hs): super(FederationClient, self).__init__(hs) self.pdu_destination_tried = {} - self._clock.looping_call( - self._clear_tried_cache, 60 * 1000, - ) + self._clock.looping_call(self._clear_tried_cache, 60 * 1000) self.state = hs.get_state_handler() self.transport_layer = hs.get_federation_transport_client() @@ -99,8 +98,14 @@ def _clear_tried_cache(self): self.pdu_destination_tried[event_id] = destination_dict @log_function - def make_query(self, destination, query_type, args, - retry_on_dns_fail=False, ignore_backoff=False): + def make_query( + self, + destination, + query_type, + args, + retry_on_dns_fail=False, + ignore_backoff=False, + ): """Sends a federation Query to a remote homeserver of the given type and arguments. @@ -120,7 +125,10 @@ def make_query(self, destination, query_type, args, sent_queries_counter.labels(query_type).inc() return self.transport_layer.make_query( - destination, query_type, args, retry_on_dns_fail=retry_on_dns_fail, + destination, + query_type, + args, + retry_on_dns_fail=retry_on_dns_fail, ignore_backoff=ignore_backoff, ) @@ -137,9 +145,7 @@ def query_client_keys(self, destination, content, timeout): response """ sent_queries_counter.labels("client_device_keys").inc() - return self.transport_layer.query_client_keys( - destination, content, timeout - ) + return self.transport_layer.query_client_keys(destination, content, timeout) @log_function def query_user_devices(self, destination, user_id, timeout=30000): @@ -147,9 +153,7 @@ def query_user_devices(self, destination, user_id, timeout=30000): server. """ sent_queries_counter.labels("user_devices").inc() - return self.transport_layer.query_user_devices( - destination, user_id, timeout - ) + return self.transport_layer.query_user_devices(destination, user_id, timeout) @log_function def claim_client_keys(self, destination, content, timeout): @@ -164,9 +168,7 @@ def claim_client_keys(self, destination, content, timeout): response """ sent_queries_counter.labels("client_one_time_keys").inc() - return self.transport_layer.claim_client_keys( - destination, content, timeout - ) + return self.transport_layer.claim_client_keys(destination, content, timeout) @defer.inlineCallbacks @log_function @@ -191,7 +193,8 @@ def backfill(self, dest, room_id, limit, extremities): return transaction_data = yield self.transport_layer.backfill( - dest, room_id, extremities, limit) + dest, room_id, extremities, limit + ) logger.debug("backfill transaction_data=%s", repr(transaction_data)) @@ -204,17 +207,19 @@ def backfill(self, dest, room_id, limit, extremities): ] # FIXME: We should handle signature failures more gracefully. - pdus[:] = yield logcontext.make_deferred_yieldable(defer.gatherResults( - self._check_sigs_and_hashes(room_version, pdus), - consumeErrors=True, - ).addErrback(unwrapFirstError)) + pdus[:] = yield logcontext.make_deferred_yieldable( + defer.gatherResults( + self._check_sigs_and_hashes(room_version, pdus), consumeErrors=True + ).addErrback(unwrapFirstError) + ) defer.returnValue(pdus) @defer.inlineCallbacks @log_function - def get_pdu(self, destinations, event_id, room_version, outlier=False, - timeout=None): + def get_pdu( + self, destinations, event_id, room_version, outlier=False, timeout=None + ): """Requests the PDU with given origin and ID from the remote home servers. @@ -255,7 +260,7 @@ def get_pdu(self, destinations, event_id, room_version, outlier=False, try: transaction_data = yield self.transport_layer.get_event( - destination, event_id, timeout=timeout, + destination, event_id, timeout=timeout ) logger.debug( @@ -282,8 +287,7 @@ def get_pdu(self, destinations, event_id, room_version, outlier=False, except SynapseError as e: logger.info( - "Failed to get PDU %s from %s because %s", - event_id, destination, e, + "Failed to get PDU %s from %s because %s", event_id, destination, e ) continue except NotRetryingDestination as e: @@ -296,8 +300,7 @@ def get_pdu(self, destinations, event_id, room_version, outlier=False, pdu_attempts[destination] = now logger.info( - "Failed to get PDU %s from %s because %s", - event_id, destination, e, + "Failed to get PDU %s from %s because %s", event_id, destination, e ) continue @@ -326,7 +329,7 @@ def get_state_for_room(self, destination, room_id, event_id): # we have most of the state and auth_chain already. # However, this may 404 if the other side has an old synapse. result = yield self.transport_layer.get_room_state_ids( - destination, room_id, event_id=event_id, + destination, room_id, event_id=event_id ) state_event_ids = result["pdu_ids"] @@ -340,12 +343,10 @@ def get_state_for_room(self, destination, room_id, event_id): logger.warning( "Failed to fetch missing state/auth events for %s: %s", room_id, - failed_to_fetch + failed_to_fetch, ) - event_map = { - ev.event_id: ev for ev in fetched_events - } + event_map = {ev.event_id: ev for ev in fetched_events} pdus = [event_map[e_id] for e_id in state_event_ids if e_id in event_map] auth_chain = [ @@ -362,15 +363,14 @@ def get_state_for_room(self, destination, room_id, event_id): raise e result = yield self.transport_layer.get_room_state( - destination, room_id, event_id=event_id, + destination, room_id, event_id=event_id ) room_version = yield self.store.get_room_version(room_id) format_ver = room_version_to_event_format(room_version) pdus = [ - event_from_pdu_json(p, format_ver, outlier=True) - for p in result["pdus"] + event_from_pdu_json(p, format_ver, outlier=True) for p in result["pdus"] ] auth_chain = [ @@ -378,9 +378,9 @@ def get_state_for_room(self, destination, room_id, event_id): for p in result.get("auth_chain", []) ] - seen_events = yield self.store.get_events([ - ev.event_id for ev in itertools.chain(pdus, auth_chain) - ]) + seen_events = yield self.store.get_events( + [ev.event_id for ev in itertools.chain(pdus, auth_chain)] + ) signed_pdus = yield self._check_sigs_and_hash_and_fetch( destination, @@ -442,7 +442,7 @@ def get_events_from_store_or_dest(self, destination, room_id, event_ids): batch_size = 20 missing_events = list(missing_events) for i in range(0, len(missing_events), batch_size): - batch = set(missing_events[i:i + batch_size]) + batch = set(missing_events[i : i + batch_size]) deferreds = [ run_in_background( @@ -470,21 +470,17 @@ def get_events_from_store_or_dest(self, destination, room_id, event_ids): @defer.inlineCallbacks @log_function def get_event_auth(self, destination, room_id, event_id): - res = yield self.transport_layer.get_event_auth( - destination, room_id, event_id, - ) + res = yield self.transport_layer.get_event_auth(destination, room_id, event_id) room_version = yield self.store.get_room_version(room_id) format_ver = room_version_to_event_format(room_version) auth_chain = [ - event_from_pdu_json(p, format_ver, outlier=True) - for p in res["auth_chain"] + event_from_pdu_json(p, format_ver, outlier=True) for p in res["auth_chain"] ] signed_auth = yield self._check_sigs_and_hash_and_fetch( - destination, auth_chain, - outlier=True, room_version=room_version, + destination, auth_chain, outlier=True, room_version=room_version ) signed_auth.sort(key=lambda e: e.depth) @@ -527,28 +523,26 @@ def _try_destination_list(self, description, destinations, callback): res = yield callback(destination) defer.returnValue(res) except InvalidResponseError as e: - logger.warn( - "Failed to %s via %s: %s", - description, destination, e, - ) + logger.warn("Failed to %s via %s: %s", description, destination, e) except HttpResponseException as e: if not 500 <= e.code < 600: raise e.to_synapse_error() else: logger.warn( "Failed to %s via %s: %i %s", - description, destination, e.code, e.args[0], + description, + destination, + e.code, + e.args[0], ) except Exception: - logger.warn( - "Failed to %s via %s", - description, destination, exc_info=1, - ) + logger.warn("Failed to %s via %s", description, destination, exc_info=1) - raise RuntimeError("Failed to %s via any server" % (description, )) + raise RuntimeError("Failed to %s via any server" % (description,)) - def make_membership_event(self, destinations, room_id, user_id, membership, - content, params): + def make_membership_event( + self, destinations, room_id, user_id, membership, content, params + ): """ Creates an m.room.member event, with context, without participating in the room. @@ -584,14 +578,14 @@ def make_membership_event(self, destinations, room_id, user_id, membership, valid_memberships = {Membership.JOIN, Membership.LEAVE} if membership not in valid_memberships: raise RuntimeError( - "make_membership_event called with membership='%s', must be one of %s" % - (membership, ",".join(valid_memberships)) + "make_membership_event called with membership='%s', must be one of %s" + % (membership, ",".join(valid_memberships)) ) @defer.inlineCallbacks def send_request(destination): ret = yield self.transport_layer.make_membership_event( - destination, room_id, user_id, membership, params, + destination, room_id, user_id, membership, params ) # Note: If not supplied, the room version may be either v1 or v2, @@ -614,16 +608,17 @@ def send_request(destination): pdu_dict["prev_state"] = [] ev = builder.create_local_event_from_event_dict( - self._clock, self.hostname, self.signing_key, - format_version=event_format, event_dict=pdu_dict, + self._clock, + self.hostname, + self.signing_key, + format_version=event_format, + event_dict=pdu_dict, ) - defer.returnValue( - (destination, ev, event_format) - ) + defer.returnValue((destination, ev, event_format)) return self._try_destination_list( - "make_" + membership, destinations, send_request, + "make_" + membership, destinations, send_request ) def send_join(self, destinations, pdu, event_format_version): @@ -655,9 +650,7 @@ def check_authchain_validity(signed_auth_chain): create_event = e break else: - raise InvalidResponseError( - "no %s in auth chain" % (EventTypes.Create,), - ) + raise InvalidResponseError("no %s in auth chain" % (EventTypes.Create,)) # the room version should be sane. room_version = create_event.content.get("room_version", "1") @@ -665,9 +658,8 @@ def check_authchain_validity(signed_auth_chain): # This shouldn't be possible, because the remote server should have # rejected the join attempt during make_join. raise InvalidResponseError( - "room appears to have unsupported version %s" % ( - room_version, - )) + "room appears to have unsupported version %s" % (room_version,) + ) @defer.inlineCallbacks def send_request(destination): @@ -691,10 +683,7 @@ def send_request(destination): for p in content.get("auth_chain", []) ] - pdus = { - p.event_id: p - for p in itertools.chain(state, auth_chain) - } + pdus = {p.event_id: p for p in itertools.chain(state, auth_chain)} room_version = None for e in state: @@ -710,15 +699,13 @@ def send_request(destination): raise SynapseError(400, "No create event in state") valid_pdus = yield self._check_sigs_and_hash_and_fetch( - destination, list(pdus.values()), + destination, + list(pdus.values()), outlier=True, room_version=room_version, ) - valid_pdus_map = { - p.event_id: p - for p in valid_pdus - } + valid_pdus_map = {p.event_id: p for p in valid_pdus} # NB: We *need* to copy to ensure that we don't have multiple # references being passed on, as that causes... issues. @@ -741,11 +728,14 @@ def send_request(destination): check_authchain_validity(signed_auth) - defer.returnValue({ - "state": signed_state, - "auth_chain": signed_auth, - "origin": destination, - }) + defer.returnValue( + { + "state": signed_state, + "auth_chain": signed_auth, + "origin": destination, + } + ) + return self._try_destination_list("send_join", destinations, send_request) @defer.inlineCallbacks @@ -854,6 +844,7 @@ def send_leave(self, destinations, pdu): Fails with a ``RuntimeError`` if no servers were reachable. """ + @defer.inlineCallbacks def send_request(destination): time_now = self._clock.time_msec() @@ -869,14 +860,23 @@ def send_request(destination): return self._try_destination_list("send_leave", destinations, send_request) - def get_public_rooms(self, destination, limit=None, since_token=None, - search_filter=None, include_all_networks=False, - third_party_instance_id=None): + def get_public_rooms( + self, + destination, + limit=None, + since_token=None, + search_filter=None, + include_all_networks=False, + third_party_instance_id=None, + ): if destination == self.server_name: return return self.transport_layer.get_public_rooms( - destination, limit, since_token, search_filter, + destination, + limit, + since_token, + search_filter, include_all_networks=include_all_networks, third_party_instance_id=third_party_instance_id, ) @@ -891,9 +891,7 @@ def query_auth(self, destination, room_id, event_id, local_auth): """ time_now = self._clock.time_msec() - send_content = { - "auth_chain": [e.get_pdu_json(time_now) for e in local_auth], - } + send_content = {"auth_chain": [e.get_pdu_json(time_now) for e in local_auth]} code, content = yield self.transport_layer.send_query_auth( destination=destination, @@ -905,13 +903,10 @@ def query_auth(self, destination, room_id, event_id, local_auth): room_version = yield self.store.get_room_version(room_id) format_ver = room_version_to_event_format(room_version) - auth_chain = [ - event_from_pdu_json(e, format_ver) - for e in content["auth_chain"] - ] + auth_chain = [event_from_pdu_json(e, format_ver) for e in content["auth_chain"]] signed_auth = yield self._check_sigs_and_hash_and_fetch( - destination, auth_chain, outlier=True, room_version=room_version, + destination, auth_chain, outlier=True, room_version=room_version ) signed_auth.sort(key=lambda e: e.depth) @@ -925,8 +920,16 @@ def query_auth(self, destination, room_id, event_id, local_auth): defer.returnValue(ret) @defer.inlineCallbacks - def get_missing_events(self, destination, room_id, earliest_events_ids, - latest_events, limit, min_depth, timeout): + def get_missing_events( + self, + destination, + room_id, + earliest_events_ids, + latest_events, + limit, + min_depth, + timeout, + ): """Tries to fetch events we are missing. This is called when we receive an event without having received all of its ancestors. @@ -957,12 +960,11 @@ def get_missing_events(self, destination, room_id, earliest_events_ids, format_ver = room_version_to_event_format(room_version) events = [ - event_from_pdu_json(e, format_ver) - for e in content.get("events", []) + event_from_pdu_json(e, format_ver) for e in content.get("events", []) ] signed_events = yield self._check_sigs_and_hash_and_fetch( - destination, events, outlier=False, room_version=room_version, + destination, events, outlier=False, room_version=room_version ) except HttpResponseException as e: if not e.code == 400: @@ -982,17 +984,14 @@ def forward_third_party_invite(self, destinations, room_id, event_dict): try: yield self.transport_layer.exchange_third_party_invite( - destination=destination, - room_id=room_id, - event_dict=event_dict, + destination=destination, room_id=room_id, event_dict=event_dict ) defer.returnValue(None) except CodeMessageException: raise except Exception as e: logger.exception( - "Failed to send_third_party_invite via %s: %s", - destination, str(e) + "Failed to send_third_party_invite via %s: %s", destination, str(e) ) raise RuntimeError("Failed to send to any server.") diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py index 4c28c1dc3cdf..2e0cebb638c4 100644 --- a/synapse/federation/federation_server.py +++ b/synapse/federation/federation_server.py @@ -69,7 +69,6 @@ class FederationServer(FederationBase): - def __init__(self, hs): super(FederationServer, self).__init__(hs) @@ -118,11 +117,13 @@ def on_incoming_transaction(self, origin, transaction_data): # use a linearizer to ensure that we don't process the same transaction # multiple times in parallel. - with (yield self._transaction_linearizer.queue( - (origin, transaction.transaction_id), - )): + with ( + yield self._transaction_linearizer.queue( + (origin, transaction.transaction_id) + ) + ): result = yield self._handle_incoming_transaction( - origin, transaction, request_time, + origin, transaction, request_time ) defer.returnValue(result) @@ -144,7 +145,7 @@ def _handle_incoming_transaction(self, origin, transaction, request_time): if response: logger.debug( "[%s] We've already responded to this request", - transaction.transaction_id + transaction.transaction_id, ) defer.returnValue(response) return @@ -152,18 +153,15 @@ def _handle_incoming_transaction(self, origin, transaction, request_time): logger.debug("[%s] Transaction is new", transaction.transaction_id) # Reject if PDU count > 50 and EDU count > 100 - if (len(transaction.pdus) > 50 - or (hasattr(transaction, "edus") and len(transaction.edus) > 100)): + if len(transaction.pdus) > 50 or ( + hasattr(transaction, "edus") and len(transaction.edus) > 100 + ): - logger.info( - "Transaction PDU or EDU count too large. Returning 400", - ) + logger.info("Transaction PDU or EDU count too large. Returning 400") response = {} yield self.transaction_actions.set_response( - origin, - transaction, - 400, response + origin, transaction, 400, response ) defer.returnValue((400, response)) @@ -230,9 +228,7 @@ def process_pdus_for_room(room_id): try: yield self.check_server_matches_acl(origin_host, room_id) except AuthError as e: - logger.warn( - "Ignoring PDUs for room %s from banned server", room_id, - ) + logger.warn("Ignoring PDUs for room %s from banned server", room_id) for pdu in pdus_by_room[room_id]: event_id = pdu.event_id pdu_results[event_id] = e.error_dict() @@ -242,9 +238,7 @@ def process_pdus_for_room(room_id): event_id = pdu.event_id with nested_logging_context(event_id): try: - yield self._handle_received_pdu( - origin, pdu - ) + yield self._handle_received_pdu(origin, pdu) pdu_results[event_id] = {} except FederationError as e: logger.warn("Error handling PDU %s: %s", event_id, e) @@ -259,29 +253,18 @@ def process_pdus_for_room(room_id): ) yield concurrently_execute( - process_pdus_for_room, pdus_by_room.keys(), - TRANSACTION_CONCURRENCY_LIMIT, + process_pdus_for_room, pdus_by_room.keys(), TRANSACTION_CONCURRENCY_LIMIT ) if hasattr(transaction, "edus"): for edu in (Edu(**x) for x in transaction.edus): - yield self.received_edu( - origin, - edu.edu_type, - edu.content - ) + yield self.received_edu(origin, edu.edu_type, edu.content) - response = { - "pdus": pdu_results, - } + response = {"pdus": pdu_results} logger.debug("Returning: %s", str(response)) - yield self.transaction_actions.set_response( - origin, - transaction, - 200, response - ) + yield self.transaction_actions.set_response(origin, transaction, 200, response) defer.returnValue((200, response)) @defer.inlineCallbacks @@ -311,7 +294,8 @@ def on_context_state_request(self, origin, room_id, event_id): resp = yield self._state_resp_cache.wrap( (room_id, event_id), self._on_context_state_request_compute, - room_id, event_id, + room_id, + event_id, ) defer.returnValue((200, resp)) @@ -328,24 +312,17 @@ def on_state_ids_request(self, origin, room_id, event_id): if not in_room: raise AuthError(403, "Host not in room.") - state_ids = yield self.handler.get_state_ids_for_pdu( - room_id, event_id, - ) + state_ids = yield self.handler.get_state_ids_for_pdu(room_id, event_id) auth_chain_ids = yield self.store.get_auth_chain_ids(state_ids) - defer.returnValue((200, { - "pdu_ids": state_ids, - "auth_chain_ids": auth_chain_ids, - })) + defer.returnValue( + (200, {"pdu_ids": state_ids, "auth_chain_ids": auth_chain_ids}) + ) @defer.inlineCallbacks def _on_context_state_request_compute(self, room_id, event_id): - pdus = yield self.handler.get_state_for_pdu( - room_id, event_id, - ) - auth_chain = yield self.store.get_auth_chain( - [pdu.event_id for pdu in pdus] - ) + pdus = yield self.handler.get_state_for_pdu(room_id, event_id) + auth_chain = yield self.store.get_auth_chain([pdu.event_id for pdu in pdus]) for event in auth_chain: # We sign these again because there was a bug where we @@ -355,14 +332,16 @@ def _on_context_state_request_compute(self, room_id, event_id): compute_event_signature( event.get_pdu_json(), self.hs.hostname, - self.hs.config.signing_key[0] + self.hs.config.signing_key[0], ) ) - defer.returnValue({ - "pdus": [pdu.get_pdu_json() for pdu in pdus], - "auth_chain": [pdu.get_pdu_json() for pdu in auth_chain], - }) + defer.returnValue( + { + "pdus": [pdu.get_pdu_json() for pdu in pdus], + "auth_chain": [pdu.get_pdu_json() for pdu in auth_chain], + } + ) @defer.inlineCallbacks @log_function @@ -370,9 +349,7 @@ def on_pdu_request(self, origin, event_id): pdu = yield self.handler.get_persisted_pdu(origin, event_id) if pdu: - defer.returnValue( - (200, self._transaction_from_pdus([pdu]).get_dict()) - ) + defer.returnValue((200, self._transaction_from_pdus([pdu]).get_dict())) else: defer.returnValue((404, "")) @@ -394,10 +371,9 @@ def on_make_join_request(self, origin, room_id, user_id, supported_versions): pdu = yield self.handler.on_make_join_request(room_id, user_id) time_now = self._clock.time_msec() - defer.returnValue({ - "event": pdu.get_pdu_json(time_now), - "room_version": room_version, - }) + defer.returnValue( + {"event": pdu.get_pdu_json(time_now), "room_version": room_version} + ) @defer.inlineCallbacks def on_invite_request(self, origin, content, room_version): @@ -431,12 +407,17 @@ def on_send_join_request(self, origin, content, room_id): logger.debug("on_send_join_request: pdu sigs: %s", pdu.signatures) res_pdus = yield self.handler.on_send_join_request(origin, pdu) time_now = self._clock.time_msec() - defer.returnValue((200, { - "state": [p.get_pdu_json(time_now) for p in res_pdus["state"]], - "auth_chain": [ - p.get_pdu_json(time_now) for p in res_pdus["auth_chain"] - ], - })) + defer.returnValue( + ( + 200, + { + "state": [p.get_pdu_json(time_now) for p in res_pdus["state"]], + "auth_chain": [ + p.get_pdu_json(time_now) for p in res_pdus["auth_chain"] + ], + }, + ) + ) @defer.inlineCallbacks def on_make_leave_request(self, origin, room_id, user_id): @@ -447,10 +428,9 @@ def on_make_leave_request(self, origin, room_id, user_id): room_version = yield self.store.get_room_version(room_id) time_now = self._clock.time_msec() - defer.returnValue({ - "event": pdu.get_pdu_json(time_now), - "room_version": room_version, - }) + defer.returnValue( + {"event": pdu.get_pdu_json(time_now), "room_version": room_version} + ) @defer.inlineCallbacks def on_send_leave_request(self, origin, content, room_id): @@ -475,9 +455,7 @@ def on_event_auth(self, origin, room_id, event_id): time_now = self._clock.time_msec() auth_pdus = yield self.handler.on_event_auth(event_id) - res = { - "auth_chain": [a.get_pdu_json(time_now) for a in auth_pdus], - } + res = {"auth_chain": [a.get_pdu_json(time_now) for a in auth_pdus]} defer.returnValue((200, res)) @defer.inlineCallbacks @@ -508,12 +486,11 @@ def on_query_auth_request(self, origin, content, room_id, event_id): format_ver = room_version_to_event_format(room_version) auth_chain = [ - event_from_pdu_json(e, format_ver) - for e in content["auth_chain"] + event_from_pdu_json(e, format_ver) for e in content["auth_chain"] ] signed_auth = yield self._check_sigs_and_hash_and_fetch( - origin, auth_chain, outlier=True, room_version=room_version, + origin, auth_chain, outlier=True, room_version=room_version ) ret = yield self.handler.on_query_auth( @@ -527,17 +504,12 @@ def on_query_auth_request(self, origin, content, room_id, event_id): time_now = self._clock.time_msec() send_content = { - "auth_chain": [ - e.get_pdu_json(time_now) - for e in ret["auth_chain"] - ], + "auth_chain": [e.get_pdu_json(time_now) for e in ret["auth_chain"]], "rejects": ret.get("rejects", []), "missing": ret.get("missing", []), } - defer.returnValue( - (200, send_content) - ) + defer.returnValue((200, send_content)) @log_function def on_query_client_keys(self, origin, content): @@ -566,20 +538,23 @@ def on_claim_client_keys(self, origin, content): logger.info( "Claimed one-time-keys: %s", - ",".join(( - "%s for %s:%s" % (key_id, user_id, device_id) - for user_id, user_keys in iteritems(json_result) - for device_id, device_keys in iteritems(user_keys) - for key_id, _ in iteritems(device_keys) - )), + ",".join( + ( + "%s for %s:%s" % (key_id, user_id, device_id) + for user_id, user_keys in iteritems(json_result) + for device_id, device_keys in iteritems(user_keys) + for key_id, _ in iteritems(device_keys) + ) + ), ) defer.returnValue({"one_time_keys": json_result}) @defer.inlineCallbacks @log_function - def on_get_missing_events(self, origin, room_id, earliest_events, - latest_events, limit): + def on_get_missing_events( + self, origin, room_id, earliest_events, latest_events, limit + ): with (yield self._server_linearizer.queue((origin, room_id))): origin_host, _ = parse_server_name(origin) yield self.check_server_matches_acl(origin_host, room_id) @@ -587,11 +562,13 @@ def on_get_missing_events(self, origin, room_id, earliest_events, logger.info( "on_get_missing_events: earliest_events: %r, latest_events: %r," " limit: %d", - earliest_events, latest_events, limit, + earliest_events, + latest_events, + limit, ) missing_events = yield self.handler.on_get_missing_events( - origin, room_id, earliest_events, latest_events, limit, + origin, room_id, earliest_events, latest_events, limit ) if len(missing_events) < 5: @@ -603,9 +580,9 @@ def on_get_missing_events(self, origin, room_id, earliest_events, time_now = self._clock.time_msec() - defer.returnValue({ - "events": [ev.get_pdu_json(time_now) for ev in missing_events], - }) + defer.returnValue( + {"events": [ev.get_pdu_json(time_now) for ev in missing_events]} + ) @log_function def on_openid_userinfo(self, token): @@ -666,22 +643,17 @@ def _handle_received_pdu(self, origin, pdu): # origin. See bug #1893. This is also true for some third party # invites). if not ( - pdu.type == 'm.room.member' and - pdu.content and - pdu.content.get("membership", None) in ( - Membership.JOIN, Membership.INVITE, - ) + pdu.type == "m.room.member" + and pdu.content + and pdu.content.get("membership", None) + in (Membership.JOIN, Membership.INVITE) ): logger.info( - "Discarding PDU %s from invalid origin %s", - pdu.event_id, origin + "Discarding PDU %s from invalid origin %s", pdu.event_id, origin ) return else: - logger.info( - "Accepting join PDU %s from %s", - pdu.event_id, origin - ) + logger.info("Accepting join PDU %s from %s", pdu.event_id, origin) # We've already checked that we know the room version by this point room_version = yield self.store.get_room_version(pdu.room_id) @@ -690,33 +662,19 @@ def _handle_received_pdu(self, origin, pdu): try: pdu = yield self._check_sigs_and_hash(room_version, pdu) except SynapseError as e: - raise FederationError( - "ERROR", - e.code, - e.msg, - affected=pdu.event_id, - ) + raise FederationError("ERROR", e.code, e.msg, affected=pdu.event_id) - yield self.handler.on_receive_pdu( - origin, pdu, sent_to_us_directly=True, - ) + yield self.handler.on_receive_pdu(origin, pdu, sent_to_us_directly=True) def __str__(self): return "" % self.server_name @defer.inlineCallbacks def exchange_third_party_invite( - self, - sender_user_id, - target_user_id, - room_id, - signed, + self, sender_user_id, target_user_id, room_id, signed ): ret = yield self.handler.exchange_third_party_invite( - sender_user_id, - target_user_id, - room_id, - signed, + sender_user_id, target_user_id, room_id, signed ) defer.returnValue(ret) @@ -771,7 +729,7 @@ def server_matches_acl_event(server_name, acl_event): allow_ip_literals = True if not allow_ip_literals: # check for ipv6 literals. These start with '['. - if server_name[0] == '[': + if server_name[0] == "[": return False # check for ipv4 literals. We can just lift the routine from twisted. @@ -805,7 +763,9 @@ def server_matches_acl_event(server_name, acl_event): def _acl_entry_matches(server_name, acl_entry): if not isinstance(acl_entry, six.string_types): - logger.warn("Ignoring non-str ACL entry '%s' (is %s)", acl_entry, type(acl_entry)) + logger.warn( + "Ignoring non-str ACL entry '%s' (is %s)", acl_entry, type(acl_entry) + ) return False regex = glob_to_regex(acl_entry) return regex.match(server_name) @@ -815,6 +775,7 @@ class FederationHandlerRegistry(object): """Allows classes to register themselves as handlers for a given EDU or query type for incoming federation traffic. """ + def __init__(self): self.edu_handlers = {} self.query_handlers = {} @@ -848,9 +809,7 @@ def register_query_handler(self, query_type, handler): on and the result used as the response to the query request. """ if query_type in self.query_handlers: - raise KeyError( - "Already have a Query handler for %s" % (query_type,) - ) + raise KeyError("Already have a Query handler for %s" % (query_type,)) logger.info("Registering federation query handler for %r", query_type) @@ -905,14 +864,10 @@ def on_edu(self, edu_type, origin, content): handler = self.edu_handlers.get(edu_type) if handler: return super(ReplicationFederationHandlerRegistry, self).on_edu( - edu_type, origin, content, + edu_type, origin, content ) - return self._send_edu( - edu_type=edu_type, - origin=origin, - content=content, - ) + return self._send_edu(edu_type=edu_type, origin=origin, content=content) def on_query(self, query_type, args): """Overrides FederationHandlerRegistry @@ -921,7 +876,4 @@ def on_query(self, query_type, args): if handler: return handler(args) - return self._get_query_client( - query_type=query_type, - args=args, - ) + return self._get_query_client(query_type=query_type, args=args) diff --git a/synapse/federation/persistence.py b/synapse/federation/persistence.py index 74ffd13b4f89..7535f79203b3 100644 --- a/synapse/federation/persistence.py +++ b/synapse/federation/persistence.py @@ -46,12 +46,9 @@ def have_responded(self, origin, transaction): response code and response body. """ if not transaction.transaction_id: - raise RuntimeError("Cannot persist a transaction with no " - "transaction_id") + raise RuntimeError("Cannot persist a transaction with no " "transaction_id") - return self.store.get_received_txn_response( - transaction.transaction_id, origin - ) + return self.store.get_received_txn_response(transaction.transaction_id, origin) @log_function def set_response(self, origin, transaction, code, response): @@ -61,14 +58,10 @@ def set_response(self, origin, transaction, code, response): Deferred """ if not transaction.transaction_id: - raise RuntimeError("Cannot persist a transaction with no " - "transaction_id") + raise RuntimeError("Cannot persist a transaction with no " "transaction_id") return self.store.set_received_txn_response( - transaction.transaction_id, - origin, - code, - response, + transaction.transaction_id, origin, code, response ) @defer.inlineCallbacks diff --git a/synapse/federation/send_queue.py b/synapse/federation/send_queue.py index 0240b339b061..454456a52d6f 100644 --- a/synapse/federation/send_queue.py +++ b/synapse/federation/send_queue.py @@ -77,12 +77,22 @@ def __init__(self, hs): # lambda binds to the queue rather than to the name of the queue which # changes. ARGH. def register(name, queue): - LaterGauge("synapse_federation_send_queue_%s_size" % (queue_name,), - "", [], lambda: len(queue)) + LaterGauge( + "synapse_federation_send_queue_%s_size" % (queue_name,), + "", + [], + lambda: len(queue), + ) for queue_name in [ - "presence_map", "presence_changed", "keyed_edu", "keyed_edu_changed", - "edus", "device_messages", "pos_time", "presence_destinations", + "presence_map", + "presence_changed", + "keyed_edu", + "keyed_edu_changed", + "edus", + "device_messages", + "pos_time", + "presence_destinations", ]: register(queue_name, getattr(self, queue_name)) @@ -121,9 +131,7 @@ def _clear_queue_before_pos(self, position_to_delete): del self.presence_changed[key] user_ids = set( - user_id - for uids in self.presence_changed.values() - for user_id in uids + user_id for uids in self.presence_changed.values() for user_id in uids ) keys = self.presence_destinations.keys() @@ -285,19 +293,21 @@ def get_replication_rows(self, from_token, to_token, limit, federation_ack=None) ] for (key, user_id) in dest_user_ids: - rows.append((key, PresenceRow( - state=self.presence_map[user_id], - ))) + rows.append((key, PresenceRow(state=self.presence_map[user_id]))) # Fetch presence to send to destinations i = self.presence_destinations.bisect_right(from_token) j = self.presence_destinations.bisect_right(to_token) + 1 for pos, (user_id, dests) in self.presence_destinations.items()[i:j]: - rows.append((pos, PresenceDestinationsRow( - state=self.presence_map[user_id], - destinations=list(dests), - ))) + rows.append( + ( + pos, + PresenceDestinationsRow( + state=self.presence_map[user_id], destinations=list(dests) + ), + ) + ) # Fetch changes keyed edus i = self.keyed_edu_changed.bisect_right(from_token) @@ -308,10 +318,14 @@ def get_replication_rows(self, from_token, to_token, limit, federation_ack=None) keyed_edus = {v: k for k, v in self.keyed_edu_changed.items()[i:j]} for ((destination, edu_key), pos) in iteritems(keyed_edus): - rows.append((pos, KeyedEduRow( - key=edu_key, - edu=self.keyed_edu[(destination, edu_key)], - ))) + rows.append( + ( + pos, + KeyedEduRow( + key=edu_key, edu=self.keyed_edu[(destination, edu_key)] + ), + ) + ) # Fetch changed edus i = self.edus.bisect_right(from_token) @@ -327,9 +341,7 @@ def get_replication_rows(self, from_token, to_token, limit, federation_ack=None) device_messages = {v: k for k, v in self.device_messages.items()[i:j]} for (destination, pos) in iteritems(device_messages): - rows.append((pos, DeviceRow( - destination=destination, - ))) + rows.append((pos, DeviceRow(destination=destination))) # Sort rows based on pos rows.sort() @@ -377,16 +389,14 @@ def add_to_buffer(self, buff): raise NotImplementedError() -class PresenceRow(BaseFederationRow, namedtuple("PresenceRow", ( - "state", # UserPresenceState -))): +class PresenceRow( + BaseFederationRow, namedtuple("PresenceRow", ("state",)) # UserPresenceState +): TypeId = "p" @staticmethod def from_data(data): - return PresenceRow( - state=UserPresenceState.from_dict(data) - ) + return PresenceRow(state=UserPresenceState.from_dict(data)) def to_data(self): return self.state.as_dict() @@ -395,33 +405,35 @@ def add_to_buffer(self, buff): buff.presence.append(self.state) -class PresenceDestinationsRow(BaseFederationRow, namedtuple("PresenceDestinationsRow", ( - "state", # UserPresenceState - "destinations", # list[str] -))): +class PresenceDestinationsRow( + BaseFederationRow, + namedtuple( + "PresenceDestinationsRow", + ("state", "destinations"), # UserPresenceState # list[str] + ), +): TypeId = "pd" @staticmethod def from_data(data): return PresenceDestinationsRow( - state=UserPresenceState.from_dict(data["state"]), - destinations=data["dests"], + state=UserPresenceState.from_dict(data["state"]), destinations=data["dests"] ) def to_data(self): - return { - "state": self.state.as_dict(), - "dests": self.destinations, - } + return {"state": self.state.as_dict(), "dests": self.destinations} def add_to_buffer(self, buff): buff.presence_destinations.append((self.state, self.destinations)) -class KeyedEduRow(BaseFederationRow, namedtuple("KeyedEduRow", ( - "key", # tuple(str) - the edu key passed to send_edu - "edu", # Edu -))): +class KeyedEduRow( + BaseFederationRow, + namedtuple( + "KeyedEduRow", + ("key", "edu"), # tuple(str) - the edu key passed to send_edu # Edu + ), +): """Streams EDUs that have an associated key that is ued to clobber. For example, typing EDUs clobber based on room_id. """ @@ -430,28 +442,19 @@ class KeyedEduRow(BaseFederationRow, namedtuple("KeyedEduRow", ( @staticmethod def from_data(data): - return KeyedEduRow( - key=tuple(data["key"]), - edu=Edu(**data["edu"]), - ) + return KeyedEduRow(key=tuple(data["key"]), edu=Edu(**data["edu"])) def to_data(self): - return { - "key": self.key, - "edu": self.edu.get_internal_dict(), - } + return {"key": self.key, "edu": self.edu.get_internal_dict()} def add_to_buffer(self, buff): - buff.keyed_edus.setdefault( - self.edu.destination, {} - )[self.key] = self.edu + buff.keyed_edus.setdefault(self.edu.destination, {})[self.key] = self.edu -class EduRow(BaseFederationRow, namedtuple("EduRow", ( - "edu", # Edu -))): +class EduRow(BaseFederationRow, namedtuple("EduRow", ("edu",))): # Edu """Streams EDUs that don't have keys. See KeyedEduRow """ + TypeId = "e" @staticmethod @@ -465,13 +468,12 @@ def add_to_buffer(self, buff): buff.edus.setdefault(self.edu.destination, []).append(self.edu) -class DeviceRow(BaseFederationRow, namedtuple("DeviceRow", ( - "destination", # str -))): +class DeviceRow(BaseFederationRow, namedtuple("DeviceRow", ("destination",))): # str """Streams the fact that either a) there is pending to device messages for users on the remote, or b) a local users device has changed and needs to be sent to the remote. """ + TypeId = "d" @staticmethod @@ -487,23 +489,20 @@ def add_to_buffer(self, buff): TypeToRow = { Row.TypeId: Row - for Row in ( - PresenceRow, - PresenceDestinationsRow, - KeyedEduRow, - EduRow, - DeviceRow, - ) + for Row in (PresenceRow, PresenceDestinationsRow, KeyedEduRow, EduRow, DeviceRow) } -ParsedFederationStreamData = namedtuple("ParsedFederationStreamData", ( - "presence", # list(UserPresenceState) - "presence_destinations", # list of tuples of UserPresenceState and destinations - "keyed_edus", # dict of destination -> { key -> Edu } - "edus", # dict of destination -> [Edu] - "device_destinations", # set of destinations -)) +ParsedFederationStreamData = namedtuple( + "ParsedFederationStreamData", + ( + "presence", # list(UserPresenceState) + "presence_destinations", # list of tuples of UserPresenceState and destinations + "keyed_edus", # dict of destination -> { key -> Edu } + "edus", # dict of destination -> [Edu] + "device_destinations", # set of destinations + ), +) def process_rows_for_federation(transaction_queue, rows): @@ -542,7 +541,7 @@ def process_rows_for_federation(transaction_queue, rows): for state, destinations in buff.presence_destinations: transaction_queue.send_presence_to_destinations( - states=[state], destinations=destinations, + states=[state], destinations=destinations ) for destination, edu_map in iteritems(buff.keyed_edus): diff --git a/synapse/federation/sender/__init__.py b/synapse/federation/sender/__init__.py index 4224b29ecf66..766c5a37cd7a 100644 --- a/synapse/federation/sender/__init__.py +++ b/synapse/federation/sender/__init__.py @@ -44,8 +44,8 @@ ) sent_pdus_destination_dist_total = Counter( - "synapse_federation_client_sent_pdu_destinations:total", "" - "Total number of PDUs queued for sending across all destinations", + "synapse_federation_client_sent_pdu_destinations:total", + "" "Total number of PDUs queued for sending across all destinations", ) @@ -63,14 +63,15 @@ def __init__(self, hs): self._transaction_manager = TransactionManager(hs) # map from destination to PerDestinationQueue - self._per_destination_queues = {} # type: dict[str, PerDestinationQueue] + self._per_destination_queues = {} # type: dict[str, PerDestinationQueue] LaterGauge( "synapse_federation_transaction_queue_pending_destinations", "", [], lambda: sum( - 1 for d in self._per_destination_queues.values() + 1 + for d in self._per_destination_queues.values() if d.transmission_loop_running ), ) @@ -108,8 +109,9 @@ def __init__(self, hs): # awaiting a call to flush_read_receipts_for_room. The presence of an entry # here for a given room means that we are rate-limiting RR flushes to that room, # and that there is a pending call to _flush_rrs_for_room in the system. - self._queues_awaiting_rr_flush_by_room = { - } # type: dict[str, set[PerDestinationQueue]] + self._queues_awaiting_rr_flush_by_room = ( + {} + ) # type: dict[str, set[PerDestinationQueue]] self._rr_txn_interval_per_room_ms = ( 1000.0 / hs.get_config().federation_rr_transactions_per_room_per_second @@ -141,8 +143,7 @@ def notify_new_events(self, current_id): # fire off a processing loop in the background run_as_background_process( - "process_event_queue_for_federation", - self._process_event_queue_loop, + "process_event_queue_for_federation", self._process_event_queue_loop ) @defer.inlineCallbacks @@ -152,7 +153,7 @@ def _process_event_queue_loop(self): while True: last_token = yield self.store.get_federation_out_pos("events") next_token, events = yield self.store.get_all_new_events_stream( - last_token, self._last_poked_id, limit=100, + last_token, self._last_poked_id, limit=100 ) logger.debug("Handling %s -> %s", last_token, next_token) @@ -179,7 +180,7 @@ def handle_event(event): # banned then it won't receive the event because it won't # be in the room after the ban. destinations = yield self.state.get_current_hosts_in_room( - event.room_id, latest_event_ids=event.prev_event_ids(), + event.room_id, latest_event_ids=event.prev_event_ids() ) except Exception: logger.exception( @@ -209,37 +210,40 @@ def handle_room_events(events): for event in events: events_by_room.setdefault(event.room_id, []).append(event) - yield logcontext.make_deferred_yieldable(defer.gatherResults( - [ - logcontext.run_in_background(handle_room_events, evs) - for evs in itervalues(events_by_room) - ], - consumeErrors=True - )) - - yield self.store.update_federation_out_pos( - "events", next_token + yield logcontext.make_deferred_yieldable( + defer.gatherResults( + [ + logcontext.run_in_background(handle_room_events, evs) + for evs in itervalues(events_by_room) + ], + consumeErrors=True, + ) ) + yield self.store.update_federation_out_pos("events", next_token) + if events: now = self.clock.time_msec() ts = yield self.store.get_received_ts(events[-1].event_id) synapse.metrics.event_processing_lag.labels( - "federation_sender").set(now - ts) + "federation_sender" + ).set(now - ts) synapse.metrics.event_processing_last_ts.labels( - "federation_sender").set(ts) + "federation_sender" + ).set(ts) events_processed_counter.inc(len(events)) - event_processing_loop_room_count.labels( - "federation_sender" - ).inc(len(events_by_room)) + event_processing_loop_room_count.labels("federation_sender").inc( + len(events_by_room) + ) event_processing_loop_counter.labels("federation_sender").inc() synapse.metrics.event_processing_positions.labels( - "federation_sender").set(next_token) + "federation_sender" + ).set(next_token) finally: self._is_processing = False @@ -312,9 +316,7 @@ def send_read_receipt(self, receipt): if not domains: return - queues_pending_flush = self._queues_awaiting_rr_flush_by_room.get( - room_id - ) + queues_pending_flush = self._queues_awaiting_rr_flush_by_room.get(room_id) # if there is no flush yet scheduled, we will send out these receipts with # immediate flushes, and schedule the next flush for this room. @@ -377,10 +379,9 @@ def send_presence(self, states): # updates in quick succession are correctly handled. # We only want to send presence for our own users, so lets always just # filter here just in case. - self.pending_presence.update({ - state.user_id: state for state in states - if self.is_mine_id(state.user_id) - }) + self.pending_presence.update( + {state.user_id: state for state in states if self.is_mine_id(state.user_id)} + ) # We then handle the new pending presence in batches, first figuring # out the destinations we need to send each state to and then poking it diff --git a/synapse/federation/sender/per_destination_queue.py b/synapse/federation/sender/per_destination_queue.py index 22a2735405df..9aab12c0d328 100644 --- a/synapse/federation/sender/per_destination_queue.py +++ b/synapse/federation/sender/per_destination_queue.py @@ -360,7 +360,7 @@ def _get_device_update_edus(self, limit): # Retrieve list of new device updates to send to the destination now_stream_id, results = yield self._store.get_devices_by_remote( - self._destination, last_device_list, limit=limit, + self._destination, last_device_list, limit=limit ) edus = [ Edu( @@ -381,10 +381,7 @@ def _get_to_device_message_edus(self, limit): last_device_stream_id = self._last_device_stream_id to_device_stream_id = self._store.get_to_device_stream_token() contents, stream_id = yield self._store.get_new_device_msgs_for_remote( - self._destination, - last_device_stream_id, - to_device_stream_id, - limit, + self._destination, last_device_stream_id, to_device_stream_id, limit ) edus = [ Edu( diff --git a/synapse/federation/sender/transaction_manager.py b/synapse/federation/sender/transaction_manager.py index 35e6b8ff5bd3..c987bb9a0d3c 100644 --- a/synapse/federation/sender/transaction_manager.py +++ b/synapse/federation/sender/transaction_manager.py @@ -29,9 +29,10 @@ class TransactionManager(object): shared between PerDestinationQueue objects """ + def __init__(self, hs): self._server_name = hs.hostname - self.clock = hs.get_clock() # nb must be called this for @measure_func + self.clock = hs.get_clock() # nb must be called this for @measure_func self._store = hs.get_datastore() self._transaction_actions = TransactionActions(self._store) self._transport_layer = hs.get_federation_transport_client() @@ -55,9 +56,9 @@ def send_new_transaction(self, destination, pending_pdus, pending_edus): txn_id = str(self._next_txn_id) logger.debug( - "TX [%s] {%s} Attempting new transaction" - " (pdus: %d, edus: %d)", - destination, txn_id, + "TX [%s] {%s} Attempting new transaction" " (pdus: %d, edus: %d)", + destination, + txn_id, len(pdus), len(edus), ) @@ -79,9 +80,9 @@ def send_new_transaction(self, destination, pending_pdus, pending_edus): logger.debug("TX [%s] Persisted transaction", destination) logger.info( - "TX [%s] {%s} Sending transaction [%s]," - " (PDUs: %d, EDUs: %d)", - destination, txn_id, + "TX [%s] {%s} Sending transaction [%s]," " (PDUs: %d, EDUs: %d)", + destination, + txn_id, transaction.transaction_id, len(pdus), len(edus), @@ -112,20 +113,12 @@ def json_data_cb(): response = e.response if e.code in (401, 404, 429) or 500 <= e.code: - logger.info( - "TX [%s] {%s} got %d response", - destination, txn_id, code - ) + logger.info("TX [%s] {%s} got %d response", destination, txn_id, code) raise e - logger.info( - "TX [%s] {%s} got %d response", - destination, txn_id, code - ) + logger.info("TX [%s] {%s} got %d response", destination, txn_id, code) - yield self._transaction_actions.delivered( - transaction, code, response - ) + yield self._transaction_actions.delivered(transaction, code, response) logger.debug("TX [%s] {%s} Marked as delivered", destination, txn_id) @@ -134,13 +127,18 @@ def json_data_cb(): if "error" in r: logger.warn( "TX [%s] {%s} Remote returned error for %s: %s", - destination, txn_id, e_id, r, + destination, + txn_id, + e_id, + r, ) else: for p in pdus: logger.warn( "TX [%s] {%s} Failed to send event %s", - destination, txn_id, p.event_id, + destination, + txn_id, + p.event_id, ) success = False diff --git a/synapse/federation/transport/client.py b/synapse/federation/transport/client.py index e424c40fdf59..aecd1423097b 100644 --- a/synapse/federation/transport/client.py +++ b/synapse/federation/transport/client.py @@ -48,12 +48,13 @@ def get_room_state(self, destination, room_id, event_id): Returns: Deferred: Results in a dict received from the remote homeserver. """ - logger.debug("get_room_state dest=%s, room=%s", - destination, room_id) + logger.debug("get_room_state dest=%s, room=%s", destination, room_id) path = _create_v1_path("/state/%s", room_id) return self.client.get_json( - destination, path=path, args={"event_id": event_id}, + destination, + path=path, + args={"event_id": event_id}, try_trailing_slash_on_400=True, ) @@ -71,12 +72,13 @@ def get_room_state_ids(self, destination, room_id, event_id): Returns: Deferred: Results in a dict received from the remote homeserver. """ - logger.debug("get_room_state_ids dest=%s, room=%s", - destination, room_id) + logger.debug("get_room_state_ids dest=%s, room=%s", destination, room_id) path = _create_v1_path("/state_ids/%s", room_id) return self.client.get_json( - destination, path=path, args={"event_id": event_id}, + destination, + path=path, + args={"event_id": event_id}, try_trailing_slash_on_400=True, ) @@ -94,13 +96,11 @@ def get_event(self, destination, event_id, timeout=None): Returns: Deferred: Results in a dict received from the remote homeserver. """ - logger.debug("get_pdu dest=%s, event_id=%s", - destination, event_id) + logger.debug("get_pdu dest=%s, event_id=%s", destination, event_id) path = _create_v1_path("/event/%s", event_id) return self.client.get_json( - destination, path=path, timeout=timeout, - try_trailing_slash_on_400=True, + destination, path=path, timeout=timeout, try_trailing_slash_on_400=True ) @log_function @@ -119,7 +119,10 @@ def backfill(self, destination, room_id, event_tuples, limit): """ logger.debug( "backfill dest=%s, room_id=%s, event_tuples=%s, limit=%s", - destination, room_id, repr(event_tuples), str(limit) + destination, + room_id, + repr(event_tuples), + str(limit), ) if not event_tuples: @@ -128,16 +131,10 @@ def backfill(self, destination, room_id, event_tuples, limit): path = _create_v1_path("/backfill/%s", room_id) - args = { - "v": event_tuples, - "limit": [str(limit)], - } + args = {"v": event_tuples, "limit": [str(limit)]} return self.client.get_json( - destination, - path=path, - args=args, - try_trailing_slash_on_400=True, + destination, path=path, args=args, try_trailing_slash_on_400=True ) @defer.inlineCallbacks @@ -163,7 +160,8 @@ def send_transaction(self, transaction, json_data_callback=None): """ logger.debug( "send_data dest=%s, txid=%s", - transaction.destination, transaction.transaction_id + transaction.destination, + transaction.transaction_id, ) if transaction.destination == self.server_name: @@ -189,8 +187,9 @@ def send_transaction(self, transaction, json_data_callback=None): @defer.inlineCallbacks @log_function - def make_query(self, destination, query_type, args, retry_on_dns_fail, - ignore_backoff=False): + def make_query( + self, destination, query_type, args, retry_on_dns_fail, ignore_backoff=False + ): path = _create_v1_path("/query/%s", query_type) content = yield self.client.get_json( @@ -235,8 +234,8 @@ def make_membership_event(self, destination, room_id, user_id, membership, param valid_memberships = {Membership.JOIN, Membership.LEAVE} if membership not in valid_memberships: raise RuntimeError( - "make_membership_event called with membership='%s', must be one of %s" % - (membership, ",".join(valid_memberships)) + "make_membership_event called with membership='%s', must be one of %s" + % (membership, ",".join(valid_memberships)) ) path = _create_v1_path("/make_%s/%s/%s", membership, room_id, user_id) @@ -268,9 +267,7 @@ def send_join(self, destination, room_id, event_id, content): path = _create_v1_path("/send_join/%s/%s", room_id, event_id) response = yield self.client.put_json( - destination=destination, - path=path, - data=content, + destination=destination, path=path, data=content ) defer.returnValue(response) @@ -284,7 +281,6 @@ def send_leave(self, destination, room_id, event_id, content): destination=destination, path=path, data=content, - # we want to do our best to send this through. The problem is # that if it fails, we won't retry it later, so if the remote # server was just having a momentary blip, the room will be out of @@ -300,10 +296,7 @@ def send_invite_v1(self, destination, room_id, event_id, content): path = _create_v1_path("/invite/%s/%s", room_id, event_id) response = yield self.client.put_json( - destination=destination, - path=path, - data=content, - ignore_backoff=True, + destination=destination, path=path, data=content, ignore_backoff=True ) defer.returnValue(response) @@ -314,26 +307,27 @@ def send_invite_v2(self, destination, room_id, event_id, content): path = _create_v2_path("/invite/%s/%s", room_id, event_id) response = yield self.client.put_json( - destination=destination, - path=path, - data=content, - ignore_backoff=True, + destination=destination, path=path, data=content, ignore_backoff=True ) defer.returnValue(response) @defer.inlineCallbacks @log_function - def get_public_rooms(self, remote_server, limit, since_token, - search_filter=None, include_all_networks=False, - third_party_instance_id=None): + def get_public_rooms( + self, + remote_server, + limit, + since_token, + search_filter=None, + include_all_networks=False, + third_party_instance_id=None, + ): path = _create_v1_path("/publicRooms") - args = { - "include_all_networks": "true" if include_all_networks else "false", - } + args = {"include_all_networks": "true" if include_all_networks else "false"} if third_party_instance_id: - args["third_party_instance_id"] = third_party_instance_id, + args["third_party_instance_id"] = (third_party_instance_id,) if limit: args["limit"] = [str(limit)] if since_token: @@ -342,10 +336,7 @@ def get_public_rooms(self, remote_server, limit, since_token, # TODO(erikj): Actually send the search_filter across federation. response = yield self.client.get_json( - destination=remote_server, - path=path, - args=args, - ignore_backoff=True, + destination=remote_server, path=path, args=args, ignore_backoff=True ) defer.returnValue(response) @@ -353,12 +344,10 @@ def get_public_rooms(self, remote_server, limit, since_token, @defer.inlineCallbacks @log_function def exchange_third_party_invite(self, destination, room_id, event_dict): - path = _create_v1_path("/exchange_third_party_invite/%s", room_id,) + path = _create_v1_path("/exchange_third_party_invite/%s", room_id) response = yield self.client.put_json( - destination=destination, - path=path, - data=event_dict, + destination=destination, path=path, data=event_dict ) defer.returnValue(response) @@ -368,10 +357,7 @@ def exchange_third_party_invite(self, destination, room_id, event_dict): def get_event_auth(self, destination, room_id, event_id): path = _create_v1_path("/event_auth/%s/%s", room_id, event_id) - content = yield self.client.get_json( - destination=destination, - path=path, - ) + content = yield self.client.get_json(destination=destination, path=path) defer.returnValue(content) @@ -381,9 +367,7 @@ def send_query_auth(self, destination, room_id, event_id, content): path = _create_v1_path("/query_auth/%s/%s", room_id, event_id) content = yield self.client.post_json( - destination=destination, - path=path, - data=content, + destination=destination, path=path, data=content ) defer.returnValue(content) @@ -416,10 +400,7 @@ def query_client_keys(self, destination, query_content, timeout): path = _create_v1_path("/user/keys/query") content = yield self.client.post_json( - destination=destination, - path=path, - data=query_content, - timeout=timeout, + destination=destination, path=path, data=query_content, timeout=timeout ) defer.returnValue(content) @@ -443,9 +424,7 @@ def query_user_devices(self, destination, user_id, timeout): path = _create_v1_path("/user/devices/%s", user_id) content = yield self.client.get_json( - destination=destination, - path=path, - timeout=timeout, + destination=destination, path=path, timeout=timeout ) defer.returnValue(content) @@ -479,18 +458,23 @@ def claim_client_keys(self, destination, query_content, timeout): path = _create_v1_path("/user/keys/claim") content = yield self.client.post_json( - destination=destination, - path=path, - data=query_content, - timeout=timeout, + destination=destination, path=path, data=query_content, timeout=timeout ) defer.returnValue(content) @defer.inlineCallbacks @log_function - def get_missing_events(self, destination, room_id, earliest_events, - latest_events, limit, min_depth, timeout): - path = _create_v1_path("/get_missing_events/%s", room_id,) + def get_missing_events( + self, + destination, + room_id, + earliest_events, + latest_events, + limit, + min_depth, + timeout, + ): + path = _create_v1_path("/get_missing_events/%s", room_id) content = yield self.client.post_json( destination=destination, @@ -510,7 +494,7 @@ def get_missing_events(self, destination, room_id, earliest_events, def get_group_profile(self, destination, group_id, requester_user_id): """Get a group profile """ - path = _create_v1_path("/groups/%s/profile", group_id,) + path = _create_v1_path("/groups/%s/profile", group_id) return self.client.get_json( destination=destination, @@ -529,7 +513,7 @@ def update_group_profile(self, destination, group_id, requester_user_id, content requester_user_id (str) content (dict): The new profile of the group """ - path = _create_v1_path("/groups/%s/profile", group_id,) + path = _create_v1_path("/groups/%s/profile", group_id) return self.client.post_json( destination=destination, @@ -543,7 +527,7 @@ def update_group_profile(self, destination, group_id, requester_user_id, content def get_group_summary(self, destination, group_id, requester_user_id): """Get a group summary """ - path = _create_v1_path("/groups/%s/summary", group_id,) + path = _create_v1_path("/groups/%s/summary", group_id) return self.client.get_json( destination=destination, @@ -556,7 +540,7 @@ def get_group_summary(self, destination, group_id, requester_user_id): def get_rooms_in_group(self, destination, group_id, requester_user_id): """Get all rooms in a group """ - path = _create_v1_path("/groups/%s/rooms", group_id,) + path = _create_v1_path("/groups/%s/rooms", group_id) return self.client.get_json( destination=destination, @@ -565,11 +549,12 @@ def get_rooms_in_group(self, destination, group_id, requester_user_id): ignore_backoff=True, ) - def add_room_to_group(self, destination, group_id, requester_user_id, room_id, - content): + def add_room_to_group( + self, destination, group_id, requester_user_id, room_id, content + ): """Add a room to a group """ - path = _create_v1_path("/groups/%s/room/%s", group_id, room_id,) + path = _create_v1_path("/groups/%s/room/%s", group_id, room_id) return self.client.post_json( destination=destination, @@ -579,13 +564,13 @@ def add_room_to_group(self, destination, group_id, requester_user_id, room_id, ignore_backoff=True, ) - def update_room_in_group(self, destination, group_id, requester_user_id, room_id, - config_key, content): + def update_room_in_group( + self, destination, group_id, requester_user_id, room_id, config_key, content + ): """Update room in group """ path = _create_v1_path( - "/groups/%s/room/%s/config/%s", - group_id, room_id, config_key, + "/groups/%s/room/%s/config/%s", group_id, room_id, config_key ) return self.client.post_json( @@ -599,7 +584,7 @@ def update_room_in_group(self, destination, group_id, requester_user_id, room_id def remove_room_from_group(self, destination, group_id, requester_user_id, room_id): """Remove a room from a group """ - path = _create_v1_path("/groups/%s/room/%s", group_id, room_id,) + path = _create_v1_path("/groups/%s/room/%s", group_id, room_id) return self.client.delete_json( destination=destination, @@ -612,7 +597,7 @@ def remove_room_from_group(self, destination, group_id, requester_user_id, room_ def get_users_in_group(self, destination, group_id, requester_user_id): """Get users in a group """ - path = _create_v1_path("/groups/%s/users", group_id,) + path = _create_v1_path("/groups/%s/users", group_id) return self.client.get_json( destination=destination, @@ -625,7 +610,7 @@ def get_users_in_group(self, destination, group_id, requester_user_id): def get_invited_users_in_group(self, destination, group_id, requester_user_id): """Get users that have been invited to a group """ - path = _create_v1_path("/groups/%s/invited_users", group_id,) + path = _create_v1_path("/groups/%s/invited_users", group_id) return self.client.get_json( destination=destination, @@ -638,16 +623,10 @@ def get_invited_users_in_group(self, destination, group_id, requester_user_id): def accept_group_invite(self, destination, group_id, user_id, content): """Accept a group invite """ - path = _create_v1_path( - "/groups/%s/users/%s/accept_invite", - group_id, user_id, - ) + path = _create_v1_path("/groups/%s/users/%s/accept_invite", group_id, user_id) return self.client.post_json( - destination=destination, - path=path, - data=content, - ignore_backoff=True, + destination=destination, path=path, data=content, ignore_backoff=True ) @log_function @@ -657,14 +636,13 @@ def join_group(self, destination, group_id, user_id, content): path = _create_v1_path("/groups/%s/users/%s/join", group_id, user_id) return self.client.post_json( - destination=destination, - path=path, - data=content, - ignore_backoff=True, + destination=destination, path=path, data=content, ignore_backoff=True ) @log_function - def invite_to_group(self, destination, group_id, user_id, requester_user_id, content): + def invite_to_group( + self, destination, group_id, user_id, requester_user_id, content + ): """Invite a user to a group """ path = _create_v1_path("/groups/%s/users/%s/invite", group_id, user_id) @@ -686,15 +664,13 @@ def invite_to_group_notification(self, destination, group_id, user_id, content): path = _create_v1_path("/groups/local/%s/users/%s/invite", group_id, user_id) return self.client.post_json( - destination=destination, - path=path, - data=content, - ignore_backoff=True, + destination=destination, path=path, data=content, ignore_backoff=True ) @log_function - def remove_user_from_group(self, destination, group_id, requester_user_id, - user_id, content): + def remove_user_from_group( + self, destination, group_id, requester_user_id, user_id, content + ): """Remove a user fron a group """ path = _create_v1_path("/groups/%s/users/%s/remove", group_id, user_id) @@ -708,8 +684,9 @@ def remove_user_from_group(self, destination, group_id, requester_user_id, ) @log_function - def remove_user_from_group_notification(self, destination, group_id, user_id, - content): + def remove_user_from_group_notification( + self, destination, group_id, user_id, content + ): """Sent by group server to inform a user's server that they have been kicked from the group. """ @@ -717,10 +694,7 @@ def remove_user_from_group_notification(self, destination, group_id, user_id, path = _create_v1_path("/groups/local/%s/users/%s/remove", group_id, user_id) return self.client.post_json( - destination=destination, - path=path, - data=content, - ignore_backoff=True, + destination=destination, path=path, data=content, ignore_backoff=True ) @log_function @@ -732,24 +706,24 @@ def renew_group_attestation(self, destination, group_id, user_id, content): path = _create_v1_path("/groups/%s/renew_attestation/%s", group_id, user_id) return self.client.post_json( - destination=destination, - path=path, - data=content, - ignore_backoff=True, + destination=destination, path=path, data=content, ignore_backoff=True ) @log_function - def update_group_summary_room(self, destination, group_id, user_id, room_id, - category_id, content): + def update_group_summary_room( + self, destination, group_id, user_id, room_id, category_id, content + ): """Update a room entry in a group summary """ if category_id: path = _create_v1_path( "/groups/%s/summary/categories/%s/rooms/%s", - group_id, category_id, room_id, + group_id, + category_id, + room_id, ) else: - path = _create_v1_path("/groups/%s/summary/rooms/%s", group_id, room_id,) + path = _create_v1_path("/groups/%s/summary/rooms/%s", group_id, room_id) return self.client.post_json( destination=destination, @@ -760,17 +734,20 @@ def update_group_summary_room(self, destination, group_id, user_id, room_id, ) @log_function - def delete_group_summary_room(self, destination, group_id, user_id, room_id, - category_id): + def delete_group_summary_room( + self, destination, group_id, user_id, room_id, category_id + ): """Delete a room entry in a group summary """ if category_id: path = _create_v1_path( "/groups/%s/summary/categories/%s/rooms/%s", - group_id, category_id, room_id, + group_id, + category_id, + room_id, ) else: - path = _create_v1_path("/groups/%s/summary/rooms/%s", group_id, room_id,) + path = _create_v1_path("/groups/%s/summary/rooms/%s", group_id, room_id) return self.client.delete_json( destination=destination, @@ -783,7 +760,7 @@ def delete_group_summary_room(self, destination, group_id, user_id, room_id, def get_group_categories(self, destination, group_id, requester_user_id): """Get all categories in a group """ - path = _create_v1_path("/groups/%s/categories", group_id,) + path = _create_v1_path("/groups/%s/categories", group_id) return self.client.get_json( destination=destination, @@ -796,7 +773,7 @@ def get_group_categories(self, destination, group_id, requester_user_id): def get_group_category(self, destination, group_id, requester_user_id, category_id): """Get category info in a group """ - path = _create_v1_path("/groups/%s/categories/%s", group_id, category_id,) + path = _create_v1_path("/groups/%s/categories/%s", group_id, category_id) return self.client.get_json( destination=destination, @@ -806,11 +783,12 @@ def get_group_category(self, destination, group_id, requester_user_id, category_ ) @log_function - def update_group_category(self, destination, group_id, requester_user_id, category_id, - content): + def update_group_category( + self, destination, group_id, requester_user_id, category_id, content + ): """Update a category in a group """ - path = _create_v1_path("/groups/%s/categories/%s", group_id, category_id,) + path = _create_v1_path("/groups/%s/categories/%s", group_id, category_id) return self.client.post_json( destination=destination, @@ -821,11 +799,12 @@ def update_group_category(self, destination, group_id, requester_user_id, catego ) @log_function - def delete_group_category(self, destination, group_id, requester_user_id, - category_id): + def delete_group_category( + self, destination, group_id, requester_user_id, category_id + ): """Delete a category in a group """ - path = _create_v1_path("/groups/%s/categories/%s", group_id, category_id,) + path = _create_v1_path("/groups/%s/categories/%s", group_id, category_id) return self.client.delete_json( destination=destination, @@ -838,7 +817,7 @@ def delete_group_category(self, destination, group_id, requester_user_id, def get_group_roles(self, destination, group_id, requester_user_id): """Get all roles in a group """ - path = _create_v1_path("/groups/%s/roles", group_id,) + path = _create_v1_path("/groups/%s/roles", group_id) return self.client.get_json( destination=destination, @@ -851,7 +830,7 @@ def get_group_roles(self, destination, group_id, requester_user_id): def get_group_role(self, destination, group_id, requester_user_id, role_id): """Get a roles info """ - path = _create_v1_path("/groups/%s/roles/%s", group_id, role_id,) + path = _create_v1_path("/groups/%s/roles/%s", group_id, role_id) return self.client.get_json( destination=destination, @@ -861,11 +840,12 @@ def get_group_role(self, destination, group_id, requester_user_id, role_id): ) @log_function - def update_group_role(self, destination, group_id, requester_user_id, role_id, - content): + def update_group_role( + self, destination, group_id, requester_user_id, role_id, content + ): """Update a role in a group """ - path = _create_v1_path("/groups/%s/roles/%s", group_id, role_id,) + path = _create_v1_path("/groups/%s/roles/%s", group_id, role_id) return self.client.post_json( destination=destination, @@ -879,7 +859,7 @@ def update_group_role(self, destination, group_id, requester_user_id, role_id, def delete_group_role(self, destination, group_id, requester_user_id, role_id): """Delete a role in a group """ - path = _create_v1_path("/groups/%s/roles/%s", group_id, role_id,) + path = _create_v1_path("/groups/%s/roles/%s", group_id, role_id) return self.client.delete_json( destination=destination, @@ -889,17 +869,17 @@ def delete_group_role(self, destination, group_id, requester_user_id, role_id): ) @log_function - def update_group_summary_user(self, destination, group_id, requester_user_id, - user_id, role_id, content): + def update_group_summary_user( + self, destination, group_id, requester_user_id, user_id, role_id, content + ): """Update a users entry in a group """ if role_id: path = _create_v1_path( - "/groups/%s/summary/roles/%s/users/%s", - group_id, role_id, user_id, + "/groups/%s/summary/roles/%s/users/%s", group_id, role_id, user_id ) else: - path = _create_v1_path("/groups/%s/summary/users/%s", group_id, user_id,) + path = _create_v1_path("/groups/%s/summary/users/%s", group_id, user_id) return self.client.post_json( destination=destination, @@ -910,11 +890,10 @@ def update_group_summary_user(self, destination, group_id, requester_user_id, ) @log_function - def set_group_join_policy(self, destination, group_id, requester_user_id, - content): + def set_group_join_policy(self, destination, group_id, requester_user_id, content): """Sets the join policy for a group """ - path = _create_v1_path("/groups/%s/settings/m.join_policy", group_id,) + path = _create_v1_path("/groups/%s/settings/m.join_policy", group_id) return self.client.put_json( destination=destination, @@ -925,17 +904,17 @@ def set_group_join_policy(self, destination, group_id, requester_user_id, ) @log_function - def delete_group_summary_user(self, destination, group_id, requester_user_id, - user_id, role_id): + def delete_group_summary_user( + self, destination, group_id, requester_user_id, user_id, role_id + ): """Delete a users entry in a group """ if role_id: path = _create_v1_path( - "/groups/%s/summary/roles/%s/users/%s", - group_id, role_id, user_id, + "/groups/%s/summary/roles/%s/users/%s", group_id, role_id, user_id ) else: - path = _create_v1_path("/groups/%s/summary/users/%s", group_id, user_id,) + path = _create_v1_path("/groups/%s/summary/users/%s", group_id, user_id) return self.client.delete_json( destination=destination, @@ -953,10 +932,7 @@ def bulk_get_publicised_groups(self, destination, user_ids): content = {"user_ids": user_ids} return self.client.post_json( - destination=destination, - path=path, - data=content, - ignore_backoff=True, + destination=destination, path=path, data=content, ignore_backoff=True ) @@ -975,9 +951,8 @@ def _create_v1_path(path, *args): Returns: str """ - return ( - FEDERATION_V1_PREFIX - + path % tuple(urllib.parse.quote(arg, "") for arg in args) + return FEDERATION_V1_PREFIX + path % tuple( + urllib.parse.quote(arg, "") for arg in args ) @@ -996,7 +971,6 @@ def _create_v2_path(path, *args): Returns: str """ - return ( - FEDERATION_V2_PREFIX - + path % tuple(urllib.parse.quote(arg, "") for arg in args) + return FEDERATION_V2_PREFIX + path % tuple( + urllib.parse.quote(arg, "") for arg in args ) diff --git a/synapse/federation/transport/server.py b/synapse/federation/transport/server.py index 949a5fb2aa6c..b4854e82f63e 100644 --- a/synapse/federation/transport/server.py +++ b/synapse/federation/transport/server.py @@ -66,8 +66,7 @@ def __init__(self, hs, servlet_groups=None): self.authenticator = Authenticator(hs) self.ratelimiter = FederationRateLimiter( - self.clock, - config=hs.config.rc_federation, + self.clock, config=hs.config.rc_federation ) self.register_servlets() @@ -84,11 +83,13 @@ def register_servlets(self): class AuthenticationError(SynapseError): """There was a problem authenticating the request""" + pass class NoAuthenticationError(AuthenticationError): """The request had no authentication information""" + pass @@ -105,8 +106,8 @@ def __init__(self, hs): def authenticate_request(self, request, content): now = self._clock.time_msec() json_request = { - "method": request.method.decode('ascii'), - "uri": request.uri.decode('ascii'), + "method": request.method.decode("ascii"), + "uri": request.uri.decode("ascii"), "destination": self.server_name, "signatures": {}, } @@ -120,7 +121,7 @@ def authenticate_request(self, request, content): if not auth_headers: raise NoAuthenticationError( - 401, "Missing Authorization headers", Codes.UNAUTHORIZED, + 401, "Missing Authorization headers", Codes.UNAUTHORIZED ) for auth in auth_headers: @@ -130,14 +131,14 @@ def authenticate_request(self, request, content): json_request["signatures"].setdefault(origin, {})[key] = sig if ( - self.federation_domain_whitelist is not None and - origin not in self.federation_domain_whitelist + self.federation_domain_whitelist is not None + and origin not in self.federation_domain_whitelist ): raise FederationDeniedError(origin) if not json_request["signatures"]: raise NoAuthenticationError( - 401, "Missing Authorization headers", Codes.UNAUTHORIZED, + 401, "Missing Authorization headers", Codes.UNAUTHORIZED ) yield self.keyring.verify_json_for_server( @@ -177,12 +178,12 @@ def _parse_auth_header(header_bytes): AuthenticationError if the header could not be parsed """ try: - header_str = header_bytes.decode('utf-8') + header_str = header_bytes.decode("utf-8") params = header_str.split(" ")[1].split(",") param_dict = dict(kv.split("=") for kv in params) def strip_quotes(value): - if value.startswith("\""): + if value.startswith('"'): return value[1:-1] else: return value @@ -198,11 +199,11 @@ def strip_quotes(value): except Exception as e: logger.warn( "Error parsing auth header '%s': %s", - header_bytes.decode('ascii', 'replace'), + header_bytes.decode("ascii", "replace"), e, ) raise AuthenticationError( - 400, "Malformed Authorization header", Codes.UNAUTHORIZED, + 400, "Malformed Authorization header", Codes.UNAUTHORIZED ) @@ -242,6 +243,7 @@ class BaseFederationServlet(object): Exception: other exceptions will be caught, logged, and a 500 will be returned. """ + REQUIRE_AUTH = True PREFIX = FEDERATION_V1_PREFIX # Allows specifying the API version @@ -293,9 +295,7 @@ def new_func(request, *args, **kwargs): origin, content, request.args, *args, **kwargs ) else: - response = yield func( - origin, content, request.args, *args, **kwargs - ) + response = yield func(origin, content, request.args, *args, **kwargs) defer.returnValue(response) @@ -343,14 +343,12 @@ def on_PUT(self, origin, content, query, transaction_id): try: transaction_data = content - logger.debug( - "Decoded %s: %s", - transaction_id, str(transaction_data) - ) + logger.debug("Decoded %s: %s", transaction_id, str(transaction_data)) logger.info( "Received txn %s from %s. (PDUs: %d, EDUs: %d)", - transaction_id, origin, + transaction_id, + origin, len(transaction_data.get("pdus", [])), len(transaction_data.get("edus", [])), ) @@ -361,8 +359,7 @@ def on_PUT(self, origin, content, query, transaction_id): # Add some extra data to the transaction dict that isn't included # in the request body. transaction_data.update( - transaction_id=transaction_id, - destination=self.server_name + transaction_id=transaction_id, destination=self.server_name ) except Exception as e: @@ -372,7 +369,7 @@ def on_PUT(self, origin, content, query, transaction_id): try: code, response = yield self.handler.on_incoming_transaction( - origin, transaction_data, + origin, transaction_data ) except Exception: logger.exception("on_incoming_transaction failed") @@ -416,7 +413,7 @@ class FederationBackfillServlet(BaseFederationServlet): PATH = "/backfill/(?P[^/]*)/?" def on_GET(self, origin, content, query, context): - versions = [x.decode('ascii') for x in query[b"v"]] + versions = [x.decode("ascii") for x in query[b"v"]] limit = parse_integer_from_args(query, "limit", None) if not limit: @@ -432,7 +429,7 @@ class FederationQueryServlet(BaseFederationServlet): def on_GET(self, origin, content, query, query_type): return self.handler.on_query_request( query_type, - {k.decode('utf8'): v[0].decode("utf-8") for k, v in query.items()} + {k.decode("utf8"): v[0].decode("utf-8") for k, v in query.items()}, ) @@ -456,15 +453,14 @@ def on_GET(self, origin, _content, query, context, user_id): Deferred[(int, object)|None]: either (response code, response object) to return a JSON response, or None if the request has already been handled. """ - versions = query.get(b'ver') + versions = query.get(b"ver") if versions is not None: supported_versions = [v.decode("utf-8") for v in versions] else: supported_versions = ["1"] content = yield self.handler.on_make_join_request( - origin, context, user_id, - supported_versions=supported_versions, + origin, context, user_id, supported_versions=supported_versions ) defer.returnValue((200, content)) @@ -474,9 +470,7 @@ class FederationMakeLeaveServlet(BaseFederationServlet): @defer.inlineCallbacks def on_GET(self, origin, content, query, context, user_id): - content = yield self.handler.on_make_leave_request( - origin, context, user_id, - ) + content = yield self.handler.on_make_leave_request(origin, context, user_id) defer.returnValue((200, content)) @@ -517,7 +511,7 @@ def on_PUT(self, origin, content, query, context, event_id): # state resolution algorithm, and we don't use that for processing # invites content = yield self.handler.on_invite_request( - origin, content, room_version=RoomVersions.V1.identifier, + origin, content, room_version=RoomVersions.V1.identifier ) # V1 federation API is defined to return a content of `[200, {...}]` @@ -545,7 +539,7 @@ def on_PUT(self, origin, content, query, context, event_id): event.setdefault("unsigned", {})["invite_room_state"] = invite_room_state content = yield self.handler.on_invite_request( - origin, event, room_version=room_version, + origin, event, room_version=room_version ) defer.returnValue((200, content)) @@ -629,8 +623,10 @@ def on_POST(self, origin, content, query): for invite in content["invites"]: try: if "signed" not in invite or "token" not in invite["signed"]: - message = ("Rejecting received notification of third-" - "party invite without signed: %s" % (invite,)) + message = ( + "Rejecting received notification of third-" + "party invite without signed: %s" % (invite,) + ) logger.info(message) raise SynapseError(400, message) yield self.handler.exchange_third_party_invite( @@ -671,18 +667,23 @@ class OpenIdUserInfo(BaseFederationServlet): def on_GET(self, origin, content, query): token = query.get(b"access_token", [None])[0] if token is None: - defer.returnValue((401, { - "errcode": "M_MISSING_TOKEN", "error": "Access Token required" - })) + defer.returnValue( + (401, {"errcode": "M_MISSING_TOKEN", "error": "Access Token required"}) + ) return - user_id = yield self.handler.on_openid_userinfo(token.decode('ascii')) + user_id = yield self.handler.on_openid_userinfo(token.decode("ascii")) if user_id is None: - defer.returnValue((401, { - "errcode": "M_UNKNOWN_TOKEN", - "error": "Access Token unknown or expired" - })) + defer.returnValue( + ( + 401, + { + "errcode": "M_UNKNOWN_TOKEN", + "error": "Access Token unknown or expired", + }, + ) + ) defer.returnValue((200, {"sub": user_id})) @@ -722,7 +723,7 @@ class PublicRoomList(BaseFederationServlet): def __init__(self, handler, authenticator, ratelimiter, server_name, deny_access): super(PublicRoomList, self).__init__( - handler, authenticator, ratelimiter, server_name, + handler, authenticator, ratelimiter, server_name ) self.deny_access = deny_access @@ -748,9 +749,7 @@ def on_GET(self, origin, content, query): network_tuple = ThirdPartyInstanceID(None, None) data = yield self.handler.get_local_public_room_list( - limit, since_token, - network_tuple=network_tuple, - from_federation=True, + limit, since_token, network_tuple=network_tuple, from_federation=True ) defer.returnValue((200, data)) @@ -761,17 +760,18 @@ class FederationVersionServlet(BaseFederationServlet): REQUIRE_AUTH = False def on_GET(self, origin, content, query): - return defer.succeed((200, { - "server": { - "name": "Synapse", - "version": get_version_string(synapse) - }, - })) + return defer.succeed( + ( + 200, + {"server": {"name": "Synapse", "version": get_version_string(synapse)}}, + ) + ) class FederationGroupsProfileServlet(BaseFederationServlet): """Get/set the basic profile of a group on behalf of a user """ + PATH = "/groups/(?P[^/]*)/profile" @defer.inlineCallbacks @@ -780,9 +780,7 @@ def on_GET(self, origin, content, query, group_id): if get_domain_from_id(requester_user_id) != origin: raise SynapseError(403, "requester_user_id doesn't match origin") - new_content = yield self.handler.get_group_profile( - group_id, requester_user_id - ) + new_content = yield self.handler.get_group_profile(group_id, requester_user_id) defer.returnValue((200, new_content)) @@ -808,9 +806,7 @@ def on_GET(self, origin, content, query, group_id): if get_domain_from_id(requester_user_id) != origin: raise SynapseError(403, "requester_user_id doesn't match origin") - new_content = yield self.handler.get_group_summary( - group_id, requester_user_id - ) + new_content = yield self.handler.get_group_summary(group_id, requester_user_id) defer.returnValue((200, new_content)) @@ -818,6 +814,7 @@ def on_GET(self, origin, content, query, group_id): class FederationGroupsRoomsServlet(BaseFederationServlet): """Get the rooms in a group on behalf of a user """ + PATH = "/groups/(?P[^/]*)/rooms" @defer.inlineCallbacks @@ -826,9 +823,7 @@ def on_GET(self, origin, content, query, group_id): if get_domain_from_id(requester_user_id) != origin: raise SynapseError(403, "requester_user_id doesn't match origin") - new_content = yield self.handler.get_rooms_in_group( - group_id, requester_user_id - ) + new_content = yield self.handler.get_rooms_in_group(group_id, requester_user_id) defer.returnValue((200, new_content)) @@ -836,6 +831,7 @@ def on_GET(self, origin, content, query, group_id): class FederationGroupsAddRoomsServlet(BaseFederationServlet): """Add/remove room from group """ + PATH = "/groups/(?P[^/]*)/room/(?P[^/]*)" @defer.inlineCallbacks @@ -857,7 +853,7 @@ def on_DELETE(self, origin, content, query, group_id, room_id): raise SynapseError(403, "requester_user_id doesn't match origin") new_content = yield self.handler.remove_room_from_group( - group_id, requester_user_id, room_id, + group_id, requester_user_id, room_id ) defer.returnValue((200, new_content)) @@ -866,6 +862,7 @@ def on_DELETE(self, origin, content, query, group_id, room_id): class FederationGroupsAddRoomsConfigServlet(BaseFederationServlet): """Update room config in group """ + PATH = ( "/groups/(?P[^/]*)/room/(?P[^/]*)" "/config/(?P[^/]*)" @@ -878,7 +875,7 @@ def on_POST(self, origin, content, query, group_id, room_id, config_key): raise SynapseError(403, "requester_user_id doesn't match origin") result = yield self.groups_handler.update_room_in_group( - group_id, requester_user_id, room_id, config_key, content, + group_id, requester_user_id, room_id, config_key, content ) defer.returnValue((200, result)) @@ -887,6 +884,7 @@ def on_POST(self, origin, content, query, group_id, room_id, config_key): class FederationGroupsUsersServlet(BaseFederationServlet): """Get the users in a group on behalf of a user """ + PATH = "/groups/(?P[^/]*)/users" @defer.inlineCallbacks @@ -895,9 +893,7 @@ def on_GET(self, origin, content, query, group_id): if get_domain_from_id(requester_user_id) != origin: raise SynapseError(403, "requester_user_id doesn't match origin") - new_content = yield self.handler.get_users_in_group( - group_id, requester_user_id - ) + new_content = yield self.handler.get_users_in_group(group_id, requester_user_id) defer.returnValue((200, new_content)) @@ -905,6 +901,7 @@ def on_GET(self, origin, content, query, group_id): class FederationGroupsInvitedUsersServlet(BaseFederationServlet): """Get the users that have been invited to a group """ + PATH = "/groups/(?P[^/]*)/invited_users" @defer.inlineCallbacks @@ -923,6 +920,7 @@ def on_GET(self, origin, content, query, group_id): class FederationGroupsInviteServlet(BaseFederationServlet): """Ask a group server to invite someone to the group """ + PATH = "/groups/(?P[^/]*)/users/(?P[^/]*)/invite" @defer.inlineCallbacks @@ -932,7 +930,7 @@ def on_POST(self, origin, content, query, group_id, user_id): raise SynapseError(403, "requester_user_id doesn't match origin") new_content = yield self.handler.invite_to_group( - group_id, user_id, requester_user_id, content, + group_id, user_id, requester_user_id, content ) defer.returnValue((200, new_content)) @@ -941,6 +939,7 @@ def on_POST(self, origin, content, query, group_id, user_id): class FederationGroupsAcceptInviteServlet(BaseFederationServlet): """Accept an invitation from the group server """ + PATH = "/groups/(?P[^/]*)/users/(?P[^/]*)/accept_invite" @defer.inlineCallbacks @@ -948,9 +947,7 @@ def on_POST(self, origin, content, query, group_id, user_id): if get_domain_from_id(user_id) != origin: raise SynapseError(403, "user_id doesn't match origin") - new_content = yield self.handler.accept_invite( - group_id, user_id, content, - ) + new_content = yield self.handler.accept_invite(group_id, user_id, content) defer.returnValue((200, new_content)) @@ -958,6 +955,7 @@ def on_POST(self, origin, content, query, group_id, user_id): class FederationGroupsJoinServlet(BaseFederationServlet): """Attempt to join a group """ + PATH = "/groups/(?P[^/]*)/users/(?P[^/]*)/join" @defer.inlineCallbacks @@ -965,9 +963,7 @@ def on_POST(self, origin, content, query, group_id, user_id): if get_domain_from_id(user_id) != origin: raise SynapseError(403, "user_id doesn't match origin") - new_content = yield self.handler.join_group( - group_id, user_id, content, - ) + new_content = yield self.handler.join_group(group_id, user_id, content) defer.returnValue((200, new_content)) @@ -975,6 +971,7 @@ def on_POST(self, origin, content, query, group_id, user_id): class FederationGroupsRemoveUserServlet(BaseFederationServlet): """Leave or kick a user from the group """ + PATH = "/groups/(?P[^/]*)/users/(?P[^/]*)/remove" @defer.inlineCallbacks @@ -984,7 +981,7 @@ def on_POST(self, origin, content, query, group_id, user_id): raise SynapseError(403, "requester_user_id doesn't match origin") new_content = yield self.handler.remove_user_from_group( - group_id, user_id, requester_user_id, content, + group_id, user_id, requester_user_id, content ) defer.returnValue((200, new_content)) @@ -993,6 +990,7 @@ def on_POST(self, origin, content, query, group_id, user_id): class FederationGroupsLocalInviteServlet(BaseFederationServlet): """A group server has invited a local user """ + PATH = "/groups/local/(?P[^/]*)/users/(?P[^/]*)/invite" @defer.inlineCallbacks @@ -1000,9 +998,7 @@ def on_POST(self, origin, content, query, group_id, user_id): if get_domain_from_id(group_id) != origin: raise SynapseError(403, "group_id doesn't match origin") - new_content = yield self.handler.on_invite( - group_id, user_id, content, - ) + new_content = yield self.handler.on_invite(group_id, user_id, content) defer.returnValue((200, new_content)) @@ -1010,6 +1006,7 @@ def on_POST(self, origin, content, query, group_id, user_id): class FederationGroupsRemoveLocalUserServlet(BaseFederationServlet): """A group server has removed a local user """ + PATH = "/groups/local/(?P[^/]*)/users/(?P[^/]*)/remove" @defer.inlineCallbacks @@ -1018,7 +1015,7 @@ def on_POST(self, origin, content, query, group_id, user_id): raise SynapseError(403, "user_id doesn't match origin") new_content = yield self.handler.user_removed_from_group( - group_id, user_id, content, + group_id, user_id, content ) defer.returnValue((200, new_content)) @@ -1027,6 +1024,7 @@ def on_POST(self, origin, content, query, group_id, user_id): class FederationGroupsRenewAttestaionServlet(BaseFederationServlet): """A group or user's server renews their attestation """ + PATH = "/groups/(?P[^/]*)/renew_attestation/(?P[^/]*)" @defer.inlineCallbacks @@ -1047,6 +1045,7 @@ class FederationGroupsSummaryRoomsServlet(BaseFederationServlet): - /groups/:group/summary/rooms/:room_id - /groups/:group/summary/categories/:category/rooms/:room_id """ + PATH = ( "/groups/(?P[^/]*)/summary" "(/categories/(?P[^/]+))?" @@ -1063,7 +1062,8 @@ def on_POST(self, origin, content, query, group_id, category_id, room_id): raise SynapseError(400, "category_id cannot be empty string") resp = yield self.handler.update_group_summary_room( - group_id, requester_user_id, + group_id, + requester_user_id, room_id=room_id, category_id=category_id, content=content, @@ -1081,9 +1081,7 @@ def on_DELETE(self, origin, content, query, group_id, category_id, room_id): raise SynapseError(400, "category_id cannot be empty string") resp = yield self.handler.delete_group_summary_room( - group_id, requester_user_id, - room_id=room_id, - category_id=category_id, + group_id, requester_user_id, room_id=room_id, category_id=category_id ) defer.returnValue((200, resp)) @@ -1092,9 +1090,8 @@ def on_DELETE(self, origin, content, query, group_id, category_id, room_id): class FederationGroupsCategoriesServlet(BaseFederationServlet): """Get all categories for a group """ - PATH = ( - "/groups/(?P[^/]*)/categories/?" - ) + + PATH = "/groups/(?P[^/]*)/categories/?" @defer.inlineCallbacks def on_GET(self, origin, content, query, group_id): @@ -1102,9 +1099,7 @@ def on_GET(self, origin, content, query, group_id): if get_domain_from_id(requester_user_id) != origin: raise SynapseError(403, "requester_user_id doesn't match origin") - resp = yield self.handler.get_group_categories( - group_id, requester_user_id, - ) + resp = yield self.handler.get_group_categories(group_id, requester_user_id) defer.returnValue((200, resp)) @@ -1112,9 +1107,8 @@ def on_GET(self, origin, content, query, group_id): class FederationGroupsCategoryServlet(BaseFederationServlet): """Add/remove/get a category in a group """ - PATH = ( - "/groups/(?P[^/]*)/categories/(?P[^/]+)" - ) + + PATH = "/groups/(?P[^/]*)/categories/(?P[^/]+)" @defer.inlineCallbacks def on_GET(self, origin, content, query, group_id, category_id): @@ -1138,7 +1132,7 @@ def on_POST(self, origin, content, query, group_id, category_id): raise SynapseError(400, "category_id cannot be empty string") resp = yield self.handler.upsert_group_category( - group_id, requester_user_id, category_id, content, + group_id, requester_user_id, category_id, content ) defer.returnValue((200, resp)) @@ -1153,7 +1147,7 @@ def on_DELETE(self, origin, content, query, group_id, category_id): raise SynapseError(400, "category_id cannot be empty string") resp = yield self.handler.delete_group_category( - group_id, requester_user_id, category_id, + group_id, requester_user_id, category_id ) defer.returnValue((200, resp)) @@ -1162,9 +1156,8 @@ def on_DELETE(self, origin, content, query, group_id, category_id): class FederationGroupsRolesServlet(BaseFederationServlet): """Get roles in a group """ - PATH = ( - "/groups/(?P[^/]*)/roles/?" - ) + + PATH = "/groups/(?P[^/]*)/roles/?" @defer.inlineCallbacks def on_GET(self, origin, content, query, group_id): @@ -1172,9 +1165,7 @@ def on_GET(self, origin, content, query, group_id): if get_domain_from_id(requester_user_id) != origin: raise SynapseError(403, "requester_user_id doesn't match origin") - resp = yield self.handler.get_group_roles( - group_id, requester_user_id, - ) + resp = yield self.handler.get_group_roles(group_id, requester_user_id) defer.returnValue((200, resp)) @@ -1182,9 +1173,8 @@ def on_GET(self, origin, content, query, group_id): class FederationGroupsRoleServlet(BaseFederationServlet): """Add/remove/get a role in a group """ - PATH = ( - "/groups/(?P[^/]*)/roles/(?P[^/]+)" - ) + + PATH = "/groups/(?P[^/]*)/roles/(?P[^/]+)" @defer.inlineCallbacks def on_GET(self, origin, content, query, group_id, role_id): @@ -1192,9 +1182,7 @@ def on_GET(self, origin, content, query, group_id, role_id): if get_domain_from_id(requester_user_id) != origin: raise SynapseError(403, "requester_user_id doesn't match origin") - resp = yield self.handler.get_group_role( - group_id, requester_user_id, role_id - ) + resp = yield self.handler.get_group_role(group_id, requester_user_id, role_id) defer.returnValue((200, resp)) @@ -1208,7 +1196,7 @@ def on_POST(self, origin, content, query, group_id, role_id): raise SynapseError(400, "role_id cannot be empty string") resp = yield self.handler.update_group_role( - group_id, requester_user_id, role_id, content, + group_id, requester_user_id, role_id, content ) defer.returnValue((200, resp)) @@ -1223,7 +1211,7 @@ def on_DELETE(self, origin, content, query, group_id, role_id): raise SynapseError(400, "role_id cannot be empty string") resp = yield self.handler.delete_group_role( - group_id, requester_user_id, role_id, + group_id, requester_user_id, role_id ) defer.returnValue((200, resp)) @@ -1236,6 +1224,7 @@ class FederationGroupsSummaryUsersServlet(BaseFederationServlet): - /groups/:group/summary/users/:user_id - /groups/:group/summary/roles/:role/users/:user_id """ + PATH = ( "/groups/(?P[^/]*)/summary" "(/roles/(?P[^/]+))?" @@ -1252,7 +1241,8 @@ def on_POST(self, origin, content, query, group_id, role_id, user_id): raise SynapseError(400, "role_id cannot be empty string") resp = yield self.handler.update_group_summary_user( - group_id, requester_user_id, + group_id, + requester_user_id, user_id=user_id, role_id=role_id, content=content, @@ -1270,9 +1260,7 @@ def on_DELETE(self, origin, content, query, group_id, role_id, user_id): raise SynapseError(400, "role_id cannot be empty string") resp = yield self.handler.delete_group_summary_user( - group_id, requester_user_id, - user_id=user_id, - role_id=role_id, + group_id, requester_user_id, user_id=user_id, role_id=role_id ) defer.returnValue((200, resp)) @@ -1281,14 +1269,13 @@ def on_DELETE(self, origin, content, query, group_id, role_id, user_id): class FederationGroupsBulkPublicisedServlet(BaseFederationServlet): """Get roles in a group """ - PATH = ( - "/get_groups_publicised" - ) + + PATH = "/get_groups_publicised" @defer.inlineCallbacks def on_POST(self, origin, content, query): resp = yield self.handler.bulk_get_publicised_groups( - content["user_ids"], proxy=False, + content["user_ids"], proxy=False ) defer.returnValue((200, resp)) @@ -1297,6 +1284,7 @@ def on_POST(self, origin, content, query): class FederationGroupsSettingJoinPolicyServlet(BaseFederationServlet): """Sets whether a group is joinable without an invite or knock """ + PATH = "/groups/(?P[^/]*)/settings/m.join_policy" @defer.inlineCallbacks @@ -1317,6 +1305,7 @@ class RoomComplexityServlet(BaseFederationServlet): Indicates to other servers how complex (and therefore likely resource-intensive) a public room this server knows about is. """ + PATH = "/rooms/(?P[^/]*)/complexity" PREFIX = FEDERATION_UNSTABLE_PREFIX @@ -1325,9 +1314,7 @@ def on_GET(self, origin, content, query, room_id): store = self.handler.hs.get_datastore() - is_public = yield store.is_room_world_readable_or_publicly_joinable( - room_id - ) + is_public = yield store.is_room_world_readable_or_publicly_joinable(room_id) if not is_public: raise SynapseError(404, "Room not found", errcode=Codes.INVALID_PARAM) @@ -1362,13 +1349,9 @@ def on_GET(self, origin, content, query, room_id): RoomComplexityServlet, ) -OPENID_SERVLET_CLASSES = ( - OpenIdUserInfo, -) +OPENID_SERVLET_CLASSES = (OpenIdUserInfo,) -ROOM_LIST_CLASSES = ( - PublicRoomList, -) +ROOM_LIST_CLASSES = (PublicRoomList,) GROUP_SERVER_SERVLET_CLASSES = ( FederationGroupsProfileServlet, @@ -1399,9 +1382,7 @@ def on_GET(self, origin, content, query, room_id): ) -GROUP_ATTESTATION_SERVLET_CLASSES = ( - FederationGroupsRenewAttestaionServlet, -) +GROUP_ATTESTATION_SERVLET_CLASSES = (FederationGroupsRenewAttestaionServlet,) DEFAULT_SERVLET_GROUPS = ( "federation", diff --git a/synapse/federation/units.py b/synapse/federation/units.py index 025a79c02209..14aad8f09d71 100644 --- a/synapse/federation/units.py +++ b/synapse/federation/units.py @@ -32,21 +32,11 @@ class Edu(JsonEncodedObject): internal ID or previous references graph. """ - valid_keys = [ - "origin", - "destination", - "edu_type", - "content", - ] + valid_keys = ["origin", "destination", "edu_type", "content"] - required_keys = [ - "edu_type", - ] + required_keys = ["edu_type"] - internal_keys = [ - "origin", - "destination", - ] + internal_keys = ["origin", "destination"] class Transaction(JsonEncodedObject): @@ -75,10 +65,7 @@ class Transaction(JsonEncodedObject): "edus", ] - internal_keys = [ - "transaction_id", - "destination", - ] + internal_keys = ["transaction_id", "destination"] required_keys = [ "transaction_id", @@ -98,9 +85,7 @@ def __init__(self, transaction_id=None, pdus=[], **kwargs): del kwargs["edus"] super(Transaction, self).__init__( - transaction_id=transaction_id, - pdus=pdus, - **kwargs + transaction_id=transaction_id, pdus=pdus, **kwargs ) @staticmethod @@ -109,13 +94,9 @@ def create_new(pdus, **kwargs): transaction_id and origin_server_ts keys. """ if "origin_server_ts" not in kwargs: - raise KeyError( - "Require 'origin_server_ts' to construct a Transaction" - ) + raise KeyError("Require 'origin_server_ts' to construct a Transaction") if "transaction_id" not in kwargs: - raise KeyError( - "Require 'transaction_id' to construct a Transaction" - ) + raise KeyError("Require 'transaction_id' to construct a Transaction") kwargs["pdus"] = [p.get_pdu_json() for p in pdus] diff --git a/synapse/groups/attestations.py b/synapse/groups/attestations.py index 3ba74001d895..e73757570cf7 100644 --- a/synapse/groups/attestations.py +++ b/synapse/groups/attestations.py @@ -65,6 +65,7 @@ class GroupAttestationSigning(object): """Creates and verifies group attestations. """ + def __init__(self, hs): self.keyring = hs.get_keyring() self.clock = hs.get_clock() @@ -113,11 +114,15 @@ def create_attestation(self, group_id, user_id): validity_period *= random.uniform(*DEFAULT_ATTESTATION_JITTER) valid_until_ms = int(self.clock.time_msec() + validity_period) - return sign_json({ - "group_id": group_id, - "user_id": user_id, - "valid_until_ms": valid_until_ms, - }, self.server_name, self.signing_key) + return sign_json( + { + "group_id": group_id, + "user_id": user_id, + "valid_until_ms": valid_until_ms, + }, + self.server_name, + self.signing_key, + ) class GroupAttestionRenewer(object): @@ -134,7 +139,7 @@ def __init__(self, hs): if not hs.config.worker_app: self._renew_attestations_loop = self.clock.looping_call( - self._start_renew_attestations, 30 * 60 * 1000, + self._start_renew_attestations, 30 * 60 * 1000 ) @defer.inlineCallbacks @@ -147,9 +152,7 @@ def on_renew_attestation(self, group_id, user_id, content): raise SynapseError(400, "Neither user not group are on this server") yield self.attestations.verify_attestation( - attestation, - user_id=user_id, - group_id=group_id, + attestation, user_id=user_id, group_id=group_id ) yield self.store.update_remote_attestion(group_id, user_id, attestation) @@ -180,7 +183,8 @@ def _renew_attestation(group_id, user_id): else: logger.warn( "Incorrectly trying to do attestations for user: %r in %r", - user_id, group_id, + user_id, + group_id, ) yield self.store.remove_attestation_renewal(group_id, user_id) return @@ -188,8 +192,7 @@ def _renew_attestation(group_id, user_id): attestation = self.attestations.create_attestation(group_id, user_id) yield self.transport_client.renew_group_attestation( - destination, group_id, user_id, - content={"attestation": attestation}, + destination, group_id, user_id, content={"attestation": attestation} ) yield self.store.update_attestation_renewal( @@ -197,12 +200,12 @@ def _renew_attestation(group_id, user_id): ) except (RequestSendFailed, HttpResponseException) as e: logger.warning( - "Failed to renew attestation of %r in %r: %s", - user_id, group_id, e, + "Failed to renew attestation of %r in %r: %s", user_id, group_id, e ) except Exception: - logger.exception("Error renewing attestation of %r in %r", - user_id, group_id) + logger.exception( + "Error renewing attestation of %r in %r", user_id, group_id + ) for row in rows: group_id = row["group_id"] diff --git a/synapse/groups/groups_server.py b/synapse/groups/groups_server.py index 817be4036085..168c9e3f84b5 100644 --- a/synapse/groups/groups_server.py +++ b/synapse/groups/groups_server.py @@ -54,8 +54,9 @@ def __init__(self, hs): hs.get_groups_attestation_renewer() @defer.inlineCallbacks - def check_group_is_ours(self, group_id, requester_user_id, - and_exists=False, and_is_admin=None): + def check_group_is_ours( + self, group_id, requester_user_id, and_exists=False, and_is_admin=None + ): """Check that the group is ours, and optionally if it exists. If group does exist then return group. @@ -73,7 +74,9 @@ def check_group_is_ours(self, group_id, requester_user_id, if and_exists and not group: raise SynapseError(404, "Unknown group") - is_user_in_group = yield self.store.is_user_in_group(requester_user_id, group_id) + is_user_in_group = yield self.store.is_user_in_group( + requester_user_id, group_id + ) if group and not is_user_in_group and not group["is_public"]: raise SynapseError(404, "Unknown group") @@ -96,25 +99,27 @@ def get_group_summary(self, group_id, requester_user_id): """ yield self.check_group_is_ours(group_id, requester_user_id, and_exists=True) - is_user_in_group = yield self.store.is_user_in_group(requester_user_id, group_id) + is_user_in_group = yield self.store.is_user_in_group( + requester_user_id, group_id + ) profile = yield self.get_group_profile(group_id, requester_user_id) users, roles = yield self.store.get_users_for_summary_by_role( - group_id, include_private=is_user_in_group, + group_id, include_private=is_user_in_group ) # TODO: Add profiles to users rooms, categories = yield self.store.get_rooms_for_summary_by_category( - group_id, include_private=is_user_in_group, + group_id, include_private=is_user_in_group ) for room_entry in rooms: room_id = room_entry["room_id"] joined_users = yield self.store.get_users_in_room(room_id) entry = yield self.room_list_handler.generate_room_entry( - room_id, len(joined_users), with_alias=False, allow_private=True, + room_id, len(joined_users), with_alias=False, allow_private=True ) entry = dict(entry) # so we don't change whats cached entry.pop("room_id", None) @@ -134,7 +139,7 @@ def get_group_summary(self, group_id, requester_user_id): entry["attestation"] = attestation else: entry["attestation"] = self.attestations.create_attestation( - group_id, user_id, + group_id, user_id ) user_profile = yield self.profile_handler.get_profile_from_cache(user_id) @@ -143,34 +148,34 @@ def get_group_summary(self, group_id, requester_user_id): users.sort(key=lambda e: e.get("order", 0)) membership_info = yield self.store.get_users_membership_info_in_group( - group_id, requester_user_id, + group_id, requester_user_id ) - defer.returnValue({ - "profile": profile, - "users_section": { - "users": users, - "roles": roles, - "total_user_count_estimate": 0, # TODO - }, - "rooms_section": { - "rooms": rooms, - "categories": categories, - "total_room_count_estimate": 0, # TODO - }, - "user": membership_info, - }) + defer.returnValue( + { + "profile": profile, + "users_section": { + "users": users, + "roles": roles, + "total_user_count_estimate": 0, # TODO + }, + "rooms_section": { + "rooms": rooms, + "categories": categories, + "total_room_count_estimate": 0, # TODO + }, + "user": membership_info, + } + ) @defer.inlineCallbacks - def update_group_summary_room(self, group_id, requester_user_id, - room_id, category_id, content): + def update_group_summary_room( + self, group_id, requester_user_id, room_id, category_id, content + ): """Add/update a room to the group summary """ yield self.check_group_is_ours( - group_id, - requester_user_id, - and_exists=True, - and_is_admin=requester_user_id, + group_id, requester_user_id, and_exists=True, and_is_admin=requester_user_id ) RoomID.from_string(room_id) # Ensure valid room id @@ -190,21 +195,17 @@ def update_group_summary_room(self, group_id, requester_user_id, defer.returnValue({}) @defer.inlineCallbacks - def delete_group_summary_room(self, group_id, requester_user_id, - room_id, category_id): + def delete_group_summary_room( + self, group_id, requester_user_id, room_id, category_id + ): """Remove a room from the summary """ yield self.check_group_is_ours( - group_id, - requester_user_id, - and_exists=True, - and_is_admin=requester_user_id, + group_id, requester_user_id, and_exists=True, and_is_admin=requester_user_id ) yield self.store.remove_room_from_summary( - group_id=group_id, - room_id=room_id, - category_id=category_id, + group_id=group_id, room_id=room_id, category_id=category_id ) defer.returnValue({}) @@ -223,9 +224,7 @@ def set_group_join_policy(self, group_id, requester_user_id, content): join_policy = _parse_join_policy_from_contents(content) if join_policy is None: - raise SynapseError( - 400, "No value specified for 'm.join_policy'" - ) + raise SynapseError(400, "No value specified for 'm.join_policy'") yield self.store.set_group_join_policy(group_id, join_policy=join_policy) @@ -237,9 +236,7 @@ def get_group_categories(self, group_id, requester_user_id): """ yield self.check_group_is_ours(group_id, requester_user_id, and_exists=True) - categories = yield self.store.get_group_categories( - group_id=group_id, - ) + categories = yield self.store.get_group_categories(group_id=group_id) defer.returnValue({"categories": categories}) @defer.inlineCallbacks @@ -249,8 +246,7 @@ def get_group_category(self, group_id, requester_user_id, category_id): yield self.check_group_is_ours(group_id, requester_user_id, and_exists=True) res = yield self.store.get_group_category( - group_id=group_id, - category_id=category_id, + group_id=group_id, category_id=category_id ) defer.returnValue(res) @@ -260,10 +256,7 @@ def update_group_category(self, group_id, requester_user_id, category_id, conten """Add/Update a group category """ yield self.check_group_is_ours( - group_id, - requester_user_id, - and_exists=True, - and_is_admin=requester_user_id, + group_id, requester_user_id, and_exists=True, and_is_admin=requester_user_id ) is_public = _parse_visibility_from_contents(content) @@ -283,15 +276,11 @@ def delete_group_category(self, group_id, requester_user_id, category_id): """Delete a group category """ yield self.check_group_is_ours( - group_id, - requester_user_id, - and_exists=True, - and_is_admin=requester_user_id + group_id, requester_user_id, and_exists=True, and_is_admin=requester_user_id ) yield self.store.remove_group_category( - group_id=group_id, - category_id=category_id, + group_id=group_id, category_id=category_id ) defer.returnValue({}) @@ -302,9 +291,7 @@ def get_group_roles(self, group_id, requester_user_id): """ yield self.check_group_is_ours(group_id, requester_user_id, and_exists=True) - roles = yield self.store.get_group_roles( - group_id=group_id, - ) + roles = yield self.store.get_group_roles(group_id=group_id) defer.returnValue({"roles": roles}) @defer.inlineCallbacks @@ -313,10 +300,7 @@ def get_group_role(self, group_id, requester_user_id, role_id): """ yield self.check_group_is_ours(group_id, requester_user_id, and_exists=True) - res = yield self.store.get_group_role( - group_id=group_id, - role_id=role_id, - ) + res = yield self.store.get_group_role(group_id=group_id, role_id=role_id) defer.returnValue(res) @defer.inlineCallbacks @@ -324,10 +308,7 @@ def update_group_role(self, group_id, requester_user_id, role_id, content): """Add/update a role in a group """ yield self.check_group_is_ours( - group_id, - requester_user_id, - and_exists=True, - and_is_admin=requester_user_id, + group_id, requester_user_id, and_exists=True, and_is_admin=requester_user_id ) is_public = _parse_visibility_from_contents(content) @@ -335,10 +316,7 @@ def update_group_role(self, group_id, requester_user_id, role_id, content): profile = content.get("profile") yield self.store.upsert_group_role( - group_id=group_id, - role_id=role_id, - is_public=is_public, - profile=profile, + group_id=group_id, role_id=role_id, is_public=is_public, profile=profile ) defer.returnValue({}) @@ -348,26 +326,21 @@ def delete_group_role(self, group_id, requester_user_id, role_id): """Remove role from group """ yield self.check_group_is_ours( - group_id, - requester_user_id, - and_exists=True, - and_is_admin=requester_user_id, + group_id, requester_user_id, and_exists=True, and_is_admin=requester_user_id ) - yield self.store.remove_group_role( - group_id=group_id, - role_id=role_id, - ) + yield self.store.remove_group_role(group_id=group_id, role_id=role_id) defer.returnValue({}) @defer.inlineCallbacks - def update_group_summary_user(self, group_id, requester_user_id, user_id, role_id, - content): + def update_group_summary_user( + self, group_id, requester_user_id, user_id, role_id, content + ): """Add/update a users entry in the group summary """ yield self.check_group_is_ours( - group_id, requester_user_id, and_exists=True, and_is_admin=requester_user_id, + group_id, requester_user_id, and_exists=True, and_is_admin=requester_user_id ) order = content.get("order", None) @@ -389,13 +362,11 @@ def delete_group_summary_user(self, group_id, requester_user_id, user_id, role_i """Remove a user from the group summary """ yield self.check_group_is_ours( - group_id, requester_user_id, and_exists=True, and_is_admin=requester_user_id, + group_id, requester_user_id, and_exists=True, and_is_admin=requester_user_id ) yield self.store.remove_user_from_summary( - group_id=group_id, - user_id=user_id, - role_id=role_id, + group_id=group_id, user_id=user_id, role_id=role_id ) defer.returnValue({}) @@ -411,8 +382,11 @@ def get_group_profile(self, group_id, requester_user_id): if group: cols = [ - "name", "short_description", "long_description", - "avatar_url", "is_public", + "name", + "short_description", + "long_description", + "avatar_url", + "is_public", ] group_description = {key: group[key] for key in cols} group_description["is_openly_joinable"] = group["join_policy"] == "open" @@ -426,12 +400,11 @@ def update_group_profile(self, group_id, requester_user_id, content): """Update the group profile """ yield self.check_group_is_ours( - group_id, requester_user_id, and_exists=True, and_is_admin=requester_user_id, + group_id, requester_user_id, and_exists=True, and_is_admin=requester_user_id ) profile = {} - for keyname in ("name", "avatar_url", "short_description", - "long_description"): + for keyname in ("name", "avatar_url", "short_description", "long_description"): if keyname in content: value = content[keyname] if not isinstance(value, string_types): @@ -449,10 +422,12 @@ def get_users_in_group(self, group_id, requester_user_id): yield self.check_group_is_ours(group_id, requester_user_id, and_exists=True) - is_user_in_group = yield self.store.is_user_in_group(requester_user_id, group_id) + is_user_in_group = yield self.store.is_user_in_group( + requester_user_id, group_id + ) user_results = yield self.store.get_users_in_group( - group_id, include_private=is_user_in_group, + group_id, include_private=is_user_in_group ) chunk = [] @@ -470,24 +445,25 @@ def get_users_in_group(self, group_id, requester_user_id): entry["is_privileged"] = bool(is_privileged) if not self.is_mine_id(g_user_id): - attestation = yield self.store.get_remote_attestation(group_id, g_user_id) + attestation = yield self.store.get_remote_attestation( + group_id, g_user_id + ) if not attestation: continue entry["attestation"] = attestation else: entry["attestation"] = self.attestations.create_attestation( - group_id, g_user_id, + group_id, g_user_id ) chunk.append(entry) # TODO: If admin add lists of users whose attestations have timed out - defer.returnValue({ - "chunk": chunk, - "total_user_count_estimate": len(user_results), - }) + defer.returnValue( + {"chunk": chunk, "total_user_count_estimate": len(user_results)} + ) @defer.inlineCallbacks def get_invited_users_in_group(self, group_id, requester_user_id): @@ -498,7 +474,9 @@ def get_invited_users_in_group(self, group_id, requester_user_id): yield self.check_group_is_ours(group_id, requester_user_id, and_exists=True) - is_user_in_group = yield self.store.is_user_in_group(requester_user_id, group_id) + is_user_in_group = yield self.store.is_user_in_group( + requester_user_id, group_id + ) if not is_user_in_group: raise SynapseError(403, "User not in group") @@ -508,9 +486,7 @@ def get_invited_users_in_group(self, group_id, requester_user_id): user_profiles = [] for user_id in invited_users: - user_profile = { - "user_id": user_id - } + user_profile = {"user_id": user_id} try: profile = yield self.profile_handler.get_profile_from_cache(user_id) user_profile.update(profile) @@ -518,10 +494,9 @@ def get_invited_users_in_group(self, group_id, requester_user_id): logger.warn("Error getting profile for %s: %s", user_id, e) user_profiles.append(user_profile) - defer.returnValue({ - "chunk": user_profiles, - "total_user_count_estimate": len(invited_users), - }) + defer.returnValue( + {"chunk": user_profiles, "total_user_count_estimate": len(invited_users)} + ) @defer.inlineCallbacks def get_rooms_in_group(self, group_id, requester_user_id): @@ -532,10 +507,12 @@ def get_rooms_in_group(self, group_id, requester_user_id): yield self.check_group_is_ours(group_id, requester_user_id, and_exists=True) - is_user_in_group = yield self.store.is_user_in_group(requester_user_id, group_id) + is_user_in_group = yield self.store.is_user_in_group( + requester_user_id, group_id + ) room_results = yield self.store.get_rooms_in_group( - group_id, include_private=is_user_in_group, + group_id, include_private=is_user_in_group ) chunk = [] @@ -544,7 +521,7 @@ def get_rooms_in_group(self, group_id, requester_user_id): joined_users = yield self.store.get_users_in_room(room_id) entry = yield self.room_list_handler.generate_room_entry( - room_id, len(joined_users), with_alias=False, allow_private=True, + room_id, len(joined_users), with_alias=False, allow_private=True ) if not entry: @@ -556,10 +533,9 @@ def get_rooms_in_group(self, group_id, requester_user_id): chunk.sort(key=lambda e: -e["num_joined_members"]) - defer.returnValue({ - "chunk": chunk, - "total_room_count_estimate": len(room_results), - }) + defer.returnValue( + {"chunk": chunk, "total_room_count_estimate": len(room_results)} + ) @defer.inlineCallbacks def add_room_to_group(self, group_id, requester_user_id, room_id, content): @@ -578,8 +554,9 @@ def add_room_to_group(self, group_id, requester_user_id, room_id, content): defer.returnValue({}) @defer.inlineCallbacks - def update_room_in_group(self, group_id, requester_user_id, room_id, config_key, - content): + def update_room_in_group( + self, group_id, requester_user_id, room_id, config_key, content + ): """Update room in group """ RoomID.from_string(room_id) # Ensure valid room id @@ -592,8 +569,7 @@ def update_room_in_group(self, group_id, requester_user_id, room_id, config_key, is_public = _parse_visibility_dict(content) yield self.store.update_room_in_group_visibility( - group_id, room_id, - is_public=is_public, + group_id, room_id, is_public=is_public ) else: raise SynapseError(400, "Uknown config option") @@ -625,10 +601,7 @@ def invite_to_group(self, group_id, user_id, requester_user_id, content): # TODO: Check if user is already invited content = { - "profile": { - "name": group["name"], - "avatar_url": group["avatar_url"], - }, + "profile": {"name": group["name"], "avatar_url": group["avatar_url"]}, "inviter": requester_user_id, } @@ -638,9 +611,7 @@ def invite_to_group(self, group_id, user_id, requester_user_id, content): local_attestation = None else: local_attestation = self.attestations.create_attestation(group_id, user_id) - content.update({ - "attestation": local_attestation, - }) + content.update({"attestation": local_attestation}) res = yield self.transport_client.invite_to_group_notification( get_domain_from_id(user_id), group_id, user_id, content @@ -658,31 +629,24 @@ def invite_to_group(self, group_id, user_id, requester_user_id, content): remote_attestation = res["attestation"] yield self.attestations.verify_attestation( - remote_attestation, - user_id=user_id, - group_id=group_id, + remote_attestation, user_id=user_id, group_id=group_id ) else: remote_attestation = None yield self.store.add_user_to_group( - group_id, user_id, + group_id, + user_id, is_admin=False, is_public=False, # TODO local_attestation=local_attestation, remote_attestation=remote_attestation, ) elif res["state"] == "invite": - yield self.store.add_group_invite( - group_id, user_id, - ) - defer.returnValue({ - "state": "invite" - }) + yield self.store.add_group_invite(group_id, user_id) + defer.returnValue({"state": "invite"}) elif res["state"] == "reject": - defer.returnValue({ - "state": "reject" - }) + defer.returnValue({"state": "reject"}) else: raise SynapseError(502, "Unknown state returned by HS") @@ -693,16 +657,12 @@ def _add_user(self, group_id, user_id, content): See accept_invite, join_group. """ if not self.hs.is_mine_id(user_id): - local_attestation = self.attestations.create_attestation( - group_id, user_id, - ) + local_attestation = self.attestations.create_attestation(group_id, user_id) remote_attestation = content["attestation"] yield self.attestations.verify_attestation( - remote_attestation, - user_id=user_id, - group_id=group_id, + remote_attestation, user_id=user_id, group_id=group_id ) else: local_attestation = None @@ -711,7 +671,8 @@ def _add_user(self, group_id, user_id, content): is_public = _parse_visibility_from_contents(content) yield self.store.add_user_to_group( - group_id, user_id, + group_id, + user_id, is_admin=False, is_public=is_public, local_attestation=local_attestation, @@ -731,17 +692,14 @@ def accept_invite(self, group_id, requester_user_id, content): yield self.check_group_is_ours(group_id, requester_user_id, and_exists=True) is_invited = yield self.store.is_user_invited_to_local_group( - group_id, requester_user_id, + group_id, requester_user_id ) if not is_invited: raise SynapseError(403, "User not invited to group") local_attestation = yield self._add_user(group_id, requester_user_id, content) - defer.returnValue({ - "state": "join", - "attestation": local_attestation, - }) + defer.returnValue({"state": "join", "attestation": local_attestation}) @defer.inlineCallbacks def join_group(self, group_id, requester_user_id, content): @@ -753,15 +711,12 @@ def join_group(self, group_id, requester_user_id, content): group_info = yield self.check_group_is_ours( group_id, requester_user_id, and_exists=True ) - if group_info['join_policy'] != "open": + if group_info["join_policy"] != "open": raise SynapseError(403, "Group is not publicly joinable") local_attestation = yield self._add_user(group_id, requester_user_id, content) - defer.returnValue({ - "state": "join", - "attestation": local_attestation, - }) + defer.returnValue({"state": "join", "attestation": local_attestation}) @defer.inlineCallbacks def knock(self, group_id, requester_user_id, content): @@ -800,9 +755,7 @@ def remove_user_from_group(self, group_id, user_id, requester_user_id, content): is_kick = True - yield self.store.remove_user_from_group( - group_id, user_id, - ) + yield self.store.remove_user_from_group(group_id, user_id) if is_kick: if self.hs.is_mine_id(user_id): @@ -830,19 +783,20 @@ def create_group(self, group_id, requester_user_id, content): if group: raise SynapseError(400, "Group already exists") - is_admin = yield self.auth.is_server_admin(UserID.from_string(requester_user_id)) + is_admin = yield self.auth.is_server_admin( + UserID.from_string(requester_user_id) + ) if not is_admin: if not self.hs.config.enable_group_creation: raise SynapseError( - 403, "Only a server admin can create groups on this server", + 403, "Only a server admin can create groups on this server" ) localpart = group_id_obj.localpart if not localpart.startswith(self.hs.config.group_creation_prefix): raise SynapseError( 400, - "Can only create groups with prefix %r on this server" % ( - self.hs.config.group_creation_prefix, - ), + "Can only create groups with prefix %r on this server" + % (self.hs.config.group_creation_prefix,), ) profile = content.get("profile", {}) @@ -865,21 +819,19 @@ def create_group(self, group_id, requester_user_id, content): remote_attestation = content["attestation"] yield self.attestations.verify_attestation( - remote_attestation, - user_id=requester_user_id, - group_id=group_id, + remote_attestation, user_id=requester_user_id, group_id=group_id ) local_attestation = self.attestations.create_attestation( - group_id, - requester_user_id, + group_id, requester_user_id ) else: local_attestation = None remote_attestation = None yield self.store.add_user_to_group( - group_id, requester_user_id, + group_id, + requester_user_id, is_admin=True, is_public=True, # TODO local_attestation=local_attestation, @@ -893,9 +845,7 @@ def create_group(self, group_id, requester_user_id, content): avatar_url=user_profile.get("avatar_url"), ) - defer.returnValue({ - "group_id": group_id, - }) + defer.returnValue({"group_id": group_id}) @defer.inlineCallbacks def delete_group(self, group_id, requester_user_id): @@ -911,29 +861,22 @@ def delete_group(self, group_id, requester_user_id): Deferred """ - yield self.check_group_is_ours( - group_id, requester_user_id, - and_exists=True, - ) + yield self.check_group_is_ours(group_id, requester_user_id, and_exists=True) # Only server admins or group admins can delete groups. - is_admin = yield self.store.is_user_admin_in_group( - group_id, requester_user_id - ) + is_admin = yield self.store.is_user_admin_in_group(group_id, requester_user_id) if not is_admin: is_admin = yield self.auth.is_server_admin( - UserID.from_string(requester_user_id), + UserID.from_string(requester_user_id) ) if not is_admin: raise SynapseError(403, "User is not an admin") # Before deleting the group lets kick everyone out of it - users = yield self.store.get_users_in_group( - group_id, include_private=True, - ) + users = yield self.store.get_users_in_group(group_id, include_private=True) @defer.inlineCallbacks def _kick_user_from_group(user_id): @@ -989,9 +932,7 @@ def _parse_join_policy_dict(join_policy_dict): return "invite" if join_policy_type not in ("invite", "open"): - raise SynapseError( - 400, "Synapse only supports 'invite'/'open' join rule" - ) + raise SynapseError(400, "Synapse only supports 'invite'/'open' join rule") return join_policy_type @@ -1018,7 +959,5 @@ def _parse_visibility_dict(visibility): return True if vis_type not in ("public", "private"): - raise SynapseError( - 400, "Synapse only supports 'public'/'private' visibility" - ) + raise SynapseError(400, "Synapse only supports 'public'/'private' visibility") return vis_type == "public" diff --git a/synapse/handlers/_base.py b/synapse/handlers/_base.py index dca337ec61f3..c29c78bd653f 100644 --- a/synapse/handlers/_base.py +++ b/synapse/handlers/_base.py @@ -94,14 +94,15 @@ def ratelimit(self, requester, update=True): burst_count = self.hs.config.rc_message.burst_count allowed, time_allowed = self.ratelimiter.can_do_action( - user_id, time_now, + user_id, + time_now, rate_hz=messages_per_second, burst_count=burst_count, update=update, ) if not allowed: raise LimitExceededError( - retry_after_ms=int(1000 * (time_allowed - time_now)), + retry_after_ms=int(1000 * (time_allowed - time_now)) ) @defer.inlineCallbacks @@ -139,7 +140,7 @@ def kick_guest_users(self, current_state): if member_event.content["membership"] not in { Membership.JOIN, - Membership.INVITE + Membership.INVITE, }: continue @@ -156,8 +157,7 @@ def kick_guest_users(self, current_state): # and having homeservers have their own users leave keeps more # of that decision-making and control local to the guest-having # homeserver. - requester = synapse.types.create_requester( - target_user, is_guest=True) + requester = synapse.types.create_requester(target_user, is_guest=True) handler = self.hs.get_room_member_handler() yield handler.update_membership( requester, diff --git a/synapse/handlers/account_data.py b/synapse/handlers/account_data.py index 7fa5d44d2958..e62e6cab7702 100644 --- a/synapse/handlers/account_data.py +++ b/synapse/handlers/account_data.py @@ -20,7 +20,7 @@ class AccountDataEventSource(object): def __init__(self, hs): self.store = hs.get_datastore() - def get_current_key(self, direction='f'): + def get_current_key(self, direction="f"): return self.store.get_max_account_data_stream_id() @defer.inlineCallbacks @@ -34,29 +34,22 @@ def get_new_events(self, user, from_key, **kwargs): tags = yield self.store.get_updated_tags(user_id, last_stream_id) for room_id, room_tags in tags.items(): - results.append({ - "type": "m.tag", - "content": {"tags": room_tags}, - "room_id": room_id, - }) + results.append( + {"type": "m.tag", "content": {"tags": room_tags}, "room_id": room_id} + ) account_data, room_account_data = ( yield self.store.get_updated_account_data_for_user(user_id, last_stream_id) ) for account_data_type, content in account_data.items(): - results.append({ - "type": account_data_type, - "content": content, - }) + results.append({"type": account_data_type, "content": content}) for room_id, account_data in room_account_data.items(): for account_data_type, content in account_data.items(): - results.append({ - "type": account_data_type, - "content": content, - "room_id": room_id, - }) + results.append( + {"type": account_data_type, "content": content, "room_id": room_id} + ) defer.returnValue((results, current_stream_id)) diff --git a/synapse/handlers/account_validity.py b/synapse/handlers/account_validity.py index 5e0b92eb1cc2..0719da3ab7e0 100644 --- a/synapse/handlers/account_validity.py +++ b/synapse/handlers/account_validity.py @@ -49,12 +49,10 @@ def __init__(self, hs): app_name = self.hs.config.email_app_name self._subject = self._account_validity.renew_email_subject % { - "app": app_name, + "app": app_name } - self._from_string = self.hs.config.email_notif_from % { - "app": app_name, - } + self._from_string = self.hs.config.email_notif_from % {"app": app_name} except Exception: # If substitution failed, fall back to the bare strings. self._subject = self._account_validity.renew_email_subject @@ -69,10 +67,7 @@ def __init__(self, hs): ) # Check the renewal emails to send and send them every 30min. - self.clock.looping_call( - self.send_renewal_emails, - 30 * 60 * 1000, - ) + self.clock.looping_call(self.send_renewal_emails, 30 * 60 * 1000) @defer.inlineCallbacks def send_renewal_emails(self): @@ -86,8 +81,7 @@ def send_renewal_emails(self): if expiring_users: for user in expiring_users: yield self._send_renewal_email( - user_id=user["user_id"], - expiration_ts=user["expiration_ts_ms"], + user_id=user["user_id"], expiration_ts=user["expiration_ts_ms"] ) @defer.inlineCallbacks @@ -146,32 +140,33 @@ def _send_renewal_email(self, user_id, expiration_ts): for address in addresses: raw_to = email.utils.parseaddr(address)[1] - multipart_msg = MIMEMultipart('alternative') - multipart_msg['Subject'] = self._subject - multipart_msg['From'] = self._from_string - multipart_msg['To'] = address - multipart_msg['Date'] = email.utils.formatdate() - multipart_msg['Message-ID'] = email.utils.make_msgid() + multipart_msg = MIMEMultipart("alternative") + multipart_msg["Subject"] = self._subject + multipart_msg["From"] = self._from_string + multipart_msg["To"] = address + multipart_msg["Date"] = email.utils.formatdate() + multipart_msg["Message-ID"] = email.utils.make_msgid() multipart_msg.attach(text_part) multipart_msg.attach(html_part) logger.info("Sending renewal email to %s", address) - yield make_deferred_yieldable(self.sendmail( - self.hs.config.email_smtp_host, - self._raw_from, raw_to, multipart_msg.as_string().encode('utf8'), - reactor=self.hs.get_reactor(), - port=self.hs.config.email_smtp_port, - requireAuthentication=self.hs.config.email_smtp_user is not None, - username=self.hs.config.email_smtp_user, - password=self.hs.config.email_smtp_pass, - requireTransportSecurity=self.hs.config.require_transport_security - )) - - yield self.store.set_renewal_mail_status( - user_id=user_id, - email_sent=True, - ) + yield make_deferred_yieldable( + self.sendmail( + self.hs.config.email_smtp_host, + self._raw_from, + raw_to, + multipart_msg.as_string().encode("utf8"), + reactor=self.hs.get_reactor(), + port=self.hs.config.email_smtp_port, + requireAuthentication=self.hs.config.email_smtp_user is not None, + username=self.hs.config.email_smtp_user, + password=self.hs.config.email_smtp_pass, + requireTransportSecurity=self.hs.config.require_transport_security, + ) + ) + + yield self.store.set_renewal_mail_status(user_id=user_id, email_sent=True) @defer.inlineCallbacks def _get_email_addresses_for_user(self, user_id): @@ -248,9 +243,7 @@ def renew_account_for_user(self, user_id, expiration_ts=None, email_sent=False): expiration_ts = self.clock.time_msec() + self._account_validity.period yield self.store.set_account_validity_for_user( - user_id=user_id, - expiration_ts=expiration_ts, - email_sent=email_sent, + user_id=user_id, expiration_ts=expiration_ts, email_sent=email_sent ) defer.returnValue(expiration_ts) diff --git a/synapse/handlers/acme.py b/synapse/handlers/acme.py index 813777bf1812..01e0ef408d2c 100644 --- a/synapse/handlers/acme.py +++ b/synapse/handlers/acme.py @@ -93,24 +93,20 @@ def start_listening(self): ) well_known = Resource() - well_known.putChild(b'acme-challenge', responder.resource) + well_known.putChild(b"acme-challenge", responder.resource) responder_resource = Resource() - responder_resource.putChild(b'.well-known', well_known) - responder_resource.putChild(b'check', static.Data(b'OK', b'text/plain')) + responder_resource.putChild(b".well-known", well_known) + responder_resource.putChild(b"check", static.Data(b"OK", b"text/plain")) srv = server.Site(responder_resource) bind_addresses = self.hs.config.acme_bind_addresses for host in bind_addresses: logger.info( - "Listening for ACME requests on %s:%i", host, self.hs.config.acme_port, + "Listening for ACME requests on %s:%i", host, self.hs.config.acme_port ) try: - self.reactor.listenTCP( - self.hs.config.acme_port, - srv, - interface=host, - ) + self.reactor.listenTCP(self.hs.config.acme_port, srv, interface=host) except twisted.internet.error.CannotListenError as e: check_bind_error(e, host, bind_addresses) diff --git a/synapse/handlers/admin.py b/synapse/handlers/admin.py index 5d629126fcd5..941ebfa1072b 100644 --- a/synapse/handlers/admin.py +++ b/synapse/handlers/admin.py @@ -23,7 +23,6 @@ class AdminHandler(BaseHandler): - def __init__(self, hs): super(AdminHandler, self).__init__(hs) @@ -33,23 +32,17 @@ def get_whois(self, user): sessions = yield self.store.get_user_ip_and_agents(user) for session in sessions: - connections.append({ - "ip": session["ip"], - "last_seen": session["last_seen"], - "user_agent": session["user_agent"], - }) + connections.append( + { + "ip": session["ip"], + "last_seen": session["last_seen"], + "user_agent": session["user_agent"], + } + ) ret = { "user_id": user.to_string(), - "devices": { - "": { - "sessions": [ - { - "connections": connections, - } - ] - }, - }, + "devices": {"": {"sessions": [{"connections": connections}]}}, } defer.returnValue(ret) diff --git a/synapse/handlers/appservice.py b/synapse/handlers/appservice.py index 17eedf4dbf92..5cc89d43f6a6 100644 --- a/synapse/handlers/appservice.py +++ b/synapse/handlers/appservice.py @@ -38,7 +38,6 @@ class ApplicationServicesHandler(object): - def __init__(self, hs): self.store = hs.get_datastore() self.is_mine_id = hs.is_mine_id @@ -101,9 +100,10 @@ def handle_event(event): yield self._check_user_exists(event.state_key) if not self.started_scheduler: + def start_scheduler(): return self.scheduler.start().addErrback( - log_failure, "Application Services Failure", + log_failure, "Application Services Failure" ) run_as_background_process("as_scheduler", start_scheduler) @@ -118,10 +118,15 @@ def handle_room_events(events): for event in events: yield handle_event(event) - yield make_deferred_yieldable(defer.gatherResults([ - run_in_background(handle_room_events, evs) - for evs in itervalues(events_by_room) - ], consumeErrors=True)) + yield make_deferred_yieldable( + defer.gatherResults( + [ + run_in_background(handle_room_events, evs) + for evs in itervalues(events_by_room) + ], + consumeErrors=True, + ) + ) yield self.store.set_appservice_last_pos(upper_bound) @@ -129,20 +134,23 @@ def handle_room_events(events): ts = yield self.store.get_received_ts(events[-1].event_id) synapse.metrics.event_processing_positions.labels( - "appservice_sender").set(upper_bound) + "appservice_sender" + ).set(upper_bound) events_processed_counter.inc(len(events)) - event_processing_loop_room_count.labels( - "appservice_sender" - ).inc(len(events_by_room)) + event_processing_loop_room_count.labels("appservice_sender").inc( + len(events_by_room) + ) event_processing_loop_counter.labels("appservice_sender").inc() synapse.metrics.event_processing_lag.labels( - "appservice_sender").set(now - ts) + "appservice_sender" + ).set(now - ts) synapse.metrics.event_processing_last_ts.labels( - "appservice_sender").set(ts) + "appservice_sender" + ).set(ts) finally: self.is_processing = False @@ -155,13 +163,9 @@ def query_user_exists(self, user_id): Returns: True if this user exists on at least one application service. """ - user_query_services = yield self._get_services_for_user( - user_id=user_id - ) + user_query_services = yield self._get_services_for_user(user_id=user_id) for user_service in user_query_services: - is_known_user = yield self.appservice_api.query_user( - user_service, user_id - ) + is_known_user = yield self.appservice_api.query_user(user_service, user_id) if is_known_user: defer.returnValue(True) defer.returnValue(False) @@ -179,9 +183,7 @@ def query_room_alias_exists(self, room_alias): room_alias_str = room_alias.to_string() services = self.store.get_app_services() alias_query_services = [ - s for s in services if ( - s.is_interested_in_alias(room_alias_str) - ) + s for s in services if (s.is_interested_in_alias(room_alias_str)) ] for alias_service in alias_query_services: is_known_alias = yield self.appservice_api.query_alias( @@ -189,22 +191,24 @@ def query_room_alias_exists(self, room_alias): ) if is_known_alias: # the alias exists now so don't query more ASes. - result = yield self.store.get_association_from_room_alias( - room_alias - ) + result = yield self.store.get_association_from_room_alias(room_alias) defer.returnValue(result) @defer.inlineCallbacks def query_3pe(self, kind, protocol, fields): services = yield self._get_services_for_3pn(protocol) - results = yield make_deferred_yieldable(defer.DeferredList([ - run_in_background( - self.appservice_api.query_3pe, - service, kind, protocol, fields, + results = yield make_deferred_yieldable( + defer.DeferredList( + [ + run_in_background( + self.appservice_api.query_3pe, service, kind, protocol, fields + ) + for service in services + ], + consumeErrors=True, ) - for service in services - ], consumeErrors=True)) + ) ret = [] for (success, result) in results: @@ -276,18 +280,12 @@ def _get_services_for_event(self, event): def _get_services_for_user(self, user_id): services = self.store.get_app_services() - interested_list = [ - s for s in services if ( - s.is_interested_in_user(user_id) - ) - ] + interested_list = [s for s in services if (s.is_interested_in_user(user_id))] return defer.succeed(interested_list) def _get_services_for_3pn(self, protocol): services = self.store.get_app_services() - interested_list = [ - s for s in services if s.is_interested_in_protocol(protocol) - ] + interested_list = [s for s in services if s.is_interested_in_protocol(protocol)] return defer.succeed(interested_list) @defer.inlineCallbacks diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py index a0cf37a9f9e0..97b21c40932b 100644 --- a/synapse/handlers/auth.py +++ b/synapse/handlers/auth.py @@ -134,13 +134,9 @@ def validate_user_via_ui_auth(self, requester, request_body, clientip): """ # build a list of supported flows - flows = [ - [login_type] for login_type in self._supported_login_types - ] + flows = [[login_type] for login_type in self._supported_login_types] - result, params, _ = yield self.check_auth( - flows, request_body, clientip, - ) + result, params, _ = yield self.check_auth(flows, request_body, clientip) # find the completed login type for login_type in self._supported_login_types: @@ -151,9 +147,7 @@ def validate_user_via_ui_auth(self, requester, request_body, clientip): break else: # this can't happen - raise Exception( - "check_auth returned True but no successful login type", - ) + raise Exception("check_auth returned True but no successful login type") # check that the UI auth matched the access token if user_id != requester.user.to_string(): @@ -215,11 +209,11 @@ def check_auth(self, flows, clientdict, clientip, password_servlet=False): authdict = None sid = None - if clientdict and 'auth' in clientdict: - authdict = clientdict['auth'] - del clientdict['auth'] - if 'session' in authdict: - sid = authdict['session'] + if clientdict and "auth" in clientdict: + authdict = clientdict["auth"] + del clientdict["auth"] + if "session" in authdict: + sid = authdict["session"] session = self._get_session_info(sid) if len(clientdict) > 0: @@ -232,27 +226,27 @@ def check_auth(self, flows, clientdict, clientip, password_servlet=False): # on a home server. # Revisit: Assumimg the REST APIs do sensible validation, the data # isn't arbintrary. - session['clientdict'] = clientdict + session["clientdict"] = clientdict self._save_session(session) - elif 'clientdict' in session: - clientdict = session['clientdict'] + elif "clientdict" in session: + clientdict = session["clientdict"] if not authdict: raise InteractiveAuthIncompleteError( - self._auth_dict_for_flows(flows, session), + self._auth_dict_for_flows(flows, session) ) - if 'creds' not in session: - session['creds'] = {} - creds = session['creds'] + if "creds" not in session: + session["creds"] = {} + creds = session["creds"] # check auth type currently being presented errordict = {} - if 'type' in authdict: - login_type = authdict['type'] + if "type" in authdict: + login_type = authdict["type"] try: result = yield self._check_auth_dict( - authdict, clientip, password_servlet=password_servlet, + authdict, clientip, password_servlet=password_servlet ) if result: creds[login_type] = result @@ -281,16 +275,15 @@ def check_auth(self, flows, clientdict, clientip, password_servlet=False): # and is not sensitive). logger.info( "Auth completed with creds: %r. Client dict has keys: %r", - creds, list(clientdict) + creds, + list(clientdict), ) - defer.returnValue((creds, clientdict, session['id'])) + defer.returnValue((creds, clientdict, session["id"])) ret = self._auth_dict_for_flows(flows, session) - ret['completed'] = list(creds) + ret["completed"] = list(creds) ret.update(errordict) - raise InteractiveAuthIncompleteError( - ret, - ) + raise InteractiveAuthIncompleteError(ret) @defer.inlineCallbacks def add_oob_auth(self, stagetype, authdict, clientip): @@ -300,15 +293,13 @@ def add_oob_auth(self, stagetype, authdict, clientip): """ if stagetype not in self.checkers: raise LoginError(400, "", Codes.MISSING_PARAM) - if 'session' not in authdict: + if "session" not in authdict: raise LoginError(400, "", Codes.MISSING_PARAM) - sess = self._get_session_info( - authdict['session'] - ) - if 'creds' not in sess: - sess['creds'] = {} - creds = sess['creds'] + sess = self._get_session_info(authdict["session"]) + if "creds" not in sess: + sess["creds"] = {} + creds = sess["creds"] result = yield self.checkers[stagetype](authdict, clientip) if result: @@ -329,10 +320,10 @@ def get_session_id(self, clientdict): not send a session ID, returns None. """ sid = None - if clientdict and 'auth' in clientdict: - authdict = clientdict['auth'] - if 'session' in authdict: - sid = authdict['session'] + if clientdict and "auth" in clientdict: + authdict = clientdict["auth"] + if "session" in authdict: + sid = authdict["session"] return sid def set_session_data(self, session_id, key, value): @@ -347,7 +338,7 @@ def set_session_data(self, session_id, key, value): value (any): The data to store """ sess = self._get_session_info(session_id) - sess.setdefault('serverdict', {})[key] = value + sess.setdefault("serverdict", {})[key] = value self._save_session(sess) def get_session_data(self, session_id, key, default=None): @@ -360,7 +351,7 @@ def get_session_data(self, session_id, key, default=None): default (any): Value to return if the key has not been set """ sess = self._get_session_info(session_id) - return sess.setdefault('serverdict', {}).get(key, default) + return sess.setdefault("serverdict", {}).get(key, default) @defer.inlineCallbacks def _check_auth_dict(self, authdict, clientip, password_servlet=False): @@ -378,15 +369,13 @@ def _check_auth_dict(self, authdict, clientip, password_servlet=False): SynapseError if there was a problem with the request LoginError if there was an authentication problem. """ - login_type = authdict['type'] + login_type = authdict["type"] checker = self.checkers.get(login_type) if checker is not None: # XXX: Temporary workaround for having Synapse handle password resets # See AuthHandler.check_auth for further details res = yield checker( - authdict, - clientip=clientip, - password_servlet=password_servlet, + authdict, clientip=clientip, password_servlet=password_servlet ) defer.returnValue(res) @@ -408,13 +397,11 @@ def _check_recaptcha(self, authdict, clientip, **kwargs): # Client tried to provide captcha but didn't give the parameter: # bad request. raise LoginError( - 400, "Captcha response is required", - errcode=Codes.CAPTCHA_NEEDED + 400, "Captcha response is required", errcode=Codes.CAPTCHA_NEEDED ) logger.info( - "Submitting recaptcha response %s with remoteip %s", - user_response, clientip + "Submitting recaptcha response %s with remoteip %s", user_response, clientip ) # TODO: get this from the homeserver rather than creating a new one for @@ -424,34 +411,34 @@ def _check_recaptcha(self, authdict, clientip, **kwargs): resp_body = yield client.post_urlencoded_get_json( self.hs.config.recaptcha_siteverify_api, args={ - 'secret': self.hs.config.recaptcha_private_key, - 'response': user_response, - 'remoteip': clientip, - } + "secret": self.hs.config.recaptcha_private_key, + "response": user_response, + "remoteip": clientip, + }, ) except PartialDownloadError as pde: # Twisted is silly data = pde.response resp_body = json.loads(data) - if 'success' in resp_body: + if "success" in resp_body: # Note that we do NOT check the hostname here: we explicitly # intend the CAPTCHA to be presented by whatever client the # user is using, we just care that they have completed a CAPTCHA. logger.info( "%s reCAPTCHA from hostname %s", - "Successful" if resp_body['success'] else "Failed", - resp_body.get('hostname') + "Successful" if resp_body["success"] else "Failed", + resp_body.get("hostname"), ) - if resp_body['success']: + if resp_body["success"]: defer.returnValue(True) raise LoginError(401, "", errcode=Codes.UNAUTHORIZED) def _check_email_identity(self, authdict, **kwargs): - return self._check_threepid('email', authdict, **kwargs) + return self._check_threepid("email", authdict, **kwargs) def _check_msisdn(self, authdict, **kwargs): - return self._check_threepid('msisdn', authdict) + return self._check_threepid("msisdn", authdict) def _check_dummy_auth(self, authdict, **kwargs): return defer.succeed(True) @@ -461,10 +448,10 @@ def _check_terms_auth(self, authdict, **kwargs): @defer.inlineCallbacks def _check_threepid(self, medium, authdict, password_servlet=False, **kwargs): - if 'threepid_creds' not in authdict: + if "threepid_creds" not in authdict: raise LoginError(400, "Missing threepid_creds", Codes.MISSING_PARAM) - threepid_creds = authdict['threepid_creds'] + threepid_creds = authdict["threepid_creds"] identity_handler = self.hs.get_handlers().identity_handler @@ -482,31 +469,36 @@ def _check_threepid(self, medium, authdict, password_servlet=False, **kwargs): validated=True, ) - threepid = { - "medium": row["medium"], - "address": row["address"], - "validated_at": row["validated_at"], - } if row else None + threepid = ( + { + "medium": row["medium"], + "address": row["address"], + "validated_at": row["validated_at"], + } + if row + else None + ) if row: # Valid threepid returned, delete from the db yield self.store.delete_threepid_session(threepid_creds["sid"]) else: - raise SynapseError(400, "Password resets are not enabled on this homeserver") + raise SynapseError( + 400, "Password resets are not enabled on this homeserver" + ) if not threepid: raise LoginError(401, "", errcode=Codes.UNAUTHORIZED) - if threepid['medium'] != medium: + if threepid["medium"] != medium: raise LoginError( 401, - "Expecting threepid of type '%s', got '%s'" % ( - medium, threepid['medium'], - ), - errcode=Codes.UNAUTHORIZED + "Expecting threepid of type '%s', got '%s'" + % (medium, threepid["medium"]), + errcode=Codes.UNAUTHORIZED, ) - threepid['threepid_creds'] = authdict['threepid_creds'] + threepid["threepid_creds"] = authdict["threepid_creds"] defer.returnValue(threepid) @@ -520,13 +512,14 @@ def _get_params_terms(self): "version": self.hs.config.user_consent_version, "en": { "name": self.hs.config.user_consent_policy_name, - "url": "%s_matrix/consent?v=%s" % ( + "url": "%s_matrix/consent?v=%s" + % ( self.hs.config.public_baseurl, self.hs.config.user_consent_version, ), }, - }, - }, + } + } } def _auth_dict_for_flows(self, flows, session): @@ -547,9 +540,9 @@ def _auth_dict_for_flows(self, flows, session): params[stage] = get_params[stage]() return { - "session": session['id'], + "session": session["id"], "flows": [{"stages": f} for f in public_flows], - "params": params + "params": params, } def _get_session_info(self, session_id): @@ -560,9 +553,7 @@ def _get_session_info(self, session_id): # create a new session while session_id is None or session_id in self.sessions: session_id = stringutils.random_string(24) - self.sessions[session_id] = { - "id": session_id, - } + self.sessions[session_id] = {"id": session_id} return self.sessions[session_id] @@ -652,7 +643,8 @@ def _find_user_id_and_pwd_hash(self, user_id): logger.warn( "Attempted to login as %s but it matches more than one user " "inexactly: %r", - user_id, user_infos.keys() + user_id, + user_infos.keys(), ) defer.returnValue(result) @@ -690,12 +682,10 @@ def validate_login(self, username, login_submission): user is too high too proceed. """ - if username.startswith('@'): + if username.startswith("@"): qualified_user_id = username else: - qualified_user_id = UserID( - username, self.hs.hostname - ).to_string() + qualified_user_id = UserID(username, self.hs.hostname).to_string() self.ratelimit_login_per_account(qualified_user_id) @@ -713,17 +703,15 @@ def validate_login(self, username, login_submission): raise SynapseError(400, "Missing parameter: password") for provider in self.password_providers: - if (hasattr(provider, "check_password") - and login_type == LoginType.PASSWORD): + if hasattr(provider, "check_password") and login_type == LoginType.PASSWORD: known_login_type = True - is_valid = yield provider.check_password( - qualified_user_id, password, - ) + is_valid = yield provider.check_password(qualified_user_id, password) if is_valid: defer.returnValue((qualified_user_id, None)) - if (not hasattr(provider, "get_supported_login_types") - or not hasattr(provider, "check_auth")): + if not hasattr(provider, "get_supported_login_types") or not hasattr( + provider, "check_auth" + ): # this password provider doesn't understand custom login types continue @@ -744,15 +732,12 @@ def validate_login(self, username, login_submission): login_dict[f] = login_submission[f] if missing_fields: raise SynapseError( - 400, "Missing parameters for login type %s: %s" % ( - login_type, - missing_fields, - ), + 400, + "Missing parameters for login type %s: %s" + % (login_type, missing_fields), ) - result = yield provider.check_auth( - username, login_type, login_dict, - ) + result = yield provider.check_auth(username, login_type, login_dict) if result: if isinstance(result, str): result = (result, None) @@ -762,7 +747,7 @@ def validate_login(self, username, login_submission): known_login_type = True canonical_user_id = yield self._check_local_password( - qualified_user_id, password, + qualified_user_id, password ) if canonical_user_id: @@ -773,7 +758,8 @@ def validate_login(self, username, login_submission): # unknown username or invalid password. self._failed_attempts_ratelimiter.ratelimit( - qualified_user_id.lower(), time_now_s=self._clock.time(), + qualified_user_id.lower(), + time_now_s=self._clock.time(), rate_hz=self.hs.config.rc_login_failed_attempts.per_second, burst_count=self.hs.config.rc_login_failed_attempts.burst_count, update=True, @@ -781,10 +767,7 @@ def validate_login(self, username, login_submission): # We raise a 403 here, but note that if we're doing user-interactive # login, it turns all LoginErrors into a 401 anyway. - raise LoginError( - 403, "Invalid password", - errcode=Codes.FORBIDDEN - ) + raise LoginError(403, "Invalid password", errcode=Codes.FORBIDDEN) @defer.inlineCallbacks def check_password_provider_3pid(self, medium, address, password): @@ -810,9 +793,7 @@ def check_password_provider_3pid(self, medium, address, password): # success, to a str (which is the user_id) or a tuple of # (user_id, callback_func), where callback_func should be run # after we've finished everything else - result = yield provider.check_3pid_auth( - medium, address, password, - ) + result = yield provider.check_3pid_auth(medium, address, password) if result: # Check if the return value is a str or a tuple if isinstance(result, str): @@ -853,8 +834,7 @@ def _check_local_password(self, user_id, password): @defer.inlineCallbacks def issue_access_token(self, user_id, device_id=None): access_token = self.macaroon_gen.generate_access_token(user_id) - yield self.store.add_access_token_to_user(user_id, access_token, - device_id) + yield self.store.add_access_token_to_user(user_id, access_token, device_id) defer.returnValue(access_token) @defer.inlineCallbacks @@ -896,12 +876,13 @@ def delete_access_token(self, access_token): # delete pushers associated with this access token if user_info["token_id"] is not None: yield self.hs.get_pusherpool().remove_pushers_by_access_token( - str(user_info["user"]), (user_info["token_id"], ) + str(user_info["user"]), (user_info["token_id"],) ) @defer.inlineCallbacks - def delete_access_tokens_for_user(self, user_id, except_token_id=None, - device_id=None): + def delete_access_tokens_for_user( + self, user_id, except_token_id=None, device_id=None + ): """Invalidate access tokens belonging to a user Args: @@ -915,7 +896,7 @@ def delete_access_tokens_for_user(self, user_id, except_token_id=None, Deferred """ tokens_and_devices = yield self.store.user_delete_access_tokens( - user_id, except_token_id=except_token_id, device_id=device_id, + user_id, except_token_id=except_token_id, device_id=device_id ) # see if any of our auth providers want to know about this @@ -923,14 +904,12 @@ def delete_access_tokens_for_user(self, user_id, except_token_id=None, if hasattr(provider, "on_logged_out"): for token, token_id, device_id in tokens_and_devices: yield provider.on_logged_out( - user_id=user_id, - device_id=device_id, - access_token=token, + user_id=user_id, device_id=device_id, access_token=token ) # delete pushers associated with the access tokens yield self.hs.get_pusherpool().remove_pushers_by_access_token( - user_id, (token_id for _, token_id, _ in tokens_and_devices), + user_id, (token_id for _, token_id, _ in tokens_and_devices) ) @defer.inlineCallbacks @@ -944,12 +923,11 @@ def add_threepid(self, user_id, medium, address, validated_at): # of specific types of threepid (and fixes the fact that checking # for the presence of an email address during password reset was # case sensitive). - if medium == 'email': + if medium == "email": address = address.lower() yield self.store.user_add_threepid( - user_id, medium, address, validated_at, - self.hs.get_clock().time_msec() + user_id, medium, address, validated_at, self.hs.get_clock().time_msec() ) @defer.inlineCallbacks @@ -973,22 +951,15 @@ def delete_threepid(self, user_id, medium, address, id_server=None): """ # 'Canonicalise' email addresses as per above - if medium == 'email': + if medium == "email": address = address.lower() identity_handler = self.hs.get_handlers().identity_handler result = yield identity_handler.try_unbind_threepid( - user_id, - { - 'medium': medium, - 'address': address, - 'id_server': id_server, - }, + user_id, {"medium": medium, "address": address, "id_server": id_server} ) - yield self.store.user_delete_threepid( - user_id, medium, address, - ) + yield self.store.user_delete_threepid(user_id, medium, address) defer.returnValue(result) def _save_session(self, session): @@ -1006,14 +977,15 @@ def hash(self, password): Returns: Deferred(unicode): Hashed password. """ + def _do_hash(): # Normalise the Unicode in the password pw = unicodedata.normalize("NFKC", password) return bcrypt.hashpw( - pw.encode('utf8') + self.hs.config.password_pepper.encode("utf8"), + pw.encode("utf8") + self.hs.config.password_pepper.encode("utf8"), bcrypt.gensalt(self.bcrypt_rounds), - ).decode('ascii') + ).decode("ascii") return logcontext.defer_to_thread(self.hs.get_reactor(), _do_hash) @@ -1027,18 +999,19 @@ def validate_hash(self, password, stored_hash): Returns: Deferred(bool): Whether self.hash(password) == stored_hash. """ + def _do_validate_hash(): # Normalise the Unicode in the password pw = unicodedata.normalize("NFKC", password) return bcrypt.checkpw( - pw.encode('utf8') + self.hs.config.password_pepper.encode("utf8"), - stored_hash + pw.encode("utf8") + self.hs.config.password_pepper.encode("utf8"), + stored_hash, ) if stored_hash: if not isinstance(stored_hash, bytes): - stored_hash = stored_hash.encode('ascii') + stored_hash = stored_hash.encode("ascii") return logcontext.defer_to_thread(self.hs.get_reactor(), _do_validate_hash) else: @@ -1058,14 +1031,16 @@ def ratelimit_login_per_account(self, user_id): for this user is too high too proceed. """ self._failed_attempts_ratelimiter.ratelimit( - user_id.lower(), time_now_s=self._clock.time(), + user_id.lower(), + time_now_s=self._clock.time(), rate_hz=self.hs.config.rc_login_failed_attempts.per_second, burst_count=self.hs.config.rc_login_failed_attempts.burst_count, update=False, ) self._account_ratelimiter.ratelimit( - user_id.lower(), time_now_s=self._clock.time(), + user_id.lower(), + time_now_s=self._clock.time(), rate_hz=self.hs.config.rc_login_account.per_second, burst_count=self.hs.config.rc_login_account.burst_count, update=True, @@ -1083,9 +1058,9 @@ def generate_access_token(self, user_id, extra_caveats=None): macaroon.add_first_party_caveat("type = access") # Include a nonce, to make sure that each login gets a different # access token. - macaroon.add_first_party_caveat("nonce = %s" % ( - stringutils.random_string_with_symbols(16), - )) + macaroon.add_first_party_caveat( + "nonce = %s" % (stringutils.random_string_with_symbols(16),) + ) for caveat in extra_caveats: macaroon.add_first_party_caveat(caveat) return macaroon.serialize() @@ -1116,7 +1091,8 @@ def _generate_base_macaroon(self, user_id): macaroon = pymacaroons.Macaroon( location=self.hs.config.server_name, identifier="key", - key=self.hs.config.macaroon_secret_key) + key=self.hs.config.macaroon_secret_key, + ) macaroon.add_first_party_caveat("gen = 1") macaroon.add_first_party_caveat("user_id = %s" % (user_id,)) return macaroon diff --git a/synapse/handlers/deactivate_account.py b/synapse/handlers/deactivate_account.py index 7378b56c1dd0..e8f9da609893 100644 --- a/synapse/handlers/deactivate_account.py +++ b/synapse/handlers/deactivate_account.py @@ -28,6 +28,7 @@ class DeactivateAccountHandler(BaseHandler): """Handler which deals with deactivating user accounts.""" + def __init__(self, hs): super(DeactivateAccountHandler, self).__init__(hs) self._auth_handler = hs.get_auth_handler() @@ -78,9 +79,9 @@ def deactivate_account(self, user_id, erase_data, id_server=None): result = yield self._identity_handler.try_unbind_threepid( user_id, { - 'medium': threepid['medium'], - 'address': threepid['address'], - 'id_server': id_server, + "medium": threepid["medium"], + "address": threepid["address"], + "id_server": id_server, }, ) identity_server_supports_unbinding &= result @@ -89,7 +90,7 @@ def deactivate_account(self, user_id, erase_data, id_server=None): logger.exception("Failed to remove threepid from ID server") raise SynapseError(400, "Failed to remove threepid from ID server") yield self.store.user_delete_threepid( - user_id, threepid['medium'], threepid['address'], + user_id, threepid["medium"], threepid["address"] ) # delete any devices belonging to the user, which will also @@ -183,5 +184,6 @@ def _part_user(self, user_id): except Exception: logger.exception( "Failed to part user %r from room %r: ignoring and continuing", - user_id, room_id, + user_id, + room_id, ) diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py index b398848079ef..f59d0479b523 100644 --- a/synapse/handlers/device.py +++ b/synapse/handlers/device.py @@ -58,9 +58,7 @@ def get_devices_by_user(self, user_id): device_map = yield self.store.get_devices_by_user(user_id) - ips = yield self.store.get_last_client_ip_by_device( - user_id, device_id=None - ) + ips = yield self.store.get_last_client_ip_by_device(user_id, device_id=None) devices = list(device_map.values()) for device in devices: @@ -85,9 +83,7 @@ def get_device(self, user_id, device_id): device = yield self.store.get_device(user_id, device_id) except errors.StoreError: raise errors.NotFoundError - ips = yield self.store.get_last_client_ip_by_device( - user_id, device_id, - ) + ips = yield self.store.get_last_client_ip_by_device(user_id, device_id) _update_device_from_client_ips(device, ips) defer.returnValue(device) @@ -114,13 +110,11 @@ def get_user_ids_changed(self, user_id, from_token): rooms_changed = self.store.get_rooms_that_changed(room_ids, from_token.room_key) member_events = yield self.store.get_membership_changes_for_user( - user_id, from_token.room_key, now_room_key, + user_id, from_token.room_key, now_room_key ) rooms_changed.update(event.room_id for event in member_events) - stream_ordering = RoomStreamToken.parse_stream_token( - from_token.room_key - ).stream + stream_ordering = RoomStreamToken.parse_stream_token(from_token.room_key).stream possibly_changed = set(changed) possibly_left = set() @@ -206,10 +200,9 @@ def get_user_ids_changed(self, user_id, from_token): possibly_joined = [] possibly_left = [] - defer.returnValue({ - "changed": list(possibly_joined), - "left": list(possibly_left), - }) + defer.returnValue( + {"changed": list(possibly_joined), "left": list(possibly_left)} + ) class DeviceHandler(DeviceWorkerHandler): @@ -223,17 +216,18 @@ def __init__(self, hs): federation_registry = hs.get_federation_registry() federation_registry.register_edu_handler( - "m.device_list_update", self._edu_updater.incoming_device_list_update, + "m.device_list_update", self._edu_updater.incoming_device_list_update ) federation_registry.register_query_handler( - "user_devices", self.on_federation_query_user_devices, + "user_devices", self.on_federation_query_user_devices ) hs.get_distributor().observe("user_left_room", self.user_left_room) @defer.inlineCallbacks - def check_device_registered(self, user_id, device_id, - initial_device_display_name=None): + def check_device_registered( + self, user_id, device_id, initial_device_display_name=None + ): """ If the given device has not been registered, register it with the supplied display name. @@ -297,12 +291,10 @@ def delete_device(self, user_id, device_id): raise yield self._auth_handler.delete_access_tokens_for_user( - user_id, device_id=device_id, + user_id, device_id=device_id ) - yield self.store.delete_e2e_keys_by_device( - user_id=user_id, device_id=device_id - ) + yield self.store.delete_e2e_keys_by_device(user_id=user_id, device_id=device_id) yield self.notify_device_update(user_id, [device_id]) @@ -349,7 +341,7 @@ def delete_devices(self, user_id, device_ids): # considered as part of a critical path. for device_id in device_ids: yield self._auth_handler.delete_access_tokens_for_user( - user_id, device_id=device_id, + user_id, device_id=device_id ) yield self.store.delete_e2e_keys_by_device( user_id=user_id, device_id=device_id @@ -372,9 +364,7 @@ def update_device(self, user_id, device_id, content): try: yield self.store.update_device( - user_id, - device_id, - new_display_name=content.get("display_name") + user_id, device_id, new_display_name=content.get("display_name") ) yield self.notify_device_update(user_id, [device_id]) except errors.StoreError as e: @@ -404,29 +394,26 @@ def notify_device_update(self, user_id, device_ids): for device_id in device_ids: logger.debug( - "Notifying about update %r/%r, ID: %r", user_id, device_id, - position, + "Notifying about update %r/%r, ID: %r", user_id, device_id, position ) room_ids = yield self.store.get_rooms_for_user(user_id) - yield self.notifier.on_new_event( - "device_list_key", position, rooms=room_ids, - ) + yield self.notifier.on_new_event("device_list_key", position, rooms=room_ids) if hosts: - logger.info("Sending device list update notif for %r to: %r", user_id, hosts) + logger.info( + "Sending device list update notif for %r to: %r", user_id, hosts + ) for host in hosts: self.federation_sender.send_device_messages(host) @defer.inlineCallbacks def on_federation_query_user_devices(self, user_id): stream_id, devices = yield self.store.get_devices_with_keys_by_user(user_id) - defer.returnValue({ - "user_id": user_id, - "stream_id": stream_id, - "devices": devices, - }) + defer.returnValue( + {"user_id": user_id, "stream_id": stream_id, "devices": devices} + ) @defer.inlineCallbacks def user_left_room(self, user, room_id): @@ -440,10 +427,7 @@ def user_left_room(self, user, room_id): def _update_device_from_client_ips(device, client_ips): ip = client_ips.get((device["user_id"], device["device_id"]), {}) - device.update({ - "last_seen_ts": ip.get("last_seen"), - "last_seen_ip": ip.get("ip"), - }) + device.update({"last_seen_ts": ip.get("last_seen"), "last_seen_ip": ip.get("ip")}) class DeviceListEduUpdater(object): @@ -481,13 +465,15 @@ def incoming_device_list_update(self, origin, edu_content): device_id = edu_content.pop("device_id") stream_id = str(edu_content.pop("stream_id")) # They may come as ints prev_ids = edu_content.pop("prev_id", []) - prev_ids = [str(p) for p in prev_ids] # They may come as ints + prev_ids = [str(p) for p in prev_ids] # They may come as ints if get_domain_from_id(user_id) != origin: # TODO: Raise? logger.warning( "Got device list update edu for %r/%r from %r", - user_id, device_id, origin, + user_id, + device_id, + origin, ) return @@ -497,13 +483,12 @@ def incoming_device_list_update(self, origin, edu_content): # probably won't get any further updates. logger.warning( "Got device list update edu for %r/%r, but don't share a room", - user_id, device_id, + user_id, + device_id, ) return - logger.debug( - "Received device list update for %r/%r", user_id, device_id, - ) + logger.debug("Received device list update for %r/%r", user_id, device_id) self._pending_updates.setdefault(user_id, []).append( (device_id, stream_id, prev_ids, edu_content) @@ -525,7 +510,10 @@ def _handle_device_updates(self, user_id): for device_id, stream_id, prev_ids, content in pending_updates: logger.debug( "Handling update %r/%r, ID: %r, prev: %r ", - user_id, device_id, stream_id, prev_ids, + user_id, + device_id, + stream_id, + prev_ids, ) # Given a list of updates we check if we need to resync. This @@ -540,13 +528,13 @@ def _handle_device_updates(self, user_id): try: result = yield self.federation.query_user_devices(origin, user_id) except ( - NotRetryingDestination, RequestSendFailed, HttpResponseException, + NotRetryingDestination, + RequestSendFailed, + HttpResponseException, ): # TODO: Remember that we are now out of sync and try again # later - logger.warn( - "Failed to handle device list update for %s", user_id, - ) + logger.warn("Failed to handle device list update for %s", user_id) # We abort on exceptions rather than accepting the update # as otherwise synapse will 'forget' that its device list # is out of date. If we bail then we will retry the resync @@ -582,18 +570,21 @@ def _handle_device_updates(self, user_id): if len(devices) > 1000: logger.warn( "Ignoring device list snapshot for %s as it has >1K devs (%d)", - user_id, len(devices) + user_id, + len(devices), ) devices = [] for device in devices: logger.debug( "Handling resync update %r/%r, ID: %r", - user_id, device["device_id"], stream_id, + user_id, + device["device_id"], + stream_id, ) yield self.store.update_remote_device_list_cache( - user_id, devices, stream_id, + user_id, devices, stream_id ) device_ids = [device["device_id"] for device in devices] yield self.device_handler.notify_device_update(user_id, device_ids) @@ -606,7 +597,7 @@ def _handle_device_updates(self, user_id): # change (because of the single prev_id matching the current cache) for device_id, stream_id, prev_ids, content in pending_updates: yield self.store.update_remote_device_list_cache_entry( - user_id, device_id, content, stream_id, + user_id, device_id, content, stream_id ) yield self.device_handler.notify_device_update( @@ -624,14 +615,9 @@ def _need_to_do_resync(self, user_id, updates): """ seen_updates = self._seen_updates.get(user_id, set()) - extremity = yield self.store.get_device_list_last_stream_id_for_remote( - user_id - ) + extremity = yield self.store.get_device_list_last_stream_id_for_remote(user_id) - logger.debug( - "Current extremity for %r: %r", - user_id, extremity, - ) + logger.debug("Current extremity for %r: %r", user_id, extremity) stream_id_in_updates = set() # stream_ids in updates list for _, stream_id, prev_ids, _ in updates: diff --git a/synapse/handlers/devicemessage.py b/synapse/handlers/devicemessage.py index 2e2e5261de8e..e1ebb6346c3a 100644 --- a/synapse/handlers/devicemessage.py +++ b/synapse/handlers/devicemessage.py @@ -25,7 +25,6 @@ class DeviceMessageHandler(object): - def __init__(self, hs): """ Args: @@ -47,15 +46,15 @@ def on_direct_to_device_edu(self, origin, content): if origin != get_domain_from_id(sender_user_id): logger.warn( "Dropping device message from %r with spoofed sender %r", - origin, sender_user_id + origin, + sender_user_id, ) message_type = content["type"] message_id = content["message_id"] for user_id, by_device in content["messages"].items(): # we use UserID.from_string to catch invalid user ids if not self.is_mine(UserID.from_string(user_id)): - logger.warning("Request for keys for non-local user %s", - user_id) + logger.warning("Request for keys for non-local user %s", user_id) raise SynapseError(400, "Not a user here") messages_by_device = { diff --git a/synapse/handlers/directory.py b/synapse/handlers/directory.py index a12f9508d879..42d5b3db30c6 100644 --- a/synapse/handlers/directory.py +++ b/synapse/handlers/directory.py @@ -36,7 +36,6 @@ class DirectoryHandler(BaseHandler): - def __init__(self, hs): super(DirectoryHandler, self).__init__(hs) @@ -77,15 +76,19 @@ def _create_association(self, room_alias, room_id, servers=None, creator=None): raise SynapseError(400, "Failed to get server list") yield self.store.create_room_alias_association( - room_alias, - room_id, - servers, - creator=creator, + room_alias, room_id, servers, creator=creator ) @defer.inlineCallbacks - def create_association(self, requester, room_alias, room_id, servers=None, - send_event=True, check_membership=True): + def create_association( + self, + requester, + room_alias, + room_id, + servers=None, + send_event=True, + check_membership=True, + ): """Attempt to create a new alias Args: @@ -115,49 +118,40 @@ def create_association(self, requester, room_alias, room_id, servers=None, if service: if not service.is_interested_in_alias(room_alias.to_string()): raise SynapseError( - 400, "This application service has not reserved" - " this kind of alias.", errcode=Codes.EXCLUSIVE + 400, + "This application service has not reserved" " this kind of alias.", + errcode=Codes.EXCLUSIVE, ) else: if self.require_membership and check_membership: rooms_for_user = yield self.store.get_rooms_for_user(user_id) if room_id not in rooms_for_user: raise AuthError( - 403, - "You must be in the room to create an alias for it", + 403, "You must be in the room to create an alias for it" ) if not self.spam_checker.user_may_create_room_alias(user_id, room_alias): - raise AuthError( - 403, "This user is not permitted to create this alias", - ) + raise AuthError(403, "This user is not permitted to create this alias") if not self.config.is_alias_creation_allowed( - user_id, room_id, room_alias.to_string(), + user_id, room_id, room_alias.to_string() ): # Lets just return a generic message, as there may be all sorts of # reasons why we said no. TODO: Allow configurable error messages # per alias creation rule? - raise SynapseError( - 403, "Not allowed to create alias", - ) + raise SynapseError(403, "Not allowed to create alias") - can_create = yield self.can_modify_alias( - room_alias, - user_id=user_id - ) + can_create = yield self.can_modify_alias(room_alias, user_id=user_id) if not can_create: raise AuthError( - 400, "This alias is reserved by an application service.", - errcode=Codes.EXCLUSIVE + 400, + "This alias is reserved by an application service.", + errcode=Codes.EXCLUSIVE, ) yield self._create_association(room_alias, room_id, servers, creator=user_id) if send_event: - yield self.send_room_alias_update_event( - requester, - room_id - ) + yield self.send_room_alias_update_event(requester, room_id) @defer.inlineCallbacks def delete_association(self, requester, room_alias, send_event=True): @@ -194,34 +188,24 @@ def delete_association(self, requester, room_alias, send_event=True): raise if not can_delete: - raise AuthError( - 403, "You don't have permission to delete the alias.", - ) + raise AuthError(403, "You don't have permission to delete the alias.") - can_delete = yield self.can_modify_alias( - room_alias, - user_id=user_id - ) + can_delete = yield self.can_modify_alias(room_alias, user_id=user_id) if not can_delete: raise SynapseError( - 400, "This alias is reserved by an application service.", - errcode=Codes.EXCLUSIVE + 400, + "This alias is reserved by an application service.", + errcode=Codes.EXCLUSIVE, ) room_id = yield self._delete_association(room_alias) try: if send_event: - yield self.send_room_alias_update_event( - requester, - room_id - ) + yield self.send_room_alias_update_event(requester, room_id) yield self._update_canonical_alias( - requester, - requester.user.to_string(), - room_id, - room_alias, + requester, requester.user.to_string(), room_id, room_alias ) except AuthError as e: logger.info("Failed to update alias events: %s", e) @@ -234,7 +218,7 @@ def delete_appservice_association(self, service, room_alias): raise SynapseError( 400, "This application service has not reserved this kind of alias", - errcode=Codes.EXCLUSIVE + errcode=Codes.EXCLUSIVE, ) yield self._delete_association(room_alias) @@ -251,9 +235,7 @@ def _delete_association(self, room_alias): def get_association(self, room_alias): room_id = None if self.hs.is_mine(room_alias): - result = yield self.get_association_from_room_alias( - room_alias - ) + result = yield self.get_association_from_room_alias(room_alias) if result: room_id = result.room_id @@ -263,9 +245,7 @@ def get_association(self, room_alias): result = yield self.federation.make_query( destination=room_alias.domain, query_type="directory", - args={ - "room_alias": room_alias.to_string(), - }, + args={"room_alias": room_alias.to_string()}, retry_on_dns_fail=False, ignore_backoff=True, ) @@ -284,7 +264,7 @@ def get_association(self, room_alias): raise SynapseError( 404, "Room alias %s not found" % (room_alias.to_string(),), - Codes.NOT_FOUND + Codes.NOT_FOUND, ) users = yield self.state.get_current_users_in_room(room_id) @@ -293,41 +273,28 @@ def get_association(self, room_alias): # If this server is in the list of servers, return it first. if self.server_name in servers: - servers = ( - [self.server_name] + - [s for s in servers if s != self.server_name] - ) + servers = [self.server_name] + [s for s in servers if s != self.server_name] else: servers = list(servers) - defer.returnValue({ - "room_id": room_id, - "servers": servers, - }) + defer.returnValue({"room_id": room_id, "servers": servers}) return @defer.inlineCallbacks def on_directory_query(self, args): room_alias = RoomAlias.from_string(args["room_alias"]) if not self.hs.is_mine(room_alias): - raise SynapseError( - 400, "Room Alias is not hosted on this Home Server" - ) + raise SynapseError(400, "Room Alias is not hosted on this Home Server") - result = yield self.get_association_from_room_alias( - room_alias - ) + result = yield self.get_association_from_room_alias(room_alias) if result is not None: - defer.returnValue({ - "room_id": result.room_id, - "servers": result.servers, - }) + defer.returnValue({"room_id": result.room_id, "servers": result.servers}) else: raise SynapseError( 404, "Room alias %r not found" % (room_alias.to_string(),), - Codes.NOT_FOUND + Codes.NOT_FOUND, ) @defer.inlineCallbacks @@ -343,7 +310,7 @@ def send_room_alias_update_event(self, requester, room_id): "sender": requester.user.to_string(), "content": {"aliases": aliases}, }, - ratelimit=False + ratelimit=False, ) @defer.inlineCallbacks @@ -365,14 +332,12 @@ def _update_canonical_alias(self, requester, user_id, room_id, room_alias): "sender": user_id, "content": {}, }, - ratelimit=False + ratelimit=False, ) @defer.inlineCallbacks def get_association_from_room_alias(self, room_alias): - result = yield self.store.get_association_from_room_alias( - room_alias - ) + result = yield self.store.get_association_from_room_alias(room_alias) if not result: # Query AS to see if it exists as_handler = self.appservice_handler @@ -421,8 +386,7 @@ def edit_published_room_list(self, requester, room_id, visibility): if not self.spam_checker.user_may_publish_room(user_id, room_id): raise AuthError( - 403, - "This user is not permitted to publish rooms to the room list" + 403, "This user is not permitted to publish rooms to the room list" ) if requester.is_guest: @@ -434,8 +398,7 @@ def edit_published_room_list(self, requester, room_id, visibility): if visibility == "public" and not self.enable_room_list_search: # The room list has been disabled. raise AuthError( - 403, - "This user is not permitted to publish rooms to the room list" + 403, "This user is not permitted to publish rooms to the room list" ) room = yield self.store.get_room(room_id) @@ -452,20 +415,19 @@ def edit_published_room_list(self, requester, room_id, visibility): room_aliases.append(canonical_alias) if not self.config.is_publishing_room_allowed( - user_id, room_id, room_aliases, + user_id, room_id, room_aliases ): # Lets just return a generic message, as there may be all sorts of # reasons why we said no. TODO: Allow configurable error messages # per alias creation rule? - raise SynapseError( - 403, "Not allowed to publish room", - ) + raise SynapseError(403, "Not allowed to publish room") yield self.store.set_room_is_public(room_id, making_public) @defer.inlineCallbacks - def edit_published_appservice_room_list(self, appservice_id, network_id, - room_id, visibility): + def edit_published_appservice_room_list( + self, appservice_id, network_id, room_id, visibility + ): """Add or remove a room from the appservice/network specific public room list. diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py index 9dc46aa15f75..807900fe52ac 100644 --- a/synapse/handlers/e2e_keys.py +++ b/synapse/handlers/e2e_keys.py @@ -99,9 +99,7 @@ def query_devices(self, query_body, timeout): query_list.append((user_id, None)) user_ids_not_in_cache, remote_results = ( - yield self.store.get_user_devices_from_cache( - query_list - ) + yield self.store.get_user_devices_from_cache(query_list) ) for user_id, devices in iteritems(remote_results): user_devices = results.setdefault(user_id, {}) @@ -126,9 +124,7 @@ def do_remote_query(destination): destination_query = remote_queries_not_in_cache[destination] try: remote_result = yield self.federation.query_client_keys( - destination, - {"device_keys": destination_query}, - timeout=timeout + destination, {"device_keys": destination_query}, timeout=timeout ) for user_id, keys in remote_result["device_keys"].items(): @@ -138,14 +134,17 @@ def do_remote_query(destination): except Exception as e: failures[destination] = _exception_to_failure(e) - yield make_deferred_yieldable(defer.gatherResults([ - run_in_background(do_remote_query, destination) - for destination in remote_queries_not_in_cache - ], consumeErrors=True)) + yield make_deferred_yieldable( + defer.gatherResults( + [ + run_in_background(do_remote_query, destination) + for destination in remote_queries_not_in_cache + ], + consumeErrors=True, + ) + ) - defer.returnValue({ - "device_keys": results, "failures": failures, - }) + defer.returnValue({"device_keys": results, "failures": failures}) @defer.inlineCallbacks def query_local_devices(self, query): @@ -165,8 +164,7 @@ def query_local_devices(self, query): for user_id, device_ids in query.items(): # we use UserID.from_string to catch invalid user ids if not self.is_mine(UserID.from_string(user_id)): - logger.warning("Request for keys for non-local user %s", - user_id) + logger.warning("Request for keys for non-local user %s", user_id) raise SynapseError(400, "Not a user here") if not device_ids: @@ -231,9 +229,7 @@ def claim_client_keys(destination): device_keys = remote_queries[destination] try: remote_result = yield self.federation.claim_client_keys( - destination, - {"one_time_keys": device_keys}, - timeout=timeout + destination, {"one_time_keys": device_keys}, timeout=timeout ) for user_id, keys in remote_result["one_time_keys"].items(): if user_id in device_keys: @@ -241,25 +237,29 @@ def claim_client_keys(destination): except Exception as e: failures[destination] = _exception_to_failure(e) - yield make_deferred_yieldable(defer.gatherResults([ - run_in_background(claim_client_keys, destination) - for destination in remote_queries - ], consumeErrors=True)) + yield make_deferred_yieldable( + defer.gatherResults( + [ + run_in_background(claim_client_keys, destination) + for destination in remote_queries + ], + consumeErrors=True, + ) + ) logger.info( "Claimed one-time-keys: %s", - ",".join(( - "%s for %s:%s" % (key_id, user_id, device_id) - for user_id, user_keys in iteritems(json_result) - for device_id, device_keys in iteritems(user_keys) - for key_id, _ in iteritems(device_keys) - )), + ",".join( + ( + "%s for %s:%s" % (key_id, user_id, device_id) + for user_id, user_keys in iteritems(json_result) + for device_id, device_keys in iteritems(user_keys) + for key_id, _ in iteritems(device_keys) + ) + ), ) - defer.returnValue({ - "one_time_keys": json_result, - "failures": failures - }) + defer.returnValue({"one_time_keys": json_result, "failures": failures}) @defer.inlineCallbacks def upload_keys_for_user(self, user_id, device_id, keys): @@ -270,11 +270,13 @@ def upload_keys_for_user(self, user_id, device_id, keys): if device_keys: logger.info( "Updating device_keys for device %r for user %s at %d", - device_id, user_id, time_now + device_id, + user_id, + time_now, ) # TODO: Sign the JSON with the server key changed = yield self.store.set_e2e_device_keys( - user_id, device_id, time_now, device_keys, + user_id, device_id, time_now, device_keys ) if changed: # Only notify about device updates *if* the keys actually changed @@ -283,7 +285,7 @@ def upload_keys_for_user(self, user_id, device_id, keys): one_time_keys = keys.get("one_time_keys", None) if one_time_keys: yield self._upload_one_time_keys_for_user( - user_id, device_id, time_now, one_time_keys, + user_id, device_id, time_now, one_time_keys ) # the device should have been registered already, but it may have been @@ -298,20 +300,22 @@ def upload_keys_for_user(self, user_id, device_id, keys): defer.returnValue({"one_time_key_counts": result}) @defer.inlineCallbacks - def _upload_one_time_keys_for_user(self, user_id, device_id, time_now, - one_time_keys): + def _upload_one_time_keys_for_user( + self, user_id, device_id, time_now, one_time_keys + ): logger.info( "Adding one_time_keys %r for device %r for user %r at %d", - one_time_keys.keys(), device_id, user_id, time_now, + one_time_keys.keys(), + device_id, + user_id, + time_now, ) # make a list of (alg, id, key) tuples key_list = [] for key_id, key_obj in one_time_keys.items(): algorithm, key_id = key_id.split(":") - key_list.append(( - algorithm, key_id, key_obj - )) + key_list.append((algorithm, key_id, key_obj)) # First we check if we have already persisted any of the keys. existing_key_map = yield self.store.get_e2e_one_time_keys( @@ -325,42 +329,35 @@ def _upload_one_time_keys_for_user(self, user_id, device_id, time_now, if not _one_time_keys_match(ex_json, key): raise SynapseError( 400, - ("One time key %s:%s already exists. " - "Old key: %s; new key: %r") % - (algorithm, key_id, ex_json, key) + ( + "One time key %s:%s already exists. " + "Old key: %s; new key: %r" + ) + % (algorithm, key_id, ex_json, key), ) else: - new_keys.append(( - algorithm, key_id, encode_canonical_json(key).decode('ascii'))) + new_keys.append( + (algorithm, key_id, encode_canonical_json(key).decode("ascii")) + ) - yield self.store.add_e2e_one_time_keys( - user_id, device_id, time_now, new_keys - ) + yield self.store.add_e2e_one_time_keys(user_id, device_id, time_now, new_keys) def _exception_to_failure(e): if isinstance(e, CodeMessageException): - return { - "status": e.code, "message": str(e), - } + return {"status": e.code, "message": str(e)} if isinstance(e, NotRetryingDestination): - return { - "status": 503, "message": "Not ready for retry", - } + return {"status": 503, "message": "Not ready for retry"} if isinstance(e, FederationDeniedError): - return { - "status": 403, "message": "Federation Denied", - } + return {"status": 403, "message": "Federation Denied"} # include ConnectionRefused and other errors # # Note that some Exceptions (notably twisted's ResponseFailed etc) don't # give a string for e.message, which json then fails to serialize. - return { - "status": 503, "message": str(e), - } + return {"status": 503, "message": str(e)} def _one_time_keys_match(old_key_json, new_key): diff --git a/synapse/handlers/e2e_room_keys.py b/synapse/handlers/e2e_room_keys.py index 7bc174070e6c..ebd807bca682 100644 --- a/synapse/handlers/e2e_room_keys.py +++ b/synapse/handlers/e2e_room_keys.py @@ -152,14 +152,14 @@ def upload_room_keys(self, user_id, version, room_keys): else: raise - if version_info['version'] != version: + if version_info["version"] != version: # Check that the version we're trying to upload actually exists try: version_info = yield self.store.get_e2e_room_keys_version_info( - user_id, version, + user_id, version ) # if we get this far, the version must exist - raise RoomKeysVersionError(current_version=version_info['version']) + raise RoomKeysVersionError(current_version=version_info["version"]) except StoreError as e: if e.code == 404: raise NotFoundError("Version '%s' not found" % (version,)) @@ -168,8 +168,8 @@ def upload_room_keys(self, user_id, version, room_keys): # go through the room_keys. # XXX: this should/could be done concurrently, given we're in a lock. - for room_id, room in iteritems(room_keys['rooms']): - for session_id, session in iteritems(room['sessions']): + for room_id, room in iteritems(room_keys["rooms"]): + for session_id, session in iteritems(room["sessions"]): yield self._upload_room_key( user_id, version, room_id, session_id, session ) @@ -223,14 +223,14 @@ def _should_replace_room_key(current_room_key, room_key): # spelt out with if/elifs rather than nested boolean expressions # purely for legibility. - if room_key['is_verified'] and not current_room_key['is_verified']: + if room_key["is_verified"] and not current_room_key["is_verified"]: return True elif ( - room_key['first_message_index'] < - current_room_key['first_message_index'] + room_key["first_message_index"] + < current_room_key["first_message_index"] ): return True - elif room_key['forwarded_count'] < current_room_key['forwarded_count']: + elif room_key["forwarded_count"] < current_room_key["forwarded_count"]: return True else: return False @@ -328,16 +328,10 @@ def update_version(self, user_id, version, version_info): A deferred of an empty dict. """ if "version" not in version_info: - raise SynapseError( - 400, - "Missing version in body", - Codes.MISSING_PARAM - ) + raise SynapseError(400, "Missing version in body", Codes.MISSING_PARAM) if version_info["version"] != version: raise SynapseError( - 400, - "Version in body does not match", - Codes.INVALID_PARAM + 400, "Version in body does not match", Codes.INVALID_PARAM ) with (yield self._upload_linearizer.queue(user_id)): try: @@ -350,12 +344,10 @@ def update_version(self, user_id, version, version_info): else: raise if old_info["algorithm"] != version_info["algorithm"]: - raise SynapseError( - 400, - "Algorithm does not match", - Codes.INVALID_PARAM - ) + raise SynapseError(400, "Algorithm does not match", Codes.INVALID_PARAM) - yield self.store.update_e2e_room_keys_version(user_id, version, version_info) + yield self.store.update_e2e_room_keys_version( + user_id, version, version_info + ) defer.returnValue({}) diff --git a/synapse/handlers/events.py b/synapse/handlers/events.py index eb525070cff8..5836d3c639bd 100644 --- a/synapse/handlers/events.py +++ b/synapse/handlers/events.py @@ -31,7 +31,6 @@ class EventStreamHandler(BaseHandler): - def __init__(self, hs): super(EventStreamHandler, self).__init__(hs) @@ -53,9 +52,17 @@ def __init__(self, hs): @defer.inlineCallbacks @log_function - def get_stream(self, auth_user_id, pagin_config, timeout=0, - as_client_event=True, affect_presence=True, - only_keys=None, room_id=None, is_guest=False): + def get_stream( + self, + auth_user_id, + pagin_config, + timeout=0, + as_client_event=True, + affect_presence=True, + only_keys=None, + room_id=None, + is_guest=False, + ): """Fetches the events stream for a given user. If `only_keys` is not None, events from keys will be sent down. @@ -73,7 +80,7 @@ def get_stream(self, auth_user_id, pagin_config, timeout=0, presence_handler = self.hs.get_presence_handler() context = yield presence_handler.user_syncing( - auth_user_id, affect_presence=affect_presence, + auth_user_id, affect_presence=affect_presence ) with context: if timeout: @@ -85,9 +92,12 @@ def get_stream(self, auth_user_id, pagin_config, timeout=0, timeout = random.randint(int(timeout * 0.9), int(timeout * 1.1)) events, tokens = yield self.notifier.get_events_for( - auth_user, pagin_config, timeout, + auth_user, + pagin_config, + timeout, only_keys=only_keys, - is_guest=is_guest, explicit_room_id=room_id + is_guest=is_guest, + explicit_room_id=room_id, ) # When the user joins a new room, or another user joins a currently @@ -102,17 +112,15 @@ def get_stream(self, auth_user_id, pagin_config, timeout=0, # Send down presence. if event.state_key == auth_user_id: # Send down presence for everyone in the room. - users = yield self.state.get_current_users_in_room(event.room_id) - states = yield presence_handler.get_states( - users, - as_event=True, + users = yield self.state.get_current_users_in_room( + event.room_id ) + states = yield presence_handler.get_states(users, as_event=True) to_add.extend(states) else: ev = yield presence_handler.get_state( - UserID.from_string(event.state_key), - as_event=True, + UserID.from_string(event.state_key), as_event=True ) to_add.append(ev) @@ -121,7 +129,9 @@ def get_stream(self, auth_user_id, pagin_config, timeout=0, time_now = self.clock.time_msec() chunks = yield self._event_serializer.serialize_events( - events, time_now, as_client_event=as_client_event, + events, + time_now, + as_client_event=as_client_event, # We don't bundle "live" events, as otherwise clients # will end up double counting annotations. bundle_aggregations=False, @@ -137,7 +147,6 @@ def get_stream(self, auth_user_id, pagin_config, timeout=0, class EventHandler(BaseHandler): - @defer.inlineCallbacks def get_event(self, user, room_id, event_id): """Retrieve a single specified event. @@ -164,16 +173,10 @@ def get_event(self, user, room_id, event_id): is_peeking = user.to_string() not in users filtered = yield filter_events_for_client( - self.store, - user.to_string(), - [event], - is_peeking=is_peeking + self.store, user.to_string(), [event], is_peeking=is_peeking ) if not filtered: - raise AuthError( - 403, - "You don't have permission to access that event." - ) + raise AuthError(403, "You don't have permission to access that event.") defer.returnValue(event) diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index d5a605d3bd00..02d397c498e4 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -82,7 +82,7 @@ def shortstr(iterable, maxitems=5): items = list(itertools.islice(iterable, maxitems + 1)) if len(items) <= maxitems: return str(items) - return u"[" + u", ".join(repr(r) for r in items[:maxitems]) + u", ...]" + return "[" + ", ".join(repr(r) for r in items[:maxitems]) + ", ...]" class FederationHandler(BaseHandler): @@ -115,14 +115,14 @@ def __init__(self, hs): self.config = hs.config self.http_client = hs.get_simple_http_client() - self._send_events_to_master = ( - ReplicationFederationSendEventsRestServlet.make_client(hs) + self._send_events_to_master = ReplicationFederationSendEventsRestServlet.make_client( + hs ) - self._notify_user_membership_change = ( - ReplicationUserJoinedLeftRoomRestServlet.make_client(hs) + self._notify_user_membership_change = ReplicationUserJoinedLeftRoomRestServlet.make_client( + hs ) - self._clean_room_for_join_client = ( - ReplicationCleanRoomRestServlet.make_client(hs) + self._clean_room_for_join_client = ReplicationCleanRoomRestServlet.make_client( + hs ) # When joining a room we need to queue any events for that room up @@ -132,9 +132,7 @@ def __init__(self, hs): self.third_party_event_rules = hs.get_third_party_event_rules() @defer.inlineCallbacks - def on_receive_pdu( - self, origin, pdu, sent_to_us_directly=False, - ): + def on_receive_pdu(self, origin, pdu, sent_to_us_directly=False): """ Process a PDU received via a federation /send/ transaction, or via backfill of missing prev_events @@ -151,26 +149,19 @@ def on_receive_pdu( room_id = pdu.room_id event_id = pdu.event_id - logger.info( - "[%s %s] handling received PDU: %s", - room_id, event_id, pdu, - ) + logger.info("[%s %s] handling received PDU: %s", room_id, event_id, pdu) # We reprocess pdus when we have seen them only as outliers existing = yield self.store.get_event( - event_id, - allow_none=True, - allow_rejected=True, + event_id, allow_none=True, allow_rejected=True ) # FIXME: Currently we fetch an event again when we already have it # if it has been marked as an outlier. - already_seen = ( - existing and ( - not existing.internal_metadata.is_outlier() - or pdu.internal_metadata.is_outlier() - ) + already_seen = existing and ( + not existing.internal_metadata.is_outlier() + or pdu.internal_metadata.is_outlier() ) if already_seen: logger.debug("[%s %s]: Already seen pdu", room_id, event_id) @@ -182,20 +173,19 @@ def on_receive_pdu( try: self._sanity_check_event(pdu) except SynapseError as err: - logger.warn("[%s %s] Received event failed sanity checks", room_id, event_id) - raise FederationError( - "ERROR", - err.code, - err.msg, - affected=pdu.event_id, + logger.warn( + "[%s %s] Received event failed sanity checks", room_id, event_id ) + raise FederationError("ERROR", err.code, err.msg, affected=pdu.event_id) # If we are currently in the process of joining this room, then we # queue up events for later processing. if room_id in self.room_queues: logger.info( "[%s %s] Queuing PDU from %s for now: join in progress", - room_id, event_id, origin, + room_id, + event_id, + origin, ) self.room_queues[room_id].append((pdu, origin)) return @@ -206,14 +196,13 @@ def on_receive_pdu( # # Note that if we were never in the room then we would have already # dropped the event, since we wouldn't know the room version. - is_in_room = yield self.auth.check_host_in_room( - room_id, - self.server_name - ) + is_in_room = yield self.auth.check_host_in_room(room_id, self.server_name) if not is_in_room: logger.info( "[%s %s] Ignoring PDU from %s as we're not in the room", - room_id, event_id, origin, + room_id, + event_id, + origin, ) defer.returnValue(None) @@ -223,14 +212,9 @@ def on_receive_pdu( # Get missing pdus if necessary. if not pdu.internal_metadata.is_outlier(): # We only backfill backwards to the min depth. - min_depth = yield self.get_min_depth_for_context( - pdu.room_id - ) + min_depth = yield self.get_min_depth_for_context(pdu.room_id) - logger.debug( - "[%s %s] min_depth: %d", - room_id, event_id, min_depth, - ) + logger.debug("[%s %s] min_depth: %d", room_id, event_id, min_depth) prevs = set(pdu.prev_event_ids()) seen = yield self.store.have_seen_events(prevs) @@ -248,12 +232,17 @@ def on_receive_pdu( # at a time. logger.info( "[%s %s] Acquiring room lock to fetch %d missing prev_events: %s", - room_id, event_id, len(missing_prevs), shortstr(missing_prevs), + room_id, + event_id, + len(missing_prevs), + shortstr(missing_prevs), ) with (yield self._room_pdu_linearizer.queue(pdu.room_id)): logger.info( "[%s %s] Acquired room lock to fetch %d missing prev_events", - room_id, event_id, len(missing_prevs), + room_id, + event_id, + len(missing_prevs), ) yield self._get_missing_events_for_pdu( @@ -267,12 +256,16 @@ def on_receive_pdu( if not prevs - seen: logger.info( "[%s %s] Found all missing prev_events", - room_id, event_id, + room_id, + event_id, ) elif missing_prevs: logger.info( "[%s %s] Not recursively fetching %d missing prev_events: %s", - room_id, event_id, len(missing_prevs), shortstr(missing_prevs), + room_id, + event_id, + len(missing_prevs), + shortstr(missing_prevs), ) if prevs - seen: @@ -303,7 +296,10 @@ def on_receive_pdu( if sent_to_us_directly: logger.warn( "[%s %s] Rejecting: failed to fetch %d prev events: %s", - room_id, event_id, len(prevs - seen), shortstr(prevs - seen) + room_id, + event_id, + len(prevs - seen), + shortstr(prevs - seen), ) raise FederationError( "ERROR", @@ -318,9 +314,7 @@ def on_receive_pdu( # Calculate the state after each of the previous events, and # resolve them to find the correct state at the current event. auth_chains = set() - event_map = { - event_id: pdu, - } + event_map = {event_id: pdu} try: # Get the state of the events we know about ours = yield self.store.get_state_groups_ids(room_id, seen) @@ -337,7 +331,9 @@ def on_receive_pdu( for p in prevs - seen: logger.info( "[%s %s] Requesting state at missing prev_event %s", - room_id, event_id, p, + room_id, + event_id, + p, ) room_version = yield self.store.get_room_version(room_id) @@ -348,19 +344,19 @@ def on_receive_pdu( # by the get_pdu_cache in federation_client. remote_state, got_auth_chain = ( yield self.federation_client.get_state_for_room( - origin, room_id, p, + origin, room_id, p ) ) # we want the state *after* p; get_state_for_room returns the # state *before* p. remote_event = yield self.federation_client.get_pdu( - [origin], p, room_version, outlier=True, + [origin], p, room_version, outlier=True ) if remote_event is None: raise Exception( - "Unable to get missing prev_event %s" % (p, ) + "Unable to get missing prev_event %s" % (p,) ) if remote_event.is_state(): @@ -380,7 +376,9 @@ def on_receive_pdu( event_map[x.event_id] = x state_map = yield resolve_events_with_store( - room_version, state_maps, event_map, + room_version, + state_maps, + event_map, state_res_store=StateResolutionStore(self.store), ) @@ -396,15 +394,15 @@ def on_receive_pdu( ) event_map.update(evs) - state = [ - event_map[e] for e in six.itervalues(state_map) - ] + state = [event_map[e] for e in six.itervalues(state_map)] auth_chain = list(auth_chains) except Exception: logger.warn( "[%s %s] Error attempting to resolve state at missing " "prev_events", - room_id, event_id, exc_info=True, + room_id, + event_id, + exc_info=True, ) raise FederationError( "ERROR", @@ -414,10 +412,7 @@ def on_receive_pdu( ) yield self._process_received_pdu( - origin, - pdu, - state=state, - auth_chain=auth_chain, + origin, pdu, state=state, auth_chain=auth_chain ) @defer.inlineCallbacks @@ -447,7 +442,10 @@ def _get_missing_events_for_pdu(self, origin, pdu, prevs, min_depth): logger.info( "[%s %s]: Requesting missing events between %s and %s", - room_id, event_id, shortstr(latest), event_id, + room_id, + event_id, + shortstr(latest), + event_id, ) # XXX: we set timeout to 10s to help workaround @@ -512,15 +510,15 @@ def _get_missing_events_for_pdu(self, origin, pdu, prevs, min_depth): # We failed to get the missing events, but since we need to handle # the case of `get_missing_events` not returning the necessary # events anyway, it is safe to simply log the error and continue. - logger.warn( - "[%s %s]: Failed to get prev_events: %s", - room_id, event_id, e, - ) + logger.warn("[%s %s]: Failed to get prev_events: %s", room_id, event_id, e) return logger.info( "[%s %s]: Got %d prev_events: %s", - room_id, event_id, len(missing_events), shortstr(missing_events), + room_id, + event_id, + len(missing_events), + shortstr(missing_events), ) # We want to sort these by depth so we process them and @@ -530,20 +528,20 @@ def _get_missing_events_for_pdu(self, origin, pdu, prevs, min_depth): for ev in missing_events: logger.info( "[%s %s] Handling received prev_event %s", - room_id, event_id, ev.event_id, + room_id, + event_id, + ev.event_id, ) with logcontext.nested_logging_context(ev.event_id): try: - yield self.on_receive_pdu( - origin, - ev, - sent_to_us_directly=False, - ) + yield self.on_receive_pdu(origin, ev, sent_to_us_directly=False) except FederationError as e: if e.code == 403: logger.warn( "[%s %s] Received prev_event %s failed history check.", - room_id, event_id, ev.event_id, + room_id, + event_id, + ev.event_id, ) else: raise @@ -556,10 +554,7 @@ def _process_received_pdu(self, origin, event, state, auth_chain): room_id = event.room_id event_id = event.event_id - logger.debug( - "[%s %s] Processing event: %s", - room_id, event_id, event, - ) + logger.debug("[%s %s] Processing event: %s", room_id, event_id, event) event_ids = set() if state: @@ -581,43 +576,32 @@ def _process_received_pdu(self, origin, event, state, auth_chain): e.internal_metadata.outlier = True auth_ids = e.auth_event_ids() auth = { - (e.type, e.state_key): e for e in auth_chain + (e.type, e.state_key): e + for e in auth_chain if e.event_id in auth_ids or e.type == EventTypes.Create } - event_infos.append({ - "event": e, - "auth_events": auth, - }) + event_infos.append({"event": e, "auth_events": auth}) seen_ids.add(e.event_id) logger.info( "[%s %s] persisting newly-received auth/state events %s", - room_id, event_id, [e["event"].event_id for e in event_infos] + room_id, + event_id, + [e["event"].event_id for e in event_infos], ) yield self._handle_new_events(origin, event_infos) try: - context = yield self._handle_new_event( - origin, - event, - state=state, - ) + context = yield self._handle_new_event(origin, event, state=state) except AuthError as e: - raise FederationError( - "ERROR", - e.code, - e.msg, - affected=event.event_id, - ) + raise FederationError("ERROR", e.code, e.msg, affected=event.event_id) room = yield self.store.get_room(room_id) if not room: try: yield self.store.store_room( - room_id=room_id, - room_creator_user_id="", - is_public=False, + room_id=room_id, room_creator_user_id="", is_public=False ) except StoreError: logger.exception("Failed to store room.") @@ -631,12 +615,10 @@ def _process_received_pdu(self, origin, event, state, auth_chain): prev_state_ids = yield context.get_prev_state_ids(self.store) - prev_state_id = prev_state_ids.get( - (event.type, event.state_key) - ) + prev_state_id = prev_state_ids.get((event.type, event.state_key)) if prev_state_id: prev_state = yield self.store.get_event( - prev_state_id, allow_none=True, + prev_state_id, allow_none=True ) if prev_state and prev_state.membership == Membership.JOIN: newly_joined = False @@ -667,10 +649,7 @@ def backfill(self, dest, room_id, limit, extremities): room_version = yield self.store.get_room_version(room_id) events = yield self.federation_client.backfill( - dest, - room_id, - limit=limit, - extremities=extremities, + dest, room_id, limit=limit, extremities=extremities ) # ideally we'd sanity check the events here for excess prev_events etc, @@ -697,16 +676,9 @@ def backfill(self, dest, room_id, limit, extremities): event_ids = set(e.event_id for e in events) - edges = [ - ev.event_id - for ev in events - if set(ev.prev_event_ids()) - event_ids - ] + edges = [ev.event_id for ev in events if set(ev.prev_event_ids()) - event_ids] - logger.info( - "backfill: Got %d events with %d edges", - len(events), len(edges), - ) + logger.info("backfill: Got %d events with %d edges", len(events), len(edges)) # For each edge get the current state. @@ -715,9 +687,7 @@ def backfill(self, dest, room_id, limit, extremities): events_to_state = {} for e_id in edges: state, auth = yield self.federation_client.get_state_for_room( - destination=dest, - room_id=room_id, - event_id=e_id + destination=dest, room_id=room_id, event_id=e_id ) auth_events.update({a.event_id: a for a in auth}) auth_events.update({s.event_id: s for s in state}) @@ -726,12 +696,14 @@ def backfill(self, dest, room_id, limit, extremities): required_auth = set( a_id - for event in events + list(state_events.values()) + list(auth_events.values()) + for event in events + + list(state_events.values()) + + list(auth_events.values()) for a_id in event.auth_event_ids() ) - auth_events.update({ - e_id: event_map[e_id] for e_id in required_auth if e_id in event_map - }) + auth_events.update( + {e_id: event_map[e_id] for e_id in required_auth if e_id in event_map} + ) missing_auth = required_auth - set(auth_events) failed_to_fetch = set() @@ -750,27 +722,30 @@ def backfill(self, dest, room_id, limit, extremities): if missing_auth - failed_to_fetch: logger.info( "Fetching missing auth for backfill: %r", - missing_auth - failed_to_fetch + missing_auth - failed_to_fetch, ) - results = yield logcontext.make_deferred_yieldable(defer.gatherResults( - [ - logcontext.run_in_background( - self.federation_client.get_pdu, - [dest], - event_id, - room_version=room_version, - outlier=True, - timeout=10000, - ) - for event_id in missing_auth - failed_to_fetch - ], - consumeErrors=True - )).addErrback(unwrapFirstError) + results = yield logcontext.make_deferred_yieldable( + defer.gatherResults( + [ + logcontext.run_in_background( + self.federation_client.get_pdu, + [dest], + event_id, + room_version=room_version, + outlier=True, + timeout=10000, + ) + for event_id in missing_auth - failed_to_fetch + ], + consumeErrors=True, + ) + ).addErrback(unwrapFirstError) auth_events.update({a.event_id: a for a in results if a}) required_auth.update( a_id - for event in results if event + for event in results + if event for a_id in event.auth_event_ids() ) missing_auth = required_auth - set(auth_events) @@ -802,15 +777,19 @@ def backfill(self, dest, room_id, limit, extremities): continue a.internal_metadata.outlier = True - ev_infos.append({ - "event": a, - "auth_events": { - (auth_events[a_id].type, auth_events[a_id].state_key): - auth_events[a_id] - for a_id in a.auth_event_ids() - if a_id in auth_events + ev_infos.append( + { + "event": a, + "auth_events": { + ( + auth_events[a_id].type, + auth_events[a_id].state_key, + ): auth_events[a_id] + for a_id in a.auth_event_ids() + if a_id in auth_events + }, } - }) + ) # Step 1b: persist the events in the chunk we fetched state for (i.e. # the backwards extremities) as non-outliers. @@ -818,23 +797,24 @@ def backfill(self, dest, room_id, limit, extremities): # For paranoia we ensure that these events are marked as # non-outliers ev = event_map[e_id] - assert(not ev.internal_metadata.is_outlier()) - - ev_infos.append({ - "event": ev, - "state": events_to_state[e_id], - "auth_events": { - (auth_events[a_id].type, auth_events[a_id].state_key): - auth_events[a_id] - for a_id in ev.auth_event_ids() - if a_id in auth_events + assert not ev.internal_metadata.is_outlier() + + ev_infos.append( + { + "event": ev, + "state": events_to_state[e_id], + "auth_events": { + ( + auth_events[a_id].type, + auth_events[a_id].state_key, + ): auth_events[a_id] + for a_id in ev.auth_event_ids() + if a_id in auth_events + }, } - }) + ) - yield self._handle_new_events( - dest, ev_infos, - backfilled=True, - ) + yield self._handle_new_events(dest, ev_infos, backfilled=True) # Step 2: Persist the rest of the events in the chunk one by one events.sort(key=lambda e: e.depth) @@ -845,14 +825,12 @@ def backfill(self, dest, room_id, limit, extremities): # For paranoia we ensure that these events are marked as # non-outliers - assert(not event.internal_metadata.is_outlier()) + assert not event.internal_metadata.is_outlier() # We store these one at a time since each event depends on the # previous to work out the state. # TODO: We can probably do something more clever here. - yield self._handle_new_event( - dest, event, backfilled=True, - ) + yield self._handle_new_event(dest, event, backfilled=True) defer.returnValue(events) @@ -861,9 +839,7 @@ def maybe_backfill(self, room_id, current_depth): """Checks the database to see if we should backfill before paginating, and if so do. """ - extremities = yield self.store.get_oldest_events_with_depth_in_room( - room_id - ) + extremities = yield self.store.get_oldest_events_with_depth_in_room(room_id) if not extremities: logger.debug("Not backfilling as no extremeties found.") @@ -895,31 +871,27 @@ def maybe_backfill(self, room_id, current_depth): # state *before* the event, ignoring the special casing certain event # types have. - forward_events = yield self.store.get_successor_events( - list(extremities), - ) + forward_events = yield self.store.get_successor_events(list(extremities)) extremities_events = yield self.store.get_events( - forward_events, - check_redacted=False, - get_prev_content=False, + forward_events, check_redacted=False, get_prev_content=False ) # We set `check_history_visibility_only` as we might otherwise get false # positives from users having been erased. filtered_extremities = yield filter_events_for_server( - self.store, self.server_name, list(extremities_events.values()), - redact=False, check_history_visibility_only=True, + self.store, + self.server_name, + list(extremities_events.values()), + redact=False, + check_history_visibility_only=True, ) if not filtered_extremities: defer.returnValue(False) # Check if we reached a point where we should start backfilling. - sorted_extremeties_tuple = sorted( - extremities.items(), - key=lambda e: -int(e[1]) - ) + sorted_extremeties_tuple = sorted(extremities.items(), key=lambda e: -int(e[1])) max_depth = sorted_extremeties_tuple[0][1] # We don't want to specify too many extremities as it causes the backfill @@ -928,8 +900,7 @@ def maybe_backfill(self, room_id, current_depth): if current_depth > max_depth: logger.debug( - "Not backfilling as we don't need to. %d < %d", - max_depth, current_depth, + "Not backfilling as we don't need to. %d < %d", max_depth, current_depth ) return @@ -954,8 +925,7 @@ def get_domains_from_state(state): joined_users = [ (state_key, int(event.depth)) for (e_type, state_key), event in iteritems(state) - if e_type == EventTypes.Member - and event.membership == Membership.JOIN + if e_type == EventTypes.Member and event.membership == Membership.JOIN ] joined_domains = {} @@ -975,8 +945,7 @@ def get_domains_from_state(state): curr_domains = get_domains_from_state(curr_state) likely_domains = [ - domain for domain, depth in curr_domains - if domain != self.server_name + domain for domain, depth in curr_domains if domain != self.server_name ] @defer.inlineCallbacks @@ -985,28 +954,20 @@ def try_backfill(domains): for dom in domains: try: yield self.backfill( - dom, room_id, - limit=100, - extremities=extremities, + dom, room_id, limit=100, extremities=extremities ) # If this succeeded then we probably already have the # appropriate stuff. # TODO: We can probably do something more intelligent here. defer.returnValue(True) except SynapseError as e: - logger.info( - "Failed to backfill from %s because %s", - dom, e, - ) + logger.info("Failed to backfill from %s because %s", dom, e) continue except CodeMessageException as e: if 400 <= e.code < 500: raise - logger.info( - "Failed to backfill from %s because %s", - dom, e, - ) + logger.info("Failed to backfill from %s because %s", dom, e) continue except NotRetryingDestination as e: logger.info(str(e)) @@ -1015,10 +976,7 @@ def try_backfill(domains): logger.info(e) continue except Exception as e: - logger.exception( - "Failed to backfill from %s because %s", - dom, e, - ) + logger.exception("Failed to backfill from %s because %s", dom, e) continue defer.returnValue(False) @@ -1039,10 +997,11 @@ def try_backfill(domains): resolve = logcontext.preserve_fn( self.state_handler.resolve_state_groups_for_events ) - states = yield logcontext.make_deferred_yieldable(defer.gatherResults( - [resolve(room_id, [e]) for e in event_ids], - consumeErrors=True, - )) + states = yield logcontext.make_deferred_yieldable( + defer.gatherResults( + [resolve(room_id, [e]) for e in event_ids], consumeErrors=True + ) + ) # dict[str, dict[tuple, str]], a map from event_id to state map of # event_ids. @@ -1050,23 +1009,23 @@ def try_backfill(domains): state_map = yield self.store.get_events( [e_id for ids in itervalues(states) for e_id in itervalues(ids)], - get_prev_content=False + get_prev_content=False, ) states = { key: { k: state_map[e_id] for k, e_id in iteritems(state_dict) if e_id in state_map - } for key, state_dict in iteritems(states) + } + for key, state_dict in iteritems(states) } for e_id, _ in sorted_extremeties_tuple: likely_domains = get_domains_from_state(states[e_id]) - success = yield try_backfill([ - dom for dom, _ in likely_domains - if dom not in tried_domains - ]) + success = yield try_backfill( + [dom for dom, _ in likely_domains if dom not in tried_domains] + ) if success: defer.returnValue(True) @@ -1091,20 +1050,20 @@ def _sanity_check_event(self, ev): SynapseError if the event does not pass muster """ if len(ev.prev_event_ids()) > 20: - logger.warn("Rejecting event %s which has %i prev_events", - ev.event_id, len(ev.prev_event_ids())) - raise SynapseError( - http_client.BAD_REQUEST, - "Too many prev_events", + logger.warn( + "Rejecting event %s which has %i prev_events", + ev.event_id, + len(ev.prev_event_ids()), ) + raise SynapseError(http_client.BAD_REQUEST, "Too many prev_events") if len(ev.auth_event_ids()) > 10: - logger.warn("Rejecting event %s which has %i auth_events", - ev.event_id, len(ev.auth_event_ids())) - raise SynapseError( - http_client.BAD_REQUEST, - "Too many auth_events", + logger.warn( + "Rejecting event %s which has %i auth_events", + ev.event_id, + len(ev.auth_event_ids()), ) + raise SynapseError(http_client.BAD_REQUEST, "Too many auth_events") @defer.inlineCallbacks def send_invite(self, target_host, event): @@ -1116,7 +1075,7 @@ def send_invite(self, target_host, event): destination=target_host, room_id=event.room_id, event_id=event.event_id, - pdu=event + pdu=event, ) defer.returnValue(pdu) @@ -1125,8 +1084,7 @@ def send_invite(self, target_host, event): def on_event_auth(self, event_id): event = yield self.store.get_event(event_id) auth = yield self.store.get_auth_chain( - [auth_id for auth_id in event.auth_event_ids()], - include_given=True + [auth_id for auth_id in event.auth_event_ids()], include_given=True ) defer.returnValue([e for e in auth]) @@ -1152,15 +1110,13 @@ def do_invite_join(self, target_hosts, room_id, joinee, content): joinee, "join", content, - params={ - "ver": KNOWN_ROOM_VERSIONS, - }, + params={"ver": KNOWN_ROOM_VERSIONS}, ) # This shouldn't happen, because the RoomMemberHandler has a # linearizer lock which only allows one operation per user per room # at a time - so this is just paranoia. - assert (room_id not in self.room_queues) + assert room_id not in self.room_queues self.room_queues[room_id] = [] @@ -1177,7 +1133,7 @@ def do_invite_join(self, target_hosts, room_id, joinee, content): except ValueError: pass ret = yield self.federation_client.send_join( - target_hosts, event, event_format_version, + target_hosts, event, event_format_version ) origin = ret["origin"] @@ -1196,17 +1152,13 @@ def do_invite_join(self, target_hosts, room_id, joinee, content): try: yield self.store.store_room( - room_id=room_id, - room_creator_user_id="", - is_public=False + room_id=room_id, room_creator_user_id="", is_public=False ) except Exception: # FIXME pass - yield self._persist_auth_tree( - origin, auth_chain, state, event - ) + yield self._persist_auth_tree(origin, auth_chain, state, event) logger.debug("Finished joining %s to %s", joinee, room_id) finally: @@ -1233,14 +1185,18 @@ def _handle_queued_pdus(self, room_queue): """ for p, origin in room_queue: try: - logger.info("Processing queued PDU %s which was received " - "while we were joining %s", p.event_id, p.room_id) + logger.info( + "Processing queued PDU %s which was received " + "while we were joining %s", + p.event_id, + p.room_id, + ) with logcontext.nested_logging_context(p.event_id): yield self.on_receive_pdu(origin, p, sent_to_us_directly=True) except Exception as e: logger.warn( - "Error handling queued PDU %s from %s: %s", - p.event_id, origin, e) + "Error handling queued PDU %s from %s: %s", p.event_id, origin, e + ) @defer.inlineCallbacks @log_function @@ -1261,30 +1217,30 @@ def on_make_join_request(self, room_id, user_id): "room_id": room_id, "sender": user_id, "state_key": user_id, - } + }, ) try: event, context = yield self.event_creation_handler.create_new_client_event( - builder=builder, + builder=builder ) except AuthError as e: logger.warn("Failed to create join %r because %s", event, e) raise e event_allowed = yield self.third_party_event_rules.check_event_allowed( - event, context, + event, context ) if not event_allowed: logger.info("Creation of join %s forbidden by third-party rules", event) raise SynapseError( - 403, "This event is not allowed in this context", Codes.FORBIDDEN, + 403, "This event is not allowed in this context", Codes.FORBIDDEN ) # The remote hasn't signed it yet, obviously. We'll do the full checks # when we get the event back in `on_send_join_request` yield self.auth.check_from_context( - room_version, event, context, do_sig_check=False, + room_version, event, context, do_sig_check=False ) defer.returnValue(event) @@ -1319,17 +1275,15 @@ def on_send_join_request(self, origin, pdu): # would introduce the danger of backwards-compatibility problems. event.internal_metadata.send_on_behalf_of = origin - context = yield self._handle_new_event( - origin, event - ) + context = yield self._handle_new_event(origin, event) event_allowed = yield self.third_party_event_rules.check_event_allowed( - event, context, + event, context ) if not event_allowed: logger.info("Sending of join %s forbidden by third-party rules", event) raise SynapseError( - 403, "This event is not allowed in this context", Codes.FORBIDDEN, + 403, "This event is not allowed in this context", Codes.FORBIDDEN ) logger.debug( @@ -1350,10 +1304,7 @@ def on_send_join_request(self, origin, pdu): state = yield self.store.get_events(list(prev_state_ids.values())) - defer.returnValue({ - "state": list(state.values()), - "auth_chain": auth_chain, - }) + defer.returnValue({"state": list(state.values()), "auth_chain": auth_chain}) @defer.inlineCallbacks def on_invite_request(self, origin, pdu): @@ -1374,7 +1325,7 @@ def on_invite_request(self, origin, pdu): raise SynapseError(403, "This server does not accept room invites") if not self.spam_checker.user_may_invite( - event.sender, event.state_key, event.room_id, + event.sender, event.state_key, event.room_id ): raise SynapseError( 403, "This user is not permitted to send invites to this server/user" @@ -1386,26 +1337,23 @@ def on_invite_request(self, origin, pdu): sender_domain = get_domain_from_id(event.sender) if sender_domain != origin: - raise SynapseError(400, "The invite event was not from the server sending it") + raise SynapseError( + 400, "The invite event was not from the server sending it" + ) if not self.is_mine_id(event.state_key): raise SynapseError(400, "The invite event must be for this server") # block any attempts to invite the server notices mxid if event.state_key == self._server_notices_mxid: - raise SynapseError( - http_client.FORBIDDEN, - "Cannot invite this user", - ) + raise SynapseError(http_client.FORBIDDEN, "Cannot invite this user") event.internal_metadata.outlier = True event.internal_metadata.out_of_band_membership = True event.signatures.update( compute_event_signature( - event.get_pdu_json(), - self.hs.hostname, - self.hs.config.signing_key[0] + event.get_pdu_json(), self.hs.hostname, self.hs.config.signing_key[0] ) ) @@ -1417,10 +1365,7 @@ def on_invite_request(self, origin, pdu): @defer.inlineCallbacks def do_remotely_reject_invite(self, target_hosts, room_id, user_id): origin, event, event_format_version = yield self._make_and_verify_event( - target_hosts, - room_id, - user_id, - "leave" + target_hosts, room_id, user_id, "leave" ) # Mark as outlier as we don't have any state for this event; we're not # even in the room. @@ -1435,10 +1380,7 @@ def do_remotely_reject_invite(self, target_hosts, room_id, user_id): except ValueError: pass - yield self.federation_client.send_leave( - target_hosts, - event - ) + yield self.federation_client.send_leave(target_hosts, event) context = yield self.state_handler.compute_event_context(event) yield self.persist_events_and_notify([(event, context)]) @@ -1446,25 +1388,21 @@ def do_remotely_reject_invite(self, target_hosts, room_id, user_id): defer.returnValue(event) @defer.inlineCallbacks - def _make_and_verify_event(self, target_hosts, room_id, user_id, membership, - content={}, params=None): + def _make_and_verify_event( + self, target_hosts, room_id, user_id, membership, content={}, params=None + ): origin, event, format_ver = yield self.federation_client.make_membership_event( - target_hosts, - room_id, - user_id, - membership, - content, - params=params, + target_hosts, room_id, user_id, membership, content, params=params ) logger.debug("Got response to make_%s: %s", membership, event) # We should assert some things. # FIXME: Do this in a nicer way - assert(event.type == EventTypes.Member) - assert(event.user_id == user_id) - assert(event.state_key == user_id) - assert(event.room_id == room_id) + assert event.type == EventTypes.Member + assert event.user_id == user_id + assert event.state_key == user_id + assert event.room_id == room_id defer.returnValue((origin, event, format_ver)) @defer.inlineCallbacks @@ -1483,27 +1421,27 @@ def on_make_leave_request(self, room_id, user_id): "room_id": room_id, "sender": user_id, "state_key": user_id, - } + }, ) event, context = yield self.event_creation_handler.create_new_client_event( - builder=builder, + builder=builder ) event_allowed = yield self.third_party_event_rules.check_event_allowed( - event, context, + event, context ) if not event_allowed: logger.warning("Creation of leave %s forbidden by third-party rules", event) raise SynapseError( - 403, "This event is not allowed in this context", Codes.FORBIDDEN, + 403, "This event is not allowed in this context", Codes.FORBIDDEN ) try: # The remote hasn't signed it yet, obviously. We'll do the full checks # when we get the event back in `on_send_leave_request` yield self.auth.check_from_context( - room_version, event, context, do_sig_check=False, + room_version, event, context, do_sig_check=False ) except AuthError as e: logger.warn("Failed to create new leave %r because %s", event, e) @@ -1525,17 +1463,15 @@ def on_send_leave_request(self, origin, pdu): event.internal_metadata.outlier = False - context = yield self._handle_new_event( - origin, event - ) + context = yield self._handle_new_event(origin, event) event_allowed = yield self.third_party_event_rules.check_event_allowed( - event, context, + event, context ) if not event_allowed: logger.info("Sending of leave %s forbidden by third-party rules", event) raise SynapseError( - 403, "This event is not allowed in this context", Codes.FORBIDDEN, + 403, "This event is not allowed in this context", Codes.FORBIDDEN ) logger.debug( @@ -1552,18 +1488,14 @@ def get_state_for_pdu(self, room_id, event_id): """ event = yield self.store.get_event( - event_id, allow_none=False, check_room_id=room_id, + event_id, allow_none=False, check_room_id=room_id ) - state_groups = yield self.store.get_state_groups( - room_id, [event_id] - ) + state_groups = yield self.store.get_state_groups(room_id, [event_id]) if state_groups: _, state = list(iteritems(state_groups)).pop() - results = { - (e.type, e.state_key): e for e in state - } + results = {(e.type, e.state_key): e for e in state} if event.is_state(): # Get previous state @@ -1585,12 +1517,10 @@ def get_state_ids_for_pdu(self, room_id, event_id): """Returns the state at the event. i.e. not including said event. """ event = yield self.store.get_event( - event_id, allow_none=False, check_room_id=room_id, + event_id, allow_none=False, check_room_id=room_id ) - state_groups = yield self.store.get_state_groups_ids( - room_id, [event_id] - ) + state_groups = yield self.store.get_state_groups_ids(room_id, [event_id]) if state_groups: _, state = list(state_groups.items()).pop() @@ -1616,11 +1546,7 @@ def on_backfill_request(self, origin, room_id, pdu_list, limit): if not in_room: raise AuthError(403, "Host not in room.") - events = yield self.store.get_backfill_events( - room_id, - pdu_list, - limit - ) + events = yield self.store.get_backfill_events(room_id, pdu_list, limit) events = yield filter_events_for_server(self.store, origin, events) @@ -1644,22 +1570,15 @@ def get_persisted_pdu(self, origin, event_id): AuthError if the server is not currently in the room """ event = yield self.store.get_event( - event_id, - allow_none=True, - allow_rejected=True, + event_id, allow_none=True, allow_rejected=True ) if event: - in_room = yield self.auth.check_host_in_room( - event.room_id, - origin - ) + in_room = yield self.auth.check_host_in_room(event.room_id, origin) if not in_room: raise AuthError(403, "Host not in room.") - events = yield filter_events_for_server( - self.store, origin, [event], - ) + events = yield filter_events_for_server(self.store, origin, [event]) event = events[0] defer.returnValue(event) else: @@ -1669,13 +1588,11 @@ def get_min_depth_for_context(self, context): return self.store.get_min_depth(context) @defer.inlineCallbacks - def _handle_new_event(self, origin, event, state=None, auth_events=None, - backfilled=False): + def _handle_new_event( + self, origin, event, state=None, auth_events=None, backfilled=False + ): context = yield self._prep_event( - origin, event, - state=state, - auth_events=auth_events, - backfilled=backfilled, + origin, event, state=state, auth_events=auth_events, backfilled=backfilled ) # reraise does not allow inlineCallbacks to preserve the stacktrace, so we @@ -1688,15 +1605,13 @@ def _handle_new_event(self, origin, event, state=None, auth_events=None, ) yield self.persist_events_and_notify( - [(event, context)], - backfilled=backfilled, + [(event, context)], backfilled=backfilled ) success = True finally: if not success: logcontext.run_in_background( - self.store.remove_push_actions_from_staging, - event.event_id, + self.store.remove_push_actions_from_staging, event.event_id ) defer.returnValue(context) @@ -1724,12 +1639,15 @@ def prep(ev_info): ) defer.returnValue(res) - contexts = yield logcontext.make_deferred_yieldable(defer.gatherResults( - [ - logcontext.run_in_background(prep, ev_info) - for ev_info in event_infos - ], consumeErrors=True, - )) + contexts = yield logcontext.make_deferred_yieldable( + defer.gatherResults( + [ + logcontext.run_in_background(prep, ev_info) + for ev_info in event_infos + ], + consumeErrors=True, + ) + ) yield self.persist_events_and_notify( [ @@ -1764,8 +1682,7 @@ def _persist_auth_tree(self, origin, auth_events, state, event): events_to_context[e.event_id] = ctx event_map = { - e.event_id: e - for e in itertools.chain(auth_events, state, [event]) + e.event_id: e for e in itertools.chain(auth_events, state, [event]) } create_event = None @@ -1780,7 +1697,7 @@ def _persist_auth_tree(self, origin, auth_events, state, event): raise SynapseError(400, "No create event in state") room_version = create_event.content.get( - "room_version", RoomVersions.V1.identifier, + "room_version", RoomVersions.V1.identifier ) missing_auth_events = set() @@ -1791,11 +1708,7 @@ def _persist_auth_tree(self, origin, auth_events, state, event): for e_id in missing_auth_events: m_ev = yield self.federation_client.get_pdu( - [origin], - e_id, - room_version=room_version, - outlier=True, - timeout=10000, + [origin], e_id, room_version=room_version, outlier=True, timeout=10000 ) if m_ev and m_ev.event_id == e_id: event_map[e_id] = m_ev @@ -1820,10 +1733,7 @@ def _persist_auth_tree(self, origin, auth_events, state, event): # cause SynapseErrors in auth.check. We don't want to give up # the attempt to federate altogether in such cases. - logger.warn( - "Rejecting %s because %s", - e.event_id, err.msg - ) + logger.warn("Rejecting %s because %s", e.event_id, err.msg) if e == event: raise @@ -1833,16 +1743,14 @@ def _persist_auth_tree(self, origin, auth_events, state, event): [ (e, events_to_context[e.event_id]) for e in itertools.chain(auth_events, state) - ], + ] ) new_event_context = yield self.state_handler.compute_event_context( event, old_state=state ) - yield self.persist_events_and_notify( - [(event, new_event_context)], - ) + yield self.persist_events_and_notify([(event, new_event_context)]) @defer.inlineCallbacks def _prep_event(self, origin, event, state, auth_events, backfilled): @@ -1858,40 +1766,30 @@ def _prep_event(self, origin, event, state, auth_events, backfilled): Returns: Deferred, which resolves to synapse.events.snapshot.EventContext """ - context = yield self.state_handler.compute_event_context( - event, old_state=state, - ) + context = yield self.state_handler.compute_event_context(event, old_state=state) if not auth_events: prev_state_ids = yield context.get_prev_state_ids(self.store) auth_events_ids = yield self.auth.compute_auth_events( - event, prev_state_ids, for_verification=True, + event, prev_state_ids, for_verification=True ) auth_events = yield self.store.get_events(auth_events_ids) - auth_events = { - (e.type, e.state_key): e for e in auth_events.values() - } + auth_events = {(e.type, e.state_key): e for e in auth_events.values()} # This is a hack to fix some old rooms where the initial join event # didn't reference the create event in its auth events. if event.type == EventTypes.Member and not event.auth_event_ids(): if len(event.prev_event_ids()) == 1 and event.depth < 5: c = yield self.store.get_event( - event.prev_event_ids()[0], - allow_none=True, + event.prev_event_ids()[0], allow_none=True ) if c and c.type == EventTypes.Create: auth_events[(c.type, c.state_key)] = c try: - yield self.do_auth( - origin, event, context, auth_events=auth_events - ) + yield self.do_auth(origin, event, context, auth_events=auth_events) except AuthError as e: - logger.warn( - "[%s %s] Rejecting: %s", - event.room_id, event.event_id, e.msg - ) + logger.warn("[%s %s] Rejecting: %s", event.room_id, event.event_id, e.msg) context.rejected = RejectedReason.AUTH_ERROR @@ -1922,9 +1820,7 @@ def _check_for_soft_fail(self, event, state, backfilled): # "soft-fail" the event. do_soft_fail_check = not backfilled and not event.internal_metadata.is_outlier() if do_soft_fail_check: - extrem_ids = yield self.store.get_latest_event_ids_in_room( - event.room_id, - ) + extrem_ids = yield self.store.get_latest_event_ids_in_room(event.room_id) extrem_ids = set(extrem_ids) prev_event_ids = set(event.prev_event_ids()) @@ -1952,31 +1848,31 @@ def _check_for_soft_fail(self, event, state, backfilled): # like bans, especially with state res v2. state_sets = yield self.store.get_state_groups( - event.room_id, extrem_ids, + event.room_id, extrem_ids ) state_sets = list(state_sets.values()) state_sets.append(state) current_state_ids = yield self.state_handler.resolve_events( - room_version, state_sets, event, + room_version, state_sets, event ) current_state_ids = { k: e.event_id for k, e in iteritems(current_state_ids) } else: current_state_ids = yield self.state_handler.get_current_state_ids( - event.room_id, latest_event_ids=extrem_ids, + event.room_id, latest_event_ids=extrem_ids ) logger.debug( "Doing soft-fail check for %s: state %s", - event.event_id, current_state_ids, + event.event_id, + current_state_ids, ) # Now check if event pass auth against said current state auth_types = auth_types_for_event(event) current_state_ids = [ - e for k, e in iteritems(current_state_ids) - if k in auth_types + e for k, e in iteritems(current_state_ids) if k in auth_types ] current_auth_events = yield self.store.get_events(current_state_ids) @@ -1987,19 +1883,14 @@ def _check_for_soft_fail(self, event, state, backfilled): try: self.auth.check(room_version, event, auth_events=current_auth_events) except AuthError as e: - logger.warn( - "Soft-failing %r because %s", - event, e, - ) + logger.warn("Soft-failing %r because %s", event, e) event.internal_metadata.soft_failed = True @defer.inlineCallbacks - def on_query_auth(self, origin, event_id, room_id, remote_auth_chain, rejects, - missing): - in_room = yield self.auth.check_host_in_room( - room_id, - origin - ) + def on_query_auth( + self, origin, event_id, room_id, remote_auth_chain, rejects, missing + ): + in_room = yield self.auth.check_host_in_room(room_id, origin) if not in_room: raise AuthError(403, "Host not in room.") @@ -2017,28 +1908,23 @@ def on_query_auth(self, origin, event_id, room_id, remote_auth_chain, rejects, # Now get the current auth_chain for the event. local_auth_chain = yield self.store.get_auth_chain( - [auth_id for auth_id in event.auth_event_ids()], - include_given=True + [auth_id for auth_id in event.auth_event_ids()], include_given=True ) # TODO: Check if we would now reject event_id. If so we need to tell # everyone. - ret = yield self.construct_auth_difference( - local_auth_chain, remote_auth_chain - ) + ret = yield self.construct_auth_difference(local_auth_chain, remote_auth_chain) logger.debug("on_query_auth returning: %s", ret) defer.returnValue(ret) @defer.inlineCallbacks - def on_get_missing_events(self, origin, room_id, earliest_events, - latest_events, limit): - in_room = yield self.auth.check_host_in_room( - room_id, - origin - ) + def on_get_missing_events( + self, origin, room_id, earliest_events, latest_events, limit + ): + in_room = yield self.auth.check_host_in_room(room_id, origin) if not in_room: raise AuthError(403, "Host not in room.") @@ -2052,7 +1938,7 @@ def on_get_missing_events(self, origin, room_id, earliest_events, ) missing_events = yield filter_events_for_server( - self.store, origin, missing_events, + self.store, origin, missing_events ) defer.returnValue(missing_events) @@ -2140,25 +2026,17 @@ def _update_auth_events_and_context_for_auth( if missing_auth: # TODO: can we use store.have_seen_events here instead? - have_events = yield self.store.get_seen_events_with_rejections( - missing_auth - ) + have_events = yield self.store.get_seen_events_with_rejections(missing_auth) logger.debug("Got events %s from store", have_events) missing_auth.difference_update(have_events.keys()) else: have_events = {} - have_events.update({ - e.event_id: "" - for e in auth_events.values() - }) + have_events.update({e.event_id: "" for e in auth_events.values()}) if missing_auth: # If we don't have all the auth events, we need to get them. - logger.info( - "auth_events contains unknown events: %s", - missing_auth, - ) + logger.info("auth_events contains unknown events: %s", missing_auth) try: try: remote_auth_chain = yield self.federation_client.get_event_auth( @@ -2184,18 +2062,16 @@ def _update_auth_events_and_context_for_auth( try: auth_ids = e.auth_event_ids() auth = { - (e.type, e.state_key): e for e in remote_auth_chain + (e.type, e.state_key): e + for e in remote_auth_chain if e.event_id in auth_ids or e.type == EventTypes.Create } e.internal_metadata.outlier = True logger.debug( - "do_auth %s missing_auth: %s", - event.event_id, e.event_id - ) - yield self._handle_new_event( - origin, e, auth_events=auth + "do_auth %s missing_auth: %s", event.event_id, e.event_id ) + yield self._handle_new_event(origin, e, auth_events=auth) if e.event_id in event_auth_events: auth_events[(e.type, e.state_key)] = e @@ -2231,35 +2107,36 @@ def _update_auth_events_and_context_for_auth( room_version = yield self.store.get_room_version(event.room_id) different_events = yield logcontext.make_deferred_yieldable( - defer.gatherResults([ - logcontext.run_in_background( - self.store.get_event, - d, - allow_none=True, - allow_rejected=False, - ) - for d in different_auth - if d in have_events and not have_events[d] - ], consumeErrors=True) + defer.gatherResults( + [ + logcontext.run_in_background( + self.store.get_event, d, allow_none=True, allow_rejected=False + ) + for d in different_auth + if d in have_events and not have_events[d] + ], + consumeErrors=True, + ) ).addErrback(unwrapFirstError) if different_events: local_view = dict(auth_events) remote_view = dict(auth_events) - remote_view.update({ - (d.type, d.state_key): d for d in different_events if d - }) + remote_view.update( + {(d.type, d.state_key): d for d in different_events if d} + ) new_state = yield self.state_handler.resolve_events( room_version, [list(local_view.values()), list(remote_view.values())], - event + event, ) logger.info( "After state res: updating auth_events with new state %s", { - (d.type, d.state_key): d.event_id for d in new_state.values() + (d.type, d.state_key): d.event_id + for d in new_state.values() if auth_events.get((d.type, d.state_key)) != d }, ) @@ -2271,7 +2148,7 @@ def _update_auth_events_and_context_for_auth( ) yield self._update_context_for_auth_events( - event, context, auth_events, event_key, + event, context, auth_events, event_key ) if not different_auth: @@ -2305,21 +2182,14 @@ def _update_auth_events_and_context_for_auth( prev_state_ids = yield context.get_prev_state_ids(self.store) # 1. Get what we think is the auth chain. - auth_ids = yield self.auth.compute_auth_events( - event, prev_state_ids - ) - local_auth_chain = yield self.store.get_auth_chain( - auth_ids, include_given=True - ) + auth_ids = yield self.auth.compute_auth_events(event, prev_state_ids) + local_auth_chain = yield self.store.get_auth_chain(auth_ids, include_given=True) try: # 2. Get remote difference. try: result = yield self.federation_client.query_auth( - origin, - event.room_id, - event.event_id, - local_auth_chain, + origin, event.room_id, event.event_id, local_auth_chain ) except RequestSendFailed as e: # The other side isn't around or doesn't implement the @@ -2344,19 +2214,15 @@ def _update_auth_events_and_context_for_auth( auth = { (e.type, e.state_key): e for e in result["auth_chain"] - if e.event_id in auth_ids - or event.type == EventTypes.Create + if e.event_id in auth_ids or event.type == EventTypes.Create } ev.internal_metadata.outlier = True logger.debug( - "do_auth %s different_auth: %s", - event.event_id, e.event_id + "do_auth %s different_auth: %s", event.event_id, e.event_id ) - yield self._handle_new_event( - origin, ev, auth_events=auth - ) + yield self._handle_new_event(origin, ev, auth_events=auth) if ev.event_id in event_auth_events: auth_events[(ev.type, ev.state_key)] = ev @@ -2371,12 +2237,11 @@ def _update_auth_events_and_context_for_auth( # TODO. yield self._update_context_for_auth_events( - event, context, auth_events, event_key, + event, context, auth_events, event_key ) @defer.inlineCallbacks - def _update_context_for_auth_events(self, event, context, auth_events, - event_key): + def _update_context_for_auth_events(self, event, context, auth_events, event_key): """Update the state_ids in an event context after auth event resolution, storing the changes as a new state group. @@ -2393,8 +2258,7 @@ def _update_context_for_auth_events(self, event, context, auth_events, this will not be included in the current_state in the context. """ state_updates = { - k: a.event_id for k, a in iteritems(auth_events) - if k != event_key + k: a.event_id for k, a in iteritems(auth_events) if k != event_key } current_state_ids = yield context.get_current_state_ids(self.store) current_state_ids = dict(current_state_ids) @@ -2404,9 +2268,7 @@ def _update_context_for_auth_events(self, event, context, auth_events, prev_state_ids = yield context.get_prev_state_ids(self.store) prev_state_ids = dict(prev_state_ids) - prev_state_ids.update({ - k: a.event_id for k, a in iteritems(auth_events) - }) + prev_state_ids.update({k: a.event_id for k, a in iteritems(auth_events)}) # create a new state group as a delta from the existing one. prev_group = context.state_group @@ -2555,30 +2417,23 @@ def get_next(it, opt=None): logger.debug("construct_auth_difference returning") - defer.returnValue({ - "auth_chain": local_auth, - "rejects": { - e.event_id: { - "reason": reason_map[e.event_id], - "proof": None, - } - for e in base_remote_rejected - }, - "missing": [e.event_id for e in missing_locals], - }) + defer.returnValue( + { + "auth_chain": local_auth, + "rejects": { + e.event_id: {"reason": reason_map[e.event_id], "proof": None} + for e in base_remote_rejected + }, + "missing": [e.event_id for e in missing_locals], + } + ) @defer.inlineCallbacks @log_function def exchange_third_party_invite( - self, - sender_user_id, - target_user_id, - room_id, - signed, + self, sender_user_id, target_user_id, room_id, signed ): - third_party_invite = { - "signed": signed, - } + third_party_invite = {"signed": signed} event_dict = { "type": EventTypes.Member, @@ -2601,7 +2456,7 @@ def exchange_third_party_invite( ) event_allowed = yield self.third_party_event_rules.check_event_allowed( - event, context, + event, context ) if not event_allowed: logger.info( @@ -2609,7 +2464,7 @@ def exchange_third_party_invite( event, ) raise SynapseError( - 403, "This event is not allowed in this context", Codes.FORBIDDEN, + 403, "This event is not allowed in this context", Codes.FORBIDDEN ) event, context = yield self.add_display_name_to_third_party_invite( @@ -2634,9 +2489,7 @@ def exchange_third_party_invite( else: destinations = set(x.split(":", 1)[-1] for x in (sender_user_id, room_id)) yield self.federation_client.forward_third_party_invite( - destinations, - room_id, - event_dict, + destinations, room_id, event_dict ) @defer.inlineCallbacks @@ -2657,19 +2510,18 @@ def on_exchange_third_party_invite_request(self, origin, room_id, event_dict): builder = self.event_builder_factory.new(room_version, event_dict) event, context = yield self.event_creation_handler.create_new_client_event( - builder=builder, + builder=builder ) event_allowed = yield self.third_party_event_rules.check_event_allowed( - event, context, + event, context ) if not event_allowed: logger.warning( - "Exchange of threepid invite %s forbidden by third-party rules", - event, + "Exchange of threepid invite %s forbidden by third-party rules", event ) raise SynapseError( - 403, "This event is not allowed in this context", Codes.FORBIDDEN, + 403, "This event is not allowed in this context", Codes.FORBIDDEN ) event, context = yield self.add_display_name_to_third_party_invite( @@ -2691,11 +2543,12 @@ def on_exchange_third_party_invite_request(self, origin, room_id, event_dict): yield member_handler.send_membership_event(None, event, context) @defer.inlineCallbacks - def add_display_name_to_third_party_invite(self, room_version, event_dict, - event, context): + def add_display_name_to_third_party_invite( + self, room_version, event_dict, event, context + ): key = ( EventTypes.ThirdPartyInvite, - event.content["third_party_invite"]["signed"]["token"] + event.content["third_party_invite"]["signed"]["token"], ) original_invite = None prev_state_ids = yield context.get_prev_state_ids(self.store) @@ -2709,8 +2562,7 @@ def add_display_name_to_third_party_invite(self, room_version, event_dict, event_dict["content"]["third_party_invite"]["display_name"] = display_name else: logger.info( - "Could not find invite event for third_party_invite: %r", - event_dict + "Could not find invite event for third_party_invite: %r", event_dict ) # We don't discard here as this is not the appropriate place to do # auth checks. If we need the invite and don't have it then the @@ -2719,7 +2571,7 @@ def add_display_name_to_third_party_invite(self, room_version, event_dict, builder = self.event_builder_factory.new(room_version, event_dict) EventValidator().validate_builder(builder) event, context = yield self.event_creation_handler.create_new_client_event( - builder=builder, + builder=builder ) EventValidator().validate_new(event) defer.returnValue((event, context)) @@ -2743,9 +2595,7 @@ def _check_signature(self, event, context): token = signed["token"] prev_state_ids = yield context.get_prev_state_ids(self.store) - invite_event_id = prev_state_ids.get( - (EventTypes.ThirdPartyInvite, token,) - ) + invite_event_id = prev_state_ids.get((EventTypes.ThirdPartyInvite, token)) invite_event = None if invite_event_id: @@ -2769,38 +2619,42 @@ def _check_signature(self, event, context): logger.debug( "Attempting to verify sig with key %s from %r " "against pubkey %r", - key_name, server, public_key_object, + key_name, + server, + public_key_object, ) try: public_key = public_key_object["public_key"] verify_key = decode_verify_key_bytes( - key_name, - decode_base64(public_key) + key_name, decode_base64(public_key) ) verify_signed_json(signed, server, verify_key) logger.debug( "Successfully verified sig with key %s from %r " "against pubkey %r", - key_name, server, public_key_object, + key_name, + server, + public_key_object, ) except Exception: logger.info( "Failed to verify sig with key %s from %r " "against pubkey %r", - key_name, server, public_key_object, + key_name, + server, + public_key_object, ) raise try: if "key_validity_url" in public_key_object: yield self._check_key_revocation( - public_key, - public_key_object["key_validity_url"] + public_key, public_key_object["key_validity_url"] ) except Exception: logger.info( "Failed to query key_validity_url %s", - public_key_object["key_validity_url"] + public_key_object["key_validity_url"], ) raise return @@ -2823,15 +2677,9 @@ def _check_key_revocation(self, public_key, url): for revocation. """ try: - response = yield self.http_client.get_json( - url, - {"public_key": public_key} - ) + response = yield self.http_client.get_json(url, {"public_key": public_key}) except Exception: - raise SynapseError( - 502, - "Third party certificate could not be checked" - ) + raise SynapseError(502, "Third party certificate could not be checked") if "valid" not in response or not response["valid"]: raise AuthError(403, "Third party certificate was invalid") @@ -2852,12 +2700,11 @@ def persist_events_and_notify(self, event_and_contexts, backfilled=False): yield self._send_events_to_master( store=self.store, event_and_contexts=event_and_contexts, - backfilled=backfilled + backfilled=backfilled, ) else: max_stream_id = yield self.store.persist_events( - event_and_contexts, - backfilled=backfilled, + event_and_contexts, backfilled=backfilled ) if not backfilled: # Never notify for backfilled events @@ -2891,13 +2738,10 @@ def _notify_persisted_event(self, event, max_stream_id): event_stream_id = event.internal_metadata.stream_ordering self.notifier.on_new_room_event( - event, event_stream_id, max_stream_id, - extra_users=extra_users + event, event_stream_id, max_stream_id, extra_users=extra_users ) - return self.pusher_pool.on_new_notifications( - event_stream_id, max_stream_id, - ) + return self.pusher_pool.on_new_notifications(event_stream_id, max_stream_id) def _clean_room_for_join(self, room_id): """Called to clean up any data in DB for a given room, ready for the @@ -2916,9 +2760,7 @@ def user_joined_room(self, user, room_id): """ if self.config.worker_app: return self._notify_user_membership_change( - room_id=room_id, - user_id=user.to_string(), - change="joined", + room_id=room_id, user_id=user.to_string(), change="joined" ) else: return user_joined_room(self.distributor, user, room_id) diff --git a/synapse/handlers/groups_local.py b/synapse/handlers/groups_local.py index f60ace02e8af..7da63bb64361 100644 --- a/synapse/handlers/groups_local.py +++ b/synapse/handlers/groups_local.py @@ -30,6 +30,7 @@ def _create_rerouter(func_name): """Returns a function that looks at the group id and calls the function on federation or the local group server if the group is local """ + def f(self, group_id, *args, **kwargs): if self.is_mine_id(group_id): return getattr(self.groups_server_handler, func_name)( @@ -58,6 +59,7 @@ def request_failed_errback(failure): d.addErrback(http_response_errback) d.addErrback(request_failed_errback) return d + return f @@ -125,7 +127,7 @@ def get_group_summary(self, group_id, requester_user_id): ) else: res = yield self.transport_client.get_group_summary( - get_domain_from_id(group_id), group_id, requester_user_id, + get_domain_from_id(group_id), group_id, requester_user_id ) group_server_name = get_domain_from_id(group_id) @@ -182,7 +184,7 @@ def create_group(self, group_id, user_id, content): content["user_profile"] = yield self.profile_handler.get_profile(user_id) res = yield self.transport_client.create_group( - get_domain_from_id(group_id), group_id, user_id, content, + get_domain_from_id(group_id), group_id, user_id, content ) remote_attestation = res["attestation"] @@ -195,16 +197,15 @@ def create_group(self, group_id, user_id, content): is_publicised = content.get("publicise", False) token = yield self.store.register_user_group_membership( - group_id, user_id, + group_id, + user_id, membership="join", is_admin=True, local_attestation=local_attestation, remote_attestation=remote_attestation, is_publicised=is_publicised, ) - self.notifier.on_new_event( - "groups_key", token, users=[user_id], - ) + self.notifier.on_new_event("groups_key", token, users=[user_id]) defer.returnValue(res) @@ -221,7 +222,7 @@ def get_users_in_group(self, group_id, requester_user_id): group_server_name = get_domain_from_id(group_id) res = yield self.transport_client.get_users_in_group( - get_domain_from_id(group_id), group_id, requester_user_id, + get_domain_from_id(group_id), group_id, requester_user_id ) chunk = res["chunk"] @@ -250,9 +251,7 @@ def join_group(self, group_id, user_id, content): """Request to join a group """ if self.is_mine_id(group_id): - yield self.groups_server_handler.join_group( - group_id, user_id, content - ) + yield self.groups_server_handler.join_group(group_id, user_id, content) local_attestation = None remote_attestation = None else: @@ -260,7 +259,7 @@ def join_group(self, group_id, user_id, content): content["attestation"] = local_attestation res = yield self.transport_client.join_group( - get_domain_from_id(group_id), group_id, user_id, content, + get_domain_from_id(group_id), group_id, user_id, content ) remote_attestation = res["attestation"] @@ -276,16 +275,15 @@ def join_group(self, group_id, user_id, content): is_publicised = content.get("publicise", False) token = yield self.store.register_user_group_membership( - group_id, user_id, + group_id, + user_id, membership="join", is_admin=False, local_attestation=local_attestation, remote_attestation=remote_attestation, is_publicised=is_publicised, ) - self.notifier.on_new_event( - "groups_key", token, users=[user_id], - ) + self.notifier.on_new_event("groups_key", token, users=[user_id]) defer.returnValue({}) @@ -294,9 +292,7 @@ def accept_invite(self, group_id, user_id, content): """Accept an invite to a group """ if self.is_mine_id(group_id): - yield self.groups_server_handler.accept_invite( - group_id, user_id, content - ) + yield self.groups_server_handler.accept_invite(group_id, user_id, content) local_attestation = None remote_attestation = None else: @@ -304,7 +300,7 @@ def accept_invite(self, group_id, user_id, content): content["attestation"] = local_attestation res = yield self.transport_client.accept_group_invite( - get_domain_from_id(group_id), group_id, user_id, content, + get_domain_from_id(group_id), group_id, user_id, content ) remote_attestation = res["attestation"] @@ -320,16 +316,15 @@ def accept_invite(self, group_id, user_id, content): is_publicised = content.get("publicise", False) token = yield self.store.register_user_group_membership( - group_id, user_id, + group_id, + user_id, membership="join", is_admin=False, local_attestation=local_attestation, remote_attestation=remote_attestation, is_publicised=is_publicised, ) - self.notifier.on_new_event( - "groups_key", token, users=[user_id], - ) + self.notifier.on_new_event("groups_key", token, users=[user_id]) defer.returnValue({}) @@ -337,17 +332,17 @@ def accept_invite(self, group_id, user_id, content): def invite(self, group_id, user_id, requester_user_id, config): """Invite a user to a group """ - content = { - "requester_user_id": requester_user_id, - "config": config, - } + content = {"requester_user_id": requester_user_id, "config": config} if self.is_mine_id(group_id): res = yield self.groups_server_handler.invite_to_group( - group_id, user_id, requester_user_id, content, + group_id, user_id, requester_user_id, content ) else: res = yield self.transport_client.invite_to_group( - get_domain_from_id(group_id), group_id, user_id, requester_user_id, + get_domain_from_id(group_id), + group_id, + user_id, + requester_user_id, content, ) @@ -370,13 +365,12 @@ def on_invite(self, group_id, user_id, content): local_profile["avatar_url"] = content["profile"]["avatar_url"] token = yield self.store.register_user_group_membership( - group_id, user_id, + group_id, + user_id, membership="invite", content={"profile": local_profile, "inviter": content["inviter"]}, ) - self.notifier.on_new_event( - "groups_key", token, users=[user_id], - ) + self.notifier.on_new_event("groups_key", token, users=[user_id]) try: user_profile = yield self.profile_handler.get_profile(user_id) except Exception as e: @@ -391,25 +385,25 @@ def remove_user_from_group(self, group_id, user_id, requester_user_id, content): """ if user_id == requester_user_id: token = yield self.store.register_user_group_membership( - group_id, user_id, - membership="leave", - ) - self.notifier.on_new_event( - "groups_key", token, users=[user_id], + group_id, user_id, membership="leave" ) + self.notifier.on_new_event("groups_key", token, users=[user_id]) # TODO: Should probably remember that we tried to leave so that we can # retry if the group server is currently down. if self.is_mine_id(group_id): res = yield self.groups_server_handler.remove_user_from_group( - group_id, user_id, requester_user_id, content, + group_id, user_id, requester_user_id, content ) else: content["requester_user_id"] = requester_user_id res = yield self.transport_client.remove_user_from_group( - get_domain_from_id(group_id), group_id, requester_user_id, - user_id, content, + get_domain_from_id(group_id), + group_id, + requester_user_id, + user_id, + content, ) defer.returnValue(res) @@ -420,12 +414,9 @@ def user_removed_from_group(self, group_id, user_id, content): """ # TODO: Check if user in group token = yield self.store.register_user_group_membership( - group_id, user_id, - membership="leave", - ) - self.notifier.on_new_event( - "groups_key", token, users=[user_id], + group_id, user_id, membership="leave" ) + self.notifier.on_new_event("groups_key", token, users=[user_id]) @defer.inlineCallbacks def get_joined_groups(self, user_id): @@ -445,7 +436,7 @@ def get_publicised_groups_for_user(self, user_id): defer.returnValue({"groups": result}) else: bulk_result = yield self.transport_client.bulk_get_publicised_groups( - get_domain_from_id(user_id), [user_id], + get_domain_from_id(user_id), [user_id] ) result = bulk_result.get("users", {}).get(user_id) # TODO: Verify attestations @@ -460,9 +451,7 @@ def bulk_get_publicised_groups(self, user_ids, proxy=True): if self.hs.is_mine_id(user_id): local_users.add(user_id) else: - destinations.setdefault( - get_domain_from_id(user_id), set() - ).add(user_id) + destinations.setdefault(get_domain_from_id(user_id), set()).add(user_id) if not proxy and destinations: raise SynapseError(400, "Some user_ids are not local") @@ -472,16 +461,14 @@ def bulk_get_publicised_groups(self, user_ids, proxy=True): for destination, dest_user_ids in iteritems(destinations): try: r = yield self.transport_client.bulk_get_publicised_groups( - destination, list(dest_user_ids), + destination, list(dest_user_ids) ) results.update(r["users"]) except Exception: failed_results.extend(dest_user_ids) for uid in local_users: - results[uid] = yield self.store.get_publicised_groups_for_user( - uid - ) + results[uid] = yield self.store.get_publicised_groups_for_user(uid) # Check AS associated groups for this user - this depends on the # RegExps in the AS registration file (under `users`) diff --git a/synapse/handlers/identity.py b/synapse/handlers/identity.py index 04caf657934d..c82b1933f216 100644 --- a/synapse/handlers/identity.py +++ b/synapse/handlers/identity.py @@ -36,7 +36,6 @@ class IdentityHandler(BaseHandler): - def __init__(self, hs): super(IdentityHandler, self).__init__(hs) @@ -64,40 +63,38 @@ def _should_trust_id_server(self, id_server): @defer.inlineCallbacks def threepid_from_creds(self, creds): - if 'id_server' in creds: - id_server = creds['id_server'] - elif 'idServer' in creds: - id_server = creds['idServer'] + if "id_server" in creds: + id_server = creds["id_server"] + elif "idServer" in creds: + id_server = creds["idServer"] else: raise SynapseError(400, "No id_server in creds") - if 'client_secret' in creds: - client_secret = creds['client_secret'] - elif 'clientSecret' in creds: - client_secret = creds['clientSecret'] + if "client_secret" in creds: + client_secret = creds["client_secret"] + elif "clientSecret" in creds: + client_secret = creds["clientSecret"] else: raise SynapseError(400, "No client_secret in creds") if not self._should_trust_id_server(id_server): logger.warn( - '%s is not a trusted ID server: rejecting 3pid ' + - 'credentials', id_server + "%s is not a trusted ID server: rejecting 3pid " + "credentials", + id_server, ) defer.returnValue(None) try: data = yield self.http_client.get_json( - "https://%s%s" % ( - id_server, - "/_matrix/identity/api/v1/3pid/getValidated3pid" - ), - {'sid': creds['sid'], 'client_secret': client_secret} + "https://%s%s" + % (id_server, "/_matrix/identity/api/v1/3pid/getValidated3pid"), + {"sid": creds["sid"], "client_secret": client_secret}, ) except HttpResponseException as e: logger.info("getValidated3pid failed with Matrix error: %r", e) raise e.to_synapse_error() - if 'medium' in data: + if "medium" in data: defer.returnValue(data) defer.returnValue(None) @@ -106,30 +103,24 @@ def bind_threepid(self, creds, mxid): logger.debug("binding threepid %r to %s", creds, mxid) data = None - if 'id_server' in creds: - id_server = creds['id_server'] - elif 'idServer' in creds: - id_server = creds['idServer'] + if "id_server" in creds: + id_server = creds["id_server"] + elif "idServer" in creds: + id_server = creds["idServer"] else: raise SynapseError(400, "No id_server in creds") - if 'client_secret' in creds: - client_secret = creds['client_secret'] - elif 'clientSecret' in creds: - client_secret = creds['clientSecret'] + if "client_secret" in creds: + client_secret = creds["client_secret"] + elif "clientSecret" in creds: + client_secret = creds["clientSecret"] else: raise SynapseError(400, "No client_secret in creds") try: data = yield self.http_client.post_urlencoded_get_json( - "https://%s%s" % ( - id_server, "/_matrix/identity/api/v1/3pid/bind" - ), - { - 'sid': creds['sid'], - 'client_secret': client_secret, - 'mxid': mxid, - } + "https://%s%s" % (id_server, "/_matrix/identity/api/v1/3pid/bind"), + {"sid": creds["sid"], "client_secret": client_secret, "mxid": mxid}, ) logger.debug("bound threepid %r to %s", creds, mxid) @@ -165,9 +156,7 @@ def try_unbind_threepid(self, mxid, threepid): id_servers = [threepid["id_server"]] else: id_servers = yield self.store.get_id_servers_user_bound( - user_id=mxid, - medium=threepid["medium"], - address=threepid["address"], + user_id=mxid, medium=threepid["medium"], address=threepid["address"] ) # We don't know where to unbind, so we don't have a choice but to return @@ -177,7 +166,7 @@ def try_unbind_threepid(self, mxid, threepid): changed = True for id_server in id_servers: changed &= yield self.try_unbind_threepid_with_id_server( - mxid, threepid, id_server, + mxid, threepid, id_server ) defer.returnValue(changed) @@ -201,10 +190,7 @@ def try_unbind_threepid_with_id_server(self, mxid, threepid, id_server): url = "https://%s/_matrix/identity/api/v1/3pid/unbind" % (id_server,) content = { "mxid": mxid, - "threepid": { - "medium": threepid["medium"], - "address": threepid["address"], - }, + "threepid": {"medium": threepid["medium"], "address": threepid["address"]}, } # we abuse the federation http client to sign the request, but we have to send it @@ -212,25 +198,19 @@ def try_unbind_threepid_with_id_server(self, mxid, threepid, id_server): # 'browser-like' HTTPS. auth_headers = self.federation_http_client.build_auth_headers( destination=None, - method='POST', - url_bytes='/_matrix/identity/api/v1/3pid/unbind'.encode('ascii'), + method="POST", + url_bytes="/_matrix/identity/api/v1/3pid/unbind".encode("ascii"), content=content, destination_is=id_server, ) - headers = { - b"Authorization": auth_headers, - } + headers = {b"Authorization": auth_headers} try: - yield self.http_client.post_json_get_json( - url, - content, - headers, - ) + yield self.http_client.post_json_get_json(url, content, headers) changed = True except HttpResponseException as e: changed = False - if e.code in (400, 404, 501,): + if e.code in (400, 404, 501): # The remote server probably doesn't support unbinding (yet) logger.warn("Received %d response while unbinding threepid", e.code) else: @@ -248,35 +228,27 @@ def try_unbind_threepid_with_id_server(self, mxid, threepid, id_server): @defer.inlineCallbacks def requestEmailToken( - self, - id_server, - email, - client_secret, - send_attempt, - next_link=None, + self, id_server, email, client_secret, send_attempt, next_link=None ): if not self._should_trust_id_server(id_server): raise SynapseError( - 400, "Untrusted ID server '%s'" % id_server, - Codes.SERVER_NOT_TRUSTED + 400, "Untrusted ID server '%s'" % id_server, Codes.SERVER_NOT_TRUSTED ) params = { - 'email': email, - 'client_secret': client_secret, - 'send_attempt': send_attempt, + "email": email, + "client_secret": client_secret, + "send_attempt": send_attempt, } if next_link: - params.update({'next_link': next_link}) + params.update({"next_link": next_link}) try: data = yield self.http_client.post_json_get_json( - "https://%s%s" % ( - id_server, - "/_matrix/identity/api/v1/validate/email/requestToken" - ), - params + "https://%s%s" + % (id_server, "/_matrix/identity/api/v1/validate/email/requestToken"), + params, ) defer.returnValue(data) except HttpResponseException as e: @@ -285,30 +257,26 @@ def requestEmailToken( @defer.inlineCallbacks def requestMsisdnToken( - self, id_server, country, phone_number, - client_secret, send_attempt, **kwargs + self, id_server, country, phone_number, client_secret, send_attempt, **kwargs ): if not self._should_trust_id_server(id_server): raise SynapseError( - 400, "Untrusted ID server '%s'" % id_server, - Codes.SERVER_NOT_TRUSTED + 400, "Untrusted ID server '%s'" % id_server, Codes.SERVER_NOT_TRUSTED ) params = { - 'country': country, - 'phone_number': phone_number, - 'client_secret': client_secret, - 'send_attempt': send_attempt, + "country": country, + "phone_number": phone_number, + "client_secret": client_secret, + "send_attempt": send_attempt, } params.update(kwargs) try: data = yield self.http_client.post_json_get_json( - "https://%s%s" % ( - id_server, - "/_matrix/identity/api/v1/validate/msisdn/requestToken" - ), - params + "https://%s%s" + % (id_server, "/_matrix/identity/api/v1/validate/msisdn/requestToken"), + params, ) defer.returnValue(data) except HttpResponseException as e: diff --git a/synapse/handlers/initial_sync.py b/synapse/handlers/initial_sync.py index aaee5db0b7b5..a1fe9d116faf 100644 --- a/synapse/handlers/initial_sync.py +++ b/synapse/handlers/initial_sync.py @@ -44,8 +44,13 @@ def __init__(self, hs): self.snapshot_cache = SnapshotCache() self._event_serializer = hs.get_event_client_serializer() - def snapshot_all_rooms(self, user_id=None, pagin_config=None, - as_client_event=True, include_archived=False): + def snapshot_all_rooms( + self, + user_id=None, + pagin_config=None, + as_client_event=True, + include_archived=False, + ): """Retrieve a snapshot of all rooms the user is invited or has joined. This snapshot may include messages for all rooms where the user is @@ -77,13 +82,22 @@ def snapshot_all_rooms(self, user_id=None, pagin_config=None, if result is not None: return result - return self.snapshot_cache.set(now_ms, key, self._snapshot_all_rooms( - user_id, pagin_config, as_client_event, include_archived - )) + return self.snapshot_cache.set( + now_ms, + key, + self._snapshot_all_rooms( + user_id, pagin_config, as_client_event, include_archived + ), + ) @defer.inlineCallbacks - def _snapshot_all_rooms(self, user_id=None, pagin_config=None, - as_client_event=True, include_archived=False): + def _snapshot_all_rooms( + self, + user_id=None, + pagin_config=None, + as_client_event=True, + include_archived=False, + ): memberships = [Membership.INVITE, Membership.JOIN] if include_archived: @@ -128,8 +142,7 @@ def handle_room(event): "room_id": event.room_id, "membership": event.membership, "visibility": ( - "public" if event.room_id in public_room_ids - else "private" + "public" if event.room_id in public_room_ids else "private" ), } @@ -139,7 +152,7 @@ def handle_room(event): invite_event = yield self.store.get_event(event.event_id) d["invite"] = yield self._event_serializer.serialize_event( - invite_event, time_now, as_client_event, + invite_event, time_now, as_client_event ) rooms_ret.append(d) @@ -151,14 +164,12 @@ def handle_room(event): if event.membership == Membership.JOIN: room_end_token = now_token.room_key deferred_room_state = run_in_background( - self.state_handler.get_current_state, - event.room_id, + self.state_handler.get_current_state, event.room_id ) elif event.membership == Membership.LEAVE: room_end_token = "s%d" % (event.stream_ordering,) deferred_room_state = run_in_background( - self.store.get_state_for_events, - [event.event_id], + self.store.get_state_for_events, [event.event_id] ) deferred_room_state.addCallback( lambda states: states[event.event_id] @@ -178,9 +189,7 @@ def handle_room(event): ) ).addErrback(unwrapFirstError) - messages = yield filter_events_for_client( - self.store, user_id, messages - ) + messages = yield filter_events_for_client(self.store, user_id, messages) start_token = now_token.copy_and_replace("room_key", token) end_token = now_token.copy_and_replace("room_key", room_end_token) @@ -189,8 +198,7 @@ def handle_room(event): d["messages"] = { "chunk": ( yield self._event_serializer.serialize_events( - messages, time_now=time_now, - as_client_event=as_client_event, + messages, time_now=time_now, as_client_event=as_client_event ) ), "start": start_token.to_string(), @@ -200,23 +208,21 @@ def handle_room(event): d["state"] = yield self._event_serializer.serialize_events( current_state.values(), time_now=time_now, - as_client_event=as_client_event + as_client_event=as_client_event, ) account_data_events = [] tags = tags_by_room.get(event.room_id) if tags: - account_data_events.append({ - "type": "m.tag", - "content": {"tags": tags}, - }) + account_data_events.append( + {"type": "m.tag", "content": {"tags": tags}} + ) account_data = account_data_by_room.get(event.room_id, {}) for account_data_type, content in account_data.items(): - account_data_events.append({ - "type": account_data_type, - "content": content, - }) + account_data_events.append( + {"type": account_data_type, "content": content} + ) d["account_data"] = account_data_events except Exception: @@ -226,10 +232,7 @@ def handle_room(event): account_data_events = [] for account_data_type, content in account_data.items(): - account_data_events.append({ - "type": account_data_type, - "content": content, - }) + account_data_events.append({"type": account_data_type, "content": content}) now = self.clock.time_msec() @@ -274,7 +277,7 @@ def room_initial_sync(self, requester, room_id, pagin_config=None): user_id = requester.user.to_string() membership, member_event_id = yield self._check_in_room_or_world_readable( - room_id, user_id, + room_id, user_id ) is_peeking = member_event_id is None @@ -290,28 +293,21 @@ def room_initial_sync(self, requester, room_id, pagin_config=None): account_data_events = [] tags = yield self.store.get_tags_for_room(user_id, room_id) if tags: - account_data_events.append({ - "type": "m.tag", - "content": {"tags": tags}, - }) + account_data_events.append({"type": "m.tag", "content": {"tags": tags}}) account_data = yield self.store.get_account_data_for_room(user_id, room_id) for account_data_type, content in account_data.items(): - account_data_events.append({ - "type": account_data_type, - "content": content, - }) + account_data_events.append({"type": account_data_type, "content": content}) result["account_data"] = account_data_events defer.returnValue(result) @defer.inlineCallbacks - def _room_initial_sync_parted(self, user_id, room_id, pagin_config, - membership, member_event_id, is_peeking): - room_state = yield self.store.get_state_for_events( - [member_event_id], - ) + def _room_initial_sync_parted( + self, user_id, room_id, pagin_config, membership, member_event_id, is_peeking + ): + room_state = yield self.store.get_state_for_events([member_event_id]) room_state = room_state[member_event_id] @@ -319,14 +315,10 @@ def _room_initial_sync_parted(self, user_id, room_id, pagin_config, if limit is None: limit = 10 - stream_token = yield self.store.get_stream_token_for_event( - member_event_id - ) + stream_token = yield self.store.get_stream_token_for_event(member_event_id) messages, token = yield self.store.get_recent_events_for_room( - room_id, - limit=limit, - end_token=stream_token + room_id, limit=limit, end_token=stream_token ) messages = yield filter_events_for_client( @@ -338,34 +330,39 @@ def _room_initial_sync_parted(self, user_id, room_id, pagin_config, time_now = self.clock.time_msec() - defer.returnValue({ - "membership": membership, - "room_id": room_id, - "messages": { - "chunk": (yield self._event_serializer.serialize_events( - messages, time_now, - )), - "start": start_token.to_string(), - "end": end_token.to_string(), - }, - "state": (yield self._event_serializer.serialize_events( - room_state.values(), time_now, - )), - "presence": [], - "receipts": [], - }) + defer.returnValue( + { + "membership": membership, + "room_id": room_id, + "messages": { + "chunk": ( + yield self._event_serializer.serialize_events( + messages, time_now + ) + ), + "start": start_token.to_string(), + "end": end_token.to_string(), + }, + "state": ( + yield self._event_serializer.serialize_events( + room_state.values(), time_now + ) + ), + "presence": [], + "receipts": [], + } + ) @defer.inlineCallbacks - def _room_initial_sync_joined(self, user_id, room_id, pagin_config, - membership, is_peeking): - current_state = yield self.state.get_current_state( - room_id=room_id, - ) + def _room_initial_sync_joined( + self, user_id, room_id, pagin_config, membership, is_peeking + ): + current_state = yield self.state.get_current_state(room_id=room_id) # TODO: These concurrently time_now = self.clock.time_msec() state = yield self._event_serializer.serialize_events( - current_state.values(), time_now, + current_state.values(), time_now ) now_token = yield self.hs.get_event_sources().get_current_token() @@ -375,7 +372,8 @@ def _room_initial_sync_joined(self, user_id, room_id, pagin_config, limit = 10 room_members = [ - m for m in current_state.values() + m + for m in current_state.values() if m.type == EventTypes.Member and m.content["membership"] == Membership.JOIN ] @@ -389,8 +387,7 @@ def get_presence(): defer.returnValue([]) states = yield presence_handler.get_states( - [m.user_id for m in room_members], - as_event=True, + [m.user_id for m in room_members], as_event=True ) defer.returnValue(states) @@ -398,8 +395,7 @@ def get_presence(): @defer.inlineCallbacks def get_receipts(): receipts = yield self.store.get_linearized_receipts_for_room( - room_id, - to_key=now_token.receipt_key, + room_id, to_key=now_token.receipt_key ) if not receipts: receipts = [] @@ -415,14 +411,14 @@ def get_receipts(): room_id, limit=limit, end_token=now_token.room_key, - ) + ), ], consumeErrors=True, - ).addErrback(unwrapFirstError), + ).addErrback(unwrapFirstError) ) messages = yield filter_events_for_client( - self.store, user_id, messages, is_peeking=is_peeking, + self.store, user_id, messages, is_peeking=is_peeking ) start_token = now_token.copy_and_replace("room_key", token) @@ -433,9 +429,9 @@ def get_receipts(): ret = { "room_id": room_id, "messages": { - "chunk": (yield self._event_serializer.serialize_events( - messages, time_now, - )), + "chunk": ( + yield self._event_serializer.serialize_events(messages, time_now) + ), "start": start_token.to_string(), "end": end_token.to_string(), }, @@ -464,8 +460,8 @@ def _check_in_room_or_world_readable(self, room_id, user_id): room_id, EventTypes.RoomHistoryVisibility, "" ) if ( - visibility and - visibility.content["history_visibility"] == "world_readable" + visibility + and visibility.content["history_visibility"] == "world_readable" ): defer.returnValue((Membership.JOIN, None)) return diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index 7728ea230d2b..683da6bf3204 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -61,8 +61,9 @@ def __init__(self, hs): self._event_serializer = hs.get_event_client_serializer() @defer.inlineCallbacks - def get_room_data(self, user_id=None, room_id=None, - event_type=None, state_key="", is_guest=False): + def get_room_data( + self, user_id=None, room_id=None, event_type=None, state_key="", is_guest=False + ): """ Get data from a room. Args: @@ -77,9 +78,7 @@ def get_room_data(self, user_id=None, room_id=None, ) if membership == Membership.JOIN: - data = yield self.state.get_current_state( - room_id, event_type, state_key - ) + data = yield self.state.get_current_state(room_id, event_type, state_key) elif membership == Membership.LEAVE: key = (event_type, state_key) room_state = yield self.store.get_state_for_events( @@ -91,8 +90,12 @@ def get_room_data(self, user_id=None, room_id=None, @defer.inlineCallbacks def get_state_events( - self, user_id, room_id, state_filter=StateFilter.all(), - at_token=None, is_guest=False, + self, + user_id, + room_id, + state_filter=StateFilter.all(), + at_token=None, + is_guest=False, ): """Retrieve all state events for a given room. If the user is joined to the room then return the current state. If the user has @@ -124,50 +127,48 @@ def get_state_events( # does not reliably give you the state at the given stream position. # (/~https://github.com/matrix-org/synapse/issues/3305) last_events, _ = yield self.store.get_recent_events_for_room( - room_id, end_token=at_token.room_key, limit=1, + room_id, end_token=at_token.room_key, limit=1 ) if not last_events: - raise NotFoundError("Can't find event for token %s" % (at_token, )) + raise NotFoundError("Can't find event for token %s" % (at_token,)) visible_events = yield filter_events_for_client( - self.store, user_id, last_events, + self.store, user_id, last_events ) event = last_events[0] if visible_events: room_state = yield self.store.get_state_for_events( - [event.event_id], state_filter=state_filter, + [event.event_id], state_filter=state_filter ) room_state = room_state[event.event_id] else: raise AuthError( 403, - "User %s not allowed to view events in room %s at token %s" % ( - user_id, room_id, at_token, - ) + "User %s not allowed to view events in room %s at token %s" + % (user_id, room_id, at_token), ) else: membership, membership_event_id = ( - yield self.auth.check_in_room_or_world_readable( - room_id, user_id, - ) + yield self.auth.check_in_room_or_world_readable(room_id, user_id) ) if membership == Membership.JOIN: state_ids = yield self.store.get_filtered_current_state_ids( - room_id, state_filter=state_filter, + room_id, state_filter=state_filter ) room_state = yield self.store.get_events(state_ids.values()) elif membership == Membership.LEAVE: room_state = yield self.store.get_state_for_events( - [membership_event_id], state_filter=state_filter, + [membership_event_id], state_filter=state_filter ) room_state = room_state[membership_event_id] now = self.clock.time_msec() events = yield self._event_serializer.serialize_events( - room_state.values(), now, + room_state.values(), + now, # We don't bother bundling aggregations in when asked for state # events, as clients won't use them. bundle_aggregations=False, @@ -211,13 +212,15 @@ def get_joined_members(self, requester, room_id): # Loop fell through, AS has no interested users in room raise AuthError(403, "Appservice not in room") - defer.returnValue({ - user_id: { - "avatar_url": profile.avatar_url, - "display_name": profile.display_name, + defer.returnValue( + { + user_id: { + "avatar_url": profile.avatar_url, + "display_name": profile.display_name, + } + for user_id, profile in iteritems(users_with_profile) } - for user_id, profile in iteritems(users_with_profile) - }) + ) class EventCreationHandler(object): @@ -269,14 +272,21 @@ def __init__(self, hs): self.clock.looping_call( lambda: run_as_background_process( "send_dummy_events_to_fill_extremities", - self._send_dummy_events_to_fill_extremities + self._send_dummy_events_to_fill_extremities, ), 5 * 60 * 1000, ) @defer.inlineCallbacks - def create_event(self, requester, event_dict, token_id=None, txn_id=None, - prev_events_and_hashes=None, require_consent=True): + def create_event( + self, + requester, + event_dict, + token_id=None, + txn_id=None, + prev_events_and_hashes=None, + require_consent=True, + ): """ Given a dict from a client, create a new event. @@ -336,8 +346,7 @@ def create_event(self, requester, event_dict, token_id=None, txn_id=None, content["avatar_url"] = yield profile.get_avatar_url(target) except Exception as e: logger.info( - "Failed to get profile information for %r: %s", - target, e + "Failed to get profile information for %r: %s", target, e ) is_exempt = yield self._is_exempt_from_privacy_policy(builder, requester) @@ -373,16 +382,17 @@ def create_event(self, requester, event_dict, token_id=None, txn_id=None, prev_event = yield self.store.get_event(prev_event_id, allow_none=True) if not prev_event or prev_event.membership != Membership.JOIN: logger.warning( - ("Attempt to send `m.room.aliases` in room %s by user %s but" - " membership is %s"), + ( + "Attempt to send `m.room.aliases` in room %s by user %s but" + " membership is %s" + ), event.room_id, event.sender, prev_event.membership if prev_event else None, ) raise AuthError( - 403, - "You must be in the room to create an alias for it", + 403, "You must be in the room to create an alias for it" ) self.validator.validate_new(event) @@ -449,8 +459,8 @@ def assert_accepted_privacy_policy(self, requester): # exempt the system notices user if ( - self.config.server_notices_mxid is not None and - user_id == self.config.server_notices_mxid + self.config.server_notices_mxid is not None + and user_id == self.config.server_notices_mxid ): return @@ -463,15 +473,10 @@ def assert_accepted_privacy_policy(self, requester): return consent_uri = self._consent_uri_builder.build_user_consent_uri( - requester.user.localpart, - ) - msg = self._block_events_without_consent_error % { - 'consent_uri': consent_uri, - } - raise ConsentNotGivenError( - msg=msg, - consent_uri=consent_uri, + requester.user.localpart ) + msg = self._block_events_without_consent_error % {"consent_uri": consent_uri} + raise ConsentNotGivenError(msg=msg, consent_uri=consent_uri) @defer.inlineCallbacks def send_nonmember_event(self, requester, event, context, ratelimit=True): @@ -486,8 +491,7 @@ def send_nonmember_event(self, requester, event, context, ratelimit=True): """ if event.type == EventTypes.Member: raise SynapseError( - 500, - "Tried to send member event through non-member codepath" + 500, "Tried to send member event through non-member codepath" ) user = UserID.from_string(event.sender) @@ -499,15 +503,13 @@ def send_nonmember_event(self, requester, event, context, ratelimit=True): if prev_state is not None: logger.info( "Not bothering to persist state event %s duplicated by %s", - event.event_id, prev_state.event_id, + event.event_id, + prev_state.event_id, ) defer.returnValue(prev_state) yield self.handle_new_client_event( - requester=requester, - event=event, - context=context, - ratelimit=ratelimit, + requester=requester, event=event, context=context, ratelimit=ratelimit ) @defer.inlineCallbacks @@ -533,11 +535,7 @@ def deduplicate_state_event(self, event, context): @defer.inlineCallbacks def create_and_send_nonmember_event( - self, - requester, - event_dict, - ratelimit=True, - txn_id=None + self, requester, event_dict, ratelimit=True, txn_id=None ): """ Creates an event, then sends it. @@ -552,32 +550,25 @@ def create_and_send_nonmember_event( # taking longer. with (yield self.limiter.queue(event_dict["room_id"])): event, context = yield self.create_event( - requester, - event_dict, - token_id=requester.access_token_id, - txn_id=txn_id + requester, event_dict, token_id=requester.access_token_id, txn_id=txn_id ) spam_error = self.spam_checker.check_event_for_spam(event) if spam_error: if not isinstance(spam_error, string_types): spam_error = "Spam is not permitted here" - raise SynapseError( - 403, spam_error, Codes.FORBIDDEN - ) + raise SynapseError(403, spam_error, Codes.FORBIDDEN) yield self.send_nonmember_event( - requester, - event, - context, - ratelimit=ratelimit, + requester, event, context, ratelimit=ratelimit ) defer.returnValue(event) @measure_func("create_new_client_event") @defer.inlineCallbacks - def create_new_client_event(self, builder, requester=None, - prev_events_and_hashes=None): + def create_new_client_event( + self, builder, requester=None, prev_events_and_hashes=None + ): """Create a new event for a local client Args: @@ -597,22 +588,21 @@ def create_new_client_event(self, builder, requester=None, """ if prev_events_and_hashes is not None: - assert len(prev_events_and_hashes) <= 10, \ - "Attempting to create an event with %i prev_events" % ( - len(prev_events_and_hashes), + assert len(prev_events_and_hashes) <= 10, ( + "Attempting to create an event with %i prev_events" + % (len(prev_events_and_hashes),) ) else: - prev_events_and_hashes = \ - yield self.store.get_prev_events_for_room(builder.room_id) + prev_events_and_hashes = yield self.store.get_prev_events_for_room( + builder.room_id + ) prev_events = [ (event_id, prev_hashes) for event_id, prev_hashes, _ in prev_events_and_hashes ] - event = yield builder.build( - prev_event_ids=[p for p, _ in prev_events], - ) + event = yield builder.build(prev_event_ids=[p for p, _ in prev_events]) context = yield self.state.compute_event_context(event) if requester: context.app_service = requester.app_service @@ -628,29 +618,19 @@ def create_new_client_event(self, builder, requester=None, aggregation_key = relation["key"] already_exists = yield self.store.has_user_annotated_event( - relates_to, event.type, aggregation_key, event.sender, + relates_to, event.type, aggregation_key, event.sender ) if already_exists: raise SynapseError(400, "Can't send same reaction twice") - logger.debug( - "Created event %s", - event.event_id, - ) + logger.debug("Created event %s", event.event_id) - defer.returnValue( - (event, context,) - ) + defer.returnValue((event, context)) @measure_func("handle_new_client_event") @defer.inlineCallbacks def handle_new_client_event( - self, - requester, - event, - context, - ratelimit=True, - extra_users=[], + self, requester, event, context, ratelimit=True, extra_users=[] ): """Processes a new event. This includes checking auth, persisting it, notifying users, sending to remote servers, etc. @@ -666,19 +646,20 @@ def handle_new_client_event( extra_users (list(UserID)): Any extra users to notify about event """ - if event.is_state() and (event.type, event.state_key) == (EventTypes.Create, ""): - room_version = event.content.get( - "room_version", RoomVersions.V1.identifier - ) + if event.is_state() and (event.type, event.state_key) == ( + EventTypes.Create, + "", + ): + room_version = event.content.get("room_version", RoomVersions.V1.identifier) else: room_version = yield self.store.get_room_version(event.room_id) event_allowed = yield self.third_party_event_rules.check_event_allowed( - event, context, + event, context ) if not event_allowed: raise SynapseError( - 403, "This event is not allowed in this context", Codes.FORBIDDEN, + 403, "This event is not allowed in this context", Codes.FORBIDDEN ) try: @@ -695,9 +676,7 @@ def handle_new_client_event( logger.exception("Failed to encode content: %r", event.content) raise - yield self.action_generator.handle_push_actions_for_event( - event, context - ) + yield self.action_generator.handle_push_actions_for_event(event, context) # reraise does not allow inlineCallbacks to preserve the stacktrace, so we # hack around with a try/finally instead. @@ -718,11 +697,7 @@ def handle_new_client_event( return yield self.persist_and_notify_client_event( - requester, - event, - context, - ratelimit=ratelimit, - extra_users=extra_users, + requester, event, context, ratelimit=ratelimit, extra_users=extra_users ) success = True @@ -731,18 +706,12 @@ def handle_new_client_event( # Ensure that we actually remove the entries in the push actions # staging area, if we calculated them. run_in_background( - self.store.remove_push_actions_from_staging, - event.event_id, + self.store.remove_push_actions_from_staging, event.event_id ) @defer.inlineCallbacks def persist_and_notify_client_event( - self, - requester, - event, - context, - ratelimit=True, - extra_users=[], + self, requester, event, context, ratelimit=True, extra_users=[] ): """Called when we have fully built the event, have already calculated the push actions for the event, and checked auth. @@ -767,20 +736,16 @@ def persist_and_notify_client_event( if mapping["room_id"] != event.room_id: raise SynapseError( 400, - "Room alias %s does not point to the room" % ( - room_alias_str, - ) + "Room alias %s does not point to the room" % (room_alias_str,), ) federation_handler = self.hs.get_handlers().federation_handler if event.type == EventTypes.Member: if event.content["membership"] == Membership.INVITE: + def is_inviter_member_event(e): - return ( - e.type == EventTypes.Member and - e.sender == event.sender - ) + return e.type == EventTypes.Member and e.sender == event.sender current_state_ids = yield context.get_current_state_ids(self.store) @@ -810,26 +775,21 @@ def is_inviter_member_event(e): # to get them to sign the event. returned_invite = yield federation_handler.send_invite( - invitee.domain, - event, + invitee.domain, event ) event.unsigned.pop("room_state", None) # TODO: Make sure the signatures actually are correct. - event.signatures.update( - returned_invite.signatures - ) + event.signatures.update(returned_invite.signatures) if event.type == EventTypes.Redaction: prev_state_ids = yield context.get_prev_state_ids(self.store) auth_events_ids = yield self.auth.compute_auth_events( - event, prev_state_ids, for_verification=True, + event, prev_state_ids, for_verification=True ) auth_events = yield self.store.get_events(auth_events_ids) - auth_events = { - (e.type, e.state_key): e for e in auth_events.values() - } + auth_events = {(e.type, e.state_key): e for e in auth_events.values()} room_version = yield self.store.get_room_version(event.room_id) if self.auth.check_redaction(room_version, event, auth_events=auth_events): original_event = yield self.store.get_event( @@ -837,13 +797,10 @@ def is_inviter_member_event(e): check_redacted=False, get_prev_content=False, allow_rejected=False, - allow_none=False + allow_none=False, ) if event.user_id != original_event.user_id: - raise AuthError( - 403, - "You don't have permission to redact events" - ) + raise AuthError(403, "You don't have permission to redact events") # We've already checked. event.internal_metadata.recheck_redaction = False @@ -851,24 +808,18 @@ def is_inviter_member_event(e): if event.type == EventTypes.Create: prev_state_ids = yield context.get_prev_state_ids(self.store) if prev_state_ids: - raise AuthError( - 403, - "Changing the room create event is forbidden", - ) + raise AuthError(403, "Changing the room create event is forbidden") (event_stream_id, max_stream_id) = yield self.store.persist_event( event, context=context ) - yield self.pusher_pool.on_new_notifications( - event_stream_id, max_stream_id, - ) + yield self.pusher_pool.on_new_notifications(event_stream_id, max_stream_id) def _notify(): try: self.notifier.on_new_room_event( - event, event_stream_id, max_stream_id, - extra_users=extra_users + event, event_stream_id, max_stream_id, extra_users=extra_users ) except Exception: logger.exception("Error notifying about new room event") @@ -895,23 +846,19 @@ def _send_dummy_events_to_fill_extremities(self): """ room_ids = yield self.store.get_rooms_with_many_extremities( - min_count=10, limit=5, + min_count=10, limit=5 ) for room_id in room_ids: # For each room we need to find a joined member we can use to send # the dummy event with. - prev_events_and_hashes = yield self.store.get_prev_events_for_room( - room_id, - ) + prev_events_and_hashes = yield self.store.get_prev_events_for_room(room_id) - latest_event_ids = ( - event_id for (event_id, _, _) in prev_events_and_hashes - ) + latest_event_ids = (event_id for (event_id, _, _) in prev_events_and_hashes) members = yield self.state.get_current_users_in_room( - room_id, latest_event_ids=latest_event_ids, + room_id, latest_event_ids=latest_event_ids ) user_id = None @@ -941,9 +888,4 @@ def _send_dummy_events_to_fill_extremities(self): event.internal_metadata.proactively_send = False - yield self.send_nonmember_event( - requester, - event, - context, - ratelimit=False, - ) + yield self.send_nonmember_event(requester, event, context, ratelimit=False) diff --git a/synapse/handlers/pagination.py b/synapse/handlers/pagination.py index 8f811e24fed9..062e026e5f94 100644 --- a/synapse/handlers/pagination.py +++ b/synapse/handlers/pagination.py @@ -55,9 +55,7 @@ def __init__(self): self.status = PurgeStatus.STATUS_ACTIVE def asdict(self): - return { - "status": PurgeStatus.STATUS_TEXT[self.status] - } + return {"status": PurgeStatus.STATUS_TEXT[self.status]} class PaginationHandler(object): @@ -79,8 +77,7 @@ def __init__(self, hs): self._purges_by_id = {} self._event_serializer = hs.get_event_client_serializer() - def start_purge_history(self, room_id, token, - delete_local_events=False): + def start_purge_history(self, room_id, token, delete_local_events=False): """Start off a history purge on a room. Args: @@ -95,8 +92,7 @@ def start_purge_history(self, room_id, token, """ if room_id in self._purges_in_progress_by_room: raise SynapseError( - 400, - "History purge already in progress for %s" % (room_id, ), + 400, "History purge already in progress for %s" % (room_id,) ) purge_id = random_string(16) @@ -107,14 +103,12 @@ def start_purge_history(self, room_id, token, self._purges_by_id[purge_id] = PurgeStatus() run_in_background( - self._purge_history, - purge_id, room_id, token, delete_local_events, + self._purge_history, purge_id, room_id, token, delete_local_events ) return purge_id @defer.inlineCallbacks - def _purge_history(self, purge_id, room_id, token, - delete_local_events): + def _purge_history(self, purge_id, room_id, token, delete_local_events): """Carry out a history purge on a room. Args: @@ -130,16 +124,13 @@ def _purge_history(self, purge_id, room_id, token, self._purges_in_progress_by_room.add(room_id) try: with (yield self.pagination_lock.write(room_id)): - yield self.store.purge_history( - room_id, token, delete_local_events, - ) + yield self.store.purge_history(room_id, token, delete_local_events) logger.info("[purge] complete") self._purges_by_id[purge_id].status = PurgeStatus.STATUS_COMPLETE except Exception: f = Failure() logger.error( - "[purge] failed", - exc_info=(f.type, f.value, f.getTracebackObject()), + "[purge] failed", exc_info=(f.type, f.value, f.getTracebackObject()) ) self._purges_by_id[purge_id].status = PurgeStatus.STATUS_FAILED finally: @@ -148,6 +139,7 @@ def _purge_history(self, purge_id, room_id, token, # remove the purge from the list 24 hours after it completes def clear_purge(): del self._purges_by_id[purge_id] + self.hs.get_reactor().callLater(24 * 3600, clear_purge) def get_purge_status(self, purge_id): @@ -162,8 +154,14 @@ def get_purge_status(self, purge_id): return self._purges_by_id.get(purge_id) @defer.inlineCallbacks - def get_messages(self, requester, room_id=None, pagin_config=None, - as_client_event=True, event_filter=None): + def get_messages( + self, + requester, + room_id=None, + pagin_config=None, + as_client_event=True, + event_filter=None, + ): """Get messages in a room. Args: @@ -201,7 +199,7 @@ def get_messages(self, requester, room_id=None, pagin_config=None, room_id, user_id ) - if source_config.direction == 'b': + if source_config.direction == "b": # if we're going backwards, we might need to backfill. This # requires that we have a topo token. if room_token.topological: @@ -235,27 +233,24 @@ def get_messages(self, requester, room_id=None, pagin_config=None, event_filter=event_filter, ) - next_token = pagin_config.from_token.copy_and_replace( - "room_key", next_key - ) + next_token = pagin_config.from_token.copy_and_replace("room_key", next_key) if events: if event_filter: events = event_filter.filter(events) events = yield filter_events_for_client( - self.store, - user_id, - events, - is_peeking=(member_event_id is None), + self.store, user_id, events, is_peeking=(member_event_id is None) ) if not events: - defer.returnValue({ - "chunk": [], - "start": pagin_config.from_token.to_string(), - "end": next_token.to_string(), - }) + defer.returnValue( + { + "chunk": [], + "start": pagin_config.from_token.to_string(), + "end": next_token.to_string(), + } + ) state = None if event_filter and event_filter.lazy_load_members() and len(events) > 0: @@ -263,12 +258,11 @@ def get_messages(self, requester, room_id=None, pagin_config=None, # FIXME: we also care about invite targets etc. state_filter = StateFilter.from_types( - (EventTypes.Member, event.sender) - for event in events + (EventTypes.Member, event.sender) for event in events ) state_ids = yield self.store.get_state_ids_for_event( - events[0].event_id, state_filter=state_filter, + events[0].event_id, state_filter=state_filter ) if state_ids: @@ -280,8 +274,7 @@ def get_messages(self, requester, room_id=None, pagin_config=None, chunk = { "chunk": ( yield self._event_serializer.serialize_events( - events, time_now, - as_client_event=as_client_event, + events, time_now, as_client_event=as_client_event ) ), "start": pagin_config.from_token.to_string(), @@ -291,8 +284,7 @@ def get_messages(self, requester, room_id=None, pagin_config=None, if state: chunk["state"] = ( yield self._event_serializer.serialize_events( - state, time_now, - as_client_event=as_client_event, + state, time_now, as_client_event=as_client_event ) ) diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py index 557fb5f83ddb..5204073a3801 100644 --- a/synapse/handlers/presence.py +++ b/synapse/handlers/presence.py @@ -50,16 +50,20 @@ notified_presence_counter = Counter("synapse_handler_presence_notified_presence", "") federation_presence_out_counter = Counter( - "synapse_handler_presence_federation_presence_out", "") + "synapse_handler_presence_federation_presence_out", "" +) presence_updates_counter = Counter("synapse_handler_presence_presence_updates", "") timers_fired_counter = Counter("synapse_handler_presence_timers_fired", "") -federation_presence_counter = Counter("synapse_handler_presence_federation_presence", "") +federation_presence_counter = Counter( + "synapse_handler_presence_federation_presence", "" +) bump_active_time_counter = Counter("synapse_handler_presence_bump_active_time", "") get_updates_counter = Counter("synapse_handler_presence_get_updates", "", ["type"]) notify_reason_counter = Counter( - "synapse_handler_presence_notify_reason", "", ["reason"]) + "synapse_handler_presence_notify_reason", "", ["reason"] +) state_transition_counter = Counter( "synapse_handler_presence_state_transition", "", ["from", "to"] ) @@ -90,7 +94,6 @@ class PresenceHandler(object): - def __init__(self, hs): """ @@ -110,31 +113,26 @@ def __init__(self, hs): federation_registry = hs.get_federation_registry() - federation_registry.register_edu_handler( - "m.presence", self.incoming_presence - ) + federation_registry.register_edu_handler("m.presence", self.incoming_presence) active_presence = self.store.take_presence_startup_info() # A dictionary of the current state of users. This is prefilled with # non-offline presence from the DB. We should fetch from the DB if # we can't find a users presence in here. - self.user_to_current_state = { - state.user_id: state - for state in active_presence - } + self.user_to_current_state = {state.user_id: state for state in active_presence} LaterGauge( - "synapse_handlers_presence_user_to_current_state_size", "", [], - lambda: len(self.user_to_current_state) + "synapse_handlers_presence_user_to_current_state_size", + "", + [], + lambda: len(self.user_to_current_state), ) now = self.clock.time_msec() for state in active_presence: self.wheel_timer.insert( - now=now, - obj=state.user_id, - then=state.last_active_ts + IDLE_TIMER, + now=now, obj=state.user_id, then=state.last_active_ts + IDLE_TIMER ) self.wheel_timer.insert( now=now, @@ -193,27 +191,21 @@ def run_timeout_handler(): "handle_presence_timeouts", self._handle_timeouts ) - self.clock.call_later( - 30, - self.clock.looping_call, - run_timeout_handler, - 5000, - ) + self.clock.call_later(30, self.clock.looping_call, run_timeout_handler, 5000) def run_persister(): return run_as_background_process( "persist_presence_changes", self._persist_unpersisted_changes ) - self.clock.call_later( - 60, - self.clock.looping_call, - run_persister, - 60 * 1000, - ) + self.clock.call_later(60, self.clock.looping_call, run_persister, 60 * 1000) - LaterGauge("synapse_handlers_presence_wheel_timer_size", "", [], - lambda: len(self.wheel_timer)) + LaterGauge( + "synapse_handlers_presence_wheel_timer_size", + "", + [], + lambda: len(self.wheel_timer), + ) # Used to handle sending of presence to newly joined users/servers if hs.config.use_presence: @@ -241,15 +233,17 @@ def _on_shutdown(self): logger.info( "Performing _on_shutdown. Persisting %d unpersisted changes", - len(self.user_to_current_state) + len(self.user_to_current_state), ) if self.unpersisted_users_changes: - yield self.store.update_presence([ - self.user_to_current_state[user_id] - for user_id in self.unpersisted_users_changes - ]) + yield self.store.update_presence( + [ + self.user_to_current_state[user_id] + for user_id in self.unpersisted_users_changes + ] + ) logger.info("Finished _on_shutdown") @defer.inlineCallbacks @@ -261,13 +255,10 @@ def _persist_unpersisted_changes(self): self.unpersisted_users_changes = set() if unpersisted: - logger.info( - "Persisting %d upersisted presence updates", len(unpersisted) + logger.info("Persisting %d upersisted presence updates", len(unpersisted)) + yield self.store.update_presence( + [self.user_to_current_state[user_id] for user_id in unpersisted] ) - yield self.store.update_presence([ - self.user_to_current_state[user_id] - for user_id in unpersisted - ]) @defer.inlineCallbacks def _update_states(self, new_states): @@ -303,10 +294,11 @@ def _update_states(self, new_states): ) new_state, should_notify, should_ping = handle_update( - prev_state, new_state, + prev_state, + new_state, is_mine=self.is_mine_id(user_id), wheel_timer=self.wheel_timer, - now=now + now=now, ) self.user_to_current_state[user_id] = new_state @@ -328,7 +320,8 @@ def _update_states(self, new_states): self.unpersisted_users_changes -= set(to_notify.keys()) to_federation_ping = { - user_id: state for user_id, state in to_federation_ping.items() + user_id: state + for user_id, state in to_federation_ping.items() if user_id not in to_notify } if to_federation_ping: @@ -351,8 +344,8 @@ def _handle_timeouts(self): # Check whether the lists of syncing processes from an external # process have expired. expired_process_ids = [ - process_id for process_id, last_update - in self.external_process_last_updated_ms.items() + process_id + for process_id, last_update in self.external_process_last_updated_ms.items() if now - last_update > EXTERNAL_PROCESS_EXPIRY ] for process_id in expired_process_ids: @@ -362,9 +355,7 @@ def _handle_timeouts(self): self.external_process_last_update.pop(process_id) states = [ - self.user_to_current_state.get( - user_id, UserPresenceState.default(user_id) - ) + self.user_to_current_state.get(user_id, UserPresenceState.default(user_id)) for user_id in users_to_check ] @@ -394,9 +385,7 @@ def bump_presence_active_time(self, user): prev_state = yield self.current_state_for_user(user_id) - new_fields = { - "last_active_ts": self.clock.time_msec(), - } + new_fields = {"last_active_ts": self.clock.time_msec()} if prev_state.state == PresenceState.UNAVAILABLE: new_fields["state"] = PresenceState.ONLINE @@ -430,15 +419,23 @@ def user_syncing(self, user_id, affect_presence=True): if prev_state.state == PresenceState.OFFLINE: # If they're currently offline then bring them online, otherwise # just update the last sync times. - yield self._update_states([prev_state.copy_and_replace( - state=PresenceState.ONLINE, - last_active_ts=self.clock.time_msec(), - last_user_sync_ts=self.clock.time_msec(), - )]) + yield self._update_states( + [ + prev_state.copy_and_replace( + state=PresenceState.ONLINE, + last_active_ts=self.clock.time_msec(), + last_user_sync_ts=self.clock.time_msec(), + ) + ] + ) else: - yield self._update_states([prev_state.copy_and_replace( - last_user_sync_ts=self.clock.time_msec(), - )]) + yield self._update_states( + [ + prev_state.copy_and_replace( + last_user_sync_ts=self.clock.time_msec() + ) + ] + ) @defer.inlineCallbacks def _end(): @@ -446,9 +443,13 @@ def _end(): self.user_to_num_current_syncs[user_id] -= 1 prev_state = yield self.current_state_for_user(user_id) - yield self._update_states([prev_state.copy_and_replace( - last_user_sync_ts=self.clock.time_msec(), - )]) + yield self._update_states( + [ + prev_state.copy_and_replace( + last_user_sync_ts=self.clock.time_msec() + ) + ] + ) except Exception: logger.exception("Error updating presence after sync") @@ -469,7 +470,8 @@ def get_currently_syncing_users(self): """ if self.hs.config.use_presence: syncing_user_ids = { - user_id for user_id, count in self.user_to_num_current_syncs.items() + user_id + for user_id, count in self.user_to_num_current_syncs.items() if count } for user_ids in self.external_process_to_current_syncs.values(): @@ -479,7 +481,9 @@ def get_currently_syncing_users(self): return set() @defer.inlineCallbacks - def update_external_syncs_row(self, process_id, user_id, is_syncing, sync_time_msec): + def update_external_syncs_row( + self, process_id, user_id, is_syncing, sync_time_msec + ): """Update the syncing users for an external process as a delta. Args: @@ -500,20 +504,22 @@ def update_external_syncs_row(self, process_id, user_id, is_syncing, sync_time_m updates = [] if is_syncing and user_id not in process_presence: if prev_state.state == PresenceState.OFFLINE: - updates.append(prev_state.copy_and_replace( - state=PresenceState.ONLINE, - last_active_ts=sync_time_msec, - last_user_sync_ts=sync_time_msec, - )) + updates.append( + prev_state.copy_and_replace( + state=PresenceState.ONLINE, + last_active_ts=sync_time_msec, + last_user_sync_ts=sync_time_msec, + ) + ) else: - updates.append(prev_state.copy_and_replace( - last_user_sync_ts=sync_time_msec, - )) + updates.append( + prev_state.copy_and_replace(last_user_sync_ts=sync_time_msec) + ) process_presence.add(user_id) elif user_id in process_presence: - updates.append(prev_state.copy_and_replace( - last_user_sync_ts=sync_time_msec, - )) + updates.append( + prev_state.copy_and_replace(last_user_sync_ts=sync_time_msec) + ) if not is_syncing: process_presence.discard(user_id) @@ -537,12 +543,12 @@ def update_external_syncs_clear(self, process_id): prev_states = yield self.current_state_for_users(process_presence) time_now_ms = self.clock.time_msec() - yield self._update_states([ - prev_state.copy_and_replace( - last_user_sync_ts=time_now_ms, - ) - for prev_state in itervalues(prev_states) - ]) + yield self._update_states( + [ + prev_state.copy_and_replace(last_user_sync_ts=time_now_ms) + for prev_state in itervalues(prev_states) + ] + ) self.external_process_last_updated_ms.pop(process_id, None) @defer.inlineCallbacks @@ -574,8 +580,7 @@ def current_state_for_users(self, user_ids): missing = [user_id for user_id, state in iteritems(states) if not state] if missing: new = { - user_id: UserPresenceState.default(user_id) - for user_id in missing + user_id: UserPresenceState.default(user_id) for user_id in missing } states.update(new) self.user_to_current_state.update(new) @@ -593,8 +598,10 @@ def _persist_and_notify(self, states): room_ids_to_states, users_to_states = parties self.notifier.on_new_event( - "presence_key", stream_id, rooms=room_ids_to_states.keys(), - users=[UserID.from_string(u) for u in users_to_states] + "presence_key", + stream_id, + rooms=room_ids_to_states.keys(), + users=[UserID.from_string(u) for u in users_to_states], ) self._push_to_remotes(states) @@ -605,8 +612,10 @@ def notify_for_states(self, state, stream_id): room_ids_to_states, users_to_states = parties self.notifier.on_new_event( - "presence_key", stream_id, rooms=room_ids_to_states.keys(), - users=[UserID.from_string(u) for u in users_to_states] + "presence_key", + stream_id, + rooms=room_ids_to_states.keys(), + users=[UserID.from_string(u) for u in users_to_states], ) def _push_to_remotes(self, states): @@ -631,15 +640,15 @@ def incoming_presence(self, origin, content): user_id = push.get("user_id", None) if not user_id: logger.info( - "Got presence update from %r with no 'user_id': %r", - origin, push, + "Got presence update from %r with no 'user_id': %r", origin, push ) continue if get_domain_from_id(user_id) != origin: logger.info( "Got presence update from %r with bad 'user_id': %r", - origin, user_id, + origin, + user_id, ) continue @@ -647,14 +656,12 @@ def incoming_presence(self, origin, content): if not presence_state: logger.info( "Got presence update from %r with no 'presence_state': %r", - origin, push, + origin, + push, ) continue - new_fields = { - "state": presence_state, - "last_federation_update_ts": now, - } + new_fields = {"state": presence_state, "last_federation_update_ts": now} last_active_ago = push.get("last_active_ago", None) if last_active_ago is not None: @@ -672,10 +679,7 @@ def incoming_presence(self, origin, content): @defer.inlineCallbacks def get_state(self, target_user, as_event=False): - results = yield self.get_states( - [target_user.to_string()], - as_event=as_event, - ) + results = yield self.get_states([target_user.to_string()], as_event=as_event) defer.returnValue(results[0]) @@ -699,13 +703,15 @@ def get_states(self, target_user_ids, as_event=False): now = self.clock.time_msec() if as_event: - defer.returnValue([ - { - "type": "m.presence", - "content": format_user_presence_state(state, now), - } - for state in updates - ]) + defer.returnValue( + [ + { + "type": "m.presence", + "content": format_user_presence_state(state, now), + } + for state in updates + ] + ) else: defer.returnValue(updates) @@ -717,7 +723,9 @@ def set_state(self, target_user, state, ignore_status_msg=False): presence = state["presence"] valid_presence = ( - PresenceState.ONLINE, PresenceState.UNAVAILABLE, PresenceState.OFFLINE + PresenceState.ONLINE, + PresenceState.UNAVAILABLE, + PresenceState.OFFLINE, ) if presence not in valid_presence: raise SynapseError(400, "Invalid presence state") @@ -726,9 +734,7 @@ def set_state(self, target_user, state, ignore_status_msg=False): prev_state = yield self.current_state_for_user(user_id) - new_fields = { - "state": presence - } + new_fields = {"state": presence} if not ignore_status_msg: msg = status_msg if presence != PresenceState.OFFLINE else None @@ -877,8 +883,7 @@ def _on_user_joined_room(self, room_id, user_id): hosts = set(host for host in hosts if host != self.server_name) self.federation.send_presence_to_destinations( - states=[state], - destinations=hosts, + states=[state], destinations=hosts ) else: # A remote user has joined the room, so we need to: @@ -904,7 +909,8 @@ def _on_user_joined_room(self, room_id, user_id): # default state. now = self.clock.time_msec() states = [ - state for state in states.values() + state + for state in states.values() if state.state != PresenceState.OFFLINE or now - state.last_active_ts < 7 * 24 * 60 * 60 * 1000 or state.status_msg is not None @@ -912,8 +918,7 @@ def _on_user_joined_room(self, room_id, user_id): if states: self.federation.send_presence_to_destinations( - states=states, - destinations=[get_domain_from_id(user_id)], + states=states, destinations=[get_domain_from_id(user_id)] ) @@ -937,7 +942,10 @@ def should_notify(old_state, new_state): notify_reason_counter.labels("current_active_change").inc() return True - if new_state.last_active_ts - old_state.last_active_ts > LAST_ACTIVE_GRANULARITY: + if ( + new_state.last_active_ts - old_state.last_active_ts + > LAST_ACTIVE_GRANULARITY + ): # Only notify about last active bumps if we're not currently acive if not new_state.currently_active: notify_reason_counter.labels("last_active_change_online").inc() @@ -958,9 +966,7 @@ def format_user_presence_state(state, now, include_user_id=True): The "user_id" is optional so that this function can be used to format presence updates for client /sync responses and for federation /send requests. """ - content = { - "presence": state.state, - } + content = {"presence": state.state} if include_user_id: content["user_id"] = state.user_id if state.last_active_ts: @@ -986,8 +992,15 @@ def __init__(self, hs): @defer.inlineCallbacks @log_function - def get_new_events(self, user, from_key, room_ids=None, include_offline=True, - explicit_room_id=None, **kwargs): + def get_new_events( + self, + user, + from_key, + room_ids=None, + include_offline=True, + explicit_room_id=None, + **kwargs + ): # The process for getting presence events are: # 1. Get the rooms the user is in. # 2. Get the list of user in the rooms. @@ -1030,7 +1043,7 @@ def get_new_events(self, user, from_key, room_ids=None, include_offline=True, if from_key: user_ids_changed = stream_change_cache.get_entities_changed( - users_interested_in, from_key, + users_interested_in, from_key ) else: user_ids_changed = users_interested_in @@ -1040,10 +1053,16 @@ def get_new_events(self, user, from_key, room_ids=None, include_offline=True, if include_offline: defer.returnValue((list(updates.values()), max_token)) else: - defer.returnValue(([ - s for s in itervalues(updates) - if s.state != PresenceState.OFFLINE - ], max_token)) + defer.returnValue( + ( + [ + s + for s in itervalues(updates) + if s.state != PresenceState.OFFLINE + ], + max_token, + ) + ) def get_current_key(self): return self.store.get_current_presence_token() @@ -1061,13 +1080,13 @@ def _get_interested_in(self, user, explicit_room_id, cache_context): users_interested_in.add(user_id) # So that we receive our own presence users_who_share_room = yield self.store.get_users_who_share_room_with_user( - user_id, on_invalidate=cache_context.invalidate, + user_id, on_invalidate=cache_context.invalidate ) users_interested_in.update(users_who_share_room) if explicit_room_id: user_ids = yield self.store.get_users_in_room( - explicit_room_id, on_invalidate=cache_context.invalidate, + explicit_room_id, on_invalidate=cache_context.invalidate ) users_interested_in.update(user_ids) @@ -1123,9 +1142,7 @@ def handle_timeout(state, is_mine, syncing_user_ids, now): if now - state.last_active_ts > IDLE_TIMER: # Currently online, but last activity ages ago so auto # idle - state = state.copy_and_replace( - state=PresenceState.UNAVAILABLE, - ) + state = state.copy_and_replace(state=PresenceState.UNAVAILABLE) changed = True elif now - state.last_active_ts > LAST_ACTIVE_GRANULARITY: # So that we send down a notification that we've @@ -1145,8 +1162,7 @@ def handle_timeout(state, is_mine, syncing_user_ids, now): sync_or_active = max(state.last_user_sync_ts, state.last_active_ts) if now - sync_or_active > SYNC_ONLINE_TIMEOUT: state = state.copy_and_replace( - state=PresenceState.OFFLINE, - status_msg=None, + state=PresenceState.OFFLINE, status_msg=None ) changed = True else: @@ -1155,10 +1171,7 @@ def handle_timeout(state, is_mine, syncing_user_ids, now): # no one gets stuck online forever. if now - state.last_federation_update_ts > FEDERATION_TIMEOUT: # The other side seems to have disappeared. - state = state.copy_and_replace( - state=PresenceState.OFFLINE, - status_msg=None, - ) + state = state.copy_and_replace(state=PresenceState.OFFLINE, status_msg=None) changed = True return state if changed else None @@ -1193,21 +1206,17 @@ def handle_update(prev_state, new_state, is_mine, wheel_timer, now): if new_state.state == PresenceState.ONLINE: # Idle timer wheel_timer.insert( - now=now, - obj=user_id, - then=new_state.last_active_ts + IDLE_TIMER + now=now, obj=user_id, then=new_state.last_active_ts + IDLE_TIMER ) active = now - new_state.last_active_ts < LAST_ACTIVE_GRANULARITY - new_state = new_state.copy_and_replace( - currently_active=active, - ) + new_state = new_state.copy_and_replace(currently_active=active) if active: wheel_timer.insert( now=now, obj=user_id, - then=new_state.last_active_ts + LAST_ACTIVE_GRANULARITY + then=new_state.last_active_ts + LAST_ACTIVE_GRANULARITY, ) if new_state.state != PresenceState.OFFLINE: @@ -1215,29 +1224,25 @@ def handle_update(prev_state, new_state, is_mine, wheel_timer, now): wheel_timer.insert( now=now, obj=user_id, - then=new_state.last_user_sync_ts + SYNC_ONLINE_TIMEOUT + then=new_state.last_user_sync_ts + SYNC_ONLINE_TIMEOUT, ) last_federate = new_state.last_federation_update_ts if now - last_federate > FEDERATION_PING_INTERVAL: # Been a while since we've poked remote servers - new_state = new_state.copy_and_replace( - last_federation_update_ts=now, - ) + new_state = new_state.copy_and_replace(last_federation_update_ts=now) federation_ping = True else: wheel_timer.insert( now=now, obj=user_id, - then=new_state.last_federation_update_ts + FEDERATION_TIMEOUT + then=new_state.last_federation_update_ts + FEDERATION_TIMEOUT, ) # Check whether the change was something worth notifying about if should_notify(prev_state, new_state): - new_state = new_state.copy_and_replace( - last_federation_update_ts=now, - ) + new_state = new_state.copy_and_replace(last_federation_update_ts=now) persist_and_notify = True return new_state, persist_and_notify, federation_ping diff --git a/synapse/handlers/profile.py b/synapse/handlers/profile.py index 3e0423339433..d8462b75eca5 100644 --- a/synapse/handlers/profile.py +++ b/synapse/handlers/profile.py @@ -73,18 +73,13 @@ def get_profile(self, user_id): raise SynapseError(404, "Profile was not found", Codes.NOT_FOUND) raise - defer.returnValue({ - "displayname": displayname, - "avatar_url": avatar_url, - }) + defer.returnValue({"displayname": displayname, "avatar_url": avatar_url}) else: try: result = yield self.federation.make_query( destination=target_user.domain, query_type="profile", - args={ - "user_id": user_id, - }, + args={"user_id": user_id}, ignore_backoff=True, ) defer.returnValue(result) @@ -113,10 +108,7 @@ def get_profile_from_cache(self, user_id): raise SynapseError(404, "Profile was not found", Codes.NOT_FOUND) raise - defer.returnValue({ - "displayname": displayname, - "avatar_url": avatar_url, - }) + defer.returnValue({"displayname": displayname, "avatar_url": avatar_url}) else: profile = yield self.store.get_from_remote_profile_cache(user_id) defer.returnValue(profile or {}) @@ -139,10 +131,7 @@ def get_displayname(self, target_user): result = yield self.federation.make_query( destination=target_user.domain, query_type="profile", - args={ - "user_id": target_user.to_string(), - "field": "displayname", - }, + args={"user_id": target_user.to_string(), "field": "displayname"}, ignore_backoff=True, ) except RequestSendFailed as e: @@ -170,15 +159,13 @@ def set_displayname(self, target_user, requester, new_displayname, by_admin=Fals if len(new_displayname) > MAX_DISPLAYNAME_LEN: raise SynapseError( - 400, "Displayname is too long (max %i)" % (MAX_DISPLAYNAME_LEN, ), + 400, "Displayname is too long (max %i)" % (MAX_DISPLAYNAME_LEN,) ) - if new_displayname == '': + if new_displayname == "": new_displayname = None - yield self.store.set_profile_displayname( - target_user.localpart, new_displayname - ) + yield self.store.set_profile_displayname(target_user.localpart, new_displayname) if self.hs.config.user_directory_search_all_users: profile = yield self.store.get_profileinfo(target_user.localpart) @@ -205,10 +192,7 @@ def get_avatar_url(self, target_user): result = yield self.federation.make_query( destination=target_user.domain, query_type="profile", - args={ - "user_id": target_user.to_string(), - "field": "avatar_url", - }, + args={"user_id": target_user.to_string(), "field": "avatar_url"}, ignore_backoff=True, ) except RequestSendFailed as e: @@ -230,12 +214,10 @@ def set_avatar_url(self, target_user, requester, new_avatar_url, by_admin=False) if len(new_avatar_url) > MAX_AVATAR_URL_LEN: raise SynapseError( - 400, "Avatar URL is too long (max %i)" % (MAX_AVATAR_URL_LEN, ), + 400, "Avatar URL is too long (max %i)" % (MAX_AVATAR_URL_LEN,) ) - yield self.store.set_profile_avatar_url( - target_user.localpart, new_avatar_url - ) + yield self.store.set_profile_avatar_url(target_user.localpart, new_avatar_url) if self.hs.config.user_directory_search_all_users: profile = yield self.store.get_profileinfo(target_user.localpart) @@ -278,9 +260,7 @@ def _update_join_states(self, requester, target_user): yield self.ratelimit(requester) - room_ids = yield self.store.get_rooms_for_user( - target_user.to_string(), - ) + room_ids = yield self.store.get_rooms_for_user(target_user.to_string()) for room_id in room_ids: handler = self.hs.get_room_member_handler() @@ -296,8 +276,7 @@ def _update_join_states(self, requester, target_user): ) except Exception as e: logger.warn( - "Failed to update join event for room %s - %s", - room_id, str(e) + "Failed to update join event for room %s - %s", room_id, str(e) ) @defer.inlineCallbacks @@ -325,11 +304,9 @@ def check_profile_query_allowed(self, target_user, requester=None): return try: - requester_rooms = yield self.store.get_rooms_for_user( - requester.to_string() - ) + requester_rooms = yield self.store.get_rooms_for_user(requester.to_string()) target_user_rooms = yield self.store.get_rooms_for_user( - target_user.to_string(), + target_user.to_string() ) # Check if the room lists have no elements in common. @@ -353,12 +330,12 @@ def __init__(self, hs): assert hs.config.worker_app is None self.clock.looping_call( - self._start_update_remote_profile_cache, self.PROFILE_UPDATE_MS, + self._start_update_remote_profile_cache, self.PROFILE_UPDATE_MS ) def _start_update_remote_profile_cache(self): return run_as_background_process( - "Update remote profile", self._update_remote_profile_cache, + "Update remote profile", self._update_remote_profile_cache ) @defer.inlineCallbacks @@ -372,7 +349,7 @@ def _update_remote_profile_cache(self): for user_id, displayname, avatar_url in entries: is_subscribed = yield self.store.is_subscribed_remote_profile_for_user( - user_id, + user_id ) if not is_subscribed: yield self.store.maybe_delete_remote_profile_cache(user_id) @@ -382,9 +359,7 @@ def _update_remote_profile_cache(self): profile = yield self.federation.make_query( destination=get_domain_from_id(user_id), query_type="profile", - args={ - "user_id": user_id, - }, + args={"user_id": user_id}, ignore_backoff=True, ) except Exception: @@ -399,6 +374,4 @@ def _update_remote_profile_cache(self): new_avatar = profile.get("avatar_url") # We always hit update to update the last_check timestamp - yield self.store.update_remote_profile_cache( - user_id, new_name, new_avatar - ) + yield self.store.update_remote_profile_cache(user_id, new_name, new_avatar) diff --git a/synapse/handlers/read_marker.py b/synapse/handlers/read_marker.py index 32108568c640..3e4d8c93a4a5 100644 --- a/synapse/handlers/read_marker.py +++ b/synapse/handlers/read_marker.py @@ -43,7 +43,7 @@ def received_client_read_marker(self, room_id, user_id, event_id): with (yield self.read_marker_linearizer.queue((room_id, user_id))): existing_read_marker = yield self.store.get_account_data_for_room_and_type( - user_id, room_id, "m.fully_read", + user_id, room_id, "m.fully_read" ) should_update = True @@ -51,14 +51,11 @@ def received_client_read_marker(self, room_id, user_id, event_id): if existing_read_marker: # Only update if the new marker is ahead in the stream should_update = yield self.store.is_event_after( - event_id, - existing_read_marker['event_id'] + event_id, existing_read_marker["event_id"] ) if should_update: - content = { - "event_id": event_id - } + content = {"event_id": event_id} max_id = yield self.store.add_account_data_to_room( user_id, room_id, "m.fully_read", content ) diff --git a/synapse/handlers/receipts.py b/synapse/handlers/receipts.py index 274d2946ad5d..a85dd8cdee69 100644 --- a/synapse/handlers/receipts.py +++ b/synapse/handlers/receipts.py @@ -88,19 +88,16 @@ def _handle_new_receipts(self, receipts): affected_room_ids = list(set([r.room_id for r in receipts])) - self.notifier.on_new_event( - "receipt_key", max_batch_id, rooms=affected_room_ids - ) + self.notifier.on_new_event("receipt_key", max_batch_id, rooms=affected_room_ids) # Note that the min here shouldn't be relied upon to be accurate. yield self.hs.get_pusherpool().on_new_receipts( - min_batch_id, max_batch_id, affected_room_ids, + min_batch_id, max_batch_id, affected_room_ids ) defer.returnValue(True) @defer.inlineCallbacks - def received_client_receipt(self, room_id, receipt_type, user_id, - event_id): + def received_client_receipt(self, room_id, receipt_type, user_id, event_id): """Called when a client tells us a local user has read up to the given event_id in the room. """ @@ -109,9 +106,7 @@ def received_client_receipt(self, room_id, receipt_type, user_id, receipt_type=receipt_type, user_id=user_id, event_ids=[event_id], - data={ - "ts": int(self.clock.time_msec()), - }, + data={"ts": int(self.clock.time_msec())}, ) is_new = yield self._handle_new_receipts([receipt]) @@ -125,8 +120,7 @@ def get_receipts_for_room(self, room_id, to_key): """Gets all receipts for a room, upto the given key. """ result = yield self.store.get_linearized_receipts_for_room( - room_id, - to_key=to_key, + room_id, to_key=to_key ) if not result: @@ -148,14 +142,12 @@ def get_new_events(self, from_key, room_ids, **kwargs): defer.returnValue(([], to_key)) events = yield self.store.get_linearized_receipts_for_rooms( - room_ids, - from_key=from_key, - to_key=to_key, + room_ids, from_key=from_key, to_key=to_key ) defer.returnValue((events, to_key)) - def get_current_key(self, direction='f'): + def get_current_key(self, direction="f"): return self.store.get_max_receipt_stream_id() @defer.inlineCallbacks @@ -169,9 +161,7 @@ def get_pagination_rows(self, user, config, key): room_ids = yield self.store.get_rooms_for_user(user.to_string()) events = yield self.store.get_linearized_receipts_for_rooms( - room_ids, - from_key=from_key, - to_key=to_key, + room_ids, from_key=from_key, to_key=to_key ) defer.returnValue((events, to_key)) diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py index 9a388ea013f6..e487b90c0821 100644 --- a/synapse/handlers/register.py +++ b/synapse/handlers/register.py @@ -47,7 +47,6 @@ class RegistrationHandler(BaseHandler): - def __init__(self, hs): """ @@ -69,44 +68,37 @@ def __init__(self, hs): self.macaroon_gen = hs.get_macaroon_generator() self._generate_user_id_linearizer = Linearizer( - name="_generate_user_id_linearizer", + name="_generate_user_id_linearizer" ) self._server_notices_mxid = hs.config.server_notices_mxid if hs.config.worker_app: self._register_client = ReplicationRegisterServlet.make_client(hs) - self._register_device_client = ( - RegisterDeviceReplicationServlet.make_client(hs) + self._register_device_client = RegisterDeviceReplicationServlet.make_client( + hs ) - self._post_registration_client = ( - ReplicationPostRegisterActionsServlet.make_client(hs) + self._post_registration_client = ReplicationPostRegisterActionsServlet.make_client( + hs ) else: self.device_handler = hs.get_device_handler() self.pusher_pool = hs.get_pusherpool() @defer.inlineCallbacks - def check_username(self, localpart, guest_access_token=None, - assigned_user_id=None): + def check_username(self, localpart, guest_access_token=None, assigned_user_id=None): if types.contains_invalid_mxid_characters(localpart): raise SynapseError( 400, "User ID can only contain characters a-z, 0-9, or '=_-./'", - Codes.INVALID_USERNAME + Codes.INVALID_USERNAME, ) if not localpart: - raise SynapseError( - 400, - "User ID cannot be empty", - Codes.INVALID_USERNAME - ) + raise SynapseError(400, "User ID cannot be empty", Codes.INVALID_USERNAME) - if localpart[0] == '_': + if localpart[0] == "_": raise SynapseError( - 400, - "User ID may not begin with _", - Codes.INVALID_USERNAME + 400, "User ID may not begin with _", Codes.INVALID_USERNAME ) user = UserID(localpart, self.hs.hostname) @@ -126,19 +118,15 @@ def check_username(self, localpart, guest_access_token=None, if len(user_id) > MAX_USERID_LENGTH: raise SynapseError( 400, - "User ID may not be longer than %s characters" % ( - MAX_USERID_LENGTH, - ), - Codes.INVALID_USERNAME + "User ID may not be longer than %s characters" % (MAX_USERID_LENGTH,), + Codes.INVALID_USERNAME, ) users = yield self.store.get_users_by_id_case_insensitive(user_id) if users: if not guest_access_token: raise SynapseError( - 400, - "User ID already taken.", - errcode=Codes.USER_IN_USE, + 400, "User ID already taken.", errcode=Codes.USER_IN_USE ) user_data = yield self.auth.get_user_by_access_token(guest_access_token) if not user_data["is_guest"] or user_data["user"].localpart != localpart: @@ -203,8 +191,7 @@ def register( try: int(localpart) raise RegistrationError( - 400, - "Numeric user IDs are reserved for guest users." + 400, "Numeric user IDs are reserved for guest users." ) except ValueError: pass @@ -283,9 +270,7 @@ def register( } # Bind email to new account - yield self._register_email_threepid( - user_id, threepid_dict, None, False, - ) + yield self._register_email_threepid(user_id, threepid_dict, None, False) defer.returnValue((user_id, token)) @@ -318,8 +303,8 @@ def _auto_join_rooms(self, user_id): room_alias = RoomAlias.from_string(r) if self.hs.hostname != room_alias.domain: logger.warning( - 'Cannot create room alias %s, ' - 'it does not match server domain', + "Cannot create room alias %s, " + "it does not match server domain", r, ) else: @@ -332,7 +317,7 @@ def _auto_join_rooms(self, user_id): fake_requester, config={ "preset": "public_chat", - "room_alias_name": room_alias_localpart + "room_alias_name": room_alias_localpart, }, ratelimit=False, ) @@ -364,8 +349,9 @@ def appservice_register(self, user_localpart, as_token): raise AuthError(403, "Invalid application service token.") if not service.is_interested_in_user(user_id): raise SynapseError( - 400, "Invalid user localpart for this application service.", - errcode=Codes.EXCLUSIVE + 400, + "Invalid user localpart for this application service.", + errcode=Codes.EXCLUSIVE, ) service_id = service.id if service.is_exclusive_user(user_id) else None @@ -391,17 +377,15 @@ def check_recaptcha(self, ip, private_key, challenge, response): """ captcha_response = yield self._validate_captcha( - ip, - private_key, - challenge, - response + ip, private_key, challenge, response ) if not captcha_response["valid"]: - logger.info("Invalid captcha entered from %s. Error: %s", - ip, captcha_response["error_url"]) - raise InvalidCaptchaError( - error_url=captcha_response["error_url"] + logger.info( + "Invalid captcha entered from %s. Error: %s", + ip, + captcha_response["error_url"], ) + raise InvalidCaptchaError(error_url=captcha_response["error_url"]) else: logger.info("Valid captcha entered from %s", ip) @@ -414,8 +398,11 @@ def register_email(self, threepidCreds): """ for c in threepidCreds: - logger.info("validating threepidcred sid %s on id server %s", - c['sid'], c['idServer']) + logger.info( + "validating threepidcred sid %s on id server %s", + c["sid"], + c["idServer"], + ) try: threepid = yield self.identity_handler.threepid_from_creds(c) except Exception: @@ -424,13 +411,14 @@ def register_email(self, threepidCreds): if not threepid: raise RegistrationError(400, "Couldn't validate 3pid") - logger.info("got threepid with medium '%s' and address '%s'", - threepid['medium'], threepid['address']) + logger.info( + "got threepid with medium '%s' and address '%s'", + threepid["medium"], + threepid["address"], + ) - if not check_3pid_allowed(self.hs, threepid['medium'], threepid['address']): - raise RegistrationError( - 403, "Third party identifier is not allowed" - ) + if not check_3pid_allowed(self.hs, threepid["medium"], threepid["address"]): + raise RegistrationError(403, "Third party identifier is not allowed") @defer.inlineCallbacks def bind_emails(self, user_id, threepidCreds): @@ -449,23 +437,23 @@ def check_user_id_not_appservice_exclusive(self, user_id, allowed_appservice=Non if self._server_notices_mxid is not None: if user_id == self._server_notices_mxid: raise SynapseError( - 400, "This user ID is reserved.", - errcode=Codes.EXCLUSIVE + 400, "This user ID is reserved.", errcode=Codes.EXCLUSIVE ) # valid user IDs must not clash with any user ID namespaces claimed by # application services. services = self.store.get_app_services() interested_services = [ - s for s in services - if s.is_interested_in_user(user_id) - and s != allowed_appservice + s + for s in services + if s.is_interested_in_user(user_id) and s != allowed_appservice ] for service in interested_services: if service.is_exclusive_user(user_id): raise SynapseError( - 400, "This user ID is reserved by an application service.", - errcode=Codes.EXCLUSIVE + 400, + "This user ID is reserved by an application service.", + errcode=Codes.EXCLUSIVE, ) @defer.inlineCallbacks @@ -491,14 +479,13 @@ def _validate_captcha(self, ip_addr, private_key, challenge, response): dict: Containing 'valid'(bool) and 'error_url'(str) if invalid. """ - response = yield self._submit_captcha(ip_addr, private_key, challenge, - response) + response = yield self._submit_captcha(ip_addr, private_key, challenge, response) # parse Google's response. Lovely format.. - lines = response.split('\n') + lines = response.split("\n") json = { - "valid": lines[0] == 'true', - "error_url": "http://www.recaptcha.net/recaptcha/api/challenge?" + - "error=%s" % lines[1] + "valid": lines[0] == "true", + "error_url": "http://www.recaptcha.net/recaptcha/api/challenge?" + + "error=%s" % lines[1], } defer.returnValue(json) @@ -510,17 +497,16 @@ def _submit_captcha(self, ip_addr, private_key, challenge, response): data = yield self.captcha_client.post_urlencoded_get_raw( "http://www.recaptcha.net:80/recaptcha/api/verify", args={ - 'privatekey': private_key, - 'remoteip': ip_addr, - 'challenge': challenge, - 'response': response - } + "privatekey": private_key, + "remoteip": ip_addr, + "challenge": challenge, + "response": response, + }, ) defer.returnValue(data) @defer.inlineCallbacks - def get_or_create_user(self, requester, localpart, displayname, - password_hash=None): + def get_or_create_user(self, requester, localpart, displayname, password_hash=None): """Creates a new user if the user does not exist, else revokes all previous access tokens and generates a new one. @@ -565,7 +551,7 @@ def get_or_create_user(self, requester, localpart, displayname, if displayname is not None: logger.info("setting user display name: %s -> %s", user_id, displayname) yield self.profile_handler.set_displayname( - user, requester, displayname, by_admin=True, + user, requester, displayname, by_admin=True ) defer.returnValue((user_id, token)) @@ -587,15 +573,12 @@ def get_or_register_3pid_guest(self, medium, address, inviter_user_id): """ access_token = yield self.store.get_3pid_guest_access_token(medium, address) if access_token: - user_info = yield self.auth.get_user_by_access_token( - access_token - ) + user_info = yield self.auth.get_user_by_access_token(access_token) defer.returnValue((user_info["user"].to_string(), access_token)) user_id, access_token = yield self.register( - generate_token=True, - make_guest=True + generate_token=True, make_guest=True ) access_token = yield self.store.save_or_get_3pid_guest_access_token( medium, address, access_token, inviter_user_id @@ -616,9 +599,9 @@ def _join_user_to_room(self, requester, room_identifier): ) room_id = room_id.to_string() else: - raise SynapseError(400, "%s was not legal room ID or room alias" % ( - room_identifier, - )) + raise SynapseError( + 400, "%s was not legal room ID or room alias" % (room_identifier,) + ) yield room_member_handler.update_membership( requester=requester, @@ -629,10 +612,19 @@ def _join_user_to_room(self, requester, room_identifier): ratelimit=False, ) - def register_with_store(self, user_id, token=None, password_hash=None, - was_guest=False, make_guest=False, appservice_id=None, - create_profile_with_displayname=None, admin=False, - user_type=None, address=None): + def register_with_store( + self, + user_id, + token=None, + password_hash=None, + was_guest=False, + make_guest=False, + appservice_id=None, + create_profile_with_displayname=None, + admin=False, + user_type=None, + address=None, + ): """Register user in the datastore. Args: @@ -661,14 +653,15 @@ def register_with_store(self, user_id, token=None, password_hash=None, time_now = self.clock.time() allowed, time_allowed = self.ratelimiter.can_do_action( - address, time_now_s=time_now, + address, + time_now_s=time_now, rate_hz=self.hs.config.rc_registration.per_second, burst_count=self.hs.config.rc_registration.burst_count, ) if not allowed: raise LimitExceededError( - retry_after_ms=int(1000 * (time_allowed - time_now)), + retry_after_ms=int(1000 * (time_allowed - time_now)) ) if self.hs.config.worker_app: @@ -698,8 +691,7 @@ def register_with_store(self, user_id, token=None, password_hash=None, ) @defer.inlineCallbacks - def register_device(self, user_id, device_id, initial_display_name, - is_guest=False): + def register_device(self, user_id, device_id, initial_display_name, is_guest=False): """Register a device for a user and generate an access token. Args: @@ -732,14 +724,15 @@ def register_device(self, user_id, device_id, initial_display_name, ) else: access_token = yield self._auth_handler.get_access_token_for_user_id( - user_id, device_id=device_id, + user_id, device_id=device_id ) defer.returnValue((device_id, access_token)) @defer.inlineCallbacks - def post_registration_actions(self, user_id, auth_result, access_token, - bind_email, bind_msisdn): + def post_registration_actions( + self, user_id, auth_result, access_token, bind_email, bind_msisdn + ): """A user has completed registration Args: @@ -773,20 +766,15 @@ def post_registration_actions(self, user_id, auth_result, access_token, yield self.store.upsert_monthly_active_user(user_id) yield self._register_email_threepid( - user_id, threepid, access_token, - bind_email, + user_id, threepid, access_token, bind_email ) if auth_result and LoginType.MSISDN in auth_result: threepid = auth_result[LoginType.MSISDN] - yield self._register_msisdn_threepid( - user_id, threepid, bind_msisdn, - ) + yield self._register_msisdn_threepid(user_id, threepid, bind_msisdn) if auth_result and LoginType.TERMS in auth_result: - yield self._on_user_consented( - user_id, self.hs.config.user_consent_version, - ) + yield self._on_user_consented(user_id, self.hs.config.user_consent_version) @defer.inlineCallbacks def _on_user_consented(self, user_id, consent_version): @@ -798,9 +786,7 @@ def _on_user_consented(self, user_id, consent_version): consented to. """ logger.info("%s has consented to the privacy policy", user_id) - yield self.store.user_set_consent_version( - user_id, consent_version, - ) + yield self.store.user_set_consent_version(user_id, consent_version) yield self.post_consent_actions(user_id) @defer.inlineCallbacks @@ -824,33 +810,30 @@ def _register_email_threepid(self, user_id, threepid, token, bind_email): Returns: defer.Deferred: """ - reqd = ('medium', 'address', 'validated_at') + reqd = ("medium", "address", "validated_at") if any(x not in threepid for x in reqd): # This will only happen if the ID server returns a malformed response logger.info("Can't add incomplete 3pid") return yield self._auth_handler.add_threepid( - user_id, - threepid['medium'], - threepid['address'], - threepid['validated_at'], + user_id, threepid["medium"], threepid["address"], threepid["validated_at"] ) # And we add an email pusher for them by default, but only # if email notifications are enabled (so people don't start # getting mail spam where they weren't before if email # notifs are set up on a home server) - if (self.hs.config.email_enable_notifs and - self.hs.config.email_notif_for_new_users - and token): + if ( + self.hs.config.email_enable_notifs + and self.hs.config.email_notif_for_new_users + and token + ): # Pull the ID of the access token back out of the db # It would really make more sense for this to be passed # up when the access token is saved, but that's quite an # invasive change I'd rather do separately. - user_tuple = yield self.store.get_user_by_access_token( - token - ) + user_tuple = yield self.store.get_user_by_access_token(token) token_id = user_tuple["token_id"] yield self.pusher_pool.add_pusher( @@ -867,11 +850,9 @@ def _register_email_threepid(self, user_id, threepid, token, bind_email): if bind_email: logger.info("bind_email specified: binding") - logger.debug("Binding emails %s to %s" % ( - threepid, user_id - )) + logger.debug("Binding emails %s to %s" % (threepid, user_id)) yield self.identity_handler.bind_threepid( - threepid['threepid_creds'], user_id + threepid["threepid_creds"], user_id ) else: logger.info("bind_email not specified: not binding email") @@ -894,7 +875,7 @@ def _register_msisdn_threepid(self, user_id, threepid, bind_msisdn): defer.Deferred: """ try: - assert_params_in_dict(threepid, ['medium', 'address', 'validated_at']) + assert_params_in_dict(threepid, ["medium", "address", "validated_at"]) except SynapseError as ex: if ex.errcode == Codes.MISSING_PARAM: # This will only happen if the ID server returns a malformed response @@ -903,17 +884,14 @@ def _register_msisdn_threepid(self, user_id, threepid, bind_msisdn): raise yield self._auth_handler.add_threepid( - user_id, - threepid['medium'], - threepid['address'], - threepid['validated_at'], + user_id, threepid["medium"], threepid["address"], threepid["validated_at"] ) if bind_msisdn: logger.info("bind_msisdn specified: binding") logger.debug("Binding msisdn %s to %s", threepid, user_id) yield self.identity_handler.bind_threepid( - threepid['threepid_creds'], user_id + threepid["threepid_creds"], user_id ) else: logger.info("bind_msisdn not specified: not binding msisdn") diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index 74793bab335b..89d89fc27ce8 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -101,7 +101,7 @@ def upgrade_room(self, requester, old_room_id, new_version): if r is None: raise NotFoundError("Unknown room id %s" % (old_room_id,)) new_room_id = yield self._generate_room_id( - creator_id=user_id, is_public=r["is_public"], + creator_id=user_id, is_public=r["is_public"] ) logger.info("Creating new room %s to replace %s", new_room_id, old_room_id) @@ -110,7 +110,8 @@ def upgrade_room(self, requester, old_room_id, new_version): # room, to check our user has perms in the old room. tombstone_event, tombstone_context = ( yield self.event_creation_handler.create_event( - requester, { + requester, + { "type": EventTypes.Tombstone, "state_key": "", "room_id": old_room_id, @@ -118,14 +119,14 @@ def upgrade_room(self, requester, old_room_id, new_version): "content": { "body": "This room has been replaced", "replacement_room": new_room_id, - } + }, }, token_id=requester.access_token_id, ) ) old_room_version = yield self.store.get_room_version(old_room_id) yield self.auth.check_from_context( - old_room_version, tombstone_event, tombstone_context, + old_room_version, tombstone_event, tombstone_context ) yield self.clone_existing_room( @@ -138,27 +139,27 @@ def upgrade_room(self, requester, old_room_id, new_version): # now send the tombstone yield self.event_creation_handler.send_nonmember_event( - requester, tombstone_event, tombstone_context, + requester, tombstone_event, tombstone_context ) old_room_state = yield tombstone_context.get_current_state_ids(self.store) # update any aliases yield self._move_aliases_to_new_room( - requester, old_room_id, new_room_id, old_room_state, + requester, old_room_id, new_room_id, old_room_state ) # and finally, shut down the PLs in the old room, and update them in the new # room. yield self._update_upgraded_room_pls( - requester, old_room_id, new_room_id, old_room_state, + requester, old_room_id, new_room_id, old_room_state ) defer.returnValue(new_room_id) @defer.inlineCallbacks def _update_upgraded_room_pls( - self, requester, old_room_id, new_room_id, old_room_state, + self, requester, old_room_id, new_room_id, old_room_state ): """Send updated power levels in both rooms after an upgrade @@ -176,7 +177,7 @@ def _update_upgraded_room_pls( if old_room_pl_event_id is None: logger.warning( "Not supported: upgrading a room with no PL event. Not setting PLs " - "in old room.", + "in old room." ) return @@ -197,45 +198,48 @@ def _update_upgraded_room_pls( if current < restricted_level: logger.info( "Setting level for %s in %s to %i (was %i)", - v, old_room_id, restricted_level, current, + v, + old_room_id, + restricted_level, + current, ) pl_content[v] = restricted_level updated = True else: - logger.info( - "Not setting level for %s (already %i)", - v, current, - ) + logger.info("Not setting level for %s (already %i)", v, current) if updated: try: yield self.event_creation_handler.create_and_send_nonmember_event( - requester, { + requester, + { "type": EventTypes.PowerLevels, - "state_key": '', + "state_key": "", "room_id": old_room_id, "sender": requester.user.to_string(), "content": pl_content, - }, ratelimit=False, + }, + ratelimit=False, ) except AuthError as e: logger.warning("Unable to update PLs in old room: %s", e) logger.info("Setting correct PLs in new room") yield self.event_creation_handler.create_and_send_nonmember_event( - requester, { + requester, + { "type": EventTypes.PowerLevels, - "state_key": '', + "state_key": "", "room_id": new_room_id, "sender": requester.user.to_string(), "content": old_room_pl_state.content, - }, ratelimit=False, + }, + ratelimit=False, ) @defer.inlineCallbacks def clone_existing_room( - self, requester, old_room_id, new_room_id, new_room_version, - tombstone_event_id, + self, requester, old_room_id, new_room_id, new_room_version, tombstone_event_id ): """Populate a new room based on an old room @@ -257,10 +261,7 @@ def clone_existing_room( creation_content = { "room_version": new_room_version, - "predecessor": { - "room_id": old_room_id, - "event_id": tombstone_event_id, - } + "predecessor": {"room_id": old_room_id, "event_id": tombstone_event_id}, } # Check if old room was non-federatable @@ -289,7 +290,7 @@ def clone_existing_room( ) old_room_state_ids = yield self.store.get_filtered_current_state_ids( - old_room_id, StateFilter.from_types(types_to_copy), + old_room_id, StateFilter.from_types(types_to_copy) ) # map from event_id to BaseEvent old_room_state_events = yield self.store.get_events(old_room_state_ids.values()) @@ -302,11 +303,9 @@ def clone_existing_room( yield self._send_events_for_new_room( requester, new_room_id, - # we expect to override all the presets with initial_state, so this is # somewhat arbitrary. preset_config=RoomCreationPreset.PRIVATE_CHAT, - invite_list=[], initial_state=initial_state, creation_content=creation_content, @@ -314,20 +313,22 @@ def clone_existing_room( # Transfer membership events old_room_member_state_ids = yield self.store.get_filtered_current_state_ids( - old_room_id, StateFilter.from_types([(EventTypes.Member, None)]), + old_room_id, StateFilter.from_types([(EventTypes.Member, None)]) ) # map from event_id to BaseEvent old_room_member_state_events = yield self.store.get_events( - old_room_member_state_ids.values(), + old_room_member_state_ids.values() ) for k, old_event in iteritems(old_room_member_state_events): # Only transfer ban events - if ("membership" in old_event.content and - old_event.content["membership"] == "ban"): + if ( + "membership" in old_event.content + and old_event.content["membership"] == "ban" + ): yield self.room_member_handler.update_membership( requester, - UserID.from_string(old_event['state_key']), + UserID.from_string(old_event["state_key"]), new_room_id, "ban", ratelimit=False, @@ -339,7 +340,7 @@ def clone_existing_room( @defer.inlineCallbacks def _move_aliases_to_new_room( - self, requester, old_room_id, new_room_id, old_room_state, + self, requester, old_room_id, new_room_id, old_room_state ): directory_handler = self.hs.get_handlers().directory_handler @@ -370,14 +371,11 @@ def _move_aliases_to_new_room( alias = RoomAlias.from_string(alias_str) try: yield directory_handler.delete_association( - requester, alias, send_event=False, + requester, alias, send_event=False ) removed_aliases.append(alias_str) except SynapseError as e: - logger.warning( - "Unable to remove alias %s from old room: %s", - alias, e, - ) + logger.warning("Unable to remove alias %s from old room: %s", alias, e) # if we didn't find any aliases, or couldn't remove anyway, we can skip the rest # of this. @@ -393,30 +391,26 @@ def _move_aliases_to_new_room( # as when you remove an alias from the directory normally - it just means that # the aliases event gets out of sync with the directory # (cf /~https://github.com/vector-im/riot-web/issues/2369) - yield directory_handler.send_room_alias_update_event( - requester, old_room_id, - ) + yield directory_handler.send_room_alias_update_event(requester, old_room_id) except AuthError as e: - logger.warning( - "Failed to send updated alias event on old room: %s", e, - ) + logger.warning("Failed to send updated alias event on old room: %s", e) # we can now add any aliases we successfully removed to the new room. for alias in removed_aliases: try: yield directory_handler.create_association( - requester, RoomAlias.from_string(alias), - new_room_id, servers=(self.hs.hostname, ), - send_event=False, check_membership=False, + requester, + RoomAlias.from_string(alias), + new_room_id, + servers=(self.hs.hostname,), + send_event=False, + check_membership=False, ) logger.info("Moved alias %s to new room", alias) except SynapseError as e: # I'm not really expecting this to happen, but it could if the spam # checking module decides it shouldn't, or similar. - logger.error( - "Error adding alias %s to new room: %s", - alias, e, - ) + logger.error("Error adding alias %s to new room: %s", alias, e) try: if canonical_alias and (canonical_alias in removed_aliases): @@ -427,24 +421,19 @@ def _move_aliases_to_new_room( "state_key": "", "room_id": new_room_id, "sender": requester.user.to_string(), - "content": {"alias": canonical_alias, }, + "content": {"alias": canonical_alias}, }, - ratelimit=False + ratelimit=False, ) - yield directory_handler.send_room_alias_update_event( - requester, new_room_id, - ) + yield directory_handler.send_room_alias_update_event(requester, new_room_id) except SynapseError as e: # again I'm not really expecting this to fail, but if it does, I'd rather # we returned the new room to the client at this point. - logger.error( - "Unable to send updated alias events in new room: %s", e, - ) + logger.error("Unable to send updated alias events in new room: %s", e) @defer.inlineCallbacks - def create_room(self, requester, config, ratelimit=True, - creator_join_profile=None): + def create_room(self, requester, config, ratelimit=True, creator_join_profile=None): """ Creates a new room. Args: @@ -474,25 +463,23 @@ def create_room(self, requester, config, ratelimit=True, yield self.auth.check_auth_blocking(user_id) - if (self._server_notices_mxid is not None and - requester.user.to_string() == self._server_notices_mxid): + if ( + self._server_notices_mxid is not None + and requester.user.to_string() == self._server_notices_mxid + ): # allow the server notices mxid to create rooms is_requester_admin = True else: - is_requester_admin = yield self.auth.is_server_admin( - requester.user, - ) + is_requester_admin = yield self.auth.is_server_admin(requester.user) # Check whether the third party rules allows/changes the room create # request. yield self.third_party_event_rules.on_create_room( - requester, - config, - is_requester_admin=is_requester_admin, + requester, config, is_requester_admin=is_requester_admin ) if not is_requester_admin and not self.spam_checker.user_may_create_room( - user_id, + user_id ): raise SynapseError(403, "You are not permitted to create rooms") @@ -500,16 +487,11 @@ def create_room(self, requester, config, ratelimit=True, yield self.ratelimit(requester) room_version = config.get( - "room_version", - self.config.default_room_version.identifier, + "room_version", self.config.default_room_version.identifier ) if not isinstance(room_version, string_types): - raise SynapseError( - 400, - "room_version must be a string", - Codes.BAD_JSON, - ) + raise SynapseError(400, "room_version must be a string", Codes.BAD_JSON) if room_version not in KNOWN_ROOM_VERSIONS: raise SynapseError( @@ -523,20 +505,11 @@ def create_room(self, requester, config, ratelimit=True, if wchar in config["room_alias_name"]: raise SynapseError(400, "Invalid characters in room alias") - room_alias = RoomAlias( - config["room_alias_name"], - self.hs.hostname, - ) - mapping = yield self.store.get_association_from_room_alias( - room_alias - ) + room_alias = RoomAlias(config["room_alias_name"], self.hs.hostname) + mapping = yield self.store.get_association_from_room_alias(room_alias) if mapping: - raise SynapseError( - 400, - "Room alias already taken", - Codes.ROOM_IN_USE - ) + raise SynapseError(400, "Room alias already taken", Codes.ROOM_IN_USE) else: room_alias = None @@ -547,9 +520,7 @@ def create_room(self, requester, config, ratelimit=True, except Exception: raise SynapseError(400, "Invalid user_id: %s" % (i,)) - yield self.event_creation_handler.assert_accepted_privacy_policy( - requester, - ) + yield self.event_creation_handler.assert_accepted_privacy_policy(requester) invite_3pid_list = config.get("invite_3pid", []) @@ -573,7 +544,7 @@ def create_room(self, requester, config, ratelimit=True, "preset", RoomCreationPreset.PRIVATE_CHAT if visibility == "private" - else RoomCreationPreset.PUBLIC_CHAT + else RoomCreationPreset.PUBLIC_CHAT, ) raw_initial_state = config.get("initial_state", []) @@ -610,7 +581,8 @@ def create_room(self, requester, config, ratelimit=True, "state_key": "", "content": {"name": name}, }, - ratelimit=False) + ratelimit=False, + ) if "topic" in config: topic = config["topic"] @@ -623,7 +595,8 @@ def create_room(self, requester, config, ratelimit=True, "state_key": "", "content": {"topic": topic}, }, - ratelimit=False) + ratelimit=False, + ) for invitee in invite_list: content = {} @@ -658,30 +631,25 @@ def create_room(self, requester, config, ratelimit=True, if room_alias: result["room_alias"] = room_alias.to_string() - yield directory_handler.send_room_alias_update_event( - requester, room_id - ) + yield directory_handler.send_room_alias_update_event(requester, room_id) defer.returnValue(result) @defer.inlineCallbacks def _send_events_for_new_room( - self, - creator, # A Requester object. - room_id, - preset_config, - invite_list, - initial_state, - creation_content, - room_alias=None, - power_level_content_override=None, - creator_join_profile=None, + self, + creator, # A Requester object. + room_id, + preset_config, + invite_list, + initial_state, + creation_content, + room_alias=None, + power_level_content_override=None, + creator_join_profile=None, ): def create(etype, content, **kwargs): - e = { - "type": etype, - "content": content, - } + e = {"type": etype, "content": content} e.update(event_keys) e.update(kwargs) @@ -693,26 +661,17 @@ def send(etype, content, **kwargs): event = create(etype, content, **kwargs) logger.info("Sending %s in new room", etype) yield self.event_creation_handler.create_and_send_nonmember_event( - creator, - event, - ratelimit=False + creator, event, ratelimit=False ) config = RoomCreationHandler.PRESETS_DICT[preset_config] creator_id = creator.user.to_string() - event_keys = { - "room_id": room_id, - "sender": creator_id, - "state_key": "", - } + event_keys = {"room_id": room_id, "sender": creator_id, "state_key": ""} creation_content.update({"creator": creator_id}) - yield send( - etype=EventTypes.Create, - content=creation_content, - ) + yield send(etype=EventTypes.Create, content=creation_content) logger.info("Sending %s in new room", EventTypes.Member) yield self.room_member_handler.update_membership( @@ -726,17 +685,12 @@ def send(etype, content, **kwargs): # We treat the power levels override specially as this needs to be one # of the first events that get sent into a room. - pl_content = initial_state.pop((EventTypes.PowerLevels, ''), None) + pl_content = initial_state.pop((EventTypes.PowerLevels, ""), None) if pl_content is not None: - yield send( - etype=EventTypes.PowerLevels, - content=pl_content, - ) + yield send(etype=EventTypes.PowerLevels, content=pl_content) else: power_level_content = { - "users": { - creator_id: 100, - }, + "users": {creator_id: 100}, "users_default": 0, "events": { EventTypes.Name: 50, @@ -760,42 +714,33 @@ def send(etype, content, **kwargs): if power_level_content_override: power_level_content.update(power_level_content_override) - yield send( - etype=EventTypes.PowerLevels, - content=power_level_content, - ) + yield send(etype=EventTypes.PowerLevels, content=power_level_content) - if room_alias and (EventTypes.CanonicalAlias, '') not in initial_state: + if room_alias and (EventTypes.CanonicalAlias, "") not in initial_state: yield send( etype=EventTypes.CanonicalAlias, content={"alias": room_alias.to_string()}, ) - if (EventTypes.JoinRules, '') not in initial_state: + if (EventTypes.JoinRules, "") not in initial_state: yield send( - etype=EventTypes.JoinRules, - content={"join_rule": config["join_rules"]}, + etype=EventTypes.JoinRules, content={"join_rule": config["join_rules"]} ) - if (EventTypes.RoomHistoryVisibility, '') not in initial_state: + if (EventTypes.RoomHistoryVisibility, "") not in initial_state: yield send( etype=EventTypes.RoomHistoryVisibility, - content={"history_visibility": config["history_visibility"]} + content={"history_visibility": config["history_visibility"]}, ) if config["guest_can_join"]: - if (EventTypes.GuestAccess, '') not in initial_state: + if (EventTypes.GuestAccess, "") not in initial_state: yield send( - etype=EventTypes.GuestAccess, - content={"guest_access": "can_join"} + etype=EventTypes.GuestAccess, content={"guest_access": "can_join"} ) for (etype, state_key), content in initial_state.items(): - yield send( - etype=etype, - state_key=state_key, - content=content, - ) + yield send(etype=etype, state_key=state_key, content=content) @defer.inlineCallbacks def _generate_room_id(self, creator_id, is_public): @@ -805,12 +750,9 @@ def _generate_room_id(self, creator_id, is_public): while attempts < 5: try: random_string = stringutils.random_string(18) - gen_room_id = RoomID( - random_string, - self.hs.hostname, - ).to_string() + gen_room_id = RoomID(random_string, self.hs.hostname).to_string() if isinstance(gen_room_id, bytes): - gen_room_id = gen_room_id.decode('utf-8') + gen_room_id = gen_room_id.decode("utf-8") yield self.store.store_room( room_id=gen_room_id, room_creator_user_id=creator_id, @@ -844,7 +786,7 @@ def get_event_context(self, user, room_id, event_id, limit, event_filter): Returns: dict, or None if the event isn't found """ - before_limit = math.floor(limit / 2.) + before_limit = math.floor(limit / 2.0) after_limit = limit - before_limit users = yield self.store.get_users_in_room(room_id) @@ -852,24 +794,19 @@ def get_event_context(self, user, room_id, event_id, limit, event_filter): def filter_evts(events): return filter_events_for_client( - self.store, - user.to_string(), - events, - is_peeking=is_peeking + self.store, user.to_string(), events, is_peeking=is_peeking ) - event = yield self.store.get_event(event_id, get_prev_content=True, - allow_none=True) + event = yield self.store.get_event( + event_id, get_prev_content=True, allow_none=True + ) if not event: defer.returnValue(None) return - filtered = yield(filter_evts([event])) + filtered = yield (filter_evts([event])) if not filtered: - raise AuthError( - 403, - "You don't have permission to access that event." - ) + raise AuthError(403, "You don't have permission to access that event.") results = yield self.store.get_events_around( room_id, event_id, before_limit, after_limit, event_filter @@ -901,7 +838,7 @@ def filter_evts(events): # /~https://github.com/matrix-org/matrix-doc/issues/687 state = yield self.store.get_state_for_events( - [last_event_id], state_filter=state_filter, + [last_event_id], state_filter=state_filter ) results["state"] = list(state[last_event_id].values()) @@ -913,9 +850,7 @@ def filter_evts(events): "room_key", results["start"] ).to_string() - results["end"] = token.copy_and_replace( - "room_key", results["end"] - ).to_string() + results["end"] = token.copy_and_replace("room_key", results["end"]).to_string() defer.returnValue(results) @@ -926,13 +861,7 @@ def __init__(self, hs): @defer.inlineCallbacks def get_new_events( - self, - user, - from_key, - limit, - room_ids, - is_guest, - explicit_room_id=None, + self, user, from_key, limit, room_ids, is_guest, explicit_room_id=None ): # We just ignore the key for now. @@ -943,9 +872,7 @@ def get_new_events( logger.warn("Stream has topological part!!!! %r", from_key) from_key = "s%s" % (from_token.stream,) - app_service = self.store.get_app_service_by_user_id( - user.to_string() - ) + app_service = self.store.get_app_service_by_user_id(user.to_string()) if app_service: # We no longer support AS users using /sync directly. # See /~https://github.com/matrix-org/matrix-doc/issues/1144 @@ -960,7 +887,7 @@ def get_new_events( from_key=from_key, to_key=to_key, limit=limit or 10, - order='ASC', + order="ASC", ) events = list(room_events) diff --git a/synapse/handlers/room_list.py b/synapse/handlers/room_list.py index 617d1c9ef811..aae696a7e8ce 100644 --- a/synapse/handlers/room_list.py +++ b/synapse/handlers/room_list.py @@ -46,13 +46,18 @@ def __init__(self, hs): super(RoomListHandler, self).__init__(hs) self.enable_room_list_search = hs.config.enable_room_list_search self.response_cache = ResponseCache(hs, "room_list") - self.remote_response_cache = ResponseCache(hs, "remote_room_list", - timeout_ms=30 * 1000) + self.remote_response_cache = ResponseCache( + hs, "remote_room_list", timeout_ms=30 * 1000 + ) - def get_local_public_room_list(self, limit=None, since_token=None, - search_filter=None, - network_tuple=EMPTY_THIRD_PARTY_ID, - from_federation=False): + def get_local_public_room_list( + self, + limit=None, + since_token=None, + search_filter=None, + network_tuple=EMPTY_THIRD_PARTY_ID, + from_federation=False, + ): """Generate a local public room list. There are multiple different lists: the main one plus one per third @@ -68,14 +73,14 @@ def get_local_public_room_list(self, limit=None, since_token=None, Setting to None returns all public rooms across all lists. """ if not self.enable_room_list_search: - return defer.succeed({ - "chunk": [], - "total_room_count_estimate": 0, - }) + return defer.succeed({"chunk": [], "total_room_count_estimate": 0}) logger.info( "Getting public room list: limit=%r, since=%r, search=%r, network=%r", - limit, since_token, bool(search_filter), network_tuple, + limit, + since_token, + bool(search_filter), + network_tuple, ) if search_filter: @@ -88,24 +93,33 @@ def get_local_public_room_list(self, limit=None, since_token=None, # solution at some point timeout = self.clock.time() + 60 return self._get_public_room_list( - limit, since_token, search_filter, - network_tuple=network_tuple, timeout=timeout, + limit, + since_token, + search_filter, + network_tuple=network_tuple, + timeout=timeout, ) key = (limit, since_token, network_tuple) return self.response_cache.wrap( key, self._get_public_room_list, - limit, since_token, - network_tuple=network_tuple, from_federation=from_federation, + limit, + since_token, + network_tuple=network_tuple, + from_federation=from_federation, ) @defer.inlineCallbacks - def _get_public_room_list(self, limit=None, since_token=None, - search_filter=None, - network_tuple=EMPTY_THIRD_PARTY_ID, - from_federation=False, - timeout=None,): + def _get_public_room_list( + self, + limit=None, + since_token=None, + search_filter=None, + network_tuple=EMPTY_THIRD_PARTY_ID, + from_federation=False, + timeout=None, + ): """Generate a public room list. Args: limit (int|None): Maximum amount of rooms to return. @@ -135,15 +149,14 @@ def _get_public_room_list(self, limit=None, since_token=None, current_public_id = yield self.store.get_current_public_room_stream_id() public_room_stream_id = since_token.public_room_stream_id newly_visible, newly_unpublished = yield self.store.get_public_room_changes( - public_room_stream_id, current_public_id, - network_tuple=network_tuple, + public_room_stream_id, current_public_id, network_tuple=network_tuple ) else: stream_token = yield self.store.get_room_max_stream_ordering() public_room_stream_id = yield self.store.get_current_public_room_stream_id() room_ids = yield self.store.get_public_room_ids_at_stream_id( - public_room_stream_id, network_tuple=network_tuple, + public_room_stream_id, network_tuple=network_tuple ) # We want to return rooms in a particular order: the number of joined @@ -168,7 +181,7 @@ def get_order_for_room(room_id): return joined_users = yield self.state_handler.get_current_users_in_room( - room_id, latest_event_ids, + room_id, latest_event_ids ) num_joined_users = len(joined_users) @@ -180,8 +193,9 @@ def get_order_for_room(room_id): # We want larger rooms to be first, hence negating num_joined_users rooms_to_order_value[room_id] = (-num_joined_users, room_id) - logger.info("Getting ordering for %i rooms since %s", - len(room_ids), stream_token) + logger.info( + "Getting ordering for %i rooms since %s", len(room_ids), stream_token + ) yield concurrently_execute(get_order_for_room, room_ids, 10) sorted_entries = sorted(rooms_to_order_value.items(), key=lambda e: e[1]) @@ -193,7 +207,8 @@ def get_order_for_room(room_id): # Filter out rooms that we don't want to return rooms_to_scan = [ - r for r in sorted_rooms + r + for r in sorted_rooms if r not in newly_unpublished and rooms_to_num_joined[r] > 0 ] @@ -204,13 +219,12 @@ def get_order_for_room(room_id): # `since_token.current_limit` is the index of the last room we # sent down, so we exclude it and everything before/after it. if since_token.direction_is_forward: - rooms_to_scan = rooms_to_scan[since_token.current_limit + 1:] + rooms_to_scan = rooms_to_scan[since_token.current_limit + 1 :] else: - rooms_to_scan = rooms_to_scan[:since_token.current_limit] + rooms_to_scan = rooms_to_scan[: since_token.current_limit] rooms_to_scan.reverse() - logger.info("After sorting and filtering, %i rooms remain", - len(rooms_to_scan)) + logger.info("After sorting and filtering, %i rooms remain", len(rooms_to_scan)) # _append_room_entry_to_chunk will append to chunk but will stop if # len(chunk) > limit @@ -237,15 +251,19 @@ def get_order_for_room(room_id): if timeout and self.clock.time() > timeout: raise Exception("Timed out searching room directory") - batch = rooms_to_scan[i:i + step] + batch = rooms_to_scan[i : i + step] logger.info("Processing %i rooms for result", len(batch)) yield concurrently_execute( lambda r: self._append_room_entry_to_chunk( - r, rooms_to_num_joined[r], - chunk, limit, search_filter, + r, + rooms_to_num_joined[r], + chunk, + limit, + search_filter, from_federation=from_federation, ), - batch, 5, + batch, + 5, ) logger.info("Now %i rooms in result", len(chunk)) if len(chunk) >= limit + 1: @@ -273,10 +291,7 @@ def get_order_for_room(room_id): new_limit = sorted_rooms.index(last_room_id) - results = { - "chunk": chunk, - "total_room_count_estimate": total_room_count, - } + results = {"chunk": chunk, "total_room_count_estimate": total_room_count} if since_token: results["new_rooms"] = bool(newly_visible) @@ -313,8 +328,15 @@ def get_order_for_room(room_id): defer.returnValue(results) @defer.inlineCallbacks - def _append_room_entry_to_chunk(self, room_id, num_joined_users, chunk, limit, - search_filter, from_federation=False): + def _append_room_entry_to_chunk( + self, + room_id, + num_joined_users, + chunk, + limit, + search_filter, + from_federation=False, + ): """Generate the entry for a room in the public room list and append it to the `chunk` if it matches the search filter @@ -345,8 +367,14 @@ def _append_room_entry_to_chunk(self, room_id, num_joined_users, chunk, limit, chunk.append(result) @cachedInlineCallbacks(num_args=1, cache_context=True) - def generate_room_entry(self, room_id, num_joined_users, cache_context, - with_alias=True, allow_private=False): + def generate_room_entry( + self, + room_id, + num_joined_users, + cache_context, + with_alias=True, + allow_private=False, + ): """Returns the entry for a room Args: @@ -360,33 +388,31 @@ def generate_room_entry(self, room_id, num_joined_users, cache_context, Deferred[dict|None]: Returns a room entry as a dictionary, or None if this room was determined not to be shown publicly. """ - result = { - "room_id": room_id, - "num_joined_members": num_joined_users, - } + result = {"room_id": room_id, "num_joined_members": num_joined_users} current_state_ids = yield self.store.get_current_state_ids( - room_id, on_invalidate=cache_context.invalidate, + room_id, on_invalidate=cache_context.invalidate ) - event_map = yield self.store.get_events([ - event_id for key, event_id in iteritems(current_state_ids) - if key[0] in ( - EventTypes.Create, - EventTypes.JoinRules, - EventTypes.Name, - EventTypes.Topic, - EventTypes.CanonicalAlias, - EventTypes.RoomHistoryVisibility, - EventTypes.GuestAccess, - "m.room.avatar", - ) - ]) + event_map = yield self.store.get_events( + [ + event_id + for key, event_id in iteritems(current_state_ids) + if key[0] + in ( + EventTypes.Create, + EventTypes.JoinRules, + EventTypes.Name, + EventTypes.Topic, + EventTypes.CanonicalAlias, + EventTypes.RoomHistoryVisibility, + EventTypes.GuestAccess, + "m.room.avatar", + ) + ] + ) - current_state = { - (ev.type, ev.state_key): ev - for ev in event_map.values() - } + current_state = {(ev.type, ev.state_key): ev for ev in event_map.values()} # Double check that this is actually a public room. @@ -446,14 +472,17 @@ def generate_room_entry(self, room_id, num_joined_users, cache_context, defer.returnValue(result) @defer.inlineCallbacks - def get_remote_public_room_list(self, server_name, limit=None, since_token=None, - search_filter=None, include_all_networks=False, - third_party_instance_id=None,): + def get_remote_public_room_list( + self, + server_name, + limit=None, + since_token=None, + search_filter=None, + include_all_networks=False, + third_party_instance_id=None, + ): if not self.enable_room_list_search: - defer.returnValue({ - "chunk": [], - "total_room_count_estimate": 0, - }) + defer.returnValue({"chunk": [], "total_room_count_estimate": 0}) if search_filter: # We currently don't support searching across federation, so we have @@ -462,52 +491,75 @@ def get_remote_public_room_list(self, server_name, limit=None, since_token=None, since_token = None res = yield self._get_remote_list_cached( - server_name, limit=limit, since_token=since_token, + server_name, + limit=limit, + since_token=since_token, include_all_networks=include_all_networks, third_party_instance_id=third_party_instance_id, ) if search_filter: - res = {"chunk": [ - entry - for entry in list(res.get("chunk", [])) - if _matches_room_entry(entry, search_filter) - ]} + res = { + "chunk": [ + entry + for entry in list(res.get("chunk", [])) + if _matches_room_entry(entry, search_filter) + ] + } defer.returnValue(res) - def _get_remote_list_cached(self, server_name, limit=None, since_token=None, - search_filter=None, include_all_networks=False, - third_party_instance_id=None,): + def _get_remote_list_cached( + self, + server_name, + limit=None, + since_token=None, + search_filter=None, + include_all_networks=False, + third_party_instance_id=None, + ): repl_layer = self.hs.get_federation_client() if search_filter: # We can't cache when asking for search return repl_layer.get_public_rooms( - server_name, limit=limit, since_token=since_token, - search_filter=search_filter, include_all_networks=include_all_networks, + server_name, + limit=limit, + since_token=since_token, + search_filter=search_filter, + include_all_networks=include_all_networks, third_party_instance_id=third_party_instance_id, ) key = ( - server_name, limit, since_token, include_all_networks, + server_name, + limit, + since_token, + include_all_networks, third_party_instance_id, ) return self.remote_response_cache.wrap( key, repl_layer.get_public_rooms, - server_name, limit=limit, since_token=since_token, + server_name, + limit=limit, + since_token=since_token, search_filter=search_filter, include_all_networks=include_all_networks, third_party_instance_id=third_party_instance_id, ) -class RoomListNextBatch(namedtuple("RoomListNextBatch", ( - "stream_ordering", # stream_ordering of the first public room list - "public_room_stream_id", # public room stream id for first public room list - "current_limit", # The number of previous rooms returned - "direction_is_forward", # Bool if this is a next_batch, false if prev_batch -))): +class RoomListNextBatch( + namedtuple( + "RoomListNextBatch", + ( + "stream_ordering", # stream_ordering of the first public room list + "public_room_stream_id", # public room stream id for first public room list + "current_limit", # The number of previous rooms returned + "direction_is_forward", # Bool if this is a next_batch, false if prev_batch + ), + ) +): KEY_DICT = { "stream_ordering": "s", @@ -527,21 +579,19 @@ def from_token(cls, token): decoded = msgpack.loads(decode_base64(token), raw=False) else: decoded = msgpack.loads(decode_base64(token)) - return RoomListNextBatch(**{ - cls.REVERSE_KEY_DICT[key]: val - for key, val in decoded.items() - }) + return RoomListNextBatch( + **{cls.REVERSE_KEY_DICT[key]: val for key, val in decoded.items()} + ) def to_token(self): - return encode_base64(msgpack.dumps({ - self.KEY_DICT[key]: val - for key, val in self._asdict().items() - })) + return encode_base64( + msgpack.dumps( + {self.KEY_DICT[key]: val for key, val in self._asdict().items()} + ) + ) def copy_and_replace(self, **kwds): - return self._replace( - **kwds - ) + return self._replace(**kwds) def _matches_room_entry(room_entry, search_filter): diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py index 458902bb7e15..4d6e8838025c 100644 --- a/synapse/handlers/room_member.py +++ b/synapse/handlers/room_member.py @@ -166,7 +166,11 @@ def _user_left_room(self, target, room_id): @defer.inlineCallbacks def _local_membership_update( - self, requester, target, room_id, membership, + self, + requester, + target, + room_id, + membership, prev_events_and_hashes, txn_id=None, ratelimit=True, @@ -190,7 +194,6 @@ def _local_membership_update( "room_id": room_id, "sender": requester.user.to_string(), "state_key": user_id, - # For backwards compatibility: "membership": membership, }, @@ -202,26 +205,19 @@ def _local_membership_update( # Check if this event matches the previous membership event for the user. duplicate = yield self.event_creation_handler.deduplicate_state_event( - event, context, + event, context ) if duplicate is not None: # Discard the new event since this membership change is a no-op. defer.returnValue(duplicate) yield self.event_creation_handler.handle_new_client_event( - requester, - event, - context, - extra_users=[target], - ratelimit=ratelimit, + requester, event, context, extra_users=[target], ratelimit=ratelimit ) prev_state_ids = yield context.get_prev_state_ids(self.store) - prev_member_event_id = prev_state_ids.get( - (EventTypes.Member, user_id), - None - ) + prev_member_event_id = prev_state_ids.get((EventTypes.Member, user_id), None) if event.membership == Membership.JOIN: # Only fire user_joined_room if the user has actually joined the @@ -243,11 +239,11 @@ def _local_membership_update( if predecessor: # It is an upgraded room. Copy over old tags self.copy_room_tags_and_direct_to_room( - predecessor["room_id"], room_id, user_id, + predecessor["room_id"], room_id, user_id ) # Move over old push rules self.store.move_push_rules_from_room_to_room_for_user( - predecessor["room_id"], room_id, user_id, + predecessor["room_id"], room_id, user_id ) elif event.membership == Membership.LEAVE: if prev_member_event_id: @@ -258,12 +254,7 @@ def _local_membership_update( defer.returnValue(event) @defer.inlineCallbacks - def copy_room_tags_and_direct_to_room( - self, - old_room_id, - new_room_id, - user_id, - ): + def copy_room_tags_and_direct_to_room(self, old_room_id, new_room_id, user_id): """Copies the tags and direct room state from one room to another. Args: @@ -275,9 +266,7 @@ def copy_room_tags_and_direct_to_room( Deferred[None] """ # Retrieve user account data for predecessor room - user_account_data, _ = yield self.store.get_account_data_for_user( - user_id, - ) + user_account_data, _ = yield self.store.get_account_data_for_user(user_id) # Copy direct message state if applicable direct_rooms = user_account_data.get("m.direct", {}) @@ -291,34 +280,30 @@ def copy_room_tags_and_direct_to_room( # Save back to user's m.direct account data yield self.store.add_account_data_for_user( - user_id, "m.direct", direct_rooms, + user_id, "m.direct", direct_rooms ) break # Copy room tags if applicable - room_tags = yield self.store.get_tags_for_room( - user_id, old_room_id, - ) + room_tags = yield self.store.get_tags_for_room(user_id, old_room_id) # Copy each room tag to the new room for tag, tag_content in room_tags.items(): - yield self.store.add_tag_to_room( - user_id, new_room_id, tag, tag_content - ) + yield self.store.add_tag_to_room(user_id, new_room_id, tag, tag_content) @defer.inlineCallbacks def update_membership( - self, - requester, - target, - room_id, - action, - txn_id=None, - remote_room_hosts=None, - third_party_signed=None, - ratelimit=True, - content=None, - require_consent=True, + self, + requester, + target, + room_id, + action, + txn_id=None, + remote_room_hosts=None, + third_party_signed=None, + ratelimit=True, + content=None, + require_consent=True, ): key = (room_id,) @@ -340,17 +325,17 @@ def update_membership( @defer.inlineCallbacks def _update_membership( - self, - requester, - target, - room_id, - action, - txn_id=None, - remote_room_hosts=None, - third_party_signed=None, - ratelimit=True, - content=None, - require_consent=True, + self, + requester, + target, + room_id, + action, + txn_id=None, + remote_room_hosts=None, + third_party_signed=None, + ratelimit=True, + content=None, + require_consent=True, ): content_specified = bool(content) if content is None: @@ -384,7 +369,7 @@ def _update_membership( if not remote_room_hosts: remote_room_hosts = [] - if effective_membership_state not in ("leave", "ban",): + if effective_membership_state not in ("leave", "ban"): is_blocked = yield self.store.is_room_blocked(room_id) if is_blocked: raise SynapseError(403, "This room has been blocked on this server") @@ -392,22 +377,19 @@ def _update_membership( if effective_membership_state == Membership.INVITE: # block any attempts to invite the server notices mxid if target.to_string() == self._server_notices_mxid: - raise SynapseError( - http_client.FORBIDDEN, - "Cannot invite this user", - ) + raise SynapseError(http_client.FORBIDDEN, "Cannot invite this user") block_invite = False - if (self._server_notices_mxid is not None and - requester.user.to_string() == self._server_notices_mxid): + if ( + self._server_notices_mxid is not None + and requester.user.to_string() == self._server_notices_mxid + ): # allow the server notices mxid to send invites is_requester_admin = True else: - is_requester_admin = yield self.auth.is_server_admin( - requester.user, - ) + is_requester_admin = yield self.auth.is_server_admin(requester.user) if not is_requester_admin: if self.config.block_non_admin_invites: @@ -418,25 +400,19 @@ def _update_membership( block_invite = True if not self.spam_checker.user_may_invite( - requester.user.to_string(), target.to_string(), room_id, + requester.user.to_string(), target.to_string(), room_id ): logger.info("Blocking invite due to spam checker") block_invite = True if block_invite: - raise SynapseError( - 403, "Invites have been disabled on this server", - ) + raise SynapseError(403, "Invites have been disabled on this server") - prev_events_and_hashes = yield self.store.get_prev_events_for_room( - room_id, - ) - latest_event_ids = ( - event_id for (event_id, _, _) in prev_events_and_hashes - ) + prev_events_and_hashes = yield self.store.get_prev_events_for_room(room_id) + latest_event_ids = (event_id for (event_id, _, _) in prev_events_and_hashes) current_state_ids = yield self.state_handler.get_current_state_ids( - room_id, latest_event_ids=latest_event_ids, + room_id, latest_event_ids=latest_event_ids ) # TODO: Refactor into dictionary of explicitly allowed transitions @@ -451,13 +427,13 @@ def _update_membership( 403, "Cannot unban user who was not banned" " (membership=%s)" % old_membership, - errcode=Codes.BAD_STATE + errcode=Codes.BAD_STATE, ) if old_membership == "ban" and action != "unban": raise SynapseError( 403, "Cannot %s user who was banned" % (action,), - errcode=Codes.BAD_STATE + errcode=Codes.BAD_STATE, ) if old_state: @@ -473,8 +449,8 @@ def _update_membership( # we don't allow people to reject invites to the server notice # room, but they can leave it once they are joined. if ( - old_membership == Membership.INVITE and - effective_membership_state == Membership.LEAVE + old_membership == Membership.INVITE + and effective_membership_state == Membership.LEAVE ): is_blocked = yield self._is_server_notice_room(room_id) if is_blocked: @@ -535,7 +511,7 @@ def _update_membership( # send the rejection to the inviter's HS. remote_room_hosts = remote_room_hosts + [inviter.domain] res = yield self._remote_reject_invite( - requester, remote_room_hosts, room_id, target, + requester, remote_room_hosts, room_id, target ) defer.returnValue(res) @@ -554,12 +530,7 @@ def _update_membership( @defer.inlineCallbacks def send_membership_event( - self, - requester, - event, - context, - remote_room_hosts=None, - ratelimit=True, + self, requester, event, context, remote_room_hosts=None, ratelimit=True ): """ Change the membership status of a user in a room. @@ -585,16 +556,15 @@ def send_membership_event( if requester is not None: sender = UserID.from_string(event.sender) - assert sender == requester.user, ( - "Sender (%s) must be same as requester (%s)" % - (sender, requester.user) - ) + assert ( + sender == requester.user + ), "Sender (%s) must be same as requester (%s)" % (sender, requester.user) assert self.hs.is_mine(sender), "Sender must be our own: %s" % (sender,) else: requester = synapse.types.create_requester(target_user) prev_event = yield self.event_creation_handler.deduplicate_state_event( - event, context, + event, context ) if prev_event is not None: return @@ -614,16 +584,11 @@ def send_membership_event( raise SynapseError(403, "This room has been blocked on this server") yield self.event_creation_handler.handle_new_client_event( - requester, - event, - context, - extra_users=[target_user], - ratelimit=ratelimit, + requester, event, context, extra_users=[target_user], ratelimit=ratelimit ) prev_member_event_id = prev_state_ids.get( - (EventTypes.Member, event.state_key), - None + (EventTypes.Member, event.state_key), None ) if event.membership == Membership.JOIN: @@ -693,31 +658,20 @@ def lookup_room_alias(self, room_alias): @defer.inlineCallbacks def _get_inviter(self, user_id, room_id): invite = yield self.store.get_invite_for_user_in_room( - user_id=user_id, - room_id=room_id, + user_id=user_id, room_id=room_id ) if invite: defer.returnValue(UserID.from_string(invite.sender)) @defer.inlineCallbacks def do_3pid_invite( - self, - room_id, - inviter, - medium, - address, - id_server, - requester, - txn_id + self, room_id, inviter, medium, address, id_server, requester, txn_id ): if self.config.block_non_admin_invites: - is_requester_admin = yield self.auth.is_server_admin( - requester.user, - ) + is_requester_admin = yield self.auth.is_server_admin(requester.user) if not is_requester_admin: raise SynapseError( - 403, "Invites have been disabled on this server", - Codes.FORBIDDEN, + 403, "Invites have been disabled on this server", Codes.FORBIDDEN ) # We need to rate limit *before* we send out any 3PID invites, so we @@ -725,35 +679,24 @@ def do_3pid_invite( yield self.base_handler.ratelimit(requester) can_invite = yield self.third_party_event_rules.check_threepid_can_be_invited( - medium, address, room_id, + medium, address, room_id ) if not can_invite: raise SynapseError( - 403, "This third-party identifier can not be invited in this room", + 403, + "This third-party identifier can not be invited in this room", Codes.FORBIDDEN, ) - invitee = yield self._lookup_3pid( - id_server, medium, address - ) + invitee = yield self._lookup_3pid(id_server, medium, address) if invitee: yield self.update_membership( - requester, - UserID.from_string(invitee), - room_id, - "invite", - txn_id=txn_id, + requester, UserID.from_string(invitee), room_id, "invite", txn_id=txn_id ) else: yield self._make_and_store_3pid_invite( - requester, - id_server, - medium, - address, - room_id, - inviter, - txn_id=txn_id + requester, id_server, medium, address, room_id, inviter, txn_id=txn_id ) @defer.inlineCallbacks @@ -771,15 +714,12 @@ def _lookup_3pid(self, id_server, medium, address): """ if not self._enable_lookup: raise SynapseError( - 403, "Looking up third-party identifiers is denied from this server", + 403, "Looking up third-party identifiers is denied from this server" ) try: data = yield self.simple_http_client.get_json( - "%s%s/_matrix/identity/api/v1/lookup" % (id_server_scheme, id_server,), - { - "medium": medium, - "address": address, - } + "%s%s/_matrix/identity/api/v1/lookup" % (id_server_scheme, id_server), + {"medium": medium, "address": address}, ) if "mxid" in data: @@ -798,29 +738,25 @@ def _verify_any_signature(self, data, server_hostname): raise AuthError(401, "No signature from server %s" % (server_hostname,)) for key_name, signature in data["signatures"][server_hostname].items(): key_data = yield self.simple_http_client.get_json( - "%s%s/_matrix/identity/api/v1/pubkey/%s" % - (id_server_scheme, server_hostname, key_name,), + "%s%s/_matrix/identity/api/v1/pubkey/%s" + % (id_server_scheme, server_hostname, key_name) ) if "public_key" not in key_data: - raise AuthError(401, "No public key named %s from %s" % - (key_name, server_hostname,)) + raise AuthError( + 401, "No public key named %s from %s" % (key_name, server_hostname) + ) verify_signed_json( data, server_hostname, - decode_verify_key_bytes(key_name, decode_base64(key_data["public_key"])) + decode_verify_key_bytes( + key_name, decode_base64(key_data["public_key"]) + ), ) return @defer.inlineCallbacks def _make_and_store_3pid_invite( - self, - requester, - id_server, - medium, - address, - room_id, - user, - txn_id + self, requester, id_server, medium, address, room_id, user, txn_id ): room_state = yield self.state_handler.get_current_state(room_id) @@ -868,7 +804,7 @@ def _make_and_store_3pid_invite( room_join_rules=room_join_rules, room_name=room_name, inviter_display_name=inviter_display_name, - inviter_avatar_url=inviter_avatar_url + inviter_avatar_url=inviter_avatar_url, ) ) @@ -879,7 +815,6 @@ def _make_and_store_3pid_invite( "content": { "display_name": display_name, "public_keys": public_keys, - # For backwards compatibility: "key_validity_url": fallback_public_key["key_validity_url"], "public_key": fallback_public_key["public_key"], @@ -893,19 +828,19 @@ def _make_and_store_3pid_invite( @defer.inlineCallbacks def _ask_id_server_for_third_party_invite( - self, - requester, - id_server, - medium, - address, - room_id, - inviter_user_id, - room_alias, - room_avatar_url, - room_join_rules, - room_name, - inviter_display_name, - inviter_avatar_url + self, + requester, + id_server, + medium, + address, + room_id, + inviter_user_id, + room_alias, + room_avatar_url, + room_join_rules, + room_name, + inviter_display_name, + inviter_avatar_url, ): """ Asks an identity server for a third party invite. @@ -937,7 +872,8 @@ def _ask_id_server_for_third_party_invite( """ is_url = "%s%s/_matrix/identity/api/v1/store-invite" % ( - id_server_scheme, id_server, + id_server_scheme, + id_server, ) invite_config = { @@ -961,14 +897,15 @@ def _ask_id_server_for_third_party_invite( inviter_user_id=inviter_user_id, ) - invite_config.update({ - "guest_access_token": guest_access_token, - "guest_user_id": guest_user_id, - }) + invite_config.update( + { + "guest_access_token": guest_access_token, + "guest_user_id": guest_user_id, + } + ) data = yield self.simple_http_client.post_urlencoded_get_json( - is_url, - invite_config + is_url, invite_config ) # TODO: Check for success token = data["token"] @@ -976,9 +913,8 @@ def _ask_id_server_for_third_party_invite( if "public_key" in data: fallback_public_key = { "public_key": data["public_key"], - "key_validity_url": "%s%s/_matrix/identity/api/v1/pubkey/isvalid" % ( - id_server_scheme, id_server, - ), + "key_validity_url": "%s%s/_matrix/identity/api/v1/pubkey/isvalid" + % (id_server_scheme, id_server), } else: fallback_public_key = public_keys[0] @@ -1047,10 +983,7 @@ def _remote_join(self, requester, remote_room_hosts, room_id, user, content): # that we are allowed to join when we decide whether or not we # need to do the invite/join dance. yield self.federation_handler.do_invite_join( - remote_room_hosts, - room_id, - user.to_string(), - content, + remote_room_hosts, room_id, user.to_string(), content ) yield self._user_joined_room(user, room_id) @@ -1061,9 +994,7 @@ def _remote_reject_invite(self, requester, remote_room_hosts, room_id, target): fed_handler = self.federation_handler try: ret = yield fed_handler.do_remotely_reject_invite( - remote_room_hosts, - room_id, - target.to_string(), + remote_room_hosts, room_id, target.to_string() ) defer.returnValue(ret) except Exception as e: @@ -1075,9 +1006,7 @@ def _remote_reject_invite(self, requester, remote_room_hosts, room_id, target): # logger.warn("Failed to reject invite: %s", e) - yield self.store.locally_reject_invite( - target.to_string(), room_id - ) + yield self.store.locally_reject_invite(target.to_string(), room_id) defer.returnValue({}) def get_or_register_3pid_guest(self, requester, medium, address, inviter_user_id): @@ -1101,18 +1030,15 @@ def forget(self, user, room_id): user_id = user.to_string() member = yield self.state_handler.get_current_state( - room_id=room_id, - event_type=EventTypes.Member, - state_key=user_id + room_id=room_id, event_type=EventTypes.Member, state_key=user_id ) membership = member.membership if member else None if membership is not None and membership not in [ - Membership.LEAVE, Membership.BAN + Membership.LEAVE, + Membership.BAN, ]: - raise SynapseError(400, "User %s in room %s" % ( - user_id, room_id - )) + raise SynapseError(400, "User %s in room %s" % (user_id, room_id)) if membership: yield self.store.forget(user_id, room_id) diff --git a/synapse/handlers/room_member_worker.py b/synapse/handlers/room_member_worker.py index acc6eb8099fb..da501f38c04e 100644 --- a/synapse/handlers/room_member_worker.py +++ b/synapse/handlers/room_member_worker.py @@ -71,18 +71,14 @@ def _user_joined_room(self, target, room_id): """Implements RoomMemberHandler._user_joined_room """ return self._notify_change_client( - user_id=target.to_string(), - room_id=room_id, - change="joined", + user_id=target.to_string(), room_id=room_id, change="joined" ) def _user_left_room(self, target, room_id): """Implements RoomMemberHandler._user_left_room """ return self._notify_change_client( - user_id=target.to_string(), - room_id=room_id, - change="left", + user_id=target.to_string(), room_id=room_id, change="left" ) def get_or_register_3pid_guest(self, requester, medium, address, inviter_user_id): diff --git a/synapse/handlers/search.py b/synapse/handlers/search.py index 9bba74d6c91c..ddc4430d03a1 100644 --- a/synapse/handlers/search.py +++ b/synapse/handlers/search.py @@ -32,7 +32,6 @@ class SearchHandler(BaseHandler): - def __init__(self, hs): super(SearchHandler, self).__init__(hs) self._event_serializer = hs.get_event_client_serializer() @@ -93,7 +92,7 @@ def search(self, user, content, batch=None): batch_token = None if batch: try: - b = decode_base64(batch).decode('ascii') + b = decode_base64(batch).decode("ascii") batch_group, batch_group_key, batch_token = b.split("\n") assert batch_group is not None @@ -104,7 +103,9 @@ def search(self, user, content, batch=None): logger.info( "Search batch properties: %r, %r, %r", - batch_group, batch_group_key, batch_token, + batch_group, + batch_group_key, + batch_token, ) logger.info("Search content: %s", content) @@ -116,9 +117,9 @@ def search(self, user, content, batch=None): search_term = room_cat["search_term"] # Which "keys" to search over in FTS query - keys = room_cat.get("keys", [ - "content.body", "content.name", "content.topic", - ]) + keys = room_cat.get( + "keys", ["content.body", "content.name", "content.topic"] + ) # Filter to apply to results filter_dict = room_cat.get("filter", {}) @@ -130,9 +131,7 @@ def search(self, user, content, batch=None): include_state = room_cat.get("include_state", False) # Include context around each event? - event_context = room_cat.get( - "event_context", None - ) + event_context = room_cat.get("event_context", None) # Group results together? May allow clients to paginate within a # group @@ -140,12 +139,8 @@ def search(self, user, content, batch=None): group_keys = [g["key"] for g in group_by] if event_context is not None: - before_limit = int(event_context.get( - "before_limit", 5 - )) - after_limit = int(event_context.get( - "after_limit", 5 - )) + before_limit = int(event_context.get("before_limit", 5)) + after_limit = int(event_context.get("after_limit", 5)) # Return the historic display name and avatar for the senders # of the events? @@ -159,7 +154,8 @@ def search(self, user, content, batch=None): if set(group_keys) - {"room_id", "sender"}: raise SynapseError( 400, - "Invalid group by keys: %r" % (set(group_keys) - {"room_id", "sender"},) + "Invalid group by keys: %r" + % (set(group_keys) - {"room_id", "sender"},), ) search_filter = Filter(filter_dict) @@ -190,15 +186,13 @@ def search(self, user, content, batch=None): room_ids.intersection_update({batch_group_key}) if not room_ids: - defer.returnValue({ - "search_categories": { - "room_events": { - "results": [], - "count": 0, - "highlights": [], + defer.returnValue( + { + "search_categories": { + "room_events": {"results": [], "count": 0, "highlights": []} } } - }) + ) rank_map = {} # event_id -> rank of event allowed_events = [] @@ -213,9 +207,7 @@ def search(self, user, content, batch=None): count = None if order_by == "rank": - search_result = yield self.store.search_msgs( - room_ids, search_term, keys - ) + search_result = yield self.store.search_msgs(room_ids, search_term, keys) count = search_result["count"] @@ -235,19 +227,17 @@ def search(self, user, content, batch=None): ) events.sort(key=lambda e: -rank_map[e.event_id]) - allowed_events = events[:search_filter.limit()] + allowed_events = events[: search_filter.limit()] for e in allowed_events: - rm = room_groups.setdefault(e.room_id, { - "results": [], - "order": rank_map[e.event_id], - }) + rm = room_groups.setdefault( + e.room_id, {"results": [], "order": rank_map[e.event_id]} + ) rm["results"].append(e.event_id) - s = sender_group.setdefault(e.sender, { - "results": [], - "order": rank_map[e.event_id], - }) + s = sender_group.setdefault( + e.sender, {"results": [], "order": rank_map[e.event_id]} + ) s["results"].append(e.event_id) elif order_by == "recent": @@ -262,7 +252,10 @@ def search(self, user, content, batch=None): while len(room_events) < search_filter.limit() and i < 5: i += 1 search_result = yield self.store.search_rooms( - room_ids, search_term, keys, search_filter.limit() * 2, + room_ids, + search_term, + keys, + search_filter.limit() * 2, pagination_token=pagination_token, ) @@ -277,16 +270,14 @@ def search(self, user, content, batch=None): rank_map.update({r["event"].event_id: r["rank"] for r in results}) - filtered_events = search_filter.filter([ - r["event"] for r in results - ]) + filtered_events = search_filter.filter([r["event"] for r in results]) events = yield filter_events_for_client( self.store, user.to_string(), filtered_events ) room_events.extend(events) - room_events = room_events[:search_filter.limit()] + room_events = room_events[: search_filter.limit()] if len(results) < search_filter.limit() * 2: pagination_token = None @@ -295,9 +286,7 @@ def search(self, user, content, batch=None): pagination_token = results[-1]["pagination_token"] for event in room_events: - group = room_groups.setdefault(event.room_id, { - "results": [], - }) + group = room_groups.setdefault(event.room_id, {"results": []}) group["results"].append(event.event_id) if room_events and len(room_events) >= search_filter.limit(): @@ -309,18 +298,23 @@ def search(self, user, content, batch=None): # it returns more from the same group (if applicable) rather # than reverting to searching all results again. if batch_group and batch_group_key: - global_next_batch = encode_base64(("%s\n%s\n%s" % ( - batch_group, batch_group_key, pagination_token - )).encode('ascii')) + global_next_batch = encode_base64( + ( + "%s\n%s\n%s" + % (batch_group, batch_group_key, pagination_token) + ).encode("ascii") + ) else: - global_next_batch = encode_base64(("%s\n%s\n%s" % ( - "all", "", pagination_token - )).encode('ascii')) + global_next_batch = encode_base64( + ("%s\n%s\n%s" % ("all", "", pagination_token)).encode("ascii") + ) for room_id, group in room_groups.items(): - group["next_batch"] = encode_base64(("%s\n%s\n%s" % ( - "room_id", room_id, pagination_token - )).encode('ascii')) + group["next_batch"] = encode_base64( + ("%s\n%s\n%s" % ("room_id", room_id, pagination_token)).encode( + "ascii" + ) + ) allowed_events.extend(room_events) @@ -338,12 +332,13 @@ def search(self, user, content, batch=None): contexts = {} for event in allowed_events: res = yield self.store.get_events_around( - event.room_id, event.event_id, before_limit, after_limit, + event.room_id, event.event_id, before_limit, after_limit ) logger.info( "Context for search returned %d and %d events", - len(res["events_before"]), len(res["events_after"]), + len(res["events_before"]), + len(res["events_after"]), ) res["events_before"] = yield filter_events_for_client( @@ -403,12 +398,12 @@ def search(self, user, content, batch=None): for context in contexts.values(): context["events_before"] = ( yield self._event_serializer.serialize_events( - context["events_before"], time_now, + context["events_before"], time_now ) ) context["events_after"] = ( yield self._event_serializer.serialize_events( - context["events_after"], time_now, + context["events_after"], time_now ) ) @@ -426,11 +421,15 @@ def search(self, user, content, batch=None): results = [] for e in allowed_events: - results.append({ - "rank": rank_map[e.event_id], - "result": (yield self._event_serializer.serialize_event(e, time_now)), - "context": contexts.get(e.event_id, {}), - }) + results.append( + { + "rank": rank_map[e.event_id], + "result": ( + yield self._event_serializer.serialize_event(e, time_now) + ), + "context": contexts.get(e.event_id, {}), + } + ) rooms_cat_res = { "results": results, @@ -442,7 +441,7 @@ def search(self, user, content, batch=None): s = {} for room_id, state in state_results.items(): s[room_id] = yield self._event_serializer.serialize_events( - state, time_now, + state, time_now ) rooms_cat_res["state"] = s @@ -456,8 +455,4 @@ def search(self, user, content, batch=None): if global_next_batch: rooms_cat_res["next_batch"] = global_next_batch - defer.returnValue({ - "search_categories": { - "room_events": rooms_cat_res - } - }) + defer.returnValue({"search_categories": {"room_events": rooms_cat_res}}) diff --git a/synapse/handlers/set_password.py b/synapse/handlers/set_password.py index 7ecdede4dc00..5a0995d4feab 100644 --- a/synapse/handlers/set_password.py +++ b/synapse/handlers/set_password.py @@ -25,6 +25,7 @@ class SetPasswordHandler(BaseHandler): """Handler which deals with changing user account passwords""" + def __init__(self, hs): super(SetPasswordHandler, self).__init__(hs) self._auth_handler = hs.get_auth_handler() @@ -47,11 +48,11 @@ def set_password(self, user_id, newpassword, requester=None): # we want to log out all of the user's other sessions. First delete # all his other devices. yield self._device_handler.delete_all_devices_for_user( - user_id, except_device_id=except_device_id, + user_id, except_device_id=except_device_id ) # and now delete any access tokens which weren't associated with # devices (or were associated with this device). yield self._auth_handler.delete_access_tokens_for_user( - user_id, except_token_id=except_access_token_id, + user_id, except_token_id=except_access_token_id ) diff --git a/synapse/handlers/state_deltas.py b/synapse/handlers/state_deltas.py index b268bbcb2c22..6b364befd595 100644 --- a/synapse/handlers/state_deltas.py +++ b/synapse/handlers/state_deltas.py @@ -21,7 +21,6 @@ class StateDeltasHandler(object): - def __init__(self, hs): self.store = hs.get_datastore() diff --git a/synapse/handlers/stats.py b/synapse/handlers/stats.py index 7ad16c85665e..a0ee8db9884f 100644 --- a/synapse/handlers/stats.py +++ b/synapse/handlers/stats.py @@ -156,7 +156,7 @@ def _handle_deltas(self, deltas): prev_event_content = {} if prev_event_id is not None: prev_event = yield self.store.get_event( - prev_event_id, allow_none=True, + prev_event_id, allow_none=True ) if prev_event: prev_event_content = prev_event.content diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py index 62fda0c66440..c5188a1f8e68 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py @@ -64,20 +64,14 @@ LAZY_LOADED_MEMBERS_CACHE_MAX_SIZE = 100 -SyncConfig = collections.namedtuple("SyncConfig", [ - "user", - "filter_collection", - "is_guest", - "request_key", - "device_id", -]) - - -class TimelineBatch(collections.namedtuple("TimelineBatch", [ - "prev_batch", - "events", - "limited", -])): +SyncConfig = collections.namedtuple( + "SyncConfig", ["user", "filter_collection", "is_guest", "request_key", "device_id"] +) + + +class TimelineBatch( + collections.namedtuple("TimelineBatch", ["prev_batch", "events", "limited"]) +): __slots__ = [] def __nonzero__(self): @@ -85,18 +79,24 @@ def __nonzero__(self): to tell if room needs to be part of the sync result. """ return bool(self.events) + __bool__ = __nonzero__ # python3 -class JoinedSyncResult(collections.namedtuple("JoinedSyncResult", [ - "room_id", # str - "timeline", # TimelineBatch - "state", # dict[(str, str), FrozenEvent] - "ephemeral", - "account_data", - "unread_notifications", - "summary", -])): +class JoinedSyncResult( + collections.namedtuple( + "JoinedSyncResult", + [ + "room_id", # str + "timeline", # TimelineBatch + "state", # dict[(str, str), FrozenEvent] + "ephemeral", + "account_data", + "unread_notifications", + "summary", + ], + ) +): __slots__ = [] def __nonzero__(self): @@ -111,77 +111,93 @@ def __nonzero__(self): # nb the notification count does not, er, count: if there's nothing # else in the result, we don't need to send it. ) + __bool__ = __nonzero__ # python3 -class ArchivedSyncResult(collections.namedtuple("ArchivedSyncResult", [ - "room_id", # str - "timeline", # TimelineBatch - "state", # dict[(str, str), FrozenEvent] - "account_data", -])): +class ArchivedSyncResult( + collections.namedtuple( + "ArchivedSyncResult", + [ + "room_id", # str + "timeline", # TimelineBatch + "state", # dict[(str, str), FrozenEvent] + "account_data", + ], + ) +): __slots__ = [] def __nonzero__(self): """Make the result appear empty if there are no updates. This is used to tell if room needs to be part of the sync result. """ - return bool( - self.timeline - or self.state - or self.account_data - ) + return bool(self.timeline or self.state or self.account_data) + __bool__ = __nonzero__ # python3 -class InvitedSyncResult(collections.namedtuple("InvitedSyncResult", [ - "room_id", # str - "invite", # FrozenEvent: the invite event -])): +class InvitedSyncResult( + collections.namedtuple( + "InvitedSyncResult", + ["room_id", "invite"], # str # FrozenEvent: the invite event + ) +): __slots__ = [] def __nonzero__(self): """Invited rooms should always be reported to the client""" return True + __bool__ = __nonzero__ # python3 -class GroupsSyncResult(collections.namedtuple("GroupsSyncResult", [ - "join", - "invite", - "leave", -])): +class GroupsSyncResult( + collections.namedtuple("GroupsSyncResult", ["join", "invite", "leave"]) +): __slots__ = [] def __nonzero__(self): return bool(self.join or self.invite or self.leave) + __bool__ = __nonzero__ # python3 -class DeviceLists(collections.namedtuple("DeviceLists", [ - "changed", # list of user_ids whose devices may have changed - "left", # list of user_ids whose devices we no longer track -])): +class DeviceLists( + collections.namedtuple( + "DeviceLists", + [ + "changed", # list of user_ids whose devices may have changed + "left", # list of user_ids whose devices we no longer track + ], + ) +): __slots__ = [] def __nonzero__(self): return bool(self.changed or self.left) + __bool__ = __nonzero__ # python3 -class SyncResult(collections.namedtuple("SyncResult", [ - "next_batch", # Token for the next sync - "presence", # List of presence events for the user. - "account_data", # List of account_data events for the user. - "joined", # JoinedSyncResult for each joined room. - "invited", # InvitedSyncResult for each invited room. - "archived", # ArchivedSyncResult for each archived room. - "to_device", # List of direct messages for the device. - "device_lists", # List of user_ids whose devices have changed - "device_one_time_keys_count", # Dict of algorithm to count for one time keys - # for this device - "groups", -])): +class SyncResult( + collections.namedtuple( + "SyncResult", + [ + "next_batch", # Token for the next sync + "presence", # List of presence events for the user. + "account_data", # List of account_data events for the user. + "joined", # JoinedSyncResult for each joined room. + "invited", # InvitedSyncResult for each invited room. + "archived", # ArchivedSyncResult for each archived room. + "to_device", # List of direct messages for the device. + "device_lists", # List of user_ids whose devices have changed + "device_one_time_keys_count", # Dict of algorithm to count for one time keys + # for this device + "groups", + ], + ) +): __slots__ = [] def __nonzero__(self): @@ -190,20 +206,20 @@ def __nonzero__(self): events. """ return bool( - self.presence or - self.joined or - self.invited or - self.archived or - self.account_data or - self.to_device or - self.device_lists or - self.groups + self.presence + or self.joined + or self.invited + or self.archived + or self.account_data + or self.to_device + or self.device_lists + or self.groups ) + __bool__ = __nonzero__ # python3 class SyncHandler(object): - def __init__(self, hs): self.hs_config = hs.config self.store = hs.get_datastore() @@ -217,13 +233,16 @@ def __init__(self, hs): # ExpiringCache((User, Device)) -> LruCache(state_key => event_id) self.lazy_loaded_members_cache = ExpiringCache( - "lazy_loaded_members_cache", self.clock, - max_len=0, expiry_ms=LAZY_LOADED_MEMBERS_CACHE_MAX_AGE, + "lazy_loaded_members_cache", + self.clock, + max_len=0, + expiry_ms=LAZY_LOADED_MEMBERS_CACHE_MAX_AGE, ) @defer.inlineCallbacks - def wait_for_sync_for_user(self, sync_config, since_token=None, timeout=0, - full_state=False): + def wait_for_sync_for_user( + self, sync_config, since_token=None, timeout=0, full_state=False + ): """Get the sync for a client if we have new data for it now. Otherwise wait for new data to arrive on the server. If the timeout expires, then return an empty sync result. @@ -239,13 +258,15 @@ def wait_for_sync_for_user(self, sync_config, since_token=None, timeout=0, res = yield self.response_cache.wrap( sync_config.request_key, self._wait_for_sync_for_user, - sync_config, since_token, timeout, full_state, + sync_config, + since_token, + timeout, + full_state, ) defer.returnValue(res) @defer.inlineCallbacks - def _wait_for_sync_for_user(self, sync_config, since_token, timeout, - full_state): + def _wait_for_sync_for_user(self, sync_config, since_token, timeout, full_state): if since_token is None: sync_type = "initial_sync" elif full_state: @@ -261,14 +282,17 @@ def _wait_for_sync_for_user(self, sync_config, since_token, timeout, # we are going to return immediately, so don't bother calling # notifier.wait_for_events. result = yield self.current_sync_for_user( - sync_config, since_token, full_state=full_state, + sync_config, since_token, full_state=full_state ) else: + def current_sync_callback(before_token, after_token): return self.current_sync_for_user(sync_config, since_token) result = yield self.notifier.wait_for_events( - sync_config.user.to_string(), timeout, current_sync_callback, + sync_config.user.to_string(), + timeout, + current_sync_callback, from_token=since_token, ) @@ -281,8 +305,7 @@ def current_sync_callback(before_token, after_token): defer.returnValue(result) - def current_sync_for_user(self, sync_config, since_token=None, - full_state=False): + def current_sync_for_user(self, sync_config, since_token=None, full_state=False): """Get the sync for client needed to match what the server has now. Returns: A Deferred SyncResult. @@ -334,8 +357,7 @@ def ephemeral_by_room(self, sync_result_builder, now_token, since_token=None): # result returned by the event source is poor form (it might cache # the object) room_id = event["room_id"] - event_copy = {k: v for (k, v) in iteritems(event) - if k != "room_id"} + event_copy = {k: v for (k, v) in iteritems(event) if k != "room_id"} ephemeral_by_room.setdefault(room_id, []).append(event_copy) receipt_key = since_token.receipt_key if since_token else "0" @@ -353,22 +375,30 @@ def ephemeral_by_room(self, sync_result_builder, now_token, since_token=None): for event in receipts: room_id = event["room_id"] # exclude room id, as above - event_copy = {k: v for (k, v) in iteritems(event) - if k != "room_id"} + event_copy = {k: v for (k, v) in iteritems(event) if k != "room_id"} ephemeral_by_room.setdefault(room_id, []).append(event_copy) defer.returnValue((now_token, ephemeral_by_room)) @defer.inlineCallbacks - def _load_filtered_recents(self, room_id, sync_config, now_token, - since_token=None, recents=None, newly_joined_room=False): + def _load_filtered_recents( + self, + room_id, + sync_config, + now_token, + since_token=None, + recents=None, + newly_joined_room=False, + ): """ Returns: a Deferred TimelineBatch """ with Measure(self.clock, "load_filtered_recents"): timeline_limit = sync_config.filter_collection.timeline_limit() - block_all_timeline = sync_config.filter_collection.blocks_all_room_timeline() + block_all_timeline = ( + sync_config.filter_collection.blocks_all_room_timeline() + ) if recents is None or newly_joined_room or timeline_limit < len(recents): limited = True @@ -396,11 +426,9 @@ def _load_filtered_recents(self, room_id, sync_config, now_token, recents = [] if not limited or block_all_timeline: - defer.returnValue(TimelineBatch( - events=recents, - prev_batch=now_token, - limited=False - )) + defer.returnValue( + TimelineBatch(events=recents, prev_batch=now_token, limited=False) + ) filtering_factor = 2 load_limit = max(timeline_limit * filtering_factor, 10) @@ -427,9 +455,7 @@ def _load_filtered_recents(self, room_id, sync_config, now_token, ) else: events, end_key = yield self.store.get_recent_events_for_room( - room_id, - limit=load_limit + 1, - end_token=end_key, + room_id, limit=load_limit + 1, end_token=end_key ) loaded_recents = sync_config.filter_collection.filter_room_timeline( events @@ -462,15 +488,15 @@ def _load_filtered_recents(self, room_id, sync_config, now_token, recents = recents[-timeline_limit:] room_key = recents[0].internal_metadata.before - prev_batch_token = now_token.copy_and_replace( - "room_key", room_key - ) + prev_batch_token = now_token.copy_and_replace("room_key", room_key) - defer.returnValue(TimelineBatch( - events=recents, - prev_batch=prev_batch_token, - limited=limited or newly_joined_room - )) + defer.returnValue( + TimelineBatch( + events=recents, + prev_batch=prev_batch_token, + limited=limited or newly_joined_room, + ) + ) @defer.inlineCallbacks def get_state_after_event(self, event, state_filter=StateFilter.all()): @@ -486,7 +512,7 @@ def get_state_after_event(self, event, state_filter=StateFilter.all()): A Deferred map from ((type, state_key)->Event) """ state_ids = yield self.store.get_state_ids_for_event( - event.event_id, state_filter=state_filter, + event.event_id, state_filter=state_filter ) if event.is_state(): state_ids = state_ids.copy() @@ -511,13 +537,13 @@ def get_state_at(self, room_id, stream_position, state_filter=StateFilter.all()) # does not reliably give you the state at the given stream position. # (/~https://github.com/matrix-org/synapse/issues/3305) last_events, _ = yield self.store.get_recent_events_for_room( - room_id, end_token=stream_position.room_key, limit=1, + room_id, end_token=stream_position.room_key, limit=1 ) if last_events: last_event = last_events[-1] state = yield self.get_state_after_event( - last_event, state_filter=state_filter, + last_event, state_filter=state_filter ) else: @@ -549,7 +575,7 @@ def compute_summary(self, room_id, sync_config, batch, state, now_token): # FIXME: this promulgates /~https://github.com/matrix-org/synapse/issues/3305 last_events, _ = yield self.store.get_recent_event_ids_for_room( - room_id, end_token=now_token.room_key, limit=1, + room_id, end_token=now_token.room_key, limit=1 ) if not last_events: @@ -559,28 +585,25 @@ def compute_summary(self, room_id, sync_config, batch, state, now_token): last_event = last_events[-1] state_ids = yield self.store.get_state_ids_for_event( last_event.event_id, - state_filter=StateFilter.from_types([ - (EventTypes.Name, ''), - (EventTypes.CanonicalAlias, ''), - ]), + state_filter=StateFilter.from_types( + [(EventTypes.Name, ""), (EventTypes.CanonicalAlias, "")] + ), ) # this is heavily cached, thus: fast. details = yield self.store.get_room_summary(room_id) - name_id = state_ids.get((EventTypes.Name, '')) - canonical_alias_id = state_ids.get((EventTypes.CanonicalAlias, '')) + name_id = state_ids.get((EventTypes.Name, "")) + canonical_alias_id = state_ids.get((EventTypes.CanonicalAlias, "")) summary = {} empty_ms = MemberSummary([], 0) # TODO: only send these when they change. - summary["m.joined_member_count"] = ( - details.get(Membership.JOIN, empty_ms).count - ) - summary["m.invited_member_count"] = ( - details.get(Membership.INVITE, empty_ms).count - ) + summary["m.joined_member_count"] = details.get(Membership.JOIN, empty_ms).count + summary["m.invited_member_count"] = details.get( + Membership.INVITE, empty_ms + ).count # if the room has a name or canonical_alias set, we can skip # calculating heroes. Empty strings are falsey, so we check @@ -592,7 +615,7 @@ def compute_summary(self, room_id, sync_config, batch, state, now_token): if canonical_alias_id: canonical_alias = yield self.store.get_event( - canonical_alias_id, allow_none=True, + canonical_alias_id, allow_none=True ) if canonical_alias and canonical_alias.content.get("alias"): defer.returnValue(summary) @@ -600,26 +623,14 @@ def compute_summary(self, room_id, sync_config, batch, state, now_token): me = sync_config.user.to_string() joined_user_ids = [ - r[0] - for r in details.get(Membership.JOIN, empty_ms).members - if r[0] != me + r[0] for r in details.get(Membership.JOIN, empty_ms).members if r[0] != me ] invited_user_ids = [ - r[0] - for r in details.get(Membership.INVITE, empty_ms).members - if r[0] != me + r[0] for r in details.get(Membership.INVITE, empty_ms).members if r[0] != me ] - gone_user_ids = ( - [ - r[0] - for r in details.get(Membership.LEAVE, empty_ms).members - if r[0] != me - ] + [ - r[0] - for r in details.get(Membership.BAN, empty_ms).members - if r[0] != me - ] - ) + gone_user_ids = [ + r[0] for r in details.get(Membership.LEAVE, empty_ms).members if r[0] != me + ] + [r[0] for r in details.get(Membership.BAN, empty_ms).members if r[0] != me] # FIXME: only build up a member_ids list for our heroes member_ids = {} @@ -627,20 +638,18 @@ def compute_summary(self, room_id, sync_config, batch, state, now_token): Membership.JOIN, Membership.INVITE, Membership.LEAVE, - Membership.BAN + Membership.BAN, ): for user_id, event_id in details.get(membership, empty_ms).members: member_ids[user_id] = event_id # FIXME: order by stream ordering rather than as returned by SQL - if (joined_user_ids or invited_user_ids): - summary['m.heroes'] = sorted( + if joined_user_ids or invited_user_ids: + summary["m.heroes"] = sorted( [user_id for user_id in (joined_user_ids + invited_user_ids)] )[0:5] else: - summary['m.heroes'] = sorted( - [user_id for user_id in gone_user_ids] - )[0:5] + summary["m.heroes"] = sorted([user_id for user_id in gone_user_ids])[0:5] if not sync_config.filter_collection.lazy_load_members(): defer.returnValue(summary) @@ -652,8 +661,7 @@ def compute_summary(self, room_id, sync_config, batch, state, now_token): # track which members the client should already know about via LL: # Ones which are already in state... existing_members = set( - user_id for (typ, user_id) in state.keys() - if typ == EventTypes.Member + user_id for (typ, user_id) in state.keys() if typ == EventTypes.Member ) # ...or ones which are in the timeline... @@ -664,10 +672,10 @@ def compute_summary(self, room_id, sync_config, batch, state, now_token): # ...and then ensure any missing ones get included in state. missing_hero_event_ids = [ member_ids[hero_id] - for hero_id in summary['m.heroes'] + for hero_id in summary["m.heroes"] if ( - cache.get(hero_id) != member_ids[hero_id] and - hero_id not in existing_members + cache.get(hero_id) != member_ids[hero_id] + and hero_id not in existing_members ) ] @@ -691,8 +699,9 @@ def get_lazy_loaded_members_cache(self, cache_key): return cache @defer.inlineCallbacks - def compute_state_delta(self, room_id, batch, sync_config, since_token, now_token, - full_state): + def compute_state_delta( + self, room_id, batch, sync_config, since_token, now_token, full_state + ): """ Works out the difference in state between the start of the timeline and the previous sync. @@ -745,23 +754,23 @@ def compute_state_delta(self, room_id, batch, sync_config, since_token, now_toke timeline_state = { (event.type, event.state_key): event.event_id - for event in batch.events if event.is_state() + for event in batch.events + if event.is_state() } if full_state: if batch: current_state_ids = yield self.store.get_state_ids_for_event( - batch.events[-1].event_id, state_filter=state_filter, + batch.events[-1].event_id, state_filter=state_filter ) state_ids = yield self.store.get_state_ids_for_event( - batch.events[0].event_id, state_filter=state_filter, + batch.events[0].event_id, state_filter=state_filter ) else: current_state_ids = yield self.get_state_at( - room_id, stream_position=now_token, - state_filter=state_filter, + room_id, stream_position=now_token, state_filter=state_filter ) state_ids = current_state_ids @@ -775,7 +784,7 @@ def compute_state_delta(self, room_id, batch, sync_config, since_token, now_toke ) elif batch.limited: state_at_timeline_start = yield self.store.get_state_ids_for_event( - batch.events[0].event_id, state_filter=state_filter, + batch.events[0].event_id, state_filter=state_filter ) # for now, we disable LL for gappy syncs - see @@ -793,12 +802,11 @@ def compute_state_delta(self, room_id, batch, sync_config, since_token, now_toke state_filter = StateFilter.all() state_at_previous_sync = yield self.get_state_at( - room_id, stream_position=since_token, - state_filter=state_filter, + room_id, stream_position=since_token, state_filter=state_filter ) current_state_ids = yield self.store.get_state_ids_for_event( - batch.events[-1].event_id, state_filter=state_filter, + batch.events[-1].event_id, state_filter=state_filter ) state_ids = _calculate_state( @@ -854,8 +862,7 @@ def compute_state_delta(self, room_id, batch, sync_config, since_token, now_toke # add any member IDs we are about to send into our LruCache for t, event_id in itertools.chain( - state_ids.items(), - timeline_state.items(), + state_ids.items(), timeline_state.items() ): if t[0] == EventTypes.Member: cache.set(t[1], event_id) @@ -864,10 +871,14 @@ def compute_state_delta(self, room_id, batch, sync_config, since_token, now_toke if state_ids: state = yield self.store.get_events(list(state_ids.values())) - defer.returnValue({ - (e.type, e.state_key): e - for e in sync_config.filter_collection.filter_room_state(list(state.values())) - }) + defer.returnValue( + { + (e.type, e.state_key): e + for e in sync_config.filter_collection.filter_room_state( + list(state.values()) + ) + } + ) @defer.inlineCallbacks def unread_notifs_for_room_id(self, room_id, sync_config): @@ -875,7 +886,7 @@ def unread_notifs_for_room_id(self, room_id, sync_config): last_unread_event_id = yield self.store.get_last_receipt_event_id_for_user( user_id=sync_config.user.to_string(), room_id=room_id, - receipt_type="m.read" + receipt_type="m.read", ) notifs = [] @@ -909,7 +920,9 @@ def generate_sync_result(self, sync_config, since_token=None, full_state=False): logger.info( "Calculating sync response for %r between %s and %s", - sync_config.user, since_token, now_token, + sync_config.user, + since_token, + now_token, ) user_id = sync_config.user.to_string() @@ -920,11 +933,12 @@ def generate_sync_result(self, sync_config, since_token=None, full_state=False): raise NotImplementedError() else: joined_room_ids = yield self.get_rooms_for_user_at( - user_id, now_token.room_stream_id, + user_id, now_token.room_stream_id ) sync_result_builder = SyncResultBuilder( - sync_config, full_state, + sync_config, + full_state, since_token=since_token, now_token=now_token, joined_room_ids=joined_room_ids, @@ -941,8 +955,7 @@ def generate_sync_result(self, sync_config, since_token=None, full_state=False): _, _, newly_left_rooms, newly_left_users = res block_all_presence_data = ( - since_token is None and - sync_config.filter_collection.blocks_all_presence() + since_token is None and sync_config.filter_collection.blocks_all_presence() ) if self.hs_config.use_presence and not block_all_presence_data: yield self._generate_sync_entry_for_presence( @@ -973,22 +986,23 @@ def generate_sync_result(self, sync_config, since_token=None, full_state=False): room_id = joined_room.room_id if room_id in newly_joined_rooms: issue4422_logger.debug( - "Sync result for newly joined room %s: %r", - room_id, joined_room, + "Sync result for newly joined room %s: %r", room_id, joined_room ) - defer.returnValue(SyncResult( - presence=sync_result_builder.presence, - account_data=sync_result_builder.account_data, - joined=sync_result_builder.joined, - invited=sync_result_builder.invited, - archived=sync_result_builder.archived, - to_device=sync_result_builder.to_device, - device_lists=device_lists, - groups=sync_result_builder.groups, - device_one_time_keys_count=one_time_key_counts, - next_batch=sync_result_builder.now_token, - )) + defer.returnValue( + SyncResult( + presence=sync_result_builder.presence, + account_data=sync_result_builder.account_data, + joined=sync_result_builder.joined, + invited=sync_result_builder.invited, + archived=sync_result_builder.archived, + to_device=sync_result_builder.to_device, + device_lists=device_lists, + groups=sync_result_builder.groups, + device_one_time_keys_count=one_time_key_counts, + next_batch=sync_result_builder.now_token, + ) + ) @measure_func("_generate_sync_entry_for_groups") @defer.inlineCallbacks @@ -999,11 +1013,11 @@ def _generate_sync_entry_for_groups(self, sync_result_builder): if since_token and since_token.groups_key: results = yield self.store.get_groups_changes_for_user( - user_id, since_token.groups_key, now_token.groups_key, + user_id, since_token.groups_key, now_token.groups_key ) else: results = yield self.store.get_all_groups_for_user( - user_id, now_token.groups_key, + user_id, now_token.groups_key ) invited = {} @@ -1031,17 +1045,19 @@ def _generate_sync_entry_for_groups(self, sync_result_builder): left[group_id] = content["content"] sync_result_builder.groups = GroupsSyncResult( - join=joined, - invite=invited, - leave=left, + join=joined, invite=invited, leave=left ) @measure_func("_generate_sync_entry_for_device_list") @defer.inlineCallbacks - def _generate_sync_entry_for_device_list(self, sync_result_builder, - newly_joined_rooms, - newly_joined_or_invited_users, - newly_left_rooms, newly_left_users): + def _generate_sync_entry_for_device_list( + self, + sync_result_builder, + newly_joined_rooms, + newly_joined_or_invited_users, + newly_left_rooms, + newly_left_users, + ): user_id = sync_result_builder.sync_config.user.to_string() since_token = sync_result_builder.since_token @@ -1065,24 +1081,20 @@ def _generate_sync_entry_for_device_list(self, sync_result_builder, changed.update(newly_joined_or_invited_users) if not changed and not newly_left_users: - defer.returnValue(DeviceLists( - changed=[], - left=newly_left_users, - )) + defer.returnValue(DeviceLists(changed=[], left=newly_left_users)) users_who_share_room = yield self.store.get_users_who_share_room_with_user( user_id ) - defer.returnValue(DeviceLists( - changed=users_who_share_room & changed, - left=set(newly_left_users) - users_who_share_room, - )) + defer.returnValue( + DeviceLists( + changed=users_who_share_room & changed, + left=set(newly_left_users) - users_who_share_room, + ) + ) else: - defer.returnValue(DeviceLists( - changed=[], - left=[], - )) + defer.returnValue(DeviceLists(changed=[], left=[])) @defer.inlineCallbacks def _generate_sync_entry_for_to_device(self, sync_result_builder): @@ -1109,8 +1121,9 @@ def _generate_sync_entry_for_to_device(self, sync_result_builder): deleted = yield self.store.delete_messages_for_device( user_id, device_id, since_stream_id ) - logger.debug("Deleted %d to-device messages up to %d", - deleted, since_stream_id) + logger.debug( + "Deleted %d to-device messages up to %d", deleted, since_stream_id + ) messages, stream_id = yield self.store.get_new_messages_for_device( user_id, device_id, since_stream_id, now_token.to_device_key @@ -1118,7 +1131,10 @@ def _generate_sync_entry_for_to_device(self, sync_result_builder): logger.debug( "Returning %d to-device messages between %d and %d (current token: %d)", - len(messages), since_stream_id, stream_id, now_token.to_device_key + len(messages), + since_stream_id, + stream_id, + now_token.to_device_key, ) sync_result_builder.now_token = now_token.copy_and_replace( "to_device_key", stream_id @@ -1145,8 +1161,7 @@ def _generate_sync_entry_for_account_data(self, sync_result_builder): if since_token and not sync_result_builder.full_state: account_data, account_data_by_room = ( yield self.store.get_updated_account_data_for_user( - user_id, - since_token.account_data_key, + user_id, since_token.account_data_key ) ) @@ -1160,27 +1175,28 @@ def _generate_sync_entry_for_account_data(self, sync_result_builder): ) else: account_data, account_data_by_room = ( - yield self.store.get_account_data_for_user( - sync_config.user.to_string() - ) + yield self.store.get_account_data_for_user(sync_config.user.to_string()) ) - account_data['m.push_rules'] = yield self.push_rules_for_user( + account_data["m.push_rules"] = yield self.push_rules_for_user( sync_config.user ) - account_data_for_user = sync_config.filter_collection.filter_account_data([ - {"type": account_data_type, "content": content} - for account_data_type, content in account_data.items() - ]) + account_data_for_user = sync_config.filter_collection.filter_account_data( + [ + {"type": account_data_type, "content": content} + for account_data_type, content in account_data.items() + ] + ) sync_result_builder.account_data = account_data_for_user defer.returnValue(account_data_by_room) @defer.inlineCallbacks - def _generate_sync_entry_for_presence(self, sync_result_builder, newly_joined_rooms, - newly_joined_or_invited_users): + def _generate_sync_entry_for_presence( + self, sync_result_builder, newly_joined_rooms, newly_joined_or_invited_users + ): """Generates the presence portion of the sync response. Populates the `sync_result_builder` with the result. @@ -1223,17 +1239,13 @@ def _generate_sync_entry_for_presence(self, sync_result_builder, newly_joined_ro extra_users_ids.discard(user.to_string()) if extra_users_ids: - states = yield self.presence_handler.get_states( - extra_users_ids, - ) + states = yield self.presence_handler.get_states(extra_users_ids) presence.extend(states) # Deduplicate the presence entries so that there's at most one per user presence = list({p.user_id: p for p in presence}.values()) - presence = sync_config.filter_collection.filter_presence( - presence - ) + presence = sync_config.filter_collection.filter_presence(presence) sync_result_builder.presence = presence @@ -1253,8 +1265,8 @@ def _generate_sync_entry_for_rooms(self, sync_result_builder, account_data_by_ro """ user_id = sync_result_builder.sync_config.user.to_string() block_all_room_ephemeral = ( - sync_result_builder.since_token is None and - sync_result_builder.sync_config.filter_collection.blocks_all_room_ephemeral() + sync_result_builder.since_token is None + and sync_result_builder.sync_config.filter_collection.blocks_all_room_ephemeral() ) if block_all_room_ephemeral: @@ -1275,15 +1287,14 @@ def _generate_sync_entry_for_rooms(self, sync_result_builder, account_data_by_ro have_changed = yield self._have_rooms_changed(sync_result_builder) if not have_changed: tags_by_room = yield self.store.get_updated_tags( - user_id, - since_token.account_data_key, + user_id, since_token.account_data_key ) if not tags_by_room: logger.debug("no-oping sync") defer.returnValue(([], [], [], [])) ignored_account_data = yield self.store.get_global_account_data_by_type_for_user( - "m.ignored_user_list", user_id=user_id, + "m.ignored_user_list", user_id=user_id ) if ignored_account_data: @@ -1296,7 +1307,7 @@ def _generate_sync_entry_for_rooms(self, sync_result_builder, account_data_by_ro room_entries, invited, newly_joined_rooms, newly_left_rooms = res tags_by_room = yield self.store.get_updated_tags( - user_id, since_token.account_data_key, + user_id, since_token.account_data_key ) else: res = yield self._get_all_rooms(sync_result_builder, ignored_users) @@ -1331,8 +1342,8 @@ def handle_room_entries(room_entry): for event in it: if event.type == EventTypes.Member: if ( - event.membership == Membership.JOIN or - event.membership == Membership.INVITE + event.membership == Membership.JOIN + or event.membership == Membership.INVITE ): newly_joined_or_invited_users.add(event.state_key) else: @@ -1343,12 +1354,14 @@ def handle_room_entries(room_entry): newly_left_users -= newly_joined_or_invited_users - defer.returnValue(( - newly_joined_rooms, - newly_joined_or_invited_users, - newly_left_rooms, - newly_left_users, - )) + defer.returnValue( + ( + newly_joined_rooms, + newly_joined_or_invited_users, + newly_left_rooms, + newly_left_users, + ) + ) @defer.inlineCallbacks def _have_rooms_changed(self, sync_result_builder): @@ -1454,7 +1467,9 @@ def _get_rooms_changed(self, sync_result_builder, ignored_users): prev_membership = old_mem_ev.membership issue4422_logger.debug( "Previous membership for room %s with join: %s (event %s)", - room_id, prev_membership, old_mem_ev_id, + room_id, + prev_membership, + old_mem_ev_id, ) if not old_mem_ev or old_mem_ev.membership != Membership.JOIN: @@ -1476,8 +1491,7 @@ def _get_rooms_changed(self, sync_result_builder, ignored_users): if not old_state_ids: old_state_ids = yield self.get_state_at(room_id, since_token) old_mem_ev_id = old_state_ids.get( - (EventTypes.Member, user_id), - None, + (EventTypes.Member, user_id), None ) old_mem_ev = None if old_mem_ev_id: @@ -1498,7 +1512,8 @@ def _get_rooms_changed(self, sync_result_builder, ignored_users): # Always include leave/ban events. Just take the last one. # TODO: How do we handle ban -> leave in same batch? leave_events = [ - e for e in non_joins + e + for e in non_joins if e.membership in (Membership.LEAVE, Membership.BAN) ] @@ -1526,15 +1541,17 @@ def _get_rooms_changed(self, sync_result_builder, ignored_users): else: batch_events = None - room_entries.append(RoomSyncResultBuilder( - room_id=room_id, - rtype="archived", - events=batch_events, - newly_joined=room_id in newly_joined_rooms, - full_state=False, - since_token=since_token, - upto_token=leave_token, - )) + room_entries.append( + RoomSyncResultBuilder( + room_id=room_id, + rtype="archived", + events=batch_events, + newly_joined=room_id in newly_joined_rooms, + full_state=False, + since_token=since_token, + upto_token=leave_token, + ) + ) timeline_limit = sync_config.filter_collection.timeline_limit() @@ -1581,7 +1598,8 @@ def _get_rooms_changed(self, sync_result_builder, ignored_users): # debugging for /~https://github.com/matrix-org/synapse/issues/4422 issue4422_logger.debug( "RoomSyncResultBuilder events for newly joined room %s: %r", - room_id, entry.events, + room_id, + entry.events, ) room_entries.append(entry) @@ -1606,12 +1624,14 @@ def _get_all_rooms(self, sync_result_builder, ignored_users): sync_config = sync_result_builder.sync_config membership_list = ( - Membership.INVITE, Membership.JOIN, Membership.LEAVE, Membership.BAN + Membership.INVITE, + Membership.JOIN, + Membership.LEAVE, + Membership.BAN, ) room_list = yield self.store.get_rooms_for_user_where_membership_is( - user_id=user_id, - membership_list=membership_list + user_id=user_id, membership_list=membership_list ) room_entries = [] @@ -1619,23 +1639,22 @@ def _get_all_rooms(self, sync_result_builder, ignored_users): for event in room_list: if event.membership == Membership.JOIN: - room_entries.append(RoomSyncResultBuilder( - room_id=event.room_id, - rtype="joined", - events=None, - newly_joined=False, - full_state=True, - since_token=since_token, - upto_token=now_token, - )) + room_entries.append( + RoomSyncResultBuilder( + room_id=event.room_id, + rtype="joined", + events=None, + newly_joined=False, + full_state=True, + since_token=since_token, + upto_token=now_token, + ) + ) elif event.membership == Membership.INVITE: if event.sender in ignored_users: continue invite = yield self.store.get_event(event.event_id) - invited.append(InvitedSyncResult( - room_id=event.room_id, - invite=invite, - )) + invited.append(InvitedSyncResult(room_id=event.room_id, invite=invite)) elif event.membership in (Membership.LEAVE, Membership.BAN): # Always send down rooms we were banned or kicked from. if not sync_config.filter_collection.include_leave: @@ -1646,22 +1665,31 @@ def _get_all_rooms(self, sync_result_builder, ignored_users): leave_token = now_token.copy_and_replace( "room_key", "s%d" % (event.stream_ordering,) ) - room_entries.append(RoomSyncResultBuilder( - room_id=event.room_id, - rtype="archived", - events=None, - newly_joined=False, - full_state=True, - since_token=since_token, - upto_token=leave_token, - )) + room_entries.append( + RoomSyncResultBuilder( + room_id=event.room_id, + rtype="archived", + events=None, + newly_joined=False, + full_state=True, + since_token=since_token, + upto_token=leave_token, + ) + ) defer.returnValue((room_entries, invited, [])) @defer.inlineCallbacks - def _generate_room_entry(self, sync_result_builder, ignored_users, - room_builder, ephemeral, tags, account_data, - always_include=False): + def _generate_room_entry( + self, + sync_result_builder, + ignored_users, + room_builder, + ephemeral, + tags, + account_data, + always_include=False, + ): """Populates the `joined` and `archived` section of `sync_result_builder` based on the `room_builder`. @@ -1678,9 +1706,7 @@ def _generate_room_entry(self, sync_result_builder, ignored_users, """ newly_joined = room_builder.newly_joined full_state = ( - room_builder.full_state - or newly_joined - or sync_result_builder.full_state + room_builder.full_state or newly_joined or sync_result_builder.full_state ) events = room_builder.events @@ -1697,7 +1723,8 @@ def _generate_room_entry(self, sync_result_builder, ignored_users, upto_token = room_builder.upto_token batch = yield self._load_filtered_recents( - room_id, sync_config, + room_id, + sync_config, now_token=upto_token, since_token=since_token, recents=events, @@ -1708,7 +1735,8 @@ def _generate_room_entry(self, sync_result_builder, ignored_users, # debug for /~https://github.com/matrix-org/synapse/issues/4422 issue4422_logger.debug( "Timeline events after filtering in newly-joined room %s: %r", - room_id, batch, + room_id, + batch, ) # When we join the room (or the client requests full_state), we should @@ -1726,16 +1754,10 @@ def _generate_room_entry(self, sync_result_builder, ignored_users, account_data_events = [] if tags is not None: - account_data_events.append({ - "type": "m.tag", - "content": {"tags": tags}, - }) + account_data_events.append({"type": "m.tag", "content": {"tags": tags}}) for account_data_type, content in account_data.items(): - account_data_events.append({ - "type": account_data_type, - "content": content, - }) + account_data_events.append({"type": account_data_type, "content": content}) account_data_events = sync_config.filter_collection.filter_room_account_data( account_data_events @@ -1743,16 +1765,13 @@ def _generate_room_entry(self, sync_result_builder, ignored_users, ephemeral = sync_config.filter_collection.filter_room_ephemeral(ephemeral) - if not (always_include - or batch - or account_data_events - or ephemeral - or full_state): + if not ( + always_include or batch or account_data_events or ephemeral or full_state + ): return state = yield self.compute_state_delta( - room_id, batch, sync_config, since_token, now_token, - full_state=full_state + room_id, batch, sync_config, since_token, now_token, full_state=full_state ) summary = {} @@ -1760,22 +1779,19 @@ def _generate_room_entry(self, sync_result_builder, ignored_users, # we include a summary in room responses when we're lazy loading # members (as the client otherwise doesn't have enough info to form # the name itself). - if ( - sync_config.filter_collection.lazy_load_members() and - ( - # we recalulate the summary: - # if there are membership changes in the timeline, or - # if membership has changed during a gappy sync, or - # if this is an initial sync. - any(ev.type == EventTypes.Member for ev in batch.events) or - ( - # XXX: this may include false positives in the form of LL - # members which have snuck into state - batch.limited and - any(t == EventTypes.Member for (t, k) in state) - ) or - since_token is None + if sync_config.filter_collection.lazy_load_members() and ( + # we recalulate the summary: + # if there are membership changes in the timeline, or + # if membership has changed during a gappy sync, or + # if this is an initial sync. + any(ev.type == EventTypes.Member for ev in batch.events) + or ( + # XXX: this may include false positives in the form of LL + # members which have snuck into state + batch.limited + and any(t == EventTypes.Member for (t, k) in state) ) + or since_token is None ): summary = yield self.compute_summary( room_id, sync_config, batch, state, now_token @@ -1794,9 +1810,7 @@ def _generate_room_entry(self, sync_result_builder, ignored_users, ) if room_sync or always_include: - notifs = yield self.unread_notifs_for_room_id( - room_id, sync_config - ) + notifs = yield self.unread_notifs_for_room_id(room_id, sync_config) if notifs is not None: unread_notifications["notification_count"] = notifs["notify_count"] @@ -1807,11 +1821,8 @@ def _generate_room_entry(self, sync_result_builder, ignored_users, if batch.limited and since_token: user_id = sync_result_builder.sync_config.user.to_string() logger.info( - "Incremental gappy sync of %s for user %s with %d state events" % ( - room_id, - user_id, - len(state), - ) + "Incremental gappy sync of %s for user %s with %d state events" + % (room_id, user_id, len(state)) ) elif room_builder.rtype == "archived": room_sync = ArchivedSyncResult( @@ -1841,9 +1852,7 @@ def get_rooms_for_user_at(self, user_id, stream_ordering): Deferred[frozenset[str]]: Set of room_ids the user is in at given stream_ordering. """ - joined_rooms = yield self.store.get_rooms_for_user_with_stream_ordering( - user_id, - ) + joined_rooms = yield self.store.get_rooms_for_user_with_stream_ordering(user_id) joined_room_ids = set() @@ -1862,11 +1871,9 @@ def get_rooms_for_user_at(self, user_id, stream_ordering): logger.info("User joined room after current token: %s", room_id) extrems = yield self.store.get_forward_extremeties_for_room( - room_id, stream_ordering, - ) - users_in_room = yield self.state.get_current_users_in_room( - room_id, extrems, + room_id, stream_ordering ) + users_in_room = yield self.state.get_current_users_in_room(room_id, extrems) if user_id in users_in_room: joined_room_ids.add(room_id) @@ -1886,7 +1893,7 @@ def _action_has_highlight(actions): def _calculate_state( - timeline_contains, timeline_start, previous, current, lazy_load_members, + timeline_contains, timeline_start, previous, current, lazy_load_members ): """Works out what state to include in a sync response. @@ -1930,15 +1937,12 @@ def _calculate_state( if lazy_load_members: p_ids.difference_update( - e for t, e in iteritems(timeline_start) - if t[0] == EventTypes.Member + e for t, e in iteritems(timeline_start) if t[0] == EventTypes.Member ) state_ids = ((c_ids | ts_ids) - p_ids) - tc_ids - return { - event_id_to_key[e]: e for e in state_ids - } + return {event_id_to_key[e]: e for e in state_ids} class SyncResultBuilder(object): @@ -1961,8 +1965,10 @@ class SyncResultBuilder(object): groups (GroupsSyncResult|None) to_device (list) """ - def __init__(self, sync_config, full_state, since_token, now_token, - joined_room_ids): + + def __init__( + self, sync_config, full_state, since_token, now_token, joined_room_ids + ): """ Args: sync_config (SyncConfig) @@ -1991,8 +1997,10 @@ class RoomSyncResultBuilder(object): """Stores information needed to create either a `JoinedSyncResult` or `ArchivedSyncResult`. """ - def __init__(self, room_id, rtype, events, newly_joined, full_state, - since_token, upto_token): + + def __init__( + self, room_id, rtype, events, newly_joined, full_state, since_token, upto_token + ): """ Args: room_id(str) diff --git a/synapse/handlers/typing.py b/synapse/handlers/typing.py index 972662eb4819..f8062c867133 100644 --- a/synapse/handlers/typing.py +++ b/synapse/handlers/typing.py @@ -68,13 +68,10 @@ def __init__(self, hs): # caches which room_ids changed at which serials self._typing_stream_change_cache = StreamChangeCache( - "TypingStreamChangeCache", self._latest_room_serial, + "TypingStreamChangeCache", self._latest_room_serial ) - self.clock.looping_call( - self._handle_timeouts, - 5000, - ) + self.clock.looping_call(self._handle_timeouts, 5000) def _reset(self): """ @@ -108,19 +105,11 @@ def _handle_timeouts(self): if self.hs.is_mine_id(member.user_id): last_fed_poke = self._member_last_federation_poke.get(member, None) if not last_fed_poke or last_fed_poke + FEDERATION_PING_INTERVAL <= now: - run_in_background( - self._push_remote, - member=member, - typing=True - ) + run_in_background(self._push_remote, member=member, typing=True) # Add a paranoia timer to ensure that we always have a timer for # each person typing. - self.wheel_timer.insert( - now=now, - obj=member, - then=now + 60 * 1000, - ) + self.wheel_timer.insert(now=now, obj=member, then=now + 60 * 1000) def is_typing(self, member): return member.user_id in self._room_typing.get(member.room_id, []) @@ -138,9 +127,7 @@ def started_typing(self, target_user, auth_user, room_id, timeout): yield self.auth.check_joined_room(room_id, target_user_id) - logger.debug( - "%s has started typing in %s", target_user_id, room_id - ) + logger.debug("%s has started typing in %s", target_user_id, room_id) member = RoomMember(room_id=room_id, user_id=target_user_id) @@ -149,20 +136,13 @@ def started_typing(self, target_user, auth_user, room_id, timeout): now = self.clock.time_msec() self._member_typing_until[member] = now + timeout - self.wheel_timer.insert( - now=now, - obj=member, - then=now + timeout, - ) + self.wheel_timer.insert(now=now, obj=member, then=now + timeout) if was_present: # No point sending another notification defer.returnValue(None) - self._push_update( - member=member, - typing=True, - ) + self._push_update(member=member, typing=True) @defer.inlineCallbacks def stopped_typing(self, target_user, auth_user, room_id): @@ -177,9 +157,7 @@ def stopped_typing(self, target_user, auth_user, room_id): yield self.auth.check_joined_room(room_id, target_user_id) - logger.debug( - "%s has stopped typing in %s", target_user_id, room_id - ) + logger.debug("%s has stopped typing in %s", target_user_id, room_id) member = RoomMember(room_id=room_id, user_id=target_user_id) @@ -200,20 +178,14 @@ def _stopped_typing(self, member): self._member_typing_until.pop(member, None) self._member_last_federation_poke.pop(member, None) - self._push_update( - member=member, - typing=False, - ) + self._push_update(member=member, typing=False) def _push_update(self, member, typing): if self.hs.is_mine_id(member.user_id): # Only send updates for changes to our own users. run_in_background(self._push_remote, member, typing) - self._push_update_local( - member=member, - typing=typing - ) + self._push_update_local(member=member, typing=typing) @defer.inlineCallbacks def _push_remote(self, member, typing): @@ -223,9 +195,7 @@ def _push_remote(self, member, typing): now = self.clock.time_msec() self.wheel_timer.insert( - now=now, - obj=member, - then=now + FEDERATION_PING_INTERVAL, + now=now, obj=member, then=now + FEDERATION_PING_INTERVAL ) for domain in set(get_domain_from_id(u) for u in users): @@ -256,8 +226,7 @@ def _recv_edu(self, origin, content): if user.domain != origin: logger.info( - "Got typing update from %r with bad 'user_id': %r", - origin, user_id, + "Got typing update from %r with bad 'user_id': %r", origin, user_id ) return @@ -268,15 +237,8 @@ def _recv_edu(self, origin, content): logger.info("Got typing update from %s: %r", user_id, content) now = self.clock.time_msec() self._member_typing_until[member] = now + FEDERATION_TIMEOUT - self.wheel_timer.insert( - now=now, - obj=member, - then=now + FEDERATION_TIMEOUT, - ) - self._push_update_local( - member=member, - typing=content["typing"] - ) + self.wheel_timer.insert(now=now, obj=member, then=now + FEDERATION_TIMEOUT) + self._push_update_local(member=member, typing=content["typing"]) def _push_update_local(self, member, typing): room_set = self._room_typing.setdefault(member.room_id, set()) @@ -288,7 +250,7 @@ def _push_update_local(self, member, typing): self._latest_room_serial += 1 self._room_serials[member.room_id] = self._latest_room_serial self._typing_stream_change_cache.entity_has_changed( - member.room_id, self._latest_room_serial, + member.room_id, self._latest_room_serial ) self.notifier.on_new_event( @@ -300,7 +262,7 @@ def get_all_typing_updates(self, last_id, current_id): return [] changed_rooms = self._typing_stream_change_cache.get_all_entities_changed( - last_id, + last_id ) if changed_rooms is None: @@ -334,9 +296,7 @@ def _make_event_for(self, room_id): return { "type": "m.typing", "room_id": room_id, - "content": { - "user_ids": list(typing), - }, + "content": {"user_ids": list(typing)}, } def get_new_events(self, from_key, room_ids, **kwargs): diff --git a/synapse/http/__init__.py b/synapse/http/__init__.py index d36bcd6336ea..3acf772cd16b 100644 --- a/synapse/http/__init__.py +++ b/synapse/http/__init__.py @@ -25,6 +25,7 @@ class RequestTimedOutError(SynapseError): """Exception representing timeout of an outbound request""" + def __init__(self): super(RequestTimedOutError, self).__init__(504, "Timed out") @@ -40,15 +41,12 @@ def cancelled_to_request_timed_out_error(value, timeout): return value -ACCESS_TOKEN_RE = re.compile(r'(\?.*access(_|%5[Ff])token=)[^&]*(.*)$') +ACCESS_TOKEN_RE = re.compile(r"(\?.*access(_|%5[Ff])token=)[^&]*(.*)$") def redact_uri(uri): """Strips access tokens from the uri replaces with """ - return ACCESS_TOKEN_RE.sub( - r'\1\3', - uri - ) + return ACCESS_TOKEN_RE.sub(r"\1\3", uri) class QuieterFileBodyProducer(FileBodyProducer): @@ -57,6 +55,7 @@ class QuieterFileBodyProducer(FileBodyProducer): Workaround for /~https://github.com/matrix-org/synapse/issues/4003 / https://twistedmatrix.com/trac/ticket/6528 """ + def stopProducing(self): try: FileBodyProducer.stopProducing(self) diff --git a/synapse/http/additional_resource.py b/synapse/http/additional_resource.py index 0e10e3f8f740..096619a8c21b 100644 --- a/synapse/http/additional_resource.py +++ b/synapse/http/additional_resource.py @@ -28,6 +28,7 @@ class AdditionalResource(Resource): This class is also where we wrap the request handler with logging, metrics, and exception handling. """ + def __init__(self, hs, handler): """Initialise AdditionalResource diff --git a/synapse/http/client.py b/synapse/http/client.py index 5c073fff07f8..9bc7035c8dae 100644 --- a/synapse/http/client.py +++ b/synapse/http/client.py @@ -103,8 +103,8 @@ def _callback(): ip_address, self._ip_whitelist, self._ip_blacklist ): logger.info( - "Dropped %s from DNS resolution to %s due to blacklist" % - (ip_address, hostname) + "Dropped %s from DNS resolution to %s due to blacklist" + % (ip_address, hostname) ) has_bad_ip = True @@ -156,7 +156,7 @@ def __init__(self, agent, reactor, ip_whitelist=None, ip_blacklist=None): self._ip_blacklist = ip_blacklist def request(self, method, uri, headers=None, bodyProducer=None): - h = urllib.parse.urlparse(uri.decode('ascii')) + h = urllib.parse.urlparse(uri.decode("ascii")) try: ip_address = IPAddress(h.hostname) @@ -164,10 +164,7 @@ def request(self, method, uri, headers=None, bodyProducer=None): if check_against_blacklist( ip_address, self._ip_whitelist, self._ip_blacklist ): - logger.info( - "Blocking access to %s due to blacklist" % - (ip_address,) - ) + logger.info("Blocking access to %s due to blacklist" % (ip_address,)) e = SynapseError(403, "IP address blocked by IP blacklist entry") return defer.fail(Failure(e)) except Exception: @@ -206,7 +203,7 @@ def __init__(self, hs, treq_args={}, ip_whitelist=None, ip_blacklist=None): if hs.config.user_agent_suffix: self.user_agent = "%s %s" % (self.user_agent, hs.config.user_agent_suffix) - self.user_agent = self.user_agent.encode('ascii') + self.user_agent = self.user_agent.encode("ascii") if self._ip_blacklist: real_reactor = hs.get_reactor() @@ -520,8 +517,8 @@ def get_file(self, url, output_stream, max_size=None, headers=None): resp_headers = dict(response.headers.getAllRawHeaders()) if ( - b'Content-Length' in resp_headers - and int(resp_headers[b'Content-Length'][0]) > max_size + b"Content-Length" in resp_headers + and int(resp_headers[b"Content-Length"][0]) > max_size ): logger.warn("Requested URL is too large > %r bytes" % (self.max_size,)) raise SynapseError( @@ -546,18 +543,13 @@ def get_file(self, url, output_stream, max_size=None, headers=None): # This can happen e.g. because the body is too large. raise except Exception as e: - raise_from( - SynapseError( - 502, ("Failed to download remote body: %s" % e), - ), - e - ) + raise_from(SynapseError(502, ("Failed to download remote body: %s" % e)), e) defer.returnValue( ( length, resp_headers, - response.request.absoluteURI.decode('ascii'), + response.request.absoluteURI.decode("ascii"), response.code, ) ) @@ -647,7 +639,7 @@ def encode_urlencode_args(args): def encode_urlencode_arg(arg): if isinstance(arg, text_type): - return arg.encode('utf-8') + return arg.encode("utf-8") elif isinstance(arg, list): return [encode_urlencode_arg(i) for i in arg] else: diff --git a/synapse/http/endpoint.py b/synapse/http/endpoint.py index cd79ebab6225..92a5b606c8cd 100644 --- a/synapse/http/endpoint.py +++ b/synapse/http/endpoint.py @@ -31,7 +31,7 @@ def parse_server_name(server_name): ValueError if the server name could not be parsed. """ try: - if server_name[-1] == ']': + if server_name[-1] == "]": # ipv6 literal, hopefully return server_name, None @@ -43,9 +43,7 @@ def parse_server_name(server_name): raise ValueError("Invalid server name '%s'" % server_name) -VALID_HOST_REGEX = re.compile( - "\\A[0-9a-zA-Z.-]+\\Z", -) +VALID_HOST_REGEX = re.compile("\\A[0-9a-zA-Z.-]+\\Z") def parse_and_validate_server_name(server_name): @@ -67,17 +65,15 @@ def parse_and_validate_server_name(server_name): # that nobody is sneaking IP literals in that look like hostnames, etc. # look for ipv6 literals - if host[0] == '[': - if host[-1] != ']': - raise ValueError("Mismatched [...] in server name '%s'" % ( - server_name, - )) + if host[0] == "[": + if host[-1] != "]": + raise ValueError("Mismatched [...] in server name '%s'" % (server_name,)) return host, port # otherwise it should only be alphanumerics. if not VALID_HOST_REGEX.match(host): - raise ValueError("Server name '%s' contains invalid characters" % ( - server_name, - )) + raise ValueError( + "Server name '%s' contains invalid characters" % (server_name,) + ) return host, port diff --git a/synapse/http/federation/matrix_federation_agent.py b/synapse/http/federation/matrix_federation_agent.py index b4cbe97b41fa..414cde077776 100644 --- a/synapse/http/federation/matrix_federation_agent.py +++ b/synapse/http/federation/matrix_federation_agent.py @@ -48,7 +48,7 @@ WELL_KNOWN_MAX_CACHE_PERIOD = 48 * 3600 logger = logging.getLogger(__name__) -well_known_cache = TTLCache('well-known') +well_known_cache = TTLCache("well-known") @implementer(IAgent) @@ -78,7 +78,9 @@ class MatrixFederationAgent(object): """ def __init__( - self, reactor, tls_client_options_factory, + self, + reactor, + tls_client_options_factory, _well_known_tls_policy=None, _srv_resolver=None, _well_known_cache=well_known_cache, @@ -100,9 +102,9 @@ def __init__( if _well_known_tls_policy is not None: # the param is called 'contextFactory', but actually passing a # contextfactory is deprecated, and it expects an IPolicyForHTTPS. - agent_args['contextFactory'] = _well_known_tls_policy + agent_args["contextFactory"] = _well_known_tls_policy _well_known_agent = RedirectAgent( - Agent(self._reactor, pool=self._pool, **agent_args), + Agent(self._reactor, pool=self._pool, **agent_args) ) self._well_known_agent = _well_known_agent @@ -149,7 +151,7 @@ def request(self, method, uri, headers=None, bodyProducer=None): tls_options = None else: tls_options = self._tls_client_options_factory.get_options( - res.tls_server_name.decode("ascii"), + res.tls_server_name.decode("ascii") ) # make sure that the Host header is set correctly @@ -158,14 +160,14 @@ def request(self, method, uri, headers=None, bodyProducer=None): else: headers = headers.copy() - if not headers.hasHeader(b'host'): - headers.addRawHeader(b'host', res.host_header) + if not headers.hasHeader(b"host"): + headers.addRawHeader(b"host", res.host_header) class EndpointFactory(object): @staticmethod def endpointForURI(_uri): ep = LoggingHostnameEndpoint( - self._reactor, res.target_host, res.target_port, + self._reactor, res.target_host, res.target_port ) if tls_options is not None: ep = wrapClientTLS(tls_options, ep) @@ -203,21 +205,25 @@ def _route_matrix_uri(self, parsed_uri, lookup_well_known=True): port = parsed_uri.port if port == -1: port = 8448 - defer.returnValue(_RoutingResult( - host_header=parsed_uri.netloc, - tls_server_name=parsed_uri.host, - target_host=parsed_uri.host, - target_port=port, - )) + defer.returnValue( + _RoutingResult( + host_header=parsed_uri.netloc, + tls_server_name=parsed_uri.host, + target_host=parsed_uri.host, + target_port=port, + ) + ) if parsed_uri.port != -1: # there is an explicit port - defer.returnValue(_RoutingResult( - host_header=parsed_uri.netloc, - tls_server_name=parsed_uri.host, - target_host=parsed_uri.host, - target_port=parsed_uri.port, - )) + defer.returnValue( + _RoutingResult( + host_header=parsed_uri.netloc, + tls_server_name=parsed_uri.host, + target_host=parsed_uri.host, + target_port=parsed_uri.port, + ) + ) if lookup_well_known: # try a .well-known lookup @@ -229,8 +235,8 @@ def _route_matrix_uri(self, parsed_uri, lookup_well_known=True): # parse the server name in the .well-known response into host/port. # (This code is lifted from twisted.web.client.URI.fromBytes). - if b':' in well_known_server: - well_known_host, well_known_port = well_known_server.rsplit(b':', 1) + if b":" in well_known_server: + well_known_host, well_known_port = well_known_server.rsplit(b":", 1) try: well_known_port = int(well_known_port) except ValueError: @@ -264,21 +270,27 @@ def _route_matrix_uri(self, parsed_uri, lookup_well_known=True): port = 8448 logger.debug( "No SRV record for %s, using %s:%i", - parsed_uri.host.decode("ascii"), target_host.decode("ascii"), port, + parsed_uri.host.decode("ascii"), + target_host.decode("ascii"), + port, ) else: target_host, port = pick_server_from_list(server_list) logger.debug( "Picked %s:%i from SRV records for %s", - target_host.decode("ascii"), port, parsed_uri.host.decode("ascii"), + target_host.decode("ascii"), + port, + parsed_uri.host.decode("ascii"), ) - defer.returnValue(_RoutingResult( - host_header=parsed_uri.netloc, - tls_server_name=parsed_uri.host, - target_host=target_host, - target_port=port, - )) + defer.returnValue( + _RoutingResult( + host_header=parsed_uri.netloc, + tls_server_name=parsed_uri.host, + target_host=target_host, + target_port=port, + ) + ) @defer.inlineCallbacks def _get_well_known(self, server_name): @@ -318,18 +330,18 @@ def _do_get_well_known(self, server_name): - None if there was no .well-known file. - INVALID_WELL_KNOWN if the .well-known was invalid """ - uri = b"https://%s/.well-known/matrix/server" % (server_name, ) + uri = b"https://%s/.well-known/matrix/server" % (server_name,) uri_str = uri.decode("ascii") logger.info("Fetching %s", uri_str) try: response = yield make_deferred_yieldable( - self._well_known_agent.request(b"GET", uri), + self._well_known_agent.request(b"GET", uri) ) body = yield make_deferred_yieldable(readBody(response)) if response.code != 200: - raise Exception("Non-200 response %s" % (response.code, )) + raise Exception("Non-200 response %s" % (response.code,)) - parsed_body = json.loads(body.decode('utf-8')) + parsed_body = json.loads(body.decode("utf-8")) logger.info("Response from .well-known: %s", parsed_body) if not isinstance(parsed_body, dict): raise Exception("not a dict") @@ -347,8 +359,7 @@ def _do_get_well_known(self, server_name): result = parsed_body["m.server"].encode("ascii") cache_period = _cache_period_from_headers( - response.headers, - time_now=self._reactor.seconds, + response.headers, time_now=self._reactor.seconds ) if cache_period is None: cache_period = WELL_KNOWN_DEFAULT_CACHE_PERIOD @@ -364,6 +375,7 @@ def _do_get_well_known(self, server_name): @implementer(IStreamClientEndpoint) class LoggingHostnameEndpoint(object): """A wrapper for HostnameEndpint which logs when it connects""" + def __init__(self, reactor, host, port, *args, **kwargs): self.host = host self.port = port @@ -377,17 +389,17 @@ def connect(self, protocol_factory): def _cache_period_from_headers(headers, time_now=time.time): cache_controls = _parse_cache_control(headers) - if b'no-store' in cache_controls: + if b"no-store" in cache_controls: return 0 - if b'max-age' in cache_controls: + if b"max-age" in cache_controls: try: - max_age = int(cache_controls[b'max-age']) + max_age = int(cache_controls[b"max-age"]) return max_age except ValueError: pass - expires = headers.getRawHeaders(b'expires') + expires = headers.getRawHeaders(b"expires") if expires is not None: try: expires_date = stringToDatetime(expires[-1]) @@ -403,9 +415,9 @@ def _cache_period_from_headers(headers, time_now=time.time): def _parse_cache_control(headers): cache_controls = {} - for hdr in headers.getRawHeaders(b'cache-control', []): - for directive in hdr.split(b','): - splits = [x.strip() for x in directive.split(b'=', 1)] + for hdr in headers.getRawHeaders(b"cache-control", []): + for directive in hdr.split(b","): + splits = [x.strip() for x in directive.split(b"=", 1)] k = splits[0].lower() v = splits[1] if len(splits) > 1 else None cache_controls[k] = v diff --git a/synapse/http/federation/srv_resolver.py b/synapse/http/federation/srv_resolver.py index 71830c549d4b..1f22f78a755f 100644 --- a/synapse/http/federation/srv_resolver.py +++ b/synapse/http/federation/srv_resolver.py @@ -45,6 +45,7 @@ class Server(object): expires (int): when the cache should expire this record - in *seconds* since the epoch """ + host = attr.ib() port = attr.ib() priority = attr.ib(default=0) @@ -79,9 +80,7 @@ def pick_server_from_list(server_list): return s.host, s.port # this should be impossible. - raise RuntimeError( - "pick_server_from_list got to end of eligible server list.", - ) + raise RuntimeError("pick_server_from_list got to end of eligible server list.") class SrvResolver(object): @@ -95,6 +94,7 @@ class SrvResolver(object): cache (dict): cache object get_time (callable): clock implementation. Should return seconds since the epoch """ + def __init__(self, dns_client=client, cache=SERVER_CACHE, get_time=time.time): self._dns_client = dns_client self._cache = cache @@ -124,7 +124,7 @@ def resolve_service(self, service_name): try: answers, _, _ = yield make_deferred_yieldable( - self._dns_client.lookupService(service_name), + self._dns_client.lookupService(service_name) ) except DNSNameError: # TODO: cache this. We can get the SOA out of the exception, and use @@ -136,17 +136,18 @@ def resolve_service(self, service_name): cache_entry = self._cache.get(service_name, None) if cache_entry: logger.warn( - "Failed to resolve %r, falling back to cache. %r", - service_name, e + "Failed to resolve %r, falling back to cache. %r", service_name, e ) defer.returnValue(list(cache_entry)) else: raise e - if (len(answers) == 1 - and answers[0].type == dns.SRV - and answers[0].payload - and answers[0].payload.target == dns.Name(b'.')): + if ( + len(answers) == 1 + and answers[0].type == dns.SRV + and answers[0].payload + and answers[0].payload.target == dns.Name(b".") + ): raise ConnectError("Service %s unavailable" % service_name) servers = [] @@ -157,13 +158,15 @@ def resolve_service(self, service_name): payload = answer.payload - servers.append(Server( - host=payload.target.name, - port=payload.port, - priority=payload.priority, - weight=payload.weight, - expires=now + answer.ttl, - )) + servers.append( + Server( + host=payload.target.name, + port=payload.port, + priority=payload.priority, + weight=payload.weight, + expires=now + answer.ttl, + ) + ) self._cache[service_name] = list(servers) defer.returnValue(servers) diff --git a/synapse/http/matrixfederationclient.py b/synapse/http/matrixfederationclient.py index 663ea72a7a3a..5ef8bb60a3ab 100644 --- a/synapse/http/matrixfederationclient.py +++ b/synapse/http/matrixfederationclient.py @@ -54,10 +54,12 @@ logger = logging.getLogger(__name__) -outgoing_requests_counter = Counter("synapse_http_matrixfederationclient_requests", - "", ["method"]) -incoming_responses_counter = Counter("synapse_http_matrixfederationclient_responses", - "", ["method", "code"]) +outgoing_requests_counter = Counter( + "synapse_http_matrixfederationclient_requests", "", ["method"] +) +incoming_responses_counter = Counter( + "synapse_http_matrixfederationclient_responses", "", ["method", "code"] +) MAX_LONG_RETRIES = 10 @@ -137,11 +139,7 @@ def _handle_json_response(reactor, timeout_sec, request, response): check_content_type_is_json(response.headers) d = treq.json_content(response) - d = timeout_deferred( - d, - timeout=timeout_sec, - reactor=reactor, - ) + d = timeout_deferred(d, timeout=timeout_sec, reactor=reactor) body = yield make_deferred_yieldable(d) except Exception as e: @@ -157,7 +155,7 @@ def _handle_json_response(reactor, timeout_sec, request, response): request.txn_id, request.destination, response.code, - response.phrase.decode('ascii', errors='replace'), + response.phrase.decode("ascii", errors="replace"), ) defer.returnValue(body) @@ -181,7 +179,7 @@ def __init__(self, hs, tls_client_options_factory): # We need to use a DNS resolver which filters out blacklisted IP # addresses, to prevent DNS rebinding. nameResolver = IPBlacklistingResolver( - real_reactor, None, hs.config.federation_ip_range_blacklist, + real_reactor, None, hs.config.federation_ip_range_blacklist ) @implementer(IReactorPluggableNameResolver) @@ -194,21 +192,19 @@ def __getattr__(_self, attr): self.reactor = Reactor() - self.agent = MatrixFederationAgent( - self.reactor, - tls_client_options_factory, - ) + self.agent = MatrixFederationAgent(self.reactor, tls_client_options_factory) # Use a BlacklistingAgentWrapper to prevent circumventing the IP # blacklist via IP literals in server names self.agent = BlacklistingAgentWrapper( - self.agent, self.reactor, + self.agent, + self.reactor, ip_blacklist=hs.config.federation_ip_range_blacklist, ) self.clock = hs.get_clock() self._store = hs.get_datastore() - self.version_string_bytes = hs.version_string.encode('ascii') + self.version_string_bytes = hs.version_string.encode("ascii") self.default_timeout = 60 def schedule(x): @@ -218,10 +214,7 @@ def schedule(x): @defer.inlineCallbacks def _send_request_with_optional_trailing_slash( - self, - request, - try_trailing_slash_on_400=False, - **send_request_args + self, request, try_trailing_slash_on_400=False, **send_request_args ): """Wrapper for _send_request which can optionally retry the request upon receiving a combination of a 400 HTTP response code and a @@ -244,9 +237,7 @@ def _send_request_with_optional_trailing_slash( Deferred[Dict]: Parsed JSON response body. """ try: - response = yield self._send_request( - request, **send_request_args - ) + response = yield self._send_request(request, **send_request_args) except HttpResponseException as e: # Received an HTTP error > 300. Check if it meets the requirements # to retry with a trailing slash @@ -262,9 +253,7 @@ def _send_request_with_optional_trailing_slash( logger.info("Retrying request with trailing slash") request.path += "/" - response = yield self._send_request( - request, **send_request_args - ) + response = yield self._send_request(request, **send_request_args) defer.returnValue(response) @@ -329,8 +318,8 @@ def _send_request( _sec_timeout = self.default_timeout if ( - self.hs.config.federation_domain_whitelist is not None and - request.destination not in self.hs.config.federation_domain_whitelist + self.hs.config.federation_domain_whitelist is not None + and request.destination not in self.hs.config.federation_domain_whitelist ): raise FederationDeniedError(request.destination) @@ -350,9 +339,7 @@ def _send_request( else: query_bytes = b"" - headers_dict = { - b"User-Agent": [self.version_string_bytes], - } + headers_dict = {b"User-Agent": [self.version_string_bytes]} with limiter: # XXX: Would be much nicer to retry only at the transaction-layer @@ -362,16 +349,14 @@ def _send_request( else: retries_left = MAX_SHORT_RETRIES - url_bytes = urllib.parse.urlunparse(( - b"matrix", destination_bytes, - path_bytes, None, query_bytes, b"", - )) - url_str = url_bytes.decode('ascii') + url_bytes = urllib.parse.urlunparse( + (b"matrix", destination_bytes, path_bytes, None, query_bytes, b"") + ) + url_str = url_bytes.decode("ascii") - url_to_sign_bytes = urllib.parse.urlunparse(( - b"", b"", - path_bytes, None, query_bytes, b"", - )) + url_to_sign_bytes = urllib.parse.urlunparse( + (b"", b"", path_bytes, None, query_bytes, b"") + ) while True: try: @@ -379,26 +364,27 @@ def _send_request( if json: headers_dict[b"Content-Type"] = [b"application/json"] auth_headers = self.build_auth_headers( - destination_bytes, method_bytes, url_to_sign_bytes, - json, + destination_bytes, method_bytes, url_to_sign_bytes, json ) data = encode_canonical_json(json) producer = QuieterFileBodyProducer( - BytesIO(data), - cooperator=self._cooperator, + BytesIO(data), cooperator=self._cooperator ) else: producer = None auth_headers = self.build_auth_headers( - destination_bytes, method_bytes, url_to_sign_bytes, + destination_bytes, method_bytes, url_to_sign_bytes ) headers_dict[b"Authorization"] = auth_headers logger.info( "{%s} [%s] Sending request: %s %s; timeout %fs", - request.txn_id, request.destination, request.method, - url_str, _sec_timeout, + request.txn_id, + request.destination, + request.method, + url_str, + _sec_timeout, ) try: @@ -430,7 +416,7 @@ def _send_request( request.txn_id, request.destination, response.code, - response.phrase.decode('ascii', errors='replace'), + response.phrase.decode("ascii", errors="replace"), ) if 200 <= response.code < 300: @@ -440,9 +426,7 @@ def _send_request( # Update transactions table? d = treq.content(response) d = timeout_deferred( - d, - timeout=_sec_timeout, - reactor=self.reactor, + d, timeout=_sec_timeout, reactor=self.reactor ) try: @@ -460,9 +444,7 @@ def _send_request( ) body = None - e = HttpResponseException( - response.code, response.phrase, body - ) + e = HttpResponseException(response.code, response.phrase, body) # Retry if the error is a 429 (Too Many Requests), # otherwise just raise a standard HttpResponseException @@ -521,7 +503,7 @@ def _send_request( defer.returnValue(response) def build_auth_headers( - self, destination, method, url_bytes, content=None, destination_is=None, + self, destination, method, url_bytes, content=None, destination_is=None ): """ Builds the Authorization headers for a federation request @@ -538,11 +520,7 @@ def build_auth_headers( Returns: list[bytes]: a list of headers to be added as "Authorization:" headers """ - request = { - "method": method, - "uri": url_bytes, - "origin": self.server_name, - } + request = {"method": method, "uri": url_bytes, "origin": self.server_name} if destination is not None: request["destination"] = destination @@ -558,20 +536,28 @@ def build_auth_headers( auth_headers = [] for key, sig in request["signatures"][self.server_name].items(): - auth_headers.append(( - "X-Matrix origin=%s,key=\"%s\",sig=\"%s\"" % ( - self.server_name, key, sig, - )).encode('ascii') + auth_headers.append( + ( + 'X-Matrix origin=%s,key="%s",sig="%s"' + % (self.server_name, key, sig) + ).encode("ascii") ) return auth_headers @defer.inlineCallbacks - def put_json(self, destination, path, args={}, data={}, - json_data_callback=None, - long_retries=False, timeout=None, - ignore_backoff=False, - backoff_on_404=False, - try_trailing_slash_on_400=False): + def put_json( + self, + destination, + path, + args={}, + data={}, + json_data_callback=None, + long_retries=False, + timeout=None, + ignore_backoff=False, + backoff_on_404=False, + try_trailing_slash_on_400=False, + ): """ Sends the specifed json data using PUT Args: @@ -635,14 +621,22 @@ def put_json(self, destination, path, args={}, data={}, ) body = yield _handle_json_response( - self.reactor, self.default_timeout, request, response, + self.reactor, self.default_timeout, request, response ) defer.returnValue(body) @defer.inlineCallbacks - def post_json(self, destination, path, data={}, long_retries=False, - timeout=None, ignore_backoff=False, args={}): + def post_json( + self, + destination, + path, + data={}, + long_retries=False, + timeout=None, + ignore_backoff=False, + args={}, + ): """ Sends the specifed json data using POST Args: @@ -681,11 +675,7 @@ def post_json(self, destination, path, data={}, long_retries=False, """ request = MatrixFederationRequest( - method="POST", - destination=destination, - path=path, - query=args, - json=data, + method="POST", destination=destination, path=path, query=args, json=data ) response = yield self._send_request( @@ -701,14 +691,21 @@ def post_json(self, destination, path, data={}, long_retries=False, _sec_timeout = self.default_timeout body = yield _handle_json_response( - self.reactor, _sec_timeout, request, response, + self.reactor, _sec_timeout, request, response ) defer.returnValue(body) @defer.inlineCallbacks - def get_json(self, destination, path, args=None, retry_on_dns_fail=True, - timeout=None, ignore_backoff=False, - try_trailing_slash_on_400=False): + def get_json( + self, + destination, + path, + args=None, + retry_on_dns_fail=True, + timeout=None, + ignore_backoff=False, + try_trailing_slash_on_400=False, + ): """ GETs some json from the given host homeserver and path Args: @@ -745,10 +742,7 @@ def get_json(self, destination, path, args=None, retry_on_dns_fail=True, remote, due to e.g. DNS failures, connection timeouts etc. """ request = MatrixFederationRequest( - method="GET", - destination=destination, - path=path, - query=args, + method="GET", destination=destination, path=path, query=args ) response = yield self._send_request_with_optional_trailing_slash( @@ -761,14 +755,21 @@ def get_json(self, destination, path, args=None, retry_on_dns_fail=True, ) body = yield _handle_json_response( - self.reactor, self.default_timeout, request, response, + self.reactor, self.default_timeout, request, response ) defer.returnValue(body) @defer.inlineCallbacks - def delete_json(self, destination, path, long_retries=False, - timeout=None, ignore_backoff=False, args={}): + def delete_json( + self, + destination, + path, + long_retries=False, + timeout=None, + ignore_backoff=False, + args={}, + ): """Send a DELETE request to the remote expecting some json response Args: @@ -802,10 +803,7 @@ def delete_json(self, destination, path, long_retries=False, remote, due to e.g. DNS failures, connection timeouts etc. """ request = MatrixFederationRequest( - method="DELETE", - destination=destination, - path=path, - query=args, + method="DELETE", destination=destination, path=path, query=args ) response = yield self._send_request( @@ -816,14 +814,21 @@ def delete_json(self, destination, path, long_retries=False, ) body = yield _handle_json_response( - self.reactor, self.default_timeout, request, response, + self.reactor, self.default_timeout, request, response ) defer.returnValue(body) @defer.inlineCallbacks - def get_file(self, destination, path, output_stream, args={}, - retry_on_dns_fail=True, max_size=None, - ignore_backoff=False): + def get_file( + self, + destination, + path, + output_stream, + args={}, + retry_on_dns_fail=True, + max_size=None, + ignore_backoff=False, + ): """GETs a file from a given homeserver Args: destination (str): The remote server to send the HTTP request to. @@ -848,16 +853,11 @@ def get_file(self, destination, path, output_stream, args={}, remote, due to e.g. DNS failures, connection timeouts etc. """ request = MatrixFederationRequest( - method="GET", - destination=destination, - path=path, - query=args, + method="GET", destination=destination, path=path, query=args ) response = yield self._send_request( - request, - retry_on_dns_fail=retry_on_dns_fail, - ignore_backoff=ignore_backoff, + request, retry_on_dns_fail=retry_on_dns_fail, ignore_backoff=ignore_backoff ) headers = dict(response.headers.getAllRawHeaders()) @@ -879,7 +879,7 @@ def get_file(self, destination, path, output_stream, args={}, request.txn_id, request.destination, response.code, - response.phrase.decode('ascii', errors='replace'), + response.phrase.decode("ascii", errors="replace"), length, ) defer.returnValue((length, headers)) @@ -896,11 +896,13 @@ def dataReceived(self, data): self.stream.write(data) self.length += len(data) if self.max_size is not None and self.length >= self.max_size: - self.deferred.errback(SynapseError( - 502, - "Requested file is too large > %r bytes" % (self.max_size,), - Codes.TOO_LARGE, - )) + self.deferred.errback( + SynapseError( + 502, + "Requested file is too large > %r bytes" % (self.max_size,), + Codes.TOO_LARGE, + ) + ) self.deferred = defer.Deferred() self.transport.loseConnection() @@ -920,8 +922,7 @@ def _readBodyToFile(response, stream, max_size): def _flatten_response_never_received(e): if hasattr(e, "reasons"): reasons = ", ".join( - _flatten_response_never_received(f.value) - for f in e.reasons + _flatten_response_never_received(f.value) for f in e.reasons ) return "%s:[%s]" % (type(e).__name__, reasons) @@ -943,16 +944,15 @@ def check_content_type_is_json(headers): """ c_type = headers.getRawHeaders(b"Content-Type") if c_type is None: - raise RequestSendFailed(RuntimeError( - "No Content-Type header" - ), can_retry=False) + raise RequestSendFailed(RuntimeError("No Content-Type header"), can_retry=False) - c_type = c_type[0].decode('ascii') # only the first header + c_type = c_type[0].decode("ascii") # only the first header val, options = cgi.parse_header(c_type) if val != "application/json": - raise RequestSendFailed(RuntimeError( - "Content-Type not application/json: was '%s'" % c_type - ), can_retry=False) + raise RequestSendFailed( + RuntimeError("Content-Type not application/json: was '%s'" % c_type), + can_retry=False, + ) def encode_query_args(args): @@ -967,4 +967,4 @@ def encode_query_args(args): query_bytes = urllib.parse.urlencode(encoded_args, True) - return query_bytes.encode('utf8') + return query_bytes.encode("utf8") diff --git a/synapse/http/server.py b/synapse/http/server.py index 16fb7935dad2..6fd13e87d149 100644 --- a/synapse/http/server.py +++ b/synapse/http/server.py @@ -81,9 +81,7 @@ def wrapped_request_handler(self, request): yield h(self, request) except SynapseError as e: code = e.code - logger.info( - "%s SynapseError: %s - %s", request, code, e.msg - ) + logger.info("%s SynapseError: %s - %s", request, code, e.msg) # Only respond with an error response if we haven't already started # writing, otherwise lets just kill the connection @@ -96,7 +94,10 @@ def wrapped_request_handler(self, request): pass else: respond_with_json( - request, code, e.error_dict(), send_cors=True, + request, + code, + e.error_dict(), + send_cors=True, pretty_print=_request_user_agent_is_curl(request), ) @@ -124,10 +125,7 @@ def wrapped_request_handler(self, request): respond_with_json( request, 500, - { - "error": "Internal server error", - "errcode": Codes.UNKNOWN, - }, + {"error": "Internal server error", "errcode": Codes.UNKNOWN}, send_cors=True, pretty_print=_request_user_agent_is_curl(request), ) @@ -143,6 +141,7 @@ def wrap_html_request_handler(h): The handler method must have a signature of "handle_foo(self, request)", where "request" must be a SynapseRequest. """ + def wrapped_request_handler(self, request): d = defer.maybeDeferred(h, self, request) d.addErrback(_return_html_error, request) @@ -164,9 +163,7 @@ def _return_html_error(f, request): msg = cme.msg if isinstance(cme, SynapseError): - logger.info( - "%s SynapseError: %s - %s", request, code, msg - ) + logger.info("%s SynapseError: %s - %s", request, code, msg) else: logger.error( "Failed handle request %r", @@ -183,9 +180,7 @@ def _return_html_error(f, request): exc_info=(f.type, f.value, f.getTracebackObject()), ) - body = HTML_ERROR_TEMPLATE.format( - code=code, msg=cgi.escape(msg), - ).encode("utf-8") + body = HTML_ERROR_TEMPLATE.format(code=code, msg=cgi.escape(msg)).encode("utf-8") request.setResponseCode(code) request.setHeader(b"Content-Type", b"text/html; charset=utf-8") request.setHeader(b"Content-Length", b"%i" % (len(body),)) @@ -205,6 +200,7 @@ def wrap_async_request_handler(h): The handler may return a deferred, in which case the completion of the request isn't logged until the deferred completes. """ + @defer.inlineCallbacks def wrapped_async_request_handler(self, request): with request.processing(): @@ -306,12 +302,14 @@ def _unquote(s): # URL again (as it was decoded by _get_handler_for request), as # ASCII because it's a URL, and then decode it to get the UTF-8 # characters that were quoted. - return urllib.parse.unquote(s.encode('ascii')).decode('utf8') + return urllib.parse.unquote(s.encode("ascii")).decode("utf8") - kwargs = intern_dict({ - name: _unquote(value) if value else value - for name, value in group_dict.items() - }) + kwargs = intern_dict( + { + name: _unquote(value) if value else value + for name, value in group_dict.items() + } + ) callback_return = yield callback(request, **kwargs) if callback_return is not None: @@ -339,7 +337,7 @@ def _get_handler_for_request(self, request): # Loop through all the registered callbacks to check if the method # and path regex match for path_entry in self.path_regexs.get(request.method, []): - m = path_entry.pattern.match(request.path.decode('ascii')) + m = path_entry.pattern.match(request.path.decode("ascii")) if m: # We found a match! return path_entry.callback, m.groupdict() @@ -347,11 +345,14 @@ def _get_handler_for_request(self, request): # Huh. No one wanted to handle that? Fiiiiiine. Send 400. return _unrecognised_request_handler, {} - def _send_response(self, request, code, response_json_object, - response_code_message=None): + def _send_response( + self, request, code, response_json_object, response_code_message=None + ): # TODO: Only enable CORS for the requests that need it. respond_with_json( - request, code, response_json_object, + request, + code, + response_json_object, send_cors=True, response_code_message=response_code_message, pretty_print=_request_user_agent_is_curl(request), @@ -395,7 +396,7 @@ def __init__(self, path): self.url = path def render_GET(self, request): - return redirectTo(self.url.encode('ascii'), request) + return redirectTo(self.url.encode("ascii"), request) def getChild(self, name, request): if len(name) == 0: @@ -403,16 +404,22 @@ def getChild(self, name, request): return resource.Resource.getChild(self, name, request) -def respond_with_json(request, code, json_object, send_cors=False, - response_code_message=None, pretty_print=False, - canonical_json=True): +def respond_with_json( + request, + code, + json_object, + send_cors=False, + response_code_message=None, + pretty_print=False, + canonical_json=True, +): # could alternatively use request.notifyFinish() and flip a flag when # the Deferred fires, but since the flag is RIGHT THERE it seems like # a waste. if request._disconnected: logger.warn( - "Not sending response to request %s, already disconnected.", - request) + "Not sending response to request %s, already disconnected.", request + ) return if pretty_print: @@ -425,14 +432,17 @@ def respond_with_json(request, code, json_object, send_cors=False, json_bytes = json.dumps(json_object).encode("utf-8") return respond_with_json_bytes( - request, code, json_bytes, + request, + code, + json_bytes, send_cors=send_cors, response_code_message=response_code_message, ) -def respond_with_json_bytes(request, code, json_bytes, send_cors=False, - response_code_message=None): +def respond_with_json_bytes( + request, code, json_bytes, send_cors=False, response_code_message=None +): """Sends encoded JSON in response to the given request. Args: @@ -474,7 +484,7 @@ def set_cors_headers(request): ) request.setHeader( b"Access-Control-Allow-Headers", - b"Origin, X-Requested-With, Content-Type, Accept, Authorization" + b"Origin, X-Requested-With, Content-Type, Accept, Authorization", ) @@ -498,9 +508,7 @@ def finish_request(request): def _request_user_agent_is_curl(request): - user_agents = request.requestHeaders.getRawHeaders( - b"User-Agent", default=[] - ) + user_agents = request.requestHeaders.getRawHeaders(b"User-Agent", default=[]) for user_agent in user_agents: if b"curl" in user_agent: return True diff --git a/synapse/http/servlet.py b/synapse/http/servlet.py index 197c652850ed..cd8415acd5ab 100644 --- a/synapse/http/servlet.py +++ b/synapse/http/servlet.py @@ -48,7 +48,7 @@ def parse_integer(request, name, default=None, required=False): def parse_integer_from_args(args, name, default=None, required=False): if not isinstance(name, bytes): - name = name.encode('ascii') + name = name.encode("ascii") if name in args: try: @@ -89,18 +89,14 @@ def parse_boolean(request, name, default=None, required=False): def parse_boolean_from_args(args, name, default=None, required=False): if not isinstance(name, bytes): - name = name.encode('ascii') + name = name.encode("ascii") if name in args: try: - return { - b"true": True, - b"false": False, - }[args[name][0]] + return {b"true": True, b"false": False}[args[name][0]] except Exception: message = ( - "Boolean query parameter %r must be one of" - " ['true', 'false']" + "Boolean query parameter %r must be one of" " ['true', 'false']" ) % (name,) raise SynapseError(400, message) else: @@ -111,8 +107,15 @@ def parse_boolean_from_args(args, name, default=None, required=False): return default -def parse_string(request, name, default=None, required=False, - allowed_values=None, param_type="string", encoding='ascii'): +def parse_string( + request, + name, + default=None, + required=False, + allowed_values=None, + param_type="string", + encoding="ascii", +): """ Parse a string parameter from the request query string. @@ -145,11 +148,18 @@ def parse_string(request, name, default=None, required=False, ) -def parse_string_from_args(args, name, default=None, required=False, - allowed_values=None, param_type="string", encoding='ascii'): +def parse_string_from_args( + args, + name, + default=None, + required=False, + allowed_values=None, + param_type="string", + encoding="ascii", +): if not isinstance(name, bytes): - name = name.encode('ascii') + name = name.encode("ascii") if name in args: value = args[name][0] @@ -159,7 +169,8 @@ def parse_string_from_args(args, name, default=None, required=False, if allowed_values is not None and value not in allowed_values: message = "Query parameter %r must be one of [%s]" % ( - name, ", ".join(repr(v) for v in allowed_values) + name, + ", ".join(repr(v) for v in allowed_values), ) raise SynapseError(400, message) else: @@ -201,7 +212,7 @@ def parse_json_value_from_request(request, allow_empty_body=False): # Decode to Unicode so that simplejson will return Unicode strings on # Python 2 try: - content_unicode = content_bytes.decode('utf8') + content_unicode = content_bytes.decode("utf8") except UnicodeDecodeError: logger.warn("Unable to decode UTF-8") raise SynapseError(400, "Content not JSON.", errcode=Codes.NOT_JSON) @@ -227,9 +238,7 @@ def parse_json_object_from_request(request, allow_empty_body=False): SynapseError if the request body couldn't be decoded as JSON or if it wasn't a JSON object. """ - content = parse_json_value_from_request( - request, allow_empty_body=allow_empty_body, - ) + content = parse_json_value_from_request(request, allow_empty_body=allow_empty_body) if allow_empty_body and content is None: return {} diff --git a/synapse/http/site.py b/synapse/http/site.py index e508c0bd4f2d..93f679ea4845 100644 --- a/synapse/http/site.py +++ b/synapse/http/site.py @@ -46,10 +46,11 @@ class SynapseRequest(Request): Attributes: logcontext(LoggingContext) : the log context for this request """ + def __init__(self, site, channel, *args, **kw): Request.__init__(self, channel, *args, **kw) self.site = site - self._channel = channel # this is used by the tests + self._channel = channel # this is used by the tests self.authenticated_entity = None self.start_time = 0 @@ -72,12 +73,12 @@ def __init__(self, site, channel, *args, **kw): def __repr__(self): # We overwrite this so that we don't log ``access_token`` - return '<%s at 0x%x method=%r uri=%r clientproto=%r site=%r>' % ( + return "<%s at 0x%x method=%r uri=%r clientproto=%r site=%r>" % ( self.__class__.__name__, id(self), self.get_method(), self.get_redacted_uri(), - self.clientproto.decode('ascii', errors='replace'), + self.clientproto.decode("ascii", errors="replace"), self.site.site_tag, ) @@ -87,7 +88,7 @@ def get_request_id(self): def get_redacted_uri(self): uri = self.uri if isinstance(uri, bytes): - uri = self.uri.decode('ascii') + uri = self.uri.decode("ascii") return redact_uri(uri) def get_method(self): @@ -102,7 +103,7 @@ def get_method(self): """ method = self.method if isinstance(method, bytes): - method = self.method.decode('ascii') + method = self.method.decode("ascii") return method def get_user_agent(self): @@ -134,8 +135,7 @@ def render(self, resrc): # dispatching to the handler, so that the handler # can update the servlet name in the request # metrics - requests_counter.labels(self.get_method(), - self.request_metrics.name).inc() + requests_counter.labels(self.get_method(), self.request_metrics.name).inc() @contextlib.contextmanager def processing(self): @@ -200,7 +200,7 @@ def connectionLost(self, reason): # the client disconnects. with PreserveLoggingContext(self.logcontext): logger.warn( - "Error processing request %r: %s %s", self, reason.type, reason.value, + "Error processing request %r: %s %s", self, reason.type, reason.value ) if not self._is_processing: @@ -222,7 +222,7 @@ def _started_processing(self, servlet_name): self.start_time = time.time() self.request_metrics = RequestMetrics() self.request_metrics.start( - self.start_time, name=servlet_name, method=self.get_method(), + self.start_time, name=servlet_name, method=self.get_method() ) self.site.access_logger.info( @@ -230,7 +230,7 @@ def _started_processing(self, servlet_name): self.getClientIP(), self.site.site_tag, self.get_method(), - self.get_redacted_uri() + self.get_redacted_uri(), ) def _finished_processing(self): @@ -282,7 +282,7 @@ def _finished_processing(self): self.site.access_logger.info( "%s - %s - {%s}" " Processed request: %.3fsec/%.3fsec (%.3fsec, %.3fsec) (%.3fsec/%.3fsec/%d)" - " %sB %s \"%s %s %s\" \"%s\" [%d dbevts]", + ' %sB %s "%s %s %s" "%s" [%d dbevts]', self.getClientIP(), self.site.site_tag, authenticated_entity, @@ -297,7 +297,7 @@ def _finished_processing(self): code, self.get_method(), self.get_redacted_uri(), - self.clientproto.decode('ascii', errors='replace'), + self.clientproto.decode("ascii", errors="replace"), user_agent, usage.evt_db_fetch_count, ) @@ -316,14 +316,19 @@ def __init__(self, *args, **kw): Add a layer on top of another request that only uses the value of an X-Forwarded-For header as the result of C{getClientIP}. """ + def getClientIP(self): """ @return: The client address (the first address) in the value of the I{X-Forwarded-For header}. If the header is not present, return C{b"-"}. """ - return self.requestHeaders.getRawHeaders( - b"x-forwarded-for", [b"-"])[0].split(b",")[0].strip().decode('ascii') + return ( + self.requestHeaders.getRawHeaders(b"x-forwarded-for", [b"-"])[0] + .split(b",")[0] + .strip() + .decode("ascii") + ) class SynapseRequestFactory(object): @@ -343,8 +348,17 @@ class SynapseSite(Site): Subclass of a twisted http Site that does access logging with python's standard logging """ - def __init__(self, logger_name, site_tag, config, resource, - server_version_string, *args, **kwargs): + + def __init__( + self, + logger_name, + site_tag, + config, + resource, + server_version_string, + *args, + **kwargs + ): Site.__init__(self, resource, *args, **kwargs) self.site_tag = site_tag @@ -352,7 +366,7 @@ def __init__(self, logger_name, site_tag, config, resource, proxied = config.get("x_forwarded", False) self.requestFactory = SynapseRequestFactory(self, proxied) self.access_logger = logging.getLogger(logger_name) - self.server_version_string = server_version_string.encode('ascii') + self.server_version_string = server_version_string.encode("ascii") def log(self, request): pass diff --git a/synapse/metrics/__init__.py b/synapse/metrics/__init__.py index 8aee14a8a86a..1f30179b5169 100644 --- a/synapse/metrics/__init__.py +++ b/synapse/metrics/__init__.py @@ -231,10 +231,7 @@ def collect(self): res.append(["+Inf", sum(data.values())]) metric = HistogramMetricFamily( - self.name, - "", - buckets=res, - sum_value=sum([x * y for x, y in data.items()]), + self.name, "", buckets=res, sum_value=sum([x * y for x, y in data.items()]) ) yield metric @@ -263,7 +260,7 @@ def __init__(self): ticks_per_sec = 100 try: # Try and get the system config - ticks_per_sec = os.sysconf('SC_CLK_TCK') + ticks_per_sec = os.sysconf("SC_CLK_TCK") except (ValueError, TypeError, AttributeError): pass diff --git a/synapse/metrics/background_process_metrics.py b/synapse/metrics/background_process_metrics.py index 037f1c490ed4..167e2c068a45 100644 --- a/synapse/metrics/background_process_metrics.py +++ b/synapse/metrics/background_process_metrics.py @@ -60,8 +60,10 @@ _background_process_db_txn_duration = Counter( "synapse_background_process_db_txn_duration_seconds", - ("Seconds spent by background processes waiting for database " - "transactions, excluding scheduling time"), + ( + "Seconds spent by background processes waiting for database " + "transactions, excluding scheduling time" + ), ["name"], registry=None, ) @@ -94,6 +96,7 @@ class _Collector(object): Ensures that all of the metrics are up-to-date with any in-flight processes before they are returned. """ + def collect(self): background_process_in_flight_count = GaugeMetricFamily( "synapse_background_process_in_flight_count", @@ -105,14 +108,11 @@ def collect(self): # We also copy the process lists as that can also change with _bg_metrics_lock: _background_processes_copy = { - k: list(v) - for k, v in six.iteritems(_background_processes) + k: list(v) for k, v in six.iteritems(_background_processes) } for desc, processes in six.iteritems(_background_processes_copy): - background_process_in_flight_count.add_metric( - (desc,), len(processes), - ) + background_process_in_flight_count.add_metric((desc,), len(processes)) for process in processes: process.update_metrics() @@ -121,11 +121,11 @@ def collect(self): # now we need to run collect() over each of the static Counters, and # yield each metric they return. for m in ( - _background_process_ru_utime, - _background_process_ru_stime, - _background_process_db_txn_count, - _background_process_db_txn_duration, - _background_process_db_sched_duration, + _background_process_ru_utime, + _background_process_ru_stime, + _background_process_db_txn_count, + _background_process_db_txn_duration, + _background_process_db_sched_duration, ): for r in m.collect(): yield r @@ -151,14 +151,12 @@ def update_metrics(self): _background_process_ru_utime.labels(self.desc).inc(diff.ru_utime) _background_process_ru_stime.labels(self.desc).inc(diff.ru_stime) - _background_process_db_txn_count.labels(self.desc).inc( - diff.db_txn_count, - ) + _background_process_db_txn_count.labels(self.desc).inc(diff.db_txn_count) _background_process_db_txn_duration.labels(self.desc).inc( - diff.db_txn_duration_sec, + diff.db_txn_duration_sec ) _background_process_db_sched_duration.labels(self.desc).inc( - diff.db_sched_duration_sec, + diff.db_sched_duration_sec ) @@ -182,6 +180,7 @@ def run_as_background_process(desc, func, *args, **kwargs): Returns: Deferred which returns the result of func, but note that it does not follow the synapse logcontext rules. """ + @defer.inlineCallbacks def run(): with _bg_metrics_lock: diff --git a/synapse/module_api/__init__.py b/synapse/module_api/__init__.py index b3abd1b3c6d5..bf43ca09be7c 100644 --- a/synapse/module_api/__init__.py +++ b/synapse/module_api/__init__.py @@ -21,6 +21,7 @@ class ModuleApi(object): """A proxy object that gets passed to password auth providers so they can register new users etc if necessary. """ + def __init__(self, hs, auth_handler): self.hs = hs @@ -57,7 +58,7 @@ def get_qualified_user_id(self, username): Returns: str: qualified @user:id """ - if username.startswith('@'): + if username.startswith("@"): return username return UserID(username, self.hs.hostname).to_string() @@ -89,8 +90,7 @@ def register(self, localpart, displayname=None, emails=[]): # Register the user reg = self.hs.get_registration_handler() user_id, access_token = yield reg.register( - localpart=localpart, default_display_name=displayname, - bind_emails=emails, + localpart=localpart, default_display_name=displayname, bind_emails=emails ) defer.returnValue((user_id, access_token)) diff --git a/synapse/notifier.py b/synapse/notifier.py index ff589660dae4..d398078eedbb 100644 --- a/synapse/notifier.py +++ b/synapse/notifier.py @@ -37,7 +37,8 @@ notified_events_counter = Counter("synapse_notifier_notified_events", "") users_woken_by_stream_counter = Counter( - "synapse_notifier_users_woken_by_stream", "", ["stream"]) + "synapse_notifier_users_woken_by_stream", "", ["stream"] +) # TODO(paul): Should be shared somewhere @@ -55,6 +56,7 @@ class _NotificationListener(object): The events stream handler will have yielded to the deferred, so to notify the handler it is sufficient to resolve the deferred. """ + __slots__ = ["deferred"] def __init__(self, deferred): @@ -95,9 +97,7 @@ def notify(self, stream_key, stream_id, time_now_ms): stream_id(str): The new id for the stream the event came from. time_now_ms(int): The current time in milliseconds. """ - self.current_token = self.current_token.copy_and_advance( - stream_key, stream_id - ) + self.current_token = self.current_token.copy_and_advance(stream_key, stream_id) self.last_notified_token = self.current_token self.last_notified_ms = time_now_ms noify_deferred = self.notify_deferred @@ -141,6 +141,7 @@ def new_listener(self, token): class EventStreamResult(namedtuple("EventStreamResult", ("events", "tokens"))): def __nonzero__(self): return bool(self.events) + __bool__ = __nonzero__ # python3 @@ -190,15 +191,17 @@ def count_listeners(): all_user_streams.add(x) return sum(stream.count_listeners() for stream in all_user_streams) + LaterGauge("synapse_notifier_listeners", "", [], count_listeners) LaterGauge( - "synapse_notifier_rooms", "", [], + "synapse_notifier_rooms", + "", + [], lambda: count(bool, list(self.room_to_user_streams.values())), ) LaterGauge( - "synapse_notifier_users", "", [], - lambda: len(self.user_to_user_stream), + "synapse_notifier_users", "", [], lambda: len(self.user_to_user_stream) ) def add_replication_callback(self, cb): @@ -209,8 +212,9 @@ def add_replication_callback(self, cb): """ self.replication_callbacks.append(cb) - def on_new_room_event(self, event, room_stream_id, max_room_stream_id, - extra_users=[]): + def on_new_room_event( + self, event, room_stream_id, max_room_stream_id, extra_users=[] + ): """ Used by handlers to inform the notifier something has happened in the room, room event wise. @@ -222,9 +226,7 @@ def on_new_room_event(self, event, room_stream_id, max_room_stream_id, until all previous events have been persisted before notifying the client streams. """ - self.pending_new_room_events.append(( - room_stream_id, event, extra_users - )) + self.pending_new_room_events.append((room_stream_id, event, extra_users)) self._notify_pending_new_room_events(max_room_stream_id) self.notify_replication() @@ -240,9 +242,9 @@ def _notify_pending_new_room_events(self, max_room_stream_id): self.pending_new_room_events = [] for room_stream_id, event, extra_users in pending: if room_stream_id > max_room_stream_id: - self.pending_new_room_events.append(( - room_stream_id, event, extra_users - )) + self.pending_new_room_events.append( + (room_stream_id, event, extra_users) + ) else: self._on_new_room_event(event, room_stream_id, extra_users) @@ -250,8 +252,7 @@ def _on_new_room_event(self, event, room_stream_id, extra_users=[]): """Notify any user streams that are interested in this room event""" # poke any interested application service. run_as_background_process( - "notify_app_services", - self._notify_app_services, room_stream_id, + "notify_app_services", self._notify_app_services, room_stream_id ) if self.federation_sender: @@ -261,9 +262,7 @@ def _on_new_room_event(self, event, room_stream_id, extra_users=[]): self._user_joined_room(event.state_key, event.room_id) self.on_new_event( - "room_key", room_stream_id, - users=extra_users, - rooms=[event.room_id], + "room_key", room_stream_id, users=extra_users, rooms=[event.room_id] ) @defer.inlineCallbacks @@ -305,8 +304,9 @@ def on_new_replication_data(self): self.notify_replication() @defer.inlineCallbacks - def wait_for_events(self, user_id, timeout, callback, room_ids=None, - from_token=StreamToken.START): + def wait_for_events( + self, user_id, timeout, callback, room_ids=None, from_token=StreamToken.START + ): """Wait until the callback returns a non empty response or the timeout fires. """ @@ -339,7 +339,7 @@ def wait_for_events(self, user_id, timeout, callback, room_ids=None, listener = user_stream.new_listener(prev_token) listener.deferred = timeout_deferred( listener.deferred, - (end_time - now) / 1000., + (end_time - now) / 1000.0, self.hs.get_reactor(), ) with PreserveLoggingContext(): @@ -368,9 +368,15 @@ def wait_for_events(self, user_id, timeout, callback, room_ids=None, defer.returnValue(result) @defer.inlineCallbacks - def get_events_for(self, user, pagination_config, timeout, - only_keys=None, - is_guest=False, explicit_room_id=None): + def get_events_for( + self, + user, + pagination_config, + timeout, + only_keys=None, + is_guest=False, + explicit_room_id=None, + ): """ For the given user and rooms, return any new events for them. If there are no new events wait for up to `timeout` milliseconds for any new events to happen before returning. @@ -419,10 +425,7 @@ def check_for_updates(before_token, after_token): if name == "room": new_events = yield filter_events_for_client( - self.store, - user.to_string(), - new_events, - is_peeking=is_peeking, + self.store, user.to_string(), new_events, is_peeking=is_peeking ) elif name == "presence": now = self.clock.time_msec() @@ -450,7 +453,8 @@ def check_for_updates(before_token, after_token): # # I am sorry for what I have done. user_id_for_stream = "_PEEKING_%s_%s" % ( - explicit_room_id, user_id_for_stream + explicit_room_id, + user_id_for_stream, ) result = yield self.wait_for_events( @@ -477,9 +481,7 @@ def _get_room_ids(self, user, explicit_room_id): @defer.inlineCallbacks def _is_world_readable(self, room_id): state = yield self.state_handler.get_current_state( - room_id, - EventTypes.RoomHistoryVisibility, - "", + room_id, EventTypes.RoomHistoryVisibility, "" ) if state and "history_visibility" in state.content: defer.returnValue(state.content["history_visibility"] == "world_readable") diff --git a/synapse/push/action_generator.py b/synapse/push/action_generator.py index a5de75c48ac5..1ffd5e2df352 100644 --- a/synapse/push/action_generator.py +++ b/synapse/push/action_generator.py @@ -40,6 +40,4 @@ def __init__(self, hs): @defer.inlineCallbacks def handle_push_actions_for_event(self, event, context): with Measure(self.clock, "action_for_event_by_user"): - yield self.bulk_evaluator.action_for_event_by_user( - event, context - ) + yield self.bulk_evaluator.action_for_event_by_user(event, context) diff --git a/synapse/push/baserules.py b/synapse/push/baserules.py index 3523a40108d9..96d087de226a 100644 --- a/synapse/push/baserules.py +++ b/synapse/push/baserules.py @@ -31,48 +31,54 @@ def list_with_base_rules(rawrules): # Grab the base rules that the user has modified. # The modified base rules have a priority_class of -1. - modified_base_rules = { - r['rule_id']: r for r in rawrules if r['priority_class'] < 0 - } + modified_base_rules = {r["rule_id"]: r for r in rawrules if r["priority_class"] < 0} # Remove the modified base rules from the list, They'll be added back # in the default postions in the list. - rawrules = [r for r in rawrules if r['priority_class'] >= 0] + rawrules = [r for r in rawrules if r["priority_class"] >= 0] # shove the server default rules for each kind onto the end of each current_prio_class = list(PRIORITY_CLASS_INVERSE_MAP)[-1] - ruleslist.extend(make_base_prepend_rules( - PRIORITY_CLASS_INVERSE_MAP[current_prio_class], modified_base_rules - )) + ruleslist.extend( + make_base_prepend_rules( + PRIORITY_CLASS_INVERSE_MAP[current_prio_class], modified_base_rules + ) + ) for r in rawrules: - if r['priority_class'] < current_prio_class: - while r['priority_class'] < current_prio_class: - ruleslist.extend(make_base_append_rules( - PRIORITY_CLASS_INVERSE_MAP[current_prio_class], - modified_base_rules, - )) - current_prio_class -= 1 - if current_prio_class > 0: - ruleslist.extend(make_base_prepend_rules( + if r["priority_class"] < current_prio_class: + while r["priority_class"] < current_prio_class: + ruleslist.extend( + make_base_append_rules( PRIORITY_CLASS_INVERSE_MAP[current_prio_class], modified_base_rules, - )) + ) + ) + current_prio_class -= 1 + if current_prio_class > 0: + ruleslist.extend( + make_base_prepend_rules( + PRIORITY_CLASS_INVERSE_MAP[current_prio_class], + modified_base_rules, + ) + ) ruleslist.append(r) while current_prio_class > 0: - ruleslist.extend(make_base_append_rules( - PRIORITY_CLASS_INVERSE_MAP[current_prio_class], - modified_base_rules, - )) + ruleslist.extend( + make_base_append_rules( + PRIORITY_CLASS_INVERSE_MAP[current_prio_class], modified_base_rules + ) + ) current_prio_class -= 1 if current_prio_class > 0: - ruleslist.extend(make_base_prepend_rules( - PRIORITY_CLASS_INVERSE_MAP[current_prio_class], - modified_base_rules, - )) + ruleslist.extend( + make_base_prepend_rules( + PRIORITY_CLASS_INVERSE_MAP[current_prio_class], modified_base_rules + ) + ) return ruleslist @@ -80,20 +86,20 @@ def list_with_base_rules(rawrules): def make_base_append_rules(kind, modified_base_rules): rules = [] - if kind == 'override': + if kind == "override": rules = BASE_APPEND_OVERRIDE_RULES - elif kind == 'underride': + elif kind == "underride": rules = BASE_APPEND_UNDERRIDE_RULES - elif kind == 'content': + elif kind == "content": rules = BASE_APPEND_CONTENT_RULES # Copy the rules before modifying them rules = copy.deepcopy(rules) for r in rules: # Only modify the actions, keep the conditions the same. - modified = modified_base_rules.get(r['rule_id']) + modified = modified_base_rules.get(r["rule_id"]) if modified: - r['actions'] = modified['actions'] + r["actions"] = modified["actions"] return rules @@ -101,103 +107,86 @@ def make_base_append_rules(kind, modified_base_rules): def make_base_prepend_rules(kind, modified_base_rules): rules = [] - if kind == 'override': + if kind == "override": rules = BASE_PREPEND_OVERRIDE_RULES # Copy the rules before modifying them rules = copy.deepcopy(rules) for r in rules: # Only modify the actions, keep the conditions the same. - modified = modified_base_rules.get(r['rule_id']) + modified = modified_base_rules.get(r["rule_id"]) if modified: - r['actions'] = modified['actions'] + r["actions"] = modified["actions"] return rules BASE_APPEND_CONTENT_RULES = [ { - 'rule_id': 'global/content/.m.rule.contains_user_name', - 'conditions': [ + "rule_id": "global/content/.m.rule.contains_user_name", + "conditions": [ { - 'kind': 'event_match', - 'key': 'content.body', - 'pattern_type': 'user_localpart' + "kind": "event_match", + "key": "content.body", + "pattern_type": "user_localpart", } ], - 'actions': [ - 'notify', - { - 'set_tweak': 'sound', - 'value': 'default', - }, { - 'set_tweak': 'highlight' - } - ] - }, + "actions": [ + "notify", + {"set_tweak": "sound", "value": "default"}, + {"set_tweak": "highlight"}, + ], + } ] BASE_PREPEND_OVERRIDE_RULES = [ { - 'rule_id': 'global/override/.m.rule.master', - 'enabled': False, - 'conditions': [], - 'actions': [ - "dont_notify" - ] + "rule_id": "global/override/.m.rule.master", + "enabled": False, + "conditions": [], + "actions": ["dont_notify"], } ] BASE_APPEND_OVERRIDE_RULES = [ { - 'rule_id': 'global/override/.m.rule.suppress_notices', - 'conditions': [ + "rule_id": "global/override/.m.rule.suppress_notices", + "conditions": [ { - 'kind': 'event_match', - 'key': 'content.msgtype', - 'pattern': 'm.notice', - '_id': '_suppress_notices', + "kind": "event_match", + "key": "content.msgtype", + "pattern": "m.notice", + "_id": "_suppress_notices", } ], - 'actions': [ - 'dont_notify', - ] + "actions": ["dont_notify"], }, # NB. .m.rule.invite_for_me must be higher prio than .m.rule.member_event # otherwise invites will be matched by .m.rule.member_event { - 'rule_id': 'global/override/.m.rule.invite_for_me', - 'conditions': [ + "rule_id": "global/override/.m.rule.invite_for_me", + "conditions": [ { - 'kind': 'event_match', - 'key': 'type', - 'pattern': 'm.room.member', - '_id': '_member', + "kind": "event_match", + "key": "type", + "pattern": "m.room.member", + "_id": "_member", }, { - 'kind': 'event_match', - 'key': 'content.membership', - 'pattern': 'invite', - '_id': '_invite_member', - }, - { - 'kind': 'event_match', - 'key': 'state_key', - 'pattern_type': 'user_id' + "kind": "event_match", + "key": "content.membership", + "pattern": "invite", + "_id": "_invite_member", }, + {"kind": "event_match", "key": "state_key", "pattern_type": "user_id"}, + ], + "actions": [ + "notify", + {"set_tweak": "sound", "value": "default"}, + {"set_tweak": "highlight", "value": False}, ], - 'actions': [ - 'notify', - { - 'set_tweak': 'sound', - 'value': 'default' - }, { - 'set_tweak': 'highlight', - 'value': False - } - ] }, # Will we sometimes want to know about people joining and leaving? # Perhaps: if so, this could be expanded upon. Seems the most usual case @@ -206,217 +195,164 @@ def make_base_prepend_rules(kind, modified_base_rules): # join/leave/avatar/displayname events. # See also: https://matrix.org/jira/browse/SYN-607 { - 'rule_id': 'global/override/.m.rule.member_event', - 'conditions': [ + "rule_id": "global/override/.m.rule.member_event", + "conditions": [ { - 'kind': 'event_match', - 'key': 'type', - 'pattern': 'm.room.member', - '_id': '_member', + "kind": "event_match", + "key": "type", + "pattern": "m.room.member", + "_id": "_member", } ], - 'actions': [ - 'dont_notify' - ] + "actions": ["dont_notify"], }, # This was changed from underride to override so it's closer in priority # to the content rules where the user name highlight rule lives. This # way a room rule is lower priority than both but a custom override rule # is higher priority than both. { - 'rule_id': 'global/override/.m.rule.contains_display_name', - 'conditions': [ - { - 'kind': 'contains_display_name' - } + "rule_id": "global/override/.m.rule.contains_display_name", + "conditions": [{"kind": "contains_display_name"}], + "actions": [ + "notify", + {"set_tweak": "sound", "value": "default"}, + {"set_tweak": "highlight"}, ], - 'actions': [ - 'notify', - { - 'set_tweak': 'sound', - 'value': 'default' - }, { - 'set_tweak': 'highlight' - } - ] }, { - 'rule_id': 'global/override/.m.rule.roomnotif', - 'conditions': [ + "rule_id": "global/override/.m.rule.roomnotif", + "conditions": [ { - 'kind': 'event_match', - 'key': 'content.body', - 'pattern': '@room', - '_id': '_roomnotif_content', + "kind": "event_match", + "key": "content.body", + "pattern": "@room", + "_id": "_roomnotif_content", }, { - 'kind': 'sender_notification_permission', - 'key': 'room', - '_id': '_roomnotif_pl', + "kind": "sender_notification_permission", + "key": "room", + "_id": "_roomnotif_pl", }, ], - 'actions': [ - 'notify', { - 'set_tweak': 'highlight', - 'value': True, - } - ] + "actions": ["notify", {"set_tweak": "highlight", "value": True}], }, { - 'rule_id': 'global/override/.m.rule.tombstone', - 'conditions': [ + "rule_id": "global/override/.m.rule.tombstone", + "conditions": [ { - 'kind': 'event_match', - 'key': 'type', - 'pattern': 'm.room.tombstone', - '_id': '_tombstone', + "kind": "event_match", + "key": "type", + "pattern": "m.room.tombstone", + "_id": "_tombstone", } ], - 'actions': [ - 'notify', { - 'set_tweak': 'highlight', - 'value': True, - } - ] - } + "actions": ["notify", {"set_tweak": "highlight", "value": True}], + }, ] BASE_APPEND_UNDERRIDE_RULES = [ { - 'rule_id': 'global/underride/.m.rule.call', - 'conditions': [ + "rule_id": "global/underride/.m.rule.call", + "conditions": [ { - 'kind': 'event_match', - 'key': 'type', - 'pattern': 'm.call.invite', - '_id': '_call', + "kind": "event_match", + "key": "type", + "pattern": "m.call.invite", + "_id": "_call", } ], - 'actions': [ - 'notify', - { - 'set_tweak': 'sound', - 'value': 'ring' - }, { - 'set_tweak': 'highlight', - 'value': False - } - ] + "actions": [ + "notify", + {"set_tweak": "sound", "value": "ring"}, + {"set_tweak": "highlight", "value": False}, + ], }, # XXX: once m.direct is standardised everywhere, we should use it to detect # a DM from the user's perspective rather than this heuristic. { - 'rule_id': 'global/underride/.m.rule.room_one_to_one', - 'conditions': [ + "rule_id": "global/underride/.m.rule.room_one_to_one", + "conditions": [ + {"kind": "room_member_count", "is": "2", "_id": "member_count"}, { - 'kind': 'room_member_count', - 'is': '2', - '_id': 'member_count', + "kind": "event_match", + "key": "type", + "pattern": "m.room.message", + "_id": "_message", }, - { - 'kind': 'event_match', - 'key': 'type', - 'pattern': 'm.room.message', - '_id': '_message', - } ], - 'actions': [ - 'notify', - { - 'set_tweak': 'sound', - 'value': 'default' - }, { - 'set_tweak': 'highlight', - 'value': False - } - ] + "actions": [ + "notify", + {"set_tweak": "sound", "value": "default"}, + {"set_tweak": "highlight", "value": False}, + ], }, # XXX: this is going to fire for events which aren't m.room.messages # but are encrypted (e.g. m.call.*)... { - 'rule_id': 'global/underride/.m.rule.encrypted_room_one_to_one', - 'conditions': [ + "rule_id": "global/underride/.m.rule.encrypted_room_one_to_one", + "conditions": [ + {"kind": "room_member_count", "is": "2", "_id": "member_count"}, { - 'kind': 'room_member_count', - 'is': '2', - '_id': 'member_count', + "kind": "event_match", + "key": "type", + "pattern": "m.room.encrypted", + "_id": "_encrypted", }, - { - 'kind': 'event_match', - 'key': 'type', - 'pattern': 'm.room.encrypted', - '_id': '_encrypted', - } ], - 'actions': [ - 'notify', - { - 'set_tweak': 'sound', - 'value': 'default' - }, { - 'set_tweak': 'highlight', - 'value': False - } - ] + "actions": [ + "notify", + {"set_tweak": "sound", "value": "default"}, + {"set_tweak": "highlight", "value": False}, + ], }, { - 'rule_id': 'global/underride/.m.rule.message', - 'conditions': [ + "rule_id": "global/underride/.m.rule.message", + "conditions": [ { - 'kind': 'event_match', - 'key': 'type', - 'pattern': 'm.room.message', - '_id': '_message', + "kind": "event_match", + "key": "type", + "pattern": "m.room.message", + "_id": "_message", } ], - 'actions': [ - 'notify', { - 'set_tweak': 'highlight', - 'value': False - } - ] + "actions": ["notify", {"set_tweak": "highlight", "value": False}], }, # XXX: this is going to fire for events which aren't m.room.messages # but are encrypted (e.g. m.call.*)... { - 'rule_id': 'global/underride/.m.rule.encrypted', - 'conditions': [ + "rule_id": "global/underride/.m.rule.encrypted", + "conditions": [ { - 'kind': 'event_match', - 'key': 'type', - 'pattern': 'm.room.encrypted', - '_id': '_encrypted', + "kind": "event_match", + "key": "type", + "pattern": "m.room.encrypted", + "_id": "_encrypted", } ], - 'actions': [ - 'notify', { - 'set_tweak': 'highlight', - 'value': False - } - ] - } + "actions": ["notify", {"set_tweak": "highlight", "value": False}], + }, ] BASE_RULE_IDS = set() for r in BASE_APPEND_CONTENT_RULES: - r['priority_class'] = PRIORITY_CLASS_MAP['content'] - r['default'] = True - BASE_RULE_IDS.add(r['rule_id']) + r["priority_class"] = PRIORITY_CLASS_MAP["content"] + r["default"] = True + BASE_RULE_IDS.add(r["rule_id"]) for r in BASE_PREPEND_OVERRIDE_RULES: - r['priority_class'] = PRIORITY_CLASS_MAP['override'] - r['default'] = True - BASE_RULE_IDS.add(r['rule_id']) + r["priority_class"] = PRIORITY_CLASS_MAP["override"] + r["default"] = True + BASE_RULE_IDS.add(r["rule_id"]) for r in BASE_APPEND_OVERRIDE_RULES: - r['priority_class'] = PRIORITY_CLASS_MAP['override'] - r['default'] = True - BASE_RULE_IDS.add(r['rule_id']) + r["priority_class"] = PRIORITY_CLASS_MAP["override"] + r["default"] = True + BASE_RULE_IDS.add(r["rule_id"]) for r in BASE_APPEND_UNDERRIDE_RULES: - r['priority_class'] = PRIORITY_CLASS_MAP['underride'] - r['default'] = True - BASE_RULE_IDS.add(r['rule_id']) + r["priority_class"] = PRIORITY_CLASS_MAP["underride"] + r["default"] = True + BASE_RULE_IDS.add(r["rule_id"]) diff --git a/synapse/push/bulk_push_rule_evaluator.py b/synapse/push/bulk_push_rule_evaluator.py index 8f9a76147f97..c8a5b381daec 100644 --- a/synapse/push/bulk_push_rule_evaluator.py +++ b/synapse/push/bulk_push_rule_evaluator.py @@ -39,9 +39,11 @@ push_rules_invalidation_counter = Counter( - "synapse_push_bulk_push_rule_evaluator_push_rules_invalidation_counter", "") + "synapse_push_bulk_push_rule_evaluator_push_rules_invalidation_counter", "" +) push_rules_state_size_counter = Counter( - "synapse_push_bulk_push_rule_evaluator_push_rules_state_size_counter", "") + "synapse_push_bulk_push_rule_evaluator_push_rules_state_size_counter", "" +) # Measures whether we use the fast path of using state deltas, or if we have to # recalculate from scratch @@ -83,7 +85,7 @@ def _get_rules_for_event(self, event, context): # if this event is an invite event, we may need to run rules for the user # who's been invited, otherwise they won't get told they've been invited - if event.type == 'm.room.member' and event.content['membership'] == 'invite': + if event.type == "m.room.member" and event.content["membership"] == "invite": invited = event.state_key if invited and self.hs.is_mine_id(invited): has_pusher = yield self.store.user_has_pusher(invited) @@ -106,7 +108,9 @@ def _get_rules_for_room(self, room_id): # before any lookup methods get called on it as otherwise there may be # a race if invalidate_all gets called (which assumes its in the cache) return RulesForRoom( - self.hs, room_id, self._get_rules_for_room.cache, + self.hs, + room_id, + self._get_rules_for_room.cache, self.room_push_rule_cache_metrics, ) @@ -121,12 +125,10 @@ def _get_power_levels_and_sender_level(self, event, context): auth_events = {POWER_KEY: pl_event} else: auth_events_ids = yield self.auth.compute_auth_events( - event, prev_state_ids, for_verification=False, + event, prev_state_ids, for_verification=False ) auth_events = yield self.store.get_events(auth_events_ids) - auth_events = { - (e.type, e.state_key): e for e in itervalues(auth_events) - } + auth_events = {(e.type, e.state_key): e for e in itervalues(auth_events)} sender_level = get_user_power_level(event.sender, auth_events) @@ -145,16 +147,14 @@ def action_for_event_by_user(self, event, context): rules_by_user = yield self._get_rules_for_event(event, context) actions_by_user = {} - room_members = yield self.store.get_joined_users_from_context( - event, context - ) + room_members = yield self.store.get_joined_users_from_context(event, context) (power_levels, sender_power_level) = ( yield self._get_power_levels_and_sender_level(event, context) ) evaluator = PushRuleEvaluatorForEvent( - event, len(room_members), sender_power_level, power_levels, + event, len(room_members), sender_power_level, power_levels ) condition_cache = {} @@ -180,15 +180,15 @@ def action_for_event_by_user(self, event, context): display_name = event.content.get("displayname", None) for rule in rules: - if 'enabled' in rule and not rule['enabled']: + if "enabled" in rule and not rule["enabled"]: continue matches = _condition_checker( - evaluator, rule['conditions'], uid, display_name, condition_cache + evaluator, rule["conditions"], uid, display_name, condition_cache ) if matches: - actions = [x for x in rule['actions'] if x != 'dont_notify'] - if actions and 'notify' in actions: + actions = [x for x in rule["actions"] if x != "dont_notify"] + if actions and "notify" in actions: # Push rules say we should notify the user of this event actions_by_user[uid] = actions break @@ -196,9 +196,7 @@ def action_for_event_by_user(self, event, context): # Mark in the DB staging area the push actions for users who should be # notified for this event. (This will then get handled when we persist # the event) - yield self.store.add_push_actions_to_staging( - event.event_id, actions_by_user, - ) + yield self.store.add_push_actions_to_staging(event.event_id, actions_by_user) def _condition_checker(evaluator, conditions, uid, display_name, cache): @@ -361,19 +359,19 @@ def get_rules(self, event, context): self.sequence, members={}, # There were no membership changes rules_by_user=ret_rules_by_user, - state_group=state_group + state_group=state_group, ) if logger.isEnabledFor(logging.DEBUG): logger.debug( - "Returning push rules for %r %r", - self.room_id, ret_rules_by_user.keys(), + "Returning push rules for %r %r", self.room_id, ret_rules_by_user.keys() ) defer.returnValue(ret_rules_by_user) @defer.inlineCallbacks - def _update_rules_with_member_event_ids(self, ret_rules_by_user, member_event_ids, - state_group, event): + def _update_rules_with_member_event_ids( + self, ret_rules_by_user, member_event_ids, state_group, event + ): """Update the partially filled rules_by_user dict by fetching rules for any newly joined users in the `member_event_ids` list. @@ -391,16 +389,13 @@ def _update_rules_with_member_event_ids(self, ret_rules_by_user, member_event_id table="room_memberships", column="event_id", iterable=member_event_ids.values(), - retcols=('user_id', 'membership', 'event_id'), + retcols=("user_id", "membership", "event_id"), keyvalues={}, batch_size=500, desc="_get_rules_for_member_event_ids", ) - members = { - row["event_id"]: (row["user_id"], row["membership"]) - for row in rows - } + members = {row["event_id"]: (row["user_id"], row["membership"]) for row in rows} # If the event is a join event then it will be in current state evnts # map but not in the DB, so we have to explicitly insert it. @@ -413,15 +408,15 @@ def _update_rules_with_member_event_ids(self, ret_rules_by_user, member_event_id logger.debug("Found members %r: %r", self.room_id, members.values()) interested_in_user_ids = set( - user_id for user_id, membership in itervalues(members) + user_id + for user_id, membership in itervalues(members) if membership == Membership.JOIN ) logger.debug("Joined: %r", interested_in_user_ids) if_users_with_pushers = yield self.store.get_if_users_have_pushers( - interested_in_user_ids, - on_invalidate=self.invalidate_all_cb, + interested_in_user_ids, on_invalidate=self.invalidate_all_cb ) user_ids = set( @@ -431,7 +426,7 @@ def _update_rules_with_member_event_ids(self, ret_rules_by_user, member_event_id logger.debug("With pushers: %r", user_ids) users_with_receipts = yield self.store.get_users_with_read_receipts_in_room( - self.room_id, on_invalidate=self.invalidate_all_cb, + self.room_id, on_invalidate=self.invalidate_all_cb ) logger.debug("With receipts: %r", users_with_receipts) @@ -442,7 +437,7 @@ def _update_rules_with_member_event_ids(self, ret_rules_by_user, member_event_id user_ids.add(uid) rules_by_user = yield self.store.bulk_get_push_rules( - user_ids, on_invalidate=self.invalidate_all_cb, + user_ids, on_invalidate=self.invalidate_all_cb ) ret_rules_by_user.update( diff --git a/synapse/push/clientformat.py b/synapse/push/clientformat.py index 8bd96b1178fd..a59b639f1586 100644 --- a/synapse/push/clientformat.py +++ b/synapse/push/clientformat.py @@ -25,14 +25,14 @@ def format_push_rules_for_user(user, ruleslist): # We're going to be mutating this a lot, so do a deep copy ruleslist = copy.deepcopy(ruleslist) - rules = {'global': {}, 'device': {}} + rules = {"global": {}, "device": {}} - rules['global'] = _add_empty_priority_class_arrays(rules['global']) + rules["global"] = _add_empty_priority_class_arrays(rules["global"]) for r in ruleslist: rulearray = None - template_name = _priority_class_to_template_name(r['priority_class']) + template_name = _priority_class_to_template_name(r["priority_class"]) # Remove internal stuff. for c in r["conditions"]: @@ -44,14 +44,14 @@ def format_push_rules_for_user(user, ruleslist): elif pattern_type == "user_localpart": c["pattern"] = user.localpart - rulearray = rules['global'][template_name] + rulearray = rules["global"][template_name] template_rule = _rule_to_template(r) if template_rule: - if 'enabled' in r: - template_rule['enabled'] = r['enabled'] + if "enabled" in r: + template_rule["enabled"] = r["enabled"] else: - template_rule['enabled'] = True + template_rule["enabled"] = True rulearray.append(template_rule) return rules @@ -65,33 +65,33 @@ def _add_empty_priority_class_arrays(d): def _rule_to_template(rule): unscoped_rule_id = None - if 'rule_id' in rule: - unscoped_rule_id = _rule_id_from_namespaced(rule['rule_id']) + if "rule_id" in rule: + unscoped_rule_id = _rule_id_from_namespaced(rule["rule_id"]) - template_name = _priority_class_to_template_name(rule['priority_class']) - if template_name in ['override', 'underride']: + template_name = _priority_class_to_template_name(rule["priority_class"]) + if template_name in ["override", "underride"]: templaterule = {k: rule[k] for k in ["conditions", "actions"]} elif template_name in ["sender", "room"]: - templaterule = {'actions': rule['actions']} - unscoped_rule_id = rule['conditions'][0]['pattern'] - elif template_name == 'content': + templaterule = {"actions": rule["actions"]} + unscoped_rule_id = rule["conditions"][0]["pattern"] + elif template_name == "content": if len(rule["conditions"]) != 1: return None thecond = rule["conditions"][0] if "pattern" not in thecond: return None - templaterule = {'actions': rule['actions']} + templaterule = {"actions": rule["actions"]} templaterule["pattern"] = thecond["pattern"] if unscoped_rule_id: - templaterule['rule_id'] = unscoped_rule_id - if 'default' in rule: - templaterule['default'] = rule['default'] + templaterule["rule_id"] = unscoped_rule_id + if "default" in rule: + templaterule["default"] = rule["default"] return templaterule def _rule_id_from_namespaced(in_rule_id): - return in_rule_id.split('/')[-1] + return in_rule_id.split("/")[-1] def _priority_class_to_template_name(pc): diff --git a/synapse/push/emailpusher.py b/synapse/push/emailpusher.py index c89a8438a938..424ffa8b682c 100644 --- a/synapse/push/emailpusher.py +++ b/synapse/push/emailpusher.py @@ -32,13 +32,13 @@ THROTTLE_START_MS = 10 * 60 * 1000 THROTTLE_MAX_MS = 24 * 60 * 60 * 1000 # 24h # THROTTLE_MULTIPLIER = 6 # 10 mins, 1 hour, 6 hours, 24 hours -THROTTLE_MULTIPLIER = 144 # 10 mins, 24 hours - i.e. jump straight to 1 day +THROTTLE_MULTIPLIER = 144 # 10 mins, 24 hours - i.e. jump straight to 1 day # If no event triggers a notification for this long after the previous, # the throttle is released. # 12 hours - a gap of 12 hours in conversation is surely enough to merit a new # notification when things get going again... -THROTTLE_RESET_AFTER_MS = (12 * 60 * 60 * 1000) +THROTTLE_RESET_AFTER_MS = 12 * 60 * 60 * 1000 # does each email include all unread notifs, or just the ones which have happened # since the last mail? @@ -53,17 +53,18 @@ class EmailPusher(object): This shares quite a bit of code with httpusher: it would be good to factor out the common parts """ + def __init__(self, hs, pusherdict, mailer): self.hs = hs self.mailer = mailer self.store = self.hs.get_datastore() self.clock = self.hs.get_clock() - self.pusher_id = pusherdict['id'] - self.user_id = pusherdict['user_name'] - self.app_id = pusherdict['app_id'] - self.email = pusherdict['pushkey'] - self.last_stream_ordering = pusherdict['last_stream_ordering'] + self.pusher_id = pusherdict["id"] + self.user_id = pusherdict["user_name"] + self.app_id = pusherdict["app_id"] + self.email = pusherdict["pushkey"] + self.last_stream_ordering = pusherdict["last_stream_ordering"] self.timed_call = None self.throttle_params = None @@ -93,7 +94,9 @@ def on_stop(self): def on_new_notifications(self, min_stream_ordering, max_stream_ordering): if self.max_stream_ordering: - self.max_stream_ordering = max(max_stream_ordering, self.max_stream_ordering) + self.max_stream_ordering = max( + max_stream_ordering, self.max_stream_ordering + ) else: self.max_stream_ordering = max_stream_ordering self._start_processing() @@ -174,14 +177,12 @@ def _unsafe_process(self): return for push_action in unprocessed: - received_at = push_action['received_ts'] + received_at = push_action["received_ts"] if received_at is None: received_at = 0 notif_ready_at = received_at + DELAY_BEFORE_MAIL_MS - room_ready_at = self.room_ready_to_notify_at( - push_action['room_id'] - ) + room_ready_at = self.room_ready_to_notify_at(push_action["room_id"]) should_notify_at = max(notif_ready_at, room_ready_at) @@ -192,25 +193,23 @@ def _unsafe_process(self): # to be delivered. reason = { - 'room_id': push_action['room_id'], - 'now': self.clock.time_msec(), - 'received_at': received_at, - 'delay_before_mail_ms': DELAY_BEFORE_MAIL_MS, - 'last_sent_ts': self.get_room_last_sent_ts(push_action['room_id']), - 'throttle_ms': self.get_room_throttle_ms(push_action['room_id']), + "room_id": push_action["room_id"], + "now": self.clock.time_msec(), + "received_at": received_at, + "delay_before_mail_ms": DELAY_BEFORE_MAIL_MS, + "last_sent_ts": self.get_room_last_sent_ts(push_action["room_id"]), + "throttle_ms": self.get_room_throttle_ms(push_action["room_id"]), } yield self.send_notification(unprocessed, reason) - yield self.save_last_stream_ordering_and_success(max([ - ea['stream_ordering'] for ea in unprocessed - ])) + yield self.save_last_stream_ordering_and_success( + max([ea["stream_ordering"] for ea in unprocessed]) + ) # we update the throttle on all the possible unprocessed push actions for ea in unprocessed: - yield self.sent_notif_update_throttle( - ea['room_id'], ea - ) + yield self.sent_notif_update_throttle(ea["room_id"], ea) break else: if soonest_due_at is None or should_notify_at < soonest_due_at: @@ -236,8 +235,11 @@ def save_last_stream_ordering_and_success(self, last_stream_ordering): self.last_stream_ordering = last_stream_ordering yield self.store.update_pusher_last_stream_ordering_and_success( - self.app_id, self.email, self.user_id, - last_stream_ordering, self.clock.time_msec() + self.app_id, + self.email, + self.user_id, + last_stream_ordering, + self.clock.time_msec(), ) def seconds_until(self, ts_msec): @@ -276,10 +278,10 @@ def sent_notif_update_throttle(self, room_id, notified_push_action): # THROTTLE_RESET_AFTER_MS after the previous one that triggered a # notif, we release the throttle. Otherwise, the throttle is increased. time_of_previous_notifs = yield self.store.get_time_of_last_push_action_before( - notified_push_action['stream_ordering'] + notified_push_action["stream_ordering"] ) - time_of_this_notifs = notified_push_action['received_ts'] + time_of_this_notifs = notified_push_action["received_ts"] if time_of_previous_notifs is not None and time_of_this_notifs is not None: gap = time_of_this_notifs - time_of_previous_notifs @@ -298,12 +300,11 @@ def sent_notif_update_throttle(self, room_id, notified_push_action): new_throttle_ms = THROTTLE_START_MS else: new_throttle_ms = min( - current_throttle_ms * THROTTLE_MULTIPLIER, - THROTTLE_MAX_MS + current_throttle_ms * THROTTLE_MULTIPLIER, THROTTLE_MAX_MS ) self.throttle_params[room_id] = { "last_sent_ts": self.clock.time_msec(), - "throttle_ms": new_throttle_ms + "throttle_ms": new_throttle_ms, } yield self.store.set_throttle_params( self.pusher_id, room_id, self.throttle_params[room_id] diff --git a/synapse/push/httppusher.py b/synapse/push/httppusher.py index fac05aa44c08..4e7b6a553124 100644 --- a/synapse/push/httppusher.py +++ b/synapse/push/httppusher.py @@ -65,16 +65,16 @@ def __init__(self, hs, pusherdict): self.store = self.hs.get_datastore() self.clock = self.hs.get_clock() self.state_handler = self.hs.get_state_handler() - self.user_id = pusherdict['user_name'] - self.app_id = pusherdict['app_id'] - self.app_display_name = pusherdict['app_display_name'] - self.device_display_name = pusherdict['device_display_name'] - self.pushkey = pusherdict['pushkey'] - self.pushkey_ts = pusherdict['ts'] - self.data = pusherdict['data'] - self.last_stream_ordering = pusherdict['last_stream_ordering'] + self.user_id = pusherdict["user_name"] + self.app_id = pusherdict["app_id"] + self.app_display_name = pusherdict["app_display_name"] + self.device_display_name = pusherdict["device_display_name"] + self.pushkey = pusherdict["pushkey"] + self.pushkey_ts = pusherdict["ts"] + self.data = pusherdict["data"] + self.last_stream_ordering = pusherdict["last_stream_ordering"] self.backoff_delay = HttpPusher.INITIAL_BACKOFF_SEC - self.failing_since = pusherdict['failing_since'] + self.failing_since = pusherdict["failing_since"] self.timed_call = None self._is_processing = False @@ -85,32 +85,26 @@ def __init__(self, hs, pusherdict): # off as None though as we don't know any better. self.max_stream_ordering = None - if 'data' not in pusherdict: - raise PusherConfigException( - "No 'data' key for HTTP pusher" - ) - self.data = pusherdict['data'] + if "data" not in pusherdict: + raise PusherConfigException("No 'data' key for HTTP pusher") + self.data = pusherdict["data"] self.name = "%s/%s/%s" % ( - pusherdict['user_name'], - pusherdict['app_id'], - pusherdict['pushkey'], + pusherdict["user_name"], + pusherdict["app_id"], + pusherdict["pushkey"], ) if self.data is None: - raise PusherConfigException( - "data can not be null for HTTP pusher" - ) + raise PusherConfigException("data can not be null for HTTP pusher") - if 'url' not in self.data: - raise PusherConfigException( - "'url' required in data for HTTP pusher" - ) - self.url = self.data['url'] + if "url" not in self.data: + raise PusherConfigException("'url' required in data for HTTP pusher") + self.url = self.data["url"] self.http_client = hs.get_simple_http_client() self.data_minus_url = {} self.data_minus_url.update(self.data) - del self.data_minus_url['url'] + del self.data_minus_url["url"] def on_started(self, should_check_for_notifs): """Called when this pusher has been started. @@ -124,7 +118,9 @@ def on_started(self, should_check_for_notifs): self._start_processing() def on_new_notifications(self, min_stream_ordering, max_stream_ordering): - self.max_stream_ordering = max(max_stream_ordering, self.max_stream_ordering or 0) + self.max_stream_ordering = max( + max_stream_ordering, self.max_stream_ordering or 0 + ) self._start_processing() def on_new_receipts(self, min_stream_id, max_stream_id): @@ -192,7 +188,9 @@ def _unsafe_process(self): logger.info( "Processing %i unprocessed push actions for %s starting at " "stream_ordering %s", - len(unprocessed), self.name, self.last_stream_ordering, + len(unprocessed), + self.name, + self.last_stream_ordering, ) for push_action in unprocessed: @@ -200,71 +198,72 @@ def _unsafe_process(self): if processed: http_push_processed_counter.inc() self.backoff_delay = HttpPusher.INITIAL_BACKOFF_SEC - self.last_stream_ordering = push_action['stream_ordering'] + self.last_stream_ordering = push_action["stream_ordering"] yield self.store.update_pusher_last_stream_ordering_and_success( - self.app_id, self.pushkey, self.user_id, + self.app_id, + self.pushkey, + self.user_id, self.last_stream_ordering, - self.clock.time_msec() + self.clock.time_msec(), ) if self.failing_since: self.failing_since = None yield self.store.update_pusher_failing_since( - self.app_id, self.pushkey, self.user_id, - self.failing_since + self.app_id, self.pushkey, self.user_id, self.failing_since ) else: http_push_failed_counter.inc() if not self.failing_since: self.failing_since = self.clock.time_msec() yield self.store.update_pusher_failing_since( - self.app_id, self.pushkey, self.user_id, - self.failing_since + self.app_id, self.pushkey, self.user_id, self.failing_since ) if ( - self.failing_since and - self.failing_since < - self.clock.time_msec() - HttpPusher.GIVE_UP_AFTER_MS + self.failing_since + and self.failing_since + < self.clock.time_msec() - HttpPusher.GIVE_UP_AFTER_MS ): # we really only give up so that if the URL gets # fixed, we don't suddenly deliver a load # of old notifications. - logger.warn("Giving up on a notification to user %s, " - "pushkey %s", - self.user_id, self.pushkey) + logger.warn( + "Giving up on a notification to user %s, " "pushkey %s", + self.user_id, + self.pushkey, + ) self.backoff_delay = HttpPusher.INITIAL_BACKOFF_SEC - self.last_stream_ordering = push_action['stream_ordering'] + self.last_stream_ordering = push_action["stream_ordering"] yield self.store.update_pusher_last_stream_ordering( self.app_id, self.pushkey, self.user_id, - self.last_stream_ordering + self.last_stream_ordering, ) self.failing_since = None yield self.store.update_pusher_failing_since( - self.app_id, - self.pushkey, - self.user_id, - self.failing_since + self.app_id, self.pushkey, self.user_id, self.failing_since ) else: logger.info("Push failed: delaying for %ds", self.backoff_delay) self.timed_call = self.hs.get_reactor().callLater( self.backoff_delay, self.on_timer ) - self.backoff_delay = min(self.backoff_delay * 2, self.MAX_BACKOFF_SEC) + self.backoff_delay = min( + self.backoff_delay * 2, self.MAX_BACKOFF_SEC + ) break @defer.inlineCallbacks def _process_one(self, push_action): - if 'notify' not in push_action['actions']: + if "notify" not in push_action["actions"]: defer.returnValue(True) - tweaks = push_rule_evaluator.tweaks_for_actions(push_action['actions']) + tweaks = push_rule_evaluator.tweaks_for_actions(push_action["actions"]) badge = yield push_tools.get_badge_count(self.hs.get_datastore(), self.user_id) - event = yield self.store.get_event(push_action['event_id'], allow_none=True) + event = yield self.store.get_event(push_action["event_id"], allow_none=True) if event is None: defer.returnValue(True) # It's been redacted rejected = yield self.dispatch_push(event, tweaks, badge) @@ -277,37 +276,30 @@ def _process_one(self, push_action): # for sanity, we only remove the pushkey if it # was the one we actually sent... logger.warn( - ("Ignoring rejected pushkey %s because we" - " didn't send it"), pk + ("Ignoring rejected pushkey %s because we" " didn't send it"), + pk, ) else: - logger.info( - "Pushkey %s was rejected: removing", - pk - ) - yield self.hs.remove_pusher( - self.app_id, pk, self.user_id - ) + logger.info("Pushkey %s was rejected: removing", pk) + yield self.hs.remove_pusher(self.app_id, pk, self.user_id) defer.returnValue(True) @defer.inlineCallbacks def _build_notification_dict(self, event, tweaks, badge): - if self.data.get('format') == 'event_id_only': + if self.data.get("format") == "event_id_only": d = { - 'notification': { - 'event_id': event.event_id, - 'room_id': event.room_id, - 'counts': { - 'unread': badge, - }, - 'devices': [ + "notification": { + "event_id": event.event_id, + "room_id": event.room_id, + "counts": {"unread": badge}, + "devices": [ { - 'app_id': self.app_id, - 'pushkey': self.pushkey, - 'pushkey_ts': long(self.pushkey_ts / 1000), - 'data': self.data_minus_url, + "app_id": self.app_id, + "pushkey": self.pushkey, + "pushkey_ts": long(self.pushkey_ts / 1000), + "data": self.data_minus_url, } - ] + ], } } defer.returnValue(d) @@ -317,41 +309,41 @@ def _build_notification_dict(self, event, tweaks, badge): ) d = { - 'notification': { - 'id': event.event_id, # deprecated: remove soon - 'event_id': event.event_id, - 'room_id': event.room_id, - 'type': event.type, - 'sender': event.user_id, - 'counts': { # -- we don't mark messages as read yet so - # we have no way of knowing + "notification": { + "id": event.event_id, # deprecated: remove soon + "event_id": event.event_id, + "room_id": event.room_id, + "type": event.type, + "sender": event.user_id, + "counts": { # -- we don't mark messages as read yet so + # we have no way of knowing # Just set the badge to 1 until we have read receipts - 'unread': badge, + "unread": badge, # 'missed_calls': 2 }, - 'devices': [ + "devices": [ { - 'app_id': self.app_id, - 'pushkey': self.pushkey, - 'pushkey_ts': long(self.pushkey_ts / 1000), - 'data': self.data_minus_url, - 'tweaks': tweaks + "app_id": self.app_id, + "pushkey": self.pushkey, + "pushkey_ts": long(self.pushkey_ts / 1000), + "data": self.data_minus_url, + "tweaks": tweaks, } - ] + ], } } - if event.type == 'm.room.member' and event.is_state(): - d['notification']['membership'] = event.content['membership'] - d['notification']['user_is_target'] = event.state_key == self.user_id + if event.type == "m.room.member" and event.is_state(): + d["notification"]["membership"] = event.content["membership"] + d["notification"]["user_is_target"] = event.state_key == self.user_id if self.hs.config.push_include_content and event.content: - d['notification']['content'] = event.content + d["notification"]["content"] = event.content # We no longer send aliases separately, instead, we send the human # readable name of the room, which may be an alias. - if 'sender_display_name' in ctx and len(ctx['sender_display_name']) > 0: - d['notification']['sender_display_name'] = ctx['sender_display_name'] - if 'name' in ctx and len(ctx['name']) > 0: - d['notification']['room_name'] = ctx['name'] + if "sender_display_name" in ctx and len(ctx["sender_display_name"]) > 0: + d["notification"]["sender_display_name"] = ctx["sender_display_name"] + if "name" in ctx and len(ctx["name"]) > 0: + d["notification"]["room_name"] = ctx["name"] defer.returnValue(d) @@ -361,16 +353,21 @@ def dispatch_push(self, event, tweaks, badge): if not notification_dict: defer.returnValue([]) try: - resp = yield self.http_client.post_json_get_json(self.url, notification_dict) + resp = yield self.http_client.post_json_get_json( + self.url, notification_dict + ) except Exception as e: logger.warning( "Failed to push event %s to %s: %s %s", - event.event_id, self.name, type(e), e, + event.event_id, + self.name, + type(e), + e, ) defer.returnValue(False) rejected = [] - if 'rejected' in resp: - rejected = resp['rejected'] + if "rejected" in resp: + rejected = resp["rejected"] defer.returnValue(rejected) @defer.inlineCallbacks @@ -381,21 +378,19 @@ def _send_badge(self, badge): """ logger.info("Sending updated badge count %d to %s", badge, self.name) d = { - 'notification': { - 'id': '', - 'type': None, - 'sender': '', - 'counts': { - 'unread': badge - }, - 'devices': [ + "notification": { + "id": "", + "type": None, + "sender": "", + "counts": {"unread": badge}, + "devices": [ { - 'app_id': self.app_id, - 'pushkey': self.pushkey, - 'pushkey_ts': long(self.pushkey_ts / 1000), - 'data': self.data_minus_url, + "app_id": self.app_id, + "pushkey": self.pushkey, + "pushkey_ts": long(self.pushkey_ts / 1000), + "data": self.data_minus_url, } - ] + ], } } try: @@ -403,7 +398,6 @@ def _send_badge(self, badge): http_badges_processed_counter.inc() except Exception as e: logger.warning( - "Failed to send badge count to %s: %s %s", - self.name, type(e), e, + "Failed to send badge count to %s: %s %s", self.name, type(e), e ) http_badges_failed_counter.inc() diff --git a/synapse/push/mailer.py b/synapse/push/mailer.py index 099f9545ab14..17c7d3195a0c 100644 --- a/synapse/push/mailer.py +++ b/synapse/push/mailer.py @@ -42,17 +42,21 @@ logger = logging.getLogger(__name__) -MESSAGE_FROM_PERSON_IN_ROOM = "You have a message on %(app)s from %(person)s " \ - "in the %(room)s room..." +MESSAGE_FROM_PERSON_IN_ROOM = ( + "You have a message on %(app)s from %(person)s " "in the %(room)s room..." +) MESSAGE_FROM_PERSON = "You have a message on %(app)s from %(person)s..." MESSAGES_FROM_PERSON = "You have messages on %(app)s from %(person)s..." MESSAGES_IN_ROOM = "You have messages on %(app)s in the %(room)s room..." -MESSAGES_IN_ROOM_AND_OTHERS = \ +MESSAGES_IN_ROOM_AND_OTHERS = ( "You have messages on %(app)s in the %(room)s room and others..." -MESSAGES_FROM_PERSON_AND_OTHERS = \ +) +MESSAGES_FROM_PERSON_AND_OTHERS = ( "You have messages on %(app)s from %(person)s and others..." -INVITE_FROM_PERSON_TO_ROOM = "%(person)s has invited you to join the " \ - "%(room)s room on %(app)s..." +) +INVITE_FROM_PERSON_TO_ROOM = ( + "%(person)s has invited you to join the " "%(room)s room on %(app)s..." +) INVITE_FROM_PERSON = "%(person)s has invited you to chat on %(app)s..." CONTEXT_BEFORE = 1 @@ -60,12 +64,38 @@ # From /~https://github.com/matrix-org/matrix-react-sdk/blob/master/src/HtmlUtils.js ALLOWED_TAGS = [ - 'font', # custom to matrix for IRC-style font coloring - 'del', # for markdown + "font", # custom to matrix for IRC-style font coloring + "del", # for markdown # deliberately no h1/h2 to stop people shouting. - 'h3', 'h4', 'h5', 'h6', 'blockquote', 'p', 'a', 'ul', 'ol', - 'nl', 'li', 'b', 'i', 'u', 'strong', 'em', 'strike', 'code', 'hr', 'br', 'div', - 'table', 'thead', 'caption', 'tbody', 'tr', 'th', 'td', 'pre' + "h3", + "h4", + "h5", + "h6", + "blockquote", + "p", + "a", + "ul", + "ol", + "nl", + "li", + "b", + "i", + "u", + "strong", + "em", + "strike", + "code", + "hr", + "br", + "div", + "table", + "thead", + "caption", + "tbody", + "tr", + "th", + "td", + "pre", ] ALLOWED_ATTRS = { # custom ones first: @@ -94,13 +124,7 @@ def __init__(self, hs, app_name, template_html, template_text): logger.info("Created Mailer for app_name %s" % app_name) @defer.inlineCallbacks - def send_password_reset_mail( - self, - email_address, - token, - client_secret, - sid, - ): + def send_password_reset_mail(self, email_address, token, client_secret, sid): """Send an email with a password reset link to a user Args: @@ -112,19 +136,16 @@ def send_password_reset_mail( group together multiple email sending attempts sid (str): The generated session ID """ - if email.utils.parseaddr(email_address)[1] == '': + if email.utils.parseaddr(email_address)[1] == "": raise RuntimeError("Invalid 'to' email address") link = ( - self.hs.config.public_baseurl + - "_matrix/client/unstable/password_reset/email/submit_token" - "?token=%s&client_secret=%s&sid=%s" % - (token, client_secret, sid) + self.hs.config.public_baseurl + + "_matrix/client/unstable/password_reset/email/submit_token" + "?token=%s&client_secret=%s&sid=%s" % (token, client_secret, sid) ) - template_vars = { - "link": link, - } + template_vars = {"link": link} yield self.send_email( email_address, @@ -133,15 +154,14 @@ def send_password_reset_mail( ) @defer.inlineCallbacks - def send_notification_mail(self, app_id, user_id, email_address, - push_actions, reason): + def send_notification_mail( + self, app_id, user_id, email_address, push_actions, reason + ): """Send email regarding a user's room notifications""" - rooms_in_order = deduped_ordered_list( - [pa['room_id'] for pa in push_actions] - ) + rooms_in_order = deduped_ordered_list([pa["room_id"] for pa in push_actions]) notif_events = yield self.store.get_events( - [pa['event_id'] for pa in push_actions] + [pa["event_id"] for pa in push_actions] ) notifs_by_room = {} @@ -171,9 +191,7 @@ def _fetch_room_state(room_id): yield concurrently_execute(_fetch_room_state, rooms_in_order, 3) # actually sort our so-called rooms_in_order list, most recent room first - rooms_in_order.sort( - key=lambda r: -(notifs_by_room[r][-1]['received_ts'] or 0) - ) + rooms_in_order.sort(key=lambda r: -(notifs_by_room[r][-1]["received_ts"] or 0)) rooms = [] @@ -183,9 +201,11 @@ def _fetch_room_state(room_id): ) rooms.append(roomvars) - reason['room_name'] = yield calculate_room_name( - self.store, state_by_room[reason['room_id']], user_id, - fallback_to_members=True + reason["room_name"] = yield calculate_room_name( + self.store, + state_by_room[reason["room_id"]], + user_id, + fallback_to_members=True, ) summary_text = yield self.make_summary_text( @@ -204,25 +224,21 @@ def _fetch_room_state(room_id): } yield self.send_email( - email_address, - "[%s] %s" % (self.app_name, summary_text), - template_vars, + email_address, "[%s] %s" % (self.app_name, summary_text), template_vars ) @defer.inlineCallbacks def send_email(self, email_address, subject, template_vars): """Send an email with the given information and template text""" try: - from_string = self.hs.config.email_notif_from % { - "app": self.app_name - } + from_string = self.hs.config.email_notif_from % {"app": self.app_name} except TypeError: from_string = self.hs.config.email_notif_from raw_from = email.utils.parseaddr(from_string)[1] raw_to = email.utils.parseaddr(email_address)[1] - if raw_to == '': + if raw_to == "": raise RuntimeError("Invalid 'to' address") html_text = self.template_html.render(**template_vars) @@ -231,27 +247,31 @@ def send_email(self, email_address, subject, template_vars): plain_text = self.template_text.render(**template_vars) text_part = MIMEText(plain_text, "plain", "utf8") - multipart_msg = MIMEMultipart('alternative') - multipart_msg['Subject'] = subject - multipart_msg['From'] = from_string - multipart_msg['To'] = email_address - multipart_msg['Date'] = email.utils.formatdate() - multipart_msg['Message-ID'] = email.utils.make_msgid() + multipart_msg = MIMEMultipart("alternative") + multipart_msg["Subject"] = subject + multipart_msg["From"] = from_string + multipart_msg["To"] = email_address + multipart_msg["Date"] = email.utils.formatdate() + multipart_msg["Message-ID"] = email.utils.make_msgid() multipart_msg.attach(text_part) multipart_msg.attach(html_part) logger.info("Sending email push notification to %s" % email_address) - yield make_deferred_yieldable(self.sendmail( - self.hs.config.email_smtp_host, - raw_from, raw_to, multipart_msg.as_string().encode('utf8'), - reactor=self.hs.get_reactor(), - port=self.hs.config.email_smtp_port, - requireAuthentication=self.hs.config.email_smtp_user is not None, - username=self.hs.config.email_smtp_user, - password=self.hs.config.email_smtp_pass, - requireTransportSecurity=self.hs.config.require_transport_security - )) + yield make_deferred_yieldable( + self.sendmail( + self.hs.config.email_smtp_host, + raw_from, + raw_to, + multipart_msg.as_string().encode("utf8"), + reactor=self.hs.get_reactor(), + port=self.hs.config.email_smtp_port, + requireAuthentication=self.hs.config.email_smtp_user is not None, + username=self.hs.config.email_smtp_user, + password=self.hs.config.email_smtp_pass, + requireTransportSecurity=self.hs.config.require_transport_security, + ) + ) @defer.inlineCallbacks def get_room_vars(self, room_id, user_id, notifs, notif_events, room_state_ids): @@ -272,17 +292,18 @@ def get_room_vars(self, room_id, user_id, notifs, notif_events, room_state_ids): if not is_invite: for n in notifs: notifvars = yield self.get_notif_vars( - n, user_id, notif_events[n['event_id']], room_state_ids + n, user_id, notif_events[n["event_id"]], room_state_ids ) # merge overlapping notifs together. # relies on the notifs being in chronological order. merge = False - if room_vars['notifs'] and 'messages' in room_vars['notifs'][-1]: - prev_messages = room_vars['notifs'][-1]['messages'] - for message in notifvars['messages']: - pm = list(filter(lambda pm: pm['id'] == message['id'], - prev_messages)) + if room_vars["notifs"] and "messages" in room_vars["notifs"][-1]: + prev_messages = room_vars["notifs"][-1]["messages"] + for message in notifvars["messages"]: + pm = list( + filter(lambda pm: pm["id"] == message["id"], prev_messages) + ) if pm: if not message["is_historical"]: pm[0]["is_historical"] = False @@ -293,20 +314,22 @@ def get_room_vars(self, room_id, user_id, notifs, notif_events, room_state_ids): prev_messages.append(message) if not merge: - room_vars['notifs'].append(notifvars) + room_vars["notifs"].append(notifvars) defer.returnValue(room_vars) @defer.inlineCallbacks def get_notif_vars(self, notif, user_id, notif_event, room_state_ids): results = yield self.store.get_events_around( - notif['room_id'], notif['event_id'], - before_limit=CONTEXT_BEFORE, after_limit=CONTEXT_AFTER + notif["room_id"], + notif["event_id"], + before_limit=CONTEXT_BEFORE, + after_limit=CONTEXT_AFTER, ) ret = { "link": self.make_notif_link(notif), - "ts": notif['received_ts'], + "ts": notif["received_ts"], "messages": [], } @@ -318,7 +341,7 @@ def get_notif_vars(self, notif, user_id, notif_event, room_state_ids): for event in the_events: messagevars = yield self.get_message_vars(notif, event, room_state_ids) if messagevars is not None: - ret['messages'].append(messagevars) + ret["messages"].append(messagevars) defer.returnValue(ret) @@ -340,7 +363,7 @@ def get_message_vars(self, notif, event, room_state_ids): ret = { "msgtype": msgtype, - "is_historical": event.event_id != notif['event_id'], + "is_historical": event.event_id != notif["event_id"], "id": event.event_id, "ts": event.origin_server_ts, "sender_name": sender_name, @@ -379,8 +402,9 @@ def add_image_message_vars(self, messagevars, event): return messagevars @defer.inlineCallbacks - def make_summary_text(self, notifs_by_room, room_state_ids, - notif_events, user_id, reason): + def make_summary_text( + self, notifs_by_room, room_state_ids, notif_events, user_id, reason + ): if len(notifs_by_room) == 1: # Only one room has new stuff room_id = list(notifs_by_room.keys())[0] @@ -404,16 +428,19 @@ def make_summary_text(self, notifs_by_room, room_state_ids, inviter_name = name_from_member_event(inviter_member_event) if room_name is None: - defer.returnValue(INVITE_FROM_PERSON % { - "person": inviter_name, - "app": self.app_name - }) + defer.returnValue( + INVITE_FROM_PERSON + % {"person": inviter_name, "app": self.app_name} + ) else: - defer.returnValue(INVITE_FROM_PERSON_TO_ROOM % { - "person": inviter_name, - "room": room_name, - "app": self.app_name, - }) + defer.returnValue( + INVITE_FROM_PERSON_TO_ROOM + % { + "person": inviter_name, + "room": room_name, + "app": self.app_name, + } + ) sender_name = None if len(notifs_by_room[room_id]) == 1: @@ -427,67 +454,86 @@ def make_summary_text(self, notifs_by_room, room_state_ids, sender_name = name_from_member_event(state_event) if sender_name is not None and room_name is not None: - defer.returnValue(MESSAGE_FROM_PERSON_IN_ROOM % { - "person": sender_name, - "room": room_name, - "app": self.app_name, - }) + defer.returnValue( + MESSAGE_FROM_PERSON_IN_ROOM + % { + "person": sender_name, + "room": room_name, + "app": self.app_name, + } + ) elif sender_name is not None: - defer.returnValue(MESSAGE_FROM_PERSON % { - "person": sender_name, - "app": self.app_name, - }) + defer.returnValue( + MESSAGE_FROM_PERSON + % {"person": sender_name, "app": self.app_name} + ) else: # There's more than one notification for this room, so just # say there are several if room_name is not None: - defer.returnValue(MESSAGES_IN_ROOM % { - "room": room_name, - "app": self.app_name, - }) + defer.returnValue( + MESSAGES_IN_ROOM % {"room": room_name, "app": self.app_name} + ) else: # If the room doesn't have a name, say who the messages # are from explicitly to avoid, "messages in the Bob room" - sender_ids = list(set([ - notif_events[n['event_id']].sender - for n in notifs_by_room[room_id] - ])) - - member_events = yield self.store.get_events([ - room_state_ids[room_id][("m.room.member", s)] - for s in sender_ids - ]) - - defer.returnValue(MESSAGES_FROM_PERSON % { - "person": descriptor_from_member_events(member_events.values()), - "app": self.app_name, - }) + sender_ids = list( + set( + [ + notif_events[n["event_id"]].sender + for n in notifs_by_room[room_id] + ] + ) + ) + + member_events = yield self.store.get_events( + [ + room_state_ids[room_id][("m.room.member", s)] + for s in sender_ids + ] + ) + + defer.returnValue( + MESSAGES_FROM_PERSON + % { + "person": descriptor_from_member_events( + member_events.values() + ), + "app": self.app_name, + } + ) else: # Stuff's happened in multiple different rooms # ...but we still refer to the 'reason' room which triggered the mail - if reason['room_name'] is not None: - defer.returnValue(MESSAGES_IN_ROOM_AND_OTHERS % { - "room": reason['room_name'], - "app": self.app_name, - }) + if reason["room_name"] is not None: + defer.returnValue( + MESSAGES_IN_ROOM_AND_OTHERS + % {"room": reason["room_name"], "app": self.app_name} + ) else: # If the reason room doesn't have a name, say who the messages # are from explicitly to avoid, "messages in the Bob room" - sender_ids = list(set([ - notif_events[n['event_id']].sender - for n in notifs_by_room[reason['room_id']] - ])) + sender_ids = list( + set( + [ + notif_events[n["event_id"]].sender + for n in notifs_by_room[reason["room_id"]] + ] + ) + ) - member_events = yield self.store.get_events([ - room_state_ids[room_id][("m.room.member", s)] - for s in sender_ids - ]) + member_events = yield self.store.get_events( + [room_state_ids[room_id][("m.room.member", s)] for s in sender_ids] + ) - defer.returnValue(MESSAGES_FROM_PERSON_AND_OTHERS % { - "person": descriptor_from_member_events(member_events.values()), - "app": self.app_name, - }) + defer.returnValue( + MESSAGES_FROM_PERSON_AND_OTHERS + % { + "person": descriptor_from_member_events(member_events.values()), + "app": self.app_name, + } + ) def make_room_link(self, room_id): if self.hs.config.email_riot_base_url: @@ -503,17 +549,17 @@ def make_notif_link(self, notif): if self.hs.config.email_riot_base_url: return "%s/#/room/%s/%s" % ( self.hs.config.email_riot_base_url, - notif['room_id'], notif['event_id'] + notif["room_id"], + notif["event_id"], ) elif self.app_name == "Vector": # need /beta for Universal Links to work on iOS return "https://vector.im/beta/#/room/%s/%s" % ( - notif['room_id'], notif['event_id'] + notif["room_id"], + notif["event_id"], ) else: - return "https://matrix.to/#/%s/%s" % ( - notif['room_id'], notif['event_id'] - ) + return "https://matrix.to/#/%s/%s" % (notif["room_id"], notif["event_id"]) def make_unsubscribe_link(self, user_id, app_id, email_address): params = { @@ -530,12 +576,18 @@ def make_unsubscribe_link(self, user_id, app_id, email_address): def safe_markup(raw_html): - return jinja2.Markup(bleach.linkify(bleach.clean( - raw_html, tags=ALLOWED_TAGS, attributes=ALLOWED_ATTRS, - # bleach master has this, but it isn't released yet - # protocols=ALLOWED_SCHEMES, - strip=True - ))) + return jinja2.Markup( + bleach.linkify( + bleach.clean( + raw_html, + tags=ALLOWED_TAGS, + attributes=ALLOWED_ATTRS, + # bleach master has this, but it isn't released yet + # protocols=ALLOWED_SCHEMES, + strip=True, + ) + ) + ) def safe_text(raw_text): @@ -543,10 +595,9 @@ def safe_text(raw_text): Process text: treat it as HTML but escape any tags (ie. just escape the HTML) then linkify it. """ - return jinja2.Markup(bleach.linkify(bleach.clean( - raw_text, tags=[], attributes={}, - strip=False - ))) + return jinja2.Markup( + bleach.linkify(bleach.clean(raw_text, tags=[], attributes={}, strip=False)) + ) def deduped_ordered_list(l): @@ -595,15 +646,11 @@ def mxc_to_http_filter(value, width, height, resize_method="crop"): serverAndMediaId = value[6:] fragment = None - if '#' in serverAndMediaId: - (serverAndMediaId, fragment) = serverAndMediaId.split('#', 1) + if "#" in serverAndMediaId: + (serverAndMediaId, fragment) = serverAndMediaId.split("#", 1) fragment = "#" + fragment - params = { - "width": width, - "height": height, - "method": resize_method, - } + params = {"width": width, "height": height, "method": resize_method} return "%s_matrix/media/v1/thumbnail/%s?%s%s" % ( config.public_baseurl, serverAndMediaId, diff --git a/synapse/push/presentable_names.py b/synapse/push/presentable_names.py index 0c66702325ad..06056fbf4fae 100644 --- a/synapse/push/presentable_names.py +++ b/synapse/push/presentable_names.py @@ -28,8 +28,13 @@ @defer.inlineCallbacks -def calculate_room_name(store, room_state_ids, user_id, fallback_to_members=True, - fallback_to_single_member=True): +def calculate_room_name( + store, + room_state_ids, + user_id, + fallback_to_members=True, + fallback_to_single_member=True, +): """ Works out a user-facing name for the given room as per Matrix spec recommendations. @@ -58,8 +63,10 @@ def calculate_room_name(store, room_state_ids, user_id, fallback_to_members=True room_state_ids[("m.room.canonical_alias", "")], allow_none=True ) if ( - canon_alias and canon_alias.content and canon_alias.content["alias"] and - _looks_like_an_alias(canon_alias.content["alias"]) + canon_alias + and canon_alias.content + and canon_alias.content["alias"] + and _looks_like_an_alias(canon_alias.content["alias"]) ): defer.returnValue(canon_alias.content["alias"]) @@ -71,9 +78,7 @@ def calculate_room_name(store, room_state_ids, user_id, fallback_to_members=True if "m.room.aliases" in room_state_bytype_ids: m_room_aliases = room_state_bytype_ids["m.room.aliases"] for alias_id in m_room_aliases.values(): - alias_event = yield store.get_event( - alias_id, allow_none=True - ) + alias_event = yield store.get_event(alias_id, allow_none=True) if alias_event and alias_event.content.get("aliases"): the_aliases = alias_event.content["aliases"] if len(the_aliases) > 0 and _looks_like_an_alias(the_aliases[0]): @@ -89,8 +94,8 @@ def calculate_room_name(store, room_state_ids, user_id, fallback_to_members=True ) if ( - my_member_event is not None and - my_member_event.content['membership'] == "invite" + my_member_event is not None + and my_member_event.content["membership"] == "invite" ): if ("m.room.member", my_member_event.sender) in room_state_ids: inviter_member_event = yield store.get_event( @@ -100,9 +105,8 @@ def calculate_room_name(store, room_state_ids, user_id, fallback_to_members=True if inviter_member_event: if fallback_to_single_member: defer.returnValue( - "Invite from %s" % ( - name_from_member_event(inviter_member_event), - ) + "Invite from %s" + % (name_from_member_event(inviter_member_event),) ) else: return @@ -116,8 +120,10 @@ def calculate_room_name(store, room_state_ids, user_id, fallback_to_members=True list(room_state_bytype_ids["m.room.member"].values()) ) all_members = [ - ev for ev in member_events.values() - if ev.content['membership'] == "join" or ev.content['membership'] == "invite" + ev + for ev in member_events.values() + if ev.content["membership"] == "join" + or ev.content["membership"] == "invite" ] # Sort the member events oldest-first so the we name people in the # order the joined (it should at least be deterministic rather than @@ -134,9 +140,9 @@ def calculate_room_name(store, room_state_ids, user_id, fallback_to_members=True # or inbound invite, or outbound 3PID invite. if all_members[0].sender == user_id: if "m.room.third_party_invite" in room_state_bytype_ids: - third_party_invites = ( - room_state_bytype_ids["m.room.third_party_invite"].values() - ) + third_party_invites = room_state_bytype_ids[ + "m.room.third_party_invite" + ].values() if len(third_party_invites) > 0: # technically third party invite events are not member @@ -191,8 +197,9 @@ def descriptor_from_member_events(member_events): def name_from_member_event(member_event): if ( - member_event.content and "displayname" in member_event.content and - member_event.content["displayname"] + member_event.content + and "displayname" in member_event.content + and member_event.content["displayname"] ): return member_event.content["displayname"] return member_event.state_key diff --git a/synapse/push/push_rule_evaluator.py b/synapse/push/push_rule_evaluator.py index cf6c8b875e04..5ed9147de4da 100644 --- a/synapse/push/push_rule_evaluator.py +++ b/synapse/push/push_rule_evaluator.py @@ -26,8 +26,8 @@ logger = logging.getLogger(__name__) -GLOB_REGEX = re.compile(r'\\\[(\\\!|)(.*)\\\]') -IS_GLOB = re.compile(r'[\?\*\[\]]') +GLOB_REGEX = re.compile(r"\\\[(\\\!|)(.*)\\\]") +IS_GLOB = re.compile(r"[\?\*\[\]]") INEQUALITY_EXPR = re.compile("^([=<>]*)([0-9]*)$") @@ -36,20 +36,20 @@ def _room_member_count(ev, condition, room_member_count): def _sender_notification_permission(ev, condition, sender_power_level, power_levels): - notif_level_key = condition.get('key') + notif_level_key = condition.get("key") if notif_level_key is None: return False - notif_levels = power_levels.get('notifications', {}) + notif_levels = power_levels.get("notifications", {}) room_notif_level = notif_levels.get(notif_level_key, 50) return sender_power_level >= room_notif_level def _test_ineq_condition(condition, number): - if 'is' not in condition: + if "is" not in condition: return False - m = INEQUALITY_EXPR.match(condition['is']) + m = INEQUALITY_EXPR.match(condition["is"]) if not m: return False ineq = m.group(1) @@ -58,15 +58,15 @@ def _test_ineq_condition(condition, number): return False rhs = int(rhs) - if ineq == '' or ineq == '==': + if ineq == "" or ineq == "==": return number == rhs - elif ineq == '<': + elif ineq == "<": return number < rhs - elif ineq == '>': + elif ineq == ">": return number > rhs - elif ineq == '>=': + elif ineq == ">=": return number >= rhs - elif ineq == '<=': + elif ineq == "<=": return number <= rhs else: return False @@ -77,8 +77,8 @@ def tweaks_for_actions(actions): for a in actions: if not isinstance(a, dict): continue - if 'set_tweak' in a and 'value' in a: - tweaks[a['set_tweak']] = a['value'] + if "set_tweak" in a and "value" in a: + tweaks[a["set_tweak"]] = a["value"] return tweaks @@ -93,26 +93,24 @@ def __init__(self, event, room_member_count, sender_power_level, power_levels): self._value_cache = _flatten_dict(event) def matches(self, condition, user_id, display_name): - if condition['kind'] == 'event_match': + if condition["kind"] == "event_match": return self._event_match(condition, user_id) - elif condition['kind'] == 'contains_display_name': + elif condition["kind"] == "contains_display_name": return self._contains_display_name(display_name) - elif condition['kind'] == 'room_member_count': - return _room_member_count( - self._event, condition, self._room_member_count - ) - elif condition['kind'] == 'sender_notification_permission': + elif condition["kind"] == "room_member_count": + return _room_member_count(self._event, condition, self._room_member_count) + elif condition["kind"] == "sender_notification_permission": return _sender_notification_permission( - self._event, condition, self._sender_power_level, self._power_levels, + self._event, condition, self._sender_power_level, self._power_levels ) else: return True def _event_match(self, condition, user_id): - pattern = condition.get('pattern', None) + pattern = condition.get("pattern", None) if not pattern: - pattern_type = condition.get('pattern_type', None) + pattern_type = condition.get("pattern_type", None) if pattern_type == "user_id": pattern = user_id elif pattern_type == "user_localpart": @@ -123,14 +121,14 @@ def _event_match(self, condition, user_id): return False # XXX: optimisation: cache our pattern regexps - if condition['key'] == 'content.body': + if condition["key"] == "content.body": body = self._event.content.get("body", None) if not body: return False return _glob_matches(pattern, body, word_boundary=True) else: - haystack = self._get_value(condition['key']) + haystack = self._get_value(condition["key"]) if haystack is None: return False @@ -193,16 +191,13 @@ def _glob_to_re(glob, word_boundary): if IS_GLOB.search(glob): r = re.escape(glob) - r = r.replace(r'\*', '.*?') - r = r.replace(r'\?', '.') + r = r.replace(r"\*", ".*?") + r = r.replace(r"\?", ".") # handle [abc], [a-z] and [!a-z] style ranges. r = GLOB_REGEX.sub( lambda x: ( - '[%s%s]' % ( - x.group(1) and '^' or '', - x.group(2).replace(r'\\\-', '-') - ) + "[%s%s]" % (x.group(1) and "^" or "", x.group(2).replace(r"\\\-", "-")) ), r, ) diff --git a/synapse/push/push_tools.py b/synapse/push/push_tools.py index 8049c298c208..e37269cdb934 100644 --- a/synapse/push/push_tools.py +++ b/synapse/push/push_tools.py @@ -23,9 +23,7 @@ def get_badge_count(store, user_id): invites = yield store.get_invited_rooms_for_user(user_id) joins = yield store.get_rooms_for_user(user_id) - my_receipts_by_room = yield store.get_receipts_for_user( - user_id, "m.read", - ) + my_receipts_by_room = yield store.get_receipts_for_user(user_id, "m.read") badge = len(invites) @@ -57,10 +55,10 @@ def get_context_for_event(store, state_handler, ev, user_id): store, room_state_ids, user_id, fallback_to_single_member=False ) if name: - ctx['name'] = name + ctx["name"] = name sender_state_event_id = room_state_ids[("m.room.member", ev.sender)] sender_state_event = yield store.get_event(sender_state_event_id) - ctx['sender_display_name'] = name_from_member_event(sender_state_event) + ctx["sender_display_name"] = name_from_member_event(sender_state_event) defer.returnValue(ctx) diff --git a/synapse/push/pusher.py b/synapse/push/pusher.py index aff85daeb5b5..a9c64a9c5401 100644 --- a/synapse/push/pusher.py +++ b/synapse/push/pusher.py @@ -36,9 +36,7 @@ class PusherFactory(object): def __init__(self, hs): self.hs = hs - self.pusher_types = { - "http": HttpPusher, - } + self.pusher_types = {"http": HttpPusher} logger.info("email enable notifs: %r", hs.config.email_enable_notifs) if hs.config.email_enable_notifs: @@ -56,7 +54,7 @@ def __init__(self, hs): logger.info("defined email pusher type") def create_pusher(self, pusherdict): - kind = pusherdict['kind'] + kind = pusherdict["kind"] f = self.pusher_types.get(kind, None) if not f: return None @@ -77,8 +75,8 @@ def _create_email_pusher(self, _hs, pusherdict): return EmailPusher(self.hs, pusherdict, mailer) def _app_name_from_pusherdict(self, pusherdict): - if 'data' in pusherdict and 'brand' in pusherdict['data']: - app_name = pusherdict['data']['brand'] + if "data" in pusherdict and "brand" in pusherdict["data"]: + app_name = pusherdict["data"]["brand"] else: app_name = self.hs.config.email_app_name diff --git a/synapse/push/pusherpool.py b/synapse/push/pusherpool.py index 63c583565fa0..df6f67074033 100644 --- a/synapse/push/pusherpool.py +++ b/synapse/push/pusherpool.py @@ -40,6 +40,7 @@ class PusherPool: notifications are sent; accordingly Pusher.on_started, Pusher.on_new_notifications and Pusher.on_new_receipts are not expected to return deferreds. """ + def __init__(self, _hs): self.hs = _hs self.pusher_factory = PusherFactory(_hs) @@ -57,9 +58,19 @@ def start(self): run_as_background_process("start_pushers", self._start_pushers) @defer.inlineCallbacks - def add_pusher(self, user_id, access_token, kind, app_id, - app_display_name, device_display_name, pushkey, lang, data, - profile_tag=""): + def add_pusher( + self, + user_id, + access_token, + kind, + app_id, + app_display_name, + device_display_name, + pushkey, + lang, + data, + profile_tag="", + ): """Creates a new pusher and adds it to the pool Returns: @@ -71,21 +82,23 @@ def add_pusher(self, user_id, access_token, kind, app_id, # will then get pulled out of the database, # recreated, added and started: this means we have only one # code path adding pushers. - self.pusher_factory.create_pusher({ - "id": None, - "user_name": user_id, - "kind": kind, - "app_id": app_id, - "app_display_name": app_display_name, - "device_display_name": device_display_name, - "pushkey": pushkey, - "ts": time_now_msec, - "lang": lang, - "data": data, - "last_stream_ordering": None, - "last_success": None, - "failing_since": None - }) + self.pusher_factory.create_pusher( + { + "id": None, + "user_name": user_id, + "kind": kind, + "app_id": app_id, + "app_display_name": app_display_name, + "device_display_name": device_display_name, + "pushkey": pushkey, + "ts": time_now_msec, + "lang": lang, + "data": data, + "last_stream_ordering": None, + "last_success": None, + "failing_since": None, + } + ) # create the pusher setting last_stream_ordering to the current maximum # stream ordering in event_push_actions, so it will process @@ -113,18 +126,19 @@ def add_pusher(self, user_id, access_token, kind, app_id, defer.returnValue(pusher) @defer.inlineCallbacks - def remove_pushers_by_app_id_and_pushkey_not_user(self, app_id, pushkey, - not_user_id): - to_remove = yield self.store.get_pushers_by_app_id_and_pushkey( - app_id, pushkey - ) + def remove_pushers_by_app_id_and_pushkey_not_user( + self, app_id, pushkey, not_user_id + ): + to_remove = yield self.store.get_pushers_by_app_id_and_pushkey(app_id, pushkey) for p in to_remove: - if p['user_name'] != not_user_id: + if p["user_name"] != not_user_id: logger.info( "Removing pusher for app id %s, pushkey %s, user %s", - app_id, pushkey, p['user_name'] + app_id, + pushkey, + p["user_name"], ) - yield self.remove_pusher(p['app_id'], p['pushkey'], p['user_name']) + yield self.remove_pusher(p["app_id"], p["pushkey"], p["user_name"]) @defer.inlineCallbacks def remove_pushers_by_access_token(self, user_id, access_tokens): @@ -138,14 +152,14 @@ def remove_pushers_by_access_token(self, user_id, access_tokens): """ tokens = set(access_tokens) for p in (yield self.store.get_pushers_by_user_id(user_id)): - if p['access_token'] in tokens: + if p["access_token"] in tokens: logger.info( "Removing pusher for app id %s, pushkey %s, user %s", - p['app_id'], p['pushkey'], p['user_name'] - ) - yield self.remove_pusher( - p['app_id'], p['pushkey'], p['user_name'], + p["app_id"], + p["pushkey"], + p["user_name"], ) + yield self.remove_pusher(p["app_id"], p["pushkey"], p["user_name"]) @defer.inlineCallbacks def on_new_notifications(self, min_stream_id, max_stream_id): @@ -199,13 +213,11 @@ def start_pusher_by_id(self, app_id, pushkey, user_id): if not self._should_start_pushers: return - resultlist = yield self.store.get_pushers_by_app_id_and_pushkey( - app_id, pushkey - ) + resultlist = yield self.store.get_pushers_by_app_id_and_pushkey(app_id, pushkey) pusher_dict = None for r in resultlist: - if r['user_name'] == user_id: + if r["user_name"] == user_id: pusher_dict = r pusher = None @@ -245,9 +257,9 @@ def _start_pusher(self, pusherdict): except PusherConfigException as e: logger.warning( "Pusher incorrectly configured user=%s, appid=%s, pushkey=%s: %s", - pusherdict.get('user_name'), - pusherdict.get('app_id'), - pusherdict.get('pushkey'), + pusherdict.get("user_name"), + pusherdict.get("app_id"), + pusherdict.get("pushkey"), e, ) return @@ -258,11 +270,8 @@ def _start_pusher(self, pusherdict): if not p: return - appid_pushkey = "%s:%s" % ( - pusherdict['app_id'], - pusherdict['pushkey'], - ) - byuser = self.pushers.setdefault(pusherdict['user_name'], {}) + appid_pushkey = "%s:%s" % (pusherdict["app_id"], pusherdict["pushkey"]) + byuser = self.pushers.setdefault(pusherdict["user_name"], {}) if appid_pushkey in byuser: byuser[appid_pushkey].on_stop() @@ -275,7 +284,7 @@ def _start_pusher(self, pusherdict): last_stream_ordering = pusherdict["last_stream_ordering"] if last_stream_ordering: have_notifs = yield self.store.get_if_maybe_push_in_range_for_user( - user_id, last_stream_ordering, + user_id, last_stream_ordering ) else: # We always want to default to starting up the pusher rather than diff --git a/synapse/push/rulekinds.py b/synapse/push/rulekinds.py index 4cae48ac073b..ce7cc1b4eebe 100644 --- a/synapse/push/rulekinds.py +++ b/synapse/push/rulekinds.py @@ -13,10 +13,10 @@ # limitations under the License. PRIORITY_CLASS_MAP = { - 'underride': 1, - 'sender': 2, - 'room': 3, - 'content': 4, - 'override': 5, + "underride": 1, + "sender": 2, + "room": 3, + "content": 4, + "override": 5, } PRIORITY_CLASS_INVERSE_MAP = {v: k for k, v in PRIORITY_CLASS_MAP.items()} diff --git a/synapse/python_dependencies.py b/synapse/python_dependencies.py index 11ace2bfb17f..77e2d1a0e100 100644 --- a/synapse/python_dependencies.py +++ b/synapse/python_dependencies.py @@ -45,14 +45,11 @@ "signedjson>=1.0.0", "pynacl>=1.2.1", "idna>=2.5", - # validating SSL certs for IP addresses requires service_identity 18.1. "service_identity>=18.1.0", - # our logcontext handling relies on the ability to cancel inlineCallbacks # (https://twistedmatrix.com/trac/ticket/4632) which landed in Twisted 18.7. "Twisted>=18.7.0", - "treq>=15.1", # Twisted has required pyopenssl 16.0 since about Twisted 16.6. "pyopenssl>=16.0.0", @@ -71,34 +68,27 @@ # prometheus_client 0.4.0 changed the format of counter metrics # (cf /~https://github.com/matrix-org/synapse/issues/4001) "prometheus_client>=0.0.18,<0.4.0", - # we use attr.s(slots), which arrived in 16.0.0 # Twisted 18.7.0 requires attrs>=17.4.0 "attrs>=17.4.0", - "netaddr>=0.7.18", ] CONDITIONAL_REQUIREMENTS = { "email": ["Jinja2>=2.9", "bleach>=1.4.3"], "matrix-synapse-ldap3": ["matrix-synapse-ldap3>=0.1"], - # we use execute_batch, which arrived in psycopg 2.7. "postgres": ["psycopg2>=2.7"], - # ConsentResource uses select_autoescape, which arrived in jinja 2.9 "resources.consent": ["Jinja2>=2.9"], - # ACME support is required to provision TLS certificates from authorities # that use the protocol, such as Let's Encrypt. "acme": [ "txacme>=0.9.2", - # txacme depends on eliot. Eliot 1.8.0 is incompatible with # python 3.5.2, as per /~https://github.com/itamarst/eliot/issues/418 'eliot<1.8.0;python_version<"3.5.3"', ], - "saml2": ["pysaml2>=4.5.0"], "systemd": ["systemd-python>=231"], "url_preview": ["lxml>=3.5.0"], @@ -121,12 +111,14 @@ def list_requirements(): class DependencyException(Exception): @property def message(self): - return "\n".join([ - "Missing Requirements: %s" % (", ".join(self.dependencies),), - "To install run:", - " pip install --upgrade --force %s" % (" ".join(self.dependencies),), - "", - ]) + return "\n".join( + [ + "Missing Requirements: %s" % (", ".join(self.dependencies),), + "To install run:", + " pip install --upgrade --force %s" % (" ".join(self.dependencies),), + "", + ] + ) @property def dependencies(self): diff --git a/synapse/replication/http/_base.py b/synapse/replication/http/_base.py index 0a432a16fa5b..fe482e279fd1 100644 --- a/synapse/replication/http/_base.py +++ b/synapse/replication/http/_base.py @@ -83,8 +83,7 @@ class ReplicationEndpoint(object): def __init__(self, hs): if self.CACHE: self.response_cache = ResponseCache( - hs, "repl." + self.NAME, - timeout_ms=30 * 60 * 1000, + hs, "repl." + self.NAME, timeout_ms=30 * 60 * 1000 ) assert self.METHOD in ("PUT", "POST", "GET") @@ -134,8 +133,7 @@ def send_request(**kwargs): data = yield cls._serialize_payload(**kwargs) url_args = [ - urllib.parse.quote(kwargs[name], safe='') - for name in cls.PATH_ARGS + urllib.parse.quote(kwargs[name], safe="") for name in cls.PATH_ARGS ] if cls.CACHE: @@ -156,7 +154,10 @@ def send_request(**kwargs): ) uri = "http://%s:%s/_synapse/replication/%s/%s" % ( - host, port, cls.NAME, "/".join(url_args) + host, + port, + cls.NAME, + "/".join(url_args), ) try: @@ -202,10 +203,7 @@ def register(self, http_server): url_args.append("txn_id") args = "/".join("(?P<%s>[^/]+)" % (arg,) for arg in url_args) - pattern = re.compile("^/_synapse/replication/%s/%s$" % ( - self.NAME, - args - )) + pattern = re.compile("^/_synapse/replication/%s/%s$" % (self.NAME, args)) http_server.register_paths(method, [pattern], handler) @@ -219,8 +217,4 @@ def _cached_handler(self, request, txn_id, **kwargs): assert self.CACHE - return self.response_cache.wrap( - txn_id, - self._handle_request, - request, **kwargs - ) + return self.response_cache.wrap(txn_id, self._handle_request, request, **kwargs) diff --git a/synapse/replication/http/federation.py b/synapse/replication/http/federation.py index 0f0a07c42256..61eafbe708bd 100644 --- a/synapse/replication/http/federation.py +++ b/synapse/replication/http/federation.py @@ -68,18 +68,17 @@ def _serialize_payload(store, event_and_contexts, backfilled): for event, context in event_and_contexts: serialized_context = yield context.serialize(event, store) - event_payloads.append({ - "event": event.get_pdu_json(), - "event_format_version": event.format_version, - "internal_metadata": event.internal_metadata.get_dict(), - "rejected_reason": event.rejected_reason, - "context": serialized_context, - }) - - payload = { - "events": event_payloads, - "backfilled": backfilled, - } + event_payloads.append( + { + "event": event.get_pdu_json(), + "event_format_version": event.format_version, + "internal_metadata": event.internal_metadata.get_dict(), + "rejected_reason": event.rejected_reason, + "context": serialized_context, + } + ) + + payload = {"events": event_payloads, "backfilled": backfilled} defer.returnValue(payload) @@ -103,18 +102,15 @@ def _handle_request(self, request): event = EventType(event_dict, internal_metadata, rejected_reason) context = yield EventContext.deserialize( - self.store, event_payload["context"], + self.store, event_payload["context"] ) event_and_contexts.append((event, context)) - logger.info( - "Got %d events from federation", - len(event_and_contexts), - ) + logger.info("Got %d events from federation", len(event_and_contexts)) yield self.federation_handler.persist_events_and_notify( - event_and_contexts, backfilled, + event_and_contexts, backfilled ) defer.returnValue((200, {})) @@ -146,10 +142,7 @@ def __init__(self, hs): @staticmethod def _serialize_payload(edu_type, origin, content): - return { - "origin": origin, - "content": content, - } + return {"origin": origin, "content": content} @defer.inlineCallbacks def _handle_request(self, request, edu_type): @@ -159,10 +152,7 @@ def _handle_request(self, request, edu_type): origin = content["origin"] edu_content = content["content"] - logger.info( - "Got %r edu from %s", - edu_type, origin, - ) + logger.info("Got %r edu from %s", edu_type, origin) result = yield self.registry.on_edu(edu_type, origin, edu_content) @@ -201,9 +191,7 @@ def _serialize_payload(query_type, args): query_type (str) args (dict): The arguments received for the given query type """ - return { - "args": args, - } + return {"args": args} @defer.inlineCallbacks def _handle_request(self, request, query_type): @@ -212,10 +200,7 @@ def _handle_request(self, request, query_type): args = content["args"] - logger.info( - "Got %r query", - query_type, - ) + logger.info("Got %r query", query_type) result = yield self.registry.on_query(query_type, args) diff --git a/synapse/replication/http/login.py b/synapse/replication/http/login.py index 63bc0405ea9b..7c1197e5ddcc 100644 --- a/synapse/replication/http/login.py +++ b/synapse/replication/http/login.py @@ -61,13 +61,10 @@ def _handle_request(self, request, user_id): is_guest = content["is_guest"] device_id, access_token = yield self.registration_handler.register_device( - user_id, device_id, initial_display_name, is_guest, + user_id, device_id, initial_display_name, is_guest ) - defer.returnValue((200, { - "device_id": device_id, - "access_token": access_token, - })) + defer.returnValue((200, {"device_id": device_id, "access_token": access_token})) def register_servlets(hs, http_server): diff --git a/synapse/replication/http/membership.py b/synapse/replication/http/membership.py index 81a2b204c7aa..0a76a3762f1c 100644 --- a/synapse/replication/http/membership.py +++ b/synapse/replication/http/membership.py @@ -40,7 +40,7 @@ class ReplicationRemoteJoinRestServlet(ReplicationEndpoint): """ NAME = "remote_join" - PATH_ARGS = ("room_id", "user_id",) + PATH_ARGS = ("room_id", "user_id") def __init__(self, hs): super(ReplicationRemoteJoinRestServlet, self).__init__(hs) @@ -50,8 +50,7 @@ def __init__(self, hs): self.clock = hs.get_clock() @staticmethod - def _serialize_payload(requester, room_id, user_id, remote_room_hosts, - content): + def _serialize_payload(requester, room_id, user_id, remote_room_hosts, content): """ Args: requester(Requester) @@ -78,16 +77,10 @@ def _handle_request(self, request, room_id, user_id): if requester.user: request.authenticated_entity = requester.user.to_string() - logger.info( - "remote_join: %s into room: %s", - user_id, room_id, - ) + logger.info("remote_join: %s into room: %s", user_id, room_id) yield self.federation_handler.do_invite_join( - remote_room_hosts, - room_id, - user_id, - event_content, + remote_room_hosts, room_id, user_id, event_content ) defer.returnValue((200, {})) @@ -107,7 +100,7 @@ class ReplicationRemoteRejectInviteRestServlet(ReplicationEndpoint): """ NAME = "remote_reject_invite" - PATH_ARGS = ("room_id", "user_id",) + PATH_ARGS = ("room_id", "user_id") def __init__(self, hs): super(ReplicationRemoteRejectInviteRestServlet, self).__init__(hs) @@ -141,16 +134,11 @@ def _handle_request(self, request, room_id, user_id): if requester.user: request.authenticated_entity = requester.user.to_string() - logger.info( - "remote_reject_invite: %s out of room: %s", - user_id, room_id, - ) + logger.info("remote_reject_invite: %s out of room: %s", user_id, room_id) try: event = yield self.federation_handler.do_remotely_reject_invite( - remote_room_hosts, - room_id, - user_id, + remote_room_hosts, room_id, user_id ) ret = event.get_pdu_json() except Exception as e: @@ -162,9 +150,7 @@ def _handle_request(self, request, room_id, user_id): # logger.warn("Failed to reject invite: %s", e) - yield self.store.locally_reject_invite( - user_id, room_id - ) + yield self.store.locally_reject_invite(user_id, room_id) ret = {} defer.returnValue((200, ret)) @@ -228,7 +214,7 @@ def _handle_request(self, request): logger.info("get_or_register_3pid_guest: %r", content) ret = yield self.registeration_handler.get_or_register_3pid_guest( - medium, address, inviter_user_id, + medium, address, inviter_user_id ) defer.returnValue((200, ret)) @@ -264,7 +250,7 @@ def _serialize_payload(room_id, user_id, change): user_id (str) change (str): Either "joined" or "left" """ - assert change in ("joined", "left",) + assert change in ("joined", "left") return {} diff --git a/synapse/replication/http/register.py b/synapse/replication/http/register.py index 912a5ac3410c..f81a0f1b8f54 100644 --- a/synapse/replication/http/register.py +++ b/synapse/replication/http/register.py @@ -37,8 +37,16 @@ def __init__(self, hs): @staticmethod def _serialize_payload( - user_id, token, password_hash, was_guest, make_guest, appservice_id, - create_profile_with_displayname, admin, user_type, address, + user_id, + token, + password_hash, + was_guest, + make_guest, + appservice_id, + create_profile_with_displayname, + admin, + user_type, + address, ): """ Args: @@ -85,7 +93,7 @@ def _handle_request(self, request, user_id): create_profile_with_displayname=content["create_profile_with_displayname"], admin=content["admin"], user_type=content["user_type"], - address=content["address"] + address=content["address"], ) defer.returnValue((200, {})) @@ -104,8 +112,7 @@ def __init__(self, hs): self.registration_handler = hs.get_registration_handler() @staticmethod - def _serialize_payload(user_id, auth_result, access_token, bind_email, - bind_msisdn): + def _serialize_payload(user_id, auth_result, access_token, bind_email, bind_msisdn): """ Args: user_id (str): The user ID that consented diff --git a/synapse/replication/http/send_event.py b/synapse/replication/http/send_event.py index 3635015eda37..034763fe993a 100644 --- a/synapse/replication/http/send_event.py +++ b/synapse/replication/http/send_event.py @@ -45,6 +45,7 @@ class ReplicationSendEventRestServlet(ReplicationEndpoint): "extra_users": [], } """ + NAME = "send_event" PATH_ARGS = ("event_id",) @@ -57,8 +58,9 @@ def __init__(self, hs): @staticmethod @defer.inlineCallbacks - def _serialize_payload(event_id, store, event, context, requester, - ratelimit, extra_users): + def _serialize_payload( + event_id, store, event, context, requester, ratelimit, extra_users + ): """ Args: event_id (str) @@ -108,14 +110,11 @@ def _handle_request(self, request, event_id): request.authenticated_entity = requester.user.to_string() logger.info( - "Got event to send with ID: %s into room: %s", - event.event_id, event.room_id, + "Got event to send with ID: %s into room: %s", event.event_id, event.room_id ) yield self.event_creation_handler.persist_and_notify_client_event( - requester, event, context, - ratelimit=ratelimit, - extra_users=extra_users, + requester, event, context, ratelimit=ratelimit, extra_users=extra_users ) defer.returnValue((200, {})) diff --git a/synapse/replication/slave/storage/_base.py b/synapse/replication/slave/storage/_base.py index 817d1f67f95b..182cb2a1d8c8 100644 --- a/synapse/replication/slave/storage/_base.py +++ b/synapse/replication/slave/storage/_base.py @@ -37,7 +37,7 @@ def __init__(self, db_conn, hs): super(BaseSlavedStore, self).__init__(db_conn, hs) if isinstance(self.database_engine, PostgresEngine): self._cache_id_gen = SlavedIdTracker( - db_conn, "cache_invalidation_stream", "stream_id", + db_conn, "cache_invalidation_stream", "stream_id" ) else: self._cache_id_gen = None diff --git a/synapse/replication/slave/storage/account_data.py b/synapse/replication/slave/storage/account_data.py index d9ba6d69b10f..3c44d1d48da4 100644 --- a/synapse/replication/slave/storage/account_data.py +++ b/synapse/replication/slave/storage/account_data.py @@ -21,10 +21,9 @@ class SlavedAccountDataStore(TagsWorkerStore, AccountDataWorkerStore, BaseSlavedStore): - def __init__(self, db_conn, hs): self._account_data_id_gen = SlavedIdTracker( - db_conn, "account_data_max_stream_id", "stream_id", + db_conn, "account_data_max_stream_id", "stream_id" ) super(SlavedAccountDataStore, self).__init__(db_conn, hs) @@ -45,24 +44,20 @@ def process_replication_rows(self, stream_name, token, rows): self._account_data_id_gen.advance(token) for row in rows: self.get_tags_for_user.invalidate((row.user_id,)) - self._account_data_stream_cache.entity_has_changed( - row.user_id, token - ) + self._account_data_stream_cache.entity_has_changed(row.user_id, token) elif stream_name == "account_data": self._account_data_id_gen.advance(token) for row in rows: if not row.room_id: self.get_global_account_data_by_type_for_user.invalidate( - (row.data_type, row.user_id,) + (row.data_type, row.user_id) ) self.get_account_data_for_user.invalidate((row.user_id,)) - self.get_account_data_for_room.invalidate((row.user_id, row.room_id,)) + self.get_account_data_for_room.invalidate((row.user_id, row.room_id)) self.get_account_data_for_room_and_type.invalidate( - (row.user_id, row.room_id, row.data_type,), - ) - self._account_data_stream_cache.entity_has_changed( - row.user_id, token + (row.user_id, row.room_id, row.data_type) ) + self._account_data_stream_cache.entity_has_changed(row.user_id, token) return super(SlavedAccountDataStore, self).process_replication_rows( stream_name, token, rows ) diff --git a/synapse/replication/slave/storage/appservice.py b/synapse/replication/slave/storage/appservice.py index b53a4c6bd113..cda12ea70d6c 100644 --- a/synapse/replication/slave/storage/appservice.py +++ b/synapse/replication/slave/storage/appservice.py @@ -20,6 +20,7 @@ ) -class SlavedApplicationServiceStore(ApplicationServiceTransactionWorkerStore, - ApplicationServiceWorkerStore): +class SlavedApplicationServiceStore( + ApplicationServiceTransactionWorkerStore, ApplicationServiceWorkerStore +): pass diff --git a/synapse/replication/slave/storage/client_ips.py b/synapse/replication/slave/storage/client_ips.py index 5b8521c7704c..14ced3233360 100644 --- a/synapse/replication/slave/storage/client_ips.py +++ b/synapse/replication/slave/storage/client_ips.py @@ -25,9 +25,7 @@ def __init__(self, db_conn, hs): super(SlavedClientIpStore, self).__init__(db_conn, hs) self.client_ip_last_seen = Cache( - name="client_ip_last_seen", - keylen=4, - max_entries=50000 * CACHE_SIZE_FACTOR, + name="client_ip_last_seen", keylen=4, max_entries=50000 * CACHE_SIZE_FACTOR ) def insert_client_ip(self, user_id, access_token, ip, user_agent, device_id): diff --git a/synapse/replication/slave/storage/deviceinbox.py b/synapse/replication/slave/storage/deviceinbox.py index 4d5977886364..284fd30d896d 100644 --- a/synapse/replication/slave/storage/deviceinbox.py +++ b/synapse/replication/slave/storage/deviceinbox.py @@ -24,15 +24,15 @@ class SlavedDeviceInboxStore(DeviceInboxWorkerStore, BaseSlavedStore): def __init__(self, db_conn, hs): super(SlavedDeviceInboxStore, self).__init__(db_conn, hs) self._device_inbox_id_gen = SlavedIdTracker( - db_conn, "device_max_stream_id", "stream_id", + db_conn, "device_max_stream_id", "stream_id" ) self._device_inbox_stream_cache = StreamChangeCache( "DeviceInboxStreamChangeCache", - self._device_inbox_id_gen.get_current_token() + self._device_inbox_id_gen.get_current_token(), ) self._device_federation_outbox_stream_cache = StreamChangeCache( "DeviceFederationOutboxStreamChangeCache", - self._device_inbox_id_gen.get_current_token() + self._device_inbox_id_gen.get_current_token(), ) self._last_device_delete_cache = ExpiringCache( diff --git a/synapse/replication/slave/storage/devices.py b/synapse/replication/slave/storage/devices.py index 16c9a162c516..d9300fce33a0 100644 --- a/synapse/replication/slave/storage/devices.py +++ b/synapse/replication/slave/storage/devices.py @@ -27,14 +27,14 @@ def __init__(self, db_conn, hs): self.hs = hs self._device_list_id_gen = SlavedIdTracker( - db_conn, "device_lists_stream", "stream_id", + db_conn, "device_lists_stream", "stream_id" ) device_list_max = self._device_list_id_gen.get_current_token() self._device_list_stream_cache = StreamChangeCache( - "DeviceListStreamChangeCache", device_list_max, + "DeviceListStreamChangeCache", device_list_max ) self._device_list_federation_stream_cache = StreamChangeCache( - "DeviceListFederationStreamChangeCache", device_list_max, + "DeviceListFederationStreamChangeCache", device_list_max ) def stream_positions(self): @@ -46,17 +46,13 @@ def process_replication_rows(self, stream_name, token, rows): if stream_name == "device_lists": self._device_list_id_gen.advance(token) for row in rows: - self._invalidate_caches_for_devices( - token, row.user_id, row.destination, - ) + self._invalidate_caches_for_devices(token, row.user_id, row.destination) return super(SlavedDeviceStore, self).process_replication_rows( stream_name, token, rows ) def _invalidate_caches_for_devices(self, token, user_id, destination): - self._device_list_stream_cache.entity_has_changed( - user_id, token - ) + self._device_list_stream_cache.entity_has_changed(user_id, token) if destination: self._device_list_federation_stream_cache.entity_has_changed( diff --git a/synapse/replication/slave/storage/events.py b/synapse/replication/slave/storage/events.py index a3952506c13e..ab5937e6384a 100644 --- a/synapse/replication/slave/storage/events.py +++ b/synapse/replication/slave/storage/events.py @@ -45,21 +45,20 @@ # the method descriptor on the DataStore and chuck them into our class. -class SlavedEventStore(EventFederationWorkerStore, - RoomMemberWorkerStore, - EventPushActionsWorkerStore, - StreamWorkerStore, - StateGroupWorkerStore, - EventsWorkerStore, - SignatureWorkerStore, - UserErasureWorkerStore, - RelationsWorkerStore, - BaseSlavedStore): - +class SlavedEventStore( + EventFederationWorkerStore, + RoomMemberWorkerStore, + EventPushActionsWorkerStore, + StreamWorkerStore, + StateGroupWorkerStore, + EventsWorkerStore, + SignatureWorkerStore, + UserErasureWorkerStore, + RelationsWorkerStore, + BaseSlavedStore, +): def __init__(self, db_conn, hs): - self._stream_id_gen = SlavedIdTracker( - db_conn, "events", "stream_ordering", - ) + self._stream_id_gen = SlavedIdTracker(db_conn, "events", "stream_ordering") self._backfill_id_gen = SlavedIdTracker( db_conn, "events", "stream_ordering", step=-1 ) @@ -90,8 +89,13 @@ def process_replication_rows(self, stream_name, token, rows): self._backfill_id_gen.advance(-token) for row in rows: self.invalidate_caches_for_event( - -token, row.event_id, row.room_id, row.type, row.state_key, - row.redacts, row.relates_to, + -token, + row.event_id, + row.room_id, + row.type, + row.state_key, + row.redacts, + row.relates_to, backfilled=True, ) return super(SlavedEventStore, self).process_replication_rows( @@ -103,41 +107,48 @@ def _process_event_stream_row(self, token, row): if row.type == EventsStreamEventRow.TypeId: self.invalidate_caches_for_event( - token, data.event_id, data.room_id, data.type, data.state_key, - data.redacts, data.relates_to, + token, + data.event_id, + data.room_id, + data.type, + data.state_key, + data.redacts, + data.relates_to, backfilled=False, ) elif row.type == EventsStreamCurrentStateRow.TypeId: if data.type == EventTypes.Member: self.get_rooms_for_user_with_stream_ordering.invalidate( - (data.state_key, ), + (data.state_key,) ) else: - raise Exception("Unknown events stream row type %s" % (row.type, )) - - def invalidate_caches_for_event(self, stream_ordering, event_id, room_id, - etype, state_key, redacts, relates_to, - backfilled): + raise Exception("Unknown events stream row type %s" % (row.type,)) + + def invalidate_caches_for_event( + self, + stream_ordering, + event_id, + room_id, + etype, + state_key, + redacts, + relates_to, + backfilled, + ): self._invalidate_get_event_cache(event_id) self.get_latest_event_ids_in_room.invalidate((room_id,)) - self.get_unread_event_push_actions_by_room_for_user.invalidate_many( - (room_id,) - ) + self.get_unread_event_push_actions_by_room_for_user.invalidate_many((room_id,)) if not backfilled: - self._events_stream_cache.entity_has_changed( - room_id, stream_ordering - ) + self._events_stream_cache.entity_has_changed(room_id, stream_ordering) if redacts: self._invalidate_get_event_cache(redacts) if etype == EventTypes.Member: - self._membership_stream_cache.entity_has_changed( - state_key, stream_ordering - ) + self._membership_stream_cache.entity_has_changed(state_key, stream_ordering) self.get_invited_rooms_for_user.invalidate((state_key,)) if relates_to: diff --git a/synapse/replication/slave/storage/groups.py b/synapse/replication/slave/storage/groups.py index e933b170bb15..28a46edd2869 100644 --- a/synapse/replication/slave/storage/groups.py +++ b/synapse/replication/slave/storage/groups.py @@ -27,10 +27,11 @@ def __init__(self, db_conn, hs): self.hs = hs self._group_updates_id_gen = SlavedIdTracker( - db_conn, "local_group_updates", "stream_id", + db_conn, "local_group_updates", "stream_id" ) self._group_updates_stream_cache = StreamChangeCache( - "_group_updates_stream_cache", self._group_updates_id_gen.get_current_token(), + "_group_updates_stream_cache", + self._group_updates_id_gen.get_current_token(), ) get_groups_changes_for_user = __func__(DataStore.get_groups_changes_for_user) @@ -46,9 +47,7 @@ def process_replication_rows(self, stream_name, token, rows): if stream_name == "groups": self._group_updates_id_gen.advance(token) for row in rows: - self._group_updates_stream_cache.entity_has_changed( - row.user_id, token - ) + self._group_updates_stream_cache.entity_has_changed(row.user_id, token) return super(SlavedGroupServerStore, self).process_replication_rows( stream_name, token, rows diff --git a/synapse/replication/slave/storage/presence.py b/synapse/replication/slave/storage/presence.py index 0ec1db25ce6d..82d808af4cdc 100644 --- a/synapse/replication/slave/storage/presence.py +++ b/synapse/replication/slave/storage/presence.py @@ -24,9 +24,7 @@ class SlavedPresenceStore(BaseSlavedStore): def __init__(self, db_conn, hs): super(SlavedPresenceStore, self).__init__(db_conn, hs) - self._presence_id_gen = SlavedIdTracker( - db_conn, "presence_stream", "stream_id", - ) + self._presence_id_gen = SlavedIdTracker(db_conn, "presence_stream", "stream_id") self._presence_on_startup = self._get_active_presence(db_conn) @@ -55,9 +53,7 @@ def process_replication_rows(self, stream_name, token, rows): if stream_name == "presence": self._presence_id_gen.advance(token) for row in rows: - self.presence_stream_cache.entity_has_changed( - row.user_id, token - ) + self.presence_stream_cache.entity_has_changed(row.user_id, token) self._get_presence_for_user.invalidate((row.user_id,)) return super(SlavedPresenceStore, self).process_replication_rows( stream_name, token, rows diff --git a/synapse/replication/slave/storage/push_rule.py b/synapse/replication/slave/storage/push_rule.py index 45fc913c52fc..af7012702e90 100644 --- a/synapse/replication/slave/storage/push_rule.py +++ b/synapse/replication/slave/storage/push_rule.py @@ -23,7 +23,7 @@ class SlavedPushRuleStore(SlavedEventStore, PushRulesWorkerStore): def __init__(self, db_conn, hs): self._push_rules_stream_id_gen = SlavedIdTracker( - db_conn, "push_rules_stream", "stream_id", + db_conn, "push_rules_stream", "stream_id" ) super(SlavedPushRuleStore, self).__init__(db_conn, hs) @@ -47,9 +47,7 @@ def process_replication_rows(self, stream_name, token, rows): for row in rows: self.get_push_rules_for_user.invalidate((row.user_id,)) self.get_push_rules_enabled_for_user.invalidate((row.user_id,)) - self.push_rules_stream_cache.entity_has_changed( - row.user_id, token - ) + self.push_rules_stream_cache.entity_has_changed(row.user_id, token) return super(SlavedPushRuleStore, self).process_replication_rows( stream_name, token, rows ) diff --git a/synapse/replication/slave/storage/pushers.py b/synapse/replication/slave/storage/pushers.py index 3b2213c0d43c..8eeb267d6167 100644 --- a/synapse/replication/slave/storage/pushers.py +++ b/synapse/replication/slave/storage/pushers.py @@ -21,12 +21,10 @@ class SlavedPusherStore(PusherWorkerStore, BaseSlavedStore): - def __init__(self, db_conn, hs): super(SlavedPusherStore, self).__init__(db_conn, hs) self._pushers_id_gen = SlavedIdTracker( - db_conn, "pushers", "id", - extra_tables=[("deleted_pushers", "stream_id")], + db_conn, "pushers", "id", extra_tables=[("deleted_pushers", "stream_id")] ) def stream_positions(self): diff --git a/synapse/replication/slave/storage/receipts.py b/synapse/replication/slave/storage/receipts.py index ed12342f407c..91afa5a72b30 100644 --- a/synapse/replication/slave/storage/receipts.py +++ b/synapse/replication/slave/storage/receipts.py @@ -29,7 +29,6 @@ class SlavedReceiptsStore(ReceiptsWorkerStore, BaseSlavedStore): - def __init__(self, db_conn, hs): # We instantiate this first as the ReceiptsWorkerStore constructor # needs to be able to call get_max_receipt_stream_id diff --git a/synapse/replication/slave/storage/room.py b/synapse/replication/slave/storage/room.py index 0cb474928c12..f68b3378e3c8 100644 --- a/synapse/replication/slave/storage/room.py +++ b/synapse/replication/slave/storage/room.py @@ -38,6 +38,4 @@ def process_replication_rows(self, stream_name, token, rows): if stream_name == "public_rooms": self._public_room_id_gen.advance(token) - return super(RoomStore, self).process_replication_rows( - stream_name, token, rows - ) + return super(RoomStore, self).process_replication_rows(stream_name, token, rows) diff --git a/synapse/replication/tcp/client.py b/synapse/replication/tcp/client.py index 206dc3b3971e..a44ceb00e71e 100644 --- a/synapse/replication/tcp/client.py +++ b/synapse/replication/tcp/client.py @@ -39,6 +39,7 @@ class ReplicationClientFactory(ReconnectingClientFactory): Accepts a handler that will be called when new data is available or data is required. """ + maxDelay = 30 # Try at least once every N seconds def __init__(self, hs, client_name, handler): @@ -64,9 +65,7 @@ def clientConnectionLost(self, connector, reason): def clientConnectionFailed(self, connector, reason): logger.error("Failed to connect to replication: %r", reason) - ReconnectingClientFactory.clientConnectionFailed( - self, connector, reason - ) + ReconnectingClientFactory.clientConnectionFailed(self, connector, reason) class ReplicationClientHandler(object): @@ -74,6 +73,7 @@ class ReplicationClientHandler(object): By default proxies incoming replication data to the SlaveStore. """ + def __init__(self, store): self.store = store diff --git a/synapse/replication/tcp/commands.py b/synapse/replication/tcp/commands.py index 2098c32a77ea..0ff2a7199fcd 100644 --- a/synapse/replication/tcp/commands.py +++ b/synapse/replication/tcp/commands.py @@ -23,9 +23,11 @@ if platform.python_implementation() == "PyPy": import json + _json_encoder = json.JSONEncoder() else: import simplejson as json + _json_encoder = json.JSONEncoder(namedtuple_as_object=False) logger = logging.getLogger(__name__) @@ -41,6 +43,7 @@ class Command(object): The default implementation creates a command of form ` ` """ + NAME = None def __init__(self, data): @@ -73,6 +76,7 @@ class ServerCommand(Command): SERVER """ + NAME = "SERVER" @@ -99,6 +103,7 @@ class RdataCommand(Command): RDATA presence batch ["@bar:example.com", "online", ...] RDATA presence 59 ["@baz:example.com", "online", ...] """ + NAME = "RDATA" def __init__(self, stream_name, token, row): @@ -110,17 +115,17 @@ def __init__(self, stream_name, token, row): def from_line(cls, line): stream_name, token, row_json = line.split(" ", 2) return cls( - stream_name, - None if token == "batch" else int(token), - json.loads(row_json) + stream_name, None if token == "batch" else int(token), json.loads(row_json) ) def to_line(self): - return " ".join(( - self.stream_name, - str(self.token) if self.token is not None else "batch", - _json_encoder.encode(self.row), - )) + return " ".join( + ( + self.stream_name, + str(self.token) if self.token is not None else "batch", + _json_encoder.encode(self.row), + ) + ) def get_logcontext_id(self): return "RDATA-" + self.stream_name @@ -133,6 +138,7 @@ class PositionCommand(Command): Sent to the client after all missing updates for a stream have been sent to the client and they're now up to date. """ + NAME = "POSITION" def __init__(self, stream_name, token): @@ -145,19 +151,21 @@ def from_line(cls, line): return cls(stream_name, int(token)) def to_line(self): - return " ".join((self.stream_name, str(self.token),)) + return " ".join((self.stream_name, str(self.token))) class ErrorCommand(Command): """Sent by either side if there was an ERROR. The data is a string describing the error. """ + NAME = "ERROR" class PingCommand(Command): """Sent by either side as a keep alive. The data is arbitary (often timestamp) """ + NAME = "PING" @@ -165,6 +173,7 @@ class NameCommand(Command): """Sent by client to inform the server of the client's identity. The data is the name """ + NAME = "NAME" @@ -184,6 +193,7 @@ class ReplicateCommand(Command): REPLICATE ALL NOW """ + NAME = "REPLICATE" def __init__(self, stream_name, token): @@ -200,7 +210,7 @@ def from_line(cls, line): return cls(stream_name, token) def to_line(self): - return " ".join((self.stream_name, str(self.token),)) + return " ".join((self.stream_name, str(self.token))) def get_logcontext_id(self): return "REPLICATE-" + self.stream_name @@ -218,6 +228,7 @@ class UserSyncCommand(Command): Where is either "start" or "stop" """ + NAME = "USER_SYNC" def __init__(self, user_id, is_syncing, last_sync_ms): @@ -235,9 +246,13 @@ def from_line(cls, line): return cls(user_id, state == "start", int(last_sync_ms)) def to_line(self): - return " ".join(( - self.user_id, "start" if self.is_syncing else "end", str(self.last_sync_ms), - )) + return " ".join( + ( + self.user_id, + "start" if self.is_syncing else "end", + str(self.last_sync_ms), + ) + ) class FederationAckCommand(Command): @@ -251,6 +266,7 @@ class FederationAckCommand(Command): FEDERATION_ACK """ + NAME = "FEDERATION_ACK" def __init__(self, token): @@ -268,6 +284,7 @@ class SyncCommand(Command): """Used for testing. The client protocol implementation allows waiting on a SYNC command with a specified data. """ + NAME = "SYNC" @@ -278,6 +295,7 @@ class RemovePusherCommand(Command): REMOVE_PUSHER """ + NAME = "REMOVE_PUSHER" def __init__(self, app_id, push_key, user_id): @@ -309,6 +327,7 @@ class InvalidateCacheCommand(Command): Where is a json list. """ + NAME = "INVALIDATE_CACHE" def __init__(self, cache_func, keys): @@ -322,9 +341,7 @@ def from_line(cls, line): return cls(cache_func, json.loads(keys_json)) def to_line(self): - return " ".join(( - self.cache_func, _json_encoder.encode(self.keys), - )) + return " ".join((self.cache_func, _json_encoder.encode(self.keys))) class UserIpCommand(Command): @@ -334,6 +351,7 @@ class UserIpCommand(Command): USER_IP , , , , , """ + NAME = "USER_IP" def __init__(self, user_id, access_token, ip, user_agent, device_id, last_seen): @@ -350,15 +368,22 @@ def from_line(cls, line): access_token, ip, user_agent, device_id, last_seen = json.loads(jsn) - return cls( - user_id, access_token, ip, user_agent, device_id, last_seen - ) + return cls(user_id, access_token, ip, user_agent, device_id, last_seen) def to_line(self): - return self.user_id + " " + _json_encoder.encode(( - self.access_token, self.ip, self.user_agent, self.device_id, - self.last_seen, - )) + return ( + self.user_id + + " " + + _json_encoder.encode( + ( + self.access_token, + self.ip, + self.user_agent, + self.device_id, + self.last_seen, + ) + ) + ) # Map of command name to command type. diff --git a/synapse/replication/tcp/protocol.py b/synapse/replication/tcp/protocol.py index b51590cf8f22..97efb835ad7b 100644 --- a/synapse/replication/tcp/protocol.py +++ b/synapse/replication/tcp/protocol.py @@ -84,7 +84,8 @@ from .streams import STREAMS_MAP connection_close_counter = Counter( - "synapse_replication_tcp_protocol_close_reason", "", ["reason_type"]) + "synapse_replication_tcp_protocol_close_reason", "", ["reason_type"] +) # A list of all connected protocols. This allows us to send metrics about the # connections. @@ -119,7 +120,8 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver): It also sends `PING` periodically, and correctly times out remote connections (if they send a `PING` command) """ - delimiter = b'\n' + + delimiter = b"\n" VALID_INBOUND_COMMANDS = [] # Valid commands we expect to receive VALID_OUTBOUND_COMMANDS = [] # Valid commans we can send @@ -183,10 +185,14 @@ def send_ping(self): if now - self.last_sent_command >= PING_TIME: self.send_command(PingCommand(now)) - if self.received_ping and now - self.last_received_command > PING_TIMEOUT_MS: + if ( + self.received_ping + and now - self.last_received_command > PING_TIMEOUT_MS + ): logger.info( "[%s] Connection hasn't received command in %r ms. Closing.", - self.id(), now - self.last_received_command + self.id(), + now - self.last_received_command, ) self.send_error("ping timeout") @@ -208,7 +214,8 @@ def lineReceived(self, line): self.last_received_command = self.clock.time_msec() self.inbound_commands_counter[cmd_name] = ( - self.inbound_commands_counter[cmd_name] + 1) + self.inbound_commands_counter[cmd_name] + 1 + ) cmd_cls = COMMAND_MAP[cmd_name] try: @@ -224,9 +231,7 @@ def lineReceived(self, line): # Now lets try and call on_ function run_as_background_process( - "replication-" + cmd.get_logcontext_id(), - self.handle_command, - cmd, + "replication-" + cmd.get_logcontext_id(), self.handle_command, cmd ) def handle_command(self, cmd): @@ -274,8 +279,9 @@ def send_command(self, cmd, do_buffer=True): return self.outbound_commands_counter[cmd.NAME] = ( - self.outbound_commands_counter[cmd.NAME] + 1) - string = "%s %s" % (cmd.NAME, cmd.to_line(),) + self.outbound_commands_counter[cmd.NAME] + 1 + ) + string = "%s %s" % (cmd.NAME, cmd.to_line()) if "\n" in string: raise Exception("Unexpected newline in command: %r", string) @@ -283,10 +289,8 @@ def send_command(self, cmd, do_buffer=True): if len(encoded_string) > self.MAX_LENGTH: raise Exception( - "Failed to send command %s as too long (%d > %d)" % ( - cmd.NAME, - len(encoded_string), self.MAX_LENGTH, - ) + "Failed to send command %s as too long (%d > %d)" + % (cmd.NAME, len(encoded_string), self.MAX_LENGTH) ) self.sendLine(encoded_string) @@ -379,7 +383,9 @@ def __str__(self): if self.transport: addr = str(self.transport.getPeer()) return "ReplicationConnection" % ( - self.name, self.conn_id, addr, + self.name, + self.conn_id, + addr, ) def id(self): @@ -422,7 +428,7 @@ def on_NAME(self, cmd): def on_USER_SYNC(self, cmd): return self.streamer.on_user_sync( - self.conn_id, cmd.user_id, cmd.is_syncing, cmd.last_sync_ms, + self.conn_id, cmd.user_id, cmd.is_syncing, cmd.last_sync_ms ) def on_REPLICATE(self, cmd): @@ -432,10 +438,7 @@ def on_REPLICATE(self, cmd): if stream_name == "ALL": # Subscribe to all streams we're publishing to. deferreds = [ - run_in_background( - self.subscribe_to_stream, - stream, token, - ) + run_in_background(self.subscribe_to_stream, stream, token) for stream in iterkeys(self.streamer.streams_by_name) ] @@ -449,16 +452,18 @@ def on_FEDERATION_ACK(self, cmd): return self.streamer.federation_ack(cmd.token) def on_REMOVE_PUSHER(self, cmd): - return self.streamer.on_remove_pusher( - cmd.app_id, cmd.push_key, cmd.user_id, - ) + return self.streamer.on_remove_pusher(cmd.app_id, cmd.push_key, cmd.user_id) def on_INVALIDATE_CACHE(self, cmd): return self.streamer.on_invalidate_cache(cmd.cache_func, cmd.keys) def on_USER_IP(self, cmd): return self.streamer.on_user_ip( - cmd.user_id, cmd.access_token, cmd.ip, cmd.user_agent, cmd.device_id, + cmd.user_id, + cmd.access_token, + cmd.ip, + cmd.user_agent, + cmd.device_id, cmd.last_seen, ) @@ -476,7 +481,7 @@ def subscribe_to_stream(self, stream_name, token): try: # Get missing updates updates, current_token = yield self.streamer.get_stream_updates( - stream_name, token, + stream_name, token ) # Send all the missing updates @@ -608,8 +613,7 @@ def on_RDATA(self, cmd): row = STREAMS_MAP[stream_name].parse_row(cmd.row) except Exception: logger.exception( - "[%s] Failed to parse RDATA: %r %r", - self.id(), stream_name, cmd.row + "[%s] Failed to parse RDATA: %r %r", self.id(), stream_name, cmd.row ) raise @@ -643,7 +647,9 @@ def replicate(self, stream_name, token): logger.info( "[%s] Subscribing to replication stream: %r from %r", - self.id(), stream_name, token + self.id(), + stream_name, + token, ) self.streams_connecting.add(stream_name) @@ -661,9 +667,7 @@ def on_connection_closed(self): "synapse_replication_tcp_protocol_pending_commands", "", ["name"], - lambda: { - (p.name,): len(p.pending_commands) for p in connected_connections - }, + lambda: {(p.name,): len(p.pending_commands) for p in connected_connections}, ) @@ -678,9 +682,7 @@ def transport_buffer_size(protocol): "synapse_replication_tcp_protocol_transport_send_buffer", "", ["name"], - lambda: { - (p.name,): transport_buffer_size(p) for p in connected_connections - }, + lambda: {(p.name,): transport_buffer_size(p) for p in connected_connections}, ) @@ -694,7 +696,7 @@ def transport_kernel_read_buffer_size(protocol, read=True): op = SIOCINQ else: op = SIOCOUTQ - size = struct.unpack("I", fcntl.ioctl(fileno, op, '\0\0\0\0'))[0] + size = struct.unpack("I", fcntl.ioctl(fileno, op, "\0\0\0\0"))[0] return size return 0 @@ -726,7 +728,7 @@ def transport_kernel_read_buffer_size(protocol, read=True): "", ["command", "name"], lambda: { - (k, p.name,): count + (k, p.name): count for p in connected_connections for k, count in iteritems(p.inbound_commands_counter) }, @@ -737,7 +739,7 @@ def transport_kernel_read_buffer_size(protocol, read=True): "", ["command", "name"], lambda: { - (k, p.name,): count + (k, p.name): count for p in connected_connections for k, count in iteritems(p.outbound_commands_counter) }, diff --git a/synapse/replication/tcp/resource.py b/synapse/replication/tcp/resource.py index f6a38f5140bf..d1e98428bcae 100644 --- a/synapse/replication/tcp/resource.py +++ b/synapse/replication/tcp/resource.py @@ -33,13 +33,15 @@ from .streams import STREAMS_MAP from .streams.federation import FederationStream -stream_updates_counter = Counter("synapse_replication_tcp_resource_stream_updates", - "", ["stream_name"]) +stream_updates_counter = Counter( + "synapse_replication_tcp_resource_stream_updates", "", ["stream_name"] +) user_sync_counter = Counter("synapse_replication_tcp_resource_user_sync", "") federation_ack_counter = Counter("synapse_replication_tcp_resource_federation_ack", "") remove_pusher_counter = Counter("synapse_replication_tcp_resource_remove_pusher", "") -invalidate_cache_counter = Counter("synapse_replication_tcp_resource_invalidate_cache", - "") +invalidate_cache_counter = Counter( + "synapse_replication_tcp_resource_invalidate_cache", "" +) user_ip_cache_counter = Counter("synapse_replication_tcp_resource_user_ip_cache", "") logger = logging.getLogger(__name__) @@ -48,6 +50,7 @@ class ReplicationStreamProtocolFactory(Factory): """Factory for new replication connections. """ + def __init__(self, hs): self.streamer = ReplicationStreamer(hs) self.clock = hs.get_clock() @@ -55,9 +58,7 @@ def __init__(self, hs): def buildProtocol(self, addr): return ServerReplicationStreamProtocol( - self.server_name, - self.clock, - self.streamer, + self.server_name, self.clock, self.streamer ) @@ -80,29 +81,39 @@ def __init__(self, hs): # Current connections. self.connections = [] - LaterGauge("synapse_replication_tcp_resource_total_connections", "", [], - lambda: len(self.connections)) + LaterGauge( + "synapse_replication_tcp_resource_total_connections", + "", + [], + lambda: len(self.connections), + ) # List of streams that clients can subscribe to. # We only support federation stream if federation sending hase been # disabled on the master. self.streams = [ - stream(hs) for stream in itervalues(STREAMS_MAP) + stream(hs) + for stream in itervalues(STREAMS_MAP) if stream != FederationStream or not hs.config.send_federation ] self.streams_by_name = {stream.NAME: stream for stream in self.streams} LaterGauge( - "synapse_replication_tcp_resource_connections_per_stream", "", + "synapse_replication_tcp_resource_connections_per_stream", + "", ["stream_name"], lambda: { - (stream_name,): len([ - conn for conn in self.connections - if stream_name in conn.replication_streams - ]) + (stream_name,): len( + [ + conn + for conn in self.connections + if stream_name in conn.replication_streams + ] + ) for stream_name in self.streams_by_name - }) + }, + ) self.federation_sender = None if not hs.config.send_federation: @@ -179,7 +190,9 @@ def _run_notifier_loop(self): logger.debug( "Getting stream: %s: %s -> %s", - stream.NAME, stream.last_token, stream.upto_token + stream.NAME, + stream.last_token, + stream.upto_token, ) try: updates, current_token = yield stream.get_updates() @@ -189,7 +202,8 @@ def _run_notifier_loop(self): logger.debug( "Sending %d updates to %d connections", - len(updates), len(self.connections), + len(updates), + len(self.connections), ) if updates: @@ -243,7 +257,7 @@ def on_user_sync(self, conn_id, user_id, is_syncing, last_sync_ms): """ user_sync_counter.inc() yield self.presence_handler.update_external_syncs_row( - conn_id, user_id, is_syncing, last_sync_ms, + conn_id, user_id, is_syncing, last_sync_ms ) @measure_func("repl.on_remove_pusher") @@ -272,7 +286,7 @@ def on_user_ip(self, user_id, access_token, ip, user_agent, device_id, last_seen """ user_ip_cache_counter.inc() yield self.store.insert_client_ip( - user_id, access_token, ip, user_agent, device_id, last_seen, + user_id, access_token, ip, user_agent, device_id, last_seen ) yield self._server_notices_sender.on_user_ip(user_id) diff --git a/synapse/replication/tcp/streams/_base.py b/synapse/replication/tcp/streams/_base.py index b6ce7a7beee7..7ef67a5a73fc 100644 --- a/synapse/replication/tcp/streams/_base.py +++ b/synapse/replication/tcp/streams/_base.py @@ -26,78 +26,75 @@ MAX_EVENTS_BEHIND = 10000 -BackfillStreamRow = namedtuple("BackfillStreamRow", ( - "event_id", # str - "room_id", # str - "type", # str - "state_key", # str, optional - "redacts", # str, optional - "relates_to", # str, optional -)) -PresenceStreamRow = namedtuple("PresenceStreamRow", ( - "user_id", # str - "state", # str - "last_active_ts", # int - "last_federation_update_ts", # int - "last_user_sync_ts", # int - "status_msg", # str - "currently_active", # bool -)) -TypingStreamRow = namedtuple("TypingStreamRow", ( - "room_id", # str - "user_ids", # list(str) -)) -ReceiptsStreamRow = namedtuple("ReceiptsStreamRow", ( - "room_id", # str - "receipt_type", # str - "user_id", # str - "event_id", # str - "data", # dict -)) -PushRulesStreamRow = namedtuple("PushRulesStreamRow", ( - "user_id", # str -)) -PushersStreamRow = namedtuple("PushersStreamRow", ( - "user_id", # str - "app_id", # str - "pushkey", # str - "deleted", # bool -)) -CachesStreamRow = namedtuple("CachesStreamRow", ( - "cache_func", # str - "keys", # list(str) - "invalidation_ts", # int -)) -PublicRoomsStreamRow = namedtuple("PublicRoomsStreamRow", ( - "room_id", # str - "visibility", # str - "appservice_id", # str, optional - "network_id", # str, optional -)) -DeviceListsStreamRow = namedtuple("DeviceListsStreamRow", ( - "user_id", # str - "destination", # str -)) -ToDeviceStreamRow = namedtuple("ToDeviceStreamRow", ( - "entity", # str -)) -TagAccountDataStreamRow = namedtuple("TagAccountDataStreamRow", ( - "user_id", # str - "room_id", # str - "data", # dict -)) -AccountDataStreamRow = namedtuple("AccountDataStream", ( - "user_id", # str - "room_id", # str - "data_type", # str - "data", # dict -)) -GroupsStreamRow = namedtuple("GroupsStreamRow", ( - "group_id", # str - "user_id", # str - "type", # str - "content", # dict -)) +BackfillStreamRow = namedtuple( + "BackfillStreamRow", + ( + "event_id", # str + "room_id", # str + "type", # str + "state_key", # str, optional + "redacts", # str, optional + "relates_to", # str, optional + ), +) +PresenceStreamRow = namedtuple( + "PresenceStreamRow", + ( + "user_id", # str + "state", # str + "last_active_ts", # int + "last_federation_update_ts", # int + "last_user_sync_ts", # int + "status_msg", # str + "currently_active", # bool + ), +) +TypingStreamRow = namedtuple( + "TypingStreamRow", ("room_id", "user_ids") # str # list(str) +) +ReceiptsStreamRow = namedtuple( + "ReceiptsStreamRow", + ( + "room_id", # str + "receipt_type", # str + "user_id", # str + "event_id", # str + "data", # dict + ), +) +PushRulesStreamRow = namedtuple("PushRulesStreamRow", ("user_id",)) # str +PushersStreamRow = namedtuple( + "PushersStreamRow", + ("user_id", "app_id", "pushkey", "deleted"), # str # str # str # bool +) +CachesStreamRow = namedtuple( + "CachesStreamRow", + ("cache_func", "keys", "invalidation_ts"), # str # list(str) # int +) +PublicRoomsStreamRow = namedtuple( + "PublicRoomsStreamRow", + ( + "room_id", # str + "visibility", # str + "appservice_id", # str, optional + "network_id", # str, optional + ), +) +DeviceListsStreamRow = namedtuple( + "DeviceListsStreamRow", ("user_id", "destination") # str # str +) +ToDeviceStreamRow = namedtuple("ToDeviceStreamRow", ("entity",)) # str +TagAccountDataStreamRow = namedtuple( + "TagAccountDataStreamRow", ("user_id", "room_id", "data") # str # str # dict +) +AccountDataStreamRow = namedtuple( + "AccountDataStream", + ("user_id", "room_id", "data_type", "data"), # str # str # str # dict +) +GroupsStreamRow = namedtuple( + "GroupsStreamRow", + ("group_id", "user_id", "type", "content"), # str # str # str # dict +) class Stream(object): @@ -106,6 +103,7 @@ class Stream(object): Provides a `get_updates()` function that returns new updates since the last time it was called up until the point `advance_current_token` was called. """ + NAME = None # The name of the stream ROW_TYPE = None # The type of the row. Used by the default impl of parse_row. _LIMITED = True # Whether the update function takes a limit @@ -185,16 +183,13 @@ def get_updates_since(self, from_token): if self._LIMITED: rows = yield self.update_function( - from_token, current_token, - limit=MAX_EVENTS_BEHIND + 1, + from_token, current_token, limit=MAX_EVENTS_BEHIND + 1 ) # never turn more than MAX_EVENTS_BEHIND + 1 into updates. rows = itertools.islice(rows, MAX_EVENTS_BEHIND + 1) else: - rows = yield self.update_function( - from_token, current_token, - ) + rows = yield self.update_function(from_token, current_token) updates = [(row[0], row[1:]) for row in rows] @@ -230,6 +225,7 @@ class BackfillStream(Stream): """We fetched some old events and either we had never seen that event before or it went from being an outlier to not. """ + NAME = "backfill" ROW_TYPE = BackfillStreamRow @@ -286,6 +282,7 @@ def __init__(self, hs): class PushRulesStream(Stream): """A user has changed their push rules """ + NAME = "push_rules" ROW_TYPE = PushRulesStreamRow @@ -306,6 +303,7 @@ def update_function(self, from_token, to_token, limit): class PushersStream(Stream): """A user has added/changed/removed a pusher """ + NAME = "pushers" ROW_TYPE = PushersStreamRow @@ -322,6 +320,7 @@ class CachesStream(Stream): """A cache was invalidated on the master and no other stream would invalidate the cache on the workers """ + NAME = "caches" ROW_TYPE = CachesStreamRow @@ -337,6 +336,7 @@ def __init__(self, hs): class PublicRoomsStream(Stream): """The public rooms list changed """ + NAME = "public_rooms" ROW_TYPE = PublicRoomsStreamRow @@ -352,6 +352,7 @@ def __init__(self, hs): class DeviceListsStream(Stream): """Someone added/changed/removed a device """ + NAME = "device_lists" _LIMITED = False ROW_TYPE = DeviceListsStreamRow @@ -368,6 +369,7 @@ def __init__(self, hs): class ToDeviceStream(Stream): """New to_device messages for a client """ + NAME = "to_device" ROW_TYPE = ToDeviceStreamRow @@ -383,6 +385,7 @@ def __init__(self, hs): class TagAccountDataStream(Stream): """Someone added/removed a tag for a room """ + NAME = "tag_account_data" ROW_TYPE = TagAccountDataStreamRow @@ -398,6 +401,7 @@ def __init__(self, hs): class AccountDataStream(Stream): """Global or per room account data was changed """ + NAME = "account_data" ROW_TYPE = AccountDataStreamRow @@ -416,7 +420,7 @@ def update_function(self, from_token, to_token, limit): results = list(room_results) results.extend( - (stream_id, user_id, None, account_data_type, content,) + (stream_id, user_id, None, account_data_type, content) for stream_id, user_id, account_data_type, content in global_results ) diff --git a/synapse/replication/tcp/streams/events.py b/synapse/replication/tcp/streams/events.py index f1290d022a21..3d0694bb1121 100644 --- a/synapse/replication/tcp/streams/events.py +++ b/synapse/replication/tcp/streams/events.py @@ -52,6 +52,7 @@ @attr.s(slots=True, frozen=True) class EventsStreamRow(object): """A parsed row from the events replication stream""" + type = attr.ib() # str: the TypeId of one of the *EventsStreamRows data = attr.ib() # BaseEventsStreamRow @@ -80,11 +81,11 @@ def from_data(cls, data): class EventsStreamEventRow(BaseEventsStreamRow): TypeId = "ev" - event_id = attr.ib() # str - room_id = attr.ib() # str - type = attr.ib() # str - state_key = attr.ib() # str, optional - redacts = attr.ib() # str, optional + event_id = attr.ib() # str + room_id = attr.ib() # str + type = attr.ib() # str + state_key = attr.ib() # str, optional + redacts = attr.ib() # str, optional relates_to = attr.ib() # str, optional @@ -92,24 +93,21 @@ class EventsStreamEventRow(BaseEventsStreamRow): class EventsStreamCurrentStateRow(BaseEventsStreamRow): TypeId = "state" - room_id = attr.ib() # str - type = attr.ib() # str + room_id = attr.ib() # str + type = attr.ib() # str state_key = attr.ib() # str - event_id = attr.ib() # str, optional + event_id = attr.ib() # str, optional TypeToRow = { - Row.TypeId: Row - for Row in ( - EventsStreamEventRow, - EventsStreamCurrentStateRow, - ) + Row.TypeId: Row for Row in (EventsStreamEventRow, EventsStreamCurrentStateRow) } class EventsStream(Stream): """We received a new event, or an event went from being an outlier to not """ + NAME = "events" def __init__(self, hs): @@ -121,19 +119,17 @@ def __init__(self, hs): @defer.inlineCallbacks def update_function(self, from_token, current_token, limit=None): event_rows = yield self._store.get_all_new_forward_event_rows( - from_token, current_token, limit, + from_token, current_token, limit ) event_updates = ( - (row[0], EventsStreamEventRow.TypeId, row[1:]) - for row in event_rows + (row[0], EventsStreamEventRow.TypeId, row[1:]) for row in event_rows ) state_rows = yield self._store.get_all_updated_current_state_deltas( from_token, current_token, limit ) state_updates = ( - (row[0], EventsStreamCurrentStateRow.TypeId, row[1:]) - for row in state_rows + (row[0], EventsStreamCurrentStateRow.TypeId, row[1:]) for row in state_rows ) all_updates = heapq.merge(event_updates, state_updates) diff --git a/synapse/replication/tcp/streams/federation.py b/synapse/replication/tcp/streams/federation.py index 9aa43aa8d231..dc2484109d1c 100644 --- a/synapse/replication/tcp/streams/federation.py +++ b/synapse/replication/tcp/streams/federation.py @@ -17,16 +17,20 @@ from ._base import Stream -FederationStreamRow = namedtuple("FederationStreamRow", ( - "type", # str, the type of data as defined in the BaseFederationRows - "data", # dict, serialization of a federation.send_queue.BaseFederationRow -)) +FederationStreamRow = namedtuple( + "FederationStreamRow", + ( + "type", # str, the type of data as defined in the BaseFederationRows + "data", # dict, serialization of a federation.send_queue.BaseFederationRow + ), +) class FederationStream(Stream): """Data to be sent over federation. Only available when master has federation sending disabled. """ + NAME = "federation" ROW_TYPE = FederationStreamRow diff --git a/synapse/rest/__init__.py b/synapse/rest/__init__.py index e6110ad9b1aa..1d20b96d0354 100644 --- a/synapse/rest/__init__.py +++ b/synapse/rest/__init__.py @@ -66,6 +66,7 @@ class ClientRestResource(JsonResource): * /_matrix/client/unstable * etc """ + def __init__(self, hs): JsonResource.__init__(self, hs, canonical_json=False) self.register_servlets(self, hs) diff --git a/synapse/rest/admin/__init__.py b/synapse/rest/admin/__init__.py index d6c4dcdb1816..9843a902c69d 100644 --- a/synapse/rest/admin/__init__.py +++ b/synapse/rest/admin/__init__.py @@ -61,7 +61,7 @@ def historical_admin_path_patterns(path_regex): "^/_synapse/admin/v1", "^/_matrix/client/api/v1/admin", "^/_matrix/client/unstable/admin", - "^/_matrix/client/r0/admin" + "^/_matrix/client/r0/admin", ) ) @@ -88,12 +88,12 @@ def on_GET(self, request, user_id): class VersionServlet(RestServlet): - PATTERNS = (re.compile("^/_synapse/admin/v1/server_version$"), ) + PATTERNS = (re.compile("^/_synapse/admin/v1/server_version$"),) def __init__(self, hs): self.res = { - 'server_version': get_version_string(synapse), - 'python_version': platform.python_version(), + "server_version": get_version_string(synapse), + "python_version": platform.python_version(), } def on_GET(self, request): @@ -107,6 +107,7 @@ class UserRegisterServlet(RestServlet): nonces (dict[str, int]): The nonces that we will accept. A dict of nonce to the time it was generated, in int seconds. """ + PATTERNS = historical_admin_path_patterns("/register") NONCE_TIMEOUT = 60 @@ -146,28 +147,24 @@ def on_POST(self, request): body = parse_json_object_from_request(request) if "nonce" not in body: - raise SynapseError( - 400, "nonce must be specified", errcode=Codes.BAD_JSON, - ) + raise SynapseError(400, "nonce must be specified", errcode=Codes.BAD_JSON) nonce = body["nonce"] if nonce not in self.nonces: - raise SynapseError( - 400, "unrecognised nonce", - ) + raise SynapseError(400, "unrecognised nonce") # Delete the nonce, so it can't be reused, even if it's invalid del self.nonces[nonce] if "username" not in body: raise SynapseError( - 400, "username must be specified", errcode=Codes.BAD_JSON, + 400, "username must be specified", errcode=Codes.BAD_JSON ) else: if ( - not isinstance(body['username'], text_type) - or len(body['username']) > 512 + not isinstance(body["username"], text_type) + or len(body["username"]) > 512 ): raise SynapseError(400, "Invalid username") @@ -177,12 +174,12 @@ def on_POST(self, request): if "password" not in body: raise SynapseError( - 400, "password must be specified", errcode=Codes.BAD_JSON, + 400, "password must be specified", errcode=Codes.BAD_JSON ) else: if ( - not isinstance(body['password'], text_type) - or len(body['password']) > 512 + not isinstance(body["password"], text_type) + or len(body["password"]) > 512 ): raise SynapseError(400, "Invalid password") @@ -202,7 +199,7 @@ def on_POST(self, request): key=self.hs.config.registration_shared_secret.encode(), digestmod=hashlib.sha1, ) - want_mac.update(nonce.encode('utf8')) + want_mac.update(nonce.encode("utf8")) want_mac.update(b"\x00") want_mac.update(username) want_mac.update(b"\x00") @@ -211,13 +208,10 @@ def on_POST(self, request): want_mac.update(b"admin" if admin else b"notadmin") if user_type: want_mac.update(b"\x00") - want_mac.update(user_type.encode('utf8')) + want_mac.update(user_type.encode("utf8")) want_mac = want_mac.hexdigest() - if not hmac.compare_digest( - want_mac.encode('ascii'), - got_mac.encode('ascii') - ): + if not hmac.compare_digest(want_mac.encode("ascii"), got_mac.encode("ascii")): raise SynapseError(403, "HMAC incorrect") # Reuse the parts of RegisterRestServlet to reduce code duplication @@ -226,7 +220,7 @@ def on_POST(self, request): register = RegisterRestServlet(self.hs) (user_id, _) = yield register.registration_handler.register( - localpart=body['username'].lower(), + localpart=body["username"].lower(), password=body["password"], admin=bool(admin), generate_token=False, @@ -308,7 +302,7 @@ def on_POST(self, request, room_id, event_id): # user can provide an event_id in the URL or the request body, or can # provide a timestamp in the request body. if event_id is None: - event_id = body.get('purge_up_to_event_id') + event_id = body.get("purge_up_to_event_id") if event_id is not None: event = yield self.store.get_event(event_id) @@ -318,44 +312,39 @@ def on_POST(self, request, room_id, event_id): token = yield self.store.get_topological_token_for_event(event_id) - logger.info( - "[purge] purging up to token %s (event_id %s)", - token, event_id, - ) - elif 'purge_up_to_ts' in body: - ts = body['purge_up_to_ts'] + logger.info("[purge] purging up to token %s (event_id %s)", token, event_id) + elif "purge_up_to_ts" in body: + ts = body["purge_up_to_ts"] if not isinstance(ts, int): raise SynapseError( - 400, "purge_up_to_ts must be an int", - errcode=Codes.BAD_JSON, + 400, "purge_up_to_ts must be an int", errcode=Codes.BAD_JSON ) - stream_ordering = ( - yield self.store.find_first_stream_ordering_after_ts(ts) - ) + stream_ordering = (yield self.store.find_first_stream_ordering_after_ts(ts)) r = ( yield self.store.get_room_event_after_stream_ordering( - room_id, stream_ordering, + room_id, stream_ordering ) ) if not r: logger.warn( "[purge] purging events not possible: No event found " "(received_ts %i => stream_ordering %i)", - ts, stream_ordering, + ts, + stream_ordering, ) raise SynapseError( - 404, - "there is no event to be purged", - errcode=Codes.NOT_FOUND, + 404, "there is no event to be purged", errcode=Codes.NOT_FOUND ) (stream, topo, _event_id) = r token = "t%d-%d" % (topo, stream) logger.info( "[purge] purging up to token %s (received_ts %i => " "stream_ordering %i)", - token, ts, stream_ordering, + token, + ts, + stream_ordering, ) else: raise SynapseError( @@ -365,13 +354,10 @@ def on_POST(self, request, room_id, event_id): ) purge_id = yield self.pagination_handler.start_purge_history( - room_id, token, - delete_local_events=delete_local_events, + room_id, token, delete_local_events=delete_local_events ) - defer.returnValue((200, { - "purge_id": purge_id, - })) + defer.returnValue((200, {"purge_id": purge_id})) class PurgeHistoryStatusRestServlet(RestServlet): @@ -421,16 +407,14 @@ def on_POST(self, request, target_user_id): UserID.from_string(target_user_id) result = yield self._deactivate_account_handler.deactivate_account( - target_user_id, erase, + target_user_id, erase ) if result: id_server_unbind_result = "success" else: id_server_unbind_result = "no-support" - defer.returnValue((200, { - "id_server_unbind_result": id_server_unbind_result, - })) + defer.returnValue((200, {"id_server_unbind_result": id_server_unbind_result})) class ShutdownRoomRestServlet(RestServlet): @@ -439,6 +423,7 @@ class ShutdownRoomRestServlet(RestServlet): to a new room created by `new_room_user_id` and kicked users will be auto joined to the new room. """ + PATTERNS = historical_admin_path_patterns("/shutdown_room/(?P[^/]+)") DEFAULT_MESSAGE = ( @@ -474,9 +459,7 @@ def on_POST(self, request, room_id): config={ "preset": "public_chat", "name": room_name, - "power_level_content_override": { - "users_default": -10, - }, + "power_level_content_override": {"users_default": -10}, }, ratelimit=False, ) @@ -485,8 +468,7 @@ def on_POST(self, request, room_id): requester_user_id = requester.user.to_string() logger.info( - "Shutting down room %r, joining to new room: %r", - room_id, new_room_id, + "Shutting down room %r, joining to new room: %r", room_id, new_room_id ) # This will work even if the room is already blocked, but that is @@ -529,7 +511,7 @@ def on_POST(self, request, room_id): kicked_users.append(user_id) except Exception: logger.exception( - "Failed to leave old room and join new room for %r", user_id, + "Failed to leave old room and join new room for %r", user_id ) failed_to_kick_users.append(user_id) @@ -550,18 +532,24 @@ def on_POST(self, request, room_id): room_id, new_room_id, requester_user_id ) - defer.returnValue((200, { - "kicked_users": kicked_users, - "failed_to_kick_users": failed_to_kick_users, - "local_aliases": aliases_for_room, - "new_room_id": new_room_id, - })) + defer.returnValue( + ( + 200, + { + "kicked_users": kicked_users, + "failed_to_kick_users": failed_to_kick_users, + "local_aliases": aliases_for_room, + "new_room_id": new_room_id, + }, + ) + ) class QuarantineMediaInRoom(RestServlet): """Quarantines all media in a room so that no one can download it via this server. """ + PATTERNS = historical_admin_path_patterns("/quarantine_media/(?P[^/]+)") def __init__(self, hs): @@ -574,7 +562,7 @@ def on_POST(self, request, room_id): yield assert_user_is_admin(self.auth, requester.user) num_quarantined = yield self.store.quarantine_media_ids_in_room( - room_id, requester.user.to_string(), + room_id, requester.user.to_string() ) defer.returnValue((200, {"num_quarantined": num_quarantined})) @@ -583,6 +571,7 @@ def on_POST(self, request, room_id): class ListMediaInRoom(RestServlet): """Lists all of the media in a given room. """ + PATTERNS = historical_admin_path_patterns("/room/(?P[^/]+)/media") def __init__(self, hs): @@ -613,7 +602,10 @@ class ResetPasswordRestServlet(RestServlet): Returns: 200 OK with empty object if success otherwise an error. """ - PATTERNS = historical_admin_path_patterns("/reset_password/(?P[^/]*)") + + PATTERNS = historical_admin_path_patterns( + "/reset_password/(?P[^/]*)" + ) def __init__(self, hs): self.store = hs.get_datastore() @@ -633,7 +625,7 @@ def on_POST(self, request, target_user_id): params = parse_json_object_from_request(request) assert_params_in_dict(params, ["new_password"]) - new_password = params['new_password'] + new_password = params["new_password"] yield self._set_password_handler.set_password( target_user_id, new_password, requester @@ -650,7 +642,10 @@ class GetUsersPaginatedRestServlet(RestServlet): Returns: 200 OK with json object {list[dict[str, Any]], count} or empty object. """ - PATTERNS = historical_admin_path_patterns("/users_paginate/(?P[^/]*)") + + PATTERNS = historical_admin_path_patterns( + "/users_paginate/(?P[^/]*)" + ) def __init__(self, hs): self.store = hs.get_datastore() @@ -676,9 +671,7 @@ def on_GET(self, request, target_user_id): logger.info("limit: %s, start: %s", limit, start) - ret = yield self.handlers.admin_handler.get_users_paginate( - order, start, limit - ) + ret = yield self.handlers.admin_handler.get_users_paginate(order, start, limit) defer.returnValue((200, ret)) @defer.inlineCallbacks @@ -702,13 +695,11 @@ def on_POST(self, request, target_user_id): order = "name" # order by name in user table params = parse_json_object_from_request(request) assert_params_in_dict(params, ["limit", "start"]) - limit = params['limit'] - start = params['start'] + limit = params["limit"] + start = params["start"] logger.info("limit: %s, start: %s", limit, start) - ret = yield self.handlers.admin_handler.get_users_paginate( - order, start, limit - ) + ret = yield self.handlers.admin_handler.get_users_paginate(order, start, limit) defer.returnValue((200, ret)) @@ -722,6 +713,7 @@ class SearchUsersRestServlet(RestServlet): Returns: 200 OK with json object {list[dict[str, Any]], count} or empty object. """ + PATTERNS = historical_admin_path_patterns("/search_users/(?P[^/]*)") def __init__(self, hs): @@ -750,15 +742,14 @@ def on_GET(self, request, target_user_id): term = parse_string(request, "term", required=True) logger.info("term: %s ", term) - ret = yield self.handlers.admin_handler.search_users( - term - ) + ret = yield self.handlers.admin_handler.search_users(term) defer.returnValue((200, ret)) class DeleteGroupAdminRestServlet(RestServlet): """Allows deleting of local groups """ + PATTERNS = historical_admin_path_patterns("/delete_group/(?P[^/]*)") def __init__(self, hs): @@ -800,15 +791,15 @@ def on_POST(self, request): raise SynapseError(400, "Missing property 'user_id' in the request body") expiration_ts = yield self.account_activity_handler.renew_account_for_user( - body["user_id"], body.get("expiration_ts"), + body["user_id"], + body.get("expiration_ts"), not body.get("enable_renewal_emails", True), ) - res = { - "expiration_ts": expiration_ts, - } + res = {"expiration_ts": expiration_ts} defer.returnValue((200, res)) + ######################################################################################## # # please don't add more servlets here: this file is already long and unwieldy. Put diff --git a/synapse/rest/admin/server_notice_servlet.py b/synapse/rest/admin/server_notice_servlet.py index ae5aca9dac13..ee66838a0d34 100644 --- a/synapse/rest/admin/server_notice_servlet.py +++ b/synapse/rest/admin/server_notice_servlet.py @@ -46,6 +46,7 @@ class SendServerNoticeServlet(RestServlet): "event_id": "$1895723857jgskldgujpious" } """ + def __init__(self, hs): """ Args: @@ -58,15 +59,9 @@ def __init__(self, hs): def register(self, json_resource): PATTERN = "^/_synapse/admin/v1/send_server_notice" + json_resource.register_paths("POST", (re.compile(PATTERN + "$"),), self.on_POST) json_resource.register_paths( - "POST", - (re.compile(PATTERN + "$"), ), - self.on_POST, - ) - json_resource.register_paths( - "PUT", - (re.compile(PATTERN + "/(?P[^/]*)$",), ), - self.on_PUT, + "PUT", (re.compile(PATTERN + "/(?P[^/]*)$"),), self.on_PUT ) @defer.inlineCallbacks @@ -96,5 +91,5 @@ def on_POST(self, request, txn_id=None): def on_PUT(self, request, txn_id): return self.txns.fetch_or_execute_request( - request, self.on_POST, request, txn_id, + request, self.on_POST, request, txn_id ) diff --git a/synapse/rest/client/transactions.py b/synapse/rest/client/transactions.py index 48c17f1b6d41..36404b797dfe 100644 --- a/synapse/rest/client/transactions.py +++ b/synapse/rest/client/transactions.py @@ -26,7 +26,6 @@ class HttpTransactionCache(object): - def __init__(self, hs): self.hs = hs self.auth = self.hs.get_auth() @@ -53,7 +52,7 @@ def _get_transaction_key(self, request): str: A transaction key """ token = self.auth.get_access_token_from_request(request) - return request.path.decode('utf8') + "/" + token + return request.path.decode("utf8") + "/" + token def fetch_or_execute_request(self, request, fn, *args, **kwargs): """A helper function for fetch_or_execute which extracts diff --git a/synapse/rest/client/v1/directory.py b/synapse/rest/client/v1/directory.py index 0035182bb91d..dd0d38ea5c12 100644 --- a/synapse/rest/client/v1/directory.py +++ b/synapse/rest/client/v1/directory.py @@ -56,8 +56,9 @@ def on_PUT(self, request, room_alias): content = parse_json_object_from_request(request) if "room_id" not in content: - raise SynapseError(400, 'Missing params: ["room_id"]', - errcode=Codes.BAD_JSON) + raise SynapseError( + 400, 'Missing params: ["room_id"]', errcode=Codes.BAD_JSON + ) logger.debug("Got content: %s", content) logger.debug("Got room name: %s", room_alias.to_string()) @@ -89,13 +90,11 @@ def on_DELETE(self, request, room_alias): try: service = yield self.auth.get_appservice_by_req(request) room_alias = RoomAlias.from_string(room_alias) - yield dir_handler.delete_appservice_association( - service, room_alias - ) + yield dir_handler.delete_appservice_association(service, room_alias) logger.info( "Application service at %s deleted alias %s", service.url, - room_alias.to_string() + room_alias.to_string(), ) defer.returnValue((200, {})) except AuthError: @@ -107,14 +106,10 @@ def on_DELETE(self, request, room_alias): room_alias = RoomAlias.from_string(room_alias) - yield dir_handler.delete_association( - requester, room_alias - ) + yield dir_handler.delete_association(requester, room_alias) logger.info( - "User %s deleted alias %s", - user.to_string(), - room_alias.to_string() + "User %s deleted alias %s", user.to_string(), room_alias.to_string() ) defer.returnValue((200, {})) @@ -135,9 +130,9 @@ def on_GET(self, request, room_id): if room is None: raise NotFoundError("Unknown room") - defer.returnValue((200, { - "visibility": "public" if room["is_public"] else "private" - })) + defer.returnValue( + (200, {"visibility": "public" if room["is_public"] else "private"}) + ) @defer.inlineCallbacks def on_PUT(self, request, room_id): @@ -147,7 +142,7 @@ def on_PUT(self, request, room_id): visibility = content.get("visibility", "public") yield self.handlers.directory_handler.edit_published_room_list( - requester, room_id, visibility, + requester, room_id, visibility ) defer.returnValue((200, {})) @@ -157,7 +152,7 @@ def on_DELETE(self, request, room_id): requester = yield self.auth.get_user_by_req(request) yield self.handlers.directory_handler.edit_published_room_list( - requester, room_id, "private", + requester, room_id, "private" ) defer.returnValue((200, {})) @@ -191,7 +186,7 @@ def _edit(self, request, network_id, room_id, visibility): ) yield self.handlers.directory_handler.edit_published_appservice_room_list( - requester.app_service.id, network_id, room_id, visibility, + requester.app_service.id, network_id, room_id, visibility ) defer.returnValue((200, {})) diff --git a/synapse/rest/client/v1/events.py b/synapse/rest/client/v1/events.py index 84ca36270bf2..d6de2b73604b 100644 --- a/synapse/rest/client/v1/events.py +++ b/synapse/rest/client/v1/events.py @@ -38,17 +38,14 @@ def __init__(self, hs): @defer.inlineCallbacks def on_GET(self, request): - requester = yield self.auth.get_user_by_req( - request, - allow_guest=True, - ) + requester = yield self.auth.get_user_by_req(request, allow_guest=True) is_guest = requester.is_guest room_id = None if is_guest: if b"room_id" not in request.args: raise SynapseError(400, "Guest users must specify room_id param") if b"room_id" in request.args: - room_id = request.args[b"room_id"][0].decode('ascii') + room_id = request.args[b"room_id"][0].decode("ascii") pagin_config = PaginationConfig.from_request(request) timeout = EventStreamRestServlet.DEFAULT_LONGPOLL_TIME_MS diff --git a/synapse/rest/client/v1/login.py b/synapse/rest/client/v1/login.py index 3b60728628bb..4efb679a042a 100644 --- a/synapse/rest/client/v1/login.py +++ b/synapse/rest/client/v1/login.py @@ -44,10 +44,7 @@ def login_submission_legacy_convert(submission): to a typed object. """ if "user" in submission: - submission["identifier"] = { - "type": "m.id.user", - "user": submission["user"], - } + submission["identifier"] = {"type": "m.id.user", "user": submission["user"]} del submission["user"] if "medium" in submission and "address" in submission: @@ -73,11 +70,7 @@ def login_id_thirdparty_from_phone(identifier): msisdn = phone_number_to_msisdn(identifier["country"], identifier["number"]) - return { - "type": "m.id.thirdparty", - "medium": "msisdn", - "address": msisdn, - } + return {"type": "m.id.thirdparty", "medium": "msisdn", "address": msisdn} class LoginRestServlet(RestServlet): @@ -120,9 +113,9 @@ def on_GET(self, request): # login flow types returned. flows.append({"type": LoginRestServlet.TOKEN_TYPE}) - flows.extend(( - {"type": t} for t in self.auth_handler.get_supported_login_types() - )) + flows.extend( + ({"type": t} for t in self.auth_handler.get_supported_login_types()) + ) return (200, {"flows": flows}) @@ -132,7 +125,8 @@ def on_OPTIONS(self, request): @defer.inlineCallbacks def on_POST(self, request): self._address_ratelimiter.ratelimit( - request.getClientIP(), time_now_s=self.hs.clock.time(), + request.getClientIP(), + time_now_s=self.hs.clock.time(), rate_hz=self.hs.config.rc_login_address.per_second, burst_count=self.hs.config.rc_login_address.burst_count, update=True, @@ -140,8 +134,9 @@ def on_POST(self, request): login_submission = parse_json_object_from_request(request) try: - if self.jwt_enabled and (login_submission["type"] == - LoginRestServlet.JWT_TYPE): + if self.jwt_enabled and ( + login_submission["type"] == LoginRestServlet.JWT_TYPE + ): result = yield self.do_jwt_login(login_submission) elif login_submission["type"] == LoginRestServlet.TOKEN_TYPE: result = yield self.do_token_login(login_submission) @@ -170,10 +165,10 @@ def _do_other_login(self, login_submission): # field) logger.info( "Got login request with identifier: %r, medium: %r, address: %r, user: %r", - login_submission.get('identifier'), - login_submission.get('medium'), - login_submission.get('address'), - login_submission.get('user'), + login_submission.get("identifier"), + login_submission.get("medium"), + login_submission.get("address"), + login_submission.get("user"), ) login_submission_legacy_convert(login_submission) @@ -190,13 +185,13 @@ def _do_other_login(self, login_submission): # convert threepid identifiers to user IDs if identifier["type"] == "m.id.thirdparty": - address = identifier.get('address') - medium = identifier.get('medium') + address = identifier.get("address") + medium = identifier.get("medium") if medium is None or address is None: raise SynapseError(400, "Invalid thirdparty identifier") - if medium == 'email': + if medium == "email": # For emails, transform the address to lowercase. # We store all email addreses as lowercase in the DB. # (See add_threepid in synapse/handlers/auth.py) @@ -205,34 +200,28 @@ def _do_other_login(self, login_submission): # Check for login providers that support 3pid login types canonical_user_id, callback_3pid = ( yield self.auth_handler.check_password_provider_3pid( - medium, - address, - login_submission["password"], + medium, address, login_submission["password"] ) ) if canonical_user_id: # Authentication through password provider and 3pid succeeded result = yield self._register_device_with_callback( - canonical_user_id, login_submission, callback_3pid, + canonical_user_id, login_submission, callback_3pid ) defer.returnValue(result) # No password providers were able to handle this 3pid # Check local store user_id = yield self.hs.get_datastore().get_user_id_by_threepid( - medium, address, + medium, address ) if not user_id: logger.warn( - "unknown 3pid identifier medium %s, address %r", - medium, address, + "unknown 3pid identifier medium %s, address %r", medium, address ) raise LoginError(403, "", errcode=Codes.FORBIDDEN) - identifier = { - "type": "m.id.user", - "user": user_id, - } + identifier = {"type": "m.id.user", "user": user_id} # by this point, the identifier should be an m.id.user: if it's anything # else, we haven't understood it. @@ -242,22 +231,16 @@ def _do_other_login(self, login_submission): raise SynapseError(400, "User identifier is missing 'user' key") canonical_user_id, callback = yield self.auth_handler.validate_login( - identifier["user"], - login_submission, + identifier["user"], login_submission ) result = yield self._register_device_with_callback( - canonical_user_id, login_submission, callback, + canonical_user_id, login_submission, callback ) defer.returnValue(result) @defer.inlineCallbacks - def _register_device_with_callback( - self, - user_id, - login_submission, - callback=None, - ): + def _register_device_with_callback(self, user_id, login_submission, callback=None): """ Registers a device with a given user_id. Optionally run a callback function after registration has completed. @@ -273,7 +256,7 @@ def _register_device_with_callback( device_id = login_submission.get("device_id") initial_display_name = login_submission.get("initial_device_display_name") device_id, access_token = yield self.registration_handler.register_device( - user_id, device_id, initial_display_name, + user_id, device_id, initial_display_name ) result = { @@ -290,7 +273,7 @@ def _register_device_with_callback( @defer.inlineCallbacks def do_token_login(self, login_submission): - token = login_submission['token'] + token = login_submission["token"] auth_handler = self.auth_handler user_id = ( yield auth_handler.validate_short_term_login_token_and_get_user_id(token) @@ -299,7 +282,7 @@ def do_token_login(self, login_submission): device_id = login_submission.get("device_id") initial_display_name = login_submission.get("initial_device_display_name") device_id, access_token = yield self.registration_handler.register_device( - user_id, device_id, initial_display_name, + user_id, device_id, initial_display_name ) result = { @@ -316,15 +299,16 @@ def do_jwt_login(self, login_submission): token = login_submission.get("token", None) if token is None: raise LoginError( - 401, "Token field for JWT is missing", - errcode=Codes.UNAUTHORIZED + 401, "Token field for JWT is missing", errcode=Codes.UNAUTHORIZED ) import jwt from jwt.exceptions import InvalidTokenError try: - payload = jwt.decode(token, self.jwt_secret, algorithms=[self.jwt_algorithm]) + payload = jwt.decode( + token, self.jwt_secret, algorithms=[self.jwt_algorithm] + ) except jwt.ExpiredSignatureError: raise LoginError(401, "JWT expired", errcode=Codes.UNAUTHORIZED) except InvalidTokenError: @@ -342,7 +326,7 @@ def do_jwt_login(self, login_submission): device_id = login_submission.get("device_id") initial_display_name = login_submission.get("initial_device_display_name") device_id, access_token = yield self.registration_handler.register_device( - registered_user_id, device_id, initial_display_name, + registered_user_id, device_id, initial_display_name ) result = { @@ -358,7 +342,7 @@ def do_jwt_login(self, login_submission): device_id = login_submission.get("device_id") initial_display_name = login_submission.get("initial_device_display_name") device_id, access_token = yield self.registration_handler.register_device( - registered_user_id, device_id, initial_display_name, + registered_user_id, device_id, initial_display_name ) result = { @@ -375,21 +359,20 @@ class CasRedirectServlet(RestServlet): def __init__(self, hs): super(CasRedirectServlet, self).__init__() - self.cas_server_url = hs.config.cas_server_url.encode('ascii') - self.cas_service_url = hs.config.cas_service_url.encode('ascii') + self.cas_server_url = hs.config.cas_server_url.encode("ascii") + self.cas_service_url = hs.config.cas_service_url.encode("ascii") def on_GET(self, request): args = request.args if b"redirectUrl" not in args: return (400, "Redirect URL not specified for CAS auth") - client_redirect_url_param = urllib.parse.urlencode({ - b"redirectUrl": args[b"redirectUrl"][0] - }).encode('ascii') - hs_redirect_url = (self.cas_service_url + - b"/_matrix/client/r0/login/cas/ticket") - service_param = urllib.parse.urlencode({ - b"service": b"%s?%s" % (hs_redirect_url, client_redirect_url_param) - }).encode('ascii') + client_redirect_url_param = urllib.parse.urlencode( + {b"redirectUrl": args[b"redirectUrl"][0]} + ).encode("ascii") + hs_redirect_url = self.cas_service_url + b"/_matrix/client/r0/login/cas/ticket" + service_param = urllib.parse.urlencode( + {b"service": b"%s?%s" % (hs_redirect_url, client_redirect_url_param)} + ).encode("ascii") request.redirect(b"%s/login?%s" % (self.cas_server_url, service_param)) finish_request(request) @@ -411,7 +394,7 @@ def on_GET(self, request): uri = self.cas_server_url + "/proxyValidate" args = { "ticket": parse_string(request, "ticket", required=True), - "service": self.cas_service_url + "service": self.cas_service_url, } try: body = yield self._http_client.get_raw(uri, args) @@ -438,7 +421,7 @@ def handle_cas_response(self, request, cas_response_body, client_redirect_url): raise LoginError(401, "Unauthorized", errcode=Codes.UNAUTHORIZED) return self._sso_auth_handler.on_successful_auth( - user, request, client_redirect_url, + user, request, client_redirect_url ) def parse_cas_response(self, cas_response_body): @@ -448,7 +431,7 @@ def parse_cas_response(self, cas_response_body): root = ET.fromstring(cas_response_body) if not root.tag.endswith("serviceResponse"): raise Exception("root of CAS response is not serviceResponse") - success = (root[0].tag.endswith("authenticationSuccess")) + success = root[0].tag.endswith("authenticationSuccess") for child in root[0]: if child.tag.endswith("user"): user = child.text @@ -466,11 +449,11 @@ def parse_cas_response(self, cas_response_body): raise Exception("CAS response does not contain user") except Exception: logger.error("Error parsing CAS response", exc_info=1) - raise LoginError(401, "Invalid CAS response", - errcode=Codes.UNAUTHORIZED) + raise LoginError(401, "Invalid CAS response", errcode=Codes.UNAUTHORIZED) if not success: - raise LoginError(401, "Unsuccessful CAS response", - errcode=Codes.UNAUTHORIZED) + raise LoginError( + 401, "Unsuccessful CAS response", errcode=Codes.UNAUTHORIZED + ) return user, attributes @@ -482,6 +465,7 @@ class SSOAuthHandler(object): Args: hs (synapse.server.HomeServer) """ + def __init__(self, hs): self._hostname = hs.hostname self._auth_handler = hs.get_auth_handler() @@ -490,8 +474,7 @@ def __init__(self, hs): @defer.inlineCallbacks def on_successful_auth( - self, username, request, client_redirect_url, - user_display_name=None, + self, username, request, client_redirect_url, user_display_name=None ): """Called once the user has successfully authenticated with the SSO. diff --git a/synapse/rest/client/v1/logout.py b/synapse/rest/client/v1/logout.py index b8064f261e11..cd711be5190e 100644 --- a/synapse/rest/client/v1/logout.py +++ b/synapse/rest/client/v1/logout.py @@ -46,7 +46,8 @@ def on_POST(self, request): yield self._auth_handler.delete_access_token(access_token) else: yield self._device_handler.delete_device( - requester.user.to_string(), requester.device_id) + requester.user.to_string(), requester.device_id + ) defer.returnValue((200, {})) diff --git a/synapse/rest/client/v1/presence.py b/synapse/rest/client/v1/presence.py index e263da3cb7e7..3e87f0fdb397 100644 --- a/synapse/rest/client/v1/presence.py +++ b/synapse/rest/client/v1/presence.py @@ -47,7 +47,7 @@ def on_GET(self, request, user_id): if requester.user != user: allowed = yield self.presence_handler.is_visible( - observed_user=user, observer_user=requester.user, + observed_user=user, observer_user=requester.user ) if not allowed: diff --git a/synapse/rest/client/v1/profile.py b/synapse/rest/client/v1/profile.py index e15d9d82a6b6..4d8ab1f47e9e 100644 --- a/synapse/rest/client/v1/profile.py +++ b/synapse/rest/client/v1/profile.py @@ -63,8 +63,7 @@ def on_PUT(self, request, user_id): except Exception: defer.returnValue((400, "Unable to parse name")) - yield self.profile_handler.set_displayname( - user, requester, new_name, is_admin) + yield self.profile_handler.set_displayname(user, requester, new_name, is_admin) defer.returnValue((200, {})) @@ -113,8 +112,7 @@ def on_PUT(self, request, user_id): except Exception: defer.returnValue((400, "Unable to parse name")) - yield self.profile_handler.set_avatar_url( - user, requester, new_name, is_admin) + yield self.profile_handler.set_avatar_url(user, requester, new_name, is_admin) defer.returnValue((200, {})) diff --git a/synapse/rest/client/v1/push_rule.py b/synapse/rest/client/v1/push_rule.py index 3d6326fe2fc3..e635efb420c5 100644 --- a/synapse/rest/client/v1/push_rule.py +++ b/synapse/rest/client/v1/push_rule.py @@ -21,7 +21,11 @@ SynapseError, UnrecognizedRequestError, ) -from synapse.http.servlet import RestServlet, parse_json_value_from_request, parse_string +from synapse.http.servlet import ( + RestServlet, + parse_json_value_from_request, + parse_string, +) from synapse.push.baserules import BASE_RULE_IDS from synapse.push.clientformat import format_push_rules_for_user from synapse.push.rulekinds import PRIORITY_CLASS_MAP @@ -32,7 +36,8 @@ class PushRuleRestServlet(RestServlet): PATTERNS = client_patterns("/(?Ppushrules/.*)$", v1=True) SLIGHTLY_PEDANTIC_TRAILING_SLASH_ERROR = ( - "Unrecognised request: You probably wanted a trailing slash") + "Unrecognised request: You probably wanted a trailing slash" + ) def __init__(self, hs): super(PushRuleRestServlet, self).__init__() @@ -54,27 +59,25 @@ def on_PUT(self, request, path): requester = yield self.auth.get_user_by_req(request) - if '/' in spec['rule_id'] or '\\' in spec['rule_id']: + if "/" in spec["rule_id"] or "\\" in spec["rule_id"]: raise SynapseError(400, "rule_id may not contain slashes") content = parse_json_value_from_request(request) user_id = requester.user.to_string() - if 'attr' in spec: + if "attr" in spec: yield self.set_rule_attr(user_id, spec, content) self.notify_user(user_id) defer.returnValue((200, {})) - if spec['rule_id'].startswith('.'): + if spec["rule_id"].startswith("."): # Rule ids starting with '.' are reserved for server default rules. raise SynapseError(400, "cannot add new rule_ids that start with '.'") try: (conditions, actions) = _rule_tuple_from_request_object( - spec['template'], - spec['rule_id'], - content, + spec["template"], spec["rule_id"], content ) except InvalidRuleException as e: raise SynapseError(400, str(e)) @@ -95,7 +98,7 @@ def on_PUT(self, request, path): conditions=conditions, actions=actions, before=before, - after=after + after=after, ) self.notify_user(user_id) except InconsistentRuleException as e: @@ -118,9 +121,7 @@ def on_DELETE(self, request, path): namespaced_rule_id = _namespaced_rule_id_from_spec(spec) try: - yield self.store.delete_push_rule( - user_id, namespaced_rule_id - ) + yield self.store.delete_push_rule(user_id, namespaced_rule_id) self.notify_user(user_id) defer.returnValue((200, {})) except StoreError as e: @@ -149,10 +150,10 @@ def on_GET(self, request, path): PushRuleRestServlet.SLIGHTLY_PEDANTIC_TRAILING_SLASH_ERROR ) - if path[0] == '': + if path[0] == "": defer.returnValue((200, rules)) - elif path[0] == 'global': - result = _filter_ruleset_with_path(rules['global'], path[1:]) + elif path[0] == "global": + result = _filter_ruleset_with_path(rules["global"], path[1:]) defer.returnValue((200, result)) else: raise UnrecognizedRequestError() @@ -162,12 +163,10 @@ def on_OPTIONS(self, request, path): def notify_user(self, user_id): stream_id, _ = self.store.get_push_rules_stream_token() - self.notifier.on_new_event( - "push_rules_key", stream_id, users=[user_id] - ) + self.notifier.on_new_event("push_rules_key", stream_id, users=[user_id]) def set_rule_attr(self, user_id, spec, val): - if spec['attr'] == 'enabled': + if spec["attr"] == "enabled": if isinstance(val, dict) and "enabled" in val: val = val["enabled"] if not isinstance(val, bool): @@ -176,14 +175,12 @@ def set_rule_attr(self, user_id, spec, val): # bools directly, so let's not break them. raise SynapseError(400, "Value for 'enabled' must be boolean") namespaced_rule_id = _namespaced_rule_id_from_spec(spec) - return self.store.set_push_rule_enabled( - user_id, namespaced_rule_id, val - ) - elif spec['attr'] == 'actions': - actions = val.get('actions') + return self.store.set_push_rule_enabled(user_id, namespaced_rule_id, val) + elif spec["attr"] == "actions": + actions = val.get("actions") _check_actions(actions) namespaced_rule_id = _namespaced_rule_id_from_spec(spec) - rule_id = spec['rule_id'] + rule_id = spec["rule_id"] is_default_rule = rule_id.startswith(".") if is_default_rule: if namespaced_rule_id not in BASE_RULE_IDS: @@ -210,12 +207,12 @@ def _rule_spec_from_path(path): """ if len(path) < 2: raise UnrecognizedRequestError() - if path[0] != 'pushrules': + if path[0] != "pushrules": raise UnrecognizedRequestError() scope = path[1] path = path[2:] - if scope != 'global': + if scope != "global": raise UnrecognizedRequestError() if len(path) == 0: @@ -229,56 +226,40 @@ def _rule_spec_from_path(path): rule_id = path[0] - spec = { - 'scope': scope, - 'template': template, - 'rule_id': rule_id - } + spec = {"scope": scope, "template": template, "rule_id": rule_id} path = path[1:] if len(path) > 0 and len(path[0]) > 0: - spec['attr'] = path[0] + spec["attr"] = path[0] return spec def _rule_tuple_from_request_object(rule_template, rule_id, req_obj): - if rule_template in ['override', 'underride']: - if 'conditions' not in req_obj: + if rule_template in ["override", "underride"]: + if "conditions" not in req_obj: raise InvalidRuleException("Missing 'conditions'") - conditions = req_obj['conditions'] + conditions = req_obj["conditions"] for c in conditions: - if 'kind' not in c: + if "kind" not in c: raise InvalidRuleException("Condition without 'kind'") - elif rule_template == 'room': - conditions = [{ - 'kind': 'event_match', - 'key': 'room_id', - 'pattern': rule_id - }] - elif rule_template == 'sender': - conditions = [{ - 'kind': 'event_match', - 'key': 'user_id', - 'pattern': rule_id - }] - elif rule_template == 'content': - if 'pattern' not in req_obj: + elif rule_template == "room": + conditions = [{"kind": "event_match", "key": "room_id", "pattern": rule_id}] + elif rule_template == "sender": + conditions = [{"kind": "event_match", "key": "user_id", "pattern": rule_id}] + elif rule_template == "content": + if "pattern" not in req_obj: raise InvalidRuleException("Content rule missing 'pattern'") - pat = req_obj['pattern'] + pat = req_obj["pattern"] - conditions = [{ - 'kind': 'event_match', - 'key': 'content.body', - 'pattern': pat - }] + conditions = [{"kind": "event_match", "key": "content.body", "pattern": pat}] else: raise InvalidRuleException("Unknown rule template: %s" % (rule_template,)) - if 'actions' not in req_obj: + if "actions" not in req_obj: raise InvalidRuleException("No actions found") - actions = req_obj['actions'] + actions = req_obj["actions"] _check_actions(actions) @@ -290,9 +271,9 @@ def _check_actions(actions): raise InvalidRuleException("No actions found") for a in actions: - if a in ['notify', 'dont_notify', 'coalesce']: + if a in ["notify", "dont_notify", "coalesce"]: pass - elif isinstance(a, dict) and 'set_tweak' in a: + elif isinstance(a, dict) and "set_tweak" in a: pass else: raise InvalidRuleException("Unrecognised action") @@ -304,7 +285,7 @@ def _filter_ruleset_with_path(ruleset, path): PushRuleRestServlet.SLIGHTLY_PEDANTIC_TRAILING_SLASH_ERROR ) - if path[0] == '': + if path[0] == "": return ruleset template_kind = path[0] if template_kind not in ruleset: @@ -314,13 +295,13 @@ def _filter_ruleset_with_path(ruleset, path): raise UnrecognizedRequestError( PushRuleRestServlet.SLIGHTLY_PEDANTIC_TRAILING_SLASH_ERROR ) - if path[0] == '': + if path[0] == "": return ruleset[template_kind] rule_id = path[0] the_rule = None for r in ruleset[template_kind]: - if r['rule_id'] == rule_id: + if r["rule_id"] == rule_id: the_rule = r if the_rule is None: raise NotFoundError @@ -339,19 +320,19 @@ def _filter_ruleset_with_path(ruleset, path): def _priority_class_from_spec(spec): - if spec['template'] not in PRIORITY_CLASS_MAP.keys(): - raise InvalidRuleException("Unknown template: %s" % (spec['template'])) - pc = PRIORITY_CLASS_MAP[spec['template']] + if spec["template"] not in PRIORITY_CLASS_MAP.keys(): + raise InvalidRuleException("Unknown template: %s" % (spec["template"])) + pc = PRIORITY_CLASS_MAP[spec["template"]] return pc def _namespaced_rule_id_from_spec(spec): - return _namespaced_rule_id(spec, spec['rule_id']) + return _namespaced_rule_id(spec, spec["rule_id"]) def _namespaced_rule_id(spec, rule_id): - return "global/%s/%s" % (spec['template'], rule_id) + return "global/%s/%s" % (spec["template"], rule_id) class InvalidRuleException(Exception): diff --git a/synapse/rest/client/v1/pusher.py b/synapse/rest/client/v1/pusher.py index 15d860db3762..e9246018df71 100644 --- a/synapse/rest/client/v1/pusher.py +++ b/synapse/rest/client/v1/pusher.py @@ -44,9 +44,7 @@ def on_GET(self, request): requester = yield self.auth.get_user_by_req(request) user = requester.user - pushers = yield self.hs.get_datastore().get_pushers_by_user_id( - user.to_string() - ) + pushers = yield self.hs.get_datastore().get_pushers_by_user_id(user.to_string()) allowed_keys = [ "app_display_name", @@ -87,50 +85,61 @@ def on_POST(self, request): content = parse_json_object_from_request(request) - if ('pushkey' in content and 'app_id' in content - and 'kind' in content and - content['kind'] is None): + if ( + "pushkey" in content + and "app_id" in content + and "kind" in content + and content["kind"] is None + ): yield self.pusher_pool.remove_pusher( - content['app_id'], content['pushkey'], user_id=user.to_string() + content["app_id"], content["pushkey"], user_id=user.to_string() ) defer.returnValue((200, {})) assert_params_in_dict( content, - ['kind', 'app_id', 'app_display_name', - 'device_display_name', 'pushkey', 'lang', 'data'] + [ + "kind", + "app_id", + "app_display_name", + "device_display_name", + "pushkey", + "lang", + "data", + ], ) - logger.debug("set pushkey %s to kind %s", content['pushkey'], content['kind']) + logger.debug("set pushkey %s to kind %s", content["pushkey"], content["kind"]) logger.debug("Got pushers request with body: %r", content) append = False - if 'append' in content: - append = content['append'] + if "append" in content: + append = content["append"] if not append: yield self.pusher_pool.remove_pushers_by_app_id_and_pushkey_not_user( - app_id=content['app_id'], - pushkey=content['pushkey'], - not_user_id=user.to_string() + app_id=content["app_id"], + pushkey=content["pushkey"], + not_user_id=user.to_string(), ) try: yield self.pusher_pool.add_pusher( user_id=user.to_string(), access_token=requester.access_token_id, - kind=content['kind'], - app_id=content['app_id'], - app_display_name=content['app_display_name'], - device_display_name=content['device_display_name'], - pushkey=content['pushkey'], - lang=content['lang'], - data=content['data'], - profile_tag=content.get('profile_tag', ""), + kind=content["kind"], + app_id=content["app_id"], + app_display_name=content["app_display_name"], + device_display_name=content["device_display_name"], + pushkey=content["pushkey"], + lang=content["lang"], + data=content["data"], + profile_tag=content.get("profile_tag", ""), ) except PusherConfigException as pce: - raise SynapseError(400, "Config Error: " + str(pce), - errcode=Codes.MISSING_PARAM) + raise SynapseError( + 400, "Config Error: " + str(pce), errcode=Codes.MISSING_PARAM + ) self.notifier.on_new_replication_data() @@ -144,6 +153,7 @@ class PushersRemoveRestServlet(RestServlet): """ To allow pusher to be delete by clicking a link (ie. GET request) """ + PATTERNS = client_patterns("/pushers/remove$", v1=True) SUCCESS_HTML = b"You have been unsubscribed" @@ -164,9 +174,7 @@ def on_GET(self, request): try: yield self.pusher_pool.remove_pusher( - app_id=app_id, - pushkey=pushkey, - user_id=user.to_string(), + app_id=app_id, pushkey=pushkey, user_id=user.to_string() ) except StoreError as se: if se.code != 404: @@ -177,9 +185,9 @@ def on_GET(self, request): request.setResponseCode(200) request.setHeader(b"Content-Type", b"text/html; charset=utf-8") - request.setHeader(b"Content-Length", b"%d" % ( - len(PushersRemoveRestServlet.SUCCESS_HTML), - )) + request.setHeader( + b"Content-Length", b"%d" % (len(PushersRemoveRestServlet.SUCCESS_HTML),) + ) request.write(PushersRemoveRestServlet.SUCCESS_HTML) finish_request(request) defer.returnValue(None) diff --git a/synapse/rest/client/v1/room.py b/synapse/rest/client/v1/room.py index e8f672c4ba42..a028337125fb 100644 --- a/synapse/rest/client/v1/room.py +++ b/synapse/rest/client/v1/room.py @@ -61,18 +61,16 @@ def register(self, http_server): PATTERNS = "/createRoom" register_txn_path(self, PATTERNS, http_server) # define CORS for all of /rooms in RoomCreateRestServlet for simplicity - http_server.register_paths("OPTIONS", - client_patterns("/rooms(?:/.*)?$", v1=True), - self.on_OPTIONS) + http_server.register_paths( + "OPTIONS", client_patterns("/rooms(?:/.*)?$", v1=True), self.on_OPTIONS + ) # define CORS for /createRoom[/txnid] - http_server.register_paths("OPTIONS", - client_patterns("/createRoom(?:/.*)?$", v1=True), - self.on_OPTIONS) + http_server.register_paths( + "OPTIONS", client_patterns("/createRoom(?:/.*)?$", v1=True), self.on_OPTIONS + ) def on_PUT(self, request, txn_id): - return self.txns.fetch_or_execute_request( - request, self.on_POST, request - ) + return self.txns.fetch_or_execute_request(request, self.on_POST, request) @defer.inlineCallbacks def on_POST(self, request): @@ -107,21 +105,23 @@ def register(self, http_server): no_state_key = "/rooms/(?P[^/]*)/state/(?P[^/]*)$" # /room/$roomid/state/$eventtype/$statekey - state_key = ("/rooms/(?P[^/]*)/state/" - "(?P[^/]*)/(?P[^/]*)$") - - http_server.register_paths("GET", - client_patterns(state_key, v1=True), - self.on_GET) - http_server.register_paths("PUT", - client_patterns(state_key, v1=True), - self.on_PUT) - http_server.register_paths("GET", - client_patterns(no_state_key, v1=True), - self.on_GET_no_state_key) - http_server.register_paths("PUT", - client_patterns(no_state_key, v1=True), - self.on_PUT_no_state_key) + state_key = ( + "/rooms/(?P[^/]*)/state/" + "(?P[^/]*)/(?P[^/]*)$" + ) + + http_server.register_paths( + "GET", client_patterns(state_key, v1=True), self.on_GET + ) + http_server.register_paths( + "PUT", client_patterns(state_key, v1=True), self.on_PUT + ) + http_server.register_paths( + "GET", client_patterns(no_state_key, v1=True), self.on_GET_no_state_key + ) + http_server.register_paths( + "PUT", client_patterns(no_state_key, v1=True), self.on_PUT_no_state_key + ) def on_GET_no_state_key(self, request, room_id, event_type): return self.on_GET(request, room_id, event_type, "") @@ -132,8 +132,9 @@ def on_PUT_no_state_key(self, request, room_id, event_type): @defer.inlineCallbacks def on_GET(self, request, room_id, event_type, state_key): requester = yield self.auth.get_user_by_req(request, allow_guest=True) - format = parse_string(request, "format", default="content", - allowed_values=["content", "event"]) + format = parse_string( + request, "format", default="content", allowed_values=["content", "event"] + ) msg_handler = self.message_handler data = yield msg_handler.get_room_data( @@ -145,9 +146,7 @@ def on_GET(self, request, room_id, event_type, state_key): ) if not data: - raise SynapseError( - 404, "Event not found.", errcode=Codes.NOT_FOUND - ) + raise SynapseError(404, "Event not found.", errcode=Codes.NOT_FOUND) if format == "event": event = format_event_for_client_v2(data.get_dict()) @@ -182,9 +181,7 @@ def on_PUT(self, request, room_id, event_type, state_key, txn_id=None): ) else: event = yield self.event_creation_handler.create_and_send_nonmember_event( - requester, - event_dict, - txn_id=txn_id, + requester, event_dict, txn_id=txn_id ) ret = {} @@ -195,7 +192,6 @@ def on_PUT(self, request, room_id, event_type, state_key, txn_id=None): # TODO: Needs unit testing for generic events + feedback class RoomSendEventRestServlet(TransactionRestServlet): - def __init__(self, hs): super(RoomSendEventRestServlet, self).__init__(hs) self.event_creation_handler = hs.get_event_creation_handler() @@ -203,7 +199,7 @@ def __init__(self, hs): def register(self, http_server): # /rooms/$roomid/send/$event_type[/$txn_id] - PATTERNS = ("/rooms/(?P[^/]*)/send/(?P[^/]*)") + PATTERNS = "/rooms/(?P[^/]*)/send/(?P[^/]*)" register_txn_path(self, PATTERNS, http_server, with_get=True) @defer.inlineCallbacks @@ -218,13 +214,11 @@ def on_POST(self, request, room_id, event_type, txn_id=None): "sender": requester.user.to_string(), } - if b'ts' in request.args and requester.app_service: - event_dict['origin_server_ts'] = parse_integer(request, "ts", 0) + if b"ts" in request.args and requester.app_service: + event_dict["origin_server_ts"] = parse_integer(request, "ts", 0) event = yield self.event_creation_handler.create_and_send_nonmember_event( - requester, - event_dict, - txn_id=txn_id, + requester, event_dict, txn_id=txn_id ) defer.returnValue((200, {"event_id": event.event_id})) @@ -247,15 +241,12 @@ def __init__(self, hs): def register(self, http_server): # /join/$room_identifier[/$txn_id] - PATTERNS = ("/join/(?P[^/]*)") + PATTERNS = "/join/(?P[^/]*)" register_txn_path(self, PATTERNS, http_server) @defer.inlineCallbacks def on_POST(self, request, room_identifier, txn_id=None): - requester = yield self.auth.get_user_by_req( - request, - allow_guest=True, - ) + requester = yield self.auth.get_user_by_req(request, allow_guest=True) try: content = parse_json_object_from_request(request) @@ -268,7 +259,7 @@ def on_POST(self, request, room_identifier, txn_id=None): room_id = room_identifier try: remote_room_hosts = [ - x.decode('ascii') for x in request.args[b"server_name"] + x.decode("ascii") for x in request.args[b"server_name"] ] except Exception: remote_room_hosts = None @@ -278,9 +269,9 @@ def on_POST(self, request, room_identifier, txn_id=None): room_id, remote_room_hosts = yield handler.lookup_room_alias(room_alias) room_id = room_id.to_string() else: - raise SynapseError(400, "%s was not legal room ID or room alias" % ( - room_identifier, - )) + raise SynapseError( + 400, "%s was not legal room ID or room alias" % (room_identifier,) + ) yield self.room_member_handler.update_membership( requester=requester, @@ -339,14 +330,11 @@ def on_GET(self, request): handler = self.hs.get_room_list_handler() if server: data = yield handler.get_remote_public_room_list( - server, - limit=limit, - since_token=since_token, + server, limit=limit, since_token=since_token ) else: data = yield handler.get_local_public_room_list( - limit=limit, - since_token=since_token, + limit=limit, since_token=since_token ) defer.returnValue((200, data)) @@ -439,16 +427,13 @@ def on_GET(self, request, room_id): chunk = [] for event in events: - if ( - (membership and event['content'].get("membership") != membership) or - (not_membership and event['content'].get("membership") == not_membership) + if (membership and event["content"].get("membership") != membership) or ( + not_membership and event["content"].get("membership") == not_membership ): continue chunk.append(event) - defer.returnValue((200, { - "chunk": chunk - })) + defer.returnValue((200, {"chunk": chunk})) # deprecated in favour of /members?membership=join? @@ -466,12 +451,10 @@ def on_GET(self, request, room_id): requester = yield self.auth.get_user_by_req(request) users_with_profile = yield self.message_handler.get_joined_members( - requester, room_id, + requester, room_id ) - defer.returnValue((200, { - "joined": users_with_profile, - })) + defer.returnValue((200, {"joined": users_with_profile})) # TODO: Needs better unit testing @@ -486,9 +469,7 @@ def __init__(self, hs): @defer.inlineCallbacks def on_GET(self, request, room_id): requester = yield self.auth.get_user_by_req(request, allow_guest=True) - pagination_config = PaginationConfig.from_request( - request, default_limit=10, - ) + pagination_config = PaginationConfig.from_request(request, default_limit=10) as_client_event = b"raw" not in request.args filter_bytes = parse_string(request, b"filter", encoding=None) if filter_bytes: @@ -544,9 +525,7 @@ def on_GET(self, request, room_id): requester = yield self.auth.get_user_by_req(request, allow_guest=True) pagination_config = PaginationConfig.from_request(request) content = yield self.initial_sync_handler.room_initial_sync( - room_id=room_id, - requester=requester, - pagin_config=pagination_config, + room_id=room_id, requester=requester, pagin_config=pagination_config ) defer.returnValue((200, content)) @@ -603,30 +582,24 @@ def on_GET(self, request, room_id, event_id): event_filter = None results = yield self.room_context_handler.get_event_context( - requester.user, - room_id, - event_id, - limit, - event_filter, + requester.user, room_id, event_id, limit, event_filter ) if not results: - raise SynapseError( - 404, "Event not found.", errcode=Codes.NOT_FOUND - ) + raise SynapseError(404, "Event not found.", errcode=Codes.NOT_FOUND) time_now = self.clock.time_msec() results["events_before"] = yield self._event_serializer.serialize_events( - results["events_before"], time_now, + results["events_before"], time_now ) results["event"] = yield self._event_serializer.serialize_event( - results["event"], time_now, + results["event"], time_now ) results["events_after"] = yield self._event_serializer.serialize_events( - results["events_after"], time_now, + results["events_after"], time_now ) results["state"] = yield self._event_serializer.serialize_events( - results["state"], time_now, + results["state"], time_now ) defer.returnValue((200, results)) @@ -639,20 +612,14 @@ def __init__(self, hs): self.auth = hs.get_auth() def register(self, http_server): - PATTERNS = ("/rooms/(?P[^/]*)/forget") + PATTERNS = "/rooms/(?P[^/]*)/forget" register_txn_path(self, PATTERNS, http_server) @defer.inlineCallbacks def on_POST(self, request, room_id, txn_id=None): - requester = yield self.auth.get_user_by_req( - request, - allow_guest=False, - ) + requester = yield self.auth.get_user_by_req(request, allow_guest=False) - yield self.room_member_handler.forget( - user=requester.user, - room_id=room_id, - ) + yield self.room_member_handler.forget(user=requester.user, room_id=room_id) defer.returnValue((200, {})) @@ -664,7 +631,6 @@ def on_PUT(self, request, room_id, txn_id): # TODO: Needs unit testing class RoomMembershipRestServlet(TransactionRestServlet): - def __init__(self, hs): super(RoomMembershipRestServlet, self).__init__(hs) self.room_member_handler = hs.get_room_member_handler() @@ -672,20 +638,19 @@ def __init__(self, hs): def register(self, http_server): # /rooms/$roomid/[invite|join|leave] - PATTERNS = ("/rooms/(?P[^/]*)/" - "(?Pjoin|invite|leave|ban|unban|kick)") + PATTERNS = ( + "/rooms/(?P[^/]*)/" + "(?Pjoin|invite|leave|ban|unban|kick)" + ) register_txn_path(self, PATTERNS, http_server) @defer.inlineCallbacks def on_POST(self, request, room_id, membership_action, txn_id=None): - requester = yield self.auth.get_user_by_req( - request, - allow_guest=True, - ) + requester = yield self.auth.get_user_by_req(request, allow_guest=True) if requester.is_guest and membership_action not in { Membership.JOIN, - Membership.LEAVE + Membership.LEAVE, }: raise AuthError(403, "Guest access not allowed") @@ -704,7 +669,7 @@ def on_POST(self, request, room_id, membership_action, txn_id=None): content["address"], content["id_server"], requester, - txn_id + txn_id, ) defer.returnValue((200, {})) return @@ -715,8 +680,8 @@ def on_POST(self, request, room_id, membership_action, txn_id=None): target = UserID.from_string(content["user_id"]) event_content = None - if 'reason' in content and membership_action in ['kick', 'ban']: - event_content = {'reason': content['reason']} + if "reason" in content and membership_action in ["kick", "ban"]: + event_content = {"reason": content["reason"]} yield self.room_member_handler.update_membership( requester=requester, @@ -755,7 +720,7 @@ def __init__(self, hs): self.auth = hs.get_auth() def register(self, http_server): - PATTERNS = ("/rooms/(?P[^/]*)/redact/(?P[^/]*)") + PATTERNS = "/rooms/(?P[^/]*)/redact/(?P[^/]*)" register_txn_path(self, PATTERNS, http_server) @defer.inlineCallbacks @@ -817,9 +782,7 @@ def on_PUT(self, request, room_id, user_id): ) else: yield self.typing_handler.stopped_typing( - target_user=target_user, - auth_user=requester.user, - room_id=room_id, + target_user=target_user, auth_user=requester.user, room_id=room_id ) defer.returnValue((200, {})) @@ -841,9 +804,7 @@ def on_POST(self, request): batch = parse_string(request, "next_batch") results = yield self.handlers.search_handler.search( - requester.user, - content, - batch, + requester.user, content, batch ) defer.returnValue((200, results)) @@ -879,20 +840,18 @@ def register_txn_path(servlet, regex_string, http_server, with_get=False): with_get: True to also register respective GET paths for the PUTs. """ http_server.register_paths( - "POST", - client_patterns(regex_string + "$", v1=True), - servlet.on_POST + "POST", client_patterns(regex_string + "$", v1=True), servlet.on_POST ) http_server.register_paths( "PUT", client_patterns(regex_string + "/(?P[^/]*)$", v1=True), - servlet.on_PUT + servlet.on_PUT, ) if with_get: http_server.register_paths( "GET", client_patterns(regex_string + "/(?P[^/]*)$", v1=True), - servlet.on_GET + servlet.on_GET, ) diff --git a/synapse/rest/client/v1/voip.py b/synapse/rest/client/v1/voip.py index 638104921001..41b3171ac828 100644 --- a/synapse/rest/client/v1/voip.py +++ b/synapse/rest/client/v1/voip.py @@ -34,8 +34,7 @@ def __init__(self, hs): @defer.inlineCallbacks def on_GET(self, request): requester = yield self.auth.get_user_by_req( - request, - self.hs.config.turn_allow_guests + request, self.hs.config.turn_allow_guests ) turnUris = self.hs.config.turn_uris @@ -49,9 +48,7 @@ def on_GET(self, request): username = "%d:%s" % (expiry, requester.user.to_string()) mac = hmac.new( - turnSecret.encode(), - msg=username.encode(), - digestmod=hashlib.sha1 + turnSecret.encode(), msg=username.encode(), digestmod=hashlib.sha1 ) # We need to use standard padded base64 encoding here # encode_base64 because we need to add the standard padding to get the @@ -65,12 +62,17 @@ def on_GET(self, request): else: defer.returnValue((200, {})) - defer.returnValue((200, { - 'username': username, - 'password': password, - 'ttl': userLifetime / 1000, - 'uris': turnUris, - })) + defer.returnValue( + ( + 200, + { + "username": username, + "password": password, + "ttl": userLifetime / 1000, + "uris": turnUris, + }, + ) + ) def on_OPTIONS(self, request): return (200, {}) diff --git a/synapse/rest/client/v2_alpha/_base.py b/synapse/rest/client/v2_alpha/_base.py index 5236d5d566e7..e3d59ac3ac5e 100644 --- a/synapse/rest/client/v2_alpha/_base.py +++ b/synapse/rest/client/v2_alpha/_base.py @@ -52,11 +52,11 @@ def client_patterns(path_regex, releases=(0,), unstable=True, v1=False): def set_timeline_upper_limit(filter_json, filter_timeline_limit): if filter_timeline_limit < 0: return # no upper limits - timeline = filter_json.get('room', {}).get('timeline', {}) - if 'limit' in timeline: - filter_json['room']['timeline']["limit"] = min( - filter_json['room']['timeline']['limit'], - filter_timeline_limit) + timeline = filter_json.get("room", {}).get("timeline", {}) + if "limit" in timeline: + filter_json["room"]["timeline"]["limit"] = min( + filter_json["room"]["timeline"]["limit"], filter_timeline_limit + ) def interactive_auth_handler(orig): @@ -74,10 +74,12 @@ def on_POST(self, request): # ... yield self.auth_handler.check_auth """ + def wrapped(*args, **kwargs): res = defer.maybeDeferred(orig, *args, **kwargs) res.addErrback(_catch_incomplete_interactive_auth) return res + return wrapped diff --git a/synapse/rest/client/v2_alpha/account.py b/synapse/rest/client/v2_alpha/account.py index ab75f6c2b218..f143d8b85cfb 100644 --- a/synapse/rest/client/v2_alpha/account.py +++ b/synapse/rest/client/v2_alpha/account.py @@ -52,6 +52,7 @@ def __init__(self, hs): if self.config.email_password_reset_behaviour == "local": from synapse.push.mailer import Mailer, load_jinja2_templates + templates = load_jinja2_templates( config=hs.config, template_html_name=hs.config.email_password_reset_template_html, @@ -72,14 +73,12 @@ def on_POST(self, request): "User password resets have been disabled due to lack of email config" ) raise SynapseError( - 400, "Email-based password resets have been disabled on this server", + 400, "Email-based password resets have been disabled on this server" ) body = parse_json_object_from_request(request) - assert_params_in_dict(body, [ - 'client_secret', 'email', 'send_attempt' - ]) + assert_params_in_dict(body, ["client_secret", "email", "send_attempt"]) # Extract params from body client_secret = body["client_secret"] @@ -95,24 +94,24 @@ def on_POST(self, request): ) existingUid = yield self.hs.get_datastore().get_user_id_by_threepid( - 'email', email, + "email", email ) if existingUid is None: raise SynapseError(400, "Email not found", Codes.THREEPID_NOT_FOUND) if self.config.email_password_reset_behaviour == "remote": - if 'id_server' not in body: + if "id_server" not in body: raise SynapseError(400, "Missing 'id_server' param in body") # Have the identity server handle the password reset flow ret = yield self.identity_handler.requestEmailToken( - body["id_server"], email, client_secret, send_attempt, next_link, + body["id_server"], email, client_secret, send_attempt, next_link ) else: # Send password reset emails from Synapse sid = yield self.send_password_reset( - email, client_secret, send_attempt, next_link, + email, client_secret, send_attempt, next_link ) # Wrap the session id in a JSON object @@ -121,13 +120,7 @@ def on_POST(self, request): defer.returnValue((200, ret)) @defer.inlineCallbacks - def send_password_reset( - self, - email, - client_secret, - send_attempt, - next_link=None, - ): + def send_password_reset(self, email, client_secret, send_attempt, next_link=None): """Send a password reset email Args: @@ -144,14 +137,14 @@ def send_password_reset( # Check that this email/client_secret/send_attempt combo is new or # greater than what we've seen previously session = yield self.datastore.get_threepid_validation_session( - "email", client_secret, address=email, validated=False, + "email", client_secret, address=email, validated=False ) # Check to see if a session already exists and that it is not yet # marked as validated if session and session.get("validated_at") is None: - session_id = session['session_id'] - last_send_attempt = session['last_send_attempt'] + session_id = session["session_id"] + last_send_attempt = session["last_send_attempt"] # Check that the send_attempt is higher than previous attempts if send_attempt <= last_send_attempt: @@ -169,22 +162,27 @@ def send_password_reset( # and session_id try: yield self.mailer.send_password_reset_mail( - email, token, client_secret, session_id, + email, token, client_secret, session_id ) except Exception: - logger.exception( - "Error sending a password reset email to %s", email, - ) + logger.exception("Error sending a password reset email to %s", email) raise SynapseError( 500, "An error was encountered when sending the password reset email" ) - token_expires = (self.hs.clock.time_msec() + - self.config.email_validation_token_lifetime) + token_expires = ( + self.hs.clock.time_msec() + self.config.email_validation_token_lifetime + ) yield self.datastore.start_or_continue_validation_session( - "email", email, session_id, client_secret, send_attempt, - next_link, token, token_expires, + "email", + email, + session_id, + client_secret, + send_attempt, + next_link, + token, + token_expires, ) defer.returnValue(session_id) @@ -203,12 +201,12 @@ def __init__(self, hs): def on_POST(self, request): body = parse_json_object_from_request(request) - assert_params_in_dict(body, [ - 'id_server', 'client_secret', - 'country', 'phone_number', 'send_attempt', - ]) + assert_params_in_dict( + body, + ["id_server", "client_secret", "country", "phone_number", "send_attempt"], + ) - msisdn = phone_number_to_msisdn(body['country'], body['phone_number']) + msisdn = phone_number_to_msisdn(body["country"], body["phone_number"]) if not check_3pid_allowed(self.hs, "msisdn", msisdn): raise SynapseError( @@ -217,9 +215,7 @@ def on_POST(self, request): Codes.THREEPID_DENIED, ) - existingUid = yield self.datastore.get_user_id_by_threepid( - 'msisdn', msisdn - ) + existingUid = yield self.datastore.get_user_id_by_threepid("msisdn", msisdn) if existingUid is None: raise SynapseError(400, "MSISDN not found", Codes.THREEPID_NOT_FOUND) @@ -230,10 +226,9 @@ def on_POST(self, request): class PasswordResetSubmitTokenServlet(RestServlet): """Handles 3PID validation token submission""" + PATTERNS = client_patterns( - "/password_reset/(?P[^/]*)/submit_token/*$", - releases=(), - unstable=True, + "/password_reset/(?P[^/]*)/submit_token/*$", releases=(), unstable=True ) def __init__(self, hs): @@ -252,8 +247,7 @@ def __init__(self, hs): def on_GET(self, request, medium): if medium != "email": raise SynapseError( - 400, - "This medium is currently not supported for password resets", + 400, "This medium is currently not supported for password resets" ) if self.config.email_password_reset_behaviour == "off": if self.config.password_resets_were_disabled_due_to_email_config: @@ -261,7 +255,7 @@ def on_GET(self, request, medium): "User password resets have been disabled due to lack of email config" ) raise SynapseError( - 400, "Email-based password resets have been disabled on this server", + 400, "Email-based password resets have been disabled on this server" ) sid = parse_string(request, "sid") @@ -272,10 +266,7 @@ def on_GET(self, request, medium): try: # Mark the session as valid next_link = yield self.datastore.validate_threepid_session( - sid, - client_secret, - token, - self.clock.time_msec(), + sid, client_secret, token, self.clock.time_msec() ) # Perform a 302 redirect if next_link is set @@ -298,13 +289,11 @@ def on_GET(self, request, medium): html = self.load_jinja2_template( self.config.email_template_dir, self.config.email_password_reset_failure_template, - template_vars={ - "failure_reason": e.msg, - } + template_vars={"failure_reason": e.msg}, ) request.setResponseCode(e.code) - request.write(html.encode('utf-8')) + request.write(html.encode("utf-8")) finish_request(request) defer.returnValue(None) @@ -330,20 +319,14 @@ def load_jinja2_template(self, template_dir, template_filename, template_vars): def on_POST(self, request, medium): if medium != "email": raise SynapseError( - 400, - "This medium is currently not supported for password resets", + 400, "This medium is currently not supported for password resets" ) body = parse_json_object_from_request(request) - assert_params_in_dict(body, [ - 'sid', 'client_secret', 'token', - ]) + assert_params_in_dict(body, ["sid", "client_secret", "token"]) valid, _ = yield self.datastore.validate_threepid_validation_token( - body['sid'], - body['client_secret'], - body['token'], - self.clock.time_msec(), + body["sid"], body["client_secret"], body["token"], self.clock.time_msec() ) response_code = 200 if valid else 400 @@ -379,29 +362,30 @@ def on_POST(self, request): if self.auth.has_access_token(request): requester = yield self.auth.get_user_by_req(request) params = yield self.auth_handler.validate_user_via_ui_auth( - requester, body, self.hs.get_ip_from_request(request), + requester, body, self.hs.get_ip_from_request(request) ) user_id = requester.user.to_string() else: requester = None result, params, _ = yield self.auth_handler.check_auth( [[LoginType.EMAIL_IDENTITY], [LoginType.MSISDN]], - body, self.hs.get_ip_from_request(request), + body, + self.hs.get_ip_from_request(request), password_servlet=True, ) if LoginType.EMAIL_IDENTITY in result: threepid = result[LoginType.EMAIL_IDENTITY] - if 'medium' not in threepid or 'address' not in threepid: + if "medium" not in threepid or "address" not in threepid: raise SynapseError(500, "Malformed threepid") - if threepid['medium'] == 'email': + if threepid["medium"] == "email": # For emails, transform the address to lowercase. # We store all email addreses as lowercase in the DB. # (See add_threepid in synapse/handlers/auth.py) - threepid['address'] = threepid['address'].lower() + threepid["address"] = threepid["address"].lower() # if using email, we must know about the email they're authing with! threepid_user_id = yield self.datastore.get_user_id_by_threepid( - threepid['medium'], threepid['address'] + threepid["medium"], threepid["address"] ) if not threepid_user_id: raise SynapseError(404, "Email address not found", Codes.NOT_FOUND) @@ -411,11 +395,9 @@ def on_POST(self, request): raise SynapseError(500, "", Codes.UNKNOWN) assert_params_in_dict(params, ["new_password"]) - new_password = params['new_password'] + new_password = params["new_password"] - yield self._set_password_handler.set_password( - user_id, new_password, requester - ) + yield self._set_password_handler.set_password(user_id, new_password, requester) defer.returnValue((200, {})) @@ -450,25 +432,22 @@ def on_POST(self, request): # allow ASes to dectivate their own users if requester.app_service: yield self._deactivate_account_handler.deactivate_account( - requester.user.to_string(), erase, + requester.user.to_string(), erase ) defer.returnValue((200, {})) yield self.auth_handler.validate_user_via_ui_auth( - requester, body, self.hs.get_ip_from_request(request), + requester, body, self.hs.get_ip_from_request(request) ) result = yield self._deactivate_account_handler.deactivate_account( - requester.user.to_string(), erase, - id_server=body.get("id_server"), + requester.user.to_string(), erase, id_server=body.get("id_server") ) if result: id_server_unbind_result = "success" else: id_server_unbind_result = "no-support" - defer.returnValue((200, { - "id_server_unbind_result": id_server_unbind_result, - })) + defer.returnValue((200, {"id_server_unbind_result": id_server_unbind_result})) class EmailThreepidRequestTokenRestServlet(RestServlet): @@ -484,11 +463,10 @@ def __init__(self, hs): def on_POST(self, request): body = parse_json_object_from_request(request) assert_params_in_dict( - body, - ['id_server', 'client_secret', 'email', 'send_attempt'], + body, ["id_server", "client_secret", "email", "send_attempt"] ) - if not check_3pid_allowed(self.hs, "email", body['email']): + if not check_3pid_allowed(self.hs, "email", body["email"]): raise SynapseError( 403, "Your email domain is not authorized on this server", @@ -496,7 +474,7 @@ def on_POST(self, request): ) existingUid = yield self.datastore.get_user_id_by_threepid( - 'email', body['email'] + "email", body["email"] ) if existingUid is not None: @@ -518,12 +496,12 @@ def __init__(self, hs): @defer.inlineCallbacks def on_POST(self, request): body = parse_json_object_from_request(request) - assert_params_in_dict(body, [ - 'id_server', 'client_secret', - 'country', 'phone_number', 'send_attempt', - ]) + assert_params_in_dict( + body, + ["id_server", "client_secret", "country", "phone_number", "send_attempt"], + ) - msisdn = phone_number_to_msisdn(body['country'], body['phone_number']) + msisdn = phone_number_to_msisdn(body["country"], body["phone_number"]) if not check_3pid_allowed(self.hs, "msisdn", msisdn): raise SynapseError( @@ -532,9 +510,7 @@ def on_POST(self, request): Codes.THREEPID_DENIED, ) - existingUid = yield self.datastore.get_user_id_by_threepid( - 'msisdn', msisdn - ) + existingUid = yield self.datastore.get_user_id_by_threepid("msisdn", msisdn) if existingUid is not None: raise SynapseError(400, "MSISDN is already in use", Codes.THREEPID_IN_USE) @@ -558,18 +534,16 @@ def __init__(self, hs): def on_GET(self, request): requester = yield self.auth.get_user_by_req(request) - threepids = yield self.datastore.user_get_threepids( - requester.user.to_string() - ) + threepids = yield self.datastore.user_get_threepids(requester.user.to_string()) - defer.returnValue((200, {'threepids': threepids})) + defer.returnValue((200, {"threepids": threepids})) @defer.inlineCallbacks def on_POST(self, request): body = parse_json_object_from_request(request) - threePidCreds = body.get('threePidCreds') - threePidCreds = body.get('three_pid_creds', threePidCreds) + threePidCreds = body.get("threePidCreds") + threePidCreds = body.get("three_pid_creds", threePidCreds) if threePidCreds is None: raise SynapseError(400, "Missing param", Codes.MISSING_PARAM) @@ -579,30 +553,20 @@ def on_POST(self, request): threepid = yield self.identity_handler.threepid_from_creds(threePidCreds) if not threepid: - raise SynapseError( - 400, "Failed to auth 3pid", Codes.THREEPID_AUTH_FAILED - ) + raise SynapseError(400, "Failed to auth 3pid", Codes.THREEPID_AUTH_FAILED) - for reqd in ['medium', 'address', 'validated_at']: + for reqd in ["medium", "address", "validated_at"]: if reqd not in threepid: logger.warn("Couldn't add 3pid: invalid response from ID server") raise SynapseError(500, "Invalid response from ID Server") yield self.auth_handler.add_threepid( - user_id, - threepid['medium'], - threepid['address'], - threepid['validated_at'], + user_id, threepid["medium"], threepid["address"], threepid["validated_at"] ) - if 'bind' in body and body['bind']: - logger.debug( - "Binding threepid %s to %s", - threepid, user_id - ) - yield self.identity_handler.bind_threepid( - threePidCreds, user_id - ) + if "bind" in body and body["bind"]: + logger.debug("Binding threepid %s to %s", threepid, user_id) + yield self.identity_handler.bind_threepid(threePidCreds, user_id) defer.returnValue((200, {})) @@ -618,14 +582,14 @@ def __init__(self, hs): @defer.inlineCallbacks def on_POST(self, request): body = parse_json_object_from_request(request) - assert_params_in_dict(body, ['medium', 'address']) + assert_params_in_dict(body, ["medium", "address"]) requester = yield self.auth.get_user_by_req(request) user_id = requester.user.to_string() try: ret = yield self.auth_handler.delete_threepid( - user_id, body['medium'], body['address'], body.get("id_server"), + user_id, body["medium"], body["address"], body.get("id_server") ) except Exception: # NB. This endpoint should succeed if there is nothing to @@ -639,9 +603,7 @@ def on_POST(self, request): else: id_server_unbind_result = "no-support" - defer.returnValue((200, { - "id_server_unbind_result": id_server_unbind_result, - })) + defer.returnValue((200, {"id_server_unbind_result": id_server_unbind_result})) class WhoamiRestServlet(RestServlet): @@ -655,7 +617,7 @@ def __init__(self, hs): def on_GET(self, request): requester = yield self.auth.get_user_by_req(request) - defer.returnValue((200, {'user_id': requester.user.to_string()})) + defer.returnValue((200, {"user_id": requester.user.to_string()})) def register_servlets(hs, http_server): diff --git a/synapse/rest/client/v2_alpha/account_data.py b/synapse/rest/client/v2_alpha/account_data.py index 574a6298cea2..f155c26259d8 100644 --- a/synapse/rest/client/v2_alpha/account_data.py +++ b/synapse/rest/client/v2_alpha/account_data.py @@ -30,6 +30,7 @@ class AccountDataServlet(RestServlet): PUT /user/{user_id}/account_data/{account_dataType} HTTP/1.1 GET /user/{user_id}/account_data/{account_dataType} HTTP/1.1 """ + PATTERNS = client_patterns( "/user/(?P[^/]*)/account_data/(?P[^/]*)" ) @@ -52,9 +53,7 @@ def on_PUT(self, request, user_id, account_data_type): user_id, account_data_type, body ) - self.notifier.on_new_event( - "account_data_key", max_id, users=[user_id] - ) + self.notifier.on_new_event("account_data_key", max_id, users=[user_id]) defer.returnValue((200, {})) @@ -65,7 +64,7 @@ def on_GET(self, request, user_id, account_data_type): raise AuthError(403, "Cannot get account data for other users.") event = yield self.store.get_global_account_data_by_type_for_user( - account_data_type, user_id, + account_data_type, user_id ) if event is None: @@ -79,6 +78,7 @@ class RoomAccountDataServlet(RestServlet): PUT /user/{user_id}/rooms/{room_id}/account_data/{account_dataType} HTTP/1.1 GET /user/{user_id}/rooms/{room_id}/account_data/{account_dataType} HTTP/1.1 """ + PATTERNS = client_patterns( "/user/(?P[^/]*)" "/rooms/(?P[^/]*)" @@ -103,16 +103,14 @@ def on_PUT(self, request, user_id, room_id, account_data_type): raise SynapseError( 405, "Cannot set m.fully_read through this API." - " Use /rooms/!roomId:server.name/read_markers" + " Use /rooms/!roomId:server.name/read_markers", ) max_id = yield self.store.add_account_data_to_room( user_id, room_id, account_data_type, body ) - self.notifier.on_new_event( - "account_data_key", max_id, users=[user_id] - ) + self.notifier.on_new_event("account_data_key", max_id, users=[user_id]) defer.returnValue((200, {})) @@ -123,7 +121,7 @@ def on_GET(self, request, user_id, room_id, account_data_type): raise AuthError(403, "Cannot get account data for other users.") event = yield self.store.get_account_data_for_room_and_type( - user_id, room_id, account_data_type, + user_id, room_id, account_data_type ) if event is None: diff --git a/synapse/rest/client/v2_alpha/account_validity.py b/synapse/rest/client/v2_alpha/account_validity.py index 63bdc3356434..d29c10b83d32 100644 --- a/synapse/rest/client/v2_alpha/account_validity.py +++ b/synapse/rest/client/v2_alpha/account_validity.py @@ -28,7 +28,9 @@ class AccountValidityRenewServlet(RestServlet): PATTERNS = client_patterns("/account_validity/renew$") - SUCCESS_HTML = b"Your account has been successfully renewed." + SUCCESS_HTML = ( + b"Your account has been successfully renewed." + ) def __init__(self, hs): """ @@ -47,13 +49,13 @@ def on_GET(self, request): raise SynapseError(400, "Missing renewal token") renewal_token = request.args[b"token"][0] - yield self.account_activity_handler.renew_account(renewal_token.decode('utf8')) + yield self.account_activity_handler.renew_account(renewal_token.decode("utf8")) request.setResponseCode(200) request.setHeader(b"Content-Type", b"text/html; charset=utf-8") - request.setHeader(b"Content-Length", b"%d" % ( - len(AccountValidityRenewServlet.SUCCESS_HTML), - )) + request.setHeader( + b"Content-Length", b"%d" % (len(AccountValidityRenewServlet.SUCCESS_HTML),) + ) request.write(AccountValidityRenewServlet.SUCCESS_HTML) finish_request(request) defer.returnValue(None) @@ -77,7 +79,9 @@ def __init__(self, hs): @defer.inlineCallbacks def on_POST(self, request): if not self.account_validity.renew_by_email_enabled: - raise AuthError(403, "Account renewal via email is disabled on this server.") + raise AuthError( + 403, "Account renewal via email is disabled on this server." + ) requester = yield self.auth.get_user_by_req(request, allow_expired=True) user_id = requester.user.to_string() diff --git a/synapse/rest/client/v2_alpha/auth.py b/synapse/rest/client/v2_alpha/auth.py index 8dfe5cba0298..bebc2951e7d0 100644 --- a/synapse/rest/client/v2_alpha/auth.py +++ b/synapse/rest/client/v2_alpha/auth.py @@ -122,6 +122,7 @@ class AuthRestServlet(RestServlet): cannot be handled in the normal flow (with requests to the same endpoint). Current use is for web fallback auth. """ + PATTERNS = client_patterns(r"/auth/(?P[\w\.]*)/fallback/web") def __init__(self, hs): @@ -138,11 +139,10 @@ def on_GET(self, request, stagetype): if stagetype == LoginType.RECAPTCHA: html = RECAPTCHA_TEMPLATE % { - 'session': session, - 'myurl': "%s/r0/auth/%s/fallback/web" % ( - CLIENT_API_PREFIX, LoginType.RECAPTCHA - ), - 'sitekey': self.hs.config.recaptcha_public_key, + "session": session, + "myurl": "%s/r0/auth/%s/fallback/web" + % (CLIENT_API_PREFIX, LoginType.RECAPTCHA), + "sitekey": self.hs.config.recaptcha_public_key, } html_bytes = html.encode("utf8") request.setResponseCode(200) @@ -154,14 +154,11 @@ def on_GET(self, request, stagetype): return None elif stagetype == LoginType.TERMS: html = TERMS_TEMPLATE % { - 'session': session, - 'terms_url': "%s_matrix/consent?v=%s" % ( - self.hs.config.public_baseurl, - self.hs.config.user_consent_version, - ), - 'myurl': "%s/r0/auth/%s/fallback/web" % ( - CLIENT_API_PREFIX, LoginType.TERMS - ), + "session": session, + "terms_url": "%s_matrix/consent?v=%s" + % (self.hs.config.public_baseurl, self.hs.config.user_consent_version), + "myurl": "%s/r0/auth/%s/fallback/web" + % (CLIENT_API_PREFIX, LoginType.TERMS), } html_bytes = html.encode("utf8") request.setResponseCode(200) @@ -187,26 +184,20 @@ def on_POST(self, request, stagetype): if not response: raise SynapseError(400, "No captcha response supplied") - authdict = { - 'response': response, - 'session': session, - } + authdict = {"response": response, "session": session} success = yield self.auth_handler.add_oob_auth( - LoginType.RECAPTCHA, - authdict, - self.hs.get_ip_from_request(request) + LoginType.RECAPTCHA, authdict, self.hs.get_ip_from_request(request) ) if success: html = SUCCESS_TEMPLATE else: html = RECAPTCHA_TEMPLATE % { - 'session': session, - 'myurl': "%s/r0/auth/%s/fallback/web" % ( - CLIENT_API_PREFIX, LoginType.RECAPTCHA - ), - 'sitekey': self.hs.config.recaptcha_public_key, + "session": session, + "myurl": "%s/r0/auth/%s/fallback/web" + % (CLIENT_API_PREFIX, LoginType.RECAPTCHA), + "sitekey": self.hs.config.recaptcha_public_key, } html_bytes = html.encode("utf8") request.setResponseCode(200) @@ -218,31 +209,28 @@ def on_POST(self, request, stagetype): defer.returnValue(None) elif stagetype == LoginType.TERMS: - if ('session' not in request.args or - len(request.args['session'])) == 0: + if ("session" not in request.args or len(request.args["session"])) == 0: raise SynapseError(400, "No session supplied") - session = request.args['session'][0] - authdict = {'session': session} + session = request.args["session"][0] + authdict = {"session": session} success = yield self.auth_handler.add_oob_auth( - LoginType.TERMS, - authdict, - self.hs.get_ip_from_request(request) + LoginType.TERMS, authdict, self.hs.get_ip_from_request(request) ) if success: html = SUCCESS_TEMPLATE else: html = TERMS_TEMPLATE % { - 'session': session, - 'terms_url': "%s_matrix/consent?v=%s" % ( + "session": session, + "terms_url": "%s_matrix/consent?v=%s" + % ( self.hs.config.public_baseurl, self.hs.config.user_consent_version, ), - 'myurl': "%s/r0/auth/%s/fallback/web" % ( - CLIENT_API_PREFIX, LoginType.TERMS - ), + "myurl": "%s/r0/auth/%s/fallback/web" + % (CLIENT_API_PREFIX, LoginType.TERMS), } html_bytes = html.encode("utf8") request.setResponseCode(200) diff --git a/synapse/rest/client/v2_alpha/devices.py b/synapse/rest/client/v2_alpha/devices.py index 78665304a50d..d279229d74e7 100644 --- a/synapse/rest/client/v2_alpha/devices.py +++ b/synapse/rest/client/v2_alpha/devices.py @@ -56,6 +56,7 @@ class DeleteDevicesRestServlet(RestServlet): API for bulk deletion of devices. Accepts a JSON object with a devices key which lists the device_ids to delete. Requires user interactive auth. """ + PATTERNS = client_patterns("/delete_devices") def __init__(self, hs): @@ -84,12 +85,11 @@ def on_POST(self, request): assert_params_in_dict(body, ["devices"]) yield self.auth_handler.validate_user_via_ui_auth( - requester, body, self.hs.get_ip_from_request(request), + requester, body, self.hs.get_ip_from_request(request) ) yield self.device_handler.delete_devices( - requester.user.to_string(), - body['devices'], + requester.user.to_string(), body["devices"] ) defer.returnValue((200, {})) @@ -112,8 +112,7 @@ def __init__(self, hs): def on_GET(self, request, device_id): requester = yield self.auth.get_user_by_req(request, allow_guest=True) device = yield self.device_handler.get_device( - requester.user.to_string(), - device_id, + requester.user.to_string(), device_id ) defer.returnValue((200, device)) @@ -134,12 +133,10 @@ def on_DELETE(self, request, device_id): raise yield self.auth_handler.validate_user_via_ui_auth( - requester, body, self.hs.get_ip_from_request(request), + requester, body, self.hs.get_ip_from_request(request) ) - yield self.device_handler.delete_device( - requester.user.to_string(), device_id, - ) + yield self.device_handler.delete_device(requester.user.to_string(), device_id) defer.returnValue((200, {})) @defer.inlineCallbacks @@ -148,9 +145,7 @@ def on_PUT(self, request, device_id): body = parse_json_object_from_request(request) yield self.device_handler.update_device( - requester.user.to_string(), - device_id, - body + requester.user.to_string(), device_id, body ) defer.returnValue((200, {})) diff --git a/synapse/rest/client/v2_alpha/filter.py b/synapse/rest/client/v2_alpha/filter.py index 65db48c3cc56..3f0adf4a21cf 100644 --- a/synapse/rest/client/v2_alpha/filter.py +++ b/synapse/rest/client/v2_alpha/filter.py @@ -53,8 +53,7 @@ def on_GET(self, request, user_id, filter_id): try: filter = yield self.filtering.get_user_filter( - user_localpart=target_user.localpart, - filter_id=filter_id, + user_localpart=target_user.localpart, filter_id=filter_id ) defer.returnValue((200, filter.get_filter_json())) @@ -84,14 +83,10 @@ def on_POST(self, request, user_id): raise AuthError(403, "Can only create filters for local users") content = parse_json_object_from_request(request) - set_timeline_upper_limit( - content, - self.hs.config.filter_timeline_limit - ) + set_timeline_upper_limit(content, self.hs.config.filter_timeline_limit) filter_id = yield self.filtering.add_user_filter( - user_localpart=target_user.localpart, - user_filter=content, + user_localpart=target_user.localpart, user_filter=content ) defer.returnValue((200, {"filter_id": str(filter_id)})) diff --git a/synapse/rest/client/v2_alpha/groups.py b/synapse/rest/client/v2_alpha/groups.py index d082385ec704..a312dd259389 100644 --- a/synapse/rest/client/v2_alpha/groups.py +++ b/synapse/rest/client/v2_alpha/groups.py @@ -29,6 +29,7 @@ class GroupServlet(RestServlet): """Get the group profile """ + PATTERNS = client_patterns("/groups/(?P[^/]*)/profile$") def __init__(self, hs): @@ -43,8 +44,7 @@ def on_GET(self, request, group_id): requester_user_id = requester.user.to_string() group_description = yield self.groups_handler.get_group_profile( - group_id, - requester_user_id, + group_id, requester_user_id ) defer.returnValue((200, group_description)) @@ -56,7 +56,7 @@ def on_POST(self, request, group_id): content = parse_json_object_from_request(request) yield self.groups_handler.update_group_profile( - group_id, requester_user_id, content, + group_id, requester_user_id, content ) defer.returnValue((200, {})) @@ -65,6 +65,7 @@ def on_POST(self, request, group_id): class GroupSummaryServlet(RestServlet): """Get the full group summary """ + PATTERNS = client_patterns("/groups/(?P[^/]*)/summary$") def __init__(self, hs): @@ -79,8 +80,7 @@ def on_GET(self, request, group_id): requester_user_id = requester.user.to_string() get_group_summary = yield self.groups_handler.get_group_summary( - group_id, - requester_user_id, + group_id, requester_user_id ) defer.returnValue((200, get_group_summary)) @@ -93,6 +93,7 @@ class GroupSummaryRoomsCatServlet(RestServlet): - /groups/:group/summary/rooms/:room_id - /groups/:group/summary/categories/:category/rooms/:room_id """ + PATTERNS = client_patterns( "/groups/(?P[^/]*)/summary" "(/categories/(?P[^/]+))?" @@ -112,7 +113,8 @@ def on_PUT(self, request, group_id, category_id, room_id): content = parse_json_object_from_request(request) resp = yield self.groups_handler.update_group_summary_room( - group_id, requester_user_id, + group_id, + requester_user_id, room_id=room_id, category_id=category_id, content=content, @@ -126,9 +128,7 @@ def on_DELETE(self, request, group_id, category_id, room_id): requester_user_id = requester.user.to_string() resp = yield self.groups_handler.delete_group_summary_room( - group_id, requester_user_id, - room_id=room_id, - category_id=category_id, + group_id, requester_user_id, room_id=room_id, category_id=category_id ) defer.returnValue((200, resp)) @@ -137,6 +137,7 @@ def on_DELETE(self, request, group_id, category_id, room_id): class GroupCategoryServlet(RestServlet): """Get/add/update/delete a group category """ + PATTERNS = client_patterns( "/groups/(?P[^/]*)/categories/(?P[^/]+)$" ) @@ -153,8 +154,7 @@ def on_GET(self, request, group_id, category_id): requester_user_id = requester.user.to_string() category = yield self.groups_handler.get_group_category( - group_id, requester_user_id, - category_id=category_id, + group_id, requester_user_id, category_id=category_id ) defer.returnValue((200, category)) @@ -166,9 +166,7 @@ def on_PUT(self, request, group_id, category_id): content = parse_json_object_from_request(request) resp = yield self.groups_handler.update_group_category( - group_id, requester_user_id, - category_id=category_id, - content=content, + group_id, requester_user_id, category_id=category_id, content=content ) defer.returnValue((200, resp)) @@ -179,8 +177,7 @@ def on_DELETE(self, request, group_id, category_id): requester_user_id = requester.user.to_string() resp = yield self.groups_handler.delete_group_category( - group_id, requester_user_id, - category_id=category_id, + group_id, requester_user_id, category_id=category_id ) defer.returnValue((200, resp)) @@ -189,9 +186,8 @@ def on_DELETE(self, request, group_id, category_id): class GroupCategoriesServlet(RestServlet): """Get all group categories """ - PATTERNS = client_patterns( - "/groups/(?P[^/]*)/categories/$" - ) + + PATTERNS = client_patterns("/groups/(?P[^/]*)/categories/$") def __init__(self, hs): super(GroupCategoriesServlet, self).__init__() @@ -205,7 +201,7 @@ def on_GET(self, request, group_id): requester_user_id = requester.user.to_string() category = yield self.groups_handler.get_group_categories( - group_id, requester_user_id, + group_id, requester_user_id ) defer.returnValue((200, category)) @@ -214,9 +210,8 @@ def on_GET(self, request, group_id): class GroupRoleServlet(RestServlet): """Get/add/update/delete a group role """ - PATTERNS = client_patterns( - "/groups/(?P[^/]*)/roles/(?P[^/]+)$" - ) + + PATTERNS = client_patterns("/groups/(?P[^/]*)/roles/(?P[^/]+)$") def __init__(self, hs): super(GroupRoleServlet, self).__init__() @@ -230,8 +225,7 @@ def on_GET(self, request, group_id, role_id): requester_user_id = requester.user.to_string() category = yield self.groups_handler.get_group_role( - group_id, requester_user_id, - role_id=role_id, + group_id, requester_user_id, role_id=role_id ) defer.returnValue((200, category)) @@ -243,9 +237,7 @@ def on_PUT(self, request, group_id, role_id): content = parse_json_object_from_request(request) resp = yield self.groups_handler.update_group_role( - group_id, requester_user_id, - role_id=role_id, - content=content, + group_id, requester_user_id, role_id=role_id, content=content ) defer.returnValue((200, resp)) @@ -256,8 +248,7 @@ def on_DELETE(self, request, group_id, role_id): requester_user_id = requester.user.to_string() resp = yield self.groups_handler.delete_group_role( - group_id, requester_user_id, - role_id=role_id, + group_id, requester_user_id, role_id=role_id ) defer.returnValue((200, resp)) @@ -266,9 +257,8 @@ def on_DELETE(self, request, group_id, role_id): class GroupRolesServlet(RestServlet): """Get all group roles """ - PATTERNS = client_patterns( - "/groups/(?P[^/]*)/roles/$" - ) + + PATTERNS = client_patterns("/groups/(?P[^/]*)/roles/$") def __init__(self, hs): super(GroupRolesServlet, self).__init__() @@ -282,7 +272,7 @@ def on_GET(self, request, group_id): requester_user_id = requester.user.to_string() category = yield self.groups_handler.get_group_roles( - group_id, requester_user_id, + group_id, requester_user_id ) defer.returnValue((200, category)) @@ -295,6 +285,7 @@ class GroupSummaryUsersRoleServlet(RestServlet): - /groups/:group/summary/users/:room_id - /groups/:group/summary/roles/:role/users/:user_id """ + PATTERNS = client_patterns( "/groups/(?P[^/]*)/summary" "(/roles/(?P[^/]+))?" @@ -314,7 +305,8 @@ def on_PUT(self, request, group_id, role_id, user_id): content = parse_json_object_from_request(request) resp = yield self.groups_handler.update_group_summary_user( - group_id, requester_user_id, + group_id, + requester_user_id, user_id=user_id, role_id=role_id, content=content, @@ -328,9 +320,7 @@ def on_DELETE(self, request, group_id, role_id, user_id): requester_user_id = requester.user.to_string() resp = yield self.groups_handler.delete_group_summary_user( - group_id, requester_user_id, - user_id=user_id, - role_id=role_id, + group_id, requester_user_id, user_id=user_id, role_id=role_id ) defer.returnValue((200, resp)) @@ -339,6 +329,7 @@ def on_DELETE(self, request, group_id, role_id, user_id): class GroupRoomServlet(RestServlet): """Get all rooms in a group """ + PATTERNS = client_patterns("/groups/(?P[^/]*)/rooms$") def __init__(self, hs): @@ -352,7 +343,9 @@ def on_GET(self, request, group_id): requester = yield self.auth.get_user_by_req(request, allow_guest=True) requester_user_id = requester.user.to_string() - result = yield self.groups_handler.get_rooms_in_group(group_id, requester_user_id) + result = yield self.groups_handler.get_rooms_in_group( + group_id, requester_user_id + ) defer.returnValue((200, result)) @@ -360,6 +353,7 @@ def on_GET(self, request, group_id): class GroupUsersServlet(RestServlet): """Get all users in a group """ + PATTERNS = client_patterns("/groups/(?P[^/]*)/users$") def __init__(self, hs): @@ -373,7 +367,9 @@ def on_GET(self, request, group_id): requester = yield self.auth.get_user_by_req(request, allow_guest=True) requester_user_id = requester.user.to_string() - result = yield self.groups_handler.get_users_in_group(group_id, requester_user_id) + result = yield self.groups_handler.get_users_in_group( + group_id, requester_user_id + ) defer.returnValue((200, result)) @@ -381,6 +377,7 @@ def on_GET(self, request, group_id): class GroupInvitedUsersServlet(RestServlet): """Get users invited to a group """ + PATTERNS = client_patterns("/groups/(?P[^/]*)/invited_users$") def __init__(self, hs): @@ -395,8 +392,7 @@ def on_GET(self, request, group_id): requester_user_id = requester.user.to_string() result = yield self.groups_handler.get_invited_users_in_group( - group_id, - requester_user_id, + group_id, requester_user_id ) defer.returnValue((200, result)) @@ -405,6 +401,7 @@ def on_GET(self, request, group_id): class GroupSettingJoinPolicyServlet(RestServlet): """Set group join policy """ + PATTERNS = client_patterns("/groups/(?P[^/]*)/settings/m.join_policy$") def __init__(self, hs): @@ -420,9 +417,7 @@ def on_PUT(self, request, group_id): content = parse_json_object_from_request(request) result = yield self.groups_handler.set_group_join_policy( - group_id, - requester_user_id, - content, + group_id, requester_user_id, content ) defer.returnValue((200, result)) @@ -431,6 +426,7 @@ def on_PUT(self, request, group_id): class GroupCreateServlet(RestServlet): """Create a group """ + PATTERNS = client_patterns("/create_group$") def __init__(self, hs): @@ -451,9 +447,7 @@ def on_POST(self, request): group_id = GroupID(localpart, self.server_name).to_string() result = yield self.groups_handler.create_group( - group_id, - requester_user_id, - content, + group_id, requester_user_id, content ) defer.returnValue((200, result)) @@ -462,6 +456,7 @@ def on_POST(self, request): class GroupAdminRoomsServlet(RestServlet): """Add a room to the group """ + PATTERNS = client_patterns( "/groups/(?P[^/]*)/admin/rooms/(?P[^/]*)$" ) @@ -479,7 +474,7 @@ def on_PUT(self, request, group_id, room_id): content = parse_json_object_from_request(request) result = yield self.groups_handler.add_room_to_group( - group_id, requester_user_id, room_id, content, + group_id, requester_user_id, room_id, content ) defer.returnValue((200, result)) @@ -490,7 +485,7 @@ def on_DELETE(self, request, group_id, room_id): requester_user_id = requester.user.to_string() result = yield self.groups_handler.remove_room_from_group( - group_id, requester_user_id, room_id, + group_id, requester_user_id, room_id ) defer.returnValue((200, result)) @@ -499,6 +494,7 @@ def on_DELETE(self, request, group_id, room_id): class GroupAdminRoomsConfigServlet(RestServlet): """Update the config of a room in a group """ + PATTERNS = client_patterns( "/groups/(?P[^/]*)/admin/rooms/(?P[^/]*)" "/config/(?P[^/]*)$" @@ -517,7 +513,7 @@ def on_PUT(self, request, group_id, room_id, config_key): content = parse_json_object_from_request(request) result = yield self.groups_handler.update_room_in_group( - group_id, requester_user_id, room_id, config_key, content, + group_id, requester_user_id, room_id, config_key, content ) defer.returnValue((200, result)) @@ -526,6 +522,7 @@ def on_PUT(self, request, group_id, room_id, config_key): class GroupAdminUsersInviteServlet(RestServlet): """Invite a user to the group """ + PATTERNS = client_patterns( "/groups/(?P[^/]*)/admin/users/invite/(?P[^/]*)$" ) @@ -546,7 +543,7 @@ def on_PUT(self, request, group_id, user_id): content = parse_json_object_from_request(request) config = content.get("config", {}) result = yield self.groups_handler.invite( - group_id, user_id, requester_user_id, config, + group_id, user_id, requester_user_id, config ) defer.returnValue((200, result)) @@ -555,6 +552,7 @@ def on_PUT(self, request, group_id, user_id): class GroupAdminUsersKickServlet(RestServlet): """Kick a user from the group """ + PATTERNS = client_patterns( "/groups/(?P[^/]*)/admin/users/remove/(?P[^/]*)$" ) @@ -572,7 +570,7 @@ def on_PUT(self, request, group_id, user_id): content = parse_json_object_from_request(request) result = yield self.groups_handler.remove_user_from_group( - group_id, user_id, requester_user_id, content, + group_id, user_id, requester_user_id, content ) defer.returnValue((200, result)) @@ -581,9 +579,8 @@ def on_PUT(self, request, group_id, user_id): class GroupSelfLeaveServlet(RestServlet): """Leave a joined group """ - PATTERNS = client_patterns( - "/groups/(?P[^/]*)/self/leave$" - ) + + PATTERNS = client_patterns("/groups/(?P[^/]*)/self/leave$") def __init__(self, hs): super(GroupSelfLeaveServlet, self).__init__() @@ -598,7 +595,7 @@ def on_PUT(self, request, group_id): content = parse_json_object_from_request(request) result = yield self.groups_handler.remove_user_from_group( - group_id, requester_user_id, requester_user_id, content, + group_id, requester_user_id, requester_user_id, content ) defer.returnValue((200, result)) @@ -607,9 +604,8 @@ def on_PUT(self, request, group_id): class GroupSelfJoinServlet(RestServlet): """Attempt to join a group, or knock """ - PATTERNS = client_patterns( - "/groups/(?P[^/]*)/self/join$" - ) + + PATTERNS = client_patterns("/groups/(?P[^/]*)/self/join$") def __init__(self, hs): super(GroupSelfJoinServlet, self).__init__() @@ -624,7 +620,7 @@ def on_PUT(self, request, group_id): content = parse_json_object_from_request(request) result = yield self.groups_handler.join_group( - group_id, requester_user_id, content, + group_id, requester_user_id, content ) defer.returnValue((200, result)) @@ -633,9 +629,8 @@ def on_PUT(self, request, group_id): class GroupSelfAcceptInviteServlet(RestServlet): """Accept a group invite """ - PATTERNS = client_patterns( - "/groups/(?P[^/]*)/self/accept_invite$" - ) + + PATTERNS = client_patterns("/groups/(?P[^/]*)/self/accept_invite$") def __init__(self, hs): super(GroupSelfAcceptInviteServlet, self).__init__() @@ -650,7 +645,7 @@ def on_PUT(self, request, group_id): content = parse_json_object_from_request(request) result = yield self.groups_handler.accept_invite( - group_id, requester_user_id, content, + group_id, requester_user_id, content ) defer.returnValue((200, result)) @@ -659,9 +654,8 @@ def on_PUT(self, request, group_id): class GroupSelfUpdatePublicityServlet(RestServlet): """Update whether we publicise a users membership of a group """ - PATTERNS = client_patterns( - "/groups/(?P[^/]*)/self/update_publicity$" - ) + + PATTERNS = client_patterns("/groups/(?P[^/]*)/self/update_publicity$") def __init__(self, hs): super(GroupSelfUpdatePublicityServlet, self).__init__() @@ -676,9 +670,7 @@ def on_PUT(self, request, group_id): content = parse_json_object_from_request(request) publicise = content["publicise"] - yield self.store.update_group_publicity( - group_id, requester_user_id, publicise, - ) + yield self.store.update_group_publicity(group_id, requester_user_id, publicise) defer.returnValue((200, {})) @@ -686,9 +678,8 @@ def on_PUT(self, request, group_id): class PublicisedGroupsForUserServlet(RestServlet): """Get the list of groups a user is advertising """ - PATTERNS = client_patterns( - "/publicised_groups/(?P[^/]*)$" - ) + + PATTERNS = client_patterns("/publicised_groups/(?P[^/]*)$") def __init__(self, hs): super(PublicisedGroupsForUserServlet, self).__init__() @@ -701,9 +692,7 @@ def __init__(self, hs): def on_GET(self, request, user_id): yield self.auth.get_user_by_req(request, allow_guest=True) - result = yield self.groups_handler.get_publicised_groups_for_user( - user_id - ) + result = yield self.groups_handler.get_publicised_groups_for_user(user_id) defer.returnValue((200, result)) @@ -711,9 +700,8 @@ def on_GET(self, request, user_id): class PublicisedGroupsForUsersServlet(RestServlet): """Get the list of groups a user is advertising """ - PATTERNS = client_patterns( - "/publicised_groups$" - ) + + PATTERNS = client_patterns("/publicised_groups$") def __init__(self, hs): super(PublicisedGroupsForUsersServlet, self).__init__() @@ -729,9 +717,7 @@ def on_POST(self, request): content = parse_json_object_from_request(request) user_ids = content["user_ids"] - result = yield self.groups_handler.bulk_get_publicised_groups( - user_ids - ) + result = yield self.groups_handler.bulk_get_publicised_groups(user_ids) defer.returnValue((200, result)) @@ -739,9 +725,8 @@ def on_POST(self, request): class GroupsForUserServlet(RestServlet): """Get all groups the logged in user is joined to """ - PATTERNS = client_patterns( - "/joined_groups$" - ) + + PATTERNS = client_patterns("/joined_groups$") def __init__(self, hs): super(GroupsForUserServlet, self).__init__() diff --git a/synapse/rest/client/v2_alpha/keys.py b/synapse/rest/client/v2_alpha/keys.py index 4cbfbf5631b7..45c9928b6510 100644 --- a/synapse/rest/client/v2_alpha/keys.py +++ b/synapse/rest/client/v2_alpha/keys.py @@ -56,6 +56,7 @@ class KeyUploadServlet(RestServlet): }, } """ + PATTERNS = client_patterns("/keys/upload(/(?P[^/]+))?$") def __init__(self, hs): @@ -76,18 +77,19 @@ def on_POST(self, request, device_id): if device_id is not None: # passing the device_id here is deprecated; however, we allow it # for now for compatibility with older clients. - if (requester.device_id is not None and - device_id != requester.device_id): - logger.warning("Client uploading keys for a different device " - "(logged in as %s, uploading for %s)", - requester.device_id, device_id) + if requester.device_id is not None and device_id != requester.device_id: + logger.warning( + "Client uploading keys for a different device " + "(logged in as %s, uploading for %s)", + requester.device_id, + device_id, + ) else: device_id = requester.device_id if device_id is None: raise SynapseError( - 400, - "To upload keys, you must pass device_id when authenticating" + 400, "To upload keys, you must pass device_id when authenticating" ) result = yield self.e2e_keys_handler.upload_keys_for_user( @@ -159,6 +161,7 @@ class KeyChangesServlet(RestServlet): 200 OK { "changed": ["@foo:example.com"] } """ + PATTERNS = client_patterns("/keys/changes$") def __init__(self, hs): @@ -184,9 +187,7 @@ def on_GET(self, request): user_id = requester.user.to_string() - results = yield self.device_handler.get_user_ids_changed( - user_id, from_token, - ) + results = yield self.device_handler.get_user_ids_changed(user_id, from_token) defer.returnValue((200, results)) @@ -209,6 +210,7 @@ class OneTimeKeyServlet(RestServlet): } } } } """ + PATTERNS = client_patterns("/keys/claim$") def __init__(self, hs): @@ -221,10 +223,7 @@ def on_POST(self, request): yield self.auth.get_user_by_req(request, allow_guest=True) timeout = parse_integer(request, "timeout", 10 * 1000) body = parse_json_object_from_request(request) - result = yield self.e2e_keys_handler.claim_one_time_keys( - body, - timeout, - ) + result = yield self.e2e_keys_handler.claim_one_time_keys(body, timeout) defer.returnValue((200, result)) diff --git a/synapse/rest/client/v2_alpha/notifications.py b/synapse/rest/client/v2_alpha/notifications.py index 53e666989bb0..728a52328f18 100644 --- a/synapse/rest/client/v2_alpha/notifications.py +++ b/synapse/rest/client/v2_alpha/notifications.py @@ -51,7 +51,7 @@ def on_GET(self, request): ) receipts_by_room = yield self.store.get_receipts_for_user_with_orderings( - user_id, 'm.read' + user_id, "m.read" ) notif_event_ids = [pa["event_id"] for pa in push_actions] @@ -67,11 +67,13 @@ def on_GET(self, request): "profile_tag": pa["profile_tag"], "actions": pa["actions"], "ts": pa["received_ts"], - "event": (yield self._event_serializer.serialize_event( - notif_events[pa["event_id"]], - self.clock.time_msec(), - event_format=format_event_for_client_v2_without_room_id, - )), + "event": ( + yield self._event_serializer.serialize_event( + notif_events[pa["event_id"]], + self.clock.time_msec(), + event_format=format_event_for_client_v2_without_room_id, + ) + ), } if pa["room_id"] not in receipts_by_room: @@ -80,17 +82,15 @@ def on_GET(self, request): receipt = receipts_by_room[pa["room_id"]] returned_pa["read"] = ( - receipt["topological_ordering"], receipt["stream_ordering"] - ) >= ( - pa["topological_ordering"], pa["stream_ordering"] - ) + receipt["topological_ordering"], + receipt["stream_ordering"], + ) >= (pa["topological_ordering"], pa["stream_ordering"]) returned_push_actions.append(returned_pa) next_token = str(pa["stream_ordering"]) - defer.returnValue((200, { - "notifications": returned_push_actions, - "next_token": next_token, - })) + defer.returnValue( + (200, {"notifications": returned_push_actions, "next_token": next_token}) + ) def register_servlets(hs, http_server): diff --git a/synapse/rest/client/v2_alpha/openid.py b/synapse/rest/client/v2_alpha/openid.py index bb927d9f9d86..b1b5385b09ef 100644 --- a/synapse/rest/client/v2_alpha/openid.py +++ b/synapse/rest/client/v2_alpha/openid.py @@ -56,9 +56,8 @@ class IdTokenServlet(RestServlet): "expires_in": 3600, } """ - PATTERNS = client_patterns( - "/user/(?P[^/]*)/openid/request_token" - ) + + PATTERNS = client_patterns("/user/(?P[^/]*)/openid/request_token") EXPIRES_MS = 3600 * 1000 @@ -84,12 +83,17 @@ def on_POST(self, request, user_id): yield self.store.insert_open_id_token(token, ts_valid_until_ms, user_id) - defer.returnValue((200, { - "access_token": token, - "token_type": "Bearer", - "matrix_server_name": self.server_name, - "expires_in": self.EXPIRES_MS / 1000, - })) + defer.returnValue( + ( + 200, + { + "access_token": token, + "token_type": "Bearer", + "matrix_server_name": self.server_name, + "expires_in": self.EXPIRES_MS / 1000, + }, + ) + ) def register_servlets(hs, http_server): diff --git a/synapse/rest/client/v2_alpha/read_marker.py b/synapse/rest/client/v2_alpha/read_marker.py index f4bd0d077f76..e75664279b51 100644 --- a/synapse/rest/client/v2_alpha/read_marker.py +++ b/synapse/rest/client/v2_alpha/read_marker.py @@ -48,7 +48,7 @@ def on_POST(self, request, room_id): room_id, "m.read", user_id=requester.user.to_string(), - event_id=read_event_id + event_id=read_event_id, ) read_marker_event_id = body.get("m.fully_read", None) @@ -56,7 +56,7 @@ def on_POST(self, request, room_id): yield self.read_marker_handler.received_client_read_marker( room_id, user_id=requester.user.to_string(), - event_id=read_marker_event_id + event_id=read_marker_event_id, ) defer.returnValue((200, {})) diff --git a/synapse/rest/client/v2_alpha/receipts.py b/synapse/rest/client/v2_alpha/receipts.py index fa12ac3e4d11..488905626a0f 100644 --- a/synapse/rest/client/v2_alpha/receipts.py +++ b/synapse/rest/client/v2_alpha/receipts.py @@ -49,10 +49,7 @@ def on_POST(self, request, room_id, receipt_type, event_id): yield self.presence_handler.bump_presence_active_time(requester.user) yield self.receipts_handler.received_client_receipt( - room_id, - receipt_type, - user_id=requester.user.to_string(), - event_id=event_id + room_id, receipt_type, user_id=requester.user.to_string(), event_id=event_id ) defer.returnValue((200, {})) diff --git a/synapse/rest/client/v2_alpha/register.py b/synapse/rest/client/v2_alpha/register.py index 79c085408bf3..5c120e4dd5d8 100644 --- a/synapse/rest/client/v2_alpha/register.py +++ b/synapse/rest/client/v2_alpha/register.py @@ -52,6 +52,7 @@ if hasattr(hmac, "compare_digest"): compare_digest = hmac.compare_digest else: + def compare_digest(a, b): return a == b @@ -75,11 +76,11 @@ def __init__(self, hs): def on_POST(self, request): body = parse_json_object_from_request(request) - assert_params_in_dict(body, [ - 'id_server', 'client_secret', 'email', 'send_attempt' - ]) + assert_params_in_dict( + body, ["id_server", "client_secret", "email", "send_attempt"] + ) - if not check_3pid_allowed(self.hs, "email", body['email']): + if not check_3pid_allowed(self.hs, "email", body["email"]): raise SynapseError( 403, "Your email domain is not authorized to register on this server", @@ -87,7 +88,7 @@ def on_POST(self, request): ) existingUid = yield self.hs.get_datastore().get_user_id_by_threepid( - 'email', body['email'] + "email", body["email"] ) if existingUid is not None: @@ -113,13 +114,12 @@ def __init__(self, hs): def on_POST(self, request): body = parse_json_object_from_request(request) - assert_params_in_dict(body, [ - 'id_server', 'client_secret', - 'country', 'phone_number', - 'send_attempt', - ]) + assert_params_in_dict( + body, + ["id_server", "client_secret", "country", "phone_number", "send_attempt"], + ) - msisdn = phone_number_to_msisdn(body['country'], body['phone_number']) + msisdn = phone_number_to_msisdn(body["country"], body["phone_number"]) if not check_3pid_allowed(self.hs, "msisdn", msisdn): raise SynapseError( @@ -129,7 +129,7 @@ def on_POST(self, request): ) existingUid = yield self.hs.get_datastore().get_user_id_by_threepid( - 'msisdn', msisdn + "msisdn", msisdn ) if existingUid is not None: @@ -165,7 +165,7 @@ def __init__(self, hs): reject_limit=1, # Allow 1 request at a time concurrent_requests=1, - ) + ), ) @defer.inlineCallbacks @@ -212,7 +212,8 @@ def on_POST(self, request): time_now = self.clock.time() allowed, time_allowed = self.ratelimiter.can_do_action( - client_addr, time_now_s=time_now, + client_addr, + time_now_s=time_now, rate_hz=self.hs.config.rc_registration.per_second, burst_count=self.hs.config.rc_registration.burst_count, update=False, @@ -220,7 +221,7 @@ def on_POST(self, request): if not allowed: raise LimitExceededError( - retry_after_ms=int(1000 * (time_allowed - time_now)), + retry_after_ms=int(1000 * (time_allowed - time_now)) ) kind = b"user" @@ -239,18 +240,22 @@ def on_POST(self, request): # we do basic sanity checks here because the auth layer will store these # in sessions. Pull out the username/password provided to us. desired_password = None - if 'password' in body: - if (not isinstance(body['password'], string_types) or - len(body['password']) > 512): + if "password" in body: + if ( + not isinstance(body["password"], string_types) + or len(body["password"]) > 512 + ): raise SynapseError(400, "Invalid password") desired_password = body["password"] desired_username = None - if 'username' in body: - if (not isinstance(body['username'], string_types) or - len(body['username']) > 512): + if "username" in body: + if ( + not isinstance(body["username"], string_types) + or len(body["username"]) > 512 + ): raise SynapseError(400, "Invalid username") - desired_username = body['username'] + desired_username = body["username"] appservice = None if self.auth.has_access_token(request): @@ -290,7 +295,7 @@ def on_POST(self, request): desired_username = desired_username.lower() # == Shared Secret Registration == (e.g. create new user scripts) - if 'mac' in body: + if "mac" in body: # FIXME: Should we really be determining if this is shared secret # auth based purely on the 'mac' key? result = yield self._do_shared_secret_registration( @@ -305,16 +310,13 @@ def on_POST(self, request): guest_access_token = body.get("guest_access_token", None) - if ( - 'initial_device_display_name' in body and - 'password' not in body - ): + if "initial_device_display_name" in body and "password" not in body: # ignore 'initial_device_display_name' if sent without # a password to work around a client bug where it sent # the 'initial_device_display_name' param alone, wiping out # the original registration params logger.warn("Ignoring initial_device_display_name without password") - del body['initial_device_display_name'] + del body["initial_device_display_name"] session_id = self.auth_handler.get_session_id(body) registered_user_id = None @@ -336,8 +338,8 @@ def on_POST(self, request): # FIXME: need a better error than "no auth flow found" for scenarios # where we required 3PID for registration but the user didn't give one - require_email = 'email' in self.hs.config.registrations_require_3pid - require_msisdn = 'msisdn' in self.hs.config.registrations_require_3pid + require_email = "email" in self.hs.config.registrations_require_3pid + require_msisdn = "msisdn" in self.hs.config.registrations_require_3pid show_msisdn = True if self.hs.config.disable_msisdn_registration: @@ -362,9 +364,9 @@ def on_POST(self, request): if not require_email: flows.extend([[LoginType.RECAPTCHA, LoginType.MSISDN]]) # always let users provide both MSISDN & email - flows.extend([ - [LoginType.RECAPTCHA, LoginType.MSISDN, LoginType.EMAIL_IDENTITY], - ]) + flows.extend( + [[LoginType.RECAPTCHA, LoginType.MSISDN, LoginType.EMAIL_IDENTITY]] + ) else: # only support 3PIDless registration if no 3PIDs are required if not require_email and not require_msisdn: @@ -378,9 +380,7 @@ def on_POST(self, request): if not require_email or require_msisdn: flows.extend([[LoginType.MSISDN]]) # always let users provide both MSISDN & email - flows.extend([ - [LoginType.MSISDN, LoginType.EMAIL_IDENTITY] - ]) + flows.extend([[LoginType.MSISDN, LoginType.EMAIL_IDENTITY]]) # Append m.login.terms to all flows if we're requiring consent if self.hs.config.user_consent_at_registration: @@ -410,21 +410,20 @@ def on_POST(self, request): if auth_result: for login_type in [LoginType.EMAIL_IDENTITY, LoginType.MSISDN]: if login_type in auth_result: - medium = auth_result[login_type]['medium'] - address = auth_result[login_type]['address'] + medium = auth_result[login_type]["medium"] + address = auth_result[login_type]["address"] if not check_3pid_allowed(self.hs, medium, address): raise SynapseError( 403, - "Third party identifiers (email/phone numbers)" + - " are not authorized on this server", + "Third party identifiers (email/phone numbers)" + + " are not authorized on this server", Codes.THREEPID_DENIED, ) if registered_user_id is not None: logger.info( - "Already registered user ID %r for this session", - registered_user_id + "Already registered user ID %r for this session", registered_user_id ) # don't re-register the threepids registered = False @@ -451,11 +450,11 @@ def on_POST(self, request): # the two activation emails, they would register the same 3pid twice. for login_type in [LoginType.EMAIL_IDENTITY, LoginType.MSISDN]: if login_type in auth_result: - medium = auth_result[login_type]['medium'] - address = auth_result[login_type]['address'] + medium = auth_result[login_type]["medium"] + address = auth_result[login_type]["address"] existingUid = yield self.store.get_user_id_by_threepid( - medium, address, + medium, address ) if existingUid is not None: @@ -520,7 +519,7 @@ def _do_shared_secret_registration(self, username, password, body): raise SynapseError(400, "Shared secret registration is not enabled") if not username: raise SynapseError( - 400, "username must be specified", errcode=Codes.BAD_JSON, + 400, "username must be specified", errcode=Codes.BAD_JSON ) # use the username from the original request rather than the @@ -541,12 +540,10 @@ def _do_shared_secret_registration(self, username, password, body): ).hexdigest() if not compare_digest(want_mac, got_mac): - raise SynapseError( - 403, "HMAC incorrect", - ) + raise SynapseError(403, "HMAC incorrect") (user_id, _) = yield self.registration_handler.register( - localpart=username, password=password, generate_token=False, + localpart=username, password=password, generate_token=False ) result = yield self._create_registration_details(user_id, body) @@ -565,21 +562,15 @@ def _create_registration_details(self, user_id, params): Returns: defer.Deferred: (object) dictionary for response from /register """ - result = { - "user_id": user_id, - "home_server": self.hs.hostname, - } + result = {"user_id": user_id, "home_server": self.hs.hostname} if not params.get("inhibit_login", False): device_id = params.get("device_id") initial_display_name = params.get("initial_device_display_name") device_id, access_token = yield self.registration_handler.register_device( - user_id, device_id, initial_display_name, is_guest=False, + user_id, device_id, initial_display_name, is_guest=False ) - result.update({ - "access_token": access_token, - "device_id": device_id, - }) + result.update({"access_token": access_token, "device_id": device_id}) defer.returnValue(result) @defer.inlineCallbacks @@ -587,9 +578,7 @@ def _do_guest_registration(self, params, address=None): if not self.hs.config.allow_guest_access: raise SynapseError(403, "Guest access is disabled") user_id, _ = yield self.registration_handler.register( - generate_token=False, - make_guest=True, - address=address, + generate_token=False, make_guest=True, address=address ) # we don't allow guests to specify their own device_id, because @@ -597,15 +586,20 @@ def _do_guest_registration(self, params, address=None): device_id = synapse.api.auth.GUEST_DEVICE_ID initial_display_name = params.get("initial_device_display_name") device_id, access_token = yield self.registration_handler.register_device( - user_id, device_id, initial_display_name, is_guest=True, + user_id, device_id, initial_display_name, is_guest=True ) - defer.returnValue((200, { - "user_id": user_id, - "device_id": device_id, - "access_token": access_token, - "home_server": self.hs.hostname, - })) + defer.returnValue( + ( + 200, + { + "user_id": user_id, + "device_id": device_id, + "access_token": access_token, + "home_server": self.hs.hostname, + }, + ) + ) def register_servlets(hs, http_server): diff --git a/synapse/rest/client/v2_alpha/relations.py b/synapse/rest/client/v2_alpha/relations.py index f8f8742bdc24..8e362782cc3c 100644 --- a/synapse/rest/client/v2_alpha/relations.py +++ b/synapse/rest/client/v2_alpha/relations.py @@ -32,7 +32,10 @@ parse_string, ) from synapse.rest.client.transactions import HttpTransactionCache -from synapse.storage.relations import AggregationPaginationToken, RelationPaginationToken +from synapse.storage.relations import ( + AggregationPaginationToken, + RelationPaginationToken, +) from ._base import client_patterns diff --git a/synapse/rest/client/v2_alpha/report_event.py b/synapse/rest/client/v2_alpha/report_event.py index 10198662a9eb..e7578af8040a 100644 --- a/synapse/rest/client/v2_alpha/report_event.py +++ b/synapse/rest/client/v2_alpha/report_event.py @@ -33,9 +33,7 @@ class ReportEventRestServlet(RestServlet): - PATTERNS = client_patterns( - "/rooms/(?P[^/]*)/report/(?P[^/]*)$" - ) + PATTERNS = client_patterns("/rooms/(?P[^/]*)/report/(?P[^/]*)$") def __init__(self, hs): super(ReportEventRestServlet, self).__init__() diff --git a/synapse/rest/client/v2_alpha/room_keys.py b/synapse/rest/client/v2_alpha/room_keys.py index 87779645f971..8d1b810565ee 100644 --- a/synapse/rest/client/v2_alpha/room_keys.py +++ b/synapse/rest/client/v2_alpha/room_keys.py @@ -129,22 +129,12 @@ def on_PUT(self, request, room_id, session_id): version = parse_string(request, "version") if session_id: - body = { - "sessions": { - session_id: body - } - } + body = {"sessions": {session_id: body}} if room_id: - body = { - "rooms": { - room_id: body - } - } + body = {"rooms": {room_id: body}} - yield self.e2e_room_keys_handler.upload_room_keys( - user_id, version, body - ) + yield self.e2e_room_keys_handler.upload_room_keys(user_id, version, body) defer.returnValue((200, {})) @defer.inlineCallbacks @@ -212,10 +202,10 @@ def on_GET(self, request, room_id, session_id): if session_id: # If the client requests a specific session, but that session was # not backed up, then return an M_NOT_FOUND. - if room_keys['rooms'] == {}: + if room_keys["rooms"] == {}: raise NotFoundError("No room_keys found") else: - room_keys = room_keys['rooms'][room_id]['sessions'][session_id] + room_keys = room_keys["rooms"][room_id]["sessions"][session_id] elif room_id: # If the client requests all sessions from a room, but no sessions # are found, then return an empty result rather than an error, so @@ -223,10 +213,10 @@ def on_GET(self, request, room_id, session_id): # empty result is valid. (Similarly if the client requests all # sessions from the backup, but in that case, room_keys is already # in the right format, so we don't need to do anything about it.) - if room_keys['rooms'] == {}: - room_keys = {'sessions': {}} + if room_keys["rooms"] == {}: + room_keys = {"sessions": {}} else: - room_keys = room_keys['rooms'][room_id] + room_keys = room_keys["rooms"][room_id] defer.returnValue((200, room_keys)) @@ -256,9 +246,7 @@ def on_DELETE(self, request, room_id, session_id): class RoomKeysNewVersionServlet(RestServlet): - PATTERNS = client_patterns( - "/room_keys/version$" - ) + PATTERNS = client_patterns("/room_keys/version$") def __init__(self, hs): """ @@ -304,9 +292,7 @@ def on_POST(self, request): user_id = requester.user.to_string() info = parse_json_object_from_request(request) - new_version = yield self.e2e_room_keys_handler.create_version( - user_id, info - ) + new_version = yield self.e2e_room_keys_handler.create_version(user_id, info) defer.returnValue((200, {"version": new_version})) # we deliberately don't have a PUT /version, as these things really should @@ -314,9 +300,7 @@ def on_POST(self, request): class RoomKeysVersionServlet(RestServlet): - PATTERNS = client_patterns( - "/room_keys/version(/(?P[^/]+))?$" - ) + PATTERNS = client_patterns("/room_keys/version(/(?P[^/]+))?$") def __init__(self, hs): """ @@ -350,9 +334,7 @@ def on_GET(self, request, version): user_id = requester.user.to_string() try: - info = yield self.e2e_room_keys_handler.get_version_info( - user_id, version - ) + info = yield self.e2e_room_keys_handler.get_version_info(user_id, version) except SynapseError as e: if e.code == 404: raise SynapseError(404, "No backup found", Codes.NOT_FOUND) @@ -375,9 +357,7 @@ def on_DELETE(self, request, version): requester = yield self.auth.get_user_by_req(request, allow_guest=False) user_id = requester.user.to_string() - yield self.e2e_room_keys_handler.delete_version( - user_id, version - ) + yield self.e2e_room_keys_handler.delete_version(user_id, version) defer.returnValue((200, {})) @defer.inlineCallbacks @@ -407,11 +387,11 @@ def on_PUT(self, request, version): info = parse_json_object_from_request(request) if version is None: - raise SynapseError(400, "No version specified to update", Codes.MISSING_PARAM) + raise SynapseError( + 400, "No version specified to update", Codes.MISSING_PARAM + ) - yield self.e2e_room_keys_handler.update_version( - user_id, version, info - ) + yield self.e2e_room_keys_handler.update_version(user_id, version, info) defer.returnValue((200, {})) diff --git a/synapse/rest/client/v2_alpha/room_upgrade_rest_servlet.py b/synapse/rest/client/v2_alpha/room_upgrade_rest_servlet.py index c621a90fbaa2..d7f7faa029b5 100644 --- a/synapse/rest/client/v2_alpha/room_upgrade_rest_servlet.py +++ b/synapse/rest/client/v2_alpha/room_upgrade_rest_servlet.py @@ -47,9 +47,10 @@ class RoomUpgradeRestServlet(RestServlet): Args: hs (synapse.server.HomeServer): """ + PATTERNS = client_patterns( # /rooms/$roomid/upgrade - "/rooms/(?P[^/]*)/upgrade$", + "/rooms/(?P[^/]*)/upgrade$" ) def __init__(self, hs): @@ -63,7 +64,7 @@ def on_POST(self, request, room_id): requester = yield self._auth.get_user_by_req(request) content = parse_json_object_from_request(request) - assert_params_in_dict(content, ("new_version", )) + assert_params_in_dict(content, ("new_version",)) new_version = content["new_version"] if new_version not in KNOWN_ROOM_VERSIONS: @@ -77,9 +78,7 @@ def on_POST(self, request, room_id): requester, room_id, new_version ) - ret = { - "replacement_room": new_room_id, - } + ret = {"replacement_room": new_room_id} defer.returnValue((200, ret)) diff --git a/synapse/rest/client/v2_alpha/sendtodevice.py b/synapse/rest/client/v2_alpha/sendtodevice.py index 120a71336198..78075b8fc0f8 100644 --- a/synapse/rest/client/v2_alpha/sendtodevice.py +++ b/synapse/rest/client/v2_alpha/sendtodevice.py @@ -28,7 +28,7 @@ class SendToDeviceRestServlet(servlet.RestServlet): PATTERNS = client_patterns( - "/sendToDevice/(?P[^/]*)/(?P[^/]*)$", + "/sendToDevice/(?P[^/]*)/(?P[^/]*)$" ) def __init__(self, hs): diff --git a/synapse/rest/client/v2_alpha/sync.py b/synapse/rest/client/v2_alpha/sync.py index 148fc6c985b5..02d56dee6cfb 100644 --- a/synapse/rest/client/v2_alpha/sync.py +++ b/synapse/rest/client/v2_alpha/sync.py @@ -96,44 +96,42 @@ def on_GET(self, request): 400, "'from' is not a valid query parameter. Did you mean 'since'?" ) - requester = yield self.auth.get_user_by_req( - request, allow_guest=True - ) + requester = yield self.auth.get_user_by_req(request, allow_guest=True) user = requester.user device_id = requester.device_id timeout = parse_integer(request, "timeout", default=0) since = parse_string(request, "since") set_presence = parse_string( - request, "set_presence", default="online", - allowed_values=self.ALLOWED_PRESENCE + request, + "set_presence", + default="online", + allowed_values=self.ALLOWED_PRESENCE, ) filter_id = parse_string(request, "filter", default=None) full_state = parse_boolean(request, "full_state", default=False) logger.debug( "/sync: user=%r, timeout=%r, since=%r," - " set_presence=%r, filter_id=%r, device_id=%r" % ( - user, timeout, since, set_presence, filter_id, device_id - ) + " set_presence=%r, filter_id=%r, device_id=%r" + % (user, timeout, since, set_presence, filter_id, device_id) ) request_key = (user, timeout, since, filter_id, full_state, device_id) if filter_id: - if filter_id.startswith('{'): + if filter_id.startswith("{"): try: filter_object = json.loads(filter_id) - set_timeline_upper_limit(filter_object, - self.hs.config.filter_timeline_limit) + set_timeline_upper_limit( + filter_object, self.hs.config.filter_timeline_limit + ) except Exception: raise SynapseError(400, "Invalid filter JSON") self.filtering.check_valid_filter(filter_object) filter = FilterCollection(filter_object) else: - filter = yield self.filtering.get_user_filter( - user.localpart, filter_id - ) + filter = yield self.filtering.get_user_filter(user.localpart, filter_id) else: filter = DEFAULT_FILTER_COLLECTION @@ -156,15 +154,19 @@ def on_GET(self, request): affect_presence = set_presence != PresenceState.OFFLINE if affect_presence: - yield self.presence_handler.set_state(user, {"presence": set_presence}, True) + yield self.presence_handler.set_state( + user, {"presence": set_presence}, True + ) context = yield self.presence_handler.user_syncing( - user.to_string(), affect_presence=affect_presence, + user.to_string(), affect_presence=affect_presence ) with context: sync_result = yield self.sync_handler.wait_for_sync_for_user( - sync_config, since_token=since_token, timeout=timeout, - full_state=full_state + sync_config, + since_token=since_token, + timeout=timeout, + full_state=full_state, ) time_now = self.clock.time_msec() @@ -176,53 +178,54 @@ def on_GET(self, request): @defer.inlineCallbacks def encode_response(self, time_now, sync_result, access_token_id, filter): - if filter.event_format == 'client': + if filter.event_format == "client": event_formatter = format_event_for_client_v2_without_room_id - elif filter.event_format == 'federation': + elif filter.event_format == "federation": event_formatter = format_event_raw else: - raise Exception("Unknown event format %s" % (filter.event_format, )) + raise Exception("Unknown event format %s" % (filter.event_format,)) joined = yield self.encode_joined( - sync_result.joined, time_now, access_token_id, + sync_result.joined, + time_now, + access_token_id, filter.event_fields, event_formatter, ) invited = yield self.encode_invited( - sync_result.invited, time_now, access_token_id, - event_formatter, + sync_result.invited, time_now, access_token_id, event_formatter ) archived = yield self.encode_archived( - sync_result.archived, time_now, access_token_id, + sync_result.archived, + time_now, + access_token_id, filter.event_fields, event_formatter, ) - defer.returnValue({ - "account_data": {"events": sync_result.account_data}, - "to_device": {"events": sync_result.to_device}, - "device_lists": { - "changed": list(sync_result.device_lists.changed), - "left": list(sync_result.device_lists.left), - }, - "presence": SyncRestServlet.encode_presence( - sync_result.presence, time_now - ), - "rooms": { - "join": joined, - "invite": invited, - "leave": archived, - }, - "groups": { - "join": sync_result.groups.join, - "invite": sync_result.groups.invite, - "leave": sync_result.groups.leave, - }, - "device_one_time_keys_count": sync_result.device_one_time_keys_count, - "next_batch": sync_result.next_batch.to_string(), - }) + defer.returnValue( + { + "account_data": {"events": sync_result.account_data}, + "to_device": {"events": sync_result.to_device}, + "device_lists": { + "changed": list(sync_result.device_lists.changed), + "left": list(sync_result.device_lists.left), + }, + "presence": SyncRestServlet.encode_presence( + sync_result.presence, time_now + ), + "rooms": {"join": joined, "invite": invited, "leave": archived}, + "groups": { + "join": sync_result.groups.join, + "invite": sync_result.groups.invite, + "leave": sync_result.groups.leave, + }, + "device_one_time_keys_count": sync_result.device_one_time_keys_count, + "next_batch": sync_result.next_batch.to_string(), + } + ) @staticmethod def encode_presence(events, time_now): @@ -262,7 +265,11 @@ def encode_joined(self, rooms, time_now, token_id, event_fields, event_formatter joined = {} for room in rooms: joined[room.room_id] = yield self.encode_room( - room, time_now, token_id, joined=True, only_fields=event_fields, + room, + time_now, + token_id, + joined=True, + only_fields=event_fields, event_formatter=event_formatter, ) @@ -290,7 +297,9 @@ def encode_invited(self, rooms, time_now, token_id, event_formatter): invited = {} for room in rooms: invite = yield self._event_serializer.serialize_event( - room.invite, time_now, token_id=token_id, + room.invite, + time_now, + token_id=token_id, event_format=event_formatter, is_invite=True, ) @@ -298,9 +307,7 @@ def encode_invited(self, rooms, time_now, token_id, event_formatter): invite["unsigned"] = unsigned invited_state = list(unsigned.pop("invite_room_state", [])) invited_state.append(invite) - invited[room.room_id] = { - "invite_state": {"events": invited_state} - } + invited[room.room_id] = {"invite_state": {"events": invited_state}} defer.returnValue(invited) @@ -327,7 +334,10 @@ def encode_archived(self, rooms, time_now, token_id, event_fields, event_formatt joined = {} for room in rooms: joined[room.room_id] = yield self.encode_room( - room, time_now, token_id, joined=False, + room, + time_now, + token_id, + joined=False, only_fields=event_fields, event_formatter=event_formatter, ) @@ -336,8 +346,7 @@ def encode_archived(self, rooms, time_now, token_id, event_fields, event_formatt @defer.inlineCallbacks def encode_room( - self, room, time_now, token_id, joined, - only_fields, event_formatter, + self, room, time_now, token_id, joined, only_fields, event_formatter ): """ Args: @@ -355,9 +364,11 @@ def encode_room( Returns: dict[str, object]: the room, encoded in our response format """ + def serialize(events): return self._event_serializer.serialize_events( - events, time_now=time_now, + events, + time_now=time_now, # We don't bundle "live" events, as otherwise clients # will end up double counting annotations. bundle_aggregations=False, @@ -377,7 +388,9 @@ def serialize(events): if event.room_id != room.room_id: logger.warn( "Event %r is under room %r instead of %r", - event.event_id, room.room_id, event.room_id, + event.event_id, + room.room_id, + event.room_id, ) serialized_state = yield serialize(state_events) diff --git a/synapse/rest/client/v2_alpha/tags.py b/synapse/rest/client/v2_alpha/tags.py index ebff7cff4516..07b6ede6030c 100644 --- a/synapse/rest/client/v2_alpha/tags.py +++ b/synapse/rest/client/v2_alpha/tags.py @@ -29,9 +29,8 @@ class TagListServlet(RestServlet): """ GET /user/{user_id}/rooms/{room_id}/tags HTTP/1.1 """ - PATTERNS = client_patterns( - "/user/(?P[^/]*)/rooms/(?P[^/]*)/tags" - ) + + PATTERNS = client_patterns("/user/(?P[^/]*)/rooms/(?P[^/]*)/tags") def __init__(self, hs): super(TagListServlet, self).__init__() @@ -54,6 +53,7 @@ class TagServlet(RestServlet): PUT /user/{user_id}/rooms/{room_id}/tags/{tag} HTTP/1.1 DELETE /user/{user_id}/rooms/{room_id}/tags/{tag} HTTP/1.1 """ + PATTERNS = client_patterns( "/user/(?P[^/]*)/rooms/(?P[^/]*)/tags/(?P[^/]*)" ) @@ -74,9 +74,7 @@ def on_PUT(self, request, user_id, room_id, tag): max_id = yield self.store.add_tag_to_room(user_id, room_id, tag, body) - self.notifier.on_new_event( - "account_data_key", max_id, users=[user_id] - ) + self.notifier.on_new_event("account_data_key", max_id, users=[user_id]) defer.returnValue((200, {})) @@ -88,9 +86,7 @@ def on_DELETE(self, request, user_id, room_id, tag): max_id = yield self.store.remove_tag_from_room(user_id, room_id, tag) - self.notifier.on_new_event( - "account_data_key", max_id, users=[user_id] - ) + self.notifier.on_new_event("account_data_key", max_id, users=[user_id]) defer.returnValue((200, {})) diff --git a/synapse/rest/client/v2_alpha/thirdparty.py b/synapse/rest/client/v2_alpha/thirdparty.py index e7a987466ab5..1e66662a056b 100644 --- a/synapse/rest/client/v2_alpha/thirdparty.py +++ b/synapse/rest/client/v2_alpha/thirdparty.py @@ -57,7 +57,7 @@ def on_GET(self, request, protocol): yield self.auth.get_user_by_req(request, allow_guest=True) protocols = yield self.appservice_handler.get_3pe_protocols( - only_protocol=protocol, + only_protocol=protocol ) if protocol in protocols: defer.returnValue((200, protocols[protocol])) diff --git a/synapse/rest/client/v2_alpha/tokenrefresh.py b/synapse/rest/client/v2_alpha/tokenrefresh.py index 6c366142e1dc..2da0f5581126 100644 --- a/synapse/rest/client/v2_alpha/tokenrefresh.py +++ b/synapse/rest/client/v2_alpha/tokenrefresh.py @@ -26,6 +26,7 @@ class TokenRefreshRestServlet(RestServlet): Exchanges refresh tokens for a pair of an access token and a new refresh token. """ + PATTERNS = client_patterns("/tokenrefresh") def __init__(self, hs): diff --git a/synapse/rest/client/v2_alpha/user_directory.py b/synapse/rest/client/v2_alpha/user_directory.py index 69e4efc47a85..e19fb6d58385 100644 --- a/synapse/rest/client/v2_alpha/user_directory.py +++ b/synapse/rest/client/v2_alpha/user_directory.py @@ -60,10 +60,7 @@ def on_POST(self, request): user_id = requester.user.to_string() if not self.hs.config.user_directory_search_enabled: - defer.returnValue((200, { - "limited": False, - "results": [], - })) + defer.returnValue((200, {"limited": False, "results": []})) body = parse_json_object_from_request(request) @@ -76,7 +73,7 @@ def on_POST(self, request): raise SynapseError(400, "`search_term` is required field") results = yield self.user_directory_handler.search_users( - user_id, search_term, limit, + user_id, search_term, limit ) defer.returnValue((200, results)) diff --git a/synapse/rest/client/versions.py b/synapse/rest/client/versions.py index babbf6a23ce8..0e0919163267 100644 --- a/synapse/rest/client/versions.py +++ b/synapse/rest/client/versions.py @@ -25,27 +25,28 @@ class VersionsRestServlet(RestServlet): PATTERNS = [re.compile("^/_matrix/client/versions$")] def on_GET(self, request): - return (200, { - "versions": [ - # XXX: at some point we need to decide whether we need to include - # the previous version numbers, given we've defined r0.3.0 to be - # backwards compatible with r0.2.0. But need to check how - # conscientious we've been in compatibility, and decide whether the - # middle number is the major revision when at 0.X.Y (as opposed to - # X.Y.Z). And we need to decide whether it's fair to make clients - # parse the version string to figure out what's going on. - "r0.0.1", - "r0.1.0", - "r0.2.0", - "r0.3.0", - "r0.4.0", - "r0.5.0", - ], - # as per MSC1497: - "unstable_features": { - "m.lazy_load_members": True, - } - }) + return ( + 200, + { + "versions": [ + # XXX: at some point we need to decide whether we need to include + # the previous version numbers, given we've defined r0.3.0 to be + # backwards compatible with r0.2.0. But need to check how + # conscientious we've been in compatibility, and decide whether the + # middle number is the major revision when at 0.X.Y (as opposed to + # X.Y.Z). And we need to decide whether it's fair to make clients + # parse the version string to figure out what's going on. + "r0.0.1", + "r0.1.0", + "r0.2.0", + "r0.3.0", + "r0.4.0", + "r0.5.0", + ], + # as per MSC1497: + "unstable_features": {"m.lazy_load_members": True}, + }, + ) def register_servlets(http_server): diff --git a/synapse/rest/consent/consent_resource.py b/synapse/rest/consent/consent_resource.py index 6b371bfa2fa3..9a32892d8bae 100644 --- a/synapse/rest/consent/consent_resource.py +++ b/synapse/rest/consent/consent_resource.py @@ -42,6 +42,7 @@ if hasattr(hmac, "compare_digest"): compare_digest = hmac.compare_digest else: + def compare_digest(a, b): return a == b @@ -80,6 +81,7 @@ class ConsentResource(Resource): For POST: required; gives the value to be recorded in the database against the user. """ + def __init__(self, hs): """ Args: @@ -98,21 +100,20 @@ def __init__(self, hs): if self._default_consent_version is None: raise ConfigError( "Consent resource is enabled but user_consent section is " - "missing in config file.", + "missing in config file." ) consent_template_directory = hs.config.user_consent_template_dir loader = jinja2.FileSystemLoader(consent_template_directory) self._jinja_env = jinja2.Environment( - loader=loader, - autoescape=jinja2.select_autoescape(['html', 'htm', 'xml']), + loader=loader, autoescape=jinja2.select_autoescape(["html", "htm", "xml"]) ) if hs.config.form_secret is None: raise ConfigError( "Consent resource is enabled but form_secret is not set in " - "config file. It should be set to an arbitrary secret string.", + "config file. It should be set to an arbitrary secret string." ) self._hmac_secret = hs.config.form_secret.encode("utf-8") @@ -139,7 +140,7 @@ def _async_render_GET(self, request): self._check_hash(username, userhmac_bytes) - if username.startswith('@'): + if username.startswith("@"): qualified_user_id = username else: qualified_user_id = UserID(username, self.hs.hostname).to_string() @@ -153,7 +154,8 @@ def _async_render_GET(self, request): try: self._render_template( - request, "%s.html" % (version,), + request, + "%s.html" % (version,), user=username, userhmac=userhmac, version=version, @@ -180,7 +182,7 @@ def _async_render_POST(self, request): self._check_hash(username, userhmac) - if username.startswith('@'): + if username.startswith("@"): qualified_user_id = username else: qualified_user_id = UserID(username, self.hs.hostname).to_string() @@ -221,11 +223,13 @@ def _check_hash(self, userid, userhmac): SynapseError if the hash doesn't match """ - want_mac = hmac.new( - key=self._hmac_secret, - msg=userid.encode('utf-8'), - digestmod=sha256, - ).hexdigest().encode('ascii') + want_mac = ( + hmac.new( + key=self._hmac_secret, msg=userid.encode("utf-8"), digestmod=sha256 + ) + .hexdigest() + .encode("ascii") + ) if not compare_digest(want_mac, userhmac): raise SynapseError(http_client.FORBIDDEN, "HMAC incorrect") diff --git a/synapse/rest/key/v2/local_key_resource.py b/synapse/rest/key/v2/local_key_resource.py index ec0ec7b431f8..c16280f66806 100644 --- a/synapse/rest/key/v2/local_key_resource.py +++ b/synapse/rest/key/v2/local_key_resource.py @@ -80,33 +80,27 @@ def response_json_object(self): for key in self.config.signing_key: verify_key_bytes = key.verify_key.encode() key_id = "%s:%s" % (key.alg, key.version) - verify_keys[key_id] = { - u"key": encode_base64(verify_key_bytes) - } + verify_keys[key_id] = {"key": encode_base64(verify_key_bytes)} old_verify_keys = {} for key_id, key in self.config.old_signing_keys.items(): verify_key_bytes = key.encode() old_verify_keys[key_id] = { - u"key": encode_base64(verify_key_bytes), - u"expired_ts": key.expired_ts, + "key": encode_base64(verify_key_bytes), + "expired_ts": key.expired_ts, } tls_fingerprints = self.config.tls_fingerprints json_object = { - u"valid_until_ts": self.valid_until_ts, - u"server_name": self.config.server_name, - u"verify_keys": verify_keys, - u"old_verify_keys": old_verify_keys, - u"tls_fingerprints": tls_fingerprints, + "valid_until_ts": self.valid_until_ts, + "server_name": self.config.server_name, + "verify_keys": verify_keys, + "old_verify_keys": old_verify_keys, + "tls_fingerprints": tls_fingerprints, } for key in self.config.signing_key: - json_object = sign_json( - json_object, - self.config.server_name, - key, - ) + json_object = sign_json(json_object, self.config.server_name, key) return json_object def render_GET(self, request): @@ -114,6 +108,4 @@ def render_GET(self, request): # Update the expiry time if less than half the interval remains. if time_now + self.config.key_refresh_interval / 2 > self.valid_until_ts: self.update_response_body(time_now) - return respond_with_json_bytes( - request, 200, self.response_body, - ) + return respond_with_json_bytes(request, 200, self.response_body) diff --git a/synapse/rest/key/v2/remote_key_resource.py b/synapse/rest/key/v2/remote_key_resource.py index 8a730bbc354c..ec8b9d72690f 100644 --- a/synapse/rest/key/v2/remote_key_resource.py +++ b/synapse/rest/key/v2/remote_key_resource.py @@ -103,20 +103,16 @@ def render_GET(self, request): def async_render_GET(self, request): if len(request.postpath) == 1: server, = request.postpath - query = {server.decode('ascii'): {}} + query = {server.decode("ascii"): {}} elif len(request.postpath) == 2: server, key_id = request.postpath - minimum_valid_until_ts = parse_integer( - request, "minimum_valid_until_ts" - ) + minimum_valid_until_ts = parse_integer(request, "minimum_valid_until_ts") arguments = {} if minimum_valid_until_ts is not None: arguments["minimum_valid_until_ts"] = minimum_valid_until_ts - query = {server.decode('ascii'): {key_id.decode('ascii'): arguments}} + query = {server.decode("ascii"): {key_id.decode("ascii"): arguments}} else: - raise SynapseError( - 404, "Not found %r" % request.postpath, Codes.NOT_FOUND - ) + raise SynapseError(404, "Not found %r" % request.postpath, Codes.NOT_FOUND) yield self.query_keys(request, query, query_remote_on_cache_miss=True) @@ -140,8 +136,8 @@ def query_keys(self, request, query, query_remote_on_cache_miss=False): store_queries = [] for server_name, key_ids in query.items(): if ( - self.federation_domain_whitelist is not None and - server_name not in self.federation_domain_whitelist + self.federation_domain_whitelist is not None + and server_name not in self.federation_domain_whitelist ): logger.debug("Federation denied with %s", server_name) continue @@ -159,9 +155,7 @@ def query_keys(self, request, query, query_remote_on_cache_miss=False): cache_misses = dict() for (server_name, key_id, from_server), results in cached.items(): - results = [ - (result["ts_added_ms"], result) for result in results - ] + results = [(result["ts_added_ms"], result) for result in results] if not results and key_id is not None: cache_misses.setdefault(server_name, set()).add(key_id) @@ -178,23 +172,30 @@ def query_keys(self, request, query, query_remote_on_cache_miss=False): logger.debug( "Cached response for %r/%r is older than requested" ": valid_until (%r) < minimum_valid_until (%r)", - server_name, key_id, - ts_valid_until_ms, req_valid_until + server_name, + key_id, + ts_valid_until_ms, + req_valid_until, ) miss = True else: logger.debug( "Cached response for %r/%r is newer than requested" ": valid_until (%r) >= minimum_valid_until (%r)", - server_name, key_id, - ts_valid_until_ms, req_valid_until + server_name, + key_id, + ts_valid_until_ms, + req_valid_until, ) elif (ts_added_ms + ts_valid_until_ms) / 2 < time_now_ms: logger.debug( "Cached response for %r/%r is too old" ": (added (%r) + valid_until (%r)) / 2 < now (%r)", - server_name, key_id, - ts_added_ms, ts_valid_until_ms, time_now_ms + server_name, + key_id, + ts_added_ms, + ts_valid_until_ms, + time_now_ms, ) # We more than half way through the lifetime of the # response. We should fetch a fresh copy. @@ -203,8 +204,11 @@ def query_keys(self, request, query, query_remote_on_cache_miss=False): logger.debug( "Cached response for %r/%r is still valid" ": (added (%r) + valid_until (%r)) / 2 < now (%r)", - server_name, key_id, - ts_added_ms, ts_valid_until_ms, time_now_ms + server_name, + key_id, + ts_added_ms, + ts_valid_until_ms, + time_now_ms, ) if miss: @@ -216,12 +220,10 @@ def query_keys(self, request, query, query_remote_on_cache_miss=False): if cache_misses and query_remote_on_cache_miss: yield self.fetcher.get_keys(cache_misses) - yield self.query_keys( - request, query, query_remote_on_cache_miss=False - ) + yield self.query_keys(request, query, query_remote_on_cache_miss=False) else: result_io = BytesIO() - result_io.write(b"{\"server_keys\":") + result_io.write(b'{"server_keys":') sep = b"[" for json_bytes in json_results: result_io.write(sep) @@ -231,6 +233,4 @@ def query_keys(self, request, query, query_remote_on_cache_miss=False): result_io.write(sep) result_io.write(b"]}") - respond_with_json_bytes( - request, 200, result_io.getvalue(), - ) + respond_with_json_bytes(request, 200, result_io.getvalue()) diff --git a/synapse/rest/media/v0/content_repository.py b/synapse/rest/media/v0/content_repository.py index 5a426ff2f63c..86884c0ef478 100644 --- a/synapse/rest/media/v0/content_repository.py +++ b/synapse/rest/media/v0/content_repository.py @@ -44,6 +44,7 @@ class ContentRepoResource(resource.Resource): - Content type base64d (so we can return it when clients GET it) """ + isLeaf = True def __init__(self, hs, directory): @@ -56,7 +57,7 @@ def render_GET(self, request): # servers. # TODO: A little crude here, we could do this better. - filename = request.path.decode('ascii').split('/')[-1] + filename = request.path.decode("ascii").split("/")[-1] # be paranoid filename = re.sub("[^0-9A-z.-_]", "", filename) @@ -69,17 +70,15 @@ def render_GET(self, request): base64_contentype = filename.split(".")[1] content_type = base64.urlsafe_b64decode(base64_contentype) logger.info("Sending file %s", file_path) - f = open(file_path, 'rb') - request.setHeader('Content-Type', content_type) + f = open(file_path, "rb") + request.setHeader("Content-Type", content_type) # cache for at least a day. # XXX: we might want to turn this off for data we don't want to # recommend caching as it's sensitive or private - or at least # select private. don't bother setting Expires as all our matrix # clients are smart enough to be happy with Cache-Control (right?) - request.setHeader( - b"Cache-Control", b"public,max-age=86400,s-maxage=86400" - ) + request.setHeader(b"Cache-Control", b"public,max-age=86400,s-maxage=86400") d = FileSender().beginFileTransfer(f, request) @@ -87,13 +86,15 @@ def render_GET(self, request): def cbFinished(ignored): f.close() finish_request(request) + d.addCallback(cbFinished) else: respond_with_json_bytes( request, 404, json.dumps(cs_error("Not found", code=Codes.NOT_FOUND)), - send_cors=True) + send_cors=True, + ) return server.NOT_DONE_YET diff --git a/synapse/rest/media/v1/_base.py b/synapse/rest/media/v1/_base.py index 2dcc8f74d699..3318638d3e69 100644 --- a/synapse/rest/media/v1/_base.py +++ b/synapse/rest/media/v1/_base.py @@ -38,8 +38,8 @@ def parse_media_id(request): server_name, media_id = request.postpath[:2] if isinstance(server_name, bytes): - server_name = server_name.decode('utf-8') - media_id = media_id.decode('utf8') + server_name = server_name.decode("utf-8") + media_id = media_id.decode("utf8") file_name = None if len(request.postpath) > 2: @@ -120,11 +120,11 @@ def _quote(x): # correctly interpret those as of 0.99.2 and (b) they are a bit of a pain and we # may as well just do the filename* version. if _can_encode_filename_as_token(upload_name): - disposition = 'inline; filename=%s' % (upload_name, ) + disposition = "inline; filename=%s" % (upload_name,) else: - disposition = "inline; filename*=utf-8''%s" % (_quote(upload_name), ) + disposition = "inline; filename*=utf-8''%s" % (_quote(upload_name),) - request.setHeader(b"Content-Disposition", disposition.encode('ascii')) + request.setHeader(b"Content-Disposition", disposition.encode("ascii")) # cache for at least a day. # XXX: we might want to turn this off for data we don't want to @@ -137,10 +137,27 @@ def _quote(x): # separators as defined in RFC2616. SP and HT are handled separately. # see _can_encode_filename_as_token. -_FILENAME_SEPARATOR_CHARS = set(( - "(", ")", "<", ">", "@", ",", ";", ":", "\\", '"', - "/", "[", "]", "?", "=", "{", "}", -)) +_FILENAME_SEPARATOR_CHARS = set( + ( + "(", + ")", + "<", + ">", + "@", + ",", + ";", + ":", + "\\", + '"', + "/", + "[", + "]", + "?", + "=", + "{", + "}", + ) +) def _can_encode_filename_as_token(x): @@ -271,7 +288,7 @@ def get_filename_from_headers(headers): Returns: A Unicode string of the filename, or None. """ - content_disposition = headers.get(b"Content-Disposition", [b'']) + content_disposition = headers.get(b"Content-Disposition", [b""]) # No header, bail out. if not content_disposition[0]: @@ -293,7 +310,7 @@ def get_filename_from_headers(headers): # Once it is decoded, we can then unquote the %-encoded # parts strictly into a unicode string. upload_name = urllib.parse.unquote( - upload_name_utf8.decode('ascii'), errors="strict" + upload_name_utf8.decode("ascii"), errors="strict" ) except UnicodeDecodeError: # Incorrect UTF-8. @@ -302,7 +319,7 @@ def get_filename_from_headers(headers): # On Python 2, we first unquote the %-encoded parts and then # decode it strictly using UTF-8. try: - upload_name = urllib.parse.unquote(upload_name_utf8).decode('utf8') + upload_name = urllib.parse.unquote(upload_name_utf8).decode("utf8") except UnicodeDecodeError: pass @@ -310,7 +327,7 @@ def get_filename_from_headers(headers): if not upload_name: upload_name_ascii = params.get(b"filename", None) if upload_name_ascii and is_ascii(upload_name_ascii): - upload_name = upload_name_ascii.decode('ascii') + upload_name = upload_name_ascii.decode("ascii") # This may be None here, indicating we did not find a matching name. return upload_name @@ -328,19 +345,19 @@ def _parse_header(line): Tuple[bytes, dict[bytes, bytes]]: the main content-type, followed by the parameter dictionary """ - parts = _parseparam(b';' + line) + parts = _parseparam(b";" + line) key = next(parts) pdict = {} for p in parts: - i = p.find(b'=') + i = p.find(b"=") if i >= 0: name = p[:i].strip().lower() - value = p[i + 1:].strip() + value = p[i + 1 :].strip() # strip double-quotes if len(value) >= 2 and value[0:1] == value[-1:] == b'"': value = value[1:-1] - value = value.replace(b'\\\\', b'\\').replace(b'\\"', b'"') + value = value.replace(b"\\\\", b"\\").replace(b'\\"', b'"') pdict[name] = value return key, pdict @@ -357,16 +374,16 @@ def _parseparam(s): Returns: Iterable[bytes]: the split input """ - while s[:1] == b';': + while s[:1] == b";": s = s[1:] # look for the next ; - end = s.find(b';') + end = s.find(b";") # if there is an odd number of " marks between here and the next ;, skip to the # next ; instead while end > 0 and (s.count(b'"', 0, end) - s.count(b'\\"', 0, end)) % 2: - end = s.find(b';', end + 1) + end = s.find(b";", end + 1) if end < 0: end = len(s) diff --git a/synapse/rest/media/v1/config_resource.py b/synapse/rest/media/v1/config_resource.py index 77316033f7e6..fa3d6680fc8d 100644 --- a/synapse/rest/media/v1/config_resource.py +++ b/synapse/rest/media/v1/config_resource.py @@ -29,9 +29,7 @@ def __init__(self, hs): config = hs.get_config() self.clock = hs.get_clock() self.auth = hs.get_auth() - self.limits_dict = { - "m.upload.size": config.max_upload_size, - } + self.limits_dict = {"m.upload.size": config.max_upload_size} def render_GET(self, request): self._async_render_GET(request) diff --git a/synapse/rest/media/v1/download_resource.py b/synapse/rest/media/v1/download_resource.py index bdc5daecc1d5..a21a35f84336 100644 --- a/synapse/rest/media/v1/download_resource.py +++ b/synapse/rest/media/v1/download_resource.py @@ -54,18 +54,20 @@ def _async_render_GET(self, request): b" plugin-types application/pdf;" b" style-src 'unsafe-inline';" b" media-src 'self';" - b" object-src 'self';" + b" object-src 'self';", ) server_name, media_id, name = parse_media_id(request) if server_name == self.server_name: yield self.media_repo.get_local_media(request, media_id, name) else: allow_remote = synapse.http.servlet.parse_boolean( - request, "allow_remote", default=True) + request, "allow_remote", default=True + ) if not allow_remote: logger.info( "Rejecting request for remote media %s/%s due to allow_remote", - server_name, media_id, + server_name, + media_id, ) respond_404(request) return diff --git a/synapse/rest/media/v1/filepath.py b/synapse/rest/media/v1/filepath.py index c8586fa28073..e25c382c9c66 100644 --- a/synapse/rest/media/v1/filepath.py +++ b/synapse/rest/media/v1/filepath.py @@ -24,6 +24,7 @@ def _wrap_in_base_path(func): """Takes a function that returns a relative path and turns it into an absolute path based on the location of the primary media store """ + @functools.wraps(func) def _wrapped(self, *args, **kwargs): path = func(self, *args, **kwargs) @@ -43,125 +44,102 @@ class MediaFilePaths(object): def __init__(self, primary_base_path): self.base_path = primary_base_path - def default_thumbnail_rel(self, default_top_level, default_sub_type, width, - height, content_type, method): + def default_thumbnail_rel( + self, default_top_level, default_sub_type, width, height, content_type, method + ): top_level_type, sub_type = content_type.split("/") - file_name = "%i-%i-%s-%s-%s" % ( - width, height, top_level_type, sub_type, method - ) + file_name = "%i-%i-%s-%s-%s" % (width, height, top_level_type, sub_type, method) return os.path.join( - "default_thumbnails", default_top_level, - default_sub_type, file_name + "default_thumbnails", default_top_level, default_sub_type, file_name ) default_thumbnail = _wrap_in_base_path(default_thumbnail_rel) def local_media_filepath_rel(self, media_id): - return os.path.join( - "local_content", - media_id[0:2], media_id[2:4], media_id[4:] - ) + return os.path.join("local_content", media_id[0:2], media_id[2:4], media_id[4:]) local_media_filepath = _wrap_in_base_path(local_media_filepath_rel) - def local_media_thumbnail_rel(self, media_id, width, height, content_type, - method): + def local_media_thumbnail_rel(self, media_id, width, height, content_type, method): top_level_type, sub_type = content_type.split("/") - file_name = "%i-%i-%s-%s-%s" % ( - width, height, top_level_type, sub_type, method - ) + file_name = "%i-%i-%s-%s-%s" % (width, height, top_level_type, sub_type, method) return os.path.join( - "local_thumbnails", - media_id[0:2], media_id[2:4], media_id[4:], - file_name + "local_thumbnails", media_id[0:2], media_id[2:4], media_id[4:], file_name ) local_media_thumbnail = _wrap_in_base_path(local_media_thumbnail_rel) def remote_media_filepath_rel(self, server_name, file_id): return os.path.join( - "remote_content", server_name, - file_id[0:2], file_id[2:4], file_id[4:] + "remote_content", server_name, file_id[0:2], file_id[2:4], file_id[4:] ) remote_media_filepath = _wrap_in_base_path(remote_media_filepath_rel) - def remote_media_thumbnail_rel(self, server_name, file_id, width, height, - content_type, method): + def remote_media_thumbnail_rel( + self, server_name, file_id, width, height, content_type, method + ): top_level_type, sub_type = content_type.split("/") file_name = "%i-%i-%s-%s" % (width, height, top_level_type, sub_type) return os.path.join( - "remote_thumbnail", server_name, - file_id[0:2], file_id[2:4], file_id[4:], - file_name + "remote_thumbnail", + server_name, + file_id[0:2], + file_id[2:4], + file_id[4:], + file_name, ) remote_media_thumbnail = _wrap_in_base_path(remote_media_thumbnail_rel) def remote_media_thumbnail_dir(self, server_name, file_id): return os.path.join( - self.base_path, "remote_thumbnail", server_name, - file_id[0:2], file_id[2:4], file_id[4:], + self.base_path, + "remote_thumbnail", + server_name, + file_id[0:2], + file_id[2:4], + file_id[4:], ) def url_cache_filepath_rel(self, media_id): if NEW_FORMAT_ID_RE.match(media_id): # Media id is of the form # E.g.: 2017-09-28-fsdRDt24DS234dsf - return os.path.join( - "url_cache", - media_id[:10], media_id[11:] - ) + return os.path.join("url_cache", media_id[:10], media_id[11:]) else: - return os.path.join( - "url_cache", - media_id[0:2], media_id[2:4], media_id[4:], - ) + return os.path.join("url_cache", media_id[0:2], media_id[2:4], media_id[4:]) url_cache_filepath = _wrap_in_base_path(url_cache_filepath_rel) def url_cache_filepath_dirs_to_delete(self, media_id): "The dirs to try and remove if we delete the media_id file" if NEW_FORMAT_ID_RE.match(media_id): - return [ - os.path.join( - self.base_path, "url_cache", - media_id[:10], - ), - ] + return [os.path.join(self.base_path, "url_cache", media_id[:10])] else: return [ - os.path.join( - self.base_path, "url_cache", - media_id[0:2], media_id[2:4], - ), - os.path.join( - self.base_path, "url_cache", - media_id[0:2], - ), + os.path.join(self.base_path, "url_cache", media_id[0:2], media_id[2:4]), + os.path.join(self.base_path, "url_cache", media_id[0:2]), ] - def url_cache_thumbnail_rel(self, media_id, width, height, content_type, - method): + def url_cache_thumbnail_rel(self, media_id, width, height, content_type, method): # Media id is of the form # E.g.: 2017-09-28-fsdRDt24DS234dsf top_level_type, sub_type = content_type.split("/") - file_name = "%i-%i-%s-%s-%s" % ( - width, height, top_level_type, sub_type, method - ) + file_name = "%i-%i-%s-%s-%s" % (width, height, top_level_type, sub_type, method) if NEW_FORMAT_ID_RE.match(media_id): return os.path.join( - "url_cache_thumbnails", - media_id[:10], media_id[11:], - file_name + "url_cache_thumbnails", media_id[:10], media_id[11:], file_name ) else: return os.path.join( "url_cache_thumbnails", - media_id[0:2], media_id[2:4], media_id[4:], - file_name + media_id[0:2], + media_id[2:4], + media_id[4:], + file_name, ) url_cache_thumbnail = _wrap_in_base_path(url_cache_thumbnail_rel) @@ -172,13 +150,15 @@ def url_cache_thumbnail_directory(self, media_id): if NEW_FORMAT_ID_RE.match(media_id): return os.path.join( - self.base_path, "url_cache_thumbnails", - media_id[:10], media_id[11:], + self.base_path, "url_cache_thumbnails", media_id[:10], media_id[11:] ) else: return os.path.join( - self.base_path, "url_cache_thumbnails", - media_id[0:2], media_id[2:4], media_id[4:], + self.base_path, + "url_cache_thumbnails", + media_id[0:2], + media_id[2:4], + media_id[4:], ) def url_cache_thumbnail_dirs_to_delete(self, media_id): @@ -188,26 +168,21 @@ def url_cache_thumbnail_dirs_to_delete(self, media_id): if NEW_FORMAT_ID_RE.match(media_id): return [ os.path.join( - self.base_path, "url_cache_thumbnails", - media_id[:10], media_id[11:], - ), - os.path.join( - self.base_path, "url_cache_thumbnails", - media_id[:10], + self.base_path, "url_cache_thumbnails", media_id[:10], media_id[11:] ), + os.path.join(self.base_path, "url_cache_thumbnails", media_id[:10]), ] else: return [ os.path.join( - self.base_path, "url_cache_thumbnails", - media_id[0:2], media_id[2:4], media_id[4:], - ), - os.path.join( - self.base_path, "url_cache_thumbnails", - media_id[0:2], media_id[2:4], + self.base_path, + "url_cache_thumbnails", + media_id[0:2], + media_id[2:4], + media_id[4:], ), os.path.join( - self.base_path, "url_cache_thumbnails", - media_id[0:2], + self.base_path, "url_cache_thumbnails", media_id[0:2], media_id[2:4] ), + os.path.join(self.base_path, "url_cache_thumbnails", media_id[0:2]), ] diff --git a/synapse/rest/media/v1/media_repository.py b/synapse/rest/media/v1/media_repository.py index a4929dd5dbd2..df3d985a3855 100644 --- a/synapse/rest/media/v1/media_repository.py +++ b/synapse/rest/media/v1/media_repository.py @@ -100,17 +100,16 @@ def __init__(self, hs): storage_providers.append(provider) self.media_storage = MediaStorage( - self.hs, self.primary_base_path, self.filepaths, storage_providers, + self.hs, self.primary_base_path, self.filepaths, storage_providers ) self.clock.looping_call( - self._start_update_recently_accessed, - UPDATE_RECENTLY_ACCESSED_TS, + self._start_update_recently_accessed, UPDATE_RECENTLY_ACCESSED_TS ) def _start_update_recently_accessed(self): return run_as_background_process( - "update_recently_accessed_media", self._update_recently_accessed, + "update_recently_accessed_media", self._update_recently_accessed ) @defer.inlineCallbacks @@ -138,8 +137,9 @@ def mark_recently_accessed(self, server_name, media_id): self.recently_accessed_locals.add(media_id) @defer.inlineCallbacks - def create_content(self, media_type, upload_name, content, content_length, - auth_user): + def create_content( + self, media_type, upload_name, content, content_length, auth_user + ): """Store uploaded content for a local user and return the mxc URL Args: @@ -154,10 +154,7 @@ def create_content(self, media_type, upload_name, content, content_length, """ media_id = random_string(24) - file_info = FileInfo( - server_name=None, - file_id=media_id, - ) + file_info = FileInfo(server_name=None, file_id=media_id) fname = yield self.media_storage.store_file(content, file_info) @@ -172,9 +169,7 @@ def create_content(self, media_type, upload_name, content, content_length, user_id=auth_user, ) - yield self._generate_thumbnails( - None, media_id, media_id, media_type, - ) + yield self._generate_thumbnails(None, media_id, media_id, media_type) defer.returnValue("mxc://%s/%s" % (self.server_name, media_id)) @@ -205,14 +200,11 @@ def get_local_media(self, request, media_id, name): upload_name = name if name else media_info["upload_name"] url_cache = media_info["url_cache"] - file_info = FileInfo( - None, media_id, - url_cache=url_cache, - ) + file_info = FileInfo(None, media_id, url_cache=url_cache) responder = yield self.media_storage.fetch_media(file_info) yield respond_with_responder( - request, responder, media_type, media_length, upload_name, + request, responder, media_type, media_length, upload_name ) @defer.inlineCallbacks @@ -232,8 +224,8 @@ def get_remote_media(self, request, server_name, media_id, name): to request """ if ( - self.federation_domain_whitelist is not None and - server_name not in self.federation_domain_whitelist + self.federation_domain_whitelist is not None + and server_name not in self.federation_domain_whitelist ): raise FederationDeniedError(server_name) @@ -244,7 +236,7 @@ def get_remote_media(self, request, server_name, media_id, name): key = (server_name, media_id) with (yield self.remote_media_linearizer.queue(key)): responder, media_info = yield self._get_remote_media_impl( - server_name, media_id, + server_name, media_id ) # We deliberately stream the file outside the lock @@ -253,7 +245,7 @@ def get_remote_media(self, request, server_name, media_id, name): media_length = media_info["media_length"] upload_name = name if name else media_info["upload_name"] yield respond_with_responder( - request, responder, media_type, media_length, upload_name, + request, responder, media_type, media_length, upload_name ) else: respond_404(request) @@ -272,8 +264,8 @@ def get_remote_media_info(self, server_name, media_id): Deferred[dict]: The media_info of the file """ if ( - self.federation_domain_whitelist is not None and - server_name not in self.federation_domain_whitelist + self.federation_domain_whitelist is not None + and server_name not in self.federation_domain_whitelist ): raise FederationDeniedError(server_name) @@ -282,7 +274,7 @@ def get_remote_media_info(self, server_name, media_id): key = (server_name, media_id) with (yield self.remote_media_linearizer.queue(key)): responder, media_info = yield self._get_remote_media_impl( - server_name, media_id, + server_name, media_id ) # Ensure we actually use the responder so that it releases resources @@ -305,9 +297,7 @@ def _get_remote_media_impl(self, server_name, media_id): Returns: Deferred[(Responder, media_info)] """ - media_info = yield self.store.get_cached_remote_media( - server_name, media_id - ) + media_info = yield self.store.get_cached_remote_media(server_name, media_id) # file_id is the ID we use to track the file locally. If we've already # seen the file then reuse the existing ID, otherwise genereate a new @@ -331,9 +321,7 @@ def _get_remote_media_impl(self, server_name, media_id): # Failed to find the file anywhere, lets download it. - media_info = yield self._download_remote_file( - server_name, media_id, file_id - ) + media_info = yield self._download_remote_file(server_name, media_id, file_id) responder = yield self.media_storage.fetch_media(file_info) defer.returnValue((responder, media_info)) @@ -354,54 +342,60 @@ def _download_remote_file(self, server_name, media_id, file_id): Deferred[MediaInfo] """ - file_info = FileInfo( - server_name=server_name, - file_id=file_id, - ) + file_info = FileInfo(server_name=server_name, file_id=file_id) with self.media_storage.store_into_file(file_info) as (f, fname, finish): - request_path = "/".join(( - "/_matrix/media/v1/download", server_name, media_id, - )) + request_path = "/".join( + ("/_matrix/media/v1/download", server_name, media_id) + ) try: length, headers = yield self.client.get_file( - server_name, request_path, output_stream=f, - max_size=self.max_upload_size, args={ + server_name, + request_path, + output_stream=f, + max_size=self.max_upload_size, + args={ # tell the remote server to 404 if it doesn't # recognise the server_name, to make sure we don't # end up with a routing loop. - "allow_remote": "false", - } + "allow_remote": "false" + }, ) except RequestSendFailed as e: - logger.warn("Request failed fetching remote media %s/%s: %r", - server_name, media_id, e) + logger.warn( + "Request failed fetching remote media %s/%s: %r", + server_name, + media_id, + e, + ) raise SynapseError(502, "Failed to fetch remote media") except HttpResponseException as e: - logger.warn("HTTP error fetching remote media %s/%s: %s", - server_name, media_id, e.response) + logger.warn( + "HTTP error fetching remote media %s/%s: %s", + server_name, + media_id, + e.response, + ) if e.code == twisted.web.http.NOT_FOUND: raise e.to_synapse_error() raise SynapseError(502, "Failed to fetch remote media") except SynapseError: - logger.warn( - "Failed to fetch remote media %s/%s", - server_name, media_id, - ) + logger.warn("Failed to fetch remote media %s/%s", server_name, media_id) raise except NotRetryingDestination: logger.warn("Not retrying destination %r", server_name) raise SynapseError(502, "Failed to fetch remote media") except Exception: - logger.exception("Failed to fetch remote media %s/%s", - server_name, media_id) + logger.exception( + "Failed to fetch remote media %s/%s", server_name, media_id + ) raise SynapseError(502, "Failed to fetch remote media") yield finish() - media_type = headers[b"Content-Type"][0].decode('ascii') + media_type = headers[b"Content-Type"][0].decode("ascii") upload_name = get_filename_from_headers(headers) time_now_ms = self.clock.time_msec() @@ -425,24 +419,23 @@ def _download_remote_file(self, server_name, media_id, file_id): "filesystem_id": file_id, } - yield self._generate_thumbnails( - server_name, media_id, file_id, media_type, - ) + yield self._generate_thumbnails(server_name, media_id, file_id, media_type) defer.returnValue(media_info) def _get_thumbnail_requirements(self, media_type): return self.thumbnail_requirements.get(media_type, ()) - def _generate_thumbnail(self, thumbnailer, t_width, t_height, - t_method, t_type): + def _generate_thumbnail(self, thumbnailer, t_width, t_height, t_method, t_type): m_width = thumbnailer.width m_height = thumbnailer.height if m_width * m_height >= self.max_image_pixels: logger.info( "Image too large to thumbnail %r x %r > %r", - m_width, m_height, self.max_image_pixels + m_width, + m_height, + self.max_image_pixels, ) return @@ -462,17 +455,22 @@ def _generate_thumbnail(self, thumbnailer, t_width, t_height, return t_byte_source @defer.inlineCallbacks - def generate_local_exact_thumbnail(self, media_id, t_width, t_height, - t_method, t_type, url_cache): - input_path = yield self.media_storage.ensure_media_is_in_local_cache(FileInfo( - None, media_id, url_cache=url_cache, - )) + def generate_local_exact_thumbnail( + self, media_id, t_width, t_height, t_method, t_type, url_cache + ): + input_path = yield self.media_storage.ensure_media_is_in_local_cache( + FileInfo(None, media_id, url_cache=url_cache) + ) thumbnailer = Thumbnailer(input_path) t_byte_source = yield logcontext.defer_to_thread( self.hs.get_reactor(), self._generate_thumbnail, - thumbnailer, t_width, t_height, t_method, t_type + thumbnailer, + t_width, + t_height, + t_method, + t_type, ) if t_byte_source: @@ -489,7 +487,7 @@ def generate_local_exact_thumbnail(self, media_id, t_width, t_height, ) output_path = yield self.media_storage.store_file( - t_byte_source, file_info, + t_byte_source, file_info ) finally: t_byte_source.close() @@ -505,17 +503,22 @@ def generate_local_exact_thumbnail(self, media_id, t_width, t_height, defer.returnValue(output_path) @defer.inlineCallbacks - def generate_remote_exact_thumbnail(self, server_name, file_id, media_id, - t_width, t_height, t_method, t_type): - input_path = yield self.media_storage.ensure_media_is_in_local_cache(FileInfo( - server_name, file_id, url_cache=False, - )) + def generate_remote_exact_thumbnail( + self, server_name, file_id, media_id, t_width, t_height, t_method, t_type + ): + input_path = yield self.media_storage.ensure_media_is_in_local_cache( + FileInfo(server_name, file_id, url_cache=False) + ) thumbnailer = Thumbnailer(input_path) t_byte_source = yield logcontext.defer_to_thread( self.hs.get_reactor(), self._generate_thumbnail, - thumbnailer, t_width, t_height, t_method, t_type + thumbnailer, + t_width, + t_height, + t_method, + t_type, ) if t_byte_source: @@ -531,7 +534,7 @@ def generate_remote_exact_thumbnail(self, server_name, file_id, media_id, ) output_path = yield self.media_storage.store_file( - t_byte_source, file_info, + t_byte_source, file_info ) finally: t_byte_source.close() @@ -541,15 +544,22 @@ def generate_remote_exact_thumbnail(self, server_name, file_id, media_id, t_len = os.path.getsize(output_path) yield self.store.store_remote_media_thumbnail( - server_name, media_id, file_id, - t_width, t_height, t_type, t_method, t_len + server_name, + media_id, + file_id, + t_width, + t_height, + t_type, + t_method, + t_len, ) defer.returnValue(output_path) @defer.inlineCallbacks - def _generate_thumbnails(self, server_name, media_id, file_id, media_type, - url_cache=False): + def _generate_thumbnails( + self, server_name, media_id, file_id, media_type, url_cache=False + ): """Generate and store thumbnails for an image. Args: @@ -568,9 +578,9 @@ def _generate_thumbnails(self, server_name, media_id, file_id, media_type, if not requirements: return - input_path = yield self.media_storage.ensure_media_is_in_local_cache(FileInfo( - server_name, file_id, url_cache=url_cache, - )) + input_path = yield self.media_storage.ensure_media_is_in_local_cache( + FileInfo(server_name, file_id, url_cache=url_cache) + ) thumbnailer = Thumbnailer(input_path) m_width = thumbnailer.width @@ -579,14 +589,15 @@ def _generate_thumbnails(self, server_name, media_id, file_id, media_type, if m_width * m_height >= self.max_image_pixels: logger.info( "Image too large to thumbnail %r x %r > %r", - m_width, m_height, self.max_image_pixels + m_width, + m_height, + self.max_image_pixels, ) return if thumbnailer.transpose_method is not None: m_width, m_height = yield logcontext.defer_to_thread( - self.hs.get_reactor(), - thumbnailer.transpose + self.hs.get_reactor(), thumbnailer.transpose ) # We deduplicate the thumbnail sizes by ignoring the cropped versions if @@ -606,15 +617,11 @@ def _generate_thumbnails(self, server_name, media_id, file_id, media_type, # Generate the thumbnail if t_method == "crop": t_byte_source = yield logcontext.defer_to_thread( - self.hs.get_reactor(), - thumbnailer.crop, - t_width, t_height, t_type, + self.hs.get_reactor(), thumbnailer.crop, t_width, t_height, t_type ) elif t_method == "scale": t_byte_source = yield logcontext.defer_to_thread( - self.hs.get_reactor(), - thumbnailer.scale, - t_width, t_height, t_type, + self.hs.get_reactor(), thumbnailer.scale, t_width, t_height, t_type ) else: logger.error("Unrecognized method: %r", t_method) @@ -636,7 +643,7 @@ def _generate_thumbnails(self, server_name, media_id, file_id, media_type, ) output_path = yield self.media_storage.store_file( - t_byte_source, file_info, + t_byte_source, file_info ) finally: t_byte_source.close() @@ -646,18 +653,21 @@ def _generate_thumbnails(self, server_name, media_id, file_id, media_type, # Write to database if server_name: yield self.store.store_remote_media_thumbnail( - server_name, media_id, file_id, - t_width, t_height, t_type, t_method, t_len + server_name, + media_id, + file_id, + t_width, + t_height, + t_type, + t_method, + t_len, ) else: yield self.store.store_local_thumbnail( media_id, t_width, t_height, t_type, t_method, t_len ) - defer.returnValue({ - "width": m_width, - "height": m_height, - }) + defer.returnValue({"width": m_width, "height": m_height}) @defer.inlineCallbacks def delete_old_remote_media(self, before_ts): @@ -749,11 +759,12 @@ def __init__(self, hs): self.putChild(b"upload", UploadResource(hs, media_repo)) self.putChild(b"download", DownloadResource(hs, media_repo)) - self.putChild(b"thumbnail", ThumbnailResource( - hs, media_repo, media_repo.media_storage, - )) + self.putChild( + b"thumbnail", ThumbnailResource(hs, media_repo, media_repo.media_storage) + ) if hs.config.url_preview_enabled: - self.putChild(b"preview_url", PreviewUrlResource( - hs, media_repo, media_repo.media_storage, - )) + self.putChild( + b"preview_url", + PreviewUrlResource(hs, media_repo, media_repo.media_storage), + ) self.putChild(b"config", MediaConfigResource(hs)) diff --git a/synapse/rest/media/v1/media_storage.py b/synapse/rest/media/v1/media_storage.py index 896078fe7666..eff86836fb8a 100644 --- a/synapse/rest/media/v1/media_storage.py +++ b/synapse/rest/media/v1/media_storage.py @@ -66,8 +66,7 @@ def store_file(self, source, file_info): with self.store_into_file(file_info) as (f, fname, finish_cb): # Write to the main repository yield logcontext.defer_to_thread( - self.hs.get_reactor(), - _write_file_synchronously, source, f, + self.hs.get_reactor(), _write_file_synchronously, source, f ) yield finish_cb() @@ -179,7 +178,8 @@ def ensure_media_is_in_local_cache(self, file_info): if res: with res: consumer = BackgroundFileConsumer( - open(local_path, "wb"), self.hs.get_reactor()) + open(local_path, "wb"), self.hs.get_reactor() + ) yield res.write_to_consumer(consumer) yield consumer.wait() defer.returnValue(local_path) @@ -217,10 +217,10 @@ def _file_info_to_path(self, file_info): width=file_info.thumbnail_width, height=file_info.thumbnail_height, content_type=file_info.thumbnail_type, - method=file_info.thumbnail_method + method=file_info.thumbnail_method, ) return self.filepaths.remote_media_filepath_rel( - file_info.server_name, file_info.file_id, + file_info.server_name, file_info.file_id ) if file_info.thumbnail: @@ -229,11 +229,9 @@ def _file_info_to_path(self, file_info): width=file_info.thumbnail_width, height=file_info.thumbnail_height, content_type=file_info.thumbnail_type, - method=file_info.thumbnail_method + method=file_info.thumbnail_method, ) - return self.filepaths.local_media_filepath_rel( - file_info.file_id, - ) + return self.filepaths.local_media_filepath_rel(file_info.file_id) def _write_file_synchronously(source, dest): @@ -255,6 +253,7 @@ class FileResponder(Responder): open_file (file): A file like object to be streamed ot the client, is closed when finished streaming. """ + def __init__(self, open_file): self.open_file = open_file diff --git a/synapse/rest/media/v1/preview_url_resource.py b/synapse/rest/media/v1/preview_url_resource.py index acf87709f2ec..de6f292ffb1c 100644 --- a/synapse/rest/media/v1/preview_url_resource.py +++ b/synapse/rest/media/v1/preview_url_resource.py @@ -92,7 +92,7 @@ def __init__(self, hs, media_repo, media_storage): ) self._cleaner_loop = self.clock.looping_call( - self._start_expire_url_cache_data, 10 * 1000, + self._start_expire_url_cache_data, 10 * 1000 ) def render_OPTIONS(self, request): @@ -121,16 +121,16 @@ def _async_render_GET(self, request): for attrib in entry: pattern = entry[attrib] value = getattr(url_tuple, attrib) - logger.debug(( - "Matching attrib '%s' with value '%s' against" - " pattern '%s'" - ) % (attrib, value, pattern)) + logger.debug( + ("Matching attrib '%s' with value '%s' against" " pattern '%s'") + % (attrib, value, pattern) + ) if value is None: match = False continue - if pattern.startswith('^'): + if pattern.startswith("^"): if not re.match(pattern, getattr(url_tuple, attrib)): match = False continue @@ -139,12 +139,9 @@ def _async_render_GET(self, request): match = False continue if match: - logger.warn( - "URL %s blocked by url_blacklist entry %s", url, entry - ) + logger.warn("URL %s blocked by url_blacklist entry %s", url, entry) raise SynapseError( - 403, "URL blocked by url pattern blacklist entry", - Codes.UNKNOWN + 403, "URL blocked by url pattern blacklist entry", Codes.UNKNOWN ) # the in-memory cache: @@ -156,14 +153,8 @@ def _async_render_GET(self, request): observable = self._cache.get(url) if not observable: - download = run_in_background( - self._do_preview, - url, requester.user, ts, - ) - observable = ObservableDeferred( - download, - consumeErrors=True - ) + download = run_in_background(self._do_preview, url, requester.user, ts) + observable = ObservableDeferred(download, consumeErrors=True) self._cache[url] = observable else: logger.info("Returning cached response") @@ -187,15 +178,15 @@ def _do_preview(self, url, user, ts): # historical previews, if we have any) cache_result = yield self.store.get_url_cache(url, ts) if ( - cache_result and - cache_result["expires_ts"] > ts and - cache_result["response_code"] / 100 == 2 + cache_result + and cache_result["expires_ts"] > ts + and cache_result["response_code"] / 100 == 2 ): # It may be stored as text in the database, not as bytes (such as # PostgreSQL). If so, encode it back before handing it on. og = cache_result["og"] if isinstance(og, six.text_type): - og = og.encode('utf8') + og = og.encode("utf8") defer.returnValue(og) return @@ -203,33 +194,31 @@ def _do_preview(self, url, user, ts): logger.debug("got media_info of '%s'" % media_info) - if _is_media(media_info['media_type']): - file_id = media_info['filesystem_id'] + if _is_media(media_info["media_type"]): + file_id = media_info["filesystem_id"] dims = yield self.media_repo._generate_thumbnails( - None, file_id, file_id, media_info["media_type"], - url_cache=True, + None, file_id, file_id, media_info["media_type"], url_cache=True ) og = { - "og:description": media_info['download_name'], - "og:image": "mxc://%s/%s" % ( - self.server_name, media_info['filesystem_id'] - ), - "og:image:type": media_info['media_type'], - "matrix:image:size": media_info['media_length'], + "og:description": media_info["download_name"], + "og:image": "mxc://%s/%s" + % (self.server_name, media_info["filesystem_id"]), + "og:image:type": media_info["media_type"], + "matrix:image:size": media_info["media_length"], } if dims: - og["og:image:width"] = dims['width'] - og["og:image:height"] = dims['height'] + og["og:image:width"] = dims["width"] + og["og:image:height"] = dims["height"] else: logger.warn("Couldn't get dims for %s" % url) # define our OG response for this media - elif _is_html(media_info['media_type']): + elif _is_html(media_info["media_type"]): # TODO: somehow stop a big HTML tree from exploding synapse's RAM - with open(media_info['filename'], 'rb') as file: + with open(media_info["filename"], "rb") as file: body = file.read() encoding = None @@ -242,45 +231,43 @@ def _do_preview(self, url, user, ts): # If we find a match, it should take precedence over the # Content-Type header, so set it here. if match: - encoding = match.group(1).decode('ascii') + encoding = match.group(1).decode("ascii") # If we don't find a match, we'll look at the HTTP Content-Type, and # if that doesn't exist, we'll fall back to UTF-8. if not encoding: - match = _content_type_match.match( - media_info['media_type'] - ) + match = _content_type_match.match(media_info["media_type"]) encoding = match.group(1) if match else "utf-8" - og = decode_and_calc_og(body, media_info['uri'], encoding) + og = decode_and_calc_og(body, media_info["uri"], encoding) # pre-cache the image for posterity # FIXME: it might be cleaner to use the same flow as the main /preview_url # request itself and benefit from the same caching etc. But for now we # just rely on the caching on the master request to speed things up. - if 'og:image' in og and og['og:image']: + if "og:image" in og and og["og:image"]: image_info = yield self._download_url( - _rebase_url(og['og:image'], media_info['uri']), user + _rebase_url(og["og:image"], media_info["uri"]), user ) - if _is_media(image_info['media_type']): + if _is_media(image_info["media_type"]): # TODO: make sure we don't choke on white-on-transparent images - file_id = image_info['filesystem_id'] + file_id = image_info["filesystem_id"] dims = yield self.media_repo._generate_thumbnails( - None, file_id, file_id, image_info["media_type"], - url_cache=True, + None, file_id, file_id, image_info["media_type"], url_cache=True ) if dims: - og["og:image:width"] = dims['width'] - og["og:image:height"] = dims['height'] + og["og:image:width"] = dims["width"] + og["og:image:height"] = dims["height"] else: logger.warn("Couldn't get dims for %s" % og["og:image"]) og["og:image"] = "mxc://%s/%s" % ( - self.server_name, image_info['filesystem_id'] + self.server_name, + image_info["filesystem_id"], ) - og["og:image:type"] = image_info['media_type'] - og["matrix:image:size"] = image_info['media_length'] + og["og:image:type"] = image_info["media_type"] + og["matrix:image:size"] = image_info["media_length"] else: del og["og:image"] else: @@ -289,7 +276,7 @@ def _do_preview(self, url, user, ts): logger.debug("Calculated OG for %s as %s" % (url, og)) - jsonog = json.dumps(og).encode('utf8') + jsonog = json.dumps(og).encode("utf8") # store OG in history-aware DB cache yield self.store.store_url_cache( @@ -310,19 +297,15 @@ def _download_url(self, url, user): # we're most likely being explicitly triggered by a human rather than a # bot, so are we really a robot? - file_id = datetime.date.today().isoformat() + '_' + random_string(16) + file_id = datetime.date.today().isoformat() + "_" + random_string(16) - file_info = FileInfo( - server_name=None, - file_id=file_id, - url_cache=True, - ) + file_info = FileInfo(server_name=None, file_id=file_id, url_cache=True) with self.media_storage.store_into_file(file_info) as (f, fname, finish): try: logger.debug("Trying to get url '%s'" % url) length, headers, uri, code = yield self.client.get_file( - url, output_stream=f, max_size=self.max_spider_size, + url, output_stream=f, max_size=self.max_spider_size ) except SynapseError: # Pass SynapseErrors through directly, so that the servlet @@ -334,24 +317,25 @@ def _download_url(self, url, user): # Note: This will also be the case if one of the resolved IP # addresses is blacklisted raise SynapseError( - 502, "DNS resolution failure during URL preview generation", - Codes.UNKNOWN + 502, + "DNS resolution failure during URL preview generation", + Codes.UNKNOWN, ) except Exception as e: # FIXME: pass through 404s and other error messages nicely logger.warn("Error downloading %s: %r", url, e) raise SynapseError( - 500, "Failed to download content: %s" % ( - traceback.format_exception_only(sys.exc_info()[0], e), - ), + 500, + "Failed to download content: %s" + % (traceback.format_exception_only(sys.exc_info()[0], e),), Codes.UNKNOWN, ) yield finish() try: if b"Content-Type" in headers: - media_type = headers[b"Content-Type"][0].decode('ascii') + media_type = headers[b"Content-Type"][0].decode("ascii") else: media_type = "application/octet-stream" time_now_ms = self.clock.time_msec() @@ -375,24 +359,26 @@ def _download_url(self, url, user): # therefore not expire it. raise - defer.returnValue({ - "media_type": media_type, - "media_length": length, - "download_name": download_name, - "created_ts": time_now_ms, - "filesystem_id": file_id, - "filename": fname, - "uri": uri, - "response_code": code, - # FIXME: we should calculate a proper expiration based on the - # Cache-Control and Expire headers. But for now, assume 1 hour. - "expires": 60 * 60 * 1000, - "etag": headers["ETag"][0] if "ETag" in headers else None, - }) + defer.returnValue( + { + "media_type": media_type, + "media_length": length, + "download_name": download_name, + "created_ts": time_now_ms, + "filesystem_id": file_id, + "filename": fname, + "uri": uri, + "response_code": code, + # FIXME: we should calculate a proper expiration based on the + # Cache-Control and Expire headers. But for now, assume 1 hour. + "expires": 60 * 60 * 1000, + "etag": headers["ETag"][0] if "ETag" in headers else None, + } + ) def _start_expire_url_cache_data(self): return run_as_background_process( - "expire_url_cache_data", self._expire_url_cache_data, + "expire_url_cache_data", self._expire_url_cache_data ) @defer.inlineCallbacks @@ -496,7 +482,7 @@ def decode_and_calc_og(body, media_uri, request_encoding=None): # blindly try decoding the body as utf-8, which seems to fix # the charset mismatches on https://google.com parser = etree.HTMLParser(recover=True, encoding=request_encoding) - tree = etree.fromstring(body.decode('utf-8', 'ignore'), parser) + tree = etree.fromstring(body.decode("utf-8", "ignore"), parser) og = _calc_og(tree, media_uri) return og @@ -523,8 +509,8 @@ def _calc_og(tree, media_uri): og = {} for tag in tree.xpath("//*/meta[starts-with(@property, 'og:')]"): - if 'content' in tag.attrib: - og[tag.attrib['property']] = tag.attrib['content'] + if "content" in tag.attrib: + og[tag.attrib["property"]] = tag.attrib["content"] # TODO: grab article: meta tags too, e.g.: @@ -535,39 +521,43 @@ def _calc_og(tree, media_uri): # "article:published_time" content="2016-03-31T19:58:24+00:00" /> # "article:modified_time" content="2016-04-01T18:31:53+00:00" /> - if 'og:title' not in og: + if "og:title" not in og: # do some basic spidering of the HTML title = tree.xpath("(//title)[1] | (//h1)[1] | (//h2)[1] | (//h3)[1]") if title and title[0].text is not None: - og['og:title'] = title[0].text.strip() + og["og:title"] = title[0].text.strip() else: - og['og:title'] = None + og["og:title"] = None - if 'og:image' not in og: + if "og:image" not in og: # TODO: extract a favicon failing all else meta_image = tree.xpath( "//*/meta[translate(@itemprop, 'IMAGE', 'image')='image']/@content" ) if meta_image: - og['og:image'] = _rebase_url(meta_image[0], media_uri) + og["og:image"] = _rebase_url(meta_image[0], media_uri) else: # TODO: consider inlined CSS styles as well as width & height attribs images = tree.xpath("//img[@src][number(@width)>10][number(@height)>10]") - images = sorted(images, key=lambda i: ( - -1 * float(i.attrib['width']) * float(i.attrib['height']) - )) + images = sorted( + images, + key=lambda i: ( + -1 * float(i.attrib["width"]) * float(i.attrib["height"]) + ), + ) if not images: images = tree.xpath("//img[@src]") if images: - og['og:image'] = images[0].attrib['src'] + og["og:image"] = images[0].attrib["src"] - if 'og:description' not in og: + if "og:description" not in og: meta_description = tree.xpath( "//*/meta" "[translate(@name, 'DESCRIPTION', 'description')='description']" - "/@content") + "/@content" + ) if meta_description: - og['og:description'] = meta_description[0] + og["og:description"] = meta_description[0] else: # grab any text nodes which are inside the tag... # unless they are within an HTML5 semantic markup tag... @@ -588,18 +578,18 @@ def _calc_og(tree, media_uri): "script", "noscript", "style", - etree.Comment + etree.Comment, ) # Split all the text nodes into paragraphs (by splitting on new # lines) text_nodes = ( - re.sub(r'\s+', '\n', el).strip() + re.sub(r"\s+", "\n", el).strip() for el in _iterate_over_text(tree.find("body"), *TAGS_TO_REMOVE) ) - og['og:description'] = summarize_paragraphs(text_nodes) + og["og:description"] = summarize_paragraphs(text_nodes) else: - og['og:description'] = summarize_paragraphs([og['og:description']]) + og["og:description"] = summarize_paragraphs([og["og:description"]]) # TODO: delete the url downloads to stop diskfilling, # as we only ever cared about its OG @@ -636,7 +626,7 @@ def _iterate_over_text(tree, *tags_to_ignore): [child, child.tail] if child.tail else [child] for child in el.iterchildren() ), - elements + elements, ) @@ -647,8 +637,8 @@ def _rebase_url(url, base): url[0] = base[0] or "http" if not url[1]: # fix up hostname url[1] = base[1] - if not url[2].startswith('/'): - url[2] = re.sub(r'/[^/]+$', '/', base[2]) + url[2] + if not url[2].startswith("/"): + url[2] = re.sub(r"/[^/]+$", "/", base[2]) + url[2] return urlparse.urlunparse(url) @@ -659,9 +649,8 @@ def _is_media(content_type): def _is_html(content_type): content_type = content_type.lower() - if ( - content_type.startswith("text/html") or - content_type.startswith("application/xhtml") + if content_type.startswith("text/html") or content_type.startswith( + "application/xhtml" ): return True @@ -671,19 +660,19 @@ def summarize_paragraphs(text_nodes, min_size=200, max_size=500): # first paragraph and then word boundaries. # TODO: Respect sentences? - description = '' + description = "" # Keep adding paragraphs until we get to the MIN_SIZE. for text_node in text_nodes: if len(description) < min_size: - text_node = re.sub(r'[\t \r\n]+', ' ', text_node) - description += text_node + '\n\n' + text_node = re.sub(r"[\t \r\n]+", " ", text_node) + description += text_node + "\n\n" else: break description = description.strip() - description = re.sub(r'[\t ]+', ' ', description) - description = re.sub(r'[\t \r\n]*[\r\n]+', '\n\n', description) + description = re.sub(r"[\t ]+", " ", description) + description = re.sub(r"[\t \r\n]*[\r\n]+", "\n\n", description) # If the concatenation of paragraphs to get above MIN_SIZE # took us over MAX_SIZE, then we need to truncate mid paragraph @@ -715,5 +704,5 @@ def summarize_paragraphs(text_nodes, min_size=200, max_size=500): # We always add an ellipsis because at the very least # we chopped mid paragraph. - description = new_desc.strip() + u"…" + description = new_desc.strip() + "…" return description if description else None diff --git a/synapse/rest/media/v1/storage_provider.py b/synapse/rest/media/v1/storage_provider.py index d90cbfb56a6a..359b45ebfc4f 100644 --- a/synapse/rest/media/v1/storage_provider.py +++ b/synapse/rest/media/v1/storage_provider.py @@ -32,6 +32,7 @@ class StorageProvider(object): """A storage provider is a service that can store uploaded media and retrieve them. """ + def store_file(self, path, file_info): """Store the file described by file_info. The actual contents can be retrieved by reading the file in file_info.upload_path. @@ -70,6 +71,7 @@ class StorageProviderWrapper(StorageProvider): uploaded, or todo the upload in the backgroud. store_remote (bool): Whether remote media should be uploaded """ + def __init__(self, backend, store_local, store_synchronous, store_remote): self.backend = backend self.store_local = store_local @@ -92,6 +94,7 @@ def store(): return self.backend.store_file(path, file_info) except Exception: logger.exception("Error storing file") + run_in_background(store) return defer.succeed(None) @@ -123,8 +126,7 @@ def store_file(self, path, file_info): os.makedirs(dirname) return logcontext.defer_to_thread( - self.hs.get_reactor(), - shutil.copyfile, primary_fname, backup_fname, + self.hs.get_reactor(), shutil.copyfile, primary_fname, backup_fname ) def fetch(self, path, file_info): diff --git a/synapse/rest/media/v1/thumbnail_resource.py b/synapse/rest/media/v1/thumbnail_resource.py index 35a750923b6f..ca84c9f13941 100644 --- a/synapse/rest/media/v1/thumbnail_resource.py +++ b/synapse/rest/media/v1/thumbnail_resource.py @@ -74,19 +74,18 @@ def _async_render_GET(self, request): else: if self.dynamic_thumbnails: yield self._select_or_generate_remote_thumbnail( - request, server_name, media_id, - width, height, method, m_type + request, server_name, media_id, width, height, method, m_type ) else: yield self._respond_remote_thumbnail( - request, server_name, media_id, - width, height, method, m_type + request, server_name, media_id, width, height, method, m_type ) self.media_repo.mark_recently_accessed(server_name, media_id) @defer.inlineCallbacks - def _respond_local_thumbnail(self, request, media_id, width, height, - method, m_type): + def _respond_local_thumbnail( + self, request, media_id, width, height, method, m_type + ): media_info = yield self.store.get_local_media(media_id) if not media_info: @@ -105,7 +104,8 @@ def _respond_local_thumbnail(self, request, media_id, width, height, ) file_info = FileInfo( - server_name=None, file_id=media_id, + server_name=None, + file_id=media_id, url_cache=media_info["url_cache"], thumbnail=True, thumbnail_width=thumbnail_info["thumbnail_width"], @@ -124,9 +124,15 @@ def _respond_local_thumbnail(self, request, media_id, width, height, respond_404(request) @defer.inlineCallbacks - def _select_or_generate_local_thumbnail(self, request, media_id, desired_width, - desired_height, desired_method, - desired_type): + def _select_or_generate_local_thumbnail( + self, + request, + media_id, + desired_width, + desired_height, + desired_method, + desired_type, + ): media_info = yield self.store.get_local_media(media_id) if not media_info: @@ -146,7 +152,8 @@ def _select_or_generate_local_thumbnail(self, request, media_id, desired_width, if t_w and t_h and t_method and t_type: file_info = FileInfo( - server_name=None, file_id=media_id, + server_name=None, + file_id=media_id, url_cache=media_info["url_cache"], thumbnail=True, thumbnail_width=info["thumbnail_width"], @@ -167,7 +174,11 @@ def _select_or_generate_local_thumbnail(self, request, media_id, desired_width, # Okay, so we generate one. file_path = yield self.media_repo.generate_local_exact_thumbnail( - media_id, desired_width, desired_height, desired_method, desired_type, + media_id, + desired_width, + desired_height, + desired_method, + desired_type, url_cache=media_info["url_cache"], ) @@ -178,13 +189,20 @@ def _select_or_generate_local_thumbnail(self, request, media_id, desired_width, respond_404(request) @defer.inlineCallbacks - def _select_or_generate_remote_thumbnail(self, request, server_name, media_id, - desired_width, desired_height, - desired_method, desired_type): + def _select_or_generate_remote_thumbnail( + self, + request, + server_name, + media_id, + desired_width, + desired_height, + desired_method, + desired_type, + ): media_info = yield self.media_repo.get_remote_media_info(server_name, media_id) thumbnail_infos = yield self.store.get_remote_media_thumbnails( - server_name, media_id, + server_name, media_id ) file_id = media_info["filesystem_id"] @@ -197,7 +215,8 @@ def _select_or_generate_remote_thumbnail(self, request, server_name, media_id, if t_w and t_h and t_method and t_type: file_info = FileInfo( - server_name=server_name, file_id=media_info["filesystem_id"], + server_name=server_name, + file_id=media_info["filesystem_id"], thumbnail=True, thumbnail_width=info["thumbnail_width"], thumbnail_height=info["thumbnail_height"], @@ -217,8 +236,13 @@ def _select_or_generate_remote_thumbnail(self, request, server_name, media_id, # Okay, so we generate one. file_path = yield self.media_repo.generate_remote_exact_thumbnail( - server_name, file_id, media_id, desired_width, - desired_height, desired_method, desired_type + server_name, + file_id, + media_id, + desired_width, + desired_height, + desired_method, + desired_type, ) if file_path: @@ -228,15 +252,16 @@ def _select_or_generate_remote_thumbnail(self, request, server_name, media_id, respond_404(request) @defer.inlineCallbacks - def _respond_remote_thumbnail(self, request, server_name, media_id, width, - height, method, m_type): + def _respond_remote_thumbnail( + self, request, server_name, media_id, width, height, method, m_type + ): # TODO: Don't download the whole remote file # We should proxy the thumbnail from the remote server instead of # downloading the remote file and generating our own thumbnails. media_info = yield self.media_repo.get_remote_media_info(server_name, media_id) thumbnail_infos = yield self.store.get_remote_media_thumbnails( - server_name, media_id, + server_name, media_id ) if thumbnail_infos: @@ -244,7 +269,8 @@ def _respond_remote_thumbnail(self, request, server_name, media_id, width, width, height, method, m_type, thumbnail_infos ) file_info = FileInfo( - server_name=server_name, file_id=media_info["filesystem_id"], + server_name=server_name, + file_id=media_info["filesystem_id"], thumbnail=True, thumbnail_width=thumbnail_info["thumbnail_width"], thumbnail_height=thumbnail_info["thumbnail_height"], @@ -261,8 +287,14 @@ def _respond_remote_thumbnail(self, request, server_name, media_id, width, logger.info("Failed to find any generated thumbnails") respond_404(request) - def _select_thumbnail(self, desired_width, desired_height, desired_method, - desired_type, thumbnail_infos): + def _select_thumbnail( + self, + desired_width, + desired_height, + desired_method, + desired_type, + thumbnail_infos, + ): d_w = desired_width d_h = desired_height @@ -280,15 +312,27 @@ def _select_thumbnail(self, desired_width, desired_height, desired_method, type_quality = desired_type != info["thumbnail_type"] length_quality = info["thumbnail_length"] if t_w >= d_w or t_h >= d_h: - info_list.append(( - aspect_quality, min_quality, size_quality, type_quality, - length_quality, info - )) + info_list.append( + ( + aspect_quality, + min_quality, + size_quality, + type_quality, + length_quality, + info, + ) + ) else: - info_list2.append(( - aspect_quality, min_quality, size_quality, type_quality, - length_quality, info - )) + info_list2.append( + ( + aspect_quality, + min_quality, + size_quality, + type_quality, + length_quality, + info, + ) + ) if info_list: return min(info_list)[-1] else: @@ -304,13 +348,11 @@ def _select_thumbnail(self, desired_width, desired_height, desired_method, type_quality = desired_type != info["thumbnail_type"] length_quality = info["thumbnail_length"] if t_method == "scale" and (t_w >= d_w or t_h >= d_h): - info_list.append(( - size_quality, type_quality, length_quality, info - )) + info_list.append((size_quality, type_quality, length_quality, info)) elif t_method == "scale": - info_list2.append(( - size_quality, type_quality, length_quality, info - )) + info_list2.append( + (size_quality, type_quality, length_quality, info) + ) if info_list: return min(info_list)[-1] else: diff --git a/synapse/rest/media/v1/thumbnailer.py b/synapse/rest/media/v1/thumbnailer.py index 3efd0d80fc16..90d8e6bffe67 100644 --- a/synapse/rest/media/v1/thumbnailer.py +++ b/synapse/rest/media/v1/thumbnailer.py @@ -28,16 +28,13 @@ 5: Image.TRANSPOSE, 6: Image.ROTATE_270, 7: Image.TRANSVERSE, - 8: Image.ROTATE_90 + 8: Image.ROTATE_90, } class Thumbnailer(object): - FORMATS = { - "image/jpeg": "JPEG", - "image/png": "PNG", - } + FORMATS = {"image/jpeg": "JPEG", "image/png": "PNG"} def __init__(self, input_path): self.image = Image.open(input_path) @@ -110,17 +107,13 @@ def crop(self, width, height, output_type): """ if width * self.height > height * self.width: scaled_height = (width * self.height) // self.width - scaled_image = self.image.resize( - (width, scaled_height), Image.ANTIALIAS - ) + scaled_image = self.image.resize((width, scaled_height), Image.ANTIALIAS) crop_top = (scaled_height - height) // 2 crop_bottom = height + crop_top cropped = scaled_image.crop((0, crop_top, width, crop_bottom)) else: scaled_width = (height * self.width) // self.height - scaled_image = self.image.resize( - (scaled_width, height), Image.ANTIALIAS - ) + scaled_image = self.image.resize((scaled_width, height), Image.ANTIALIAS) crop_left = (scaled_width - width) // 2 crop_right = width + crop_left cropped = scaled_image.crop((crop_left, 0, crop_right, height)) diff --git a/synapse/rest/media/v1/upload_resource.py b/synapse/rest/media/v1/upload_resource.py index c1240e1963dd..d1d7e959f0fa 100644 --- a/synapse/rest/media/v1/upload_resource.py +++ b/synapse/rest/media/v1/upload_resource.py @@ -55,48 +55,36 @@ def _async_render_POST(self, request): requester = yield self.auth.get_user_by_req(request) # TODO: The checks here are a bit late. The content will have # already been uploaded to a tmp file at this point - content_length = request.getHeader(b"Content-Length").decode('ascii') + content_length = request.getHeader(b"Content-Length").decode("ascii") if content_length is None: - raise SynapseError( - msg="Request must specify a Content-Length", code=400 - ) + raise SynapseError(msg="Request must specify a Content-Length", code=400) if int(content_length) > self.max_upload_size: - raise SynapseError( - msg="Upload request body is too large", - code=413, - ) + raise SynapseError(msg="Upload request body is too large", code=413) upload_name = parse_string(request, b"filename", encoding=None) if upload_name: try: - upload_name = upload_name.decode('utf8') + upload_name = upload_name.decode("utf8") except UnicodeDecodeError: raise SynapseError( - msg="Invalid UTF-8 filename parameter: %r" % (upload_name), - code=400, + msg="Invalid UTF-8 filename parameter: %r" % (upload_name), code=400 ) headers = request.requestHeaders if headers.hasHeader(b"Content-Type"): - media_type = headers.getRawHeaders(b"Content-Type")[0].decode('ascii') + media_type = headers.getRawHeaders(b"Content-Type")[0].decode("ascii") else: - raise SynapseError( - msg="Upload request missing 'Content-Type'", - code=400, - ) + raise SynapseError(msg="Upload request missing 'Content-Type'", code=400) # if headers.hasHeader(b"Content-Disposition"): # disposition = headers.getRawHeaders(b"Content-Disposition")[0] # TODO(markjh): parse content-dispostion content_uri = yield self.media_repo.create_content( - media_type, upload_name, request.content, - content_length, requester.user + media_type, upload_name, request.content, content_length, requester.user ) logger.info("Uploaded content with URI %r", content_uri) - respond_with_json( - request, 200, {"content_uri": content_uri}, send_cors=True - ) + respond_with_json(request, 200, {"content_uri": content_uri}, send_cors=True) diff --git a/synapse/rest/saml2/metadata_resource.py b/synapse/rest/saml2/metadata_resource.py index e8c680aeb454..1e8526e22e71 100644 --- a/synapse/rest/saml2/metadata_resource.py +++ b/synapse/rest/saml2/metadata_resource.py @@ -30,7 +30,7 @@ def __init__(self, hs): def render_GET(self, request): metadata_xml = saml2.metadata.create_metadata_string( - configfile=None, config=self.sp_config, + configfile=None, config=self.sp_config ) request.setHeader(b"Content-Type", b"text/xml; charset=utf-8") return metadata_xml diff --git a/synapse/rest/saml2/response_resource.py b/synapse/rest/saml2/response_resource.py index 69fb77b32267..ab14b7067538 100644 --- a/synapse/rest/saml2/response_resource.py +++ b/synapse/rest/saml2/response_resource.py @@ -46,18 +46,16 @@ def render_POST(self, request): @wrap_html_request_handler def _async_render_POST(self, request): - resp_bytes = parse_string(request, 'SAMLResponse', required=True) - relay_state = parse_string(request, 'RelayState', required=True) + resp_bytes = parse_string(request, "SAMLResponse", required=True) + relay_state = parse_string(request, "RelayState", required=True) try: saml2_auth = self._saml_client.parse_authn_request_response( - resp_bytes, saml2.BINDING_HTTP_POST, + resp_bytes, saml2.BINDING_HTTP_POST ) except Exception as e: logger.warning("Exception parsing SAML2 response", exc_info=1) - raise CodeMessageException( - 400, "Unable to parse SAML2 response: %s" % (e,), - ) + raise CodeMessageException(400, "Unable to parse SAML2 response: %s" % (e,)) if saml2_auth.not_signed: raise CodeMessageException(400, "SAML2 response was not signed") @@ -69,6 +67,5 @@ def _async_render_POST(self, request): displayName = saml2_auth.ava.get("displayName", [None])[0] return self._sso_auth_handler.on_successful_auth( - username, request, relay_state, - user_display_name=displayName, + username, request, relay_state, user_display_name=displayName ) diff --git a/synapse/rest/well_known.py b/synapse/rest/well_known.py index a7fa4f39af71..5e8fda4b6575 100644 --- a/synapse/rest/well_known.py +++ b/synapse/rest/well_known.py @@ -29,6 +29,7 @@ class WellKnownBuilder(object): Args: hs (synapse.server.HomeServer): """ + def __init__(self, hs): self._config = hs.config @@ -37,15 +38,11 @@ def get_well_known(self): if self._config.public_baseurl is None: return None - result = { - "m.homeserver": { - "base_url": self._config.public_baseurl, - }, - } + result = {"m.homeserver": {"base_url": self._config.public_baseurl}} if self._config.default_identity_server: result["m.identity_server"] = { - "base_url": self._config.default_identity_server, + "base_url": self._config.default_identity_server } return result @@ -66,7 +63,7 @@ def render_GET(self, request): if not r: request.setResponseCode(404) request.setHeader(b"Content-Type", b"text/plain") - return b'.well-known not available' + return b".well-known not available" logger.debug("returning: %s", r) request.setHeader(b"Content-Type", b"application/json") diff --git a/synapse/secrets.py b/synapse/secrets.py index f6280f951cdf..0b327a0f8233 100644 --- a/synapse/secrets.py +++ b/synapse/secrets.py @@ -29,6 +29,7 @@ def Secrets(): return secrets + else: import os import binascii @@ -38,4 +39,4 @@ def token_bytes(self, nbytes=32): return os.urandom(nbytes) def token_hex(self, nbytes=32): - return binascii.hexlify(self.token_bytes(nbytes)).decode('ascii') + return binascii.hexlify(self.token_bytes(nbytes)).decode("ascii") diff --git a/synapse/server.py b/synapse/server.py index a54e023cc98c..a9592c396c74 100644 --- a/synapse/server.py +++ b/synapse/server.py @@ -91,7 +91,9 @@ from synapse.secrets import Secrets from synapse.server_notices.server_notices_manager import ServerNoticesManager from synapse.server_notices.server_notices_sender import ServerNoticesSender -from synapse.server_notices.worker_server_notices_sender import WorkerServerNoticesSender +from synapse.server_notices.worker_server_notices_sender import ( + WorkerServerNoticesSender, +) from synapse.state import StateHandler, StateResolutionHandler from synapse.streams.events import EventSources from synapse.util import Clock @@ -127,79 +129,76 @@ def build_DEPENDENCY(self) __metaclass__ = abc.ABCMeta DEPENDENCIES = [ - 'http_client', - 'db_pool', - 'federation_client', - 'federation_server', - 'handlers', - 'auth', - 'room_creation_handler', - 'state_handler', - 'state_resolution_handler', - 'presence_handler', - 'sync_handler', - 'typing_handler', - 'room_list_handler', - 'acme_handler', - 'auth_handler', - 'device_handler', - 'stats_handler', - 'e2e_keys_handler', - 'e2e_room_keys_handler', - 'event_handler', - 'event_stream_handler', - 'initial_sync_handler', - 'application_service_api', - 'application_service_scheduler', - 'application_service_handler', - 'device_message_handler', - 'profile_handler', - 'event_creation_handler', - 'deactivate_account_handler', - 'set_password_handler', - 'notifier', - 'event_sources', - 'keyring', - 'pusherpool', - 'event_builder_factory', - 'filtering', - 'http_client_context_factory', - 'simple_http_client', - 'media_repository', - 'media_repository_resource', - 'federation_transport_client', - 'federation_sender', - 'receipts_handler', - 'macaroon_generator', - 'tcp_replication', - 'read_marker_handler', - 'action_generator', - 'user_directory_handler', - 'groups_local_handler', - 'groups_server_handler', - 'groups_attestation_signing', - 'groups_attestation_renewer', - 'secrets', - 'spam_checker', - 'third_party_event_rules', - 'room_member_handler', - 'federation_registry', - 'server_notices_manager', - 'server_notices_sender', - 'message_handler', - 'pagination_handler', - 'room_context_handler', - 'sendmail', - 'registration_handler', - 'account_validity_handler', - 'event_client_serializer', - ] - - REQUIRED_ON_MASTER_STARTUP = [ + "http_client", + "db_pool", + "federation_client", + "federation_server", + "handlers", + "auth", + "room_creation_handler", + "state_handler", + "state_resolution_handler", + "presence_handler", + "sync_handler", + "typing_handler", + "room_list_handler", + "acme_handler", + "auth_handler", + "device_handler", + "stats_handler", + "e2e_keys_handler", + "e2e_room_keys_handler", + "event_handler", + "event_stream_handler", + "initial_sync_handler", + "application_service_api", + "application_service_scheduler", + "application_service_handler", + "device_message_handler", + "profile_handler", + "event_creation_handler", + "deactivate_account_handler", + "set_password_handler", + "notifier", + "event_sources", + "keyring", + "pusherpool", + "event_builder_factory", + "filtering", + "http_client_context_factory", + "simple_http_client", + "media_repository", + "media_repository_resource", + "federation_transport_client", + "federation_sender", + "receipts_handler", + "macaroon_generator", + "tcp_replication", + "read_marker_handler", + "action_generator", "user_directory_handler", - "stats_handler" + "groups_local_handler", + "groups_server_handler", + "groups_attestation_signing", + "groups_attestation_renewer", + "secrets", + "spam_checker", + "third_party_event_rules", + "room_member_handler", + "federation_registry", + "server_notices_manager", + "server_notices_sender", + "message_handler", + "pagination_handler", + "room_context_handler", + "sendmail", + "registration_handler", + "account_validity_handler", + "event_client_serializer", ] + REQUIRED_ON_MASTER_STARTUP = ["user_directory_handler", "stats_handler"] + # This is overridden in derived application classes # (such as synapse.app.homeserver.SynapseHomeServer) and gives the class to be # instantiated during setup() for future return by get_datastore() @@ -410,9 +409,7 @@ def build_db_pool(self): name = self.db_config["name"] return adbapi.ConnectionPool( - name, - cp_reactor=self.get_reactor(), - **self.db_config.get("args", {}) + name, cp_reactor=self.get_reactor(), **self.db_config.get("args", {}) ) def get_db_conn(self, run_new_connection=True): @@ -424,7 +421,8 @@ def get_db_conn(self, run_new_connection=True): # Any param beginning with cp_ is a parameter for adbapi, and should # not be passed to the database engine. db_params = { - k: v for k, v in self.db_config.get("args", {}).items() + k: v + for k, v in self.db_config.get("args", {}).items() if not k.startswith("cp_") } db_conn = self.database_engine.module.connect(**db_params) @@ -555,9 +553,7 @@ def _get(hs): if builder: # Prevent cyclic dependencies from deadlocking if depname in hs._building: - raise ValueError("Cyclic dependency while building %s" % ( - depname, - )) + raise ValueError("Cyclic dependency while building %s" % (depname,)) hs._building[depname] = 1 dep = builder() @@ -568,9 +564,7 @@ def _get(hs): return dep raise NotImplementedError( - "%s has no %s nor a builder for it" % ( - type(hs).__name__, depname, - ) + "%s has no %s nor a builder for it" % (type(hs).__name__, depname) ) setattr(HomeServer, "get_%s" % (depname), _get) diff --git a/synapse/server.pyi b/synapse/server.pyi index 9583e82d5213..16f8f6b573fe 100644 --- a/synapse/server.pyi +++ b/synapse/server.pyi @@ -22,60 +22,57 @@ class HomeServer(object): @property def config(self) -> synapse.config.homeserver.HomeServerConfig: pass - def get_auth(self) -> synapse.api.auth.Auth: pass - def get_auth_handler(self) -> synapse.handlers.auth.AuthHandler: pass - def get_datastore(self) -> synapse.storage.DataStore: pass - def get_device_handler(self) -> synapse.handlers.device.DeviceHandler: pass - def get_e2e_keys_handler(self) -> synapse.handlers.e2e_keys.E2eKeysHandler: pass - def get_handlers(self) -> synapse.handlers.Handlers: pass - def get_state_handler(self) -> synapse.state.StateHandler: pass - def get_state_resolution_handler(self) -> synapse.state.StateResolutionHandler: pass - - def get_deactivate_account_handler(self) -> synapse.handlers.deactivate_account.DeactivateAccountHandler: + def get_deactivate_account_handler( + self + ) -> synapse.handlers.deactivate_account.DeactivateAccountHandler: pass - def get_room_creation_handler(self) -> synapse.handlers.room.RoomCreationHandler: pass - def get_room_member_handler(self) -> synapse.handlers.room_member.RoomMemberHandler: pass - - def get_event_creation_handler(self) -> synapse.handlers.message.EventCreationHandler: + def get_event_creation_handler( + self + ) -> synapse.handlers.message.EventCreationHandler: pass - - def get_set_password_handler(self) -> synapse.handlers.set_password.SetPasswordHandler: + def get_set_password_handler( + self + ) -> synapse.handlers.set_password.SetPasswordHandler: pass - def get_federation_sender(self) -> synapse.federation.sender.FederationSender: pass - - def get_federation_transport_client(self) -> synapse.federation.transport.client.TransportLayerClient: + def get_federation_transport_client( + self + ) -> synapse.federation.transport.client.TransportLayerClient: pass - - def get_media_repository_resource(self) -> synapse.rest.media.v1.media_repository.MediaRepositoryResource: + def get_media_repository_resource( + self + ) -> synapse.rest.media.v1.media_repository.MediaRepositoryResource: pass - - def get_media_repository(self) -> synapse.rest.media.v1.media_repository.MediaRepository: + def get_media_repository( + self + ) -> synapse.rest.media.v1.media_repository.MediaRepository: pass - - def get_server_notices_manager(self) -> synapse.server_notices.server_notices_manager.ServerNoticesManager: + def get_server_notices_manager( + self + ) -> synapse.server_notices.server_notices_manager.ServerNoticesManager: pass - - def get_server_notices_sender(self) -> synapse.server_notices.server_notices_sender.ServerNoticesSender: + def get_server_notices_sender( + self + ) -> synapse.server_notices.server_notices_sender.ServerNoticesSender: pass diff --git a/synapse/server_notices/consent_server_notices.py b/synapse/server_notices/consent_server_notices.py index 5e3044d164ca..415e9c17d8cf 100644 --- a/synapse/server_notices/consent_server_notices.py +++ b/synapse/server_notices/consent_server_notices.py @@ -30,6 +30,7 @@ class ConsentServerNotices(object): """Keeps track of whether we need to send users server_notices about privacy policy consent, and sends one if we do. """ + def __init__(self, hs): """ @@ -49,12 +50,11 @@ def __init__(self, hs): if not self._server_notices_manager.is_enabled(): raise ConfigError( "user_consent configuration requires server notices, but " - "server notices are not enabled.", + "server notices are not enabled." ) - if 'body' not in self._server_notice_content: + if "body" not in self._server_notice_content: raise ConfigError( - "user_consent server_notice_consent must contain a 'body' " - "key.", + "user_consent server_notice_consent must contain a 'body' " "key." ) self._consent_uri_builder = ConsentURIBuilder(hs.config) @@ -95,18 +95,14 @@ def maybe_send_server_notice_to_user(self, user_id): # need to send a message. try: consent_uri = self._consent_uri_builder.build_user_consent_uri( - get_localpart_from_id(user_id), + get_localpart_from_id(user_id) ) content = copy_with_str_subst( - self._server_notice_content, { - 'consent_uri': consent_uri, - }, - ) - yield self._server_notices_manager.send_notice( - user_id, content, + self._server_notice_content, {"consent_uri": consent_uri} ) + yield self._server_notices_manager.send_notice(user_id, content) yield self._store.user_set_consent_server_notice_sent( - user_id, self._current_consent_version, + user_id, self._current_consent_version ) except SynapseError as e: logger.error("Error sending server notice about user consent: %s", e) @@ -128,9 +124,7 @@ def copy_with_str_subst(x, substitutions): if isinstance(x, string_types): return x % substitutions if isinstance(x, dict): - return { - k: copy_with_str_subst(v, substitutions) for (k, v) in iteritems(x) - } + return {k: copy_with_str_subst(v, substitutions) for (k, v) in iteritems(x)} if isinstance(x, (list, tuple)): return [copy_with_str_subst(y) for y in x] diff --git a/synapse/server_notices/resource_limits_server_notices.py b/synapse/server_notices/resource_limits_server_notices.py index af15cba0ee49..f183743f3181 100644 --- a/synapse/server_notices/resource_limits_server_notices.py +++ b/synapse/server_notices/resource_limits_server_notices.py @@ -33,6 +33,7 @@ class ResourceLimitsServerNotices(object): """ Keeps track of whether the server has reached it's resource limit and ensures that the client is kept up to date. """ + def __init__(self, hs): """ Args: @@ -104,34 +105,28 @@ def maybe_send_server_notice_to_user(self, user_id): if currently_blocked and not is_auth_blocking: # Room is notifying of a block, when it ought not to be. # Remove block notification - content = { - "pinned": ref_events - } + content = {"pinned": ref_events} yield self._server_notices_manager.send_notice( - user_id, content, EventTypes.Pinned, '', + user_id, content, EventTypes.Pinned, "" ) elif not currently_blocked and is_auth_blocking: # Room is not notifying of a block, when it ought to be. # Add block notification content = { - 'body': event_content, - 'msgtype': ServerNoticeMsgType, - 'server_notice_type': ServerNoticeLimitReached, - 'admin_contact': self._config.admin_contact, - 'limit_type': event_limit_type + "body": event_content, + "msgtype": ServerNoticeMsgType, + "server_notice_type": ServerNoticeLimitReached, + "admin_contact": self._config.admin_contact, + "limit_type": event_limit_type, } event = yield self._server_notices_manager.send_notice( - user_id, content, EventTypes.Message, + user_id, content, EventTypes.Message ) - content = { - "pinned": [ - event.event_id, - ] - } + content = {"pinned": [event.event_id]} yield self._server_notices_manager.send_notice( - user_id, content, EventTypes.Pinned, '', + user_id, content, EventTypes.Pinned, "" ) except SynapseError as e: @@ -156,9 +151,7 @@ def _check_and_set_tags(self, user_id, room_id): max_id = yield self._store.add_tag_to_room( user_id, room_id, SERVER_NOTICE_ROOM_TAG, {} ) - self._notifier.on_new_event( - "account_data_key", max_id, users=[user_id] - ) + self._notifier.on_new_event("account_data_key", max_id, users=[user_id]) @defer.inlineCallbacks def _is_room_currently_blocked(self, room_id): @@ -188,7 +181,7 @@ def _is_room_currently_blocked(self, room_id): referenced_events = [] if pinned_state_event is not None: - referenced_events = list(pinned_state_event.content.get('pinned', [])) + referenced_events = list(pinned_state_event.content.get("pinned", [])) events = yield self._store.get_events(referenced_events) for event_id, event in iteritems(events): diff --git a/synapse/server_notices/server_notices_manager.py b/synapse/server_notices/server_notices_manager.py index c5cc6d728e76..71e7e7532025 100644 --- a/synapse/server_notices/server_notices_manager.py +++ b/synapse/server_notices/server_notices_manager.py @@ -51,8 +51,7 @@ def is_enabled(self): @defer.inlineCallbacks def send_notice( - self, user_id, event_content, - type=EventTypes.Message, state_key=None + self, user_id, event_content, type=EventTypes.Message, state_key=None ): """Send a notice to the given user @@ -82,10 +81,10 @@ def send_notice( } if state_key is not None: - event_dict['state_key'] = state_key + event_dict["state_key"] = state_key res = yield self._event_creation_handler.create_and_send_nonmember_event( - requester, event_dict, ratelimit=False, + requester, event_dict, ratelimit=False ) defer.returnValue(res) @@ -104,11 +103,10 @@ def get_notice_room_for_user(self, user_id): if not self.is_enabled(): raise Exception("Server notices not enabled") - assert self._is_mine_id(user_id), \ - "Cannot send server notices to remote users" + assert self._is_mine_id(user_id), "Cannot send server notices to remote users" rooms = yield self._store.get_rooms_for_user_where_membership_is( - user_id, [Membership.INVITE, Membership.JOIN], + user_id, [Membership.INVITE, Membership.JOIN] ) system_mxid = self._config.server_notices_mxid for room in rooms: @@ -132,8 +130,8 @@ def get_notice_room_for_user(self, user_id): # avatar, we have to use both. join_profile = None if ( - self._config.server_notices_mxid_display_name is not None or - self._config.server_notices_mxid_avatar_url is not None + self._config.server_notices_mxid_display_name is not None + or self._config.server_notices_mxid_avatar_url is not None ): join_profile = { "displayname": self._config.server_notices_mxid_display_name, @@ -146,22 +144,18 @@ def get_notice_room_for_user(self, user_id): config={ "preset": RoomCreationPreset.PRIVATE_CHAT, "name": self._config.server_notices_room_name, - "power_level_content_override": { - "users_default": -10, - }, - "invite": (user_id,) + "power_level_content_override": {"users_default": -10}, + "invite": (user_id,), }, ratelimit=False, creator_join_profile=join_profile, ) - room_id = info['room_id'] + room_id = info["room_id"] max_id = yield self._store.add_tag_to_room( - user_id, room_id, SERVER_NOTICE_ROOM_TAG, {}, - ) - self._notifier.on_new_event( - "account_data_key", max_id, users=[user_id] + user_id, room_id, SERVER_NOTICE_ROOM_TAG, {} ) + self._notifier.on_new_event("account_data_key", max_id, users=[user_id]) logger.info("Created server notices room %s for %s", room_id, user_id) defer.returnValue(room_id) diff --git a/synapse/server_notices/server_notices_sender.py b/synapse/server_notices/server_notices_sender.py index 6121b2f26778..652bab58e333 100644 --- a/synapse/server_notices/server_notices_sender.py +++ b/synapse/server_notices/server_notices_sender.py @@ -24,6 +24,7 @@ class ServerNoticesSender(object): """A centralised place which sends server notices automatically when Certain Events take place """ + def __init__(self, hs): """ @@ -32,7 +33,7 @@ def __init__(self, hs): """ self._server_notices = ( ConsentServerNotices(hs), - ResourceLimitsServerNotices(hs) + ResourceLimitsServerNotices(hs), ) @defer.inlineCallbacks @@ -43,9 +44,7 @@ def on_user_syncing(self, user_id): user_id (str): mxid of user who synced """ for sn in self._server_notices: - yield sn.maybe_send_server_notice_to_user( - user_id, - ) + yield sn.maybe_send_server_notice_to_user(user_id) @defer.inlineCallbacks def on_user_ip(self, user_id): @@ -58,6 +57,4 @@ def on_user_ip(self, user_id): # we check for notices to send to the user in on_user_ip as well as # in on_user_syncing for sn in self._server_notices: - yield sn.maybe_send_server_notice_to_user( - user_id, - ) + yield sn.maybe_send_server_notice_to_user(user_id) diff --git a/synapse/server_notices/worker_server_notices_sender.py b/synapse/server_notices/worker_server_notices_sender.py index 4a133026c323..245ec7c64ff4 100644 --- a/synapse/server_notices/worker_server_notices_sender.py +++ b/synapse/server_notices/worker_server_notices_sender.py @@ -17,6 +17,7 @@ class WorkerServerNoticesSender(object): """Stub impl of ServerNoticesSender which does nothing""" + def __init__(self, hs): """ Args: diff --git a/synapse/state/__init__.py b/synapse/state/__init__.py index 36684ef9f6e8..fc20d1eaee93 100644 --- a/synapse/state/__init__.py +++ b/synapse/state/__init__.py @@ -98,8 +98,9 @@ def __init__(self, hs): self._state_resolution_handler = hs.get_state_resolution_handler() @defer.inlineCallbacks - def get_current_state(self, room_id, event_type=None, state_key="", - latest_event_ids=None): + def get_current_state( + self, room_id, event_type=None, state_key="", latest_event_ids=None + ): """ Retrieves the current state for the room. This is done by calling `get_latest_events_in_room` to get the leading edges of the event graph and then resolving any of the state conflicts. @@ -128,8 +129,9 @@ def get_current_state(self, room_id, event_type=None, state_key="", defer.returnValue(event) return - state_map = yield self.store.get_events(list(state.values()), - get_prev_content=False) + state_map = yield self.store.get_events( + list(state.values()), get_prev_content=False + ) state = { key: state_map[e_id] for key, e_id in iteritems(state) if e_id in state_map } @@ -211,9 +213,7 @@ def compute_event_context(self, event, old_state=None): # state. Certainly store.get_current_state won't return any, and # persisting the event won't store the state group. if old_state: - prev_state_ids = { - (s.type, s.state_key): s.event_id for s in old_state - } + prev_state_ids = {(s.type, s.state_key): s.event_id for s in old_state} if event.is_state(): current_state_ids = dict(prev_state_ids) key = (event.type, event.state_key) @@ -239,9 +239,7 @@ def compute_event_context(self, event, old_state=None): # Let's just correctly fill out the context and create a # new state group for it. - prev_state_ids = { - (s.type, s.state_key): s.event_id for s in old_state - } + prev_state_ids = {(s.type, s.state_key): s.event_id for s in old_state} if event.is_state(): key = (event.type, event.state_key) @@ -273,7 +271,7 @@ def compute_event_context(self, event, old_state=None): logger.debug("calling resolve_state_groups from compute_event_context") entry = yield self.resolve_state_groups_for_events( - event.room_id, event.prev_event_ids(), + event.room_id, event.prev_event_ids() ) prev_state_ids = entry.state @@ -296,9 +294,7 @@ def compute_event_context(self, event, old_state=None): # If the state at the event has a state group assigned then # we can use that as the prev group prev_group = entry.state_group - delta_ids = { - key: event.event_id - } + delta_ids = {key: event.event_id} elif entry.prev_group: # If the state at the event only has a prev group, then we can # use that as a prev group too. @@ -360,31 +356,31 @@ def resolve_state_groups_for_events(self, room_id, event_ids): # map from state group id to the state in that state group (where # 'state' is a map from state key to event id) # dict[int, dict[(str, str), str]] - state_groups_ids = yield self.store.get_state_groups_ids( - room_id, event_ids - ) + state_groups_ids = yield self.store.get_state_groups_ids(room_id, event_ids) if len(state_groups_ids) == 0: - defer.returnValue(_StateCacheEntry( - state={}, - state_group=None, - )) + defer.returnValue(_StateCacheEntry(state={}, state_group=None)) elif len(state_groups_ids) == 1: name, state_list = list(state_groups_ids.items()).pop() prev_group, delta_ids = yield self.store.get_state_group_delta(name) - defer.returnValue(_StateCacheEntry( - state=state_list, - state_group=name, - prev_group=prev_group, - delta_ids=delta_ids, - )) + defer.returnValue( + _StateCacheEntry( + state=state_list, + state_group=name, + prev_group=prev_group, + delta_ids=delta_ids, + ) + ) room_version = yield self.store.get_room_version(room_id) result = yield self._state_resolution_handler.resolve_state_groups( - room_id, room_version, state_groups_ids, None, + room_id, + room_version, + state_groups_ids, + None, state_res_store=StateResolutionStore(self.store), ) defer.returnValue(result) @@ -394,27 +390,21 @@ def resolve_events(self, room_version, state_sets, event): logger.info( "Resolving state for %s with %d groups", event.room_id, len(state_sets) ) - state_set_ids = [{ - (ev.type, ev.state_key): ev.event_id - for ev in st - } for st in state_sets] - - state_map = { - ev.event_id: ev - for st in state_sets - for ev in st - } + state_set_ids = [ + {(ev.type, ev.state_key): ev.event_id for ev in st} for st in state_sets + ] + + state_map = {ev.event_id: ev for st in state_sets for ev in st} with Measure(self.clock, "state._resolve_events"): new_state = yield resolve_events_with_store( - room_version, state_set_ids, + room_version, + state_set_ids, event_map=state_map, state_res_store=StateResolutionStore(self.store), ) - new_state = { - key: state_map[ev_id] for key, ev_id in iteritems(new_state) - } + new_state = {key: state_map[ev_id] for key, ev_id in iteritems(new_state)} defer.returnValue(new_state) @@ -425,6 +415,7 @@ class StateResolutionHandler(object): Note that the storage layer depends on this handler, so all functions must be storage-independent. """ + def __init__(self, hs): self.clock = hs.get_clock() @@ -444,7 +435,7 @@ def __init__(self, hs): @defer.inlineCallbacks @log_function def resolve_state_groups( - self, room_id, room_version, state_groups_ids, event_map, state_res_store, + self, room_id, room_version, state_groups_ids, event_map, state_res_store ): """Resolves conflicts between a set of state groups @@ -471,10 +462,7 @@ def resolve_state_groups( Returns: Deferred[_StateCacheEntry]: resolved state """ - logger.debug( - "resolve_state_groups state_groups %s", - state_groups_ids.keys() - ) + logger.debug("resolve_state_groups state_groups %s", state_groups_ids.keys()) group_names = frozenset(state_groups_ids.keys()) @@ -529,10 +517,7 @@ def resolve_state_groups( defer.returnValue(cache) -def _make_state_cache_entry( - new_state, - state_groups_ids, -): +def _make_state_cache_entry(new_state, state_groups_ids): """Given a resolved state, and a set of input state groups, pick one to base a new state group on (if any), and return an appropriately-constructed _StateCacheEntry. @@ -562,10 +547,7 @@ def _make_state_cache_entry( old_state_event_ids = set(itervalues(state)) if new_state_event_ids == old_state_event_ids: # got an exact match. - return _StateCacheEntry( - state=new_state, - state_group=sg, - ) + return _StateCacheEntry(state=new_state, state_group=sg) # TODO: We want to create a state group for this set of events, to # increase cache hits, but we need to make sure that it doesn't @@ -576,20 +558,13 @@ def _make_state_cache_entry( delta_ids = None for old_group, old_state in iteritems(state_groups_ids): - n_delta_ids = { - k: v - for k, v in iteritems(new_state) - if old_state.get(k) != v - } + n_delta_ids = {k: v for k, v in iteritems(new_state) if old_state.get(k) != v} if not delta_ids or len(n_delta_ids) < len(delta_ids): prev_group = old_group delta_ids = n_delta_ids return _StateCacheEntry( - state=new_state, - state_group=None, - prev_group=prev_group, - delta_ids=delta_ids, + state=new_state, state_group=None, prev_group=prev_group, delta_ids=delta_ids ) @@ -618,11 +593,11 @@ def resolve_events_with_store(room_version, state_sets, event_map, state_res_sto v = KNOWN_ROOM_VERSIONS[room_version] if v.state_res == StateResolutionVersions.V1: return v1.resolve_events_with_store( - state_sets, event_map, state_res_store.get_events, + state_sets, event_map, state_res_store.get_events ) else: return v2.resolve_events_with_store( - room_version, state_sets, event_map, state_res_store, + room_version, state_sets, event_map, state_res_store ) diff --git a/synapse/state/v1.py b/synapse/state/v1.py index 29b4e86cfd01..88acd4817e6e 100644 --- a/synapse/state/v1.py +++ b/synapse/state/v1.py @@ -57,23 +57,17 @@ def resolve_events_with_store(state_sets, event_map, state_map_factory): if len(state_sets) == 1: defer.returnValue(state_sets[0]) - unconflicted_state, conflicted_state = _seperate( - state_sets, - ) + unconflicted_state, conflicted_state = _seperate(state_sets) needed_events = set( - event_id - for event_ids in itervalues(conflicted_state) - for event_id in event_ids + event_id for event_ids in itervalues(conflicted_state) for event_id in event_ids ) needed_event_count = len(needed_events) if event_map is not None: needed_events -= set(iterkeys(event_map)) logger.info( - "Asking for %d/%d conflicted events", - len(needed_events), - needed_event_count, + "Asking for %d/%d conflicted events", len(needed_events), needed_event_count ) # dict[str, FrozenEvent]: a map from state event id to event. Only includes @@ -97,17 +91,17 @@ def resolve_events_with_store(state_sets, event_map, state_map_factory): new_needed_events -= set(iterkeys(event_map)) logger.info( - "Asking for %d/%d auth events", - len(new_needed_events), - new_needed_event_count, + "Asking for %d/%d auth events", len(new_needed_events), new_needed_event_count ) state_map_new = yield state_map_factory(new_needed_events) state_map.update(state_map_new) - defer.returnValue(_resolve_with_state( - unconflicted_state, conflicted_state, auth_events, state_map - )) + defer.returnValue( + _resolve_with_state( + unconflicted_state, conflicted_state, auth_events, state_map + ) + ) def _seperate(state_sets): @@ -173,8 +167,9 @@ def _create_auth_events_from_maps(unconflicted_state, conflicted_state, state_ma return auth_events -def _resolve_with_state(unconflicted_state_ids, conflicted_state_ids, auth_event_ids, - state_map): +def _resolve_with_state( + unconflicted_state_ids, conflicted_state_ids, auth_event_ids, state_map +): conflicted_state = {} for key, event_ids in iteritems(conflicted_state_ids): events = [state_map[ev_id] for ev_id in event_ids if ev_id in state_map] @@ -190,9 +185,7 @@ def _resolve_with_state(unconflicted_state_ids, conflicted_state_ids, auth_event } try: - resolved_state = _resolve_state_events( - conflicted_state, auth_events - ) + resolved_state = _resolve_state_events(conflicted_state, auth_events) except Exception: logger.exception("Failed to resolve state") raise @@ -218,37 +211,28 @@ def _resolve_state_events(conflicted_state, auth_events): if POWER_KEY in conflicted_state: events = conflicted_state[POWER_KEY] logger.debug("Resolving conflicted power levels %r", events) - resolved_state[POWER_KEY] = _resolve_auth_events( - events, auth_events) + resolved_state[POWER_KEY] = _resolve_auth_events(events, auth_events) auth_events.update(resolved_state) for key, events in iteritems(conflicted_state): if key[0] == EventTypes.JoinRules: logger.debug("Resolving conflicted join rules %r", events) - resolved_state[key] = _resolve_auth_events( - events, - auth_events - ) + resolved_state[key] = _resolve_auth_events(events, auth_events) auth_events.update(resolved_state) for key, events in iteritems(conflicted_state): if key[0] == EventTypes.Member: logger.debug("Resolving conflicted member lists %r", events) - resolved_state[key] = _resolve_auth_events( - events, - auth_events - ) + resolved_state[key] = _resolve_auth_events(events, auth_events) auth_events.update(resolved_state) for key, events in iteritems(conflicted_state): if key not in resolved_state: logger.debug("Resolving conflicted state %r:%r", key, events) - resolved_state[key] = _resolve_normal_events( - events, auth_events - ) + resolved_state[key] = _resolve_normal_events(events, auth_events) return resolved_state @@ -257,9 +241,7 @@ def _resolve_auth_events(events, auth_events): reverse = [i for i in reversed(_ordered_events(events))] auth_keys = set( - key - for event in events - for key in event_auth.auth_types_for_event(event) + key for event in events for key in event_auth.auth_types_for_event(event) ) new_auth_events = {} @@ -313,6 +295,6 @@ def _ordered_events(events): def key_func(e): # we have to use utf-8 rather than ascii here because it turns out we allow # people to send us events with non-ascii event IDs :/ - return -int(e.depth), hashlib.sha1(e.event_id.encode('utf-8')).hexdigest() + return -int(e.depth), hashlib.sha1(e.event_id.encode("utf-8")).hexdigest() return sorted(events, key=key_func) diff --git a/synapse/state/v2.py b/synapse/state/v2.py index 650995c92c62..db969e899752 100644 --- a/synapse/state/v2.py +++ b/synapse/state/v2.py @@ -70,19 +70,18 @@ def resolve_events_with_store(room_version, state_sets, event_map, state_res_sto # Also fetch all auth events that appear in only some of the state sets' # auth chains. - auth_diff = yield _get_auth_chain_difference( - state_sets, event_map, state_res_store, - ) + auth_diff = yield _get_auth_chain_difference(state_sets, event_map, state_res_store) - full_conflicted_set = set(itertools.chain( - itertools.chain.from_iterable(itervalues(conflicted_state)), - auth_diff, - )) + full_conflicted_set = set( + itertools.chain( + itertools.chain.from_iterable(itervalues(conflicted_state)), auth_diff + ) + ) - events = yield state_res_store.get_events([ - eid for eid in full_conflicted_set - if eid not in event_map - ], allow_rejected=True) + events = yield state_res_store.get_events( + [eid for eid in full_conflicted_set if eid not in event_map], + allow_rejected=True, + ) event_map.update(events) full_conflicted_set = set(eid for eid in full_conflicted_set if eid in event_map) @@ -91,22 +90,21 @@ def resolve_events_with_store(room_version, state_sets, event_map, state_res_sto # Get and sort all the power events (kicks/bans/etc) power_events = ( - eid for eid in full_conflicted_set - if _is_power_event(event_map[eid]) + eid for eid in full_conflicted_set if _is_power_event(event_map[eid]) ) sorted_power_events = yield _reverse_topological_power_sort( - power_events, - event_map, - state_res_store, - full_conflicted_set, + power_events, event_map, state_res_store, full_conflicted_set ) logger.debug("sorted %d power events", len(sorted_power_events)) # Now sequentially auth each one resolved_state = yield _iterative_auth_checks( - room_version, sorted_power_events, unconflicted_state, event_map, + room_version, + sorted_power_events, + unconflicted_state, + event_map, state_res_store, ) @@ -116,23 +114,20 @@ def resolve_events_with_store(room_version, state_sets, event_map, state_res_sto # events using the mainline of the resolved power level. leftover_events = [ - ev_id - for ev_id in full_conflicted_set - if ev_id not in sorted_power_events + ev_id for ev_id in full_conflicted_set if ev_id not in sorted_power_events ] logger.debug("sorting %d remaining events", len(leftover_events)) pl = resolved_state.get((EventTypes.PowerLevels, ""), None) leftover_events = yield _mainline_sort( - leftover_events, pl, event_map, state_res_store, + leftover_events, pl, event_map, state_res_store ) logger.debug("resolving remaining events") resolved_state = yield _iterative_auth_checks( - room_version, leftover_events, resolved_state, event_map, - state_res_store, + room_version, leftover_events, resolved_state, event_map, state_res_store ) logger.debug("resolved") @@ -209,14 +204,16 @@ def _get_auth_chain_difference(state_sets, event_map, state_res_store): auth_ids = set( eid for key, eid in iteritems(state_set) - if (key[0] in ( - EventTypes.Member, - EventTypes.ThirdPartyInvite, - ) or key in ( - (EventTypes.PowerLevels, ''), - (EventTypes.Create, ''), - (EventTypes.JoinRules, ''), - )) and eid not in common + if ( + key[0] in (EventTypes.Member, EventTypes.ThirdPartyInvite) + or key + in ( + (EventTypes.PowerLevels, ""), + (EventTypes.Create, ""), + (EventTypes.JoinRules, ""), + ) + ) + and eid not in common ) auth_chain = yield state_res_store.get_auth_chain(auth_ids) @@ -274,15 +271,16 @@ def _is_power_event(event): return True if event.type == EventTypes.Member: - if event.membership in ('leave', 'ban'): + if event.membership in ("leave", "ban"): return event.sender != event.state_key return False @defer.inlineCallbacks -def _add_event_and_auth_chain_to_graph(graph, event_id, event_map, - state_res_store, auth_diff): +def _add_event_and_auth_chain_to_graph( + graph, event_id, event_map, state_res_store, auth_diff +): """Helper function for _reverse_topological_power_sort that add the event and its auth chain (that is in the auth diff) to the graph @@ -327,7 +325,7 @@ def _reverse_topological_power_sort(event_ids, event_map, state_res_store, auth_ graph = {} for event_id in event_ids: yield _add_event_and_auth_chain_to_graph( - graph, event_id, event_map, state_res_store, auth_diff, + graph, event_id, event_map, state_res_store, auth_diff ) event_to_pl = {} @@ -342,18 +340,16 @@ def _get_power_order(event_id): return -pl, ev.origin_server_ts, event_id # Note: graph is modified during the sort - it = lexicographical_topological_sort( - graph, - key=_get_power_order, - ) + it = lexicographical_topological_sort(graph, key=_get_power_order) sorted_events = list(it) defer.returnValue(sorted_events) @defer.inlineCallbacks -def _iterative_auth_checks(room_version, event_ids, base_state, event_map, - state_res_store): +def _iterative_auth_checks( + room_version, event_ids, base_state, event_map, state_res_store +): """Sequentially apply auth checks to each event in given list, updating the state as it goes along. @@ -389,9 +385,11 @@ def _iterative_auth_checks(room_version, event_ids, base_state, event_map, try: event_auth.check( - room_version, event, auth_events, + room_version, + event, + auth_events, do_sig_check=False, - do_size_check=False + do_size_check=False, ) resolved_state[(event.type, event.state_key)] = event_id @@ -402,8 +400,7 @@ def _iterative_auth_checks(room_version, event_ids, base_state, event_map, @defer.inlineCallbacks -def _mainline_sort(event_ids, resolved_power_event_id, event_map, - state_res_store): +def _mainline_sort(event_ids, resolved_power_event_id, event_map, state_res_store): """Returns a sorted list of event_ids sorted by mainline ordering based on the given event resolved_power_event_id @@ -436,8 +433,7 @@ def _mainline_sort(event_ids, resolved_power_event_id, event_map, order_map = {} for ev_id in event_ids: depth = yield _get_mainline_depth_for_event( - event_map[ev_id], mainline_map, - event_map, state_res_store, + event_map[ev_id], mainline_map, event_map, state_res_store ) order_map[ev_id] = (depth, event_map[ev_id].origin_server_ts, ev_id) diff --git a/synapse/storage/__init__.py b/synapse/storage/__init__.py index 0ca6f6121fe0..6b0ca800876c 100644 --- a/synapse/storage/__init__.py +++ b/synapse/storage/__init__.py @@ -280,7 +280,7 @@ def count_daily_users(self): Counts the number of users who used this homeserver in the last 24 hours. """ yesterday = int(self._clock.time_msec()) - (1000 * 60 * 60 * 24) - return self.runInteraction("count_daily_users", self._count_users, yesterday,) + return self.runInteraction("count_daily_users", self._count_users, yesterday) def count_monthly_users(self): """ @@ -291,9 +291,7 @@ def count_monthly_users(self): """ thirty_days_ago = int(self._clock.time_msec()) - (1000 * 60 * 60 * 24 * 30) return self.runInteraction( - "count_monthly_users", - self._count_users, - thirty_days_ago, + "count_monthly_users", self._count_users, thirty_days_ago ) def _count_users(self, txn, time_from): @@ -361,7 +359,7 @@ def _count_r30_users(txn): txn.execute(sql, (thirty_days_ago_in_secs, thirty_days_ago_in_secs)) for row in txn: - if row[0] == 'unknown': + if row[0] == "unknown": pass results[row[0]] = row[1] @@ -388,7 +386,7 @@ def _count_r30_users(txn): txn.execute(sql, (thirty_days_ago_in_secs, thirty_days_ago_in_secs)) count, = txn.fetchone() - results['all'] = count + results["all"] = count return results diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py index 941c07fce540..c74bcc8f0b06 100644 --- a/synapse/storage/_base.py +++ b/synapse/storage/_base.py @@ -312,9 +312,7 @@ def select_users_with_no_expiration_date_txn(txn): if res: for user in res: self.set_expiration_date_for_user_txn( - txn, - user["name"], - use_delta=True, + txn, user["name"], use_delta=True ) yield self.runInteraction( @@ -1667,7 +1665,7 @@ def db_to_json(db_content): # Decode it to a Unicode string before feeding it to json.loads, so we # consistenty get a Unicode-containing object out. if isinstance(db_content, (bytes, bytearray)): - db_content = db_content.decode('utf8') + db_content = db_content.decode("utf8") try: return json.loads(db_content) diff --git a/synapse/storage/background_updates.py b/synapse/storage/background_updates.py index b8b8273f73e6..50f913a414a5 100644 --- a/synapse/storage/background_updates.py +++ b/synapse/storage/background_updates.py @@ -169,7 +169,7 @@ def do_next_background_update(self, desired_duration_ms): in_flight = set(update["update_name"] for update in updates) for update in updates: if update["depends_on"] not in in_flight: - self._background_update_queue.append(update['update_name']) + self._background_update_queue.append(update["update_name"]) if not self._background_update_queue: # no work left to do diff --git a/synapse/storage/devices.py b/synapse/storage/devices.py index d102e07372cc..3413a46675a3 100644 --- a/synapse/storage/devices.py +++ b/synapse/storage/devices.py @@ -149,9 +149,7 @@ def get_devices_by_remote(self, destination, from_stream_id, limit): defer.returnValue((stream_id_cutoff, [])) results = yield self._get_device_update_edus_by_remote( - destination, - from_stream_id, - query_map, + destination, from_stream_id, query_map ) defer.returnValue((now_stream_id, results)) @@ -182,9 +180,7 @@ def _get_devices_by_remote_txn( return list(txn) @defer.inlineCallbacks - def _get_device_update_edus_by_remote( - self, destination, from_stream_id, query_map, - ): + def _get_device_update_edus_by_remote(self, destination, from_stream_id, query_map): """Returns a list of device update EDUs as well as E2EE keys Args: @@ -210,7 +206,7 @@ def _get_device_update_edus_by_remote( # The prev_id for the first row is always the last row before # `from_stream_id` prev_id = yield self._get_last_device_update_for_remote_user( - destination, user_id, from_stream_id, + destination, user_id, from_stream_id ) for device_id, device in iteritems(user_devices): stream_id = query_map[(user_id, device_id)] @@ -238,7 +234,7 @@ def _get_device_update_edus_by_remote( defer.returnValue(results) def _get_last_device_update_for_remote_user( - self, destination, user_id, from_stream_id, + self, destination, user_id, from_stream_id ): def f(txn): prev_sent_id_sql = """ diff --git a/synapse/storage/e2e_room_keys.py b/synapse/storage/e2e_room_keys.py index 521936e3b06d..f40ef2ab6451 100644 --- a/synapse/storage/e2e_room_keys.py +++ b/synapse/storage/e2e_room_keys.py @@ -87,10 +87,10 @@ def set_e2e_room_key(self, user_id, version, room_id, session_id, room_key): }, values={ "version": version, - "first_message_index": room_key['first_message_index'], - "forwarded_count": room_key['forwarded_count'], - "is_verified": room_key['is_verified'], - "session_data": json.dumps(room_key['session_data']), + "first_message_index": room_key["first_message_index"], + "forwarded_count": room_key["forwarded_count"], + "is_verified": room_key["is_verified"], + "session_data": json.dumps(room_key["session_data"]), }, lock=False, ) @@ -118,13 +118,13 @@ def get_e2e_room_keys(self, user_id, version, room_id=None, session_id=None): try: version = int(version) except ValueError: - defer.returnValue({'rooms': {}}) + defer.returnValue({"rooms": {}}) keyvalues = {"user_id": user_id, "version": version} if room_id: - keyvalues['room_id'] = room_id + keyvalues["room_id"] = room_id if session_id: - keyvalues['session_id'] = session_id + keyvalues["session_id"] = session_id rows = yield self._simple_select_list( table="e2e_room_keys", @@ -141,10 +141,10 @@ def get_e2e_room_keys(self, user_id, version, room_id=None, session_id=None): desc="get_e2e_room_keys", ) - sessions = {'rooms': {}} + sessions = {"rooms": {}} for row in rows: - room_entry = sessions['rooms'].setdefault(row['room_id'], {"sessions": {}}) - room_entry['sessions'][row['session_id']] = { + room_entry = sessions["rooms"].setdefault(row["room_id"], {"sessions": {}}) + room_entry["sessions"][row["session_id"]] = { "first_message_index": row["first_message_index"], "forwarded_count": row["forwarded_count"], "is_verified": row["is_verified"], @@ -174,9 +174,9 @@ def delete_e2e_room_keys(self, user_id, version, room_id=None, session_id=None): keyvalues = {"user_id": user_id, "version": int(version)} if room_id: - keyvalues['room_id'] = room_id + keyvalues["room_id"] = room_id if session_id: - keyvalues['session_id'] = session_id + keyvalues["session_id"] = session_id yield self._simple_delete( table="e2e_room_keys", keyvalues=keyvalues, desc="delete_e2e_room_keys" @@ -191,7 +191,7 @@ def _get_current_version(txn, user_id): ) row = txn.fetchone() if not row: - raise StoreError(404, 'No current backup version') + raise StoreError(404, "No current backup version") return row[0] def get_e2e_room_keys_version_info(self, user_id, version=None): @@ -255,7 +255,7 @@ def _create_e2e_room_keys_version_txn(txn): ) current_version = txn.fetchone()[0] if current_version is None: - current_version = '0' + current_version = "0" new_version = str(int(current_version) + 1) diff --git a/synapse/storage/engines/sqlite.py b/synapse/storage/engines/sqlite.py index 933bcf42c2b9..e9b9caa49a3c 100644 --- a/synapse/storage/engines/sqlite.py +++ b/synapse/storage/engines/sqlite.py @@ -85,7 +85,7 @@ def server_version(self): def _parse_match_info(buf): bufsize = len(buf) - return [struct.unpack('@I', buf[i : i + 4])[0] for i in range(0, bufsize, 4)] + return [struct.unpack("@I", buf[i : i + 4])[0] for i in range(0, bufsize, 4)] def _rank(raw_match_info): diff --git a/synapse/storage/event_federation.py b/synapse/storage/event_federation.py index e8d16edbc8f1..cb4478342f17 100644 --- a/synapse/storage/event_federation.py +++ b/synapse/storage/event_federation.py @@ -215,8 +215,7 @@ def _get_rooms_with_many_extremities_txn(txn): return [room_id for room_id, in txn] return self.runInteraction( - "get_rooms_with_many_extremities", - _get_rooms_with_many_extremities_txn, + "get_rooms_with_many_extremities", _get_rooms_with_many_extremities_txn ) @cached(max_entries=5000, iterable=True) diff --git a/synapse/storage/event_push_actions.py b/synapse/storage/event_push_actions.py index a729f3e06781..eca77069fd16 100644 --- a/synapse/storage/event_push_actions.py +++ b/synapse/storage/event_push_actions.py @@ -277,7 +277,7 @@ def get_no_receipt(txn): # contain results from the first query, correctly ordered, followed # by results from the second query, but we want them all ordered # by stream_ordering, oldest first. - notifs.sort(key=lambda r: r['stream_ordering']) + notifs.sort(key=lambda r: r["stream_ordering"]) # Take only up to the limit. We have to stop at the limit because # one of the subqueries may have hit the limit. @@ -379,7 +379,7 @@ def get_no_receipt(txn): # contain results from the first query, correctly ordered, followed # by results from the second query, but we want them all ordered # by received_ts (most recent first) - notifs.sort(key=lambda r: -(r['received_ts'] or 0)) + notifs.sort(key=lambda r: -(r["received_ts"] or 0)) # Now return the first `limit` defer.returnValue(notifs[:limit]) diff --git a/synapse/storage/events.py b/synapse/storage/events.py index f631fb173345..de5965569c2a 100644 --- a/synapse/storage/events.py +++ b/synapse/storage/events.py @@ -234,7 +234,7 @@ def __init__(self, db_conn, hs): BucketCollector( "synapse_forward_extremities", lambda: self._current_forward_extremities_amount, - buckets=[1, 2, 3, 5, 7, 10, 15, 20, 50, 100, 200, 500, "+Inf"] + buckets=[1, 2, 3, 5, 7, 10, 15, 20, 50, 100, 200, 500, "+Inf"], ) # Read the extrems every 60 minutes diff --git a/synapse/storage/events_bg_updates.py b/synapse/storage/events_bg_updates.py index 75c1935bf34a..1ce21d190c85 100644 --- a/synapse/storage/events_bg_updates.py +++ b/synapse/storage/events_bg_updates.py @@ -64,8 +64,7 @@ def __init__(self, db_conn, hs): ) self.register_background_update_handler( - self.DELETE_SOFT_FAILED_EXTREMITIES, - self._cleanup_extremities_bg_update, + self.DELETE_SOFT_FAILED_EXTREMITIES, self._cleanup_extremities_bg_update ) @defer.inlineCallbacks @@ -269,7 +268,8 @@ def _cleanup_extremities_bg_update_txn(txn): LEFT JOIN events USING (event_id) LEFT JOIN event_json USING (event_id) LEFT JOIN rejections USING (event_id) - """, (batch_size,) + """, + (batch_size,), ) for prev_event_id, event_id, metadata, rejected, outlier in txn: @@ -364,13 +364,12 @@ def _cleanup_extremities_bg_update_txn(txn): column="event_id", iterable=to_delete, keyvalues={}, - retcols=("room_id",) + retcols=("room_id",), ) room_ids = set(row["room_id"] for row in rows) for room_id in room_ids: txn.call_after( - self.get_latest_event_ids_in_room.invalidate, - (room_id,) + self.get_latest_event_ids_in_room.invalidate, (room_id,) ) self._simple_delete_many_txn( @@ -384,7 +383,7 @@ def _cleanup_extremities_bg_update_txn(txn): return len(original_set) num_handled = yield self.runInteraction( - "_cleanup_extremities_bg_update", _cleanup_extremities_bg_update_txn, + "_cleanup_extremities_bg_update", _cleanup_extremities_bg_update_txn ) if not num_handled: @@ -394,8 +393,7 @@ def _drop_table_txn(txn): txn.execute("DROP TABLE _extremities_to_check") yield self.runInteraction( - "_cleanup_extremities_bg_update_drop_table", - _drop_table_txn, + "_cleanup_extremities_bg_update_drop_table", _drop_table_txn ) defer.returnValue(num_handled) diff --git a/synapse/storage/events_worker.py b/synapse/storage/events_worker.py index cc7df5cf14df..6d680d405ac1 100644 --- a/synapse/storage/events_worker.py +++ b/synapse/storage/events_worker.py @@ -27,7 +27,6 @@ from synapse.api.errors import NotFoundError from synapse.api.room_versions import EventFormatVersions from synapse.events import FrozenEvent, event_type_from_format_version # noqa: F401 -# these are only included to make the type annotations work from synapse.events.snapshot import EventContext # noqa: F401 from synapse.events.utils import prune_event from synapse.metrics.background_process_metrics import run_as_background_process @@ -111,8 +110,7 @@ def _get_approximate_received_ts_txn(txn): return ts return self.runInteraction( - "get_approximate_received_ts", - _get_approximate_received_ts_txn, + "get_approximate_received_ts", _get_approximate_received_ts_txn ) @defer.inlineCallbacks @@ -677,7 +675,8 @@ def get_total_state_event_counts(self, room_id): """ return self.runInteraction( "get_total_state_event_counts", - self._get_total_state_event_counts_txn, room_id + self._get_total_state_event_counts_txn, + room_id, ) def _get_current_state_event_counts_txn(self, txn, room_id): @@ -701,7 +700,8 @@ def get_current_state_event_counts(self, room_id): """ return self.runInteraction( "get_current_state_event_counts", - self._get_current_state_event_counts_txn, room_id + self._get_current_state_event_counts_txn, + room_id, ) @defer.inlineCallbacks diff --git a/synapse/storage/group_server.py b/synapse/storage/group_server.py index dce6a43ac1eb..73e6fc6de2b0 100644 --- a/synapse/storage/group_server.py +++ b/synapse/storage/group_server.py @@ -1179,11 +1179,7 @@ def _delete_group_txn(txn): for table in tables: self._simple_delete_txn( - txn, - table=table, - keyvalues={"group_id": group_id}, + txn, table=table, keyvalues={"group_id": group_id} ) - return self.runInteraction( - "delete_group", _delete_group_txn - ) + return self.runInteraction("delete_group", _delete_group_txn) diff --git a/synapse/storage/keys.py b/synapse/storage/keys.py index e3655ad8d759..e72f89e44602 100644 --- a/synapse/storage/keys.py +++ b/synapse/storage/keys.py @@ -131,7 +131,7 @@ def store_server_verify_keys(self, from_server, ts_added_ms, verify_keys): def _invalidate(res): f = self._get_server_verify_key.invalidate for i in invalidations: - f((i, )) + f((i,)) return res return self.runInteraction( diff --git a/synapse/storage/media_repository.py b/synapse/storage/media_repository.py index 3ecf47e7a787..6b1238ce4ae2 100644 --- a/synapse/storage/media_repository.py +++ b/synapse/storage/media_repository.py @@ -22,11 +22,11 @@ def __init__(self, db_conn, hs): super(MediaRepositoryStore, self).__init__(db_conn, hs) self.register_background_index_update( - update_name='local_media_repository_url_idx', - index_name='local_media_repository_url_idx', - table='local_media_repository', - columns=['created_ts'], - where_clause='url_cache IS NOT NULL', + update_name="local_media_repository_url_idx", + index_name="local_media_repository_url_idx", + table="local_media_repository", + columns=["created_ts"], + where_clause="url_cache IS NOT NULL", ) def get_local_media(self, media_id): @@ -108,12 +108,12 @@ def get_url_cache_txn(txn): return dict( zip( ( - 'response_code', - 'etag', - 'expires_ts', - 'og', - 'media_id', - 'download_ts', + "response_code", + "etag", + "expires_ts", + "og", + "media_id", + "download_ts", ), row, ) diff --git a/synapse/storage/monthly_active_users.py b/synapse/storage/monthly_active_users.py index 8aa8abc47051..081564360fb6 100644 --- a/synapse/storage/monthly_active_users.py +++ b/synapse/storage/monthly_active_users.py @@ -86,11 +86,11 @@ def _reap_users(txn): if len(self.reserved_users) > 0: # questionmarks is a hack to overcome sqlite not supporting # tuples in 'WHERE IN %s' - questionmarks = '?' * len(self.reserved_users) + questionmarks = "?" * len(self.reserved_users) query_args.extend(self.reserved_users) sql = base_sql + """ AND user_id NOT IN ({})""".format( - ','.join(questionmarks) + ",".join(questionmarks) ) else: sql = base_sql @@ -124,7 +124,7 @@ def _reap_users(txn): if len(self.reserved_users) > 0: query_args.extend(self.reserved_users) sql = base_sql + """ AND user_id NOT IN ({})""".format( - ','.join(questionmarks) + ",".join(questionmarks) ) else: sql = base_sql diff --git a/synapse/storage/prepare_database.py b/synapse/storage/prepare_database.py index f2c1bed487cb..fc10b9534e19 100644 --- a/synapse/storage/prepare_database.py +++ b/synapse/storage/prepare_database.py @@ -146,9 +146,10 @@ def _setup_new_database(cur, database_engine): directory_entries = os.listdir(sql_dir) - for filename in sorted(fnmatch.filter(directory_entries, "*.sql") + fnmatch.filter( - directory_entries, "*.sql." + specific - )): + for filename in sorted( + fnmatch.filter(directory_entries, "*.sql") + + fnmatch.filter(directory_entries, "*.sql." + specific) + ): sql_loc = os.path.join(sql_dir, filename) logger.debug("Applying schema %s", sql_loc) executescript(cur, sql_loc) @@ -313,7 +314,7 @@ def _apply_module_schemas(txn, database_engine, config): application config """ for (mod, _config) in config.password_providers: - if not hasattr(mod, 'get_db_schema_files'): + if not hasattr(mod, "get_db_schema_files"): continue modname = ".".join((mod.__module__, mod.__name__)) _apply_module_schema_files( @@ -343,7 +344,7 @@ def _apply_module_schema_files(cur, database_engine, modname, names_and_streams) continue root_name, ext = os.path.splitext(name) - if ext != '.sql': + if ext != ".sql": raise PrepareDatabaseException( "only .sql files are currently supported for module schemas" ) @@ -407,7 +408,7 @@ def get_statements(f): def executescript(txn, schema_path): - with open(schema_path, 'r') as f: + with open(schema_path, "r") as f: for statement in get_statements(f): txn.execute(statement) diff --git a/synapse/storage/profile.py b/synapse/storage/profile.py index aeec2f57c4d4..0ff392bdb4d8 100644 --- a/synapse/storage/profile.py +++ b/synapse/storage/profile.py @@ -41,7 +41,7 @@ def get_profileinfo(self, user_localpart): defer.returnValue( ProfileInfo( - avatar_url=profile['avatar_url'], display_name=profile['displayname'] + avatar_url=profile["avatar_url"], display_name=profile["displayname"] ) ) diff --git a/synapse/storage/push_rule.py b/synapse/storage/push_rule.py index 9e406baafaaa..98cec8c82bf6 100644 --- a/synapse/storage/push_rule.py +++ b/synapse/storage/push_rule.py @@ -46,12 +46,12 @@ def _load_rules(rawrules, enabled_map): rules = list(list_with_base_rules(ruleslist)) for i, rule in enumerate(rules): - rule_id = rule['rule_id'] + rule_id = rule["rule_id"] if rule_id in enabled_map: - if rule.get('enabled', True) != bool(enabled_map[rule_id]): + if rule.get("enabled", True) != bool(enabled_map[rule_id]): # Rules are cached across users. rule = dict(rule) - rule['enabled'] = bool(enabled_map[rule_id]) + rule["enabled"] = bool(enabled_map[rule_id]) rules[i] = rule return rules @@ -126,12 +126,12 @@ def get_push_rules_for_user(self, user_id): def get_push_rules_enabled_for_user(self, user_id): results = yield self._simple_select_list( table="push_rules_enable", - keyvalues={'user_name': user_id}, + keyvalues={"user_name": user_id}, retcols=("user_name", "rule_id", "enabled"), desc="get_push_rules_enabled_for_user", ) defer.returnValue( - {r['rule_id']: False if r['enabled'] == 0 else True for r in results} + {r["rule_id"]: False if r["enabled"] == 0 else True for r in results} ) def have_push_rules_changed_for_user(self, user_id, last_id): @@ -175,7 +175,7 @@ def bulk_get_push_rules(self, user_ids): rows.sort(key=lambda row: (-int(row["priority_class"]), -int(row["priority"]))) for row in rows: - results.setdefault(row['user_name'], []).append(row) + results.setdefault(row["user_name"], []).append(row) enabled_map_by_user = yield self.bulk_get_push_rules_enabled(user_ids) @@ -194,7 +194,7 @@ def move_push_rule_from_room_to_room(self, new_room_id, user_id, rule): rule (Dict): A push rule. """ # Create new rule id - rule_id_scope = '/'.join(rule["rule_id"].split('/')[:-1]) + rule_id_scope = "/".join(rule["rule_id"].split("/")[:-1]) new_rule_id = rule_id_scope + "/" + new_room_id # Change room id in each condition @@ -334,8 +334,8 @@ def bulk_get_push_rules_enabled(self, user_ids): desc="bulk_get_push_rules_enabled", ) for row in rows: - enabled = bool(row['enabled']) - results.setdefault(row['user_name'], {})[row['rule_id']] = enabled + enabled = bool(row["enabled"]) + results.setdefault(row["user_name"], {})[row["rule_id"]] = enabled defer.returnValue(results) @@ -568,7 +568,7 @@ def delete_push_rule(self, user_id, rule_id): def delete_push_rule_txn(txn, stream_id, event_stream_ordering): self._simple_delete_one_txn( - txn, "push_rules", {'user_name': user_id, 'rule_id': rule_id} + txn, "push_rules", {"user_name": user_id, "rule_id": rule_id} ) self._insert_push_rules_update_txn( @@ -605,9 +605,9 @@ def _set_push_rule_enabled_txn( self._simple_upsert_txn( txn, "push_rules_enable", - {'user_name': user_id, 'rule_id': rule_id}, - {'enabled': 1 if enabled else 0}, - {'id': new_id}, + {"user_name": user_id, "rule_id": rule_id}, + {"enabled": 1 if enabled else 0}, + {"id": new_id}, ) self._insert_push_rules_update_txn( @@ -645,8 +645,8 @@ def set_push_rule_actions_txn(txn, stream_id, event_stream_ordering): self._simple_update_one_txn( txn, "push_rules", - {'user_name': user_id, 'rule_id': rule_id}, - {'actions': actions_json}, + {"user_name": user_id, "rule_id": rule_id}, + {"actions": actions_json}, ) self._insert_push_rules_update_txn( diff --git a/synapse/storage/pusher.py b/synapse/storage/pusher.py index 1567e1df4810..cfe0a94330fb 100644 --- a/synapse/storage/pusher.py +++ b/synapse/storage/pusher.py @@ -37,24 +37,24 @@ class PusherWorkerStore(SQLBaseStore): def _decode_pushers_rows(self, rows): for r in rows: - dataJson = r['data'] - r['data'] = None + dataJson = r["data"] + r["data"] = None try: if isinstance(dataJson, db_binary_type): dataJson = str(dataJson).decode("UTF8") - r['data'] = json.loads(dataJson) + r["data"] = json.loads(dataJson) except Exception as e: logger.warn( "Invalid JSON in data for pusher %d: %s, %s", - r['id'], + r["id"], dataJson, e.args[0], ) pass - if isinstance(r['pushkey'], db_binary_type): - r['pushkey'] = str(r['pushkey']).decode("UTF8") + if isinstance(r["pushkey"], db_binary_type): + r["pushkey"] = str(r["pushkey"]).decode("UTF8") return rows @@ -195,15 +195,15 @@ def get_if_user_has_pusher(self, user_id): ) def get_if_users_have_pushers(self, user_ids): rows = yield self._simple_select_many_batch( - table='pushers', - column='user_name', + table="pushers", + column="user_name", iterable=user_ids, - retcols=['user_name'], - desc='get_if_users_have_pushers', + retcols=["user_name"], + desc="get_if_users_have_pushers", ) result = {user_id: False for user_id in user_ids} - result.update({r['user_name']: True for r in rows}) + result.update({r["user_name"]: True for r in rows}) defer.returnValue(result) @@ -299,8 +299,8 @@ def update_pusher_last_stream_ordering( ): yield self._simple_update_one( "pushers", - {'app_id': app_id, 'pushkey': pushkey, 'user_name': user_id}, - {'last_stream_ordering': last_stream_ordering}, + {"app_id": app_id, "pushkey": pushkey, "user_name": user_id}, + {"last_stream_ordering": last_stream_ordering}, desc="update_pusher_last_stream_ordering", ) @@ -310,10 +310,10 @@ def update_pusher_last_stream_ordering_and_success( ): yield self._simple_update_one( "pushers", - {'app_id': app_id, 'pushkey': pushkey, 'user_name': user_id}, + {"app_id": app_id, "pushkey": pushkey, "user_name": user_id}, { - 'last_stream_ordering': last_stream_ordering, - 'last_success': last_success, + "last_stream_ordering": last_stream_ordering, + "last_success": last_success, }, desc="update_pusher_last_stream_ordering_and_success", ) @@ -322,8 +322,8 @@ def update_pusher_last_stream_ordering_and_success( def update_pusher_failing_since(self, app_id, pushkey, user_id, failing_since): yield self._simple_update_one( "pushers", - {'app_id': app_id, 'pushkey': pushkey, 'user_name': user_id}, - {'failing_since': failing_since}, + {"app_id": app_id, "pushkey": pushkey, "user_name": user_id}, + {"failing_since": failing_since}, desc="update_pusher_failing_since", ) diff --git a/synapse/storage/receipts.py b/synapse/storage/receipts.py index a1647e50a11f..b477da12b1df 100644 --- a/synapse/storage/receipts.py +++ b/synapse/storage/receipts.py @@ -58,7 +58,7 @@ def get_max_receipt_stream_id(self): @cachedInlineCallbacks() def get_users_with_read_receipts_in_room(self, room_id): receipts = yield self.get_receipts_for_room(room_id, "m.read") - defer.returnValue(set(r['user_id'] for r in receipts)) + defer.returnValue(set(r["user_id"] for r in receipts)) @cached(num_args=2) def get_receipts_for_room(self, room_id, receipt_type): diff --git a/synapse/storage/registration.py b/synapse/storage/registration.py index 0b3c656e9029..983ce132910a 100644 --- a/synapse/storage/registration.py +++ b/synapse/storage/registration.py @@ -116,8 +116,9 @@ def get_expiration_ts_for_user(self, user_id): defer.returnValue(res) @defer.inlineCallbacks - def set_account_validity_for_user(self, user_id, expiration_ts, email_sent, - renewal_token=None): + def set_account_validity_for_user( + self, user_id, expiration_ts, email_sent, renewal_token=None + ): """Updates the account validity properties of the given account, with the given values. @@ -131,6 +132,7 @@ def set_account_validity_for_user(self, user_id, expiration_ts, email_sent, renewal_token (str): Renewal token the user can use to extend the validity of their account. Defaults to no token. """ + def set_account_validity_for_user_txn(txn): self._simple_update_txn( txn=txn, @@ -143,12 +145,11 @@ def set_account_validity_for_user_txn(txn): }, ) self._invalidate_cache_and_stream( - txn, self.get_expiration_ts_for_user, (user_id,), + txn, self.get_expiration_ts_for_user, (user_id,) ) yield self.runInteraction( - "set_account_validity_for_user", - set_account_validity_for_user_txn, + "set_account_validity_for_user", set_account_validity_for_user_txn ) @defer.inlineCallbacks @@ -217,6 +218,7 @@ def get_users_expiring_soon(self): Returns: Deferred: Resolves to a list[dict[user_id (str), expiration_ts_ms (int)]] """ + def select_users_txn(txn, now_ms, renew_at): sql = ( "SELECT user_id, expiration_ts_ms FROM account_validity" @@ -229,7 +231,8 @@ def select_users_txn(txn, now_ms, renew_at): res = yield self.runInteraction( "get_users_expiring_soon", select_users_txn, - self.clock.time_msec(), self.config.account_validity.renew_at, + self.clock.time_msec(), + self.config.account_validity.renew_at, ) defer.returnValue(res) @@ -369,7 +372,7 @@ def _count_daily_user_type(txn): WHERE creation_ts > ? ) AS t GROUP BY user_type """ - results = {'native': 0, 'guest': 0, 'bridged': 0} + results = {"native": 0, "guest": 0, "bridged": 0} txn.execute(sql, (yesterday,)) for row in txn: results[row[0]] = row[1] @@ -435,7 +438,7 @@ def get_3pid_guest_access_token(self, medium, address): {"medium": medium, "address": address}, ["guest_access_token"], True, - 'get_3pid_guest_access_token', + "get_3pid_guest_access_token", ) if ret: defer.returnValue(ret["guest_access_token"]) @@ -472,11 +475,11 @@ def get_user_id_by_threepid_txn(self, txn, medium, address): txn, "user_threepids", {"medium": medium, "address": address}, - ['user_id'], + ["user_id"], True, ) if ret: - return ret['user_id'] + return ret["user_id"] return None @defer.inlineCallbacks @@ -492,8 +495,8 @@ def user_get_threepids(self, user_id): ret = yield self._simple_select_list( "user_threepids", {"user_id": user_id}, - ['medium', 'address', 'validated_at', 'added_at'], - 'user_get_threepids', + ["medium", "address", "validated_at", "added_at"], + "user_get_threepids", ) defer.returnValue(ret) @@ -572,11 +575,7 @@ def get_id_servers_user_bound(self, user_id, medium, address): """ return self._simple_select_onecol( table="user_threepid_id_server", - keyvalues={ - "user_id": user_id, - "medium": medium, - "address": address, - }, + keyvalues={"user_id": user_id, "medium": medium, "address": address}, retcol="id_server", desc="get_id_servers_user_bound", ) @@ -612,16 +611,16 @@ def __init__(self, db_conn, hs): self.register_noop_background_update("refresh_tokens_device_index") self.register_background_update_handler( - "user_threepids_grandfather", self._bg_user_threepids_grandfather, + "user_threepids_grandfather", self._bg_user_threepids_grandfather ) self.register_background_update_handler( - "users_set_deactivated_flag", self._backgroud_update_set_deactivated_flag, + "users_set_deactivated_flag", self._backgroud_update_set_deactivated_flag ) # Create a background job for culling expired 3PID validity tokens hs.get_clock().looping_call( - self.cull_expired_threepid_validation_tokens, THIRTY_MINUTES_IN_MS, + self.cull_expired_threepid_validation_tokens, THIRTY_MINUTES_IN_MS ) @defer.inlineCallbacks @@ -677,8 +676,7 @@ def _backgroud_update_set_deactivated_flag_txn(txn): return False end = yield self.runInteraction( - "users_set_deactivated_flag", - _backgroud_update_set_deactivated_flag_txn, + "users_set_deactivated_flag", _backgroud_update_set_deactivated_flag_txn ) if end: @@ -851,7 +849,7 @@ def user_set_password_hash(self, user_id, password_hash): def user_set_password_hash_txn(txn): self._simple_update_one_txn( - txn, 'users', {'name': user_id}, {'password_hash': password_hash} + txn, "users", {"name": user_id}, {"password_hash": password_hash} ) self._invalidate_cache_and_stream(txn, self.get_user_by_id, (user_id,)) @@ -872,9 +870,9 @@ def user_set_consent_version(self, user_id, consent_version): def f(txn): self._simple_update_one_txn( txn, - table='users', - keyvalues={'name': user_id}, - updatevalues={'consent_version': consent_version}, + table="users", + keyvalues={"name": user_id}, + updatevalues={"consent_version": consent_version}, ) self._invalidate_cache_and_stream(txn, self.get_user_by_id, (user_id,)) @@ -896,9 +894,9 @@ def user_set_consent_server_notice_sent(self, user_id, consent_version): def f(txn): self._simple_update_one_txn( txn, - table='users', - keyvalues={'name': user_id}, - updatevalues={'consent_server_notice_sent': consent_version}, + table="users", + keyvalues={"name": user_id}, + updatevalues={"consent_server_notice_sent": consent_version}, ) self._invalidate_cache_and_stream(txn, self.get_user_by_id, (user_id,)) @@ -1068,7 +1066,7 @@ def _bg_user_threepids_grandfather_txn(txn): if id_servers: yield self.runInteraction( - "_bg_user_threepids_grandfather", _bg_user_threepids_grandfather_txn, + "_bg_user_threepids_grandfather", _bg_user_threepids_grandfather_txn ) yield self._end_background_update("user_threepids_grandfather") @@ -1076,12 +1074,7 @@ def _bg_user_threepids_grandfather_txn(txn): defer.returnValue(1) def get_threepid_validation_session( - self, - medium, - client_secret, - address=None, - sid=None, - validated=True, + self, medium, client_secret, address=None, sid=None, validated=True ): """Gets a session_id and last_send_attempt (if available) for a client_secret/medium/(address|session_id) combo @@ -1101,23 +1094,22 @@ def get_threepid_validation_session( latest session_id and send_attempt count for this 3PID. Otherwise None if there hasn't been a previous attempt """ - keyvalues = { - "medium": medium, - "client_secret": client_secret, - } + keyvalues = {"medium": medium, "client_secret": client_secret} if address: keyvalues["address"] = address if sid: keyvalues["session_id"] = sid - assert(address or sid) + assert address or sid def get_threepid_validation_session_txn(txn): sql = """ SELECT address, session_id, medium, client_secret, last_send_attempt, validated_at FROM threepid_validation_session WHERE %s - """ % (" AND ".join("%s = ?" % k for k in iterkeys(keyvalues)),) + """ % ( + " AND ".join("%s = ?" % k for k in iterkeys(keyvalues)), + ) if validated is not None: sql += " AND validated_at IS " + ("NOT NULL" if validated else "NULL") @@ -1132,17 +1124,10 @@ def get_threepid_validation_session_txn(txn): return rows[0] return self.runInteraction( - "get_threepid_validation_session", - get_threepid_validation_session_txn, + "get_threepid_validation_session", get_threepid_validation_session_txn ) - def validate_threepid_session( - self, - session_id, - client_secret, - token, - current_ts, - ): + def validate_threepid_session(self, session_id, client_secret, token, current_ts): """Attempt to validate a threepid session using a token Args: @@ -1174,7 +1159,7 @@ def validate_threepid_session_txn(txn): if retrieved_client_secret != client_secret: raise ThreepidValidationError( - 400, "This client_secret does not match the provided session_id", + 400, "This client_secret does not match the provided session_id" ) row = self._simple_select_one_txn( @@ -1187,7 +1172,7 @@ def validate_threepid_session_txn(txn): if not row: raise ThreepidValidationError( - 400, "Validation token not found or has expired", + 400, "Validation token not found or has expired" ) expires = row["expires"] next_link = row["next_link"] @@ -1198,7 +1183,7 @@ def validate_threepid_session_txn(txn): if expires <= current_ts: raise ThreepidValidationError( - 400, "This token has expired. Please request a new one", + 400, "This token has expired. Please request a new one" ) # Looks good. Validate the session @@ -1213,8 +1198,7 @@ def validate_threepid_session_txn(txn): # Return next_link if it exists return self.runInteraction( - "validate_threepid_session_txn", - validate_threepid_session_txn, + "validate_threepid_session_txn", validate_threepid_session_txn ) def upsert_threepid_validation_session( @@ -1281,6 +1265,7 @@ def start_or_continue_validation_session( token_expires (int): The timestamp for which after the token will no longer be valid """ + def start_or_continue_validation_session_txn(txn): # Create or update a validation session self._simple_upsert_txn( @@ -1314,6 +1299,7 @@ def start_or_continue_validation_session_txn(txn): def cull_expired_threepid_validation_tokens(self): """Remove threepid validation tokens with expiry dates that have passed""" + def cull_expired_threepid_validation_tokens_txn(txn, ts): sql = """ DELETE FROM threepid_validation_token WHERE @@ -1335,6 +1321,7 @@ def delete_threepid_session(self, session_id): Args: session_id (str): The ID of the session to delete """ + def delete_threepid_session_txn(txn): self._simple_delete_txn( txn, @@ -1348,8 +1335,7 @@ def delete_threepid_session_txn(txn): ) return self.runInteraction( - "delete_threepid_session", - delete_threepid_session_txn, + "delete_threepid_session", delete_threepid_session_txn ) def set_user_deactivated_status_txn(self, txn, user_id, deactivated): @@ -1360,7 +1346,7 @@ def set_user_deactivated_status_txn(self, txn, user_id, deactivated): updatevalues={"deactivated": 1 if deactivated else 0}, ) self._invalidate_cache_and_stream( - txn, self.get_user_deactivated_status, (user_id,), + txn, self.get_user_deactivated_status, (user_id,) ) @defer.inlineCallbacks @@ -1375,7 +1361,8 @@ def set_user_deactivated_status(self, user_id, deactivated): yield self.runInteraction( "set_user_deactivated_status", self.set_user_deactivated_status_txn, - user_id, deactivated, + user_id, + deactivated, ) @cachedInlineCallbacks() diff --git a/synapse/storage/relations.py b/synapse/storage/relations.py index 4c83800ccaf7..1b01934c19e8 100644 --- a/synapse/storage/relations.py +++ b/synapse/storage/relations.py @@ -468,9 +468,5 @@ def _handle_redaction(self, txn, redacted_event_id): """ self._simple_delete_txn( - txn, - table="event_relations", - keyvalues={ - "event_id": redacted_event_id, - } + txn, table="event_relations", keyvalues={"event_id": redacted_event_id} ) diff --git a/synapse/storage/roommember.py b/synapse/storage/roommember.py index 761791332674..8004aeb90974 100644 --- a/synapse/storage/roommember.py +++ b/synapse/storage/roommember.py @@ -420,7 +420,7 @@ def _get_joined_users_from_context( table="room_memberships", column="event_id", iterable=missing_member_event_ids, - retcols=('user_id', 'display_name', 'avatar_url'), + retcols=("user_id", "display_name", "avatar_url"), keyvalues={"membership": Membership.JOIN}, batch_size=500, desc="_get_joined_users_from_context", @@ -448,7 +448,7 @@ def _get_joined_users_from_context( @cachedInlineCallbacks(max_entries=10000) def is_host_joined(self, room_id, host): - if '%' in host or '_' in host: + if "%" in host or "_" in host: raise Exception("Invalid host name") sql = """ @@ -490,7 +490,7 @@ def was_host_joined(self, room_id, host): Deferred: Resolves to True if the host is/was in the room, otherwise False. """ - if '%' in host or '_' in host: + if "%" in host or "_" in host: raise Exception("Invalid host name") sql = """ @@ -723,7 +723,7 @@ def add_membership_profile_txn(txn): room_id = row["room_id"] try: event_json = json.loads(row["json"]) - content = event_json['content'] + content = event_json["content"] except Exception: continue diff --git a/synapse/storage/schema/delta/20/pushers.py b/synapse/storage/schema/delta/20/pushers.py index 147496a38b74..3edfcfd78320 100644 --- a/synapse/storage/schema/delta/20/pushers.py +++ b/synapse/storage/schema/delta/20/pushers.py @@ -29,7 +29,8 @@ def run_create(cur, database_engine, *args, **kwargs): logger.info("Porting pushers table...") - cur.execute(""" + cur.execute( + """ CREATE TABLE IF NOT EXISTS pushers2 ( id BIGINT PRIMARY KEY, user_name TEXT NOT NULL, @@ -48,27 +49,34 @@ def run_create(cur, database_engine, *args, **kwargs): failing_since BIGINT, UNIQUE (app_id, pushkey, user_name) ) - """) - cur.execute("""SELECT + """ + ) + cur.execute( + """SELECT id, user_name, access_token, profile_tag, kind, app_id, app_display_name, device_display_name, pushkey, ts, lang, data, last_token, last_success, failing_since FROM pushers - """) + """ + ) count = 0 for row in cur.fetchall(): row = list(row) row[8] = bytes(row[8]).decode("utf-8") row[11] = bytes(row[11]).decode("utf-8") - cur.execute(database_engine.convert_param_style(""" + cur.execute( + database_engine.convert_param_style( + """ INSERT into pushers2 ( id, user_name, access_token, profile_tag, kind, app_id, app_display_name, device_display_name, pushkey, ts, lang, data, last_token, last_success, failing_since - ) values (%s)""" % (','.join(['?' for _ in range(len(row))]))), - row + ) values (%s)""" + % (",".join(["?" for _ in range(len(row))])) + ), + row, ) count += 1 cur.execute("DROP TABLE pushers") diff --git a/synapse/storage/schema/delta/30/as_users.py b/synapse/storage/schema/delta/30/as_users.py index ef7ec34346fd..9b95411fb627 100644 --- a/synapse/storage/schema/delta/30/as_users.py +++ b/synapse/storage/schema/delta/30/as_users.py @@ -40,9 +40,7 @@ def run_upgrade(cur, database_engine, config, *args, **kwargs): logger.warning("Could not get app_service_config_files from config") pass - appservices = load_appservices( - config.server_name, config_files - ) + appservices = load_appservices(config.server_name, config_files) owned = {} @@ -53,20 +51,19 @@ def run_upgrade(cur, database_engine, config, *args, **kwargs): if user_id in owned.keys(): logger.error( "user_id %s was owned by more than one application" - " service (IDs %s and %s); assigning arbitrarily to %s" % - (user_id, owned[user_id], appservice.id, owned[user_id]) + " service (IDs %s and %s); assigning arbitrarily to %s" + % (user_id, owned[user_id], appservice.id, owned[user_id]) ) owned.setdefault(appservice.id, []).append(user_id) for as_id, user_ids in owned.items(): n = 100 - user_chunks = (user_ids[i:i + 100] for i in range(0, len(user_ids), n)) + user_chunks = (user_ids[i : i + 100] for i in range(0, len(user_ids), n)) for chunk in user_chunks: cur.execute( database_engine.convert_param_style( - "UPDATE users SET appservice_id = ? WHERE name IN (%s)" % ( - ",".join("?" for _ in chunk), - ) + "UPDATE users SET appservice_id = ? WHERE name IN (%s)" + % (",".join("?" for _ in chunk),) ), - [as_id] + chunk + [as_id] + chunk, ) diff --git a/synapse/storage/schema/delta/31/pushers.py b/synapse/storage/schema/delta/31/pushers.py index 93367fa09e33..9bb504aad5fa 100644 --- a/synapse/storage/schema/delta/31/pushers.py +++ b/synapse/storage/schema/delta/31/pushers.py @@ -24,12 +24,13 @@ def token_to_stream_ordering(token): - return int(token[1:].split('_')[0]) + return int(token[1:].split("_")[0]) def run_create(cur, database_engine, *args, **kwargs): logger.info("Porting pushers table, delta 31...") - cur.execute(""" + cur.execute( + """ CREATE TABLE IF NOT EXISTS pushers2 ( id BIGINT PRIMARY KEY, user_name TEXT NOT NULL, @@ -48,26 +49,33 @@ def run_create(cur, database_engine, *args, **kwargs): failing_since BIGINT, UNIQUE (app_id, pushkey, user_name) ) - """) - cur.execute("""SELECT + """ + ) + cur.execute( + """SELECT id, user_name, access_token, profile_tag, kind, app_id, app_display_name, device_display_name, pushkey, ts, lang, data, last_token, last_success, failing_since FROM pushers - """) + """ + ) count = 0 for row in cur.fetchall(): row = list(row) row[12] = token_to_stream_ordering(row[12]) - cur.execute(database_engine.convert_param_style(""" + cur.execute( + database_engine.convert_param_style( + """ INSERT into pushers2 ( id, user_name, access_token, profile_tag, kind, app_id, app_display_name, device_display_name, pushkey, ts, lang, data, last_stream_ordering, last_success, failing_since - ) values (%s)""" % (','.join(['?' for _ in range(len(row))]))), - row + ) values (%s)""" + % (",".join(["?" for _ in range(len(row))])) + ), + row, ) count += 1 cur.execute("DROP TABLE pushers") diff --git a/synapse/storage/schema/delta/33/remote_media_ts.py b/synapse/storage/schema/delta/33/remote_media_ts.py index 9754d3ccfb16..a26057dfb6ef 100644 --- a/synapse/storage/schema/delta/33/remote_media_ts.py +++ b/synapse/storage/schema/delta/33/remote_media_ts.py @@ -26,5 +26,5 @@ def run_upgrade(cur, database_engine, *args, **kwargs): database_engine.convert_param_style( "UPDATE remote_media_cache SET last_access_ts = ?" ), - (int(time.time() * 1000),) + (int(time.time() * 1000),), ) diff --git a/synapse/storage/schema/delta/47/state_group_seq.py b/synapse/storage/schema/delta/47/state_group_seq.py index f6766501d253..9fd1ccf6f792 100644 --- a/synapse/storage/schema/delta/47/state_group_seq.py +++ b/synapse/storage/schema/delta/47/state_group_seq.py @@ -27,10 +27,7 @@ def run_create(cur, database_engine, *args, **kwargs): else: start_val = row[0] + 1 - cur.execute( - "CREATE SEQUENCE state_group_id_seq START WITH %s", - (start_val, ), - ) + cur.execute("CREATE SEQUENCE state_group_id_seq START WITH %s", (start_val,)) def run_upgrade(*args, **kwargs): diff --git a/synapse/storage/schema/delta/48/group_unique_indexes.py b/synapse/storage/schema/delta/48/group_unique_indexes.py index 2233af87d770..49f5f2c00324 100644 --- a/synapse/storage/schema/delta/48/group_unique_indexes.py +++ b/synapse/storage/schema/delta/48/group_unique_indexes.py @@ -38,16 +38,22 @@ def run_create(cur, database_engine, *args, **kwargs): rowid = "ctid" if isinstance(database_engine, PostgresEngine) else "rowid" # remove duplicates from group_users & group_invites tables - cur.execute(""" + cur.execute( + """ DELETE FROM group_users WHERE %s NOT IN ( SELECT min(%s) FROM group_users GROUP BY group_id, user_id ); - """ % (rowid, rowid)) - cur.execute(""" + """ + % (rowid, rowid) + ) + cur.execute( + """ DELETE FROM group_invites WHERE %s NOT IN ( SELECT min(%s) FROM group_invites GROUP BY group_id, user_id ); - """ % (rowid, rowid)) + """ + % (rowid, rowid) + ) for statement in get_statements(FIX_INDEXES.splitlines()): cur.execute(statement) diff --git a/synapse/storage/schema/delta/50/make_event_content_nullable.py b/synapse/storage/schema/delta/50/make_event_content_nullable.py index 6dd467b6c533..b1684a8441dc 100644 --- a/synapse/storage/schema/delta/50/make_event_content_nullable.py +++ b/synapse/storage/schema/delta/50/make_event_content_nullable.py @@ -65,14 +65,18 @@ def run_create(cur, database_engine, *args, **kwargs): def run_upgrade(cur, database_engine, *args, **kwargs): if isinstance(database_engine, PostgresEngine): - cur.execute(""" + cur.execute( + """ ALTER TABLE events ALTER COLUMN content DROP NOT NULL; - """) + """ + ) return # sqlite is an arse about this. ref: https://www.sqlite.org/lang_altertable.html - cur.execute("SELECT sql FROM sqlite_master WHERE tbl_name='events' AND type='table'") + cur.execute( + "SELECT sql FROM sqlite_master WHERE tbl_name='events' AND type='table'" + ) (oldsql,) = cur.fetchone() sql = oldsql.replace("content TEXT NOT NULL", "content TEXT") @@ -86,7 +90,7 @@ def run_upgrade(cur, database_engine, *args, **kwargs): cur.execute("PRAGMA writable_schema=ON") cur.execute( "UPDATE sqlite_master SET sql=? WHERE tbl_name='events' AND type='table'", - (sql, ), + (sql,), ) cur.execute("PRAGMA schema_version=%i" % (oldver + 1,)) cur.execute("PRAGMA writable_schema=OFF") diff --git a/synapse/storage/search.py b/synapse/storage/search.py index 10a27c207a2f..f3b1cec93363 100644 --- a/synapse/storage/search.py +++ b/synapse/storage/search.py @@ -31,8 +31,8 @@ logger = logging.getLogger(__name__) SearchEntry = namedtuple( - 'SearchEntry', - ['key', 'value', 'event_id', 'room_id', 'stream_ordering', 'origin_server_ts'], + "SearchEntry", + ["key", "value", "event_id", "room_id", "stream_ordering", "origin_server_ts"], ) @@ -216,7 +216,7 @@ def _background_reindex_search_order(self, progress, batch_size): target_min_stream_id = progress["target_min_stream_id_inclusive"] max_stream_id = progress["max_stream_id_exclusive"] rows_inserted = progress.get("rows_inserted", 0) - have_added_index = progress['have_added_indexes'] + have_added_index = progress["have_added_indexes"] if not have_added_index: diff --git a/synapse/storage/stats.py b/synapse/storage/stats.py index ff266b09b03b..1cec84ee2eeb 100644 --- a/synapse/storage/stats.py +++ b/synapse/storage/stats.py @@ -71,7 +71,8 @@ def _populate_stats_createtables(self, progress, batch_size): # Get all the rooms that we want to process. def _make_staging_area(txn): # Create the temporary tables - stmts = get_statements(""" + stmts = get_statements( + """ -- We just recreate the table, we'll be reinserting the -- correct entries again later anyway. DROP TABLE IF EXISTS {temp}_rooms; @@ -85,7 +86,10 @@ def _make_staging_area(txn): ON {temp}_rooms(events); CREATE INDEX {temp}_rooms_id ON {temp}_rooms(room_id); - """.format(temp=TEMP_TABLE).splitlines()) + """.format( + temp=TEMP_TABLE + ).splitlines() + ) for statement in stmts: txn.execute(statement) @@ -105,7 +109,9 @@ def _make_staging_area(txn): LEFT JOIN room_stats_earliest_token AS t USING (room_id) WHERE t.room_id IS NULL GROUP BY c.room_id - """ % (TEMP_TABLE,) + """ % ( + TEMP_TABLE, + ) txn.execute(sql) new_pos = yield self.get_max_stream_id_in_current_state_deltas() @@ -184,7 +190,8 @@ def _get_next_batch(txn): logger.info( "Processing the next %d rooms of %d remaining", - len(rooms_to_work_on), progress["remaining"], + len(rooms_to_work_on), + progress["remaining"], ) # Number of state events we've processed by going through each room @@ -204,10 +211,17 @@ def _get_next_batch(txn): avatar_id = current_state_ids.get((EventTypes.RoomAvatar, "")) canonical_alias_id = current_state_ids.get((EventTypes.CanonicalAlias, "")) - state_events = yield self.get_events([ - join_rules_id, history_visibility_id, encryption_id, name_id, - topic_id, avatar_id, canonical_alias_id, - ]) + state_events = yield self.get_events( + [ + join_rules_id, + history_visibility_id, + encryption_id, + name_id, + topic_id, + avatar_id, + canonical_alias_id, + ] + ) def _get_or_none(event_id, arg): event = state_events.get(event_id) @@ -271,7 +285,7 @@ def _fetch_data(txn): # We've finished a room. Delete it from the table. self._simple_delete_one_txn( - txn, TEMP_TABLE + "_rooms", {"room_id": room_id}, + txn, TEMP_TABLE + "_rooms", {"room_id": room_id} ) yield self.runInteraction("update_room_stats", _fetch_data) @@ -338,7 +352,7 @@ def update_room_state(self, room_id, fields): "name", "topic", "avatar", - "canonical_alias" + "canonical_alias", ): field = fields.get(col) if field and "\0" in field: diff --git a/synapse/storage/stream.py b/synapse/storage/stream.py index 6f7f65d96ba3..d9482a384339 100644 --- a/synapse/storage/stream.py +++ b/synapse/storage/stream.py @@ -65,7 +65,7 @@ def generate_pagination_where_clause( - direction, column_names, from_token, to_token, engine, + direction, column_names, from_token, to_token, engine ): """Creates an SQL expression to bound the columns by the pagination tokens. @@ -153,7 +153,7 @@ def _make_generic_sql_bound(bound, column_names, values, engine): str """ - assert(bound in (">", "<", ">=", "<=")) + assert bound in (">", "<", ">=", "<=") name1, name2 = column_names val1, val2 = values @@ -169,11 +169,7 @@ def _make_generic_sql_bound(bound, column_names, values, engine): # Postgres doesn't optimise ``(x < a) OR (x=a AND y" % ( - id(self), self._result, self._deferred, + id(self), + self._result, + self._deferred, ) @@ -150,10 +153,12 @@ def _concurrently_execute_inner(): except StopIteration: pass - return logcontext.make_deferred_yieldable(defer.gatherResults([ - run_in_background(_concurrently_execute_inner) - for _ in range(limit) - ], consumeErrors=True)).addErrback(unwrapFirstError) + return logcontext.make_deferred_yieldable( + defer.gatherResults( + [run_in_background(_concurrently_execute_inner) for _ in range(limit)], + consumeErrors=True, + ) + ).addErrback(unwrapFirstError) def yieldable_gather_results(func, iter, *args, **kwargs): @@ -169,10 +174,12 @@ def yieldable_gather_results(func, iter, *args, **kwargs): Deferred[list]: Resolved when all functions have been invoked, or errors if one of the function calls fails. """ - return logcontext.make_deferred_yieldable(defer.gatherResults([ - run_in_background(func, item, *args, **kwargs) - for item in iter - ], consumeErrors=True)).addErrback(unwrapFirstError) + return logcontext.make_deferred_yieldable( + defer.gatherResults( + [run_in_background(func, item, *args, **kwargs) for item in iter], + consumeErrors=True, + ) + ).addErrback(unwrapFirstError) class Linearizer(object): @@ -185,6 +192,7 @@ class Linearizer(object): # do some work. """ + def __init__(self, name=None, max_count=1, clock=None): """ Args: @@ -197,6 +205,7 @@ def __init__(self, name=None, max_count=1, clock=None): if not clock: from twisted.internet import reactor + clock = Clock(reactor) self._clock = clock self.max_count = max_count @@ -221,7 +230,7 @@ def queue(self, key): res = self._await_lock(key) else: logger.debug( - "Acquired uncontended linearizer lock %r for key %r", self.name, key, + "Acquired uncontended linearizer lock %r for key %r", self.name, key ) entry[0] += 1 res = defer.succeed(None) @@ -266,9 +275,7 @@ def _await_lock(self, key): """ entry = self.key_to_defer[key] - logger.debug( - "Waiting to acquire linearizer lock %r for key %r", self.name, key, - ) + logger.debug("Waiting to acquire linearizer lock %r for key %r", self.name, key) new_defer = make_deferred_yieldable(defer.Deferred()) entry[1][new_defer] = 1 @@ -293,14 +300,14 @@ def eb(e): logger.info("defer %r got err %r", new_defer, e) if isinstance(e, CancelledError): logger.debug( - "Cancelling wait for linearizer lock %r for key %r", - self.name, key, + "Cancelling wait for linearizer lock %r for key %r", self.name, key ) else: logger.warn( "Unexpected exception waiting for linearizer lock %r for key %r", - self.name, key, + self.name, + key, ) # we just have to take ourselves back out of the queue. @@ -438,7 +445,7 @@ def time_it_out(): try: deferred.cancel() - except: # noqa: E722, if we throw any exception it'll break time outs + except: # noqa: E722, if we throw any exception it'll break time outs logger.exception("Canceller failed during timeout") if not new_d.called: diff --git a/synapse/util/caches/__init__.py b/synapse/util/caches/__init__.py index f37d5bec0864..8271229015b3 100644 --- a/synapse/util/caches/__init__.py +++ b/synapse/util/caches/__init__.py @@ -104,8 +104,8 @@ def collect(self): KNOWN_KEYS = { - key: key for key in - ( + key: key + for key in ( "auth_events", "content", "depth", @@ -150,7 +150,7 @@ def intern_dict(dictionary): def _intern_known_values(key, value): - intern_keys = ("event_id", "room_id", "sender", "user_id", "type", "state_key",) + intern_keys = ("event_id", "room_id", "sender", "user_id", "type", "state_key") if key in intern_keys: return intern_string(value) diff --git a/synapse/util/caches/descriptors.py b/synapse/util/caches/descriptors.py index 187510576a36..d2f25063aaa8 100644 --- a/synapse/util/caches/descriptors.py +++ b/synapse/util/caches/descriptors.py @@ -40,9 +40,7 @@ class CacheEntry(object): - __slots__ = [ - "deferred", "callbacks", "invalidated" - ] + __slots__ = ["deferred", "callbacks", "invalidated"] def __init__(self, deferred, callbacks): self.deferred = deferred @@ -73,7 +71,9 @@ def __init__(self, name, max_entries=1000, keylen=1, tree=False, iterable=False) self._pending_deferred_cache = cache_type() self.cache = LruCache( - max_size=max_entries, keylen=keylen, cache_type=cache_type, + max_size=max_entries, + keylen=keylen, + cache_type=cache_type, size_callback=(lambda d: len(d)) if iterable else None, evicted_callback=self._on_evicted, ) @@ -133,10 +133,7 @@ def get(self, key, default=_CacheSentinel, callback=None, update_metrics=True): def set(self, key, value, callback=None): callbacks = [callback] if callback else [] self.check_thread() - entry = CacheEntry( - deferred=value, - callbacks=callbacks, - ) + entry = CacheEntry(deferred=value, callbacks=callbacks) existing_entry = self._pending_deferred_cache.pop(key, None) if existing_entry: @@ -191,9 +188,7 @@ def invalidate(self, key): def invalidate_many(self, key): self.check_thread() if not isinstance(key, tuple): - raise TypeError( - "The cache key must be a tuple not %r" % (type(key),) - ) + raise TypeError("The cache key must be a tuple not %r" % (type(key),)) self.cache.del_multi(key) # if we have a pending lookup for this key, remove it from the @@ -244,29 +239,25 @@ def __init__(self, orig, num_args, inlineCallbacks, cache_context=False): raise Exception( "Not enough explicit positional arguments to key off for %r: " "got %i args, but wanted %i. (@cached cannot key off *args or " - "**kwargs)" - % (orig.__name__, len(all_args), num_args) + "**kwargs)" % (orig.__name__, len(all_args), num_args) ) self.num_args = num_args # list of the names of the args used as the cache key - self.arg_names = all_args[1:num_args + 1] + self.arg_names = all_args[1 : num_args + 1] # self.arg_defaults is a map of arg name to its default value for each # argument that has a default value if arg_spec.defaults: - self.arg_defaults = dict(zip( - all_args[-len(arg_spec.defaults):], - arg_spec.defaults - )) + self.arg_defaults = dict( + zip(all_args[-len(arg_spec.defaults) :], arg_spec.defaults) + ) else: self.arg_defaults = {} if "cache_context" in self.arg_names: - raise Exception( - "cache_context arg cannot be included among the cache keys" - ) + raise Exception("cache_context arg cannot be included among the cache keys") self.add_cache_context = cache_context @@ -304,12 +295,24 @@ def foo(self, key, cache_context): ``cache_context``) to use as cache keys. Defaults to all named args of the function. """ - def __init__(self, orig, max_entries=1000, num_args=None, tree=False, - inlineCallbacks=False, cache_context=False, iterable=False): + + def __init__( + self, + orig, + max_entries=1000, + num_args=None, + tree=False, + inlineCallbacks=False, + cache_context=False, + iterable=False, + ): super(CacheDescriptor, self).__init__( - orig, num_args=num_args, inlineCallbacks=inlineCallbacks, - cache_context=cache_context) + orig, + num_args=num_args, + inlineCallbacks=inlineCallbacks, + cache_context=cache_context, + ) max_entries = int(max_entries * get_cache_factor_for(orig.__name__)) @@ -356,7 +359,9 @@ def get_cache_key(args, kwargs): return args[0] else: return self.arg_defaults[nm] + else: + def get_cache_key(args, kwargs): return tuple(get_cache_key_gen(args, kwargs)) @@ -383,8 +388,7 @@ def wrapped(*args, **kwargs): except KeyError: ret = defer.maybeDeferred( - logcontext.preserve_fn(self.function_to_call), - obj, *args, **kwargs + logcontext.preserve_fn(self.function_to_call), obj, *args, **kwargs ) def onErr(f): @@ -437,8 +441,9 @@ class CacheListDescriptor(_CacheDescriptorBase): results. """ - def __init__(self, orig, cached_method_name, list_name, num_args=None, - inlineCallbacks=False): + def __init__( + self, orig, cached_method_name, list_name, num_args=None, inlineCallbacks=False + ): """ Args: orig (function) @@ -451,7 +456,8 @@ def __init__(self, orig, cached_method_name, list_name, num_args=None, be wrapped by defer.inlineCallbacks """ super(CacheListDescriptor, self).__init__( - orig, num_args=num_args, inlineCallbacks=inlineCallbacks) + orig, num_args=num_args, inlineCallbacks=inlineCallbacks + ) self.list_name = list_name @@ -463,7 +469,7 @@ def __init__(self, orig, cached_method_name, list_name, num_args=None, if self.list_name not in self.arg_names: raise Exception( "Couldn't see arguments %r for %r." - % (self.list_name, cached_method_name,) + % (self.list_name, cached_method_name) ) def __get__(self, obj, objtype=None): @@ -494,8 +500,10 @@ def update_results_dict(res, arg): # If the cache takes a single arg then that is used as the key, # otherwise a tuple is used. if num_args == 1: + def arg_to_cache_key(arg): return arg + else: keylist = list(keyargs) @@ -505,8 +513,7 @@ def arg_to_cache_key(arg): for arg in list_args: try: - res = cache.get(arg_to_cache_key(arg), - callback=invalidate_callback) + res = cache.get(arg_to_cache_key(arg), callback=invalidate_callback) if not isinstance(res, ObservableDeferred): results[arg] = res elif not res.has_succeeded(): @@ -554,18 +561,15 @@ def errback(f): args_to_call = dict(arg_dict) args_to_call[self.list_name] = list(missing) - cached_defers.append(defer.maybeDeferred( - logcontext.preserve_fn(self.function_to_call), - **args_to_call - ).addCallbacks(complete_all, errback)) + cached_defers.append( + defer.maybeDeferred( + logcontext.preserve_fn(self.function_to_call), **args_to_call + ).addCallbacks(complete_all, errback) + ) if cached_defers: - d = defer.gatherResults( - cached_defers, - consumeErrors=True, - ).addCallbacks( - lambda _: results, - unwrapFirstError + d = defer.gatherResults(cached_defers, consumeErrors=True).addCallbacks( + lambda _: results, unwrapFirstError ) return logcontext.make_deferred_yieldable(d) else: @@ -586,8 +590,9 @@ def invalidate(self): self.cache.invalidate(self.key) -def cached(max_entries=1000, num_args=None, tree=False, cache_context=False, - iterable=False): +def cached( + max_entries=1000, num_args=None, tree=False, cache_context=False, iterable=False +): return lambda orig: CacheDescriptor( orig, max_entries=max_entries, @@ -598,8 +603,9 @@ def cached(max_entries=1000, num_args=None, tree=False, cache_context=False, ) -def cachedInlineCallbacks(max_entries=1000, num_args=None, tree=False, - cache_context=False, iterable=False): +def cachedInlineCallbacks( + max_entries=1000, num_args=None, tree=False, cache_context=False, iterable=False +): return lambda orig: CacheDescriptor( orig, max_entries=max_entries, diff --git a/synapse/util/caches/dictionary_cache.py b/synapse/util/caches/dictionary_cache.py index 6c0b5a409463..6834e6f3ae7a 100644 --- a/synapse/util/caches/dictionary_cache.py +++ b/synapse/util/caches/dictionary_cache.py @@ -35,6 +35,7 @@ class DictionaryEntry(namedtuple("DictionaryEntry", ("full", "known_absent", "va there. value (dict): The full or partial dict value """ + def __len__(self): return len(self.value) @@ -84,13 +85,15 @@ def get(self, key, dict_keys=None): self.metrics.inc_hits() if dict_keys is None: - return DictionaryEntry(entry.full, entry.known_absent, dict(entry.value)) + return DictionaryEntry( + entry.full, entry.known_absent, dict(entry.value) + ) else: - return DictionaryEntry(entry.full, entry.known_absent, { - k: entry.value[k] - for k in dict_keys - if k in entry.value - }) + return DictionaryEntry( + entry.full, + entry.known_absent, + {k: entry.value[k] for k in dict_keys if k in entry.value}, + ) self.metrics.inc_misses() return DictionaryEntry(False, set(), {}) diff --git a/synapse/util/caches/expiringcache.py b/synapse/util/caches/expiringcache.py index f36978027748..cddf1ed51521 100644 --- a/synapse/util/caches/expiringcache.py +++ b/synapse/util/caches/expiringcache.py @@ -28,8 +28,15 @@ class ExpiringCache(object): - def __init__(self, cache_name, clock, max_len=0, expiry_ms=0, - reset_expiry_on_get=False, iterable=False): + def __init__( + self, + cache_name, + clock, + max_len=0, + expiry_ms=0, + reset_expiry_on_get=False, + iterable=False, + ): """ Args: cache_name (str): Name of this cache, used for logging. @@ -67,8 +74,7 @@ def __init__(self, cache_name, clock, max_len=0, expiry_ms=0, def f(): return run_as_background_process( - "prune_cache_%s" % self._cache_name, - self._prune_cache, + "prune_cache_%s" % self._cache_name, self._prune_cache ) self._clock.looping_call(f, self._expiry_ms / 2) @@ -153,7 +159,9 @@ def _prune_cache(self): logger.debug( "[%s] _prune_cache before: %d, after len: %d", - self._cache_name, begin_length, len(self) + self._cache_name, + begin_length, + len(self), ) def __len__(self): diff --git a/synapse/util/caches/lrucache.py b/synapse/util/caches/lrucache.py index b684f24e7b3a..1536cb64f3bd 100644 --- a/synapse/util/caches/lrucache.py +++ b/synapse/util/caches/lrucache.py @@ -49,8 +49,15 @@ class LruCache(object): Can also set callbacks on objects when getting/setting which are fired when that key gets invalidated/evicted. """ - def __init__(self, max_size, keylen=1, cache_type=dict, size_callback=None, - evicted_callback=None): + + def __init__( + self, + max_size, + keylen=1, + cache_type=dict, + size_callback=None, + evicted_callback=None, + ): """ Args: max_size (int): @@ -93,9 +100,12 @@ def inner(*args, **kwargs): cached_cache_len = [0] if size_callback is not None: + def cache_len(): return cached_cache_len[0] + else: + def cache_len(): return len(cache) diff --git a/synapse/util/caches/response_cache.py b/synapse/util/caches/response_cache.py index afb03b2e1b7c..b1da81633cd7 100644 --- a/synapse/util/caches/response_cache.py +++ b/synapse/util/caches/response_cache.py @@ -35,12 +35,10 @@ def __init__(self, hs, name, timeout_ms=0): self.pending_result_cache = {} # Requests that haven't finished yet. self.clock = hs.get_clock() - self.timeout_sec = timeout_ms / 1000. + self.timeout_sec = timeout_ms / 1000.0 self._name = name - self._metrics = register_cache( - "response_cache", name, self - ) + self._metrics = register_cache("response_cache", name, self) def size(self): return len(self.pending_result_cache) @@ -100,8 +98,7 @@ def set(self, key, deferred): def remove(r): if self.timeout_sec: self.clock.call_later( - self.timeout_sec, - self.pending_result_cache.pop, key, None, + self.timeout_sec, self.pending_result_cache.pop, key, None ) else: self.pending_result_cache.pop(key, None) @@ -147,14 +144,15 @@ def handle_request(request): """ result = self.get(key) if not result: - logger.info("[%s]: no cached result for [%s], calculating new one", - self._name, key) + logger.info( + "[%s]: no cached result for [%s], calculating new one", self._name, key + ) d = run_in_background(callback, *args, **kwargs) result = self.set(key, d) elif not isinstance(result, defer.Deferred) or result.called: - logger.info("[%s]: using completed cached result for [%s]", - self._name, key) + logger.info("[%s]: using completed cached result for [%s]", self._name, key) else: - logger.info("[%s]: using incomplete cached result for [%s]", - self._name, key) + logger.info( + "[%s]: using incomplete cached result for [%s]", self._name, key + ) return make_deferred_yieldable(result) diff --git a/synapse/util/caches/stream_change_cache.py b/synapse/util/caches/stream_change_cache.py index 625aedc940a9..235f64049c95 100644 --- a/synapse/util/caches/stream_change_cache.py +++ b/synapse/util/caches/stream_change_cache.py @@ -77,9 +77,8 @@ def get_entities_changed(self, entities, stream_pos): if stream_pos >= self._earliest_known_stream_pos: changed_entities = { - self._cache[k] for k in self._cache.islice( - start=self._cache.bisect_right(stream_pos), - ) + self._cache[k] + for k in self._cache.islice(start=self._cache.bisect_right(stream_pos)) } result = changed_entities.intersection(entities) @@ -114,8 +113,10 @@ def get_all_entities_changed(self, stream_pos): assert type(stream_pos) is int if stream_pos >= self._earliest_known_stream_pos: - return [self._cache[k] for k in self._cache.islice( - start=self._cache.bisect_right(stream_pos))] + return [ + self._cache[k] + for k in self._cache.islice(start=self._cache.bisect_right(stream_pos)) + ] else: return None @@ -136,7 +137,7 @@ def entity_has_changed(self, entity, stream_pos): while len(self._cache) > self._max_size: k, r = self._cache.popitem(0) self._earliest_known_stream_pos = max( - k, self._earliest_known_stream_pos, + k, self._earliest_known_stream_pos ) self._entity_to_key.pop(r, None) diff --git a/synapse/util/caches/treecache.py b/synapse/util/caches/treecache.py index dd4c9e60677e..9a72218d85ac 100644 --- a/synapse/util/caches/treecache.py +++ b/synapse/util/caches/treecache.py @@ -9,6 +9,7 @@ class TreeCache(object): efficiently. Keys must be tuples. """ + def __init__(self): self.size = 0 self.root = {} diff --git a/synapse/util/caches/ttlcache.py b/synapse/util/caches/ttlcache.py index 5ba186250667..2af8ca43b123 100644 --- a/synapse/util/caches/ttlcache.py +++ b/synapse/util/caches/ttlcache.py @@ -155,6 +155,7 @@ def expire(self): @attr.s(frozen=True, slots=True) class _CacheEntry(object): """TTLCache entry""" + # expiry_time is the first attribute, so that entries are sorted by expiry. expiry_time = attr.ib() key = attr.ib() diff --git a/synapse/util/distributor.py b/synapse/util/distributor.py index e14c8bdfda43..5a79db821c98 100644 --- a/synapse/util/distributor.py +++ b/synapse/util/distributor.py @@ -51,9 +51,7 @@ def declare(self, name): if name in self.signals: raise KeyError("%r already has a signal named %s" % (self, name)) - self.signals[name] = Signal( - name, - ) + self.signals[name] = Signal(name) if name in self.pre_registration: signal = self.signals[name] @@ -78,11 +76,7 @@ def fire(self, name, *args, **kwargs): if name not in self.signals: raise KeyError("%r does not have a signal named %s" % (self, name)) - run_as_background_process( - name, - self.signals[name].fire, - *args, **kwargs - ) + run_as_background_process(name, self.signals[name].fire, *args, **kwargs) class Signal(object): @@ -118,22 +112,23 @@ def do(observer): def eb(failure): logger.warning( "%s signal observer %s failed: %r", - self.name, observer, failure, + self.name, + observer, + failure, exc_info=( failure.type, failure.value, - failure.getTracebackObject())) + failure.getTracebackObject(), + ), + ) return defer.maybeDeferred(observer, *args, **kwargs).addErrback(eb) - deferreds = [ - run_in_background(do, o) - for o in self.observers - ] + deferreds = [run_in_background(do, o) for o in self.observers] - return make_deferred_yieldable(defer.gatherResults( - deferreds, consumeErrors=True, - )) + return make_deferred_yieldable( + defer.gatherResults(deferreds, consumeErrors=True) + ) def __repr__(self): return "" % (self.name,) diff --git a/synapse/util/frozenutils.py b/synapse/util/frozenutils.py index 014edea971a7..635b897d6c1a 100644 --- a/synapse/util/frozenutils.py +++ b/synapse/util/frozenutils.py @@ -60,11 +60,10 @@ def _handle_frozendict(obj): # fishing the protected dict out of the object is a bit nasty, # but we don't really want the overhead of copying the dict. return obj._dict - raise TypeError('Object of type %s is not JSON serializable' % - obj.__class__.__name__) + raise TypeError( + "Object of type %s is not JSON serializable" % obj.__class__.__name__ + ) # A JSONEncoder which is capable of encoding frozendics without barfing -frozendict_json_encoder = json.JSONEncoder( - default=_handle_frozendict, -) +frozendict_json_encoder = json.JSONEncoder(default=_handle_frozendict) diff --git a/synapse/util/httpresourcetree.py b/synapse/util/httpresourcetree.py index 2d7ddc1cbef1..1a20c596bf12 100644 --- a/synapse/util/httpresourcetree.py +++ b/synapse/util/httpresourcetree.py @@ -45,7 +45,7 @@ def create_resource_tree(desired_tree, root_resource): logger.info("Attaching %s to path %s", res, full_path) last_resource = root_resource - for path_seg in full_path.split(b'/')[1:-1]: + for path_seg in full_path.split(b"/")[1:-1]: if path_seg not in last_resource.listNames(): # resource doesn't exist, so make a "dummy resource" child_resource = NoResource() @@ -60,7 +60,7 @@ def create_resource_tree(desired_tree, root_resource): # =========================== # now attach the actual desired resource - last_path_seg = full_path.split(b'/')[-1] + last_path_seg = full_path.split(b"/")[-1] # if there is already a resource here, thieve its children and # replace it @@ -70,9 +70,7 @@ def create_resource_tree(desired_tree, root_resource): # to be replaced with the desired resource. existing_dummy_resource = resource_mappings[res_id] for child_name in existing_dummy_resource.listNames(): - child_res_id = _resource_id( - existing_dummy_resource, child_name - ) + child_res_id = _resource_id(existing_dummy_resource, child_name) child_resource = resource_mappings[child_res_id] # steal the children res.putChild(child_name, child_resource) diff --git a/synapse/util/jsonobject.py b/synapse/util/jsonobject.py index d668e5a6b8cb..6dce03dd3ac8 100644 --- a/synapse/util/jsonobject.py +++ b/synapse/util/jsonobject.py @@ -70,7 +70,8 @@ def get_dict(self): dict """ d = { - k: _encode(v) for (k, v) in self.__dict__.items() + k: _encode(v) + for (k, v) in self.__dict__.items() if k in self.valid_keys and k not in self.internal_keys } d.update(self.unrecognized_keys) @@ -78,7 +79,8 @@ def get_dict(self): def get_internal_dict(self): d = { - k: _encode(v, internal=True) for (k, v) in self.__dict__.items() + k: _encode(v, internal=True) + for (k, v) in self.__dict__.items() if k in self.valid_keys } d.update(self.unrecognized_keys) diff --git a/synapse/util/logcontext.py b/synapse/util/logcontext.py index fe412355d858..a9885cb5078b 100644 --- a/synapse/util/logcontext.py +++ b/synapse/util/logcontext.py @@ -42,6 +42,8 @@ def get_thread_resource_usage(): return resource.getrusage(RUSAGE_THREAD) + + except Exception: # If the system doesn't support resource.getrusage(RUSAGE_THREAD) then we # won't track resource usage by returning None. @@ -64,8 +66,11 @@ class ContextResourceUsage(object): """ __slots__ = [ - "ru_stime", "ru_utime", - "db_txn_count", "db_txn_duration_sec", "db_sched_duration_sec", + "ru_stime", + "ru_utime", + "db_txn_count", + "db_txn_duration_sec", + "db_sched_duration_sec", "evt_db_fetch_count", ] @@ -91,8 +96,8 @@ def copy(self): return ContextResourceUsage(copy_from=self) def reset(self): - self.ru_stime = 0. - self.ru_utime = 0. + self.ru_stime = 0.0 + self.ru_utime = 0.0 self.db_txn_count = 0 self.db_txn_duration_sec = 0 @@ -100,15 +105,18 @@ def reset(self): self.evt_db_fetch_count = 0 def __repr__(self): - return ("") % ( - self.ru_stime, - self.ru_utime, - self.db_txn_count, - self.db_txn_duration_sec, - self.db_sched_duration_sec, - self.evt_db_fetch_count,) + return ( + "" + ) % ( + self.ru_stime, + self.ru_utime, + self.db_txn_count, + self.db_txn_duration_sec, + self.db_sched_duration_sec, + self.evt_db_fetch_count, + ) def __iadd__(self, other): """Add another ContextResourceUsage's stats to this one's. @@ -159,11 +167,15 @@ class LoggingContext(object): """ __slots__ = [ - "previous_context", "name", "parent_context", + "previous_context", + "name", + "parent_context", "_resource_usage", "usage_start", - "main_thread", "alive", - "request", "tag", + "main_thread", + "alive", + "request", + "tag", ] thread_local = threading.local() @@ -196,6 +208,7 @@ def record_event_fetch(self, event_count): def __nonzero__(self): return False + __bool__ = __nonzero__ # python3 sentinel = Sentinel() @@ -261,7 +274,8 @@ def __enter__(self): if self.previous_context != old_context: logger.warn( "Expected previous context %r, found %r", - self.previous_context, old_context + self.previous_context, + old_context, ) self.alive = True @@ -285,9 +299,8 @@ def __exit__(self, type, value, traceback): self.alive = False # if we have a parent, pass our CPU usage stats on - if ( - self.parent_context is not None - and hasattr(self.parent_context, '_resource_usage') + if self.parent_context is not None and hasattr( + self.parent_context, "_resource_usage" ): self.parent_context._resource_usage += self._resource_usage @@ -320,9 +333,7 @@ def stop(self): # When we stop, let's record the cpu used since we started if not self.usage_start: - logger.warning( - "Called stop on logcontext %s without calling start", self, - ) + logger.warning("Called stop on logcontext %s without calling start", self) return usage_end = get_thread_resource_usage() @@ -381,6 +392,7 @@ class LoggingContextFilter(logging.Filter): **defaults: Default values to avoid formatters complaining about missing fields """ + def __init__(self, **defaults): self.defaults = defaults @@ -416,17 +428,12 @@ def __init__(self, new_context=None): def __enter__(self): """Captures the current logging context""" - self.current_context = LoggingContext.set_current_context( - self.new_context - ) + self.current_context = LoggingContext.set_current_context(self.new_context) if self.current_context: self.has_parent = self.current_context.previous_context is not None if not self.current_context.alive: - logger.debug( - "Entering dead context: %s", - self.current_context, - ) + logger.debug("Entering dead context: %s", self.current_context) def __exit__(self, type, value, traceback): """Restores the current logging context""" @@ -444,10 +451,7 @@ def __exit__(self, type, value, traceback): if self.current_context is not LoggingContext.sentinel: if not self.current_context.alive: - logger.debug( - "Restoring dead context: %s", - self.current_context, - ) + logger.debug("Restoring dead context: %s", self.current_context) def nested_logging_context(suffix, parent_context=None): @@ -474,15 +478,16 @@ def nested_logging_context(suffix, parent_context=None): if parent_context is None: parent_context = LoggingContext.current_context() return LoggingContext( - parent_context=parent_context, - request=parent_context.request + "-" + suffix, + parent_context=parent_context, request=parent_context.request + "-" + suffix ) def preserve_fn(f): """Function decorator which wraps the function with run_in_background""" + def g(*args, **kwargs): return run_in_background(f, *args, **kwargs) + return g @@ -502,7 +507,7 @@ def run_in_background(f, *args, **kwargs): current = LoggingContext.current_context() try: res = f(*args, **kwargs) - except: # noqa: E722 + except: # noqa: E722 # the assumption here is that the caller doesn't want to be disturbed # by synchronous exceptions, so let's turn them into Failures. return defer.fail() @@ -639,6 +644,4 @@ def g(): with LoggingContext(parent_context=logcontext): return f(*args, **kwargs) - return make_deferred_yieldable( - threads.deferToThreadPool(reactor, threadpool, g) - ) + return make_deferred_yieldable(threads.deferToThreadPool(reactor, threadpool, g)) diff --git a/synapse/util/logformatter.py b/synapse/util/logformatter.py index a46bc47ce3d3..fbf570c756d8 100644 --- a/synapse/util/logformatter.py +++ b/synapse/util/logformatter.py @@ -29,6 +29,7 @@ class LogFormatter(logging.Formatter): (Normally only stack frames between the point the exception was raised and where it was caught are logged). """ + def __init__(self, *args, **kwargs): super(LogFormatter, self).__init__(*args, **kwargs) @@ -40,7 +41,7 @@ def formatException(self, ei): # check that we actually have an f_back attribute to work around # https://twistedmatrix.com/trac/ticket/9305 - if tb and hasattr(tb.tb_frame, 'f_back'): + if tb and hasattr(tb.tb_frame, "f_back"): sio.write("Capture point (most recent call last):\n") traceback.print_stack(tb.tb_frame.f_back, None, sio) diff --git a/synapse/util/logutils.py b/synapse/util/logutils.py index ef31458226b3..7df0fa6087bf 100644 --- a/synapse/util/logutils.py +++ b/synapse/util/logutils.py @@ -44,7 +44,7 @@ def _log_debug_as_f(f, msg, msg_args): lineno=lineno, msg=msg, args=msg_args, - exc_info=None + exc_info=None, ) logger.handle(record) @@ -70,20 +70,11 @@ def format(value): r = r[:50] + "..." return r - func_args = [ - "%s=%s" % (k, format(v)) for k, v in bound_args.items() - ] + func_args = ["%s=%s" % (k, format(v)) for k, v in bound_args.items()] - msg_args = { - "func_name": func_name, - "args": ", ".join(func_args) - } + msg_args = {"func_name": func_name, "args": ", ".join(func_args)} - _log_debug_as_f( - f, - "Invoked '%(func_name)s' with args: %(args)s", - msg_args - ) + _log_debug_as_f(f, "Invoked '%(func_name)s' with args: %(args)s", msg_args) return f(*args, **kwargs) @@ -103,19 +94,13 @@ def wrapped(*args, **kwargs): start = time.clock() try: - _log_debug_as_f( - f, - "[FUNC START] {%s-%d}", - (func_name, id), - ) + _log_debug_as_f(f, "[FUNC START] {%s-%d}", (func_name, id)) r = f(*args, **kwargs) finally: end = time.clock() _log_debug_as_f( - f, - "[FUNC END] {%s-%d} %.3f sec", - (func_name, id, end - start,), + f, "[FUNC END] {%s-%d} %.3f sec", (func_name, id, end - start) ) return r @@ -137,9 +122,8 @@ def wrapped(*args, **kwargs): s = inspect.currentframe().f_back to_print = [ - "\t%s:%s %s. Args: args=%s, kwargs=%s" % ( - pathname, linenum, func_name, args, kwargs - ) + "\t%s:%s %s. Args: args=%s, kwargs=%s" + % (pathname, linenum, func_name, args, kwargs) ] while s: if True or s.f_globals["__name__"].startswith("synapse"): @@ -147,9 +131,7 @@ def wrapped(*args, **kwargs): args_string = inspect.formatargvalues(*inspect.getargvalues(s)) to_print.append( - "\t%s:%d %s. Args: %s" % ( - filename, lineno, function, args_string - ) + "\t%s:%d %s. Args: %s" % (filename, lineno, function, args_string) ) s = s.f_back @@ -163,7 +145,7 @@ def wrapped(*args, **kwargs): lineno=lineno, msg=msg, args=None, - exc_info=None + exc_info=None, ) logger.handle(record) @@ -182,13 +164,13 @@ def get_previous_frames(): filename, lineno, function, _, _ = inspect.getframeinfo(s) args_string = inspect.formatargvalues(*inspect.getargvalues(s)) - to_return.append("{{ %s:%d %s - Args: %s }}" % ( - filename, lineno, function, args_string - )) + to_return.append( + "{{ %s:%d %s - Args: %s }}" % (filename, lineno, function, args_string) + ) s = s.f_back - return ", ". join(to_return) + return ", ".join(to_return) def get_previous_frame(ignore=[]): @@ -201,7 +183,10 @@ def get_previous_frame(ignore=[]): args_string = inspect.formatargvalues(*inspect.getargvalues(s)) return "{{ %s:%d %s - Args: %s }}" % ( - filename, lineno, function, args_string + filename, + lineno, + function, + args_string, ) s = s.f_back diff --git a/synapse/util/manhole.py b/synapse/util/manhole.py index 628a2962d9da..631654f2974e 100644 --- a/synapse/util/manhole.py +++ b/synapse/util/manhole.py @@ -74,27 +74,25 @@ def manhole(username, password, globals): twisted.internet.protocol.Factory: A factory to pass to ``listenTCP`` """ if not isinstance(password, bytes): - password = password.encode('ascii') + password = password.encode("ascii") - checker = checkers.InMemoryUsernamePasswordDatabaseDontUse( - **{username: password} - ) + checker = checkers.InMemoryUsernamePasswordDatabaseDontUse(**{username: password}) rlm = manhole_ssh.TerminalRealm() rlm.chainedProtocolFactory = lambda: insults.ServerProtocol( - SynapseManhole, - dict(globals, __name__="__console__") + SynapseManhole, dict(globals, __name__="__console__") ) factory = manhole_ssh.ConchFactory(portal.Portal(rlm, [checker])) - factory.publicKeys[b'ssh-rsa'] = Key.fromString(PUBLIC_KEY) - factory.privateKeys[b'ssh-rsa'] = Key.fromString(PRIVATE_KEY) + factory.publicKeys[b"ssh-rsa"] = Key.fromString(PUBLIC_KEY) + factory.privateKeys[b"ssh-rsa"] = Key.fromString(PRIVATE_KEY) return factory class SynapseManhole(ColoredManhole): """Overrides connectionMade to create our own ManholeInterpreter""" + def connectionMade(self): super(SynapseManhole, self).connectionMade() @@ -127,7 +125,7 @@ def showsyntaxerror(self, filename=None): value = SyntaxError(msg, (filename, lineno, offset, line)) sys.last_value = value lines = traceback.format_exception_only(type, value) - self.write(''.join(lines)) + self.write("".join(lines)) def showtraceback(self): """Display the exception that just occurred. @@ -140,6 +138,6 @@ def showtraceback(self): try: # We remove the first stack item because it is our own code. lines = traceback.format_exception(ei[0], ei[1], last_tb.tb_next) - self.write(''.join(lines)) + self.write("".join(lines)) finally: last_tb = ei = None diff --git a/synapse/util/metrics.py b/synapse/util/metrics.py index 4b4ac5f6c7d8..01284d3cf812 100644 --- a/synapse/util/metrics.py +++ b/synapse/util/metrics.py @@ -30,25 +30,31 @@ block_timer = Counter("synapse_util_metrics_block_time_seconds", "", ["block_name"]) block_ru_utime = Counter( - "synapse_util_metrics_block_ru_utime_seconds", "", ["block_name"]) + "synapse_util_metrics_block_ru_utime_seconds", "", ["block_name"] +) block_ru_stime = Counter( - "synapse_util_metrics_block_ru_stime_seconds", "", ["block_name"]) + "synapse_util_metrics_block_ru_stime_seconds", "", ["block_name"] +) block_db_txn_count = Counter( - "synapse_util_metrics_block_db_txn_count", "", ["block_name"]) + "synapse_util_metrics_block_db_txn_count", "", ["block_name"] +) # seconds spent waiting for db txns, excluding scheduling time, in this block block_db_txn_duration = Counter( - "synapse_util_metrics_block_db_txn_duration_seconds", "", ["block_name"]) + "synapse_util_metrics_block_db_txn_duration_seconds", "", ["block_name"] +) # seconds spent waiting for a db connection, in this block block_db_sched_duration = Counter( - "synapse_util_metrics_block_db_sched_duration_seconds", "", ["block_name"]) + "synapse_util_metrics_block_db_sched_duration_seconds", "", ["block_name"] +) # Tracks the number of blocks currently active in_flight = InFlightGauge( - "synapse_util_metrics_block_in_flight", "", + "synapse_util_metrics_block_in_flight", + "", labels=["block_name"], sub_metrics=["real_time_max", "real_time_sum"], ) @@ -62,13 +68,18 @@ def measured_func(self, *args, **kwargs): with Measure(self.clock, name): r = yield func(self, *args, **kwargs) defer.returnValue(r) + return measured_func + return wrapper class Measure(object): __slots__ = [ - "clock", "name", "start_context", "start", + "clock", + "name", + "start_context", + "start", "created_context", "start_usage", ] @@ -108,7 +119,9 @@ def __exit__(self, exc_type, exc_val, exc_tb): if context != self.start_context: logger.warn( "Context has unexpectedly changed from '%s' to '%s'. (%r)", - self.start_context, context, self.name + self.start_context, + context, + self.name, ) return @@ -126,8 +139,7 @@ def __exit__(self, exc_type, exc_val, exc_tb): block_db_sched_duration.labels(self.name).inc(usage.db_sched_duration_sec) except ValueError: logger.warn( - "Failed to save metrics! OLD: %r, NEW: %r", - self.start_usage, current + "Failed to save metrics! OLD: %r, NEW: %r", self.start_usage, current ) if self.created_context: diff --git a/synapse/util/module_loader.py b/synapse/util/module_loader.py index 4288312b8ad2..522acd5aa80b 100644 --- a/synapse/util/module_loader.py +++ b/synapse/util/module_loader.py @@ -28,15 +28,13 @@ def load_module(provider): """ # We need to import the module, and then pick the class out of # that, so we split based on the last dot. - module, clz = provider['module'].rsplit(".", 1) + module, clz = provider["module"].rsplit(".", 1) module = importlib.import_module(module) provider_class = getattr(module, clz) try: provider_config = provider_class.parse_config(provider["config"]) except Exception as e: - raise ConfigError( - "Failed to parse config for %r: %r" % (provider['module'], e) - ) + raise ConfigError("Failed to parse config for %r: %r" % (provider["module"], e)) return provider_class, provider_config diff --git a/synapse/util/msisdn.py b/synapse/util/msisdn.py index a6c30e526546..c8bcbe297ab4 100644 --- a/synapse/util/msisdn.py +++ b/synapse/util/msisdn.py @@ -36,6 +36,6 @@ def phone_number_to_msisdn(country, number): phoneNumber = phonenumbers.parse(number, country) except phonenumbers.NumberParseException: raise SynapseError(400, "Unable to parse phone number") - return phonenumbers.format_number( - phoneNumber, phonenumbers.PhoneNumberFormat.E164 - )[1:] + return phonenumbers.format_number(phoneNumber, phonenumbers.PhoneNumberFormat.E164)[ + 1: + ] diff --git a/synapse/util/ratelimitutils.py b/synapse/util/ratelimitutils.py index b146d137f46b..06defa81992e 100644 --- a/synapse/util/ratelimitutils.py +++ b/synapse/util/ratelimitutils.py @@ -56,11 +56,7 @@ def ratelimit(self, host): _PerHostRatelimiter """ return self.ratelimiters.setdefault( - host, - _PerHostRatelimiter( - clock=self.clock, - config=self._config, - ) + host, _PerHostRatelimiter(clock=self.clock, config=self._config) ).ratelimit() @@ -112,8 +108,7 @@ def _on_enter(self, request_id): # remove any entries from request_times which aren't within the window self.request_times[:] = [ - r for r in self.request_times - if time_now - r < self.window_size + r for r in self.request_times if time_now - r < self.window_size ] # reject the request if we already have too many queued up (either @@ -121,9 +116,7 @@ def _on_enter(self, request_id): queue_size = len(self.ready_request_queue) + len(self.sleeping_requests) if queue_size > self.reject_limit: raise LimitExceededError( - retry_after_ms=int( - self.window_size / self.sleep_limit - ), + retry_after_ms=int(self.window_size / self.sleep_limit) ) self.request_times.append(time_now) @@ -143,22 +136,18 @@ def queue_request(): logger.debug( "Ratelimit [%s]: len(self.request_times)=%d", - id(request_id), len(self.request_times), + id(request_id), + len(self.request_times), ) if len(self.request_times) > self.sleep_limit: - logger.debug( - "Ratelimiter: sleeping request for %f sec", self.sleep_sec, - ) + logger.debug("Ratelimiter: sleeping request for %f sec", self.sleep_sec) ret_defer = run_in_background(self.clock.sleep, self.sleep_sec) self.sleeping_requests.add(request_id) def on_wait_finished(_): - logger.debug( - "Ratelimit [%s]: Finished sleeping", - id(request_id), - ) + logger.debug("Ratelimit [%s]: Finished sleeping", id(request_id)) self.sleeping_requests.discard(request_id) queue_defer = queue_request() return queue_defer @@ -168,10 +157,7 @@ def on_wait_finished(_): ret_defer = queue_request() def on_start(r): - logger.debug( - "Ratelimit [%s]: Processing req", - id(request_id), - ) + logger.debug("Ratelimit [%s]: Processing req", id(request_id)) self.current_processing.add(request_id) return r @@ -193,10 +179,7 @@ def on_both(r): return make_deferred_yieldable(ret_defer) def _on_exit(self, request_id): - logger.debug( - "Ratelimit [%s]: Processed req", - id(request_id), - ) + logger.debug("Ratelimit [%s]: Processed req", id(request_id)) self.current_processing.discard(request_id) try: # start processing the next item on the queue. diff --git a/synapse/util/stringutils.py b/synapse/util/stringutils.py index 69dffd824454..982c6d81ca82 100644 --- a/synapse/util/stringutils.py +++ b/synapse/util/stringutils.py @@ -20,9 +20,7 @@ from six import PY2, PY3 from six.moves import range -_string_with_symbols = ( - string.digits + string.ascii_letters + ".,;:^&*-_+=#~@" -) +_string_with_symbols = string.digits + string.ascii_letters + ".,;:^&*-_+=#~@" # random_string and random_string_with_symbols are used for a range of things, # some cryptographically important, some less so. We use SystemRandom to make sure @@ -31,13 +29,11 @@ def random_string(length): - return ''.join(rand.choice(string.ascii_letters) for _ in range(length)) + return "".join(rand.choice(string.ascii_letters) for _ in range(length)) def random_string_with_symbols(length): - return ''.join( - rand.choice(_string_with_symbols) for _ in range(length) - ) + return "".join(rand.choice(_string_with_symbols) for _ in range(length)) def is_ascii(s): @@ -45,7 +41,7 @@ def is_ascii(s): if PY3: if isinstance(s, bytes): try: - s.decode('ascii').encode('ascii') + s.decode("ascii").encode("ascii") except UnicodeDecodeError: return False except UnicodeEncodeError: @@ -104,12 +100,12 @@ def exception_to_unicode(e): # and instead look at what is in the args member. if len(e.args) == 0: - return u"" + return "" elif len(e.args) > 1: return six.text_type(repr(e.args)) msg = e.args[0] if isinstance(msg, bytes): - return msg.decode('utf-8', errors='replace') + return msg.decode("utf-8", errors="replace") else: return msg diff --git a/synapse/util/threepids.py b/synapse/util/threepids.py index 75efa0117bb7..3ec1dfb0c2ea 100644 --- a/synapse/util/threepids.py +++ b/synapse/util/threepids.py @@ -35,11 +35,13 @@ def check_3pid_allowed(hs, medium, address): for constraint in hs.config.allowed_local_3pids: logger.debug( "Checking 3PID %s (%s) against %s (%s)", - address, medium, constraint['pattern'], constraint['medium'], + address, + medium, + constraint["pattern"], + constraint["medium"], ) - if ( - medium == constraint['medium'] and - re.match(constraint['pattern'], address) + if medium == constraint["medium"] and re.match( + constraint["pattern"], address ): return True else: diff --git a/synapse/util/versionstring.py b/synapse/util/versionstring.py index 3baba3225afa..a4d9a462f790 100644 --- a/synapse/util/versionstring.py +++ b/synapse/util/versionstring.py @@ -23,44 +23,53 @@ def get_version_string(module): try: - null = open(os.devnull, 'w') + null = open(os.devnull, "w") cwd = os.path.dirname(os.path.abspath(module.__file__)) try: - git_branch = subprocess.check_output( - ['git', 'rev-parse', '--abbrev-ref', 'HEAD'], - stderr=null, - cwd=cwd, - ).strip().decode('ascii') + git_branch = ( + subprocess.check_output( + ["git", "rev-parse", "--abbrev-ref", "HEAD"], stderr=null, cwd=cwd + ) + .strip() + .decode("ascii") + ) git_branch = "b=" + git_branch except subprocess.CalledProcessError: git_branch = "" try: - git_tag = subprocess.check_output( - ['git', 'describe', '--exact-match'], - stderr=null, - cwd=cwd, - ).strip().decode('ascii') + git_tag = ( + subprocess.check_output( + ["git", "describe", "--exact-match"], stderr=null, cwd=cwd + ) + .strip() + .decode("ascii") + ) git_tag = "t=" + git_tag except subprocess.CalledProcessError: git_tag = "" try: - git_commit = subprocess.check_output( - ['git', 'rev-parse', '--short', 'HEAD'], - stderr=null, - cwd=cwd, - ).strip().decode('ascii') + git_commit = ( + subprocess.check_output( + ["git", "rev-parse", "--short", "HEAD"], stderr=null, cwd=cwd + ) + .strip() + .decode("ascii") + ) except subprocess.CalledProcessError: git_commit = "" try: dirty_string = "-this_is_a_dirty_checkout" - is_dirty = subprocess.check_output( - ['git', 'describe', '--dirty=' + dirty_string], - stderr=null, - cwd=cwd, - ).strip().decode('ascii').endswith(dirty_string) + is_dirty = ( + subprocess.check_output( + ["git", "describe", "--dirty=" + dirty_string], stderr=null, cwd=cwd + ) + .strip() + .decode("ascii") + .endswith(dirty_string) + ) git_dirty = "dirty" if is_dirty else "" except subprocess.CalledProcessError: @@ -68,16 +77,10 @@ def get_version_string(module): if git_branch or git_tag or git_commit or git_dirty: git_version = ",".join( - s for s in - (git_branch, git_tag, git_commit, git_dirty,) - if s + s for s in (git_branch, git_tag, git_commit, git_dirty) if s ) - return ( - "%s (%s)" % ( - module.__version__, git_version, - ) - ) + return "%s (%s)" % (module.__version__, git_version) except Exception as e: logger.info("Failed to check for git repository: %s", e) diff --git a/synapse/util/wheel_timer.py b/synapse/util/wheel_timer.py index 7a9e45aca90d..9bf6a44f758c 100644 --- a/synapse/util/wheel_timer.py +++ b/synapse/util/wheel_timer.py @@ -69,9 +69,7 @@ def insert(self, now, obj, then): # Add empty entries between the end of the current list and when we want # to insert. This ensures there are no gaps. - self.entries.extend( - _Entry(key) for key in range(last_key, then_key + 1) - ) + self.entries.extend(_Entry(key) for key in range(last_key, then_key + 1)) self.entries[-1].queue.append(obj) diff --git a/synapse/visibility.py b/synapse/visibility.py index 16c40cd74c4e..2a11c8359699 100644 --- a/synapse/visibility.py +++ b/synapse/visibility.py @@ -29,12 +29,7 @@ logger = logging.getLogger(__name__) -VISIBILITY_PRIORITY = ( - "world_readable", - "shared", - "invited", - "joined", -) +VISIBILITY_PRIORITY = ("world_readable", "shared", "invited", "joined") MEMBERSHIP_PRIORITY = ( @@ -47,8 +42,9 @@ @defer.inlineCallbacks -def filter_events_for_client(store, user_id, events, is_peeking=False, - always_include_ids=frozenset()): +def filter_events_for_client( + store, user_id, events, is_peeking=False, always_include_ids=frozenset() +): """ Check which events a user is allowed to see @@ -71,23 +67,21 @@ def filter_events_for_client(store, user_id, events, is_peeking=False, # to clients. events = list(e for e in events if not e.internal_metadata.is_soft_failed()) - types = ( - (EventTypes.RoomHistoryVisibility, ""), - (EventTypes.Member, user_id), - ) + types = ((EventTypes.RoomHistoryVisibility, ""), (EventTypes.Member, user_id)) event_id_to_state = yield store.get_state_for_events( frozenset(e.event_id for e in events), state_filter=StateFilter.from_types(types), ) ignore_dict_content = yield store.get_global_account_data_by_type_for_user( - "m.ignored_user_list", user_id, + "m.ignored_user_list", user_id ) # FIXME: This will explode if people upload something incorrect. ignore_list = frozenset( ignore_dict_content.get("ignored_users", {}).keys() - if ignore_dict_content else [] + if ignore_dict_content + else [] ) erased_senders = yield store.are_users_erased((e.sender for e in events)) @@ -185,9 +179,7 @@ def allowed(event): elif visibility == "invited": # user can also see the event if they were *invited* at the time # of the event. - return ( - event if membership == Membership.INVITE else None - ) + return event if membership == Membership.INVITE else None elif visibility == "shared" and is_peeking: # if the visibility is shared, users cannot see the event unless @@ -220,8 +212,9 @@ def allowed(event): @defer.inlineCallbacks -def filter_events_for_server(store, server_name, events, redact=True, - check_history_visibility_only=False): +def filter_events_for_server( + store, server_name, events, redact=True, check_history_visibility_only=False +): """Filter a list of events based on whether given server is allowed to see them. @@ -242,15 +235,12 @@ def filter_events_for_server(store, server_name, events, redact=True, def is_sender_erased(event, erased_senders): if erased_senders and erased_senders[event.sender]: - logger.info( - "Sender of %s has been erased, redacting", - event.event_id, - ) + logger.info("Sender of %s has been erased, redacting", event.event_id) return True return False def check_event_is_visible(event, state): - history = state.get((EventTypes.RoomHistoryVisibility, ''), None) + history = state.get((EventTypes.RoomHistoryVisibility, ""), None) if history: visibility = history.content.get("history_visibility", "shared") if visibility in ["invited", "joined"]: @@ -287,8 +277,8 @@ def check_event_is_visible(event, state): event_to_state_ids = yield store.get_state_ids_for_events( frozenset(e.event_id for e in events), state_filter=StateFilter.from_types( - types=((EventTypes.RoomHistoryVisibility, ""),), - ) + types=((EventTypes.RoomHistoryVisibility, ""),) + ), ) visibility_ids = set() @@ -309,9 +299,7 @@ def check_event_is_visible(event, state): ) if not check_history_visibility_only: - erased_senders = yield store.are_users_erased( - (e.sender for e in events), - ) + erased_senders = yield store.are_users_erased((e.sender for e in events)) else: # We don't want to check whether users are erased, which is equivalent # to no users having been erased. @@ -343,11 +331,8 @@ def check_event_is_visible(event, state): event_to_state_ids = yield store.get_state_ids_for_events( frozenset(e.event_id for e in events), state_filter=StateFilter.from_types( - types=( - (EventTypes.RoomHistoryVisibility, ""), - (EventTypes.Member, None), - ), - ) + types=((EventTypes.RoomHistoryVisibility, ""), (EventTypes.Member, None)) + ), ) # We only want to pull out member events that correspond to the @@ -371,13 +356,15 @@ def include(typ, state_key): idx = state_key.find(":") if idx == -1: return False - return state_key[idx + 1:] == server_name - - event_map = yield store.get_events([ - e_id - for e_id, key in iteritems(event_id_to_state_key) - if include(key[0], key[1]) - ]) + return state_key[idx + 1 :] == server_name + + event_map = yield store.get_events( + [ + e_id + for e_id, key in iteritems(event_id_to_state_key) + if include(key[0], key[1]) + ] + ) event_to_state = { e_id: { diff --git a/tests/api/test_auth.py b/tests/api/test_auth.py index d0d36f96fa7e..d4e75b5b2e72 100644 --- a/tests/api/test_auth.py +++ b/tests/api/test_auth.py @@ -172,7 +172,7 @@ def test_get_user_by_req_appservice_valid_token_valid_user_id(self): request.requestHeaders.getRawHeaders = mock_getRawHeaders() requester = yield self.auth.get_user_by_req(request) self.assertEquals( - requester.user.to_string(), masquerading_user_id.decode('utf8') + requester.user.to_string(), masquerading_user_id.decode("utf8") ) def test_get_user_by_req_appservice_valid_token_bad_user_id(self): @@ -264,7 +264,7 @@ def get_user(tok): # check the token works request = Mock(args={}) - request.args[b"access_token"] = [token.encode('ascii')] + request.args[b"access_token"] = [token.encode("ascii")] request.requestHeaders.getRawHeaders = mock_getRawHeaders() requester = yield self.auth.get_user_by_req(request, allow_guest=True) self.assertEqual(UserID.from_string(USER_ID), requester.user) @@ -277,7 +277,7 @@ def get_user(tok): # the token should *not* work now request = Mock(args={}) - request.args[b"access_token"] = [guest_tok.encode('ascii')] + request.args[b"access_token"] = [guest_tok.encode("ascii")] request.requestHeaders.getRawHeaders = mock_getRawHeaders() with self.assertRaises(AuthError) as cm: @@ -321,11 +321,11 @@ def test_reserved_threepid(self): self.hs.config.limit_usage_by_mau = True self.hs.config.max_mau_value = 1 self.store.get_monthly_active_count = lambda: defer.succeed(2) - threepid = {'medium': 'email', 'address': 'reserved@server.com'} - unknown_threepid = {'medium': 'email', 'address': 'unreserved@server.com'} + threepid = {"medium": "email", "address": "reserved@server.com"} + unknown_threepid = {"medium": "email", "address": "unreserved@server.com"} self.hs.config.mau_limits_reserved_threepids = [threepid] - yield self.store.register(user_id='user1', token="123", password_hash=None) + yield self.store.register(user_id="user1", token="123", password_hash=None) with self.assertRaises(ResourceLimitError): yield self.auth.check_auth_blocking() diff --git a/tests/config/test_server.py b/tests/config/test_server.py index de64965a6069..1ca5ea54ca6e 100644 --- a/tests/config/test_server.py +++ b/tests/config/test_server.py @@ -20,10 +20,10 @@ class ServerConfigTestCase(unittest.TestCase): def test_is_threepid_reserved(self): - user1 = {'medium': 'email', 'address': 'user1@example.com'} - user2 = {'medium': 'email', 'address': 'user2@example.com'} - user3 = {'medium': 'email', 'address': 'user3@example.com'} - user1_msisdn = {'medium': 'msisdn', 'address': '447700000000'} + user1 = {"medium": "email", "address": "user1@example.com"} + user2 = {"medium": "email", "address": "user2@example.com"} + user3 = {"medium": "email", "address": "user3@example.com"} + user1_msisdn = {"medium": "msisdn", "address": "447700000000"} config = [user1, user2] self.assertTrue(is_threepid_reserved(config, user1)) diff --git a/tests/config/test_tls.py b/tests/config/test_tls.py index 40ca42877843..0cbbf4e885e2 100644 --- a/tests/config/test_tls.py +++ b/tests/config/test_tls.py @@ -32,7 +32,7 @@ def test_warn_self_signed(self): """ config_dir = self.mktemp() os.mkdir(config_dir) - with open(os.path.join(config_dir, "cert.pem"), 'w') as f: + with open(os.path.join(config_dir, "cert.pem"), "w") as f: f.write( """-----BEGIN CERTIFICATE----- MIID6DCCAtACAws9CjANBgkqhkiG9w0BAQUFADCBtzELMAkGA1UEBhMCVFIxDzAN diff --git a/tests/crypto/test_event_signing.py b/tests/crypto/test_event_signing.py index 71aa7314397f..126e1760048c 100644 --- a/tests/crypto/test_event_signing.py +++ b/tests/crypto/test_event_signing.py @@ -41,25 +41,25 @@ def setUp(self): def test_sign_minimal(self): event_dict = { - 'event_id': "$0:domain", - 'origin': "domain", - 'origin_server_ts': 1000000, - 'signatures': {}, - 'type': "X", - 'unsigned': {'age_ts': 1000000}, + "event_id": "$0:domain", + "origin": "domain", + "origin_server_ts": 1000000, + "signatures": {}, + "type": "X", + "unsigned": {"age_ts": 1000000}, } add_hashes_and_signatures(event_dict, HOSTNAME, self.signing_key) event = FrozenEvent(event_dict) - self.assertTrue(hasattr(event, 'hashes')) - self.assertIn('sha256', event.hashes) + self.assertTrue(hasattr(event, "hashes")) + self.assertIn("sha256", event.hashes) self.assertEquals( - event.hashes['sha256'], "6tJjLpXtggfke8UxFhAKg82QVkJzvKOVOOSjUDK4ZSI" + event.hashes["sha256"], "6tJjLpXtggfke8UxFhAKg82QVkJzvKOVOOSjUDK4ZSI" ) - self.assertTrue(hasattr(event, 'signatures')) + self.assertTrue(hasattr(event, "signatures")) self.assertIn(HOSTNAME, event.signatures) self.assertIn(KEY_NAME, event.signatures["domain"]) self.assertEquals( @@ -70,28 +70,28 @@ def test_sign_minimal(self): def test_sign_message(self): event_dict = { - 'content': {'body': "Here is the message content"}, - 'event_id': "$0:domain", - 'origin': "domain", - 'origin_server_ts': 1000000, - 'type': "m.room.message", - 'room_id': "!r:domain", - 'sender': "@u:domain", - 'signatures': {}, - 'unsigned': {'age_ts': 1000000}, + "content": {"body": "Here is the message content"}, + "event_id": "$0:domain", + "origin": "domain", + "origin_server_ts": 1000000, + "type": "m.room.message", + "room_id": "!r:domain", + "sender": "@u:domain", + "signatures": {}, + "unsigned": {"age_ts": 1000000}, } add_hashes_and_signatures(event_dict, HOSTNAME, self.signing_key) event = FrozenEvent(event_dict) - self.assertTrue(hasattr(event, 'hashes')) - self.assertIn('sha256', event.hashes) + self.assertTrue(hasattr(event, "hashes")) + self.assertIn("sha256", event.hashes) self.assertEquals( - event.hashes['sha256'], "onLKD1bGljeBWQhWZ1kaP9SorVmRQNdN5aM2JYU2n/g" + event.hashes["sha256"], "onLKD1bGljeBWQhWZ1kaP9SorVmRQNdN5aM2JYU2n/g" ) - self.assertTrue(hasattr(event, 'signatures')) + self.assertTrue(hasattr(event, "signatures")) self.assertIn(HOSTNAME, event.signatures) self.assertIn(KEY_NAME, event.signatures["domain"]) self.assertEquals( diff --git a/tests/events/test_utils.py b/tests/events/test_utils.py index d0cc492deb83..9e3d4d0f4756 100644 --- a/tests/events/test_utils.py +++ b/tests/events/test_utils.py @@ -37,88 +37,88 @@ def run_test(self, evdict, matchdict): def test_minimal(self): self.run_test( - {'type': 'A', 'event_id': '$test:domain'}, + {"type": "A", "event_id": "$test:domain"}, { - 'type': 'A', - 'event_id': '$test:domain', - 'content': {}, - 'signatures': {}, - 'unsigned': {}, + "type": "A", + "event_id": "$test:domain", + "content": {}, + "signatures": {}, + "unsigned": {}, }, ) def test_basic_keys(self): self.run_test( { - 'type': 'A', - 'room_id': '!1:domain', - 'sender': '@2:domain', - 'event_id': '$3:domain', - 'origin': 'domain', + "type": "A", + "room_id": "!1:domain", + "sender": "@2:domain", + "event_id": "$3:domain", + "origin": "domain", }, { - 'type': 'A', - 'room_id': '!1:domain', - 'sender': '@2:domain', - 'event_id': '$3:domain', - 'origin': 'domain', - 'content': {}, - 'signatures': {}, - 'unsigned': {}, + "type": "A", + "room_id": "!1:domain", + "sender": "@2:domain", + "event_id": "$3:domain", + "origin": "domain", + "content": {}, + "signatures": {}, + "unsigned": {}, }, ) def test_unsigned_age_ts(self): self.run_test( - {'type': 'B', 'event_id': '$test:domain', 'unsigned': {'age_ts': 20}}, + {"type": "B", "event_id": "$test:domain", "unsigned": {"age_ts": 20}}, { - 'type': 'B', - 'event_id': '$test:domain', - 'content': {}, - 'signatures': {}, - 'unsigned': {'age_ts': 20}, + "type": "B", + "event_id": "$test:domain", + "content": {}, + "signatures": {}, + "unsigned": {"age_ts": 20}, }, ) self.run_test( { - 'type': 'B', - 'event_id': '$test:domain', - 'unsigned': {'other_key': 'here'}, + "type": "B", + "event_id": "$test:domain", + "unsigned": {"other_key": "here"}, }, { - 'type': 'B', - 'event_id': '$test:domain', - 'content': {}, - 'signatures': {}, - 'unsigned': {}, + "type": "B", + "event_id": "$test:domain", + "content": {}, + "signatures": {}, + "unsigned": {}, }, ) def test_content(self): self.run_test( - {'type': 'C', 'event_id': '$test:domain', 'content': {'things': 'here'}}, + {"type": "C", "event_id": "$test:domain", "content": {"things": "here"}}, { - 'type': 'C', - 'event_id': '$test:domain', - 'content': {}, - 'signatures': {}, - 'unsigned': {}, + "type": "C", + "event_id": "$test:domain", + "content": {}, + "signatures": {}, + "unsigned": {}, }, ) self.run_test( { - 'type': 'm.room.create', - 'event_id': '$test:domain', - 'content': {'creator': '@2:domain', 'other_field': 'here'}, + "type": "m.room.create", + "event_id": "$test:domain", + "content": {"creator": "@2:domain", "other_field": "here"}, }, { - 'type': 'm.room.create', - 'event_id': '$test:domain', - 'content': {'creator': '@2:domain'}, - 'signatures': {}, - 'unsigned': {}, + "type": "m.room.create", + "event_id": "$test:domain", + "content": {"creator": "@2:domain"}, + "signatures": {}, + "unsigned": {}, }, ) diff --git a/tests/federation/test_complexity.py b/tests/federation/test_complexity.py index 1e3e5aec6643..a5b03005d7aa 100644 --- a/tests/federation/test_complexity.py +++ b/tests/federation/test_complexity.py @@ -32,7 +32,7 @@ class RoomComplexityTests(unittest.HomeserverTestCase): login.register_servlets, ] - def default_config(self, name='test'): + def default_config(self, name="test"): config = super(RoomComplexityTests, self).default_config(name=name) config["limit_large_remote_room_joins"] = True config["limit_large_remote_room_complexity"] = 0.05 diff --git a/tests/federation/test_federation_sender.py b/tests/federation/test_federation_sender.py index 7bb106b5f7df..cce8d8c6de96 100644 --- a/tests/federation/test_federation_sender.py +++ b/tests/federation/test_federation_sender.py @@ -51,16 +51,16 @@ def test_send_receipts(self): json_cb = mock_send_transaction.call_args[0][1] data = json_cb() self.assertEqual( - data['edus'], + data["edus"], [ { - 'edu_type': 'm.receipt', - 'content': { - 'room_id': { - 'm.read': { - 'user_id': { - 'event_ids': ['event_id'], - 'data': {'ts': 1234}, + "edu_type": "m.receipt", + "content": { + "room_id": { + "m.read": { + "user_id": { + "event_ids": ["event_id"], + "data": {"ts": 1234}, } } } @@ -93,16 +93,16 @@ def test_send_receipts_with_backoff(self): json_cb = mock_send_transaction.call_args[0][1] data = json_cb() self.assertEqual( - data['edus'], + data["edus"], [ { - 'edu_type': 'm.receipt', - 'content': { - 'room_id': { - 'm.read': { - 'user_id': { - 'event_ids': ['event_id'], - 'data': {'ts': 1234}, + "edu_type": "m.receipt", + "content": { + "room_id": { + "m.read": { + "user_id": { + "event_ids": ["event_id"], + "data": {"ts": 1234}, } } } @@ -128,16 +128,16 @@ def test_send_receipts_with_backoff(self): json_cb = mock_send_transaction.call_args[0][1] data = json_cb() self.assertEqual( - data['edus'], + data["edus"], [ { - 'edu_type': 'm.receipt', - 'content': { - 'room_id': { - 'm.read': { - 'user_id': { - 'event_ids': ['other_id'], - 'data': {'ts': 1234}, + "edu_type": "m.receipt", + "content": { + "room_id": { + "m.read": { + "user_id": { + "event_ids": ["other_id"], + "data": {"ts": 1234}, } } } diff --git a/tests/handlers/test_auth.py b/tests/handlers/test_auth.py index 1e39fe0ec201..b204a0700d25 100644 --- a/tests/handlers/test_auth.py +++ b/tests/handlers/test_auth.py @@ -117,7 +117,7 @@ def test_short_term_login_token_cannot_replace_user_id(self): def test_mau_limits_disabled(self): self.hs.config.limit_usage_by_mau = False # Ensure does not throw exception - yield self.auth_handler.get_access_token_for_user_id('user_a') + yield self.auth_handler.get_access_token_for_user_id("user_a") yield self.auth_handler.validate_short_term_login_token_and_get_user_id( self._get_macaroon().serialize() @@ -131,7 +131,7 @@ def test_mau_limits_exceeded_large(self): ) with self.assertRaises(ResourceLimitError): - yield self.auth_handler.get_access_token_for_user_id('user_a') + yield self.auth_handler.get_access_token_for_user_id("user_a") self.hs.get_datastore().get_monthly_active_count = Mock( return_value=defer.succeed(self.large_number_of_users) @@ -150,7 +150,7 @@ def test_mau_limits_parity(self): return_value=defer.succeed(self.hs.config.max_mau_value) ) with self.assertRaises(ResourceLimitError): - yield self.auth_handler.get_access_token_for_user_id('user_a') + yield self.auth_handler.get_access_token_for_user_id("user_a") self.hs.get_datastore().get_monthly_active_count = Mock( return_value=defer.succeed(self.hs.config.max_mau_value) @@ -166,7 +166,7 @@ def test_mau_limits_parity(self): self.hs.get_datastore().get_monthly_active_count = Mock( return_value=defer.succeed(self.hs.config.max_mau_value) ) - yield self.auth_handler.get_access_token_for_user_id('user_a') + yield self.auth_handler.get_access_token_for_user_id("user_a") self.hs.get_datastore().user_last_seen_monthly_active = Mock( return_value=defer.succeed(self.hs.get_clock().time_msec()) ) @@ -185,7 +185,7 @@ def test_mau_limits_not_exceeded(self): return_value=defer.succeed(self.small_number_of_users) ) # Ensure does not raise exception - yield self.auth_handler.get_access_token_for_user_id('user_a') + yield self.auth_handler.get_access_token_for_user_id("user_a") self.hs.get_datastore().get_monthly_active_count = Mock( return_value=defer.succeed(self.small_number_of_users) diff --git a/tests/handlers/test_directory.py b/tests/handlers/test_directory.py index 917548bb31d0..91c7a170704b 100644 --- a/tests/handlers/test_directory.py +++ b/tests/handlers/test_directory.py @@ -132,7 +132,7 @@ def test_denied(self): request, channel = self.make_request( "PUT", b"directory/room/%23test%3Atest", - ('{"room_id":"%s"}' % (room_id,)).encode('ascii'), + ('{"room_id":"%s"}' % (room_id,)).encode("ascii"), ) self.render(request) self.assertEquals(403, channel.code, channel.result) @@ -143,7 +143,7 @@ def test_allowed(self): request, channel = self.make_request( "PUT", b"directory/room/%23unofficial_test%3Atest", - ('{"room_id":"%s"}' % (room_id,)).encode('ascii'), + ('{"room_id":"%s"}' % (room_id,)).encode("ascii"), ) self.render(request) self.assertEquals(200, channel.code, channel.result) @@ -158,7 +158,7 @@ def prepare(self, reactor, clock, hs): room_id = self.helper.create_room_as(self.user_id) request, channel = self.make_request( - "PUT", b"directory/list/room/%s" % (room_id.encode('ascii'),), b'{}' + "PUT", b"directory/list/room/%s" % (room_id.encode("ascii"),), b"{}" ) self.render(request) self.assertEquals(200, channel.code, channel.result) @@ -190,7 +190,7 @@ def test_disabling_room_list(self): # Room list disabled so we shouldn't be allowed to publish rooms room_id = self.helper.create_room_as(self.user_id) request, channel = self.make_request( - "PUT", b"directory/list/room/%s" % (room_id.encode('ascii'),), b'{}' + "PUT", b"directory/list/room/%s" % (room_id.encode("ascii"),), b"{}" ) self.render(request) self.assertEquals(403, channel.code, channel.result) diff --git a/tests/handlers/test_e2e_room_keys.py b/tests/handlers/test_e2e_room_keys.py index 2e72a1dd2381..c4503c161191 100644 --- a/tests/handlers/test_e2e_room_keys.py +++ b/tests/handlers/test_e2e_room_keys.py @@ -395,37 +395,37 @@ def test_upload_room_keys_merge(self): yield self.handler.upload_room_keys(self.local_user, version, room_keys) new_room_keys = copy.deepcopy(room_keys) - new_room_key = new_room_keys['rooms']['!abc:matrix.org']['sessions']['c0ff33'] + new_room_key = new_room_keys["rooms"]["!abc:matrix.org"]["sessions"]["c0ff33"] # test that increasing the message_index doesn't replace the existing session - new_room_key['first_message_index'] = 2 - new_room_key['session_data'] = 'new' + new_room_key["first_message_index"] = 2 + new_room_key["session_data"] = "new" yield self.handler.upload_room_keys(self.local_user, version, new_room_keys) res = yield self.handler.get_room_keys(self.local_user, version) self.assertEqual( - res['rooms']['!abc:matrix.org']['sessions']['c0ff33']['session_data'], + res["rooms"]["!abc:matrix.org"]["sessions"]["c0ff33"]["session_data"], "SSBBTSBBIEZJU0gK", ) # test that marking the session as verified however /does/ replace it - new_room_key['is_verified'] = True + new_room_key["is_verified"] = True yield self.handler.upload_room_keys(self.local_user, version, new_room_keys) res = yield self.handler.get_room_keys(self.local_user, version) self.assertEqual( - res['rooms']['!abc:matrix.org']['sessions']['c0ff33']['session_data'], "new" + res["rooms"]["!abc:matrix.org"]["sessions"]["c0ff33"]["session_data"], "new" ) # test that a session with a higher forwarded_count doesn't replace one # with a lower forwarding count - new_room_key['forwarded_count'] = 2 - new_room_key['session_data'] = 'other' + new_room_key["forwarded_count"] = 2 + new_room_key["session_data"] = "other" yield self.handler.upload_room_keys(self.local_user, version, new_room_keys) res = yield self.handler.get_room_keys(self.local_user, version) self.assertEqual( - res['rooms']['!abc:matrix.org']['sessions']['c0ff33']['session_data'], "new" + res["rooms"]["!abc:matrix.org"]["sessions"]["c0ff33"]["session_data"], "new" ) # TODO: check edge cases as well as the common variations here diff --git a/tests/handlers/test_register.py b/tests/handlers/test_register.py index 5ffba2ca7aa6..4edce7af435f 100644 --- a/tests/handlers/test_register.py +++ b/tests/handlers/test_register.py @@ -52,7 +52,7 @@ def prepare(self, reactor, clock, hs): self.mock_distributor.declare("registered_user") self.mock_captcha_client = Mock() self.macaroon_generator = Mock( - generate_access_token=Mock(return_value='secret') + generate_access_token=Mock(return_value="secret") ) self.hs.get_macaroon_generator = Mock(return_value=self.macaroon_generator) self.handler = self.hs.get_registration_handler() @@ -71,7 +71,7 @@ def test_user_is_created_and_logged_in_if_doesnt_exist(self): ) self.assertEquals(result_user_id, user_id) self.assertTrue(result_token is not None) - self.assertEquals(result_token, 'secret') + self.assertEquals(result_token, "secret") def test_if_user_exists(self): store = self.hs.get_datastore() @@ -96,7 +96,7 @@ def test_mau_limits_when_disabled(self): self.hs.config.limit_usage_by_mau = False # Ensure does not throw exception self.get_success( - self.handler.get_or_create_user(self.requester, 'a', "display_name") + self.handler.get_or_create_user(self.requester, "a", "display_name") ) def test_get_or_create_user_mau_not_blocked(self): @@ -105,7 +105,7 @@ def test_get_or_create_user_mau_not_blocked(self): return_value=defer.succeed(self.hs.config.max_mau_value - 1) ) # Ensure does not throw exception - self.get_success(self.handler.get_or_create_user(self.requester, 'c', "User")) + self.get_success(self.handler.get_or_create_user(self.requester, "c", "User")) def test_get_or_create_user_mau_blocked(self): self.hs.config.limit_usage_by_mau = True @@ -113,7 +113,7 @@ def test_get_or_create_user_mau_blocked(self): return_value=defer.succeed(self.lots_of_users) ) self.get_failure( - self.handler.get_or_create_user(self.requester, 'b', "display_name"), + self.handler.get_or_create_user(self.requester, "b", "display_name"), ResourceLimitError, ) @@ -121,7 +121,7 @@ def test_get_or_create_user_mau_blocked(self): return_value=defer.succeed(self.hs.config.max_mau_value) ) self.get_failure( - self.handler.get_or_create_user(self.requester, 'b', "display_name"), + self.handler.get_or_create_user(self.requester, "b", "display_name"), ResourceLimitError, ) @@ -144,13 +144,13 @@ def test_register_mau_blocked(self): def test_auto_create_auto_join_rooms(self): room_alias_str = "#room:test" self.hs.config.auto_join_rooms = [room_alias_str] - res = self.get_success(self.handler.register(localpart='jeff')) + res = self.get_success(self.handler.register(localpart="jeff")) rooms = self.get_success(self.store.get_rooms_for_user(res[0])) directory_handler = self.hs.get_handlers().directory_handler room_alias = RoomAlias.from_string(room_alias_str) room_id = self.get_success(directory_handler.get_association(room_alias)) - self.assertTrue(room_id['room_id'] in rooms) + self.assertTrue(room_id["room_id"] in rooms) self.assertEqual(len(rooms), 1) def test_auto_create_auto_join_rooms_with_no_rooms(self): @@ -173,7 +173,7 @@ def test_auto_create_auto_join_where_auto_create_is_false(self): self.hs.config.autocreate_auto_join_rooms = False room_alias_str = "#room:test" self.hs.config.auto_join_rooms = [room_alias_str] - res = self.get_success(self.handler.register(localpart='jeff')) + res = self.get_success(self.handler.register(localpart="jeff")) rooms = self.get_success(self.store.get_rooms_for_user(res[0])) self.assertEqual(len(rooms), 0) @@ -182,7 +182,7 @@ def test_auto_create_auto_join_rooms_when_support_user_exists(self): self.hs.config.auto_join_rooms = [room_alias_str] self.store.is_support_user = Mock(return_value=True) - res = self.get_success(self.handler.register(localpart='support')) + res = self.get_success(self.handler.register(localpart="support")) rooms = self.get_success(self.store.get_rooms_for_user(res[0])) self.assertEqual(len(rooms), 0) directory_handler = self.hs.get_handlers().directory_handler @@ -211,7 +211,7 @@ def test_auto_create_auto_join_where_no_consent(self): # When:- # * the user is registered and post consent actions are called - res = self.get_success(self.handler.register(localpart='jeff')) + res = self.get_success(self.handler.register(localpart="jeff")) self.get_success(self.handler.post_consent_actions(res[0])) # Then:- @@ -221,17 +221,14 @@ def test_auto_create_auto_join_where_no_consent(self): def test_register_support_user(self): res = self.get_success( - self.handler.register(localpart='user', user_type=UserTypes.SUPPORT) + self.handler.register(localpart="user", user_type=UserTypes.SUPPORT) ) self.assertTrue(self.store.is_support_user(res[0])) def test_register_not_support_user(self): - res = self.get_success(self.handler.register(localpart='user')) + res = self.get_success(self.handler.register(localpart="user")) self.assertFalse(self.store.is_support_user(res[0])) def test_invalid_user_id_length(self): invalid_user_id = "x" * 256 - self.get_failure( - self.handler.register(localpart=invalid_user_id), - SynapseError - ) + self.get_failure(self.handler.register(localpart=invalid_user_id), SynapseError) diff --git a/tests/handlers/test_stats.py b/tests/handlers/test_stats.py index 2710c991cfec..a8b858eb4ff1 100644 --- a/tests/handlers/test_stats.py +++ b/tests/handlers/test_stats.py @@ -265,10 +265,7 @@ def test_redacted_prev_event(self): while not self.get_success(self.store.has_completed_background_updates()): self.get_success(self.store.do_next_background_update(100), by=0.1) - events = { - "a1": None, - "a2": {"membership": Membership.JOIN}, - } + events = {"a1": None, "a2": {"membership": Membership.JOIN}} def get_event(event_id, allow_none=True): if events.get(event_id): diff --git a/tests/handlers/test_typing.py b/tests/handlers/test_typing.py index cb8b4d29138b..5d5e324df262 100644 --- a/tests/handlers/test_typing.py +++ b/tests/handlers/test_typing.py @@ -47,7 +47,7 @@ def _expect_edu_transaction(edu_type, content, origin="test"): def _make_edu_transaction_json(edu_type, content): - return json.dumps(_expect_edu_transaction(edu_type, content)).encode('utf8') + return json.dumps(_expect_edu_transaction(edu_type, content)).encode("utf8") class TypingNotificationsTestCase(unittest.HomeserverTestCase): @@ -151,7 +151,7 @@ def test_started_typing_local(self): ) ) - self.on_new_event.assert_has_calls([call('typing_key', 1, rooms=[ROOM_ID])]) + self.on_new_event.assert_has_calls([call("typing_key", 1, rooms=[ROOM_ID])]) self.assertEquals(self.event_source.get_current_key(), 1) events = self.event_source.get_new_events(room_ids=[ROOM_ID], from_key=0) @@ -209,12 +209,12 @@ def test_started_typing_remote_recv(self): "typing": True, }, ), - federation_auth_origin=b'farm', + federation_auth_origin=b"farm", ) self.render(request) self.assertEqual(channel.code, 200) - self.on_new_event.assert_has_calls([call('typing_key', 1, rooms=[ROOM_ID])]) + self.on_new_event.assert_has_calls([call("typing_key", 1, rooms=[ROOM_ID])]) self.assertEquals(self.event_source.get_current_key(), 1) events = self.event_source.get_new_events(room_ids=[ROOM_ID], from_key=0) @@ -247,7 +247,7 @@ def test_stopped_typing(self): ) ) - self.on_new_event.assert_has_calls([call('typing_key', 1, rooms=[ROOM_ID])]) + self.on_new_event.assert_has_calls([call("typing_key", 1, rooms=[ROOM_ID])]) put_json = self.hs.get_http_client().put_json put_json.assert_called_once_with( @@ -285,7 +285,7 @@ def test_typing_timeout(self): ) ) - self.on_new_event.assert_has_calls([call('typing_key', 1, rooms=[ROOM_ID])]) + self.on_new_event.assert_has_calls([call("typing_key", 1, rooms=[ROOM_ID])]) self.on_new_event.reset_mock() self.assertEquals(self.event_source.get_current_key(), 1) @@ -303,7 +303,7 @@ def test_typing_timeout(self): self.reactor.pump([16]) - self.on_new_event.assert_has_calls([call('typing_key', 2, rooms=[ROOM_ID])]) + self.on_new_event.assert_has_calls([call("typing_key", 2, rooms=[ROOM_ID])]) self.assertEquals(self.event_source.get_current_key(), 2) events = self.event_source.get_new_events(room_ids=[ROOM_ID], from_key=1) @@ -320,7 +320,7 @@ def test_typing_timeout(self): ) ) - self.on_new_event.assert_has_calls([call('typing_key', 3, rooms=[ROOM_ID])]) + self.on_new_event.assert_has_calls([call("typing_key", 3, rooms=[ROOM_ID])]) self.on_new_event.reset_mock() self.assertEquals(self.event_source.get_current_key(), 3) diff --git a/tests/handlers/test_user_directory.py b/tests/handlers/test_user_directory.py index 9021e647feb4..b135486c4877 100644 --- a/tests/handlers/test_user_directory.py +++ b/tests/handlers/test_user_directory.py @@ -60,15 +60,15 @@ def test_handle_local_profile_change_with_support_user(self): ) profile = self.get_success(self.store.get_user_in_directory(support_user_id)) self.assertTrue(profile is None) - display_name = 'display_name' + display_name = "display_name" - profile_info = ProfileInfo(avatar_url='avatar_url', display_name=display_name) - regular_user_id = '@regular:test' + profile_info = ProfileInfo(avatar_url="avatar_url", display_name=display_name) + regular_user_id = "@regular:test" self.get_success( self.handler.handle_local_profile_change(regular_user_id, profile_info) ) profile = self.get_success(self.store.get_user_in_directory(regular_user_id)) - self.assertTrue(profile['display_name'] == display_name) + self.assertTrue(profile["display_name"] == display_name) def test_handle_user_deactivated_support_user(self): s_user_id = "@support:test" diff --git a/tests/http/federation/test_matrix_federation_agent.py b/tests/http/federation/test_matrix_federation_agent.py index ecce473b011c..b1094c14483b 100644 --- a/tests/http/federation/test_matrix_federation_agent.py +++ b/tests/http/federation/test_matrix_federation_agent.py @@ -53,13 +53,15 @@ def get_connection_factory(): # this needs to happen once, but not until we are ready to run the first test global test_server_connection_factory if test_server_connection_factory is None: - test_server_connection_factory = TestServerTLSConnectionFactory(sanlist=[ - b'DNS:testserv', - b'DNS:target-server', - b'DNS:xn--bcher-kva.com', - b'IP:1.2.3.4', - b'IP:::1', - ]) + test_server_connection_factory = TestServerTLSConnectionFactory( + sanlist=[ + b"DNS:testserv", + b"DNS:target-server", + b"DNS:xn--bcher-kva.com", + b"IP:1.2.3.4", + b"IP:::1", + ] + ) return test_server_connection_factory @@ -133,7 +135,7 @@ def _make_get_request(self, uri): Sends a simple GET request via the agent, and checks its logcontext management """ with LoggingContext("one") as context: - fetch_d = self.agent.request(b'GET', uri) + fetch_d = self.agent.request(b"GET", uri) # Nothing happened yet self.assertNoResult(fetch_d) @@ -177,9 +179,9 @@ def _send_well_known_response(self, request, content, headers={}): """Check that an incoming request looks like a valid .well-known request, and send back the response. """ - self.assertEqual(request.method, b'GET') - self.assertEqual(request.path, b'/.well-known/matrix/server') - self.assertEqual(request.requestHeaders.getRawHeaders(b'host'), [b'testserv']) + self.assertEqual(request.method, b"GET") + self.assertEqual(request.path, b"/.well-known/matrix/server") + self.assertEqual(request.requestHeaders.getRawHeaders(b"host"), [b"testserv"]) # send back a response for k, v in headers.items(): request.setHeader(k, v) @@ -202,7 +204,7 @@ def test_get(self): clients = self.reactor.tcpClients self.assertEqual(len(clients), 1) (host, port, client_factory, _timeout, _bindAddress) = clients[0] - self.assertEqual(host, '1.2.3.4') + self.assertEqual(host, "1.2.3.4") self.assertEqual(port, 8448) # make a test server, and wire up the client @@ -210,20 +212,20 @@ def test_get(self): self.assertEqual(len(http_server.requests), 1) request = http_server.requests[0] - self.assertEqual(request.method, b'GET') - self.assertEqual(request.path, b'/foo/bar') + self.assertEqual(request.method, b"GET") + self.assertEqual(request.path, b"/foo/bar") self.assertEqual( - request.requestHeaders.getRawHeaders(b'host'), [b'testserv:8448'] + request.requestHeaders.getRawHeaders(b"host"), [b"testserv:8448"] ) content = request.content.read() - self.assertEqual(content, b'') + self.assertEqual(content, b"") # Deferred is still without a result self.assertNoResult(test_d) # send the headers - request.responseHeaders.setRawHeaders(b'Content-Type', [b'application/json']) - request.write('') + request.responseHeaders.setRawHeaders(b"Content-Type", [b"application/json"]) + request.write("") self.reactor.pump((0.1,)) @@ -233,7 +235,7 @@ def test_get(self): self.assertEqual(response.code, 200) # Send the body - request.write('{ "a": 1 }'.encode('ascii')) + request.write('{ "a": 1 }'.encode("ascii")) request.finish() self.reactor.pump((0.1,)) @@ -258,7 +260,7 @@ def test_get_ip_address(self): clients = self.reactor.tcpClients self.assertEqual(len(clients), 1) (host, port, client_factory, _timeout, _bindAddress) = clients[0] - self.assertEqual(host, '1.2.3.4') + self.assertEqual(host, "1.2.3.4") self.assertEqual(port, 8448) # make a test server, and wire up the client @@ -266,9 +268,9 @@ def test_get_ip_address(self): self.assertEqual(len(http_server.requests), 1) request = http_server.requests[0] - self.assertEqual(request.method, b'GET') - self.assertEqual(request.path, b'/foo/bar') - self.assertEqual(request.requestHeaders.getRawHeaders(b'host'), [b'1.2.3.4']) + self.assertEqual(request.method, b"GET") + self.assertEqual(request.path, b"/foo/bar") + self.assertEqual(request.requestHeaders.getRawHeaders(b"host"), [b"1.2.3.4"]) # finish the request request.finish() @@ -293,7 +295,7 @@ def test_get_ipv6_address(self): clients = self.reactor.tcpClients self.assertEqual(len(clients), 1) (host, port, client_factory, _timeout, _bindAddress) = clients[0] - self.assertEqual(host, '::1') + self.assertEqual(host, "::1") self.assertEqual(port, 8448) # make a test server, and wire up the client @@ -301,9 +303,9 @@ def test_get_ipv6_address(self): self.assertEqual(len(http_server.requests), 1) request = http_server.requests[0] - self.assertEqual(request.method, b'GET') - self.assertEqual(request.path, b'/foo/bar') - self.assertEqual(request.requestHeaders.getRawHeaders(b'host'), [b'[::1]']) + self.assertEqual(request.method, b"GET") + self.assertEqual(request.path, b"/foo/bar") + self.assertEqual(request.requestHeaders.getRawHeaders(b"host"), [b"[::1]"]) # finish the request request.finish() @@ -328,7 +330,7 @@ def test_get_ipv6_address_with_port(self): clients = self.reactor.tcpClients self.assertEqual(len(clients), 1) (host, port, client_factory, _timeout, _bindAddress) = clients[0] - self.assertEqual(host, '::1') + self.assertEqual(host, "::1") self.assertEqual(port, 80) # make a test server, and wire up the client @@ -336,9 +338,9 @@ def test_get_ipv6_address_with_port(self): self.assertEqual(len(http_server.requests), 1) request = http_server.requests[0] - self.assertEqual(request.method, b'GET') - self.assertEqual(request.path, b'/foo/bar') - self.assertEqual(request.requestHeaders.getRawHeaders(b'host'), [b'[::1]:80']) + self.assertEqual(request.method, b"GET") + self.assertEqual(request.path, b"/foo/bar") + self.assertEqual(request.requestHeaders.getRawHeaders(b"host"), [b"[::1]:80"]) # finish the request request.finish() @@ -364,7 +366,7 @@ def test_get_hostname_bad_cert(self): clients = self.reactor.tcpClients self.assertEqual(len(clients), 1) (host, port, client_factory, _timeout, _bindAddress) = clients[0] - self.assertEqual(host, '1.2.3.4') + self.assertEqual(host, "1.2.3.4") self.assertEqual(port, 443) # fonx the connection @@ -382,11 +384,11 @@ def test_get_hostname_bad_cert(self): # we should fall back to a direct connection self.assertEqual(len(clients), 2) (host, port, client_factory, _timeout, _bindAddress) = clients[1] - self.assertEqual(host, '1.2.3.4') + self.assertEqual(host, "1.2.3.4") self.assertEqual(port, 8448) # make a test server, and wire up the client - http_server = self._make_connection(client_factory, expected_sni=b'testserv1') + http_server = self._make_connection(client_factory, expected_sni=b"testserv1") # there should be no requests self.assertEqual(len(http_server.requests), 0) @@ -413,7 +415,7 @@ def test_get_ip_address_bad_cert(self): clients = self.reactor.tcpClients self.assertEqual(len(clients), 1) (host, port, client_factory, _timeout, _bindAddress) = clients[0] - self.assertEqual(host, '1.2.3.5') + self.assertEqual(host, "1.2.3.5") self.assertEqual(port, 8448) # make a test server, and wire up the client @@ -447,7 +449,7 @@ def test_get_no_srv_no_well_known(self): clients = self.reactor.tcpClients self.assertEqual(len(clients), 1) (host, port, client_factory, _timeout, _bindAddress) = clients[0] - self.assertEqual(host, '1.2.3.4') + self.assertEqual(host, "1.2.3.4") self.assertEqual(port, 443) # fonx the connection @@ -465,17 +467,17 @@ def test_get_no_srv_no_well_known(self): # we should fall back to a direct connection self.assertEqual(len(clients), 2) (host, port, client_factory, _timeout, _bindAddress) = clients[1] - self.assertEqual(host, '1.2.3.4') + self.assertEqual(host, "1.2.3.4") self.assertEqual(port, 8448) # make a test server, and wire up the client - http_server = self._make_connection(client_factory, expected_sni=b'testserv') + http_server = self._make_connection(client_factory, expected_sni=b"testserv") self.assertEqual(len(http_server.requests), 1) request = http_server.requests[0] - self.assertEqual(request.method, b'GET') - self.assertEqual(request.path, b'/foo/bar') - self.assertEqual(request.requestHeaders.getRawHeaders(b'host'), [b'testserv']) + self.assertEqual(request.method, b"GET") + self.assertEqual(request.path, b"/foo/bar") + self.assertEqual(request.requestHeaders.getRawHeaders(b"host"), [b"testserv"]) # finish the request request.finish() @@ -499,7 +501,7 @@ def test_get_well_known(self): clients = self.reactor.tcpClients self.assertEqual(len(clients), 1) (host, port, client_factory, _timeout, _bindAddress) = clients[0] - self.assertEqual(host, '1.2.3.4') + self.assertEqual(host, "1.2.3.4") self.assertEqual(port, 443) self._handle_well_known_connection( @@ -516,20 +518,20 @@ def test_get_well_known(self): # now we should get a connection to the target server self.assertEqual(len(clients), 2) (host, port, client_factory, _timeout, _bindAddress) = clients[1] - self.assertEqual(host, '1::f') + self.assertEqual(host, "1::f") self.assertEqual(port, 8448) # make a test server, and wire up the client http_server = self._make_connection( - client_factory, expected_sni=b'target-server' + client_factory, expected_sni=b"target-server" ) self.assertEqual(len(http_server.requests), 1) request = http_server.requests[0] - self.assertEqual(request.method, b'GET') - self.assertEqual(request.path, b'/foo/bar') + self.assertEqual(request.method, b"GET") + self.assertEqual(request.path, b"/foo/bar") self.assertEqual( - request.requestHeaders.getRawHeaders(b'host'), [b'target-server'] + request.requestHeaders.getRawHeaders(b"host"), [b"target-server"] ) # finish the request @@ -561,7 +563,7 @@ def test_get_well_known_redirect(self): clients = self.reactor.tcpClients self.assertEqual(len(clients), 1) (host, port, client_factory, _timeout, _bindAddress) = clients.pop() - self.assertEqual(host, '1.2.3.4') + self.assertEqual(host, "1.2.3.4") self.assertEqual(port, 443) redirect_server = self._make_connection( @@ -571,7 +573,7 @@ def test_get_well_known_redirect(self): # send a 302 redirect self.assertEqual(len(redirect_server.requests), 1) request = redirect_server.requests[0] - request.redirect(b'https://testserv/even_better_known') + request.redirect(b"https://testserv/even_better_known") request.finish() self.reactor.pump((0.1,)) @@ -580,7 +582,7 @@ def test_get_well_known_redirect(self): clients = self.reactor.tcpClients self.assertEqual(len(clients), 1) (host, port, client_factory, _timeout, _bindAddress) = clients.pop() - self.assertEqual(host, '1.2.3.4') + self.assertEqual(host, "1.2.3.4") self.assertEqual(port, 443) well_known_server = self._make_connection( @@ -589,8 +591,8 @@ def test_get_well_known_redirect(self): self.assertEqual(len(well_known_server.requests), 1, "No request after 302") request = well_known_server.requests[0] - self.assertEqual(request.method, b'GET') - self.assertEqual(request.path, b'/even_better_known') + self.assertEqual(request.method, b"GET") + self.assertEqual(request.path, b"/even_better_known") request.write(b'{ "m.server": "target-server" }') request.finish() @@ -604,20 +606,20 @@ def test_get_well_known_redirect(self): # now we should get a connection to the target server self.assertEqual(len(clients), 1) (host, port, client_factory, _timeout, _bindAddress) = clients[0] - self.assertEqual(host, '1::f') + self.assertEqual(host, "1::f") self.assertEqual(port, 8448) # make a test server, and wire up the client http_server = self._make_connection( - client_factory, expected_sni=b'target-server' + client_factory, expected_sni=b"target-server" ) self.assertEqual(len(http_server.requests), 1) request = http_server.requests[0] - self.assertEqual(request.method, b'GET') - self.assertEqual(request.path, b'/foo/bar') + self.assertEqual(request.method, b"GET") + self.assertEqual(request.path, b"/foo/bar") self.assertEqual( - request.requestHeaders.getRawHeaders(b'host'), [b'target-server'] + request.requestHeaders.getRawHeaders(b"host"), [b"target-server"] ) # finish the request @@ -652,11 +654,11 @@ def test_get_invalid_well_known(self): clients = self.reactor.tcpClients self.assertEqual(len(clients), 1) (host, port, client_factory, _timeout, _bindAddress) = clients.pop() - self.assertEqual(host, '1.2.3.4') + self.assertEqual(host, "1.2.3.4") self.assertEqual(port, 443) self._handle_well_known_connection( - client_factory, expected_sni=b"testserv", content=b'NOT JSON' + client_factory, expected_sni=b"testserv", content=b"NOT JSON" ) # now there should be a SRV lookup @@ -667,17 +669,17 @@ def test_get_invalid_well_known(self): # we should fall back to a direct connection self.assertEqual(len(clients), 1) (host, port, client_factory, _timeout, _bindAddress) = clients.pop() - self.assertEqual(host, '1.2.3.4') + self.assertEqual(host, "1.2.3.4") self.assertEqual(port, 8448) # make a test server, and wire up the client - http_server = self._make_connection(client_factory, expected_sni=b'testserv') + http_server = self._make_connection(client_factory, expected_sni=b"testserv") self.assertEqual(len(http_server.requests), 1) request = http_server.requests[0] - self.assertEqual(request.method, b'GET') - self.assertEqual(request.path, b'/foo/bar') - self.assertEqual(request.requestHeaders.getRawHeaders(b'host'), [b'testserv']) + self.assertEqual(request.method, b"GET") + self.assertEqual(request.path, b"/foo/bar") + self.assertEqual(request.requestHeaders.getRawHeaders(b"host"), [b"testserv"]) # finish the request request.finish() @@ -712,12 +714,10 @@ def test_get_well_known_unsigned_cert(self): clients = self.reactor.tcpClients self.assertEqual(len(clients), 1) (host, port, client_factory, _timeout, _bindAddress) = clients[0] - self.assertEqual(host, '1.2.3.4') + self.assertEqual(host, "1.2.3.4") self.assertEqual(port, 443) - http_proto = self._make_connection( - client_factory, expected_sni=b"testserv", - ) + http_proto = self._make_connection(client_factory, expected_sni=b"testserv") # there should be no requests self.assertEqual(len(http_proto.requests), 0) @@ -750,17 +750,17 @@ def test_get_hostname_srv(self): clients = self.reactor.tcpClients self.assertEqual(len(clients), 1) (host, port, client_factory, _timeout, _bindAddress) = clients[0] - self.assertEqual(host, '1.2.3.4') + self.assertEqual(host, "1.2.3.4") self.assertEqual(port, 8443) # make a test server, and wire up the client - http_server = self._make_connection(client_factory, expected_sni=b'testserv') + http_server = self._make_connection(client_factory, expected_sni=b"testserv") self.assertEqual(len(http_server.requests), 1) request = http_server.requests[0] - self.assertEqual(request.method, b'GET') - self.assertEqual(request.path, b'/foo/bar') - self.assertEqual(request.requestHeaders.getRawHeaders(b'host'), [b'testserv']) + self.assertEqual(request.method, b"GET") + self.assertEqual(request.path, b"/foo/bar") + self.assertEqual(request.requestHeaders.getRawHeaders(b"host"), [b"testserv"]) # finish the request request.finish() @@ -783,7 +783,7 @@ def test_get_well_known_srv(self): clients = self.reactor.tcpClients self.assertEqual(len(clients), 1) (host, port, client_factory, _timeout, _bindAddress) = clients[0] - self.assertEqual(host, '1.2.3.4') + self.assertEqual(host, "1.2.3.4") self.assertEqual(port, 443) self.mock_resolver.resolve_service.side_effect = lambda _: [ @@ -804,20 +804,20 @@ def test_get_well_known_srv(self): # now we should get a connection to the target of the SRV record self.assertEqual(len(clients), 2) (host, port, client_factory, _timeout, _bindAddress) = clients[1] - self.assertEqual(host, '5.6.7.8') + self.assertEqual(host, "5.6.7.8") self.assertEqual(port, 8443) # make a test server, and wire up the client http_server = self._make_connection( - client_factory, expected_sni=b'target-server' + client_factory, expected_sni=b"target-server" ) self.assertEqual(len(http_server.requests), 1) request = http_server.requests[0] - self.assertEqual(request.method, b'GET') - self.assertEqual(request.path, b'/foo/bar') + self.assertEqual(request.method, b"GET") + self.assertEqual(request.path, b"/foo/bar") self.assertEqual( - request.requestHeaders.getRawHeaders(b'host'), [b'target-server'] + request.requestHeaders.getRawHeaders(b"host"), [b"target-server"] ) # finish the request @@ -846,7 +846,7 @@ def test_idna_servername(self): clients = self.reactor.tcpClients self.assertEqual(len(clients), 1) (host, port, client_factory, _timeout, _bindAddress) = clients[0] - self.assertEqual(host, '1.2.3.4') + self.assertEqual(host, "1.2.3.4") self.assertEqual(port, 443) # fonx the connection @@ -865,20 +865,20 @@ def test_idna_servername(self): clients = self.reactor.tcpClients self.assertEqual(len(clients), 2) (host, port, client_factory, _timeout, _bindAddress) = clients[1] - self.assertEqual(host, '1.2.3.4') + self.assertEqual(host, "1.2.3.4") self.assertEqual(port, 8448) # make a test server, and wire up the client http_server = self._make_connection( - client_factory, expected_sni=b'xn--bcher-kva.com' + client_factory, expected_sni=b"xn--bcher-kva.com" ) self.assertEqual(len(http_server.requests), 1) request = http_server.requests[0] - self.assertEqual(request.method, b'GET') - self.assertEqual(request.path, b'/foo/bar') + self.assertEqual(request.method, b"GET") + self.assertEqual(request.path, b"/foo/bar") self.assertEqual( - request.requestHeaders.getRawHeaders(b'host'), [b'xn--bcher-kva.com'] + request.requestHeaders.getRawHeaders(b"host"), [b"xn--bcher-kva.com"] ) # finish the request @@ -907,20 +907,20 @@ def test_idna_srv_target(self): clients = self.reactor.tcpClients self.assertEqual(len(clients), 1) (host, port, client_factory, _timeout, _bindAddress) = clients[0] - self.assertEqual(host, '1.2.3.4') + self.assertEqual(host, "1.2.3.4") self.assertEqual(port, 8443) # make a test server, and wire up the client http_server = self._make_connection( - client_factory, expected_sni=b'xn--bcher-kva.com' + client_factory, expected_sni=b"xn--bcher-kva.com" ) self.assertEqual(len(http_server.requests), 1) request = http_server.requests[0] - self.assertEqual(request.method, b'GET') - self.assertEqual(request.path, b'/foo/bar') + self.assertEqual(request.method, b"GET") + self.assertEqual(request.path, b"/foo/bar") self.assertEqual( - request.requestHeaders.getRawHeaders(b'host'), [b'xn--bcher-kva.com'] + request.requestHeaders.getRawHeaders(b"host"), [b"xn--bcher-kva.com"] ) # finish the request @@ -941,42 +941,42 @@ def do_get_well_known(self, serv): def test_well_known_cache(self): self.reactor.lookups["testserv"] = "1.2.3.4" - fetch_d = self.do_get_well_known(b'testserv') + fetch_d = self.do_get_well_known(b"testserv") # there should be an attempt to connect on port 443 for the .well-known clients = self.reactor.tcpClients self.assertEqual(len(clients), 1) (host, port, client_factory, _timeout, _bindAddress) = clients.pop(0) - self.assertEqual(host, '1.2.3.4') + self.assertEqual(host, "1.2.3.4") self.assertEqual(port, 443) well_known_server = self._handle_well_known_connection( client_factory, expected_sni=b"testserv", - response_headers={b'Cache-Control': b'max-age=10'}, + response_headers={b"Cache-Control": b"max-age=10"}, content=b'{ "m.server": "target-server" }', ) r = self.successResultOf(fetch_d) - self.assertEqual(r, b'target-server') + self.assertEqual(r, b"target-server") # close the tcp connection well_known_server.loseConnection() # repeat the request: it should hit the cache - fetch_d = self.do_get_well_known(b'testserv') + fetch_d = self.do_get_well_known(b"testserv") r = self.successResultOf(fetch_d) - self.assertEqual(r, b'target-server') + self.assertEqual(r, b"target-server") # expire the cache self.reactor.pump((10.0,)) # now it should connect again - fetch_d = self.do_get_well_known(b'testserv') + fetch_d = self.do_get_well_known(b"testserv") self.assertEqual(len(clients), 1) (host, port, client_factory, _timeout, _bindAddress) = clients.pop(0) - self.assertEqual(host, '1.2.3.4') + self.assertEqual(host, "1.2.3.4") self.assertEqual(port, 443) self._handle_well_known_connection( @@ -986,7 +986,7 @@ def test_well_known_cache(self): ) r = self.successResultOf(fetch_d) - self.assertEqual(r, b'other-server') + self.assertEqual(r, b"other-server") class TestCachePeriodFromHeaders(TestCase): @@ -994,27 +994,27 @@ def test_cache_control(self): # uppercase self.assertEqual( _cache_period_from_headers( - Headers({b'Cache-Control': [b'foo, Max-Age = 100, bar']}) + Headers({b"Cache-Control": [b"foo, Max-Age = 100, bar"]}) ), 100, ) # missing value self.assertIsNone( - _cache_period_from_headers(Headers({b'Cache-Control': [b'max-age=, bar']})) + _cache_period_from_headers(Headers({b"Cache-Control": [b"max-age=, bar"]})) ) # hackernews: bogus due to semicolon self.assertIsNone( _cache_period_from_headers( - Headers({b'Cache-Control': [b'private; max-age=0']}) + Headers({b"Cache-Control": [b"private; max-age=0"]}) ) ) # github self.assertEqual( _cache_period_from_headers( - Headers({b'Cache-Control': [b'max-age=0, private, must-revalidate']}) + Headers({b"Cache-Control": [b"max-age=0, private, must-revalidate"]}) ), 0, ) @@ -1022,7 +1022,7 @@ def test_cache_control(self): # google self.assertEqual( _cache_period_from_headers( - Headers({b'cache-control': [b'private, max-age=0']}) + Headers({b"cache-control": [b"private, max-age=0"]}) ), 0, ) @@ -1030,7 +1030,7 @@ def test_cache_control(self): def test_expires(self): self.assertEqual( _cache_period_from_headers( - Headers({b'Expires': [b'Wed, 30 Jan 2019 07:35:33 GMT']}), + Headers({b"Expires": [b"Wed, 30 Jan 2019 07:35:33 GMT"]}), time_now=lambda: 1548833700, ), 33, @@ -1041,8 +1041,8 @@ def test_expires(self): _cache_period_from_headers( Headers( { - b'cache-control': [b'max-age=10'], - b'Expires': [b'Wed, 30 Jan 2019 07:35:33 GMT'], + b"cache-control": [b"max-age=10"], + b"Expires": [b"Wed, 30 Jan 2019 07:35:33 GMT"], } ), time_now=lambda: 1548833700, @@ -1051,7 +1051,7 @@ def test_expires(self): ) # invalid expires means immediate expiry - self.assertEqual(_cache_period_from_headers(Headers({b'Expires': [b'0']})), 0) + self.assertEqual(_cache_period_from_headers(Headers({b"Expires": [b"0"]})), 0) def _check_logcontext(context): diff --git a/tests/http/federation/test_srv_resolver.py b/tests/http/federation/test_srv_resolver.py index 034c0db8d2f7..cf6c6e95b520 100644 --- a/tests/http/federation/test_srv_resolver.py +++ b/tests/http/federation/test_srv_resolver.py @@ -100,7 +100,7 @@ def test_from_cache_expired_and_dns_fail(self): def test_from_cache(self): clock = MockClock() - dns_client_mock = Mock(spec_set=['lookupService']) + dns_client_mock = Mock(spec_set=["lookupService"]) dns_client_mock.lookupService = Mock(spec_set=[]) service_name = b"test_service.example.com" diff --git a/tests/http/test_endpoint.py b/tests/http/test_endpoint.py index 3b0155ed03cc..b2e9533b07ca 100644 --- a/tests/http/test_endpoint.py +++ b/tests/http/test_endpoint.py @@ -20,12 +20,12 @@ class ServerNameTestCase(unittest.TestCase): def test_parse_server_name(self): test_data = { - 'localhost': ('localhost', None), - 'my-example.com:1234': ('my-example.com', 1234), - '1.2.3.4': ('1.2.3.4', None), - '[0abc:1def::1234]': ('[0abc:1def::1234]', None), - '1.2.3.4:1': ('1.2.3.4', 1), - '[0abc:1def::1234]:8080': ('[0abc:1def::1234]', 8080), + "localhost": ("localhost", None), + "my-example.com:1234": ("my-example.com", 1234), + "1.2.3.4": ("1.2.3.4", None), + "[0abc:1def::1234]": ("[0abc:1def::1234]", None), + "1.2.3.4:1": ("1.2.3.4", 1), + "[0abc:1def::1234]:8080": ("[0abc:1def::1234]", 8080), } for i, o in test_data.items(): diff --git a/tests/http/test_fedclient.py b/tests/http/test_fedclient.py index ee767f3a5a07..c4c0d9b96817 100644 --- a/tests/http/test_fedclient.py +++ b/tests/http/test_fedclient.py @@ -83,7 +83,7 @@ def do_request(): clients = self.reactor.tcpClients self.assertEqual(len(clients), 1) (host, port, factory, _timeout, _bindAddress) = clients[0] - self.assertEqual(host, '1.2.3.4') + self.assertEqual(host, "1.2.3.4") self.assertEqual(port, 8008) # complete the connection and wire it up to a fake transport @@ -99,7 +99,7 @@ def do_request(): self.assertNoResult(test_d) # Send it the HTTP response - res_json = '{ "a": 1 }'.encode('ascii') + res_json = '{ "a": 1 }'.encode("ascii") protocol.dataReceived( b"HTTP/1.1 200 OK\r\n" b"Server: Fake\r\n" @@ -138,7 +138,7 @@ def test_client_connection_refused(self): clients = self.reactor.tcpClients self.assertEqual(len(clients), 1) (host, port, factory, _timeout, _bindAddress) = clients[0] - self.assertEqual(host, '1.2.3.4') + self.assertEqual(host, "1.2.3.4") self.assertEqual(port, 8008) e = Exception("go away") factory.clientConnectionFailed(None, e) @@ -164,7 +164,7 @@ def test_client_never_connect(self): # Make sure treq is trying to connect clients = self.reactor.tcpClients self.assertEqual(len(clients), 1) - self.assertEqual(clients[0][0], '1.2.3.4') + self.assertEqual(clients[0][0], "1.2.3.4") self.assertEqual(clients[0][1], 8008) # Deferred is still without a result @@ -194,7 +194,7 @@ def test_client_connect_no_response(self): # Make sure treq is trying to connect clients = self.reactor.tcpClients self.assertEqual(len(clients), 1) - self.assertEqual(clients[0][0], '1.2.3.4') + self.assertEqual(clients[0][0], "1.2.3.4") self.assertEqual(clients[0][1], 8008) conn = Mock() @@ -215,10 +215,9 @@ def test_client_ip_range_blacklist(self): """Ensure that Synapse does not try to connect to blacklisted IPs""" # Set up the ip_range blacklist - self.hs.config.federation_ip_range_blacklist = IPSet([ - "127.0.0.0/8", - "fe80::/64", - ]) + self.hs.config.federation_ip_range_blacklist = IPSet( + ["127.0.0.0/8", "fe80::/64"] + ) self.reactor.lookups["internal"] = "127.0.0.1" self.reactor.lookups["internalv6"] = "fe80:0:0:0:0:8a2e:370:7337" self.reactor.lookups["fine"] = "10.20.30.40" @@ -382,7 +381,7 @@ def test_client_requires_trailing_slashes(self): b"Content-Type: application/json\r\n" b"Content-Length: 2\r\n" b"\r\n" - b'{}' + b"{}" ) # We should get a successful response diff --git a/tests/push/test_email.py b/tests/push/test_email.py index 72760a0733b7..358b593cd4ab 100644 --- a/tests/push/test_email.py +++ b/tests/push/test_email.py @@ -57,7 +57,7 @@ def sendmail(*args, **kwargs): config["email"] = { "enable_notifs": True, "template_dir": os.path.abspath( - pkg_resources.resource_filename('synapse', 'res/templates') + pkg_resources.resource_filename("synapse", "res/templates") ), "expiry_template_html": "notice_expiry.html", "expiry_template_text": "notice_expiry.txt", @@ -120,7 +120,7 @@ def test_simple_sends_email(self): # Create a simple room with two users room = self.helper.create_room_as(self.user_id, tok=self.access_token) self.helper.invite( - room=room, src=self.user_id, tok=self.access_token, targ=self.others[0].id, + room=room, src=self.user_id, tok=self.access_token, targ=self.others[0].id ) self.helper.join(room=room, user=self.others[0].id, tok=self.others[0].token) @@ -141,7 +141,7 @@ def test_multiple_members_email(self): for other in self.others: self.helper.invite( - room=room, src=self.user_id, tok=self.access_token, targ=other.id, + room=room, src=self.user_id, tok=self.access_token, targ=other.id ) self.helper.join(room=room, user=other.id, tok=other.token) diff --git a/tests/rest/admin/test_admin.py b/tests/rest/admin/test_admin.py index e5fc2fcd159e..5877bb21337f 100644 --- a/tests/rest/admin/test_admin.py +++ b/tests/rest/admin/test_admin.py @@ -30,7 +30,7 @@ class VersionTestCase(unittest.HomeserverTestCase): - url = '/_synapse/admin/v1/server_version' + url = "/_synapse/admin/v1/server_version" def create_test_json_resource(self): resource = JsonResource(self.hs) @@ -43,7 +43,7 @@ def test_version_string(self): self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual( - {'server_version', 'python_version'}, set(channel.json_body.keys()) + {"server_version", "python_version"}, set(channel.json_body.keys()) ) @@ -68,7 +68,7 @@ def make_homeserver(self, reactor, clock): self.hs = self.setup_test_homeserver() - self.hs.config.registration_shared_secret = u"shared" + self.hs.config.registration_shared_secret = "shared" self.hs.get_media_repository = Mock() self.hs.get_deactivate_account_handler = Mock() @@ -82,12 +82,12 @@ def test_disabled(self): """ self.hs.config.registration_shared_secret = None - request, channel = self.make_request("POST", self.url, b'{}') + request, channel = self.make_request("POST", self.url, b"{}") self.render(request) self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual( - 'Shared secret registration is not enabled', channel.json_body["error"] + "Shared secret registration is not enabled", channel.json_body["error"] ) def test_get_nonce(self): @@ -118,20 +118,20 @@ def test_expired_nonce(self): self.reactor.advance(59) body = json.dumps({"nonce": nonce}) - request, channel = self.make_request("POST", self.url, body.encode('utf8')) + request, channel = self.make_request("POST", self.url, body.encode("utf8")) self.render(request) self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) - self.assertEqual('username must be specified', channel.json_body["error"]) + self.assertEqual("username must be specified", channel.json_body["error"]) # 61 seconds self.reactor.advance(2) - request, channel = self.make_request("POST", self.url, body.encode('utf8')) + request, channel = self.make_request("POST", self.url, body.encode("utf8")) self.render(request) self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) - self.assertEqual('unrecognised nonce', channel.json_body["error"]) + self.assertEqual("unrecognised nonce", channel.json_body["error"]) def test_register_incorrect_nonce(self): """ @@ -154,7 +154,7 @@ def test_register_incorrect_nonce(self): "mac": want_mac, } ) - request, channel = self.make_request("POST", self.url, body.encode('utf8')) + request, channel = self.make_request("POST", self.url, body.encode("utf8")) self.render(request) self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"]) @@ -171,7 +171,7 @@ def test_register_correct_nonce(self): want_mac = hmac.new(key=b"shared", digestmod=hashlib.sha1) want_mac.update( - nonce.encode('ascii') + b"\x00bob\x00abc123\x00admin\x00support" + nonce.encode("ascii") + b"\x00bob\x00abc123\x00admin\x00support" ) want_mac = want_mac.hexdigest() @@ -185,7 +185,7 @@ def test_register_correct_nonce(self): "mac": want_mac, } ) - request, channel = self.make_request("POST", self.url, body.encode('utf8')) + request, channel = self.make_request("POST", self.url, body.encode("utf8")) self.render(request) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) @@ -200,7 +200,7 @@ def test_nonce_reuse(self): nonce = channel.json_body["nonce"] want_mac = hmac.new(key=b"shared", digestmod=hashlib.sha1) - want_mac.update(nonce.encode('ascii') + b"\x00bob\x00abc123\x00admin") + want_mac.update(nonce.encode("ascii") + b"\x00bob\x00abc123\x00admin") want_mac = want_mac.hexdigest() body = json.dumps( @@ -212,18 +212,18 @@ def test_nonce_reuse(self): "mac": want_mac, } ) - request, channel = self.make_request("POST", self.url, body.encode('utf8')) + request, channel = self.make_request("POST", self.url, body.encode("utf8")) self.render(request) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual("@bob:test", channel.json_body["user_id"]) # Now, try and reuse it - request, channel = self.make_request("POST", self.url, body.encode('utf8')) + request, channel = self.make_request("POST", self.url, body.encode("utf8")) self.render(request) self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) - self.assertEqual('unrecognised nonce', channel.json_body["error"]) + self.assertEqual("unrecognised nonce", channel.json_body["error"]) def test_missing_parts(self): """ @@ -243,11 +243,11 @@ def nonce(): # Must be present body = json.dumps({}) - request, channel = self.make_request("POST", self.url, body.encode('utf8')) + request, channel = self.make_request("POST", self.url, body.encode("utf8")) self.render(request) self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) - self.assertEqual('nonce must be specified', channel.json_body["error"]) + self.assertEqual("nonce must be specified", channel.json_body["error"]) # # Username checks @@ -255,35 +255,35 @@ def nonce(): # Must be present body = json.dumps({"nonce": nonce()}) - request, channel = self.make_request("POST", self.url, body.encode('utf8')) + request, channel = self.make_request("POST", self.url, body.encode("utf8")) self.render(request) self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) - self.assertEqual('username must be specified', channel.json_body["error"]) + self.assertEqual("username must be specified", channel.json_body["error"]) # Must be a string body = json.dumps({"nonce": nonce(), "username": 1234}) - request, channel = self.make_request("POST", self.url, body.encode('utf8')) + request, channel = self.make_request("POST", self.url, body.encode("utf8")) self.render(request) self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) - self.assertEqual('Invalid username', channel.json_body["error"]) + self.assertEqual("Invalid username", channel.json_body["error"]) # Must not have null bytes - body = json.dumps({"nonce": nonce(), "username": u"abcd\u0000"}) - request, channel = self.make_request("POST", self.url, body.encode('utf8')) + body = json.dumps({"nonce": nonce(), "username": "abcd\u0000"}) + request, channel = self.make_request("POST", self.url, body.encode("utf8")) self.render(request) self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) - self.assertEqual('Invalid username', channel.json_body["error"]) + self.assertEqual("Invalid username", channel.json_body["error"]) # Must not have null bytes body = json.dumps({"nonce": nonce(), "username": "a" * 1000}) - request, channel = self.make_request("POST", self.url, body.encode('utf8')) + request, channel = self.make_request("POST", self.url, body.encode("utf8")) self.render(request) self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) - self.assertEqual('Invalid username', channel.json_body["error"]) + self.assertEqual("Invalid username", channel.json_body["error"]) # # Password checks @@ -291,37 +291,35 @@ def nonce(): # Must be present body = json.dumps({"nonce": nonce(), "username": "a"}) - request, channel = self.make_request("POST", self.url, body.encode('utf8')) + request, channel = self.make_request("POST", self.url, body.encode("utf8")) self.render(request) self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) - self.assertEqual('password must be specified', channel.json_body["error"]) + self.assertEqual("password must be specified", channel.json_body["error"]) # Must be a string body = json.dumps({"nonce": nonce(), "username": "a", "password": 1234}) - request, channel = self.make_request("POST", self.url, body.encode('utf8')) + request, channel = self.make_request("POST", self.url, body.encode("utf8")) self.render(request) self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) - self.assertEqual('Invalid password', channel.json_body["error"]) + self.assertEqual("Invalid password", channel.json_body["error"]) # Must not have null bytes - body = json.dumps( - {"nonce": nonce(), "username": "a", "password": u"abcd\u0000"} - ) - request, channel = self.make_request("POST", self.url, body.encode('utf8')) + body = json.dumps({"nonce": nonce(), "username": "a", "password": "abcd\u0000"}) + request, channel = self.make_request("POST", self.url, body.encode("utf8")) self.render(request) self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) - self.assertEqual('Invalid password', channel.json_body["error"]) + self.assertEqual("Invalid password", channel.json_body["error"]) # Super long body = json.dumps({"nonce": nonce(), "username": "a", "password": "A" * 1000}) - request, channel = self.make_request("POST", self.url, body.encode('utf8')) + request, channel = self.make_request("POST", self.url, body.encode("utf8")) self.render(request) self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) - self.assertEqual('Invalid password', channel.json_body["error"]) + self.assertEqual("Invalid password", channel.json_body["error"]) # # user_type check @@ -336,11 +334,11 @@ def nonce(): "user_type": "invalid", } ) - request, channel = self.make_request("POST", self.url, body.encode('utf8')) + request, channel = self.make_request("POST", self.url, body.encode("utf8")) self.render(request) self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) - self.assertEqual('Invalid user type', channel.json_body["error"]) + self.assertEqual("Invalid user type", channel.json_body["error"]) class ShutdownRoomTestCase(unittest.HomeserverTestCase): @@ -396,7 +394,7 @@ def test_shutdown_room_consent(self): url = "admin/shutdown_room/" + room_id request, channel = self.make_request( "POST", - url.encode('ascii'), + url.encode("ascii"), json.dumps({"new_room_user_id": self.admin_user}), access_token=self.admin_user_tok, ) @@ -421,7 +419,7 @@ def test_shutdown_room_block_peek(self): url = "rooms/%s/state/m.room.history_visibility" % (room_id,) request, channel = self.make_request( "PUT", - url.encode('ascii'), + url.encode("ascii"), json.dumps({"history_visibility": "world_readable"}), access_token=self.other_user_token, ) @@ -432,7 +430,7 @@ def test_shutdown_room_block_peek(self): url = "admin/shutdown_room/" + room_id request, channel = self.make_request( "POST", - url.encode('ascii'), + url.encode("ascii"), json.dumps({"new_room_user_id": self.admin_user}), access_token=self.admin_user_tok, ) @@ -449,7 +447,7 @@ def _assert_peek(self, room_id, expect_code): url = "rooms/%s/initialSync" % (room_id,) request, channel = self.make_request( - "GET", url.encode('ascii'), access_token=self.admin_user_tok + "GET", url.encode("ascii"), access_token=self.admin_user_tok ) self.render(request) self.assertEqual( @@ -458,7 +456,7 @@ def _assert_peek(self, room_id, expect_code): url = "events?timeout=0&room_id=" + room_id request, channel = self.make_request( - "GET", url.encode('ascii'), access_token=self.admin_user_tok + "GET", url.encode("ascii"), access_token=self.admin_user_tok ) self.render(request) self.assertEqual( @@ -486,7 +484,7 @@ def test_delete_group(self): # Create a new group request, channel = self.make_request( "POST", - "/create_group".encode('ascii'), + "/create_group".encode("ascii"), access_token=self.admin_user_tok, content={"localpart": "test"}, ) @@ -502,14 +500,14 @@ def test_delete_group(self): url = "/groups/%s/admin/users/invite/%s" % (group_id, self.other_user) request, channel = self.make_request( - "PUT", url.encode('ascii'), access_token=self.admin_user_tok, content={} + "PUT", url.encode("ascii"), access_token=self.admin_user_tok, content={} ) self.render(request) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) url = "/groups/%s/self/accept_invite" % (group_id,) request, channel = self.make_request( - "PUT", url.encode('ascii'), access_token=self.other_user_token, content={} + "PUT", url.encode("ascii"), access_token=self.other_user_token, content={} ) self.render(request) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) @@ -522,7 +520,7 @@ def test_delete_group(self): url = "/admin/delete_group/" + group_id request, channel = self.make_request( "POST", - url.encode('ascii'), + url.encode("ascii"), access_token=self.admin_user_tok, content={"localpart": "test"}, ) @@ -544,7 +542,7 @@ def _check_group(self, group_id, expect_code): url = "/groups/%s/profile" % (group_id,) request, channel = self.make_request( - "GET", url.encode('ascii'), access_token=self.admin_user_tok + "GET", url.encode("ascii"), access_token=self.admin_user_tok ) self.render(request) @@ -556,7 +554,7 @@ def _get_groups_user_is_in(self, access_token): """Returns the list of groups the user is in (given their access token) """ request, channel = self.make_request( - "GET", "/joined_groups".encode('ascii'), access_token=access_token + "GET", "/joined_groups".encode("ascii"), access_token=access_token ) self.render(request) diff --git a/tests/rest/client/test_consent.py b/tests/rest/client/test_consent.py index efc5a99db33c..6803b372ac04 100644 --- a/tests/rest/client/test_consent.py +++ b/tests/rest/client/test_consent.py @@ -42,17 +42,17 @@ def make_homeserver(self, reactor, clock): # Make some temporary templates... temp_consent_path = self.mktemp() os.mkdir(temp_consent_path) - os.mkdir(os.path.join(temp_consent_path, 'en')) + os.mkdir(os.path.join(temp_consent_path, "en")) config["user_consent"] = { "version": "1", "template_dir": os.path.abspath(temp_consent_path), } - with open(os.path.join(temp_consent_path, "en/1.html"), 'w') as f: + with open(os.path.join(temp_consent_path, "en/1.html"), "w") as f: f.write("{{version}},{{has_consented}}") - with open(os.path.join(temp_consent_path, "en/success.html"), 'w') as f: + with open(os.path.join(temp_consent_path, "en/success.html"), "w") as f: f.write("yay!") hs = self.setup_test_homeserver(config=config) @@ -88,7 +88,7 @@ def test_accept_consent(self): self.assertEqual(channel.code, 200) # Get the version from the body, and whether we've consented - version, consented = channel.result["body"].decode('ascii').split(",") + version, consented = channel.result["body"].decode("ascii").split(",") self.assertEqual(consented, "False") # POST to the consent page, saying we've agreed @@ -111,6 +111,6 @@ def test_accept_consent(self): # Get the version from the body, and check that it's the version we # agreed to, and that we've consented to it. - version, consented = channel.result["body"].decode('ascii').split(",") + version, consented = channel.result["body"].decode("ascii").split(",") self.assertEqual(consented, "True") self.assertEqual(version, "1") diff --git a/tests/rest/client/test_identity.py b/tests/rest/client/test_identity.py index 68949307d97f..c9735219072f 100644 --- a/tests/rest/client/test_identity.py +++ b/tests/rest/client/test_identity.py @@ -56,7 +56,7 @@ def test_3pid_lookup_disabled(self): "address": "test@example.com", } request_data = json.dumps(params) - request_url = ("/rooms/%s/invite" % (room_id)).encode('ascii') + request_url = ("/rooms/%s/invite" % (room_id)).encode("ascii") request, channel = self.make_request( b"POST", request_url, request_data, access_token=tok ) diff --git a/tests/rest/client/v1/test_profile.py b/tests/rest/client/v1/test_profile.py index 72c7ed93cb3d..dff9b2f10c55 100644 --- a/tests/rest/client/v1/test_profile.py +++ b/tests/rest/client/v1/test_profile.py @@ -183,7 +183,7 @@ def prepare(self, reactor, clock, hs): def test_set_displayname(self): request, channel = self.make_request( "PUT", - "/profile/%s/displayname" % (self.owner, ), + "/profile/%s/displayname" % (self.owner,), content=json.dumps({"displayname": "test"}), access_token=self.owner_tok, ) @@ -197,7 +197,7 @@ def test_set_displayname_too_long(self): """Attempts to set a stupid displayname should get a 400""" request, channel = self.make_request( "PUT", - "/profile/%s/displayname" % (self.owner, ), + "/profile/%s/displayname" % (self.owner,), content=json.dumps({"displayname": "test" * 100}), access_token=self.owner_tok, ) @@ -209,8 +209,7 @@ def test_set_displayname_too_long(self): def get_displayname(self): request, channel = self.make_request( - "GET", - "/profile/%s/displayname" % (self.owner, ), + "GET", "/profile/%s/displayname" % (self.owner,) ) self.render(request) self.assertEqual(channel.code, 200, channel.result) diff --git a/tests/rest/client/v1/test_rooms.py b/tests/rest/client/v1/test_rooms.py index 5f75ad757952..2e3a765bf310 100644 --- a/tests/rest/client/v1/test_rooms.py +++ b/tests/rest/client/v1/test_rooms.py @@ -79,7 +79,7 @@ def prepare(self, reactor, clock, hs): # send a message in one of the rooms self.created_rmid_msg_path = ( "rooms/%s/send/m.room.message/a1" % (self.created_rmid) - ).encode('ascii') + ).encode("ascii") request, channel = self.make_request( "PUT", self.created_rmid_msg_path, b'{"msgtype":"m.text","body":"test msg"}' ) @@ -89,7 +89,7 @@ def prepare(self, reactor, clock, hs): # set topic for public room request, channel = self.make_request( "PUT", - ("rooms/%s/state/m.room.topic" % self.created_public_rmid).encode('ascii'), + ("rooms/%s/state/m.room.topic" % self.created_public_rmid).encode("ascii"), b'{"topic":"Public Room Topic"}', ) self.render(request) @@ -193,7 +193,7 @@ def test_topic_perms(self): request, channel = self.make_request("GET", topic_path) self.render(request) self.assertEquals(200, channel.code, msg=channel.result["body"]) - self.assert_dict(json.loads(topic_content.decode('utf8')), channel.json_body) + self.assert_dict(json.loads(topic_content.decode("utf8")), channel.json_body) # set/get topic in created PRIVATE room and left, expect 403 self.helper.leave(room=self.created_rmid, user=self.user_id) @@ -497,7 +497,7 @@ def prepare(self, reactor, clock, hs): def test_invalid_puts(self): # missing keys or invalid json - request, channel = self.make_request("PUT", self.path, '{}') + request, channel = self.make_request("PUT", self.path, "{}") self.render(request) self.assertEquals(400, channel.code, msg=channel.result["body"]) @@ -515,11 +515,11 @@ def test_invalid_puts(self): self.render(request) self.assertEquals(400, channel.code, msg=channel.result["body"]) - request, channel = self.make_request("PUT", self.path, 'text only') + request, channel = self.make_request("PUT", self.path, "text only") self.render(request) self.assertEquals(400, channel.code, msg=channel.result["body"]) - request, channel = self.make_request("PUT", self.path, '') + request, channel = self.make_request("PUT", self.path, "") self.render(request) self.assertEquals(400, channel.code, msg=channel.result["body"]) @@ -572,7 +572,7 @@ def prepare(self, reactor, clock, hs): def test_invalid_puts(self): path = "/rooms/%s/state/m.room.member/%s" % (self.room_id, self.user_id) # missing keys or invalid json - request, channel = self.make_request("PUT", path, '{}') + request, channel = self.make_request("PUT", path, "{}") self.render(request) self.assertEquals(400, channel.code, msg=channel.result["body"]) @@ -590,11 +590,11 @@ def test_invalid_puts(self): self.render(request) self.assertEquals(400, channel.code, msg=channel.result["body"]) - request, channel = self.make_request("PUT", path, 'text only') + request, channel = self.make_request("PUT", path, "text only") self.render(request) self.assertEquals(400, channel.code, msg=channel.result["body"]) - request, channel = self.make_request("PUT", path, '') + request, channel = self.make_request("PUT", path, "") self.render(request) self.assertEquals(400, channel.code, msg=channel.result["body"]) @@ -604,7 +604,7 @@ def test_invalid_puts(self): Membership.JOIN, Membership.LEAVE, ) - request, channel = self.make_request("PUT", path, content.encode('ascii')) + request, channel = self.make_request("PUT", path, content.encode("ascii")) self.render(request) self.assertEquals(400, channel.code, msg=channel.result["body"]) @@ -616,7 +616,7 @@ def test_rooms_members_self(self): # valid join message (NOOP since we made the room) content = '{"membership":"%s"}' % Membership.JOIN - request, channel = self.make_request("PUT", path, content.encode('ascii')) + request, channel = self.make_request("PUT", path, content.encode("ascii")) self.render(request) self.assertEquals(200, channel.code, msg=channel.result["body"]) @@ -678,7 +678,7 @@ def prepare(self, reactor, clock, hs): def test_invalid_puts(self): path = "/rooms/%s/send/m.room.message/mid1" % (urlparse.quote(self.room_id)) # missing keys or invalid json - request, channel = self.make_request("PUT", path, b'{}') + request, channel = self.make_request("PUT", path, b"{}") self.render(request) self.assertEquals(400, channel.code, msg=channel.result["body"]) @@ -696,11 +696,11 @@ def test_invalid_puts(self): self.render(request) self.assertEquals(400, channel.code, msg=channel.result["body"]) - request, channel = self.make_request("PUT", path, b'text only') + request, channel = self.make_request("PUT", path, b"text only") self.render(request) self.assertEquals(400, channel.code, msg=channel.result["body"]) - request, channel = self.make_request("PUT", path, b'') + request, channel = self.make_request("PUT", path, b"") self.render(request) self.assertEquals(400, channel.code, msg=channel.result["body"]) @@ -786,7 +786,7 @@ def test_topo_token_is_accepted(self): self.render(request) self.assertEquals(200, channel.code) self.assertTrue("start" in channel.json_body) - self.assertEquals(token, channel.json_body['start']) + self.assertEquals(token, channel.json_body["start"]) self.assertTrue("chunk" in channel.json_body) self.assertTrue("end" in channel.json_body) @@ -798,7 +798,7 @@ def test_stream_token_is_accepted_for_fwd_pagianation(self): self.render(request) self.assertEquals(200, channel.code) self.assertTrue("start" in channel.json_body) - self.assertEquals(token, channel.json_body['start']) + self.assertEquals(token, channel.json_body["start"]) self.assertTrue("chunk" in channel.json_body) self.assertTrue("end" in channel.json_body) @@ -961,9 +961,7 @@ def prepare(self, reactor, clock, homeserver): # Set a profile for the test user self.displayname = "test user" - data = { - "displayname": self.displayname, - } + data = {"displayname": self.displayname} request_data = json.dumps(data) request, channel = self.make_request( "PUT", @@ -977,16 +975,12 @@ def prepare(self, reactor, clock, homeserver): self.room_id = self.helper.create_room_as(self.user_id, tok=self.tok) def test_per_room_profile_forbidden(self): - data = { - "membership": "join", - "displayname": "other test user" - } + data = {"membership": "join", "displayname": "other test user"} request_data = json.dumps(data) request, channel = self.make_request( "PUT", - "/_matrix/client/r0/rooms/%s/state/m.room.member/%s" % ( - self.room_id, self.user_id, - ), + "/_matrix/client/r0/rooms/%s/state/m.room.member/%s" + % (self.room_id, self.user_id), request_data, access_token=self.tok, ) diff --git a/tests/rest/client/v1/utils.py b/tests/rest/client/v1/utils.py index f7133fc12ea3..991536714425 100644 --- a/tests/rest/client/v1/utils.py +++ b/tests/rest/client/v1/utils.py @@ -44,7 +44,7 @@ def create_room_as(self, room_creator, is_public=True, tok=None): path = path + "?access_token=%s" % tok request, channel = make_request( - self.hs.get_reactor(), "POST", path, json.dumps(content).encode('utf8') + self.hs.get_reactor(), "POST", path, json.dumps(content).encode("utf8") ) render(request, self.resource, self.hs.get_reactor()) @@ -93,7 +93,7 @@ def change_membership(self, room, src, targ, membership, tok=None, expect_code=2 data = {"membership": membership} request, channel = make_request( - self.hs.get_reactor(), "PUT", path, json.dumps(data).encode('utf8') + self.hs.get_reactor(), "PUT", path, json.dumps(data).encode("utf8") ) render(request, self.resource, self.hs.get_reactor()) @@ -117,7 +117,7 @@ def send(self, room_id, body=None, txn_id=None, tok=None, expect_code=200): path = path + "?access_token=%s" % tok request, channel = make_request( - self.hs.get_reactor(), "PUT", path, json.dumps(content).encode('utf8') + self.hs.get_reactor(), "PUT", path, json.dumps(content).encode("utf8") ) render(request, self.resource, self.hs.get_reactor()) @@ -134,7 +134,7 @@ def send_state(self, room_id, event_type, body, tok, expect_code=200): path = path + "?access_token=%s" % tok request, channel = make_request( - self.hs.get_reactor(), "PUT", path, json.dumps(body).encode('utf8') + self.hs.get_reactor(), "PUT", path, json.dumps(body).encode("utf8") ) render(request, self.resource, self.hs.get_reactor()) diff --git a/tests/rest/client/v2_alpha/test_account.py b/tests/rest/client/v2_alpha/test_account.py index a60a4a3b875b..920de41de476 100644 --- a/tests/rest/client/v2_alpha/test_account.py +++ b/tests/rest/client/v2_alpha/test_account.py @@ -135,9 +135,7 @@ def test_cant_reset_password_without_clicking_link(self): self.assertEquals(len(self.email_attempts), 1) # Attempt to reset password without clicking the link - self._reset_password( - new_password, session_id, client_secret, expected_code=401, - ) + self._reset_password(new_password, session_id, client_secret, expected_code=401) # Assert we can log in with the old password self.login("kermit", old_password) @@ -172,9 +170,7 @@ def test_no_valid_token(self): session_id = "weasle" # Attempt to reset password without even requesting an email - self._reset_password( - new_password, session_id, client_secret, expected_code=401, - ) + self._reset_password(new_password, session_id, client_secret, expected_code=401) # Assert we can log in with the old password self.login("kermit", old_password) @@ -258,19 +254,18 @@ def test_deactivate_account(self): user_id = self.register_user("kermit", "test") tok = self.login("kermit", "test") - request_data = json.dumps({ - "auth": { - "type": "m.login.password", - "user": user_id, - "password": "test", - }, - "erase": False, - }) + request_data = json.dumps( + { + "auth": { + "type": "m.login.password", + "user": user_id, + "password": "test", + }, + "erase": False, + } + ) request, channel = self.make_request( - "POST", - "account/deactivate", - request_data, - access_token=tok, + "POST", "account/deactivate", request_data, access_token=tok ) self.render(request) self.assertEqual(request.code, 200) diff --git a/tests/rest/client/v2_alpha/test_capabilities.py b/tests/rest/client/v2_alpha/test_capabilities.py index bce5b0cf4c70..b9e01c9418f1 100644 --- a/tests/rest/client/v2_alpha/test_capabilities.py +++ b/tests/rest/client/v2_alpha/test_capabilities.py @@ -47,15 +47,15 @@ def test_get_room_version_capabilities(self): request, channel = self.make_request("GET", self.url, access_token=access_token) self.render(request) - capabilities = channel.json_body['capabilities'] + capabilities = channel.json_body["capabilities"] self.assertEqual(channel.code, 200) - for room_version in capabilities['m.room_versions']['available'].keys(): + for room_version in capabilities["m.room_versions"]["available"].keys(): self.assertTrue(room_version in KNOWN_ROOM_VERSIONS, "" + room_version) self.assertEqual( self.config.default_room_version.identifier, - capabilities['m.room_versions']['default'], + capabilities["m.room_versions"]["default"], ) def test_get_change_password_capabilities(self): @@ -66,16 +66,16 @@ def test_get_change_password_capabilities(self): request, channel = self.make_request("GET", self.url, access_token=access_token) self.render(request) - capabilities = channel.json_body['capabilities'] + capabilities = channel.json_body["capabilities"] self.assertEqual(channel.code, 200) # Test case where password is handled outside of Synapse - self.assertTrue(capabilities['m.change_password']['enabled']) + self.assertTrue(capabilities["m.change_password"]["enabled"]) self.get_success(self.store.user_set_password_hash(user, None)) request, channel = self.make_request("GET", self.url, access_token=access_token) self.render(request) - capabilities = channel.json_body['capabilities'] + capabilities = channel.json_body["capabilities"] self.assertEqual(channel.code, 200) - self.assertFalse(capabilities['m.change_password']['enabled']) + self.assertFalse(capabilities["m.change_password"]["enabled"]) diff --git a/tests/rest/client/v2_alpha/test_register.py b/tests/rest/client/v2_alpha/test_register.py index b35b21544678..89a3f95c0a8b 100644 --- a/tests/rest/client/v2_alpha/test_register.py +++ b/tests/rest/client/v2_alpha/test_register.py @@ -335,7 +335,7 @@ def sendmail(*args, **kwargs): config["email"] = { "enable_notifs": True, "template_dir": os.path.abspath( - pkg_resources.resource_filename('synapse', 'res/templates') + pkg_resources.resource_filename("synapse", "res/templates") ), "expiry_template_html": "notice_expiry.html", "expiry_template_text": "notice_expiry.txt", @@ -400,19 +400,18 @@ def test_deactivated_user(self): (user_id, tok) = self.create_user() - request_data = json.dumps({ - "auth": { - "type": "m.login.password", - "user": user_id, - "password": "monkey", - }, - "erase": False, - }) + request_data = json.dumps( + { + "auth": { + "type": "m.login.password", + "user": user_id, + "password": "monkey", + }, + "erase": False, + } + ) request, channel = self.make_request( - "POST", - "account/deactivate", - request_data, - access_token=tok, + "POST", "account/deactivate", request_data, access_token=tok ) self.render(request) self.assertEqual(request.code, 200) @@ -476,20 +475,16 @@ def test_manual_email_send_expired_account(self): class AccountValidityBackgroundJobTestCase(unittest.HomeserverTestCase): - servlets = [ - synapse.rest.admin.register_servlets_for_client_rest_resource, - ] + servlets = [synapse.rest.admin.register_servlets_for_client_rest_resource] def make_homeserver(self, reactor, clock): self.validity_period = 10 - self.max_delta = self.validity_period * 10. / 100. + self.max_delta = self.validity_period * 10.0 / 100.0 config = self.default_config() config["enable_registration"] = True - config["account_validity"] = { - "enabled": False, - } + config["account_validity"] = {"enabled": False} self.hs = self.setup_test_homeserver(config=config) self.hs.config.account_validity.period = self.validity_period diff --git a/tests/rest/client/v2_alpha/test_relations.py b/tests/rest/client/v2_alpha/test_relations.py index 43b3049daaf6..3deeed3a70f6 100644 --- a/tests/rest/client/v2_alpha/test_relations.py +++ b/tests/rest/client/v2_alpha/test_relations.py @@ -56,7 +56,7 @@ def test_send_relation(self): creates the right shape of event. """ - channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", key=u"👍") + channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", key="👍") self.assertEquals(200, channel.code, channel.json_body) event_id = channel.json_body["event_id"] @@ -76,7 +76,7 @@ def test_send_relation(self): "content": { "m.relates_to": { "event_id": self.parent_id, - "key": u"👍", + "key": "👍", "rel_type": RelationTypes.ANNOTATION, } }, @@ -187,7 +187,7 @@ def test_aggregation_pagination_groups(self): access_tokens.append(token) idx = 0 - sent_groups = {u"👍": 10, u"a": 7, u"b": 5, u"c": 3, u"d": 2, u"e": 1} + sent_groups = {"👍": 10, "a": 7, "b": 5, "c": 3, "d": 2, "e": 1} for key in itertools.chain.from_iterable( itertools.repeat(key, num) for key, num in sent_groups.items() ): @@ -259,7 +259,7 @@ def test_aggregation_pagination_within_group(self): channel = self._send_relation( RelationTypes.ANNOTATION, "m.reaction", - key=u"👍", + key="👍", access_token=access_tokens[idx], ) self.assertEquals(200, channel.code, channel.json_body) @@ -273,7 +273,7 @@ def test_aggregation_pagination_within_group(self): prev_token = None found_event_ids = [] - encoded_key = six.moves.urllib.parse.quote_plus(u"👍".encode("utf-8")) + encoded_key = six.moves.urllib.parse.quote_plus("👍".encode("utf-8")) for _ in range(20): from_token = "" if prev_token: diff --git a/tests/rest/media/v1/test_base.py b/tests/rest/media/v1/test_base.py index 00688a732521..ebd78692082f 100644 --- a/tests/rest/media/v1/test_base.py +++ b/tests/rest/media/v1/test_base.py @@ -21,17 +21,17 @@ class GetFileNameFromHeadersTests(unittest.TestCase): # input -> expected result TEST_CASES = { - b"inline; filename=abc.txt": u"abc.txt", - b'inline; filename="azerty"': u"azerty", - b'inline; filename="aze%20rty"': u"aze%20rty", - b'inline; filename="aze\"rty"': u'aze"rty', - b'inline; filename="azer;ty"': u"azer;ty", - b"inline; filename*=utf-8''foo%C2%A3bar": u"foo£bar", + b"inline; filename=abc.txt": "abc.txt", + b'inline; filename="azerty"': "azerty", + b'inline; filename="aze%20rty"': "aze%20rty", + b'inline; filename="aze"rty"': 'aze"rty', + b'inline; filename="azer;ty"': "azer;ty", + b"inline; filename*=utf-8''foo%C2%A3bar": "foo£bar", } def tests(self): for hdr, expected in self.TEST_CASES.items(): - res = get_filename_from_headers({b'Content-Disposition': [hdr]}) + res = get_filename_from_headers({b"Content-Disposition": [hdr]}) self.assertEqual( res, expected, diff --git a/tests/rest/media/v1/test_media_storage.py b/tests/rest/media/v1/test_media_storage.py index 1069a44145b9..e2d418b1df0c 100644 --- a/tests/rest/media/v1/test_media_storage.py +++ b/tests/rest/media/v1/test_media_storage.py @@ -143,7 +143,7 @@ def write_to(r): def prepare(self, reactor, clock, hs): self.media_repo = hs.get_media_repository_resource() - self.download_resource = self.media_repo.children[b'download'] + self.download_resource = self.media_repo.children[b"download"] # smol png self.end_content = unhexlify( @@ -171,7 +171,7 @@ def _req(self, content_disposition): headers = { b"Content-Length": [b"%d" % (len(self.end_content))], - b"Content-Type": [b'image/png'], + b"Content-Type": [b"image/png"], } if content_disposition: headers[b"Content-Disposition"] = [content_disposition] @@ -204,7 +204,7 @@ def test_disposition_filenamestar_utf8escaped(self): correctly decode it as the UTF-8 string, and use filename* in the response. """ - filename = parse.quote(u"\u2603".encode('utf8')).encode('ascii') + filename = parse.quote("\u2603".encode("utf8")).encode("ascii") channel = self._req(b"inline; filename*=utf-8''" + filename + b".png") headers = channel.headers diff --git a/tests/rest/media/v1/test_url_preview.py b/tests/rest/media/v1/test_url_preview.py index 1ab0f7293afb..8fe596186640 100644 --- a/tests/rest/media/v1/test_url_preview.py +++ b/tests/rest/media/v1/test_url_preview.py @@ -55,10 +55,10 @@ class URLPreviewTests(unittest.HomeserverTestCase): hijack_auth = True user_id = "@test:user" end_content = ( - b'' + b"" b'' b'' - b'' + b"" ) def make_homeserver(self, reactor, clock): @@ -98,7 +98,7 @@ def make_homeserver(self, reactor, clock): def prepare(self, reactor, clock, hs): self.media_repo = hs.get_media_repository_resource() - self.preview_url = self.media_repo.children[b'preview_url'] + self.preview_url = self.media_repo.children[b"preview_url"] self.lookups = {} @@ -109,7 +109,7 @@ def resolveHostName( hostName, portNumber=0, addressTypes=None, - transportSemantics='TCP', + transportSemantics="TCP", ): resolution = HostResolution(hostName) @@ -118,7 +118,7 @@ def resolveHostName( raise DNSLookupError("OH NO") for i in self.lookups[hostName]: - resolutionReceiver.addressResolved(i[0]('TCP', i[1], portNumber)) + resolutionReceiver.addressResolved(i[0]("TCP", i[1], portNumber)) resolutionReceiver.resolutionComplete() return resolutionReceiver @@ -184,11 +184,11 @@ def test_non_ascii_preview_httpequiv(self): self.lookups["matrix.org"] = [(IPv4Address, "8.8.8.8")] end_content = ( - b'' + b"" b'' b'' b'' - b'' + b"" ) request, channel = self.make_request( @@ -204,7 +204,7 @@ def test_non_ascii_preview_httpequiv(self): client.dataReceived( ( b"HTTP/1.0 200 OK\r\nContent-Length: %d\r\n" - b"Content-Type: text/html; charset=\"utf8\"\r\n\r\n" + b'Content-Type: text/html; charset="utf8"\r\n\r\n' ) % (len(end_content),) + end_content @@ -212,16 +212,16 @@ def test_non_ascii_preview_httpequiv(self): self.pump() self.assertEqual(channel.code, 200) - self.assertEqual(channel.json_body["og:title"], u"\u0434\u043a\u0430") + self.assertEqual(channel.json_body["og:title"], "\u0434\u043a\u0430") def test_non_ascii_preview_content_type(self): self.lookups["matrix.org"] = [(IPv4Address, "8.8.8.8")] end_content = ( - b'' + b"" b'' b'' - b'' + b"" ) request, channel = self.make_request( @@ -237,7 +237,7 @@ def test_non_ascii_preview_content_type(self): client.dataReceived( ( b"HTTP/1.0 200 OK\r\nContent-Length: %d\r\n" - b"Content-Type: text/html; charset=\"windows-1251\"\r\n\r\n" + b'Content-Type: text/html; charset="windows-1251"\r\n\r\n' ) % (len(end_content),) + end_content @@ -245,7 +245,7 @@ def test_non_ascii_preview_content_type(self): self.pump() self.assertEqual(channel.code, 200) - self.assertEqual(channel.json_body["og:title"], u"\u0434\u043a\u0430") + self.assertEqual(channel.json_body["og:title"], "\u0434\u043a\u0430") def test_ipaddr(self): """ @@ -293,8 +293,8 @@ def test_blacklisted_ip_specific(self): self.assertEqual( channel.json_body, { - 'errcode': 'M_UNKNOWN', - 'error': 'DNS resolution failure during URL preview generation', + "errcode": "M_UNKNOWN", + "error": "DNS resolution failure during URL preview generation", }, ) @@ -314,8 +314,8 @@ def test_blacklisted_ip_range(self): self.assertEqual( channel.json_body, { - 'errcode': 'M_UNKNOWN', - 'error': 'DNS resolution failure during URL preview generation', + "errcode": "M_UNKNOWN", + "error": "DNS resolution failure during URL preview generation", }, ) @@ -334,8 +334,8 @@ def test_blacklisted_ip_specific_direct(self): self.assertEqual( channel.json_body, { - 'errcode': 'M_UNKNOWN', - 'error': 'IP address blocked by IP blacklist entry', + "errcode": "M_UNKNOWN", + "error": "IP address blocked by IP blacklist entry", }, ) self.assertEqual(channel.code, 403) @@ -354,8 +354,8 @@ def test_blacklisted_ip_range_direct(self): self.assertEqual( channel.json_body, { - 'errcode': 'M_UNKNOWN', - 'error': 'IP address blocked by IP blacklist entry', + "errcode": "M_UNKNOWN", + "error": "IP address blocked by IP blacklist entry", }, ) @@ -396,7 +396,7 @@ def test_blacklisted_ip_with_external_ip(self): non-blacklisted one, it will be rejected. """ # Hardcode the URL resolving to the IP we want. - self.lookups[u"example.com"] = [ + self.lookups["example.com"] = [ (IPv4Address, "1.1.1.2"), (IPv4Address, "8.8.8.8"), ] @@ -410,8 +410,8 @@ def test_blacklisted_ip_with_external_ip(self): self.assertEqual( channel.json_body, { - 'errcode': 'M_UNKNOWN', - 'error': 'DNS resolution failure during URL preview generation', + "errcode": "M_UNKNOWN", + "error": "DNS resolution failure during URL preview generation", }, ) @@ -435,8 +435,8 @@ def test_blacklisted_ipv6_specific(self): self.assertEqual( channel.json_body, { - 'errcode': 'M_UNKNOWN', - 'error': 'DNS resolution failure during URL preview generation', + "errcode": "M_UNKNOWN", + "error": "DNS resolution failure during URL preview generation", }, ) @@ -456,7 +456,7 @@ def test_blacklisted_ipv6_range(self): self.assertEqual( channel.json_body, { - 'errcode': 'M_UNKNOWN', - 'error': 'DNS resolution failure during URL preview generation', + "errcode": "M_UNKNOWN", + "error": "DNS resolution failure during URL preview generation", }, ) diff --git a/tests/server.py b/tests/server.py index c15a47f2a405..e573c4e4c5ce 100644 --- a/tests/server.py +++ b/tests/server.py @@ -46,7 +46,7 @@ class FakeChannel(object): def json_body(self): if not self.result: raise Exception("No result yet.") - return json.loads(self.result["body"].decode('utf8')) + return json.loads(self.result["body"].decode("utf8")) @property def code(self): @@ -151,10 +151,10 @@ def make_request( Tuple[synapse.http.site.SynapseRequest, channel] """ if not isinstance(method, bytes): - method = method.encode('ascii') + method = method.encode("ascii") if not isinstance(path, bytes): - path = path.encode('ascii') + path = path.encode("ascii") # Decorate it to be the full path, if we're using shorthand if shorthand and not path.startswith(b"/_matrix"): @@ -165,7 +165,7 @@ def make_request( path = b"/" + path if isinstance(content, text_type): - content = content.encode('utf8') + content = content.encode("utf8") site = FakeSite() channel = FakeChannel(reactor) @@ -173,11 +173,11 @@ def make_request( req = request(site, channel) req.process = lambda: b"" req.content = BytesIO(content) - req.postpath = list(map(unquote, path[1:].split(b'/'))) + req.postpath = list(map(unquote, path[1:].split(b"/"))) if access_token: req.requestHeaders.addRawHeader( - b"Authorization", b"Bearer " + access_token.encode('ascii') + b"Authorization", b"Bearer " + access_token.encode("ascii") ) if federation_auth_origin is not None: @@ -242,7 +242,7 @@ def getHostByName(self, name, timeout=None): self.nameResolver = SimpleResolverComplexifier(FakeResolver()) super(ThreadedMemoryReactorClock, self).__init__() - def listenUDP(self, port, protocol, interface='', maxPacketSize=8196): + def listenUDP(self, port, protocol, interface="", maxPacketSize=8196): p = udp.Port(port, protocol, interface, maxPacketSize, self) p.startListening() self._udp.append(p) @@ -371,7 +371,7 @@ class FakeTransport(object): disconnecting = False disconnected = False - buffer = attr.ib(default=b'') + buffer = attr.ib(default=b"") producer = attr.ib(default=None) autoflush = attr.ib(default=True) diff --git a/tests/server_notices/test_resource_limits_server_notices.py b/tests/server_notices/test_resource_limits_server_notices.py index 739ee59ce4b2..984feb623f43 100644 --- a/tests/server_notices/test_resource_limits_server_notices.py +++ b/tests/server_notices/test_resource_limits_server_notices.py @@ -109,7 +109,7 @@ def test_maybe_send_server_notice_to_user_remove_blocked_notice_noop(self): Test when user has blocked notice, but notice ought to be there (NOOP) """ self._rlsn._auth.check_auth_blocking = Mock( - side_effect=ResourceLimitError(403, 'foo') + side_effect=ResourceLimitError(403, "foo") ) mock_event = Mock( @@ -128,7 +128,7 @@ def test_maybe_send_server_notice_to_user_add_blocked_notice(self): """ self._rlsn._auth.check_auth_blocking = Mock( - side_effect=ResourceLimitError(403, 'foo') + side_effect=ResourceLimitError(403, "foo") ) self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id)) diff --git a/tests/state/test_v2.py b/tests/state/test_v2.py index 9c5311d916ad..8d3845c870c6 100644 --- a/tests/state/test_v2.py +++ b/tests/state/test_v2.py @@ -181,7 +181,7 @@ def test_ban_vs_pl(self): id="PB", sender=BOB, type=EventTypes.PowerLevels, - state_key='', + state_key="", content={"users": {ALICE: 100, BOB: 50}}, ), ] @@ -229,14 +229,14 @@ def test_offtopic_pl(self): id="PB", sender=BOB, type=EventTypes.PowerLevels, - state_key='', + state_key="", content={"users": {ALICE: 100, BOB: 50, CHARLIE: 50}}, ), FakeEvent( id="PC", sender=CHARLIE, type=EventTypes.PowerLevels, - state_key='', + state_key="", content={"users": {ALICE: 100, BOB: 50, CHARLIE: 0}}, ), ] @@ -256,7 +256,7 @@ def test_topic_basic(self): id="PA1", sender=ALICE, type=EventTypes.PowerLevels, - state_key='', + state_key="", content={"users": {ALICE: 100, BOB: 50}}, ), FakeEvent( @@ -266,14 +266,14 @@ def test_topic_basic(self): id="PA2", sender=ALICE, type=EventTypes.PowerLevels, - state_key='', + state_key="", content={"users": {ALICE: 100, BOB: 0}}, ), FakeEvent( id="PB", sender=BOB, type=EventTypes.PowerLevels, - state_key='', + state_key="", content={"users": {ALICE: 100, BOB: 50}}, ), FakeEvent( @@ -296,7 +296,7 @@ def test_topic_reset(self): id="PA", sender=ALICE, type=EventTypes.PowerLevels, - state_key='', + state_key="", content={"users": {ALICE: 100, BOB: 50}}, ), FakeEvent( @@ -326,7 +326,7 @@ def test_topic(self): id="PA1", sender=ALICE, type=EventTypes.PowerLevels, - state_key='', + state_key="", content={"users": {ALICE: 100, BOB: 50}}, ), FakeEvent( @@ -336,14 +336,14 @@ def test_topic(self): id="PA2", sender=ALICE, type=EventTypes.PowerLevels, - state_key='', + state_key="", content={"users": {ALICE: 100, BOB: 0}}, ), FakeEvent( id="PB", sender=BOB, type=EventTypes.PowerLevels, - state_key='', + state_key="", content={"users": {ALICE: 100, BOB: 50}}, ), FakeEvent( diff --git a/tests/storage/test_appservice.py b/tests/storage/test_appservice.py index 25a6c89ef5c1..622b16a071a2 100644 --- a/tests/storage/test_appservice.py +++ b/tests/storage/test_appservice.py @@ -74,7 +74,7 @@ def _add_appservice(self, as_token, id, url, hs_token, sender): namespaces={}, ) # use the token as the filename - with open(as_token, 'w') as outfile: + with open(as_token, "w") as outfile: outfile.write(yaml.dump(as_yaml)) self.as_yaml_files.append(as_token) @@ -135,7 +135,7 @@ def _add_service(self, url, as_token, id): namespaces={}, ) # use the token as the filename - with open(as_token, 'w') as outfile: + with open(as_token, "w") as outfile: outfile.write(yaml.dump(as_yaml)) self.as_yaml_files.append(as_token) diff --git a/tests/storage/test_client_ips.py b/tests/storage/test_client_ips.py index b62eae7abcaa..59c6f8c227a0 100644 --- a/tests/storage/test_client_ips.py +++ b/tests/storage/test_client_ips.py @@ -94,11 +94,11 @@ def test_insert_new_client_ip_none_device_id(self): result, [ { - 'access_token': 'access_token', - 'ip': 'ip', - 'user_agent': 'user_agent', - 'device_id': None, - 'last_seen': 12345678000, + "access_token": "access_token", + "ip": "ip", + "user_agent": "user_agent", + "device_id": None, + "last_seen": 12345678000, } ], ) @@ -125,11 +125,11 @@ def test_insert_new_client_ip_none_device_id(self): result, [ { - 'access_token': 'access_token', - 'ip': 'ip', - 'user_agent': 'user_agent', - 'device_id': None, - 'last_seen': 12345878000, + "access_token": "access_token", + "ip": "ip", + "user_agent": "user_agent", + "device_id": None, + "last_seen": 12345878000, } ], ) diff --git a/tests/storage/test_devices.py b/tests/storage/test_devices.py index 6396ccddb52b..3cc18f9f1cfe 100644 --- a/tests/storage/test_devices.py +++ b/tests/storage/test_devices.py @@ -77,12 +77,12 @@ def test_get_devices_by_remote(self): # Add two device updates with a single stream_id yield self.store.add_device_change_to_streams( - "user_id", device_ids, ["somehost"], + "user_id", device_ids, ["somehost"] ) # Get all device updates ever meant for this remote now_stream_id, device_updates = yield self.store.get_devices_by_remote( - "somehost", -1, limit=100, + "somehost", -1, limit=100 ) # Check original device_ids are contained within these updates @@ -95,19 +95,19 @@ def test_get_devices_by_remote_limited(self): # first add one device device_ids1 = ["device_id0"] yield self.store.add_device_change_to_streams( - "user_id", device_ids1, ["someotherhost"], + "user_id", device_ids1, ["someotherhost"] ) # then add 101 device_ids2 = ["device_id" + str(i + 1) for i in range(101)] yield self.store.add_device_change_to_streams( - "user_id", device_ids2, ["someotherhost"], + "user_id", device_ids2, ["someotherhost"] ) # then one more device_ids3 = ["newdevice"] yield self.store.add_device_change_to_streams( - "user_id", device_ids3, ["someotherhost"], + "user_id", device_ids3, ["someotherhost"] ) # @@ -116,20 +116,20 @@ def test_get_devices_by_remote_limited(self): # first we should get a single update now_stream_id, device_updates = yield self.store.get_devices_by_remote( - "someotherhost", -1, limit=100, + "someotherhost", -1, limit=100 ) self._check_devices_in_updates(device_ids1, device_updates) # Then we should get an empty list back as the 101 devices broke the limit now_stream_id, device_updates = yield self.store.get_devices_by_remote( - "someotherhost", now_stream_id, limit=100, + "someotherhost", now_stream_id, limit=100 ) self.assertEqual(len(device_updates), 0) # The 101 devices should've been cleared, so we should now just get one device # update now_stream_id, device_updates = yield self.store.get_devices_by_remote( - "someotherhost", now_stream_id, limit=100, + "someotherhost", now_stream_id, limit=100 ) self._check_devices_in_updates(device_ids3, device_updates) diff --git a/tests/storage/test_end_to_end_keys.py b/tests/storage/test_end_to_end_keys.py index cd2bcd4ca301..c8ece1528456 100644 --- a/tests/storage/test_end_to_end_keys.py +++ b/tests/storage/test_end_to_end_keys.py @@ -80,10 +80,10 @@ def test_multiple_devices(self): yield self.store.store_device("user2", "device1", None) yield self.store.store_device("user2", "device2", None) - yield self.store.set_e2e_device_keys("user1", "device1", now, 'json11') - yield self.store.set_e2e_device_keys("user1", "device2", now, 'json12') - yield self.store.set_e2e_device_keys("user2", "device1", now, 'json21') - yield self.store.set_e2e_device_keys("user2", "device2", now, 'json22') + yield self.store.set_e2e_device_keys("user1", "device1", now, "json11") + yield self.store.set_e2e_device_keys("user1", "device2", now, "json12") + yield self.store.set_e2e_device_keys("user2", "device1", now, "json21") + yield self.store.set_e2e_device_keys("user2", "device2", now, "json22") res = yield self.store.get_e2e_device_keys( (("user1", "device1"), ("user2", "device2")) diff --git a/tests/storage/test_event_federation.py b/tests/storage/test_event_federation.py index 0d4e74d63719..86c7ac350d49 100644 --- a/tests/storage/test_event_federation.py +++ b/tests/storage/test_event_federation.py @@ -27,11 +27,11 @@ def setUp(self): @defer.inlineCallbacks def test_get_prev_events_for_room(self): - room_id = '@ROOM:local' + room_id = "@ROOM:local" # add a bunch of events and hashes to act as forward extremities def insert_event(txn, i): - event_id = '$event_%i:local' % i + event_id = "$event_%i:local" % i txn.execute( ( @@ -45,19 +45,19 @@ def insert_event(txn, i): txn.execute( ( - 'INSERT INTO event_forward_extremities (room_id, event_id) ' - 'VALUES (?, ?)' + "INSERT INTO event_forward_extremities (room_id, event_id) " + "VALUES (?, ?)" ), (room_id, event_id), ) txn.execute( ( - 'INSERT INTO event_reference_hashes ' - '(event_id, algorithm, hash) ' + "INSERT INTO event_reference_hashes " + "(event_id, algorithm, hash) " "VALUES (?, 'sha256', ?)" ), - (event_id, b'ffff'), + (event_id, b"ffff"), ) for i in range(0, 11): diff --git a/tests/storage/test_event_metrics.py b/tests/storage/test_event_metrics.py index 19f9ccf5e00d..d44359ff9333 100644 --- a/tests/storage/test_event_metrics.py +++ b/tests/storage/test_event_metrics.py @@ -61,22 +61,24 @@ def test_exposed_to_prometheus(self): ) ) - expected = set([ - b'synapse_forward_extremities_bucket{le="1.0"} 0.0', - b'synapse_forward_extremities_bucket{le="2.0"} 2.0', - b'synapse_forward_extremities_bucket{le="3.0"} 2.0', - b'synapse_forward_extremities_bucket{le="5.0"} 2.0', - b'synapse_forward_extremities_bucket{le="7.0"} 3.0', - b'synapse_forward_extremities_bucket{le="10.0"} 3.0', - b'synapse_forward_extremities_bucket{le="15.0"} 3.0', - b'synapse_forward_extremities_bucket{le="20.0"} 3.0', - b'synapse_forward_extremities_bucket{le="50.0"} 3.0', - b'synapse_forward_extremities_bucket{le="100.0"} 3.0', - b'synapse_forward_extremities_bucket{le="200.0"} 3.0', - b'synapse_forward_extremities_bucket{le="500.0"} 3.0', - b'synapse_forward_extremities_bucket{le="+Inf"} 3.0', - b'synapse_forward_extremities_count 3.0', - b'synapse_forward_extremities_sum 10.0', - ]) + expected = set( + [ + b'synapse_forward_extremities_bucket{le="1.0"} 0.0', + b'synapse_forward_extremities_bucket{le="2.0"} 2.0', + b'synapse_forward_extremities_bucket{le="3.0"} 2.0', + b'synapse_forward_extremities_bucket{le="5.0"} 2.0', + b'synapse_forward_extremities_bucket{le="7.0"} 3.0', + b'synapse_forward_extremities_bucket{le="10.0"} 3.0', + b'synapse_forward_extremities_bucket{le="15.0"} 3.0', + b'synapse_forward_extremities_bucket{le="20.0"} 3.0', + b'synapse_forward_extremities_bucket{le="50.0"} 3.0', + b'synapse_forward_extremities_bucket{le="100.0"} 3.0', + b'synapse_forward_extremities_bucket{le="200.0"} 3.0', + b'synapse_forward_extremities_bucket{le="500.0"} 3.0', + b'synapse_forward_extremities_bucket{le="+Inf"} 3.0', + b"synapse_forward_extremities_count 3.0", + b"synapse_forward_extremities_sum 10.0", + ] + ) self.assertEqual(items, expected) diff --git a/tests/storage/test_monthly_active_users.py b/tests/storage/test_monthly_active_users.py index f458c03054f0..0ce0b991f98a 100644 --- a/tests/storage/test_monthly_active_users.py +++ b/tests/storage/test_monthly_active_users.py @@ -46,9 +46,9 @@ def test_initialise_reserved_users(self): user3_email = "user3@matrix.org" threepids = [ - {'medium': 'email', 'address': user1_email}, - {'medium': 'email', 'address': user2_email}, - {'medium': 'email', 'address': user3_email}, + {"medium": "email", "address": user1_email}, + {"medium": "email", "address": user2_email}, + {"medium": "email", "address": user3_email}, ] # -1 because user3 is a support user and does not count user_num = len(threepids) - 1 @@ -177,7 +177,7 @@ def test_populate_monthly_users_should_update(self): self.store.user_last_seen_monthly_active = Mock( return_value=defer.succeed(None) ) - self.store.populate_monthly_active_users('user_id') + self.store.populate_monthly_active_users("user_id") self.pump() self.store.upsert_monthly_active_user.assert_called_once() @@ -188,7 +188,7 @@ def test_populate_monthly_users_should_not_update(self): self.store.user_last_seen_monthly_active = Mock( return_value=defer.succeed(self.hs.get_clock().time_msec()) ) - self.store.populate_monthly_active_users('user_id') + self.store.populate_monthly_active_users("user_id") self.pump() self.store.upsert_monthly_active_user.assert_not_called() @@ -198,13 +198,13 @@ def test_get_reserved_real_user_account(self): self.assertEquals(self.get_success(count), 0) # Test reserved users but no registered users - user1 = '@user1:example.com' - user2 = '@user2:example.com' - user1_email = 'user1@example.com' - user2_email = 'user2@example.com' + user1 = "@user1:example.com" + user2 = "@user2:example.com" + user1_email = "user1@example.com" + user2_email = "user2@example.com" threepids = [ - {'medium': 'email', 'address': user1_email}, - {'medium': 'email', 'address': user2_email}, + {"medium": "email", "address": user1_email}, + {"medium": "email", "address": user2_email}, ] self.hs.config.mau_limits_reserved_threepids = threepids self.store.runInteraction( diff --git a/tests/storage/test_redaction.py b/tests/storage/test_redaction.py index 4823d44decfc..732a778fabca 100644 --- a/tests/storage/test_redaction.py +++ b/tests/storage/test_redaction.py @@ -82,7 +82,7 @@ def inject_message(self, room, user, body): "sender": user.to_string(), "state_key": user.to_string(), "room_id": room.to_string(), - "content": {"body": body, "msgtype": u"message"}, + "content": {"body": body, "msgtype": "message"}, }, ) @@ -118,7 +118,7 @@ def inject_redaction(self, room, event_id, user, reason): def test_redact(self): yield self.inject_room_member(self.room1, self.u_alice, Membership.JOIN) - msg_event = yield self.inject_message(self.room1, self.u_alice, u"t") + msg_event = yield self.inject_message(self.room1, self.u_alice, "t") # Check event has not been redacted: event = yield self.store.get_event(msg_event.event_id) diff --git a/tests/storage/test_registration.py b/tests/storage/test_registration.py index c0e0155bb4a6..625b651e9123 100644 --- a/tests/storage/test_registration.py +++ b/tests/storage/test_registration.py @@ -128,4 +128,4 @@ def __init__(self): def generate(self, user_id): self._last_issued_token += 1 - return u"%s-%d" % (user_id, self._last_issued_token) + return "%s-%d" % (user_id, self._last_issued_token) diff --git a/tests/storage/test_room.py b/tests/storage/test_room.py index a1ea23b06896..1bee45706f0b 100644 --- a/tests/storage/test_room.py +++ b/tests/storage/test_room.py @@ -78,7 +78,7 @@ def inject_room_event(self, **kwargs): @defer.inlineCallbacks def STALE_test_room_name(self): - name = u"A-Room-Name" + name = "A-Room-Name" yield self.inject_room_event( etype=EventTypes.Name, name=name, content={"name": name}, depth=1 @@ -94,7 +94,7 @@ def STALE_test_room_name(self): @defer.inlineCallbacks def STALE_test_room_topic(self): - topic = u"A place for things" + topic = "A place for things" yield self.inject_room_event( etype=EventTypes.Topic, topic=topic, content={"topic": topic}, depth=1 diff --git a/tests/storage/test_state.py b/tests/storage/test_state.py index b6169436de49..212a7ae765d9 100644 --- a/tests/storage/test_state.py +++ b/tests/storage/test_state.py @@ -76,10 +76,10 @@ def assertStateMapEqual(self, s1, s2): @defer.inlineCallbacks def test_get_state_groups_ids(self): e1 = yield self.inject_state_event( - self.room, self.u_alice, EventTypes.Create, '', {} + self.room, self.u_alice, EventTypes.Create, "", {} ) e2 = yield self.inject_state_event( - self.room, self.u_alice, EventTypes.Name, '', {"name": "test room"} + self.room, self.u_alice, EventTypes.Name, "", {"name": "test room"} ) state_group_map = yield self.store.get_state_groups_ids( @@ -89,16 +89,16 @@ def test_get_state_groups_ids(self): state_map = list(state_group_map.values())[0] self.assertDictEqual( state_map, - {(EventTypes.Create, ''): e1.event_id, (EventTypes.Name, ''): e2.event_id}, + {(EventTypes.Create, ""): e1.event_id, (EventTypes.Name, ""): e2.event_id}, ) @defer.inlineCallbacks def test_get_state_groups(self): e1 = yield self.inject_state_event( - self.room, self.u_alice, EventTypes.Create, '', {} + self.room, self.u_alice, EventTypes.Create, "", {} ) e2 = yield self.inject_state_event( - self.room, self.u_alice, EventTypes.Name, '', {"name": "test room"} + self.room, self.u_alice, EventTypes.Name, "", {"name": "test room"} ) state_group_map = yield self.store.get_state_groups(self.room, [e2.event_id]) @@ -113,10 +113,10 @@ def test_get_state_for_event(self): # this defaults to a linear DAG as each new injection defaults to whatever # forward extremities are currently in the DB for this room. e1 = yield self.inject_state_event( - self.room, self.u_alice, EventTypes.Create, '', {} + self.room, self.u_alice, EventTypes.Create, "", {} ) e2 = yield self.inject_state_event( - self.room, self.u_alice, EventTypes.Name, '', {"name": "test room"} + self.room, self.u_alice, EventTypes.Name, "", {"name": "test room"} ) e3 = yield self.inject_state_event( self.room, @@ -158,7 +158,7 @@ def test_get_state_for_event(self): # check we can filter to the m.room.name event (with a '' state key) state = yield self.store.get_state_for_event( - e5.event_id, StateFilter.from_types([(EventTypes.Name, '')]) + e5.event_id, StateFilter.from_types([(EventTypes.Name, "")]) ) self.assertStateMapEqual({(e2.type, e2.state_key): e2}, state) diff --git a/tests/test_preview.py b/tests/test_preview.py index 84ef5e5ba4b6..7f67ee9e1f6f 100644 --- a/tests/test_preview.py +++ b/tests/test_preview.py @@ -24,14 +24,14 @@ class PreviewTestCase(unittest.TestCase): def test_long_summarize(self): example_paras = [ - u"""Tromsø (Norwegian pronunciation: [ˈtrʊmsœ] ( listen); Northern Sami: + """Tromsø (Norwegian pronunciation: [ˈtrʊmsœ] ( listen); Northern Sami: Romsa; Finnish: Tromssa[2] Kven: Tromssa) is a city and municipality in Troms county, Norway. The administrative centre of the municipality is the city of Tromsø. Outside of Norway, Tromso and Tromsö are alternative spellings of the city.Tromsø is considered the northernmost city in the world with a population above 50,000. The most populous town north of it is Alta, Norway, with a population of 14,272 (2013).""", - u"""Tromsø lies in Northern Norway. The municipality has a population of + """Tromsø lies in Northern Norway. The municipality has a population of (2015) 72,066, but with an annual influx of students it has over 75,000 most of the year. It is the largest urban area in Northern Norway and the third largest north of the Arctic Circle (following Murmansk and Norilsk). @@ -44,7 +44,7 @@ def test_long_summarize(self): Sandnessund Bridge. Tromsø Airport connects the city to many destinations in Europe. The city is warmer than most other places located on the same latitude, due to the warming effect of the Gulf Stream.""", - u"""The city centre of Tromsø contains the highest number of old wooden + """The city centre of Tromsø contains the highest number of old wooden houses in Northern Norway, the oldest house dating from 1789. The Arctic Cathedral, a modern church from 1965, is probably the most famous landmark in Tromsø. The city is a cultural centre for its region, with several @@ -58,87 +58,87 @@ def test_long_summarize(self): self.assertEquals( desc, - u"Tromsø (Norwegian pronunciation: [ˈtrʊmsœ] ( listen); Northern Sami:" - u" Romsa; Finnish: Tromssa[2] Kven: Tromssa) is a city and municipality in" - u" Troms county, Norway. The administrative centre of the municipality is" - u" the city of Tromsø. Outside of Norway, Tromso and Tromsö are" - u" alternative spellings of the city.Tromsø is considered the northernmost" - u" city in the world with a population above 50,000. The most populous town" - u" north of it is Alta, Norway, with a population of 14,272 (2013).", + "Tromsø (Norwegian pronunciation: [ˈtrʊmsœ] ( listen); Northern Sami:" + " Romsa; Finnish: Tromssa[2] Kven: Tromssa) is a city and municipality in" + " Troms county, Norway. The administrative centre of the municipality is" + " the city of Tromsø. Outside of Norway, Tromso and Tromsö are" + " alternative spellings of the city.Tromsø is considered the northernmost" + " city in the world with a population above 50,000. The most populous town" + " north of it is Alta, Norway, with a population of 14,272 (2013).", ) desc = summarize_paragraphs(example_paras[1:], min_size=200, max_size=500) self.assertEquals( desc, - u"Tromsø lies in Northern Norway. The municipality has a population of" - u" (2015) 72,066, but with an annual influx of students it has over 75,000" - u" most of the year. It is the largest urban area in Northern Norway and the" - u" third largest north of the Arctic Circle (following Murmansk and Norilsk)." - u" Most of Tromsø, including the city centre, is located on the island of" - u" Tromsøya, 350 kilometres (217 mi) north of the Arctic Circle. In 2012," - u" Tromsøya had a population of 36,088. Substantial parts of the urban…", + "Tromsø lies in Northern Norway. The municipality has a population of" + " (2015) 72,066, but with an annual influx of students it has over 75,000" + " most of the year. It is the largest urban area in Northern Norway and the" + " third largest north of the Arctic Circle (following Murmansk and Norilsk)." + " Most of Tromsø, including the city centre, is located on the island of" + " Tromsøya, 350 kilometres (217 mi) north of the Arctic Circle. In 2012," + " Tromsøya had a population of 36,088. Substantial parts of the urban…", ) def test_short_summarize(self): example_paras = [ - u"Tromsø (Norwegian pronunciation: [ˈtrʊmsœ] ( listen); Northern Sami:" - u" Romsa; Finnish: Tromssa[2] Kven: Tromssa) is a city and municipality in" - u" Troms county, Norway.", - u"Tromsø lies in Northern Norway. The municipality has a population of" - u" (2015) 72,066, but with an annual influx of students it has over 75,000" - u" most of the year.", - u"The city centre of Tromsø contains the highest number of old wooden" - u" houses in Northern Norway, the oldest house dating from 1789. The Arctic" - u" Cathedral, a modern church from 1965, is probably the most famous landmark" - u" in Tromsø.", + "Tromsø (Norwegian pronunciation: [ˈtrʊmsœ] ( listen); Northern Sami:" + " Romsa; Finnish: Tromssa[2] Kven: Tromssa) is a city and municipality in" + " Troms county, Norway.", + "Tromsø lies in Northern Norway. The municipality has a population of" + " (2015) 72,066, but with an annual influx of students it has over 75,000" + " most of the year.", + "The city centre of Tromsø contains the highest number of old wooden" + " houses in Northern Norway, the oldest house dating from 1789. The Arctic" + " Cathedral, a modern church from 1965, is probably the most famous landmark" + " in Tromsø.", ] desc = summarize_paragraphs(example_paras, min_size=200, max_size=500) self.assertEquals( desc, - u"Tromsø (Norwegian pronunciation: [ˈtrʊmsœ] ( listen); Northern Sami:" - u" Romsa; Finnish: Tromssa[2] Kven: Tromssa) is a city and municipality in" - u" Troms county, Norway.\n" - u"\n" - u"Tromsø lies in Northern Norway. The municipality has a population of" - u" (2015) 72,066, but with an annual influx of students it has over 75,000" - u" most of the year.", + "Tromsø (Norwegian pronunciation: [ˈtrʊmsœ] ( listen); Northern Sami:" + " Romsa; Finnish: Tromssa[2] Kven: Tromssa) is a city and municipality in" + " Troms county, Norway.\n" + "\n" + "Tromsø lies in Northern Norway. The municipality has a population of" + " (2015) 72,066, but with an annual influx of students it has over 75,000" + " most of the year.", ) def test_small_then_large_summarize(self): example_paras = [ - u"Tromsø (Norwegian pronunciation: [ˈtrʊmsœ] ( listen); Northern Sami:" - u" Romsa; Finnish: Tromssa[2] Kven: Tromssa) is a city and municipality in" - u" Troms county, Norway.", - u"Tromsø lies in Northern Norway. The municipality has a population of" - u" (2015) 72,066, but with an annual influx of students it has over 75,000" - u" most of the year." - u" The city centre of Tromsø contains the highest number of old wooden" - u" houses in Northern Norway, the oldest house dating from 1789. The Arctic" - u" Cathedral, a modern church from 1965, is probably the most famous landmark" - u" in Tromsø.", + "Tromsø (Norwegian pronunciation: [ˈtrʊmsœ] ( listen); Northern Sami:" + " Romsa; Finnish: Tromssa[2] Kven: Tromssa) is a city and municipality in" + " Troms county, Norway.", + "Tromsø lies in Northern Norway. The municipality has a population of" + " (2015) 72,066, but with an annual influx of students it has over 75,000" + " most of the year." + " The city centre of Tromsø contains the highest number of old wooden" + " houses in Northern Norway, the oldest house dating from 1789. The Arctic" + " Cathedral, a modern church from 1965, is probably the most famous landmark" + " in Tromsø.", ] desc = summarize_paragraphs(example_paras, min_size=200, max_size=500) self.assertEquals( desc, - u"Tromsø (Norwegian pronunciation: [ˈtrʊmsœ] ( listen); Northern Sami:" - u" Romsa; Finnish: Tromssa[2] Kven: Tromssa) is a city and municipality in" - u" Troms county, Norway.\n" - u"\n" - u"Tromsø lies in Northern Norway. The municipality has a population of" - u" (2015) 72,066, but with an annual influx of students it has over 75,000" - u" most of the year. The city centre of Tromsø contains the highest number" - u" of old wooden houses in Northern Norway, the oldest house dating from" - u" 1789. The Arctic Cathedral, a modern church from…", + "Tromsø (Norwegian pronunciation: [ˈtrʊmsœ] ( listen); Northern Sami:" + " Romsa; Finnish: Tromssa[2] Kven: Tromssa) is a city and municipality in" + " Troms county, Norway.\n" + "\n" + "Tromsø lies in Northern Norway. The municipality has a population of" + " (2015) 72,066, but with an annual influx of students it has over 75,000" + " most of the year. The city centre of Tromsø contains the highest number" + " of old wooden houses in Northern Norway, the oldest house dating from" + " 1789. The Arctic Cathedral, a modern church from…", ) class PreviewUrlTestCase(unittest.TestCase): def test_simple(self): - html = u""" + html = """ Foo @@ -149,10 +149,10 @@ def test_simple(self): og = decode_and_calc_og(html, "http://example.com/test.html") - self.assertEquals(og, {u"og:title": u"Foo", u"og:description": u"Some text."}) + self.assertEquals(og, {"og:title": "Foo", "og:description": "Some text."}) def test_comment(self): - html = u""" + html = """ Foo @@ -164,10 +164,10 @@ def test_comment(self): og = decode_and_calc_og(html, "http://example.com/test.html") - self.assertEquals(og, {u"og:title": u"Foo", u"og:description": u"Some text."}) + self.assertEquals(og, {"og:title": "Foo", "og:description": "Some text."}) def test_comment2(self): - html = u""" + html = """ Foo @@ -185,13 +185,13 @@ def test_comment2(self): self.assertEquals( og, { - u"og:title": u"Foo", - u"og:description": u"Some text.\n\nSome more text.\n\nText\n\nMore text", + "og:title": "Foo", + "og:description": "Some text.\n\nSome more text.\n\nText\n\nMore text", }, ) def test_script(self): - html = u""" + html = """ Foo @@ -203,10 +203,10 @@ def test_script(self): og = decode_and_calc_og(html, "http://example.com/test.html") - self.assertEquals(og, {u"og:title": u"Foo", u"og:description": u"Some text."}) + self.assertEquals(og, {"og:title": "Foo", "og:description": "Some text."}) def test_missing_title(self): - html = u""" + html = """ Some text. @@ -216,10 +216,10 @@ def test_missing_title(self): og = decode_and_calc_og(html, "http://example.com/test.html") - self.assertEquals(og, {u"og:title": None, u"og:description": u"Some text."}) + self.assertEquals(og, {"og:title": None, "og:description": "Some text."}) def test_h1_as_title(self): - html = u""" + html = """ @@ -230,10 +230,10 @@ def test_h1_as_title(self): og = decode_and_calc_og(html, "http://example.com/test.html") - self.assertEquals(og, {u"og:title": u"Title", u"og:description": u"Some text."}) + self.assertEquals(og, {"og:title": "Title", "og:description": "Some text."}) def test_missing_title_and_broken_h1(self): - html = u""" + html = """

@@ -244,4 +244,4 @@ def test_missing_title_and_broken_h1(self): og = decode_and_calc_og(html, "http://example.com/test.html") - self.assertEquals(og, {u"og:title": None, u"og:description": u"Some text."}) + self.assertEquals(og, {"og:title": None, "og:description": "Some text."}) diff --git a/tests/test_server.py b/tests/test_server.py index 08fb3fe02f27..da29ae92ced7 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -69,8 +69,8 @@ def _callback(request, **kwargs): ) render(request, res, self.reactor) - self.assertEqual(request.args, {b'a': [u"\N{SNOWMAN}".encode('utf8')]}) - self.assertEqual(got_kwargs, {u"room_id": u"\N{SNOWMAN}"}) + self.assertEqual(request.args, {b"a": ["\N{SNOWMAN}".encode("utf8")]}) + self.assertEqual(got_kwargs, {"room_id": "\N{SNOWMAN}"}) def test_callback_direct_exception(self): """ @@ -87,7 +87,7 @@ def _callback(request, **kwargs): request, channel = make_request(self.reactor, b"GET", b"/_matrix/foo") render(request, res, self.reactor) - self.assertEqual(channel.result["code"], b'500') + self.assertEqual(channel.result["code"], b"500") def test_callback_indirect_exception(self): """ @@ -110,7 +110,7 @@ def _callback(request, **kwargs): request, channel = make_request(self.reactor, b"GET", b"/_matrix/foo") render(request, res, self.reactor) - self.assertEqual(channel.result["code"], b'500') + self.assertEqual(channel.result["code"], b"500") def test_callback_synapseerror(self): """ @@ -127,7 +127,7 @@ def _callback(request, **kwargs): request, channel = make_request(self.reactor, b"GET", b"/_matrix/foo") render(request, res, self.reactor) - self.assertEqual(channel.result["code"], b'403') + self.assertEqual(channel.result["code"], b"403") self.assertEqual(channel.json_body["error"], "Forbidden!!one!") self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN") @@ -148,7 +148,7 @@ def _callback(request, **kwargs): request, channel = make_request(self.reactor, b"GET", b"/_matrix/foobar") render(request, res, self.reactor) - self.assertEqual(channel.result["code"], b'400') + self.assertEqual(channel.result["code"], b"400") self.assertEqual(channel.json_body["error"], "Unrecognized request") self.assertEqual(channel.json_body["errcode"], "M_UNRECOGNIZED") @@ -180,7 +180,7 @@ def render(self, request): # Make a resource and a Site, the resource will hang and allow us to # time out the request while it's 'processing' base_resource = Resource() - base_resource.putChild(b'', HangingResource()) + base_resource.putChild(b"", HangingResource()) site = SynapseSite("test", "site_tag", {}, base_resource, "1.0") server = site.buildProtocol(None) diff --git a/tests/test_state.py b/tests/test_state.py index 6491a7105a55..6d33566f474d 100644 --- a/tests/test_state.py +++ b/tests/test_state.py @@ -366,11 +366,11 @@ def test_branch_have_perms_conflict(self): def _add_depths(self, nodes, edges): def _get_depth(ev): node = nodes[ev] - if 'depth' not in node: + if "depth" not in node: prevs = edges[ev] depth = max(_get_depth(prev) for prev in prevs) + 1 - node['depth'] = depth - return node['depth'] + node["depth"] = depth + return node["depth"] for n in nodes: _get_depth(n) diff --git a/tests/test_types.py b/tests/test_types.py index d83c36559fb2..9ab5f829b094 100644 --- a/tests/test_types.py +++ b/tests/test_types.py @@ -102,7 +102,7 @@ def testLeadingUnderscore(self): def testNonAscii(self): # this should work with either a unicode or a bytes - self.assertEqual(map_username_to_mxid_localpart(u'têst'), "t=c3=aast") + self.assertEqual(map_username_to_mxid_localpart("têst"), "t=c3=aast") self.assertEqual( - map_username_to_mxid_localpart(u'têst'.encode('utf-8')), "t=c3=aast" + map_username_to_mxid_localpart("têst".encode("utf-8")), "t=c3=aast" ) diff --git a/tests/test_utils/logging_setup.py b/tests/test_utils/logging_setup.py index fde0baee8ee6..813f984199be 100644 --- a/tests/test_utils/logging_setup.py +++ b/tests/test_utils/logging_setup.py @@ -27,7 +27,7 @@ class ToTwistedHandler(logging.Handler): def emit(self, record): log_entry = self.format(record) - log_level = record.levelname.lower().replace('warning', 'warn') + log_level = record.levelname.lower().replace("warning", "warn") self.tx_log.emit( twisted.logger.LogLevel.levelWithName(log_level), log_entry.replace("{", r"(").replace("}", r")"), diff --git a/tests/test_visibility.py b/tests/test_visibility.py index 6a180ddc3229..118c3bd238bd 100644 --- a/tests/test_visibility.py +++ b/tests/test_visibility.py @@ -265,7 +265,7 @@ def test_large_room(self): pr.disable() with open("filter_events_for_server.profile", "w+") as f: - ps = pstats.Stats(pr, stream=f).sort_stats('cumulative') + ps = pstats.Stats(pr, stream=f).sort_stats("cumulative") ps.print_stats() # the result should be 5 redacted events, and 5 unredacted events. diff --git a/tests/unittest.py b/tests/unittest.py index b6dc7932ce5c..d64702b0c211 100644 --- a/tests/unittest.py +++ b/tests/unittest.py @@ -298,7 +298,7 @@ def make_request( Tuple[synapse.http.site.SynapseRequest, channel] """ if isinstance(content, dict): - content = json.dumps(content).encode('utf8') + content = json.dumps(content).encode("utf8") return make_request( self.reactor, @@ -389,7 +389,7 @@ def register_user(self, username, password, admin=False): Returns: The MXID of the new user (unicode). """ - self.hs.config.registration_shared_secret = u"shared" + self.hs.config.registration_shared_secret = "shared" # Create the user request, channel = self.make_request("GET", "/_matrix/client/r0/admin/register") @@ -397,13 +397,13 @@ def register_user(self, username, password, admin=False): nonce = channel.json_body["nonce"] want_mac = hmac.new(key=b"shared", digestmod=hashlib.sha1) - nonce_str = b"\x00".join([username.encode('utf8'), password.encode('utf8')]) + nonce_str = b"\x00".join([username.encode("utf8"), password.encode("utf8")]) if admin: nonce_str += b"\x00admin" else: nonce_str += b"\x00notadmin" - want_mac.update(nonce.encode('ascii') + b"\x00" + nonce_str) + want_mac.update(nonce.encode("ascii") + b"\x00" + nonce_str) want_mac = want_mac.hexdigest() body = json.dumps( @@ -416,7 +416,7 @@ def register_user(self, username, password, admin=False): } ) request, channel = self.make_request( - "POST", "/_matrix/client/r0/admin/register", body.encode('utf8') + "POST", "/_matrix/client/r0/admin/register", body.encode("utf8") ) self.render(request) self.assertEqual(channel.code, 200) @@ -435,7 +435,7 @@ def login(self, username, password, device_id=None): body["device_id"] = device_id request, channel = self.make_request( - "POST", "/_matrix/client/r0/login", json.dumps(body).encode('utf8') + "POST", "/_matrix/client/r0/login", json.dumps(body).encode("utf8") ) self.render(request) self.assertEqual(channel.code, 200, channel.result) @@ -481,9 +481,7 @@ def create_and_send_event( if soft_failed: event.internal_metadata.soft_failed = True - self.get_success( - event_creator.send_nonmember_event(requester, event, context) - ) + self.get_success(event_creator.send_nonmember_event(requester, event, context)) return event.event_id @@ -508,7 +506,7 @@ def attempt_wrong_password_login(self, username, password): body = {"type": "m.login.password", "user": username, "password": password} request, channel = self.make_request( - "POST", "/_matrix/client/r0/login", json.dumps(body).encode('utf8') + "POST", "/_matrix/client/r0/login", json.dumps(body).encode("utf8") ) self.render(request) self.assertEqual(channel.code, 403, channel.result) diff --git a/tests/util/caches/test_descriptors.py b/tests/util/caches/test_descriptors.py index 463a737efaf5..6f8f52537cd3 100644 --- a/tests/util/caches/test_descriptors.py +++ b/tests/util/caches/test_descriptors.py @@ -88,24 +88,24 @@ def fn(self, arg1, arg2): obj = Cls() - obj.mock.return_value = 'fish' + obj.mock.return_value = "fish" r = yield obj.fn(1, 2) - self.assertEqual(r, 'fish') + self.assertEqual(r, "fish") obj.mock.assert_called_once_with(1, 2) obj.mock.reset_mock() # a call with different params should call the mock again - obj.mock.return_value = 'chips' + obj.mock.return_value = "chips" r = yield obj.fn(1, 3) - self.assertEqual(r, 'chips') + self.assertEqual(r, "chips") obj.mock.assert_called_once_with(1, 3) obj.mock.reset_mock() # the two values should now be cached r = yield obj.fn(1, 2) - self.assertEqual(r, 'fish') + self.assertEqual(r, "fish") r = yield obj.fn(1, 3) - self.assertEqual(r, 'chips') + self.assertEqual(r, "chips") obj.mock.assert_not_called() @defer.inlineCallbacks @@ -121,25 +121,25 @@ def fn(self, arg1, arg2): return self.mock(arg1, arg2) obj = Cls() - obj.mock.return_value = 'fish' + obj.mock.return_value = "fish" r = yield obj.fn(1, 2) - self.assertEqual(r, 'fish') + self.assertEqual(r, "fish") obj.mock.assert_called_once_with(1, 2) obj.mock.reset_mock() # a call with different params should call the mock again - obj.mock.return_value = 'chips' + obj.mock.return_value = "chips" r = yield obj.fn(2, 3) - self.assertEqual(r, 'chips') + self.assertEqual(r, "chips") obj.mock.assert_called_once_with(2, 3) obj.mock.reset_mock() # the two values should now be cached; we should be able to vary # the second argument and still get the cached result. r = yield obj.fn(1, 4) - self.assertEqual(r, 'fish') + self.assertEqual(r, "fish") r = yield obj.fn(2, 5) - self.assertEqual(r, 'chips') + self.assertEqual(r, "chips") obj.mock.assert_not_called() def test_cache_logcontexts(self): @@ -248,30 +248,30 @@ def fn(self, arg1, arg2=2, arg3=3): obj = Cls() - obj.mock.return_value = 'fish' + obj.mock.return_value = "fish" r = yield obj.fn(1, 2, 3) - self.assertEqual(r, 'fish') + self.assertEqual(r, "fish") obj.mock.assert_called_once_with(1, 2, 3) obj.mock.reset_mock() # a call with same params shouldn't call the mock again r = yield obj.fn(1, 2) - self.assertEqual(r, 'fish') + self.assertEqual(r, "fish") obj.mock.assert_not_called() obj.mock.reset_mock() # a call with different params should call the mock again - obj.mock.return_value = 'chips' + obj.mock.return_value = "chips" r = yield obj.fn(2, 3) - self.assertEqual(r, 'chips') + self.assertEqual(r, "chips") obj.mock.assert_called_once_with(2, 3, 3) obj.mock.reset_mock() # the two values should now be cached r = yield obj.fn(1, 2) - self.assertEqual(r, 'fish') + self.assertEqual(r, "fish") r = yield obj.fn(2, 3) - self.assertEqual(r, 'chips') + self.assertEqual(r, "chips") obj.mock.assert_not_called() @@ -297,7 +297,7 @@ def list_fn(self, args1, arg2): with logcontext.LoggingContext() as c1: c1.request = "c1" obj = Cls() - obj.mock.return_value = {10: 'fish', 20: 'chips'} + obj.mock.return_value = {10: "fish", 20: "chips"} d1 = obj.list_fn([10, 20], 2) self.assertEqual( logcontext.LoggingContext.current_context(), @@ -306,26 +306,26 @@ def list_fn(self, args1, arg2): r = yield d1 self.assertEqual(logcontext.LoggingContext.current_context(), c1) obj.mock.assert_called_once_with([10, 20], 2) - self.assertEqual(r, {10: 'fish', 20: 'chips'}) + self.assertEqual(r, {10: "fish", 20: "chips"}) obj.mock.reset_mock() # a call with different params should call the mock again - obj.mock.return_value = {30: 'peas'} + obj.mock.return_value = {30: "peas"} r = yield obj.list_fn([20, 30], 2) obj.mock.assert_called_once_with([30], 2) - self.assertEqual(r, {20: 'chips', 30: 'peas'}) + self.assertEqual(r, {20: "chips", 30: "peas"}) obj.mock.reset_mock() # all the values should now be cached r = yield obj.fn(10, 2) - self.assertEqual(r, 'fish') + self.assertEqual(r, "fish") r = yield obj.fn(20, 2) - self.assertEqual(r, 'chips') + self.assertEqual(r, "chips") r = yield obj.fn(30, 2) - self.assertEqual(r, 'peas') + self.assertEqual(r, "peas") r = yield obj.list_fn([10, 20, 30], 2) obj.mock.assert_not_called() - self.assertEqual(r, {10: 'fish', 20: 'chips', 30: 'peas'}) + self.assertEqual(r, {10: "fish", 20: "chips", 30: "peas"}) @defer.inlineCallbacks def test_invalidate(self): @@ -350,16 +350,16 @@ def list_fn(self, args1, arg2): invalidate1 = mock.Mock() # cache miss - obj.mock.return_value = {10: 'fish', 20: 'chips'} + obj.mock.return_value = {10: "fish", 20: "chips"} r1 = yield obj.list_fn([10, 20], 2, on_invalidate=invalidate0) obj.mock.assert_called_once_with([10, 20], 2) - self.assertEqual(r1, {10: 'fish', 20: 'chips'}) + self.assertEqual(r1, {10: "fish", 20: "chips"}) obj.mock.reset_mock() # cache hit r2 = yield obj.list_fn([10, 20], 2, on_invalidate=invalidate1) obj.mock.assert_not_called() - self.assertEqual(r2, {10: 'fish', 20: 'chips'}) + self.assertEqual(r2, {10: "fish", 20: "chips"}) invalidate0.assert_not_called() invalidate1.assert_not_called() diff --git a/tests/util/caches/test_ttlcache.py b/tests/util/caches/test_ttlcache.py index 03b3c15db69c..c94cbb662bd5 100644 --- a/tests/util/caches/test_ttlcache.py +++ b/tests/util/caches/test_ttlcache.py @@ -27,57 +27,57 @@ def setUp(self): def test_get(self): """simple set/get tests""" - self.cache.set('one', '1', 10) - self.cache.set('two', '2', 20) - self.cache.set('three', '3', 30) + self.cache.set("one", "1", 10) + self.cache.set("two", "2", 20) + self.cache.set("three", "3", 30) self.assertEqual(len(self.cache), 3) - self.assertTrue('one' in self.cache) - self.assertEqual(self.cache.get('one'), '1') - self.assertEqual(self.cache['one'], '1') - self.assertEqual(self.cache.get_with_expiry('one'), ('1', 110)) + self.assertTrue("one" in self.cache) + self.assertEqual(self.cache.get("one"), "1") + self.assertEqual(self.cache["one"], "1") + self.assertEqual(self.cache.get_with_expiry("one"), ("1", 110)) self.assertEqual(self.cache._metrics.hits, 3) self.assertEqual(self.cache._metrics.misses, 0) - self.cache.set('two', '2.5', 20) - self.assertEqual(self.cache['two'], '2.5') + self.cache.set("two", "2.5", 20) + self.assertEqual(self.cache["two"], "2.5") self.assertEqual(self.cache._metrics.hits, 4) # non-existent-item tests - self.assertEqual(self.cache.get('four', '4'), '4') - self.assertIs(self.cache.get('four', None), None) + self.assertEqual(self.cache.get("four", "4"), "4") + self.assertIs(self.cache.get("four", None), None) with self.assertRaises(KeyError): - self.cache['four'] + self.cache["four"] with self.assertRaises(KeyError): - self.cache.get('four') + self.cache.get("four") with self.assertRaises(KeyError): - self.cache.get_with_expiry('four') + self.cache.get_with_expiry("four") self.assertEqual(self.cache._metrics.hits, 4) self.assertEqual(self.cache._metrics.misses, 5) def test_expiry(self): - self.cache.set('one', '1', 10) - self.cache.set('two', '2', 20) - self.cache.set('three', '3', 30) + self.cache.set("one", "1", 10) + self.cache.set("two", "2", 20) + self.cache.set("three", "3", 30) self.assertEqual(len(self.cache), 3) - self.assertEqual(self.cache['one'], '1') - self.assertEqual(self.cache['two'], '2') + self.assertEqual(self.cache["one"], "1") + self.assertEqual(self.cache["two"], "2") # enough for the first entry to expire, but not the rest self.mock_timer.side_effect = lambda: 110.0 self.assertEqual(len(self.cache), 2) - self.assertFalse('one' in self.cache) - self.assertEqual(self.cache['two'], '2') - self.assertEqual(self.cache['three'], '3') + self.assertFalse("one" in self.cache) + self.assertEqual(self.cache["two"], "2") + self.assertEqual(self.cache["three"], "3") - self.assertEqual(self.cache.get_with_expiry('two'), ('2', 120)) + self.assertEqual(self.cache.get_with_expiry("two"), ("2", 120)) self.assertEqual(self.cache._metrics.hits, 5) self.assertEqual(self.cache._metrics.misses, 0) diff --git a/tests/utils.py b/tests/utils.py index f8c7ad2604e0..bd2c7c954c52 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -358,9 +358,9 @@ def cleanup(): # Need to let the HS build an auth handler and then mess with it # because AuthHandler's constructor requires the HS, so we can't make one # beforehand and pass it in to the HS's constructor (chicken / egg) - hs.get_auth_handler().hash = lambda p: hashlib.md5(p.encode('utf8')).hexdigest() + hs.get_auth_handler().hash = lambda p: hashlib.md5(p.encode("utf8")).hexdigest() hs.get_auth_handler().validate_hash = ( - lambda p, h: hashlib.md5(p.encode('utf8')).hexdigest() == h + lambda p, h: hashlib.md5(p.encode("utf8")).hexdigest() == h ) fed = kargs.get("resource_for_federation", None) @@ -407,7 +407,7 @@ def __init__(self, prefix=""): def trigger_get(self, path): return self.trigger(b"GET", path, None) - @patch('twisted.web.http.Request') + @patch("twisted.web.http.Request") @defer.inlineCallbacks def trigger( self, http_method, path, content, mock_request, federation_auth_origin=None @@ -431,12 +431,12 @@ def trigger( # annoyingly we return a twisted http request which has chained calls # to get at the http content, hence mock it here. mock_content = Mock() - config = {'read.return_value': content} + config = {"read.return_value": content} mock_content.configure_mock(**config) mock_request.content = mock_content - mock_request.method = http_method.encode('ascii') - mock_request.uri = path.encode('ascii') + mock_request.method = http_method.encode("ascii") + mock_request.uri = path.encode("ascii") mock_request.getClientIP.return_value = "-" @@ -452,14 +452,14 @@ def trigger( # add in query params to the right place try: - mock_request.args = urlparse.parse_qs(path.split('?')[1]) - mock_request.path = path.split('?')[0] + mock_request.args = urlparse.parse_qs(path.split("?")[1]) + mock_request.path = path.split("?")[0] path = mock_request.path except Exception: pass if isinstance(path, bytes): - path = path.decode('utf8') + path = path.decode("utf8") for (method, pattern, func) in self.callbacks: if http_method != method: diff --git a/tox.ini b/tox.ini index 0c4d562766da..09b4b8fc3ca7 100644 --- a/tox.ini +++ b/tox.ini @@ -1,5 +1,5 @@ [tox] -envlist = packaging, py35, py36, py37, pep8, check_isort +envlist = packaging, py35, py36, py37, check_codestyle, check_isort [base] deps = @@ -112,12 +112,15 @@ deps = commands = check-manifest -[testenv:pep8] +[testenv:check_codestyle] skip_install = True basepython = python3.6 deps = flake8 -commands = /bin/sh -c "flake8 synapse tests scripts scripts-dev scripts/hash_password scripts/register_new_matrix_user scripts/synapse_port_db synctl {env:PEP8SUFFIX:}" + black +commands = + python -m black --check --diff . + /bin/sh -c "flake8 synapse tests scripts scripts-dev scripts/hash_password scripts/register_new_matrix_user scripts/synapse_port_db synctl {env:PEP8SUFFIX:}" [testenv:check_isort] skip_install = True