Skip to content

Commit

Permalink
Use tower-http layer to validate ACCEPT
Browse files Browse the repository at this point in the history
Signed-off-by: Daniele Ahmed <[email protected]>
  • Loading branch information
82marbag authored and Daniele Ahmed committed Aug 16, 2022
1 parent 7b2aef7 commit 63d6fe1
Show file tree
Hide file tree
Showing 2 changed files with 2 additions and 23 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -164,26 +164,6 @@ private class ServerHttpBoundProtocolTraitImplGenerator(
val operationName = symbolProvider.toSymbol(operationShape).name
val inputName = "${operationName}${ServerHttpBoundProtocolGenerator.OPERATION_INPUT_WRAPPER_SUFFIX}"

val verifyResponseContentType = writable {
httpBindingResolver.responseContentType(operationShape)?.also { contentType ->
rustTemplate(
"""
if let Some(headers) = req.headers() {
if let Some(accept) = headers.get(#{http}::header::ACCEPT) {
if accept != "$contentType" {
return Err(#{RuntimeError} {
protocol: #{SmithyHttpServer}::protocols::Protocol::${codegenContext.protocol.name.toPascalCase()},
kind: #{SmithyHttpServer}::runtime_error::RuntimeErrorKind::NotAcceptable,
})
}
}
}
""",
*codegenScope,
)
}
}

// Implement `from_request` trait for input types.
rustTemplate(
"""
Expand All @@ -197,7 +177,6 @@ private class ServerHttpBoundProtocolTraitImplGenerator(
B::Data: Send,
#{RequestRejection} : From<<B as #{SmithyHttpServer}::body::HttpBody>::Error>
{
#{verify_response_content_type:W}
#{parse_request}(req)
.await
.map($inputName)
Expand All @@ -213,7 +192,6 @@ private class ServerHttpBoundProtocolTraitImplGenerator(
*codegenScope,
"I" to inputSymbol,
"parse_request" to serverParseRequest(operationShape),
"verify_response_content_type" to verifyResponseContentType,
)

// Implement `into_response` for output types.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ use pokemon_service::{
};
use pokemon_service_server_sdk::operation_registry::OperationRegistryBuilder;
use tower::ServiceBuilder;
use tower_http::trace::TraceLayer;
use tower_http::{trace::TraceLayer, validate_request::ValidateRequestHeaderLayer};

#[derive(Parser, Debug)]
#[clap(author, version, about, long_about = None)]
Expand Down Expand Up @@ -52,6 +52,7 @@ pub async fn main() {
let app = app.layer(
ServiceBuilder::new()
.layer(TraceLayer::new_for_http())
.layer(ValidateRequestHeaderLayer::accept("application/json"))
.layer(AddExtensionLayer::new(shared_state)),
);

Expand Down

0 comments on commit 63d6fe1

Please sign in to comment.