Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add lifetimes to async traits that take args by reference #3061

Merged
merged 6 commits into from
Oct 13, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 16 additions & 10 deletions aws/rust-runtime/aws-config/src/ecs.rs
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ use crate::http_credential_provider::HttpCredentialProvider;
use crate::provider_config::ProviderConfig;
use aws_credential_types::provider::{self, error::CredentialsError, future, ProvideCredentials};
use aws_smithy_http::endpoint::apply_endpoint;
use aws_smithy_runtime_api::client::dns::{DnsResolver, ResolveDnsError, SharedDnsResolver};
use aws_smithy_runtime_api::client::dns::{ResolveDns, ResolveDnsError, SharedDnsResolver};
use aws_smithy_runtime_api::client::http::HttpConnectorSettings;
use aws_smithy_runtime_api::shared::IntoShared;
use aws_smithy_types::error::display::DisplayErrorContext;
Expand Down Expand Up @@ -272,9 +272,9 @@ impl Builder {

/// Override the DNS resolver used to validate URIs
///
/// URIs must refer to loopback addresses. The [`DnsResolver`](aws_smithy_runtime_api::client::dns::DnsResolver)
/// is used to retrieve IP addresses for a given domain.
pub fn dns(mut self, dns: impl DnsResolver + 'static) -> Self {
/// URIs must refer to loopback addresses. The [`ResolveDns`](aws_smithy_runtime_api::client::dns::ResolveDns)
/// implementation is used to retrieve IP addresses for a given domain.
pub fn dns(mut self, dns: impl ResolveDns + 'static) -> Self {
self.dns = Some(dns.into_shared());
self
}
Expand Down Expand Up @@ -399,7 +399,7 @@ async fn validate_full_uri(
Ok(addr) => addr.is_loopback(),
Err(_domain_name) => {
let dns = dns.ok_or(InvalidFullUriErrorKind::NoDnsResolver)?;
dns.resolve_dns(host.to_owned())
dns.resolve_dns(host)
.await
.map_err(|err| InvalidFullUriErrorKind::DnsLookupFailed(ResolveDnsError::new(err)))?
.iter()
Expand Down Expand Up @@ -751,16 +751,22 @@ mod test {
}
}

impl DnsResolver for TestDns {
fn resolve_dns(&self, name: String) -> DnsFuture {
DnsFuture::ready(Ok(self.addrs.get(&name).unwrap_or(&self.fallback).clone()))
impl ResolveDns for TestDns {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍 for verb + noun convention.

fn resolve_dns<'a>(&'a self, name: &'a str) -> DnsFuture<'a>
where
Self: 'a,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we say &'a self, then why do we also need the where bound? What's the difference between those two bounds?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I initially added it because the async-trait crate adds it, but I can't think of a good reason to have it for our use-case, so I've removed it in 591e76f.

{
DnsFuture::ready(Ok(self.addrs.get(name).unwrap_or(&self.fallback).clone()))
}
}

#[derive(Debug)]
struct NeverDns;
impl DnsResolver for NeverDns {
fn resolve_dns(&self, _name: String) -> DnsFuture {
impl ResolveDns for NeverDns {
fn resolve_dns<'a>(&'a self, _name: &'a str) -> DnsFuture<'a>
where
Self: 'a,
{
DnsFuture::new(async {
Never::new().await;
unreachable!()
Expand Down
10 changes: 6 additions & 4 deletions aws/rust-runtime/aws-config/src/imds/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -524,11 +524,13 @@ struct ImdsEndpointResolver {
}

impl EndpointResolver for ImdsEndpointResolver {
fn resolve_endpoint(&self, _: &EndpointResolverParams) -> EndpointFuture {
let this = self.clone();
fn resolve_endpoint<'a>(&'a self, _: &'a EndpointResolverParams) -> EndpointFuture<'a>
where
Self: 'a,
{
EndpointFuture::new(async move {
this.endpoint_source
.endpoint(this.mode_override)
self.endpoint_source
.endpoint(self.mode_override.clone())
.await
.map(|uri| Endpoint::builder().url(uri.to_string()).build())
.map_err(|err| err.into())
Expand Down
19 changes: 9 additions & 10 deletions aws/rust-runtime/aws-config/src/imds/client/token.rs
Original file line number Diff line number Diff line change
Expand Up @@ -192,23 +192,22 @@ fn parse_token_response(response: &HttpResponse, now: SystemTime) -> Result<Toke
}

impl IdentityResolver for TokenResolver {
fn resolve_identity(&self, _config_bag: &ConfigBag) -> IdentityFuture {
let this = self.clone();
IdentityFuture::new(async move {
let preloaded_token = this
fn resolve_identity<'a>(&'a self, _config_bag: &'a ConfigBag) -> IdentityFuture<'a>
where
Self: 'a,
{
IdentityFuture::new(async {
let preloaded_token = self
.inner
.cache
.yield_or_clear_if_expired(this.inner.time_source.now())
.yield_or_clear_if_expired(self.inner.time_source.now())
.await;
let token = match preloaded_token {
Some(token) => Ok(token),
None => {
this.inner
self.inner
.cache
.get_or_load(|| {
let this = this.clone();
async move { this.get_token().await }
})
.get_or_load(|| async { self.get_token().await })
.await
}
}?;
Expand Down
4 changes: 3 additions & 1 deletion aws/rust-runtime/aws-inlineable/src/endpoint_discovery.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,9 @@ use tokio::sync::oneshot::{Receiver, Sender};
/// Endpoint reloader
#[must_use]
pub struct ReloadEndpoint {
loader: Box<dyn Fn() -> BoxFuture<(Endpoint, SystemTime), ResolveEndpointError> + Send + Sync>,
loader: Box<
dyn Fn() -> BoxFuture<'static, (Endpoint, SystemTime), ResolveEndpointError> + Send + Sync,
>,
endpoint: Arc<Mutex<Option<ExpiringEndpoint>>>,
error: Arc<Mutex<Option<ResolveEndpointError>>>,
rx: Receiver<()>,
Expand Down
5 changes: 4 additions & 1 deletion aws/rust-runtime/aws-runtime/src/identity.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,10 @@ pub mod credentials {
}

impl IdentityResolver for CredentialsIdentityResolver {
fn resolve_identity(&self, _config_bag: &ConfigBag) -> IdentityFuture {
fn resolve_identity<'a>(&'a self, _config_bag: &'a ConfigBag) -> IdentityFuture<'a>
where
Self: 'a,
{
let cache = self.credentials_cache.clone();
IdentityFuture::new(async move {
let credentials = cache.as_ref().provide_cached_credentials().await?;
Expand Down
2 changes: 1 addition & 1 deletion rust-runtime/aws-smithy-async/src/future/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,4 +15,4 @@ pub mod rendezvous;
pub mod timeout;

/// A boxed future that outputs a `Result<T, E>`.
pub type BoxFuture<T, E> = Pin<Box<dyn Future<Output = Result<T, E>> + Send>>;
pub type BoxFuture<'a, T, E> = Pin<Box<dyn Future<Output = Result<T, E>> + Send + 'a>>;
42 changes: 34 additions & 8 deletions rust-runtime/aws-smithy-runtime-api/src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,28 +3,54 @@
* SPDX-License-Identifier: Apache-2.0
*/

/// Declares a new-type for a future that is returned from an async trait (prior to stable async trait).
///
/// To declare a future with a static lifetime:
/// ```ignore
/// new_type_future! {
/// doc = "some rustdoc for the future's struct",
/// pub struct NameOfFuture<'static, OutputType, ErrorType>;
/// }
/// ```
///
/// To declare a future with a non-static lifetime:
/// ```ignore
/// new_type_future! {
/// doc = "some rustdoc for the future's struct",
/// pub struct NameOfFuture<'a, OutputType, ErrorType>;
/// }
/// ```
macro_rules! new_type_future {
(
doc = $type_docs:literal,
pub struct $future_name:ident<$output:ty, $err:ty>,
#[doc = $type_docs:literal]
pub struct $future_name:ident<'static, $output:ty, $err:ty>;
) => {
new_type_future!(@internal, $type_docs, $future_name, $output, $err, 'static,);
};
(
#[doc = $type_docs:literal]
pub struct $future_name:ident<$lifetime:lifetime, $output:ty, $err:ty>;
) => {
new_type_future!(@internal, $type_docs, $future_name, $output, $err, $lifetime, <$lifetime>);
};
(@internal, $type_docs:literal, $future_name:ident, $output:ty, $err:ty, $lifetime:lifetime, $($decl_lifetime:tt)*) => {
pin_project_lite::pin_project! {
#[allow(clippy::type_complexity)]
#[doc = $type_docs]
pub struct $future_name {
pub struct $future_name$($decl_lifetime)* {
#[pin]
inner: aws_smithy_async::future::now_or_later::NowOrLater<
Result<$output, $err>,
aws_smithy_async::future::BoxFuture<$output, $err>
aws_smithy_async::future::BoxFuture<$lifetime, $output, $err>
>,
}
}

impl $future_name {
impl$($decl_lifetime)* $future_name$($decl_lifetime)* {
#[doc = concat!("Create a new `", stringify!($future_name), "` with the given future.")]
pub fn new<F>(future: F) -> Self
where
F: std::future::Future<Output = Result<$output, $err>> + Send + 'static,
F: std::future::Future<Output = Result<$output, $err>> + Send + $lifetime,
{
Self {
inner: aws_smithy_async::future::now_or_later::NowOrLater::new(Box::pin(future)),
Expand All @@ -38,7 +64,7 @@ macro_rules! new_type_future {
")]
pub fn new_boxed(
future: std::pin::Pin<
Box<dyn std::future::Future<Output = Result<$output, $err>> + Send>,
Box<dyn std::future::Future<Output = Result<$output, $err>> + Send + $lifetime>,
>,
) -> Self {
Self {
Expand All @@ -54,7 +80,7 @@ macro_rules! new_type_future {
}
}

impl std::future::Future for $future_name {
impl$($decl_lifetime)* std::future::Future for $future_name$($decl_lifetime)* {
type Output = Result<$output, $err>;

fn poll(self: std::pin::Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> std::task::Poll<Self::Output> {
Expand Down
55 changes: 17 additions & 38 deletions rust-runtime/aws-smithy-runtime-api/src/client/dns.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,10 @@

use crate::box_error::BoxError;
use crate::impl_shared_conversions;
use aws_smithy_async::future::now_or_later::NowOrLater;
use std::error::Error as StdError;
use std::fmt;
use std::future::Future;
use std::net::IpAddr;
use std::pin::Pin;
use std::sync::Arc;
use std::task::{Context, Poll};

/// Error that occurs when failing to perform a DNS lookup.
#[derive(Debug)]
Expand Down Expand Up @@ -43,57 +39,40 @@ impl StdError for ResolveDnsError {
}
}

type BoxFuture<T> = aws_smithy_async::future::BoxFuture<T, ResolveDnsError>;

/// New-type for the future returned by the [`DnsResolver`] trait.
pub struct DnsFuture(NowOrLater<Result<Vec<IpAddr>, ResolveDnsError>, BoxFuture<Vec<IpAddr>>>);
impl DnsFuture {
/// Create a new `DnsFuture`
pub fn new(
future: impl Future<Output = Result<Vec<IpAddr>, ResolveDnsError>> + Send + 'static,
) -> Self {
Self(NowOrLater::new(Box::pin(future)))
}

/// Create a `DnsFuture` that is immediately ready
pub fn ready(result: Result<Vec<IpAddr>, ResolveDnsError>) -> Self {
Self(NowOrLater::ready(result))
}
}
impl Future for DnsFuture {
type Output = Result<Vec<IpAddr>, ResolveDnsError>;

fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let mut this = self.as_mut();
let inner = Pin::new(&mut this.0);
Future::poll(inner, cx)
}
new_type_future! {
#[doc = "New-type for the future returned by the [`ResolveDns`] trait."]
pub struct DnsFuture<'a, Vec<IpAddr>, ResolveDnsError>;
}

/// Trait for resolving domain names
pub trait DnsResolver: fmt::Debug + Send + Sync {
pub trait ResolveDns: fmt::Debug + Send + Sync {
/// Asynchronously resolve the given domain name
fn resolve_dns(&self, name: String) -> DnsFuture;
fn resolve_dns<'a>(&'a self, name: &'a str) -> DnsFuture<'a>
where
Self: 'a;
}

/// Shared DNS resolver
/// Shared instance of [`ResolveDns`].
#[derive(Clone, Debug)]
pub struct SharedDnsResolver(Arc<dyn DnsResolver>);
pub struct SharedDnsResolver(Arc<dyn ResolveDns>);

impl SharedDnsResolver {
/// Create a new `SharedDnsResolver`.
pub fn new(resolver: impl DnsResolver + 'static) -> Self {
pub fn new(resolver: impl ResolveDns + 'static) -> Self {
Self(Arc::new(resolver))
}
}

impl DnsResolver for SharedDnsResolver {
fn resolve_dns(&self, name: String) -> DnsFuture {
impl ResolveDns for SharedDnsResolver {
fn resolve_dns<'a>(&'a self, name: &'a str) -> DnsFuture<'a>
where
Self: 'a,
{
self.0.resolve_dns(name)
}
}

impl_shared_conversions!(convert SharedDnsResolver from DnsResolver using SharedDnsResolver::new);
impl_shared_conversions!(convert SharedDnsResolver from ResolveDns using SharedDnsResolver::new);

#[cfg(test)]
mod tests {
Expand All @@ -102,6 +81,6 @@ mod tests {
#[test]
fn check_send() {
fn is_send<T: Send>() {}
is_send::<DnsFuture>();
is_send::<DnsFuture<'_>>();
}
}
13 changes: 9 additions & 4 deletions rust-runtime/aws-smithy-runtime-api/src/client/endpoint.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@ use std::fmt;
use std::sync::Arc;

new_type_future! {
doc = "Future for [`EndpointResolver::resolve_endpoint`].",
pub struct EndpointFuture<Endpoint, BoxError>,
#[doc = "Future for [`EndpointResolver::resolve_endpoint`]."]
pub struct EndpointFuture<'a, Endpoint, BoxError>;
}

/// Parameters originating from the Smithy endpoint ruleset required for endpoint resolution.
Expand Down Expand Up @@ -45,7 +45,9 @@ impl Storable for EndpointResolverParams {
/// Configurable endpoint resolver implementation.
pub trait EndpointResolver: Send + Sync + fmt::Debug {
/// Asynchronously resolves an endpoint to use from the given endpoint parameters.
fn resolve_endpoint(&self, params: &EndpointResolverParams) -> EndpointFuture;
fn resolve_endpoint<'a>(&'a self, params: &'a EndpointResolverParams) -> EndpointFuture<'a>
where
Self: 'a;
}

/// Shared endpoint resolver.
Expand All @@ -62,7 +64,10 @@ impl SharedEndpointResolver {
}

impl EndpointResolver for SharedEndpointResolver {
fn resolve_endpoint(&self, params: &EndpointResolverParams) -> EndpointFuture {
fn resolve_endpoint<'a>(&'a self, params: &'a EndpointResolverParams) -> EndpointFuture<'a>
where
Self: 'a,
{
self.0.resolve_endpoint(params)
}
}
Expand Down
4 changes: 2 additions & 2 deletions rust-runtime/aws-smithy-runtime-api/src/client/http.rs
Original file line number Diff line number Diff line change
Expand Up @@ -62,8 +62,8 @@ use std::sync::Arc;
use std::time::Duration;

new_type_future! {
doc = "Future for [`HttpConnector::call`].",
pub struct HttpConnectorFuture<HttpResponse, ConnectorError>,
#[doc = "Future for [`HttpConnector::call`]."]
pub struct HttpConnectorFuture<'static, HttpResponse, ConnectorError>;
}

/// Trait with a `call` function that asynchronously converts a request into a response.
Expand Down
13 changes: 9 additions & 4 deletions rust-runtime/aws-smithy-runtime-api/src/client/identity.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@ use std::time::SystemTime;
pub mod http;

new_type_future! {
doc = "Future for [`IdentityResolver::resolve_identity`].",
pub struct IdentityFuture<Identity, BoxError>,
#[doc = "Future for [`IdentityResolver::resolve_identity`]."]
pub struct IdentityFuture<'a, Identity, BoxError>;
}

/// Resolver for identities.
Expand All @@ -34,7 +34,9 @@ new_type_future! {
/// There is no fallback to other auth schemes in the absence of an identity.
pub trait IdentityResolver: Send + Sync + Debug {
/// Asynchronously resolves an identity for a request using the given config.
fn resolve_identity(&self, config_bag: &ConfigBag) -> IdentityFuture;
fn resolve_identity<'a>(&'a self, config_bag: &'a ConfigBag) -> IdentityFuture<'a>
where
Self: 'a;
}

/// Container for a shared identity resolver.
Expand All @@ -49,7 +51,10 @@ impl SharedIdentityResolver {
}

impl IdentityResolver for SharedIdentityResolver {
fn resolve_identity(&self, config_bag: &ConfigBag) -> IdentityFuture {
fn resolve_identity<'a>(&'a self, config_bag: &'a ConfigBag) -> IdentityFuture<'a>
where
Self: 'a,
{
self.0.resolve_identity(config_bag)
}
}
Expand Down
Loading
Loading