Skip to content

Commit

Permalink
Implement multipart support for jdkhttp-server
Browse files Browse the repository at this point in the history
  • Loading branch information
jnatten committed Aug 24, 2023
1 parent 23cc7c9 commit 33a1dbd
Show file tree
Hide file tree
Showing 5 changed files with 192 additions and 8 deletions.
4 changes: 3 additions & 1 deletion build.sbt
Original file line number Diff line number Diff line change
Expand Up @@ -1376,7 +1376,9 @@ lazy val jdkhttpServer: ProjectMatrix = (projectMatrix in file("server/jdkhttp-s
.settings(commonJvmSettings)
.settings(
name := "tapir-jdkhttp-server",
libraryDependencies ++= loggerDependencies
libraryDependencies ++= Seq(
"org.apache.httpcomponents" % "httpmime" % "4.5.14"
) ++ loggerDependencies
)
.jvmPlatform(scalaVersions = List(scala2_13, scala3))
.dependsOn(serverCore, serverTests % Test)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,20 +3,27 @@ package internal

import com.sun.net.httpserver.HttpExchange
import sttp.capabilities
import sttp.model.Part
import sttp.tapir.capabilities.NoStreams
import sttp.tapir.model.ServerRequest
import sttp.tapir.server.interpreter.{RawValue, RequestBody}
import sttp.tapir.{FileRange, InputStreamRange, RawBodyType, TapirFile}
import sttp.tapir.server.jdkhttp.internal.ParsedMultiPart.parseMultipartBody
import sttp.tapir.{FileRange, InputStreamRange, RawBodyType, RawPart, TapirFile}

import java.io.InputStream
import java.io._
import java.nio.ByteBuffer
import java.nio.file.{Files, StandardCopyOption}

private[jdkhttp] class JdkHttpRequestBody(createFile: ServerRequest => TapirFile) extends RequestBody[Id, NoStreams] {
override val streams: capabilities.Streams[NoStreams] = NoStreams

override def toRaw[RAW](serverRequest: ServerRequest, bodyType: RawBodyType[RAW]): RawValue[RAW] = {
def asInputStream: InputStream = jdkHttpRequest(serverRequest).getRequestBody
val request = jdkHttpRequest(serverRequest)
toRaw(serverRequest, bodyType, request.getRequestBody)
}

private def toRaw[RAW](serverRequest: ServerRequest, bodyType: RawBodyType[RAW], body: InputStream): RawValue[RAW] = {
def asInputStream: InputStream = body
def asByteArray: Array[Byte] = asInputStream.readAllBytes()

bodyType match {
Expand All @@ -29,10 +36,42 @@ private[jdkhttp] class JdkHttpRequestBody(createFile: ServerRequest => TapirFile
val file = createFile(serverRequest)
Files.copy(asInputStream, file.toPath, StandardCopyOption.REPLACE_EXISTING)
RawValue(FileRange(file), Seq(FileRange(file)))
case _: RawBodyType.MultipartBody => throw new UnsupportedOperationException("MultipartBody is not supported")
case m: RawBodyType.MultipartBody => RawValue.fromParts(multiPartRequestToRawBody(serverRequest, m))
}
}

private val boundaryPrefix = "boundary="
private def extractBoundary(request: HttpExchange): String = {
Option(request.getRequestHeaders.getFirst("Content-Type"))
.flatMap(
_.split(";")
.find(_.trim().startsWith(boundaryPrefix))
.map(line => s"--${line.trim().substring(boundaryPrefix.length)}")
)
.getOrElse(throw new IllegalArgumentException("Unable to extract multipart boundary from multipart request"))
}

private def multiPartRequestToRawBody(request: ServerRequest, m: RawBodyType.MultipartBody): Seq[RawPart] = {
val httpExchange = jdkHttpRequest(request)
val boundary = extractBoundary(httpExchange)

parseMultipartBody(httpExchange, boundary).flatMap(parsedPart =>
parsedPart.getName.flatMap(name =>
m.partType(name)
.map(partType => {
val bodyInputStream = new ByteArrayInputStream(parsedPart.body)
val bodyRawValue = toRaw(request, partType, bodyInputStream)
Part(
name,
bodyRawValue.value,
otherDispositionParams = parsedPart.getDispositionParams - "name",
headers = parsedPart.fileItemHeaders
)
})
)
)
}

override def toStream(serverRequest: ServerRequest): streams.BinaryStream = throw new UnsupportedOperationException(
"Streaming is not supported"
)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
package sttp.tapir.server.jdkhttp
package internal

import org.apache.http.entity.ContentType
import org.apache.http.entity.mime.content._
import org.apache.http.entity.mime.{FormBodyPart, FormBodyPartBuilder, MultipartEntityBuilder}
import sttp.capabilities
import sttp.model.HasHeaders
import sttp.model.{HasHeaders, Part}
import sttp.tapir.capabilities.NoStreams
import sttp.tapir.server.interpreter.ToResponseBody
import sttp.tapir.{CodecFormat, FileRange, InputStreamRange, RawBodyType, WebSocketBodyOutput}
Expand Down Expand Up @@ -49,7 +52,56 @@ private[jdkhttp] class JdkHttpToResponseBody extends ToResponseBody[JdkHttpRespo
.getOrElse {
(base, Some(tapirFile.file.length()))
}
case _: RawBodyType.MultipartBody => throw new UnsupportedOperationException("MultipartBody is not supported")
case m: RawBodyType.MultipartBody =>
val entity = MultipartEntityBuilder.create()
v.flatMap(rawPartToFormBodyPart(m, _)).foreach { (formBodyPart: FormBodyPart) => entity.addPart(formBodyPart) }
val builtEntity = entity.build()
val inputStream: InputStream = builtEntity.getContent
(inputStream, Some(builtEntity.getContentLength))
}
}

private def rawPartToFormBodyPart[R](m: RawBodyType.MultipartBody, part: Part[R]): Option[FormBodyPart] = {
m.partType(part.name).map { partType =>
val builder = FormBodyPartBuilder
.create(
part.name,
rawValueToContentBody(partType.asInstanceOf[RawBodyType[Any]], part.asInstanceOf[Part[Any]], part.body)
)

part.headers.foreach(header => builder.addField(header.name, header.value))

builder.build()
}
}

private def rawValueToContentBody[CF <: CodecFormat, R](
bodyType: RawBodyType[R],
part: Part[R],
r: R
): ContentBody = {
val contentType: String = part.header("content-type").getOrElse("text/plain")

bodyType match {
case RawBodyType.StringBody(_) =>
new StringBody(r.toString, ContentType.parse(contentType))
case RawBodyType.ByteArrayBody =>
new ByteArrayBody(r, ContentType.create(contentType), part.fileName.get)
case RawBodyType.ByteBufferBody =>
val array: Array[Byte] = new Array[Byte](r.remaining)
r.get(array)
new ByteArrayBody(array, ContentType.create(contentType), part.fileName.get)
case RawBodyType.FileBody =>
part.fileName match {
case Some(filename) => new FileBody(r.file, ContentType.create(contentType), filename)
case None => new FileBody(r.file, ContentType.create(contentType))
}
case RawBodyType.InputStreamRangeBody =>
new InputStreamBody(r.inputStream(), ContentType.create(contentType), part.fileName.get)
case RawBodyType.InputStreamBody =>
new InputStreamBody(r, ContentType.create(contentType), part.fileName.get)
case _: RawBodyType.MultipartBody =>
throw new UnsupportedOperationException("Nested multipart messages are not supported.")
}
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
package sttp.tapir.server.jdkhttp.internal

import com.sun.net.httpserver.HttpExchange
import sttp.model.Header
import java.io.{BufferedReader, InputStreamReader}

case class ParsedMultiPart(headers: Map[String, Seq[String]], body: Array[Byte]) {
def getHeader(headerName: String): Option[String] = headers.get(headerName).flatMap(_.headOption)
def fileItemHeaders: Seq[Header] = headers.toSeq.flatMap { case (name, values) => values.map(value => Header(name, value)) }

def getDispositionParams: Map[String, String] = {
val headerValue = getHeader("content-disposition")
headerValue
.map(
_.split(";")
.map(_.trim)
.tail
.map(_.split("="))
.map(array => array(0) -> array(1))
.toMap
)
.getOrElse(Map.empty)
}

def getName: Option[String] =
headers
.getOrElse("content-disposition", Seq.empty)
.headOption
.flatMap(
_.split(";")
.find(_.trim.startsWith("name"))
.map(_.split("=")(1).trim)
.map(_.replaceAll("^\"|\"$", ""))
)

def addHeader(l: String): ParsedMultiPart = {
val (name, value) = l.splitAt(l.indexOf(":"))
val headerName = name.trim.toLowerCase
val headerValue = value.stripPrefix(":").trim
val newHeaderEntry = (headerName -> (this.headers.getOrElse(headerName, Seq.empty) :+ headerValue))
this.copy(headers = headers + newHeaderEntry)
}

}

object ParsedMultiPart {
def empty: ParsedMultiPart = new ParsedMultiPart(Map.empty, Array.empty)

sealed trait ParseState
case object Default extends ParseState
case object AfterBoundary extends ParseState
case object AfterHeaderSpace extends ParseState

private case class ParseData(
currentPart: ParsedMultiPart,
completedParts: List[ParsedMultiPart],
parseState: ParseState
) {
def changeState(state: ParseState): ParseData = this.copy(parseState = state)
def addHeader(header: String): ParseData = this.copy(currentPart = currentPart.addHeader(header))
def addBody(body: Array[Byte]): ParseData = this.copy(currentPart = currentPart.copy(body = currentPart.body ++ body))
def completePart(): ParseData = this.currentPart.getName match {
case Some(_) =>
this.copy(
completedParts = completedParts :+ currentPart,
currentPart = empty,
parseState = AfterBoundary
)
case None => changeState(AfterBoundary)
}
}

def parseMultipartBody(httpExchange: HttpExchange, boundary: String): Seq[ParsedMultiPart] = {
val reader = new BufferedReader(new InputStreamReader(httpExchange.getRequestBody))
val initialParseState: ParseData = ParseData(empty, List.empty, Default)
Iterator
.continually(reader.readLine())
.takeWhile(_ != null)
.foldLeft(initialParseState) { case (state, line) =>
state.parseState match {
case Default if line.startsWith(boundary) => state.changeState(AfterBoundary)
case Default => state
case AfterBoundary if line.trim.isEmpty => state.changeState(AfterHeaderSpace)
case AfterBoundary => state.addHeader(line)
case AfterHeaderSpace if !line.startsWith(boundary) => state.addBody(line.getBytes())
case AfterHeaderSpace => state.completePart()
}
}
.completedParts
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ class JdkHttpServerTest extends TestSuite with EitherValues {
val createServerTest = new DefaultCreateServerTest(backend, interpreter)

new ServerBasicTests(createServerTest, interpreter, invulnerableToUnsanitizedHeaders = false).tests() ++
new AllServerTests(createServerTest, interpreter, backend, basic = false, multipart = false).tests()
new AllServerTests(createServerTest, interpreter, backend, basic = false).tests()
})
}
}

0 comments on commit 33a1dbd

Please sign in to comment.