Skip to content

Commit

Permalink
Return PutResult with an ETag from ObjectStore::put (#4934) (#4944)
Browse files Browse the repository at this point in the history
* Return ETag from ObjectStore::put (#4934)

* Further tests

* Clippy

* Review feedback
  • Loading branch information
tustvold authored Oct 19, 2023
1 parent 51ac6fe commit 4cca029
Show file tree
Hide file tree
Showing 15 changed files with 169 additions and 122 deletions.
12 changes: 10 additions & 2 deletions object_store/src/aws/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ use crate::aws::{
AwsCredentialProvider, S3CopyIfNotExists, STORE, STRICT_PATH_ENCODE_SET,
};
use crate::client::get::GetClient;
use crate::client::header::get_etag;
use crate::client::list::ListClient;
use crate::client::list_response::ListResponse;
use crate::client::retry::RetryExt;
Expand Down Expand Up @@ -122,6 +123,11 @@ pub(crate) enum Error {

#[snafu(display("Got invalid multipart response: {}", source))]
InvalidMultipartResponse { source: quick_xml::de::DeError },

#[snafu(display("Unable to extract metadata from headers: {}", source))]
Metadata {
source: crate::client::header::Error,
},
}

impl From<Error> for crate::Error {
Expand Down Expand Up @@ -243,12 +249,14 @@ impl S3Client {
}

/// Make an S3 PUT request <https://docs.aws.amazon.com/AmazonS3/latest/API/API_PutObject.html>
///
/// Returns the ETag
pub async fn put_request<T: Serialize + ?Sized + Sync>(
&self,
path: &Path,
bytes: Bytes,
query: &T,
) -> Result<Response> {
) -> Result<String> {
let credential = self.get_credential().await?;
let url = self.config.path_url(path);
let mut builder = self.client.request(Method::PUT, url);
Expand Down Expand Up @@ -287,7 +295,7 @@ impl S3Client {
path: path.as_ref(),
})?;

Ok(response)
Ok(get_etag(response.headers()).context(MetadataSnafu)?)
}

/// Make an S3 Delete request <https://docs.aws.amazon.com/AmazonS3/latest/API/API_DeleteObject.html>
Expand Down
25 changes: 6 additions & 19 deletions object_store/src/aws/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ use crate::multipart::{PartId, PutPart, WriteMultiPart};
use crate::signer::Signer;
use crate::{
ClientOptions, GetOptions, GetResult, ListResult, MultipartId, ObjectMeta,
ObjectStore, Path, Result, RetryConfig,
ObjectStore, Path, PutResult, Result, RetryConfig,
};

mod checksum;
Expand Down Expand Up @@ -109,12 +109,6 @@ enum Error {
#[snafu(display("Missing SecretAccessKey"))]
MissingSecretAccessKey,

#[snafu(display("ETag Header missing from response"))]
MissingEtag,

#[snafu(display("Received header containing non-ASCII data"))]
BadHeader { source: reqwest::header::ToStrError },

#[snafu(display("Unable parse source url. Url: {}, Error: {}", url, source))]
UnableToParseUrl {
source: url::ParseError,
Expand Down Expand Up @@ -273,9 +267,9 @@ impl Signer for AmazonS3 {

#[async_trait]
impl ObjectStore for AmazonS3 {
async fn put(&self, location: &Path, bytes: Bytes) -> Result<()> {
self.client.put_request(location, bytes, &()).await?;
Ok(())
async fn put(&self, location: &Path, bytes: Bytes) -> Result<PutResult> {
let e_tag = self.client.put_request(location, bytes, &()).await?;
Ok(PutResult { e_tag: Some(e_tag) })
}

async fn put_multipart(
Expand Down Expand Up @@ -365,10 +359,9 @@ struct S3MultiPartUpload {
#[async_trait]
impl PutPart for S3MultiPartUpload {
async fn put_part(&self, buf: Vec<u8>, part_idx: usize) -> Result<PartId> {
use reqwest::header::ETAG;
let part = (part_idx + 1).to_string();

let response = self
let content_id = self
.client
.put_request(
&self.location,
Expand All @@ -377,13 +370,7 @@ impl PutPart for S3MultiPartUpload {
)
.await?;

let etag = response.headers().get(ETAG).context(MissingEtagSnafu)?;

let etag = etag.to_str().context(BadHeaderSnafu)?;

Ok(PartId {
content_id: etag.to_string(),
})
Ok(PartId { content_id })
}

async fn complete(&self, completed_parts: Vec<PartId>) -> Result<()> {
Expand Down
20 changes: 11 additions & 9 deletions object_store/src/azure/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ use crate::{
multipart::{PartId, PutPart, WriteMultiPart},
path::Path,
ClientOptions, GetOptions, GetResult, ListResult, MultipartId, ObjectMeta,
ObjectStore, Result, RetryConfig,
ObjectStore, PutResult, Result, RetryConfig,
};
use async_trait::async_trait;
use base64::prelude::BASE64_STANDARD;
Expand Down Expand Up @@ -62,6 +62,7 @@ mod credential;
/// [`CredentialProvider`] for [`MicrosoftAzure`]
pub type AzureCredentialProvider =
Arc<dyn CredentialProvider<Credential = AzureCredential>>;
use crate::client::header::get_etag;
pub use credential::AzureCredential;

const STORE: &str = "MicrosoftAzure";
Expand All @@ -81,9 +82,6 @@ const MSI_ENDPOINT_ENV_KEY: &str = "IDENTITY_ENDPOINT";
#[derive(Debug, Snafu)]
#[allow(missing_docs)]
enum Error {
#[snafu(display("Received header containing non-ASCII data"))]
BadHeader { source: reqwest::header::ToStrError },

#[snafu(display("Unable parse source url. Url: {}, Error: {}", url, source))]
UnableToParseUrl {
source: url::ParseError,
Expand Down Expand Up @@ -126,8 +124,10 @@ enum Error {
#[snafu(display("Configuration key: '{}' is not known.", key))]
UnknownConfigurationKey { key: String },

#[snafu(display("ETag Header missing from response"))]
MissingEtag,
#[snafu(display("Unable to extract metadata from headers: {}", source))]
Metadata {
source: crate::client::header::Error,
},
}

impl From<Error> for super::Error {
Expand Down Expand Up @@ -170,11 +170,13 @@ impl std::fmt::Display for MicrosoftAzure {

#[async_trait]
impl ObjectStore for MicrosoftAzure {
async fn put(&self, location: &Path, bytes: Bytes) -> Result<()> {
self.client
async fn put(&self, location: &Path, bytes: Bytes) -> Result<PutResult> {
let response = self
.client
.put_request(location, Some(bytes), false, &())
.await?;
Ok(())
let e_tag = Some(get_etag(response.headers()).context(MetadataSnafu)?);
Ok(PutResult { e_tag })
}

async fn put_multipart(
Expand Down
3 changes: 2 additions & 1 deletion object_store/src/chunked.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ use tokio::io::AsyncWrite;
use crate::path::Path;
use crate::{
GetOptions, GetResult, GetResultPayload, ListResult, ObjectMeta, ObjectStore,
PutResult,
};
use crate::{MultipartId, Result};

Expand Down Expand Up @@ -62,7 +63,7 @@ impl Display for ChunkedStore {

#[async_trait]
impl ObjectStore for ChunkedStore {
async fn put(&self, location: &Path, bytes: Bytes) -> Result<()> {
async fn put(&self, location: &Path, bytes: Bytes) -> Result<PutResult> {
self.inner.put(location, bytes).await
}

Expand Down
17 changes: 10 additions & 7 deletions object_store/src/client/header.rs
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,12 @@ pub enum Error {
},
}

/// Extracts an etag from the provided [`HeaderMap`]
pub fn get_etag(headers: &HeaderMap) -> Result<String, Error> {
let e_tag = headers.get(ETAG).ok_or(Error::MissingEtag)?;
Ok(e_tag.to_str().context(BadHeaderSnafu)?.to_string())
}

/// Extracts [`ObjectMeta`] from the provided [`HeaderMap`]
pub fn header_meta(
location: &Path,
Expand All @@ -81,13 +87,10 @@ pub fn header_meta(
None => Utc.timestamp_nanos(0),
};

let e_tag = match headers.get(ETAG) {
Some(e_tag) => {
let e_tag = e_tag.to_str().context(BadHeaderSnafu)?;
Some(e_tag.to_string())
}
None if cfg.etag_required => return Err(Error::MissingEtag),
None => None,
let e_tag = match get_etag(headers) {
Ok(e_tag) => Some(e_tag),
Err(Error::MissingEtag) if !cfg.etag_required => None,
Err(e) => return Err(e),
};

let content_length = headers
Expand Down
87 changes: 35 additions & 52 deletions object_store/src/gcp/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ use crate::{
multipart::{PartId, PutPart, WriteMultiPart},
path::{Path, DELIMITER},
ClientOptions, GetOptions, GetResult, ListResult, MultipartId, ObjectMeta,
ObjectStore, Result, RetryConfig,
ObjectStore, PutResult, Result, RetryConfig,
};

use credential::{InstanceCredentialProvider, ServiceAccountCredentials};
Expand All @@ -65,6 +65,7 @@ const STORE: &str = "GCS";

/// [`CredentialProvider`] for [`GoogleCloudStorage`]
pub type GcpCredentialProvider = Arc<dyn CredentialProvider<Credential = GcpCredential>>;
use crate::client::header::get_etag;
use crate::gcp::credential::{ApplicationDefaultCredentials, DEFAULT_GCS_BASE_URL};
pub use credential::GcpCredential;

Expand Down Expand Up @@ -155,11 +156,10 @@ enum Error {
#[snafu(display("Configuration key: '{}' is not known.", key))]
UnknownConfigurationKey { key: String },

#[snafu(display("ETag Header missing from response"))]
MissingEtag,

#[snafu(display("Received header containing non-ASCII data"))]
BadHeader { source: header::ToStrError },
#[snafu(display("Unable to extract metadata from headers: {}", source))]
Metadata {
source: crate::client::header::Error,
},
}

impl From<Error> for super::Error {
Expand Down Expand Up @@ -247,7 +247,14 @@ impl GoogleCloudStorageClient {
}

/// Perform a put request <https://cloud.google.com/storage/docs/xml-api/put-object-upload>
async fn put_request(&self, path: &Path, payload: Bytes) -> Result<()> {
///
/// Returns the new ETag
async fn put_request<T: Serialize + ?Sized + Sync>(
&self,
path: &Path,
payload: Bytes,
query: &T,
) -> Result<String> {
let credential = self.get_credential().await?;
let url = self.object_url(path);

Expand All @@ -256,8 +263,10 @@ impl GoogleCloudStorageClient {
.get_content_type(path)
.unwrap_or("application/octet-stream");

self.client
let response = self
.client
.request(Method::PUT, url)
.query(query)
.bearer_auth(&credential.bearer)
.header(header::CONTENT_TYPE, content_type)
.header(header::CONTENT_LENGTH, payload.len())
Expand All @@ -268,7 +277,7 @@ impl GoogleCloudStorageClient {
path: path.as_ref(),
})?;

Ok(())
Ok(get_etag(response.headers()).context(MetadataSnafu)?)
}

/// Initiate a multi-part upload <https://cloud.google.com/storage/docs/xml-api/post-object-multipart>
Expand Down Expand Up @@ -469,7 +478,7 @@ impl ListClient for GoogleCloudStorageClient {

struct GCSMultipartUpload {
client: Arc<GoogleCloudStorageClient>,
encoded_path: String,
path: Path,
multipart_id: MultipartId,
}

Expand All @@ -478,49 +487,25 @@ impl PutPart for GCSMultipartUpload {
/// Upload an object part <https://cloud.google.com/storage/docs/xml-api/put-object-multipart>
async fn put_part(&self, buf: Vec<u8>, part_idx: usize) -> Result<PartId> {
let upload_id = self.multipart_id.clone();
let url = format!(
"{}/{}/{}",
self.client.base_url, self.client.bucket_name_encoded, self.encoded_path
);

let credential = self.client.get_credential().await?;

let response = self
let content_id = self
.client
.client
.request(Method::PUT, &url)
.bearer_auth(&credential.bearer)
.query(&[
("partNumber", format!("{}", part_idx + 1)),
("uploadId", upload_id),
])
.header(header::CONTENT_TYPE, "application/octet-stream")
.header(header::CONTENT_LENGTH, format!("{}", buf.len()))
.body(buf)
.send_retry(&self.client.retry_config)
.await
.context(PutRequestSnafu {
path: &self.encoded_path,
})?;

let content_id = response
.headers()
.get("ETag")
.context(MissingEtagSnafu)?
.to_str()
.context(BadHeaderSnafu)?
.to_string();
.put_request(
&self.path,
buf.into(),
&[
("partNumber", format!("{}", part_idx + 1)),
("uploadId", upload_id),
],
)
.await?;

Ok(PartId { content_id })
}

/// Complete a multipart upload <https://cloud.google.com/storage/docs/xml-api/post-object-complete>
async fn complete(&self, completed_parts: Vec<PartId>) -> Result<()> {
let upload_id = self.multipart_id.clone();
let url = format!(
"{}/{}/{}",
self.client.base_url, self.client.bucket_name_encoded, self.encoded_path
);
let url = self.client.object_url(&self.path);

let parts = completed_parts
.into_iter()
Expand Down Expand Up @@ -550,7 +535,7 @@ impl PutPart for GCSMultipartUpload {
.send_retry(&self.client.retry_config)
.await
.context(PostRequestSnafu {
path: &self.encoded_path,
path: self.path.as_ref(),
})?;

Ok(())
Expand All @@ -559,8 +544,9 @@ impl PutPart for GCSMultipartUpload {

#[async_trait]
impl ObjectStore for GoogleCloudStorage {
async fn put(&self, location: &Path, bytes: Bytes) -> Result<()> {
self.client.put_request(location, bytes).await
async fn put(&self, location: &Path, bytes: Bytes) -> Result<PutResult> {
let e_tag = self.client.put_request(location, bytes, &()).await?;
Ok(PutResult { e_tag: Some(e_tag) })
}

async fn put_multipart(
Expand All @@ -569,12 +555,9 @@ impl ObjectStore for GoogleCloudStorage {
) -> Result<(MultipartId, Box<dyn AsyncWrite + Unpin + Send>)> {
let upload_id = self.client.multipart_initiate(location).await?;

let encoded_path =
percent_encode(location.to_string().as_bytes(), NON_ALPHANUMERIC).to_string();

let inner = GCSMultipartUpload {
client: Arc::clone(&self.client),
encoded_path,
path: location.clone(),
multipart_id: upload_id.clone(),
};

Expand Down
Loading

0 comments on commit 4cca029

Please sign in to comment.