From 3eb76abf9fdce5f903de1a7f05b8afc8694fa0ce Mon Sep 17 00:00:00 2001 From: Lucio Franco Date: Fri, 13 Dec 2019 20:25:00 -0500 Subject: [PATCH] =?UTF-8?q?feat(transport):=20Add=20`remote=5Faddr`=20to?= =?UTF-8?q?=20`Request`=20on=20the=20server=20si=E2=80=A6=20(#186)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- examples/src/helloworld/server.rs | 2 +- examples/src/uds/server.rs | 52 +++++++++++++++++++++-- tonic/src/request.rs | 25 +++++++++++ tonic/src/transport/server/conn.rs | 30 ++++++++++++++ tonic/src/transport/server/incoming.rs | 12 +++--- tonic/src/transport/server/mod.rs | 35 ++++++++++++---- tonic/src/transport/service/io.rs | 57 ++++++++++++++++++++++++-- tonic/src/transport/service/mod.rs | 2 +- tonic/src/transport/service/tls.rs | 4 +- 9 files changed, 193 insertions(+), 26 deletions(-) create mode 100644 tonic/src/transport/server/conn.rs diff --git a/examples/src/helloworld/server.rs b/examples/src/helloworld/server.rs index 23a3454e4..3a3051c66 100644 --- a/examples/src/helloworld/server.rs +++ b/examples/src/helloworld/server.rs @@ -18,7 +18,7 @@ impl Greeter for MyGreeter { &self, request: Request, ) -> Result, Status> { - println!("Got a request: {:?}", request); + println!("Got a request from {:?}", request.remote_addr()); let reply = hello_world::HelloReply { message: format!("Hello {}!", request.into_inner().name).into(), diff --git a/examples/src/uds/server.rs b/examples/src/uds/server.rs index 36258ab49..103205511 100644 --- a/examples/src/uds/server.rs +++ b/examples/src/uds/server.rs @@ -1,6 +1,17 @@ -use std::path::Path; -use tokio::net::UnixListener; -use tonic::{transport::Server, Request, Response, Status}; +use futures::stream::TryStreamExt; +use std::{ + path::Path, + pin::Pin, + task::{Context, Poll}, +}; +use tokio::{ + io::{AsyncRead, AsyncWrite}, + net::UnixListener, +}; +use tonic::{ + transport::{server::Connected, Server}, + Request, Response, Status, +}; pub mod hello_world { tonic::include_proto!("helloworld"); @@ -41,8 +52,41 @@ async fn main() -> Result<(), Box> { Server::builder() .add_service(GreeterServer::new(greeter)) - .serve_with_incoming(uds.incoming()) + .serve_with_incoming(uds.incoming().map_ok(UnixStream)) .await?; Ok(()) } + +#[derive(Debug)] +struct UnixStream(tokio::net::UnixStream); + +impl Connected for UnixStream {} + +impl AsyncRead for UnixStream { + fn poll_read( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut [u8], + ) -> Poll> { + Pin::new(&mut self.0).poll_read(cx, buf) + } +} + +impl AsyncWrite for UnixStream { + fn poll_write( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + Pin::new(&mut self.0).poll_write(cx, buf) + } + + fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + Pin::new(&mut self.0).poll_flush(cx) + } + + fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + Pin::new(&mut self.0).poll_shutdown(cx) + } +} diff --git a/tonic/src/request.rs b/tonic/src/request.rs index 27af35a25..489c2a228 100644 --- a/tonic/src/request.rs +++ b/tonic/src/request.rs @@ -1,11 +1,19 @@ use crate::metadata::MetadataMap; use futures_core::Stream; +use http::Extensions; +use std::net::SocketAddr; /// A gRPC request and metadata from an RPC call. #[derive(Debug)] pub struct Request { metadata: MetadataMap, message: T, + extensions: Extensions, +} + +#[derive(Clone)] +pub(crate) struct ConnectionInfo { + pub(crate) remote_addr: Option, } /// Trait implemented by RPC request types. @@ -102,6 +110,7 @@ impl Request { Request { metadata: MetadataMap::new(), message, + extensions: Extensions::default(), } } @@ -134,6 +143,7 @@ impl Request { Request { metadata: MetadataMap::from_headers(parts.headers), message, + extensions: parts.extensions, } } @@ -150,6 +160,7 @@ impl Request { *request.method_mut() = http::Method::POST; *request.uri_mut() = uri; *request.headers_mut() = self.metadata.into_sanitized_headers(); + *request.extensions_mut() = self.extensions; request } @@ -164,8 +175,22 @@ impl Request { Request { metadata: self.metadata, message, + extensions: Extensions::default(), } } + + /// Get the remote address of this connection. + /// + /// This will return `None` if the `IO` type used + /// does not implement `Connected`. This currently, + /// only works on the server side. + pub fn remote_addr(&self) -> Option { + self.get::()?.remote_addr + } + + pub(crate) fn get(&self) -> Option<&I> { + self.extensions.get::() + } } impl IntoRequest for T { diff --git a/tonic/src/transport/server/conn.rs b/tonic/src/transport/server/conn.rs new file mode 100644 index 000000000..f50df1075 --- /dev/null +++ b/tonic/src/transport/server/conn.rs @@ -0,0 +1,30 @@ +use hyper::server::conn::AddrStream; +use std::net::SocketAddr; +#[cfg(feature = "tls")] +use tokio_rustls::TlsStream; + +/// Trait that connected IO resources implement. +/// +/// The goal for this trait is to allow users to implement +/// custom IO types that can still provide the same connection +/// metadata. +pub trait Connected { + /// Return the remote address this IO resource is connected too. + fn remote_addr(&self) -> Option { + None + } +} + +impl Connected for AddrStream { + fn remote_addr(&self) -> Option { + Some(self.remote_addr()) + } +} + +#[cfg(feature = "tls")] +impl Connected for TlsStream { + fn remote_addr(&self) -> Option { + let (inner, _) = self.get_ref(); + inner.remote_addr() + } +} diff --git a/tonic/src/transport/server/incoming.rs b/tonic/src/transport/server/incoming.rs index a6c78ce89..5232d6bdf 100644 --- a/tonic/src/transport/server/incoming.rs +++ b/tonic/src/transport/server/incoming.rs @@ -1,5 +1,5 @@ -use super::Server; -use crate::transport::service::BoxedIo; +use super::{Connected, Server}; +use crate::transport::service::ServerIo; use futures_core::Stream; use futures_util::stream::TryStreamExt; use hyper::server::{ @@ -20,9 +20,9 @@ use tracing::error; pub(crate) fn tcp_incoming( incoming: impl Stream>, server: Server, -) -> impl Stream> +) -> impl Stream> where - IO: AsyncRead + AsyncWrite + Unpin + Send + 'static, + IO: AsyncRead + AsyncWrite + Connected + Unpin + Send + 'static, IE: Into, { async_stream::try_stream! { @@ -39,12 +39,12 @@ where continue }, }; - yield BoxedIo::new(io); + yield ServerIo::new(io); continue; } } - yield BoxedIo::new(stream); + yield ServerIo::new(stream); } } } diff --git a/tonic/src/transport/server/mod.rs b/tonic/src/transport/server/mod.rs index e1bd03b8f..913fb03b3 100644 --- a/tonic/src/transport/server/mod.rs +++ b/tonic/src/transport/server/mod.rs @@ -1,9 +1,11 @@ //! Server implementation and builder. +mod conn; mod incoming; #[cfg(feature = "tls")] mod tls; +pub use conn::Connected; #[cfg(feature = "tls")] pub use tls::ServerTlsConfig; @@ -12,8 +14,8 @@ use super::service::TlsAcceptor; use incoming::TcpIncoming; -use super::service::{layer_fn, Or, Routes, ServiceBuilderExt}; -use crate::body::BoxBody; +use super::service::{layer_fn, Or, Routes, ServerIo, ServiceBuilderExt}; +use crate::{body::BoxBody, request::ConnectionInfo}; use futures_core::Stream; use futures_util::{ future::{self, MapErr}, @@ -252,7 +254,7 @@ impl Server { S::Future: Send + 'static, S::Error: Into + Send, I: Stream>, - IO: AsyncRead + AsyncWrite + Unpin + Send + 'static, + IO: AsyncRead + AsyncWrite + Connected + Unpin + Send + 'static, IE: Into, F: Future, { @@ -390,7 +392,7 @@ where pub async fn serve_with_incoming(self, incoming: I) -> Result<(), super::Error> where I: Stream>, - IO: AsyncRead + AsyncWrite + Unpin + Send + 'static, + IO: AsyncRead + AsyncWrite + Connected + Unpin + Send + 'static, IE: Into, { self.server @@ -412,6 +414,7 @@ impl fmt::Debug for Server { struct Svc { inner: S, span: Option, + conn_info: ConnectionInfo, } impl Service> for Svc @@ -427,13 +430,15 @@ where self.inner.poll_ready(cx).map_err(Into::into) } - fn call(&mut self, req: Request) -> Self::Future { + fn call(&mut self, mut req: Request) -> Self::Future { let span = if let Some(trace_interceptor) = &self.span { trace_interceptor(req.headers()) } else { tracing::Span::none() }; + req.extensions_mut().insert(self.conn_info.clone()); + self.inner.call(req).instrument(span).map_err(|e| e.into()) } } @@ -452,7 +457,7 @@ struct MakeSvc { span: Option, } -impl Service for MakeSvc +impl Service<&ServerIo> for MakeSvc where S: Service, Response = Response> + Clone + Send + 'static, S::Future: Send + 'static, @@ -467,7 +472,11 @@ where Ok(()).into() } - fn call(&mut self, _: T) -> Self::Future { + fn call(&mut self, io: &ServerIo) -> Self::Future { + let conn_info = crate::request::ConnectionInfo { + remote_addr: io.remote_addr(), + }; + let interceptor = self.interceptor.clone(); let svc = self.inner.clone(); let concurrency_limit = self.concurrency_limit; @@ -481,10 +490,18 @@ where .service(svc); let svc = if let Some(interceptor) = interceptor { - let layered = interceptor.layer(BoxService::new(Svc { inner: svc, span })); + let layered = interceptor.layer(BoxService::new(Svc { + inner: svc, + span, + conn_info, + })); BoxService::new(layered) } else { - BoxService::new(Svc { inner: svc, span }) + BoxService::new(Svc { + inner: svc, + span, + conn_info, + }) }; Ok(svc) diff --git a/tonic/src/transport/service/io.rs b/tonic/src/transport/service/io.rs index e001f580f..a72bc1d4f 100644 --- a/tonic/src/transport/service/io.rs +++ b/tonic/src/transport/service/io.rs @@ -1,5 +1,7 @@ -use hyper::client::connect::{Connected, Connection}; +use crate::transport::server::Connected; +use hyper::client::connect::{Connected as HyperConnected, Connection}; use std::io; +use std::net::SocketAddr; use std::pin::Pin; use std::task::{Context, Poll}; use tokio::io::{AsyncRead, AsyncWrite}; @@ -20,11 +22,13 @@ impl BoxedIo { } impl Connection for BoxedIo { - fn connected(&self) -> Connected { - Connected::new() + fn connected(&self) -> HyperConnected { + HyperConnected::new() } } +impl Connected for BoxedIo {} + impl AsyncRead for BoxedIo { fn poll_read( mut self: Pin<&mut Self>, @@ -52,3 +56,50 @@ impl AsyncWrite for BoxedIo { Pin::new(&mut self.0).poll_shutdown(cx) } } + +pub(in crate::transport) trait ConnectedIo: Io + Connected {} + +impl ConnectedIo for T where T: Io + Connected {} + +pub(crate) struct ServerIo(Pin>); + +impl ServerIo { + pub(in crate::transport) fn new(io: I) -> Self { + ServerIo(Box::pin(io)) + } +} + +impl Connected for ServerIo { + fn remote_addr(&self) -> Option { + let io = &*self.0; + io.remote_addr() + } +} + +impl AsyncRead for ServerIo { + fn poll_read( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut [u8], + ) -> Poll> { + Pin::new(&mut self.0).poll_read(cx, buf) + } +} + +impl AsyncWrite for ServerIo { + fn poll_write( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + Pin::new(&mut self.0).poll_write(cx, buf) + } + + fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + Pin::new(&mut self.0).poll_flush(cx) + } + + fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + Pin::new(&mut self.0).poll_shutdown(cx) + } +} diff --git a/tonic/src/transport/service/mod.rs b/tonic/src/transport/service/mod.rs index 0819f02da..40800033e 100644 --- a/tonic/src/transport/service/mod.rs +++ b/tonic/src/transport/service/mod.rs @@ -14,7 +14,7 @@ pub(crate) use self::add_origin::AddOrigin; pub(crate) use self::connection::Connection; pub(crate) use self::connector::connector; pub(crate) use self::discover::ServiceList; -pub(crate) use self::io::BoxedIo; +pub(crate) use self::io::ServerIo; pub(crate) use self::layer::{layer_fn, ServiceBuilderExt}; pub(crate) use self::router::{Or, Routes}; #[cfg(feature = "tls")] diff --git a/tonic/src/transport/service/tls.rs b/tonic/src/transport/service/tls.rs index 876ef1631..f6b79868b 100644 --- a/tonic/src/transport/service/tls.rs +++ b/tonic/src/transport/service/tls.rs @@ -1,5 +1,5 @@ use super::io::BoxedIo; -use crate::transport::{Certificate, Identity}; +use crate::transport::{server::Connected, Certificate, Identity}; #[cfg(feature = "tls-roots")] use rustls_native_certs; use std::{fmt, sync::Arc}; @@ -159,7 +159,7 @@ impl TlsAcceptor { pub(crate) async fn accept(&self, io: IO) -> Result where - IO: AsyncRead + AsyncWrite + Unpin + Send + 'static, + IO: AsyncRead + AsyncWrite + Connected + Unpin + Send + 'static, { let io = { let acceptor = RustlsAcceptor::from(self.inner.clone());