From 17f46478df559c5bc3f66dcba88dccfd33e46128 Mon Sep 17 00:00:00 2001 From: kciesielski Date: Fri, 7 Jul 2023 12:40:53 +0200 Subject: [PATCH 01/18] wip --- build.sbt | 10 +- .../cats/NettyCatsServerInterpreter.scala | 46 +++++-- .../netty/internal/NettyCatsRequestBody.scala | 48 +++++++ .../internal/NettyCatsToResponseBody.scala | 125 ++++++++++++++++++ .../sttp/tapir/server/netty/NettyConfig.scala | 2 + .../server/netty/NettyResponseContent.scala | 3 + .../netty/internal/NettyServerHandler.scala | 15 ++- .../netty/internal/NettyToResponseBody.scala | 8 +- 8 files changed, 242 insertions(+), 15 deletions(-) create mode 100644 server/netty-server/cats/src/main/scala/sttp/tapir/server/netty/internal/NettyCatsRequestBody.scala create mode 100644 server/netty-server/cats/src/main/scala/sttp/tapir/server/netty/internal/NettyCatsToResponseBody.scala diff --git a/build.sbt b/build.sbt index c17c70412d..18bc8c39c2 100644 --- a/build.sbt +++ b/build.sbt @@ -1353,7 +1353,10 @@ lazy val nettyServer: ProjectMatrix = (projectMatrix in file("server/netty-serve .settings(commonJvmSettings) .settings( name := "tapir-netty-server", - libraryDependencies ++= Seq("io.netty" % "netty-all" % Versions.nettyAll) + libraryDependencies ++= Seq( + "io.netty" % "netty-all" % Versions.nettyAll, + "com.typesafe.netty" % "netty-reactive-streams-http" % "2.0.8" + ) ++ loggerDependencies, // needed because of https://github.com/coursier/coursier/issues/2016 useCoursier := false @@ -1362,7 +1365,10 @@ lazy val nettyServer: ProjectMatrix = (projectMatrix in file("server/netty-serve .dependsOn(serverCore, serverTests % Test) lazy val nettyServerCats: ProjectMatrix = nettyServerProject("cats", catsEffect) - .settings(libraryDependencies += "com.softwaremill.sttp.shared" %% "fs2" % Versions.sttpShared) + .settings(libraryDependencies ++= Seq( + "com.softwaremill.sttp.shared" %% "fs2" % Versions.sttpShared, + "co.fs2" %% "fs2-reactive-streams" % Versions.fs2 + )) lazy val nettyServerZio: ProjectMatrix = nettyServerProject("zio", zio) .settings(libraryDependencies += "dev.zio" %% "zio-interop-cats" % Versions.zioInteropCats) diff --git a/server/netty-server/cats/src/main/scala/sttp/tapir/server/netty/cats/NettyCatsServerInterpreter.scala b/server/netty-server/cats/src/main/scala/sttp/tapir/server/netty/cats/NettyCatsServerInterpreter.scala index 3e2d3d9072..fb2e01fed8 100644 --- a/server/netty-server/cats/src/main/scala/sttp/tapir/server/netty/cats/NettyCatsServerInterpreter.scala +++ b/server/netty-server/cats/src/main/scala/sttp/tapir/server/netty/cats/NettyCatsServerInterpreter.scala @@ -3,27 +3,57 @@ package sttp.tapir.server.netty.cats import cats.effect.Async import cats.effect.std.Dispatcher import sttp.monad.MonadError +import sttp.monad.syntax._ import sttp.tapir.integ.cats.effect.CatsMonadError import sttp.tapir.server.ServerEndpoint import sttp.tapir.server.netty.Route -import sttp.tapir.server.netty.internal.{NettyServerInterpreter, RunAsync} +import sttp.tapir.server.interpreter.BodyListener +import sttp.tapir.server.netty.internal.NettyBodyListener +import sttp.tapir.server.netty.NettyResponse +import sttp.tapir.server.interpreter.ServerInterpreter +import sttp.tapir.server.interpreter.FilterServerEndpoints +import sttp.tapir.server.netty.internal.NettyRequestBody +import sttp.tapir.server.netty.internal.NettyToResponseBody +import sttp.tapir.server.interceptor.reject.RejectInterceptor +import sttp.tapir.server.netty.NettyServerRequest +import sttp.tapir.server.interceptor.RequestResult +import sttp.capabilities.fs2.Fs2Streams +import sttp.tapir.server.netty.internal.RunAsync +import sttp.tapir.server.netty.internal._ trait NettyCatsServerInterpreter[F[_]] { implicit def async: Async[F] def nettyServerOptions: NettyCatsServerOptions[F] - def toRoute(ses: List[ServerEndpoint[Any, F]]): Route[F] = { + def toRoute(ses: List[ServerEndpoint[Fs2Streams[F], F]]): Route[F] = { + implicit val monad: MonadError[F] = new CatsMonadError[F] val runAsync = new RunAsync[F] { override def apply[T](f: => F[T]): Unit = nettyServerOptions.dispatcher.unsafeRunAndForget(f) } - NettyServerInterpreter.toRoute( - ses, - nettyServerOptions.interceptors, - nettyServerOptions.createFile, - nettyServerOptions.deleteFile, - runAsync + implicit val bodyListener: BodyListener[F, NettyResponse] = new NettyBodyListener(runAsync) + + val interceptors = nettyServerOptions.interceptors + val createFile = nettyServerOptions.createFile + val deleteFile = nettyServerOptions.deleteFile + + val serverInterpreter = new ServerInterpreter[Fs2Streams[F], F, NettyResponse, Fs2Streams[F]]( + FilterServerEndpoints(ses), + new NettyCatsRequestBody(createFile), + new NettyCatsToResponseBody(nettyServerOptions.dispatcher), + RejectInterceptor.disableWhenSingleEndpoint(interceptors, ses), + deleteFile ) + + val handler: Route[F] = { (request: NettyServerRequest) => + serverInterpreter(request) + .map { + case RequestResult.Response(response) => Some(response) + case RequestResult.Failure(_) => None + } + } + + handler } } diff --git a/server/netty-server/cats/src/main/scala/sttp/tapir/server/netty/internal/NettyCatsRequestBody.scala b/server/netty-server/cats/src/main/scala/sttp/tapir/server/netty/internal/NettyCatsRequestBody.scala new file mode 100644 index 0000000000..967e9db5ae --- /dev/null +++ b/server/netty-server/cats/src/main/scala/sttp/tapir/server/netty/internal/NettyCatsRequestBody.scala @@ -0,0 +1,48 @@ +package sttp.tapir.server.netty.internal + +import cats.effect.{Async, Sync} +import cats.syntax.all._ +import io.netty.buffer.{ByteBufInputStream, ByteBufUtil} +import io.netty.handler.codec.http.FullHttpRequest +import sttp.capabilities.fs2.Fs2Streams +import sttp.tapir.{FileRange, InputStreamRange, RawBodyType, TapirFile} +import sttp.tapir.model.ServerRequest +import sttp.tapir.server.interpreter.{RawValue, RequestBody} + +import java.nio.ByteBuffer +import java.nio.file.Files + +private[netty] class NettyCatsRequestBody[F[_]](createFile: ServerRequest => F[TapirFile])(implicit val monad: Async[F]) + extends RequestBody[F, Fs2Streams[F]] { + + val streamChunkSize = 8192 + override val streams: Fs2Streams[F] = Fs2Streams[F] + + override def toRaw[R](serverRequest: ServerRequest, bodyType: RawBodyType[R]): F[RawValue[R]] = { + + /** [[ByteBufUtil.getBytes(io.netty.buffer.ByteBuf)]] copies buffer without affecting reader index of the original. */ + def requestContentAsByteArray = ByteBufUtil.getBytes(nettyRequest(serverRequest).content()) + + bodyType match { + case RawBodyType.StringBody(charset) => monad.delay(RawValue(nettyRequest(serverRequest).content().toString(charset))) + case RawBodyType.ByteArrayBody => monad.delay(RawValue(requestContentAsByteArray)) + case RawBodyType.ByteBufferBody => monad.delay(RawValue(ByteBuffer.wrap(requestContentAsByteArray))) + case RawBodyType.InputStreamBody => monad.delay(RawValue(new ByteBufInputStream(nettyRequest(serverRequest).content()))) + case RawBodyType.InputStreamRangeBody => + monad.delay(RawValue(InputStreamRange(() => new ByteBufInputStream(nettyRequest(serverRequest).content())))) + case RawBodyType.FileBody => + createFile(serverRequest) + .map(file => { + Files.write(file.toPath, requestContentAsByteArray) + RawValue(FileRange(file), Seq(FileRange(file))) + }) + case _: RawBodyType.MultipartBody => ??? + } + } + + override def toStream(serverRequest: ServerRequest): streams.BinaryStream = { + fs2.io.readInputStream(Sync[F].delay(new ByteBufInputStream(nettyRequest(serverRequest).content())), streamChunkSize) + } + + private def nettyRequest(serverRequest: ServerRequest): FullHttpRequest = serverRequest.underlying.asInstanceOf[FullHttpRequest] +} diff --git a/server/netty-server/cats/src/main/scala/sttp/tapir/server/netty/internal/NettyCatsToResponseBody.scala b/server/netty-server/cats/src/main/scala/sttp/tapir/server/netty/internal/NettyCatsToResponseBody.scala new file mode 100644 index 0000000000..b5f1ee00ea --- /dev/null +++ b/server/netty-server/cats/src/main/scala/sttp/tapir/server/netty/internal/NettyCatsToResponseBody.scala @@ -0,0 +1,125 @@ +package sttp.tapir.server.netty.internal + +import cats.effect.kernel.Async +import cats.effect.std.Dispatcher +import fs2.interop.reactivestreams._ +import io.netty.buffer.Unpooled +import io.netty.channel.ChannelHandlerContext +import io.netty.handler.stream.{ChunkedFile, ChunkedStream} +import sttp.model.HasHeaders +import sttp.tapir.capabilities.NoStreams +import sttp.tapir.server.interpreter.ToResponseBody +import sttp.tapir.server.netty.NettyResponse +import sttp.tapir.server.netty.NettyResponseContent.{ + ByteBufNettyResponseContent, + ChunkedFileNettyResponseContent, + ChunkedStreamNettyResponseContent +} +import sttp.tapir.{CodecFormat, FileRange, InputStreamRange, RawBodyType, WebSocketBodyOutput} + +import java.io.{InputStream, RandomAccessFile} +import java.nio.ByteBuffer +import java.nio.charset.Charset +import sttp.capabilities.fs2.Fs2Streams +import sttp.tapir.server.netty.NettyResponse +import sttp.tapir.server.netty.NettyResponseContent +import org.reactivestreams.Publisher +import io.netty.buffer.ByteBuf +import org.reactivestreams.Subscriber +import fs2.Chunk +import io.netty.handler.codec.http.HttpContent +import io.netty.handler.codec.http.DefaultHttpContent + +private[netty] class RangedChunkedStream(raw: InputStream, length: Long) extends ChunkedStream(raw) { + + override def isEndOfInput(): Boolean = + super.isEndOfInput || transferredBytes == length +} + +class NettyCatsToResponseBody[F[_]: Async](dispatcher: Dispatcher[F]) extends ToResponseBody[NettyResponse, Fs2Streams[F]] { + override val streams: Fs2Streams[F] = Fs2Streams[F] + + override def fromRawValue[R](v: R, headers: HasHeaders, format: CodecFormat, bodyType: RawBodyType[R]): NettyResponse = { + bodyType match { + case RawBodyType.StringBody(charset) => + val bytes = v.asInstanceOf[String].getBytes(charset) + (ctx: ChannelHandlerContext) => ByteBufNettyResponseContent(ctx.newPromise(), Unpooled.wrappedBuffer(bytes)) + + case RawBodyType.ByteArrayBody => + val bytes = v.asInstanceOf[Array[Byte]] + (ctx: ChannelHandlerContext) => ByteBufNettyResponseContent(ctx.newPromise(), Unpooled.wrappedBuffer(bytes)) + + case RawBodyType.ByteBufferBody => + val byteBuffer = v.asInstanceOf[ByteBuffer] + (ctx: ChannelHandlerContext) => ByteBufNettyResponseContent(ctx.newPromise(), Unpooled.wrappedBuffer(byteBuffer)) + + case RawBodyType.InputStreamBody => + (ctx: ChannelHandlerContext) => ChunkedStreamNettyResponseContent(ctx.newPromise(), wrap(v)) + + case RawBodyType.InputStreamRangeBody => + (ctx: ChannelHandlerContext) => ChunkedStreamNettyResponseContent(ctx.newPromise(), wrap(v)) + + case RawBodyType.FileBody => + (ctx: ChannelHandlerContext) => ChunkedFileNettyResponseContent(ctx.newPromise(), wrap(v)) + + case _: RawBodyType.MultipartBody => throw new UnsupportedOperationException + } + } + + private def wrap(streamRange: InputStreamRange): ChunkedStream = { + streamRange.range + .map(r => new RangedChunkedStream(streamRange.inputStreamFromRangeStart(), r.contentLength)) + .getOrElse(new ChunkedStream(streamRange.inputStream())) + } + + private def wrap(content: InputStream): ChunkedStream = { + new ChunkedStream(content) + } + + private def wrap(content: FileRange): ChunkedFile = { + val file = content.file + val maybeRange = for { + range <- content.range + start <- range.start + end <- range.end + } yield (start, end + NettyToResponseBody.IncludingLastOffset) + + maybeRange match { + case Some((start, end)) => { + val randomAccessFile = new RandomAccessFile(file, NettyToResponseBody.ReadOnlyAccessMode) + new ChunkedFile(randomAccessFile, start, end - start, NettyToResponseBody.DefaultChunkSize) + } + case None => new ChunkedFile(file) + } + } + + def fs2StreamToPublisher(stream: streams.BinaryStream): Publisher[HttpContent] = { + // Deprecated constructor, but the proposed one does roughly the same, forcing a dedicated + // dispatcher, which results in a Resource[], which is hard to afford here + StreamUnicastPublisher( + stream.chunks + .map { chunk => + val bytes: Chunk.ArraySlice[Byte] = chunk.compact + new DefaultHttpContent(Unpooled.wrappedBuffer(bytes.values, bytes.offset, bytes.length)) + }, + dispatcher + ) + } + + override def fromStreamValue( + v: streams.BinaryStream, + headers: HasHeaders, + format: CodecFormat, + charset: Option[Charset] + ): NettyResponse = + (ctx: ChannelHandlerContext) => { + new NettyResponseContent.ReactivePublisherNettyResponseContent( + ctx.newPromise(), + fs2StreamToPublisher(v)) + } + + override def fromWebSocketPipe[REQ, RESP]( + pipe: streams.Pipe[REQ, RESP], + o: WebSocketBodyOutput[streams.Pipe[REQ, RESP], REQ, RESP, _, Fs2Streams[F]] + ): NettyResponse = throw new UnsupportedOperationException +} diff --git a/server/netty-server/src/main/scala/sttp/tapir/server/netty/NettyConfig.scala b/server/netty-server/src/main/scala/sttp/tapir/server/netty/NettyConfig.scala index b65e072d0b..6339c63544 100644 --- a/server/netty-server/src/main/scala/sttp/tapir/server/netty/NettyConfig.scala +++ b/server/netty-server/src/main/scala/sttp/tapir/server/netty/NettyConfig.scala @@ -1,5 +1,6 @@ package sttp.tapir.server.netty +import com.typesafe.netty.http.HttpStreamsServerHandler import io.netty.channel.epoll.{Epoll, EpollEventLoopGroup, EpollServerSocketChannel} import io.netty.channel.kqueue.{KQueue, KQueueEventLoopGroup, KQueueServerSocketChannel} import io.netty.channel.nio.NioEventLoopGroup @@ -121,6 +122,7 @@ object NettyConfig { pipeline.addLast(new HttpServerCodec()) pipeline.addLast(new HttpObjectAggregator(cfg.maxContentLength)) pipeline.addLast(new ChunkedWriteHandler()) + pipeline.addLast("serverStreamsHandler", new HttpStreamsServerHandler()) pipeline.addLast(handler) if (cfg.addLoggingHandler) pipeline.addLast(new LoggingHandler()) () diff --git a/server/netty-server/src/main/scala/sttp/tapir/server/netty/NettyResponseContent.scala b/server/netty-server/src/main/scala/sttp/tapir/server/netty/NettyResponseContent.scala index dc79aebae2..d3adfe3480 100644 --- a/server/netty-server/src/main/scala/sttp/tapir/server/netty/NettyResponseContent.scala +++ b/server/netty-server/src/main/scala/sttp/tapir/server/netty/NettyResponseContent.scala @@ -2,7 +2,9 @@ package sttp.tapir.server.netty import io.netty.buffer.ByteBuf import io.netty.channel.ChannelPromise +import io.netty.handler.codec.http.HttpContent import io.netty.handler.stream.{ChunkedFile, ChunkedStream} +import org.reactivestreams.Publisher sealed trait NettyResponseContent { def channelPromise: ChannelPromise @@ -13,4 +15,5 @@ object NettyResponseContent { final case class ChunkedStreamNettyResponseContent(channelPromise: ChannelPromise, chunkedStream: ChunkedStream) extends NettyResponseContent final case class ChunkedFileNettyResponseContent(channelPromise: ChannelPromise, chunkedFile: ChunkedFile) extends NettyResponseContent + final case class ReactivePublisherNettyResponseContent(channelPromise: ChannelPromise, publisher: Publisher[HttpContent]) extends NettyResponseContent } diff --git a/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/NettyServerHandler.scala b/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/NettyServerHandler.scala index 7842f6b3c6..1e42ddb856 100644 --- a/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/NettyServerHandler.scala +++ b/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/NettyServerHandler.scala @@ -5,17 +5,19 @@ import io.netty.buffer.{ByteBuf, Unpooled} import io.netty.channel._ import io.netty.handler.codec.http._ import io.netty.handler.stream.{ChunkedFile, ChunkedStream} +import org.reactivestreams.Publisher import sttp.monad.MonadError import sttp.monad.syntax._ import sttp.tapir.server.model.ServerResponse -import sttp.tapir.server.netty.{NettyResponse, NettyServerRequest, Route} import sttp.tapir.server.netty.NettyResponseContent.{ ByteBufNettyResponseContent, ChunkedFileNettyResponseContent, ChunkedStreamNettyResponseContent } +import sttp.tapir.server.netty.{NettyResponse, NettyResponseContent, NettyServerRequest, Route} import scala.collection.JavaConverters._ +import com.typesafe.netty.http.DefaultStreamedHttpResponse class NettyServerHandler[F[_]](route: Route[F], unsafeRunAsync: (() => F[Unit]) => Unit)(implicit me: MonadError[F]) extends SimpleChannelInboundHandler[FullHttpRequest] { @@ -94,6 +96,15 @@ class NettyServerHandler[F[_]](route: Route[F], unsafeRunAsync: (() => F[Unit]) // HttpChunkedInput will write the end marker (LastHttpContent) for us. ctx.writeAndFlush(new HttpChunkedInput(chunkedFile), channelPromise).closeIfNeeded(req) }, + reactiveStreamHandler = (channelPromise, publisher) => { + val resHeader: DefaultStreamedHttpResponse = new DefaultStreamedHttpResponse( + req.protocolVersion(), HttpResponseStatus.valueOf(serverResponse.code.code), publisher) + + resHeader.setHeadersFrom(serverResponse) + resHeader.handleContentLengthAndChunkedHeaders(None) + resHeader.handleCloseAndKeepAliveHeaders(req) + + }, noBodyHandler = () => { val res = new DefaultFullHttpResponse( req.protocolVersion(), @@ -115,6 +126,7 @@ class NettyServerHandler[F[_]](route: Route[F], unsafeRunAsync: (() => F[Unit]) byteBufHandler: (ChannelPromise, ByteBuf) => Unit, chunkedStreamHandler: (ChannelPromise, ChunkedStream) => Unit, chunkedFileHandler: (ChannelPromise, ChunkedFile) => Unit, + reactiveStreamHandler: (ChannelPromise, Publisher[HttpContent]) => Unit, noBodyHandler: () => Unit ): Unit = { r.body match { @@ -125,6 +137,7 @@ class NettyServerHandler[F[_]](route: Route[F], unsafeRunAsync: (() => F[Unit]) case r: ByteBufNettyResponseContent => byteBufHandler(r.channelPromise, r.byteBuf) case r: ChunkedStreamNettyResponseContent => chunkedStreamHandler(r.channelPromise, r.chunkedStream) case r: ChunkedFileNettyResponseContent => chunkedFileHandler(r.channelPromise, r.chunkedFile) + case r: NettyResponseContent.ReactivePublisherNettyResponseContent => reactiveStreamHandler(r.channelPromise, r.publisher) } } case None => noBodyHandler() diff --git a/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/NettyToResponseBody.scala b/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/NettyToResponseBody.scala index 8a060c56cc..dd39ebef31 100644 --- a/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/NettyToResponseBody.scala +++ b/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/NettyToResponseBody.scala @@ -95,8 +95,8 @@ class NettyToResponseBody extends ToResponseBody[NettyResponse, NoStreams] { ): NettyResponse = throw new UnsupportedOperationException } -object NettyToResponseBody { - private val DefaultChunkSize = 8192 - private val IncludingLastOffset = 1 - private val ReadOnlyAccessMode = "r" +private[internal] object NettyToResponseBody { + val DefaultChunkSize = 8192 + val IncludingLastOffset = 1 + val ReadOnlyAccessMode = "r" } From 2f35d51847ed56eae1e1a12875222f7631c71cc6 Mon Sep 17 00:00:00 2001 From: kciesielski Date: Fri, 14 Jul 2023 23:29:48 +0200 Subject: [PATCH 02/18] wip --- .../netty/internal/NettyCatsRequestBody.scala | 53 +++-- .../internal/NettyCatsToResponseBody.scala | 12 +- .../netty/cats/NettyCatsServerTest.scala | 64 +++++- .../cats/NettyCatsTestServerInterpreter.scala | 5 +- .../sttp/tapir/server/netty/NettyConfig.scala | 6 +- .../server/netty/NettyServerRequest.scala | 12 +- .../netty/internal/NettyRequestBody.scala | 4 + .../netty/internal/NettyServerHandler.scala | 206 ++++++++++++++---- .../tapir/server/tests/CreateServerTest.scala | 1 + .../tapir/server/tests/ServerFilesTests.scala | 3 +- .../server/tests/ServerStreamingTests.scala | 6 + .../scala/sttp/tapir/tests/Streaming.scala | 5 + 12 files changed, 307 insertions(+), 70 deletions(-) diff --git a/server/netty-server/cats/src/main/scala/sttp/tapir/server/netty/internal/NettyCatsRequestBody.scala b/server/netty-server/cats/src/main/scala/sttp/tapir/server/netty/internal/NettyCatsRequestBody.scala index 967e9db5ae..e398360fdd 100644 --- a/server/netty-server/cats/src/main/scala/sttp/tapir/server/netty/internal/NettyCatsRequestBody.scala +++ b/server/netty-server/cats/src/main/scala/sttp/tapir/server/netty/internal/NettyCatsRequestBody.scala @@ -8,9 +8,15 @@ import sttp.capabilities.fs2.Fs2Streams import sttp.tapir.{FileRange, InputStreamRange, RawBodyType, TapirFile} import sttp.tapir.model.ServerRequest import sttp.tapir.server.interpreter.{RawValue, RequestBody} - +import java.io.ByteArrayInputStream import java.nio.ByteBuffer -import java.nio.file.Files +import com.typesafe.netty.http.StreamedHttpRequest +import com.typesafe.netty.http.DefaultStreamedHttpRequest +import fs2.interop.reactivestreams.StreamSubscriber +import io.netty.handler.codec.http.HttpContent +import fs2.Chunk +import fs2.io.file.Files +import fs2.io.file.Path private[netty] class NettyCatsRequestBody[F[_]](createFile: ServerRequest => F[TapirFile])(implicit val monad: Async[F]) extends RequestBody[F, Fs2Streams[F]] { @@ -20,29 +26,46 @@ private[netty] class NettyCatsRequestBody[F[_]](createFile: ServerRequest => F[T override def toRaw[R](serverRequest: ServerRequest, bodyType: RawBodyType[R]): F[RawValue[R]] = { - /** [[ByteBufUtil.getBytes(io.netty.buffer.ByteBuf)]] copies buffer without affecting reader index of the original. */ - def requestContentAsByteArray = ByteBufUtil.getBytes(nettyRequest(serverRequest).content()) - bodyType match { - case RawBodyType.StringBody(charset) => monad.delay(RawValue(nettyRequest(serverRequest).content().toString(charset))) - case RawBodyType.ByteArrayBody => monad.delay(RawValue(requestContentAsByteArray)) - case RawBodyType.ByteBufferBody => monad.delay(RawValue(ByteBuffer.wrap(requestContentAsByteArray))) - case RawBodyType.InputStreamBody => monad.delay(RawValue(new ByteBufInputStream(nettyRequest(serverRequest).content()))) + case RawBodyType.StringBody(charset) => nettyRequestBytes(serverRequest).map(bs => RawValue(new String(bs, charset))) + case RawBodyType.ByteArrayBody => + nettyRequestBytes(serverRequest).map(RawValue(_)) + case RawBodyType.ByteBufferBody => + nettyRequestBytes(serverRequest).map(bs => RawValue(ByteBuffer.wrap(bs))) + case RawBodyType.InputStreamBody => + nettyRequestBytes(serverRequest).map(bs => RawValue(new ByteArrayInputStream(bs))) case RawBodyType.InputStreamRangeBody => - monad.delay(RawValue(InputStreamRange(() => new ByteBufInputStream(nettyRequest(serverRequest).content())))) + nettyRequestBytes(serverRequest).map(bs => RawValue(InputStreamRange(() => new ByteArrayInputStream(bs)))) case RawBodyType.FileBody => createFile(serverRequest) - .map(file => { - Files.write(file.toPath, requestContentAsByteArray) - RawValue(FileRange(file), Seq(FileRange(file))) + .flatMap(tapirFile => { + toStream(serverRequest) + .through( + Files[F](Files.forAsync[F]).writeAll(Path.fromNioPath(tapirFile.toPath)) + ) + .compile + .drain + .map(_ => RawValue(FileRange(tapirFile), Seq(FileRange(tapirFile)))) }) case _: RawBodyType.MultipartBody => ??? } } override def toStream(serverRequest: ServerRequest): streams.BinaryStream = { - fs2.io.readInputStream(Sync[F].delay(new ByteBufInputStream(nettyRequest(serverRequest).content())), streamChunkSize) + + val nettyRequest = serverRequest.underlying.asInstanceOf[StreamedHttpRequest] + + fs2.Stream + .eval(StreamSubscriber[F, HttpContent](NettyRequestBody.bufferSize)) + .flatMap(s => s.sub.stream(Sync[F].delay(nettyRequest.subscribe(s)))) + .flatMap(httpContent => fs2.Stream.chunk(Chunk.byteBuffer(httpContent.content.nioBuffer()))) + + // fs2.io.readInputStream(Sync[F].delay(new ByteBufInputStream(nettyRequest(serverRequest).content())), streamChunkSize) } - private def nettyRequest(serverRequest: ServerRequest): FullHttpRequest = serverRequest.underlying.asInstanceOf[FullHttpRequest] + private def nettyRequestBytes(serverRequest: ServerRequest): F[Array[Byte]] = serverRequest.underlying match { + case req: FullHttpRequest => monad.delay(ByteBufUtil.getBytes(req.content())) + case req: StreamedHttpRequest => toStream(serverRequest).compile.to(Chunk).map(_.toArray[Byte]) + case other => monad.raiseError(new UnsupportedOperationException(s"Unexpected Netty request of type ${other.getClass().getName()}")) + } } diff --git a/server/netty-server/cats/src/main/scala/sttp/tapir/server/netty/internal/NettyCatsToResponseBody.scala b/server/netty-server/cats/src/main/scala/sttp/tapir/server/netty/internal/NettyCatsToResponseBody.scala index b5f1ee00ea..5ad206b7f3 100644 --- a/server/netty-server/cats/src/main/scala/sttp/tapir/server/netty/internal/NettyCatsToResponseBody.scala +++ b/server/netty-server/cats/src/main/scala/sttp/tapir/server/netty/internal/NettyCatsToResponseBody.scala @@ -40,6 +40,7 @@ class NettyCatsToResponseBody[F[_]: Async](dispatcher: Dispatcher[F]) extends To override val streams: Fs2Streams[F] = Fs2Streams[F] override def fromRawValue[R](v: R, headers: HasHeaders, format: CodecFormat, bodyType: RawBodyType[R]): NettyResponse = { + println(">>>>>>>>>>>>>>>>>>>>>>> Using fromRawValue") bodyType match { case RawBodyType.StringBody(charset) => val bytes = v.asInstanceOf[String].getBytes(charset) @@ -60,6 +61,7 @@ class NettyCatsToResponseBody[F[_]: Async](dispatcher: Dispatcher[F]) extends To (ctx: ChannelHandlerContext) => ChunkedStreamNettyResponseContent(ctx.newPromise(), wrap(v)) case RawBodyType.FileBody => + println(s"Returning FileBody, headers = $headers") (ctx: ChannelHandlerContext) => ChunkedFileNettyResponseContent(ctx.newPromise(), wrap(v)) case _: RawBodyType.MultipartBody => throw new UnsupportedOperationException @@ -89,11 +91,14 @@ class NettyCatsToResponseBody[F[_]: Async](dispatcher: Dispatcher[F]) extends To val randomAccessFile = new RandomAccessFile(file, NettyToResponseBody.ReadOnlyAccessMode) new ChunkedFile(randomAccessFile, start, end - start, NettyToResponseBody.DefaultChunkSize) } - case None => new ChunkedFile(file) + case None => + println(s">>>>>>>>>>>>>>>>>>>>>>>>>> no range, $file") + new ChunkedFile(file) } } def fs2StreamToPublisher(stream: streams.BinaryStream): Publisher[HttpContent] = { + println(">>>>>>>>>>>>> handling stream and creating a publisher") // Deprecated constructor, but the proposed one does roughly the same, forcing a dedicated // dispatcher, which results in a Resource[], which is hard to afford here StreamUnicastPublisher( @@ -113,9 +118,8 @@ class NettyCatsToResponseBody[F[_]: Async](dispatcher: Dispatcher[F]) extends To charset: Option[Charset] ): NettyResponse = (ctx: ChannelHandlerContext) => { - new NettyResponseContent.ReactivePublisherNettyResponseContent( - ctx.newPromise(), - fs2StreamToPublisher(v)) + println(">>>>>>>> Creating reactive stream from response") + new NettyResponseContent.ReactivePublisherNettyResponseContent(ctx.newPromise(), fs2StreamToPublisher(v)) } override def fromWebSocketPipe[REQ, RESP]( diff --git a/server/netty-server/cats/src/test/scala/sttp/tapir/server/netty/cats/NettyCatsServerTest.scala b/server/netty-server/cats/src/test/scala/sttp/tapir/server/netty/cats/NettyCatsServerTest.scala index 648fd8a69a..988cdf5c3c 100644 --- a/server/netty-server/cats/src/test/scala/sttp/tapir/server/netty/cats/NettyCatsServerTest.scala +++ b/server/netty-server/cats/src/test/scala/sttp/tapir/server/netty/cats/NettyCatsServerTest.scala @@ -1,15 +1,30 @@ package sttp.tapir.server.netty.cats +import cats.data.NonEmptyList +import cats.effect.unsafe.implicits.global import cats.effect.{IO, Resource} +import cats.syntax.all._ +import com.typesafe.scalalogging.StrictLogging import io.netty.channel.nio.NioEventLoopGroup import org.scalatest.EitherValues +import org.scalatest.matchers.should.Matchers +import sttp.capabilities.fs2.Fs2Streams +import sttp.client3._ import sttp.monad.MonadError +import sttp.tapir._ import sttp.tapir.integ.cats.effect.CatsMonadError import sttp.tapir.server.netty.internal.FutureUtil import sttp.tapir.server.tests._ import sttp.tapir.tests.{Test, TestSuite} -class NettyCatsServerTest extends TestSuite with EitherValues { +import java.nio.charset.StandardCharsets +import scala.concurrent.duration._ +import fs2.io.file.Files +import java.nio.file.Paths +import fs2.io.file.Path +import scala.util.Random + +class NettyCatsServerTest extends TestSuite with EitherValues with StrictLogging with Matchers { override def tests: Resource[IO, List[Test]] = backendResource.flatMap { backend => Resource @@ -20,7 +35,52 @@ class NettyCatsServerTest extends TestSuite with EitherValues { val interpreter = new NettyCatsTestServerInterpreter(eventLoopGroup, dispatcher) val createServerTest = new DefaultCreateServerTest(backend, interpreter) - val tests = new AllServerTests(createServerTest, interpreter, backend, multipart = false).tests() + val tests = new AllServerTests(createServerTest, interpreter, backend, multipart = false).tests() + // + // val tests = new ServerStreamingTests(createServerTest, Fs2Streams[IO]).tests() + + val s = Fs2Streams[IO] + val streamBody = streamTextBody(s)(CodecFormat.TextPlain(), Some(StandardCharsets.UTF_8)) + + val responseStr = "This is response text" + val responseText = responseStr.getBytes().toList + // val responseAsStream = Files[IO].readAll(Path.fromNioPath(Paths.get("/home/kc/LICENSE"))) + val responseAsStream: fs2.Stream[IO, Byte] = fs2.Stream.emits(responseText).covary[IO] + val testRoute = endpoint.post.in("kc").in(streamBody).out(streamBody).serverLogic[IO] { (inputStream) => + val sink = Files[IO].writeAll(Path.fromNioPath(Paths.get(s"./out-${Random.nextInt()}"))) + // val s2 = inputStream.through(sink).compile + // s2.drain.unsafeRunSync() + + IO.delay(Right(inputStream.map(_.toChar).map(_.toString).map(_.toUpperCase).map(_.head).map(_.toByte))) + // IO.delay(Right(fs2.Stream.emits(inputStream.getBytes().toList))) + } + + val route = interpreter.route(testRoute) + val rs = NonEmptyList.one(route) + val resources = for { + port <- interpreter.server(rs).onError { case e: Exception => + Resource.eval(IO(logger.error(s"Starting server failed because of ${e.getMessage}"))) + } + _ <- Resource.eval(IO(logger.info(s"Bound server on port: $port"))) + } yield port + + val requestStr = "Pen pineapple apple pen streaming req text" + val requestBytes = requestStr.getBytes().toList + val requestAsStream: fs2.Stream[IO, Byte] = fs2.Stream.emits(requestBytes).covary[IO] + val tests2 = List(Test("work!") { + resources + .use { port => + val baseUri = uri"http://localhost:$port" + // IO.sleep(30.seconds) >> + basicRequest + .post(uri"$baseUri/kc") + .streamBody(Fs2Streams[IO])(requestAsStream) + // .body("requestStr") + .send(backend) + .map(_.body shouldBe Right(responseStr)) + } + .unsafeToFuture() + }) IO.pure((tests, eventLoopGroup)) } { case (_, eventLoopGroup) => diff --git a/server/netty-server/cats/src/test/scala/sttp/tapir/server/netty/cats/NettyCatsTestServerInterpreter.scala b/server/netty-server/cats/src/test/scala/sttp/tapir/server/netty/cats/NettyCatsTestServerInterpreter.scala index f41d7dcbdb..0eb281c97a 100644 --- a/server/netty-server/cats/src/test/scala/sttp/tapir/server/netty/cats/NettyCatsTestServerInterpreter.scala +++ b/server/netty-server/cats/src/test/scala/sttp/tapir/server/netty/cats/NettyCatsTestServerInterpreter.scala @@ -8,10 +8,11 @@ import sttp.tapir.server.ServerEndpoint import sttp.tapir.server.netty.{NettyConfig, Route} import sttp.tapir.server.tests.TestServerInterpreter import sttp.tapir.tests.Port +import sttp.capabilities.fs2.Fs2Streams class NettyCatsTestServerInterpreter(eventLoopGroup: NioEventLoopGroup, dispatcher: Dispatcher[IO]) - extends TestServerInterpreter[IO, Any, NettyCatsServerOptions[IO], Route[IO]] { - override def route(es: List[ServerEndpoint[Any, IO]], interceptors: Interceptors): Route[IO] = { + extends TestServerInterpreter[IO, Fs2Streams[IO], NettyCatsServerOptions[IO], Route[IO]] { + override def route(es: List[ServerEndpoint[Fs2Streams[IO], IO]], interceptors: Interceptors): Route[IO] = { val serverOptions: NettyCatsServerOptions[IO] = interceptors( NettyCatsServerOptions.customiseInterceptors[IO](dispatcher) ).options diff --git a/server/netty-server/src/main/scala/sttp/tapir/server/netty/NettyConfig.scala b/server/netty-server/src/main/scala/sttp/tapir/server/netty/NettyConfig.scala index 6339c63544..a82b793ae3 100644 --- a/server/netty-server/src/main/scala/sttp/tapir/server/netty/NettyConfig.scala +++ b/server/netty-server/src/main/scala/sttp/tapir/server/netty/NettyConfig.scala @@ -14,6 +14,8 @@ import io.netty.handler.timeout.ReadTimeoutException import sttp.tapir.server.netty.NettyConfig.EventLoopConfig import scala.concurrent.duration._ +import com.typesafe.netty.http.HttpStreamsClientHandler +import sttp.tapir.server.netty.internal.NettyStreamingHandler /** Netty configuration, used by [[NettyFutureServer]] and other server implementations to configure the networking layer, the Netty * processing pipeline, and start & stop the server. @@ -119,10 +121,10 @@ object NettyConfig { def defaultInitPipeline(cfg: NettyConfig)(pipeline: ChannelPipeline, handler: ChannelHandler): Unit = { cfg.sslContext.foreach(s => pipeline.addLast(s.newHandler(pipeline.channel().alloc()))) + println("--------------------------------------- default init pipeline") pipeline.addLast(new HttpServerCodec()) - pipeline.addLast(new HttpObjectAggregator(cfg.maxContentLength)) - pipeline.addLast(new ChunkedWriteHandler()) pipeline.addLast("serverStreamsHandler", new HttpStreamsServerHandler()) + pipeline.addLast(new ChunkedWriteHandler()) pipeline.addLast(handler) if (cfg.addLoggingHandler) pipeline.addLast(new LoggingHandler()) () diff --git a/server/netty-server/src/main/scala/sttp/tapir/server/netty/NettyServerRequest.scala b/server/netty-server/src/main/scala/sttp/tapir/server/netty/NettyServerRequest.scala index bda70cd592..7b9be99466 100644 --- a/server/netty-server/src/main/scala/sttp/tapir/server/netty/NettyServerRequest.scala +++ b/server/netty-server/src/main/scala/sttp/tapir/server/netty/NettyServerRequest.scala @@ -2,13 +2,14 @@ package sttp.tapir.server.netty import scala.collection.JavaConverters._ import scala.collection.immutable.Seq -import io.netty.handler.codec.http.{FullHttpRequest, QueryStringDecoder} +import io.netty.handler.codec.http.{HttpRequest, QueryStringDecoder} import sttp.model.{Header, Method, QueryParams, Uri} import sttp.tapir.{AttributeKey, AttributeMap} import sttp.tapir.model.{ConnectionInfo, ServerRequest} import sttp.tapir.server.netty.internal.RichNettyHttpHeaders +import io.netty.handler.codec.http.FullHttpRequest -case class NettyServerRequest(req: FullHttpRequest, attributes: AttributeMap = AttributeMap.Empty) extends ServerRequest { +case class NettyServerRequest(req: HttpRequest, attributes: AttributeMap = AttributeMap.Empty) extends ServerRequest { override lazy val protocol: String = req.protocolVersion().text() override lazy val connectionInfo: ConnectionInfo = ConnectionInfo.NoInfo override lazy val underlying: Any = req @@ -25,9 +26,12 @@ case class NettyServerRequest(req: FullHttpRequest, attributes: AttributeMap = A override lazy val method: Method = Method.unsafeApply(req.method().name()) override lazy val uri: Uri = Uri.unsafeParse(req.uri()) override lazy val pathSegments: List[String] = uri.pathSegments.segments.map(_.v).filter(_.nonEmpty).toList - override lazy val headers: Seq[Header] = req.headers().toHeaderSeq ::: req.trailingHeaders().toHeaderSeq + override lazy val headers: Seq[Header] = req.headers().toHeaderSeq ::: (req match { + case full: FullHttpRequest => full.trailingHeaders().toHeaderSeq + case _ => List.empty + }) override def attribute[T](k: AttributeKey[T]): Option[T] = attributes.get(k) override def attribute[T](k: AttributeKey[T], v: T): NettyServerRequest = copy(attributes = attributes.put(k, v)) override def withUnderlying(underlying: Any): ServerRequest = - NettyServerRequest(req = underlying.asInstanceOf[FullHttpRequest], attributes) + NettyServerRequest(req = underlying.asInstanceOf[HttpRequest], attributes) } diff --git a/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/NettyRequestBody.scala b/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/NettyRequestBody.scala index e32c06fc46..4a9ab12c73 100644 --- a/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/NettyRequestBody.scala +++ b/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/NettyRequestBody.scala @@ -46,3 +46,7 @@ class NettyRequestBody[F[_]](createFile: ServerRequest => F[TapirFile])(implicit private def nettyRequest(serverRequest: ServerRequest): FullHttpRequest = serverRequest.underlying.asInstanceOf[FullHttpRequest] } + +object NettyRequestBody { + private[internal] val bufferSize = 8192 +} diff --git a/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/NettyServerHandler.scala b/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/NettyServerHandler.scala index 1e42ddb856..ce5bf08ce4 100644 --- a/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/NettyServerHandler.scala +++ b/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/NettyServerHandler.scala @@ -18,60 +18,106 @@ import sttp.tapir.server.netty.{NettyResponse, NettyResponseContent, NettyServer import scala.collection.JavaConverters._ import com.typesafe.netty.http.DefaultStreamedHttpResponse +import com.typesafe.netty.http.DelegateStreamedHttpRequest +import com.typesafe.netty.http.StreamedHttpRequest + +class NettyStreamingHandler[F[_]]() extends SimpleChannelInboundHandler[StreamedHttpRequest] { + + override def channelRead0(ctx: ChannelHandlerContext, request: StreamedHttpRequest): Unit = { + println(">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> Handling STREAMING request") + } +} class NettyServerHandler[F[_]](route: Route[F], unsafeRunAsync: (() => F[Unit]) => Unit)(implicit me: MonadError[F]) - extends SimpleChannelInboundHandler[FullHttpRequest] { + extends SimpleChannelInboundHandler[HttpRequest] { private val logger = Logger[NettyServerHandler[F]] - override def channelRead0(ctx: ChannelHandlerContext, request: FullHttpRequest): Unit = { + override def channelRead0(ctx: ChannelHandlerContext, request: HttpRequest): Unit = { if (HttpUtil.is100ContinueExpected(request)) { ctx.writeAndFlush(new DefaultFullHttpResponse(HttpVersion.HTTP_1_1, HttpResponseStatus.CONTINUE)) () } else { - val req = request.retain() + request match { + case full: FullHttpRequest => + val req = full.retain() + + unsafeRunAsync { () => + route(NettyServerRequest(req)) + .map { + case Some(response) => response + case None => ServerResponse.notFound + } + .flatMap((serverResponse: ServerResponse[NettyResponse]) => + // in ZIO, exceptions thrown in .map become defects - instead, we want them represented as errors so that + // we get the 500 response, instead of dropping the request + try handleResponse(ctx, req, serverResponse).unit + catch { + case e: Exception => me.error[Unit](e) + } + ) + .handleError { case ex: Exception => + logger.error("Error while processing the request", ex) + // send 500 + val res = new DefaultFullHttpResponse(HttpVersion.HTTP_1_1, HttpResponseStatus.INTERNAL_SERVER_ERROR) + res.headers().set(HttpHeaderNames.CONTENT_LENGTH, 0) + res.handleCloseAndKeepAliveHeaders(req) + + ctx.writeAndFlush(res).closeIfNeeded(req) + me.unit(()) + } + .ensure(me.eval(req.release())) + } // exceptions should be handled + case req: StreamedHttpRequest => + println(">>>>>>>>>>>>>>>>>>>>>>>>>>> Last handler receives a StreamedHttpRequest") + unsafeRunAsync { () => + route(NettyServerRequest(req)) + .map { + case Some(response) => response + case None => ServerResponse.notFound + } + .flatMap((serverResponse: ServerResponse[NettyResponse]) => + // in ZIO, exceptions thrown in .map become defects - instead, we want them represented as errors so that + // we get the 500 response, instead of dropping the request + try handleResponseStreamingReq(ctx, req, serverResponse).unit + catch { + case e: Exception => me.error[Unit](e) + } + ) + .handleError { case ex: Exception => + logger.error("Error while processing the request", ex) + // send 500 + val res = new DefaultFullHttpResponse(HttpVersion.HTTP_1_1, HttpResponseStatus.INTERNAL_SERVER_ERROR) + res.headers().set(HttpHeaderNames.CONTENT_LENGTH, 0) + + ctx.writeAndFlush(res) + me.unit(()) + } + } // exceptions should be handled + case _ => throw new UnsupportedOperationException(s"Unexpected Netty request type: ${request.getClass.getName}") + } - unsafeRunAsync { () => - route(NettyServerRequest(req)) - .map { - case Some(response) => response - case None => ServerResponse.notFound - } - .flatMap((serverResponse: ServerResponse[NettyResponse]) => - // in ZIO, exceptions thrown in .map become defects - instead, we want them represented as errors so that - // we get the 500 response, instead of dropping the request - try handleResponse(ctx, req, serverResponse).unit - catch { - case e: Exception => me.error[Unit](e) - } - ) - .handleError { case ex: Exception => - logger.error("Error while processing the request", ex) - // send 500 - val res = new DefaultFullHttpResponse(HttpVersion.HTTP_1_1, HttpResponseStatus.INTERNAL_SERVER_ERROR) - res.headers().set(HttpHeaderNames.CONTENT_LENGTH, 0) - res.handleCloseAndKeepAliveHeaders(req) - - ctx.writeAndFlush(res).closeIfNeeded(req) - me.unit(()) - } - .ensure(me.eval(req.release())) - } // exceptions should be handled () } } - private def handleResponse(ctx: ChannelHandlerContext, req: FullHttpRequest, serverResponse: ServerResponse[NettyResponse]): Unit = + private def handleResponseStreamingReq( + ctx: ChannelHandlerContext, + req: StreamedHttpRequest, + serverResponse: ServerResponse[NettyResponse] + ): Unit = serverResponse.handle( ctx = ctx, byteBufHandler = (channelPromise, byteBuf) => { + println(s"Response code = ${serverResponse.code.code}") val res = new DefaultFullHttpResponse(req.protocolVersion(), HttpResponseStatus.valueOf(serverResponse.code.code), byteBuf) + println(s">>>>>>>>>>>>>> byte buf handler, rc = ${serverResponse.code.code}") res.setHeadersFrom(serverResponse) res.handleContentLengthAndChunkedHeaders(Option(byteBuf.readableBytes())) - res.handleCloseAndKeepAliveHeaders(req) ctx.writeAndFlush(res, channelPromise).closeIfNeeded(req) + }, chunkedStreamHandler = (channelPromise, chunkedStream) => { val resHeader: DefaultHttpResponse = @@ -79,10 +125,10 @@ class NettyServerHandler[F[_]](route: Route[F], unsafeRunAsync: (() => F[Unit]) resHeader.setHeadersFrom(serverResponse) resHeader.handleContentLengthAndChunkedHeaders(None) - resHeader.handleCloseAndKeepAliveHeaders(req) ctx.write(resHeader) ctx.writeAndFlush(new HttpChunkedInput(chunkedStream), channelPromise).closeIfNeeded(req) + }, chunkedFileHandler = (channelPromise, chunkedFile) => { val resHeader: DefaultHttpResponse = @@ -92,19 +138,98 @@ class NettyServerHandler[F[_]](route: Route[F], unsafeRunAsync: (() => F[Unit]) resHeader.handleContentLengthAndChunkedHeaders(Option(chunkedFile.length())) resHeader.handleCloseAndKeepAliveHeaders(req) + ctx.write(resHeader) // HttpChunkedInput will write the end marker (LastHttpContent) for us. ctx.writeAndFlush(new HttpChunkedInput(chunkedFile), channelPromise).closeIfNeeded(req) + }, reactiveStreamHandler = (channelPromise, publisher) => { - val resHeader: DefaultStreamedHttpResponse = new DefaultStreamedHttpResponse( - req.protocolVersion(), HttpResponseStatus.valueOf(serverResponse.code.code), publisher) + println(">>>>>>>>>>>>>>>>>>>> stream handler") + val res: DefaultStreamedHttpResponse = + new DefaultStreamedHttpResponse(req.protocolVersion(), HttpResponseStatus.valueOf(serverResponse.code.code), publisher) + + println(s">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> Flushing $res") + res.setHeadersFrom(serverResponse) + res.handleCloseAndKeepAliveHeaders(req) + res.handleContentLengthAndChunkedHeaders(None) + ctx.writeAndFlush(res, channelPromise).closeIfNeeded(req) + + + }, + noBodyHandler = () => { + val res = new DefaultFullHttpResponse( + req.protocolVersion(), + HttpResponseStatus.valueOf(serverResponse.code.code), + Unpooled.EMPTY_BUFFER + ) + + res.handleCloseAndKeepAliveHeaders(req) + res.setHeadersFrom(serverResponse) + res.handleContentLengthAndChunkedHeaders(Option(Unpooled.EMPTY_BUFFER.readableBytes())) + + ctx.writeAndFlush(res).closeIfNeeded(req) + + } + ) + private def handleResponse(ctx: ChannelHandlerContext, req: FullHttpRequest, serverResponse: ServerResponse[NettyResponse]): Unit = + serverResponse.handle( + ctx = ctx, + byteBufHandler = (channelPromise, byteBuf) => { + val res = new DefaultFullHttpResponse(req.protocolVersion(), HttpResponseStatus.valueOf(serverResponse.code.code), byteBuf) + + println(s">>>>>>>>>>>>>> byte buf handler, rc = ${serverResponse.code.code}") + println(s">>>>>>>>>>>>>> byte buf handler, req = ${req}") + res.setHeadersFrom(serverResponse) + res.handleContentLengthAndChunkedHeaders(Option(byteBuf.readableBytes())) + res.handleCloseAndKeepAliveHeaders(req) + + ctx.writeAndFlush(res, channelPromise).closeIfNeeded(req) + }, + chunkedStreamHandler = (channelPromise, chunkedStream) => { + val resHeader: DefaultHttpResponse = + new DefaultHttpResponse(req.protocolVersion(), HttpResponseStatus.valueOf(serverResponse.code.code)) resHeader.setHeadersFrom(serverResponse) resHeader.handleContentLengthAndChunkedHeaders(None) resHeader.handleCloseAndKeepAliveHeaders(req) - }, + ctx.write(resHeader) + ctx.writeAndFlush(new HttpChunkedInput(chunkedStream), channelPromise).closeIfNeeded(req) + }, + chunkedFileHandler = (channelPromise, chunkedFile) => { + println(s">>>>>>>>>>>>>>>>>>>>>>> Handling chunked file with FullHttpRequest, ${req}") + println(s">>>>>>>>>>>>>>>>>>>> $chunkedFile") + val resHeader: DefaultHttpResponse = + new DefaultHttpResponse(req.protocolVersion(), HttpResponseStatus.valueOf(serverResponse.code.code)) + + resHeader.setHeadersFrom(serverResponse) + println(s">>>>>>>>>>>>>>>> response headers: ${serverResponse.headers}") + resHeader.handleContentLengthAndChunkedHeaders(Option(chunkedFile.length())) + resHeader.handleCloseAndKeepAliveHeaders(req) + println(">>>>>>>> writing resp header") + ctx.write(resHeader) + println(">>>>>>>> flushing") + // HttpChunkedInput will write the end marker (LastHttpContent) for us. + ctx.writeAndFlush(new HttpChunkedInput(chunkedFile), channelPromise).closeIfNeeded(req) + println(">>>>>>>> flushed") + }, + reactiveStreamHandler = (channelPromise, publisher) => { + println(">>>>>>>>>>>>>>>>>>>> stream handler") + val res: DefaultStreamedHttpResponse = + new DefaultStreamedHttpResponse(req.protocolVersion(), HttpResponseStatus.valueOf(serverResponse.code.code), publisher) + + // val v = "TEST RESPONSE" + // val bytes = v.asInstanceOf[String].getBytes("UTF-8") + // val byteBuf = Unpooled.wrappedBuffer(bytes) + // val res2 = new DefaultFullHttpResponse(req.protocolVersion(), HttpResponseStatus.valueOf(serverResponse.code.code), byteBuf) + println(s">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> Flushing $res") + res.setHeadersFrom(serverResponse) + res.handleContentLengthAndChunkedHeaders(None) + res.handleCloseAndKeepAliveHeaders(req) + ctx.writeAndFlush(res, channelPromise).closeIfNeeded(req) + + }, noBodyHandler = () => { val res = new DefaultFullHttpResponse( req.protocolVersion(), @@ -134,9 +259,9 @@ class NettyServerHandler[F[_]](route: Route[F], unsafeRunAsync: (() => F[Unit]) val values = function(ctx) values match { - case r: ByteBufNettyResponseContent => byteBufHandler(r.channelPromise, r.byteBuf) - case r: ChunkedStreamNettyResponseContent => chunkedStreamHandler(r.channelPromise, r.chunkedStream) - case r: ChunkedFileNettyResponseContent => chunkedFileHandler(r.channelPromise, r.chunkedFile) + case r: ByteBufNettyResponseContent => byteBufHandler(r.channelPromise, r.byteBuf) + case r: ChunkedStreamNettyResponseContent => chunkedStreamHandler(r.channelPromise, r.chunkedStream) + case r: ChunkedFileNettyResponseContent => chunkedFileHandler(r.channelPromise, r.chunkedFile) case r: NettyResponseContent.ReactivePublisherNettyResponseContent => reactiveStreamHandler(r.channelPromise, r.publisher) } } @@ -158,11 +283,12 @@ class NettyServerHandler[F[_]](route: Route[F], unsafeRunAsync: (() => F[Unit]) val lengthKnownAndShouldBeSet = !m.headers().contains(HttpHeaderNames.CONTENT_LENGTH) && length.nonEmpty val lengthUnknownAndChunkedShouldBeUsed = !m.headers().contains(HttpHeaderNames.CONTENT_LENGTH) && length.isEmpty + println(s"l1 = $lengthKnownAndShouldBeSet, l2 = $lengthUnknownAndChunkedShouldBeUsed") if (lengthKnownAndShouldBeSet) { length.map { l => m.headers().set(HttpHeaderNames.CONTENT_LENGTH, l) } } if (lengthUnknownAndChunkedShouldBeUsed) { m.headers().add(HttpHeaderNames.TRANSFER_ENCODING, HttpHeaderValues.CHUNKED) } } - def handleCloseAndKeepAliveHeaders(request: FullHttpRequest): Unit = { + def handleCloseAndKeepAliveHeaders(request: HttpRequest): Unit = { if (!HttpUtil.isKeepAlive(request)) m.headers.set(HttpHeaderNames.CONNECTION, HttpHeaderValues.CLOSE) else if (request.protocolVersion.equals(HttpVersion.HTTP_1_0)) @@ -171,7 +297,7 @@ class NettyServerHandler[F[_]](route: Route[F], unsafeRunAsync: (() => F[Unit]) } private implicit class RichChannelFuture(val cf: ChannelFuture) { - def closeIfNeeded(request: FullHttpRequest): Unit = { + def closeIfNeeded(request: HttpRequest): Unit = { if (!HttpUtil.isKeepAlive(request)) { cf.addListener(ChannelFutureListener.CLOSE) } diff --git a/server/tests/src/main/scala/sttp/tapir/server/tests/CreateServerTest.scala b/server/tests/src/main/scala/sttp/tapir/server/tests/CreateServerTest.scala index 950d018d67..ddbf2854b2 100644 --- a/server/tests/src/main/scala/sttp/tapir/server/tests/CreateServerTest.scala +++ b/server/tests/src/main/scala/sttp/tapir/server/tests/CreateServerTest.scala @@ -79,6 +79,7 @@ class DefaultCreateServerTest[F[_], +R, OPTIONS, ROUTE]( _ <- Resource.eval(IO(logger.info(s"Bound server on port: $port"))) } yield port + import scala.concurrent.duration._ Test(name)( resources .use { port => diff --git a/server/tests/src/main/scala/sttp/tapir/server/tests/ServerFilesTests.scala b/server/tests/src/main/scala/sttp/tapir/server/tests/ServerFilesTests.scala index d67690407f..ee428b4f6a 100644 --- a/server/tests/src/main/scala/sttp/tapir/server/tests/ServerFilesTests.scala +++ b/server/tests/src/main/scala/sttp/tapir/server/tests/ServerFilesTests.scala @@ -218,8 +218,9 @@ class ServerFilesTests[F[_], OPTIONS, ROUTE]( }, Test("should return 200 status code for whole file") { withTestFilesDirectory { testDir => + import scala.concurrent.duration._ serveRoute(staticFilesGetServerEndpoint[F]("test")(testDir.getAbsolutePath)) - .use { port => get(port, List("test", "f2")).map(_.code shouldBe StatusCode.Ok) } + .use { port => get(port, List("test", "f2")).timeout(10.seconds).map(_.code shouldBe StatusCode.Ok) } .unsafeToFuture() } }, diff --git a/server/tests/src/main/scala/sttp/tapir/server/tests/ServerStreamingTests.scala b/server/tests/src/main/scala/sttp/tapir/server/tests/ServerStreamingTests.scala index 3369f6c009..c0f3601e59 100644 --- a/server/tests/src/main/scala/sttp/tapir/server/tests/ServerStreamingTests.scala +++ b/server/tests/src/main/scala/sttp/tapir/server/tests/ServerStreamingTests.scala @@ -14,6 +14,9 @@ import sttp.tapir.tests.Streaming.{ in_string_stream_out_either_stream_string, out_custom_content_type_stream_body } +import scala.concurrent.duration._ +import java.io.ByteArrayInputStream +import java.io.InputStream class ServerStreamingTests[F[_], S, OPTIONS, ROUTE](createServerTest: CreateServerTest[F, S, OPTIONS, ROUTE], streams: Streams[S])(implicit m: MonadError[F] @@ -33,6 +36,7 @@ class ServerStreamingTests[F[_], S, OPTIONS, ROUTE](createServerTest: CreateServ )((in: (Long, streams.BinaryStream)) => pureResult(in.asRight[Unit])) { (backend, baseUri) => { basicRequest + .readTimeout(3.seconds) .post(uri"$baseUri/api/echo") .contentLength(penPineapple.length.toLong) .body(penPineapple) @@ -52,6 +56,7 @@ class ServerStreamingTests[F[_], S, OPTIONS, ROUTE](createServerTest: CreateServ } { (backend, baseUri) => basicRequest .post(uri"$baseUri?kind=-1") + .readTimeout(3.seconds) .body(penPineapple) .send(backend) .map { r => @@ -85,6 +90,7 @@ class ServerStreamingTests[F[_], S, OPTIONS, ROUTE](createServerTest: CreateServ testServer(in_stream_out_either_json_xml_stream(streams)) { s => pureResult(s.asRight[Unit]) } { (backend, baseUri) => basicRequest .post(uri"$baseUri") + .readTimeout(3.seconds) .body(penPineapple) .header(Header.accept(MediaType.ApplicationXml, MediaType.ApplicationJson)) .send(backend) diff --git a/tests/src/main/scala/sttp/tapir/tests/Streaming.scala b/tests/src/main/scala/sttp/tapir/tests/Streaming.scala index 91b5a59c18..33d53aedb2 100644 --- a/tests/src/main/scala/sttp/tapir/tests/Streaming.scala +++ b/tests/src/main/scala/sttp/tapir/tests/Streaming.scala @@ -7,6 +7,11 @@ import sttp.tapir._ import java.nio.charset.StandardCharsets object Streaming { + def in_string_out_string_stream[S](s: Streams[S]) = { + val sb = streamTextBody(s)(CodecFormat.TextPlain(), Some(StandardCharsets.UTF_8)) + endpoint.post.in("api" / "echo").in(stringBody).out(sb) + } + def in_stream_out_stream[S](s: Streams[S]): PublicEndpoint[s.BinaryStream, Unit, s.BinaryStream, S] = { val sb = streamTextBody(s)(CodecFormat.TextPlain(), Some(StandardCharsets.UTF_8)) endpoint.post.in("api" / "echo").in(sb).out(sb) From c04918b85088bfc77307e08964aac74c87d02ee1 Mon Sep 17 00:00:00 2001 From: kciesielski Date: Sat, 15 Jul 2023 02:17:42 +0200 Subject: [PATCH 03/18] Streaming for Netty Cats server --- .../server/netty/cats/NettyCatsServer.scala | 4 +- .../internal/NettyCatsToResponseBody.scala | 82 ++++++---------- .../netty/cats/NettyCatsServerTest.scala | 59 +---------- .../cats/NettyCatsTestServerInterpreter.scala | 2 +- .../sttp/tapir/server/netty/NettyConfig.scala | 16 ++- .../netty/internal/NettyServerHandler.scala | 98 +------------------ 6 files changed, 50 insertions(+), 211 deletions(-) diff --git a/server/netty-server/cats/src/main/scala/sttp/tapir/server/netty/cats/NettyCatsServer.scala b/server/netty-server/cats/src/main/scala/sttp/tapir/server/netty/cats/NettyCatsServer.scala index 39f0ba1826..152fba1328 100644 --- a/server/netty-server/cats/src/main/scala/sttp/tapir/server/netty/cats/NettyCatsServer.scala +++ b/server/netty-server/cats/src/main/scala/sttp/tapir/server/netty/cats/NettyCatsServer.scala @@ -81,9 +81,9 @@ case class NettyCatsServer[F[_]: Async](routes: Vector[Route[F]], options: Netty object NettyCatsServer { def apply[F[_]: Async](dispatcher: Dispatcher[F]): NettyCatsServer[F] = - NettyCatsServer(Vector.empty, NettyCatsServerOptions.default(dispatcher), NettyConfig.default) + NettyCatsServer(Vector.empty, NettyCatsServerOptions.default(dispatcher), NettyConfig.defaultWithStreaming) def apply[F[_]: Async](options: NettyCatsServerOptions[F]): NettyCatsServer[F] = - NettyCatsServer(Vector.empty, options, NettyConfig.default) + NettyCatsServer(Vector.empty, options, NettyConfig.defaultWithStreaming) def apply[F[_]: Async](dispatcher: Dispatcher[F], config: NettyConfig): NettyCatsServer[F] = NettyCatsServer(Vector.empty, NettyCatsServerOptions.default(dispatcher), config) def apply[F[_]: Async](options: NettyCatsServerOptions[F], config: NettyConfig): NettyCatsServer[F] = diff --git a/server/netty-server/cats/src/main/scala/sttp/tapir/server/netty/internal/NettyCatsToResponseBody.scala b/server/netty-server/cats/src/main/scala/sttp/tapir/server/netty/internal/NettyCatsToResponseBody.scala index 5ad206b7f3..ed4a2b14a8 100644 --- a/server/netty-server/cats/src/main/scala/sttp/tapir/server/netty/internal/NettyCatsToResponseBody.scala +++ b/server/netty-server/cats/src/main/scala/sttp/tapir/server/netty/internal/NettyCatsToResponseBody.scala @@ -1,34 +1,25 @@ package sttp.tapir.server.netty.internal -import cats.effect.kernel.Async +import cats.effect.kernel.{Async, Sync} import cats.effect.std.Dispatcher +import fs2.Chunk import fs2.interop.reactivestreams._ +import fs2.io.file.{Files, Flags, Path} import io.netty.buffer.Unpooled import io.netty.channel.ChannelHandlerContext -import io.netty.handler.stream.{ChunkedFile, ChunkedStream} +import io.netty.handler.codec.http.{DefaultHttpContent, HttpContent} +import io.netty.handler.stream.ChunkedStream +import org.reactivestreams.Publisher +import sttp.capabilities.fs2.Fs2Streams import sttp.model.HasHeaders -import sttp.tapir.capabilities.NoStreams import sttp.tapir.server.interpreter.ToResponseBody -import sttp.tapir.server.netty.NettyResponse -import sttp.tapir.server.netty.NettyResponseContent.{ - ByteBufNettyResponseContent, - ChunkedFileNettyResponseContent, - ChunkedStreamNettyResponseContent -} -import sttp.tapir.{CodecFormat, FileRange, InputStreamRange, RawBodyType, WebSocketBodyOutput} +import sttp.tapir.server.netty.NettyResponseContent.ByteBufNettyResponseContent +import sttp.tapir.server.netty.{NettyResponse, NettyResponseContent} +import sttp.tapir.{CodecFormat, InputStreamRange, RawBodyType, WebSocketBodyOutput} -import java.io.{InputStream, RandomAccessFile} +import java.io.InputStream import java.nio.ByteBuffer import java.nio.charset.Charset -import sttp.capabilities.fs2.Fs2Streams -import sttp.tapir.server.netty.NettyResponse -import sttp.tapir.server.netty.NettyResponseContent -import org.reactivestreams.Publisher -import io.netty.buffer.ByteBuf -import org.reactivestreams.Subscriber -import fs2.Chunk -import io.netty.handler.codec.http.HttpContent -import io.netty.handler.codec.http.DefaultHttpContent private[netty] class RangedChunkedStream(raw: InputStream, length: Long) extends ChunkedStream(raw) { @@ -40,7 +31,6 @@ class NettyCatsToResponseBody[F[_]: Async](dispatcher: Dispatcher[F]) extends To override val streams: Fs2Streams[F] = Fs2Streams[F] override def fromRawValue[R](v: R, headers: HasHeaders, format: CodecFormat, bodyType: RawBodyType[R]): NettyResponse = { - println(">>>>>>>>>>>>>>>>>>>>>>> Using fromRawValue") bodyType match { case RawBodyType.StringBody(charset) => val bytes = v.asInstanceOf[String].getBytes(charset) @@ -55,50 +45,43 @@ class NettyCatsToResponseBody[F[_]: Async](dispatcher: Dispatcher[F]) extends To (ctx: ChannelHandlerContext) => ByteBufNettyResponseContent(ctx.newPromise(), Unpooled.wrappedBuffer(byteBuffer)) case RawBodyType.InputStreamBody => - (ctx: ChannelHandlerContext) => ChunkedStreamNettyResponseContent(ctx.newPromise(), wrap(v)) + val stream = inputStreamToFs2(() => v) + (ctx: ChannelHandlerContext) => + new NettyResponseContent.ReactivePublisherNettyResponseContent(ctx.newPromise(), fs2StreamToPublisher(stream)) case RawBodyType.InputStreamRangeBody => - (ctx: ChannelHandlerContext) => ChunkedStreamNettyResponseContent(ctx.newPromise(), wrap(v)) + val stream = v.range + .map(range => inputStreamToFs2(v.inputStreamFromRangeStart).take(range.contentLength)) + .getOrElse(inputStreamToFs2(v.inputStream)) + (ctx: ChannelHandlerContext) => + new NettyResponseContent.ReactivePublisherNettyResponseContent(ctx.newPromise(), fs2StreamToPublisher(stream)) case RawBodyType.FileBody => - println(s"Returning FileBody, headers = $headers") - (ctx: ChannelHandlerContext) => ChunkedFileNettyResponseContent(ctx.newPromise(), wrap(v)) + val tapirFile = v + val path = Path.fromNioPath(tapirFile.file.toPath) + val stream = tapirFile.range + .flatMap(r => r.startAndEnd.map(s => Files[F](Files.forAsync[F]).readRange(path, r.contentLength.toInt, s._1, s._2))) + .getOrElse(Files[F](Files.forAsync[F]).readAll(path, NettyToResponseBody.DefaultChunkSize, Flags.Read)) + + (ctx: ChannelHandlerContext) => + new NettyResponseContent.ReactivePublisherNettyResponseContent(ctx.newPromise(), fs2StreamToPublisher(stream)) case _: RawBodyType.MultipartBody => throw new UnsupportedOperationException } } + private def inputStreamToFs2(inputStream: () => InputStream) = + fs2.io.readInputStream( + Sync[F].blocking(inputStream()), + NettyToResponseBody.DefaultChunkSize + ) private def wrap(streamRange: InputStreamRange): ChunkedStream = { streamRange.range .map(r => new RangedChunkedStream(streamRange.inputStreamFromRangeStart(), r.contentLength)) .getOrElse(new ChunkedStream(streamRange.inputStream())) } - private def wrap(content: InputStream): ChunkedStream = { - new ChunkedStream(content) - } - - private def wrap(content: FileRange): ChunkedFile = { - val file = content.file - val maybeRange = for { - range <- content.range - start <- range.start - end <- range.end - } yield (start, end + NettyToResponseBody.IncludingLastOffset) - - maybeRange match { - case Some((start, end)) => { - val randomAccessFile = new RandomAccessFile(file, NettyToResponseBody.ReadOnlyAccessMode) - new ChunkedFile(randomAccessFile, start, end - start, NettyToResponseBody.DefaultChunkSize) - } - case None => - println(s">>>>>>>>>>>>>>>>>>>>>>>>>> no range, $file") - new ChunkedFile(file) - } - } - def fs2StreamToPublisher(stream: streams.BinaryStream): Publisher[HttpContent] = { - println(">>>>>>>>>>>>> handling stream and creating a publisher") // Deprecated constructor, but the proposed one does roughly the same, forcing a dedicated // dispatcher, which results in a Resource[], which is hard to afford here StreamUnicastPublisher( @@ -118,7 +101,6 @@ class NettyCatsToResponseBody[F[_]: Async](dispatcher: Dispatcher[F]) extends To charset: Option[Charset] ): NettyResponse = (ctx: ChannelHandlerContext) => { - println(">>>>>>>> Creating reactive stream from response") new NettyResponseContent.ReactivePublisherNettyResponseContent(ctx.newPromise(), fs2StreamToPublisher(v)) } diff --git a/server/netty-server/cats/src/test/scala/sttp/tapir/server/netty/cats/NettyCatsServerTest.scala b/server/netty-server/cats/src/test/scala/sttp/tapir/server/netty/cats/NettyCatsServerTest.scala index 988cdf5c3c..5c4a81acd2 100644 --- a/server/netty-server/cats/src/test/scala/sttp/tapir/server/netty/cats/NettyCatsServerTest.scala +++ b/server/netty-server/cats/src/test/scala/sttp/tapir/server/netty/cats/NettyCatsServerTest.scala @@ -1,29 +1,17 @@ package sttp.tapir.server.netty.cats -import cats.data.NonEmptyList -import cats.effect.unsafe.implicits.global import cats.effect.{IO, Resource} -import cats.syntax.all._ import com.typesafe.scalalogging.StrictLogging import io.netty.channel.nio.NioEventLoopGroup import org.scalatest.EitherValues import org.scalatest.matchers.should.Matchers import sttp.capabilities.fs2.Fs2Streams -import sttp.client3._ import sttp.monad.MonadError -import sttp.tapir._ import sttp.tapir.integ.cats.effect.CatsMonadError import sttp.tapir.server.netty.internal.FutureUtil import sttp.tapir.server.tests._ import sttp.tapir.tests.{Test, TestSuite} -import java.nio.charset.StandardCharsets -import scala.concurrent.duration._ -import fs2.io.file.Files -import java.nio.file.Paths -import fs2.io.file.Path -import scala.util.Random - class NettyCatsServerTest extends TestSuite with EitherValues with StrictLogging with Matchers { override def tests: Resource[IO, List[Test]] = backendResource.flatMap { backend => @@ -35,52 +23,7 @@ class NettyCatsServerTest extends TestSuite with EitherValues with StrictLogging val interpreter = new NettyCatsTestServerInterpreter(eventLoopGroup, dispatcher) val createServerTest = new DefaultCreateServerTest(backend, interpreter) - val tests = new AllServerTests(createServerTest, interpreter, backend, multipart = false).tests() - // - // val tests = new ServerStreamingTests(createServerTest, Fs2Streams[IO]).tests() - - val s = Fs2Streams[IO] - val streamBody = streamTextBody(s)(CodecFormat.TextPlain(), Some(StandardCharsets.UTF_8)) - - val responseStr = "This is response text" - val responseText = responseStr.getBytes().toList - // val responseAsStream = Files[IO].readAll(Path.fromNioPath(Paths.get("/home/kc/LICENSE"))) - val responseAsStream: fs2.Stream[IO, Byte] = fs2.Stream.emits(responseText).covary[IO] - val testRoute = endpoint.post.in("kc").in(streamBody).out(streamBody).serverLogic[IO] { (inputStream) => - val sink = Files[IO].writeAll(Path.fromNioPath(Paths.get(s"./out-${Random.nextInt()}"))) - // val s2 = inputStream.through(sink).compile - // s2.drain.unsafeRunSync() - - IO.delay(Right(inputStream.map(_.toChar).map(_.toString).map(_.toUpperCase).map(_.head).map(_.toByte))) - // IO.delay(Right(fs2.Stream.emits(inputStream.getBytes().toList))) - } - - val route = interpreter.route(testRoute) - val rs = NonEmptyList.one(route) - val resources = for { - port <- interpreter.server(rs).onError { case e: Exception => - Resource.eval(IO(logger.error(s"Starting server failed because of ${e.getMessage}"))) - } - _ <- Resource.eval(IO(logger.info(s"Bound server on port: $port"))) - } yield port - - val requestStr = "Pen pineapple apple pen streaming req text" - val requestBytes = requestStr.getBytes().toList - val requestAsStream: fs2.Stream[IO, Byte] = fs2.Stream.emits(requestBytes).covary[IO] - val tests2 = List(Test("work!") { - resources - .use { port => - val baseUri = uri"http://localhost:$port" - // IO.sleep(30.seconds) >> - basicRequest - .post(uri"$baseUri/kc") - .streamBody(Fs2Streams[IO])(requestAsStream) - // .body("requestStr") - .send(backend) - .map(_.body shouldBe Right(responseStr)) - } - .unsafeToFuture() - }) + val tests = new AllServerTests(createServerTest, interpreter, backend, multipart = false).tests() ++ new ServerStreamingTests(createServerTest, Fs2Streams[IO]).tests() IO.pure((tests, eventLoopGroup)) } { case (_, eventLoopGroup) => diff --git a/server/netty-server/cats/src/test/scala/sttp/tapir/server/netty/cats/NettyCatsTestServerInterpreter.scala b/server/netty-server/cats/src/test/scala/sttp/tapir/server/netty/cats/NettyCatsTestServerInterpreter.scala index 0eb281c97a..63968fa1ce 100644 --- a/server/netty-server/cats/src/test/scala/sttp/tapir/server/netty/cats/NettyCatsTestServerInterpreter.scala +++ b/server/netty-server/cats/src/test/scala/sttp/tapir/server/netty/cats/NettyCatsTestServerInterpreter.scala @@ -20,7 +20,7 @@ class NettyCatsTestServerInterpreter(eventLoopGroup: NioEventLoopGroup, dispatch } override def server(routes: NonEmptyList[Route[IO]]): Resource[IO, Port] = { - val config = NettyConfig.default.eventLoopGroup(eventLoopGroup).randomPort.withDontShutdownEventLoopGroupOnClose + val config = NettyConfig.defaultWithStreaming.eventLoopGroup(eventLoopGroup).randomPort.withDontShutdownEventLoopGroupOnClose val options = NettyCatsServerOptions.default[IO](dispatcher) val bind: IO[NettyCatsServerBinding[IO]] = NettyCatsServer(options, config).addRoutes(routes.toList).start() diff --git a/server/netty-server/src/main/scala/sttp/tapir/server/netty/NettyConfig.scala b/server/netty-server/src/main/scala/sttp/tapir/server/netty/NettyConfig.scala index a82b793ae3..e0733cccfe 100644 --- a/server/netty-server/src/main/scala/sttp/tapir/server/netty/NettyConfig.scala +++ b/server/netty-server/src/main/scala/sttp/tapir/server/netty/NettyConfig.scala @@ -14,8 +14,6 @@ import io.netty.handler.timeout.ReadTimeoutException import sttp.tapir.server.netty.NettyConfig.EventLoopConfig import scala.concurrent.duration._ -import com.typesafe.netty.http.HttpStreamsClientHandler -import sttp.tapir.server.netty.internal.NettyStreamingHandler /** Netty configuration, used by [[NettyFutureServer]] and other server implementations to configure the networking layer, the Netty * processing pipeline, and start & stop the server. @@ -98,6 +96,7 @@ case class NettyConfig( def eventLoopGroup(elg: EventLoopGroup): NettyConfig = copy(eventLoopConfig = EventLoopConfig.useExisting(elg)) def initPipeline(f: NettyConfig => (ChannelPipeline, ChannelHandler) => Unit): NettyConfig = copy(initPipeline = f) + } object NettyConfig { @@ -121,15 +120,24 @@ object NettyConfig { def defaultInitPipeline(cfg: NettyConfig)(pipeline: ChannelPipeline, handler: ChannelHandler): Unit = { cfg.sslContext.foreach(s => pipeline.addLast(s.newHandler(pipeline.channel().alloc()))) - println("--------------------------------------- default init pipeline") pipeline.addLast(new HttpServerCodec()) - pipeline.addLast("serverStreamsHandler", new HttpStreamsServerHandler()) + pipeline.addLast(new HttpObjectAggregator(cfg.maxContentLength)) pipeline.addLast(new ChunkedWriteHandler()) pipeline.addLast(handler) + () + } + + def streamingPipeline(cfg: NettyConfig)(pipeline: ChannelPipeline, handler: ChannelHandler): Unit = { + cfg.sslContext.foreach(s => pipeline.addLast(s.newHandler(pipeline.channel().alloc()))) + pipeline.addLast(new HttpServerCodec()) + pipeline.addLast(new HttpStreamsServerHandler()) + pipeline.addLast(handler) if (cfg.addLoggingHandler) pipeline.addLast(new LoggingHandler()) () } + def defaultWithStreaming: NettyConfig = default.copy(initPipeline = cfg => streamingPipeline(cfg)(_, _)) + case class EventLoopConfig(initEventLoopGroup: () => EventLoopGroup, serverChannel: Class[_ <: ServerChannel]) object EventLoopConfig { diff --git a/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/NettyServerHandler.scala b/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/NettyServerHandler.scala index ce5bf08ce4..09025dfa08 100644 --- a/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/NettyServerHandler.scala +++ b/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/NettyServerHandler.scala @@ -21,13 +21,6 @@ import com.typesafe.netty.http.DefaultStreamedHttpResponse import com.typesafe.netty.http.DelegateStreamedHttpRequest import com.typesafe.netty.http.StreamedHttpRequest -class NettyStreamingHandler[F[_]]() extends SimpleChannelInboundHandler[StreamedHttpRequest] { - - override def channelRead0(ctx: ChannelHandlerContext, request: StreamedHttpRequest): Unit = { - println(">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> Handling STREAMING request") - } -} - class NettyServerHandler[F[_]](route: Route[F], unsafeRunAsync: (() => F[Unit]) => Unit)(implicit me: MonadError[F]) extends SimpleChannelInboundHandler[HttpRequest] { @@ -69,7 +62,6 @@ class NettyServerHandler[F[_]](route: Route[F], unsafeRunAsync: (() => F[Unit]) .ensure(me.eval(req.release())) } // exceptions should be handled case req: StreamedHttpRequest => - println(">>>>>>>>>>>>>>>>>>>>>>>>>>> Last handler receives a StreamedHttpRequest") unsafeRunAsync { () => route(NettyServerRequest(req)) .map { @@ -79,7 +71,7 @@ class NettyServerHandler[F[_]](route: Route[F], unsafeRunAsync: (() => F[Unit]) .flatMap((serverResponse: ServerResponse[NettyResponse]) => // in ZIO, exceptions thrown in .map become defects - instead, we want them represented as errors so that // we get the 500 response, instead of dropping the request - try handleResponseStreamingReq(ctx, req, serverResponse).unit + try handleResponse(ctx, req, serverResponse).unit catch { case e: Exception => me.error[Unit](e) } @@ -101,85 +93,12 @@ class NettyServerHandler[F[_]](route: Route[F], unsafeRunAsync: (() => F[Unit]) } } - private def handleResponseStreamingReq( - ctx: ChannelHandlerContext, - req: StreamedHttpRequest, - serverResponse: ServerResponse[NettyResponse] - ): Unit = - serverResponse.handle( - ctx = ctx, - byteBufHandler = (channelPromise, byteBuf) => { - println(s"Response code = ${serverResponse.code.code}") - val res = new DefaultFullHttpResponse(req.protocolVersion(), HttpResponseStatus.valueOf(serverResponse.code.code), byteBuf) - - println(s">>>>>>>>>>>>>> byte buf handler, rc = ${serverResponse.code.code}") - res.setHeadersFrom(serverResponse) - res.handleContentLengthAndChunkedHeaders(Option(byteBuf.readableBytes())) - - ctx.writeAndFlush(res, channelPromise).closeIfNeeded(req) - - }, - chunkedStreamHandler = (channelPromise, chunkedStream) => { - val resHeader: DefaultHttpResponse = - new DefaultHttpResponse(req.protocolVersion(), HttpResponseStatus.valueOf(serverResponse.code.code)) - - resHeader.setHeadersFrom(serverResponse) - resHeader.handleContentLengthAndChunkedHeaders(None) - - ctx.write(resHeader) - ctx.writeAndFlush(new HttpChunkedInput(chunkedStream), channelPromise).closeIfNeeded(req) - - }, - chunkedFileHandler = (channelPromise, chunkedFile) => { - val resHeader: DefaultHttpResponse = - new DefaultHttpResponse(req.protocolVersion(), HttpResponseStatus.valueOf(serverResponse.code.code)) - - resHeader.setHeadersFrom(serverResponse) - resHeader.handleContentLengthAndChunkedHeaders(Option(chunkedFile.length())) - resHeader.handleCloseAndKeepAliveHeaders(req) - - - ctx.write(resHeader) - // HttpChunkedInput will write the end marker (LastHttpContent) for us. - ctx.writeAndFlush(new HttpChunkedInput(chunkedFile), channelPromise).closeIfNeeded(req) - - }, - reactiveStreamHandler = (channelPromise, publisher) => { - println(">>>>>>>>>>>>>>>>>>>> stream handler") - val res: DefaultStreamedHttpResponse = - new DefaultStreamedHttpResponse(req.protocolVersion(), HttpResponseStatus.valueOf(serverResponse.code.code), publisher) - - println(s">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> Flushing $res") - res.setHeadersFrom(serverResponse) - res.handleCloseAndKeepAliveHeaders(req) - res.handleContentLengthAndChunkedHeaders(None) - ctx.writeAndFlush(res, channelPromise).closeIfNeeded(req) - - - }, - noBodyHandler = () => { - val res = new DefaultFullHttpResponse( - req.protocolVersion(), - HttpResponseStatus.valueOf(serverResponse.code.code), - Unpooled.EMPTY_BUFFER - ) - - res.handleCloseAndKeepAliveHeaders(req) - res.setHeadersFrom(serverResponse) - res.handleContentLengthAndChunkedHeaders(Option(Unpooled.EMPTY_BUFFER.readableBytes())) - - ctx.writeAndFlush(res).closeIfNeeded(req) - - } - ) - private def handleResponse(ctx: ChannelHandlerContext, req: FullHttpRequest, serverResponse: ServerResponse[NettyResponse]): Unit = + private def handleResponse(ctx: ChannelHandlerContext, req: HttpRequest, serverResponse: ServerResponse[NettyResponse]): Unit = serverResponse.handle( ctx = ctx, byteBufHandler = (channelPromise, byteBuf) => { val res = new DefaultFullHttpResponse(req.protocolVersion(), HttpResponseStatus.valueOf(serverResponse.code.code), byteBuf) - println(s">>>>>>>>>>>>>> byte buf handler, rc = ${serverResponse.code.code}") - println(s">>>>>>>>>>>>>> byte buf handler, req = ${req}") res.setHeadersFrom(serverResponse) res.handleContentLengthAndChunkedHeaders(Option(byteBuf.readableBytes())) res.handleCloseAndKeepAliveHeaders(req) @@ -198,32 +117,20 @@ class NettyServerHandler[F[_]](route: Route[F], unsafeRunAsync: (() => F[Unit]) ctx.writeAndFlush(new HttpChunkedInput(chunkedStream), channelPromise).closeIfNeeded(req) }, chunkedFileHandler = (channelPromise, chunkedFile) => { - println(s">>>>>>>>>>>>>>>>>>>>>>> Handling chunked file with FullHttpRequest, ${req}") - println(s">>>>>>>>>>>>>>>>>>>> $chunkedFile") val resHeader: DefaultHttpResponse = new DefaultHttpResponse(req.protocolVersion(), HttpResponseStatus.valueOf(serverResponse.code.code)) resHeader.setHeadersFrom(serverResponse) - println(s">>>>>>>>>>>>>>>> response headers: ${serverResponse.headers}") resHeader.handleContentLengthAndChunkedHeaders(Option(chunkedFile.length())) resHeader.handleCloseAndKeepAliveHeaders(req) - println(">>>>>>>> writing resp header") ctx.write(resHeader) - println(">>>>>>>> flushing") // HttpChunkedInput will write the end marker (LastHttpContent) for us. ctx.writeAndFlush(new HttpChunkedInput(chunkedFile), channelPromise).closeIfNeeded(req) - println(">>>>>>>> flushed") }, reactiveStreamHandler = (channelPromise, publisher) => { - println(">>>>>>>>>>>>>>>>>>>> stream handler") val res: DefaultStreamedHttpResponse = new DefaultStreamedHttpResponse(req.protocolVersion(), HttpResponseStatus.valueOf(serverResponse.code.code), publisher) - // val v = "TEST RESPONSE" - // val bytes = v.asInstanceOf[String].getBytes("UTF-8") - // val byteBuf = Unpooled.wrappedBuffer(bytes) - // val res2 = new DefaultFullHttpResponse(req.protocolVersion(), HttpResponseStatus.valueOf(serverResponse.code.code), byteBuf) - println(s">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> Flushing $res") res.setHeadersFrom(serverResponse) res.handleContentLengthAndChunkedHeaders(None) res.handleCloseAndKeepAliveHeaders(req) @@ -283,7 +190,6 @@ class NettyServerHandler[F[_]](route: Route[F], unsafeRunAsync: (() => F[Unit]) val lengthKnownAndShouldBeSet = !m.headers().contains(HttpHeaderNames.CONTENT_LENGTH) && length.nonEmpty val lengthUnknownAndChunkedShouldBeUsed = !m.headers().contains(HttpHeaderNames.CONTENT_LENGTH) && length.isEmpty - println(s"l1 = $lengthKnownAndShouldBeSet, l2 = $lengthUnknownAndChunkedShouldBeUsed") if (lengthKnownAndShouldBeSet) { length.map { l => m.headers().set(HttpHeaderNames.CONTENT_LENGTH, l) } } if (lengthUnknownAndChunkedShouldBeUsed) { m.headers().add(HttpHeaderNames.TRANSFER_ENCODING, HttpHeaderValues.CHUNKED) } } From 305253b8d2e19d22224fdd627dcbe919dafd9283 Mon Sep 17 00:00:00 2001 From: kciesielski Date: Sat, 15 Jul 2023 02:48:19 +0200 Subject: [PATCH 04/18] Adjust addEndpoint to accept streaming endpoints --- .../sttp/tapir/server/netty/cats/NettyCatsServer.scala | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/server/netty-server/cats/src/main/scala/sttp/tapir/server/netty/cats/NettyCatsServer.scala b/server/netty-server/cats/src/main/scala/sttp/tapir/server/netty/cats/NettyCatsServer.scala index 152fba1328..7566146933 100644 --- a/server/netty-server/cats/src/main/scala/sttp/tapir/server/netty/cats/NettyCatsServer.scala +++ b/server/netty-server/cats/src/main/scala/sttp/tapir/server/netty/cats/NettyCatsServer.scala @@ -15,15 +15,16 @@ import sttp.tapir.server.netty.{NettyConfig, Route} import java.net.{InetSocketAddress, SocketAddress} import java.nio.file.{Path, Paths} import java.util.UUID +import sttp.capabilities.fs2.Fs2Streams case class NettyCatsServer[F[_]: Async](routes: Vector[Route[F]], options: NettyCatsServerOptions[F], config: NettyConfig) { - def addEndpoint(se: ServerEndpoint[Any, F]): NettyCatsServer[F] = addEndpoints(List(se)) - def addEndpoint(se: ServerEndpoint[Any, F], overrideOptions: NettyCatsServerOptions[F]): NettyCatsServer[F] = + def addEndpoint(se: ServerEndpoint[Fs2Streams[F], F]): NettyCatsServer[F] = addEndpoints(List(se)) + def addEndpoint(se: ServerEndpoint[Fs2Streams[F], F], overrideOptions: NettyCatsServerOptions[F]): NettyCatsServer[F] = addEndpoints(List(se), overrideOptions) - def addEndpoints(ses: List[ServerEndpoint[Any, F]]): NettyCatsServer[F] = addRoute( + def addEndpoints(ses: List[ServerEndpoint[Fs2Streams[F], F]]): NettyCatsServer[F] = addRoute( NettyCatsServerInterpreter(options).toRoute(ses) ) - def addEndpoints(ses: List[ServerEndpoint[Any, F]], overrideOptions: NettyCatsServerOptions[F]): NettyCatsServer[F] = addRoute( + def addEndpoints(ses: List[ServerEndpoint[Fs2Streams[F], F]], overrideOptions: NettyCatsServerOptions[F]): NettyCatsServer[F] = addRoute( NettyCatsServerInterpreter(overrideOptions).toRoute(ses) ) From 24e64e747b92440fe3df34decdae748c5890f6c5 Mon Sep 17 00:00:00 2001 From: kciesielski Date: Sat, 15 Jul 2023 02:58:47 +0200 Subject: [PATCH 05/18] Mention streaming in docs --- doc/server/netty.md | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/doc/server/netty.md b/doc/server/netty.md index f253265f87..eb6cc831f3 100644 --- a/doc/server/netty.md +++ b/doc/server/netty.md @@ -15,11 +15,11 @@ To expose an endpoint using a [Netty](https://netty.io)-based server, first add Then, use: -* `NettyFutureServer().addEndpoints` to expose `Future`-based server endpoints. -* `NettyCatsServer().addEndpoints` to expose `F`-based server endpoints, where `F` is any cats-effect supported effect. -* `NettyZioServer().addEndpoints` to expose `ZIO`-based server endpoints, where `R` represents ZIO requirements supported effect. +- `NettyFutureServer().addEndpoints` to expose `Future`-based server endpoints. +- `NettyCatsServer().addEndpoints` to expose `F`-based server endpoints, where `F` is any cats-effect supported effect. [Streaming](../endpoint/streaming.md) request and response body is supported with fs2. +- `NettyZioServer().addEndpoints` to expose `ZIO`-based server endpoints, where `R` represents ZIO requirements supported effect. -These methods require a single, or a list of `ServerEndpoint`s, which can be created by adding [server logic](logic.md) +These methods require a single, or a list of `ServerEndpoint`s, which can be created by adding [server logic](logic.md) to an endpoint. For example: @@ -36,7 +36,7 @@ val helloWorld = endpoint .out(stringBody) .serverLogic(name => Future.successful[Either[Unit, String]](Right(s"Hello, $name!"))) -val binding: Future[NettyFutureServerBinding] = +val binding: Future[NettyFutureServerBinding] = NettyFutureServer().addEndpoint(helloWorld).start() ``` @@ -83,4 +83,4 @@ val serverBinding: Future[NettyFutureDomainSocketBinding] = Future.successful[Either[Unit, String]](Right(s"Hello, $name!"))) ) .startUsingDomainSocket(Paths.get(System.getProperty("java.io.tmpdir"), "hello")) -``` \ No newline at end of file +``` From 94b00ee7461879f575254ae13b73e2061906cd42 Mon Sep 17 00:00:00 2001 From: kciesielski Date: Sat, 15 Jul 2023 03:09:34 +0200 Subject: [PATCH 06/18] Minor cleanup --- .../cats/NettyCatsServerInterpreter.scala | 18 ++++---------- .../netty/internal/NettyCatsRequestBody.scala | 24 +++++++------------ .../internal/NettyCatsToResponseBody.scala | 9 ++----- .../netty/internal/NettyRequestBody.scala | 1 - .../netty/internal/NettyServerHandler.scala | 4 +--- 5 files changed, 17 insertions(+), 39 deletions(-) diff --git a/server/netty-server/cats/src/main/scala/sttp/tapir/server/netty/cats/NettyCatsServerInterpreter.scala b/server/netty-server/cats/src/main/scala/sttp/tapir/server/netty/cats/NettyCatsServerInterpreter.scala index fb2e01fed8..92830b8652 100644 --- a/server/netty-server/cats/src/main/scala/sttp/tapir/server/netty/cats/NettyCatsServerInterpreter.scala +++ b/server/netty-server/cats/src/main/scala/sttp/tapir/server/netty/cats/NettyCatsServerInterpreter.scala @@ -2,24 +2,16 @@ package sttp.tapir.server.netty.cats import cats.effect.Async import cats.effect.std.Dispatcher +import sttp.capabilities.fs2.Fs2Streams import sttp.monad.MonadError import sttp.monad.syntax._ import sttp.tapir.integ.cats.effect.CatsMonadError import sttp.tapir.server.ServerEndpoint -import sttp.tapir.server.netty.Route -import sttp.tapir.server.interpreter.BodyListener -import sttp.tapir.server.netty.internal.NettyBodyListener -import sttp.tapir.server.netty.NettyResponse -import sttp.tapir.server.interpreter.ServerInterpreter -import sttp.tapir.server.interpreter.FilterServerEndpoints -import sttp.tapir.server.netty.internal.NettyRequestBody -import sttp.tapir.server.netty.internal.NettyToResponseBody -import sttp.tapir.server.interceptor.reject.RejectInterceptor -import sttp.tapir.server.netty.NettyServerRequest import sttp.tapir.server.interceptor.RequestResult -import sttp.capabilities.fs2.Fs2Streams -import sttp.tapir.server.netty.internal.RunAsync -import sttp.tapir.server.netty.internal._ +import sttp.tapir.server.interceptor.reject.RejectInterceptor +import sttp.tapir.server.interpreter.{BodyListener, FilterServerEndpoints, ServerInterpreter} +import sttp.tapir.server.netty.internal.{NettyBodyListener, RunAsync, _} +import sttp.tapir.server.netty.{NettyResponse, NettyServerRequest, Route} trait NettyCatsServerInterpreter[F[_]] { implicit def async: Async[F] diff --git a/server/netty-server/cats/src/main/scala/sttp/tapir/server/netty/internal/NettyCatsRequestBody.scala b/server/netty-server/cats/src/main/scala/sttp/tapir/server/netty/internal/NettyCatsRequestBody.scala index e398360fdd..f6f3f9d601 100644 --- a/server/netty-server/cats/src/main/scala/sttp/tapir/server/netty/internal/NettyCatsRequestBody.scala +++ b/server/netty-server/cats/src/main/scala/sttp/tapir/server/netty/internal/NettyCatsRequestBody.scala @@ -2,21 +2,19 @@ package sttp.tapir.server.netty.internal import cats.effect.{Async, Sync} import cats.syntax.all._ -import io.netty.buffer.{ByteBufInputStream, ByteBufUtil} -import io.netty.handler.codec.http.FullHttpRequest +import com.typesafe.netty.http.StreamedHttpRequest +import fs2.Chunk +import fs2.interop.reactivestreams.StreamSubscriber +import fs2.io.file.{Files, Path} +import io.netty.buffer.ByteBufUtil +import io.netty.handler.codec.http.{FullHttpRequest, HttpContent} import sttp.capabilities.fs2.Fs2Streams -import sttp.tapir.{FileRange, InputStreamRange, RawBodyType, TapirFile} import sttp.tapir.model.ServerRequest import sttp.tapir.server.interpreter.{RawValue, RequestBody} +import sttp.tapir.{FileRange, InputStreamRange, RawBodyType, TapirFile} + import java.io.ByteArrayInputStream import java.nio.ByteBuffer -import com.typesafe.netty.http.StreamedHttpRequest -import com.typesafe.netty.http.DefaultStreamedHttpRequest -import fs2.interop.reactivestreams.StreamSubscriber -import io.netty.handler.codec.http.HttpContent -import fs2.Chunk -import fs2.io.file.Files -import fs2.io.file.Path private[netty] class NettyCatsRequestBody[F[_]](createFile: ServerRequest => F[TapirFile])(implicit val monad: Async[F]) extends RequestBody[F, Fs2Streams[F]] { @@ -52,20 +50,16 @@ private[netty] class NettyCatsRequestBody[F[_]](createFile: ServerRequest => F[T } override def toStream(serverRequest: ServerRequest): streams.BinaryStream = { - val nettyRequest = serverRequest.underlying.asInstanceOf[StreamedHttpRequest] - fs2.Stream .eval(StreamSubscriber[F, HttpContent](NettyRequestBody.bufferSize)) .flatMap(s => s.sub.stream(Sync[F].delay(nettyRequest.subscribe(s)))) .flatMap(httpContent => fs2.Stream.chunk(Chunk.byteBuffer(httpContent.content.nioBuffer()))) - - // fs2.io.readInputStream(Sync[F].delay(new ByteBufInputStream(nettyRequest(serverRequest).content())), streamChunkSize) } private def nettyRequestBytes(serverRequest: ServerRequest): F[Array[Byte]] = serverRequest.underlying match { case req: FullHttpRequest => monad.delay(ByteBufUtil.getBytes(req.content())) - case req: StreamedHttpRequest => toStream(serverRequest).compile.to(Chunk).map(_.toArray[Byte]) + case _: StreamedHttpRequest => toStream(serverRequest).compile.to(Chunk).map(_.toArray[Byte]) case other => monad.raiseError(new UnsupportedOperationException(s"Unexpected Netty request of type ${other.getClass().getName()}")) } } diff --git a/server/netty-server/cats/src/main/scala/sttp/tapir/server/netty/internal/NettyCatsToResponseBody.scala b/server/netty-server/cats/src/main/scala/sttp/tapir/server/netty/internal/NettyCatsToResponseBody.scala index ed4a2b14a8..275fe12105 100644 --- a/server/netty-server/cats/src/main/scala/sttp/tapir/server/netty/internal/NettyCatsToResponseBody.scala +++ b/server/netty-server/cats/src/main/scala/sttp/tapir/server/netty/internal/NettyCatsToResponseBody.scala @@ -15,7 +15,7 @@ import sttp.model.HasHeaders import sttp.tapir.server.interpreter.ToResponseBody import sttp.tapir.server.netty.NettyResponseContent.ByteBufNettyResponseContent import sttp.tapir.server.netty.{NettyResponse, NettyResponseContent} -import sttp.tapir.{CodecFormat, InputStreamRange, RawBodyType, WebSocketBodyOutput} +import sttp.tapir.{CodecFormat, RawBodyType, WebSocketBodyOutput} import java.io.InputStream import java.nio.ByteBuffer @@ -75,13 +75,8 @@ class NettyCatsToResponseBody[F[_]: Async](dispatcher: Dispatcher[F]) extends To Sync[F].blocking(inputStream()), NettyToResponseBody.DefaultChunkSize ) - private def wrap(streamRange: InputStreamRange): ChunkedStream = { - streamRange.range - .map(r => new RangedChunkedStream(streamRange.inputStreamFromRangeStart(), r.contentLength)) - .getOrElse(new ChunkedStream(streamRange.inputStream())) - } - def fs2StreamToPublisher(stream: streams.BinaryStream): Publisher[HttpContent] = { + private def fs2StreamToPublisher(stream: streams.BinaryStream): Publisher[HttpContent] = { // Deprecated constructor, but the proposed one does roughly the same, forcing a dedicated // dispatcher, which results in a Resource[], which is hard to afford here StreamUnicastPublisher( diff --git a/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/NettyRequestBody.scala b/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/NettyRequestBody.scala index 4a9ab12c73..2c46031306 100644 --- a/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/NettyRequestBody.scala +++ b/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/NettyRequestBody.scala @@ -9,7 +9,6 @@ import sttp.tapir.model.ServerRequest import sttp.monad.syntax._ import sttp.tapir.capabilities.NoStreams import sttp.tapir.server.interpreter.{RawValue, RequestBody} -import sttp.tapir.server.netty.NettyServerRequest import java.nio.ByteBuffer import java.nio.file.Files diff --git a/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/NettyServerHandler.scala b/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/NettyServerHandler.scala index 09025dfa08..d877e5cdaa 100644 --- a/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/NettyServerHandler.scala +++ b/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/NettyServerHandler.scala @@ -1,5 +1,6 @@ package sttp.tapir.server.netty.internal +import com.typesafe.netty.http.{DefaultStreamedHttpResponse, StreamedHttpRequest} import com.typesafe.scalalogging.Logger import io.netty.buffer.{ByteBuf, Unpooled} import io.netty.channel._ @@ -17,9 +18,6 @@ import sttp.tapir.server.netty.NettyResponseContent.{ import sttp.tapir.server.netty.{NettyResponse, NettyResponseContent, NettyServerRequest, Route} import scala.collection.JavaConverters._ -import com.typesafe.netty.http.DefaultStreamedHttpResponse -import com.typesafe.netty.http.DelegateStreamedHttpRequest -import com.typesafe.netty.http.StreamedHttpRequest class NettyServerHandler[F[_]](route: Route[F], unsafeRunAsync: (() => F[Unit]) => Unit)(implicit me: MonadError[F]) extends SimpleChannelInboundHandler[HttpRequest] { From 739d6f17babc0b0d94f852727c6c9a395e76d310 Mon Sep 17 00:00:00 2001 From: kciesielski Date: Sat, 15 Jul 2023 03:13:32 +0200 Subject: [PATCH 07/18] Additional cleanup in tests --- .../scala/sttp/tapir/server/tests/CreateServerTest.scala | 1 - .../scala/sttp/tapir/server/tests/ServerFilesTests.scala | 3 +-- .../sttp/tapir/server/tests/ServerStreamingTests.scala | 6 ------ tests/src/main/scala/sttp/tapir/tests/Streaming.scala | 5 ----- 4 files changed, 1 insertion(+), 14 deletions(-) diff --git a/server/tests/src/main/scala/sttp/tapir/server/tests/CreateServerTest.scala b/server/tests/src/main/scala/sttp/tapir/server/tests/CreateServerTest.scala index ddbf2854b2..950d018d67 100644 --- a/server/tests/src/main/scala/sttp/tapir/server/tests/CreateServerTest.scala +++ b/server/tests/src/main/scala/sttp/tapir/server/tests/CreateServerTest.scala @@ -79,7 +79,6 @@ class DefaultCreateServerTest[F[_], +R, OPTIONS, ROUTE]( _ <- Resource.eval(IO(logger.info(s"Bound server on port: $port"))) } yield port - import scala.concurrent.duration._ Test(name)( resources .use { port => diff --git a/server/tests/src/main/scala/sttp/tapir/server/tests/ServerFilesTests.scala b/server/tests/src/main/scala/sttp/tapir/server/tests/ServerFilesTests.scala index ee428b4f6a..d67690407f 100644 --- a/server/tests/src/main/scala/sttp/tapir/server/tests/ServerFilesTests.scala +++ b/server/tests/src/main/scala/sttp/tapir/server/tests/ServerFilesTests.scala @@ -218,9 +218,8 @@ class ServerFilesTests[F[_], OPTIONS, ROUTE]( }, Test("should return 200 status code for whole file") { withTestFilesDirectory { testDir => - import scala.concurrent.duration._ serveRoute(staticFilesGetServerEndpoint[F]("test")(testDir.getAbsolutePath)) - .use { port => get(port, List("test", "f2")).timeout(10.seconds).map(_.code shouldBe StatusCode.Ok) } + .use { port => get(port, List("test", "f2")).map(_.code shouldBe StatusCode.Ok) } .unsafeToFuture() } }, diff --git a/server/tests/src/main/scala/sttp/tapir/server/tests/ServerStreamingTests.scala b/server/tests/src/main/scala/sttp/tapir/server/tests/ServerStreamingTests.scala index c0f3601e59..3369f6c009 100644 --- a/server/tests/src/main/scala/sttp/tapir/server/tests/ServerStreamingTests.scala +++ b/server/tests/src/main/scala/sttp/tapir/server/tests/ServerStreamingTests.scala @@ -14,9 +14,6 @@ import sttp.tapir.tests.Streaming.{ in_string_stream_out_either_stream_string, out_custom_content_type_stream_body } -import scala.concurrent.duration._ -import java.io.ByteArrayInputStream -import java.io.InputStream class ServerStreamingTests[F[_], S, OPTIONS, ROUTE](createServerTest: CreateServerTest[F, S, OPTIONS, ROUTE], streams: Streams[S])(implicit m: MonadError[F] @@ -36,7 +33,6 @@ class ServerStreamingTests[F[_], S, OPTIONS, ROUTE](createServerTest: CreateServ )((in: (Long, streams.BinaryStream)) => pureResult(in.asRight[Unit])) { (backend, baseUri) => { basicRequest - .readTimeout(3.seconds) .post(uri"$baseUri/api/echo") .contentLength(penPineapple.length.toLong) .body(penPineapple) @@ -56,7 +52,6 @@ class ServerStreamingTests[F[_], S, OPTIONS, ROUTE](createServerTest: CreateServ } { (backend, baseUri) => basicRequest .post(uri"$baseUri?kind=-1") - .readTimeout(3.seconds) .body(penPineapple) .send(backend) .map { r => @@ -90,7 +85,6 @@ class ServerStreamingTests[F[_], S, OPTIONS, ROUTE](createServerTest: CreateServ testServer(in_stream_out_either_json_xml_stream(streams)) { s => pureResult(s.asRight[Unit]) } { (backend, baseUri) => basicRequest .post(uri"$baseUri") - .readTimeout(3.seconds) .body(penPineapple) .header(Header.accept(MediaType.ApplicationXml, MediaType.ApplicationJson)) .send(backend) diff --git a/tests/src/main/scala/sttp/tapir/tests/Streaming.scala b/tests/src/main/scala/sttp/tapir/tests/Streaming.scala index 33d53aedb2..91b5a59c18 100644 --- a/tests/src/main/scala/sttp/tapir/tests/Streaming.scala +++ b/tests/src/main/scala/sttp/tapir/tests/Streaming.scala @@ -7,11 +7,6 @@ import sttp.tapir._ import java.nio.charset.StandardCharsets object Streaming { - def in_string_out_string_stream[S](s: Streams[S]) = { - val sb = streamTextBody(s)(CodecFormat.TextPlain(), Some(StandardCharsets.UTF_8)) - endpoint.post.in("api" / "echo").in(stringBody).out(sb) - } - def in_stream_out_stream[S](s: Streams[S]): PublicEndpoint[s.BinaryStream, Unit, s.BinaryStream, S] = { val sb = streamTextBody(s)(CodecFormat.TextPlain(), Some(StandardCharsets.UTF_8)) endpoint.post.in("api" / "echo").in(sb).out(sb) From dcfff1695c7bc803c17f88f8d89f084c07cc9a2a Mon Sep 17 00:00:00 2001 From: kciesielski Date: Sat, 15 Jul 2023 03:15:20 +0200 Subject: [PATCH 08/18] Make netty-reactive-streams version a val --- build.sbt | 2 +- project/Versions.scala | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/build.sbt b/build.sbt index 18bc8c39c2..255ab4f138 100644 --- a/build.sbt +++ b/build.sbt @@ -1355,7 +1355,7 @@ lazy val nettyServer: ProjectMatrix = (projectMatrix in file("server/netty-serve name := "tapir-netty-server", libraryDependencies ++= Seq( "io.netty" % "netty-all" % Versions.nettyAll, - "com.typesafe.netty" % "netty-reactive-streams-http" % "2.0.8" + "com.typesafe.netty" % "netty-reactive-streams-http" % Versions.nettyReactiveStreams ) ++ loggerDependencies, // needed because of https://github.com/coursier/coursier/issues/2016 diff --git a/project/Versions.scala b/project/Versions.scala index 967e43e267..409d05b6a7 100644 --- a/project/Versions.scala +++ b/project/Versions.scala @@ -19,6 +19,7 @@ object Versions { val finatra = "22.12.0" val catbird = "21.12.0" val json4s = "4.0.6" + val nettyReactiveStreams = "2.0.8" val sprayJson = "1.3.6" val scalaCheck = "1.17.0" val scalaTest = "3.2.16" From 0e9f04fc43f4da507f4890d52937ed2161363c72 Mon Sep 17 00:00:00 2001 From: kciesielski Date: Sat, 15 Jul 2023 03:17:00 +0200 Subject: [PATCH 09/18] Remove unused val --- .../sttp/tapir/server/netty/internal/NettyCatsRequestBody.scala | 1 - 1 file changed, 1 deletion(-) diff --git a/server/netty-server/cats/src/main/scala/sttp/tapir/server/netty/internal/NettyCatsRequestBody.scala b/server/netty-server/cats/src/main/scala/sttp/tapir/server/netty/internal/NettyCatsRequestBody.scala index f6f3f9d601..330c343f00 100644 --- a/server/netty-server/cats/src/main/scala/sttp/tapir/server/netty/internal/NettyCatsRequestBody.scala +++ b/server/netty-server/cats/src/main/scala/sttp/tapir/server/netty/internal/NettyCatsRequestBody.scala @@ -19,7 +19,6 @@ import java.nio.ByteBuffer private[netty] class NettyCatsRequestBody[F[_]](createFile: ServerRequest => F[TapirFile])(implicit val monad: Async[F]) extends RequestBody[F, Fs2Streams[F]] { - val streamChunkSize = 8192 override val streams: Fs2Streams[F] = Fs2Streams[F] override def toRaw[R](serverRequest: ServerRequest, bodyType: RawBodyType[R]): F[RawValue[R]] = { From 47a109761120819b9ff49d9ade2abbe338ac209c Mon Sep 17 00:00:00 2001 From: kciesielski Date: Sat, 15 Jul 2023 03:22:58 +0200 Subject: [PATCH 10/18] Add an example --- .../streaming/StreamingNettyFs2Server.scala | 71 +++++++++++++++++++ 1 file changed, 71 insertions(+) create mode 100644 examples/src/main/scala/sttp/tapir/examples/streaming/StreamingNettyFs2Server.scala diff --git a/examples/src/main/scala/sttp/tapir/examples/streaming/StreamingNettyFs2Server.scala b/examples/src/main/scala/sttp/tapir/examples/streaming/StreamingNettyFs2Server.scala new file mode 100644 index 0000000000..1386ae94ec --- /dev/null +++ b/examples/src/main/scala/sttp/tapir/examples/streaming/StreamingNettyFs2Server.scala @@ -0,0 +1,71 @@ +package sttp.tapir.examples.streaming + +import cats.effect.{ExitCode, IO, IOApp} +import cats.implicits._ +import fs2.{Chunk, Stream} +import sttp.capabilities.fs2.Fs2Streams +import sttp.client3._ +import sttp.model.HeaderNames +import sttp.tapir._ +import sttp.tapir.server.ServerEndpoint +import sttp.tapir.server.netty.cats.{NettyCatsServer, NettyCatsServerBinding} + +import java.nio.charset.StandardCharsets +import scala.concurrent.duration._ + +object StreamingNettyFs2Server extends IOApp { + // corresponds to: GET /receive?name=... + // We need to provide both the schema of the value (for documentation), as well as the format (media type) of the + // body. Here, the schema is a `string` (set by `streamTextBody`) and the media type is `text/plain`. + val streamingEndpoint: PublicEndpoint[Unit, Unit, (Long, Stream[IO, Byte]), Fs2Streams[IO]] = + endpoint.get + .in("receive") + .out(header[Long](HeaderNames.ContentLength)) + .out(streamTextBody(Fs2Streams[IO])(CodecFormat.TextPlain(), Some(StandardCharsets.UTF_8))) + + val serverEndpoint: ServerEndpoint[Fs2Streams[IO], IO] = streamingEndpoint + .serverLogicSuccess { _ => + val size = 100L + Stream + .emit(List[Char]('a', 'b', 'c', 'd')) + .repeat + .flatMap(list => Stream.chunk(Chunk.seq(list))) + .metered[IO](100.millis) + .take(size) + .covary[IO] + .map(_.toByte) + .pure[IO] + .map(s => (size, s)) + } + + private val declaredPort = 9090 + private val declaredHost = "localhost" + + override def run(args: List[String]): IO[ExitCode] = { + // starting the server + NettyCatsServer + .io() + .use { server => + + val effect: IO[NettyCatsServerBinding[IO]] = server + .port(declaredPort) + .host(declaredHost) + .addEndpoint(serverEndpoint) + .start() + + effect.map { binding => + + val port = binding.port + val host = binding.hostName + println(s"Server started at port = ${binding.port}") + + val backend: SttpBackend[Identity, Any] = HttpURLConnectionBackend() + val result: String = basicRequest.response(asStringAlways).get(uri"http://$declaredHost:$declaredPort/receive").send(backend).body + println("Got result: " + result) + + assert(result == "abcd" * 25) + } + .as(ExitCode.Success) + } + } +} From 70863aa9401c20c11b22b3a98daf7f91cde6a907 Mon Sep 17 00:00:00 2001 From: kciesielski Date: Sat, 15 Jul 2023 03:23:10 +0200 Subject: [PATCH 11/18] More cleanup --- .../tapir/server/netty/internal/NettyCatsRequestBody.scala | 2 +- .../server/netty/internal/NettyCatsToResponseBody.scala | 7 ------- .../sttp/tapir/server/netty/cats/NettyCatsServerTest.scala | 7 +++---- .../main/scala/sttp/tapir/server/netty/NettyConfig.scala | 1 - .../tapir/server/netty/internal/NettyRequestBody.scala | 4 ++-- 5 files changed, 6 insertions(+), 15 deletions(-) diff --git a/server/netty-server/cats/src/main/scala/sttp/tapir/server/netty/internal/NettyCatsRequestBody.scala b/server/netty-server/cats/src/main/scala/sttp/tapir/server/netty/internal/NettyCatsRequestBody.scala index 330c343f00..59ab71939d 100644 --- a/server/netty-server/cats/src/main/scala/sttp/tapir/server/netty/internal/NettyCatsRequestBody.scala +++ b/server/netty-server/cats/src/main/scala/sttp/tapir/server/netty/internal/NettyCatsRequestBody.scala @@ -51,7 +51,7 @@ private[netty] class NettyCatsRequestBody[F[_]](createFile: ServerRequest => F[T override def toStream(serverRequest: ServerRequest): streams.BinaryStream = { val nettyRequest = serverRequest.underlying.asInstanceOf[StreamedHttpRequest] fs2.Stream - .eval(StreamSubscriber[F, HttpContent](NettyRequestBody.bufferSize)) + .eval(StreamSubscriber[F, HttpContent](NettyRequestBody.DefaultChunkSize)) .flatMap(s => s.sub.stream(Sync[F].delay(nettyRequest.subscribe(s)))) .flatMap(httpContent => fs2.Stream.chunk(Chunk.byteBuffer(httpContent.content.nioBuffer()))) } diff --git a/server/netty-server/cats/src/main/scala/sttp/tapir/server/netty/internal/NettyCatsToResponseBody.scala b/server/netty-server/cats/src/main/scala/sttp/tapir/server/netty/internal/NettyCatsToResponseBody.scala index 275fe12105..90abd69ac8 100644 --- a/server/netty-server/cats/src/main/scala/sttp/tapir/server/netty/internal/NettyCatsToResponseBody.scala +++ b/server/netty-server/cats/src/main/scala/sttp/tapir/server/netty/internal/NettyCatsToResponseBody.scala @@ -8,7 +8,6 @@ import fs2.io.file.{Files, Flags, Path} import io.netty.buffer.Unpooled import io.netty.channel.ChannelHandlerContext import io.netty.handler.codec.http.{DefaultHttpContent, HttpContent} -import io.netty.handler.stream.ChunkedStream import org.reactivestreams.Publisher import sttp.capabilities.fs2.Fs2Streams import sttp.model.HasHeaders @@ -21,12 +20,6 @@ import java.io.InputStream import java.nio.ByteBuffer import java.nio.charset.Charset -private[netty] class RangedChunkedStream(raw: InputStream, length: Long) extends ChunkedStream(raw) { - - override def isEndOfInput(): Boolean = - super.isEndOfInput || transferredBytes == length -} - class NettyCatsToResponseBody[F[_]: Async](dispatcher: Dispatcher[F]) extends ToResponseBody[NettyResponse, Fs2Streams[F]] { override val streams: Fs2Streams[F] = Fs2Streams[F] diff --git a/server/netty-server/cats/src/test/scala/sttp/tapir/server/netty/cats/NettyCatsServerTest.scala b/server/netty-server/cats/src/test/scala/sttp/tapir/server/netty/cats/NettyCatsServerTest.scala index 5c4a81acd2..53ac2b4c9b 100644 --- a/server/netty-server/cats/src/test/scala/sttp/tapir/server/netty/cats/NettyCatsServerTest.scala +++ b/server/netty-server/cats/src/test/scala/sttp/tapir/server/netty/cats/NettyCatsServerTest.scala @@ -1,10 +1,8 @@ package sttp.tapir.server.netty.cats import cats.effect.{IO, Resource} -import com.typesafe.scalalogging.StrictLogging import io.netty.channel.nio.NioEventLoopGroup import org.scalatest.EitherValues -import org.scalatest.matchers.should.Matchers import sttp.capabilities.fs2.Fs2Streams import sttp.monad.MonadError import sttp.tapir.integ.cats.effect.CatsMonadError @@ -12,7 +10,7 @@ import sttp.tapir.server.netty.internal.FutureUtil import sttp.tapir.server.tests._ import sttp.tapir.tests.{Test, TestSuite} -class NettyCatsServerTest extends TestSuite with EitherValues with StrictLogging with Matchers { +class NettyCatsServerTest extends TestSuite with EitherValues { override def tests: Resource[IO, List[Test]] = backendResource.flatMap { backend => Resource @@ -23,7 +21,8 @@ class NettyCatsServerTest extends TestSuite with EitherValues with StrictLogging val interpreter = new NettyCatsTestServerInterpreter(eventLoopGroup, dispatcher) val createServerTest = new DefaultCreateServerTest(backend, interpreter) - val tests = new AllServerTests(createServerTest, interpreter, backend, multipart = false).tests() ++ new ServerStreamingTests(createServerTest, Fs2Streams[IO]).tests() + val tests = new AllServerTests(createServerTest, interpreter, backend, multipart = false) + .tests() ++ new ServerStreamingTests(createServerTest, Fs2Streams[IO]).tests() IO.pure((tests, eventLoopGroup)) } { case (_, eventLoopGroup) => diff --git a/server/netty-server/src/main/scala/sttp/tapir/server/netty/NettyConfig.scala b/server/netty-server/src/main/scala/sttp/tapir/server/netty/NettyConfig.scala index e0733cccfe..99224c8723 100644 --- a/server/netty-server/src/main/scala/sttp/tapir/server/netty/NettyConfig.scala +++ b/server/netty-server/src/main/scala/sttp/tapir/server/netty/NettyConfig.scala @@ -96,7 +96,6 @@ case class NettyConfig( def eventLoopGroup(elg: EventLoopGroup): NettyConfig = copy(eventLoopConfig = EventLoopConfig.useExisting(elg)) def initPipeline(f: NettyConfig => (ChannelPipeline, ChannelHandler) => Unit): NettyConfig = copy(initPipeline = f) - } object NettyConfig { diff --git a/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/NettyRequestBody.scala b/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/NettyRequestBody.scala index 2c46031306..6c9d251676 100644 --- a/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/NettyRequestBody.scala +++ b/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/NettyRequestBody.scala @@ -46,6 +46,6 @@ class NettyRequestBody[F[_]](createFile: ServerRequest => F[TapirFile])(implicit private def nettyRequest(serverRequest: ServerRequest): FullHttpRequest = serverRequest.underlying.asInstanceOf[FullHttpRequest] } -object NettyRequestBody { - private[internal] val bufferSize = 8192 +private[internal] object NettyRequestBody { + val DefaultChunkSize = 8192 } From aa92c093daf6ae3d8afb085e3f9d7e2f845b30d2 Mon Sep 17 00:00:00 2001 From: kciesielski Date: Tue, 18 Jul 2023 10:11:57 +0200 Subject: [PATCH 12/18] Handle maxContentLength --- .../server/netty/cats/NettyCatsServer.scala | 2 +- .../cats/NettyCatsServerInterpreter.scala | 2 +- .../internal/NettyCatsToResponseBody.scala | 26 +++++++- .../netty/cats/NettyCatsServerTest.scala | 9 ++- .../cats/NettyCatsTestServerInterpreter.scala | 10 +++- .../sttp/tapir/server/netty/NettyConfig.scala | 10 ++-- .../server/netty/NettyFutureServer.scala | 2 +- .../netty/internal/NettyServerHandler.scala | 59 ++++++++++++++++--- .../server/netty/zio/NettyZioServer.scala | 3 +- .../tapir/server/tests/AllServerTests.scala | 5 +- .../tapir/server/tests/ServerBasicTests.scala | 10 +++- 11 files changed, 115 insertions(+), 23 deletions(-) diff --git a/server/netty-server/cats/src/main/scala/sttp/tapir/server/netty/cats/NettyCatsServer.scala b/server/netty-server/cats/src/main/scala/sttp/tapir/server/netty/cats/NettyCatsServer.scala index 7566146933..604a13e57d 100644 --- a/server/netty-server/cats/src/main/scala/sttp/tapir/server/netty/cats/NettyCatsServer.scala +++ b/server/netty-server/cats/src/main/scala/sttp/tapir/server/netty/cats/NettyCatsServer.scala @@ -61,7 +61,7 @@ case class NettyCatsServer[F[_]: Async](routes: Vector[Route[F]], options: Netty val channelFuture = NettyBootstrap( config, - new NettyServerHandler(route, (f: () => F[Unit]) => options.dispatcher.unsafeToFuture(f())), + new NettyServerHandler(route, (f: () => F[Unit]) => options.dispatcher.unsafeToFuture(f()), config.maxContentLength), eventLoopGroup, socketOverride ) diff --git a/server/netty-server/cats/src/main/scala/sttp/tapir/server/netty/cats/NettyCatsServerInterpreter.scala b/server/netty-server/cats/src/main/scala/sttp/tapir/server/netty/cats/NettyCatsServerInterpreter.scala index 92830b8652..0c97bc3798 100644 --- a/server/netty-server/cats/src/main/scala/sttp/tapir/server/netty/cats/NettyCatsServerInterpreter.scala +++ b/server/netty-server/cats/src/main/scala/sttp/tapir/server/netty/cats/NettyCatsServerInterpreter.scala @@ -32,7 +32,7 @@ trait NettyCatsServerInterpreter[F[_]] { val serverInterpreter = new ServerInterpreter[Fs2Streams[F], F, NettyResponse, Fs2Streams[F]]( FilterServerEndpoints(ses), new NettyCatsRequestBody(createFile), - new NettyCatsToResponseBody(nettyServerOptions.dispatcher), + new NettyCatsToResponseBody(nettyServerOptions.dispatcher, Some(10000)), // TODO pass from NettyCatsServer? RejectInterceptor.disableWhenSingleEndpoint(interceptors, ses), deleteFile ) diff --git a/server/netty-server/cats/src/main/scala/sttp/tapir/server/netty/internal/NettyCatsToResponseBody.scala b/server/netty-server/cats/src/main/scala/sttp/tapir/server/netty/internal/NettyCatsToResponseBody.scala index 90abd69ac8..9bdfee74df 100644 --- a/server/netty-server/cats/src/main/scala/sttp/tapir/server/netty/internal/NettyCatsToResponseBody.scala +++ b/server/netty-server/cats/src/main/scala/sttp/tapir/server/netty/internal/NettyCatsToResponseBody.scala @@ -19,8 +19,10 @@ import sttp.tapir.{CodecFormat, RawBodyType, WebSocketBodyOutput} import java.io.InputStream import java.nio.ByteBuffer import java.nio.charset.Charset +import fs2.Pull -class NettyCatsToResponseBody[F[_]: Async](dispatcher: Dispatcher[F]) extends ToResponseBody[NettyResponse, Fs2Streams[F]] { +class NettyCatsToResponseBody[F[_]: Async](dispatcher: Dispatcher[F], maxContentLength: Option[Int]) + extends ToResponseBody[NettyResponse, Fs2Streams[F]] { override val streams: Fs2Streams[F] = Fs2Streams[F] override def fromRawValue[R](v: R, headers: HasHeaders, format: CodecFormat, bodyType: RawBodyType[R]): NettyResponse = { @@ -70,17 +72,37 @@ class NettyCatsToResponseBody[F[_]: Async](dispatcher: Dispatcher[F]) extends To ) private def fs2StreamToPublisher(stream: streams.BinaryStream): Publisher[HttpContent] = { + val chunkedStream = stream.chunkLimit(NettyToResponseBody.DefaultChunkSize) + val streamToConvert = maxContentLength.map(enforceFs2MaxBytes(chunkedStream, _)).getOrElse(chunkedStream) // Deprecated constructor, but the proposed one does roughly the same, forcing a dedicated // dispatcher, which results in a Resource[], which is hard to afford here StreamUnicastPublisher( - stream.chunks + streamToConvert .map { chunk => val bytes: Chunk.ArraySlice[Byte] = chunk.compact + new DefaultHttpContent(Unpooled.wrappedBuffer(bytes.values, bytes.offset, bytes.length)) }, dispatcher ) } + private def enforceFs2MaxBytes(stream: fs2.Stream[F, Chunk[Byte]], maxBytes: Int): fs2.Stream[F, Chunk[Byte]] = { + def go(s: fs2.Stream[F, Chunk[Byte]], count: Long): Pull[F, Chunk[Byte], Unit] = { + s.pull.uncons.flatMap { + case Some((chunk, tail)) => + val chunkSize = chunk.size.toLong + val newCount = count + chunkSize + if (newCount > maxBytes) { + Pull.raiseError[F](new IllegalArgumentException(s"Body size limit $maxBytes exceeded")) + } else { + Pull.output(chunk) >> go(tail, newCount) + } + case None => + Pull.done + } + } + go(stream, 0L).stream + } override def fromStreamValue( v: streams.BinaryStream, diff --git a/server/netty-server/cats/src/test/scala/sttp/tapir/server/netty/cats/NettyCatsServerTest.scala b/server/netty-server/cats/src/test/scala/sttp/tapir/server/netty/cats/NettyCatsServerTest.scala index 53ac2b4c9b..1ec8c40b14 100644 --- a/server/netty-server/cats/src/test/scala/sttp/tapir/server/netty/cats/NettyCatsServerTest.scala +++ b/server/netty-server/cats/src/test/scala/sttp/tapir/server/netty/cats/NettyCatsServerTest.scala @@ -11,6 +11,7 @@ import sttp.tapir.server.tests._ import sttp.tapir.tests.{Test, TestSuite} class NettyCatsServerTest extends TestSuite with EitherValues { + override def tests: Resource[IO, List[Test]] = backendResource.flatMap { backend => Resource @@ -21,7 +22,13 @@ class NettyCatsServerTest extends TestSuite with EitherValues { val interpreter = new NettyCatsTestServerInterpreter(eventLoopGroup, dispatcher) val createServerTest = new DefaultCreateServerTest(backend, interpreter) - val tests = new AllServerTests(createServerTest, interpreter, backend, multipart = false) + val tests = new AllServerTests( + createServerTest, + interpreter, + backend, + multipart = false, + maxContentLength = Some(NettyCatsTestServerInterpreter.maxContentLength) + ) .tests() ++ new ServerStreamingTests(createServerTest, Fs2Streams[IO]).tests() IO.pure((tests, eventLoopGroup)) diff --git a/server/netty-server/cats/src/test/scala/sttp/tapir/server/netty/cats/NettyCatsTestServerInterpreter.scala b/server/netty-server/cats/src/test/scala/sttp/tapir/server/netty/cats/NettyCatsTestServerInterpreter.scala index 63968fa1ce..b0847aedb4 100644 --- a/server/netty-server/cats/src/test/scala/sttp/tapir/server/netty/cats/NettyCatsTestServerInterpreter.scala +++ b/server/netty-server/cats/src/test/scala/sttp/tapir/server/netty/cats/NettyCatsTestServerInterpreter.scala @@ -20,7 +20,11 @@ class NettyCatsTestServerInterpreter(eventLoopGroup: NioEventLoopGroup, dispatch } override def server(routes: NonEmptyList[Route[IO]]): Resource[IO, Port] = { - val config = NettyConfig.defaultWithStreaming.eventLoopGroup(eventLoopGroup).randomPort.withDontShutdownEventLoopGroupOnClose + val config = NettyConfig.defaultWithStreaming + .eventLoopGroup(eventLoopGroup) + .randomPort + .withDontShutdownEventLoopGroupOnClose + .maxContentLength(NettyCatsTestServerInterpreter.maxContentLength) val options = NettyCatsServerOptions.default[IO](dispatcher) val bind: IO[NettyCatsServerBinding[IO]] = NettyCatsServer(options, config).addRoutes(routes.toList).start() @@ -29,3 +33,7 @@ class NettyCatsTestServerInterpreter(eventLoopGroup: NioEventLoopGroup, dispatch .map(_.port) } } + +object NettyCatsTestServerInterpreter { + val maxContentLength = 10000 +} diff --git a/server/netty-server/src/main/scala/sttp/tapir/server/netty/NettyConfig.scala b/server/netty-server/src/main/scala/sttp/tapir/server/netty/NettyConfig.scala index 99224c8723..3eea0e9c48 100644 --- a/server/netty-server/src/main/scala/sttp/tapir/server/netty/NettyConfig.scala +++ b/server/netty-server/src/main/scala/sttp/tapir/server/netty/NettyConfig.scala @@ -49,7 +49,7 @@ case class NettyConfig( host: String, port: Int, shutdownEventLoopGroupOnClose: Boolean, - maxContentLength: Int, + maxContentLength: Option[Int], socketBacklog: Int, requestTimeout: Option[FiniteDuration], connectionTimeout: Option[FiniteDuration], @@ -70,8 +70,8 @@ case class NettyConfig( def withShutdownEventLoopGroupOnClose: NettyConfig = copy(shutdownEventLoopGroupOnClose = true) def withDontShutdownEventLoopGroupOnClose: NettyConfig = copy(shutdownEventLoopGroupOnClose = false) - def maxContentLength(m: Int): NettyConfig = copy(maxContentLength = m) - def noMaxContentLength: NettyConfig = copy(maxContentLength = Integer.MAX_VALUE) + def maxContentLength(m: Int): NettyConfig = copy(maxContentLength = Some(m)) + def noMaxContentLength: NettyConfig = copy(maxContentLength = None) def socketBacklog(s: Int): NettyConfig = copy(socketBacklog = s) @@ -109,7 +109,7 @@ object NettyConfig { connectionTimeout = Some(10.seconds), socketTimeout = Some(60.seconds), lingerTimeout = Some(60.seconds), - maxContentLength = Integer.MAX_VALUE, + maxContentLength = None, addLoggingHandler = false, sslContext = None, eventLoopConfig = EventLoopConfig.auto, @@ -120,7 +120,7 @@ object NettyConfig { def defaultInitPipeline(cfg: NettyConfig)(pipeline: ChannelPipeline, handler: ChannelHandler): Unit = { cfg.sslContext.foreach(s => pipeline.addLast(s.newHandler(pipeline.channel().alloc()))) pipeline.addLast(new HttpServerCodec()) - pipeline.addLast(new HttpObjectAggregator(cfg.maxContentLength)) + pipeline.addLast(new HttpObjectAggregator(cfg.maxContentLength.getOrElse(Integer.MAX_VALUE))) pipeline.addLast(new ChunkedWriteHandler()) pipeline.addLast(handler) () diff --git a/server/netty-server/src/main/scala/sttp/tapir/server/netty/NettyFutureServer.scala b/server/netty-server/src/main/scala/sttp/tapir/server/netty/NettyFutureServer.scala index faa18cb069..b04ea2dd5c 100644 --- a/server/netty-server/src/main/scala/sttp/tapir/server/netty/NettyFutureServer.scala +++ b/server/netty-server/src/main/scala/sttp/tapir/server/netty/NettyFutureServer.scala @@ -55,7 +55,7 @@ case class NettyFutureServer(routes: Vector[FutureRoute], options: NettyFutureSe val route = Route.combine(routes) val channelFuture = - NettyBootstrap(config, new NettyServerHandler(route, (f: () => Future[Unit]) => f()), eventLoopGroup, socketOverride) + NettyBootstrap(config, new NettyServerHandler(route, (f: () => Future[Unit]) => f(), config.maxContentLength), eventLoopGroup, socketOverride) nettyChannelFutureToScala(channelFuture).map(ch => (ch.localAddress().asInstanceOf[SA], () => stop(ch, eventLoopGroup))) } diff --git a/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/NettyServerHandler.scala b/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/NettyServerHandler.scala index d877e5cdaa..fa268a68db 100644 --- a/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/NettyServerHandler.scala +++ b/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/NettyServerHandler.scala @@ -16,13 +16,29 @@ import sttp.tapir.server.netty.NettyResponseContent.{ ChunkedStreamNettyResponseContent } import sttp.tapir.server.netty.{NettyResponse, NettyResponseContent, NettyServerRequest, Route} +import io.netty.handler.codec.http.HttpHeaderNames.CONNECTION; +import io.netty.handler.codec.http.HttpHeaderNames.CONTENT_LENGTH; import scala.collection.JavaConverters._ -class NettyServerHandler[F[_]](route: Route[F], unsafeRunAsync: (() => F[Unit]) => Unit)(implicit me: MonadError[F]) +class NettyServerHandler[F[_]](route: Route[F], unsafeRunAsync: (() => F[Unit]) => Unit, maxContentLength: Option[Int])(implicit me: MonadError[F]) extends SimpleChannelInboundHandler[HttpRequest] { private val logger = Logger[NettyServerHandler[F]] + + private val EntityTooLarge: FullHttpResponse = { + val res = new DefaultFullHttpResponse(HttpVersion.HTTP_1_1, HttpResponseStatus.REQUEST_ENTITY_TOO_LARGE, Unpooled.EMPTY_BUFFER) + res.headers().set(CONTENT_LENGTH, 0) + res + } + + private val EntityTooLargeClose: FullHttpResponse = { + val res = new DefaultFullHttpResponse(HttpVersion.HTTP_1_1, HttpResponseStatus.REQUEST_ENTITY_TOO_LARGE, Unpooled.EMPTY_BUFFER) + res.headers().set(CONTENT_LENGTH, 0) + res.headers().set(CONNECTION, HttpHeaderValues.CLOSE) + res + } + override def channelRead0(ctx: ChannelHandlerContext, request: HttpRequest): Unit = { if (HttpUtil.is100ContinueExpected(request)) { @@ -95,13 +111,16 @@ class NettyServerHandler[F[_]](route: Route[F], unsafeRunAsync: (() => F[Unit]) serverResponse.handle( ctx = ctx, byteBufHandler = (channelPromise, byteBuf) => { - val res = new DefaultFullHttpResponse(req.protocolVersion(), HttpResponseStatus.valueOf(serverResponse.code.code), byteBuf) - res.setHeadersFrom(serverResponse) - res.handleContentLengthAndChunkedHeaders(Option(byteBuf.readableBytes())) - res.handleCloseAndKeepAliveHeaders(req) - - ctx.writeAndFlush(res, channelPromise).closeIfNeeded(req) + if (maxContentLength.exists(_ < byteBuf.readableBytes)) + writeEntityTooLargeResponse(ctx, req) + else { + val res = new DefaultFullHttpResponse(req.protocolVersion(), HttpResponseStatus.valueOf(serverResponse.code.code), byteBuf) + res.setHeadersFrom(serverResponse) + res.handleContentLengthAndChunkedHeaders(Option(byteBuf.readableBytes())) + res.handleCloseAndKeepAliveHeaders(req) + ctx.writeAndFlush(res, channelPromise).closeIfNeeded(req) + } }, chunkedStreamHandler = (channelPromise, chunkedStream) => { val resHeader: DefaultHttpResponse = @@ -150,6 +169,32 @@ class NettyServerHandler[F[_]](route: Route[F], unsafeRunAsync: (() => F[Unit]) } ) + private def writeEntityTooLargeResponse(ctx: ChannelHandlerContext, req: HttpRequest): Unit = { + + if (!HttpUtil.is100ContinueExpected(req) && !HttpUtil.isKeepAlive(req)) { + val future: ChannelFuture = ctx.writeAndFlush(EntityTooLargeClose.retainedDuplicate()) + future.addListener(new ChannelFutureListener() { + override def operationComplete(future: ChannelFuture) = { + if (!future.isSuccess()) { + logger.warn("Failed to send a 413 Request Entity Too Large.", future.cause()) + } + ctx.close() + } + }) + } else { + ctx + .writeAndFlush(EntityTooLarge.retainedDuplicate()) + .addListener(new ChannelFutureListener() { + override def operationComplete(future: ChannelFuture) = { + if (!future.isSuccess()) { + logger.warn("Failed to send a 413 Request Entity Too Large.", future.cause()) + ctx.close() + } + } + }) + } + } + private implicit class RichServerNettyResponse(val r: ServerResponse[NettyResponse]) { def handle( ctx: ChannelHandlerContext, diff --git a/server/netty-server/zio/src/main/scala/sttp/tapir/server/netty/zio/NettyZioServer.scala b/server/netty-server/zio/src/main/scala/sttp/tapir/server/netty/zio/NettyZioServer.scala index 91e139e056..8481e84715 100644 --- a/server/netty-server/zio/src/main/scala/sttp/tapir/server/netty/zio/NettyZioServer.scala +++ b/server/netty-server/zio/src/main/scala/sttp/tapir/server/netty/zio/NettyZioServer.scala @@ -66,7 +66,8 @@ case class NettyZioServer[R](routes: Vector[RIO[R, Route[RIO[R, *]]]], options: config, new NettyServerHandler[RIO[R, *]]( route, - (f: () => RIO[R, Unit]) => Unsafe.unsafe(implicit u => runtime.unsafe.runToFuture(f())) + (f: () => RIO[R, Unit]) => Unsafe.unsafe(implicit u => runtime.unsafe.runToFuture(f())), + config.maxContentLength ), eventLoopGroup, socketOverride diff --git a/server/tests/src/main/scala/sttp/tapir/server/tests/AllServerTests.scala b/server/tests/src/main/scala/sttp/tapir/server/tests/AllServerTests.scala index 96eea0ee72..369f0cc9b6 100644 --- a/server/tests/src/main/scala/sttp/tapir/server/tests/AllServerTests.scala +++ b/server/tests/src/main/scala/sttp/tapir/server/tests/AllServerTests.scala @@ -27,13 +27,14 @@ class AllServerTests[F[_], OPTIONS, ROUTE]( validation: Boolean = true, oneOfBody: Boolean = true, cors: Boolean = true, - options: Boolean = true + options: Boolean = true, + maxContentLength: Option[Int] = None )(implicit m: MonadError[F] ) { def tests(): List[Test] = (if (security) new ServerSecurityTests(createServerTest).tests() else Nil) ++ - (if (basic) new ServerBasicTests(createServerTest, serverInterpreter).tests() else Nil) ++ + (if (basic) new ServerBasicTests(createServerTest, serverInterpreter, maxContentLength = maxContentLength).tests() else Nil) ++ (if (contentNegotiation) new ServerContentNegotiationTests(createServerTest).tests() else Nil) ++ (if (file) new ServerFileTests(createServerTest).tests() else Nil) ++ (if (mapping) new ServerMappingTests(createServerTest).tests() else Nil) ++ diff --git a/server/tests/src/main/scala/sttp/tapir/server/tests/ServerBasicTests.scala b/server/tests/src/main/scala/sttp/tapir/server/tests/ServerBasicTests.scala index 5122c4fb8d..752ba36e9e 100644 --- a/server/tests/src/main/scala/sttp/tapir/server/tests/ServerBasicTests.scala +++ b/server/tests/src/main/scala/sttp/tapir/server/tests/ServerBasicTests.scala @@ -31,7 +31,8 @@ class ServerBasicTests[F[_], OPTIONS, ROUTE]( inputStreamSupport: Boolean = true, supportsUrlEncodedPathSegments: Boolean = true, supportsMultipleSetCookieHeaders: Boolean = true, - invulnerableToUnsanitizedHeaders: Boolean = true + invulnerableToUnsanitizedHeaders: Boolean = true, + maxContentLength: Option[Int] = None )(implicit m: MonadError[F] ) { @@ -46,6 +47,7 @@ class ServerBasicTests[F[_], OPTIONS, ROUTE]( customiseDecodeFailureHandlerTests() ++ serverSecurityLogicTests() ++ (if (inputStreamSupport) inputStreamTests() else Nil) ++ + (if (maxContentLength.nonEmpty) maxContentLengthTests() else Nil) ++ exceptionTests() def basicTests(): List[Test] = List( @@ -742,6 +744,12 @@ class ServerBasicTests[F[_], OPTIONS, ROUTE]( } ) + def maxContentLengthTests(): List[Test] = List( + testServer(in_string_out_string, "returns 413 on exceeded max content length")(_ => + pureResult(List.fill(maxContentLength.getOrElse(0) + 1)('x').mkString.asRight[Unit]) + ) { (backend, baseUri) => basicRequest.post(uri"$baseUri/api/echo").body("irrelevant").send(backend).map(_.code.code shouldBe 413) } + ) + def exceptionTests(): List[Test] = List( testServer(endpoint, "handle exceptions")(_ => throw new RuntimeException()) { (backend, baseUri) => basicRequest.get(uri"$baseUri").send(backend).map(_.code shouldBe StatusCode.InternalServerError) From 72b6d53d72003e438c0ae5c21fd853b0b7cebec4 Mon Sep 17 00:00:00 2001 From: kciesielski Date: Tue, 18 Jul 2023 10:14:04 +0200 Subject: [PATCH 13/18] Remove maxContentLength handling for streams - This needs to go to a separate PR --- .../cats/NettyCatsServerInterpreter.scala | 2 +- .../internal/NettyCatsToResponseBody.scala | 23 ++----------------- 2 files changed, 3 insertions(+), 22 deletions(-) diff --git a/server/netty-server/cats/src/main/scala/sttp/tapir/server/netty/cats/NettyCatsServerInterpreter.scala b/server/netty-server/cats/src/main/scala/sttp/tapir/server/netty/cats/NettyCatsServerInterpreter.scala index 0c97bc3798..a6a796d96d 100644 --- a/server/netty-server/cats/src/main/scala/sttp/tapir/server/netty/cats/NettyCatsServerInterpreter.scala +++ b/server/netty-server/cats/src/main/scala/sttp/tapir/server/netty/cats/NettyCatsServerInterpreter.scala @@ -32,7 +32,7 @@ trait NettyCatsServerInterpreter[F[_]] { val serverInterpreter = new ServerInterpreter[Fs2Streams[F], F, NettyResponse, Fs2Streams[F]]( FilterServerEndpoints(ses), new NettyCatsRequestBody(createFile), - new NettyCatsToResponseBody(nettyServerOptions.dispatcher, Some(10000)), // TODO pass from NettyCatsServer? + new NettyCatsToResponseBody(nettyServerOptions.dispatcher), RejectInterceptor.disableWhenSingleEndpoint(interceptors, ses), deleteFile ) diff --git a/server/netty-server/cats/src/main/scala/sttp/tapir/server/netty/internal/NettyCatsToResponseBody.scala b/server/netty-server/cats/src/main/scala/sttp/tapir/server/netty/internal/NettyCatsToResponseBody.scala index 9bdfee74df..eb7085284e 100644 --- a/server/netty-server/cats/src/main/scala/sttp/tapir/server/netty/internal/NettyCatsToResponseBody.scala +++ b/server/netty-server/cats/src/main/scala/sttp/tapir/server/netty/internal/NettyCatsToResponseBody.scala @@ -21,7 +21,7 @@ import java.nio.ByteBuffer import java.nio.charset.Charset import fs2.Pull -class NettyCatsToResponseBody[F[_]: Async](dispatcher: Dispatcher[F], maxContentLength: Option[Int]) +class NettyCatsToResponseBody[F[_]: Async](dispatcher: Dispatcher[F]) extends ToResponseBody[NettyResponse, Fs2Streams[F]] { override val streams: Fs2Streams[F] = Fs2Streams[F] @@ -72,12 +72,10 @@ class NettyCatsToResponseBody[F[_]: Async](dispatcher: Dispatcher[F], maxContent ) private def fs2StreamToPublisher(stream: streams.BinaryStream): Publisher[HttpContent] = { - val chunkedStream = stream.chunkLimit(NettyToResponseBody.DefaultChunkSize) - val streamToConvert = maxContentLength.map(enforceFs2MaxBytes(chunkedStream, _)).getOrElse(chunkedStream) // Deprecated constructor, but the proposed one does roughly the same, forcing a dedicated // dispatcher, which results in a Resource[], which is hard to afford here StreamUnicastPublisher( - streamToConvert + stream.chunkLimit(NettyToResponseBody.DefaultChunkSize) .map { chunk => val bytes: Chunk.ArraySlice[Byte] = chunk.compact @@ -86,23 +84,6 @@ class NettyCatsToResponseBody[F[_]: Async](dispatcher: Dispatcher[F], maxContent dispatcher ) } - private def enforceFs2MaxBytes(stream: fs2.Stream[F, Chunk[Byte]], maxBytes: Int): fs2.Stream[F, Chunk[Byte]] = { - def go(s: fs2.Stream[F, Chunk[Byte]], count: Long): Pull[F, Chunk[Byte], Unit] = { - s.pull.uncons.flatMap { - case Some((chunk, tail)) => - val chunkSize = chunk.size.toLong - val newCount = count + chunkSize - if (newCount > maxBytes) { - Pull.raiseError[F](new IllegalArgumentException(s"Body size limit $maxBytes exceeded")) - } else { - Pull.output(chunk) >> go(tail, newCount) - } - case None => - Pull.done - } - } - go(stream, 0L).stream - } override def fromStreamValue( v: streams.BinaryStream, From cec1cca18175b4882c265ef2aa6f44ebc7588c53 Mon Sep 17 00:00:00 2001 From: kciesielski Date: Tue, 18 Jul 2023 10:36:52 +0200 Subject: [PATCH 14/18] Refactoring after review --- doc/server/netty.md | 2 +- .../streaming/StreamingNettyFs2Server.scala | 38 ++++---- .../cats/NettyCatsServerInterpreter.scala | 2 +- .../internal/NettyCatsToResponseBody.scala | 29 ++---- .../sttp/tapir/server/netty/NettyConfig.scala | 11 +-- .../server/netty/NettyFutureServer.scala | 11 ++- .../netty/internal/NettyServerHandler.scala | 96 ++++++++----------- .../tapir/server/tests/ServerBasicTests.scala | 2 +- 8 files changed, 85 insertions(+), 106 deletions(-) diff --git a/doc/server/netty.md b/doc/server/netty.md index eb6cc831f3..2db26b4bab 100644 --- a/doc/server/netty.md +++ b/doc/server/netty.md @@ -16,7 +16,7 @@ To expose an endpoint using a [Netty](https://netty.io)-based server, first add Then, use: - `NettyFutureServer().addEndpoints` to expose `Future`-based server endpoints. -- `NettyCatsServer().addEndpoints` to expose `F`-based server endpoints, where `F` is any cats-effect supported effect. [Streaming](../endpoint/streaming.md) request and response body is supported with fs2. +- `NettyCatsServer().addEndpoints` to expose `F`-based server endpoints, where `F` is any cats-effect supported effect. [Streaming](../endpoint/streaming.md) request and response bodies is supported with fs2. - `NettyZioServer().addEndpoints` to expose `ZIO`-based server endpoints, where `R` represents ZIO requirements supported effect. These methods require a single, or a list of `ServerEndpoint`s, which can be created by adding [server logic](logic.md) diff --git a/examples/src/main/scala/sttp/tapir/examples/streaming/StreamingNettyFs2Server.scala b/examples/src/main/scala/sttp/tapir/examples/streaming/StreamingNettyFs2Server.scala index 1386ae94ec..a7eb446d6a 100644 --- a/examples/src/main/scala/sttp/tapir/examples/streaming/StreamingNettyFs2Server.scala +++ b/examples/src/main/scala/sttp/tapir/examples/streaming/StreamingNettyFs2Server.scala @@ -43,29 +43,31 @@ object StreamingNettyFs2Server extends IOApp { override def run(args: List[String]): IO[ExitCode] = { // starting the server - NettyCatsServer - .io() - .use { server => + NettyCatsServer + .io() + .use { server => - val effect: IO[NettyCatsServerBinding[IO]] = server - .port(declaredPort) - .host(declaredHost) - .addEndpoint(serverEndpoint) - .start() + val startServer: IO[NettyCatsServerBinding[IO]] = server + .port(declaredPort) + .host(declaredHost) + .addEndpoint(serverEndpoint) + .start() - effect.map { binding => + startServer + .map { binding => - val port = binding.port - val host = binding.hostName - println(s"Server started at port = ${binding.port}") + val port = binding.port + val host = binding.hostName + println(s"Server started at port = ${binding.port}") - val backend: SttpBackend[Identity, Any] = HttpURLConnectionBackend() - val result: String = basicRequest.response(asStringAlways).get(uri"http://$declaredHost:$declaredPort/receive").send(backend).body - println("Got result: " + result) + val backend: SttpBackend[Identity, Any] = HttpURLConnectionBackend() + val result: String = + basicRequest.response(asStringAlways).get(uri"http://$declaredHost:$declaredPort/receive").send(backend).body + println("Got result: " + result) - assert(result == "abcd" * 25) + assert(result == "abcd" * 25) + } + .as(ExitCode.Success) } - .as(ExitCode.Success) - } } } diff --git a/server/netty-server/cats/src/main/scala/sttp/tapir/server/netty/cats/NettyCatsServerInterpreter.scala b/server/netty-server/cats/src/main/scala/sttp/tapir/server/netty/cats/NettyCatsServerInterpreter.scala index a6a796d96d..6cffa531ce 100644 --- a/server/netty-server/cats/src/main/scala/sttp/tapir/server/netty/cats/NettyCatsServerInterpreter.scala +++ b/server/netty-server/cats/src/main/scala/sttp/tapir/server/netty/cats/NettyCatsServerInterpreter.scala @@ -32,7 +32,7 @@ trait NettyCatsServerInterpreter[F[_]] { val serverInterpreter = new ServerInterpreter[Fs2Streams[F], F, NettyResponse, Fs2Streams[F]]( FilterServerEndpoints(ses), new NettyCatsRequestBody(createFile), - new NettyCatsToResponseBody(nettyServerOptions.dispatcher), + new NettyCatsToResponseBody(nettyServerOptions.dispatcher, delegate = new NettyToResponseBody), RejectInterceptor.disableWhenSingleEndpoint(interceptors, ses), deleteFile ) diff --git a/server/netty-server/cats/src/main/scala/sttp/tapir/server/netty/internal/NettyCatsToResponseBody.scala b/server/netty-server/cats/src/main/scala/sttp/tapir/server/netty/internal/NettyCatsToResponseBody.scala index eb7085284e..bcf754da4e 100644 --- a/server/netty-server/cats/src/main/scala/sttp/tapir/server/netty/internal/NettyCatsToResponseBody.scala +++ b/server/netty-server/cats/src/main/scala/sttp/tapir/server/netty/internal/NettyCatsToResponseBody.scala @@ -12,44 +12,31 @@ import org.reactivestreams.Publisher import sttp.capabilities.fs2.Fs2Streams import sttp.model.HasHeaders import sttp.tapir.server.interpreter.ToResponseBody -import sttp.tapir.server.netty.NettyResponseContent.ByteBufNettyResponseContent -import sttp.tapir.server.netty.{NettyResponse, NettyResponseContent} +import sttp.tapir.server.netty.NettyResponse +import sttp.tapir.server.netty.NettyResponseContent._ import sttp.tapir.{CodecFormat, RawBodyType, WebSocketBodyOutput} import java.io.InputStream -import java.nio.ByteBuffer import java.nio.charset.Charset -import fs2.Pull -class NettyCatsToResponseBody[F[_]: Async](dispatcher: Dispatcher[F]) +class NettyCatsToResponseBody[F[_]: Async](dispatcher: Dispatcher[F], delegate: NettyToResponseBody) extends ToResponseBody[NettyResponse, Fs2Streams[F]] { override val streams: Fs2Streams[F] = Fs2Streams[F] override def fromRawValue[R](v: R, headers: HasHeaders, format: CodecFormat, bodyType: RawBodyType[R]): NettyResponse = { bodyType match { - case RawBodyType.StringBody(charset) => - val bytes = v.asInstanceOf[String].getBytes(charset) - (ctx: ChannelHandlerContext) => ByteBufNettyResponseContent(ctx.newPromise(), Unpooled.wrappedBuffer(bytes)) - - case RawBodyType.ByteArrayBody => - val bytes = v.asInstanceOf[Array[Byte]] - (ctx: ChannelHandlerContext) => ByteBufNettyResponseContent(ctx.newPromise(), Unpooled.wrappedBuffer(bytes)) - - case RawBodyType.ByteBufferBody => - val byteBuffer = v.asInstanceOf[ByteBuffer] - (ctx: ChannelHandlerContext) => ByteBufNettyResponseContent(ctx.newPromise(), Unpooled.wrappedBuffer(byteBuffer)) case RawBodyType.InputStreamBody => val stream = inputStreamToFs2(() => v) (ctx: ChannelHandlerContext) => - new NettyResponseContent.ReactivePublisherNettyResponseContent(ctx.newPromise(), fs2StreamToPublisher(stream)) + new ReactivePublisherNettyResponseContent(ctx.newPromise(), fs2StreamToPublisher(stream)) case RawBodyType.InputStreamRangeBody => val stream = v.range .map(range => inputStreamToFs2(v.inputStreamFromRangeStart).take(range.contentLength)) .getOrElse(inputStreamToFs2(v.inputStream)) (ctx: ChannelHandlerContext) => - new NettyResponseContent.ReactivePublisherNettyResponseContent(ctx.newPromise(), fs2StreamToPublisher(stream)) + new ReactivePublisherNettyResponseContent(ctx.newPromise(), fs2StreamToPublisher(stream)) case RawBodyType.FileBody => val tapirFile = v @@ -59,9 +46,11 @@ class NettyCatsToResponseBody[F[_]: Async](dispatcher: Dispatcher[F]) .getOrElse(Files[F](Files.forAsync[F]).readAll(path, NettyToResponseBody.DefaultChunkSize, Flags.Read)) (ctx: ChannelHandlerContext) => - new NettyResponseContent.ReactivePublisherNettyResponseContent(ctx.newPromise(), fs2StreamToPublisher(stream)) + new ReactivePublisherNettyResponseContent(ctx.newPromise(), fs2StreamToPublisher(stream)) case _: RawBodyType.MultipartBody => throw new UnsupportedOperationException + + case _ => delegate.fromRawValue(v, headers, format, bodyType) } } @@ -92,7 +81,7 @@ class NettyCatsToResponseBody[F[_]: Async](dispatcher: Dispatcher[F]) charset: Option[Charset] ): NettyResponse = (ctx: ChannelHandlerContext) => { - new NettyResponseContent.ReactivePublisherNettyResponseContent(ctx.newPromise(), fs2StreamToPublisher(v)) + new ReactivePublisherNettyResponseContent(ctx.newPromise(), fs2StreamToPublisher(v)) } override def fromWebSocketPipe[REQ, RESP]( diff --git a/server/netty-server/src/main/scala/sttp/tapir/server/netty/NettyConfig.scala b/server/netty-server/src/main/scala/sttp/tapir/server/netty/NettyConfig.scala index 3eea0e9c48..3d5c6af29d 100644 --- a/server/netty-server/src/main/scala/sttp/tapir/server/netty/NettyConfig.scala +++ b/server/netty-server/src/main/scala/sttp/tapir/server/netty/NettyConfig.scala @@ -10,7 +10,6 @@ import io.netty.handler.codec.http.{HttpObjectAggregator, HttpServerCodec} import io.netty.handler.logging.LoggingHandler import io.netty.handler.ssl.SslContext import io.netty.handler.stream.ChunkedWriteHandler -import io.netty.handler.timeout.ReadTimeoutException import sttp.tapir.server.netty.NettyConfig.EventLoopConfig import scala.concurrent.duration._ @@ -99,7 +98,7 @@ case class NettyConfig( } object NettyConfig { - def default: NettyConfig = NettyConfig( + def defaultNoStreaming: NettyConfig = NettyConfig( host = "localhost", port = 8080, shutdownEventLoopGroupOnClose = true, @@ -114,10 +113,10 @@ object NettyConfig { sslContext = None, eventLoopConfig = EventLoopConfig.auto, socketConfig = NettySocketConfig.default, - initPipeline = cfg => defaultInitPipeline(cfg)(_, _) + initPipeline = cfg => defaultInitPipelineNoStreaming(cfg)(_, _) ) - def defaultInitPipeline(cfg: NettyConfig)(pipeline: ChannelPipeline, handler: ChannelHandler): Unit = { + def defaultInitPipelineNoStreaming(cfg: NettyConfig)(pipeline: ChannelPipeline, handler: ChannelHandler): Unit = { cfg.sslContext.foreach(s => pipeline.addLast(s.newHandler(pipeline.channel().alloc()))) pipeline.addLast(new HttpServerCodec()) pipeline.addLast(new HttpObjectAggregator(cfg.maxContentLength.getOrElse(Integer.MAX_VALUE))) @@ -126,7 +125,7 @@ object NettyConfig { () } - def streamingPipeline(cfg: NettyConfig)(pipeline: ChannelPipeline, handler: ChannelHandler): Unit = { + def defaultInitPipelineStreaming(cfg: NettyConfig)(pipeline: ChannelPipeline, handler: ChannelHandler): Unit = { cfg.sslContext.foreach(s => pipeline.addLast(s.newHandler(pipeline.channel().alloc()))) pipeline.addLast(new HttpServerCodec()) pipeline.addLast(new HttpStreamsServerHandler()) @@ -135,7 +134,7 @@ object NettyConfig { () } - def defaultWithStreaming: NettyConfig = default.copy(initPipeline = cfg => streamingPipeline(cfg)(_, _)) + def defaultWithStreaming: NettyConfig = defaultNoStreaming.copy(initPipeline = cfg => defaultInitPipelineStreaming(cfg)(_, _)) case class EventLoopConfig(initEventLoopGroup: () => EventLoopGroup, serverChannel: Class[_ <: ServerChannel]) diff --git a/server/netty-server/src/main/scala/sttp/tapir/server/netty/NettyFutureServer.scala b/server/netty-server/src/main/scala/sttp/tapir/server/netty/NettyFutureServer.scala index b04ea2dd5c..b53483a0a9 100644 --- a/server/netty-server/src/main/scala/sttp/tapir/server/netty/NettyFutureServer.scala +++ b/server/netty-server/src/main/scala/sttp/tapir/server/netty/NettyFutureServer.scala @@ -55,7 +55,12 @@ case class NettyFutureServer(routes: Vector[FutureRoute], options: NettyFutureSe val route = Route.combine(routes) val channelFuture = - NettyBootstrap(config, new NettyServerHandler(route, (f: () => Future[Unit]) => f(), config.maxContentLength), eventLoopGroup, socketOverride) + NettyBootstrap( + config, + new NettyServerHandler(route, (f: () => Future[Unit]) => f(), config.maxContentLength), + eventLoopGroup, + socketOverride + ) nettyChannelFutureToScala(channelFuture).map(ch => (ch.localAddress().asInstanceOf[SA], () => stop(ch, eventLoopGroup))) } @@ -71,10 +76,10 @@ case class NettyFutureServer(routes: Vector[FutureRoute], options: NettyFutureSe object NettyFutureServer { def apply()(implicit ec: ExecutionContext): NettyFutureServer = - NettyFutureServer(Vector.empty, NettyFutureServerOptions.default, NettyConfig.default) + NettyFutureServer(Vector.empty, NettyFutureServerOptions.default, NettyConfig.defaultNoStreaming) def apply(serverOptions: NettyFutureServerOptions)(implicit ec: ExecutionContext): NettyFutureServer = - NettyFutureServer(Vector.empty, serverOptions, NettyConfig.default) + NettyFutureServer(Vector.empty, serverOptions, NettyConfig.defaultNoStreaming) def apply(config: NettyConfig)(implicit ec: ExecutionContext): NettyFutureServer = NettyFutureServer(Vector.empty, NettyFutureServerOptions.default, config) diff --git a/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/NettyServerHandler.scala b/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/NettyServerHandler.scala index fa268a68db..953634710e 100644 --- a/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/NettyServerHandler.scala +++ b/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/NettyServerHandler.scala @@ -4,6 +4,7 @@ import com.typesafe.netty.http.{DefaultStreamedHttpResponse, StreamedHttpRequest import com.typesafe.scalalogging.Logger import io.netty.buffer.{ByteBuf, Unpooled} import io.netty.channel._ +import io.netty.handler.codec.http.HttpHeaderNames.{CONNECTION, CONTENT_LENGTH} import io.netty.handler.codec.http._ import io.netty.handler.stream.{ChunkedFile, ChunkedStream} import org.reactivestreams.Publisher @@ -13,19 +14,19 @@ import sttp.tapir.server.model.ServerResponse import sttp.tapir.server.netty.NettyResponseContent.{ ByteBufNettyResponseContent, ChunkedFileNettyResponseContent, - ChunkedStreamNettyResponseContent + ChunkedStreamNettyResponseContent, + ReactivePublisherNettyResponseContent } -import sttp.tapir.server.netty.{NettyResponse, NettyResponseContent, NettyServerRequest, Route} -import io.netty.handler.codec.http.HttpHeaderNames.CONNECTION; -import io.netty.handler.codec.http.HttpHeaderNames.CONTENT_LENGTH; +import sttp.tapir.server.netty.{NettyResponse, NettyServerRequest, Route} import scala.collection.JavaConverters._ -class NettyServerHandler[F[_]](route: Route[F], unsafeRunAsync: (() => F[Unit]) => Unit, maxContentLength: Option[Int])(implicit me: MonadError[F]) - extends SimpleChannelInboundHandler[HttpRequest] { +class NettyServerHandler[F[_]](route: Route[F], unsafeRunAsync: (() => F[Unit]) => Unit, maxContentLength: Option[Int])(implicit + me: MonadError[F] +) extends SimpleChannelInboundHandler[HttpRequest] { private val logger = Logger[NettyServerHandler[F]] - + private val EntityTooLarge: FullHttpResponse = { val res = new DefaultFullHttpResponse(HttpVersion.HTTP_1_1, HttpResponseStatus.REQUEST_ENTITY_TOO_LARGE, Unpooled.EMPTY_BUFFER) res.headers().set(CONTENT_LENGTH, 0) @@ -39,8 +40,35 @@ class NettyServerHandler[F[_]](route: Route[F], unsafeRunAsync: (() => F[Unit]) res } - override def channelRead0(ctx: ChannelHandlerContext, request: HttpRequest): Unit = { + + def runRoute(req: HttpRequest) = { + + route(NettyServerRequest(req)) + .map { + case Some(response) => response + case None => ServerResponse.notFound + } + .flatMap((serverResponse: ServerResponse[NettyResponse]) => + // in ZIO, exceptions thrown in .map become defects - instead, we want them represented as errors so that + // we get the 500 response, instead of dropping the request + try handleResponse(ctx, req, serverResponse).unit + catch { + case e: Exception => me.error[Unit](e) + } + ) + .handleError { case ex: Exception => + logger.error("Error while processing the request", ex) + // send 500 + val res = new DefaultFullHttpResponse(HttpVersion.HTTP_1_1, HttpResponseStatus.INTERNAL_SERVER_ERROR) + res.headers().set(HttpHeaderNames.CONTENT_LENGTH, 0) + res.handleCloseAndKeepAliveHeaders(req) + + ctx.writeAndFlush(res).closeIfNeeded(req) + me.unit(()) + } + } + if (HttpUtil.is100ContinueExpected(request)) { ctx.writeAndFlush(new DefaultFullHttpResponse(HttpVersion.HTTP_1_1, HttpResponseStatus.CONTINUE)) () @@ -48,58 +76,14 @@ class NettyServerHandler[F[_]](route: Route[F], unsafeRunAsync: (() => F[Unit]) request match { case full: FullHttpRequest => val req = full.retain() - unsafeRunAsync { () => - route(NettyServerRequest(req)) - .map { - case Some(response) => response - case None => ServerResponse.notFound - } - .flatMap((serverResponse: ServerResponse[NettyResponse]) => - // in ZIO, exceptions thrown in .map become defects - instead, we want them represented as errors so that - // we get the 500 response, instead of dropping the request - try handleResponse(ctx, req, serverResponse).unit - catch { - case e: Exception => me.error[Unit](e) - } - ) - .handleError { case ex: Exception => - logger.error("Error while processing the request", ex) - // send 500 - val res = new DefaultFullHttpResponse(HttpVersion.HTTP_1_1, HttpResponseStatus.INTERNAL_SERVER_ERROR) - res.headers().set(HttpHeaderNames.CONTENT_LENGTH, 0) - res.handleCloseAndKeepAliveHeaders(req) - - ctx.writeAndFlush(res).closeIfNeeded(req) - me.unit(()) - } + runRoute(req) .ensure(me.eval(req.release())) } // exceptions should be handled case req: StreamedHttpRequest => unsafeRunAsync { () => - route(NettyServerRequest(req)) - .map { - case Some(response) => response - case None => ServerResponse.notFound - } - .flatMap((serverResponse: ServerResponse[NettyResponse]) => - // in ZIO, exceptions thrown in .map become defects - instead, we want them represented as errors so that - // we get the 500 response, instead of dropping the request - try handleResponse(ctx, req, serverResponse).unit - catch { - case e: Exception => me.error[Unit](e) - } - ) - .handleError { case ex: Exception => - logger.error("Error while processing the request", ex) - // send 500 - val res = new DefaultFullHttpResponse(HttpVersion.HTTP_1_1, HttpResponseStatus.INTERNAL_SERVER_ERROR) - res.headers().set(HttpHeaderNames.CONTENT_LENGTH, 0) - - ctx.writeAndFlush(res) - me.unit(()) - } - } // exceptions should be handled + runRoute(req) + } case _ => throw new UnsupportedOperationException(s"Unexpected Netty request type: ${request.getClass.getName}") } @@ -212,7 +196,7 @@ class NettyServerHandler[F[_]](route: Route[F], unsafeRunAsync: (() => F[Unit]) case r: ByteBufNettyResponseContent => byteBufHandler(r.channelPromise, r.byteBuf) case r: ChunkedStreamNettyResponseContent => chunkedStreamHandler(r.channelPromise, r.chunkedStream) case r: ChunkedFileNettyResponseContent => chunkedFileHandler(r.channelPromise, r.chunkedFile) - case r: NettyResponseContent.ReactivePublisherNettyResponseContent => reactiveStreamHandler(r.channelPromise, r.publisher) + case r: ReactivePublisherNettyResponseContent => reactiveStreamHandler(r.channelPromise, r.publisher) } } case None => noBodyHandler() diff --git a/server/tests/src/main/scala/sttp/tapir/server/tests/ServerBasicTests.scala b/server/tests/src/main/scala/sttp/tapir/server/tests/ServerBasicTests.scala index 752ba36e9e..e82f8aa2dd 100644 --- a/server/tests/src/main/scala/sttp/tapir/server/tests/ServerBasicTests.scala +++ b/server/tests/src/main/scala/sttp/tapir/server/tests/ServerBasicTests.scala @@ -15,7 +15,7 @@ import sttp.tapir.codec.enumeratum.TapirCodecEnumeratum import sttp.tapir.generic.auto._ import sttp.tapir.json.circe._ import sttp.tapir.server.ServerEndpoint -import sttp.tapir.server.interceptor.decodefailure.{DecodeFailureHandler, DefaultDecodeFailureHandler} +import sttp.tapir.server.interceptor.decodefailure.DefaultDecodeFailureHandler import sttp.tapir.tests.Basic._ import sttp.tapir.tests.TestUtil._ import sttp.tapir.tests._ From 2712d6a9652bea90b5a004106ae53c07b2fdd5ae Mon Sep 17 00:00:00 2001 From: kciesielski Date: Tue, 18 Jul 2023 13:31:51 +0200 Subject: [PATCH 15/18] Let Netty Handle content-length and close handlers for streaming --- .../sttp/tapir/server/netty/internal/NettyServerHandler.scala | 1 - 1 file changed, 1 deletion(-) diff --git a/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/NettyServerHandler.scala b/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/NettyServerHandler.scala index 953634710e..645521965f 100644 --- a/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/NettyServerHandler.scala +++ b/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/NettyServerHandler.scala @@ -133,7 +133,6 @@ class NettyServerHandler[F[_]](route: Route[F], unsafeRunAsync: (() => F[Unit]) new DefaultStreamedHttpResponse(req.protocolVersion(), HttpResponseStatus.valueOf(serverResponse.code.code), publisher) res.setHeadersFrom(serverResponse) - res.handleContentLengthAndChunkedHeaders(None) res.handleCloseAndKeepAliveHeaders(req) ctx.writeAndFlush(res, channelPromise).closeIfNeeded(req) From 0064b12cdc8f36b7d39f0477411edea4c4b334d3 Mon Sep 17 00:00:00 2001 From: kciesielski Date: Tue, 18 Jul 2023 13:32:18 +0200 Subject: [PATCH 16/18] Flag for checking setting content-length header in streaming responses --- .../server/tests/ServerStreamingTests.scala | 170 +++++++++--------- 1 file changed, 89 insertions(+), 81 deletions(-) diff --git a/server/tests/src/main/scala/sttp/tapir/server/tests/ServerStreamingTests.scala b/server/tests/src/main/scala/sttp/tapir/server/tests/ServerStreamingTests.scala index 3369f6c009..c69681ff0b 100644 --- a/server/tests/src/main/scala/sttp/tapir/server/tests/ServerStreamingTests.scala +++ b/server/tests/src/main/scala/sttp/tapir/server/tests/ServerStreamingTests.scala @@ -4,9 +4,8 @@ import cats.syntax.all._ import org.scalatest.matchers.should.Matchers._ import sttp.capabilities.Streams import sttp.client3._ -import sttp.model.{Header, HeaderNames, MediaType} +import sttp.model.{Header, MediaType} import sttp.monad.MonadError -import sttp.tapir.tests.Test import sttp.tapir.tests.Streaming.{ in_stream_out_either_json_xml_stream, in_stream_out_stream, @@ -14,8 +13,13 @@ import sttp.tapir.tests.Streaming.{ in_string_stream_out_either_stream_string, out_custom_content_type_stream_body } +import sttp.tapir.tests.Test -class ServerStreamingTests[F[_], S, OPTIONS, ROUTE](createServerTest: CreateServerTest[F, S, OPTIONS, ROUTE], streams: Streams[S])(implicit +class ServerStreamingTests[F[_], S, OPTIONS, ROUTE]( + createServerTest: CreateServerTest[F, S, OPTIONS, ROUTE], + streams: Streams[S], + streamingWithContentLengthSupport: Boolean = true +)(implicit m: MonadError[F] ) { @@ -23,85 +27,89 @@ class ServerStreamingTests[F[_], S, OPTIONS, ROUTE](createServerTest: CreateServ import createServerTest._ val penPineapple = "pen pineapple apple pen" - - List( - testServer(in_stream_out_stream(streams))((s: streams.BinaryStream) => pureResult(s.asRight[Unit])) { (backend, baseUri) => - basicRequest.post(uri"$baseUri/api/echo").body(penPineapple).send(backend).map(_.body shouldBe Right(penPineapple)) - }, - testServer( - in_stream_out_stream_with_content_length(streams) - )((in: (Long, streams.BinaryStream)) => pureResult(in.asRight[Unit])) { (backend, baseUri) => - { - basicRequest - .post(uri"$baseUri/api/echo") - .contentLength(penPineapple.length.toLong) - .body(penPineapple) - .send(backend) - .map { response => - response.body shouldBe Right(penPineapple) - if (response.headers.contains(Header(HeaderNames.TransferEncoding, "chunked"))) { - response.contentLength shouldBe None - } else { - response.contentLength shouldBe Some(penPineapple.length) + if (streamingWithContentLengthSupport) + List( + testServer( + in_stream_out_stream_with_content_length(streams) + )((in: (Long, streams.BinaryStream)) => pureResult(in.asRight[Unit])) { (backend, baseUri) => + { + basicRequest + .post(uri"$baseUri/api/echo") + .contentLength(penPineapple.length.toLong) + .body(penPineapple) + .send(backend) + .map { response => + response.body shouldBe Right(penPineapple) + if (streamingWithContentLengthSupport) { + response.contentLength shouldBe Some(penPineapple.length) + } else { + response.contentLength shouldBe None + } } - } + } } - }, - testServer(out_custom_content_type_stream_body(streams)) { case (k, s) => - pureResult((if (k < 0) (MediaType.ApplicationJson.toString(), s) else (MediaType.ApplicationXml.toString(), s)).asRight[Unit]) - } { (backend, baseUri) => - basicRequest - .post(uri"$baseUri?kind=-1") - .body(penPineapple) - .send(backend) - .map { r => - r.body shouldBe Right(penPineapple) - r.contentType shouldBe Some(MediaType.ApplicationJson.toString()) - } >> - basicRequest - .post(uri"$baseUri?kind=1") - .body(penPineapple) - .send(backend) - .map { r => - r.body shouldBe Right(penPineapple) - r.contentType shouldBe Some(MediaType.ApplicationXml.toString()) - } - }, - testServer(in_string_stream_out_either_stream_string(streams)) { - case ("left", s) => pureResult((Left(s): Either[streams.BinaryStream, String]).asRight[Unit]) - case _ => pureResult((Right("was not left"): Either[streams.BinaryStream, String]).asRight[Unit]) - } { (backend, baseUri) => - basicRequest - .post(uri"$baseUri?which=left") - .body(penPineapple) - .send(backend) - .map(_.body shouldBe Right(penPineapple)) >> - basicRequest - .post(uri"$baseUri?which=right") - .body(penPineapple) - .send(backend) - .map(_.body shouldBe Right("was not left")) - }, - testServer(in_stream_out_either_json_xml_stream(streams)) { s => pureResult(s.asRight[Unit]) } { (backend, baseUri) => - basicRequest - .post(uri"$baseUri") - .body(penPineapple) - .header(Header.accept(MediaType.ApplicationXml, MediaType.ApplicationJson)) - .send(backend) - .map { r => - r.contentType shouldBe Some(MediaType.ApplicationXml.toString()) - r.body shouldBe Right(penPineapple) - } >> - basicRequest - .post(uri"$baseUri") - .body(penPineapple) - .header(Header.accept(MediaType.ApplicationJson, MediaType.ApplicationXml)) - .send(backend) - .map { r => - r.contentType shouldBe Some(MediaType.ApplicationJson.toString()) - r.body shouldBe Right(penPineapple) - } - } - ) + ) + else + Nil ++ + List( + testServer(in_stream_out_stream(streams))((s: streams.BinaryStream) => pureResult(s.asRight[Unit])) { (backend, baseUri) => + basicRequest.post(uri"$baseUri/api/echo").body(penPineapple).send(backend).map(_.body shouldBe Right(penPineapple)) + }, + testServer(out_custom_content_type_stream_body(streams)) { case (k, s) => + pureResult((if (k < 0) (MediaType.ApplicationJson.toString(), s) else (MediaType.ApplicationXml.toString(), s)).asRight[Unit]) + } { (backend, baseUri) => + basicRequest + .post(uri"$baseUri?kind=-1") + .body(penPineapple) + .send(backend) + .map { r => + r.body shouldBe Right(penPineapple) + r.contentType shouldBe Some(MediaType.ApplicationJson.toString()) + } >> + basicRequest + .post(uri"$baseUri?kind=1") + .body(penPineapple) + .send(backend) + .map { r => + r.body shouldBe Right(penPineapple) + r.contentType shouldBe Some(MediaType.ApplicationXml.toString()) + } + }, + testServer(in_string_stream_out_either_stream_string(streams)) { + case ("left", s) => pureResult((Left(s): Either[streams.BinaryStream, String]).asRight[Unit]) + case _ => pureResult((Right("was not left"): Either[streams.BinaryStream, String]).asRight[Unit]) + } { (backend, baseUri) => + basicRequest + .post(uri"$baseUri?which=left") + .body(penPineapple) + .send(backend) + .map(_.body shouldBe Right(penPineapple)) >> + basicRequest + .post(uri"$baseUri?which=right") + .body(penPineapple) + .send(backend) + .map(_.body shouldBe Right("was not left")) + }, + testServer(in_stream_out_either_json_xml_stream(streams)) { s => pureResult(s.asRight[Unit]) } { (backend, baseUri) => + basicRequest + .post(uri"$baseUri") + .body(penPineapple) + .header(Header.accept(MediaType.ApplicationXml, MediaType.ApplicationJson)) + .send(backend) + .map { r => + r.contentType shouldBe Some(MediaType.ApplicationXml.toString()) + r.body shouldBe Right(penPineapple) + } >> + basicRequest + .post(uri"$baseUri") + .body(penPineapple) + .header(Header.accept(MediaType.ApplicationJson, MediaType.ApplicationXml)) + .send(backend) + .map { r => + r.contentType shouldBe Some(MediaType.ApplicationJson.toString()) + r.body shouldBe Right(penPineapple) + } + } + ) } } From fd8383f2a03e1bf86e49080c2c7f606bab3bd3de Mon Sep 17 00:00:00 2001 From: kciesielski Date: Tue, 18 Jul 2023 13:51:23 +0200 Subject: [PATCH 17/18] Update overlooked references to NettyConfig.default --- doc/server/netty.md | 2 +- .../tapir/server/netty/NettyFutureTestServerInterpreter.scala | 2 +- .../scala/sttp/tapir/server/netty/zio/NettyZioServer.scala | 4 ++-- .../server/netty/zio/NettyZioTestServerInterpreter.scala | 2 +- 4 files changed, 5 insertions(+), 5 deletions(-) diff --git a/doc/server/netty.md b/doc/server/netty.md index 2db26b4bab..4c714cd145 100644 --- a/doc/server/netty.md +++ b/doc/server/netty.md @@ -60,7 +60,7 @@ NettyFutureServer().port(9090).addEndpoints(???) NettyFutureServer(NettyFutureServerOptions.customiseInterceptors.serverLog(None).options) // customise Netty config -NettyFutureServer(NettyConfig.default.socketBacklog(256)) +NettyFutureServer(NettyConfig.defaultNoStreaming.socketBacklog(256)) ``` ## Domain socket support diff --git a/server/netty-server/src/test/scala/sttp/tapir/server/netty/NettyFutureTestServerInterpreter.scala b/server/netty-server/src/test/scala/sttp/tapir/server/netty/NettyFutureTestServerInterpreter.scala index 15b9b9ab30..37f2f2481a 100644 --- a/server/netty-server/src/test/scala/sttp/tapir/server/netty/NettyFutureTestServerInterpreter.scala +++ b/server/netty-server/src/test/scala/sttp/tapir/server/netty/NettyFutureTestServerInterpreter.scala @@ -18,7 +18,7 @@ class NettyFutureTestServerInterpreter(eventLoopGroup: NioEventLoopGroup)(implic } override def server(routes: NonEmptyList[FutureRoute]): Resource[IO, Port] = { - val config = NettyConfig.default.eventLoopGroup(eventLoopGroup).randomPort.withDontShutdownEventLoopGroupOnClose + val config = NettyConfig.defaultNoStreaming.eventLoopGroup(eventLoopGroup).randomPort.withDontShutdownEventLoopGroupOnClose val options = NettyFutureServerOptions.default val bind = IO.fromFuture(IO.delay(NettyFutureServer(options, config).addRoutes(routes.toList).start())) diff --git a/server/netty-server/zio/src/main/scala/sttp/tapir/server/netty/zio/NettyZioServer.scala b/server/netty-server/zio/src/main/scala/sttp/tapir/server/netty/zio/NettyZioServer.scala index 8481e84715..cfaa64cf81 100644 --- a/server/netty-server/zio/src/main/scala/sttp/tapir/server/netty/zio/NettyZioServer.scala +++ b/server/netty-server/zio/src/main/scala/sttp/tapir/server/netty/zio/NettyZioServer.scala @@ -93,8 +93,8 @@ case class NettyZioServer[R](routes: Vector[RIO[R, Route[RIO[R, *]]]], options: } object NettyZioServer { - def apply[R](): NettyZioServer[R] = NettyZioServer(Vector.empty, NettyZioServerOptions.default[R], NettyConfig.default) - def apply[R](options: NettyZioServerOptions[R]): NettyZioServer[R] = NettyZioServer(Vector.empty, options, NettyConfig.default) + def apply[R](): NettyZioServer[R] = NettyZioServer(Vector.empty, NettyZioServerOptions.default[R], NettyConfig.defaultNoStreaming) + def apply[R](options: NettyZioServerOptions[R]): NettyZioServer[R] = NettyZioServer(Vector.empty, options, NettyConfig.defaultNoStreaming) def apply[R](config: NettyConfig): NettyZioServer[R] = NettyZioServer(Vector.empty, NettyZioServerOptions.default[R], config) def apply[R](options: NettyZioServerOptions[R], config: NettyConfig): NettyZioServer[R] = NettyZioServer(Vector.empty, options, config) } diff --git a/server/netty-server/zio/src/test/scala/sttp/tapir/server/netty/zio/NettyZioTestServerInterpreter.scala b/server/netty-server/zio/src/test/scala/sttp/tapir/server/netty/zio/NettyZioTestServerInterpreter.scala index 629bec2ca8..1c66df45a4 100644 --- a/server/netty-server/zio/src/test/scala/sttp/tapir/server/netty/zio/NettyZioTestServerInterpreter.scala +++ b/server/netty-server/zio/src/test/scala/sttp/tapir/server/netty/zio/NettyZioTestServerInterpreter.scala @@ -21,7 +21,7 @@ class NettyZioTestServerInterpreter[R](eventLoopGroup: NioEventLoopGroup) } override def server(routes: NonEmptyList[Task[Route[Task]]]): Resource[IO, Port] = { - val config = NettyConfig.default.eventLoopGroup(eventLoopGroup).randomPort.withDontShutdownEventLoopGroupOnClose + val config = NettyConfig.defaultNoStreaming.eventLoopGroup(eventLoopGroup).randomPort.withDontShutdownEventLoopGroupOnClose val options = NettyZioServerOptions.default[R] val runtime: Runtime[R] = Runtime.default.asInstanceOf[Runtime[R]] From 7dcdc6db47ca47fea302c7aba2fe0b0de7732bb1 Mon Sep 17 00:00:00 2001 From: kciesielski Date: Tue, 18 Jul 2023 15:35:03 +0200 Subject: [PATCH 18/18] Always expect content-length in streaming response if set by user --- .../server/tests/ServerStreamingTests.scala | 168 ++++++++---------- 1 file changed, 78 insertions(+), 90 deletions(-) diff --git a/server/tests/src/main/scala/sttp/tapir/server/tests/ServerStreamingTests.scala b/server/tests/src/main/scala/sttp/tapir/server/tests/ServerStreamingTests.scala index c69681ff0b..13ee47524a 100644 --- a/server/tests/src/main/scala/sttp/tapir/server/tests/ServerStreamingTests.scala +++ b/server/tests/src/main/scala/sttp/tapir/server/tests/ServerStreamingTests.scala @@ -4,8 +4,9 @@ import cats.syntax.all._ import org.scalatest.matchers.should.Matchers._ import sttp.capabilities.Streams import sttp.client3._ -import sttp.model.{Header, MediaType} +import sttp.model.{Header, HeaderNames, MediaType} import sttp.monad.MonadError +import sttp.tapir.tests.Test import sttp.tapir.tests.Streaming.{ in_stream_out_either_json_xml_stream, in_stream_out_stream, @@ -13,13 +14,8 @@ import sttp.tapir.tests.Streaming.{ in_string_stream_out_either_stream_string, out_custom_content_type_stream_body } -import sttp.tapir.tests.Test -class ServerStreamingTests[F[_], S, OPTIONS, ROUTE]( - createServerTest: CreateServerTest[F, S, OPTIONS, ROUTE], - streams: Streams[S], - streamingWithContentLengthSupport: Boolean = true -)(implicit +class ServerStreamingTests[F[_], S, OPTIONS, ROUTE](createServerTest: CreateServerTest[F, S, OPTIONS, ROUTE], streams: Streams[S])(implicit m: MonadError[F] ) { @@ -27,89 +23,81 @@ class ServerStreamingTests[F[_], S, OPTIONS, ROUTE]( import createServerTest._ val penPineapple = "pen pineapple apple pen" - if (streamingWithContentLengthSupport) - List( - testServer( - in_stream_out_stream_with_content_length(streams) - )((in: (Long, streams.BinaryStream)) => pureResult(in.asRight[Unit])) { (backend, baseUri) => - { - basicRequest - .post(uri"$baseUri/api/echo") - .contentLength(penPineapple.length.toLong) - .body(penPineapple) - .send(backend) - .map { response => - response.body shouldBe Right(penPineapple) - if (streamingWithContentLengthSupport) { - response.contentLength shouldBe Some(penPineapple.length) - } else { - response.contentLength shouldBe None - } - } - } + + List( + testServer(in_stream_out_stream(streams))((s: streams.BinaryStream) => pureResult(s.asRight[Unit])) { (backend, baseUri) => + basicRequest.post(uri"$baseUri/api/echo").body(penPineapple).send(backend).map(_.body shouldBe Right(penPineapple)) + }, + testServer( + in_stream_out_stream_with_content_length(streams) + )((in: (Long, streams.BinaryStream)) => pureResult(in.asRight[Unit])) { (backend, baseUri) => + { + basicRequest + .post(uri"$baseUri/api/echo") + .contentLength(penPineapple.length.toLong) + .body(penPineapple) + .send(backend) + .map { response => + response.body shouldBe Right(penPineapple) + response.contentLength shouldBe Some(penPineapple.length) + } } - ) - else - Nil ++ - List( - testServer(in_stream_out_stream(streams))((s: streams.BinaryStream) => pureResult(s.asRight[Unit])) { (backend, baseUri) => - basicRequest.post(uri"$baseUri/api/echo").body(penPineapple).send(backend).map(_.body shouldBe Right(penPineapple)) - }, - testServer(out_custom_content_type_stream_body(streams)) { case (k, s) => - pureResult((if (k < 0) (MediaType.ApplicationJson.toString(), s) else (MediaType.ApplicationXml.toString(), s)).asRight[Unit]) - } { (backend, baseUri) => - basicRequest - .post(uri"$baseUri?kind=-1") - .body(penPineapple) - .send(backend) - .map { r => - r.body shouldBe Right(penPineapple) - r.contentType shouldBe Some(MediaType.ApplicationJson.toString()) - } >> - basicRequest - .post(uri"$baseUri?kind=1") - .body(penPineapple) - .send(backend) - .map { r => - r.body shouldBe Right(penPineapple) - r.contentType shouldBe Some(MediaType.ApplicationXml.toString()) - } - }, - testServer(in_string_stream_out_either_stream_string(streams)) { - case ("left", s) => pureResult((Left(s): Either[streams.BinaryStream, String]).asRight[Unit]) - case _ => pureResult((Right("was not left"): Either[streams.BinaryStream, String]).asRight[Unit]) - } { (backend, baseUri) => - basicRequest - .post(uri"$baseUri?which=left") - .body(penPineapple) - .send(backend) - .map(_.body shouldBe Right(penPineapple)) >> - basicRequest - .post(uri"$baseUri?which=right") - .body(penPineapple) - .send(backend) - .map(_.body shouldBe Right("was not left")) - }, - testServer(in_stream_out_either_json_xml_stream(streams)) { s => pureResult(s.asRight[Unit]) } { (backend, baseUri) => - basicRequest - .post(uri"$baseUri") - .body(penPineapple) - .header(Header.accept(MediaType.ApplicationXml, MediaType.ApplicationJson)) - .send(backend) - .map { r => - r.contentType shouldBe Some(MediaType.ApplicationXml.toString()) - r.body shouldBe Right(penPineapple) - } >> - basicRequest - .post(uri"$baseUri") - .body(penPineapple) - .header(Header.accept(MediaType.ApplicationJson, MediaType.ApplicationXml)) - .send(backend) - .map { r => - r.contentType shouldBe Some(MediaType.ApplicationJson.toString()) - r.body shouldBe Right(penPineapple) - } - } - ) + }, + testServer(out_custom_content_type_stream_body(streams)) { case (k, s) => + pureResult((if (k < 0) (MediaType.ApplicationJson.toString(), s) else (MediaType.ApplicationXml.toString(), s)).asRight[Unit]) + } { (backend, baseUri) => + basicRequest + .post(uri"$baseUri?kind=-1") + .body(penPineapple) + .send(backend) + .map { r => + r.body shouldBe Right(penPineapple) + r.contentType shouldBe Some(MediaType.ApplicationJson.toString()) + } >> + basicRequest + .post(uri"$baseUri?kind=1") + .body(penPineapple) + .send(backend) + .map { r => + r.body shouldBe Right(penPineapple) + r.contentType shouldBe Some(MediaType.ApplicationXml.toString()) + } + }, + testServer(in_string_stream_out_either_stream_string(streams)) { + case ("left", s) => pureResult((Left(s): Either[streams.BinaryStream, String]).asRight[Unit]) + case _ => pureResult((Right("was not left"): Either[streams.BinaryStream, String]).asRight[Unit]) + } { (backend, baseUri) => + basicRequest + .post(uri"$baseUri?which=left") + .body(penPineapple) + .send(backend) + .map(_.body shouldBe Right(penPineapple)) >> + basicRequest + .post(uri"$baseUri?which=right") + .body(penPineapple) + .send(backend) + .map(_.body shouldBe Right("was not left")) + }, + testServer(in_stream_out_either_json_xml_stream(streams)) { s => pureResult(s.asRight[Unit]) } { (backend, baseUri) => + basicRequest + .post(uri"$baseUri") + .body(penPineapple) + .header(Header.accept(MediaType.ApplicationXml, MediaType.ApplicationJson)) + .send(backend) + .map { r => + r.contentType shouldBe Some(MediaType.ApplicationXml.toString()) + r.body shouldBe Right(penPineapple) + } >> + basicRequest + .post(uri"$baseUri") + .body(penPineapple) + .header(Header.accept(MediaType.ApplicationJson, MediaType.ApplicationXml)) + .send(backend) + .map { r => + r.contentType shouldBe Some(MediaType.ApplicationJson.toString()) + r.body shouldBe Right(penPineapple) + } + } + ) } }