Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Downsizing sample for some language in speech to text #7835

Merged
merged 11 commits into from
May 12, 2021
Original file line number Diff line number Diff line change
Expand Up @@ -217,6 +217,31 @@ describe('Predictions convert provider test', () => {
predictionsProvider.convert(validSpeechToTextInput)
).rejects.toMatch('region not configured for transcription');
});
test('Error languageCode not configured ', () => {
SebSchwartz marked this conversation as resolved.
Show resolved Hide resolved
AmazonAIConvertPredictionsProvider.serializeDataFromTranscribe = jest.fn(
() => {
return 'Hello how are you';
}
);

const predictionsProvider = new AmazonAIConvertPredictionsProvider();
const speechGenOptions = {
transcription: {
region: 'us-west-2',
proxy: false,
},
};
predictionsProvider.configure(speechGenOptions);
jest.spyOn(Credentials, 'get').mockImplementationOnce(() => {
return Promise.resolve(credentials);
});

return expect(
predictionsProvider.convert(validSpeechToTextInput)
).rejects.toMatch(
'languageCode not configured or provided for transcription'
);
});
test('Happy case ', () => {
AmazonAIConvertPredictionsProvider.serializeDataFromTranscribe = jest.fn(
() => {
Expand Down Expand Up @@ -247,5 +272,37 @@ describe('Predictions convert provider test', () => {
},
} as SpeechToTextOutput);
});
test('Downsized Happy case ', async () => {
AmazonAIConvertPredictionsProvider.serializeDataFromTranscribe = jest.fn(
() => {
return 'Bonjour, comment vas tu?';
}
);
const downsampleBufferSpyon = jest.spyOn(
AmazonAIConvertPredictionsProvider.prototype as any,
'downsampleBuffer'
);

const predictionsProvider = new AmazonAIConvertPredictionsProvider();
const speechGenOptions = {
transcription: {
region: 'us-west-2',
proxy: false,
defaults: {
language: 'fr-FR',
},
},
};
predictionsProvider.configure(speechGenOptions);
jest.spyOn(Credentials, 'get').mockImplementationOnce(() => {
return Promise.resolve(credentials);
});

await predictionsProvider.convert(validSpeechToTextInput);
expect(downsampleBufferSpyon).toBeCalledWith(
expect.objectContaining({ outputSampleRate: 8000 })
);
downsampleBufferSpyon.mockClear();
});
});
});
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@ import { fromUtf8, toUtf8 } from '@aws-sdk/util-utf8-node';
const logger = new Logger('AmazonAIConvertPredictionsProvider');
const eventBuilder = new EventStreamMarshaller(toUtf8, fromUtf8);

const LANGUAGES_CODE_IN_8KHZ = ['fr-FR', 'en-AU', 'en-GB', 'fr-CA'];

export class AmazonAIConvertPredictionsProvider extends AbstractConvertPredictionsProvider {
private translateClient: TranslateClient;
private pollyClient: PollyClient;
Expand Down Expand Up @@ -182,6 +184,7 @@ export class AmazonAIConvertPredictionsProvider extends AbstractConvertPredictio
const fullText = await this.sendDataToTranscribe({
connection,
raw: source.bytes,
languageCode: language,
});
return {
transcription: {
Expand All @@ -206,9 +209,7 @@ export class AmazonAIConvertPredictionsProvider extends AbstractConvertPredictio
const transcribeMessage = eventBuilder.unmarshall(
Buffer.from(message.data)
);
const transcribeMessageJson = JSON.parse(
toUtf8(transcribeMessage.body)
);
const transcribeMessageJson = JSON.parse(toUtf8(transcribeMessage.body));
if (transcribeMessage.headers[':message-type'].value === 'exception') {
logger.debug(
'exception',
Expand Down Expand Up @@ -244,7 +245,11 @@ export class AmazonAIConvertPredictionsProvider extends AbstractConvertPredictio
return decodedMessage;
}

private sendDataToTranscribe({ connection, raw }): Promise<string> {
private sendDataToTranscribe({
connection,
raw,
languageCode,
}): Promise<string> {
return new Promise((res, rej) => {
let fullText = '';
connection.onmessage = message => {
Expand Down Expand Up @@ -276,8 +281,11 @@ export class AmazonAIConvertPredictionsProvider extends AbstractConvertPredictio
if (Array.isArray(raw)) {
for (let i = 0; i < raw.length - 1023; i += 1024) {
const data = raw.slice(i, i + 1024);
this.sendEncodedDataToTranscribe(connection, data);
this.sendEncodedDataToTranscribe(connection, data, languageCode);
}
} else {
// If Buffer
this.sendEncodedDataToTranscribe(connection, raw, languageCode);
}

// sending end frame
Expand All @@ -287,8 +295,13 @@ export class AmazonAIConvertPredictionsProvider extends AbstractConvertPredictio
});
}

private sendEncodedDataToTranscribe(connection, data) {
const downsampledBuffer = this.downsampleBuffer({ buffer: data });
private sendEncodedDataToTranscribe(connection, data, languageCode) {
const downsampledBuffer = this.downsampleBuffer({
buffer: data,
outputSampleRate: LANGUAGES_CODE_IN_8KHZ.includes(languageCode)
? 8000
: 16000,
});
const pcmEncodedBuffer = this.pcmEncode(downsampledBuffer);
const audioEventMessage = this.getAudioEventMessage(
Buffer.from(pcmEncodedBuffer)
Expand Down Expand Up @@ -327,14 +340,13 @@ export class AmazonAIConvertPredictionsProvider extends AbstractConvertPredictio
}

private inputSampleRate = 44100;
private outputSampleRate = 16000;

private downsampleBuffer({ buffer }) {
if (this.outputSampleRate === this.inputSampleRate) {
private downsampleBuffer({ buffer, outputSampleRate = 16000 }) {
if (outputSampleRate === this.inputSampleRate) {
return buffer;
}

const sampleRateRatio = this.inputSampleRate / this.outputSampleRate;
const sampleRateRatio = this.inputSampleRate / outputSampleRate;
const newLength = Math.round(buffer.length / sampleRateRatio);
const result = new Float32Array(newLength);
let offsetResult = 0;
Expand Down Expand Up @@ -399,7 +411,9 @@ export class AmazonAIConvertPredictionsProvider extends AbstractConvertPredictio
`wss://transcribestreaming.${region}.amazonaws.com:8443`,
'/stream-transcription-websocket?',
`media-encoding=pcm&`,
`sample-rate=16000&`,
`sample-rate=${
LANGUAGES_CODE_IN_8KHZ.includes(languageCode) ? '8000' : '16000'
}&`,
`language-code=${languageCode}`,
].join('');

Expand Down