diff --git a/packages/storage/__tests__/providers/AWSS3Provider-unit-test.ts b/packages/storage/__tests__/providers/AWSS3Provider-unit-test.ts index af1fbd4dd47..af1694f8fef 100644 --- a/packages/storage/__tests__/providers/AWSS3Provider-unit-test.ts +++ b/packages/storage/__tests__/providers/AWSS3Provider-unit-test.ts @@ -16,6 +16,7 @@ import * as formatURL from '@aws-sdk/util-format-url'; import { S3Client, ListObjectsCommand } from '@aws-sdk/client-s3'; import { S3RequestPresigner } from '@aws-sdk/s3-request-presigner'; import * as events from 'events'; + import { S3CopySource, S3CopyDestination } from '../../src/types'; /** * NOTE - These test cases use Hub.dispatch but they should @@ -229,6 +230,77 @@ describe('StorageProvider test', () => { }); }); + test('get object with download and progress tracker', async () => { + jest.spyOn(Credentials, 'get').mockImplementationOnce(() => { + return Promise.resolve(credentials); + }); + const mockCallback = jest.fn(); + const mockRemoveAllListeners = jest.fn(); + const mockEventEmitter = { + emit: jest.fn(), + on: jest.fn(), + removeAllListeners: mockRemoveAllListeners, + }; + jest + .spyOn(events, 'EventEmitter') + .mockImplementationOnce(() => mockEventEmitter); + const downloadOptionsWithProgressCallback = Object.assign({}, options, { + download: true, + progressCallback: mockCallback, + }); + const storage = new StorageProvider(); + storage.configure(downloadOptionsWithProgressCallback); + const spyon = jest + .spyOn(S3Client.prototype, 'send') + .mockImplementationOnce(async params => { + return { Body: [1, 2] }; + }); + expect(await storage.get('key', { download: true })).toEqual({ + Body: [1, 2], + }); + expect(mockEventEmitter.on).toBeCalledWith( + 'sendDownloadProgress', + expect.any(Function) + ); + // Get the anonymous function called by the emitter + const emitterOnFn = mockEventEmitter.on.mock.calls[0][1]; + // Manully invoke it for testing + emitterOnFn('arg'); + expect(mockCallback).toBeCalledWith('arg'); + expect(mockRemoveAllListeners).toHaveBeenCalled(); + }); + + test('get object with incorrect progressCallback type', async () => { + jest.spyOn(Credentials, 'get').mockImplementationOnce(() => { + return Promise.resolve(credentials); + }); + const loggerSpy = jest.spyOn(Logger.prototype, '_log'); + const mockEventEmitter = { + emit: jest.fn(), + on: jest.fn(), + removeAllListeners: jest.fn(), + }; + jest + .spyOn(events, 'EventEmitter') + .mockImplementationOnce(() => mockEventEmitter); + const downloadOptionsWithProgressCallback = Object.assign({}, options); + const storage = new StorageProvider(); + storage.configure(downloadOptionsWithProgressCallback); + jest + .spyOn(S3Client.prototype, 'send') + .mockImplementationOnce(async params => { + return { Body: [1, 2] }; + }); + await storage.get('key', { + download: true, + progressCallback: 'this is not a function', + }); + expect(loggerSpy).toHaveBeenCalledWith( + 'WARN', + 'progressCallback should be a function, not a string' + ); + }); + test('get object with download with failure', async () => { jest.spyOn(Credentials, 'get').mockImplementationOnce(() => { return new Promise((res, rej) => { @@ -604,7 +676,7 @@ describe('StorageProvider test', () => { progressCallback: mockCallback, }); expect(mockEventEmitter.on).toBeCalledWith( - 'sendProgress', + 'sendUploadProgress', expect.any(Function) ); const emitterOnFn = mockEventEmitter.on.mock.calls[0][1]; diff --git a/packages/storage/__tests__/providers/AWSS3ProviderManagedUpload-unit-test.ts b/packages/storage/__tests__/providers/AWSS3ProviderManagedUpload-unit-test.ts index 3a4f85dcaeb..26a0ae9f9f7 100644 --- a/packages/storage/__tests__/providers/AWSS3ProviderManagedUpload-unit-test.ts +++ b/packages/storage/__tests__/providers/AWSS3ProviderManagedUpload-unit-test.ts @@ -10,10 +10,7 @@ * CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions * and limitations under the License. */ -import { - AWSS3ProviderManagedUpload, - Part, -} from '../../src/providers/AWSS3ProviderManagedUpload'; +import { AWSS3ProviderManagedUpload, Part } from '../../src/providers/AWSS3ProviderManagedUpload'; import { S3Client, PutObjectCommand, @@ -23,7 +20,7 @@ import { AbortMultipartUploadCommand, ListPartsCommand, } from '@aws-sdk/client-s3'; -import { Logger } from "@aws-amplify/core"; +import { Logger } from '@aws-amplify/core'; import * as events from 'events'; import * as sinon from 'sinon'; @@ -65,7 +62,7 @@ class TestClass extends AWSS3ProviderManagedUpload { await super.uploadParts(uploadId, parts); // Now trigger some notifications from the event listeners for (const part of parts) { - part.emitter.emit('sendProgress', { + part.emitter.emit('sendUploadProgress', { // Assume that the notification is send when 100% of part is uploaded loaded: (part.bodyPart as string).length, }); @@ -80,64 +77,39 @@ afterEach(() => { describe('single part upload tests', () => { test('upload a string as body', async () => { - const putObjectSpyOn = jest - .spyOn(S3Client.prototype, 'send') - .mockImplementation(command => { - if (command instanceof PutObjectCommand) - return Promise.resolve(command.input.Key); - }); - const uploader = new AWSS3ProviderManagedUpload( - testParams, - testOpts, - new events.EventEmitter() - ); + const putObjectSpyOn = jest.spyOn(S3Client.prototype, 'send').mockImplementation(command => { + if (command instanceof PutObjectCommand) return Promise.resolve(command.input.Key); + }); + const uploader = new AWSS3ProviderManagedUpload(testParams, testOpts, new events.EventEmitter()); const data = await uploader.upload(); expect(data).toBe(testParams.Key); expect(putObjectSpyOn.mock.calls[0][0].input).toStrictEqual(testParams); }); test('upload a javascript object as body', async () => { - const putObjectSpyOn = jest - .spyOn(S3Client.prototype, 'send') - .mockImplementation(command => { - if (command instanceof PutObjectCommand) - return Promise.resolve(command.input.Key); - }); + const putObjectSpyOn = jest.spyOn(S3Client.prototype, 'send').mockImplementation(command => { + if (command instanceof PutObjectCommand) return Promise.resolve(command.input.Key); + }); const objectBody = { key1: 'value1', key2: 'value2' }; const testParamsWithObjectBody: any = Object.assign({}, testParams); testParamsWithObjectBody.Body = objectBody; - const uploader = new AWSS3ProviderManagedUpload( - testParamsWithObjectBody, - testOpts, - new events.EventEmitter() - ); + const uploader = new AWSS3ProviderManagedUpload(testParamsWithObjectBody, testOpts, new events.EventEmitter()); const data = await uploader.upload(); expect(data).toBe(testParamsWithObjectBody.Key); - expect(putObjectSpyOn.mock.calls[0][0].input).toStrictEqual( - testParamsWithObjectBody - ); + expect(putObjectSpyOn.mock.calls[0][0].input).toStrictEqual(testParamsWithObjectBody); }); test('upload a file as body', async () => { - const putObjectSpyOn = jest - .spyOn(S3Client.prototype, 'send') - .mockImplementation(command => { - if (command instanceof PutObjectCommand) - return Promise.resolve(command.input.Key); - }); + const putObjectSpyOn = jest.spyOn(S3Client.prototype, 'send').mockImplementation(command => { + if (command instanceof PutObjectCommand) return Promise.resolve(command.input.Key); + }); const file = new File(['TestFileContent'], 'testFileName'); const testParamsWithFileBody: any = Object.assign({}, testParams); testParamsWithFileBody.Body = file; - const uploader = new AWSS3ProviderManagedUpload( - testParamsWithFileBody, - testOpts, - new events.EventEmitter() - ); + const uploader = new AWSS3ProviderManagedUpload(testParamsWithFileBody, testOpts, new events.EventEmitter()); const data = await uploader.upload(); expect(data).toBe(testParamsWithFileBody.Key); - expect(putObjectSpyOn.mock.calls[0][0].input).toStrictEqual( - testParamsWithFileBody - ); + expect(putObjectSpyOn.mock.calls[0][0].input).toStrictEqual(testParamsWithFileBody); }); }); @@ -146,22 +118,20 @@ describe('multi part upload tests', () => { // setup event handling const emitter = new events.EventEmitter(); const eventSpy = sinon.spy(); - emitter.on('sendProgress', eventSpy); + emitter.on('sendUploadProgress', eventSpy); // Setup Spy for S3 service calls - const s3ServiceCallSpy = jest - .spyOn(S3Client.prototype, 'send') - .mockImplementation(async command => { - if (command instanceof CreateMultipartUploadCommand) { - return Promise.resolve({ UploadId: testUploadId }); - } else if (command instanceof UploadPartCommand) { - return Promise.resolve({ - ETag: 'test_etag_' + command.input.PartNumber, - }); - } else if (command instanceof CompleteMultipartUploadCommand) { - return Promise.resolve({ Key: testParams.Key }); - } - }); + const s3ServiceCallSpy = jest.spyOn(S3Client.prototype, 'send').mockImplementation(async command => { + if (command instanceof CreateMultipartUploadCommand) { + return Promise.resolve({ UploadId: testUploadId }); + } else if (command instanceof UploadPartCommand) { + return Promise.resolve({ + ETag: 'test_etag_' + command.input.PartNumber, + }); + } else if (command instanceof CompleteMultipartUploadCommand) { + return Promise.resolve({ Key: testParams.Key }); + } + }); // Now make calls const uploader = new TestClass(testParams, testOpts, emitter); @@ -230,36 +200,34 @@ describe('multi part upload tests', () => { // setup event handling const emitter = new events.EventEmitter(); const eventSpy = sinon.spy(); - emitter.on('sendProgress', eventSpy); + emitter.on('sendUploadProgress', eventSpy); // Setup Spy for S3 service calls and introduce a service failure - const s3ServiceCallSpy = jest - .spyOn(S3Client.prototype, 'send') - .mockImplementation(async command => { - if (command instanceof CreateMultipartUploadCommand) { - return Promise.resolve({ UploadId: testUploadId }); - } else if (command instanceof UploadPartCommand) { - let promise = null; - if (command.input.PartNumber === 2) { - promise = new Promise((resolve, reject) => { - setTimeout(() => { - reject(new Error('Part 2 just going to fail in 100ms')); - }, 100); - }); - } else { - promise = new Promise((resolve, reject) => { - setTimeout(() => { - resolve({ - ETag: 'test_etag_' + command.input.PartNumber, - }); - }, 200); - }); - } - return promise; - } else if (command instanceof CompleteMultipartUploadCommand) { - return Promise.resolve({ Key: testParams.key }); + const s3ServiceCallSpy = jest.spyOn(S3Client.prototype, 'send').mockImplementation(async command => { + if (command instanceof CreateMultipartUploadCommand) { + return Promise.resolve({ UploadId: testUploadId }); + } else if (command instanceof UploadPartCommand) { + let promise = null; + if (command.input.PartNumber === 2) { + promise = new Promise((resolve, reject) => { + setTimeout(() => { + reject(new Error('Part 2 just going to fail in 100ms')); + }, 100); + }); + } else { + promise = new Promise((resolve, reject) => { + setTimeout(() => { + resolve({ + ETag: 'test_etag_' + command.input.PartNumber, + }); + }, 200); + }); } - }); + return promise; + } else if (command instanceof CompleteMultipartUploadCommand) { + return Promise.resolve({ Key: testParams.key }); + } + }); // Now make calls const uploader = new TestClass(testParams, testOpts, emitter); @@ -341,14 +309,8 @@ describe('multi part upload tests', () => { return Promise.resolve(); } }); - const uploader = new TestClass( - testParams, - testOpts, - new events.EventEmitter() - ); - await expect(uploader.upload()).rejects.toThrow( - 'Upload was cancelled. Multi Part upload clean up failed' - ); + const uploader = new TestClass(testParams, testOpts, new events.EventEmitter()); + await expect(uploader.upload()).rejects.toThrow('Upload was cancelled. Multi Part upload clean up failed'); }); test('error case: finish multipart upload failed', async () => { @@ -364,16 +326,12 @@ describe('multi part upload tests', () => { } }); const loggerSpy = jest.spyOn(Logger.prototype, '_log'); - const uploader = new TestClass( - testParams, - testOpts, - new events.EventEmitter() - ); + const uploader = new TestClass(testParams, testOpts, new events.EventEmitter()); await uploader.upload(); expect(loggerSpy).toHaveBeenCalledWith( 'ERROR', 'error happened while finishing the upload. Cancelling the multipart upload', 'error' - ) + ); }); }); diff --git a/packages/storage/__tests__/providers/axios-http-handler.test.ts b/packages/storage/__tests__/providers/axios-http-handler.test.ts index 709090b104f..a65978b2c8f 100644 --- a/packages/storage/__tests__/providers/axios-http-handler.test.ts +++ b/packages/storage/__tests__/providers/axios-http-handler.test.ts @@ -1,4 +1,5 @@ -import axios from 'axios'; +import axios, { CancelTokenSource } from 'axios'; +import * as events from 'events'; import { AxiosHttpHandler, @@ -123,11 +124,19 @@ describe('AxiosHttpHandler', () => { responseType: 'blob', url: 'http://localhost:3000/', onUploadProgress: expect.any(Function), + onDownloadProgress: expect.any(Function), }); // Invoke the request's onUploadProgress function manually lastCall.onUploadProgress({ loaded: 10, total: 100 }); - expect(mockEmit).toHaveBeenLastCalledWith('sendProgress', { + expect(mockEmit).toHaveBeenLastCalledWith('sendUploadProgress', { + loaded: 10, + total: 100, + }); + + // Invoke the request's onDownloadProgress function manually + lastCall.onDownloadProgress({ loaded: 10, total: 100 }); + expect(mockEmit).toHaveBeenLastCalledWith('sendDownloadProgress', { loaded: 10, total: 100, }); diff --git a/packages/storage/src/providers/AWSS3Provider.ts b/packages/storage/src/providers/AWSS3Provider.ts index 096a647d804..0e0f6b19021 100644 --- a/packages/storage/src/providers/AWSS3Provider.ts +++ b/packages/storage/src/providers/AWSS3Provider.ts @@ -23,6 +23,11 @@ import { import { formatUrl } from '@aws-sdk/util-format-url'; import { createRequest } from '@aws-sdk/util-create-request'; import { S3RequestPresigner } from '@aws-sdk/s3-request-presigner'; +import { + AxiosHttpHandler, + SEND_DOWNLOAD_PROGRESS_EVENT, + SEND_UPLOAD_PROGRESS_EVENT, +} from './axios-http-handler'; import { StorageOptions, StorageProvider, @@ -32,7 +37,6 @@ import { S3CopyDestination, } from '../types'; import { StorageErrorStrings } from '../common/StorageErrorStrings'; -import { AxiosHttpHandler } from './axios-http-handler'; import { AWSS3ProviderManagedUpload } from './AWSS3ProviderManagedUpload'; import * as events from 'events'; @@ -244,10 +248,12 @@ export class AWSS3Provider implements StorageProvider { contentType, expires, track, + progressCallback, } = opt; const prefix = this._prefix(opt); const final_key = prefix + key; - const s3 = this._createNewS3Client(opt); + const emitter = new events.EventEmitter(); + const s3 = this._createNewS3Client(opt, emitter); logger.debug('get ' + key + ' from ' + final_key); const params: any = { @@ -265,7 +271,20 @@ export class AWSS3Provider implements StorageProvider { if (download === true) { const getObjectCommand = new GetObjectCommand(params); try { + if (progressCallback) { + if (typeof progressCallback === 'function') { + emitter.on(SEND_DOWNLOAD_PROGRESS_EVENT, progress => { + progressCallback(progress); + }); + } else { + logger.warn( + 'progressCallback should be a function, not a ' + + typeof progressCallback + ); + } + } const response = await s3.send(getObjectCommand); + emitter.removeAllListeners(SEND_DOWNLOAD_PROGRESS_EVENT); dispatchStorageEvent( track, 'download', @@ -386,7 +405,7 @@ export class AWSS3Provider implements StorageProvider { try { if (progressCallback) { if (typeof progressCallback === 'function') { - emitter.on('sendProgress', progress => { + emitter.on(SEND_UPLOAD_PROGRESS_EVENT, progress => { progressCallback(progress); }); } else { diff --git a/packages/storage/src/providers/AWSS3ProviderManagedUpload.ts b/packages/storage/src/providers/AWSS3ProviderManagedUpload.ts index 36601e5ff67..fa229144f9b 100644 --- a/packages/storage/src/providers/AWSS3ProviderManagedUpload.ts +++ b/packages/storage/src/providers/AWSS3ProviderManagedUpload.ts @@ -28,7 +28,7 @@ import { AbortMultipartUploadCommand, CompletedPart, } from '@aws-sdk/client-s3'; -import { AxiosHttpHandler, SEND_PROGRESS_EVENT } from './axios-http-handler'; +import { AxiosHttpHandler, SEND_UPLOAD_PROGRESS_EVENT, SEND_DOWNLOAD_PROGRESS_EVENT } from './axios-http-handler'; import * as events from 'events'; const logger = new Logger('AWSS3ProviderManagedUpload'); @@ -267,11 +267,12 @@ export class AWSS3ProviderManagedUpload { } private removeEventListener(part: Part) { - part.emitter.removeAllListeners(SEND_PROGRESS_EVENT); + part.emitter.removeAllListeners(SEND_UPLOAD_PROGRESS_EVENT); + part.emitter.removeAllListeners(SEND_DOWNLOAD_PROGRESS_EVENT); } private setupEventListener(part: Part) { - part.emitter.on(SEND_PROGRESS_EVENT, progress => { + part.emitter.on(SEND_UPLOAD_PROGRESS_EVENT, progress => { this.progressChanged( part.partNumber, progress.loaded - part._lastUploadedBytes @@ -282,7 +283,7 @@ export class AWSS3ProviderManagedUpload { private progressChanged(partNumber: number, incrementalUpdate: number) { this.bytesUploaded += incrementalUpdate; - this.emitter.emit(SEND_PROGRESS_EVENT, { + this.emitter.emit(SEND_UPLOAD_PROGRESS_EVENT, { loaded: this.bytesUploaded, total: this.totalBytesToUpload, part: partNumber, diff --git a/packages/storage/src/providers/axios-http-handler.ts b/packages/storage/src/providers/axios-http-handler.ts index 30f20e51cb4..a838375e8e1 100644 --- a/packages/storage/src/providers/axios-http-handler.ts +++ b/packages/storage/src/providers/axios-http-handler.ts @@ -22,9 +22,11 @@ import axios, { } from 'axios'; import { ConsoleLogger as Logger, Platform } from '@aws-amplify/core'; import { FetchHttpHandlerOptions } from '@aws-sdk/fetch-http-handler'; +import * as events from 'events'; const logger = new Logger('axios-http-handler'); -export const SEND_PROGRESS_EVENT = 'sendProgress'; +export const SEND_UPLOAD_PROGRESS_EVENT = 'sendUploadProgress'; +export const SEND_DOWNLOAD_PROGRESS_EVENT = 'sendDownloadProgress'; function isBlob(body: any): body is Blob { return typeof Blob !== 'undefined' && body instanceof Blob; @@ -60,7 +62,7 @@ export const reactNativeRequestTransformer: AxiosTransformer[] = [ export class AxiosHttpHandler implements HttpHandler { constructor( private readonly httpOptions: FetchHttpHandlerOptions = {}, - private readonly emitter?: any, + private readonly emitter?: events.EventEmitter, private readonly cancelTokenSource?: CancelTokenSource ) {} @@ -122,7 +124,11 @@ export class AxiosHttpHandler implements HttpHandler { } if (emitter) { axiosRequest.onUploadProgress = function(event) { - emitter.emit(SEND_PROGRESS_EVENT, event); + emitter.emit(SEND_UPLOAD_PROGRESS_EVENT, event); + logger.debug(event); + }; + axiosRequest.onDownloadProgress = function(event) { + emitter.emit(SEND_DOWNLOAD_PROGRESS_EVENT, event); logger.debug(event); }; }