Skip to content

Commit

Permalink
Validates embedded errors for multipart upload. (#2057)
Browse files Browse the repository at this point in the history
  • Loading branch information
sbiscigl authored Aug 25, 2022
1 parent 0f9d322 commit 1614bce
Show file tree
Hide file tree
Showing 12 changed files with 130 additions and 5 deletions.
20 changes: 19 additions & 1 deletion aws-cpp-sdk-core-tests/aws/client/AWSClientTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -102,13 +102,18 @@ class AWSClientTestSuite : public ::testing::Test
}

void QueueMockResponse(HttpResponseCode code, const HeaderValueCollection& headers)
{
QueueMockResponse(code, headers, "ss");
}

void QueueMockResponse(HttpResponseCode code, const HeaderValueCollection& headers, const Aws::String& body)
{
auto httpRequest = CreateHttpRequest(URI("http://www.uri.com/path/to/res"),
HttpMethod::HTTP_GET, Aws::Utils::Stream::DefaultResponseStreamFactoryMethod);
httpRequest->SetResolvedRemoteHost("127.0.0.1");
auto httpResponse = Aws::MakeShared<StandardHttpResponse>(ALLOCATION_TAG, httpRequest);
httpResponse->SetResponseCode(code);
httpResponse->GetResponseBody() << "";
httpResponse->GetResponseBody() << body;
for(auto&& header : headers)
{
httpResponse->AddHeader(header.first, header.second);
Expand Down Expand Up @@ -550,6 +555,19 @@ TEST_F(AWSClientTestSuite, TestRecursionDetection)
}
}

TEST_F(AWSClientTestSuite, TestErrorInBodyOfResponse)
{
HeaderValueCollection responseHeaders;
AmazonWebServiceRequestMock request;
QueueMockResponse(HttpResponseCode::OK, responseHeaders, "<Error><Code>SomeException</Code><Message>TestErrorInBodyOfResponse</Message></Error>");
auto outcome = client->MakeRequest(request);

ASSERT_TRUE(!outcome.IsSuccess());
ASSERT_EQ(outcome.GetError().GetErrorType(), CoreErrors::SLOW_DOWN);
ASSERT_EQ(outcome.GetError().GetMessage(), "TestErrorInBodyOfResponse");
ASSERT_EQ(outcome.GetError().GetExceptionName(), "TestErrorInBodyOfResponse");
}

TEST(AWSClientTest, TestBuildHttpRequestWithHeadersOnly)
{
HeaderValueCollection headerValues;
Expand Down
10 changes: 10 additions & 0 deletions aws-cpp-sdk-core/include/aws/core/AmazonWebServiceRequest.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
#include <aws/core/utils/memory/stl/AWSStreamFwd.h>
#include <aws/core/utils/stream/ResponseStream.h>
#include <aws/core/auth/AWSAuthSigner.h>
#include <aws/core/client/CoreErrors.h>

namespace Aws
{
Expand Down Expand Up @@ -75,6 +76,15 @@ namespace Aws
*/
virtual bool SignBody() const { return true; }

/**
* Defaults to false, if a derived class returns true it indicates that the body has an embedded error.
*/
virtual bool HasEmbeddedError(Aws::IOStream& body, const Aws::Http::HeaderValueCollection& header) const {
(void) body;
(void) header;
return false;
}

/**
* Defaults to false, if this is set to true, it supports chunked transfer encoding.
*/
Expand Down
4 changes: 1 addition & 3 deletions aws-cpp-sdk-core/source/client/AWSClient.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -529,7 +529,7 @@ HttpResponseOutcome AWSClient::AttemptOneRequest(const std::shared_ptr<HttpReque
}
}

if (DoesResponseGenerateError(httpResponse))
if (DoesResponseGenerateError(httpResponse) || request.HasEmbeddedError(httpResponse->GetResponseBody(), httpResponse->GetHeaders()))
{
AWS_LOGSTREAM_DEBUG(AWS_CLIENT_LOG_TAG, "Request returned error. Attempting to generate appropriate error codes from response");
auto error = BuildAWSError(httpResponse);
Expand Down Expand Up @@ -1297,8 +1297,6 @@ AWSError<CoreErrors> AWSXMLClient::BuildAWSError(const std::shared_ptr<Http::Htt
}
else
{
assert(httpResponse->GetResponseCode() != HttpResponseCode::OK);

// When trying to build an AWS Error from a response which is an FStream, we need to rewind the
// file pointer back to the beginning in order to correctly read the input using the XML string iterator
if ((httpResponse->GetResponseBody().tellp() > 0)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ namespace Model

Aws::Http::HeaderValueCollection GetRequestSpecificHeaders() const override;

bool HasEmbeddedError(IOStream &body, const Http::HeaderValueCollection &header) const override;

/**
* <p>Name of the bucket to which the multipart upload was initiated.</p> <p>When
Expand Down
22 changes: 22 additions & 0 deletions aws-cpp-sdk-s3-crt/source/model/CompleteMultipartUploadRequest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,28 @@ CompleteMultipartUploadRequest::CompleteMultipartUploadRequest() :
{
}

bool CompleteMultipartUploadRequest::HasEmbeddedError(Aws::IOStream &body,
const Aws::Http::HeaderValueCollection &header) const
{
// Header is unused
(void) header;

auto readPointer = body.tellg();
XmlDocument doc = XmlDocument::CreateFromXmlStream(body);

if (!doc.WasParseSuccessful()) {
body.seekg(readPointer);
return false;
}

if (doc.GetRootElement().GetName() == "Error") {
body.seekg(readPointer);
return true;
}
body.seekg(readPointer);
return false;
}

Aws::String CompleteMultipartUploadRequest::SerializePayload() const
{
XmlDocument payloadDoc = XmlDocument::CreateWithRootNode("CompleteMultipartUpload");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ namespace Model

Aws::Http::HeaderValueCollection GetRequestSpecificHeaders() const override;

bool HasEmbeddedError(IOStream &body, const Http::HeaderValueCollection &header) const override;

/**
* <p>Name of the bucket to which the multipart upload was initiated.</p> <p>When
Expand Down
22 changes: 22 additions & 0 deletions aws-cpp-sdk-s3/source/model/CompleteMultipartUploadRequest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,28 @@ CompleteMultipartUploadRequest::CompleteMultipartUploadRequest() :
{
}

bool CompleteMultipartUploadRequest::HasEmbeddedError(Aws::IOStream &body,
const Aws::Http::HeaderValueCollection &header) const
{
// Header is unused
(void) header;

auto readPointer = body.tellg();
XmlDocument doc = XmlDocument::CreateFromXmlStream(body);

if (!doc.WasParseSuccessful()) {
body.seekg(readPointer);
return false;
}

if (doc.GetRootElement().GetName() == "Error") {
body.seekg(readPointer);
return true;
}
body.seekg(readPointer);
return false;
}

Aws::String CompleteMultipartUploadRequest::SerializePayload() const
{
XmlDocument payloadDoc = XmlDocument::CreateWithRootNode("CompleteMultipartUpload");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ public class Shape {
private boolean sensitive;
private boolean hasPreSignedUrl;
private boolean document;
private boolean hasEmbeddedErrors = false;

public boolean isMap() {
return "map".equals(type.toLowerCase());
Expand Down Expand Up @@ -89,6 +90,14 @@ public boolean isDocument() {
return "structure".equals(type.toLowerCase()) && document;
}

public boolean hasEmbeddedErrors() {
return this.hasEmbeddedErrors;
}

public void setEmbeddedErrors(boolean hasEmbeddedErrors) {
this.hasEmbeddedErrors = hasEmbeddedErrors;
}

public boolean isPrimitive() {
return !isMap() && !isList() && !isStructure() && !isString() && !isEnum() && !isBlob() && !isTimeStamp();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,13 @@
import com.amazonaws.util.awsclientgenerator.domainmodels.codegeneration.cpp.CppShapeInformation;
import com.amazonaws.util.awsclientgenerator.domainmodels.codegeneration.cpp.CppViewHelper;
import com.amazonaws.util.awsclientgenerator.generators.cpp.RestXmlCppClientGenerator;
import com.google.common.collect.ImmutableSet;
import org.apache.velocity.Template;
import org.apache.velocity.VelocityContext;

import java.nio.charset.StandardCharsets;
import java.util.HashMap;
import java.util.HashSet;
import java.util.LinkedHashMap;
import java.util.Map;
import java.util.Set;

Expand All @@ -28,6 +28,7 @@ public class S3RestXmlCppClientGenerator extends RestXmlCppClientGenerator {
private static Set<String> opsThatDoNotSupportArnEndpoint = new HashSet<>();
private static Set<String> opsThatDoNotSupportFutureInS3CRT = new HashSet<>();
private static Set<String> bucketLocationConstraints = new HashSet<>();
private Set<String> functionsWithEmbeddedErrors = ImmutableSet.of("CompleteMultipartUploadRequest");

static {
opsThatDoNotSupportVirtualAddressing.add("CreateBucket");
Expand Down Expand Up @@ -124,6 +125,11 @@ public SdkFileEntry[] generateSourceFiles(ServiceModel serviceModel) throws Exce
replicationStatus.getEnumValues().set(indexOfComplete, "COMPLETED");
}

// Some S3 operations have embedded errors, and we need to search for errors in the response.
serviceModel.getShapes().values().stream()
.filter(shape -> functionsWithEmbeddedErrors.contains(shape.getName()))
.forEach(shape -> shape.setEmbeddedErrors(true));

// Customized Log Information
Shape logTagKeyShape = new Shape();
logTagKeyShape.setName("customizedAccessLogTagKey");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,9 @@ namespace Model
Aws::Http::HeaderValueCollection GetRequestSpecificHeaders() const override;

#end
#if($shape.hasEmbeddedErrors())
bool HasEmbeddedError(IOStream &body, const Http::HeaderValueCollection &header) const override;
#end
#if($operation.requestAlgorithmMember)
Aws::String GetChecksumAlgorithmName() const override;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,30 @@ using namespace Aws::Http;
${typeInfo.className}::${typeInfo.className}()$initializers
{
}
#if($shape.hasEmbeddedErrors())

bool CompleteMultipartUploadRequest::HasEmbeddedError(Aws::IOStream &body,
const Aws::Http::HeaderValueCollection &header) const
{
// Header is unused
(void) header;

auto readPointer = body.tellg();
XmlDocument doc = XmlDocument::CreateFromXmlStream(body);

if (!doc.WasParseSuccessful()) {
body.seekg(readPointer);
return false;
}

if (doc.GetRootElement().GetName() == "Error") {
body.seekg(readPointer);
return true;
}
body.seekg(readPointer);
return false;
}
#end

Aws::String ${typeInfo.className}::SerializePayload() const
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,13 @@ class AmazonWebServiceRequestMock : public Aws::AmazonWebServiceRequest
bool ShouldComputeContentMd5() const override { return m_shouldComputeMd5; }
void SetComputeContentMd5(bool value) { m_shouldComputeMd5 = value; }
virtual const char* GetServiceRequestName() const override { return "AmazonWebServiceRequestMock"; }
virtual bool HasEmbeddedError(Aws::IOStream& body, const Aws::Http::HeaderValueCollection& header) const override {
(void) header;
std::stringstream ss;
ss << body.rdbuf();
auto bodyString = ss.str();
return bodyString.find("TestErrorInBodyOfResponse") != std::string::npos;
}

private:
std::shared_ptr<Aws::IOStream> m_body;
Expand Down Expand Up @@ -119,6 +126,10 @@ class MockAWSClient : Aws::Client::AWSClient
std::shared_ptr<CountedRetryStrategy> m_countedRetryStrategy;
Aws::Client::AWSError<Aws::Client::CoreErrors> BuildAWSError(const std::shared_ptr<Aws::Http::HttpResponse>& response) const override
{
if (response->GetResponseCode() == Aws::Http::HttpResponseCode::OK)
{
return { Aws::Client::CoreErrors::SLOW_DOWN, "TestErrorInBodyOfResponse", "TestErrorInBodyOfResponse", false };
}
Aws::Client::AWSError<Aws::Client::CoreErrors> error;
if (response->HasClientError())
{
Expand Down

0 comments on commit 1614bce

Please sign in to comment.