diff --git a/doc/generator/sbt-openapi-codegen.md b/doc/generator/sbt-openapi-codegen.md index 42ba97c918..c637ee0ed5 100644 --- a/doc/generator/sbt-openapi-codegen.md +++ b/doc/generator/sbt-openapi-codegen.md @@ -35,13 +35,14 @@ defined case-classes and endpoint definitions. The generator currently supports these settings, you can override them in the `build.sbt`; ```eval_rst -=================== ==================================== =========================================== -setting default value description -=================== ==================================== =========================================== -openapiSwaggerFile baseDirectory.value / "swagger.yaml" The swagger file with the api definitions. -openapiPackage sttp.tapir.generated The name for the generated package. -openapiObject TapirGeneratedEndpoints The name for the generated object. -=================== ==================================== =========================================== +=============================== ==================================== ===================================================================== +setting default value description +=============================== ==================================== ===================================================================== +openapiSwaggerFile baseDirectory.value / "swagger.yaml" The swagger file with the api definitions. +openapiPackage sttp.tapir.generated The name for the generated package. +openapiObject TapirGeneratedEndpoints The name for the generated object. +openapiUseHeadTagForObjectName false If true, put endpoints in separate files based on first declared tag. +=============================== ==================================== ===================================================================== ``` The general usage is; @@ -54,6 +55,33 @@ import sttp.tapir.docs.openapi._ val docs = TapirGeneratedEndpoints.generatedEndpoints.toOpenAPI("My Bookshop", "1.0") ``` +### Output files + +To expand on the `openapiUseHeadTagForObjectName` setting a little more, suppose we have the following endpoints: +```yaml +paths: + /foo: + get: + tags: + - Baz + - Foo + put: + tags: [] + /bar: + get: + tags: + - Baz + - Bar +``` +In this case 'head' tag for `GET /foo` and `GET /bar` would be 'Baz', and `PUT /foo` has no tags (and thus no 'head' tag). + +If `openapiUseHeadTagForObjectName = false` (assuming default settings for the other flags) then all endpoint definitions +will be output to the `TapirGeneratedEndpoints.scala` file, which will contain a single `object TapirGeneratedEndpoints`. + +If `openapiUseHeadTagForObjectName = true`, then the `GET /foo` and `GET /bar` endpoints would be output to a +`Baz.scala` file, containing a single `object Baz` with those endpoint definitions; the `PUT /foo` endpoint, by dint of +having no tags, would be output to the `TapirGeneratedEndpoints` file, along with any schema and parameter definitions. + ### Limitations Currently, the generated code depends on `"io.circe" %% "circe-generic"`. In the future probably we will make the encoder/decoder json lib configurable (PRs welcome). diff --git a/openapi-codegen/cli/src/main/scala/sttp/tapir/codegen/GenScala.scala b/openapi-codegen/cli/src/main/scala/sttp/tapir/codegen/GenScala.scala index df9894c8d0..a9c9f38fa9 100644 --- a/openapi-codegen/cli/src/main/scala/sttp/tapir/codegen/GenScala.scala +++ b/openapi-codegen/cli/src/main/scala/sttp/tapir/codegen/GenScala.scala @@ -40,6 +40,12 @@ object GenScala { ) .orNone + private val targetScala3Opt: Opts[Boolean] = + Opts.flag("scala3", "Whether to generate Scala 3 code", "3").orFalse + + private val headTagForNamesOpt: Opts[Boolean] = + Opts.flag("headTagForNames", "Whether to group generated endpoints by first declared tag", "t").orFalse + private val destDirOpt: Opts[File] = Opts .option[String]("destdir", "Destination directory", "d") @@ -53,22 +59,25 @@ object GenScala { } val cmd: Command[IO[ExitCode]] = Command("genscala", "Generate Scala classes", helpFlag = true) { - (fileOpt, packageNameOpt, destDirOpt, objectNameOpt).mapN { case (file, packageName, destDir, maybeObjectName) => - val objectName = maybeObjectName.getOrElse(DefaultObjectName) - - def generateCode(doc: OpenapiDocument): IO[Unit] = for { - content <- IO.pure(BasicGenerator.generateObjects(doc, packageName, objectName, false)) - destFile <- writeGeneratedFile(destDir, objectName, content) - _ <- IO.println(s"Generated endpoints written to: $destFile") - } yield () - - for { - parsed <- readFile(file).map(YamlParser.parseFile) - exitCode <- parsed match { - case Left(err) => IO.println(s"Invalid YAML file: ${err.getMessage}").as(ExitCode.Error) - case Right(doc) => generateCode(doc).as(ExitCode.Success) - } - } yield exitCode + (fileOpt, packageNameOpt, destDirOpt, objectNameOpt, targetScala3Opt, headTagForNamesOpt).mapN { + case (file, packageName, destDir, maybeObjectName, targetScala3, headTagForNames) => + val objectName = maybeObjectName.getOrElse(DefaultObjectName) + + def generateCode(doc: OpenapiDocument): IO[Unit] = for { + contents <- IO.pure( + BasicGenerator.generateObjects(doc, packageName, objectName, targetScala3, headTagForNames) + ) + destFiles <- contents.toVector.traverse{ case (fileName, content) => writeGeneratedFile(destDir, fileName, content) } + _ <- IO.println(s"Generated endpoints written to: ${destFiles.mkString(", ")}") + } yield () + + for { + parsed <- readFile(file).map(YamlParser.parseFile) + exitCode <- parsed match { + case Left(err) => IO.println(s"Invalid YAML file: ${err.getMessage}").as(ExitCode.Error) + case Right(doc) => generateCode(doc).as(ExitCode.Success) + } + } yield exitCode } } diff --git a/openapi-codegen/core/src/main/scala/sttp/tapir/codegen/BasicGenerator.scala b/openapi-codegen/core/src/main/scala/sttp/tapir/codegen/BasicGenerator.scala index e847ed9187..89fffd2664 100644 --- a/openapi-codegen/core/src/main/scala/sttp/tapir/codegen/BasicGenerator.scala +++ b/openapi-codegen/core/src/main/scala/sttp/tapir/codegen/BasicGenerator.scala @@ -22,11 +22,35 @@ object BasicGenerator { val classGenerator = new ClassDefinitionGenerator() val endpointGenerator = new EndpointGenerator() - def generateObjects(doc: OpenapiDocument, packagePath: String, objName: String, targetScala3: Boolean): String = { + def generateObjects( + doc: OpenapiDocument, + packagePath: String, + objName: String, + targetScala3: Boolean, + useHeadTagForObjectNames: Boolean + ): Map[String, String] = { val enumImport = if (!targetScala3 && doc.components.toSeq.flatMap(_.schemas).exists(_._2.isInstanceOf[OpenapiSchemaEnum])) "\n import enumeratum._" else "" - s"""| + + val endpointsByTag = endpointGenerator.endpointDefs(doc, useHeadTagForObjectNames) + val taggedObjs = endpointsByTag.collect { + case (Some(headTag), body) if body.nonEmpty => + val taggedObj = + s"""package $packagePath + | + |import $objName._ + | + |object $headTag { + | + |${indent(2)(imports)} + | + |${indent(2)(body)} + | + |}""".stripMargin + headTag -> taggedObj + } + val mainObj = s"""| |package $packagePath | |object $objName { @@ -35,10 +59,11 @@ object BasicGenerator { | |${indent(2)(classGenerator.classDefs(doc, targetScala3).getOrElse(""))} | - |${indent(2)(endpointGenerator.endpointDefs(doc))} + |${indent(2)(endpointsByTag.getOrElse(None, ""))} | |} |""".stripMargin + taggedObjs + (objName -> mainObj) } private[codegen] def imports: String = diff --git a/openapi-codegen/core/src/main/scala/sttp/tapir/codegen/EndpointGenerator.scala b/openapi-codegen/core/src/main/scala/sttp/tapir/codegen/EndpointGenerator.scala index 23a571b53d..4b9468ed6d 100644 --- a/openapi-codegen/core/src/main/scala/sttp/tapir/codegen/EndpointGenerator.scala +++ b/openapi-codegen/core/src/main/scala/sttp/tapir/codegen/EndpointGenerator.scala @@ -19,49 +19,60 @@ class EndpointGenerator { private[codegen] def allEndpoints: String = "generatedEndpoints" - def endpointDefs(doc: OpenapiDocument): String = { + def endpointDefs(doc: OpenapiDocument, useHeadTagForObjectNames: Boolean): Map[Option[String], String] = { val components = Option(doc.components).flatten - val ge = doc.paths.flatMap(generatedEndpoints(components)) - val definitions = ge - .map { case (name, definition) => - s"""|lazy val $name = + val geMap = + doc.paths.flatMap(generatedEndpoints(components, useHeadTagForObjectNames)).groupBy(_._1).mapValues(_.map(_._2).reduce(_ ++ _)) + geMap.mapValues { ge => + val definitions = ge + .map { case (name, definition) => + s"""|lazy val $name = |${indent(2)(definition)} |""".stripMargin - } - .mkString("\n") - val allEP = s"lazy val $allEndpoints = List(${ge.map(_._1).mkString(", ")})" - - s"""|$definitions - | - |$allEP - |""".stripMargin + } + .mkString("\n") + val allEP = s"lazy val $allEndpoints = List(${ge.map(_._1).mkString(", ")})" + + s"""|$definitions + | + |$allEP + |""".stripMargin + }.toMap } - private[codegen] def generatedEndpoints(components: Option[OpenapiComponent])(p: OpenapiPath): Seq[(String, String)] = { + private[codegen] def generatedEndpoints(components: Option[OpenapiComponent], useHeadTagForObjectNames: Boolean)( + p: OpenapiPath + ): Seq[(Option[String], Seq[(String, String)])] = { val parameters = components.map(_.parameters).getOrElse(Map.empty) val securitySchemes = components.map(_.securitySchemes).getOrElse(Map.empty) - p.methods.map(_.withResolvedParentParameters(parameters, p.parameters)).map { m => - implicit val location: Location = Location(p.url, m.methodType) - val definition = - s"""|endpoint - | .${m.methodType} - | ${urlMapper(p.url, m.resolvedParameters)} - |${indent(2)(security(securitySchemes, m.security))} - |${indent(2)(ins(m.resolvedParameters, m.requestBody))} - |${indent(2)(outs(m.responses))} - |${indent(2)(tags(m.tags))} - |""".stripMargin - - val name = m.operationId - .getOrElse(m.methodType + p.url.capitalize) - .split("[^0-9a-zA-Z$_]") - .filter(_.nonEmpty) - .zipWithIndex - .map { case (part, 0) => part; case (part, _) => part.capitalize } - .mkString - (name, definition) - } + p.methods + .map(_.withResolvedParentParameters(parameters, p.parameters)) + .map { m => + implicit val location: Location = Location(p.url, m.methodType) + val definition = + s"""|endpoint + | .${m.methodType} + | ${urlMapper(p.url, m.resolvedParameters)} + |${indent(2)(security(securitySchemes, m.security))} + |${indent(2)(ins(m.resolvedParameters, m.requestBody))} + |${indent(2)(outs(m.responses))} + |${indent(2)(tags(m.tags))} + |""".stripMargin + + val name = m.operationId + .getOrElse(m.methodType + p.url.capitalize) + .split("[^0-9a-zA-Z$_]") + .filter(_.nonEmpty) + .zipWithIndex + .map { case (part, 0) => part; case (part, _) => part.capitalize } + .mkString + val maybeTargetFileName = if (useHeadTagForObjectNames) m.tags.flatMap(_.headOption) else None + (maybeTargetFileName, (name, definition)) + } + .groupBy(_._1) + .toSeq + .map { case (maybeTargetFileName, defns) => maybeTargetFileName -> defns.map(_._2) } } private def urlMapper(url: String, parameters: Seq[OpenapiParameter])(implicit location: Location): String = { diff --git a/openapi-codegen/core/src/test/scala/sttp/tapir/codegen/BasicGeneratorSpec.scala b/openapi-codegen/core/src/test/scala/sttp/tapir/codegen/BasicGeneratorSpec.scala index d41bde3f55..0d054b5cf1 100644 --- a/openapi-codegen/core/src/test/scala/sttp/tapir/codegen/BasicGeneratorSpec.scala +++ b/openapi-codegen/core/src/test/scala/sttp/tapir/codegen/BasicGeneratorSpec.scala @@ -9,8 +9,30 @@ class BasicGeneratorSpec extends CompileCheckTestBase { TestHelpers.myBookshopDoc, "sttp.tapir.generated", "TapirGeneratedEndpoints", - targetScala3 = false - ) shouldCompile () + targetScala3 = false, + useHeadTagForObjectNames = false + )("TapirGeneratedEndpoints") shouldCompile () + } + + it should "split outputs by tag if useHeadTagForObjectNames = true" in { + val generated = BasicGenerator.generateObjects( + TestHelpers.myBookshopDoc, + "sttp.tapir.generated", + "TapirGeneratedEndpoints", + targetScala3 = false, + useHeadTagForObjectNames = true + ) + val schemas = generated("TapirGeneratedEndpoints") + val endpoints = generated("Bookshop") + // schema file on its own should compile + schemas shouldCompile () + // schema file should contain no endpoint definitions + schemas.linesIterator.count(_.matches("""^\s*endpoint""")) shouldEqual 0 + // Bookshop file should contain all endpoint definitions + endpoints.linesIterator.count(_.matches("""^\s*endpoint""")) shouldEqual 3 + // endpoint file depends on schema file. For simplicity of testing, just strip the package declaration from the + // endpoint file, and concat the two, before testing for compilation + (schemas + "\n" + (endpoints.linesIterator.filterNot(_ startsWith "package").mkString("\n"))) shouldCompile () } } diff --git a/openapi-codegen/core/src/test/scala/sttp/tapir/codegen/ClassDefinitionGeneratorSpec.scala b/openapi-codegen/core/src/test/scala/sttp/tapir/codegen/ClassDefinitionGeneratorSpec.scala index 27c37eea95..c3a6201f12 100644 --- a/openapi-codegen/core/src/test/scala/sttp/tapir/codegen/ClassDefinitionGeneratorSpec.scala +++ b/openapi-codegen/core/src/test/scala/sttp/tapir/codegen/ClassDefinitionGeneratorSpec.scala @@ -344,7 +344,7 @@ class ClassDefinitionGeneratorSpec extends CompileCheckTestBase { val res: String = parserRes match { case Left(value) => throw new Exception(value) - case Right(doc) => new EndpointGenerator().endpointDefs(doc) + case Right(doc) => new EndpointGenerator().endpointDefs(doc, useHeadTagForObjectNames = false)(None) } val compileUnit = diff --git a/openapi-codegen/core/src/test/scala/sttp/tapir/codegen/EndpointGeneratorSpec.scala b/openapi-codegen/core/src/test/scala/sttp/tapir/codegen/EndpointGeneratorSpec.scala index 150e63e54d..2240f0a0b7 100644 --- a/openapi-codegen/core/src/test/scala/sttp/tapir/codegen/EndpointGeneratorSpec.scala +++ b/openapi-codegen/core/src/test/scala/sttp/tapir/codegen/EndpointGeneratorSpec.scala @@ -56,7 +56,7 @@ class EndpointGeneratorSpec extends CompileCheckTestBase { ), null ) - val generatedCode = BasicGenerator.imports ++ new EndpointGenerator().endpointDefs(doc) + val generatedCode = BasicGenerator.imports ++ new EndpointGenerator().endpointDefs(doc, useHeadTagForObjectNames = false)(None) generatedCode should include("val getTestAsdId =") generatedCode shouldCompile () } @@ -131,7 +131,7 @@ class EndpointGeneratorSpec extends CompileCheckTestBase { ) ) BasicGenerator.imports ++ - new EndpointGenerator().endpointDefs(doc) shouldCompile () + new EndpointGenerator().endpointDefs(doc, useHeadTagForObjectNames = false)(None) shouldCompile () } it should "handle status codes" in { @@ -174,7 +174,7 @@ class EndpointGeneratorSpec extends CompileCheckTestBase { ), null ) - val generatedCode = BasicGenerator.imports ++ new EndpointGenerator().endpointDefs(doc) + val generatedCode = BasicGenerator.imports ++ new EndpointGenerator().endpointDefs(doc, useHeadTagForObjectNames = false)(None) generatedCode should include( """.out(stringBody.description("Processing").and(statusCode(sttp.model.StatusCode(202))))""" ) // status code with body @@ -230,7 +230,13 @@ class EndpointGeneratorSpec extends CompileCheckTestBase { ) ) ) - val generatedCode = BasicGenerator.generateObjects(doc, "sttp.tapir.generated", "TapirGeneratedEndpoints", targetScala3 = false) + val generatedCode = BasicGenerator.generateObjects( + doc, + "sttp.tapir.generated", + "TapirGeneratedEndpoints", + targetScala3 = false, + useHeadTagForObjectNames = false + )("TapirGeneratedEndpoints") generatedCode should include( """file: sttp.model.Part[java.io.File]""" ) diff --git a/openapi-codegen/sbt-plugin/src/main/scala/sttp/tapir/sbt/OpenapiCodegenKeys.scala b/openapi-codegen/sbt-plugin/src/main/scala/sttp/tapir/sbt/OpenapiCodegenKeys.scala index fa8a6d696e..46894eea8e 100644 --- a/openapi-codegen/sbt-plugin/src/main/scala/sttp/tapir/sbt/OpenapiCodegenKeys.scala +++ b/openapi-codegen/sbt-plugin/src/main/scala/sttp/tapir/sbt/OpenapiCodegenKeys.scala @@ -6,6 +6,9 @@ trait OpenapiCodegenKeys { lazy val openapiSwaggerFile = settingKey[File]("The swagger file with the api definitions.") lazy val openapiPackage = settingKey[String]("The name for the generated package.") lazy val openapiObject = settingKey[String]("The name for the generated object.") + lazy val openapiUseHeadTagForObjectName = settingKey[Boolean]( + "If true, any tagged endpoints will be defined in an object with a name based on the first tag, instead of on the default generated object." + ) lazy val generateTapirDefinitions = taskKey[Unit]("The task that generates tapir definitions based on the input swagger file.") } diff --git a/openapi-codegen/sbt-plugin/src/main/scala/sttp/tapir/sbt/OpenapiCodegenPlugin.scala b/openapi-codegen/sbt-plugin/src/main/scala/sttp/tapir/sbt/OpenapiCodegenPlugin.scala index 745fa8e56e..da30743f86 100644 --- a/openapi-codegen/sbt-plugin/src/main/scala/sttp/tapir/sbt/OpenapiCodegenPlugin.scala +++ b/openapi-codegen/sbt-plugin/src/main/scala/sttp/tapir/sbt/OpenapiCodegenPlugin.scala @@ -18,14 +18,15 @@ object OpenapiCodegenPlugin extends AutoPlugin { def openapiCodegenScopedSettings(conf: Configuration): Seq[Setting[_]] = inConfig(conf)( Seq( generateTapirDefinitions := codegen.value, - sourceGenerators += (codegen.taskValue).map(_.map(_.toPath.toFile)) + sourceGenerators += (codegen.taskValue).map(_.flatMap(_.map(_.toPath.toFile))) ) ) def openapiCodegenDefaultSettings: Seq[Setting[_]] = Seq( openapiSwaggerFile := baseDirectory.value / "swagger.yaml", openapiPackage := "sttp.tapir.generated", - openapiObject := "TapirGeneratedEndpoints" + openapiObject := "TapirGeneratedEndpoints", + openapiUseHeadTagForObjectName := false ) private def codegen = Def.task { @@ -35,6 +36,7 @@ object OpenapiCodegenPlugin extends AutoPlugin { openapiSwaggerFile, openapiPackage, openapiObject, + openapiUseHeadTagForObjectName, sourceManaged, streams, scalaVersion @@ -43,11 +45,12 @@ object OpenapiCodegenPlugin extends AutoPlugin { swaggerFile: File, packageName: String, objectName: String, + useHeadTagForObjectName: Boolean, srcDir: File, taskStreams: TaskStreams, sv: String ) => - OpenapiCodegenTask(swaggerFile, packageName, objectName, srcDir, taskStreams.cacheDirectory, sv.startsWith("3")).file + OpenapiCodegenTask(swaggerFile, packageName, objectName, useHeadTagForObjectName, srcDir, taskStreams.cacheDirectory, sv.startsWith("3")).file }) map (Seq(_))).value } } diff --git a/openapi-codegen/sbt-plugin/src/main/scala/sttp/tapir/sbt/OpenapiCodegenTask.scala b/openapi-codegen/sbt-plugin/src/main/scala/sttp/tapir/sbt/OpenapiCodegenTask.scala index 2817107974..736e3df81a 100644 --- a/openapi-codegen/sbt-plugin/src/main/scala/sttp/tapir/sbt/OpenapiCodegenTask.scala +++ b/openapi-codegen/sbt-plugin/src/main/scala/sttp/tapir/sbt/OpenapiCodegenTask.scala @@ -9,39 +9,47 @@ case class OpenapiCodegenTask( inputYaml: File, packageName: String, objectName: String, + useHeadTagForObjectName: Boolean, dir: File, cacheDir: File, targetScala3: Boolean ) { - val tempFile = cacheDir / "sbt-openapi-codegen" / s"$objectName.scala" - val outFile = dir / "sbt-openapi-codegen" / s"$objectName.scala" + val tempDirectory = cacheDir / "sbt-openapi-codegen" + val outDirectory = dir / "sbt-openapi-codegen" - // 1. make the file under cache/sbt-tapircodegen. - // 2. compare its SHA1 against cache/sbtbuildinfo-inputs - def file: Task[File] = { - makeFile(tempFile) map { _ => - cachedCopyFile(hash(tempFile)) - outFile + // 1. make the files under cache/sbt-tapircodegen. + // 2. compare their SHA1 against cache/sbtbuildinfo-inputs + def file: Task[Seq[File]] = { + makeFiles(tempDirectory) map { files => + files.map { tempFile => + val outFile = outDirectory / tempFile.getName + cachedCopyFile(tempFile, outFile)(hash(tempFile)) + outFile + } } } - val cachedCopyFile = + def cachedCopyFile(tempFile: File, outFile: File) = inputChanged(cacheDir / "sbt-openapi-codegen-inputs") { (inChanged, _: HashFileInfo) => if (inChanged || !outFile.exists) { IO.copyFile(tempFile, outFile, preserveLastModified = true) - } // if + } } - def makeFile(file: File): Task[File] = { + def makeFiles(directory: File): Task[Seq[File]] = { task { val parsed = YamlParser .parseFile(IO.readLines(inputYaml).mkString("\n")) .left .map(d => new RuntimeException(_root_.io.circe.Error.showError.show(d))) - val lines = BasicGenerator.generateObjects(parsed.toTry.get, packageName, objectName, targetScala3).linesIterator.toSeq - IO.writeLines(file, lines, IO.utf8) - file + BasicGenerator.generateObjects(parsed.toTry.get, packageName, objectName, targetScala3, useHeadTagForObjectName).map { + case (objectName, fileBody) => + val file = directory / s"$objectName.scala" + val lines = fileBody.linesIterator.toSeq + IO.writeLines(file, lines, IO.utf8) + file + }.toSeq } } }