Skip to content

Commit

Permalink
feat: add TLS support
Browse files Browse the repository at this point in the history
Closes #14.
  • Loading branch information
kezhuw committed Mar 19, 2024
1 parent 45db940 commit 0b225d9
Show file tree
Hide file tree
Showing 8 changed files with 453 additions and 76 deletions.
6 changes: 6 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,10 @@ hashbrown = "0.12.0"
hashlink = "0.8.0"
either = "1.9.0"
uuid = { version = "1.4.1", features = ["v4"] }
rustls = "0.23.2"
rustls-pemfile = "2"
webpki-roots = "0.26.1"
derive-where = "1.2.7"

[dev-dependencies]
test-log = "0.2.12"
Expand All @@ -40,3 +44,5 @@ testcontainers = { git = "https://github.com/kezhuw/testcontainers-rs.git", bran
assertor = "0.0.2"
assert_matches = "1.5.0"
tempfile = "3.6.0"
maplit = "1.0.2"
rcgen = { version = "0.12.1", features = ["default", "x509-parser"] }
77 changes: 72 additions & 5 deletions src/client/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,14 @@ use std::borrow::Cow;
use std::fmt::Write as _;
use std::future::Future;
use std::mem::ManuallyDrop;
use std::sync::Arc;
use std::time::Duration;

use const_format::formatcp;
use either::{Either, Left, Right};
use ignore_result::Ignore;
use rustls::pki_types::{CertificateDer, PrivateKeyDer};
use rustls::{ClientConfig, RootCertStore};
use thiserror::Error;
use tokio::sync::{mpsc, watch};

Expand Down Expand Up @@ -1528,6 +1531,9 @@ pub(crate) struct Version(u32, u32, u32);
/// Builder for [Client] with more options than [Client::connect].
#[derive(Clone, Debug)]
pub struct ClientBuilder {
tls: bool,
trusted_certs: RootCertStore,
client_certs: Option<(Vec<CertificateDer<'static>>, Arc<PrivateKeyDer<'static>>)>,
authes: Vec<AuthPacket>,
version: Version,
session: Option<(SessionId, Vec<u8>)>,
Expand All @@ -1540,6 +1546,9 @@ pub struct ClientBuilder {
impl ClientBuilder {
fn new() -> Self {
Self {
tls: false,
trusted_certs: RootCertStore::empty(),
client_certs: None,
authes: Default::default(),
version: Version(u32::MAX, u32::MAX, u32::MAX),
session: None,
Expand Down Expand Up @@ -1584,6 +1593,43 @@ impl ClientBuilder {
self
}

/// Assumes tls for server in connection string if no protocol specified individually.
/// See [Self::connect] for syntax to specify protocol individually.
pub fn assume_tls(&mut self) -> &mut Self {
self.tls = true;
self
}

/// Trusts certificates signed by given ca certificates.
pub fn trust_ca_pem_certs(&mut self, certs: &str) -> Result<&mut Self> {
for r in rustls_pemfile::certs(&mut certs.as_bytes()) {
let cert = match r {
Ok(cert) => cert,
Err(err) => return Err(Error::other(format!("fail to read cert {}", err), err)),
};
if let Err(err) = self.trusted_certs.add(cert) {
return Err(Error::other(format!("fail to add cert {}", err), err));
}
}
Ok(self)
}

/// Identifies client itself to server with given cert chain and private key.
pub fn use_client_pem_cert(&mut self, cert: &str, key: &str) -> Result<&mut Self> {
let r: std::result::Result<Vec<_>, _> = rustls_pemfile::certs(&mut cert.as_bytes()).collect();
let certs = match r {
Err(err) => return Err(Error::other(format!("fail to read cert {}", err), err)),
Ok(certs) => certs,
};
let key = match rustls_pemfile::private_key(&mut key.as_bytes()) {
Err(err) => return Err(Error::other(format!("fail to read client private key {err}"), err)),
Ok(None) => return Err(Error::BadArguments(&"no client private key")),
Ok(Some(key)) => key,
};
self.client_certs = Some((certs, Arc::new(key)));
Ok(self)
}

/// Specifies client assumed server version of ZooKeeper cluster.
///
/// Client will issue server compatible protocol to avoid [Error::Unimplemented] for some
Expand All @@ -1606,13 +1652,17 @@ impl ClientBuilder {

/// Connects to ZooKeeper cluster.
///
/// Parameter `cluster` specifies connection string to ZooKeeper cluster. It has same syntax as
/// Java client except that you can specifies protocol for server individually. For example,
/// `tcp://server1,tcp+tls://server2:port,server3`. This claims that `server1` uses plaintext
/// protocol, `server2` uses tls encrypted protocol while `server3` uses tls if
/// [Self::assume_tls] is specified or plaintext otherwise.
///
/// # Notable errors
/// * [Error::NoHosts] if no host is available
/// * [Error::SessionExpired] if specified session expired
pub async fn connect(&mut self, cluster: &str) -> Result<Client> {
let (hosts, chroot) = util::parse_connect_string(cluster)?;
let mut buf = Vec::with_capacity(4096);
let mut connecting_depot = Depot::for_connecting();
let (hosts, chroot) = util::parse_connect_string(cluster, self.tls)?;
if let Some((id, password)) = &self.session {
if id.0 == 0 {
return Err(Error::BadArguments(&"session id must not be 0"));
Expand All @@ -1628,22 +1678,39 @@ impl ClientBuilder {
} else if self.connection_timeout < Duration::ZERO {
return Err(Error::BadArguments(&"connection timeout must not be negative"));
}
self.trusted_certs.extend(webpki_roots::TLS_SERVER_ROOTS.iter().cloned());
let tls_config = if let Some((certs, private_key)) = self.client_certs.take() {
match ClientConfig::builder()
.with_root_certificates(std::mem::replace(&mut self.trusted_certs, RootCertStore::empty()))
.with_client_auth_cert(certs, Arc::try_unwrap(private_key).unwrap_or_else(|k| k.clone_key()))
{
Ok(config) => config,
Err(err) => return Err(Error::other(format!("invalid client private key {err}"), err)),
}
} else {
ClientConfig::builder()
.with_root_certificates(std::mem::replace(&mut self.trusted_certs, RootCertStore::empty()))
.with_no_client_auth()
};
let (mut session, state_receiver) = Session::new(
self.session.take(),
&self.authes,
self.readonly,
self.detached,
tls_config,
self.session_timeout,
self.connection_timeout,
);
let mut hosts_iter = hosts.iter().copied();
let sock = session.start(&mut hosts_iter, &mut buf, &mut connecting_depot).await?;
let mut buf = Vec::with_capacity(4096);
let mut connecting_depot = Depot::for_connecting();
let conn = session.start(&mut hosts_iter, &mut buf, &mut connecting_depot).await?;
let (sender, receiver) = mpsc::unbounded_channel();
let servers = hosts.into_iter().map(|addr| addr.to_value()).collect();
let session_info = (session.session_id, session.session_password.clone());
let session_timeout = session.session_timeout;
tokio::spawn(async move {
session.serve(servers, sock, buf, connecting_depot, receiver).await;
session.serve(servers, conn, buf, connecting_depot, receiver).await;
});
let client =
Client::new(chroot.to_owned(), self.version, session_info, session_timeout, sender, state_receiver);
Expand Down
19 changes: 19 additions & 0 deletions src/error.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
use std::sync::Arc;

use derive_where::derive_where;
use static_assertions::assert_impl_all;
use thiserror::Error;

Expand Down Expand Up @@ -82,6 +85,18 @@ pub enum Error {

#[error("runtime condition mismatch")]
RuntimeInconsistent,

#[error(transparent)]
Other(OtherError),
}

#[derive(Error, Clone, Debug)]
#[derive_where(Eq, PartialEq)]
#[error("{message}")]
pub struct OtherError {
message: Arc<String>,
#[derive_where(skip(EqHashOrd))]
source: Option<Arc<dyn std::error::Error + Send + Sync + 'static>>,
}

impl Error {
Expand Down Expand Up @@ -111,6 +126,10 @@ impl Error {
_ => false,
}
}

pub(crate) fn other(message: impl Into<String>, source: impl std::error::Error + Send + Sync + 'static) -> Self {
Self::Other(OtherError { message: Arc::new(message.into()), source: Some(Arc::new(source)) })
}
}

assert_impl_all!(Error: Send, Sync);
Expand Down
128 changes: 128 additions & 0 deletions src/session/connection.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
use std::io::{self, Read, Write};
use std::sync::Arc;

use rustls::pki_types::ServerName;
use rustls::{ClientConfig, ClientConnection};
use tokio::net::TcpStream;

use crate::error::Error;

pub struct Connection {
tls: Option<ClientConnection>,
stream: TcpStream,
}

struct WrappingStream<'a> {
stream: &'a TcpStream,
}

impl io::Read for WrappingStream<'_> {
fn read(&mut self, mut buf: &mut [u8]) -> io::Result<usize> {
self.stream.try_read_buf(&mut buf)
}
}

impl io::Write for WrappingStream<'_> {
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
self.stream.try_write(buf)
}

fn write_vectored(&mut self, bufs: &[io::IoSlice<'_>]) -> io::Result<usize> {
self.stream.try_write_vectored(bufs)
}

fn flush(&mut self) -> io::Result<()> {
Ok(())
}
}

impl Connection {
pub fn new_raw(stream: TcpStream) -> Self {
Self { tls: None, stream }
}

pub fn new_tls(host: &str, config: Arc<ClientConfig>, stream: TcpStream) -> Result<Self, Error> {
let name = match ServerName::try_from(host) {
Err(_) => return Err(Error::BadArguments(&"invalid server dns name")),
Ok(name) => name.to_owned(),
};
let client = match ClientConnection::new(config, name) {
Err(err) => return Err(Error::other(format!("fail to create tls client for host({host}): {err}"), err)),
Ok(client) => client,
};
Ok(Self { tls: Some(client), stream })
}

pub fn write_vectored(&mut self, bufs: &[io::IoSlice<'_>]) -> io::Result<usize> {
let Some(client) = self.tls.as_mut() else {
return self.stream.try_write_vectored(bufs);
};
let n = client.writer().write_vectored(bufs)?;
let mut stream = WrappingStream { stream: &self.stream };
client.write_tls(&mut stream)?;
Ok(n)
}

pub fn read_buf(&mut self, buf: &mut Vec<u8>) -> io::Result<usize> {
let Some(client) = self.tls.as_mut() else {
return self.stream.try_read_buf(buf);
};
let mut stream = WrappingStream { stream: &self.stream };
let mut read_bytes = 0;
loop {
match client.read_tls(&mut stream) {
// We may have plaintext to return though tcp stream has been closed.
// If not, read_bytes should be zero.
Ok(0) => break,
Ok(_) => {},
Err(err) => match err.kind() {
// backpressure: tls buffer is full, let's process_new_packets.
io::ErrorKind::Other => {},
io::ErrorKind::WouldBlock if read_bytes == 0 => {
return Err(err);
},
_ => break,
},
}
let state = client.process_new_packets().map_err(io::Error::other)?;
let n = state.plaintext_bytes_to_read();
buf.reserve(n);
let slice = unsafe { &mut std::slice::from_raw_parts_mut(buf.as_mut_ptr(), buf.len() + n)[buf.len()..] };
client.reader().read_exact(slice).unwrap();
unsafe { buf.set_len(buf.len() + n) };
read_bytes += n;
}
Ok(read_bytes)
}

pub async fn readable(&self) -> io::Result<()> {
let Some(client) = self.tls.as_ref() else {
return self.stream.readable().await;
};
if client.wants_read() {
self.stream.readable().await
} else {
// plaintext data are available for read
std::future::ready(Ok(())).await
}
}

pub async fn writable(&self) -> io::Result<()> {
self.stream.writable().await
}

pub fn wants_write(&self) -> bool {
self.tls.as_ref().map(|tls| tls.wants_write()).unwrap_or(false)
}

pub fn flush(&mut self) -> io::Result<()> {
let Some(client) = self.tls.as_mut() else {
return Ok(());
};
let mut stream = WrappingStream { stream: &self.stream };
while client.wants_write() {
client.write_tls(&mut stream)?;
}
Ok(())
}
}
16 changes: 13 additions & 3 deletions src/session/depot.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@ use std::io::{self, IoSlice};

use hashbrown::HashMap;
use strum::IntoEnumIterator;
use tokio::net::TcpStream;

use super::connection::Connection;
use super::request::{MarshalledRequest, OpStat, Operation, SessionOperation, StateResponser};
use super::types::WatchMode;
use super::xid::Xid;
Expand Down Expand Up @@ -205,8 +205,18 @@ impl Depot {
.any(|mode| self.watching_paths.contains_key(&(path, mode)))
}

pub fn write_operations(&mut self, sock: &TcpStream, session_id: SessionId) -> Result<(), Error> {
let result = sock.try_write_vectored(self.writing_slices.as_slice());
pub fn write_operations(&mut self, conn: &mut Connection, session_id: SessionId) -> Result<(), Error> {
if !self.has_pending_writes() {
if let Err(err) = conn.flush() {
if err.kind() == io::ErrorKind::WouldBlock {
return Ok(());
}
log::debug!("ZooKeeper session {} write failed {}", session_id, err);
return Err(Error::ConnectionLoss);
}
return Ok(());
}
let result = conn.write_vectored(self.writing_slices.as_slice());
let mut written_bytes = match result {
Err(err) => {
if err.kind() == io::ErrorKind::WouldBlock {
Expand Down
Loading

0 comments on commit 0b225d9

Please sign in to comment.