Skip to content

Commit

Permalink
Add security support for codegen (#3094)
Browse files Browse the repository at this point in the history
  • Loading branch information
markussammallahti authored Aug 10, 2023
1 parent 5c73922 commit 05f9b2b
Show file tree
Hide file tree
Showing 9 changed files with 327 additions and 23 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ object BasicGenerator {

private[codegen] def imports: String =
"""import sttp.tapir._
|import sttp.tapir.model._
|import sttp.tapir.json.circe._
|import sttp.tapir.generic.auto._
|import io.circe.generic.auto._
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,16 @@ package sttp.tapir.codegen

import sttp.tapir.codegen.BasicGenerator.{indent, mapSchemaSimpleTypeToType}
import sttp.tapir.codegen.openapi.models.OpenapiModels.{OpenapiDocument, OpenapiParameter, OpenapiPath, OpenapiRequestBody, OpenapiResponse}
import sttp.tapir.codegen.openapi.models.OpenapiSchemaType
import sttp.tapir.codegen.openapi.models.{OpenapiComponent, OpenapiSchemaType, OpenapiSecuritySchemeType}
import sttp.tapir.codegen.openapi.models.OpenapiSchemaType.{OpenapiSchemaArray, OpenapiSchemaSimpleType}

class EndpointGenerator {

private[codegen] def allEndpoints: String = "generatedEndpoints"

def endpointDefs(doc: OpenapiDocument): String = {
val ps = Option(doc.components).flatten.map(_.parameters) getOrElse Map.empty
val ge = doc.paths.flatMap(generatedEndpoints(ps))
val components = Option(doc.components).flatten
val ge = doc.paths.flatMap(generatedEndpoints(components))
val definitions = ge
.map { case (name, definition) =>
s"""|lazy val $name =
Expand All @@ -27,12 +27,16 @@ class EndpointGenerator {
|""".stripMargin
}

private[codegen] def generatedEndpoints(parameters: Map[String, OpenapiParameter])(p: OpenapiPath): Seq[(String, String)] = {
private[codegen] def generatedEndpoints(components: Option[OpenapiComponent])(p: OpenapiPath): Seq[(String, String)] = {
val parameters = components.map(_.parameters).getOrElse(Map.empty)
val securitySchemes = components.map(_.securitySchemes).getOrElse(Map.empty)

p.methods.map(_.withResolvedParentParameters(parameters, p.parameters)).map { m =>
val definition =
s"""|endpoint
| .${m.methodType}
| ${urlMapper(p.url, m.resolvedParameters)}
|${indent(2)(security(securitySchemes, m.security))}
|${indent(2)(ins(m.resolvedParameters, m.requestBody))}
|${indent(2)(outs(m.responses))}
|${indent(2)(tags(m.tags))}
Expand Down Expand Up @@ -71,6 +75,29 @@ class EndpointGenerator {
".in((" + inPath.mkString(" / ") + "))"
}

private def security(securitySchemes: Map[String, OpenapiSecuritySchemeType], security: Seq[Seq[String]]) = {
if (security.size > 1 || security.exists(_.size > 1))
throw new NotImplementedError("We can handle only single security entry!")

security.headOption
.flatMap(_.headOption)
.fold("") { schemeName =>
securitySchemes.get(schemeName) match {
case Some(OpenapiSecuritySchemeType.OpenapiSecuritySchemeBearerType) =>
".securityIn(auth.bearer[String]())"

case Some(OpenapiSecuritySchemeType.OpenapiSecuritySchemeBasicType) =>
".securityIn(auth.basic[UsernamePassword]())"

case Some(OpenapiSecuritySchemeType.OpenapiSecuritySchemeApiKeyType(in, name)) =>
s""".securityIn(auth.apiKey($in[String]("$name")))"""

case None =>
throw new Error(s"Unknown security scheme $schemeName!")
}
}
}

private def ins(parameters: Seq[OpenapiParameter], requestBody: Option[OpenapiRequestBody]): String = {
// .in(query[Limit]("limit").description("Maximum number of books to retrieve"))
// .in(header[AuthToken]("X-Auth-Token"))
Expand Down
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
package sttp.tapir.codegen.openapi.models

import cats.syntax.either._

import OpenapiModels.OpenapiParameter

case class OpenapiComponent(
schemas: Map[String, OpenapiSchemaType],
securitySchemes: Map[String, OpenapiSecuritySchemeType] = Map.empty,
parameters: Map[String, OpenapiParameter] = Map.empty
)

Expand All @@ -14,10 +13,15 @@ object OpenapiComponent {

implicit val OpenapiComponentDecoder: Decoder[OpenapiComponent] = { (c: HCursor) =>
for {
schemas <- c.downField("schemas").as[Map[String, OpenapiSchemaType]]
parameters <- c.downField("parameters").as[Option[Map[String, OpenapiParameter]]].map(_.getOrElse(Map.empty))
schemas <- c.getOrElse[Map[String, OpenapiSchemaType]]("schemas")(Map.empty)
securitySchemes <- c.getOrElse[Map[String, OpenapiSecuritySchemeType]]("securitySchemes")(Map.empty)
parameters <- c.getOrElse[Map[String, OpenapiParameter]]("parameters")(Map.empty)
} yield {
OpenapiComponent(schemas, parameters.map { case (k, v) => s"#/components/parameters/$k" -> v })
OpenapiComponent(
schemas,
securitySchemes,
parameters.map { case (k, v) => s"#/components/parameters/$k" -> v }
)
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ object OpenapiModels {
parameters: Seq[Resolvable[OpenapiParameter]],
responses: Seq[OpenapiResponse],
requestBody: Option[OpenapiRequestBody],
security: Seq[Seq[String]] = Nil,
summary: Option[String] = None,
tags: Option[Seq[String]] = None,
operationId: Option[String] = None
Expand Down Expand Up @@ -167,17 +168,24 @@ object OpenapiModels {
}
implicit val PartialOpenapiPathMethodDecoder: Decoder[OpenapiPathMethod] = { (c: HCursor) =>
for {
parameters <- c
.downField("parameters")
.as[Option[Seq[Resolvable[OpenapiParameter]]]]
.map(_.getOrElse(Nil))
responses <- c.downField("responses").as[Seq[OpenapiResponse]]
requestBody <- c.downField("requestBody").as[Option[OpenapiRequestBody]]
summary <- c.downField("summary").as[Option[String]]
tags <- c.downField("tags").as[Option[Seq[String]]]
operationId <- c.downField("operationId").as[Option[String]]
parameters <- c.getOrElse[Seq[Resolvable[OpenapiParameter]]]("parameters")(Nil)
responses <- c.get[Seq[OpenapiResponse]]("responses")
requestBody <- c.get[Option[OpenapiRequestBody]]("requestBody")
security <- c.getOrElse[Seq[Map[String, Seq[String]]]]("security")(Nil)
summary <- c.get[Option[String]]("summary")
tags <- c.get[Option[Seq[String]]]("tags")
operationId <- c.get[Option[String]]("operationId")
} yield {
OpenapiPathMethod("--partial--", parameters, responses, requestBody, summary, tags, operationId)
OpenapiPathMethod(
"--partial--",
parameters,
responses,
requestBody,
security.map(_.keys.toSeq),
summary,
tags,
operationId
)
}
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
package sttp.tapir.codegen.openapi.models

sealed trait OpenapiSecuritySchemeType

object OpenapiSecuritySchemeType {
case object OpenapiSecuritySchemeBearerType extends OpenapiSecuritySchemeType
case object OpenapiSecuritySchemeBasicType extends OpenapiSecuritySchemeType
case class OpenapiSecuritySchemeApiKeyType(in: String, name: String) extends OpenapiSecuritySchemeType

import io.circe._
import cats.implicits._

private implicit val BearerTypeDecoder: Decoder[OpenapiSecuritySchemeBearerType.type] = { (c: HCursor) =>
for {
_ <- c.get[String]("type").ensure(DecodingFailure("Given type is not http!", c.history))(_ == "http")
_ <- c.get[String]("scheme").ensure(DecodingFailure("Given scheme is not bearer!", c.history))(_ == "bearer")
} yield {
OpenapiSecuritySchemeBearerType
}
}

private implicit val BasicTypeDecoder: Decoder[OpenapiSecuritySchemeBasicType.type] = { (c: HCursor) =>
for {
_ <- c.get[String]("type").ensure(DecodingFailure("Given type is not http!", c.history))(_ == "http")
_ <- c.get[String]("scheme").ensure(DecodingFailure("Given scheme is not basic!", c.history))(_ == "basic")
} yield {
OpenapiSecuritySchemeBasicType
}
}

private val ApiKeyInOptions = List("header", "query", "cookie")

private implicit val ApiKeyDecoder: Decoder[OpenapiSecuritySchemeApiKeyType] = { (c: HCursor) =>
for {
_ <- c.get[String]("type").ensure(DecodingFailure("Given type is not apiKey!", c.history))(_ == "apiKey")
in <- c.get[String]("in").ensure(DecodingFailure("Invalid apiKey in value!", c.history))(ApiKeyInOptions.contains)
name <- c.get[String]("name")
} yield {
OpenapiSecuritySchemeApiKeyType(in, name)
}
}

implicit val OpenapiSecuritySchemeTypeDecoder: Decoder[OpenapiSecuritySchemeType] =
List[Decoder[OpenapiSecuritySchemeType]](
Decoder[OpenapiSecuritySchemeBearerType.type].widen,
Decoder[OpenapiSecuritySchemeBasicType.type].widen,
Decoder[OpenapiSecuritySchemeApiKeyType].widen
).reduceLeft(_ or _)
}
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
package sttp.tapir.codegen

import sttp.tapir.codegen.openapi.models.OpenapiComponent
import sttp.tapir.codegen.openapi.models.OpenapiModels.{
OpenapiDocument,
OpenapiParameter,
Expand All @@ -9,6 +10,11 @@ import sttp.tapir.codegen.openapi.models.OpenapiModels.{
OpenapiResponseContent,
Resolved
}
import sttp.tapir.codegen.openapi.models.OpenapiSecuritySchemeType.{
OpenapiSecuritySchemeBearerType,
OpenapiSecuritySchemeBasicType,
OpenapiSecuritySchemeApiKeyType
}
import sttp.tapir.codegen.openapi.models.OpenapiSchemaType.{OpenapiSchemaArray, OpenapiSchemaString}
import sttp.tapir.codegen.testutils.CompileCheckTestBase

Expand Down Expand Up @@ -47,4 +53,74 @@ class EndpointGeneratorSpec extends CompileCheckTestBase {
generatedCode shouldCompile ()
}

it should "generete endpoints defs with security" in {
val doc = OpenapiDocument(
"",
null,
Seq(
OpenapiPath(
"test",
Seq(
OpenapiPathMethod(
methodType = "get",
parameters = Seq(),
responses = Seq(),
requestBody = None,
security = Seq(Seq("httpBearer")),
summary = None,
tags = None
),
OpenapiPathMethod(
methodType = "post",
parameters = Seq(),
responses = Seq(),
requestBody = None,
security = Seq(Seq("httpBasic")),
summary = None,
tags = None
),
OpenapiPathMethod(
methodType = "put",
parameters = Seq(),
responses = Seq(),
requestBody = None,
security = Seq(Seq("apiKeyHeader")),
summary = None,
tags = None
),
OpenapiPathMethod(
methodType = "patch",
parameters = Seq(),
responses = Seq(),
requestBody = None,
security = Seq(Seq("apiKeyCookie")),
summary = None,
tags = None
),
OpenapiPathMethod(
methodType = "delete",
parameters = Seq(),
responses = Seq(),
requestBody = None,
security = Seq(Seq("apiKeyQuery")),
summary = None,
tags = None
)
)
)
),
Some(OpenapiComponent(
Map(),
Map(
"httpBearer" -> OpenapiSecuritySchemeBearerType,
"httpBasic" -> OpenapiSecuritySchemeBasicType,
"apiKeyHeader" -> OpenapiSecuritySchemeApiKeyType("header", "X-API-KEY"),
"apiKeyCookie" -> OpenapiSecuritySchemeApiKeyType("cookie", "api_key"),
"apiKeyQuery" -> OpenapiSecuritySchemeApiKeyType("query", "api-key")
)
))
)
BasicGenerator.imports ++
new EndpointGenerator().endpointDefs(doc) shouldCompile()
}
}
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
package sttp.tapir.codegen

import sttp.tapir.codegen.openapi.models.{OpenapiComponent, OpenapiSchemaType}
import sttp.tapir.codegen.openapi.models.OpenapiComponent
import sttp.tapir.codegen.openapi.models.OpenapiModels.{
OpenapiDocument,
OpenapiInfo,
Expand All @@ -16,7 +16,6 @@ import sttp.tapir.codegen.openapi.models.OpenapiModels.{
}
import sttp.tapir.codegen.openapi.models.OpenapiSchemaType.{
OpenapiSchemaArray,
OpenapiSchemaDouble,
OpenapiSchemaInt,
OpenapiSchemaObject,
OpenapiSchemaRef,
Expand Down Expand Up @@ -227,10 +226,11 @@ object TestHelpers {
),
Some(
OpenapiComponent(
Map(
schemas = Map(
"Book" -> OpenapiSchemaObject(Map("title" -> OpenapiSchemaString(false)), Seq("title"), false)
),
Map(
securitySchemes = Map.empty,
parameters = Map(
"#/components/parameters/offset" ->
OpenapiParameter("offset", "query", true, Some("Offset at which to start fetching books"), OpenapiSchemaInt(false)),
"#/components/parameters/year" -> OpenapiParameter("year", "path", true, None, OpenapiSchemaInt(false))
Expand Down Expand Up @@ -440,4 +440,74 @@ object TestHelpers {
),
None
)

val simpleSecurityYaml =
"""
|openapi: 3.1.0
|info:
| title: hello
| version: '1.0'
|paths:
| /hello:
| get:
| security:
| - basicAuth: []
| responses: {}
""".stripMargin

val simpleSecurityDocs = OpenapiDocument(
"3.1.0",
OpenapiInfo("hello", "1.0"),
Seq(
OpenapiPath(
url = "/hello",
methods = Seq(
OpenapiPathMethod(
methodType = "get",
parameters = Seq(),
responses = Seq(),
requestBody = None,
security = Seq(Seq("basicAuth"))
)
)
)
),
None
)

val complexSecurityYaml =
"""
|openapi: 3.1.0
|info:
| title: hello
| version: '1.0'
|paths:
| /hello:
| get:
| security:
| - bearerAuth: []
| - basicAuth: []
| apiKeyAuth: []
| responses: {}
""".stripMargin

val complexSecurityDocs = OpenapiDocument(
"3.1.0",
OpenapiInfo("hello", "1.0"),
Seq(
OpenapiPath(
url = "/hello",
methods = Seq(
OpenapiPathMethod(
methodType = "get",
parameters = Seq(),
responses = Seq(),
requestBody = None,
security = Seq(Seq("bearerAuth"), Seq("basicAuth", "apiKeyAuth"))
)
)
)
),
None
)
}
Loading

0 comments on commit 05f9b2b

Please sign in to comment.