diff --git a/doc/server/errors.md b/doc/server/errors.md index 181b751ab8..8d947ffb4a 100644 --- a/doc/server/errors.md +++ b/doc/server/errors.md @@ -112,9 +112,41 @@ an error or return a "no match", create error messages and create the response. swapped, e.g. to return responses in a different format (other than plain text), or customise the error messages. The default decode failure handler also has the option to return a `400 Bad Request`, instead of a no-match (ultimately -leading to a `404 Not Found`), when the "shape" of the path matches (that is, the number of segments in the request -and endpoint's paths are the same), but when decoding some part of the path ends in an error. See the -`badRequestOnPathErrorIfPathShapeMatches` in `ServerDefaults`. +leading to a `404 Not Found`), when the "shape" of the path matches (that is, the constant parts and number of segments +in the request and endpoint's paths are the same), but when decoding some part of the path ends in an error. See the +scaladoc for `DefaultDecodeFailureHandler.default` and parameters of `DefaultDecodeFailureHandler.response`. For example: + +```scala mdoc:compile-only +import sttp.tapir._ +import sttp.tapir.server.akkahttp.AkkaHttpServerOptions +import sttp.tapir.server.interceptor.decodefailure.DefaultDecodeFailureHandler +import scala.concurrent.ExecutionContext.Implicits.global + +val myDecodeFailureHandler = DefaultDecodeFailureHandler.default.copy( + respond = DefaultDecodeFailureHandler.respond( + _, + badRequestOnPathErrorIfPathShapeMatches = true, + badRequestOnPathInvalidIfPathShapeMatches = true + ) +) + +val myServerOptions: AkkaHttpServerOptions = AkkaHttpServerOptions + .customiseInterceptors + .decodeFailureHandler(myDecodeFailureHandler) + .options +``` + +Moreover, when using the `DefaultDecodeFailureHandler`, decode failure handling can be overriden on a per-input/output +basis, by setting an attribute. For example: + +```scala mdoc:compile-only +import sttp.tapir._ +// bringing into scope the onDecodeFailureBadRequest extension method +import sttp.tapir.server.interceptor.decodefailure.DefaultDecodeFailureHandler.OnDecodeFailure._ + +// by default, when the customer_id is not an int, the next endpoint would be tried; here, we always return a bad request +endpoint.in("customer" / path[Int]("customer_id").onDecodeFailureBadRequest) +``` ## Customising how error messages are rendered diff --git a/server/core/src/main/scala/sttp/tapir/server/interceptor/decodefailure/DecodeFailureHandler.scala b/server/core/src/main/scala/sttp/tapir/server/interceptor/decodefailure/DecodeFailureHandler.scala index c20c201a61..c118e5c3ba 100644 --- a/server/core/src/main/scala/sttp/tapir/server/interceptor/decodefailure/DecodeFailureHandler.scala +++ b/server/core/src/main/scala/sttp/tapir/server/interceptor/decodefailure/DecodeFailureHandler.scala @@ -61,7 +61,21 @@ object DefaultDecodeFailureHandler { * The error messages contain information about the source of the decode error, and optionally the validation error detail that caused * the failure. * + * The default decode failure handler can be customised by providing alternate functions for deciding whether a response should be sent, + * creating the error message and creating the response. + * + * Furthermore, how decode failures are handled can be adjusted globally by changing the flags passed to [[respond]]. By default, if the + * shape of the path for an endpoint matches the request, but decoding a path capture causes an error (e.g. a `path[Int]("amount")` + * cannot be parsed), the next endpoint is tried. However, if there's a validation error (e.g. a `path[Kind]("kind")`, where `Kind` is an + * enum, and a value outside the enumeration values is provided), a 400 response is sent. + * + * Finally, behavior can be adjusted per-endpoint-input, by setting an attribute. Import the [[OnDecodeFailure]] object and use the + * [[OnDecodeFailure.RichEndpointTransput.onDecodeFailureBadRequest]] and + * [[OnDecodeFailure.RichEndpointTransput.onDecodeFailureNextEndpoint]] extension methods. + * * This is only used for failures that occur when decoding inputs, not for exceptions that happen when the server logic is invoked. + * Exceptions can be either handled by the server logic, and converted to an error output value. Uncaught exceptions can be handled using + * the [[sttp.tapir.server.interceptor.exception.ExceptionInterceptor]]. */ val default: DefaultDecodeFailureHandler = DefaultDecodeFailureHandler( respond(_, badRequestOnPathErrorIfPathShapeMatches = false, badRequestOnPathInvalidIfPathShapeMatches = true), @@ -95,35 +109,38 @@ object DefaultDecodeFailureHandler { badRequestOnPathInvalidIfPathShapeMatches: Boolean ): Option[(StatusCode, List[Header])] = { failingInput(ctx) match { - case _: EndpointInput.Query[_] => Some(onlyStatus(StatusCode.BadRequest)) - case _: EndpointInput.QueryParams[_] => Some(onlyStatus(StatusCode.BadRequest)) - case _: EndpointInput.Cookie[_] => Some(onlyStatus(StatusCode.BadRequest)) + case i: EndpointTransput.Atom[_] if i.attribute(OnDecodeFailure.key).contains(OnDecodeFailureAttribute(true)) => respondBadRequest + case i: EndpointTransput.Atom[_] if i.attribute(OnDecodeFailure.key).contains(OnDecodeFailureAttribute(false)) => None + case _: EndpointInput.Query[_] => respondBadRequest + case _: EndpointInput.QueryParams[_] => respondBadRequest + case _: EndpointInput.Cookie[_] => respondBadRequest case h: EndpointIO.Header[_] if ctx.failure.isInstanceOf[DecodeResult.Mismatch] && h.name == HeaderNames.ContentType => - Some(onlyStatus(StatusCode.UnsupportedMediaType)) - case _: EndpointIO.Header[_] => Some(onlyStatus(StatusCode.BadRequest)) + respondUnsupportedMediaType + case _: EndpointIO.Header[_] => respondBadRequest case fh: EndpointIO.FixedHeader[_] if ctx.failure.isInstanceOf[DecodeResult.Mismatch] && fh.h.name == HeaderNames.ContentType => - Some(onlyStatus(StatusCode.UnsupportedMediaType)) - case _: EndpointIO.FixedHeader[_] => Some(onlyStatus(StatusCode.BadRequest)) - case _: EndpointIO.Headers[_] => Some(onlyStatus(StatusCode.BadRequest)) - case _: EndpointIO.Body[_, _] => Some(onlyStatus(StatusCode.BadRequest)) - case _: EndpointIO.OneOfBody[_, _] if ctx.failure.isInstanceOf[DecodeResult.Mismatch] => - Some(onlyStatus(StatusCode.UnsupportedMediaType)) - case _: EndpointIO.StreamBodyWrapper[_, _] => Some(onlyStatus(StatusCode.BadRequest)) + respondUnsupportedMediaType + case _: EndpointIO.FixedHeader[_] => respondBadRequest + case _: EndpointIO.Headers[_] => respondBadRequest + case _: EndpointIO.Body[_, _] => respondBadRequest + case _: EndpointIO.OneOfBody[_, _] if ctx.failure.isInstanceOf[DecodeResult.Mismatch] => respondUnsupportedMediaType + case _: EndpointIO.StreamBodyWrapper[_, _] => respondBadRequest // we assume that the only decode failure that might happen during path segment decoding is an error // a non-standard path decoder might return Missing/Multiple/Mismatch, but that would be indistinguishable from // a path shape mismatch case _: EndpointInput.PathCapture[_] if (badRequestOnPathErrorIfPathShapeMatches && ctx.failure.isInstanceOf[DecodeResult.Error]) || (badRequestOnPathInvalidIfPathShapeMatches && ctx.failure.isInstanceOf[DecodeResult.InvalidValue]) => - Some(onlyStatus(StatusCode.BadRequest)) + respondBadRequest // if the failing input contains an authentication input (potentially nested), sending its challenge case FirstAuth(a) => Some((StatusCode.Unauthorized, Header.wwwAuthenticate(a.challenge))) // other basic endpoints - the request doesn't match, but not returning a response (trying other endpoints) case _: EndpointInput.Basic[_] => None // all other inputs (tuples, mapped) - responding with bad request - case _ => Some(onlyStatus(StatusCode.BadRequest)) + case _ => respondBadRequest } } + private val respondBadRequest = Some(onlyStatus(StatusCode.BadRequest)) + private val respondUnsupportedMediaType = Some(onlyStatus(StatusCode.UnsupportedMediaType)) def respondNotFoundIfHasAuth( ctx: DecodeFailureContext, @@ -285,4 +302,15 @@ object DefaultDecodeFailureHandler { case _ => v } } + + private[decodefailure] case class OnDecodeFailureAttribute(value: Boolean) extends AnyVal + + object OnDecodeFailure { + private[decodefailure] val key: AttributeKey[OnDecodeFailureAttribute] = AttributeKey[OnDecodeFailureAttribute] + + implicit class RichEndpointTransput[ET <: EndpointTransput.Atom[_]](val et: ET) extends AnyVal { + def onDecodeFailureBadRequest: ET = et.attribute(key, OnDecodeFailureAttribute(true)).asInstanceOf[ET] + def onDecodeFailureNextEndpoint: ET = et.attribute(key, OnDecodeFailureAttribute(false)).asInstanceOf[ET] + } + } } diff --git a/server/tests/src/main/scala/sttp/tapir/server/tests/ServerBasicTests.scala b/server/tests/src/main/scala/sttp/tapir/server/tests/ServerBasicTests.scala index be802755f9..90dbe77a4c 100644 --- a/server/tests/src/main/scala/sttp/tapir/server/tests/ServerBasicTests.scala +++ b/server/tests/src/main/scala/sttp/tapir/server/tests/ServerBasicTests.scala @@ -40,7 +40,7 @@ class ServerBasicTests[F[_], OPTIONS, ROUTE]( methodMatchingTests() ++ pathMatchingTests() ++ pathMatchingMultipleEndpoints() ++ - pathShapeMatchingTests() ++ + customiseDecodeFailureHandlerTests() ++ serverSecurityLogicTests() ++ (if (inputStreamSupport) inputStreamTests() else Nil) ++ exceptionTests() @@ -593,10 +593,10 @@ class ServerBasicTests[F[_], OPTIONS, ROUTE]( } ) - def pathShapeMatchingTests(): List[Test] = List( + def customiseDecodeFailureHandlerTests(): List[Test] = List( testServer( in_path_fixed_capture_fixed_capture, - "Returns 400 if path 'shape' matches, but failed to parse a path parameter", + "Returns 400 if path 'shape' matches, but failed to parse a path parameter, using a custom decode failure handler", _.decodeFailureHandler(decodeFailureHandlerBadRequestOnPathFailure) )(_ => pureResult(Either.right[Unit, Unit](()))) { (backend, baseUri) => basicRequest.get(uri"$baseUri/customer/asd/orders/2").send(backend).map { response => @@ -615,6 +615,39 @@ class ServerBasicTests[F[_], OPTIONS, ROUTE]( .get(uri"$baseUri/customer/asd/orders/2/xyz") .send(backend) .map(response => response.code shouldBe StatusCode.NotFound) + }, { + import DefaultDecodeFailureHandler.OnDecodeFailure._ + testServer( + endpoint.get.in("customer" / path[Int]("customer_id").onDecodeFailureBadRequest), + "Returns 400 if path 'shape' matches, but failed to parse a path parameter, using .badRequestOnDecodeFailure" + )(_ => pureResult(Either.right[Unit, Unit](()))) { (backend, baseUri) => + basicRequest.get(uri"$baseUri/customer/asd").send(backend).map { response => + response.body shouldBe Left("Invalid value for: path parameter customer_id") + response.code shouldBe StatusCode.BadRequest + } + } + }, { + import DefaultDecodeFailureHandler.OnDecodeFailure._ + testServer( + "Tries next endpoint if path 'shape' matches, but validation fails, using .onDecodeFailureNextEndpoint", + NonEmptyList.of( + route( + List( + endpoint.get + .in("customer" / path[Int]("customer_id").validate(Validator.min(10)).onDecodeFailureNextEndpoint) + .out(stringBody) + .serverLogic[F]((_: Int) => pureResult("e1".asRight[Unit])), + endpoint.get + .in("customer" / path[String]("customer_id")) + .out(stringBody) + .serverLogic[F]((_: String) => pureResult("e2".asRight[Unit])) + ) + ) + ) + ) { (backend, baseUri) => + basicStringRequest.get(uri"$baseUri/customer/20").send(backend).map(_.body shouldBe "e1") >> + basicStringRequest.get(uri"$baseUri/customer/2").send(backend).map(_.body shouldBe "e2") + } } )