Skip to content

Commit

Permalink
Implement IN and NOT IN (#1344)
Browse files Browse the repository at this point in the history
  • Loading branch information
joka921 authored Jun 28, 2024
1 parent 18e2fd6 commit 797f325
Show file tree
Hide file tree
Showing 9 changed files with 367 additions and 57 deletions.
169 changes: 131 additions & 38 deletions src/engine/sparqlExpressions/RelationalExpressions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

#include "engine/sparqlExpressions/LangExpression.h"
#include "engine/sparqlExpressions/LiteralExpression.h"
#include "engine/sparqlExpressions/NaryExpression.h"
#include "engine/sparqlExpressions/RelationalExpressionHelpers.h"
#include "engine/sparqlExpressions/SparqlExpressionGenerators.h"
#include "util/LambdaHelpers.h"
Expand All @@ -15,6 +16,10 @@ using namespace sparqlExpression;

namespace {

constexpr int reductionFactorEquals = 1000;
constexpr int reductionFactorNotEquals = 1;
constexpr int reductionFactorDefault = 50;

using valueIdComparators::Comparison;

// Several concepts used to choose the proper evaluation methods for different
Expand All @@ -29,8 +34,8 @@ using valueIdComparators::Comparison;
// First the `idGenerator` for constants (string, int, double). It yields the
// same ID `targetSize` many times.
template <SingleExpressionResult S>
requires isConstantResult<S>
auto idGenerator(S value, size_t targetSize, const EvaluationContext* context)
requires isConstantResult<S> auto idGenerator(const S& value, size_t targetSize,
const EvaluationContext* context)
-> cppcoro::generator<const decltype(makeValueId(value, context))> {
auto id = makeValueId(value, context);
for (size_t i = 0; i < targetSize; ++i) {
Expand All @@ -42,8 +47,8 @@ auto idGenerator(S value, size_t targetSize, const EvaluationContext* context)
// equal to `targetSize` and the yields the corresponding ID for each of the
// elements in the vector.
template <SingleExpressionResult S>
requires isVectorResult<S>
auto idGenerator(S values, size_t targetSize, const EvaluationContext* context)
requires isVectorResult<S> auto idGenerator(const S& values, size_t targetSize,
const EvaluationContext* context)
-> cppcoro::generator<decltype(makeValueId(values[0], context))> {
AD_CONTRACT_CHECK(targetSize == values.size());
for (const auto& el : values) {
Expand All @@ -52,12 +57,12 @@ auto idGenerator(S values, size_t targetSize, const EvaluationContext* context)
}
}

// For the `Variable` class, the generator from the `sparqlExpressions` module
// already yields the `ValueIds`.
auto idGenerator(Variable variable, size_t targetSize,
const EvaluationContext* context) {
return sparqlExpression::detail::makeGenerator(std::move(variable),
targetSize, context);
// For the `Variable` and `SetOfIntervals` class, the generator from the
// `sparqlExpressions` module already yields the `ValueIds`.
template <ad_utility::SimilarToAny<Variable, ad_utility::SetOfIntervals> S>
auto idGenerator(S input, size_t targetSize, const EvaluationContext* context) {
return sparqlExpression::detail::makeGenerator(std::move(input), targetSize,
context);
}

// Return a pair of generators that generate the values from `value1` and
Expand All @@ -67,16 +72,16 @@ auto idGenerator(Variable variable, size_t targetSize,
// both inputs. Else the "plain" generators from `sparqlExpression::detail` are
// returned. These simply yield the values unchanged.
template <SingleExpressionResult S1, SingleExpressionResult S2>
auto getGenerators(S1 value1, S2 value2, size_t targetSize,
auto getGenerators(S1&& value1, S2&& value2, size_t targetSize,
const EvaluationContext* context) {
if constexpr (StoresValueId<S1> || StoresValueId<S2>) {
return std::pair{idGenerator(std::move(value1), targetSize, context),
idGenerator(std::move(value2), targetSize, context)};
return std::pair{idGenerator(AD_FWD(value1), targetSize, context),
idGenerator(AD_FWD(value2), targetSize, context)};
} else {
return std::pair{sparqlExpression::detail::makeGenerator(
std::move(value1), targetSize, context),
AD_FWD(value1), targetSize, context),
sparqlExpression::detail::makeGenerator(
std::move(value2), targetSize, context)};
AD_FWD(value2), targetSize, context)};
}
}

Expand Down Expand Up @@ -128,7 +133,7 @@ ad_utility::SetOfIntervals evaluateWithBinarySearch(
// supported and not always false.
template <Comparison Comp, SingleExpressionResult S1, SingleExpressionResult S2>
requires AreComparable<S1, S2> ExpressionResult evaluateRelationalExpression(
S1 value1, S2 value2, const EvaluationContext* context) {
S1&& value1, S2&& value2, const EvaluationContext* context) {
auto resultSize =
sparqlExpression::detail::getResultSize(*context, value1, value2);
constexpr static bool resultIsConstant =
Expand Down Expand Up @@ -171,7 +176,7 @@ requires AreComparable<S1, S2> ExpressionResult evaluateRelationalExpression(
}

auto [generatorA, generatorB] =
getGenerators(std::move(value1), std::move(value2), resultSize, context);
getGenerators(AD_FWD(value1), AD_FWD(value2), resultSize, context);
auto itA = generatorA.begin();
auto itB = generatorB.begin();

Expand Down Expand Up @@ -226,9 +231,9 @@ Id evaluateRelationalExpression(const A&, const B&, const EvaluationContext*) {
template <Comparison Comp, SingleExpressionResult A, SingleExpressionResult B>
requires(!AreComparable<A, B> && AreComparable<B, A>)
ExpressionResult evaluateRelationalExpression(
A a, B b, const EvaluationContext* context) {
A&& a, B&& b, const EvaluationContext* context) {
return evaluateRelationalExpression<getComparisonForSwappedArguments(Comp)>(
std::move(b), std::move(a), context);
AD_FWD(b), AD_FWD(a), context);
}

} // namespace
Expand Down Expand Up @@ -304,40 +309,128 @@ RelationalExpression<Comp>::getLanguageFilterExpression() const {
return getLangFilterData(child2, child1);
}
}
namespace {
// _____________________________________________________________________________
SparqlExpression::Estimates getEstimatesForFilterExpressionImpl(
uint64_t inputSizeEstimate, uint64_t reductionFactor, const auto& children,
const std::optional<Variable>& firstSortedVariable) {
AD_CORRECTNESS_CHECK(children.size() >= 1);
// For the binary expressions `=` `<=`, etc., we have exactly two children, so
// the following line is a noop. For the `IN` expression we expect to have
// more results if we have more arguments on the right side that can possibly
// match, so we reduce the `reductionFactor`.
reductionFactor /= children.size() - 1;
auto sizeEstimate = inputSizeEstimate / reductionFactor;

// By default, we have to linearly scan over the input and write the output.
size_t costEstimate = inputSizeEstimate + sizeEstimate;

// Returns true iff `left` is a variable by which the input is sorted, and
// `right` is a constant.
auto canBeEvaluatedWithBinarySearch =
[&firstSortedVariable](const SparqlExpression::Ptr& left,
const SparqlExpression::Ptr& right) {
auto varPtr = dynamic_cast<const VariableExpression*>(left.get());
return varPtr && varPtr->value() == firstSortedVariable &&
right->isConstantExpression();
};
// TODO<joka921> This check has to be more complex once we support proper
// filtering on the `LocalVocab`.
// Check iff all the pairs `(children[0], someOtherChild)` can be evaluated
// using binary search.
if (std::ranges::all_of(children | std::views::drop(1),
[&lhs = children.at(0),
&canBeEvaluatedWithBinarySearch](const auto& child) {
// The implementation automatically chooses the
// cheaper direction, so we can do the same when
// estimating the cost.
return canBeEvaluatedWithBinarySearch(lhs, child) ||
canBeEvaluatedWithBinarySearch(child, lhs);
})) {
// When evaluating via binary search, the only significant cost that occurs
// is that of writing the output.
costEstimate = sizeEstimate;
}
return {sizeEstimate, costEstimate};
}
} // namespace

// _____________________________________________________________________________
template <Comparison comp>
SparqlExpression::Estimates
RelationalExpression<comp>::getEstimatesForFilterExpression(
uint64_t inputSizeEstimate,
[[maybe_unused]] const std::optional<Variable>& firstSortedVariable) const {
size_t sizeEstimate = 0;
const std::optional<Variable>& firstSortedVariable) const {
uint64_t reductionFactor = 0;

if (comp == valueIdComparators::Comparison::EQ) {
sizeEstimate = inputSizeEstimate / 1000;
reductionFactor = reductionFactorEquals;
} else if (comp == valueIdComparators::Comparison::NE) {
sizeEstimate = inputSizeEstimate;
reductionFactor = reductionFactorNotEquals;
} else {
sizeEstimate = inputSizeEstimate / 50;
reductionFactor = reductionFactorDefault;
}

size_t costEstimate = sizeEstimate;
return getEstimatesForFilterExpressionImpl(inputSizeEstimate, reductionFactor,
children_, firstSortedVariable);
}

auto canBeEvaluatedWithBinarySearch = [&firstSortedVariable](
const Ptr& left, const Ptr& right) {
auto varPtr = dynamic_cast<const VariableExpression*>(left.get());
if (!varPtr || varPtr->value() != firstSortedVariable) {
return false;
// _____________________________________________________________________________
ExpressionResult InExpression::evaluate(
sparqlExpression::EvaluationContext* context) const {
auto lhs = children_.at(0)->evaluate(context);
ExpressionResult result{ad_utility::SetOfIntervals{}};
bool firstChild = true;
for (const auto& child : children_ | std::views::drop(1)) {
auto rhs = child->evaluate(context);
auto evaluateEqualsExpression = [context](const auto& a,
auto b) -> ExpressionResult {
return evaluateRelationalExpression<Comparison::EQ>(a, std::move(b),
context);
};
auto subRes = std::visit(evaluateEqualsExpression, lhs, std::move(rhs));
if (firstChild) {
firstChild = false;
result = std::move(subRes);
continue;
}
return right->isConstantExpression();
};
// TODO We could implement early stopping for entries which are already
// true, This could be especially beneficial if some of the `==` expressions
// are more expensive than others. Same goes for the `logical or` and
// `logical and` expressions.
auto expressionForSubRes =
std::make_unique<SingleUseExpression>(std::move(subRes));
auto expressionForPreviousResult =
std::make_unique<SingleUseExpression>(std::move(result));
result = makeOrExpression(std::move(expressionForSubRes),
std::move(expressionForPreviousResult))
->evaluate(context);
}
return result;
}

// TODO<joka921> This check has to be more complex once we support proper
// filtering on the `LocalVocab`.
if (canBeEvaluatedWithBinarySearch(children_[0], children_[1]) ||
canBeEvaluatedWithBinarySearch(children_[1], children_[0])) {
costEstimate = 0;
// _____________________________________________________________________________
std::span<SparqlExpression::Ptr> InExpression::childrenImpl() {
return children_;
}

// _____________________________________________________________________________
string InExpression::getCacheKey(const VariableToColumnMap& varColMap) const {
std::stringstream result;
result << "IN Expression with (";
for (const auto& child : children_) {
result << ' ' << child->getCacheKey(varColMap);
}
return {sizeEstimate, costEstimate};
result << ')';
return std::move(result).str();
}

// _____________________________________________________________________________
auto InExpression::getEstimatesForFilterExpression(
uint64_t inputSizeEstimate,
const std::optional<Variable>& firstSortedVariable) const -> Estimates {
return getEstimatesForFilterExpressionImpl(
inputSizeEstimate, reductionFactorEquals, children_, firstSortedVariable);
}

// Explicit instantiations
Expand Down
34 changes: 34 additions & 0 deletions src/engine/sparqlExpressions/RelationalExpressions.h
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,38 @@ class RelationalExpression : public SparqlExpression {
std::span<SparqlExpression::Ptr> childrenImpl() override;
};

// Implementation of the `IN` expression
class InExpression : public SparqlExpression {
public:
using Children = std::vector<SparqlExpression::Ptr>;

private:
// The first child implicitly is the left hand side.
Children children_;

public:
// Construct from the two children.
explicit InExpression(SparqlExpression::Ptr lhs, Children children) {
children_.reserve(children.size() + 1);
children_.push_back(std::move(lhs));
std::ranges::move(children, std::back_inserter(children_));
}

ExpressionResult evaluate(EvaluationContext* context) const override;

[[nodiscard]] string getCacheKey(
const VariableToColumnMap& varColMap) const override;

// These expressions are typically used inside `FILTER` clauses, so we need
// proper estimates.
Estimates getEstimatesForFilterExpression(
uint64_t inputSizeEstimate,
const std::optional<Variable>& firstSortedVariable) const override;

private:
std::span<SparqlExpression::Ptr> childrenImpl() override;
};

} // namespace sparqlExpression::relational

namespace sparqlExpression {
Expand All @@ -64,4 +96,6 @@ using GreaterThanExpression =
relational::RelationalExpression<valueIdComparators::Comparison::GT>;
using GreaterEqualExpression =
relational::RelationalExpression<valueIdComparators::Comparison::GE>;

using InExpression = relational::InExpression;
} // namespace sparqlExpression
4 changes: 1 addition & 3 deletions src/engine/sparqlExpressions/SparqlExpressionGenerators.h
Original file line number Diff line number Diff line change
Expand Up @@ -86,9 +86,7 @@ template <SingleExpressionResult Input, typename Transformation = std::identity>
auto makeGenerator(Input&& input, size_t numItems, const EvaluationContext* context,
Transformation transformation = {}) {
if constexpr (ad_utility::isSimilar<::Variable, Input>) {
std::span<const ValueId> inputWithVariableResolved{
getIdsFromVariable(std::forward<Input>(input), context)};
return resultGenerator(inputWithVariableResolved, numItems, transformation);
return resultGenerator(getIdsFromVariable(AD_FWD(input), context), numItems, transformation);
} else {
return resultGenerator(AD_FWD(input), numItems, transformation);
}
Expand Down
11 changes: 10 additions & 1 deletion src/parser/sparqlParser/SparqlQleverVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1722,7 +1722,16 @@ ExpressionPtr Visitor::visit(Parser::RelationalExpressionContext* ctx) {
auto children = visitVector(ctx->numericExpression());

if (ctx->expressionList()) {
reportNotSupported(ctx, "IN or NOT IN in an expression is ");
auto lhs = visitVector(ctx->numericExpression());
AD_CORRECTNESS_CHECK(lhs.size() == 1);
auto expressions = visit(ctx->expressionList());
auto inExpression = std::make_unique<InExpression>(std::move(lhs.at(0)),
std::move(expressions));
if (ctx->notToken) {
return makeUnaryNegateExpression(std::move(inExpression));
} else {
return inExpression;
}
}
AD_CONTRACT_CHECK(children.size() == 1 || children.size() == 2);
if (children.size() == 1) {
Expand Down
2 changes: 1 addition & 1 deletion src/parser/sparqlParser/generated/SparqlAutomatic.g4
Original file line number Diff line number Diff line change
Expand Up @@ -466,7 +466,7 @@ valueLogical
;

relationalExpression
: numericExpression ( '=' numericExpression | '!=' numericExpression | '<' numericExpression | '>' numericExpression | '<=' numericExpression | '>=' numericExpression | IN expressionList | NOT IN expressionList)?
: numericExpression ( '=' numericExpression | '!=' numericExpression | '<' numericExpression | '>' numericExpression | '<=' numericExpression | '>=' numericExpression | IN expressionList | (notToken = NOT) IN expressionList)?
;

numericExpression
Expand Down
3 changes: 2 additions & 1 deletion src/parser/sparqlParser/generated/SparqlAutomaticParser.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13100,7 +13100,8 @@ SparqlAutomaticParser::relationalExpression() {

case SparqlAutomaticParser::NOT: {
setState(1159);
match(SparqlAutomaticParser::NOT);
antlrcpp::downCast<RelationalExpressionContext*>(_localctx)->notToken =
match(SparqlAutomaticParser::NOT);
setState(1160);
match(SparqlAutomaticParser::IN);
setState(1161);
Expand Down
1 change: 1 addition & 0 deletions src/parser/sparqlParser/generated/SparqlAutomaticParser.h
Original file line number Diff line number Diff line change
Expand Up @@ -2510,6 +2510,7 @@ class SparqlAutomaticParser : public antlr4::Parser {

class RelationalExpressionContext : public antlr4::ParserRuleContext {
public:
antlr4::Token* notToken = nullptr;
RelationalExpressionContext(antlr4::ParserRuleContext* parent,
size_t invokingState);
virtual size_t getRuleIndex() const override;
Expand Down
Loading

0 comments on commit 797f325

Please sign in to comment.