diff --git a/.devcontainer/Dockerfile b/.devcontainer/Dockerfile index f45522fe64..4633f24926 100644 --- a/.devcontainer/Dockerfile +++ b/.devcontainer/Dockerfile @@ -2,7 +2,7 @@ # [Choice] Debian / Ubuntu version (use Debian 11, Ubuntu 18.04/22.04 on local arm64/Apple Silicon): debian-11, debian-10, ubuntu-22.04, ubuntu-20.04, ubuntu-18.04 ARG IMAGE="ubuntu-22.04" -ARG LLVM_VERSION="18" +ARG LLVM_VERSION="19" FROM mcr.microsoft.com/vscode/devcontainers/cpp:${IMAGE} ARG LLVM_VERSION diff --git a/.devcontainer/devcontainer.json b/.devcontainer/devcontainer.json index f6bf6959e2..7a54f0cea6 100644 --- a/.devcontainer/devcontainer.json +++ b/.devcontainer/devcontainer.json @@ -24,7 +24,7 @@ "python": "os-provided" }, "containerEnv": { - "CMAKE_PREFIX_PATH": "/usr/lib/llvm-18/lib/cmake/mlir/;/usr/lib/llvm-18/lib/cmake/clang/", + "CMAKE_PREFIX_PATH": "/usr/lib/llvm-19/lib/cmake/mlir/;/usr/lib/llvm-19/lib/cmake/clang/", "LLVM_EXTERNAL_LIT": "/usr/local/bin/lit" } diff --git a/.devcontainer/install-llvm.sh b/.devcontainer/install-llvm.sh index 48e3e1b739..6cc7bfa830 100644 --- a/.devcontainer/install-llvm.sh +++ b/.devcontainer/install-llvm.sh @@ -11,7 +11,7 @@ usage() { exit 1; } -CURRENT_LLVM_STABLE=18 +CURRENT_LLVM_STABLE=19 BASE_URL="http://apt.llvm.org" # Check for required tools @@ -118,6 +118,7 @@ LLVM_VERSION_PATTERNS[15]="-15" LLVM_VERSION_PATTERNS[16]="-16" LLVM_VERSION_PATTERNS[17]="-17" LLVM_VERSION_PATTERNS[18]="-18" +LLVM_VERSION_PATTERNS[19]="-19" if [ ! ${LLVM_VERSION_PATTERNS[$LLVM_VERSION]+_} ]; then echo "This script does not support LLVM version $LLVM_VERSION" diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index d187f6943b..6017bc9719 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -24,10 +24,10 @@ jobs: build: strategy: matrix: - llvm-version: [18] + llvm-version: [19] image-version: [22.04] build-type: [Release, Debug] - sanitizers: [ON, OFF] + sanitizers: [OFF] features: ["nosarif", "sarif"] runs-on: ubuntu-${{ matrix.image-version }} @@ -57,4 +57,3 @@ jobs: - name: Test ${{ matrix.build-type }} with sanitizers set ${{ matrix.sanitizers }} run: ctest --preset ci-${{ matrix.features }} --build-config ${{ matrix.build-type }} - diff --git a/.github/workflows/devcontainer.yml b/.github/workflows/devcontainer.yml index b140c7705f..7819eee45f 100644 --- a/.github/workflows/devcontainer.yml +++ b/.github/workflows/devcontainer.yml @@ -14,7 +14,7 @@ jobs: build: strategy: matrix: - llvm-version: [18] + llvm-version: [19] image-version: [22.04] runs-on: ubuntu-22.04 timeout-minutes: 45 diff --git a/.github/workflows/linter.yml b/.github/workflows/linter.yml index 9256dab67f..fb88810678 100644 --- a/.github/workflows/linter.yml +++ b/.github/workflows/linter.yml @@ -16,7 +16,7 @@ jobs: cpp-linter: strategy: matrix: - llvm-version: [18] + llvm-version: [19] image-version: [22.04] runs-on: ubuntu-${{ matrix.image-version }} diff --git a/.github/workflows/prerelease.yml b/.github/workflows/prerelease.yml index 1efea86bdb..a5f254a348 100644 --- a/.github/workflows/prerelease.yml +++ b/.github/workflows/prerelease.yml @@ -25,7 +25,7 @@ jobs: build: strategy: matrix: - llvm-version: [18] + llvm-version: [19] image-version: [22.04] name: "Pre Release" @@ -89,7 +89,7 @@ jobs: needs: build strategy: matrix: - llvm-version: [18] + llvm-version: [19] image-version: [22.04] vast-target: ['hl', 'llvm', 'bin'] disable-unsup: ['ON','OFF'] @@ -217,7 +217,7 @@ jobs: needs: build strategy: matrix: - llvm-version: [18] + llvm-version: [19] image-version: [22.04] vast-target: ['hl'] disable-unsup: [true, false] @@ -321,7 +321,7 @@ jobs: needs: build strategy: matrix: - llvm-version: [18] + llvm-version: [19] image-version: [22.04] vast-target: ['hl'] disable-unsup: [true, false] @@ -440,7 +440,7 @@ jobs: build_doc: strategy: matrix: - llvm-version: [18] + llvm-version: [19] image-version: [22.04] name: "Build VAST doc" runs-on: ubuntu-${{ matrix.image-version }} diff --git a/CMakeLists.txt b/CMakeLists.txt index fedde7a808..326e96d37d 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -99,7 +99,7 @@ FindAndSelectClangCompiler() # # LLVM & MLIR & Clang # -find_package(LLVM 18.1 REQUIRED CONFIG) +find_package(LLVM 19.1 REQUIRED CONFIG) message(STATUS "Found LLVM ${LLVM_PACKAGE_VERSION}") message(STATUS "Using LLVMConfig.cmake in: ${LLVM_DIR}") # diff --git a/CMakePresets.json b/CMakePresets.json index daa15d5e7b..e84040c49a 100644 --- a/CMakePresets.json +++ b/CMakePresets.json @@ -54,13 +54,13 @@ "displayName": "Configure VAST for Compiler Explorer", "inherits": "base", "environment": { - "CMAKE_PREFIX_PATH": "/usr/lib/llvm-18" + "CMAKE_PREFIX_PATH": "/usr/lib/llvm-19" }, "cacheVariables": { "CMAKE_CONFIGURATION_TYPES": "Release", "VAST_ENABLE_TESTING": "OFF", - "CMAKE_C_COMPILER": "/usr/bin/clang-18", - "CMAKE_CXX_COMPILER": "/usr/bin/clang++-18", + "CMAKE_C_COMPILER": "/usr/bin/clang-19", + "CMAKE_CXX_COMPILER": "/usr/bin/clang++-19", "CMAKE_INSTALL_PREFIX": "$env{STAGING_DIR}", "CMAKE_INSTALL_RPATH": "$env{ORIGIN}/../lib" } diff --git a/docs/GettingStarted/build.md b/docs/GettingStarted/build.md index 754f23f036..8e80bfbeae 100644 --- a/docs/GettingStarted/build.md +++ b/docs/GettingStarted/build.md @@ -3,11 +3,11 @@ Currently, it is necessary to use `clang` (due to `gcc` bug) to build VAST. On Linux it is also necessary to use `lld` at the moment. -VAST uses `llvm-18` which can be obtained from the [repository](https://apt.llvm.org/) provided by LLVM. +VAST uses `llvm-19` which can be obtained from the [repository](https://apt.llvm.org/) provided by LLVM. Before building (for Ubuntu) get all the necessary dependencies by running ``` -apt-get install build-essential cmake ninja-build libstdc++-12-dev llvm-18 libmlir-18 libmlir-18-dev mlir-18-tools libclang-18-dev +apt-get install build-essential cmake ninja-build libstdc++-12-dev llvm-19 libmlir-19 libmlir-19-dev mlir-19-tools libclang-19-dev ``` or an equivalent command for your operating system of choice. diff --git a/include/vast/Conversion/TypeConverters/HLToStd.hpp b/include/vast/Conversion/TypeConverters/HLToStd.hpp index 9c636997af..9e047ffca0 100644 --- a/include/vast/Conversion/TypeConverters/HLToStd.hpp +++ b/include/vast/Conversion/TypeConverters/HLToStd.hpp @@ -89,7 +89,7 @@ namespace vast::conv::tc { protected: auto convert_pointer_element_type() { return [&](auto t) -> maybe_type_t { - if (t.template isa< hl::VoidType >()) { + if (mlir::isa< hl::VoidType >(t)) { auto sign = mlir::IntegerType::SignednessSemantics::Signless; return underlying().int_type(8u, sign); } diff --git a/include/vast/Conversion/TypeConverters/LLVMTypeConverter.hpp b/include/vast/Conversion/TypeConverters/LLVMTypeConverter.hpp index b64ac46ae7..400b1e6abf 100644 --- a/include/vast/Conversion/TypeConverters/LLVMTypeConverter.hpp +++ b/include/vast/Conversion/TypeConverters/LLVMTypeConverter.hpp @@ -99,7 +99,7 @@ namespace vast::conv::tc { auto make_ptr_type() { return [&](auto t) { - VAST_ASSERT(!t.template isa< mlir::NoneType >()); + VAST_ASSERT(!mlir::isa< mlir::NoneType >(t)); return LLVM::LLVMPointerType::get(&this->getContext(), 0); }; } @@ -139,9 +139,9 @@ namespace vast::conv::tc { maybe_type_t convert_memref_type(mlir::UnrankedMemRefType t) { return {}; } maybe_signature_conversion_t - get_conversion_signature(mlir::FunctionOpInterface fn, bool variadic) { + get_conversion_signature(core::function_op_interface fn, bool variadic) { signature_conversion_t conversion(fn.getNumArguments()); - auto fn_type = fn.getFunctionType().dyn_cast< core::FunctionType >(); + auto fn_type = mlir::dyn_cast< core::FunctionType >(fn.getFunctionType()); VAST_ASSERT(fn_type); for (auto arg : llvm::enumerate(fn_type.getInputs())) { auto cty = convert_arg_t(arg.value()); @@ -186,14 +186,14 @@ namespace vast::conv::tc { } maybe_types_t convert_arg_t(mlir::Type t) { - if (auto lvalue = t.dyn_cast< hl::LValueType >()) { + if (auto lvalue = mlir::dyn_cast< hl::LValueType >(t)) { return this->convert_type_to_types(lvalue.getElementType()); } return this->convert_type_to_types(t); } maybe_types_t convert_ret_t(mlir::Type t) { - if (auto lvalue = t.dyn_cast< hl::LValueType >()) { + if (auto lvalue = mlir::dyn_cast< hl::LValueType >(t)) { return this->convert_type_to_types(lvalue.getElementType()); } return this->convert_type_to_types(t); diff --git a/include/vast/Conversion/TypeConverters/TypeConverter.hpp b/include/vast/Conversion/TypeConverters/TypeConverter.hpp index 2ce6f60b64..a569e19f26 100644 --- a/include/vast/Conversion/TypeConverters/TypeConverter.hpp +++ b/include/vast/Conversion/TypeConverters/TypeConverter.hpp @@ -13,6 +13,8 @@ VAST_UNRELAX_WARNINGS #include #include "vast/Dialect/Core/CoreTypes.hpp" +#include "vast/Dialect/Core/Interfaces/FunctionInterface.hpp" + #include "vast/Dialect/HighLevel/HighLevelDialect.hpp" #include "vast/Util/Common.hpp" diff --git a/include/vast/Conversion/TypeConverters/TypeConvertingPattern.hpp b/include/vast/Conversion/TypeConverters/TypeConvertingPattern.hpp index 1b167f379e..f21443834b 100644 --- a/include/vast/Conversion/TypeConverters/TypeConvertingPattern.hpp +++ b/include/vast/Conversion/TypeConverters/TypeConvertingPattern.hpp @@ -34,7 +34,7 @@ namespace vast::conv::tc { // TODO(conv:tc): This should probably be some interface instead, since // we are only updating the root? logical_result replace( - mlir::FunctionOpInterface fn, + core::function_op_interface fn, auto &rewriter ) const { auto old_type = fn.getFunctionType(); @@ -115,7 +115,7 @@ namespace vast::conv::tc { operation op, mlir::ArrayRef< mlir::Value >, conversion_rewriter &rewriter ) const override { - if (auto func_op = mlir::dyn_cast< mlir::FunctionOpInterface >(op)) + if (auto func_op = mlir::dyn_cast< core::function_op_interface >(op)) return this->replace(func_op, rewriter); return this->replace(op, rewriter); } diff --git a/include/vast/Dialect/ABI/ABIOps.hpp b/include/vast/Dialect/ABI/ABIOps.hpp index 83f3f9202a..a9c6025d82 100644 --- a/include/vast/Dialect/ABI/ABIOps.hpp +++ b/include/vast/Dialect/ABI/ABIOps.hpp @@ -21,6 +21,8 @@ VAST_RELAX_WARNINGS #include "vast/Dialect/Core/CoreTypes.hpp" #include "vast/Dialect/Core/CoreAttributes.hpp" #include "vast/Dialect/Core/Func.hpp" + +#include "vast/Dialect/Core/Interfaces/FunctionInterface.hpp" #include "vast/Dialect/Core/Interfaces/SymbolInterface.hpp" #include "vast/Dialect/Core/Interfaces/SymbolTableInterface.hpp" diff --git a/include/vast/Dialect/Builtin/Ops.hpp b/include/vast/Dialect/Builtin/Ops.hpp index 914ddafd7c..7513541fcf 100644 --- a/include/vast/Dialect/Builtin/Ops.hpp +++ b/include/vast/Dialect/Builtin/Ops.hpp @@ -5,7 +5,6 @@ #include "vast/Util/Warnings.hpp" VAST_RELAX_WARNINGS -#include #include #include VAST_UNRELAX_WARNINGS diff --git a/include/vast/Dialect/Core/CoreLazy.td b/include/vast/Dialect/Core/CoreLazy.td index 9a72fc783c..3c4bf9d528 100644 --- a/include/vast/Dialect/Core/CoreLazy.td +++ b/include/vast/Dialect/Core/CoreLazy.td @@ -7,7 +7,6 @@ include "mlir/IR/OpBase.td" include "mlir/Interfaces/CallInterfaces.td" include "mlir/Interfaces/CastInterfaces.td" -include "mlir/Interfaces/FunctionInterfaces.td" include "vast/Dialect/Core/Interfaces/SymbolInterface.td" class Core_LazyEval< string mnemonic, list < Trait > traits = [] > diff --git a/include/vast/Dialect/Core/CoreOps.hpp b/include/vast/Dialect/Core/CoreOps.hpp index 2033b5a399..3af048280a 100644 --- a/include/vast/Dialect/Core/CoreOps.hpp +++ b/include/vast/Dialect/Core/CoreOps.hpp @@ -11,7 +11,6 @@ #include "vast/Util/Common.hpp" #include "vast/Util/TypeList.hpp" -#include #include #include #include @@ -24,4 +23,4 @@ namespace vast::core { using module = core::ModuleOp; -} // namespace vast::core \ No newline at end of file +} // namespace vast::core diff --git a/include/vast/Dialect/Core/Func.hpp b/include/vast/Dialect/Core/Func.hpp index 70c78a9009..8ec8153eef 100644 --- a/include/vast/Dialect/Core/Func.hpp +++ b/include/vast/Dialect/Core/Func.hpp @@ -13,6 +13,9 @@ VAST_UNRELAX_WARNINGS #include "vast/Dialect/Core/Linkage.hpp" #include "vast/Dialect/Core/CoreTypes.hpp" +#include "vast/Dialect/Core/Interfaces/FunctionInterface.hpp" +#include "vast/Dialect/Core/Interfaces/FunctionImplementation.hpp" + namespace vast::core { llvm::StringRef getLinkageAttrNameString(); @@ -77,11 +80,11 @@ namespace vast::core { } auto fty = op.getFunctionType(); - mlir::function_interface_impl::printFunctionSignature( + vast::core::function_interface_impl::printFunctionSignature( printer, op, fty.getInputs(), fty.isVarArg(), fty.getResults() ); - mlir::function_interface_impl::printFunctionAttributes( + vast::core::function_interface_impl::printFunctionAttributes( printer, op, { getLinkageAttrNameString(), op.getFunctionTypeAttrName() } ); @@ -160,7 +163,7 @@ namespace vast::core { template< typename DstFuncOp > logical_result convert_and_replace_function(auto src, auto &rewriter) { - return convert_and_replace_function< DstFuncOp >(src, rewriter, src.getName()); + return convert_and_replace_function< DstFuncOp >(src, rewriter, src.getSymbolName()); } } // namespace vast::core diff --git a/include/vast/Dialect/Core/Func.td b/include/vast/Dialect/Core/Func.td index ed7cbe2925..63ea03e124 100644 --- a/include/vast/Dialect/Core/Func.td +++ b/include/vast/Dialect/Core/Func.td @@ -6,8 +6,8 @@ include "mlir/IR/OpBase.td" include "mlir/Interfaces/CallInterfaces.td" -include "mlir/Interfaces/FunctionInterfaces.td" include "vast/Dialect/Core/Interfaces/SymbolInterface.td" +include "vast/Dialect/Core/Interfaces/FunctionInterface.td" include "vast/Dialect/Core/CoreTraits.td" include "vast/Dialect/Core/LinkageHelper.td" @@ -22,9 +22,8 @@ class Core_FuncBaseOp< Dialect dialect, string mnemonic, list< Trait > traits = : Op< dialect, mnemonic, !listconcat(traits, [ AutomaticAllocationScope, - CallableOpInterface, - FunctionOpInterface, IsolatedFromAbove, + Core_FunctionOpInterface, Core_FuncSymbol, NoTerminator ]) @@ -86,7 +85,7 @@ class Core_FuncBaseOp< Dialect dialect, string mnemonic, list< Trait > traits = if (arg_attrs.empty()) return; - mlir::function_interface_impl::addArgAndResultAttrs( + vast::core::function_interface_impl::addArgAndResultAttrs( $_builder, $_state, arg_attrs, res_attrs, getArgAttrsAttrName($_state.name), getResAttrsAttrName($_state.name) ); diff --git a/include/vast/Dialect/Core/Interfaces/CMakeLists.txt b/include/vast/Dialect/Core/Interfaces/CMakeLists.txt index d6ab0139a1..b86f5c3fa1 100644 --- a/include/vast/Dialect/Core/Interfaces/CMakeLists.txt +++ b/include/vast/Dialect/Core/Interfaces/CMakeLists.txt @@ -1,4 +1,5 @@ add_vast_op_interface(DeclStorageInterface) +add_vast_op_interface(FunctionInterface) add_vast_op_interface(TypeDefinitionInterface) add_vast_op_interface(SymbolInterface) diff --git a/include/vast/Dialect/Core/Interfaces/DeclStorageInterface.hpp b/include/vast/Dialect/Core/Interfaces/DeclStorageInterface.hpp index b762a0c9d1..455dfa55ed 100644 --- a/include/vast/Dialect/Core/Interfaces/DeclStorageInterface.hpp +++ b/include/vast/Dialect/Core/Interfaces/DeclStorageInterface.hpp @@ -5,7 +5,6 @@ #include "vast/Util/Warnings.hpp" VAST_RELAX_WARNINGS -#include #include #include #include @@ -13,6 +12,7 @@ VAST_RELAX_WARNINGS VAST_RELAX_WARNINGS #include "vast/Dialect/Core/CoreOps.hpp" +#include "vast/Dialect/Core/Interfaces/FunctionInterface.hpp" #define GET_OP_FWD_DEFINES #include "vast/Dialect/HighLevel/HighLevel.h.inc" diff --git a/include/vast/Dialect/Core/Interfaces/DeclStorageInterface.td b/include/vast/Dialect/Core/Interfaces/DeclStorageInterface.td index 06df4b0108..84e8c0a244 100644 --- a/include/vast/Dialect/Core/Interfaces/DeclStorageInterface.td +++ b/include/vast/Dialect/Core/Interfaces/DeclStorageInterface.td @@ -72,7 +72,7 @@ def Core_DeclStorageInterface : Core_OpInterface< "DeclStorageInterface" > { return kind_attr.getValue(); } auto st = core::get_effective_symbol_table_for< core::var_symbol >($_op)->get_defining_operation(); - if (mlir::isa< mlir::FunctionOpInterface >(st)) + if (mlir::isa< core::function_op_interface >(st)) return DeclContextKind::dc_function; if (st->template hasTrait< core::ScopeLikeTrait >()) return DeclContextKind::dc_function; diff --git a/include/vast/Dialect/Core/Interfaces/FunctionImplementation.hpp b/include/vast/Dialect/Core/Interfaces/FunctionImplementation.hpp new file mode 100644 index 0000000000..a0f28af814 --- /dev/null +++ b/include/vast/Dialect/Core/Interfaces/FunctionImplementation.hpp @@ -0,0 +1,108 @@ +//===- FunctionImplementation.h - Function-like Op utilities ----*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file provides utility functions for implementing function-like +// operations, in particular, parsing, printing and verification components +// common to function-like operations. +// +//===----------------------------------------------------------------------===// + +#ifndef VAST_DIALECT_CORE_FUNCTIONIMPLEMENTATION_HPP +#define VAST_DIALECT_CORE_FUNCTIONIMPLEMENTATION_HPP + +#include "vast/Util/Warnings.hpp" + +VAST_RELAX_WARNINGS +#include +VAST_UNRELAX_WARNINGS + +#include "vast/Dialect/Core/Interfaces/FunctionInterface.hpp" + +namespace vast::core { + +namespace function_interface_impl { + +/// A named class for passing around the variadic flag. +class VariadicFlag { +public: + explicit VariadicFlag(bool variadic) : variadic(variadic) {} + bool isVariadic() const { return variadic; } + +private: + /// Underlying storage. + bool variadic; +}; + +/// Adds argument and result attributes, provided as `argAttrs` and +/// `resultAttrs` arguments, to the list of operation attributes in `result`. +/// Internally, argument and result attributes are stored as dict attributes +/// with special names given by getResultAttrName, getArgumentAttrName. +void addArgAndResultAttrs(::mlir::Builder &builder, ::mlir::OperationState &result, + ::mlir::ArrayRef<::mlir::DictionaryAttr> argAttrs, + ::mlir::ArrayRef<::mlir::DictionaryAttr> resultAttrs, + ::mlir::StringAttr argAttrsName, ::mlir::StringAttr resAttrsName); +void addArgAndResultAttrs(::mlir::Builder &builder, ::mlir::OperationState &result, + ::mlir::ArrayRef<::mlir::OpAsmParser::Argument> args, + ::mlir::ArrayRef<::mlir::DictionaryAttr> resultAttrs, + ::mlir::StringAttr argAttrsName, ::mlir::StringAttr resAttrsName); + +/// Callback type for `parseFunctionOp`, the callback should produce the +/// type that will be associated with a function-like operation from lists of +/// function arguments and results, VariadicFlag indicates whether the function +/// should have variadic arguments; in case of error, it may populate the last +/// argument with a message. +using FuncTypeBuilder = ::mlir::function_ref<::mlir::Type( + ::mlir::Builder &, ::mlir::ArrayRef<::mlir::Type>, ::mlir::ArrayRef<::mlir::Type>, VariadicFlag, std::string &)>; + +/// Parses a function signature using `parser`. The `allowVariadic` argument +/// indicates whether functions with variadic arguments are supported. The +/// trailing arguments are populated by this function with names, types, +/// attributes and locations of the arguments and those of the results. +::mlir::ParseResult +parseFunctionSignature(::mlir::OpAsmParser &parser, bool allowVariadic, + ::mlir::SmallVectorImpl<::mlir::OpAsmParser::Argument> &arguments, + bool &isVariadic, ::mlir::SmallVectorImpl<::mlir::Type> &resultTypes, + ::mlir::SmallVectorImpl<::mlir::DictionaryAttr> &resultAttrs); + +/// Parser implementation for function-like operations. Uses +/// `funcTypeBuilder` to construct the custom function type given lists of +/// input and output types. The parser sets the `typeAttrName` attribute to the +/// resulting function type. If `allowVariadic` is set, the parser will accept +/// trailing ellipsis in the function signature and indicate to the builder +/// whether the function is variadic. If the builder returns a null type, +/// `result` will not contain the `type` attribute. The caller can then add a +/// type, report the error or delegate the reporting to the op's verifier. +::mlir::ParseResult parseFunctionOp(::mlir::OpAsmParser &parser, ::mlir::OperationState &result, + bool allowVariadic, ::mlir::StringAttr typeAttrName, + FuncTypeBuilder funcTypeBuilder, + ::mlir::StringAttr argAttrsName, ::mlir::StringAttr resAttrsName); + +/// Printer implementation for function-like operations. +void printFunctionOp(::mlir::OpAsmPrinter &p, FunctionOpInterface op, bool isVariadic, + ::mlir::StringRef typeAttrName, ::mlir::StringAttr argAttrsName, + ::mlir::StringAttr resAttrsName); + +/// Prints the signature of the function-like operation `op`. Assumes `op` has +/// is a FunctionOpInterface and has passed verification. +void printFunctionSignature(::mlir::OpAsmPrinter &p, FunctionOpInterface op, + ::mlir::ArrayRef<::mlir::Type> argTypes, bool isVariadic, + ::mlir::ArrayRef<::mlir::Type> resultTypes); + +/// Prints the list of function prefixed with the "attributes" keyword. The +/// attributes with names listed in "elided" as well as those used by the +/// function-like operation internally are not printed. Nothing is printed +/// if all attributes are elided. Assumes `op` is a FunctionOpInterface and +/// has passed verification. +void printFunctionAttributes(::mlir::OpAsmPrinter &p, ::mlir::Operation *op, + ::mlir::ArrayRef<::mlir::StringRef> elided = {}); + +} // namespace function_interface_impl + +} // namespace vast::core + +#endif // VAST_DIALECT_CORE_FUNCTIONIMPLEMENTATION_HPP diff --git a/include/vast/Dialect/Core/Interfaces/FunctionInterface.hpp b/include/vast/Dialect/Core/Interfaces/FunctionInterface.hpp new file mode 100644 index 0000000000..23647af2c4 --- /dev/null +++ b/include/vast/Dialect/Core/Interfaces/FunctionInterface.hpp @@ -0,0 +1,249 @@ +//===- FunctionSupport.h - Utility types for function-like ops --*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file defines support types for Operations that represent function-like +// constructs to use. +// +//===----------------------------------------------------------------------===// + +#ifndef VAST_DIALECT_CORE_FUNCTIONINTERFACE_HPP +#define VAST_DIALECT_CORE_FUNCTIONINTERFACE_HPP + +#include "vast/Util/Warnings.hpp" + +VAST_RELAX_WARNINGS +#include +#include +#include +#include +#include +#include +#include +#include +VAST_UNRELAX_WARNINGS + +namespace vast::core { +class FunctionOpInterface; + +using function_op_interface = FunctionOpInterface; + +namespace function_interface_impl { + +/// Returns the dictionary attribute corresponding to the argument at 'index'. +/// If there are no argument attributes at 'index', a null attribute is +/// returned. +mlir::DictionaryAttr getArgAttrDict(FunctionOpInterface op, unsigned index); + +/// Returns the dictionary attribute corresponding to the result at 'index'. +/// If there are no result attributes at 'index', a null attribute is +/// returned. +mlir::DictionaryAttr getResultAttrDict(FunctionOpInterface op, unsigned index); + +/// Return all of the attributes for the argument at 'index'. +mlir::ArrayRef getArgAttrs(FunctionOpInterface op, unsigned index); + +/// Return all of the attributes for the result at 'index'. +mlir::ArrayRef getResultAttrs(FunctionOpInterface op, unsigned index); + +/// Set all of the argument or result attribute dictionaries for a function. The +/// size of `attrs` is expected to match the number of arguments/results of the +/// given `op`. +void setAllArgAttrDicts(FunctionOpInterface op, mlir::ArrayRef attrs); +void setAllArgAttrDicts(FunctionOpInterface op, mlir::ArrayRef attrs); +void setAllResultAttrDicts(FunctionOpInterface op, + mlir::ArrayRef attrs); +void setAllResultAttrDicts(FunctionOpInterface op, mlir::ArrayRef attrs); + +/// Insert the specified arguments and update the function type attribute. +void insertFunctionArguments(FunctionOpInterface op, + mlir::ArrayRef argIndices, mlir::TypeRange argTypes, + mlir::ArrayRef argAttrs, + mlir::ArrayRef argLocs, + unsigned originalNumArgs, mlir::Type newType); + +/// Insert the specified results and update the function type attribute. +void insertFunctionResults(FunctionOpInterface op, + mlir::ArrayRef resultIndices, + mlir::TypeRange resultTypes, + mlir::ArrayRef resultAttrs, + unsigned originalNumResults, mlir::Type newType); + +/// Erase the specified arguments and update the function type attribute. +void eraseFunctionArguments(FunctionOpInterface op, const mlir::BitVector &argIndices, + mlir::Type newType); + +/// Erase the specified results and update the function type attribute. +void eraseFunctionResults(FunctionOpInterface op, + const mlir::BitVector &resultIndices, mlir::Type newType); + +/// Set a FunctionOpInterface operation's type signature. +void setFunctionType(FunctionOpInterface op, mlir::Type newType); + +//===----------------------------------------------------------------------===// +// Function Argument mlir::Attribute. +//===----------------------------------------------------------------------===// + +/// Set the attributes held by the argument at 'index'. +void setArgAttrs(FunctionOpInterface op, unsigned index, + mlir::ArrayRef attributes); +void setArgAttrs(FunctionOpInterface op, unsigned index, + mlir::DictionaryAttr attributes); + +/// If the an attribute exists with the specified name, change it to the new +/// value. Otherwise, add a new attribute with the specified name/value. +template +void setArgAttr(ConcreteType op, unsigned index, mlir::StringAttr name, + mlir::Attribute value) { + mlir::NamedAttrList attributes(op.getArgAttrDict(index)); + mlir::Attribute oldValue = attributes.set(name, value); + + // If the attribute changed, then set the new arg attribute list. + if (value != oldValue) + op.setArgAttrs(index, attributes.getDictionary(value.getContext())); +} + +/// Remove the attribute 'name' from the argument at 'index'. Returns the +/// removed attribute, or nullptr if `name` was not a valid attribute. +template +mlir::Attribute removeArgAttr(ConcreteType op, unsigned index, mlir::StringAttr name) { + // Build an attribute list and remove the attribute at 'name'. + mlir::NamedAttrList attributes(op.getArgAttrDict(index)); + mlir::Attribute removedAttr = attributes.erase(name); + + // If the attribute was removed, then update the argument dictionary. + if (removedAttr) + op.setArgAttrs(index, attributes.getDictionary(removedAttr.getContext())); + return removedAttr; +} + +//===----------------------------------------------------------------------===// +// Function Result mlir::Attribute. +//===----------------------------------------------------------------------===// + +/// Set the attributes held by the result at 'index'. +void setResultAttrs(FunctionOpInterface op, unsigned index, + mlir::ArrayRef attributes); +void setResultAttrs(FunctionOpInterface op, unsigned index, + mlir::DictionaryAttr attributes); + +/// If the an attribute exists with the specified name, change it to the new +/// value. Otherwise, add a new attribute with the specified name/value. +template +void setResultAttr(ConcreteType op, unsigned index, mlir::StringAttr name, + mlir::Attribute value) { + mlir::NamedAttrList attributes(op.getResultAttrDict(index)); + mlir::Attribute oldAttr = attributes.set(name, value); + + // If the attribute changed, then set the new arg attribute list. + if (oldAttr != value) + op.setResultAttrs(index, attributes.getDictionary(value.getContext())); +} + +/// Remove the attribute 'name' from the result at 'index'. +template +mlir::Attribute removeResultAttr(ConcreteType op, unsigned index, mlir::StringAttr name) { + // Build an attribute list and remove the attribute at 'name'. + mlir::NamedAttrList attributes(op.getResultAttrDict(index)); + mlir::Attribute removedAttr = attributes.erase(name); + + // If the attribute was removed, then update the result dictionary. + if (removedAttr) + op.setResultAttrs(index, + attributes.getDictionary(removedAttr.getContext())); + return removedAttr; +} + +/// This function defines the internal implementation of the `verifyTrait` +/// method on FunctionOpInterface::Trait. +template +mlir::LogicalResult verifyTrait(ConcreteOp op) { + if (failed(op.verifyType())) + return llvm::failure(); + + if (mlir::ArrayAttr allArgAttrs = op.getAllArgAttrs()) { + unsigned numArgs = op.getNumArguments(); + if (allArgAttrs.size() != numArgs) { + return op.emitOpError() + << "expects argument attribute array to have the same number of " + "elements as the number of function arguments, got " + << allArgAttrs.size() << ", but expected " << numArgs; + } + for (unsigned i = 0; i != numArgs; ++i) { + mlir::DictionaryAttr argAttrs = + llvm::dyn_cast_or_null(allArgAttrs[i]); + if (!argAttrs) { + return op.emitOpError() << "expects argument attribute dictionary " + "to be a mlir::DictionaryAttr, but got `" + << allArgAttrs[i] << "`"; + } + + // Verify that all of the argument attributes are dialect attributes, i.e. + // that they contain a dialect prefix in their name. Call the dialect, if + // registered, to verify the attributes themselves. + for (auto attr : argAttrs) { + if (!attr.getName().strref().contains('.')) + return op.emitOpError("arguments may only have dialect attributes"); + if (mlir::Dialect *dialect = attr.getNameDialect()) { + if (failed(dialect->verifyRegionArgAttribute(op, /*regionIndex=*/0, + /*argIndex=*/i, attr))) + return llvm::failure(); + } + } + } + } + if (mlir::ArrayAttr allResultAttrs = op.getAllResultAttrs()) { + unsigned numResults = op.getNumResults(); + if (allResultAttrs.size() != numResults) { + return op.emitOpError() + << "expects result attribute array to have the same number of " + "elements as the number of function results, got " + << allResultAttrs.size() << ", but expected " << numResults; + } + for (unsigned i = 0; i != numResults; ++i) { + mlir::DictionaryAttr resultAttrs = + llvm::dyn_cast_or_null(allResultAttrs[i]); + if (!resultAttrs) { + return op.emitOpError() << "expects result attribute dictionary " + "to be a mlir::DictionaryAttr, but got `" + << allResultAttrs[i] << "`"; + } + + // Verify that all of the result attributes are dialect attributes, i.e. + // that they contain a dialect prefix in their name. Call the dialect, if + // registered, to verify the attributes themselves. + for (auto attr : resultAttrs) { + if (!attr.getName().strref().contains('.')) + return op.emitOpError("results may only have dialect attributes"); + if (mlir::Dialect *dialect = attr.getNameDialect()) { + if (failed(dialect->verifyRegionResultAttribute(op, /*regionIndex=*/0, + /*resultIndex=*/i, + attr))) + return llvm::failure(); + } + } + } + } + + // Check that the op has exactly one region for the body. + if (op->getNumRegions() != 1) + return op.emitOpError("expects one region"); + + return op.verifyBody(); +} +} // namespace function_interface_impl +} // namespace mlir + +//===----------------------------------------------------------------------===// +// Tablegen Interface Declarations +//===----------------------------------------------------------------------===// + +VAST_RELAX_WARNINGS +#include "vast/Dialect/Core/Interfaces/FunctionInterface.h.inc" +VAST_UNRELAX_WARNINGS + +#endif // VAST_DIALECT_CORE_FUNCTIONINTERFACE_HPP diff --git a/include/vast/Dialect/Core/Interfaces/FunctionInterface.td b/include/vast/Dialect/Core/Interfaces/FunctionInterface.td new file mode 100644 index 0000000000..139bd691f6 --- /dev/null +++ b/include/vast/Dialect/Core/Interfaces/FunctionInterface.td @@ -0,0 +1,566 @@ +//===- FunctionInterfaces.td - Function interfaces --------*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file contains definitions for interfaces that support the definition of +// "function-like" operations. +// +//===----------------------------------------------------------------------===// + +#ifndef VAST_DIALECT_CORE_FUNCTIONINTERFACE_TD_ +#define VAST_DIALECT_CORE_FUNCTIONINTERFACE_TD_ + +include "mlir/Interfaces/CallInterfaces.td" + +//===----------------------------------------------------------------------===// +// FunctionOpInterface +//===----------------------------------------------------------------------===// + +def Core_FunctionOpInterface : OpInterface<"FunctionOpInterface", [ + CallableOpInterface + ]> { + let cppNamespace = "::vast::core"; + let description = [{ + This is a copy of mlir::FunctionOpInterface with the following changes: + - The interface is moved to the vast::core namespace. + - The interface does not implicitly inherit from SymbolOpInterface. + + This interfaces provides support for interacting with operations that + behave like functions. In particular, these operations: + + - must be symbols, i.e. have the `Symbol` trait. + - must have a single region, that may be comprised with multiple blocks, + that corresponds to the function body. + * when this region is empty, the operation corresponds to an external + function. + * leading arguments of the first block of the region are treated as + function arguments. + + The function, aside from implementing the various interface methods, + should have the following ODS arguments: + + - `function_type` (required) + * A TypeAttr that holds the signature type of the function. + + - `arg_attrs` (optional) + * An ArrayAttr of DictionaryAttr that contains attribute dictionaries + for each of the function arguments. + + - `res_attrs` (optional) + * An ArrayAttr of DictionaryAttr that contains attribute dictionaries + for each of the function results. + }]; + let methods = [ + InterfaceMethod<[{ + Returns the type of the function. + }], + "::mlir::Type", "getFunctionType">, + InterfaceMethod<[{ + Set the type of the function. This method should perform an unsafe + modification to the function type; it should not update argument or + result attributes. + }], + "void", "setFunctionTypeAttr", (ins "::mlir::TypeAttr":$type)>, + + InterfaceMethod<[{ + Returns a clone of the function type with the given argument and + result types. + + Note: The default implementation assumes the function type has + an appropriate clone method: + `Type clone(ArrayRef inputs, ArrayRef results)` + }], + "::mlir::Type", "cloneTypeWith", (ins + "::mlir::TypeRange":$inputs, "::mlir::TypeRange":$results + ), /*methodBody=*/[{}], /*defaultImplementation=*/[{ + return $_op.getFunctionType().clone(inputs, results); + }]>, + + InterfaceMethod<[{ + Verify the contents of the body of this function. + + Note: The default implementation merely checks that if the entry block + exists, it has the same number and type of arguments as the function type. + }], + "::llvm::LogicalResult", "verifyBody", (ins), + /*methodBody=*/[{}], /*defaultImplementation=*/[{ + if ($_op.isExternal()) + return ::mlir::success(); + ::llvm::ArrayRef<::mlir::Type> fnInputTypes = $_op.getArgumentTypes(); + // NOTE: This should just be $_op.front() but access generically + // because the interface methods defined here may be shadowed in + // arbitrary ways. https://github.com/llvm/llvm-project/issues/54807 + ::mlir::Block &entryBlock = $_op->getRegion(0).front(); + + unsigned numArguments = fnInputTypes.size(); + if (entryBlock.getNumArguments() != numArguments) + return $_op.emitOpError("entry block must have ") + << numArguments << " arguments to match function signature"; + + for (unsigned i = 0, e = fnInputTypes.size(); i != e; ++i) { + ::mlir::Type argType = entryBlock.getArgument(i).getType(); + if (fnInputTypes[i] != argType) { + return $_op.emitOpError("type of entry block argument #") + << i << '(' << argType + << ") must match the type of the corresponding argument in " + << "function signature(" << fnInputTypes[i] << ')'; + } + } + + return ::mlir::success(); + }]>, + InterfaceMethod<[{ + Verify the type attribute of the function for derived op-specific + invariants. + }], + "::llvm::LogicalResult", "verifyType", (ins), + /*methodBody=*/[{}], /*defaultImplementation=*/[{ + return ::mlir::success(); + }]>, + ]; + + let extraTraitClassDeclaration = [{ + //===------------------------------------------------------------------===// + // Builders + //===------------------------------------------------------------------===// + + /// Build the function with the given name, attributes, and type. This + /// builder also inserts an entry block into the function body with the + /// given argument types. + static void buildWithEntryBlock( + ::mlir::OpBuilder &builder, ::mlir::OperationState &state, ::mlir::StringRef name, ::mlir::Type type, + ::mlir::ArrayRef<::mlir::NamedAttribute> attrs, ::mlir::TypeRange inputTypes) { + ::mlir::OpBuilder::InsertionGuard g(builder); + state.addAttribute(::mlir::SymbolTable::getSymbolAttrName(), + builder.getStringAttr(name)); + state.addAttribute(ConcreteOp::getFunctionTypeAttrName(state.name), + ::mlir::TypeAttr::get(type)); + state.attributes.append(attrs.begin(), attrs.end()); + + // Add the function body. + ::mlir::Region *bodyRegion = state.addRegion(); + ::mlir::Block *body = builder.createBlock(bodyRegion); + for (::mlir::Type input : inputTypes) + body->addArgument(input, state.location); + } + }]; + let extraSharedClassDeclaration = [{ + /// Block list iterator types. + using BlockListType = ::mlir::Region::BlockListType; + using iterator = BlockListType::iterator; + using reverse_iterator = BlockListType::reverse_iterator; + + /// Block argument iterator types. + using BlockArgListType = ::mlir::Region::BlockArgListType; + using args_iterator = BlockArgListType::iterator; + + //===------------------------------------------------------------------===// + // Body Handling + //===------------------------------------------------------------------===// + + /// Returns true if this function is external, i.e. it has no body. + bool isExternal() { return empty(); } + + /// Return the region containing the body of this function. + ::mlir::Region &getFunctionBody() { return $_op->getRegion(0); } + + /// Delete all blocks from this function. + void eraseBody() { + getFunctionBody().dropAllReferences(); + getFunctionBody().getBlocks().clear(); + } + + /// Return the list of blocks within the function body. + BlockListType &getBlocks() { return getFunctionBody().getBlocks(); } + + iterator begin() { return getFunctionBody().begin(); } + iterator end() { return getFunctionBody().end(); } + reverse_iterator rbegin() { return getFunctionBody().rbegin(); } + reverse_iterator rend() { return getFunctionBody().rend(); } + + /// Returns true if this function has no blocks within the body. + bool empty() { return getFunctionBody().empty(); } + + /// Push a new block to the back of the body region. + void push_back(::mlir::Block *block) { getFunctionBody().push_back(block); } + + /// Push a new block to the front of the body region. + void push_front(::mlir::Block *block) { getFunctionBody().push_front(block); } + + /// Return the last block in the body region. + ::mlir::Block &back() { return getFunctionBody().back(); } + + /// Return the first block in the body region. + ::mlir::Block &front() { return getFunctionBody().front(); } + + /// Add an entry block to an empty function, and set up the block arguments + /// to match the signature of the function. The newly inserted entry block + /// is returned. + ::mlir::Block *addEntryBlock() { + assert(empty() && "function already has an entry block"); + ::mlir::Block *entry = new ::mlir::Block(); + push_back(entry); + + // FIXME: Allow for passing in locations for these arguments instead of using + // the operations location. + ::llvm::ArrayRef<::mlir::Type> inputTypes = $_op.getArgumentTypes(); + ::llvm::SmallVector<::mlir::Location> locations(inputTypes.size(), + $_op.getOperation()->getLoc()); + entry->addArguments(inputTypes, locations); + return entry; + } + + /// Add a normal block to the end of the function's block list. The function + /// should at least already have an entry block. + ::mlir::Block *addBlock() { + assert(!empty() && "function should at least have an entry block"); + push_back(new ::mlir::Block()); + return &back(); + } + + //===------------------------------------------------------------------===// + // Type Attribute Handling + //===------------------------------------------------------------------===// + + /// Change the type of this function in place. This is an extremely dangerous + /// operation and it is up to the caller to ensure that this is legal for + /// this function, and to restore invariants: + /// - the entry block args must be updated to match the function params. + /// - the argument/result attributes may need an update: if the new type + /// has less parameters we drop the extra attributes, if there are more + /// parameters they won't have any attributes. + void setType(::mlir::Type newType) { + ::vast::core::function_interface_impl::setFunctionType($_op, newType); + } + + //===------------------------------------------------------------------===// + // Argument and Result Handling + //===------------------------------------------------------------------===// + + /// Returns the number of function arguments. + unsigned getNumArguments() { return $_op.getArgumentTypes().size(); } + + /// Returns the number of function results. + unsigned getNumResults() { return $_op.getResultTypes().size(); } + + /// Returns the entry block function argument at the given index. + ::mlir::BlockArgument getArgument(unsigned idx) { + return getFunctionBody().getArgument(idx); + } + + /// Support argument iteration. + args_iterator args_begin() { return getFunctionBody().args_begin(); } + args_iterator args_end() { return getFunctionBody().args_end(); } + BlockArgListType getArguments() { return getFunctionBody().getArguments(); } + + /// Insert a single argument of type `argType` with attributes `argAttrs` and + /// location `argLoc` at `argIndex`. + void insertArgument(unsigned argIndex, ::mlir::Type argType, ::mlir::DictionaryAttr argAttrs, + ::mlir::Location argLoc) { + insertArguments({argIndex}, {argType}, {argAttrs}, {argLoc}); + } + + /// Inserts arguments with the listed types, attributes, and locations at the + /// listed indices. `argIndices` must be sorted. Arguments are inserted in the + /// order they are listed, such that arguments with identical index will + /// appear in the same order that they were listed here. + void insertArguments(::llvm::ArrayRef argIndices, ::mlir::TypeRange argTypes, + ::llvm::ArrayRef<::mlir::DictionaryAttr> argAttrs, + ::llvm::ArrayRef<::mlir::Location> argLocs) { + unsigned originalNumArgs = $_op.getNumArguments(); + ::mlir::Type newType = $_op.getTypeWithArgsAndResults( + argIndices, argTypes, /*resultIndices=*/{}, /*resultTypes=*/{}); + ::vast::core::function_interface_impl::insertFunctionArguments( + $_op, argIndices, argTypes, argAttrs, argLocs, + originalNumArgs, newType); + } + + /// Insert a single result of type `resultType` at `resultIndex`. + void insertResult(unsigned resultIndex, ::mlir::Type resultType, + ::mlir::DictionaryAttr resultAttrs) { + insertResults({resultIndex}, {resultType}, {resultAttrs}); + } + + /// Inserts results with the listed types at the listed indices. + /// `resultIndices` must be sorted. Results are inserted in the order they are + /// listed, such that results with identical index will appear in the same + /// order that they were listed here. + void insertResults(::llvm::ArrayRef resultIndices, ::mlir::TypeRange resultTypes, + ::llvm::ArrayRef<::mlir::DictionaryAttr> resultAttrs) { + unsigned originalNumResults = $_op.getNumResults(); + ::mlir::Type newType = $_op.getTypeWithArgsAndResults( + /*argIndices=*/{}, /*argTypes=*/{}, resultIndices, resultTypes); + ::vast::core::function_interface_impl::insertFunctionResults( + $_op, resultIndices, resultTypes, resultAttrs, + originalNumResults, newType); + } + + /// Erase a single argument at `argIndex`. + void eraseArgument(unsigned argIndex) { + ::llvm::BitVector argsToErase($_op.getNumArguments()); + argsToErase.set(argIndex); + eraseArguments(argsToErase); + } + + /// Erases the arguments listed in `argIndices`. + void eraseArguments(const ::llvm::BitVector &argIndices) { + ::mlir::Type newType = $_op.getTypeWithoutArgs(argIndices); + ::vast::core::function_interface_impl::eraseFunctionArguments( + $_op, argIndices, newType); + } + + /// Erase a single result at `resultIndex`. + void eraseResult(unsigned resultIndex) { + ::llvm::BitVector resultsToErase($_op.getNumResults()); + resultsToErase.set(resultIndex); + eraseResults(resultsToErase); + } + + /// Erases the results listed in `resultIndices`. + void eraseResults(const ::llvm::BitVector &resultIndices) { + ::mlir::Type newType = $_op.getTypeWithoutResults(resultIndices); + ::vast::core::function_interface_impl::eraseFunctionResults( + $_op, resultIndices, newType); + } + + /// Return the type of this function with the specified arguments and + /// results inserted. This is used to update the function's signature in + /// the `insertArguments` and `insertResults` methods. The arrays must be + /// sorted by increasing index. + ::mlir::Type getTypeWithArgsAndResults( + ::llvm::ArrayRef argIndices, ::mlir::TypeRange argTypes, + ::llvm::ArrayRef resultIndices, ::mlir::TypeRange resultTypes) { + ::llvm::SmallVector<::mlir::Type> argStorage, resultStorage; + ::mlir::TypeRange newArgTypes = insertTypesInto( + $_op.getArgumentTypes(), argIndices, argTypes, argStorage); + ::mlir::TypeRange newResultTypes = insertTypesInto( + $_op.getResultTypes(), resultIndices, resultTypes, resultStorage); + return $_op.cloneTypeWith(newArgTypes, newResultTypes); + } + + /// Return the type of this function without the specified arguments and + /// results. This is used to update the function's signature in the + /// `eraseArguments` and `eraseResults` methods. + ::mlir::Type getTypeWithoutArgsAndResults( + const ::llvm::BitVector &argIndices, const ::llvm::BitVector &resultIndices) { + ::llvm::SmallVector<::mlir::Type> argStorage, resultStorage; + ::mlir::TypeRange newArgTypes = filterTypesOut( + $_op.getArgumentTypes(), argIndices, argStorage); + ::mlir::TypeRange newResultTypes = filterTypesOut( + $_op.getResultTypes(), resultIndices, resultStorage); + return $_op.cloneTypeWith(newArgTypes, newResultTypes); + } + ::mlir::Type getTypeWithoutArgs(const ::llvm::BitVector &argIndices) { + ::llvm::SmallVector<::mlir::Type> argStorage; + ::mlir::TypeRange newArgTypes = filterTypesOut( + $_op.getArgumentTypes(), argIndices, argStorage); + return $_op.cloneTypeWith(newArgTypes, $_op.getResultTypes()); + } + ::mlir::Type getTypeWithoutResults(const ::llvm::BitVector &resultIndices) { + ::llvm::SmallVector<::mlir::Type> resultStorage; + ::mlir::TypeRange newResultTypes = filterTypesOut( + $_op.getResultTypes(), resultIndices, resultStorage); + return $_op.cloneTypeWith($_op.getArgumentTypes(), newResultTypes); + } + + //===------------------------------------------------------------------===// + // Argument Attributes + //===------------------------------------------------------------------===// + + /// Return all of the attributes for the argument at 'index'. + ::llvm::ArrayRef<::mlir::NamedAttribute> getArgAttrs(unsigned index) { + return ::vast::core::function_interface_impl::getArgAttrs($_op, index); + } + + /// Return an ArrayAttr containing all argument attribute dictionaries of + /// this function, or nullptr if no arguments have attributes. + ::mlir::ArrayAttr getAllArgAttrs() { return $_op.getArgAttrsAttr(); } + + /// Return all argument attributes of this function. + void getAllArgAttrs(::llvm::SmallVectorImpl<::mlir::DictionaryAttr> &result) { + if (::mlir::ArrayAttr argAttrs = getAllArgAttrs()) { + auto argAttrRange = argAttrs.template getAsRange<::mlir::DictionaryAttr>(); + result.append(argAttrRange.begin(), argAttrRange.end()); + } else { + result.append($_op.getNumArguments(), + ::mlir::DictionaryAttr::get(this->getOperation()->getContext())); + } + } + + /// Return the specified attribute, if present, for the argument at 'index', + /// null otherwise. + ::mlir::Attribute getArgAttr(unsigned index, ::mlir::StringAttr name) { + auto argDict = getArgAttrDict(index); + return argDict ? argDict.get(name) : nullptr; + } + ::mlir::Attribute getArgAttr(unsigned index, ::llvm::StringRef name) { + auto argDict = getArgAttrDict(index); + return argDict ? argDict.get(name) : nullptr; + } + + template + AttrClass getArgAttrOfType(unsigned index, ::mlir::StringAttr name) { + return ::llvm::dyn_cast_or_null(getArgAttr(index, name)); + } + template + AttrClass getArgAttrOfType(unsigned index, ::llvm::StringRef name) { + return ::llvm::dyn_cast_or_null(getArgAttr(index, name)); + } + + /// Set the attributes held by the argument at 'index'. + void setArgAttrs(unsigned index, ::llvm::ArrayRef<::mlir::NamedAttribute> attributes) { + ::vast::core::function_interface_impl::setArgAttrs($_op, index, attributes); + } + + /// Set the attributes held by the argument at 'index'. `attributes` may be + /// null, in which case any existing argument attributes are removed. + void setArgAttrs(unsigned index, ::mlir::DictionaryAttr attributes) { + ::vast::core::function_interface_impl::setArgAttrs($_op, index, attributes); + } + void setAllArgAttrs(::llvm::ArrayRef<::mlir::DictionaryAttr> attributes) { + assert(attributes.size() == $_op.getNumArguments()); + ::vast::core::function_interface_impl::setAllArgAttrDicts($_op, attributes); + } + void setAllArgAttrs(::llvm::ArrayRef<::mlir::Attribute> attributes) { + assert(attributes.size() == $_op.getNumArguments()); + ::vast::core::function_interface_impl::setAllArgAttrDicts($_op, attributes); + } + void setAllArgAttrs(::mlir::ArrayAttr attributes) { + assert(attributes.size() == $_op.getNumArguments()); + $_op.setArgAttrsAttr(attributes); + } + + /// If the an attribute exists with the specified name, change it to the new + /// value. Otherwise, add a new attribute with the specified name/value. + void setArgAttr(unsigned index, ::mlir::StringAttr name, ::mlir::Attribute value) { + ::vast::core::function_interface_impl::setArgAttr($_op, index, name, value); + } + void setArgAttr(unsigned index, ::llvm::StringRef name, ::mlir::Attribute value) { + setArgAttr(index, + ::mlir::StringAttr::get(this->getOperation()->getContext(), name), + value); + } + + /// Remove the attribute 'name' from the argument at 'index'. Return the + /// attribute that was erased, or nullptr if there was no attribute with + /// such name. + ::mlir::Attribute removeArgAttr(unsigned index, ::mlir::StringAttr name) { + return ::vast::core::function_interface_impl::removeArgAttr($_op, index, name); + } + ::mlir::Attribute removeArgAttr(unsigned index, ::llvm::StringRef name) { + return removeArgAttr( + index, ::mlir::StringAttr::get(this->getOperation()->getContext(), name)); + } + + //===------------------------------------------------------------------===// + // Result Attributes + //===------------------------------------------------------------------===// + + /// Return all of the attributes for the result at 'index'. + ::llvm::ArrayRef<::mlir::NamedAttribute> getResultAttrs(unsigned index) { + return ::vast::core::function_interface_impl::getResultAttrs($_op, index); + } + + /// Return an ArrayAttr containing all result attribute dictionaries of this + /// function, or nullptr if no result have attributes. + ::mlir::ArrayAttr getAllResultAttrs() { return $_op.getResAttrsAttr(); } + + /// Return all result attributes of this function. + void getAllResultAttrs(::llvm::SmallVectorImpl<::mlir::DictionaryAttr> &result) { + if (::mlir::ArrayAttr argAttrs = getAllResultAttrs()) { + auto argAttrRange = argAttrs.template getAsRange<::mlir::DictionaryAttr>(); + result.append(argAttrRange.begin(), argAttrRange.end()); + } else { + result.append($_op.getNumResults(), + ::mlir::DictionaryAttr::get(this->getOperation()->getContext())); + } + } + + /// Return the specified attribute, if present, for the result at 'index', + /// null otherwise. + ::mlir::Attribute getResultAttr(unsigned index, ::mlir::StringAttr name) { + auto argDict = getResultAttrDict(index); + return argDict ? argDict.get(name) : nullptr; + } + ::mlir::Attribute getResultAttr(unsigned index, ::llvm::StringRef name) { + auto argDict = getResultAttrDict(index); + return argDict ? argDict.get(name) : nullptr; + } + + template + AttrClass getResultAttrOfType(unsigned index, ::mlir::StringAttr name) { + return ::llvm::dyn_cast_or_null(getResultAttr(index, name)); + } + template + AttrClass getResultAttrOfType(unsigned index, ::llvm::StringRef name) { + return ::llvm::dyn_cast_or_null(getResultAttr(index, name)); + } + + /// Set the attributes held by the result at 'index'. + void setResultAttrs(unsigned index, ::llvm::ArrayRef<::mlir::NamedAttribute> attributes) { + ::vast::core::function_interface_impl::setResultAttrs($_op, index, attributes); + } + + /// Set the attributes held by the result at 'index'. `attributes` may be + /// null, in which case any existing argument attributes are removed. + void setResultAttrs(unsigned index, ::mlir::DictionaryAttr attributes) { + ::vast::core::function_interface_impl::setResultAttrs($_op, index, attributes); + } + void setAllResultAttrs(::llvm::ArrayRef<::mlir::DictionaryAttr> attributes) { + assert(attributes.size() == $_op.getNumResults()); + ::vast::core::function_interface_impl::setAllResultAttrDicts( + $_op, attributes); + } + void setAllResultAttrs(::llvm::ArrayRef<::mlir::Attribute> attributes) { + assert(attributes.size() == $_op.getNumResults()); + ::vast::core::function_interface_impl::setAllResultAttrDicts( + $_op, attributes); + } + void setAllResultAttrs(::mlir::ArrayAttr attributes) { + assert(attributes.size() == $_op.getNumResults()); + $_op.setResAttrsAttr(attributes); + } + + /// If the an attribute exists with the specified name, change it to the new + /// value. Otherwise, add a new attribute with the specified name/value. + void setResultAttr(unsigned index, ::mlir::StringAttr name, ::mlir::Attribute value) { + ::vast::core::function_interface_impl::setResultAttr($_op, index, name, value); + } + void setResultAttr(unsigned index, ::llvm::StringRef name, ::mlir::Attribute value) { + setResultAttr(index, + ::mlir::StringAttr::get(this->getOperation()->getContext(), name), + value); + } + + /// Remove the attribute 'name' from the result at 'index'. Return the + /// attribute that was erased, or nullptr if there was no attribute with + /// such name. + ::mlir::Attribute removeResultAttr(unsigned index, ::mlir::StringAttr name) { + return ::vast::core::function_interface_impl::removeResultAttr($_op, index, name); + } + + /// Returns the dictionary attribute corresponding to the argument at + /// 'index'. If there are no argument attributes at 'index', a null + /// attribute is returned. + ::mlir::DictionaryAttr getArgAttrDict(unsigned index) { + assert(index < $_op.getNumArguments() && "invalid argument number"); + return ::vast::core::function_interface_impl::getArgAttrDict($_op, index); + } + + /// Returns the dictionary attribute corresponding to the result at 'index'. + /// If there are no result attributes at 'index', a null attribute is + /// returned. + ::mlir::DictionaryAttr getResultAttrDict(unsigned index) { + assert(index < $_op.getNumResults() && "invalid result number"); + return ::vast::core::function_interface_impl::getResultAttrDict($_op, index); + } + }]; + + let verify = "return function_interface_impl::verifyTrait(cast($_op));"; +} + +#endif // VAST_DIALECT_CORE_FUNCTIONINTERFACE_TD_ diff --git a/include/vast/Dialect/HighLevel/HighLevelAttributes.td b/include/vast/Dialect/HighLevel/HighLevelAttributes.td index 54585bf95e..64bed3701a 100644 --- a/include/vast/Dialect/HighLevel/HighLevelAttributes.td +++ b/include/vast/Dialect/HighLevel/HighLevelAttributes.td @@ -70,7 +70,7 @@ def HighLevel_FormatAttr : HighLevel_NameAttr< "Format", "format" >; def HighLevel_SectionAttr : HighLevel_NameAttr< "Section", "section" >; def HighLevel_AliasAttr : HighLevel_NameAttr< "Alias", "alias" >; def HighLevel_ErrorAttr : HighLevel_NameAttr< "Error", "error" >; -def HighLevel_CountedByAttr : HighLevel_FlatSymbolReferringAttr< "CountedBy", "counted_by" >; +def HighLevel_CountedByAttr : HighLevel_Attr< "CountedBy", "counted_by" >; def HighLevel_CleanupAttr : HighLevel_SymbolReferringAttr< "Cleanup", "cleanup" >; // TODO(#595): Make aligned attribute keep the alignment value/expr def HighLevel_AlignedAttr : HighLevel_Attr< "Aligned", "aligned" >; diff --git a/include/vast/Dialect/HighLevel/HighLevelOps.hpp b/include/vast/Dialect/HighLevel/HighLevelOps.hpp index 85bff6e749..1155d55710 100644 --- a/include/vast/Dialect/HighLevel/HighLevelOps.hpp +++ b/include/vast/Dialect/HighLevel/HighLevelOps.hpp @@ -6,7 +6,6 @@ VAST_RELAX_WARNINGS #include -#include #include VAST_UNRELAX_WARNINGS diff --git a/include/vast/Dialect/HighLevel/HighLevelOps.td b/include/vast/Dialect/HighLevel/HighLevelOps.td index c426d9d6e1..6d35c38696 100644 --- a/include/vast/Dialect/HighLevel/HighLevelOps.td +++ b/include/vast/Dialect/HighLevel/HighLevelOps.td @@ -80,7 +80,7 @@ def HighLevel_FuncOp } $_state.attributes.append(attrs.begin(), attrs.end()); - mlir::function_interface_impl::addArgAndResultAttrs( + vast::core::function_interface_impl::addArgAndResultAttrs( $_builder, $_state, arg_attrs, res_attrs, getArgAttrsAttrName($_state.name), getResAttrsAttrName($_state.name) ); @@ -386,7 +386,7 @@ def HighLevel_CallOp build($_builder, $_state, mlir::SymbolRefAttr::get($_builder.getContext(), callee), results, operands); }]>, OpBuilder< (ins "FuncOp":$callee, CArg< "mlir::ValueRange", "{}" >:$operands), [{ - build($_builder, $_state, callee.getName(), callee.getFunctionType().getResults(), operands); + build($_builder, $_state, callee.getSymName(), callee.getFunctionType().getResults(), operands); }]> ]; @@ -662,6 +662,8 @@ def HighLevel_AddressSpaceConversion : HighLevel_CastKindAttr<"AddressSpaceConve def HighLevel_IntToOCLSampler : HighLevel_CastKindAttr<"IntToOCLSampler", 63>; def HighLevel_MatrixCast : HighLevel_CastKindAttr<"MatrixCast", 64>; +def HighLevel_HLSLVectorTruncation : HighLevel_CastKindAttr<"HLSLVectorTruncation", 65>; +def HighLevel_HLSLArrayRValue : HighLevel_CastKindAttr<"HLSLArrayRValue", 66>; let cppNamespace = "::vast::hl" in def HighLevel_CastKind : HighLevel_CastKindList< "CastKind", "cast kind", [ @@ -743,7 +745,9 @@ def HighLevel_CastKind : HighLevel_CastKindList< "CastKind", "cast kind", [ HighLevel_AddressSpaceConversion, HighLevel_IntToOCLSampler, - HighLevel_MatrixCast + HighLevel_MatrixCast, + HighLevel_HLSLVectorTruncation, + HighLevel_HLSLArrayRValue ] >; class HighLevel_CastOp< string mnemonic, list< Trait > traits = [] > @@ -891,7 +895,7 @@ class HighLevel_ComplexUnaryOp< string mnemonic, list< Trait > traits = [] > , Arguments< (ins AnyType:$arg) > , Results< (outs AnyType:$result) > { - let assemblyFormat = [{ $arg attr-dict `:` functional-type(operands, results) }]; + let assemblyFormat = [{ $arg attr-dict `:` type($arg) `->` type($result) }]; } def HighLevel_RealOp : HighLevel_ComplexUnaryOp< "real" >; @@ -1025,7 +1029,7 @@ class HighLevel_CompoundAssignOpTemplate< string mnemonic, TypeConstraint Type, let builders = [ OpBuilder<(ins "Value":$dst, "Value":$src), [{ $_state.addOperands(src); $_state.addOperands(dst); - auto type = dst.getType().cast< LValueType >(); + auto type = mlir::cast< LValueType >(dst.getType()); $_state.addTypes(type.getElementType()); }]> ]; @@ -1173,7 +1177,7 @@ class HighLevel_UnInplaceOp< string mnemonic, list< Trait > traits = [] > TypesMatchWith< "underlying argument type match result type", "arg", "result", - "$_self.cast< LValueType >().getElementType()" + "mlir::cast< LValueType >($_self).getElementType()" > ]) > , Arguments<(ins HighLevel_LValueOf:$arg)> diff --git a/include/vast/Dialect/HighLevel/HighLevelTypes.td b/include/vast/Dialect/HighLevel/HighLevelTypes.td index 2cb0c6eb24..c38645ad71 100644 --- a/include/vast/Dialect/HighLevel/HighLevelTypes.td +++ b/include/vast/Dialect/HighLevel/HighLevelTypes.td @@ -192,7 +192,7 @@ def HighLevel_ParenType : HighLevel_Type< "Paren", [ElementTypeInterface, MemRef class HighLevel_ElaboratedType< TypeConstraint value > : Type< And< [ HighLevel_ElaboratedType.predicate, - SubstLeaves< "$_self", "$_self.cast< ElaboratedType >().getElementType()", value.predicate > + SubstLeaves< "$_self", "mlir::cast< ElaboratedType >($_self).getElementType()", value.predicate > ] > >; @@ -222,7 +222,7 @@ def HighLevel_LValue : HighLevel_TypeWithSubType< "LValue", "lvalue" > { } class HighLevel_LValueOf : Type< And< [ HighLevel_LValue.predicate, - SubstLeaves< "$_self", "$_self.cast< LValueType >().getElementType()", value.predicate > + SubstLeaves< "$_self", "mlir::cast< LValueType >($_self).getElementType()", value.predicate > ] >, "lvalue to " # value.summary >; @@ -504,7 +504,7 @@ class HighLevel_LValueOrType : Type< Or< [ And< [ HighLevel_LValue.predicate, - SubstLeaves< "$_self", "$_self.cast< LValueType >().getElementType()", + SubstLeaves< "$_self", "mlir::cast< LValueType >($_self).getElementType()", Or< [ value.predicate, HighLevel_UnresolvedType.predicate ] > > ] >, diff --git a/include/vast/Dialect/LowLevel/LowLevelOps.hpp b/include/vast/Dialect/LowLevel/LowLevelOps.hpp index 9cbe3cf11c..23f0dd49b1 100644 --- a/include/vast/Dialect/LowLevel/LowLevelOps.hpp +++ b/include/vast/Dialect/LowLevel/LowLevelOps.hpp @@ -12,7 +12,6 @@ VAST_RELAX_WARNINGS #include #include -#include #include #include VAST_RELAX_WARNINGS diff --git a/include/vast/Frontend/Consumer.hpp b/include/vast/Frontend/Consumer.hpp index 31c18e1213..d5f4573dc4 100644 --- a/include/vast/Frontend/Consumer.hpp +++ b/include/vast/Frontend/Consumer.hpp @@ -52,8 +52,6 @@ namespace vast::cc { void CompleteTentativeDefinition(clang::VarDecl *decl) override; - void CompleteExternalDeclaration(clang::VarDecl * /* decl */) override; - void AssignInheritanceModel(clang::CXXRecordDecl * /* decl */) override; void HandleVTable(clang::CXXRecordDecl * /* decl */) override; diff --git a/include/vast/Frontend/Driver.hpp b/include/vast/Frontend/Driver.hpp index a2c7e84209..bc48e460ec 100644 --- a/include/vast/Frontend/Driver.hpp +++ b/include/vast/Frontend/Driver.hpp @@ -103,8 +103,6 @@ namespace vast::cc { : cc1_entry_point(cc1), cmd_args(cmd_args), diag(cmd_args, path) , drv(path, llvm::sys::getDefaultTargetTriple(), diag.engine, "vast compiler") { - set_install_dir(cmd_args, canonical_prefixes); - auto target_and_mode = toolchain::getTargetAndModeFromProgramName(cmd_args[0]); drv.setTargetAndMode(target_and_mode); @@ -337,31 +335,6 @@ namespace vast::cc { } } - void set_install_dir(argv_storage_base &argv, bool canonical_prefixes) { - // Attempt to find the original path used to invoke the driver, to determine - // the installed path. We do this manually, because we want to support that - // path being a symlink. - llvm::SmallString< 128 > installed_path(argv[0]); - - // Do a PATH lookup, if there are no directory components. - if (llvm::sys::path::filename(installed_path) == installed_path) { - if (auto tmp = llvm::sys::findProgramByName(llvm::sys::path::filename(installed_path.str()))) { - installed_path = *tmp; - } - } - - // FIXME: We don't actually canonicalize this, we just make it absolute. - if (canonical_prefixes) { - llvm::sys::fs::make_absolute(installed_path); - } - - string_ref installed_path_parent(llvm::sys::path::parent_path(installed_path)); - if (llvm::sys::fs::exists(installed_path_parent)) { - drv.setInstalledDir(installed_path_parent); - } - } - - exec_compile_t cc1_entry_point; argv_storage_base &cmd_args; std::vector< std::string > cached_strings; diff --git a/include/vast/Interfaces/AST/TypeInterface.td b/include/vast/Interfaces/AST/TypeInterface.td index 15a21cf963..489af64b5c 100644 --- a/include/vast/Interfaces/AST/TypeInterface.td +++ b/include/vast/Interfaces/AST/TypeInterface.td @@ -6,6 +6,4 @@ include "mlir/IR/OpBase.td" include "vast/Interfaces/AST/Common.td" - - #endif // VAST_INTERFACES_AST_TYPE_INTERFACE diff --git a/include/vast/Util/Symbols.hpp b/include/vast/Util/Symbols.hpp index 1b24cd56d5..8c3a13c6d4 100644 --- a/include/vast/Util/Symbols.hpp +++ b/include/vast/Util/Symbols.hpp @@ -40,7 +40,7 @@ namespace vast::util template< typename Yield > void functions(mlir::Operation *op, Yield &&yield) { - // TODO use mlir::FunctionOpInterface? + // TODO use core::function_op_interface? op->walk([yield = std::forward< Yield >(yield)](hl::FuncOp fn, const mlir::WalkStage &stage) { yield(fn); }); @@ -100,9 +100,10 @@ namespace vast::util auto loc = value.getLoc(); std::string buff; llvm::raw_string_ostream ss(buff); - if (auto file_loc = loc.template dyn_cast< mlir::FileLineColLoc >()) { - ss << " : " << file_loc.getFilename().getValue() << ":" << file_loc.getLine() - << ":" << file_loc.getColumn(); + if (auto file_loc = mlir::dyn_cast< mlir::FileLineColLoc >(loc)) { + ss << " : " << file_loc.getFilename().getValue() + << ":" << file_loc.getLine() + << ":" << file_loc.getColumn(); } else { ss << " : " << loc; } diff --git a/include/vast/Util/TypeList.hpp b/include/vast/Util/TypeList.hpp index 2f4bfa9e15..ac9cdc02da 100644 --- a/include/vast/Util/TypeList.hpp +++ b/include/vast/Util/TypeList.hpp @@ -226,8 +226,8 @@ namespace vast::util { } else { using head = typename list::head; - if (type.isa< head >()) { - return f(type.cast< head >()); + if (auto ty = mlir::dyn_cast< head >(type)) { + return f(ty); } return dispatch< typename list::tail, ret >(type, std::forward< fn >(f)); @@ -320,4 +320,4 @@ namespace vast::util { static_assert( std::is_same_v< flatten< type_list< type_list< int > > >, type_list< int > > ); } // namespace test -} // namespace vast::util \ No newline at end of file +} // namespace vast::util diff --git a/include/vast/Util/TypeUtils.hpp b/include/vast/Util/TypeUtils.hpp index d604277b5e..6b0a21e72b 100644 --- a/include/vast/Util/TypeUtils.hpp +++ b/include/vast/Util/TypeUtils.hpp @@ -2,12 +2,6 @@ #pragma once -#include - -VAST_RELAX_WARNINGS -#include -VAST_UNRELAX_WARNINGS - #include #include @@ -86,7 +80,7 @@ namespace vast bool has_type_somewhere(operation op, auto &&accept) { auto contains_in_function_type = [&] { - if (auto fn = mlir::dyn_cast< mlir::FunctionOpInterface >(op)) { + if (auto fn = mlir::dyn_cast< core::function_op_interface >(op)) { return contains_subtype(fn.getResultTypes(), accept) || contains_subtype(fn.getArgumentTypes(), accept); } diff --git a/lib/vast/CodeGen/CodeGenFunction.cpp b/lib/vast/CodeGen/CodeGenFunction.cpp index 88431c9333..550a680e72 100644 --- a/lib/vast/CodeGen/CodeGenFunction.cpp +++ b/lib/vast/CodeGen/CodeGenFunction.cpp @@ -211,7 +211,7 @@ namespace vast::cg bool may_drop_function_return(clang_qual_type rty, acontext_t &actx) { // We can't just discard the return value for a record type with a // complex destructor or a non-trivially copyable type. - if (const auto *recorrd_type = rty.getCanonicalType()->getAs< clang::RecordType >()) { + if (rty.getCanonicalType()->getAs< clang::RecordType >()) { VAST_UNIMPLEMENTED; } diff --git a/lib/vast/CodeGen/DefaultAttrVisitor.cpp b/lib/vast/CodeGen/DefaultAttrVisitor.cpp index da5e16a2c8..59b96c4389 100644 --- a/lib/vast/CodeGen/DefaultAttrVisitor.cpp +++ b/lib/vast/CodeGen/DefaultAttrVisitor.cpp @@ -175,7 +175,7 @@ namespace vast::cg } mlir_attr default_attr_visitor::VisitCountedByAttr(const clang::CountedByAttr *attr) { - return make< hl::CountedByAttr >(attr->getCountedByField()->getName()); + return make< hl::CountedByAttr >(); } mlir_attr default_attr_visitor::VisitCleanupAttr(const clang::CleanupAttr *attr) { diff --git a/lib/vast/CodeGen/DefaultDeclVisitor.cpp b/lib/vast/CodeGen/DefaultDeclVisitor.cpp index 8693796430..025b5a1fc5 100644 --- a/lib/vast/CodeGen/DefaultDeclVisitor.cpp +++ b/lib/vast/CodeGen/DefaultDeclVisitor.cpp @@ -259,7 +259,7 @@ namespace vast::cg { operation default_decl_visitor::VisitParmVarDecl(const clang::ParmVarDecl *decl) { auto blk = bld.getInsertionBlock(); - if (auto fn = mlir::dyn_cast< mlir::FunctionOpInterface >(blk->getParentOp())) { + if (auto fn = mlir::dyn_cast< core::function_op_interface >(blk->getParentOp())) { auto param_index = decl->getFunctionScopeIndex(); return bld.compose< hl::ParmVarDeclOp >() .bind(self.location(decl)) diff --git a/lib/vast/CodeGen/DefaultStmtVisitor.cpp b/lib/vast/CodeGen/DefaultStmtVisitor.cpp index a02cb3e6e8..f02cf614e3 100644 --- a/lib/vast/CodeGen/DefaultStmtVisitor.cpp +++ b/lib/vast/CodeGen/DefaultStmtVisitor.cpp @@ -98,6 +98,9 @@ namespace vast::cg case clang::CastKind::CK_IntToOCLSampler: return CastKind::IntToOCLSampler; case clang::CastKind::CK_MatrixCast: return CastKind::MatrixCast; + + case clang::CastKind::CK_HLSLVectorTruncation: return CastKind::HLSLVectorTruncation; + case clang::CastKind::CK_HLSLArrayRValue: return CastKind::HLSLArrayRValue; } VAST_UNIMPLEMENTED_MSG( "unsupported cast kind" ); diff --git a/lib/vast/CodeGen/DefaultSymbolGenerator.cpp b/lib/vast/CodeGen/DefaultSymbolGenerator.cpp index 127ebb03af..79dcd3312b 100644 --- a/lib/vast/CodeGen/DefaultSymbolGenerator.cpp +++ b/lib/vast/CodeGen/DefaultSymbolGenerator.cpp @@ -25,7 +25,7 @@ namespace vast::cg // Some ABIs don't have constructor variants. Make sure that base and // complete constructors get mangled the same. - if (const auto *ctor = clang::dyn_cast< clang::CXXConstructorDecl >(decl)) { + if (clang::isa< clang::CXXConstructorDecl >(decl)) { if (!actx.getTargetInfo().getCXXABI().hasConstructorVariants()) { return std::nullopt; } @@ -81,7 +81,7 @@ namespace vast::cg }; if (const auto *field = clang::dyn_cast< clang::FieldDecl >(decl)) { - if (field->isUnnamedBitfield() || field->isAnonymousStructOrUnion()) { + if (field->isUnnamedBitField() || field->isAnonymousStructOrUnion()) { return anonoymous_mangle(); } } diff --git a/lib/vast/Conversion/ABI/EmitABI.cpp b/lib/vast/Conversion/ABI/EmitABI.cpp index 9bd45ff727..d0e101e9b5 100644 --- a/lib/vast/Conversion/ABI/EmitABI.cpp +++ b/lib/vast/Conversion/ABI/EmitABI.cpp @@ -68,7 +68,7 @@ namespace vast abi_info_map_t< R > out; auto gather = [&](R op, const mlir::WalkStage &) { - auto name = op.getName(); + auto name = mlir::cast< core::func_symbol >(op.getOperation()).getSymbolName(); out.emplace( name.str(), abi::make_x86_64(op, dl) ); return mlir::WalkResult::advance(); @@ -78,7 +78,7 @@ namespace vast return out; } - using func_abi_info_t = abi_info_map_t< mlir::FunctionOpInterface >; + using func_abi_info_t = abi_info_map_t< core::function_op_interface >; // TODO(conv:abi): Remove as we most likely do not need this. struct TypeConverter : conv::tc::mixins< TypeConverter >, @@ -98,7 +98,7 @@ namespace vast struct abi_info_utils { using types_t = std::vector< mlir::Type >; - using abi_info_t = abi::func_info< mlir::FunctionOpInterface >; + using abi_info_t = abi::func_info< core::function_op_interface >; const auto &self() const { return static_cast< const Self & >(*this); } auto &self() { return static_cast< Self & >(*this); } @@ -224,7 +224,7 @@ namespace vast abi::FuncOp make() { auto wrapper = core::convert_function_without_body< abi::FuncOp >( - op, rewriter, conv::abi::abi_func_name_prefix + op.getName().str(), + op, rewriter, conv::abi::abi_func_name_prefix + op.getSymbolName().str(), this->abified_type(op.isVarArg()) ); @@ -399,9 +399,9 @@ namespace vast using values_t = std::vector< mlir::Value >; - auto get_callee() -> mlir::FunctionOpInterface { + auto get_callee() -> core::function_op_interface { auto caller = mlir::dyn_cast< VastCallOpInterface >(*op); - auto callee = mlir::dyn_cast< mlir::FunctionOpInterface >( + auto callee = mlir::dyn_cast< core::function_op_interface >( caller.resolveCallable()); VAST_ASSERT(callee); return callee; @@ -689,10 +689,11 @@ namespace vast }; - template< typename Op > - struct func_type : mlir::OpConversionPattern< Op > + template< typename op_t > + struct func_type : mlir::OpConversionPattern< op_t > { - using Base = mlir::OpConversionPattern< Op >; + using base = mlir::OpConversionPattern< op_t >; + using adaptor_t = typename op_t::Adaptor; TypeConverter &tc; const func_abi_info_t &abi_info_map; @@ -700,19 +701,18 @@ namespace vast func_type(TypeConverter &tc, const func_abi_info_t &abi_info_map, mcontext_t &mctx) - : Base(tc, &mctx), tc(tc), abi_info_map(abi_info_map) + : base(tc, &mctx), tc(tc), abi_info_map(abi_info_map) {} - mlir::LogicalResult matchAndRewrite( - Op op, typename Op::Adaptor ops, - conversion_rewriter &rewriter) const override - { - auto abi_map_it = abi_info_map.find(op.getName().str()); + logical_result matchAndRewrite( + op_t op, adaptor_t ops, conversion_rewriter &rewriter + ) const override { + auto abi_map_it = abi_info_map.find(op.getSymbolName().str()); if (abi_map_it == abi_info_map.end()) return mlir::failure(); const auto &abi_info = abi_map_it->second; - abi_transform< Op >({ op, ops, rewriter }, abi_info).make(); + abi_transform< op_t >({ op, ops, rewriter }, abi_info).make(); rewriter.eraseOp(op); return mlir::success(); } @@ -768,7 +768,7 @@ namespace vast if (!func) return mlir::failure(); - auto name = func.getName(); + auto name = func.getSymbolName(); if (!name.consume_front(conv::abi::abi_func_name_prefix)) return mlir::failure(); @@ -818,8 +818,8 @@ namespace vast auto should_transform = [&](operation op) { // TODO(conv:abi): We should always emit main with a fixed type. - if (auto fn = mlir::dyn_cast< mlir::FunctionOpInterface >(op)) - return fn.getName() == "main"; + if (auto fn = mlir::dyn_cast< core::func_symbol >(op)) + return fn.getSymbolName() == "main"; return true; }; @@ -875,7 +875,7 @@ namespace vast const auto &dl_analysis = this->getAnalysis< mlir::DataLayoutAnalysis >(); auto tc = TypeConverter(dl_analysis.getAtOrAbove(op), mctx); - auto abi_info_map = collect_abi_info< mlir::FunctionOpInterface >( + auto abi_info_map = collect_abi_info< core::function_op_interface >( op, dl_analysis.getAtOrAbove(op)); if (mlir::failed(run(first_phase(tc, abi_info_map)))) diff --git a/lib/vast/Conversion/ABI/LowerABI.cpp b/lib/vast/Conversion/ABI/LowerABI.cpp index 1bb33604e7..d7eaa8b1ff 100644 --- a/lib/vast/Conversion/ABI/LowerABI.cpp +++ b/lib/vast/Conversion/ABI/LowerABI.cpp @@ -543,7 +543,7 @@ namespace vast logical_result matchAndRewrite( op_t op, adaptor_t adaptor, conversion_rewriter &rewriter) const override { - auto name = op.getName(); + auto name = op.getSymbolName(); if (!name.consume_front(conv::abi::abi_func_name_prefix)) return mlir::failure(); diff --git a/lib/vast/Conversion/FromHL/ToLLCF.cpp b/lib/vast/Conversion/FromHL/ToLLCF.cpp index 6ee7d805da..542c659842 100644 --- a/lib/vast/Conversion/FromHL/ToLLCF.cpp +++ b/lib/vast/Conversion/FromHL/ToLLCF.cpp @@ -36,7 +36,7 @@ namespace vast::conv { auto coerce_condition(auto op, conversion_rewriter &rewriter) -> std::optional< mlir::Value > { - auto int_type = op.getType().template dyn_cast< mlir::IntegerType >(); + auto int_type = mlir::dyn_cast< mlir::IntegerType >(op.getType()); if (!int_type) { return {}; } @@ -63,7 +63,7 @@ namespace vast::conv { } auto cond_yield(mlir::Block *block) { - auto cond_yield = hard_terminator_t::get(*block).cast< hl::CondYieldOp >(); + auto cond_yield = mlir::cast< hl::CondYieldOp >(hard_terminator_t::get(*block).value()); VAST_CHECK(cond_yield, "Block does not have a hl::CondYieldOp as terminator."); return cond_yield; } diff --git a/lib/vast/Conversion/Generic/LowerValueCategories.cpp b/lib/vast/Conversion/Generic/LowerValueCategories.cpp index 4483be0052..64d321a139 100644 --- a/lib/vast/Conversion/Generic/LowerValueCategories.cpp +++ b/lib/vast/Conversion/Generic/LowerValueCategories.cpp @@ -139,7 +139,7 @@ namespace vast::conv { using base = base_pattern< op_t >; VAST_DEFINE_REWRITE { - auto func_op = mlir::dyn_cast< mlir::FunctionOpInterface >(op.getOperation()); + auto func_op = mlir::dyn_cast< core::function_op_interface >(op.getOperation()); if (!func_op) { return mlir::failure(); } @@ -322,7 +322,7 @@ namespace vast::conv { auto rhs = ops.getSrc(); // TODO(lukas): This should not happen? - if (rhs.getType().template isa< hl::LValueType >()) { + if (mlir::isa< hl::LValueType >(rhs.getType())) { return logical_result::failure(); } diff --git a/lib/vast/Conversion/ToLLVM/IRsToLLVM.cpp b/lib/vast/Conversion/ToLLVM/IRsToLLVM.cpp index 50533590a8..7b48fcd068 100644 --- a/lib/vast/Conversion/ToLLVM/IRsToLLVM.cpp +++ b/lib/vast/Conversion/ToLLVM/IRsToLLVM.cpp @@ -454,7 +454,7 @@ namespace vast::conv::irstollvm VAST_CHECK(linkage, "Attempting lower function without set linkage {0}", func_op); auto new_func = rewriter.create< llvm_func_op >( func_op.getLoc(), - func_op.getName(), + func_op.getSymbolName(), target_type, core::convert_linkage_to_llvm(linkage.value()), func_op.isVarArg(), LLVM::CConv::C @@ -551,8 +551,7 @@ namespace vast::conv::irstollvm return logical_result::success(); } - mlir::Attribute convert_attr(auto attr, auto op, - conversion_rewriter &rewriter) const + mlir::Attribute convert_attr(auto attr, auto op, conversion_rewriter &rewriter) const { auto target_type = this->convert(attr.getType()); const auto &dl = this->type_converter().getDataLayoutAnalysis() @@ -560,8 +559,7 @@ namespace vast::conv::irstollvm if (!target_type) return {}; - if (auto float_attr = attr.template dyn_cast< core::FloatAttr >()) - { + if (auto float_attr = mlir::dyn_cast< core::FloatAttr >(attr)) { // NOTE(lukas): We cannot simply forward the return value of `getValue()` // because it can have different semantics than one expected // by `FloatAttr`. @@ -571,15 +569,16 @@ namespace vast::conv::irstollvm return rewriter.getFloatAttr(target_type, raw_value); } - if (auto int_attr = attr.template dyn_cast< core::IntegerAttr >()) - { + if (auto int_attr = mlir::dyn_cast< core::IntegerAttr >(attr)) { auto size = dl.getTypeSizeInBits(target_type); auto coerced = int_attr.getValue().sextOrTrunc(size); return rewriter.getIntegerAttr(target_type, coerced); } VAST_UNREACHABLE("Trying to convert attr that is not supported, {0} in op {1}", - attr, op); + attr, op + ); + return {}; } @@ -989,41 +988,43 @@ namespace vast::conv::irstollvm } auto callee = caller.resolveCallable(); - if (!callee && !mlir::isa< mlir::FunctionOpInterface >(callee)) { + if (!callee && !mlir::isa< core::function_op_interface >(callee)) { return logical_result::failure(); } - auto fn = mlir::cast< mlir::FunctionOpInterface >(callee); + auto fn = mlir::cast< core::function_op_interface >(callee); + auto fty = mlir::cast< core::FunctionType >(fn.getFunctionType()); auto rtys = type_converter().convert_types_to_types(fn.getResultTypes()); + auto atys = type_converter().convert_types_to_types(fn.getArgumentTypes()); if (!rtys) { return logical_result::failure(); } - auto mk_call = [&](auto... args) { - return rewriter.create< mlir::LLVM::CallOp >(op.getLoc(), args...); + auto mk_fty = [&] { + mlir_type rty = rtys->empty() + ? mlir::LLVM::LLVMVoidType::get(op.getContext()) + : rtys->front(); + + return mlir::LLVM::LLVMFunctionType::get(rty, atys.value(), fty.isVarArg()); }; - if (rtys->empty() || rtys->front().isa< mlir::LLVM::LLVMVoidType >()) { - // We cannot pass in void type as some internal check inside `mlir::LLVM` - // dialect will fire - it would create a value of void type, which is not - // allowed. - mk_call(types_t{}, op.getCallee(), ops.getOperands()); - rewriter.eraseOp(op); - } else { - auto call = mk_call(*rtys, op.getCallee(), ops.getOperands()); + auto call = rewriter.create< mlir::LLVM::CallOp >( + op.getLoc(), mk_fty(), op.getCallee(), ops.getOperands() + ); + + // the result gets removed when return type is void + // because the number of results is mismatched, we can't use replace (triggers assert) + // removing the op is ok, since in llvm dialect a void value can't be used anyway + if (call.getResult()) rewriter.replaceOp(op, call.getResults()); - } + else + rewriter.eraseOp(op); return logical_result::success(); } }; - bool is_lvalue(auto op) - { - return op && op.getType().template isa< hl::LValueType >(); - } - struct logical_not : base_pattern< hl::LNotOp > { using base = base_pattern< hl::LNotOp >; @@ -1106,8 +1107,6 @@ namespace vast::conv::irstollvm conversion_rewriter &rewriter) const override { auto arg = adaptor.getArg(); - if (is_lvalue(arg)) - return logical_result::failure(); auto arg_type = convert(arg.getType()); auto zero = this->constant(rewriter, op.getLoc(), arg_type, 0); @@ -1399,8 +1398,7 @@ namespace vast::conv::irstollvm // Some operations need to keep it even with void value. if (!mlir::isa< core::LazyOp >(op->getParentOp())) { - if (ops.getResult().getType().template isa< mlir::LLVM::LLVMVoidType >()) - { + if (mlir::isa< mlir::LLVM::LLVMVoidType >(ops.getResult().getType())) { rewriter.eraseOp(op); return logical_result::success(); } diff --git a/lib/vast/Conversion/ToLLVM/LLCFToLLVM.hpp b/lib/vast/Conversion/ToLLVM/LLCFToLLVM.hpp index b923f51d29..2a02c8a085 100644 --- a/lib/vast/Conversion/ToLLVM/LLCFToLLVM.hpp +++ b/lib/vast/Conversion/ToLLVM/LLCFToLLVM.hpp @@ -87,9 +87,9 @@ namespace vast::conv::irstollvm::ll_cf auto &last = block.back(); std::vector< mlir::Value > no_vals; - if (auto ret = mlir::dyn_cast< ll::ScopeRet >(last)) { + if (mlir::isa< ll::ScopeRet >(last)) { make_after_op< LLVM::BrOp >(rewriter, &last, last.getLoc(), no_vals, &end); - } else if (auto ret = mlir::isa< ll::ScopeRecurse >(last)) { + } else if (mlir::isa< ll::ScopeRecurse >(last)) { make_after_op< LLVM::BrOp >(rewriter, &last, last.getLoc(), no_vals, &start); } else if (auto ret = mlir::dyn_cast< ll::CondScopeRet >(last)) { diff --git a/lib/vast/Conversion/ToMem/EvictStaticLocals.cpp b/lib/vast/Conversion/ToMem/EvictStaticLocals.cpp index 84cefaf2d0..ee45dab148 100644 --- a/lib/vast/Conversion/ToMem/EvictStaticLocals.cpp +++ b/lib/vast/Conversion/ToMem/EvictStaticLocals.cpp @@ -28,21 +28,23 @@ namespace vast::conv { logical_result matchAndRewrite( hl::VarDeclOp op, adaptor_t adaptor, conversion_rewriter &rewriter ) const override { - auto parent_fn = op->getParentOfType< mlir::FunctionOpInterface >(); - if (!parent_fn) + auto fn = op->getParentOfType< core::function_op_interface >(); + if (!fn) return mlir::failure(); auto guard = insertion_guard(rewriter); - auto &module_block = parent_fn->getParentOfType< core::ModuleOp >().getBody().front(); + auto &module_block = fn->getParentOfType< core::ModuleOp >().getBody().front(); rewriter.setInsertionPoint(&module_block, module_block.begin()); + auto fn_symbol = mlir::dyn_cast< core::func_symbol >(fn.getOperation()); + auto new_decl = rewriter.create< hl::VarDeclOp >( - op.getLoc(), - op.getType(), - (parent_fn.getName() + "." + op.getSymName()).str(), - op.getStorageClass(), - op.getThreadStorageClass(), - std::optional(core::GlobalLinkageKind::InternalLinkage) + op.getLoc(), + op.getType(), + (fn_symbol.getSymbolName() + "." + op.getSymName()).str(), + op.getStorageClass(), + op.getThreadStorageClass(), + std::optional(core::GlobalLinkageKind::InternalLinkage) ); // Save current context informationinto the op to make sure the information stays valid @@ -58,7 +60,7 @@ namespace vast::conv { static void legalize(conversion_target &trg) { trg.addDynamicallyLegalOp< hl::VarDeclOp >([] (hl::VarDeclOp op) { - return !(op.isStaticLocal() && op->getParentOfType< mlir::FunctionOpInterface >()); + return !(op.isStaticLocal() && op->getParentOfType< core::function_op_interface >()); }); } }; @@ -75,15 +77,16 @@ namespace vast::conv { ) const override { auto var = core::symbol_table::lookup< core::var_symbol >(op, op.getName()); if (auto decl_storage = mlir::dyn_cast< core::DeclStorageInterface>(var)) { - auto parent_fn = op->getParentOfType< mlir::FunctionOpInterface >(); + auto fn = op->getParentOfType< core::function_op_interface >(); - if (!parent_fn || !decl_storage.isStaticLocal()) + if (!fn || !decl_storage.isStaticLocal()) return mlir::failure(); + auto fn_symbol = mlir::dyn_cast< core::func_symbol >(fn.getOperation()); + rewriter.replaceOpWithNewOp< hl::DeclRefOp >( - op, - op.getType(), - (parent_fn.getName() + "." + op.getName()).str() + op, op.getType(), + (fn_symbol.getSymbolName() + "." + op.getName()).str() ); return mlir::success(); } @@ -93,8 +96,8 @@ namespace vast::conv { static void legalize(conversion_target &trg) { trg.addDynamicallyLegalOp< hl::DeclRefOp >([&](hl::DeclRefOp op) { auto var = core::symbol_table::lookup< core::var_symbol >(op, op.getName()); - if (auto decl_storage = mlir::dyn_cast< core::DeclStorageInterface>(var)) { - return !(decl_storage.isStaticLocal() && var->getParentOfType< mlir::FunctionOpInterface >()); + if (auto storage = mlir::dyn_cast< core::DeclStorageInterface >(var)) { + return !(storage.isStaticLocal() && var->getParentOfType< core::function_op_interface >()); } return (bool)var; }); diff --git a/lib/vast/Conversion/ToMem/StripParamLValues.cpp b/lib/vast/Conversion/ToMem/StripParamLValues.cpp index 508c4c3563..c95894afeb 100644 --- a/lib/vast/Conversion/ToMem/StripParamLValues.cpp +++ b/lib/vast/Conversion/ToMem/StripParamLValues.cpp @@ -67,7 +67,7 @@ namespace vast::conv { conversion_target trg(mctx); trg.markUnknownOpDynamicallyLegal([] (operation op) { - if (auto fn = mlir::dyn_cast< mlir::FunctionOpInterface >(op)) { + if (auto fn = mlir::dyn_cast< core::function_op_interface >(op)) { auto fty = mlir::cast< core::FunctionType >(fn.getFunctionType()); return rns::all_of(fty.getInputs(), is_not_lvalue_type); } diff --git a/lib/vast/Dialect/Builtin/Ops.cpp b/lib/vast/Dialect/Builtin/Ops.cpp index 1f97ac95bc..a96547e97d 100644 --- a/lib/vast/Dialect/Builtin/Ops.cpp +++ b/lib/vast/Dialect/Builtin/Ops.cpp @@ -9,7 +9,6 @@ VAST_RELAX_WARNINGS #include #include -#include #include #include diff --git a/lib/vast/Dialect/Core/CMakeLists.txt b/lib/vast/Dialect/Core/CMakeLists.txt index 57edcd697c..cc2e3ac662 100644 --- a/lib/vast/Dialect/Core/CMakeLists.txt +++ b/lib/vast/Dialect/Core/CMakeLists.txt @@ -11,6 +11,7 @@ add_vast_dialect_library(Core LINK_LIBS PRIVATE VASTAliasTypeInterface + VASTFunctionInterface ) add_subdirectory(Interfaces) diff --git a/lib/vast/Dialect/Core/CoreDialect.cpp b/lib/vast/Dialect/Core/CoreDialect.cpp index 9073c54429..34784cb701 100644 --- a/lib/vast/Dialect/Core/CoreDialect.cpp +++ b/lib/vast/Dialect/Core/CoreDialect.cpp @@ -28,7 +28,7 @@ namespace vast::core AliasResult getAlias(mlir_type type, llvm::raw_ostream &os) const final { if (mlir::isa< CoreDialect >(type.getDialect())) { - if (auto ty = type.dyn_cast< AliasTypeInterface >()) { + if (auto ty = mlir::dyn_cast< AliasTypeInterface >(type)) { os << ty.getAlias(); return ty.getAliasResultKind(); } @@ -38,12 +38,12 @@ namespace vast::core } AliasResult getAlias(mlir_attr attr, llvm::raw_ostream &os) const final { - if (auto at = attr.dyn_cast< core::VoidAttr >()) { + if (auto at = mlir::dyn_cast< core::VoidAttr >(attr)) { os << "void_value"; return AliasResult::FinalAlias; } - if (auto at = attr.dyn_cast< core::BooleanAttr >()) { + if (auto at = mlir::dyn_cast< core::BooleanAttr >(attr)) { os << (at.getValue() ? "true" : "false"); return AliasResult::FinalAlias; } diff --git a/lib/vast/Dialect/Core/Func.cpp b/lib/vast/Dialect/Core/Func.cpp index 7ce19bc2d3..fb9d570b10 100644 --- a/lib/vast/Dialect/Core/Func.cpp +++ b/lib/vast/Dialect/Core/Func.cpp @@ -17,6 +17,8 @@ VAST_UNRELAX_WARNINGS #include "vast/Dialect/Core/CoreDialect.hpp" #include "vast/Dialect/Core/Linkage.hpp" +#include "vast/Dialect/Core/Interfaces/FunctionInterface.hpp" + #include "vast/Util/Common.hpp" #include "vast/Util/Region.hpp" @@ -55,7 +57,7 @@ namespace vast::core } bool is_variadic = false; - if (mlir::failed(mlir::function_interface_impl::parseFunctionSignature( + if (mlir::failed(vast::core::function_interface_impl::parseFunctionSignature( parser, /*allowVariadic=*/true, arguments, is_variadic, result_types, result_attrs ))) { return mlir::failure(); @@ -79,7 +81,7 @@ namespace vast::core // TODO: Add the attributes to the function arguments. // VAST_ASSERT(result_attrs.size() == result_types.size()); - // return mlir::function_interface_impl::addArgAndResultAttrs( + // return vast::core::function_interface_impl::addArgAndResultAttrs( // builder, state, arguments, result_attrs // ); diff --git a/lib/vast/Dialect/Core/Interfaces/CMakeLists.txt b/lib/vast/Dialect/Core/Interfaces/CMakeLists.txt index 6f91fa4033..b6ed89fd04 100644 --- a/lib/vast/Dialect/Core/Interfaces/CMakeLists.txt +++ b/lib/vast/Dialect/Core/Interfaces/CMakeLists.txt @@ -1,5 +1,7 @@ set(VAST_OPTIONAL_SOURCES DeclStorageInterface.cpp + FunctionInterface.cpp + FunctionImplementation.cpp SymbolInterface.cpp SymbolTableInterface.cpp SymbolRefInterface.cpp @@ -10,6 +12,11 @@ add_vast_interface_library(DeclStorageInterface DeclStorageInterface.cpp ) +add_vast_interface_library(FunctionInterface + FunctionInterface.cpp + FunctionImplementation.cpp +) + add_vast_interface_library(SymbolInterface SymbolInterface.cpp ) diff --git a/lib/vast/Dialect/Core/Interfaces/FunctionImplementation.cpp b/lib/vast/Dialect/Core/Interfaces/FunctionImplementation.cpp new file mode 100644 index 0000000000..5c012c2240 --- /dev/null +++ b/lib/vast/Dialect/Core/Interfaces/FunctionImplementation.cpp @@ -0,0 +1,345 @@ +//===- FunctionImplementation.cpp - Utilities for function-like ops -------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "vast/Dialect/Core/Interfaces/FunctionImplementation.hpp" + +VAST_RELAX_WARNINGS +#include +#include + +using namespace vast::core; +using namespace mlir; + +static ParseResult +parseFunctionArgumentList(OpAsmParser &parser, bool allowVariadic, + SmallVectorImpl &arguments, + bool &isVariadic) { + + // Parse the function arguments. The argument list either has to consistently + // have ssa-id's followed by types, or just be a type list. It isn't ok to + // sometimes have SSA ID's and sometimes not. + isVariadic = false; + + return parser.parseCommaSeparatedList( + OpAsmParser::Delimiter::Paren, [&]() -> ParseResult { + // Ellipsis must be at end of the list. + if (isVariadic) + return parser.emitError( + parser.getCurrentLocation(), + "variadic arguments must be in the end of the argument list"); + + // Handle ellipsis as a special case. + if (allowVariadic && succeeded(parser.parseOptionalEllipsis())) { + // This is a variadic designator. + isVariadic = true; + return success(); // Stop parsing arguments. + } + // Parse argument name if present. + OpAsmParser::Argument argument; + auto argPresent = parser.parseOptionalArgument( + argument, /*allowType=*/true, /*allowAttrs=*/true); + if (argPresent.has_value()) { + if (failed(argPresent.value())) + return failure(); // Present but malformed. + + // Reject this if the preceding argument was missing a name. + if (!arguments.empty() && arguments.back().ssaName.name.empty()) + return parser.emitError(argument.ssaName.location, + "expected type instead of SSA identifier"); + + } else { + argument.ssaName.location = parser.getCurrentLocation(); + // Otherwise we just have a type list without SSA names. Reject + // this if the preceding argument had a name. + if (!arguments.empty() && !arguments.back().ssaName.name.empty()) + return parser.emitError(argument.ssaName.location, + "expected SSA identifier"); + + NamedAttrList attrs; + if (parser.parseType(argument.type) || + parser.parseOptionalAttrDict(attrs) || + parser.parseOptionalLocationSpecifier(argument.sourceLoc)) + return failure(); + argument.attrs = attrs.getDictionary(parser.getContext()); + } + arguments.push_back(argument); + return success(); + }); +} + +/// Parse a function result list. +/// +/// function-result-list ::= function-result-list-parens +/// | non-function-type +/// function-result-list-parens ::= `(` `)` +/// | `(` function-result-list-no-parens `)` +/// function-result-list-no-parens ::= function-result (`,` function-result)* +/// function-result ::= type attribute-dict? +/// +static ParseResult +parseFunctionResultList(OpAsmParser &parser, SmallVectorImpl &resultTypes, + SmallVectorImpl &resultAttrs) { + if (failed(parser.parseOptionalLParen())) { + // We already know that there is no `(`, so parse a type. + // Because there is no `(`, it cannot be a function type. + Type ty; + if (parser.parseType(ty)) + return failure(); + resultTypes.push_back(ty); + resultAttrs.emplace_back(); + return success(); + } + + // Special case for an empty set of parens. + if (succeeded(parser.parseOptionalRParen())) + return success(); + + // Parse individual function results. + if (parser.parseCommaSeparatedList([&]() -> ParseResult { + resultTypes.emplace_back(); + resultAttrs.emplace_back(); + NamedAttrList attrs; + if (parser.parseType(resultTypes.back()) || + parser.parseOptionalAttrDict(attrs)) + return failure(); + resultAttrs.back() = attrs.getDictionary(parser.getContext()); + return success(); + })) + return failure(); + + return parser.parseRParen(); +} + +ParseResult function_interface_impl::parseFunctionSignature( + OpAsmParser &parser, bool allowVariadic, + SmallVectorImpl &arguments, bool &isVariadic, + SmallVectorImpl &resultTypes, + SmallVectorImpl &resultAttrs) { + if (parseFunctionArgumentList(parser, allowVariadic, arguments, isVariadic)) + return failure(); + if (succeeded(parser.parseOptionalArrow())) + return parseFunctionResultList(parser, resultTypes, resultAttrs); + return success(); +} + +void function_interface_impl::addArgAndResultAttrs( + Builder &builder, OperationState &result, ArrayRef argAttrs, + ArrayRef resultAttrs, StringAttr argAttrsName, + StringAttr resAttrsName) { + auto nonEmptyAttrsFn = [](DictionaryAttr attrs) { + return attrs && !attrs.empty(); + }; + // Convert the specified array of dictionary attrs (which may have null + // entries) to an ArrayAttr of dictionaries. + auto getArrayAttr = [&](ArrayRef dictAttrs) { + SmallVector attrs; + for (auto &dict : dictAttrs) + attrs.push_back(dict ? dict : builder.getDictionaryAttr({})); + return builder.getArrayAttr(attrs); + }; + + // Add the attributes to the function arguments. + if (llvm::any_of(argAttrs, nonEmptyAttrsFn)) + result.addAttribute(argAttrsName, getArrayAttr(argAttrs)); + + // Add the attributes to the function results. + if (llvm::any_of(resultAttrs, nonEmptyAttrsFn)) + result.addAttribute(resAttrsName, getArrayAttr(resultAttrs)); +} + +void function_interface_impl::addArgAndResultAttrs( + Builder &builder, OperationState &result, + ArrayRef args, ArrayRef resultAttrs, + StringAttr argAttrsName, StringAttr resAttrsName) { + SmallVector argAttrs; + for (const auto &arg : args) + argAttrs.push_back(arg.attrs); + addArgAndResultAttrs(builder, result, argAttrs, resultAttrs, argAttrsName, + resAttrsName); +} + +ParseResult function_interface_impl::parseFunctionOp( + OpAsmParser &parser, OperationState &result, bool allowVariadic, + StringAttr typeAttrName, FuncTypeBuilder funcTypeBuilder, + StringAttr argAttrsName, StringAttr resAttrsName) { + SmallVector entryArgs; + SmallVector resultAttrs; + SmallVector resultTypes; + auto &builder = parser.getBuilder(); + + // Parse visibility. + (void)impl::parseOptionalVisibilityKeyword(parser, result.attributes); + + // Parse the name as a symbol. + StringAttr nameAttr; + if (parser.parseSymbolName(nameAttr, SymbolTable::getSymbolAttrName(), + result.attributes)) + return failure(); + + // Parse the function signature. + SMLoc signatureLocation = parser.getCurrentLocation(); + bool isVariadic = false; + if (parseFunctionSignature(parser, allowVariadic, entryArgs, isVariadic, + resultTypes, resultAttrs)) + return failure(); + + std::string errorMessage; + SmallVector argTypes; + argTypes.reserve(entryArgs.size()); + for (auto &arg : entryArgs) + argTypes.push_back(arg.type); + Type type = funcTypeBuilder(builder, argTypes, resultTypes, + VariadicFlag(isVariadic), errorMessage); + if (!type) { + return parser.emitError(signatureLocation) + << "failed to construct function type" + << (errorMessage.empty() ? "" : ": ") << errorMessage; + } + result.addAttribute(typeAttrName, TypeAttr::get(type)); + + // If function attributes are present, parse them. + NamedAttrList parsedAttributes; + SMLoc attributeDictLocation = parser.getCurrentLocation(); + if (parser.parseOptionalAttrDictWithKeyword(parsedAttributes)) + return failure(); + + // Disallow attributes that are inferred from elsewhere in the attribute + // dictionary. + for (StringRef disallowed : + {SymbolTable::getVisibilityAttrName(), SymbolTable::getSymbolAttrName(), + typeAttrName.getValue()}) { + if (parsedAttributes.get(disallowed)) + return parser.emitError(attributeDictLocation, "'") + << disallowed + << "' is an inferred attribute and should not be specified in the " + "explicit attribute dictionary"; + } + result.attributes.append(parsedAttributes); + + // Add the attributes to the function arguments. + assert(resultAttrs.size() == resultTypes.size()); + addArgAndResultAttrs(builder, result, entryArgs, resultAttrs, argAttrsName, + resAttrsName); + + // Parse the optional function body. The printer will not print the body if + // its empty, so disallow parsing of empty body in the parser. + auto *body = result.addRegion(); + SMLoc loc = parser.getCurrentLocation(); + OptionalParseResult parseResult = + parser.parseOptionalRegion(*body, entryArgs, + /*enableNameShadowing=*/false); + if (parseResult.has_value()) { + if (failed(*parseResult)) + return failure(); + // Function body was parsed, make sure its not empty. + if (body->empty()) + return parser.emitError(loc, "expected non-empty function body"); + } + return success(); +} + +/// Print a function result list. The provided `attrs` must either be null, or +/// contain a set of DictionaryAttrs of the same arity as `types`. +static void printFunctionResultList(OpAsmPrinter &p, ArrayRef types, + ArrayAttr attrs) { + assert(!types.empty() && "Should not be called for empty result list."); + assert((!attrs || attrs.size() == types.size()) && + "Invalid number of attributes."); + + auto &os = p.getStream(); + bool needsParens = types.size() > 1 || llvm::isa(types[0]) || + (attrs && !llvm::cast(attrs[0]).empty()); + if (needsParens) + os << '('; + llvm::interleaveComma(llvm::seq(0, types.size()), os, [&](size_t i) { + p.printType(types[i]); + if (attrs) + p.printOptionalAttrDict(llvm::cast(attrs[i]).getValue()); + }); + if (needsParens) + os << ')'; +} + +void function_interface_impl::printFunctionSignature( + OpAsmPrinter &p, FunctionOpInterface op, ArrayRef argTypes, + bool isVariadic, ArrayRef resultTypes) { + Region &body = op->getRegion(0); + bool isExternal = body.empty(); + + p << '('; + ArrayAttr argAttrs = op.getArgAttrsAttr(); + for (unsigned i = 0, e = argTypes.size(); i < e; ++i) { + if (i > 0) + p << ", "; + + if (!isExternal) { + ArrayRef attrs; + if (argAttrs) + attrs = llvm::cast(argAttrs[i]).getValue(); + p.printRegionArgument(body.getArgument(i), attrs); + } else { + p.printType(argTypes[i]); + if (argAttrs) + p.printOptionalAttrDict( + llvm::cast(argAttrs[i]).getValue()); + } + } + + if (isVariadic) { + if (!argTypes.empty()) + p << ", "; + p << "..."; + } + + p << ')'; + + if (!resultTypes.empty()) { + p.getStream() << " -> "; + auto resultAttrs = op.getResAttrsAttr(); + printFunctionResultList(p, resultTypes, resultAttrs); + } +} + +void function_interface_impl::printFunctionAttributes( + OpAsmPrinter &p, Operation *op, ArrayRef elided) { + // Print out function attributes, if present. + SmallVector ignoredAttrs = {SymbolTable::getSymbolAttrName()}; + ignoredAttrs.append(elided.begin(), elided.end()); + + p.printOptionalAttrDictWithKeyword(op->getAttrs(), ignoredAttrs); +} + +void function_interface_impl::printFunctionOp( + OpAsmPrinter &p, FunctionOpInterface op, bool isVariadic, + StringRef typeAttrName, StringAttr argAttrsName, StringAttr resAttrsName) { + // Print the operation and the function name. + auto funcName = + op->getAttrOfType(SymbolTable::getSymbolAttrName()) + .getValue(); + p << ' '; + + StringRef visibilityAttrName = SymbolTable::getVisibilityAttrName(); + if (auto visibility = op->getAttrOfType(visibilityAttrName)) + p << visibility.getValue() << ' '; + p.printSymbolName(funcName); + + ArrayRef argTypes = op.getArgumentTypes(); + ArrayRef resultTypes = op.getResultTypes(); + printFunctionSignature(p, op, argTypes, isVariadic, resultTypes); + printFunctionAttributes( + p, op, {visibilityAttrName, typeAttrName, argAttrsName, resAttrsName}); + // Print the body if this is not an external function. + Region &body = op->getRegion(0); + if (!body.empty()) { + p << ' '; + p.printRegion(body, /*printEntryBlockArgs=*/false, + /*printBlockTerminators=*/true); + } +} + +VAST_UNRELAX_WARNINGS diff --git a/lib/vast/Dialect/Core/Interfaces/FunctionInterface.cpp b/lib/vast/Dialect/Core/Interfaces/FunctionInterface.cpp new file mode 100644 index 0000000000..5fc5d058a9 --- /dev/null +++ b/lib/vast/Dialect/Core/Interfaces/FunctionInterface.cpp @@ -0,0 +1,366 @@ +//===- FunctionSupport.cpp - Utility types for function-like ops ----------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "vast/Dialect/Core/Interfaces/FunctionInterface.hpp" + +using namespace vast::core; +using namespace mlir; + +VAST_RELAX_WARNINGS + +//===----------------------------------------------------------------------===// +// Tablegen Interface Definitions +//===----------------------------------------------------------------------===// + +#include "vast/Dialect/Core/Interfaces/FunctionInterface.cpp.inc" + +//===----------------------------------------------------------------------===// +// Function Arguments and Results. +//===----------------------------------------------------------------------===// + +static bool isEmptyAttrDict(Attribute attr) { + return llvm::cast(attr).empty(); +} + +DictionaryAttr function_interface_impl::getArgAttrDict(FunctionOpInterface op, + unsigned index) { + ArrayAttr attrs = op.getArgAttrsAttr(); + DictionaryAttr argAttrs = + attrs ? llvm::cast(attrs[index]) : DictionaryAttr(); + return argAttrs; +} + +DictionaryAttr +function_interface_impl::getResultAttrDict(FunctionOpInterface op, + unsigned index) { + ArrayAttr attrs = op.getResAttrsAttr(); + DictionaryAttr resAttrs = + attrs ? llvm::cast(attrs[index]) : DictionaryAttr(); + return resAttrs; +} + +ArrayRef +function_interface_impl::getArgAttrs(FunctionOpInterface op, unsigned index) { + auto argDict = getArgAttrDict(op, index); + return argDict ? argDict.getValue() : std::nullopt; +} + +ArrayRef +function_interface_impl::getResultAttrs(FunctionOpInterface op, + unsigned index) { + auto resultDict = getResultAttrDict(op, index); + return resultDict ? resultDict.getValue() : std::nullopt; +} + +/// Get either the argument or result attributes array. +template +static ArrayAttr getArgResAttrs(FunctionOpInterface op) { + if constexpr (isArg) + return op.getArgAttrsAttr(); + else + return op.getResAttrsAttr(); +} + +/// Set either the argument or result attributes array. +template +static void setArgResAttrs(FunctionOpInterface op, ArrayAttr attrs) { + if constexpr (isArg) + op.setArgAttrsAttr(attrs); + else + op.setResAttrsAttr(attrs); +} + +/// Erase either the argument or result attributes array. +template +static void removeArgResAttrs(FunctionOpInterface op) { + if constexpr (isArg) + op.removeArgAttrsAttr(); + else + op.removeResAttrsAttr(); +} + +/// Set all of the argument or result attribute dictionaries for a function. +template +static void setAllArgResAttrDicts(FunctionOpInterface op, + ArrayRef attrs) { + if (llvm::all_of(attrs, isEmptyAttrDict)) + removeArgResAttrs(op); + else + setArgResAttrs(op, ArrayAttr::get(op->getContext(), attrs)); +} + +void function_interface_impl::setAllArgAttrDicts( + FunctionOpInterface op, ArrayRef attrs) { + setAllArgAttrDicts(op, ArrayRef(attrs.data(), attrs.size())); +} + +void function_interface_impl::setAllArgAttrDicts(FunctionOpInterface op, + ArrayRef attrs) { + auto wrappedAttrs = llvm::map_range(attrs, [op](Attribute attr) -> Attribute { + return !attr ? DictionaryAttr::get(op->getContext()) : attr; + }); + setAllArgResAttrDicts(op, llvm::to_vector<8>(wrappedAttrs)); +} + +void function_interface_impl::setAllResultAttrDicts( + FunctionOpInterface op, ArrayRef attrs) { + setAllResultAttrDicts(op, ArrayRef(attrs.data(), attrs.size())); +} + +void function_interface_impl::setAllResultAttrDicts(FunctionOpInterface op, + ArrayRef attrs) { + auto wrappedAttrs = llvm::map_range(attrs, [op](Attribute attr) -> Attribute { + return !attr ? DictionaryAttr::get(op->getContext()) : attr; + }); + setAllArgResAttrDicts(op, llvm::to_vector<8>(wrappedAttrs)); +} + +/// Update the given index into an argument or result attribute dictionary. +template +static void setArgResAttrDict(FunctionOpInterface op, unsigned numTotalIndices, + unsigned index, DictionaryAttr attrs) { + ArrayAttr allAttrs = getArgResAttrs(op); + if (!allAttrs) { + if (attrs.empty()) + return; + + // If this attribute is not empty, we need to create a new attribute array. + SmallVector newAttrs(numTotalIndices, + DictionaryAttr::get(op->getContext())); + newAttrs[index] = attrs; + setArgResAttrs(op, ArrayAttr::get(op->getContext(), newAttrs)); + return; + } + // Check to see if the attribute is different from what we already have. + if (allAttrs[index] == attrs) + return; + + // If it is, check to see if the attribute array would now contain only empty + // dictionaries. + ArrayRef rawAttrArray = allAttrs.getValue(); + if (attrs.empty() && + llvm::all_of(rawAttrArray.take_front(index), isEmptyAttrDict) && + llvm::all_of(rawAttrArray.drop_front(index + 1), isEmptyAttrDict)) + return removeArgResAttrs(op); + + // Otherwise, create a new attribute array with the updated dictionary. + SmallVector newAttrs(rawAttrArray.begin(), rawAttrArray.end()); + newAttrs[index] = attrs; + setArgResAttrs(op, ArrayAttr::get(op->getContext(), newAttrs)); +} + +void function_interface_impl::setArgAttrs(FunctionOpInterface op, + unsigned index, + ArrayRef attributes) { + assert(index < op.getNumArguments() && "invalid argument number"); + return setArgResAttrDict( + op, op.getNumArguments(), index, + DictionaryAttr::get(op->getContext(), attributes)); +} + +void function_interface_impl::setArgAttrs(FunctionOpInterface op, + unsigned index, + DictionaryAttr attributes) { + return setArgResAttrDict( + op, op.getNumArguments(), index, + attributes ? attributes : DictionaryAttr::get(op->getContext())); +} + +void function_interface_impl::setResultAttrs( + FunctionOpInterface op, unsigned index, + ArrayRef attributes) { + assert(index < op.getNumResults() && "invalid result number"); + return setArgResAttrDict( + op, op.getNumResults(), index, + DictionaryAttr::get(op->getContext(), attributes)); +} + +void function_interface_impl::setResultAttrs(FunctionOpInterface op, + unsigned index, + DictionaryAttr attributes) { + assert(index < op.getNumResults() && "invalid result number"); + return setArgResAttrDict( + op, op.getNumResults(), index, + attributes ? attributes : DictionaryAttr::get(op->getContext())); +} + +void function_interface_impl::insertFunctionArguments( + FunctionOpInterface op, ArrayRef argIndices, TypeRange argTypes, + ArrayRef argAttrs, ArrayRef argLocs, + unsigned originalNumArgs, Type newType) { + assert(argIndices.size() == argTypes.size()); + assert(argIndices.size() == argAttrs.size() || argAttrs.empty()); + assert(argIndices.size() == argLocs.size()); + if (argIndices.empty()) + return; + + // There are 3 things that need to be updated: + // - Function type. + // - Arg attrs. + // - Block arguments of entry block. + Block &entry = op->getRegion(0).front(); + + // Update the argument attributes of the function. + ArrayAttr oldArgAttrs = op.getArgAttrsAttr(); + if (oldArgAttrs || !argAttrs.empty()) { + SmallVector newArgAttrs; + newArgAttrs.reserve(originalNumArgs + argIndices.size()); + unsigned oldIdx = 0; + auto migrate = [&](unsigned untilIdx) { + if (!oldArgAttrs) { + newArgAttrs.resize(newArgAttrs.size() + untilIdx - oldIdx); + } else { + auto oldArgAttrRange = oldArgAttrs.getAsRange(); + newArgAttrs.append(oldArgAttrRange.begin() + oldIdx, + oldArgAttrRange.begin() + untilIdx); + } + oldIdx = untilIdx; + }; + for (unsigned i = 0, e = argIndices.size(); i < e; ++i) { + migrate(argIndices[i]); + newArgAttrs.push_back(argAttrs.empty() ? DictionaryAttr{} : argAttrs[i]); + } + migrate(originalNumArgs); + setAllArgAttrDicts(op, newArgAttrs); + } + + // Update the function type and any entry block arguments. + op.setFunctionTypeAttr(TypeAttr::get(newType)); + for (unsigned i = 0, e = argIndices.size(); i < e; ++i) + entry.insertArgument(argIndices[i] + i, argTypes[i], argLocs[i]); +} + +void function_interface_impl::insertFunctionResults( + FunctionOpInterface op, ArrayRef resultIndices, + TypeRange resultTypes, ArrayRef resultAttrs, + unsigned originalNumResults, Type newType) { + assert(resultIndices.size() == resultTypes.size()); + assert(resultIndices.size() == resultAttrs.size() || resultAttrs.empty()); + if (resultIndices.empty()) + return; + + // There are 2 things that need to be updated: + // - Function type. + // - Result attrs. + + // Update the result attributes of the function. + ArrayAttr oldResultAttrs = op.getResAttrsAttr(); + if (oldResultAttrs || !resultAttrs.empty()) { + SmallVector newResultAttrs; + newResultAttrs.reserve(originalNumResults + resultIndices.size()); + unsigned oldIdx = 0; + auto migrate = [&](unsigned untilIdx) { + if (!oldResultAttrs) { + newResultAttrs.resize(newResultAttrs.size() + untilIdx - oldIdx); + } else { + auto oldResultAttrsRange = oldResultAttrs.getAsRange(); + newResultAttrs.append(oldResultAttrsRange.begin() + oldIdx, + oldResultAttrsRange.begin() + untilIdx); + } + oldIdx = untilIdx; + }; + for (unsigned i = 0, e = resultIndices.size(); i < e; ++i) { + migrate(resultIndices[i]); + newResultAttrs.push_back(resultAttrs.empty() ? DictionaryAttr{} + : resultAttrs[i]); + } + migrate(originalNumResults); + setAllResultAttrDicts(op, newResultAttrs); + } + + // Update the function type. + op.setFunctionTypeAttr(TypeAttr::get(newType)); +} + +void function_interface_impl::eraseFunctionArguments( + FunctionOpInterface op, const BitVector &argIndices, Type newType) { + // There are 3 things that need to be updated: + // - Function type. + // - Arg attrs. + // - Block arguments of entry block. + Block &entry = op->getRegion(0).front(); + + // Update the argument attributes of the function. + if (ArrayAttr argAttrs = op.getArgAttrsAttr()) { + SmallVector newArgAttrs; + newArgAttrs.reserve(argAttrs.size()); + for (unsigned i = 0, e = argIndices.size(); i < e; ++i) + if (!argIndices[i]) + newArgAttrs.emplace_back(llvm::cast(argAttrs[i])); + setAllArgAttrDicts(op, newArgAttrs); + } + + // Update the function type and any entry block arguments. + op.setFunctionTypeAttr(TypeAttr::get(newType)); + entry.eraseArguments(argIndices); +} + +void function_interface_impl::eraseFunctionResults( + FunctionOpInterface op, const BitVector &resultIndices, Type newType) { + // There are 2 things that need to be updated: + // - Function type. + // - Result attrs. + + // Update the result attributes of the function. + if (ArrayAttr resAttrs = op.getResAttrsAttr()) { + SmallVector newResultAttrs; + newResultAttrs.reserve(resAttrs.size()); + for (unsigned i = 0, e = resultIndices.size(); i < e; ++i) + if (!resultIndices[i]) + newResultAttrs.emplace_back(llvm::cast(resAttrs[i])); + setAllResultAttrDicts(op, newResultAttrs); + } + + // Update the function type. + op.setFunctionTypeAttr(TypeAttr::get(newType)); +} + +//===----------------------------------------------------------------------===// +// Function type signature. +//===----------------------------------------------------------------------===// + +void function_interface_impl::setFunctionType(FunctionOpInterface op, + Type newType) { + unsigned oldNumArgs = op.getNumArguments(); + unsigned oldNumResults = op.getNumResults(); + op.setFunctionTypeAttr(TypeAttr::get(newType)); + unsigned newNumArgs = op.getNumArguments(); + unsigned newNumResults = op.getNumResults(); + + // Functor used to update the argument and result attributes of the function. + auto emptyDict = DictionaryAttr::get(op.getContext()); + auto updateAttrFn = [&](auto isArg, unsigned oldCount, unsigned newCount) { + constexpr bool isArgVal = std::is_same_v; + + if (oldCount == newCount) + return; + // The new type has no arguments/results, just drop the attribute. + if (newCount == 0) + return removeArgResAttrs(op); + ArrayAttr attrs = getArgResAttrs(op); + if (!attrs) + return; + + // The new type has less arguments/results, take the first N attributes. + if (newCount < oldCount) + return setAllArgResAttrDicts( + op, attrs.getValue().take_front(newCount)); + + // Otherwise, the new type has more arguments/results. Initialize the new + // arguments/results with empty dictionary attributes. + SmallVector newAttrs(attrs.begin(), attrs.end()); + newAttrs.resize(newCount, emptyDict); + setAllArgResAttrDicts(op, newAttrs); + }; + + // Update the argument and result attributes. + updateAttrFn(std::true_type{}, oldNumArgs, newNumArgs); + updateAttrFn(std::false_type{}, oldNumResults, newNumResults); +} + +VAST_UNRELAX_WARNINGS diff --git a/lib/vast/Dialect/HighLevel/HighLevelDialect.cpp b/lib/vast/Dialect/HighLevel/HighLevelDialect.cpp index 70f85f4ebd..d4b5275d88 100644 --- a/lib/vast/Dialect/HighLevel/HighLevelDialect.cpp +++ b/lib/vast/Dialect/HighLevel/HighLevelDialect.cpp @@ -34,7 +34,7 @@ namespace vast::hl AliasResult getAlias(mlir_type type, llvm::raw_ostream &os) const final { if (mlir::isa< HighLevelDialect >(type.getDialect())) { - if (auto ty = type.dyn_cast< AliasTypeInterface >()) { + if (auto ty = mlir::dyn_cast< AliasTypeInterface >(type)) { os << ty.getAlias(); return ty.getAliasResultKind(); } diff --git a/lib/vast/Dialect/HighLevel/HighLevelOps.cpp b/lib/vast/Dialect/HighLevel/HighLevelOps.cpp index e2dca19916..d731cf3e20 100644 --- a/lib/vast/Dialect/HighLevel/HighLevelOps.cpp +++ b/lib/vast/Dialect/HighLevel/HighLevelOps.cpp @@ -16,7 +16,6 @@ VAST_RELAX_WARNINGS #include #include -#include VAST_UNRELAX_WARNINGS #include "vast/Dialect/HighLevel/HighLevelAttributes.hpp" diff --git a/lib/vast/Dialect/HighLevel/HighLevelTypes.cpp b/lib/vast/Dialect/HighLevel/HighLevelTypes.cpp index 961ffed6ba..c1254a1682 100644 --- a/lib/vast/Dialect/HighLevel/HighLevelTypes.cpp +++ b/lib/vast/Dialect/HighLevel/HighLevelTypes.cpp @@ -90,11 +90,11 @@ namespace vast::hl } core::FunctionType getFunctionType(mlir_type type, operation from) { - if (auto ty = type.dyn_cast< core::FunctionType >()) { + if (auto ty = mlir::dyn_cast< core::FunctionType >(type)) { return ty; } else if (auto ty = dyn_cast< ElementTypeInterface >(type)) { return getFunctionType(ty.getElementType(), from); - } else if (auto ty = type.dyn_cast< TypedefType >()) { + } else if (auto ty = mlir::dyn_cast< TypedefType >(type)) { auto mod = from->getParentOfType< core::module >(); return getFunctionType(getTypedefType(ty, mod), from); } else { @@ -107,13 +107,13 @@ namespace vast::hl return {}; } - if (auto sym = callee.dyn_cast< mlir::SymbolRefAttr >()) { + if (auto sym = mlir::dyn_cast< mlir::SymbolRefAttr >(callee)) { auto fn = core::symbol_table::lookup< core::func_symbol >(from, sym.getRootReference()); VAST_CHECK(fn, "Function {} not present in the symbol table.", sym.getRootReference()); return mlir::cast< FuncOp >(fn).getFunctionType(); } - if (auto value = callee.dyn_cast< mlir_value >()) { + if (auto value = mlir::dyn_cast< mlir_value >(callee)) { return getFunctionType(value.getType(), from); } @@ -130,7 +130,7 @@ namespace vast::hl bool isBoolType(mlir_type type) { - return type.isa< BoolType >(); + return mlir::isa< BoolType >(type); } bool isIntegerType(mlir_type type) @@ -149,8 +149,9 @@ namespace vast::hl return false; } - if (auto builtin_type = type.dyn_cast< mlir::IntegerType >()) + if (auto builtin_type = mlir::dyn_cast< mlir::IntegerType >(type)) { return builtin_type.isSigned(); + } VAST_ASSERT(isIntegerType(type)); return util::dispatch< integer_types, bool >(type, [] (auto ty) { @@ -186,8 +187,9 @@ namespace vast::hl // do this recursion? auto collect = [&](ArrayType arr, auto &self) -> mlir_type { dims.push_back(arr.getSize()); - if (auto nested = arr.getElementType().dyn_cast< ArrayType >()) + if (auto nested = mlir::dyn_cast< ArrayType >(arr.getElementType())) { return self(nested, self); + } return arr.getElementType(); }; return { std::move(dims), collect(*this, collect) }; diff --git a/lib/vast/Dialect/HighLevel/Transforms/ExportFnInfo.cpp b/lib/vast/Dialect/HighLevel/Transforms/ExportFnInfo.cpp index bed6bb1d19..91fdefdfe6 100644 --- a/lib/vast/Dialect/HighLevel/Transforms/ExportFnInfo.cpp +++ b/lib/vast/Dialect/HighLevel/Transforms/ExportFnInfo.cpp @@ -60,7 +60,7 @@ namespace vast::hl using Base = TypeEntryBase; DialectTypeEntry(DialectType type) : Base(type) {} - DialectType in_dialect() { return type.cast< DialectType >(); } + DialectType in_dialect() { return mlir::cast< DialectType >(type); } DialectTypeEntry &name() { raw["type"] = in_dialect().getMnemonic(); @@ -214,7 +214,7 @@ namespace vast::hl llvm::json::Object top; - // TODO use FunctionOpInterface instead of specific operation + // TODO use core::function_op_interface instead of specific operation util::functions(mod, [&](FuncOp fn) { const auto &dl_analysis = this->getAnalysis< mlir::DataLayoutAnalysis >(); const auto &dl = dl_analysis.getAtOrAbove(mod); @@ -233,7 +233,7 @@ namespace vast::hl current["rets"] = std::move(rets); current["args"] = std::move(args); - top[fn.getName().str()] = std::move(current); + top[fn.getSymName().str()] = std::move(current); }); auto value = llvm::formatv("{0:2}", llvm::json::Value(std::move(top))); diff --git a/lib/vast/Dialect/LowLevel/LowLevelOps.cpp b/lib/vast/Dialect/LowLevel/LowLevelOps.cpp index 74304a1cb0..d8047aa4a1 100644 --- a/lib/vast/Dialect/LowLevel/LowLevelOps.cpp +++ b/lib/vast/Dialect/LowLevel/LowLevelOps.cpp @@ -9,7 +9,6 @@ VAST_RELAX_WARNINGS #include -#include #include VAST_UNRELAX_WARNINGS diff --git a/lib/vast/Dialect/Meta/MetaDialect.cpp b/lib/vast/Dialect/Meta/MetaDialect.cpp index 51423e7225..ee74b82190 100644 --- a/lib/vast/Dialect/Meta/MetaDialect.cpp +++ b/lib/vast/Dialect/Meta/MetaDialect.cpp @@ -31,9 +31,7 @@ namespace vast::meta bool has_identifier(mlir::Operation *op, identifier_t id) { if (auto attr = op->getAttr(identifier_name)) { - if (attr.cast< IdentifierAttr >().getValue() == id) { - return true; - } + return mlir::cast< IdentifierAttr >(attr).getValue() == id; } return false; @@ -52,7 +50,7 @@ namespace vast::meta std::vector< mlir::Operation * > get_with_meta_location(mlir::Operation *scope, IdentifierAttr id) { std::vector< mlir::Operation * > result; scope->walk([&](mlir::Operation *op) { - if (auto loc = op->getLoc().dyn_cast< mlir::FusedLoc >()) { + if (auto loc = mlir::dyn_cast< mlir::FusedLoc >(op->getLoc())) { if (id == loc.getMetadata()) { result.push_back(op); } diff --git a/lib/vast/Frontend/Consumer.cpp b/lib/vast/Frontend/Consumer.cpp index 49070f32fb..de581feeee 100644 --- a/lib/vast/Frontend/Consumer.cpp +++ b/lib/vast/Frontend/Consumer.cpp @@ -99,8 +99,6 @@ namespace vast::cc { void vast_consumer::CompleteTentativeDefinition(clang::VarDecl * /* decl */) {} - void vast_consumer::CompleteExternalDeclaration(clang::VarDecl * /* decl */) {} - void vast_consumer::AssignInheritanceModel(clang::CXXRecordDecl * /* decl */) { VAST_UNIMPLEMENTED; } diff --git a/scripts/setup_llvm_dependencies.sh b/scripts/setup_llvm_dependencies.sh index 02f825624a..97c97a1186 100644 --- a/scripts/setup_llvm_dependencies.sh +++ b/scripts/setup_llvm_dependencies.sh @@ -1,7 +1,7 @@ #!/bin/bash # Set default LLVM version if not specified -LLVM_VERSION=${LLVM_VERSION:-18} +LLVM_VERSION=${LLVM_VERSION:-19} # Install LLVM tools and libraries bash -c "$(curl -s -o - https://apt.llvm.org/llvm.sh)" llvm.sh $LLVM_VERSION diff --git a/test/vast/Conversion/printf-a.c b/test/vast/Conversion/printf-a.c new file mode 100644 index 0000000000..d6ea97c3f8 --- /dev/null +++ b/test/vast/Conversion/printf-a.c @@ -0,0 +1,8 @@ +// RUN: %vast-front -vast-emit-mlir=llvm %s -o - | %file-check %s +#include + +int main(int argc, char **argv) +{ + // CHECK: llvm.call @printf({{.*}}, {{.*}}) vararg(!llvm.func) : (!llvm.ptr, i32) -> i32 + printf("argc: %i\n", argc); +} diff --git a/test/vast/Conversion/select-b.c b/test/vast/Conversion/select-b.c index 08d4ee358a..fdba378a5f 100644 --- a/test/vast/Conversion/select-b.c +++ b/test/vast/Conversion/select-b.c @@ -1,7 +1,3 @@ -// RUN: %vast-front -vast-emit-mlir=llvm -vast-snapshot-at="vast-core-to-llvm;vast-irs-to-llvm" %s -// RUN: %file-check %s -input-file=$(basename %s .c).vast-irs-to-llvm -check-prefix=I_LLVM -// RUN: %file-check %s -input-file=$(basename %s .c).vast-core-to-llvm -check-prefix=C_LLVM - // RUN: %vast-front -vast-emit-mlir-after=vast-irs-to-llvm %s -o %t.mlir // RUN: %file-check --input-file=%t.mlir %s -check-prefix=I_LLVM @@ -14,7 +10,6 @@ int main(int argc, char** argv) { // I_LLVM: [[V6:%[0-9]+]] = core.lazy.op { // I_LLVM: llvm.call @foo() : () -> () - // I_LLVM: {{.*}} = llvm.mlir.zero : !llvm.void // I_LLVM: [[V11:%[0-9]+]] = llvm.mlir.zero : !llvm.void // I_LLVM: hl.value.yield [[V11]] : !llvm.void // I_LLVM: } : !llvm.void diff --git a/test/vast/Dialect/HighLevel/complex-a.c b/test/vast/Dialect/HighLevel/complex-a.c index c2e1238567..ca3a82ae8c 100644 --- a/test/vast/Dialect/HighLevel/complex-a.c +++ b/test/vast/Dialect/HighLevel/complex-a.c @@ -6,18 +6,18 @@ void fun() { // CHECK: @x : !hl.lvalue> // CHECK: hl.const.init : !hl.complex double complex x = I; -// CHECK: hl.fadd {{.*}} : (!hl.complex, !hl.complex) +// CHECK: hl.fadd {{.*}} : (!hl.complex, !hl.double) double complex y = x + 3; double c = 3; // CHECK: hl.fadd {{.*}} : (!hl.complex, !hl.double) double complex z = x + c; -// CHEKC: hl.real {{.*}} : {{.*}}!hl.complex{{.*}} -> !hl.double +// CHEKC: hl.real {{.*}} : !hl.lvalue> -> !hl.double double a = __real__ y; -// CHEKC: hl.imag {{.*}} : {{.*}}!hl.complex{{.*}} -> !hl.double +// CHEKC: hl.imag {{.*}} : !hl.lvalue> -> !hl.double double b = __imag__ y; -// CHEKC: hl.real {{.*}} : {{.*}}!hl.complex{{.*}} -> !hl.lvalue -// CHEKC: hl.imag {{.*}} : {{.*}}!hl.complex{{.*}} -> !hl.lvalue +// CHEKC: hl.real {{.*}} : !hl.lvalue> -> !hl.lvalue +// CHEKC: hl.imag {{.*}} : !hl.lvalue> -> !hl.lvalue __real__ x = 1; __imag__ x = 1; } diff --git a/test/vast/Dialect/HighLevel/complex-c.c b/test/vast/Dialect/HighLevel/complex-c.c index 68fd100951..d5cd1dfe4a 100644 --- a/test/vast/Dialect/HighLevel/complex-c.c +++ b/test/vast/Dialect/HighLevel/complex-c.c @@ -10,12 +10,12 @@ void fun(void) { // CHECK: hl.add {{.*}} : (!hl.complex, !hl.complex) int complex y = x + z; -// CHEKC: hl.real {{.*}} : {{.*}}!hl.complex{{.*}} -> !hl.int +// CHEKC: hl.real {{.*}} : !hl.lvalue> -> !hl.int int u = __real__ y; -// CHEKC: hl.imag {{.*}} : {{.*}}!hl.complex{{.*}} -> !hl.int +// CHEKC: hl.imag {{.*}} : !hl.lvalue> -> !hl.int int v = __imag__ y; -// CHEKC: hl.real {{.*}} : {{.*}}!hl.complex{{.*}} -> !hl.lvalue +// CHEKC: hl.real {{.*}} : !hl.lvalue> -> !hl.lvalue __real__ y = 5; -// CHEKC: hl.imag {{.*}} : {{.*}}!hl.complex{{.*}} -> !hl.lvalue +// CHEKC: hl.imag {{.*}} : !hl.lvalue> -> !hl.lvalue __imag__ y = 6; } diff --git a/tools/vast-repl/vast-repl.cpp b/tools/vast-repl/vast-repl.cpp index e4ff027105..69066b4eb3 100644 --- a/tools/vast-repl/vast-repl.cpp +++ b/tools/vast-repl/vast-repl.cpp @@ -59,7 +59,7 @@ namespace vast::repl while (!cli.exit()) { std::string cmd; - if (auto quit = linenoise::Readline("> ", cmd)) { + if (linenoise::Readline("> ", cmd)) { break; }