From bc5a21fcff0a326d79ddf41651a8997165d89da3 Mon Sep 17 00:00:00 2001 From: Sebastien Schwartz Date: Wed, 12 May 2021 20:17:36 +0200 Subject: [PATCH] fix(@aws-amplify/predictions): Downsizing sample for some language in speech to text (#7835) * fix: fix typo in SignIn.vue (#6921) Forget your password? the text has a typo * Downsizing sample for some language in speech to text * adapting test and sendEncodedDataToTranscribe Co-authored-by: sksabircn Co-authored-by: William Lee <43682783+wlee221@users.noreply.github.com> --- ...SAIConvertPredictionsProvider-unit-test.ts | 57 +++++++++++++++++++ .../AmazonAIConvertPredictionsProvider.ts | 38 +++++++++---- 2 files changed, 83 insertions(+), 12 deletions(-) diff --git a/packages/predictions/__tests__/Providers/AWSAIConvertPredictionsProvider-unit-test.ts b/packages/predictions/__tests__/Providers/AWSAIConvertPredictionsProvider-unit-test.ts index f9a509956ef..7fe16fd02e8 100644 --- a/packages/predictions/__tests__/Providers/AWSAIConvertPredictionsProvider-unit-test.ts +++ b/packages/predictions/__tests__/Providers/AWSAIConvertPredictionsProvider-unit-test.ts @@ -217,6 +217,31 @@ describe('Predictions convert provider test', () => { predictionsProvider.convert(validSpeechToTextInput) ).rejects.toMatch('region not configured for transcription'); }); + test('Error languageCode not configured ', () => { + AmazonAIConvertPredictionsProvider.serializeDataFromTranscribe = jest.fn( + () => { + return 'Hello how are you'; + } + ); + + const predictionsProvider = new AmazonAIConvertPredictionsProvider(); + const speechGenOptions = { + transcription: { + region: 'us-west-2', + proxy: false, + }, + }; + predictionsProvider.configure(speechGenOptions); + jest.spyOn(Credentials, 'get').mockImplementationOnce(() => { + return Promise.resolve(credentials); + }); + + return expect( + predictionsProvider.convert(validSpeechToTextInput) + ).rejects.toMatch( + 'languageCode not configured or provided for transcription' + ); + }); test('Happy case ', () => { AmazonAIConvertPredictionsProvider.serializeDataFromTranscribe = jest.fn( () => { @@ -247,5 +272,37 @@ describe('Predictions convert provider test', () => { }, } as SpeechToTextOutput); }); + test('Downsized Happy case ', async () => { + AmazonAIConvertPredictionsProvider.serializeDataFromTranscribe = jest.fn( + () => { + return 'Bonjour, comment vas tu?'; + } + ); + const downsampleBufferSpyon = jest.spyOn( + AmazonAIConvertPredictionsProvider.prototype as any, + 'downsampleBuffer' + ); + + const predictionsProvider = new AmazonAIConvertPredictionsProvider(); + const speechGenOptions = { + transcription: { + region: 'us-west-2', + proxy: false, + defaults: { + language: 'fr-FR', + }, + }, + }; + predictionsProvider.configure(speechGenOptions); + jest.spyOn(Credentials, 'get').mockImplementationOnce(() => { + return Promise.resolve(credentials); + }); + + await predictionsProvider.convert(validSpeechToTextInput); + expect(downsampleBufferSpyon).toBeCalledWith( + expect.objectContaining({ outputSampleRate: 8000 }) + ); + downsampleBufferSpyon.mockClear(); + }); }); }); diff --git a/packages/predictions/src/Providers/AmazonAIConvertPredictionsProvider.ts b/packages/predictions/src/Providers/AmazonAIConvertPredictionsProvider.ts index f4ef5f803e8..686b8fe8816 100644 --- a/packages/predictions/src/Providers/AmazonAIConvertPredictionsProvider.ts +++ b/packages/predictions/src/Providers/AmazonAIConvertPredictionsProvider.ts @@ -28,6 +28,8 @@ import { fromUtf8, toUtf8 } from '@aws-sdk/util-utf8-node'; const logger = new Logger('AmazonAIConvertPredictionsProvider'); const eventBuilder = new EventStreamMarshaller(toUtf8, fromUtf8); +const LANGUAGES_CODE_IN_8KHZ = ['fr-FR', 'en-AU', 'en-GB', 'fr-CA']; + export class AmazonAIConvertPredictionsProvider extends AbstractConvertPredictionsProvider { private translateClient: TranslateClient; private pollyClient: PollyClient; @@ -182,6 +184,7 @@ export class AmazonAIConvertPredictionsProvider extends AbstractConvertPredictio const fullText = await this.sendDataToTranscribe({ connection, raw: source.bytes, + languageCode: language, }); return { transcription: { @@ -206,9 +209,7 @@ export class AmazonAIConvertPredictionsProvider extends AbstractConvertPredictio const transcribeMessage = eventBuilder.unmarshall( Buffer.from(message.data) ); - const transcribeMessageJson = JSON.parse( - toUtf8(transcribeMessage.body) - ); + const transcribeMessageJson = JSON.parse(toUtf8(transcribeMessage.body)); if (transcribeMessage.headers[':message-type'].value === 'exception') { logger.debug( 'exception', @@ -244,7 +245,11 @@ export class AmazonAIConvertPredictionsProvider extends AbstractConvertPredictio return decodedMessage; } - private sendDataToTranscribe({ connection, raw }): Promise { + private sendDataToTranscribe({ + connection, + raw, + languageCode, + }): Promise { return new Promise((res, rej) => { let fullText = ''; connection.onmessage = message => { @@ -276,8 +281,11 @@ export class AmazonAIConvertPredictionsProvider extends AbstractConvertPredictio if (Array.isArray(raw)) { for (let i = 0; i < raw.length - 1023; i += 1024) { const data = raw.slice(i, i + 1024); - this.sendEncodedDataToTranscribe(connection, data); + this.sendEncodedDataToTranscribe(connection, data, languageCode); } + } else { + // If Buffer + this.sendEncodedDataToTranscribe(connection, raw, languageCode); } // sending end frame @@ -287,8 +295,13 @@ export class AmazonAIConvertPredictionsProvider extends AbstractConvertPredictio }); } - private sendEncodedDataToTranscribe(connection, data) { - const downsampledBuffer = this.downsampleBuffer({ buffer: data }); + private sendEncodedDataToTranscribe(connection, data, languageCode) { + const downsampledBuffer = this.downsampleBuffer({ + buffer: data, + outputSampleRate: LANGUAGES_CODE_IN_8KHZ.includes(languageCode) + ? 8000 + : 16000, + }); const pcmEncodedBuffer = this.pcmEncode(downsampledBuffer); const audioEventMessage = this.getAudioEventMessage( Buffer.from(pcmEncodedBuffer) @@ -327,14 +340,13 @@ export class AmazonAIConvertPredictionsProvider extends AbstractConvertPredictio } private inputSampleRate = 44100; - private outputSampleRate = 16000; - private downsampleBuffer({ buffer }) { - if (this.outputSampleRate === this.inputSampleRate) { + private downsampleBuffer({ buffer, outputSampleRate = 16000 }) { + if (outputSampleRate === this.inputSampleRate) { return buffer; } - const sampleRateRatio = this.inputSampleRate / this.outputSampleRate; + const sampleRateRatio = this.inputSampleRate / outputSampleRate; const newLength = Math.round(buffer.length / sampleRateRatio); const result = new Float32Array(newLength); let offsetResult = 0; @@ -399,7 +411,9 @@ export class AmazonAIConvertPredictionsProvider extends AbstractConvertPredictio `wss://transcribestreaming.${region}.amazonaws.com:8443`, '/stream-transcription-websocket?', `media-encoding=pcm&`, - `sample-rate=16000&`, + `sample-rate=${ + LANGUAGES_CODE_IN_8KHZ.includes(languageCode) ? '8000' : '16000' + }&`, `language-code=${languageCode}`, ].join('');