Skip to content

Commit

Permalink
generate attributes on endpoints, based on any parsed specification e…
Browse files Browse the repository at this point in the history
…xtensions
  • Loading branch information
hughsimpson committed Mar 15, 2024
1 parent aa40a56 commit cde47d7
Show file tree
Hide file tree
Showing 3 changed files with 88 additions and 10 deletions.
Original file line number Diff line number Diff line change
@@ -1,6 +1,13 @@
package sttp.tapir.codegen
import sttp.tapir.codegen.BasicGenerator.{indent, mapSchemaSimpleTypeToType}
import sttp.tapir.codegen.openapi.models.OpenapiModels.{OpenapiDocument, OpenapiParameter, OpenapiPath, OpenapiRequestBody, OpenapiResponse}
import sttp.tapir.codegen.openapi.models.OpenapiModels.{
OpenapiDocument,
OpenapiParameter,
OpenapiPath,
OpenapiRequestBody,
OpenapiResponse,
SpecificationExtensionValue
}
import sttp.tapir.codegen.openapi.models.OpenapiSchemaType.{
OpenapiSchemaArray,
OpenapiSchemaBinary,
Expand Down Expand Up @@ -60,6 +67,18 @@ class EndpointGenerator {
.map(_.withResolvedParentParameters(parameters, p.parameters))
.map { m =>
implicit val location: Location = Location(p.url, m.methodType)

val attributeString = {
val pathAttributes = attributes(p.specificationExtensions)
val operationAttributes = attributes(m.specificationExtensions)
(pathAttributes, operationAttributes) match {
case (None, None) => ""
case (Some(atts), None) => indent(2)(atts)
case (None, Some(atts)) => indent(2)(atts)
case (Some(pathAtts), Some(operationAtts)) => indent(2)(pathAtts + "\n" + operationAtts)
}
}

val definition =
s"""|endpoint
| .${m.methodType}
Expand All @@ -68,7 +87,8 @@ class EndpointGenerator {
|${indent(2)(ins(m.resolvedParameters, m.requestBody))}
|${indent(2)(outs(m.responses))}
|${indent(2)(tags(m.tags))}
|""".stripMargin
|$attributeString
|""".stripMargin.linesIterator.filterNot(_.trim.isEmpty).mkString("\n")

val name = m.operationId
.getOrElse(m.methodType + p.url.capitalize)
Expand Down Expand Up @@ -166,6 +186,11 @@ class EndpointGenerator {
openapiTags.map(_.distinct.mkString(".tags(List(\"", "\", \"", "\"))")).mkString
}

private def attributes(atts: Map[String, SpecificationExtensionValue]): Option[String] = if (atts.nonEmpty) Some {
atts.map { case (k, v) => s""".attribute[${v.tpe}](new AttributeKey[${v.tpe}]("${k}"), ${v.render})""" }.mkString("\n")
}
else None

// treats redirects as ok
private val okStatus = """([23]\d\d)""".r
private val errorStatus = """([45]\d\d)""".r
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,42 @@ import cats.syntax.either._
import OpenapiSchemaType.OpenapiSchemaRef
// https://swagger.io/specification/
object OpenapiModels {
sealed trait SpecificationExtensionValue
case object SpecificationExtensionValueNull extends SpecificationExtensionValue
case class SpecificationExtensionValueBoolean(value: Boolean) extends SpecificationExtensionValue
case class SpecificationExtensionValueLong(value: Long) extends SpecificationExtensionValue
case class SpecificationExtensionValueDouble(value: Double) extends SpecificationExtensionValue
case class SpecificationExtensionValueString(value: String) extends SpecificationExtensionValue
case class SpecificationExtensionValueList(values: Seq[SpecificationExtensionValue]) extends SpecificationExtensionValue
case class SpecificationExtensionValueMap(kvs: Map[String, SpecificationExtensionValue]) extends SpecificationExtensionValue
sealed trait SpecificationExtensionValue {
def tpe: String
def render: String
def value: Any
}
case object SpecificationExtensionValueNull extends SpecificationExtensionValue {
val tpe = "Null"
val render = "null"
val value = null
}
case class SpecificationExtensionValueBoolean(value: Boolean) extends SpecificationExtensionValue {
val render = value.toString
val tpe = "Boolean"
}
case class SpecificationExtensionValueLong(value: Long) extends SpecificationExtensionValue {
val render = s"${value}L"
val tpe = "Long"
}
case class SpecificationExtensionValueDouble(value: Double) extends SpecificationExtensionValue {
val render = s"${value}d"
val tpe = "Double"
}
case class SpecificationExtensionValueString(value: String) extends SpecificationExtensionValue {
val render = '"' +: value :+ '"'
val tpe = "String"
}
case class SpecificationExtensionValueList(values: Seq[SpecificationExtensionValue]) extends SpecificationExtensionValue {
val render = s"Vector(${values.map(_.render).mkString(", ")})"
def tpe = values.map(_.tpe).distinct match { case single +: Nil => s"Seq[$single]"; case _ => "Seq[Any]" }
def value = values.map(_.value)
}
case class SpecificationExtensionValueMap(kvs: Map[String, SpecificationExtensionValue]) extends SpecificationExtensionValue {
val render = s"Map(${kvs.map { case (k, v) => s""""$k" -> ${v.render}""" }.mkString(", ")})"
def tpe = kvs.values.map(_.tpe).toSeq.distinct match { case single +: Nil => s"Map[String, $single]"; case _ => "Map[String, Any]" }
def value = kvs.map { case (k, v) => k -> v.value }
}

sealed trait Resolvable[T] {
def resolve(input: Map[String, T]): T
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -247,4 +247,29 @@ class EndpointGeneratorSpec extends CompileCheckTestBase {
)
generatedCode shouldCompile ()
}

it should "generate attributes for specification extensions on path and operation objects" in {
val doc = TestHelpers.specificationExtensionDocs
val generatedCode = BasicGenerator.generateObjects(
doc,
"sttp.tapir.generated",
"TapirGeneratedEndpoints",
targetScala3 = false,
useHeadTagForObjectNames = false
)("TapirGeneratedEndpoints")
generatedCode should include(
""".attribute[String](new AttributeKey[String]("custom-string-extension-on-path"), "foobar")"""
)
generatedCode should include(
""".attribute[String](new AttributeKey[String]("custom-string-extension-on-operation"), "bazquux")"""
)
generatedCode should include(
""".attribute[Seq[String]](new AttributeKey[Seq[String]]("custom-list-extension-on-operation"), Vector("baz", "quux"))"""
)
generatedCode should include(
""".attribute[Map[String, Any]](new AttributeKey[Map[String, Any]]("custom-map-extension-on-path"), Map("bazkey" -> "bazval", "quuxkey" -> Vector("quux1", "quux2"))"""
)
generatedCode shouldCompile ()
}

}

0 comments on commit cde47d7

Please sign in to comment.