Skip to content

Commit

Permalink
Use our own cache for AWS tokens
Browse files Browse the repository at this point in the history
  • Loading branch information
mwylde committed Jan 8, 2025
1 parent 6cf294c commit b581ce5
Show file tree
Hide file tree
Showing 4 changed files with 102 additions and 47 deletions.
2 changes: 1 addition & 1 deletion Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 1 addition & 2 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -90,8 +90,7 @@ datafusion-functions-window = {git = 'https://github.com/ArroyoSystems/arrow-dat

datafusion-functions-json = {git = 'https://github.com/ArroyoSystems/datafusion-functions-json', branch = 'datafusion_43'}

# object_store = { git = 'http://github.com/ArroyoSystems/arrow-rs', branch = 'object_store_0.11.1/arroyo' }
object_store = { git = 'http://github.com/ArroyoSystems/arrow-rs', branch = 'public_token_cache' }
object_store = { git = 'http://github.com/ArroyoSystems/arrow-rs', branch = 'object_store_0.11.1/arroyo' }

cornucopia_async = { git = "https://github.com/ArroyoSystems/cornucopia", branch = "sqlite" }
cornucopia = { git = "https://github.com/ArroyoSystems/cornucopia", branch = "sqlite" }
Expand Down
7 changes: 1 addition & 6 deletions crates/arroyo-connectors/src/filesystem/sink/delta.rs
Original file line number Diff line number Diff line change
Expand Up @@ -67,14 +67,9 @@ pub(crate) async fn load_or_create_table(
};

let mut delta = DeltaTableBuilder::from_uri(&url)
.with_storage_backend(
backing_store,
Url::parse(&storage_provider.canonical_url())?,
)
.with_storage_backend(backing_store, Url::parse(storage_provider.canonical_url())?)
.build()?;

println!("Table uri = {}", delta.table_uri());

if delta.verify_deltatable_existence().await? {
delta.load().await?;
Ok(delta)
Expand Down
137 changes: 99 additions & 38 deletions crates/arroyo-storage/src/aws.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,22 @@ use crate::StorageError;
use aws_config::timeout::TimeoutConfig;
use aws_config::{BehaviorVersion, SdkConfig};
use aws_credential_types::provider::{ProvideCredentials, SharedCredentialsProvider};
use object_store::{aws::AwsCredential, CredentialProvider, TemporaryToken, TokenCache};
use std::error::Error;
use std::sync::Arc;
use std::time::{Duration, Instant};
use object_store::{aws::AwsCredential, CredentialProvider};
use std::sync::{Arc, Mutex};
use std::time::{Duration, Instant, SystemTime};
use tokio::sync::OnceCell;
use tokio::task::JoinHandle;
use tracing::info;

const EXPIRATION_BUFFER: Duration = Duration::from_secs(5 * 60);

type TemporaryToken = (Arc<AwsCredential>, Option<SystemTime>, Instant);

#[derive(Clone)]
pub struct ArroyoCredentialProvider {
cache: TokenCache<Arc<AwsCredential>>,
cache: Arc<tokio::sync::Mutex<Option<TemporaryToken>>>,
provider: SharedCredentialsProvider,
refresh_task: Arc<Mutex<Option<JoinHandle<()>>>>,
}

impl std::fmt::Debug for ArroyoCredentialProvider {
Expand All @@ -21,9 +27,10 @@ impl std::fmt::Debug for ArroyoCredentialProvider {
}

static AWS_CONFIG: OnceCell<Arc<SdkConfig>> = OnceCell::const_new();
static CREDENTIAL_PROVIDER: OnceCell<ArroyoCredentialProvider> = OnceCell::const_new();

async fn get_config<'a>() -> &'a SdkConfig {
&*AWS_CONFIG
AWS_CONFIG
.get_or_init(|| async {
Arc::new(
aws_config::defaults(BehaviorVersion::latest())
Expand All @@ -42,21 +49,28 @@ async fn get_config<'a>() -> &'a SdkConfig {

impl ArroyoCredentialProvider {
pub async fn try_new() -> Result<Self, StorageError> {
let config = get_config().await;

let credentials = config
.credentials_provider()
.ok_or_else(|| {
StorageError::CredentialsError(
"Unable to load S3 credentials from environment".to_string(),
)
})?
.clone();

Ok(Self {
cache: Default::default(),
provider: credentials,
})
Ok(CREDENTIAL_PROVIDER
.get_or_try_init(|| async {
let config = get_config().await;

let credentials = config
.credentials_provider()
.ok_or_else(|| {
StorageError::CredentialsError(
"Unable to load S3 credentials from environment".to_string(),
)
})?
.clone();

info!("Creating credential provider");
Ok::<Self, StorageError>(Self {
cache: Default::default(),
refresh_task: Default::default(),
provider: credentials,
})
})
.await?
.clone())
}

pub async fn default_region() -> Option<String> {
Expand All @@ -66,40 +80,87 @@ impl ArroyoCredentialProvider {

async fn get_token(
provider: &SharedCredentialsProvider,
) -> Result<TemporaryToken<Arc<AwsCredential>>, Box<dyn Error + Send + Sync>> {
info!("Getting credentials");
) -> Result<(Arc<AwsCredential>, Option<SystemTime>, Instant), object_store::Error> {
info!("fetching new AWS token");
let creds = provider
.provide_credentials()
.await
.map_err(|e| object_store::Error::Generic {
store: "S3",
source: Box::new(e),
})?;
info!("Got credentials = {:?}", creds);
let expiry = creds
.expiry()
.map(|exp| Instant::now() + exp.elapsed().unwrap_or_default());
Ok(TemporaryToken {
token: Arc::new(AwsCredential {
Ok((
Arc::new(AwsCredential {
key_id: creds.access_key_id().to_string(),
secret_key: creds.secret_access_key().to_string(),
token: creds.session_token().map(ToString::to_string),
}),
expiry,
})
creds.expiry(),
Instant::now(),
))
}

#[async_trait::async_trait]
impl CredentialProvider for ArroyoCredentialProvider {
type Credential = AwsCredential;

async fn get_credential(&self) -> object_store::Result<Arc<Self::Credential>> {
self.cache
.get_or_insert_with(|| get_token(&self.provider))
.await
.map_err(|e| object_store::Error::Generic {
store: "S3",
source: e,
})
let token = self.cache.lock().await.clone();
match token {
Some((token, Some(expiration), last_refreshed)) => {
let expires_in = expiration
.duration_since(SystemTime::now())
.unwrap_or_default();
if expires_in < Duration::from_millis(100) {
info!("AWS token has expired, immediately refreshing");
let lock = self.cache.try_lock();

let token = get_token(&self.provider).await?;

if let Ok(mut lock) = lock {
*lock = Some(token.clone());
}
return Ok(token.0);
}

if expires_in < EXPIRATION_BUFFER
&& last_refreshed.elapsed() > Duration::from_millis(100)
{
let refresh_lock = self.refresh_task.try_lock();
if let Ok(mut task) = refresh_lock {
if task.is_some() && !task.as_ref().unwrap().is_finished() {
// the task is working on refreshing, let it do its job
return Ok(token);
}

// else we need to start a refresh task
let our_provider = self.provider.clone();
let our_lock = self.cache.clone();
*task = Some(tokio::spawn(async move {
let token = get_token(&our_provider)
.await
.unwrap_or_else(|e| panic!("Failed to refresh AWS token: {:?}", e));

let mut lock = our_lock.lock().await;
*lock = Some(token);
}));
}
}

Ok(token)
}
Some((token, None, _)) => Ok(token),
None => {
// get the initial token
let mut cache = self.cache.lock().await;
if let Some((token, _, _)) = &*cache {
return Ok(token.clone());
}

let token = get_token(&self.provider).await?;
*cache = Some(token.clone());
Ok(token.0)
}
}
}
}

0 comments on commit b581ce5

Please sign in to comment.