Skip to content

Commit

Permalink
WX-833 Real Azure DRS Credentials (#6952)
Browse files Browse the repository at this point in the history
* Remove B2C reference from name

* Get token for current user rather than getting from KeyVault

* Remove KeyVault config for engine

* Remove KeyVault config for DRSLocalizer

* Remove KeyVault dependency

* Remove KeyVault support from localizer repo template

* Cleaned up and working Azure token acquisition for engine

* Collapse localizer's AccessTokenStrategy into DrsCredentials

* Cleanup
  • Loading branch information
jgainerdewar authored Nov 22, 2022
1 parent 001957b commit e0c8dd0
Show file tree
Hide file tree
Showing 18 changed files with 109 additions and 201 deletions.
Original file line number Diff line number Diff line change
@@ -1,25 +1,32 @@
package cloud.nio.impl.drs

import cats.syntax.validated._
import com.azure.core.credential.TokenRequestContext
import com.azure.core.management.AzureEnvironment
import com.azure.core.management.profile.AzureProfile
import com.azure.identity.DefaultAzureCredentialBuilder
import com.azure.security.keyvault.secrets.{SecretClient, SecretClientBuilder}
import com.google.auth.oauth2.{AccessToken, OAuth2Credentials}
import com.google.auth.oauth2.{AccessToken, GoogleCredentials, OAuth2Credentials}
import com.typesafe.config.Config
import common.validation.ErrorOr
import common.validation.ErrorOr.ErrorOr
import net.ceedubs.ficus.Ficus._

import scala.concurrent.duration._
import scala.jdk.DurationConverters._
import scala.util.{Failure, Success, Try}

/**
* This trait allows us to abstract away different token attainment strategies
* for different cloud environments.
**/
sealed trait DrsCredentials {
trait DrsCredentials {
def getAccessToken: ErrorOr[String]
}

case class GoogleDrsCredentials(credentials: OAuth2Credentials, acceptableTTL: Duration) extends DrsCredentials {
/**
* Strategy for obtaining an access token from an existing OAuth credential. This class
* is designed for use within the Cromwell engine.
*/
case class GoogleOauthDrsCredentials(credentials: OAuth2Credentials, acceptableTTL: Duration) extends DrsCredentials {
//Based on method from GoogleRegistry
def getAccessToken: ErrorOr[String] = {
def accessTokenTTLIsAcceptable(accessToken: AccessToken): Boolean = {
Expand All @@ -39,25 +46,68 @@ case class GoogleDrsCredentials(credentials: OAuth2Credentials, acceptableTTL: D
}
}

object GoogleDrsCredentials {
def apply(credentials: OAuth2Credentials, config: Config): GoogleDrsCredentials =
GoogleDrsCredentials(credentials, config.as[FiniteDuration]("access-token-acceptable-ttl"))
object GoogleOauthDrsCredentials {
def apply(credentials: OAuth2Credentials, config: Config): GoogleOauthDrsCredentials =
GoogleOauthDrsCredentials(credentials, config.as[FiniteDuration]("access-token-acceptable-ttl"))
}


/**
* Strategy for obtaining an access token from Google Application Default credentials that are assumed to already exist
* in the environment. This class is designed for use by standalone executables running in environments
* that have direct access to a Google identity (ex. CromwellDrsLocalizer).
*/
case object GoogleAppDefaultTokenStrategy extends DrsCredentials {
private final val UserInfoEmailScope = "https://www.googleapis.com/auth/userinfo.email"
private final val UserInfoProfileScope = "https://www.googleapis.com/auth/userinfo.profile"

def getAccessToken: ErrorOr[String] = {
Try {
val scopedCredentials = GoogleCredentials.getApplicationDefault().createScoped(UserInfoEmailScope, UserInfoProfileScope)
scopedCredentials.refreshAccessToken().getTokenValue
} match {
case Success(null) => "null token value attempting to refresh access token".invalidNel
case Success(value) => value.validNel
case Failure(e) => s"Failed to refresh access token: ${e.getMessage}".invalidNel
}
}
}

case class AzureDrsCredentials(identityClientId: Option[String], vaultName: String, secretName: String) extends DrsCredentials {
/**
* Strategy for obtaining an access token in an environment with available Azure identity.
* If you need to disambiguate among multiple active user-assigned managed identities, pass
* in the client id of the identity that should be used.
*/
case class AzureDrsCredentials(identityClientId: Option[String]) extends DrsCredentials {

final val tokenAcquisitionTimeout = 30.seconds

lazy val secretClient: ErrorOr[SecretClient] = ErrorOr {
val defaultCreds = identityClientId.map(identityId =>
new DefaultAzureCredentialBuilder().managedIdentityClientId(identityId)
).getOrElse(
new DefaultAzureCredentialBuilder()
).build()
val azureProfile = new AzureProfile(AzureEnvironment.AZURE)
val tokenScope = "https://management.azure.com/.default"

new SecretClientBuilder()
.vaultUrl(s"https://${vaultName}.vault.azure.net")
.credential(defaultCreds)
.buildClient()
def tokenRequestContext: TokenRequestContext = {
val trc = new TokenRequestContext()
trc.addScopes(tokenScope)
trc
}

def getAccessToken: ErrorOr[String] = secretClient.map(_.getSecret(secretName).getValue)
def defaultCredentialBuilder: DefaultAzureCredentialBuilder =
new DefaultAzureCredentialBuilder()
.authorityHost(azureProfile.getEnvironment.getActiveDirectoryEndpoint)

def getAccessToken: ErrorOr[String] = {
val credentials = identityClientId.foldLeft(defaultCredentialBuilder) {
(builder, clientId) => builder.managedIdentityClientId(clientId)
}.build()

Try(
credentials
.getToken(tokenRequestContext)
.block(tokenAcquisitionTimeout.toJava)
) match {
case Success(null) => "null token value attempting to obtain access token".invalidNel
case Success(token) => token.getToken.validNel
case Failure(error) => s"Failed to refresh access token: ${error.getMessage}".invalidNel
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ class DrsCloudNioFileProviderSpec extends AnyFlatSpecLike with CromwellTimeoutSp
val fileSystemProvider = new MockDrsCloudNioFileSystemProvider(config = config)
fileSystemProvider.drsConfig.marthaUrl should be("https://from.config")
fileSystemProvider.drsCredentials match {
case GoogleDrsCredentials(_, ttl) => ttl should be(1.minute)
case GoogleOauthDrsCredentials(_, ttl) => ttl should be(1.minute)
case error => fail(s"Expected GoogleDrsCredentials, found $error")
}
fileSystemProvider.fileProvider should be(a[DrsCloudNioFileProvider])
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ class MockDrsCloudNioFileSystemProvider(config: Config = mockConfig,
),
mockResolver: Option[EngineDrsPathResolver] = None,
)
extends DrsCloudNioFileSystemProvider(config, GoogleDrsCredentials(NoCredentials.getInstance, config), drsReadInterpreter) {
extends DrsCloudNioFileSystemProvider(config, GoogleOauthDrsCredentials(NoCredentials.getInstance, config), drsReadInterpreter) {

override lazy val drsPathResolver: EngineDrsPathResolver = {
mockResolver getOrElse
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ class MockEngineDrsPathResolver(drsConfig: DrsConfig = MockDrsPaths.mockDrsConfi
httpClientBuilderOverride: Option[HttpClientBuilder] = None,
accessTokenAcceptableTTL: Duration = Duration.Inf,
)
extends EngineDrsPathResolver(drsConfig, GoogleDrsCredentials(NoCredentials.getInstance, accessTokenAcceptableTTL)) {
extends EngineDrsPathResolver(drsConfig, GoogleOauthDrsCredentials(NoCredentials.getInstance, accessTokenAcceptableTTL)) {

override protected lazy val httpClientBuilder: HttpClientBuilder =
httpClientBuilderOverride getOrElse MockSugar.mock[HttpClientBuilder]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,12 +36,6 @@ class CommandLineParser extends scopt.OptionParser[CommandLineArguments](Usage)
opt[String]('t', "access-token-strategy").text(s"Access token strategy, must be one of '$Azure' or '$Google' (default '$Google')").
action((s, c) =>
c.copy(accessTokenStrategy = Option(s.toLowerCase())))
opt[String]('v', "vault-name").text("Azure vault name").
action((s, c) =>
c.copy(azureVaultName = Option(s)))
opt[String]('s', "secret-name").text("Azure secret name").
action((s, c) =>
c.copy(azureSecretName = Option(s)))
opt[String]('i', "identity-client-id").text("Azure identity client id").
action((s, c) =>
c.copy(azureIdentityClientId = Option(s)))
Expand All @@ -54,11 +48,10 @@ class CommandLineParser extends scopt.OptionParser[CommandLineArguments](Usage)
c.accessTokenStrategy match {
case Some(Azure) if c.googleRequesterPaysProject.nonEmpty =>
Left(s"Requester pays project is only valid with access token strategy '$Google'")
case Some(Azure) if List(c.azureVaultName, c.azureSecretName).exists(_.isEmpty) =>
Left(s"Both vault name and secret name must be specified for access token strategy $Azure")
case Some(Google) if c.azureIdentityClientId.nonEmpty =>
Left(s"Identity client id is only valid with access token strategy '$Azure'")
case Some(Azure) => Right(())
case Some(Google) if List(c.azureSecretName, c.azureVaultName, c.azureIdentityClientId).forall(_.isEmpty) => Right(())
case Some(Google) => Left(s"One or more specified options are only valid with access token strategy '$Azure'")
case Some(Google) => Right(())
case Some(huh) => Left(s"Unrecognized access token strategy '$huh'")
case None => Left("Programmer error, access token strategy should not be None")
}
Expand Down Expand Up @@ -100,8 +93,6 @@ case class CommandLineArguments(accessTokenStrategy: Option[String] = Option(Goo
drsObject: Option[String] = None,
containerPath: Option[String] = None,
googleRequesterPaysProject: Option[String] = None,
azureVaultName: Option[String] = None,
azureSecretName: Option[String] = None,
azureIdentityClientId: Option[String] = None,
manifestPath: Option[String] = None,
googleRequesterPaysProjectConflict: Boolean = false)
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
package drs.localizer

import cloud.nio.impl.drs.{DrsConfig, DrsPathResolver}
import cloud.nio.impl.drs.{DrsConfig, DrsCredentials, DrsPathResolver}
import common.validation.ErrorOr.ErrorOr
import drs.localizer.accesstokens.AccessTokenStrategy


class DrsLocalizerDrsPathResolver(drsConfig: DrsConfig, accessTokenStrategy: AccessTokenStrategy) extends DrsPathResolver(drsConfig) {
override def getAccessToken: ErrorOr[String] = accessTokenStrategy.getAccessToken()
class DrsLocalizerDrsPathResolver(drsConfig: DrsConfig, drsCredentials: DrsCredentials) extends DrsPathResolver(drsConfig) {
override def getAccessToken: ErrorOr[String] = drsCredentials.getAccessToken
}
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,10 @@ import cats.data.NonEmptyList
import cats.effect.{ExitCode, IO, IOApp}
import cats.implicits._
import cloud.nio.impl.drs.DrsPathResolver.{FatalRetryDisposition, RegularRetryDisposition}
import cloud.nio.impl.drs.{AccessUrl, DrsConfig, DrsPathResolver, MarthaField}
import cloud.nio.impl.drs._
import cloud.nio.spi.{CloudNioBackoff, CloudNioSimpleExponentialBackoff}
import com.typesafe.scalalogging.StrictLogging
import drs.localizer.CommandLineParser.AccessTokenStrategy.{Azure, Google}
import drs.localizer.accesstokens.{AccessTokenStrategy, AzureB2CAccessTokenStrategy, GoogleAccessTokenStrategy}
import drs.localizer.downloaders.AccessUrlDownloader.Hashes
import drs.localizer.downloaders._
import org.apache.commons.csv.{CSVFormat, CSVParser}
Expand All @@ -29,8 +28,8 @@ object DrsLocalizerMain extends IOApp with StrictLogging {
val localize: Option[IO[ExitCode]] = for {
pa <- parsedArgs
run <- pa.accessTokenStrategy.collect {
case Azure => runLocalizer(pa, AzureB2CAccessTokenStrategy(pa))
case Google => runLocalizer(pa, GoogleAccessTokenStrategy)
case Azure => runLocalizer(pa, AzureDrsCredentials(pa.azureIdentityClientId))
case Google => runLocalizer(pa, GoogleAppDefaultTokenStrategy)
}
} yield run

Expand All @@ -55,39 +54,39 @@ object DrsLocalizerMain extends IOApp with StrictLogging {
IO.pure(ExitCode.Error)
}

def runLocalizer(commandLineArguments: CommandLineArguments, accessTokenStrategy: AccessTokenStrategy): IO[ExitCode] = {
def runLocalizer(commandLineArguments: CommandLineArguments, drsCredentials: DrsCredentials): IO[ExitCode] = {
commandLineArguments.manifestPath match {
case Some(manifestPath) =>
val manifestFile = new File(manifestPath)
val csvParser = CSVParser.parse(manifestFile, Charset.defaultCharset(), CSVFormat.DEFAULT)
val exitCodes: IO[List[ExitCode]] = csvParser.asScala.map(record => {
val drsObject = record.get(0)
val containerPath = record.get(1)
localizeFile(commandLineArguments, accessTokenStrategy, drsObject, containerPath)
localizeFile(commandLineArguments, drsCredentials, drsObject, containerPath)
}).toList.sequence
exitCodes.map(_.find(_ != ExitCode.Success).getOrElse(ExitCode.Success))
case None =>
val drsObject = commandLineArguments.drsObject.get
val containerPath = commandLineArguments.containerPath.get
localizeFile(commandLineArguments, accessTokenStrategy, drsObject, containerPath)
localizeFile(commandLineArguments, drsCredentials, drsObject, containerPath)
}
}

private def localizeFile(commandLineArguments: CommandLineArguments, accessTokenStrategy: AccessTokenStrategy, drsObject: String, containerPath: String) = {
new DrsLocalizerMain(drsObject, containerPath, accessTokenStrategy, commandLineArguments.googleRequesterPaysProject).
private def localizeFile(commandLineArguments: CommandLineArguments, drsCredentials: DrsCredentials, drsObject: String, containerPath: String) = {
new DrsLocalizerMain(drsObject, containerPath, drsCredentials, commandLineArguments.googleRequesterPaysProject).
resolveAndDownloadWithRetries(downloadRetries = 3, checksumRetries = 1, defaultDownloaderFactory, Option(defaultBackoff)).map(_.exitCode)
}
}

class DrsLocalizerMain(drsUrl: String,
downloadLoc: String,
accessTokenStrategy: AccessTokenStrategy,
drsCredentials: DrsCredentials,
requesterPaysProjectIdOption: Option[String]) extends StrictLogging {

def getDrsPathResolver: IO[DrsLocalizerDrsPathResolver] = {
IO {
val drsConfig = DrsConfig.fromEnv(sys.env)
new DrsLocalizerDrsPathResolver(drsConfig, accessTokenStrategy)
new DrsLocalizerDrsPathResolver(drsConfig, drsCredentials)
}
}

Expand Down

This file was deleted.

This file was deleted.

This file was deleted.

Loading

0 comments on commit e0c8dd0

Please sign in to comment.