diff --git a/object_store/src/aws/checksum.rs b/object_store/src/aws/checksum.rs new file mode 100644 index 000000000000..ae35f0612456 --- /dev/null +++ b/object_store/src/aws/checksum.rs @@ -0,0 +1,51 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use ring::digest::{self, digest as ring_digest}; + +#[allow(non_camel_case_types)] +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +/// Enum representing checksum algorithm supported by S3. +pub enum Checksum { + /// SHA-256 algorithm. + SHA256, +} + +impl Checksum { + pub(super) fn digest(&self, bytes: &[u8]) -> Vec { + match self { + Self::SHA256 => ring_digest(&digest::SHA256, bytes).as_ref().to_owned(), + } + } + + pub(super) fn header_name(&self) -> &'static str { + match self { + Self::SHA256 => "x-amz-checksum-sha256", + } + } +} + +impl TryFrom<&String> for Checksum { + type Error = (); + + fn try_from(value: &String) -> Result { + match value.as_str() { + "sha256" => Ok(Self::SHA256), + _ => Err(()), + } + } +} diff --git a/object_store/src/aws/client.rs b/object_store/src/aws/client.rs index 0b0f883b7e51..bd58d09676aa 100644 --- a/object_store/src/aws/client.rs +++ b/object_store/src/aws/client.rs @@ -15,6 +15,7 @@ // specific language governing permissions and limitations // under the License. +use crate::aws::checksum::Checksum; use crate::aws::credential::{AwsCredential, CredentialExt, CredentialProvider}; use crate::aws::STRICT_PATH_ENCODE_SET; use crate::client::pagination::stream_paginated; @@ -26,6 +27,8 @@ use crate::{ BoxStream, ClientOptions, ListResult, MultipartId, ObjectMeta, Path, Result, RetryConfig, StreamExt, }; +use base64::prelude::BASE64_STANDARD; +use base64::Engine; use bytes::{Buf, Bytes}; use chrono::{DateTime, Utc}; use percent_encoding::{utf8_percent_encode, PercentEncode}; @@ -205,6 +208,7 @@ pub struct S3Config { pub retry_config: RetryConfig, pub client_options: ClientOptions, pub sign_payload: bool, + pub checksum: Option, } impl S3Config { @@ -262,6 +266,7 @@ impl S3Client { &self.config.region, "s3", self.config.sign_payload, + None, ) .send_retry(&self.config.retry_config) .await @@ -281,10 +286,19 @@ impl S3Client { ) -> Result { let credential = self.get_credential().await?; let url = self.config.path_url(path); - let mut builder = self.client.request(Method::PUT, url); + let mut payload_sha256 = None; + if let Some(bytes) = bytes { - builder = builder.body(bytes) + if let Some(checksum) = self.config().checksum { + let digest = checksum.digest(&bytes); + builder = builder + .header(checksum.header_name(), BASE64_STANDARD.encode(&digest)); + if checksum == Checksum::SHA256 { + payload_sha256 = Some(digest); + } + } + builder = builder.body(bytes); } if let Some(value) = self.config().client_options.get_content_type(path) { @@ -298,6 +312,7 @@ impl S3Client { &self.config.region, "s3", self.config.sign_payload, + payload_sha256, ) .send_retry(&self.config.retry_config) .await @@ -325,6 +340,7 @@ impl S3Client { &self.config.region, "s3", self.config.sign_payload, + None, ) .send_retry(&self.config.retry_config) .await @@ -349,6 +365,7 @@ impl S3Client { &self.config.region, "s3", self.config.sign_payload, + None, ) .send_retry(&self.config.retry_config) .await @@ -395,6 +412,7 @@ impl S3Client { &self.config.region, "s3", self.config.sign_payload, + None, ) .send_retry(&self.config.retry_config) .await @@ -438,6 +456,7 @@ impl S3Client { &self.config.region, "s3", self.config.sign_payload, + None, ) .send_retry(&self.config.retry_config) .await @@ -482,6 +501,7 @@ impl S3Client { &self.config.region, "s3", self.config.sign_payload, + None, ) .send_retry(&self.config.retry_config) .await diff --git a/object_store/src/aws/credential.rs b/object_store/src/aws/credential.rs index 05f2c535bfdc..183e8434650b 100644 --- a/object_store/src/aws/credential.rs +++ b/object_store/src/aws/credential.rs @@ -84,7 +84,7 @@ const AUTH_HEADER: &str = "authorization"; const ALL_HEADERS: &[&str; 4] = &[DATE_HEADER, HASH_HEADER, TOKEN_HEADER, AUTH_HEADER]; impl<'a> RequestSigner<'a> { - fn sign(&self, request: &mut Request) { + fn sign(&self, request: &mut Request, pre_calculated_digest: Option>) { if let Some(ref token) = self.credential.token { let token_val = HeaderValue::from_str(token).unwrap(); request.headers_mut().insert(TOKEN_HEADER, token_val); @@ -101,9 +101,13 @@ impl<'a> RequestSigner<'a> { request.headers_mut().insert(DATE_HEADER, date_val); let digest = if self.sign_payload { - match request.body() { - None => EMPTY_SHA256_HASH.to_string(), - Some(body) => hex_digest(body.as_bytes().unwrap()), + if let Some(digest) = pre_calculated_digest { + hex_encode(&digest) + } else { + match request.body() { + None => EMPTY_SHA256_HASH.to_string(), + Some(body) => hex_digest(body.as_bytes().unwrap()), + } } } else { UNSIGNED_PAYLOAD_LITERAL.to_string() @@ -165,6 +169,7 @@ pub trait CredentialExt { region: &str, service: &str, sign_payload: bool, + payload_sha256: Option>, ) -> Self; } @@ -175,6 +180,7 @@ impl CredentialExt for RequestBuilder { region: &str, service: &str, sign_payload: bool, + payload_sha256: Option>, ) -> Self { // Hack around lack of access to underlying request // https://github.com/seanmonstar/reqwest/issues/1212 @@ -193,7 +199,7 @@ impl CredentialExt for RequestBuilder { sign_payload, }; - signer.sign(&mut request); + signer.sign(&mut request, payload_sha256); for header in ALL_HEADERS { if let Some(val) = request.headers_mut().remove(*header) { @@ -627,7 +633,7 @@ mod tests { sign_payload: true, }; - signer.sign(&mut request); + signer.sign(&mut request, None); assert_eq!(request.headers().get(AUTH_HEADER).unwrap(), "AWS4-HMAC-SHA256 Credential=AKIAIOSFODNN7EXAMPLE/20220806/us-east-1/ec2/aws4_request, SignedHeaders=host;x-amz-content-sha256;x-amz-date, Signature=a3c787a7ed37f7fdfbfd2d7056a3d7c9d85e6d52a2bfbec73793c0be6e7862d4") } @@ -665,7 +671,7 @@ mod tests { sign_payload: false, }; - signer.sign(&mut request); + signer.sign(&mut request, None); assert_eq!(request.headers().get(AUTH_HEADER).unwrap(), "AWS4-HMAC-SHA256 Credential=AKIAIOSFODNN7EXAMPLE/20220806/us-east-1/ec2/aws4_request, SignedHeaders=host;x-amz-content-sha256;x-amz-date, Signature=653c3d8ea261fd826207df58bc2bb69fbb5003e9eb3c0ef06e4a51f2a81d8699") } @@ -702,7 +708,7 @@ mod tests { sign_payload: true, }; - signer.sign(&mut request); + signer.sign(&mut request, None); assert_eq!(request.headers().get(AUTH_HEADER).unwrap(), "AWS4-HMAC-SHA256 Credential=H20ABqCkLZID4rLe/20220809/us-east-1/s3/aws4_request, SignedHeaders=host;x-amz-content-sha256;x-amz-date, Signature=9ebf2f92872066c99ac94e573b4e1b80f4dbb8a32b1e8e23178318746e7d1b4d") } diff --git a/object_store/src/aws/mod.rs b/object_store/src/aws/mod.rs index c724886cf0e6..7d10f3728238 100644 --- a/object_store/src/aws/mod.rs +++ b/object_store/src/aws/mod.rs @@ -47,6 +47,7 @@ use tokio::io::AsyncWrite; use tracing::info; use url::Url; +pub use crate::aws::checksum::Checksum; use crate::aws::client::{S3Client, S3Config}; use crate::aws::credential::{ AwsCredential, CredentialProvider, InstanceCredentialProvider, @@ -59,6 +60,7 @@ use crate::{ Result, RetryConfig, StreamExt, }; +mod checksum; mod client; mod credential; @@ -101,6 +103,9 @@ enum Error { source: std::num::ParseIntError, }, + #[snafu(display("Invalid Checksum algorithm"))] + InvalidChecksumAlgorithm, + #[snafu(display("Missing region"))] MissingRegion, @@ -386,6 +391,7 @@ pub struct AmazonS3Builder { imdsv1_fallback: bool, virtual_hosted_style_request: bool, unsigned_payload: bool, + checksum_algorithm: Option, metadata_endpoint: Option, profile: Option, client_options: ClientOptions, @@ -514,6 +520,11 @@ pub enum AmazonS3ConfigKey { /// - `unsigned_payload` UnsignedPayload, + /// Set the checksum algorithm for this client + /// + /// See [`AmazonS3Builder::with_checksum_algorithm`] + Checksum, + /// Set the instance metadata endpoint /// /// See [`AmazonS3Builder::with_metadata_endpoint`] for details. @@ -546,6 +557,7 @@ impl AsRef for AmazonS3ConfigKey { Self::MetadataEndpoint => "aws_metadata_endpoint", Self::Profile => "aws_profile", Self::UnsignedPayload => "aws_unsigned_payload", + Self::Checksum => "aws_checksum_algorithm", } } } @@ -575,6 +587,7 @@ impl FromStr for AmazonS3ConfigKey { "aws_imdsv1_fallback" | "imdsv1_fallback" => Ok(Self::ImdsV1Fallback), "aws_metadata_endpoint" | "metadata_endpoint" => Ok(Self::MetadataEndpoint), "aws_unsigned_payload" | "unsigned_payload" => Ok(Self::UnsignedPayload), + "aws_checksum_algorithm" | "checksum_algorithm" => Ok(Self::Checksum), _ => Err(Error::UnknownConfigurationKey { key: s.into() }.into()), } } @@ -694,6 +707,11 @@ impl AmazonS3Builder { AmazonS3ConfigKey::UnsignedPayload => { self.unsigned_payload = str_is_truthy(&value.into()) } + AmazonS3ConfigKey::Checksum => { + let algorithm = Checksum::try_from(&value.into()) + .map_err(|_| Error::InvalidChecksumAlgorithm)?; + self.checksum_algorithm = Some(algorithm) + } }; Ok(self) } @@ -846,6 +864,14 @@ impl AmazonS3Builder { self } + /// Sets the [checksum algorithm] which has to be used for object integrity check during upload. + /// + /// [checksum algorithm]: https://docs.aws.amazon.com/AmazonS3/latest/userguide/checking-object-integrity.html + pub fn with_checksum_algorithm(mut self, checksum_algorithm: Checksum) -> Self { + self.checksum_algorithm = Some(checksum_algorithm); + self + } + /// Set the [instance metadata endpoint](https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/ec2-instance-metadata.html), /// used primarily within AWS EC2. /// @@ -992,6 +1018,7 @@ impl AmazonS3Builder { retry_config: self.retry_config, client_options: self.client_options, sign_payload: !self.unsigned_payload, + checksum: self.checksum_algorithm, }; let client = Arc::new(S3Client::new(config)?); @@ -1151,6 +1178,7 @@ mod tests { &container_creds_relative_uri, ); env::set_var("AWS_UNSIGNED_PAYLOAD", "true"); + env::set_var("AWS_CHECKSUM_ALGORITHM", "sha256"); let builder = AmazonS3Builder::from_env(); assert_eq!(builder.access_key_id.unwrap(), aws_access_key_id.as_str()); @@ -1164,6 +1192,7 @@ mod tests { assert_eq!(builder.token.unwrap(), aws_session_token); let metadata_uri = format!("{METADATA_ENDPOINT}{container_creds_relative_uri}"); assert_eq!(builder.metadata_endpoint.unwrap(), metadata_uri); + assert_eq!(builder.checksum_algorithm.unwrap(), Checksum::SHA256); assert!(builder.unsigned_payload); } @@ -1181,6 +1210,7 @@ mod tests { ("aws_endpoint", aws_endpoint.clone()), ("aws_session_token", aws_session_token.clone()), ("aws_unsigned_payload", "true".to_string()), + ("aws_checksum_algorithm", "sha256".to_string()), ]); let builder = AmazonS3Builder::new() @@ -1193,6 +1223,7 @@ mod tests { assert_eq!(builder.region.unwrap(), aws_default_region); assert_eq!(builder.endpoint.unwrap(), aws_endpoint); assert_eq!(builder.token.unwrap(), aws_session_token); + assert_eq!(builder.checksum_algorithm.unwrap(), Checksum::SHA256); assert!(builder.unsigned_payload); } @@ -1256,6 +1287,12 @@ mod tests { let is_local = matches!(&config.endpoint, Some(e) if e.starts_with("http://")); let integration = config.build().unwrap(); put_get_delete_list_opts(&integration, is_local).await; + + // run integration test with checksum set to sha256 + let config = maybe_skip_integration!().with_checksum_algorithm(Checksum::SHA256); + let is_local = matches!(&config.endpoint, Some(e) if e.starts_with("http://")); + let integration = config.build().unwrap(); + put_get_delete_list_opts(&integration, is_local).await; } #[tokio::test]