Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Configuration simplification #202

Merged
merged 4 commits into from
Aug 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
202 changes: 86 additions & 116 deletions wtransport/src/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,6 @@

use crate::tls::build_native_cert_store;
use crate::tls::Identity;
use quinn::ClientConfig as QuicClientConfig;
use quinn::ServerConfig as QuicServerConfig;
use quinn::TransportConfig;
use std::fmt::Debug;
use std::fmt::Display;
Expand All @@ -49,8 +47,6 @@ use std::net::SocketAddr;
use std::net::SocketAddrV6;
use std::pin::Pin;
use std::sync::Arc;
use std::task::Context;
use std::task::Poll;
use std::time::Duration;

/// Alias of [`crate::tls::rustls::ServerConfig`].
Expand All @@ -64,6 +60,16 @@ pub type TlsClientConfig = crate::tls::rustls::ClientConfig;
#[cfg_attr(docsrs, doc(cfg(feature = "quinn")))]
pub type QuicTransportConfig = crate::quinn::TransportConfig;

/// Alias of [`crate::quinn::ServerConfig`].
#[cfg(feature = "quinn")]
#[cfg_attr(docsrs, doc(cfg(feature = "quinn")))]
pub type QuicServerConfig = crate::quinn::ServerConfig;

/// Alias of [`crate::quinn::ClientConfig`].
#[cfg(feature = "quinn")]
#[cfg_attr(docsrs, doc(cfg(feature = "quinn")))]
pub type QuicClientConfig = crate::quinn::ClientConfig;

/// Configuration for IP address socket bind.
#[derive(Debug, Copy, Clone)]
pub enum IpBindConfig {
Expand Down Expand Up @@ -174,7 +180,10 @@ pub struct InvalidIdleTimeout;
/// - [`with_custom_transport`](ServerConfigBuilder::with_custom_transport): sets the QUIC
/// transport configuration manually (using default TLS).
/// - [`with_custom_tls_and_transport`](ServerConfigBuilder::with_custom_tls_and_transport): sets both
/// a custom TLS and QUIC configuration.
/// a custom TLS and QUIC transport configuration.
/// - [`build_with_quic_config`](ServerConfigBuilder::build_with_quic_config): directly builds
/// [`ServerConfig`] providing both TLS and QUIC transport configuration given by
/// [`quic_config`](QuicServerConfig).
///
/// #### Examples:
/// ```
Expand All @@ -201,7 +210,6 @@ pub struct InvalidIdleTimeout;
/// - [`max_idle_timeout`](ServerConfigBuilder::max_idle_timeout)
/// - [`keep_alive_interval`](ServerConfigBuilder::keep_alive_interval)
/// - [`allow_migration`](ServerConfigBuilder::allow_migration)
/// - [`enable_key_log`](ServerConfigBuilder::enable_key_log)
///
/// #### Examples:
/// ```
Expand All @@ -218,10 +226,11 @@ pub struct InvalidIdleTimeout;
/// .build();
/// # Ok(())
/// # }
#[derive(Clone, Debug)]
pub struct ServerConfig {
pub(crate) bind_address: SocketAddr,
pub(crate) dual_stack_config: Ipv6DualStackConfig,
pub(crate) quic_config: QuicServerConfig,
pub(crate) quic_config: quinn::ServerConfig,
}

impl ServerConfig {
Expand Down Expand Up @@ -472,6 +481,19 @@ impl ServerConfigBuilder<states::WantsIdentity> {
self.with(tls_config, quic_transport_config)
}

/// Directly builds [`ServerConfig`] skipping TLS and transport configuration.
///
/// Both TLS and transport configuration is given by [`quic_config`](QuicServerConfig).
#[cfg(feature = "quinn")]
#[cfg_attr(docsrs, doc(cfg(feature = "quinn")))]
pub fn build_with_quic_config(self, quic_config: QuicServerConfig) -> ServerConfig {
ServerConfig {
bind_address: self.0.bind_address,
dual_stack_config: self.0.dual_stack_config,
quic_config,
}
}

fn with(
self,
tls_config: TlsServerConfig,
Expand All @@ -481,7 +503,6 @@ impl ServerConfigBuilder<states::WantsIdentity> {
bind_address: self.0.bind_address,
dual_stack_config: self.0.dual_stack_config,
tls_config,
token_key: None,
transport_config,
migration: true,
})
Expand All @@ -500,11 +521,8 @@ impl ServerConfigBuilder<states::WantsTransportConfigServer> {
quinn::crypto::rustls::QuicServerConfig::try_from(self.0.tls_config)
.expect("CipherSuite::TLS13_AES_128_GCM_SHA256 missing"),
);
let mut quic_config = if let Some(token_key) = self.0.token_key {
QuicServerConfig::new(crypto, token_key)
} else {
QuicServerConfig::with_crypto(crypto)
};

let mut quic_config = quinn::ServerConfig::with_crypto(crypto);

quic_config.transport_config(Arc::new(self.0.transport_config));
quic_config.migration(self.0.migration);
Expand Down Expand Up @@ -557,26 +575,6 @@ impl ServerConfigBuilder<states::WantsTransportConfigServer> {
self.0.migration = value;
self
}

/// Use `Some` to use specific handshake token key instead of a random one.
///
/// Allows reloading the configuration without invalidating in-flight retry tokens.
#[cfg(feature = "quinn")]
#[cfg_attr(docsrs, doc(cfg(feature = "quinn")))]
pub fn token_key(mut self, value: Option<Arc<dyn quinn::crypto::HandshakeTokenKey>>) -> Self {
self.0.token_key = value;
self
}

/// Writes key material for debugging into file provided by `SSLKEYLOGFILE` environment variable.
///
/// Disabled by default.
#[cfg(feature = "dangerous-configuration")]
#[cfg_attr(docsrs, doc(cfg(feature = "dangerous-configuration")))]
pub fn enable_key_log(mut self) -> Self {
self.0.tls_config.key_log = Arc::new(rustls::KeyLogFile::new());
self
}
}

/// Client configuration.
Expand Down Expand Up @@ -635,7 +633,10 @@ impl ServerConfigBuilder<states::WantsTransportConfigServer> {
/// - [`with_custom_transport`](ClientConfigBuilder::with_custom_transport): sets the QUIC
/// transport configuration manually (using default TLS).
/// - [`with_custom_tls_and_transport`](ClientConfigBuilder::with_custom_tls_and_transport): sets both
/// a custom TLS and QUIC configuration.
/// a custom TLS and QUIC transport configuration.
/// - [`build_with_quic_config`](ClientConfigBuilder::build_with_quic_config): directly builds
/// [`ClientConfig`] providing both TLS and QUIC transport configuration given by
/// [`quic_config`](QuicClientConfig).
///
/// Only one of these options can be selected during the client configuration process.
///
Expand All @@ -661,7 +662,6 @@ impl ServerConfigBuilder<states::WantsTransportConfigServer> {
/// - [`max_idle_timeout`](ClientConfigBuilder::max_idle_timeout)
/// - [`keep_alive_interval`](ClientConfigBuilder::keep_alive_interval)
/// - [`dns_resolver`](ClientConfigBuilder::dns_resolver)
/// - [`enable_key_log`](ClientConfigBuilder::enable_key_log)
///
/// #### Examples:
/// ```
Expand All @@ -676,11 +676,12 @@ impl ServerConfigBuilder<states::WantsTransportConfigServer> {
/// .keep_alive_interval(Some(Duration::from_secs(3)))
/// .build();
/// ```
#[derive(Clone, Debug)]
pub struct ClientConfig {
pub(crate) bind_address: SocketAddr,
pub(crate) dual_stack_config: Ipv6DualStackConfig,
pub(crate) quic_config: QuicClientConfig,
pub(crate) dns_resolver: Box<dyn DnsResolver + Send + Sync + Unpin>,
pub(crate) quic_config: quinn::ClientConfig,
pub(crate) dns_resolver: Arc<dyn DnsResolver>,
}

impl ClientConfig {
Expand All @@ -691,6 +692,16 @@ impl ClientConfig {
ClientConfigBuilder::default()
}

/// Allows setting a custom [`DnsResolver`] for this configuration.
///
/// Default resolver is [`TokioDnsResolver`].
pub fn set_dns_resolver<R>(&mut self, dns_resolver: R)
where
R: DnsResolver + 'static,
{
self.dns_resolver = Arc::new(dns_resolver);
}

/// Returns a reference to the inner QUIC configuration.
#[cfg(feature = "quinn")]
#[cfg_attr(docsrs, doc(cfg(feature = "quinn")))]
Expand Down Expand Up @@ -959,6 +970,20 @@ impl ClientConfigBuilder<states::WantsRootStore> {
self.with(tls_config, quic_transport_config)
}

/// Directly builds [`ClientConfig`] skipping TLS and transport configuration.
///
/// Both TLS and transport configuration is given by [`quic_config`](QuicClientConfig).
#[cfg(feature = "quinn")]
#[cfg_attr(docsrs, doc(cfg(feature = "quinn")))]
pub fn build_with_quic_config(self, quic_config: QuicClientConfig) -> ClientConfig {
ClientConfig {
bind_address: self.0.bind_address,
dual_stack_config: self.0.dual_stack_config,
quic_config,
dns_resolver: Arc::<TokioDnsResolver>::default(),
}
}

fn with(
self,
tls_config: TlsClientConfig,
Expand All @@ -969,7 +994,7 @@ impl ClientConfigBuilder<states::WantsRootStore> {
dual_stack_config: self.0.dual_stack_config,
tls_config,
transport_config,
dns_resolver: Box::<TokioDnsResolver>::default(),
dns_resolver: Arc::<TokioDnsResolver>::default(),
})
}
}
Expand All @@ -984,7 +1009,8 @@ impl ClientConfigBuilder<states::WantsTransportConfigClient> {
pub fn build(self) -> ClientConfig {
let crypto = quinn::crypto::rustls::QuicClientConfig::try_from(self.0.tls_config)
.expect("CipherSuite::TLS13_AES_128_GCM_SHA256 missing");
let mut quic_config = QuicClientConfig::new(Arc::new(crypto));

let mut quic_config = quinn::ClientConfig::new(Arc::new(crypto));
quic_config.transport_config(Arc::new(self.0.transport_config));

ClientConfig {
Expand Down Expand Up @@ -1033,19 +1059,9 @@ impl ClientConfigBuilder<states::WantsTransportConfigClient> {
/// Default configuration uses [`TokioDnsResolver`].
pub fn dns_resolver<R>(mut self, dns_resolver: R) -> Self
where
R: DnsResolver + Send + Sync + Unpin + 'static,
R: DnsResolver + 'static,
{
self.0.dns_resolver = Box::new(dns_resolver);
self
}

/// Writes key material for debugging into file provided by `SSLKEYLOGFILE` environment variable.
///
/// Disabled by default.
#[cfg(feature = "dangerous-configuration")]
#[cfg_attr(docsrs, doc(cfg(feature = "dangerous-configuration")))]
pub fn enable_key_log(mut self) -> Self {
self.0.tls_config.key_log = Arc::new(rustls::KeyLogFile::new());
self.0.dns_resolver = Arc::new(dns_resolver);
self
}
}
Expand Down Expand Up @@ -1086,7 +1102,6 @@ pub mod states {
pub(super) bind_address: SocketAddr,
pub(super) dual_stack_config: Ipv6DualStackConfig,
pub(super) tls_config: TlsServerConfig,
pub(super) token_key: Option<Arc<dyn quinn::crypto::HandshakeTokenKey>>,
pub(super) transport_config: quinn::TransportConfig,
pub(super) migration: bool,
}
Expand All @@ -1097,85 +1112,40 @@ pub mod states {
pub(super) dual_stack_config: Ipv6DualStackConfig,
pub(super) tls_config: TlsClientConfig,
pub(super) transport_config: quinn::TransportConfig,
pub(super) dns_resolver: Box<dyn DnsResolver + Send + Sync + Unpin>,
}
}

/// A trait for asynchronously resolving domain names to IP addresses using DNS.
///
/// Utilities for working with `DnsResolver` values are provided by [`DnsResolverExt`].
pub trait DnsResolver {
/// Resolves a domain name to one IP address.
fn poll_resolve(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
host: &str,
) -> Poll<std::io::Result<Option<SocketAddr>>>;
}

/// Extension trait for [`DnsResolver`].
pub trait DnsResolverExt: DnsResolver {
/// Resolves a domain name to one IP address.
fn resolve(&mut self, host: &str) -> Resolve<Self>;
}

impl<T> DnsResolverExt for T
where
T: DnsResolver + ?Sized,
{
fn resolve(&mut self, host: &str) -> Resolve<Self> {
Resolve {
resolver: self,
host: host.to_string(),
}
pub(super) dns_resolver: Arc<dyn DnsResolver>,
}
}

/// Future resolving domain name.
///
/// See [`DnsResolverExt::resolve`].
#[must_use = "futures do nothing unless you `.await` or poll them"]
pub struct Resolve<'a, R>
where
R: ?Sized,
{
resolver: &'a mut R,
host: String,
}
/// See [`DnsResolver::resolve`].
pub trait DnsLookupFuture: Future<Output = std::io::Result<Option<SocketAddr>>> {}

impl<'a, R> Future for Resolve<'a, R>
where
R: DnsResolver + Unpin + ?Sized,
{
type Output = std::io::Result<Option<SocketAddr>>;
impl<F> DnsLookupFuture for F where F: Future<Output = std::io::Result<Option<SocketAddr>>> {}

fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let this = self.get_mut();
DnsResolver::poll_resolve(Pin::new(this.resolver), cx, &this.host)
}
/// A trait for asynchronously resolving domain names to IP addresses using DNS.
pub trait DnsResolver: Debug {
/// Resolves a domain name to one IP address.
fn resolve(&self, host: &str) -> Pin<Box<dyn DnsLookupFuture>>;
}

/// A DNS resolver implementation using the *Tokio* asynchronous runtime.
///
/// Internally, it uses [`tokio::net::lookup_host`].
#[derive(Default)]
pub struct TokioDnsResolver {
#[allow(clippy::type_complexity)]
fut: Option<Pin<Box<dyn Future<Output = std::io::Result<Option<SocketAddr>>> + Send + Sync>>>,
}
pub struct TokioDnsResolver;

impl DnsResolver for TokioDnsResolver {
fn poll_resolve(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
host: &str,
) -> Poll<std::io::Result<Option<SocketAddr>>> {
let fut = self.fut.get_or_insert_with(|| {
let host = host.to_string();
Box::pin(async move { Ok(tokio::net::lookup_host(host).await?.next()) })
});

Future::poll(fut.as_mut(), cx)
fn resolve(&self, host: &str) -> Pin<Box<dyn DnsLookupFuture>> {
let host = host.to_string();

Box::pin(async move { Ok(tokio::net::lookup_host(host).await?.next()) })
}
}

impl Debug for TokioDnsResolver {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("TokioDnsResolver").finish()
}
}

Expand Down
Loading
Loading