From aa38d6b5ae3ee783566282eafe243579cdd3684f Mon Sep 17 00:00:00 2001 From: xlauko Date: Tue, 8 Oct 2024 10:04:24 +0200 Subject: [PATCH] treewide: Introduce core function interface to avoid implicit symbol from mlir::FunctionOpInterface. --- .../TypeConverters/LLVMTypeConverter.hpp | 2 +- .../TypeConverters/TypeConverter.hpp | 2 + .../TypeConverters/TypeConvertingPattern.hpp | 4 +- include/vast/Dialect/ABI/ABIOps.hpp | 2 + include/vast/Dialect/Builtin/Ops.hpp | 1 - include/vast/Dialect/Core/CoreLazy.td | 1 - include/vast/Dialect/Core/CoreOps.hpp | 3 +- include/vast/Dialect/Core/Func.hpp | 9 +- include/vast/Dialect/Core/Func.td | 7 +- .../Dialect/Core/Interfaces/CMakeLists.txt | 1 + .../Core/Interfaces/DeclStorageInterface.hpp | 2 +- .../Core/Interfaces/DeclStorageInterface.td | 2 +- .../Interfaces/FunctionImplementation.hpp | 108 ++++ .../Core/Interfaces/FunctionInterface.hpp | 249 ++++++++ .../Core/Interfaces/FunctionInterface.td | 566 ++++++++++++++++++ .../vast/Dialect/HighLevel/HighLevelOps.hpp | 1 - .../vast/Dialect/HighLevel/HighLevelOps.td | 4 +- include/vast/Dialect/LowLevel/LowLevelOps.hpp | 1 - include/vast/Interfaces/AST/TypeInterface.td | 2 - include/vast/Util/Symbols.hpp | 2 +- include/vast/Util/TypeUtils.hpp | 8 +- lib/vast/CodeGen/DefaultDeclVisitor.cpp | 2 +- lib/vast/Conversion/ABI/EmitABI.cpp | 40 +- lib/vast/Conversion/ABI/LowerABI.cpp | 2 +- .../Generic/LowerValueCategories.cpp | 2 +- lib/vast/Conversion/ToLLVM/IRsToLLVM.cpp | 6 +- .../Conversion/ToMem/EvictStaticLocals.cpp | 37 +- .../Conversion/ToMem/StripParamLValues.cpp | 2 +- lib/vast/Dialect/Builtin/Ops.cpp | 1 - lib/vast/Dialect/Core/CMakeLists.txt | 1 + lib/vast/Dialect/Core/Func.cpp | 6 +- .../Dialect/Core/Interfaces/CMakeLists.txt | 7 + .../Interfaces/FunctionImplementation.cpp | 345 +++++++++++ .../Core/Interfaces/FunctionInterface.cpp | 366 +++++++++++ lib/vast/Dialect/HighLevel/HighLevelOps.cpp | 1 - .../HighLevel/Transforms/ExportFnInfo.cpp | 4 +- lib/vast/Dialect/LowLevel/LowLevelOps.cpp | 1 - 37 files changed, 1719 insertions(+), 81 deletions(-) create mode 100644 include/vast/Dialect/Core/Interfaces/FunctionImplementation.hpp create mode 100644 include/vast/Dialect/Core/Interfaces/FunctionInterface.hpp create mode 100644 include/vast/Dialect/Core/Interfaces/FunctionInterface.td create mode 100644 lib/vast/Dialect/Core/Interfaces/FunctionImplementation.cpp create mode 100644 lib/vast/Dialect/Core/Interfaces/FunctionInterface.cpp diff --git a/include/vast/Conversion/TypeConverters/LLVMTypeConverter.hpp b/include/vast/Conversion/TypeConverters/LLVMTypeConverter.hpp index a6fb49c3cf..400b1e6abf 100644 --- a/include/vast/Conversion/TypeConverters/LLVMTypeConverter.hpp +++ b/include/vast/Conversion/TypeConverters/LLVMTypeConverter.hpp @@ -139,7 +139,7 @@ namespace vast::conv::tc { maybe_type_t convert_memref_type(mlir::UnrankedMemRefType t) { return {}; } maybe_signature_conversion_t - get_conversion_signature(mlir::FunctionOpInterface fn, bool variadic) { + get_conversion_signature(core::function_op_interface fn, bool variadic) { signature_conversion_t conversion(fn.getNumArguments()); auto fn_type = mlir::dyn_cast< core::FunctionType >(fn.getFunctionType()); VAST_ASSERT(fn_type); diff --git a/include/vast/Conversion/TypeConverters/TypeConverter.hpp b/include/vast/Conversion/TypeConverters/TypeConverter.hpp index 2ce6f60b64..a569e19f26 100644 --- a/include/vast/Conversion/TypeConverters/TypeConverter.hpp +++ b/include/vast/Conversion/TypeConverters/TypeConverter.hpp @@ -13,6 +13,8 @@ VAST_UNRELAX_WARNINGS #include #include "vast/Dialect/Core/CoreTypes.hpp" +#include "vast/Dialect/Core/Interfaces/FunctionInterface.hpp" + #include "vast/Dialect/HighLevel/HighLevelDialect.hpp" #include "vast/Util/Common.hpp" diff --git a/include/vast/Conversion/TypeConverters/TypeConvertingPattern.hpp b/include/vast/Conversion/TypeConverters/TypeConvertingPattern.hpp index 1b167f379e..f21443834b 100644 --- a/include/vast/Conversion/TypeConverters/TypeConvertingPattern.hpp +++ b/include/vast/Conversion/TypeConverters/TypeConvertingPattern.hpp @@ -34,7 +34,7 @@ namespace vast::conv::tc { // TODO(conv:tc): This should probably be some interface instead, since // we are only updating the root? logical_result replace( - mlir::FunctionOpInterface fn, + core::function_op_interface fn, auto &rewriter ) const { auto old_type = fn.getFunctionType(); @@ -115,7 +115,7 @@ namespace vast::conv::tc { operation op, mlir::ArrayRef< mlir::Value >, conversion_rewriter &rewriter ) const override { - if (auto func_op = mlir::dyn_cast< mlir::FunctionOpInterface >(op)) + if (auto func_op = mlir::dyn_cast< core::function_op_interface >(op)) return this->replace(func_op, rewriter); return this->replace(op, rewriter); } diff --git a/include/vast/Dialect/ABI/ABIOps.hpp b/include/vast/Dialect/ABI/ABIOps.hpp index 83f3f9202a..a9c6025d82 100644 --- a/include/vast/Dialect/ABI/ABIOps.hpp +++ b/include/vast/Dialect/ABI/ABIOps.hpp @@ -21,6 +21,8 @@ VAST_RELAX_WARNINGS #include "vast/Dialect/Core/CoreTypes.hpp" #include "vast/Dialect/Core/CoreAttributes.hpp" #include "vast/Dialect/Core/Func.hpp" + +#include "vast/Dialect/Core/Interfaces/FunctionInterface.hpp" #include "vast/Dialect/Core/Interfaces/SymbolInterface.hpp" #include "vast/Dialect/Core/Interfaces/SymbolTableInterface.hpp" diff --git a/include/vast/Dialect/Builtin/Ops.hpp b/include/vast/Dialect/Builtin/Ops.hpp index 914ddafd7c..7513541fcf 100644 --- a/include/vast/Dialect/Builtin/Ops.hpp +++ b/include/vast/Dialect/Builtin/Ops.hpp @@ -5,7 +5,6 @@ #include "vast/Util/Warnings.hpp" VAST_RELAX_WARNINGS -#include #include #include VAST_UNRELAX_WARNINGS diff --git a/include/vast/Dialect/Core/CoreLazy.td b/include/vast/Dialect/Core/CoreLazy.td index 9a72fc783c..3c4bf9d528 100644 --- a/include/vast/Dialect/Core/CoreLazy.td +++ b/include/vast/Dialect/Core/CoreLazy.td @@ -7,7 +7,6 @@ include "mlir/IR/OpBase.td" include "mlir/Interfaces/CallInterfaces.td" include "mlir/Interfaces/CastInterfaces.td" -include "mlir/Interfaces/FunctionInterfaces.td" include "vast/Dialect/Core/Interfaces/SymbolInterface.td" class Core_LazyEval< string mnemonic, list < Trait > traits = [] > diff --git a/include/vast/Dialect/Core/CoreOps.hpp b/include/vast/Dialect/Core/CoreOps.hpp index 2033b5a399..3af048280a 100644 --- a/include/vast/Dialect/Core/CoreOps.hpp +++ b/include/vast/Dialect/Core/CoreOps.hpp @@ -11,7 +11,6 @@ #include "vast/Util/Common.hpp" #include "vast/Util/TypeList.hpp" -#include #include #include #include @@ -24,4 +23,4 @@ namespace vast::core { using module = core::ModuleOp; -} // namespace vast::core \ No newline at end of file +} // namespace vast::core diff --git a/include/vast/Dialect/Core/Func.hpp b/include/vast/Dialect/Core/Func.hpp index 70c78a9009..8ec8153eef 100644 --- a/include/vast/Dialect/Core/Func.hpp +++ b/include/vast/Dialect/Core/Func.hpp @@ -13,6 +13,9 @@ VAST_UNRELAX_WARNINGS #include "vast/Dialect/Core/Linkage.hpp" #include "vast/Dialect/Core/CoreTypes.hpp" +#include "vast/Dialect/Core/Interfaces/FunctionInterface.hpp" +#include "vast/Dialect/Core/Interfaces/FunctionImplementation.hpp" + namespace vast::core { llvm::StringRef getLinkageAttrNameString(); @@ -77,11 +80,11 @@ namespace vast::core { } auto fty = op.getFunctionType(); - mlir::function_interface_impl::printFunctionSignature( + vast::core::function_interface_impl::printFunctionSignature( printer, op, fty.getInputs(), fty.isVarArg(), fty.getResults() ); - mlir::function_interface_impl::printFunctionAttributes( + vast::core::function_interface_impl::printFunctionAttributes( printer, op, { getLinkageAttrNameString(), op.getFunctionTypeAttrName() } ); @@ -160,7 +163,7 @@ namespace vast::core { template< typename DstFuncOp > logical_result convert_and_replace_function(auto src, auto &rewriter) { - return convert_and_replace_function< DstFuncOp >(src, rewriter, src.getName()); + return convert_and_replace_function< DstFuncOp >(src, rewriter, src.getSymbolName()); } } // namespace vast::core diff --git a/include/vast/Dialect/Core/Func.td b/include/vast/Dialect/Core/Func.td index ed7cbe2925..63ea03e124 100644 --- a/include/vast/Dialect/Core/Func.td +++ b/include/vast/Dialect/Core/Func.td @@ -6,8 +6,8 @@ include "mlir/IR/OpBase.td" include "mlir/Interfaces/CallInterfaces.td" -include "mlir/Interfaces/FunctionInterfaces.td" include "vast/Dialect/Core/Interfaces/SymbolInterface.td" +include "vast/Dialect/Core/Interfaces/FunctionInterface.td" include "vast/Dialect/Core/CoreTraits.td" include "vast/Dialect/Core/LinkageHelper.td" @@ -22,9 +22,8 @@ class Core_FuncBaseOp< Dialect dialect, string mnemonic, list< Trait > traits = : Op< dialect, mnemonic, !listconcat(traits, [ AutomaticAllocationScope, - CallableOpInterface, - FunctionOpInterface, IsolatedFromAbove, + Core_FunctionOpInterface, Core_FuncSymbol, NoTerminator ]) @@ -86,7 +85,7 @@ class Core_FuncBaseOp< Dialect dialect, string mnemonic, list< Trait > traits = if (arg_attrs.empty()) return; - mlir::function_interface_impl::addArgAndResultAttrs( + vast::core::function_interface_impl::addArgAndResultAttrs( $_builder, $_state, arg_attrs, res_attrs, getArgAttrsAttrName($_state.name), getResAttrsAttrName($_state.name) ); diff --git a/include/vast/Dialect/Core/Interfaces/CMakeLists.txt b/include/vast/Dialect/Core/Interfaces/CMakeLists.txt index d6ab0139a1..b86f5c3fa1 100644 --- a/include/vast/Dialect/Core/Interfaces/CMakeLists.txt +++ b/include/vast/Dialect/Core/Interfaces/CMakeLists.txt @@ -1,4 +1,5 @@ add_vast_op_interface(DeclStorageInterface) +add_vast_op_interface(FunctionInterface) add_vast_op_interface(TypeDefinitionInterface) add_vast_op_interface(SymbolInterface) diff --git a/include/vast/Dialect/Core/Interfaces/DeclStorageInterface.hpp b/include/vast/Dialect/Core/Interfaces/DeclStorageInterface.hpp index b762a0c9d1..455dfa55ed 100644 --- a/include/vast/Dialect/Core/Interfaces/DeclStorageInterface.hpp +++ b/include/vast/Dialect/Core/Interfaces/DeclStorageInterface.hpp @@ -5,7 +5,6 @@ #include "vast/Util/Warnings.hpp" VAST_RELAX_WARNINGS -#include #include #include #include @@ -13,6 +12,7 @@ VAST_RELAX_WARNINGS VAST_RELAX_WARNINGS #include "vast/Dialect/Core/CoreOps.hpp" +#include "vast/Dialect/Core/Interfaces/FunctionInterface.hpp" #define GET_OP_FWD_DEFINES #include "vast/Dialect/HighLevel/HighLevel.h.inc" diff --git a/include/vast/Dialect/Core/Interfaces/DeclStorageInterface.td b/include/vast/Dialect/Core/Interfaces/DeclStorageInterface.td index 06df4b0108..84e8c0a244 100644 --- a/include/vast/Dialect/Core/Interfaces/DeclStorageInterface.td +++ b/include/vast/Dialect/Core/Interfaces/DeclStorageInterface.td @@ -72,7 +72,7 @@ def Core_DeclStorageInterface : Core_OpInterface< "DeclStorageInterface" > { return kind_attr.getValue(); } auto st = core::get_effective_symbol_table_for< core::var_symbol >($_op)->get_defining_operation(); - if (mlir::isa< mlir::FunctionOpInterface >(st)) + if (mlir::isa< core::function_op_interface >(st)) return DeclContextKind::dc_function; if (st->template hasTrait< core::ScopeLikeTrait >()) return DeclContextKind::dc_function; diff --git a/include/vast/Dialect/Core/Interfaces/FunctionImplementation.hpp b/include/vast/Dialect/Core/Interfaces/FunctionImplementation.hpp new file mode 100644 index 0000000000..a0f28af814 --- /dev/null +++ b/include/vast/Dialect/Core/Interfaces/FunctionImplementation.hpp @@ -0,0 +1,108 @@ +//===- FunctionImplementation.h - Function-like Op utilities ----*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file provides utility functions for implementing function-like +// operations, in particular, parsing, printing and verification components +// common to function-like operations. +// +//===----------------------------------------------------------------------===// + +#ifndef VAST_DIALECT_CORE_FUNCTIONIMPLEMENTATION_HPP +#define VAST_DIALECT_CORE_FUNCTIONIMPLEMENTATION_HPP + +#include "vast/Util/Warnings.hpp" + +VAST_RELAX_WARNINGS +#include +VAST_UNRELAX_WARNINGS + +#include "vast/Dialect/Core/Interfaces/FunctionInterface.hpp" + +namespace vast::core { + +namespace function_interface_impl { + +/// A named class for passing around the variadic flag. +class VariadicFlag { +public: + explicit VariadicFlag(bool variadic) : variadic(variadic) {} + bool isVariadic() const { return variadic; } + +private: + /// Underlying storage. + bool variadic; +}; + +/// Adds argument and result attributes, provided as `argAttrs` and +/// `resultAttrs` arguments, to the list of operation attributes in `result`. +/// Internally, argument and result attributes are stored as dict attributes +/// with special names given by getResultAttrName, getArgumentAttrName. +void addArgAndResultAttrs(::mlir::Builder &builder, ::mlir::OperationState &result, + ::mlir::ArrayRef<::mlir::DictionaryAttr> argAttrs, + ::mlir::ArrayRef<::mlir::DictionaryAttr> resultAttrs, + ::mlir::StringAttr argAttrsName, ::mlir::StringAttr resAttrsName); +void addArgAndResultAttrs(::mlir::Builder &builder, ::mlir::OperationState &result, + ::mlir::ArrayRef<::mlir::OpAsmParser::Argument> args, + ::mlir::ArrayRef<::mlir::DictionaryAttr> resultAttrs, + ::mlir::StringAttr argAttrsName, ::mlir::StringAttr resAttrsName); + +/// Callback type for `parseFunctionOp`, the callback should produce the +/// type that will be associated with a function-like operation from lists of +/// function arguments and results, VariadicFlag indicates whether the function +/// should have variadic arguments; in case of error, it may populate the last +/// argument with a message. +using FuncTypeBuilder = ::mlir::function_ref<::mlir::Type( + ::mlir::Builder &, ::mlir::ArrayRef<::mlir::Type>, ::mlir::ArrayRef<::mlir::Type>, VariadicFlag, std::string &)>; + +/// Parses a function signature using `parser`. The `allowVariadic` argument +/// indicates whether functions with variadic arguments are supported. The +/// trailing arguments are populated by this function with names, types, +/// attributes and locations of the arguments and those of the results. +::mlir::ParseResult +parseFunctionSignature(::mlir::OpAsmParser &parser, bool allowVariadic, + ::mlir::SmallVectorImpl<::mlir::OpAsmParser::Argument> &arguments, + bool &isVariadic, ::mlir::SmallVectorImpl<::mlir::Type> &resultTypes, + ::mlir::SmallVectorImpl<::mlir::DictionaryAttr> &resultAttrs); + +/// Parser implementation for function-like operations. Uses +/// `funcTypeBuilder` to construct the custom function type given lists of +/// input and output types. The parser sets the `typeAttrName` attribute to the +/// resulting function type. If `allowVariadic` is set, the parser will accept +/// trailing ellipsis in the function signature and indicate to the builder +/// whether the function is variadic. If the builder returns a null type, +/// `result` will not contain the `type` attribute. The caller can then add a +/// type, report the error or delegate the reporting to the op's verifier. +::mlir::ParseResult parseFunctionOp(::mlir::OpAsmParser &parser, ::mlir::OperationState &result, + bool allowVariadic, ::mlir::StringAttr typeAttrName, + FuncTypeBuilder funcTypeBuilder, + ::mlir::StringAttr argAttrsName, ::mlir::StringAttr resAttrsName); + +/// Printer implementation for function-like operations. +void printFunctionOp(::mlir::OpAsmPrinter &p, FunctionOpInterface op, bool isVariadic, + ::mlir::StringRef typeAttrName, ::mlir::StringAttr argAttrsName, + ::mlir::StringAttr resAttrsName); + +/// Prints the signature of the function-like operation `op`. Assumes `op` has +/// is a FunctionOpInterface and has passed verification. +void printFunctionSignature(::mlir::OpAsmPrinter &p, FunctionOpInterface op, + ::mlir::ArrayRef<::mlir::Type> argTypes, bool isVariadic, + ::mlir::ArrayRef<::mlir::Type> resultTypes); + +/// Prints the list of function prefixed with the "attributes" keyword. The +/// attributes with names listed in "elided" as well as those used by the +/// function-like operation internally are not printed. Nothing is printed +/// if all attributes are elided. Assumes `op` is a FunctionOpInterface and +/// has passed verification. +void printFunctionAttributes(::mlir::OpAsmPrinter &p, ::mlir::Operation *op, + ::mlir::ArrayRef<::mlir::StringRef> elided = {}); + +} // namespace function_interface_impl + +} // namespace vast::core + +#endif // VAST_DIALECT_CORE_FUNCTIONIMPLEMENTATION_HPP diff --git a/include/vast/Dialect/Core/Interfaces/FunctionInterface.hpp b/include/vast/Dialect/Core/Interfaces/FunctionInterface.hpp new file mode 100644 index 0000000000..23647af2c4 --- /dev/null +++ b/include/vast/Dialect/Core/Interfaces/FunctionInterface.hpp @@ -0,0 +1,249 @@ +//===- FunctionSupport.h - Utility types for function-like ops --*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file defines support types for Operations that represent function-like +// constructs to use. +// +//===----------------------------------------------------------------------===// + +#ifndef VAST_DIALECT_CORE_FUNCTIONINTERFACE_HPP +#define VAST_DIALECT_CORE_FUNCTIONINTERFACE_HPP + +#include "vast/Util/Warnings.hpp" + +VAST_RELAX_WARNINGS +#include +#include +#include +#include +#include +#include +#include +#include +VAST_UNRELAX_WARNINGS + +namespace vast::core { +class FunctionOpInterface; + +using function_op_interface = FunctionOpInterface; + +namespace function_interface_impl { + +/// Returns the dictionary attribute corresponding to the argument at 'index'. +/// If there are no argument attributes at 'index', a null attribute is +/// returned. +mlir::DictionaryAttr getArgAttrDict(FunctionOpInterface op, unsigned index); + +/// Returns the dictionary attribute corresponding to the result at 'index'. +/// If there are no result attributes at 'index', a null attribute is +/// returned. +mlir::DictionaryAttr getResultAttrDict(FunctionOpInterface op, unsigned index); + +/// Return all of the attributes for the argument at 'index'. +mlir::ArrayRef getArgAttrs(FunctionOpInterface op, unsigned index); + +/// Return all of the attributes for the result at 'index'. +mlir::ArrayRef getResultAttrs(FunctionOpInterface op, unsigned index); + +/// Set all of the argument or result attribute dictionaries for a function. The +/// size of `attrs` is expected to match the number of arguments/results of the +/// given `op`. +void setAllArgAttrDicts(FunctionOpInterface op, mlir::ArrayRef attrs); +void setAllArgAttrDicts(FunctionOpInterface op, mlir::ArrayRef attrs); +void setAllResultAttrDicts(FunctionOpInterface op, + mlir::ArrayRef attrs); +void setAllResultAttrDicts(FunctionOpInterface op, mlir::ArrayRef attrs); + +/// Insert the specified arguments and update the function type attribute. +void insertFunctionArguments(FunctionOpInterface op, + mlir::ArrayRef argIndices, mlir::TypeRange argTypes, + mlir::ArrayRef argAttrs, + mlir::ArrayRef argLocs, + unsigned originalNumArgs, mlir::Type newType); + +/// Insert the specified results and update the function type attribute. +void insertFunctionResults(FunctionOpInterface op, + mlir::ArrayRef resultIndices, + mlir::TypeRange resultTypes, + mlir::ArrayRef resultAttrs, + unsigned originalNumResults, mlir::Type newType); + +/// Erase the specified arguments and update the function type attribute. +void eraseFunctionArguments(FunctionOpInterface op, const mlir::BitVector &argIndices, + mlir::Type newType); + +/// Erase the specified results and update the function type attribute. +void eraseFunctionResults(FunctionOpInterface op, + const mlir::BitVector &resultIndices, mlir::Type newType); + +/// Set a FunctionOpInterface operation's type signature. +void setFunctionType(FunctionOpInterface op, mlir::Type newType); + +//===----------------------------------------------------------------------===// +// Function Argument mlir::Attribute. +//===----------------------------------------------------------------------===// + +/// Set the attributes held by the argument at 'index'. +void setArgAttrs(FunctionOpInterface op, unsigned index, + mlir::ArrayRef attributes); +void setArgAttrs(FunctionOpInterface op, unsigned index, + mlir::DictionaryAttr attributes); + +/// If the an attribute exists with the specified name, change it to the new +/// value. Otherwise, add a new attribute with the specified name/value. +template +void setArgAttr(ConcreteType op, unsigned index, mlir::StringAttr name, + mlir::Attribute value) { + mlir::NamedAttrList attributes(op.getArgAttrDict(index)); + mlir::Attribute oldValue = attributes.set(name, value); + + // If the attribute changed, then set the new arg attribute list. + if (value != oldValue) + op.setArgAttrs(index, attributes.getDictionary(value.getContext())); +} + +/// Remove the attribute 'name' from the argument at 'index'. Returns the +/// removed attribute, or nullptr if `name` was not a valid attribute. +template +mlir::Attribute removeArgAttr(ConcreteType op, unsigned index, mlir::StringAttr name) { + // Build an attribute list and remove the attribute at 'name'. + mlir::NamedAttrList attributes(op.getArgAttrDict(index)); + mlir::Attribute removedAttr = attributes.erase(name); + + // If the attribute was removed, then update the argument dictionary. + if (removedAttr) + op.setArgAttrs(index, attributes.getDictionary(removedAttr.getContext())); + return removedAttr; +} + +//===----------------------------------------------------------------------===// +// Function Result mlir::Attribute. +//===----------------------------------------------------------------------===// + +/// Set the attributes held by the result at 'index'. +void setResultAttrs(FunctionOpInterface op, unsigned index, + mlir::ArrayRef attributes); +void setResultAttrs(FunctionOpInterface op, unsigned index, + mlir::DictionaryAttr attributes); + +/// If the an attribute exists with the specified name, change it to the new +/// value. Otherwise, add a new attribute with the specified name/value. +template +void setResultAttr(ConcreteType op, unsigned index, mlir::StringAttr name, + mlir::Attribute value) { + mlir::NamedAttrList attributes(op.getResultAttrDict(index)); + mlir::Attribute oldAttr = attributes.set(name, value); + + // If the attribute changed, then set the new arg attribute list. + if (oldAttr != value) + op.setResultAttrs(index, attributes.getDictionary(value.getContext())); +} + +/// Remove the attribute 'name' from the result at 'index'. +template +mlir::Attribute removeResultAttr(ConcreteType op, unsigned index, mlir::StringAttr name) { + // Build an attribute list and remove the attribute at 'name'. + mlir::NamedAttrList attributes(op.getResultAttrDict(index)); + mlir::Attribute removedAttr = attributes.erase(name); + + // If the attribute was removed, then update the result dictionary. + if (removedAttr) + op.setResultAttrs(index, + attributes.getDictionary(removedAttr.getContext())); + return removedAttr; +} + +/// This function defines the internal implementation of the `verifyTrait` +/// method on FunctionOpInterface::Trait. +template +mlir::LogicalResult verifyTrait(ConcreteOp op) { + if (failed(op.verifyType())) + return llvm::failure(); + + if (mlir::ArrayAttr allArgAttrs = op.getAllArgAttrs()) { + unsigned numArgs = op.getNumArguments(); + if (allArgAttrs.size() != numArgs) { + return op.emitOpError() + << "expects argument attribute array to have the same number of " + "elements as the number of function arguments, got " + << allArgAttrs.size() << ", but expected " << numArgs; + } + for (unsigned i = 0; i != numArgs; ++i) { + mlir::DictionaryAttr argAttrs = + llvm::dyn_cast_or_null(allArgAttrs[i]); + if (!argAttrs) { + return op.emitOpError() << "expects argument attribute dictionary " + "to be a mlir::DictionaryAttr, but got `" + << allArgAttrs[i] << "`"; + } + + // Verify that all of the argument attributes are dialect attributes, i.e. + // that they contain a dialect prefix in their name. Call the dialect, if + // registered, to verify the attributes themselves. + for (auto attr : argAttrs) { + if (!attr.getName().strref().contains('.')) + return op.emitOpError("arguments may only have dialect attributes"); + if (mlir::Dialect *dialect = attr.getNameDialect()) { + if (failed(dialect->verifyRegionArgAttribute(op, /*regionIndex=*/0, + /*argIndex=*/i, attr))) + return llvm::failure(); + } + } + } + } + if (mlir::ArrayAttr allResultAttrs = op.getAllResultAttrs()) { + unsigned numResults = op.getNumResults(); + if (allResultAttrs.size() != numResults) { + return op.emitOpError() + << "expects result attribute array to have the same number of " + "elements as the number of function results, got " + << allResultAttrs.size() << ", but expected " << numResults; + } + for (unsigned i = 0; i != numResults; ++i) { + mlir::DictionaryAttr resultAttrs = + llvm::dyn_cast_or_null(allResultAttrs[i]); + if (!resultAttrs) { + return op.emitOpError() << "expects result attribute dictionary " + "to be a mlir::DictionaryAttr, but got `" + << allResultAttrs[i] << "`"; + } + + // Verify that all of the result attributes are dialect attributes, i.e. + // that they contain a dialect prefix in their name. Call the dialect, if + // registered, to verify the attributes themselves. + for (auto attr : resultAttrs) { + if (!attr.getName().strref().contains('.')) + return op.emitOpError("results may only have dialect attributes"); + if (mlir::Dialect *dialect = attr.getNameDialect()) { + if (failed(dialect->verifyRegionResultAttribute(op, /*regionIndex=*/0, + /*resultIndex=*/i, + attr))) + return llvm::failure(); + } + } + } + } + + // Check that the op has exactly one region for the body. + if (op->getNumRegions() != 1) + return op.emitOpError("expects one region"); + + return op.verifyBody(); +} +} // namespace function_interface_impl +} // namespace mlir + +//===----------------------------------------------------------------------===// +// Tablegen Interface Declarations +//===----------------------------------------------------------------------===// + +VAST_RELAX_WARNINGS +#include "vast/Dialect/Core/Interfaces/FunctionInterface.h.inc" +VAST_UNRELAX_WARNINGS + +#endif // VAST_DIALECT_CORE_FUNCTIONINTERFACE_HPP diff --git a/include/vast/Dialect/Core/Interfaces/FunctionInterface.td b/include/vast/Dialect/Core/Interfaces/FunctionInterface.td new file mode 100644 index 0000000000..139bd691f6 --- /dev/null +++ b/include/vast/Dialect/Core/Interfaces/FunctionInterface.td @@ -0,0 +1,566 @@ +//===- FunctionInterfaces.td - Function interfaces --------*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file contains definitions for interfaces that support the definition of +// "function-like" operations. +// +//===----------------------------------------------------------------------===// + +#ifndef VAST_DIALECT_CORE_FUNCTIONINTERFACE_TD_ +#define VAST_DIALECT_CORE_FUNCTIONINTERFACE_TD_ + +include "mlir/Interfaces/CallInterfaces.td" + +//===----------------------------------------------------------------------===// +// FunctionOpInterface +//===----------------------------------------------------------------------===// + +def Core_FunctionOpInterface : OpInterface<"FunctionOpInterface", [ + CallableOpInterface + ]> { + let cppNamespace = "::vast::core"; + let description = [{ + This is a copy of mlir::FunctionOpInterface with the following changes: + - The interface is moved to the vast::core namespace. + - The interface does not implicitly inherit from SymbolOpInterface. + + This interfaces provides support for interacting with operations that + behave like functions. In particular, these operations: + + - must be symbols, i.e. have the `Symbol` trait. + - must have a single region, that may be comprised with multiple blocks, + that corresponds to the function body. + * when this region is empty, the operation corresponds to an external + function. + * leading arguments of the first block of the region are treated as + function arguments. + + The function, aside from implementing the various interface methods, + should have the following ODS arguments: + + - `function_type` (required) + * A TypeAttr that holds the signature type of the function. + + - `arg_attrs` (optional) + * An ArrayAttr of DictionaryAttr that contains attribute dictionaries + for each of the function arguments. + + - `res_attrs` (optional) + * An ArrayAttr of DictionaryAttr that contains attribute dictionaries + for each of the function results. + }]; + let methods = [ + InterfaceMethod<[{ + Returns the type of the function. + }], + "::mlir::Type", "getFunctionType">, + InterfaceMethod<[{ + Set the type of the function. This method should perform an unsafe + modification to the function type; it should not update argument or + result attributes. + }], + "void", "setFunctionTypeAttr", (ins "::mlir::TypeAttr":$type)>, + + InterfaceMethod<[{ + Returns a clone of the function type with the given argument and + result types. + + Note: The default implementation assumes the function type has + an appropriate clone method: + `Type clone(ArrayRef inputs, ArrayRef results)` + }], + "::mlir::Type", "cloneTypeWith", (ins + "::mlir::TypeRange":$inputs, "::mlir::TypeRange":$results + ), /*methodBody=*/[{}], /*defaultImplementation=*/[{ + return $_op.getFunctionType().clone(inputs, results); + }]>, + + InterfaceMethod<[{ + Verify the contents of the body of this function. + + Note: The default implementation merely checks that if the entry block + exists, it has the same number and type of arguments as the function type. + }], + "::llvm::LogicalResult", "verifyBody", (ins), + /*methodBody=*/[{}], /*defaultImplementation=*/[{ + if ($_op.isExternal()) + return ::mlir::success(); + ::llvm::ArrayRef<::mlir::Type> fnInputTypes = $_op.getArgumentTypes(); + // NOTE: This should just be $_op.front() but access generically + // because the interface methods defined here may be shadowed in + // arbitrary ways. https://github.com/llvm/llvm-project/issues/54807 + ::mlir::Block &entryBlock = $_op->getRegion(0).front(); + + unsigned numArguments = fnInputTypes.size(); + if (entryBlock.getNumArguments() != numArguments) + return $_op.emitOpError("entry block must have ") + << numArguments << " arguments to match function signature"; + + for (unsigned i = 0, e = fnInputTypes.size(); i != e; ++i) { + ::mlir::Type argType = entryBlock.getArgument(i).getType(); + if (fnInputTypes[i] != argType) { + return $_op.emitOpError("type of entry block argument #") + << i << '(' << argType + << ") must match the type of the corresponding argument in " + << "function signature(" << fnInputTypes[i] << ')'; + } + } + + return ::mlir::success(); + }]>, + InterfaceMethod<[{ + Verify the type attribute of the function for derived op-specific + invariants. + }], + "::llvm::LogicalResult", "verifyType", (ins), + /*methodBody=*/[{}], /*defaultImplementation=*/[{ + return ::mlir::success(); + }]>, + ]; + + let extraTraitClassDeclaration = [{ + //===------------------------------------------------------------------===// + // Builders + //===------------------------------------------------------------------===// + + /// Build the function with the given name, attributes, and type. This + /// builder also inserts an entry block into the function body with the + /// given argument types. + static void buildWithEntryBlock( + ::mlir::OpBuilder &builder, ::mlir::OperationState &state, ::mlir::StringRef name, ::mlir::Type type, + ::mlir::ArrayRef<::mlir::NamedAttribute> attrs, ::mlir::TypeRange inputTypes) { + ::mlir::OpBuilder::InsertionGuard g(builder); + state.addAttribute(::mlir::SymbolTable::getSymbolAttrName(), + builder.getStringAttr(name)); + state.addAttribute(ConcreteOp::getFunctionTypeAttrName(state.name), + ::mlir::TypeAttr::get(type)); + state.attributes.append(attrs.begin(), attrs.end()); + + // Add the function body. + ::mlir::Region *bodyRegion = state.addRegion(); + ::mlir::Block *body = builder.createBlock(bodyRegion); + for (::mlir::Type input : inputTypes) + body->addArgument(input, state.location); + } + }]; + let extraSharedClassDeclaration = [{ + /// Block list iterator types. + using BlockListType = ::mlir::Region::BlockListType; + using iterator = BlockListType::iterator; + using reverse_iterator = BlockListType::reverse_iterator; + + /// Block argument iterator types. + using BlockArgListType = ::mlir::Region::BlockArgListType; + using args_iterator = BlockArgListType::iterator; + + //===------------------------------------------------------------------===// + // Body Handling + //===------------------------------------------------------------------===// + + /// Returns true if this function is external, i.e. it has no body. + bool isExternal() { return empty(); } + + /// Return the region containing the body of this function. + ::mlir::Region &getFunctionBody() { return $_op->getRegion(0); } + + /// Delete all blocks from this function. + void eraseBody() { + getFunctionBody().dropAllReferences(); + getFunctionBody().getBlocks().clear(); + } + + /// Return the list of blocks within the function body. + BlockListType &getBlocks() { return getFunctionBody().getBlocks(); } + + iterator begin() { return getFunctionBody().begin(); } + iterator end() { return getFunctionBody().end(); } + reverse_iterator rbegin() { return getFunctionBody().rbegin(); } + reverse_iterator rend() { return getFunctionBody().rend(); } + + /// Returns true if this function has no blocks within the body. + bool empty() { return getFunctionBody().empty(); } + + /// Push a new block to the back of the body region. + void push_back(::mlir::Block *block) { getFunctionBody().push_back(block); } + + /// Push a new block to the front of the body region. + void push_front(::mlir::Block *block) { getFunctionBody().push_front(block); } + + /// Return the last block in the body region. + ::mlir::Block &back() { return getFunctionBody().back(); } + + /// Return the first block in the body region. + ::mlir::Block &front() { return getFunctionBody().front(); } + + /// Add an entry block to an empty function, and set up the block arguments + /// to match the signature of the function. The newly inserted entry block + /// is returned. + ::mlir::Block *addEntryBlock() { + assert(empty() && "function already has an entry block"); + ::mlir::Block *entry = new ::mlir::Block(); + push_back(entry); + + // FIXME: Allow for passing in locations for these arguments instead of using + // the operations location. + ::llvm::ArrayRef<::mlir::Type> inputTypes = $_op.getArgumentTypes(); + ::llvm::SmallVector<::mlir::Location> locations(inputTypes.size(), + $_op.getOperation()->getLoc()); + entry->addArguments(inputTypes, locations); + return entry; + } + + /// Add a normal block to the end of the function's block list. The function + /// should at least already have an entry block. + ::mlir::Block *addBlock() { + assert(!empty() && "function should at least have an entry block"); + push_back(new ::mlir::Block()); + return &back(); + } + + //===------------------------------------------------------------------===// + // Type Attribute Handling + //===------------------------------------------------------------------===// + + /// Change the type of this function in place. This is an extremely dangerous + /// operation and it is up to the caller to ensure that this is legal for + /// this function, and to restore invariants: + /// - the entry block args must be updated to match the function params. + /// - the argument/result attributes may need an update: if the new type + /// has less parameters we drop the extra attributes, if there are more + /// parameters they won't have any attributes. + void setType(::mlir::Type newType) { + ::vast::core::function_interface_impl::setFunctionType($_op, newType); + } + + //===------------------------------------------------------------------===// + // Argument and Result Handling + //===------------------------------------------------------------------===// + + /// Returns the number of function arguments. + unsigned getNumArguments() { return $_op.getArgumentTypes().size(); } + + /// Returns the number of function results. + unsigned getNumResults() { return $_op.getResultTypes().size(); } + + /// Returns the entry block function argument at the given index. + ::mlir::BlockArgument getArgument(unsigned idx) { + return getFunctionBody().getArgument(idx); + } + + /// Support argument iteration. + args_iterator args_begin() { return getFunctionBody().args_begin(); } + args_iterator args_end() { return getFunctionBody().args_end(); } + BlockArgListType getArguments() { return getFunctionBody().getArguments(); } + + /// Insert a single argument of type `argType` with attributes `argAttrs` and + /// location `argLoc` at `argIndex`. + void insertArgument(unsigned argIndex, ::mlir::Type argType, ::mlir::DictionaryAttr argAttrs, + ::mlir::Location argLoc) { + insertArguments({argIndex}, {argType}, {argAttrs}, {argLoc}); + } + + /// Inserts arguments with the listed types, attributes, and locations at the + /// listed indices. `argIndices` must be sorted. Arguments are inserted in the + /// order they are listed, such that arguments with identical index will + /// appear in the same order that they were listed here. + void insertArguments(::llvm::ArrayRef argIndices, ::mlir::TypeRange argTypes, + ::llvm::ArrayRef<::mlir::DictionaryAttr> argAttrs, + ::llvm::ArrayRef<::mlir::Location> argLocs) { + unsigned originalNumArgs = $_op.getNumArguments(); + ::mlir::Type newType = $_op.getTypeWithArgsAndResults( + argIndices, argTypes, /*resultIndices=*/{}, /*resultTypes=*/{}); + ::vast::core::function_interface_impl::insertFunctionArguments( + $_op, argIndices, argTypes, argAttrs, argLocs, + originalNumArgs, newType); + } + + /// Insert a single result of type `resultType` at `resultIndex`. + void insertResult(unsigned resultIndex, ::mlir::Type resultType, + ::mlir::DictionaryAttr resultAttrs) { + insertResults({resultIndex}, {resultType}, {resultAttrs}); + } + + /// Inserts results with the listed types at the listed indices. + /// `resultIndices` must be sorted. Results are inserted in the order they are + /// listed, such that results with identical index will appear in the same + /// order that they were listed here. + void insertResults(::llvm::ArrayRef resultIndices, ::mlir::TypeRange resultTypes, + ::llvm::ArrayRef<::mlir::DictionaryAttr> resultAttrs) { + unsigned originalNumResults = $_op.getNumResults(); + ::mlir::Type newType = $_op.getTypeWithArgsAndResults( + /*argIndices=*/{}, /*argTypes=*/{}, resultIndices, resultTypes); + ::vast::core::function_interface_impl::insertFunctionResults( + $_op, resultIndices, resultTypes, resultAttrs, + originalNumResults, newType); + } + + /// Erase a single argument at `argIndex`. + void eraseArgument(unsigned argIndex) { + ::llvm::BitVector argsToErase($_op.getNumArguments()); + argsToErase.set(argIndex); + eraseArguments(argsToErase); + } + + /// Erases the arguments listed in `argIndices`. + void eraseArguments(const ::llvm::BitVector &argIndices) { + ::mlir::Type newType = $_op.getTypeWithoutArgs(argIndices); + ::vast::core::function_interface_impl::eraseFunctionArguments( + $_op, argIndices, newType); + } + + /// Erase a single result at `resultIndex`. + void eraseResult(unsigned resultIndex) { + ::llvm::BitVector resultsToErase($_op.getNumResults()); + resultsToErase.set(resultIndex); + eraseResults(resultsToErase); + } + + /// Erases the results listed in `resultIndices`. + void eraseResults(const ::llvm::BitVector &resultIndices) { + ::mlir::Type newType = $_op.getTypeWithoutResults(resultIndices); + ::vast::core::function_interface_impl::eraseFunctionResults( + $_op, resultIndices, newType); + } + + /// Return the type of this function with the specified arguments and + /// results inserted. This is used to update the function's signature in + /// the `insertArguments` and `insertResults` methods. The arrays must be + /// sorted by increasing index. + ::mlir::Type getTypeWithArgsAndResults( + ::llvm::ArrayRef argIndices, ::mlir::TypeRange argTypes, + ::llvm::ArrayRef resultIndices, ::mlir::TypeRange resultTypes) { + ::llvm::SmallVector<::mlir::Type> argStorage, resultStorage; + ::mlir::TypeRange newArgTypes = insertTypesInto( + $_op.getArgumentTypes(), argIndices, argTypes, argStorage); + ::mlir::TypeRange newResultTypes = insertTypesInto( + $_op.getResultTypes(), resultIndices, resultTypes, resultStorage); + return $_op.cloneTypeWith(newArgTypes, newResultTypes); + } + + /// Return the type of this function without the specified arguments and + /// results. This is used to update the function's signature in the + /// `eraseArguments` and `eraseResults` methods. + ::mlir::Type getTypeWithoutArgsAndResults( + const ::llvm::BitVector &argIndices, const ::llvm::BitVector &resultIndices) { + ::llvm::SmallVector<::mlir::Type> argStorage, resultStorage; + ::mlir::TypeRange newArgTypes = filterTypesOut( + $_op.getArgumentTypes(), argIndices, argStorage); + ::mlir::TypeRange newResultTypes = filterTypesOut( + $_op.getResultTypes(), resultIndices, resultStorage); + return $_op.cloneTypeWith(newArgTypes, newResultTypes); + } + ::mlir::Type getTypeWithoutArgs(const ::llvm::BitVector &argIndices) { + ::llvm::SmallVector<::mlir::Type> argStorage; + ::mlir::TypeRange newArgTypes = filterTypesOut( + $_op.getArgumentTypes(), argIndices, argStorage); + return $_op.cloneTypeWith(newArgTypes, $_op.getResultTypes()); + } + ::mlir::Type getTypeWithoutResults(const ::llvm::BitVector &resultIndices) { + ::llvm::SmallVector<::mlir::Type> resultStorage; + ::mlir::TypeRange newResultTypes = filterTypesOut( + $_op.getResultTypes(), resultIndices, resultStorage); + return $_op.cloneTypeWith($_op.getArgumentTypes(), newResultTypes); + } + + //===------------------------------------------------------------------===// + // Argument Attributes + //===------------------------------------------------------------------===// + + /// Return all of the attributes for the argument at 'index'. + ::llvm::ArrayRef<::mlir::NamedAttribute> getArgAttrs(unsigned index) { + return ::vast::core::function_interface_impl::getArgAttrs($_op, index); + } + + /// Return an ArrayAttr containing all argument attribute dictionaries of + /// this function, or nullptr if no arguments have attributes. + ::mlir::ArrayAttr getAllArgAttrs() { return $_op.getArgAttrsAttr(); } + + /// Return all argument attributes of this function. + void getAllArgAttrs(::llvm::SmallVectorImpl<::mlir::DictionaryAttr> &result) { + if (::mlir::ArrayAttr argAttrs = getAllArgAttrs()) { + auto argAttrRange = argAttrs.template getAsRange<::mlir::DictionaryAttr>(); + result.append(argAttrRange.begin(), argAttrRange.end()); + } else { + result.append($_op.getNumArguments(), + ::mlir::DictionaryAttr::get(this->getOperation()->getContext())); + } + } + + /// Return the specified attribute, if present, for the argument at 'index', + /// null otherwise. + ::mlir::Attribute getArgAttr(unsigned index, ::mlir::StringAttr name) { + auto argDict = getArgAttrDict(index); + return argDict ? argDict.get(name) : nullptr; + } + ::mlir::Attribute getArgAttr(unsigned index, ::llvm::StringRef name) { + auto argDict = getArgAttrDict(index); + return argDict ? argDict.get(name) : nullptr; + } + + template + AttrClass getArgAttrOfType(unsigned index, ::mlir::StringAttr name) { + return ::llvm::dyn_cast_or_null(getArgAttr(index, name)); + } + template + AttrClass getArgAttrOfType(unsigned index, ::llvm::StringRef name) { + return ::llvm::dyn_cast_or_null(getArgAttr(index, name)); + } + + /// Set the attributes held by the argument at 'index'. + void setArgAttrs(unsigned index, ::llvm::ArrayRef<::mlir::NamedAttribute> attributes) { + ::vast::core::function_interface_impl::setArgAttrs($_op, index, attributes); + } + + /// Set the attributes held by the argument at 'index'. `attributes` may be + /// null, in which case any existing argument attributes are removed. + void setArgAttrs(unsigned index, ::mlir::DictionaryAttr attributes) { + ::vast::core::function_interface_impl::setArgAttrs($_op, index, attributes); + } + void setAllArgAttrs(::llvm::ArrayRef<::mlir::DictionaryAttr> attributes) { + assert(attributes.size() == $_op.getNumArguments()); + ::vast::core::function_interface_impl::setAllArgAttrDicts($_op, attributes); + } + void setAllArgAttrs(::llvm::ArrayRef<::mlir::Attribute> attributes) { + assert(attributes.size() == $_op.getNumArguments()); + ::vast::core::function_interface_impl::setAllArgAttrDicts($_op, attributes); + } + void setAllArgAttrs(::mlir::ArrayAttr attributes) { + assert(attributes.size() == $_op.getNumArguments()); + $_op.setArgAttrsAttr(attributes); + } + + /// If the an attribute exists with the specified name, change it to the new + /// value. Otherwise, add a new attribute with the specified name/value. + void setArgAttr(unsigned index, ::mlir::StringAttr name, ::mlir::Attribute value) { + ::vast::core::function_interface_impl::setArgAttr($_op, index, name, value); + } + void setArgAttr(unsigned index, ::llvm::StringRef name, ::mlir::Attribute value) { + setArgAttr(index, + ::mlir::StringAttr::get(this->getOperation()->getContext(), name), + value); + } + + /// Remove the attribute 'name' from the argument at 'index'. Return the + /// attribute that was erased, or nullptr if there was no attribute with + /// such name. + ::mlir::Attribute removeArgAttr(unsigned index, ::mlir::StringAttr name) { + return ::vast::core::function_interface_impl::removeArgAttr($_op, index, name); + } + ::mlir::Attribute removeArgAttr(unsigned index, ::llvm::StringRef name) { + return removeArgAttr( + index, ::mlir::StringAttr::get(this->getOperation()->getContext(), name)); + } + + //===------------------------------------------------------------------===// + // Result Attributes + //===------------------------------------------------------------------===// + + /// Return all of the attributes for the result at 'index'. + ::llvm::ArrayRef<::mlir::NamedAttribute> getResultAttrs(unsigned index) { + return ::vast::core::function_interface_impl::getResultAttrs($_op, index); + } + + /// Return an ArrayAttr containing all result attribute dictionaries of this + /// function, or nullptr if no result have attributes. + ::mlir::ArrayAttr getAllResultAttrs() { return $_op.getResAttrsAttr(); } + + /// Return all result attributes of this function. + void getAllResultAttrs(::llvm::SmallVectorImpl<::mlir::DictionaryAttr> &result) { + if (::mlir::ArrayAttr argAttrs = getAllResultAttrs()) { + auto argAttrRange = argAttrs.template getAsRange<::mlir::DictionaryAttr>(); + result.append(argAttrRange.begin(), argAttrRange.end()); + } else { + result.append($_op.getNumResults(), + ::mlir::DictionaryAttr::get(this->getOperation()->getContext())); + } + } + + /// Return the specified attribute, if present, for the result at 'index', + /// null otherwise. + ::mlir::Attribute getResultAttr(unsigned index, ::mlir::StringAttr name) { + auto argDict = getResultAttrDict(index); + return argDict ? argDict.get(name) : nullptr; + } + ::mlir::Attribute getResultAttr(unsigned index, ::llvm::StringRef name) { + auto argDict = getResultAttrDict(index); + return argDict ? argDict.get(name) : nullptr; + } + + template + AttrClass getResultAttrOfType(unsigned index, ::mlir::StringAttr name) { + return ::llvm::dyn_cast_or_null(getResultAttr(index, name)); + } + template + AttrClass getResultAttrOfType(unsigned index, ::llvm::StringRef name) { + return ::llvm::dyn_cast_or_null(getResultAttr(index, name)); + } + + /// Set the attributes held by the result at 'index'. + void setResultAttrs(unsigned index, ::llvm::ArrayRef<::mlir::NamedAttribute> attributes) { + ::vast::core::function_interface_impl::setResultAttrs($_op, index, attributes); + } + + /// Set the attributes held by the result at 'index'. `attributes` may be + /// null, in which case any existing argument attributes are removed. + void setResultAttrs(unsigned index, ::mlir::DictionaryAttr attributes) { + ::vast::core::function_interface_impl::setResultAttrs($_op, index, attributes); + } + void setAllResultAttrs(::llvm::ArrayRef<::mlir::DictionaryAttr> attributes) { + assert(attributes.size() == $_op.getNumResults()); + ::vast::core::function_interface_impl::setAllResultAttrDicts( + $_op, attributes); + } + void setAllResultAttrs(::llvm::ArrayRef<::mlir::Attribute> attributes) { + assert(attributes.size() == $_op.getNumResults()); + ::vast::core::function_interface_impl::setAllResultAttrDicts( + $_op, attributes); + } + void setAllResultAttrs(::mlir::ArrayAttr attributes) { + assert(attributes.size() == $_op.getNumResults()); + $_op.setResAttrsAttr(attributes); + } + + /// If the an attribute exists with the specified name, change it to the new + /// value. Otherwise, add a new attribute with the specified name/value. + void setResultAttr(unsigned index, ::mlir::StringAttr name, ::mlir::Attribute value) { + ::vast::core::function_interface_impl::setResultAttr($_op, index, name, value); + } + void setResultAttr(unsigned index, ::llvm::StringRef name, ::mlir::Attribute value) { + setResultAttr(index, + ::mlir::StringAttr::get(this->getOperation()->getContext(), name), + value); + } + + /// Remove the attribute 'name' from the result at 'index'. Return the + /// attribute that was erased, or nullptr if there was no attribute with + /// such name. + ::mlir::Attribute removeResultAttr(unsigned index, ::mlir::StringAttr name) { + return ::vast::core::function_interface_impl::removeResultAttr($_op, index, name); + } + + /// Returns the dictionary attribute corresponding to the argument at + /// 'index'. If there are no argument attributes at 'index', a null + /// attribute is returned. + ::mlir::DictionaryAttr getArgAttrDict(unsigned index) { + assert(index < $_op.getNumArguments() && "invalid argument number"); + return ::vast::core::function_interface_impl::getArgAttrDict($_op, index); + } + + /// Returns the dictionary attribute corresponding to the result at 'index'. + /// If there are no result attributes at 'index', a null attribute is + /// returned. + ::mlir::DictionaryAttr getResultAttrDict(unsigned index) { + assert(index < $_op.getNumResults() && "invalid result number"); + return ::vast::core::function_interface_impl::getResultAttrDict($_op, index); + } + }]; + + let verify = "return function_interface_impl::verifyTrait(cast($_op));"; +} + +#endif // VAST_DIALECT_CORE_FUNCTIONINTERFACE_TD_ diff --git a/include/vast/Dialect/HighLevel/HighLevelOps.hpp b/include/vast/Dialect/HighLevel/HighLevelOps.hpp index 85bff6e749..1155d55710 100644 --- a/include/vast/Dialect/HighLevel/HighLevelOps.hpp +++ b/include/vast/Dialect/HighLevel/HighLevelOps.hpp @@ -6,7 +6,6 @@ VAST_RELAX_WARNINGS #include -#include #include VAST_UNRELAX_WARNINGS diff --git a/include/vast/Dialect/HighLevel/HighLevelOps.td b/include/vast/Dialect/HighLevel/HighLevelOps.td index f578f0eb17..c2b315d3b2 100644 --- a/include/vast/Dialect/HighLevel/HighLevelOps.td +++ b/include/vast/Dialect/HighLevel/HighLevelOps.td @@ -80,7 +80,7 @@ def HighLevel_FuncOp } $_state.attributes.append(attrs.begin(), attrs.end()); - mlir::function_interface_impl::addArgAndResultAttrs( + vast::core::function_interface_impl::addArgAndResultAttrs( $_builder, $_state, arg_attrs, res_attrs, getArgAttrsAttrName($_state.name), getResAttrsAttrName($_state.name) ); @@ -386,7 +386,7 @@ def HighLevel_CallOp build($_builder, $_state, mlir::SymbolRefAttr::get($_builder.getContext(), callee), results, operands); }]>, OpBuilder< (ins "FuncOp":$callee, CArg< "mlir::ValueRange", "{}" >:$operands), [{ - build($_builder, $_state, callee.getName(), callee.getFunctionType().getResults(), operands); + build($_builder, $_state, callee.getSymName(), callee.getFunctionType().getResults(), operands); }]> ]; diff --git a/include/vast/Dialect/LowLevel/LowLevelOps.hpp b/include/vast/Dialect/LowLevel/LowLevelOps.hpp index 9cbe3cf11c..23f0dd49b1 100644 --- a/include/vast/Dialect/LowLevel/LowLevelOps.hpp +++ b/include/vast/Dialect/LowLevel/LowLevelOps.hpp @@ -12,7 +12,6 @@ VAST_RELAX_WARNINGS #include #include -#include #include #include VAST_RELAX_WARNINGS diff --git a/include/vast/Interfaces/AST/TypeInterface.td b/include/vast/Interfaces/AST/TypeInterface.td index 15a21cf963..489af64b5c 100644 --- a/include/vast/Interfaces/AST/TypeInterface.td +++ b/include/vast/Interfaces/AST/TypeInterface.td @@ -6,6 +6,4 @@ include "mlir/IR/OpBase.td" include "vast/Interfaces/AST/Common.td" - - #endif // VAST_INTERFACES_AST_TYPE_INTERFACE diff --git a/include/vast/Util/Symbols.hpp b/include/vast/Util/Symbols.hpp index abf3edc2be..8c3a13c6d4 100644 --- a/include/vast/Util/Symbols.hpp +++ b/include/vast/Util/Symbols.hpp @@ -40,7 +40,7 @@ namespace vast::util template< typename Yield > void functions(mlir::Operation *op, Yield &&yield) { - // TODO use mlir::FunctionOpInterface? + // TODO use core::function_op_interface? op->walk([yield = std::forward< Yield >(yield)](hl::FuncOp fn, const mlir::WalkStage &stage) { yield(fn); }); diff --git a/include/vast/Util/TypeUtils.hpp b/include/vast/Util/TypeUtils.hpp index d604277b5e..6b0a21e72b 100644 --- a/include/vast/Util/TypeUtils.hpp +++ b/include/vast/Util/TypeUtils.hpp @@ -2,12 +2,6 @@ #pragma once -#include - -VAST_RELAX_WARNINGS -#include -VAST_UNRELAX_WARNINGS - #include #include @@ -86,7 +80,7 @@ namespace vast bool has_type_somewhere(operation op, auto &&accept) { auto contains_in_function_type = [&] { - if (auto fn = mlir::dyn_cast< mlir::FunctionOpInterface >(op)) { + if (auto fn = mlir::dyn_cast< core::function_op_interface >(op)) { return contains_subtype(fn.getResultTypes(), accept) || contains_subtype(fn.getArgumentTypes(), accept); } diff --git a/lib/vast/CodeGen/DefaultDeclVisitor.cpp b/lib/vast/CodeGen/DefaultDeclVisitor.cpp index f02cf00177..93af7f06bf 100644 --- a/lib/vast/CodeGen/DefaultDeclVisitor.cpp +++ b/lib/vast/CodeGen/DefaultDeclVisitor.cpp @@ -258,7 +258,7 @@ namespace vast::cg { operation default_decl_visitor::VisitParmVarDecl(const clang::ParmVarDecl *decl) { auto blk = bld.getInsertionBlock(); - if (auto fn = mlir::dyn_cast< mlir::FunctionOpInterface >(blk->getParentOp())) { + if (auto fn = mlir::dyn_cast< core::function_op_interface >(blk->getParentOp())) { auto param_index = decl->getFunctionScopeIndex(); return bld.compose< hl::ParmVarDeclOp >() .bind(self.location(decl)) diff --git a/lib/vast/Conversion/ABI/EmitABI.cpp b/lib/vast/Conversion/ABI/EmitABI.cpp index 9bd45ff727..d0e101e9b5 100644 --- a/lib/vast/Conversion/ABI/EmitABI.cpp +++ b/lib/vast/Conversion/ABI/EmitABI.cpp @@ -68,7 +68,7 @@ namespace vast abi_info_map_t< R > out; auto gather = [&](R op, const mlir::WalkStage &) { - auto name = op.getName(); + auto name = mlir::cast< core::func_symbol >(op.getOperation()).getSymbolName(); out.emplace( name.str(), abi::make_x86_64(op, dl) ); return mlir::WalkResult::advance(); @@ -78,7 +78,7 @@ namespace vast return out; } - using func_abi_info_t = abi_info_map_t< mlir::FunctionOpInterface >; + using func_abi_info_t = abi_info_map_t< core::function_op_interface >; // TODO(conv:abi): Remove as we most likely do not need this. struct TypeConverter : conv::tc::mixins< TypeConverter >, @@ -98,7 +98,7 @@ namespace vast struct abi_info_utils { using types_t = std::vector< mlir::Type >; - using abi_info_t = abi::func_info< mlir::FunctionOpInterface >; + using abi_info_t = abi::func_info< core::function_op_interface >; const auto &self() const { return static_cast< const Self & >(*this); } auto &self() { return static_cast< Self & >(*this); } @@ -224,7 +224,7 @@ namespace vast abi::FuncOp make() { auto wrapper = core::convert_function_without_body< abi::FuncOp >( - op, rewriter, conv::abi::abi_func_name_prefix + op.getName().str(), + op, rewriter, conv::abi::abi_func_name_prefix + op.getSymbolName().str(), this->abified_type(op.isVarArg()) ); @@ -399,9 +399,9 @@ namespace vast using values_t = std::vector< mlir::Value >; - auto get_callee() -> mlir::FunctionOpInterface { + auto get_callee() -> core::function_op_interface { auto caller = mlir::dyn_cast< VastCallOpInterface >(*op); - auto callee = mlir::dyn_cast< mlir::FunctionOpInterface >( + auto callee = mlir::dyn_cast< core::function_op_interface >( caller.resolveCallable()); VAST_ASSERT(callee); return callee; @@ -689,10 +689,11 @@ namespace vast }; - template< typename Op > - struct func_type : mlir::OpConversionPattern< Op > + template< typename op_t > + struct func_type : mlir::OpConversionPattern< op_t > { - using Base = mlir::OpConversionPattern< Op >; + using base = mlir::OpConversionPattern< op_t >; + using adaptor_t = typename op_t::Adaptor; TypeConverter &tc; const func_abi_info_t &abi_info_map; @@ -700,19 +701,18 @@ namespace vast func_type(TypeConverter &tc, const func_abi_info_t &abi_info_map, mcontext_t &mctx) - : Base(tc, &mctx), tc(tc), abi_info_map(abi_info_map) + : base(tc, &mctx), tc(tc), abi_info_map(abi_info_map) {} - mlir::LogicalResult matchAndRewrite( - Op op, typename Op::Adaptor ops, - conversion_rewriter &rewriter) const override - { - auto abi_map_it = abi_info_map.find(op.getName().str()); + logical_result matchAndRewrite( + op_t op, adaptor_t ops, conversion_rewriter &rewriter + ) const override { + auto abi_map_it = abi_info_map.find(op.getSymbolName().str()); if (abi_map_it == abi_info_map.end()) return mlir::failure(); const auto &abi_info = abi_map_it->second; - abi_transform< Op >({ op, ops, rewriter }, abi_info).make(); + abi_transform< op_t >({ op, ops, rewriter }, abi_info).make(); rewriter.eraseOp(op); return mlir::success(); } @@ -768,7 +768,7 @@ namespace vast if (!func) return mlir::failure(); - auto name = func.getName(); + auto name = func.getSymbolName(); if (!name.consume_front(conv::abi::abi_func_name_prefix)) return mlir::failure(); @@ -818,8 +818,8 @@ namespace vast auto should_transform = [&](operation op) { // TODO(conv:abi): We should always emit main with a fixed type. - if (auto fn = mlir::dyn_cast< mlir::FunctionOpInterface >(op)) - return fn.getName() == "main"; + if (auto fn = mlir::dyn_cast< core::func_symbol >(op)) + return fn.getSymbolName() == "main"; return true; }; @@ -875,7 +875,7 @@ namespace vast const auto &dl_analysis = this->getAnalysis< mlir::DataLayoutAnalysis >(); auto tc = TypeConverter(dl_analysis.getAtOrAbove(op), mctx); - auto abi_info_map = collect_abi_info< mlir::FunctionOpInterface >( + auto abi_info_map = collect_abi_info< core::function_op_interface >( op, dl_analysis.getAtOrAbove(op)); if (mlir::failed(run(first_phase(tc, abi_info_map)))) diff --git a/lib/vast/Conversion/ABI/LowerABI.cpp b/lib/vast/Conversion/ABI/LowerABI.cpp index 1bb33604e7..d7eaa8b1ff 100644 --- a/lib/vast/Conversion/ABI/LowerABI.cpp +++ b/lib/vast/Conversion/ABI/LowerABI.cpp @@ -543,7 +543,7 @@ namespace vast logical_result matchAndRewrite( op_t op, adaptor_t adaptor, conversion_rewriter &rewriter) const override { - auto name = op.getName(); + auto name = op.getSymbolName(); if (!name.consume_front(conv::abi::abi_func_name_prefix)) return mlir::failure(); diff --git a/lib/vast/Conversion/Generic/LowerValueCategories.cpp b/lib/vast/Conversion/Generic/LowerValueCategories.cpp index ebbae7b9e0..64d321a139 100644 --- a/lib/vast/Conversion/Generic/LowerValueCategories.cpp +++ b/lib/vast/Conversion/Generic/LowerValueCategories.cpp @@ -139,7 +139,7 @@ namespace vast::conv { using base = base_pattern< op_t >; VAST_DEFINE_REWRITE { - auto func_op = mlir::dyn_cast< mlir::FunctionOpInterface >(op.getOperation()); + auto func_op = mlir::dyn_cast< core::function_op_interface >(op.getOperation()); if (!func_op) { return mlir::failure(); } diff --git a/lib/vast/Conversion/ToLLVM/IRsToLLVM.cpp b/lib/vast/Conversion/ToLLVM/IRsToLLVM.cpp index bcfb8a2445..07c6596478 100644 --- a/lib/vast/Conversion/ToLLVM/IRsToLLVM.cpp +++ b/lib/vast/Conversion/ToLLVM/IRsToLLVM.cpp @@ -454,7 +454,7 @@ namespace vast::conv::irstollvm VAST_CHECK(linkage, "Attempting lower function without set linkage {0}", func_op); auto new_func = rewriter.create< llvm_func_op >( func_op.getLoc(), - func_op.getName(), + func_op.getSymbolName(), target_type, core::convert_linkage_to_llvm(linkage.value()), func_op.isVarArg(), LLVM::CConv::C @@ -988,11 +988,11 @@ namespace vast::conv::irstollvm } auto callee = caller.resolveCallable(); - if (!callee && !mlir::isa< mlir::FunctionOpInterface >(callee)) { + if (!callee && !mlir::isa< core::function_op_interface >(callee)) { return logical_result::failure(); } - auto fn = mlir::cast< mlir::FunctionOpInterface >(callee); + auto fn = mlir::cast< core::function_op_interface >(callee); auto rtys = type_converter().convert_types_to_types(fn.getResultTypes()); if (!rtys) { diff --git a/lib/vast/Conversion/ToMem/EvictStaticLocals.cpp b/lib/vast/Conversion/ToMem/EvictStaticLocals.cpp index 84cefaf2d0..ee45dab148 100644 --- a/lib/vast/Conversion/ToMem/EvictStaticLocals.cpp +++ b/lib/vast/Conversion/ToMem/EvictStaticLocals.cpp @@ -28,21 +28,23 @@ namespace vast::conv { logical_result matchAndRewrite( hl::VarDeclOp op, adaptor_t adaptor, conversion_rewriter &rewriter ) const override { - auto parent_fn = op->getParentOfType< mlir::FunctionOpInterface >(); - if (!parent_fn) + auto fn = op->getParentOfType< core::function_op_interface >(); + if (!fn) return mlir::failure(); auto guard = insertion_guard(rewriter); - auto &module_block = parent_fn->getParentOfType< core::ModuleOp >().getBody().front(); + auto &module_block = fn->getParentOfType< core::ModuleOp >().getBody().front(); rewriter.setInsertionPoint(&module_block, module_block.begin()); + auto fn_symbol = mlir::dyn_cast< core::func_symbol >(fn.getOperation()); + auto new_decl = rewriter.create< hl::VarDeclOp >( - op.getLoc(), - op.getType(), - (parent_fn.getName() + "." + op.getSymName()).str(), - op.getStorageClass(), - op.getThreadStorageClass(), - std::optional(core::GlobalLinkageKind::InternalLinkage) + op.getLoc(), + op.getType(), + (fn_symbol.getSymbolName() + "." + op.getSymName()).str(), + op.getStorageClass(), + op.getThreadStorageClass(), + std::optional(core::GlobalLinkageKind::InternalLinkage) ); // Save current context informationinto the op to make sure the information stays valid @@ -58,7 +60,7 @@ namespace vast::conv { static void legalize(conversion_target &trg) { trg.addDynamicallyLegalOp< hl::VarDeclOp >([] (hl::VarDeclOp op) { - return !(op.isStaticLocal() && op->getParentOfType< mlir::FunctionOpInterface >()); + return !(op.isStaticLocal() && op->getParentOfType< core::function_op_interface >()); }); } }; @@ -75,15 +77,16 @@ namespace vast::conv { ) const override { auto var = core::symbol_table::lookup< core::var_symbol >(op, op.getName()); if (auto decl_storage = mlir::dyn_cast< core::DeclStorageInterface>(var)) { - auto parent_fn = op->getParentOfType< mlir::FunctionOpInterface >(); + auto fn = op->getParentOfType< core::function_op_interface >(); - if (!parent_fn || !decl_storage.isStaticLocal()) + if (!fn || !decl_storage.isStaticLocal()) return mlir::failure(); + auto fn_symbol = mlir::dyn_cast< core::func_symbol >(fn.getOperation()); + rewriter.replaceOpWithNewOp< hl::DeclRefOp >( - op, - op.getType(), - (parent_fn.getName() + "." + op.getName()).str() + op, op.getType(), + (fn_symbol.getSymbolName() + "." + op.getName()).str() ); return mlir::success(); } @@ -93,8 +96,8 @@ namespace vast::conv { static void legalize(conversion_target &trg) { trg.addDynamicallyLegalOp< hl::DeclRefOp >([&](hl::DeclRefOp op) { auto var = core::symbol_table::lookup< core::var_symbol >(op, op.getName()); - if (auto decl_storage = mlir::dyn_cast< core::DeclStorageInterface>(var)) { - return !(decl_storage.isStaticLocal() && var->getParentOfType< mlir::FunctionOpInterface >()); + if (auto storage = mlir::dyn_cast< core::DeclStorageInterface >(var)) { + return !(storage.isStaticLocal() && var->getParentOfType< core::function_op_interface >()); } return (bool)var; }); diff --git a/lib/vast/Conversion/ToMem/StripParamLValues.cpp b/lib/vast/Conversion/ToMem/StripParamLValues.cpp index 508c4c3563..c95894afeb 100644 --- a/lib/vast/Conversion/ToMem/StripParamLValues.cpp +++ b/lib/vast/Conversion/ToMem/StripParamLValues.cpp @@ -67,7 +67,7 @@ namespace vast::conv { conversion_target trg(mctx); trg.markUnknownOpDynamicallyLegal([] (operation op) { - if (auto fn = mlir::dyn_cast< mlir::FunctionOpInterface >(op)) { + if (auto fn = mlir::dyn_cast< core::function_op_interface >(op)) { auto fty = mlir::cast< core::FunctionType >(fn.getFunctionType()); return rns::all_of(fty.getInputs(), is_not_lvalue_type); } diff --git a/lib/vast/Dialect/Builtin/Ops.cpp b/lib/vast/Dialect/Builtin/Ops.cpp index 1f97ac95bc..a96547e97d 100644 --- a/lib/vast/Dialect/Builtin/Ops.cpp +++ b/lib/vast/Dialect/Builtin/Ops.cpp @@ -9,7 +9,6 @@ VAST_RELAX_WARNINGS #include #include -#include #include #include diff --git a/lib/vast/Dialect/Core/CMakeLists.txt b/lib/vast/Dialect/Core/CMakeLists.txt index 57edcd697c..cc2e3ac662 100644 --- a/lib/vast/Dialect/Core/CMakeLists.txt +++ b/lib/vast/Dialect/Core/CMakeLists.txt @@ -11,6 +11,7 @@ add_vast_dialect_library(Core LINK_LIBS PRIVATE VASTAliasTypeInterface + VASTFunctionInterface ) add_subdirectory(Interfaces) diff --git a/lib/vast/Dialect/Core/Func.cpp b/lib/vast/Dialect/Core/Func.cpp index 7ce19bc2d3..fb9d570b10 100644 --- a/lib/vast/Dialect/Core/Func.cpp +++ b/lib/vast/Dialect/Core/Func.cpp @@ -17,6 +17,8 @@ VAST_UNRELAX_WARNINGS #include "vast/Dialect/Core/CoreDialect.hpp" #include "vast/Dialect/Core/Linkage.hpp" +#include "vast/Dialect/Core/Interfaces/FunctionInterface.hpp" + #include "vast/Util/Common.hpp" #include "vast/Util/Region.hpp" @@ -55,7 +57,7 @@ namespace vast::core } bool is_variadic = false; - if (mlir::failed(mlir::function_interface_impl::parseFunctionSignature( + if (mlir::failed(vast::core::function_interface_impl::parseFunctionSignature( parser, /*allowVariadic=*/true, arguments, is_variadic, result_types, result_attrs ))) { return mlir::failure(); @@ -79,7 +81,7 @@ namespace vast::core // TODO: Add the attributes to the function arguments. // VAST_ASSERT(result_attrs.size() == result_types.size()); - // return mlir::function_interface_impl::addArgAndResultAttrs( + // return vast::core::function_interface_impl::addArgAndResultAttrs( // builder, state, arguments, result_attrs // ); diff --git a/lib/vast/Dialect/Core/Interfaces/CMakeLists.txt b/lib/vast/Dialect/Core/Interfaces/CMakeLists.txt index 6f91fa4033..b6ed89fd04 100644 --- a/lib/vast/Dialect/Core/Interfaces/CMakeLists.txt +++ b/lib/vast/Dialect/Core/Interfaces/CMakeLists.txt @@ -1,5 +1,7 @@ set(VAST_OPTIONAL_SOURCES DeclStorageInterface.cpp + FunctionInterface.cpp + FunctionImplementation.cpp SymbolInterface.cpp SymbolTableInterface.cpp SymbolRefInterface.cpp @@ -10,6 +12,11 @@ add_vast_interface_library(DeclStorageInterface DeclStorageInterface.cpp ) +add_vast_interface_library(FunctionInterface + FunctionInterface.cpp + FunctionImplementation.cpp +) + add_vast_interface_library(SymbolInterface SymbolInterface.cpp ) diff --git a/lib/vast/Dialect/Core/Interfaces/FunctionImplementation.cpp b/lib/vast/Dialect/Core/Interfaces/FunctionImplementation.cpp new file mode 100644 index 0000000000..5c012c2240 --- /dev/null +++ b/lib/vast/Dialect/Core/Interfaces/FunctionImplementation.cpp @@ -0,0 +1,345 @@ +//===- FunctionImplementation.cpp - Utilities for function-like ops -------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "vast/Dialect/Core/Interfaces/FunctionImplementation.hpp" + +VAST_RELAX_WARNINGS +#include +#include + +using namespace vast::core; +using namespace mlir; + +static ParseResult +parseFunctionArgumentList(OpAsmParser &parser, bool allowVariadic, + SmallVectorImpl &arguments, + bool &isVariadic) { + + // Parse the function arguments. The argument list either has to consistently + // have ssa-id's followed by types, or just be a type list. It isn't ok to + // sometimes have SSA ID's and sometimes not. + isVariadic = false; + + return parser.parseCommaSeparatedList( + OpAsmParser::Delimiter::Paren, [&]() -> ParseResult { + // Ellipsis must be at end of the list. + if (isVariadic) + return parser.emitError( + parser.getCurrentLocation(), + "variadic arguments must be in the end of the argument list"); + + // Handle ellipsis as a special case. + if (allowVariadic && succeeded(parser.parseOptionalEllipsis())) { + // This is a variadic designator. + isVariadic = true; + return success(); // Stop parsing arguments. + } + // Parse argument name if present. + OpAsmParser::Argument argument; + auto argPresent = parser.parseOptionalArgument( + argument, /*allowType=*/true, /*allowAttrs=*/true); + if (argPresent.has_value()) { + if (failed(argPresent.value())) + return failure(); // Present but malformed. + + // Reject this if the preceding argument was missing a name. + if (!arguments.empty() && arguments.back().ssaName.name.empty()) + return parser.emitError(argument.ssaName.location, + "expected type instead of SSA identifier"); + + } else { + argument.ssaName.location = parser.getCurrentLocation(); + // Otherwise we just have a type list without SSA names. Reject + // this if the preceding argument had a name. + if (!arguments.empty() && !arguments.back().ssaName.name.empty()) + return parser.emitError(argument.ssaName.location, + "expected SSA identifier"); + + NamedAttrList attrs; + if (parser.parseType(argument.type) || + parser.parseOptionalAttrDict(attrs) || + parser.parseOptionalLocationSpecifier(argument.sourceLoc)) + return failure(); + argument.attrs = attrs.getDictionary(parser.getContext()); + } + arguments.push_back(argument); + return success(); + }); +} + +/// Parse a function result list. +/// +/// function-result-list ::= function-result-list-parens +/// | non-function-type +/// function-result-list-parens ::= `(` `)` +/// | `(` function-result-list-no-parens `)` +/// function-result-list-no-parens ::= function-result (`,` function-result)* +/// function-result ::= type attribute-dict? +/// +static ParseResult +parseFunctionResultList(OpAsmParser &parser, SmallVectorImpl &resultTypes, + SmallVectorImpl &resultAttrs) { + if (failed(parser.parseOptionalLParen())) { + // We already know that there is no `(`, so parse a type. + // Because there is no `(`, it cannot be a function type. + Type ty; + if (parser.parseType(ty)) + return failure(); + resultTypes.push_back(ty); + resultAttrs.emplace_back(); + return success(); + } + + // Special case for an empty set of parens. + if (succeeded(parser.parseOptionalRParen())) + return success(); + + // Parse individual function results. + if (parser.parseCommaSeparatedList([&]() -> ParseResult { + resultTypes.emplace_back(); + resultAttrs.emplace_back(); + NamedAttrList attrs; + if (parser.parseType(resultTypes.back()) || + parser.parseOptionalAttrDict(attrs)) + return failure(); + resultAttrs.back() = attrs.getDictionary(parser.getContext()); + return success(); + })) + return failure(); + + return parser.parseRParen(); +} + +ParseResult function_interface_impl::parseFunctionSignature( + OpAsmParser &parser, bool allowVariadic, + SmallVectorImpl &arguments, bool &isVariadic, + SmallVectorImpl &resultTypes, + SmallVectorImpl &resultAttrs) { + if (parseFunctionArgumentList(parser, allowVariadic, arguments, isVariadic)) + return failure(); + if (succeeded(parser.parseOptionalArrow())) + return parseFunctionResultList(parser, resultTypes, resultAttrs); + return success(); +} + +void function_interface_impl::addArgAndResultAttrs( + Builder &builder, OperationState &result, ArrayRef argAttrs, + ArrayRef resultAttrs, StringAttr argAttrsName, + StringAttr resAttrsName) { + auto nonEmptyAttrsFn = [](DictionaryAttr attrs) { + return attrs && !attrs.empty(); + }; + // Convert the specified array of dictionary attrs (which may have null + // entries) to an ArrayAttr of dictionaries. + auto getArrayAttr = [&](ArrayRef dictAttrs) { + SmallVector attrs; + for (auto &dict : dictAttrs) + attrs.push_back(dict ? dict : builder.getDictionaryAttr({})); + return builder.getArrayAttr(attrs); + }; + + // Add the attributes to the function arguments. + if (llvm::any_of(argAttrs, nonEmptyAttrsFn)) + result.addAttribute(argAttrsName, getArrayAttr(argAttrs)); + + // Add the attributes to the function results. + if (llvm::any_of(resultAttrs, nonEmptyAttrsFn)) + result.addAttribute(resAttrsName, getArrayAttr(resultAttrs)); +} + +void function_interface_impl::addArgAndResultAttrs( + Builder &builder, OperationState &result, + ArrayRef args, ArrayRef resultAttrs, + StringAttr argAttrsName, StringAttr resAttrsName) { + SmallVector argAttrs; + for (const auto &arg : args) + argAttrs.push_back(arg.attrs); + addArgAndResultAttrs(builder, result, argAttrs, resultAttrs, argAttrsName, + resAttrsName); +} + +ParseResult function_interface_impl::parseFunctionOp( + OpAsmParser &parser, OperationState &result, bool allowVariadic, + StringAttr typeAttrName, FuncTypeBuilder funcTypeBuilder, + StringAttr argAttrsName, StringAttr resAttrsName) { + SmallVector entryArgs; + SmallVector resultAttrs; + SmallVector resultTypes; + auto &builder = parser.getBuilder(); + + // Parse visibility. + (void)impl::parseOptionalVisibilityKeyword(parser, result.attributes); + + // Parse the name as a symbol. + StringAttr nameAttr; + if (parser.parseSymbolName(nameAttr, SymbolTable::getSymbolAttrName(), + result.attributes)) + return failure(); + + // Parse the function signature. + SMLoc signatureLocation = parser.getCurrentLocation(); + bool isVariadic = false; + if (parseFunctionSignature(parser, allowVariadic, entryArgs, isVariadic, + resultTypes, resultAttrs)) + return failure(); + + std::string errorMessage; + SmallVector argTypes; + argTypes.reserve(entryArgs.size()); + for (auto &arg : entryArgs) + argTypes.push_back(arg.type); + Type type = funcTypeBuilder(builder, argTypes, resultTypes, + VariadicFlag(isVariadic), errorMessage); + if (!type) { + return parser.emitError(signatureLocation) + << "failed to construct function type" + << (errorMessage.empty() ? "" : ": ") << errorMessage; + } + result.addAttribute(typeAttrName, TypeAttr::get(type)); + + // If function attributes are present, parse them. + NamedAttrList parsedAttributes; + SMLoc attributeDictLocation = parser.getCurrentLocation(); + if (parser.parseOptionalAttrDictWithKeyword(parsedAttributes)) + return failure(); + + // Disallow attributes that are inferred from elsewhere in the attribute + // dictionary. + for (StringRef disallowed : + {SymbolTable::getVisibilityAttrName(), SymbolTable::getSymbolAttrName(), + typeAttrName.getValue()}) { + if (parsedAttributes.get(disallowed)) + return parser.emitError(attributeDictLocation, "'") + << disallowed + << "' is an inferred attribute and should not be specified in the " + "explicit attribute dictionary"; + } + result.attributes.append(parsedAttributes); + + // Add the attributes to the function arguments. + assert(resultAttrs.size() == resultTypes.size()); + addArgAndResultAttrs(builder, result, entryArgs, resultAttrs, argAttrsName, + resAttrsName); + + // Parse the optional function body. The printer will not print the body if + // its empty, so disallow parsing of empty body in the parser. + auto *body = result.addRegion(); + SMLoc loc = parser.getCurrentLocation(); + OptionalParseResult parseResult = + parser.parseOptionalRegion(*body, entryArgs, + /*enableNameShadowing=*/false); + if (parseResult.has_value()) { + if (failed(*parseResult)) + return failure(); + // Function body was parsed, make sure its not empty. + if (body->empty()) + return parser.emitError(loc, "expected non-empty function body"); + } + return success(); +} + +/// Print a function result list. The provided `attrs` must either be null, or +/// contain a set of DictionaryAttrs of the same arity as `types`. +static void printFunctionResultList(OpAsmPrinter &p, ArrayRef types, + ArrayAttr attrs) { + assert(!types.empty() && "Should not be called for empty result list."); + assert((!attrs || attrs.size() == types.size()) && + "Invalid number of attributes."); + + auto &os = p.getStream(); + bool needsParens = types.size() > 1 || llvm::isa(types[0]) || + (attrs && !llvm::cast(attrs[0]).empty()); + if (needsParens) + os << '('; + llvm::interleaveComma(llvm::seq(0, types.size()), os, [&](size_t i) { + p.printType(types[i]); + if (attrs) + p.printOptionalAttrDict(llvm::cast(attrs[i]).getValue()); + }); + if (needsParens) + os << ')'; +} + +void function_interface_impl::printFunctionSignature( + OpAsmPrinter &p, FunctionOpInterface op, ArrayRef argTypes, + bool isVariadic, ArrayRef resultTypes) { + Region &body = op->getRegion(0); + bool isExternal = body.empty(); + + p << '('; + ArrayAttr argAttrs = op.getArgAttrsAttr(); + for (unsigned i = 0, e = argTypes.size(); i < e; ++i) { + if (i > 0) + p << ", "; + + if (!isExternal) { + ArrayRef attrs; + if (argAttrs) + attrs = llvm::cast(argAttrs[i]).getValue(); + p.printRegionArgument(body.getArgument(i), attrs); + } else { + p.printType(argTypes[i]); + if (argAttrs) + p.printOptionalAttrDict( + llvm::cast(argAttrs[i]).getValue()); + } + } + + if (isVariadic) { + if (!argTypes.empty()) + p << ", "; + p << "..."; + } + + p << ')'; + + if (!resultTypes.empty()) { + p.getStream() << " -> "; + auto resultAttrs = op.getResAttrsAttr(); + printFunctionResultList(p, resultTypes, resultAttrs); + } +} + +void function_interface_impl::printFunctionAttributes( + OpAsmPrinter &p, Operation *op, ArrayRef elided) { + // Print out function attributes, if present. + SmallVector ignoredAttrs = {SymbolTable::getSymbolAttrName()}; + ignoredAttrs.append(elided.begin(), elided.end()); + + p.printOptionalAttrDictWithKeyword(op->getAttrs(), ignoredAttrs); +} + +void function_interface_impl::printFunctionOp( + OpAsmPrinter &p, FunctionOpInterface op, bool isVariadic, + StringRef typeAttrName, StringAttr argAttrsName, StringAttr resAttrsName) { + // Print the operation and the function name. + auto funcName = + op->getAttrOfType(SymbolTable::getSymbolAttrName()) + .getValue(); + p << ' '; + + StringRef visibilityAttrName = SymbolTable::getVisibilityAttrName(); + if (auto visibility = op->getAttrOfType(visibilityAttrName)) + p << visibility.getValue() << ' '; + p.printSymbolName(funcName); + + ArrayRef argTypes = op.getArgumentTypes(); + ArrayRef resultTypes = op.getResultTypes(); + printFunctionSignature(p, op, argTypes, isVariadic, resultTypes); + printFunctionAttributes( + p, op, {visibilityAttrName, typeAttrName, argAttrsName, resAttrsName}); + // Print the body if this is not an external function. + Region &body = op->getRegion(0); + if (!body.empty()) { + p << ' '; + p.printRegion(body, /*printEntryBlockArgs=*/false, + /*printBlockTerminators=*/true); + } +} + +VAST_UNRELAX_WARNINGS diff --git a/lib/vast/Dialect/Core/Interfaces/FunctionInterface.cpp b/lib/vast/Dialect/Core/Interfaces/FunctionInterface.cpp new file mode 100644 index 0000000000..5fc5d058a9 --- /dev/null +++ b/lib/vast/Dialect/Core/Interfaces/FunctionInterface.cpp @@ -0,0 +1,366 @@ +//===- FunctionSupport.cpp - Utility types for function-like ops ----------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "vast/Dialect/Core/Interfaces/FunctionInterface.hpp" + +using namespace vast::core; +using namespace mlir; + +VAST_RELAX_WARNINGS + +//===----------------------------------------------------------------------===// +// Tablegen Interface Definitions +//===----------------------------------------------------------------------===// + +#include "vast/Dialect/Core/Interfaces/FunctionInterface.cpp.inc" + +//===----------------------------------------------------------------------===// +// Function Arguments and Results. +//===----------------------------------------------------------------------===// + +static bool isEmptyAttrDict(Attribute attr) { + return llvm::cast(attr).empty(); +} + +DictionaryAttr function_interface_impl::getArgAttrDict(FunctionOpInterface op, + unsigned index) { + ArrayAttr attrs = op.getArgAttrsAttr(); + DictionaryAttr argAttrs = + attrs ? llvm::cast(attrs[index]) : DictionaryAttr(); + return argAttrs; +} + +DictionaryAttr +function_interface_impl::getResultAttrDict(FunctionOpInterface op, + unsigned index) { + ArrayAttr attrs = op.getResAttrsAttr(); + DictionaryAttr resAttrs = + attrs ? llvm::cast(attrs[index]) : DictionaryAttr(); + return resAttrs; +} + +ArrayRef +function_interface_impl::getArgAttrs(FunctionOpInterface op, unsigned index) { + auto argDict = getArgAttrDict(op, index); + return argDict ? argDict.getValue() : std::nullopt; +} + +ArrayRef +function_interface_impl::getResultAttrs(FunctionOpInterface op, + unsigned index) { + auto resultDict = getResultAttrDict(op, index); + return resultDict ? resultDict.getValue() : std::nullopt; +} + +/// Get either the argument or result attributes array. +template +static ArrayAttr getArgResAttrs(FunctionOpInterface op) { + if constexpr (isArg) + return op.getArgAttrsAttr(); + else + return op.getResAttrsAttr(); +} + +/// Set either the argument or result attributes array. +template +static void setArgResAttrs(FunctionOpInterface op, ArrayAttr attrs) { + if constexpr (isArg) + op.setArgAttrsAttr(attrs); + else + op.setResAttrsAttr(attrs); +} + +/// Erase either the argument or result attributes array. +template +static void removeArgResAttrs(FunctionOpInterface op) { + if constexpr (isArg) + op.removeArgAttrsAttr(); + else + op.removeResAttrsAttr(); +} + +/// Set all of the argument or result attribute dictionaries for a function. +template +static void setAllArgResAttrDicts(FunctionOpInterface op, + ArrayRef attrs) { + if (llvm::all_of(attrs, isEmptyAttrDict)) + removeArgResAttrs(op); + else + setArgResAttrs(op, ArrayAttr::get(op->getContext(), attrs)); +} + +void function_interface_impl::setAllArgAttrDicts( + FunctionOpInterface op, ArrayRef attrs) { + setAllArgAttrDicts(op, ArrayRef(attrs.data(), attrs.size())); +} + +void function_interface_impl::setAllArgAttrDicts(FunctionOpInterface op, + ArrayRef attrs) { + auto wrappedAttrs = llvm::map_range(attrs, [op](Attribute attr) -> Attribute { + return !attr ? DictionaryAttr::get(op->getContext()) : attr; + }); + setAllArgResAttrDicts(op, llvm::to_vector<8>(wrappedAttrs)); +} + +void function_interface_impl::setAllResultAttrDicts( + FunctionOpInterface op, ArrayRef attrs) { + setAllResultAttrDicts(op, ArrayRef(attrs.data(), attrs.size())); +} + +void function_interface_impl::setAllResultAttrDicts(FunctionOpInterface op, + ArrayRef attrs) { + auto wrappedAttrs = llvm::map_range(attrs, [op](Attribute attr) -> Attribute { + return !attr ? DictionaryAttr::get(op->getContext()) : attr; + }); + setAllArgResAttrDicts(op, llvm::to_vector<8>(wrappedAttrs)); +} + +/// Update the given index into an argument or result attribute dictionary. +template +static void setArgResAttrDict(FunctionOpInterface op, unsigned numTotalIndices, + unsigned index, DictionaryAttr attrs) { + ArrayAttr allAttrs = getArgResAttrs(op); + if (!allAttrs) { + if (attrs.empty()) + return; + + // If this attribute is not empty, we need to create a new attribute array. + SmallVector newAttrs(numTotalIndices, + DictionaryAttr::get(op->getContext())); + newAttrs[index] = attrs; + setArgResAttrs(op, ArrayAttr::get(op->getContext(), newAttrs)); + return; + } + // Check to see if the attribute is different from what we already have. + if (allAttrs[index] == attrs) + return; + + // If it is, check to see if the attribute array would now contain only empty + // dictionaries. + ArrayRef rawAttrArray = allAttrs.getValue(); + if (attrs.empty() && + llvm::all_of(rawAttrArray.take_front(index), isEmptyAttrDict) && + llvm::all_of(rawAttrArray.drop_front(index + 1), isEmptyAttrDict)) + return removeArgResAttrs(op); + + // Otherwise, create a new attribute array with the updated dictionary. + SmallVector newAttrs(rawAttrArray.begin(), rawAttrArray.end()); + newAttrs[index] = attrs; + setArgResAttrs(op, ArrayAttr::get(op->getContext(), newAttrs)); +} + +void function_interface_impl::setArgAttrs(FunctionOpInterface op, + unsigned index, + ArrayRef attributes) { + assert(index < op.getNumArguments() && "invalid argument number"); + return setArgResAttrDict( + op, op.getNumArguments(), index, + DictionaryAttr::get(op->getContext(), attributes)); +} + +void function_interface_impl::setArgAttrs(FunctionOpInterface op, + unsigned index, + DictionaryAttr attributes) { + return setArgResAttrDict( + op, op.getNumArguments(), index, + attributes ? attributes : DictionaryAttr::get(op->getContext())); +} + +void function_interface_impl::setResultAttrs( + FunctionOpInterface op, unsigned index, + ArrayRef attributes) { + assert(index < op.getNumResults() && "invalid result number"); + return setArgResAttrDict( + op, op.getNumResults(), index, + DictionaryAttr::get(op->getContext(), attributes)); +} + +void function_interface_impl::setResultAttrs(FunctionOpInterface op, + unsigned index, + DictionaryAttr attributes) { + assert(index < op.getNumResults() && "invalid result number"); + return setArgResAttrDict( + op, op.getNumResults(), index, + attributes ? attributes : DictionaryAttr::get(op->getContext())); +} + +void function_interface_impl::insertFunctionArguments( + FunctionOpInterface op, ArrayRef argIndices, TypeRange argTypes, + ArrayRef argAttrs, ArrayRef argLocs, + unsigned originalNumArgs, Type newType) { + assert(argIndices.size() == argTypes.size()); + assert(argIndices.size() == argAttrs.size() || argAttrs.empty()); + assert(argIndices.size() == argLocs.size()); + if (argIndices.empty()) + return; + + // There are 3 things that need to be updated: + // - Function type. + // - Arg attrs. + // - Block arguments of entry block. + Block &entry = op->getRegion(0).front(); + + // Update the argument attributes of the function. + ArrayAttr oldArgAttrs = op.getArgAttrsAttr(); + if (oldArgAttrs || !argAttrs.empty()) { + SmallVector newArgAttrs; + newArgAttrs.reserve(originalNumArgs + argIndices.size()); + unsigned oldIdx = 0; + auto migrate = [&](unsigned untilIdx) { + if (!oldArgAttrs) { + newArgAttrs.resize(newArgAttrs.size() + untilIdx - oldIdx); + } else { + auto oldArgAttrRange = oldArgAttrs.getAsRange(); + newArgAttrs.append(oldArgAttrRange.begin() + oldIdx, + oldArgAttrRange.begin() + untilIdx); + } + oldIdx = untilIdx; + }; + for (unsigned i = 0, e = argIndices.size(); i < e; ++i) { + migrate(argIndices[i]); + newArgAttrs.push_back(argAttrs.empty() ? DictionaryAttr{} : argAttrs[i]); + } + migrate(originalNumArgs); + setAllArgAttrDicts(op, newArgAttrs); + } + + // Update the function type and any entry block arguments. + op.setFunctionTypeAttr(TypeAttr::get(newType)); + for (unsigned i = 0, e = argIndices.size(); i < e; ++i) + entry.insertArgument(argIndices[i] + i, argTypes[i], argLocs[i]); +} + +void function_interface_impl::insertFunctionResults( + FunctionOpInterface op, ArrayRef resultIndices, + TypeRange resultTypes, ArrayRef resultAttrs, + unsigned originalNumResults, Type newType) { + assert(resultIndices.size() == resultTypes.size()); + assert(resultIndices.size() == resultAttrs.size() || resultAttrs.empty()); + if (resultIndices.empty()) + return; + + // There are 2 things that need to be updated: + // - Function type. + // - Result attrs. + + // Update the result attributes of the function. + ArrayAttr oldResultAttrs = op.getResAttrsAttr(); + if (oldResultAttrs || !resultAttrs.empty()) { + SmallVector newResultAttrs; + newResultAttrs.reserve(originalNumResults + resultIndices.size()); + unsigned oldIdx = 0; + auto migrate = [&](unsigned untilIdx) { + if (!oldResultAttrs) { + newResultAttrs.resize(newResultAttrs.size() + untilIdx - oldIdx); + } else { + auto oldResultAttrsRange = oldResultAttrs.getAsRange(); + newResultAttrs.append(oldResultAttrsRange.begin() + oldIdx, + oldResultAttrsRange.begin() + untilIdx); + } + oldIdx = untilIdx; + }; + for (unsigned i = 0, e = resultIndices.size(); i < e; ++i) { + migrate(resultIndices[i]); + newResultAttrs.push_back(resultAttrs.empty() ? DictionaryAttr{} + : resultAttrs[i]); + } + migrate(originalNumResults); + setAllResultAttrDicts(op, newResultAttrs); + } + + // Update the function type. + op.setFunctionTypeAttr(TypeAttr::get(newType)); +} + +void function_interface_impl::eraseFunctionArguments( + FunctionOpInterface op, const BitVector &argIndices, Type newType) { + // There are 3 things that need to be updated: + // - Function type. + // - Arg attrs. + // - Block arguments of entry block. + Block &entry = op->getRegion(0).front(); + + // Update the argument attributes of the function. + if (ArrayAttr argAttrs = op.getArgAttrsAttr()) { + SmallVector newArgAttrs; + newArgAttrs.reserve(argAttrs.size()); + for (unsigned i = 0, e = argIndices.size(); i < e; ++i) + if (!argIndices[i]) + newArgAttrs.emplace_back(llvm::cast(argAttrs[i])); + setAllArgAttrDicts(op, newArgAttrs); + } + + // Update the function type and any entry block arguments. + op.setFunctionTypeAttr(TypeAttr::get(newType)); + entry.eraseArguments(argIndices); +} + +void function_interface_impl::eraseFunctionResults( + FunctionOpInterface op, const BitVector &resultIndices, Type newType) { + // There are 2 things that need to be updated: + // - Function type. + // - Result attrs. + + // Update the result attributes of the function. + if (ArrayAttr resAttrs = op.getResAttrsAttr()) { + SmallVector newResultAttrs; + newResultAttrs.reserve(resAttrs.size()); + for (unsigned i = 0, e = resultIndices.size(); i < e; ++i) + if (!resultIndices[i]) + newResultAttrs.emplace_back(llvm::cast(resAttrs[i])); + setAllResultAttrDicts(op, newResultAttrs); + } + + // Update the function type. + op.setFunctionTypeAttr(TypeAttr::get(newType)); +} + +//===----------------------------------------------------------------------===// +// Function type signature. +//===----------------------------------------------------------------------===// + +void function_interface_impl::setFunctionType(FunctionOpInterface op, + Type newType) { + unsigned oldNumArgs = op.getNumArguments(); + unsigned oldNumResults = op.getNumResults(); + op.setFunctionTypeAttr(TypeAttr::get(newType)); + unsigned newNumArgs = op.getNumArguments(); + unsigned newNumResults = op.getNumResults(); + + // Functor used to update the argument and result attributes of the function. + auto emptyDict = DictionaryAttr::get(op.getContext()); + auto updateAttrFn = [&](auto isArg, unsigned oldCount, unsigned newCount) { + constexpr bool isArgVal = std::is_same_v; + + if (oldCount == newCount) + return; + // The new type has no arguments/results, just drop the attribute. + if (newCount == 0) + return removeArgResAttrs(op); + ArrayAttr attrs = getArgResAttrs(op); + if (!attrs) + return; + + // The new type has less arguments/results, take the first N attributes. + if (newCount < oldCount) + return setAllArgResAttrDicts( + op, attrs.getValue().take_front(newCount)); + + // Otherwise, the new type has more arguments/results. Initialize the new + // arguments/results with empty dictionary attributes. + SmallVector newAttrs(attrs.begin(), attrs.end()); + newAttrs.resize(newCount, emptyDict); + setAllArgResAttrDicts(op, newAttrs); + }; + + // Update the argument and result attributes. + updateAttrFn(std::true_type{}, oldNumArgs, newNumArgs); + updateAttrFn(std::false_type{}, oldNumResults, newNumResults); +} + +VAST_UNRELAX_WARNINGS diff --git a/lib/vast/Dialect/HighLevel/HighLevelOps.cpp b/lib/vast/Dialect/HighLevel/HighLevelOps.cpp index e2dca19916..d731cf3e20 100644 --- a/lib/vast/Dialect/HighLevel/HighLevelOps.cpp +++ b/lib/vast/Dialect/HighLevel/HighLevelOps.cpp @@ -16,7 +16,6 @@ VAST_RELAX_WARNINGS #include #include -#include VAST_UNRELAX_WARNINGS #include "vast/Dialect/HighLevel/HighLevelAttributes.hpp" diff --git a/lib/vast/Dialect/HighLevel/Transforms/ExportFnInfo.cpp b/lib/vast/Dialect/HighLevel/Transforms/ExportFnInfo.cpp index 05e0f5a778..91fdefdfe6 100644 --- a/lib/vast/Dialect/HighLevel/Transforms/ExportFnInfo.cpp +++ b/lib/vast/Dialect/HighLevel/Transforms/ExportFnInfo.cpp @@ -214,7 +214,7 @@ namespace vast::hl llvm::json::Object top; - // TODO use FunctionOpInterface instead of specific operation + // TODO use core::function_op_interface instead of specific operation util::functions(mod, [&](FuncOp fn) { const auto &dl_analysis = this->getAnalysis< mlir::DataLayoutAnalysis >(); const auto &dl = dl_analysis.getAtOrAbove(mod); @@ -233,7 +233,7 @@ namespace vast::hl current["rets"] = std::move(rets); current["args"] = std::move(args); - top[fn.getName().str()] = std::move(current); + top[fn.getSymName().str()] = std::move(current); }); auto value = llvm::formatv("{0:2}", llvm::json::Value(std::move(top))); diff --git a/lib/vast/Dialect/LowLevel/LowLevelOps.cpp b/lib/vast/Dialect/LowLevel/LowLevelOps.cpp index 74304a1cb0..d8047aa4a1 100644 --- a/lib/vast/Dialect/LowLevel/LowLevelOps.cpp +++ b/lib/vast/Dialect/LowLevel/LowLevelOps.cpp @@ -9,7 +9,6 @@ VAST_RELAX_WARNINGS #include -#include #include VAST_UNRELAX_WARNINGS