diff --git a/integration_tests/CMakeLists.txt b/integration_tests/CMakeLists.txt index 4e12713..4e0b1e5 100644 --- a/integration_tests/CMakeLists.txt +++ b/integration_tests/CMakeLists.txt @@ -222,3 +222,5 @@ RUN(NAME function_01.cpp LABELS gcc llvm NOFAST) RUN(NAME nbody_01.cpp LABELS gcc llvm NOFAST) RUN(NAME nbody_02.cpp LABELS gcc llvm NOFAST) + +RUN(NAME enum_01.cpp LABELS gcc llvm NOFAST) diff --git a/integration_tests/enum_01.cpp b/integration_tests/enum_01.cpp new file mode 100644 index 0000000..281e24f --- /dev/null +++ b/integration_tests/enum_01.cpp @@ -0,0 +1,38 @@ +#include + +enum Color { + RED = 1, + GREEN = 2, + BLUE = 3 +}; + +void test_color_enum() { + std::cout << RED << " " << GREEN << " " << BLUE << std::endl; + if( RED != 1 ) { + exit(2); + } + if( GREEN != 2 ) { + exit(2); + } + if( BLUE != 3 ) { + exit(2); + } +} + +void test_selected_color(enum Color selected_color) { + enum Color color; + color = selected_color; + if( color != RED ) { + exit(2); + } + std::cout << color << std::endl; +} + +int main() { + +test_color_enum(); +test_selected_color(RED); + +return 0; + +} diff --git a/src/lc/clang_ast_to_asr.cpp b/src/lc/clang_ast_to_asr.cpp index 3df3053..2cbe383 100644 --- a/src/lc/clang_ast_to_asr.cpp +++ b/src/lc/clang_ast_to_asr.cpp @@ -161,6 +161,7 @@ class ClangASTtoASRVisitor: public clang::RecursiveASTVisitor> struct2member_inits; + std::map> scope2enums; explicit ClangASTtoASRVisitor(clang::ASTContext *Context_, Allocator& al_, ASR::asr_t*& tu_): @@ -328,6 +329,13 @@ class ClangASTtoASRVisitor: public clang::RecursiveASTVisitorisVoidType() ) { // do nothing + } else if( clang_type->isEnumeralType() ) { + std::string enum_name = qual_type.getAsString().erase(0, 5); + ASR::symbol_t* enum_sym = current_scope->resolve_symbol(enum_name); + if( !enum_sym ) { + throw std::runtime_error(enum_name + " not found in current scope."); + } + type = ASRUtils::TYPE(ASR::make_Enum_t(al, l, enum_sym)); } else if( clang_type->isCharType() ) { type = ASRUtils::TYPE(ASR::make_Character_t(al, l, 1, -1, nullptr)); } else if( clang_type->isBooleanType() ) { @@ -538,6 +546,84 @@ class ClangASTtoASRVisitor: public clang::RecursiveASTVisitorgetName().str(); + if( current_scope->get_symbol(enum_name) ) { + throw std::runtime_error(enum_name + std::string(" is already defined.")); + } + + SymbolTable* parent_scope = current_scope; + current_scope = al.make_new(parent_scope); + Vec field_names; field_names.reserve(al, 1); + for( auto enum_const_itr = x->enumerator_begin(); + enum_const_itr != x->enumerator_end(); enum_const_itr++ ) { + clang::EnumConstantDecl* enum_const = *enum_const_itr; + std::string enum_const_name = enum_const->getNameAsString(); + field_names.push_back(al, s2c(al, enum_const_name)); + TraverseStmt(enum_const->getInitExpr()); + ASR::expr_t* init_expr = ASRUtils::EXPR(tmp.get()); + ASR::symbol_t* v = ASR::down_cast(ASR::make_Variable_t(al, Lloc(x), + current_scope, s2c(al, enum_const_name), nullptr, 0, ASR::intentType::Local, + init_expr, init_expr, ASR::storage_typeType::Default, ASRUtils::expr_type(init_expr), + nullptr, ASR::abiType::Source, ASR::accessType::Public, ASR::presenceType::Required, false)); + current_scope->add_symbol(enum_const_name, v); + } + + ASR::enumtypeType enum_value_type = ASR::enumtypeType::NonInteger; + ASR::ttype_t* common_type = ASRUtils::TYPE(ASR::make_Integer_t(al, Lloc(x), 4)); + int8_t IntegerConsecutiveFromZero = 1; + int8_t IntegerNotUnique = 0; + int8_t IntegerUnique = 1; + std::map value2count; + for( auto sym: current_scope->get_scope() ) { + ASR::Variable_t* member_var = ASR::down_cast(sym.second); + ASR::expr_t* value = ASRUtils::expr_value(member_var->m_symbolic_value); + int64_t value_int64 = -1; + ASRUtils::extract_value(value, value_int64); + if( value2count.find(value_int64) == value2count.end() ) { + value2count[value_int64] = 0; + } + value2count[value_int64] += 1; + } + int64_t prev = -1; + for( auto itr: value2count ) { + if( itr.second > 1 ) { + IntegerNotUnique = 1; + IntegerUnique = 0; + IntegerConsecutiveFromZero = 0; + break ; + } + if( itr.first - prev != 1 ) { + IntegerConsecutiveFromZero = 0; + } + prev = itr.first; + } + if( IntegerConsecutiveFromZero ) { + if( value2count.find(0) == value2count.end() ) { + IntegerConsecutiveFromZero = 0; + IntegerUnique = 1; + } else { + IntegerUnique = 0; + } + } + LCOMPILERS_ASSERT(IntegerConsecutiveFromZero + IntegerNotUnique + IntegerUnique == 1); + if( IntegerConsecutiveFromZero ) { + enum_value_type = ASR::enumtypeType::IntegerConsecutiveFromZero; + } else if( IntegerNotUnique ) { + enum_value_type = ASR::enumtypeType::IntegerNotUnique; + } else if( IntegerUnique ) { + enum_value_type = ASR::enumtypeType::IntegerUnique; + } + + ASR::symbol_t* enum_t = ASR::down_cast(ASR::make_EnumType_t(al, Lloc(x), + current_scope, s2c(al, enum_name), nullptr, 0, field_names.p, field_names.size(), + ASR::abiType::Source, ASR::accessType::Public, enum_value_type, common_type, nullptr)); + parent_scope->add_symbol(enum_name, enum_t); + current_scope = parent_scope; + scope2enums[current_scope].push_back(enum_t); + return true; + } + bool TraverseParmVarDecl(clang::ParmVarDecl* x) { std::string name = x->getName().str(); if( name == "" ) { @@ -1593,35 +1679,36 @@ class ClangASTtoASRVisitor: public clang::RecursiveASTVisitor(*left_type) ) { + left_type = ASR::down_cast( + ASR::down_cast(left_type)->m_enum_type)->m_type; + } + if( ASR::is_a(*right_type) ) { + right_type = ASR::down_cast( + ASR::down_cast(right_type)->m_enum_type)->m_type; + } cast_helper(lhs, rhs, false); ASRUtils::make_ArrayBroadcast_t_util(al, loc, lhs, rhs); ASR::dimension_t* m_dims; - size_t n_dims = ASRUtils::extract_dimensions_from_ttype( - ASRUtils::expr_type(lhs), m_dims); + size_t n_dims = ASRUtils::extract_dimensions_from_ttype(left_type, m_dims); ASR::ttype_t* result_type = ASRUtils::TYPE(ASR::make_Logical_t(al, loc, 4)); if( n_dims > 0 ) { result_type = ASRUtils::make_Array_t_util(al, loc, result_type, m_dims, n_dims); } - if( ASRUtils::is_integer(*ASRUtils::expr_type(lhs)) && - ASRUtils::is_integer(*ASRUtils::expr_type(rhs)) ) { - tmp = ASR::make_IntegerCompare_t(al, loc, lhs, - cmpop_type, rhs, result_type, nullptr); - } else if( ASRUtils::is_real(*ASRUtils::expr_type(lhs)) && - ASRUtils::is_real(*ASRUtils::expr_type(rhs)) ) { - tmp = ASR::make_RealCompare_t(al, loc, lhs, - cmpop_type, rhs, result_type, nullptr); - } else if( ASRUtils::is_logical(*ASRUtils::expr_type(lhs)) && - ASRUtils::is_logical(*ASRUtils::expr_type(rhs)) ) { - tmp = ASR::make_LogicalCompare_t(al, loc, lhs, - cmpop_type, rhs, result_type, nullptr); - } else if( ASRUtils::is_complex(*ASRUtils::expr_type(lhs)) && - ASRUtils::is_complex(*ASRUtils::expr_type(rhs)) ) { - tmp = ASR::make_ComplexCompare_t(al, loc, lhs, - cmpop_type, rhs, ASRUtils::expr_type(lhs), nullptr); + if( ASRUtils::is_integer(*left_type) && ASRUtils::is_integer(*right_type) ) { + tmp = ASR::make_IntegerCompare_t(al, loc, lhs, cmpop_type, rhs, result_type, nullptr); + } else if( ASRUtils::is_real(*left_type) && ASRUtils::is_real(*right_type) ) { + tmp = ASR::make_RealCompare_t(al, loc, lhs, cmpop_type, rhs, result_type, nullptr); + } else if( ASRUtils::is_logical(*left_type) && ASRUtils::is_logical(*right_type) ) { + tmp = ASR::make_LogicalCompare_t(al, loc, lhs, cmpop_type, rhs, result_type, nullptr); + } else if( ASRUtils::is_complex(*left_type) && ASRUtils::is_complex(*right_type) ) { + tmp = ASR::make_ComplexCompare_t(al, loc, lhs, cmpop_type, rhs, left_type, nullptr); } else { throw std::runtime_error("Only integer, real and complex types are supported so " - "far for comparison operator, found: " + ASRUtils::type_to_str(ASRUtils::expr_type(lhs)) - + " and " + ASRUtils::type_to_str(ASRUtils::expr_type(rhs))); + "far for comparison operator, found: " + ASRUtils::type_to_str(left_type) + + " and " + ASRUtils::type_to_str(right_type)); } } @@ -1926,6 +2013,29 @@ class ClangASTtoASRVisitor: public clang::RecursiveASTVisitorgetType()->isEnumeralType() ) { + SymbolTable* scope = current_scope; + while( scope ) { + for( auto itr = scope2enums[scope].begin(); itr != scope2enums[scope].end(); itr++ ) { + std::string mangled_name = current_scope->get_unique_name( + name + "@" + ASRUtils::symbol_name(*itr)); + ASR::EnumType_t* enumtype_t = ASR::down_cast(*itr); + ASR::symbol_t* enum_member_orig = enumtype_t->m_symtab->resolve_symbol(name); + if( enum_member_orig == nullptr ) { + continue ; + } + ASR::symbol_t* enum_member = ASR::down_cast(ASR::make_ExternalSymbol_t( + al, Lloc(x), current_scope, s2c(al, mangled_name), enum_member_orig, + ASRUtils::symbol_name(*itr), nullptr, 0, s2c(al, name), ASR::accessType::Public)); + current_scope->add_symbol(mangled_name, enum_member); + tmp = ASR::make_EnumValue_t(al, Lloc(x), ASRUtils::EXPR(ASR::make_Var_t(al, Lloc(x), enum_member)), + ASRUtils::TYPE(ASR::make_Enum_t(al, Lloc(x), *itr)), enumtype_t->m_type, + ASRUtils::expr_value(ASR::down_cast(enum_member_orig)->m_symbolic_value)); + return true; + } + scope = scope->parent; + } + } throw std::runtime_error("Symbol " + name + " not found in current scope."); } tmp = ASR::make_Var_t(al, Lloc(x), sym); diff --git a/src/libasr/asr_utils.h b/src/libasr/asr_utils.h index 39f6cf1..b5c3b56 100644 --- a/src/libasr/asr_utils.h +++ b/src/libasr/asr_utils.h @@ -482,6 +482,9 @@ static inline std::string type_to_str(const ASR::ttype_t *t) case ASR::ttypeType::Struct: { return ASRUtils::symbol_name(ASR::down_cast(t)->m_derived_type); } + case ASR::ttypeType::Enum: { + return ASRUtils::symbol_name(ASR::down_cast(t)->m_enum_type); + } case ASR::ttypeType::Class: { return ASRUtils::symbol_name(ASR::down_cast(t)->m_class_type); }