Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Unpin numba #23162

Merged
merged 16 commits into from
May 31, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 1 addition & 3 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,6 @@
"librosa",
"nltk",
"natten>=0.14.6",
"numba<0.57.0", # Can be removed once unpinned.
"numpy>=1.17",
"onnxconverter-common",
"onnxruntime-tools>=1.4.2",
Expand Down Expand Up @@ -286,8 +285,7 @@ def run(self):
extras["integrations"] = extras["optuna"] + extras["ray"] + extras["sigopt"]

extras["serving"] = deps_list("pydantic", "uvicorn", "fastapi", "starlette")
# numba can be removed here once unpinned
extras["audio"] = deps_list("librosa", "pyctcdecode", "phonemizer", "kenlm", "numba")
extras["audio"] = deps_list("librosa", "pyctcdecode", "phonemizer", "kenlm")
# `pip install ".[speech]"` is deprecated and `pip install ".[torch-speech]"` should be used instead
extras["speech"] = deps_list("torchaudio") + extras["audio"]
extras["torch-speech"] = deps_list("torchaudio") + extras["audio"]
Expand Down
1 change: 0 additions & 1 deletion src/transformers/dependency_versions_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,6 @@
"librosa": "librosa",
"nltk": "nltk",
"natten": "natten>=0.14.6",
"numba": "numba<0.57.0",
"numpy": "numpy>=1.17",
"onnxconverter-common": "onnxconverter-common",
"onnxruntime-tools": "onnxruntime-tools>=1.4.2",
Expand Down
10 changes: 9 additions & 1 deletion src/transformers/feature_extraction_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,15 @@ def as_tensor(value):
as_tensor = jnp.array
is_tensor = is_jax_tensor
else:
as_tensor = np.asarray

def as_tensor(value, dtype=None):
if isinstance(value, (list, tuple)) and isinstance(value[0], (list, tuple, np.ndarray)):
value_lens = [len(val) for val in value]
if len(set(value_lens)) > 1 and dtype is None:
# we have a ragged list so handle explicitly
value = as_tensor([np.asarray(val) for val in value], dtype=object)
return np.asarray(value, dtype=dtype)

is_tensor = is_numpy_array

# Do the tensor conversion in batch
Expand Down
10 changes: 9 additions & 1 deletion src/transformers/tokenization_utils_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -705,7 +705,15 @@ def convert_to_tensors(
as_tensor = jnp.array
is_tensor = is_jax_tensor
else:
as_tensor = np.asarray

def as_tensor(value, dtype=None):
if isinstance(value, (list, tuple)) and isinstance(value[0], (list, tuple, np.ndarray)):
value_lens = [len(val) for val in value]
if len(set(value_lens)) > 1 and dtype is None:
# we have a ragged list so handle explicitly
value = as_tensor([np.asarray(val) for val in value], dtype=object)
return np.asarray(value, dtype=dtype)

is_tensor = is_numpy_array

# Do the tensor conversion in batch
Expand Down
2 changes: 1 addition & 1 deletion tests/models/realm/test_modeling_realm.py
Original file line number Diff line number Diff line change
Expand Up @@ -392,7 +392,7 @@ def test_training(self):
b"This is the fourth record.",
b"This is the fifth record.",
],
dtype=np.object,
dtype=object,
)
retriever = RealmRetriever(block_records, tokenizer)
model = RealmForOpenQA(openqa_config, retriever)
Expand Down
6 changes: 3 additions & 3 deletions tests/models/realm/test_retrieval_realm.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ def get_dummy_block_records(self):
b"This is the fifth record",
b"This is a longer longer longer record",
],
dtype=np.object,
dtype=object,
)
return block_records

Expand All @@ -116,7 +116,7 @@ def test_retrieve(self):
retriever = self.get_dummy_retriever()
tokenizer = retriever.tokenizer

retrieved_block_ids = np.array([0, 3], dtype=np.long)
retrieved_block_ids = np.array([0, 3], dtype="long")
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

np.long is also deprecated in favour of dtype="long"

question_input_ids = tokenizer(["Test question"]).input_ids
answer_ids = tokenizer(
["the fourth"],
Expand Down Expand Up @@ -151,7 +151,7 @@ def test_block_has_answer(self):
retriever = self.get_dummy_retriever()
tokenizer = retriever.tokenizer

retrieved_block_ids = np.array([0, 3, 5], dtype=np.long)
retrieved_block_ids = np.array([0, 3, 5], dtype="long")
question_input_ids = tokenizer(["Test question"]).input_ids
answer_ids = tokenizer(
["the fourth", "longer longer"],
Expand Down