Skip to content

Commit

Permalink
Merge pull request numba#5418 from guilhermeleobas/np_asfarray
Browse files Browse the repository at this point in the history
Add np.asfarray impl
  • Loading branch information
stuartarchibald authored Aug 27, 2020
2 parents 8390b49 + af17656 commit 0db523e
Show file tree
Hide file tree
Showing 3 changed files with 40 additions and 0 deletions.
1 change: 1 addition & 0 deletions docs/source/reference/numpysupported.rst
Original file line number Diff line number Diff line change
Expand Up @@ -338,6 +338,7 @@ The following top-level functions are supported:
* :func:`numpy.array` (only the 2 first arguments)
* :func:`numpy.array_equal`
* :func:`numpy.asarray` (only the 2 first arguments)
* :func:`numpy.asfarray`
* :func:`numpy.asfortranarray` (only the first argument)
* :func:`numpy.atleast_1d`
* :func:`numpy.atleast_2d`
Expand Down
13 changes: 13 additions & 0 deletions numba/np/arraymath.py
Original file line number Diff line number Diff line change
Expand Up @@ -4070,6 +4070,19 @@ def impl(a, dtype=None):
return impl


@overload(np.asfarray)
def np_asfarray(a, dtype=np.float64):
dtype = as_dtype(dtype)
if not np.issubdtype(dtype, np.inexact):
dx = types.float64
else:
dx = dtype

def impl(a, dtype=np.float64):
return np.asarray(a, dx)
return impl


@overload(np.extract)
def np_extract(condition, arr):

Expand Down
26 changes: 26 additions & 0 deletions numba/tests/test_np_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,6 +245,10 @@ def asarray_kws(a, dtype):
return np.asarray(a, dtype=dtype)


def asfarray(a, dtype=np.float64):
return np.asfarray(a, dtype=dtype)


def extract(condition, arr):
return np.extract(condition, arr)

Expand Down Expand Up @@ -3208,6 +3212,28 @@ def make_unicode_list():
test_reject(make_nested_list_with_dict())
test_reject(make_unicode_list())

def test_asfarray(self):
def inputs():
yield np.array([1, 2, 3]), None
yield np.array([2, 3], dtype=np.float32), np.float32
yield np.array([2, 3], dtype=np.int8), np.int8
yield np.array([2, 3], dtype=np.int8), np.complex64
yield np.array([2, 3], dtype=np.int8), np.complex128

pyfunc = asfarray
cfunc = jit(nopython=True)(pyfunc)

for arr, dt in inputs():
if dt is None:
expected = pyfunc(arr)
got = cfunc(arr)
else:
expected = pyfunc(arr, dtype=dt)
got = cfunc(arr, dtype=dt)

self.assertPreciseEqual(expected, got)
self.assertTrue(np.issubdtype(got.dtype, np.inexact), got.dtype)

def test_repeat(self):
# np.repeat(a, repeats)
np_pyfunc = np_repeat
Expand Down

0 comments on commit 0db523e

Please sign in to comment.