Skip to content

Commit

Permalink
Codegen: Add new useHeadTagForObjectNames flag to permit splitting …
Browse files Browse the repository at this point in the history
…generated endpoint objects by tag (#3594)
  • Loading branch information
hughsimpson authored Mar 14, 2024
1 parent a0721ec commit e611d21
Show file tree
Hide file tree
Showing 10 changed files with 200 additions and 85 deletions.
42 changes: 35 additions & 7 deletions doc/generator/sbt-openapi-codegen.md
Original file line number Diff line number Diff line change
Expand Up @@ -35,13 +35,14 @@ defined case-classes and endpoint definitions.
The generator currently supports these settings, you can override them in the `build.sbt`;

```eval_rst
=================== ==================================== ===========================================
setting default value description
=================== ==================================== ===========================================
openapiSwaggerFile baseDirectory.value / "swagger.yaml" The swagger file with the api definitions.
openapiPackage sttp.tapir.generated The name for the generated package.
openapiObject TapirGeneratedEndpoints The name for the generated object.
=================== ==================================== ===========================================
=============================== ==================================== =====================================================================
setting default value description
=============================== ==================================== =====================================================================
openapiSwaggerFile baseDirectory.value / "swagger.yaml" The swagger file with the api definitions.
openapiPackage sttp.tapir.generated The name for the generated package.
openapiObject TapirGeneratedEndpoints The name for the generated object.
openapiUseHeadTagForObjectName false If true, put endpoints in separate files based on first declared tag.
=============================== ==================================== =====================================================================
```

The general usage is;
Expand All @@ -54,6 +55,33 @@ import sttp.tapir.docs.openapi._
val docs = TapirGeneratedEndpoints.generatedEndpoints.toOpenAPI("My Bookshop", "1.0")
```

### Output files

To expand on the `openapiUseHeadTagForObjectName` setting a little more, suppose we have the following endpoints:
```yaml
paths:
/foo:
get:
tags:
- Baz
- Foo
put:
tags: []
/bar:
get:
tags:
- Baz
- Bar
```
In this case 'head' tag for `GET /foo` and `GET /bar` would be 'Baz', and `PUT /foo` has no tags (and thus no 'head' tag).

If `openapiUseHeadTagForObjectName = false` (assuming default settings for the other flags) then all endpoint definitions
will be output to the `TapirGeneratedEndpoints.scala` file, which will contain a single `object TapirGeneratedEndpoints`.

If `openapiUseHeadTagForObjectName = true`, then the `GET /foo` and `GET /bar` endpoints would be output to a
`Baz.scala` file, containing a single `object Baz` with those endpoint definitions; the `PUT /foo` endpoint, by dint of
having no tags, would be output to the `TapirGeneratedEndpoints` file, along with any schema and parameter definitions.

### Limitations

Currently, the generated code depends on `"io.circe" %% "circe-generic"`. In the future probably we will make the encoder/decoder json lib configurable (PRs welcome).
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,12 @@ object GenScala {
)
.orNone

private val targetScala3Opt: Opts[Boolean] =
Opts.flag("scala3", "Whether to generate Scala 3 code", "3").orFalse

private val headTagForNamesOpt: Opts[Boolean] =
Opts.flag("headTagForNames", "Whether to group generated endpoints by first declared tag", "t").orFalse

private val destDirOpt: Opts[File] =
Opts
.option[String]("destdir", "Destination directory", "d")
Expand All @@ -53,22 +59,25 @@ object GenScala {
}

val cmd: Command[IO[ExitCode]] = Command("genscala", "Generate Scala classes", helpFlag = true) {
(fileOpt, packageNameOpt, destDirOpt, objectNameOpt).mapN { case (file, packageName, destDir, maybeObjectName) =>
val objectName = maybeObjectName.getOrElse(DefaultObjectName)

def generateCode(doc: OpenapiDocument): IO[Unit] = for {
content <- IO.pure(BasicGenerator.generateObjects(doc, packageName, objectName, false))
destFile <- writeGeneratedFile(destDir, objectName, content)
_ <- IO.println(s"Generated endpoints written to: $destFile")
} yield ()

for {
parsed <- readFile(file).map(YamlParser.parseFile)
exitCode <- parsed match {
case Left(err) => IO.println(s"Invalid YAML file: ${err.getMessage}").as(ExitCode.Error)
case Right(doc) => generateCode(doc).as(ExitCode.Success)
}
} yield exitCode
(fileOpt, packageNameOpt, destDirOpt, objectNameOpt, targetScala3Opt, headTagForNamesOpt).mapN {
case (file, packageName, destDir, maybeObjectName, targetScala3, headTagForNames) =>
val objectName = maybeObjectName.getOrElse(DefaultObjectName)

def generateCode(doc: OpenapiDocument): IO[Unit] = for {
contents <- IO.pure(
BasicGenerator.generateObjects(doc, packageName, objectName, targetScala3, headTagForNames)
)
destFiles <- contents.toVector.traverse{ case (fileName, content) => writeGeneratedFile(destDir, fileName, content) }
_ <- IO.println(s"Generated endpoints written to: ${destFiles.mkString(", ")}")
} yield ()

for {
parsed <- readFile(file).map(YamlParser.parseFile)
exitCode <- parsed match {
case Left(err) => IO.println(s"Invalid YAML file: ${err.getMessage}").as(ExitCode.Error)
case Right(doc) => generateCode(doc).as(ExitCode.Success)
}
} yield exitCode
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,35 @@ object BasicGenerator {
val classGenerator = new ClassDefinitionGenerator()
val endpointGenerator = new EndpointGenerator()

def generateObjects(doc: OpenapiDocument, packagePath: String, objName: String, targetScala3: Boolean): String = {
def generateObjects(
doc: OpenapiDocument,
packagePath: String,
objName: String,
targetScala3: Boolean,
useHeadTagForObjectNames: Boolean
): Map[String, String] = {
val enumImport =
if (!targetScala3 && doc.components.toSeq.flatMap(_.schemas).exists(_._2.isInstanceOf[OpenapiSchemaEnum])) "\n import enumeratum._"
else ""
s"""|

val endpointsByTag = endpointGenerator.endpointDefs(doc, useHeadTagForObjectNames)
val taggedObjs = endpointsByTag.collect {
case (Some(headTag), body) if body.nonEmpty =>
val taggedObj =
s"""package $packagePath
|
|import $objName._
|
|object $headTag {
|
|${indent(2)(imports)}
|
|${indent(2)(body)}
|
|}""".stripMargin
headTag -> taggedObj
}
val mainObj = s"""|
|package $packagePath
|
|object $objName {
Expand All @@ -35,10 +59,11 @@ object BasicGenerator {
|
|${indent(2)(classGenerator.classDefs(doc, targetScala3).getOrElse(""))}
|
|${indent(2)(endpointGenerator.endpointDefs(doc))}
|${indent(2)(endpointsByTag.getOrElse(None, ""))}
|
|}
|""".stripMargin
taggedObjs + (objName -> mainObj)
}

private[codegen] def imports: String =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,49 +19,60 @@ class EndpointGenerator {

private[codegen] def allEndpoints: String = "generatedEndpoints"

def endpointDefs(doc: OpenapiDocument): String = {
def endpointDefs(doc: OpenapiDocument, useHeadTagForObjectNames: Boolean): Map[Option[String], String] = {
val components = Option(doc.components).flatten
val ge = doc.paths.flatMap(generatedEndpoints(components))
val definitions = ge
.map { case (name, definition) =>
s"""|lazy val $name =
val geMap =
doc.paths.flatMap(generatedEndpoints(components, useHeadTagForObjectNames)).groupBy(_._1).mapValues(_.map(_._2).reduce(_ ++ _))
geMap.mapValues { ge =>
val definitions = ge
.map { case (name, definition) =>
s"""|lazy val $name =
|${indent(2)(definition)}
|""".stripMargin
}
.mkString("\n")
val allEP = s"lazy val $allEndpoints = List(${ge.map(_._1).mkString(", ")})"

s"""|$definitions
|
|$allEP
|""".stripMargin
}
.mkString("\n")
val allEP = s"lazy val $allEndpoints = List(${ge.map(_._1).mkString(", ")})"

s"""|$definitions
|
|$allEP
|""".stripMargin
}.toMap
}

private[codegen] def generatedEndpoints(components: Option[OpenapiComponent])(p: OpenapiPath): Seq[(String, String)] = {
private[codegen] def generatedEndpoints(components: Option[OpenapiComponent], useHeadTagForObjectNames: Boolean)(
p: OpenapiPath
): Seq[(Option[String], Seq[(String, String)])] = {
val parameters = components.map(_.parameters).getOrElse(Map.empty)
val securitySchemes = components.map(_.securitySchemes).getOrElse(Map.empty)

p.methods.map(_.withResolvedParentParameters(parameters, p.parameters)).map { m =>
implicit val location: Location = Location(p.url, m.methodType)
val definition =
s"""|endpoint
| .${m.methodType}
| ${urlMapper(p.url, m.resolvedParameters)}
|${indent(2)(security(securitySchemes, m.security))}
|${indent(2)(ins(m.resolvedParameters, m.requestBody))}
|${indent(2)(outs(m.responses))}
|${indent(2)(tags(m.tags))}
|""".stripMargin

val name = m.operationId
.getOrElse(m.methodType + p.url.capitalize)
.split("[^0-9a-zA-Z$_]")
.filter(_.nonEmpty)
.zipWithIndex
.map { case (part, 0) => part; case (part, _) => part.capitalize }
.mkString
(name, definition)
}
p.methods
.map(_.withResolvedParentParameters(parameters, p.parameters))
.map { m =>
implicit val location: Location = Location(p.url, m.methodType)
val definition =
s"""|endpoint
| .${m.methodType}
| ${urlMapper(p.url, m.resolvedParameters)}
|${indent(2)(security(securitySchemes, m.security))}
|${indent(2)(ins(m.resolvedParameters, m.requestBody))}
|${indent(2)(outs(m.responses))}
|${indent(2)(tags(m.tags))}
|""".stripMargin

val name = m.operationId
.getOrElse(m.methodType + p.url.capitalize)
.split("[^0-9a-zA-Z$_]")
.filter(_.nonEmpty)
.zipWithIndex
.map { case (part, 0) => part; case (part, _) => part.capitalize }
.mkString
val maybeTargetFileName = if (useHeadTagForObjectNames) m.tags.flatMap(_.headOption) else None
(maybeTargetFileName, (name, definition))
}
.groupBy(_._1)
.toSeq
.map { case (maybeTargetFileName, defns) => maybeTargetFileName -> defns.map(_._2) }
}

private def urlMapper(url: String, parameters: Seq[OpenapiParameter])(implicit location: Location): String = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,30 @@ class BasicGeneratorSpec extends CompileCheckTestBase {
TestHelpers.myBookshopDoc,
"sttp.tapir.generated",
"TapirGeneratedEndpoints",
targetScala3 = false
) shouldCompile ()
targetScala3 = false,
useHeadTagForObjectNames = false
)("TapirGeneratedEndpoints") shouldCompile ()
}

it should "split outputs by tag if useHeadTagForObjectNames = true" in {
val generated = BasicGenerator.generateObjects(
TestHelpers.myBookshopDoc,
"sttp.tapir.generated",
"TapirGeneratedEndpoints",
targetScala3 = false,
useHeadTagForObjectNames = true
)
val schemas = generated("TapirGeneratedEndpoints")
val endpoints = generated("Bookshop")
// schema file on its own should compile
schemas shouldCompile ()
// schema file should contain no endpoint definitions
schemas.linesIterator.count(_.matches("""^\s*endpoint""")) shouldEqual 0
// Bookshop file should contain all endpoint definitions
endpoints.linesIterator.count(_.matches("""^\s*endpoint""")) shouldEqual 3
// endpoint file depends on schema file. For simplicity of testing, just strip the package declaration from the
// endpoint file, and concat the two, before testing for compilation
(schemas + "\n" + (endpoints.linesIterator.filterNot(_ startsWith "package").mkString("\n"))) shouldCompile ()
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -344,7 +344,7 @@ class ClassDefinitionGeneratorSpec extends CompileCheckTestBase {

val res: String = parserRes match {
case Left(value) => throw new Exception(value)
case Right(doc) => new EndpointGenerator().endpointDefs(doc)
case Right(doc) => new EndpointGenerator().endpointDefs(doc, useHeadTagForObjectNames = false)(None)
}

val compileUnit =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ class EndpointGeneratorSpec extends CompileCheckTestBase {
),
null
)
val generatedCode = BasicGenerator.imports ++ new EndpointGenerator().endpointDefs(doc)
val generatedCode = BasicGenerator.imports ++ new EndpointGenerator().endpointDefs(doc, useHeadTagForObjectNames = false)(None)
generatedCode should include("val getTestAsdId =")
generatedCode shouldCompile ()
}
Expand Down Expand Up @@ -131,7 +131,7 @@ class EndpointGeneratorSpec extends CompileCheckTestBase {
)
)
BasicGenerator.imports ++
new EndpointGenerator().endpointDefs(doc) shouldCompile ()
new EndpointGenerator().endpointDefs(doc, useHeadTagForObjectNames = false)(None) shouldCompile ()
}

it should "handle status codes" in {
Expand Down Expand Up @@ -174,7 +174,7 @@ class EndpointGeneratorSpec extends CompileCheckTestBase {
),
null
)
val generatedCode = BasicGenerator.imports ++ new EndpointGenerator().endpointDefs(doc)
val generatedCode = BasicGenerator.imports ++ new EndpointGenerator().endpointDefs(doc, useHeadTagForObjectNames = false)(None)
generatedCode should include(
""".out(stringBody.description("Processing").and(statusCode(sttp.model.StatusCode(202))))"""
) // status code with body
Expand Down Expand Up @@ -230,7 +230,13 @@ class EndpointGeneratorSpec extends CompileCheckTestBase {
)
)
)
val generatedCode = BasicGenerator.generateObjects(doc, "sttp.tapir.generated", "TapirGeneratedEndpoints", targetScala3 = false)
val generatedCode = BasicGenerator.generateObjects(
doc,
"sttp.tapir.generated",
"TapirGeneratedEndpoints",
targetScala3 = false,
useHeadTagForObjectNames = false
)("TapirGeneratedEndpoints")
generatedCode should include(
"""file: sttp.model.Part[java.io.File]"""
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@ trait OpenapiCodegenKeys {
lazy val openapiSwaggerFile = settingKey[File]("The swagger file with the api definitions.")
lazy val openapiPackage = settingKey[String]("The name for the generated package.")
lazy val openapiObject = settingKey[String]("The name for the generated object.")
lazy val openapiUseHeadTagForObjectName = settingKey[Boolean](
"If true, any tagged endpoints will be defined in an object with a name based on the first tag, instead of on the default generated object."
)

lazy val generateTapirDefinitions = taskKey[Unit]("The task that generates tapir definitions based on the input swagger file.")
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,15 @@ object OpenapiCodegenPlugin extends AutoPlugin {
def openapiCodegenScopedSettings(conf: Configuration): Seq[Setting[_]] = inConfig(conf)(
Seq(
generateTapirDefinitions := codegen.value,
sourceGenerators += (codegen.taskValue).map(_.map(_.toPath.toFile))
sourceGenerators += (codegen.taskValue).map(_.flatMap(_.map(_.toPath.toFile)))
)
)

def openapiCodegenDefaultSettings: Seq[Setting[_]] = Seq(
openapiSwaggerFile := baseDirectory.value / "swagger.yaml",
openapiPackage := "sttp.tapir.generated",
openapiObject := "TapirGeneratedEndpoints"
openapiObject := "TapirGeneratedEndpoints",
openapiUseHeadTagForObjectName := false
)

private def codegen = Def.task {
Expand All @@ -35,6 +36,7 @@ object OpenapiCodegenPlugin extends AutoPlugin {
openapiSwaggerFile,
openapiPackage,
openapiObject,
openapiUseHeadTagForObjectName,
sourceManaged,
streams,
scalaVersion
Expand All @@ -43,11 +45,12 @@ object OpenapiCodegenPlugin extends AutoPlugin {
swaggerFile: File,
packageName: String,
objectName: String,
useHeadTagForObjectName: Boolean,
srcDir: File,
taskStreams: TaskStreams,
sv: String
) =>
OpenapiCodegenTask(swaggerFile, packageName, objectName, srcDir, taskStreams.cacheDirectory, sv.startsWith("3")).file
OpenapiCodegenTask(swaggerFile, packageName, objectName, useHeadTagForObjectName, srcDir, taskStreams.cacheDirectory, sv.startsWith("3")).file
}) map (Seq(_))).value
}
}
Loading

0 comments on commit e611d21

Please sign in to comment.