Skip to content

Commit

Permalink
[C++] Refactor and optimize SemanticContext
Browse files Browse the repository at this point in the history
  • Loading branch information
jcking committed Mar 29, 2022
1 parent 3ab65c4 commit fb7b12e
Show file tree
Hide file tree
Showing 16 changed files with 288 additions and 246 deletions.
7 changes: 4 additions & 3 deletions runtime/Cpp/runtime/src/FailedPredicateException.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
#include "atn/PredicateTransition.h"
#include "atn/ATN.h"
#include "atn/ATNState.h"
#include "support/Casts.h"
#include "support/CPPUtils.h"

#include "FailedPredicateException.h"
Expand All @@ -27,9 +28,9 @@ FailedPredicateException::FailedPredicateException(Parser *recognizer, const std

atn::ATNState *s = recognizer->getInterpreter<atn::ATNSimulator>()->atn.states[recognizer->getState()];
const atn::Transition *transition = s->transitions[0].get();
if (is<const atn::PredicateTransition*>(transition)) {
_ruleIndex = static_cast<const atn::PredicateTransition *>(transition)->ruleIndex;
_predicateIndex = static_cast<const atn::PredicateTransition *>(transition)->predIndex;
if (transition->getTransitionType() == atn::TransitionType::PREDICATE) {
_ruleIndex = downCast<const atn::PredicateTransition&>(*transition).getRuleIndex();
_predicateIndex = downCast<const atn::PredicateTransition&>(*transition).getPredIndex();
} else {
_ruleIndex = 0;
_predicateIndex = 0;
Expand Down
6 changes: 3 additions & 3 deletions runtime/Cpp/runtime/src/ParserInterpreter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,7 @@ void ParserInterpreter::visitState(atn::ATNState *p) {
case atn::TransitionType::PREDICATE:
{
const atn::PredicateTransition *predicateTransition = static_cast<const atn::PredicateTransition*>(transition);
if (!sempred(_ctx, predicateTransition->ruleIndex, predicateTransition->predIndex)) {
if (!sempred(_ctx, predicateTransition->getRuleIndex(), predicateTransition->getPredIndex())) {
throw FailedPredicateException(this);
}
}
Expand All @@ -214,8 +214,8 @@ void ParserInterpreter::visitState(atn::ATNState *p) {

case atn::TransitionType::PRECEDENCE:
{
if (!precpred(_ctx, static_cast<const atn::PrecedencePredicateTransition*>(transition)->precedence)) {
throw FailedPredicateException(this, "precpred(_ctx, " + std::to_string(static_cast<const atn::PrecedencePredicateTransition*>(transition)->precedence) + ")");
if (!precpred(_ctx, static_cast<const atn::PrecedencePredicateTransition*>(transition)->getPrecedence())) {
throw FailedPredicateException(this, "precpred(_ctx, " + std::to_string(static_cast<const atn::PrecedencePredicateTransition*>(transition)->getPrecedence()) + ")");
}
}
break;
Expand Down
1 change: 0 additions & 1 deletion runtime/Cpp/runtime/src/antlr4-runtime.h
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,6 @@
#include "atn/ATNSimulator.h"
#include "atn/ATNState.h"
#include "atn/ATNType.h"
#include "atn/AbstractPredicateTransition.h"
#include "atn/ActionTransition.h"
#include "atn/AmbiguityInfo.h"
#include "atn/ArrayPredictionContext.h"
Expand Down
11 changes: 0 additions & 11 deletions runtime/Cpp/runtime/src/atn/AbstractPredicateTransition.cpp

This file was deleted.

21 changes: 0 additions & 21 deletions runtime/Cpp/runtime/src/atn/AbstractPredicateTransition.h

This file was deleted.

1 change: 0 additions & 1 deletion runtime/Cpp/runtime/src/atn/LL1Analyzer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
#include "atn/Transition.h"
#include "atn/RuleTransition.h"
#include "atn/SingletonPredictionContext.h"
#include "atn/AbstractPredicateTransition.h"
#include "atn/WildcardTransition.h"
#include "atn/NotSetTransition.h"
#include "misc/IntervalSet.h"
Expand Down
4 changes: 2 additions & 2 deletions runtime/Cpp/runtime/src/atn/LexerATNSimulator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -392,11 +392,11 @@ Ref<LexerATNConfig> LexerATNSimulator::getEpsilonTarget(CharStream *input, const
const PredicateTransition *pt = static_cast<const PredicateTransition*>(t);

#if DEBUG_ATN == 1
std::cout << "EVAL rule " << pt->ruleIndex << ":" << pt->predIndex << std::endl;
std::cout << "EVAL rule " << pt->getRuleIndex() << ":" << pt->getPredIndex() << std::endl;
#endif

configs->hasSemanticContext = true;
if (evaluatePredicate(input, pt->ruleIndex, pt->predIndex, speculative)) {
if (evaluatePredicate(input, pt->getRuleIndex(), pt->getPredIndex(), speculative)) {
c = std::make_shared<LexerATNConfig>(*config, t->target);
}
break;
Expand Down
18 changes: 9 additions & 9 deletions runtime/Cpp/runtime/src/atn/ParserATNSimulator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1088,15 +1088,15 @@ Ref<ATNConfig> ParserATNSimulator::actionTransition(Ref<ATNConfig> const& config
Ref<ATNConfig> ParserATNSimulator::precedenceTransition(Ref<ATNConfig> const& config, const PrecedencePredicateTransition *pt,
bool collectPredicates, bool inContext, bool fullCtx) {
#if DEBUG_DFA == 1
std::cout << "PRED (collectPredicates=" << collectPredicates << ") " << pt->precedence << ">=_p" << ", ctx dependent=true" << std::endl;
std::cout << "PRED (collectPredicates=" << collectPredicates << ") " << pt->getPrecedence() << ">=_p" << ", ctx dependent=true" << std::endl;
if (parser != nullptr) {
std::cout << "context surrounding pred is " << Arrays::listToString(parser->getRuleInvocationStack(), ", ") << std::endl;
}
#endif

Ref<ATNConfig> c;
if (collectPredicates && inContext) {
Ref<SemanticContext::PrecedencePredicate> predicate = pt->getPredicate();
const auto &predicate = pt->getPredicate();

if (fullCtx) {
// In full context mode, we can evaluate predicates on-the-fly
Expand All @@ -1105,14 +1105,14 @@ Ref<ATNConfig> ParserATNSimulator::precedenceTransition(Ref<ATNConfig> const& co
// later during conflict resolution.
size_t currentPosition = _input->index();
_input->seek(_startIndex);
bool predSucceeds = evalSemanticContext(pt->getPredicate(), _outerContext, config->alt, fullCtx);
bool predSucceeds = evalSemanticContext(predicate, _outerContext, config->alt, fullCtx);
_input->seek(currentPosition);
if (predSucceeds) {
c = std::make_shared<ATNConfig>(*config, pt->target); // no pred context
}
} else {
Ref<const SemanticContext> newSemCtx = SemanticContext::And(config->semanticContext, predicate);
c = std::make_shared<ATNConfig>(*config, pt->target, newSemCtx);
c = std::make_shared<ATNConfig>(*config, pt->target, std::move(newSemCtx));
}
} else {
c = std::make_shared<ATNConfig>(*config, pt->target);
Expand All @@ -1128,30 +1128,30 @@ Ref<ATNConfig> ParserATNSimulator::precedenceTransition(Ref<ATNConfig> const& co
Ref<ATNConfig> ParserATNSimulator::predTransition(Ref<ATNConfig> const& config, const PredicateTransition *pt,
bool collectPredicates, bool inContext, bool fullCtx) {
#if DEBUG_DFA == 1
std::cout << "PRED (collectPredicates=" << collectPredicates << ") " << pt->ruleIndex << ":" << pt->predIndex << ", ctx dependent=" << pt->isCtxDependent << std::endl;
std::cout << "PRED (collectPredicates=" << collectPredicates << ") " << pt->getRuleIndex() << ":" << pt->getPredIndex() << ", ctx dependent=" << pt->isCtxDependent() << std::endl;
if (parser != nullptr) {
std::cout << "context surrounding pred is " << Arrays::listToString(parser->getRuleInvocationStack(), ", ") << std::endl;
}
#endif

Ref<ATNConfig> c = nullptr;
if (collectPredicates && (!pt->isCtxDependent || (pt->isCtxDependent && inContext))) {
Ref<SemanticContext::Predicate> predicate = pt->getPredicate();
if (collectPredicates && (!pt->isCtxDependent() || (pt->isCtxDependent() && inContext))) {
const auto &predicate = pt->getPredicate();
if (fullCtx) {
// In full context mode, we can evaluate predicates on-the-fly
// during closure, which dramatically reduces the size of
// the config sets. It also obviates the need to test predicates
// later during conflict resolution.
size_t currentPosition = _input->index();
_input->seek(_startIndex);
bool predSucceeds = evalSemanticContext(pt->getPredicate(), _outerContext, config->alt, fullCtx);
bool predSucceeds = evalSemanticContext(predicate, _outerContext, config->alt, fullCtx);
_input->seek(currentPosition);
if (predSucceeds) {
c = std::make_shared<ATNConfig>(*config, pt->target); // no pred context
}
} else {
Ref<const SemanticContext> newSemCtx = SemanticContext::And(config->semanticContext, predicate);
c = std::make_shared<ATNConfig>(*config, pt->target, newSemCtx);
c = std::make_shared<ATNConfig>(*config, pt->target, std::move(newSemCtx));
}
} else {
c = std::make_shared<ATNConfig>(*config, pt->target);
Expand Down
9 changes: 2 additions & 7 deletions runtime/Cpp/runtime/src/atn/PrecedencePredicateTransition.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,7 @@
using namespace antlr4::atn;

PrecedencePredicateTransition::PrecedencePredicateTransition(ATNState *target, int precedence)
: AbstractPredicateTransition(target), precedence(precedence) {
}
: Transition(target), _predicate(std::make_shared<SemanticContext::PrecedencePredicate>(precedence)) {}

TransitionType PrecedencePredicateTransition::getTransitionType() const {
return TransitionType::PRECEDENCE;
Expand All @@ -23,10 +22,6 @@ bool PrecedencePredicateTransition::matches(size_t /*symbol*/, size_t /*minVocab
return false;
}

Ref<SemanticContext::PrecedencePredicate> PrecedencePredicateTransition::getPredicate() const {
return std::make_shared<SemanticContext::PrecedencePredicate>(precedence);
}

std::string PrecedencePredicateTransition::toString() const {
return "PRECEDENCE " + Transition::toString() + " { precedence: " + std::to_string(precedence) + " }";
return "PRECEDENCE " + Transition::toString() + " { precedence: " + std::to_string(getPrecedence()) + " }";
}
21 changes: 12 additions & 9 deletions runtime/Cpp/runtime/src/atn/PrecedencePredicateTransition.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,24 +5,27 @@

#pragma once

#include "atn/AbstractPredicateTransition.h"
#include "SemanticContext.h"
#include "atn/Transition.h"
#include "atn/SemanticContext.h"

namespace antlr4 {
namespace atn {

class ANTLR4CPP_PUBLIC PrecedencePredicateTransition final : public AbstractPredicateTransition {
class ANTLR4CPP_PUBLIC PrecedencePredicateTransition final : public Transition {
public:
const int precedence;

PrecedencePredicateTransition(ATNState *target, int precedence);

int getPrecedence() const { return _predicate->precedence; }

TransitionType getTransitionType() const override;
virtual bool isEpsilon() const override;
virtual bool matches(size_t symbol, size_t minVocabSymbol, size_t maxVocabSymbol) const override;
Ref<SemanticContext::PrecedencePredicate> getPredicate() const;
virtual std::string toString() const override;
bool isEpsilon() const override;
bool matches(size_t symbol, size_t minVocabSymbol, size_t maxVocabSymbol) const override;
std::string toString() const override;

const Ref<const SemanticContext::PrecedencePredicate>& getPredicate() const { return _predicate; }

private:
const std::shared_ptr<const SemanticContext::PrecedencePredicate> _predicate;
};

} // namespace atn
Expand Down
14 changes: 4 additions & 10 deletions runtime/Cpp/runtime/src/atn/PredicateTransition.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@

using namespace antlr4::atn;

PredicateTransition::PredicateTransition(ATNState *target, size_t ruleIndex, size_t predIndex, bool isCtxDependent) : AbstractPredicateTransition(target), ruleIndex(ruleIndex), predIndex(predIndex), isCtxDependent(isCtxDependent) {
}
PredicateTransition::PredicateTransition(ATNState *target, size_t ruleIndex, size_t predIndex, bool isCtxDependent)
: Transition(target), _predicate(std::make_shared<SemanticContext::Predicate>(ruleIndex, predIndex, isCtxDependent)) {}

TransitionType PredicateTransition::getTransitionType() const {
return TransitionType::PREDICATE;
Expand All @@ -22,13 +22,7 @@ bool PredicateTransition::matches(size_t /*symbol*/, size_t /*minVocabSymbol*/,
return false;
}

Ref<SemanticContext::Predicate> PredicateTransition::getPredicate() const {
return std::make_shared<SemanticContext::Predicate>(ruleIndex, predIndex, isCtxDependent);
}

std::string PredicateTransition::toString() const {
return "PREDICATE " + Transition::toString() + " { ruleIndex: " + std::to_string(ruleIndex) +
", predIndex: " + std::to_string(predIndex) + ", isCtxDependent: " + std::to_string(isCtxDependent) + " }";

// Generate and add a predicate context here?
return "PREDICATE " + Transition::toString() + " { ruleIndex: " + std::to_string(getRuleIndex()) +
", predIndex: " + std::to_string(getPredIndex()) + ", isCtxDependent: " + std::to_string(isCtxDependent()) + " }";
}
32 changes: 20 additions & 12 deletions runtime/Cpp/runtime/src/atn/PredicateTransition.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@

#pragma once

#include "atn/AbstractPredicateTransition.h"
#include "SemanticContext.h"
#include "atn/Transition.h"
#include "atn/SemanticContext.h"

namespace antlr4 {
namespace atn {
Expand All @@ -16,23 +16,31 @@ namespace atn {
/// In the ATN, labels will always be exactly one predicate, but the DFA
/// may have to combine a bunch of them as it collects predicates from
/// multiple ATN configurations into a single DFA state.
class ANTLR4CPP_PUBLIC PredicateTransition final : public AbstractPredicateTransition {
class ANTLR4CPP_PUBLIC PredicateTransition final : public Transition {
public:
const size_t ruleIndex;
const size_t predIndex;
const bool isCtxDependent; // e.g., $i ref in pred

PredicateTransition(ATNState *target, size_t ruleIndex, size_t predIndex, bool isCtxDependent);

TransitionType getTransitionType() const override;
size_t getRuleIndex() const {
return _predicate->ruleIndex;
}

virtual bool isEpsilon() const override;
virtual bool matches(size_t symbol, size_t minVocabSymbol, size_t maxVocabSymbol) const override;
size_t getPredIndex() const {
return _predicate->predIndex;
}

Ref<SemanticContext::Predicate> getPredicate() const;
bool isCtxDependent() const {
return _predicate->isCtxDependent;
}

TransitionType getTransitionType() const override;
bool isEpsilon() const override;
bool matches(size_t symbol, size_t minVocabSymbol, size_t maxVocabSymbol) const override;
std::string toString() const override;

virtual std::string toString() const override;
const Ref<const SemanticContext::Predicate>& getPredicate() const { return _predicate; }

private:
const std::shared_ptr<const SemanticContext::Predicate> _predicate;
};

} // namespace atn
Expand Down
Loading

0 comments on commit fb7b12e

Please sign in to comment.