diff --git a/packages/middleware-flexible-checksums/src/configuration.ts b/packages/middleware-flexible-checksums/src/configuration.ts index 1a9eebe516da..a6a066c74924 100644 --- a/packages/middleware-flexible-checksums/src/configuration.ts +++ b/packages/middleware-flexible-checksums/src/configuration.ts @@ -9,7 +9,7 @@ import { StreamHasher, } from "@smithy/types"; -import { RequestChecksumCalculation } from "./constants"; +import { RequestChecksumCalculation, ResponseChecksumValidation } from "./constants"; export interface PreviouslyResolved { /** @@ -39,6 +39,11 @@ export interface PreviouslyResolved { */ requestChecksumCalculation: Provider; + /** + * Determines when a checksum will be calculated for response payloads + */ + responseChecksumValidation: Provider; + /** * A constructor for a class implementing the {@link Hash} interface that computes SHA1 hashes. * @internal diff --git a/packages/middleware-flexible-checksums/src/flexibleChecksumsResponseMiddleware.spec.ts b/packages/middleware-flexible-checksums/src/flexibleChecksumsResponseMiddleware.spec.ts index 5dc42ae84836..d2c203dfb792 100644 --- a/packages/middleware-flexible-checksums/src/flexibleChecksumsResponseMiddleware.spec.ts +++ b/packages/middleware-flexible-checksums/src/flexibleChecksumsResponseMiddleware.spec.ts @@ -2,7 +2,7 @@ import { HttpRequest } from "@smithy/protocol-http"; import { DeserializeHandlerArguments } from "@smithy/types"; import { PreviouslyResolved } from "./configuration"; -import { ChecksumAlgorithm } from "./constants"; +import { ChecksumAlgorithm, ResponseChecksumValidation } from "./constants"; import { flexibleChecksumsResponseMiddleware } from "./flexibleChecksumsResponseMiddleware"; import { getChecksumLocationName } from "./getChecksumLocationName"; import { FlexibleChecksumsMiddlewareConfig } from "./getFlexibleChecksumsPlugin"; @@ -23,7 +23,9 @@ describe(flexibleChecksumsResponseMiddleware.name, () => { commandName: "mockCommandName", }; - const mockConfig = {} as PreviouslyResolved; + const mockConfig = { + responseChecksumValidation: () => Promise.resolve(ResponseChecksumValidation.WHEN_REQUIRED), + } as PreviouslyResolved; const mockRequestValidationModeMember = "ChecksumEnabled"; const mockResponseAlgorithms = [ChecksumAlgorithm.CRC32, ChecksumAlgorithm.CRC32C]; const mockMiddlewareConfig = { @@ -59,45 +61,40 @@ describe(flexibleChecksumsResponseMiddleware.name, () => { }); describe("skips", () => { - it("if not an instance of HttpRequest", async () => { - const { isInstance } = HttpRequest; - (isInstance as unknown as jest.Mock).mockReturnValue(false); - const handler = flexibleChecksumsResponseMiddleware(mockConfig, mockMiddlewareConfig)(mockNext, mockContext); + it("if requestValidationModeMember is not defined", async () => { + const mockMwConfig = Object.assign({}, mockMiddlewareConfig) as FlexibleChecksumsMiddlewareConfig; + delete mockMwConfig.requestValidationModeMember; + const handler = flexibleChecksumsResponseMiddleware(mockConfig, mockMwConfig)(mockNext, mockContext); await handler(mockArgs); expect(validateChecksumFromResponse).not.toHaveBeenCalled(); + expect(mockNext).toHaveBeenCalledWith(mockArgs); }); - describe("response checksum", () => { - it("if requestValidationModeMember is not defined", async () => { - const mockMwConfig = Object.assign({}, mockMiddlewareConfig) as FlexibleChecksumsMiddlewareConfig; - delete mockMwConfig.requestValidationModeMember; - const handler = flexibleChecksumsResponseMiddleware(mockConfig, mockMwConfig)(mockNext, mockContext); - await handler(mockArgs); - expect(validateChecksumFromResponse).not.toHaveBeenCalled(); - }); + it("if requestValidationModeMember is not enabled in input", async () => { + const handler = flexibleChecksumsResponseMiddleware(mockConfig, mockMiddlewareConfig)(mockNext, mockContext); - it("if requestValidationModeMember is not enabled in input", async () => { - const handler = flexibleChecksumsResponseMiddleware(mockConfig, mockMiddlewareConfig)(mockNext, mockContext); - await handler({ ...mockArgs, input: {} }); - expect(validateChecksumFromResponse).not.toHaveBeenCalled(); - }); + const mockArgsWithoutEnabled = { ...mockArgs, input: {} }; + await handler(mockArgsWithoutEnabled); + expect(validateChecksumFromResponse).not.toHaveBeenCalled(); + expect(mockNext).toHaveBeenCalledWith(mockArgsWithoutEnabled); + }); - it("if checksum is for S3 whole-object multipart GET", async () => { - (isChecksumWithPartNumber as jest.Mock).mockReturnValue(true); - const handler = flexibleChecksumsResponseMiddleware(mockConfig, mockMiddlewareConfig)(mockNext, { - clientName: "S3Client", - commandName: "GetObjectCommand", - }); - await handler(mockArgs); - expect(isChecksumWithPartNumber).toHaveBeenCalledTimes(1); - expect(isChecksumWithPartNumber).toHaveBeenCalledWith(mockChecksum); - expect(validateChecksumFromResponse).not.toHaveBeenCalled(); + it("if checksum is for S3 whole-object multipart GET", async () => { + (isChecksumWithPartNumber as jest.Mock).mockReturnValue(true); + const handler = flexibleChecksumsResponseMiddleware(mockConfig, mockMiddlewareConfig)(mockNext, { + clientName: "S3Client", + commandName: "GetObjectCommand", }); + await handler(mockArgs); + expect(isChecksumWithPartNumber).toHaveBeenCalledTimes(1); + expect(isChecksumWithPartNumber).toHaveBeenCalledWith(mockChecksum); + expect(validateChecksumFromResponse).not.toHaveBeenCalled(); + expect(mockNext).toHaveBeenCalledWith(mockArgs); }); }); describe("validates checksum from response header", () => { - it("generic case", async () => { + it("if requestValidationModeMember is enabled in input", async () => { const handler = flexibleChecksumsResponseMiddleware(mockConfig, mockMiddlewareConfig)(mockNext, mockContext); await handler(mockArgs); @@ -105,6 +102,25 @@ describe(flexibleChecksumsResponseMiddleware.name, () => { config: mockConfig, responseAlgorithms: mockResponseAlgorithms, }); + expect(mockNext).toHaveBeenCalledWith(mockArgs); + }); + + it(`if requestValidationModeMember is not enabled in input, but responseChecksumValidation returns ${ResponseChecksumValidation.WHEN_SUPPORTED}`, async () => { + const mockConfigWithResponseChecksumValidationSupported = { + ...mockConfig, + responseChecksumValidation: () => Promise.resolve(ResponseChecksumValidation.WHEN_SUPPORTED), + }; + const handler = flexibleChecksumsResponseMiddleware( + mockConfigWithResponseChecksumValidationSupported, + mockMiddlewareConfig + )(mockNext, mockContext); + + await handler({ ...mockArgs, input: {} }); + expect(validateChecksumFromResponse).toHaveBeenCalledWith(mockResult.response, { + config: mockConfigWithResponseChecksumValidationSupported, + responseAlgorithms: mockResponseAlgorithms, + }); + expect(mockNext).toHaveBeenCalledWith(mockArgs); }); it("if checksum is for S3 GET without part number", async () => { @@ -120,6 +136,7 @@ describe(flexibleChecksumsResponseMiddleware.name, () => { config: mockConfig, responseAlgorithms: mockResponseAlgorithms, }); + expect(mockNext).toHaveBeenCalledWith(mockArgs); }); }); }); diff --git a/packages/middleware-flexible-checksums/src/flexibleChecksumsResponseMiddleware.ts b/packages/middleware-flexible-checksums/src/flexibleChecksumsResponseMiddleware.ts index 27735dc5b169..0c28ed33d345 100644 --- a/packages/middleware-flexible-checksums/src/flexibleChecksumsResponseMiddleware.ts +++ b/packages/middleware-flexible-checksums/src/flexibleChecksumsResponseMiddleware.ts @@ -10,7 +10,7 @@ import { } from "@smithy/types"; import { PreviouslyResolved } from "./configuration"; -import { ChecksumAlgorithm } from "./constants"; +import { ChecksumAlgorithm, ResponseChecksumValidation } from "./constants"; import { getChecksumAlgorithmListForResponse } from "./getChecksumAlgorithmListForResponse"; import { getChecksumLocationName } from "./getChecksumLocationName"; import { isChecksumWithPartNumber } from "./isChecksumWithPartNumber"; @@ -37,8 +37,8 @@ export interface FlexibleChecksumsResponseMiddlewareConfig { */ export const flexibleChecksumsResponseMiddlewareOptions: RelativeMiddlewareOptions = { name: "flexibleChecksumsResponseMiddleware", - toMiddleware: "deserializerMiddleware", - relation: "after", + toMiddleware: "serializerMiddleware", + relation: "before", tags: ["BODY_CHECKSUM"], override: true, }; @@ -58,19 +58,25 @@ export const flexibleChecksumsResponseMiddleware = context: HandlerExecutionContext ): DeserializeHandler => async (args: DeserializeHandlerArguments): Promise> => { - if (!HttpRequest.isInstance(args.request)) { - return next(args); + const input = args.input; + const { requestValidationModeMember, responseAlgorithms } = middlewareConfig; + const responseChecksumValidation = await config.responseChecksumValidation(); + + const isResponseChecksumValidationNeeded = + requestValidationModeMember && + (input[requestValidationModeMember] === "ENABLED" || + responseChecksumValidation === ResponseChecksumValidation.WHEN_SUPPORTED); + + if (isResponseChecksumValidationNeeded) { + input[requestValidationModeMember] = "ENABLED"; } - const input = args.input; const result = await next(args); const response = result.response as HttpResponse; let collectedStream: Uint8Array | undefined = undefined; - const { requestValidationModeMember, responseAlgorithms } = middlewareConfig; - // @ts-ignore Element implicitly has an 'any' type for input[requestValidationModeMember] - if (requestValidationModeMember && input[requestValidationModeMember] === "ENABLED") { + if (isResponseChecksumValidationNeeded) { const { clientName, commandName } = context; const isS3WholeObjectMultipartGetResponseChecksum = clientName === "S3Client" &&