diff --git a/server/zio-http-server/src/main/scala/sttp/tapir/server/ziohttp/ZioHttpInterpreter.scala b/server/zio-http-server/src/main/scala/sttp/tapir/server/ziohttp/ZioHttpInterpreter.scala index d62875e736..9bc8e45f87 100644 --- a/server/zio-http-server/src/main/scala/sttp/tapir/server/ziohttp/ZioHttpInterpreter.scala +++ b/server/zio-http-server/src/main/scala/sttp/tapir/server/ziohttp/ZioHttpInterpreter.scala @@ -168,19 +168,15 @@ trait ZioHttpInterpreter[R] { } val statusCode = resp.code.code - ZIO.succeed( - Response( - status = Status.fromInt(statusCode), - headers = ZioHttpHeaders(allHeaders), - body = body - .map { - case ZioStreamHttpResponseBody(stream, Some(contentLength)) => Body.fromStream(stream, contentLength) - case ZioStreamHttpResponseBody(stream, None) => Body.fromStreamChunked(stream) - case ZioRawHttpResponseBody(chunk, _) => Body.fromChunk(chunk) - } - .getOrElse(Body.empty) - ) - ) + body + .map { + case ZioStreamHttpResponseBody(stream, Some(contentLength)) => ZIO.succeed(Body.fromStream(stream, contentLength)) + case ZioStreamHttpResponseBody(stream, None) => ZIO.succeed(Body.fromStreamChunked(stream)) + case ZioMultipartHttpResponseBody(formFields) => Body.fromMultipartFormUUID(Form(Chunk.fromIterable(formFields))) + case ZioRawHttpResponseBody(chunk, _) => ZIO.succeed(Body.fromChunk(chunk)) + } + .getOrElse(ZIO.succeed(Body.empty)) + .map(zioBody => Response(status = Status.fromInt(statusCode), headers = ZioHttpHeaders(allHeaders), body = zioBody)) } private def sttpToZioHttpHeader(hl: (String, Seq[SttpHeader])): Seq[ZioHttpHeader] = { diff --git a/server/zio-http-server/src/main/scala/sttp/tapir/server/ziohttp/ZioHttpRequestBody.scala b/server/zio-http-server/src/main/scala/sttp/tapir/server/ziohttp/ZioHttpRequestBody.scala index fe9e503324..c5828cad2d 100644 --- a/server/zio-http-server/src/main/scala/sttp/tapir/server/ziohttp/ZioHttpRequestBody.scala +++ b/server/zio-http-server/src/main/scala/sttp/tapir/server/ziohttp/ZioHttpRequestBody.scala @@ -2,12 +2,19 @@ package sttp.tapir.server.ziohttp import sttp.capabilities import sttp.capabilities.zio.ZioStreams +import sttp.model.Part +import sttp.model.Part.FileNameDispositionParam +import sttp.tapir.FileRange +import sttp.tapir.InputStreamRange +import sttp.tapir.RawBodyType import sttp.tapir.model.ServerRequest -import sttp.tapir.server.interpreter.{RawValue, RequestBody} -import sttp.tapir.{FileRange, InputStreamRange, RawBodyType} -import zio.http.Request -import zio.stream.{Stream, ZSink, ZStream} +import sttp.tapir.server.interpreter.RawValue +import sttp.tapir.server.interpreter.RequestBody import zio.{RIO, Task, ZIO} +import zio.http.{FormField, Request, StreamingForm} +import zio.http.FormField.StreamingBinary +import zio.stream.ZSink +import zio.stream.ZStream import java.io.ByteArrayInputStream import java.nio.ByteBuffer @@ -15,10 +22,17 @@ import java.nio.ByteBuffer class ZioHttpRequestBody[R](serverOptions: ZioHttpServerOptions[R]) extends RequestBody[RIO[R, *], ZioStreams] { override val streams: capabilities.Streams[ZioStreams] = ZioStreams - override def toRaw[RAW](serverRequest: ServerRequest, bodyType: RawBodyType[RAW], maxBytes: Option[Long]): Task[RawValue[RAW]] = { + override def toRaw[RAW](serverRequest: ServerRequest, bodyType: RawBodyType[RAW], maxBytes: Option[Long]): Task[RawValue[RAW]] = + toRaw(serverRequest, zStream(serverRequest), bodyType, maxBytes) - def asByteArray: Task[Array[Byte]] = - (toStream(serverRequest, maxBytes).asInstanceOf[ZStream[Any, Throwable, Byte]]).runCollect.map(_.toArray) + private def toRaw[RAW]( + serverRequest: ServerRequest, + stream: ZStream[Any, Throwable, Byte], + bodyType: RawBodyType[RAW], + maxBytes: Option[Long] + ): Task[RawValue[RAW]] = { + val limitedStream = limitedZStream(stream, maxBytes) + val asByteArray = limitedStream.runCollect.map(_.toArray) bodyType match { case RawBodyType.StringBody(defaultCharset) => asByteArray.map(new String(_, defaultCharset)).map(RawValue(_)) @@ -26,23 +40,58 @@ class ZioHttpRequestBody[R](serverOptions: ZioHttpServerOptions[R]) extends Requ case RawBodyType.ByteBufferBody => asByteArray.map(bytes => ByteBuffer.wrap(bytes)).map(RawValue(_)) case RawBodyType.InputStreamBody => asByteArray.map(new ByteArrayInputStream(_)).map(RawValue(_)) case RawBodyType.InputStreamRangeBody => - asByteArray.map(bytes => new InputStreamRange(() => new ByteArrayInputStream(bytes))).map(RawValue(_)) + asByteArray.map(bytes => InputStreamRange(() => new ByteArrayInputStream(bytes))).map(RawValue(_)) case RawBodyType.FileBody => for { file <- serverOptions.createFile(serverRequest) - _ <- (toStream(serverRequest, maxBytes).asInstanceOf[ZStream[Any, Throwable, Byte]]).run(ZSink.fromFile(file)).map(_ => ()) + _ <- limitedStream.run(ZSink.fromFile(file)).unit } yield RawValue(FileRange(file), Seq(FileRange(file))) - case RawBodyType.MultipartBody(_, _) => ZIO.fail(new UnsupportedOperationException("Multipart is not supported")) + case m: RawBodyType.MultipartBody => handleMultipartBody(serverRequest, m, limitedStream) } } - override def toStream(serverRequest: ServerRequest, maxBytes: Option[Long]): streams.BinaryStream = { - val inputStream = stream(serverRequest) - maxBytes.map(ZioStreams.limitBytes(inputStream, _)).getOrElse(inputStream).asInstanceOf[streams.BinaryStream] + private def handleMultipartBody[RAW]( + serverRequest: ServerRequest, + bodyType: RawBodyType.MultipartBody, + limitedStream: ZStream[Any, Throwable, Byte] + ): Task[RawValue[RAW]] = { + zRequest(serverRequest).body.contentType.flatMap(_.boundary) match { + case Some(boundary) => + StreamingForm(limitedStream, boundary).fields + .flatMap(field => ZStream.fromIterable(bodyType.partType(field.name).map((field, _)))) + .mapZIO { case (field, bodyType) => toRawPart(serverRequest, field, bodyType) } + .runCollect + .map(RawValue.fromParts(_).asInstanceOf[RawValue[RAW]]) + case None => + ZIO.fail( + new IllegalStateException("Cannot decode body as streaming multipart/form-data without a known boundary") + ) + } + } + + private def toRawPart[A](serverRequest: ServerRequest, field: FormField, bodyType: RawBodyType[A]): Task[Part[A]] = { + val fieldsStream = field match { + case StreamingBinary(_, _, _, _, s) => s + case _ => ZStream.fromIterableZIO(field.asChunk) + } + toRaw(serverRequest, fieldsStream, bodyType, None) + .map(raw => + Part( + field.name, + raw.value, + otherDispositionParams = field.filename.map(name => Map(FileNameDispositionParam -> name)).getOrElse(Map.empty) + ).contentType(field.contentType.fullType) + ) } - private def stream(serverRequest: ServerRequest): Stream[Throwable, Byte] = - zioHttpRequest(serverRequest).body.asStream + override def toStream(serverRequest: ServerRequest, maxBytes: Option[Long]): streams.BinaryStream = + limitedZStream(zStream(serverRequest), maxBytes).asInstanceOf[streams.BinaryStream] + + private def zRequest(serverRequest: ServerRequest) = serverRequest.underlying.asInstanceOf[Request] + + private def limitedZStream(stream: ZStream[Any, Throwable, Byte], maxBytes: Option[Long]) = { + maxBytes.map(ZioStreams.limitBytes(stream, _)).getOrElse(stream) + } - private def zioHttpRequest(serverRequest: ServerRequest) = serverRequest.underlying.asInstanceOf[Request] + private def zStream(serverRequest: ServerRequest) = zRequest(serverRequest).body.asStream } diff --git a/server/zio-http-server/src/main/scala/sttp/tapir/server/ziohttp/ZioHttpResponseBody.scala b/server/zio-http-server/src/main/scala/sttp/tapir/server/ziohttp/ZioHttpResponseBody.scala index d673cac40e..e7644ae236 100644 --- a/server/zio-http-server/src/main/scala/sttp/tapir/server/ziohttp/ZioHttpResponseBody.scala +++ b/server/zio-http-server/src/main/scala/sttp/tapir/server/ziohttp/ZioHttpResponseBody.scala @@ -2,6 +2,7 @@ package sttp.tapir.server.ziohttp import zio.stream.ZStream import zio.Chunk +import zio.http.FormField sealed trait ZioHttpResponseBody { def contentLength: Option[Long] @@ -10,3 +11,7 @@ sealed trait ZioHttpResponseBody { case class ZioStreamHttpResponseBody(stream: ZStream[Any, Throwable, Byte], contentLength: Option[Long]) extends ZioHttpResponseBody case class ZioRawHttpResponseBody(bytes: Chunk[Byte], contentLength: Option[Long]) extends ZioHttpResponseBody + +case class ZioMultipartHttpResponseBody(formFields: List[FormField]) extends ZioHttpResponseBody { + override def contentLength: Option[Long] = None +} diff --git a/server/zio-http-server/src/main/scala/sttp/tapir/server/ziohttp/ZioHttpToResponseBody.scala b/server/zio-http-server/src/main/scala/sttp/tapir/server/ziohttp/ZioHttpToResponseBody.scala index af412f40fb..223230f020 100644 --- a/server/zio-http-server/src/main/scala/sttp/tapir/server/ziohttp/ZioHttpToResponseBody.scala +++ b/server/zio-http-server/src/main/scala/sttp/tapir/server/ziohttp/ZioHttpToResponseBody.scala @@ -2,11 +2,15 @@ package sttp.tapir.server.ziohttp import sttp.capabilities.zio.ZioStreams import sttp.model.HasHeaders +import sttp.model.Part import sttp.tapir.server.interpreter.ToResponseBody -import sttp.tapir.{CodecFormat, RawBodyType, WebSocketBodyOutput} +import sttp.tapir.{CodecFormat, RawBodyType, RawPart, WebSocketBodyOutput} import zio.Chunk +import zio.http.FormField +import zio.http.MediaType import zio.stream.ZStream +import java.io.InputStream import java.nio.ByteBuffer import java.nio.charset.Charset @@ -74,6 +78,59 @@ class ZioHttpToResponseBody extends ToResponseBody[ZioResponseBody, ZioStreams] } } .getOrElse(ZioStreamHttpResponseBody(ZStream.fromPath(tapirFile.file.toPath), Some(tapirFile.file.length))) - case RawBodyType.MultipartBody(_, _) => throw new UnsupportedOperationException("Multipart is not supported") + case m @ RawBodyType.MultipartBody(_, _) => + val formFields = (r: Seq[RawPart]).flatMap { part => + m.partType(part.name).map { partType => + toFormField(partType.asInstanceOf[RawBodyType[Any]], part) + } + }.toList + ZioMultipartHttpResponseBody(formFields) + } + + private def toFormField[R](bodyType: RawBodyType[R], part: Part[R]): FormField = { + val mediaType: Option[MediaType] = part.contentType.flatMap(MediaType.forContentType) + bodyType match { + case RawBodyType.StringBody(_) => + FormField.Text(part.name, part.body, mediaType.getOrElse(MediaType.text.plain), part.fileName) + case RawBodyType.ByteArrayBody => + FormField.Binary( + part.name, + Chunk.fromArray(part.body), + mediaType.getOrElse(MediaType.application.`octet-stream`), + filename = part.fileName + ) + case RawBodyType.ByteBufferBody => + val array: Array[Byte] = new Array[Byte](part.body.remaining) + part.body.get(array) + FormField.Binary( + part.name, + Chunk.fromArray(array), + mediaType.getOrElse(MediaType.application.`octet-stream`), + filename = part.fileName + ) + case RawBodyType.FileBody => + FormField.streamingBinaryField( + part.name, + ZStream.fromFile(part.body.file).orDie, + mediaType.getOrElse(MediaType.application.`octet-stream`), + filename = part.fileName + ) + case RawBodyType.InputStreamBody => + FormField.streamingBinaryField( + part.name, + ZStream.fromInputStream(part.body).orDie, + mediaType.getOrElse(MediaType.application.`octet-stream`), + filename = part.fileName + ) + case RawBodyType.InputStreamRangeBody => + FormField.streamingBinaryField( + part.name, + ZStream.fromInputStream(part.body.inputStream()).orDie, + mediaType.getOrElse(MediaType.application.`octet-stream`), + filename = part.fileName + ) + case _: RawBodyType.MultipartBody => + throw new UnsupportedOperationException("Nested multipart messages are not supported.") } + } } diff --git a/server/zio-http-server/src/test/scala/sttp/tapir/server/ziohttp/ZioHttpServerTest.scala b/server/zio-http-server/src/test/scala/sttp/tapir/server/ziohttp/ZioHttpServerTest.scala index 6a079c2cd8..eb64d68777 100644 --- a/server/zio-http-server/src/test/scala/sttp/tapir/server/ziohttp/ZioHttpServerTest.scala +++ b/server/zio-http-server/src/test/scala/sttp/tapir/server/ziohttp/ZioHttpServerTest.scala @@ -304,11 +304,10 @@ class ZioHttpServerTest extends TestSuite { interpreter, backend, basic = false, - staticContent = true, multipart = false, - file = true, options = false ).tests() ++ + new ServerMultipartTests(createServerTest, partOtherHeaderSupport = false).tests() ++ new ServerStreamingTests(createServerTest).tests(ZioStreams)(drainZStream) ++ new ZioHttpCompositionTest(createServerTest).tests() ++ new ServerWebSocketTests(