Skip to content

Commit

Permalink
better type resolution
Browse files Browse the repository at this point in the history
  • Loading branch information
hughsimpson committed Mar 20, 2024
1 parent 5d83824 commit faf1012
Show file tree
Hide file tree
Showing 8 changed files with 128 additions and 104 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@ package sttp.tapir.codegen
import sttp.tapir.codegen.openapi.models.OpenapiModels.OpenapiDocument
import sttp.tapir.codegen.openapi.models.OpenapiSchemaType.{
OpenapiSchemaAny,
OpenapiSchemaBoolean,
OpenapiSchemaBinary,
OpenapiSchemaBoolean,
OpenapiSchemaDateTime,
OpenapiSchemaDouble,
OpenapiSchemaFloat,
Expand All @@ -15,6 +15,7 @@ import sttp.tapir.codegen.openapi.models.OpenapiSchemaType.{
OpenapiSchemaString,
OpenapiSchemaUUID
}
import sttp.tapir.codegen.openapi.models.SpecificationExtensionRenderer

object JsonSerdeLib extends Enumeration {
val Circe, Jsoniter = Value
Expand Down Expand Up @@ -69,12 +70,7 @@ object BasicGenerator {
.groupBy(_._1)
.map { case (keyName, pairs) =>
val values = pairs.map(_._2)
val distinctTypes = values.map(_.tpe).distinct
if (distinctTypes.size != 1)
throw new IllegalArgumentException(
s"specification extensions with the same key are expected to all have the same type. Found $distinctTypes for $keyName"
)
val `type` = distinctTypes.head
val `type` = SpecificationExtensionRenderer.renderCombinedType(values)
val name = strippedToCamelCase(keyName)
val uncapitalisedName = name.head.toLower + name.tail
val capitalisedName = name.head.toUpper + name.tail
Expand Down
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
package sttp.tapir.codegen

import io.circe.Json
import sttp.tapir.codegen.BasicGenerator.{indent, mapSchemaSimpleTypeToType}
import sttp.tapir.codegen.openapi.models.OpenapiModels.OpenapiDocument
import sttp.tapir.codegen.openapi.models.{OpenapiSchemaType, Renderer}
import sttp.tapir.codegen.openapi.models.{OpenapiSchemaType, DefaultValueRenderer}
import sttp.tapir.codegen.openapi.models.OpenapiSchemaType._

import scala.annotation.tailrec
Expand Down Expand Up @@ -267,7 +266,8 @@ class ClassDefinitionGenerator {
val tpe = mapSchemaTypeToType(name, key, obj.required.contains(key), schemaType, isJson)
val fixedKey = fixKey(key)
val optional = schemaType.nullable || !obj.required.contains(key)
val maybeExplicitDefault = maybeDefault.map(" = " + Renderer.render(allModels = allSchemas, thisType = schemaType, optional)(_))
val maybeExplicitDefault =
maybeDefault.map(" = " + DefaultValueRenderer.render(allModels = allSchemas, thisType = schemaType, optional)(_))
val default = maybeExplicitDefault getOrElse (if (optional) " = None" else "")
s"$fixedKey: $tpe$default"
}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,21 +1,15 @@
package sttp.tapir.codegen
import io.circe.Json
import sttp.tapir.codegen.BasicGenerator.{indent, mapSchemaSimpleTypeToType, strippedToCamelCase}
import sttp.tapir.codegen.openapi.models.OpenapiModels.{
OpenapiDocument,
OpenapiParameter,
OpenapiPath,
OpenapiRequestBody,
OpenapiResponse,
SpecificationExtensionValue
}
import sttp.tapir.codegen.openapi.models.OpenapiModels.{OpenapiDocument, OpenapiParameter, OpenapiPath, OpenapiRequestBody, OpenapiResponse}
import sttp.tapir.codegen.openapi.models.OpenapiSchemaType.{
OpenapiSchemaAny,
OpenapiSchemaArray,
OpenapiSchemaBinary,
OpenapiSchemaRef,
OpenapiSchemaAny,
OpenapiSchemaSimpleType
}
import sttp.tapir.codegen.openapi.models.{OpenapiComponent, OpenapiSchemaType, OpenapiSecuritySchemeType}
import sttp.tapir.codegen.openapi.models.{OpenapiComponent, OpenapiSchemaType, OpenapiSecuritySchemeType, SpecificationExtensionRenderer}
import sttp.tapir.codegen.util.JavaEscape

case class Location(path: String, method: String) {
Expand Down Expand Up @@ -210,12 +204,12 @@ class EndpointGenerator {
openapiTags.map(_.distinct.mkString(".tags(List(\"", "\", \"", "\"))")).mkString
}

private def attributes(atts: Map[String, SpecificationExtensionValue]): Option[String] = if (atts.nonEmpty) Some {
private def attributes(atts: Map[String, Json]): Option[String] = if (atts.nonEmpty) Some {
atts
.map { case (k, v) =>
val camelCaseK = strippedToCamelCase(k)
val uncapitalisedName = camelCaseK.head.toLower + camelCaseK.tail
s""".attribute(${uncapitalisedName}ExtensionKey, ${v.render})"""
s""".attribute[${camelCaseK.capitalize}Extension](${uncapitalisedName}ExtensionKey, ${SpecificationExtensionRenderer.renderValue(v)})"""
}
.mkString("\n")
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ import sttp.tapir.codegen.openapi.models.OpenapiSchemaType.{
OpenapiSchemaUUID
}

object Renderer {
object DefaultValueRenderer {
private def lookup(allModels: Map[String, OpenapiSchemaType], ref: OpenapiSchemaRef): OpenapiSchemaType = allModels(
ref.name.stripPrefix("#/components/schemas/")
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,46 +2,10 @@ package sttp.tapir.codegen.openapi.models

import cats.implicits.toTraverseOps
import cats.syntax.either._

import OpenapiSchemaType.OpenapiSchemaRef
import io.circe.Json
// https://swagger.io/specification/
object OpenapiModels {
sealed trait SpecificationExtensionValue {
def tpe: String
def render: String
def value: Any
}
case object SpecificationExtensionValueNull extends SpecificationExtensionValue {
val tpe = "Null"
val render = "null"
val value = null
}
case class SpecificationExtensionValueBoolean(value: Boolean) extends SpecificationExtensionValue {
val render = value.toString
val tpe = "Boolean"
}
case class SpecificationExtensionValueLong(value: Long) extends SpecificationExtensionValue {
val render = s"${value}L"
val tpe = "Long"
}
case class SpecificationExtensionValueDouble(value: Double) extends SpecificationExtensionValue {
val render = s"${value}d"
val tpe = "Double"
}
case class SpecificationExtensionValueString(value: String) extends SpecificationExtensionValue {
val render = '"' +: value :+ '"'
val tpe = "String"
}
case class SpecificationExtensionValueList(values: Seq[SpecificationExtensionValue]) extends SpecificationExtensionValue {
val render = s"Vector(${values.map(_.render).mkString(", ")})"
def tpe = values.map(_.tpe).distinct match { case single +: Nil => s"Seq[$single]"; case _ => "Seq[Any]" }
def value = values.map(_.value)
}
case class SpecificationExtensionValueMap(kvs: Map[String, SpecificationExtensionValue]) extends SpecificationExtensionValue {
val render = s"Map(${kvs.map { case (k, v) => s""""$k" -> ${v.render}""" }.mkString(", ")})"
def tpe = kvs.values.map(_.tpe).toSeq.distinct match { case single +: Nil => s"Map[String, $single]"; case _ => "Map[String, Any]" }
def value = kvs.map { case (k, v) => k -> v.value }
}

sealed trait Resolvable[T] {
def resolve(input: Map[String, T]): T
Expand Down Expand Up @@ -72,7 +36,7 @@ object OpenapiModels {
url: String,
methods: Seq[OpenapiPathMethod],
parameters: Seq[Resolvable[OpenapiParameter]] = Nil,
specificationExtensions: Map[String, SpecificationExtensionValue] = Map.empty
specificationExtensions: Map[String, Json] = Map.empty
)

case class OpenapiPathMethod(
Expand All @@ -84,7 +48,7 @@ object OpenapiModels {
summary: Option[String] = None,
tags: Option[Seq[String]] = None,
operationId: Option[String] = None,
specificationExtensions: Map[String, SpecificationExtensionValue] = Map.empty
specificationExtensions: Map[String, Json] = Map.empty
) {
def resolvedParameters: Seq[OpenapiParameter] = parameters.collect { case Resolved(t) => t }
def withResolvedParentParameters(
Expand Down Expand Up @@ -205,19 +169,6 @@ object OpenapiModels {
c.as[T].map(Resolved(_)).orElse(c.as[OpenapiSchemaRef].map(r => Ref(r.name)))
}

def decodeSpecificationExtensionValue(json: Json): SpecificationExtensionValue =
json.fold(
SpecificationExtensionValueNull,
SpecificationExtensionValueBoolean.apply,
n => n.toLong.map(SpecificationExtensionValueLong.apply).getOrElse(SpecificationExtensionValueDouble(n.toDouble)),
SpecificationExtensionValueString.apply,
arr => SpecificationExtensionValueList(arr.map(decodeSpecificationExtensionValue)),
obj => SpecificationExtensionValueMap(obj.toMap.map { case (k, v) => k -> decodeSpecificationExtensionValue(v) })
)
implicit val SpecificationExtensionValueDecoder: Decoder[SpecificationExtensionValue] = { (c: HCursor) =>
Right(decodeSpecificationExtensionValue(c.value))
}

implicit val PartialOpenapiPathMethodDecoder: Decoder[OpenapiPathMethod] = { (c: HCursor) =>
for {
parameters <- c.getOrElse[Seq[Resolvable[OpenapiParameter]]]("parameters")(Nil)
Expand All @@ -229,7 +180,7 @@ object OpenapiModels {
operationId <- c.get[Option[String]]("operationId")
specificationExtensionKeys = c.keys.toSeq.flatMap(_.filter(_.startsWith("x-")))
specificationExtensions = specificationExtensionKeys
.flatMap(key => c.downField(key).as[SpecificationExtensionValue].toOption.map(key.stripPrefix("x-") -> _))
.flatMap(key => c.downField(key).as[Option[Json]].toOption.flatten.map(key.stripPrefix("x-") -> _))
.toMap
} yield {
OpenapiPathMethod(
Expand All @@ -256,7 +207,7 @@ object OpenapiModels {
.traverse(method => c.downField(method).as[Option[OpenapiPathMethod]].map(_.map(_.copy(methodType = method))))
specificationExtensionKeys = c.keys.toSeq.flatMap(_.filter(_.startsWith("x-")))
specificationExtensions = specificationExtensionKeys
.flatMap(key => c.downField(key).as[SpecificationExtensionValue].toOption.map(key.stripPrefix("x-") -> _))
.flatMap(key => c.downField(key).as[Option[Json]].toOption.flatten.map(key.stripPrefix("x-") -> _))
.toMap
} yield OpenapiPath("--partial--", methods.flatten, parameters, specificationExtensions)
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
package sttp.tapir.codegen.openapi.models

import io.circe.Json

object SpecificationExtensionRenderer {

def renderCombinedType(jsons: Seq[Json]): String = {
// permit nulls for any type, but specify type as null if every value is null
val nonNull = jsons.filterNot(_.isNull)
if (jsons.isEmpty) "Nothing"
else if (nonNull.isEmpty) "Null"
else {
val groupedByBaseType = nonNull.groupBy(j =>
if (j.isBoolean) "Boolean"
else if (j.isNumber) "Number"
else if (j.isString) "String"
else if (j.isArray) "Array"
else if (j.isObject) "Object"
else throw new IllegalStateException("json must be one of boolean, number, string, array or object")
)
// Cannot resolve types if totally different...
if (groupedByBaseType.size > 1) "Any"
else
groupedByBaseType.head match {
case (t @ ("Boolean" | "String"), _) => t
case ("Number", vs) => if (vs.forall(_.asNumber.flatMap(_.toLong).isDefined)) "Long" else "Double"
case ("Array", vs) =>
val t = renderCombinedType(vs.flatMap(_.asArray).flatten)
s"Seq[$t]"
case ("Object", kvs) =>
val t = renderCombinedType(kvs.flatMap(_.asObject).flatMap(_.toMap.values))
s"Map[String, $t]"
case (x, _) => throw new IllegalStateException(s"No such group $x")
}
}
}

def renderValue(json: Json): String = json.fold(
"null",
bool => bool.toString,
n => n.toLong.map(l => s"${l}L") getOrElse s"${n.toDouble}d", // the long repr is fine even if type expanded to Double
s => '"' +: s :+ '"',
arr => if (arr.isEmpty) "Vector.empty" else s"Vector(${arr.map(renderValue).mkString(", ")})",
obj =>
if (obj.isEmpty) "Map.empty[String, Nothing]"
else s"Map(${obj.toMap.map { case (k, v) => s""""$k" -> ${renderValue(v)}""" }.mkString(", ")})"
)
}
Original file line number Diff line number Diff line change
Expand Up @@ -261,22 +261,27 @@ class EndpointGeneratorSpec extends CompileCheckTestBase {
jsonSerdeLib = "circe"
)("TapirGeneratedEndpoints")
generatedCode shouldCompile ()
generatedCode should include(
""".attribute(customStringExtensionOnPathExtensionKey, "foobar")"""
)
generatedCode should include(
""".attribute(customStringExtensionOnOperationExtensionKey, "bazquux")"""
)
generatedCode should include(
""".attribute(customListExtensionOnOperationExtensionKey, Vector("baz", "quux"))"""
)
generatedCode should include(
""".attribute(customMapExtensionOnPathExtensionKey, Map("bazkey" -> "bazval", "quuxkey" -> Vector("quux1", "quux2"))"""
val expectedAttrDecls = Seq(
""".attribute[CustomStringExtensionOnPathExtension](customStringExtensionOnPathExtensionKey, "another string")""",
""".attribute[CustomStringExtensionOnOperationExtension](customStringExtensionOnOperationExtensionKey, "bazquux")""",
""".attribute[CustomListExtensionOnOperationExtension](customListExtensionOnOperationExtensionKey, Vector("baz", "quux"))""",
""".attribute[CustomMapExtensionOnPathExtension](customMapExtensionOnPathExtensionKey, Map("bazkey" -> "bazval", "quuxkey" -> Vector("quux1", "quux2"))"""
)
generatedCode should include("""type CustomMapExtensionOnOperationExtension = Map[String, Any]""")
expectedAttrDecls foreach (decl => generatedCode should include(decl))
generatedCode should include(
"""val customMapExtensionOnOperationExtensionKey = new sttp.tapir.AttributeKey[CustomMapExtensionOnOperationExtension]("sttp.tapir.generated.TapirGeneratedEndpoints.CustomMapExtensionOnOperationExtension")""".stripMargin
)
val expectedKeyDeclarations = Seq(
"""type CustomMapExtensionOnOperationExtension = Map[String, Any]""",
"""type CustomListExtensionOnPathAnyTypeExtension = Seq[Any]""",
"""type CustomMapExtensionOnPathSingleValueTypeExtension = Map[String, String]""",
"""type CustomListExtensionOnOperationExtension = Seq[String]""",
"""type CustomStringExtensionOnPathAnyTypeExtension = Any""",
"""type CustomStringExtensionOnPathDoubleTypeExtension = Double""",
"""type CustomListExtensionOnPathExtension = Seq[String]""",
"""type CustomStringExtensionOnPathExtension = String"""
)
expectedKeyDeclarations foreach (decl => generatedCode should include(decl))
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -772,10 +772,14 @@ object TestHelpers {
| version: '1.0'
|paths:
| /hello:
| x-custom-string-extension-on-path: foobar
| x-custom-string-extension-on-path-any-type: foobar
| x-custom-string-extension-on-path-double-type: 123
| x-custom-string-extension-on-path: null
| x-custom-list-extension-on-path:
| - foo
| - bar
| x-custom-list-extension-on-path-any-type:
| - string
| x-custom-map-extension-on-path:
| bazkey: bazval
| quuxkey:
Expand All @@ -784,6 +788,16 @@ object TestHelpers {
| post:
| responses: {}
| /goodbye:
| x-custom-string-extension-on-path-any-type: 123
| x-custom-string-extension-on-path-double-type: 123.456
| x-custom-string-extension-on-path: another string
| x-custom-list-extension-on-path: []
| x-custom-list-extension-on-path-any-type:
| - 123
| x-custom-map-extension-on-path: {}
| x-custom-map-extension-on-path-single-value-type:
| bazkey: bazval
| quuxkey: quuxval
| delete:
| x-custom-string-extension-on-operation: bazquux
| x-custom-list-extension-on-operation:
Expand All @@ -804,15 +818,17 @@ object TestHelpers {
url = "/hello",
methods = Seq(OpenapiPathMethod(methodType = "post", parameters = Seq(), responses = Seq(), requestBody = None)),
specificationExtensions = Map(
"custom-string-extension-on-path" -> SpecificationExtensionValueString("foobar"),
"custom-list-extension-on-path" -> SpecificationExtensionValueList(
Vector(SpecificationExtensionValueString("foo"), SpecificationExtensionValueString("bar"))
"custom-string-extension-on-path-any-type" -> Json.fromString("foobar"),
"custom-string-extension-on-path-double-type" -> Json.fromLong(123L),
"custom-list-extension-on-path" -> Json.fromValues(
Vector(Json.fromString("foo"), Json.fromString("bar"))
),
"custom-map-extension-on-path" -> SpecificationExtensionValueMap(
"custom-list-extension-on-path-any-type" -> Json.arr(Json.fromString("string")),
"custom-map-extension-on-path" -> Json.fromFields(
Map(
"bazkey" -> SpecificationExtensionValueString("bazval"),
"quuxkey" -> SpecificationExtensionValueList(
Vector(SpecificationExtensionValueString("quux1"), SpecificationExtensionValueString("quux2"))
"bazkey" -> Json.fromString("bazval"),
"quuxkey" -> Json.fromValues(
Vector(Json.fromString("quux1"), Json.fromString("quux2"))
)
)
)
Expand All @@ -827,20 +843,34 @@ object TestHelpers {
responses = Seq(),
requestBody = None,
specificationExtensions = Map(
"custom-string-extension-on-operation" -> SpecificationExtensionValueString("bazquux"),
"custom-list-extension-on-operation" -> SpecificationExtensionValueList(
Vector(SpecificationExtensionValueString("baz"), SpecificationExtensionValueString("quux"))
"custom-string-extension-on-operation" -> Json.fromString("bazquux"),
"custom-list-extension-on-operation" -> Json.fromValues(
Vector(Json.fromString("baz"), Json.fromString("quux"))
),
"custom-map-extension-on-operation" -> SpecificationExtensionValueMap(
"custom-map-extension-on-operation" -> Json.fromFields(
Map(
"bazkey" -> SpecificationExtensionValueString("bazval"),
"quuxkey" -> SpecificationExtensionValueList(
Vector(SpecificationExtensionValueString("quux1"), SpecificationExtensionValueString("quux2"))
"bazkey" -> Json.fromString("bazval"),
"quuxkey" -> Json.fromValues(
Vector(Json.fromString("quux1"), Json.fromString("quux2"))
)
)
)
)
)
),
specificationExtensions = Map(
"custom-string-extension-on-path-any-type" -> Json.fromLong(123L),
"custom-string-extension-on-path-double-type" -> Json.fromDouble(123.456d).get,
"custom-string-extension-on-path" -> Json.fromString("another string"),
"custom-list-extension-on-path" -> Json.arr(),
"custom-list-extension-on-path-any-type" -> Json.arr(Json.fromLong(123L)),
"custom-map-extension-on-path" -> Json.fromFields(Map.empty),
"custom-map-extension-on-path-single-value-type" -> Json.fromFields(
Map(
"bazkey" -> Json.fromString("bazval"),
"quuxkey" -> Json.fromString("quuxval")
)
)
)
)
),
Expand Down

0 comments on commit faf1012

Please sign in to comment.