Skip to content

Commit

Permalink
Merge pull request #96 from czgdp1807/enum_01
Browse files Browse the repository at this point in the history
Ported ``integration_tests/enum_01.py`` from LPython and add support for ``enum`` in LC to compile it
  • Loading branch information
czgdp1807 authored Feb 29, 2024
2 parents 530e5e5 + e6ffb6f commit d77844d
Show file tree
Hide file tree
Showing 4 changed files with 173 additions and 20 deletions.
2 changes: 2 additions & 0 deletions integration_tests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -222,3 +222,5 @@ RUN(NAME function_01.cpp LABELS gcc llvm NOFAST)

RUN(NAME nbody_01.cpp LABELS gcc llvm NOFAST)
RUN(NAME nbody_02.cpp LABELS gcc llvm NOFAST)

RUN(NAME enum_01.cpp LABELS gcc llvm NOFAST)
38 changes: 38 additions & 0 deletions integration_tests/enum_01.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
#include <iostream>

enum Color {
RED = 1,
GREEN = 2,
BLUE = 3
};

void test_color_enum() {
std::cout << RED << " " << GREEN << " " << BLUE << std::endl;
if( RED != 1 ) {
exit(2);
}
if( GREEN != 2 ) {
exit(2);
}
if( BLUE != 3 ) {
exit(2);
}
}

void test_selected_color(enum Color selected_color) {
enum Color color;
color = selected_color;
if( color != RED ) {
exit(2);
}
std::cout << color << std::endl;
}

int main() {

test_color_enum();
test_selected_color(RED);

return 0;

}
150 changes: 130 additions & 20 deletions src/lc/clang_ast_to_asr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,7 @@ class ClangASTtoASRVisitor: public clang::RecursiveASTVisitor<ClangASTtoASRVisit
OneTimeUseBool is_break_stmt_present;
bool enable_fall_through;
std::map<ASR::symbol_t*, std::map<std::string, ASR::expr_t*>> struct2member_inits;
std::map<SymbolTable*, std::vector<ASR::symbol_t*>> scope2enums;

explicit ClangASTtoASRVisitor(clang::ASTContext *Context_,
Allocator& al_, ASR::asr_t*& tu_):
Expand Down Expand Up @@ -328,6 +329,13 @@ class ClangASTtoASRVisitor: public clang::RecursiveASTVisitor<ClangASTtoASRVisit
ASR::ttype_t* type = nullptr;
if (clang_type->isVoidType() ) {
// do nothing
} else if( clang_type->isEnumeralType() ) {
std::string enum_name = qual_type.getAsString().erase(0, 5);
ASR::symbol_t* enum_sym = current_scope->resolve_symbol(enum_name);
if( !enum_sym ) {
throw std::runtime_error(enum_name + " not found in current scope.");
}
type = ASRUtils::TYPE(ASR::make_Enum_t(al, l, enum_sym));
} else if( clang_type->isCharType() ) {
type = ASRUtils::TYPE(ASR::make_Character_t(al, l, 1, -1, nullptr));
} else if( clang_type->isBooleanType() ) {
Expand Down Expand Up @@ -538,6 +546,84 @@ class ClangASTtoASRVisitor: public clang::RecursiveASTVisitor<ClangASTtoASRVisit
return true;
}

bool TraverseEnumDecl(clang::EnumDecl* x) {
std::string enum_name = x->getName().str();
if( current_scope->get_symbol(enum_name) ) {
throw std::runtime_error(enum_name + std::string(" is already defined."));
}

SymbolTable* parent_scope = current_scope;
current_scope = al.make_new<SymbolTable>(parent_scope);
Vec<char*> field_names; field_names.reserve(al, 1);
for( auto enum_const_itr = x->enumerator_begin();
enum_const_itr != x->enumerator_end(); enum_const_itr++ ) {
clang::EnumConstantDecl* enum_const = *enum_const_itr;
std::string enum_const_name = enum_const->getNameAsString();
field_names.push_back(al, s2c(al, enum_const_name));
TraverseStmt(enum_const->getInitExpr());
ASR::expr_t* init_expr = ASRUtils::EXPR(tmp.get());
ASR::symbol_t* v = ASR::down_cast<ASR::symbol_t>(ASR::make_Variable_t(al, Lloc(x),
current_scope, s2c(al, enum_const_name), nullptr, 0, ASR::intentType::Local,
init_expr, init_expr, ASR::storage_typeType::Default, ASRUtils::expr_type(init_expr),
nullptr, ASR::abiType::Source, ASR::accessType::Public, ASR::presenceType::Required, false));
current_scope->add_symbol(enum_const_name, v);
}

ASR::enumtypeType enum_value_type = ASR::enumtypeType::NonInteger;
ASR::ttype_t* common_type = ASRUtils::TYPE(ASR::make_Integer_t(al, Lloc(x), 4));
int8_t IntegerConsecutiveFromZero = 1;
int8_t IntegerNotUnique = 0;
int8_t IntegerUnique = 1;
std::map<int64_t, int64_t> value2count;
for( auto sym: current_scope->get_scope() ) {
ASR::Variable_t* member_var = ASR::down_cast<ASR::Variable_t>(sym.second);
ASR::expr_t* value = ASRUtils::expr_value(member_var->m_symbolic_value);
int64_t value_int64 = -1;
ASRUtils::extract_value(value, value_int64);
if( value2count.find(value_int64) == value2count.end() ) {
value2count[value_int64] = 0;
}
value2count[value_int64] += 1;
}
int64_t prev = -1;
for( auto itr: value2count ) {
if( itr.second > 1 ) {
IntegerNotUnique = 1;
IntegerUnique = 0;
IntegerConsecutiveFromZero = 0;
break ;
}
if( itr.first - prev != 1 ) {
IntegerConsecutiveFromZero = 0;
}
prev = itr.first;
}
if( IntegerConsecutiveFromZero ) {
if( value2count.find(0) == value2count.end() ) {
IntegerConsecutiveFromZero = 0;
IntegerUnique = 1;
} else {
IntegerUnique = 0;
}
}
LCOMPILERS_ASSERT(IntegerConsecutiveFromZero + IntegerNotUnique + IntegerUnique == 1);
if( IntegerConsecutiveFromZero ) {
enum_value_type = ASR::enumtypeType::IntegerConsecutiveFromZero;
} else if( IntegerNotUnique ) {
enum_value_type = ASR::enumtypeType::IntegerNotUnique;
} else if( IntegerUnique ) {
enum_value_type = ASR::enumtypeType::IntegerUnique;
}

ASR::symbol_t* enum_t = ASR::down_cast<ASR::symbol_t>(ASR::make_EnumType_t(al, Lloc(x),
current_scope, s2c(al, enum_name), nullptr, 0, field_names.p, field_names.size(),
ASR::abiType::Source, ASR::accessType::Public, enum_value_type, common_type, nullptr));
parent_scope->add_symbol(enum_name, enum_t);
current_scope = parent_scope;
scope2enums[current_scope].push_back(enum_t);
return true;
}

bool TraverseParmVarDecl(clang::ParmVarDecl* x) {
std::string name = x->getName().str();
if( name == "" ) {
Expand Down Expand Up @@ -1593,35 +1679,36 @@ class ClangASTtoASRVisitor: public clang::RecursiveASTVisitor<ClangASTtoASRVisit

void CreateCompareOp(ASR::expr_t* lhs, ASR::expr_t* rhs,
ASR::cmpopType cmpop_type, const Location& loc) {
ASR::ttype_t* left_type = ASRUtils::expr_type(lhs);
ASR::ttype_t* right_type = ASRUtils::expr_type(rhs);
if( ASR::is_a<ASR::Enum_t>(*left_type) ) {
left_type = ASR::down_cast<ASR::EnumType_t>(
ASR::down_cast<ASR::Enum_t>(left_type)->m_enum_type)->m_type;
}
if( ASR::is_a<ASR::Enum_t>(*right_type) ) {
right_type = ASR::down_cast<ASR::EnumType_t>(
ASR::down_cast<ASR::Enum_t>(right_type)->m_enum_type)->m_type;
}
cast_helper(lhs, rhs, false);
ASRUtils::make_ArrayBroadcast_t_util(al, loc, lhs, rhs);
ASR::dimension_t* m_dims;
size_t n_dims = ASRUtils::extract_dimensions_from_ttype(
ASRUtils::expr_type(lhs), m_dims);
size_t n_dims = ASRUtils::extract_dimensions_from_ttype(left_type, m_dims);
ASR::ttype_t* result_type = ASRUtils::TYPE(ASR::make_Logical_t(al, loc, 4));
if( n_dims > 0 ) {
result_type = ASRUtils::make_Array_t_util(al, loc, result_type, m_dims, n_dims);
}
if( ASRUtils::is_integer(*ASRUtils::expr_type(lhs)) &&
ASRUtils::is_integer(*ASRUtils::expr_type(rhs)) ) {
tmp = ASR::make_IntegerCompare_t(al, loc, lhs,
cmpop_type, rhs, result_type, nullptr);
} else if( ASRUtils::is_real(*ASRUtils::expr_type(lhs)) &&
ASRUtils::is_real(*ASRUtils::expr_type(rhs)) ) {
tmp = ASR::make_RealCompare_t(al, loc, lhs,
cmpop_type, rhs, result_type, nullptr);
} else if( ASRUtils::is_logical(*ASRUtils::expr_type(lhs)) &&
ASRUtils::is_logical(*ASRUtils::expr_type(rhs)) ) {
tmp = ASR::make_LogicalCompare_t(al, loc, lhs,
cmpop_type, rhs, result_type, nullptr);
} else if( ASRUtils::is_complex(*ASRUtils::expr_type(lhs)) &&
ASRUtils::is_complex(*ASRUtils::expr_type(rhs)) ) {
tmp = ASR::make_ComplexCompare_t(al, loc, lhs,
cmpop_type, rhs, ASRUtils::expr_type(lhs), nullptr);
if( ASRUtils::is_integer(*left_type) && ASRUtils::is_integer(*right_type) ) {
tmp = ASR::make_IntegerCompare_t(al, loc, lhs, cmpop_type, rhs, result_type, nullptr);
} else if( ASRUtils::is_real(*left_type) && ASRUtils::is_real(*right_type) ) {
tmp = ASR::make_RealCompare_t(al, loc, lhs, cmpop_type, rhs, result_type, nullptr);
} else if( ASRUtils::is_logical(*left_type) && ASRUtils::is_logical(*right_type) ) {
tmp = ASR::make_LogicalCompare_t(al, loc, lhs, cmpop_type, rhs, result_type, nullptr);
} else if( ASRUtils::is_complex(*left_type) && ASRUtils::is_complex(*right_type) ) {
tmp = ASR::make_ComplexCompare_t(al, loc, lhs, cmpop_type, rhs, left_type, nullptr);
} else {
throw std::runtime_error("Only integer, real and complex types are supported so "
"far for comparison operator, found: " + ASRUtils::type_to_str(ASRUtils::expr_type(lhs))
+ " and " + ASRUtils::type_to_str(ASRUtils::expr_type(rhs)));
"far for comparison operator, found: " + ASRUtils::type_to_str(left_type)
+ " and " + ASRUtils::type_to_str(right_type));
}
}

Expand Down Expand Up @@ -1926,6 +2013,29 @@ class ClangASTtoASRVisitor: public clang::RecursiveASTVisitor<ClangASTtoASRVisit
}
}
if( sym == nullptr ) {
if( x->getType()->isEnumeralType() ) {
SymbolTable* scope = current_scope;
while( scope ) {
for( auto itr = scope2enums[scope].begin(); itr != scope2enums[scope].end(); itr++ ) {
std::string mangled_name = current_scope->get_unique_name(
name + "@" + ASRUtils::symbol_name(*itr));
ASR::EnumType_t* enumtype_t = ASR::down_cast<ASR::EnumType_t>(*itr);
ASR::symbol_t* enum_member_orig = enumtype_t->m_symtab->resolve_symbol(name);
if( enum_member_orig == nullptr ) {
continue ;
}
ASR::symbol_t* enum_member = ASR::down_cast<ASR::symbol_t>(ASR::make_ExternalSymbol_t(
al, Lloc(x), current_scope, s2c(al, mangled_name), enum_member_orig,
ASRUtils::symbol_name(*itr), nullptr, 0, s2c(al, name), ASR::accessType::Public));
current_scope->add_symbol(mangled_name, enum_member);
tmp = ASR::make_EnumValue_t(al, Lloc(x), ASRUtils::EXPR(ASR::make_Var_t(al, Lloc(x), enum_member)),
ASRUtils::TYPE(ASR::make_Enum_t(al, Lloc(x), *itr)), enumtype_t->m_type,
ASRUtils::expr_value(ASR::down_cast<ASR::Variable_t>(enum_member_orig)->m_symbolic_value));
return true;
}
scope = scope->parent;
}
}
throw std::runtime_error("Symbol " + name + " not found in current scope.");
}
tmp = ASR::make_Var_t(al, Lloc(x), sym);
Expand Down
3 changes: 3 additions & 0 deletions src/libasr/asr_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -482,6 +482,9 @@ static inline std::string type_to_str(const ASR::ttype_t *t)
case ASR::ttypeType::Struct: {
return ASRUtils::symbol_name(ASR::down_cast<ASR::Struct_t>(t)->m_derived_type);
}
case ASR::ttypeType::Enum: {
return ASRUtils::symbol_name(ASR::down_cast<ASR::Enum_t>(t)->m_enum_type);
}
case ASR::ttypeType::Class: {
return ASRUtils::symbol_name(ASR::down_cast<ASR::Class_t>(t)->m_class_type);
}
Expand Down

0 comments on commit d77844d

Please sign in to comment.