From 41fd641cb0bed42cff73578ba9bd2ea6a9d3ab58 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jeremy=20Lain=C3=A9?= Date: Thu, 16 Nov 2023 08:18:05 +0100 Subject: [PATCH] [tls] Add parsing and serialization for certificate requests In order to implement support for client certificates, we need to be able to parse and serialize CERTIFICATE_REQUEST messages. Co-authored-by: doronz88 --- src/aioquic/tls.py | 49 ++++++++++++++++++++++++++++++ tests/test_tls.py | 36 ++++++++++++++++++++++ tests/tls_certificate_request.bin | Bin 0 -> 28 bytes 3 files changed, 85 insertions(+) create mode 100644 tests/tls_certificate_request.bin diff --git a/src/aioquic/tls.py b/src/aioquic/tls.py index bdb4242f5..d9c552949 100644 --- a/src/aioquic/tls.py +++ b/src/aioquic/tls.py @@ -861,6 +861,55 @@ def push_certificate_entry(buf: Buffer, entry: CertificateEntry) -> None: ) +@dataclass +class CertificateRequest: + request_context: bytes = b"" + signature_algorithms: Optional[List[int]] = None + other_extensions: List[Tuple[int, bytes]] = field(default_factory=list) + + +def pull_certificate_request(buf: Buffer) -> CertificateRequest: + certificate_request = CertificateRequest() + + assert buf.pull_uint8() == HandshakeType.CERTIFICATE_REQUEST + with pull_block(buf, 3): + certificate_request.request_context = pull_opaque(buf, 1) + + def pull_extension() -> None: + extension_type = buf.pull_uint16() + extension_length = buf.pull_uint16() + if extension_type == ExtensionType.SIGNATURE_ALGORITHMS: + certificate_request.signature_algorithms = pull_list( + buf, 2, buf.pull_uint16 + ) + else: + certificate_request.other_extensions.append( + (extension_type, buf.pull_bytes(extension_length)) + ) + + pull_list(buf, 2, pull_extension) + + return certificate_request + + +def push_certificate_request( + buf: Buffer, certificate_request: CertificateRequest +) -> None: + buf.push_uint8(HandshakeType.CERTIFICATE_REQUEST) + with push_block(buf, 3): + push_opaque(buf, 1, certificate_request.request_context) + + with push_block(buf, 2): + with push_extension(buf, ExtensionType.SIGNATURE_ALGORITHMS): + push_list( + buf, 2, buf.push_uint16, certificate_request.signature_algorithms + ) + + for extension_type, extension_value in certificate_request.other_extensions: + with push_extension(buf, extension_type): + buf.push_bytes(extension_value) + + @dataclass class CertificateVerify: algorithm: int diff --git a/tests/test_tls.py b/tests/test_tls.py index 5cca0ddf7..cd663321f 100644 --- a/tests/test_tls.py +++ b/tests/test_tls.py @@ -9,6 +9,7 @@ from aioquic.quic.configuration import QuicConfiguration from aioquic.tls import ( Certificate, + CertificateRequest, CertificateVerify, ClientHello, Context, @@ -20,6 +21,7 @@ load_pem_x509_certificates, pull_block, pull_certificate, + pull_certificate_request, pull_certificate_verify, pull_client_hello, pull_encrypted_extensions, @@ -27,6 +29,7 @@ pull_new_session_ticket, pull_server_hello, push_certificate, + push_certificate_request, push_certificate_verify, push_client_hello, push_encrypted_extensions, @@ -1228,6 +1231,39 @@ def test_push_certificate(self): push_certificate(buf, certificate) self.assertEqual(buf.data, load("tls_certificate.bin")) + def test_pull_certificate_request(self): + buf = Buffer(data=load("tls_certificate_request.bin")) + certificate_request = pull_certificate_request(buf) + self.assertTrue(buf.eof()) + + self.assertEqual(certificate_request.request_context, b"") + self.assertEqual( + certificate_request.signature_algorithms, + [ + tls.SignatureAlgorithm.RSA_PSS_RSAE_SHA256, + tls.SignatureAlgorithm.ECDSA_SECP256R1_SHA256, + tls.SignatureAlgorithm.RSA_PKCS1_SHA256, + tls.SignatureAlgorithm.RSA_PKCS1_SHA1, + ], + ) + self.assertEqual(certificate_request.other_extensions, [(12345, b"foo")]) + + def test_push_certificate_request(self): + certificate_request = CertificateRequest( + request_context=b"", + signature_algorithms=[ + tls.SignatureAlgorithm.RSA_PSS_RSAE_SHA256, + tls.SignatureAlgorithm.ECDSA_SECP256R1_SHA256, + tls.SignatureAlgorithm.RSA_PKCS1_SHA256, + tls.SignatureAlgorithm.RSA_PKCS1_SHA1, + ], + other_extensions=[(12345, b"foo")], + ) + + buf = Buffer(400) + push_certificate_request(buf, certificate_request) + self.assertEqual(buf.data, load("tls_certificate_request.bin")) + def test_pull_certificate_verify(self): buf = Buffer(data=load("tls_certificate_verify.bin")) verify = pull_certificate_verify(buf) diff --git a/tests/tls_certificate_request.bin b/tests/tls_certificate_request.bin new file mode 100644 index 0000000000000000000000000000000000000000..e65f2ea6554c048a0fc45ba8cc0fe84b68885e7b GIT binary patch literal 28 jcmd;OV31&75M|(H;9}t5U}0fqVPs-7uw-CP%g+Y@4sQY_ literal 0 HcmV?d00001