From b3cc71f56bb9202bc4f303b4b5cc0ec1c612d0ce Mon Sep 17 00:00:00 2001 From: Weston Pace Date: Wed, 20 Mar 2024 16:31:33 -0700 Subject: [PATCH 1/8] Update to latest substrait, add support for precision timestamp types. Add support for more arrow-specific types that were not previously supported --- .../engine/substrait/expression_internal.cc | 228 +++++++++++++++--- .../arrow/engine/substrait/extension_set.cc | 10 + .../arrow/engine/substrait/extension_set.h | 11 + .../arrow/engine/substrait/extension_types.cc | 31 +++ .../arrow/engine/substrait/extension_types.h | 10 + cpp/src/arrow/engine/substrait/serde_test.cc | 114 +++++++-- .../arrow/engine/substrait/type_internal.cc | 121 ++++++++-- cpp/thirdparty/versions.txt | 4 +- format/substrait/extension_types.yaml | 132 ++++++++-- python/pyarrow/tests/test_substrait.py | 25 ++ 10 files changed, 602 insertions(+), 84 deletions(-) diff --git a/cpp/src/arrow/engine/substrait/expression_internal.cc b/cpp/src/arrow/engine/substrait/expression_internal.cc index 5d892af9a394e..0ba129f3efd89 100644 --- a/cpp/src/arrow/engine/substrait/expression_internal.cc +++ b/cpp/src/arrow/engine/substrait/expression_internal.cc @@ -33,6 +33,7 @@ #include #include +#include #include "arrow/array/array_base.h" #include "arrow/array/array_nested.h" @@ -58,8 +59,10 @@ #include "arrow/util/decimal.h" #include "arrow/util/endian.h" #include "arrow/util/logging.h" +#include "arrow/util/macros.h" #include "arrow/util/small_vector.h" #include "arrow/util/string.h" +#include "arrow/util/unreachable.h" #include "arrow/visit_scalar_inline.h" namespace arrow { @@ -71,6 +74,9 @@ namespace engine { namespace { +constexpr int64_t kMicrosPerSecond = 1000000; +constexpr int64_t kMicrosPerMilli = 1000; + Id NormalizeFunctionName(Id id) { // Substrait plans encode the types into the function name so it might look like // add:opt_i32_i32. We don't care about the :opt_i32_i32 so we just trim it @@ -421,6 +427,65 @@ Result FromProto(const substrait::Expression& expr, expr.DebugString()); } +namespace { +struct UserDefinedLiteralToArrow { + Status Visit(const DataType& type) { + return Status::NotImplemented("User defined literals of type ", type); + } + Status Visit(const IntegerType& type) { + google::protobuf::UInt64Value value; + if (!user_defined_->value().UnpackTo(&value)) { + return Status::Invalid("Failed to unpack user defined integer literal to uint64"); + } + ARROW_ASSIGN_OR_RAISE(scalar_, MakeScalar(type.GetSharedPtr(), value.value())); + return Status::OK(); + } + Status Visit(const Date64Type& type) { + google::protobuf::UInt64Value value; + if (!user_defined_->value().UnpackTo(&value)) { + return Status::Invalid("Failed to unpack user defined date64 literal to uint64"); + } + ARROW_ASSIGN_OR_RAISE(scalar_, MakeScalar(type.GetSharedPtr(), value.value())); + return Status::OK(); + } + Status Visit(const HalfFloatType& type) { + google::protobuf::UInt32Value value; + if (!user_defined_->value().UnpackTo(&value)) { + return Status::Invalid("Failed to unpack user defined half_float literal to bytes"); + } + uint16_t half_float_value = value.value(); + ARROW_ASSIGN_OR_RAISE(scalar_, MakeScalar(type.GetSharedPtr(), half_float_value)); + return Status::OK(); + } + Status Visit(const LargeStringType& type) { + google::protobuf::StringValue value; + if (!user_defined_->value().UnpackTo(&value)) { + return Status::Invalid( + "Failed to unpack user defined large_string literal to string"); + } + ARROW_ASSIGN_OR_RAISE(scalar_, + MakeScalar(type.GetSharedPtr(), std::string(value.value()))); + return Status::OK(); + } + Status Visit(const LargeBinaryType& type) { + google::protobuf::BytesValue value; + if (!user_defined_->value().UnpackTo(&value)) { + return Status::Invalid( + "Failed to unpack user defined large_binary literal to bytes"); + } + ARROW_ASSIGN_OR_RAISE(scalar_, + MakeScalar(type.GetSharedPtr(), std::string(value.value()))); + return Status::OK(); + } + Status operator()(const DataType& type) { return VisitTypeInline(type, this); } + + std::shared_ptr scalar_; + const substrait::Expression::Literal::UserDefined* user_defined_; + const ExtensionSet* ext_set_; + const ConversionOptions& conversion_options_; +}; +} // namespace + Result FromProto(const substrait::Expression::Literal& lit, const ExtensionSet& ext_set, const ConversionOptions& conversion_options) { @@ -455,6 +520,7 @@ Result FromProto(const substrait::Expression::Literal& lit, case substrait::Expression::Literal::kBinary: return Datum(BinaryScalar(lit.binary())); + ARROW_SUPPRESS_DEPRECATION_WARNING case substrait::Expression::Literal::kTimestamp: return Datum( TimestampScalar(static_cast(lit.timestamp()), TimeUnit::MICRO)); @@ -462,7 +528,17 @@ Result FromProto(const substrait::Expression::Literal& lit, case substrait::Expression::Literal::kTimestampTz: return Datum(TimestampScalar(static_cast(lit.timestamp_tz()), TimeUnit::MICRO, TimestampTzTimezoneString())); - + ARROW_UNSUPPRESS_DEPRECATION_WARNING + case substrait::Expression::Literal::kPrecisionTimestamp: { + // https://github.com/substrait-io/substrait/issues/611 + // TODO(weston) don't break, return precision timestamp + break; + } + case substrait::Expression::Literal::kPrecisionTimestampTz: { + // https://github.com/substrait-io/substrait/issues/611 + // TODO(weston) don't break, return precision timestamp + break; + } case substrait::Expression::Literal::kDate: return Datum(Date32Scalar(lit.date())); case substrait::Expression::Literal::kTime: @@ -674,18 +750,32 @@ Result FromProto(const substrait::Expression::Literal& lit, return Datum(MakeNullScalar(std::move(type_nullable.first))); } + case substrait::Expression::Literal::kUserDefined: { + const auto& user_defined = lit.user_defined(); + ARROW_ASSIGN_OR_RAISE(auto type_record, + ext_set.DecodeType(user_defined.type_reference())); + UserDefinedLiteralToArrow visitor{.scalar_ = nullptr, + .user_defined_ = &user_defined, + .ext_set_ = &ext_set, + .conversion_options_ = conversion_options}; + ARROW_RETURN_NOT_OK((visitor)(*type_record.type)); + return Datum(std::move(visitor.scalar_)); + } + + case substrait::Expression::Literal::LITERAL_TYPE_NOT_SET: + return Status::Invalid("substrait literal did not have any literal type set"); + default: break; } - return Status::NotImplemented("conversion to arrow::Datum from Substrait literal ", - lit.DebugString()); + return Status::NotImplemented("conversion to arrow::Datum from Substrait literal `", + lit.DebugString(), "`"); } namespace { -struct ScalarToProtoImpl { - Status Visit(const NullScalar& s) { return NotImplemented(s); } +struct ScalarToProtoImpl { using Lit = substrait::Expression::Literal; template @@ -702,6 +792,25 @@ struct ScalarToProtoImpl { return Status::OK(); } + Status EncodeUserDefined(const DataType& data_type, + const google::protobuf::Message& value) { + ARROW_ASSIGN_OR_RAISE(auto anchor, ext_set_->EncodeType(data_type)); + auto user_defined = std::make_unique(); + user_defined->set_type_reference(anchor); + auto value_any = std::make_unique(); + value_any->PackFrom(value); + user_defined->set_allocated_value(value_any.release()); + lit_->set_allocated_user_defined(user_defined.release()); + return Status::OK(); + } + + Status Visit(const NullScalar& s) { + ARROW_ASSIGN_OR_RAISE(auto anchor, ext_set_->EncodeType(*s.type)); + auto user_defined = std::make_unique(); + user_defined->set_type_reference(anchor); + lit_->set_allocated_user_defined(user_defined.release()); + return Status::OK(); + } Status Visit(const BooleanScalar& s) { return Primitive(&Lit::set_boolean, s); } Status Visit(const Int8Scalar& s) { return Primitive(&Lit::set_i8, s); } @@ -709,12 +818,31 @@ struct ScalarToProtoImpl { Status Visit(const Int32Scalar& s) { return Primitive(&Lit::set_i32, s); } Status Visit(const Int64Scalar& s) { return Primitive(&Lit::set_i64, s); } - Status Visit(const UInt8Scalar& s) { return NotImplemented(s); } - Status Visit(const UInt16Scalar& s) { return NotImplemented(s); } - Status Visit(const UInt32Scalar& s) { return NotImplemented(s); } - Status Visit(const UInt64Scalar& s) { return NotImplemented(s); } - - Status Visit(const HalfFloatScalar& s) { return NotImplemented(s); } + Status Visit(const UInt8Scalar& s) { + google::protobuf::UInt64Value value; + value.set_value(s.value); + return EncodeUserDefined(*s.type, value); + } + Status Visit(const UInt16Scalar& s) { + google::protobuf::UInt64Value value; + value.set_value(s.value); + return EncodeUserDefined(*s.type, value); + } + Status Visit(const UInt32Scalar& s) { + google::protobuf::UInt64Value value; + value.set_value(s.value); + return EncodeUserDefined(*s.type, value); + } + Status Visit(const UInt64Scalar& s) { + google::protobuf::UInt64Value value; + value.set_value(s.value); + return EncodeUserDefined(*s.type, value); + } + Status Visit(const HalfFloatScalar& s) { + google::protobuf::UInt32Value value; + value.set_value(s.value); + return EncodeUserDefined(*s.type, value); + } Status Visit(const FloatScalar& s) { return Primitive(&Lit::set_fp32, s); } Status Visit(const DoubleScalar& s) { return Primitive(&Lit::set_fp64, s); } @@ -722,12 +850,18 @@ struct ScalarToProtoImpl { return FromBuffer([](Lit* lit, std::string&& s) { lit->set_string(std::move(s)); }, s); } + Status Visit(const StringViewScalar& s) { + return FromBuffer([](Lit* lit, std::string&& s) { lit->set_string(std::move(s)); }, + s); + } Status Visit(const BinaryScalar& s) { return FromBuffer([](Lit* lit, std::string&& s) { lit->set_binary(std::move(s)); }, s); } - - Status Visit(const BinaryViewScalar& s) { return NotImplemented(s); } + Status Visit(const BinaryViewScalar& s) { + return FromBuffer([](Lit* lit, std::string&& s) { lit->set_binary(std::move(s)); }, + s); + } Status Visit(const FixedSizeBinaryScalar& s) { return FromBuffer( @@ -735,22 +869,51 @@ struct ScalarToProtoImpl { } Status Visit(const Date32Scalar& s) { return Primitive(&Lit::set_date, s); } - Status Visit(const Date64Scalar& s) { return NotImplemented(s); } + Status Visit(const Date64Scalar& s) { + google::protobuf::UInt64Value value; + value.set_value(s.value); + return EncodeUserDefined(*s.type, value); + } Status Visit(const TimestampScalar& s) { const auto& t = checked_cast(*s.type); - if (t.unit() != TimeUnit::MICRO) return NotImplemented(s); + uint64_t micros; + switch (t.unit()) { + case TimeUnit::SECOND: + micros = s.value * kMicrosPerSecond; + break; + case TimeUnit::MILLI: + micros = s.value * kMicrosPerMilli; + break; + case TimeUnit::MICRO: + micros = s.value; + break; + case TimeUnit::NANO: + // TODO(weston): can support nanos when + // https://github.com/substrait-io/substrait/issues/611 is resolved + return NotImplemented(s); + default: + return NotImplemented(s); + } - if (t.timezone() == "") return Primitive(&Lit::set_timestamp, s); + // Remove these and use precision timestamp once + // https://github.com/substrait-io/substrait/issues/611 is resolved + ARROW_SUPPRESS_DEPRECATION_WARNING - if (t.timezone() == TimestampTzTimezoneString()) { - return Primitive(&Lit::set_timestamp_tz, s); + if (t.timezone() == "") { + lit_->set_timestamp(micros); + } else { + // Some loss of info here, Substrait doesn't store timezone + // in field data + lit_->set_timestamp_tz(micros); } + ARROW_UNSUPPRESS_DEPRECATION_WARNING - return NotImplemented(s); + return Status::OK(); } + // Need to support parameterized UDTs Status Visit(const Time32Scalar& s) { return NotImplemented(s); } Status Visit(const Time64Scalar& s) { if (checked_cast(*s.type).unit() != TimeUnit::MICRO) { @@ -778,9 +941,10 @@ struct ScalarToProtoImpl { return Status::OK(); } + // Need support for parameterized UDTs Status Visit(const Decimal256Scalar& s) { return NotImplemented(s); } - Status Visit(const ListScalar& s) { + Status Visit(const BaseListScalar& s) { if (s.value->length() == 0) { ARROW_ASSIGN_OR_RAISE(auto list_type, ToProto(*s.type, /*nullable=*/true, ext_set_, conversion_options_)); @@ -807,10 +971,6 @@ struct ScalarToProtoImpl { return Status::OK(); } - Status Visit(const ListViewScalar& s) { - return Status::NotImplemented("list-view to proto"); - } - Status Visit(const LargeListViewScalar& s) { return Status::NotImplemented("list-view to proto"); } @@ -830,7 +990,10 @@ struct ScalarToProtoImpl { Status Visit(const SparseUnionScalar& s) { return NotImplemented(s); } Status Visit(const DenseUnionScalar& s) { return NotImplemented(s); } - Status Visit(const DictionaryScalar& s) { return NotImplemented(s); } + Status Visit(const DictionaryScalar& s) { + ARROW_ASSIGN_OR_RAISE(auto encoded, s.GetEncodedValue()); + return (*this)(*encoded); + } Status Visit(const MapScalar& s) { if (s.value->length() == 0) { @@ -914,10 +1077,21 @@ struct ScalarToProtoImpl { return NotImplemented(s); } + // Need support for parameterized UDTs Status Visit(const FixedSizeListScalar& s) { return NotImplemented(s); } Status Visit(const DurationScalar& s) { return NotImplemented(s); } - Status Visit(const LargeStringScalar& s) { return NotImplemented(s); } - Status Visit(const LargeBinaryScalar& s) { return NotImplemented(s); } + + Status Visit(const LargeStringScalar& s) { + google::protobuf::StringValue value; + value.set_value(s.view().data(), s.view().size()); + return EncodeUserDefined(*s.type, value); + } + Status Visit(const LargeBinaryScalar& s) { + google::protobuf::BytesValue value; + value.set_value(s.view().data(), s.view().size()); + return EncodeUserDefined(*s.type, value); + } + // Need support for parameterized UDTs Status Visit(const LargeListScalar& s) { return NotImplemented(s); } Status Visit(const MonthDayNanoIntervalScalar& s) { return NotImplemented(s); } diff --git a/cpp/src/arrow/engine/substrait/extension_set.cc b/cpp/src/arrow/engine/substrait/extension_set.cc index b0dd6aeffbcfa..5274071b83b59 100644 --- a/cpp/src/arrow/engine/substrait/extension_set.cc +++ b/cpp/src/arrow/engine/substrait/extension_set.cc @@ -314,6 +314,12 @@ Result ExtensionSet::EncodeType(const DataType& type) { return Status::KeyError("type ", type.ToString(), " not found in the registry"); } +Result ExtensionSet::EncodeTypeId(Id type_id) { + RETURN_NOT_OK(this->AddUri(type_id)); + auto it_success = types_map_.emplace(type_id, static_cast(types_map_.size())); + return it_success.first->second; +} + Result ExtensionSet::DecodeFunction(uint32_t anchor) const { if (functions_.find(anchor) == functions_.end() || functions_.at(anchor).empty()) { return Status::Invalid("User defined function reference ", anchor, @@ -1043,6 +1049,10 @@ struct DefaultExtensionIdRegistry : ExtensionIdRegistryImpl { TypeName{uint32(), "u32"}, TypeName{uint64(), "u64"}, TypeName{float16(), "fp16"}, + TypeName{large_utf8(), "large_string"}, + TypeName{large_binary(), "large_binary"}, + TypeName{date64(), "date_millis"}, + TypeName{time64(TimeUnit::NANO), "time_nanos"}, }) { DCHECK_OK(RegisterType({kArrowExtTypesUri, e.name}, std::move(e.type))); } diff --git a/cpp/src/arrow/engine/substrait/extension_set.h b/cpp/src/arrow/engine/substrait/extension_set.h index 0a502960447e6..fdd512558fc53 100644 --- a/cpp/src/arrow/engine/substrait/extension_set.h +++ b/cpp/src/arrow/engine/substrait/extension_set.h @@ -295,6 +295,10 @@ class ARROW_ENGINE_EXPORT ExtensionIdRegistry { constexpr std::string_view kArrowExtTypesUri = "https://github.com/apache/arrow/blob/main/format/substrait/" "extension_types.yaml"; +// Extension types that don't match 1:1 with a data type (or the data type is +// parameterized) +constexpr std::string_view kTimeNanosTypeName = "time_nanos"; +constexpr Id kTimeNanosId = {kArrowExtTypesUri, kTimeNanosTypeName}; /// A default registry with all supported functions and data types registered /// @@ -408,6 +412,13 @@ class ARROW_ENGINE_EXPORT ExtensionSet { /// \return An anchor that can be used to refer to the type within a plan Result EncodeType(const DataType& type); + /// \brief Lookup the anchor for a given type alias + /// + /// Similar to \see EncodeType but this is used for cases where the data type is either + /// parameterized or custom in some way (e.g. we use this for Time64::Nanos). We need + /// to use the Id directly since we can't have registered the type with the registry. + Result EncodeTypeId(Id type_id); + /// \brief Return a function id given an anchor /// /// This is used when converting a Substrait plan to an Arrow execution plan. diff --git a/cpp/src/arrow/engine/substrait/extension_types.cc b/cpp/src/arrow/engine/substrait/extension_types.cc index fcc722e9d9410..f71b5f7185d00 100644 --- a/cpp/src/arrow/engine/substrait/extension_types.cc +++ b/cpp/src/arrow/engine/substrait/extension_types.cc @@ -22,6 +22,7 @@ #include #include "arrow/engine/simple_extension_type_internal.h" +#include "arrow/engine/substrait/type_internal.h" #include "arrow/result.h" #include "arrow/type_fwd.h" #include "arrow/util/reflection_internal.h" @@ -113,6 +114,36 @@ std::shared_ptr interval_year() { return IntervalYearType::Make({}); } std::shared_ptr interval_day() { return IntervalDayType::Make({}); } +Result> precision_timestamp(int precision) { + switch (precision) { + case 0: + return timestamp(TimeUnit::SECOND); + case 3: + return timestamp(TimeUnit::MILLI); + case 6: + return timestamp(TimeUnit::MICRO); + case 9: + return timestamp(TimeUnit::NANO); + default: + return Status::NotImplemented("Unrecognized timestamp precision (", precision, ")"); + } +} + +Result> precision_timestamp_tz(int precision) { + switch (precision) { + case 0: + return timestamp(TimeUnit::SECOND, TimestampTzTimezoneString()); + case 3: + return timestamp(TimeUnit::MILLI, TimestampTzTimezoneString()); + case 6: + return timestamp(TimeUnit::MICRO, TimestampTzTimezoneString()); + case 9: + return timestamp(TimeUnit::NANO, TimestampTzTimezoneString()); + default: + return Status::NotImplemented("Unrecognized timestamp precision (", precision, ")"); + } +} + bool UnwrapUuid(const DataType& t) { if (UuidType::GetIf(t)) { return true; diff --git a/cpp/src/arrow/engine/substrait/extension_types.h b/cpp/src/arrow/engine/substrait/extension_types.h index 28a4898a878d7..ae71ad83f7e54 100644 --- a/cpp/src/arrow/engine/substrait/extension_types.h +++ b/cpp/src/arrow/engine/substrait/extension_types.h @@ -56,6 +56,16 @@ std::shared_ptr interval_year(); ARROW_ENGINE_EXPORT std::shared_ptr interval_day(); +/// constructs the appropriate timestamp type given the precision +/// no time zone +ARROW_ENGINE_EXPORT +Result> precision_timestamp(int precision); + +/// constructs the appropriate timestamp type given the precision +/// and the UTC time zone +ARROW_ENGINE_EXPORT +Result> precision_timestamp_tz(int precision); + /// Return true if t is Uuid, otherwise false ARROW_ENGINE_EXPORT bool UnwrapUuid(const DataType&); diff --git a/cpp/src/arrow/engine/substrait/serde_test.cc b/cpp/src/arrow/engine/substrait/serde_test.cc index 1e771ccdd25c2..209103bb1452d 100644 --- a/cpp/src/arrow/engine/substrait/serde_test.cc +++ b/cpp/src/arrow/engine/substrait/serde_test.cc @@ -54,6 +54,7 @@ #include "arrow/engine/substrait/options.h" #include "arrow/engine/substrait/serde.h" #include "arrow/engine/substrait/test_util.h" +#include "arrow/engine/substrait/type_internal.h" #include "arrow/engine/substrait/util.h" #include "arrow/filesystem/filesystem.h" #include "arrow/filesystem/localfs.h" @@ -299,6 +300,8 @@ TEST(Substrait, SupportedTypes) { map(utf8(), field("", utf8()), false)); } +// These types don't exist in Substrait. However, we have user defined types +// defined for them and they should be able to round-trip TEST(Substrait, SupportedExtensionTypes) { ExtensionSet ext_set; @@ -308,6 +311,10 @@ TEST(Substrait, SupportedExtensionTypes) { uint16(), uint32(), uint64(), + large_utf8(), + large_binary(), + date64(), + time64(TimeUnit::NANO), }) { auto anchor = ext_set.num_types(); @@ -332,6 +339,53 @@ TEST(Substrait, SupportedExtensionTypes) { } } +// Encodings are not considered distinct types in Substrait. The encoding information +// is lost during a round-trip. +TEST(Substrait, OneWayTypes) { + ExtensionSet ext_set; + + for (auto [source_type, return_type] : + {std::pair{binary_view(), binary()}, + {utf8_view(), utf8()}, + {dictionary(int32(), utf8()), utf8()}, + {run_end_encoded(int32(), int32()), int32()}, + {dictionary(int32(), dictionary(int32(), utf8())), utf8()}}) { + ASSERT_OK_AND_ASSIGN(auto substrait_type, SerializeType(*source_type, &ext_set, {})); + + ASSERT_OK_AND_ASSIGN(auto actual_return, + DeserializeType(*substrait_type, ext_set, {})); + + EXPECT_EQ(*actual_return, *return_type); + } +} + +// Substrait does not store the time zone as part of the type. That information is stored +// on the function instead. As a result, that information is lost on a round-trip. +TEST(Substrait, TimestampTypes) { + ExtensionSet ext_set; + + for (auto time_unit : + {TimeUnit::NANO, TimeUnit::MICRO, TimeUnit::MILLI, TimeUnit::SECOND}) { + for (auto time_zone : {"UTC", "America/New_York"}) { + auto input_type = timestamp(time_unit, time_zone); + ASSERT_OK_AND_ASSIGN(auto substrait_type, SerializeType(*input_type, &ext_set, {})); + + auto expected_return = timestamp(time_unit, TimestampTzTimezoneString()); + ASSERT_OK_AND_ASSIGN(auto actual_return, + DeserializeType(*substrait_type, ext_set, {})); + + EXPECT_EQ(*actual_return, *expected_return); + } + auto input_type = timestamp(time_unit); + ASSERT_OK_AND_ASSIGN(auto substrait_type, SerializeType(*input_type, &ext_set, {})); + + ASSERT_OK_AND_ASSIGN(auto actual_return, + DeserializeType(*substrait_type, ext_set, {})); + + EXPECT_EQ(*actual_return, *input_type); + } +} + TEST(Substrait, NamedStruct) { ExtensionSet ext_set; @@ -415,26 +469,13 @@ TEST(Substrait, NoEquivalentArrowType) { TEST(Substrait, NoEquivalentSubstraitType) { for (auto type : { - date64(), - timestamp(TimeUnit::SECOND), - timestamp(TimeUnit::NANO), - timestamp(TimeUnit::MICRO, "New York"), - time32(TimeUnit::SECOND), - time32(TimeUnit::MILLI), - time64(TimeUnit::NANO), - decimal256(76, 67), - sparse_union({field("i8", int8()), field("f32", float32())}), dense_union({field("i8", int8()), field("f32", float32())}), - dictionary(int32(), utf8()), - fixed_size_list(float16(), 3), - duration(TimeUnit::MICRO), - - large_utf8(), - large_binary(), + time32(TimeUnit::MILLI), + time32(TimeUnit::SECOND), large_list(utf8()), }) { ARROW_SCOPED_TRACE(type->ToString()); @@ -563,6 +604,45 @@ TEST(Substrait, SupportedLiterals) { } } +template +void CheckArrowSpecificLiteral(ScalarType scalar) { + compute::Expression lit = compute::literal(scalar); + ExtensionSet ext_set; + ASSERT_OK_AND_ASSIGN(auto serialized, SerializeExpression(lit, &ext_set)); + ASSERT_OK_AND_ASSIGN(auto roundtripped, DeserializeExpression(*serialized, ext_set)); + ASSERT_EQ(lit, roundtripped); +} + +TEST(Substrait, ArrowSpecificLiterals) { + CheckArrowSpecificLiteral(UInt8Scalar(7)); + CheckArrowSpecificLiteral(UInt16Scalar(7)); + CheckArrowSpecificLiteral(UInt32Scalar(7)); + CheckArrowSpecificLiteral(UInt64Scalar(7)); + CheckArrowSpecificLiteral(Date64Scalar(86400000)); + CheckArrowSpecificLiteral(HalfFloatScalar(0)); + CheckArrowSpecificLiteral(LargeStringScalar("hello")); + CheckArrowSpecificLiteral(LargeBinaryScalar("hello")); + CheckArrowSpecificLiteral(MakeNullScalar(null())); +} + +template +void CheckOneWayLiteral(SourceScalarType source, DestScalarType expected) { + compute::Expression lit = compute::literal(source); + ExtensionSet ext_set; + ASSERT_OK_AND_ASSIGN(auto serialized, SerializeExpression(lit, &ext_set)); + ASSERT_OK_AND_ASSIGN(auto roundtripped, DeserializeExpression(*serialized, ext_set)); + compute::Expression expected_lit = compute::literal(expected); + ASSERT_EQ(expected_lit, roundtripped); +} + +TEST(Substrait, OneWayLiterals) { + CheckOneWayLiteral(StringViewScalar("test"), StringScalar("test")); + CheckOneWayLiteral(BinaryViewScalar("test"), BinaryScalar("test")); + CheckOneWayLiteral(RunEndEncodedScalar(std::make_shared(7), + run_end_encoded(int16(), uint32())), + UInt32Scalar(7)); +} + TEST(Substrait, CannotDeserializeLiteral) { ExtensionSet ext_set; @@ -823,8 +903,8 @@ TEST(Substrait, Cast) { std::shared_ptr cast_opts = std::dynamic_pointer_cast(call_opts); ASSERT_TRUE(!!cast_opts); - // It is unclear whether a Substrait cast should be safe or not. In the meantime we are - // assuming it is unsafe based on the behavior of many SQL engines. + // It is unclear whether a Substrait cast should be safe or not. In the meantime we + // are assuming it is unsafe based on the behavior of many SQL engines. ASSERT_TRUE(cast_opts->allow_int_overflow); ASSERT_TRUE(cast_opts->allow_float_truncate); ASSERT_TRUE(cast_opts->allow_decimal_truncate); diff --git a/cpp/src/arrow/engine/substrait/type_internal.cc b/cpp/src/arrow/engine/substrait/type_internal.cc index f4a2e6800eb49..71e0bdc1511ac 100644 --- a/cpp/src/arrow/engine/substrait/type_internal.cc +++ b/cpp/src/arrow/engine/substrait/type_internal.cc @@ -34,6 +34,8 @@ #include "arrow/status.h" #include "arrow/type.h" #include "arrow/type_fwd.h" +#include "arrow/type_traits.h" +#include "arrow/util/unreachable.h" #include "arrow/visit_type_inline.h" namespace arrow { @@ -121,11 +123,24 @@ Result, bool>> FromProto( case substrait::Type::kBinary: return FromProtoImpl(type.binary()); + ARROW_SUPPRESS_DEPRECATION_WARNING case substrait::Type::kTimestamp: return FromProtoImpl(type.timestamp(), TimeUnit::MICRO); case substrait::Type::kTimestampTz: return FromProtoImpl(type.timestamp_tz(), TimeUnit::MICRO, TimestampTzTimezoneString()); + ARROW_UNSUPPRESS_DEPRECATION_WARNING + case substrait::Type::kPrecisionTimestamp: { + ARROW_ASSIGN_OR_RAISE(std::shared_ptr ts_type, + precision_timestamp(type.precision_timestamp().precision())); + return std::make_pair(ts_type, IsNullable(type.precision_timestamp())); + } + case substrait::Type::kPrecisionTimestampTz: { + ARROW_ASSIGN_OR_RAISE( + std::shared_ptr ts_type, + precision_timestamp_tz(type.precision_timestamp_tz().precision())); + return std::make_pair(ts_type, IsNullable(type.precision_timestamp_tz())); + } case substrait::Type::kDate: return FromProtoImpl(type.date()); @@ -263,7 +278,14 @@ struct DataTypeToProtoImpl { return SetWith(&substrait::Type::set_allocated_binary); } - Status Visit(const BinaryViewType& t) { return NotImplemented(t); } + // From Substrait's point of view the view types are encodings, and an execution detail, + // and not distinct from the non-view type. + Status Visit(const BinaryViewType& t) { + return SetWith(&substrait::Type::set_allocated_binary); + } + Status Visit(const StringViewType& t) { + return SetWith(&substrait::Type::set_allocated_string); + } Status Visit(const FixedSizeBinaryType& t) { SetWithThen(&substrait::Type::set_allocated_fixed_binary)->set_length(t.byte_width()); @@ -273,25 +295,66 @@ struct DataTypeToProtoImpl { Status Visit(const Date32Type& t) { return SetWith(&substrait::Type::set_allocated_date); } - Status Visit(const Date64Type& t) { return NotImplemented(t); } + Status Visit(const Date64Type& t) { return EncodeUserDefined(t); } Status Visit(const TimestampType& t) { - if (t.unit() != TimeUnit::MICRO) return NotImplemented(t); - if (t.timezone() == "") { - return SetWith(&substrait::Type::set_allocated_timestamp); - } - if (t.timezone() == TimestampTzTimezoneString()) { - return SetWith(&substrait::Type::set_allocated_timestamp_tz); + auto ts = SetWithThen(&substrait::Type::set_allocated_precision_timestamp); + switch (t.unit()) { + case TimeUnit::SECOND: + ts->set_precision(0); + break; + case TimeUnit::MILLI: + ts->set_precision(3); + break; + case TimeUnit::MICRO: + ts->set_precision(6); + break; + case TimeUnit::NANO: + ts->set_precision(9); + break; + default: + return NotImplemented(t); + } + } else { + // Note: The timezone information is discarded here. In Substrait the time zone + // information is part of the function and not part of the type. For example, to + // convert a timestamp to a string, the time zone is passed as an argument to the + // function. + auto ts = SetWithThen(&substrait::Type::set_allocated_precision_timestamp_tz); + switch (t.unit()) { + case TimeUnit::SECOND: + ts->set_precision(0); + break; + case TimeUnit::MILLI: + ts->set_precision(3); + break; + case TimeUnit::MICRO: + ts->set_precision(6); + break; + case TimeUnit::NANO: + ts->set_precision(9); + break; + default: + return NotImplemented(t); + } } + return Status::OK(); + } + Status Visit(const Time32Type& t) { + // TODO(weston) + // Unsupported for the same reason we don't support parameterized types + // which is that the extension registry only supports encoding one type + // per type id return NotImplemented(t); } - - Status Visit(const Time32Type& t) { return NotImplemented(t); } Status Visit(const Time64Type& t) { - if (t.unit() != TimeUnit::MICRO) return NotImplemented(t); - return SetWith(&substrait::Type::set_allocated_time); + if (t.unit() == TimeUnit::MICRO) { + return SetWith(&substrait::Type::set_allocated_time); + } else { + return EncodeUserDefined(t); + } } Status Visit(const MonthIntervalType& t) { return EncodeUserDefined(t); } @@ -303,6 +366,7 @@ struct DataTypeToProtoImpl { dec->set_scale(t.scale()); return Status::OK(); } + // TODO(weston) support parameterized UDT Status Visit(const Decimal256Type& t) { return NotImplemented(t); } Status Visit(const ListType& t) { @@ -313,8 +377,16 @@ struct DataTypeToProtoImpl { return Status::OK(); } - Status Visit(const ListViewType& t) { return NotImplemented(t); } + // From Substrait's point of view this is an encoding, and an implementation detail, + // and not distinct from the list type. + Status Visit(const ListViewType& t) { + ARROW_ASSIGN_OR_RAISE(auto type, ToProto(*t.value_type(), t.value_field()->nullable(), + ext_set_, conversion_options_)); + SetWithThen(&substrait::Type::set_allocated_list)->set_allocated_type(type.release()); + return Status::OK(); + } + // TODO(weston) support parameterized UDT Status Visit(const LargeListViewType& t) { return NotImplemented(t); } Status Visit(const StructType& t) { @@ -335,8 +407,9 @@ struct DataTypeToProtoImpl { Status Visit(const SparseUnionType& t) { return NotImplemented(t); } Status Visit(const DenseUnionType& t) { return NotImplemented(t); } - Status Visit(const DictionaryType& t) { return NotImplemented(t); } - Status Visit(const RunEndEncodedType& t) { return NotImplemented(t); } + // The caller should have unwrapped the dictionary / RLE type + Status Visit(const DictionaryType& t) { Unreachable(); } + Status Visit(const RunEndEncodedType& t) { Unreachable(); } Status Visit(const MapType& t) { // FIXME assert default field names; custom ones won't roundtrip @@ -379,10 +452,13 @@ struct DataTypeToProtoImpl { return NotImplemented(t); } + // TODO(weston) support parameterized UDT Status Visit(const FixedSizeListType& t) { return NotImplemented(t); } + // TODO(weston) support parameterized UDT Status Visit(const DurationType& t) { return NotImplemented(t); } - Status Visit(const LargeStringType& t) { return NotImplemented(t); } - Status Visit(const LargeBinaryType& t) { return NotImplemented(t); } + Status Visit(const LargeStringType& t) { return EncodeUserDefined(t); } + Status Visit(const LargeBinaryType& t) { return EncodeUserDefined(t); } + // TODO(weston) support parameterized UDT Status Visit(const LargeListType& t) { return NotImplemented(t); } Status Visit(const MonthDayNanoIntervalType& t) { return EncodeUserDefined(t); } @@ -429,6 +505,17 @@ struct DataTypeToProtoImpl { Result> ToProto( const DataType& type, bool nullable, ExtensionSet* ext_set, const ConversionOptions& conversion_options) { + // From Substrait's perspective the "dictionary type" is just an encoding. As a result, + // we lose that information on conversion and just convert the value type. + if (type.id() == Type::DICTIONARY) { + const auto& dict_type = internal::checked_cast(type); + return ToProto(*dict_type.value_type(), nullable, ext_set, conversion_options); + } + // Ditto for REE + if (type.id() == Type::RUN_END_ENCODED) { + const auto& ree_type = internal::checked_cast(type); + return ToProto(*ree_type.value_type(), nullable, ext_set, conversion_options); + } auto out = std::make_unique(); RETURN_NOT_OK( (DataTypeToProtoImpl{out.get(), nullable, ext_set, conversion_options})(type)); diff --git a/cpp/thirdparty/versions.txt b/cpp/thirdparty/versions.txt index 4093b0ec43efd..4983f3cee2c2d 100644 --- a/cpp/thirdparty/versions.txt +++ b/cpp/thirdparty/versions.txt @@ -103,8 +103,8 @@ ARROW_RE2_BUILD_VERSION=2022-06-01 ARROW_RE2_BUILD_SHA256_CHECKSUM=f89c61410a072e5cbcf8c27e3a778da7d6fd2f2b5b1445cd4f4508bee946ab0f ARROW_SNAPPY_BUILD_VERSION=1.1.10 ARROW_SNAPPY_BUILD_SHA256_CHECKSUM=49d831bffcc5f3d01482340fe5af59852ca2fe76c3e05df0e67203ebbe0f1d90 -ARROW_SUBSTRAIT_BUILD_VERSION=v0.27.0 -ARROW_SUBSTRAIT_BUILD_SHA256_CHECKSUM=4ed375f69d972a57fdc5ec406c17003a111831d8640d3f1733eccd4b3ff45628 +ARROW_SUBSTRAIT_BUILD_VERSION=v0.44.0 +ARROW_SUBSTRAIT_BUILD_SHA256_CHECKSUM=f989a862f694e7dbb695925ddb7c4ce06aa6c51aca945105c075139aed7e55a2 ARROW_S2N_TLS_BUILD_VERSION=v1.3.35 ARROW_S2N_TLS_BUILD_SHA256_CHECKSUM=9d32b26e6bfcc058d98248bf8fc231537e347395dd89cf62bb432b55c5da990d ARROW_THRIFT_BUILD_VERSION=0.16.0 diff --git a/format/substrait/extension_types.yaml b/format/substrait/extension_types.yaml index 888d6c94c8182..0a5e43c30d8f1 100644 --- a/format/substrait/extension_types.yaml +++ b/format/substrait/extension_types.yaml @@ -42,29 +42,48 @@ # (but that is an infinite space). Similarly, we would have to declare a # timestamp variation for all possible timezone strings. -type_variations: - - parent: i8 - name: u8 - description: an unsigned 8 bit integer - functions: SEPARATE - - parent: i16 - name: u16 - description: an unsigned 16 bit integer - functions: SEPARATE - - parent: i32 - name: u32 - description: an unsigned 32 bit integer - functions: SEPARATE - - parent: i64 - name: u64 - description: an unsigned 64 bit integer - functions: SEPARATE +# Certain Arrow data types are, from Substrait's point of view, encodings. +# These include dictionary, the view types (e.g. binary view, list view), +# and REE. +# +# These types are not logically distinct from the type they are encoding. +# Specifically: +# * There is no value in the decoded type that cannot be represented +# in the encoded type and vice versa. +# * Functions have the same meaning when applied to the encoded type +# +# These types will never have a Substrait equivalent. In the Substrait point +# of view these are execution details. + +# The following types are encodings: + +# binary_view +# list_view +# dictionary +# ree - - parent: i16 - name: fp16 - description: a 16 bit floating point number - functions: SEPARATE +# Arrow-cpp's Substrait serde does not yet handle parameterized UDFs. This means +# the following types are not yet supported but may be supported in the future. +# We define them below in case other implementations support them in the meantime. +# decimal256 +# large_list +# fixed_size_list +# duration +# time32 - not technically a parameterized type, but unsupported for similar reasons + +# Other types are not encodings, but are not first-class in Substrait. These +# types are often similar to existing Substrait types but define a different range +# of values. For example, unsigned integer types are very similar to their integer +# counterparts, but have a different range of values. These types are defined here +# as extension types. +# +# A full description of the types, along with their specified range, can be found +# in Schema.fbs +# +# Consumers should take care when supporting the below types. Should Substrait decide +# later to support these types, the consumer will need to make sure to continue supporting +# the extension type names as aliases for proper backwards compatibility. types: - name: "null" structure: {} @@ -80,3 +99,74 @@ types: months: i32 days: i32 nanos: i64 + # All signed integer literals are encoded as user defined literals with + # a google.protobuf.UInt64Value message. + - name: i8 + structure: {} + - name: i16 + structure: {} + - name: i32 + structure: {} + - name: i64 + structure: {} + # fp16 literals are encoded as user defined literals with + # a google.protobuf.UInt32Value message where the lower 16 bits are + # the fp16 value. + - name: fp16 + structure: {} + # 64-bit integers are big. Even though date64 stores ms and not days it + # can still represent about 50x more dates than date32. Since it has a + # different range of values, it is an extension type. + # + # date64 literals are encoded as user defined literals with + # a google.protobuf.UInt64Value message. + - name: date_millis + structure: {} + # We cannot generate `time` today for reasons similar to parameterized + # UDTs. + - name: time_seconds + structure: {} + - name: time_millis + structure: {} + - name: time_nanos + # Large string literals are encoded using a + # google.protobuf.StringValue message. + structure: {} + - name: large_string + structure: {} + # Large binary literals are encoded using a + # google.protobuf.BytesValue message. + - name: large_binary + structure: {} + # We cannot generate these today because they are parameterized UDTs and + # substrait-cpp does not yet support parameterized UDTs. + - name: decimal256 + structure: {} + parameters: + - name: precision + type: integer + min: 0 + max: 76 + - name: scale + type: integer + min: 0 + max: 76 + - name: large_list + structure: {} + parameters: + - name: value_type + type: dataType + - name: fixed size list + structure: {} + parameters: + - name: value_type + type: dataType + - name: dimension + type: integer + min: 0 + - name: duration + structure: {} + parameters: + - name: unit + type: string + diff --git a/python/pyarrow/tests/test_substrait.py b/python/pyarrow/tests/test_substrait.py index d4fbfb7406838..a30cd3071b0cd 100644 --- a/python/pyarrow/tests/test_substrait.py +++ b/python/pyarrow/tests/test_substrait.py @@ -944,6 +944,31 @@ def test_serializing_expressions(expr): assert "test_expr" in returned.expressions +def test_arrow_specific_types(): + schema = pa.schema( + [ + pa.field("time_nanos", pa.time64("ns")), + pa.field("date_millis", pa.date64()), + pa.field("large_string", pa.large_string()), + pa.field("large_binary", pa.large_binary()), + ] + ) + + def check_round_trip(expr): + buf = pa.substrait.serialize_expressions([expr], ["test_expr"], schema) + returned = pa.substrait.deserialize_expressions(buf) + assert schema == returned.schema + + check_round_trip(pc.field("large_string") == "test_string") + check_round_trip(pc.field("large_binary") == "test_string") + # Arrow-cpp supports round tripping these types but pyarrow doesn't support + # constructing literals of these types + # + # So best we can do is verify field references work + check_round_trip(pc.field("time_nanos")) + check_round_trip(pc.field("date_millis")) + + def test_invalid_expression_ser_des(): schema = pa.schema([ pa.field("x", pa.int32()), From c69490b934ebec865bd63ae9b78d25a21bb31dc9 Mon Sep 17 00:00:00 2001 From: Weston Pace Date: Wed, 20 Mar 2024 17:24:27 -0700 Subject: [PATCH 2/8] Bump up version constant. Remove designated initializer syntax --- cpp/src/arrow/engine/substrait/expression_internal.cc | 6 ++---- cpp/src/arrow/engine/substrait/util.h | 2 +- 2 files changed, 3 insertions(+), 5 deletions(-) diff --git a/cpp/src/arrow/engine/substrait/expression_internal.cc b/cpp/src/arrow/engine/substrait/expression_internal.cc index 0ba129f3efd89..ed452d46ab2be 100644 --- a/cpp/src/arrow/engine/substrait/expression_internal.cc +++ b/cpp/src/arrow/engine/substrait/expression_internal.cc @@ -754,10 +754,8 @@ Result FromProto(const substrait::Expression::Literal& lit, const auto& user_defined = lit.user_defined(); ARROW_ASSIGN_OR_RAISE(auto type_record, ext_set.DecodeType(user_defined.type_reference())); - UserDefinedLiteralToArrow visitor{.scalar_ = nullptr, - .user_defined_ = &user_defined, - .ext_set_ = &ext_set, - .conversion_options_ = conversion_options}; + UserDefinedLiteralToArrow visitor{nullptr, &user_defined, &ext_set, + conversion_options}; ARROW_RETURN_NOT_OK((visitor)(*type_record.type)); return Datum(std::move(visitor.scalar_)); } diff --git a/cpp/src/arrow/engine/substrait/util.h b/cpp/src/arrow/engine/substrait/util.h index 5128ec44bff77..bef2a6c7e1823 100644 --- a/cpp/src/arrow/engine/substrait/util.h +++ b/cpp/src/arrow/engine/substrait/util.h @@ -70,7 +70,7 @@ ARROW_ENGINE_EXPORT const std::string& default_extension_types_uri(); // TODO(ARROW-18145) Populate these from cmake files constexpr uint32_t kSubstraitMajorVersion = 0; -constexpr uint32_t kSubstraitMinorVersion = 27; +constexpr uint32_t kSubstraitMinorVersion = 44; constexpr uint32_t kSubstraitPatchVersion = 0; constexpr uint32_t kSubstraitMinimumMajorVersion = 0; From 7d0195e8f8e894b7a0f8152254dcc5fb684dac0b Mon Sep 17 00:00:00 2001 From: Weston Pace Date: Wed, 20 Mar 2024 18:56:02 -0700 Subject: [PATCH 3/8] Add include of checked_cast --- cpp/src/arrow/engine/substrait/type_internal.cc | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/cpp/src/arrow/engine/substrait/type_internal.cc b/cpp/src/arrow/engine/substrait/type_internal.cc index 71e0bdc1511ac..700a68db9b12c 100644 --- a/cpp/src/arrow/engine/substrait/type_internal.cc +++ b/cpp/src/arrow/engine/substrait/type_internal.cc @@ -35,9 +35,12 @@ #include "arrow/type.h" #include "arrow/type_fwd.h" #include "arrow/type_traits.h" +#include "arrow/util/checked_cast.h" #include "arrow/util/unreachable.h" #include "arrow/visit_type_inline.h" +using arrow::internal::checked_cast; + namespace arrow { namespace engine { @@ -508,12 +511,12 @@ Result> ToProto( // From Substrait's perspective the "dictionary type" is just an encoding. As a result, // we lose that information on conversion and just convert the value type. if (type.id() == Type::DICTIONARY) { - const auto& dict_type = internal::checked_cast(type); + const auto& dict_type = checked_cast(type); return ToProto(*dict_type.value_type(), nullable, ext_set, conversion_options); } // Ditto for REE if (type.id() == Type::RUN_END_ENCODED) { - const auto& ree_type = internal::checked_cast(type); + const auto& ree_type = checked_cast(type); return ToProto(*ree_type.value_type(), nullable, ext_set, conversion_options); } auto out = std::make_unique(); From 99ae09f709ac5fdeeb90412d6b06ead7faf88a07 Mon Sep 17 00:00:00 2001 From: Weston Pace Date: Fri, 22 Mar 2024 04:37:42 -0700 Subject: [PATCH 4/8] Apply suggestions from code review Co-authored-by: Benjamin Kietzman --- cpp/src/arrow/engine/substrait/extension_set.cc | 4 ++-- format/substrait/extension_types.yaml | 10 +++++----- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/cpp/src/arrow/engine/substrait/extension_set.cc b/cpp/src/arrow/engine/substrait/extension_set.cc index 5274071b83b59..63d0e672dfadc 100644 --- a/cpp/src/arrow/engine/substrait/extension_set.cc +++ b/cpp/src/arrow/engine/substrait/extension_set.cc @@ -316,8 +316,8 @@ Result ExtensionSet::EncodeType(const DataType& type) { Result ExtensionSet::EncodeTypeId(Id type_id) { RETURN_NOT_OK(this->AddUri(type_id)); - auto it_success = types_map_.emplace(type_id, static_cast(types_map_.size())); - return it_success.first->second; + auto [it, success] = types_map_.emplace(type_id, static_cast(types_map_.size())); + return it->second; } Result ExtensionSet::DecodeFunction(uint32_t anchor) const { diff --git a/format/substrait/extension_types.yaml b/format/substrait/extension_types.yaml index 0a5e43c30d8f1..f61b08a4a82f2 100644 --- a/format/substrait/extension_types.yaml +++ b/format/substrait/extension_types.yaml @@ -48,8 +48,8 @@ # # These types are not logically distinct from the type they are encoding. # Specifically: -# * There is no value in the decoded type that cannot be represented -# in the encoded type and vice versa. +# * There is no column of the decoded type that cannot be represented +# as a column the encoded type and vice versa. # * Functions have the same meaning when applied to the encoded type # # These types will never have a Substrait equivalent. In the Substrait point @@ -62,7 +62,7 @@ # dictionary # ree -# Arrow-cpp's Substrait serde does not yet handle parameterized UDFs. This means +# Arrow-cpp's Substrait serde does not yet handle parameterized UDTs. This means # the following types are not yet supported but may be supported in the future. # We define them below in case other implementations support them in the meantime. @@ -129,9 +129,9 @@ types: - name: time_millis structure: {} - name: time_nanos + structure: {} # Large string literals are encoded using a # google.protobuf.StringValue message. - structure: {} - name: large_string structure: {} # Large binary literals are encoded using a @@ -156,7 +156,7 @@ types: parameters: - name: value_type type: dataType - - name: fixed size list + - name: fixed_size_list structure: {} parameters: - name: value_type From 3c87d3ad82389adee72f135c1f211aefd3f1ea56 Mon Sep 17 00:00:00 2001 From: Weston Pace Date: Fri, 22 Mar 2024 06:41:24 -0700 Subject: [PATCH 5/8] Addressing review comments --- .../engine/substrait/expression_internal.cc | 56 ++++++++++++++----- .../arrow/engine/substrait/extension_set.cc | 10 +--- .../arrow/engine/substrait/extension_set.h | 7 --- cpp/src/arrow/engine/substrait/serde_test.cc | 15 ++++- .../arrow/engine/substrait/type_internal.cc | 18 ++---- format/substrait/extension_types.yaml | 36 ++++++------ python/pyarrow/tests/test_substrait.py | 43 +++++++++----- 7 files changed, 110 insertions(+), 75 deletions(-) diff --git a/cpp/src/arrow/engine/substrait/expression_internal.cc b/cpp/src/arrow/engine/substrait/expression_internal.cc index ed452d46ab2be..480cf30d3033f 100644 --- a/cpp/src/arrow/engine/substrait/expression_internal.cc +++ b/cpp/src/arrow/engine/substrait/expression_internal.cc @@ -435,15 +435,35 @@ struct UserDefinedLiteralToArrow { Status Visit(const IntegerType& type) { google::protobuf::UInt64Value value; if (!user_defined_->value().UnpackTo(&value)) { - return Status::Invalid("Failed to unpack user defined integer literal to uint64"); + return Status::Invalid( + "Failed to unpack user defined integer literal to UInt64Value"); + } + ARROW_ASSIGN_OR_RAISE(scalar_, MakeScalar(type.GetSharedPtr(), value.value())); + return Status::OK(); + } + Status Visit(const Time32Type& type) { + google::protobuf::Int32Value value; + if (!user_defined_->value().UnpackTo(&value)) { + return Status::Invalid( + "Failed to unpack user defined time32 literal to Int32Value"); + } + ARROW_ASSIGN_OR_RAISE(scalar_, MakeScalar(type.GetSharedPtr(), value.value())); + return Status::OK(); + } + Status Visit(const Time64Type& type) { + google::protobuf::Int64Value value; + if (!user_defined_->value().UnpackTo(&value)) { + return Status::Invalid( + "Failed to unpack user defined time64 literal to Int64Value"); } ARROW_ASSIGN_OR_RAISE(scalar_, MakeScalar(type.GetSharedPtr(), value.value())); return Status::OK(); } Status Visit(const Date64Type& type) { - google::protobuf::UInt64Value value; + google::protobuf::Int64Value value; if (!user_defined_->value().UnpackTo(&value)) { - return Status::Invalid("Failed to unpack user defined date64 literal to uint64"); + return Status::Invalid( + "Failed to unpack user defined date64 literal to Int64Value"); } ARROW_ASSIGN_OR_RAISE(scalar_, MakeScalar(type.GetSharedPtr(), value.value())); return Status::OK(); @@ -451,7 +471,8 @@ struct UserDefinedLiteralToArrow { Status Visit(const HalfFloatType& type) { google::protobuf::UInt32Value value; if (!user_defined_->value().UnpackTo(&value)) { - return Status::Invalid("Failed to unpack user defined half_float literal to bytes"); + return Status::Invalid( + "Failed to unpack user defined half_float literal to UInt32Value"); } uint16_t half_float_value = value.value(); ARROW_ASSIGN_OR_RAISE(scalar_, MakeScalar(type.GetSharedPtr(), half_float_value)); @@ -461,7 +482,7 @@ struct UserDefinedLiteralToArrow { google::protobuf::StringValue value; if (!user_defined_->value().UnpackTo(&value)) { return Status::Invalid( - "Failed to unpack user defined large_string literal to string"); + "Failed to unpack user defined large_string literal to StringValue"); } ARROW_ASSIGN_OR_RAISE(scalar_, MakeScalar(type.GetSharedPtr(), std::string(value.value()))); @@ -471,7 +492,7 @@ struct UserDefinedLiteralToArrow { google::protobuf::BytesValue value; if (!user_defined_->value().UnpackTo(&value)) { return Status::Invalid( - "Failed to unpack user defined large_binary literal to bytes"); + "Failed to unpack user defined large_binary literal to BytesValue"); } ARROW_ASSIGN_OR_RAISE(scalar_, MakeScalar(type.GetSharedPtr(), std::string(value.value()))); @@ -531,12 +552,12 @@ Result FromProto(const substrait::Expression::Literal& lit, ARROW_UNSUPPRESS_DEPRECATION_WARNING case substrait::Expression::Literal::kPrecisionTimestamp: { // https://github.com/substrait-io/substrait/issues/611 - // TODO(weston) don't break, return precision timestamp + // TODO(GH-40741) don't break, return precision timestamp break; } case substrait::Expression::Literal::kPrecisionTimestampTz: { // https://github.com/substrait-io/substrait/issues/611 - // TODO(weston) don't break, return precision timestamp + // TODO(GH-40741) don't break, return precision timestamp break; } case substrait::Expression::Literal::kDate: @@ -868,7 +889,7 @@ struct ScalarToProtoImpl { Status Visit(const Date32Scalar& s) { return Primitive(&Lit::set_date, s); } Status Visit(const Date64Scalar& s) { - google::protobuf::UInt64Value value; + google::protobuf::Int64Value value; value.set_value(s.value); return EncodeUserDefined(*s.type, value); } @@ -888,7 +909,7 @@ struct ScalarToProtoImpl { micros = s.value; break; case TimeUnit::NANO: - // TODO(weston): can support nanos when + // TODO(GH-40741): can support nanos when // https://github.com/substrait-io/substrait/issues/611 is resolved return NotImplemented(s); default: @@ -912,12 +933,19 @@ struct ScalarToProtoImpl { } // Need to support parameterized UDTs - Status Visit(const Time32Scalar& s) { return NotImplemented(s); } + Status Visit(const Time32Scalar& s) { + google::protobuf::Int32Value value; + value.set_value(s.value); + return EncodeUserDefined(*s.type, value); + } Status Visit(const Time64Scalar& s) { - if (checked_cast(*s.type).unit() != TimeUnit::MICRO) { - return NotImplemented(s); + if (checked_cast(*s.type).unit() == TimeUnit::MICRO) { + return Primitive(&Lit::set_time, s); + } else { + google::protobuf::Int64Value value; + value.set_value(s.value); + return EncodeUserDefined(*s.type, value); } - return Primitive(&Lit::set_time, s); } Status Visit(const MonthIntervalScalar& s) { return NotImplemented(s); } diff --git a/cpp/src/arrow/engine/substrait/extension_set.cc b/cpp/src/arrow/engine/substrait/extension_set.cc index 63d0e672dfadc..e955084dcdfbb 100644 --- a/cpp/src/arrow/engine/substrait/extension_set.cc +++ b/cpp/src/arrow/engine/substrait/extension_set.cc @@ -314,12 +314,6 @@ Result ExtensionSet::EncodeType(const DataType& type) { return Status::KeyError("type ", type.ToString(), " not found in the registry"); } -Result ExtensionSet::EncodeTypeId(Id type_id) { - RETURN_NOT_OK(this->AddUri(type_id)); - auto [it, success] = types_map_.emplace(type_id, static_cast(types_map_.size())); - return it->second; -} - Result ExtensionSet::DecodeFunction(uint32_t anchor) const { if (functions_.find(anchor) == functions_.end() || functions_.at(anchor).empty()) { return Status::Invalid("User defined function reference ", anchor, @@ -1041,7 +1035,7 @@ struct DefaultExtensionIdRegistry : ExtensionIdRegistryImpl { }; // The type (variation) mappings listed below need to be kept in sync - // with the YAML at substrait/format/extension_types.yaml manually; + // with the YAML at format/substrait/extension_types.yaml manually; // see ARROW-15535. for (TypeName e : { TypeName{uint8(), "u8"}, @@ -1051,6 +1045,8 @@ struct DefaultExtensionIdRegistry : ExtensionIdRegistryImpl { TypeName{float16(), "fp16"}, TypeName{large_utf8(), "large_string"}, TypeName{large_binary(), "large_binary"}, + TypeName{time32(TimeUnit::SECOND), "time_seconds"}, + TypeName{time32(TimeUnit::MILLI), "time_millis"}, TypeName{date64(), "date_millis"}, TypeName{time64(TimeUnit::NANO), "time_nanos"}, }) { diff --git a/cpp/src/arrow/engine/substrait/extension_set.h b/cpp/src/arrow/engine/substrait/extension_set.h index fdd512558fc53..c18e0cf77aae5 100644 --- a/cpp/src/arrow/engine/substrait/extension_set.h +++ b/cpp/src/arrow/engine/substrait/extension_set.h @@ -412,13 +412,6 @@ class ARROW_ENGINE_EXPORT ExtensionSet { /// \return An anchor that can be used to refer to the type within a plan Result EncodeType(const DataType& type); - /// \brief Lookup the anchor for a given type alias - /// - /// Similar to \see EncodeType but this is used for cases where the data type is either - /// parameterized or custom in some way (e.g. we use this for Time64::Nanos). We need - /// to use the Id directly since we can't have registered the type with the registry. - Result EncodeTypeId(Id type_id); - /// \brief Return a function id given an anchor /// /// This is used when converting a Substrait plan to an Arrow execution plan. diff --git a/cpp/src/arrow/engine/substrait/serde_test.cc b/cpp/src/arrow/engine/substrait/serde_test.cc index 209103bb1452d..3e80192377937 100644 --- a/cpp/src/arrow/engine/substrait/serde_test.cc +++ b/cpp/src/arrow/engine/substrait/serde_test.cc @@ -314,6 +314,8 @@ TEST(Substrait, SupportedExtensionTypes) { large_utf8(), large_binary(), date64(), + time32(TimeUnit::SECOND), + time32(TimeUnit::MILLI), time64(TimeUnit::NANO), }) { auto anchor = ext_set.num_types(); @@ -474,8 +476,6 @@ TEST(Substrait, NoEquivalentSubstraitType) { dense_union({field("i8", int8()), field("f32", float32())}), fixed_size_list(float16(), 3), duration(TimeUnit::MICRO), - time32(TimeUnit::MILLI), - time32(TimeUnit::SECOND), large_list(utf8()), }) { ARROW_SCOPED_TRACE(type->ToString()); @@ -619,6 +619,17 @@ TEST(Substrait, ArrowSpecificLiterals) { CheckArrowSpecificLiteral(UInt32Scalar(7)); CheckArrowSpecificLiteral(UInt64Scalar(7)); CheckArrowSpecificLiteral(Date64Scalar(86400000)); + CheckArrowSpecificLiteral(Time64Scalar(7, TimeUnit::NANO)); + CheckArrowSpecificLiteral(Time32Scalar(7, TimeUnit::SECOND)); + CheckArrowSpecificLiteral(Time32Scalar(7, TimeUnit::MILLI)); + CheckArrowSpecificLiteral(Time32Scalar(7, TimeUnit::SECOND)); + // We serialize as a signed integer, which doesn't make sense for Time scalars but + // Arrow supports it so we might as well round-trip it. + CheckArrowSpecificLiteral(Time32Scalar(-7, TimeUnit::MILLI)); + CheckArrowSpecificLiteral(Time32Scalar(-7, TimeUnit::SECOND)); + CheckArrowSpecificLiteral(Time64Scalar(-7, TimeUnit::NANO)); + // Negative date scalars DO make sense and we should make sure they work + CheckArrowSpecificLiteral(Date64Scalar(-86400000)); CheckArrowSpecificLiteral(HalfFloatScalar(0)); CheckArrowSpecificLiteral(LargeStringScalar("hello")); CheckArrowSpecificLiteral(LargeBinaryScalar("hello")); diff --git a/cpp/src/arrow/engine/substrait/type_internal.cc b/cpp/src/arrow/engine/substrait/type_internal.cc index 700a68db9b12c..1587112518426 100644 --- a/cpp/src/arrow/engine/substrait/type_internal.cc +++ b/cpp/src/arrow/engine/substrait/type_internal.cc @@ -345,13 +345,7 @@ struct DataTypeToProtoImpl { return Status::OK(); } - Status Visit(const Time32Type& t) { - // TODO(weston) - // Unsupported for the same reason we don't support parameterized types - // which is that the extension registry only supports encoding one type - // per type id - return NotImplemented(t); - } + Status Visit(const Time32Type& t) { return EncodeUserDefined(t); } Status Visit(const Time64Type& t) { if (t.unit() == TimeUnit::MICRO) { return SetWith(&substrait::Type::set_allocated_time); @@ -369,7 +363,7 @@ struct DataTypeToProtoImpl { dec->set_scale(t.scale()); return Status::OK(); } - // TODO(weston) support parameterized UDT + // TODO(GH-40740) support parameterized UDT Status Visit(const Decimal256Type& t) { return NotImplemented(t); } Status Visit(const ListType& t) { @@ -389,7 +383,7 @@ struct DataTypeToProtoImpl { return Status::OK(); } - // TODO(weston) support parameterized UDT + // TODO(GH-40740) support parameterized UDT Status Visit(const LargeListViewType& t) { return NotImplemented(t); } Status Visit(const StructType& t) { @@ -455,13 +449,13 @@ struct DataTypeToProtoImpl { return NotImplemented(t); } - // TODO(weston) support parameterized UDT + // TODO(GH-40740) support parameterized UDT Status Visit(const FixedSizeListType& t) { return NotImplemented(t); } - // TODO(weston) support parameterized UDT + // TODO(GH-40740) support parameterized UDT Status Visit(const DurationType& t) { return NotImplemented(t); } Status Visit(const LargeStringType& t) { return EncodeUserDefined(t); } Status Visit(const LargeBinaryType& t) { return EncodeUserDefined(t); } - // TODO(weston) support parameterized UDT + // TODO(GH-40740) support parameterized UDT Status Visit(const LargeListType& t) { return NotImplemented(t); } Status Visit(const MonthDayNanoIntervalType& t) { return EncodeUserDefined(t); } diff --git a/format/substrait/extension_types.yaml b/format/substrait/extension_types.yaml index f61b08a4a82f2..0073da1acc1ed 100644 --- a/format/substrait/extension_types.yaml +++ b/format/substrait/extension_types.yaml @@ -35,23 +35,21 @@ # - interval # - arrow::ExtensionTypes # -# Note that not all of these are currently implemented. In particular, these -# extension types are currently not parameterizable in Substrait, which means -# among other things that we can't declare dictionary type here at all since -# we'd have to declare a different dictionary type for all encoded types -# (but that is an infinite space). Similarly, we would have to declare a -# timestamp variation for all possible timezone strings. +# These types fall into several categories of behavior: # Certain Arrow data types are, from Substrait's point of view, encodings. # These include dictionary, the view types (e.g. binary view, list view), # and REE. # # These types are not logically distinct from the type they are encoding. -# Specifically: -# * There is no column of the decoded type that cannot be represented -# as a column the encoded type and vice versa. +# Specifically, the types meet the following criteria: +# * There is no value in the decoded type that cannot be represented +# as a value in the encoded type and vice versa. # * Functions have the same meaning when applied to the encoded type -# +# +# Note: if two types have a different range (e.g. string and large_string) then +# they do not satisfy the above criteria and are not encodings. +# # These types will never have a Substrait equivalent. In the Substrait point # of view these are execution details. @@ -70,7 +68,6 @@ # large_list # fixed_size_list # duration -# time32 - not technically a parameterized type, but unsupported for similar reasons # Other types are not encodings, but are not first-class in Substrait. These # types are often similar to existing Substrait types but define a different range @@ -99,15 +96,15 @@ types: months: i32 days: i32 nanos: i64 - # All signed integer literals are encoded as user defined literals with + # All unsigned integer literals are encoded as user defined literals with # a google.protobuf.UInt64Value message. - - name: i8 + - name: u8 structure: {} - - name: i16 + - name: u16 structure: {} - - name: i32 + - name: u32 structure: {} - - name: i64 + - name: u64 structure: {} # fp16 literals are encoded as user defined literals with # a google.protobuf.UInt32Value message where the lower 16 bits are @@ -119,11 +116,12 @@ types: # different range of values, it is an extension type. # # date64 literals are encoded as user defined literals with - # a google.protobuf.UInt64Value message. + # a google.protobuf.Int64Value message. - name: date_millis structure: {} - # We cannot generate `time` today for reasons similar to parameterized - # UDTs. + # time literals are encoded as user defined literals with + # a google.protobuf.Int32Value message (for time_seconds/time_millis) + # or a google.protobuf.Int64Value message (for time_nanos). - name: time_seconds structure: {} - name: time_millis diff --git a/python/pyarrow/tests/test_substrait.py b/python/pyarrow/tests/test_substrait.py index a30cd3071b0cd..53e96f4540363 100644 --- a/python/pyarrow/tests/test_substrait.py +++ b/python/pyarrow/tests/test_substrait.py @@ -945,28 +945,43 @@ def test_serializing_expressions(expr): def test_arrow_specific_types(): + fields = { + "time_seconds": (pa.time32("s"), 0), + "time_millis": (pa.time32("ms"), 0), + "time_nanos": (pa.time64("ns"), 0), + "date_millis": (pa.date64(), 0), + "large_string": (pa.large_string(), "test_string"), + "large_binary": (pa.large_binary(), b"test_string"), + } + schema = pa.schema([pa.field(name, typ) for name, (typ, _) in fields.items()]) + + def check_round_trip(expr): + buf = pa.substrait.serialize_expressions([expr], ["test_expr"], schema) + returned = pa.substrait.deserialize_expressions(buf) + assert schema == returned.schema + + for name, (typ, val) in fields.items(): + check_round_trip(pc.field(name) == pa.scalar(val, type=typ)) + + +def test_arrow_one_way_types(): schema = pa.schema( [ - pa.field("time_nanos", pa.time64("ns")), - pa.field("date_millis", pa.date64()), - pa.field("large_string", pa.large_string()), - pa.field("large_binary", pa.large_binary()), + pa.field("binary_view", pa.binary_view()), + pa.field("string_view", pa.string_view()), ] ) + alt_schema = pa.schema( + [pa.field("binary_view", pa.binary()), pa.field("string_view", pa.string())] + ) - def check_round_trip(expr): + def check_one_way(expr): buf = pa.substrait.serialize_expressions([expr], ["test_expr"], schema) returned = pa.substrait.deserialize_expressions(buf) - assert schema == returned.schema + assert alt_schema == returned.schema - check_round_trip(pc.field("large_string") == "test_string") - check_round_trip(pc.field("large_binary") == "test_string") - # Arrow-cpp supports round tripping these types but pyarrow doesn't support - # constructing literals of these types - # - # So best we can do is verify field references work - check_round_trip(pc.field("time_nanos")) - check_round_trip(pc.field("date_millis")) + check_one_way(pc.is_null(pc.field("binary_view"))) + check_one_way(pc.is_null(pc.field("string_view"))) def test_invalid_expression_ser_des(): From ed145bc7de7a527b49ed72bd4326dd1ee5cbfc69 Mon Sep 17 00:00:00 2001 From: Weston Pace Date: Fri, 22 Mar 2024 07:20:07 -0700 Subject: [PATCH 6/8] Added dictionary and ree to the one_way test --- python/pyarrow/tests/test_substrait.py | 18 +++++++++++++----- 1 file changed, 13 insertions(+), 5 deletions(-) diff --git a/python/pyarrow/tests/test_substrait.py b/python/pyarrow/tests/test_substrait.py index 53e96f4540363..8579eb1ab2fc3 100644 --- a/python/pyarrow/tests/test_substrait.py +++ b/python/pyarrow/tests/test_substrait.py @@ -969,20 +969,28 @@ def test_arrow_one_way_types(): [ pa.field("binary_view", pa.binary_view()), pa.field("string_view", pa.string_view()), + pa.field("dictionary", pa.dictionary(pa.int32(), pa.string())), + pa.field("ree", pa.run_end_encoded(pa.int32(), pa.string())), ] ) alt_schema = pa.schema( - [pa.field("binary_view", pa.binary()), pa.field("string_view", pa.string())] + [ + pa.field("binary_view", pa.binary()), + pa.field("string_view", pa.string()), + pa.field("dictionary", pa.string()), + pa.field("ree", pa.string()) + ] ) - def check_one_way(expr): + def check_one_way(field): + expr = pc.is_null(pc.field(field.name)) buf = pa.substrait.serialize_expressions([expr], ["test_expr"], schema) returned = pa.substrait.deserialize_expressions(buf) assert alt_schema == returned.schema - check_one_way(pc.is_null(pc.field("binary_view"))) - check_one_way(pc.is_null(pc.field("string_view"))) - + for field in schema: + check_one_way(field) + def test_invalid_expression_ser_des(): schema = pa.schema([ From 10aedfd7adbed3c35c56911e95e766fe19497a23 Mon Sep 17 00:00:00 2001 From: Weston Pace Date: Fri, 22 Mar 2024 07:26:57 -0700 Subject: [PATCH 7/8] Lint fix --- python/pyarrow/tests/test_substrait.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/pyarrow/tests/test_substrait.py b/python/pyarrow/tests/test_substrait.py index 8579eb1ab2fc3..40700e4741321 100644 --- a/python/pyarrow/tests/test_substrait.py +++ b/python/pyarrow/tests/test_substrait.py @@ -990,7 +990,7 @@ def check_one_way(field): for field in schema: check_one_way(field) - + def test_invalid_expression_ser_des(): schema = pa.schema([ From ab8237a6ec274dfe34bb1d17edd160ea93fc2ec9 Mon Sep 17 00:00:00 2001 From: Weston Pace Date: Wed, 10 Apr 2024 19:06:18 -0700 Subject: [PATCH 8/8] Reduce code duplication per PR review --- .../arrow/engine/substrait/type_internal.cc | 60 ++++++++----------- 1 file changed, 25 insertions(+), 35 deletions(-) diff --git a/cpp/src/arrow/engine/substrait/type_internal.cc b/cpp/src/arrow/engine/substrait/type_internal.cc index 1587112518426..5e7e364fe00c5 100644 --- a/cpp/src/arrow/engine/substrait/type_internal.cc +++ b/cpp/src/arrow/engine/substrait/type_internal.cc @@ -300,49 +300,39 @@ struct DataTypeToProtoImpl { } Status Visit(const Date64Type& t) { return EncodeUserDefined(t); } + template + Status VisitTimestamp(const TimestampType& t, + void (substrait::Type::*set_allocated_sub)(Sub*)) { + auto ts = SetWithThen(set_allocated_sub); + switch (t.unit()) { + case TimeUnit::SECOND: + ts->set_precision(0); + break; + case TimeUnit::MILLI: + ts->set_precision(3); + break; + case TimeUnit::MICRO: + ts->set_precision(6); + break; + case TimeUnit::NANO: + ts->set_precision(9); + break; + default: + return NotImplemented(t); + } + return Status::OK(); + } + Status Visit(const TimestampType& t) { if (t.timezone() == "") { - auto ts = SetWithThen(&substrait::Type::set_allocated_precision_timestamp); - switch (t.unit()) { - case TimeUnit::SECOND: - ts->set_precision(0); - break; - case TimeUnit::MILLI: - ts->set_precision(3); - break; - case TimeUnit::MICRO: - ts->set_precision(6); - break; - case TimeUnit::NANO: - ts->set_precision(9); - break; - default: - return NotImplemented(t); - } + return VisitTimestamp(t, &substrait::Type::set_allocated_precision_timestamp); } else { // Note: The timezone information is discarded here. In Substrait the time zone // information is part of the function and not part of the type. For example, to // convert a timestamp to a string, the time zone is passed as an argument to the // function. - auto ts = SetWithThen(&substrait::Type::set_allocated_precision_timestamp_tz); - switch (t.unit()) { - case TimeUnit::SECOND: - ts->set_precision(0); - break; - case TimeUnit::MILLI: - ts->set_precision(3); - break; - case TimeUnit::MICRO: - ts->set_precision(6); - break; - case TimeUnit::NANO: - ts->set_precision(9); - break; - default: - return NotImplemented(t); - } + return VisitTimestamp(t, &substrait::Type::set_allocated_precision_timestamp_tz); } - return Status::OK(); } Status Visit(const Time32Type& t) { return EncodeUserDefined(t); }