Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update to sttp client 3. Use capabilities, replace Nothing with Any as the no-streaming specification #753

Merged
merged 6 commits into from
Sep 17, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ val booksListingRoute: Route = booksListing.toRoute((bookListingLogic _).tupled)
// Convert to sttp Request

import sttp.tapir.client.sttp._
import sttp.client._
import sttp.client3._

val booksListingRequest: Request[DecodeResult[Either[String, List[Book]]], Nothing] = booksListing
.toSttpRequest(uri"http://localhost:8080")
Expand Down
24 changes: 14 additions & 10 deletions build.sbt
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,8 @@ lazy val core: ProjectMatrix = (projectMatrix in file("core"))
libraryDependencies ++= Seq(
"com.propensive" %% "magnolia" % "0.17.0",
"org.scala-lang" % "scala-reflect" % scalaVersion.value,
"com.softwaremill.sttp.model" %% "core" % "1.1.4",
"com.softwaremill.sttp.model" %% "core" % Versions.sttpModel,
"com.softwaremill.sttp.shared" %% "core" % Versions.sttpShared,
scalaTest % Test,
scalaCheck % Test,
scalaTestPlusScalaCheck % Test,
Expand Down Expand Up @@ -202,7 +203,8 @@ lazy val zio: ProjectMatrix = (projectMatrix in file("integrations/zio"))
libraryDependencies ++= Seq(
"dev.zio" %% "zio" % Versions.zio,
"dev.zio" %% "zio-streams" % Versions.zio,
scalaTest % Test
scalaTest % Test,
"com.softwaremill.sttp.shared" %% "zio" % Versions.sttpShared
)
)
.jvmPlatform(scalaVersions = allScalaVersions)
Expand Down Expand Up @@ -404,7 +406,7 @@ lazy val serverTests: ProjectMatrix = (projectMatrix in file("server/tests"))
.settings(
name := "tapir-server-tests",
libraryDependencies ++= Seq(
"com.softwaremill.sttp.client" %% "async-http-client-backend-cats" % Versions.sttp
"com.softwaremill.sttp.client3" %% "async-http-client-backend-cats" % Versions.sttp
)
)
.dependsOn(tests)
Expand All @@ -416,7 +418,8 @@ lazy val akkaHttpServer: ProjectMatrix = (projectMatrix in file("server/akka-htt
name := "tapir-akka-http-server",
libraryDependencies ++= Seq(
"com.typesafe.akka" %% "akka-http" % Versions.akkaHttp,
"com.typesafe.akka" %% "akka-stream" % Versions.akkaStreams
"com.typesafe.akka" %% "akka-stream" % Versions.akkaStreams,
"com.softwaremill.sttp.shared" %% "akka" % Versions.sttpShared
)
)
.jvmPlatform(scalaVersions = allScalaVersions)
Expand All @@ -427,7 +430,8 @@ lazy val http4sServer: ProjectMatrix = (projectMatrix in file("server/http4s-ser
.settings(
name := "tapir-http4s-server",
libraryDependencies ++= Seq(
"org.http4s" %% "http4s-blaze-server" % Versions.http4s
"org.http4s" %% "http4s-blaze-server" % Versions.http4s,
"com.softwaremill.sttp.shared" %% "fs2" % Versions.sttpShared
)
)
.jvmPlatform(scalaVersions = allScalaVersions)
Expand Down Expand Up @@ -532,8 +536,8 @@ lazy val sttpClient: ProjectMatrix = (projectMatrix in file("client/sttp-client"
.settings(
name := "tapir-sttp-client",
libraryDependencies ++= Seq(
"com.softwaremill.sttp.client" %% "core" % Versions.sttp,
"com.softwaremill.sttp.client" %% "async-http-client-backend-fs2" % Versions.sttp % Test
"com.softwaremill.sttp.client3" %% "core" % Versions.sttp,
"com.softwaremill.sttp.client3" %% "async-http-client-backend-fs2" % Versions.sttp % Test
)
)
.jvmPlatform(scalaVersions = allScalaVersions)
Expand All @@ -549,7 +553,7 @@ lazy val examples: ProjectMatrix = (projectMatrix in file("examples"))
"dev.zio" %% "zio-interop-cats" % Versions.zioInteropCats,
"org.typelevel" %% "cats-effect" % Versions.catsEffect,
"org.http4s" %% "http4s-dsl" % Versions.http4s,
"com.softwaremill.sttp.client" %% "async-http-client-backend-zio" % Versions.sttp
"com.softwaremill.sttp.client3" %% "async-http-client-backend-zio" % Versions.sttp
),
libraryDependencies ++= loggerDependencies,
publishArtifact := false
Expand All @@ -562,13 +566,13 @@ lazy val playground: ProjectMatrix = (projectMatrix in file("playground"))
.settings(
name := "tapir-playground",
libraryDependencies ++= Seq(
"com.softwaremill.sttp.client" %% "akka-http-backend" % Versions.sttp,
"com.softwaremill.sttp.client3" %% "akka-http-backend" % Versions.sttp,
"dev.zio" %% "zio" % Versions.zio,
"dev.zio" %% "zio-interop-cats" % Versions.zioInteropCats,
"org.typelevel" %% "cats-effect" % Versions.catsEffect,
"io.swagger" % "swagger-annotations" % "1.6.2",
"io.circe" %% "circe-generic-extras" % "0.13.0",
"com.softwaremill.sttp.client" %% "akka-http-backend" % Versions.sttp
"com.softwaremill.sttp.client3" %% "akka-http-backend" % Versions.sttp
),
libraryDependencies ++= loggerDependencies,
publishArtifact := false
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,19 +3,20 @@ package sttp.tapir.client.sttp
import java.io.ByteArrayInputStream
import java.nio.ByteBuffer

import sttp.client._
import sttp.capabilities.Streams
import sttp.client3._
import sttp.model.Uri.PathSegment
import sttp.model.{HeaderNames, Method, Part, Uri}
import sttp.tapir.Codec.PlainCodec
import sttp.tapir._
import sttp.tapir.internal._

class EndpointToSttpClient(clientOptions: SttpClientOptions) {
def toSttpRequestUnsafe[I, E, O, S](e: Endpoint[I, E, O, S], baseUri: Uri): I => Request[Either[E, O], S] = { params =>
def toSttpRequestUnsafe[I, E, O, R](e: Endpoint[I, E, O, R], baseUri: Uri): I => Request[Either[E, O], R] = { params =>
toSttpRequest(e, baseUri)(params).mapResponse(getOrThrow)
}

def toSttpRequest[S, O, E, I](e: Endpoint[I, E, O, S], baseUri: Uri): I => Request[DecodeResult[Either[E, O]], S] = { params =>
def toSttpRequest[R, O, E, I](e: Endpoint[I, E, O, R], baseUri: Uri): I => Request[DecodeResult[Either[E, O]], R] = { params =>
val (uri, req1) =
setInputParams(
e.input,
Expand All @@ -26,14 +27,10 @@ class EndpointToSttpClient(clientOptions: SttpClientOptions) {

val req2 = req1.copy[Identity, Any, Any](method = sttp.model.Method(e.input.method.getOrElse(Method.GET).method), uri = uri)

val responseAs = fromMetadata { meta =>
val output = if (meta.isSuccess) e.output else e.errorOutput
if (output == EndpointOutput.Void()) {
throw new IllegalStateException(s"Got response: $meta, cannot map to a void output of: $e.")
}

responseAsFromOutputs(meta, output)
}.mapWithMetadata { (body, meta) =>
val responseAs = fromMetadata(
responseAsFromOutputs(e.errorOutput),
ConditionalResponseAs(_.isSuccess, responseAsFromOutputs(e.output))
).mapWithMetadata { (body, meta) =>
val output = if (meta.isSuccess) e.output else e.errorOutput
val params = getOutputParams(output, body, meta)
params.map(_.asAny).map(p => if (meta.isSuccess) Right(p) else Left(p))
Expand All @@ -43,21 +40,21 @@ class EndpointToSttpClient(clientOptions: SttpClientOptions) {
case other => other
}

req2.response(responseAs).asInstanceOf[Request[DecodeResult[Either[E, O]], S]]
req2.response(responseAs).asInstanceOf[Request[DecodeResult[Either[E, O]], R]]
}

private def getOutputParams(output: EndpointOutput[_], body: Any, meta: ResponseMetadata): DecodeResult[Params] = {
output match {
case s: EndpointOutput.Single[_] =>
(s match {
case EndpointIO.Body(_, codec, _) => codec.decode(body)
case EndpointIO.StreamBodyWrapper(StreamingEndpointIO.Body(codec, _, _)) => codec.decode(body)
case EndpointIO.Header(name, codec, _) => codec.decode(meta.headers(name).toList)
case EndpointIO.Headers(codec, _) => codec.decode(meta.headers.toList)
case EndpointOutput.StatusCode(_, codec, _) => codec.decode(meta.code)
case EndpointOutput.FixedStatusCode(_, codec, _) => codec.decode(())
case EndpointIO.FixedHeader(_, codec, _) => codec.decode(())
case EndpointIO.Empty(codec, _) => codec.decode(())
case EndpointIO.Body(_, codec, _) => codec.decode(body)
case EndpointIO.StreamBodyWrapper(StreamingEndpointIO.Body(_, codec, _, _)) => codec.decode(body)
case EndpointIO.Header(name, codec, _) => codec.decode(meta.headers(name).toList)
case EndpointIO.Headers(codec, _) => codec.decode(meta.headers.toList)
case EndpointOutput.StatusCode(_, codec, _) => codec.decode(meta.code)
case EndpointOutput.FixedStatusCode(_, codec, _) => codec.decode(())
case EndpointIO.FixedHeader(_, codec, _) => codec.decode(())
case EndpointIO.Empty(codec, _) => codec.decode(())
case EndpointOutput.OneOf(mappings, codec) =>
mappings
.find(mapping => mapping.statusCode.isEmpty || mapping.statusCode.contains(meta.code)) match {
Expand Down Expand Up @@ -126,8 +123,8 @@ class EndpointToSttpClient(clientOptions: SttpClientOptions) {
case EndpointIO.Body(bodyType, codec, _) =>
val req2 = setBody(value, bodyType, codec, req)
(uri, req2)
case EndpointIO.StreamBodyWrapper(_) =>
val req2 = req.streamBody(value)
case EndpointIO.StreamBodyWrapper(StreamingEndpointIO.Body(streams, _, _, _)) =>
val req2 = req.streamBody(streams)(value.asInstanceOf[streams.BinaryStream])
(uri, req2)
case EndpointIO.Header(name, codec, _) =>
val req2 = codec
Expand All @@ -136,10 +133,9 @@ class EndpointToSttpClient(clientOptions: SttpClientOptions) {
(uri, req2)
case EndpointIO.Headers(codec, _) =>
val headers = codec.encode(value)
val req2 = headers.foldLeft(req) {
case (r, h) =>
val replaceExisting = HeaderNames.ContentType.equalsIgnoreCase(h.name) || HeaderNames.ContentLength.equalsIgnoreCase(h.name)
r.header(h, replaceExisting)
val req2 = headers.foldLeft(req) { case (r, h) =>
val replaceExisting = HeaderNames.ContentType.equalsIgnoreCase(h.name) || HeaderNames.ContentLength.equalsIgnoreCase(h.name)
r.header(h, replaceExisting)
}
(uri, req2)
case EndpointIO.FixedHeader(h, _, _) =>
Expand Down Expand Up @@ -193,7 +189,7 @@ class EndpointToSttpClient(clientOptions: SttpClientOptions) {
case RawBodyType.InputStreamBody => req.body(encoded)
case RawBodyType.FileBody => req.body(encoded)
case m: RawBodyType.MultipartBody =>
val parts: Seq[Part[BasicRequestBody]] = (encoded: Seq[RawPart]).flatMap { p =>
val parts: Seq[Part[RequestBody[Any]]] = (encoded: Seq[RawPart]).flatMap { p =>
m.partType(p.name).map { partType =>
// copying the name & body
val sttpPart1 = partToSttpPart(p.asInstanceOf[Part[Any]], partType.asInstanceOf[RawBodyType[Any]])
Expand All @@ -212,7 +208,7 @@ class EndpointToSttpClient(clientOptions: SttpClientOptions) {
req2.contentType(codec.format.mediaType)
}

private def partToSttpPart[R](p: Part[R], bodyType: RawBodyType[R]): Part[BasicRequestBody] =
private def partToSttpPart[R](p: Part[R], bodyType: RawBodyType[R]): Part[RequestBody[Any]] =
bodyType match {
case RawBodyType.StringBody(charset) => multipart(p.name, p.body, charset.toString)
case RawBodyType.ByteArrayBody => multipart(p.name, p.body)
Expand All @@ -222,30 +218,32 @@ class EndpointToSttpClient(clientOptions: SttpClientOptions) {
case RawBodyType.MultipartBody(_, _) => throw new IllegalArgumentException("Nested multipart bodies aren't supported")
}

private def responseAsFromOutputs(meta: ResponseMetadata, out: EndpointOutput[_]): ResponseAs[Any, Any] = {
if (bodyIsStream(out)) asStreamAlways[Any]
else {
out.bodyType
.map {
case RawBodyType.StringBody(charset) => asStringAlways(charset.name())
case RawBodyType.ByteArrayBody => asByteArrayAlways
case RawBodyType.ByteBufferBody => asByteArrayAlways.map(ByteBuffer.wrap)
case RawBodyType.InputStreamBody => asByteArrayAlways.map(new ByteArrayInputStream(_))
case RawBodyType.FileBody => asFileAlways(clientOptions.createFile(meta))
case RawBodyType.MultipartBody(_, _) => throw new IllegalArgumentException("Multipart bodies aren't supported in responses")
}
.getOrElse(ignore)
}.asInstanceOf[ResponseAs[Any, Any]]
private def responseAsFromOutputs(out: EndpointOutput[_]): ResponseAs[Any, Any] = {
(bodyIsStream(out) match {
case Some(streams) => asStreamAlwaysUnsafe(streams)
case None => {
out.bodyType
.map {
case RawBodyType.StringBody(charset) => asStringAlways(charset.name())
case RawBodyType.ByteArrayBody => asByteArrayAlways
case RawBodyType.ByteBufferBody => asByteArrayAlways.map(ByteBuffer.wrap)
case RawBodyType.InputStreamBody => asByteArrayAlways.map(new ByteArrayInputStream(_))
case RawBodyType.FileBody => asFileAlways(clientOptions.createFile())
case RawBodyType.MultipartBody(_, _) => throw new IllegalArgumentException("Multipart bodies aren't supported in responses")
}
.getOrElse(ignore)
}
}).asInstanceOf[ResponseAs[Any, Any]]
}

private def bodyIsStream[I](out: EndpointOutput[I]): Boolean = {
private def bodyIsStream[I](out: EndpointOutput[I]): Option[Streams[_]] = {
out match {
case _: EndpointIO.StreamBodyWrapper[_, _] => true
case EndpointIO.Pair(left, right, _, _) => List(left, right).exists(i => bodyIsStream(i))
case EndpointOutput.Pair(left, right, _, _) => List(left, right).exists(i => bodyIsStream(i))
case EndpointIO.MappedPair(wrapped, _) => bodyIsStream(wrapped)
case EndpointOutput.MappedPair(wrapped, _) => bodyIsStream(wrapped)
case _ => false
case EndpointIO.StreamBodyWrapper(StreamingEndpointIO.Body(streams, _, _, _)) => Some(streams)
case EndpointIO.Pair(left, right, _, _) => bodyIsStream(left).orElse(bodyIsStream(right))
case EndpointOutput.Pair(left, right, _, _) => bodyIsStream(left).orElse(bodyIsStream(right))
case EndpointIO.MappedPair(wrapped, _) => bodyIsStream(wrapped)
case EndpointOutput.MappedPair(wrapped, _) => bodyIsStream(wrapped)
case _ => None
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,10 @@ package sttp.tapir.client.sttp

import java.io.File

import sttp.client.ResponseMetadata
import sttp.tapir.Defaults

case class SttpClientOptions(createFile: ResponseMetadata => File)
case class SttpClientOptions(createFile: () => File)

object SttpClientOptions {
implicit val default: SttpClientOptions = SttpClientOptions(_ => Defaults.createTempFile())
implicit val default: SttpClientOptions = SttpClientOptions(Defaults.createTempFile)
}
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
package sttp.tapir.client.sttp

import sttp.client.Request
import sttp.client3.Request
import sttp.model.Uri
import sttp.tapir.{DecodeResult, Endpoint}

trait TapirSttpClient {
implicit class RichEndpoint[I, E, O, S](e: Endpoint[I, E, O, S]) {
implicit class RichEndpoint[I, E, O, R](e: Endpoint[I, E, O, R]) {

/**
* Interprets the endpoint as a client call, using the given `baseUri` as the starting point to create the target
Expand All @@ -18,7 +18,7 @@ trait TapirSttpClient {
*
* @throws IllegalArgumentException when response parsing fails
*/
def toSttpRequestUnsafe(baseUri: Uri)(implicit clientOptions: SttpClientOptions): I => Request[Either[E, O], S] =
def toSttpRequestUnsafe(baseUri: Uri)(implicit clientOptions: SttpClientOptions): I => Request[Either[E, O], R] =
new EndpointToSttpClient(clientOptions).toSttpRequestUnsafe(e, baseUri)

/**
Expand All @@ -30,7 +30,7 @@ trait TapirSttpClient {
* which can be sent using any sttp backend. The response will then contain the decoded error or success values
* (note that this can be the body enriched with data from headers/status code).
*/
def toSttpRequest(baseUri: Uri)(implicit clientOptions: SttpClientOptions): I => Request[DecodeResult[Either[E, O]], S] =
def toSttpRequest(baseUri: Uri)(implicit clientOptions: SttpClientOptions): I => Request[DecodeResult[Either[E, O]], R] =
new EndpointToSttpClient(clientOptions).toSttpRequest(e, baseUri)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ package sttp.tapir.client.sttp
import java.io.File

import sttp.tapir._
import sttp.client._
import sttp.client3._
import sttp.model.{Header, HeaderNames, MediaType, Part}
import sttp.tapir.tests.FruitData
import org.scalatest.funsuite.AnyFunSuite
Expand All @@ -21,7 +21,7 @@ class SttpClientRequestTests extends AnyFunSuite with Matchers {
.apply(FruitData(Part("image", testFile, contentType = Some(MediaType.ImageJpeg))))

// then
val part = sttpClientRequest.body.asInstanceOf[MultipartBody].parts.head
val part = sttpClientRequest.body.asInstanceOf[MultipartBody[Any]].parts.head
part.headers.filter(_.is(HeaderNames.ContentType)) shouldBe List(Header.contentType(MediaType.ImageJpeg))
}
}
Original file line number Diff line number Diff line change
@@ -1,16 +1,18 @@
package sttp.tapir.client.sttp

import cats.effect.{ContextShift, IO}
import cats.effect.{Blocker, ContextShift, IO}
import sttp.capabilities.fs2.Fs2Streams
import sttp.tapir.{DecodeResult, Endpoint}
import sttp.tapir.client.tests.ClientTests
import sttp.client._
import sttp.client.asynchttpclient.fs2.AsyncHttpClientFs2Backend
import sttp.client3._
import sttp.client3.asynchttpclient.fs2.AsyncHttpClientFs2Backend

import scala.concurrent.ExecutionContext

class SttpClientTests extends ClientTests[fs2.Stream[IO, Byte]] {
class SttpClientTests extends ClientTests[Fs2Streams[IO]](Fs2Streams[IO]) {
private implicit val cs: ContextShift[IO] = IO.contextShift(ExecutionContext.Implicits.global)
private implicit val backend: SttpBackend[IO, fs2.Stream[IO, Byte], NothingT] = AsyncHttpClientFs2Backend[IO]().unsafeRunSync()
private val backend: SttpBackend[IO, Fs2Streams[IO]] =
AsyncHttpClientFs2Backend[IO](Blocker.liftExecutionContext(ExecutionContext.Implicits.global)).unsafeRunSync()

override def mkStream(s: String): fs2.Stream[IO, Byte] = fs2.Stream.emits(s.getBytes("utf-8"))
override def rmStream(s: fs2.Stream[IO, Byte]): String =
Expand All @@ -19,16 +21,16 @@ class SttpClientTests extends ClientTests[fs2.Stream[IO, Byte]] {
.foldMonoid
.unsafeRunSync()

override def send[I, E, O, FN[_]](e: Endpoint[I, E, O, fs2.Stream[IO, Byte]], port: Port, args: I): IO[Either[E, O]] = {
e.toSttpRequestUnsafe(uri"http://localhost:$port").apply(args).send().map(_.body)
override def send[I, E, O, FN[_]](e: Endpoint[I, E, O, Fs2Streams[IO]], port: Port, args: I): IO[Either[E, O]] = {
e.toSttpRequestUnsafe(uri"http://localhost:$port").apply(args).send(backend).map(_.body)
}

override def safeSend[I, E, O, FN[_]](
e: Endpoint[I, E, O, fs2.Stream[IO, Byte]],
e: Endpoint[I, E, O, Fs2Streams[IO]],
port: Port,
args: I
): IO[DecodeResult[Either[E, O]]] = {
e.toSttpRequest(uri"http://localhost:$port").apply(args).send().map(_.body)
e.toSttpRequest(uri"http://localhost:$port").apply(args).send(backend).map(_.body)
}

override protected def afterAll(): Unit = {
Expand Down
Loading