diff --git a/openapi-codegen/core/src/main/scala/sttp/tapir/codegen/BasicGenerator.scala b/openapi-codegen/core/src/main/scala/sttp/tapir/codegen/BasicGenerator.scala index 84200a4845..cd5be2730a 100644 --- a/openapi-codegen/core/src/main/scala/sttp/tapir/codegen/BasicGenerator.scala +++ b/openapi-codegen/core/src/main/scala/sttp/tapir/codegen/BasicGenerator.scala @@ -61,6 +61,29 @@ object BasicGenerator { |}""".stripMargin headTag -> taggedObj } + + val maybeSpecificationExtensionKeys = doc.paths + .flatMap { p => + p.specificationExtensions.toSeq ++ p.methods.flatMap(_.specificationExtensions.toSeq) + } + .groupBy(_._1) + .map { case (keyName, pairs) => + val values = pairs.map(_._2) + val distinctTypes = values.map(_.tpe).distinct + if (distinctTypes.size != 1) + throw new IllegalArgumentException( + s"specification extensions with the same key are expected to all have the same type. Found $distinctTypes for $keyName" + ) + val `type` = distinctTypes.head + val name = strippedToCamelCase(keyName) + val uncapitalisedName = name.head.toLower + name.tail + val capitalisedName = name.head.toUpper + name.tail + s"""type ${capitalisedName}Extension = ${`type`} + |val ${uncapitalisedName}ExtensionKey = new sttp.tapir.AttributeKey[${capitalisedName}Extension]("$packagePath.$objName.${capitalisedName}Extension") + |""".stripMargin + } + .mkString("\n") + val mainObj = s"""| |package $packagePath | @@ -70,8 +93,9 @@ object BasicGenerator { | |${indent(2)(classGenerator.classDefs(doc, targetScala3, queryParamRefs, normalisedJsonLib, jsonParamRefs).getOrElse(""))} | - |${indent(2)(endpointsByTag.getOrElse(None, ""))} + |${indent(2)(maybeSpecificationExtensionKeys)} | + |${indent(2)(endpointsByTag.getOrElse(None, ""))} |} |""".stripMargin taggedObjs + (objName -> mainObj) @@ -127,4 +151,11 @@ object BasicGenerator { case x => throw new NotImplementedError(s"Not all simple types supported! Found $x") } } + + def strippedToCamelCase(string: String): String = string + .split("[^0-9a-zA-Z$_]") + .filter(_.nonEmpty) + .zipWithIndex + .map { case (part, 0) => part; case (part, _) => part.capitalize } + .mkString } diff --git a/openapi-codegen/core/src/main/scala/sttp/tapir/codegen/EndpointGenerator.scala b/openapi-codegen/core/src/main/scala/sttp/tapir/codegen/EndpointGenerator.scala index bb1f64f8ca..c8c032b14b 100644 --- a/openapi-codegen/core/src/main/scala/sttp/tapir/codegen/EndpointGenerator.scala +++ b/openapi-codegen/core/src/main/scala/sttp/tapir/codegen/EndpointGenerator.scala @@ -1,5 +1,5 @@ package sttp.tapir.codegen -import sttp.tapir.codegen.BasicGenerator.{indent, mapSchemaSimpleTypeToType} +import sttp.tapir.codegen.BasicGenerator.{indent, mapSchemaSimpleTypeToType, strippedToCamelCase} import sttp.tapir.codegen.openapi.models.OpenapiModels.{ OpenapiDocument, OpenapiParameter, @@ -98,13 +98,7 @@ class EndpointGenerator { |$attributeString |""".stripMargin.linesIterator.filterNot(_.trim.isEmpty).mkString("\n") - val name = m.operationId - .getOrElse(m.methodType + p.url.capitalize) - .split("[^0-9a-zA-Z$_]") - .filter(_.nonEmpty) - .zipWithIndex - .map { case (part, 0) => part; case (part, _) => part.capitalize } - .mkString + val name = strippedToCamelCase(m.operationId.getOrElse(m.methodType + p.url.capitalize)) val maybeTargetFileName = if (useHeadTagForObjectNames) m.tags.flatMap(_.headOption) else None val queryParamRefs = m.resolvedParameters .collect { case queryParam: OpenapiParameter if queryParam.in == "query" => queryParam.schema } @@ -217,7 +211,13 @@ class EndpointGenerator { } private def attributes(atts: Map[String, SpecificationExtensionValue]): Option[String] = if (atts.nonEmpty) Some { - atts.map { case (k, v) => s""".attribute[${v.tpe}](new AttributeKey[${v.tpe}]("${k}"), ${v.render})""" }.mkString("\n") + atts + .map { case (k, v) => + val camelCaseK = strippedToCamelCase(k) + val uncapitalisedName = camelCaseK.head.toLower + camelCaseK.tail + s""".attribute(${uncapitalisedName}ExtensionKey, ${v.render})""" + } + .mkString("\n") } else None diff --git a/openapi-codegen/core/src/test/scala/sttp/tapir/codegen/EndpointGeneratorSpec.scala b/openapi-codegen/core/src/test/scala/sttp/tapir/codegen/EndpointGeneratorSpec.scala index f858a14fb7..b00c584d1e 100644 --- a/openapi-codegen/core/src/test/scala/sttp/tapir/codegen/EndpointGeneratorSpec.scala +++ b/openapi-codegen/core/src/test/scala/sttp/tapir/codegen/EndpointGeneratorSpec.scala @@ -259,19 +259,23 @@ class EndpointGeneratorSpec extends CompileCheckTestBase { useHeadTagForObjectNames = false, jsonSerdeLib = "circe" )("TapirGeneratedEndpoints") + generatedCode shouldCompile () generatedCode should include( - """.attribute[String](new AttributeKey[String]("custom-string-extension-on-path"), "foobar")""" + """.attribute(customStringExtensionOnPathExtensionKey, "foobar")""" ) generatedCode should include( - """.attribute[String](new AttributeKey[String]("custom-string-extension-on-operation"), "bazquux")""" + """.attribute(customStringExtensionOnOperationExtensionKey, "bazquux")""" ) generatedCode should include( - """.attribute[Seq[String]](new AttributeKey[Seq[String]]("custom-list-extension-on-operation"), Vector("baz", "quux"))""" + """.attribute(customListExtensionOnOperationExtensionKey, Vector("baz", "quux"))""" ) generatedCode should include( - """.attribute[Map[String, Any]](new AttributeKey[Map[String, Any]]("custom-map-extension-on-path"), Map("bazkey" -> "bazval", "quuxkey" -> Vector("quux1", "quux2"))""" + """.attribute(customMapExtensionOnPathExtensionKey, Map("bazkey" -> "bazval", "quuxkey" -> Vector("quux1", "quux2"))""" + ) + generatedCode should include("""type CustomMapExtensionOnOperationExtension = Map[String, Any]""") + generatedCode should include( + """val customMapExtensionOnOperationExtensionKey = new sttp.tapir.AttributeKey[CustomMapExtensionOnOperationExtension]("sttp.tapir.generated.TapirGeneratedEndpoints.CustomMapExtensionOnOperationExtension")""".stripMargin ) - generatedCode shouldCompile () } }