Skip to content

Commit

Permalink
Fix more than 100 mypy errors in test_ssl.py (#1395)
Browse files Browse the repository at this point in the history
  • Loading branch information
alex authored Jan 7, 2025
1 parent ee017b2 commit 38888ab
Showing 1 changed file with 40 additions and 22 deletions.
62 changes: 40 additions & 22 deletions tests/test_ssl.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import select
import sys
import time
import typing
import uuid
from errno import (
EAFNOSUPPORT,
Expand Down Expand Up @@ -156,7 +157,7 @@
"""


def socket_any_family():
def socket_any_family() -> socket:
try:
return socket(AF_INET)
except OSError as e:
Expand All @@ -165,7 +166,7 @@ def socket_any_family():
raise


def loopback_address(socket):
def loopback_address(socket: socket) -> str:
if socket.family == AF_INET:
return "127.0.0.1"
else:
Expand Down Expand Up @@ -194,7 +195,7 @@ def verify_cb(conn, cert, errnum, depth, ok):
return ok


def socket_pair():
def socket_pair() -> tuple[socket, socket]:
"""
Establish and return a pair of network sockets connected to each other.
"""
Expand Down Expand Up @@ -225,7 +226,7 @@ def socket_pair():
return (server, client)


def handshake(client, server):
def handshake(client: Connection, server: Connection) -> None:
conns = [client, server]
while conns:
for conn in conns:
Expand Down Expand Up @@ -322,13 +323,17 @@ def _create_certificate_chain():
]


def loopback_client_factory(socket, version=SSLv23_METHOD):
def loopback_client_factory(
socket: socket, version: int = SSLv23_METHOD
) -> Connection:
client = Connection(Context(version), socket)
client.set_connect_state()
return client


def loopback_server_factory(socket, version=SSLv23_METHOD):
def loopback_server_factory(
socket: socket | None, version: int = SSLv23_METHOD
) -> Connection:
ctx = Context(version)
ctx.use_privatekey(load_privatekey(FILETYPE_PEM, server_key_pem))
ctx.use_certificate(load_certificate(FILETYPE_PEM, server_cert_pem))
Expand All @@ -337,7 +342,10 @@ def loopback_server_factory(socket, version=SSLv23_METHOD):
return server


def loopback(server_factory=None, client_factory=None):
def loopback(
server_factory: typing.Callable[[socket], Connection] | None = None,
client_factory: typing.Callable[[socket], Connection] | None = None,
) -> tuple[Connection, Connection]:
"""
Create a connected socket pair and force two connected SSL sockets
to talk to each other via memory BIOs.
Expand All @@ -348,17 +356,19 @@ def loopback(server_factory=None, client_factory=None):
client_factory = loopback_client_factory

(server, client) = socket_pair()
server = server_factory(server)
client = client_factory(client)
tls_server = server_factory(server)
tls_client = client_factory(client)

handshake(client, server)
handshake(tls_client, tls_server)

server.setblocking(True)
client.setblocking(True)
return server, client
tls_server.setblocking(True)
tls_client.setblocking(True)
return tls_server, tls_client


def interact_in_memory(client_conn, server_conn):
def interact_in_memory(
client_conn: Connection, server_conn: Connection
) -> None:
"""
Try to read application bytes from each of the two `Connection` objects.
Copy bytes back and forth between their send/receive buffers for as long
Expand Down Expand Up @@ -404,7 +414,9 @@ def interact_in_memory(client_conn, server_conn):
write.bio_write(dirty)


def handshake_in_memory(client_conn, server_conn):
def handshake_in_memory(
client_conn: Connection, server_conn: Connection
) -> None:
"""
Perform the TLS handshake between two `Connection` instances connected to
each other via memory BIOs.
Expand Down Expand Up @@ -620,7 +632,7 @@ def test_method(self) -> None:
Context(meth)

with pytest.raises(TypeError):
Context("")
Context("") # type: ignore[arg-type]
with pytest.raises(ValueError):
Context(13)

Expand Down Expand Up @@ -690,11 +702,11 @@ def test_use_certificate_file_wrong_args(self) -> None:
"""
ctx = Context(SSLv23_METHOD)
with pytest.raises(TypeError):
ctx.use_certificate_file(object(), FILETYPE_PEM)
ctx.use_certificate_file(object(), FILETYPE_PEM) # type: ignore[arg-type]
with pytest.raises(TypeError):
ctx.use_certificate_file(b"somefile", object())
ctx.use_certificate_file(b"somefile", object()) # type: ignore[arg-type]
with pytest.raises(TypeError):
ctx.use_certificate_file(object(), FILETYPE_PEM)
ctx.use_certificate_file(object(), FILETYPE_PEM) # type: ignore[arg-type]

def test_use_certificate_file_missing(self, tmpfile) -> None:
"""
Expand Down Expand Up @@ -1070,7 +1082,7 @@ def _load_verify_locations_test(self, *args):
# connection will fail.
clientContext.set_verify(
VERIFY_PEER,
lambda conn, cert, errno, depth, preverify_ok: preverify_ok,
lambda conn, cert, errno, depth, preverify_ok: bool(preverify_ok),
)

clientSSL = Connection(clientContext, client)
Expand All @@ -1094,6 +1106,7 @@ def _load_verify_locations_test(self, *args):
handshake(clientSSL, serverSSL)

cert = clientSSL.get_peer_certificate()
assert cert is not None
assert cert.get_subject().CN == "Testing Root CA"

cryptography_cert = clientSSL.get_peer_certificate(
Expand Down Expand Up @@ -1228,6 +1241,7 @@ def test_fallback_default_verify_paths(self, monkeypatch) -> None:
)
context.set_default_verify_paths()
store = context.get_cert_store()
assert store is not None
sk_obj = _lib.X509_STORE_get0_objects(store._store)
assert sk_obj != _ffi.NULL
num = _lib.sk_X509_OBJECT_num(sk_obj)
Expand Down Expand Up @@ -1323,7 +1337,9 @@ def test_add_extra_chain_cert_invalid_cert(self) -> None:
with pytest.raises(TypeError):
context.add_extra_chain_cert(object())

def _handshake_test(self, serverContext, clientContext):
def _handshake_test(
self, serverContext: Context, clientContext: Context
) -> None:
"""
Verify that a client and server created with the given contexts can
successfully handshake and communicate.
Expand Down Expand Up @@ -2691,12 +2707,14 @@ def test_get_verified_chain(self) -> None:
interact_in_memory(client, server)

chain = client.get_verified_chain()
assert chain is not None
assert len(chain) == 3
assert "Server Certificate" == chain[0].get_subject().CN
assert "Intermediate Certificate" == chain[1].get_subject().CN
assert "Authority Certificate" == chain[2].get_subject().CN

cryptography_chain = client.get_verified_chain(as_cryptography=True)
assert cryptography_chain is not None
assert len(cryptography_chain) == 3
assert (
cryptography_chain[0].subject.rfc4514_string()
Expand Down Expand Up @@ -4509,7 +4527,7 @@ def pump_membio(label, source, sink):
sink.bio_write(chunk)
return True

def pump():
def pump() -> None:
# Raises if there was no data to pump, to avoid infinite loops if
# we aren't making progress.
assert pump_membio("s -> c", s, c) or pump_membio("c -> s", c, s)
Expand Down

0 comments on commit 38888ab

Please sign in to comment.