Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Codegen: Add new useHeadTagForObjectNames flag to permit splitting generated endpoint objects by tag #3594

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 35 additions & 7 deletions doc/generator/sbt-openapi-codegen.md
Original file line number Diff line number Diff line change
Expand Up @@ -35,13 +35,14 @@ defined case-classes and endpoint definitions.
The generator currently supports these settings, you can override them in the `build.sbt`;

```eval_rst
=================== ==================================== ===========================================
setting default value description
=================== ==================================== ===========================================
openapiSwaggerFile baseDirectory.value / "swagger.yaml" The swagger file with the api definitions.
openapiPackage sttp.tapir.generated The name for the generated package.
openapiObject TapirGeneratedEndpoints The name for the generated object.
=================== ==================================== ===========================================
=============================== ==================================== =====================================================================
setting default value description
=============================== ==================================== =====================================================================
openapiSwaggerFile baseDirectory.value / "swagger.yaml" The swagger file with the api definitions.
openapiPackage sttp.tapir.generated The name for the generated package.
openapiObject TapirGeneratedEndpoints The name for the generated object.
openapiUseHeadTagForObjectName false If true, put endpoints in separate files based on first declared tag.
=============================== ==================================== =====================================================================
```

The general usage is;
Expand All @@ -54,6 +55,33 @@ import sttp.tapir.docs.openapi._
val docs = TapirGeneratedEndpoints.generatedEndpoints.toOpenAPI("My Bookshop", "1.0")
```

### Output files

To expand on the `openapiUseHeadTagForObjectName` setting a little more, suppose we have the following endpoints:
```yaml
paths:
/foo:
get:
tags:
- Baz
- Foo
put:
tags: []
/bar:
get:
tags:
- Baz
- Bar
```
In this case 'head' tag for `GET /foo` and `GET /bar` would be 'Baz', and `PUT /foo` has no tags (and thus no 'head' tag).

If `openapiUseHeadTagForObjectName = false` (assuming default settings for the other flags) then all endpoint definitions
will be output to the `TapirGeneratedEndpoints.scala` file, which will contain a single `object TapirGeneratedEndpoints`.

If `openapiUseHeadTagForObjectName = true`, then the `GET /foo` and `GET /bar` endpoints would be output to a
`Baz.scala` file, containing a single `object Baz` with those endpoint definitions; the `PUT /foo` endpoint, by dint of
having no tags, would be output to the `TapirGeneratedEndpoints` file, along with any schema and parameter definitions.

### Limitations

Currently, the generated code depends on `"io.circe" %% "circe-generic"`. In the future probably we will make the encoder/decoder json lib configurable (PRs welcome).
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,12 @@ object GenScala {
)
.orNone

private val targetScala3Opt: Opts[Boolean] =
Opts.flag("scala3", "Whether to generate Scala 3 code", "3").orFalse

private val headTagForNamesOpt: Opts[Boolean] =
Opts.flag("headTagForNames", "Whether to group generated endpoints by first declared tag", "t").orFalse

private val destDirOpt: Opts[File] =
Opts
.option[String]("destdir", "Destination directory", "d")
Expand All @@ -53,22 +59,25 @@ object GenScala {
}

val cmd: Command[IO[ExitCode]] = Command("genscala", "Generate Scala classes", helpFlag = true) {
(fileOpt, packageNameOpt, destDirOpt, objectNameOpt).mapN { case (file, packageName, destDir, maybeObjectName) =>
val objectName = maybeObjectName.getOrElse(DefaultObjectName)

def generateCode(doc: OpenapiDocument): IO[Unit] = for {
content <- IO.pure(BasicGenerator.generateObjects(doc, packageName, objectName, false))
destFile <- writeGeneratedFile(destDir, objectName, content)
_ <- IO.println(s"Generated endpoints written to: $destFile")
} yield ()

for {
parsed <- readFile(file).map(YamlParser.parseFile)
exitCode <- parsed match {
case Left(err) => IO.println(s"Invalid YAML file: ${err.getMessage}").as(ExitCode.Error)
case Right(doc) => generateCode(doc).as(ExitCode.Success)
}
} yield exitCode
(fileOpt, packageNameOpt, destDirOpt, objectNameOpt, targetScala3Opt, headTagForNamesOpt).mapN {
case (file, packageName, destDir, maybeObjectName, targetScala3, headTagForNames) =>
val objectName = maybeObjectName.getOrElse(DefaultObjectName)

def generateCode(doc: OpenapiDocument): IO[Unit] = for {
contents <- IO.pure(
BasicGenerator.generateObjects(doc, packageName, objectName, targetScala3, headTagForNames)
)
destFiles <- contents.toVector.traverse{ case (fileName, content) => writeGeneratedFile(destDir, fileName, content) }
_ <- IO.println(s"Generated endpoints written to: ${destFiles.mkString(", ")}")
} yield ()

for {
parsed <- readFile(file).map(YamlParser.parseFile)
exitCode <- parsed match {
case Left(err) => IO.println(s"Invalid YAML file: ${err.getMessage}").as(ExitCode.Error)
case Right(doc) => generateCode(doc).as(ExitCode.Success)
}
} yield exitCode
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,35 @@ object BasicGenerator {
val classGenerator = new ClassDefinitionGenerator()
val endpointGenerator = new EndpointGenerator()

def generateObjects(doc: OpenapiDocument, packagePath: String, objName: String, targetScala3: Boolean): String = {
def generateObjects(
doc: OpenapiDocument,
packagePath: String,
objName: String,
targetScala3: Boolean,
useHeadTagForObjectNames: Boolean
): Map[String, String] = {
val enumImport =
if (!targetScala3 && doc.components.toSeq.flatMap(_.schemas).exists(_._2.isInstanceOf[OpenapiSchemaEnum])) "\n import enumeratum._"
else ""
s"""|

val endpointsByTag = endpointGenerator.endpointDefs(doc, useHeadTagForObjectNames)
val taggedObjs = endpointsByTag.collect {
case (Some(headTag), body) if body.nonEmpty =>
val taggedObj =
s"""package $packagePath
|
|import $objName._
|
|object $headTag {
|
|${indent(2)(imports)}
|
|${indent(2)(body)}
|
|}""".stripMargin
headTag -> taggedObj
}
val mainObj = s"""|
|package $packagePath
|
|object $objName {
Expand All @@ -35,10 +59,11 @@ object BasicGenerator {
|
|${indent(2)(classGenerator.classDefs(doc, targetScala3).getOrElse(""))}
|
|${indent(2)(endpointGenerator.endpointDefs(doc))}
|${indent(2)(endpointsByTag.getOrElse(None, ""))}
|
|}
|""".stripMargin
taggedObjs + (objName -> mainObj)
}

private[codegen] def imports: String =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,49 +19,60 @@ class EndpointGenerator {

private[codegen] def allEndpoints: String = "generatedEndpoints"

def endpointDefs(doc: OpenapiDocument): String = {
def endpointDefs(doc: OpenapiDocument, useHeadTagForObjectNames: Boolean): Map[Option[String], String] = {
val components = Option(doc.components).flatten
val ge = doc.paths.flatMap(generatedEndpoints(components))
val definitions = ge
.map { case (name, definition) =>
s"""|lazy val $name =
val geMap =
doc.paths.flatMap(generatedEndpoints(components, useHeadTagForObjectNames)).groupBy(_._1).mapValues(_.map(_._2).reduce(_ ++ _))
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This would've been so much cleaner with groupMapReduce(_._1)(_._2)(_ ++ _) 😂 but, alas, scala 2.12...

geMap.mapValues { ge =>
val definitions = ge
.map { case (name, definition) =>
s"""|lazy val $name =
|${indent(2)(definition)}
|""".stripMargin
}
.mkString("\n")
val allEP = s"lazy val $allEndpoints = List(${ge.map(_._1).mkString(", ")})"

s"""|$definitions
|
|$allEP
|""".stripMargin
}
.mkString("\n")
val allEP = s"lazy val $allEndpoints = List(${ge.map(_._1).mkString(", ")})"

s"""|$definitions
|
|$allEP
|""".stripMargin
}.toMap
}

private[codegen] def generatedEndpoints(components: Option[OpenapiComponent])(p: OpenapiPath): Seq[(String, String)] = {
private[codegen] def generatedEndpoints(components: Option[OpenapiComponent], useHeadTagForObjectNames: Boolean)(
p: OpenapiPath
): Seq[(Option[String], Seq[(String, String)])] = {
val parameters = components.map(_.parameters).getOrElse(Map.empty)
val securitySchemes = components.map(_.securitySchemes).getOrElse(Map.empty)

p.methods.map(_.withResolvedParentParameters(parameters, p.parameters)).map { m =>
implicit val location: Location = Location(p.url, m.methodType)
val definition =
s"""|endpoint
| .${m.methodType}
| ${urlMapper(p.url, m.resolvedParameters)}
|${indent(2)(security(securitySchemes, m.security))}
|${indent(2)(ins(m.resolvedParameters, m.requestBody))}
|${indent(2)(outs(m.responses))}
|${indent(2)(tags(m.tags))}
|""".stripMargin

val name = m.operationId
.getOrElse(m.methodType + p.url.capitalize)
.split("[^0-9a-zA-Z$_]")
.filter(_.nonEmpty)
.zipWithIndex
.map { case (part, 0) => part; case (part, _) => part.capitalize }
.mkString
(name, definition)
}
p.methods
.map(_.withResolvedParentParameters(parameters, p.parameters))
.map { m =>
implicit val location: Location = Location(p.url, m.methodType)
val definition =
s"""|endpoint
| .${m.methodType}
| ${urlMapper(p.url, m.resolvedParameters)}
|${indent(2)(security(securitySchemes, m.security))}
|${indent(2)(ins(m.resolvedParameters, m.requestBody))}
|${indent(2)(outs(m.responses))}
|${indent(2)(tags(m.tags))}
|""".stripMargin

val name = m.operationId
.getOrElse(m.methodType + p.url.capitalize)
.split("[^0-9a-zA-Z$_]")
.filter(_.nonEmpty)
.zipWithIndex
.map { case (part, 0) => part; case (part, _) => part.capitalize }
.mkString
val maybeTargetFileName = if (useHeadTagForObjectNames) m.tags.flatMap(_.headOption) else None
(maybeTargetFileName, (name, definition))
}
.groupBy(_._1)
.toSeq
.map { case (maybeTargetFileName, defns) => maybeTargetFileName -> defns.map(_._2) }
}

private def urlMapper(url: String, parameters: Seq[OpenapiParameter])(implicit location: Location): String = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,30 @@ class BasicGeneratorSpec extends CompileCheckTestBase {
TestHelpers.myBookshopDoc,
"sttp.tapir.generated",
"TapirGeneratedEndpoints",
targetScala3 = false
) shouldCompile ()
targetScala3 = false,
useHeadTagForObjectNames = false
)("TapirGeneratedEndpoints") shouldCompile ()
}

it should "split outputs by tag if useHeadTagForObjectNames = true" in {
val generated = BasicGenerator.generateObjects(
TestHelpers.myBookshopDoc,
"sttp.tapir.generated",
"TapirGeneratedEndpoints",
targetScala3 = false,
useHeadTagForObjectNames = true
)
val schemas = generated("TapirGeneratedEndpoints")
val endpoints = generated("Bookshop")
// schema file on its own should compile
schemas shouldCompile ()
// schema file should contain no endpoint definitions
schemas.linesIterator.count(_.matches("""^\s*endpoint""")) shouldEqual 0
// Bookshop file should contain all endpoint definitions
endpoints.linesIterator.count(_.matches("""^\s*endpoint""")) shouldEqual 3
// endpoint file depends on schema file. For simplicity of testing, just strip the package declaration from the
// endpoint file, and concat the two, before testing for compilation
(schemas + "\n" + (endpoints.linesIterator.filterNot(_ startsWith "package").mkString("\n"))) shouldCompile ()
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -344,7 +344,7 @@ class ClassDefinitionGeneratorSpec extends CompileCheckTestBase {

val res: String = parserRes match {
case Left(value) => throw new Exception(value)
case Right(doc) => new EndpointGenerator().endpointDefs(doc)
case Right(doc) => new EndpointGenerator().endpointDefs(doc, useHeadTagForObjectNames = false)(None)
}

val compileUnit =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ class EndpointGeneratorSpec extends CompileCheckTestBase {
),
null
)
val generatedCode = BasicGenerator.imports ++ new EndpointGenerator().endpointDefs(doc)
val generatedCode = BasicGenerator.imports ++ new EndpointGenerator().endpointDefs(doc, useHeadTagForObjectNames = false)(None)
generatedCode should include("val getTestAsdId =")
generatedCode shouldCompile ()
}
Expand Down Expand Up @@ -131,7 +131,7 @@ class EndpointGeneratorSpec extends CompileCheckTestBase {
)
)
BasicGenerator.imports ++
new EndpointGenerator().endpointDefs(doc) shouldCompile ()
new EndpointGenerator().endpointDefs(doc, useHeadTagForObjectNames = false)(None) shouldCompile ()
}

it should "handle status codes" in {
Expand Down Expand Up @@ -174,7 +174,7 @@ class EndpointGeneratorSpec extends CompileCheckTestBase {
),
null
)
val generatedCode = BasicGenerator.imports ++ new EndpointGenerator().endpointDefs(doc)
val generatedCode = BasicGenerator.imports ++ new EndpointGenerator().endpointDefs(doc, useHeadTagForObjectNames = false)(None)
generatedCode should include(
""".out(stringBody.description("Processing").and(statusCode(sttp.model.StatusCode(202))))"""
) // status code with body
Expand Down Expand Up @@ -230,7 +230,13 @@ class EndpointGeneratorSpec extends CompileCheckTestBase {
)
)
)
val generatedCode = BasicGenerator.generateObjects(doc, "sttp.tapir.generated", "TapirGeneratedEndpoints", targetScala3 = false)
val generatedCode = BasicGenerator.generateObjects(
doc,
"sttp.tapir.generated",
"TapirGeneratedEndpoints",
targetScala3 = false,
useHeadTagForObjectNames = false
)("TapirGeneratedEndpoints")
generatedCode should include(
"""file: sttp.model.Part[java.io.File]"""
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@ trait OpenapiCodegenKeys {
lazy val openapiSwaggerFile = settingKey[File]("The swagger file with the api definitions.")
lazy val openapiPackage = settingKey[String]("The name for the generated package.")
lazy val openapiObject = settingKey[String]("The name for the generated object.")
lazy val openapiUseHeadTagForObjectName = settingKey[Boolean](
"If true, any tagged endpoints will be defined in an object with a name based on the first tag, instead of on the default generated object."
)

lazy val generateTapirDefinitions = taskKey[Unit]("The task that generates tapir definitions based on the input swagger file.")
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,15 @@ object OpenapiCodegenPlugin extends AutoPlugin {
def openapiCodegenScopedSettings(conf: Configuration): Seq[Setting[_]] = inConfig(conf)(
Seq(
generateTapirDefinitions := codegen.value,
sourceGenerators += (codegen.taskValue).map(_.map(_.toPath.toFile))
sourceGenerators += (codegen.taskValue).map(_.flatMap(_.map(_.toPath.toFile)))
)
)

def openapiCodegenDefaultSettings: Seq[Setting[_]] = Seq(
openapiSwaggerFile := baseDirectory.value / "swagger.yaml",
openapiPackage := "sttp.tapir.generated",
openapiObject := "TapirGeneratedEndpoints"
openapiObject := "TapirGeneratedEndpoints",
openapiUseHeadTagForObjectName := false
)

private def codegen = Def.task {
Expand All @@ -35,6 +36,7 @@ object OpenapiCodegenPlugin extends AutoPlugin {
openapiSwaggerFile,
openapiPackage,
openapiObject,
openapiUseHeadTagForObjectName,
sourceManaged,
streams,
scalaVersion
Expand All @@ -43,11 +45,12 @@ object OpenapiCodegenPlugin extends AutoPlugin {
swaggerFile: File,
packageName: String,
objectName: String,
useHeadTagForObjectName: Boolean,
srcDir: File,
taskStreams: TaskStreams,
sv: String
) =>
OpenapiCodegenTask(swaggerFile, packageName, objectName, srcDir, taskStreams.cacheDirectory, sv.startsWith("3")).file
OpenapiCodegenTask(swaggerFile, packageName, objectName, useHeadTagForObjectName, srcDir, taskStreams.cacheDirectory, sv.startsWith("3")).file
}) map (Seq(_))).value
}
}
Loading
Loading