Skip to content

Commit

Permalink
Collapse localizer's AccessTokenStrategy into DrsCredentials
Browse files Browse the repository at this point in the history
  • Loading branch information
jgainerdewar committed Nov 19, 2022
1 parent 57eb1d5 commit f34a54c
Show file tree
Hide file tree
Showing 14 changed files with 58 additions and 86 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ import com.azure.core.credential.TokenRequestContext
import com.azure.core.management.AzureEnvironment
import com.azure.core.management.profile.AzureProfile
import com.azure.identity.DefaultAzureCredentialBuilder
import com.google.auth.oauth2.{AccessToken, OAuth2Credentials}
import com.google.auth.oauth2.{AccessToken, GoogleCredentials, OAuth2Credentials}
import com.typesafe.config.Config
import common.validation.ErrorOr.ErrorOr
import net.ceedubs.ficus.Ficus._
Expand All @@ -23,7 +23,11 @@ sealed trait DrsCredentials {
def getAccessToken: ErrorOr[String]
}

case class GoogleDrsCredentials(credentials: OAuth2Credentials, acceptableTTL: Duration) extends DrsCredentials {
/**
* Strategy for obtaining an access token from an existing OAuth credential. This class
* is designed for use within the Cromwell engine.
*/
case class GoogleOauthDrsCredentials(credentials: OAuth2Credentials, acceptableTTL: Duration) extends DrsCredentials {
//Based on method from GoogleRegistry
def getAccessToken: ErrorOr[String] = {
def accessTokenTTLIsAcceptable(accessToken: AccessToken): Boolean = {
Expand All @@ -43,11 +47,34 @@ case class GoogleDrsCredentials(credentials: OAuth2Credentials, acceptableTTL: D
}
}

object GoogleDrsCredentials {
def apply(credentials: OAuth2Credentials, config: Config): GoogleDrsCredentials =
GoogleDrsCredentials(credentials, config.as[FiniteDuration]("access-token-acceptable-ttl"))
object GoogleOauthDrsCredentials {
def apply(credentials: OAuth2Credentials, config: Config): GoogleOauthDrsCredentials =
GoogleOauthDrsCredentials(credentials, config.as[FiniteDuration]("access-token-acceptable-ttl"))
}


/**
* Strategy for obtaining an access token from Google Application Default credentials that are assumed to already exist
* in the environment. This class is designed for use by standalone executables running in environments
* that have direct access to a Google identity (ex. CromwellDrsLocalizer).
*/
case object GoogleAppDefaultTokenStrategy extends DrsCredentials {
private final val UserInfoEmailScope = "https://www.googleapis.com/auth/userinfo.email"
private final val UserInfoProfileScope = "https://www.googleapis.com/auth/userinfo.profile"

def getAccessToken: ErrorOr[String] = {
Try {
val scopedCredentials = GoogleCredentials.getApplicationDefault().createScoped(UserInfoEmailScope, UserInfoProfileScope)
scopedCredentials.refreshAccessToken().getTokenValue
} match {
case Success(null) => "null token value attempting to refresh access token".invalidNel
case Success(value) => value.validNel
case Failure(e) => s"Failed to refresh access token: ${e.getMessage}".invalidNel
}
}
}


case class AzureDrsCredentials(identityClientId: Option[String]) extends DrsCredentials {

final val tokenAcquisitionTimeout = new jDuration(30, TimeUnit.SECONDS)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ class DrsCloudNioFileProviderSpec extends AnyFlatSpecLike with CromwellTimeoutSp
val fileSystemProvider = new MockDrsCloudNioFileSystemProvider(config = config)
fileSystemProvider.drsConfig.marthaUrl should be("https://from.config")
fileSystemProvider.drsCredentials match {
case GoogleDrsCredentials(_, ttl) => ttl should be(1.minute)
case GoogleOauthDrsCredentials(_, ttl) => ttl should be(1.minute)
case error => fail(s"Expected GoogleDrsCredentials, found $error")
}
fileSystemProvider.fileProvider should be(a[DrsCloudNioFileProvider])
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ class MockDrsCloudNioFileSystemProvider(config: Config = mockConfig,
),
mockResolver: Option[EngineDrsPathResolver] = None,
)
extends DrsCloudNioFileSystemProvider(config, GoogleDrsCredentials(NoCredentials.getInstance, config), drsReadInterpreter) {
extends DrsCloudNioFileSystemProvider(config, GoogleOauthDrsCredentials(NoCredentials.getInstance, config), drsReadInterpreter) {

override lazy val drsPathResolver: EngineDrsPathResolver = {
mockResolver getOrElse
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ class MockEngineDrsPathResolver(drsConfig: DrsConfig = MockDrsPaths.mockDrsConfi
httpClientBuilderOverride: Option[HttpClientBuilder] = None,
accessTokenAcceptableTTL: Duration = Duration.Inf,
)
extends EngineDrsPathResolver(drsConfig, GoogleDrsCredentials(NoCredentials.getInstance, accessTokenAcceptableTTL)) {
extends EngineDrsPathResolver(drsConfig, GoogleOauthDrsCredentials(NoCredentials.getInstance, accessTokenAcceptableTTL)) {

override protected lazy val httpClientBuilder: HttpClientBuilder =
httpClientBuilderOverride getOrElse MockSugar.mock[HttpClientBuilder]
Expand Down
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
package drs.localizer

import cloud.nio.impl.drs.{DrsConfig, DrsPathResolver}
import cloud.nio.impl.drs.{DrsConfig, DrsCredentials, DrsPathResolver}
import common.validation.ErrorOr.ErrorOr
import drs.localizer.accesstokens.AccessTokenStrategy


class DrsLocalizerDrsPathResolver(drsConfig: DrsConfig, accessTokenStrategy: AccessTokenStrategy) extends DrsPathResolver(drsConfig) {
override def getAccessToken: ErrorOr[String] = accessTokenStrategy.getAccessToken()
class DrsLocalizerDrsPathResolver(drsConfig: DrsConfig, drsCredentials: DrsCredentials) extends DrsPathResolver(drsConfig) {
override def getAccessToken: ErrorOr[String] = drsCredentials.getAccessToken
}
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,10 @@ import cats.data.NonEmptyList
import cats.effect.{ExitCode, IO, IOApp}
import cats.implicits._
import cloud.nio.impl.drs.DrsPathResolver.{FatalRetryDisposition, RegularRetryDisposition}
import cloud.nio.impl.drs.{AccessUrl, DrsConfig, DrsPathResolver, MarthaField}
import cloud.nio.impl.drs._
import cloud.nio.spi.{CloudNioBackoff, CloudNioSimpleExponentialBackoff}
import com.typesafe.scalalogging.StrictLogging
import drs.localizer.CommandLineParser.AccessTokenStrategy.{Azure, Google}
import drs.localizer.accesstokens.{AccessTokenStrategy, AzureAccessTokenStrategy, GoogleAccessTokenStrategy}
import drs.localizer.downloaders.AccessUrlDownloader.Hashes
import drs.localizer.downloaders._
import org.apache.commons.csv.{CSVFormat, CSVParser}
Expand All @@ -29,8 +28,8 @@ object DrsLocalizerMain extends IOApp with StrictLogging {
val localize: Option[IO[ExitCode]] = for {
pa <- parsedArgs
run <- pa.accessTokenStrategy.collect {
case Azure => runLocalizer(pa, AzureAccessTokenStrategy(pa))
case Google => runLocalizer(pa, GoogleAccessTokenStrategy)
case Azure => runLocalizer(pa, AzureDrsCredentials(pa.azureIdentityClientId))
case Google => runLocalizer(pa, GoogleAppDefaultTokenStrategy)
}
} yield run

Expand All @@ -55,39 +54,39 @@ object DrsLocalizerMain extends IOApp with StrictLogging {
IO.pure(ExitCode.Error)
}

def runLocalizer(commandLineArguments: CommandLineArguments, accessTokenStrategy: AccessTokenStrategy): IO[ExitCode] = {
def runLocalizer(commandLineArguments: CommandLineArguments, drsCredentials: DrsCredentials): IO[ExitCode] = {
commandLineArguments.manifestPath match {
case Some(manifestPath) =>
val manifestFile = new File(manifestPath)
val csvParser = CSVParser.parse(manifestFile, Charset.defaultCharset(), CSVFormat.DEFAULT)
val exitCodes: IO[List[ExitCode]] = csvParser.asScala.map(record => {
val drsObject = record.get(0)
val containerPath = record.get(1)
localizeFile(commandLineArguments, accessTokenStrategy, drsObject, containerPath)
localizeFile(commandLineArguments, drsCredentials, drsObject, containerPath)
}).toList.sequence
exitCodes.map(_.find(_ != ExitCode.Success).getOrElse(ExitCode.Success))
case None =>
val drsObject = commandLineArguments.drsObject.get
val containerPath = commandLineArguments.containerPath.get
localizeFile(commandLineArguments, accessTokenStrategy, drsObject, containerPath)
localizeFile(commandLineArguments, drsCredentials, drsObject, containerPath)
}
}

private def localizeFile(commandLineArguments: CommandLineArguments, accessTokenStrategy: AccessTokenStrategy, drsObject: String, containerPath: String) = {
new DrsLocalizerMain(drsObject, containerPath, accessTokenStrategy, commandLineArguments.googleRequesterPaysProject).
private def localizeFile(commandLineArguments: CommandLineArguments, drsCredentials: DrsCredentials, drsObject: String, containerPath: String) = {
new DrsLocalizerMain(drsObject, containerPath, drsCredentials, commandLineArguments.googleRequesterPaysProject).
resolveAndDownloadWithRetries(downloadRetries = 3, checksumRetries = 1, defaultDownloaderFactory, Option(defaultBackoff)).map(_.exitCode)
}
}

class DrsLocalizerMain(drsUrl: String,
downloadLoc: String,
accessTokenStrategy: AccessTokenStrategy,
drsCredentials: DrsCredentials,
requesterPaysProjectIdOption: Option[String]) extends StrictLogging {

def getDrsPathResolver: IO[DrsLocalizerDrsPathResolver] = {
IO {
val drsConfig = DrsConfig.fromEnv(sys.env)
new DrsLocalizerDrsPathResolver(drsConfig, accessTokenStrategy)
new DrsLocalizerDrsPathResolver(drsConfig, drsCredentials)
}
}

Expand Down

This file was deleted.

This file was deleted.

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ import cloud.nio.impl.drs.DrsPathResolver.FatalRetryDisposition
import cloud.nio.impl.drs.{AccessUrl, DrsConfig, MarthaField, MarthaResponse}
import common.assertion.CromwellTimeoutSpec
import drs.localizer.MockDrsLocalizerDrsPathResolver.{FakeAccessTokenStrategy, FakeHashes}
import drs.localizer.accesstokens.AccessTokenStrategy
import drs.localizer.accesstokens.DrsCredentials
import drs.localizer.downloaders.AccessUrlDownloader.Hashes
import drs.localizer.downloaders._
import org.scalatest.flatspec.AnyFlatSpec
Expand Down Expand Up @@ -341,5 +341,5 @@ class MockDrsLocalizerDrsPathResolver(drsConfig: DrsConfig) extends

object MockDrsLocalizerDrsPathResolver {
val FakeHashes: Option[Map[String, String]] = Option(Map("md5" -> "abc123", "crc32c" -> "34fd67"))
val FakeAccessTokenStrategy: AccessTokenStrategy = () => "testing code: do not call me".invalidNel
val FakeAccessTokenStrategy: DrsCredentials = () => "testing code: do not call me".invalidNel
}
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ package cromwell.filesystems.drs

import akka.actor.ActorSystem
import cats.data.Validated.{Invalid, Valid}
import cloud.nio.impl.drs.{AzureDrsCredentials, DrsCloudNioFileSystemProvider, GoogleDrsCredentials}
import cloud.nio.impl.drs.{AzureDrsCredentials, DrsCloudNioFileSystemProvider, GoogleOauthDrsCredentials}
import com.google.api.services.oauth2.Oauth2Scopes
import com.typesafe.config.Config
import cromwell.cloudsupport.gcp.GoogleConfiguration
Expand Down Expand Up @@ -39,7 +39,7 @@ class DrsPathBuilderFactory(globalConfig: Config, instanceConfig: Config, single
case googleAuthScheme => googleConfiguration.auth(googleAuthScheme) match {
case Valid(auth) => (
Option(auth),
GoogleDrsCredentials(auth.credentials(options.get(_).get, marthaScopes), singletonConfig.config)
GoogleOauthDrsCredentials(auth.credentials(options.get(_).get, marthaScopes), singletonConfig.config)
)
case Invalid(error) => throw new RuntimeException(s"Error while instantiating DRS path builder factory. Errors: ${error.toString}")
}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
package cromwell.filesystems.drs

import cloud.nio.impl.drs.DrsCloudNioFileProvider.DrsReadInterpreter
import cloud.nio.impl.drs.{DrsCloudNioFileSystemProvider, GoogleDrsCredentials}
import cloud.nio.impl.drs.{DrsCloudNioFileSystemProvider, GoogleOauthDrsCredentials}
import com.google.cloud.NoCredentials
import com.typesafe.config.{Config, ConfigFactory}
import cromwell.core.TestKitSuite
Expand Down Expand Up @@ -387,7 +387,7 @@ class DrsPathBuilderSpec extends TestKitSuite with AnyFlatSpecLike with Matchers
private lazy val fakeCredentials = NoCredentials.getInstance

private lazy val drsPathBuilder = DrsPathBuilder(
new DrsCloudNioFileSystemProvider(marthaConfig, GoogleDrsCredentials(fakeCredentials, 1.minutes), drsReadInterpreter),
new DrsCloudNioFileSystemProvider(marthaConfig, GoogleOauthDrsCredentials(fakeCredentials, 1.minutes), drsReadInterpreter),
None,
)
}
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
package cromwell.backend.google.pipelines.v2alpha1

import cloud.nio.impl.drs.DrsCloudNioFileProvider.DrsReadInterpreter
import cloud.nio.impl.drs.{DrsCloudNioFileSystemProvider, GoogleDrsCredentials}
import cloud.nio.impl.drs.{DrsCloudNioFileSystemProvider, GoogleOauthDrsCredentials}
import com.google.cloud.NoCredentials
import com.typesafe.config.{Config, ConfigFactory}
import common.assertion.CromwellTimeoutSpec
Expand Down Expand Up @@ -39,7 +39,7 @@ class PipelinesConversionsSpec extends AnyFlatSpec with CromwellTimeoutSpec with
it should "create a DRS input parameter" in {

val drsPathBuilder = DrsPathBuilder(
new DrsCloudNioFileSystemProvider(marthaConfig, GoogleDrsCredentials(fakeCredentials, 1.minutes), drsReadInterpreter),
new DrsCloudNioFileSystemProvider(marthaConfig, GoogleOauthDrsCredentials(fakeCredentials, 1.minutes), drsReadInterpreter),
None,
)
val drsPath = drsPathBuilder.build("drs://drs.example.org/aaaabbbb-cccc-dddd-eeee-abcd0000dcba").get
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ package cromwell.backend.google.pipelines.v2beta
import java.nio.file.Paths
import cats.data.NonEmptyList
import cloud.nio.impl.drs.DrsCloudNioFileProvider.DrsReadInterpreter
import cloud.nio.impl.drs.{DrsCloudNioFileSystemProvider, GoogleDrsCredentials}
import cloud.nio.impl.drs.{DrsCloudNioFileSystemProvider, GoogleOauthDrsCredentials}
import com.google.cloud.NoCredentials
import com.typesafe.config.{Config, ConfigFactory}
import common.assertion.CromwellTimeoutSpec
Expand Down Expand Up @@ -68,7 +68,7 @@ class PipelinesApiAsyncBackendJobExecutionActorSpec extends AnyFlatSpec with Cro
throw new UnsupportedOperationException("PipelinesApiAsyncBackendJobExecutionActorSpec doesn't need to use drs read interpreter.")

DrsPathBuilder(
new DrsCloudNioFileSystemProvider(marthaConfig, GoogleDrsCredentials(fakeCredentials, 1.minutes), drsReadInterpreter),
new DrsCloudNioFileSystemProvider(marthaConfig, GoogleOauthDrsCredentials(fakeCredentials, 1.minutes), drsReadInterpreter),
None,
)
}
Expand Down

0 comments on commit f34a54c

Please sign in to comment.