Skip to content

Commit

Permalink
Proposal to describe OAuth2 flows (#3490)
Browse files Browse the repository at this point in the history
Co-authored-by: adamw <[email protected]>
  • Loading branch information
leoniv and adamw authored Feb 2, 2024
1 parent a2e011b commit d71bb4c
Show file tree
Hide file tree
Showing 6 changed files with 122 additions and 21 deletions.
58 changes: 57 additions & 1 deletion core/src/main/scala/sttp/tapir/TapirAuth.scala
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ import sttp.model.headers.{AuthenticationScheme, WWWAuthenticateChallenge}
import sttp.tapir.CodecFormat.TextPlain
import sttp.tapir.EndpointInput.Auth

import scala.collection.immutable.ListMap
import scala.collection.immutable.{ListMap, Seq}

object TapirAuth {

Expand Down Expand Up @@ -58,6 +58,16 @@ object TapirAuth {
}

object oauth2 {
sealed trait OAuth2Flow
object OAuth2Flow {
case object AuthenticationCode extends OAuth2Flow
case object ClientCredentials extends OAuth2Flow
case object Implicit extends OAuth2Flow

val Attribute: AttributeKey[OAuth2Flow] = new AttributeKey[OAuth2Flow]("sttp.tapir.TapirAuth.oauth2.OAuth2Flow")
}

@deprecated("Use insted authorizationCodeFlow, clientCredentialsFlow or implicitFlow", "")
def authorizationCode(
authorizationUrl: Option[String] = None,
scopes: ListMap[String, String] = ListMap(),
Expand All @@ -73,6 +83,52 @@ object TapirAuth {
)
}

private def buildInput(
baseOAuth: EndpointInput.AuthType.OAuth2,
challenge: WWWAuthenticateChallenge,
flow: OAuth2Flow
): Auth[String, EndpointInput.AuthType.OAuth2] =
EndpointInput
.Auth(
header[String](HeaderNames.Authorization).map(stringPrefixWithSpace(AuthenticationScheme.Bearer.name)),
challenge,
baseOAuth: EndpointInput.AuthType.OAuth2,
EndpointInput.AuthInfo.Empty
)
.attribute(OAuth2Flow.Attribute, flow)

def authorizationCodeFlow(
authorizationUrl: String,
tokenUrl: String,
refreshUrl: Option[String] = None,
scopes: ListMap[String, String] = ListMap(),
challenge: WWWAuthenticateChallenge = WWWAuthenticateChallenge.bearer
): Auth[String, EndpointInput.AuthType.OAuth2] = buildInput(
EndpointInput.AuthType.OAuth2(Some(authorizationUrl), Some(tokenUrl), scopes, refreshUrl),
challenge,
OAuth2Flow.AuthenticationCode
)

def clientCredentialsFlow(
tokenUrl: String,
refreshUrl: Option[String] = None,
scopes: ListMap[String, String] = ListMap(),
challenge: WWWAuthenticateChallenge = WWWAuthenticateChallenge.bearer
): Auth[String, EndpointInput.AuthType.OAuth2] =
buildInput(
EndpointInput.AuthType.OAuth2(None, Some(tokenUrl), scopes, refreshUrl),
challenge,
OAuth2Flow.ClientCredentials
)

def implicitFlow(
authorizationUrl: String,
refreshUrl: Option[String] = None,
scopes: ListMap[String, String] = ListMap(),
challenge: WWWAuthenticateChallenge = WWWAuthenticateChallenge.bearer
): Auth[String, EndpointInput.AuthType.OAuth2] =
buildInput(EndpointInput.AuthType.OAuth2(Some(authorizationUrl), None, scopes, refreshUrl), challenge, OAuth2Flow.Implicit)

private def stringPrefixWithSpace(prefix: String) = Mapping.stringPrefixCaseInsensitive(prefix + " ")
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import sttp.apispec.{OAuthFlow, OAuthFlows, SecurityScheme}
import sttp.tapir.internal._
import sttp.tapir.docs.apispec.DocsExtensionAttribute.RichEndpointAuth
import sttp.tapir.{AnyEndpoint, EndpointIO, EndpointInput}
import sttp.tapir.TapirAuth.oauth2.OAuth2Flow

import scala.annotation.tailrec

Expand Down Expand Up @@ -42,6 +43,14 @@ private[docs] object SecuritySchemesForEndpoints {
}
}

private def getOAuth2Flow(a: EndpointInput.AuthType.OAuth2, definedFlow: Option[OAuth2Flow]): OAuthFlows = definedFlow match {
case Some(OAuth2Flow.AuthenticationCode) =>
OAuthFlows(authorizationCode = Some(OAuthFlow(a.authorizationUrl, a.tokenUrl, a.refreshUrl, a.scopes)))
case Some(OAuth2Flow.ClientCredentials) => OAuthFlows(clientCredentials = Some(OAuthFlow(None, a.tokenUrl, a.refreshUrl, a.scopes)))
case Some(OAuth2Flow.Implicit) => OAuthFlows(`implicit` = Some(OAuthFlow(a.authorizationUrl, None, a.refreshUrl, a.scopes)))
case None => OAuthFlows(authorizationCode = Some(OAuthFlow(a.authorizationUrl, a.tokenUrl, a.refreshUrl, a.scopes)))
}

private def authToSecurityScheme(a: EndpointInput.Auth[_, _ <: EndpointInput.AuthType], apiKeyAuthTypeName: String): SecurityScheme = {
val extensions = DocsExtensions.fromIterable(a.docsExtensions)
a.authType match {
Expand All @@ -50,27 +59,27 @@ private[docs] object SecuritySchemesForEndpoints {
SecurityScheme(apiKeyAuthTypeName, a.info.description, Some(name), Some(in), None, a.info.bearerFormat, None, None, extensions)
case EndpointInput.AuthType.Http(scheme) =>
SecurityScheme("http", a.info.description, None, None, Some(scheme.toLowerCase()), a.info.bearerFormat, None, None, extensions)
case EndpointInput.AuthType.OAuth2(authorizationUrl, tokenUrl, scopes, refreshUrl) =>
case oauth2: EndpointInput.AuthType.OAuth2 =>
SecurityScheme(
"oauth2",
a.info.description,
None,
None,
None,
a.info.bearerFormat,
Some(OAuthFlows(authorizationCode = Some(OAuthFlow(authorizationUrl, tokenUrl, refreshUrl, scopes)))),
Some(getOAuth2Flow(oauth2, a.attribute(OAuth2Flow.Attribute))),
None,
extensions
)
case EndpointInput.AuthType.ScopedOAuth2(EndpointInput.AuthType.OAuth2(authorizationUrl, tokenUrl, scopes, refreshUrl), _) =>
case EndpointInput.AuthType.ScopedOAuth2(oauth2, _) =>
SecurityScheme(
"oauth2",
a.info.description,
None,
None,
None,
a.info.bearerFormat,
Some(OAuthFlows(authorizationCode = Some(OAuthFlow(authorizationUrl, tokenUrl, refreshUrl, scopes)))),
Some(getOAuth2Flow(oauth2, a.attribute(OAuth2Flow.Attribute))),
None,
extensions
)
Expand Down
28 changes: 24 additions & 4 deletions docs/openapi-docs/src/test/resources/security/expected_oauth2.yml
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,8 @@ paths:
schema:
type: string
security:
- oauth2Auth:
- client
- oauth2Auth1:
- client
/api3/{p1}:
get:
operationId: getApi3P1
Expand All @@ -57,15 +57,35 @@ paths:
schema:
type: string
security:
- oauth2Auth:
- admin
- oauth2Auth2:
- admin
components:
securitySchemes:
oauth2Auth:
type: oauth2
flows:
authorizationCode:
authorizationUrl: https://example.com/auth
tokenUrl: https://example.com/token
refreshUrl: https://example.com/token/refresh
scopes:
client: scope for clients
admin: administration scope
oauth2Auth1:
type: oauth2
flows:
clientCredentials:
tokenUrl: https://example.com/token
refreshUrl: https://example.com/token/refresh
scopes:
client: scope for clients
admin: administration scope
oauth2Auth2:
type: oauth2
flows:
implicit:
authorizationUrl: https://example.com/auth
refreshUrl: https://example.com/token/refresh
scopes:
client: scope for clients
admin: administration scope
Original file line number Diff line number Diff line change
Expand Up @@ -121,26 +121,42 @@ class VerifyYamlSecurityTest extends AnyFunSuite with Matchers {

test("should support Oauth2") {
val expectedYaml = load("security/expected_oauth2.yml")
val oauth2 =
val authCodeFlow =
auth.oauth2
.authorizationCode(
Some("https://example.com/auth"),
.authorizationCodeFlow(
"https://example.com/auth",
"https://example.com/token",
Some("https://example.com/token/refresh"),
ListMap("client" -> "scope for clients", "admin" -> "administration scope")
)
val clientCredFlow =
auth.oauth2
.clientCredentialsFlow(
"https://example.com/token",
Some("https://example.com/token/refresh"),
ListMap("client" -> "scope for clients", "admin" -> "administration scope")
)
val implicitFlow =
auth.oauth2
.implicitFlow(
"https://example.com/auth",
Some("https://example.com/token/refresh"),
ListMap("client" -> "scope for clients", "admin" -> "administration scope")
)

val e1 =
endpoint
.securityIn(oauth2)
.securityIn(authCodeFlow)
.in("api1" / path[String])
.out(stringBody)
val e2 =
endpoint
.securityIn(oauth2.requiredScopes(Seq("client")))
.securityIn(clientCredFlow.requiredScopes(Seq("client")))
.in("api2" / path[String])
.out(stringBody)
val e3 =
endpoint
.securityIn(oauth2.requiredScopes(Seq("admin")))
.securityIn(implicitFlow.requiredScopes(Seq("admin")))
.in("api3" / path[String])
.out(stringBody)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,9 @@ object SwaggerUIOAuth2PekkoServer extends App with RouteConcatenation {
val secureEndpoint: PartialServerEndpoint[String, String, Unit, Int, Unit, Any, Future] =
endpoint
.securityIn(
auth.oauth2.authorizationCode(
authorizationUrl = Some("http://localhost:8080/realms/myrealm/protocol/openid-connect/auth"),
tokenUrl = Some("http://localhost:8080/realms/myrealm/protocol/openid-connect/token")
auth.oauth2.authorizationCodeFlow(
authorizationUrl = "http://localhost:8080/realms/myrealm/protocol/openid-connect/auth",
tokenUrl = "http://localhost:8080/realms/myrealm/protocol/openid-connect/token"
)
)
.errorOut(plainBody[Int])
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,10 +36,10 @@ object OAuth2GithubHttp4sServer extends IOApp {

case class AccessDetails(token: String)

val authorizationUrl = Some("https://github.com/login/oauth/authorize")
val accessTokenUrl = Some("https://github.com/login/oauth/access_token")
val authorizationUrl = "https://github.com/login/oauth/authorize"
val accessTokenUrl = "https://github.com/login/oauth/access_token"

val authOAuth2 = auth.oauth2.authorizationCode(authorizationUrl, ListMap.empty, accessTokenUrl)
val authOAuth2 = auth.oauth2.authorizationCodeFlow(authorizationUrl, accessTokenUrl)

// endpoint declarations
val login: PublicEndpoint[Unit, Unit, String, Any] =
Expand Down

0 comments on commit d71bb4c

Please sign in to comment.