Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Simplify ObjectStore configuration pattern #4189

Merged
merged 1 commit into from
May 10, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 13 additions & 4 deletions object_store/src/aws/checksum.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
// under the License.

use ring::digest::{self, digest as ring_digest};
use std::str::FromStr;

#[allow(non_camel_case_types)]
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
Expand Down Expand Up @@ -47,13 +48,21 @@ impl std::fmt::Display for Checksum {
}
}

impl TryFrom<&String> for Checksum {
type Error = ();
impl FromStr for Checksum {
type Err = ();

fn try_from(value: &String) -> Result<Self, Self::Error> {
match value.to_lowercase().as_str() {
fn from_str(s: &str) -> Result<Self, Self::Err> {
match s.to_lowercase().as_str() {
"sha256" => Ok(Self::SHA256),
_ => Err(()),
}
}
}

impl TryFrom<&String> for Checksum {
type Error = ();

fn try_from(value: &String) -> Result<Self, Self::Error> {
value.parse()
}
}
155 changes: 59 additions & 96 deletions object_store/src/aws/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -423,7 +423,7 @@ pub struct AmazonS3Builder {
/// When set to true, unsigned payload option has to be used
unsigned_payload: bool,
/// Checksum algorithm which has to be used for object integrity check during upload
checksum_algorithm: Option<Checksum>,
checksum_algorithm: Option<String>,
/// Metadata endpoint, see <https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/ec2-instance-metadata.html>
metadata_endpoint: Option<String>,
/// Profile name, see <https://docs.aws.amazon.com/cli/latest/userguide/cli-configure-profiles.html>
Expand All @@ -434,30 +434,17 @@ pub struct AmazonS3Builder {

/// Configuration keys for [`AmazonS3Builder`]
///
/// Configuration via keys can be dome via the [`try_with_option`](AmazonS3Builder::try_with_option)
/// or [`with_options`](AmazonS3Builder::try_with_options) methods on the builder.
/// Configuration via keys can be done via [`AmazonS3Builder::with_config`]
///
/// # Example
/// ```
/// use std::collections::HashMap;
/// use object_store::aws::{AmazonS3Builder, AmazonS3ConfigKey};
///
/// let options = HashMap::from([
/// ("aws_access_key_id", "my-access-key-id"),
/// ("aws_secret_access_key", "my-secret-access-key"),
/// ]);
/// let typed_options = vec![
/// (AmazonS3ConfigKey::DefaultRegion, "my-default-region"),
/// ];
/// let aws = AmazonS3Builder::new()
/// .try_with_options(options)
/// .unwrap()
/// .try_with_options(typed_options)
/// .unwrap()
/// .try_with_option(AmazonS3ConfigKey::Region, "my-region")
/// .unwrap();
/// # use object_store::aws::{AmazonS3Builder, AmazonS3ConfigKey};
/// let builder = AmazonS3Builder::new()
/// .with_config("aws_access_key_id".parse().unwrap(), "my-access-key-id")
/// .with_config(AmazonS3ConfigKey::DefaultRegion, "my-default-region");
/// ```
#[derive(PartialEq, Eq, Hash, Clone, Debug, Copy, Serialize, Deserialize)]
#[non_exhaustive]
pub enum AmazonS3ConfigKey {
/// AWS Access Key
///
Expand Down Expand Up @@ -662,7 +649,7 @@ impl AmazonS3Builder {
if let Ok(config_key) =
AmazonS3ConfigKey::from_str(&key.to_ascii_lowercase())
{
builder = builder.try_with_option(config_key, value).unwrap();
builder = builder.with_config(config_key, value);
}
}
}
Expand Down Expand Up @@ -710,14 +697,12 @@ impl AmazonS3Builder {
}

/// Set an option on the builder via a key - value pair.
///
/// This method will return an `UnknownConfigKey` error if key cannot be parsed into [`AmazonS3ConfigKey`].
pub fn try_with_option(
pub fn with_config(
mut self,
key: impl AsRef<str>,
key: AmazonS3ConfigKey,
value: impl Into<String>,
) -> Result<Self> {
match AmazonS3ConfigKey::from_str(key.as_ref())? {
) -> Self {
match key {
AmazonS3ConfigKey::AccessKeyId => self.access_key_id = Some(value.into()),
AmazonS3ConfigKey::SecretAccessKey => {
self.secret_access_key = Some(value.into())
Expand All @@ -742,18 +727,28 @@ impl AmazonS3Builder {
AmazonS3ConfigKey::UnsignedPayload => {
self.unsigned_payload = str_is_truthy(&value.into())
}
AmazonS3ConfigKey::Checksum => {
let algorithm = Checksum::try_from(&value.into())
.map_err(|_| Error::InvalidChecksumAlgorithm)?;
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was problematic as this error would get unwrapped in from_env

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is fixed properly in #4192

self.checksum_algorithm = Some(algorithm)
}
AmazonS3ConfigKey::Checksum => self.checksum_algorithm = Some(value.into()),
};
Ok(self)
self
}

/// Set an option on the builder via a key - value pair.
///
/// This method will return an `UnknownConfigKey` error if key cannot be parsed into [`AmazonS3ConfigKey`].
#[deprecated(note = "Use with_config")]
pub fn try_with_option(
self,
key: impl AsRef<str>,
value: impl Into<String>,
) -> Result<Self> {
Ok(self.with_config(key.as_ref().parse()?, value))
}

/// Hydrate builder from key value pairs
///
/// This method will return an `UnknownConfigKey` error if any key cannot be parsed into [`AmazonS3ConfigKey`].
#[deprecated(note = "Use with_config")]
#[allow(deprecated)]
pub fn try_with_options<
I: IntoIterator<Item = (impl AsRef<str>, impl Into<String>)>,
>(
Expand Down Expand Up @@ -794,7 +789,7 @@ impl AmazonS3Builder {
AmazonS3ConfigKey::MetadataEndpoint => self.metadata_endpoint.clone(),
AmazonS3ConfigKey::Profile => self.profile.clone(),
AmazonS3ConfigKey::UnsignedPayload => Some(self.unsigned_payload.to_string()),
AmazonS3ConfigKey::Checksum => self.checksum_algorithm.map(|v| v.to_string()),
AmazonS3ConfigKey::Checksum => self.checksum_algorithm.clone(),
}
}

Expand Down Expand Up @@ -935,7 +930,8 @@ impl AmazonS3Builder {
///
/// [checksum algorithm]: https://docs.aws.amazon.com/AmazonS3/latest/userguide/checking-object-integrity.html
pub fn with_checksum_algorithm(mut self, checksum_algorithm: Checksum) -> Self {
self.checksum_algorithm = Some(checksum_algorithm);
// Convert to String to enable deferred parsing of config
self.checksum_algorithm = Some(checksum_algorithm.to_string());
self
}

Expand Down Expand Up @@ -988,6 +984,11 @@ impl AmazonS3Builder {

let bucket = self.bucket_name.context(MissingBucketNameSnafu)?;
let region = self.region.context(MissingRegionSnafu)?;
let checksum = self
.checksum_algorithm
.map(|c| c.parse())
.transpose()
.map_err(|_| Error::InvalidChecksumAlgorithm)?;

let credentials = match (self.access_key_id, self.secret_access_key, self.token) {
(Some(key_id), Some(secret_key), token) => {
Expand Down Expand Up @@ -1085,7 +1086,7 @@ impl AmazonS3Builder {
retry_config: self.retry_config,
client_options: self.client_options,
sign_payload: !self.unsigned_payload,
checksum: self.checksum_algorithm,
checksum,
};

let client = Arc::new(S3Client::new(config)?);
Expand Down Expand Up @@ -1259,7 +1260,10 @@ mod tests {
assert_eq!(builder.token.unwrap(), aws_session_token);
let metadata_uri = format!("{METADATA_ENDPOINT}{container_creds_relative_uri}");
assert_eq!(builder.metadata_endpoint.unwrap(), metadata_uri);
assert_eq!(builder.checksum_algorithm.unwrap(), Checksum::SHA256);
assert_eq!(
builder.checksum_algorithm.unwrap(),
Checksum::SHA256.to_string()
);
assert!(builder.unsigned_payload);
}

Expand All @@ -1280,46 +1284,22 @@ mod tests {
("aws_checksum_algorithm", "sha256".to_string()),
]);

let builder = AmazonS3Builder::new()
.try_with_options(&options)
.unwrap()
.try_with_option("aws_secret_access_key", "new-secret-key")
.unwrap();
assert_eq!(builder.access_key_id.unwrap(), aws_access_key_id.as_str());
assert_eq!(builder.secret_access_key.unwrap(), "new-secret-key");
assert_eq!(builder.region.unwrap(), aws_default_region);
assert_eq!(builder.endpoint.unwrap(), aws_endpoint);
assert_eq!(builder.token.unwrap(), aws_session_token);
assert_eq!(builder.checksum_algorithm.unwrap(), Checksum::SHA256);
assert!(builder.unsigned_payload);
}

#[test]
fn s3_test_config_from_typed_map() {
let aws_access_key_id = "object_store:fake_access_key_id".to_string();
let aws_secret_access_key = "object_store:fake_secret_key".to_string();
let aws_default_region = "object_store:fake_default_region".to_string();
let aws_endpoint = "object_store:fake_endpoint".to_string();
let aws_session_token = "object_store:fake_session_token".to_string();
let options = HashMap::from([
(AmazonS3ConfigKey::AccessKeyId, aws_access_key_id.clone()),
(AmazonS3ConfigKey::SecretAccessKey, aws_secret_access_key),
(AmazonS3ConfigKey::DefaultRegion, aws_default_region.clone()),
(AmazonS3ConfigKey::Endpoint, aws_endpoint.clone()),
(AmazonS3ConfigKey::Token, aws_session_token.clone()),
(AmazonS3ConfigKey::UnsignedPayload, "true".to_string()),
]);
let builder = options
.into_iter()
.fold(AmazonS3Builder::new(), |builder, (key, value)| {
builder.with_config(key.parse().unwrap(), value)
})
.with_config(AmazonS3ConfigKey::SecretAccessKey, "new-secret-key");

let builder = AmazonS3Builder::new()
.try_with_options(&options)
.unwrap()
.try_with_option(AmazonS3ConfigKey::SecretAccessKey, "new-secret-key")
.unwrap();
assert_eq!(builder.access_key_id.unwrap(), aws_access_key_id.as_str());
assert_eq!(builder.secret_access_key.unwrap(), "new-secret-key");
assert_eq!(builder.region.unwrap(), aws_default_region);
assert_eq!(builder.endpoint.unwrap(), aws_endpoint);
assert_eq!(builder.token.unwrap(), aws_session_token);
assert_eq!(
builder.checksum_algorithm.unwrap(),
Checksum::SHA256.to_string()
);
assert!(builder.unsigned_payload);
}

Expand All @@ -1330,19 +1310,15 @@ mod tests {
let aws_default_region = "object_store:fake_default_region".to_string();
let aws_endpoint = "object_store:fake_endpoint".to_string();
let aws_session_token = "object_store:fake_session_token".to_string();
let options = HashMap::from([
(AmazonS3ConfigKey::AccessKeyId, aws_access_key_id.clone()),
(
AmazonS3ConfigKey::SecretAccessKey,
aws_secret_access_key.clone(),
),
(AmazonS3ConfigKey::DefaultRegion, aws_default_region.clone()),
(AmazonS3ConfigKey::Endpoint, aws_endpoint.clone()),
(AmazonS3ConfigKey::Token, aws_session_token.clone()),
(AmazonS3ConfigKey::UnsignedPayload, "true".to_string()),
]);

let builder = AmazonS3Builder::new().try_with_options(&options).unwrap();
let builder = AmazonS3Builder::new()
.with_config(AmazonS3ConfigKey::AccessKeyId, &aws_access_key_id)
.with_config(AmazonS3ConfigKey::SecretAccessKey, &aws_secret_access_key)
.with_config(AmazonS3ConfigKey::DefaultRegion, &aws_default_region)
.with_config(AmazonS3ConfigKey::Endpoint, &aws_endpoint)
.with_config(AmazonS3ConfigKey::Token, &aws_session_token)
.with_config(AmazonS3ConfigKey::UnsignedPayload, "true");

assert_eq!(
builder
.get_config_value(&AmazonS3ConfigKey::AccessKeyId)
Expand Down Expand Up @@ -1379,19 +1355,6 @@ mod tests {
);
}

#[test]
fn s3_test_config_fallible_options() {
let aws_access_key_id = "object_store:fake_access_key_id".to_string();
let aws_secret_access_key = "object_store:fake_secret_key".to_string();
let options = HashMap::from([
("aws_access_key_id", aws_access_key_id),
("invalid-key", aws_secret_access_key),
]);

let builder = AmazonS3Builder::new().try_with_options(&options);
assert!(builder.is_err());
}

#[tokio::test]
async fn s3_test() {
let config = maybe_skip_integration!();
Expand Down
Loading