Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

GH-21761: [Python] accept pyarrow scalars in array constructor #36162

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 25 additions & 0 deletions python/pyarrow/src/arrow/python/inference.cc
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
#include <utility>
#include <vector>

#include "arrow/scalar.h"
#include "arrow/status.h"
#include "arrow/util/decimal.h"
#include "arrow/util/logging.h"
Expand Down Expand Up @@ -340,6 +341,7 @@ class TypeInferrer {
decimal_count_(0),
list_count_(0),
struct_count_(0),
arrow_scalar_count_(0),
numpy_dtype_count_(0),
interval_count_(0),
max_decimal_metadata_(std::numeric_limits<int32_t>::min(),
Expand Down Expand Up @@ -391,6 +393,8 @@ class TypeInferrer {
} else if (PyUnicode_Check(obj)) {
++unicode_count_;
*keep_going = make_unions_;
} else if (arrow::py::is_scalar(obj)) {
RETURN_NOT_OK(VisitArrowScalar(obj, keep_going));
} else if (PyArray_CheckAnyScalarExact(obj)) {
RETURN_NOT_OK(VisitDType(PyArray_DescrFromScalar(obj), keep_going));
} else if (PySet_Check(obj) || (Py_TYPE(obj) == &PyDictValues_Type)) {
Expand Down Expand Up @@ -455,6 +459,12 @@ class TypeInferrer {

RETURN_NOT_OK(Validate());

if (arrow_scalar_count_ > 0 && arrow_scalar_count_ + none_count_ != total_count_) {
return Status::Invalid(
"pyarrow scalars cannot be mixed "
"with other Python scalar values currently");
}

if (numpy_dtype_count_ > 0) {
// All NumPy scalars and Nones/nulls
if (numpy_dtype_count_ + none_count_ == total_count_) {
Expand Down Expand Up @@ -534,6 +544,8 @@ class TypeInferrer {
*out = utf8();
} else if (interval_count_) {
*out = month_day_nano_interval();
} else if (arrow_scalar_count_) {
*out = scalar_type_;
} else {
*out = null();
}
Expand All @@ -560,6 +572,17 @@ class TypeInferrer {
return Status::OK();
}

Status VisitArrowScalar(PyObject* obj, bool* keep_going /* unused */) {
ARROW_ASSIGN_OR_RAISE(auto scalar, arrow::py::unwrap_scalar(obj));
// Check that all the scalar types for the sequence are the same
if (arrow_scalar_count_ > 0 && *scalar->type != *scalar_type_) {
return internal::InvalidValue(obj, "cannot mix scalars with different types");
}
scalar_type_ = scalar->type;
++arrow_scalar_count_;
return Status::OK();
}

Status VisitDType(PyArray_Descr* dtype, bool* keep_going) {
// Continue visiting dtypes for now.
// TODO(wesm): devise approach for unions
Expand Down Expand Up @@ -675,10 +698,12 @@ class TypeInferrer {
int64_t decimal_count_;
int64_t list_count_;
int64_t struct_count_;
int64_t arrow_scalar_count_;
int64_t numpy_dtype_count_;
int64_t interval_count_;
std::unique_ptr<TypeInferrer> list_inferrer_;
std::map<std::string, TypeInferrer> struct_inferrers_;
std::shared_ptr<DataType> scalar_type_;

// If we observe a strongly-typed value in e.g. a NumPy array, we can store
// it here to skip the type counting logic above
Expand Down
39 changes: 39 additions & 0 deletions python/pyarrow/src/arrow/python/python_to_arrow.cc
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
#include <vector>

#include "arrow/array.h"
#include "arrow/array/builder_base.h"
#include "arrow/array/builder_binary.h"
#include "arrow/array/builder_decimal.h"
#include "arrow/array/builder_dict.h"
Expand All @@ -36,6 +37,7 @@
#include "arrow/array/builder_time.h"
#include "arrow/chunked_array.h"
#include "arrow/result.h"
#include "arrow/scalar.h"
#include "arrow/status.h"
#include "arrow/type.h"
#include "arrow/type_traits.h"
Expand Down Expand Up @@ -599,6 +601,15 @@ class PyPrimitiveConverter<T, enable_if_null<T>>
Status Append(PyObject* value) override {
if (PyValue::IsNull(this->options_, value)) {
return this->primitive_builder_->AppendNull();
} else if (arrow::py::is_scalar(value)) {
ARROW_ASSIGN_OR_RAISE(std::shared_ptr<Scalar> scalar,
arrow::py::unwrap_scalar(value));
if (scalar->is_valid) {
return Status::Invalid("Cannot append scalar of type ", scalar->type->ToString(),
" to builder for type null");
} else {
return this->primitive_builder_->AppendNull();
}
} else {
ARROW_ASSIGN_OR_RAISE(
auto converted, PyValue::Convert(this->primitive_type_, this->options_, value));
Expand All @@ -620,6 +631,10 @@ class PyPrimitiveConverter<
// rely on the Unsafe builder API which improves the performance.
if (PyValue::IsNull(this->options_, value)) {
this->primitive_builder_->UnsafeAppendNull();
} else if (arrow::py::is_scalar(value)) {
ARROW_ASSIGN_OR_RAISE(std::shared_ptr<Scalar> scalar,
arrow::py::unwrap_scalar(value));
ARROW_RETURN_NOT_OK(this->primitive_builder_->AppendScalar(*scalar));
} else {
ARROW_ASSIGN_OR_RAISE(
auto converted, PyValue::Convert(this->primitive_type_, this->options_, value));
Expand All @@ -637,6 +652,10 @@ class PyPrimitiveConverter<
Status Append(PyObject* value) override {
if (PyValue::IsNull(this->options_, value)) {
this->primitive_builder_->UnsafeAppendNull();
} else if (arrow::py::is_scalar(value)) {
ARROW_ASSIGN_OR_RAISE(std::shared_ptr<Scalar> scalar,
arrow::py::unwrap_scalar(value));
ARROW_RETURN_NOT_OK(this->primitive_builder_->AppendScalar(*scalar));
} else {
ARROW_ASSIGN_OR_RAISE(
auto converted, PyValue::Convert(this->primitive_type_, this->options_, value));
Expand All @@ -659,6 +678,10 @@ class PyPrimitiveConverter<T, enable_if_t<std::is_same<T, FixedSizeBinaryType>::
Status Append(PyObject* value) override {
if (PyValue::IsNull(this->options_, value)) {
this->primitive_builder_->UnsafeAppendNull();
} else if (arrow::py::is_scalar(value)) {
ARROW_ASSIGN_OR_RAISE(std::shared_ptr<Scalar> scalar,
arrow::py::unwrap_scalar(value));
ARROW_RETURN_NOT_OK(this->primitive_builder_->AppendScalar(*scalar));
} else {
ARROW_RETURN_NOT_OK(
PyValue::Convert(this->primitive_type_, this->options_, value, view_));
Expand All @@ -681,6 +704,10 @@ class PyPrimitiveConverter<T, enable_if_base_binary<T>>
Status Append(PyObject* value) override {
if (PyValue::IsNull(this->options_, value)) {
this->primitive_builder_->UnsafeAppendNull();
} else if (arrow::py::is_scalar(value)) {
ARROW_ASSIGN_OR_RAISE(std::shared_ptr<Scalar> scalar,
arrow::py::unwrap_scalar(value));
ARROW_RETURN_NOT_OK(this->primitive_builder_->AppendScalar(*scalar));
} else {
ARROW_RETURN_NOT_OK(
PyValue::Convert(this->primitive_type_, this->options_, value, view_));
Expand Down Expand Up @@ -721,6 +748,10 @@ class PyDictionaryConverter<U, enable_if_has_c_type<U>>
Status Append(PyObject* value) override {
if (PyValue::IsNull(this->options_, value)) {
return this->value_builder_->AppendNull();
} else if (arrow::py::is_scalar(value)) {
ARROW_ASSIGN_OR_RAISE(std::shared_ptr<Scalar> scalar,
arrow::py::unwrap_scalar(value));
return this->value_builder_->AppendScalar(*scalar, 1);
} else {
ARROW_ASSIGN_OR_RAISE(auto converted,
PyValue::Convert(this->value_type_, this->options_, value));
Expand All @@ -736,6 +767,10 @@ class PyDictionaryConverter<U, enable_if_has_string_view<U>>
Status Append(PyObject* value) override {
if (PyValue::IsNull(this->options_, value)) {
return this->value_builder_->AppendNull();
} else if (arrow::py::is_scalar(value)) {
ARROW_ASSIGN_OR_RAISE(std::shared_ptr<Scalar> scalar,
arrow::py::unwrap_scalar(value));
return this->value_builder_->AppendScalar(*scalar, 1);
} else {
ARROW_RETURN_NOT_OK(
PyValue::Convert(this->value_type_, this->options_, value, view_));
Expand Down Expand Up @@ -884,6 +919,10 @@ class PyStructConverter : public StructConverter<PyConverter, PyConverterTrait>
Status Append(PyObject* value) override {
if (PyValue::IsNull(this->options_, value)) {
return this->struct_builder_->AppendNull();
} else if (arrow::py::is_scalar(value)) {
ARROW_ASSIGN_OR_RAISE(std::shared_ptr<Scalar> scalar,
arrow::py::unwrap_scalar(value));
return this->struct_builder_->AppendScalar(*scalar);
}
switch (input_kind_) {
case InputKind::DICT:
Expand Down
138 changes: 138 additions & 0 deletions python/pyarrow/tests/test_convert_builtin.py
Original file line number Diff line number Diff line change
Expand Up @@ -2363,3 +2363,141 @@ def test_array_from_pylist_offset_overflow():
assert isinstance(arr, pa.ChunkedArray)
assert len(arr) == 2**31
assert len(arr.chunks) > 1


AlenkaF marked this conversation as resolved.
Show resolved Hide resolved
@parametrize_with_collections_types
@pytest.mark.parametrize(('data', 'scalar_data', 'value_type'), [
([True, False, None], [pa.scalar(True), pa.scalar(False), None], pa.bool_()),
(
[1, 2, None],
[pa.scalar(1), pa.scalar(2), pa.scalar(None, pa.int64())],
pa.int64()
),
([1, None, None], [pa.scalar(1), None, pa.scalar(None, pa.int64())], pa.int64()),
([None, None], [pa.scalar(None), pa.scalar(None)], pa.null()),
([1., 2., None], [pa.scalar(1.), pa.scalar(2.), None], pa.float64()),
(
[None, datetime.date.today()],
[None, pa.scalar(datetime.date.today())],
pa.date32()
),
(
[None, datetime.date.today()],
[None, pa.scalar(datetime.date.today(), pa.date64())],
pa.date64()
),
(
[datetime.time(1, 1, 1), None],
[pa.scalar(datetime.time(1, 1, 1)), None],
pa.time64('us')
),
(
[datetime.timedelta(seconds=10)],
[pa.scalar(datetime.timedelta(seconds=10))],
pa.duration('us')
),
(
[None, datetime.datetime(2014, 1, 1)],
[None, pa.scalar(datetime.datetime(2014, 1, 1))],
pa.timestamp('us')
),
(
[pa.MonthDayNano([1, -1, -10100])],
[pa.scalar(pa.MonthDayNano([1, -1, -10100]))],
pa.month_day_nano_interval()
),
(["a", "b"], [pa.scalar("a"), pa.scalar("b")], pa.string()),
([b"a", b"b"], [pa.scalar(b"a"), pa.scalar(b"b")], pa.binary()),
(
[b"a", b"b"],
[pa.scalar(b"a", pa.binary(1)), pa.scalar(b"b", pa.binary(1))],
pa.binary(1)
),
([[1, 2, 3]], [pa.scalar([1, 2, 3])], pa.list_(pa.int64())),
([["a", "b"]], [pa.scalar(["a", "b"])], pa.list_(pa.string())),
(
[1, 2, None],
[pa.scalar(1, type=pa.int8()), pa.scalar(2, type=pa.int8()), None],
pa.int8()
),
([1, None], [pa.scalar(1.0, type=pa.int32()), None], pa.int32()),
(
["aaa", "bbb"],
[pa.scalar("aaa", type=pa.binary(3)), pa.scalar("bbb", type=pa.binary(3))],
pa.binary(3)),
([b"a"], [pa.scalar("a", type=pa.large_binary())], pa.large_binary()),
(["a"], [pa.scalar("a", type=pa.large_string())], pa.large_string()),
(
["a"],
[pa.scalar("a", type=pa.dictionary(pa.int64(), pa.string()))],
pa.dictionary(pa.int64(), pa.string())
),
(
["a", "b"],
[pa.scalar("a", pa.dictionary(pa.int64(), pa.string())),
pa.scalar("b", pa.dictionary(pa.int64(), pa.string()))],
pa.dictionary(pa.int64(), pa.string())
),
(
[1],
[pa.scalar(1, type=pa.dictionary(pa.int64(), pa.int32()))],
pa.dictionary(pa.int64(), pa.int32())
),
(
[(1, 2)],
[pa.scalar([('a', 1), ('b', 2)], type=pa.struct(
[('a', pa.int8()), ('b', pa.int8())]))],
pa.struct([('a', pa.int8()), ('b', pa.int8())])
),
(
[(1, 'bar')],
[pa.scalar([('a', 1), ('b', 'bar')], type=pa.struct(
[('a', pa.int8()), ('b', pa.string())]))],
pa.struct([('a', pa.int8()), ('b', pa.string())])
)
])
def test_array_accepts_pyarrow_scalar(seq, data, scalar_data, value_type):
if type(seq(scalar_data)) == set:
pytest.skip("The elements in the set get reordered.")
expect = pa.array(data, type=value_type)
result = pa.array(seq(scalar_data))
assert expect.equals(result)

result = pa.array(seq(scalar_data), type=value_type)
assert expect.equals(result)


@parametrize_with_collections_types
def test_array_accepts_pyarrow_scalar_errors(seq):
sequence = seq([pa.scalar(1), pa.scalar("a"), pa.scalar(3.0)])
with pytest.raises(pa.ArrowInvalid,
match="cannot mix scalars with different types"):
pa.array(sequence)

sequence = seq([1, pa.scalar("a"), None])
with pytest.raises(pa.ArrowInvalid,
match="pyarrow scalars cannot be mixed with other "
"Python scalar values currently"):
pa.array(sequence)

sequence = seq([np.float16("0.1"), pa.scalar("a"), None])
with pytest.raises(pa.ArrowInvalid,
match="pyarrow scalars cannot be mixed with other "
"Python scalar values currently"):
pa.array(sequence)

sequence = seq([pa.scalar("a"), np.float16("0.1"), None])
with pytest.raises(pa.ArrowInvalid,
match="pyarrow scalars cannot be mixed with other "
"Python scalar values currently"):
pa.array(sequence)

with pytest.raises(pa.ArrowInvalid,
match="Cannot append scalar of type string "
"to builder for type int32"):
pa.array([pa.scalar("a")], type=pa.int32())
AlenkaF marked this conversation as resolved.
Show resolved Hide resolved

with pytest.raises(pa.ArrowInvalid,
match="Cannot append scalar of type int64 "
"to builder for type null"):
pa.array([pa.scalar(1)], type=pa.null())