Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WX-833 Real Azure DRS Credentials #6952

Merged
merged 9 commits into from
Nov 22, 2022
Original file line number Diff line number Diff line change
@@ -1,25 +1,32 @@
package cloud.nio.impl.drs

import cats.syntax.validated._
import com.azure.core.credential.TokenRequestContext
import com.azure.core.management.AzureEnvironment
import com.azure.core.management.profile.AzureProfile
import com.azure.identity.DefaultAzureCredentialBuilder
import com.azure.security.keyvault.secrets.{SecretClient, SecretClientBuilder}
import com.google.auth.oauth2.{AccessToken, OAuth2Credentials}
import com.google.auth.oauth2.{AccessToken, GoogleCredentials, OAuth2Credentials}
import com.typesafe.config.Config
import common.validation.ErrorOr
import common.validation.ErrorOr.ErrorOr
import net.ceedubs.ficus.Ficus._

import scala.concurrent.duration._
import scala.jdk.DurationConverters._
import scala.util.{Failure, Success, Try}

/**
* This trait allows us to abstract away different token attainment strategies
* for different cloud environments.
**/
sealed trait DrsCredentials {
trait DrsCredentials {
def getAccessToken: ErrorOr[String]
}

case class GoogleDrsCredentials(credentials: OAuth2Credentials, acceptableTTL: Duration) extends DrsCredentials {
/**
* Strategy for obtaining an access token from an existing OAuth credential. This class
* is designed for use within the Cromwell engine.
*/
case class GoogleOauthDrsCredentials(credentials: OAuth2Credentials, acceptableTTL: Duration) extends DrsCredentials {
//Based on method from GoogleRegistry
def getAccessToken: ErrorOr[String] = {
def accessTokenTTLIsAcceptable(accessToken: AccessToken): Boolean = {
Expand All @@ -39,25 +46,68 @@ case class GoogleDrsCredentials(credentials: OAuth2Credentials, acceptableTTL: D
}
}

object GoogleDrsCredentials {
def apply(credentials: OAuth2Credentials, config: Config): GoogleDrsCredentials =
GoogleDrsCredentials(credentials, config.as[FiniteDuration]("access-token-acceptable-ttl"))
object GoogleOauthDrsCredentials {
def apply(credentials: OAuth2Credentials, config: Config): GoogleOauthDrsCredentials =
GoogleOauthDrsCredentials(credentials, config.as[FiniteDuration]("access-token-acceptable-ttl"))
}


/**
* Strategy for obtaining an access token from Google Application Default credentials that are assumed to already exist
* in the environment. This class is designed for use by standalone executables running in environments
* that have direct access to a Google identity (ex. CromwellDrsLocalizer).
*/
case object GoogleAppDefaultTokenStrategy extends DrsCredentials {
private final val UserInfoEmailScope = "https://www.googleapis.com/auth/userinfo.email"
private final val UserInfoProfileScope = "https://www.googleapis.com/auth/userinfo.profile"

def getAccessToken: ErrorOr[String] = {
Try {
val scopedCredentials = GoogleCredentials.getApplicationDefault().createScoped(UserInfoEmailScope, UserInfoProfileScope)
scopedCredentials.refreshAccessToken().getTokenValue
} match {
case Success(null) => "null token value attempting to refresh access token".invalidNel
case Success(value) => value.validNel
case Failure(e) => s"Failed to refresh access token: ${e.getMessage}".invalidNel
}
}
}

case class AzureDrsCredentials(identityClientId: Option[String], vaultName: String, secretName: String) extends DrsCredentials {
/**
* Strategy for obtaining an access token in an environment with available Azure identity.
* If you need to disambiguate among multiple active user-assigned managed identities, pass
* in the client id of the identity that should be used.
*/
case class AzureDrsCredentials(identityClientId: Option[String]) extends DrsCredentials {

final val tokenAcquisitionTimeout = 30.seconds

lazy val secretClient: ErrorOr[SecretClient] = ErrorOr {
val defaultCreds = identityClientId.map(identityId =>
new DefaultAzureCredentialBuilder().managedIdentityClientId(identityId)
).getOrElse(
new DefaultAzureCredentialBuilder()
).build()
val azureProfile = new AzureProfile(AzureEnvironment.AZURE)
val tokenScope = "https://management.azure.com/.default"

new SecretClientBuilder()
.vaultUrl(s"https://${vaultName}.vault.azure.net")
.credential(defaultCreds)
.buildClient()
def tokenRequestContext: TokenRequestContext = {
val trc = new TokenRequestContext()
trc.addScopes(tokenScope)
trc
}

def getAccessToken: ErrorOr[String] = secretClient.map(_.getSecret(secretName).getValue)
def defaultCredentialBuilder: DefaultAzureCredentialBuilder =
new DefaultAzureCredentialBuilder()
.authorityHost(azureProfile.getEnvironment.getActiveDirectoryEndpoint)

def getAccessToken: ErrorOr[String] = {
val credentials = identityClientId.foldLeft(defaultCredentialBuilder) {
(builder, clientId) => builder.managedIdentityClientId(clientId)
}.build()

Try(
credentials
.getToken(tokenRequestContext)
.block(tokenAcquisitionTimeout.toJava)
) match {
case Success(null) => "null token value attempting to obtain access token".invalidNel
case Success(token) => token.getToken.validNel
case Failure(error) => s"Failed to refresh access token: ${error.getMessage}".invalidNel
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ class DrsCloudNioFileProviderSpec extends AnyFlatSpecLike with CromwellTimeoutSp
val fileSystemProvider = new MockDrsCloudNioFileSystemProvider(config = config)
fileSystemProvider.drsConfig.marthaUrl should be("https://from.config")
fileSystemProvider.drsCredentials match {
case GoogleDrsCredentials(_, ttl) => ttl should be(1.minute)
case GoogleOauthDrsCredentials(_, ttl) => ttl should be(1.minute)
case error => fail(s"Expected GoogleDrsCredentials, found $error")
}
fileSystemProvider.fileProvider should be(a[DrsCloudNioFileProvider])
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ class MockDrsCloudNioFileSystemProvider(config: Config = mockConfig,
),
mockResolver: Option[EngineDrsPathResolver] = None,
)
extends DrsCloudNioFileSystemProvider(config, GoogleDrsCredentials(NoCredentials.getInstance, config), drsReadInterpreter) {
extends DrsCloudNioFileSystemProvider(config, GoogleOauthDrsCredentials(NoCredentials.getInstance, config), drsReadInterpreter) {

override lazy val drsPathResolver: EngineDrsPathResolver = {
mockResolver getOrElse
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ class MockEngineDrsPathResolver(drsConfig: DrsConfig = MockDrsPaths.mockDrsConfi
httpClientBuilderOverride: Option[HttpClientBuilder] = None,
accessTokenAcceptableTTL: Duration = Duration.Inf,
)
extends EngineDrsPathResolver(drsConfig, GoogleDrsCredentials(NoCredentials.getInstance, accessTokenAcceptableTTL)) {
extends EngineDrsPathResolver(drsConfig, GoogleOauthDrsCredentials(NoCredentials.getInstance, accessTokenAcceptableTTL)) {

override protected lazy val httpClientBuilder: HttpClientBuilder =
httpClientBuilderOverride getOrElse MockSugar.mock[HttpClientBuilder]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,12 +36,6 @@ class CommandLineParser extends scopt.OptionParser[CommandLineArguments](Usage)
opt[String]('t', "access-token-strategy").text(s"Access token strategy, must be one of '$Azure' or '$Google' (default '$Google')").
action((s, c) =>
c.copy(accessTokenStrategy = Option(s.toLowerCase())))
opt[String]('v', "vault-name").text("Azure vault name").
action((s, c) =>
c.copy(azureVaultName = Option(s)))
opt[String]('s', "secret-name").text("Azure secret name").
action((s, c) =>
c.copy(azureSecretName = Option(s)))
opt[String]('i', "identity-client-id").text("Azure identity client id").
action((s, c) =>
c.copy(azureIdentityClientId = Option(s)))
Expand All @@ -54,11 +48,10 @@ class CommandLineParser extends scopt.OptionParser[CommandLineArguments](Usage)
c.accessTokenStrategy match {
case Some(Azure) if c.googleRequesterPaysProject.nonEmpty =>
Left(s"Requester pays project is only valid with access token strategy '$Google'")
case Some(Azure) if List(c.azureVaultName, c.azureSecretName).exists(_.isEmpty) =>
Left(s"Both vault name and secret name must be specified for access token strategy $Azure")
case Some(Google) if c.azureIdentityClientId.nonEmpty =>
Left(s"Identity client id is only valid with access token strategy '$Azure'")
case Some(Azure) => Right(())
case Some(Google) if List(c.azureSecretName, c.azureVaultName, c.azureIdentityClientId).forall(_.isEmpty) => Right(())
case Some(Google) => Left(s"One or more specified options are only valid with access token strategy '$Azure'")
case Some(Google) => Right(())
case Some(huh) => Left(s"Unrecognized access token strategy '$huh'")
case None => Left("Programmer error, access token strategy should not be None")
}
Expand Down Expand Up @@ -100,8 +93,6 @@ case class CommandLineArguments(accessTokenStrategy: Option[String] = Option(Goo
drsObject: Option[String] = None,
containerPath: Option[String] = None,
googleRequesterPaysProject: Option[String] = None,
azureVaultName: Option[String] = None,
azureSecretName: Option[String] = None,
azureIdentityClientId: Option[String] = None,
manifestPath: Option[String] = None,
googleRequesterPaysProjectConflict: Boolean = false)
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
package drs.localizer

import cloud.nio.impl.drs.{DrsConfig, DrsPathResolver}
import cloud.nio.impl.drs.{DrsConfig, DrsCredentials, DrsPathResolver}
import common.validation.ErrorOr.ErrorOr
import drs.localizer.accesstokens.AccessTokenStrategy


class DrsLocalizerDrsPathResolver(drsConfig: DrsConfig, accessTokenStrategy: AccessTokenStrategy) extends DrsPathResolver(drsConfig) {
override def getAccessToken: ErrorOr[String] = accessTokenStrategy.getAccessToken()
class DrsLocalizerDrsPathResolver(drsConfig: DrsConfig, drsCredentials: DrsCredentials) extends DrsPathResolver(drsConfig) {
override def getAccessToken: ErrorOr[String] = drsCredentials.getAccessToken
}
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,10 @@ import cats.data.NonEmptyList
import cats.effect.{ExitCode, IO, IOApp}
import cats.implicits._
import cloud.nio.impl.drs.DrsPathResolver.{FatalRetryDisposition, RegularRetryDisposition}
import cloud.nio.impl.drs.{AccessUrl, DrsConfig, DrsPathResolver, MarthaField}
import cloud.nio.impl.drs._
import cloud.nio.spi.{CloudNioBackoff, CloudNioSimpleExponentialBackoff}
import com.typesafe.scalalogging.StrictLogging
import drs.localizer.CommandLineParser.AccessTokenStrategy.{Azure, Google}
import drs.localizer.accesstokens.{AccessTokenStrategy, AzureB2CAccessTokenStrategy, GoogleAccessTokenStrategy}
import drs.localizer.downloaders.AccessUrlDownloader.Hashes
import drs.localizer.downloaders._
import org.apache.commons.csv.{CSVFormat, CSVParser}
Expand All @@ -29,8 +28,8 @@ object DrsLocalizerMain extends IOApp with StrictLogging {
val localize: Option[IO[ExitCode]] = for {
pa <- parsedArgs
run <- pa.accessTokenStrategy.collect {
case Azure => runLocalizer(pa, AzureB2CAccessTokenStrategy(pa))
case Google => runLocalizer(pa, GoogleAccessTokenStrategy)
case Azure => runLocalizer(pa, AzureDrsCredentials(pa.azureIdentityClientId))
case Google => runLocalizer(pa, GoogleAppDefaultTokenStrategy)
}
} yield run

Expand All @@ -55,39 +54,39 @@ object DrsLocalizerMain extends IOApp with StrictLogging {
IO.pure(ExitCode.Error)
}

def runLocalizer(commandLineArguments: CommandLineArguments, accessTokenStrategy: AccessTokenStrategy): IO[ExitCode] = {
def runLocalizer(commandLineArguments: CommandLineArguments, drsCredentials: DrsCredentials): IO[ExitCode] = {
commandLineArguments.manifestPath match {
case Some(manifestPath) =>
val manifestFile = new File(manifestPath)
val csvParser = CSVParser.parse(manifestFile, Charset.defaultCharset(), CSVFormat.DEFAULT)
val exitCodes: IO[List[ExitCode]] = csvParser.asScala.map(record => {
val drsObject = record.get(0)
val containerPath = record.get(1)
localizeFile(commandLineArguments, accessTokenStrategy, drsObject, containerPath)
localizeFile(commandLineArguments, drsCredentials, drsObject, containerPath)
}).toList.sequence
exitCodes.map(_.find(_ != ExitCode.Success).getOrElse(ExitCode.Success))
case None =>
val drsObject = commandLineArguments.drsObject.get
val containerPath = commandLineArguments.containerPath.get
localizeFile(commandLineArguments, accessTokenStrategy, drsObject, containerPath)
localizeFile(commandLineArguments, drsCredentials, drsObject, containerPath)
}
}

private def localizeFile(commandLineArguments: CommandLineArguments, accessTokenStrategy: AccessTokenStrategy, drsObject: String, containerPath: String) = {
new DrsLocalizerMain(drsObject, containerPath, accessTokenStrategy, commandLineArguments.googleRequesterPaysProject).
private def localizeFile(commandLineArguments: CommandLineArguments, drsCredentials: DrsCredentials, drsObject: String, containerPath: String) = {
new DrsLocalizerMain(drsObject, containerPath, drsCredentials, commandLineArguments.googleRequesterPaysProject).
resolveAndDownloadWithRetries(downloadRetries = 3, checksumRetries = 1, defaultDownloaderFactory, Option(defaultBackoff)).map(_.exitCode)
}
}

class DrsLocalizerMain(drsUrl: String,
downloadLoc: String,
accessTokenStrategy: AccessTokenStrategy,
drsCredentials: DrsCredentials,
requesterPaysProjectIdOption: Option[String]) extends StrictLogging {

def getDrsPathResolver: IO[DrsLocalizerDrsPathResolver] = {
IO {
val drsConfig = DrsConfig.fromEnv(sys.env)
new DrsLocalizerDrsPathResolver(drsConfig, accessTokenStrategy)
new DrsLocalizerDrsPathResolver(drsConfig, drsCredentials)
}
}

Expand Down

This file was deleted.

This file was deleted.

This file was deleted.

Loading