Skip to content

Commit

Permalink
Allow using streaming bodies in oneOfBody
Browse files Browse the repository at this point in the history
  • Loading branch information
adamw committed Apr 21, 2022
1 parent 4343ced commit c3f45cb
Show file tree
Hide file tree
Showing 11 changed files with 123 additions and 50 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ abstract class ClientOutputParams {
.flatMap(MediaType.parse(_).toOption)
.flatMap(ct => variants.find(v => ct.matches(v.range)))
.getOrElse(variants.head)
body2.flatMap(decode(bodyVariant.body.codec, _))
body2.flatMap(decode(bodyVariant.codec, _))
case EndpointIO.StreamBodyWrapper(StreamBodyIO(_, codec, _, _, _)) => decode(codec, body)
case EndpointOutput.WebSocketBodyWrapper(o) => decodeWebSocketBody(o, body)
case EndpointIO.Header(name, codec, _) => codec.decode(meta.headers(name).toList)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,14 @@ private[sttp] class EndpointToSttpClient[R](clientOptions: SttpClientOptions, ws
case EndpointIO.Body(bodyType, codec, _) =>
val req2 = setBody(value, bodyType, codec, req)
(uri, req2)
case EndpointIO.OneOfBody(variants, _) => setInputParams(variants.head.body, params, uri, req)
case EndpointIO.OneOfBody(EndpointIO.OneOfBodyVariant(_, Left(body)) :: _, _) => setInputParams(body, params, uri, req)
case EndpointIO.OneOfBody(
EndpointIO.OneOfBodyVariant(_, Right(EndpointIO.StreamBodyWrapper(StreamBodyIO(streams, _, _, _, _)))) :: _,
_
) =>
val req2 = req.streamBody(streams)(value.asInstanceOf[streams.BinaryStream])
(uri, req2)
case EndpointIO.OneOfBody(Nil, _) => throw new RuntimeException("One of body without variants")
case EndpointIO.StreamBodyWrapper(StreamBodyIO(streams, _, _, _, _)) =>
val req2 = req.streamBody(streams)(value.asInstanceOf[streams.BinaryStream])
(uri, req2)
Expand Down Expand Up @@ -206,8 +213,9 @@ private[sttp] class EndpointToSttpClient[R](clientOptions: SttpClientOptions, ws
}

private def bodyIsStream[I](out: EndpointOutput[I]): Option[Streams[_]] = {
out.traverseOutputs { case EndpointIO.StreamBodyWrapper(StreamBodyIO(streams, _, _, _, _)) =>
Vector(streams)
out.traverseOutputs {
case EndpointIO.StreamBodyWrapper(StreamBodyIO(streams, _, _, _, _)) => Vector(streams)
case EndpointIO.OneOfBody(variants, _) => variants.flatMap(_.body.toOption).map(_.wrapped.streams).toVector
}.headOption
}

Expand Down
13 changes: 9 additions & 4 deletions core/src/main/scala/sttp/tapir/EndpointIO.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ package sttp.tapir

import sttp.capabilities.Streams
import sttp.model.headers.WWWAuthenticateChallenge
import sttp.model.{ContentTypeRange, Method}
import sttp.model.{ContentTypeRange, MediaType, Method}
import sttp.tapir.CodecFormat.TextPlain
import sttp.tapir.EndpointIO.{Example, Info}
import sttp.tapir.RawBodyType._
Expand Down Expand Up @@ -439,13 +439,18 @@ object EndpointIO {
override def show: String = wrapped.show
}

case class OneOfBodyVariant[O](range: ContentTypeRange, body: Body[_, O])
case class OneOfBodyVariant[O](range: ContentTypeRange, body: Either[Body[_, O], StreamBodyWrapper[_, O]]) {
def show: String = body.fold(_.show, _.show)
def mediaTypeWithCharset: MediaType = body.fold(_.mediaTypeWithCharset, _.mediaTypeWithCharset)
def codec: Codec[_, O, _ <: CodecFormat] = body.fold(_.codec, _.codec)
}
case class OneOfBody[O, T](variants: List[OneOfBodyVariant[O]], mapping: Mapping[O, T]) extends Basic[T] {
override private[tapir] type ThisType[X] = OneOfBody[O, X]
override def show: String = showOneOf(variants.map { variant =>
val prefix =
if (ContentTypeRange.exactNoCharset(variant.body.codec.format.mediaType) == variant.range) "" else s"${variant.range} -> "
prefix + variant.body.show
if (ContentTypeRange.exactNoCharset(variant.codec.format.mediaType) == variant.range) ""
else s"${variant.range} -> "
prefix + variant.show
})
override def map[U](m: Mapping[T, U]): OneOfBody[O, U] = copy[O, U](mapping = mapping.map(m))
}
Expand Down
22 changes: 20 additions & 2 deletions core/src/main/scala/sttp/tapir/Tapir.scala
Original file line number Diff line number Diff line change
Expand Up @@ -398,7 +398,14 @@ trait Tapir extends TapirExtensions with TapirComputedInputs with TapirStaticCon
*/
def oneOfBody[T](first: EndpointIO.Body[_, T], others: EndpointIO.Body[_, T]*): EndpointIO.OneOfBody[T, T] =
EndpointIO.OneOfBody[T, T](
(first +: others.toList).map(b => EndpointIO.OneOfBodyVariant(ContentTypeRange.exactNoCharset(b.codec.format.mediaType), b)),
(first +: others.toList).map(b => EndpointIO.OneOfBodyVariant(ContentTypeRange.exactNoCharset(b.codec.format.mediaType), Left(b))),
Mapping.id
)

/** Streaming variant of [[oneOfBody]]. */
def oneOfBody[T](first: EndpointIO.StreamBodyWrapper[_, T], others: EndpointIO.StreamBodyWrapper[_, T]*): EndpointIO.OneOfBody[T, T] =
EndpointIO.OneOfBody[T, T](
(first +: others.toList).map(b => EndpointIO.OneOfBodyVariant(ContentTypeRange.exactNoCharset(b.codec.format.mediaType), Right(b))),
Mapping.id
)

Expand All @@ -411,7 +418,18 @@ trait Tapir extends TapirExtensions with TapirComputedInputs with TapirStaticCon
first: (ContentTypeRange, EndpointIO.Body[_, T]),
others: (ContentTypeRange, EndpointIO.Body[_, T])*
): EndpointIO.OneOfBody[T, T] =
EndpointIO.OneOfBody[T, T]((first +: others.toList).map { case (r, b) => EndpointIO.OneOfBodyVariant(r, b) }, Mapping.id)
EndpointIO.OneOfBody[T, T]((first +: others.toList).map { case (r, b) => EndpointIO.OneOfBodyVariant(r, Left(b)) }, Mapping.id)

/** Streaming variant of [[oneOfBody]].
*
* Allows explicitly specifying the content type range, for which each body will be used, instead of defaulting to the exact media type
* as specified by the body's codec. This is only used when choosing which body to decode.
*/
def oneOfBody[T](
first: (ContentTypeRange, EndpointIO.StreamBodyWrapper[_, T]),
others: (ContentTypeRange, EndpointIO.StreamBodyWrapper[_, T])*
): EndpointIO.OneOfBody[T, T] =
EndpointIO.OneOfBody[T, T]((first +: others.toList).map { case (r, b) => EndpointIO.OneOfBodyVariant(r, Right(b)) }, Mapping.id)

private val emptyIO: EndpointIO.Empty[Unit] = EndpointIO.Empty(Codec.idPlain(), EndpointIO.Info.empty)

Expand Down
8 changes: 4 additions & 4 deletions core/src/main/scala/sttp/tapir/internal/package.scala
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ package object internal {
case EndpointIO.MappedPair(wrapped, _) => wrapped.asBasicOutputsList
case _: EndpointOutput.Void[_] => List(Vector.empty)
case s: EndpointOutput.OneOf[_, _] => s.variants.flatMap(_.output.asBasicOutputsList)
case EndpointIO.OneOfBody(variants, _) => variants.flatMap(_.body.asBasicOutputsList)
case EndpointIO.OneOfBody(variants, _) => variants.flatMap(_.body.fold(_.asBasicOutputsList, _.asBasicOutputsList))
case e: EndpointIO.Empty[_] => if (hasMetaData(e)) List(Vector(e)) else List(Vector.empty)
case b: EndpointOutput.Basic[_] => List(Vector(b))
}
Expand All @@ -169,13 +169,13 @@ package object internal {
def bodyType: Option[RawBodyType[_]] = {
traverseOutputs[RawBodyType[_]] {
case b: EndpointIO.Body[_, _] => Vector(b.bodyType)
case EndpointIO.OneOfBody(variants, _) => variants.map(_.body.bodyType).toVector
case EndpointIO.OneOfBody(variants, _) => variants.flatMap(_.body.fold(body => Some(body.bodyType), _.bodyType)).toVector
}.headOption
}

def supportedMediaTypes: Vector[MediaType] = traverseOutputs {
case b: EndpointIO.Body[_, _] => Vector(b.mediaTypeWithCharset)
case EndpointIO.OneOfBody(variants, _) => variants.map(_.body.mediaTypeWithCharset).toVector
case EndpointIO.OneOfBody(variants, _) => variants.map(_.mediaTypeWithCharset).toVector
case b: EndpointIO.StreamBodyWrapper[_, _] => Vector(b.mediaTypeWithCharset)
}

Expand Down Expand Up @@ -214,7 +214,7 @@ package object internal {
}

implicit class RichOneOfBody[O, T](body: EndpointIO.OneOfBody[O, T]) {
def chooseBodyToDecode(contentType: Option[MediaType]): Option[EndpointIO.Body[_, O]] = {
def chooseBodyToDecode(contentType: Option[MediaType]): Option[Either[EndpointIO.Body[_, O], EndpointIO.StreamBodyWrapper[_, O]]] = {
contentType match {
case Some(ct) => body.variants.find { case EndpointIO.OneOfBodyVariant(range, _) => ct.matches(range) }
case None => Some(body.variants.head)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
package sttp.tapir.server.interpreter

import sttp.model._
import sttp.tapir.EndpointIO.{Body, OneOfBodyVariant}
import sttp.tapir.EndpointIO.{Body, OneOfBodyVariant, StreamBodyWrapper}
import sttp.tapir.EndpointOutput.OneOfVariant
import sttp.tapir.internal.{Params, ParamsAsAny, SplitParams, _}
import sttp.tapir.{Codec, CodecFormat, EndpointIO, EndpointOutput, Mapping, StreamBodyIO, WebSocketBodyOutput}
Expand Down Expand Up @@ -74,9 +74,9 @@ class EncodeOutputs[B, S](rawToResponseBody: ToResponseBody[B, S], acceptsConten
}
}

private def chooseOneOfVariant(variants: List[OneOfBodyVariant[_]]): Body[_, _] = {
val mediaTypeToBody = variants.map(v => v.body.mediaTypeWithCharset -> v.body)
chooseBestVariant[Body[_, _]](mediaTypeToBody).getOrElse(variants.head.body)
private def chooseOneOfVariant(variants: List[OneOfBodyVariant[_]]): EndpointIO.Atom[_] = {
val mediaTypeToBody = variants.map(v => v.mediaTypeWithCharset -> v.body)
chooseBestVariant[Either[Body[_, _], StreamBodyWrapper[_, _]]](mediaTypeToBody).getOrElse(variants.head.body).fold(identity, identity)
}

private def chooseOneOfVariant(variants: Seq[OneOfVariant[_]]): OneOfVariant[_] = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ import sttp.tapir.model.ServerRequest
import sttp.tapir.server.{model, _}
import sttp.tapir.server.interceptor._
import sttp.tapir.server.model.{ServerResponse, ValuedEndpointOutput}
import sttp.tapir.{Codec, DecodeResult, EndpointIO, EndpointInput, StreamBodyIO, TapirFile}
import sttp.tapir.{DecodeResult, EndpointIO, EndpointInput, TapirFile}

class ServerInterpreter[R, F[_], B, S](
serverEndpoints: ServerRequest => List[ServerEndpoint[R, F]],
Expand Down Expand Up @@ -142,21 +142,26 @@ class ServerInterpreter[R, F[_], B, S](
values.bodyInputWithIndex match {
case Some((Left(oneOfBodyInput), _)) =>
oneOfBodyInput.chooseBodyToDecode(request.contentTypeParsed) match {
case Some(body) => decodeBody(request, values, body)
case None => unsupportedInputMediaTypeResponse(request, oneOfBodyInput)
case Some(Left(body)) => decodeBody(request, values, body)
case Some(Right(body: EndpointIO.StreamBodyWrapper[Any, Any])) => decodeStreamingBody(request, values, body)
case None => unsupportedInputMediaTypeResponse(request, oneOfBodyInput)
}

case Some((Right(bodyInput @ EndpointIO.StreamBodyWrapper(StreamBodyIO(_, codec: Codec[Any, Any, _], _, _, _))), _)) =>
(codec.decode(requestBody.toStream(request)) match {
case DecodeResult.Value(bodyV) => values.setBodyInputValue(bodyV)
case failure: DecodeResult.Failure => DecodeBasicInputsResult.Failure(bodyInput, failure): DecodeBasicInputsResult
}).unit

case None => (values: DecodeBasicInputsResult).unit
case Some((Right(bodyInput: EndpointIO.StreamBodyWrapper[Any, Any]), _)) => decodeStreamingBody(request, values, bodyInput)
case None => (values: DecodeBasicInputsResult).unit
}
case failure: DecodeBasicInputsResult.Failure => (failure: DecodeBasicInputsResult).unit
}

private def decodeStreamingBody(
request: ServerRequest,
values: DecodeBasicInputsResult.Values,
bodyInput: EndpointIO.StreamBodyWrapper[Any, Any]
): F[DecodeBasicInputsResult] =
(bodyInput.codec.decode(requestBody.toStream(request)) match {
case DecodeResult.Value(bodyV) => values.setBodyInputValue(bodyV)
case failure: DecodeResult.Failure => DecodeBasicInputsResult.Failure(bodyInput, failure): DecodeBasicInputsResult
}).unit

private def decodeBody[RAW, T](
request: ServerRequest,
values: DecodeBasicInputsResult.Values,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,44 +5,48 @@ import sttp.client3.{Request, StreamBody}
import sttp.model._
import sttp.tapir.internal.RichOneOfBody
import sttp.tapir.server.interpreter.{DecodeBasicInputs, DecodeBasicInputsResult, DecodeInputsContext, RawValue}
import sttp.tapir.{Codec, DecodeResult, EndpointIO, EndpointInput, RawBodyType, StreamBodyIO}
import sttp.tapir.{DecodeResult, EndpointIO, EndpointInput, RawBodyType}

import java.io.ByteArrayInputStream
import java.nio.ByteBuffer

private[stub] object SttpRequestDecoder {
def apply(request: Request[_, _], input: EndpointInput[_]): DecodeBasicInputsResult = {
DecodeBasicInputs(input, DecodeInputsContext(new SttpRequest(request)))._1 match {
DecodeBasicInputs(input, DecodeInputsContext(SttpRequest(request)))._1 match {
case values: DecodeBasicInputsResult.Values =>
def decodeBody[RAW, T](bodyInput: EndpointIO.Body[RAW, T]): DecodeBasicInputsResult = {
bodyInput.codec.decode(rawBody(request, bodyInput)) match {
case DecodeResult.Value(bodyV) => values.setBodyInputValue(bodyV)
case failure: DecodeResult.Failure => DecodeBasicInputsResult.Failure(bodyInput, failure): DecodeBasicInputsResult
}
}

def decodeStreamingBody(bodyInput: EndpointIO.StreamBodyWrapper[Any, Any]): DecodeBasicInputsResult = {
val value = request.body match {
case StreamBody(s) => RawValue(s)
case _ => throw new IllegalArgumentException("Raw body provided while endpoint accepts stream body")
}
bodyInput.wrapped.codec.decode(value) match {
case DecodeResult.Value(bodyV) => values.setBodyInputValue(bodyV)
case failure: DecodeResult.Failure => DecodeBasicInputsResult.Failure(bodyInput, failure): DecodeBasicInputsResult
}
}

values.bodyInputWithIndex match {
case Some((Left(oneOfBodyInput), _)) =>
def run[RAW, T](bodyInput: EndpointIO.Body[RAW, T]): DecodeBasicInputsResult = {
bodyInput.codec.decode(rawBody(request, bodyInput)) match {
case DecodeResult.Value(bodyV) => values.setBodyInputValue(bodyV)
case failure: DecodeResult.Failure => DecodeBasicInputsResult.Failure(bodyInput, failure): DecodeBasicInputsResult
}
}

val requestContentType: Option[String] = request.contentType
oneOfBodyInput.chooseBodyToDecode(requestContentType.flatMap(MediaType.parse(_).toOption)) match {
case Some(body) => run(body)
case Some(Left(body)) => decodeBody(body)
case Some(Right(body: EndpointIO.StreamBodyWrapper[Any, Any])) => decodeStreamingBody(body)
case None =>
DecodeBasicInputsResult.Failure(
oneOfBodyInput,
DecodeResult.Mismatch(oneOfBodyInput.show, requestContentType.getOrElse(""))
): DecodeBasicInputsResult
}

case Some((Right(bodyInput @ EndpointIO.StreamBodyWrapper(StreamBodyIO(_, codec: Codec[Any, Any, _], _, _, _))), _)) =>
val value = request.body match {
case StreamBody(s) => RawValue(s)
case _ => throw new IllegalArgumentException("Raw body provided while endpoint accepts stream body")
}
codec.decode(value) match {
case DecodeResult.Value(bodyV) => values.setBodyInputValue(bodyV)
case failure: DecodeResult.Failure => DecodeBasicInputsResult.Failure(bodyInput, failure): DecodeBasicInputsResult
}
case None => values
case Some((Right(bodyInput: EndpointIO.StreamBodyWrapper[Any, Any]), _)) => decodeStreamingBody(bodyInput)
case None => values
}
case failure: DecodeBasicInputsResult.Failure => failure
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ package sttp.tapir.server.stub

import sttp.client3.Response
import sttp.model.{ContentTypeRange, HasHeaders, Headers, StatusCode}
import sttp.tapir.internal.{NoStreams, ParamsAsAny}
import sttp.tapir.internal.ParamsAsAny
import sttp.tapir.server.interpreter.{EncodeOutputs, OutputValues, ToResponseBody}
import sttp.tapir.{CodecFormat, EndpointOutput, RawBodyType, WebSocketBodyOutput}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import sttp.model.{Header, HeaderNames, MediaType}
import sttp.monad.MonadError
import sttp.tapir.tests.Test
import sttp.tapir.tests.Streaming.{
in_stream_out_either_json_xml_stream,
in_stream_out_stream,
in_stream_out_stream_with_content_length,
in_string_stream_out_either_stream_string,
Expand Down Expand Up @@ -80,6 +81,26 @@ class ServerStreamingTests[F[_], S, OPTIONS, ROUTE](createServerTest: CreateServ
.body(penPineapple)
.send(backend)
.map(_.body shouldBe Right("was not left"))
},
testServer(in_stream_out_either_json_xml_stream(streams)) { s => pureResult(s.asRight[Unit]) } { (backend, baseUri) =>
basicRequest
.post(uri"$baseUri")
.body(penPineapple)
.header(Header.accept(MediaType.ApplicationXml, MediaType.ApplicationJson))
.send(backend)
.map { r =>
r.contentType shouldBe Some(MediaType.ApplicationXml.toString())
r.body shouldBe Right(penPineapple)
} >>
basicRequest
.post(uri"$baseUri")
.body(penPineapple)
.header(Header.accept(MediaType.ApplicationJson, MediaType.ApplicationXml))
.send(backend)
.map { r =>
r.contentType shouldBe Some(MediaType.ApplicationJson.toString())
r.body shouldBe Right(penPineapple)
}
}
)
}
Expand Down
12 changes: 12 additions & 0 deletions tests/src/main/scala/sttp/tapir/tests/Streaming.scala
Original file line number Diff line number Diff line change
Expand Up @@ -43,4 +43,16 @@ object Streaming {
)
)
}

def in_stream_out_either_json_xml_stream[S](
s: Streams[S]
): PublicEndpoint[s.BinaryStream, Unit, s.BinaryStream, S] = {
def textStream(format: CodecFormat) = streamTextBody(s)(format, None)

endpoint.post
.in(textStream(CodecFormat.TextPlain()))
.out(
oneOfBody(textStream(CodecFormat.Json()).toEndpointIO, textStream(CodecFormat.Xml()).toEndpointIO)
)
}
}

0 comments on commit c3f45cb

Please sign in to comment.