Skip to content

Commit

Permalink
fix prompt strip to support tensors and np arrays (#27818)
Browse files Browse the repository at this point in the history
* fix prompt strip to support tensors and np arrays

* framework agnostic

* change logic check before converting prompt into list

Co-authored-by: Sanchit Gandhi <[email protected]>

* adding _convert_to_list to tokenization_whisper_fast

* adding tests for prompt decoding

* adding comment

Co-authored-by: Sanchit Gandhi <[email protected]>

* adding comment

Co-authored-by: Sanchit Gandhi <[email protected]>

* revert minor

* make style formatting

* style formatting after update

* Update src/transformers/models/whisper/tokenization_whisper_fast.py

Co-authored-by: Sanchit Gandhi <[email protected]>

* fixing _strip_prompt to handle _decode_with_timestamps

* fix copies

---------

Co-authored-by: Sanchit Gandhi <[email protected]>
  • Loading branch information
AvivSham and sanchit-gandhi authored Jul 12, 2024
1 parent d1a1bcf commit 7f79a97
Show file tree
Hide file tree
Showing 3 changed files with 76 additions and 6 deletions.
23 changes: 20 additions & 3 deletions src/transformers/models/whisper/tokenization_whisper.py
Original file line number Diff line number Diff line change
Expand Up @@ -851,9 +851,16 @@ def get_prompt_ids(self, text: str, return_tensors="np"):
batch_encoding.convert_to_tensors(tensor_type=return_tensors)
return batch_encoding["input_ids"]

@staticmethod
def _strip_prompt(token_ids: List[int], prompt_token_id: int, decoder_start_token_id: int):
has_prompt = isinstance(token_ids, list) and token_ids and token_ids[0] == prompt_token_id
def _strip_prompt(self, token_ids: List[int], prompt_token_id: int, decoder_start_token_id: int):
if not isinstance(token_ids, list):
token_ids = self._convert_to_list(token_ids)

# handle case of empty token_ids for decoding with timestamps.
# at this point token_ids is a list, so it is safe to use if not check.
if not token_ids:
return token_ids

has_prompt = token_ids[0] == prompt_token_id
if has_prompt:
if decoder_start_token_id in token_ids:
return token_ids[token_ids.index(decoder_start_token_id) :]
Expand All @@ -862,6 +869,16 @@ def _strip_prompt(token_ids: List[int], prompt_token_id: int, decoder_start_toke

return token_ids

@staticmethod
def _convert_to_list(token_ids):
# convert type to ndarray if necessary
if "torch" in str(type(token_ids)) or "tensorflow" in str(type(token_ids)) and hasattr(token_ids, "numpy"):
token_ids = token_ids.numpy()
# now the token ids are either a numpy array, or a list of lists
if isinstance(token_ids, np.ndarray):
token_ids = token_ids.tolist()
return token_ids


def _decode_asr(tokenizer, model_outputs, *, return_timestamps, return_language, time_precision):
"""
Expand Down
24 changes: 21 additions & 3 deletions src/transformers/models/whisper/tokenization_whisper_fast.py
Original file line number Diff line number Diff line change
Expand Up @@ -582,14 +582,32 @@ def get_prompt_ids(self, text: str, return_tensors="np"):
batch_encoding.convert_to_tensors(tensor_type=return_tensors)
return batch_encoding["input_ids"]

@staticmethod
# Copied from transformers.models.whisper.tokenization_whisper.WhisperTokenizer._strip_prompt
def _strip_prompt(token_ids: List[int], prompt_token_id: int, decoder_start_token_id: int):
has_prompt = isinstance(token_ids, list) and token_ids and token_ids[0] == prompt_token_id
def _strip_prompt(self, token_ids: List[int], prompt_token_id: int, decoder_start_token_id: int):
if not isinstance(token_ids, list):
token_ids = self._convert_to_list(token_ids)

# handle case of empty token_ids for decoding with timestamps.
# at this point token_ids is a list, so it is safe to use if not check.
if not token_ids:
return token_ids

has_prompt = token_ids[0] == prompt_token_id
if has_prompt:
if decoder_start_token_id in token_ids:
return token_ids[token_ids.index(decoder_start_token_id) :]
else:
return []

return token_ids

@staticmethod
# Copied from transformers.models.whisper.tokenization_whisper.WhisperTokenizer._convert_to_list
def _convert_to_list(token_ids):
# convert type to ndarray if necessary
if "torch" in str(type(token_ids)) or "tensorflow" in str(type(token_ids)) and hasattr(token_ids, "numpy"):
token_ids = token_ids.numpy()
# now the token ids are either a numpy array, or a list of lists
if isinstance(token_ids, np.ndarray):
token_ids = token_ids.tolist()
return token_ids
35 changes: 35 additions & 0 deletions tests/models/whisper/test_tokenization_whisper.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@

import unittest

import numpy as np

from transformers.models.whisper import WhisperTokenizer, WhisperTokenizerFast
from transformers.models.whisper.tokenization_whisper import _combine_tokens_into_words, _find_longest_common_sequence
from transformers.testing_utils import slow
Expand Down Expand Up @@ -251,6 +253,39 @@ def test_fast_tokenizer_get_prompt_ids(self):

self.assertListEqual(tokenizer_prompt_ids.tolist(), fast_tokenizer_prompt_ids.tolist())

def test_tokenizer_decode_prompt(self):
prompt_text = "What does the fox say?"
input_text = "Hatee hatee hatee ho"

tokenizer = self.get_tokenizer()
rust_tokenizer = self.get_rust_tokenizer()

# encode prompt and input text using tokenizer
prompt_ids = tokenizer.get_prompt_ids(prompt_text, return_tensors="np")
input_ids = tokenizer(input_text, return_tensors="np").input_ids[0]
input_ids = np.hstack([prompt_ids, input_ids])

# encode using fast tokenizer
rust_prompt_ids = rust_tokenizer.get_prompt_ids(prompt_text, return_tensors="np")
rust_input_ids = rust_tokenizer(input_text, return_tensors="np").input_ids[0]
rust_input_ids = np.hstack([rust_prompt_ids, rust_input_ids])

# check with prompt in output
pred_text = tokenizer.decode(input_ids, skip_special_tokens=False)
rust_pred_text = rust_tokenizer.decode(rust_input_ids, skip_special_tokens=False)

# check correctness for both tokenizers
expected_text = f"<|startofprev|> {prompt_text}<|startoftranscript|><|notimestamps|>{input_text}<|endoftext|>"
self.assertEqual(pred_text.strip(), expected_text)
self.assertEqual(rust_pred_text.strip(), expected_text)

# check stripping prompt from output
pred_text = tokenizer.decode(input_ids, skip_special_tokens=True)
rust_pred_text = tokenizer.decode(input_ids, skip_special_tokens=True)

self.assertEqual(pred_text.strip(), input_text)
self.assertEqual(rust_pred_text.strip(), input_text)

def test_combine_tokens_into_words(self):
tokenizer = self.get_tokenizer()
rust_tokenizer = self.get_rust_tokenizer()
Expand Down

0 comments on commit 7f79a97

Please sign in to comment.