diff --git a/src/client/mod.rs b/src/client/mod.rs index 0ff4d1d..b55a8dd 100644 --- a/src/client/mod.rs +++ b/src/client/mod.rs @@ -4,7 +4,6 @@ use std::borrow::Cow; use std::fmt::Write as _; use std::future::Future; use std::mem::ManuallyDrop; -use std::sync::Arc; use std::time::Duration; use const_format::formatcp; @@ -227,14 +226,20 @@ impl Client { /// Connects to ZooKeeper cluster. pub async fn connect(cluster: &str) -> Result { - Self::builder().connect(cluster).await + Self::connector().connect(cluster).await } /// Creates a builder with configurable options in connecting to ZooKeeper cluster. + #[deprecated(since = "0.7.0", note = "use Client::connector instead")] pub fn builder() -> ClientBuilder { ClientBuilder::new() } + /// Creates a builder with configurable options in connecting to ZooKeeper cluster. + pub fn connector() -> Connector { + Connector::new() + } + pub(crate) fn new( chroot: OwnedChroot, version: Version, @@ -1528,32 +1533,96 @@ impl Drop for OwnedLockClient { #[derive(Copy, Clone, Debug, PartialEq, PartialOrd)] pub(crate) struct Version(u32, u32, u32); -/// Builder for [Client] with more options than [Client::connect]. +/// Options for tls connection. +#[derive(Debug)] +pub struct TlsOptions { + identity: Option<(Vec>, PrivateKeyDer<'static>)>, + ca_certs: RootCertStore, +} + +impl Clone for TlsOptions { + fn clone(&self) -> Self { + Self { + identity: self.identity.as_ref().map(|id| (id.0.clone(), id.1.clone_key())), + ca_certs: self.ca_certs.clone(), + } + } +} + +impl Default for TlsOptions { + /// Tls options with well-known ca roots. + fn default() -> Self { + let mut options = Self::no_ca(); + options.ca_certs.extend(webpki_roots::TLS_SERVER_ROOTS.iter().cloned()); + options + } +} + +impl TlsOptions { + /// Tls options with no ca certificates. Use [TlsOptions::default] if well-known ca roots is + /// desirable. + pub fn no_ca() -> Self { + Self { ca_certs: RootCertStore::empty(), identity: None } + } + + /// Adds new ca certificates. + pub fn with_pem_ca_certs(mut self, certs: &str) -> Result { + for r in rustls_pemfile::certs(&mut certs.as_bytes()) { + let cert = match r { + Ok(cert) => cert, + Err(err) => return Err(Error::other(format!("fail to read cert {}", err), err)), + }; + if let Err(err) = self.ca_certs.add(cert) { + return Err(Error::other(format!("fail to add cert {}", err), err)); + } + } + Ok(self) + } + + /// Specifies client identity for server to authenticate. + pub fn with_pem_identity(mut self, cert: &str, key: &str) -> Result { + let r: std::result::Result, _> = rustls_pemfile::certs(&mut cert.as_bytes()).collect(); + let certs = match r { + Err(err) => return Err(Error::other(format!("fail to read cert {}", err), err)), + Ok(certs) => certs, + }; + let key = match rustls_pemfile::private_key(&mut key.as_bytes()) { + Err(err) => return Err(Error::other(format!("fail to read client private key {err}"), err)), + Ok(None) => return Err(Error::BadArguments(&"no client private key")), + Ok(Some(key)) => key, + }; + self.identity = Some((certs, key)); + Ok(self) + } + + fn take_roots(&mut self) -> RootCertStore { + std::mem::replace(&mut self.ca_certs, RootCertStore::empty()) + } +} + +/// A builder for [Client]. #[derive(Clone, Debug)] -pub struct ClientBuilder { - tls: bool, - trusted_certs: RootCertStore, - client_certs: Option<(Vec>, Arc>)>, +pub struct Connector { + tls: Option, authes: Vec, - version: Version, session: Option<(SessionId, Vec)>, readonly: bool, detached: bool, + server_version: Version, session_timeout: Duration, connection_timeout: Duration, } -impl ClientBuilder { +/// Builder for [Client] with more options than [Client::connect]. +impl Connector { fn new() -> Self { Self { - tls: false, - trusted_certs: RootCertStore::empty(), - client_certs: None, + tls: None, authes: Default::default(), - version: Version(u32::MAX, u32::MAX, u32::MAX), session: None, readonly: false, detached: false, + server_version: Version(u32::MAX, u32::MAX, u32::MAX), session_timeout: Duration::ZERO, connection_timeout: Duration::ZERO, } @@ -1562,7 +1631,7 @@ impl ClientBuilder { /// Specifies target session timeout to negotiate with ZooKeeper server. /// /// Defaults to 6s. - pub fn with_session_timeout(&mut self, timeout: Duration) -> &mut Self { + pub fn session_timeout(&mut self, timeout: Duration) -> &mut Self { self.session_timeout = timeout; self } @@ -1570,67 +1639,30 @@ impl ClientBuilder { /// Specifies idle timeout to conclude a connection as loss. /// /// Defaults to `2/5` of session timeout. - pub fn with_connection_timeout(&mut self, timeout: Duration) -> &mut Self { + pub fn connection_timeout(&mut self, timeout: Duration) -> &mut Self { self.connection_timeout = timeout; self } /// Specifies whether readonly server is allowed. - pub fn with_readonly(&mut self, readonly: bool) -> &mut ClientBuilder { + pub fn readonly(&mut self, readonly: bool) -> &mut Self { self.readonly = readonly; self } /// Specifies auth info for given authentication scheme. - pub fn with_auth(&mut self, scheme: String, auth: Vec) -> &mut ClientBuilder { + pub fn auth(&mut self, scheme: String, auth: Vec) -> &mut Self { self.authes.push(AuthPacket { scheme, auth }); self } /// Specifies session to reestablish. - pub fn with_session(&mut self, id: SessionId, password: Vec) -> &mut Self { + pub fn session(&mut self, id: SessionId, password: Vec) -> &mut Self { self.session = Some((id, password)); self } - /// Assumes tls for server in connection string if no protocol specified individually. - /// See [Self::connect] for syntax to specify protocol individually. - pub fn assume_tls(&mut self) -> &mut Self { - self.tls = true; - self - } - - /// Trusts certificates signed by given ca certificates. - pub fn trust_ca_pem_certs(&mut self, certs: &str) -> Result<&mut Self> { - for r in rustls_pemfile::certs(&mut certs.as_bytes()) { - let cert = match r { - Ok(cert) => cert, - Err(err) => return Err(Error::other(format!("fail to read cert {}", err), err)), - }; - if let Err(err) = self.trusted_certs.add(cert) { - return Err(Error::other(format!("fail to add cert {}", err), err)); - } - } - Ok(self) - } - - /// Identifies client itself to server with given cert chain and private key. - pub fn use_client_pem_cert(&mut self, cert: &str, key: &str) -> Result<&mut Self> { - let r: std::result::Result, _> = rustls_pemfile::certs(&mut cert.as_bytes()).collect(); - let certs = match r { - Err(err) => return Err(Error::other(format!("fail to read cert {}", err), err)), - Ok(certs) => certs, - }; - let key = match rustls_pemfile::private_key(&mut key.as_bytes()) { - Err(err) => return Err(Error::other(format!("fail to read client private key {err}"), err)), - Ok(None) => return Err(Error::BadArguments(&"no client private key")), - Ok(Some(key)) => key, - }; - self.client_certs = Some((certs, Arc::new(key))); - Ok(self) - } - - /// Specifies client assumed server version of ZooKeeper cluster. + /// Specifies target server version of ZooKeeper cluster. /// /// Client will issue server compatible protocol to avoid [Error::Unimplemented] for some /// operations. See [Client::create] for an example. @@ -1639,30 +1671,25 @@ impl ClientBuilder { /// /// [ZOOKEEPER-1381]: https://issues.apache.org/jira/browse/ZOOKEEPER-1381 /// [ZOOKEEPER-3762]: https://issues.apache.org/jira/browse/ZOOKEEPER-3762 - pub fn assume_server_version(&mut self, major: u32, minor: u32, patch: u32) -> &mut Self { - self.version = Version(major, minor, patch); + pub fn server_version(&mut self, major: u32, minor: u32, patch: u32) -> &mut Self { + self.server_version = Version(major, minor, patch); self } - /// Detaches creating session so it will not be closed after all client instances dropped. - pub fn detach(&mut self) -> &mut Self { + /// Detaches created session so it will not be closed after all client instances dropped. + pub fn detached(&mut self) -> &mut Self { self.detached = true; self } - /// Connects to ZooKeeper cluster. - /// - /// Parameter `cluster` specifies connection string to ZooKeeper cluster. It has same syntax as - /// Java client except that you can specifies protocol for server individually. For example, - /// `tcp://server1,tcp+tls://server2:port,server3`. This claims that `server1` uses plaintext - /// protocol, `server2` uses tls encrypted protocol while `server3` uses tls if - /// [Self::assume_tls] is specified or plaintext otherwise. - /// - /// # Notable errors - /// * [Error::NoHosts] if no host is available - /// * [Error::SessionExpired] if specified session expired - pub async fn connect(&mut self, cluster: &str) -> Result { - let (hosts, chroot) = util::parse_connect_string(cluster, self.tls)?; + /// Specifies tls options for connections to ZooKeeper. + pub fn tls(&mut self, options: TlsOptions) -> &mut Self { + self.tls = Some(options); + self + } + + async fn connect_internally(&mut self, secure: bool, cluster: &str) -> Result { + let (hosts, chroot) = util::parse_connect_string(cluster, secure)?; if let Some((id, password)) = &self.session { if id.0 == 0 { return Err(Error::BadArguments(&"session id must not be 0")); @@ -1678,19 +1705,15 @@ impl ClientBuilder { } else if self.connection_timeout < Duration::ZERO { return Err(Error::BadArguments(&"connection timeout must not be negative")); } - self.trusted_certs.extend(webpki_roots::TLS_SERVER_ROOTS.iter().cloned()); - let tls_config = if let Some((certs, private_key)) = self.client_certs.take() { - match ClientConfig::builder() - .with_root_certificates(std::mem::replace(&mut self.trusted_certs, RootCertStore::empty())) - .with_client_auth_cert(certs, Arc::try_unwrap(private_key).unwrap_or_else(|k| k.clone_key())) - { + let mut tls_options = self.tls.take().unwrap_or_default(); + let tls_builder = ClientConfig::builder().with_root_certificates(tls_options.take_roots()); + let tls_config = if let Some((client_cert, client_key)) = tls_options.identity.take() { + match tls_builder.with_client_auth_cert(client_cert, client_key) { Ok(config) => config, Err(err) => return Err(Error::other(format!("invalid client private key {err}"), err)), } } else { - ClientConfig::builder() - .with_root_certificates(std::mem::replace(&mut self.trusted_certs, RootCertStore::empty())) - .with_no_client_auth() + tls_builder.with_no_client_auth() }; let (mut session, state_receiver) = Session::new( self.session.take(), @@ -1713,9 +1736,110 @@ impl ClientBuilder { session.serve(servers, conn, buf, connecting_depot, receiver).await; }); let client = - Client::new(chroot.to_owned(), self.version, session_info, session_timeout, sender, state_receiver); + Client::new(chroot.to_owned(), self.server_version, session_info, session_timeout, sender, state_receiver); Ok(client) } + + /// Connects to ZooKeeper cluster. + /// + /// Same to [Self::connect] except that `server1` will use tls encrypted protocol given + /// the connection string `server1,tcp://server2,tcp+tls://server3`. + pub async fn secure_connect(&mut self, cluster: &str) -> Result { + self.connect_internally(true, cluster).await + } + + /// Connects to ZooKeeper cluster. + /// + /// Parameter `cluster` specifies connection string to ZooKeeper cluster. It has same syntax as + /// Java client except that you can specifies protocol for server individually. For example, + /// `server1,tcp://server2,tcp+tls://server3`. This claims that `server1` and `server2` use + /// plaintext protocol, while `server3` uses tls encrypted protocol. + /// + /// # Notable errors + /// * [Error::NoHosts] if no host is available + /// * [Error::SessionExpired] if specified session expired + /// + /// # Notable behaviors + /// The state of this connector is undefined after connection attempt no matter whether it is + /// success or not. + pub async fn connect(&mut self, cluster: &str) -> Result { + self.connect_internally(false, cluster).await + } +} + +/// Builder for [Client] with more options than [Client::connect]. +#[derive(Clone, Debug)] +pub struct ClientBuilder { + connector: Connector, +} + +impl ClientBuilder { + fn new() -> Self { + Self { connector: Connector::new() } + } + + /// Specifies target session timeout to negotiate with ZooKeeper server. + /// + /// Defaults to 6s. + pub fn with_session_timeout(&mut self, timeout: Duration) -> &mut Self { + self.connector.session_timeout(timeout); + self + } + + /// Specifies idle timeout to conclude a connection as loss. + /// + /// Defaults to `2/5` of session timeout. + pub fn with_connection_timeout(&mut self, timeout: Duration) -> &mut Self { + self.connector.connection_timeout(timeout); + self + } + + /// Specifies whether readonly server is allowed. + pub fn with_readonly(&mut self, readonly: bool) -> &mut ClientBuilder { + self.connector.readonly = readonly; + self + } + + /// Specifies auth info for given authentication scheme. + pub fn with_auth(&mut self, scheme: String, auth: Vec) -> &mut ClientBuilder { + self.connector.auth(scheme, auth); + self + } + + /// Specifies session to reestablish. + pub fn with_session(&mut self, id: SessionId, password: Vec) -> &mut Self { + self.connector.session(id, password); + self + } + + /// Specifies client assumed server version of ZooKeeper cluster. + /// + /// Client will issue server compatible protocol to avoid [Error::Unimplemented] for some + /// operations. See [Client::create] for an example. + /// + /// See [ZOOKEEPER-1381][] and [ZOOKEEPER-3762][] for references. + /// + /// [ZOOKEEPER-1381]: https://issues.apache.org/jira/browse/ZOOKEEPER-1381 + /// [ZOOKEEPER-3762]: https://issues.apache.org/jira/browse/ZOOKEEPER-3762 + pub fn assume_server_version(&mut self, major: u32, minor: u32, patch: u32) -> &mut Self { + self.connector.server_version(major, minor, patch); + self + } + + /// Detaches creating session so it will not be closed after all client instances dropped. + pub fn detach(&mut self) -> &mut Self { + self.connector.detached(); + self + } + + /// Connects to ZooKeeper cluster. + /// + /// # Notable errors + /// * [Error::NoHosts] if no host is available + /// * [Error::SessionExpired] if specified session expired + pub async fn connect(&mut self, cluster: &str) -> Result { + self.connector.connect(cluster).await + } } trait MultiBuffer { diff --git a/tests/zookeeper.rs b/tests/zookeeper.rs index df18246..e561693 100644 --- a/tests/zookeeper.rs +++ b/tests/zookeeper.rs @@ -149,13 +149,13 @@ async fn test_connect_nohosts() { #[test_log::test(tokio::test)] async fn test_connect_session_expired() { let server = Server::new(); - let client = server.custom_client(None, |builder| builder.detach()).await.unwrap(); + let client = server.custom_client(None, |connector| connector.detached()).await.unwrap(); let timeout = client.session_timeout(); let (id, password) = client.into_session(); tokio::time::sleep(timeout * 2).await; - assert_that!(server.custom_client(None, |builder| builder.with_session(id, password)).await.unwrap_err()) + assert_that!(server.custom_client(None, |connector| connector.session(id, password)).await.unwrap_err()) .is_equal_to(zk::Error::SessionExpired); } @@ -322,52 +322,47 @@ serverCnxnFactory=org.apache.zookeeper.server.NettyServerCnxnFactory self.container.get_host_port(2181) } - pub fn url(&self) -> String { - let protocol = match (self.tls.is_some(), rand::random()) { - (true, _) => "tcp+tls://", - (false, true) => "tcp://", - (false, false) => "", + fn url(&self, chroot: Option<&str>) -> (String, bool) { + let (secure, explicit) = (rand::random(), rand::random()); + let protocol = match (self.tls.is_some(), secure, explicit) { + (true, false, _) | (true, _, true) => "tcp+tls://", + (false, true, _) | (false, _, true) => "tcp://", + (_, _, _) => "", }; - format!("{}127.0.0.1:{}", protocol, self.port()) + (format!("{}127.0.0.1:{}{}", protocol, self.port(), chroot.unwrap_or("")), secure) } - pub fn builder(&self) -> zk::ClientBuilder { + pub fn connector(&self) -> zk::Connector { + let mut connector = zk::Client::connector(); let Some(tls) = self.tls.as_ref() else { - return zk::Client::builder(); + return connector; }; - let mut builder = zk::Client::builder(); - builder.trust_ca_pem_certs(&tls.ca).unwrap().use_client_pem_cert(&tls.cert, &tls.key).unwrap(); - builder + let tls_options = zk::TlsOptions::default() + .with_pem_ca_certs(&tls.ca) + .unwrap() + .with_pem_identity(&tls.cert, &tls.key) + .unwrap(); + connector.tls(tls_options); + connector } pub async fn client(&self, chroot: Option<&str>) -> zk::Client { - self.custom_client(chroot, |builder| builder).await.unwrap() + self.custom_client(chroot, |connector| connector).await.unwrap() } pub async fn custom_client( &self, chroot: Option<&str>, - custom: impl FnOnce(&mut zk::ClientBuilder) -> &mut zk::ClientBuilder, + custom: impl FnOnce(&mut zk::Connector) -> &mut zk::Connector, ) -> Result { - let mut builder = self.builder(); - custom(&mut builder); - let chroot = chroot.unwrap_or(""); - let Some(tls) = self.tls.as_ref() else { - let url = self.url() + chroot; - return builder.connect(&url).await; - }; - let assume_tls: bool = rand::random(); - let protocol = if !assume_tls || rand::random() { "tcp+tls://" } else { "" }; - if assume_tls { - builder.assume_tls(); + let mut connector = self.connector(); + custom(&mut connector); + let (url, secure) = self.url(chroot); + if secure { + connector.secure_connect(&url).await + } else { + connector.connect(&url).await } - builder - .trust_ca_pem_certs(&tls.ca) - .unwrap() - .use_client_pem_cert(&tls.cert, &tls.key) - .unwrap() - .connect(&format!("{}127.0.0.1:{}{}", protocol, self.port(), chroot)) - .await } } @@ -771,7 +766,7 @@ async fn test_create_container() { async fn test_zookeeper34() { let server = Server::with_options(ServerOptions { tls: Some(false), tag: "3.4", ..Default::default() }); - let client = server.custom_client(None, |builder| builder.assume_server_version(3, 4, u32::MAX)).await.unwrap(); + let client = server.custom_client(None, |connector| connector.server_version(3, 4, u32::MAX)).await.unwrap(); let (stat, _sequence) = client.create("/a", b"a1", PERSISTENT_OPEN).await.unwrap(); assert_that!(stat.is_invalid()).is_true(); @@ -989,7 +984,7 @@ async fn test_auth() { assert!(authed_users.contains(&authed_user)); let authed_client = - server.custom_client(None, |builder| builder.with_auth(scheme.to_string(), auth.to_vec())).await.unwrap(); + server.custom_client(None, |connector| connector.auth(scheme.to_string(), auth.to_vec())).await.unwrap(); authed_client.auth(scheme.to_string(), auth.to_vec()).await.unwrap(); let authed_users = client.list_auth_users().await.unwrap(); @@ -1574,19 +1569,19 @@ async fn test_client_drop() { let (id, password) = client.into_session(); assert_eq!(zk::SessionState::Closed, state_watcher.changed().await); - server.custom_client(None, |builder| builder.with_session(id, password)).await.unwrap_err(); + server.custom_client(None, |connector| connector.session(id, password)).await.unwrap_err(); } #[test_log::test(tokio::test)] async fn test_client_detach() { let server = Server::new(); - let client = server.custom_client(None, |builder| builder.detach()).await.unwrap(); + let client = server.custom_client(None, |connector| connector.detached()).await.unwrap(); let mut state_watcher = client.state_watcher(); let (id, password) = client.into_session(); assert_eq!(zk::SessionState::Closed, state_watcher.changed().await); - server.custom_client(None, |builder| builder.with_session(id, password)).await.unwrap(); + server.custom_client(None, |connector| connector.session(id, password)).await.unwrap(); } fn generate_ca_cert() -> (Certificate, String) {