diff --git a/src/ast/UserDefinedAggregator.cpp b/src/ast/UserDefinedAggregator.cpp index ba44e0f7a86..d924a47bbca 100644 --- a/src/ast/UserDefinedAggregator.cpp +++ b/src/ast/UserDefinedAggregator.cpp @@ -34,10 +34,9 @@ Node::NodeVec UserDefinedAggregator::getChildren() const { } void UserDefinedAggregator::print(std::ostream& os) const { - os << "@" << name; - os << " init: " << *initValue; + os << "@@" << name << " " << *initValue; if (targetExpression) { - os << " " << *targetExpression; + os << ", " << *targetExpression; } os << " : { " << join(body) << " }"; } diff --git a/src/ram/Aggregate.h b/src/ram/Aggregate.h index 5c967277348..d59145d26e9 100644 --- a/src/ram/Aggregate.h +++ b/src/ram/Aggregate.h @@ -63,6 +63,7 @@ class Aggregate : public RelationOperation, public AbstractAggregate { RelationOperation::apply(map); condition = map(std::move(condition)); expression = map(std::move(expression)); + function->apply(map); } static bool classof(const Node* n) { diff --git a/src/ram/Aggregator.h b/src/ram/Aggregator.h index 89608c2fb1f..7514a4a7e08 100644 --- a/src/ram/Aggregator.h +++ b/src/ram/Aggregator.h @@ -45,6 +45,8 @@ class Aggregator { return {}; } + virtual void apply(const NodeMapper&) {} + /** * @brief Create a cloning (i.e. deep copy) of this node */ diff --git a/src/ram/IndexAggregate.h b/src/ram/IndexAggregate.h index b9fb90049cf..3b7d6c80a42 100644 --- a/src/ram/IndexAggregate.h +++ b/src/ram/IndexAggregate.h @@ -66,6 +66,7 @@ class IndexAggregate : public IndexOperation, public AbstractAggregate { IndexOperation::apply(map); condition = map(std::move(condition)); expression = map(std::move(expression)); + function->apply(map); } static bool classof(const Node* n) { diff --git a/src/ram/UserDefinedAggregator.h b/src/ram/UserDefinedAggregator.h index 4c46d873afe..ce83442aeb9 100644 --- a/src/ram/UserDefinedAggregator.h +++ b/src/ram/UserDefinedAggregator.h @@ -77,6 +77,10 @@ class UserDefinedAggregator : public Aggregator { os << name << " INIT " << *initValue << " "; } + void apply(const NodeMapper& map) override { + initValue = map(std::move(initValue)); + } + protected: /** Aggregation function */ const std::string name; @@ -92,4 +96,4 @@ class UserDefinedAggregator : public Aggregator { /** Stateful */ const bool stateful; }; -} // namespace souffle::ram \ No newline at end of file +} // namespace souffle::ram