Skip to content

Commit

Permalink
Add multilingual transcription + translation for whisper models (#87, #…
Browse files Browse the repository at this point in the history
…95) (#133)

* Align `.generate()` return type with python library

* Add multilingual transcription + translation for whisper models (#87, #95)

* Include `return_timestamps` in calculation of `forced_decoder_ids`

* Only return non-null `forced_decoder_ids`

* Allow user to specify task in any case

* Only set `forced_decoder_ids` when non-empty

* Implement `SuppressTokensAtBeginLogitsProcessor`
  • Loading branch information
xenova authored Jun 9, 2023
1 parent f5f78c4 commit 8625f4a
Show file tree
Hide file tree
Showing 7 changed files with 427 additions and 157 deletions.
69 changes: 48 additions & 21 deletions src/models.js
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ import {
ForceTokensLogitsProcessor,
ForcedBOSTokenLogitsProcessor,
ForcedEOSTokenLogitsProcessor,
SuppressTokensAtBeginLogitsProcessor,
WhisperTimeStampLogitsProcessor,
NoRepeatNGramLogitsProcessor,
RepetitionPenaltyLogitsProcessor,
Expand Down Expand Up @@ -389,6 +390,13 @@ async function seq2seq_forward(self, model_inputs, {
function seq2seqStartBeams(self, inputTokenIds, numOutputTokens, requires_attention_mask = true) {
let beams = [];
let beamId = 0;

// decoder_input_ids == output_token_ids
let decoder_input_ids = self.config.decoder_start_token_id;
if (!Array.isArray(decoder_input_ids)) {
decoder_input_ids = [decoder_input_ids];
}

for (let tokens of inputTokenIds) {
// TODO: Improve
// Currently, just add back batch dimension.
Expand All @@ -401,8 +409,7 @@ function seq2seqStartBeams(self, inputTokenIds, numOutputTokens, requires_attent
encoder_outputs: null,
past_key_values: null,

// decoder_input_ids == output_token_ids
output_token_ids: [self.config.decoder_start_token_id],
output_token_ids: decoder_input_ids,
done: false,
score: 0,
id: beamId++ // assign unique id to beams
Expand Down Expand Up @@ -652,7 +659,7 @@ export class PreTrainedModel extends Callable {

/**
* @param {GenerationConfig} generation_config
* @param {number} input_ids_seq_length
* @param {number} input_ids_seq_length The starting sequence length for the input ids.
* @returns {LogitsProcessorList}
*/
_get_logits_processor(
Expand Down Expand Up @@ -749,14 +756,17 @@ export class PreTrainedModel extends Callable {
// processors.push(new SuppressTokensLogitsProcessor(generation_config.suppress_tokens));
// }

// if (generation_config.begin_suppress_tokens !== null) {
// let begin_index = input_ids_seq_length;
// begin_index = (input_ids_seq_length > 1 || generation_config.forced_bos_token_id === null) ? begin_index : begin_index + 1;
// if (generation_config.forced_decoder_ids !== null) {
// begin_index += generation_config.forced_decoder_ids[generation_config.forced_decoder_ids.length - 1][0];
// }
// processors.push(new SuppressTokensAtBeginLogitsProcessor(generation_config.begin_suppress_tokens, begin_index));
// }
if (generation_config.begin_suppress_tokens !== null) {
let begin_index = (input_ids_seq_length > 1 || generation_config.forced_bos_token_id === null)
? input_ids_seq_length
: input_ids_seq_length + 1;

if (generation_config.forced_decoder_ids !== null) {
// generation starts after the last token that is forced
begin_index += generation_config.forced_decoder_ids[generation_config.forced_decoder_ids.length - 1][0];
}
processors.push(new SuppressTokensAtBeginLogitsProcessor(generation_config.begin_suppress_tokens, begin_index));
}

if (generation_config.forced_decoder_ids !== null) {
processors.push(new ForceTokensLogitsProcessor(generation_config.forced_decoder_ids));
Expand Down Expand Up @@ -809,7 +819,7 @@ export class PreTrainedModel extends Callable {
* @param {Object|null} logits_processor An optional logits processor to use. If null, a new LogitsProcessorList instance will be created.
* @param {Object} options options
* @param {Object} [options.inputs_attention_mask=null] An optional attention mask for the inputs.
* @returns {Promise<Array>} An array of generated output sequences, where each sequence is an array of token IDs.
* @returns {Promise<number[][]>} An array of generated output sequences, where each sequence is an array of token IDs.
* @throws {Error} Throws an error if the inputs array is empty.
*/
async generate(
Expand All @@ -825,22 +835,32 @@ export class PreTrainedModel extends Callable {
throw Error(`\`inputs\` must be a Tensor, TypedArray, or Array, but is "${inputs.constructor.name}".`);
}

if (inputs.length === 0) {
throw Error("Must supply a non-empty array of input token ids.")
let input_ids_seq_length;

// Prepare `input_ids` which will be used for auto-regressive generation
// TODO: Update to align with HF transformers' implementation
if (this.config.is_encoder_decoder) {
// Generating from the encoder outputs
input_ids_seq_length = 0;

} else {
input_ids_seq_length = inputs instanceof Tensor ? inputs.dims[0] : inputs.length;

// decoder-only
if (input_ids_seq_length === 0) {
throw Error("Must supply a non-empty array of input token ids.")
}
}

// Update generation config with defaults
generation_config = this._get_generation_config(generation_config);

logits_processor = logits_processor ?? new LogitsProcessorList()

// TODO Update generation config
// this.generation_config

// Update logits processor
logits_processor = this._get_logits_processor(
generation_config,
inputs.length,
input_ids_seq_length,
logits_processor
)

Expand All @@ -850,6 +870,8 @@ export class PreTrainedModel extends Callable {
let numOutputTokens = 1;
const maxOutputTokens = numOutputTokens + (generation_config.max_new_tokens ?? Infinity);

// Only use max length if max_new_tokens is not provided
const useMaxLength = Number.isInteger(generation_config.max_length) && (generation_config.max_new_tokens ?? null) === null;
let sampler = Sampler.getSampler(generation_config);

// @ts-ignore
Expand All @@ -859,11 +881,16 @@ export class PreTrainedModel extends Callable {
let newest_beams = [];
for (let beam of beams) {
if (beam.done) {
// TODO add length penalty (for ending early)
// Add this beam back into the pool
newest_beams.push(beam);
continue
}
if (useMaxLength && beam.output_token_ids.length >= generation_config.max_length) {
// Set this beam to done and add it back into the pool
beam.done = true;
newest_beams.push(beam);
continue
}

// @ts-ignore
let output = await this.runBeam(beam);
Expand Down Expand Up @@ -899,7 +926,7 @@ export class PreTrainedModel extends Callable {
// Next, we get the best beams, per ID
newest_beams = this.groupBeams(newest_beams).map(
group => group
.sort((a, b) => b.score - a.score) // sort based on score
.sort((a, b) => b.score - a.score) // sort by score
.slice(0, generation_config.num_beams) // remove outside beam width
);

Expand All @@ -922,7 +949,7 @@ export class PreTrainedModel extends Callable {
return [batch[0].output_token_ids];
}
}
)
).flat(); // Flatten across batches (depth=1)
}

/**
Expand Down
122 changes: 91 additions & 31 deletions src/pipelines.js
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ import {
Callable,
isString,
dispatchCallback,
pop,
} from './utils/core.js';
import {
softmax,
Expand All @@ -51,7 +52,6 @@ import {
read_audio
} from './utils/audio.js';
import {
Tensor,
mean_pooling,
} from './utils/tensor.js';
import { RawImage } from './utils/image.js';
Expand Down Expand Up @@ -400,7 +400,7 @@ export class Text2TextGenerationPipeline extends Pipeline {
input_ids = this.tokenizer(texts, tokenizer_options).input_ids;
}

let outputTokenIds = (await this.model.generate(input_ids, generate_kwargs)).flat();
let outputTokenIds = await this.model.generate(input_ids, generate_kwargs);

/**
* @type {any[]}
Expand Down Expand Up @@ -461,26 +461,23 @@ export class TextGenerationPipeline extends Pipeline {
let input_ids = inputs.input_ids;
let attention_mask = inputs.attention_mask;

/**
* @type {any[]}
*/
let outputTokenIds = await this.model.generate(input_ids, generate_kwargs, null, {
inputs_attention_mask: attention_mask
});

let toReturn = outputTokenIds.map((outTokens, i) => {
let startText = texts[i].trim();
let decoded = this.tokenizer.batch_decode(outTokens, {
skip_special_tokens: true,
}).map(x => {
return {
generated_text: startText + x
}
});

return decoded
const trimmedTexts = texts.map(x => x.trim());
const decoded = this.tokenizer.batch_decode(outputTokenIds, {
skip_special_tokens: true,
});
const toReturn = Array.from({ length: texts.length }, _ => []);
for (let i = 0; i < decoded.length; ++i) {
const textIndex = Math.floor(i / outputTokenIds.length * trimmedTexts.length);
let startText = trimmedTexts[textIndex];

toReturn[textIndex].push({
generated_text: startText + decoded[i]
});
}
return (stringInput && toReturn.length === 1) ? toReturn[0] : toReturn;
}
}
Expand Down Expand Up @@ -691,6 +688,44 @@ export class FeatureExtractionPipeline extends Pipeline {

/**
* Pipeline that aims at extracting spoken text contained within some audio.
*
* **Example:** Transcribe English.
* ```javascript
* let url = 'https://huggingface.co/datasets/Xenova/transformers.js-docs/resolve/main/jfk.wav';
* let transcriber = await pipeline('automatic-speech-recognition', 'Xenova/whisper-tiny.en');
* let output = await transcriber(url);
* // { text: " And so my fellow Americans ask not what your country can do for you, ask what you can do for your country." }
* ```
*
* **Example:** Transcribe English w/ timestamps.
* ```javascript
* let url = 'https://huggingface.co/datasets/Xenova/transformers.js-docs/resolve/main/jfk.wav';
* let transcriber = await pipeline('automatic-speech-recognition', 'Xenova/whisper-tiny.en');
* let output = await transcriber(url, { return_timestamps: true });
* // {
* // text: " And so my fellow Americans ask not what your country can do for you, ask what you can do for your country."
* // chunks: [
* // { timestamp: [0, 8], text: " And so my fellow Americans ask not what your country can do for you" }
* // { timestamp: [8, 11], text: " ask what you can do for your country." }
* // ]
* // }
* ```
*
* **Example:** Transcribe French.
* ```javascript
* let url = 'https://huggingface.co/datasets/Xenova/transformers.js-docs/resolve/main/french-audio.mp3';
* let transcriber = await pipeline('automatic-speech-recognition', 'Xenova/whisper-small');
* let output = await transcriber(url, { language: 'french', task: 'transcribe' });
* // { text: " J'adore, j'aime, je n'aime pas, je déteste." }
* ```
*
* **Example:** Translate French to English.
* ```javascript
* let url = 'https://huggingface.co/datasets/Xenova/transformers.js-docs/resolve/main/french-audio.mp3';
* let transcriber = await pipeline('automatic-speech-recognition', 'Xenova/whisper-small');
* let output = await transcriber(url, { language: 'french', task: 'translate' });
* // { text: " I love, I like, I don't like, I hate." }
* ```
* @extends Pipeline
*/
export class AutomaticSpeechRecognitionPipeline extends Pipeline {
Expand All @@ -711,7 +746,7 @@ export class AutomaticSpeechRecognitionPipeline extends Pipeline {
* Preprocesses the input audio for the AutomaticSpeechRecognitionPipeline.
* @param {any} audio The audio to be preprocessed.
* @param {number} sampling_rate The sampling rate of the audio.
* @returns {Promise<string | ArrayBuffer>} A promise that resolves to the preprocessed audio data.
* @returns {Promise<Float32Array>} A promise that resolves to the preprocessed audio data.
* @private
*/
async _preprocess(audio, sampling_rate) {
Expand All @@ -722,16 +757,28 @@ export class AutomaticSpeechRecognitionPipeline extends Pipeline {
return audio;
}

/**
* @typedef {import('./utils/tensor.js').Tensor} Tensor
* @typedef {{stride: number[], input_features: Tensor, is_last: boolean, tokens?: number[]}} Chunk
*
* @callback ChunkCallback
* @param {Chunk} chunk The chunk to process.
*/

/**
* Asynchronously processes audio and generates text transcription using the model.
* @param {Array} audio The audio to be transcribed. Can be a single Float32Array or an array of Float32Arrays.
* @param {Float32Array|Float32Array[]} audio The audio to be transcribed. Can be a single Float32Array or an array of Float32Arrays.
* @param {Object} [kwargs={}] Optional arguments.
* @param {boolean} [kwargs.return_timestamps] Whether to return timestamps or not. Default is false.
* @param {boolean} [kwargs.return_timestamps] Whether to return timestamps or not. Default is `false`.
* @param {number} [kwargs.chunk_length_s] The length of audio chunks to process in seconds. Default is 0 (no chunking).
* @param {number} [kwargs.stride_length_s] The length of overlap between consecutive audio chunks in seconds. If not provided, defaults to chunk_length_s / 6.
* @param {function} [kwargs.chunk_callback] Callback function to be called with each chunk processed.
* @param {boolean} [kwargs.force_full_sequences] Whether to force outputting full sequences or not. Default is false.
* @returns {Promise<Object>} A Promise that resolves to an object containing the transcription text and optionally timestamps if return_timestamps is true.
* @param {number} [kwargs.stride_length_s] The length of overlap between consecutive audio chunks in seconds. If not provided, defaults to `chunk_length_s / 6`.
* @param {ChunkCallback} [kwargs.chunk_callback] Callback function to be called with each chunk processed.
* @param {boolean} [kwargs.force_full_sequences] Whether to force outputting full sequences or not. Default is `false`.
* @param {string} [kwargs.language] The source language. Default is `null`, meaning it should be auto-detected. Use this to potentially improve performance if the source language is known.
* @param {string} [kwargs.task] The task to perform. Default is `null`, meaning it should be auto-detected.
* @param {number[][]} [kwargs.forced_decoder_ids] A list of pairs of integers which indicates a mapping from generation indices to token indices
* that will be forced before sampling. For example, [[1, 123]] means the second generated token will always be a token of index 123.
* @returns {Promise<Object>} A Promise that resolves to an object containing the transcription text and optionally timestamps if `return_timestamps` is `true`.
*/
async _call(audio, kwargs = {}) {
let return_timestamps = kwargs.return_timestamps ?? false;
Expand All @@ -740,13 +787,25 @@ export class AutomaticSpeechRecognitionPipeline extends Pipeline {
let chunk_callback = kwargs.chunk_callback ?? null;
let force_full_sequences = kwargs.force_full_sequences ?? false;

// TODO
// task = 'transcribe',
// language = 'en',
let language = pop(kwargs, 'language', null);
let task = pop(kwargs, 'task', null);

if (language || task || return_timestamps) {
if (kwargs.forced_decoder_ids) {
throw new Error("Cannot specify `language`/`task`/`return_timestamps` and `forced_decoder_ids` at the same time.")
}
// @ts-ignore
let decoder_prompt_ids = this.tokenizer.get_decoder_prompt_ids({ language, task, no_timestamps: !return_timestamps })

let single = !Array.isArray(audio)
if(decoder_prompt_ids.length > 0){
kwargs.forced_decoder_ids = decoder_prompt_ids;
}
}

let single = !Array.isArray(audio);
if (single) {
audio = [audio]
// @ts-ignore
audio = [audio];
}

const sampling_rate = this.processor.feature_extractor.config.sampling_rate;
Expand All @@ -756,7 +815,7 @@ export class AutomaticSpeechRecognitionPipeline extends Pipeline {
for (let aud of audio) {
aud = await this._preprocess(aud, sampling_rate)

/** @type {any[]} */
/** @type {Chunk[]} */
let chunks = [];
if (chunk_length_s > 0) {
if (stride_length_s === null) {
Expand Down Expand Up @@ -806,7 +865,7 @@ export class AutomaticSpeechRecognitionPipeline extends Pipeline {
let data = await this.model.generate(chunk.input_features, kwargs);

// Get top beam
chunk.tokens = data[0].flat()
chunk.tokens = data[0];

// convert stride to seconds
chunk.stride = chunk.stride.map(x => x / sampling_rate);
Expand Down Expand Up @@ -863,7 +922,7 @@ export class ImageToTextPipeline extends Pipeline {
let toReturn = [];
for (let batch of pixel_values) {
batch.dims = [1, ...batch.dims]
let output = (await this.model.generate(batch, generate_kwargs)).flat();
let output = await this.model.generate(batch, generate_kwargs);
let decoded = this.tokenizer.batch_decode(output, {
skip_special_tokens: true,
}).map(x => {
Expand Down Expand Up @@ -1367,6 +1426,7 @@ const TASK_ALIASES = {
"sentiment-analysis": "text-classification",
"ner": "token-classification",
"vqa": "visual-question-answering",
"asr": "automatic-speech-recognition",

// Add for backwards compatibility
"embeddings": "feature-extraction",
Expand Down
2 changes: 1 addition & 1 deletion src/processors.js
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
* import { AutoProcessor, read_audio } from '@xenova/transformers';
*
* let processor = await AutoProcessor.from_pretrained('openai/whisper-tiny.en');
* let audio = await read_audio('https://huggingface.co/datasets/Narsil/asr_dummy/resolve/main/mlk.flac');
* let audio = await read_audio('https://huggingface.co/datasets/Narsil/asr_dummy/resolve/main/mlk.flac', 16000);
* let { input_features } = await processor(audio);
* // Tensor {
* // data: Float32Array(240000) [0.4752984642982483, 0.5597258806228638, 0.56434166431427, ...],
Expand Down
Loading

0 comments on commit 8625f4a

Please sign in to comment.