From 4de826cf2eeeb901ae27fdf7ee12980a6de05882 Mon Sep 17 00:00:00 2001
From: Hugh Simpson <hsimpson@rzsoftware.com>
Date: Wed, 13 Mar 2024 10:47:27 +0000
Subject: [PATCH] parse specification extensions on paths and operations

---
 .../openapi/models/OpenapiModels.scala        | 40 +++++++-
 .../sttp/tapir/codegen/TestHelpers.scala      | 97 ++++++++++++++++---
 .../codegen/models/ModelParserSpec.scala      | 11 +++
 3 files changed, 131 insertions(+), 17 deletions(-)

diff --git a/openapi-codegen/core/src/main/scala/sttp/tapir/codegen/openapi/models/OpenapiModels.scala b/openapi-codegen/core/src/main/scala/sttp/tapir/codegen/openapi/models/OpenapiModels.scala
index fc22b87521..95ce6062d3 100644
--- a/openapi-codegen/core/src/main/scala/sttp/tapir/codegen/openapi/models/OpenapiModels.scala
+++ b/openapi-codegen/core/src/main/scala/sttp/tapir/codegen/openapi/models/OpenapiModels.scala
@@ -6,6 +6,14 @@ import cats.syntax.either._
 import OpenapiSchemaType.OpenapiSchemaRef
 // https://swagger.io/specification/
 object OpenapiModels {
+  sealed trait SpecificationExtensionValue
+  case object SpecificationExtensionValueNull extends SpecificationExtensionValue
+  case class SpecificationExtensionValueBoolean(value: Boolean) extends SpecificationExtensionValue
+  case class SpecificationExtensionValueLong(value: Long) extends SpecificationExtensionValue
+  case class SpecificationExtensionValueDouble(value: Double) extends SpecificationExtensionValue
+  case class SpecificationExtensionValueString(value: String) extends SpecificationExtensionValue
+  case class SpecificationExtensionValueList(values: Seq[SpecificationExtensionValue]) extends SpecificationExtensionValue
+  case class SpecificationExtensionValueMap(kvs: Map[String, SpecificationExtensionValue]) extends SpecificationExtensionValue
 
   sealed trait Resolvable[T] {
     def resolve(input: Map[String, T]): T
@@ -35,7 +43,8 @@ object OpenapiModels {
   case class OpenapiPath(
       url: String,
       methods: Seq[OpenapiPathMethod],
-      parameters: Seq[Resolvable[OpenapiParameter]] = Nil
+      parameters: Seq[Resolvable[OpenapiParameter]] = Nil,
+      specificationExtensions: Map[String, SpecificationExtensionValue] = Map.empty
   )
 
   case class OpenapiPathMethod(
@@ -46,7 +55,8 @@ object OpenapiModels {
       security: Seq[Seq[String]] = Nil,
       summary: Option[String] = None,
       tags: Option[Seq[String]] = None,
-      operationId: Option[String] = None
+      operationId: Option[String] = None,
+      specificationExtensions: Map[String, SpecificationExtensionValue] = Map.empty
   ) {
     def resolvedParameters: Seq[OpenapiParameter] = parameters.collect { case Resolved(t) => t }
     def withResolvedParentParameters(
@@ -166,6 +176,19 @@ object OpenapiModels {
   implicit def ResolvableDecoder[T: Decoder]: Decoder[Resolvable[T]] = { (c: HCursor) =>
     c.as[T].map(Resolved(_)).orElse(c.as[OpenapiSchemaRef].map(r => Ref(r.name)))
   }
+
+  def decodeSpecificationExtensionValue(json: Json): SpecificationExtensionValue =
+    json.fold(
+      SpecificationExtensionValueNull,
+      SpecificationExtensionValueBoolean.apply,
+      n => n.toLong.map(SpecificationExtensionValueLong.apply).getOrElse(SpecificationExtensionValueDouble(n.toDouble)),
+      SpecificationExtensionValueString.apply,
+      arr => SpecificationExtensionValueList(arr.map(decodeSpecificationExtensionValue)),
+      obj => SpecificationExtensionValueMap(obj.toMap.map { case (k, v) => k -> decodeSpecificationExtensionValue(v) })
+    )
+  implicit val SpecificationExtensionValueDecoder: Decoder[SpecificationExtensionValue] = { (c: HCursor) =>
+    Right(decodeSpecificationExtensionValue(c.value))
+  }
   implicit val PartialOpenapiPathMethodDecoder: Decoder[OpenapiPathMethod] = { (c: HCursor) =>
     for {
       parameters <- c.getOrElse[Seq[Resolvable[OpenapiParameter]]]("parameters")(Nil)
@@ -175,6 +198,10 @@ object OpenapiModels {
       summary <- c.get[Option[String]]("summary")
       tags <- c.get[Option[Seq[String]]]("tags")
       operationId <- c.get[Option[String]]("operationId")
+      specificationExtensionKeys = c.keys.toSeq.flatMap(_.filter(_.startsWith("x-")))
+      specificationExtensions = specificationExtensionKeys
+        .flatMap(key => c.downField(key).as[SpecificationExtensionValue].toOption.map(key.stripPrefix("x-") -> _))
+        .toMap
     } yield {
       OpenapiPathMethod(
         "--partial--",
@@ -184,7 +211,8 @@ object OpenapiModels {
         security.map(_.keys.toSeq),
         summary,
         tags,
-        operationId
+        operationId,
+        specificationExtensions
       )
     }
   }
@@ -197,7 +225,11 @@ object OpenapiModels {
         .map(_.getOrElse(Nil))
       methods <- List("get", "put", "post", "delete", "options", "head", "patch", "connect", "trace")
         .traverse(method => c.downField(method).as[Option[OpenapiPathMethod]].map(_.map(_.copy(methodType = method))))
-    } yield OpenapiPath("--partial--", methods.flatten, parameters)
+      specificationExtensionKeys = c.keys.toSeq.flatMap(_.filter(_.startsWith("x-")))
+      specificationExtensions = specificationExtensionKeys
+        .flatMap(key => c.downField(key).as[SpecificationExtensionValue].toOption.map(key.stripPrefix("x-") -> _))
+        .toMap
+    } yield OpenapiPath("--partial--", methods.flatten, parameters, specificationExtensions)
   }
 
   implicit val OpenapiPathsDecoder: Decoder[Seq[OpenapiPath]] = { (c: HCursor) =>
diff --git a/openapi-codegen/core/src/test/scala/sttp/tapir/codegen/TestHelpers.scala b/openapi-codegen/core/src/test/scala/sttp/tapir/codegen/TestHelpers.scala
index e84024803a..7f44a7a8ec 100644
--- a/openapi-codegen/core/src/test/scala/sttp/tapir/codegen/TestHelpers.scala
+++ b/openapi-codegen/core/src/test/scala/sttp/tapir/codegen/TestHelpers.scala
@@ -1,19 +1,7 @@
 package sttp.tapir.codegen
 
 import sttp.tapir.codegen.openapi.models.OpenapiComponent
-import sttp.tapir.codegen.openapi.models.OpenapiModels.{
-  OpenapiDocument,
-  OpenapiInfo,
-  OpenapiParameter,
-  OpenapiPath,
-  OpenapiPathMethod,
-  OpenapiRequestBody,
-  OpenapiRequestBodyContent,
-  OpenapiResponse,
-  OpenapiResponseContent,
-  Ref,
-  Resolved
-}
+import sttp.tapir.codegen.openapi.models.OpenapiModels._
 import sttp.tapir.codegen.openapi.models.OpenapiSchemaType.{
   OpenapiSchemaArray,
   OpenapiSchemaConstantString,
@@ -577,4 +565,87 @@ object TestHelpers {
       )
     )
   )
+
+  val specificationExtensionYaml =
+    """
+     |openapi: 3.1.0
+     |info:
+     |  title: hello goodbye
+     |  version: '1.0'
+     |paths:
+     |  /hello:
+     |    x-custom-string-extension-on-path: foobar
+     |    x-custom-list-extension-on-path:
+     |      - foo
+     |      - bar
+     |    x-custom-map-extension-on-path:
+     |      bazkey: bazval
+     |      quuxkey:
+     |        - quux1
+     |        - quux2
+     |    post:
+     |      responses: {}
+     |  /goodbye:
+     |    delete:
+     |      x-custom-string-extension-on-operation: bazquux
+     |      x-custom-list-extension-on-operation:
+     |        - baz
+     |        - quux
+     |      x-custom-map-extension-on-operation:
+     |        bazkey: bazval
+     |        quuxkey:
+     |          - quux1
+     |          - quux2
+     |      responses: {}""".stripMargin
+
+  val specificationExtensionDocs = OpenapiDocument(
+    "3.1.0",
+    OpenapiInfo("hello goodbye", "1.0"),
+    Seq(
+      OpenapiPath(
+        url = "/hello",
+        methods = Seq(OpenapiPathMethod(methodType = "post", parameters = Seq(), responses = Seq(), requestBody = None)),
+        specificationExtensions = Map(
+          "custom-string-extension-on-path" -> SpecificationExtensionValueString("foobar"),
+          "custom-list-extension-on-path" -> SpecificationExtensionValueList(
+            Vector(SpecificationExtensionValueString("foo"), SpecificationExtensionValueString("bar"))
+          ),
+          "custom-map-extension-on-path" -> SpecificationExtensionValueMap(
+            Map(
+              "bazkey" -> SpecificationExtensionValueString("bazval"),
+              "quuxkey" -> SpecificationExtensionValueList(
+                Vector(SpecificationExtensionValueString("quux1"), SpecificationExtensionValueString("quux2"))
+              )
+            )
+          )
+        )
+      ),
+      OpenapiPath(
+        url = "/goodbye",
+        methods = Seq(
+          OpenapiPathMethod(
+            methodType = "delete",
+            parameters = Seq(),
+            responses = Seq(),
+            requestBody = None,
+            specificationExtensions = Map(
+              "custom-string-extension-on-operation" -> SpecificationExtensionValueString("bazquux"),
+              "custom-list-extension-on-operation" -> SpecificationExtensionValueList(
+                Vector(SpecificationExtensionValueString("baz"), SpecificationExtensionValueString("quux"))
+              ),
+              "custom-map-extension-on-operation" -> SpecificationExtensionValueMap(
+                Map(
+                  "bazkey" -> SpecificationExtensionValueString("bazval"),
+                  "quuxkey" -> SpecificationExtensionValueList(
+                    Vector(SpecificationExtensionValueString("quux1"), SpecificationExtensionValueString("quux2"))
+                  )
+                )
+              )
+            )
+          )
+        )
+      )
+    ),
+    None
+  )
 }
diff --git a/openapi-codegen/core/src/test/scala/sttp/tapir/codegen/models/ModelParserSpec.scala b/openapi-codegen/core/src/test/scala/sttp/tapir/codegen/models/ModelParserSpec.scala
index 1489ad50bb..fd1a325826 100644
--- a/openapi-codegen/core/src/test/scala/sttp/tapir/codegen/models/ModelParserSpec.scala
+++ b/openapi-codegen/core/src/test/scala/sttp/tapir/codegen/models/ModelParserSpec.scala
@@ -167,4 +167,15 @@ class ModelParserSpec extends AnyFlatSpec with Matchers with Checkers {
       OpenapiSchemaEnum("string", Seq(OpenapiSchemaConstantString("paperback"), OpenapiSchemaConstantString("hardback")), false)
     )
   }
+
+  it should "parse endpoint with simple specification extensions" in {
+    val res = parser
+      .parse(TestHelpers.specificationExtensionYaml)
+      .leftMap(err => err: Error)
+      .flatMap(_.as[OpenapiDocument])
+
+    res shouldBe (Right(
+      TestHelpers.specificationExtensionDocs
+    ))
+  }
 }