Skip to content

Commit

Permalink
[script] replace: export table with persistant immutable data structure
Browse files Browse the repository at this point in the history
- add: Ast nodes now have a view of its environment.  Structural sharing should mitigate excessive memory usage.
  • Loading branch information
jd28 committed Nov 12, 2023
1 parent c4bf197 commit 97ba7e2
Show file tree
Hide file tree
Showing 4 changed files with 83 additions and 64 deletions.
3 changes: 3 additions & 0 deletions lib/nw/script/Ast.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
#include "Token.hpp"

#include <glm/glm.hpp>
#include <immer/map.hpp>

#include <limits>
#include <memory>
Expand All @@ -14,6 +15,7 @@ namespace nw::script {

struct Nss;
struct Ast;
struct Export;

struct FunctionDecl;
struct FunctionDefinition;
Expand Down Expand Up @@ -94,6 +96,7 @@ struct AstNode {

size_t type_id_ = invalid_type_id;
bool is_const_ = false;
immer::map<std::string, Export> env;
};

#define DEFINE_ACCEPT_VISITOR \
Expand Down
81 changes: 67 additions & 14 deletions lib/nw/script/AstResolver.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
#include "Context.hpp"
#include "Nss.hpp"

#include <immer/map.hpp>

#include <unordered_map>
#include <vector>

Expand All @@ -23,26 +25,32 @@ struct AstResolver : BaseVisitor {
, ctx_{ctx}
{
CHECK_F(!!ctx_, "[script] invalid script context");
env_stack_.push_back({});
}

virtual ~AstResolver() = default;

using ScopeMap = std::unordered_map<std::string, ScopeDecl>;
using ScopeStack = std::vector<ScopeMap>;
using EnvStack = std::vector<immer::map<std::string, Export>>;

Nss* parent_ = nullptr;
Context* ctx_ = nullptr;
ScopeStack scope_stack_;
EnvStack env_stack_;
int loop_stack_ = 0;
int switch_stack_ = 0;
int func_def_stack_ = 0;

// == Resolver Helpers ====================================================
// ========================================================================

void begin_scope()
void begin_scope(bool global = false)
{
scope_stack_.push_back(ScopeMap{});
if (!global) {
env_stack_.push_back(env_stack_.back());
}
}

void declare(NssToken token, Declaration* decl, bool is_type = false)
Expand Down Expand Up @@ -93,16 +101,34 @@ struct AstResolver : BaseVisitor {
} else {
it->second.decl_ready = true;
}
auto& env = env_stack_.back();

auto sym = env.find(s);
Export temp;

if (sym) { temp = *sym; }
if (is_type) {
temp.type = it->second.struct_decl;
} else {
temp.decl = it->second.decl;
}
env_stack_.back() = env.set(s, temp);
}

void end_scope()
void end_scope(bool global = false)
{
scope_stack_.pop_back();
if (!global) { env_stack_.pop_back(); }
}

immer::map<std::string, Export> symbol_table() const
{
return env_stack_[0];
}

Declaration* locate(std::string_view token, Nss* script, bool is_type)
{
if (auto decl = script->locate_export(token, is_type)) {
if (auto decl = script->locate_export(std::string(token), is_type)) {
return decl;
} else {
for (auto& it : reverse(script->ast().includes)) {
Expand Down Expand Up @@ -144,29 +170,22 @@ struct AstResolver : BaseVisitor {
if (auto decl = locate(token, it, is_type)) { return decl; }
}

return ctx_->command_script_->locate_export(token, is_type);
return ctx_->command_script_->locate_export(s, is_type);
}

// == Visitor =============================================================
// ========================================================================

virtual void visit(Ast* script) override
{
begin_scope();
begin_scope(true);
for (const auto& decl : script->decls) {
decl->accept(this);
if (auto d = dynamic_cast<VarDecl*>(decl.get())) {
d->is_const_ = true; // All top level var decls are constant. Only thing that makes sense.
parent_->add_export(std::string(d->identifier.loc.view()), d);
} else if (auto d = dynamic_cast<StructDecl*>(decl.get())) {
parent_->add_export(std::string(d->type.struct_id.loc.view()), d);
} else if (auto d = dynamic_cast<FunctionDecl*>(decl.get())) {
parent_->add_export(std::string(d->identifier.loc.view()), d);
} else if (auto d = dynamic_cast<FunctionDefinition*>(decl.get())) {
parent_->add_export(std::string(d->decl_inline->identifier.loc.view()), d);
d->is_const_ = true; // All top level var decls are constant. Probably wrong??
}
}
end_scope();
end_scope(true);
}

// Decls
Expand Down Expand Up @@ -244,6 +263,7 @@ struct AstResolver : BaseVisitor {

virtual void visit(FunctionDecl* decl) override
{
decl->env = env_stack_.back();
// Check to see if there's been a function definition, if so got to match.
auto fd = resolve(decl->identifier.loc.view(), decl->identifier.loc, false);

Expand All @@ -265,6 +285,7 @@ struct AstResolver : BaseVisitor {

virtual void visit(FunctionDefinition* decl) override
{
decl->env = env_stack_.back();
++func_def_stack_;
// Check to see if there's been a function declaration, if so got to match.
auto fd = resolve(decl->decl_inline->identifier.loc.view(), decl->decl_inline->identifier.loc, false);
Expand Down Expand Up @@ -293,6 +314,7 @@ struct AstResolver : BaseVisitor {

virtual void visit(StructDecl* decl) override
{
decl->env = env_stack_.back();
declare(decl->type.struct_id, decl, true);
decl->type_id_ = ctx_->type_id(decl->type, true);
begin_scope();
Expand All @@ -306,6 +328,7 @@ struct AstResolver : BaseVisitor {

virtual void visit(VarDecl* decl) override
{
decl->env = env_stack_.back();
decl->is_const_ = decl->type.type_qualifier.type == NssTokenType::CONST_;
decl->type_id_ = ctx_->type_id(decl->type);

Expand Down Expand Up @@ -339,6 +362,7 @@ struct AstResolver : BaseVisitor {
// Expressions
virtual void visit(AssignExpression* expr) override
{
expr->env = env_stack_.back();
expr->lhs->accept(this);
expr->rhs->accept(this);

Expand All @@ -357,6 +381,7 @@ struct AstResolver : BaseVisitor {

virtual void visit(BinaryExpression* expr) override
{
expr->env = env_stack_.back();
expr->lhs->accept(this);
expr->rhs->accept(this);

Expand All @@ -376,6 +401,8 @@ struct AstResolver : BaseVisitor {

virtual void visit(CallExpression* expr) override
{
expr->env = env_stack_.back();

auto ve = dynamic_cast<VariableExpression*>(expr->expr.get());
if (!ve) {
// Parser already handles this case
Expand Down Expand Up @@ -434,6 +461,7 @@ struct AstResolver : BaseVisitor {

virtual void visit(ComparisonExpression* expr) override
{
expr->env = env_stack_.back();
expr->lhs->accept(this);
expr->rhs->accept(this);

Expand All @@ -452,6 +480,7 @@ struct AstResolver : BaseVisitor {

virtual void visit(ConditionalExpression* expr) override
{
expr->env = env_stack_.back();
expr->test->accept(this);
if (expr->test->type_id_ != ctx_->type_id("int")) {
ctx_->semantic_error(parent_,
Expand All @@ -476,6 +505,8 @@ struct AstResolver : BaseVisitor {

virtual void visit(DotExpression* expr) override
{
expr->env = env_stack_.back();

auto resolve_struct_member = [this](VariableExpression* var, StructDecl* str) {
for (const auto& it : str->decls) {
if (it->identifier.loc.view() == var->var.loc.view()) {
Expand Down Expand Up @@ -534,13 +565,15 @@ struct AstResolver : BaseVisitor {

virtual void visit(GroupingExpression* expr) override
{
expr->env = env_stack_.back();
expr->expr->accept(this);
expr->type_id_ = expr->expr->type_id_;
expr->is_const_ = expr->expr->is_const_;
}

virtual void visit(LiteralExpression* expr) override
{
expr->env = env_stack_.back();
expr->is_const_ = true;
if (expr->literal.type == NssTokenType::FLOAT_CONST) {
expr->type_id_ = ctx_->type_id("float");
Expand All @@ -556,12 +589,14 @@ struct AstResolver : BaseVisitor {

virtual void visit(LiteralVectorExpression* expr) override
{
expr->env = env_stack_.back();
expr->is_const_ = true;
expr->type_id_ = ctx_->type_id("vector");
}

virtual void visit(LogicalExpression* expr) override
{
expr->env = env_stack_.back();
expr->lhs->accept(this);
expr->rhs->accept(this);

Expand All @@ -575,20 +610,23 @@ struct AstResolver : BaseVisitor {

virtual void visit(PostfixExpression* expr) override
{
expr->env = env_stack_.back();
expr->lhs->accept(this);
expr->type_id_ = expr->lhs->type_id_;
expr->is_const_ = expr->lhs->is_const_;
}

virtual void visit(UnaryExpression* expr) override
{
expr->env = env_stack_.back();
expr->rhs->accept(this);
expr->type_id_ = expr->rhs->type_id_;
expr->is_const_ = expr->rhs->is_const_;
}

virtual void visit(VariableExpression* expr) override
{
expr->env = env_stack_.back();
auto decl = resolve(expr->var.loc.view(), expr->var.loc, false);
if (decl) {
expr->type_id_ = decl->type_id_;
Expand All @@ -603,6 +641,7 @@ struct AstResolver : BaseVisitor {
// Statements
virtual void visit(BlockStatement* stmt) override
{
stmt->env = env_stack_.back();
stmt->type_id_ = ctx_->type_id("void");
for (auto& s : stmt->nodes) {
s->accept(this);
Expand All @@ -611,6 +650,7 @@ struct AstResolver : BaseVisitor {

virtual void visit(DeclStatement* stmt) override
{
stmt->env = env_stack_.back();
size_t ti = invalid_type_id;
for (auto& s : stmt->decls) {
// types of all must be the same;
Expand All @@ -626,6 +666,7 @@ struct AstResolver : BaseVisitor {

virtual void visit(DoStatement* stmt) override
{
stmt->env = env_stack_.back();
++loop_stack_;
begin_scope();
stmt->block->accept(this);
Expand All @@ -644,16 +685,19 @@ struct AstResolver : BaseVisitor {

virtual void visit(EmptyStatement* stmt) override
{
stmt->env = env_stack_.back();
stmt->type_id_ = ctx_->type_id("void");
}

virtual void visit(ExprStatement* stmt) override
{
stmt->env = env_stack_.back();
stmt->expr->accept(this);
}

virtual void visit(IfStatement* stmt) override
{
stmt->env = env_stack_.back();
stmt->type_id_ = ctx_->type_id("void");
stmt->expr->accept(this);

Expand All @@ -676,6 +720,7 @@ struct AstResolver : BaseVisitor {

virtual void visit(ForStatement* stmt) override
{
stmt->env = env_stack_.back();
++loop_stack_;
begin_scope();

Expand Down Expand Up @@ -704,6 +749,8 @@ struct AstResolver : BaseVisitor {

virtual void visit(JumpStatement* stmt) override
{
stmt->env = env_stack_.back();

if (stmt->expr) {
stmt->expr->accept(this);
stmt->type_id_ = stmt->expr->type_id_;
Expand All @@ -725,6 +772,8 @@ struct AstResolver : BaseVisitor {

virtual void visit(LabelStatement* stmt) override
{
stmt->env = env_stack_.back();

if (stmt->type.type == NssTokenType::CASE && switch_stack_ == 0) {
ctx_->semantic_error(parent_, "case statement not within switch", stmt->type.loc);
}
Expand All @@ -745,6 +794,8 @@ struct AstResolver : BaseVisitor {

virtual void visit(SwitchStatement* stmt) override
{
stmt->env = env_stack_.back();

stmt->type_id_ = ctx_->type_id("void");
++switch_stack_;
stmt->target->accept(this);
Expand All @@ -766,6 +817,8 @@ struct AstResolver : BaseVisitor {

virtual void visit(WhileStatement* stmt) override
{
stmt->env = env_stack_.back();

stmt->type_id_ = ctx_->type_id("void");
++loop_stack_;

Expand Down
Loading

0 comments on commit 97ba7e2

Please sign in to comment.