Skip to content

Commit

Permalink
make the key a fully qualified type; check that types are consistent
Browse files Browse the repository at this point in the history
  • Loading branch information
hughsimpson committed Mar 18, 2024
1 parent d84eefe commit bee2ad3
Show file tree
Hide file tree
Showing 3 changed files with 50 additions and 15 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,29 @@ object BasicGenerator {
|}""".stripMargin
headTag -> taggedObj
}

val maybeSpecificationExtensionKeys = doc.paths
.flatMap { p =>
p.specificationExtensions.toSeq ++ p.methods.flatMap(_.specificationExtensions.toSeq)
}
.groupBy(_._1)
.map { case (keyName, pairs) =>
val values = pairs.map(_._2)
val distinctTypes = values.map(_.tpe).distinct
if (distinctTypes.size != 1)
throw new IllegalArgumentException(
s"specification extensions with the same key are expected to all have the same type. Found $distinctTypes for $keyName"
)
val `type` = distinctTypes.head
val name = strippedToCamelCase(keyName)
val uncapitalisedName = name.head.toLower + name.tail
val capitalisedName = name.head.toUpper + name.tail
s"""type ${capitalisedName}Extension = ${`type`}
|val ${uncapitalisedName}ExtensionKey = new sttp.tapir.AttributeKey[${capitalisedName}Extension]("$packagePath.$objName.${capitalisedName}Extension")
|""".stripMargin
}
.mkString("\n")

val mainObj = s"""|
|package $packagePath
|
Expand All @@ -70,8 +93,9 @@ object BasicGenerator {
|
|${indent(2)(classGenerator.classDefs(doc, targetScala3, queryParamRefs, normalisedJsonLib, jsonParamRefs).getOrElse(""))}
|
|${indent(2)(endpointsByTag.getOrElse(None, ""))}
|${indent(2)(maybeSpecificationExtensionKeys)}
|
|${indent(2)(endpointsByTag.getOrElse(None, ""))}
|}
|""".stripMargin
taggedObjs + (objName -> mainObj)
Expand Down Expand Up @@ -127,4 +151,11 @@ object BasicGenerator {
case x => throw new NotImplementedError(s"Not all simple types supported! Found $x")
}
}

def strippedToCamelCase(string: String): String = string
.split("[^0-9a-zA-Z$_]")
.filter(_.nonEmpty)
.zipWithIndex
.map { case (part, 0) => part; case (part, _) => part.capitalize }
.mkString
}
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
package sttp.tapir.codegen
import sttp.tapir.codegen.BasicGenerator.{indent, mapSchemaSimpleTypeToType}
import sttp.tapir.codegen.BasicGenerator.{indent, mapSchemaSimpleTypeToType, strippedToCamelCase}
import sttp.tapir.codegen.openapi.models.OpenapiModels.{
OpenapiDocument,
OpenapiParameter,
Expand Down Expand Up @@ -98,13 +98,7 @@ class EndpointGenerator {
|$attributeString
|""".stripMargin.linesIterator.filterNot(_.trim.isEmpty).mkString("\n")

val name = m.operationId
.getOrElse(m.methodType + p.url.capitalize)
.split("[^0-9a-zA-Z$_]")
.filter(_.nonEmpty)
.zipWithIndex
.map { case (part, 0) => part; case (part, _) => part.capitalize }
.mkString
val name = strippedToCamelCase(m.operationId.getOrElse(m.methodType + p.url.capitalize))
val maybeTargetFileName = if (useHeadTagForObjectNames) m.tags.flatMap(_.headOption) else None
val queryParamRefs = m.resolvedParameters
.collect { case queryParam: OpenapiParameter if queryParam.in == "query" => queryParam.schema }
Expand Down Expand Up @@ -217,7 +211,13 @@ class EndpointGenerator {
}

private def attributes(atts: Map[String, SpecificationExtensionValue]): Option[String] = if (atts.nonEmpty) Some {
atts.map { case (k, v) => s""".attribute[${v.tpe}](new AttributeKey[${v.tpe}]("${k}"), ${v.render})""" }.mkString("\n")
atts
.map { case (k, v) =>
val camelCaseK = strippedToCamelCase(k)
val uncapitalisedName = camelCaseK.head.toLower + camelCaseK.tail
s""".attribute(${uncapitalisedName}ExtensionKey, ${v.render})"""
}
.mkString("\n")
}
else None

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -259,19 +259,23 @@ class EndpointGeneratorSpec extends CompileCheckTestBase {
useHeadTagForObjectNames = false,
jsonSerdeLib = "circe"
)("TapirGeneratedEndpoints")
generatedCode shouldCompile ()
generatedCode should include(
""".attribute[String](new AttributeKey[String]("custom-string-extension-on-path"), "foobar")"""
""".attribute(customStringExtensionOnPathExtensionKey, "foobar")"""
)
generatedCode should include(
""".attribute[String](new AttributeKey[String]("custom-string-extension-on-operation"), "bazquux")"""
""".attribute(customStringExtensionOnOperationExtensionKey, "bazquux")"""
)
generatedCode should include(
""".attribute[Seq[String]](new AttributeKey[Seq[String]]("custom-list-extension-on-operation"), Vector("baz", "quux"))"""
""".attribute(customListExtensionOnOperationExtensionKey, Vector("baz", "quux"))"""
)
generatedCode should include(
""".attribute[Map[String, Any]](new AttributeKey[Map[String, Any]]("custom-map-extension-on-path"), Map("bazkey" -> "bazval", "quuxkey" -> Vector("quux1", "quux2"))"""
""".attribute(customMapExtensionOnPathExtensionKey, Map("bazkey" -> "bazval", "quuxkey" -> Vector("quux1", "quux2"))"""
)
generatedCode should include("""type CustomMapExtensionOnOperationExtension = Map[String, Any]""")
generatedCode should include(
"""val customMapExtensionOnOperationExtensionKey = new sttp.tapir.AttributeKey[CustomMapExtensionOnOperationExtension]("sttp.tapir.generated.TapirGeneratedEndpoints.CustomMapExtensionOnOperationExtension")""".stripMargin
)
generatedCode shouldCompile ()
}

}

0 comments on commit bee2ad3

Please sign in to comment.