From 3b9d6a6262b62f30b8c9953f0da8e403be53216e Mon Sep 17 00:00:00 2001 From: Johan Andersson Date: Mon, 6 Jul 2020 15:18:10 +0200 Subject: [PATCH] fix(transport): Propagate errors in tls_config instead of unwrap/panic (#385) * Propagate errors in tls_config instead of unwrap Ran into `tls_connector` failing and causing our app to panic and shutdown as it seems there wasn't any way to avoid panicking in `tls_config`. So after talking to @LucioFranco briefly `tls_config` now returns a `Result` instead and propagates errors to the caller, where they can be handled. * Fix compile warning when tls feature is disabled --- examples/src/gcp/client.rs | 2 +- examples/src/tls/client.rs | 2 +- examples/src/tls/server.rs | 2 +- examples/src/tls_client_auth/client.rs | 2 +- examples/src/tls_client_auth/server.rs | 2 +- interop/src/bin/client.rs | 2 +- interop/src/bin/server.rs | 2 +- tonic/src/transport/channel/endpoint.rs | 12 ++++++++---- tonic/src/transport/mod.rs | 4 ++-- tonic/src/transport/server/mod.rs | 15 +++++++++++---- 10 files changed, 28 insertions(+), 17 deletions(-) diff --git a/examples/src/gcp/client.rs b/examples/src/gcp/client.rs index 15f97588d..1cf30fce9 100644 --- a/examples/src/gcp/client.rs +++ b/examples/src/gcp/client.rs @@ -32,7 +32,7 @@ async fn main() -> Result<(), Box> { .domain_name("pubsub.googleapis.com"); let channel = Channel::from_static(ENDPOINT) - .tls_config(tls_config) + .tls_config(tls_config)? .connect() .await?; diff --git a/examples/src/tls/client.rs b/examples/src/tls/client.rs index d2e78b694..451f41d58 100644 --- a/examples/src/tls/client.rs +++ b/examples/src/tls/client.rs @@ -15,7 +15,7 @@ async fn main() -> Result<(), Box> { .domain_name("example.com"); let channel = Channel::from_static("http://[::1]:50051") - .tls_config(tls) + .tls_config(tls)? .connect() .await?; diff --git a/examples/src/tls/server.rs b/examples/src/tls/server.rs index 0aa9bc3c8..fc50c57ab 100644 --- a/examples/src/tls/server.rs +++ b/examples/src/tls/server.rs @@ -60,7 +60,7 @@ async fn main() -> Result<(), Box> { let server = EchoServer::default(); Server::builder() - .tls_config(ServerTlsConfig::new().identity(identity)) + .tls_config(ServerTlsConfig::new().identity(identity))? .add_service(pb::echo_server::EchoServer::new(server)) .serve(addr) .await?; diff --git a/examples/src/tls_client_auth/client.rs b/examples/src/tls_client_auth/client.rs index bd41a2bce..97966bf38 100644 --- a/examples/src/tls_client_auth/client.rs +++ b/examples/src/tls_client_auth/client.rs @@ -19,7 +19,7 @@ async fn main() -> Result<(), Box> { .identity(client_identity); let channel = Channel::from_static("http://[::1]:50051") - .tls_config(tls) + .tls_config(tls)? .connect() .await?; diff --git a/examples/src/tls_client_auth/server.rs b/examples/src/tls_client_auth/server.rs index 43061aa98..2a719f6f4 100644 --- a/examples/src/tls_client_auth/server.rs +++ b/examples/src/tls_client_auth/server.rs @@ -68,7 +68,7 @@ async fn main() -> Result<(), Box> { .client_ca_root(client_ca_cert); Server::builder() - .tls_config(tls) + .tls_config(tls)? .add_service(pb::echo_server::EchoServer::new(server)) .serve(addr) .await?; diff --git a/interop/src/bin/client.rs b/interop/src/bin/client.rs index 307548e23..1f676faeb 100644 --- a/interop/src/bin/client.rs +++ b/interop/src/bin/client.rs @@ -38,7 +38,7 @@ async fn main() -> Result<(), Box> { ClientTlsConfig::new() .ca_certificate(ca) .domain_name("foo.test.google.fr"), - ); + )?; } let channel = endpoint.connect().await?; diff --git a/interop/src/bin/server.rs b/interop/src/bin/server.rs index 35fdb1921..b48987dad 100644 --- a/interop/src/bin/server.rs +++ b/interop/src/bin/server.rs @@ -24,7 +24,7 @@ async fn main() -> std::result::Result<(), Box> { let key = tokio::fs::read("interop/data/server1.key").await?; let identity = Identity::from_pem(cert, key); - builder = builder.tls_config(ServerTlsConfig::new().identity(identity)); + builder = builder.tls_config(ServerTlsConfig::new().identity(identity))?; } let test_service = server::TestServiceServer::new(server::TestService::default()); diff --git a/tonic/src/transport/channel/endpoint.rs b/tonic/src/transport/channel/endpoint.rs index 346fa27d6..3934c2ea7 100644 --- a/tonic/src/transport/channel/endpoint.rs +++ b/tonic/src/transport/channel/endpoint.rs @@ -155,11 +155,15 @@ impl Endpoint { /// Configures TLS for the endpoint. #[cfg(feature = "tls")] #[cfg_attr(docsrs, doc(cfg(feature = "tls")))] - pub fn tls_config(self, tls_config: ClientTlsConfig) -> Self { - Endpoint { - tls: Some(tls_config.tls_connector(self.uri.clone()).unwrap()), + pub fn tls_config(self, tls_config: ClientTlsConfig) -> Result { + Ok(Endpoint { + tls: Some( + tls_config + .tls_connector(self.uri.clone()) + .map_err(|e| Error::from_source(e))?, + ), ..self - } + }) } /// Set the value of `TCP_NODELAY` option for accepted connections. Enabled by default. diff --git a/tonic/src/transport/mod.rs b/tonic/src/transport/mod.rs index 7bee18167..40bd7b110 100644 --- a/tonic/src/transport/mod.rs +++ b/tonic/src/transport/mod.rs @@ -30,7 +30,7 @@ //! let mut channel = Channel::from_static("https://example.com") //! .tls_config(ClientTlsConfig::new() //! .ca_certificate(Certificate::from_pem(&cert)) -//! .domain_name("example.com".to_string())) +//! .domain_name("example.com".to_string()))? //! .timeout(Duration::from_secs(5)) //! .rate_limit(5, Duration::from_secs(1)) //! .concurrency_limit(256) @@ -74,7 +74,7 @@ //! //! Server::builder() //! .tls_config(ServerTlsConfig::new() -//! .identity(Identity::from_pem(&cert, &key))) +//! .identity(Identity::from_pem(&cert, &key)))? //! .concurrency_limit_per_connection(256) //! .add_service(my_svc) //! .serve(addr) diff --git a/tonic/src/transport/server/mod.rs b/tonic/src/transport/server/mod.rs index ce84b7d42..7751029f4 100644 --- a/tonic/src/transport/server/mod.rs +++ b/tonic/src/transport/server/mod.rs @@ -18,6 +18,9 @@ use incoming::TcpIncoming; #[cfg(feature = "tls")] pub(crate) use incoming::TlsStream; +#[cfg(feature = "tls")] +use crate::transport::Error; + use super::service::{Or, Routes, ServerIo, ServiceBuilderExt}; use crate::{body::BoxBody, request::ConnectionInfo}; use futures_core::Stream; @@ -97,11 +100,15 @@ impl Server { /// Configure TLS for this server. #[cfg(feature = "tls")] #[cfg_attr(docsrs, doc(cfg(feature = "tls")))] - pub fn tls_config(self, tls_config: ServerTlsConfig) -> Self { - Server { - tls: Some(tls_config.tls_acceptor().unwrap()), + pub fn tls_config(self, tls_config: ServerTlsConfig) -> Result { + Ok(Server { + tls: Some( + tls_config + .tls_acceptor() + .map_err(|e| Error::from_source(e))?, + ), ..self - } + }) } /// Set the concurrency limit applied to on requests inbound per connection.