diff --git a/tonic/src/transport/endpoint.rs b/tonic/src/transport/endpoint.rs index 6db484976..525e33145 100644 --- a/tonic/src/transport/endpoint.rs +++ b/tonic/src/transport/endpoint.rs @@ -29,6 +29,7 @@ pub struct Endpoint { Option>, pub(super) init_stream_window_size: Option, pub(super) init_connection_window_size: Option, + pub(super) tcp_keepalive: Option, } impl Endpoint { @@ -83,6 +84,21 @@ impl Endpoint { } } + /// Set whether TCP keepalive messages are enabled on accepted connections. + /// + /// If `None` is specified, keepalive is disabled, otherwise the duration + /// specified will be the time to remain idle before sending TCP keepalive + /// probes. + /// + /// Default is no keepalive (`None`) + /// + pub fn tcp_keepalive(self, tcp_keepalive: Option) -> Self { + Endpoint { + tcp_keepalive, + ..self + } + } + /// Apply a concurrency limit to each request. /// /// ``` @@ -174,6 +190,7 @@ impl From for Endpoint { interceptor_headers: None, init_stream_window_size: None, init_connection_window_size: None, + tcp_keepalive: None, } } } diff --git a/tonic/src/transport/server.rs b/tonic/src/transport/server.rs index f6794b0a9..179c5fbfc 100644 --- a/tonic/src/transport/server.rs +++ b/tonic/src/transport/server.rs @@ -14,6 +14,7 @@ use hyper::{ server::{accept::Accept, conn}, Body, }; +use std::time::Duration; use std::{ fmt, future::Future, @@ -54,6 +55,7 @@ pub struct Server { init_stream_window_size: Option, init_connection_window_size: Option, max_concurrent_streams: Option, + tcp_keepalive: Option, } /// A stack based `Service` router. @@ -147,6 +149,21 @@ impl Server { } } + /// Set whether TCP keepalive messages are enabled on accepted connections. + /// + /// If `None` is specified, keepalive is disabled, otherwise the duration + /// specified will be the time to remain idle before sending TCP keepalive + /// probes. + /// + /// Default is no keepalive (`None`) + /// + pub fn tcp_keepalive(self, tcp_keepalive: Option) -> Self { + Server { + tcp_keepalive, + ..self + } + } + /// Intercept the execution of gRPC methods. /// /// ``` @@ -204,11 +221,12 @@ impl Server { let init_connection_window_size = self.init_connection_window_size; let init_stream_window_size = self.init_stream_window_size; let max_concurrent_streams = self.max_concurrent_streams; + let tcp_keepalive = self.tcp_keepalive; // let timeout = self.timeout.clone(); let incoming = hyper::server::accept::from_stream::<_, _, crate::Error>( async_stream::try_stream! { - let mut tcp = TcpIncoming::bind(addr)?; + let mut tcp = TcpIncoming::bind(addr, tcp_keepalive)?; while let Some(stream) = tcp.try_next().await? { #[cfg(feature = "tls")] @@ -400,9 +418,10 @@ struct TcpIncoming { } impl TcpIncoming { - fn bind(addr: SocketAddr) -> Result { + fn bind(addr: SocketAddr, tcp_keepalive: Option) -> Result { let mut inner = conn::AddrIncoming::bind(&addr).map_err(Box::new)?; inner.set_nodelay(true); + inner.set_keepalive(tcp_keepalive); Ok(Self { inner }) } diff --git a/tonic/src/transport/service/connection.rs b/tonic/src/transport/service/connection.rs index 47693b35e..10fcf3545 100644 --- a/tonic/src/transport/service/connection.rs +++ b/tonic/src/transport/service/connection.rs @@ -28,10 +28,10 @@ pub(crate) struct Connection { impl Connection { pub(crate) async fn new(endpoint: Endpoint) -> Result { #[cfg(feature = "tls")] - let connector = connector(endpoint.tls.clone()); + let connector = connector(endpoint.tls.clone(), endpoint.tcp_keepalive); #[cfg(not(feature = "tls"))] - let connector = connector(); + let connector = connector(endpoint.tcp_keepalive); let settings = Builder::new() .http2_initial_stream_window_size(endpoint.init_stream_window_size) diff --git a/tonic/src/transport/service/connector.rs b/tonic/src/transport/service/connector.rs index c02d17367..17e2d8d41 100644 --- a/tonic/src/transport/service/connector.rs +++ b/tonic/src/transport/service/connector.rs @@ -6,20 +6,22 @@ use hyper::client::connect::HttpConnector; use std::future::Future; use std::pin::Pin; use std::task::{Context, Poll}; +use std::time::Duration; use tower_make::MakeConnection; use tower_service::Service; #[cfg(not(feature = "tls"))] -pub(crate) fn connector() -> HttpConnector { +pub(crate) fn connector(tcp_keepalive: Option) -> HttpConnector { let mut http = HttpConnector::new(); http.enforce_http(false); http.set_nodelay(true); + http.set_keepalive(tcp_keepalive); http } #[cfg(feature = "tls")] -pub(crate) fn connector(tls: Option) -> Connector { - Connector::new(tls) +pub(crate) fn connector(tls: Option, tcp_keepalive: Option) -> Connector { + Connector::new(tls, tcp_keepalive) } pub(crate) struct Connector { @@ -30,11 +32,11 @@ pub(crate) struct Connector { impl Connector { #[cfg(feature = "tls")] - pub(crate) fn new(tls: Option) -> Self { + pub(crate) fn new(tls: Option, tcp_keepalive: Option) -> Self { let mut http = HttpConnector::new(); http.enforce_http(false); http.set_nodelay(true); - + http.set_keepalive(tcp_keepalive); Self { http, tls } } }