From aaae641bb2ce4c19b842501d91f1857c0538944a Mon Sep 17 00:00:00 2001 From: James Mayclin Date: Mon, 9 Dec 2024 11:11:39 -0800 Subject: [PATCH] feat(bindings): enable application owned certs (#4937) --- bindings/rust/s2n-tls/src/callbacks/pkey.rs | 50 ++- bindings/rust/s2n-tls/src/cert_chain.rs | 397 ++++++++++++++++++-- bindings/rust/s2n-tls/src/config.rs | 91 ++++- bindings/rust/s2n-tls/src/connection.rs | 11 +- bindings/rust/s2n-tls/src/renegotiate.rs | 15 +- bindings/rust/s2n-tls/src/testing.rs | 66 +++- tests/unit/s2n_certificate_test.c | 9 + 7 files changed, 589 insertions(+), 50 deletions(-) diff --git a/bindings/rust/s2n-tls/src/callbacks/pkey.rs b/bindings/rust/s2n-tls/src/callbacks/pkey.rs index bbf44d6e8cf..b946617fd8f 100644 --- a/bindings/rust/s2n-tls/src/callbacks/pkey.rs +++ b/bindings/rust/s2n-tls/src/callbacks/pkey.rs @@ -128,7 +128,7 @@ mod tests { testing::{self, *}, }; use core::task::{Poll, Waker}; - use futures_test::task::new_count_waker; + use futures_test::task::{new_count_waker, noop_waker}; use openssl::{ec::EcKey, ecdsa::EcdsaSig}; type Error = Box; @@ -350,4 +350,52 @@ mod tests { assert_test_error(err, ERROR); Ok(()) } + + /// pkey offload should also work with public certs created from + /// [CertificateChain::from_public_pems]. + #[test] + fn app_owned_public_cert() -> Result<(), Error> { + struct TestPkeyCallback; + impl PrivateKeyCallback for TestPkeyCallback { + fn handle_operation( + &self, + conn: &mut connection::Connection, + op: PrivateKeyOperation, + ) -> Result>>, error::Error> { + ecdsa_sign(op, conn, KEY)?; + Ok(None) + } + } + + let public_chain = { + let mut chain = crate::cert_chain::Builder::new()?; + chain.load_public_pem(CERT)?; + chain.build()? + }; + + let server_config = { + let mut config = config::Builder::new(); + config + .set_security_policy(&security::DEFAULT_TLS13)? + .load_chain(public_chain)? + .set_private_key_callback(TestPkeyCallback)?; + config.build()? + }; + + let client_config = { + let mut config = config::Builder::new(); + config + .set_security_policy(&security::DEFAULT_TLS13)? + .set_verify_host_callback(InsecureAcceptAllCertificatesHandler {})? + .trust_pem(CERT)?; + config.build()? + }; + + let mut pair = TestPair::from_configs(&client_config, &server_config); + pair.server.set_waker(Some(&noop_waker()))?; + + assert!(pair.handshake().is_ok()); + + Ok(()) + } } diff --git a/bindings/rust/s2n-tls/src/cert_chain.rs b/bindings/rust/s2n-tls/src/cert_chain.rs index e95502c1fc9..189383cfb15 100644 --- a/bindings/rust/s2n-tls/src/cert_chain.rs +++ b/bindings/rust/s2n-tls/src/cert_chain.rs @@ -6,34 +6,160 @@ use s2n_tls_sys::*; use std::{ marker::PhantomData, ptr::{self, NonNull}, + sync::Arc, }; +/// Internal wrapper type used for a convenient drop implementation. +/// +/// [CertificateChain] is internally reference counted. The reference counted `T` +/// must have a drop implementation. +struct CertificateChainHandle { + cert: NonNull, + is_owned: bool, +} + +// # Safety +// +// s2n_cert_chain_and_key objects can be sent across threads. +unsafe impl Send for CertificateChainHandle {} +unsafe impl Sync for CertificateChainHandle {} + +impl CertificateChainHandle { + fn from_owned(cert: NonNull) -> Self { + Self { + cert, + is_owned: true, + } + } + + fn from_reference(cert: NonNull) -> Self { + Self { + cert, + is_owned: false, + } + } +} + +impl Drop for CertificateChainHandle { + fn drop(&mut self) { + // ignore failures since there's not much we can do about it + if self.is_owned { + unsafe { + let _ = s2n_cert_chain_and_key_free(self.cert.as_ptr()).into_result(); + } + } + } +} + +pub struct Builder { + cert: CertificateChain<'static>, +} + +impl Builder { + pub fn new() -> Result { + Ok(Self { + cert: CertificateChain::allocate_owned()?, + }) + } + + /// Corresponds to [s2n_cert_chain_and_key_load_pem_bytes] + /// + /// This can be used with [crate::config::Builder::load_chain] to share a + /// single cert across multiple configs. + pub fn load_pem(&mut self, chain: &[u8], key: &[u8]) -> Result<&mut Self, Error> { + unsafe { + // SAFETY: manual audit of load_pem_bytes shows that `chain_pem` and + // `private_key_pem` are not modified. + // https://github.com/aws/s2n-tls/issues/4140 + s2n_cert_chain_and_key_load_pem_bytes( + self.cert.as_mut_ptr(), + chain.as_ptr() as *mut _, + chain.len() as u32, + key.as_ptr() as *mut _, + key.len() as u32, + ) + .into_result() + }?; + + Ok(self) + } + + /// Corresponds to [s2n_cert_chain_and_key_load_public_pem_bytes]. + /// + /// This method is only used when performing private-key offloading. For standard + /// use-cases see [CertificateChain::from_pem]. + pub fn load_public_pem(&mut self, chain: &[u8]) -> Result<&mut Self, Error> { + unsafe { + // SAFETY: manual audit of load_public_pem_bytes shows that `chain_pem` + // is not modified + // https://github.com/aws/s2n-tls/issues/4140 + s2n_cert_chain_and_key_load_public_pem_bytes( + self.cert.as_mut_ptr(), + chain.as_ptr() as *mut _, + chain.len() as u32, + ) + .into_result() + }?; + + Ok(self) + } + + /// Corresponds to [s2n_cert_chain_and_key_set_ocsp_data]. + pub fn set_ocsp_data(&mut self, data: &[u8]) -> Result<&mut Self, Error> { + unsafe { + s2n_cert_chain_and_key_set_ocsp_data( + self.cert.as_mut_ptr(), + data.as_ptr(), + data.len() as u32, + ) + .into_result() + }?; + Ok(self) + } + + /// Return an immutable, internally-reference counted CertificateChain. + pub fn build(self) -> Result, Error> { + // This method is currently infalliable, but returning a result allows + // us to add validation in the future. + Ok(self.cert) + } +} + /// A CertificateChain represents a chain of X.509 certificates. +/// +/// Certificate chains are internally reference counted and are cheaply cloneable. +// +// SAFETY: it is important that no CertificateChain methods operate on mutable +// references. Because CertificateChains can be shared across threads, it is not +// safe to mutate CertificateChains. +#[derive(Clone)] pub struct CertificateChain<'a> { - ptr: NonNull, - is_owned: bool, + ptr: Arc, _lifetime: PhantomData<&'a s2n_cert_chain_and_key>, } impl CertificateChain<'_> { /// This allocates a new certificate chain from s2n. - pub(crate) fn new() -> Result, Error> { + pub(crate) fn allocate_owned() -> Result, Error> { + crate::init::init(); unsafe { let ptr = s2n_cert_chain_and_key_new().into_result()?; Ok(CertificateChain { - ptr, - is_owned: true, + ptr: Arc::new(CertificateChainHandle::from_owned(ptr)), _lifetime: PhantomData, }) } } + /// This is used to create a CertificateChain "reference" backed by memory + /// on some external struct, where the external struct has some lifetime `'a`. pub(crate) unsafe fn from_ptr_reference<'a>( ptr: NonNull, ) -> CertificateChain<'a> { + let handle = Arc::new(CertificateChainHandle::from_reference(ptr)); + CertificateChain { - ptr, - is_owned: false, + ptr: handle, _lifetime: PhantomData, } } @@ -54,8 +180,7 @@ impl CertificateChain<'_> { /// expensive API to call. pub fn len(&self) -> usize { let mut length: u32 = 0; - let res = - unsafe { s2n_cert_chain_get_length(self.ptr.as_ptr(), &mut length).into_result() }; + let res = unsafe { s2n_cert_chain_get_length(self.as_ptr(), &mut length).into_result() }; if res.is_err() { // Errors should only happen on empty chains (we guarantee that `ptr` is a valid chain). return 0; @@ -72,24 +197,16 @@ impl CertificateChain<'_> { self.len() == 0 } - pub(crate) fn as_mut_ptr(&mut self) -> NonNull { - self.ptr + /// SAFETY: Only one instance of `CertificateChain` may exist when this method + /// is called. s2n_cert_chain_and_key is not thread-safe, so it is not safe + /// to mutate the certificate chain if references are held across multiple threads. + pub(crate) unsafe fn as_mut_ptr(&mut self) -> *mut s2n_cert_chain_and_key { + debug_assert_eq!(Arc::strong_count(&self.ptr), 1); + self.ptr.cert.as_ptr() } -} - -// # Safety -// -// s2n_cert_chain_and_key objects can be sent across threads. -unsafe impl Send for CertificateChain<'_> {} -impl Drop for CertificateChain<'_> { - fn drop(&mut self) { - if self.is_owned { - // ignore failures since there's not much we can do about it - unsafe { - let _ = s2n_cert_chain_and_key_free(self.ptr.as_ptr()).into_result(); - } - } + pub(crate) fn as_ptr(&self) -> *const s2n_cert_chain_and_key { + self.ptr.cert.as_ptr() as *const _ } } @@ -112,7 +229,7 @@ impl<'a> Iterator for CertificateChainIter<'a> { let mut out = ptr::null_mut(); unsafe { if let Err(e) = - s2n_cert_chain_get_cert(self.chain.ptr.as_ptr(), &mut out, idx).into_result() + s2n_cert_chain_get_cert(self.chain.as_ptr(), &mut out, idx).into_result() { return Some(Err(e)); } @@ -152,3 +269,231 @@ impl Certificate<'_> { // // Certificates just reference data in the chain, so share the Send-ness of the chain. unsafe impl Send for Certificate<'_> {} + +#[cfg(test)] +mod tests { + use crate::{ + config, + error::{ErrorSource, ErrorType}, + security::DEFAULT_TLS13, + testing::{InsecureAcceptAllCertificatesHandler, SniTestCerts, TestPair}, + }; + + use super::*; + + /// Create a test pair using SNI certs + /// * `certs`: takes references to already created cert chains. This is useful + /// to assert on expected reference counts. + /// * `types`: Used to find the CA paths for the client configs + fn sni_test_pair( + certs: Vec>, + defaults: Option>>, + types: &[SniTestCerts], + ) -> Result { + let mut server_config = config::Builder::new(); + server_config + .with_system_certs(false)? + .set_security_policy(&DEFAULT_TLS13)?; + for cert in certs.into_iter() { + server_config.load_chain(cert)?; + } + if let Some(defaults) = defaults { + server_config.set_default_chains(defaults)?; + } + + let mut client_config = config::Builder::new(); + client_config + .with_system_certs(false)? + .set_security_policy(&DEFAULT_TLS13)? + .set_verify_host_callback(InsecureAcceptAllCertificatesHandler {})?; + for t in types { + client_config.trust_pem(t.get().cert())?; + } + Ok(TestPair::from_configs( + &client_config.build()?, + &server_config.build()?, + )) + } + + /// This is a useful (but inefficient) test utility to check if CertificateChain + /// structs are equal. It does this by comparing the serialized `der` representation. + fn cert_chains_are_equal(this: &CertificateChain<'_>, that: &CertificateChain<'_>) -> bool { + let this: Vec> = this + .iter() + .map(|cert| cert.unwrap().der().unwrap().to_owned()) + .collect(); + let that: Vec> = that + .iter() + .map(|cert| cert.unwrap().der().unwrap().to_owned()) + .collect(); + this == that + } + + #[test] + fn reference_count_increment() -> Result<(), crate::error::Error> { + let cert = SniTestCerts::AlligatorRsa.get().into_certificate_chain(); + assert_eq!(Arc::strong_count(&cert.ptr), 1); + + { + let mut server = config::Builder::new(); + server.load_chain(cert.clone())?; + + // after being added, the reference count should have increased + assert_eq!(Arc::strong_count(&cert.ptr), 2); + } + + // after the config goes out of scope and is dropped, the ref count should + // decrement + assert_eq!(Arc::strong_count(&cert.ptr), 1); + Ok(()) + } + + #[test] + fn cert_is_dropped() { + let weak_ref = { + let cert = SniTestCerts::AlligatorEcdsa.get().into_certificate_chain(); + assert_eq!(Arc::strong_count(&cert.ptr), 1); + Arc::downgrade(&cert.ptr) + }; + assert_eq!(weak_ref.strong_count(), 0); + assert!(weak_ref.upgrade().is_none()); + } + + // a cert can be successfully shared across multiple configs + #[test] + fn shared_certs() -> Result<(), crate::error::Error> { + let test_key_pair = SniTestCerts::AlligatorRsa.get(); + let cert = test_key_pair.into_certificate_chain(); + + let mut test_pair_1 = + sni_test_pair(vec![cert.clone()], None, &[SniTestCerts::AlligatorRsa])?; + let mut test_pair_2 = + sni_test_pair(vec![cert.clone()], None, &[SniTestCerts::AlligatorRsa])?; + + assert_eq!(Arc::strong_count(&cert.ptr), 3); + + assert!(test_pair_1.handshake().is_ok()); + assert!(test_pair_2.handshake().is_ok()); + + assert_eq!(Arc::strong_count(&cert.ptr), 3); + + drop(test_pair_1); + assert_eq!(Arc::strong_count(&cert.ptr), 2); + drop(test_pair_2); + assert_eq!(Arc::strong_count(&cert.ptr), 1); + Ok(()) + } + + #[test] + fn too_many_certs_in_default() -> Result<(), crate::error::Error> { + // 5 certs in the maximum allowed, 6 should error. + const FAILING_NUMBER: usize = 6; + let certs = vec![SniTestCerts::AlligatorRsa.get().into_certificate_chain(); FAILING_NUMBER]; + assert_eq!(Arc::strong_count(&certs[0].ptr), FAILING_NUMBER); + + let mut config = config::Builder::new(); + let err = config.set_default_chains(certs.clone()).err().unwrap(); + assert_eq!(err.kind(), ErrorType::UsageError); + assert_eq!(err.source(), ErrorSource::Bindings); + + // The config should not hold a reference when the error was detected + // in the bindings + assert_eq!(Arc::strong_count(&certs[0].ptr), FAILING_NUMBER); + + Ok(()) + } + + #[test] + fn default_selection() -> Result<(), crate::error::Error> { + let alligator_cert = SniTestCerts::AlligatorRsa.get().into_certificate_chain(); + let beaver_cert = SniTestCerts::BeaverRsa.get().into_certificate_chain(); + + // when no default is explicitly set, the first loaded cert is the default + { + let mut test_pair = sni_test_pair( + vec![alligator_cert.clone(), beaver_cert.clone()], + None, + &[SniTestCerts::AlligatorRsa, SniTestCerts::BeaverRsa], + )?; + + assert!(test_pair.handshake().is_ok()); + + assert!(cert_chains_are_equal( + &alligator_cert, + &test_pair.client.peer_cert_chain().unwrap() + )); + + assert_eq!(Arc::strong_count(&alligator_cert.ptr), 2); + assert_eq!(Arc::strong_count(&beaver_cert.ptr), 2); + } + + // set an explicit default + { + let mut test_pair = sni_test_pair( + vec![alligator_cert.clone(), beaver_cert.clone()], + Some(vec![beaver_cert.clone()]), + &[SniTestCerts::AlligatorRsa, SniTestCerts::BeaverRsa], + )?; + + assert!(test_pair.handshake().is_ok()); + + assert!(cert_chains_are_equal( + &beaver_cert, + &test_pair.client.peer_cert_chain().unwrap() + )); + + assert_eq!(Arc::strong_count(&alligator_cert.ptr), 2); + // beaver has an additional reference because it was used in multiple + // calls + assert_eq!(Arc::strong_count(&beaver_cert.ptr), 3); + } + + // set a default without adding it to the store + { + let mut test_pair = sni_test_pair( + vec![alligator_cert.clone()], + Some(vec![beaver_cert.clone()]), + &[SniTestCerts::AlligatorRsa, SniTestCerts::BeaverRsa], + )?; + + assert!(test_pair.handshake().is_ok()); + + assert!(cert_chains_are_equal( + &beaver_cert, + &test_pair.client.peer_cert_chain().unwrap() + )); + + assert_eq!(Arc::strong_count(&alligator_cert.ptr), 2); + assert_eq!(Arc::strong_count(&beaver_cert.ptr), 2); + } + + Ok(()) + } + + #[test] + fn cert_ownership_error() -> Result<(), crate::error::Error> { + let application_owned_cert = SniTestCerts::AlligatorRsa.get().into_certificate_chain(); + let cert_for_lib = SniTestCerts::BeaverRsa.get(); + + let mut config = config::Builder::new(); + + // library owned certs can not be used with application owned certs + config.load_chain(application_owned_cert)?; + let err = config + .load_pem(cert_for_lib.cert(), cert_for_lib.key()) + .err() + .unwrap(); + + assert_eq!(err.kind(), ErrorType::UsageError); + assert_eq!(err.name(), "S2N_ERR_CERT_OWNERSHIP"); + + Ok(()) + } + + // ensure the certificates are send and sync + #[test] + fn certificate_send_sync_test() { + fn assert_send_sync() {} + assert_send_sync::>(); + } +} diff --git a/bindings/rust/s2n-tls/src/config.rs b/bindings/rust/s2n-tls/src/config.rs index c7e6353347e..81c0c4db0a4 100644 --- a/bindings/rust/s2n-tls/src/config.rs +++ b/bindings/rust/s2n-tls/src/config.rs @@ -5,8 +5,9 @@ use crate::renegotiate::RenegotiateCallback; use crate::{ callbacks::*, + cert_chain::CertificateChain, enums::*, - error::{Error, Fallible}, + error::{Error, ErrorType, Fallible}, security, }; use core::{convert::TryInto, ptr::NonNull}; @@ -277,6 +278,12 @@ impl Builder { Ok(self) } + /// Associate a `certificate` and corresponding `private_key` with a config. + /// Using this method, at most one certificate per auth type (ECDSA, RSA, RSA-PSS) + /// can be loaded. + /// + /// For more advanced cert use cases such as sharing certs across configs or + /// serving differents certs based on the client SNI, see [Builder::load_chain]. pub fn load_pem(&mut self, certificate: &[u8], private_key: &[u8]) -> Result<&mut Self, Error> { let certificate = CString::new(certificate).map_err(|_| Error::INVALID_INPUT)?; let private_key = CString::new(private_key).map_err(|_| Error::INVALID_INPUT)?; @@ -291,6 +298,77 @@ impl Builder { Ok(self) } + /// Corresponds to [s2n_config_add_cert_chain_and_key_to_store]. + pub fn load_chain(&mut self, chain: CertificateChain<'static>) -> Result<&mut Self, Error> { + // Out of an abudance of caution, we hold a reference to the CertificateChain + // regardless of whether add_to_store fails or succeeds. We have limited + // visibility into the failure modes, so this behavior ensures that _if_ + // the C library held the reference despite the failure, it would continue + // to be valid memory. + let result = unsafe { + s2n_config_add_cert_chain_and_key_to_store( + self.as_mut_ptr(), + // SAFETY: audit of add_to_store shows that the certificate chain + // is not mutated. https://github.com/aws/s2n-tls/issues/4140 + chain.as_ptr() as *mut _, + ) + .into_result() + }; + self.context_mut().application_owned_certs.push(chain); + result?; + + Ok(self) + } + + /// Corresponds to [s2n_config_set_cert_chain_and_key_defaults]. + pub fn set_default_chains>>( + &mut self, + chains: T, + ) -> Result<&mut Self, Error> { + // Must be equal to S2N_CERT_TYPE_COUNT in s2n_certificate.h. + const CHAINS_MAX_COUNT: usize = 3; + + let mut chain_arrays: [Option>; CHAINS_MAX_COUNT] = + [None, None, None]; + let mut pointer_array = [std::ptr::null_mut(); CHAINS_MAX_COUNT]; + let mut cert_chain_count = 0; + + for chain in chains.into_iter() { + if cert_chain_count >= CHAINS_MAX_COUNT { + return Err(Error::bindings( + ErrorType::UsageError, + "InvalidInput", + "A single default can be specified for RSA, ECDSA, + and RSA-PSS auth types, but more than 3 certs were supplied", + )); + } + + // SAFETY: manual inspection of set_defaults shows that certificates + // are not mutated. https://github.com/aws/s2n-tls/issues/4140 + pointer_array[cert_chain_count] = chain.as_ptr() as *mut _; + chain_arrays[cert_chain_count] = Some(chain); + + cert_chain_count += 1; + } + + let collected_chains = chain_arrays.into_iter().take(cert_chain_count).flatten(); + + self.context_mut() + .application_owned_certs + .extend(collected_chains); + + unsafe { + s2n_config_set_cert_chain_and_key_defaults( + self.as_mut_ptr(), + pointer_array.as_mut_ptr(), + cert_chain_count as u32, + ) + .into_result() + }?; + + Ok(self) + } + pub fn load_public_pem(&mut self, certificate: &[u8]) -> Result<&mut Self, Error> { let size: u32 = certificate .len() @@ -811,6 +889,16 @@ impl Default for Builder { pub(crate) struct Context { refcount: AtomicUsize, + /// This is a container for reference counts. + /// + /// In the bindings, application owned certificate chains are reference counted. + /// The C library is not aware of the reference counts, so a naive implementation + /// would result in certs being prematurely freed because the "reference" + /// held by the C library wouldn't be accounted for. + /// + /// Storing the CertificateChains in this Vec ensures that reference counts + /// behave as expected when stored in an s2n-tls config. + application_owned_certs: Vec>, pub(crate) client_hello_callback: Option>, pub(crate) private_key_callback: Option>, pub(crate) verify_host_callback: Option>, @@ -830,6 +918,7 @@ impl Default for Context { Self { refcount, + application_owned_certs: Vec::new(), client_hello_callback: None, private_key_callback: None, verify_host_callback: None, diff --git a/bindings/rust/s2n-tls/src/connection.rs b/bindings/rust/s2n-tls/src/connection.rs index e65225cf111..7db69da50bf 100644 --- a/bindings/rust/s2n-tls/src/connection.rs +++ b/bindings/rust/s2n-tls/src/connection.rs @@ -1098,13 +1098,10 @@ impl Connection { // chain, so the lifetime is independent of the connection. pub fn peer_cert_chain(&self) -> Result, Error> { unsafe { - let mut chain = CertificateChain::new()?; - s2n_connection_get_peer_cert_chain( - self.connection.as_ptr(), - chain.as_mut_ptr().as_ptr(), - ) - .into_result() - .map(|_| ())?; + let mut chain = CertificateChain::allocate_owned()?; + s2n_connection_get_peer_cert_chain(self.connection.as_ptr(), chain.as_mut_ptr()) + .into_result() + .map(|_| ())?; Ok(chain) } } diff --git a/bindings/rust/s2n-tls/src/renegotiate.rs b/bindings/rust/s2n-tls/src/renegotiate.rs index 493cf746a95..0a3cb841ec4 100644 --- a/bindings/rust/s2n-tls/src/renegotiate.rs +++ b/bindings/rust/s2n-tls/src/renegotiate.rs @@ -514,11 +514,12 @@ mod tests { // // openssl also requires a properly configured CA cert, which the // default TestPair does not include. - let certs_dir = concat!( - env!("CARGO_MANIFEST_DIR"), - "/../../../tests/pems/permutations/rsae_pkcs_4096_sha384/" + let certs = CertKeyPair::from_path( + "permutations/rsae_pkcs_4096_sha384/", + "server-chain", + "server-key", + "ca-cert", ); - let certs = CertKeyPair::from(certs_dir, "server-chain", "server-key", "ca-cert"); // Build the s2n-tls client. builder.load_pem(certs.cert(), certs.key())?; @@ -1021,9 +1022,11 @@ mod tests { // Perform the pkey operation with the selected cert / key pair. let op = this.op.take().unwrap(); let opt_ptr = op.as_ptr(); - let chain_ptr = conn.selected_cert().unwrap().as_mut_ptr().as_ptr(); + let chain_ptr = conn.selected_cert().unwrap().as_ptr(); unsafe { - let key_ptr = s2n_cert_chain_and_key_get_private_key(chain_ptr) + // SAFETY, mut cast: get_private_key does not modify the + // chain, and it is invalid to modify key through `key_ptr` + let key_ptr = s2n_cert_chain_and_key_get_private_key(chain_ptr as *mut _) .into_result()? .as_ptr(); s2n_async_pkey_op_perform(opt_ptr, key_ptr).into_result()?; diff --git a/bindings/rust/s2n-tls/src/testing.rs b/bindings/rust/s2n-tls/src/testing.rs index feaab38acba..4eb74fa396b 100644 --- a/bindings/rust/s2n-tls/src/testing.rs +++ b/bindings/rust/s2n-tls/src/testing.rs @@ -3,6 +3,7 @@ use crate::{ callbacks::VerifyHostNameCallback, + cert_chain::{self, CertificateChain}, config::{self, *}, connection, enums::{self, Blinding}, @@ -59,6 +60,26 @@ impl Default for Counter { } } +#[allow(non_camel_case_types)] +// allow non camel case types because the mixture of letters and numbers is easier +// to read with snake_case. +pub enum SniTestCerts { + AlligatorRsa, + AlligatorEcdsa, + BeaverRsa, +} + +impl SniTestCerts { + pub fn get(&self) -> CertKeyPair { + let prefix = match *self { + SniTestCerts::AlligatorRsa => "alligator_", + SniTestCerts::AlligatorEcdsa => "alligator_ecdsa_", + SniTestCerts::BeaverRsa => "beaver_", + }; + CertKeyPair::from_path(&format!("sni/{prefix}"), "cert", "key", "cert") + } +} + pub struct CertKeyPair { cert_path: String, key_path: String, @@ -69,19 +90,40 @@ pub struct CertKeyPair { impl Default for CertKeyPair { fn default() -> Self { - let prefix = concat!( - env!("CARGO_MANIFEST_DIR"), - "/../../../tests/pems/rsa_4096_sha512_client_" - ); - Self::from(prefix, "cert", "key", "cert") + Self::from_path("rsa_4096_sha512_client_", "cert", "key", "cert") } } impl CertKeyPair { - pub fn from(prefix: &str, chain: &str, key: &str, ca: &str) -> Self { - let cert_path = format!("{prefix}{chain}.pem"); - let key_path = format!("{prefix}{key}.pem"); - let ca_path = format!("{prefix}{ca}.pem"); + /// This is the directory holding all of the pems used for s2n-tls unit tests + const TEST_PEMS_PATH: &'static str = + concat!(env!("CARGO_MANIFEST_DIR"), "/../../../tests/pems/"); + + /// Create a test CertKeyPair + /// * `prefix`: The *relative* prefix from the s2n-tls/tests/pems/ folder. + /// * `chain`: The suffix indicate the full chain. + /// * `key`: The suffix indicate the private key. + /// * `ca`: The suffix indicating the CA. + /// + /// ### Example + /// Assuming the relevant files are at + /// - s2n-tls/tests/pems/permutations/rsae_pkcs_4096_sha384/server-chain.pem + /// - s2n-tls/tests/pems/permutations/rsae_pkcs_4096_sha384/server-key.pem + /// - s2n-tls/tests/pems/permutations/rsae_pkcs_4096_sha384/ca-cert.pem + /// + /// ```ignore + /// let cert = CertKeyPair::from( + /// "permutations/rsae_pkcs_4096_sha384/", + /// "server-chain", + /// "server-key", + /// "ca-cert" + /// ); + /// ``` + pub fn from_path(prefix: &str, chain: &str, key: &str, ca: &str) -> Self { + let cert_path = format!("{}{prefix}{chain}.pem", Self::TEST_PEMS_PATH); + println!("{:?}", cert_path); + let key_path = format!("{}{prefix}{key}.pem", Self::TEST_PEMS_PATH); + let ca_path = format!("{}{prefix}{ca}.pem", Self::TEST_PEMS_PATH); let cert = std::fs::read(&cert_path) .unwrap_or_else(|_| panic!("Failed to read cert at {cert_path}")); let key = @@ -95,6 +137,12 @@ impl CertKeyPair { } } + pub fn into_certificate_chain(&self) -> CertificateChain<'static> { + let mut chain = cert_chain::Builder::new().unwrap(); + chain.load_pem(&self.cert, &self.key).unwrap(); + chain.build().unwrap() + } + pub fn cert_path(&self) -> &str { &self.cert_path } diff --git a/tests/unit/s2n_certificate_test.c b/tests/unit/s2n_certificate_test.c index 48797611486..cdc555e23d3 100644 --- a/tests/unit/s2n_certificate_test.c +++ b/tests/unit/s2n_certificate_test.c @@ -1055,5 +1055,14 @@ int main(int argc, char **argv) EXPECT_EQUAL(len, 2); }; + /* cert type count is equal to 3 */ + { + /* The rust bindings have a constant - CHAINS_MAX_COUNT - which must be + * equal to S2N_CERT_TYPE_COUNT. If this test fails, CHAINS_MAX_COUNT in + * config.rs must be updated. + */ + EXPECT_EQUAL(S2N_CERT_TYPE_COUNT, 3); + } + END_TEST(); }