Skip to content

Commit

Permalink
GH-43187: [C++] Support basic is_in predicate simplification (#43761)
Browse files Browse the repository at this point in the history
### Rationale for this change

Prior to #43256, this PR adds a basic implementation that does a linear scan filter over the value set on each guarantee. This isolates the correctness/semantics of `is_in` predicate simplification from the binary search performance optimization.

### What changes are included in this PR?

`SimplifyWithGuarantee` now handles `is_in` expressions.

### Are these changes tested?

A new unit test was added to arrow-compute-expression-test testing this change.

### Are there any user-facing changes?

No.
* GitHub Issue: #43187

Lead-authored-by: Larry Wang <[email protected]>
Co-authored-by: larry98 <[email protected]>
Co-authored-by: Benjamin Kietzman <[email protected]>
Signed-off-by: Benjamin Kietzman <[email protected]>
  • Loading branch information
larry98 and bkietz authored Sep 10, 2024
1 parent a87a8e0 commit 44b72d5
Show file tree
Hide file tree
Showing 2 changed files with 246 additions and 0 deletions.
73 changes: 73 additions & 0 deletions cpp/src/arrow/compute/expression.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
#include <unordered_set>

#include "arrow/chunked_array.h"
#include "arrow/compute/api_aggregate.h"
#include "arrow/compute/api_vector.h"
#include "arrow/compute/exec_internal.h"
#include "arrow/compute/expression_internal.h"
Expand Down Expand Up @@ -1242,6 +1243,72 @@ struct Inequality {
/*insert_implicit_casts=*/false, &exec_context);
}

/// Simplify an `is_in` call against an inequality guarantee.
///
/// We avoid the complexity of fully simplifying EQUAL comparisons to true
/// literals (e.g., 'x is_in [1, 2, 3]' given the guarantee 'x = 2') due to
/// potential complications with null matching behavior. This is ok for the
/// predicate pushdown use case because the overall aim is to simplify to an
/// unsatisfiable expression.
///
/// \pre `is_in_call` is a call to the `is_in` function
/// \return a simplified expression, or nullopt if no simplification occurred
static Result<std::optional<Expression>> SimplifyIsIn(
const Inequality& guarantee, const Expression::Call* is_in_call) {
DCHECK_EQ(is_in_call->function_name, "is_in");

auto options = checked_pointer_cast<SetLookupOptions>(is_in_call->options);

const auto& lhs = Comparison::StripOrderPreservingCasts(is_in_call->arguments[0]);
if (!lhs.field_ref()) return std::nullopt;
if (*lhs.field_ref() != guarantee.target) return std::nullopt;

FilterOptions::NullSelectionBehavior null_selection;
switch (options->null_matching_behavior) {
case SetLookupOptions::MATCH:
null_selection =
guarantee.nullable ? FilterOptions::EMIT_NULL : FilterOptions::DROP;
break;
case SetLookupOptions::SKIP:
null_selection = FilterOptions::DROP;
break;
case SetLookupOptions::EMIT_NULL:
if (guarantee.nullable) return std::nullopt;
null_selection = FilterOptions::DROP;
break;
case SetLookupOptions::INCONCLUSIVE:
if (guarantee.nullable) return std::nullopt;
ARROW_ASSIGN_OR_RAISE(Datum is_null, IsNull(options->value_set));
ARROW_ASSIGN_OR_RAISE(Datum any_null, Any(is_null));
if (any_null.scalar_as<BooleanScalar>().value) return std::nullopt;
null_selection = FilterOptions::DROP;
break;
}

std::string func_name = Comparison::GetName(guarantee.cmp);
DCHECK_NE(func_name, "na");
std::vector<Datum> args{options->value_set, guarantee.bound};
ARROW_ASSIGN_OR_RAISE(Datum filter_mask, CallFunction(func_name, args));
FilterOptions filter_options(null_selection);
ARROW_ASSIGN_OR_RAISE(Datum simplified_value_set,
Filter(options->value_set, filter_mask, filter_options));

if (simplified_value_set.length() == 0) return literal(false);
if (simplified_value_set.length() == options->value_set.length()) return std::nullopt;

ExecContext exec_context;
Expression::Call simplified_call;
simplified_call.function_name = "is_in";
simplified_call.arguments = is_in_call->arguments;
simplified_call.options = std::make_shared<SetLookupOptions>(
simplified_value_set, options->null_matching_behavior);
ARROW_ASSIGN_OR_RAISE(
Expression simplified_expr,
BindNonRecursive(std::move(simplified_call),
/*insert_implicit_casts=*/false, &exec_context));
return simplified_expr;
}

/// \brief Simplify the given expression given this inequality as a guarantee.
Result<Expression> Simplify(Expression expr) {
const auto& guarantee = *this;
Expand All @@ -1258,6 +1325,12 @@ struct Inequality {
return call->function_name == "is_valid" ? literal(true) : literal(false);
}

if (call->function_name == "is_in") {
ARROW_ASSIGN_OR_RAISE(std::optional<Expression> result,
SimplifyIsIn(guarantee, call));
return result.value_or(expr);
}

auto cmp = Comparison::Get(expr);
if (!cmp) return expr;

Expand Down
173 changes: 173 additions & 0 deletions cpp/src/arrow/compute/expression_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
#include <gmock/gmock.h>
#include <gtest/gtest.h>

#include "arrow/array/builder_primitive.h"
#include "arrow/compute/expression_internal.h"
#include "arrow/compute/function_internal.h"
#include "arrow/compute/registry.h"
Expand Down Expand Up @@ -1616,6 +1617,144 @@ TEST(Expression, SimplifyWithComparisonAndNullableCaveat) {
true_unless_null(field_ref("i32")))); // not satisfiable, will drop row group
}

TEST(Expression, SimplifyIsIn) {
auto is_in = [](Expression field, std::shared_ptr<DataType> value_set_type,
std::string json_array,
SetLookupOptions::NullMatchingBehavior null_matching_behavior) {
SetLookupOptions options{ArrayFromJSON(value_set_type, json_array),
null_matching_behavior};
return call("is_in", {field}, options);
};

for (SetLookupOptions::NullMatchingBehavior null_matching : {
SetLookupOptions::MATCH,
SetLookupOptions::SKIP,
SetLookupOptions::EMIT_NULL,
SetLookupOptions::INCONCLUSIVE,
}) {
Simplify{is_in(field_ref("i32"), int32(), "[]", null_matching)}
.WithGuarantee(greater(field_ref("i32"), literal(2)))
.Expect(false);

Simplify{is_in(field_ref("i32"), int32(), "[1,3,5,7,9]", null_matching)}
.WithGuarantee(equal(field_ref("i32"), literal(6)))
.Expect(false);

Simplify{is_in(field_ref("i32"), int32(), "[1,3,5,7,9]", null_matching)}
.WithGuarantee(greater(field_ref("i32"), literal(3)))
.Expect(is_in(field_ref("i32"), int32(), "[5,7,9]", null_matching));

Simplify{is_in(field_ref("i32"), int32(), "[1,3,5,7,9]", null_matching)}
.WithGuarantee(greater(field_ref("i32"), literal(9)))
.Expect(false);

Simplify{is_in(field_ref("i32"), int32(), "[1,3,5,7,9]", null_matching)}
.WithGuarantee(less_equal(field_ref("i32"), literal(0)))
.Expect(false);

Simplify{is_in(field_ref("i32"), int32(), "[1,3,5,7,9]", null_matching)}
.WithGuarantee(greater(field_ref("i32"), literal(0)))
.ExpectUnchanged();

Simplify{is_in(field_ref("i32"), int32(), "[1,3,5,7,9]", null_matching)}
.WithGuarantee(less_equal(field_ref("i32"), literal(9)))
.ExpectUnchanged();

Simplify{is_in(field_ref("i32"), int32(), "[1,3,5,7,9]", null_matching)}
.WithGuarantee(and_(less_equal(field_ref("i32"), literal(7)),
greater(field_ref("i32"), literal(4))))
.Expect(is_in(field_ref("i32"), int32(), "[5,7]", null_matching));

Simplify{is_in(field_ref("u32"), int8(), "[1,3,5,7,9]", null_matching)}
.WithGuarantee(greater(field_ref("u32"), literal(3)))
.Expect(is_in(field_ref("u32"), int8(), "[5,7,9]", null_matching));

Simplify{is_in(field_ref("u32"), int64(), "[1,3,5,7,9]", null_matching)}
.WithGuarantee(greater(field_ref("u32"), literal(3)))
.Expect(is_in(field_ref("u32"), int64(), "[5,7,9]", null_matching));
}

Simplify{
is_in(field_ref("i32"), int32(), "[1,2,3]", SetLookupOptions::MATCH),
}
.WithGuarantee(
or_(greater(field_ref("i32"), literal(2)), is_null(field_ref("i32"))))
.Expect(is_in(field_ref("i32"), int32(), "[3]", SetLookupOptions::MATCH));

Simplify{
is_in(field_ref("i32"), int32(), "[1,2,3,null]", SetLookupOptions::MATCH),
}
.WithGuarantee(greater(field_ref("i32"), literal(2)))
.Expect(is_in(field_ref("i32"), int32(), "[3]", SetLookupOptions::MATCH));

Simplify{
is_in(field_ref("i32"), int32(), "[1,2,3,null]", SetLookupOptions::MATCH),
}
.WithGuarantee(
or_(greater(field_ref("i32"), literal(2)), is_null(field_ref("i32"))))
.Expect(is_in(field_ref("i32"), int32(), "[3,null]", SetLookupOptions::MATCH));

Simplify{
is_in(field_ref("i32"), int32(), "[1,2,3]", SetLookupOptions::SKIP),
}
.WithGuarantee(
or_(greater(field_ref("i32"), literal(2)), is_null(field_ref("i32"))))
.Expect(is_in(field_ref("i32"), int32(), "[3]", SetLookupOptions::SKIP));

Simplify{
is_in(field_ref("i32"), int32(), "[1,2,3,null]", SetLookupOptions::SKIP),
}
.WithGuarantee(greater(field_ref("i32"), literal(2)))
.Expect(is_in(field_ref("i32"), int32(), "[3]", SetLookupOptions::SKIP));

Simplify{
is_in(field_ref("i32"), int32(), "[1,2,3,null]", SetLookupOptions::SKIP),
}
.WithGuarantee(
or_(greater(field_ref("i32"), literal(2)), is_null(field_ref("i32"))))
.Expect(is_in(field_ref("i32"), int32(), "[3]", SetLookupOptions::SKIP));

Simplify{
is_in(field_ref("i32"), int32(), "[1,2,3]", SetLookupOptions::EMIT_NULL),
}
.WithGuarantee(
or_(greater(field_ref("i32"), literal(2)), is_null(field_ref("i32"))))
.ExpectUnchanged();

Simplify{
is_in(field_ref("i32"), int32(), "[1,2,3,null]", SetLookupOptions::EMIT_NULL),
}
.WithGuarantee(greater(field_ref("i32"), literal(2)))
.Expect(is_in(field_ref("i32"), int32(), "[3]", SetLookupOptions::EMIT_NULL));

Simplify{
is_in(field_ref("i32"), int32(), "[1,2,3,null]", SetLookupOptions::EMIT_NULL),
}
.WithGuarantee(
or_(greater(field_ref("i32"), literal(2)), is_null(field_ref("i32"))))
.ExpectUnchanged();

Simplify{
is_in(field_ref("i32"), int32(), "[1,2,3]", SetLookupOptions::INCONCLUSIVE),
}
.WithGuarantee(
or_(greater(field_ref("i32"), literal(2)), is_null(field_ref("i32"))))
.ExpectUnchanged();

Simplify{
is_in(field_ref("i32"), int32(), "[1,2,3,null]", SetLookupOptions::INCONCLUSIVE),
}
.WithGuarantee(greater(field_ref("i32"), literal(2)))
.ExpectUnchanged();

Simplify{
is_in(field_ref("i32"), int32(), "[1,2,3,null]", SetLookupOptions::INCONCLUSIVE),
}
.WithGuarantee(
or_(greater(field_ref("i32"), literal(2)), is_null(field_ref("i32"))))
.ExpectUnchanged();
}

TEST(Expression, SimplifyThenExecute) {
auto filter =
or_({equal(field_ref("f32"), literal(0)),
Expand Down Expand Up @@ -1643,6 +1782,40 @@ TEST(Expression, SimplifyThenExecute) {
AssertDatumsEqual(evaluated, simplified_evaluated, /*verbose=*/true);
}

TEST(Expression, SimplifyIsInThenExecute) {
auto input = RecordBatchFromJSON(kBoringSchema, R"([
{"i64": 2, "i32": 5},
{"i64": 5, "i32": 6},
{"i64": 3, "i32": 6},
{"i64": 3, "i32": 5},
{"i64": 4, "i32": 5},
{"i64": 2, "i32": 7},
{"i64": 5, "i32": 5}
])");

std::vector<Expression> guarantees{greater(field_ref("i64"), literal(1)),
greater_equal(field_ref("i32"), literal(5)),
less_equal(field_ref("i64"), literal(5))};

for (const Expression& guarantee : guarantees) {
auto filter =
call("is_in", {guarantee.call()->arguments[0]},
compute::SetLookupOptions{ArrayFromJSON(int32(), "[1,2,3]"), true});
ASSERT_OK_AND_ASSIGN(filter, filter.Bind(*kBoringSchema));
ASSERT_OK_AND_ASSIGN(auto simplified, SimplifyWithGuarantee(filter, guarantee));

Datum evaluated, simplified_evaluated;
ExpectExecute(filter, input, &evaluated);
ExpectExecute(simplified, input, &simplified_evaluated);
if (simplified_evaluated.is_scalar()) {
ASSERT_OK_AND_ASSIGN(
simplified_evaluated,
MakeArrayFromScalar(*simplified_evaluated.scalar(), evaluated.length()));
}
AssertDatumsEqual(evaluated, simplified_evaluated, /*verbose=*/true);
}
}

TEST(Expression, Filter) {
auto ExpectFilter = [](Expression filter, std::string batch_json) {
ASSERT_OK_AND_ASSIGN(auto s, kBoringSchema->AddField(0, field("in", boolean())));
Expand Down

0 comments on commit 44b72d5

Please sign in to comment.