Skip to content

Commit

Permalink
[bugfix] Return raw ZIO response body where applicable (#3047)
Browse files Browse the repository at this point in the history
  • Loading branch information
kciesielski authored Jul 20, 2023
1 parent 067f900 commit 9c3ad53
Show file tree
Hide file tree
Showing 5 changed files with 50 additions and 31 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -6,18 +6,22 @@ import zio.stream.ZStream

import scala.util.{Failure, Success, Try}

class ZioHttpBodyListener[R] extends BodyListener[RIO[R, *], ZioHttpResponseBody] {
private[ziohttp] class ZioHttpBodyListener[R] extends BodyListener[RIO[R, *], ZioHttpResponseBody] {
override def onComplete(body: ZioHttpResponseBody)(cb: Try[Unit] => RIO[R, Unit]): RIO[R, ZioHttpResponseBody] =
ZIO
.environmentWith[R]
.environmentWithZIO[R]
.apply { r =>
val (stream, contentLength) = body
(
stream.onError(cause => cb(Failure(cause.squash)).orDie.provideEnvironment(r)) ++ ZStream
.fromZIO(cb(Success(())))
.provideEnvironment(r)
.drain,
contentLength
)
body match {
case ZioStreamHttpResponseBody(stream, contentLength) =>
ZIO.succeed(ZioStreamHttpResponseBody(
stream.onError(cause => cb(Failure(cause.squash)).orDie.provideEnvironment(r)) ++ ZStream
.fromZIO(cb(Success(())))
.provideEnvironment(r)
.drain,
contentLength
)
)
case raw: ZioRawHttpResponseBody => cb(Success(())).provideEnvironment(r).map(_ => raw)
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,8 @@ trait ZioHttpInterpreter[R] {
{
case RequestResult.Response(resp) =>
val baseHeaders = resp.headers.groupBy(_.name).flatMap(sttpToZioHttpHeader).toList
val allHeaders = resp.body match {
case Some((_, Some(contentLength))) if resp.contentLength.isEmpty =>
val allHeaders = resp.body.flatMap(_.contentLength) match {
case Some(contentLength) if resp.contentLength.isEmpty =>
ZioHttpHeader.ContentLength(contentLength) :: baseHeaders
case _ => baseHeaders
}
Expand All @@ -53,7 +53,12 @@ trait ZioHttpInterpreter[R] {
Response(
status = Status.fromInt(statusCode).getOrElse(Status.Custom(statusCode)),
headers = ZioHttpHeaders(allHeaders),
body = resp.body.map { case (stream, _) => Body.fromStream(stream) }.getOrElse(Body.empty)
body = resp.body
.map {
case ZioStreamHttpResponseBody(stream, _) => Body.fromStream(stream)
case ZioRawHttpResponseBody(chunk, _) => Body.fromChunk(chunk)
}
.getOrElse(Body.empty)
)
)
case RequestResult.Failure(_) =>
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
package sttp.tapir.server.ziohttp

import zio.stream.ZStream
import zio.Chunk

private[ziohttp] sealed trait ZioHttpResponseBody {
def contentLength: Option[Long]
}

private[ziohttp] case class ZioStreamHttpResponseBody(stream: ZStream[Any, Throwable, Byte], contentLength: Option[Long])
extends ZioHttpResponseBody

private[ziohttp] case class ZioRawHttpResponseBody(bytes: Chunk[Byte], contentLength: Option[Long]) extends ZioHttpResponseBody
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import sttp.tapir.{CodecFormat, RawBodyType, WebSocketBodyOutput}
import zio.Chunk
import zio.stream.ZStream

import java.nio.ByteBuffer
import java.nio.charset.Charset

class ZioHttpToResponseBody extends ToResponseBody[ZioHttpResponseBody, ZioStreams] {
Expand All @@ -20,33 +21,37 @@ class ZioHttpToResponseBody extends ToResponseBody[ZioHttpResponseBody, ZioStrea
headers: HasHeaders,
format: CodecFormat,
charset: Option[Charset]
): ZioHttpResponseBody = (v, None)
): ZioHttpResponseBody = ZioStreamHttpResponseBody(v, None)

override def fromWebSocketPipe[REQ, RESP](
pipe: streams.Pipe[REQ, RESP],
o: WebSocketBodyOutput[streams.Pipe[REQ, RESP], REQ, RESP, _, ZioStreams]
): ZioHttpResponseBody =
(ZStream.empty, None) // TODO
ZioStreamHttpResponseBody(ZStream.empty, None) // TODO

private def rawValueToEntity[R](bodyType: RawBodyType[R], r: R): ZioHttpResponseBody = {
bodyType match {
case RawBodyType.StringBody(charset) =>
val bytes = r.toString.getBytes(charset)
(ZStream.fromIterable(bytes), Some(bytes.length.toLong))
case RawBodyType.ByteArrayBody => (ZStream.fromChunk(Chunk.fromArray(r)), Some((r: Array[Byte]).length.toLong))
case RawBodyType.ByteBufferBody => (ZStream.fromChunk(Chunk.fromByteBuffer(r)), None)
case RawBodyType.InputStreamBody => (ZStream.fromInputStream(r), None)
ZioRawHttpResponseBody(Chunk.fromArray(bytes), Some(bytes.length.toLong))
case RawBodyType.ByteArrayBody =>
ZioRawHttpResponseBody(Chunk.fromArray(r), Some((r: Array[Byte]).length.toLong))
case RawBodyType.ByteBufferBody =>
val buffer: ByteBuffer = r
ZioRawHttpResponseBody(Chunk.fromByteBuffer(buffer), Some(buffer.remaining()))
case RawBodyType.InputStreamBody =>
ZioStreamHttpResponseBody(ZStream.fromInputStream(r), None)
case RawBodyType.InputStreamRangeBody =>
r.range
.map(range => (ZStream.fromInputStream(r.inputStreamFromRangeStart()).take(range.contentLength), Some(range.contentLength)))
.getOrElse((ZStream.fromInputStream(r.inputStream()), None))
.map(range => ZioStreamHttpResponseBody(ZStream.fromInputStream(r.inputStreamFromRangeStart()).take(range.contentLength), Some(range.contentLength)))
.getOrElse(ZioStreamHttpResponseBody(ZStream.fromInputStream(r.inputStream()), None))
case RawBodyType.FileBody =>
val tapirFile = r
tapirFile.range
.flatMap { r =>
r.startAndEnd.map { s =>
var count = 0L
(
ZioStreamHttpResponseBody(
ZStream
.fromPath(tapirFile.file.toPath)
.dropWhile(_ =>
Expand All @@ -58,7 +63,7 @@ class ZioHttpToResponseBody extends ToResponseBody[ZioHttpResponseBody, ZioStrea
)
}
}
.getOrElse((ZStream.fromPath(tapirFile.file.toPath), Some(tapirFile.file.length)))
.getOrElse(ZioStreamHttpResponseBody(ZStream.fromPath(tapirFile.file.toPath), Some(tapirFile.file.length)))
case RawBodyType.MultipartBody(_, _) => throw new UnsupportedOperationException("Multipart is not supported")
}
}
Expand Down

This file was deleted.

0 comments on commit 9c3ad53

Please sign in to comment.