Skip to content

Commit

Permalink
codegen: Semiauto schema derivation (#3671)
Browse files Browse the repository at this point in the history
hughsimpson authored Apr 19, 2024

Verified

This commit was created on GitHub.com and signed with GitHub’s verified signature.
1 parent 1cf911d commit 20a526c
Showing 18 changed files with 518 additions and 79 deletions.
7 changes: 4 additions & 3 deletions doc/generator/sbt-openapi-codegen.md
Original file line number Diff line number Diff line change
@@ -35,16 +35,17 @@ defined case-classes and endpoint definitions.
The generator currently supports these settings, you can override them in the `build.sbt`;

```eval_rst
===================================== ==================================== =======================================================================================
===================================== ==================================== ==================================================================================================
setting default value description
===================================== ==================================== =======================================================================================
===================================== ==================================== ==================================================================================================
openapiSwaggerFile baseDirectory.value / "swagger.yaml" The swagger file with the api definitions.
openapiPackage sttp.tapir.generated The name for the generated package.
openapiObject TapirGeneratedEndpoints The name for the generated object.
openapiUseHeadTagForObjectName false If true, put endpoints in separate files based on first declared tag.
openapiJsonSerdeLib circe The json serde library to use.
openapiValidateNonDiscriminatedOneOfs true Whether to fail if variants of a oneOf without a discriminator cannot be disambiguated.
===================================== ==================================== =======================================================================================
openapiMaxSchemasPerFile 400 Maximum number of schemas to generate in a single file (tweak if hitting javac class size limits).
===================================== ==================================== ==================================================================================================
```

The general usage is;
Original file line number Diff line number Diff line change
@@ -54,6 +54,10 @@ object GenScala {
"v"
)
.orFalse
private val maxSchemasPerFileOpt: Opts[Option[Int]] =
Opts
.option[Int]("maxSchemasPerFile", "Maximum number of schemas to generate in a single file.", "m")
.orNone

private val jsonLibOpt: Opts[Option[String]] =
Opts.option[String]("jsonLib", "Json library to use for serdes", "j").orNone
@@ -71,8 +75,8 @@ object GenScala {
}

val cmd: Command[IO[ExitCode]] = Command("genscala", "Generate Scala classes", helpFlag = true) {
(fileOpt, packageNameOpt, destDirOpt, objectNameOpt, targetScala3Opt, headTagForNamesOpt, jsonLibOpt, validateNonDiscriminatedOneOfsOpt)
.mapN { case (file, packageName, destDir, maybeObjectName, targetScala3, headTagForNames, jsonLib, validateNonDiscriminatedOneOfs) =>
(fileOpt, packageNameOpt, destDirOpt, objectNameOpt, targetScala3Opt, headTagForNamesOpt, jsonLibOpt, validateNonDiscriminatedOneOfsOpt, maxSchemasPerFileOpt)
.mapN { case (file, packageName, destDir, maybeObjectName, targetScala3, headTagForNames, jsonLib, validateNonDiscriminatedOneOfs, maxSchemasPerFile) =>
val objectName = maybeObjectName.getOrElse(DefaultObjectName)

def generateCode(doc: OpenapiDocument): IO[Unit] = for {
@@ -84,7 +88,8 @@ object GenScala {
targetScala3,
headTagForNames,
jsonLib.getOrElse("circe"),
validateNonDiscriminatedOneOfs
validateNonDiscriminatedOneOfs,
maxSchemasPerFile.getOrElse(400)
)
)
destFiles <- contents.toVector.traverse { case (fileName, content) => writeGeneratedFile(destDir, fileName, content) }
Original file line number Diff line number Diff line change
@@ -34,7 +34,8 @@ object BasicGenerator {
targetScala3: Boolean,
useHeadTagForObjectNames: Boolean,
jsonSerdeLib: String,
validateNonDiscriminatedOneOfs: Boolean
validateNonDiscriminatedOneOfs: Boolean,
maxSchemasPerFile: Int
): Map[String, String] = {
val normalisedJsonLib = jsonSerdeLib.toLowerCase match {
case "circe" => JsonSerdeLib.Circe
@@ -47,7 +48,7 @@ object BasicGenerator {
}

val EndpointDefs(endpointsByTag, queryParamRefs, jsonParamRefs) = endpointGenerator.endpointDefs(doc, useHeadTagForObjectNames)
val GeneratedClassDefinitions(classDefns, extras) =
val GeneratedClassDefinitions(classDefns, jsonSerdes, schemas) =
classGenerator
.classDefs(
doc = doc,
@@ -56,15 +57,19 @@ object BasicGenerator {
jsonSerdeLib = normalisedJsonLib,
jsonParamRefs = jsonParamRefs,
fullModelPath = s"$packagePath.$objName",
validateNonDiscriminatedOneOfs = validateNonDiscriminatedOneOfs
validateNonDiscriminatedOneOfs = validateNonDiscriminatedOneOfs,
maxSchemasPerFile = maxSchemasPerFile
)
.getOrElse(GeneratedClassDefinitions("", None))
val isSplit = extras.nonEmpty
val internalImports =
if (isSplit)
s"""import $packagePath.$objName._
|import ${objName}JsonSerdes._""".stripMargin
else s"import $objName._"
.getOrElse(GeneratedClassDefinitions("", None, Nil))
val hasJsonSerdes = jsonSerdes.nonEmpty

val maybeJsonImport = if (hasJsonSerdes) s"\nimport $packagePath.${objName}JsonSerdes._" else ""
val maybeSchemaImport =
if (schemas.size > 1) (1 to schemas.size).map(i => s"import ${objName}Schemas$i._").mkString("\n", "\n", "")
else if (schemas.size == 1) s"\nimport ${objName}Schemas._"
else ""
val internalImports = s"import $packagePath.$objName._$maybeJsonImport$maybeSchemaImport"

val taggedObjs = endpointsByTag.collect {
case (Some(headTag), body) if body.nonEmpty =>
val taggedObj =
@@ -81,14 +86,39 @@ object BasicGenerator {
|}""".stripMargin
headTag -> taggedObj
}
val extraObj = extras.map { body =>

val jsonSerdeObj = jsonSerdes.map { body =>
s"""package $packagePath
|
|object ${objName}JsonSerdes {
| import $packagePath.$objName._
| import sttp.tapir.generic.auto._
|${indent(2)(body)}
|}""".stripMargin
}

val schemaObjs = if (schemas.size > 1) schemas.zipWithIndex.map { case (body, idx) =>
val priorImports = (0 until idx).map { i => s"import $packagePath.${objName}Schemas${i + 1}._" }.mkString("\n")
val name = s"${objName}Schemas${idx + 1}"
name -> s"""package $packagePath
|
|object $name {
| import $packagePath.$objName._
| import sttp.tapir.generic.auto._
|${indent(2)(priorImports)}
|${indent(2)(body)}
|}""".stripMargin
}
else if (schemas.size == 1)
Seq(s"${objName}Schemas" -> s"""package $packagePath
|
|object ${objName}Schemas {
| import $packagePath.$objName._
| import sttp.tapir.generic.auto._
|${indent(2)(schemas.head)}
|}""".stripMargin)
else Nil

val endpointsInMain = endpointsByTag.getOrElse(None, "")

val maybeSpecificationExtensionKeys = doc.paths
@@ -100,21 +130,21 @@ object BasicGenerator {
val values = pairs.map(_._2)
val `type` = SpecificationExtensionRenderer.renderCombinedType(values)
val name = strippedToCamelCase(keyName)
val uncapitalisedName = name.head.toLower + name.tail
val capitalisedName = name.head.toUpper + name.tail
val uncapitalisedName = uncapitalise(name)
val capitalisedName = uncapitalisedName.capitalize
s"""type ${capitalisedName}Extension = ${`type`}
|val ${uncapitalisedName}ExtensionKey = new sttp.tapir.AttributeKey[${capitalisedName}Extension]("$packagePath.$objName.${capitalisedName}Extension")
|""".stripMargin
}
.mkString("\n")

val serdeImport = if (isSplit && endpointsInMain.nonEmpty) s"\nimport $packagePath.${objName}JsonSerdes._" else ""
val mainObj = s"""|
val extraImports = if (endpointsInMain.nonEmpty) s"$maybeJsonImport$maybeSchemaImport" else ""
val mainObj = s"""
|package $packagePath
|
|object $objName {
|
|${indent(2)(imports(normalisedJsonLib) + serdeImport)}
|${indent(2)(imports(normalisedJsonLib) + extraImports)}
|
|${indent(2)(classDefns)}
|
@@ -124,7 +154,7 @@ object BasicGenerator {
|
|}
|""".stripMargin
taggedObjs ++ extraObj.map(s"${objName}JsonSerdes" -> _) + (objName -> mainObj)
taggedObjs ++ jsonSerdeObj.map(s"${objName}JsonSerdes" -> _) ++ schemaObjs + (objName -> mainObj)
}

private[codegen] def imports(jsonSerdeLib: JsonSerdeLib.JsonSerdeLib): String = {
@@ -184,4 +214,6 @@ object BasicGenerator {
.zipWithIndex
.map { case (part, 0) => part; case (part, _) => part.capitalize }
.mkString

def uncapitalise(name: String): String = name.head.toLower +: name.tail
}
Original file line number Diff line number Diff line change
@@ -7,7 +7,7 @@ import sttp.tapir.codegen.openapi.models.OpenapiSchemaType._

import scala.annotation.tailrec

case class GeneratedClassDefinitions(classRepr: String, serdeRepr: Option[String])
case class GeneratedClassDefinitions(classRepr: String, serdeRepr: Option[String], schemaRepr: Seq[String])

class ClassDefinitionGenerator {

@@ -18,7 +18,8 @@ class ClassDefinitionGenerator {
jsonSerdeLib: JsonSerdeLib.JsonSerdeLib = JsonSerdeLib.Circe,
jsonParamRefs: Set[String] = Set.empty,
fullModelPath: String = "",
validateNonDiscriminatedOneOfs: Boolean = true
validateNonDiscriminatedOneOfs: Boolean = true,
maxSchemasPerFile: Int = 400
): Option[GeneratedClassDefinitions] = {
val allSchemas: Map[String, OpenapiSchemaType] = doc.components.toSeq.flatMap(_.schemas).toMap
val allOneOfSchemas = allSchemas.collect { case (name, oneOf: OpenapiSchemaOneOf) => name -> oneOf }.toSeq
@@ -40,7 +41,8 @@ class ClassDefinitionGenerator {

val adtTypes = adtInheritanceMap.flatMap(_._2).toSeq.distinct.map(name => s"sealed trait $name").mkString("", "\n", "\n")
val enumQuerySerdeHelper = if (!generatesQueryParamEnums) "" else enumQuerySerdeHelperDefn(targetScala3)
val postDefns = JsonSerdeGenerator.serdeDefs(
val schemas = SchemaGenerator.generateSchemas(doc, allSchemas, fullModelPath, jsonSerdeLib, maxSchemasPerFile)
val jsonSerdes = JsonSerdeGenerator.serdeDefs(
doc,
jsonSerdeLib,
jsonParamRefs,
@@ -63,8 +65,8 @@ class ClassDefinitionGenerator {
val helpers = (enumQuerySerdeHelper + adtTypes).linesIterator
.filterNot(_.forall(_.isWhitespace))
.mkString("\n")
// Json serdes live in a separate file from the class defns
defns.map(helpers + "\n" + _).map(defStr => GeneratedClassDefinitions(defStr, postDefns))
// Json serdes & schemas live in separate files from the class defns
defns.map(helpers + "\n" + _).map(defStr => GeneratedClassDefinitions(defStr, jsonSerdes, schemas))
}

private def mkMapParentsByChild(allOneOfSchemas: Seq[(String, OpenapiSchemaOneOf)]): Map[String, Seq[String]] =
@@ -219,7 +221,7 @@ class ClassDefinitionGenerator {
| case ${obj.items.map(_.value).mkString(", ")}
|}""".stripMargin :: Nil
} else {
val uncapitalisedName = name.head.toLower +: name.tail
val uncapitalisedName = BasicGenerator.uncapitalise(name)
val members = obj.items.map { i => s"case object ${i.value} extends $name" }
val maybeCodecExtension = jsonSerdeLib match {
case _ if !jsonParamRefs.contains(name) && !queryParamRefs.contains(name) => ""
Original file line number Diff line number Diff line change
@@ -7,6 +7,7 @@ import sttp.tapir.codegen.openapi.models.OpenapiSchemaType.{
OpenapiSchemaArray,
OpenapiSchemaBoolean,
OpenapiSchemaEnum,
OpenapiSchemaField,
OpenapiSchemaMap,
OpenapiSchemaNumericType,
OpenapiSchemaObject,
@@ -86,7 +87,7 @@ object JsonSerdeGenerator {
// if lhs has some required non-nullable fields with no default that rhs will never contain, then right cannot be mistaken for left
if ((requiredL.keySet -- anyR.keySet).nonEmpty) false
else {
// otherwise, if any required field on rhs can't look like the similarly-named field on lhs, then r can't look like l
// otherwise, if any field on rhs required by lhs can't look like the similarly-named field on lhs, then r can't look like l
val rForRequiredL = anyR.filter(requiredL.keySet contains _._1)
requiredL.forall { case (k, lhsV) => rCanLookLikeL(lhsV.`type`, rForRequiredL(k).`type`) }
}
@@ -118,8 +119,10 @@ object JsonSerdeGenerator {
// Enum serdes are generated at the declaration site
case (_, _: OpenapiSchemaEnum) => None
// We generate the serde if it's referenced in any json model
case (name, _: OpenapiSchemaObject | _: OpenapiSchemaMap) if allTransitiveJsonParamRefs.contains(name) =>
Some(genCirceNamedSerde(name))
case (name, schema: OpenapiSchemaObject) if allTransitiveJsonParamRefs.contains(name) =>
Some(genCirceObjectSerde(name, schema))
case (name, schema: OpenapiSchemaMap) if allTransitiveJsonParamRefs.contains(name) =>
Some(genCirceMapSerde(name, schema))
case (name, schema: OpenapiSchemaOneOf) if allTransitiveJsonParamRefs.contains(name) =>
Some(genCirceAdtSerde(allSchemas, schema, name, validateNonDiscriminatedOneOfs))
case (_, _: OpenapiSchemaObject | _: OpenapiSchemaMap | _: OpenapiSchemaEnum | _: OpenapiSchemaOneOf) => None
@@ -128,19 +131,36 @@ object JsonSerdeGenerator {
.map(_.mkString("\n"))
}

private def genCirceNamedSerde(name: String): String = {
val uncapitalisedName = name.head.toLower +: name.tail
s"""implicit lazy val ${uncapitalisedName}JsonDecoder: io.circe.Decoder[$name] = io.circe.generic.semiauto.deriveDecoder[$name]
private def genCirceObjectSerde(name: String, schema: OpenapiSchemaObject): String = {
val subs = schema.properties.collect {
case (k, OpenapiSchemaField(`type`: OpenapiSchemaObject, _)) => genCirceObjectSerde(s"$name${k.capitalize}", `type`)
case (k, OpenapiSchemaField(OpenapiSchemaArray(`type`: OpenapiSchemaObject, _), _)) =>
genCirceObjectSerde(s"$name${k.capitalize}Item", `type`)
case (k, OpenapiSchemaField(OpenapiSchemaMap(`type`: OpenapiSchemaObject, _), _)) =>
genCirceObjectSerde(s"$name${k.capitalize}Item", `type`)
} match {
case Nil => ""
case s => s.mkString("", "\n", "\n")
}
val uncapitalisedName = BasicGenerator.uncapitalise(name)
s"""${subs}implicit lazy val ${uncapitalisedName}JsonDecoder: io.circe.Decoder[$name] = io.circe.generic.semiauto.deriveDecoder[$name]
|implicit lazy val ${uncapitalisedName}JsonEncoder: io.circe.Encoder[$name] = io.circe.generic.semiauto.deriveEncoder[$name]""".stripMargin
}
private def genCirceMapSerde(name: String, schema: OpenapiSchemaMap): String = {
val subs = schema.items match {
case `type`: OpenapiSchemaObject => Some(genCirceObjectSerde(s"${name}ObjectsItem", `type`))
case _ => None
}
subs.fold("")("\n" + _)
}

private def genCirceAdtSerde(
allSchemas: Map[String, OpenapiSchemaType],
schema: OpenapiSchemaOneOf,
name: String,
validateNonDiscriminatedOneOfs: Boolean
): String = {
val uncapitalisedName = name.head.toLower +: name.tail
val uncapitalisedName = BasicGenerator.uncapitalise(name)

schema match {
case OpenapiSchemaOneOf(_, Some(discriminator)) =>
@@ -256,7 +276,7 @@ object JsonSerdeGenerator {
}

private def genJsoniterClassSerde(supertypes: Seq[OpenapiSchemaOneOf])(name: String): String = {
val uncapitalisedName = name.head.toLower +: name.tail
val uncapitalisedName = BasicGenerator.uncapitalise(name)
if (supertypes.exists(_.discriminator.isDefined))
throw new NotImplementedError(
s"A class cannot be used both in a oneOf with discriminator and at the top level when using jsoniter serdes at $name"
@@ -266,13 +286,13 @@ object JsonSerdeGenerator {
}

private def genJsoniterEnumSerde(name: String): String = {
val uncapitalisedName = name.head.toLower +: name.tail
val uncapitalisedName = BasicGenerator.uncapitalise(name)
s"""
|implicit lazy val ${uncapitalisedName}JsonCodec: $jsoniterPkgCore.JsonValueCodec[${name}] = $jsoniterPkgMacros.JsonCodecMaker.make($jsoniteEnumConfig.withDiscriminatorFieldName(scala.None))""".stripMargin
}

private def genJsoniterNamedSerde(name: String): String = {
val uncapitalisedName = name.head.toLower +: name.tail
val uncapitalisedName = BasicGenerator.uncapitalise(name)
s"""
|implicit lazy val ${uncapitalisedName}JsonCodec: $jsoniterPkgCore.JsonValueCodec[$name] = $jsoniterPkgMacros.JsonCodecMaker.make($jsoniterBaseConfig)""".stripMargin
}
@@ -285,7 +305,7 @@ object JsonSerdeGenerator {
validateNonDiscriminatedOneOfs: Boolean
): String = {
val fullPathPrefix = maybeFullModelPath.map(_ + ".").getOrElse("")
val uncapitalisedName = name.head.toLower +: name.tail
val uncapitalisedName = BasicGenerator.uncapitalise(name)
schema match {
case OpenapiSchemaOneOf(_, Some(discriminator)) =>
def subtypeNames = schema.types.map {
@@ -321,7 +341,7 @@ object JsonSerdeGenerator {
if (validateNonDiscriminatedOneOfs) checkForSoundness(allSchemas)(schema.types.map(_.asInstanceOf[OpenapiSchemaRef]))
val childNameAndSerde = schemas.collect { case ref: OpenapiSchemaRef =>
val name = ref.stripped
name -> s"${name.head.toLower +: name.tail}JsonCodec"
name -> s"${BasicGenerator.uncapitalise(name)}JsonCodec"
}
val childSerdes = childNameAndSerde.map(_._2)
val doDecode = childSerdes.mkString("List(\n ", ",\n ", ")\n") +
Loading

0 comments on commit 20a526c

Please sign in to comment.