Skip to content

Commit

Permalink
add support for copy=True as well
Browse files Browse the repository at this point in the history
  • Loading branch information
jorisvandenbossche committed Apr 11, 2024
1 parent 5b87fd3 commit 0e2b402
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 1 deletion.
7 changes: 6 additions & 1 deletion python/pyarrow/array.pxi
Original file line number Diff line number Diff line change
Expand Up @@ -1530,10 +1530,15 @@ cdef class Array(_PandasConvertible):
# values is already a numpy array at this point, but calling np.array(..)
# again to handle the `dtype` keyword with a no-copy guarantee
return np.array(values, dtype=dtype, copy=False)

values = self.to_numpy(zero_copy_only=False)
if copy is True and is_primitive(self.type.id) and self.null_count == 0:
# to_numpy did not yet make a copy
return np.array(values, dtype=dtype, copy=True)

if dtype is None:
return values
return values.astype(dtype)
return np.asarray(values, dtype=dtype)

def to_numpy(self, zero_copy_only=True, writable=False):
"""
Expand Down
1 change: 1 addition & 0 deletions python/pyarrow/includes/libarrow.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,7 @@ cdef extern from "arrow/api.h" namespace "arrow" nogil:
c_string ToString()

c_bool is_primitive(Type type)
c_bool is_numeric(Type type)

cdef cppclass CArrayData" arrow::ArrayData":
shared_ptr[CDataType] type
Expand Down
9 changes: 9 additions & 0 deletions python/pyarrow/tests/test_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -3339,6 +3339,15 @@ def test_numpy_array_protocol():
with pytest.raises(ValueError):
np.array(arr, dtype="float64", copy=False)

# copy=True -> not yet passed by numpy, so we have to call this directly to test
arr = pa.array([1, 2, 3])
result = arr.__array__(copy=True)
assert result.flags.writeable

arr = pa.array([1, 2, 3])
result = arr.__array__(dtype=np.dtype("float64"), copy=True)
assert result.dtype == "float64"


def test_array_protocol():

Expand Down

0 comments on commit 0e2b402

Please sign in to comment.