Skip to content

Commit

Permalink
codegen: Support default values in schema objects (#3614)
Browse files Browse the repository at this point in the history
  • Loading branch information
hughsimpson authored Mar 20, 2024
1 parent 171fb66 commit 3880806
Show file tree
Hide file tree
Showing 10 changed files with 475 additions and 73 deletions.
Original file line number Diff line number Diff line change
@@ -1,14 +1,16 @@
package sttp.tapir.codegen

import io.circe.Json
import sttp.tapir.codegen.BasicGenerator.{indent, mapSchemaSimpleTypeToType}
import sttp.tapir.codegen.openapi.models.OpenapiModels.OpenapiDocument
import sttp.tapir.codegen.openapi.models.OpenapiSchemaType
import sttp.tapir.codegen.openapi.models.{OpenapiSchemaType, Renderer}
import sttp.tapir.codegen.openapi.models.OpenapiSchemaType._

import scala.annotation.tailrec

class ClassDefinitionGenerator {
val jsoniterDefaultConfig = "com.github.plokhotnyuk.jsoniter_scala.macros.CodecMakerConfig.withAllowRecursiveTypes(true).withDiscriminatorFieldName(scala.None)"
val jsoniterDefaultConfig =
"com.github.plokhotnyuk.jsoniter_scala.macros.CodecMakerConfig.withAllowRecursiveTypes(true).withDiscriminatorFieldName(scala.None)"

def classDefs(
doc: OpenapiDocument,
Expand Down Expand Up @@ -59,7 +61,7 @@ class ClassDefinitionGenerator {
val defns = doc.components
.map(_.schemas.flatMap {
case (name, obj: OpenapiSchemaObject) =>
generateClass(name, obj, jsonSerdeLib, allTransitiveJsonParamRefs)
generateClass(allSchemas, name, obj, jsonSerdeLib, allTransitiveJsonParamRefs)
case (name, obj: OpenapiSchemaEnum) =>
generateEnum(name, obj, targetScala3, queryParamRefs, jsonSerdeLib, allTransitiveJsonParamRefs)
case (name, OpenapiSchemaMap(valueSchema, _)) => generateMap(name, valueSchema, jsonSerdeLib, allTransitiveJsonParamRefs)
Expand Down Expand Up @@ -139,7 +141,7 @@ class ClassDefinitionGenerator {
case OpenapiSchemaObject(properties, _, _) if properties.isEmpty => None
case OpenapiSchemaObject(properties, required, nullable) =>
val propToCheck = properties.head
val (propToCheckName, propToCheckType) = propToCheck
val (propToCheckName, OpenapiSchemaField(propToCheckType, _)) = propToCheck
val objectWithoutHeadField = OpenapiSchemaObject(properties - propToCheckName, required, nullable)
Some((propToCheckType, checked, objectWithoutHeadField +: tail))
case _ => None
Expand Down Expand Up @@ -236,6 +238,7 @@ class ClassDefinitionGenerator {
}

private[codegen] def generateClass(
allSchemas: Map[String, OpenapiSchemaType],
name: String,
obj: OpenapiSchemaObject,
jsonSerdeLib: JsonSerdeLib.JsonSerdeLib,
Expand All @@ -245,25 +248,28 @@ class ClassDefinitionGenerator {
def rec(name: String, obj: OpenapiSchemaObject, acc: List[String]): Seq[String] = {
val innerClasses = obj.properties
.collect {
case (propName, st: OpenapiSchemaObject) =>
case (propName, OpenapiSchemaField(st: OpenapiSchemaObject, _)) =>
val newName = addName(name, propName)
rec(newName, st, Nil)

case (propName, OpenapiSchemaMap(st: OpenapiSchemaObject, _)) =>
case (propName, OpenapiSchemaField(OpenapiSchemaMap(st: OpenapiSchemaObject, _), _)) =>
val newName = addName(addName(name, propName), "item")
rec(newName, st, Nil)

case (propName, OpenapiSchemaArray(st: OpenapiSchemaObject, _)) =>
case (propName, OpenapiSchemaField(OpenapiSchemaArray(st: OpenapiSchemaObject, _), _)) =>
val newName = addName(addName(name, propName), "item")
rec(newName, st, Nil)
}
.flatten
.toList

val properties = obj.properties.map { case (key, schemaType) =>
val properties = obj.properties.map { case (key, OpenapiSchemaField(schemaType, maybeDefault)) =>
val tpe = mapSchemaTypeToType(name, key, obj.required.contains(key), schemaType, isJson)
val fixedKey = fixKey(key)
s"$fixedKey: $tpe"
val optional = schemaType.nullable || !obj.required.contains(key)
val maybeExplicitDefault = maybeDefault.map(" = " + Renderer.render(allModels = allSchemas, thisType = schemaType, optional)(_))
val default = maybeExplicitDefault getOrElse (if (optional) " = None" else "")
s"$fixedKey: $tpe$default"
}

val uncapitalisedName = name.head.toLower +: name.tail
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,7 @@ object OpenapiModels {
implicit def ResolvableDecoder[T: Decoder]: Decoder[Resolvable[T]] = { (c: HCursor) =>
c.as[T].map(Resolved(_)).orElse(c.as[OpenapiSchemaRef].map(r => Ref(r.name)))
}

implicit val PartialOpenapiPathMethodDecoder: Decoder[OpenapiPathMethod] = { (c: HCursor) =>
for {
parameters <- c.getOrElse[Seq[Resolvable[OpenapiParameter]]]("parameters")(Nil)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
package sttp.tapir.codegen.openapi.models

import io.circe.Json

sealed trait OpenapiSchemaType {
def nullable: Boolean
}
Expand Down Expand Up @@ -106,9 +108,13 @@ object OpenapiSchemaType {
nullable: Boolean
) extends OpenapiSchemaType

case class OpenapiSchemaField(
`type`: OpenapiSchemaType,
default: Option[Json]
)
// no readOnly/writeOnly, minProperties/maxProperties support
case class OpenapiSchemaObject(
properties: Map[String, OpenapiSchemaType],
properties: Map[String, OpenapiSchemaField],
required: Seq[String],
nullable: Boolean
) extends OpenapiSchemaType
Expand Down Expand Up @@ -253,14 +259,21 @@ object OpenapiSchemaType {
} yield OpenapiSchemaEnum(tpe, items, nb.getOrElse(false))
}

implicit val SchemaTypeWithDefaultDecoder: Decoder[(OpenapiSchemaType, Option[Json])] = { (c: HCursor) =>
for {
schemaType <- c.as[OpenapiSchemaType]
maybeDefault <- c.downField("default").as[Option[Json]]
} yield (schemaType, maybeDefault)
}
implicit val OpenapiSchemaObjectDecoder: Decoder[OpenapiSchemaObject] = { (c: HCursor) =>
for {
_ <- c.downField("type").as[String].ensure(DecodingFailure("Given type is not object!", c.history))(v => v == "object")
f <- c.downField("properties").as[Option[Map[String, OpenapiSchemaType]]]
fieldsWithDefaults <- c.downField("properties").as[Option[Map[String, (OpenapiSchemaType, Option[Json])]]]
r <- c.downField("required").as[Option[Seq[String]]]
nb <- c.downField("nullable").as[Option[Boolean]]
fields = fieldsWithDefaults.getOrElse(Map.empty).map { case (k, (f, d)) => k -> OpenapiSchemaField(f, d) }
} yield {
OpenapiSchemaObject(f.getOrElse(Map.empty), r.getOrElse(Seq.empty), nb.getOrElse(false))
OpenapiSchemaObject(fields, r.getOrElse(Seq.empty), nb.getOrElse(false))
}
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
package sttp.tapir.codegen.openapi.models

import io.circe.Json
import sttp.tapir.codegen.openapi.models.OpenapiSchemaType.{
OpenapiSchemaArray,
OpenapiSchemaBinary,
OpenapiSchemaBoolean,
OpenapiSchemaDateTime,
OpenapiSchemaDouble,
OpenapiSchemaEnum,
OpenapiSchemaFloat,
OpenapiSchemaInt,
OpenapiSchemaLong,
OpenapiSchemaMap,
OpenapiSchemaObject,
OpenapiSchemaRef,
OpenapiSchemaString,
OpenapiSchemaUUID
}

object Renderer {
private def lookup(allModels: Map[String, OpenapiSchemaType], ref: OpenapiSchemaRef): OpenapiSchemaType = allModels(
ref.name.stripPrefix("#/components/schemas/")
)

private def renderStringWithName(
value: String
)(allModels: Map[String, OpenapiSchemaType], thisType: OpenapiSchemaType, name: String): String =
thisType match {
case ref: OpenapiSchemaRef =>
renderStringWithName(value)(allModels, lookup(allModels, ref), ref.name.stripPrefix("#/components/schemas/"))
case OpenapiSchemaString(_) => '"' +: value :+ '"'
case OpenapiSchemaEnum(_, _, _) => s"$name.$value"
case OpenapiSchemaDateTime(_) => s"""java.time.Instant.parse("$value")"""
case OpenapiSchemaBinary(_) => s""""$value".getBytes("utf-8")"""
case OpenapiSchemaUUID(_) => s"""java.util.UUID.fromString("$value")"""
case other => throw new IllegalArgumentException(s"Cannot render a string as type ${other.getClass.getName}")
}
private def renderMapWithName(
kvs: Map[String, Json]
)(allModels: Map[String, OpenapiSchemaType], thisType: OpenapiSchemaType, name: String): String = {
def errorForKey(k: String): Nothing = throw new IllegalArgumentException(
s"Cannot find property $k in schema $name when constructing default value"
)
thisType match {
case ref: OpenapiSchemaRef => renderMapWithName(kvs)(allModels, lookup(allModels, ref), ref.name.stripPrefix("#/components/schemas/"))
case OpenapiSchemaMap(types, _) =>
s"Map(${kvs.map { case (k, v) => s""""$k" -> ${render(allModels, types, isOptional = false)(v)}""" }.mkString(", ")})"
case OpenapiSchemaObject(properties, required, _) =>
val kvsWithProps = kvs.map { case (k, v) => (k, (v, properties.get(k).getOrElse(errorForKey(k)))) }
s"$name(${kvsWithProps
.map { case (k, (v, p)) => s"""$k = ${render(allModels, p.`type`, p.`type`.nullable || !required.contains(k))(v)}""" }
.mkString(", ")})"
case other => throw new IllegalArgumentException(s"Cannot render a map as type ${other.getClass.getName}")
}
}

def render(allModels: Map[String, OpenapiSchemaType], thisType: OpenapiSchemaType, isOptional: Boolean)(json: Json): String =
if (json == Json.Null) {
if (isOptional) "None" else "null"
} else {
def fail(tpe: String, schemaType: OpenapiSchemaType, reason: Option[String] = None): Nothing =
throw new IllegalArgumentException(
s"Cannot render a $tpe as type ${schemaType.getClass.getName}.${reason.map(" " + _).getOrElse("")}"
)
val base: String = json.fold[String](
"null",
jsBool =>
thisType match {
case ref: OpenapiSchemaRef => render(allModels, lookup(allModels, ref), isOptional = false)(json)
case OpenapiSchemaBoolean(_) => jsBool.toString
case other => fail("boolean", other)
},
jsonNumber =>
thisType match {
case ref: OpenapiSchemaRef => render(allModels, lookup(allModels, ref), isOptional = false)(json)
case l @ OpenapiSchemaLong(_) => s"${jsonNumber.toLong.getOrElse(fail("number", l, Some(s"$jsonNumber is not a long")))}L"
case i @ OpenapiSchemaInt(_) => jsonNumber.toInt.getOrElse(fail("number", i, Some(s"$jsonNumber is not an int"))).toString
case OpenapiSchemaFloat(_) => s"${jsonNumber.toFloat}f"
case OpenapiSchemaDouble(_) => s"${jsonNumber.toDouble}d"
case other => fail("number", other)
},
jsonString =>
thisType match {
case ref: OpenapiSchemaRef =>
renderStringWithName(jsonString)(allModels, lookup(allModels, ref), ref.name.stripPrefix("#/components/schemas/"))
case OpenapiSchemaString(_) => '"' +: jsonString :+ '"'
case OpenapiSchemaDateTime(_) => s"""java.time.Instant.parse("$jsonString")"""
case OpenapiSchemaBinary(_) => s""""$jsonString".getBytes("utf-8")"""
case OpenapiSchemaUUID(_) => s"""java.util.UUID.fromString("$jsonString")"""
// case OpenapiSchemaEnum(_, _, _) => // inline enum definitions are not currently supported, so let it throw
case other => fail("string", other)
},
jsonArray =>
thisType match {
case ref: OpenapiSchemaRef => render(allModels, lookup(allModels, ref), isOptional = false)(json)
case OpenapiSchemaArray(items, _) => s"Vector(${jsonArray.map(render(allModels, items, isOptional = false)).mkString(", ")})"
case other => fail("list", other)
},
jsonObject =>
thisType match {
case ref: OpenapiSchemaRef =>
renderMapWithName(jsonObject.toMap)(allModels, lookup(allModels, ref), ref.name.stripPrefix("#/components/schemas/"))
case OpenapiSchemaMap(types, _) =>
s"Map(${jsonObject.toMap.map { case (k, v) => s""""$k" -> ${render(allModels, types, isOptional = false)(v)}""" }.mkString(", ")})"
case other => fail("map", other)
}
)
if (isOptional) s"Some($base)" else base
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,30 @@ class BasicGeneratorSpec extends CompileCheckTestBase {
)("TapirGeneratedEndpoints") shouldCompile ()
}

it should s"compile endpoints with default params using ${jsonSerdeLib} serdes" in {
val genWithParams = BasicGenerator.generateObjects(
TestHelpers.withDefaultsDocs,
"sttp.tapir.generated",
"TapirGeneratedEndpoints",
targetScala3 = false,
useHeadTagForObjectNames = false,
jsonSerdeLib = jsonSerdeLib
)("TapirGeneratedEndpoints")

val expectedDefaultDeclarations = Seq(
"""f1: String = "default string"""",
"""f2: Option[Int] = Some(1977)""",
"""g1: Option[java.util.UUID] = Some(java.util.UUID.fromString("default string"))""",
"""g2: Float = 1977.0f""",
"""g3: Option[AnEnum] = Some(AnEnum.v1)""",
"""g4: Option[Seq[AnEnum]] = Some(Vector(AnEnum.v1, AnEnum.v2, AnEnum.v3))""",
"""sub: Option[SubObject] = Some(SubObject(subsub = SubSubObject(value = "hi there", value2 = Some(java.util.UUID.fromString("ac8113ed-6105-4f65-a393-e88be2c5d585")))))"""
)
expectedDefaultDeclarations foreach (decln => genWithParams should include(decln))

genWithParams shouldCompile ()
}

}
Seq("circe", "jsoniter") foreach testJsonLib

Expand Down
Loading

0 comments on commit 3880806

Please sign in to comment.