From 63d6fe1e87821eb1cd75e5e9945664b8a6b9cf02 Mon Sep 17 00:00:00 2001 From: 82marbag <69267416+82marbag@users.noreply.github.com> Date: Tue, 16 Aug 2022 16:47:01 -0700 Subject: [PATCH] Use tower-http layer to validate ACCEPT Signed-off-by: Daniele Ahmed --- .../ServerHttpBoundProtocolGenerator.kt | 22 ------------------- .../examples/pokemon-service/src/main.rs | 3 ++- 2 files changed, 2 insertions(+), 23 deletions(-) diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/ServerHttpBoundProtocolGenerator.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/ServerHttpBoundProtocolGenerator.kt index 088ed67d9dc..464df06deb9 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/ServerHttpBoundProtocolGenerator.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/ServerHttpBoundProtocolGenerator.kt @@ -164,26 +164,6 @@ private class ServerHttpBoundProtocolTraitImplGenerator( val operationName = symbolProvider.toSymbol(operationShape).name val inputName = "${operationName}${ServerHttpBoundProtocolGenerator.OPERATION_INPUT_WRAPPER_SUFFIX}" - val verifyResponseContentType = writable { - httpBindingResolver.responseContentType(operationShape)?.also { contentType -> - rustTemplate( - """ - if let Some(headers) = req.headers() { - if let Some(accept) = headers.get(#{http}::header::ACCEPT) { - if accept != "$contentType" { - return Err(#{RuntimeError} { - protocol: #{SmithyHttpServer}::protocols::Protocol::${codegenContext.protocol.name.toPascalCase()}, - kind: #{SmithyHttpServer}::runtime_error::RuntimeErrorKind::NotAcceptable, - }) - } - } - } - """, - *codegenScope, - ) - } - } - // Implement `from_request` trait for input types. rustTemplate( """ @@ -197,7 +177,6 @@ private class ServerHttpBoundProtocolTraitImplGenerator( B::Data: Send, #{RequestRejection} : From<::Error> { - #{verify_response_content_type:W} #{parse_request}(req) .await .map($inputName) @@ -213,7 +192,6 @@ private class ServerHttpBoundProtocolTraitImplGenerator( *codegenScope, "I" to inputSymbol, "parse_request" to serverParseRequest(operationShape), - "verify_response_content_type" to verifyResponseContentType, ) // Implement `into_response` for output types. diff --git a/rust-runtime/aws-smithy-http-server/examples/pokemon-service/src/main.rs b/rust-runtime/aws-smithy-http-server/examples/pokemon-service/src/main.rs index e1413c68557..57921bf335d 100644 --- a/rust-runtime/aws-smithy-http-server/examples/pokemon-service/src/main.rs +++ b/rust-runtime/aws-smithy-http-server/examples/pokemon-service/src/main.rs @@ -14,7 +14,7 @@ use pokemon_service::{ }; use pokemon_service_server_sdk::operation_registry::OperationRegistryBuilder; use tower::ServiceBuilder; -use tower_http::trace::TraceLayer; +use tower_http::{trace::TraceLayer, validate_request::ValidateRequestHeaderLayer}; #[derive(Parser, Debug)] #[clap(author, version, about, long_about = None)] @@ -52,6 +52,7 @@ pub async fn main() { let app = app.layer( ServiceBuilder::new() .layer(TraceLayer::new_for_http()) + .layer(ValidateRequestHeaderLayer::accept("application/json")) .layer(AddExtensionLayer::new(shared_state)), );