Skip to content

Commit

Permalink
GH-40695 [C++] Expand Substrait type support (#40696)
Browse files Browse the repository at this point in the history
### Rationale for this change

See #40695 

### What changes are included in this PR?

This PR does a few things:

 * Substrait is upgraded to the latest version
 * Support is added for the parameterized timestamp type (but not literals due to substrait-io/substrait#611).
 * Support is added for the following arrow-specific types:
   * fp16
   * date_millis
   * time_seconds
   * time_millis
   * time_nanos
   * large_string
   * large_binary

When adding support for the new timestamp types I also relaxed the restrictions on the time zone column.  Substrait puts time zone information in the function and not the type.  In other words, to print the "America/New York" value of a column of instants one would do something like `to_char(my_timestamp, "America/New York")` instead of `to_char(cast(my_timestamp, timestamp("nanos", "America/New York")`.

However, the current implementation makes it impossible to produce or consume a plan with `to_char(my_timestamp, "America/New York")` because it would reject the type because it has a non-UTC time zone.  With this latest change, we treat any non-empty timezone as a timezone_tz type.

In addition, I have enabled conversions from "encoded types" to their unencoded representation.  E.g. a type of `DICTIONARY<INT32>` will convert to `INT32`.  At a logical expression / plan perspective these encodings are irrelevant.  If anything, they may belong in a more physical plan representation.  Should a need for them arise we can dig into it more later.  However, I believe it is better to err on the side of generating "something" rather than failing in these cases.  I don't consider this last change critical and can back it out if need be.

### Are these changes tested?

Yes, I added new unit tests

### Are there any user-facing changes?

Yes, via the Substrait conversion.  These changes should be backwards compatible in that they only add functionality in places that previously reported "Not Supported".
* GitHub Issue: #40695

Lead-authored-by: Weston Pace <[email protected]>
Co-authored-by: Benjamin Kietzman <[email protected]>
Signed-off-by: Weston Pace <[email protected]>
  • Loading branch information
westonpace and bkietz authored Apr 11, 2024
1 parent a5c39fe commit 7f64fff
Show file tree
Hide file tree
Showing 11 changed files with 649 additions and 97 deletions.
262 changes: 231 additions & 31 deletions cpp/src/arrow/engine/substrait/expression_internal.cc
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
#include <vector>

#include <google/protobuf/descriptor.h>
#include <google/protobuf/wrappers.pb.h>

#include "arrow/array/array_base.h"
#include "arrow/array/array_nested.h"
Expand All @@ -58,8 +59,10 @@
#include "arrow/util/decimal.h"
#include "arrow/util/endian.h"
#include "arrow/util/logging.h"
#include "arrow/util/macros.h"
#include "arrow/util/small_vector.h"
#include "arrow/util/string.h"
#include "arrow/util/unreachable.h"
#include "arrow/visit_scalar_inline.h"

namespace arrow {
Expand All @@ -71,6 +74,9 @@ namespace engine {

namespace {

constexpr int64_t kMicrosPerSecond = 1000000;
constexpr int64_t kMicrosPerMilli = 1000;

Id NormalizeFunctionName(Id id) {
// Substrait plans encode the types into the function name so it might look like
// add:opt_i32_i32. We don't care about the :opt_i32_i32 so we just trim it
Expand Down Expand Up @@ -421,6 +427,86 @@ Result<compute::Expression> FromProto(const substrait::Expression& expr,
expr.DebugString());
}

namespace {
struct UserDefinedLiteralToArrow {
Status Visit(const DataType& type) {
return Status::NotImplemented("User defined literals of type ", type);
}
Status Visit(const IntegerType& type) {
google::protobuf::UInt64Value value;
if (!user_defined_->value().UnpackTo(&value)) {
return Status::Invalid(
"Failed to unpack user defined integer literal to UInt64Value");
}
ARROW_ASSIGN_OR_RAISE(scalar_, MakeScalar(type.GetSharedPtr(), value.value()));
return Status::OK();
}
Status Visit(const Time32Type& type) {
google::protobuf::Int32Value value;
if (!user_defined_->value().UnpackTo(&value)) {
return Status::Invalid(
"Failed to unpack user defined time32 literal to Int32Value");
}
ARROW_ASSIGN_OR_RAISE(scalar_, MakeScalar(type.GetSharedPtr(), value.value()));
return Status::OK();
}
Status Visit(const Time64Type& type) {
google::protobuf::Int64Value value;
if (!user_defined_->value().UnpackTo(&value)) {
return Status::Invalid(
"Failed to unpack user defined time64 literal to Int64Value");
}
ARROW_ASSIGN_OR_RAISE(scalar_, MakeScalar(type.GetSharedPtr(), value.value()));
return Status::OK();
}
Status Visit(const Date64Type& type) {
google::protobuf::Int64Value value;
if (!user_defined_->value().UnpackTo(&value)) {
return Status::Invalid(
"Failed to unpack user defined date64 literal to Int64Value");
}
ARROW_ASSIGN_OR_RAISE(scalar_, MakeScalar(type.GetSharedPtr(), value.value()));
return Status::OK();
}
Status Visit(const HalfFloatType& type) {
google::protobuf::UInt32Value value;
if (!user_defined_->value().UnpackTo(&value)) {
return Status::Invalid(
"Failed to unpack user defined half_float literal to UInt32Value");
}
uint16_t half_float_value = value.value();
ARROW_ASSIGN_OR_RAISE(scalar_, MakeScalar(type.GetSharedPtr(), half_float_value));
return Status::OK();
}
Status Visit(const LargeStringType& type) {
google::protobuf::StringValue value;
if (!user_defined_->value().UnpackTo(&value)) {
return Status::Invalid(
"Failed to unpack user defined large_string literal to StringValue");
}
ARROW_ASSIGN_OR_RAISE(scalar_,
MakeScalar(type.GetSharedPtr(), std::string(value.value())));
return Status::OK();
}
Status Visit(const LargeBinaryType& type) {
google::protobuf::BytesValue value;
if (!user_defined_->value().UnpackTo(&value)) {
return Status::Invalid(
"Failed to unpack user defined large_binary literal to BytesValue");
}
ARROW_ASSIGN_OR_RAISE(scalar_,
MakeScalar(type.GetSharedPtr(), std::string(value.value())));
return Status::OK();
}
Status operator()(const DataType& type) { return VisitTypeInline(type, this); }

std::shared_ptr<Scalar> scalar_;
const substrait::Expression::Literal::UserDefined* user_defined_;
const ExtensionSet* ext_set_;
const ConversionOptions& conversion_options_;
};
} // namespace

Result<Datum> FromProto(const substrait::Expression::Literal& lit,
const ExtensionSet& ext_set,
const ConversionOptions& conversion_options) {
Expand Down Expand Up @@ -455,14 +541,25 @@ Result<Datum> FromProto(const substrait::Expression::Literal& lit,
case substrait::Expression::Literal::kBinary:
return Datum(BinaryScalar(lit.binary()));

ARROW_SUPPRESS_DEPRECATION_WARNING
case substrait::Expression::Literal::kTimestamp:
return Datum(
TimestampScalar(static_cast<int64_t>(lit.timestamp()), TimeUnit::MICRO));

case substrait::Expression::Literal::kTimestampTz:
return Datum(TimestampScalar(static_cast<int64_t>(lit.timestamp_tz()),
TimeUnit::MICRO, TimestampTzTimezoneString()));

ARROW_UNSUPPRESS_DEPRECATION_WARNING
case substrait::Expression::Literal::kPrecisionTimestamp: {
// https://github.com/substrait-io/substrait/issues/611
// TODO(GH-40741) don't break, return precision timestamp
break;
}
case substrait::Expression::Literal::kPrecisionTimestampTz: {
// https://github.com/substrait-io/substrait/issues/611
// TODO(GH-40741) don't break, return precision timestamp
break;
}
case substrait::Expression::Literal::kDate:
return Datum(Date32Scalar(lit.date()));
case substrait::Expression::Literal::kTime:
Expand Down Expand Up @@ -674,18 +771,30 @@ Result<Datum> FromProto(const substrait::Expression::Literal& lit,
return Datum(MakeNullScalar(std::move(type_nullable.first)));
}

case substrait::Expression::Literal::kUserDefined: {
const auto& user_defined = lit.user_defined();
ARROW_ASSIGN_OR_RAISE(auto type_record,
ext_set.DecodeType(user_defined.type_reference()));
UserDefinedLiteralToArrow visitor{nullptr, &user_defined, &ext_set,
conversion_options};
ARROW_RETURN_NOT_OK((visitor)(*type_record.type));
return Datum(std::move(visitor.scalar_));
}

case substrait::Expression::Literal::LITERAL_TYPE_NOT_SET:
return Status::Invalid("substrait literal did not have any literal type set");

default:
break;
}

return Status::NotImplemented("conversion to arrow::Datum from Substrait literal ",
lit.DebugString());
return Status::NotImplemented("conversion to arrow::Datum from Substrait literal `",
lit.DebugString(), "`");
}

namespace {
struct ScalarToProtoImpl {
Status Visit(const NullScalar& s) { return NotImplemented(s); }

struct ScalarToProtoImpl {
using Lit = substrait::Expression::Literal;

template <typename Arg, typename PrimitiveScalar>
Expand All @@ -702,61 +811,141 @@ struct ScalarToProtoImpl {
return Status::OK();
}

Status EncodeUserDefined(const DataType& data_type,
const google::protobuf::Message& value) {
ARROW_ASSIGN_OR_RAISE(auto anchor, ext_set_->EncodeType(data_type));
auto user_defined = std::make_unique<Lit::UserDefined>();
user_defined->set_type_reference(anchor);
auto value_any = std::make_unique<google::protobuf::Any>();
value_any->PackFrom(value);
user_defined->set_allocated_value(value_any.release());
lit_->set_allocated_user_defined(user_defined.release());
return Status::OK();
}

Status Visit(const NullScalar& s) {
ARROW_ASSIGN_OR_RAISE(auto anchor, ext_set_->EncodeType(*s.type));
auto user_defined = std::make_unique<Lit::UserDefined>();
user_defined->set_type_reference(anchor);
lit_->set_allocated_user_defined(user_defined.release());
return Status::OK();
}
Status Visit(const BooleanScalar& s) { return Primitive(&Lit::set_boolean, s); }

Status Visit(const Int8Scalar& s) { return Primitive(&Lit::set_i8, s); }
Status Visit(const Int16Scalar& s) { return Primitive(&Lit::set_i16, s); }
Status Visit(const Int32Scalar& s) { return Primitive(&Lit::set_i32, s); }
Status Visit(const Int64Scalar& s) { return Primitive(&Lit::set_i64, s); }

Status Visit(const UInt8Scalar& s) { return NotImplemented(s); }
Status Visit(const UInt16Scalar& s) { return NotImplemented(s); }
Status Visit(const UInt32Scalar& s) { return NotImplemented(s); }
Status Visit(const UInt64Scalar& s) { return NotImplemented(s); }

Status Visit(const HalfFloatScalar& s) { return NotImplemented(s); }
Status Visit(const UInt8Scalar& s) {
google::protobuf::UInt64Value value;
value.set_value(s.value);
return EncodeUserDefined(*s.type, value);
}
Status Visit(const UInt16Scalar& s) {
google::protobuf::UInt64Value value;
value.set_value(s.value);
return EncodeUserDefined(*s.type, value);
}
Status Visit(const UInt32Scalar& s) {
google::protobuf::UInt64Value value;
value.set_value(s.value);
return EncodeUserDefined(*s.type, value);
}
Status Visit(const UInt64Scalar& s) {
google::protobuf::UInt64Value value;
value.set_value(s.value);
return EncodeUserDefined(*s.type, value);
}
Status Visit(const HalfFloatScalar& s) {
google::protobuf::UInt32Value value;
value.set_value(s.value);
return EncodeUserDefined(*s.type, value);
}
Status Visit(const FloatScalar& s) { return Primitive(&Lit::set_fp32, s); }
Status Visit(const DoubleScalar& s) { return Primitive(&Lit::set_fp64, s); }

Status Visit(const StringScalar& s) {
return FromBuffer([](Lit* lit, std::string&& s) { lit->set_string(std::move(s)); },
s);
}
Status Visit(const StringViewScalar& s) {
return FromBuffer([](Lit* lit, std::string&& s) { lit->set_string(std::move(s)); },
s);
}
Status Visit(const BinaryScalar& s) {
return FromBuffer([](Lit* lit, std::string&& s) { lit->set_binary(std::move(s)); },
s);
}

Status Visit(const BinaryViewScalar& s) { return NotImplemented(s); }
Status Visit(const BinaryViewScalar& s) {
return FromBuffer([](Lit* lit, std::string&& s) { lit->set_binary(std::move(s)); },
s);
}

Status Visit(const FixedSizeBinaryScalar& s) {
return FromBuffer(
[](Lit* lit, std::string&& s) { lit->set_fixed_binary(std::move(s)); }, s);
}

Status Visit(const Date32Scalar& s) { return Primitive(&Lit::set_date, s); }
Status Visit(const Date64Scalar& s) { return NotImplemented(s); }
Status Visit(const Date64Scalar& s) {
google::protobuf::Int64Value value;
value.set_value(s.value);
return EncodeUserDefined(*s.type, value);
}

Status Visit(const TimestampScalar& s) {
const auto& t = checked_cast<const TimestampType&>(*s.type);

if (t.unit() != TimeUnit::MICRO) return NotImplemented(s);
uint64_t micros;
switch (t.unit()) {
case TimeUnit::SECOND:
micros = s.value * kMicrosPerSecond;
break;
case TimeUnit::MILLI:
micros = s.value * kMicrosPerMilli;
break;
case TimeUnit::MICRO:
micros = s.value;
break;
case TimeUnit::NANO:
// TODO(GH-40741): can support nanos when
// https://github.com/substrait-io/substrait/issues/611 is resolved
return NotImplemented(s);
default:
return NotImplemented(s);
}

if (t.timezone() == "") return Primitive(&Lit::set_timestamp, s);
// Remove these and use precision timestamp once
// https://github.com/substrait-io/substrait/issues/611 is resolved
ARROW_SUPPRESS_DEPRECATION_WARNING

if (t.timezone() == TimestampTzTimezoneString()) {
return Primitive(&Lit::set_timestamp_tz, s);
if (t.timezone() == "") {
lit_->set_timestamp(micros);
} else {
// Some loss of info here, Substrait doesn't store timezone
// in field data
lit_->set_timestamp_tz(micros);
}
ARROW_UNSUPPRESS_DEPRECATION_WARNING

return NotImplemented(s);
return Status::OK();
}

Status Visit(const Time32Scalar& s) { return NotImplemented(s); }
// Need to support parameterized UDTs
Status Visit(const Time32Scalar& s) {
google::protobuf::Int32Value value;
value.set_value(s.value);
return EncodeUserDefined(*s.type, value);
}
Status Visit(const Time64Scalar& s) {
if (checked_cast<const Time64Type&>(*s.type).unit() != TimeUnit::MICRO) {
return NotImplemented(s);
if (checked_cast<const Time64Type&>(*s.type).unit() == TimeUnit::MICRO) {
return Primitive(&Lit::set_time, s);
} else {
google::protobuf::Int64Value value;
value.set_value(s.value);
return EncodeUserDefined(*s.type, value);
}
return Primitive(&Lit::set_time, s);
}

Status Visit(const MonthIntervalScalar& s) { return NotImplemented(s); }
Expand All @@ -778,9 +967,10 @@ struct ScalarToProtoImpl {
return Status::OK();
}

// Need support for parameterized UDTs
Status Visit(const Decimal256Scalar& s) { return NotImplemented(s); }

Status Visit(const ListScalar& s) {
Status Visit(const BaseListScalar& s) {
if (s.value->length() == 0) {
ARROW_ASSIGN_OR_RAISE(auto list_type, ToProto(*s.type, /*nullable=*/true, ext_set_,
conversion_options_));
Expand All @@ -807,10 +997,6 @@ struct ScalarToProtoImpl {
return Status::OK();
}

Status Visit(const ListViewScalar& s) {
return Status::NotImplemented("list-view to proto");
}

Status Visit(const LargeListViewScalar& s) {
return Status::NotImplemented("list-view to proto");
}
Expand All @@ -830,7 +1016,10 @@ struct ScalarToProtoImpl {

Status Visit(const SparseUnionScalar& s) { return NotImplemented(s); }
Status Visit(const DenseUnionScalar& s) { return NotImplemented(s); }
Status Visit(const DictionaryScalar& s) { return NotImplemented(s); }
Status Visit(const DictionaryScalar& s) {
ARROW_ASSIGN_OR_RAISE(auto encoded, s.GetEncodedValue());
return (*this)(*encoded);
}

Status Visit(const MapScalar& s) {
if (s.value->length() == 0) {
Expand Down Expand Up @@ -914,10 +1103,21 @@ struct ScalarToProtoImpl {
return NotImplemented(s);
}

// Need support for parameterized UDTs
Status Visit(const FixedSizeListScalar& s) { return NotImplemented(s); }
Status Visit(const DurationScalar& s) { return NotImplemented(s); }
Status Visit(const LargeStringScalar& s) { return NotImplemented(s); }
Status Visit(const LargeBinaryScalar& s) { return NotImplemented(s); }

Status Visit(const LargeStringScalar& s) {
google::protobuf::StringValue value;
value.set_value(s.view().data(), s.view().size());
return EncodeUserDefined(*s.type, value);
}
Status Visit(const LargeBinaryScalar& s) {
google::protobuf::BytesValue value;
value.set_value(s.view().data(), s.view().size());
return EncodeUserDefined(*s.type, value);
}
// Need support for parameterized UDTs
Status Visit(const LargeListScalar& s) { return NotImplemented(s); }
Status Visit(const MonthDayNanoIntervalScalar& s) { return NotImplemented(s); }

Expand Down
Loading

0 comments on commit 7f64fff

Please sign in to comment.