Skip to content

Commit

Permalink
Created ConnectionInitializer
Browse files Browse the repository at this point in the history
  • Loading branch information
maddeleine committed Oct 17, 2023
1 parent c3d8af6 commit 2df4e9a
Show file tree
Hide file tree
Showing 4 changed files with 116 additions and 31 deletions.
5 changes: 0 additions & 5 deletions bindings/rust/s2n-tls/src/callbacks/session_ticket.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,6 @@ pub trait SessionTicketCallback: Send + Sync + 'static {
fn on_session_ticket(&self, connection: &mut Connection, session_ticket: &SessionTicket);
}

// A trait to give session tickets to new TLS connections
pub trait SessionTicketProvider: 'static + Send + Sync {
fn provide_session_ticket(&self) -> Option<Vec<u8>>;
}

pub struct SessionTicket(s2n_session_ticket);

impl SessionTicket {
Expand Down
99 changes: 86 additions & 13 deletions bindings/rust/s2n-tls/src/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,9 @@ use s2n_tls_sys::*;
use std::{
ffi::{c_void, CString},
path::Path,
pin::Pin,
sync::atomic::{AtomicUsize, Ordering},
task::Poll,
time::{Duration, SystemTime},
};

Expand Down Expand Up @@ -489,6 +491,17 @@ impl Builder {
Ok(self)
}

pub fn set_connection_initializer<T: 'static + ConnectionInitializer>(
&mut self,
handler: T,
) -> Result<&mut Self, Error> {
// Store callback in config context
let handler = Box::new(handler);
let context = self.config.context_mut();
context.connection_initializer = Some(handler);
Ok(self)
}

/// Sets a custom callback which provides access to session tickets when they arrive
pub fn set_session_ticket_callback<T: 'static + SessionTicketCallback>(
&mut self,
Expand Down Expand Up @@ -524,17 +537,6 @@ impl Builder {
Ok(self)
}

pub fn set_session_ticket_provider<T: 'static + SessionTicketProvider>(
&mut self,
handler: T,
) -> Result<&mut Self, Error> {
// Store callback in context
let handler = Box::new(handler);
let context = self.config.context_mut();
context.session_ticket_provider = Some(handler);
Ok(self)
}

/// Set a callback function triggered by operations requiring the private key.
///
/// See https://github.com/aws/s2n-tls/blob/main/docs/USAGE-GUIDE.md#private-key-operation-related-calls
Expand Down Expand Up @@ -744,7 +746,7 @@ pub(crate) struct Context {
pub(crate) private_key_callback: Option<Box<dyn PrivateKeyCallback>>,
pub(crate) verify_host_callback: Option<Box<dyn VerifyHostNameCallback>>,
pub(crate) session_ticket_callback: Option<Box<dyn SessionTicketCallback>>,
pub(crate) session_ticket_provider: Option<Box<dyn SessionTicketProvider>>,
pub(crate) connection_initializer: Option<Box<dyn ConnectionInitializer>>,
pub(crate) wall_clock: Option<Box<dyn WallClock>>,
pub(crate) monotonic_clock: Option<Box<dyn MonotonicClock>>,
}
Expand All @@ -761,9 +763,80 @@ impl Default for Context {
private_key_callback: None,
verify_host_callback: None,
session_ticket_callback: None,
session_ticket_provider: None,
connection_initializer: None,
wall_clock: None,
monotonic_clock: None,
}
}
}

/// A trait executed before a new connection negotiates TLS.
///
/// Used for any dynamic configuration of the connection.
/// Use in conjunction with
/// [config::Builder::set_connection_initializer](`crate::config::Builder::set_connection_initializer()`).
pub trait ConnectionInitializer: 'static + Send + Sync {
/// The application can return an `Ok(None)` to resolve the callback
/// synchronously or return an `Ok(Some(ConnectionFuture))` if it wants to
/// run some asynchronous task before resolving the callback.
///
fn initialize_connection(
&self,
connection: &mut crate::connection::Connection,
) -> ConnectionFutureResult;
}

impl<A: ConnectionInitializer, B: ConnectionInitializer> ConnectionInitializer for (A, B) {
fn initialize_connection(
&self,
connection: &mut crate::connection::Connection,
) -> ConnectionFutureResult {
let a = self.0.initialize_connection(connection)?;
let b = self.1.initialize_connection(connection)?;
match (a, b) {
(None, None) => Ok(None),
(None, Some(fut)) => Ok(Some(fut)),
(Some(fut), None) => Ok(Some(fut)),
(Some(fut_a), Some(fut_b)) => Ok(Some(Box::pin(ConcurrentConnectionFuture::new([
fut_a, fut_b,
])))),
}
}
}

struct ConcurrentConnectionFuture<const N: usize> {
futures: [Option<Pin<Box<dyn ConnectionFuture>>>; N],
}

impl<const N: usize> ConcurrentConnectionFuture<N> {
fn new(futures: [Pin<Box<dyn ConnectionFuture>>; N]) -> Self {
let futures = futures.map(Some);
Self { futures }
}
}

impl<const N: usize> ConnectionFuture for ConcurrentConnectionFuture<N> {
fn poll(
mut self: std::pin::Pin<&mut Self>,
connection: &mut crate::connection::Connection,
ctx: &mut core::task::Context,
) -> std::task::Poll<Result<(), Error>> {
let mut is_pending = false;
for container in self.futures.iter_mut() {
if let Some(future) = container.as_mut() {
match future.as_mut().poll(connection, ctx) {
Poll::Ready(result) => {
result?;
*container = None;
}
Poll::Pending => is_pending = true,
}
}
}
if is_pending {
Poll::Pending
} else {
Poll::Ready(Ok(()))
}
}
}
17 changes: 10 additions & 7 deletions bindings/rust/s2n-tls/src/connection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -202,13 +202,6 @@ impl Connection {
"s2n_connection_set_config was successful"
};

let context = config.context();
if let Some(callback) = &context.session_ticket_provider {
if let Some(ticket) = callback.provide_session_ticket() {
self.set_session_ticket(&ticket)?;
}
}

// Setting the config on the connection creates one additional reference to the config
// so do not drop so prevent Rust from calling `drop()` at the end of this function.
mem::forget(config);
Expand Down Expand Up @@ -427,6 +420,14 @@ impl Connection {
/// any other callbacks) until the blocking async task reports completion.
pub fn poll_negotiate(&mut self) -> Poll<Result<&mut Self, Error>> {
let mut blocked = s2n_blocked_status::NOT_BLOCKED;
if !core::mem::replace(&mut self.context_mut().connection_initialized, true) {
if let Some(config) = self.config() {
if let Some(callback) = config.context().connection_initializer.as_ref() {
let future = callback.initialize_connection(self);
AsyncCallback::trigger(future, self);
}
}
}

loop {
// check if an async task exists and poll it to completion
Expand Down Expand Up @@ -849,6 +850,7 @@ struct Context {
waker: Option<Waker>,
async_callback: Option<AsyncCallback>,
verify_host_callback: Option<Box<dyn VerifyHostNameCallback>>,
connection_initialized: bool,
}

impl Context {
Expand All @@ -858,6 +860,7 @@ impl Context {
waker: None,
async_callback: None,
verify_host_callback: None,
connection_initialized: false,
}
}
}
Expand Down
26 changes: 20 additions & 6 deletions bindings/rust/s2n-tls/src/testing/resumption.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,12 @@
#[cfg(test)]
mod tests {
use crate::{
callbacks::{SessionTicket, SessionTicketCallback, SessionTicketProvider},
callbacks::{SessionTicket, SessionTicketCallback},
config::ConnectionInitializer,
connection,
testing::{s2n_tls::*, *},
};
use futures_test::task::noop_waker;
use std::{error::Error, sync::Mutex, time::SystemTime};

#[derive(Default, Clone)]
Expand All @@ -32,9 +34,15 @@ mod tests {
}
}

impl SessionTicketProvider for SessionTicketHandler {
fn provide_session_ticket(&self) -> Option<Vec<u8>> {
(*self.stored_ticket).lock().unwrap().clone()
impl ConnectionInitializer for SessionTicketHandler {
fn initialize_connection(
&self,
connection: &mut crate::connection::Connection,
) -> crate::callbacks::ConnectionFutureResult {
if let Some(ticket) = (*self.stored_ticket).lock().unwrap().as_deref() {
connection.set_session_ticket(ticket)?;
}
Ok(None)
}
}

Expand Down Expand Up @@ -63,7 +71,7 @@ mod tests {
.set_session_ticket_callback(handler.clone())?
.trust_pem(keypair.cert())?
.set_verify_host_callback(InsecureAcceptAllCertificatesHandler {})?
.set_session_ticket_provider(handler.clone())?;
.set_connection_initializer(handler.clone())?;
let client_config = client_config_builder.build()?;

// create and configure a server connection
Expand All @@ -74,7 +82,10 @@ mod tests {

// create a client connection
let mut client = connection::Connection::new_client();

// Client needs a waker due to its use of an async callback
client
.set_waker(Some(&noop_waker()))?
.set_config(client_config.clone())
.expect("Unable to set client config");

Expand All @@ -101,6 +112,7 @@ mod tests {
let mut client = connection::Connection::new_client();

client
.set_waker(Some(&noop_waker()))?
.set_config(client_config)
.expect("Unable to set client config");

Expand Down Expand Up @@ -135,7 +147,7 @@ mod tests {
client_config_builder
.enable_session_tickets(true)?
.set_session_ticket_callback(handler.clone())?
.set_session_ticket_provider(handler.clone())?
.set_connection_initializer(handler.clone())?
.trust_pem(keypair.cert())?
.set_verify_host_callback(InsecureAcceptAllCertificatesHandler {})?
.set_security_policy(&security::DEFAULT_TLS13)?;
Expand All @@ -150,6 +162,7 @@ mod tests {
// create a client connection
let mut client = connection::Connection::new_client();
client
.set_waker(Some(&noop_waker()))?
.set_config(client_config.clone())
.expect("Unable to set client config");

Expand Down Expand Up @@ -180,6 +193,7 @@ mod tests {
// create a client connection with a resumption ticket
let mut client = connection::Connection::new_client();
client
.set_waker(Some(&noop_waker()))?
.set_config(client_config)
.expect("Unable to set client config");

Expand Down

0 comments on commit 2df4e9a

Please sign in to comment.