Skip to content

Commit

Permalink
feat(transport): Make transport server and channel independent
Browse files Browse the repository at this point in the history
  • Loading branch information
tottoto committed Feb 9, 2024
1 parent 1fe8cb0 commit d258fb8
Show file tree
Hide file tree
Showing 15 changed files with 136 additions and 110 deletions.
5 changes: 4 additions & 1 deletion .github/workflows/CI.yml
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,10 @@ jobs:
with:
tool: protoc@${{ env.PROTOC_VERSION }}
- uses: Swatinem/rust-cache@v2
- run: cargo hack udeps --workspace --each-feature ${{ matrix.option }}
- run: cargo hack udeps --workspace --exclude-features tls --each-feature ${{ matrix.option }}
- run: cargo udeps --package tonic --features tls,transport
- run: cargo udeps --package tonic --features tls,server
- run: cargo udeps --package tonic --features tls,channel

check:
runs-on: ${{ matrix.os }}
Expand Down
19 changes: 12 additions & 7 deletions tonic/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -26,24 +26,29 @@ version = "0.11.0"
codegen = ["dep:async-trait"]
gzip = ["dep:flate2"]
zstd = ["dep:zstd"]
default = ["channel", "codegen", "prost"]
default = ["channel", "codegen", "prost", "server"]
prost = ["dep:prost"]
tls = ["dep:rustls-pki-types", "dep:rustls-pemfile", "transport", "dep:tokio-rustls", "dep:tokio", "tokio?/rt", "tokio?/macros"]
tls = ["dep:rustls-pki-types", "dep:rustls-pemfile", "dep:tokio-rustls", "dep:tokio", "tokio?/rt", "tokio?/macros"]
tls-roots = ["tls-roots-common", "dep:rustls-native-certs"]
tls-roots-common = ["tls", "channel"]
tls-webpki-roots = ["tls-roots-common", "dep:webpki-roots"]
transport = [
"dep:tower", "tower?/util", "tower?/limit",
"dep:tokio", "tokio?/time",
"dep:hyper",
]
server = [
"transport",
"dep:async-stream",
"tokio?/net",
"dep:axum",
"dep:h2",
"dep:hyper", "hyper?/server",
"dep:tokio", "tokio?/net", "tokio?/time",
"dep:tower", "tower?/util", "tower?/limit",
"hyper?/server",
]
channel = [
"transport",
"dep:hyper", "hyper?/client",
"dep:tower", "tower?/balance", "tower?/buffer", "tower?/discover", "tower?/load", "tower?/make",
"hyper?/client",
"tower?/balance", "tower?/buffer", "tower?/discover", "tower?/load", "tower?/make",
"dep:hyper-timeout",
]

Expand Down
20 changes: 10 additions & 10 deletions tonic/src/request.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
use crate::metadata::{MetadataMap, MetadataValue};
#[cfg(feature = "transport")]
#[cfg(feature = "server")]
use crate::transport::server::TcpConnectInfo;
#[cfg(feature = "tls")]
#[cfg(all(feature = "server", feature = "tls"))]
use crate::transport::{server::TlsConnectInfo, Certificate};
use crate::Extensions;
#[cfg(feature = "transport")]
#[cfg(feature = "server")]
use std::net::SocketAddr;
#[cfg(feature = "tls")]
#[cfg(all(feature = "server", feature = "tls"))]
use std::sync::Arc;
use std::time::Duration;
use tokio_stream::Stream;
Expand Down Expand Up @@ -209,8 +209,8 @@ impl<T> Request<T> {
/// This will return `None` if the `IO` type used
/// does not implement `Connected` or when using a unix domain socket.
/// This currently only works on the server side.
#[cfg(feature = "transport")]
#[cfg_attr(docsrs, doc(cfg(feature = "transport")))]
#[cfg(feature = "server")]
#[cfg_attr(docsrs, doc(cfg(feature = "server")))]
pub fn local_addr(&self) -> Option<SocketAddr> {
let addr = self
.extensions()
Expand All @@ -232,8 +232,8 @@ impl<T> Request<T> {
/// This will return `None` if the `IO` type used
/// does not implement `Connected` or when using a unix domain socket.
/// This currently only works on the server side.
#[cfg(feature = "transport")]
#[cfg_attr(docsrs, doc(cfg(feature = "transport")))]
#[cfg(feature = "server")]
#[cfg_attr(docsrs, doc(cfg(feature = "server")))]
pub fn remote_addr(&self) -> Option<SocketAddr> {
let addr = self
.extensions()
Expand All @@ -256,8 +256,8 @@ impl<T> Request<T> {
/// and is mostly used for mTLS. This currently only returns
/// `Some` on the server side of the `transport` server with
/// TLS enabled connections.
#[cfg(feature = "tls")]
#[cfg_attr(docsrs, doc(cfg(feature = "tls")))]
#[cfg(all(feature = "server", feature = "tls"))]
#[cfg_attr(docsrs, doc(cfg(all(feature = "server", feature = "tls"))))]
pub fn peer_certs(&self) -> Option<Arc<Vec<Certificate>>> {
self.extensions()
.get::<TlsConnectInfo<TcpConnectInfo>>()
Expand Down
18 changes: 9 additions & 9 deletions tonic/src/status.rs
Original file line number Diff line number Diff line change
Expand Up @@ -342,7 +342,7 @@ impl Status {
Err(err) => err,
};

#[cfg(feature = "transport")]
#[cfg(feature = "server")]
let err = match err.downcast::<h2::Error>() {
Ok(h2) => {
return Ok(Status::from_h2_error(h2));
Expand All @@ -359,7 +359,7 @@ impl Status {
}

// FIXME: bubble this into `transport` and expose generic http2 reasons.
#[cfg(feature = "transport")]
#[cfg(feature = "server")]
fn from_h2_error(err: Box<h2::Error>) -> Status {
let code = Self::code_from_h2(&err);

Expand All @@ -368,7 +368,7 @@ impl Status {
status
}

#[cfg(feature = "transport")]
#[cfg(feature = "server")]
fn code_from_h2(err: &h2::Error) -> Code {
// See https://github.com/grpc/grpc/blob/3977c30/doc/PROTOCOL-HTTP2.md#errors
match err.reason() {
Expand All @@ -388,7 +388,7 @@ impl Status {
}
}

#[cfg(feature = "transport")]
#[cfg(feature = "server")]
fn to_h2_error(&self) -> h2::Error {
// conservatively transform to h2 error codes...
let reason = match self.code {
Expand All @@ -404,7 +404,7 @@ impl Status {
///
/// Returns Some if there's a way to handle the error, or None if the information from this
/// hyper error, but perhaps not its source, should be ignored.
#[cfg(feature = "transport")]
#[cfg(feature = "server")]
fn from_hyper_error(err: &hyper::Error) -> Option<Status> {
// is_timeout results from hyper's keep-alive logic
// (https://docs.rs/hyper/0.14.11/src/hyper/error.rs.html#192-194). Per the grpc spec
Expand Down Expand Up @@ -614,12 +614,12 @@ fn find_status_in_source_chain(err: &(dyn Error + 'static)) -> Option<Status> {
});
}

#[cfg(feature = "transport")]
#[cfg(any(feature = "server", feature = "channel"))]
if let Some(timeout) = err.downcast_ref::<crate::transport::TimeoutExpired>() {
return Some(Status::cancelled(timeout.to_string()));
}

#[cfg(feature = "transport")]
#[cfg(feature = "server")]
if let Some(hyper) = err
.downcast_ref::<hyper::Error>()
.and_then(Status::from_hyper_error)
Expand Down Expand Up @@ -666,14 +666,14 @@ fn invalid_header_value_byte<Error: fmt::Display>(err: Error) -> Status {
)
}

#[cfg(feature = "transport")]
#[cfg(feature = "server")]
impl From<h2::Error> for Status {
fn from(err: h2::Error) -> Self {
Status::from_h2_error(Box::new(err))
}
}

#[cfg(feature = "transport")]
#[cfg(feature = "server")]
impl From<Status> for h2::Error {
fn from(status: Status) -> Self {
status.to_h2_error()
Expand Down
6 changes: 5 additions & 1 deletion tonic/src/transport/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ struct ErrorImpl {

#[derive(Debug)]
pub(crate) enum Kind {
#[allow(unused)]
Transport,
#[cfg(feature = "channel")]
InvalidUri,
Expand All @@ -22,17 +23,20 @@ pub(crate) enum Kind {
}

impl Error {
pub(crate) fn new(kind: Kind) -> Self {
#[cfg(any(feature = "server", feature = "channel"))]
fn new(kind: Kind) -> Self {
Self {
inner: ErrorImpl { kind, source: None },
}
}

#[cfg(any(feature = "server", feature = "channel"))]
pub(crate) fn with(mut self, source: impl Into<Source>) -> Self {
self.inner.source = Some(source.into());
self
}

#[cfg(any(feature = "server", feature = "channel"))]
pub(crate) fn from_source(source: impl Into<crate::Error>) -> Self {
Error::new(Kind::Transport).with(source)
}
Expand Down
9 changes: 7 additions & 2 deletions tonic/src/transport/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@
#[cfg(feature = "channel")]
pub mod channel;
#[cfg(feature = "server")]
pub mod server;

mod error;
Expand All @@ -102,12 +103,16 @@ mod tls;
pub use self::channel::{Channel, Endpoint};
pub use self::error::Error;
#[doc(inline)]
#[cfg(feature = "server")]
pub use self::server::Server;
#[doc(inline)]
#[cfg(any(feature = "server", feature = "channel"))]
pub use self::service::grpc_timeout::TimeoutExpired;
#[cfg(feature = "tls")]
#[cfg_attr(docsrs, doc(cfg(feature = "tls")))]
pub use self::tls::Certificate;
#[cfg(feature = "server")]
#[cfg_attr(docsrs, doc(cfg(feature = "server")))]
pub use axum::{body::BoxBody as AxumBoxBody, Router as AxumRouter};
pub use hyper::{Body, Uri};

Expand All @@ -117,8 +122,8 @@ pub(crate) use self::channel::service::executor::Executor;
#[cfg(all(feature = "channel", feature = "tls"))]
#[cfg_attr(docsrs, doc(cfg(all(feature = "channel", feature = "tls"))))]
pub use self::channel::ClientTlsConfig;
#[cfg(feature = "tls")]
#[cfg_attr(docsrs, doc(cfg(feature = "tls")))]
#[cfg(all(feature = "server", feature = "tls"))]
#[cfg_attr(docsrs, doc(cfg(all(feature = "server", feature = "tls"))))]
pub use self::server::ServerTlsConfig;
#[cfg(feature = "tls")]
#[cfg_attr(docsrs, doc(cfg(feature = "tls")))]
Expand Down
2 changes: 1 addition & 1 deletion tonic/src/transport/server/incoming.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use super::service::ServerIo;
use super::{Connected, Server};
use crate::transport::service::ServerIo;
use hyper::server::{
accept::Accept,
conn::{AddrIncoming, AddrStream},
Expand Down
9 changes: 5 additions & 4 deletions tonic/src/transport/server/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,14 @@
mod conn;
mod incoming;
mod recover_error;
mod service;
#[cfg(feature = "tls")]
#[cfg_attr(docsrs, doc(cfg(feature = "tls")))]
mod tls;
#[cfg(unix)]
mod unix;

pub use super::service::Routes;
pub use super::service::RoutesBuilder;
pub use self::service::{Routes, RoutesBuilder};

pub use conn::{Connected, TcpConnectInfo};
#[cfg(feature = "tls")]
Expand All @@ -20,7 +20,7 @@ pub use tls::ServerTlsConfig;
pub use conn::TlsConnectInfo;

#[cfg(feature = "tls")]
use super::service::TlsAcceptor;
use self::service::tls::TlsAcceptor;

#[cfg(unix)]
pub use unix::UdsConnectInfo;
Expand All @@ -34,7 +34,8 @@ pub(crate) use tokio_rustls::server::TlsStream;
use crate::transport::Error;

use self::recover_error::RecoverError;
use super::service::{GrpcTimeout, ServerIo};
use self::service::ServerIo;
use super::service::GrpcTimeout;
use crate::body::BoxBody;
use crate::server::NamedService;
use bytes::Bytes;
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
use crate::transport::server::Connected;
use std::io;
use std::io::IoSlice;
use std::pin::Pin;
Expand All @@ -7,6 +6,8 @@ use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
#[cfg(feature = "tls")]
use tokio_rustls::server::TlsStream;

use super::super::Connected;

pub(crate) enum ServerIo<IO> {
Io(IO),
#[cfg(feature = "tls")]
Expand Down
8 changes: 8 additions & 0 deletions tonic/src/transport/server/service/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
pub(crate) mod io;
pub(crate) use self::io::ServerIo;

mod router;
pub use self::router::{Routes, RoutesBuilder};

#[cfg(feature = "tls")]
pub(crate) mod tls;
File renamed without changes.
65 changes: 65 additions & 0 deletions tonic/src/transport/server/service/tls.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
use std::io::Cursor;
use std::{fmt, sync::Arc};

use tokio::io::{AsyncRead, AsyncWrite};
use tokio_rustls::{
rustls::{server::WebPkiClientVerifier, RootCertStore, ServerConfig},
TlsAcceptor as RustlsAcceptor,
};

use crate::transport::server::Connected;
use crate::transport::server::TlsStream;
use crate::transport::service::tls::{add_certs_from_pem, load_identity, ALPN_H2};
use crate::transport::tls::{Certificate, Identity};

#[derive(Clone)]
pub(crate) struct TlsAcceptor {
inner: Arc<ServerConfig>,
}

impl TlsAcceptor {
pub(crate) fn new(
identity: Identity,
client_ca_root: Option<Certificate>,
client_auth_optional: bool,
) -> Result<Self, crate::Error> {
let builder = ServerConfig::builder();

let builder = match client_ca_root {
None => builder.with_no_client_auth(),
Some(cert) => {
let mut roots = RootCertStore::empty();
add_certs_from_pem(&mut Cursor::new(cert), &mut roots)?;
let verifier = if client_auth_optional {
WebPkiClientVerifier::builder(roots.into()).allow_unauthenticated()
} else {
WebPkiClientVerifier::builder(roots.into())
}
.build()?;
builder.with_client_cert_verifier(verifier)
}
};

let (cert, key) = load_identity(identity)?;
let mut config = builder.with_single_cert(cert, key)?;

config.alpn_protocols.push(ALPN_H2.into());
Ok(Self {
inner: Arc::new(config),
})
}

pub(crate) async fn accept<IO>(&self, io: IO) -> Result<TlsStream<IO>, crate::Error>
where
IO: AsyncRead + AsyncWrite + Connected + Unpin + Send + 'static,
{
let acceptor = RustlsAcceptor::from(self.inner.clone());
acceptor.accept(io).await.map_err(Into::into)
}
}

impl fmt::Debug for TlsAcceptor {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("TlsAcceptor").finish()
}
}
7 changes: 3 additions & 4 deletions tonic/src/transport/server/tls.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
use crate::transport::{
service::TlsAcceptor,
tls::{Certificate, Identity},
};
use std::fmt;

use crate::transport::server::service::tls::TlsAcceptor;
use crate::transport::tls::{Certificate, Identity};

/// Configures TLS settings for servers.
#[derive(Clone, Default)]
pub struct ServerTlsConfig {
Expand Down
Loading

0 comments on commit d258fb8

Please sign in to comment.