From 05f9b2b4a67d8622cb8642e5f55d442ffc328cae Mon Sep 17 00:00:00 2001 From: Markus Sammallahti <6243605+markussammallahti@users.noreply.github.com> Date: Thu, 10 Aug 2023 17:30:54 +0300 Subject: [PATCH] Add security support for codegen (#3094) --- .../sttp/tapir/codegen/BasicGenerator.scala | 1 + .../tapir/codegen/EndpointGenerator.scala | 35 ++++++++- .../openapi/models/OpenapiComponent.scala | 14 ++-- .../openapi/models/OpenapiModels.scala | 28 ++++--- .../models/OpenapiSecuritySchemeType.scala | 49 ++++++++++++ .../tapir/codegen/EndpointGeneratorSpec.scala | 76 ++++++++++++++++++ .../sttp/tapir/codegen/TestHelpers.scala | 78 ++++++++++++++++++- .../codegen/models/ModelParserSpec.scala | 22 ++++++ .../codegen/models/SchemaParserSpec.scala | 47 +++++++++++ 9 files changed, 327 insertions(+), 23 deletions(-) create mode 100644 openapi-codegen/core/src/main/scala/sttp/tapir/codegen/openapi/models/OpenapiSecuritySchemeType.scala diff --git a/openapi-codegen/core/src/main/scala/sttp/tapir/codegen/BasicGenerator.scala b/openapi-codegen/core/src/main/scala/sttp/tapir/codegen/BasicGenerator.scala index 14c828ef68..1b9515f887 100644 --- a/openapi-codegen/core/src/main/scala/sttp/tapir/codegen/BasicGenerator.scala +++ b/openapi-codegen/core/src/main/scala/sttp/tapir/codegen/BasicGenerator.scala @@ -40,6 +40,7 @@ object BasicGenerator { private[codegen] def imports: String = """import sttp.tapir._ + |import sttp.tapir.model._ |import sttp.tapir.json.circe._ |import sttp.tapir.generic.auto._ |import io.circe.generic.auto._ diff --git a/openapi-codegen/core/src/main/scala/sttp/tapir/codegen/EndpointGenerator.scala b/openapi-codegen/core/src/main/scala/sttp/tapir/codegen/EndpointGenerator.scala index 9417699a43..3504ab127a 100644 --- a/openapi-codegen/core/src/main/scala/sttp/tapir/codegen/EndpointGenerator.scala +++ b/openapi-codegen/core/src/main/scala/sttp/tapir/codegen/EndpointGenerator.scala @@ -2,7 +2,7 @@ package sttp.tapir.codegen import sttp.tapir.codegen.BasicGenerator.{indent, mapSchemaSimpleTypeToType} import sttp.tapir.codegen.openapi.models.OpenapiModels.{OpenapiDocument, OpenapiParameter, OpenapiPath, OpenapiRequestBody, OpenapiResponse} -import sttp.tapir.codegen.openapi.models.OpenapiSchemaType +import sttp.tapir.codegen.openapi.models.{OpenapiComponent, OpenapiSchemaType, OpenapiSecuritySchemeType} import sttp.tapir.codegen.openapi.models.OpenapiSchemaType.{OpenapiSchemaArray, OpenapiSchemaSimpleType} class EndpointGenerator { @@ -10,8 +10,8 @@ class EndpointGenerator { private[codegen] def allEndpoints: String = "generatedEndpoints" def endpointDefs(doc: OpenapiDocument): String = { - val ps = Option(doc.components).flatten.map(_.parameters) getOrElse Map.empty - val ge = doc.paths.flatMap(generatedEndpoints(ps)) + val components = Option(doc.components).flatten + val ge = doc.paths.flatMap(generatedEndpoints(components)) val definitions = ge .map { case (name, definition) => s"""|lazy val $name = @@ -27,12 +27,16 @@ class EndpointGenerator { |""".stripMargin } - private[codegen] def generatedEndpoints(parameters: Map[String, OpenapiParameter])(p: OpenapiPath): Seq[(String, String)] = { + private[codegen] def generatedEndpoints(components: Option[OpenapiComponent])(p: OpenapiPath): Seq[(String, String)] = { + val parameters = components.map(_.parameters).getOrElse(Map.empty) + val securitySchemes = components.map(_.securitySchemes).getOrElse(Map.empty) + p.methods.map(_.withResolvedParentParameters(parameters, p.parameters)).map { m => val definition = s"""|endpoint | .${m.methodType} | ${urlMapper(p.url, m.resolvedParameters)} + |${indent(2)(security(securitySchemes, m.security))} |${indent(2)(ins(m.resolvedParameters, m.requestBody))} |${indent(2)(outs(m.responses))} |${indent(2)(tags(m.tags))} @@ -71,6 +75,29 @@ class EndpointGenerator { ".in((" + inPath.mkString(" / ") + "))" } + private def security(securitySchemes: Map[String, OpenapiSecuritySchemeType], security: Seq[Seq[String]]) = { + if (security.size > 1 || security.exists(_.size > 1)) + throw new NotImplementedError("We can handle only single security entry!") + + security.headOption + .flatMap(_.headOption) + .fold("") { schemeName => + securitySchemes.get(schemeName) match { + case Some(OpenapiSecuritySchemeType.OpenapiSecuritySchemeBearerType) => + ".securityIn(auth.bearer[String]())" + + case Some(OpenapiSecuritySchemeType.OpenapiSecuritySchemeBasicType) => + ".securityIn(auth.basic[UsernamePassword]())" + + case Some(OpenapiSecuritySchemeType.OpenapiSecuritySchemeApiKeyType(in, name)) => + s""".securityIn(auth.apiKey($in[String]("$name")))""" + + case None => + throw new Error(s"Unknown security scheme $schemeName!") + } + } + } + private def ins(parameters: Seq[OpenapiParameter], requestBody: Option[OpenapiRequestBody]): String = { // .in(query[Limit]("limit").description("Maximum number of books to retrieve")) // .in(header[AuthToken]("X-Auth-Token")) diff --git a/openapi-codegen/core/src/main/scala/sttp/tapir/codegen/openapi/models/OpenapiComponent.scala b/openapi-codegen/core/src/main/scala/sttp/tapir/codegen/openapi/models/OpenapiComponent.scala index b5258d4264..b1711dcbe0 100644 --- a/openapi-codegen/core/src/main/scala/sttp/tapir/codegen/openapi/models/OpenapiComponent.scala +++ b/openapi-codegen/core/src/main/scala/sttp/tapir/codegen/openapi/models/OpenapiComponent.scala @@ -1,11 +1,10 @@ package sttp.tapir.codegen.openapi.models -import cats.syntax.either._ - import OpenapiModels.OpenapiParameter case class OpenapiComponent( schemas: Map[String, OpenapiSchemaType], + securitySchemes: Map[String, OpenapiSecuritySchemeType] = Map.empty, parameters: Map[String, OpenapiParameter] = Map.empty ) @@ -14,10 +13,15 @@ object OpenapiComponent { implicit val OpenapiComponentDecoder: Decoder[OpenapiComponent] = { (c: HCursor) => for { - schemas <- c.downField("schemas").as[Map[String, OpenapiSchemaType]] - parameters <- c.downField("parameters").as[Option[Map[String, OpenapiParameter]]].map(_.getOrElse(Map.empty)) + schemas <- c.getOrElse[Map[String, OpenapiSchemaType]]("schemas")(Map.empty) + securitySchemes <- c.getOrElse[Map[String, OpenapiSecuritySchemeType]]("securitySchemes")(Map.empty) + parameters <- c.getOrElse[Map[String, OpenapiParameter]]("parameters")(Map.empty) } yield { - OpenapiComponent(schemas, parameters.map { case (k, v) => s"#/components/parameters/$k" -> v }) + OpenapiComponent( + schemas, + securitySchemes, + parameters.map { case (k, v) => s"#/components/parameters/$k" -> v } + ) } } } diff --git a/openapi-codegen/core/src/main/scala/sttp/tapir/codegen/openapi/models/OpenapiModels.scala b/openapi-codegen/core/src/main/scala/sttp/tapir/codegen/openapi/models/OpenapiModels.scala index 88fadc8455..c1298c0541 100644 --- a/openapi-codegen/core/src/main/scala/sttp/tapir/codegen/openapi/models/OpenapiModels.scala +++ b/openapi-codegen/core/src/main/scala/sttp/tapir/codegen/openapi/models/OpenapiModels.scala @@ -43,6 +43,7 @@ object OpenapiModels { parameters: Seq[Resolvable[OpenapiParameter]], responses: Seq[OpenapiResponse], requestBody: Option[OpenapiRequestBody], + security: Seq[Seq[String]] = Nil, summary: Option[String] = None, tags: Option[Seq[String]] = None, operationId: Option[String] = None @@ -167,17 +168,24 @@ object OpenapiModels { } implicit val PartialOpenapiPathMethodDecoder: Decoder[OpenapiPathMethod] = { (c: HCursor) => for { - parameters <- c - .downField("parameters") - .as[Option[Seq[Resolvable[OpenapiParameter]]]] - .map(_.getOrElse(Nil)) - responses <- c.downField("responses").as[Seq[OpenapiResponse]] - requestBody <- c.downField("requestBody").as[Option[OpenapiRequestBody]] - summary <- c.downField("summary").as[Option[String]] - tags <- c.downField("tags").as[Option[Seq[String]]] - operationId <- c.downField("operationId").as[Option[String]] + parameters <- c.getOrElse[Seq[Resolvable[OpenapiParameter]]]("parameters")(Nil) + responses <- c.get[Seq[OpenapiResponse]]("responses") + requestBody <- c.get[Option[OpenapiRequestBody]]("requestBody") + security <- c.getOrElse[Seq[Map[String, Seq[String]]]]("security")(Nil) + summary <- c.get[Option[String]]("summary") + tags <- c.get[Option[Seq[String]]]("tags") + operationId <- c.get[Option[String]]("operationId") } yield { - OpenapiPathMethod("--partial--", parameters, responses, requestBody, summary, tags, operationId) + OpenapiPathMethod( + "--partial--", + parameters, + responses, + requestBody, + security.map(_.keys.toSeq), + summary, + tags, + operationId + ) } } diff --git a/openapi-codegen/core/src/main/scala/sttp/tapir/codegen/openapi/models/OpenapiSecuritySchemeType.scala b/openapi-codegen/core/src/main/scala/sttp/tapir/codegen/openapi/models/OpenapiSecuritySchemeType.scala new file mode 100644 index 0000000000..ef991cc487 --- /dev/null +++ b/openapi-codegen/core/src/main/scala/sttp/tapir/codegen/openapi/models/OpenapiSecuritySchemeType.scala @@ -0,0 +1,49 @@ +package sttp.tapir.codegen.openapi.models + +sealed trait OpenapiSecuritySchemeType + +object OpenapiSecuritySchemeType { + case object OpenapiSecuritySchemeBearerType extends OpenapiSecuritySchemeType + case object OpenapiSecuritySchemeBasicType extends OpenapiSecuritySchemeType + case class OpenapiSecuritySchemeApiKeyType(in: String, name: String) extends OpenapiSecuritySchemeType + + import io.circe._ + import cats.implicits._ + + private implicit val BearerTypeDecoder: Decoder[OpenapiSecuritySchemeBearerType.type] = { (c: HCursor) => + for { + _ <- c.get[String]("type").ensure(DecodingFailure("Given type is not http!", c.history))(_ == "http") + _ <- c.get[String]("scheme").ensure(DecodingFailure("Given scheme is not bearer!", c.history))(_ == "bearer") + } yield { + OpenapiSecuritySchemeBearerType + } + } + + private implicit val BasicTypeDecoder: Decoder[OpenapiSecuritySchemeBasicType.type] = { (c: HCursor) => + for { + _ <- c.get[String]("type").ensure(DecodingFailure("Given type is not http!", c.history))(_ == "http") + _ <- c.get[String]("scheme").ensure(DecodingFailure("Given scheme is not basic!", c.history))(_ == "basic") + } yield { + OpenapiSecuritySchemeBasicType + } + } + + private val ApiKeyInOptions = List("header", "query", "cookie") + + private implicit val ApiKeyDecoder: Decoder[OpenapiSecuritySchemeApiKeyType] = { (c: HCursor) => + for { + _ <- c.get[String]("type").ensure(DecodingFailure("Given type is not apiKey!", c.history))(_ == "apiKey") + in <- c.get[String]("in").ensure(DecodingFailure("Invalid apiKey in value!", c.history))(ApiKeyInOptions.contains) + name <- c.get[String]("name") + } yield { + OpenapiSecuritySchemeApiKeyType(in, name) + } + } + + implicit val OpenapiSecuritySchemeTypeDecoder: Decoder[OpenapiSecuritySchemeType] = + List[Decoder[OpenapiSecuritySchemeType]]( + Decoder[OpenapiSecuritySchemeBearerType.type].widen, + Decoder[OpenapiSecuritySchemeBasicType.type].widen, + Decoder[OpenapiSecuritySchemeApiKeyType].widen + ).reduceLeft(_ or _) +} diff --git a/openapi-codegen/core/src/test/scala/sttp/tapir/codegen/EndpointGeneratorSpec.scala b/openapi-codegen/core/src/test/scala/sttp/tapir/codegen/EndpointGeneratorSpec.scala index fc0be736ce..2b49655f6c 100644 --- a/openapi-codegen/core/src/test/scala/sttp/tapir/codegen/EndpointGeneratorSpec.scala +++ b/openapi-codegen/core/src/test/scala/sttp/tapir/codegen/EndpointGeneratorSpec.scala @@ -1,5 +1,6 @@ package sttp.tapir.codegen +import sttp.tapir.codegen.openapi.models.OpenapiComponent import sttp.tapir.codegen.openapi.models.OpenapiModels.{ OpenapiDocument, OpenapiParameter, @@ -9,6 +10,11 @@ import sttp.tapir.codegen.openapi.models.OpenapiModels.{ OpenapiResponseContent, Resolved } +import sttp.tapir.codegen.openapi.models.OpenapiSecuritySchemeType.{ + OpenapiSecuritySchemeBearerType, + OpenapiSecuritySchemeBasicType, + OpenapiSecuritySchemeApiKeyType +} import sttp.tapir.codegen.openapi.models.OpenapiSchemaType.{OpenapiSchemaArray, OpenapiSchemaString} import sttp.tapir.codegen.testutils.CompileCheckTestBase @@ -47,4 +53,74 @@ class EndpointGeneratorSpec extends CompileCheckTestBase { generatedCode shouldCompile () } + it should "generete endpoints defs with security" in { + val doc = OpenapiDocument( + "", + null, + Seq( + OpenapiPath( + "test", + Seq( + OpenapiPathMethod( + methodType = "get", + parameters = Seq(), + responses = Seq(), + requestBody = None, + security = Seq(Seq("httpBearer")), + summary = None, + tags = None + ), + OpenapiPathMethod( + methodType = "post", + parameters = Seq(), + responses = Seq(), + requestBody = None, + security = Seq(Seq("httpBasic")), + summary = None, + tags = None + ), + OpenapiPathMethod( + methodType = "put", + parameters = Seq(), + responses = Seq(), + requestBody = None, + security = Seq(Seq("apiKeyHeader")), + summary = None, + tags = None + ), + OpenapiPathMethod( + methodType = "patch", + parameters = Seq(), + responses = Seq(), + requestBody = None, + security = Seq(Seq("apiKeyCookie")), + summary = None, + tags = None + ), + OpenapiPathMethod( + methodType = "delete", + parameters = Seq(), + responses = Seq(), + requestBody = None, + security = Seq(Seq("apiKeyQuery")), + summary = None, + tags = None + ) + ) + ) + ), + Some(OpenapiComponent( + Map(), + Map( + "httpBearer" -> OpenapiSecuritySchemeBearerType, + "httpBasic" -> OpenapiSecuritySchemeBasicType, + "apiKeyHeader" -> OpenapiSecuritySchemeApiKeyType("header", "X-API-KEY"), + "apiKeyCookie" -> OpenapiSecuritySchemeApiKeyType("cookie", "api_key"), + "apiKeyQuery" -> OpenapiSecuritySchemeApiKeyType("query", "api-key") + ) + )) + ) + BasicGenerator.imports ++ + new EndpointGenerator().endpointDefs(doc) shouldCompile() + } } diff --git a/openapi-codegen/core/src/test/scala/sttp/tapir/codegen/TestHelpers.scala b/openapi-codegen/core/src/test/scala/sttp/tapir/codegen/TestHelpers.scala index d11d215c1d..c829dcb49d 100644 --- a/openapi-codegen/core/src/test/scala/sttp/tapir/codegen/TestHelpers.scala +++ b/openapi-codegen/core/src/test/scala/sttp/tapir/codegen/TestHelpers.scala @@ -1,6 +1,6 @@ package sttp.tapir.codegen -import sttp.tapir.codegen.openapi.models.{OpenapiComponent, OpenapiSchemaType} +import sttp.tapir.codegen.openapi.models.OpenapiComponent import sttp.tapir.codegen.openapi.models.OpenapiModels.{ OpenapiDocument, OpenapiInfo, @@ -16,7 +16,6 @@ import sttp.tapir.codegen.openapi.models.OpenapiModels.{ } import sttp.tapir.codegen.openapi.models.OpenapiSchemaType.{ OpenapiSchemaArray, - OpenapiSchemaDouble, OpenapiSchemaInt, OpenapiSchemaObject, OpenapiSchemaRef, @@ -227,10 +226,11 @@ object TestHelpers { ), Some( OpenapiComponent( - Map( + schemas = Map( "Book" -> OpenapiSchemaObject(Map("title" -> OpenapiSchemaString(false)), Seq("title"), false) ), - Map( + securitySchemes = Map.empty, + parameters = Map( "#/components/parameters/offset" -> OpenapiParameter("offset", "query", true, Some("Offset at which to start fetching books"), OpenapiSchemaInt(false)), "#/components/parameters/year" -> OpenapiParameter("year", "path", true, None, OpenapiSchemaInt(false)) @@ -440,4 +440,74 @@ object TestHelpers { ), None ) + + val simpleSecurityYaml = + """ + |openapi: 3.1.0 + |info: + | title: hello + | version: '1.0' + |paths: + | /hello: + | get: + | security: + | - basicAuth: [] + | responses: {} + """.stripMargin + + val simpleSecurityDocs = OpenapiDocument( + "3.1.0", + OpenapiInfo("hello", "1.0"), + Seq( + OpenapiPath( + url = "/hello", + methods = Seq( + OpenapiPathMethod( + methodType = "get", + parameters = Seq(), + responses = Seq(), + requestBody = None, + security = Seq(Seq("basicAuth")) + ) + ) + ) + ), + None + ) + + val complexSecurityYaml = + """ + |openapi: 3.1.0 + |info: + | title: hello + | version: '1.0' + |paths: + | /hello: + | get: + | security: + | - bearerAuth: [] + | - basicAuth: [] + | apiKeyAuth: [] + | responses: {} + """.stripMargin + + val complexSecurityDocs = OpenapiDocument( + "3.1.0", + OpenapiInfo("hello", "1.0"), + Seq( + OpenapiPath( + url = "/hello", + methods = Seq( + OpenapiPathMethod( + methodType = "get", + parameters = Seq(), + responses = Seq(), + requestBody = None, + security = Seq(Seq("bearerAuth"), Seq("basicAuth", "apiKeyAuth")) + ) + ) + ) + ), + None + ) } diff --git a/openapi-codegen/core/src/test/scala/sttp/tapir/codegen/models/ModelParserSpec.scala b/openapi-codegen/core/src/test/scala/sttp/tapir/codegen/models/ModelParserSpec.scala index b4910a8b32..da796f0c1f 100644 --- a/openapi-codegen/core/src/test/scala/sttp/tapir/codegen/models/ModelParserSpec.scala +++ b/openapi-codegen/core/src/test/scala/sttp/tapir/codegen/models/ModelParserSpec.scala @@ -93,6 +93,28 @@ class ModelParserSpec extends AnyFlatSpec with Matchers with Checkers { )) } + it should "parse endpoint with single security entry" in { + val res = parser + .parse(TestHelpers.simpleSecurityYaml) + .leftMap(err => err: Error) + .flatMap(_.as[OpenapiDocument]) + + res shouldBe (Right( + TestHelpers.simpleSecurityDocs + )) + } + + it should "parse endpoint with complex security entry" in { + val res = parser + .parse(TestHelpers.complexSecurityYaml) + .leftMap(err => err: Error) + .flatMap(_.as[OpenapiDocument]) + + res shouldBe (Right( + TestHelpers.complexSecurityDocs + )) + } + it should "parse uuids" in { val yaml = """ diff --git a/openapi-codegen/core/src/test/scala/sttp/tapir/codegen/models/SchemaParserSpec.scala b/openapi-codegen/core/src/test/scala/sttp/tapir/codegen/models/SchemaParserSpec.scala index e86344ffbd..a535b3f9ac 100644 --- a/openapi-codegen/core/src/test/scala/sttp/tapir/codegen/models/SchemaParserSpec.scala +++ b/openapi-codegen/core/src/test/scala/sttp/tapir/codegen/models/SchemaParserSpec.scala @@ -8,6 +8,11 @@ import sttp.tapir.codegen.openapi.models.OpenapiSchemaType.{ OpenapiSchemaObject, OpenapiSchemaString } +import sttp.tapir.codegen.openapi.models.OpenapiSecuritySchemeType.{ + OpenapiSecuritySchemeBearerType, + OpenapiSecuritySchemeBasicType, + OpenapiSecuritySchemeApiKeyType +} import org.scalatest.flatspec.AnyFlatSpec import org.scalatest.matchers.should.Matchers import org.scalatestplus.scalacheck.Checkers @@ -82,6 +87,48 @@ class SchemaParserSpec extends AnyFlatSpec with Matchers with Checkers { ) } + it should "parse security schemes" in { + val yaml = + """ + |securitySchemes: + | httpAuthBearer: + | type: http + | scheme: bearer + | httpAuthBasic: + | type: http + | scheme: basic + | apiKeyHeader: + | type: apiKey + | in: header + | name: X-API-KEY + | apiKeyCookie: + | type: apiKey + | in: cookie + | name: api_key + | apiKeyQuery: + | type: apiKey + | in: query + | name: api-key""".stripMargin + + val res = parser + .parse(yaml) + .leftMap(err => err: Error) + .flatMap(_.as[OpenapiComponent]) + + res shouldBe Right( + OpenapiComponent( + Map(), + Map( + "httpAuthBearer" -> OpenapiSecuritySchemeBearerType, + "httpAuthBasic" -> OpenapiSecuritySchemeBasicType, + "apiKeyHeader" -> OpenapiSecuritySchemeApiKeyType("header", "X-API-KEY"), + "apiKeyCookie" -> OpenapiSecuritySchemeApiKeyType("cookie", "api_key"), + "apiKeyQuery" -> OpenapiSecuritySchemeApiKeyType("query", "api-key") + ) + ) + ) + } + it should "parse basic-response (array) yaml" in { // https://swagger.io/docs/specification/basic-structure/ val yaml = """application/json: