Skip to content

Commit

Permalink
fix(@aws-amplify/predictions): Downsizing sample for some language in…
Browse files Browse the repository at this point in the history
… speech to text (#7835)

* fix: fix typo in SignIn.vue (#6921)

Forget your password? the text has a typo

* Downsizing sample for some language in speech to text

* adapting test and sendEncodedDataToTranscribe

Co-authored-by: sksabircn <[email protected]>
Co-authored-by: William Lee <[email protected]>
  • Loading branch information
3 people authored May 12, 2021
1 parent c48ad83 commit bc5a21f
Show file tree
Hide file tree
Showing 2 changed files with 83 additions and 12 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -217,6 +217,31 @@ describe('Predictions convert provider test', () => {
predictionsProvider.convert(validSpeechToTextInput)
).rejects.toMatch('region not configured for transcription');
});
test('Error languageCode not configured ', () => {
AmazonAIConvertPredictionsProvider.serializeDataFromTranscribe = jest.fn(
() => {
return 'Hello how are you';
}
);

const predictionsProvider = new AmazonAIConvertPredictionsProvider();
const speechGenOptions = {
transcription: {
region: 'us-west-2',
proxy: false,
},
};
predictionsProvider.configure(speechGenOptions);
jest.spyOn(Credentials, 'get').mockImplementationOnce(() => {
return Promise.resolve(credentials);
});

return expect(
predictionsProvider.convert(validSpeechToTextInput)
).rejects.toMatch(
'languageCode not configured or provided for transcription'
);
});
test('Happy case ', () => {
AmazonAIConvertPredictionsProvider.serializeDataFromTranscribe = jest.fn(
() => {
Expand Down Expand Up @@ -247,5 +272,37 @@ describe('Predictions convert provider test', () => {
},
} as SpeechToTextOutput);
});
test('Downsized Happy case ', async () => {
AmazonAIConvertPredictionsProvider.serializeDataFromTranscribe = jest.fn(
() => {
return 'Bonjour, comment vas tu?';
}
);
const downsampleBufferSpyon = jest.spyOn(
AmazonAIConvertPredictionsProvider.prototype as any,
'downsampleBuffer'
);

const predictionsProvider = new AmazonAIConvertPredictionsProvider();
const speechGenOptions = {
transcription: {
region: 'us-west-2',
proxy: false,
defaults: {
language: 'fr-FR',
},
},
};
predictionsProvider.configure(speechGenOptions);
jest.spyOn(Credentials, 'get').mockImplementationOnce(() => {
return Promise.resolve(credentials);
});

await predictionsProvider.convert(validSpeechToTextInput);
expect(downsampleBufferSpyon).toBeCalledWith(
expect.objectContaining({ outputSampleRate: 8000 })
);
downsampleBufferSpyon.mockClear();
});
});
});
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@ import { fromUtf8, toUtf8 } from '@aws-sdk/util-utf8-node';
const logger = new Logger('AmazonAIConvertPredictionsProvider');
const eventBuilder = new EventStreamMarshaller(toUtf8, fromUtf8);

const LANGUAGES_CODE_IN_8KHZ = ['fr-FR', 'en-AU', 'en-GB', 'fr-CA'];

export class AmazonAIConvertPredictionsProvider extends AbstractConvertPredictionsProvider {
private translateClient: TranslateClient;
private pollyClient: PollyClient;
Expand Down Expand Up @@ -182,6 +184,7 @@ export class AmazonAIConvertPredictionsProvider extends AbstractConvertPredictio
const fullText = await this.sendDataToTranscribe({
connection,
raw: source.bytes,
languageCode: language,
});
return {
transcription: {
Expand All @@ -206,9 +209,7 @@ export class AmazonAIConvertPredictionsProvider extends AbstractConvertPredictio
const transcribeMessage = eventBuilder.unmarshall(
Buffer.from(message.data)
);
const transcribeMessageJson = JSON.parse(
toUtf8(transcribeMessage.body)
);
const transcribeMessageJson = JSON.parse(toUtf8(transcribeMessage.body));
if (transcribeMessage.headers[':message-type'].value === 'exception') {
logger.debug(
'exception',
Expand Down Expand Up @@ -244,7 +245,11 @@ export class AmazonAIConvertPredictionsProvider extends AbstractConvertPredictio
return decodedMessage;
}

private sendDataToTranscribe({ connection, raw }): Promise<string> {
private sendDataToTranscribe({
connection,
raw,
languageCode,
}): Promise<string> {
return new Promise((res, rej) => {
let fullText = '';
connection.onmessage = message => {
Expand Down Expand Up @@ -276,8 +281,11 @@ export class AmazonAIConvertPredictionsProvider extends AbstractConvertPredictio
if (Array.isArray(raw)) {
for (let i = 0; i < raw.length - 1023; i += 1024) {
const data = raw.slice(i, i + 1024);
this.sendEncodedDataToTranscribe(connection, data);
this.sendEncodedDataToTranscribe(connection, data, languageCode);
}
} else {
// If Buffer
this.sendEncodedDataToTranscribe(connection, raw, languageCode);
}

// sending end frame
Expand All @@ -287,8 +295,13 @@ export class AmazonAIConvertPredictionsProvider extends AbstractConvertPredictio
});
}

private sendEncodedDataToTranscribe(connection, data) {
const downsampledBuffer = this.downsampleBuffer({ buffer: data });
private sendEncodedDataToTranscribe(connection, data, languageCode) {
const downsampledBuffer = this.downsampleBuffer({
buffer: data,
outputSampleRate: LANGUAGES_CODE_IN_8KHZ.includes(languageCode)
? 8000
: 16000,
});
const pcmEncodedBuffer = this.pcmEncode(downsampledBuffer);
const audioEventMessage = this.getAudioEventMessage(
Buffer.from(pcmEncodedBuffer)
Expand Down Expand Up @@ -327,14 +340,13 @@ export class AmazonAIConvertPredictionsProvider extends AbstractConvertPredictio
}

private inputSampleRate = 44100;
private outputSampleRate = 16000;

private downsampleBuffer({ buffer }) {
if (this.outputSampleRate === this.inputSampleRate) {
private downsampleBuffer({ buffer, outputSampleRate = 16000 }) {
if (outputSampleRate === this.inputSampleRate) {
return buffer;
}

const sampleRateRatio = this.inputSampleRate / this.outputSampleRate;
const sampleRateRatio = this.inputSampleRate / outputSampleRate;
const newLength = Math.round(buffer.length / sampleRateRatio);
const result = new Float32Array(newLength);
let offsetResult = 0;
Expand Down Expand Up @@ -399,7 +411,9 @@ export class AmazonAIConvertPredictionsProvider extends AbstractConvertPredictio
`wss://transcribestreaming.${region}.amazonaws.com:8443`,
'/stream-transcription-websocket?',
`media-encoding=pcm&`,
`sample-rate=16000&`,
`sample-rate=${
LANGUAGES_CODE_IN_8KHZ.includes(languageCode) ? '8000' : '16000'
}&`,
`language-code=${languageCode}`,
].join('');

Expand Down

0 comments on commit bc5a21f

Please sign in to comment.