Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding download progress tracker for Storage.get #8295

Merged
merged 18 commits into from
Jul 7, 2021
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,11 @@
* and limitations under the License.
*/
import StorageProvider from '../../src/providers/AWSS3Provider';
import { Hub, Credentials } from '@aws-amplify/core';
import { Hub, Credentials, Logger } from '@aws-amplify/core';
import * as formatURL from '@aws-sdk/util-format-url';
import { S3Client, ListObjectsCommand } from '@aws-sdk/client-s3';
import { S3RequestPresigner } from '@aws-sdk/s3-request-presigner';
import * as events from 'events';

/**
* NOTE - These test cases use Hub.dispatch but they should
Expand Down Expand Up @@ -240,6 +241,75 @@ describe('StorageProvider test', () => {
});
});

test('get object with download and progress tracker', async () => {
jest.spyOn(Credentials, 'get').mockImplementationOnce(() => {
return Promise.resolve(credentials);
});
const mockCallback = jest.fn();
const mockEventEmitter = {
emit: jest.fn(),
on: jest.fn(),
};
jest
.spyOn(events, 'EventEmitter')
.mockImplementationOnce(() => mockEventEmitter);
const downloadOptionsWithProgressCallback = Object.assign({}, options, {
download: true,
progressCallback: mockCallback,
});
const storage = new StorageProvider();
storage.configure(downloadOptionsWithProgressCallback);
const spyon = jest
.spyOn(S3Client.prototype, 'send')
.mockImplementationOnce(async params => {
return { Body: [1, 2] };
});
expect(await storage.get('key', { download: true })).toEqual({
Body: [1, 2],
});
expect(mockEventEmitter.on).toBeCalledWith(
'sendDownloadProgress',
expect.any(Function)
);
// Get the anonymous function called by the emitter
const emitterOnFn = mockEventEmitter.on.mock.calls[0][1];
// Manully invoke it for testing
emitterOnFn('arg');
expect(mockCallback).toBeCalledWith('arg');
});

test('get object with incorrect progressCallback type', async () => {
jest.spyOn(Credentials, 'get').mockImplementationOnce(() => {
return Promise.resolve(credentials);
});
const loggerSpy = jest.spyOn(Logger.prototype, '_log');
const mockEventEmitter = {
emit: jest.fn(),
on: jest.fn(),
};
jest
.spyOn(events, 'EventEmitter')
.mockImplementationOnce(() => mockEventEmitter);
const downloadOptionsWithProgressCallback = Object.assign({}, options);
const storage = new StorageProvider();
storage.configure(downloadOptionsWithProgressCallback);
jest
.spyOn(S3Client.prototype, 'send')
.mockImplementationOnce(async params => {
return { Body: [1, 2] };
});
await storage.get('key', {
download: true,
progressCallback: 'this is not a function',
});
const emitterOnFn = mockEventEmitter.on.mock.calls[0][1];
emitterOnFn('arg');
expect(loggerSpy).toHaveBeenCalledWith(
'WARN',
'progressCallback should be a function, not a string'
);
});

test('get object with download with failure', async () => {
jest.spyOn(Credentials, 'get').mockImplementationOnce(() => {
return new Promise((res, rej) => {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ describe('multi part upload tests', () => {
// setup event handling
const emitter = new events.EventEmitter();
const eventSpy = sinon.spy();
emitter.on('sendProgress', eventSpy);
emitter.on('sendUploadProgress', eventSpy);

/** Extend our test class such that minPartSize is reasonable
* and we can mock emit the progress events
Expand All @@ -140,7 +140,7 @@ describe('multi part upload tests', () => {
await super.uploadParts(uploadId, parts);
// Now trigger some notifications from the event listeners
for (const part of parts) {
part.emitter.emit('sendProgress', {
part.emitter.emit('sendUploadProgress', {
// Assume that the notification is sent when 100% of part is uploaded
loaded: part.bodyPart.length,
});
Expand Down Expand Up @@ -230,7 +230,7 @@ describe('multi part upload tests', () => {
// setup event handling
const emitter = new events.EventEmitter();
const eventSpy = sinon.spy();
emitter.on('sendProgress', eventSpy);
emitter.on('sendUploadProgress', eventSpy);

/** Extend our test class such that minPartSize is reasonable
* and we can mock emit the progress events
Expand All @@ -242,7 +242,7 @@ describe('multi part upload tests', () => {
await super.uploadParts(uploadId, parts);
// Now trigger some notifications from the event listeners
for (const part of parts) {
part.emitter.emit('sendProgress', {
part.emitter.emit('sendUploadProgress', {
// Assume that the notification is send when 100% of part is uploaded
loaded: part.bodyPart.length,
});
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import axios from 'axios';
import axios, { CancelTokenSource } from 'axios';
import * as events from 'events';

import { AxiosHttpHandler } from '../../src/providers/axios-http-handler';

Expand Down Expand Up @@ -55,5 +56,55 @@ describe('AxiosHttpHandler', () => {
url: 'http://localhost:3000/',
});
});

it('should attach cancelToken to the request', async () => {
const mockCancelToken = jest.fn().mockImplementationOnce(() => ({
token: 'token',
}));
const handler = new AxiosHttpHandler({}, null, mockCancelToken());
await handler.handle(request, options);

expect(axios.request).toHaveBeenLastCalledWith({
headers: {},
method: 'get',
responseType: 'blob',
url: 'http://localhost:3000/',
cancelToken: 'token',
});
});

it('should track upload or download progress if emitter is present', async () => {
const mockEmit = jest.fn();
const mockEmitter = jest.fn().mockImplementationOnce(() => ({
emit: mockEmit,
}));
const handler = new AxiosHttpHandler({}, mockEmitter());
await handler.handle(request, options);
const lastCall =
axios.request.mock.calls[axios.request.mock.calls.length - 1][0];

expect(lastCall).toStrictEqual({
headers: {},
method: 'get',
responseType: 'blob',
url: 'http://localhost:3000/',
onUploadProgress: expect.any(Function),
onDownloadProgress: expect.any(Function),
});

// Invoke the request's onUploadProgress function manually
lastCall.onUploadProgress({ loaded: 10, total: 100 });
expect(mockEmit).toHaveBeenLastCalledWith('sendUploadProgress', {
loaded: 10,
total: 100,
});

// Invoke the request's onDownloadProgress function manually
lastCall.onDownloadProgress({ loaded: 10, total: 100 });
expect(mockEmit).toHaveBeenLastCalledWith('sendDownloadProgress', {
loaded: 10,
total: 100,
});
});
});
});
26 changes: 22 additions & 4 deletions packages/storage/src/providers/AWSS3Provider.ts
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,11 @@ import { formatUrl } from '@aws-sdk/util-format-url';
import { createRequest } from '@aws-sdk/util-create-request';
import { S3RequestPresigner } from '@aws-sdk/s3-request-presigner';
import { StorageOptions, StorageProvider } from '../types';
import { AxiosHttpHandler } from './axios-http-handler';
import {
AxiosHttpHandler,
SEND_DOWNLOAD_PROGRESS_EVENT,
SEND_UPLOAD_PROGRESS_EVENT,
} from './axios-http-handler';
import { AWSS3ProviderManagedUpload } from './AWSS3ProviderManagedUpload';
import * as events from 'events';

Expand Down Expand Up @@ -138,10 +142,12 @@ export class AWSS3Provider implements StorageProvider {
contentType,
expires,
track,
progressCallback,
} = opt;
const prefix = this._prefix(opt);
const final_key = prefix + key;
const s3 = this._createNewS3Client(opt);
const emitter = new events.EventEmitter();
const s3 = this._createNewS3Client(opt, emitter);
logger.debug('get ' + key + ' from ' + final_key);

const params: any = {
Expand All @@ -160,6 +166,18 @@ export class AWSS3Provider implements StorageProvider {
if (download === true) {
const getObjectCommand = new GetObjectCommand(params);
try {
emitter.on(SEND_DOWNLOAD_PROGRESS_EVENT, progress => {
if (progressCallback) {
if (typeof progressCallback === 'function') {
progressCallback(progress);
} else {
logger.warn(
'progressCallback should be a function, not a ' +
typeof progressCallback
);
}
}
jamesaucode marked this conversation as resolved.
Show resolved Hide resolved
});
const response = await s3.send(getObjectCommand);
dispatchStorageEvent(
track,
Expand Down Expand Up @@ -298,7 +316,7 @@ export class AWSS3Provider implements StorageProvider {
}

try {
emitter.on('sendProgress', progress => {
emitter.on(SEND_UPLOAD_PROGRESS_EVENT, progress => {
if (progressCallback) {
if (typeof progressCallback === 'function') {
progressCallback(progress);
Expand Down Expand Up @@ -502,7 +520,7 @@ export class AWSS3Provider implements StorageProvider {
/**
* @private creates an S3 client with new V3 aws sdk
*/
private _createNewS3Client(config, emitter?) {
private _createNewS3Client(config, emitter?: events.EventEmitter) {
const {
region,
credentials,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ import {
ListPartsCommand,
AbortMultipartUploadCommand,
} from '@aws-sdk/client-s3';
import { AxiosHttpHandler, SEND_PROGRESS_EVENT } from './axios-http-handler';
import { AxiosHttpHandler, SEND_UPLOAD_PROGRESS_EVENT } from './axios-http-handler';
import * as events from 'events';
import { streamCollector } from '@aws-sdk/fetch-http-handler';

Expand Down Expand Up @@ -268,7 +268,7 @@ export class AWSS3ProviderManagedUpload {
}

private setupEventListener(part: Part) {
part.emitter.on(SEND_PROGRESS_EVENT, progress => {
part.emitter.on(SEND_UPLOAD_PROGRESS_EVENT, progress => {
this.progressChanged(
part.partNumber,
progress.loaded - part._lastUploadedBytes
Expand All @@ -279,7 +279,7 @@ export class AWSS3ProviderManagedUpload {

private progressChanged(partNumber: number, incrementalUpdate: number) {
this.bytesUploaded += incrementalUpdate;
this.emitter.emit(SEND_PROGRESS_EVENT, {
this.emitter.emit(SEND_UPLOAD_PROGRESS_EVENT, {
loaded: this.bytesUploaded,
total: this.totalBytesToUpload,
part: partNumber,
Expand Down
12 changes: 9 additions & 3 deletions packages/storage/src/providers/axios-http-handler.ts
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,16 @@ import { buildQueryString } from '@aws-sdk/querystring-builder';
import axios, { AxiosRequestConfig, Method, CancelTokenSource } from 'axios';
import { ConsoleLogger as Logger } from '@aws-amplify/core';
import { FetchHttpHandlerOptions } from '@aws-sdk/fetch-http-handler';
import * as events from 'events';

const logger = new Logger('axios-http-handler');
export const SEND_PROGRESS_EVENT = 'sendProgress';
export const SEND_UPLOAD_PROGRESS_EVENT = 'sendUploadProgress';
export const SEND_DOWNLOAD_PROGRESS_EVENT = 'sendDownloadProgress';

export class AxiosHttpHandler implements HttpHandler {
constructor(
private readonly httpOptions: FetchHttpHandlerOptions = {},
private readonly emitter?: any,
private readonly emitter?: events.EventEmitter,
private readonly cancelTokenSource?: CancelTokenSource
) {}

Expand Down Expand Up @@ -86,7 +88,11 @@ export class AxiosHttpHandler implements HttpHandler {
}
if (emitter) {
axiosRequest.onUploadProgress = function(event) {
emitter.emit(SEND_PROGRESS_EVENT, event);
emitter.emit(SEND_UPLOAD_PROGRESS_EVENT, event);
logger.debug(event);
};
axiosRequest.onDownloadProgress = function(event) {
emitter.emit(SEND_DOWNLOAD_PROGRESS_EVENT, event);
logger.debug(event);
};
}
Expand Down