diff --git a/include/vast/Dialect/Parser/Ops.td b/include/vast/Dialect/Parser/Ops.td index 3718c8a86b..97e2ea3b42 100644 --- a/include/vast/Dialect/Parser/Ops.td +++ b/include/vast/Dialect/Parser/Ops.td @@ -100,4 +100,19 @@ def Parse_Ref }]; } +// TODO: Types should match +def Parse_Assign + : Parser_Op< "assign" > + , Arguments< (ins + Parser_AnyDataType:$value, + Parser_AnyDataType:$target + ) > +{ + let summary = "Assignment to referenced value."; + + let assemblyFormat = [{ + $value `to` $target attr-dict functional-type($value, $target) + }]; +} + #endif // VAST_DIALECT_PARSER_OPS diff --git a/include/vast/Util/Terminator.hpp b/include/vast/Util/Terminator.hpp index 79e215f9df..7990161905 100644 --- a/include/vast/Util/Terminator.hpp +++ b/include/vast/Util/Terminator.hpp @@ -6,104 +6,92 @@ VAST_RELAX_WARNINGS #include -#include #include +#include VAST_UNRELAX_WARNINGS #include #include #include -namespace vast -{ - namespace detail - { +namespace vast { + namespace detail { + template< typename self_t > struct terminator_base { - auto &self() { return static_cast< self_t & >( *this ); } - auto &self() const { return static_cast< const self_t & >( *this ); } + auto &self() { return static_cast< self_t & >(*this); } + + auto &self() const { return static_cast< const self_t & >(*this); } template< typename T > - T cast() const - { - if ( !self().has_value() ) + T cast() const { + if (!self().has_value()) { return {}; - return mlir::dyn_cast< T >( self().op_ptr() ); + } + return mlir::dyn_cast< T >(self().op_ptr()); } template< typename T > - T op() const - { + T op() const { auto out = self().template cast< T >(); - VAST_ASSERT( out ); + VAST_ASSERT(out); return out; } - template< typename ... Args > - bool is_one_of() - { - return self().has_value() && (mlir::isa< Args >( self().op_ptr() ) || ... ); + template< typename... Args > + bool is_one_of() { + return self().has_value() && (mlir::isa< Args >(self().op_ptr()) || ...); } - static bool has( mlir::Block &block ) - { - if ( std::distance( block.begin(), block.end() ) == 0 ) + static bool has(block_t &block) { + if (std::distance(block.begin(), block.end()) == 0) { return false; - return self_t::is( &block.back() ); + } + return self_t::is(&block.back()); } - static self_t get( mlir::Block &block ) - { - if ( !has( block ) ) + static self_t get(block_t &block) { + if (!has(block)) { return self_t{}; + } return self_t{ &block.back() }; } - }; - } // detail + } // namespace detail - static inline bool is_terminator( mlir::Operation *op ) - { + static inline bool is_terminator(operation op) { return op->hasTrait< mlir::OpTrait::IsTerminator >(); } - struct hard_terminator_t : std::optional< mlir::Operation * >, - detail::terminator_base< hard_terminator_t > + struct hard_terminator + : std::optional< operation > + , detail::terminator_base< hard_terminator > { - mlir::Operation *op_ptr() const { return **this; } - - static bool is( mlir::Operation *op ) - { - return is_terminator( op ); - } + operation op_ptr() const { return **this; } + static bool is(operation op) { return is_terminator(op); } }; - - struct any_terminator_t : std::optional< Operation * >, - detail::terminator_base< any_terminator_t > + struct any_terminator + : std::optional< operation > + , detail::terminator_base< any_terminator > { - mlir::Operation *op_ptr() const { return **this; } - - static bool is( mlir::Operation *op ) - { - return is_terminator( op ) || core::is_soft_terminator( op ); + operation op_ptr() const { return **this; } + static bool is(operation op) { + return is_terminator(op) || core::is_soft_terminator(op); } - }; template< typename op_t > - struct terminator_t : std::optional< Operation * >, - detail::terminator_base< terminator_t< op_t > > + struct terminator + : std::optional< operation > + , detail::terminator_base< terminator< op_t > > { - using impl = detail::terminator_base< terminator_t< op_t > >; + using impl = detail::terminator_base< terminator< op_t > >; - mlir::Operation *op_ptr() const { return **this; } + operation op_ptr() const { return **this; } - static bool is( mlir::Operation *op ) - { - return mlir::isa< op_t >( op ); - } + static bool is(operation op) { return mlir::isa< op_t >(op); } op_t op() { return this->impl::template op< op_t >(); } }; diff --git a/lib/vast/Conversion/FromHL/EmitLazyRegions.cpp b/lib/vast/Conversion/FromHL/EmitLazyRegions.cpp index 75e2c8f6d8..565fb15c19 100644 --- a/lib/vast/Conversion/FromHL/EmitLazyRegions.cpp +++ b/lib/vast/Conversion/FromHL/EmitLazyRegions.cpp @@ -39,7 +39,7 @@ namespace vast auto lazy_op = [&] { - if (!terminator_t< hl::ValueYieldOp >::get(side.front())) + if (!terminator< hl::ValueYieldOp >::get(side.front())) { return mk_lazy_op(mlir::NoneType::get(rewriter.getContext())); } @@ -93,7 +93,7 @@ namespace vast VAST_PATTERN_CHECK(conv::size(op.getCondRegion()) == 1, "Unsupported shape of cond region of hl::CondOp:\n{0}", op); - auto yield = terminator_t< hl::CondYieldOp >::get(cond_block); + auto yield = terminator< hl::CondYieldOp >::get(cond_block); VAST_PATTERN_CHECK(yield, "Was not able to retrieve cond yield, {0}.", op); rewriter.inlineBlockBefore(&cond_block, op, std::nullopt); diff --git a/lib/vast/Conversion/FromHL/ToLLCF.cpp b/lib/vast/Conversion/FromHL/ToLLCF.cpp index 542c659842..1c47e3a5a1 100644 --- a/lib/vast/Conversion/FromHL/ToLLCF.cpp +++ b/lib/vast/Conversion/FromHL/ToLLCF.cpp @@ -63,7 +63,7 @@ namespace vast::conv { } auto cond_yield(mlir::Block *block) { - auto cond_yield = mlir::cast< hl::CondYieldOp >(hard_terminator_t::get(*block).value()); + auto cond_yield = mlir::cast< hl::CondYieldOp >(hard_terminator::get(*block).value()); VAST_CHECK(cond_yield, "Block does not have a hl::CondYieldOp as terminator."); return cond_yield; } @@ -98,7 +98,7 @@ namespace vast::conv { namespace pattern { auto get_cond_yield(mlir::Block &block) { - return terminator_t< hl::CondYieldOp >::get(block).op(); + return terminator< hl::CondYieldOp >::get(block).op(); } // We do not use patterns, because for example `hl.continue` in for loop is kinda tricky @@ -214,7 +214,7 @@ namespace vast::conv { static logical_result tie(auto &&bld, auto loc, mlir::Block &from, mlir::Block &to) { - if (!empty(from) && any_terminator_t::get(from)) { + if (!empty(from) && any_terminator::get(from)) { return mlir::success(); } @@ -287,7 +287,7 @@ namespace vast::conv { // verification. // We are using any_terminator as it can have for example `hl.return` // or other soft terminator that will get eliminated in this pass. - if (!any_terminator_t::has(*tail_block)) { + if (!any_terminator::has(*tail_block)) { bld.guarded_at_end(tail_block, [&]() { bld->template create< ll::ScopeRet >(op.getLoc()); }); diff --git a/lib/vast/Conversion/Generic/LowerValueCategories.cpp b/lib/vast/Conversion/Generic/LowerValueCategories.cpp index e124b19b9d..bcf8013bf4 100644 --- a/lib/vast/Conversion/Generic/LowerValueCategories.cpp +++ b/lib/vast/Conversion/Generic/LowerValueCategories.cpp @@ -358,7 +358,7 @@ namespace vast::conv { if (!body) return logical_result::success(); - auto yield = terminator_t< yield_op_t >::get(*body); + auto yield = terminator< yield_op_t >::get(*body); VAST_PATTERN_CHECK(yield, "Expected yield in: {0}", op); rewriter.inlineBlockBefore(body, op); diff --git a/lib/vast/Conversion/Parser/ToParser.cpp b/lib/vast/Conversion/Parser/ToParser.cpp index 64eed12f5a..3c134d0b97 100644 --- a/lib/vast/Conversion/Parser/ToParser.cpp +++ b/lib/vast/Conversion/Parser/ToParser.cpp @@ -21,6 +21,7 @@ VAST_UNRELAX_WARNINGS #include "vast/Conversion/TypeConverters/TypeConvertingPattern.hpp" #include "vast/Util/Common.hpp" +#include "vast/Util/Terminator.hpp" #include "vast/Dialect/Parser/Ops.hpp" #include "vast/Dialect/Parser/Types.hpp" @@ -493,13 +494,97 @@ namespace vast::conv { } }; + struct ExprConversion + : parser_conversion_pattern_base< hl::ExprOp > + { + using op_t = hl::ExprOp; + using base = parser_conversion_pattern_base< op_t >; + using base::base; + + using adaptor_t = typename op_t::Adaptor; + + logical_result matchAndRewrite( + op_t op, adaptor_t adaptor, conversion_rewriter &rewriter + ) const override { + auto body = op.getBody(); + if (!body) { + return mlir::failure(); + } + + auto yield = terminator< hl::ValueYieldOp >::get(*body); + VAST_PATTERN_CHECK(yield, "Expected yield in: {0}", op); + + rewriter.inlineBlockBefore(body, op); + rewriter.replaceOp(op, yield.op().getResult()); + rewriter.eraseOp(yield.op()); + + return mlir::success(); + } + + static void legalize(parser_conversion_config &cfg) { + cfg.target.addLegalOp< mlir::UnrealizedConversionCastOp >(); + } + }; + + struct VarDeclConversion + : parser_conversion_pattern_base< hl::VarDeclOp > + { + using op_t = hl::VarDeclOp; + using base = parser_conversion_pattern_base< op_t >; + using base::base; + + using adaptor_t = typename op_t::Adaptor; + + logical_result matchAndRewrite( + op_t op, adaptor_t adaptor, conversion_rewriter &rewriter + ) const override { + auto maybe = to_mlir_type(data_type::maybedata, rewriter.getContext()); + /* auto decl = */ rewriter.create< pr::Decl >(op.getLoc(), op.getSymName(), maybe); + + if (auto &init_region = op.getInitializer(); !init_region.empty()) { + VAST_PATTERN_CHECK(init_region.getBlocks().size() == 1, "Expected single block in: {0}", op); + auto &init_block = init_region.back(); + auto yield = terminator< hl::ValueYieldOp >::get(init_block); + VAST_PATTERN_CHECK(yield, "Expected yield in: {0}", op); + + rewriter.inlineBlockBefore(&init_block, op); + rewriter.setInsertionPointAfter(op); + auto ref = rewriter.create< pr::Ref >(op.getLoc(), maybe, op.getSymName()); + auto value = rewriter.create< mlir::UnrealizedConversionCastOp >( + yield.op().getLoc(), maybe, yield.op().getResult() + ); + rewriter.create< pr::Assign >(yield.op().getLoc(), value.getResult(0), ref); + rewriter.eraseOp(yield.op()); + } + + rewriter.eraseOp(op); + return mlir::success(); + } + + static void legalize(parser_conversion_config &cfg) { + cfg.target.addLegalOp< pr::Decl, pr::Ref, pr::Assign >(); + cfg.target.addLegalOp< mlir::UnrealizedConversionCastOp >(); + } + }; + using operation_conversions = util::type_list< ToNoParse< hl::ConstantOp >, ToNoParse< hl::ImplicitCastOp >, - ToNoParse< hl::CmpOp>, ToNoParse< hl::FCmpOp >, + ToNoParse< hl::CmpOp >, ToNoParse< hl::FCmpOp >, + // Integer arithmetic + ToNoParse< hl::MulIOp >, + ToNoParse< hl::DivSOp >, ToNoParse< hl::DivUOp >, + ToNoParse< hl::RemSOp >, ToNoParse< hl::RemUOp >, + // Floating point arithmetic + ToNoParse< hl::AddFOp >, ToNoParse< hl::SubFOp >, + ToNoParse< hl::MulFOp >, ToNoParse< hl::DivFOp >, + ToNoParse< hl::RemFOp >, + // Other operations + ExprConversion, FuncConversion, ParamConversion, - // DeclRefConversion, + DeclRefConversion, + VarDeclConversion, ReturnConversion, CallConversion >; diff --git a/lib/vast/Conversion/ToLLVM/IRsToLLVM.cpp b/lib/vast/Conversion/ToLLVM/IRsToLLVM.cpp index 3f29451f76..c2649370cf 100644 --- a/lib/vast/Conversion/ToLLVM/IRsToLLVM.cpp +++ b/lib/vast/Conversion/ToLLVM/IRsToLLVM.cpp @@ -1347,7 +1347,7 @@ namespace vast::conv::irstollvm if (!body) return logical_result::success(); - auto yield = terminator_t< yield_op_t >::get(*body); + auto yield = terminator< yield_op_t >::get(*body); VAST_PATTERN_CHECK(yield, "Expected yield in: {0}", op); rewriter.inlineBlockBefore(body, op);