diff --git a/openapi-codegen/core/src/main/scala/sttp/tapir/codegen/EndpointGenerator.scala b/openapi-codegen/core/src/main/scala/sttp/tapir/codegen/EndpointGenerator.scala index 239a0f77b8..bb1f64f8ca 100644 --- a/openapi-codegen/core/src/main/scala/sttp/tapir/codegen/EndpointGenerator.scala +++ b/openapi-codegen/core/src/main/scala/sttp/tapir/codegen/EndpointGenerator.scala @@ -1,6 +1,13 @@ package sttp.tapir.codegen import sttp.tapir.codegen.BasicGenerator.{indent, mapSchemaSimpleTypeToType} -import sttp.tapir.codegen.openapi.models.OpenapiModels.{OpenapiDocument, OpenapiParameter, OpenapiPath, OpenapiRequestBody, OpenapiResponse} +import sttp.tapir.codegen.openapi.models.OpenapiModels.{ + OpenapiDocument, + OpenapiParameter, + OpenapiPath, + OpenapiRequestBody, + OpenapiResponse, + SpecificationExtensionValue +} import sttp.tapir.codegen.openapi.models.OpenapiSchemaType.{ OpenapiSchemaArray, OpenapiSchemaBinary, @@ -68,6 +75,18 @@ class EndpointGenerator { .map(_.withResolvedParentParameters(parameters, p.parameters)) .map { m => implicit val location: Location = Location(p.url, m.methodType) + + val attributeString = { + val pathAttributes = attributes(p.specificationExtensions) + val operationAttributes = attributes(m.specificationExtensions) + (pathAttributes, operationAttributes) match { + case (None, None) => "" + case (Some(atts), None) => indent(2)(atts) + case (None, Some(atts)) => indent(2)(atts) + case (Some(pathAtts), Some(operationAtts)) => indent(2)(pathAtts + "\n" + operationAtts) + } + } + val definition = s"""|endpoint | .${m.methodType} @@ -76,6 +95,7 @@ class EndpointGenerator { |${indent(2)(ins(m.resolvedParameters, m.requestBody))} |${indent(2)(outs(m.responses))} |${indent(2)(tags(m.tags))} + |$attributeString |""".stripMargin.linesIterator.filterNot(_.trim.isEmpty).mkString("\n") val name = m.operationId @@ -196,6 +216,11 @@ class EndpointGenerator { openapiTags.map(_.distinct.mkString(".tags(List(\"", "\", \"", "\"))")).mkString } + private def attributes(atts: Map[String, SpecificationExtensionValue]): Option[String] = if (atts.nonEmpty) Some { + atts.map { case (k, v) => s""".attribute[${v.tpe}](new AttributeKey[${v.tpe}]("${k}"), ${v.render})""" }.mkString("\n") + } + else None + // treats redirects as ok private val okStatus = """([23]\d\d)""".r private val errorStatus = """([45]\d\d)""".r diff --git a/openapi-codegen/core/src/main/scala/sttp/tapir/codegen/openapi/models/OpenapiModels.scala b/openapi-codegen/core/src/main/scala/sttp/tapir/codegen/openapi/models/OpenapiModels.scala index 95ce6062d3..83dbc678a3 100644 --- a/openapi-codegen/core/src/main/scala/sttp/tapir/codegen/openapi/models/OpenapiModels.scala +++ b/openapi-codegen/core/src/main/scala/sttp/tapir/codegen/openapi/models/OpenapiModels.scala @@ -6,14 +6,42 @@ import cats.syntax.either._ import OpenapiSchemaType.OpenapiSchemaRef // https://swagger.io/specification/ object OpenapiModels { - sealed trait SpecificationExtensionValue - case object SpecificationExtensionValueNull extends SpecificationExtensionValue - case class SpecificationExtensionValueBoolean(value: Boolean) extends SpecificationExtensionValue - case class SpecificationExtensionValueLong(value: Long) extends SpecificationExtensionValue - case class SpecificationExtensionValueDouble(value: Double) extends SpecificationExtensionValue - case class SpecificationExtensionValueString(value: String) extends SpecificationExtensionValue - case class SpecificationExtensionValueList(values: Seq[SpecificationExtensionValue]) extends SpecificationExtensionValue - case class SpecificationExtensionValueMap(kvs: Map[String, SpecificationExtensionValue]) extends SpecificationExtensionValue + sealed trait SpecificationExtensionValue { + def tpe: String + def render: String + def value: Any + } + case object SpecificationExtensionValueNull extends SpecificationExtensionValue { + val tpe = "Null" + val render = "null" + val value = null + } + case class SpecificationExtensionValueBoolean(value: Boolean) extends SpecificationExtensionValue { + val render = value.toString + val tpe = "Boolean" + } + case class SpecificationExtensionValueLong(value: Long) extends SpecificationExtensionValue { + val render = s"${value}L" + val tpe = "Long" + } + case class SpecificationExtensionValueDouble(value: Double) extends SpecificationExtensionValue { + val render = s"${value}d" + val tpe = "Double" + } + case class SpecificationExtensionValueString(value: String) extends SpecificationExtensionValue { + val render = '"' +: value :+ '"' + val tpe = "String" + } + case class SpecificationExtensionValueList(values: Seq[SpecificationExtensionValue]) extends SpecificationExtensionValue { + val render = s"Vector(${values.map(_.render).mkString(", ")})" + def tpe = values.map(_.tpe).distinct match { case single +: Nil => s"Seq[$single]"; case _ => "Seq[Any]" } + def value = values.map(_.value) + } + case class SpecificationExtensionValueMap(kvs: Map[String, SpecificationExtensionValue]) extends SpecificationExtensionValue { + val render = s"Map(${kvs.map { case (k, v) => s""""$k" -> ${v.render}""" }.mkString(", ")})" + def tpe = kvs.values.map(_.tpe).toSeq.distinct match { case single +: Nil => s"Map[String, $single]"; case _ => "Map[String, Any]" } + def value = kvs.map { case (k, v) => k -> v.value } + } sealed trait Resolvable[T] { def resolve(input: Map[String, T]): T diff --git a/openapi-codegen/core/src/test/scala/sttp/tapir/codegen/EndpointGeneratorSpec.scala b/openapi-codegen/core/src/test/scala/sttp/tapir/codegen/EndpointGeneratorSpec.scala index e5bab6ecf6..f858a14fb7 100644 --- a/openapi-codegen/core/src/test/scala/sttp/tapir/codegen/EndpointGeneratorSpec.scala +++ b/openapi-codegen/core/src/test/scala/sttp/tapir/codegen/EndpointGeneratorSpec.scala @@ -249,4 +249,29 @@ class EndpointGeneratorSpec extends CompileCheckTestBase { generatedCode shouldCompile () } + it should "generate attributes for specification extensions on path and operation objects" in { + val doc = TestHelpers.specificationExtensionDocs + val generatedCode = BasicGenerator.generateObjects( + doc, + "sttp.tapir.generated", + "TapirGeneratedEndpoints", + targetScala3 = false, + useHeadTagForObjectNames = false, + jsonSerdeLib = "circe" + )("TapirGeneratedEndpoints") + generatedCode should include( + """.attribute[String](new AttributeKey[String]("custom-string-extension-on-path"), "foobar")""" + ) + generatedCode should include( + """.attribute[String](new AttributeKey[String]("custom-string-extension-on-operation"), "bazquux")""" + ) + generatedCode should include( + """.attribute[Seq[String]](new AttributeKey[Seq[String]]("custom-list-extension-on-operation"), Vector("baz", "quux"))""" + ) + generatedCode should include( + """.attribute[Map[String, Any]](new AttributeKey[Map[String, Any]]("custom-map-extension-on-path"), Map("bazkey" -> "bazval", "quuxkey" -> Vector("quux1", "quux2"))""" + ) + generatedCode shouldCompile () + } + }