Skip to content

Commit

Permalink
add default and unit tests
Browse files Browse the repository at this point in the history
  • Loading branch information
jmayclin committed Nov 25, 2024
1 parent 3369c7d commit ce0ea6c
Show file tree
Hide file tree
Showing 3 changed files with 240 additions and 71 deletions.
199 changes: 141 additions & 58 deletions bindings/rust/s2n-tls/src/cert_chain.rs
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ impl CertificateChain<'_> {
// count to allow for the "reference" held by the s2n_connection.
let clone_to_increment_refcount = Arc::clone(&handle);
std::mem::forget(clone_to_increment_refcount);
// handle & owning struct
// `handle` + owning struct = 2
debug_assert_eq!(Arc::strong_count(&handle), 2);

CertificateChain {
Expand Down Expand Up @@ -189,97 +189,180 @@ unsafe impl Send for Certificate<'_> {}
mod tests {
use crate::{
config,
security::{Policy, DEFAULT_TLS13},
testing::{CertKeyPair, InsecureAcceptAllCertificatesHandler, TestPair},
security::DEFAULT_TLS13,
testing::{InsecureAcceptAllCertificatesHandler, SniTestCerts, TestPair},
};

use super::*;

#[test]
fn ref_counts() -> Result<(), crate::error::Error> {
let cert = CertKeyPair::default();
fn reference_count_increment() -> Result<(), crate::error::Error> {
let alligator_cert = SniTestCerts::AlligatorRsa.get().into_certificate_chain();

let chain = CertificateChain::load_pems(cert.cert(), cert.key())?;
assert_eq!(Arc::strong_count(&chain.ptr), 1);
// cert on a single config
{
let mut server = config::Builder::new();
server.add_to_store(alligator_cert.clone())?;

let mut list = Vec::new();
for _ in 0..10 {
list.push(chain.clone());
// after being added, the reference count should have increased
assert_eq!(Arc::strong_count(&alligator_cert.ptr), 2);
}
assert_eq!(Arc::strong_count(&chain.ptr), 1 + 10);
drop(list);
assert_eq!(Arc::strong_count(&chain.ptr), 1);

// after the config goes out of scope and is dropped, the ref count should
// decrement
assert_eq!(Arc::strong_count(&alligator_cert.ptr), 1);
Ok(())
}

#[test]
fn sanity_check() -> Result<(), crate::error::Error> {
let cert = CertKeyPair::default();

fn cert_is_dropped() {
let weak_ref;
{
let mut server = config::Builder::new();
server.set_security_policy(&DEFAULT_TLS13)?;
server.load_pem(cert.cert(), cert.key())?;

let mut client = config::Builder::new();
client.set_security_policy(&DEFAULT_TLS13)?;
client.set_verify_host_callback(InsecureAcceptAllCertificatesHandler {})?;
client.trust_pem(cert.cert())?;
let cert = SniTestCerts::AlligatorEcdsa.get().into_certificate_chain();
weak_ref = Arc::downgrade(&cert.ptr);
assert_eq!(Arc::strong_count(&cert.ptr), 1);
}
assert_eq!(weak_ref.strong_count(), 0);
assert!(weak_ref.upgrade().is_none());
}

let mut pair = TestPair::from_configs(&client.build()?, &server.build()?);
/// Create a test pair using SNI certs
/// * `certs`: takes references to already created cert chains. This is useful
/// to assert on expected reference counts.
/// * `types`: Used to find the CA paths for the client configs
fn sni_test_pair(
certs: Vec<CertificateChain<'static>>,
defaults: Option<Vec<CertificateChain<'static>>>,
types: &[SniTestCerts],
) -> Result<TestPair, crate::error::Error> {
let mut server_config = config::Builder::new();
server_config
.with_system_certs(false)?
.set_security_policy(&DEFAULT_TLS13)?;
for cert in certs.into_iter() {
server_config.add_to_store(cert)?;
}
if let Some(defaults) = defaults {
server_config.set_default_cert_chain_and_key(defaults)?;
}

pair.handshake().unwrap();
let mut client_config = config::Builder::new();
client_config
.with_system_certs(false)?
.set_security_policy(&DEFAULT_TLS13)?
.set_verify_host_callback(InsecureAcceptAllCertificatesHandler {})?;
for tipe in types {
client_config.trust_pem(tipe.get().cert())?;
}
Ok(TestPair::from_configs(
&client_config.build()?,
&server_config.build()?,
))
}

Ok(())
/// This is a useful (but inefficient) test utility to check if CertificateChain
/// structs are equal. It does this by comparing the serialized `der` representation.
fn cert_chains_are_equal<'a, 'b>(
this: &CertificateChain<'a>,
that: &CertificateChain<'b>,
) -> bool {
let this: Vec<Vec<u8>> = this
.iter()
.map(|cert| cert.unwrap().der().unwrap().to_owned())
.collect();
let that: Vec<Vec<u8>> = that
.iter()
.map(|cert| cert.unwrap().der().unwrap().to_owned())
.collect();
this == that
}

// a cert can be successfully shared across multiple configs
#[test]
fn config_drop() -> Result<(), crate::error::Error> {
let cert = CertKeyPair::default();
fn shared_certs() -> Result<(), crate::error::Error> {
let test_key_pair = SniTestCerts::AlligatorRsa.get();
let cert = test_key_pair.into_certificate_chain();

let chain = CertificateChain::load_pems(cert.cert(), cert.key())?;
let mut test_pair_1 =
sni_test_pair(vec![cert.clone()], None, &[SniTestCerts::AlligatorRsa])?;
let mut test_pair_2 =
sni_test_pair(vec![cert.clone()], None, &[SniTestCerts::AlligatorRsa])?;

// cert on a single config
assert_eq!(Arc::strong_count(&cert.ptr), 3);

assert!(test_pair_1.handshake().is_ok());
assert!(test_pair_2.handshake().is_ok());

assert_eq!(Arc::strong_count(&cert.ptr), 3);

drop(test_pair_1);
assert_eq!(Arc::strong_count(&cert.ptr), 2);
drop(test_pair_2);
assert_eq!(Arc::strong_count(&cert.ptr), 1);
Ok(())
}

#[test]
fn default_effects() -> Result<(), crate::error::Error> {
let alligator_cert = SniTestCerts::AlligatorRsa.get().into_certificate_chain();
let beaver_cert = SniTestCerts::BeaverRsa.get().into_certificate_chain();

// when no default is explicitly set, the first loaded cert is the default
{
let mut server = config::Builder::new();
server.set_security_policy(&DEFAULT_TLS13)?;
server.add_to_store(chain.clone())?;
server.set_verify_host_callback(InsecureAcceptAllCertificatesHandler {})?;
server.trust_pem(cert.cert())?;
let mut test_pair = sni_test_pair(
vec![alligator_cert.clone(), beaver_cert.clone()],
None,
&[SniTestCerts::AlligatorRsa, SniTestCerts::BeaverRsa],
)?;

// after being added, the reference count should have increased
assert_eq!(Arc::strong_count(&chain.ptr), 2);
assert!(test_pair.handshake().is_ok());

let mut pair = TestPair::from_config(&server.build()?);
assert!(pair.handshake().is_ok());
assert!(cert_chains_are_equal(
&alligator_cert,
&test_pair.client.peer_cert_chain().unwrap()
));

assert_eq!(Arc::strong_count(&chain.ptr), 2);
assert_eq!(Arc::strong_count(&alligator_cert.ptr), 2);
assert_eq!(Arc::strong_count(&beaver_cert.ptr), 2);
}
// after the config goes out of scope and is dropped, the ref count should
// decrement
assert_eq!(Arc::strong_count(&chain.ptr), 1);
{

// cert on a single config

// set an explicit default
{
let mut server = config::Builder::new();
server.set_security_policy(&DEFAULT_TLS13)?;
server.add_to_store(chain.clone())?;
server.set_verify_host_callback(InsecureAcceptAllCertificatesHandler {})?;
server.trust_pem(cert.cert())?;
let mut test_pair = sni_test_pair(
vec![alligator_cert.clone(), beaver_cert.clone()],
Some(vec![beaver_cert.clone()]),
&[SniTestCerts::AlligatorRsa, SniTestCerts::BeaverRsa],
)?;

// after being added, the reference count should have increased
assert_eq!(Arc::strong_count(&chain.ptr), 2);
assert!(test_pair.handshake().is_ok());

let mut pair = TestPair::from_config(&server.build()?);
assert!(pair.handshake().is_ok());
assert!(cert_chains_are_equal(
&beaver_cert,
&test_pair.client.peer_cert_chain().unwrap()
));

assert_eq!(Arc::strong_count(&chain.ptr), 2);
}
assert_eq!(Arc::strong_count(&alligator_cert.ptr), 2);
assert_eq!(Arc::strong_count(&beaver_cert.ptr), 3);
}

// set a default without adding it to the store
{
let mut test_pair = sni_test_pair(
vec![alligator_cert.clone()],
Some(vec![beaver_cert.clone()]),
&[SniTestCerts::AlligatorRsa, SniTestCerts::BeaverRsa],
)?;

assert!(test_pair.handshake().is_ok());

assert!(cert_chains_are_equal(
&beaver_cert,
&test_pair.client.peer_cert_chain().unwrap()
));

assert_eq!(Arc::strong_count(&alligator_cert.ptr), 2);
assert_eq!(Arc::strong_count(&beaver_cert.ptr), 2);
}

Ok(())
}
Expand Down
48 changes: 44 additions & 4 deletions bindings/rust/s2n-tls/src/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ use crate::{
callbacks::*,
cert_chain::CertificateChain,
enums::*,
error::{Error, Fallible},
error::{Error, ErrorType, Fallible},
security,
};
use core::{convert::TryInto, ptr::NonNull};
Expand All @@ -16,7 +16,11 @@ use std::{
ffi::{c_void, CString},
path::Path,
pin::Pin,
sync::atomic::{AtomicUsize, Ordering},
ptr,
sync::{
atomic::{AtomicUsize, Ordering},
Arc,
},
task::Poll,
time::{Duration, SystemTime},
};
Expand Down Expand Up @@ -294,7 +298,15 @@ impl Builder {
}

pub fn add_to_store(&mut self, chain: CertificateChain<'static>) -> Result<&mut Self, Error> {
// TODO: should we hold the extra reference before or after loading the cert?
// Out of an abudance of caution, we hold a reference to the CertificateChain
// regardless of whether add_to_store fails or succeeds. We have limited
// visibility into the failure modes, so this behavior ensures that _if_
// the C library held the reference despite the failure, it would continue
// to be valid memory.
self.context_mut()
.application_owned_certs
.push(chain.clone());

unsafe {
s2n_config_add_cert_chain_and_key_to_store(
self.as_mut_ptr(),
Expand All @@ -305,7 +317,35 @@ impl Builder {
.into_result()
}?;

self.context_mut().application_owned_certs.push(chain);
Ok(self)
}

/// Set the default cert for a particular auth type. Auth types are
/// - RSA
/// - ECDSA
/// - RSA-PSS
/// Repeated calls to this function will overwrite previous defaults.
pub fn set_default_cert_chain_and_key(
&mut self,
chains: Vec<CertificateChain<'static>>,
) -> Result<&mut Self, Error> {
self.context_mut()
.application_owned_certs
.extend(chains.clone());

let raw_certs: Vec<*mut s2n_cert_chain_and_key> = chains
.into_iter()
.map(|cert| cert.as_ptr() as *mut _)
.collect();

unsafe {
s2n_config_set_cert_chain_and_key_defaults(
self.as_mut_ptr(),
raw_certs.as_ptr() as *mut _,
raw_certs.len() as u32,
);
}

Ok(self)
}

Expand Down
Loading

0 comments on commit ce0ea6c

Please sign in to comment.