Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor(bindings/bench): make harness own IO #4847

Merged
merged 10 commits into from
Dec 16, 2024
17 changes: 5 additions & 12 deletions bindings/rust/bench/benches/handshake.rs
Original file line number Diff line number Diff line change
@@ -6,8 +6,8 @@ use bench::OpenSslConnection;
#[cfg(feature = "rustls")]
use bench::RustlsConnection;
use bench::{
harness::TlsBenchConfig, CipherSuite, ConnectedBuffer, CryptoConfig, HandshakeType, KXGroup,
Mode, S2NConnection, SigType, TlsConnPair, TlsConnection, PROFILER_FREQUENCY,
harness::TlsBenchConfig, CipherSuite, CryptoConfig, HandshakeType, KXGroup, Mode,
S2NConnection, SigType, TlsConnPair, TlsConnection, PROFILER_FREQUENCY,
};
use criterion::{
criterion_group, criterion_main, measurement::WallTime, BatchSize, BenchmarkGroup, Criterion,
@@ -35,16 +35,9 @@ fn bench_handshake_for_library<T>(
bench_group.bench_function(T::name(), |b| {
b.iter_batched_ref(
|| -> Result<TlsConnPair<T, T>, Box<dyn Error>> {
if let (Ok(client_config), Ok(server_config)) =
(client_config.as_ref(), server_config.as_ref())
{
let connected_buffer = ConnectedBuffer::default();
let client =
T::new_from_config(client_config, connected_buffer.clone_inverse())?;
let server = T::new_from_config(server_config, connected_buffer)?;
Ok(TlsConnPair::wrap(client, server))
} else {
Err("invalid configs".into())
match (client_config.as_ref(), server_config.as_ref()) {
(Ok(c_conf), Ok(s_conf)) => Ok(TlsConnPair::from_configs(c_conf, s_conf)),
_ => Err("invalid configs".into()),
}
},
|conn_pair| {
29 changes: 13 additions & 16 deletions bindings/rust/bench/benches/throughput.rs
Original file line number Diff line number Diff line change
@@ -6,8 +6,8 @@ use bench::OpenSslConnection;
#[cfg(feature = "rustls")]
use bench::RustlsConnection;
use bench::{
harness::TlsBenchConfig, CipherSuite, ConnectedBuffer, CryptoConfig, HandshakeType, KXGroup,
Mode, S2NConnection, SigType, TlsConnPair, TlsConnection, PROFILER_FREQUENCY,
harness::TlsBenchConfig, CipherSuite, CryptoConfig, HandshakeType, KXGroup, Mode,
S2NConnection, SigType, TlsConnPair, TlsConnection, PROFILER_FREQUENCY,
};
use criterion::{
criterion_group, criterion_main, measurement::WallTime, BatchSize, BenchmarkGroup, Criterion,
@@ -26,24 +26,21 @@ fn bench_throughput_for_library<T>(
T::Config: TlsBenchConfig,
{
let crypto_config = CryptoConfig::new(cipher_suite, KXGroup::default(), SigType::default());
let client_config = T::Config::make_config(Mode::Client, crypto_config, HandshakeType::default());
let server_config = T::Config::make_config(Mode::Server, crypto_config, HandshakeType::default());
let client_config =
T::Config::make_config(Mode::Client, crypto_config, HandshakeType::default());
let server_config =
T::Config::make_config(Mode::Server, crypto_config, HandshakeType::default());

bench_group.bench_function(T::name(), |b| {
b.iter_batched_ref(
|| -> Result<TlsConnPair<T, T>, Box<dyn Error>> {
if let (Ok(client_config), Ok(server_config)) =
(client_config.as_ref(), server_config.as_ref())
{
let connected_buffer = ConnectedBuffer::default();
let client =
T::new_from_config(client_config, connected_buffer.clone_inverse())?;
let server = T::new_from_config(server_config, connected_buffer)?;
let mut conn_pair = TlsConnPair::wrap(client, server);
conn_pair.handshake()?;
Ok(conn_pair)
} else {
Err("invalid configs".into())
match (client_config.as_ref(), server_config.as_ref()) {
(Ok(c_conf), Ok(s_conf)) => {
let mut pair = TlsConnPair::<T, T>::from_configs(c_conf, s_conf);
pair.handshake()?;
Ok(pair)
}
_ => Err("invalid configs".into()),
}
},
|conn_pair| {
2 changes: 1 addition & 1 deletion bindings/rust/bench/src/bin/memory.rs
Original file line number Diff line number Diff line change
@@ -108,7 +108,7 @@ fn memory_bench<T: TlsConnection>(opt: &Opt) -> Result<(), Box<dyn Error>> {
)?;
conn_pair = TlsConnPair::wrap(client_conn, server_conn);
} else {
conn_pair = TlsConnPair::<T, T>::new(
conn_pair = TlsConnPair::<T, T>::from_configs(
CryptoConfig::default(),
HandshakeType::default(),
buffers.pop().unwrap(),
goatgoose marked this conversation as resolved.
Show resolved Hide resolved
67 changes: 67 additions & 0 deletions bindings/rust/bench/src/harness/io.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
// SPDX-License-Identifier: Apache-2.0

use std::{cell::RefCell, collections::VecDeque, io::ErrorKind, pin::Pin, rc::Rc};
jmayclin marked this conversation as resolved.
Show resolved Hide resolved

pub type LocalDataBuffer = RefCell<VecDeque<u8>>;

#[derive(Debug)]
pub struct TestPairIO {
/// a data buffer that the server writes to and the client reads from
pub server_tx_stream: Pin<Rc<LocalDataBuffer>>,
/// a data buffer that the client writes to and the server reads from
pub client_tx_stream: Pin<Rc<LocalDataBuffer>>,
}

impl TestPairIO {
pub fn client_view(&self) -> ViewIO {
ViewIO {
send_ctx: self.client_tx_stream.clone(),
recv_ctx: self.server_tx_stream.clone(),
}
}

pub fn server_view(&self) -> ViewIO {
ViewIO {
send_ctx: self.server_tx_stream.clone(),
recv_ctx: self.client_tx_stream.clone(),
}
}
}

/// A "view" of the IO.
///
/// This view is client/server specific, and notably implements the read and write
/// traits.
///
// This struct is used by Openssl and Rustls which both rely on a "stream" abstraction
// which implements read and write. This is not used by s2n-tls, which relies on
// lower level callbacks.
pub struct ViewIO {
maddeleine marked this conversation as resolved.
Show resolved Hide resolved
pub send_ctx: Pin<Rc<LocalDataBuffer>>,
pub recv_ctx: Pin<Rc<LocalDataBuffer>>,
}

impl std::io::Read for ViewIO {
fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
let res = self.recv_ctx.borrow_mut().read(buf);
if let Ok(0) = res {
// We are "faking" a TcpStream, where a read of length 0 indicates
// EoF. That is incorrect for this scenario. Instead we return WouldBlock
// to indicate that there is simply no more data to be read.
Err(std::io::Error::new(ErrorKind::WouldBlock, "blocking"))
} else {
res
}
}
}

impl std::io::Write for ViewIO {
fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
self.send_ctx.borrow_mut().write(buf)
}

fn flush(&mut self) -> std::io::Result<()> {
Ok(())
}
}
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
// SPDX-License-Identifier: Apache-2.0

mod io;
pub use io::{LocalDataBuffer, ViewIO};

use io::TestPairIO;
use std::{
cell::RefCell,
collections::VecDeque,
error::Error,
fmt::Debug,
fs::read_to_string,
io::{ErrorKind, Read, Write},
rc::Rc,
};
use strum::EnumIter;
@@ -161,10 +162,7 @@ pub trait TlsConnection: Sized {
fn name() -> String;

/// Make connection from existing config and buffer
fn new_from_config(
config: &Self::Config,
connected_buffer: ConnectedBuffer,
) -> Result<Self, Box<dyn Error>>;
fn new_from_config(config: &Self::Config, io: ViewIO) -> Result<Self, Box<dyn Error>>;

/// Run one handshake step: receive msgs from other connection, process, and send new msgs
fn handshake(&mut self) -> Result<(), Box<dyn Error>>;
@@ -184,29 +182,13 @@ pub trait TlsConnection: Sized {

/// Read application data from ConnectedBuffer
fn recv(&mut self, data: &mut [u8]) -> Result<(), Box<dyn Error>>;

/// Shrink buffers owned by the connection
fn shrink_connection_buffers(&mut self);

/// Clear and shrink buffers used for IO with another connection
fn shrink_connected_buffer(&mut self);

/// Get reference to internal connected buffer
fn connected_buffer(&self) -> &ConnectedBuffer;
}

/// A TlsConnPair owns the client and server tls connections along with the IO buffers.
pub struct TlsConnPair<C: TlsConnection, S: TlsConnection> {
client: C,
server: S,
}

impl<C: TlsConnection, S: TlsConnection> TlsConnPair<C, S> {
pub fn new(client_config: &C::Config, server_config: &S::Config) -> TlsConnPair<C, S> {
let connected_buffer = ConnectedBuffer::default();
let client = C::new_from_config(&client_config, connected_buffer.clone_inverse()).unwrap();
let server = S::new_from_config(&server_config, connected_buffer).unwrap();
Self { client, server }
}
pub client: C,
pub server: S,
pub io: TestPairIO,
}

impl<C, S> Default for TlsConnPair<C, S>
@@ -242,7 +224,7 @@ where

// handshake the client and server connections. This will result in
// session ticket getting stored in client_config
let mut pair = TlsConnPair::<C, S>::new(&client_config, &server_config);
let mut pair = TlsConnPair::<C, S>::from_configs(&client_config, &server_config);
pair.handshake()?;
// NewSessionTicket messages are part of the application data and sent
// after the handshake is complete, so we must trigger an additional
@@ -255,10 +237,13 @@ where
// on the connection. This results in the session ticket in
// client_config (from the previous handshake) getting set on the
// client connection.
return Ok(TlsConnPair::<C, S>::new(&client_config, &server_config));
return Ok(TlsConnPair::<C, S>::from_configs(
&client_config,
&server_config,
));
}

Ok(TlsConnPair::<C, S>::new(
Ok(TlsConnPair::<C, S>::from_configs(
&C::Config::make_config(Mode::Client, crypto_config, handshake_type).unwrap(),
&S::Config::make_config(Mode::Server, crypto_config, handshake_type).unwrap(),
))
@@ -270,13 +255,14 @@ where
C: TlsConnection,
S: TlsConnection,
{
/// Wrap two TlsConnections into a TlsConnPair
pub fn wrap(client: C, server: S) -> Self {
assert!(
client.connected_buffer() == &server.connected_buffer().clone_inverse(),
"connected buffers don't match"
);
Self { client, server }
pub fn from_configs(client_config: &C::Config, server_config: &S::Config) -> Self {
let io = TestPairIO {
server_tx_stream: Rc::pin(Default::default()),
client_tx_stream: Rc::pin(Default::default()),
};
let client = C::new_from_config(client_config, io.client_view()).unwrap();
let server = S::new_from_config(server_config, io.server_view()).unwrap();
Self { client, server, io }
}

/// Take back ownership of individual connections in the TlsConnPair
@@ -325,93 +311,6 @@ where

Ok(())
}

/// Shrink buffers owned by the connections
pub fn shrink_connection_buffers(&mut self) {
maddeleine marked this conversation as resolved.
Show resolved Hide resolved
self.client.shrink_connection_buffers();
self.server.shrink_connection_buffers();
}

/// Clear and shrink buffers used for IO between the connections
pub fn shrink_connected_buffers(&mut self) {
self.client.shrink_connected_buffer();
self.server.shrink_connected_buffer();
}
}

/// Wrapper of two shared buffers to pass as stream
/// This wrapper `read()`s into one buffer and `write()`s to another
/// `Rc<RefCell<VecDeque<u8>>>` allows sharing of references to the buffers for two connections
#[derive(Clone, Eq)]
pub struct ConnectedBuffer {
recv: Rc<RefCell<VecDeque<u8>>>,
send: Rc<RefCell<VecDeque<u8>>>,
}

impl PartialEq for ConnectedBuffer {
/// ConnectedBuffers are equal if and only if they point to the same VecDeques
fn eq(&self, other: &ConnectedBuffer) -> bool {
Rc::ptr_eq(&self.recv, &other.recv) && Rc::ptr_eq(&self.send, &other.send)
}
}

impl ConnectedBuffer {
/// Make a new struct with new internal buffers
pub fn new() -> Self {
let recv = Rc::new(RefCell::new(VecDeque::new()));
let send = Rc::new(RefCell::new(VecDeque::new()));

// prevent (potentially slow) resizing of buffers for small data transfers,
// like with handshake
recv.borrow_mut().reserve(10000);
send.borrow_mut().reserve(10000);
goatgoose marked this conversation as resolved.
Show resolved Hide resolved

Self { recv, send }
}

/// Makes a new ConnectedBuffer that shares internal buffers but swapped,
/// ex. `write()` writes to the buffer that the inverse `read()`s from
pub fn clone_inverse(&self) -> Self {
Self {
recv: self.send.clone(),
send: self.recv.clone(),
}
}

/// Clears and shrinks buffers
pub fn shrink(&mut self) {
self.recv.borrow_mut().clear();
self.recv.borrow_mut().shrink_to_fit();
self.send.borrow_mut().clear();
self.send.borrow_mut().shrink_to_fit();
}
}

impl Read for ConnectedBuffer {
fn read(&mut self, dest: &mut [u8]) -> Result<usize, std::io::Error> {
let res = self.recv.borrow_mut().read(dest);
match res {
// rustls expects WouldBlock on read of length 0
Ok(0) => Err(std::io::Error::new(ErrorKind::WouldBlock, "blocking")),
Ok(len) => Ok(len),
Err(err) => Err(err),
}
}
}

impl Write for ConnectedBuffer {
fn write(&mut self, src: &[u8]) -> Result<usize, std::io::Error> {
self.send.borrow_mut().write(src)
}
fn flush(&mut self) -> Result<(), std::io::Error> {
Ok(()) // data already available to destination
}
}

impl Default for ConnectedBuffer {
fn default() -> Self {
Self::new()
}
}

#[cfg(test)]
4 changes: 2 additions & 2 deletions bindings/rust/bench/src/lib.rs
Original file line number Diff line number Diff line change
@@ -14,8 +14,8 @@ pub use crate::openssl::OpenSslConnection;
pub use crate::rustls::RustlsConnection;
pub use crate::{
harness::{
get_cert_path, CipherSuite, ConnectedBuffer, CryptoConfig, HandshakeType, KXGroup, Mode,
PemType, SigType, TlsConnPair, TlsConnection,
get_cert_path, CipherSuite, CryptoConfig, HandshakeType, KXGroup, Mode, PemType, SigType,
TlsConnPair, TlsConnection,
},
s2n_tls::S2NConnection,
};
Loading