diff --git a/thinc/backends/cupy_ops.py b/thinc/backends/cupy_ops.py index 366faf70a..1e1e5b92b 100644 --- a/thinc/backends/cupy_ops.py +++ b/thinc/backends/cupy_ops.py @@ -94,7 +94,7 @@ def asarray(self, data, dtype=None): elif is_mxnet_gpu_array(data): array = mxnet2xp(data) else: - array = self.xp.array(data) + array = self.xp.array(data, dtype=dtype) if dtype is not None: array = array.astype(dtype=dtype, copy=False) diff --git a/thinc/backends/numpy_ops.pyx b/thinc/backends/numpy_ops.pyx index 5ab4d0d8f..4ecad4271 100644 --- a/thinc/backends/numpy_ops.pyx +++ b/thinc/backends/numpy_ops.pyx @@ -72,7 +72,7 @@ class NumpyOps(Ops): elif hasattr(data, "get"): array = data.get() else: - array = self.xp.array(data) + array = self.xp.array(data, dtype=dtype) if dtype is not None: array = array.astype(dtype=dtype, copy=False) diff --git a/thinc/layers/strings2arrays.py b/thinc/layers/strings2arrays.py index 91a6b1a31..ed40b1e88 100644 --- a/thinc/layers/strings2arrays.py +++ b/thinc/layers/strings2arrays.py @@ -17,8 +17,10 @@ def strings2arrays() -> Model[InT, OutT]: def forward(model: Model[InT, OutT], Xs: InT, is_train: bool) -> Tuple[OutT, Callable]: - hashes = [[hash_unicode(word) for word in X] for X in Xs] - hash_arrays = [model.ops.asarray2i(h, dtype="uint64") for h in hashes] + hashes = model.ops.asarray2i( + [[hash_unicode(word) for word in X] for X in Xs], dtype="int32" + ) + hash_arrays = [model.ops.asarray1i(h, dtype="uint64") for h in hashes] arrays = [model.ops.reshape2i(array, -1, 1) for array in hash_arrays] def backprop(dX: OutT) -> InT: diff --git a/thinc/tests/backends/test_ops.py b/thinc/tests/backends/test_ops.py index b867b14e4..9f03c0438 100644 --- a/thinc/tests/backends/test_ops.py +++ b/thinc/tests/backends/test_ops.py @@ -1597,3 +1597,10 @@ def test_custom_kernel_compilation(): assert compiled_kernel is not None assert compile_mmh() is not None + + +@pytest.mark.parametrize("ops", ALL_OPS) +def test_asarray_from_list_uint64(ops): + # list contains int values both above and below int64.max + uint64_list = [16, 11648197037703959513] + assert uint64_list == list(ops.asarray(uint64_list, dtype="uint64"))