Skip to content

Commit

Permalink
GH-21761: [Python] accept pyarrow scalars in array constructor (#36162)
Browse files Browse the repository at this point in the history
### Rationale for this change

Currently, `pyarrow.array `doesn't accept list of pyarrow Scalars and this PR adds a check to allow that.

* Closes: #21761

Lead-authored-by: AlenkaF <[email protected]>
Co-authored-by: Alenka Frim <[email protected]>
Co-authored-by: Joris Van den Bossche <[email protected]>
Signed-off-by: Joris Van den Bossche <[email protected]>
  • Loading branch information
AlenkaF and jorisvandenbossche authored Jul 5, 2023
1 parent 947a446 commit b116b8a
Show file tree
Hide file tree
Showing 3 changed files with 202 additions and 0 deletions.
25 changes: 25 additions & 0 deletions python/pyarrow/src/arrow/python/inference.cc
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
#include <utility>
#include <vector>

#include "arrow/scalar.h"
#include "arrow/status.h"
#include "arrow/util/decimal.h"
#include "arrow/util/logging.h"
Expand Down Expand Up @@ -340,6 +341,7 @@ class TypeInferrer {
decimal_count_(0),
list_count_(0),
struct_count_(0),
arrow_scalar_count_(0),
numpy_dtype_count_(0),
interval_count_(0),
max_decimal_metadata_(std::numeric_limits<int32_t>::min(),
Expand Down Expand Up @@ -391,6 +393,8 @@ class TypeInferrer {
} else if (PyUnicode_Check(obj)) {
++unicode_count_;
*keep_going = make_unions_;
} else if (arrow::py::is_scalar(obj)) {
RETURN_NOT_OK(VisitArrowScalar(obj, keep_going));
} else if (PyArray_CheckAnyScalarExact(obj)) {
RETURN_NOT_OK(VisitDType(PyArray_DescrFromScalar(obj), keep_going));
} else if (PySet_Check(obj) || (Py_TYPE(obj) == &PyDictValues_Type)) {
Expand Down Expand Up @@ -455,6 +459,12 @@ class TypeInferrer {

RETURN_NOT_OK(Validate());

if (arrow_scalar_count_ > 0 && arrow_scalar_count_ + none_count_ != total_count_) {
return Status::Invalid(
"pyarrow scalars cannot be mixed "
"with other Python scalar values currently");
}

if (numpy_dtype_count_ > 0) {
// All NumPy scalars and Nones/nulls
if (numpy_dtype_count_ + none_count_ == total_count_) {
Expand Down Expand Up @@ -534,6 +544,8 @@ class TypeInferrer {
*out = utf8();
} else if (interval_count_) {
*out = month_day_nano_interval();
} else if (arrow_scalar_count_) {
*out = scalar_type_;
} else {
*out = null();
}
Expand All @@ -560,6 +572,17 @@ class TypeInferrer {
return Status::OK();
}

Status VisitArrowScalar(PyObject* obj, bool* keep_going /* unused */) {
ARROW_ASSIGN_OR_RAISE(auto scalar, arrow::py::unwrap_scalar(obj));
// Check that all the scalar types for the sequence are the same
if (arrow_scalar_count_ > 0 && *scalar->type != *scalar_type_) {
return internal::InvalidValue(obj, "cannot mix scalars with different types");
}
scalar_type_ = scalar->type;
++arrow_scalar_count_;
return Status::OK();
}

Status VisitDType(PyArray_Descr* dtype, bool* keep_going) {
// Continue visiting dtypes for now.
// TODO(wesm): devise approach for unions
Expand Down Expand Up @@ -675,10 +698,12 @@ class TypeInferrer {
int64_t decimal_count_;
int64_t list_count_;
int64_t struct_count_;
int64_t arrow_scalar_count_;
int64_t numpy_dtype_count_;
int64_t interval_count_;
std::unique_ptr<TypeInferrer> list_inferrer_;
std::map<std::string, TypeInferrer> struct_inferrers_;
std::shared_ptr<DataType> scalar_type_;

// If we observe a strongly-typed value in e.g. a NumPy array, we can store
// it here to skip the type counting logic above
Expand Down
39 changes: 39 additions & 0 deletions python/pyarrow/src/arrow/python/python_to_arrow.cc
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
#include <vector>

#include "arrow/array.h"
#include "arrow/array/builder_base.h"
#include "arrow/array/builder_binary.h"
#include "arrow/array/builder_decimal.h"
#include "arrow/array/builder_dict.h"
Expand All @@ -36,6 +37,7 @@
#include "arrow/array/builder_time.h"
#include "arrow/chunked_array.h"
#include "arrow/result.h"
#include "arrow/scalar.h"
#include "arrow/status.h"
#include "arrow/type.h"
#include "arrow/type_traits.h"
Expand Down Expand Up @@ -599,6 +601,15 @@ class PyPrimitiveConverter<T, enable_if_null<T>>
Status Append(PyObject* value) override {
if (PyValue::IsNull(this->options_, value)) {
return this->primitive_builder_->AppendNull();
} else if (arrow::py::is_scalar(value)) {
ARROW_ASSIGN_OR_RAISE(std::shared_ptr<Scalar> scalar,
arrow::py::unwrap_scalar(value));
if (scalar->is_valid) {
return Status::Invalid("Cannot append scalar of type ", scalar->type->ToString(),
" to builder for type null");
} else {
return this->primitive_builder_->AppendNull();
}
} else {
ARROW_ASSIGN_OR_RAISE(
auto converted, PyValue::Convert(this->primitive_type_, this->options_, value));
Expand All @@ -620,6 +631,10 @@ class PyPrimitiveConverter<
// rely on the Unsafe builder API which improves the performance.
if (PyValue::IsNull(this->options_, value)) {
this->primitive_builder_->UnsafeAppendNull();
} else if (arrow::py::is_scalar(value)) {
ARROW_ASSIGN_OR_RAISE(std::shared_ptr<Scalar> scalar,
arrow::py::unwrap_scalar(value));
ARROW_RETURN_NOT_OK(this->primitive_builder_->AppendScalar(*scalar));
} else {
ARROW_ASSIGN_OR_RAISE(
auto converted, PyValue::Convert(this->primitive_type_, this->options_, value));
Expand All @@ -637,6 +652,10 @@ class PyPrimitiveConverter<
Status Append(PyObject* value) override {
if (PyValue::IsNull(this->options_, value)) {
this->primitive_builder_->UnsafeAppendNull();
} else if (arrow::py::is_scalar(value)) {
ARROW_ASSIGN_OR_RAISE(std::shared_ptr<Scalar> scalar,
arrow::py::unwrap_scalar(value));
ARROW_RETURN_NOT_OK(this->primitive_builder_->AppendScalar(*scalar));
} else {
ARROW_ASSIGN_OR_RAISE(
auto converted, PyValue::Convert(this->primitive_type_, this->options_, value));
Expand All @@ -659,6 +678,10 @@ class PyPrimitiveConverter<T, enable_if_t<std::is_same<T, FixedSizeBinaryType>::
Status Append(PyObject* value) override {
if (PyValue::IsNull(this->options_, value)) {
this->primitive_builder_->UnsafeAppendNull();
} else if (arrow::py::is_scalar(value)) {
ARROW_ASSIGN_OR_RAISE(std::shared_ptr<Scalar> scalar,
arrow::py::unwrap_scalar(value));
ARROW_RETURN_NOT_OK(this->primitive_builder_->AppendScalar(*scalar));
} else {
ARROW_RETURN_NOT_OK(
PyValue::Convert(this->primitive_type_, this->options_, value, view_));
Expand All @@ -681,6 +704,10 @@ class PyPrimitiveConverter<T, enable_if_base_binary<T>>
Status Append(PyObject* value) override {
if (PyValue::IsNull(this->options_, value)) {
this->primitive_builder_->UnsafeAppendNull();
} else if (arrow::py::is_scalar(value)) {
ARROW_ASSIGN_OR_RAISE(std::shared_ptr<Scalar> scalar,
arrow::py::unwrap_scalar(value));
ARROW_RETURN_NOT_OK(this->primitive_builder_->AppendScalar(*scalar));
} else {
ARROW_RETURN_NOT_OK(
PyValue::Convert(this->primitive_type_, this->options_, value, view_));
Expand Down Expand Up @@ -721,6 +748,10 @@ class PyDictionaryConverter<U, enable_if_has_c_type<U>>
Status Append(PyObject* value) override {
if (PyValue::IsNull(this->options_, value)) {
return this->value_builder_->AppendNull();
} else if (arrow::py::is_scalar(value)) {
ARROW_ASSIGN_OR_RAISE(std::shared_ptr<Scalar> scalar,
arrow::py::unwrap_scalar(value));
return this->value_builder_->AppendScalar(*scalar, 1);
} else {
ARROW_ASSIGN_OR_RAISE(auto converted,
PyValue::Convert(this->value_type_, this->options_, value));
Expand All @@ -736,6 +767,10 @@ class PyDictionaryConverter<U, enable_if_has_string_view<U>>
Status Append(PyObject* value) override {
if (PyValue::IsNull(this->options_, value)) {
return this->value_builder_->AppendNull();
} else if (arrow::py::is_scalar(value)) {
ARROW_ASSIGN_OR_RAISE(std::shared_ptr<Scalar> scalar,
arrow::py::unwrap_scalar(value));
return this->value_builder_->AppendScalar(*scalar, 1);
} else {
ARROW_RETURN_NOT_OK(
PyValue::Convert(this->value_type_, this->options_, value, view_));
Expand Down Expand Up @@ -884,6 +919,10 @@ class PyStructConverter : public StructConverter<PyConverter, PyConverterTrait>
Status Append(PyObject* value) override {
if (PyValue::IsNull(this->options_, value)) {
return this->struct_builder_->AppendNull();
} else if (arrow::py::is_scalar(value)) {
ARROW_ASSIGN_OR_RAISE(std::shared_ptr<Scalar> scalar,
arrow::py::unwrap_scalar(value));
return this->struct_builder_->AppendScalar(*scalar);
}
switch (input_kind_) {
case InputKind::DICT:
Expand Down
138 changes: 138 additions & 0 deletions python/pyarrow/tests/test_convert_builtin.py
Original file line number Diff line number Diff line change
Expand Up @@ -2363,3 +2363,141 @@ def test_array_from_pylist_offset_overflow():
assert isinstance(arr, pa.ChunkedArray)
assert len(arr) == 2**31
assert len(arr.chunks) > 1


@parametrize_with_collections_types
@pytest.mark.parametrize(('data', 'scalar_data', 'value_type'), [
([True, False, None], [pa.scalar(True), pa.scalar(False), None], pa.bool_()),
(
[1, 2, None],
[pa.scalar(1), pa.scalar(2), pa.scalar(None, pa.int64())],
pa.int64()
),
([1, None, None], [pa.scalar(1), None, pa.scalar(None, pa.int64())], pa.int64()),
([None, None], [pa.scalar(None), pa.scalar(None)], pa.null()),
([1., 2., None], [pa.scalar(1.), pa.scalar(2.), None], pa.float64()),
(
[None, datetime.date.today()],
[None, pa.scalar(datetime.date.today())],
pa.date32()
),
(
[None, datetime.date.today()],
[None, pa.scalar(datetime.date.today(), pa.date64())],
pa.date64()
),
(
[datetime.time(1, 1, 1), None],
[pa.scalar(datetime.time(1, 1, 1)), None],
pa.time64('us')
),
(
[datetime.timedelta(seconds=10)],
[pa.scalar(datetime.timedelta(seconds=10))],
pa.duration('us')
),
(
[None, datetime.datetime(2014, 1, 1)],
[None, pa.scalar(datetime.datetime(2014, 1, 1))],
pa.timestamp('us')
),
(
[pa.MonthDayNano([1, -1, -10100])],
[pa.scalar(pa.MonthDayNano([1, -1, -10100]))],
pa.month_day_nano_interval()
),
(["a", "b"], [pa.scalar("a"), pa.scalar("b")], pa.string()),
([b"a", b"b"], [pa.scalar(b"a"), pa.scalar(b"b")], pa.binary()),
(
[b"a", b"b"],
[pa.scalar(b"a", pa.binary(1)), pa.scalar(b"b", pa.binary(1))],
pa.binary(1)
),
([[1, 2, 3]], [pa.scalar([1, 2, 3])], pa.list_(pa.int64())),
([["a", "b"]], [pa.scalar(["a", "b"])], pa.list_(pa.string())),
(
[1, 2, None],
[pa.scalar(1, type=pa.int8()), pa.scalar(2, type=pa.int8()), None],
pa.int8()
),
([1, None], [pa.scalar(1.0, type=pa.int32()), None], pa.int32()),
(
["aaa", "bbb"],
[pa.scalar("aaa", type=pa.binary(3)), pa.scalar("bbb", type=pa.binary(3))],
pa.binary(3)),
([b"a"], [pa.scalar("a", type=pa.large_binary())], pa.large_binary()),
(["a"], [pa.scalar("a", type=pa.large_string())], pa.large_string()),
(
["a"],
[pa.scalar("a", type=pa.dictionary(pa.int64(), pa.string()))],
pa.dictionary(pa.int64(), pa.string())
),
(
["a", "b"],
[pa.scalar("a", pa.dictionary(pa.int64(), pa.string())),
pa.scalar("b", pa.dictionary(pa.int64(), pa.string()))],
pa.dictionary(pa.int64(), pa.string())
),
(
[1],
[pa.scalar(1, type=pa.dictionary(pa.int64(), pa.int32()))],
pa.dictionary(pa.int64(), pa.int32())
),
(
[(1, 2)],
[pa.scalar([('a', 1), ('b', 2)], type=pa.struct(
[('a', pa.int8()), ('b', pa.int8())]))],
pa.struct([('a', pa.int8()), ('b', pa.int8())])
),
(
[(1, 'bar')],
[pa.scalar([('a', 1), ('b', 'bar')], type=pa.struct(
[('a', pa.int8()), ('b', pa.string())]))],
pa.struct([('a', pa.int8()), ('b', pa.string())])
)
])
def test_array_accepts_pyarrow_scalar(seq, data, scalar_data, value_type):
if type(seq(scalar_data)) == set:
pytest.skip("The elements in the set get reordered.")
expect = pa.array(data, type=value_type)
result = pa.array(seq(scalar_data))
assert expect.equals(result)

result = pa.array(seq(scalar_data), type=value_type)
assert expect.equals(result)


@parametrize_with_collections_types
def test_array_accepts_pyarrow_scalar_errors(seq):
sequence = seq([pa.scalar(1), pa.scalar("a"), pa.scalar(3.0)])
with pytest.raises(pa.ArrowInvalid,
match="cannot mix scalars with different types"):
pa.array(sequence)

sequence = seq([1, pa.scalar("a"), None])
with pytest.raises(pa.ArrowInvalid,
match="pyarrow scalars cannot be mixed with other "
"Python scalar values currently"):
pa.array(sequence)

sequence = seq([np.float16("0.1"), pa.scalar("a"), None])
with pytest.raises(pa.ArrowInvalid,
match="pyarrow scalars cannot be mixed with other "
"Python scalar values currently"):
pa.array(sequence)

sequence = seq([pa.scalar("a"), np.float16("0.1"), None])
with pytest.raises(pa.ArrowInvalid,
match="pyarrow scalars cannot be mixed with other "
"Python scalar values currently"):
pa.array(sequence)

with pytest.raises(pa.ArrowInvalid,
match="Cannot append scalar of type string "
"to builder for type int32"):
pa.array([pa.scalar("a")], type=pa.int32())

with pytest.raises(pa.ArrowInvalid,
match="Cannot append scalar of type int64 "
"to builder for type null"):
pa.array([pa.scalar(1)], type=pa.null())

0 comments on commit b116b8a

Please sign in to comment.