Skip to content

Commit

Permalink
Unpin numba (#23162)
Browse files Browse the repository at this point in the history
* fix for ragged list

* unpin numba

* make style

* np.object -> object

* propagate changes to tokenizer as well

* np.long -> "long"

* revert tokenization changes

* check with tokenization changes

* list/tuple logic

* catch numpy

* catch else case

* clean up

* up

* better check

* trigger ci

* Empty commit to trigger CI
  • Loading branch information
sanchit-gandhi authored May 31, 2023
1 parent d99f11e commit 8f915c4
Show file tree
Hide file tree
Showing 6 changed files with 23 additions and 10 deletions.
4 changes: 1 addition & 3 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,6 @@
"librosa",
"nltk",
"natten>=0.14.6",
"numba<0.57.0", # Can be removed once unpinned.
"numpy>=1.17",
"onnxconverter-common",
"onnxruntime-tools>=1.4.2",
Expand Down Expand Up @@ -286,8 +285,7 @@ def run(self):
extras["integrations"] = extras["optuna"] + extras["ray"] + extras["sigopt"]

extras["serving"] = deps_list("pydantic", "uvicorn", "fastapi", "starlette")
# numba can be removed here once unpinned
extras["audio"] = deps_list("librosa", "pyctcdecode", "phonemizer", "kenlm", "numba")
extras["audio"] = deps_list("librosa", "pyctcdecode", "phonemizer", "kenlm")
# `pip install ".[speech]"` is deprecated and `pip install ".[torch-speech]"` should be used instead
extras["speech"] = deps_list("torchaudio") + extras["audio"]
extras["torch-speech"] = deps_list("torchaudio") + extras["audio"]
Expand Down
1 change: 0 additions & 1 deletion src/transformers/dependency_versions_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,6 @@
"librosa": "librosa",
"nltk": "nltk",
"natten": "natten>=0.14.6",
"numba": "numba<0.57.0",
"numpy": "numpy>=1.17",
"onnxconverter-common": "onnxconverter-common",
"onnxruntime-tools": "onnxruntime-tools>=1.4.2",
Expand Down
10 changes: 9 additions & 1 deletion src/transformers/feature_extraction_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,15 @@ def as_tensor(value):
as_tensor = jnp.array
is_tensor = is_jax_tensor
else:
as_tensor = np.asarray

def as_tensor(value, dtype=None):
if isinstance(value, (list, tuple)) and isinstance(value[0], (list, tuple, np.ndarray)):
value_lens = [len(val) for val in value]
if len(set(value_lens)) > 1 and dtype is None:
# we have a ragged list so handle explicitly
value = as_tensor([np.asarray(val) for val in value], dtype=object)
return np.asarray(value, dtype=dtype)

is_tensor = is_numpy_array

# Do the tensor conversion in batch
Expand Down
10 changes: 9 additions & 1 deletion src/transformers/tokenization_utils_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -705,7 +705,15 @@ def convert_to_tensors(
as_tensor = jnp.array
is_tensor = is_jax_tensor
else:
as_tensor = np.asarray

def as_tensor(value, dtype=None):
if isinstance(value, (list, tuple)) and isinstance(value[0], (list, tuple, np.ndarray)):
value_lens = [len(val) for val in value]
if len(set(value_lens)) > 1 and dtype is None:
# we have a ragged list so handle explicitly
value = as_tensor([np.asarray(val) for val in value], dtype=object)
return np.asarray(value, dtype=dtype)

is_tensor = is_numpy_array

# Do the tensor conversion in batch
Expand Down
2 changes: 1 addition & 1 deletion tests/models/realm/test_modeling_realm.py
Original file line number Diff line number Diff line change
Expand Up @@ -392,7 +392,7 @@ def test_training(self):
b"This is the fourth record.",
b"This is the fifth record.",
],
dtype=np.object,
dtype=object,
)
retriever = RealmRetriever(block_records, tokenizer)
model = RealmForOpenQA(openqa_config, retriever)
Expand Down
6 changes: 3 additions & 3 deletions tests/models/realm/test_retrieval_realm.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ def get_dummy_block_records(self):
b"This is the fifth record",
b"This is a longer longer longer record",
],
dtype=np.object,
dtype=object,
)
return block_records

Expand All @@ -116,7 +116,7 @@ def test_retrieve(self):
retriever = self.get_dummy_retriever()
tokenizer = retriever.tokenizer

retrieved_block_ids = np.array([0, 3], dtype=np.long)
retrieved_block_ids = np.array([0, 3], dtype="long")
question_input_ids = tokenizer(["Test question"]).input_ids
answer_ids = tokenizer(
["the fourth"],
Expand Down Expand Up @@ -151,7 +151,7 @@ def test_block_has_answer(self):
retriever = self.get_dummy_retriever()
tokenizer = retriever.tokenizer

retrieved_block_ids = np.array([0, 3, 5], dtype=np.long)
retrieved_block_ids = np.array([0, 3, 5], dtype="long")
question_input_ids = tokenizer(["Test question"]).input_ids
answer_ids = tokenizer(
["the fourth", "longer longer"],
Expand Down

0 comments on commit 8f915c4

Please sign in to comment.