Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for checksum algorithms in AWS #3873

Merged
merged 4 commits into from
Mar 21, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 51 additions & 0 deletions object_store/src/aws/checksum.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

use ring::digest::{self, digest as ring_digest};

#[allow(non_camel_case_types)]
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
/// Enum representing checksum algorithm supported by S3.
pub enum Checksum {
/// SHA-256 algorithm.
SHA256,
}

impl Checksum {
pub(super) fn digest(&self, bytes: &[u8]) -> Vec<u8> {
match self {
Self::SHA256 => ring_digest(&digest::SHA256, bytes).as_ref().to_owned(),
}
}

pub(super) fn header_name(&self) -> &'static str {
match self {
Self::SHA256 => "x-amz-checksum-sha256",
}
}
}

impl TryFrom<&String> for Checksum {
type Error = ();

fn try_from(value: &String) -> Result<Self, Self::Error> {
match value.as_str() {
"sha256" => Ok(Self::SHA256),
_ => Err(()),
}
}
}
24 changes: 22 additions & 2 deletions object_store/src/aws/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
// specific language governing permissions and limitations
// under the License.

use crate::aws::checksum::Checksum;
use crate::aws::credential::{AwsCredential, CredentialExt, CredentialProvider};
use crate::aws::STRICT_PATH_ENCODE_SET;
use crate::client::pagination::stream_paginated;
Expand All @@ -26,6 +27,8 @@ use crate::{
BoxStream, ClientOptions, ListResult, MultipartId, ObjectMeta, Path, Result,
RetryConfig, StreamExt,
};
use base64::prelude::BASE64_STANDARD;
use base64::Engine;
use bytes::{Buf, Bytes};
use chrono::{DateTime, Utc};
use percent_encoding::{utf8_percent_encode, PercentEncode};
Expand Down Expand Up @@ -205,6 +208,7 @@ pub struct S3Config {
pub retry_config: RetryConfig,
pub client_options: ClientOptions,
pub sign_payload: bool,
pub checksum: Option<Checksum>,
}

impl S3Config {
Expand Down Expand Up @@ -262,6 +266,7 @@ impl S3Client {
&self.config.region,
"s3",
self.config.sign_payload,
None,
)
.send_retry(&self.config.retry_config)
.await
Expand All @@ -281,10 +286,19 @@ impl S3Client {
) -> Result<Response> {
let credential = self.get_credential().await?;
let url = self.config.path_url(path);

let mut builder = self.client.request(Method::PUT, url);
let mut payload_sha256 = None;

if let Some(bytes) = bytes {
builder = builder.body(bytes)
if let Some(checksum) = self.config().checksum {
let digest = checksum.digest(&bytes);
builder = builder
.header(checksum.header_name(), BASE64_STANDARD.encode(&digest));
if checksum == Checksum::SHA256 {
payload_sha256 = Some(digest);
}
}
builder = builder.body(bytes);
}

if let Some(value) = self.config().client_options.get_content_type(path) {
Expand All @@ -298,6 +312,7 @@ impl S3Client {
&self.config.region,
"s3",
self.config.sign_payload,
payload_sha256,
)
.send_retry(&self.config.retry_config)
.await
Expand Down Expand Up @@ -325,6 +340,7 @@ impl S3Client {
&self.config.region,
"s3",
self.config.sign_payload,
None,
)
.send_retry(&self.config.retry_config)
.await
Expand All @@ -349,6 +365,7 @@ impl S3Client {
&self.config.region,
"s3",
self.config.sign_payload,
None,
)
.send_retry(&self.config.retry_config)
.await
Expand Down Expand Up @@ -395,6 +412,7 @@ impl S3Client {
&self.config.region,
"s3",
self.config.sign_payload,
None,
)
.send_retry(&self.config.retry_config)
.await
Expand Down Expand Up @@ -438,6 +456,7 @@ impl S3Client {
&self.config.region,
"s3",
self.config.sign_payload,
None,
)
.send_retry(&self.config.retry_config)
.await
Expand Down Expand Up @@ -482,6 +501,7 @@ impl S3Client {
&self.config.region,
"s3",
self.config.sign_payload,
None,
)
.send_retry(&self.config.retry_config)
.await
Expand Down
22 changes: 14 additions & 8 deletions object_store/src/aws/credential.rs
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ const AUTH_HEADER: &str = "authorization";
const ALL_HEADERS: &[&str; 4] = &[DATE_HEADER, HASH_HEADER, TOKEN_HEADER, AUTH_HEADER];

impl<'a> RequestSigner<'a> {
fn sign(&self, request: &mut Request) {
fn sign(&self, request: &mut Request, pre_calculated_digest: Option<Vec<u8>>) {
if let Some(ref token) = self.credential.token {
let token_val = HeaderValue::from_str(token).unwrap();
request.headers_mut().insert(TOKEN_HEADER, token_val);
Expand All @@ -101,9 +101,13 @@ impl<'a> RequestSigner<'a> {
request.headers_mut().insert(DATE_HEADER, date_val);

let digest = if self.sign_payload {
match request.body() {
None => EMPTY_SHA256_HASH.to_string(),
Some(body) => hex_digest(body.as_bytes().unwrap()),
if let Some(digest) = pre_calculated_digest {
hex_encode(&digest)
} else {
match request.body() {
None => EMPTY_SHA256_HASH.to_string(),
Some(body) => hex_digest(body.as_bytes().unwrap()),
}
}
} else {
UNSIGNED_PAYLOAD_LITERAL.to_string()
Expand Down Expand Up @@ -165,6 +169,7 @@ pub trait CredentialExt {
region: &str,
service: &str,
sign_payload: bool,
payload_sha256: Option<Vec<u8>>,
) -> Self;
}

Expand All @@ -175,6 +180,7 @@ impl CredentialExt for RequestBuilder {
region: &str,
service: &str,
sign_payload: bool,
payload_sha256: Option<Vec<u8>>,
) -> Self {
// Hack around lack of access to underlying request
// https://github.com/seanmonstar/reqwest/issues/1212
Expand All @@ -193,7 +199,7 @@ impl CredentialExt for RequestBuilder {
sign_payload,
};

signer.sign(&mut request);
signer.sign(&mut request, payload_sha256);

for header in ALL_HEADERS {
if let Some(val) = request.headers_mut().remove(*header) {
Expand Down Expand Up @@ -627,7 +633,7 @@ mod tests {
sign_payload: true,
};

signer.sign(&mut request);
signer.sign(&mut request, None);
assert_eq!(request.headers().get(AUTH_HEADER).unwrap(), "AWS4-HMAC-SHA256 Credential=AKIAIOSFODNN7EXAMPLE/20220806/us-east-1/ec2/aws4_request, SignedHeaders=host;x-amz-content-sha256;x-amz-date, Signature=a3c787a7ed37f7fdfbfd2d7056a3d7c9d85e6d52a2bfbec73793c0be6e7862d4")
}

Expand Down Expand Up @@ -665,7 +671,7 @@ mod tests {
sign_payload: false,
};

signer.sign(&mut request);
signer.sign(&mut request, None);
assert_eq!(request.headers().get(AUTH_HEADER).unwrap(), "AWS4-HMAC-SHA256 Credential=AKIAIOSFODNN7EXAMPLE/20220806/us-east-1/ec2/aws4_request, SignedHeaders=host;x-amz-content-sha256;x-amz-date, Signature=653c3d8ea261fd826207df58bc2bb69fbb5003e9eb3c0ef06e4a51f2a81d8699")
}

Expand Down Expand Up @@ -702,7 +708,7 @@ mod tests {
sign_payload: true,
};

signer.sign(&mut request);
signer.sign(&mut request, None);
assert_eq!(request.headers().get(AUTH_HEADER).unwrap(), "AWS4-HMAC-SHA256 Credential=H20ABqCkLZID4rLe/20220809/us-east-1/s3/aws4_request, SignedHeaders=host;x-amz-content-sha256;x-amz-date, Signature=9ebf2f92872066c99ac94e573b4e1b80f4dbb8a32b1e8e23178318746e7d1b4d")
}

Expand Down
37 changes: 37 additions & 0 deletions object_store/src/aws/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ use tokio::io::AsyncWrite;
use tracing::info;
use url::Url;

pub use crate::aws::checksum::Checksum;
use crate::aws::client::{S3Client, S3Config};
use crate::aws::credential::{
AwsCredential, CredentialProvider, InstanceCredentialProvider,
Expand All @@ -59,6 +60,7 @@ use crate::{
Result, RetryConfig, StreamExt,
};

mod checksum;
mod client;
mod credential;

Expand Down Expand Up @@ -101,6 +103,9 @@ enum Error {
source: std::num::ParseIntError,
},

#[snafu(display("Invalid Checksum algorithm"))]
InvalidChecksumAlgorithm,

#[snafu(display("Missing region"))]
MissingRegion,

Expand Down Expand Up @@ -386,6 +391,7 @@ pub struct AmazonS3Builder {
imdsv1_fallback: bool,
virtual_hosted_style_request: bool,
unsigned_payload: bool,
checksum_algorithm: Option<Checksum>,
metadata_endpoint: Option<String>,
profile: Option<String>,
client_options: ClientOptions,
Expand Down Expand Up @@ -514,6 +520,11 @@ pub enum AmazonS3ConfigKey {
/// - `unsigned_payload`
UnsignedPayload,

/// Set the checksum algorithm for this client
///
/// See [`AmazonS3Builder::with_checksum_algorithm`]
Checksum,

/// Set the instance metadata endpoint
///
/// See [`AmazonS3Builder::with_metadata_endpoint`] for details.
Expand Down Expand Up @@ -546,6 +557,7 @@ impl AsRef<str> for AmazonS3ConfigKey {
Self::MetadataEndpoint => "aws_metadata_endpoint",
Self::Profile => "aws_profile",
Self::UnsignedPayload => "aws_unsigned_payload",
Self::Checksum => "aws_checksum_algorithm",
}
}
}
Expand Down Expand Up @@ -575,6 +587,7 @@ impl FromStr for AmazonS3ConfigKey {
"aws_imdsv1_fallback" | "imdsv1_fallback" => Ok(Self::ImdsV1Fallback),
"aws_metadata_endpoint" | "metadata_endpoint" => Ok(Self::MetadataEndpoint),
"aws_unsigned_payload" | "unsigned_payload" => Ok(Self::UnsignedPayload),
"aws_checksum_algorithm" | "checksum_algorithm" => Ok(Self::Checksum),
_ => Err(Error::UnknownConfigurationKey { key: s.into() }.into()),
}
}
Expand Down Expand Up @@ -694,6 +707,11 @@ impl AmazonS3Builder {
AmazonS3ConfigKey::UnsignedPayload => {
self.unsigned_payload = str_is_truthy(&value.into())
}
AmazonS3ConfigKey::Checksum => {
let algorithm = Checksum::try_from(&value.into())
.map_err(|_| Error::InvalidChecksumAlgorithm)?;
self.checksum_algorithm = Some(algorithm)
}
};
Ok(self)
}
Expand Down Expand Up @@ -846,6 +864,14 @@ impl AmazonS3Builder {
self
}

/// Sets the [checksum algorithm] which has to be used for object integrity check during upload.
///
/// [checksum algorithm]: https://docs.aws.amazon.com/AmazonS3/latest/userguide/checking-object-integrity.html
pub fn with_checksum_algorithm(mut self, checksum_algorithm: Checksum) -> Self {
self.checksum_algorithm = Some(checksum_algorithm);
self
}

/// Set the [instance metadata endpoint](https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/ec2-instance-metadata.html),
/// used primarily within AWS EC2.
///
Expand Down Expand Up @@ -992,6 +1018,7 @@ impl AmazonS3Builder {
retry_config: self.retry_config,
client_options: self.client_options,
sign_payload: !self.unsigned_payload,
checksum: self.checksum_algorithm,
};

let client = Arc::new(S3Client::new(config)?);
Expand Down Expand Up @@ -1151,6 +1178,7 @@ mod tests {
&container_creds_relative_uri,
);
env::set_var("AWS_UNSIGNED_PAYLOAD", "true");
env::set_var("AWS_CHECKSUM_ALGORITHM", "sha256");

let builder = AmazonS3Builder::from_env();
assert_eq!(builder.access_key_id.unwrap(), aws_access_key_id.as_str());
Expand All @@ -1164,6 +1192,7 @@ mod tests {
assert_eq!(builder.token.unwrap(), aws_session_token);
let metadata_uri = format!("{METADATA_ENDPOINT}{container_creds_relative_uri}");
assert_eq!(builder.metadata_endpoint.unwrap(), metadata_uri);
assert_eq!(builder.checksum_algorithm.unwrap(), Checksum::SHA256);
assert!(builder.unsigned_payload);
}

Expand All @@ -1181,6 +1210,7 @@ mod tests {
("aws_endpoint", aws_endpoint.clone()),
("aws_session_token", aws_session_token.clone()),
("aws_unsigned_payload", "true".to_string()),
("aws_checksum_algorithm", "sha256".to_string()),
]);

let builder = AmazonS3Builder::new()
Expand All @@ -1193,6 +1223,7 @@ mod tests {
assert_eq!(builder.region.unwrap(), aws_default_region);
assert_eq!(builder.endpoint.unwrap(), aws_endpoint);
assert_eq!(builder.token.unwrap(), aws_session_token);
assert_eq!(builder.checksum_algorithm.unwrap(), Checksum::SHA256);
assert!(builder.unsigned_payload);
}

Expand Down Expand Up @@ -1256,6 +1287,12 @@ mod tests {
let is_local = matches!(&config.endpoint, Some(e) if e.starts_with("http://"));
let integration = config.build().unwrap();
put_get_delete_list_opts(&integration, is_local).await;

// run integration test with checksum set to sha256
let config = maybe_skip_integration!().with_checksum_algorithm(Checksum::SHA256);
let is_local = matches!(&config.endpoint, Some(e) if e.starts_with("http://"));
let integration = config.build().unwrap();
put_get_delete_list_opts(&integration, is_local).await;
}

#[tokio::test]
Expand Down