From 468807332caf69e52816ea9dae31cc4000bd7667 Mon Sep 17 00:00:00 2001 From: Alex Gaynor Date: Thu, 9 Jan 2025 14:16:10 -0500 Subject: [PATCH] Bring full tests directory to typing correctly --- tests/test_ssl.py | 141 +++++++++++++++++++++++++++------------------- tox.ini | 2 +- 2 files changed, 84 insertions(+), 59 deletions(-) diff --git a/tests/test_ssl.py b/tests/test_ssl.py index f28fa05e..7127477b 100644 --- a/tests/test_ssl.py +++ b/tests/test_ssl.py @@ -9,6 +9,7 @@ import datetime import gc +import os import pathlib import select import sys @@ -25,7 +26,6 @@ ) from gc import collect, get_referrers from os import makedirs -from os.path import join from socket import ( AF_INET, AF_INET6, @@ -124,6 +124,7 @@ WantWriteError, ZeroReturnError, _make_requires, + _NoOverlappingProtocols, ) from .test_crypto import ( @@ -166,25 +167,10 @@ def loopback_address(socket: socket) -> str: return "::1" -def join_bytes_or_unicode(prefix, suffix): - """ - Join two path components of either ``bytes`` or ``unicode``. - - The return type is the same as the type of ``prefix``. - """ - # If the types are the same, nothing special is necessary. - if type(prefix) is type(suffix): - return join(prefix, suffix) - - # Otherwise, coerce suffix to the type of prefix. - if isinstance(prefix, str): - return join(prefix, suffix.decode(getfilesystemencoding())) - else: - return join(prefix, suffix.encode(getfilesystemencoding())) - - -def verify_cb(conn, cert, errnum, depth, ok): - return ok +def verify_cb( + conn: Connection, cert: X509, errnum: int, depth: int, ok: int +) -> bool: + return bool(ok) def socket_pair() -> tuple[socket, socket]: @@ -360,7 +346,7 @@ def loopback( def interact_in_memory( client_conn: Connection, server_conn: Connection -) -> tuple[Connection, bytes]: +) -> tuple[Connection, bytes] | None: """ Try to read application bytes from each of the two `Connection` objects. Copy bytes back and forth between their send/receive buffers for as long @@ -405,6 +391,8 @@ def interact_in_memory( wrote = True write.bio_write(dirty) + return None + def handshake_in_memory( client_conn: Connection, server_conn: Connection @@ -1021,9 +1009,9 @@ def info(conn: Connection, where: int, ret: int) -> None: for (conn, where, ret) in called if not isinstance(conn, Connection) ] - assert ( - [] == notConnections - ), "Some info callback arguments were not Connection instances." + assert [] == notConnections, ( + "Some info callback arguments were not Connection instances." + ) @pytest.mark.skipif( not getattr(_lib, "Cryptography_HAS_KEYLOG", None), @@ -1168,7 +1156,9 @@ def test_load_verify_invalid_file(self, tmpfile: bytes) -> None: with pytest.raises(Error): clientContext.load_verify_locations(tmpfile) - def _load_verify_directory_locations_capath(self, capath: bytes) -> None: + def _load_verify_directory_locations_capath( + self, capath: str | bytes + ) -> None: """ Verify that if path to a directory containing certificate files is passed to ``Context.load_verify_locations`` for the ``capath`` @@ -1180,7 +1170,11 @@ def _load_verify_directory_locations_capath(self, capath: bytes) -> None: # c_rehash in the test suite. One is from OpenSSL 0.9.8, the other # from OpenSSL 1.0.0. for name in [b"c7adac82.0", b"c3705638.0"]: - cafile = join_bytes_or_unicode(capath, name) + cafile: str | bytes + if isinstance(capath, str): + cafile = os.path.join(capath, name.decode()) + else: + cafile = os.path.join(capath, name) with open(cafile, "w") as fObj: fObj.write(root_cert_pem.decode("ascii")) @@ -1209,9 +1203,13 @@ def test_load_verify_directory_capath( """ if pathtype == "unicode_path": tmpfile += NON_ASCII.encode(getfilesystemencoding()) + if argtype == "unicode_arg": - tmpfile = tmpfile.decode(getfilesystemencoding()) - self._load_verify_directory_locations_capath(tmpfile) + self._load_verify_directory_locations_capath( + tmpfile.decode(getfilesystemencoding()) + ) + else: + self._load_verify_directory_locations_capath(tmpfile) def test_load_verify_locations_wrong_args(self) -> None: """ @@ -1393,7 +1391,14 @@ def test_set_verify_callback_connection_argument(self) -> None: serverConnection = Connection(serverContext, None) class VerifyCallback: - def callback(self, connection: Connection, *args) -> bool: + def callback( + self, + connection: Connection, + cert: X509, + err: int, + depth: int, + ok: int, + ) -> bool: self.connection = connection return True @@ -1452,7 +1457,9 @@ def test_set_verify_callback_exception(self) -> None: clientContext = Context(TLSv1_2_METHOD) - def verify_callback(*args): + def verify_callback( + conn: Connection, cert: X509, err: int, depth: int, ok: int + ) -> bool: raise Exception("silly verify failure") clientContext.set_verify(VERIFY_PEER, verify_callback) @@ -1482,7 +1489,7 @@ def test_set_verify_callback_reference(self) -> None: for i in range(5): - def verify_callback(*args): + def verify_callback(*args: object) -> bool: return True serverSocket, clientSocket = socket_pair() @@ -1589,8 +1596,14 @@ def _use_certificate_chain_file_test(self, certdir: str | bytes) -> None: makedirs(certdir) - chainFile = join_bytes_or_unicode(certdir, "chain.pem") - caFile = join_bytes_or_unicode(certdir, "ca.pem") + chainFile: str | bytes + caFile: str | bytes + if isinstance(certdir, str): + chainFile = os.path.join(certdir, "chain.pem") + caFile = os.path.join(certdir, "ca.pem") + else: + chainFile = os.path.join(certdir, b"chain.pem") + caFile = os.path.join(certdir, b"ca.pem") # Write out the chain file. with open(chainFile, "wb") as fObj: @@ -1848,9 +1861,9 @@ def replacement(connection: Connection) -> None: # pragma: no cover collect() collect() - callback = tracker() - if callback is not None: - referrers = get_referrers(callback) + callback_ref = tracker() + if callback_ref is not None: + referrers = get_referrers(callback_ref) assert len(referrers) == 1 def test_no_servername(self) -> None: @@ -2064,7 +2077,9 @@ def test_alpn_no_server_overlap(self) -> None: """ refusal_args = [] - def refusal(conn: Connection, options: list[bytes]): + def refusal( + conn: Connection, options: list[bytes] + ) -> _NoOverlappingProtocols: refusal_args.append((conn, options)) return NO_OVERLAPPING_PROTOCOLS @@ -2218,7 +2233,7 @@ def test_construction(self) -> None: @pytest.fixture(params=["context", "connection"]) -def ctx_or_conn(request) -> Context | Connection: +def ctx_or_conn(request: pytest.FixtureRequest) -> Context | Connection: ctx = Context(SSLv23_METHOD) if request.param == "context": return ctx @@ -2823,9 +2838,9 @@ def callback( ) collect() collect() - callback = tracker() - if callback is not None: # pragma: nocover - referrers = get_referrers(callback) + callback_ref = tracker() + if callback_ref is not None: # pragma: nocover + referrers = get_referrers(callback_ref) assert len(referrers) == 1 def test_get_session_unconnected(self) -> None: @@ -3862,7 +3877,9 @@ def test_outgoing_overflow(self) -> None: # meaningless. assert sent < size - receiver, received = interact_in_memory(client, server) + result = interact_in_memory(client, server) + assert result is not None + receiver, received = result assert receiver is server # We can rely on all of these bytes being received at once because @@ -4249,7 +4266,7 @@ def test_callbacks_arent_called_by_default(self) -> None: called. """ - def ocsp_callback(*args, **kwargs): # pragma: nocover + def ocsp_callback(*args: object) -> typing.NoReturn: # pragma: nocover pytest.fail("Should not be called") client = self._client_connection( @@ -4284,7 +4301,7 @@ def test_client_receives_servers_data(self) -> None: """ calls = [] - def server_callback(*args, **kwargs): + def server_callback(*args: object, **kwargs: object) -> bytes: return self.sample_ocsp_data def client_callback( @@ -4307,11 +4324,15 @@ def test_callbacks_are_invoked_with_connections(self) -> None: client_calls = [] server_calls = [] - def client_callback(conn, *args, **kwargs): + def client_callback( + conn: Connection, *args: object, **kwargs: object + ) -> bool: client_calls.append(conn) return True - def server_callback(conn, *args, **kwargs): + def server_callback( + conn: Connection, *args: object, **kwargs: object + ) -> bytes: server_calls.append(conn) return self.sample_ocsp_data @@ -4331,11 +4352,11 @@ def test_opaque_data_is_passed_through(self) -> None: """ calls = [] - def server_callback(*args): + def server_callback(*args: object) -> bytes: calls.append(args) return self.sample_ocsp_data - def client_callback(*args): + def client_callback(*args: object) -> bool: calls.append(args) return True @@ -4360,7 +4381,7 @@ def test_server_returns_empty_string(self) -> None: """ client_calls = [] - def server_callback(*args): + def server_callback(*args: object) -> bytes: return b"" def client_callback( @@ -4381,10 +4402,10 @@ def test_client_returns_false_terminates_handshake(self) -> None: If the client returns False from its callback, the handshake fails. """ - def server_callback(*args): + def server_callback(*args: object) -> bytes: return self.sample_ocsp_data - def client_callback(*args): + def client_callback(*args: object) -> bool: return False client = self._client_connection(callback=client_callback, data=None) @@ -4401,10 +4422,10 @@ def test_exceptions_in_client_bubble_up(self) -> None: class SentinelException(Exception): pass - def server_callback(*args): + def server_callback(*args: object) -> bytes: return self.sample_ocsp_data - def client_callback(*args): + def client_callback(*args: object) -> typing.NoReturn: raise SentinelException() client = self._client_connection(callback=client_callback, data=None) @@ -4421,10 +4442,12 @@ def test_exceptions_in_server_bubble_up(self) -> None: class SentinelException(Exception): pass - def server_callback(*args): + def server_callback(*args: object) -> typing.NoReturn: raise SentinelException() - def client_callback(*args): # pragma: nocover + def client_callback( + *args: object, + ) -> typing.NoReturn: # pragma: nocover pytest.fail("Should not be called") client = self._client_connection(callback=client_callback, data=None) @@ -4438,14 +4461,16 @@ def test_server_must_return_bytes(self) -> None: The server callback must return a bytestring, or a TypeError is thrown. """ - def server_callback(*args): + def server_callback(*args: object) -> str: return self.sample_ocsp_data.decode("ascii") - def client_callback(*args): # pragma: nocover + def client_callback( + *args: object, + ) -> typing.NoReturn: # pragma: nocover pytest.fail("Should not be called") client = self._client_connection(callback=client_callback, data=None) - server = self._server_connection(callback=server_callback, data=None) + server = self._server_connection(callback=server_callback, data=None) # type: ignore[arg-type] with pytest.raises(TypeError): handshake_in_memory(client, server) diff --git a/tox.ini b/tox.ini index 54521dfa..50105fa8 100644 --- a/tox.ini +++ b/tox.ini @@ -47,7 +47,7 @@ extras = deps = mypy commands = - mypy src/ tests/conftest.py tests/test_crypto.py tests/test_debug.py tests/test_rand.py tests/test_util.py tests/util.py + mypy src/ tests/ [testenv:check-manifest] deps =