diff --git a/tests/test_ssl.py b/tests/test_ssl.py index a30df67a..feb4d804 100644 --- a/tests/test_ssl.py +++ b/tests/test_ssl.py @@ -12,6 +12,7 @@ import select import sys import time +import typing import uuid from errno import ( EAFNOSUPPORT, @@ -156,7 +157,7 @@ """ -def socket_any_family(): +def socket_any_family() -> socket: try: return socket(AF_INET) except OSError as e: @@ -165,7 +166,7 @@ def socket_any_family(): raise -def loopback_address(socket): +def loopback_address(socket: socket) -> str: if socket.family == AF_INET: return "127.0.0.1" else: @@ -194,7 +195,7 @@ def verify_cb(conn, cert, errnum, depth, ok): return ok -def socket_pair(): +def socket_pair() -> tuple[socket, socket]: """ Establish and return a pair of network sockets connected to each other. """ @@ -225,7 +226,7 @@ def socket_pair(): return (server, client) -def handshake(client, server): +def handshake(client: Connection, server: Connection) -> None: conns = [client, server] while conns: for conn in conns: @@ -322,13 +323,17 @@ def _create_certificate_chain(): ] -def loopback_client_factory(socket, version=SSLv23_METHOD): +def loopback_client_factory( + socket: socket, version: int = SSLv23_METHOD +) -> Connection: client = Connection(Context(version), socket) client.set_connect_state() return client -def loopback_server_factory(socket, version=SSLv23_METHOD): +def loopback_server_factory( + socket: socket | None, version: int = SSLv23_METHOD +) -> Connection: ctx = Context(version) ctx.use_privatekey(load_privatekey(FILETYPE_PEM, server_key_pem)) ctx.use_certificate(load_certificate(FILETYPE_PEM, server_cert_pem)) @@ -337,7 +342,10 @@ def loopback_server_factory(socket, version=SSLv23_METHOD): return server -def loopback(server_factory=None, client_factory=None): +def loopback( + server_factory: typing.Callable[[socket], Connection] | None = None, + client_factory: typing.Callable[[socket], Connection] | None = None, +) -> tuple[Connection, Connection]: """ Create a connected socket pair and force two connected SSL sockets to talk to each other via memory BIOs. @@ -348,17 +356,19 @@ def loopback(server_factory=None, client_factory=None): client_factory = loopback_client_factory (server, client) = socket_pair() - server = server_factory(server) - client = client_factory(client) + tls_server = server_factory(server) + tls_client = client_factory(client) - handshake(client, server) + handshake(tls_client, tls_server) - server.setblocking(True) - client.setblocking(True) - return server, client + tls_server.setblocking(True) + tls_client.setblocking(True) + return tls_server, tls_client -def interact_in_memory(client_conn, server_conn): +def interact_in_memory( + client_conn: Connection, server_conn: Connection +) -> None: """ Try to read application bytes from each of the two `Connection` objects. Copy bytes back and forth between their send/receive buffers for as long @@ -404,7 +414,9 @@ def interact_in_memory(client_conn, server_conn): write.bio_write(dirty) -def handshake_in_memory(client_conn, server_conn): +def handshake_in_memory( + client_conn: Connection, server_conn: Connection +) -> None: """ Perform the TLS handshake between two `Connection` instances connected to each other via memory BIOs. @@ -620,7 +632,7 @@ def test_method(self) -> None: Context(meth) with pytest.raises(TypeError): - Context("") + Context("") # type: ignore[arg-type] with pytest.raises(ValueError): Context(13) @@ -690,11 +702,11 @@ def test_use_certificate_file_wrong_args(self) -> None: """ ctx = Context(SSLv23_METHOD) with pytest.raises(TypeError): - ctx.use_certificate_file(object(), FILETYPE_PEM) + ctx.use_certificate_file(object(), FILETYPE_PEM) # type: ignore[arg-type] with pytest.raises(TypeError): - ctx.use_certificate_file(b"somefile", object()) + ctx.use_certificate_file(b"somefile", object()) # type: ignore[arg-type] with pytest.raises(TypeError): - ctx.use_certificate_file(object(), FILETYPE_PEM) + ctx.use_certificate_file(object(), FILETYPE_PEM) # type: ignore[arg-type] def test_use_certificate_file_missing(self, tmpfile) -> None: """ @@ -1070,7 +1082,7 @@ def _load_verify_locations_test(self, *args): # connection will fail. clientContext.set_verify( VERIFY_PEER, - lambda conn, cert, errno, depth, preverify_ok: preverify_ok, + lambda conn, cert, errno, depth, preverify_ok: bool(preverify_ok), ) clientSSL = Connection(clientContext, client) @@ -1094,6 +1106,7 @@ def _load_verify_locations_test(self, *args): handshake(clientSSL, serverSSL) cert = clientSSL.get_peer_certificate() + assert cert is not None assert cert.get_subject().CN == "Testing Root CA" cryptography_cert = clientSSL.get_peer_certificate( @@ -1228,6 +1241,7 @@ def test_fallback_default_verify_paths(self, monkeypatch) -> None: ) context.set_default_verify_paths() store = context.get_cert_store() + assert store is not None sk_obj = _lib.X509_STORE_get0_objects(store._store) assert sk_obj != _ffi.NULL num = _lib.sk_X509_OBJECT_num(sk_obj) @@ -1323,7 +1337,9 @@ def test_add_extra_chain_cert_invalid_cert(self) -> None: with pytest.raises(TypeError): context.add_extra_chain_cert(object()) - def _handshake_test(self, serverContext, clientContext): + def _handshake_test( + self, serverContext: Context, clientContext: Context + ) -> None: """ Verify that a client and server created with the given contexts can successfully handshake and communicate. @@ -2691,12 +2707,14 @@ def test_get_verified_chain(self) -> None: interact_in_memory(client, server) chain = client.get_verified_chain() + assert chain is not None assert len(chain) == 3 assert "Server Certificate" == chain[0].get_subject().CN assert "Intermediate Certificate" == chain[1].get_subject().CN assert "Authority Certificate" == chain[2].get_subject().CN cryptography_chain = client.get_verified_chain(as_cryptography=True) + assert cryptography_chain is not None assert len(cryptography_chain) == 3 assert ( cryptography_chain[0].subject.rfc4514_string() @@ -4509,7 +4527,7 @@ def pump_membio(label, source, sink): sink.bio_write(chunk) return True - def pump(): + def pump() -> None: # Raises if there was no data to pump, to avoid infinite loops if # we aren't making progress. assert pump_membio("s -> c", s, c) or pump_membio("c -> s", c, s)