-
Notifications
You must be signed in to change notification settings - Fork 3.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
GH-40695 [C++] Expand Substrait type support #40696
Changes from all commits
b3cc71f
c69490b
7d0195e
99ae09f
3c87d3a
ed145bc
10aedfd
ab8237a
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -33,6 +33,7 @@ | |
#include <vector> | ||
|
||
#include <google/protobuf/descriptor.h> | ||
#include <google/protobuf/wrappers.pb.h> | ||
|
||
#include "arrow/array/array_base.h" | ||
#include "arrow/array/array_nested.h" | ||
|
@@ -58,8 +59,10 @@ | |
#include "arrow/util/decimal.h" | ||
#include "arrow/util/endian.h" | ||
#include "arrow/util/logging.h" | ||
#include "arrow/util/macros.h" | ||
#include "arrow/util/small_vector.h" | ||
#include "arrow/util/string.h" | ||
#include "arrow/util/unreachable.h" | ||
#include "arrow/visit_scalar_inline.h" | ||
|
||
namespace arrow { | ||
|
@@ -71,6 +74,9 @@ namespace engine { | |
|
||
namespace { | ||
|
||
constexpr int64_t kMicrosPerSecond = 1000000; | ||
constexpr int64_t kMicrosPerMilli = 1000; | ||
|
||
Id NormalizeFunctionName(Id id) { | ||
// Substrait plans encode the types into the function name so it might look like | ||
// add:opt_i32_i32. We don't care about the :opt_i32_i32 so we just trim it | ||
|
@@ -421,6 +427,86 @@ Result<compute::Expression> FromProto(const substrait::Expression& expr, | |
expr.DebugString()); | ||
} | ||
|
||
namespace { | ||
struct UserDefinedLiteralToArrow { | ||
Status Visit(const DataType& type) { | ||
return Status::NotImplemented("User defined literals of type ", type); | ||
} | ||
Status Visit(const IntegerType& type) { | ||
google::protobuf::UInt64Value value; | ||
if (!user_defined_->value().UnpackTo(&value)) { | ||
return Status::Invalid( | ||
"Failed to unpack user defined integer literal to UInt64Value"); | ||
} | ||
ARROW_ASSIGN_OR_RAISE(scalar_, MakeScalar(type.GetSharedPtr(), value.value())); | ||
return Status::OK(); | ||
} | ||
Status Visit(const Time32Type& type) { | ||
google::protobuf::Int32Value value; | ||
if (!user_defined_->value().UnpackTo(&value)) { | ||
return Status::Invalid( | ||
"Failed to unpack user defined time32 literal to Int32Value"); | ||
} | ||
ARROW_ASSIGN_OR_RAISE(scalar_, MakeScalar(type.GetSharedPtr(), value.value())); | ||
return Status::OK(); | ||
} | ||
Status Visit(const Time64Type& type) { | ||
google::protobuf::Int64Value value; | ||
if (!user_defined_->value().UnpackTo(&value)) { | ||
return Status::Invalid( | ||
"Failed to unpack user defined time64 literal to Int64Value"); | ||
} | ||
ARROW_ASSIGN_OR_RAISE(scalar_, MakeScalar(type.GetSharedPtr(), value.value())); | ||
return Status::OK(); | ||
} | ||
Status Visit(const Date64Type& type) { | ||
google::protobuf::Int64Value value; | ||
if (!user_defined_->value().UnpackTo(&value)) { | ||
return Status::Invalid( | ||
"Failed to unpack user defined date64 literal to Int64Value"); | ||
} | ||
ARROW_ASSIGN_OR_RAISE(scalar_, MakeScalar(type.GetSharedPtr(), value.value())); | ||
return Status::OK(); | ||
} | ||
Status Visit(const HalfFloatType& type) { | ||
google::protobuf::UInt32Value value; | ||
if (!user_defined_->value().UnpackTo(&value)) { | ||
return Status::Invalid( | ||
"Failed to unpack user defined half_float literal to UInt32Value"); | ||
} | ||
uint16_t half_float_value = value.value(); | ||
ARROW_ASSIGN_OR_RAISE(scalar_, MakeScalar(type.GetSharedPtr(), half_float_value)); | ||
return Status::OK(); | ||
} | ||
Status Visit(const LargeStringType& type) { | ||
google::protobuf::StringValue value; | ||
if (!user_defined_->value().UnpackTo(&value)) { | ||
return Status::Invalid( | ||
"Failed to unpack user defined large_string literal to StringValue"); | ||
} | ||
ARROW_ASSIGN_OR_RAISE(scalar_, | ||
MakeScalar(type.GetSharedPtr(), std::string(value.value()))); | ||
return Status::OK(); | ||
} | ||
Status Visit(const LargeBinaryType& type) { | ||
google::protobuf::BytesValue value; | ||
if (!user_defined_->value().UnpackTo(&value)) { | ||
return Status::Invalid( | ||
"Failed to unpack user defined large_binary literal to BytesValue"); | ||
} | ||
ARROW_ASSIGN_OR_RAISE(scalar_, | ||
MakeScalar(type.GetSharedPtr(), std::string(value.value()))); | ||
return Status::OK(); | ||
} | ||
Status operator()(const DataType& type) { return VisitTypeInline(type, this); } | ||
|
||
std::shared_ptr<Scalar> scalar_; | ||
const substrait::Expression::Literal::UserDefined* user_defined_; | ||
const ExtensionSet* ext_set_; | ||
const ConversionOptions& conversion_options_; | ||
}; | ||
} // namespace | ||
|
||
Result<Datum> FromProto(const substrait::Expression::Literal& lit, | ||
const ExtensionSet& ext_set, | ||
const ConversionOptions& conversion_options) { | ||
|
@@ -455,14 +541,25 @@ Result<Datum> FromProto(const substrait::Expression::Literal& lit, | |
case substrait::Expression::Literal::kBinary: | ||
return Datum(BinaryScalar(lit.binary())); | ||
|
||
ARROW_SUPPRESS_DEPRECATION_WARNING | ||
case substrait::Expression::Literal::kTimestamp: | ||
return Datum( | ||
TimestampScalar(static_cast<int64_t>(lit.timestamp()), TimeUnit::MICRO)); | ||
|
||
case substrait::Expression::Literal::kTimestampTz: | ||
return Datum(TimestampScalar(static_cast<int64_t>(lit.timestamp_tz()), | ||
TimeUnit::MICRO, TimestampTzTimezoneString())); | ||
|
||
ARROW_UNSUPPRESS_DEPRECATION_WARNING | ||
case substrait::Expression::Literal::kPrecisionTimestamp: { | ||
// https://github.com/substrait-io/substrait/issues/611 | ||
// TODO(GH-40741) don't break, return precision timestamp | ||
break; | ||
} | ||
case substrait::Expression::Literal::kPrecisionTimestampTz: { | ||
// https://github.com/substrait-io/substrait/issues/611 | ||
// TODO(GH-40741) don't break, return precision timestamp | ||
break; | ||
} | ||
case substrait::Expression::Literal::kDate: | ||
return Datum(Date32Scalar(lit.date())); | ||
case substrait::Expression::Literal::kTime: | ||
|
@@ -674,18 +771,30 @@ Result<Datum> FromProto(const substrait::Expression::Literal& lit, | |
return Datum(MakeNullScalar(std::move(type_nullable.first))); | ||
} | ||
|
||
case substrait::Expression::Literal::kUserDefined: { | ||
const auto& user_defined = lit.user_defined(); | ||
ARROW_ASSIGN_OR_RAISE(auto type_record, | ||
ext_set.DecodeType(user_defined.type_reference())); | ||
UserDefinedLiteralToArrow visitor{nullptr, &user_defined, &ext_set, | ||
conversion_options}; | ||
ARROW_RETURN_NOT_OK((visitor)(*type_record.type)); | ||
return Datum(std::move(visitor.scalar_)); | ||
} | ||
|
||
case substrait::Expression::Literal::LITERAL_TYPE_NOT_SET: | ||
return Status::Invalid("substrait literal did not have any literal type set"); | ||
|
||
default: | ||
break; | ||
} | ||
|
||
return Status::NotImplemented("conversion to arrow::Datum from Substrait literal ", | ||
lit.DebugString()); | ||
return Status::NotImplemented("conversion to arrow::Datum from Substrait literal `", | ||
lit.DebugString(), "`"); | ||
} | ||
|
||
namespace { | ||
struct ScalarToProtoImpl { | ||
Status Visit(const NullScalar& s) { return NotImplemented(s); } | ||
|
||
struct ScalarToProtoImpl { | ||
using Lit = substrait::Expression::Literal; | ||
|
||
template <typename Arg, typename PrimitiveScalar> | ||
|
@@ -702,61 +811,141 @@ struct ScalarToProtoImpl { | |
return Status::OK(); | ||
} | ||
|
||
Status EncodeUserDefined(const DataType& data_type, | ||
const google::protobuf::Message& value) { | ||
ARROW_ASSIGN_OR_RAISE(auto anchor, ext_set_->EncodeType(data_type)); | ||
auto user_defined = std::make_unique<Lit::UserDefined>(); | ||
user_defined->set_type_reference(anchor); | ||
auto value_any = std::make_unique<google::protobuf::Any>(); | ||
value_any->PackFrom(value); | ||
user_defined->set_allocated_value(value_any.release()); | ||
lit_->set_allocated_user_defined(user_defined.release()); | ||
Comment on lines
+817
to
+822
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You can use destination-passing style to avoid extra allocations with protobuf types: auto *user_defined = lit_->mutable_user_defined();
user_defined->set_type_reference(anchor);
value.PackTo(user_defined->mutable_value()); A change in style for the whole file really, so not a blocker for this PR specifically. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
return Status::OK(); | ||
} | ||
|
||
Status Visit(const NullScalar& s) { | ||
ARROW_ASSIGN_OR_RAISE(auto anchor, ext_set_->EncodeType(*s.type)); | ||
auto user_defined = std::make_unique<Lit::UserDefined>(); | ||
user_defined->set_type_reference(anchor); | ||
lit_->set_allocated_user_defined(user_defined.release()); | ||
return Status::OK(); | ||
} | ||
Status Visit(const BooleanScalar& s) { return Primitive(&Lit::set_boolean, s); } | ||
|
||
Status Visit(const Int8Scalar& s) { return Primitive(&Lit::set_i8, s); } | ||
Status Visit(const Int16Scalar& s) { return Primitive(&Lit::set_i16, s); } | ||
Status Visit(const Int32Scalar& s) { return Primitive(&Lit::set_i32, s); } | ||
Status Visit(const Int64Scalar& s) { return Primitive(&Lit::set_i64, s); } | ||
|
||
Status Visit(const UInt8Scalar& s) { return NotImplemented(s); } | ||
Status Visit(const UInt16Scalar& s) { return NotImplemented(s); } | ||
Status Visit(const UInt32Scalar& s) { return NotImplemented(s); } | ||
Status Visit(const UInt64Scalar& s) { return NotImplemented(s); } | ||
|
||
Status Visit(const HalfFloatScalar& s) { return NotImplemented(s); } | ||
Status Visit(const UInt8Scalar& s) { | ||
google::protobuf::UInt64Value value; | ||
value.set_value(s.value); | ||
return EncodeUserDefined(*s.type, value); | ||
} | ||
Status Visit(const UInt16Scalar& s) { | ||
google::protobuf::UInt64Value value; | ||
value.set_value(s.value); | ||
return EncodeUserDefined(*s.type, value); | ||
} | ||
Status Visit(const UInt32Scalar& s) { | ||
google::protobuf::UInt64Value value; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Shouldn't be There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In protobuf all unsigned integers are encoded the same way and there is no uint16 or uint8 so I figured it would be simplest and most consistent to just use uint64 for everything. It's something of an arbitrary decision and, even if the receiver decodes it as a uint32, it will still work. |
||
value.set_value(s.value); | ||
return EncodeUserDefined(*s.type, value); | ||
} | ||
Status Visit(const UInt64Scalar& s) { | ||
google::protobuf::UInt64Value value; | ||
value.set_value(s.value); | ||
return EncodeUserDefined(*s.type, value); | ||
} | ||
Status Visit(const HalfFloatScalar& s) { | ||
google::protobuf::UInt32Value value; | ||
value.set_value(s.value); | ||
return EncodeUserDefined(*s.type, value); | ||
} | ||
Status Visit(const FloatScalar& s) { return Primitive(&Lit::set_fp32, s); } | ||
Status Visit(const DoubleScalar& s) { return Primitive(&Lit::set_fp64, s); } | ||
|
||
Status Visit(const StringScalar& s) { | ||
return FromBuffer([](Lit* lit, std::string&& s) { lit->set_string(std::move(s)); }, | ||
s); | ||
} | ||
Status Visit(const StringViewScalar& s) { | ||
return FromBuffer([](Lit* lit, std::string&& s) { lit->set_string(std::move(s)); }, | ||
s); | ||
} | ||
Status Visit(const BinaryScalar& s) { | ||
return FromBuffer([](Lit* lit, std::string&& s) { lit->set_binary(std::move(s)); }, | ||
s); | ||
} | ||
|
||
Status Visit(const BinaryViewScalar& s) { return NotImplemented(s); } | ||
Status Visit(const BinaryViewScalar& s) { | ||
return FromBuffer([](Lit* lit, std::string&& s) { lit->set_binary(std::move(s)); }, | ||
s); | ||
} | ||
|
||
Status Visit(const FixedSizeBinaryScalar& s) { | ||
return FromBuffer( | ||
[](Lit* lit, std::string&& s) { lit->set_fixed_binary(std::move(s)); }, s); | ||
} | ||
|
||
Status Visit(const Date32Scalar& s) { return Primitive(&Lit::set_date, s); } | ||
Status Visit(const Date64Scalar& s) { return NotImplemented(s); } | ||
Status Visit(const Date64Scalar& s) { | ||
google::protobuf::Int64Value value; | ||
value.set_value(s.value); | ||
return EncodeUserDefined(*s.type, value); | ||
} | ||
|
||
Status Visit(const TimestampScalar& s) { | ||
const auto& t = checked_cast<const TimestampType&>(*s.type); | ||
|
||
if (t.unit() != TimeUnit::MICRO) return NotImplemented(s); | ||
uint64_t micros; | ||
switch (t.unit()) { | ||
case TimeUnit::SECOND: | ||
micros = s.value * kMicrosPerSecond; | ||
break; | ||
case TimeUnit::MILLI: | ||
micros = s.value * kMicrosPerMilli; | ||
break; | ||
case TimeUnit::MICRO: | ||
micros = s.value; | ||
break; | ||
case TimeUnit::NANO: | ||
// TODO(GH-40741): can support nanos when | ||
// https://github.com/substrait-io/substrait/issues/611 is resolved | ||
return NotImplemented(s); | ||
default: | ||
return NotImplemented(s); | ||
} | ||
|
||
if (t.timezone() == "") return Primitive(&Lit::set_timestamp, s); | ||
// Remove these and use precision timestamp once | ||
// https://github.com/substrait-io/substrait/issues/611 is resolved | ||
ARROW_SUPPRESS_DEPRECATION_WARNING | ||
|
||
if (t.timezone() == TimestampTzTimezoneString()) { | ||
return Primitive(&Lit::set_timestamp_tz, s); | ||
if (t.timezone() == "") { | ||
lit_->set_timestamp(micros); | ||
} else { | ||
// Some loss of info here, Substrait doesn't store timezone | ||
// in field data | ||
lit_->set_timestamp_tz(micros); | ||
} | ||
ARROW_UNSUPPRESS_DEPRECATION_WARNING | ||
|
||
return NotImplemented(s); | ||
return Status::OK(); | ||
} | ||
|
||
Status Visit(const Time32Scalar& s) { return NotImplemented(s); } | ||
// Need to support parameterized UDTs | ||
Status Visit(const Time32Scalar& s) { | ||
google::protobuf::Int32Value value; | ||
value.set_value(s.value); | ||
return EncodeUserDefined(*s.type, value); | ||
} | ||
Status Visit(const Time64Scalar& s) { | ||
if (checked_cast<const Time64Type&>(*s.type).unit() != TimeUnit::MICRO) { | ||
return NotImplemented(s); | ||
if (checked_cast<const Time64Type&>(*s.type).unit() == TimeUnit::MICRO) { | ||
return Primitive(&Lit::set_time, s); | ||
} else { | ||
google::protobuf::Int64Value value; | ||
value.set_value(s.value); | ||
return EncodeUserDefined(*s.type, value); | ||
} | ||
return Primitive(&Lit::set_time, s); | ||
} | ||
|
||
Status Visit(const MonthIntervalScalar& s) { return NotImplemented(s); } | ||
|
@@ -778,9 +967,10 @@ struct ScalarToProtoImpl { | |
return Status::OK(); | ||
} | ||
|
||
// Need support for parameterized UDTs | ||
Status Visit(const Decimal256Scalar& s) { return NotImplemented(s); } | ||
|
||
Status Visit(const ListScalar& s) { | ||
Status Visit(const BaseListScalar& s) { | ||
if (s.value->length() == 0) { | ||
ARROW_ASSIGN_OR_RAISE(auto list_type, ToProto(*s.type, /*nullable=*/true, ext_set_, | ||
conversion_options_)); | ||
|
@@ -807,10 +997,6 @@ struct ScalarToProtoImpl { | |
return Status::OK(); | ||
} | ||
|
||
Status Visit(const ListViewScalar& s) { | ||
return Status::NotImplemented("list-view to proto"); | ||
} | ||
felipecrv marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
Status Visit(const LargeListViewScalar& s) { | ||
return Status::NotImplemented("list-view to proto"); | ||
} | ||
|
@@ -830,7 +1016,10 @@ struct ScalarToProtoImpl { | |
|
||
Status Visit(const SparseUnionScalar& s) { return NotImplemented(s); } | ||
Status Visit(const DenseUnionScalar& s) { return NotImplemented(s); } | ||
Status Visit(const DictionaryScalar& s) { return NotImplemented(s); } | ||
Status Visit(const DictionaryScalar& s) { | ||
ARROW_ASSIGN_OR_RAISE(auto encoded, s.GetEncodedValue()); | ||
return (*this)(*encoded); | ||
} | ||
|
||
Status Visit(const MapScalar& s) { | ||
if (s.value->length() == 0) { | ||
|
@@ -914,10 +1103,21 @@ struct ScalarToProtoImpl { | |
return NotImplemented(s); | ||
} | ||
|
||
// Need support for parameterized UDTs | ||
Status Visit(const FixedSizeListScalar& s) { return NotImplemented(s); } | ||
Status Visit(const DurationScalar& s) { return NotImplemented(s); } | ||
Status Visit(const LargeStringScalar& s) { return NotImplemented(s); } | ||
Status Visit(const LargeBinaryScalar& s) { return NotImplemented(s); } | ||
|
||
Status Visit(const LargeStringScalar& s) { | ||
google::protobuf::StringValue value; | ||
value.set_value(s.view().data(), s.view().size()); | ||
return EncodeUserDefined(*s.type, value); | ||
} | ||
Status Visit(const LargeBinaryScalar& s) { | ||
google::protobuf::BytesValue value; | ||
value.set_value(s.view().data(), s.view().size()); | ||
return EncodeUserDefined(*s.type, value); | ||
} | ||
// Need support for parameterized UDTs | ||
Status Visit(const LargeListScalar& s) { return NotImplemented(s); } | ||
Status Visit(const MonthDayNanoIntervalScalar& s) { return NotImplemented(s); } | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: a function to create this Status taking arguments like
"integer literal"
and"UInt64Value"
would reduce the string literal bloat in the binary because "Failed to unpack user defined " wouldn't have to be inlined so many times in the literals.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Adding
ARROW_PREDICT_FALSE
to the condition would also reduce the inlining of code in the block because it becomes cold code code for the compiler.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
#41141