Skip to content

Commit

Permalink
Merge pull request #2631 from softwaremill/on-decode-failure-attribute
Browse files Browse the repository at this point in the history
Add an attribute to customise default decode failure handling for individual inputs/outputs
  • Loading branch information
adamw authored Dec 15, 2022
2 parents 92393ff + cbbb0a3 commit 4e3c004
Show file tree
Hide file tree
Showing 3 changed files with 113 additions and 20 deletions.
38 changes: 35 additions & 3 deletions doc/server/errors.md
Original file line number Diff line number Diff line change
Expand Up @@ -112,9 +112,41 @@ an error or return a "no match", create error messages and create the response.
swapped, e.g. to return responses in a different format (other than plain text), or customise the error messages.

The default decode failure handler also has the option to return a `400 Bad Request`, instead of a no-match (ultimately
leading to a `404 Not Found`), when the "shape" of the path matches (that is, the number of segments in the request
and endpoint's paths are the same), but when decoding some part of the path ends in an error. See the
`badRequestOnPathErrorIfPathShapeMatches` in `ServerDefaults`.
leading to a `404 Not Found`), when the "shape" of the path matches (that is, the constant parts and number of segments
in the request and endpoint's paths are the same), but when decoding some part of the path ends in an error. See the
scaladoc for `DefaultDecodeFailureHandler.default` and parameters of `DefaultDecodeFailureHandler.response`. For example:

```scala mdoc:compile-only
import sttp.tapir._
import sttp.tapir.server.akkahttp.AkkaHttpServerOptions
import sttp.tapir.server.interceptor.decodefailure.DefaultDecodeFailureHandler
import scala.concurrent.ExecutionContext.Implicits.global

val myDecodeFailureHandler = DefaultDecodeFailureHandler.default.copy(
respond = DefaultDecodeFailureHandler.respond(
_,
badRequestOnPathErrorIfPathShapeMatches = true,
badRequestOnPathInvalidIfPathShapeMatches = true
)
)

val myServerOptions: AkkaHttpServerOptions = AkkaHttpServerOptions
.customiseInterceptors
.decodeFailureHandler(myDecodeFailureHandler)
.options
```

Moreover, when using the `DefaultDecodeFailureHandler`, decode failure handling can be overriden on a per-input/output
basis, by setting an attribute. For example:

```scala mdoc:compile-only
import sttp.tapir._
// bringing into scope the onDecodeFailureBadRequest extension method
import sttp.tapir.server.interceptor.decodefailure.DefaultDecodeFailureHandler.OnDecodeFailure._

// by default, when the customer_id is not an int, the next endpoint would be tried; here, we always return a bad request
endpoint.in("customer" / path[Int]("customer_id").onDecodeFailureBadRequest)
```

## Customising how error messages are rendered

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,21 @@ object DefaultDecodeFailureHandler {
* The error messages contain information about the source of the decode error, and optionally the validation error detail that caused
* the failure.
*
* The default decode failure handler can be customised by providing alternate functions for deciding whether a response should be sent,
* creating the error message and creating the response.
*
* Furthermore, how decode failures are handled can be adjusted globally by changing the flags passed to [[respond]]. By default, if the
* shape of the path for an endpoint matches the request, but decoding a path capture causes an error (e.g. a `path[Int]("amount")`
* cannot be parsed), the next endpoint is tried. However, if there's a validation error (e.g. a `path[Kind]("kind")`, where `Kind` is an
* enum, and a value outside the enumeration values is provided), a 400 response is sent.
*
* Finally, behavior can be adjusted per-endpoint-input, by setting an attribute. Import the [[OnDecodeFailure]] object and use the
* [[OnDecodeFailure.RichEndpointTransput.onDecodeFailureBadRequest]] and
* [[OnDecodeFailure.RichEndpointTransput.onDecodeFailureNextEndpoint]] extension methods.
*
* This is only used for failures that occur when decoding inputs, not for exceptions that happen when the server logic is invoked.
* Exceptions can be either handled by the server logic, and converted to an error output value. Uncaught exceptions can be handled using
* the [[sttp.tapir.server.interceptor.exception.ExceptionInterceptor]].
*/
val default: DefaultDecodeFailureHandler = DefaultDecodeFailureHandler(
respond(_, badRequestOnPathErrorIfPathShapeMatches = false, badRequestOnPathInvalidIfPathShapeMatches = true),
Expand Down Expand Up @@ -95,35 +109,38 @@ object DefaultDecodeFailureHandler {
badRequestOnPathInvalidIfPathShapeMatches: Boolean
): Option[(StatusCode, List[Header])] = {
failingInput(ctx) match {
case _: EndpointInput.Query[_] => Some(onlyStatus(StatusCode.BadRequest))
case _: EndpointInput.QueryParams[_] => Some(onlyStatus(StatusCode.BadRequest))
case _: EndpointInput.Cookie[_] => Some(onlyStatus(StatusCode.BadRequest))
case i: EndpointTransput.Atom[_] if i.attribute(OnDecodeFailure.key).contains(OnDecodeFailureAttribute(true)) => respondBadRequest
case i: EndpointTransput.Atom[_] if i.attribute(OnDecodeFailure.key).contains(OnDecodeFailureAttribute(false)) => None
case _: EndpointInput.Query[_] => respondBadRequest
case _: EndpointInput.QueryParams[_] => respondBadRequest
case _: EndpointInput.Cookie[_] => respondBadRequest
case h: EndpointIO.Header[_] if ctx.failure.isInstanceOf[DecodeResult.Mismatch] && h.name == HeaderNames.ContentType =>
Some(onlyStatus(StatusCode.UnsupportedMediaType))
case _: EndpointIO.Header[_] => Some(onlyStatus(StatusCode.BadRequest))
respondUnsupportedMediaType
case _: EndpointIO.Header[_] => respondBadRequest
case fh: EndpointIO.FixedHeader[_] if ctx.failure.isInstanceOf[DecodeResult.Mismatch] && fh.h.name == HeaderNames.ContentType =>
Some(onlyStatus(StatusCode.UnsupportedMediaType))
case _: EndpointIO.FixedHeader[_] => Some(onlyStatus(StatusCode.BadRequest))
case _: EndpointIO.Headers[_] => Some(onlyStatus(StatusCode.BadRequest))
case _: EndpointIO.Body[_, _] => Some(onlyStatus(StatusCode.BadRequest))
case _: EndpointIO.OneOfBody[_, _] if ctx.failure.isInstanceOf[DecodeResult.Mismatch] =>
Some(onlyStatus(StatusCode.UnsupportedMediaType))
case _: EndpointIO.StreamBodyWrapper[_, _] => Some(onlyStatus(StatusCode.BadRequest))
respondUnsupportedMediaType
case _: EndpointIO.FixedHeader[_] => respondBadRequest
case _: EndpointIO.Headers[_] => respondBadRequest
case _: EndpointIO.Body[_, _] => respondBadRequest
case _: EndpointIO.OneOfBody[_, _] if ctx.failure.isInstanceOf[DecodeResult.Mismatch] => respondUnsupportedMediaType
case _: EndpointIO.StreamBodyWrapper[_, _] => respondBadRequest
// we assume that the only decode failure that might happen during path segment decoding is an error
// a non-standard path decoder might return Missing/Multiple/Mismatch, but that would be indistinguishable from
// a path shape mismatch
case _: EndpointInput.PathCapture[_]
if (badRequestOnPathErrorIfPathShapeMatches && ctx.failure.isInstanceOf[DecodeResult.Error]) ||
(badRequestOnPathInvalidIfPathShapeMatches && ctx.failure.isInstanceOf[DecodeResult.InvalidValue]) =>
Some(onlyStatus(StatusCode.BadRequest))
respondBadRequest
// if the failing input contains an authentication input (potentially nested), sending its challenge
case FirstAuth(a) => Some((StatusCode.Unauthorized, Header.wwwAuthenticate(a.challenge)))
// other basic endpoints - the request doesn't match, but not returning a response (trying other endpoints)
case _: EndpointInput.Basic[_] => None
// all other inputs (tuples, mapped) - responding with bad request
case _ => Some(onlyStatus(StatusCode.BadRequest))
case _ => respondBadRequest
}
}
private val respondBadRequest = Some(onlyStatus(StatusCode.BadRequest))
private val respondUnsupportedMediaType = Some(onlyStatus(StatusCode.UnsupportedMediaType))

def respondNotFoundIfHasAuth(
ctx: DecodeFailureContext,
Expand Down Expand Up @@ -285,4 +302,15 @@ object DefaultDecodeFailureHandler {
case _ => v
}
}

private[decodefailure] case class OnDecodeFailureAttribute(value: Boolean) extends AnyVal

object OnDecodeFailure {
private[decodefailure] val key: AttributeKey[OnDecodeFailureAttribute] = AttributeKey[OnDecodeFailureAttribute]

implicit class RichEndpointTransput[ET <: EndpointTransput.Atom[_]](val et: ET) extends AnyVal {
def onDecodeFailureBadRequest: ET = et.attribute(key, OnDecodeFailureAttribute(true)).asInstanceOf[ET]
def onDecodeFailureNextEndpoint: ET = et.attribute(key, OnDecodeFailureAttribute(false)).asInstanceOf[ET]
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ class ServerBasicTests[F[_], OPTIONS, ROUTE](
methodMatchingTests() ++
pathMatchingTests() ++
pathMatchingMultipleEndpoints() ++
pathShapeMatchingTests() ++
customiseDecodeFailureHandlerTests() ++
serverSecurityLogicTests() ++
(if (inputStreamSupport) inputStreamTests() else Nil) ++
exceptionTests()
Expand Down Expand Up @@ -593,10 +593,10 @@ class ServerBasicTests[F[_], OPTIONS, ROUTE](
}
)

def pathShapeMatchingTests(): List[Test] = List(
def customiseDecodeFailureHandlerTests(): List[Test] = List(
testServer(
in_path_fixed_capture_fixed_capture,
"Returns 400 if path 'shape' matches, but failed to parse a path parameter",
"Returns 400 if path 'shape' matches, but failed to parse a path parameter, using a custom decode failure handler",
_.decodeFailureHandler(decodeFailureHandlerBadRequestOnPathFailure)
)(_ => pureResult(Either.right[Unit, Unit](()))) { (backend, baseUri) =>
basicRequest.get(uri"$baseUri/customer/asd/orders/2").send(backend).map { response =>
Expand All @@ -615,6 +615,39 @@ class ServerBasicTests[F[_], OPTIONS, ROUTE](
.get(uri"$baseUri/customer/asd/orders/2/xyz")
.send(backend)
.map(response => response.code shouldBe StatusCode.NotFound)
}, {
import DefaultDecodeFailureHandler.OnDecodeFailure._
testServer(
endpoint.get.in("customer" / path[Int]("customer_id").onDecodeFailureBadRequest),
"Returns 400 if path 'shape' matches, but failed to parse a path parameter, using .badRequestOnDecodeFailure"
)(_ => pureResult(Either.right[Unit, Unit](()))) { (backend, baseUri) =>
basicRequest.get(uri"$baseUri/customer/asd").send(backend).map { response =>
response.body shouldBe Left("Invalid value for: path parameter customer_id")
response.code shouldBe StatusCode.BadRequest
}
}
}, {
import DefaultDecodeFailureHandler.OnDecodeFailure._
testServer(
"Tries next endpoint if path 'shape' matches, but validation fails, using .onDecodeFailureNextEndpoint",
NonEmptyList.of(
route(
List(
endpoint.get
.in("customer" / path[Int]("customer_id").validate(Validator.min(10)).onDecodeFailureNextEndpoint)
.out(stringBody)
.serverLogic[F]((_: Int) => pureResult("e1".asRight[Unit])),
endpoint.get
.in("customer" / path[String]("customer_id"))
.out(stringBody)
.serverLogic[F]((_: String) => pureResult("e2".asRight[Unit]))
)
)
)
) { (backend, baseUri) =>
basicStringRequest.get(uri"$baseUri/customer/20").send(backend).map(_.body shouldBe "e1") >>
basicStringRequest.get(uri"$baseUri/customer/2").send(backend).map(_.body shouldBe "e2")
}
}
)

Expand Down

0 comments on commit 4e3c004

Please sign in to comment.