Skip to content

Commit

Permalink
feat(tonic): pass trace_fn the request rather than just the headers (
Browse files Browse the repository at this point in the history
  • Loading branch information
davidpdrsn authored May 12, 2021
1 parent 2bf14e1 commit 7862a22
Showing 1 changed file with 20 additions and 12 deletions.
32 changes: 20 additions & 12 deletions tonic/src/transport/server/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ use futures_util::{
future::{self, Either as FutureEither, MapErr},
TryFutureExt,
};
use http::{HeaderMap, Request, Response};
use http::{Request, Response};
use hyper::{server::accept, Body};
use std::{
fmt,
Expand All @@ -48,7 +48,7 @@ use tower::{limit::concurrency::ConcurrencyLimitLayer, util::Either, Service, Se
use tracing_futures::{Instrument, Instrumented};

type BoxService = tower::util::BoxService<Request<Body>, Response<BoxBody>, crate::Error>;
type TraceInterceptor = Arc<dyn Fn(&HeaderMap) -> tracing::Span + Send + Sync + 'static>;
type TraceInterceptor = Arc<dyn Fn(&http::Request<()>) -> tracing::Span + Send + Sync + 'static>;

const DEFAULT_HTTP2_KEEPALIVE_TIMEOUT_SECS: u64 = 20;

Expand Down Expand Up @@ -290,10 +290,10 @@ impl Server {
}
}

/// Intercept inbound headers and add a [`tracing::Span`] to each response future.
/// Intercept inbound requests and add a [`tracing::Span`] to each response future.
pub fn trace_fn<F>(self, f: F) -> Self
where
F: Fn(&HeaderMap) -> tracing::Span + Send + Sync + 'static,
F: Fn(&http::Request<()>) -> tracing::Span + Send + Sync + 'static,
{
Server {
trace_interceptor: Some(Arc::new(f)),
Expand Down Expand Up @@ -361,7 +361,7 @@ impl Server {
IE: Into<crate::Error>,
F: Future<Output = ()>,
{
let span = self.trace_interceptor.clone();
let trace_interceptor = self.trace_interceptor.clone();
let concurrency_limit = self.concurrency_limit;
let init_connection_window_size = self.init_connection_window_size;
let init_stream_window_size = self.init_stream_window_size;
Expand All @@ -381,7 +381,7 @@ impl Server {
inner: svc,
concurrency_limit,
timeout,
span,
trace_interceptor,
};

let server = hyper::Server::builder(incoming)
Expand Down Expand Up @@ -582,7 +582,7 @@ impl fmt::Debug for Server {

struct Svc<S> {
inner: S,
span: Option<TraceInterceptor>,
trace_interceptor: Option<TraceInterceptor>,
conn_info: ConnectionInfo,
}

Expand All @@ -602,8 +602,16 @@ where
}

fn call(&mut self, mut req: Request<Body>) -> Self::Future {
let span = if let Some(trace_interceptor) = &self.span {
trace_interceptor(req.headers())
let span = if let Some(trace_interceptor) = &self.trace_interceptor {
let (parts, body) = req.into_parts();
let bodyless_request = Request::from_parts(parts, ());

let span = trace_interceptor(&bodyless_request);

let (parts, _) = bodyless_request.into_parts();
req = Request::from_parts(parts, body);

span
} else {
tracing::Span::none()
};
Expand All @@ -624,7 +632,7 @@ struct MakeSvc<S> {
concurrency_limit: Option<usize>,
timeout: Option<Duration>,
inner: S,
span: Option<TraceInterceptor>,
trace_interceptor: Option<TraceInterceptor>,
}

impl<S> Service<&ServerIo> for MakeSvc<S>
Expand All @@ -650,7 +658,7 @@ where
let svc = self.inner.clone();
let concurrency_limit = self.concurrency_limit;
let timeout = self.timeout;
let span = self.span.clone();
let trace_interceptor = self.trace_interceptor.clone();

Box::pin(async move {
let svc = ServiceBuilder::new()
Expand All @@ -661,7 +669,7 @@ where

let svc = BoxService::new(Svc {
inner: svc,
span,
trace_interceptor,
conn_info,
});

Expand Down

0 comments on commit 7862a22

Please sign in to comment.