diff --git a/Cargo.lock b/Cargo.lock index 693a1690f6..d61f95b38b 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1077,6 +1077,7 @@ dependencies = [ "linkerd-app-outbound", "linkerd-error", "linkerd-opencensus", + "linkerd-tonic-stream", "rangemap", "regex", "thiserror", @@ -1215,6 +1216,7 @@ dependencies = [ "linkerd-meshtls-rustls", "linkerd-proxy-client-policy", "linkerd-proxy-server-policy", + "linkerd-tonic-stream", "linkerd-tonic-watch", "linkerd-tracing", "linkerd2-proxy-api", @@ -1286,6 +1288,7 @@ dependencies = [ "linkerd-proxy-client-policy", "linkerd-retry", "linkerd-stack", + "linkerd-tonic-stream", "linkerd-tonic-watch", "linkerd-tracing", "linkerd2-proxy-api", @@ -1671,7 +1674,6 @@ dependencies = [ name = "linkerd-proxy-api-resolve" version = "0.1.0" dependencies = [ - "async-stream", "futures", "http", "http-body", @@ -1680,6 +1682,7 @@ dependencies = [ "linkerd-proxy-core", "linkerd-stack", "linkerd-tls", + "linkerd-tonic-stream", "linkerd2-proxy-api", "pin-project", "prost", @@ -1980,6 +1983,7 @@ dependencies = [ "linkerd-http-box", "linkerd-proxy-api-resolve", "linkerd-stack", + "linkerd-tonic-stream", "linkerd-tonic-watch", "linkerd2-proxy-api", "once_cell", @@ -2078,6 +2082,21 @@ dependencies = [ name = "linkerd-tls-test-util" version = "0.1.0" +[[package]] +name = "linkerd-tonic-stream" +version = "0.1.0" +dependencies = [ + "futures", + "linkerd-stack", + "linkerd-tracing", + "pin-project", + "tokio", + "tokio-stream", + "tokio-test", + "tonic", + "tracing", +] + [[package]] name = "linkerd-tonic-watch" version = "0.1.0" diff --git a/Cargo.toml b/Cargo.toml index 5b23ede593..bb3eb79313 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -61,6 +61,7 @@ members = [ "linkerd/stack/metrics", "linkerd/stack/tracing", "linkerd/system", + "linkerd/tonic-stream", "linkerd/tonic-watch", "linkerd/tls", "linkerd/tls/test-util", diff --git a/linkerd/app/Cargo.toml b/linkerd/app/Cargo.toml index d5a9f64700..af443530a3 100644 --- a/linkerd/app/Cargo.toml +++ b/linkerd/app/Cargo.toml @@ -25,6 +25,7 @@ linkerd-app-inbound = { path = "./inbound" } linkerd-app-outbound = { path = "./outbound" } linkerd-error = { path = "../error" } linkerd-opencensus = { path = "../opencensus" } +linkerd-tonic-stream = { path = "../tonic-stream" } rangemap = "1" regex = "1" thiserror = "1" diff --git a/linkerd/app/inbound/Cargo.toml b/linkerd/app/inbound/Cargo.toml index 64b5651eef..215f10f9fc 100644 --- a/linkerd/app/inbound/Cargo.toml +++ b/linkerd/app/inbound/Cargo.toml @@ -28,6 +28,7 @@ linkerd-idle-cache = { path = "../../idle-cache" } linkerd-meshtls = { path = "../../meshtls", optional = true } linkerd-meshtls-rustls = { path = "../../meshtls/rustls", optional = true } linkerd-proxy-client-policy = { path = "../../proxy/client-policy" } +linkerd-tonic-stream = { path = "../../tonic-stream" } linkerd-tonic-watch = { path = "../../tonic-watch" } linkerd2-proxy-api = { version = "0.12", features = ["inbound"] } once_cell = "1" diff --git a/linkerd/app/inbound/src/policy/api.rs b/linkerd/app/inbound/src/policy/api.rs index 87e4638997..51cc47c8db 100644 --- a/linkerd/app/inbound/src/policy/api.rs +++ b/linkerd/app/inbound/src/policy/api.rs @@ -9,13 +9,16 @@ use linkerd_app_core::{ Error, Recover, Result, }; use linkerd_proxy_server_policy::ServerPolicy; +use linkerd_tonic_stream::{LimitReceiveFuture, ReceiveLimits}; use linkerd_tonic_watch::StreamWatch; -use std::{sync::Arc, time}; +use std::sync::Arc; +use tokio::time; #[derive(Clone, Debug)] pub(super) struct Api { workload: Arc, - detect_timeout: time::Duration, + limits: ReceiveLimits, + default_detect_timeout: time::Duration, client: Client, } @@ -34,10 +37,16 @@ where S::ResponseBody: http::HttpBody + Default + Send + 'static, { - pub(super) fn new(workload: Arc, detect_timeout: time::Duration, client: S) -> Self { + pub(super) fn new( + workload: Arc, + limits: ReceiveLimits, + default_detect_timeout: time::Duration, + client: S, + ) -> Self { Self { workload, - detect_timeout, + limits, + default_detect_timeout, client: Client::new(client), } } @@ -72,26 +81,28 @@ where port: port.into(), workload: self.workload.as_ref().to_owned(), }; - let detect_timeout = self.detect_timeout; + + let detect_timeout = self.default_detect_timeout; + let limits = self.limits; let mut client = self.client.clone(); Box::pin(async move { - let rsp = client.watch_port(tonic::Request::new(req)).await?; - Ok(rsp.map(|updates| { - updates - .map_ok(move |up| { - // If the server returned an invalid server policy, we - // default to using an invalid policy that causes all - // requests to report an internal error. - let policy = ServerPolicy::try_from(up).unwrap_or_else(|error| { - tracing::warn!(%error, "Server misconfigured"); - INVALID_POLICY - .get_or_init(|| ServerPolicy::invalid(detect_timeout)) - .clone() - }); - tracing::debug!(?policy); - policy - }) - .boxed() + let rsp = LimitReceiveFuture::new(limits, client.watch_port(tonic::Request::new(req))) + .await?; + Ok(rsp.map(move |s| { + s.map_ok(move |up| { + // If the server returned an invalid server policy, we + // default to using an invalid policy that causes all + // requests to report an internal error. + let policy = ServerPolicy::try_from(up).unwrap_or_else(|error| { + tracing::warn!(%error, "Server misconfigured"); + INVALID_POLICY + .get_or_init(|| ServerPolicy::invalid(detect_timeout)) + .clone() + }); + tracing::debug!(?policy); + policy + }) + .boxed() })) }) } diff --git a/linkerd/app/inbound/src/policy/config.rs b/linkerd/app/inbound/src/policy/config.rs index 12f8ea1ebe..4249132386 100644 --- a/linkerd/app/inbound/src/policy/config.rs +++ b/linkerd/app/inbound/src/policy/config.rs @@ -1,5 +1,6 @@ use super::{api::Api, DefaultPolicy, GetPolicy, Protocol, ServerPolicy, Store}; use linkerd_app_core::{exp_backoff::ExponentialBackoff, proxy::http, Error}; +use linkerd_tonic_stream::ReceiveLimits; use rangemap::RangeInclusiveSet; use std::{ collections::{HashMap, HashSet}, @@ -36,6 +37,7 @@ impl Config { workload: Arc, client: C, backoff: ExponentialBackoff, + limits: ReceiveLimits, ) -> impl GetPolicy + Clone + Send + Sync + 'static where C: tonic::client::GrpcService, @@ -66,7 +68,7 @@ impl Config { }) => timeout, _ => Duration::from_secs(10), }; - Api::new(workload, detect_timeout, client).into_watch(backoff) + Api::new(workload, limits, detect_timeout, client).into_watch(backoff) }; Store::spawn_discover(default, cache_max_idle_age, watch, ports, opaque_ports) } diff --git a/linkerd/app/inbound/src/server.rs b/linkerd/app/inbound/src/server.rs index 6b45ced8ef..5dd58ad1d3 100644 --- a/linkerd/app/inbound/src/server.rs +++ b/linkerd/app/inbound/src/server.rs @@ -7,6 +7,7 @@ use linkerd_app_core::{ transport::{self, addrs::*}, Error, }; +use linkerd_tonic_stream::ReceiveLimits; use std::{fmt::Debug, sync::Arc}; use tracing::debug_span; @@ -23,6 +24,7 @@ impl Inbound<()> { workload: Arc, client: C, backoff: ExponentialBackoff, + limits: ReceiveLimits, ) -> impl policy::GetPolicy + Clone + Send + Sync + 'static where C: tonic::client::GrpcService, @@ -31,7 +33,10 @@ impl Inbound<()> { C::ResponseBody: Default + Send + 'static, C::Future: Send, { - self.config.policy.clone().build(workload, client, backoff) + self.config + .policy + .clone() + .build(workload, client, backoff, limits) } pub fn mk( diff --git a/linkerd/app/outbound/Cargo.toml b/linkerd/app/outbound/Cargo.toml index 278eeace03..debf270578 100644 --- a/linkerd/app/outbound/Cargo.toml +++ b/linkerd/app/outbound/Cargo.toml @@ -33,6 +33,7 @@ linkerd-proxy-client-policy = { path = "../../proxy/client-policy", features = [ "proto", ] } linkerd-retry = { path = "../../retry" } +linkerd-tonic-stream = { path = "../../tonic-stream" } linkerd-tonic-watch = { path = "../../tonic-watch" } once_cell = "1" parking_lot = "0.12" diff --git a/linkerd/app/outbound/src/lib.rs b/linkerd/app/outbound/src/lib.rs index e3f7e54328..59fe1eaa68 100644 --- a/linkerd/app/outbound/src/lib.rs +++ b/linkerd/app/outbound/src/lib.rs @@ -25,6 +25,7 @@ use linkerd_app_core::{ transport::addrs::*, AddrMatch, Error, ProxyRuntime, }; +use linkerd_tonic_stream::ReceiveLimits; use std::{ collections::{HashMap, HashSet}, fmt::Debug, @@ -141,6 +142,7 @@ impl Outbound<()> { workload: Arc, client: C, backoff: ExponentialBackoff, + limits: ReceiveLimits, ) -> impl policy::GetPolicy where C: tonic::client::GrpcService, @@ -149,7 +151,7 @@ impl Outbound<()> { C::ResponseBody: Default + Send + 'static, C::Future: Send, { - policy::Api::new(workload, Duration::from_secs(10), client) + policy::Api::new(workload, limits, Duration::from_secs(10), client) .into_watch(backoff) .map_result(|response| match response { Err(e) => Err(e.into()), diff --git a/linkerd/app/outbound/src/policy/api.rs b/linkerd/app/outbound/src/policy/api.rs index d4020677c8..e5d1f60e0c 100644 --- a/linkerd/app/outbound/src/policy/api.rs +++ b/linkerd/app/outbound/src/policy/api.rs @@ -9,13 +9,16 @@ use linkerd_app_core::{ Addr, Error, Recover, Result, }; use linkerd_proxy_client_policy::ClientPolicy; +use linkerd_tonic_stream::{LimitReceiveFuture, ReceiveLimits}; use linkerd_tonic_watch::StreamWatch; -use std::{sync::Arc, time}; +use std::sync::Arc; +use tokio::time; #[derive(Clone, Debug)] pub(crate) struct Api { workload: Arc, - detect_timeout: time::Duration, + limits: ReceiveLimits, + default_detect_timeout: time::Duration, client: Client, } @@ -33,10 +36,16 @@ where S::ResponseBody: http::HttpBody + Default + Send + 'static, { - pub(crate) fn new(workload: Arc, detect_timeout: time::Duration, client: S) -> Self { + pub(crate) fn new( + workload: Arc, + limits: ReceiveLimits, + default_detect_timeout: time::Duration, + client: S, + ) -> Self { Self { workload, - detect_timeout, + limits, + default_detect_timeout, client: Client::new(client), } } @@ -77,26 +86,28 @@ where target: Some(target), } }; - let detect_timeout = self.detect_timeout; + + let detect_timeout = self.default_detect_timeout; + let limits = self.limits; let mut client = self.client.clone(); Box::pin(async move { - let rsp = client.watch(tonic::Request::new(req)).await?; - Ok(rsp.map(|updates| { - updates - .map_ok(move |up| { - // If the server returned an invalid client policy, we - // default to using an invalid policy that causes all - // requests to report an internal error. - let policy = ClientPolicy::try_from(up).unwrap_or_else(|error| { - tracing::warn!(%error, "Client policy misconfigured"); - INVALID_POLICY - .get_or_init(|| ClientPolicy::invalid(detect_timeout)) - .clone() - }); - tracing::debug!(?policy); - policy - }) - .boxed() + let rsp = + LimitReceiveFuture::new(limits, client.watch(tonic::Request::new(req))).await?; + Ok(rsp.map(move |s| { + s.map_ok(move |up| { + // If the server returned an invalid client policy, we + // default to using an invalid policy that causes all + // requests to report an internal error. + let policy = ClientPolicy::try_from(up).unwrap_or_else(|error| { + tracing::warn!(%error, "Client policy misconfigured"); + INVALID_POLICY + .get_or_init(|| ClientPolicy::invalid(detect_timeout)) + .clone() + }); + tracing::debug!(?policy); + policy + }) + .boxed() })) }) } diff --git a/linkerd/app/src/dst.rs b/linkerd/app/src/dst.rs index ac075af810..7cf8155d2a 100644 --- a/linkerd/app/src/dst.rs +++ b/linkerd/app/src/dst.rs @@ -8,11 +8,13 @@ use linkerd_app_core::{ svc::{self, NewService, ServiceExt}, Error, Recover, }; +use linkerd_tonic_stream::ReceiveLimits; #[derive(Clone, Debug)] pub struct Config { pub control: control::Config, pub context: String, + pub limits: ReceiveLimits, } /// Handles to destination service clients. @@ -58,13 +60,20 @@ impl Config { .new_service(()) .map_err(Error::from); - let profiles = - profiles::Client::new_recover_default(backoff, svc.clone(), self.context.clone()); + let profiles = profiles::Client::new_recover_default( + backoff, + svc.clone(), + self.context.clone(), + self.limits, + ); Ok(Dst { addr, profiles, - resolve: recover::Resolve::new(backoff, api::Resolve::new(svc, self.context)), + resolve: recover::Resolve::new( + backoff, + api::Resolve::new(svc, self.context, self.limits), + ), }) } } diff --git a/linkerd/app/src/env.rs b/linkerd/app/src/env.rs index 2303c52e47..e8cef2026a 100644 --- a/linkerd/app/src/env.rs +++ b/linkerd/app/src/env.rs @@ -8,6 +8,7 @@ use linkerd_app_core::{ transport::{Keepalive, ListenAddr}, Addr, AddrMatch, Conditional, IpNet, }; +use linkerd_tonic_stream::ReceiveLimits; use rangemap::RangeInclusiveSet; use std::{ collections::{HashMap, HashSet}, @@ -368,6 +369,8 @@ pub fn parse_config(strings: &S) -> Result let metrics_retain_idle = parse(strings, ENV_METRICS_RETAIN_IDLE, parse_duration); + let control_receive_limits = mk_control_receive_limits(strings)?; + // DNS let resolv_conf_path = strings.get(ENV_RESOLV_CONF); @@ -666,6 +669,11 @@ pub fn parse_config(strings: &S) -> Result } else { outbound.http_request_queue.failfast_timeout }; + let limits = addr + .addr + .is_loopback() + .then(ReceiveLimits::default) + .unwrap_or(control_receive_limits); super::dst::Config { context: dst_token?.unwrap_or_default(), control: ControlConfig { @@ -676,6 +684,7 @@ pub fn parse_config(strings: &S) -> Result failfast_timeout, }, }, + limits, } }; @@ -692,6 +701,12 @@ pub fn parse_config(strings: &S) -> Result EnvError::InvalidEnvVar })?; + let limits = addr + .addr + .is_loopback() + .then(ReceiveLimits::default) + .unwrap_or(control_receive_limits); + let control = { let connect = if addr.addr.is_loopback() { inbound.proxy.connect.clone() @@ -707,7 +722,12 @@ pub fn parse_config(strings: &S) -> Result }, } }; - policy::Config { control, workload } + + policy::Config { + control, + workload, + limits, + } }; let admin = super::admin::Config { @@ -848,6 +868,35 @@ fn convert_attributes_string_to_map(attributes: String) -> HashMap Result { + const ENV_INIT: &str = "LINKERD2_PROXY_CONTROL_STREAM_INITIAL_TIMEOUT"; + const ENV_IDLE: &str = "LINKERD2_PROXY_CONTROL_STREAM_IDLE_TIMEOUT"; + const ENV_LIFE: &str = "LINKERD2_PROXY_CONTROL_STREAM_LIFETIME"; + + let initial = parse(env, ENV_INIT, parse_duration_opt)?.flatten(); + let idle = parse(env, ENV_IDLE, parse_duration_opt)?.flatten(); + let lifetime = parse(env, ENV_LIFE, parse_duration_opt)?.flatten(); + + if initial.unwrap_or(Duration::ZERO) > idle.unwrap_or(Duration::MAX) { + error!("{ENV_INIT} must be less than {ENV_IDLE}"); + return Err(EnvError::InvalidEnvVar); + } + if initial.unwrap_or(Duration::ZERO) > lifetime.unwrap_or(Duration::MAX) { + error!("{ENV_INIT} must be less than {ENV_LIFE}"); + return Err(EnvError::InvalidEnvVar); + } + if idle.unwrap_or(Duration::ZERO) > lifetime.unwrap_or(Duration::MAX) { + error!("{ENV_IDLE} must be less than {ENV_LIFE}"); + return Err(EnvError::InvalidEnvVar); + } + + Ok(ReceiveLimits { + initial, + idle, + lifetime, + }) +} + // === impl Env === impl Strings for Env { @@ -908,6 +957,13 @@ where s.parse().map_err(Into::into) } +fn parse_duration_opt(s: &str) -> Result, ParseError> { + if s.is_empty() { + return Ok(None); + } + parse_duration(s).map(Some) +} + fn parse_duration(s: &str) -> Result { use regex::Regex; @@ -1479,4 +1535,45 @@ mod tests { assert!(dbg!(parse_port_range_set("69420")).is_err()); assert!(dbg!(parse_port_range_set("1-69420")).is_err()); } + + #[test] + fn control_stream_limits() { + impl Strings for HashMap<&'static str, &'static str> { + fn get(&self, key: &str) -> Result, EnvError> { + Ok(self.get(key).map(ToString::to_string)) + } + } + + let mut env = HashMap::default(); + env.insert("LINKERD2_PROXY_CONTROL_STREAM_INITIAL_TIMEOUT", "1s"); + env.insert("LINKERD2_PROXY_CONTROL_STREAM_IDLE_TIMEOUT", "2s"); + env.insert("LINKERD2_PROXY_CONTROL_STREAM_LIFETIME", "3s"); + let limits = mk_control_receive_limits(&env).unwrap(); + assert_eq!(limits.initial, Some(Duration::from_secs(1))); + assert_eq!(limits.idle, Some(Duration::from_secs(2))); + assert_eq!(limits.lifetime, Some(Duration::from_secs(3))); + + env.insert("LINKERD2_PROXY_CONTROL_STREAM_INITIAL_TIMEOUT", ""); + env.insert("LINKERD2_PROXY_CONTROL_STREAM_IDLE_TIMEOUT", ""); + env.insert("LINKERD2_PROXY_CONTROL_STREAM_LIFETIME", ""); + let limits = mk_control_receive_limits(&env).unwrap(); + assert_eq!(limits.initial, None); + assert_eq!(limits.idle, None); + assert_eq!(limits.lifetime, None); + + env.insert("LINKERD2_PROXY_CONTROL_STREAM_INITIAL_TIMEOUT", "3s"); + env.insert("LINKERD2_PROXY_CONTROL_STREAM_IDLE_TIMEOUT", "1s"); + env.insert("LINKERD2_PROXY_CONTROL_STREAM_LIFETIME", ""); + assert!(mk_control_receive_limits(&env).is_err()); + + env.insert("LINKERD2_PROXY_CONTROL_STREAM_INITIAL_TIMEOUT", "3s"); + env.insert("LINKERD2_PROXY_CONTROL_STREAM_IDLE_TIMEOUT", ""); + env.insert("LINKERD2_PROXY_CONTROL_STREAM_LIFETIME", "1s"); + assert!(mk_control_receive_limits(&env).is_err()); + + env.insert("LINKERD2_PROXY_CONTROL_STREAM_INITIAL_TIMEOUT", ""); + env.insert("LINKERD2_PROXY_CONTROL_STREAM_IDLE_TIMEOUT", "3s"); + env.insert("LINKERD2_PROXY_CONTROL_STREAM_LIFETIME", "1s"); + assert!(mk_control_receive_limits(&env).is_err()); + } } diff --git a/linkerd/app/src/lib.rs b/linkerd/app/src/lib.rs index 7b91fd8a82..e7021114fa 100644 --- a/linkerd/app/src/lib.rs +++ b/linkerd/app/src/lib.rs @@ -198,12 +198,14 @@ impl Config { policies.workload.clone(), policies.client.clone(), policies.backoff, + policies.limits, ); let outbound_policies = outbound.build_policies( policies.workload.clone(), policies.client.clone(), policies.backoff, + policies.limits, ); let dst_addr = dst.addr.clone(); diff --git a/linkerd/app/src/policy.rs b/linkerd/app/src/policy.rs index ee258f96c9..29d6f43290 100644 --- a/linkerd/app/src/policy.rs +++ b/linkerd/app/src/policy.rs @@ -7,6 +7,7 @@ use linkerd_app_core::{ svc::{self, NewService, ServiceExt}, Error, }; +use linkerd_tonic_stream::ReceiveLimits; use std::sync::Arc; @@ -14,6 +15,7 @@ use std::sync::Arc; pub struct Config { pub control: control::Config, pub workload: String, + pub limits: ReceiveLimits, } /// Handles to policy service clients. @@ -28,6 +30,8 @@ pub struct Policy { pub workload: Arc, pub backoff: ExponentialBackoff, + + pub limits: ReceiveLimits, } // === impl Config === @@ -64,6 +68,7 @@ impl Config { client, workload, backoff, + limits: self.limits, }) } } diff --git a/linkerd/proxy/api-resolve/Cargo.toml b/linkerd/proxy/api-resolve/Cargo.toml index 29012a30fa..0ecde8cba5 100644 --- a/linkerd/proxy/api-resolve/Cargo.toml +++ b/linkerd/proxy/api-resolve/Cargo.toml @@ -10,13 +10,13 @@ Implements the Resolve trait using the proxy's gRPC API """ [dependencies] -async-stream = "0.3" futures = { version = "0.3", default-features = false } linkerd-addr = { path = "../../addr" } linkerd-error = { path = "../../error" } linkerd2-proxy-api = { version = "0.12", features = ["destination"] } linkerd-proxy-core = { path = "../core" } linkerd-stack = { path = "../../stack" } +linkerd-tonic-stream = { path = "../../tonic-stream" } linkerd-tls = { path = "../../tls" } http = "0.2" http-body = "0.4" diff --git a/linkerd/proxy/api-resolve/src/resolve.rs b/linkerd/proxy/api-resolve/src/resolve.rs index 16cf891f90..1a23b8292d 100644 --- a/linkerd/proxy/api-resolve/src/resolve.rs +++ b/linkerd/proxy/api-resolve/src/resolve.rs @@ -1,15 +1,10 @@ -use crate::{ - api::destination as api, - core::resolve::{self, Update}, - metadata::Metadata, - pb, ConcreteAddr, -}; +use crate::{api::destination as api, core::resolve::Update, metadata::Metadata, pb, ConcreteAddr}; use api::destination_client::DestinationClient; -use async_stream::try_stream; use futures::prelude::*; use http_body::Body; use linkerd_error::Error; use linkerd_stack::Param; +use linkerd_tonic_stream::{LimitReceiveFuture, ReceiveLimits}; use std::pin::Pin; use std::task::{Context, Poll}; use tonic::{self as grpc, body::BoxBody, client::GrpcService}; @@ -18,8 +13,9 @@ use tracing::{debug, info, trace}; #[derive(Clone)] pub struct Resolve { - service: DestinationClient, + client: DestinationClient, context_token: String, + limits: ReceiveLimits, } // === impl Resolve === @@ -32,10 +28,11 @@ where ::Error: Into + Send, S::Future: Send, { - pub fn new(svc: S, context_token: String) -> Self { + pub fn new(svc: S, context_token: String, limits: ReceiveLimits) -> Self { Self { - service: DestinationClient::new(svc), + client: DestinationClient::new(svc), context_token, + limits, } } } @@ -72,61 +69,56 @@ where context_token: self.context_token.clone(), ..Default::default() }; - let mut client = self.service.clone(); + + let limits = self.limits; + let mut client = self.client.clone(); Box::pin(async move { - // Wait for the server to respond once before returning a stream. This let's us eagerly - // detect errors (like InvalidArgument). - let rsp = client.get(grpc::Request::new(req)).await?; + let rsp = LimitReceiveFuture::new(limits, client.get(grpc::Request::new(req))).await?; trace!(metadata = ?rsp.metadata()); - let stream: UpdatesStream = Box::pin(resolution(rsp.into_inner())); - Ok(stream) + Ok(rsp + .into_inner() + .try_filter_map(|up| futures::future::ok::<_, _>(mk_update(up))) + .boxed()) }) } } -fn resolution( - mut stream: tonic::Streaming, -) -> impl Stream, grpc::Status>> { - try_stream! { - while let Some(update) = stream.next().await { - match update?.update { - Some(api::update::Update::Add(api::WeightedAddrSet { - addrs, - metric_labels, - })) => { - let addr_metas = addrs - .into_iter() - .filter_map(|addr| pb::to_addr_meta(addr, &metric_labels)) - .collect::>(); - if !addr_metas.is_empty() { - debug!(endpoints = %addr_metas.len(), "Add"); - yield Update::Add(addr_metas); - } - } - - Some(api::update::Update::Remove(api::AddrSet { addrs })) => { - let sock_addrs = addrs - .into_iter() - .filter_map(pb::to_sock_addr) - .collect::>(); - if !sock_addrs.is_empty() { - debug!(endpoints = %sock_addrs.len(), "Remove"); - yield Update::Remove(sock_addrs); - } - } - - Some(api::update::Update::NoEndpoints(api::NoEndpoints { exists })) => { - info!("No endpoints"); - let update = if exists { - Update::Reset(Vec::new()) - } else { - Update::DoesNotExist - }; - yield update; - } +fn mk_update(up: api::Update) -> Option> { + match up.update? { + api::update::Update::Add(api::WeightedAddrSet { + addrs, + metric_labels, + }) => { + let addr_metas = addrs + .into_iter() + .filter_map(|addr| pb::to_addr_meta(addr, &metric_labels)) + .collect::>(); + if !addr_metas.is_empty() { + debug!(endpoints = %addr_metas.len(), "Add"); + return Some(Update::Add(addr_metas)); + } + } - None => {} // continue + api::update::Update::Remove(api::AddrSet { addrs }) => { + let sock_addrs = addrs + .into_iter() + .filter_map(pb::to_sock_addr) + .collect::>(); + if !sock_addrs.is_empty() { + debug!(endpoints = %sock_addrs.len(), "Remove"); + return Some(Update::Remove(sock_addrs)); } } + + api::update::Update::NoEndpoints(api::NoEndpoints { exists }) => { + info!("No endpoints"); + return Some(if exists { + Update::Reset(Vec::new()) + } else { + Update::DoesNotExist + }); + } } + + None } diff --git a/linkerd/service-profiles/Cargo.toml b/linkerd/service-profiles/Cargo.toml index 6c6b6d84d6..2242fc516b 100644 --- a/linkerd/service-profiles/Cargo.toml +++ b/linkerd/service-profiles/Cargo.toml @@ -20,6 +20,7 @@ linkerd-error = { path = "../error" } linkerd-http-box = { path = "../http-box" } linkerd-proxy-api-resolve = { path = "../proxy/api-resolve" } linkerd-stack = { path = "../stack" } +linkerd-tonic-stream = { path = "../tonic-stream" } linkerd-tonic-watch = { path = "../tonic-watch" } linkerd2-proxy-api = { version = "0.12", features = ["destination"] } once_cell = "1.17" diff --git a/linkerd/service-profiles/src/client.rs b/linkerd/service-profiles/src/client.rs index 0aeb71c62b..b360890b16 100644 --- a/linkerd/service-profiles/src/client.rs +++ b/linkerd/service-profiles/src/client.rs @@ -4,6 +4,7 @@ use http_body::Body; use linkerd2_proxy_api::destination::{self as api, destination_client::DestinationClient}; use linkerd_error::{Infallible, Recover}; use linkerd_stack::{Param, Service}; +use linkerd_tonic_stream::{LimitReceiveFuture, ReceiveLimits}; use linkerd_tonic_watch::StreamWatch; use std::{ sync::Arc, @@ -23,6 +24,7 @@ pub struct Client { struct Inner { client: DestinationClient, context_token: Arc, + limits: ReceiveLimits, } // === impl Client === @@ -38,9 +40,14 @@ where R: Recover + Send + Clone + 'static, R::Backoff: Unpin + Send, { - pub fn new(recover: R, inner: S, context_token: impl Into>) -> Self { + pub fn new( + recover: R, + inner: S, + context_token: impl Into>, + limits: ReceiveLimits, + ) -> Self { Self { - watch: StreamWatch::new(recover, Inner::new(context_token.into(), inner)), + watch: StreamWatch::new(recover, Inner::new(context_token.into(), limits, inner)), } } @@ -48,8 +55,9 @@ where recover: R, inner: S, context_token: impl Into>, + limits: ReceiveLimits, ) -> RecoverDefault { - RecoverDefault::new(Self::new(recover, inner, context_token)) + RecoverDefault::new(Self::new(recover, inner, context_token, limits)) } } @@ -109,9 +117,10 @@ where Into> + Send, S::Future: Send, { - fn new(context_token: Arc, inner: S) -> Self { + fn new(context_token: Arc, limits: ReceiveLimits, inner: S) -> Self { Self { context_token, + limits, client: DestinationClient::new(inner), } } @@ -142,12 +151,15 @@ where ..Default::default() }; + // TODO(ver): Record metrics on requests/errors/etc per addr. let mut client = self.client.clone(); + let limits = self.limits; + let port = addr.port(); Box::pin(async move { - let rsp = client.get_profile(req).await?; - Ok(rsp.map(|s| { - Box::pin(s.map_ok(move |p| proto::convert_profile(p, addr.port()))) as InnerStream - })) + // Limit the amount of time we spend waiting for the first + // profile update. + let rsp = LimitReceiveFuture::new(limits, client.get_profile(req)).await?; + Ok(rsp.map(move |rsp| rsp.map_ok(move |p| proto::convert_profile(p, port)).boxed())) }) } } diff --git a/linkerd/tonic-stream/Cargo.toml b/linkerd/tonic-stream/Cargo.toml new file mode 100644 index 0000000000..1a0ffc7655 --- /dev/null +++ b/linkerd/tonic-stream/Cargo.toml @@ -0,0 +1,21 @@ +[package] +name = "linkerd-tonic-stream" +version = "0.1.0" +authors = ["Linkerd Developers "] +license = "Apache-2.0" +edition = "2021" +publish = false + +[dependencies] +futures = { version = "0.3", default-features = false } +linkerd-stack = { path = "../stack" } +pin-project = "1" +tonic = { version = "0.10", default-features = false } +tokio = { version = "1", features = ["time"] } +tracing = "0.1" + +[dev-dependencies] +tokio = { version = "1", features = ["macros"] } +tokio-test = "0.4" +tokio-stream = { version = "0.1", features = ["sync"] } +linkerd-tracing = { path = "../tracing" } diff --git a/linkerd/tonic-stream/src/lib.rs b/linkerd/tonic-stream/src/lib.rs new file mode 100644 index 0000000000..ded24ad0a1 --- /dev/null +++ b/linkerd/tonic-stream/src/lib.rs @@ -0,0 +1,334 @@ +use futures::FutureExt; +use linkerd_stack::{layer, ExtractParam, NewService, Service}; +use std::{ + future::Future, + pin::Pin, + task::{Context, Poll}, +}; +use tokio::time; + +#[derive(Clone, Copy, Debug, Default)] +pub struct ReceiveLimits { + /// Bounds the amount of time until a gRPC stream's initial item is + /// received. + pub initial: Option, + + /// Bounds the amount of time between received gRPC stream items. + pub idle: Option, + + /// Bounds the total lifetime of a gRPC stream. + pub lifetime: Option, +} + +#[derive(Clone, Debug)] +pub struct NewLimitReceive { + inner: N, + params: P, +} + +#[derive(Clone, Debug)] +pub struct LimitReceive { + inner: S, + limits: ReceiveLimits, +} + +#[pin_project::pin_project] +#[derive(Debug)] +pub struct LimitReceiveFuture { + #[pin] + inner: F, + limits: ReceiveLimits, + recv: Option>>, +} + +#[pin_project::pin_project] +#[derive(Debug)] +pub struct LimitReceiveStream { + #[pin] + inner: S, + + #[pin] + lifetime: time::Sleep, + + recv: Pin>, + recv_timeout: Option, + recv_init: bool, +} + +// === impl NewLimitReceive === + +impl NewLimitReceive { + pub fn layer_via(params: P) -> impl layer::Layer + Clone { + layer::mk(move |inner| Self { + inner, + params: params.clone(), + }) + } +} + +impl NewService for NewLimitReceive +where + N: NewService, + P: ExtractParam, +{ + type Service = LimitReceive; + + fn new_service(&self, target: T) -> Self::Service { + let limits = self.params.extract_param(&target); + let inner = self.inner.new_service(target); + LimitReceive { inner, limits } + } +} + +// === impl LimitReceive ===` + +impl LimitReceive { + pub fn new(limits: ReceiveLimits, inner: S) -> Self + where + S: Service, Response = tonic::Response, Error = tonic::Status>, + { + Self { inner, limits } + } +} + +impl Service> for LimitReceive +where + S: Service, Response = tonic::Response, Error = tonic::Status>, +{ + type Response = tonic::Response>; + type Error = tonic::Status; + type Future = LimitReceiveFuture; + + fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { + self.inner.poll_ready(cx) + } + + fn call(&mut self, req: tonic::Request) -> Self::Future { + let inner = self.inner.call(req); + LimitReceiveFuture::new(self.limits, inner) + } +} + +// === impl LimitReceiveFuture === + +impl LimitReceiveFuture +where + F: Future, tonic::Status>>, +{ + pub fn new(limits: ReceiveLimits, inner: F) -> Self { + let recv = Some(Box::pin(time::sleep( + limits + .initial + .or(limits.idle) + .unwrap_or(time::Duration::MAX), + ))); + Self { + inner, + limits, + recv, + } + } +} + +impl Future for LimitReceiveFuture +where + F: Future, tonic::Status>>, +{ + type Output = Result>, tonic::Status>; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let this = self.project(); + + let rsp = if let Poll::Ready(res) = this.inner.poll(cx) { + res? + } else { + futures::ready!(this.recv.as_mut().unwrap().poll_unpin(cx)); + return Poll::Ready(Err(tonic::Status::deadline_exceeded( + "initial item not received within timeout", + ))); + }; + + let rsp = rsp.map(|inner| LimitReceiveStream { + inner, + recv: this.recv.take().unwrap(), + recv_timeout: this.limits.idle, + recv_init: true, + lifetime: time::sleep(this.limits.lifetime.unwrap_or(time::Duration::MAX)), + }); + Poll::Ready(Ok(rsp)) + } +} + +// === impl ReceiveStream === + +impl futures::Stream for LimitReceiveStream +where + S: futures::TryStream, +{ + type Item = Result; + + fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let this = self.project(); + + // If the lifetime has expired, end the stream. + if this.lifetime.poll(cx).is_ready() { + return Poll::Ready(None); + } + + // Every time the inner stream yields an item, reset the receive + // timeout. + if let Poll::Ready(res) = this.inner.try_poll_next(cx) { + *this.recv_init = false; + if let Some(timeout) = this.recv_timeout { + if let Some(t) = time::Instant::now().checked_add(*timeout) { + this.recv.as_mut().reset(t); + } else { + tracing::error!("Receive timeout overflowed; ignoring") + } + } + return Poll::Ready(res); + } + + // If the receive timeout has expired, end the stream. + if this.recv.poll_unpin(cx).is_ready() { + return Poll::Ready(this.recv_init.then(|| { + Err(tonic::Status::deadline_exceeded( + "Initial update not received within timeout", + )) + })); + } + + Poll::Pending + } +} + +#[cfg(test)] +mod tests { + use super::*; + use futures::prelude::*; + use linkerd_stack::ServiceExt; + use tokio_stream::wrappers::ReceiverStream; + + /// Tests that the initial timeout bounds the response from the inner + /// service. + #[tokio::test(flavor = "current_thread", start_paused = true)] + async fn init_timeout_response() { + let _trace = linkerd_tracing::test::trace_init(); + + let limits = ReceiveLimits { + initial: Some(time::Duration::from_millis(1)), + ..Default::default() + }; + let svc = linkerd_stack::service_fn(|_: tonic::Request<()>| { + futures::future::pending::< + tonic::Result>>>, + >() + }); + let svc = LimitReceive::new(limits, svc); + + let status = svc.oneshot(tonic::Request::new(())).await.unwrap_err(); + assert_eq!(status.code(), tonic::Code::DeadlineExceeded); + } + + /// Tests that the initial timeout bounds the first item received from the + /// inner response stream. + #[tokio::test(flavor = "current_thread", start_paused = true)] + async fn init_timeout_recv() { + let _trace = linkerd_tracing::test::trace_init(); + + let limits = ReceiveLimits { + initial: Some(time::Duration::from_millis(1)), + ..Default::default() + }; + let svc = linkerd_stack::service_fn(|_: tonic::Request<()>| { + futures::future::ok::<_, tonic::Status>(tonic::Response::new( + futures::stream::pending::>(), + )) + }); + let svc = LimitReceive::new(limits, svc); + + let rsp = svc + .oneshot(tonic::Request::new(())) + .await + .unwrap() + .into_inner(); + tokio::pin!(rsp); + let status = rsp.try_next().await.unwrap_err(); + assert_eq!(status.code(), tonic::Code::DeadlineExceeded); + } + + /// Tests that the receive timeout bounds idleness after the initial update. + #[tokio::test(flavor = "current_thread", start_paused = true)] + async fn recv_timeout() { + let _trace = linkerd_tracing::test::trace_init(); + + let limits = ReceiveLimits { + initial: Some(time::Duration::from_millis(1)), + idle: Some(time::Duration::from_millis(1)), + ..Default::default() + }; + let (tx, rx) = tokio::sync::mpsc::channel::>(2); + let svc = { + let mut rx = Some(rx); + linkerd_stack::service_fn(move |_: tonic::Request<()>| { + futures::future::ok::<_, tonic::Status>(tonic::Response::new(ReceiverStream::new( + rx.take().unwrap(), + ))) + }) + }; + let svc = LimitReceive::new(limits, svc); + + let rsp = svc + .oneshot(tonic::Request::new(())) + .await + .unwrap() + .into_inner(); + tokio::pin!(rsp); + + tx.send(Ok(())).await.unwrap(); + assert!(rsp.try_next().await.is_ok()); + + let res = rsp.try_next().await.expect("stream should not error"); + assert_eq!(res, None); + } + + /// Tests that the lifetime bounds the total duration of the stream. + #[tokio::test(flavor = "current_thread", start_paused = true)] + async fn lifetime() { + let _trace = linkerd_tracing::test::trace_init(); + + let limits = ReceiveLimits { + lifetime: Some(time::Duration::from_millis(10)), + ..Default::default() + }; + let (tx, rx) = tokio::sync::mpsc::channel::>(2); + let svc = { + let mut rx = Some(rx); + linkerd_stack::service_fn(move |_: tonic::Request<()>| { + futures::future::ok::<_, tonic::Status>(tonic::Response::new(ReceiverStream::new( + rx.take().unwrap(), + ))) + }) + }; + let svc = LimitReceive::new(limits, svc); + + let rsp = svc + .oneshot(tonic::Request::new(())) + .await + .unwrap() + .into_inner(); + tokio::pin!(rsp); + + tx.send(Ok(())).await.unwrap(); + time::sleep(time::Duration::from_millis(5)).await; + assert!(rsp.try_next().await.is_ok()); + + tx.send(Ok(())).await.unwrap(); + time::sleep(time::Duration::from_millis(9)).await; + assert!(rsp.try_next().await.is_ok()); + + tx.send(Ok(())).await.unwrap(); + time::sleep(time::Duration::from_millis(10)).await; + assert!(rsp.try_next().await.is_ok()); + } +}