diff --git a/generated-doc/out/generator/sbt-openapi-codegen.md b/generated-doc/out/generator/sbt-openapi-codegen.md index 75980962e9..f0bdae75b6 100644 --- a/generated-doc/out/generator/sbt-openapi-codegen.md +++ b/generated-doc/out/generator/sbt-openapi-codegen.md @@ -44,6 +44,7 @@ openapiJsonSerdeLib circe The j openapiValidateNonDiscriminatedOneOfs true Whether to fail if variants of a oneOf without a discriminator cannot be disambiguated. openapiMaxSchemasPerFile 400 Maximum number of schemas to generate in a single file (tweak if hitting javac class size limits). openapiAdditionalPackages Nil Additional packageName/swaggerFile pairs for generating from multiple schemas +openapiStreamingImplementation fs2 Backend capability to assume for streaming content. Supports akka, fs2, pekko and zio. ===================================== ==================================== ================================================================================================== ``` diff --git a/openapi-codegen/cli/src/main/scala/sttp/tapir/codegen/GenScala.scala b/openapi-codegen/cli/src/main/scala/sttp/tapir/codegen/GenScala.scala index ad41daa46a..361167d58e 100644 --- a/openapi-codegen/cli/src/main/scala/sttp/tapir/codegen/GenScala.scala +++ b/openapi-codegen/cli/src/main/scala/sttp/tapir/codegen/GenScala.scala @@ -62,6 +62,9 @@ object GenScala { private val jsonLibOpt: Opts[Option[String]] = Opts.option[String]("jsonLib", "Json library to use for serdes", "j").orNone + private val streamingImplementationOpt: Opts[Option[String]] = + Opts.option[String]("streamingImplementation", "Capability to use for binary streams", "s").orNone + private val destDirOpt: Opts[File] = Opts .option[String]("destdir", "Destination directory", "d") @@ -84,7 +87,8 @@ object GenScala { headTagForNamesOpt, jsonLibOpt, validateNonDiscriminatedOneOfsOpt, - maxSchemasPerFileOpt + maxSchemasPerFileOpt, + streamingImplementationOpt ) .mapN { case ( @@ -96,7 +100,8 @@ object GenScala { headTagForNames, jsonLib, validateNonDiscriminatedOneOfs, - maxSchemasPerFile + maxSchemasPerFile, + streamingImplementation ) => val objectName = maybeObjectName.getOrElse(DefaultObjectName) @@ -109,6 +114,7 @@ object GenScala { targetScala3, headTagForNames, jsonLib.getOrElse("circe"), + streamingImplementation.getOrElse("fs2"), validateNonDiscriminatedOneOfs, maxSchemasPerFile.getOrElse(400) ) diff --git a/openapi-codegen/core/src/main/scala/sttp/tapir/codegen/BasicGenerator.scala b/openapi-codegen/core/src/main/scala/sttp/tapir/codegen/BasicGenerator.scala index f75afced78..b7d21b9cdd 100644 --- a/openapi-codegen/core/src/main/scala/sttp/tapir/codegen/BasicGenerator.scala +++ b/openapi-codegen/core/src/main/scala/sttp/tapir/codegen/BasicGenerator.scala @@ -21,6 +21,10 @@ object JsonSerdeLib extends Enumeration { val Circe, Jsoniter, Zio = Value type JsonSerdeLib = Value } +object StreamingImplementation extends Enumeration { + val Akka, FS2, Pekko, Zio = Value + type StreamingImplementation = Value +} object BasicGenerator { @@ -34,6 +38,7 @@ object BasicGenerator { targetScala3: Boolean, useHeadTagForObjectNames: Boolean, jsonSerdeLib: String, + streamingImplementation: String, validateNonDiscriminatedOneOfs: Boolean, maxSchemasPerFile: Int ): Map[String, String] = { @@ -47,9 +52,20 @@ object BasicGenerator { ) JsonSerdeLib.Circe } + val normalisedStreamingImplementation = streamingImplementation.toLowerCase match { + case "akka" => StreamingImplementation.Akka + case "fs2" => StreamingImplementation.FS2 + case "pekko" => StreamingImplementation.Pekko + case "zio" => StreamingImplementation.Zio + case _ => + System.err.println( + s"!!! Unrecognised value $streamingImplementation for streaming impl -- should be one of akka, fs2, pekko or zio. Defaulting to fs2 !!!" + ) + StreamingImplementation.FS2 + } val EndpointDefs(endpointsByTag, queryOrPathParamRefs, jsonParamRefs, enumsDefinedOnEndpointParams) = - endpointGenerator.endpointDefs(doc, useHeadTagForObjectNames, targetScala3, normalisedJsonLib) + endpointGenerator.endpointDefs(doc, useHeadTagForObjectNames, targetScala3, normalisedJsonLib, normalisedStreamingImplementation) val GeneratedClassDefinitions(classDefns, jsonSerdes, schemas) = classGenerator .classDefs( diff --git a/openapi-codegen/core/src/main/scala/sttp/tapir/codegen/EndpointGenerator.scala b/openapi-codegen/core/src/main/scala/sttp/tapir/codegen/EndpointGenerator.scala index 24f260f6f4..5f88c23bc0 100644 --- a/openapi-codegen/core/src/main/scala/sttp/tapir/codegen/EndpointGenerator.scala +++ b/openapi-codegen/core/src/main/scala/sttp/tapir/codegen/EndpointGenerator.scala @@ -2,6 +2,8 @@ package sttp.tapir.codegen import io.circe.Json import sttp.tapir.codegen.BasicGenerator.{indent, mapSchemaSimpleTypeToType, strippedToCamelCase} import sttp.tapir.codegen.JsonSerdeLib.JsonSerdeLib +import sttp.tapir.codegen.StreamingImplementation +import sttp.tapir.codegen.StreamingImplementation.StreamingImplementation import sttp.tapir.codegen.openapi.models.OpenapiModels.{OpenapiDocument, OpenapiParameter, OpenapiPath, OpenapiRequestBody, OpenapiResponse} import sttp.tapir.codegen.openapi.models.OpenapiSchemaType.{ OpenapiSchemaAny, @@ -10,7 +12,8 @@ import sttp.tapir.codegen.openapi.models.OpenapiSchemaType.{ OpenapiSchemaEnum, OpenapiSchemaMap, OpenapiSchemaRef, - OpenapiSchemaSimpleType + OpenapiSchemaSimpleType, + OpenapiSchemaString } import sttp.tapir.codegen.openapi.models.{OpenapiComponent, OpenapiSchemaType, OpenapiSecuritySchemeType, SpecificationExtensionRenderer} import sttp.tapir.codegen.util.JavaEscape @@ -55,12 +58,13 @@ class EndpointGenerator { doc: OpenapiDocument, useHeadTagForObjectNames: Boolean, targetScala3: Boolean, - jsonSerdeLib: JsonSerdeLib + jsonSerdeLib: JsonSerdeLib, + streamingImplementation: StreamingImplementation ): EndpointDefs = { val components = Option(doc.components).flatten val GeneratedEndpoints(endpointsByFile, queryOrPathParamRefs, jsonParamRefs, definesEnumQueryParam) = doc.paths - .map(generatedEndpoints(components, useHeadTagForObjectNames, targetScala3, jsonSerdeLib)) + .map(generatedEndpoints(components, useHeadTagForObjectNames, targetScala3, jsonSerdeLib, streamingImplementation)) .foldLeft(GeneratedEndpoints(Nil, Set.empty, Set.empty, false))(_ merge _) val endpointDecls = endpointsByFile.map { case GeneratedEndpointsForFile(k, ge) => val definitions = ge @@ -84,7 +88,8 @@ class EndpointGenerator { components: Option[OpenapiComponent], useHeadTagForObjectNames: Boolean, targetScala3: Boolean, - jsonSerdeLib: JsonSerdeLib + jsonSerdeLib: JsonSerdeLib, + streamingImplementation: StreamingImplementation )(p: OpenapiPath): GeneratedEndpoints = { val parameters = components.map(_.parameters).getOrElse(Map.empty) val securitySchemes = components.map(_.securitySchemes).getOrElse(Map.empty) @@ -106,14 +111,15 @@ class EndpointGenerator { } val name = strippedToCamelCase(m.operationId.getOrElse(m.methodType + p.url.capitalize)) - val (inParams, maybeLocalEnums) = ins(m.resolvedParameters, m.requestBody, name, targetScala3, jsonSerdeLib) + val (inParams, maybeLocalEnums) = + ins(m.resolvedParameters, m.requestBody, name, targetScala3, jsonSerdeLib, streamingImplementation) val definition = s"""|endpoint | .${m.methodType} | ${urlMapper(p.url, m.resolvedParameters)} |${indent(2)(security(securitySchemes, m.security))} |${indent(2)(inParams)} - |${indent(2)(outs(m.responses))} + |${indent(2)(outs(m.responses, streamingImplementation))} |${indent(2)(tags(m.tags))} |$attributeString |""".stripMargin.linesIterator.filterNot(_.trim.isEmpty).mkString("\n") @@ -211,7 +217,8 @@ class EndpointGenerator { requestBody: Option[OpenapiRequestBody], endpointName: String, targetScala3: Boolean, - jsonSerdeLib: JsonSerdeLib + jsonSerdeLib: JsonSerdeLib, + streamingImplementation: StreamingImplementation )(implicit location: Location): (String, Option[String]) = { def getEnumParamDefn(param: OpenapiParameter, e: OpenapiSchemaEnum, isArray: Boolean) = { val enumName = endpointName.capitalize + strippedToCamelCase(param.name).capitalize @@ -267,7 +274,7 @@ class EndpointGenerator { val rqBody = requestBody.flatMap { b => if (b.content.isEmpty) None else if (b.content.size != 1) bail(s"We can handle only one requestBody content! Saw ${b.content.map(_.contentType)}") - else Some(s".in(${contentTypeMapper(b.content.head.contentType, b.content.head.schema, b.required)})") + else Some(s".in(${contentTypeMapper(b.content.head.contentType, b.content.head.schema, streamingImplementation, b.required)})") } (params ++ rqBody).mkString("\n") -> maybeEnumDefns.foldLeft(Option.empty[String]) { @@ -298,7 +305,7 @@ class EndpointGenerator { // treats redirects as ok private val okStatus = """([23]\d\d)""".r private val errorStatus = """([45]\d\d)""".r - private def outs(responses: Seq[OpenapiResponse])(implicit location: Location) = { + private def outs(responses: Seq[OpenapiResponse], streamingImplementation: StreamingImplementation)(implicit location: Location) = { // .errorOut(stringBody) // .out(jsonBody[List[Book]]) @@ -315,13 +322,13 @@ class EndpointGenerator { case content +: Nil => resp.code match { case "200" => - s".out(${contentTypeMapper(content.contentType, content.schema)}$d)" + s".out(${contentTypeMapper(content.contentType, content.schema, streamingImplementation)}$d)" case okStatus(s) => - s".out(${contentTypeMapper(content.contentType, content.schema)}$d.and(statusCode(sttp.model.StatusCode($s))))" + s".out(${contentTypeMapper(content.contentType, content.schema, streamingImplementation)}$d.and(statusCode(sttp.model.StatusCode($s))))" case "default" => - s".errorOut(${contentTypeMapper(content.contentType, content.schema)}$d)" + s".errorOut(${contentTypeMapper(content.contentType, content.schema, streamingImplementation)}$d)" case errorStatus(s) => - s".errorOut(${contentTypeMapper(content.contentType, content.schema)}$d.and(statusCode(sttp.model.StatusCode($s))))" + s".errorOut(${contentTypeMapper(content.contentType, content.schema, streamingImplementation)}$d.and(statusCode(sttp.model.StatusCode($s))))" case x => bail(s"Statuscode mapping is incomplete! Cannot handle $x") } @@ -333,7 +340,12 @@ class EndpointGenerator { .mkString("\n") } - private def contentTypeMapper(contentType: String, schema: OpenapiSchemaType, required: Boolean = true)(implicit location: Location) = { + private def contentTypeMapper( + contentType: String, + schema: OpenapiSchemaType, + streamingImplementation: StreamingImplementation, + required: Boolean = true + )(implicit location: Location) = { contentType match { case "text/plain" => "stringBody" @@ -362,6 +374,31 @@ class EndpointGenerator { s"multipartBody[$t]" case x => bail(s"$contentType only supports schema ref or binary. Found $x") } + case "application/octet-stream" => + val capability = streamingImplementation match { + case StreamingImplementation.Akka => "sttp.capabilities.akka.AkkaStreams" + case StreamingImplementation.FS2 => "sttp.capabilities.fs2.Fs2Streams[cats.effect.IO]" + case StreamingImplementation.Pekko => "sttp.capabilities.pekko.PekkoStreams" + case StreamingImplementation.Zio => "sttp.capabilities.zio.ZioStreams" + } + schema match { + case _: OpenapiSchemaString => + s"streamTextBody($capability)(CodecFormat.OctetStream())" + case schema => + val outT = schema match { + case st: OpenapiSchemaSimpleType => + val (t, _) = mapSchemaSimpleTypeToType(st) + t + case OpenapiSchemaArray(st: OpenapiSchemaSimpleType, _) => + val (t, _) = mapSchemaSimpleTypeToType(st) + s"List[$t]" + case OpenapiSchemaMap(st: OpenapiSchemaSimpleType, _) => + val (t, _) = mapSchemaSimpleTypeToType(st) + s"Map[String, $t]" + case x => bail(s"Can't create this param as output (found $x)") + } + s"streamBody($capability)(Schema.binary[$outT], CodecFormat.OctetStream())" + } case x => bail(s"Not all content types supported! Found $x") } diff --git a/openapi-codegen/core/src/test/scala/sttp/tapir/codegen/BasicGeneratorSpec.scala b/openapi-codegen/core/src/test/scala/sttp/tapir/codegen/BasicGeneratorSpec.scala index 5f8ae2b027..a4f334339b 100644 --- a/openapi-codegen/core/src/test/scala/sttp/tapir/codegen/BasicGeneratorSpec.scala +++ b/openapi-codegen/core/src/test/scala/sttp/tapir/codegen/BasicGeneratorSpec.scala @@ -17,7 +17,8 @@ class BasicGeneratorSpec extends CompileCheckTestBase { useHeadTagForObjectNames = useHeadTagForObjectNames, jsonSerdeLib = jsonSerdeLib, validateNonDiscriminatedOneOfs = true, - maxSchemasPerFile = 400 + maxSchemasPerFile = 400, + streamingImplementation = "fs2" ) } def gen( diff --git a/openapi-codegen/core/src/test/scala/sttp/tapir/codegen/ClassDefinitionGeneratorSpec.scala b/openapi-codegen/core/src/test/scala/sttp/tapir/codegen/ClassDefinitionGeneratorSpec.scala index 28f63ad0c4..b73060d8a7 100644 --- a/openapi-codegen/core/src/test/scala/sttp/tapir/codegen/ClassDefinitionGeneratorSpec.scala +++ b/openapi-codegen/core/src/test/scala/sttp/tapir/codegen/ClassDefinitionGeneratorSpec.scala @@ -391,7 +391,13 @@ class ClassDefinitionGeneratorSpec extends CompileCheckTestBase { case Left(value) => throw new Exception(value) case Right(doc) => new EndpointGenerator() - .endpointDefs(doc, useHeadTagForObjectNames = false, targetScala3 = false, jsonSerdeLib = JsonSerdeLib.Circe) + .endpointDefs( + doc, + useHeadTagForObjectNames = false, + targetScala3 = false, + jsonSerdeLib = JsonSerdeLib.Circe, + streamingImplementation = StreamingImplementation.FS2 + ) .endpointDecls(None) } diff --git a/openapi-codegen/core/src/test/scala/sttp/tapir/codegen/EndpointGeneratorSpec.scala b/openapi-codegen/core/src/test/scala/sttp/tapir/codegen/EndpointGeneratorSpec.scala index 491a5eb894..453ea55dab 100644 --- a/openapi-codegen/core/src/test/scala/sttp/tapir/codegen/EndpointGeneratorSpec.scala +++ b/openapi-codegen/core/src/test/scala/sttp/tapir/codegen/EndpointGeneratorSpec.scala @@ -63,7 +63,13 @@ class EndpointGeneratorSpec extends CompileCheckTestBase { ) val generatedCode = BasicGenerator.imports(JsonSerdeLib.Circe) ++ new EndpointGenerator() - .endpointDefs(doc, useHeadTagForObjectNames = false, targetScala3 = false, jsonSerdeLib = JsonSerdeLib.Circe) + .endpointDefs( + doc, + useHeadTagForObjectNames = false, + targetScala3 = false, + jsonSerdeLib = JsonSerdeLib.Circe, + streamingImplementation = StreamingImplementation.FS2 + ) .endpointDecls(None) generatedCode should include("val getTestAsdId =") generatedCode should include(""".in(query[Option[String]]("fgh-id"))""") @@ -142,7 +148,13 @@ class EndpointGeneratorSpec extends CompileCheckTestBase { ) BasicGenerator.imports(JsonSerdeLib.Circe) ++ new EndpointGenerator() - .endpointDefs(doc, useHeadTagForObjectNames = false, targetScala3 = false, jsonSerdeLib = JsonSerdeLib.Circe) + .endpointDefs( + doc, + useHeadTagForObjectNames = false, + targetScala3 = false, + jsonSerdeLib = JsonSerdeLib.Circe, + streamingImplementation = StreamingImplementation.FS2 + ) .endpointDecls(None) shouldCompile () } @@ -188,7 +200,13 @@ class EndpointGeneratorSpec extends CompileCheckTestBase { ) val generatedCode = BasicGenerator.imports(JsonSerdeLib.Circe) ++ new EndpointGenerator() - .endpointDefs(doc, useHeadTagForObjectNames = false, targetScala3 = false, jsonSerdeLib = JsonSerdeLib.Circe) + .endpointDefs( + doc, + useHeadTagForObjectNames = false, + targetScala3 = false, + jsonSerdeLib = JsonSerdeLib.Circe, + streamingImplementation = StreamingImplementation.FS2 + ) .endpointDecls(None) generatedCode should include( """.out(stringBody.description("Processing").and(statusCode(sttp.model.StatusCode(202))))""" @@ -253,7 +271,8 @@ class EndpointGeneratorSpec extends CompileCheckTestBase { useHeadTagForObjectNames = false, jsonSerdeLib = "circe", validateNonDiscriminatedOneOfs = true, - maxSchemasPerFile = 400 + maxSchemasPerFile = 400, + streamingImplementation = "fs2" )("TapirGeneratedEndpoints") generatedCode should include( """file: sttp.model.Part[java.io.File]""" @@ -274,7 +293,8 @@ class EndpointGeneratorSpec extends CompileCheckTestBase { useHeadTagForObjectNames = false, jsonSerdeLib = "circe", validateNonDiscriminatedOneOfs = true, - maxSchemasPerFile = 400 + maxSchemasPerFile = 400, + streamingImplementation = "fs2" )("TapirGeneratedEndpoints") generatedCode shouldCompile () val expectedAttrDecls = Seq( diff --git a/openapi-codegen/sbt-plugin/src/main/scala/sttp/tapir/sbt/OpenapiCodegenKeys.scala b/openapi-codegen/sbt-plugin/src/main/scala/sttp/tapir/sbt/OpenapiCodegenKeys.scala index 312a261207..ba00d8a34e 100644 --- a/openapi-codegen/sbt-plugin/src/main/scala/sttp/tapir/sbt/OpenapiCodegenKeys.scala +++ b/openapi-codegen/sbt-plugin/src/main/scala/sttp/tapir/sbt/OpenapiCodegenKeys.scala @@ -2,6 +2,18 @@ package sttp.tapir.sbt import sbt._ +case class OpenApiConfiguration( + swaggerFile: File, + packageName: String, + objectName: String, + useHeadTagForObjectName: Boolean, + jsonSerdeLib: String, + streamingImplementation: String, + validateNonDiscriminatedOneOfs: Boolean, + maxSchemasPerFile: Int, + additionalPackages: List[(String, File)] +) + trait OpenapiCodegenKeys { lazy val openapiSwaggerFile = settingKey[File]("The swagger file with the api definitions.") lazy val openapiPackage = settingKey[String]("The name for the generated package.") @@ -13,7 +25,10 @@ trait OpenapiCodegenKeys { lazy val openapiValidateNonDiscriminatedOneOfs = settingKey[Boolean]("Whether to fail if variants of a oneOf without a discriminator cannot be disambiguated..") lazy val openapiMaxSchemasPerFile = settingKey[Int]("Maximum number of schemas to generate for a single file") - lazy val openapiAdditionalPackages = taskKey[List[(String, File)]]("Addition package -> spec mappings to generate.") + lazy val openapiAdditionalPackages = settingKey[List[(String, File)]]("Addition package -> spec mappings to generate.") + lazy val openapiStreamingImplementation = settingKey[String]("Implementation for streamTextBody. Supports: akka, fs2, pekko, zio.") + lazy val openapiOpenApiConfiguration = + settingKey[OpenApiConfiguration]("Aggregation of other settings. Manually set value will be disregarded.") lazy val generateTapirDefinitions = taskKey[Unit]("The task that generates tapir definitions based on the input swagger file.") } diff --git a/openapi-codegen/sbt-plugin/src/main/scala/sttp/tapir/sbt/OpenapiCodegenPlugin.scala b/openapi-codegen/sbt-plugin/src/main/scala/sttp/tapir/sbt/OpenapiCodegenPlugin.scala index dfab718af3..5f086617e1 100644 --- a/openapi-codegen/sbt-plugin/src/main/scala/sttp/tapir/sbt/OpenapiCodegenPlugin.scala +++ b/openapi-codegen/sbt-plugin/src/main/scala/sttp/tapir/sbt/OpenapiCodegenPlugin.scala @@ -18,10 +18,22 @@ object OpenapiCodegenPlugin extends AutoPlugin { def openapiCodegenScopedSettings(conf: Configuration): Seq[Setting[_]] = inConfig(conf)( Seq( generateTapirDefinitions := codegen.value, - sourceGenerators += (codegen.taskValue).map(_.flatMap(_.map(_.toPath.toFile))) + sourceGenerators += (codegen.taskValue).map(_.map(_.toPath.toFile)) ) ) + def standardParamSetting = + openapiOpenApiConfiguration := OpenApiConfiguration( + openapiSwaggerFile.value, + openapiPackage.value, + openapiObject.value, + openapiUseHeadTagForObjectName.value, + openapiJsonSerdeLib.value, + openapiStreamingImplementation.value, + openapiValidateNonDiscriminatedOneOfs.value, + openapiMaxSchemasPerFile.value, + openapiAdditionalPackages.value + ) def openapiCodegenDefaultSettings: Seq[Setting[_]] = Seq( openapiSwaggerFile := baseDirectory.value / "swagger.yaml", openapiPackage := "sttp.tapir.generated", @@ -30,56 +42,46 @@ object OpenapiCodegenPlugin extends AutoPlugin { openapiJsonSerdeLib := "circe", openapiValidateNonDiscriminatedOneOfs := true, openapiMaxSchemasPerFile := 400, - openapiAdditionalPackages := Nil + openapiAdditionalPackages := Nil, + openapiStreamingImplementation := "fs2", + standardParamSetting ) - private def codegen = Def.task { - val log = sLog.value - log.info("Zipping file...") - ((( - openapiSwaggerFile, - openapiPackage, - openapiObject, - openapiUseHeadTagForObjectName, - openapiJsonSerdeLib, - openapiValidateNonDiscriminatedOneOfs, - openapiMaxSchemasPerFile, - openapiAdditionalPackages, - sourceManaged, - streams, - scalaVersion - ) flatMap { + private def codegen = + Def.task { + val log = sLog.value + log.info("Zipping file...") ( - swaggerFile: File, - packageName: String, - objectName: String, - useHeadTagForObjectName: Boolean, - jsonSerdeLib: String, - validateNonDiscriminatedOneOfs: Boolean, - maxSchemasPerFile: Int, - additionalPackages: List[(String, File)], - srcDir: File, - taskStreams: TaskStreams, - sv: String - ) => - def genTask(swaggerFile: File, packageName: String, directoryName: Option[String] = None) = - OpenapiCodegenTask( - swaggerFile, - packageName, - objectName, - useHeadTagForObjectName, - jsonSerdeLib, - validateNonDiscriminatedOneOfs, - maxSchemasPerFile, - srcDir, - taskStreams.cacheDirectory, - sv.startsWith("3"), - directoryName - ) - (genTask(swaggerFile, packageName).file +: additionalPackages.map { case (pkg, defns) => - genTask(defns, pkg, Some(pkg.replace('.', '/'))).file - }) - .reduceLeft((l, r) => l.flatMap(_l => r.map(_l ++ _))) - }) map (Seq(_))).value - } + openapiOpenApiConfiguration, + sourceManaged, + streams, + scalaVersion + ).flatMap { + ( + c: OpenApiConfiguration, + srcDir: File, + taskStreams: TaskStreams, + sv: String + ) => + def genTask(swaggerFile: File, packageName: String, directoryName: Option[String] = None) = + OpenapiCodegenTask( + swaggerFile, + packageName, + c.objectName, + c.useHeadTagForObjectName, + c.jsonSerdeLib, + c.streamingImplementation, + c.validateNonDiscriminatedOneOfs, + c.maxSchemasPerFile, + srcDir, + taskStreams.cacheDirectory, + sv.startsWith("3"), + directoryName + ) + (genTask(c.swaggerFile, c.packageName).file +: c.additionalPackages.map { case (pkg, defns) => + genTask(defns, pkg, Some(pkg.replace('.', '/'))).file + }) + .reduceLeft((l, r) => l.flatMap(_l => r.map(_l ++ _))) + }.value + } } diff --git a/openapi-codegen/sbt-plugin/src/main/scala/sttp/tapir/sbt/OpenapiCodegenTask.scala b/openapi-codegen/sbt-plugin/src/main/scala/sttp/tapir/sbt/OpenapiCodegenTask.scala index 546081e95d..e689e3a09e 100644 --- a/openapi-codegen/sbt-plugin/src/main/scala/sttp/tapir/sbt/OpenapiCodegenTask.scala +++ b/openapi-codegen/sbt-plugin/src/main/scala/sttp/tapir/sbt/OpenapiCodegenTask.scala @@ -11,6 +11,7 @@ case class OpenapiCodegenTask( objectName: String, useHeadTagForObjectName: Boolean, jsonSerdeLib: String, + streamingImplementation: String, validateNonDiscriminatedOneOfs: Boolean, maxSchemasPerFile: Int, dir: File, @@ -56,6 +57,7 @@ case class OpenapiCodegenTask( targetScala3, useHeadTagForObjectName, jsonSerdeLib, + streamingImplementation, validateNonDiscriminatedOneOfs, maxSchemasPerFile ) diff --git a/openapi-codegen/sbt-plugin/src/sbt-test/sbt-openapi-codegen/oneOf-json-roundtrip/Expected.scala.txt b/openapi-codegen/sbt-plugin/src/sbt-test/sbt-openapi-codegen/oneOf-json-roundtrip/Expected.scala.txt index 17af38a6e6..4b2c350770 100644 --- a/openapi-codegen/sbt-plugin/src/sbt-test/sbt-openapi-codegen/oneOf-json-roundtrip/Expected.scala.txt +++ b/openapi-codegen/sbt-plugin/src/sbt-test/sbt-openapi-codegen/oneOf-json-roundtrip/Expected.scala.txt @@ -117,6 +117,21 @@ object TapirGeneratedEndpoints { case object Baz extends AnEnum } + + + lazy val getBinaryTest = + endpoint + .get + .in(("binary" / "test")) + .out(streamBody(sttp.capabilities.pekko.PekkoStreams)(Schema.binary[Array[Byte]], CodecFormat.OctetStream()).description("Response CSV body")) + + lazy val postBinaryTest = + endpoint + .post + .in(("binary" / "test")) + .in(streamBody(sttp.capabilities.pekko.PekkoStreams)(Schema.binary[Array[Byte]], CodecFormat.OctetStream())) + .out(jsonBody[String].description("successful operation")) + lazy val putAdtTest = endpoint .put @@ -183,6 +198,6 @@ object TapirGeneratedEndpoints { } - lazy val generatedEndpoints = List(putAdtTest, postAdtTest, postInlineEnumTest) + lazy val generatedEndpoints = List(getBinaryTest, postBinaryTest, putAdtTest, postAdtTest, postInlineEnumTest) } diff --git a/openapi-codegen/sbt-plugin/src/sbt-test/sbt-openapi-codegen/oneOf-json-roundtrip/build.sbt b/openapi-codegen/sbt-plugin/src/sbt-test/sbt-openapi-codegen/oneOf-json-roundtrip/build.sbt index 5d3fceb614..bc1782b177 100644 --- a/openapi-codegen/sbt-plugin/src/sbt-test/sbt-openapi-codegen/oneOf-json-roundtrip/build.sbt +++ b/openapi-codegen/sbt-plugin/src/sbt-test/sbt-openapi-codegen/oneOf-json-roundtrip/build.sbt @@ -2,12 +2,14 @@ lazy val root = (project in file(".")) .enablePlugins(OpenapiCodegenPlugin) .settings( scalaVersion := "2.13.14", - version := "0.1" + version := "0.1", + openapiStreamingImplementation := "pekko" ) libraryDependencies ++= Seq( "com.softwaremill.sttp.tapir" %% "tapir-json-circe" % "1.10.0", "com.softwaremill.sttp.tapir" %% "tapir-openapi-docs" % "1.10.0", + "com.softwaremill.sttp.tapir" %% "tapir-pekko-http-server" % "1.10.0", "com.softwaremill.sttp.apispec" %% "openapi-circe-yaml" % "0.8.0", "io.circe" %% "circe-generic" % "0.14.9", "com.beachape" %% "enumeratum" % "1.7.4", diff --git a/openapi-codegen/sbt-plugin/src/sbt-test/sbt-openapi-codegen/oneOf-json-roundtrip/src/test/scala/BinaryEndpoints.scala b/openapi-codegen/sbt-plugin/src/sbt-test/sbt-openapi-codegen/oneOf-json-roundtrip/src/test/scala/BinaryEndpoints.scala new file mode 100644 index 0000000000..f79c31d486 --- /dev/null +++ b/openapi-codegen/sbt-plugin/src/sbt-test/sbt-openapi-codegen/oneOf-json-roundtrip/src/test/scala/BinaryEndpoints.scala @@ -0,0 +1,81 @@ +import org.apache.pekko.actor.ActorSystem +import org.apache.pekko.stream.Materializer +import org.apache.pekko.stream.scaladsl.Source +import org.apache.pekko.util.ByteString +import org.scalatest.freespec.AnyFreeSpec +import org.scalatest.matchers.should.Matchers +import sttp.capabilities.pekko.PekkoStreams +import sttp.client3.testing.SttpBackendStub +import sttp.client3.{Response, UriContext, asStringAlways} +import sttp.monad.FutureMonad +import sttp.tapir.generated.TapirGeneratedEndpoints +import sttp.tapir.server.stub.TapirStubInterpreter + +import scala.collection.mutable +import scala.concurrent.ExecutionContext.Implicits.global +import scala.concurrent.duration.DurationInt +import scala.concurrent.{Await, Future} +import scala.util.Random + +class BinaryEndpoints extends AnyFreeSpec with Matchers { + private implicit val system: ActorSystem = ActorSystem() + private implicit val materializer: Materializer = Materializer.matFromSystem(system) + "binary endpoints work" in { + val respQueue = mutable.Queue.empty[String] + def iterator: Iterator[ByteString] = new Iterator[ByteString] { + private var linesToGo: Int = 100 + def hasNext: Boolean = linesToGo > 0 + def next(): ByteString = { + val nxt = Random.alphanumeric.take(40).mkString + linesToGo -= 1 + ByteString.fromString(nxt, "utf-8") + } + } + val route1 = + TapirGeneratedEndpoints.postBinaryTest.serverLogicSuccess[Future]({ source: Source[ByteString, Any] => + source + .map(bs => respQueue.append(bs.utf8String.reverse)) + .run() + .map(_ => "ok") + }) + val route2 = + TapirGeneratedEndpoints.getBinaryTest.serverLogicSuccess[Future]({ _ => + Future.successful(org.apache.pekko.stream.scaladsl.Source[ByteString]({ + respQueue.map(ByteString.fromString).toSeq + })) + }) + + val stub1 = TapirStubInterpreter(SttpBackendStub[Future, PekkoStreams](new FutureMonad())) + .whenServerEndpoint(route1) + .thenRunLogic() + .backend() + val stub2 = TapirStubInterpreter(SttpBackendStub[Future, PekkoStreams](new FutureMonad())) + .whenServerEndpoint(route2) + .thenRunLogic() + .backend() + + def genSomeLines: Source[ByteString, Any] = Source.fromIterator(() => iterator) + + def doPost = sttp.client3.basicRequest + .post(uri"http://test.com/binary/test") + .response(asStringAlways) + .streamBody(PekkoStreams)(genSomeLines) + .send(stub1) + .map { resp => + resp.code.code === 200 + resp.body === Right("ok") + } + + def doGet: Future[Response[Source[ByteString, Any]]] = sttp.client3.basicRequest + .get(uri"http://test.com/binary/test") + .response(sttp.client3.asStreamUnsafe(PekkoStreams).map { case Right(s) => s }) + .send[Future, PekkoStreams](stub2) + + Await.result(doPost, 5.seconds) + respQueue.size shouldEqual 100 + val orig = respQueue.toSeq + val b = Await.result(doGet.flatMap(_.body.runFold(Seq.empty[String])((l, a) => l :+ a.utf8String)), 5.seconds) + b.size shouldEqual 100 + b shouldEqual orig + } +} diff --git a/openapi-codegen/sbt-plugin/src/sbt-test/sbt-openapi-codegen/oneOf-json-roundtrip/swagger.yaml b/openapi-codegen/sbt-plugin/src/sbt-test/sbt-openapi-codegen/oneOf-json-roundtrip/swagger.yaml index a7f8ac04ac..c380f44bd7 100644 --- a/openapi-codegen/sbt-plugin/src/sbt-test/sbt-openapi-codegen/oneOf-json-roundtrip/swagger.yaml +++ b/openapi-codegen/sbt-plugin/src/sbt-test/sbt-openapi-codegen/oneOf-json-roundtrip/swagger.yaml @@ -7,6 +7,34 @@ info: title: OneOf Json test for scala 2 tags: [ ] paths: + '/binary/test': + get: + responses: + "200": + description: "Response CSV body" + content: + application/octet-stream: + schema: + description: "csv file" + type: string + format: binary + post: + responses: + "200": + description: successful operation + content: + application/json: + schema: + type: string + requestBody: + required: true + description: Upload a csv + content: + application/octet-stream: + schema: + description: "csv file" + type: string + format: binary '/adt/test': post: responses: