From 24cc2c1765fe4534692e493208c403cbdfc46205 Mon Sep 17 00:00:00 2001 From: Alex Gaynor Date: Fri, 25 Oct 2024 08:26:28 -0400 Subject: [PATCH] Refactor two tests that rely on deprecated APIs (#1375) --- tests/test_crypto.py | 81 +++++++++++++++++++------------ tests/test_ssl.py | 112 ++++++++++++++++++++++++++----------------- 2 files changed, 118 insertions(+), 75 deletions(-) diff --git a/tests/test_crypto.py b/tests/test_crypto.py index 91cca320..c1db014c 100644 --- a/tests/test_crypto.py +++ b/tests/test_crypto.py @@ -2088,35 +2088,46 @@ def test_digest(self): ) ) - def _extcert(self, pkey, extensions): - cert = X509() - # Certificates with extensions must be X.509v3, which is encoded with a - # version of two. - cert.set_version(2) - cert.set_pubkey(pkey) - cert.get_subject().commonName = "Unit Tests" - cert.get_issuer().commonName = "Unit Tests" - when = datetime.now().strftime("%Y%m%d%H%M%SZ").encode("ascii") - cert.set_notBefore(when) - cert.set_notAfter(when) - - cert.add_extensions(extensions) - cert.sign(pkey, "sha256") - return load_certificate( - FILETYPE_PEM, dump_certificate(FILETYPE_PEM, cert) + def _extcert(self, key, extensions): + subject = x509.Name( + [x509.NameAttribute(x509.NameOID.COMMON_NAME, "Unit Tests")] ) + when = datetime.now() + builder = ( + x509.CertificateBuilder() + .public_key(key.public_key()) + .subject_name(subject) + .issuer_name(subject) + .not_valid_before(when) + .not_valid_after(when) + .serial_number(1) + ) + for i, ext in enumerate(extensions): + builder = builder.add_extension(ext, critical=i % 2 == 0) + + return X509.from_cryptography(builder.sign(key, hashes.SHA256())) def test_extension_count(self): """ `X509.get_extension_count` returns the number of extensions that are present in the certificate. """ - pkey = load_privatekey(FILETYPE_PEM, client_key_pem) - ca = X509Extension(b"basicConstraints", True, b"CA:FALSE") - key = X509Extension(b"keyUsage", True, b"digitalSignature") - subjectAltName = X509Extension( - b"subjectAltName", True, b"DNS:example.com" + pkey = load_privatekey( + FILETYPE_PEM, client_key_pem + ).to_cryptography_key() + ca = x509.BasicConstraints(ca=False, path_length=None) + key = x509.KeyUsage( + digital_signature=True, + content_commitment=False, + key_encipherment=False, + data_encipherment=False, + key_agreement=False, + key_cert_sign=False, + crl_sign=False, + encipher_only=False, + decipher_only=False, ) + san = x509.SubjectAlternativeName([x509.DNSName("example.com")]) # Try a certificate with no extensions at all. c = self._extcert(pkey, []) @@ -2127,7 +2138,7 @@ def test_extension_count(self): assert c.get_extension_count() == 1 # And a certificate with several - c = self._extcert(pkey, [ca, key, subjectAltName]) + c = self._extcert(pkey, [ca, key, san]) assert c.get_extension_count() == 3 def test_get_extension(self): @@ -2135,14 +2146,24 @@ def test_get_extension(self): `X509.get_extension` takes an integer and returns an `X509Extension` corresponding to the extension at that index. """ - pkey = load_privatekey(FILETYPE_PEM, client_key_pem) - ca = X509Extension(b"basicConstraints", True, b"CA:FALSE") - key = X509Extension(b"keyUsage", True, b"digitalSignature") - subjectAltName = X509Extension( - b"subjectAltName", False, b"DNS:example.com" + pkey = load_privatekey( + FILETYPE_PEM, client_key_pem + ).to_cryptography_key() + ca = x509.BasicConstraints(ca=False, path_length=None) + key = x509.KeyUsage( + digital_signature=True, + content_commitment=False, + key_encipherment=False, + data_encipherment=False, + key_agreement=False, + key_cert_sign=False, + crl_sign=False, + encipher_only=False, + decipher_only=False, ) + san = x509.SubjectAlternativeName([x509.DNSName("example.com")]) - cert = self._extcert(pkey, [ca, key, subjectAltName]) + cert = self._extcert(pkey, [ca, key, san]) ext = cert.get_extension(0) assert isinstance(ext, X509Extension) @@ -2151,12 +2172,12 @@ def test_get_extension(self): ext = cert.get_extension(1) assert isinstance(ext, X509Extension) - assert ext.get_critical() + assert not ext.get_critical() assert ext.get_short_name() == b"keyUsage" ext = cert.get_extension(2) assert isinstance(ext, X509Extension) - assert not ext.get_critical() + assert ext.get_critical() assert ext.get_short_name() == b"subjectAltName" with pytest.raises(IndexError): diff --git a/tests/test_ssl.py b/tests/test_ssl.py index f6cb4455..b009cd6f 100644 --- a/tests/test_ssl.py +++ b/tests/test_ssl.py @@ -56,10 +56,6 @@ load_certificate, load_privatekey, ) - -with pytest.warns(DeprecationWarning): - from OpenSSL.crypto import X509Extension - from OpenSSL.SSL import ( DTLS_METHOD, MODE_RELEASE_BUFFERS, @@ -248,55 +244,81 @@ def _create_certificate_chain(): 2. A new intermediate certificate signed by cacert (icert) 3. A new server certificate signed by icert (scert) """ - caext = X509Extension(b"basicConstraints", False, b"CA:true") - not_after_date = datetime.date.today() + datetime.timedelta(days=365) - not_after = not_after_date.strftime("%Y%m%d%H%M%SZ").encode("ascii") + not_before = datetime.datetime(2000, 1, 1, 0, 0, 0) + not_after = datetime.datetime.now() + datetime.timedelta(days=365) # Step 1 - cakey = PKey() - cakey.generate_key(TYPE_RSA, 2048) - cacert = X509() - cacert.set_version(2) - cacert.get_subject().commonName = "Authority Certificate" - cacert.set_issuer(cacert.get_subject()) - cacert.set_pubkey(cakey) - cacert.set_notBefore(b"20000101000000Z") - cacert.set_notAfter(not_after) - cacert.add_extensions([caext]) - cacert.set_serial_number(0) - cacert.sign(cakey, "sha256") + cakey = rsa.generate_private_key(key_size=2048, public_exponent=65537) + casubject = x509.Name( + [x509.NameAttribute(x509.NameOID.COMMON_NAME, "Authority Certificate")] + ) + cacert = ( + x509.CertificateBuilder() + .subject_name(casubject) + .issuer_name(casubject) + .public_key(cakey.public_key()) + .not_valid_before(not_before) + .not_valid_after(not_after) + .add_extension( + x509.BasicConstraints(ca=True, path_length=None), critical=False + ) + .serial_number(1) + .sign(cakey, hashes.SHA256()) + ) # Step 2 - ikey = PKey() - ikey.generate_key(TYPE_RSA, 2048) - icert = X509() - icert.set_version(2) - icert.get_subject().commonName = "Intermediate Certificate" - icert.set_issuer(cacert.get_subject()) - icert.set_pubkey(ikey) - icert.set_notBefore(b"20000101000000Z") - icert.set_notAfter(not_after) - icert.add_extensions([caext]) - icert.set_serial_number(0) - icert.sign(cakey, "sha256") + ikey = rsa.generate_private_key(key_size=2048, public_exponent=65537) + icert = ( + x509.CertificateBuilder() + .subject_name( + x509.Name( + [ + x509.NameAttribute( + x509.NameOID.COMMON_NAME, "Intermediate Certificate" + ) + ] + ) + ) + .issuer_name(cacert.subject) + .public_key(ikey.public_key()) + .not_valid_before(not_before) + .not_valid_after(not_after) + .add_extension( + x509.BasicConstraints(ca=True, path_length=None), critical=False + ) + .serial_number(1) + .sign(cakey, hashes.SHA256()) + ) # Step 3 - skey = PKey() - skey.generate_key(TYPE_RSA, 2048) - scert = X509() - scert.set_version(2) - scert.get_subject().commonName = "Server Certificate" - scert.set_issuer(icert.get_subject()) - scert.set_pubkey(skey) - scert.set_notBefore(b"20000101000000Z") - scert.set_notAfter(not_after) - scert.add_extensions( - [X509Extension(b"basicConstraints", True, b"CA:false")] + skey = rsa.generate_private_key(key_size=2048, public_exponent=65537) + scert = ( + x509.CertificateBuilder() + .subject_name( + x509.Name( + [ + x509.NameAttribute( + x509.NameOID.COMMON_NAME, "Server Certificate" + ) + ] + ) + ) + .issuer_name(icert.subject) + .public_key(skey.public_key()) + .not_valid_before(not_before) + .not_valid_after(not_after) + .add_extension( + x509.BasicConstraints(ca=False, path_length=None), critical=True + ) + .serial_number(1) + .sign(ikey, hashes.SHA256()) ) - scert.set_serial_number(0) - scert.sign(ikey, "sha256") - return [(cakey, cacert), (ikey, icert), (skey, scert)] + return [ + (PKey.from_cryptography_key(cakey), X509.from_cryptography(cacert)), + (PKey.from_cryptography_key(ikey), X509.from_cryptography(icert)), + (PKey.from_cryptography_key(skey), X509.from_cryptography(scert)), + ] def loopback_client_factory(socket, version=SSLv23_METHOD):