Skip to content

Commit

Permalink
Merge branch 'main' into 1671
Browse files Browse the repository at this point in the history
  • Loading branch information
82marbag authored Nov 15, 2022
2 parents 427e6bd + b2528a1 commit 0a7f8a6
Show file tree
Hide file tree
Showing 7 changed files with 269 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -308,6 +308,11 @@ class ServerServiceGeneratorV2(
#{SmithyHttpServer}::routing::IntoMakeService::new(self)
}
/// Converts [`$serviceName`] into a [`MakeService`](tower::make::MakeService) with [`ConnectInfo`](#{SmithyHttpServer}::routing::into_make_service_with_connect_info::ConnectInfo).
pub fn into_make_service_with_connect_info<C>(self) -> #{SmithyHttpServer}::routing::IntoMakeServiceWithConnectInfo<Self, C> {
#{SmithyHttpServer}::routing::IntoMakeServiceWithConnectInfo::new(self)
}
/// Applies a [`Layer`](#{Tower}::Layer) uniformly to all routes.
pub fn layer<L>(self, layer: &L) -> $serviceName<L::Service>
where
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -532,9 +532,11 @@ class ServerProtocolTestGenerator(
// corresponding Unicode code point. That is the "form feed" 0x0c character. When printing it,
// it gets written as "\f", which is an invalid Rust escape sequence: https://static.rust-lang.org/doc/master/reference.html#literals
// So we need to write the corresponding Rust Unicode escape sequence to make the program compile.
"#{SmithyHttpServer}::body::Body::from(#{Bytes}::from_static(${
body.replace("\u000c", "\\u{000c}").dq()
}.as_bytes()))"
//
// We also escape to avoid interactions with templating in the case where the body contains `#`.
val sanitizedBody = escape(body.replace("\u000c", "\\u{000c}")).dq()
"#{SmithyHttpServer}::body::Body::from(#{Bytes}::from_static($sanitizedBody.as_bytes()))"
} else {
"#{SmithyHttpServer}::body::Body::empty()"
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,10 @@ path = "src/bin/pokemon-service-tls.rs"
name = "pokemon-service-lambda"
path = "src/bin/pokemon-service-lambda.rs"

[[bin]]
name = "pokemon-service-connect-info"
path = "src/bin/pokemon-service-connect-info.rs"

[dependencies]
async-stream = "0.3"
clap = { version = "~3.2.1", features = ["derive"] }
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
/*
* Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
* SPDX-License-Identifier: Apache-2.0
*/

use clap::Parser;
use pokemon_service::{
capture_pokemon, check_health, do_nothing, get_pokemon_species, get_server_statistics, setup_tracing,
};

#[derive(Parser, Debug)]
#[clap(author, version, about, long_about = None)]
struct Args {
/// Hyper server bind address.
#[clap(short, long, action, default_value = "127.0.0.1")]
address: String,
/// Hyper server bind port.
#[clap(short, long, action, default_value = "13734")]
port: u16,
}

/// Retrieves the user's storage. No authentication required for locals.
pub async fn get_storage_with_local_approved(
input: pokemon_service_server_sdk::input::GetStorageInput,
connect_info: aws_smithy_http_server::Extension<aws_smithy_http_server::routing::ConnectInfo<std::net::SocketAddr>>,
) -> Result<pokemon_service_server_sdk::output::GetStorageOutput, pokemon_service_server_sdk::error::GetStorageError> {
tracing::debug!("attempting to authenticate storage user");
let local = connect_info.0 .0.ip() == "127.0.0.1".parse::<std::net::IpAddr>().unwrap();

// We currently support Ash: he has nothing stored
if input.user == "ash" && input.passcode == "pikachu123" {
return Ok(pokemon_service_server_sdk::output::GetStorageOutput { collection: vec![] });
}
// We support trainers in our gym
if local {
tracing::info!("welcome back");
return Ok(pokemon_service_server_sdk::output::GetStorageOutput {
collection: vec![
String::from("bulbasaur"),
String::from("charmander"),
String::from("squirtle"),
],
});
}
tracing::debug!("authentication failed");
Err(pokemon_service_server_sdk::error::GetStorageError::NotAuthorized(
pokemon_service_server_sdk::error::NotAuthorized {},
))
}

#[tokio::main]
async fn main() {
let args = Args::parse();
setup_tracing();
let app = pokemon_service_server_sdk::service::PokemonService::builder()
.get_pokemon_species(get_pokemon_species)
.get_storage(get_storage_with_local_approved)
.get_server_statistics(get_server_statistics)
.capture_pokemon(capture_pokemon)
.do_nothing(do_nothing)
.check_health(check_health)
.build();

// Start the [`hyper::Server`].
let bind: std::net::SocketAddr = format!("{}:{}", args.address, args.port)
.parse()
.expect("unable to parse the server bind address and port");
let server = hyper::Server::bind(&bind).serve(app.into_make_service_with_connect_info::<std::net::SocketAddr>());

// Run forever-ish...
if let Err(err) = server.await {
eprintln!("server error: {}", err);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,7 @@ pub async fn get_pokemon_species(
}
}

/// Retrieves the users storage.
/// Retrieves the user's storage.
pub async fn get_storage(
input: input::GetStorageInput,
_state: Extension<Arc<State>>,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,163 @@
/*
* Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
* SPDX-License-Identifier: Apache-2.0
*/

// This code was copied and then modified from Tokio's Axum.

/* Copyright (c) 2021 Tower Contributors
*
* Permission is hereby granted, free of charge, to any
* person obtaining a copy of this software and associated
* documentation files (the "Software"), to deal in the
* Software without restriction, including without
* limitation the rights to use, copy, modify, merge,
* publish, distribute, sublicense, and/or sell copies of
* the Software, and to permit persons to whom the Software
* is furnished to do so, subject to the following
* conditions:
*
* The above copyright notice and this permission notice
* shall be included in all copies or substantial portions
* of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF
* ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED
* TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A
* PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT
* SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION
* OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR
* IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
* DEALINGS IN THE SOFTWARE.
*/

use std::{
convert::Infallible,
fmt,
future::ready,
marker::PhantomData,
net::SocketAddr,
task::{Context, Poll},
};

use http::request::Parts;
use hyper::server::conn::AddrStream;
use tower::{Layer, Service};
use tower_http::add_extension::{AddExtension, AddExtensionLayer};

use crate::{request::FromParts, Extension};

/// A [`MakeService`] created from a router.
///
/// See [`Router::into_make_service_with_connect_info`] for more details.
///
/// [`MakeService`]: tower::make::MakeService
/// [`Router::into_make_service_with_connect_info`]: crate::routing::Router::into_make_service_with_connect_info
pub struct IntoMakeServiceWithConnectInfo<S, C> {
inner: S,
_connect_info: PhantomData<fn() -> C>,
}

impl<S, C> IntoMakeServiceWithConnectInfo<S, C> {
pub fn new(svc: S) -> Self {
Self {
inner: svc,
_connect_info: PhantomData,
}
}
}

impl<S, C> fmt::Debug for IntoMakeServiceWithConnectInfo<S, C>
where
S: fmt::Debug,
{
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("IntoMakeServiceWithConnectInfo")
.field("inner", &self.inner)
.finish()
}
}

impl<S, C> Clone for IntoMakeServiceWithConnectInfo<S, C>
where
S: Clone,
{
fn clone(&self) -> Self {
Self {
inner: self.inner.clone(),
_connect_info: PhantomData,
}
}
}

/// Trait that connected IO resources implement and use to produce information
/// about the connection.
///
/// The goal for this trait is to allow users to implement custom IO types that
/// can still provide the same connection metadata.
///
/// See [`Router::into_make_service_with_connect_info`] for more details.
///
/// [`Router::into_make_service_with_connect_info`]: crate::routing::Router::into_make_service_with_connect_info
pub trait Connected<T>: Clone {
/// Create type holding information about the connection.
fn connect_info(target: T) -> Self;
}

impl Connected<&AddrStream> for SocketAddr {
fn connect_info(target: &AddrStream) -> Self {
target.remote_addr()
}
}

impl<S, C, T> Service<T> for IntoMakeServiceWithConnectInfo<S, C>
where
S: Clone,
C: Connected<T>,
{
type Response = AddExtension<S, ConnectInfo<C>>;
type Error = Infallible;
type Future = ResponseFuture<S, C>;

#[inline]
fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
Poll::Ready(Ok(()))
}

fn call(&mut self, target: T) -> Self::Future {
let connect_info = ConnectInfo(C::connect_info(target));
let svc = AddExtensionLayer::new(connect_info).layer(self.inner.clone());
ResponseFuture::new(ready(Ok(svc)))
}
}

opaque_future! {
/// Response future for [`IntoMakeServiceWithConnectInfo`].
pub type ResponseFuture<S, C> =
std::future::Ready<Result<AddExtension<S, ConnectInfo<C>>, Infallible>>;
}

/// Extractor for getting connection information produced by a `Connected`.
///
/// Note this extractor requires you to use
/// [`Router::into_make_service_with_connect_info`] to run your app
/// otherwise it will fail at runtime.
///
/// See [`Router::into_make_service_with_connect_info`] for more details.
///
/// [`Router::into_make_service_with_connect_info`]: crate::routing::Router::into_make_service_with_connect_info
#[derive(Clone, Debug)]
pub struct ConnectInfo<T>(pub T);

impl<P, T> FromParts<P> for ConnectInfo<T>
where
T: Send + Sync + 'static,
{
type Rejection = <Extension<Self> as FromParts<P>>::Rejection;

fn from_parts(parts: &mut Parts) -> Result<Self, Self::Rejection> {
let Extension(connect_info) = <Extension<Self> as FromParts<P>>::from_parts(parts)?;
Ok(connect_info)
}
}
18 changes: 17 additions & 1 deletion rust-runtime/aws-smithy-http-server/src/routing/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ use tower_http::map_response_body::MapResponseBodyLayer;

mod future;
mod into_make_service;
mod into_make_service_with_connect_info;
mod lambda_handler;

#[doc(hidden)]
Expand All @@ -39,7 +40,10 @@ mod route;
pub(crate) mod tiny_map;

pub use self::lambda_handler::LambdaHandler;
pub use self::{future::RouterFuture, into_make_service::IntoMakeService, route::Route};
pub use self::{
future::RouterFuture, into_make_service::IntoMakeService, into_make_service_with_connect_info::ConnectInfo,
into_make_service_with_connect_info::IntoMakeServiceWithConnectInfo, route::Route,
};

/// The router is a [`tower::Service`] that routes incoming requests to other `Service`s
/// based on the request's URI and HTTP method or on some specific header setting the target operation.
Expand Down Expand Up @@ -116,6 +120,18 @@ where
IntoMakeService::new(self)
}

/// Convert this router into a [`MakeService`], that is a [`Service`] whose
/// response is another service, and provides a [`ConnectInfo`] object to service handlers.
///
/// This is useful when running your application with hyper's
/// [`Server`].
///
/// [`Server`]: hyper::server::Server
/// [`MakeService`]: tower::make::MakeService
pub fn into_make_service_with_connect_info<C>(self) -> IntoMakeServiceWithConnectInfo<Self, C> {
IntoMakeServiceWithConnectInfo::new(self)
}

/// Apply a [`tower::Layer`] to the router.
///
/// All requests to the router will be processed by the layer's
Expand Down

0 comments on commit 0a7f8a6

Please sign in to comment.