Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

codegen: add streaming support for application/octet-stream contents #3966

Merged
merged 8 commits into from
Aug 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions generated-doc/out/generator/sbt-openapi-codegen.md
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ openapiJsonSerdeLib circe The j
openapiValidateNonDiscriminatedOneOfs true Whether to fail if variants of a oneOf without a discriminator cannot be disambiguated.
openapiMaxSchemasPerFile 400 Maximum number of schemas to generate in a single file (tweak if hitting javac class size limits).
openapiAdditionalPackages Nil Additional packageName/swaggerFile pairs for generating from multiple schemas
openapiStreamingImplementation fs2 Backend capability to assume for streaming content. Supports akka, fs2, pekko and zio.
===================================== ==================================== ==================================================================================================
```

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,9 @@ object GenScala {
private val jsonLibOpt: Opts[Option[String]] =
Opts.option[String]("jsonLib", "Json library to use for serdes", "j").orNone

private val streamingImplementationOpt: Opts[Option[String]] =
Opts.option[String]("streamingImplementation", "Capability to use for binary streams", "s").orNone

private val destDirOpt: Opts[File] =
Opts
.option[String]("destdir", "Destination directory", "d")
Expand All @@ -84,7 +87,8 @@ object GenScala {
headTagForNamesOpt,
jsonLibOpt,
validateNonDiscriminatedOneOfsOpt,
maxSchemasPerFileOpt
maxSchemasPerFileOpt,
streamingImplementationOpt
)
.mapN {
case (
Expand All @@ -96,7 +100,8 @@ object GenScala {
headTagForNames,
jsonLib,
validateNonDiscriminatedOneOfs,
maxSchemasPerFile
maxSchemasPerFile,
streamingImplementation
) =>
val objectName = maybeObjectName.getOrElse(DefaultObjectName)

Expand All @@ -109,6 +114,7 @@ object GenScala {
targetScala3,
headTagForNames,
jsonLib.getOrElse("circe"),
streamingImplementation.getOrElse("fs2"),
validateNonDiscriminatedOneOfs,
maxSchemasPerFile.getOrElse(400)
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,10 @@ object JsonSerdeLib extends Enumeration {
val Circe, Jsoniter, Zio = Value
type JsonSerdeLib = Value
}
object StreamingImplementation extends Enumeration {
val Akka, FS2, Pekko, Zio = Value
type StreamingImplementation = Value
}

object BasicGenerator {

Expand All @@ -34,6 +38,7 @@ object BasicGenerator {
targetScala3: Boolean,
useHeadTagForObjectNames: Boolean,
jsonSerdeLib: String,
streamingImplementation: String,
validateNonDiscriminatedOneOfs: Boolean,
maxSchemasPerFile: Int
): Map[String, String] = {
Expand All @@ -47,9 +52,20 @@ object BasicGenerator {
)
JsonSerdeLib.Circe
}
val normalisedStreamingImplementation = streamingImplementation.toLowerCase match {
case "akka" => StreamingImplementation.Akka
case "fs2" => StreamingImplementation.FS2
case "pekko" => StreamingImplementation.Pekko
case "zio" => StreamingImplementation.Zio
case _ =>
System.err.println(
s"!!! Unrecognised value $streamingImplementation for streaming impl -- should be one of akka, fs2, pekko or zio. Defaulting to fs2 !!!"
)
StreamingImplementation.FS2
}

val EndpointDefs(endpointsByTag, queryOrPathParamRefs, jsonParamRefs, enumsDefinedOnEndpointParams) =
endpointGenerator.endpointDefs(doc, useHeadTagForObjectNames, targetScala3, normalisedJsonLib)
endpointGenerator.endpointDefs(doc, useHeadTagForObjectNames, targetScala3, normalisedJsonLib, normalisedStreamingImplementation)
val GeneratedClassDefinitions(classDefns, jsonSerdes, schemas) =
classGenerator
.classDefs(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@ package sttp.tapir.codegen
import io.circe.Json
import sttp.tapir.codegen.BasicGenerator.{indent, mapSchemaSimpleTypeToType, strippedToCamelCase}
import sttp.tapir.codegen.JsonSerdeLib.JsonSerdeLib
import sttp.tapir.codegen.StreamingImplementation
import sttp.tapir.codegen.StreamingImplementation.StreamingImplementation
import sttp.tapir.codegen.openapi.models.OpenapiModels.{OpenapiDocument, OpenapiParameter, OpenapiPath, OpenapiRequestBody, OpenapiResponse}
import sttp.tapir.codegen.openapi.models.OpenapiSchemaType.{
OpenapiSchemaAny,
Expand All @@ -10,7 +12,8 @@ import sttp.tapir.codegen.openapi.models.OpenapiSchemaType.{
OpenapiSchemaEnum,
OpenapiSchemaMap,
OpenapiSchemaRef,
OpenapiSchemaSimpleType
OpenapiSchemaSimpleType,
OpenapiSchemaString
}
import sttp.tapir.codegen.openapi.models.{OpenapiComponent, OpenapiSchemaType, OpenapiSecuritySchemeType, SpecificationExtensionRenderer}
import sttp.tapir.codegen.util.JavaEscape
Expand Down Expand Up @@ -55,12 +58,13 @@ class EndpointGenerator {
doc: OpenapiDocument,
useHeadTagForObjectNames: Boolean,
targetScala3: Boolean,
jsonSerdeLib: JsonSerdeLib
jsonSerdeLib: JsonSerdeLib,
streamingImplementation: StreamingImplementation
): EndpointDefs = {
val components = Option(doc.components).flatten
val GeneratedEndpoints(endpointsByFile, queryOrPathParamRefs, jsonParamRefs, definesEnumQueryParam) =
doc.paths
.map(generatedEndpoints(components, useHeadTagForObjectNames, targetScala3, jsonSerdeLib))
.map(generatedEndpoints(components, useHeadTagForObjectNames, targetScala3, jsonSerdeLib, streamingImplementation))
.foldLeft(GeneratedEndpoints(Nil, Set.empty, Set.empty, false))(_ merge _)
val endpointDecls = endpointsByFile.map { case GeneratedEndpointsForFile(k, ge) =>
val definitions = ge
Expand All @@ -84,7 +88,8 @@ class EndpointGenerator {
components: Option[OpenapiComponent],
useHeadTagForObjectNames: Boolean,
targetScala3: Boolean,
jsonSerdeLib: JsonSerdeLib
jsonSerdeLib: JsonSerdeLib,
streamingImplementation: StreamingImplementation
)(p: OpenapiPath): GeneratedEndpoints = {
val parameters = components.map(_.parameters).getOrElse(Map.empty)
val securitySchemes = components.map(_.securitySchemes).getOrElse(Map.empty)
Expand All @@ -106,14 +111,15 @@ class EndpointGenerator {
}

val name = strippedToCamelCase(m.operationId.getOrElse(m.methodType + p.url.capitalize))
val (inParams, maybeLocalEnums) = ins(m.resolvedParameters, m.requestBody, name, targetScala3, jsonSerdeLib)
val (inParams, maybeLocalEnums) =
ins(m.resolvedParameters, m.requestBody, name, targetScala3, jsonSerdeLib, streamingImplementation)
val definition =
s"""|endpoint
| .${m.methodType}
| ${urlMapper(p.url, m.resolvedParameters)}
|${indent(2)(security(securitySchemes, m.security))}
|${indent(2)(inParams)}
|${indent(2)(outs(m.responses))}
|${indent(2)(outs(m.responses, streamingImplementation))}
|${indent(2)(tags(m.tags))}
|$attributeString
|""".stripMargin.linesIterator.filterNot(_.trim.isEmpty).mkString("\n")
Expand Down Expand Up @@ -211,7 +217,8 @@ class EndpointGenerator {
requestBody: Option[OpenapiRequestBody],
endpointName: String,
targetScala3: Boolean,
jsonSerdeLib: JsonSerdeLib
jsonSerdeLib: JsonSerdeLib,
streamingImplementation: StreamingImplementation
)(implicit location: Location): (String, Option[String]) = {
def getEnumParamDefn(param: OpenapiParameter, e: OpenapiSchemaEnum, isArray: Boolean) = {
val enumName = endpointName.capitalize + strippedToCamelCase(param.name).capitalize
Expand Down Expand Up @@ -267,7 +274,7 @@ class EndpointGenerator {
val rqBody = requestBody.flatMap { b =>
if (b.content.isEmpty) None
else if (b.content.size != 1) bail(s"We can handle only one requestBody content! Saw ${b.content.map(_.contentType)}")
else Some(s".in(${contentTypeMapper(b.content.head.contentType, b.content.head.schema, b.required)})")
else Some(s".in(${contentTypeMapper(b.content.head.contentType, b.content.head.schema, streamingImplementation, b.required)})")
}

(params ++ rqBody).mkString("\n") -> maybeEnumDefns.foldLeft(Option.empty[String]) {
Expand Down Expand Up @@ -298,7 +305,7 @@ class EndpointGenerator {
// treats redirects as ok
private val okStatus = """([23]\d\d)""".r
private val errorStatus = """([45]\d\d)""".r
private def outs(responses: Seq[OpenapiResponse])(implicit location: Location) = {
private def outs(responses: Seq[OpenapiResponse], streamingImplementation: StreamingImplementation)(implicit location: Location) = {
// .errorOut(stringBody)
// .out(jsonBody[List[Book]])

Expand All @@ -315,13 +322,13 @@ class EndpointGenerator {
case content +: Nil =>
resp.code match {
case "200" =>
s".out(${contentTypeMapper(content.contentType, content.schema)}$d)"
s".out(${contentTypeMapper(content.contentType, content.schema, streamingImplementation)}$d)"
case okStatus(s) =>
s".out(${contentTypeMapper(content.contentType, content.schema)}$d.and(statusCode(sttp.model.StatusCode($s))))"
s".out(${contentTypeMapper(content.contentType, content.schema, streamingImplementation)}$d.and(statusCode(sttp.model.StatusCode($s))))"
case "default" =>
s".errorOut(${contentTypeMapper(content.contentType, content.schema)}$d)"
s".errorOut(${contentTypeMapper(content.contentType, content.schema, streamingImplementation)}$d)"
case errorStatus(s) =>
s".errorOut(${contentTypeMapper(content.contentType, content.schema)}$d.and(statusCode(sttp.model.StatusCode($s))))"
s".errorOut(${contentTypeMapper(content.contentType, content.schema, streamingImplementation)}$d.and(statusCode(sttp.model.StatusCode($s))))"
case x =>
bail(s"Statuscode mapping is incomplete! Cannot handle $x")
}
Expand All @@ -333,7 +340,12 @@ class EndpointGenerator {
.mkString("\n")
}

private def contentTypeMapper(contentType: String, schema: OpenapiSchemaType, required: Boolean = true)(implicit location: Location) = {
private def contentTypeMapper(
contentType: String,
schema: OpenapiSchemaType,
streamingImplementation: StreamingImplementation,
required: Boolean = true
)(implicit location: Location) = {
contentType match {
case "text/plain" =>
"stringBody"
Expand Down Expand Up @@ -362,6 +374,31 @@ class EndpointGenerator {
s"multipartBody[$t]"
case x => bail(s"$contentType only supports schema ref or binary. Found $x")
}
case "application/octet-stream" =>
val capability = streamingImplementation match {
case StreamingImplementation.Akka => "sttp.capabilities.akka.AkkaStreams"
case StreamingImplementation.FS2 => "sttp.capabilities.fs2.Fs2Streams[cats.effect.IO]"
case StreamingImplementation.Pekko => "sttp.capabilities.pekko.PekkoStreams"
case StreamingImplementation.Zio => "sttp.capabilities.zio.ZioStreams"
}
schema match {
case _: OpenapiSchemaString =>
s"streamTextBody($capability)(CodecFormat.OctetStream())"
case schema =>
val outT = schema match {
case st: OpenapiSchemaSimpleType =>
val (t, _) = mapSchemaSimpleTypeToType(st)
t
case OpenapiSchemaArray(st: OpenapiSchemaSimpleType, _) =>
val (t, _) = mapSchemaSimpleTypeToType(st)
s"List[$t]"
case OpenapiSchemaMap(st: OpenapiSchemaSimpleType, _) =>
val (t, _) = mapSchemaSimpleTypeToType(st)
s"Map[String, $t]"
case x => bail(s"Can't create this param as output (found $x)")
}
s"streamBody($capability)(Schema.binary[$outT], CodecFormat.OctetStream())"
}

case x => bail(s"Not all content types supported! Found $x")
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@ class BasicGeneratorSpec extends CompileCheckTestBase {
useHeadTagForObjectNames = useHeadTagForObjectNames,
jsonSerdeLib = jsonSerdeLib,
validateNonDiscriminatedOneOfs = true,
maxSchemasPerFile = 400
maxSchemasPerFile = 400,
streamingImplementation = "fs2"
)
}
def gen(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -391,7 +391,13 @@ class ClassDefinitionGeneratorSpec extends CompileCheckTestBase {
case Left(value) => throw new Exception(value)
case Right(doc) =>
new EndpointGenerator()
.endpointDefs(doc, useHeadTagForObjectNames = false, targetScala3 = false, jsonSerdeLib = JsonSerdeLib.Circe)
.endpointDefs(
doc,
useHeadTagForObjectNames = false,
targetScala3 = false,
jsonSerdeLib = JsonSerdeLib.Circe,
streamingImplementation = StreamingImplementation.FS2
)
.endpointDecls(None)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,13 @@ class EndpointGeneratorSpec extends CompileCheckTestBase {
)
val generatedCode = BasicGenerator.imports(JsonSerdeLib.Circe) ++
new EndpointGenerator()
.endpointDefs(doc, useHeadTagForObjectNames = false, targetScala3 = false, jsonSerdeLib = JsonSerdeLib.Circe)
.endpointDefs(
doc,
useHeadTagForObjectNames = false,
targetScala3 = false,
jsonSerdeLib = JsonSerdeLib.Circe,
streamingImplementation = StreamingImplementation.FS2
)
.endpointDecls(None)
generatedCode should include("val getTestAsdId =")
generatedCode should include(""".in(query[Option[String]]("fgh-id"))""")
Expand Down Expand Up @@ -142,7 +148,13 @@ class EndpointGeneratorSpec extends CompileCheckTestBase {
)
BasicGenerator.imports(JsonSerdeLib.Circe) ++
new EndpointGenerator()
.endpointDefs(doc, useHeadTagForObjectNames = false, targetScala3 = false, jsonSerdeLib = JsonSerdeLib.Circe)
.endpointDefs(
doc,
useHeadTagForObjectNames = false,
targetScala3 = false,
jsonSerdeLib = JsonSerdeLib.Circe,
streamingImplementation = StreamingImplementation.FS2
)
.endpointDecls(None) shouldCompile ()
}

Expand Down Expand Up @@ -188,7 +200,13 @@ class EndpointGeneratorSpec extends CompileCheckTestBase {
)
val generatedCode = BasicGenerator.imports(JsonSerdeLib.Circe) ++
new EndpointGenerator()
.endpointDefs(doc, useHeadTagForObjectNames = false, targetScala3 = false, jsonSerdeLib = JsonSerdeLib.Circe)
.endpointDefs(
doc,
useHeadTagForObjectNames = false,
targetScala3 = false,
jsonSerdeLib = JsonSerdeLib.Circe,
streamingImplementation = StreamingImplementation.FS2
)
.endpointDecls(None)
generatedCode should include(
""".out(stringBody.description("Processing").and(statusCode(sttp.model.StatusCode(202))))"""
Expand Down Expand Up @@ -253,7 +271,8 @@ class EndpointGeneratorSpec extends CompileCheckTestBase {
useHeadTagForObjectNames = false,
jsonSerdeLib = "circe",
validateNonDiscriminatedOneOfs = true,
maxSchemasPerFile = 400
maxSchemasPerFile = 400,
streamingImplementation = "fs2"
)("TapirGeneratedEndpoints")
generatedCode should include(
"""file: sttp.model.Part[java.io.File]"""
Expand All @@ -274,7 +293,8 @@ class EndpointGeneratorSpec extends CompileCheckTestBase {
useHeadTagForObjectNames = false,
jsonSerdeLib = "circe",
validateNonDiscriminatedOneOfs = true,
maxSchemasPerFile = 400
maxSchemasPerFile = 400,
streamingImplementation = "fs2"
)("TapirGeneratedEndpoints")
generatedCode shouldCompile ()
val expectedAttrDecls = Seq(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,18 @@ package sttp.tapir.sbt

import sbt._

case class OpenApiConfiguration(
swaggerFile: File,
packageName: String,
objectName: String,
useHeadTagForObjectName: Boolean,
jsonSerdeLib: String,
streamingImplementation: String,
validateNonDiscriminatedOneOfs: Boolean,
maxSchemasPerFile: Int,
additionalPackages: List[(String, File)]
)

trait OpenapiCodegenKeys {
lazy val openapiSwaggerFile = settingKey[File]("The swagger file with the api definitions.")
lazy val openapiPackage = settingKey[String]("The name for the generated package.")
Expand All @@ -13,7 +25,10 @@ trait OpenapiCodegenKeys {
lazy val openapiValidateNonDiscriminatedOneOfs =
settingKey[Boolean]("Whether to fail if variants of a oneOf without a discriminator cannot be disambiguated..")
lazy val openapiMaxSchemasPerFile = settingKey[Int]("Maximum number of schemas to generate for a single file")
lazy val openapiAdditionalPackages = taskKey[List[(String, File)]]("Addition package -> spec mappings to generate.")
lazy val openapiAdditionalPackages = settingKey[List[(String, File)]]("Addition package -> spec mappings to generate.")
lazy val openapiStreamingImplementation = settingKey[String]("Implementation for streamTextBody. Supports: akka, fs2, pekko, zio.")
lazy val openapiOpenApiConfiguration =
settingKey[OpenApiConfiguration]("Aggregation of other settings. Manually set value will be disregarded.")

lazy val generateTapirDefinitions = taskKey[Unit]("The task that generates tapir definitions based on the input swagger file.")
}
Expand Down
Loading
Loading