Skip to content

Commit

Permalink
Use rustls instead of native-tls with openssl in SDK (#1340)
Browse files Browse the repository at this point in the history
  • Loading branch information
spetz authored Nov 15, 2024
1 parent eebc6a4 commit f75650e
Show file tree
Hide file tree
Showing 10 changed files with 170 additions and 95 deletions.
163 changes: 93 additions & 70 deletions Cargo.lock

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions examples/src/shared/args.rs
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,7 @@ impl Args {
tcp_heartbeat_interval: self.tcp_heartbeat_interval.clone(),
tcp_tls_enabled: self.tcp_tls_enabled,
tcp_tls_domain: self.tcp_tls_domain.clone(),
tcp_tls_ca_file: None,
quic_client_address: self.quic_client_address.clone(),
quic_server_address: self.quic_server_address.clone(),
quic_server_name: self.quic_server_name.clone(),
Expand Down
14 changes: 8 additions & 6 deletions sdk/Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "iggy"
version = "0.6.34"
version = "0.6.40"
description = "Iggy is the persistent message streaming platform written in Rust, supporting QUIC, TCP and HTTP transport protocols, capable of processing millions of messages per second."
edition = "2021"
license = "MIT"
Expand Down Expand Up @@ -40,22 +40,24 @@ humantime = "2.1.0"
keyring = { version = "3.2.0", optional = true, features = ["sync-secret-service", "vendored"] }
lazy_static = "1.4.0"
passterm = { version = "2.0.1", optional = true }
pem = { version = "3.0.4" }
quinn = { version = "0.11.5" }
regex = "1.10.4"
reqwest = { version = "0.12.7", default-features = false, features = ["json", "rustls-tls"] }
reqwest-middleware = { version = "0.3.2", features = ["json"] }
reqwest-retry = "0.6.1"
reqwest-middleware = { version = "0.3.3", features = ["json"] }
reqwest-retry = "0.7.0"
rustls = { version = "0.23.10", features = ["ring"] }
serde = { version = "1.0.210", features = ["derive", "rc"] }
serde_json = "1.0.127"
serde_with = { version = "3.8.1", features = ["base64"] }
strum = { version = "0.26.2", features = ["derive"] }
thiserror = "1.0.61"
tokio = { version = "1.40.0", features = ["full"] }
tokio-native-tls = "0.3.1"
thiserror = "2.0.3"
tokio = { version = "1.40.1", features = ["full"] }
tokio-rustls = { version = "0.26.0" }
toml = "0.8.14"
tracing = { version = "0.1.40" }
uuid = { version = "1.1.0", features = ["v7", "fast-rng", "zerocopy"] }
webpki-roots = { version = "0.26.6" }

[build-dependencies]
convert_case = "0.6.0"
Expand Down
4 changes: 4 additions & 0 deletions sdk/src/args.rs
Original file line number Diff line number Diff line change
Expand Up @@ -227,6 +227,9 @@ pub struct Args {
/// The optional TLS domain for the TCP transport
pub tcp_tls_domain: String,

/// The optional CA file for the TCP transport
pub tcp_tls_ca_file: Option<String>,

/// The optional client address for the QUIC transport
pub quic_client_address: String,

Expand Down Expand Up @@ -316,6 +319,7 @@ impl Default for Args {
tcp_heartbeat_interval: "5s".to_string(),
tcp_tls_enabled: false,
tcp_tls_domain: "localhost".to_string(),
tcp_tls_ca_file: None,
quic_client_address: "127.0.0.1:0".to_string(),
quic_server_address: "127.0.0.1:8080".to_string(),
quic_server_name: "localhost".to_string(),
Expand Down
1 change: 1 addition & 0 deletions sdk/src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -595,6 +595,7 @@ impl From<ConnectionString> for TcpClientConfig {
auto_login: connection_string.auto_login,
tls_enabled: connection_string.options.tls_enabled,
tls_domain: connection_string.options.tls_domain,
tls_ca_file: None,
reconnection: connection_string.options.reconnection,
heartbeat_interval: connection_string.options.heartbeat_interval,
}
Expand Down
1 change: 1 addition & 0 deletions sdk/src/client_provider.rs
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,7 @@ impl ClientProviderConfig {
server_address: args.tcp_server_address,
tls_enabled: args.tcp_tls_enabled,
tls_domain: args.tcp_tls_domain,
tls_ca_file: args.tcp_tls_ca_file,
heartbeat_interval: IggyDuration::from_str(&args.tcp_heartbeat_interval)
.unwrap(),
reconnection: TcpClientReconnectionConfig {
Expand Down
5 changes: 5 additions & 0 deletions sdk/src/clients/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,11 @@ impl TcpClientBuilder {
self
}

pub fn with_tls_ca_file(mut self, tls_ca_file: String) -> Self {
self.config = self.config.with_tls_ca_file(tls_ca_file);
self
}

/// Builds the parent `IggyClient` with TCP configuration.
pub fn build(self) -> Result<IggyClient, IggyError> {
let client = TcpClient::create(Arc::new(self.config.build()))?;
Expand Down
8 changes: 8 additions & 0 deletions sdk/src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,14 @@ pub enum IggyError {
RequestError(#[from] reqwest::Error) = 62,
#[error("Client shutdown")]
ClientShutdown = 63,
#[error("Invalid TLS domain")]
InvalidTlsDomain = 64,
#[error("Invalid TLS certificate path")]
InvalidTlsCertificatePath = 65,
#[error("Invalid TLS certificate")]
InvalidTlsCertificate = 66,
#[error("Failed to add certificate")]
FailedToAddCertificate = 67,
#[error("Invalid encryption key")]
InvalidEncryptionKey = 70,
#[error("Cannot encrypt data")]
Expand Down
58 changes: 39 additions & 19 deletions sdk/src/tcp/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ use crate::utils::timestamp::IggyTimestamp;
use async_broadcast::{broadcast, Receiver, Sender};
use async_trait::async_trait;
use bytes::{Buf, BufMut, Bytes, BytesMut};
use rustls::pki_types::{pem::PemObject, CertificateDer, ServerName};
use std::fmt::Debug;
use std::net::SocketAddr;
use std::str::FromStr;
Expand All @@ -21,8 +22,7 @@ use tokio::net::tcp::{OwnedReadHalf, OwnedWriteHalf};
use tokio::net::TcpStream;
use tokio::sync::Mutex;
use tokio::time::sleep;
use tokio_native_tls::native_tls::TlsConnector;
use tokio_native_tls::TlsStream;
use tokio_rustls::{TlsConnector, TlsStream};
use tracing::{error, info, trace, warn};

const REQUEST_INITIAL_BYTES_LENGTH: usize = 4;
Expand Down Expand Up @@ -462,24 +462,44 @@ impl TcpClient {
break;
}

let connector = tokio_native_tls::TlsConnector::from(
TlsConnector::builder().build().map_err(|error| {
error!("Failed to create a TLS connector: {error}");
IggyError::CannotEstablishConnection
})?,
);
let stream = tokio_native_tls::TlsConnector::connect(
&connector,
&self.config.tls_domain,
stream,
)
.await
.map_err(|error| {
error!("Failed to establish a TLS connection: {error}");
IggyError::CannotEstablishConnection
})?;
let mut root_cert_store = rustls::RootCertStore::empty();
if let Some(certificate_path) = &self.config.tls_ca_file {
for cert in CertificateDer::pem_file_iter(certificate_path).map_err(|error| {
error!("Failed to read the CA file: {certificate_path}. {error}",);
IggyError::InvalidTlsCertificatePath
})? {
let certificate = cert.map_err(|error| {
error!(
"Failed to read a certificate from the CA file: {certificate_path}. {error}",
);
IggyError::InvalidTlsCertificate
})?;
root_cert_store.add(certificate).map_err(|error| {
error!(
"Failed to add a certificate to the root certificate store. {error}",
);
IggyError::InvalidTlsCertificate
})?;
}
} else {
root_cert_store.extend(webpki_roots::TLS_SERVER_ROOTS.iter().cloned());
}

connection_stream = Box::new(TcpTlsConnectionStream::new(client_address, stream));
let config = rustls::ClientConfig::builder()
.with_root_certificates(root_cert_store)
.with_no_client_auth();
let connector = TlsConnector::from(Arc::new(config));
let stream = TcpStream::connect(client_address).await?;
let tls_domain = self.config.tls_domain.to_owned();
let domain = ServerName::try_from(tls_domain).map_err(|error| {
error!("Failed to create a server name from the domain. {error}",);
IggyError::InvalidTlsDomain
})?;
let stream = connector.connect(domain, stream).await?;
connection_stream = Box::new(TcpTlsConnectionStream::new(
client_address,
TlsStream::Client(stream),
));
break;
}

Expand Down
10 changes: 10 additions & 0 deletions sdk/src/tcp/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ pub struct TcpClientConfig {
pub tls_enabled: bool,
/// The domain to use for TLS when connecting to the server.
pub tls_domain: String,
/// The path to the CA file for TLS.
pub tls_ca_file: Option<String>,
/// Whether to automatically login user after establishing connection.
pub auto_login: AutoLogin,
/// Whether to automatically reconnect when disconnected.
Expand All @@ -33,6 +35,7 @@ impl Default for TcpClientConfig {
server_address: "127.0.0.1:8090".to_string(),
tls_enabled: false,
tls_domain: "localhost".to_string(),
tls_ca_file: None,
heartbeat_interval: IggyDuration::from_str("5s").unwrap(),
auto_login: AutoLogin::Disabled,
reconnection: TcpClientReconnectionConfig::default(),
Expand All @@ -58,6 +61,7 @@ impl Default for TcpClientReconnectionConfig {
/// - `reconnection`: Default is enabled unlimited retries and 1 second interval.
/// - `tls_enabled`: Default is false.
/// - `tls_domain`: Default is "localhost".
/// - `tls_ca_file`: Default is None.
#[derive(Debug, Default)]
pub struct TcpClientConfigBuilder {
config: TcpClientConfig,
Expand Down Expand Up @@ -109,6 +113,12 @@ impl TcpClientConfigBuilder {
self
}

/// Sets the path to the CA file for TLS.
pub fn with_tls_ca_file(mut self, tls_ca_file: String) -> Self {
self.config.tls_ca_file = Some(tls_ca_file);
self
}

/// Builds the TCP client configuration.
pub fn build(self) -> TcpClientConfig {
self.config
Expand Down

0 comments on commit f75650e

Please sign in to comment.