Skip to content

Commit

Permalink
Refactor of blobfilesystemManager and tests covering its functionality
Browse files Browse the repository at this point in the history
  • Loading branch information
kraefrei committed Aug 31, 2022
1 parent 8f00cc1 commit a0e1606
Show file tree
Hide file tree
Showing 4 changed files with 296 additions and 142 deletions.
Original file line number Diff line number Diff line change
@@ -1,25 +1,22 @@
package cromwell.filesystems.blob

import com.azure.core.credential.AzureSasCredential
import com.google.common.net.UrlEscapers
import cromwell.core.path.{NioPath, Path, PathBuilder}
import cromwell.filesystems.blob.BlobPathBuilder._

import java.net.{MalformedURLException, URI}
import java.time.Instant
import java.time.temporal.TemporalAmount
import scala.language.postfixOps
import scala.util.{Failure, Try}
import scala.util.{Failure, Success, Try}

object BlobPathBuilder {

sealed trait BlobPathValidation
case class ValidBlobPath(path: String) extends BlobPathValidation
case class UnparsableBlobPath(errorMessage: Throwable) extends BlobPathValidation

def invalidBlobPathMessage(container: String, endpoint: String) = s"Malformed Blob URL for this builder. Expecting a URL for a container $container and endpoint $endpoint"
def invalidBlobPathMessage(container: BlobContainerName, endpoint: EndpointURL) = s"Malformed Blob URL for this builder. Expecting a URL for a container $container and endpoint $endpoint"
def parseURI(string: String): URI = URI.create(UrlEscapers.urlFragmentEscaper().escape(string))
def parseStorageAccount(uri: URI): Option[String] = uri.getHost.split("\\.").find(_.nonEmpty)
def parseStorageAccount(uri: URI): Try[StorageAccountName] = uri.getHost.split("\\.").find(_.nonEmpty).map(StorageAccountName(_)).fold[Try[StorageAccountName]](Failure(new Exception("bad")))(Success(_))

/**
* Validates a that a path from a string is a valid BlobPath of the format:
Expand All @@ -39,13 +36,13 @@ object BlobPathBuilder {
*
* If the configured container and storage account do not match, the string is considered unparsable
*/
def validateBlobPath(string: String, container: String, endpoint: String): BlobPathValidation = {
def validateBlobPath(string: String, container: BlobContainerName, endpoint: EndpointURL): BlobPathValidation = {
Try {
val uri = parseURI(string)
val storageAccount = parseStorageAccount(parseURI(endpoint))
val hasContainer = uri.getPath.split("/").find(_.nonEmpty).contains(container)
val hasEndpoint = storageAccount.exists(parseStorageAccount(uri).contains(_))
if (hasContainer && storageAccount.isDefined && hasEndpoint) {
val storageAccount = parseStorageAccount(parseURI(endpoint.value))
val hasContainer = uri.getPath.split("/").find(_.nonEmpty).contains(container.value)
val hasEndpoint = storageAccount.toOption.exists(parseStorageAccount(uri).toOption.contains(_))
if (hasContainer && storageAccount.isSuccess && hasEndpoint) {
ValidBlobPath(uri.getPath.replaceFirst("/" + container, ""))
} else {
UnparsableBlobPath(new MalformedURLException(invalidBlobPathMessage(container, endpoint)))
Expand All @@ -54,39 +51,31 @@ object BlobPathBuilder {
}
}

class BlobPathBuilder(fsm: FileSystemManager, container: String, endpoint: String) extends PathBuilder {
class BlobPathBuilder(container: BlobContainerName, endpoint: EndpointURL)(private val fsm: BlobFileSystemManager) extends PathBuilder {

def build(string: String): Try[BlobPath] = {
validateBlobPath(string, container, endpoint) match {
case ValidBlobPath(path) => Try(BlobPath(path, endpoint, container, fsm))
case ValidBlobPath(path) => Try(BlobPath(path, endpoint, container)(fsm))
case UnparsableBlobPath(errorMessage: Throwable) => Failure(errorMessage)
}
}
override def name: String = "Azure Blob Storage"
}

case class BlobPath private[blob](pathString: String, endpoint: String, container: String, fsm: FileSystemManager) extends Path {
//var token = blobTokenGenerator.getAccessToken
//var expiry = token.getSignature.split("&").filter(_.startsWith("se")).headOption.map(_.replaceFirst("se=",""))
override def nioPath: NioPath = findNioPath(path = pathString, endpoint, container)
case class BlobPath private[blob](pathString: String, endpoint: EndpointURL, container: BlobContainerName)(private val fsm: BlobFileSystemManager) extends Path {
override def nioPath: NioPath = findNioPath(pathString)

override protected def newPath(nioPath: NioPath): Path = BlobPath(nioPath.toString, endpoint, container, fsm)
override protected def newPath(nioPath: NioPath): Path = BlobPath(nioPath.toString, endpoint, container)(fsm)

override def pathAsString: String = List(endpoint, container, nioPath.toString).mkString("/")

override def pathWithoutScheme: String = parseURI(endpoint).getHost + "/" + container + "/" + nioPath.toString
override def pathWithoutScheme: String = parseURI(endpoint.value).getHost + "/" + container + "/" + nioPath.toString

def findNioPath(path: String, endpoint: String, container: String): NioPath = (for {
private def findNioPath(path: String): NioPath = (for {
fileSystem <- fsm.retrieveFilesystem()
nioPath = fileSystem.getPath(path)
} yield nioPath).get
}

case class TokenExpiration(token: AzureSasCredential, buffer: TemporalAmount) {
val expiry = for {
expiryString <- token.getSignature.split("&").find(_.startsWith("se")).map(_.replaceFirst("se=","")).map(_.replace("%3A", ":"))
instant = Instant.parse(expiryString)
} yield instant

def hasTokenExpired: Boolean = expiry.exists(_.isAfter(Instant.now.plus(buffer)))
} yield nioPath) match {
case Success(value) => value
case Failure(exception) => throw exception
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,10 @@ import com.azure.core.management.AzureEnvironment
import com.azure.core.management.profile.AzureProfile
import com.azure.identity.DefaultAzureCredentialBuilder
import com.azure.resourcemanager.AzureResourceManager
import com.azure.storage.blob.BlobContainerClientBuilder
import com.azure.resourcemanager.storage.models.{StorageAccount, StorageAccountKey}
import com.azure.storage.blob.nio.AzureFileSystem
import com.azure.storage.blob.sas.{BlobContainerSasPermission, BlobServiceSasSignatureValues}
import com.azure.storage.blob.{BlobContainerClient, BlobContainerClientBuilder}
import com.azure.storage.common.StorageSharedKeyCredential
import com.typesafe.config.Config
import cromwell.core.WorkflowOptions
Expand All @@ -18,80 +19,98 @@ import net.ceedubs.ficus.Ficus._
import java.net.URI
import java.nio.file.{FileSystem, FileSystemNotFoundException, FileSystems}
import java.time.temporal.ChronoUnit
import java.time.{Duration, OffsetDateTime}
import java.time.{Duration, Instant, OffsetDateTime}
import scala.concurrent.{ExecutionContext, Future}
import scala.jdk.CollectionConverters._
import scala.util.{Failure, Try}
import scala.util.{Failure, Success, Try}

final case class BlobFileSystemConfig(config: Config)
final case class BlobPathBuilderFactory(globalConfig: Config, instanceConfig: Config, singletonConfig: BlobFileSystemConfig) extends PathBuilderFactory {
val sasToken: String = instanceConfig.as[String]("sas-token")
val container: String = instanceConfig.as[String]("store")
val endpoint: String = instanceConfig.as[String]("endpoint")
val workspaceId: Option[String] = instanceConfig.as[Option[String]]("workspace-id")
val workspaceManagerURL: Option[String] = singletonConfig.config.as[Option[String]]("workspace-manager-url")
val container = BlobContainerName(instanceConfig.as[String]("store"))
val endpoint = EndpointURL(instanceConfig.as[String]("endpoint"))
val workspaceId: Option[WorkspaceId] = instanceConfig.as[Option[String]]("workspace-id").map(WorkspaceId(_))
val expiryBufferMinutes: Long = instanceConfig.as[Option[Long]]("expiry-buffer-minutes").getOrElse(10)
val workspaceManagerURL: Option[WorkspaceManagerURL] = singletonConfig.config.as[Option[String]]("workspace-manager-url").map(WorkspaceManagerURL(_))

val fsm = FileSystemManager(container, endpoint, 10, workspaceId, workspaceManagerURL)
val blobTokenGenerator: BlobTokenGenerator = BlobTokenGenerator.createBlobTokenGenerator(
container, endpoint, workspaceId, workspaceManagerURL)
val fsm = BlobFileSystemManager(container, endpoint, expiryBufferMinutes, blobTokenGenerator)

override def withOptions(options: WorkflowOptions)(implicit as: ActorSystem, ec: ExecutionContext): Future[BlobPathBuilder] = {
Future {
new BlobPathBuilder(fsm, container, endpoint)
new BlobPathBuilder(container, endpoint)(fsm)
}
}
}

case class FileSystemManager(container: String,
endpoint: String,
preemptionMinutes: Long,
workspaceId: Option[String] = None,
workspaceManagerURL: Option[String] = None) {
final case class BlobContainerName(value: String) {override def toString: String = value}
final case class StorageAccountName(value: String) {override def toString: String = value}
final case class EndpointURL(value: String) {override def toString: String = value}
final case class WorkspaceId(value: String) {override def toString: String = value}
final case class WorkspaceManagerURL(value: String) {override def toString: String = value}

var expiry: Option[TokenExpiration] = None
val blobTokenGenerator: BlobTokenGenerator = BlobTokenGenerator.createBlobTokenGenerator(
container, endpoint, workspaceId, workspaceManagerURL)
object BlobFileSystemManager {
def parseTokenExpiry(token: AzureSasCredential): Option[Instant] = for {
expiryString <- token.getSignature.split("&").find(_.startsWith("se")).map(_.replaceFirst("se=","")).map(_.replace("%3A", ":"))
instant = Instant.parse(expiryString)
} yield instant

def buildConfigMap(credential: AzureSasCredential, container: String): Map[String, Object] = {
def buildConfigMap(credential: AzureSasCredential, container: BlobContainerName): Map[String, Object] = {
Map((AzureFileSystem.AZURE_STORAGE_SAS_TOKEN_CREDENTIAL, credential),
(AzureFileSystem.AZURE_STORAGE_FILE_STORES, container),
(AzureFileSystem.AZURE_STORAGE_FILE_STORES, container.value),
(AzureFileSystem.AZURE_STORAGE_SKIP_INITIAL_CONTAINER_CHECK, java.lang.Boolean.TRUE))
}

def uri = new URI("azb://?endpoint=" + endpoint)

def hasTokenExpired(tokenExpiry: Instant, buffer: Duration) = Instant.now.plus(buffer).isAfter(tokenExpiry)
def uri(endpoint: EndpointURL) = new URI("azb://?endpoint=" + endpoint)
}
case class BlobFileSystemManager(container: BlobContainerName,
endpoint: EndpointURL,
expiryBufferMinutes: Long,
blobTokenGenerator: BlobTokenGenerator,
fileSystemAPI: FileSystemAPI = FileSystemAPI(),
initialExpiration: Option[Instant] = None) {
private var expiry: Option[Instant] = initialExpiration
val buffer: Duration = Duration.of(expiryBufferMinutes, ChronoUnit.MINUTES)

def getExpiry() = expiry
def uri = BlobFileSystemManager.uri(endpoint)
def hasTokenExpired: Boolean = expiry.exists(BlobFileSystemManager.hasTokenExpired(_, buffer))
def retrieveFilesystem(): Try[FileSystem] = {
synchronized {
expiry.map(_.hasTokenExpired) match {
case Some(false) => Try(FileSystems.getFileSystem(uri)) recoverWith {
(hasTokenExpired, expiry) match {
case (false, Some(_)) => fileSystemAPI.getFileSystem(uri) recoverWith {
// If no filesystem already exists, this will create a new connection, with the provided configs
case _: FileSystemNotFoundException => blobTokenGenerator.generateAccessToken.flatMap(generateFilesystem(uri, container, _))
}
// If the token has expired, OR there is no token record, try to close the FS and regenerate
case _ => {
closeFileSystem(uri)
blobTokenGenerator.generateAccessToken.flatMap(generateFilesystem(uri, container, _))
closeFileSystem(uri)
blobTokenGenerator.generateAccessToken.flatMap(generateFilesystem(uri, container, _))
}
}
}
}

def generateFilesystem(uri: URI, container: String, token: AzureSasCredential): Try[FileSystem] = {
expiry = Some(TokenExpiration(token, Duration.of(preemptionMinutes, ChronoUnit.MINUTES)))
Try(FileSystems.newFileSystem(uri, buildConfigMap(token, container).asJava))
private def generateFilesystem(uri: URI, container: BlobContainerName, token: AzureSasCredential): Try[FileSystem] = {
expiry = BlobFileSystemManager.parseTokenExpiry(token)
Try(fileSystemAPI.newFileSystem(uri, BlobFileSystemManager.buildConfigMap(token, container)))
}

def closeFileSystem(uri: URI): Try[Unit] = Try(FileSystems.getFileSystem(uri)).map(_.close)
private def closeFileSystem(uri: URI): Try[Unit] = fileSystemAPI.getFileSystem(uri).map(_.close)
}

sealed trait BlobTokenGenerator {
def generateAccessToken: Try[AzureSasCredential]
case class FileSystemAPI() {
def getFileSystem(uri: URI): Try[FileSystem] = Try(FileSystems.getFileSystem(uri))
def newFileSystem(uri: URI, config: Map[String, Object]): FileSystem = FileSystems.newFileSystem(uri, config.asJava)
}

sealed trait BlobTokenGenerator {def generateAccessToken: Try[AzureSasCredential]}
object BlobTokenGenerator {
def createBlobTokenGenerator(container: String, endpoint: String): BlobTokenGenerator = {
def createBlobTokenGenerator(container: BlobContainerName, endpoint: EndpointURL): BlobTokenGenerator = {
createBlobTokenGenerator(container, endpoint, None, None)
}
def createBlobTokenGenerator(container: String, endpoint: String, workspaceId: Option[String], workspaceManagerURL: Option[String]): BlobTokenGenerator = {
(container: String, endpoint: String, workspaceId, workspaceManagerURL) match {
def createBlobTokenGenerator(container: BlobContainerName, endpoint: EndpointURL, workspaceId: Option[WorkspaceId], workspaceManagerURL: Option[WorkspaceManagerURL]): BlobTokenGenerator = {
(container: BlobContainerName, endpoint: EndpointURL, workspaceId, workspaceManagerURL) match {
case (container, endpoint, None, None) =>
NativeBlobTokenGenerator(container, endpoint)
case (container, endpoint, Some(workspaceId), Some(workspaceManagerURL)) =>
Expand All @@ -102,58 +121,43 @@ object BlobTokenGenerator {
}
}

case class WSMBlobTokenGenerator(container: String, endpoint: String, workspaceId: String, workspaceManagerURL: String) extends BlobTokenGenerator {
case class WSMBlobTokenGenerator(container: BlobContainerName, endpoint: EndpointURL, workspaceId: WorkspaceId, workspaceManagerURL: WorkspaceManagerURL) extends BlobTokenGenerator {
def generateAccessToken: Try[AzureSasCredential] = Failure(new NotImplementedError)
}

case class NativeBlobTokenGenerator(container: String, endpoint: String) extends BlobTokenGenerator {
def generateAccessToken: Try[AzureSasCredential] = {
val storageAccountName = BlobPathBuilder.parseStorageAccount(BlobPathBuilder.parseURI(endpoint)) match {
case Some(storageAccountName) => storageAccountName
case _ => throw new Exception("Storage account could not be parsed from endpoint")
}
case class NativeBlobTokenGenerator(container: BlobContainerName, endpoint: EndpointURL) extends BlobTokenGenerator {

val profile = new AzureProfile(AzureEnvironment.AZURE)
val azureCredential = new DefaultAzureCredentialBuilder()
.authorityHost(profile.getEnvironment.getActiveDirectoryEndpoint)
private val azureProfile = new AzureProfile(AzureEnvironment.AZURE)
private def azureCredentialBuilder = new DefaultAzureCredentialBuilder()
.authorityHost(azureProfile.getEnvironment.getActiveDirectoryEndpoint)
.build
val azure = AzureResourceManager.authenticate(azureCredential, profile).withDefaultSubscription()

val storageAccounts = azure.storageAccounts()
val storageAccount = storageAccounts
.list()
.asScala
.find(_.name == storageAccountName)

val storageAccountKeys = storageAccount match {
case Some(value) => value.getKeys.asScala.map(_.value())
case _ => throw new Exception("Storage Account not found")
}

val storageAccountKey = storageAccountKeys.headOption match {
case Some(value) => value
case _ => throw new Exception("Storage Account has no keys")
}

val keyCredential = new StorageSharedKeyCredential(
storageAccountName,
storageAccountKey
)
val blobContainerClient = new BlobContainerClientBuilder()
.credential(keyCredential)
.endpoint(endpoint)
.containerName(container)
.buildClient()

val blobContainerSasPermission = new BlobContainerSasPermission()
.setReadPermission(true)
.setCreatePermission(true)
.setListPermission(true)
val blobServiceSasSignatureValues = new BlobServiceSasSignatureValues(
OffsetDateTime.now.plusDays(1),
blobContainerSasPermission
)

Try(new AzureSasCredential(blobContainerClient.generateSas(blobServiceSasSignatureValues)))
private def azure = AzureResourceManager.authenticate(azureCredentialBuilder, azureProfile).withSubscription("62b22893-6bc1-46d9-8a90-806bb3cce3c9")

private def findAzureStorageAccount(name: StorageAccountName) = azure.storageAccounts.list.asScala.find(_.name.equals(name.value))
.fold[Try[StorageAccount]](Failure(new Exception("Azure Storage Account not found")))(Success(_))
private def buildBlobContainerClient(credential: StorageSharedKeyCredential, endpoint: EndpointURL, container: BlobContainerName): BlobContainerClient = {
new BlobContainerClientBuilder()
.credential(credential)
.endpoint(endpoint.value)
.containerName(container.value)
.buildClient()
}
private val bcsp = new BlobContainerSasPermission()
.setReadPermission(true)
.setCreatePermission(true)
.setListPermission(true)


def generateAccessToken: Try[AzureSasCredential] = for {
configuredAccount <- BlobPathBuilder.parseStorageAccount(BlobPathBuilder.parseURI(endpoint.value))
azureAccount <- findAzureStorageAccount(configuredAccount)
keys = azureAccount.getKeys.asScala
key <- keys.headOption.fold[Try[StorageAccountKey]](Failure(new Exception("Storage account has no keys")))(Success(_))
first = key.value
sskc = new StorageSharedKeyCredential(configuredAccount.value, first)
bcc = buildBlobContainerClient(sskc, endpoint, container)
bsssv = new BlobServiceSasSignatureValues(OffsetDateTime.now.plusDays(1), bcsp)
asc = new AzureSasCredential(bcc.generateSas(bsssv))
} yield asc
}

Loading

0 comments on commit a0e1606

Please sign in to comment.