Skip to content

Commit

Permalink
Addressing review comments
Browse files Browse the repository at this point in the history
  • Loading branch information
westonpace committed Apr 8, 2024
1 parent 1937f2e commit 8321e64
Show file tree
Hide file tree
Showing 7 changed files with 110 additions and 75 deletions.
56 changes: 42 additions & 14 deletions cpp/src/arrow/engine/substrait/expression_internal.cc
Original file line number Diff line number Diff line change
Expand Up @@ -435,23 +435,44 @@ struct UserDefinedLiteralToArrow {
Status Visit(const IntegerType& type) {
google::protobuf::UInt64Value value;
if (!user_defined_->value().UnpackTo(&value)) {
return Status::Invalid("Failed to unpack user defined integer literal to uint64");
return Status::Invalid(
"Failed to unpack user defined integer literal to UInt64Value");
}
ARROW_ASSIGN_OR_RAISE(scalar_, MakeScalar(type.GetSharedPtr(), value.value()));
return Status::OK();
}
Status Visit(const Time32Type& type) {
google::protobuf::Int32Value value;
if (!user_defined_->value().UnpackTo(&value)) {
return Status::Invalid(
"Failed to unpack user defined time32 literal to Int32Value");
}
ARROW_ASSIGN_OR_RAISE(scalar_, MakeScalar(type.GetSharedPtr(), value.value()));
return Status::OK();
}
Status Visit(const Time64Type& type) {
google::protobuf::Int64Value value;
if (!user_defined_->value().UnpackTo(&value)) {
return Status::Invalid(
"Failed to unpack user defined time64 literal to Int64Value");
}
ARROW_ASSIGN_OR_RAISE(scalar_, MakeScalar(type.GetSharedPtr(), value.value()));
return Status::OK();
}
Status Visit(const Date64Type& type) {
google::protobuf::UInt64Value value;
google::protobuf::Int64Value value;
if (!user_defined_->value().UnpackTo(&value)) {
return Status::Invalid("Failed to unpack user defined date64 literal to uint64");
return Status::Invalid(
"Failed to unpack user defined date64 literal to Int64Value");
}
ARROW_ASSIGN_OR_RAISE(scalar_, MakeScalar(type.GetSharedPtr(), value.value()));
return Status::OK();
}
Status Visit(const HalfFloatType& type) {
google::protobuf::UInt32Value value;
if (!user_defined_->value().UnpackTo(&value)) {
return Status::Invalid("Failed to unpack user defined half_float literal to bytes");
return Status::Invalid(
"Failed to unpack user defined half_float literal to UInt32Value");
}
uint16_t half_float_value = value.value();
ARROW_ASSIGN_OR_RAISE(scalar_, MakeScalar(type.GetSharedPtr(), half_float_value));
Expand All @@ -461,7 +482,7 @@ struct UserDefinedLiteralToArrow {
google::protobuf::StringValue value;
if (!user_defined_->value().UnpackTo(&value)) {
return Status::Invalid(
"Failed to unpack user defined large_string literal to string");
"Failed to unpack user defined large_string literal to StringValue");
}
ARROW_ASSIGN_OR_RAISE(scalar_,
MakeScalar(type.GetSharedPtr(), std::string(value.value())));
Expand All @@ -471,7 +492,7 @@ struct UserDefinedLiteralToArrow {
google::protobuf::BytesValue value;
if (!user_defined_->value().UnpackTo(&value)) {
return Status::Invalid(
"Failed to unpack user defined large_binary literal to bytes");
"Failed to unpack user defined large_binary literal to BytesValue");
}
ARROW_ASSIGN_OR_RAISE(scalar_,
MakeScalar(type.GetSharedPtr(), std::string(value.value())));
Expand Down Expand Up @@ -531,12 +552,12 @@ Result<Datum> FromProto(const substrait::Expression::Literal& lit,
ARROW_UNSUPPRESS_DEPRECATION_WARNING
case substrait::Expression::Literal::kPrecisionTimestamp: {
// https://github.com/substrait-io/substrait/issues/611
// TODO(weston) don't break, return precision timestamp
// TODO(GH-40741) don't break, return precision timestamp
break;
}
case substrait::Expression::Literal::kPrecisionTimestampTz: {
// https://github.com/substrait-io/substrait/issues/611
// TODO(weston) don't break, return precision timestamp
// TODO(GH-40741) don't break, return precision timestamp
break;
}
case substrait::Expression::Literal::kDate:
Expand Down Expand Up @@ -868,7 +889,7 @@ struct ScalarToProtoImpl {

Status Visit(const Date32Scalar& s) { return Primitive(&Lit::set_date, s); }
Status Visit(const Date64Scalar& s) {
google::protobuf::UInt64Value value;
google::protobuf::Int64Value value;
value.set_value(s.value);
return EncodeUserDefined(*s.type, value);
}
Expand All @@ -888,7 +909,7 @@ struct ScalarToProtoImpl {
micros = s.value;
break;
case TimeUnit::NANO:
// TODO(weston): can support nanos when
// TODO(GH-40741): can support nanos when
// https://github.com/substrait-io/substrait/issues/611 is resolved
return NotImplemented(s);
default:
Expand All @@ -912,12 +933,19 @@ struct ScalarToProtoImpl {
}

// Need to support parameterized UDTs
Status Visit(const Time32Scalar& s) { return NotImplemented(s); }
Status Visit(const Time32Scalar& s) {
google::protobuf::Int32Value value;
value.set_value(s.value);
return EncodeUserDefined(*s.type, value);
}
Status Visit(const Time64Scalar& s) {
if (checked_cast<const Time64Type&>(*s.type).unit() != TimeUnit::MICRO) {
return NotImplemented(s);
if (checked_cast<const Time64Type&>(*s.type).unit() == TimeUnit::MICRO) {
return Primitive(&Lit::set_time, s);
} else {
google::protobuf::Int64Value value;
value.set_value(s.value);
return EncodeUserDefined(*s.type, value);
}
return Primitive(&Lit::set_time, s);
}

Status Visit(const MonthIntervalScalar& s) { return NotImplemented(s); }
Expand Down
10 changes: 3 additions & 7 deletions cpp/src/arrow/engine/substrait/extension_set.cc
Original file line number Diff line number Diff line change
Expand Up @@ -314,12 +314,6 @@ Result<uint32_t> ExtensionSet::EncodeType(const DataType& type) {
return Status::KeyError("type ", type.ToString(), " not found in the registry");
}

Result<uint32_t> ExtensionSet::EncodeTypeId(Id type_id) {
RETURN_NOT_OK(this->AddUri(type_id));
auto [it, success] = types_map_.emplace(type_id, static_cast<uint32_t>(types_map_.size()));
return it->second;
}

Result<Id> ExtensionSet::DecodeFunction(uint32_t anchor) const {
if (functions_.find(anchor) == functions_.end() || functions_.at(anchor).empty()) {
return Status::Invalid("User defined function reference ", anchor,
Expand Down Expand Up @@ -1041,7 +1035,7 @@ struct DefaultExtensionIdRegistry : ExtensionIdRegistryImpl {
};

// The type (variation) mappings listed below need to be kept in sync
// with the YAML at substrait/format/extension_types.yaml manually;
// with the YAML at format/substrait/extension_types.yaml manually;
// see ARROW-15535.
for (TypeName e : {
TypeName{uint8(), "u8"},
Expand All @@ -1051,6 +1045,8 @@ struct DefaultExtensionIdRegistry : ExtensionIdRegistryImpl {
TypeName{float16(), "fp16"},
TypeName{large_utf8(), "large_string"},
TypeName{large_binary(), "large_binary"},
TypeName{time32(TimeUnit::SECOND), "time_seconds"},
TypeName{time32(TimeUnit::MILLI), "time_millis"},
TypeName{date64(), "date_millis"},
TypeName{time64(TimeUnit::NANO), "time_nanos"},
}) {
Expand Down
7 changes: 0 additions & 7 deletions cpp/src/arrow/engine/substrait/extension_set.h
Original file line number Diff line number Diff line change
Expand Up @@ -412,13 +412,6 @@ class ARROW_ENGINE_EXPORT ExtensionSet {
/// \return An anchor that can be used to refer to the type within a plan
Result<uint32_t> EncodeType(const DataType& type);

/// \brief Lookup the anchor for a given type alias
///
/// Similar to \see EncodeType but this is used for cases where the data type is either
/// parameterized or custom in some way (e.g. we use this for Time64::Nanos). We need
/// to use the Id directly since we can't have registered the type with the registry.
Result<uint32_t> EncodeTypeId(Id type_id);

/// \brief Return a function id given an anchor
///
/// This is used when converting a Substrait plan to an Arrow execution plan.
Expand Down
15 changes: 13 additions & 2 deletions cpp/src/arrow/engine/substrait/serde_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -314,6 +314,8 @@ TEST(Substrait, SupportedExtensionTypes) {
large_utf8(),
large_binary(),
date64(),
time32(TimeUnit::SECOND),
time32(TimeUnit::MILLI),
time64(TimeUnit::NANO),
}) {
auto anchor = ext_set.num_types();
Expand Down Expand Up @@ -474,8 +476,6 @@ TEST(Substrait, NoEquivalentSubstraitType) {
dense_union({field("i8", int8()), field("f32", float32())}),
fixed_size_list(float16(), 3),
duration(TimeUnit::MICRO),
time32(TimeUnit::MILLI),
time32(TimeUnit::SECOND),
large_list(utf8()),
}) {
ARROW_SCOPED_TRACE(type->ToString());
Expand Down Expand Up @@ -619,6 +619,17 @@ TEST(Substrait, ArrowSpecificLiterals) {
CheckArrowSpecificLiteral(UInt32Scalar(7));
CheckArrowSpecificLiteral(UInt64Scalar(7));
CheckArrowSpecificLiteral(Date64Scalar(86400000));
CheckArrowSpecificLiteral(Time64Scalar(7, TimeUnit::NANO));
CheckArrowSpecificLiteral(Time32Scalar(7, TimeUnit::SECOND));
CheckArrowSpecificLiteral(Time32Scalar(7, TimeUnit::MILLI));
CheckArrowSpecificLiteral(Time32Scalar(7, TimeUnit::SECOND));
// We serialize as a signed integer, which doesn't make sense for Time scalars but
// Arrow supports it so we might as well round-trip it.
CheckArrowSpecificLiteral(Time32Scalar(-7, TimeUnit::MILLI));
CheckArrowSpecificLiteral(Time32Scalar(-7, TimeUnit::SECOND));
CheckArrowSpecificLiteral(Time64Scalar(-7, TimeUnit::NANO));
// Negative date scalars DO make sense and we should make sure they work
CheckArrowSpecificLiteral(Date64Scalar(-86400000));
CheckArrowSpecificLiteral(HalfFloatScalar(0));
CheckArrowSpecificLiteral(LargeStringScalar("hello"));
CheckArrowSpecificLiteral(LargeBinaryScalar("hello"));
Expand Down
18 changes: 6 additions & 12 deletions cpp/src/arrow/engine/substrait/type_internal.cc
Original file line number Diff line number Diff line change
Expand Up @@ -345,13 +345,7 @@ struct DataTypeToProtoImpl {
return Status::OK();
}

Status Visit(const Time32Type& t) {
// TODO(weston)
// Unsupported for the same reason we don't support parameterized types
// which is that the extension registry only supports encoding one type
// per type id
return NotImplemented(t);
}
Status Visit(const Time32Type& t) { return EncodeUserDefined(t); }
Status Visit(const Time64Type& t) {
if (t.unit() == TimeUnit::MICRO) {
return SetWith(&substrait::Type::set_allocated_time);
Expand All @@ -369,7 +363,7 @@ struct DataTypeToProtoImpl {
dec->set_scale(t.scale());
return Status::OK();
}
// TODO(weston) support parameterized UDT
// TODO(GH-40740) support parameterized UDT
Status Visit(const Decimal256Type& t) { return NotImplemented(t); }

Status Visit(const ListType& t) {
Expand All @@ -389,7 +383,7 @@ struct DataTypeToProtoImpl {
return Status::OK();
}

// TODO(weston) support parameterized UDT
// TODO(GH-40740) support parameterized UDT
Status Visit(const LargeListViewType& t) { return NotImplemented(t); }

Status Visit(const StructType& t) {
Expand Down Expand Up @@ -455,13 +449,13 @@ struct DataTypeToProtoImpl {
return NotImplemented(t);
}

// TODO(weston) support parameterized UDT
// TODO(GH-40740) support parameterized UDT
Status Visit(const FixedSizeListType& t) { return NotImplemented(t); }
// TODO(weston) support parameterized UDT
// TODO(GH-40740) support parameterized UDT
Status Visit(const DurationType& t) { return NotImplemented(t); }
Status Visit(const LargeStringType& t) { return EncodeUserDefined(t); }
Status Visit(const LargeBinaryType& t) { return EncodeUserDefined(t); }
// TODO(weston) support parameterized UDT
// TODO(GH-40740) support parameterized UDT
Status Visit(const LargeListType& t) { return NotImplemented(t); }
Status Visit(const MonthDayNanoIntervalType& t) { return EncodeUserDefined(t); }

Expand Down
36 changes: 17 additions & 19 deletions format/substrait/extension_types.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -35,23 +35,21 @@
# - interval<months: i32, days: i32, nanos: i64>
# - arrow::ExtensionTypes
#
# Note that not all of these are currently implemented. In particular, these
# extension types are currently not parameterizable in Substrait, which means
# among other things that we can't declare dictionary type here at all since
# we'd have to declare a different dictionary type for all encoded types
# (but that is an infinite space). Similarly, we would have to declare a
# timestamp variation for all possible timezone strings.
# These types fall into several categories of behavior:

# Certain Arrow data types are, from Substrait's point of view, encodings.
# These include dictionary, the view types (e.g. binary view, list view),
# and REE.
#
# These types are not logically distinct from the type they are encoding.
# Specifically:
# * There is no column of the decoded type that cannot be represented
# as a column the encoded type and vice versa.
# Specifically, the types meet the following criteria:
# * There is no value in the decoded type that cannot be represented
# as a value in the encoded type and vice versa.
# * Functions have the same meaning when applied to the encoded type
#
#
# Note: if two types have a different range (e.g. string and large_string) then
# they do not satisfy the above criteria and are not encodings.
#
# These types will never have a Substrait equivalent. In the Substrait point
# of view these are execution details.

Expand All @@ -70,7 +68,6 @@
# large_list
# fixed_size_list
# duration
# time32 - not technically a parameterized type, but unsupported for similar reasons

# Other types are not encodings, but are not first-class in Substrait. These
# types are often similar to existing Substrait types but define a different range
Expand Down Expand Up @@ -99,15 +96,15 @@ types:
months: i32
days: i32
nanos: i64
# All signed integer literals are encoded as user defined literals with
# All unsigned integer literals are encoded as user defined literals with
# a google.protobuf.UInt64Value message.
- name: i8
- name: u8
structure: {}
- name: i16
- name: u16
structure: {}
- name: i32
- name: u32
structure: {}
- name: i64
- name: u64
structure: {}
# fp16 literals are encoded as user defined literals with
# a google.protobuf.UInt32Value message where the lower 16 bits are
Expand All @@ -119,11 +116,12 @@ types:
# different range of values, it is an extension type.
#
# date64 literals are encoded as user defined literals with
# a google.protobuf.UInt64Value message.
# a google.protobuf.Int64Value message.
- name: date_millis
structure: {}
# We cannot generate `time` today for reasons similar to parameterized
# UDTs.
# time literals are encoded as user defined literals with
# a google.protobuf.Int32Value message (for time_seconds/time_millis)
# or a google.protobuf.Int64Value message (for time_nanos).
- name: time_seconds
structure: {}
- name: time_millis
Expand Down
43 changes: 29 additions & 14 deletions python/pyarrow/tests/test_substrait.py
Original file line number Diff line number Diff line change
Expand Up @@ -945,28 +945,43 @@ def test_serializing_expressions(expr):


def test_arrow_specific_types():
fields = {
"time_seconds": (pa.time32("s"), 0),
"time_millis": (pa.time32("ms"), 0),
"time_nanos": (pa.time64("ns"), 0),
"date_millis": (pa.date64(), 0),
"large_string": (pa.large_string(), "test_string"),
"large_binary": (pa.large_binary(), b"test_string"),
}
schema = pa.schema([pa.field(name, typ) for name, (typ, _) in fields.items()])

def check_round_trip(expr):
buf = pa.substrait.serialize_expressions([expr], ["test_expr"], schema)
returned = pa.substrait.deserialize_expressions(buf)
assert schema == returned.schema

for name, (typ, val) in fields.items():
check_round_trip(pc.field(name) == pa.scalar(val, type=typ))


def test_arrow_one_way_types():
schema = pa.schema(
[
pa.field("time_nanos", pa.time64("ns")),
pa.field("date_millis", pa.date64()),
pa.field("large_string", pa.large_string()),
pa.field("large_binary", pa.large_binary()),
pa.field("binary_view", pa.binary_view()),
pa.field("string_view", pa.string_view()),
]
)
alt_schema = pa.schema(
[pa.field("binary_view", pa.binary()), pa.field("string_view", pa.string())]
)

def check_round_trip(expr):
def check_one_way(expr):
buf = pa.substrait.serialize_expressions([expr], ["test_expr"], schema)
returned = pa.substrait.deserialize_expressions(buf)
assert schema == returned.schema
assert alt_schema == returned.schema

check_round_trip(pc.field("large_string") == "test_string")
check_round_trip(pc.field("large_binary") == "test_string")
# Arrow-cpp supports round tripping these types but pyarrow doesn't support
# constructing literals of these types
#
# So best we can do is verify field references work
check_round_trip(pc.field("time_nanos"))
check_round_trip(pc.field("date_millis"))
check_one_way(pc.is_null(pc.field("binary_view")))
check_one_way(pc.is_null(pc.field("string_view")))


def test_invalid_expression_ser_des():
Expand Down

0 comments on commit 8321e64

Please sign in to comment.