Skip to content

Commit

Permalink
Add option for changing the comprehension accumulator variable used b…
Browse files Browse the repository at this point in the history
…y standard macros.

PiperOrigin-RevId: 708381185
  • Loading branch information
jnthntatum authored and copybara-github committed Dec 20, 2024
1 parent 72871f2 commit 8ce99ed
Show file tree
Hide file tree
Showing 8 changed files with 305 additions and 17 deletions.
11 changes: 7 additions & 4 deletions common/expr_factory.h
Original file line number Diff line number Diff line change
Expand Up @@ -179,9 +179,9 @@ class ExprFactory {
return expr;
}

Expr NewAccuIdent(ExprId id) {
return NewIdent(id, kAccumulatorVariableName);
}
absl::string_view AccuVarName() { return accu_var_; }

Expr NewAccuIdent(ExprId id) { return NewIdent(id, AccuVarName()); }

template <typename Operand, typename Field,
typename = std::enable_if_t<IsExprLike<Operand>::value>,
Expand Down Expand Up @@ -356,7 +356,10 @@ class ExprFactory {
friend class MacroExprFactory;
friend class ParserMacroExprFactory;

ExprFactory() = default;
ExprFactory() : accu_var_(kAccumulatorVariableName) {}
explicit ExprFactory(absl::string_view accu_var) : accu_var_(accu_var) {}

std::string accu_var_;
};

} // namespace cel
Expand Down
1 change: 1 addition & 0 deletions parser/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ cc_library(
"//base/ast_internal:expr",
"//common:ast",
"//common:constant",
"//common:expr",
"//common:expr_factory",
"//common:operators",
"//common:source",
Expand Down
12 changes: 6 additions & 6 deletions parser/macro.cc
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ absl::optional<Expr> ExpandAllMacro(MacroExprFactory& factory, Expr& target,
std::move(args[1]));
auto result = factory.NewAccuIdent();
return factory.NewComprehension(args[0].ident_expr().name(),
std::move(target), kAccumulatorVariableName,
std::move(target), factory.AccuVarName(),
std::move(init), std::move(condition),
std::move(step), std::move(result));
}
Expand Down Expand Up @@ -136,7 +136,7 @@ absl::optional<Expr> ExpandExistsMacro(MacroExprFactory& factory, Expr& target,
std::move(args[1]));
auto result = factory.NewAccuIdent();
return factory.NewComprehension(args[0].ident_expr().name(),
std::move(target), kAccumulatorVariableName,
std::move(target), factory.AccuVarName(),
std::move(init), std::move(condition),
std::move(step), std::move(result));
}
Expand Down Expand Up @@ -172,7 +172,7 @@ absl::optional<Expr> ExpandExistsOneMacro(MacroExprFactory& factory,
auto result = factory.NewCall(CelOperator::EQUALS, factory.NewAccuIdent(),
factory.NewIntConst(1));
return factory.NewComprehension(args[0].ident_expr().name(),
std::move(target), kAccumulatorVariableName,
std::move(target), factory.AccuVarName(),
std::move(init), std::move(condition),
std::move(step), std::move(result));
}
Expand Down Expand Up @@ -204,7 +204,7 @@ absl::optional<Expr> ExpandMap2Macro(MacroExprFactory& factory, Expr& target,
CelOperator::ADD, factory.NewAccuIdent(),
factory.NewList(factory.NewListElement(std::move(args[1]))));
return factory.NewComprehension(args[0].ident_expr().name(),
std::move(target), kAccumulatorVariableName,
std::move(target), factory.AccuVarName(),
std::move(init), std::move(condition),
std::move(step), factory.NewAccuIdent());
}
Expand Down Expand Up @@ -237,7 +237,7 @@ absl::optional<Expr> ExpandMap3Macro(MacroExprFactory& factory, Expr& target,
step = factory.NewCall(CelOperator::CONDITIONAL, std::move(args[1]),
std::move(step), factory.NewAccuIdent());
return factory.NewComprehension(args[0].ident_expr().name(),
std::move(target), kAccumulatorVariableName,
std::move(target), factory.AccuVarName(),
std::move(init), std::move(condition),
std::move(step), factory.NewAccuIdent());
}
Expand Down Expand Up @@ -272,7 +272,7 @@ absl::optional<Expr> ExpandFilterMacro(MacroExprFactory& factory, Expr& target,
step = factory.NewCall(CelOperator::CONDITIONAL, std::move(args[1]),
std::move(step), factory.NewAccuIdent());
return factory.NewComprehension(std::move(name), std::move(target),
kAccumulatorVariableName, std::move(init),
factory.AccuVarName(), std::move(init),
std::move(condition), std::move(step),
factory.NewAccuIdent());
}
Expand Down
6 changes: 5 additions & 1 deletion parser/macro_expr_factory.h
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,8 @@ class MacroExprFactory : protected ExprFactory {
return NewIdent(NextId(), std::move(name));
}

absl::string_view AccuVarName() { return ExprFactory::AccuVarName(); }

ABSL_MUST_USE_RESULT Expr NewAccuIdent() { return NewAccuIdent(NextId()); }

template <typename Operand, typename Field,
Expand Down Expand Up @@ -282,6 +284,7 @@ class MacroExprFactory : protected ExprFactory {
const Expr& expr, absl::string_view message) = 0;

protected:
using ExprFactory::AccuVarName;
using ExprFactory::NewAccuIdent;
using ExprFactory::NewBoolConst;
using ExprFactory::NewBytesConst;
Expand Down Expand Up @@ -316,7 +319,8 @@ class MacroExprFactory : protected ExprFactory {
friend class ParserMacroExprFactory;
friend class TestMacroExprFactory;

MacroExprFactory() : ExprFactory() {}
explicit MacroExprFactory(absl::string_view accu_var)
: ExprFactory(accu_var) {}
};

} // namespace cel
Expand Down
2 changes: 1 addition & 1 deletion parser/macro_expr_factory_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ namespace cel {

class TestMacroExprFactory final : public MacroExprFactory {
public:
TestMacroExprFactory() : MacroExprFactory() {}
TestMacroExprFactory() : MacroExprFactory(kAccumulatorVariableName) {}

ExprId id() const { return id_; }

Expand Down
3 changes: 3 additions & 0 deletions parser/options.h
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,9 @@ struct ParserOptions final {

// Disable standard macros (has, all, exists, exists_one, filter, map).
bool disable_standard_macros = false;

// Enable hidden accumulator variable '@result' for builtin comprehensions.
bool enable_hidden_accumulator_var = false;
};

} // namespace cel
Expand Down
20 changes: 15 additions & 5 deletions parser/parser.cc
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@
#include "base/ast_internal/expr.h"
#include "common/ast.h"
#include "common/constant.h"
#include "common/expr.h"
#include "common/expr_factory.h"
#include "common/operators.h"
#include "common/source.h"
Expand Down Expand Up @@ -84,6 +85,8 @@ namespace cel {

namespace {

constexpr const char kHiddenAccumulatorVariableName[] = "@result";

std::any ExprPtrToAny(std::unique_ptr<Expr>&& expr) {
return std::make_any<Expr*>(expr.release());
}
Expand Down Expand Up @@ -158,8 +161,9 @@ SourceRange SourceRangeFromParserRuleContext(

class ParserMacroExprFactory final : public MacroExprFactory {
public:
explicit ParserMacroExprFactory(const cel::Source& source)
: MacroExprFactory(), source_(source) {}
explicit ParserMacroExprFactory(const cel::Source& source,
absl::string_view accu_var)
: MacroExprFactory(accu_var), source_(source) {}

void BeginMacro(SourceRange macro_position) {
macro_position_ = macro_position;
Expand Down Expand Up @@ -601,6 +605,7 @@ class ParserVisitor final : public CelBaseVisitor,
public antlr4::BaseErrorListener {
public:
ParserVisitor(const cel::Source& source, int max_recursion_depth,
absl::string_view accu_var,
const cel::MacroRegistry& macro_registry,
bool add_macro_calls = false,
bool enable_optional_syntax = false);
Expand Down Expand Up @@ -704,11 +709,12 @@ class ParserVisitor final : public CelBaseVisitor,

ParserVisitor::ParserVisitor(const cel::Source& source,
const int max_recursion_depth,
absl::string_view accu_var,
const cel::MacroRegistry& macro_registry,
const bool add_macro_calls,
bool enable_optional_syntax)
: source_(source),
factory_(source_),
factory_(source_, accu_var),
macro_registry_(macro_registry),
recursion_depth_(0),
max_recursion_depth_(max_recursion_depth),
Expand Down Expand Up @@ -1617,8 +1623,12 @@ absl::StatusOr<ParseResult> ParseImpl(const cel::Source& source,
CommonTokenStream tokens(&lexer);
CelParser parser(&tokens);
ExprRecursionListener listener(options.max_recursion_depth);
ParserVisitor visitor(source, options.max_recursion_depth, registry,
options.add_macro_calls,
absl::string_view accu_var = cel::kAccumulatorVariableName;
if (options.enable_hidden_accumulator_var) {
accu_var = cel::kHiddenAccumulatorVariableName;
}
ParserVisitor visitor(source, options.max_recursion_depth, accu_var,
registry, options.add_macro_calls,
options.enable_optional_syntax);

lexer.removeErrorListeners();
Expand Down
Loading

0 comments on commit 8ce99ed

Please sign in to comment.