From 9b2255034980968419df7f33f9595cb84e2742c1 Mon Sep 17 00:00:00 2001 From: Trivikram Kamat <16024985+trivikr@users.noreply.github.com> Date: Thu, 19 Dec 2024 17:19:06 +0000 Subject: [PATCH 1/3] chore(middleware-flexible-checksums): perform checksum calculation and validation by default --- .../src/configuration.ts | 13 +++ .../flexibleChecksumsInputMiddleware.spec.ts | 103 ++++++++++++++++++ .../src/flexibleChecksumsInputMiddleware.ts | 82 ++++++++++++++ .../src/flexibleChecksumsMiddleware.spec.ts | 63 +++++++++-- .../src/flexibleChecksumsMiddleware.ts | 18 ++- .../getChecksumAlgorithmForRequest.spec.ts | 68 +++++++++--- .../src/getChecksumAlgorithmForRequest.ts | 25 +++-- .../src/getFlexibleChecksumsPlugin.ts | 10 ++ 8 files changed, 347 insertions(+), 35 deletions(-) create mode 100644 packages/middleware-flexible-checksums/src/flexibleChecksumsInputMiddleware.spec.ts create mode 100644 packages/middleware-flexible-checksums/src/flexibleChecksumsInputMiddleware.ts diff --git a/packages/middleware-flexible-checksums/src/configuration.ts b/packages/middleware-flexible-checksums/src/configuration.ts index e92871b491b5..a6a066c74924 100644 --- a/packages/middleware-flexible-checksums/src/configuration.ts +++ b/packages/middleware-flexible-checksums/src/configuration.ts @@ -4,10 +4,13 @@ import { Encoder, GetAwsChunkedEncodingStream, HashConstructor, + Provider, StreamCollector, StreamHasher, } from "@smithy/types"; +import { RequestChecksumCalculation, ResponseChecksumValidation } from "./constants"; + export interface PreviouslyResolved { /** * The function that will be used to convert binary data to a base64-encoded string. @@ -31,6 +34,16 @@ export interface PreviouslyResolved { */ md5: ChecksumConstructor | HashConstructor; + /** + * Determines when a checksum will be calculated for request payloads + */ + requestChecksumCalculation: Provider; + + /** + * Determines when a checksum will be calculated for response payloads + */ + responseChecksumValidation: Provider; + /** * A constructor for a class implementing the {@link Hash} interface that computes SHA1 hashes. * @internal diff --git a/packages/middleware-flexible-checksums/src/flexibleChecksumsInputMiddleware.spec.ts b/packages/middleware-flexible-checksums/src/flexibleChecksumsInputMiddleware.spec.ts new file mode 100644 index 000000000000..7f715114d18c --- /dev/null +++ b/packages/middleware-flexible-checksums/src/flexibleChecksumsInputMiddleware.spec.ts @@ -0,0 +1,103 @@ +import { setFeature } from "@aws-sdk/core"; +import { afterEach, describe, expect, test as it, vi } from "vitest"; + +import { PreviouslyResolved } from "./configuration"; +import { DEFAULT_CHECKSUM_ALGORITHM, RequestChecksumCalculation, ResponseChecksumValidation } from "./constants"; +import { flexibleChecksumsInputMiddleware } from "./flexibleChecksumsInputMiddleware"; + +vi.mock("@aws-sdk/core"); + +describe(flexibleChecksumsInputMiddleware.name, () => { + const mockNext = vi.fn(); + const mockRequestValidationModeMember = "mockRequestValidationModeMember"; + + const mockConfig = { + requestChecksumCalculation: () => Promise.resolve(RequestChecksumCalculation.WHEN_SUPPORTED), + responseChecksumValidation: () => Promise.resolve(ResponseChecksumValidation.WHEN_SUPPORTED), + } as PreviouslyResolved; + + afterEach(() => { + expect(mockNext).toHaveBeenCalledTimes(1); + vi.clearAllMocks(); + }); + + describe("sets input.requestValidationModeMember", () => { + it("when requestValidationModeMember is defined and responseChecksumValidation is supported", async () => { + const mockMiddlewareConfigWithMockRequestValidationModeMember = { + requestValidationModeMember: mockRequestValidationModeMember, + }; + const handler = flexibleChecksumsInputMiddleware( + mockConfig, + mockMiddlewareConfigWithMockRequestValidationModeMember + )(mockNext, {}); + await handler({ input: {} }); + expect(mockNext).toHaveBeenCalledWith({ input: { [mockRequestValidationModeMember]: "ENABLED" } }); + }); + }); + + describe("leaves input.requestValidationModeMember", () => { + const mockArgs = { input: {} }; + + it("when requestValidationModeMember is not defined", async () => { + const handler = flexibleChecksumsInputMiddleware(mockConfig, {})(mockNext, {}); + await handler(mockArgs); + expect(mockNext).toHaveBeenCalledWith(mockArgs); + }); + + it("when responseChecksumValidation is required", async () => { + const mockConfigResWhenRequired = { + ...mockConfig, + responseChecksumValidation: () => Promise.resolve(ResponseChecksumValidation.WHEN_REQUIRED), + } as PreviouslyResolved; + + const handler = flexibleChecksumsInputMiddleware(mockConfigResWhenRequired, {})(mockNext, {}); + await handler(mockArgs); + + expect(mockNext).toHaveBeenCalledWith(mockArgs); + }); + }); + + describe("set feature", () => { + it.each([ + [ + "FLEXIBLE_CHECKSUMS_REQ_WHEN_REQUIRED", + "a", + "requestChecksumCalculation", + RequestChecksumCalculation.WHEN_REQUIRED, + ], + [ + "FLEXIBLE_CHECKSUMS_REQ_WHEN_SUPPORTED", + "Z", + "requestChecksumCalculation", + RequestChecksumCalculation.WHEN_SUPPORTED, + ], + [ + "FLEXIBLE_CHECKSUMS_RES_WHEN_REQUIRED", + "c", + "responseChecksumValidation", + ResponseChecksumValidation.WHEN_REQUIRED, + ], + [ + "FLEXIBLE_CHECKSUMS_RES_WHEN_SUPPORTED", + "b", + "responseChecksumValidation", + ResponseChecksumValidation.WHEN_SUPPORTED, + ], + ])("logs %s:%s when %s=%s", async (feature, value, configKey, configValue) => { + const mockConfigOverride = { + ...mockConfig, + [configKey]: () => Promise.resolve(configValue), + } as PreviouslyResolved; + + const handler = flexibleChecksumsInputMiddleware(mockConfigOverride, {})(mockNext, {}); + await handler({ input: {} }); + + expect(setFeature).toHaveBeenCalledTimes(2); + if (configKey === "requestChecksumCalculation") { + expect(setFeature).toHaveBeenNthCalledWith(1, expect.anything(), feature, value); + } else { + expect(setFeature).toHaveBeenNthCalledWith(2, expect.anything(), feature, value); + } + }); + }); +}); diff --git a/packages/middleware-flexible-checksums/src/flexibleChecksumsInputMiddleware.ts b/packages/middleware-flexible-checksums/src/flexibleChecksumsInputMiddleware.ts new file mode 100644 index 000000000000..0dcb7c94cba9 --- /dev/null +++ b/packages/middleware-flexible-checksums/src/flexibleChecksumsInputMiddleware.ts @@ -0,0 +1,82 @@ +import { setFeature } from "@aws-sdk/core"; +import { + HandlerExecutionContext, + MetadataBearer, + RelativeMiddlewareOptions, + SerializeHandler, + SerializeHandlerArguments, + SerializeHandlerOutput, + SerializeMiddleware, +} from "@smithy/types"; + +import { PreviouslyResolved } from "./configuration"; +import { RequestChecksumCalculation, ResponseChecksumValidation } from "./constants"; + +export interface FlexibleChecksumsInputMiddlewareConfig { + /** + * Defines a top-level operation input member used to opt-in to best-effort validation + * of a checksum returned in the HTTP response of the operation. + */ + requestValidationModeMember?: string; +} + +/** + * @internal + */ +export const flexibleChecksumsInputMiddlewareOptions: RelativeMiddlewareOptions = { + name: "flexibleChecksumsInputMiddleware", + toMiddleware: "serializerMiddleware", + relation: "before", + tags: ["BODY_CHECKSUM"], + override: true, +}; + +/** + * @internal + * + * The input counterpart to the flexibleChecksumsMiddleware. + */ +export const flexibleChecksumsInputMiddleware = + ( + config: PreviouslyResolved, + middlewareConfig: FlexibleChecksumsInputMiddlewareConfig + ): SerializeMiddleware => + ( + next: SerializeHandler, + context: HandlerExecutionContext + ): SerializeHandler => + async (args: SerializeHandlerArguments): Promise> => { + const input = args.input; + const { requestValidationModeMember } = middlewareConfig; + + const requestChecksumCalculation = await config.requestChecksumCalculation(); + const responseChecksumValidation = await config.responseChecksumValidation(); + + switch (requestChecksumCalculation) { + case RequestChecksumCalculation.WHEN_REQUIRED: + setFeature(context, "FLEXIBLE_CHECKSUMS_REQ_WHEN_REQUIRED", "a"); + break; + case RequestChecksumCalculation.WHEN_SUPPORTED: + setFeature(context, "FLEXIBLE_CHECKSUMS_REQ_WHEN_SUPPORTED", "Z"); + break; + } + + switch (responseChecksumValidation) { + case ResponseChecksumValidation.WHEN_REQUIRED: + setFeature(context, "FLEXIBLE_CHECKSUMS_RES_WHEN_REQUIRED", "c"); + break; + case ResponseChecksumValidation.WHEN_SUPPORTED: + setFeature(context, "FLEXIBLE_CHECKSUMS_RES_WHEN_SUPPORTED", "b"); + break; + } + + // The value for input member to opt-in to best-effort validation of a checksum returned in the HTTP response is not set. + if (requestValidationModeMember && !input[requestValidationModeMember]) { + // Set requestValidationModeMember as ENABLED only if response checksum validation is supported. + if (responseChecksumValidation === ResponseChecksumValidation.WHEN_SUPPORTED) { + input[requestValidationModeMember] = "ENABLED"; + } + } + + return next(args); + }; diff --git a/packages/middleware-flexible-checksums/src/flexibleChecksumsMiddleware.spec.ts b/packages/middleware-flexible-checksums/src/flexibleChecksumsMiddleware.spec.ts index 3d173e584117..27afd7773cce 100644 --- a/packages/middleware-flexible-checksums/src/flexibleChecksumsMiddleware.spec.ts +++ b/packages/middleware-flexible-checksums/src/flexibleChecksumsMiddleware.spec.ts @@ -3,7 +3,7 @@ import { BuildHandlerArguments } from "@smithy/types"; import { afterEach, beforeEach, describe, expect, test as it, vi } from "vitest"; import { PreviouslyResolved } from "./configuration"; -import { ChecksumAlgorithm } from "./constants"; +import { ChecksumAlgorithm, DEFAULT_CHECKSUM_ALGORITHM, RequestChecksumCalculation } from "./constants"; import { flexibleChecksumsMiddleware } from "./flexibleChecksumsMiddleware"; import { getChecksumAlgorithmForRequest } from "./getChecksumAlgorithmForRequest"; import { getChecksumLocationName } from "./getChecksumLocationName"; @@ -13,6 +13,7 @@ import { isStreaming } from "./isStreaming"; import { selectChecksumAlgorithmFunction } from "./selectChecksumAlgorithmFunction"; import { stringHasher } from "./stringHasher"; +vi.mock("@aws-sdk/core"); vi.mock("@smithy/protocol-http"); vi.mock("./getChecksumAlgorithmForRequest"); vi.mock("./getChecksumLocationName"); @@ -28,10 +29,14 @@ describe(flexibleChecksumsMiddleware.name, () => { const mockChecksum = "mockChecksum"; const mockChecksumAlgorithmFunction = vi.fn(); const mockChecksumLocationName = "mock-checksum-location-name"; + const mockRequestAlgorithmMember = "mockRequestAlgorithmMember"; + const mockRequestAlgorithmMemberHttpHeader = "mock-request-algorithm-member-http-header"; const mockInput = {}; - const mockConfig = {} as PreviouslyResolved; - const mockMiddlewareConfig = { requestChecksumRequired: false }; + const mockConfig = { + requestChecksumCalculation: () => Promise.resolve(RequestChecksumCalculation.WHEN_REQUIRED), + } as PreviouslyResolved; + const mockMiddlewareConfig = { input: mockInput, requestChecksumRequired: false }; const mockBody = { body: "mockRequestBody" }; const mockHeaders = { "content-length": 100, "content-encoding": "gzip" }; @@ -41,9 +46,8 @@ describe(flexibleChecksumsMiddleware.name, () => { beforeEach(() => { mockNext.mockResolvedValueOnce(mockResult); - const { isInstance } = HttpRequest; - (isInstance as unknown as any).mockReturnValue(true); - vi.mocked(getChecksumAlgorithmForRequest).mockReturnValue(ChecksumAlgorithm.MD5); + vi.mocked(HttpRequest.isInstance).mockReturnValue(true); + vi.mocked(getChecksumAlgorithmForRequest).mockReturnValue(ChecksumAlgorithm.CRC32); vi.mocked(getChecksumLocationName).mockReturnValue(mockChecksumLocationName); vi.mocked(hasHeader).mockReturnValue(true); vi.mocked(hasHeaderWithPrefix).mockReturnValue(false); @@ -58,8 +62,7 @@ describe(flexibleChecksumsMiddleware.name, () => { describe("skips", () => { it("if not an instance of HttpRequest", async () => { - const { isInstance } = HttpRequest; - (isInstance as unknown as any).mockReturnValue(false); + vi.mocked(HttpRequest.isInstance).mockReturnValue(false); const handler = flexibleChecksumsMiddleware(mockConfig, mockMiddlewareConfig)(mockNext, {}); await handler(mockArgs); expect(getChecksumAlgorithmForRequest).not.toHaveBeenCalled(); @@ -77,7 +80,7 @@ describe(flexibleChecksumsMiddleware.name, () => { expect(getChecksumAlgorithmForRequest).toHaveBeenCalledTimes(1); }); - it("if header is already present", async () => { + it("skip if header is already present", async () => { const handler = flexibleChecksumsMiddleware(mockConfig, mockMiddlewareConfig)(mockNext, {}); vi.mocked(hasHeaderWithPrefix).mockReturnValue(true); @@ -94,11 +97,53 @@ describe(flexibleChecksumsMiddleware.name, () => { describe("adds checksum in the request header", () => { afterEach(() => { + expect(HttpRequest.isInstance).toHaveBeenCalledTimes(1); + expect(hasHeaderWithPrefix).toHaveBeenCalledTimes(1); expect(getChecksumAlgorithmForRequest).toHaveBeenCalledTimes(1); expect(getChecksumLocationName).toHaveBeenCalledTimes(1); expect(selectChecksumAlgorithmFunction).toHaveBeenCalledTimes(1); }); + describe("if input.requestAlgorithmMember can be set", () => { + describe("input[requestAlgorithmMember] is not defined and", () => { + const mockMwConfigWithReqAlgoMember = { + ...mockMiddlewareConfig, + requestAlgorithmMember: { + name: mockRequestAlgorithmMember, + httpHeader: mockRequestAlgorithmMemberHttpHeader, + }, + }; + + it("requestChecksumCalculation is supported", async () => { + const handler = flexibleChecksumsMiddleware( + { + ...mockConfig, + requestChecksumCalculation: () => Promise.resolve(RequestChecksumCalculation.WHEN_SUPPORTED), + }, + mockMwConfigWithReqAlgoMember + )(mockNext, {}); + await handler(mockArgs); + expect(mockNext.mock.calls[0][0].input[mockRequestAlgorithmMember]).toEqual(DEFAULT_CHECKSUM_ALGORITHM); + expect(mockNext.mock.calls[0][0].request.headers[mockRequestAlgorithmMemberHttpHeader]).toEqual( + DEFAULT_CHECKSUM_ALGORITHM + ); + }); + + it("requestChecksumRequired is set to true", async () => { + const handler = flexibleChecksumsMiddleware(mockConfig, { + ...mockMwConfigWithReqAlgoMember, + requestChecksumRequired: true, + })(mockNext, {}); + + await handler(mockArgs); + expect(mockNext.mock.calls[0][0].input[mockRequestAlgorithmMember]).toEqual(DEFAULT_CHECKSUM_ALGORITHM); + expect(mockNext.mock.calls[0][0].request.headers[mockRequestAlgorithmMemberHttpHeader]).toEqual( + DEFAULT_CHECKSUM_ALGORITHM + ); + }); + }); + }); + it("for streaming body", async () => { vi.mocked(isStreaming).mockReturnValue(true); const mockUpdatedBody = { body: "mockUpdatedBody" }; diff --git a/packages/middleware-flexible-checksums/src/flexibleChecksumsMiddleware.ts b/packages/middleware-flexible-checksums/src/flexibleChecksumsMiddleware.ts index 2ed8e66f8c39..8872adde5d93 100644 --- a/packages/middleware-flexible-checksums/src/flexibleChecksumsMiddleware.ts +++ b/packages/middleware-flexible-checksums/src/flexibleChecksumsMiddleware.ts @@ -11,7 +11,7 @@ import { } from "@smithy/types"; import { PreviouslyResolved } from "./configuration"; -import { ChecksumAlgorithm } from "./constants"; +import { ChecksumAlgorithm, DEFAULT_CHECKSUM_ALGORITHM, RequestChecksumCalculation } from "./constants"; import { getChecksumAlgorithmForRequest } from "./getChecksumAlgorithmForRequest"; import { getChecksumLocationName } from "./getChecksumLocationName"; import { hasHeader } from "./hasHeader"; @@ -73,10 +73,26 @@ export const flexibleChecksumsMiddleware = const { body: requestBody, headers } = request; const { base64Encoder, streamHasher } = config; const { requestChecksumRequired, requestAlgorithmMember } = middlewareConfig; + const requestChecksumCalculation = await config.requestChecksumCalculation(); + + const requestAlgorithmMemberName = requestAlgorithmMember?.name; + const requestAlgorithmMemberHttpHeader = requestAlgorithmMember?.httpHeader; + // The value for input member to configure flexible checksum is not set. + if (requestAlgorithmMemberName && !input[requestAlgorithmMemberName]) { + // Set requestAlgorithmMember as default checksum algorithm only if request checksum calculation is supported + // or request checksum is required. + if (requestChecksumCalculation === RequestChecksumCalculation.WHEN_SUPPORTED || requestChecksumRequired) { + input[requestAlgorithmMemberName] = DEFAULT_CHECKSUM_ALGORITHM; + if (requestAlgorithmMemberHttpHeader) { + headers[requestAlgorithmMemberHttpHeader] = DEFAULT_CHECKSUM_ALGORITHM; + } + } + } const checksumAlgorithm = getChecksumAlgorithmForRequest(input, { requestChecksumRequired, requestAlgorithmMember: requestAlgorithmMember?.name, + requestChecksumCalculation, }); let updatedBody = requestBody; let updatedHeaders = headers; diff --git a/packages/middleware-flexible-checksums/src/getChecksumAlgorithmForRequest.spec.ts b/packages/middleware-flexible-checksums/src/getChecksumAlgorithmForRequest.spec.ts index 1981789e9b95..7ac3ce5438d7 100644 --- a/packages/middleware-flexible-checksums/src/getChecksumAlgorithmForRequest.spec.ts +++ b/packages/middleware-flexible-checksums/src/getChecksumAlgorithmForRequest.spec.ts @@ -1,6 +1,6 @@ import { describe, expect, test as it } from "vitest"; -import { ChecksumAlgorithm, DEFAULT_CHECKSUM_ALGORITHM } from "./constants"; +import { DEFAULT_CHECKSUM_ALGORITHM, RequestChecksumCalculation } from "./constants"; import { getChecksumAlgorithmForRequest } from "./getChecksumAlgorithmForRequest"; import { CLIENT_SUPPORTED_ALGORITHMS } from "./types"; @@ -8,36 +8,64 @@ describe(getChecksumAlgorithmForRequest.name, () => { const mockRequestAlgorithmMember = "mockRequestAlgorithmMember"; describe("when requestAlgorithmMember is not provided", () => { - it(`returns ${DEFAULT_CHECKSUM_ALGORITHM} if requestChecksumRequired is set`, () => { - expect(getChecksumAlgorithmForRequest({}, { requestChecksumRequired: true })).toEqual(DEFAULT_CHECKSUM_ALGORITHM); - }); + describe(`when requestChecksumCalculation is '${RequestChecksumCalculation.WHEN_REQUIRED}'`, () => { + const mockOptions = { requestChecksumCalculation: RequestChecksumCalculation.WHEN_REQUIRED }; + + it(`returns ${DEFAULT_CHECKSUM_ALGORITHM} if requestChecksumRequired is set`, () => { + expect(getChecksumAlgorithmForRequest({}, { ...mockOptions, requestChecksumRequired: true })).toEqual( + DEFAULT_CHECKSUM_ALGORITHM + ); + }); - it("returns undefined if requestChecksumRequired is false", () => { - expect(getChecksumAlgorithmForRequest({}, { requestChecksumRequired: false })).toBeUndefined(); + it("returns undefined if requestChecksumRequired is false", () => { + expect(getChecksumAlgorithmForRequest({}, { ...mockOptions, requestChecksumRequired: false })).toBeUndefined(); + }); }); - }); - describe("when requestAlgorithmMember is not set in input", () => { - const mockOptions = { requestAlgorithmMember: mockRequestAlgorithmMember }; + describe(`when requestChecksumCalculation is '${RequestChecksumCalculation.WHEN_SUPPORTED}'`, () => { + const mockOptions = { requestChecksumCalculation: RequestChecksumCalculation.WHEN_SUPPORTED }; - it(`returns ${DEFAULT_CHECKSUM_ALGORITHM} if requestChecksumRequired is set`, () => { - expect(getChecksumAlgorithmForRequest({}, { ...mockOptions, requestChecksumRequired: true })).toEqual( - DEFAULT_CHECKSUM_ALGORITHM - ); + it(`returns ${DEFAULT_CHECKSUM_ALGORITHM} if requestChecksumRequired is set`, () => { + expect(getChecksumAlgorithmForRequest({}, { ...mockOptions, requestChecksumRequired: true })).toEqual( + DEFAULT_CHECKSUM_ALGORITHM + ); + }); + + it(`returns ${DEFAULT_CHECKSUM_ALGORITHM} if requestChecksumRequired is false`, () => { + expect(getChecksumAlgorithmForRequest({}, { ...mockOptions, requestChecksumRequired: false })).toEqual( + DEFAULT_CHECKSUM_ALGORITHM + ); + }); }); + }); - it("returns undefined if requestChecksumRequired is false", () => { - expect(getChecksumAlgorithmForRequest({}, { ...mockOptions, requestChecksumRequired: false })).toBeUndefined(); + describe("returns undefined if input[requestAlgorithmMember] is not set", () => { + describe.each([true, false])("when requestChecksumRequired='%s'", (requestChecksumRequired) => { + it.each([RequestChecksumCalculation.WHEN_SUPPORTED, RequestChecksumCalculation.WHEN_REQUIRED])( + "when requestChecksumCalculation='%s'", + (requestChecksumCalculation) => { + const mockOptions = { + requestChecksumRequired, + requestChecksumCalculation, + requestAlgorithmMember: mockRequestAlgorithmMember, + }; + expect(getChecksumAlgorithmForRequest({}, mockOptions)).toBeUndefined(); + } + ); }); }); it("throws error if input[requestAlgorithmMember] if not supported by client", () => { const unsupportedAlgo = "unsupportedAlgo"; const mockInput = { [mockRequestAlgorithmMember]: unsupportedAlgo }; - const mockOptions = { requestChecksumRequired: true, requestAlgorithmMember: mockRequestAlgorithmMember }; + const mockOptions = { + requestChecksumRequired: true, + requestAlgorithmMember: mockRequestAlgorithmMember, + requestChecksumCalculation: RequestChecksumCalculation.WHEN_REQUIRED, + }; expect(() => { getChecksumAlgorithmForRequest(mockInput, mockOptions); - }).toThrowError( + }).toThrow( `The checksum algorithm "${unsupportedAlgo}" is not supported by the client.` + ` Select one of ${CLIENT_SUPPORTED_ALGORITHMS}.` ); @@ -46,7 +74,11 @@ describe(getChecksumAlgorithmForRequest.name, () => { describe("returns input[requestAlgorithmMember] if supported by client", () => { it.each(CLIENT_SUPPORTED_ALGORITHMS)("Supported algorithm: %s", (supportedAlgorithm) => { const mockInput = { [mockRequestAlgorithmMember]: supportedAlgorithm }; - const mockOptions = { requestChecksumRequired: true, requestAlgorithmMember: mockRequestAlgorithmMember }; + const mockOptions = { + requestChecksumRequired: true, + requestAlgorithmMember: mockRequestAlgorithmMember, + requestChecksumCalculation: RequestChecksumCalculation.WHEN_REQUIRED, + }; expect(getChecksumAlgorithmForRequest(mockInput, mockOptions)).toEqual(supportedAlgorithm); }); }); diff --git a/packages/middleware-flexible-checksums/src/getChecksumAlgorithmForRequest.ts b/packages/middleware-flexible-checksums/src/getChecksumAlgorithmForRequest.ts index dc79a1a0c307..809b6714b24b 100644 --- a/packages/middleware-flexible-checksums/src/getChecksumAlgorithmForRequest.ts +++ b/packages/middleware-flexible-checksums/src/getChecksumAlgorithmForRequest.ts @@ -1,4 +1,4 @@ -import { ChecksumAlgorithm, DEFAULT_CHECKSUM_ALGORITHM } from "./constants"; +import { ChecksumAlgorithm, DEFAULT_CHECKSUM_ALGORITHM, RequestChecksumCalculation } from "./constants"; import { CLIENT_SUPPORTED_ALGORITHMS } from "./types"; export interface GetChecksumAlgorithmForRequestOptions { @@ -11,6 +11,11 @@ export interface GetChecksumAlgorithmForRequestOptions { * Defines a top-level operation input member that is used to configure request checksum behavior. */ requestAlgorithmMember?: string; + + /** + * Determines when a checksum will be calculated for request payloads + */ + requestChecksumCalculation: RequestChecksumCalculation; } /** @@ -20,13 +25,19 @@ export interface GetChecksumAlgorithmForRequestOptions { */ export const getChecksumAlgorithmForRequest = ( input: any, - { requestChecksumRequired, requestAlgorithmMember }: GetChecksumAlgorithmForRequestOptions + { requestChecksumRequired, requestAlgorithmMember, requestChecksumCalculation }: GetChecksumAlgorithmForRequestOptions ): ChecksumAlgorithm | undefined => { - // Either the Operation input member that is used to configure request checksum behavior is not set, or - // the value for input member to configure flexible checksum is not set. - if (!requestAlgorithmMember || !input[requestAlgorithmMember]) { - // Select an algorithm only if request checksum is required. - return requestChecksumRequired ? DEFAULT_CHECKSUM_ALGORITHM : undefined; + // The Operation input member that is used to configure request checksum behavior is not set. + if (!requestAlgorithmMember) { + // Select an algorithm only if request checksum calculation is supported + // or request checksum is required. + return requestChecksumCalculation === RequestChecksumCalculation.WHEN_SUPPORTED || requestChecksumRequired + ? DEFAULT_CHECKSUM_ALGORITHM + : undefined; + } + + if (!input[requestAlgorithmMember]) { + return undefined; } const checksumAlgorithm = input[requestAlgorithmMember]; diff --git a/packages/middleware-flexible-checksums/src/getFlexibleChecksumsPlugin.ts b/packages/middleware-flexible-checksums/src/getFlexibleChecksumsPlugin.ts index 94dd3ecea9b0..0d6898ea5ea7 100644 --- a/packages/middleware-flexible-checksums/src/getFlexibleChecksumsPlugin.ts +++ b/packages/middleware-flexible-checksums/src/getFlexibleChecksumsPlugin.ts @@ -1,6 +1,11 @@ import { Pluggable } from "@smithy/types"; import { PreviouslyResolved } from "./configuration"; +import { + flexibleChecksumsInputMiddleware, + FlexibleChecksumsInputMiddlewareConfig, + flexibleChecksumsInputMiddlewareOptions, +} from "./flexibleChecksumsInputMiddleware"; import { flexibleChecksumsMiddleware, flexibleChecksumsMiddlewareOptions, @@ -14,6 +19,7 @@ import { export interface FlexibleChecksumsMiddlewareConfig extends FlexibleChecksumsRequestMiddlewareConfig, + FlexibleChecksumsInputMiddlewareConfig, FlexibleChecksumsResponseMiddlewareConfig {} export const getFlexibleChecksumsPlugin = ( @@ -22,6 +28,10 @@ export const getFlexibleChecksumsPlugin = ( ): Pluggable => ({ applyToStack: (clientStack) => { clientStack.add(flexibleChecksumsMiddleware(config, middlewareConfig), flexibleChecksumsMiddlewareOptions); + clientStack.addRelativeTo( + flexibleChecksumsInputMiddleware(config, middlewareConfig), + flexibleChecksumsInputMiddlewareOptions + ); clientStack.addRelativeTo( flexibleChecksumsResponseMiddleware(config, middlewareConfig), flexibleChecksumsResponseMiddlewareOptions From 4c0514bbc93ff4b2e955ebb02bd4a2d11eac3c4d Mon Sep 17 00:00:00 2001 From: Trivikram Kamat <16024985+trivikr@users.noreply.github.com> Date: Thu, 19 Dec 2024 18:17:06 +0000 Subject: [PATCH 2/3] test: change toThrow to toThrowError --- .../src/getChecksumAlgorithmForRequest.spec.ts | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/packages/middleware-flexible-checksums/src/getChecksumAlgorithmForRequest.spec.ts b/packages/middleware-flexible-checksums/src/getChecksumAlgorithmForRequest.spec.ts index 7ac3ce5438d7..ce2942dec6bc 100644 --- a/packages/middleware-flexible-checksums/src/getChecksumAlgorithmForRequest.spec.ts +++ b/packages/middleware-flexible-checksums/src/getChecksumAlgorithmForRequest.spec.ts @@ -65,7 +65,7 @@ describe(getChecksumAlgorithmForRequest.name, () => { }; expect(() => { getChecksumAlgorithmForRequest(mockInput, mockOptions); - }).toThrow( + }).toThrowError( `The checksum algorithm "${unsupportedAlgo}" is not supported by the client.` + ` Select one of ${CLIENT_SUPPORTED_ALGORITHMS}.` ); From c450f57a2573ce51820a0c3c4c6fca73d79db0bf Mon Sep 17 00:00:00 2001 From: Trivikram Kamat <16024985+trivikr@users.noreply.github.com> Date: Fri, 20 Dec 2024 02:51:45 +0000 Subject: [PATCH 3/3] test: update middleware-flexible-checksums.integ.spec.ts --- ...iddleware-flexible-checksums.integ.spec.ts | 233 ++++++++++-------- 1 file changed, 132 insertions(+), 101 deletions(-) diff --git a/packages/middleware-flexible-checksums/src/middleware-flexible-checksums.integ.spec.ts b/packages/middleware-flexible-checksums/src/middleware-flexible-checksums.integ.spec.ts index 1512a76b53f5..ba9a9ef4ceb9 100644 --- a/packages/middleware-flexible-checksums/src/middleware-flexible-checksums.integ.spec.ts +++ b/packages/middleware-flexible-checksums/src/middleware-flexible-checksums.integ.spec.ts @@ -4,6 +4,7 @@ import { Readable, Transform } from "stream"; import { describe, expect, test as it } from "vitest"; import { requireRequestsFrom } from "../../../private/aws-util-test/src"; +import { DEFAULT_CHECKSUM_ALGORITHM, RequestChecksumCalculation, ResponseChecksumValidation } from "./constants"; describe("middleware-flexible-checksums", () => { const logger = { @@ -14,7 +15,7 @@ describe("middleware-flexible-checksums", () => { error() {}, }; - const testCases: [string, ChecksumAlgorithm, string][] = [ + const testCases: [string, ChecksumAlgorithm | undefined, string][] = [ ["", ChecksumAlgorithm.CRC32, "AAAAAA=="], ["abc", ChecksumAlgorithm.CRC32, "NSRBwg=="], ["Hello world", ChecksumAlgorithm.CRC32, "i9aeUg=="], @@ -30,118 +31,146 @@ describe("middleware-flexible-checksums", () => { ["", ChecksumAlgorithm.SHA256, "47DEQpj8HBSa+/TImW+5JCeuQeRkm5NMpJWZG3hSuFU="], ["abc", ChecksumAlgorithm.SHA256, "ungWv48Bz+pBQUDeXa4iI7ADYaOWF3qctBD/YfIAFa0="], ["Hello world", ChecksumAlgorithm.SHA256, "ZOyIygCyaOW6GjVnihtTFtIS9PNmskdyMlNKiuyjfzw="], + + // Choose default checksum algorithm when explicily not provided. + ["", undefined, "AAAAAA=="], + ["abc", undefined, "NSRBwg=="], + ["Hello world", undefined, "i9aeUg=="], ]; describe(S3.name, () => { - const client = new S3({ region: "us-west-2", logger }); - describe("putObject", () => { - testCases.forEach(([body, checksumAlgorithm, checksumValue]) => { - const checksumHeader = `x-amz-checksum-${checksumAlgorithm.toLowerCase()}`; - - it(`sets ${checksumHeader}="${checksumValue}"" for checksum="${checksumAlgorithm}"`, async () => { - requireRequestsFrom(client).toMatch({ - method: "PUT", - hostname: "s3.us-west-2.amazonaws.com", - protocol: "https:", - path: "/b/k", - headers: { - "content-type": "application/octet-stream", - ...(body.length - ? { - "content-length": body.length.toString(), - Expect: "100-continue", - } - : {}), - "x-amz-sdk-checksum-algorithm": checksumAlgorithm, - [checksumHeader]: checksumValue, - host: "s3.us-west-2.amazonaws.com", - "x-amz-user-agent": /./, - "user-agent": /./, - "amz-sdk-invocation-id": /./, - "amz-sdk-request": /./, - "x-amz-date": /./, - "x-amz-content-sha256": /./, - authorization: /./, - }, - query: { - "x-id": "PutObject", - }, + describe.each([undefined, RequestChecksumCalculation.WHEN_SUPPORTED, RequestChecksumCalculation.WHEN_REQUIRED])( + `when requestChecksumCalculation='%s'`, + (requestChecksumCalculation) => { + testCases.forEach(([body, checksumAlgorithm, checksumValue]) => { + const client = new S3({ region: "us-west-2", logger, requestChecksumCalculation }); + const checksumHeader = `x-amz-checksum-${(checksumAlgorithm ?? DEFAULT_CHECKSUM_ALGORITHM).toLowerCase()}`; + + it(`tests ${checksumHeader}="${checksumValue}"" for checksum="${checksumAlgorithm}"`, async () => { + requireRequestsFrom(client).toMatch({ + method: "PUT", + hostname: "s3.us-west-2.amazonaws.com", + protocol: "https:", + path: "/b/k", + headers: { + "content-type": "application/octet-stream", + ...(body.length + ? { + "content-length": body.length.toString(), + Expect: "100-continue", + } + : {}), + ...(requestChecksumCalculation === RequestChecksumCalculation.WHEN_REQUIRED && + checksumAlgorithm === undefined + ? {} + : { + "x-amz-sdk-checksum-algorithm": checksumAlgorithm, + [checksumHeader]: checksumValue, + }), + host: "s3.us-west-2.amazonaws.com", + "x-amz-user-agent": /./, + "user-agent": /./, + "amz-sdk-invocation-id": /./, + "amz-sdk-request": /./, + "x-amz-date": /./, + "x-amz-content-sha256": /./, + authorization: /./, + }, + query: { + "x-id": "PutObject", + }, + }); + + await client.putObject({ + Bucket: "b", + Key: "k", + Body: body, + ChecksumAlgorithm: checksumAlgorithm as ChecksumAlgorithm, + }); + + expect.hasAssertions(); + }); }); - - await client.putObject({ - Bucket: "b", - Key: "k", - Body: body, - ChecksumAlgorithm: checksumAlgorithm, - }); - - expect.hasAssertions(); - }); - }); + } + ); }); describe("getObject", () => { - testCases.forEach(([body, checksumAlgorithm, checksumValue]) => { - const checksumHeader = `x-amz-checksum-${checksumAlgorithm.toLowerCase()}`; - - it(`validates ${checksumHeader}="${checksumValue}"" set for checksum="${checksumAlgorithm}"`, async () => { - const client = new S3({ - region: "us-west-2", - logger, - requestHandler: new (class implements HttpHandler { - async handle(request: HttpRequest): Promise { - expect(request).toMatchObject({ - method: "GET", - hostname: "s3.us-west-2.amazonaws.com", - protocol: "https:", - path: "/b/k", - headers: { - "x-amz-checksum-mode": "ENABLED", - host: "s3.us-west-2.amazonaws.com", - "x-amz-user-agent": /./, - "user-agent": /./, - "amz-sdk-invocation-id": /./, - "amz-sdk-request": /./, - "x-amz-date": /./, - "x-amz-content-sha256": /./, - authorization: /./, - }, - query: { - "x-id": "GetObject", - }, - }); - return { - response: new HttpResponse({ - statusCode: 200, - headers: { - "content-type": "application/octet-stream", - "content-length": body.length.toString(), - [checksumHeader]: checksumValue, - }, - body: Readable.from([body]), - }), - }; - } - updateHttpClientConfig(key: never, value: never): void {} - httpHandlerConfigs() { - return {}; - } - })(), - }); - - const response = await client.getObject({ - Bucket: "b", - Key: "k", - ChecksumMode: "ENABLED", + describe.each([undefined, ResponseChecksumValidation.WHEN_SUPPORTED, ResponseChecksumValidation.WHEN_REQUIRED])( + `when responseChecksumValidation='%s'`, + (responseChecksumValidation) => { + testCases.forEach(([body, checksumAlgorithm, checksumValue]) => { + const checksumHeader = `x-amz-checksum-${(checksumAlgorithm ?? DEFAULT_CHECKSUM_ALGORITHM).toLowerCase()}`; + + it(`validates ${checksumHeader}="${checksumValue}"" for checksum="${checksumAlgorithm}"`, async () => { + const client = new S3({ + region: "us-west-2", + logger, + requestHandler: new (class implements HttpHandler { + async handle(request: HttpRequest): Promise { + expect(request).toMatchObject({ + method: "GET", + hostname: "s3.us-west-2.amazonaws.com", + protocol: "https:", + path: "/b/k", + headers: { + ...(responseChecksumValidation === ResponseChecksumValidation.WHEN_REQUIRED && + !checksumAlgorithm + ? {} + : { + "x-amz-checksum-mode": "ENABLED", + }), + host: "s3.us-west-2.amazonaws.com", + "x-amz-user-agent": /./, + "user-agent": /./, + "amz-sdk-invocation-id": /./, + "amz-sdk-request": /./, + "x-amz-date": /./, + "x-amz-content-sha256": /./, + authorization: /./, + }, + query: { + "x-id": "GetObject", + }, + }); + return { + response: new HttpResponse({ + statusCode: 200, + headers: { + "content-type": "application/octet-stream", + "content-length": body.length.toString(), + [checksumHeader]: checksumValue, + }, + body: Readable.from([body]), + }), + }; + } + updateHttpClientConfig(key: never, value: never): void {} + httpHandlerConfigs() { + return {}; + } + })(), + responseChecksumValidation, + }); + + const response = await client.getObject({ + Bucket: "b", + Key: "k", + // Do not pass ChecksumMode if algorithm is not explicitly defined. It'll be set by SDK. + ChecksumMode: checksumAlgorithm ? "ENABLED" : undefined, + }); + + await expect(response.Body?.transformToString()).resolves.toEqual(body); + }); }); - - await expect(response.Body?.transformToString()).resolves.toEqual(body); - }); - }); + } + ); }); it("should not set binary file content length", async () => { + const client = new S3({ region: "us-west-2", logger }); + requireRequestsFrom(client).toMatch({ method: "PUT", hostname: "s3.us-west-2.amazonaws.com", @@ -182,6 +211,8 @@ describe("middleware-flexible-checksums", () => { ["CRC32C", "V"], ].forEach(([algo, id]) => { it(`should feature-detect checksum ${algo}=${id}`, async () => { + const client = new S3({ region: "us-west-2", logger }); + requireRequestsFrom(client).toMatch({ headers: { "user-agent": new RegExp(`(.*?) m\/(.*?)${id}(.*?)$`),