From c1d91c7c7e3dfa5f6c698c8591b2900d27203ab5 Mon Sep 17 00:00:00 2001 From: Emilien Bauer Date: Thu, 20 Apr 2023 10:36:28 +0100 Subject: [PATCH] Black everything. --- .../workflows/update_xdsl_pyodide_build.py | 5 +- bench/parser/bench_lexer.py | 29 +- bench/parser/bench_parser.py | 53 +- docs/Toy/toy/dialect.py | 157 ++- docs/Toy/toy/interpreter.py | 100 +- docs/Toy/toy/ir_gen.py | 119 +- docs/Toy/toy/lexer.py | 21 +- docs/Toy/toy/location.py | 4 +- docs/Toy/toy/parser.py | 174 +-- docs/Toy/toy/tests/test_interpreter.py | 26 +- docs/Toy/toy/tests/test_ir_gen.py | 23 +- docs/Toy/toy/tests/test_parser.py | 226 +++- docs/Toy/toy/toy_ast.py | 114 +- setup.py | 2 +- tests/conftest.py | 32 +- tests/dialects/test_affine.py | 55 +- tests/dialects/test_arith.py | 105 +- tests/dialects/test_builtin.py | 106 +- tests/dialects/test_func.py | 32 +- tests/dialects/test_gpu.py | 82 +- tests/dialects/test_irdl.py | 26 +- tests/dialects/test_llvm.py | 73 +- tests/dialects/test_memref.py | 83 +- tests/dialects/test_mpi.py | 9 +- tests/dialects/test_mpi_lowering.py | 166 ++- tests/dialects/test_pdl.py | 76 +- tests/dialects/test_scf.py | 18 +- tests/dialects/test_stencil.py | 3 +- tests/dialects/test_vector.py | 364 +++--- tests/filecheck/frontend/dialects/arith.py | 3 - tests/filecheck/frontend/dialects/builtin.py | 14 +- tests/filecheck/frontend/dialects/invalid.py | 5 - tests/filecheck/frontend/programs/invalid.py | 10 - tests/interpreters/test_pdl_interpreter.py | 95 +- .../immutable_ir/test_immutable_ir.py | 46 +- tests/test_attribute_definition.py | 110 +- tests/test_frontend_op_inserter.py | 13 +- tests/test_frontend_op_resolver.py | 20 +- tests/test_frontend_python_code_check.py | 56 +- tests/test_frontend_type_conversion.py | 11 +- tests/test_immutable_list.py | 2 +- tests/test_interpreter.py | 3 +- tests/test_ir.py | 126 +- tests/test_is_satisfying_hint.py | 22 +- tests/test_lexer.py | 216 +-- tests/test_mlctx.py | 30 +- tests/test_mlir_printer.py | 30 +- tests/test_op_builder.py | 11 +- tests/test_operation_builder.py | 101 +- tests/test_operation_definition.py | 38 +- tests/test_parser.py | 402 +++--- tests/test_parser_error.py | 28 +- tests/test_pattern_rewriter.py | 438 ++++--- tests/test_printer.py | 134 +- tests/test_pyrdl.py | 73 +- tests/test_rewriter.py | 104 +- tests/test_ssa_value.py | 19 +- tests/test_traits.py | 29 +- tests/xdsl_opt/test_xdsl_opt.py | 64 +- versioneer.py | 127 +- xdsl/_version.py | 44 +- xdsl/builder.py | 36 +- xdsl/dialects/affine.py | 59 +- xdsl/dialects/arith.py | 377 +++--- xdsl/dialects/builtin.py | 446 ++++--- xdsl/dialects/cf.py | 25 +- xdsl/dialects/cmath.py | 53 +- xdsl/dialects/experimental/math.py | 439 ++++--- xdsl/dialects/experimental/stencil.py | 243 ++-- xdsl/dialects/func.py | 92 +- xdsl/dialects/gpu.py | 282 ++-- xdsl/dialects/irdl.py | 28 +- xdsl/dialects/llvm.py | 274 ++-- xdsl/dialects/memref.py | 297 +++-- xdsl/dialects/mpi.py | 206 +-- xdsl/dialects/pdl.py | 266 ++-- xdsl/dialects/scf.py | 90 +- xdsl/dialects/test.py | 12 +- xdsl/dialects/vector.py | 119 +- xdsl/frontend/block.py | 11 +- xdsl/frontend/code_generation.py | 137 +- xdsl/frontend/const.py | 8 +- xdsl/frontend/context.py | 9 +- xdsl/frontend/dialects/builtin.py | 57 +- xdsl/frontend/exception.py | 8 +- xdsl/frontend/op_inserter.py | 3 +- xdsl/frontend/op_resolver.py | 34 +- xdsl/frontend/passes/desymref.py | 71 +- xdsl/frontend/program.py | 9 +- xdsl/frontend/python_code_check.py | 146 ++- xdsl/frontend/symref.py | 15 +- xdsl/frontend/type_conversion.py | 60 +- xdsl/interpreter.py | 118 +- xdsl/interpreters/experimental/pdl.py | 156 ++- xdsl/ir.py | 395 +++--- xdsl/irdl.py | 639 +++++---- xdsl/irdl_mlir_printer.py | 89 +- xdsl/parser.py | 1153 ++++++++++------- xdsl/passes.py | 3 +- xdsl/pattern_rewriter.py | 152 +-- xdsl/printer.py | 325 +++-- xdsl/rewriter.py | 33 +- .../immutable_ir/immutable_ir.py | 224 ++-- .../experimental/Apply1DMPIToStencil.py | 170 ++- .../experimental/ConvertStencilToLLMLIR.py | 215 +-- .../experimental/StencilShapeInference.py | 80 +- .../experimental/stencil_global_to_local.py | 206 +-- xdsl/transforms/lower_mpi.py | 500 +++---- xdsl/utils/deprecation.py | 8 +- xdsl/utils/diagnostic.py | 10 +- xdsl/utils/exceptions.py | 19 +- xdsl/utils/hints.py | 16 +- xdsl/utils/immutable_list.py | 5 +- xdsl/utils/lexer.py | 198 +-- xdsl/utils/test_value.py | 3 +- xdsl/xdsl_opt_main.py | 149 ++- 116 files changed, 7703 insertions(+), 5782 deletions(-) diff --git a/.github/workflows/update_xdsl_pyodide_build.py b/.github/workflows/update_xdsl_pyodide_build.py index e6df8b645e..3318ab32a0 100755 --- a/.github/workflows/update_xdsl_pyodide_build.py +++ b/.github/workflows/update_xdsl_pyodide_build.py @@ -25,9 +25,6 @@ sha256_hash.update(byte_block) # Make it build the local xDSL, not the PyPi release. The pyodide build still requires the SHA256 sum. -yaml_doc["source"] = { - "url": f"file://{xdsl_sdist}", - "sha256": sha256_hash.hexdigest() -} +yaml_doc["source"] = {"url": f"file://{xdsl_sdist}", "sha256": sha256_hash.hexdigest()} with open(meta_yaml_path, "w") as f: yaml.dump(yaml_doc, f) diff --git a/bench/parser/bench_lexer.py b/bench/parser/bench_lexer.py index 76b5a47c9c..4d8492c29e 100644 --- a/bench/parser/bench_lexer.py +++ b/bench/parser/bench_lexer.py @@ -29,8 +29,9 @@ def run_on_files(file_names: Iterable[str]): try: contents = open(file_name, "r").read() input = Input(contents, file_name) - file_time = timeit.timeit(lambda: lex_file(input), - number=args.num_iterations) + file_time = timeit.timeit( + lambda: lex_file(input), number=args.num_iterations + ) total_time += file_time / args.num_iterations print("Time taken: " + str(file_time)) except Exception as e: @@ -45,20 +46,22 @@ def run_on_files(file_names: Iterable[str]): arg_parser.add_argument( "root_directory", type=str, - help="Path to the root directory containing MLIR files.") - arg_parser.add_argument("--num_iterations", - type=int, - required=False, - default=1, - help="Number of times to lex each file.") - arg_parser.add_argument("--profile", - action="store_true", - help="Enable profiling metrics.") + help="Path to the root directory containing MLIR files.", + ) + arg_parser.add_argument( + "--num_iterations", + type=int, + required=False, + default=1, + help="Number of times to lex each file.", + ) + arg_parser.add_argument( + "--profile", action="store_true", help="Enable profiling metrics." + ) args = arg_parser.parse_args() - file_names = list( - glob.iglob(args.root_directory + "/**/*.mlir", recursive=True)) + file_names = list(glob.iglob(args.root_directory + "/**/*.mlir", recursive=True)) print("Found " + str(len(file_names)) + " files to lex.") if args.profile: diff --git a/bench/parser/bench_parser.py b/bench/parser/bench_parser.py index 85284127ee..054782e747 100644 --- a/bench/parser/bench_parser.py +++ b/bench/parser/bench_parser.py @@ -54,19 +54,20 @@ def run_on_files(file_names: Iterable[str], mlir_path: str, ctx: MLContext): # Parse each sub-file separately. for sub_contents in splitted_contents: - # First, parse the file with MLIR to check that it is valid, and # print it back in generic form. - res = subprocess.run([ - mlir_path, - "--allow-unregistered-dialect", - "-mlir-print-op-generic", - "-mlir-print-local-scope", - ], - input=sub_contents, - text=True, - capture_output=True, - timeout=60) + res = subprocess.run( + [ + mlir_path, + "--allow-unregistered-dialect", + "-mlir-print-op-generic", + "-mlir-print-local-scope", + ], + input=sub_contents, + text=True, + capture_output=True, + timeout=60, + ) if res.returncode != 0: continue n_total_files += 1 @@ -78,7 +79,8 @@ def run_on_files(file_names: Iterable[str], mlir_path: str, ctx: MLContext): try: file_time = timeit.timeit( lambda: parse_file(generic_sub_contents, ctx), - number=args.num_iterations) + number=args.num_iterations, + ) total_time += file_time / args.num_iterations print("Time taken: " + str(file_time)) n_parsed_files += 1 @@ -96,27 +98,30 @@ def run_on_files(file_names: Iterable[str], mlir_path: str, ctx: MLContext): arg_parser.add_argument( "root_directory", type=str, - help="Path to the root directory containing MLIR files.") + help="Path to the root directory containing MLIR files.", + ) arg_parser.add_argument("--mlir-path", type=str, help="Path to mlir-opt.") - arg_parser.add_argument("--num_iterations", - type=int, - required=False, - default=1, - help="Number of times to parse each file.") - arg_parser.add_argument("--profile", - action="store_true", - help="Enable profiling metrics.") + arg_parser.add_argument( + "--num_iterations", + type=int, + required=False, + default=1, + help="Number of times to parse each file.", + ) + arg_parser.add_argument( + "--profile", action="store_true", help="Enable profiling metrics." + ) arg_parser.add_argument( "--timeout", type=int, required=False, default=60, - help="Timeout for processing each sub-program with MLIR. (in seconds)") + help="Timeout for processing each sub-program with MLIR. (in seconds)", + ) args = arg_parser.parse_args() - file_names = list( - glob.iglob(args.root_directory + "/**/*.mlir", recursive=True)) + file_names = list(glob.iglob(args.root_directory + "/**/*.mlir", recursive=True)) print("Found " + str(len(file_names)) + " files to parse.") ctx = MLContext() diff --git a/docs/Toy/toy/dialect.py b/docs/Toy/toy/dialect.py index 0e79e74303..1256185846 100644 --- a/docs/Toy/toy/dialect.py +++ b/docs/Toy/toy/dialect.py @@ -6,12 +6,28 @@ from typing import Annotated, TypeAlias, cast -from xdsl.ir import (Dialect, SSAValue, Attribute, Block, Region, OpResult) -from xdsl.dialects.builtin import (Float64Type, FunctionType, SymbolRefAttr, - TensorType, UnrankedTensorType, f64, - DenseIntOrFPElementsAttr, StringAttr) -from xdsl.irdl import (OpAttr, Operand, OptOpAttr, OptOperand, VarOpResult, - VarOperand, irdl_op_definition, AnyAttr, IRDLOperation) +from xdsl.ir import Dialect, SSAValue, Attribute, Block, Region, OpResult +from xdsl.dialects.builtin import ( + Float64Type, + FunctionType, + SymbolRefAttr, + TensorType, + UnrankedTensorType, + f64, + DenseIntOrFPElementsAttr, + StringAttr, +) +from xdsl.irdl import ( + OpAttr, + Operand, + OptOpAttr, + OptOperand, + VarOpResult, + VarOperand, + irdl_op_definition, + AnyAttr, + IRDLOperation, +) from xdsl.utils.exceptions import VerifyException from xdsl.utils.hints import isa @@ -32,13 +48,13 @@ class ConstantOp(IRDLOperation): : tensor<2x3xf64> ``` """ + name: str = "toy.constant" value: OpAttr[DenseIntOrFPElementsAttr] res: Annotated[OpResult, TensorTypeF64] def __init__(self, value: DenseIntOrFPElementsAttr): - super().__init__(result_types=[value.type], - attributes={"value": value}) + super().__init__(result_types=[value.type], attributes={"value": value}) @staticmethod def from_list(data: list[float], shape: list[int]) -> ConstantOp: @@ -49,7 +65,8 @@ def verify_(self) -> None: if not self.res.typ == self.value.type: raise VerifyException( "Expected value and result types to be equal: " - f"{self.res.typ}, {self.value.type}") + f"{self.res.typ}, {self.value.type}" + ) def get_type(self) -> TensorTypeF64: # Constant cannot be unranked @@ -68,7 +85,8 @@ class AddOp(IRDLOperation): The "add" operation performs element-wise addition between two tensors. The shapes of the tensor operands are expected to match. """ - name: str = 'toy.add' + + name: str = "toy.add" lhs: Annotated[Operand, AnyTensorTypeF64] rhs: Annotated[Operand, AnyTensorTypeF64] res: Annotated[OpResult, AnyTensorTypeF64] @@ -92,7 +110,8 @@ def verify_(self): else: if shape != arg.typ.shape: raise VerifyException( - "Expected AddOp args to have the same shape") + "Expected AddOp args to have the same shape" + ) @irdl_op_definition @@ -112,18 +131,16 @@ class FuncOp(IRDLOperation): } ``` """ - name: str = 'toy.func' + + name: str = "toy.func" body: Region sym_name: OpAttr[StringAttr] function_type: OpAttr[FunctionType] sym_visibility: OptOpAttr[StringAttr] - def __init__(self, - name: str, - ftype: FunctionType, - region: Region, - /, - private: bool = False): + def __init__( + self, name: str, ftype: FunctionType, region: Region, /, private: bool = False + ): attributes: dict[str, Attribute] = { "sym_name": StringAttr(name), "function_type": ftype, @@ -134,17 +151,21 @@ def __init__(self, return super().__init__(attributes=attributes, regions=[region]) @staticmethod - def from_callable(name: str, - input_types: list[Attribute], - return_types: list[Attribute], - func: Block.BlockCallback, - /, - private: bool = False): + def from_callable( + name: str, + input_types: list[Attribute], + return_types: list[Attribute], + func: Block.BlockCallback, + /, + private: bool = False, + ): ftype = FunctionType.from_lists(input_types, return_types) - return FuncOp(name, - ftype, - Region([Block.from_callable(input_types, func)]), - private=private) + return FuncOp( + name, + ftype, + Region([Block.from_callable(input_types, func)]), + private=private, + ) def verify_(self): # Check that the returned value matches the type of the function @@ -159,8 +180,7 @@ def verify_(self): last_op = block.last_op if not isinstance(last_op, ReturnOp): - raise VerifyException( - "Expected last op of FuncOp to be a ReturnOp") + raise VerifyException("Expected last op of FuncOp to be a ReturnOp") operand = last_op.input operand_typ = None if operand is None else operand.typ @@ -172,13 +192,15 @@ def verify_(self): return_typ = return_typs[0] else: raise VerifyException( - "Expected return type of func to have 0 or 1 values") + "Expected return type of func to have 0 or 1 values" + ) else: return_typ = None if operand_typ != return_typ: raise VerifyException( - "Expected return value to match return type of function") + "Expected return value to match return type of function" + ) @irdl_op_definition @@ -190,15 +212,20 @@ class GenericCallOp(IRDLOperation): # Note: naming this results triggers an ArgumentError res: Annotated[VarOpResult, AnyTensorTypeF64] - def __init__(self, callee: str | SymbolRefAttr, - operands: list[SSAValue | OpResult], - return_types: list[Attribute]): + def __init__( + self, + callee: str | SymbolRefAttr, + operands: list[SSAValue | OpResult], + return_types: list[Attribute], + ): if isinstance(callee, str): callee = SymbolRefAttr(callee) - return super().__init__(operands=[operands], - result_types=[return_types], - attributes={"callee": callee}) + return super().__init__( + operands=[operands], + result_types=[return_types], + attributes={"callee": callee}, + ) @irdl_op_definition @@ -207,7 +234,8 @@ class MulOp(IRDLOperation): The "mul" operation performs element-wise multiplication between two tensors. The shapes of the tensor operands are expected to match. """ - name: str = 'toy.mul' + + name: str = "toy.mul" lhs: Annotated[Operand, AnyTensorTypeF64] rhs: Annotated[Operand, AnyTensorTypeF64] res: Annotated[OpResult, AnyTensorTypeF64] @@ -231,7 +259,8 @@ def verify_(self): else: if shape != arg.typ.shape: raise VerifyException( - "Expected MulOp args to have the same shape") + "Expected MulOp args to have the same shape" + ) @irdl_op_definition @@ -240,7 +269,8 @@ class PrintOp(IRDLOperation): The "print" builtin operation prints a given input tensor, and produces no results. """ - name: str = 'toy.print' + + name: str = "toy.print" input: Annotated[Operand, AnyAttr()] def __init__(self, input: SSAValue): @@ -262,7 +292,8 @@ class ReturnOp(IRDLOperation): } ``` """ - name: str = 'toy.return' + + name: str = "toy.return" input: Annotated[OptOperand, AnyTensorTypeF64] def __init__(self, input: SSAValue | None = None): @@ -279,7 +310,8 @@ class ReshapeOp(IRDLOperation): %0 = toy.reshape (%arg1 : tensor<10xf64>) to tensor<5x2xf64> ``` """ - name: str = 'toy.reshape' + + name: str = "toy.reshape" arg: Annotated[Operand, AnyTensorTypeF64] # We expect that the reshape operation returns a statically shaped tensor. res: Annotated[OpResult, TensorTypeF64] @@ -287,7 +319,7 @@ class ReshapeOp(IRDLOperation): def __init__(self, arg: SSAValue, shape: list[int]): if not isa(arg.typ, AnyTensorTypeF64): raise ValueError( - f'Unexpected arg of type {arg.typ} passed to ReshapeOp, expected {AnyTensorTypeF64}' + f"Unexpected arg of type {arg.typ} passed to ReshapeOp, expected {AnyTensorTypeF64}" ) element_type = arg.typ.element_type t = TensorTypeF64.from_type_and_list(element_type, shape) @@ -297,7 +329,7 @@ def __init__(self, arg: SSAValue, shape: list[int]): def from_input_and_type(arg: SSAValue, t: TensorTypeF64) -> ReshapeOp: if not isa(arg.typ, AnyTensorTypeF64): raise ValueError( - f'Unexpected arg of type {arg.typ} passed to ReshapeOp, expected {AnyTensorTypeF64}' + f"Unexpected arg of type {arg.typ} passed to ReshapeOp, expected {AnyTensorTypeF64}" ) return ReshapeOp.create(result_types=[t], operands=[arg]) @@ -305,13 +337,12 @@ def verify_(self): result_type = self.res.typ assert isa(result_type, TensorTypeF64) if not len(result_type.shape.data): - raise VerifyException( - 'Reshape operation result shape should be defined') + raise VerifyException("Reshape operation result shape should be defined") @irdl_op_definition class TransposeOp(IRDLOperation): - name: str = 'toy.transpose' + name: str = "toy.transpose" arguments: Annotated[Operand, AnyTensorTypeF64] res: Annotated[OpResult, AnyTensorTypeF64] @@ -320,25 +351,29 @@ def __init__(self, input: SSAValue): if isa(input.typ, TensorTypeF64): element_type = input.typ.element_type output_type = TensorType.from_type_and_list( - element_type, list(reversed(input.typ.get_shape()))) + element_type, list(reversed(input.typ.get_shape())) + ) else: if not isa(input.typ, UnrankedTensorTypeF64): raise ValueError( - f'Unexpected arg of type {input.typ} passed to TransposeOp, expected {TensorTypeF64 | UnrankedTensorTypeF64}' + f"Unexpected arg of type {input.typ} passed to TransposeOp, expected {TensorTypeF64 | UnrankedTensorTypeF64}" ) output_type = input.typ super().__init__(operands=[input], result_types=[output_type]) -Toy = Dialect([ - ConstantOp, - AddOp, - FuncOp, - GenericCallOp, - PrintOp, - MulOp, - ReturnOp, - ReshapeOp, - TransposeOp, -], []) +Toy = Dialect( + [ + ConstantOp, + AddOp, + FuncOp, + GenericCallOp, + PrintOp, + MulOp, + ReturnOp, + ReshapeOp, + TransposeOp, + ], + [], +) diff --git a/docs/Toy/toy/interpreter.py b/docs/Toy/toy/interpreter.py index de5b44db1f..a021a49408 100644 --- a/docs/Toy/toy/interpreter.py +++ b/docs/Toy/toy/interpreter.py @@ -7,8 +7,7 @@ from dataclasses import dataclass from xdsl.dialects.builtin import TensorType, VectorType, ModuleOp -from xdsl.interpreter import (Interpreter, InterpreterFunctions, - register_impls, impl) +from xdsl.interpreter import Interpreter, InterpreterFunctions, register_impls, impl from xdsl.utils.exceptions import InterpretationError from . import dialect as toy @@ -20,45 +19,45 @@ class Tensor: shape: list[int] def __format__(self, __format_spec: str) -> str: - prod_shapes: list[int] = list( - accumulate(reversed(self.shape), operator.mul)) + prod_shapes: list[int] = list(accumulate(reversed(self.shape), operator.mul)) assert prod_shapes[-1] == len(self.data) - result = '[' * len(self.shape) + result = "[" * len(self.shape) for i, d in enumerate(self.data): if i: n = sum(not i % p for p in prod_shapes) - result += ']' * n - result += ', ' - result += '[' * n - result += f'{d}' + result += "]" * n + result += ", " + result += "[" * n + result += f"{d}" - result += ']' * len(self.shape) + result += "]" * len(self.shape) return result @register_impls class ToyFunctions(InterpreterFunctions): - - def run_toy_func(self, interpreter: Interpreter, name: str, - args: tuple[Any, ...]) -> tuple[Any, ...]: + def run_toy_func( + self, interpreter: Interpreter, name: str, args: tuple[Any, ...] + ) -> tuple[Any, ...]: for op in interpreter.module.ops: if isinstance(op, toy.FuncOp) and op.sym_name.data == name: return self.run_func(interpreter, op, args) - raise InterpretationError( - f'Could not find toy function with name: {name}') + raise InterpretationError(f"Could not find toy function with name: {name}") @impl(toy.PrintOp) - def run_print(self, interpreter: Interpreter, op: toy.PrintOp, - args: tuple[Any, ...]) -> tuple[Any, ...]: - interpreter.print(f'{args[0]}') + def run_print( + self, interpreter: Interpreter, op: toy.PrintOp, args: tuple[Any, ...] + ) -> tuple[Any, ...]: + interpreter.print(f"{args[0]}") return () @impl(toy.FuncOp) - def run_func(self, interpreter: Interpreter, op: toy.FuncOp, - args: tuple[Any, ...]) -> tuple[Any, ...]: - interpreter.push_scope(f'ctx_{op.sym_name.data}') + def run_func( + self, interpreter: Interpreter, op: toy.FuncOp, args: tuple[Any, ...] + ) -> tuple[Any, ...]: + interpreter.push_scope(f"ctx_{op.sym_name.data}") block = op.body.blocks[0] interpreter.set_values(zip(block.args, args)) for body_op in block.ops: @@ -70,60 +69,67 @@ def run_func(self, interpreter: Interpreter, op: toy.FuncOp, return results @impl(toy.ConstantOp) - def run_const(self, interpreter: Interpreter, op: toy.ConstantOp, - args: tuple[Any, ...]) -> tuple[Any, ...]: + def run_const( + self, interpreter: Interpreter, op: toy.ConstantOp, args: tuple[Any, ...] + ) -> tuple[Any, ...]: assert not len(args) data = op.get_data() shape = op.get_shape() result = Tensor(data, shape) - return result, + return (result,) @impl(toy.ReshapeOp) - def run_reshape(self, interpreter: Interpreter, op: toy.ReshapeOp, - args: tuple[Any, ...]) -> tuple[Any, ...]: - arg, = args + def run_reshape( + self, interpreter: Interpreter, op: toy.ReshapeOp, args: tuple[Any, ...] + ) -> tuple[Any, ...]: + (arg,) = args assert isinstance(arg, Tensor) result_typ = op.results[0].typ assert isinstance(result_typ, VectorType | TensorType) new_shape = result_typ.get_shape() - return Tensor(arg.data, new_shape), + return (Tensor(arg.data, new_shape),) @impl(toy.AddOp) - def run_add(self, interpreter: Interpreter, op: toy.AddOp, - args: tuple[Any, ...]) -> tuple[Any, ...]: + def run_add( + self, interpreter: Interpreter, op: toy.AddOp, args: tuple[Any, ...] + ) -> tuple[Any, ...]: lhs, rhs = args assert isinstance(lhs, Tensor) assert isinstance(rhs, Tensor) assert lhs.shape == rhs.shape - return Tensor([l + r for l, r in zip(lhs.data, rhs.data)], lhs.shape), + return (Tensor([l + r for l, r in zip(lhs.data, rhs.data)], lhs.shape),) @impl(toy.MulOp) - def run_mul(self, interpreter: Interpreter, op: toy.MulOp, - args: tuple[Any, ...]) -> tuple[Any, ...]: + def run_mul( + self, interpreter: Interpreter, op: toy.MulOp, args: tuple[Any, ...] + ) -> tuple[Any, ...]: lhs, rhs = args assert isinstance(lhs, Tensor) assert isinstance(rhs, Tensor) assert lhs.shape == rhs.shape - return Tensor([l * r for l, r in zip(lhs.data, rhs.data)], lhs.shape), + return (Tensor([l * r for l, r in zip(lhs.data, rhs.data)], lhs.shape),) @impl(toy.ReturnOp) - def run_return(self, interpreter: Interpreter, op: toy.ReturnOp, - args: tuple[Any, ...]) -> tuple[Any, ...]: + def run_return( + self, interpreter: Interpreter, op: toy.ReturnOp, args: tuple[Any, ...] + ) -> tuple[Any, ...]: assert len(args) < 2 return () @impl(toy.GenericCallOp) - def run_generic_call(self, interpreter: Interpreter, op: toy.GenericCallOp, - args: tuple[Any, ...]) -> tuple[Any, ...]: + def run_generic_call( + self, interpreter: Interpreter, op: toy.GenericCallOp, args: tuple[Any, ...] + ) -> tuple[Any, ...]: return self.run_toy_func(interpreter, op.callee.string_value(), args) @impl(toy.TransposeOp) - def run_transpose(self, interpreter: Interpreter, op: toy.TransposeOp, - args: tuple[Any, ...]) -> tuple[Any, ...]: - arg, = args + def run_transpose( + self, interpreter: Interpreter, op: toy.TransposeOp, args: tuple[Any, ...] + ) -> tuple[Any, ...]: + (arg,) = args assert isinstance(arg, Tensor) assert len(arg.shape) == 2 @@ -131,15 +137,15 @@ def run_transpose(self, interpreter: Interpreter, op: toy.TransposeOp, rows = arg.shape[1] new_data = [ - arg.data[row * cols + col] for col in range(cols) - for row in range(rows) + arg.data[row * cols + col] for col in range(cols) for row in range(rows) ] result = Tensor(new_data, arg.shape[::-1]) - return result, + return (result,) @impl(ModuleOp) - def run_module(self, interpreter: Interpreter, op: ModuleOp, - args: tuple[Any, ...]) -> tuple[Any, ...]: - return self.run_toy_func(interpreter, 'main', args) + def run_module( + self, interpreter: Interpreter, op: ModuleOp, args: tuple[Any, ...] + ) -> tuple[Any, ...]: + return self.run_toy_func(interpreter, "main", args) diff --git a/docs/Toy/toy/ir_gen.py b/docs/Toy/toy/ir_gen.py index 49c419c47e..ade029680c 100644 --- a/docs/Toy/toy/ir_gen.py +++ b/docs/Toy/toy/ir_gen.py @@ -8,14 +8,35 @@ from xdsl.builder import Builder from toy.location import Location -from toy.toy_ast import (LiteralExprAST, ModuleAST, NumberExprAST, - PrototypeAST, VariableExprAST, VarDeclExprAST, - ReturnExprAST, PrintExprAST, FunctionAST, ExprAST, - CallExprAST, BinaryExprAST) - -from .dialect import (TensorTypeF64, UnrankedTensorTypeF64, AddOp, MulOp, - FuncOp, FunctionType, ReturnOp, ConstantOp, - GenericCallOp, TransposeOp, ReshapeOp, PrintOp) +from toy.toy_ast import ( + LiteralExprAST, + ModuleAST, + NumberExprAST, + PrototypeAST, + VariableExprAST, + VarDeclExprAST, + ReturnExprAST, + PrintExprAST, + FunctionAST, + ExprAST, + CallExprAST, + BinaryExprAST, +) + +from .dialect import ( + TensorTypeF64, + UnrankedTensorTypeF64, + AddOp, + MulOp, + FuncOp, + FunctionType, + ReturnOp, + ConstantOp, + GenericCallOp, + TransposeOp, + ReshapeOp, + PrintOp, +) class IRGenError(Exception): @@ -24,7 +45,7 @@ class IRGenError(Exception): @dataclass class ScopedSymbolTable: - 'A mapping from variable names to SSAValues, append-only' + "A mapping from variable names to SSAValues, append-only" table: dict[str, SSAValue] = field(default_factory=dict) def __contains__(self, __o: object) -> bool: @@ -35,8 +56,7 @@ def __getitem__(self, __key: str) -> SSAValue: def __setitem__(self, __key: str, __value: SSAValue) -> None: if __key in self: - raise AssertionError( - f'Cannot add value for key {__key} in scope {self}') + raise AssertionError(f"Cannot add value for key {__key} in scope {self}") self.table[__key] = __value @@ -86,13 +106,13 @@ def ir_gen_module(self, module_ast: ModuleAST) -> ModuleOp: try: self.module.verify() except Exception: - print('module verification error') + print("module verification error") raise return self.module def loc(self, loc: Location): - 'Helper conversion for a Toy AST location to an MLIR location.' + "Helper conversion for a Toy AST location to an MLIR location." # TODO: Need location support in xDSL # return mlir::FileLineColLoc::get(builder.getStringAttr(*loc.file), loc.line, loc.col); pass @@ -107,9 +127,8 @@ def declare(self, var: str, value: SSAValue) -> bool: self.symbol_table[var] = value return True - def get_type(self, - shape: list[int]) -> TensorTypeF64 | UnrankedTensorTypeF64: - 'Build a tensor type from a list of shape dimensions.' + def get_type(self, shape: list[int]) -> TensorTypeF64 | UnrankedTensorTypeF64: + "Build a tensor type from a list of shape dimensions." # If the shape is empty, then this type is unranked. if len(shape): return TensorType.from_type_and_list(f64, shape) @@ -125,11 +144,12 @@ def ir_gen_proto(self, proto_ast: PrototypeAST) -> FuncOp: # This is a generic function, the return type will be inferred later. # Arguments type are uniformly unranked tensors. func_type = FunctionType.from_lists( - [self.get_type([])] * len(proto_ast.args), [self.get_type([])]) + [self.get_type([])] * len(proto_ast.args), [self.get_type([])] + ) return self.builder.insert(FuncOp(proto_ast.name, func_type, Region())) def ir_gen_function(self, function_ast: FunctionAST) -> FuncOp: - 'Emit a new function and add it to the MLIR module.' + "Emit a new function and add it to the MLIR module." # keep builder for later parent_builder = self.builder @@ -140,9 +160,11 @@ def ir_gen_function(self, function_ast: FunctionAST) -> FuncOp: proto_args = function_ast.proto.args # Create the block for the current function - block = Block(arg_types=[ - UnrankedTensorType.from_type(f64) for _ in range(len(proto_args)) - ]) + block = Block( + arg_types=[ + UnrankedTensorType.from_type(f64) for _ in range(len(proto_args)) + ] + ) self.builder = Builder(block) # Declare all the function arguments in the symbol table. @@ -166,29 +188,25 @@ def ir_gen_function(self, function_ast: FunctionAST) -> FuncOp: if return_op is None: self.builder.insert(ReturnOp()) - input_types = [ - self.get_type([]) for _ in range(len(function_ast.proto.args)) - ] + input_types = [self.get_type([]) for _ in range(len(function_ast.proto.args))] func_type = FunctionType.from_lists(input_types, return_types) # main should be public, all the others private - private = function_ast.proto.name != 'main' + private = function_ast.proto.name != "main" # clean up self.symbol_table = None self.builder = parent_builder func = self.builder.insert( - FuncOp(function_ast.proto.name, - func_type, - Region(block), - private=private)) + FuncOp(function_ast.proto.name, func_type, Region(block), private=private) + ) return func def ir_gen_binary_expr(self, binop: BinaryExprAST) -> SSAValue: - 'Emit a binary operation' + "Emit a binary operation" # First emit the operations for each side of the operation before emitting # the operation itself. For example if the expression is `a + foo(a)` @@ -208,12 +226,12 @@ def ir_gen_binary_expr(self, binop: BinaryExprAST) -> SSAValue: # Derive the operation name from the binary operator. At the moment we only # support '+' and '*'. - if binop.op == '+': + if binop.op == "+": op = self.builder.insert(AddOp(lhs, rhs)) - elif binop.op == '*': + elif binop.op == "*": op = self.builder.insert(MulOp(lhs, rhs)) else: - self.error(f'Unsupported binary operation `{binop.op}`') + self.error(f"Unsupported binary operation `{binop.op}`") return op.res @@ -227,10 +245,10 @@ def ir_gen_variable_expr(self, expr: VariableExprAST) -> SSAValue: variable = self.symbol_table[expr.name] return variable except Exception as e: - self.error(f'error: unknown variable `{expr.name}`', e) + self.error(f"error: unknown variable `{expr.name}`", e) def ir_gen_return_expr(self, ret: ReturnExprAST): - 'Emit a return operation. This will return failure if any generation fails.' + "Emit a return operation. This will return failure if any generation fails." # location = self.loc(binop.loc) @@ -287,8 +305,10 @@ def collect_data(self, expr: ExprAST) -> list[float]: elif isinstance(expr, NumberExprAST): return [expr.val] else: - self.error(f'Unsupported expr ({expr}) of type ({type(expr)}), ' - 'expected literal or number expr') + self.error( + f"Unsupported expr ({expr}) of type ({type(expr)}), " + "expected literal or number expr" + ) def ir_gen_call_expr(self, call: CallExprAST) -> SSAValue: """ @@ -304,10 +324,12 @@ def ir_gen_call_expr(self, call: CallExprAST) -> SSAValue: # Builtin calls have their custom operation, meaning this is a # straightforward emission. - if callee == 'transpose': + if callee == "transpose": if len(operands) != 1: - self.error("MLIR codegen encountered an error: toy.transpose " - "does not accept multiple arguments") + self.error( + "MLIR codegen encountered an error: toy.transpose " + "does not accept multiple arguments" + ) op = self.builder.insert(TransposeOp(operands[0])) return op.res @@ -315,8 +337,8 @@ def ir_gen_call_expr(self, call: CallExprAST) -> SSAValue: # user-defined functions are mapped to a custom call that takes the callee # name as an attribute. op = self.builder.insert( - GenericCallOp(callee, operands, - [UnrankedTensorTypeF64.from_type(f64)])) + GenericCallOp(callee, operands, [UnrankedTensorTypeF64.from_type(f64)]) + ) return op.res[0] @@ -329,13 +351,13 @@ def ir_gen_print_expr(self, call: PrintExprAST): self.builder.insert(PrintOp(arg)) def ir_gen_number_expr(self, num: NumberExprAST) -> SSAValue: - 'Emit a constant for a single number' + "Emit a constant for a single number" constant_op = self.builder.insert(ConstantOp.from_list([num.val], [])) return constant_op.res def ir_gen_expr(self, expr: ExprAST) -> SSAValue: - 'Dispatch codegen for the right expression subclass using RTTI.' + "Dispatch codegen for the right expression subclass using RTTI." if isinstance(expr, BinaryExprAST): return self.ir_gen_binary_expr(expr) @@ -348,9 +370,7 @@ def ir_gen_expr(self, expr: ExprAST) -> SSAValue: if isinstance(expr, NumberExprAST): return self.ir_gen_number_expr(expr) else: - self.error( - f"MLIR codegen encountered an unhandled expr kind '{expr.kind}'" - ) + self.error(f"MLIR codegen encountered an unhandled expr kind '{expr.kind}'") def ir_gen_var_decl_expr(self, vardecl: VarDeclExprAST) -> SSAValue: """ @@ -366,8 +386,7 @@ def ir_gen_var_decl_expr(self, vardecl: VarDeclExprAST) -> SSAValue: # with specific shape, we emit a "reshape" operation. It will get # optimized out later as needed. if len(vardecl.varType.shape): - reshape_op = self.builder.insert( - ReshapeOp(value, vardecl.varType.shape)) + reshape_op = self.builder.insert(ReshapeOp(value, vardecl.varType.shape)) value = reshape_op.res @@ -377,7 +396,7 @@ def ir_gen_var_decl_expr(self, vardecl: VarDeclExprAST) -> SSAValue: return value def ir_gen_expr_list(self, exprs: Iterable[ExprAST]) -> None: - 'Codegen a list of expressions, raise error if one of them hit an error.' + "Codegen a list of expressions, raise error if one of them hit an error." assert self.symbol_table is not None for expr in exprs: diff --git a/docs/Toy/toy/lexer.py b/docs/Toy/toy/lexer.py index ef1cfc113f..891785279c 100644 --- a/docs/Toy/toy/lexer.py +++ b/docs/Toy/toy/lexer.py @@ -54,19 +54,19 @@ class EOFToken(Token): pass -IDENTIFIER_CHARS = re.compile(r'[\w]|[\d]|_') -OPERATOR_CHARS = set('+-*/') -SPECIAL_CHARS = set('<>}{(),;=[]') +IDENTIFIER_CHARS = re.compile(r"[\w]|[\d]|_") +OPERATOR_CHARS = set("+-*/") +SPECIAL_CHARS = set("<>}{(),;=[]") def tokenize(file: Path, program: str | None = None): tokens: List[Token] = [] if program is None: - with open(file, 'r') as f: + with open(file, "r") as f: program = f.read() - text = '' + text = "" row = col = 1 def flush(): @@ -83,7 +83,7 @@ def flush(): else: tokens.append(IdentifierToken(file, row, true_col, text)) - text = '' + text = "" for row, line in enumerate(program.splitlines()): # 1-indexed @@ -91,7 +91,7 @@ def flush(): for col, char in enumerate(line): # 1-indexed col += 1 - if char == '#': + if char == "#": # Comment break @@ -101,7 +101,7 @@ def flush(): flush() - if char == ' ': + if char == " ": continue if char in OPERATOR_CHARS: @@ -111,12 +111,11 @@ def flush(): tokens.append(SpecialToken(file, row, col, char)) continue - raise AssertionError( - f'unhandled char {char} at ({row}, {col}) in \n{line}') + raise AssertionError(f"unhandled char {char} at ({row}, {col}) in \n{line}") col += 1 flush() - tokens.append(EOFToken(file, row, col, '')) + tokens.append(EOFToken(file, row, col, "")) return tokens diff --git a/docs/Toy/toy/location.py b/docs/Toy/toy/location.py index 2981c4c3da..8281f83d1b 100644 --- a/docs/Toy/toy/location.py +++ b/docs/Toy/toy/location.py @@ -4,10 +4,10 @@ @dataclass class Location: - 'Structure definition a location in a file.' + "Structure definition a location in a file." file: Path line: int col: int def __repr__(self): - return f'{self.file}:{self.line}:{self.col}' + return f"{self.file}:{self.line}:{self.col}" diff --git a/docs/Toy/toy/parser.py b/docs/Toy/toy/parser.py index 32aa9ae9bc..3b06fecddc 100644 --- a/docs/Toy/toy/parser.py +++ b/docs/Toy/toy/parser.py @@ -1,43 +1,61 @@ from pathlib import Path from typing import List, NoReturn, TypeVar, cast -from .lexer import (OperatorToken, Token, tokenize, NumberToken, EOFToken, - IdentifierToken) -from .toy_ast import (BinaryExprAST, ExprAST, NumberExprAST, FunctionAST, - ModuleAST, LiteralExprAST, VarDeclExprAST, VarType, - VariableExprAST, ReturnExprAST, PrintExprAST, - CallExprAST, PrototypeAST) +from .lexer import ( + OperatorToken, + Token, + tokenize, + NumberToken, + EOFToken, + IdentifierToken, +) +from .toy_ast import ( + BinaryExprAST, + ExprAST, + NumberExprAST, + FunctionAST, + ModuleAST, + LiteralExprAST, + VarDeclExprAST, + VarType, + VariableExprAST, + ReturnExprAST, + PrintExprAST, + CallExprAST, + PrototypeAST, +) class ParseError(Exception): - - def __init__(self, - token: Token, - expected: str | type[Token], - context: str = '', - line: str = ''): + def __init__( + self, + token: Token, + expected: str | type[Token], + context: str = "", + line: str = "", + ): loc = token.loc - message = f'Parse error ({loc.line}, {loc.col}): expected ' + message = f"Parse error ({loc.line}, {loc.col}): expected " if isinstance(expected, str): message += expected else: message += expected.name() if len(context): - message += ' ' + context + message += " " + context message += f" but has Token '{token.text}'\n" if len(line): - message += line + '\n' + message += line + "\n" super().__init__(message) pass -TokenT = TypeVar('TokenT', bound=Token) +TokenT = TypeVar("TokenT", bound=Token) class Parser: @@ -59,9 +77,9 @@ def getToken(self): def getTokenPrecedence(self) -> int: """Returns precedence if the current token is a binary operation, -1 otherwise""" PRECEDENCE = { - '-': 20, - '+': 20, - '*': 40, + "-": 20, + "+": 20, + "*": 40, } op = self.getToken().text @@ -76,9 +94,9 @@ def peek(self, pattern: str | type[Token] | None = None) -> None: raises ParseError otherwise """ token = self.getToken() - tokenType, text = (None, - pattern) if isinstance(pattern, str) else (pattern, - None) + tokenType, text = ( + (None, pattern) if isinstance(pattern, str) else (pattern, None) + ) if tokenType is not None: if type(token) is not tokenType: self.parseError(tokenType) @@ -140,11 +158,11 @@ def parseReturn(self): Parse a return statement. return :== return ; | return expr ; """ - returnToken = self.pop_pattern('return') + returnToken = self.pop_pattern("return") expr = None # Return takes an optional argument - if not self.check(';'): + if not self.check(";"): expr = self.parseExpression() return ReturnExprAST(returnToken.loc, expr) @@ -163,7 +181,7 @@ def parseTensorLiteralExpr(self): tensorLiteral ::= [ literalList ] | number literalList ::= tensorLiteral | tensorLiteral, literalList """ - openBracket = self.pop_pattern('[') + openBracket = self.pop_pattern("[") # Hold the list of values at this nesting level. values: List[LiteralExprAST | NumberExprAST] = [] @@ -172,21 +190,21 @@ def parseTensorLiteralExpr(self): while True: # We can have either another nested array or a number literal. - if self.check('['): + if self.check("["): values.append(self.parseTensorLiteralExpr()) else: if not self.check(NumberToken): - self.parseError(' or [', 'in literal expression') + self.parseError(" or [", "in literal expression") values.append(self.parseNumberExpr()) # End of this list on ']' - if self.check(']'): + if self.check("]"): break # Elements are separated by a comma. - self.pop_pattern(',') + self.pop_pattern(",") - self.pop_pattern(']') + self.pop_pattern("]") # Fill in the dimensions now. First the current nesting level: dims.append(len(values)) @@ -196,15 +214,17 @@ def parseTensorLiteralExpr(self): if any(type(val) is LiteralExprAST for val in values): allTensors = all(type(val) is LiteralExprAST for val in values) if not allTensors: - self.parseError('uniform well-nested dimensions', - 'inside literal expression') + self.parseError( + "uniform well-nested dimensions", "inside literal expression" + ) tensor_values = cast(List[LiteralExprAST], values) first = tensor_values[0].dims allEqual = all(val.dims == first for val in tensor_values) if not allEqual: - self.parseError('uniform well-nested dimensions', - 'inside literal expression') + self.parseError( + "uniform well-nested dimensions", "inside literal expression" + ) dims += first @@ -212,9 +232,9 @@ def parseTensorLiteralExpr(self): def parseParenExpr(self) -> ExprAST: "parenexpr ::= '(' expression ')'" - self.pop_pattern('(') + self.pop_pattern("(") v = self.parseExpression() - self.pop_pattern(')') + self.pop_pattern(")") return v def parseIdentifierExpr(self): @@ -224,24 +244,24 @@ def parseIdentifierExpr(self): ::= identifier '(' expression ')' """ name = self.pop_token(IdentifierToken) - if not self.check('('): + if not self.check("("): # Simple variable ref. return VariableExprAST(name.loc, name.text) # This is a function call. - self.pop_pattern('(') + self.pop_pattern("(") args: List[ExprAST] = [] while True: args.append(self.parseExpression()) - if self.check(')'): + if self.check(")"): break - self.pop_pattern(',') - self.pop_pattern(')') + self.pop_pattern(",") + self.pop_pattern(")") - if name.text == 'print': + if name.text == "print": # It can be a builtin call to print if len(args) != 1: - self.parseError('', 'as argument to print()') + self.parseError("", "as argument to print()") return PrintExprAST(name.loc, args[0]) @@ -260,16 +280,16 @@ def parsePrimary(self) -> ExprAST | None: return self.parseIdentifierExpr() elif isinstance(current, NumberToken): return self.parseNumberExpr() - elif current.text == '(': + elif current.text == "(": return self.parseParenExpr() - elif current.text == '[': + elif current.text == "[": return self.parseTensorLiteralExpr() - elif current.text == ';': + elif current.text == ";": return None - elif current.text == '}': + elif current.text == "}": return None else: - self.parseError('expression or one of `;`, `}`') + self.parseError("expression or one of `;`, `}`") def parsePrimaryNotNone(self) -> ExprAST: """ @@ -284,12 +304,12 @@ def parsePrimaryNotNone(self) -> ExprAST: return self.parseIdentifierExpr() elif isinstance(current, NumberToken): return self.parseNumberExpr() - elif current.text == '(': + elif current.text == "(": return self.parseParenExpr() - elif current.text == '[': + elif current.text == "[": return self.parseTensorLiteralExpr() else: - self.parseError('expression') + self.parseError("expression") def parseBinOpRHS(self, exprPrec: int, lhs: ExprAST) -> ExprAST: """ @@ -314,7 +334,7 @@ def parseBinOpRHS(self, exprPrec: int, lhs: ExprAST) -> ExprAST: rhs = self.parsePrimary() if rhs is None: - self.parseError('expression', 'to complete binary operator') + self.parseError("expression", "to complete binary operator") # If BinOp binds less tightly with rhs than the operator after rhs, let # the pending operator take rhs as its lhs. @@ -335,15 +355,15 @@ def parseType(self): type ::= < shape_list > shape_list ::= num | num , shape_list """ - self.pop_pattern('<') + self.pop_pattern("<") shape: List[int] = [] - while (token := self.pop_token(NumberToken)): + while token := self.pop_token(NumberToken): shape.append(int(token.value)) - if self.check('>'): + if self.check(">"): self.pop() break - self.pop_pattern(',') + self.pop_pattern(",") return VarType(shape) @@ -354,16 +374,16 @@ def parseDeclaration(self): initializer. decl ::= var identifier [ type ] = expr """ - var = self.pop_pattern('var') + var = self.pop_pattern("var") name = self.pop_token(IdentifierToken).text # Type is optional, it can be inferred - if self.check('<'): + if self.check("<"): varType = self.parseType() else: varType = VarType([]) - self.pop_pattern('=') + self.pop_pattern("=") expr = self.parseExpression() return VarDeclExprAST(var.loc, name, varType, expr) @@ -377,18 +397,18 @@ def parseBlock(self) -> tuple[ExprAST, ...]: expression_list ::= block_expr ; expression_list block_expr ::= decl | "return" | expr """ - self.pop_pattern('{') + self.pop_pattern("{") exprList: List[ExprAST] = [] # Ignore empty expressions: swallow sequences of semicolons. - while self.check(';'): - self.pop_pattern(';') + while self.check(";"): + self.pop_pattern(";") - while not self.check('}'): - if self.check('var'): + while not self.check("}"): + if self.check("var"): # Variable declaration exprList.append(self.parseDeclaration()) - elif self.check('return'): + elif self.check("return"): # Return statement exprList.append(self.parseReturn()) else: @@ -396,13 +416,13 @@ def parseBlock(self) -> tuple[ExprAST, ...]: exprList.append(self.parseExpression()) # Ensure that elements are separated by a semicolon. - self.pop_pattern(';') + self.pop_pattern(";") # Ignore empty expressions: swallow sequences of semicolons. - while self.check(';'): - self.pop_pattern(';') + while self.check(";"): + self.pop_pattern(";") - self.pop_pattern('}') + self.pop_pattern("}") return tuple(exprList) @@ -411,20 +431,20 @@ def parsePrototype(self): prototype ::= def id '(' decl_list ')' decl_list ::= identifier | identifier, decl_list """ - defToken = self.pop_pattern('def') + defToken = self.pop_pattern("def") fnName = self.pop_token(IdentifierToken).text - self.pop_pattern('(') + self.pop_pattern("(") args: List[VariableExprAST] = [] - if not self.check(')'): + if not self.check(")"): while True: arg = self.pop_token(IdentifierToken) args.append(VariableExprAST(arg.loc, arg.text)) - if not self.check(','): + if not self.check(","): break - self.pop_pattern(',') + self.pop_pattern(",") - self.pop_pattern(')') + self.pop_pattern(")") return PrototypeAST(defToken.loc, fnName, args) def parseDefinition(self): @@ -438,9 +458,7 @@ def parseDefinition(self): block = self.parseBlock() return FunctionAST(proto.loc, proto, block) - def parseError(self, - expected: str | type[Token], - context: str = '') -> NoReturn: + def parseError(self, expected: str | type[Token], context: str = "") -> NoReturn: """ Helper function to signal errors while parsing, it takes an argument indicating the expected token and another argument giving more context. diff --git a/docs/Toy/toy/tests/test_interpreter.py b/docs/Toy/toy/tests/test_interpreter.py index b54ca209d8..2d13e340c9 100644 --- a/docs/Toy/toy/tests/test_interpreter.py +++ b/docs/Toy/toy/tests/test_interpreter.py @@ -9,12 +9,12 @@ def test_tensor_printing(): - assert f'{Tensor([], [0])}' == '[]' - assert f'{Tensor([1], [1])}' == '[1]' - assert f'{Tensor([1], [1, 1])}' == '[[1]]' - assert f'{Tensor([1], [1, 1, 1])}' == '[[[1]]]' - assert f'{Tensor([1, 2, 3, 4, 5, 6], [2, 3])}' == '[[1, 2, 3], [4, 5, 6]]' - assert f'{Tensor([1, 2, 3, 4, 5, 6], [3, 2])}' == '[[1, 2], [3, 4], [5, 6]]' + assert f"{Tensor([], [0])}" == "[]" + assert f"{Tensor([1], [1])}" == "[1]" + assert f"{Tensor([1], [1, 1])}" == "[[1]]" + assert f"{Tensor([1], [1, 1, 1])}" == "[[[1]]]" + assert f"{Tensor([1, 2, 3, 4, 5, 6], [2, 3])}" == "[[1, 2, 3], [4, 5, 6]]" + assert f"{Tensor([1, 2, 3, 4, 5, 6], [3, 2])}" == "[[1, 2], [3, 4], [5, 6]]" def test_interpreter(): @@ -23,11 +23,10 @@ def test_interpreter(): interpreter = Interpreter(module_op, file=stream) interpreter.register_implementations(ToyFunctions()) interpreter.run_module() - assert '[[1.0, 9.0], [25.0, 4.0], [16.0, 36.0]]\n' == stream.getvalue() + assert "[[1.0, 9.0], [25.0, 4.0], [16.0, 36.0]]\n" == stream.getvalue() def build_module() -> ModuleOp: - unrankedf64TensorType = td.UnrankedTensorType.from_type(f64) def func_body(*args: BlockArgument) -> list[Operation]: @@ -44,18 +43,19 @@ def main_body(*args: BlockArgument) -> list[Operation]: m1 = td.ConstantOp.from_list([1, 2, 3, 4, 5, 6], [6]) m2 = td.ReshapeOp(m1.results[0], [2, 3]) [b] = m2.results - m3 = td.GenericCallOp('multiply_transpose', [a, b], - [unrankedf64TensorType]) + m3 = td.GenericCallOp("multiply_transpose", [a, b], [unrankedf64TensorType]) [c] = m3.results m4 = td.PrintOp(c) m5 = td.ReturnOp() return [m0, m1, m2, m3, m4, m5] multiply_transpose = td.FuncOp.from_callable( - 'multiply_transpose', [unrankedf64TensorType, unrankedf64TensorType], + "multiply_transpose", + [unrankedf64TensorType, unrankedf64TensorType], [unrankedf64TensorType], func_body, - private=True) - main = td.FuncOp.from_callable('main', [], [], main_body, private=False) + private=True, + ) + main = td.FuncOp.from_callable("main", [], [], main_body, private=False) return ModuleOp([multiply_transpose, main]) diff --git a/docs/Toy/toy/tests/test_ir_gen.py b/docs/Toy/toy/tests/test_ir_gen.py index cb43b6f0e2..03f8d6c889 100644 --- a/docs/Toy/toy/tests/test_ir_gen.py +++ b/docs/Toy/toy/tests/test_ir_gen.py @@ -11,9 +11,9 @@ def test_convert_ast(): - ast_toy = Path('docs/Toy/examples/ast.toy') + ast_toy = Path("docs/Toy/examples/ast.toy") - with open(ast_toy, 'r') as f: + with open(ast_toy, "r") as f: parser = Parser(ast_toy, f.read()) module_ast = parser.parseModule() @@ -28,8 +28,8 @@ def module_op(): unrankedf64TensorType = toy.UnrankedTensorType.from_type(f64) multiply_transpose_type = FunctionType.from_lists( - [unrankedf64TensorType, unrankedf64TensorType], - [unrankedf64TensorType]) + [unrankedf64TensorType, unrankedf64TensorType], [unrankedf64TensorType] + ) @Builder.implicit_region(multiply_transpose_type.inputs) def multiply_transpose(args: tuple[BlockArgument, ...]) -> None: @@ -40,8 +40,9 @@ def multiply_transpose(args: tuple[BlockArgument, ...]) -> None: toy.ReturnOp(prod) def call_multiply_transpose(a: SSAValue, b: SSAValue) -> OpResult: - return toy.GenericCallOp("multiply_transpose", [a, b], - [unrankedf64TensorType]).res[0] + return toy.GenericCallOp( + "multiply_transpose", [a, b], [unrankedf64TensorType] + ).res[0] main_type = FunctionType.from_lists([], []) @@ -57,10 +58,12 @@ def main() -> None: call_multiply_transpose(a_t, c) toy.ReturnOp() - toy.FuncOp("multiply_transpose", - multiply_transpose_type, - multiply_transpose, - private=True) + toy.FuncOp( + "multiply_transpose", + multiply_transpose_type, + multiply_transpose, + private=True, + ) toy.FuncOp("main", main_type, main) assert module_op.is_structurally_equivalent(generated_module_op) diff --git a/docs/Toy/toy/tests/test_parser.py b/docs/Toy/toy/tests/test_parser.py index ffafbdf94b..c2f6c84fd6 100644 --- a/docs/Toy/toy/tests/test_parser.py +++ b/docs/Toy/toy/tests/test_parser.py @@ -3,14 +3,26 @@ import pytest from ..parser import Parser, ParseError -from ..toy_ast import ModuleAST, FunctionAST, PrototypeAST, VariableExprAST, ReturnExprAST, BinaryExprAST, CallExprAST, VarDeclExprAST, VarType, LiteralExprAST, NumberExprAST +from ..toy_ast import ( + ModuleAST, + FunctionAST, + PrototypeAST, + VariableExprAST, + ReturnExprAST, + BinaryExprAST, + CallExprAST, + VarDeclExprAST, + VarType, + LiteralExprAST, + NumberExprAST, +) from ..location import Location def test_parse_ast(): - ast_toy = Path('docs/Toy/examples/ast.toy') + ast_toy = Path("docs/Toy/examples/ast.toy") - with open(ast_toy, 'r') as f: + with open(ast_toy, "r") as f: parser = Parser(ast_toy, f.read()) parsed_module_ast = parser.parseModule() @@ -18,72 +30,148 @@ def test_parse_ast(): def loc(line: int, col: int) -> Location: return Location(ast_toy, line, col) - module_ast = ModuleAST(( - FunctionAST( - loc(4, 1), - PrototypeAST(loc(4, 1), 'multiply_transpose', [ - VariableExprAST(loc(4, 24), 'a'), - VariableExprAST(loc(4, 27), 'b'), - ]), (ReturnExprAST( - loc(5, 3), - BinaryExprAST( - loc(5, 25), '*', - CallExprAST(loc(5, 10), 'transpose', - [VariableExprAST(loc(5, 20), 'a')]), - CallExprAST(loc(5, 25), 'transpose', - [VariableExprAST(loc(5, 35), 'b')]))), )), - FunctionAST(loc(8, 1), PrototypeAST(loc(8, 1), 'main', []), ( - VarDeclExprAST( - loc(11, 3), 'a', VarType([]), - LiteralExprAST(loc(11, 11), [ - LiteralExprAST(loc(11, 12), [ - NumberExprAST(loc(11, 13), 1.0), - NumberExprAST(loc(11, 16), 2.0), - NumberExprAST(loc(11, 19), 3.0) - ], [3]), - LiteralExprAST(loc(11, 23), [ - NumberExprAST(loc(11, 24), 4.0), - NumberExprAST(loc(11, 27), 5.0), - NumberExprAST(loc(11, 30), 6.0), - ], [3]) - ], [2, 3])), - VarDeclExprAST( - loc(15, 3), 'b', VarType([2, 3]), - LiteralExprAST(loc(15, 17), [ - NumberExprAST(loc(15, 18), 1.0), - NumberExprAST(loc(15, 21), 2.0), - NumberExprAST(loc(15, 24), 3.0), - NumberExprAST(loc(15, 27), 4.0), - NumberExprAST(loc(15, 30), 5.0), - NumberExprAST(loc(15, 33), 6.0), - ], [6])), - VarDeclExprAST( - loc(19, 3), 'c', VarType([]), - CallExprAST(loc(19, 11), 'multiply_transpose', [ - VariableExprAST(loc(19, 30), 'a'), - VariableExprAST(loc(19, 33), 'b'), - ])), - VarDeclExprAST( - loc(22, 3), 'd', VarType([]), - CallExprAST(loc(22, 11), 'multiply_transpose', [ - VariableExprAST(loc(22, 30), 'b'), - VariableExprAST(loc(22, 33), 'a'), - ])), - VarDeclExprAST( - loc(25, 3), 'e', VarType([]), - CallExprAST(loc(25, 11), 'multiply_transpose', [ - VariableExprAST(loc(25, 30), 'b'), - VariableExprAST(loc(25, 33), 'c'), - ])), - VarDeclExprAST( - loc(28, 3), 'f', VarType([]), - CallExprAST(loc(28, 11), 'multiply_transpose', [ - CallExprAST(loc(28, 30), 'transpose', - [VariableExprAST(loc(28, 40), 'a')]), - VariableExprAST(loc(28, 44), 'c'), - ])), - )), - )) + module_ast = ModuleAST( + ( + FunctionAST( + loc(4, 1), + PrototypeAST( + loc(4, 1), + "multiply_transpose", + [ + VariableExprAST(loc(4, 24), "a"), + VariableExprAST(loc(4, 27), "b"), + ], + ), + ( + ReturnExprAST( + loc(5, 3), + BinaryExprAST( + loc(5, 25), + "*", + CallExprAST( + loc(5, 10), + "transpose", + [VariableExprAST(loc(5, 20), "a")], + ), + CallExprAST( + loc(5, 25), + "transpose", + [VariableExprAST(loc(5, 35), "b")], + ), + ), + ), + ), + ), + FunctionAST( + loc(8, 1), + PrototypeAST(loc(8, 1), "main", []), + ( + VarDeclExprAST( + loc(11, 3), + "a", + VarType([]), + LiteralExprAST( + loc(11, 11), + [ + LiteralExprAST( + loc(11, 12), + [ + NumberExprAST(loc(11, 13), 1.0), + NumberExprAST(loc(11, 16), 2.0), + NumberExprAST(loc(11, 19), 3.0), + ], + [3], + ), + LiteralExprAST( + loc(11, 23), + [ + NumberExprAST(loc(11, 24), 4.0), + NumberExprAST(loc(11, 27), 5.0), + NumberExprAST(loc(11, 30), 6.0), + ], + [3], + ), + ], + [2, 3], + ), + ), + VarDeclExprAST( + loc(15, 3), + "b", + VarType([2, 3]), + LiteralExprAST( + loc(15, 17), + [ + NumberExprAST(loc(15, 18), 1.0), + NumberExprAST(loc(15, 21), 2.0), + NumberExprAST(loc(15, 24), 3.0), + NumberExprAST(loc(15, 27), 4.0), + NumberExprAST(loc(15, 30), 5.0), + NumberExprAST(loc(15, 33), 6.0), + ], + [6], + ), + ), + VarDeclExprAST( + loc(19, 3), + "c", + VarType([]), + CallExprAST( + loc(19, 11), + "multiply_transpose", + [ + VariableExprAST(loc(19, 30), "a"), + VariableExprAST(loc(19, 33), "b"), + ], + ), + ), + VarDeclExprAST( + loc(22, 3), + "d", + VarType([]), + CallExprAST( + loc(22, 11), + "multiply_transpose", + [ + VariableExprAST(loc(22, 30), "b"), + VariableExprAST(loc(22, 33), "a"), + ], + ), + ), + VarDeclExprAST( + loc(25, 3), + "e", + VarType([]), + CallExprAST( + loc(25, 11), + "multiply_transpose", + [ + VariableExprAST(loc(25, 30), "b"), + VariableExprAST(loc(25, 33), "c"), + ], + ), + ), + VarDeclExprAST( + loc(28, 3), + "f", + VarType([]), + CallExprAST( + loc(28, 11), + "multiply_transpose", + [ + CallExprAST( + loc(28, 30), + "transpose", + [VariableExprAST(loc(28, 40), "a")], + ), + VariableExprAST(loc(28, 44), "c"), + ], + ), + ), + ), + ), + ) + ) assert parsed_module_ast == module_ast diff --git a/docs/Toy/toy/toy_ast.py b/docs/Toy/toy/toy_ast.py index 5586f4dd9d..91cb5abb9a 100644 --- a/docs/Toy/toy/toy_ast.py +++ b/docs/Toy/toy/toy_ast.py @@ -11,7 +11,7 @@ @dataclass class VarType: - 'A variable type with shape information.' + "A variable type with shape information." shape: List[int] @@ -27,28 +27,33 @@ class ExprASTKind(Enum): @dataclass() -class Dumper(): +class Dumper: lines: List[str] indentation: int = 0 def append(self, prefix: str, line: str): - self.lines.append(' ' * self.indentation * INDENT + prefix + line) - - def append_list(self, prefix: str, open_paren: str, - exprs: Iterable[ExprAST | FunctionAST], close_paren: str, - block: Callable[[Dumper, ExprAST | FunctionAST], None]): + self.lines.append(" " * self.indentation * INDENT + prefix + line) + + def append_list( + self, + prefix: str, + open_paren: str, + exprs: Iterable[ExprAST | FunctionAST], + close_paren: str, + block: Callable[[Dumper, ExprAST | FunctionAST], None], + ): self.append(prefix, open_paren) child = self.child() for expr in exprs: block(child, expr) - self.append('', close_paren) + self.append("", close_paren) def child(self): return Dumper(self.lines, self.indentation + 1) @property def message(self): - return '\n'.join(self.lines) + return "\n".join(self.lines) @dataclass @@ -61,20 +66,20 @@ def __init__(self, loc: Location): @property def kind(self) -> ExprASTKind: - raise AssertionError(f'ExprAST kind not defined for {type(self)}') + raise AssertionError(f"ExprAST kind not defined for {type(self)}") def inner_dump(self, prefix: str, dumper: Dumper): dumper.append(prefix, self.__class__.__name__) def dump(self): dumper = Dumper([]) - self.inner_dump('', dumper) + self.inner_dump("", dumper) return dumper.message @dataclass class VarDeclExprAST(ExprAST): - 'Expression class for defining a variable.' + "Expression class for defining a variable." name: str varType: VarType expr: ExprAST @@ -84,15 +89,15 @@ def kind(self): return ExprASTKind.Expr_VarDecl def inner_dump(self, prefix: str, dumper: Dumper): - dims_str = ', '.join(f'{int(dim)}' for dim in self.varType.shape) - dumper.append('VarDecl ', f'{self.name}<{dims_str}> {self.loc}') + dims_str = ", ".join(f"{int(dim)}" for dim in self.varType.shape) + dumper.append("VarDecl ", f"{self.name}<{dims_str}> {self.loc}") child = dumper.child() - self.expr.inner_dump('', child) + self.expr.inner_dump("", child) @dataclass class ReturnExprAST(ExprAST): - 'Expression class for a return operator.' + "Expression class for a return operator." expr: Optional[ExprAST] @property @@ -100,10 +105,10 @@ def kind(self): return ExprASTKind.Expr_Return def inner_dump(self, prefix: str, dumper: Dumper): - dumper.append(prefix, 'Return') + dumper.append(prefix, "Return") if self.expr is not None: child = dumper.child() - self.expr.inner_dump('', child) + self.expr.inner_dump("", child) @dataclass @@ -116,12 +121,12 @@ def kind(self): return ExprASTKind.Expr_Num def inner_dump(self, prefix: str, dumper: Dumper): - dumper.append(prefix, ' {:.6e}'.format(self.val)) + dumper.append(prefix, " {:.6e}".format(self.val)) @dataclass class LiteralExprAST(ExprAST): - 'Expression class for a literal value.' + "Expression class for a literal value." values: List[LiteralExprAST | NumberExprAST] dims: List[int] @@ -130,14 +135,15 @@ def kind(self): return ExprASTKind.Expr_Literal def __dump(self) -> str: - dims_str = ', '.join(f'{int(dim)}' for dim in self.dims) - vals_str = ','.join( + dims_str = ", ".join(f"{int(dim)}" for dim in self.dims) + vals_str = ",".join( val.__dump() if isinstance(val, LiteralExprAST) else val.dump() - for val in self.values) - return f' <{dims_str}>[{vals_str}]' + for val in self.values + ) + return f" <{dims_str}>[{vals_str}]" def inner_dump(self, prefix: str, dumper: Dumper): - dumper.append('Literal:', self.__dump() + f' {self.loc}') + dumper.append("Literal:", self.__dump() + f" {self.loc}") def iter_flattened_values(self) -> Generator[float, None, None]: for value in self.values: @@ -160,12 +166,12 @@ def kind(self): return ExprASTKind.Expr_Var def inner_dump(self, prefix: str, dumper: Dumper): - dumper.append('var: ', f'{self.name} {self.loc}') + dumper.append("var: ", f"{self.name} {self.loc}") @dataclass class BinaryExprAST(ExprAST): - 'Expression class for a binary operator.' + "Expression class for a binary operator." op: str lhs: ExprAST rhs: ExprAST @@ -177,13 +183,13 @@ def kind(self): def inner_dump(self, prefix: str, dumper: Dumper): dumper.append(prefix, f"BinOp: {self.op} {self.loc}") child = dumper.child() - self.lhs.inner_dump('', child) - self.rhs.inner_dump('', child) + self.lhs.inner_dump("", child) + self.rhs.inner_dump("", child) @dataclass class CallExprAST(ExprAST): - 'Expression class for function calls.' + "Expression class for function calls." callee: str args: List[ExprAST] @@ -192,14 +198,18 @@ def kind(self): return ExprASTKind.Expr_Call def inner_dump(self, prefix: str, dumper: Dumper): - dumper.append_list(prefix, f"Call '{self.callee}' [ {self.loc}", - self.args, ']', - lambda dd, arg: arg.inner_dump('', dd)) + dumper.append_list( + prefix, + f"Call '{self.callee}' [ {self.loc}", + self.args, + "]", + lambda dd, arg: arg.inner_dump("", dd), + ) @dataclass class PrintExprAST(ExprAST): - 'Expression class for builtin print calls.' + "Expression class for builtin print calls." arg: ExprAST @property @@ -209,7 +219,7 @@ def kind(self): def inner_dump(self, prefix: str, dumper: Dumper): super().inner_dump(prefix, dumper) child = dumper.child() - self.arg.inner_dump('arg: ', child) + self.arg.inner_dump("arg: ", child) @dataclass @@ -219,51 +229,57 @@ class PrototypeAST: name, and its argument names (thus implicitly the number of arguments the function takes). """ + loc: Location name: str args: List[VariableExprAST] def dump(self): dumper = Dumper([]) - self.inner_dump('', dumper) + self.inner_dump("", dumper) return dumper.message def inner_dump(self, prefix: str, dumper: Dumper): - dumper.append('', f"Proto '{self.name}' {self.loc}'") - dumper.append('Params: ', - f'[{", ".join(arg.name for arg in self.args)}]') + dumper.append("", f"Proto '{self.name}' {self.loc}'") + dumper.append("Params: ", f'[{", ".join(arg.name for arg in self.args)}]') @dataclass class FunctionAST: - 'This class represents a function definition itself.' + "This class represents a function definition itself." loc: Location proto: PrototypeAST body: tuple[ExprAST, ...] def dump(self): dumper = Dumper([]) - self.inner_dump('', dumper) + self.inner_dump("", dumper) return dumper.message def inner_dump(self, prefix: str, dumper: Dumper): - dumper.append(prefix, 'Function ') + dumper.append(prefix, "Function ") child = dumper.child() - self.proto.inner_dump('proto: ', child) - child.append_list('Block ', '{', self.body, '} // Block', - lambda dd, stmt: stmt.inner_dump('', dd)) + self.proto.inner_dump("proto: ", child) + child.append_list( + "Block ", + "{", + self.body, + "} // Block", + lambda dd, stmt: stmt.inner_dump("", dd), + ) @dataclass class ModuleAST: - 'This class represents a list of functions to be processed together' + "This class represents a list of functions to be processed together" funcs: tuple[FunctionAST, ...] def dump(self): dumper = Dumper([]) - self.inner_dump('', dumper) + self.inner_dump("", dumper) return dumper.message def inner_dump(self, prefix: str, dumper: Dumper): - dumper.append_list(prefix, 'Module:', self.funcs, '', - lambda dd, func: func.inner_dump('', dd)) + dumper.append_list( + prefix, "Module:", self.funcs, "", lambda dd, func: func.inner_dump("", dd) + ) diff --git a/setup.py b/setup.py index b47efc477e..26a8850080 100644 --- a/setup.py +++ b/setup.py @@ -49,7 +49,7 @@ cmdclass=versioneer.get_cmdclass(), description="xDSL", long_description=long_description, - long_description_content_type='text/markdown', + long_description_content_type="text/markdown", entry_points={"console_scripts": ["xdsl-opt = xdsl.tools.xdsl_opt:main"]}, project_urls={ "Source Code": "https://github.com/xdslproject/xdsl", diff --git a/tests/conftest.py b/tests/conftest.py index d1c2978311..ca9ba65e21 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -5,11 +5,13 @@ from xdsl.utils.diagnostic import Diagnostic -def assert_print_op(operation: Operation, - expected: str, - diagnostic: Diagnostic | None, - print_generic_format: bool = False, - target: Printer.Target | None = None): +def assert_print_op( + operation: Operation, + expected: str, + diagnostic: Diagnostic | None, + print_generic_format: bool = False, + target: Printer.Target | None = None, +): """ Utility function that helps to check the printing of an operation compared to some string @@ -44,15 +46,19 @@ def assert_print_op(operation: Operation, if diagnostic is None: diagnostic = Diagnostic() if target is None: - printer = Printer(stream=file, - print_generic_format=print_generic_format, - diagnostic=diagnostic, - target=Printer.Target.XDSL) + printer = Printer( + stream=file, + print_generic_format=print_generic_format, + diagnostic=diagnostic, + target=Printer.Target.XDSL, + ) else: - printer = Printer(stream=file, - print_generic_format=print_generic_format, - diagnostic=diagnostic, - target=target) + printer = Printer( + stream=file, + print_generic_format=print_generic_format, + diagnostic=diagnostic, + target=target, + ) printer.print(operation) assert file.getvalue().strip() == expected.strip() diff --git a/tests/dialects/test_affine.py b/tests/dialects/test_affine.py index d804f346ec..bc8f55c052 100644 --- a/tests/dialects/test_affine.py +++ b/tests/dialects/test_affine.py @@ -15,52 +15,61 @@ def test_for_mismatch_operands_results_counts(): attributes: dict[str, Attribute] = { "lower_bound": IntegerAttr.from_index_int_value(0), "upper_bound": IntegerAttr.from_index_int_value(5), - "step": IntegerAttr.from_index_int_value(1) + "step": IntegerAttr.from_index_int_value(1), } - f = For.create(operands=[], - regions=[Region()], - attributes=attributes, - result_types=[IndexType()]) + f = For.create( + operands=[], + regions=[Region()], + attributes=attributes, + result_types=[IndexType()], + ) with pytest.raises(Exception) as e: f.verify() - assert e.value.args[ - 0] == "Expected the same amount of operands and results" + assert e.value.args[0] == "Expected the same amount of operands and results" def test_for_mismatch_operands_results_types(): attributes: dict[str, Attribute] = { "lower_bound": IntegerAttr.from_index_int_value(0), "upper_bound": IntegerAttr.from_index_int_value(5), - "step": IntegerAttr.from_index_int_value(1) + "step": IntegerAttr.from_index_int_value(1), } - b = Block(arg_types=(IntegerType(32), )) + b = Block(arg_types=(IntegerType(32),)) inp = b.args[0] - f = For.create(operands=[inp], - regions=[Region()], - attributes=attributes, - result_types=[IndexType()]) + f = For.create( + operands=[inp], + regions=[Region()], + attributes=attributes, + result_types=[IndexType()], + ) with pytest.raises(Exception) as e: f.verify() - assert e.value.args[ - 0] == "Expected all operands and result pairs to have matching types" + assert ( + e.value.args[0] + == "Expected all operands and result pairs to have matching types" + ) def test_for_mismatch_blockargs(): attributes: dict[str, Attribute] = { "lower_bound": IntegerAttr.from_index_int_value(0), "upper_bound": IntegerAttr.from_index_int_value(5), - "step": IntegerAttr.from_index_int_value(1) + "step": IntegerAttr.from_index_int_value(1), } - b = Block(arg_types=(IndexType(), )) + b = Block(arg_types=(IndexType(),)) inp = b.args[0] - f = For.create(operands=[inp], - regions=[Region(Block.from_callable([], lambda *args: []))], - attributes=attributes, - result_types=[IndexType()]) + f = For.create( + operands=[inp], + regions=[Region(Block.from_callable([], lambda *args: []))], + attributes=attributes, + result_types=[IndexType()], + ) with pytest.raises(Exception) as e: f.verify() - assert e.value.args[ - 0] == "Expected BlockArguments to have the same types as the operands" + assert ( + e.value.args[0] + == "Expected BlockArguments to have the same types as the operands" + ) def test_yield(): diff --git a/tests/dialects/test_arith.py b/tests/dialects/test_arith.py index 7f055feb18..0b9096ca7f 100644 --- a/tests/dialects/test_arith.py +++ b/tests/dialects/test_arith.py @@ -1,12 +1,50 @@ import pytest -from xdsl.dialects.arith import (Addi, Constant, DivUI, DivSI, Subi, - FloorDivSI, CeilDivSI, CeilDivUI, RemUI, - RemSI, MinUI, MinSI, MaxUI, MaxSI, AndI, OrI, - XOrI, ShLI, ShRUI, ShRSI, Cmpi, Addf, Subf, - Mulf, Divf, Maxf, Minf, IndexCastOp, FPToSIOp, - SIToFPOp, ExtFOp, TruncFOp, Cmpf, Negf) -from xdsl.dialects.builtin import i32, i64, f32, f64, IndexType, IntegerType, Float32Type +from xdsl.dialects.arith import ( + Addi, + Constant, + DivUI, + DivSI, + Subi, + FloorDivSI, + CeilDivSI, + CeilDivUI, + RemUI, + RemSI, + MinUI, + MinSI, + MaxUI, + MaxSI, + AndI, + OrI, + XOrI, + ShLI, + ShRUI, + ShRSI, + Cmpi, + Addf, + Subf, + Mulf, + Divf, + Maxf, + Minf, + IndexCastOp, + FPToSIOp, + SIToFPOp, + ExtFOp, + TruncFOp, + Cmpf, + Negf, +) +from xdsl.dialects.builtin import ( + i32, + i64, + f32, + f64, + IndexType, + IntegerType, + Float32Type, +) from xdsl.utils.exceptions import VerifyException @@ -17,9 +55,25 @@ class Test_integer_arith_construction: @pytest.mark.parametrize( "func", [ - Addi, Subi, DivUI, DivSI, FloorDivSI, CeilDivSI, CeilDivUI, RemUI, - RemSI, MinUI, MinSI, MaxUI, MaxSI, AndI, OrI, XOrI, ShLI, ShRUI, - ShRSI + Addi, + Subi, + DivUI, + DivSI, + FloorDivSI, + CeilDivSI, + CeilDivUI, + RemUI, + RemSI, + MinUI, + MinSI, + MaxUI, + MaxSI, + AndI, + OrI, + XOrI, + ShLI, + ShRUI, + ShRSI, ], ) def test_arith_ops(self, func): @@ -39,7 +93,6 @@ def test_Cmpi_from_mnemonic(self, input): class Test_float_arith_construction: - a = Constant.from_float_and_width(1.1, f32) b = Constant.from_float_and_width(2.2, f32) @@ -100,8 +153,22 @@ def test_cmpf_from_mnemonic(): a = Constant.from_float_and_width(1.0, f64) b = Constant.from_float_and_width(2.0, f64) operations = [ - "false", "oeq", "ogt", "oge", "olt", "ole", "one", "ord", "ueq", "ugt", - "uge", "ult", "ule", "une", "uno", "true" + "false", + "oeq", + "ogt", + "oge", + "olt", + "ole", + "one", + "ord", + "ueq", + "ugt", + "uge", + "ult", + "ule", + "une", + "uno", + "true", ] cmpf_ops = [None] * len(operations) @@ -131,8 +198,10 @@ def test_cmpf_missmatch_type(): with pytest.raises(TypeError) as e: cmpf_op = Cmpf.get(a, b, 1) - assert e.value.args[ - 0] == "Comparison operands must have same type, but provided f32 and f64" + assert ( + e.value.args[0] + == "Comparison operands must have same type, but provided f32 and f64" + ) def test_cmpi_missmatch_type(): @@ -141,8 +210,10 @@ def test_cmpi_missmatch_type(): with pytest.raises(TypeError) as e: cmpi_op = Cmpi.get(a, b, 1) - assert e.value.args[ - 0] == "Comparison operands must have same type, but provided i32 and i64" + assert ( + e.value.args[0] + == "Comparison operands must have same type, but provided i32 and i64" + ) def test_cmpf_incorrect_comparison(): diff --git a/tests/dialects/test_builtin.py b/tests/dialects/test_builtin.py index 4575a9574b..d4b27ee719 100644 --- a/tests/dialects/test_builtin.py +++ b/tests/dialects/test_builtin.py @@ -2,10 +2,22 @@ import pytest from xdsl.dialects.builtin import ( - ComplexType, DenseArrayBase, DenseIntOrFPElementsAttr, NoneAttr, - StridedLayoutAttr, i32, f32, FloatAttr, ArrayAttr, IntAttr, FloatData, - SymbolRefAttr, VectorBaseTypeConstraint, VectorRankConstraint, - VectorBaseTypeAndRankConstraint) + ComplexType, + DenseArrayBase, + DenseIntOrFPElementsAttr, + NoneAttr, + StridedLayoutAttr, + i32, + f32, + FloatAttr, + ArrayAttr, + IntAttr, + FloatData, + SymbolRefAttr, + VectorBaseTypeConstraint, + VectorRankConstraint, + VectorBaseTypeAndRankConstraint, +) from xdsl.dialects.builtin import i32, i64, VectorType, UnrealizedConversionCastOp from xdsl.dialects.arith import Constant from xdsl.dialects.memref import MemRefType @@ -45,20 +57,25 @@ def test_DenseArrayBase_verifier_failure(): with pytest.raises(VerifyException) as err: DenseArrayBase([f32, ArrayAttr([IntAttr(0)])]) - assert err.value.args[0] == ("dense array of float element type " - "should only contain floats") + assert err.value.args[0] == ( + "dense array of float element type " "should only contain floats" + ) with pytest.raises(VerifyException) as err: DenseArrayBase([i32, ArrayAttr([FloatData(0.0)])]) - assert err.value.args[0] == ("dense array of integer element type " - "should only contain integers") + assert err.value.args[0] == ( + "dense array of integer element type " "should only contain integers" + ) -@pytest.mark.parametrize('ref,expected', ( - (SymbolRefAttr('test'), 'test'), - (SymbolRefAttr('test', ["2"]), 'test.2'), - (SymbolRefAttr('test', ["2", "3"]), 'test.2.3'), -)) +@pytest.mark.parametrize( + "ref,expected", + ( + (SymbolRefAttr("test"), "test"), + (SymbolRefAttr("test", ["2"]), "test.2"), + (SymbolRefAttr("test", ["2", "3"]), "test.2.3"), + ), +) def test_SymbolRefAttr_string_value(ref: SymbolRefAttr, expected: str): assert ref.string_value() == expected @@ -73,15 +90,17 @@ def test_array_len_and_iter_attr(): assert tuple(arr) == arr.data -@pytest.mark.parametrize('attr, dims, num_scalable_dims', ( - (i32, [1, 2], 0), - (i32, [1, 2], 1), - (i32, [1, 1, 3], 0), - (i64, [1, 1, 3], 2), - (i64, [], 0), -)) -def test_vector_constructor(attr: Attribute, dims: list[int], - num_scalable_dims: int): +@pytest.mark.parametrize( + "attr, dims, num_scalable_dims", + ( + (i32, [1, 2], 0), + (i32, [1, 2], 1), + (i32, [1, 1, 3], 0), + (i64, [1, 1, 3], 2), + (i64, [], 0), + ), +) +def test_vector_constructor(attr: Attribute, dims: list[int], num_scalable_dims: int): vec = VectorType.from_element_type_and_shape(attr, dims, num_scalable_dims) assert vec.get_num_dims() == len(dims) @@ -89,11 +108,14 @@ def test_vector_constructor(attr: Attribute, dims: list[int], assert vec.get_shape() == dims -@pytest.mark.parametrize('dims, num_scalable_dims', ( - ([], 1), - ([1, 2], 3), - ([1], 2), -)) +@pytest.mark.parametrize( + "dims, num_scalable_dims", + ( + ([], 1), + ([1, 2], 3), + ([1], 2), + ), +) def test_vector_verifier_fail(dims: list[int], num_scalable_dims: int): with pytest.raises(VerifyException): VectorType.from_element_type_and_shape(i32, dims, num_scalable_dims) @@ -197,24 +219,28 @@ def test_unrealized_conversion_cast(): conv_op1 = UnrealizedConversionCastOp.get([i64_constant.results[0]], [f32]) conv_op2 = UnrealizedConversionCastOp.get([f32_constant.results[0]], [i32]) - assert (conv_op1.inputs[0].typ == i64) - assert (conv_op1.outputs[0].typ == f32) + assert conv_op1.inputs[0].typ == i64 + assert conv_op1.outputs[0].typ == f32 - assert (conv_op2.inputs[0].typ == f32) - assert (conv_op2.outputs[0].typ == i32) + assert conv_op2.inputs[0].typ == f32 + assert conv_op2.outputs[0].typ == i32 @pytest.mark.parametrize( "strides, offset, expected_strides, expected_offset", - [([2], None, ArrayAttr([IntAttr(2)]), NoneAttr()), - ([None], 2, ArrayAttr([NoneAttr()]), IntAttr(2)), - ([IntAttr(2)], NoneAttr(), ArrayAttr([IntAttr(2)]), NoneAttr()), - ([NoneAttr()], IntAttr(2), ArrayAttr([NoneAttr()]), IntAttr(2))]) -def test_strided_constructor(strides: ArrayAttr[IntAttr | NoneAttr] - | Sequence[int | None | IntAttr | NoneAttr], - offset: int | None | IntAttr | NoneAttr, - expected_strides: ArrayAttr[IntAttr | NoneAttr], - expected_offset: IntAttr | NoneAttr): + [ + ([2], None, ArrayAttr([IntAttr(2)]), NoneAttr()), + ([None], 2, ArrayAttr([NoneAttr()]), IntAttr(2)), + ([IntAttr(2)], NoneAttr(), ArrayAttr([IntAttr(2)]), NoneAttr()), + ([NoneAttr()], IntAttr(2), ArrayAttr([NoneAttr()]), IntAttr(2)), + ], +) +def test_strided_constructor( + strides: ArrayAttr[IntAttr | NoneAttr] | Sequence[int | None | IntAttr | NoneAttr], + offset: int | None | IntAttr | NoneAttr, + expected_strides: ArrayAttr[IntAttr | NoneAttr], + expected_offset: IntAttr | NoneAttr, +): strided = StridedLayoutAttr(strides, offset) assert strided.strides == expected_strides assert strided.offset == expected_offset diff --git a/tests/dialects/test_func.py b/tests/dialects/test_func.py index 9300cad0fc..a9409643d1 100644 --- a/tests/dialects/test_func.py +++ b/tests/dialects/test_func.py @@ -85,7 +85,8 @@ def test_wrong_blockarg_types(): assert e.value.args[0] == ( "Expected entry block arguments to have the same " - "types as the function input types") + "types as the function input types" + ) def test_func_rewriting_helpers(): @@ -93,8 +94,9 @@ def test_func_rewriting_helpers(): test replace_argument_type and update_function_type (implicitly) :return: """ - func = FuncOp.from_callable('test', [i32, i32, i32], [], - lambda *args: [Return.get()]) + func = FuncOp.from_callable( + "test", [i32, i32, i32], [], lambda *args: [Return.get()] + ) func.replace_argument_type(2, i64) assert func.function_type.inputs.data[2] is i64 @@ -116,7 +118,7 @@ def test_func_rewriting_helpers(): with pytest.raises(IndexError): func.replace_argument_type(-4, i64) - decl = FuncOp.external('external_func', [], []) + decl = FuncOp.external("external_func", [], []) assert decl.is_declaration with pytest.raises(AssertionError): @@ -126,10 +128,11 @@ def test_func_rewriting_helpers(): def test_func_get_return_op(): # pyright complains about lambda arg types unknown # honestly don't know how to fix - func_w_ret = FuncOp.from_callable('test', [i32, i32, i32], [i32], - lambda *args: [Return.get(args[1])]) + func_w_ret = FuncOp.from_callable( + "test", [i32, i32, i32], [i32], lambda *args: [Return.get(args[1])] + ) - func = FuncOp.from_callable('test', [i32, i32, i32], [], lambda *args: []) + func = FuncOp.from_callable("test", [i32, i32, i32], [], lambda *args: []) assert func_w_ret.get_return_op() is not None assert func.get_return_op() is None @@ -161,9 +164,9 @@ def test_call(): # Create a func0 that gets the block args as arguments, returns the resulting # type of c and has the region as body - func0 = FuncOp.from_region("func0", - [block0.args[0].typ, block0.args[1].typ], - [c.result.typ], region) + func0 = FuncOp.from_region( + "func0", [block0.args[0].typ, block0.args[1].typ], [c.result.typ], region + ) # Create a call for this function, passing a, b as args # and returning the type of the return @@ -172,8 +175,7 @@ def test_call(): # Wrap all in a ModuleOp mod = ModuleOp([func0, a, b, call0]) - expected = \ - """ + expected = """ builtin.module() { func.func() ["sym_name" = "func0", "function_type" = !fun<[!i32, !i32], [!i32]>, "sym_visibility" = "private"] { ^0(%0 : !i32, %1 : !i32): @@ -210,8 +212,7 @@ def test_call_II(): # Create a func0 that gets the block args as arguments, returns the resulting # type of c and has the region as body - func0 = FuncOp.from_region("func1", [block0.args[0].typ], [c.result.typ], - region) + func0 = FuncOp.from_region("func1", [block0.args[0].typ], [c.result.typ], region) # Create a call for this function, passing a, b as args # and returning the type of the return @@ -220,8 +221,7 @@ def test_call_II(): # Wrap all in a ModuleOp mod = ModuleOp([func0, a, call0]) - expected = \ - """ + expected = """ builtin.module() { func.func() ["sym_name" = "func1", "function_type" = !fun<[!i32], [!i32]>, "sym_visibility" = "private"] { ^0(%0 : !i32): diff --git a/tests/dialects/test_gpu.py b/tests/dialects/test_gpu.py index a9b2d0cc51..441f925ce5 100644 --- a/tests/dialects/test_gpu.py +++ b/tests/dialects/test_gpu.py @@ -1,11 +1,31 @@ from xdsl.builder import Builder from xdsl.dialects import builtin, arith, memref from xdsl.dialects.gpu import ( - AllocOp, AllReduceOp, AllReduceOperationAttr, AsyncTokenType, BarrierOp, - BlockDimOp, BlockIdOp, DeallocOp, GlobalIdOp, GridDimOp, HostRegisterOp, - LaneIdOp, LaunchOp, MemcpyOp, ModuleEndOp, ModuleOp, DimensionAttr, - NumSubgroupsOp, SetDefaultDeviceOp, SubgroupIdOp, SubgroupSizeOp, - TerminatorOp, ThreadIdOp, YieldOp) + AllocOp, + AllReduceOp, + AllReduceOperationAttr, + AsyncTokenType, + BarrierOp, + BlockDimOp, + BlockIdOp, + DeallocOp, + GlobalIdOp, + GridDimOp, + HostRegisterOp, + LaneIdOp, + LaunchOp, + MemcpyOp, + ModuleEndOp, + ModuleOp, + DimensionAttr, + NumSubgroupsOp, + SetDefaultDeviceOp, + SubgroupIdOp, + SubgroupSizeOp, + TerminatorOp, + ThreadIdOp, + YieldOp, +) from xdsl.ir import Block, Operation, Region, SSAValue @@ -17,7 +37,8 @@ def test_dimension(): def test_alloc(): typ = memref.MemRefType.from_element_type_and_shape( - builtin.Float32Type(), [10, 10, 10]) + builtin.Float32Type(), [10, 10, 10] + ) alloc = AllocOp.get(typ, is_async=True) assert isinstance(alloc, AllocOp) @@ -29,15 +50,18 @@ def test_alloc(): assert alloc.hostShared is None dyntyp = memref.MemRefType.from_element_type_and_shape( - builtin.Float32Type(), [-1, -1, -1]) + builtin.Float32Type(), [-1, -1, -1] + ) ten = arith.Constant.from_int_and_width(10, builtin.IndexType()) dynamic_sizes = [ten, ten, ten] token = alloc.asyncToken - full_alloc = AllocOp.get(return_type=dyntyp, - dynamic_sizes=dynamic_sizes, - host_shared=True, - async_dependencies=[token]) + full_alloc = AllocOp.get( + return_type=dyntyp, + dynamic_sizes=dynamic_sizes, + host_shared=True, + async_dependencies=[token], + ) assert isinstance(full_alloc, AllocOp) assert full_alloc.result.typ is dyntyp @@ -110,16 +134,16 @@ def test_block_id(): def test_dealloc(): - typ = memref.MemRefType.from_element_type_and_shape( - builtin.Float32Type(), [10, 10, 10]) + builtin.Float32Type(), [10, 10, 10] + ) alloc = AllocOp.get(typ, is_async=True) assert alloc.asyncToken is not None # For pyright - dealloc = DeallocOp.get(buffer=alloc.result, - async_dependencies=[alloc.asyncToken], - is_async=True) + dealloc = DeallocOp.get( + buffer=alloc.result, async_dependencies=[alloc.asyncToken], is_async=True + ) assert dealloc.asyncToken is not None assert isinstance(dealloc.asyncToken.typ, AsyncTokenType) @@ -171,13 +195,12 @@ def test_grid_dim(): def test_host_register(): - memref_type = memref.MemRefType.from_element_type_and_shape( - builtin.i32, [10, 10]) + memref_type = memref.MemRefType.from_element_type_and_shape(builtin.i32, [10, 10]) ref = memref.Alloca.get(memref_type, 0) unranked = memref.Cast.build( - operands=[ref], - result_types=[memref.UnrankedMemrefType.from_type(builtin.i32)]) + operands=[ref], result_types=[memref.UnrankedMemrefType.from_type(builtin.i32)] + ) register = HostRegisterOp.from_memref(unranked) @@ -214,8 +237,7 @@ def test_launch(): body2 = Region() - nd_launch = LaunchOp.get(body2, gridSize, blockSize, True, - asyncDependencies, ten) + nd_launch = LaunchOp.get(body2, gridSize, blockSize, True, asyncDependencies, ten) assert isinstance(launch, LaunchOp) assert nd_launch.body is body2 @@ -232,13 +254,13 @@ def test_launch(): def test_memcpy(): - typ = memref.MemRefType.from_element_type_and_shape( - builtin.Float32Type(), [10, 10, 10]) + builtin.Float32Type(), [10, 10, 10] + ) host_alloc = memref.Alloc.get(builtin.Float32Type(), 0, [10, 10, 10]) alloc = AllocOp.get(typ, is_async=True) - assert alloc.asyncToken is not None #for Pyright + assert alloc.asyncToken is not None # for Pyright memcpy = MemcpyOp.get(host_alloc, alloc.result, [alloc.asyncToken]) @@ -248,9 +270,9 @@ def test_memcpy(): assert memcpy.asyncDependencies == tuple([alloc.asyncToken]) assert memcpy.asyncToken is None - memcpy2 = MemcpyOp.get(alloc.result, - host_alloc.memref, [alloc.asyncToken], - is_async=True) + memcpy2 = MemcpyOp.get( + alloc.result, host_alloc.memref, [alloc.asyncToken], is_async=True + ) assert isinstance(memcpy2, MemcpyOp) assert memcpy2.src is alloc.result @@ -303,9 +325,9 @@ def test_terminator(): def test_yield(): - operands: list[SSAValue | Operation] = [ - o for o in [ + o + for o in [ arith.Constant.from_int_and_width(42, builtin.i32), arith.Constant.from_int_and_width(19, builtin.IndexType()), arith.Constant.from_int_and_width(84, builtin.i64), diff --git a/tests/dialects/test_irdl.py b/tests/dialects/test_irdl.py index 6dffcab7ba..1987e8b27c 100644 --- a/tests/dialects/test_irdl.py +++ b/tests/dialects/test_irdl.py @@ -8,17 +8,14 @@ def test_dialect_accessors(): Create a dialect with some operations and types, and check that we can retrieve the list of operations, or the list of types. """ - type1 = TypeOp.create(attributes={"name": StringAttr("type1")}, - regions=[Region()]) - type2 = TypeOp.create(attributes={"name": StringAttr("type2")}, - regions=[Region()]) - op1 = OperationOp.create(attributes={"name": StringAttr("op1")}, - regions=[Region()]) - op2 = OperationOp.create(attributes={"name": StringAttr("op2")}, - regions=[Region()]) + type1 = TypeOp.create(attributes={"name": StringAttr("type1")}, regions=[Region()]) + type2 = TypeOp.create(attributes={"name": StringAttr("type2")}, regions=[Region()]) + op1 = OperationOp.create(attributes={"name": StringAttr("op1")}, regions=[Region()]) + op2 = OperationOp.create(attributes={"name": StringAttr("op2")}, regions=[Region()]) dialect = DialectOp.create( attributes={"name": StringAttr("test")}, - regions=[Region([Block([type1, type2, op1, op2])])]) + regions=[Region([Block([type1, type2, op1, op2])])], + ) assert dialect.get_op_defs() == [op1, op2] assert dialect.get_type_defs() == [type1, type2] @@ -34,15 +31,18 @@ def test_operation_accessors(): results = ResultsOp.create(attributes={"params": AnyArrayAttr([])}) # Check it on an operation that has operands and results - op = OperationOp.create(attributes={"name": StringAttr("op")}, - regions=[Region([Block([operands, results])])]) + op = OperationOp.create( + attributes={"name": StringAttr("op")}, + regions=[Region([Block([operands, results])])], + ) assert op.get_operands() is operands assert op.get_results() is results # Check it on an operation that has no operands and results - empty_op = OperationOp.create(attributes={"name": StringAttr("op")}, - regions=[Region([Block()])]) + empty_op = OperationOp.create( + attributes={"name": StringAttr("op")}, regions=[Region([Block()])] + ) assert empty_op.get_operands() is None assert empty_op.get_results() is None diff --git a/tests/dialects/test_llvm.py b/tests/dialects/test_llvm.py index 834ab3cab8..432cac03d6 100644 --- a/tests/dialects/test_llvm.py +++ b/tests/dialects/test_llvm.py @@ -5,19 +5,19 @@ def test_llvm_pointer_ops(): - module = builtin.ModuleOp([ - idx := arith.Constant.from_int_and_width(0, 64), - ptr := llvm.AllocaOp.get(idx, builtin.i32), - val := llvm.LoadOp.get(ptr), - nullptr := llvm.NullOp.get(), - alloc_ptr := llvm.AllocaOp.get(idx, elem_type=builtin.IndexType()), - llvm.LoadOp.get(alloc_ptr), - store := llvm.StoreOp.get(val, - ptr, - alignment=32, - volatile=True, - nontemporal=True), - ]) + module = builtin.ModuleOp( + [ + idx := arith.Constant.from_int_and_width(0, 64), + ptr := llvm.AllocaOp.get(idx, builtin.i32), + val := llvm.LoadOp.get(ptr), + nullptr := llvm.NullOp.get(), + alloc_ptr := llvm.AllocaOp.get(idx, elem_type=builtin.IndexType()), + llvm.LoadOp.get(alloc_ptr), + store := llvm.StoreOp.get( + val, ptr, alignment=32, volatile=True, nontemporal=True + ), + ] + ) module.verify() @@ -27,10 +27,10 @@ def test_llvm_pointer_ops(): assert ptr.res.typ.type == builtin.i32 assert isinstance(ptr.res.typ.addr_space, builtin.NoneAttr) - assert 'volatile_' in store.attributes - assert 'nontemporal' in store.attributes - assert 'alignment' in store.attributes - assert 'ordering' in store.attributes + assert "volatile_" in store.attributes + assert "nontemporal" in store.attributes + assert "alignment" in store.attributes + assert "ordering" in store.attributes assert isinstance(nullptr.nullptr.typ, llvm.LLVMPointerType) assert isinstance(nullptr.nullptr.typ.type, builtin.NoneAttr) @@ -54,12 +54,12 @@ def test_llvm_pointer_type(): assert llvm.LLVMPointerType.typed(builtin.i64).is_typed() assert llvm.LLVMPointerType.typed(builtin.i64).type is builtin.i64 assert isinstance( - llvm.LLVMPointerType.typed(builtin.i64).addr_space, builtin.NoneAttr) + llvm.LLVMPointerType.typed(builtin.i64).addr_space, builtin.NoneAttr + ) assert not llvm.LLVMPointerType.opaque().is_typed() assert isinstance(llvm.LLVMPointerType.opaque().type, builtin.NoneAttr) - assert isinstance(llvm.LLVMPointerType.opaque().addr_space, - builtin.NoneAttr) + assert isinstance(llvm.LLVMPointerType.opaque().addr_space, builtin.NoneAttr) def test_llvm_getelementptr_op_invalid_construction(): @@ -69,8 +69,7 @@ def test_llvm_getelementptr_op_invalid_construction(): # check that passing an opaque pointer to GEP without a pointee type fails with pytest.raises(ValueError): - llvm.GEPOp.get(opaque_ptr, llvm.LLVMPointerType.typed(builtin.i32), - [1]) + llvm.GEPOp.get(opaque_ptr, llvm.LLVMPointerType.typed(builtin.i32), [1]) # check that non-pointer arguments fail with pytest.raises(ValueError): @@ -78,7 +77,7 @@ def test_llvm_getelementptr_op_invalid_construction(): # check that non-pointer result types fail with pytest.raises(ValueError): - llvm.GEPOp.get(ptr, builtin.i32, [1]) #type: ignore + llvm.GEPOp.get(ptr, builtin.i32, [1]) # type: ignore def test_llvm_getelementptr_op(): @@ -90,18 +89,18 @@ def test_llvm_getelementptr_op(): # check that construction with static-only offsets and inbounds attr works: gep1 = llvm.GEPOp.get(ptr, ptr_typ, [1], inbounds=True) - assert 'inbounds' in gep1.attributes + assert "inbounds" in gep1.attributes assert gep1.result.typ == ptr_typ assert gep1.ptr == ptr.res - assert 'elem_type' not in gep1.attributes + assert "elem_type" not in gep1.attributes assert len(gep1.rawConstantIndices.data) == 1 assert len(gep1.ssa_indices) == 0 # check that construction with opaque pointer works: gep2 = llvm.GEPOp.get(opaque_ptr, ptr_typ, [1], pointee_type=builtin.i32) - assert 'elem_type' in gep2.attributes - assert 'inbounds' not in gep2.attributes + assert "elem_type" in gep2.attributes + assert "inbounds" not in gep2.attributes assert gep2.result.typ == ptr_typ assert len(gep1.rawConstantIndices.data) == 1 assert len(gep1.ssa_indices) == 0 @@ -134,15 +133,17 @@ def test_linkage_attr_unknown_str(): def test_global_op(): - global_op = llvm.GlobalOp.get(builtin.i32, - "testsymbol", - "internal", - 10, - True, - value=builtin.IntegerAttr(76, 32), - alignment=8, - unnamed_addr=0, - section="test") + global_op = llvm.GlobalOp.get( + builtin.i32, + "testsymbol", + "internal", + 10, + True, + value=builtin.IntegerAttr(76, 32), + alignment=8, + unnamed_addr=0, + section="test", + ) assert global_op.global_type == builtin.i32 assert isinstance(global_op.sym_name, builtin.StringAttr) diff --git a/tests/dialects/test_memref.py b/tests/dialects/test_memref.py index c160205f98..6999cdbca2 100644 --- a/tests/dialects/test_memref.py +++ b/tests/dialects/test_memref.py @@ -1,10 +1,29 @@ from xdsl.ir import OpResult, Block from xdsl.dialects.arith import Constant -from xdsl.dialects.builtin import StridedLayoutAttr, i32, i64, IntegerType, IndexType, ArrayAttr, DenseArrayBase, IntegerAttr, IntAttr -from xdsl.dialects.memref import (Alloc, Alloca, Dealloc, Dealloca, MemRefType, - Load, Store, UnrankedMemrefType, - ExtractAlignedPointerAsIndexOp, Subview, - Cast) +from xdsl.dialects.builtin import ( + StridedLayoutAttr, + i32, + i64, + IntegerType, + IndexType, + ArrayAttr, + DenseArrayBase, + IntegerAttr, + IntAttr, +) +from xdsl.dialects.memref import ( + Alloc, + Alloca, + Dealloc, + Dealloca, + MemRefType, + Load, + Store, + UnrankedMemrefType, + ExtractAlignedPointerAsIndexOp, + Subview, + Cast, +) from xdsl.dialects import builtin, memref, func, arith, scf from xdsl.utils.hints import isa from xdsl.utils.test_value import TestSSAValue @@ -150,8 +169,10 @@ def test_memref_ExtractAlignedPointerAsIndexOp(): def test_memref_matmul_verify(): memref_f64_rank2 = memref.MemRefType.from_element_type_and_shape( - builtin.f64, [-1, -1]) + builtin.f64, [-1, -1] + ) + # fmt: off module = builtin.ModuleOp([ func.FuncOp.from_callable( 'matmul', @@ -189,7 +210,8 @@ def test_memref_matmul_verify(): func.Return.get(out) ] ) - ]) # yapf: disable + ]) + # fmt: on # check that it verifies correctly module.verify() @@ -201,38 +223,36 @@ def test_memref_subview(): res_memref_type = MemRefType.from_element_type_and_shape(i32, [1, 1]) - offset_arg1 = Constant.from_attr(IntegerAttr.from_int_and_width(0, 64), - i64) - offset_arg2 = Constant.from_attr(IntegerAttr.from_int_and_width(0, 64), - i64) + offset_arg1 = Constant.from_attr(IntegerAttr.from_int_and_width(0, 64), i64) + offset_arg2 = Constant.from_attr(IntegerAttr.from_int_and_width(0, 64), i64) size_arg1 = Constant.from_attr(IntegerAttr.from_int_and_width(1, 64), i64) size_arg2 = Constant.from_attr(IntegerAttr.from_int_and_width(1, 64), i64) - stride_arg1 = Constant.from_attr(IntegerAttr.from_int_and_width(1, 64), - i64) - stride_arg2 = Constant.from_attr(IntegerAttr.from_int_and_width(1, 64), - i64) + stride_arg1 = Constant.from_attr(IntegerAttr.from_int_and_width(1, 64), i64) + stride_arg2 = Constant.from_attr(IntegerAttr.from_int_and_width(1, 64), i64) - operand_segment_sizes = ArrayAttr( - [IntAttr(1), IntAttr(2), - IntAttr(2), IntAttr(2)]) + operand_segment_sizes = ArrayAttr([IntAttr(1), IntAttr(2), IntAttr(2), IntAttr(2)]) static_offsets = DenseArrayBase.from_list(i64, [0, 0]) static_sizes = DenseArrayBase.from_list(i64, [1, 1]) static_strides = DenseArrayBase.from_list(i64, [1, 1]) - subview = Subview.build(operands=[ - memref_ssa_value, [offset_arg1, offset_arg2], [size_arg1, size_arg2], - [stride_arg1, stride_arg2] - ], - attributes={ - "operand_segment_sizes": operand_segment_sizes, - "static_offsets": static_offsets, - "static_sizes": static_sizes, - "static_strides": static_strides - }, - result_types=[res_memref_type]) + subview = Subview.build( + operands=[ + memref_ssa_value, + [offset_arg1, offset_arg2], + [size_arg1, size_arg2], + [stride_arg1, stride_arg2], + ], + attributes={ + "operand_segment_sizes": operand_segment_sizes, + "static_offsets": static_offsets, + "static_sizes": static_sizes, + "static_strides": static_strides, + }, + result_types=[res_memref_type], + ) assert subview.source is memref_ssa_value assert subview.offsets == (offset_arg1.result, offset_arg2.result) @@ -249,8 +269,9 @@ def test_memref_subview_constant_parameters(): shape: list[int] = [10, 10, 10] alloc = Alloc.get(element_type, 8, list(shape)) - subview = Subview.from_static_parameters(alloc, element_type, shape, - [2, 2, 2], [2, 2, 2], [3, 3, 3]) + subview = Subview.from_static_parameters( + alloc, element_type, shape, [2, 2, 2], [2, 2, 2], [3, 3, 3] + ) assert isinstance(subview, Subview) assert isinstance(subview.result.typ, MemRefType) diff --git a/tests/dialects/test_mpi.py b/tests/dialects/test_mpi.py index 71040de389..7fd50e470c 100644 --- a/tests/dialects/test_mpi.py +++ b/tests/dialects/test_mpi.py @@ -13,15 +13,12 @@ def test_mpi_baseop(): req_vec = mpi.AllocateTypeOp.get(mpi.RequestType, dest) req_obj = mpi.VectorGetOp.get(req_vec, dest) tag = Constant.from_int_and_width(1, i32) - send = mpi.Isend.get(unwrap.ptr, unwrap.len, unwrap.typ, dest, tag, - req_obj) + send = mpi.Isend.get(unwrap.ptr, unwrap.len, unwrap.typ, dest, tag, req_obj) wait = mpi.Wait.get(send.request, ignore_status=False) - recv = mpi.Irecv.get(unwrap.ptr, unwrap.len, unwrap.typ, dest, tag, - req_obj) + recv = mpi.Irecv.get(unwrap.ptr, unwrap.len, unwrap.typ, dest, tag, req_obj) test_res = mpi.Test.get(recv.request) assert wait.status is not None - source = mpi.GetStatusField.get(wait.status, - mpi.StatusTypeField.MPI_SOURCE) + source = mpi.GetStatusField.get(wait.status, mpi.StatusTypeField.MPI_SOURCE) assert unwrap.ref == alloc0.memref assert send.buffer == unwrap.ptr diff --git a/tests/dialects/test_mpi_lowering.py b/tests/dialects/test_mpi_lowering.py index ede67270be..f81e31d958 100644 --- a/tests/dialects/test_mpi_lowering.py +++ b/tests/dialects/test_mpi_lowering.py @@ -8,11 +8,9 @@ info = lower_mpi.MpiLibraryInfo() -def extract_func_call(ops: list[Operation], - name: str = 'MPI_') -> func.Call | None: +def extract_func_call(ops: list[Operation], name: str = "MPI_") -> func.Call | None: for op in ops: - if (isinstance(op, func.Call) - and op.callee.string_value().startswith(name)): + if isinstance(op, func.Call) and op.callee.string_value().startswith(name): return op @@ -52,7 +50,7 @@ def test_lower_mpi_init(): assert isinstance(call, func.Call) assert isinstance(nullop, llvm.NullOp) - assert call.callee.string_value() == 'MPI_Init' + assert call.callee.string_value() == "MPI_Init" assert len(call.arguments) == 2 assert all(arg == nullop.nullptr for arg in call.arguments) @@ -63,37 +61,38 @@ def test_lower_mpi_finalize(): assert len(result) == 0 assert len(ops) == 1 - call, = ops + (call,) = ops assert isinstance(call, func.Call) - assert call.callee.string_value() == 'MPI_Finalize' + assert call.callee.string_value() == "MPI_Finalize" assert len(call.arguments) == 0 def test_lower_mpi_wait_no_status(): - request, = CreateTestValsOp.get(mpi.RequestType()).results + (request,) = CreateTestValsOp.get(mpi.RequestType()).results ops, result = lower_mpi.LowerMpiWait(info).lower(mpi.Wait.get(request)) assert len(result) == 0 call = extract_func_call(ops) assert call is not None - assert call.callee.string_value() == 'MPI_Wait' + assert call.callee.string_value() == "MPI_Wait" assert len(call.arguments) == 2 def test_lower_mpi_wait_with_status(): - request, = CreateTestValsOp.get(mpi.RequestType()).results + (request,) = CreateTestValsOp.get(mpi.RequestType()).results ops, result = lower_mpi.LowerMpiWait(info).lower( - mpi.Wait.get(request, ignore_status=False)) + mpi.Wait.get(request, ignore_status=False) + ) assert len(result) == 1 assert result[0] is not None assert isinstance(result[0].typ, llvm.LLVMPointerType) call = extract_func_call(ops) assert call is not None - assert call.callee.string_value() == 'MPI_Wait' + assert call.callee.string_value() == "MPI_Wait" assert len(call.arguments) == 2 assert isinstance(call.arguments[1], OpResult) assert isinstance(call.arguments[1].op, llvm.AllocaOp) @@ -110,7 +109,7 @@ def test_lower_mpi_comm_rank(): # int MPI_Comm_rank(MPI_Comm comm, int *rank) check_emitted_function_signature( ops, - 'MPI_Comm_rank', + "MPI_Comm_rank", (None, llvm.LLVMPointerType), ) @@ -126,17 +125,19 @@ def test_lower_mpi_comm_size(): # int MPI_Comm_size(MPI_Comm comm, int *size) check_emitted_function_signature( ops, - 'MPI_Comm_size', + "MPI_Comm_size", (None, llvm.LLVMPointerType), ) def test_lower_mpi_send(): buff, size, dtype, dest, tag = CreateTestValsOp.get( - llvm.LLVMPointerType.typed(i32), i32, mpi.DataType(), i32, i32).results + llvm.LLVMPointerType.typed(i32), i32, mpi.DataType(), i32, i32 + ).results ops, result = lower_mpi.LowerMpiSend(info).lower( - mpi.Send.get(buff, size, dtype, dest, tag)) + mpi.Send.get(buff, size, dtype, dest, tag) + ) """ Check for function with signature like: int MPI_Send(const void *buf, int count, MPI_Datatype datatype, int dest, @@ -148,18 +149,19 @@ def test_lower_mpi_send(): check_emitted_function_signature( ops, - 'MPI_Send', + "MPI_Send", (llvm.LLVMPointerType, type(i32), None, type(i32), type(i32), None), ) def test_lower_mpi_isend(): ptr, count, dtype, dest, tag, req = CreateTestValsOp.get( - llvm.LLVMPointerType.opaque(), i32, mpi.DataType(), i32, i32, - mpi.RequestType()).results + llvm.LLVMPointerType.opaque(), i32, mpi.DataType(), i32, i32, mpi.RequestType() + ).results ops, result = lower_mpi.LowerMpiIsend(info).lower( - mpi.Isend.get(ptr, count, dtype, dest, tag, req)) + mpi.Isend.get(ptr, count, dtype, dest, tag, req) + ) """ Check for function with signature like: int MPI_Isend(const void *buf, int count, MPI_Datatype datatype, int dest, @@ -171,94 +173,129 @@ def test_lower_mpi_isend(): check_emitted_function_signature( ops, - 'MPI_Isend', - (llvm.LLVMPointerType, type(i32), None, type(i32), type(i32), None, - mpi.RequestType), + "MPI_Isend", + ( + llvm.LLVMPointerType, + type(i32), + None, + type(i32), + type(i32), + None, + mpi.RequestType, + ), ) def test_lower_mpi_recv_no_status(): buff, count, dtype, source, tag = CreateTestValsOp.get( - llvm.LLVMPointerType.typed(i32), i32, mpi.DataType(), i32, i32).results + llvm.LLVMPointerType.typed(i32), i32, mpi.DataType(), i32, i32 + ).results """ int MPI_Recv(void *buf, int count, MPI_Datatype datatype, int source, int tag, MPI_Comm comm, MPI_Status *status) """ ops, result = lower_mpi.LowerMpiRecv(info).lower( - mpi.Recv.get(buff, count, dtype, source, tag, ignore_status=True)) + mpi.Recv.get(buff, count, dtype, source, tag, ignore_status=True) + ) assert len(result) == 0 check_emitted_function_signature( ops, - 'MPI_Recv', - (llvm.LLVMPointerType, type(i32), None, type(i32), type(i32), None, - llvm.LLVMPointerType), + "MPI_Recv", + ( + llvm.LLVMPointerType, + type(i32), + None, + type(i32), + type(i32), + None, + llvm.LLVMPointerType, + ), ) def test_lower_mpi_recv_with_status(): buff, count, dtype, source, tag = CreateTestValsOp.get( - llvm.LLVMPointerType.typed(i32), i32, mpi.DataType(), i32, i32).results + llvm.LLVMPointerType.typed(i32), i32, mpi.DataType(), i32, i32 + ).results """ int MPI_Recv(void *buf, int count, MPI_Datatype datatype, int source, int tag, MPI_Comm comm, MPI_Status *status) """ ops, result = lower_mpi.LowerMpiRecv(info).lower( - mpi.Recv.get(buff, count, dtype, source, tag, ignore_status=False)) + mpi.Recv.get(buff, count, dtype, source, tag, ignore_status=False) + ) assert len(result) == 1 check_emitted_function_signature( ops, - 'MPI_Recv', - (llvm.LLVMPointerType, type(i32), None, type(i32), type(i32), None, - llvm.LLVMPointerType), + "MPI_Recv", + ( + llvm.LLVMPointerType, + type(i32), + None, + type(i32), + type(i32), + None, + llvm.LLVMPointerType, + ), ) def test_lower_mpi_irecv(): ptr, count, dtype, source, tag, req = CreateTestValsOp.get( - llvm.LLVMPointerType.opaque(), i32, mpi.DataType(), i32, i32, - mpi.RequestType()).results + llvm.LLVMPointerType.opaque(), i32, mpi.DataType(), i32, i32, mpi.RequestType() + ).results """ int MPI_Irecv(void *buf, int count, MPI_Datatype datatype, int source, int tag, MPI_Comm comm, MPI_Request *request) """ ops, result = lower_mpi.LowerMpiIrecv(info).lower( - mpi.Irecv.get(ptr, count, dtype, source, tag, req)) + mpi.Irecv.get(ptr, count, dtype, source, tag, req) + ) # recv has no results assert len(result) == 0 check_emitted_function_signature( ops, - 'MPI_Irecv', - (llvm.LLVMPointerType, type(i32), None, type(i32), type(i32), None, - mpi.RequestType), + "MPI_Irecv", + ( + llvm.LLVMPointerType, + type(i32), + None, + type(i32), + type(i32), + None, + mpi.RequestType, + ), ) def test_lower_mpi_reduce(): ptr, count, dtype, root = CreateTestValsOp.get( - llvm.LLVMPointerType.opaque(), i32, mpi.DataType(), i32).results + llvm.LLVMPointerType.opaque(), i32, mpi.DataType(), i32 + ).results """ int MPI_Reduce(const void *sendbuf, void *recvbuf, int count, MPI_Datatype datatype, MPI_Op op, int root, MPI_Comm comm) """ ops, result = lower_mpi.LowerMpiReduce(info).lower( - mpi.Reduce.get(ptr, ptr, count, dtype, mpi.MpiOp.MPI_SUM, root)) + mpi.Reduce.get(ptr, ptr, count, dtype, mpi.MpiOp.MPI_SUM, root) + ) # reduce has no results assert len(result) == 0 check_emitted_function_signature( ops, - 'MPI_Reduce', + "MPI_Reduce", ( llvm.LLVMPointerType, llvm.LLVMPointerType, @@ -272,22 +309,24 @@ def test_lower_mpi_reduce(): def test_lower_mpi_all_reduce(): - ptr, count, dtype = CreateTestValsOp.get(llvm.LLVMPointerType.opaque(), - i32, mpi.DataType()).results + ptr, count, dtype = CreateTestValsOp.get( + llvm.LLVMPointerType.opaque(), i32, mpi.DataType() + ).results """ int MPI_Allreduce(const void *sendbuf, void *recvbuf, int count, MPI_Datatype datatype, MPI_Op op, MPI_Comm comm) """ ops, result = lower_mpi.LowerMpiAllreduce(info).lower( - mpi.Allreduce.get(ptr, ptr, count, dtype, mpi.MpiOp.MPI_SUM)) + mpi.Allreduce.get(ptr, ptr, count, dtype, mpi.MpiOp.MPI_SUM) + ) # allreduce has no results assert len(result) == 0 check_emitted_function_signature( ops, - 'MPI_Allreduce', + "MPI_Allreduce", ( llvm.LLVMPointerType, llvm.LLVMPointerType, @@ -301,21 +340,23 @@ def test_lower_mpi_all_reduce(): def test_lower_mpi_bcast(): ptr, count, dtype, root = CreateTestValsOp.get( - llvm.LLVMPointerType.opaque(), i32, mpi.DataType(), i32).results + llvm.LLVMPointerType.opaque(), i32, mpi.DataType(), i32 + ).results """ int MPI_Bcast(void *buffer, int count, MPI_Datatype datatype, int root, MPI_Comm comm) """ ops, result = lower_mpi.LowerMpiBcast(info).lower( - mpi.Bcast.get(ptr, count, dtype, root)) + mpi.Bcast.get(ptr, count, dtype, root) + ) # bcast has no results assert len(result) == 0 check_emitted_function_signature( ops, - 'MPI_Bcast', + "MPI_Bcast", ( llvm.LLVMPointerType, type(i32), @@ -327,7 +368,7 @@ def test_lower_mpi_bcast(): def test_lower_mpi_allocate(): - count, = CreateTestValsOp.get(i32).results + (count,) = CreateTestValsOp.get(i32).results op = mpi.AllocateTypeOp.get(mpi.RequestType, count) ops, res = lower_mpi.LowerMpiAllocateType(info).lower(op) @@ -340,15 +381,16 @@ def test_lower_mpi_allocate(): def test_lower_mpi_vec_get(): - mod = builtin.ModuleOp([ - count := CreateTestValsOp.get(i32), - vec := mpi.AllocateTypeOp.get(mpi.RequestType, count), - get := mpi.VectorGetOp.get(vec, count), - ]) + mod = builtin.ModuleOp( + [ + count := CreateTestValsOp.get(i32), + vec := mpi.AllocateTypeOp.get(mpi.RequestType, count), + get := mpi.VectorGetOp.get(vec, count), + ] + ) # we have to apply this rewrite to that the argument type of the `get` # becomes an llvm.ptr - PatternRewriteWalker( - lower_mpi.LowerMpiAllocateType(info)).rewrite_module(mod) + PatternRewriteWalker(lower_mpi.LowerMpiAllocateType(info)).rewrite_module(mod) ops, res = lower_mpi.LowerMpiVectorGet(info).lower(get) @@ -376,6 +418,7 @@ def test_mpi_type_conversion(): lowering = lower_mpi.LowerMpiRecv(info) from xdsl.dialects.builtin import f64, f32, IntegerType, i32, i64, Signedness + u64 = IntegerType(64, Signedness.UNSIGNED) u32 = IntegerType(32, Signedness.UNSIGNED) @@ -389,12 +432,11 @@ def test_mpi_type_conversion(): ] for width in (8, 16): - for sign in (Signedness.UNSIGNED, Signedness.SIGNLESS, - Signedness.SIGNED): - sign_str = 'UNSIGNED_' if sign == Signedness.UNSIGNED else '' - name = 'CHAR' if width == 8 else 'SHORT' + for sign in (Signedness.UNSIGNED, Signedness.SIGNLESS, Signedness.SIGNED): + sign_str = "UNSIGNED_" if sign == Signedness.UNSIGNED else "" + name = "CHAR" if width == 8 else "SHORT" typ = IntegerType(width, sign) - checks.append((typ, getattr(info, f'MPI_{sign_str}{name}'))) + checks.append((typ, getattr(info, f"MPI_{sign_str}{name}"))) for type, target in checks: # we test a private member function here, so we need to tell pyright that that's okay diff --git a/tests/dialects/test_pdl.py b/tests/dialects/test_pdl.py index 51c8a9a37a..3873d523d0 100644 --- a/tests/dialects/test_pdl.py +++ b/tests/dialects/test_pdl.py @@ -10,59 +10,60 @@ value_type = pdl.ValueType() operation_type = pdl.OperationType() -block = Block(arg_types=[ - type_type, - attribute_type, - value_type, - operation_type, -]) +block = Block( + arg_types=[ + type_type, + attribute_type, + value_type, + operation_type, + ] +) type_val, attr_val, val_val, op_val = block.args def test_build_anc(): - anc = pdl.ApplyNativeConstraintOp.get('anc', [type_val]) + anc = pdl.ApplyNativeConstraintOp.get("anc", [type_val]) - assert anc.attributes['name'] == StringAttr('anc') - assert anc.args == (type_val, ) + assert anc.attributes["name"] == StringAttr("anc") + assert anc.args == (type_val,) def test_build_anr(): - anr = pdl.ApplyNativeRewriteOp.get('anr', [type_val], [attribute_type]) + anr = pdl.ApplyNativeRewriteOp.get("anr", [type_val], [attribute_type]) - assert anr.attributes['name'] == StringAttr('anr') - assert anr.args == (type_val, ) + assert anr.attributes["name"] == StringAttr("anr") + assert anr.args == (type_val,) assert len(anr.results) == 1 assert [r.typ for r in anr.results] == [attribute_type] def test_build_rewrite(): - r = pdl.RewriteOp.get(StringAttr('r'), - root=None, - external_args=[type_val, attr_val], - body=None) + r = pdl.RewriteOp.get( + StringAttr("r"), root=None, external_args=[type_val, attr_val], body=None + ) - assert r.attributes['name'] == StringAttr('r') + assert r.attributes["name"] == StringAttr("r") assert r.externalArgs == (type_val, attr_val) assert len(r.results) == 0 def test_build_operation_replace(): - operation = pdl.OperationOp.get(opName=StringAttr('operation'), - attributeValueNames=ArrayAttr( - [StringAttr('name')]), - operandValues=[val_val], - attributeValues=[attr_val], - typeValues=[type_val]) - - assert operation.opName == StringAttr('operation') - assert operation.attributeValueNames == ArrayAttr([StringAttr('name')]) - assert operation.operandValues == (val_val, ) - assert operation.attributeValues == (attr_val, ) - assert operation.typeValues == (type_val, ) - - replace = pdl.ReplaceOp.get(opValue=op_val, - replOperation=operation.results[0]) + operation = pdl.OperationOp.get( + opName=StringAttr("operation"), + attributeValueNames=ArrayAttr([StringAttr("name")]), + operandValues=[val_val], + attributeValues=[attr_val], + typeValues=[type_val], + ) + + assert operation.opName == StringAttr("operation") + assert operation.attributeValueNames == ArrayAttr([StringAttr("name")]) + assert operation.operandValues == (val_val,) + assert operation.attributeValues == (attr_val,) + assert operation.typeValues == (type_val,) + + replace = pdl.ReplaceOp.get(opValue=op_val, replOperation=operation.results[0]) replace.verify() assert replace.opValue == op_val @@ -74,22 +75,21 @@ def test_build_operation_replace(): assert replace.opValue == op_val assert replace.replOperation == None - assert replace.replValues == (val_val, ) + assert replace.replValues == (val_val,) with pytest.raises(VerifyException): replace = pdl.ReplaceOp.get(opValue=op_val) replace.verify() with pytest.raises(VerifyException): - replace = pdl.ReplaceOp.get(opValue=op_val, - replOperation=operation.results[0], - replValues=[val_val]) + replace = pdl.ReplaceOp.get( + opValue=op_val, replOperation=operation.results[0], replValues=[val_val] + ) replace.verify() def test_build_result(): - res = pdl.ResultOp.get(IntegerAttr.from_int_and_width(1, 32), - parent=op_val) + res = pdl.ResultOp.get(IntegerAttr.from_int_and_width(1, 32), parent=op_val) assert res.index == IntegerAttr.from_int_and_width(1, 32) assert res.parent_ == op_val diff --git a/tests/dialects/test_scf.py b/tests/dialects/test_scf.py index 8604587e99..b69fc507b7 100644 --- a/tests/dialects/test_scf.py +++ b/tests/dialects/test_scf.py @@ -58,11 +58,17 @@ def test_parallel(): def test_empty_else(): # create if without an else block: - m = ModuleOp([ - t := Constant.from_int_and_width(1, 1), - If.get(t, [], [ - Yield.get(), - ]), - ]) + m = ModuleOp( + [ + t := Constant.from_int_and_width(1, 1), + If.get( + t, + [], + [ + Yield.get(), + ], + ), + ] + ) assert len(cast(If, m.ops[1]).false_region.blocks) == 0 diff --git a/tests/dialects/test_stencil.py b/tests/dialects/test_stencil.py index 563881484d..ba668cccc9 100644 --- a/tests/dialects/test_stencil.py +++ b/tests/dialects/test_stencil.py @@ -35,8 +35,7 @@ def test_stencil_return_multiple_ResultType(): result_type_val2 = TestSSAValue(ResultType.from_type(f32)) result_type_val3 = TestSSAValue(ResultType.from_type(f32)) - return_op = ReturnOp.get( - [result_type_val1, result_type_val2, result_type_val3]) + return_op = ReturnOp.get([result_type_val1, result_type_val2, result_type_val3]) assert return_op.arg[0] is result_type_val1 assert return_op.arg[1] is result_type_val2 diff --git a/tests/dialects/test_vector.py b/tests/dialects/test_vector.py index ff5c5b464d..3bf83921a7 100644 --- a/tests/dialects/test_vector.py +++ b/tests/dialects/test_vector.py @@ -2,27 +2,42 @@ from typing import List -from xdsl.dialects.builtin import IntegerAttr, i1, i32, i64, IntegerType, IndexType, VectorType +from xdsl.dialects.builtin import ( + IntegerAttr, + i1, + i32, + i64, + IntegerType, + IndexType, + VectorType, +) from xdsl.dialects.memref import MemRefType, AnyIntegerAttr -from xdsl.dialects.vector import Broadcast, Load, Maskedload, Maskedstore, Store, FMA, Print, Createmask +from xdsl.dialects.vector import ( + Broadcast, + Load, + Maskedload, + Maskedstore, + Store, + FMA, + Print, + Createmask, +) from xdsl.ir import OpResult from xdsl.irdl import Attribute from xdsl.utils.test_value import TestSSAValue def get_MemRef_SSAVal_from_element_type_and_shape( - referenced_type: Attribute, - shape: List[int | AnyIntegerAttr]) -> TestSSAValue: - memref_type = MemRefType.from_element_type_and_shape( - referenced_type, shape) + referenced_type: Attribute, shape: List[int | AnyIntegerAttr] +) -> TestSSAValue: + memref_type = MemRefType.from_element_type_and_shape(referenced_type, shape) return TestSSAValue(memref_type) def get_Vector_SSAVal_from_element_type_and_shape( - referenced_type: Attribute, - shape: List[int | IntegerAttr[IndexType]]) -> TestSSAValue: - vector_type = VectorType.from_element_type_and_shape( - referenced_type, shape) + referenced_type: Attribute, shape: List[int | IntegerAttr[IndexType]] +) -> TestSSAValue: + vector_type = VectorType.from_element_type_and_shape(referenced_type, shape) return TestSSAValue(vector_type) @@ -61,8 +76,7 @@ def test_vector_load_i32(): def test_vector_load_i32_with_dimensions(): - memref_ssa_value = get_MemRef_SSAVal_from_element_type_and_shape( - i32, [2, 3]) + memref_ssa_value = get_MemRef_SSAVal_from_element_type_and_shape(i32, [2, 3]) index1 = TestSSAValue(IndexType()) index2 = TestSSAValue(IndexType()) load = Load.get(memref_ssa_value, [index1, index2]) @@ -76,21 +90,20 @@ def test_vector_load_i32_with_dimensions(): def test_vector_load_verify_type_matching(): res_vector_type = VectorType.from_element_type_and_shape(i64, [1]) - memref_ssa_value = get_MemRef_SSAVal_from_element_type_and_shape( - i32, [4, 5]) + memref_ssa_value = get_MemRef_SSAVal_from_element_type_and_shape(i32, [4, 5]) - load = Load.build(operands=[memref_ssa_value, []], - result_types=[res_vector_type]) + load = Load.build(operands=[memref_ssa_value, []], result_types=[res_vector_type]) with pytest.raises(Exception) as exc_info: load.verify() - assert exc_info.value.args[ - 0] == "MemRef element type should match the Vector element type." + assert ( + exc_info.value.args[0] + == "MemRef element type should match the Vector element type." + ) def test_vector_load_verify_indexing_exception(): - memref_ssa_value = get_MemRef_SSAVal_from_element_type_and_shape( - i32, [2, 3]) + memref_ssa_value = get_MemRef_SSAVal_from_element_type_and_shape(i32, [2, 3]) load = Load.get(memref_ssa_value, []) @@ -111,10 +124,8 @@ def test_vector_store_i32(): def test_vector_store_i32_with_dimensions(): - vector_ssa_value = get_Vector_SSAVal_from_element_type_and_shape( - i32, [2, 3]) - memref_ssa_value = get_MemRef_SSAVal_from_element_type_and_shape( - i32, [4, 5]) + vector_ssa_value = get_Vector_SSAVal_from_element_type_and_shape(i32, [2, 3]) + memref_ssa_value = get_MemRef_SSAVal_from_element_type_and_shape(i32, [4, 5]) index1 = TestSSAValue(IndexType()) index2 = TestSSAValue(IndexType()) @@ -127,24 +138,22 @@ def test_vector_store_i32_with_dimensions(): def test_vector_store_verify_type_matching(): - vector_ssa_value = get_Vector_SSAVal_from_element_type_and_shape( - i64, [2, 3]) - memref_ssa_value = get_MemRef_SSAVal_from_element_type_and_shape( - i32, [4, 5]) + vector_ssa_value = get_Vector_SSAVal_from_element_type_and_shape(i64, [2, 3]) + memref_ssa_value = get_MemRef_SSAVal_from_element_type_and_shape(i32, [4, 5]) store = Store.get(vector_ssa_value, memref_ssa_value, []) with pytest.raises(Exception) as exc_info: store.verify() - assert exc_info.value.args[ - 0] == "MemRef element type should match the Vector element type." + assert ( + exc_info.value.args[0] + == "MemRef element type should match the Vector element type." + ) def test_vector_store_verify_indexing_exception(): - vector_ssa_value = get_Vector_SSAVal_from_element_type_and_shape( - i32, [2, 3]) - memref_ssa_value = get_MemRef_SSAVal_from_element_type_and_shape( - i32, [4, 5]) + vector_ssa_value = get_Vector_SSAVal_from_element_type_and_shape(i32, [2, 3]) + memref_ssa_value = get_MemRef_SSAVal_from_element_type_and_shape(i32, [4, 5]) store = Store.get(vector_ssa_value, memref_ssa_value, []) @@ -166,13 +175,14 @@ def test_vector_broadcast_verify_type_matching(): index1 = TestSSAValue(IndexType()) res_vector_type = VectorType.from_element_type_and_shape(i64, [1]) - broadcast = Broadcast.build(operands=[index1], - result_types=[res_vector_type]) + broadcast = Broadcast.build(operands=[index1], result_types=[res_vector_type]) with pytest.raises(Exception) as exc_info: broadcast.verify() - assert exc_info.value.args[ - 0] == "Source operand and result vector must have the same element type." + assert ( + exc_info.value.args[0] + == "Source operand and result vector must have the same element type." + ) def test_vector_fma(): @@ -182,8 +192,7 @@ def test_vector_fma(): rhs_vector_ssa_value = TestSSAValue(i32_vector_type) acc_vector_ssa_value = TestSSAValue(i32_vector_type) - fma = FMA.get(lhs_vector_ssa_value, rhs_vector_ssa_value, - acc_vector_ssa_value) + fma = FMA.get(lhs_vector_ssa_value, rhs_vector_ssa_value, acc_vector_ssa_value) assert type(fma.results[0]) is OpResult assert type(fma.results[0].typ) is VectorType @@ -199,8 +208,7 @@ def test_vector_fma_with_dimensions(): rhs_vector_ssa_value = TestSSAValue(i32_vector_type) acc_vector_ssa_value = TestSSAValue(i32_vector_type) - fma = FMA.get(lhs_vector_ssa_value, rhs_vector_ssa_value, - acc_vector_ssa_value) + fma = FMA.get(lhs_vector_ssa_value, rhs_vector_ssa_value, acc_vector_ssa_value) assert type(fma.results[0]) is OpResult assert type(fma.results[0].typ) is VectorType @@ -212,123 +220,127 @@ def test_vector_fma_with_dimensions(): def test_vector_fma_verify_res_lhs_type_matching(): i64_vector_type = VectorType.from_element_type_and_shape(i64, [1]) - i32_vector_ssa_value = get_Vector_SSAVal_from_element_type_and_shape( - i32, [1]) - i64_vector_ssa_value = get_Vector_SSAVal_from_element_type_and_shape( - i64, [1]) + i32_vector_ssa_value = get_Vector_SSAVal_from_element_type_and_shape(i32, [1]) + i64_vector_ssa_value = get_Vector_SSAVal_from_element_type_and_shape(i64, [1]) - fma = FMA.build(operands=[ - i32_vector_ssa_value, i64_vector_ssa_value, i64_vector_ssa_value - ], - result_types=[i64_vector_type]) + fma = FMA.build( + operands=[i32_vector_ssa_value, i64_vector_ssa_value, i64_vector_ssa_value], + result_types=[i64_vector_type], + ) with pytest.raises(Exception) as exc_info: fma.verify() - assert exc_info.value.args[ - 0] == "Result vector type must match with all source vectors. Found different types for result vector and lhs vector." + assert ( + exc_info.value.args[0] + == "Result vector type must match with all source vectors. Found different types for result vector and lhs vector." + ) def test_vector_fma_verify_res_rhs_type_matching(): i64_vector_type = VectorType.from_element_type_and_shape(i64, [1]) - i32_vector_ssa_value = get_Vector_SSAVal_from_element_type_and_shape( - i32, [1]) - i64_vector_ssa_value = get_Vector_SSAVal_from_element_type_and_shape( - i64, [1]) + i32_vector_ssa_value = get_Vector_SSAVal_from_element_type_and_shape(i32, [1]) + i64_vector_ssa_value = get_Vector_SSAVal_from_element_type_and_shape(i64, [1]) - fma = FMA.build(operands=[ - i64_vector_ssa_value, i32_vector_ssa_value, i64_vector_ssa_value - ], - result_types=[i64_vector_type]) + fma = FMA.build( + operands=[i64_vector_ssa_value, i32_vector_ssa_value, i64_vector_ssa_value], + result_types=[i64_vector_type], + ) with pytest.raises(Exception) as exc_info: fma.verify() - assert exc_info.value.args[ - 0] == "Result vector type must match with all source vectors. Found different types for result vector and rhs vector." + assert ( + exc_info.value.args[0] + == "Result vector type must match with all source vectors. Found different types for result vector and rhs vector." + ) def test_vector_fma_verify_res_acc_type_matching(): i64_vector_type = VectorType.from_element_type_and_shape(i64, [1]) - i32_vector_ssa_value = get_Vector_SSAVal_from_element_type_and_shape( - i32, [1]) - i64_vector_ssa_value = get_Vector_SSAVal_from_element_type_and_shape( - i64, [1]) + i32_vector_ssa_value = get_Vector_SSAVal_from_element_type_and_shape(i32, [1]) + i64_vector_ssa_value = get_Vector_SSAVal_from_element_type_and_shape(i64, [1]) - fma = FMA.build(operands=[ - i64_vector_ssa_value, i64_vector_ssa_value, i32_vector_ssa_value - ], - result_types=[i64_vector_type]) + fma = FMA.build( + operands=[i64_vector_ssa_value, i64_vector_ssa_value, i32_vector_ssa_value], + result_types=[i64_vector_type], + ) with pytest.raises(Exception) as exc_info: fma.verify() - assert exc_info.value.args[ - 0] == "Result vector type must match with all source vectors. Found different types for result vector and acc vector." + assert ( + exc_info.value.args[0] + == "Result vector type must match with all source vectors. Found different types for result vector and acc vector." + ) def test_vector_fma_verify_res_lhs_shape_matching(): i32_vector_type2 = VectorType.from_element_type_and_shape(i32, [4, 5]) - vector_ssa_value1 = get_Vector_SSAVal_from_element_type_and_shape( - i32, [2, 3]) - vector_ssa_value2 = get_Vector_SSAVal_from_element_type_and_shape( - i32, [4, 5]) + vector_ssa_value1 = get_Vector_SSAVal_from_element_type_and_shape(i32, [2, 3]) + vector_ssa_value2 = get_Vector_SSAVal_from_element_type_and_shape(i32, [4, 5]) fma = FMA.build( operands=[vector_ssa_value1, vector_ssa_value2, vector_ssa_value2], - result_types=[i32_vector_type2]) + result_types=[i32_vector_type2], + ) with pytest.raises(Exception) as exc_info: fma.verify() - assert exc_info.value.args[ - 0] == "Result vector shape must match with all source vector shapes. Found different shapes for result vector and lhs vector." + assert ( + exc_info.value.args[0] + == "Result vector shape must match with all source vector shapes. Found different shapes for result vector and lhs vector." + ) def test_vector_fma_verify_res_rhs_shape_matching(): i32_vector_type2 = VectorType.from_element_type_and_shape(i32, [4, 5]) - vector_ssa_value1 = get_Vector_SSAVal_from_element_type_and_shape( - i32, [2, 3]) - vector_ssa_value2 = get_Vector_SSAVal_from_element_type_and_shape( - i32, [4, 5]) + vector_ssa_value1 = get_Vector_SSAVal_from_element_type_and_shape(i32, [2, 3]) + vector_ssa_value2 = get_Vector_SSAVal_from_element_type_and_shape(i32, [4, 5]) fma = FMA.build( operands=[vector_ssa_value2, vector_ssa_value1, vector_ssa_value2], - result_types=[i32_vector_type2]) + result_types=[i32_vector_type2], + ) with pytest.raises(Exception) as exc_info: fma.verify() - assert exc_info.value.args[ - 0] == "Result vector shape must match with all source vector shapes. Found different shapes for result vector and rhs vector." + assert ( + exc_info.value.args[0] + == "Result vector shape must match with all source vector shapes. Found different shapes for result vector and rhs vector." + ) def test_vector_fma_verify_res_acc_shape_matching(): i32_vector_type2 = VectorType.from_element_type_and_shape(i32, [4, 5]) - vector_ssa_value1 = get_Vector_SSAVal_from_element_type_and_shape( - i32, [2, 3]) - vector_ssa_value2 = get_Vector_SSAVal_from_element_type_and_shape( - i32, [4, 5]) + vector_ssa_value1 = get_Vector_SSAVal_from_element_type_and_shape(i32, [2, 3]) + vector_ssa_value2 = get_Vector_SSAVal_from_element_type_and_shape(i32, [4, 5]) fma = FMA.build( operands=[vector_ssa_value2, vector_ssa_value2, vector_ssa_value1], - result_types=[i32_vector_type2]) + result_types=[i32_vector_type2], + ) with pytest.raises(Exception) as exc_info: fma.verify() - assert exc_info.value.args[ - 0] == "Result vector shape must match with all source vector shapes. Found different shapes for result vector and acc vector." + assert ( + exc_info.value.args[0] + == "Result vector shape must match with all source vector shapes. Found different shapes for result vector and acc vector." + ) def test_vector_masked_load(): memref_ssa_value = get_MemRef_SSAVal_from_element_type_and_shape(i32, [1]) - mask_vector_ssa_value = get_Vector_SSAVal_from_element_type_and_shape( - i1, [1]) + mask_vector_ssa_value = get_Vector_SSAVal_from_element_type_and_shape(i1, [1]) passthrough_vector_ssa_value = get_Vector_SSAVal_from_element_type_and_shape( - i32, [1]) + i32, [1] + ) - maskedload = Maskedload.get(memref_ssa_value, [], mask_vector_ssa_value, - passthrough_vector_ssa_value) + maskedload = Maskedload.get( + memref_ssa_value, [], mask_vector_ssa_value, passthrough_vector_ssa_value + ) assert type(maskedload.results[0]) is OpResult assert type(maskedload.results[0].typ) is VectorType @@ -336,19 +348,21 @@ def test_vector_masked_load(): def test_vector_masked_load_with_dimensions(): - memref_ssa_value = get_MemRef_SSAVal_from_element_type_and_shape( - i32, [4, 5]) - mask_vector_ssa_value = get_Vector_SSAVal_from_element_type_and_shape( - i1, [1]) + memref_ssa_value = get_MemRef_SSAVal_from_element_type_and_shape(i32, [4, 5]) + mask_vector_ssa_value = get_Vector_SSAVal_from_element_type_and_shape(i1, [1]) passthrough_vector_ssa_value = get_Vector_SSAVal_from_element_type_and_shape( - i32, [1]) + i32, [1] + ) index1 = TestSSAValue(IndexType()) index2 = TestSSAValue(IndexType()) - maskedload = Maskedload.get(memref_ssa_value, [index1, index2], - mask_vector_ssa_value, - passthrough_vector_ssa_value) + maskedload = Maskedload.get( + memref_ssa_value, + [index1, index2], + mask_vector_ssa_value, + passthrough_vector_ssa_value, + ) assert type(maskedload.results[0]) is OpResult assert type(maskedload.results[0].typ) is VectorType @@ -358,72 +372,84 @@ def test_vector_masked_load_with_dimensions(): def test_vector_masked_load_verify_memref_res_type_matching(): memref_ssa_value = get_MemRef_SSAVal_from_element_type_and_shape(i32, [1]) - mask_vector_ssa_value = get_Vector_SSAVal_from_element_type_and_shape( - i1, [1]) + mask_vector_ssa_value = get_Vector_SSAVal_from_element_type_and_shape(i1, [1]) passthrough_vector_ssa_value = get_Vector_SSAVal_from_element_type_and_shape( - i32, [1]) + i32, [1] + ) i64_res_vector_type = VectorType.from_element_type_and_shape(i64, [1]) - maskedload = Maskedload.build(operands=[ - memref_ssa_value, [], mask_vector_ssa_value, - passthrough_vector_ssa_value - ], - result_types=[i64_res_vector_type]) + maskedload = Maskedload.build( + operands=[ + memref_ssa_value, + [], + mask_vector_ssa_value, + passthrough_vector_ssa_value, + ], + result_types=[i64_res_vector_type], + ) with pytest.raises(Exception) as exc_info: maskedload.verify() - assert exc_info.value.args[ - 0] == "MemRef element type should match the result vector and passthrough vector element type. Found different element types for memref and result." + assert ( + exc_info.value.args[0] + == "MemRef element type should match the result vector and passthrough vector element type. Found different element types for memref and result." + ) def test_vector_masked_load_verify_memref_passthrough_type_matching(): memref_ssa_value = get_MemRef_SSAVal_from_element_type_and_shape(i32, [1]) - mask_vector_ssa_value = get_Vector_SSAVal_from_element_type_and_shape( - i1, [1]) + mask_vector_ssa_value = get_Vector_SSAVal_from_element_type_and_shape(i1, [1]) passthrough_vector_ssa_value = get_Vector_SSAVal_from_element_type_and_shape( - i64, [1]) + i64, [1] + ) i64_res_vector_type = VectorType.from_element_type_and_shape(i32, [1]) - maskedload = Maskedload.build(operands=[ - memref_ssa_value, [], mask_vector_ssa_value, - passthrough_vector_ssa_value - ], - result_types=[i64_res_vector_type]) + maskedload = Maskedload.build( + operands=[ + memref_ssa_value, + [], + mask_vector_ssa_value, + passthrough_vector_ssa_value, + ], + result_types=[i64_res_vector_type], + ) with pytest.raises(Exception) as exc_info: maskedload.verify() - assert exc_info.value.args[ - 0] == "MemRef element type should match the result vector and passthrough vector element type. Found different element types for memref and passthrough." + assert ( + exc_info.value.args[0] + == "MemRef element type should match the result vector and passthrough vector element type. Found different element types for memref and passthrough." + ) def test_vector_masked_load_verify_indexing_exception(): - memref_ssa_value = get_MemRef_SSAVal_from_element_type_and_shape( - i32, [4, 5]) - mask_vector_ssa_value = get_Vector_SSAVal_from_element_type_and_shape( - i1, [2]) + memref_ssa_value = get_MemRef_SSAVal_from_element_type_and_shape(i32, [4, 5]) + mask_vector_ssa_value = get_Vector_SSAVal_from_element_type_and_shape(i1, [2]) passthrough_vector_ssa_value = get_Vector_SSAVal_from_element_type_and_shape( - i32, [1]) + i32, [1] + ) - maskedload = Maskedload.get(memref_ssa_value, [], mask_vector_ssa_value, - passthrough_vector_ssa_value) + maskedload = Maskedload.get( + memref_ssa_value, [], mask_vector_ssa_value, passthrough_vector_ssa_value + ) with pytest.raises(Exception) as exc_info: maskedload.verify() - assert exc_info.value.args[ - 0] == "Expected an index for each memref dimension." + assert exc_info.value.args[0] == "Expected an index for each memref dimension." def test_vector_masked_store(): memref_ssa_value = get_MemRef_SSAVal_from_element_type_and_shape(i32, [1]) - mask_vector_ssa_value = get_Vector_SSAVal_from_element_type_and_shape( - i1, [1]) + mask_vector_ssa_value = get_Vector_SSAVal_from_element_type_and_shape(i1, [1]) value_to_store_vector_ssa_value = get_Vector_SSAVal_from_element_type_and_shape( - i32, [1]) + i32, [1] + ) - maskedstore = Maskedstore.get(memref_ssa_value, [], mask_vector_ssa_value, - value_to_store_vector_ssa_value) + maskedstore = Maskedstore.get( + memref_ssa_value, [], mask_vector_ssa_value, value_to_store_vector_ssa_value + ) assert maskedstore.memref is memref_ssa_value assert maskedstore.mask is mask_vector_ssa_value @@ -432,19 +458,21 @@ def test_vector_masked_store(): def test_vector_masked_store_with_dimensions(): - memref_ssa_value = get_MemRef_SSAVal_from_element_type_and_shape( - i32, [4, 5]) - mask_vector_ssa_value = get_Vector_SSAVal_from_element_type_and_shape( - i1, [1]) + memref_ssa_value = get_MemRef_SSAVal_from_element_type_and_shape(i32, [4, 5]) + mask_vector_ssa_value = get_Vector_SSAVal_from_element_type_and_shape(i1, [1]) value_to_store_vector_ssa_value = get_Vector_SSAVal_from_element_type_and_shape( - i32, [1]) + i32, [1] + ) index1 = TestSSAValue(IndexType()) index2 = TestSSAValue(IndexType()) - maskedstore = Maskedstore.get(memref_ssa_value, [index1, index2], - mask_vector_ssa_value, - value_to_store_vector_ssa_value) + maskedstore = Maskedstore.get( + memref_ssa_value, + [index1, index2], + mask_vector_ssa_value, + value_to_store_vector_ssa_value, + ) assert maskedstore.memref is memref_ssa_value assert maskedstore.mask is mask_vector_ssa_value @@ -455,36 +483,37 @@ def test_vector_masked_store_with_dimensions(): def test_vector_masked_store_verify_memref_value_to_store_type_matching(): memref_ssa_value = get_MemRef_SSAVal_from_element_type_and_shape(i32, [1]) - mask_vector_ssa_value = get_Vector_SSAVal_from_element_type_and_shape( - i1, [1]) + mask_vector_ssa_value = get_Vector_SSAVal_from_element_type_and_shape(i1, [1]) value_to_store_vector_ssa_value = get_Vector_SSAVal_from_element_type_and_shape( - i64, [1]) + i64, [1] + ) - maskedstore = Maskedstore.get(memref_ssa_value, [], mask_vector_ssa_value, - value_to_store_vector_ssa_value) + maskedstore = Maskedstore.get( + memref_ssa_value, [], mask_vector_ssa_value, value_to_store_vector_ssa_value + ) with pytest.raises(Exception) as exc_info: maskedstore.verify() assert exc_info.value.args[0] == ( "MemRef element type should match the stored vector type. " - "Obtained types were i32 and i64.") + "Obtained types were i32 and i64." + ) def test_vector_masked_store_verify_indexing_exception(): - memref_ssa_value = get_MemRef_SSAVal_from_element_type_and_shape( - i32, [4, 5]) - mask_vector_ssa_value = get_Vector_SSAVal_from_element_type_and_shape( - i1, [2]) + memref_ssa_value = get_MemRef_SSAVal_from_element_type_and_shape(i32, [4, 5]) + mask_vector_ssa_value = get_Vector_SSAVal_from_element_type_and_shape(i1, [2]) value_to_store_vector_ssa_value = get_Vector_SSAVal_from_element_type_and_shape( - i32, [1]) + i32, [1] + ) - maskedstore = Maskedstore.get(memref_ssa_value, [], mask_vector_ssa_value, - value_to_store_vector_ssa_value) + maskedstore = Maskedstore.get( + memref_ssa_value, [], mask_vector_ssa_value, value_to_store_vector_ssa_value + ) with pytest.raises(Exception) as exc_info: maskedstore.verify() - assert exc_info.value.args[ - 0] == "Expected an index for each memref dimension." + assert exc_info.value.args[0] == "Expected an index for each memref dimension." def test_vector_print(): @@ -518,10 +547,11 @@ def test_vector_create_mask_with_dimensions(): def test_vector_create_mask_verify_indexing_exception(): mask_vector_type = VectorType.from_element_type_and_shape(i1, [2, 3]) - create_mask = Createmask.build(operands=[[]], - result_types=[mask_vector_type]) + create_mask = Createmask.build(operands=[[]], result_types=[mask_vector_type]) with pytest.raises(Exception) as exc_info: create_mask.verify() - assert exc_info.value.args[ - 0] == "Expected an operand value for each dimension of resultant mask." + assert ( + exc_info.value.args[0] + == "Expected an operand value for each dimension of resultant mask." + ) diff --git a/tests/filecheck/frontend/dialects/arith.py b/tests/filecheck/frontend/dialects/arith.py index af3d89c047..3e3c5c3798 100644 --- a/tests/filecheck/frontend/dialects/arith.py +++ b/tests/filecheck/frontend/dialects/arith.py @@ -7,7 +7,6 @@ p = FrontendProgram() with CodeContext(p): - # CHECK: arith.addi(%{{.*}} : !i32, %{{.*}} : !i32) def test_addi_overload(a: i32, b: i32) -> i32: return a + b @@ -98,7 +97,6 @@ def test_mulf_overload_f64(a: f64, b: f64) -> f64: try: with CodeContext(p): - # CHECK: Binary operation 'FloorDiv' is not supported by type '_Float64' which does not overload '__floordiv__'. def test_missing_floordiv_overload_f64(a: f64, b: f64) -> f64: # We expect the type error here, since the function doesn't exist on f64 @@ -111,7 +109,6 @@ def test_missing_floordiv_overload_f64(a: f64, b: f64) -> f64: try: with CodeContext(p): - # CHECK: Comparison operation 'In' is not supported by type '_Float64' which does not overload '__contains__'. def test_missing_contains_overload_f64(a: f64, b: f64) -> f64: # We expect the type error here, since the function doesn't exist on f64 diff --git a/tests/filecheck/frontend/dialects/builtin.py b/tests/filecheck/frontend/dialects/builtin.py index 8638f377d7..4afd33bbf2 100644 --- a/tests/filecheck/frontend/dialects/builtin.py +++ b/tests/filecheck/frontend/dialects/builtin.py @@ -3,7 +3,19 @@ from typing import Literal, Tuple from xdsl.frontend.program import FrontendProgram from xdsl.frontend.context import CodeContext -from xdsl.frontend.dialects.builtin import i1, i32, i64, ui32, ui64, si32, si64, index, f16, f32, f64 +from xdsl.frontend.dialects.builtin import ( + i1, + i32, + i64, + ui32, + ui64, + si32, + si64, + index, + f16, + f32, + f64, +) p = FrontendProgram() with CodeContext(p): diff --git a/tests/filecheck/frontend/dialects/invalid.py b/tests/filecheck/frontend/dialects/invalid.py index 0f72cf1813..936a6b4e19 100644 --- a/tests/filecheck/frontend/dialects/invalid.py +++ b/tests/filecheck/frontend/dialects/invalid.py @@ -11,7 +11,6 @@ try: with CodeContext(p): - # CHECK: Expected non-zero number of return types in function 'test_no_return_type', but got 0. def test_no_return_type(a: i32) -> i32: return @@ -23,7 +22,6 @@ def test_no_return_type(a: i32) -> i32: try: with CodeContext(p): - # CHECK: Type signature and the type of the return value do not match at position 0: expected i32, got i64. def test_wrong_return_type(a: i32, b: i64) -> i32: return b @@ -35,7 +33,6 @@ def test_wrong_return_type(a: i32, b: i64) -> i32: try: with CodeContext(p): - # CHECK: Expected no return types in function 'test_wrong_return_type'. def test_wrong_return_type(a: i32): return a @@ -47,7 +44,6 @@ def test_wrong_return_type(a: i32): try: with CodeContext(p): - # CHECK: Expected the same types for binary operation 'Add', but got i32 and i64. def bin_op_type_mismatch(a: i32, b: i64) -> i32: return a + b @@ -59,7 +55,6 @@ def bin_op_type_mismatch(a: i32, b: i64) -> i32: try: with CodeContext(p): - # CHECK: Expected the same types for comparison operator 'Lt', but got i32 and i64. def cmp_op_type_mismatch(a: i32, b: i64) -> i1: return a < b diff --git a/tests/filecheck/frontend/programs/invalid.py b/tests/filecheck/frontend/programs/invalid.py index 53b66848a0..ac194f6577 100644 --- a/tests/filecheck/frontend/programs/invalid.py +++ b/tests/filecheck/frontend/programs/invalid.py @@ -53,7 +53,6 @@ def foo(): with CodeContext(p): def foo(): - # CHECK-NEXT: Cannot have an inner function 'bar' inside the function 'foo'. def bar(): return @@ -69,7 +68,6 @@ def bar(): with CodeContext(p): def foo(): - @block def bb1(): # CHECK-NEXT: Cannot have an inner function 'foo' inside the block 'bb1'. @@ -87,10 +85,8 @@ def foo(): with CodeContext(p): def foo(): - @block def bb0(): - # CHECK-NEXT: Cannot have a nested block 'bb1' inside the block 'bb0'. @block def bb1(): @@ -109,7 +105,6 @@ def bb1(): with CodeContext(p): def foo(): - @block def bb0(): bb0() @@ -238,7 +233,6 @@ def bb0(): try: with CodeContext(p): - c: Const[i32] = 23 def foo(): @@ -253,7 +247,6 @@ def foo(): try: with CodeContext(p): - c: Const[i32] = 23 def foo(): @@ -268,7 +261,6 @@ def foo(): try: with CodeContext(p): - c: Const[i32] = 23 # CHECK-NEXT: Constant 'c' is already defined and cannot be used as a function/block argument name. @@ -282,11 +274,9 @@ def foo(c: i32): try: with CodeContext(p): - e: Const[i32] = 23 def foo(): - @block def bb0(): # CHECK-NEXT: Constant 'e' is already defined and cannot be assigned to. diff --git a/tests/interpreters/test_pdl_interpreter.py b/tests/interpreters/test_pdl_interpreter.py index 8ad29c62a1..9fcf4fbdd9 100644 --- a/tests/interpreters/test_pdl_interpreter.py +++ b/tests/interpreters/test_pdl_interpreter.py @@ -4,16 +4,18 @@ from xdsl.ir import MLContext, OpResult from xdsl.dialects import arith, func, pdl from xdsl.dialects.builtin import ArrayAttr, IntegerAttr, ModuleOp, StringAttr -from xdsl.pattern_rewriter import (PatternRewriter, RewritePattern, - op_type_rewrite_pattern, - PatternRewriteWalker) +from xdsl.pattern_rewriter import ( + PatternRewriter, + RewritePattern, + op_type_rewrite_pattern, + PatternRewriteWalker, +) from xdsl.interpreter import Interpreter from xdsl.interpreters.experimental.pdl import PDLFunctions class SwapInputs(RewritePattern): - @op_type_rewrite_pattern def match_and_rewrite(self, op: arith.Addi, rewriter: PatternRewriter, /): if not isinstance(op.lhs, OpResult): @@ -28,8 +30,9 @@ def test_rewrite_swap_inputs_python(): input_module = swap_arguments_input() output_module = swap_arguments_output() - PatternRewriteWalker(SwapInputs(), - apply_recursively=False).rewrite_module(input_module) + PatternRewriteWalker(SwapInputs(), apply_recursively=False).rewrite_module( + input_module + ) assert input_module.is_structurally_equivalent(output_module) @@ -53,14 +56,11 @@ def test_rewrite_swap_inputs_pdl(): def swap_arguments_input(): - @ModuleOp @Builder.implicit_region def ir_module(): - @Builder.implicit_region def impl(): - x = arith.Constant.from_int_and_width(4, 32).result y = arith.Constant.from_int_and_width(2, 32).result z = arith.Constant.from_int_and_width(1, 32).result @@ -68,20 +68,17 @@ def impl(): x_y_z = arith.Addi.get(x_y, z).result func.Return.get(x_y_z) - func.FuncOp.from_region('impl', [], [], impl) + func.FuncOp.from_region("impl", [], [], impl) return ir_module def swap_arguments_output(): - @ModuleOp @Builder.implicit_region def ir_module(): - @Builder.implicit_region def impl(): - x = arith.Constant.from_int_and_width(4, 32).result y = arith.Constant.from_int_and_width(2, 32).result z = arith.Constant.from_int_and_width(1, 32).result @@ -89,7 +86,7 @@ def impl(): z_x_y = arith.Addi.get(z, x_y).result func.Return.get(z_x_y) - func.FuncOp.from_region('impl', [], [], impl) + func.FuncOp.from_region("impl", [], [], impl) return ir_module @@ -103,27 +100,27 @@ def pattern_region(): y = pdl.OperandOp.get().value typ = pdl.TypeOp.get().result - x_y_op = pdl.OperationOp.get(StringAttr("arith.addi"), - operandValues=[x, y], - typeValues=[typ]).op - x_y = pdl.ResultOp.get(IntegerAttr.from_int_and_width(0, 32), - parent=x_y_op).val + x_y_op = pdl.OperationOp.get( + StringAttr("arith.addi"), operandValues=[x, y], typeValues=[typ] + ).op + x_y = pdl.ResultOp.get(IntegerAttr.from_int_and_width(0, 32), parent=x_y_op).val z = pdl.OperandOp.get().value - x_y_z_op = pdl.OperationOp.get(opName=StringAttr("arith.addi"), - operandValues=[x_y, z], - typeValues=[typ]).op + x_y_z_op = pdl.OperationOp.get( + opName=StringAttr("arith.addi"), operandValues=[x_y, z], typeValues=[typ] + ).op @Builder.implicit_region def rewrite_region(): - z_x_y_op = pdl.OperationOp.get(StringAttr("arith.addi"), - operandValues=[z, x_y], - typeValues=[typ]).op + z_x_y_op = pdl.OperationOp.get( + StringAttr("arith.addi"), operandValues=[z, x_y], typeValues=[typ] + ).op pdl.ReplaceOp.get(x_y_z_op, z_x_y_op) pdl.RewriteOp.get(None, x_y_z_op, [], rewrite_region) - pattern = pdl.PatternOp.get(IntegerAttr.from_int_and_width(2, 16), None, - pattern_region) + pattern = pdl.PatternOp.get( + IntegerAttr.from_int_and_width(2, 16), None, pattern_region + ) pdl_module = ModuleOp([pattern]) @@ -131,7 +128,6 @@ def rewrite_region(): class AddZero(RewritePattern): - @op_type_rewrite_pattern def match_and_rewrite(self, op: arith.Addi, rewriter: PatternRewriter, /): if not isinstance(op.rhs, OpResult): @@ -150,8 +146,9 @@ def test_rewrite_add_zero_python(): input_module = add_zero_input() output_module = add_zero_output() - PatternRewriteWalker(AddZero(), - apply_recursively=False).rewrite_module(input_module) + PatternRewriteWalker(AddZero(), apply_recursively=False).rewrite_module( + input_module + ) assert input_module.is_structurally_equivalent(output_module) @@ -178,11 +175,9 @@ def test_rewrite_add_zero_pdl(): def add_zero_input(): - @ModuleOp @Builder.implicit_region def ir_module(): - @Builder.implicit_region def impl(): x = arith.Constant.from_int_and_width(4, 32) @@ -190,24 +185,22 @@ def impl(): z = arith.Addi.get(x, y) func.Return.get(z) - func.FuncOp.from_region('impl', [], [], impl) + func.FuncOp.from_region("impl", [], [], impl) return ir_module def add_zero_output(): - @ModuleOp @Builder.implicit_region def ir_module(): - @Builder.implicit_region def impl(): x = arith.Constant.from_int_and_width(4, 32) _y = arith.Constant.from_int_and_width(0, 32) func.Return.get(x) - func.FuncOp.from_region('impl', [], [], impl) + func.FuncOp.from_region("impl", [], [], impl) return ir_module @@ -224,20 +217,21 @@ def pattern_region(): lhs = pdl.OperandOp.get().results[0] # Constant 0: i32 - zero = pdl.AttributeOp.get(value=IntegerAttr.from_int_and_width(0, 32), - valueType=pdl_i32).results[0] - rhs_op = pdl.OperationOp.get(opName=StringAttr("arith.constant"), - attributeValueNames=ArrayAttr( - [StringAttr("value")]), - attributeValues=[zero], - typeValues=[pdl_i32]).op - rhs = pdl.ResultOp.get(IntegerAttr.from_int_and_width(0, 32), - parent=rhs_op).val + zero = pdl.AttributeOp.get( + value=IntegerAttr.from_int_and_width(0, 32), valueType=pdl_i32 + ).results[0] + rhs_op = pdl.OperationOp.get( + opName=StringAttr("arith.constant"), + attributeValueNames=ArrayAttr([StringAttr("value")]), + attributeValues=[zero], + typeValues=[pdl_i32], + ).op + rhs = pdl.ResultOp.get(IntegerAttr.from_int_and_width(0, 32), parent=rhs_op).val # LHS + 0 - sum = pdl.OperationOp.get(StringAttr("arith.addi"), - operandValues=[lhs, rhs], - typeValues=[pdl_i32]).op + sum = pdl.OperationOp.get( + StringAttr("arith.addi"), operandValues=[lhs, rhs], typeValues=[pdl_i32] + ).op @Builder.implicit_region def rewrite_region(): @@ -245,8 +239,9 @@ def rewrite_region(): pdl.RewriteOp.get(None, sum, [], rewrite_region) - pattern = pdl.PatternOp.get(IntegerAttr.from_int_and_width(2, 16), None, - pattern_region) + pattern = pdl.PatternOp.get( + IntegerAttr.from_int_and_width(2, 16), None, pattern_region + ) pdl_module = ModuleOp([pattern]) diff --git a/tests/rewriting/composable_rewriting/immutable_ir/test_immutable_ir.py b/tests/rewriting/composable_rewriting/immutable_ir/test_immutable_ir.py index 51daa9ab40..d2f595cba1 100644 --- a/tests/rewriting/composable_rewriting/immutable_ir/test_immutable_ir.py +++ b/tests/rewriting/composable_rewriting/immutable_ir/test_immutable_ir.py @@ -6,44 +6,39 @@ from xdsl.dialects.func import Func from xdsl.dialects.arith import Arith from xdsl.dialects.cf import Cf -from xdsl.rewriting.composable_rewriting.immutable_ir.immutable_ir import get_immutable_copy # noqa +from xdsl.rewriting.composable_rewriting.immutable_ir.immutable_ir import ( + get_immutable_copy, +) # noqa -program_region = \ -"""builtin.module() { +program_region = """builtin.module() { %0 : !i32 = arith.constant() ["value" = 1 : !i32] } """ -program_region_2 = \ -"""builtin.module() { +program_region_2 = """builtin.module() { %0 : !i32 = arith.constant() ["value" = 2 : !i32] } """ -program_region_2_diff_name = \ -"""builtin.module() { +program_region_2_diff_name = """builtin.module() { %cst : !i32 = arith.constant() ["value" = 2 : !i32] } """ -program_region_2_diff_type = \ -"""builtin.module() { +program_region_2_diff_type = """builtin.module() { %0 : !i64 = arith.constant() ["value" = 2 : !i64] } """ -program_add = \ -"""builtin.module() { +program_add = """builtin.module() { %0 : !i32 = arith.constant() ["value" = 1 : !i32] %1 : !i32 = arith.constant() ["value" = 2 : !i32] %2 : !i32 = arith.addi(%0 : !i32, %1 : !i32) } """ -program_add_2 = \ -"""builtin.module() { +program_add_2 = """builtin.module() { %0 : !i32 = arith.constant() ["value" = 1 : !i32] %1 : !i32 = arith.constant() ["value" = 2 : !i32] %2 : !i32 = arith.addi(%1 : !i32, %0 : !i32) } """ -program_func = \ -"""builtin.module() { +program_func = """builtin.module() { func.func() ["sym_name" = "test", "type" = !fun<[!i32, !i32], [!i32]>, "sym_visibility" = "private"] { ^0(%0 : !i32, %1 : !i32): %2 : !i32 = arith.addi(%0 : !i32, %1 : !i32) @@ -51,8 +46,7 @@ } } """ -program_successors = \ -"""builtin.module() { +program_successors = """builtin.module() { func.func() ["sym_name" = "unconditional_br", "function_type" = !fun<[], []>, "sym_visibility" = "private"] { ^0: cf.br() (^1) @@ -63,11 +57,19 @@ """ -@pytest.mark.parametrize("program_str", [(program_region), (program_region_2), - (program_region_2_diff_type), - (program_region_2_diff_name), - (program_add), (program_add_2), - (program_func), (program_successors)]) +@pytest.mark.parametrize( + "program_str", + [ + (program_region), + (program_region_2), + (program_region_2_diff_type), + (program_region_2_diff_name), + (program_add), + (program_add_2), + (program_func), + (program_successors), + ], +) def test_immutable_ir(program_str: str): ctx = MLContext() ctx.register_dialect(Builtin) diff --git a/tests/test_attribute_definition.py b/tests/test_attribute_definition.py index 1226c14fe8..8e1018766a 100644 --- a/tests/test_attribute_definition.py +++ b/tests/test_attribute_definition.py @@ -10,9 +10,16 @@ import pytest from xdsl.ir import Attribute, Data, ParametrizedAttribute -from xdsl.irdl import (AttrConstraint, GenericData, ParameterDef, - irdl_attr_definition, irdl_to_attr_constraint, AnyAttr, - BaseAttr, ParamAttrDef) +from xdsl.irdl import ( + AttrConstraint, + GenericData, + ParameterDef, + irdl_attr_definition, + irdl_to_attr_constraint, + AnyAttr, + BaseAttr, + ParamAttrDef, +) from xdsl.parser import BaseParser from xdsl.printer import Printer from xdsl.utils.exceptions import PyRDLAttrDefinitionError, VerifyException @@ -25,12 +32,13 @@ @irdl_attr_definition class BoolData(Data[bool]): """An attribute holding a boolean value.""" + name = "bool" @staticmethod def parse_parameter(parser: BaseParser) -> bool: - val = parser.tokenizer.next_token_of_pattern('(True|False)') - if val is None or val.text not in ('True', 'False'): + val = parser.tokenizer.next_token_of_pattern("(True|False)") + if val is None or val.text not in ("True", "False"): parser.raise_error("Expected True or False literal") if val.text == "True": return True @@ -43,6 +51,7 @@ def print_parameter(self, printer: Printer): @irdl_attr_definition class IntData(Data[int]): """An attribute holding an integer value.""" + name = "int" @staticmethod @@ -56,6 +65,7 @@ def print_parameter(self, printer: Printer): @irdl_attr_definition class StringData(Data[str]): """An attribute holding a string value.""" + name = "str" @staticmethod @@ -82,8 +92,9 @@ def test_simple_data_verifier_failure(): """ with pytest.raises(VerifyException) as e: BoolData(2) # type: ignore - assert e.value.args[0] == ("bool data attribute expected type " - ", but given.") + assert e.value.args[0] == ( + "bool data attribute expected type " ", but given." + ) class IntListMissingVerifierData(Data[list[int]]): @@ -92,6 +103,7 @@ class IntListMissingVerifierData(Data[list[int]]): The definition should fail, since no verifier is provided, and the Data type parameter is not a class. """ + name = "missing_verifier_data" @staticmethod @@ -111,9 +123,9 @@ def test_data_with_non_class_param_missing_verifier_failure(): # Python 3.10 and 3.11 have different error messages assert e.value.args[0] in [ - 'In IntListMissingVerifierData definition: ' + "In IntListMissingVerifierData definition: " 'Cannot infer "verify" method. Type parameter of Data has type GenericAlias.', - 'In IntListMissingVerifierData definition: ' + "In IntListMissingVerifierData definition: " 'Cannot infer "verify" method. Type parameter of Data is not a class.', ] @@ -123,6 +135,7 @@ class IntListData(Data[list[int]]): """ An attribute holding a list of integers. """ + name = "int_list" @staticmethod @@ -139,8 +152,7 @@ def verify(self) -> None: raise VerifyException("int_list data should hold a list.") for elem in self.data: if not isinstance(elem, int): - raise VerifyException( - "int_list list elements should be integers.") + raise VerifyException("int_list list elements should be integers.") def test_non_class_data(): @@ -233,14 +245,13 @@ def test_union_constraint_fail(): class PositiveIntConstr(AttrConstraint): - def verify(self, attr: Attribute) -> None: if not isinstance(attr, IntData): raise VerifyException( - f"Expected {IntData.name} attribute, but got {attr.name}.") + f"Expected {IntData.name} attribute, but got {attr.name}." + ) if attr.data <= 0: - raise VerifyException( - f"Expected positive integer, got {attr.data}.") + raise VerifyException(f"Expected positive integer, got {attr.data}.") @irdl_attr_definition @@ -350,8 +361,7 @@ def test_nested_generic_constraint(): stream = StringIO() p = Printer(stream=stream) p.print_attribute(attr) - assert stream.getvalue( - ) == "#nested_param_wrapper<#int_or_bool_generic<#int<42>>>" + assert stream.getvalue() == "#nested_param_wrapper<#int_or_bool_generic<#int<42>>>" def test_nested_generic_constraint_fail(): @@ -368,8 +378,7 @@ def test_nested_generic_constraint_fail(): class NestedParamConstrAttr(ParametrizedAttribute): name = "nested_param_constr" - param: ParameterDef[NestedParamWrapperAttr[Annotated[IntData, - PositiveIntConstr()]]] + param: ParameterDef[NestedParamWrapperAttr[Annotated[IntData, PositiveIntConstr()]]] def test_nested_param_attr_constraint(): @@ -377,12 +386,15 @@ def test_nested_param_attr_constraint(): Test the verifier of a nested parametric constraint. """ attr = NestedParamConstrAttr( - [NestedParamWrapperAttr([ParamWrapperAttr([IntData(42)])])]) + [NestedParamWrapperAttr([ParamWrapperAttr([IntData(42)])])] + ) stream = StringIO() p = Printer(stream=stream) p.print_attribute(attr) - assert stream.getvalue( - ) == "#nested_param_constr<#nested_param_wrapper<#int_or_bool_generic<#int<42>>>>" + assert ( + stream.getvalue() + == "#nested_param_constr<#nested_param_wrapper<#int_or_bool_generic<#int<42>>>>" + ) def test_nested_param_attr_constraint_fail(): @@ -391,7 +403,8 @@ def test_nested_param_attr_constraint_fail(): """ with pytest.raises(Exception) as e: NestedParamConstrAttr( - [NestedParamWrapperAttr([ParamWrapperAttr([IntData(-42)])])]) + [NestedParamWrapperAttr([ParamWrapperAttr([IntData(-42)])])] + ) assert e.value.args[0] == "Expected positive integer, got -42." @@ -433,7 +446,8 @@ def test_data_with_generic_missing_generic_data_failure(): assert e.value.args[0] == ( "Generic `Data` type 'missing_genericdata' cannot be converted to " "an attribute constraint. Consider making it inherit from " - "`GenericData` instead of `Data`.") + "`GenericData` instead of `Data`." + ) A = TypeVar("A", bound=Attribute) @@ -445,6 +459,7 @@ class DataListAttr(AttrConstraint): A constraint that enforces that the elements of a ListData all respect a constraint. """ + elem_constr: AttrConstraint def verify(self, attr: Attribute) -> None: @@ -480,19 +495,20 @@ def verify(self) -> None: raise VerifyException( f"Wrong type given to attribute {self.name}: got" f" {type(self.data)}, but expected list of" - " attributes.") + " attributes." + ) for idx, val in enumerate(self.data): if not isinstance(val, Attribute): raise VerifyException( f"{self.name} data expects attribute list, but element " - f"{idx} is of type {type(val)}.") + f"{idx} is of type {type(val)}." + ) AnyListData: TypeAlias = ListData[Attribute] class Test_generic_data_verifier: - def test_generic_data_verifier(self): """ Test that a GenericData can be created. @@ -501,8 +517,7 @@ def test_generic_data_verifier(self): stream = StringIO() p = Printer(stream=stream) p.print_attribute(attr) - assert stream.getvalue( - ) == "#list<[#bool, #list<[#bool]>]>" + assert stream.getvalue() == "#list<[#bool, #list<[#bool]>]>" def test_generic_data_verifier_fail(self): """ @@ -510,8 +525,10 @@ def test_generic_data_verifier_fail(self): """ with pytest.raises(VerifyException) as e: ListData([0]) # type: ignore - assert e.value.args[0] == ("list data expects attribute list, but" - " element 0 is of type .") + assert e.value.args[0] == ( + "list data expects attribute list, but" + " element 0 is of type ." + ) def test_generic_data_verifier_fail_II(self): """ @@ -521,7 +538,8 @@ def test_generic_data_verifier_fail_II(self): ListData((0)) # type: ignore assert e.value.args[0] == ( "Wrong type given to attribute list: " - "got , but expected list of attributes.") + "got , but expected list of attributes." + ) @irdl_attr_definition @@ -539,8 +557,7 @@ def test_generic_data_wrapper_verifier(): stream = StringIO() p = Printer(stream=stream) p.print_attribute(attr) - assert stream.getvalue( - ) == "#list_wrapper<#list<[#bool, #bool]>>" + assert stream.getvalue() == "#list_wrapper<#list<[#bool, #bool]>>" def test_generic_data_wrapper_verifier_failure(): @@ -549,11 +566,8 @@ def test_generic_data_wrapper_verifier_failure(): the verifier when constraints are not satisfied. """ with pytest.raises(VerifyException) as e: - ListDataWrapper( - [ListData([BoolData(True), - ListData([BoolData(False)])])]) - assert e.value.args[ - 0] == "#list<[#bool]> should be of base attribute bool" + ListDataWrapper([ListData([BoolData(True), ListData([BoolData(False)])])]) + assert e.value.args[0] == "#list<[#bool]> should be of base attribute bool" @irdl_attr_definition @@ -568,12 +582,15 @@ def test_generic_data_no_generics_wrapper_verifier(): Test that GenericType can be used in constraints without a parameter. """ attr = ListDataNoGenericsWrapper( - [ListData([BoolData(True), ListData([BoolData(False)])])]) + [ListData([BoolData(True), ListData([BoolData(False)])])] + ) stream = StringIO() p = Printer(stream=stream) p.print_attribute(attr) - assert stream.getvalue( - ) == "#list_no_generics_wrapper<#list<[#bool, #list<[#bool]>]>>" + assert ( + stream.getvalue() + == "#list_no_generics_wrapper<#list<[#bool, #list<[#bool]>]>>" + ) ################################################################################ @@ -597,8 +614,8 @@ def test_irdl_definition(): """Test that we can get the IRDL definition of a parametrized attribute.""" assert ParamAttrDefAttr.irdl_definition == ParamAttrDef( - "test.param_attr_def_attr", [("arg1", AnyAttr()), - ("arg2", BaseAttr(BoolData))]) + "test.param_attr_def_attr", [("arg1", AnyAttr()), ("arg2", BaseAttr(BoolData))] + ) class InvalidTypedFieldTestAttr(ParametrizedAttribute): @@ -637,8 +654,9 @@ def __init__(self, param: int | str): elif isinstance(param, str): super().__init__([StringData(param)]) else: - raise TypeError("Expected `int` or `str` type in " - "OveriddenInitAttr constructor") + raise TypeError( + "Expected `int` or `str` type in " "OveriddenInitAttr constructor" + ) def test_generic_constructor(): diff --git a/tests/test_frontend_op_inserter.py b/tests/test_frontend_op_inserter.py index 2b1a146675..a110f4535c 100644 --- a/tests/test_frontend_op_inserter.py +++ b/tests/test_frontend_op_inserter.py @@ -20,8 +20,10 @@ def test_raises_exception_on_op_with_no_regions(): op_with_no_region = Constant.from_int_and_width(1, i32) with pytest.raises(FrontendProgramException) as err: inserter.set_insertion_point_from_op(op_with_no_region) - assert err.value.msg == ("Trying to set the insertion point for operation" - " 'arith.constant' with no regions.") + assert err.value.msg == ( + "Trying to set the insertion point for operation" + " 'arith.constant' with no regions." + ) def test_raises_exception_on_op_with_no_blocks(): @@ -31,7 +33,8 @@ def test_raises_exception_on_op_with_no_blocks(): inserter.set_insertion_point_from_op(op_with_no_region) assert err.value.msg == ( "Trying to set the insertion point for operation" - " 'affine.for' with no blocks in its last region.") + " 'affine.for' with no blocks in its last region." + ) def test_raises_exception_on_op_with_no_blocks_II(): @@ -40,8 +43,8 @@ def test_raises_exception_on_op_with_no_blocks_II(): with pytest.raises(FrontendProgramException) as err: inserter.set_insertion_point_from_region(empty_region) assert err.value.msg == ( - "Trying to set the insertion point from the region without" - " blocks.") + "Trying to set the insertion point from the region without" " blocks." + ) def test_inserts_ops(): diff --git a/tests/test_frontend_op_resolver.py b/tests/test_frontend_op_resolver.py index e74e5b8995..58f4b80f9e 100644 --- a/tests/test_frontend_op_resolver.py +++ b/tests/test_frontend_op_resolver.py @@ -2,8 +2,15 @@ import xdsl.frontend.dialects.builtin as builtin from xdsl.dialects.arith import Addi, Constant, Mulf -from xdsl.dialects.builtin import (Float32Type, Float64Type, IntegerType, i32, - i64, f32, f64) +from xdsl.dialects.builtin import ( + Float32Type, + Float64Type, + IntegerType, + i32, + i64, + f32, + f64, +) from xdsl.frontend.exception import FrontendProgramException from xdsl.frontend.op_resolver import OpResolver @@ -11,13 +18,18 @@ def test_raises_exception_on_unknown_op(): with pytest.raises(FrontendProgramException) as err: _ = OpResolver.resolve_op("xdsl.frontend.dialects.arith", "unknown") - assert err.value.msg == "Internal failure: operation 'unknown' does not exist in module 'xdsl.frontend.dialects.arith'." + assert ( + err.value.msg + == "Internal failure: operation 'unknown' does not exist in module 'xdsl.frontend.dialects.arith'." + ) def test_raises_exception_on_unknown_overload(): with pytest.raises(FrontendProgramException) as err: _ = OpResolver.resolve_op_overload("__unknown__", builtin._Integer) - assert err.value.msg == "Internal failure: '_Integer' does not overload '__unknown__'." + assert ( + err.value.msg == "Internal failure: '_Integer' does not overload '__unknown__'." + ) def test_resolves_ops(): diff --git a/tests/test_frontend_python_code_check.py b/tests/test_frontend_python_code_check.py index 6ee7ecdf2e..dd84f69d38 100644 --- a/tests/test_frontend_python_code_check.py +++ b/tests/test_frontend_python_code_check.py @@ -6,9 +6,7 @@ def test_const_correctly_evaluated_I(): - - src = \ -""" + src = """ a: Const[i32] = 2 ** 5 x = a """ @@ -18,9 +16,7 @@ def test_const_correctly_evaluated_I(): def test_const_correctly_evaluated_II(): - - src = \ -""" + src = """ a: Const[i32] = 4 x: i64 = a + 2 """ @@ -30,9 +26,7 @@ def test_const_correctly_evaluated_II(): def test_const_correctly_evaluated_III(): - - src = \ -""" + src = """ a: Const[i32] = 4 b: Const[i32] = len([1, 2, 3, 4]) x: Const[i32] = a + b @@ -44,9 +38,7 @@ def test_const_correctly_evaluated_III(): def test_const_correctly_evaluated_IV(): - - src = \ -""" + src = """ a: Const[i32] = 4 def foo(y: i32): x: i32 = a + y @@ -57,9 +49,7 @@ def foo(y: i32): def test_const_correctly_evaluated_V(): - - src = \ -""" + src = """ a: Const[i32] = 4 b: Const[i32] = 4 def foo(y: i32): @@ -72,8 +62,7 @@ def foo(y: i32): def test_raises_exception_on_assignemnt_to_const_I(): - src = \ -""" + src = """ a: Const[i32] = 2 ** 5 a = 34 """ @@ -84,8 +73,7 @@ def test_raises_exception_on_assignemnt_to_const_I(): def test_raises_exception_on_assignemnt_to_const_II(): - src = \ -""" + src = """ x: Const[i32] = 100 def foo(): x: i32 = 2 @@ -98,8 +86,7 @@ def foo(): def test_raises_exception_on_assignemnt_to_const_III(): - src = \ -""" + src = """ y: Const[i32] = 100 @block def bb0(): @@ -114,8 +101,7 @@ def bb0(): def test_raises_exception_on_assignemnt_to_const_IV(): - src = \ -""" + src = """ z: Const[i32] = 100 def foo(x: i32): @block @@ -126,12 +112,14 @@ def bb0(z: i32): stmts = ast.parse(src).body with pytest.raises(CodeGenerationException) as err: CheckAndInlineConstants.run(stmts, __file__) - assert err.value.msg == "Constant 'z' is already defined and cannot be used as a function/block argument name." + assert ( + err.value.msg + == "Constant 'z' is already defined and cannot be used as a function/block argument name." + ) def test_raises_exception_on_duplicate_const(): - src = \ -""" + src = """ z: Const[i32] = 100 z: Const[i32] = 2 """ @@ -142,22 +130,26 @@ def test_raises_exception_on_duplicate_const(): def test_raises_exception_on_evaluation_error_I(): - src = \ -""" + src = """ z: Const[i32] = 23 / 0 """ stmts = ast.parse(src).body with pytest.raises(CodeGenerationException) as err: CheckAndInlineConstants.run(stmts, __file__) - assert err.value.msg == "Non-constant expression cannot be assigned to constant variable 'z' or cannot be evaluated." + assert ( + err.value.msg + == "Non-constant expression cannot be assigned to constant variable 'z' or cannot be evaluated." + ) def test_raises_exception_on_evaluation_error_II(): - src = \ -""" + src = """ a: Const[i32] = x + 12 """ stmts = ast.parse(src).body with pytest.raises(CodeGenerationException) as err: CheckAndInlineConstants.run(stmts, __file__) - assert err.value.msg == "Non-constant expression cannot be assigned to constant variable 'a' or cannot be evaluated." + assert ( + err.value.msg + == "Non-constant expression cannot be assigned to constant variable 'a' or cannot be evaluated." + ) diff --git a/tests/test_frontend_type_conversion.py b/tests/test_frontend_type_conversion.py index f6d14eb0a8..ad13217767 100644 --- a/tests/test_frontend_type_conversion.py +++ b/tests/test_frontend_type_conversion.py @@ -18,7 +18,6 @@ class A(ParametrizedAttribute): class _A(_FrontendType): - @staticmethod def to_xdsl() -> Callable[..., Any]: return A @@ -41,7 +40,6 @@ class D(ParametrizedAttribute): class _D(Generic[T], _FrontendType): - @staticmethod def to_xdsl() -> Callable[..., Any]: return D @@ -70,7 +68,9 @@ def test_raises_exception_on_non_frontend_type_I(): with pytest.raises(CodeGenerationException) as err: type_converter.convert_type_hint(type_hint) - assert err.value.msg == "Unknown type hint for type 'b' inside 'ast.Name' expression." + assert ( + err.value.msg == "Unknown type hint for type 'b' inside 'ast.Name' expression." + ) def test_raises_exception_on_non_frontend_type_II(): @@ -88,7 +88,10 @@ def test_raises_exception_on_nontrivial_generics(): with pytest.raises(CodeGenerationException) as err: type_converter.convert_type_hint(type_hint) - assert err.value.msg == "Expected 1 type argument for generic type 'd12', got 2 type arguments instead." + assert ( + err.value.msg + == "Expected 1 type argument for generic type 'd12', got 2 type arguments instead." + ) def test_type_conversion_caches_type(): diff --git a/tests/test_immutable_list.py b/tests/test_immutable_list.py index 561cf39111..8b41877357 100644 --- a/tests/test_immutable_list.py +++ b/tests/test_immutable_list.py @@ -163,4 +163,4 @@ def test_eq(): i, j, k = 1, 2, 3 list0: IList[int] = IList([i, j, k]) list1: IList[int] = IList([i, j, k]) - assert list0 == list1 \ No newline at end of file + assert list0 == list1 diff --git a/tests/test_interpreter.py b/tests/test_interpreter.py index c1137dc3ad..72aaa8d680 100644 --- a/tests/test_interpreter.py +++ b/tests/test_interpreter.py @@ -7,7 +7,6 @@ def test_import_functions(): - @dataclass class A(InterpreterFunctions): pass @@ -24,4 +23,4 @@ class B(InterpreterFunctions): with pytest.raises(ValueError) as e: i.register_implementations(A()) - assert e.value.args[0] == 'Use `@register_impls` on class A' + assert e.value.args[0] == "Use `@register_impls` on class A" diff --git a/tests/test_ir.py b/tests/test_ir.py index 047c7a407a..1e2057e247 100644 --- a/tests/test_ir.py +++ b/tests/test_ir.py @@ -141,42 +141,35 @@ def test_op_clone_with_regions(): ##################### Testing is_structurally_equal ##################### -program_region = \ -"""builtin.module() { +program_region = """builtin.module() { %0 : !i32 = arith.constant() ["value" = 1 : !i32] } """ -program_region_2 = \ -"""builtin.module() { +program_region_2 = """builtin.module() { %0 : !i32 = arith.constant() ["value" = 2 : !i32] } """ -program_region_2_diff_name = \ -"""builtin.module() { +program_region_2_diff_name = """builtin.module() { %cst : !i32 = arith.constant() ["value" = 2 : !i32] } """ -program_region_2_diff_type = \ -"""builtin.module() { +program_region_2_diff_type = """builtin.module() { %0 : !i64 = arith.constant() ["value" = 2 : !i64] } """ -program_add = \ -"""builtin.module() { +program_add = """builtin.module() { %0 : !i32 = arith.constant() ["value" = 1 : !i32] %1 : !i32 = arith.constant() ["value" = 2 : !i32] %2 : !i32 = arith.addi(%0 : !i32, %1 : !i32) } """ -program_add_2 = \ -"""builtin.module() { +program_add_2 = """builtin.module() { %0 : !i32 = arith.constant() ["value" = 1 : !i32] %1 : !i32 = arith.constant() ["value" = 2 : !i32] %2 : !i32 = arith.addi(%1 : !i32, %0 : !i32) } """ -program_func = \ -"""builtin.module() { +program_func = """builtin.module() { func.func() ["sym_name" = "test", "type" = !fun<[!i32, !i32], [!i32]>, "sym_visibility" = "private"] { ^0(%0 : !i32, %1 : !i32): %2 : !i32 = arith.addi(%0 : !i32, %1 : !i32) @@ -184,8 +177,7 @@ def test_op_clone_with_regions(): } } """ -program_successors = \ -""" +program_successors = """ func.func() ["sym_name" = "unconditional_br", "function_type" = !fun<[], []>, "sym_visibility" = "private"] { ^0: cf.br() (^1) @@ -197,20 +189,23 @@ def test_op_clone_with_regions(): @pytest.mark.parametrize( "args, expected_result", - [([program_region, program_region], True), - ([program_region_2, program_region_2], True), - ([program_region_2_diff_type, program_region_2_diff_type], True), - ([program_region_2_diff_name, program_region_2_diff_name], True), - ([program_region, program_region_2], False), - ([program_region_2, program_region_2_diff_type], False), - ([program_region_2, program_region_2_diff_name], True), - ([program_add, program_add], True), - ([program_add_2, program_add_2], True), - ([program_add, program_add_2], False), - ([program_func, program_func], True), - ([program_successors, program_successors], True), - ([program_successors, program_func], False), - ([program_successors, program_add], False)]) + [ + ([program_region, program_region], True), + ([program_region_2, program_region_2], True), + ([program_region_2_diff_type, program_region_2_diff_type], True), + ([program_region_2_diff_name, program_region_2_diff_name], True), + ([program_region, program_region_2], False), + ([program_region_2, program_region_2_diff_type], False), + ([program_region_2, program_region_2_diff_name], True), + ([program_add, program_add], True), + ([program_add_2, program_add_2], True), + ([program_add, program_add_2], False), + ([program_func, program_func], True), + ([program_successors, program_successors], True), + ([program_successors, program_func], False), + ([program_successors, program_add], False), + ], +) def test_is_structurally_equivalent(args: list[str], expected_result: bool): ctx = MLContext() ctx.register_dialect(Builtin) @@ -228,8 +223,7 @@ def test_is_structurally_equivalent(args: list[str], expected_result: bool): def test_is_structurally_equivalent_incompatible_ir_nodes(): - program_func = \ - """builtin.module() { + program_func = """builtin.module() { func.func() ["sym_name" = "test", "type" = !fun<[!i32, !i32], [!i32]>, "sym_visibility" = "private"] { ^0(%0 : !i32, %1 : !i32): %2 : !i32 = arith.addi(%0 : !i32, %1 : !i32) @@ -252,47 +246,61 @@ def test_is_structurally_equivalent_incompatible_ir_nodes(): assert isinstance(program, ModuleOp) assert program.is_structurally_equivalent(program.regions[0]) == False - assert program.is_structurally_equivalent( - program.regions[0].blocks[0]) == False + assert program.is_structurally_equivalent(program.regions[0].blocks[0]) == False assert program.regions[0].is_structurally_equivalent(program) == False - assert program.regions[0].blocks[0].is_structurally_equivalent( - program) == False - assert program.ops[0].regions[0].blocks[0].ops[ - 0].is_structurally_equivalent( - program.ops[0].regions[0].blocks[0].ops[1]) == False - assert program.ops[0].regions[0].blocks[0].is_structurally_equivalent( - program.ops[0].regions[0].blocks[1]) == False + assert program.regions[0].blocks[0].is_structurally_equivalent(program) == False + assert ( + program.ops[0] + .regions[0] + .blocks[0] + .ops[0] + .is_structurally_equivalent(program.ops[0].regions[0].blocks[0].ops[1]) + == False + ) + assert ( + program.ops[0] + .regions[0] + .blocks[0] + .is_structurally_equivalent(program.ops[0].regions[0].blocks[1]) + == False + ) def test_descriptions(): a = Constant.from_int_and_width(1, 32) - assert str(a.value) == '1 : i32' - assert f'{a.value}' == '1 : i32' + assert str(a.value) == "1 : i32" + assert f"{a.value}" == "1 : i32" assert str(a) == '%0 : !i32 = arith.constant() ["value" = 1 : !i32]' - assert f'{a}' == 'Constant(%0 : !i32 = arith.constant() ["value" = 1 : !i32])' + assert f"{a}" == 'Constant(%0 : !i32 = arith.constant() ["value" = 1 : !i32])' m = ModuleOp([a]) - assert str(m) == '''\ + assert ( + str(m) + == """\ builtin.module() { %0 : !i32 = arith.constant() ["value" = 1 : !i32] -}''' +}""" + ) - assert f'{m}' == '''\ + assert ( + f"{m}" + == """\ ModuleOp( \tbuiltin.module() { \t %0 : !i32 = arith.constant() ["value" = 1 : !i32] \t} -)''' +)""" + ) # ToDo: Create this op without IRDL itself, since it tests fine grained # stuff which is supposed to be used with IRDL or PDL. @irdl_op_definition class CustomOpWithMultipleRegions(IRDLOperation): - name = 'test.custom_op_with_multiple_regions' + name = "test.custom_op_with_multiple_regions" region: VarRegion @@ -308,7 +316,8 @@ def test_region_index_fetch(): region3 = Region([Block([d])]) op = CustomOpWithMultipleRegions.build( - regions=[[region0, region1, region2, region3]]) + regions=[[region0, region1, region2, region3]] + ) assert op.get_region_index(region0) == 0 assert op.get_region_index(region1) == 1 @@ -340,8 +349,7 @@ def test_detach_region(): region1 = Region([Block([b])]) region2 = Region([Block([c])]) - op = CustomOpWithMultipleRegions.build( - regions=[[region0, region1, region2]]) + op = CustomOpWithMultipleRegions.build(regions=[[region0, region1, region2]]) assert op.detach_region(1) == region1 assert op.detach_region(region0) == region0 @@ -351,7 +359,7 @@ def test_detach_region(): @irdl_op_definition class CustomVerify(IRDLOperation): - name = 'test.custom_verify_op' + name = "test.custom_verify_op" val: Annotated[Operand, i64] @staticmethod @@ -359,7 +367,7 @@ def get(val: SSAValue): return CustomVerify.build(operands=[val]) def verify_(self): - raise Exception('Custom Verification Check') + raise Exception("Custom Verification Check") def test_op_custom_verify_is_called(): @@ -367,7 +375,7 @@ def test_op_custom_verify_is_called(): b = CustomVerify.get(a.result) with pytest.raises(Exception) as e: b.verify() - assert e.value.args[0] == 'Custom Verification Check' + assert e.value.args[0] == "Custom Verification Check" def test_op_custom_verify_is_done_last(): @@ -376,9 +384,11 @@ def test_op_custom_verify_is_done_last(): b = CustomVerify.get(a.result) with pytest.raises(Exception) as e: b.verify() - assert e.value.args[0] != 'Custom Verification Check' - assert e.value.args[0] == \ - 'test.custom_verify_op operation does not verify\n\ntest.custom_verify_op(% : !i32)\n\n' + assert e.value.args[0] != "Custom Verification Check" + assert ( + e.value.args[0] + == "test.custom_verify_op operation does not verify\n\ntest.custom_verify_op(% : !i32)\n\n" + ) def test_replace_operand(): diff --git a/tests/test_is_satisfying_hint.py b/tests/test_is_satisfying_hint.py index 811618b6fb..1a8b087bb2 100644 --- a/tests/test_is_satisfying_hint.py +++ b/tests/test_is_satisfying_hint.py @@ -4,8 +4,14 @@ from xdsl.irdl import ParameterDef, irdl_attr_definition from xdsl.utils.hints import isa -from xdsl.dialects.builtin import (ArrayAttr, IndexType, IntAttr, FloatData, - IntegerAttr, IntegerType) +from xdsl.dialects.builtin import ( + ArrayAttr, + IndexType, + IntAttr, + FloatData, + IntegerAttr, + IntegerType, +) class Class1: @@ -151,9 +157,9 @@ def test_tuple_hint_correct(): """ Test that tuple hints work correcly on non-empty tuples of the right type. """ - assert isa((42, ), tuple[int]) + assert isa((42,), tuple[int]) assert isa((0, 3, 5), tuple[int, int, int]) - assert isa((False, ), tuple[bool]) + assert isa((False,), tuple[bool]) assert isa((True, False), tuple[bool, ...]) assert isa((True, 1, "test"), tuple[bool, int, str]) assert isa((Class1(), SubClass1()), tuple[Class1, ...]) @@ -177,20 +183,20 @@ def test_tuple_hint_failure(): """ Test that tuple hints work correcly on non-empty tuples of the wrong type. """ - assert not isa((0, ), tuple[bool]) + assert not isa((0,), tuple[bool]) assert not isa((0, True), tuple[bool, bool]) assert not isa((0, True), tuple[int]) assert not isa((True, 0), tuple[bool, bool]) assert not isa((True, False, True, 0), tuple[bool, ...]) - assert not isa((Class2(), ), tuple[Class1]) + assert not isa((Class2(),), tuple[Class1]) def test_tuple_hint_nested(): """ Test that we can check nested tuple hints. """ - assert isa(((), ), tuple[tuple[int, ...]]) - assert isa(((0, ), ), tuple[tuple[int]]) + assert isa(((),), tuple[tuple[int, ...]]) + assert isa(((0,),), tuple[tuple[int]]) assert isa((0, (1, 2)), tuple[int, tuple[int, int]]) assert isa(((0, 1), (2, 3), (4, 5)), tuple[tuple[int, int], ...]) assert not isa(((0, 1), (2, 3), (4, "5")), tuple[tuple[int, int], ...]) diff --git a/tests/test_lexer.py b/tests/test_lexer.py index 5e16dc4ab6..6e1994e43d 100644 --- a/tests/test_lexer.py +++ b/tests/test_lexer.py @@ -11,9 +11,9 @@ def get_token(input: str) -> Token: return token -def assert_single_token(input: str, - expected_kind: Token.Kind, - expected_text: str | None = None): +def assert_single_token( + input: str, expected_kind: Token.Kind, expected_text: str | None = None +): if expected_text is None: expected_text = input @@ -30,165 +30,201 @@ def assert_token_fail(input: str): lexer.lex() -@pytest.mark.parametrize('text,kind', [('->', Token.Kind.ARROW), - (':', Token.Kind.COLON), - (',', Token.Kind.COMMA), - ('...', Token.Kind.ELLIPSIS), - ('=', Token.Kind.EQUAL), - ('>', Token.Kind.GREATER), - ('{', Token.Kind.L_BRACE), - ('(', Token.Kind.L_PAREN), - ('[', Token.Kind.L_SQUARE), - ('<', Token.Kind.LESS), - ('-', Token.Kind.MINUS), - ('+', Token.Kind.PLUS), - ('?', Token.Kind.QUESTION), - ('}', Token.Kind.R_BRACE), - (')', Token.Kind.R_PAREN), - (']', Token.Kind.R_SQUARE), - ('*', Token.Kind.STAR), - ('|', Token.Kind.VERTICAL_BAR), - ('{-#', Token.Kind.FILE_METADATA_BEGIN), - ('#-}', Token.Kind.FILE_METADATA_END)]) +@pytest.mark.parametrize( + "text,kind", + [ + ("->", Token.Kind.ARROW), + (":", Token.Kind.COLON), + (",", Token.Kind.COMMA), + ("...", Token.Kind.ELLIPSIS), + ("=", Token.Kind.EQUAL), + (">", Token.Kind.GREATER), + ("{", Token.Kind.L_BRACE), + ("(", Token.Kind.L_PAREN), + ("[", Token.Kind.L_SQUARE), + ("<", Token.Kind.LESS), + ("-", Token.Kind.MINUS), + ("+", Token.Kind.PLUS), + ("?", Token.Kind.QUESTION), + ("}", Token.Kind.R_BRACE), + (")", Token.Kind.R_PAREN), + ("]", Token.Kind.R_SQUARE), + ("*", Token.Kind.STAR), + ("|", Token.Kind.VERTICAL_BAR), + ("{-#", Token.Kind.FILE_METADATA_BEGIN), + ("#-}", Token.Kind.FILE_METADATA_END), + ], +) def test_punctuation(text: str, kind: Token.Kind): assert_single_token(text, kind) -@pytest.mark.parametrize('text', ['.', '&', '/']) +@pytest.mark.parametrize("text", [".", "&", "/"]) def test_punctuation_fail(text: str): assert_token_fail(text) @pytest.mark.parametrize( - 'text', ['""', '"@"', '"foo"', '"\\""', '"\\n"', '"\\\\"', '"\\t"']) + "text", ['""', '"@"', '"foo"', '"\\""', '"\\n"', '"\\\\"', '"\\t"'] +) def test_str_literal(text: str): assert_single_token(text, Token.Kind.STRING_LIT) -@pytest.mark.parametrize('text', - ['"', '"\\"', '"\\a"', '"\n"', '"\v"', '"\f"']) +@pytest.mark.parametrize("text", ['"', '"\\"', '"\\a"', '"\n"', '"\v"', '"\f"']) def test_str_literal_fail(text: str): assert_token_fail(text) @pytest.mark.parametrize( - 'text', - ['a', 'A', '_', 'a_', 'a1', 'a1_', 'a1_2', 'a1_2_3', 'a$_.', 'a$_.1']) + "text", ["a", "A", "_", "a_", "a1", "a1_", "a1_2", "a1_2_3", "a$_.", "a$_.1"] +) def test_bare_ident(text: str): - '''bare-id ::= (letter|[_]) (letter|digit|[_$.])*''' + """bare-id ::= (letter|[_]) (letter|digit|[_$.])*""" assert_single_token(text, Token.Kind.BARE_IDENT) -@pytest.mark.parametrize('text', [ - '@a', '@A', '@_', '@a_', '@a1', '@a1_', '@a1_2', '@a1_2_3', '@a$_.', - '@a$_.1', '@""', '@"@"', '@"foo"', '@"\\""', '@"\\n"', '@"\\\\"', '@"\\t"' -]) +@pytest.mark.parametrize( + "text", + [ + "@a", + "@A", + "@_", + "@a_", + "@a1", + "@a1_", + "@a1_2", + "@a1_2_3", + "@a$_.", + "@a$_.1", + '@""', + '@"@"', + '@"foo"', + '@"\\""', + '@"\\n"', + '@"\\\\"', + '@"\\t"', + ], +) def test_at_ident(text: str): - '''at-ident ::= `@` (bare-id | string-literal)''' + """at-ident ::= `@` (bare-id | string-literal)""" assert_single_token(text, Token.Kind.AT_IDENT) -@pytest.mark.parametrize('text', [ - '@', '@"', '@"\\"', '@"\\a"', '@"\n"', '@"\v"', '@"\f"', '@ "a"', '@ f', - '@$' -]) +@pytest.mark.parametrize( + "text", + ["@", '@"', '@"\\"', '@"\\a"', '@"\n"', '@"\v"', '@"\f"', '@ "a"', "@ f", "@$"], +) def test_at_ident_fail(text: str): - '''at-ident ::= `@` (bare-id | string-literal)''' + """at-ident ::= `@` (bare-id | string-literal)""" assert_token_fail(text) @pytest.mark.parametrize( - 'text', - ['0', '1234', 'a', 'S', '$', '_', '.', '-', 'e_.$-324', 'e5$-e_', 'foo']) + "text", ["0", "1234", "a", "S", "$", "_", ".", "-", "e_.$-324", "e5$-e_", "foo"] +) def test_prefixed_ident(text: str): - '''hash-ident ::= `#` (digit+ | (letter|[$._-]) (letter|[$._-]|digit)*)''' - '''percent-ident ::= `%` (digit+ | (letter|[$._-]) (letter|[$._-]|digit)*)''' - '''caret-ident ::= `^` (digit+ | (letter|[$._-]) (letter|[$._-]|digit)*)''' - '''exclamation-ident ::= `!` (digit+ | (letter|[$._-]) (letter|[$._-]|digit)*)''' - assert_single_token('#' + text, Token.Kind.HASH_IDENT) - assert_single_token('%' + text, Token.Kind.PERCENT_IDENT) - assert_single_token('^' + text, Token.Kind.CARET_IDENT) - assert_single_token('!' + text, Token.Kind.EXCLAMATION_IDENT) + """hash-ident ::= `#` (digit+ | (letter|[$._-]) (letter|[$._-]|digit)*)""" + """percent-ident ::= `%` (digit+ | (letter|[$._-]) (letter|[$._-]|digit)*)""" + """caret-ident ::= `^` (digit+ | (letter|[$._-]) (letter|[$._-]|digit)*)""" + """exclamation-ident ::= `!` (digit+ | (letter|[$._-]) (letter|[$._-]|digit)*)""" + assert_single_token("#" + text, Token.Kind.HASH_IDENT) + assert_single_token("%" + text, Token.Kind.PERCENT_IDENT) + assert_single_token("^" + text, Token.Kind.CARET_IDENT) + assert_single_token("!" + text, Token.Kind.EXCLAMATION_IDENT) -@pytest.mark.parametrize('text', ['+', '""', '#', '%', '^', '!', '\n', '']) +@pytest.mark.parametrize("text", ["+", '""', "#", "%", "^", "!", "\n", ""]) def test_prefixed_ident_fail(text: str): - ''' + """ hash-ident ::= `#` (digit+ | (letter|[$._-]) (letter|[$._-]|digit)*) percent-ident ::= `%` (digit+ | (letter|[$._-]) (letter|[$._-]|digit)*) caret-ident ::= `^` (digit+ | (letter|[$._-]) (letter|[$._-]|digit)*) exclamation-ident ::= `!` (digit+ | (letter|[$._-]) (letter|[$._-]|digit)*) - ''' - assert_token_fail('#' + text) - assert_token_fail('%' + text) - assert_token_fail('^' + text) - assert_token_fail('!' + text) + """ + assert_token_fail("#" + text) + assert_token_fail("%" + text) + assert_token_fail("^" + text) + assert_token_fail("!" + text) -@pytest.mark.parametrize('text,expected', [('0x0', '0'), ('0e', '0'), - ('0$', '0'), ('0_', '0'), - ('0-', '0'), ('0.', '0')]) +@pytest.mark.parametrize( + "text,expected", + [("0x0", "0"), ("0e", "0"), ("0$", "0"), ("0_", "0"), ("0-", "0"), ("0.", "0")], +) def test_prefixed_ident_split(text: str, expected: str): - '''Check that the prefixed identifier is split at the right character.''' - assert_single_token('#' + text, Token.Kind.HASH_IDENT, '#' + expected) - assert_single_token('%' + text, Token.Kind.PERCENT_IDENT, '%' + expected) - assert_single_token('^' + text, Token.Kind.CARET_IDENT, '^' + expected) - assert_single_token('!' + text, Token.Kind.EXCLAMATION_IDENT, - '!' + expected) + """Check that the prefixed identifier is split at the right character.""" + assert_single_token("#" + text, Token.Kind.HASH_IDENT, "#" + expected) + assert_single_token("%" + text, Token.Kind.PERCENT_IDENT, "%" + expected) + assert_single_token("^" + text, Token.Kind.CARET_IDENT, "^" + expected) + assert_single_token("!" + text, Token.Kind.EXCLAMATION_IDENT, "!" + expected) -@pytest.mark.parametrize('text', - ['0', '01', '123456789', '99', '0x1234', '0xabcdef']) +@pytest.mark.parametrize("text", ["0", "01", "123456789", "99", "0x1234", "0xabcdef"]) def test_integer_literal(text: str): assert_single_token(text, Token.Kind.INTEGER_LIT) -@pytest.mark.parametrize('text,expected', [('0a', '0'), ('0xg', '0'), - ('0xfg', '0xf'), ('0xf.', '0xf')]) +@pytest.mark.parametrize( + "text,expected", [("0a", "0"), ("0xg", "0"), ("0xfg", "0xf"), ("0xf.", "0xf")] +) def test_integer_literal_split(text: str, expected: str): assert_single_token(text, Token.Kind.INTEGER_LIT, expected) -@pytest.mark.parametrize('text', [ - '0.', '1.', '0.2', '38.1243', '92.54e43', '92.5E43', '43.3e-54', '32.E+25' -]) +@pytest.mark.parametrize( + "text", ["0.", "1.", "0.2", "38.1243", "92.54e43", "92.5E43", "43.3e-54", "32.E+25"] +) def test_float_literal(text: str): assert_single_token(text, Token.Kind.FLOAT_LIT) -@pytest.mark.parametrize('text,expected', [('3.9e', '3.9'), ('4.5e+', '4.5'), - ('5.8e-', '5.8')]) +@pytest.mark.parametrize( + "text,expected", [("3.9e", "3.9"), ("4.5e+", "4.5"), ("5.8e-", "5.8")] +) def test_float_literal_split(text: str, expected: str): assert_single_token(text, Token.Kind.FLOAT_LIT, expected) -@pytest.mark.parametrize('text', - ['0', ' 0', ' 0', '\n0', '\t0', '// Comment\n0']) +@pytest.mark.parametrize("text", ["0", " 0", " 0", "\n0", "\t0", "// Comment\n0"]) def test_whitespace_skip(text: str): - assert_single_token(text, Token.Kind.INTEGER_LIT, '0') + assert_single_token(text, Token.Kind.INTEGER_LIT, "0") -@pytest.mark.parametrize('text', ['', ' ', '\n\n', '// Comment\n']) +@pytest.mark.parametrize("text", ["", " ", "\n\n", "// Comment\n"]) def test_eof(text: str): - assert_single_token(text, Token.Kind.EOF, '') + assert_single_token(text, Token.Kind.EOF, "") -@pytest.mark.parametrize('text, expected', [('0', 0), ('010', 10), - ('123456789', 123456789), - ('0x1234', 4660), - ('0xabcdef23', 2882400035)]) +@pytest.mark.parametrize( + "text, expected", + [ + ("0", 0), + ("010", 10), + ("123456789", 123456789), + ("0x1234", 4660), + ("0xabcdef23", 2882400035), + ], +) def test_token_get_int_value(text: str, expected: int): token = get_token(text) assert token.kind == Token.Kind.INTEGER_LIT assert token.get_int_value() == expected -@pytest.mark.parametrize('text, expected', [('0.', 0.0), ('1.', 1.0), - ('0.2', 0.2), ('38.1243', 38.1243), - ('92.54e43', 92.54e43), - ('92.5E43', 92.5E43), - ('43.3e-54', 43.3e-54), - ('32.E+25', 32.E+25)]) +@pytest.mark.parametrize( + "text, expected", + [ + ("0.", 0.0), + ("1.", 1.0), + ("0.2", 0.2), + ("38.1243", 38.1243), + ("92.54e43", 92.54e43), + ("92.5E43", 92.5e43), + ("43.3e-54", 43.3e-54), + ("32.E+25", 32.0e25), + ], +) def test_token_get_float_value(text: str, expected: float): token = get_token(text) assert token.kind == Token.Kind.FLOAT_LIT diff --git a/tests/test_mlctx.py b/tests/test_mlctx.py index 0bc3e8fb04..870546e54a 100644 --- a/tests/test_mlctx.py +++ b/tests/test_mlctx.py @@ -50,8 +50,7 @@ def test_get_op_unregistered(): assert issubclass(op, UnregisteredOp) assert ctx.get_op("dummy", allow_unregistered=True) == DummyOp - assert issubclass(ctx.get_op("dummy2", allow_unregistered=True), - UnregisteredOp) + assert issubclass(ctx.get_op("dummy2", allow_unregistered=True), UnregisteredOp) def test_get_attr(): @@ -76,22 +75,29 @@ def test_get_attr_unregistered(is_type: bool): ctx = MLContext() ctx.register_attr(DummyAttr) - assert ctx.get_optional_attr( - "dummy_attr", - allow_unregistered=True, - create_unregistered_as_type=is_type) == DummyAttr + assert ( + ctx.get_optional_attr( + "dummy_attr", allow_unregistered=True, create_unregistered_as_type=is_type + ) + == DummyAttr + ) attr = ctx.get_optional_attr("dummy_attr2", allow_unregistered=True) assert attr is not None assert issubclass(attr, UnregisteredAttr) if is_type: assert issubclass(attr, TypeAttribute) - assert ctx.get_attr("dummy_attr", - allow_unregistered=True, - create_unregistered_as_type=is_type) == DummyAttr + assert ( + ctx.get_attr( + "dummy_attr", allow_unregistered=True, create_unregistered_as_type=is_type + ) + == DummyAttr + ) assert issubclass( - ctx.get_attr("dummy_attr2", - allow_unregistered=True, - create_unregistered_as_type=is_type), UnregisteredAttr) + ctx.get_attr( + "dummy_attr2", allow_unregistered=True, create_unregistered_as_type=is_type + ), + UnregisteredAttr, + ) if is_type: assert issubclass(attr, TypeAttribute) diff --git a/tests/test_mlir_printer.py b/tests/test_mlir_printer.py index 1f8bfe150b..e9a934ea89 100644 --- a/tests/test_mlir_printer.py +++ b/tests/test_mlir_printer.py @@ -1,10 +1,23 @@ import re from io import StringIO from typing import Annotated -from xdsl.ir import (Attribute, Data, MLContext, TypeAttribute, - ParametrizedAttribute, Region) -from xdsl.irdl import (AnyAttr, ParameterDef, VarOpResult, VarOperand, - irdl_attr_definition, irdl_op_definition, IRDLOperation) +from xdsl.ir import ( + Attribute, + Data, + MLContext, + TypeAttribute, + ParametrizedAttribute, + Region, +) +from xdsl.irdl import ( + AnyAttr, + ParameterDef, + VarOpResult, + VarOperand, + irdl_attr_definition, + irdl_op_definition, + IRDLOperation, +) from xdsl.parser import BaseParser, XDSLParser from xdsl.printer import Printer @@ -12,6 +25,7 @@ @irdl_op_definition class ModuleOp(IRDLOperation): """Module operation. Redefined to not depend on the builtin dialect.""" + name = "module" region: Region @@ -19,6 +33,7 @@ class ModuleOp(IRDLOperation): @irdl_op_definition class AnyOp(IRDLOperation): """Operation only used for testing.""" + name = "any" op: Annotated[VarOperand, AnyAttr()] res: Annotated[VarOpResult, AnyAttr()] @@ -27,6 +42,7 @@ class AnyOp(IRDLOperation): @irdl_attr_definition class DataAttr(Data[int]): """Attribute only used for testing.""" + name = "data_attr" @staticmethod @@ -40,6 +56,7 @@ def print_parameter(self, printer: Printer) -> None: @irdl_attr_definition class DataType(Data[int], TypeAttribute): """Attribute only used for testing.""" + name = "data_type" @staticmethod @@ -95,9 +112,8 @@ def print_as_mlir_and_compare(test_prog: str, expected: str): printer.print_op(module) # Remove all whitespace from the expected string. - regex = re.compile(r'[^\S]+') - assert (regex.sub("", res.getvalue()).strip() == \ - regex.sub("", expected).strip()) + regex = re.compile(r"[^\S]+") + assert regex.sub("", res.getvalue()).strip() == regex.sub("", expected).strip() def test_empty_op(): diff --git a/tests/test_op_builder.py b/tests/test_op_builder.py index b6c4d7b904..4ae7850ed8 100644 --- a/tests/test_op_builder.py +++ b/tests/test_op_builder.py @@ -24,7 +24,6 @@ def test_builder(): def test_build_region(): - one = IntAttr(1) two = IntAttr(2) @@ -52,7 +51,6 @@ def region(b: Builder): def test_build_callable_region(): - one = IntAttr(1) two = IntAttr(2) @@ -87,7 +85,6 @@ def region(b: Builder, args: tuple[BlockArgument, ...]): def test_build_implicit_region(): - one = IntAttr(1) two = IntAttr(2) @@ -112,7 +109,6 @@ def region(): def test_build_implicit_callable_region(): - one = IntAttr(1) two = IntAttr(2) @@ -144,7 +140,6 @@ def region(args: tuple[BlockArgument, ...]): def test_build_nested_implicit_region(): - one = IntAttr(1) two = IntAttr(2) @@ -186,7 +181,6 @@ def region_1(): def test_build_implicit_region_fail(): - one = IntAttr(1) two = IntAttr(2) three = IntAttr(3) @@ -205,8 +199,9 @@ def region_2(b: Builder): b.insert(Constant.from_int_and_width(three, i32)) assert e.value.args[0] == ( - 'Cannot insert operation explicitly when an implicit' - ' builder exists.') + "Cannot insert operation explicitly when an implicit" + " builder exists." + ) y.add_region(region_2) diff --git a/tests/test_operation_builder.py b/tests/test_operation_builder.py index 53920efe72..c901122878 100644 --- a/tests/test_operation_builder.py +++ b/tests/test_operation_builder.py @@ -9,10 +9,24 @@ from xdsl.ir import Block, OpResult, Region from xdsl.irdl import ( - AttrSizedRegionSegments, OptOpResult, OptOperand, OptRegion, - OptSingleBlockRegion, Operand, SingleBlockRegion, VarOpResult, VarRegion, - VarSingleBlockRegion, irdl_op_definition, AttrSizedResultSegments, - VarOperand, AttrSizedOperandSegments, OpAttr, OptOpAttr, IRDLOperation) + AttrSizedRegionSegments, + OptOpResult, + OptOperand, + OptRegion, + OptSingleBlockRegion, + Operand, + SingleBlockRegion, + VarOpResult, + VarRegion, + VarSingleBlockRegion, + irdl_op_definition, + AttrSizedResultSegments, + VarOperand, + AttrSizedOperandSegments, + OpAttr, + OptOpAttr, + IRDLOperation, +) # ____ _ _ # | _ \ ___ ___ _ _| | |_ @@ -71,9 +85,7 @@ class VarResultOp(IRDLOperation): def test_var_result_builder(): op = VarResultOp.build(result_types=[[StringAttr("0"), StringAttr("1")]]) op.verify() - assert [res.typ - for res in op.results] == [StringAttr("0"), - StringAttr("1")] + assert [res.typ for res in op.results] == [StringAttr("0"), StringAttr("1")] @irdl_op_definition @@ -87,36 +99,41 @@ class TwoVarResultOp(IRDLOperation): def test_two_var_result_builder(): op = TwoVarResultOp.build( - result_types=[[StringAttr("0"), StringAttr("1")], - [StringAttr("2"), StringAttr("3")]]) + result_types=[ + [StringAttr("0"), StringAttr("1")], + [StringAttr("2"), StringAttr("3")], + ] + ) op.verify() assert [res.typ for res in op.results] == [ StringAttr("0"), StringAttr("1"), StringAttr("2"), - StringAttr("3") + StringAttr("3"), ] assert op.attributes[ - AttrSizedResultSegments.attribute_name] == DenseArrayBase.from_list( - i32, [2, 2]) + AttrSizedResultSegments.attribute_name + ] == DenseArrayBase.from_list(i32, [2, 2]) def test_two_var_result_builder2(): - op = TwoVarResultOp.build(result_types=[[StringAttr( - "0")], [StringAttr("1"), - StringAttr("2"), - StringAttr("3")]]) + op = TwoVarResultOp.build( + result_types=[ + [StringAttr("0")], + [StringAttr("1"), StringAttr("2"), StringAttr("3")], + ] + ) op.verify() assert [res.typ for res in op.results] == [ StringAttr("0"), StringAttr("1"), StringAttr("2"), - StringAttr("3") + StringAttr("3"), ] assert op.attributes[ - AttrSizedResultSegments.attribute_name] == DenseArrayBase.from_list( - i32, [1, 3]) + AttrSizedResultSegments.attribute_name + ] == DenseArrayBase.from_list(i32, [1, 3]) @irdl_op_definition @@ -131,21 +148,24 @@ class MixedResultOp(IRDLOperation): def test_var_mixed_builder(): op = MixedResultOp.build( - result_types=[[StringAttr("0"), StringAttr("1")], - StringAttr("2"), [StringAttr("3"), - StringAttr("4")]]) + result_types=[ + [StringAttr("0"), StringAttr("1")], + StringAttr("2"), + [StringAttr("3"), StringAttr("4")], + ] + ) op.verify() assert [res.typ for res in op.results] == [ StringAttr("0"), StringAttr("1"), StringAttr("2"), StringAttr("3"), - StringAttr("4") + StringAttr("4"), ] assert op.attributes[ - AttrSizedResultSegments.attribute_name] == DenseArrayBase.from_list( - i32, [2, 1, 2]) + AttrSizedResultSegments.attribute_name + ] == DenseArrayBase.from_list(i32, [2, 1, 2]) # ___ _ @@ -168,14 +188,14 @@ def test_operand_builder_operation(): op1 = ResultOp.build(result_types=[StringAttr("0")]) op2 = OperandOp.build(operands=[op1]) op2.verify() - assert op2.operands == (op1.res, ) + assert op2.operands == (op1.res,) def test_operand_builder_value(): op1 = ResultOp.build(result_types=[StringAttr("0")]) op2 = OperandOp.build(operands=[op1.res]) op2.verify() - assert op2.operands == (op1.res, ) + assert op2.operands == (op1.res,) def test_operand_builder_exception(): @@ -235,8 +255,8 @@ def test_two_var_operand_builder(): op2.verify() assert op2.operands == (op1.res, op1.res, op1.res, op1.res) assert op2.attributes[ - AttrSizedOperandSegments.attribute_name] == DenseArrayBase.from_list( - i32, [2, 2]) + AttrSizedOperandSegments.attribute_name + ] == DenseArrayBase.from_list(i32, [2, 2]) def test_two_var_operand_builder2(): @@ -245,8 +265,8 @@ def test_two_var_operand_builder2(): op2.verify() assert op2.operands == (op1.res, op1.res, op1.res, op1.res) assert op2.attributes[ - AttrSizedOperandSegments.attribute_name] == DenseArrayBase.from_list( - i32, [1, 3]) + AttrSizedOperandSegments.attribute_name + ] == DenseArrayBase.from_list(i32, [1, 3]) # _ _ _ _ _ _ @@ -270,10 +290,7 @@ def test_attr_op(): def test_attr_new_attr_op(): - op = AttrOp.build(attributes={ - "attr": StringAttr("0"), - "new_attr": StringAttr("1") - }) + op = AttrOp.build(attributes={"attr": StringAttr("0"), "new_attr": StringAttr("1")}) op.verify() assert op.attr == StringAttr("0") assert op.attributes["new_attr"] == StringAttr("1") @@ -436,13 +453,12 @@ def test_two_var_region_builder(): region2 = Region() region3 = Region() region4 = Region() - op2 = TwoVarRegionOp.build( - regions=[[region1, region2], [region3, region4]]) + op2 = TwoVarRegionOp.build(regions=[[region1, region2], [region3, region4]]) op2.verify() assert op2.regions == [region1, region2, region3, region4] assert op2.attributes[ - AttrSizedRegionSegments.attribute_name] == DenseArrayBase.from_list( - i32, [2, 2]) + AttrSizedRegionSegments.attribute_name + ] == DenseArrayBase.from_list(i32, [2, 2]) def test_two_var_operand_builder3(): @@ -450,13 +466,12 @@ def test_two_var_operand_builder3(): region2 = Region() region3 = Region() region4 = Region() - op2 = TwoVarRegionOp.build( - regions=[[region1], [region2, region3, region4]]) + op2 = TwoVarRegionOp.build(regions=[[region1], [region2, region3, region4]]) op2.verify() assert op2.regions == [region1, region2, region3, region4] assert op2.attributes[ - AttrSizedRegionSegments.attribute_name] == DenseArrayBase.from_list( - i32, [1, 3]) + AttrSizedRegionSegments.attribute_name + ] == DenseArrayBase.from_list(i32, [1, 3]) # __ __ _ diff --git a/tests/test_operation_definition.py b/tests/test_operation_definition.py index b0cfd60e47..46ecc698e2 100644 --- a/tests/test_operation_definition.py +++ b/tests/test_operation_definition.py @@ -5,12 +5,28 @@ from xdsl.dialects.builtin import IntAttr, StringAttr, i32 from xdsl.ir import Attribute, OpResult, Region -from xdsl.irdl import (AttrSizedOperandSegments, AttrSizedRegionSegments, - AttrSizedResultSegments, Operand, OptOpAttr, - OptOpResult, OptOperand, OptRegion, VarOpResult, - VarOperand, VarRegion, irdl_op_definition, OperandDef, - ResultDef, AttributeDef, AnyAttr, OpDef, RegionDef, - OpAttr, IRDLOperation) +from xdsl.irdl import ( + AttrSizedOperandSegments, + AttrSizedRegionSegments, + AttrSizedResultSegments, + Operand, + OptOpAttr, + OptOpResult, + OptOperand, + OptRegion, + VarOpResult, + VarOperand, + VarRegion, + irdl_op_definition, + OperandDef, + ResultDef, + AttributeDef, + AnyAttr, + OpDef, + RegionDef, + OpAttr, + IRDLOperation, +) from xdsl.utils.exceptions import PyRDLOpDefinitionError, VerifyException ################################################################################ @@ -39,7 +55,8 @@ def test_get_definition(): operands=[("operand", OperandDef(AnyAttr()))], results=[("result", ResultDef(AnyAttr()))], attributes={"attr": AttributeDef(AnyAttr())}, - regions=[("region", RegionDef())]) + regions=[("region", RegionDef())], + ) ################################################################################ @@ -192,10 +209,9 @@ class AttributeOp(IRDLOperation): def test_attribute_accessors(): """Test accessors for attributes.""" - op = AttributeOp.create(attributes={ - "attr": StringAttr("test"), - "opt_attr": StringAttr("opt_test") - }) + op = AttributeOp.create( + attributes={"attr": StringAttr("test"), "opt_attr": StringAttr("opt_test")} + ) assert op.attr is op.attributes["attr"] assert op.opt_attr is op.attributes["opt_attr"] diff --git a/tests/test_parser.py b/tests/test_parser.py index 9b5f3c9904..fe31481f7c 100644 --- a/tests/test_parser.py +++ b/tests/test_parser.py @@ -2,9 +2,15 @@ from io import StringIO -from xdsl.dialects.builtin import (IntAttr, DictionaryAttr, StringAttr, - ArrayAttr, Builtin, SymbolRefAttr) -from xdsl.ir import (MLContext, Attribute, Region, ParametrizedAttribute) +from xdsl.dialects.builtin import ( + IntAttr, + DictionaryAttr, + StringAttr, + ArrayAttr, + Builtin, + SymbolRefAttr, +) +from xdsl.ir import MLContext, Attribute, Region, ParametrizedAttribute from xdsl.irdl import irdl_attr_definition, irdl_op_definition, IRDLOperation from xdsl.parser import BaseParser, XDSLParser, MLIRParser from xdsl.printer import Printer @@ -14,24 +20,30 @@ # pyright: reportPrivateUsage=false -@pytest.mark.parametrize("input,expected", [("0, 1, 1", [0, 1, 1]), - ("1, 0, 1", [1, 0, 1]), - ("1, 1, 0", [1, 1, 0])]) +@pytest.mark.parametrize( + "input,expected", + [("0, 1, 1", [0, 1, 1]), ("1, 0, 1", [1, 0, 1]), ("1, 1, 0", [1, 1, 0])], +) def test_int_list_parser(input: str, expected: list[int]): ctx = MLContext() parser = XDSLParser(ctx, input) - int_list = parser.parse_list_of(parser.try_parse_integer_literal, '') + int_list = parser.parse_list_of(parser.try_parse_integer_literal, "") assert [int(span.text) for span in int_list] == expected -@pytest.mark.parametrize('data', [ - dict(a=IntAttr(1), b=IntAttr(2), c=IntAttr(3)), - dict(a=StringAttr('hello'), - b=IntAttr(2), - c=ArrayAttr([IntAttr(2), StringAttr('world')])), - dict(), -]) +@pytest.mark.parametrize( + "data", + [ + dict(a=IntAttr(1), b=IntAttr(2), c=IntAttr(3)), + dict( + a=StringAttr("hello"), + b=IntAttr(2), + c=ArrayAttr([IntAttr(2), StringAttr("world")]), + ), + dict(), + ], +) def test_dictionary_attr(data: dict[str, Attribute]): attr = DictionaryAttr(data) @@ -50,7 +62,7 @@ def test_dictionary_attr(data: dict[str, Attribute]): @irdl_attr_definition class DummyAttr(ParametrizedAttribute): - name = 'dummy.attr' + name = "dummy.attr" def test_parsing(): @@ -68,11 +80,14 @@ def test_parsing(): assert r == DummyAttr() -@pytest.mark.parametrize("ref,expected", [ - ("@foo", SymbolRefAttr("foo")), - ("@foo::@bar", SymbolRefAttr("foo", ["bar"])), - ("@foo::@bar::@baz", SymbolRefAttr("foo", ["bar", "baz"])), -]) +@pytest.mark.parametrize( + "ref,expected", + [ + ("@foo", SymbolRefAttr("foo")), + ("@foo::@bar", SymbolRefAttr("foo", ["bar"])), + ("@foo::@bar::@baz", SymbolRefAttr("foo", ["bar", "baz"])), + ], +) def test_symref(ref: str, expected: Attribute | None): """ Test that symbol references are correctly parsed. @@ -136,122 +151,141 @@ def test_parse_block_name(): parser = XDSLParser(ctx, block_str) block = parser.parse_block() - assert block.args[0].name == 'name' + assert block.args[0].name == "name" assert block.args[1].name is None -@pytest.mark.parametrize("delimiter,open_bracket,close_bracket", [ - (BaseParser.Delimiter.NONE, '', ''), - (BaseParser.Delimiter.PAREN, '(', ')'), - (BaseParser.Delimiter.SQUARE, '[', ']'), - (BaseParser.Delimiter.BRACES, '{', '}'), - (BaseParser.Delimiter.ANGLE, '<', '>'), -]) -def test_parse_comma_separated_list(delimiter: BaseParser.Delimiter, - open_bracket: str, close_bracket: str): +@pytest.mark.parametrize( + "delimiter,open_bracket,close_bracket", + [ + (BaseParser.Delimiter.NONE, "", ""), + (BaseParser.Delimiter.PAREN, "(", ")"), + (BaseParser.Delimiter.SQUARE, "[", "]"), + (BaseParser.Delimiter.BRACES, "{", "}"), + (BaseParser.Delimiter.ANGLE, "<", ">"), + ], +) +def test_parse_comma_separated_list( + delimiter: BaseParser.Delimiter, open_bracket: str, close_bracket: str +): input = open_bracket + "2, 4, 5" + close_bracket parser = XDSLParser(MLContext(), input) - res = parser.parse_comma_separated_list(delimiter, - parser.parse_int_literal, - ' in test') + res = parser.parse_comma_separated_list( + delimiter, parser.parse_int_literal, " in test" + ) assert res == [2, 4, 5] -@pytest.mark.parametrize("delimiter,open_bracket,close_bracket", [ - (BaseParser.Delimiter.PAREN, '(', ')'), - (BaseParser.Delimiter.SQUARE, '[', ']'), - (BaseParser.Delimiter.BRACES, '{', '}'), - (BaseParser.Delimiter.ANGLE, '<', '>'), -]) -def test_parse_comma_separated_list_empty(delimiter: BaseParser.Delimiter, - open_bracket: str, - close_bracket: str): +@pytest.mark.parametrize( + "delimiter,open_bracket,close_bracket", + [ + (BaseParser.Delimiter.PAREN, "(", ")"), + (BaseParser.Delimiter.SQUARE, "[", "]"), + (BaseParser.Delimiter.BRACES, "{", "}"), + (BaseParser.Delimiter.ANGLE, "<", ">"), + ], +) +def test_parse_comma_separated_list_empty( + delimiter: BaseParser.Delimiter, open_bracket: str, close_bracket: str +): input = open_bracket + close_bracket parser = XDSLParser(MLContext(), input) - res = parser.parse_comma_separated_list(delimiter, - parser.parse_int_literal, - ' in test') + res = parser.parse_comma_separated_list( + delimiter, parser.parse_int_literal, " in test" + ) assert res == [] def test_parse_comma_separated_list_none_delimiter_empty(): - parser = XDSLParser(MLContext(), 'o') + parser = XDSLParser(MLContext(), "o") with pytest.raises(ParseError): - parser.parse_comma_separated_list(BaseParser.Delimiter.NONE, - parser.parse_int_literal, ' in test') + parser.parse_comma_separated_list( + BaseParser.Delimiter.NONE, parser.parse_int_literal, " in test" + ) -@pytest.mark.parametrize("delimiter,open_bracket,close_bracket", - [(BaseParser.Delimiter.PAREN, '(', ')'), - (BaseParser.Delimiter.SQUARE, '[', ']'), - (BaseParser.Delimiter.BRACES, '{', '}'), - (BaseParser.Delimiter.ANGLE, '<', '>')]) +@pytest.mark.parametrize( + "delimiter,open_bracket,close_bracket", + [ + (BaseParser.Delimiter.PAREN, "(", ")"), + (BaseParser.Delimiter.SQUARE, "[", "]"), + (BaseParser.Delimiter.BRACES, "{", "}"), + (BaseParser.Delimiter.ANGLE, "<", ">"), + ], +) def test_parse_comma_separated_list_error_element( - delimiter: BaseParser.Delimiter, open_bracket: str, - close_bracket: str): + delimiter: BaseParser.Delimiter, open_bracket: str, close_bracket: str +): input = open_bracket + "o" + close_bracket parser = XDSLParser(MLContext(), input) with pytest.raises(ParseError) as e: - parser.parse_comma_separated_list(delimiter, parser.parse_int_literal, - ' in test') - assert e.value.span.text == 'o' + parser.parse_comma_separated_list( + delimiter, parser.parse_int_literal, " in test" + ) + assert e.value.span.text == "o" assert e.value.msg == "Expected integer literal here" -@pytest.mark.parametrize("delimiter,open_bracket,close_bracket", - [(BaseParser.Delimiter.PAREN, '(', ')'), - (BaseParser.Delimiter.SQUARE, '[', ']'), - (BaseParser.Delimiter.BRACES, '{', '}'), - (BaseParser.Delimiter.ANGLE, '<', '>')]) +@pytest.mark.parametrize( + "delimiter,open_bracket,close_bracket", + [ + (BaseParser.Delimiter.PAREN, "(", ")"), + (BaseParser.Delimiter.SQUARE, "[", "]"), + (BaseParser.Delimiter.BRACES, "{", "}"), + (BaseParser.Delimiter.ANGLE, "<", ">"), + ], +) def test_parse_comma_separated_list_error_delimiters( - delimiter: BaseParser.Delimiter, open_bracket: str, - close_bracket: str): + delimiter: BaseParser.Delimiter, open_bracket: str, close_bracket: str +): input = open_bracket + "2, 4 5" parser = XDSLParser(MLContext(), input) with pytest.raises(ParseError) as e: - parser.parse_comma_separated_list(delimiter, parser.parse_int_literal, - ' in test') - assert e.value.span.text == '5' + parser.parse_comma_separated_list( + delimiter, parser.parse_int_literal, " in test" + ) + assert e.value.span.text == "5" assert e.value.msg == "Expected '" + close_bracket + "' in test" @pytest.mark.parametrize( - 'punctuation', - list(Token.Kind.get_punctuation_spelling_to_kind_dict().values())) + "punctuation", list(Token.Kind.get_punctuation_spelling_to_kind_dict().values()) +) def test_is_punctuation_true(punctuation: Token.Kind): assert punctuation.is_punctuation() @pytest.mark.parametrize( - 'punctuation', - [Token.Kind.BARE_IDENT, Token.Kind.EOF, Token.Kind.INTEGER_LIT]) + "punctuation", [Token.Kind.BARE_IDENT, Token.Kind.EOF, Token.Kind.INTEGER_LIT] +) def test_is_punctuation_false(punctuation: Token.Kind): assert not punctuation.is_punctuation() @pytest.mark.parametrize( - 'punctuation', - list(Token.Kind.get_punctuation_spelling_to_kind_dict().values())) + "punctuation", list(Token.Kind.get_punctuation_spelling_to_kind_dict().values()) +) def test_is_spelling_of_punctuation_true(punctuation: Token.Kind): assert Token.Kind.is_spelling_of_punctuation(punctuation.value) -@pytest.mark.parametrize('punctuation', ['>-', 'o', '4', '$', '_', '@']) +@pytest.mark.parametrize("punctuation", [">-", "o", "4", "$", "_", "@"]) def test_is_spelling_of_punctuation_false(punctuation: str): assert not Token.Kind.is_spelling_of_punctuation(punctuation) @pytest.mark.parametrize( - 'punctuation', - list(Token.Kind.get_punctuation_spelling_to_kind_dict().values())) + "punctuation", list(Token.Kind.get_punctuation_spelling_to_kind_dict().values()) +) def test_get_punctuation_kind(punctuation: Token.Kind): - assert punctuation.get_punctuation_kind_from_spelling( - punctuation.value) == punctuation + assert ( + punctuation.get_punctuation_kind_from_spelling(punctuation.value) == punctuation + ) @pytest.mark.parametrize( - "punctuation", - list(Token.Kind.get_punctuation_spelling_to_kind_dict().keys())) + "punctuation", list(Token.Kind.get_punctuation_spelling_to_kind_dict().keys()) +) def test_parse_punctuation(punctuation: Token.PunctuationSpelling): parser = XDSLParser(MLContext(), punctuation) @@ -263,20 +297,20 @@ def test_parse_punctuation(punctuation: Token.PunctuationSpelling): @pytest.mark.parametrize( - "punctuation", - list(Token.Kind.get_punctuation_spelling_to_kind_dict().keys())) + "punctuation", list(Token.Kind.get_punctuation_spelling_to_kind_dict().keys()) +) def test_parse_punctuation_fail(punctuation: Token.PunctuationSpelling): - parser = XDSLParser(MLContext(), 'e +') + parser = XDSLParser(MLContext(), "e +") parser._synchronize_lexer_and_tokenizer() with pytest.raises(ParseError) as e: - parser.parse_punctuation(punctuation, ' in test') - assert e.value.span.text == 'e' + parser.parse_punctuation(punctuation, " in test") + assert e.value.span.text == "e" assert e.value.msg == "Expected '" + punctuation + "' in test" @pytest.mark.parametrize( - "punctuation", - list(Token.Kind.get_punctuation_spelling_to_kind_dict().keys())) + "punctuation", list(Token.Kind.get_punctuation_spelling_to_kind_dict().keys()) +) def test_parse_optional_punctuation(punctuation: Token.PunctuationSpelling): parser = XDSLParser(MLContext(), punctuation) parser._synchronize_lexer_and_tokenizer() @@ -287,21 +321,23 @@ def test_parse_optional_punctuation(punctuation: Token.PunctuationSpelling): @pytest.mark.parametrize( - "punctuation", - list(Token.Kind.get_punctuation_spelling_to_kind_dict().keys())) -def test_parse_optional_punctuation_fail( - punctuation: Token.PunctuationSpelling): - parser = XDSLParser(MLContext(), 'e +') + "punctuation", list(Token.Kind.get_punctuation_spelling_to_kind_dict().keys()) +) +def test_parse_optional_punctuation_fail(punctuation: Token.PunctuationSpelling): + parser = XDSLParser(MLContext(), "e +") parser._synchronize_lexer_and_tokenizer() assert parser.parse_optional_punctuation(punctuation) is None -@pytest.mark.parametrize("text, expected_value", [ - ("true", True), - ("false", False), - ("True", None), - ("False", None), -]) +@pytest.mark.parametrize( + "text, expected_value", + [ + ("true", True), + ("false", False), + ("True", None), + ("False", None), + ], +) def test_parse_boolean(text: str, expected_value: bool | None): parser = MLIRParser(MLContext(), text) assert parser.parse_optional_boolean() == expected_value @@ -314,87 +350,106 @@ def test_parse_boolean(text: str, expected_value: bool | None): assert parser.parse_boolean() == expected_value -@pytest.mark.parametrize("text, expected_value, allow_boolean, allow_negative", - [ - ("42", 42, False, False), - ("42", 42, True, False), - ("42", 42, False, True), - ("42", 42, True, True), - ("-1", None, False, False), - ("-1", None, True, False), - ("-1", -1, False, True), - ("-1", -1, True, True), - ("true", None, False, False), - ("true", 1, True, False), - ("true", None, False, True), - ("true", 1, True, True), - ("false", None, False, False), - ("false", 0, True, False), - ("false", None, False, True), - ("false", 0, True, True), - ("True", None, True, True), - ("False", None, True, True), - ("0x1a", 26, False, False), - ("0x1a", 26, True, False), - ("0x1a", 26, False, True), - ("0x1a", 26, True, True), - ("-0x1a", None, False, False), - ("-0x1a", None, True, False), - ("-0x1a", -26, False, True), - ("-0x1a", -26, True, True), - ]) -def test_parse_int(text: str, expected_value: int | None, allow_boolean: bool, - allow_negative: bool): +@pytest.mark.parametrize( + "text, expected_value, allow_boolean, allow_negative", + [ + ("42", 42, False, False), + ("42", 42, True, False), + ("42", 42, False, True), + ("42", 42, True, True), + ("-1", None, False, False), + ("-1", None, True, False), + ("-1", -1, False, True), + ("-1", -1, True, True), + ("true", None, False, False), + ("true", 1, True, False), + ("true", None, False, True), + ("true", 1, True, True), + ("false", None, False, False), + ("false", 0, True, False), + ("false", None, False, True), + ("false", 0, True, True), + ("True", None, True, True), + ("False", None, True, True), + ("0x1a", 26, False, False), + ("0x1a", 26, True, False), + ("0x1a", 26, False, True), + ("0x1a", 26, True, True), + ("-0x1a", None, False, False), + ("-0x1a", None, True, False), + ("-0x1a", -26, False, True), + ("-0x1a", -26, True, True), + ], +) +def test_parse_int( + text: str, expected_value: int | None, allow_boolean: bool, allow_negative: bool +): parser = MLIRParser(MLContext(), text) - assert parser.parse_optional_integer( - allow_boolean=allow_boolean, - allow_negative=allow_negative) == expected_value + assert ( + parser.parse_optional_integer( + allow_boolean=allow_boolean, allow_negative=allow_negative + ) + == expected_value + ) parser = MLIRParser(MLContext(), text) if expected_value is None: with pytest.raises(ParseError): - parser.parse_integer(allow_boolean=allow_boolean, - allow_negative=allow_negative) + parser.parse_integer( + allow_boolean=allow_boolean, allow_negative=allow_negative + ) else: - assert parser.parse_integer( - allow_boolean=allow_boolean, - allow_negative=allow_negative) == expected_value + assert ( + parser.parse_integer( + allow_boolean=allow_boolean, allow_negative=allow_negative + ) + == expected_value + ) -@pytest.mark.parametrize("text, allow_boolean, allow_negative", - [("-false", False, True), ("-false", True, True), - ("-true", False, True), ("-true", True, True), - ("-k", True, True), ("-(", False, True)]) -def test_parse_optional_int_error(text: str, allow_boolean: bool, - allow_negative: bool): +@pytest.mark.parametrize( + "text, allow_boolean, allow_negative", + [ + ("-false", False, True), + ("-false", True, True), + ("-true", False, True), + ("-true", True, True), + ("-k", True, True), + ("-(", False, True), + ], +) +def test_parse_optional_int_error(text: str, allow_boolean: bool, allow_negative: bool): """Test that parsing a negative without an integer after raise an error.""" parser = MLIRParser(MLContext(), text) with pytest.raises(ParseError): - parser.parse_optional_integer(allow_boolean=allow_boolean, - allow_negative=allow_negative) + parser.parse_optional_integer( + allow_boolean=allow_boolean, allow_negative=allow_negative + ) parser = MLIRParser(MLContext(), text) with pytest.raises(ParseError): - parser.parse_integer(allow_boolean=allow_boolean, - allow_negative=allow_negative) - - -@pytest.mark.parametrize("text, expected_value", [ - ("42", 42), - ("-1", -1), - ("true", None), - ("false", None), - ("0x1a", 26), - ("-0x1a", -26), - ('0.', 0.0), - ('1.', 1.0), - ('0.2', 0.2), - ('38.1243', 38.1243), - ('92.54e43', 92.54e43), - ('92.5E43', 92.5E43), - ('43.3e-54', 43.3e-54), - ('32.E+25', 32.E+25), -]) + parser.parse_integer(allow_boolean=allow_boolean, allow_negative=allow_negative) + + +@pytest.mark.parametrize( + "text, expected_value", + [ + ("42", 42), + ("-1", -1), + ("true", None), + ("false", None), + ("0x1a", 26), + ("-0x1a", -26), + ("0.", 0.0), + ("1.", 1.0), + ("0.2", 0.2), + ("38.1243", 38.1243), + ("92.54e43", 92.54e43), + ("92.5E43", 92.5e43), + ("43.3e-54", 43.3e-54), + ("32.E+25", 32.0e25), + ], +) def test_parse_number(text: str, expected_value: int | float | None): parser = MLIRParser(MLContext(), text) assert parser.parse_optional_number() == expected_value @@ -407,12 +462,15 @@ def test_parse_number(text: str, expected_value: int | float | None): assert parser.parse_number() == expected_value -@pytest.mark.parametrize("text", [ - ("-false"), - ("-true"), - ("-k"), - ("-("), -]) +@pytest.mark.parametrize( + "text", + [ + ("-false"), + ("-true"), + ("-k"), + ("-("), + ], +) def test_parse_number_error(text: str): """ Test that parsing a negative without an diff --git a/tests/test_parser_error.py b/tests/test_parser_error.py index f34f0e11d4..03089d23c8 100644 --- a/tests/test_parser_error.py +++ b/tests/test_parser_error.py @@ -3,7 +3,13 @@ from pytest import raises from xdsl.ir import MLContext -from xdsl.irdl import AnyAttr, irdl_op_definition, IRDLOperation, VarOperand, VarOpResult +from xdsl.irdl import ( + AnyAttr, + irdl_op_definition, + IRDLOperation, + VarOperand, + VarOpResult, +) from xdsl.parser import XDSLParser from xdsl.utils.exceptions import ParseError @@ -33,7 +39,8 @@ def check_error(prog: str, line: int, column: int, message: str): break else: assert False, "'{}' not found in an error message {}!".format( - message, e.value.args) + message, e.value.args + ) def test_parser_missing_equal(): @@ -41,14 +48,14 @@ def test_parser_missing_equal(): ctx = MLContext() ctx.register_op(UnkownOp) - prog = \ -""" + prog = """ unknown() { %0 : !i32 unknown() } """ - check_error(prog, 3, 12, - "Operation definitions expect an `=` after op-result-list!") + check_error( + prog, 3, 12, "Operation definitions expect an `=` after op-result-list!" + ) def test_parser_redefined_value(): @@ -56,8 +63,7 @@ def test_parser_redefined_value(): ctx = MLContext() ctx.register_op(UnkownOp) - prog = \ -""" + prog = """ unknown() { %val : !i32 = unknown() %val : !i32 = unknown() @@ -71,8 +77,7 @@ def test_parser_missing_operation_name(): ctx = MLContext() ctx.register_op(UnkownOp) - prog = \ -""" + prog = """ unknown() { %val : !i32 = } @@ -85,8 +90,7 @@ def test_parser_malformed_type(): ctx = MLContext() ctx.register_op(UnkownOp) - prog = \ -""" + prog = """ unknown() { %val : i32 = unknown() } diff --git a/tests/test_pattern_rewriter.py b/tests/test_pattern_rewriter.py index a20846f34b..cb6cd4c5c1 100644 --- a/tests/test_pattern_rewriter.py +++ b/tests/test_pattern_rewriter.py @@ -4,16 +4,19 @@ from xdsl.dialects.builtin import i32, i64, Builtin, IntegerAttr, ModuleOp from xdsl.dialects.scf import If, Scf from xdsl.ir import Block, MLContext, Region, Operation -from xdsl.pattern_rewriter import (PatternRewriteWalker, - op_type_rewrite_pattern, RewritePattern, - PatternRewriter, AnonymousRewritePattern, - GreedyRewritePatternApplier) +from xdsl.pattern_rewriter import ( + PatternRewriteWalker, + op_type_rewrite_pattern, + RewritePattern, + PatternRewriter, + AnonymousRewritePattern, + GreedyRewritePatternApplier, +) from xdsl.parser import Parser from xdsl.utils.hints import isa -def rewrite_and_compare(prog: str, expected_prog: str, - walker: PatternRewriteWalker): +def rewrite_and_compare(prog: str, expected_prog: str, walker: PatternRewriteWalker): ctx = MLContext() ctx.register_dialect(Builtin) ctx.register_dialect(Arith) @@ -30,96 +33,87 @@ def rewrite_and_compare(prog: str, expected_prog: str, def test_non_recursive_rewrite(): """Test a simple non-recursive rewrite""" - prog = \ -"""builtin.module() { + prog = """builtin.module() { %0 : !i32 = arith.constant() ["value" = 42 : !i32] %1 : !i32 = arith.addi(%0 : !i32, %0 : !i32) }""" - expected = \ -"""builtin.module() { + expected = """builtin.module() { %0 : !i32 = arith.constant() ["value" = 43 : !i32] %1 : !i32 = arith.addi(%0 : !i32, %0 : !i32) }""" class RewriteConst(RewritePattern): - def match_and_rewrite(self, op: Operation, rewriter: PatternRewriter): if isinstance(op, Constant): new_constant = Constant.from_int_and_width(43, i32) rewriter.replace_matched_op([new_constant]) rewrite_and_compare( - prog, expected, - PatternRewriteWalker(RewriteConst(), apply_recursively=False)) + prog, expected, PatternRewriteWalker(RewriteConst(), apply_recursively=False) + ) def test_non_recursive_rewrite_reversed(): """Test a simple non-recursive rewrite with reverse walk order.""" - prog = \ - """builtin.module() { + prog = """builtin.module() { %0 : !i32 = arith.constant() ["value" = 42 : !i32] %1 : !i32 = arith.addi(%0 : !i32, %0 : !i32) }""" - expected = \ -"""builtin.module() { + expected = """builtin.module() { %0 : !i32 = arith.constant() ["value" = 43 : !i32] %1 : !i32 = arith.addi(%0 : !i32, %0 : !i32) }""" class RewriteConst(RewritePattern): - def match_and_rewrite(self, op: Operation, rewriter: PatternRewriter): if isinstance(op, Constant): new_constant = Constant.from_int_and_width(43, i32) rewriter.replace_matched_op([new_constant]) rewrite_and_compare( - prog, expected, - PatternRewriteWalker(RewriteConst(), - apply_recursively=False, - walk_reverse=True)) + prog, + expected, + PatternRewriteWalker( + RewriteConst(), apply_recursively=False, walk_reverse=True + ), + ) def test_op_type_rewrite_pattern_method_decorator(): """Test op_type_rewrite_pattern decorator on methods.""" - prog = \ -"""builtin.module() { + prog = """builtin.module() { %0 : !i32 = arith.constant() ["value" = 42 : !i32] %1 : !i32 = arith.addi(%0 : !i32, %0 : !i32) }""" - expected = \ -"""builtin.module() { + expected = """builtin.module() { %0 : !i32 = arith.constant() ["value" = 43 : !i32] %1 : !i32 = arith.addi(%0 : !i32, %0 : !i32) }""" class RewriteConst(RewritePattern): - @op_type_rewrite_pattern def match_and_rewrite(self, op: Constant, rewriter: PatternRewriter): rewriter.replace_matched_op(Constant.from_int_and_width(43, i32)) rewrite_and_compare( - prog, expected, - PatternRewriteWalker(RewriteConst(), apply_recursively=False)) + prog, expected, PatternRewriteWalker(RewriteConst(), apply_recursively=False) + ) def test_op_type_rewrite_pattern_static_decorator(): """Test op_type_rewrite_pattern decorator on static functions.""" - prog = \ -"""builtin.module() { + prog = """builtin.module() { %0 : !i32 = arith.constant() ["value" = 42 : !i32] %1 : !i32 = arith.addi(%0 : !i32, %0 : !i32) }""" - expected = \ -"""builtin.module() { + expected = """builtin.module() { %0 : !i32 = arith.constant() ["value" = 43 : !i32] %1 : !i32 = arith.addi(%0 : !i32, %0 : !i32) }""" @@ -129,23 +123,24 @@ def match_and_rewrite(op: Constant, rewriter: PatternRewriter): rewriter.replace_matched_op(Constant.from_int_and_width(43, i32)) rewrite_and_compare( - prog, expected, - PatternRewriteWalker(AnonymousRewritePattern(match_and_rewrite), - apply_recursively=False)) + prog, + expected, + PatternRewriteWalker( + AnonymousRewritePattern(match_and_rewrite), apply_recursively=False + ), + ) def test_op_type_rewrite_pattern_union_type(): """Test op_type_rewrite_pattern decorator on static functions.""" - prog = \ -"""builtin.module() { + prog = """builtin.module() { %0 : !i32 = arith.constant() ["value" = 0 : !i32] %1 : !i32 = arith.addi(%0 : !i32, %0 : !i32) %2 : !i32 = "test"(%0 : !i32, %1 : !i32) }""" - expected = \ -"""builtin.module() { + expected = """builtin.module() { %0 : !i32 = arith.constant() ["value" = 42 : !i32] %1 : !i32 = arith.constant() ["value" = 42 : !i32] %2 : !i32 = "test"(%0 : !i32, %1 : !i32) @@ -156,21 +151,22 @@ def match_and_rewrite(op: Constant | Addi, rewriter: PatternRewriter): rewriter.replace_matched_op(Constant.from_int_and_width(42, i32)) rewrite_and_compare( - prog, expected, - PatternRewriteWalker(AnonymousRewritePattern(match_and_rewrite), - apply_recursively=False)) + prog, + expected, + PatternRewriteWalker( + AnonymousRewritePattern(match_and_rewrite), apply_recursively=False + ), + ) def test_recursive_rewriter(): """Test recursive walks on operations created by rewrites.""" - prog = \ -"""builtin.module() { + prog = """builtin.module() { %0 : !i32 = arith.constant() ["value" = 5 : !i32] }""" - expected = \ -"""builtin.module() { + expected = """builtin.module() { %0 : !i32 = arith.constant() ["value" = 1 : !i64] %1 : !i32 = arith.constant() ["value" = 1 : !i64] %2 : !i32 = arith.addi(%0 : !i32, %1 : !i32) @@ -190,28 +186,29 @@ def match_and_rewrite(op: Constant, rewriter: PatternRewriter): if val == 0 or val == 1: return None constant_op = Constant.from_attr( - IntegerAttr.from_int_and_width(val - 1, 64), i32) - constant_one = Constant.from_attr( - IntegerAttr.from_int_and_width(1, 64), i32) + IntegerAttr.from_int_and_width(val - 1, 64), i32 + ) + constant_one = Constant.from_attr(IntegerAttr.from_int_and_width(1, 64), i32) add_op = Addi.get(constant_op, constant_one) rewriter.replace_matched_op([constant_op, constant_one, add_op]) rewrite_and_compare( - prog, expected, - PatternRewriteWalker(AnonymousRewritePattern(match_and_rewrite), - apply_recursively=True)) + prog, + expected, + PatternRewriteWalker( + AnonymousRewritePattern(match_and_rewrite), apply_recursively=True + ), + ) def test_recursive_rewriter_reversed(): """Test recursive walks on operations created by rewrites, in reverse walk order.""" - prog = \ -"""builtin.module() { + prog = """builtin.module() { %0 : !i32 = arith.constant() ["value" = 5 : !i32] }""" - expected = \ -"""builtin.module() { + expected = """builtin.module() { %0 : !i32 = arith.constant() ["value" = 1 : !i64] %1 : !i32 = arith.constant() ["value" = 1 : !i64] %2 : !i32 = arith.addi(%0 : !i32, %1 : !i32) @@ -231,30 +228,32 @@ def match_and_rewrite(op: Constant, rewriter: PatternRewriter): if val == 0 or val == 1: return None constant_op = Constant.from_attr( - IntegerAttr.from_int_and_width(val - 1, 64), i32) - constant_one = Constant.from_attr( - IntegerAttr.from_int_and_width(1, 64), i32) + IntegerAttr.from_int_and_width(val - 1, 64), i32 + ) + constant_one = Constant.from_attr(IntegerAttr.from_int_and_width(1, 64), i32) add_op = Addi.get(constant_op, constant_one) rewriter.replace_matched_op([constant_op, constant_one, add_op]) rewrite_and_compare( - prog, expected, - PatternRewriteWalker(AnonymousRewritePattern(match_and_rewrite), - apply_recursively=True, - walk_reverse=True)) + prog, + expected, + PatternRewriteWalker( + AnonymousRewritePattern(match_and_rewrite), + apply_recursively=True, + walk_reverse=True, + ), + ) def test_greedy_rewrite_pattern_applier(): """Test GreedyRewritePatternApplier.""" - prog = \ -"""builtin.module() { + prog = """builtin.module() { %0 : !i32 = arith.constant() ["value" = 42 : !i32] %1 : !i32 = arith.addi(%0 : !i32, %0 : !i32) }""" - expected = \ -"""builtin.module() { + expected = """builtin.module() { %0 : !i32 = arith.constant() ["value" = 43 : !i32] %1 : !i32 = arith.muli(%0 : !i32, %0 : !i32) }""" @@ -268,24 +267,28 @@ def addi_rewrite(op: Addi, rewriter: PatternRewriter): rewriter.replace_matched_op([Muli.get(op.lhs, op.rhs)]) rewrite_and_compare( - prog, expected, - PatternRewriteWalker(GreedyRewritePatternApplier([ - AnonymousRewritePattern(constant_rewrite), - AnonymousRewritePattern(addi_rewrite) - ]), - apply_recursively=False)) + prog, + expected, + PatternRewriteWalker( + GreedyRewritePatternApplier( + [ + AnonymousRewritePattern(constant_rewrite), + AnonymousRewritePattern(addi_rewrite), + ] + ), + apply_recursively=False, + ), + ) def test_insert_op_before_matched_op(): """Test rewrites where operations are inserted before the matched operation.""" - prog = \ -"""builtin.module() { + prog = """builtin.module() { %0 : !i32 = arith.constant() ["value" = 5 : !i32] }""" - expected = \ -"""builtin.module() { + expected = """builtin.module() { %0 : !i32 = arith.constant() ["value" = 42 : !i32] %1 : !i32 = arith.constant() ["value" = 5 : !i32] }""" @@ -296,21 +299,22 @@ def match_and_rewrite(cst: Constant, rewriter: PatternRewriter): rewriter.insert_op_before_matched_op(new_cst) rewrite_and_compare( - prog, expected, - PatternRewriteWalker(AnonymousRewritePattern(match_and_rewrite), - apply_recursively=False)) + prog, + expected, + PatternRewriteWalker( + AnonymousRewritePattern(match_and_rewrite), apply_recursively=False + ), + ) def test_insert_op_at_pos(): """Test rewrites where operations are inserted with a given position.""" - prog = \ -"""builtin.module() { + prog = """builtin.module() { %0 : !i32 = arith.constant() ["value" = 5 : !i32] }""" - expected = \ -"""builtin.module() { + expected = """builtin.module() { %0 : !i32 = arith.constant() ["value" = 42 : !i32] %1 : !i32 = arith.constant() ["value" = 5 : !i32] }""" @@ -321,9 +325,12 @@ def match_and_rewrite(mod: ModuleOp, rewriter: PatternRewriter): rewriter.insert_op_at_pos(new_cst, mod.regions[0].blocks[0], 0) rewrite_and_compare( - prog, expected, - PatternRewriteWalker(AnonymousRewritePattern(match_and_rewrite), - apply_recursively=False)) + prog, + expected, + PatternRewriteWalker( + AnonymousRewritePattern(match_and_rewrite), apply_recursively=False + ), + ) def test_insert_op_at_pos_negative(): @@ -339,8 +346,9 @@ def test_insert_op_at_pos_negative(): def match_and_rewrite(mod: ModuleOp, rewriter: PatternRewriter): rewriter.insert_op_at_pos(to_be_inserted, mod.regions[0].blocks[0], -1) - PatternRewriteWalker(AnonymousRewritePattern(match_and_rewrite), - apply_recursively=False).rewrite_module(prog) + PatternRewriteWalker( + AnonymousRewritePattern(match_and_rewrite), apply_recursively=False + ).rewrite_module(prog) assert to_be_inserted in prog.ops assert prog.ops.index(to_be_inserted) == 1 @@ -349,13 +357,11 @@ def match_and_rewrite(mod: ModuleOp, rewriter: PatternRewriter): def test_insert_op_before(): """Test rewrites where operations are inserted before a given operation.""" - prog = \ -"""builtin.module() { + prog = """builtin.module() { %0 : !i32 = arith.constant() ["value" = 5 : !i32] }""" - expected = \ -"""builtin.module() { + expected = """builtin.module() { %0 : !i32 = arith.constant() ["value" = 42 : !i32] %1 : !i32 = arith.constant() ["value" = 5 : !i32] }""" @@ -366,21 +372,22 @@ def match_and_rewrite(mod: ModuleOp, rewriter: PatternRewriter): rewriter.insert_op_before(new_cst, mod.ops[0]) rewrite_and_compare( - prog, expected, - PatternRewriteWalker(AnonymousRewritePattern(match_and_rewrite), - apply_recursively=False)) + prog, + expected, + PatternRewriteWalker( + AnonymousRewritePattern(match_and_rewrite), apply_recursively=False + ), + ) def test_insert_op_after(): """Test rewrites where operations are inserted after a given operation.""" - prog = \ -"""builtin.module() { + prog = """builtin.module() { %0 : !i32 = arith.constant() ["value" = 5 : !i32] }""" - expected = \ -"""builtin.module() { + expected = """builtin.module() { %0 : !i32 = arith.constant() ["value" = 5 : !i32] %1 : !i32 = arith.constant() ["value" = 42 : !i32] }""" @@ -391,21 +398,22 @@ def match_and_rewrite(mod: ModuleOp, rewriter: PatternRewriter): rewriter.insert_op_after(new_cst, mod.ops[0]) rewrite_and_compare( - prog, expected, - PatternRewriteWalker(AnonymousRewritePattern(match_and_rewrite), - apply_recursively=False)) + prog, + expected, + PatternRewriteWalker( + AnonymousRewritePattern(match_and_rewrite), apply_recursively=False + ), + ) def test_insert_op_after_matched_op(): """Test rewrites where operations are inserted after a given operation.""" - prog = \ -"""builtin.module() { + prog = """builtin.module() { %0 : !i32 = arith.constant() ["value" = 5 : !i32] }""" - expected = \ -"""builtin.module() { + expected = """builtin.module() { %0 : !i32 = arith.constant() ["value" = 5 : !i32] %1 : !i32 = arith.constant() ["value" = 42 : !i32] }""" @@ -416,21 +424,22 @@ def match_and_rewrite(cst: Constant, rewriter: PatternRewriter): rewriter.insert_op_after_matched_op(new_cst) rewrite_and_compare( - prog, expected, - PatternRewriteWalker(AnonymousRewritePattern(match_and_rewrite), - apply_recursively=False)) + prog, + expected, + PatternRewriteWalker( + AnonymousRewritePattern(match_and_rewrite), apply_recursively=False + ), + ) def test_insert_op_after_matched_op_reversed(): """Test rewrites where operations are inserted after a given operation.""" - prog = \ -"""builtin.module() { + prog = """builtin.module() { %0 : !i32 = arith.constant() ["value" = 5 : !i32] }""" - expected = \ -"""builtin.module() { + expected = """builtin.module() { %0 : !i32 = arith.constant() ["value" = 5 : !i32] %1 : !i32 = arith.constant() ["value" = 42 : !i32] }""" @@ -441,22 +450,24 @@ def match_and_rewrite(cst: Constant, rewriter: PatternRewriter): rewriter.insert_op_after_matched_op(new_cst) rewrite_and_compare( - prog, expected, - PatternRewriteWalker(AnonymousRewritePattern(match_and_rewrite), - apply_recursively=False, - walk_reverse=True)) + prog, + expected, + PatternRewriteWalker( + AnonymousRewritePattern(match_and_rewrite), + apply_recursively=False, + walk_reverse=True, + ), + ) def test_operation_deletion(): """Test rewrites where SSA values are deleted.""" - prog = \ -"""builtin.module() { + prog = """builtin.module() { %0 : !i32 = arith.constant() ["value" = 5 : !i32] }""" - expected = \ -"""builtin.module() { + expected = """builtin.module() { }""" @op_type_rewrite_pattern @@ -464,8 +475,8 @@ def match_and_rewrite(op: Constant, rewriter: PatternRewriter): rewriter.erase_matched_op() rewrite_and_compare( - prog, expected, - PatternRewriteWalker(AnonymousRewritePattern(match_and_rewrite))) + prog, expected, PatternRewriteWalker(AnonymousRewritePattern(match_and_rewrite)) + ) def test_operation_deletion_reversed(): @@ -474,14 +485,12 @@ def test_operation_deletion_reversed(): They have to be deleted in order for the rewrite to not fail. """ - prog = \ -"""builtin.module() { + prog = """builtin.module() { %0 : !i32 = arith.constant() ["value" = 5 : !i32] %1 : !i32 = arith.addi(%0 : !i32, %0 : !i32) }""" - expected = \ -"""builtin.module() { + expected = """builtin.module() { }""" def match_and_rewrite(op: Operation, rewriter: PatternRewriter): @@ -489,9 +498,12 @@ def match_and_rewrite(op: Operation, rewriter: PatternRewriter): rewriter.erase_matched_op() rewrite_and_compare( - prog, expected, - PatternRewriteWalker(AnonymousRewritePattern(match_and_rewrite), - walk_reverse=True)) + prog, + expected, + PatternRewriteWalker( + AnonymousRewritePattern(match_and_rewrite), walk_reverse=True + ), + ) def test_operation_deletion_failure(): @@ -501,8 +513,7 @@ def test_operation_deletion_failure(): ctx.register_dialect(Builtin) ctx.register_dialect(Arith) - prog = \ -"""builtin.module() { + prog = """builtin.module() { %0 : !i32 = arith.constant() ["value" = 5 : !i32] %1 : !i32 = arith.addi(%0 : !i32, %0 : !i32) }""" @@ -526,13 +537,11 @@ def match_and_rewrite(op: Constant, rewriter: PatternRewriter): def test_delete_inner_op(): """Test rewrites where an operation inside a region of the matched op is deleted.""" - prog = \ -"""builtin.module() { + prog = """builtin.module() { %0 : !i32 = arith.constant() ["value" = 5 : !i32] }""" - expected = \ -"""builtin.module() { + expected = """builtin.module() { }""" @op_type_rewrite_pattern @@ -540,20 +549,18 @@ def match_and_rewrite(op: ModuleOp, rewriter: PatternRewriter): rewriter.erase_op(op.ops[0]) rewrite_and_compare( - prog, expected, - PatternRewriteWalker(AnonymousRewritePattern(match_and_rewrite))) + prog, expected, PatternRewriteWalker(AnonymousRewritePattern(match_and_rewrite)) + ) def test_replace_inner_op(): """Test rewrites where an operation inside a region of the matched op is deleted.""" - prog = \ -"""builtin.module() { + prog = """builtin.module() { %0 : !i32 = arith.constant() ["value" = 5 : !i32] }""" - expected = \ -"""builtin.module() { + expected = """builtin.module() { %0 : !i32 = arith.constant() ["value" = 42 : !i32] }""" @@ -562,15 +569,14 @@ def match_and_rewrite(op: ModuleOp, rewriter: PatternRewriter): rewriter.replace_op(op.ops[0], [Constant.from_int_and_width(42, i32)]) rewrite_and_compare( - prog, expected, - PatternRewriteWalker(AnonymousRewritePattern(match_and_rewrite))) + prog, expected, PatternRewriteWalker(AnonymousRewritePattern(match_and_rewrite)) + ) def test_block_argument_type_change(): """Test the modification of a block argument type.""" - prog = \ - """builtin.module() { + prog = """builtin.module() { %0 : !i1 = arith.constant() ["value" = true] scf.if(%0 : !i1) { ^0(%1 : !i32): @@ -578,8 +584,7 @@ def test_block_argument_type_change(): } {} }""" - expected = \ -"""builtin.module() { + expected = """builtin.module() { %0 : !i1 = arith.constant() ["value" = true] scf.if(%0 : !i1) { ^0(%1 : !i64): @@ -590,28 +595,28 @@ def test_block_argument_type_change(): @op_type_rewrite_pattern def match_and_rewrite(op: If, rewriter: PatternRewriter): - rewriter.modify_block_argument_type(op.true_region.blocks[0].args[0], - i64) + rewriter.modify_block_argument_type(op.true_region.blocks[0].args[0], i64) rewrite_and_compare( - prog, expected, - PatternRewriteWalker(AnonymousRewritePattern(match_and_rewrite), - apply_recursively=False)) + prog, + expected, + PatternRewriteWalker( + AnonymousRewritePattern(match_and_rewrite), apply_recursively=False + ), + ) def test_block_argument_erasure(): """Test the erasure of a block argument.""" - prog = \ - """builtin.module() { + prog = """builtin.module() { %0 : !i1 = arith.constant() ["value" = true] scf.if(%0 : !i1) { ^0(%1 : !i32): } {} }""" - expected = \ -"""builtin.module() { + expected = """builtin.module() { %0 : !i1 = arith.constant() ["value" = true] scf.if(%0 : !i1) { } { @@ -623,24 +628,25 @@ def match_and_rewrite(op: If, rewriter: PatternRewriter): rewriter.erase_block_argument(op.true_region.blocks[0].args[0]) rewrite_and_compare( - prog, expected, - PatternRewriteWalker(AnonymousRewritePattern(match_and_rewrite), - apply_recursively=False)) + prog, + expected, + PatternRewriteWalker( + AnonymousRewritePattern(match_and_rewrite), apply_recursively=False + ), + ) def test_block_argument_insertion(): """Test the insertion of a block argument.""" - prog = \ - """builtin.module() { + prog = """builtin.module() { %0 : !i1 = arith.constant() ["value" = true] scf.if(%0 : !i1) { } { } }""" - expected = \ -"""builtin.module() { + expected = """builtin.module() { %0 : !i1 = arith.constant() ["value" = true] scf.if(%0 : !i1) { ^0(%1 : !i32): @@ -653,16 +659,18 @@ def match_and_rewrite(op: If, rewriter: PatternRewriter): rewriter.insert_block_argument(op.true_region.blocks[0], 0, i32) rewrite_and_compare( - prog, expected, - PatternRewriteWalker(AnonymousRewritePattern(match_and_rewrite), - apply_recursively=False)) + prog, + expected, + PatternRewriteWalker( + AnonymousRewritePattern(match_and_rewrite), apply_recursively=False + ), + ) def test_inline_block_at_pos(): """Test the inlining of a block at a certain position.""" - prog = \ - """builtin.module() { + prog = """builtin.module() { %0 : !i1 = arith.constant() ["value" = true] scf.if(%0 : !i1) { scf.if(%0 : !i1) { @@ -673,8 +681,7 @@ def test_inline_block_at_pos(): } }""" - expected = \ -"""builtin.module() { + expected = """builtin.module() { %0 : !i1 = arith.constant() ["value" = true] scf.if(%0 : !i1) { %1 : !i32 = arith.constant() ["value" = 2 : !i32] @@ -690,27 +697,24 @@ def match_and_rewrite(op: If, rewriter: PatternRewriter): first_op = op.true_region.blocks[0].first_op if isinstance(first_op, If): inner_if_block = first_op.true_region.blocks[0] - rewriter.inline_block_at_pos(inner_if_block, - op.true_region.blocks[0], 0) + rewriter.inline_block_at_pos(inner_if_block, op.true_region.blocks[0], 0) rewrite_and_compare( - prog, expected, - PatternRewriteWalker(AnonymousRewritePattern(match_and_rewrite))) + prog, expected, PatternRewriteWalker(AnonymousRewritePattern(match_and_rewrite)) + ) def test_inline_block_before_matched_op(): """Test the inlining of a block before the matched operation.""" - prog = \ - """builtin.module() { + prog = """builtin.module() { %0 : !i1 = arith.constant() ["value" = true] scf.if(%0 : !i1) { %1 : !i32 = arith.constant() ["value" = 2 : !i32] } {} }""" - expected = \ -"""builtin.module() { + expected = """builtin.module() { %0 : !i1 = arith.constant() ["value" = true] %1 : !i32 = arith.constant() ["value" = 2 : !i32] scf.if(%0 : !i1) { @@ -723,16 +727,18 @@ def match_and_rewrite(op: If, rewriter: PatternRewriter): rewriter.inline_block_before_matched_op(op.true_region.blocks[0]) rewrite_and_compare( - prog, expected, - PatternRewriteWalker(AnonymousRewritePattern(match_and_rewrite), - apply_recursively=False)) + prog, + expected, + PatternRewriteWalker( + AnonymousRewritePattern(match_and_rewrite), apply_recursively=False + ), + ) def test_inline_block_before(): """Test the inlining of a block before an operation.""" - prog = \ -"""builtin.module() { + prog = """builtin.module() { %0 : !i1 = arith.constant() ["value" = true] scf.if(%0 : !i1) { scf.if(%0 : !i1) { @@ -742,8 +748,7 @@ def test_inline_block_before(): } }""" - expected = \ -"""builtin.module() { + expected = """builtin.module() { %0 : !i1 = arith.constant() ["value" = true] scf.if(%0 : !i1) { %1 : !i32 = arith.constant() ["value" = 2 : !i32] @@ -762,16 +767,18 @@ def match_and_rewrite(op: If, rewriter: PatternRewriter): rewriter.inline_block_before(inner_if_block, first_op) rewrite_and_compare( - prog, expected, - PatternRewriteWalker(AnonymousRewritePattern(match_and_rewrite), - apply_recursively=False)) + prog, + expected, + PatternRewriteWalker( + AnonymousRewritePattern(match_and_rewrite), apply_recursively=False + ), + ) def test_inline_block_at_before_when_op_is_matched_op(): """Test the inlining of a block before an operation, being the matched one.""" - prog = \ - """builtin.module() { + prog = """builtin.module() { %0 : !i1 = arith.constant() ["value" = true] scf.if(%0 : !i1) { %1 : !i32 = arith.constant() ["value" = 2 : !i32] @@ -779,8 +786,7 @@ def test_inline_block_at_before_when_op_is_matched_op(): } }""" - expected = \ -"""builtin.module() { + expected = """builtin.module() { %0 : !i1 = arith.constant() ["value" = true] %1 : !i32 = arith.constant() ["value" = 2 : !i32] scf.if(%0 : !i1) { @@ -793,16 +799,18 @@ def match_and_rewrite(op: If, rewriter: PatternRewriter): rewriter.inline_block_before(op.true_region.blocks[0], op) rewrite_and_compare( - prog, expected, - PatternRewriteWalker(AnonymousRewritePattern(match_and_rewrite), - apply_recursively=False)) + prog, + expected, + PatternRewriteWalker( + AnonymousRewritePattern(match_and_rewrite), apply_recursively=False + ), + ) def test_inline_block_after(): """Test the inlining of a block after an operation.""" - prog = \ -"""builtin.module() { + prog = """builtin.module() { %0 : !i1 = arith.constant() ["value" = true] scf.if(%0 : !i1) { scf.if(%0 : !i1) { @@ -813,8 +821,7 @@ def test_inline_block_after(): } }""" - expected = \ -"""builtin.module() { + expected = """builtin.module() { %0 : !i1 = arith.constant() ["value" = true] scf.if(%0 : !i1) { scf.if(%0 : !i1) { @@ -833,24 +840,25 @@ def match_and_rewrite(op: If, rewriter: PatternRewriter): rewriter.inline_block_after(inner_if_block, first_op) rewrite_and_compare( - prog, expected, - PatternRewriteWalker(AnonymousRewritePattern(match_and_rewrite), - apply_recursively=False)) + prog, + expected, + PatternRewriteWalker( + AnonymousRewritePattern(match_and_rewrite), apply_recursively=False + ), + ) def test_move_region_contents_to_new_regions(): """Test moving a region outside of a region.""" - prog = \ -"""builtin.module() { + prog = """builtin.module() { %0 : !i1 = arith.constant() ["value" = true] scf.if(%0 : !i1) { %1 : !i32 = arith.constant() ["value" = 2 : !i32] } }""" - expected = \ -"""builtin.module() { + expected = """builtin.module() { %0 : !i1 = arith.constant() ["value" = true] scf.if(%0 : !i1) { } @@ -864,12 +872,14 @@ def test_move_region_contents_to_new_regions(): def match_and_rewrite(op: ModuleOp, rewriter: PatternRewriter): old_if = op.ops[1] assert isinstance(old_if, If) - new_region = rewriter.move_region_contents_to_new_regions( - old_if.regions[0]) + new_region = rewriter.move_region_contents_to_new_regions(old_if.regions[0]) new_if = If.get(old_if.cond, [], new_region, Region([Block()])) rewriter.insert_op_after(new_if, op.ops[1]) rewrite_and_compare( - prog, expected, - PatternRewriteWalker(AnonymousRewritePattern(match_and_rewrite), - apply_recursively=False)) + prog, + expected, + PatternRewriteWalker( + AnonymousRewritePattern(match_and_rewrite), apply_recursively=False + ), + ) diff --git a/tests/test_printer.py b/tests/test_printer.py index 12bbb7bdd8..506c4b61a9 100644 --- a/tests/test_printer.py +++ b/tests/test_printer.py @@ -7,9 +7,22 @@ from xdsl.dialects.arith import Arith, Addi, Constant from xdsl.dialects.builtin import Builtin, IntAttr, ModuleOp, IntegerType, UnitAttr, i32 from xdsl.dialects.func import Func -from xdsl.ir import Attribute, MLContext, OpResult, Operation, ParametrizedAttribute, Block -from xdsl.irdl import (OptOpAttr, ParameterDef, irdl_attr_definition, - irdl_op_definition, IRDLOperation, Operand) +from xdsl.ir import ( + Attribute, + MLContext, + OpResult, + Operation, + ParametrizedAttribute, + Block, +) +from xdsl.irdl import ( + OptOpAttr, + ParameterDef, + irdl_attr_definition, + irdl_op_definition, + IRDLOperation, + Operand, +) from xdsl.parser import Parser, BaseParser, XDSLParser from xdsl.printer import Printer from xdsl.utils.diagnostic import Diagnostic @@ -27,8 +40,7 @@ def test_simple_forgotten_op(): add.verify() - expected = \ -""" + expected = """ %0 : !i32 = arith.addi(% : !i32, % : !i32) -----------------------^^^^^^^^^^---------------------------------------------------------------- | ERROR: SSAValue is not part of the IR, are you sure all operations are added before their uses? @@ -52,8 +64,7 @@ def test_forgotten_op_non_fail(): mod = ModuleOp([add, add2]) mod.verify() - expected = \ -""" + expected = """ builtin.module() { %0 : !i32 = arith.addi(% : !i32, % : !i32) -----------------------^^^^^^^^^^---------------------------------------------------------------- @@ -79,8 +90,7 @@ class UnitAttrOp(IRDLOperation): def test_unit_attr(): """Test that a UnitAttr can be defined and printed""" - expected = \ -""" + expected = """ unit_attr_op() ["parallelize"] """ @@ -92,14 +102,12 @@ def test_unit_attr(): def test_added_unit_attr(): """Test that a UnitAttr can be added to an op, even if its not defined as a field.""" - expected = \ -""" + expected = """ unit_attr_op() ["parallelize", "vectorize"] """ - unitop = UnitAttrOp.build(attributes={ - "parallelize": UnitAttr([]), - "vectorize": UnitAttr([]) - }) + unitop = UnitAttrOp.build( + attributes={"parallelize": UnitAttr([]), "vectorize": UnitAttr([])} + ) assert_print_op(unitop, expected, None) @@ -115,14 +123,12 @@ def test_added_unit_attr(): def test_op_message(): """Test that an operation message can be printed.""" - prog = \ - """builtin.module() { + prog = """builtin.module() { %0 : !i32 = arith.constant() ["value" = 42 : !i32] %1 : !i32 = arith.addi(%0 : !i32, %0 : !i32) }""" - expected = \ -""" + expected = """ builtin.module() { %0 : !i32 = arith.constant() ["value" = 42 : !i32] ^^^^^^^^^^^^^^^^^^^^^^^^^^ @@ -147,14 +153,12 @@ def test_op_message(): def test_two_different_op_messages(): """Test that an operation message can be printed.""" - prog = \ - """builtin.module() { + prog = """builtin.module() { %0 : !i32 = arith.constant() ["value" = 42 : !i32] %1 : !i32 = arith.addi(%0 : !i32, %0 : !i32) }""" - expected = \ -"""builtin.module() { + expected = """builtin.module() { %0 : !i32 = arith.constant() ["value" = 42 : !i32] ^^^^^^^^^^^^^^^^^^^^^^^^^^ | Test message 1 @@ -181,14 +185,12 @@ def test_two_different_op_messages(): def test_two_same_op_messages(): """Test that an operation message can be printed.""" - prog = \ - """builtin.module() { + prog = """builtin.module() { %0 : !i32 = arith.constant() ["value" = 42 : !i32] %1 : !i32 = arith.addi(%0 : !i32, %0 : !i32) }""" - expected = \ -"""builtin.module() { + expected = """builtin.module() { %0 : !i32 = arith.constant() ["value" = 42 : !i32] ^^^^^^^^^^^^^^^^^^^^^^^^^^ | Test message 1 @@ -215,14 +217,12 @@ def test_two_same_op_messages(): def test_op_message_with_region(): """Test that an operation message can be printed on an operation with a region.""" - prog = \ - """builtin.module() { + prog = """builtin.module() { %0 : !i32 = arith.constant() ["value" = 42 : !i32] %1 : !i32 = arith.addi(%0 : !i32, %0 : !i32) }""" - expected = \ -"""\ + expected = """\ builtin.module() { ^^^^^^^^^^^^^^ | Test @@ -249,14 +249,12 @@ def test_op_message_with_region_and_overflow(): Test that an operation message can be printed on an operation with a region, where the message is bigger than the operation. """ - prog = \ - """builtin.module() { + prog = """builtin.module() { %0 : !i32 = arith.constant() ["value" = 42 : !i32] %1 : !i32 = arith.addi(%0 : !i32, %0 : !i32) }""" - expected = \ -"""\ + expected = """\ builtin.module() { ^^^^^^^^^^^^^^----- | Test long message @@ -282,8 +280,7 @@ def test_diagnostic(): Test that an operation message can be printed on an operation with a region, where the message is bigger than the operation. """ - prog = \ - """builtin.module() { + prog = """builtin.module() { %0 : !i32 = arith.constant() ["value" = 42 : !i32] %1 : !i32 = arith.addi(%0 : !i32, %0 : !i32) }""" @@ -315,14 +312,12 @@ def test_print_custom_name(): """ Test that an SSAValue, that is a name and not a number, reserves that name """ - prog = \ - """builtin.module() { + prog = """builtin.module() { %i : !i32 = arith.constant() ["value" = 42 : !i32] %213 : !i32 = arith.addi(%i : !i32, %i : !i32) }""" - expected = \ -"""\ + expected = """\ builtin.module() { %i : !i32 = arith.constant() ["value" = 42 : !i32] %0 : !i32 = arith.addi(%i : !i32, %i : !i32) @@ -365,15 +360,14 @@ class PlusCustomFormatOp(IRDLOperation): res: Annotated[OpResult, IntegerType] @classmethod - def parse(cls, result_types: List[Attribute], - parser: BaseParser) -> PlusCustomFormatOp: - lhs = parser.parse_operand('Expected SSA Value name here!') - parser.parse_characters("+", - "Malformed operation format, expected `+`!") - rhs = parser.parse_operand('Expected SSA Value name here!') + def parse( + cls, result_types: List[Attribute], parser: BaseParser + ) -> PlusCustomFormatOp: + lhs = parser.parse_operand("Expected SSA Value name here!") + parser.parse_characters("+", "Malformed operation format, expected `+`!") + rhs = parser.parse_operand("Expected SSA Value name here!") - return PlusCustomFormatOp.create(operands=[lhs, rhs], - result_types=result_types) + return PlusCustomFormatOp.create(operands=[lhs, rhs], result_types=result_types) def print(self, printer: Printer): printer.print(" ", self.lhs, " + ", self.rhs) @@ -383,14 +377,12 @@ def test_generic_format(): """ Test that we can use generic formats in operations. """ - prog = \ - """builtin.module() { + prog = """builtin.module() { %0 : !i32 = arith.constant() ["value" = 42 : !i32] %1 : !i32 = "test.add"(%0: !i32, %0: !i32) }""" - expected = \ -"""\ + expected = """\ builtin.module() { %0 : !i32 = arith.constant() ["value" = 42 : !i32] %1 : !i32 = test.add %0 + %0 @@ -411,14 +403,12 @@ def test_custom_format(): """ Test that we can use custom formats in operations. """ - prog = \ - """builtin.module() { + prog = """builtin.module() { %0 : !i32 = arith.constant() ["value" = 42 : !i32] %1 : !i32 = test.add %0 + %0 }""" - expected = \ -"""\ + expected = """\ builtin.module() { %0 : !i32 = arith.constant() ["value" = 42 : !i32] %1 : !i32 = test.add %0 + %0 @@ -439,14 +429,12 @@ def test_custom_format_II(): """ Test that we can print using generic formats. """ - prog = \ - """builtin.module() { + prog = """builtin.module() { %0 : !i32 = arith.constant() ["value" = 42 : !i32] %1 : !i32 = test.add %0 + %0 }""" - expected = \ -"""\ + expected = """\ "builtin.module"() { %0 : !i32 = "arith.constant"() ["value" = 42 : !i32] %1 : !i32 = "test.add"(%0 : !i32, %0 : !i32) @@ -472,8 +460,7 @@ class CustomFormatAttr(ParametrizedAttribute): @staticmethod def parse_parameters(parser: BaseParser) -> list[Attribute]: parser.parse_char("<") - value = parser.tokenizer.next_token_of_pattern( - re.compile('(zero|one)')) + value = parser.tokenizer.next_token_of_pattern(re.compile("(zero|one)")) if value and value.text == "zero": parser.parse_char(">") return [IntAttr(0)] @@ -495,13 +482,11 @@ def test_custom_format_attr(): """ Test that we can parse and print attributes using custom formats. """ - prog = \ - """builtin.module() { + prog = """builtin.module() { any() ["attr" = !custom] }""" - expected = \ -"""\ + expected = """\ builtin.module() { any() ["attr" = !custom] }""" @@ -521,13 +506,11 @@ def test_parse_generic_format_attr_II(): """ Test that we can parse attributes using generic formats. """ - prog = \ - """builtin.module() { + prog = """builtin.module() { any() ["attr" = !custom] }""" - expected = \ -"""\ + expected = """\ "builtin.module"() { "any"() ["attr" = !"custom">] }""" @@ -547,13 +530,11 @@ def test_parse_generic_format_attr_III(): """ Test that we can parse attributes using generic formats. """ - prog = \ - """builtin.module() { + prog = """builtin.module() { any() ["attr" = !custom] }""" - expected = \ -"""\ + expected = """\ "builtin.module"() { "any"() ["attr" = !"custom">] }""" @@ -573,8 +554,7 @@ def test_foo_string(): """ Fail attribute in purpose. """ - prog = \ - """builtin.module() { + prog = """builtin.module() { any() ["attr" = !"string"<"foo">] }""" diff --git a/tests/test_pyrdl.py b/tests/test_pyrdl.py index f265fe472b..5819a53c3c 100644 --- a/tests/test_pyrdl.py +++ b/tests/test_pyrdl.py @@ -4,9 +4,17 @@ from dataclasses import dataclass from xdsl.ir import Attribute, Data, ParametrizedAttribute -from xdsl.irdl import (AllOf, AnyAttr, AnyOf, AttrConstraint, BaseAttr, - EqAttrConstraint, ParamAttrConstraint, ParameterDef, - irdl_attr_definition) +from xdsl.irdl import ( + AllOf, + AnyAttr, + AnyOf, + AttrConstraint, + BaseAttr, + EqAttrConstraint, + ParamAttrConstraint, + ParameterDef, + irdl_attr_definition, +) from xdsl.parser import BaseParser from xdsl.printer import Printer from xdsl.utils.exceptions import VerifyException @@ -15,6 +23,7 @@ @irdl_attr_definition class BoolData(Data[bool]): """An attribute holding a boolean value.""" + name = "bool" @staticmethod @@ -28,6 +37,7 @@ def print_parameter(self, printer: Printer): @irdl_attr_definition class IntData(Data[int]): """An attribute holding an integer value.""" + name = "int" @staticmethod @@ -41,6 +51,7 @@ def print_parameter(self, printer: Printer): @irdl_attr_definition class DoubleParamAttr(ParametrizedAttribute): """An attribute with two unbounded attribute parameters.""" + name = "param" param1: ParameterDef[Attribute] @@ -64,8 +75,7 @@ def test_eq_attr_verify_wrong_parameters_fail(): eq_true_constraint = EqAttrConstraint(bool_true) with pytest.raises(VerifyException) as e: eq_true_constraint.verify(bool_false) - assert e.value.args[0] == ( - f"Expected attribute {bool_true} but got {bool_false}") + assert e.value.args[0] == (f"Expected attribute {bool_true} but got {bool_false}") def test_eq_attr_verify_wrong_base_fail(): @@ -78,8 +88,7 @@ def test_eq_attr_verify_wrong_base_fail(): eq_true_constraint = EqAttrConstraint(bool_true) with pytest.raises(VerifyException) as e: eq_true_constraint.verify(int_zero) - assert e.value.args[0] == ( - f"Expected attribute {bool_true} but got {int_zero}") + assert e.value.args[0] == (f"Expected attribute {bool_true} but got {int_zero}") def test_base_attr_verify(): @@ -102,7 +111,8 @@ def test_base_attr_verify_wrong_base_fail(): with pytest.raises(VerifyException) as e: eq_true_constraint.verify(int_zero) assert e.value.args[0] == ( - f"{int_zero} should be of base attribute {BoolData.name}") + f"{int_zero} should be of base attribute {BoolData.name}" + ) def test_any_attr_verify(): @@ -119,11 +129,9 @@ class LessThan(AttrConstraint): def verify(self, attr: Attribute) -> None: if not isinstance(attr, IntData): - raise VerifyException( - f"{attr} should be of base attribute {IntData.name}") + raise VerifyException(f"{attr} should be of base attribute {IntData.name}") if attr.data >= self.bound: - raise VerifyException( - f"{attr} should hold a value less than {self.bound}") + raise VerifyException(f"{attr} should hold a value less than {self.bound}") @dataclass @@ -132,11 +140,11 @@ class GreaterThan(AttrConstraint): def verify(self, attr: Attribute) -> None: if not isinstance(attr, IntData): - raise VerifyException( - f"{attr} should be of base attribute {IntData.name}") + raise VerifyException(f"{attr} should be of base attribute {IntData.name}") if attr.data <= self.bound: raise VerifyException( - f"{attr} should hold a value greater than {self.bound}") + f"{attr} should hold a value greater than {self.bound}" + ) def test_anyof_verify(): @@ -194,28 +202,29 @@ def test_allof_verify_fail(): with pytest.raises(VerifyException) as e: constraint.verify(IntData(0)) - assert e.value.args[ - 0] == f"{IntData(0)} should hold a value greater than 0" + assert e.value.args[0] == f"{IntData(0)} should hold a value greater than 0" def test_allof_verify_multiple_failures(): """ - Check that an AllOf constraint provides verification info for all related constraints + Check that an AllOf constraint provides verification info for all related constraints even when one of them fails. """ constraint = AllOf([LessThan(5), GreaterThan(8)]) with pytest.raises(VerifyException) as e: constraint.verify(IntData(7)) - assert e.value.args[ - 0] == f"The following constraints were not satisfied:\n{IntData(7)} should hold a value less than 5\n{IntData(7)} should hold a value greater than 8" + assert ( + e.value.args[0] + == f"The following constraints were not satisfied:\n{IntData(7)} should hold a value less than 5\n{IntData(7)} should hold a value greater than 8" + ) def test_param_attr_verify(): bool_true = BoolData(True) constraint = ParamAttrConstraint( - DoubleParamAttr, [EqAttrConstraint(bool_true), - BaseAttr(IntData)]) + DoubleParamAttr, [EqAttrConstraint(bool_true), BaseAttr(IntData)] + ) constraint.verify(DoubleParamAttr([bool_true, IntData(0)])) constraint.verify(DoubleParamAttr([bool_true, IntData(42)])) @@ -223,18 +232,18 @@ def test_param_attr_verify(): def test_param_attr_verify_base_fail(): bool_true = BoolData(True) constraint = ParamAttrConstraint( - DoubleParamAttr, [EqAttrConstraint(bool_true), - BaseAttr(IntData)]) + DoubleParamAttr, [EqAttrConstraint(bool_true), BaseAttr(IntData)] + ) with pytest.raises(VerifyException) as e: constraint.verify(bool_true) assert e.value.args[0] == ( - f"{bool_true} should be of base attribute {DoubleParamAttr.name}") + f"{bool_true} should be of base attribute {DoubleParamAttr.name}" + ) def test_param_attr_verify_params_num_params_fail(): bool_true = BoolData(True) - constraint = ParamAttrConstraint(DoubleParamAttr, - [EqAttrConstraint(bool_true)]) + constraint = ParamAttrConstraint(DoubleParamAttr, [EqAttrConstraint(bool_true)]) attr = DoubleParamAttr([bool_true, IntData(0)]) with pytest.raises(VerifyException) as e: constraint.verify(attr) @@ -245,15 +254,15 @@ def test_param_attr_verify_params_fail(): bool_true = BoolData(True) bool_false = BoolData(False) constraint = ParamAttrConstraint( - DoubleParamAttr, [EqAttrConstraint(bool_true), - BaseAttr(IntData)]) + DoubleParamAttr, [EqAttrConstraint(bool_true), BaseAttr(IntData)] + ) with pytest.raises(VerifyException) as e: constraint.verify(DoubleParamAttr([bool_true, bool_false])) assert e.value.args[0] == ( - f"{bool_false} should be of base attribute {IntData.name}") + f"{bool_false} should be of base attribute {IntData.name}" + ) with pytest.raises(VerifyException) as e: constraint.verify(DoubleParamAttr([bool_false, IntData(0)])) - assert e.value.args[0] == ( - f"Expected attribute {bool_true} but got {bool_false}") + assert e.value.args[0] == (f"Expected attribute {bool_true} but got {bool_false}") diff --git a/tests/test_rewriter.py b/tests/test_rewriter.py index 86d73959d5..f798ee17eb 100644 --- a/tests/test_rewriter.py +++ b/tests/test_rewriter.py @@ -10,8 +10,9 @@ from xdsl.rewriter import Rewriter -def rewrite_and_compare(prog: str, expected_prog: str, - transformation: Callable[[ModuleOp, Rewriter], None]): +def rewrite_and_compare( + prog: str, expected_prog: str, transformation: Callable[[ModuleOp, Rewriter], None] +): ctx = MLContext() ctx.register_dialect(Builtin) ctx.register_dialect(Arith) @@ -30,13 +31,11 @@ def rewrite_and_compare(prog: str, expected_prog: str, def test_operation_deletion(): """Test rewrites where SSA values are deleted.""" - prog = \ -"""builtin.module() { + prog = """builtin.module() { %0 : !i32 = arith.constant() ["value" = 5 : !i32] }""" - expected = \ -"""builtin.module() { + expected = """builtin.module() { }""" def transformation(module: ModuleOp, rewriter: Rewriter) -> None: @@ -48,14 +47,12 @@ def transformation(module: ModuleOp, rewriter: Rewriter) -> None: # Test an operation replacement def test_replace_op_one_op(): - prog = \ - """builtin.module() { + prog = """builtin.module() { %0 : !i32 = arith.constant() ["value" = 42 : !i32] %1 : !i32 = arith.addi(%0 : !i32, %0 : !i32) }""" - expected = \ -"""builtin.module() { + expected = """builtin.module() { %0 : !i32 = arith.constant() ["value" = 43 : !i32] %1 : !i32 = arith.addi(%0 : !i32, %0 : !i32) }""" @@ -70,14 +67,12 @@ def transformation(module: ModuleOp, rewriter: Rewriter) -> None: # Test an operation replacement with multiple ops def test_replace_op_multiple_op(): - prog = \ - """builtin.module() { + prog = """builtin.module() { %0 : !i32 = arith.constant() ["value" = 2 : !i32] %1 : !i32 = arith.addi(%0 : !i32, %0 : !i32) }""" - expected = \ -"""builtin.module() { + expected = """builtin.module() { %0 : !i32 = arith.constant() ["value" = 1 : !i32] %1 : !i32 = arith.addi(%0 : !i32, %0 : !i32) %2 : !i32 = arith.addi(%1 : !i32, %1 : !i32) @@ -95,15 +90,13 @@ def transformation(module: ModuleOp, rewriter: Rewriter) -> None: # Test an operation replacement with manually specified results def test_replace_op_new_results(): - prog = \ - """builtin.module() { + prog = """builtin.module() { %0 : !i32 = arith.constant() ["value" = 2 : !i32] %1 : !i32 = arith.addi(%0 : !i32, %0 : !i32) %2 : !i32 = arith.muli(%1 : !i32, %1 : !i32) }""" - expected = \ -"""builtin.module() { + expected = """builtin.module() { %0 : !i32 = arith.constant() ["value" = 2 : !i32] %1 : !i32 = arith.muli(%0 : !i32, %0 : !i32) }""" @@ -119,16 +112,14 @@ def transformation(module: ModuleOp, rewriter: Rewriter) -> None: def test_inline_block_at_pos(): """Test the inlining of a block at a certain position.""" - prog = \ - """builtin.module() { + prog = """builtin.module() { %0 : !i1 = arith.constant() ["value" = true] scf.if(%0 : !i1) { %1 : !i32 = arith.constant() ["value" = 2 : !i32] } }""" - expected = \ -"""builtin.module() { + expected = """builtin.module() { %0 : !i1 = arith.constant() ["value" = true] %1 : !i32 = arith.constant() ["value" = 2 : !i32] scf.if(%0 : !i1) { @@ -147,16 +138,14 @@ def transformation(module: ModuleOp, rewriter: Rewriter) -> None: def test_inline_block_before(): """Test the inlining of a block before an operation.""" - prog = \ - """builtin.module() { + prog = """builtin.module() { %0 : !i1 = arith.constant() ["value" = true] scf.if(%0 : !i1) { %1 : !i32 = arith.constant() ["value" = 2 : !i32] } }""" - expected = \ -"""builtin.module() { + expected = """builtin.module() { %0 : !i1 = arith.constant() ["value" = true] %1 : !i32 = arith.constant() ["value" = 2 : !i32] scf.if(%0 : !i1) { @@ -174,16 +163,14 @@ def transformation(module: ModuleOp, rewriter: Rewriter) -> None: def test_inline_block_after(): """Test the inlining of a block after an operation.""" - prog = \ - """builtin.module() { + prog = """builtin.module() { %0 : !i1 = arith.constant() ["value" = true] scf.if(%0 : !i1) { %1 : !i32 = arith.constant() ["value" = 2 : !i32] } }""" - expected = \ -"""builtin.module() { + expected = """builtin.module() { %0 : !i1 = arith.constant() ["value" = true] %1 : !i32 = arith.constant() ["value" = 2 : !i32] scf.if(%0 : !i1) { @@ -202,13 +189,11 @@ def transformation(module: ModuleOp, rewriter: Rewriter) -> None: def test_insert_block(): """Test the insertion of a block in a region.""" - prog = \ - """builtin.module() { + prog = """builtin.module() { %0 : !i1 = arith.constant() ["value" = true] }""" - expected = \ -"""builtin.module() { + expected = """builtin.module() { ^0: ^1: %0 : !i1 = arith.constant() ["value" = true] @@ -222,13 +207,11 @@ def transformation(module: ModuleOp, rewriter: Rewriter) -> None: def test_insert_block2(): """Test the insertion of a block in a region.""" - prog = \ - """builtin.module() { + prog = """builtin.module() { %0 : !i1 = arith.constant() ["value" = true] }""" - expected = \ -"""builtin.module() { + expected = """builtin.module() { ^0: %0 : !i1 = arith.constant() ["value" = true] ^1: @@ -242,20 +225,17 @@ def transformation(module: ModuleOp, rewriter: Rewriter) -> None: def test_insert_block_before(): """Test the insertion of a block before another block.""" - prog = \ - """builtin.module() { + prog = """builtin.module() { %0 : !i1 = arith.constant() ["value" = true] }""" - expected = \ -"""builtin.module() { + expected = """builtin.module() { ^0: ^1: %0 : !i1 = arith.constant() ["value" = true] }""" - def insert_empty_block_before(module: ModuleOp, - rewriter: Rewriter) -> None: + def insert_empty_block_before(module: ModuleOp, rewriter: Rewriter) -> None: rewriter.insert_block_before(Block(), module.regions[0].blocks[0]) rewrite_and_compare(prog, expected, insert_empty_block_before) @@ -263,13 +243,11 @@ def insert_empty_block_before(module: ModuleOp, def test_insert_block_after(): """Test the insertion of a block after another block.""" - prog = \ - """builtin.module() { + prog = """builtin.module() { %0 : !i1 = arith.constant() ["value" = true] }""" - expected = \ -"""builtin.module() { + expected = """builtin.module() { ^0: %0 : !i1 = arith.constant() ["value" = true] ^1: @@ -283,13 +261,11 @@ def insert_empty_block_after(module: ModuleOp, rewriter: Rewriter) -> None: def test_insert_op_before(): """Test the insertion of an operation before another operation.""" - prog = \ - """builtin.module() { + prog = """builtin.module() { %0 : !i32 = arith.constant() ["value" = 43 : !i32] }""" - expected = \ -"""builtin.module() { + expected = """builtin.module() { %0 : !i64 = arith.constant() ["value" = 34 : !i64] %1 : !i32 = arith.constant() ["value" = 43 : !i32] }""" @@ -305,13 +281,11 @@ def transformation(module: ModuleOp, rewriter: Rewriter) -> None: def test_insert_op_after(): """Test the insertion of an operation after another operation.""" - prog = \ - """builtin.module() { + prog = """builtin.module() { %0 : !i32 = arith.constant() ["value" = 43 : !i32] }""" - expected = \ -"""builtin.module() { + expected = """builtin.module() { %0 : !i32 = arith.constant() ["value" = 43 : !i32] %1 : !i64 = arith.constant() ["value" = 34 : !i64] }""" @@ -327,14 +301,12 @@ def transformation(module: ModuleOp, rewriter: Rewriter) -> None: def test_preserve_naming_single_op(): """Test the preservation of names of SSAValues""" - prog = \ - """builtin.module() { + prog = """builtin.module() { %i : !i32 = arith.constant() ["value" = 42 : !i32] %1 : !i32 = arith.addi(%i : !i32, %i : !i32) }""" - expected = \ -"""builtin.module() { + expected = """builtin.module() { %i : !i32 = arith.constant() ["value" = 1 : !i32] %0 : !i32 = arith.addi(%i : !i32, %i : !i32) }""" @@ -350,14 +322,12 @@ def transformation(module: ModuleOp, rewriter: Rewriter) -> None: def test_preserve_naming_multiple_ops(): """Test the preservation of names of SSAValues for transformations to multiple ops""" - prog = \ - """builtin.module() { + prog = """builtin.module() { %i : !i32 = arith.constant() ["value" = 42 : !i32] %1 : !i32 = arith.addi(%i : !i32, %i : !i32) }""" - expected = \ -"""builtin.module() { + expected = """builtin.module() { %i : !i32 = arith.constant() ["value" = 1 : !i32] %i_1 : !i32 = arith.addi(%i : !i32, %i : !i32) %0 : !i32 = arith.addi(%i_1 : !i32, %i_1 : !i32) @@ -375,13 +345,11 @@ def transformation(module: ModuleOp, rewriter: Rewriter) -> None: def test_no_result_rewriter(): """Test rewriter on ops without results""" - prog = \ - """builtin.module() { + prog = """builtin.module() { func.return() }""" - expected = \ -"""builtin.module() { + expected = """builtin.module() { scf.yield() }""" diff --git a/tests/test_ssa_value.py b/tests/test_ssa_value.py index 8c0236dcbd..0bfafa0de2 100644 --- a/tests/test_ssa_value.py +++ b/tests/test_ssa_value.py @@ -38,13 +38,16 @@ def test_var_mixed_builder(): _ = SSAValue.get(op) -@pytest.mark.parametrize("name", [ - "test", - "-2", - "test_123", - "kebab-case-name", - None, -]) +@pytest.mark.parametrize( + "name", + [ + "test", + "-2", + "test_123", + "kebab-case-name", + None, + ], +) def test_ssa_value_name_hints(name: str | None): r""" As per the MLIR language reference, legal SSA value names must conform to @@ -64,7 +67,7 @@ def test_ssa_value_name_hints(name: str | None): assert val.name == name -@pytest.mark.parametrize("name", ['&', '#', '%2', '"', '::', '42']) +@pytest.mark.parametrize("name", ["&", "#", "%2", '"', "::", "42"]) def test_invalid_ssa_vals(name: str): """ This test tests invalid name hints that raise an error, because diff --git a/tests/test_traits.py b/tests/test_traits.py index 25fa6b371c..0ad9b5cae9 100644 --- a/tests/test_traits.py +++ b/tests/test_traits.py @@ -33,8 +33,10 @@ def verify(self, op: Operation): assert isinstance(op.results[0].typ, IntegerType) assert isinstance(op.operands[0].typ, IntegerType) if op.results[0].typ.width.data >= op.operands[0].typ.width.data: - raise VerifyException("Operation has a result bitwidth greater " - "or equal to the operand bitwidth.") + raise VerifyException( + "Operation has a result bitwidth greater " + "or equal to the operand bitwidth." + ) @dataclass(frozen=True) @@ -58,8 +60,9 @@ def verify(self, op: Operation): sum_bitwidth += result.typ.width.data if sum_bitwidth >= self.max_sum: - raise VerifyException("Operation has a bitwidth sum " - f"greater or equal to {self.max_sum}.") + raise VerifyException( + "Operation has a bitwidth sum " f"greater or equal to {self.max_sum}." + ) @irdl_op_definition @@ -86,9 +89,7 @@ def test_get_traits_of_type(): Test the `get_traits_of_type` `Operation` method on a simple operation definition. """ - assert TestOp.get_traits_of_type(LargerOperandTrait) == [ - LargerOperandTrait() - ] + assert TestOp.get_traits_of_type(LargerOperandTrait) == [LargerOperandTrait()] assert TestOp.get_traits_of_type(LargerResultTrait) == [] assert TestOp.get_traits_of_type(BitwidthSumLessThanTrait) == [ BitwidthSumLessThanTrait(64) @@ -105,14 +106,16 @@ def test_verifier(): op = TestOp.create(operands=[operand1], result_types=[i32]) with pytest.raises(VerifyException) as e: op.verify() - assert e.value.args[0] == ("Operation has a result bitwidth greater" - " or equal to the operand bitwidth.") + assert e.value.args[0] == ( + "Operation has a result bitwidth greater" " or equal to the operand bitwidth." + ) op = TestOp.create(operands=[operand64], result_types=[i32]) with pytest.raises(VerifyException) as e: op.verify() - assert e.value.args[0] == ("Operation has a bitwidth sum " - "greater or equal to 64.") + assert e.value.args[0] == ( + "Operation has a bitwidth sum " "greater or equal to 64." + ) op = TestOp.create(operands=[operand32], result_types=[i1]) op.verify() @@ -144,8 +147,8 @@ def test_trait_inheritance(): Check that traits are correctly inherited from parent classes. """ assert TestCopyOp.traits == frozenset( - [LargerOperandTrait(), - BitwidthSumLessThanTrait(64)]) + [LargerOperandTrait(), BitwidthSumLessThanTrait(64)] + ) @irdl_op_definition diff --git a/tests/xdsl_opt/test_xdsl_opt.py b/tests/xdsl_opt/test_xdsl_opt.py index cfa926ddd1..010e805945 100644 --- a/tests/xdsl_opt/test_xdsl_opt.py +++ b/tests/xdsl_opt/test_xdsl_opt.py @@ -11,37 +11,39 @@ def test_opt(): opt = xDSLOptMain(args=[]) - assert list(opt.available_frontends.keys()) == ['xdsl', 'mlir'] - assert list(opt.available_targets.keys()) == ['xdsl', 'irdl', 'mlir'] + assert list(opt.available_frontends.keys()) == ["xdsl", "mlir"] + assert list(opt.available_targets.keys()) == ["xdsl", "irdl", "mlir"] assert list(opt.available_passes.keys()) == [ - 'lower-mpi', - 'convert-stencil-to-ll-mlir', - 'convert-stencil-to-gpu', - 'stencil-shape-inference', - 'stencil-to-local-2d-horizontal', - 'frontend-desymrefy', + "lower-mpi", + "convert-stencil-to-ll-mlir", + "convert-stencil-to-gpu", + "stencil-shape-inference", + "stencil-to-local-2d-horizontal", + "frontend-desymrefy", ] def test_empty_program(): - filename = 'tests/xdsl_opt/empty_program.xdsl' + filename = "tests/xdsl_opt/empty_program.xdsl" opt = xDSLOptMain(args=[filename]) f = StringIO("") with redirect_stdout(f): opt.run() - with open(filename, 'r') as file: + with open(filename, "r") as file: expected = file.read() assert f.getvalue().strip() == expected.strip() @pytest.mark.parametrize( "args, expected_error", - [(['tests/xdsl_opt/not_module.xdsl'], "Expected ModuleOp at top level!"), - (['tests/xdsl_opt/not_module.mlir'], "Expected ModuleOp at top level!"), - (['tests/xdsl_opt/empty_program.wrong' - ], "Unrecognized file extension 'wrong'")]) + [ + (["tests/xdsl_opt/not_module.xdsl"], "Expected ModuleOp at top level!"), + (["tests/xdsl_opt/not_module.mlir"], "Expected ModuleOp at top level!"), + (["tests/xdsl_opt/empty_program.wrong"], "Unrecognized file extension 'wrong'"), + ], +) def test_error_on_run(args, expected_error): opt = xDSLOptMain(args=args) @@ -53,8 +55,13 @@ def test_error_on_run(args, expected_error): @pytest.mark.parametrize( "args, expected_error", - [(['tests/xdsl_opt/empty_program.xdsl', '-p', 'wrong' - ], "Unrecognized pass: wrong")]) + [ + ( + ["tests/xdsl_opt/empty_program.xdsl", "-p", "wrong"], + "Unrecognized pass: wrong", + ) + ], +) def test_error_on_construction(args, expected_error): with pytest.raises(Exception) as e: opt = xDSLOptMain(args=args) @@ -63,7 +70,7 @@ def test_error_on_construction(args, expected_error): def test_wrong_target(): - filename = 'tests/xdsl_opt/empty_program.xdsl' + filename = "tests/xdsl_opt/empty_program.xdsl" opt = xDSLOptMain(args=[filename]) opt.args.target = "wrong" @@ -74,31 +81,28 @@ def test_wrong_target(): def test_print_to_file(): - filename_in = 'tests/xdsl_opt/empty_program.xdsl' - filename_out = 'tests/xdsl_opt/empty_program.out' + filename_in = "tests/xdsl_opt/empty_program.xdsl" + filename_out = "tests/xdsl_opt/empty_program.out" - opt = xDSLOptMain(args=[filename_in, '-o', filename_out]) + opt = xDSLOptMain(args=[filename_in, "-o", filename_out]) opt.run() - with open(filename_in, 'r') as file: + with open(filename_in, "r") as file: inp = file.read() - with open(filename_out, 'r') as file: + with open(filename_out, "r") as file: expected = file.read() assert inp.strip() == expected.strip() def test_operation_deletion(): - filename_in = 'tests/xdsl_opt/constant_program.xdsl' - filename_out = 'tests/xdsl_opt/empty_program.xdsl' + filename_in = "tests/xdsl_opt/constant_program.xdsl" + filename_out = "tests/xdsl_opt/empty_program.xdsl" class xDSLOptMainPass(xDSLOptMain): - def register_all_passes(self): - class RemoveConstantPass(ModulePass): - - name = 'remove-constant' + name = "remove-constant" def apply(self, ctx: MLContext, op: builtin.ModuleOp): if isinstance(op, builtin.ModuleOp): @@ -106,12 +110,12 @@ def apply(self, ctx: MLContext, op: builtin.ModuleOp): self.register_pass(RemoveConstantPass) - opt = xDSLOptMainPass(args=[filename_in, '-p', 'remove-constant']) + opt = xDSLOptMainPass(args=[filename_in, "-p", "remove-constant"]) f = StringIO("") with redirect_stdout(f): opt.run() - with open(filename_out, 'r') as file: + with open(filename_out, "r") as file: expected = file.read() assert f.getvalue().strip() == expected.strip() diff --git a/versioneer.py b/versioneer.py index 1b9e84ef32..6d2d220007 100644 --- a/versioneer.py +++ b/versioneer.py @@ -306,11 +306,13 @@ def get_root(): setup_py = os.path.join(root, "setup.py") versioneer_py = os.path.join(root, "versioneer.py") if not (os.path.exists(setup_py) or os.path.exists(versioneer_py)): - err = ("Versioneer was unable to run the project root directory. " - "Versioneer requires setup.py to be executed from " - "its immediate directory (like 'python setup.py COMMAND'), " - "or in a way that lets it use sys.argv[0] to find the root " - "(like 'python path/to/setup.py COMMAND').") + err = ( + "Versioneer was unable to run the project root directory. " + "Versioneer requires setup.py to be executed from " + "its immediate directory (like 'python setup.py COMMAND'), " + "or in a way that lets it use sys.argv[0] to find the root " + "(like 'python path/to/setup.py COMMAND')." + ) raise VersioneerBadRootError(err) try: # Certain runtime workflows (setup.py install/develop in a setuptools @@ -323,8 +325,10 @@ def get_root(): me_dir = os.path.normcase(os.path.splitext(me)[0]) vsr_dir = os.path.normcase(os.path.splitext(versioneer_py)[0]) if me_dir != vsr_dir: - print("Warning: build in %s is using versioneer.py from %s" % - (os.path.dirname(me), versioneer_py)) + print( + "Warning: build in %s is using versioneer.py from %s" + % (os.path.dirname(me), versioneer_py) + ) except NameError: pass return root @@ -382,12 +386,7 @@ def decorate(f): return decorate -def run_command(commands, - args, - cwd=None, - verbose=False, - hide_stderr=False, - env=None): +def run_command(commands, args, cwd=None, verbose=False, hide_stderr=False, env=None): """Call the given command(s).""" assert isinstance(commands, list) p = None @@ -413,7 +412,7 @@ def run_command(commands, return None, None else: if verbose: - print("unable to find command, tried %s" % (commands, )) + print("unable to find command, tried %s" % (commands,)) return None, None stdout = p.communicate()[0].strip() if sys.version_info[0] >= 3: @@ -426,7 +425,9 @@ def run_command(commands, return stdout, p.returncode -LONG_VERSION_PY["git"] = ''' +LONG_VERSION_PY[ + "git" +] = ''' # This file helps to compute a version number in source trees obtained from # git-archive tarball (such as those provided by githubs download-from-tag # feature). Distribution tarballs (built by setup.py sdist) and build @@ -1001,7 +1002,7 @@ def git_versions_from_keywords(keywords, tag_prefix, verbose): # starting in git-1.8.3, tags are listed as "tag: foo-1.0" instead of # just "foo-1.0". If we see a "tag: " prefix, prefer those. TAG = "tag: " - tags = set([r[len(TAG):] for r in refs if r.startswith(TAG)]) + tags = set([r[len(TAG) :] for r in refs if r.startswith(TAG)]) if not tags: # Either we're using git < 1.8.3, or there really are no tags. We use # a heuristic: assume all version tags have a digit. The old git %d @@ -1018,7 +1019,7 @@ def git_versions_from_keywords(keywords, tag_prefix, verbose): for ref in sorted(tags): # sorting will prefer e.g. "2.0" over "2.0rc1" if ref.startswith(tag_prefix): - r = ref[len(tag_prefix):] + r = ref[len(tag_prefix) :] if verbose: print("picking %s" % r) return { @@ -1052,9 +1053,7 @@ def git_pieces_from_vcs(tag_prefix, root, verbose, run_command=run_command): if sys.platform == "win32": GITS = ["git.cmd", "git.exe"] - out, rc = run_command(GITS, ["rev-parse", "--git-dir"], - cwd=root, - hide_stderr=True) + out, rc = run_command(GITS, ["rev-parse", "--git-dir"], cwd=root, hide_stderr=True) if rc != 0: if verbose: print("Directory %s not under git control" % root) @@ -1097,7 +1096,7 @@ def git_pieces_from_vcs(tag_prefix, root, verbose, run_command=run_command): dirty = git_describe.endswith("-dirty") pieces["dirty"] = dirty if dirty: - git_describe = git_describe[:git_describe.rindex("-dirty")] + git_describe = git_describe[: git_describe.rindex("-dirty")] # now we have TAG-NUM-gHEX or HEX @@ -1106,8 +1105,7 @@ def git_pieces_from_vcs(tag_prefix, root, verbose, run_command=run_command): mo = re.search(r"^(.+)-(\d+)-g([0-9a-f]+)$", git_describe) if not mo: # unparseable. Maybe git-describe is misbehaving? - pieces[ - "error"] = "unable to parse git-describe output: '%s'" % describe_out + pieces["error"] = "unable to parse git-describe output: '%s'" % describe_out return pieces # tag @@ -1121,7 +1119,7 @@ def git_pieces_from_vcs(tag_prefix, root, verbose, run_command=run_command): tag_prefix, ) return pieces - pieces["closest-tag"] = full_tag[len(tag_prefix):] + pieces["closest-tag"] = full_tag[len(tag_prefix) :] # distance: number of commits since tag pieces["distance"] = int(mo.group(2)) @@ -1132,13 +1130,13 @@ def git_pieces_from_vcs(tag_prefix, root, verbose, run_command=run_command): else: # HEX: no tags pieces["closest-tag"] = None - count_out, rc = run_command(GITS, ["rev-list", "HEAD", "--count"], - cwd=root) + count_out, rc = run_command(GITS, ["rev-list", "HEAD", "--count"], cwd=root) pieces["distance"] = int(count_out) # total number of commits # commit date: see ISO-8601 comment in git_versions_from_keywords() - date = run_command(GITS, ["show", "-s", "--format=%ci", "HEAD"], - cwd=root)[0].strip() + date = run_command(GITS, ["show", "-s", "--format=%ci", "HEAD"], cwd=root)[ + 0 + ].strip() pieces["date"] = date.strip().replace(" ", "T", 1).replace(" ", "", 1) return pieces @@ -1195,7 +1193,7 @@ def versions_from_parentdir(parentdir_prefix, root, verbose): dirname = os.path.basename(root) if dirname.startswith(parentdir_prefix): return { - "version": dirname[len(parentdir_prefix):], + "version": dirname[len(parentdir_prefix) :], "full-revisionid": None, "dirty": False, "error": None, @@ -1206,8 +1204,10 @@ def versions_from_parentdir(parentdir_prefix, root, verbose): root = os.path.dirname(root) # up a level if verbose: - print("Tried directories %s but none started with prefix %s" % - (str(rootdirs), parentdir_prefix)) + print( + "Tried directories %s but none started with prefix %s" + % (str(rootdirs), parentdir_prefix) + ) raise NotThisMethod("rootdir doesn't start with parentdir_prefix") @@ -1236,11 +1236,13 @@ def versions_from_file(filename): contents = f.read() except EnvironmentError: raise NotThisMethod("unable to read _version.py") - mo = re.search(r"version_json = '''\n(.*)''' # END VERSION_JSON", - contents, re.M | re.S) + mo = re.search( + r"version_json = '''\n(.*)''' # END VERSION_JSON", contents, re.M | re.S + ) if not mo: - mo = re.search(r"version_json = '''\r\n(.*)''' # END VERSION_JSON", - contents, re.M | re.S) + mo = re.search( + r"version_json = '''\r\n(.*)''' # END VERSION_JSON", contents, re.M | re.S + ) if not mo: raise NotThisMethod("no version_json in _version.py") return json.loads(mo.group(1)) @@ -1249,10 +1251,7 @@ def versions_from_file(filename): def write_to_version_file(filename, versions): """Write the given version number to the given _version.py file.""" os.unlink(filename) - contents = json.dumps(versions, - sort_keys=True, - indent=1, - separators=(",", ": ")) + contents = json.dumps(versions, sort_keys=True, indent=1, separators=(",", ": ")) with open(filename, "w") as f: f.write(SHORT_VERSION_PY % contents) @@ -1453,8 +1452,9 @@ def get_versions(verbose=False): handlers = HANDLERS.get(cfg.VCS) assert handlers, "unrecognized VCS '%s'" % cfg.VCS verbose = verbose or cfg.verbose - assert (cfg.versionfile_source - is not None), "please set versioneer.versionfile_source" + assert ( + cfg.versionfile_source is not None + ), "please set versioneer.versionfile_source" assert cfg.tag_prefix is not None, "please set versioneer.tag_prefix" versionfile_abs = os.path.join(root, cfg.versionfile_source) @@ -1588,7 +1588,6 @@ def run(self): from distutils.command.build_py import build_py as _build_py class cmd_build_py(_build_py): - def run(self): root = get_root() cfg = get_config_from_root(root) @@ -1597,8 +1596,7 @@ def run(self): # now locate _version.py in the new build/ directory and replace # it with an updated value if cfg.versionfile_build: - target_versionfile = os.path.join(self.build_lib, - cfg.versionfile_build) + target_versionfile = os.path.join(self.build_lib, cfg.versionfile_build) print("UPDATING %s" % target_versionfile) write_to_version_file(target_versionfile, versions) @@ -1615,7 +1613,6 @@ def run(self): # ... class cmd_build_exe(_build_exe): - def run(self): root = get_root() cfg = get_config_from_root(root) @@ -1629,13 +1626,15 @@ def run(self): with open(cfg.versionfile_source, "w") as f: LONG = LONG_VERSION_PY[cfg.VCS] f.write( - LONG % { + LONG + % { "DOLLAR": "$", "STYLE": cfg.style, "TAG_PREFIX": cfg.tag_prefix, "PARENTDIR_PREFIX": cfg.parentdir_prefix, "VERSIONFILE_SOURCE": cfg.versionfile_source, - }) + } + ) cmds["build_exe"] = cmd_build_exe del cmds["build_py"] @@ -1647,7 +1646,6 @@ def run(self): from py2exe.build_exe import py2exe as _py2exe # py2 class cmd_py2exe(_py2exe): - def run(self): root = get_root() cfg = get_config_from_root(root) @@ -1661,13 +1659,15 @@ def run(self): with open(cfg.versionfile_source, "w") as f: LONG = LONG_VERSION_PY[cfg.VCS] f.write( - LONG % { + LONG + % { "DOLLAR": "$", "STYLE": cfg.style, "TAG_PREFIX": cfg.tag_prefix, "PARENTDIR_PREFIX": cfg.parentdir_prefix, "VERSIONFILE_SOURCE": cfg.versionfile_source, - }) + } + ) cmds["py2exe"] = cmd_py2exe @@ -1678,7 +1678,6 @@ def run(self): from distutils.command.sdist import sdist as _sdist class cmd_sdist(_sdist): - def run(self): versions = get_versions() self._versioneer_generated_versions = versions @@ -1696,8 +1695,9 @@ def make_release_tree(self, base_dir, files): # updated value target_versionfile = os.path.join(base_dir, cfg.versionfile_source) print("UPDATING %s" % target_versionfile) - write_to_version_file(target_versionfile, - self._versioneer_generated_versions) + write_to_version_file( + target_versionfile, self._versioneer_generated_versions + ) cmds["sdist"] = cmd_sdist @@ -1754,13 +1754,12 @@ def do_setup(): try: cfg = get_config_from_root(root) except ( - EnvironmentError, - configparser.NoSectionError, - configparser.NoOptionError, + EnvironmentError, + configparser.NoSectionError, + configparser.NoOptionError, ) as e: if isinstance(e, (EnvironmentError, configparser.NoSectionError)): - print("Adding sample versioneer config to setup.cfg", - file=sys.stderr) + print("Adding sample versioneer config to setup.cfg", file=sys.stderr) with open(os.path.join(root, "setup.cfg"), "a") as f: f.write(SAMPLE_CONFIG) print(CONFIG_ERROR, file=sys.stderr) @@ -1770,13 +1769,15 @@ def do_setup(): with open(cfg.versionfile_source, "w") as f: LONG = LONG_VERSION_PY[cfg.VCS] f.write( - LONG % { + LONG + % { "DOLLAR": "$", "STYLE": cfg.style, "TAG_PREFIX": cfg.tag_prefix, "PARENTDIR_PREFIX": cfg.parentdir_prefix, "VERSIONFILE_SOURCE": cfg.versionfile_source, - }) + } + ) ipy = os.path.join(os.path.dirname(cfg.versionfile_source), "__init__.py") if os.path.exists(ipy): @@ -1820,8 +1821,10 @@ def do_setup(): else: print(" 'versioneer.py' already in MANIFEST.in") if cfg.versionfile_source not in simple_includes: - print(" appending versionfile_source ('%s') to MANIFEST.in" % - cfg.versionfile_source) + print( + " appending versionfile_source ('%s') to MANIFEST.in" + % cfg.versionfile_source + ) with open(manifest_in, "a") as f: f.write("include %s\n" % cfg.versionfile_source) else: diff --git a/xdsl/_version.py b/xdsl/_version.py index fd27ae7314..24bc1dccac 100644 --- a/xdsl/_version.py +++ b/xdsl/_version.py @@ -32,6 +32,7 @@ def get_keywords(): class VersioneerConfig: """Container for Versioneer configuration parameters.""" + VCS: str style: str tag_prefix: str @@ -75,12 +76,7 @@ def decorate(f): return decorate -def run_command(commands, - args, - cwd=None, - verbose=False, - hide_stderr=False, - env=None): +def run_command(commands, args, cwd=None, verbose=False, hide_stderr=False, env=None): """Call the given command(s).""" assert isinstance(commands, list) p = None @@ -106,7 +102,7 @@ def run_command(commands, return None, None else: if verbose: - print("unable to find command, tried %s" % (commands, )) + print("unable to find command, tried %s" % (commands,)) return None, None stdout = p.communicate()[0].strip() if sys.version_info[0] >= 3: @@ -132,7 +128,7 @@ def versions_from_parentdir(parentdir_prefix, root, verbose): dirname = os.path.basename(root) if dirname.startswith(parentdir_prefix): return { - "version": dirname[len(parentdir_prefix):], + "version": dirname[len(parentdir_prefix) :], "full-revisionid": None, "dirty": False, "error": None, @@ -143,8 +139,10 @@ def versions_from_parentdir(parentdir_prefix, root, verbose): root = os.path.dirname(root) # up a level if verbose: - print("Tried directories %s but none started with prefix %s" % - (str(rootdirs), parentdir_prefix)) + print( + "Tried directories %s but none started with prefix %s" + % (str(rootdirs), parentdir_prefix) + ) raise NotThisMethod("rootdir doesn't start with parentdir_prefix") @@ -200,7 +198,7 @@ def git_versions_from_keywords(keywords, tag_prefix, verbose): # starting in git-1.8.3, tags are listed as "tag: foo-1.0" instead of # just "foo-1.0". If we see a "tag: " prefix, prefer those. TAG = "tag: " - tags = set([r[len(TAG):] for r in refs if r.startswith(TAG)]) + tags = set([r[len(TAG) :] for r in refs if r.startswith(TAG)]) if not tags: # Either we're using git < 1.8.3, or there really are no tags. We use # a heuristic: assume all version tags have a digit. The old git %d @@ -217,7 +215,7 @@ def git_versions_from_keywords(keywords, tag_prefix, verbose): for ref in sorted(tags): # sorting will prefer e.g. "2.0" over "2.0rc1" if ref.startswith(tag_prefix): - r = ref[len(tag_prefix):] + r = ref[len(tag_prefix) :] if verbose: print("picking %s" % r) return { @@ -251,9 +249,7 @@ def git_pieces_from_vcs(tag_prefix, root, verbose, run_command=run_command): if sys.platform == "win32": GITS = ["git.cmd", "git.exe"] - out, rc = run_command(GITS, ["rev-parse", "--git-dir"], - cwd=root, - hide_stderr=True) + out, rc = run_command(GITS, ["rev-parse", "--git-dir"], cwd=root, hide_stderr=True) if rc != 0: if verbose: print("Directory %s not under git control" % root) @@ -296,7 +292,7 @@ def git_pieces_from_vcs(tag_prefix, root, verbose, run_command=run_command): dirty = git_describe.endswith("-dirty") pieces["dirty"] = dirty if dirty: - git_describe = git_describe[:git_describe.rindex("-dirty")] + git_describe = git_describe[: git_describe.rindex("-dirty")] # now we have TAG-NUM-gHEX or HEX @@ -305,8 +301,7 @@ def git_pieces_from_vcs(tag_prefix, root, verbose, run_command=run_command): mo = re.search(r"^(.+)-(\d+)-g([0-9a-f]+)$", git_describe) if not mo: # unparseable. Maybe git-describe is misbehaving? - pieces[ - "error"] = "unable to parse git-describe output: '%s'" % describe_out + pieces["error"] = "unable to parse git-describe output: '%s'" % describe_out return pieces # tag @@ -320,7 +315,7 @@ def git_pieces_from_vcs(tag_prefix, root, verbose, run_command=run_command): tag_prefix, ) return pieces - pieces["closest-tag"] = full_tag[len(tag_prefix):] + pieces["closest-tag"] = full_tag[len(tag_prefix) :] # distance: number of commits since tag pieces["distance"] = int(mo.group(2)) @@ -331,13 +326,13 @@ def git_pieces_from_vcs(tag_prefix, root, verbose, run_command=run_command): else: # HEX: no tags pieces["closest-tag"] = None - count_out, rc = run_command(GITS, ["rev-list", "HEAD", "--count"], - cwd=root) + count_out, rc = run_command(GITS, ["rev-list", "HEAD", "--count"], cwd=root) pieces["distance"] = int(count_out) # total number of commits # commit date: see ISO-8601 comment in git_versions_from_keywords() - date = run_command(GITS, ["show", "-s", "--format=%ci", "HEAD"], - cwd=root)[0].strip() + date = run_command(GITS, ["show", "-s", "--format=%ci", "HEAD"], cwd=root)[ + 0 + ].strip() pieces["date"] = date.strip().replace(" ", "T", 1).replace(" ", "", 1) return pieces @@ -528,8 +523,7 @@ def get_versions(): verbose = cfg.verbose try: - return git_versions_from_keywords(get_keywords(), cfg.tag_prefix, - verbose) + return git_versions_from_keywords(get_keywords(), cfg.tag_prefix, verbose) except NotThisMethod: pass diff --git a/xdsl/builder.py b/xdsl/builder.py index 8a5dda1d72..b01e46befd 100644 --- a/xdsl/builder.py +++ b/xdsl/builder.py @@ -16,7 +16,7 @@ class Builder: """ A helper class to construct IRs, by keeping track of where to insert an operation. Currently the insertion point is always at the end of the block. - In the future will closely follow the API of `OpBuilder` in MLIR, inserting + In the future will closely follow the API of `OpBuilder` in MLIR, inserting at arbitrary locations. https://mlir.llvm.org/doxygen/classmlir_1_1OpBuilder.html @@ -36,8 +36,8 @@ def insert(self, op: OperationInvT) -> OperationInvT: if implicit_builder is not None and implicit_builder is not self: raise ValueError( - 'Cannot insert operation explicitly when an implicit ' - 'builder exists.') + "Cannot insert operation explicitly when an implicit " "builder exists." + ) self.block.add_op(op) return op @@ -54,7 +54,7 @@ def _region_no_args(func: Callable[[Builder], None]) -> Region: @staticmethod def _region_args( - input_types: Sequence[Attribute] | ArrayAttr[Attribute] + input_types: Sequence[Attribute] | ArrayAttr[Attribute], ) -> Callable[[_CallableRegionFuncType], Region]: """ Decorator for constructing a single-block region, containing the implementation of a @@ -78,7 +78,7 @@ def wrapper(func: _CallableRegionFuncType) -> Region: @overload @staticmethod def region( - input: Sequence[Attribute] | ArrayAttr[Attribute] + input: Sequence[Attribute] | ArrayAttr[Attribute], ) -> Callable[[_CallableRegionFuncType], Region]: """ Annotation used to construct a Region tuple from a function. @@ -107,8 +107,7 @@ def region(input: Callable[[Builder], None]) -> Region: @staticmethod def region( - input: Sequence[Attribute] | ArrayAttr[Attribute] - | Callable[[Builder], None] + input: Sequence[Attribute] | ArrayAttr[Attribute] | Callable[[Builder], None] ) -> Callable[[_CallableRegionFuncType], Region] | Region: if isinstance(input, Callable): return Builder._region_no_args(input) @@ -130,7 +129,7 @@ def _implicit_region_no_args(func: Callable[[], None]) -> Region: @staticmethod def _implicit_region_args( - input_types: Sequence[Attribute] | ArrayAttr[Attribute] + input_types: Sequence[Attribute] | ArrayAttr[Attribute], ) -> Callable[[_CallableImplicitRegionFuncType], Region]: """ Decorator for constructing a single-block region, containing the implementation of a @@ -155,7 +154,7 @@ def wrapper(func: _CallableImplicitRegionFuncType) -> Region: @overload @staticmethod def implicit_region( - input: Sequence[Attribute] | ArrayAttr[Attribute] + input: Sequence[Attribute] | ArrayAttr[Attribute], ) -> Callable[[_CallableImplicitRegionFuncType], Region]: """ Annotation used to construct a Region tuple from a function. @@ -184,8 +183,7 @@ def implicit_region(input: Callable[[], None]) -> Region: @staticmethod def implicit_region( - input: Sequence[Attribute] | ArrayAttr[Attribute] - | Callable[[], None] + input: Sequence[Attribute] | ArrayAttr[Attribute] | Callable[[], None] ) -> Callable[[_CallableImplicitRegionFuncType], Region] | Region: if isinstance(input, Callable): return Builder._implicit_region_no_args(input) @@ -202,6 +200,7 @@ class _ImplicitBuilderStack(threading.local): Stores the stack of implicit builders for use in @Builder.implicit_region, empty by default. There is a stack per thread, guaranteed by inheriting from `threading.local`. """ + stack: list[Builder] = field(default_factory=list) def push(self, builder: Builder) -> None: @@ -231,9 +230,12 @@ class _ImplicitBuilder(contextlib.AbstractContextManager[None]): def __enter__(self) -> None: type(self)._stack.push(self.builder) - def __exit__(self, __exc_type: type[BaseException] | None, - __exc_value: BaseException | None, - __traceback: TracebackType | None) -> bool | None: + def __exit__( + self, + __exc_type: type[BaseException] | None, + __exc_value: BaseException | None, + __traceback: TracebackType | None, + ) -> bool | None: type(self)._stack.pop(self.builder) @classmethod @@ -242,9 +244,9 @@ def get(cls) -> Builder | None: _CallableRegionFuncType: TypeAlias = Callable[ - [Builder, tuple[BlockArgument, ...]], None] -_CallableImplicitRegionFuncType: TypeAlias = Callable[ - [tuple[BlockArgument, ...]], None] + [Builder, tuple[BlockArgument, ...]], None +] +_CallableImplicitRegionFuncType: TypeAlias = Callable[[tuple[BlockArgument, ...]], None] def _op_init_callback(op: Operation): diff --git a/xdsl/dialects/affine.py b/xdsl/dialects/affine.py index 70b752c9c4..0942513a69 100644 --- a/xdsl/dialects/affine.py +++ b/xdsl/dialects/affine.py @@ -4,8 +4,14 @@ from xdsl.dialects.builtin import AnyIntegerAttr, IndexType, IntegerAttr from xdsl.ir import Attribute, Operation, SSAValue, Block, Region, Dialect -from xdsl.irdl import (OpAttr, VarOpResult, irdl_op_definition, VarOperand, - AnyAttr, IRDLOperation) +from xdsl.irdl import ( + OpAttr, + VarOpResult, + irdl_op_definition, + VarOperand, + AnyAttr, + IRDLOperation, +) @irdl_op_definition @@ -28,7 +34,7 @@ def verify_(self) -> None: raise Exception("Expected the same amount of operands and results") operand_types = [SSAValue.get(op).typ for op in self.operands] - if (operand_types != [res.typ for res in self.results]): + if operand_types != [res.typ for res in self.results]: raise Exception( "Expected all operands and result pairs to have matching types" ) @@ -42,11 +48,13 @@ def verify_(self) -> None: ) @staticmethod - def from_region(operands: list[Operation | SSAValue], - lower_bound: int | AnyIntegerAttr, - upper_bound: int | AnyIntegerAttr, - region: Region, - step: int | AnyIntegerAttr = 1) -> For: + def from_region( + operands: list[Operation | SSAValue], + lower_bound: int | AnyIntegerAttr, + upper_bound: int | AnyIntegerAttr, + region: Region, + step: int | AnyIntegerAttr = 1, + ) -> For: if isinstance(lower_bound, int): lower_bound = IntegerAttr.from_index_int_value(lower_bound) if isinstance(upper_bound, int): @@ -59,21 +67,29 @@ def from_region(operands: list[Operation | SSAValue], "upper_bound": upper_bound, "step": step, } - return For.build(operands=[[operand for operand in operands]], - result_types=[result_types], - attributes=attributes, - regions=[region]) + return For.build( + operands=[[operand for operand in operands]], + result_types=[result_types], + attributes=attributes, + regions=[region], + ) @staticmethod - def from_callable(operands: list[Operation | SSAValue], - lower_bound: int | AnyIntegerAttr, - upper_bound: int | AnyIntegerAttr, - body: Block.BlockCallback, - step: int | AnyIntegerAttr = 1) -> For: + def from_callable( + operands: list[Operation | SSAValue], + lower_bound: int | AnyIntegerAttr, + upper_bound: int | AnyIntegerAttr, + body: Block.BlockCallback, + step: int | AnyIntegerAttr = 1, + ) -> For: arg_types = [IndexType()] + [SSAValue.get(op).typ for op in operands] - return For.from_region(operands, lower_bound, upper_bound, - Region(Block.from_callable(arg_types, body)), - step) + return For.from_region( + operands, + lower_bound, + upper_bound, + Region(Block.from_callable(arg_types, body)), + step, + ) @irdl_op_definition @@ -83,8 +99,7 @@ class Yield(IRDLOperation): @staticmethod def get(*operands: SSAValue | Operation) -> Yield: - return Yield.create( - operands=[SSAValue.get(operand) for operand in operands]) + return Yield.create(operands=[SSAValue.get(operand) for operand in operands]) Affine = Dialect([For, Yield], []) diff --git a/xdsl/dialects/arith.py b/xdsl/dialects/arith.py index bd66395c63..53fbb50d06 100644 --- a/xdsl/dialects/arith.py +++ b/xdsl/dialects/arith.py @@ -5,13 +5,31 @@ from enum import Enum from typing import Annotated, TypeVar, Union, Set, Optional -from xdsl.dialects.builtin import (ContainerOf, Float16Type, Float64Type, - IndexType, IntAttr, IntegerType, - Float32Type, IntegerAttr, FloatAttr, - Attribute, AnyFloat, AnyIntegerAttr) +from xdsl.dialects.builtin import ( + ContainerOf, + Float16Type, + Float64Type, + IndexType, + IntAttr, + IntegerType, + Float32Type, + IntegerAttr, + FloatAttr, + Attribute, + AnyFloat, + AnyIntegerAttr, +) from xdsl.ir import Operation, SSAValue, Dialect, OpResult, Data -from xdsl.irdl import (AnyOf, irdl_op_definition, OpAttr, AnyAttr, Operand, - irdl_attr_definition, OptOpAttr, IRDLOperation) +from xdsl.irdl import ( + AnyOf, + irdl_op_definition, + OpAttr, + AnyAttr, + Operand, + irdl_attr_definition, + OptOpAttr, + IRDLOperation, +) from xdsl.parser import BaseParser from xdsl.printer import Printer from xdsl.utils.exceptions import VerifyException @@ -19,7 +37,7 @@ signlessIntegerLike = ContainerOf(AnyOf([IntegerType, IndexType])) floatingPointLike = ContainerOf(AnyOf([Float16Type, Float32Type, Float64Type])) -_FloatTypeT = TypeVar('_FloatTypeT', bound=AnyFloat) +_FloatTypeT = TypeVar("_FloatTypeT", bound=AnyFloat) class FastMathFlag(Enum): @@ -63,10 +81,10 @@ class FastMathFlagsAttr(Data[FastMathFlags]): @staticmethod def parse_parameter(parser: BaseParser) -> FastMathFlags: - flags = parser.parse_list_of(lambda: FastMathFlags.try_parse(parser), - "Expected fast math flags") - result = functools.reduce(FastMathFlags.__or__, flags, - FastMathFlags(set())) + flags = parser.parse_list_of( + lambda: FastMathFlags.try_parse(parser), "Expected fast math flags" + ) + result = functools.reduce(FastMathFlags.__or__, flags, FastMathFlags(set())) return result def print_parameter(self, printer: Printer): @@ -77,8 +95,7 @@ def print_parameter(self, printer: Printer): printer.print("fast") else: # make sure we emit flags in a consistent order - printer.print(",".join(flag.value for flag in FastMathFlag - if flag in data)) + printer.print(",".join(flag.value for flag in FastMathFlag if flag in data)) @staticmethod def from_flags(flags: FastMathFlags): @@ -96,18 +113,20 @@ def from_attr(attr: Attribute, typ: Attribute) -> Constant: return Constant.create(result_types=[typ], attributes={"value": attr}) @staticmethod - def from_int_and_width(val: int | IntAttr, - typ: int | IntegerType | IndexType) -> Constant: + def from_int_and_width( + val: int | IntAttr, typ: int | IntegerType | IndexType + ) -> Constant: if isinstance(typ, int): typ = IntegerType(typ) return Constant.create( - result_types=[typ], - attributes={"value": IntegerAttr.from_params(val, typ)}) + result_types=[typ], attributes={"value": IntegerAttr.from_params(val, typ)} + ) # To add tests for this constructor @staticmethod - def from_float_and_width(val: float | FloatAttr[_FloatTypeT], - typ: _FloatTypeT) -> Constant: + def from_float_and_width( + val: float | FloatAttr[_FloatTypeT], typ: _FloatTypeT + ) -> Constant: if isinstance(val, float): val = FloatAttr(val, typ) return Constant.create(result_types=[typ], attributes={"value": val}) @@ -120,12 +139,9 @@ class BinaryOperation(IRDLOperation): # TODO replace with trait def verify_(self) -> None: if len(self.operands) != 2 or len(self.results) != 1: - raise VerifyException( - "Binary operation expects 2 operands and 1 result.") - if not (self.operands[0].typ == self.operands[1].typ == - self.results[0].typ): - raise VerifyException( - "expect all input and result types to be equal") + raise VerifyException("Binary operation expects 2 operands and 1 result.") + if not (self.operands[0].typ == self.operands[1].typ == self.results[0].typ): + raise VerifyException("expect all input and result types to be equal") def __hash__(self) -> int: return id(self) @@ -139,11 +155,11 @@ class Addi(BinaryOperation): result: Annotated[OpResult, signlessIntegerLike] @staticmethod - def get(operand1: Union[Operation, SSAValue], - operand2: Union[Operation, SSAValue]) -> Addi: + def get( + operand1: Union[Operation, SSAValue], operand2: Union[Operation, SSAValue] + ) -> Addi: operand1 = SSAValue.get(operand1) - return Addi.build(operands=[operand1, operand2], - result_types=[operand1.typ]) + return Addi.build(operands=[operand1, operand2], result_types=[operand1.typ]) @irdl_op_definition @@ -154,11 +170,11 @@ class Muli(BinaryOperation): result: Annotated[OpResult, signlessIntegerLike] @staticmethod - def get(operand1: Union[Operation, SSAValue], - operand2: Union[Operation, SSAValue]) -> Muli: + def get( + operand1: Union[Operation, SSAValue], operand2: Union[Operation, SSAValue] + ) -> Muli: operand1 = SSAValue.get(operand1) - return Muli.build(operands=[operand1, operand2], - result_types=[operand1.typ]) + return Muli.build(operands=[operand1, operand2], result_types=[operand1.typ]) @irdl_op_definition @@ -169,11 +185,11 @@ class Subi(BinaryOperation): result: Annotated[OpResult, signlessIntegerLike] @staticmethod - def get(operand1: Union[Operation, SSAValue], - operand2: Union[Operation, SSAValue]) -> Subi: + def get( + operand1: Union[Operation, SSAValue], operand2: Union[Operation, SSAValue] + ) -> Subi: operand1 = SSAValue.get(operand1) - return Subi.build(operands=[operand1, operand2], - result_types=[operand1.typ]) + return Subi.build(operands=[operand1, operand2], result_types=[operand1.typ]) @irdl_op_definition @@ -183,17 +199,18 @@ class DivUI(BinaryOperation): the most significant, i.e. for `i16` given two's complement representation, `6 / -2 = 6 / (2^16 - 2) = 0`. """ + name: str = "arith.divui" lhs: Annotated[Operand, signlessIntegerLike] rhs: Annotated[Operand, signlessIntegerLike] result: Annotated[OpResult, signlessIntegerLike] @staticmethod - def get(operand1: Union[Operation, SSAValue], - operand2: Union[Operation, SSAValue]) -> DivUI: + def get( + operand1: Union[Operation, SSAValue], operand2: Union[Operation, SSAValue] + ) -> DivUI: operand1 = SSAValue.get(operand1) - return DivUI.build(operands=[operand1, operand2], - result_types=[operand1.typ]) + return DivUI.build(operands=[operand1, operand2], result_types=[operand1.typ]) @irdl_op_definition @@ -202,17 +219,18 @@ class DivSI(BinaryOperation): Signed integer division. Rounds towards zero. Treats the leading bit as sign, i.e. `6 / -2 = -3`. """ + name: str = "arith.divsi" lhs: Annotated[Operand, signlessIntegerLike] rhs: Annotated[Operand, signlessIntegerLike] result: Annotated[OpResult, signlessIntegerLike] @staticmethod - def get(operand1: Union[Operation, SSAValue], - operand2: Union[Operation, SSAValue]) -> DivSI: + def get( + operand1: Union[Operation, SSAValue], operand2: Union[Operation, SSAValue] + ) -> DivSI: operand1 = SSAValue.get(operand1) - return DivSI.build(operands=[operand1, operand2], - result_types=[operand1.typ]) + return DivSI.build(operands=[operand1, operand2], result_types=[operand1.typ]) @irdl_op_definition @@ -220,17 +238,20 @@ class FloorDivSI(BinaryOperation): """ Signed floor integer division. Rounds towards negative infinity i.e. `5 / -2 = -3`. """ + name: str = "arith.floordivsi" lhs: Annotated[Operand, signlessIntegerLike] rhs: Annotated[Operand, signlessIntegerLike] result: Annotated[OpResult, signlessIntegerLike] @staticmethod - def get(operand1: Union[Operation, SSAValue], - operand2: Union[Operation, SSAValue]) -> FloorDivSI: + def get( + operand1: Union[Operation, SSAValue], operand2: Union[Operation, SSAValue] + ) -> FloorDivSI: operand1 = SSAValue.get(operand1) - return FloorDivSI.build(operands=[operand1, operand2], - result_types=[operand1.typ]) + return FloorDivSI.build( + operands=[operand1, operand2], result_types=[operand1.typ] + ) @irdl_op_definition @@ -241,11 +262,13 @@ class CeilDivSI(BinaryOperation): result: Annotated[OpResult, signlessIntegerLike] @staticmethod - def get(operand1: Union[Operation, SSAValue], - operand2: Union[Operation, SSAValue]) -> CeilDivSI: + def get( + operand1: Union[Operation, SSAValue], operand2: Union[Operation, SSAValue] + ) -> CeilDivSI: operand1 = SSAValue.get(operand1) - return CeilDivSI.build(operands=[operand1, operand2], - result_types=[operand1.typ]) + return CeilDivSI.build( + operands=[operand1, operand2], result_types=[operand1.typ] + ) @irdl_op_definition @@ -256,11 +279,13 @@ class CeilDivUI(BinaryOperation): result: Annotated[OpResult, signlessIntegerLike] @staticmethod - def get(operand1: Union[Operation, SSAValue], - operand2: Union[Operation, SSAValue]) -> CeilDivUI: + def get( + operand1: Union[Operation, SSAValue], operand2: Union[Operation, SSAValue] + ) -> CeilDivUI: operand1 = SSAValue.get(operand1) - return CeilDivUI.build(operands=[operand1, operand2], - result_types=[operand1.typ]) + return CeilDivUI.build( + operands=[operand1, operand2], result_types=[operand1.typ] + ) @irdl_op_definition @@ -271,11 +296,11 @@ class RemUI(BinaryOperation): result: Annotated[OpResult, signlessIntegerLike] @staticmethod - def get(operand1: Union[Operation, SSAValue], - operand2: Union[Operation, SSAValue]) -> RemUI: + def get( + operand1: Union[Operation, SSAValue], operand2: Union[Operation, SSAValue] + ) -> RemUI: operand1 = SSAValue.get(operand1) - return RemUI.build(operands=[operand1, operand2], - result_types=[operand1.typ]) + return RemUI.build(operands=[operand1, operand2], result_types=[operand1.typ]) @irdl_op_definition @@ -286,11 +311,11 @@ class RemSI(BinaryOperation): result: Annotated[OpResult, IntegerType] @staticmethod - def get(operand1: Union[Operation, SSAValue], - operand2: Union[Operation, SSAValue]) -> RemSI: + def get( + operand1: Union[Operation, SSAValue], operand2: Union[Operation, SSAValue] + ) -> RemSI: operand1 = SSAValue.get(operand1) - return RemSI.build(operands=[operand1, operand2], - result_types=[operand1.typ]) + return RemSI.build(operands=[operand1, operand2], result_types=[operand1.typ]) @irdl_op_definition @@ -301,11 +326,11 @@ class MinUI(BinaryOperation): result: Annotated[OpResult, signlessIntegerLike] @staticmethod - def get(operand1: Union[Operation, SSAValue], - operand2: Union[Operation, SSAValue]) -> MinUI: + def get( + operand1: Union[Operation, SSAValue], operand2: Union[Operation, SSAValue] + ) -> MinUI: operand1 = SSAValue.get(operand1) - return MinUI.build(operands=[operand1, operand2], - result_types=[operand1.typ]) + return MinUI.build(operands=[operand1, operand2], result_types=[operand1.typ]) @irdl_op_definition @@ -316,11 +341,11 @@ class MaxUI(BinaryOperation): result: Annotated[OpResult, signlessIntegerLike] @staticmethod - def get(operand1: Union[Operation, SSAValue], - operand2: Union[Operation, SSAValue]) -> MaxUI: + def get( + operand1: Union[Operation, SSAValue], operand2: Union[Operation, SSAValue] + ) -> MaxUI: operand1 = SSAValue.get(operand1) - return MaxUI.build(operands=[operand1, operand2], - result_types=[operand1.typ]) + return MaxUI.build(operands=[operand1, operand2], result_types=[operand1.typ]) @irdl_op_definition @@ -331,11 +356,11 @@ class MinSI(BinaryOperation): result: Annotated[OpResult, signlessIntegerLike] @staticmethod - def get(operand1: Union[Operation, SSAValue], - operand2: Union[Operation, SSAValue]) -> MinSI: + def get( + operand1: Union[Operation, SSAValue], operand2: Union[Operation, SSAValue] + ) -> MinSI: operand1 = SSAValue.get(operand1) - return MinSI.build(operands=[operand1, operand2], - result_types=[operand1.typ]) + return MinSI.build(operands=[operand1, operand2], result_types=[operand1.typ]) @irdl_op_definition @@ -346,11 +371,11 @@ class MaxSI(BinaryOperation): result: Annotated[OpResult, signlessIntegerLike] @staticmethod - def get(operand1: Union[Operation, SSAValue], - operand2: Union[Operation, SSAValue]) -> MaxSI: + def get( + operand1: Union[Operation, SSAValue], operand2: Union[Operation, SSAValue] + ) -> MaxSI: operand1 = SSAValue.get(operand1) - return MaxSI.build(operands=[operand1, operand2], - result_types=[operand1.typ]) + return MaxSI.build(operands=[operand1, operand2], result_types=[operand1.typ]) @irdl_op_definition @@ -361,11 +386,11 @@ class AndI(BinaryOperation): result: Annotated[OpResult, signlessIntegerLike] @staticmethod - def get(operand1: Union[Operation, SSAValue], - operand2: Union[Operation, SSAValue]) -> AndI: + def get( + operand1: Union[Operation, SSAValue], operand2: Union[Operation, SSAValue] + ) -> AndI: operand1 = SSAValue.get(operand1) - return AndI.build(operands=[operand1, operand2], - result_types=[operand1.typ]) + return AndI.build(operands=[operand1, operand2], result_types=[operand1.typ]) @irdl_op_definition @@ -376,11 +401,11 @@ class OrI(BinaryOperation): result: Annotated[OpResult, signlessIntegerLike] @staticmethod - def get(operand1: Union[Operation, SSAValue], - operand2: Union[Operation, SSAValue]) -> OrI: + def get( + operand1: Union[Operation, SSAValue], operand2: Union[Operation, SSAValue] + ) -> OrI: operand1 = SSAValue.get(operand1) - return OrI.build(operands=[operand1, operand2], - result_types=[operand1.typ]) + return OrI.build(operands=[operand1, operand2], result_types=[operand1.typ]) @irdl_op_definition @@ -391,11 +416,11 @@ class XOrI(BinaryOperation): result: Annotated[OpResult, signlessIntegerLike] @staticmethod - def get(operand1: Union[Operation, SSAValue], - operand2: Union[Operation, SSAValue]) -> XOrI: + def get( + operand1: Union[Operation, SSAValue], operand2: Union[Operation, SSAValue] + ) -> XOrI: operand1 = SSAValue.get(operand1) - return XOrI.build(operands=[operand1, operand2], - result_types=[operand1.typ]) + return XOrI.build(operands=[operand1, operand2], result_types=[operand1.typ]) @irdl_op_definition @@ -404,6 +429,7 @@ class ShLI(IRDLOperation): The `shli` operation shifts an integer value to the left by a variable amount. The low order bits are filled with zeros. """ + name: str = "arith.shli" lhs: Annotated[Operand, IntegerType] rhs: Annotated[Operand, IntegerType] @@ -412,15 +438,14 @@ class ShLI(IRDLOperation): # TODO replace with trait def verify_(self) -> None: if self.lhs.typ != self.rhs.typ or self.rhs.typ != self.result.typ: - raise VerifyException( - "expect all input and output types to be equal") + raise VerifyException("expect all input and output types to be equal") @staticmethod - def get(operand1: Union[Operation, SSAValue], - operand2: Union[Operation, SSAValue]) -> ShLI: + def get( + operand1: Union[Operation, SSAValue], operand2: Union[Operation, SSAValue] + ) -> ShLI: operand1 = SSAValue.get(operand1) - return ShLI.build(operands=[operand1, operand2], - result_types=[operand1.typ]) + return ShLI.build(operands=[operand1, operand2], result_types=[operand1.typ]) @irdl_op_definition @@ -430,6 +455,7 @@ class ShRUI(IRDLOperation): amount. The integer is interpreted as unsigned. The high order bits are always filled with zeros. """ + name: str = "arith.shrui" lhs: Annotated[Operand, signlessIntegerLike] rhs: Annotated[Operand, signlessIntegerLike] @@ -438,15 +464,14 @@ class ShRUI(IRDLOperation): # TODO replace with trait def verify_(self) -> None: if self.lhs.typ != self.rhs.typ or self.rhs.typ != self.result.typ: - raise VerifyException( - "expect all input and output types to be equal") + raise VerifyException("expect all input and output types to be equal") @staticmethod - def get(operand1: Union[Operation, SSAValue], - operand2: Union[Operation, SSAValue]) -> ShRUI: + def get( + operand1: Union[Operation, SSAValue], operand2: Union[Operation, SSAValue] + ) -> ShRUI: operand1 = SSAValue.get(operand1) - return ShRUI.build(operands=[operand1, operand2], - result_types=[operand1.typ]) + return ShRUI.build(operands=[operand1, operand2], result_types=[operand1.typ]) @irdl_op_definition @@ -457,6 +482,7 @@ class ShRSI(IRDLOperation): output are filled with copies of the most-significant bit of the shifted value (which means that the sign of the value is preserved). """ + name: str = "arith.shrsi" lhs: Annotated[Operand, IntegerType] rhs: Annotated[Operand, IntegerType] @@ -465,19 +491,18 @@ class ShRSI(IRDLOperation): # TODO replace with trait def verify_(self) -> None: if self.lhs.typ != self.rhs.typ or self.rhs.typ != self.result.typ: - raise VerifyException( - "expect all input and output types to be equal") + raise VerifyException("expect all input and output types to be equal") @staticmethod - def get(operand1: Union[Operation, SSAValue], - operand2: Union[Operation, SSAValue]) -> ShRSI: + def get( + operand1: Union[Operation, SSAValue], operand2: Union[Operation, SSAValue] + ) -> ShRSI: operand1 = SSAValue.get(operand1) - return ShRSI.build(operands=[operand1, operand2], - result_types=[operand1.typ]) + return ShRSI.build(operands=[operand1, operand2], result_types=[operand1.typ]) @dataclass -class ComparisonOperation(): +class ComparisonOperation: """ A generic comparison operation, operation definitions inherit this class. @@ -498,7 +523,8 @@ class ComparisonOperation(): @staticmethod def _get_comparison_predicate( - mnemonic: str, comparison_operations: dict[str, int]) -> int: + mnemonic: str, comparison_operations: dict[str, int] + ) -> int: if mnemonic in comparison_operations: return comparison_operations[mnemonic] else: @@ -507,8 +533,10 @@ def _get_comparison_predicate( @staticmethod def _validate_operand_types(operand1: SSAValue, operand2: SSAValue): if operand1.typ != operand2.typ: - raise TypeError(f"Comparison operands must have same type, but " - f"provided {operand1.typ} and {operand2.typ}") + raise TypeError( + f"Comparison operands must have same type, but " + f"provided {operand1.typ} and {operand2.typ}" + ) @irdl_op_definition @@ -539,6 +567,7 @@ class Cmpi(IRDLOperation, ComparisonOperation): %x = "arith.cmpi"(%lhs, %rhs) {predicate = 0 : i64} : (vector<4xi64>, vector<4xi64>) -> vector<4xi1> """ + name: str = "arith.cmpi" predicate: OpAttr[AnyIntegerAttr] lhs: Annotated[Operand, IntegerType] @@ -546,8 +575,11 @@ class Cmpi(IRDLOperation, ComparisonOperation): result: Annotated[OpResult, IntegerType(1)] @staticmethod - def get(operand1: Union[Operation, SSAValue], - operand2: Union[Operation, SSAValue], arg: int | str) -> Cmpi: + def get( + operand1: Union[Operation, SSAValue], + operand2: Union[Operation, SSAValue], + arg: int | str, + ) -> Cmpi: operand1 = SSAValue.get(operand1) operand2 = SSAValue.get(operand2) Cmpi._validate_operand_types(operand1, operand2) @@ -563,15 +595,15 @@ def get(operand1: Union[Operation, SSAValue], "ult": 6, "ule": 7, "ugt": 8, - "uge": 9 + "uge": 9, } - arg = Cmpi._get_comparison_predicate(arg, - cmpi_comparison_operations) + arg = Cmpi._get_comparison_predicate(arg, cmpi_comparison_operations) return Cmpi.build( operands=[operand1, operand2], result_types=[IntegerType(1)], - attributes={"predicate": IntegerAttr.from_int_and_width(arg, 64)}) + attributes={"predicate": IntegerAttr.from_int_and_width(arg, 64)}, + ) @irdl_op_definition @@ -599,6 +631,7 @@ class Cmpf(IRDLOperation, ComparisonOperation): %r2 = arith.cmpf ult, %0, %1 : tensor<42x42xf64> %r3 = "arith.cmpf"(%0, %1) {predicate: 0} : (f8, f8) -> i1 """ + name: str = "arith.cmpf" predicate: OpAttr[AnyIntegerAttr] lhs: Annotated[Operand, floatingPointLike] @@ -606,8 +639,9 @@ class Cmpf(IRDLOperation, ComparisonOperation): result: Annotated[OpResult, IntegerType(1)] @staticmethod - def get(operand1: SSAValue | Operation, operand2: SSAValue | Operation, - arg: int | str) -> Cmpf: + def get( + operand1: SSAValue | Operation, operand2: SSAValue | Operation, arg: int | str + ) -> Cmpf: operand1 = SSAValue.get(operand1) operand2 = SSAValue.get(operand2) @@ -630,15 +664,15 @@ def get(operand1: SSAValue | Operation, operand2: SSAValue | Operation, "ule": 12, "une": 13, "uno": 14, - "true": 15 + "true": 15, } - arg = Cmpf._get_comparison_predicate(arg, - cmpf_comparison_operations) + arg = Cmpf._get_comparison_predicate(arg, cmpf_comparison_operations) return Cmpf.build( operands=[operand1, operand2], result_types=[IntegerType(1)], - attributes={"predicate": IntegerAttr.from_int_and_width(arg, 64)}) + attributes={"predicate": IntegerAttr.from_int_and_width(arg, 64)}, + ) @irdl_op_definition @@ -649,6 +683,7 @@ class Select(IRDLOperation): the second operand is chosen, otherwise the third operand is chosen. The second and the third operand must have the same type. """ + name: str = "arith.select" cond: Annotated[Operand, IntegerType(1)] # should be unsigned lhs: Annotated[Operand, Attribute] @@ -660,16 +695,18 @@ def verify_(self) -> None: if self.cond.typ != IntegerType(1): raise VerifyException("Condition has to be of type !i1") if self.lhs.typ != self.rhs.typ or self.rhs.typ != self.result.typ: - raise VerifyException( - "expect all input and output types to be equal") + raise VerifyException("expect all input and output types to be equal") @staticmethod - def get(operand1: Union[Operation, SSAValue], operand2: Union[Operation, - SSAValue], - operand3: Union[Operation, SSAValue]) -> Select: + def get( + operand1: Union[Operation, SSAValue], + operand2: Union[Operation, SSAValue], + operand3: Union[Operation, SSAValue], + ) -> Select: operand2 = SSAValue.get(operand2) - return Select.build(operands=[operand1, operand2, operand3], - result_types=[operand2.typ]) + return Select.build( + operands=[operand1, operand2, operand3], result_types=[operand2.typ] + ) @irdl_op_definition @@ -680,11 +717,11 @@ class Addf(BinaryOperation): result: Annotated[OpResult, floatingPointLike] @staticmethod - def get(operand1: Union[Operation, SSAValue], - operand2: Union[Operation, SSAValue]) -> Addf: + def get( + operand1: Union[Operation, SSAValue], operand2: Union[Operation, SSAValue] + ) -> Addf: operand1 = SSAValue.get(operand1) - return Addf.build(operands=[operand1, operand2], - result_types=[operand1.typ]) + return Addf.build(operands=[operand1, operand2], result_types=[operand1.typ]) @irdl_op_definition @@ -695,11 +732,11 @@ class Subf(BinaryOperation): result: Annotated[OpResult, floatingPointLike] @staticmethod - def get(operand1: Union[Operation, SSAValue], - operand2: Union[Operation, SSAValue]) -> Subf: + def get( + operand1: Union[Operation, SSAValue], operand2: Union[Operation, SSAValue] + ) -> Subf: operand1 = SSAValue.get(operand1) - return Subf.build(operands=[operand1, operand2], - result_types=[operand1.typ]) + return Subf.build(operands=[operand1, operand2], result_types=[operand1.typ]) @irdl_op_definition @@ -710,11 +747,11 @@ class Mulf(BinaryOperation): result: Annotated[OpResult, floatingPointLike] @staticmethod - def get(operand1: Union[Operation, SSAValue], - operand2: Union[Operation, SSAValue]) -> Mulf: + def get( + operand1: Union[Operation, SSAValue], operand2: Union[Operation, SSAValue] + ) -> Mulf: operand1 = SSAValue.get(operand1) - return Mulf.build(operands=[operand1, operand2], - result_types=[operand1.typ]) + return Mulf.build(operands=[operand1, operand2], result_types=[operand1.typ]) @irdl_op_definition @@ -725,11 +762,11 @@ class Divf(BinaryOperation): result: Annotated[OpResult, floatingPointLike] @staticmethod - def get(operand1: Union[Operation, SSAValue], - operand2: Union[Operation, SSAValue]) -> Divf: + def get( + operand1: Union[Operation, SSAValue], operand2: Union[Operation, SSAValue] + ) -> Divf: operand1 = SSAValue.get(operand1) - return Divf.build(operands=[operand1, operand2], - result_types=[operand1.typ]) + return Divf.build(operands=[operand1, operand2], result_types=[operand1.typ]) @irdl_op_definition @@ -740,13 +777,15 @@ class Negf(IRDLOperation): result: Annotated[OpResult, floatingPointLike] @staticmethod - def get(operand: Union[Operation, SSAValue], - fastmath: FastMathFlagsAttr | None = None) -> Negf: - + def get( + operand: Union[Operation, SSAValue], fastmath: FastMathFlagsAttr | None = None + ) -> Negf: operand = SSAValue.get(operand) - return Negf.build(attributes={"fastmath": fastmath}, - operands=[operand], - result_types=[operand.typ]) + return Negf.build( + attributes={"fastmath": fastmath}, + operands=[operand], + result_types=[operand.typ], + ) @irdl_op_definition @@ -757,11 +796,11 @@ class Maxf(BinaryOperation): result: Annotated[OpResult, floatingPointLike] @staticmethod - def get(operand1: Union[Operation, SSAValue], - operand2: Union[Operation, SSAValue]) -> Maxf: + def get( + operand1: Union[Operation, SSAValue], operand2: Union[Operation, SSAValue] + ) -> Maxf: operand1 = SSAValue.get(operand1) - return Maxf.build(operands=[operand1, operand2], - result_types=[operand1.typ]) + return Maxf.build(operands=[operand1, operand2], result_types=[operand1.typ]) @irdl_op_definition @@ -772,11 +811,11 @@ class Minf(BinaryOperation): result: Annotated[OpResult, floatingPointLike] @staticmethod - def get(operand1: Union[Operation, SSAValue], - operand2: Union[Operation, SSAValue]) -> Minf: + def get( + operand1: Union[Operation, SSAValue], operand2: Union[Operation, SSAValue] + ) -> Minf: operand1 = SSAValue.get(operand1) - return Minf.build(operands=[operand1, operand2], - result_types=[operand1.typ]) + return Minf.build(operands=[operand1, operand2], result_types=[operand1.typ]) @irdl_op_definition @@ -843,7 +882,6 @@ def get(op: SSAValue | Operation, target_typ: AnyFloat): Arith = Dialect( [ Constant, - # Integer-like Addi, Subi, @@ -859,33 +897,27 @@ def get(op: SSAValue | Operation, target_typ: AnyFloat): MaxSI, MinUI, MaxUI, - # Float-like Addf, Subf, Mulf, Divf, Negf, - # Comparison/Condition Cmpi, Cmpf, Select, - # Logical AndI, OrI, XOrI, - # Shift ShLI, ShRUI, ShRSI, - # Min/Max Minf, Maxf, - # Casts IndexCastOp, FPToSIOp, @@ -895,4 +927,5 @@ def get(op: SSAValue | Operation, target_typ: AnyFloat): ], [ FastMathFlagsAttr, - ]) + ], +) diff --git a/xdsl/dialects/builtin.py b/xdsl/dialects/builtin.py index d7bd1fc2c7..aac7f478bf 100644 --- a/xdsl/dialects/builtin.py +++ b/xdsl/dialects/builtin.py @@ -3,19 +3,53 @@ from dataclasses import dataclass from enum import Enum -from typing import (Iterable, TypeAlias, List, cast, Type, Sequence, - TYPE_CHECKING, Any, TypeVar, overload, Iterator) - -from xdsl.ir import (Block, Data, TypeAttribute, ParametrizedAttribute, - Operation, Region, Attribute, Dialect, SSAValue, - AttributeCovT, AttributeInvT) - -from xdsl.irdl import (AllOf, OpAttr, VarOpResult, VarOperand, VarRegion, - irdl_attr_definition, attr_constr_coercion, - irdl_data_definition, irdl_to_attr_constraint, - irdl_op_definition, ParameterDef, SingleBlockRegion, - Generic, GenericData, AttrConstraint, AnyAttr, - IRDLOperation) +from typing import ( + Iterable, + TypeAlias, + List, + cast, + Type, + Sequence, + TYPE_CHECKING, + Any, + TypeVar, + overload, + Iterator, +) + +from xdsl.ir import ( + Block, + Data, + TypeAttribute, + ParametrizedAttribute, + Operation, + Region, + Attribute, + Dialect, + SSAValue, + AttributeCovT, + AttributeInvT, +) + +from xdsl.irdl import ( + AllOf, + OpAttr, + VarOpResult, + VarOperand, + VarRegion, + irdl_attr_definition, + attr_constr_coercion, + irdl_data_definition, + irdl_to_attr_constraint, + irdl_op_definition, + ParameterDef, + SingleBlockRegion, + Generic, + GenericData, + AttrConstraint, + AnyAttr, + IRDLOperation, +) from xdsl.utils.deprecation import deprecated_constructor from xdsl.utils.exceptions import VerifyException @@ -28,6 +62,7 @@ @irdl_attr_definition class NoneAttr(ParametrizedAttribute): """An attribute representing the absence of an attribute.""" + name: str = "none" @@ -37,6 +72,7 @@ class ArrayOfConstraint(AttrConstraint): A constraint that enforces an ArrayData whose elements all satisfy the elem_constr. """ + elem_constr: AttrConstraint def __init__(self, constr: Attribute | Type[Attribute] | AttrConstraint): @@ -50,8 +86,7 @@ def verify(self, attr: Attribute) -> None: @irdl_attr_definition -class ArrayAttr(GenericData[tuple[AttributeCovT, ...]], - Iterable[AttributeCovT]): +class ArrayAttr(GenericData[tuple[AttributeCovT, ...]], Iterable[AttributeCovT]): name: str = "array" def __init__(self, param: Iterable[AttributeCovT]) -> None: @@ -60,8 +95,7 @@ def __init__(self, param: Iterable[AttributeCovT]) -> None: @staticmethod def parse_parameter(parser: BaseParser) -> tuple[AttributeCovT]: parser.parse_char("[") - data = parser.parse_list_of(parser.try_parse_attribute, - "Expected attribute") + data = parser.parse_list_of(parser.try_parse_attribute, "Expected attribute") parser.parse_char("]") # the type system can't ensure that the elements are of type _ArrayAttrT result = cast(tuple[AttributeCovT], tuple(data)) @@ -78,20 +112,24 @@ def generic_constraint_coercion(args: tuple[Any]) -> AttrConstraint: return ArrayOfConstraint(irdl_to_attr_constraint(args[0])) if len(args) == 0: return ArrayOfConstraint(AnyAttr()) - raise TypeError(f"Attribute ArrayAttr expects at most type" - f" parameter, but {len(args)} were given") + raise TypeError( + f"Attribute ArrayAttr expects at most type" + f" parameter, but {len(args)} were given" + ) def verify(self) -> None: if not isinstance(self.data, tuple): raise VerifyException( f"Wrong type given to attribute {self.name}: got" f" {type(self.data)}, but expected list of" - " attributes") + " attributes" + ) for idx, val in enumerate(self.data): if not isinstance(val, Attribute): raise VerifyException( f"{self.name} data expects attribute list, but {idx} " - f"element is of type {type(val)}") + f"element is of type {type(val)}" + ) @staticmethod @deprecated_constructor @@ -161,13 +199,14 @@ class SymbolRefAttr(ParametrizedAttribute): def __init__( self, root: str | StringAttr, - nested: list[str] | list[StringAttr] | ArrayAttr[StringAttr] = [] + nested: list[str] | list[StringAttr] | ArrayAttr[StringAttr] = [], ) -> None: if isinstance(root, str): root = StringAttr(root) if isinstance(nested, list): nested = ArrayAttr( - [StringAttr(x) if isinstance(x, str) else x for x in nested]) + [StringAttr(x) if isinstance(x, str) else x for x in nested] + ) super().__init__([root, nested]) @staticmethod @@ -177,14 +216,15 @@ def from_str(root: str, nested: List[str] = []) -> SymbolRefAttr: @staticmethod @deprecated_constructor - def from_string_attr(root: StringAttr, - nested: List[StringAttr] = []) -> SymbolRefAttr: + def from_string_attr( + root: StringAttr, nested: List[StringAttr] = [] + ) -> SymbolRefAttr: return SymbolRefAttr(root, nested) def string_value(self): root = self.root_reference.data for ref in self.nested_references.data: - root += '.' + ref.data + root += "." + ref.data return root @@ -198,7 +238,7 @@ def parse_parameter(parser: BaseParser) -> int: return data def print_parameter(self, printer: Printer) -> None: - printer.print_string(f'{self.data}') + printer.print_string(f"{self.data}") @staticmethod @deprecated_constructor @@ -222,8 +262,9 @@ class SignednessAttr(Data[Signedness]): @staticmethod def parse_parameter(parser: BaseParser) -> Signedness: - value = parser.expect(parser.try_parse_bare_id, - "Expected `signless`, `signed`, or `unsigned`.") + value = parser.expect( + parser.try_parse_bare_id, "Expected `signless`, `signed`, or `unsigned`." + ) if value.text == "signless": return Signedness.SIGNLESS elif value.text == "signed": @@ -256,9 +297,9 @@ class IntegerType(ParametrizedAttribute, TypeAttribute): signedness: ParameterDef[SignednessAttr] def __init__( - self, - data: int | IntAttr, - signedness: Signedness | SignednessAttr = Signedness.SIGNLESS + self, + data: int | IntAttr, + signedness: Signedness | SignednessAttr = Signedness.SIGNLESS, ) -> None: if isinstance(data, int): data = IntAttr(data) @@ -269,8 +310,8 @@ def __init__( @staticmethod @deprecated_constructor def from_width( - width: int, - signedness: Signedness = Signedness.SIGNLESS) -> IntegerType: + width: int, signedness: Signedness = Signedness.SIGNLESS + ) -> IntegerType: return IntegerType(width, signedness) @@ -289,11 +330,10 @@ class IndexType(ParametrizedAttribute): name = "index" -_IntegerAttrTyp = TypeVar("_IntegerAttrTyp", - bound=IntegerType | IndexType, - covariant=True) -_IntegerAttrTypInv = TypeVar("_IntegerAttrTypInv", - bound=IntegerType | IndexType) +_IntegerAttrTyp = TypeVar( + "_IntegerAttrTyp", bound=IntegerType | IndexType, covariant=True +) +_IntegerAttrTypInv = TypeVar("_IntegerAttrTypInv", bound=IntegerType | IndexType) @irdl_attr_definition @@ -303,17 +343,20 @@ class IntegerAttr(Generic[_IntegerAttrTyp], ParametrizedAttribute): typ: ParameterDef[_IntegerAttrTyp] @overload - def __init__(self: IntegerAttr[_IntegerAttrTyp], value: int | IntAttr, - typ: _IntegerAttrTyp) -> None: + def __init__( + self: IntegerAttr[_IntegerAttrTyp], value: int | IntAttr, typ: _IntegerAttrTyp + ) -> None: ... @overload - def __init__(self: IntegerAttr[IntegerType], value: int | IntAttr, - typ: int) -> None: + def __init__( + self: IntegerAttr[IntegerType], value: int | IntAttr, typ: int + ) -> None: ... - def __init__(self, value: int | IntAttr, - typ: int | IntegerType | IndexType) -> None: + def __init__( + self, value: int | IntAttr, typ: int | IntegerType | IndexType + ) -> None: if isinstance(value, int): value = IntAttr(value) if isinstance(typ, int): @@ -368,8 +411,9 @@ class Float128Type(ParametrizedAttribute, TypeAttribute): name: str = "f128" -AnyFloat: TypeAlias = (BFloat16Type | Float16Type | Float32Type | Float64Type - | Float80Type | Float128Type) +AnyFloat: TypeAlias = ( + BFloat16Type | Float16Type | Float32Type | Float64Type | Float80Type | Float128Type +) @irdl_attr_definition @@ -378,12 +422,11 @@ class FloatData(Data[float]): @staticmethod def parse_parameter(parser: BaseParser) -> float: - span = parser.expect(parser.try_parse_float_literal, - "Expect float literal") + span = parser.expect(parser.try_parse_float_literal, "Expect float literal") return float(span.text) def print_parameter(self, printer: Printer) -> None: - printer.print_string(f'{self.data}') + printer.print_string(f"{self.data}") @staticmethod @deprecated_constructor @@ -411,8 +454,9 @@ def __init__(self, data: float | FloatData, type: _FloatAttrTyp) -> None: def __init__(self, data: float | FloatData, type: int) -> None: ... - def __init__(self, data: float | FloatData, - type: int | _FloatAttrTyp | AnyFloat) -> None: + def __init__( + self, data: float | FloatData, type: int | _FloatAttrTyp | AnyFloat + ) -> None: if isinstance(data, float): data = FloatData(data) if isinstance(type, int): @@ -432,8 +476,7 @@ def __init__(self, data: float | FloatData, @staticmethod @deprecated_constructor - def from_value(value: float, - type: _FloatAttrTypInv) -> FloatAttr[_FloatAttrTypInv]: + def from_value(value: float, type: _FloatAttrTypInv) -> FloatAttr[_FloatAttrTypInv]: return FloatAttr(FloatData.from_float(value), type) @staticmethod @@ -464,8 +507,9 @@ def parse_parameter(parser: BaseParser) -> dict[str, Attribute]: def print_parameter(self, printer: Printer) -> None: printer.print_string("{") - printer.print_dictionary(self.data, printer.print_string_literal, - printer.print_attribute) + printer.print_dictionary( + self.data, printer.print_string_literal, printer.print_attribute + ) printer.print_string("}") @staticmethod @@ -477,16 +521,19 @@ def verify(self) -> None: raise VerifyException( f"Wrong type given to attribute {self.name}: got" f" {type(self.data)}, but expected dictionary of" - " attributes") + " attributes" + ) for key, val in self.data.items(): if not isinstance(key, str): raise VerifyException( f"{self.name} key expects str, but {key} " - f"element is of type {type(key)}") + f"element is of type {type(key)}" + ) if not isinstance(val, Attribute): raise VerifyException( f"{self.name} key expects attribute, but {val} " - f"element is of type {type(val)}") + f"element is of type {type(val)}" + ) @staticmethod @deprecated_constructor @@ -500,7 +547,8 @@ def from_dict(data: dict[str | StringAttr, Attribute]) -> DictionaryAttr: if not isinstance(k, str): raise TypeError( f"DictionaryAttr.from_dict expects keys to" - f" be of type str or StringAttr, but {type(k)} provided") + f" be of type str or StringAttr, but {type(k)} provided" + ) to_add_data[k] = v return DictionaryAttr(to_add_data) @@ -543,12 +591,14 @@ def verify(self): if self.get_num_scalable_dims() < 0: raise VerifyException( f"Number of scalable dimensions {self.get_num_dims()} cannot" - " be negative") + " be negative" + ) if self.get_num_scalable_dims() > self.get_num_dims(): raise VerifyException( f"Number of scalable dimensions {self.get_num_scalable_dims()}" " cannot be larger than number of dimensions" - f" {self.get_num_dims()}") + f" {self.get_num_dims()}" + ) @staticmethod def from_element_type_and_shape( @@ -558,20 +608,27 @@ def from_element_type_and_shape( ) -> VectorType[AttributeInvT]: if isinstance(num_scalable_dims, int): num_scalable_dims = IntAttr(num_scalable_dims) - return VectorType([ - ArrayAttr([ - IntegerAttr[IntegerType].from_index_int_value(d) if isinstance( - d, int) else d for d in shape - ]), - referenced_type, - num_scalable_dims, - ]) + return VectorType( + [ + ArrayAttr( + [ + IntegerAttr[IntegerType].from_index_int_value(d) + if isinstance(d, int) + else d + for d in shape + ] + ), + referenced_type, + num_scalable_dims, + ] + ) @staticmethod def from_params( referenced_type: AttributeInvT, shape: ArrayAttr[IntegerAttr[IntegerType]] = ArrayAttr( - [IntegerAttr.from_int_and_width(1, 64)]), + [IntegerAttr.from_int_and_width(1, 64)] + ), num_scalable_dims: IntAttr = IntAttr(0), ) -> VectorType[AttributeInvT]: return VectorType([shape, referenced_type, num_scalable_dims]) @@ -598,23 +655,30 @@ def get_shape(self) -> List[int]: def from_type_and_list( referenced_type: AttributeInvT, shape: Sequence[int | IntegerAttr[IndexType]] | None = None, - encoding: Attribute = NoneAttr() + encoding: Attribute = NoneAttr(), ) -> TensorType[AttributeInvT]: if shape is None: shape = [1] - return TensorType([ - ArrayAttr([ - IntegerAttr[IntegerType].from_index_int_value(d) if isinstance( - d, int) else d for d in shape - ]), referenced_type, encoding - ]) + return TensorType( + [ + ArrayAttr( + [ + IntegerAttr[IntegerType].from_index_int_value(d) + if isinstance(d, int) + else d + for d in shape + ] + ), + referenced_type, + encoding, + ] + ) @staticmethod def from_params( referenced_type: AttributeInvT, - shape: AnyArrayAttr = AnyArrayAttr( - [IntegerAttr.from_int_and_width(1, 64)]), - encoding: Attribute = NoneAttr() + shape: AnyArrayAttr = AnyArrayAttr([IntegerAttr.from_int_and_width(1, 64)]), + encoding: Attribute = NoneAttr(), ) -> TensorType[AttributeInvT]: return TensorType([shape, referenced_type, encoding]) @@ -623,16 +687,13 @@ def from_params( @irdl_attr_definition -class UnrankedTensorType(Generic[AttributeCovT], ParametrizedAttribute, - TypeAttribute): +class UnrankedTensorType(Generic[AttributeCovT], ParametrizedAttribute, TypeAttribute): name = "unranked_tensor" element_type: ParameterDef[AttributeCovT] @staticmethod - def from_type( - referenced_type: AttributeInvT - ) -> UnrankedTensorType[AttributeInvT]: + def from_type(referenced_type: AttributeInvT) -> UnrankedTensorType[AttributeInvT]: return UnrankedTensorType([referenced_type]) @@ -642,11 +703,12 @@ def from_type( @dataclass(init=False) class ContainerOf(AttrConstraint): """A type constraint that can be nested once in a vector or a tensor.""" + elem_constr: AttrConstraint def __init__( - self, - elem_constr: Attribute | type[Attribute] | AttrConstraint) -> None: + self, elem_constr: Attribute | type[Attribute] | AttrConstraint + ) -> None: self.elem_constr = attr_constr_coercion(elem_constr) def verify(self, attr: Attribute) -> None: @@ -656,12 +718,15 @@ def verify(self, attr: Attribute) -> None: self.elem_constr.verify(attr) -VectorOrTensorOf: TypeAlias = (VectorType[AttributeCovT] - | TensorType[AttributeCovT] - | UnrankedTensorType[AttributeCovT]) +VectorOrTensorOf: TypeAlias = ( + VectorType[AttributeCovT] + | TensorType[AttributeCovT] + | UnrankedTensorType[AttributeCovT] +) -RankedVectorOrTensorOf: TypeAlias = (VectorType[AttributeCovT] - | TensorType[AttributeCovT]) +RankedVectorOrTensorOf: TypeAlias = ( + VectorType[AttributeCovT] | TensorType[AttributeCovT] +) @dataclass @@ -713,19 +778,23 @@ class VectorBaseTypeAndRankConstraint(AttrConstraint): """The expected vector rank.""" def verify(self, attr: Attribute) -> None: - constraint = AllOf([ - VectorBaseTypeConstraint(self.expected_type), - VectorRankConstraint(self.expected_rank) - ]) + constraint = AllOf( + [ + VectorBaseTypeConstraint(self.expected_type), + VectorRankConstraint(self.expected_rank), + ] + ) constraint.verify(attr) @irdl_attr_definition class DenseIntOrFPElementsAttr(ParametrizedAttribute): name = "dense" - type: ParameterDef[RankedVectorOrTensorOf[IntegerType] - | RankedVectorOrTensorOf[IndexType] - | RankedVectorOrTensorOf[AnyFloat]] + type: ParameterDef[ + RankedVectorOrTensorOf[IntegerType] + | RankedVectorOrTensorOf[IndexType] + | RankedVectorOrTensorOf[AnyFloat] + ] data: ParameterDef[ArrayAttr[AnyIntegerAttr] | ArrayAttr[AnyFloatAttr]] # The type stores the shape data @@ -754,12 +823,11 @@ def shape_is_complete(self) -> bool: @staticmethod def create_dense_index( type: RankedVectorOrTensorOf[IndexType], - data: Sequence[int] | Sequence[IntegerAttr[IndexType]] + data: Sequence[int] | Sequence[IntegerAttr[IndexType]], ) -> DenseIntOrFPElementsAttr: if len(data) and isinstance(data[0], int): attr_list = [ - IntegerAttr.from_index_int_value(d) - for d in cast(Sequence[int], data) + IntegerAttr.from_index_int_value(d) for d in cast(Sequence[int], data) ] else: attr_list = cast(Sequence[IntegerAttr[IndexType]], data) @@ -769,7 +837,7 @@ def create_dense_index( @staticmethod def create_dense_int( type: RankedVectorOrTensorOf[IntegerType], - data: Sequence[int] | Sequence[IntegerAttr[IntegerType]] + data: Sequence[int] | Sequence[IntegerAttr[IntegerType]], ) -> DenseIntOrFPElementsAttr: if len(data) and isinstance(data[0], int): attr_list = [ @@ -784,7 +852,7 @@ def create_dense_int( @staticmethod def create_dense_float( type: RankedVectorOrTensorOf[AnyFloat], - data: Sequence[int | float] | Sequence[AnyFloatAttr] + data: Sequence[int | float] | Sequence[AnyFloatAttr], ) -> DenseIntOrFPElementsAttr: if len(data) and isinstance(data[0], int | float): attr_list = [ @@ -801,7 +869,8 @@ def create_dense_float( def from_list( type: RankedVectorOrTensorOf[AnyFloat | IntegerType | IndexType], data: Sequence[int] - | Sequence[IntegerAttr[IndexType]] | Sequence[IntegerAttr[IntegerType]] + | Sequence[IntegerAttr[IndexType]] + | Sequence[IntegerAttr[IntegerType]], ) -> DenseIntOrFPElementsAttr: ... @@ -809,54 +878,48 @@ def from_list( @staticmethod def from_list( type: RankedVectorOrTensorOf[AnyFloat | IntegerType | IndexType], - data: Sequence[int | float] | Sequence[AnyFloatAttr] + data: Sequence[int | float] | Sequence[AnyFloatAttr], ) -> DenseIntOrFPElementsAttr: ... @staticmethod def from_list( type: RankedVectorOrTensorOf[AnyFloat | IntegerType | IndexType], - data: Sequence[int | float] - | Sequence[AnyIntegerAttr] | Sequence[AnyFloatAttr] + data: Sequence[int | float] | Sequence[AnyIntegerAttr] | Sequence[AnyFloatAttr], ) -> DenseIntOrFPElementsAttr: if isinstance(type.element_type, IntegerType): new_type = cast(RankedVectorOrTensorOf[IntegerType], type) - new_data = cast(Sequence[int] | Sequence[IntegerAttr[IntegerType]], - data) - return DenseIntOrFPElementsAttr.create_dense_int( - new_type, new_data) + new_data = cast(Sequence[int] | Sequence[IntegerAttr[IntegerType]], data) + return DenseIntOrFPElementsAttr.create_dense_int(new_type, new_data) elif isinstance(type.element_type, IndexType): new_type = cast(RankedVectorOrTensorOf[IndexType], type) - new_data = cast(Sequence[int] | Sequence[IntegerAttr[IndexType]], - data) - return DenseIntOrFPElementsAttr.create_dense_index( - new_type, new_data) + new_data = cast(Sequence[int] | Sequence[IntegerAttr[IndexType]], data) + return DenseIntOrFPElementsAttr.create_dense_index(new_type, new_data) elif isinstance(type.element_type, AnyFloat): new_type = cast(RankedVectorOrTensorOf[AnyFloat], type) - new_data = cast( - Sequence[int | float] | Sequence[FloatAttr[AnyFloat]], data) - return DenseIntOrFPElementsAttr.create_dense_float( - new_type, new_data) + new_data = cast(Sequence[int | float] | Sequence[FloatAttr[AnyFloat]], data) + return DenseIntOrFPElementsAttr.create_dense_float(new_type, new_data) else: raise TypeError(f"Unsupported element type {type.element_type}") @staticmethod def vector_from_list( - data: List[int] | List[float], - typ: IntegerType | IndexType | AnyFloat + data: List[int] | List[float], typ: IntegerType | IndexType | AnyFloat ) -> DenseIntOrFPElementsAttr: t = VectorType.from_element_type_and_shape(typ, [len(data)]) return DenseIntOrFPElementsAttr.from_list(t, data) @staticmethod - def tensor_from_list(data: List[int] | List[float] - | List[IntegerAttr[IndexType]] - | List[IntegerAttr[IntegerType]] - | List[AnyFloatAttr], - typ: IntegerType | IndexType | AnyFloat, - shape: List[int] = []) -> DenseIntOrFPElementsAttr: - t = AnyTensorType.from_type_and_list( - typ, shape if len(shape) else [len(data)]) + def tensor_from_list( + data: List[int] + | List[float] + | List[IntegerAttr[IndexType]] + | List[IntegerAttr[IntegerType]] + | List[AnyFloatAttr], + typ: IntegerType | IndexType | AnyFloat, + shape: List[int] = [], + ) -> DenseIntOrFPElementsAttr: + t = AnyTensorType.from_type_and_list(typ, shape if len(shape) else [len(data)]) return DenseIntOrFPElementsAttr.from_list(t, data) @@ -870,8 +933,7 @@ class DenseResourceAttr(ParametrizedAttribute): type: ParameterDef[Attribute] @staticmethod - def from_params(handle: str | StringAttr, - type: Attribute) -> DenseResourceAttr: + def from_params(handle: str | StringAttr, type: Attribute) -> DenseResourceAttr: if isinstance(handle, str): handle = StringAttr(handle) return DenseResourceAttr([handle, type]) @@ -890,17 +952,20 @@ def verify(self): if isinstance(d, FloatData): raise VerifyException( "dense array of integer element type " - "should only contain integers") + "should only contain integers" + ) else: for d in self.data.data: if isinstance(d, IntAttr): - raise VerifyException("dense array of float element type " - "should only contain floats") + raise VerifyException( + "dense array of float element type " + "should only contain floats" + ) @staticmethod def create_dense_int_or_index( - typ: IntegerType | IndexType, - data: Sequence[int] | Sequence[IntAttr]) -> DenseArrayBase: + typ: IntegerType | IndexType, data: Sequence[int] | Sequence[IntAttr] + ) -> DenseArrayBase: if len(data) and isinstance(data[0], int): attr_list = [IntAttr(d) for d in cast(Sequence[int], data)] else: @@ -910,12 +975,10 @@ def create_dense_int_or_index( @staticmethod def create_dense_float( - typ: AnyFloat, data: Sequence[int | float] | Sequence[FloatData] + typ: AnyFloat, data: Sequence[int | float] | Sequence[FloatData] ) -> DenseArrayBase: if len(data) and isinstance(data[0], int | float): - attr_list = [ - FloatData(float(d)) for d in cast(Sequence[int | float], data) - ] + attr_list = [FloatData(float(d)) for d in cast(Sequence[int | float], data)] else: attr_list = cast(Sequence[FloatData], data) @@ -923,21 +986,25 @@ def create_dense_float( @overload @staticmethod - def from_list(type: IntegerType | IndexType, data: Sequence[int] - | Sequence[IntAttr]) -> DenseArrayBase: + def from_list( + type: IntegerType | IndexType, data: Sequence[int] | Sequence[IntAttr] + ) -> DenseArrayBase: ... @overload @staticmethod def from_list( - type: Attribute, data: Sequence[int | float] - | Sequence[FloatData]) -> DenseArrayBase: + type: Attribute, data: Sequence[int | float] | Sequence[FloatData] + ) -> DenseArrayBase: ... @staticmethod def from_list( - type: Attribute, data: Sequence[int] | Sequence[int | float] - | Sequence[IntAttr] | Sequence[FloatData] + type: Attribute, + data: Sequence[int] + | Sequence[int | float] + | Sequence[IntAttr] + | Sequence[FloatData], ) -> DenseArrayBase: if isinstance(type, IndexType | IntegerType): _data = cast(Sequence[int] | Sequence[IntAttr], data) @@ -968,13 +1035,15 @@ class FunctionType(ParametrizedAttribute, TypeAttribute): outputs: ParameterDef[ArrayAttr[Attribute]] @staticmethod - def from_lists(inputs: Sequence[Attribute], - outputs: Sequence[Attribute]) -> FunctionType: + def from_lists( + inputs: Sequence[Attribute], outputs: Sequence[Attribute] + ) -> FunctionType: return FunctionType([ArrayAttr(inputs), ArrayAttr(outputs)]) @staticmethod - def from_attrs(inputs: ArrayAttr[Attribute], - outputs: ArrayAttr[Attribute]) -> FunctionType: + def from_attrs( + inputs: ArrayAttr[Attribute], outputs: ArrayAttr[Attribute] + ) -> FunctionType: return FunctionType([inputs, outputs]) @@ -987,8 +1056,7 @@ class OpaqueAttr(ParametrizedAttribute): type: ParameterDef[Attribute] @staticmethod - def from_strings(name: str, value: str, - type: Attribute = NoneAttr()) -> OpaqueAttr: + def from_strings(name: str, value: str, type: Attribute = NoneAttr()) -> OpaqueAttr: return OpaqueAttr([StringAttr(name), StringAttr(value), type]) @@ -1002,15 +1070,18 @@ class StridedLayoutAttr(ParametrizedAttribute): `NoneAttr`, and we do not restrict offsets and strides to 64-bits integers. """ + name: str = "strided" strides: ParameterDef[ArrayAttr[IntAttr | NoneAttr]] offset: ParameterDef[IntAttr | NoneAttr] - def __init__(self, - strides: ArrayAttr[IntAttr | NoneAttr] - | Sequence[int | None | IntAttr | NoneAttr], - offset: int | None | IntAttr | NoneAttr = 0) -> None: + def __init__( + self, + strides: ArrayAttr[IntAttr | NoneAttr] + | Sequence[int | None | IntAttr | NoneAttr], + offset: int | None | IntAttr | NoneAttr = 0, + ) -> None: if not isinstance(strides, ArrayAttr): strides_values: list[IntAttr | NoneAttr] = [] for stride in strides: @@ -1038,8 +1109,7 @@ class UnrealizedConversionCastOp(IRDLOperation): outputs: VarOpResult @staticmethod - def get(inputs: Sequence[SSAValue | Operation], - result_type: Sequence[Attribute]): + def get(inputs: Sequence[SSAValue | Operation], result_type: Sequence[Attribute]): return UnrealizedConversionCastOp.build( operands=[inputs], result_types=[result_type], @@ -1053,6 +1123,7 @@ class UnregisteredOp(IRDLOperation, ABC): Each unregistered op is registered as a subclass of `UnregisteredOp`, and op with different names have distinct subclasses. """ + name: str = "builtin.unregistered" op_name__: OpAttr[StringAttr] @@ -1073,17 +1144,19 @@ def with_name(cls, name: str) -> type[Operation]: """ class UnregisteredOpWithName(UnregisteredOp): - @classmethod - def create(cls, - operands: Sequence[SSAValue] | None = None, - result_types: Sequence[Attribute] | None = None, - attributes: dict[str, Attribute] | None = None, - successors: Sequence[Block] | None = None, - regions: Sequence[Region] | None = None): - op = super().create(operands, result_types, attributes, - successors, regions) - op.attributes['op_name__'] = StringAttr(name) + def create( + cls, + operands: Sequence[SSAValue] | None = None, + result_types: Sequence[Attribute] | None = None, + attributes: dict[str, Attribute] | None = None, + successors: Sequence[Block] | None = None, + regions: Sequence[Region] | None = None, + ): + op = super().create( + operands, result_types, attributes, successors, regions + ) + op.attributes["op_name__"] = StringAttr(name) return op return irdl_op_definition(UnregisteredOpWithName) @@ -1103,6 +1176,7 @@ class UnregisteredAttr(ParametrizedAttribute, ABC): which is exactly the content parsed from the textual representation. """ + name: str = "builtin.unregistered" attr_name: ParameterDef[StringAttr] @@ -1112,8 +1186,12 @@ class UnregisteredAttr(ParametrizedAttribute, ABC): This parameter is non-null is the attribute is a type, and null otherwise. """ - def __init__(self, attr_name: str | StringAttr, is_type: bool | IntAttr, - value: str | StringAttr): + def __init__( + self, + attr_name: str | StringAttr, + is_type: bool | IntAttr, + value: str | StringAttr, + ): if isinstance(attr_name, str): attr_name = StringAttr(attr_name) if isinstance(is_type, bool): @@ -1123,8 +1201,7 @@ def __init__(self, attr_name: str | StringAttr, is_type: bool | IntAttr, super().__init__([attr_name, is_type, value]) @classmethod - def with_name_and_type(cls, name: str, - is_type: bool) -> type[UnregisteredAttr]: + def with_name_and_type(cls, name: str, is_type: bool) -> type[UnregisteredAttr]: """ Return a new unregistered attribute type given a name and a boolean indicating if the attribute can be a type. @@ -1133,24 +1210,18 @@ def with_name_and_type(cls, name: str, """ class UnregisteredAttrWithName(UnregisteredAttr): - def verify(self): if self.attr_name.data != name: - raise VerifyException( - "Unregistered attribute name mismatch") + raise VerifyException("Unregistered attribute name mismatch") if self.is_type.data != int(is_type): - raise VerifyException( - "Unregistered attribute is_type mismatch") + raise VerifyException("Unregistered attribute is_type mismatch") class UnregisteredAttrTypeWithName(UnregisteredAttr, TypeAttribute): - def verify(self): if self.attr_name.data != name: - raise VerifyException( - "Unregistered attribute name mismatch") + raise VerifyException("Unregistered attribute name mismatch") if self.is_type.data != int(is_type): - raise VerifyException( - "Unregistered attribute is_type mismatch") + raise VerifyException("Unregistered attribute is_type mismatch") if is_type: return UnregisteredAttrWithName @@ -1192,7 +1263,6 @@ def ops(self) -> List[Operation]: ], [ UnregisteredAttr, - # Attributes StringAttr, SymbolRefAttr, @@ -1207,7 +1277,6 @@ def ops(self) -> List[Operation]: FloatData, NoneAttr, OpaqueAttr, - # Types ComplexType, FunctionType, @@ -1224,5 +1293,6 @@ def ops(self) -> List[Operation]: IndexType, VectorType, TensorType, - UnrankedTensorType - ]) + UnrankedTensorType, + ], +) diff --git a/xdsl/dialects/cf.py b/xdsl/dialects/cf.py index 649b7ff0d3..303e786bf5 100644 --- a/xdsl/dialects/cf.py +++ b/xdsl/dialects/cf.py @@ -4,8 +4,14 @@ from xdsl.dialects.builtin import IntegerType from xdsl.ir import SSAValue, Operation, Block, Dialect -from xdsl.irdl import (irdl_op_definition, VarOperand, AnyAttr, Operand, - AttrSizedOperandSegments, IRDLOperation) +from xdsl.irdl import ( + irdl_op_definition, + VarOperand, + AnyAttr, + Operand, + AttrSizedOperandSegments, + IRDLOperation, +) @irdl_op_definition @@ -30,11 +36,16 @@ class ConditionalBranch(IRDLOperation): irdl_options = [AttrSizedOperandSegments()] @staticmethod - def get(cond: Union[Operation, SSAValue], then_block: Block, - then_ops: List[Union[Operation, SSAValue]], else_block: Block, - else_ops: List[Union[Operation, SSAValue]]) -> ConditionalBranch: - return ConditionalBranch.build(operands=[cond, then_ops, else_ops], - successors=[then_block, else_block]) + def get( + cond: Union[Operation, SSAValue], + then_block: Block, + then_ops: List[Union[Operation, SSAValue]], + else_block: Block, + else_ops: List[Union[Operation, SSAValue]], + ) -> ConditionalBranch: + return ConditionalBranch.build( + operands=[cond, then_ops, else_ops], successors=[then_block, else_block] + ) Cf = Dialect([Branch, ConditionalBranch], []) diff --git a/xdsl/dialects/cmath.py b/xdsl/dialects/cmath.py index 820918662f..efcb59fd44 100644 --- a/xdsl/dialects/cmath.py +++ b/xdsl/dialects/cmath.py @@ -2,11 +2,24 @@ from typing import Annotated, Union from xdsl.dialects.builtin import Float32Type, Float64Type -from xdsl.ir import (TypeAttribute, ParametrizedAttribute, Operation, Dialect, - OpResult, SSAValue) -from xdsl.irdl import (irdl_op_definition, irdl_attr_definition, Operand, - ParameterDef, ParamAttrConstraint, AnyOf, - VerifyException, IRDLOperation) +from xdsl.ir import ( + TypeAttribute, + ParametrizedAttribute, + Operation, + Dialect, + OpResult, + SSAValue, +) +from xdsl.irdl import ( + irdl_op_definition, + irdl_attr_definition, + Operand, + ParameterDef, + ParamAttrConstraint, + AnyOf, + VerifyException, + IRDLOperation, +) @irdl_attr_definition @@ -20,8 +33,8 @@ class Norm(IRDLOperation): name: str = "cmath.norm" op: Annotated[ - Operand, - ParamAttrConstraint(ComplexType, [AnyOf([Float32Type, Float64Type])])] + Operand, ParamAttrConstraint(ComplexType, [AnyOf([Float32Type, Float64Type])]) + ] res: Annotated[OpResult, AnyOf([Float32Type, Float64Type])] # TODO replace with trait @@ -29,8 +42,7 @@ def verify_(self) -> None: if not isinstance(self.op.typ, ComplexType): raise VerifyException("Expected complex type") if self.op.typ.data != self.res.typ: - raise VerifyException( - "expect all input and output types to be equal") + raise VerifyException("expect all input and output types to be equal") @irdl_op_definition @@ -38,27 +50,26 @@ class Mul(IRDLOperation): name: str = "cmath.mul" lhs: Annotated[ - Operand, - ParamAttrConstraint(ComplexType, [AnyOf([Float32Type, Float64Type])])] + Operand, ParamAttrConstraint(ComplexType, [AnyOf([Float32Type, Float64Type])]) + ] rhs: Annotated[ - Operand, - ParamAttrConstraint(ComplexType, [AnyOf([Float32Type, Float64Type])])] + Operand, ParamAttrConstraint(ComplexType, [AnyOf([Float32Type, Float64Type])]) + ] result: Annotated[ - OpResult, - ParamAttrConstraint(ComplexType, [AnyOf([Float32Type, Float64Type])])] + OpResult, ParamAttrConstraint(ComplexType, [AnyOf([Float32Type, Float64Type])]) + ] # TODO replace with trait def verify_(self) -> None: if self.lhs.typ != self.rhs.typ and self.rhs.typ != self.result.typ: - raise VerifyException( - "expect all input and output types to be equal") + raise VerifyException("expect all input and output types to be equal") @staticmethod - def get(operand1: Union[Operation, SSAValue], - operand2: Union[Operation, SSAValue]) -> Mul: + def get( + operand1: Union[Operation, SSAValue], operand2: Union[Operation, SSAValue] + ) -> Mul: operand1 = SSAValue.get(operand1) - return Mul.build(operands=[operand1, operand2], - result_types=[operand1.typ]) + return Mul.build(operands=[operand1, operand2], result_types=[operand1.typ]) CMath = Dialect([Norm, Mul], [ComplexType]) diff --git a/xdsl/dialects/experimental/math.py b/xdsl/dialects/experimental/math.py index b565bec96f..ba78dba391 100644 --- a/xdsl/dialects/experimental/math.py +++ b/xdsl/dialects/experimental/math.py @@ -20,18 +20,22 @@ class AbsFOp(IRDLOperation): // Scalar absolute value. %a = math.absf %b : f64 """ + name: str = "math.absf" fastmath: OptOpAttr[FastMathFlagsAttr] operand: Annotated[Operand, AnyFloat] result: Annotated[OpResult, AnyFloat] @staticmethod - def get(operand: Union[Operation, SSAValue], - fastmath: FastMathFlagsAttr | None = None) -> AbsFOp: + def get( + operand: Union[Operation, SSAValue], fastmath: FastMathFlagsAttr | None = None + ) -> AbsFOp: operand = SSAValue.get(operand) - return AbsFOp.build(attributes={"fastmath": fastmath}, - operands=[operand], - result_types=[operand.typ]) + return AbsFOp.build( + attributes={"fastmath": fastmath}, + operands=[operand], + result_types=[operand.typ], + ) @irdl_op_definition @@ -46,6 +50,7 @@ class AbsIOp(IRDLOperation): // Scalar absolute value. %a = math.absi %b : i64 """ + name: str = "math.absi" operand: Annotated[Operand, IntegerType] result: Annotated[OpResult, IntegerType] @@ -78,6 +83,7 @@ class Atan2Op(IRDLOperation): // Scalar variant. %a = math.atan2 %b, %c : f32 """ + name: str = "math.atan2" fastmath: OptOpAttr[FastMathFlagsAttr] lhs: Annotated[Operand, AnyFloat] @@ -85,16 +91,18 @@ class Atan2Op(IRDLOperation): result: Annotated[OpResult, AnyFloat] @staticmethod - def get(lhs: Union[Operation, SSAValue], - rhs: Union[Operation, SSAValue], - fastmath: FastMathFlagsAttr | None = None) -> Atan2Op: + def get( + lhs: Union[Operation, SSAValue], + rhs: Union[Operation, SSAValue], + fastmath: FastMathFlagsAttr | None = None, + ) -> Atan2Op: attributes = {"fastmath": fastmath} lhs = SSAValue.get(lhs) rhs = SSAValue.get(rhs) - return Atan2Op.build(attributes=attributes, - operands=[lhs, rhs], - result_types=[lhs.typ]) + return Atan2Op.build( + attributes=attributes, operands=[lhs, rhs], result_types=[lhs.typ] + ) @irdl_op_definition @@ -112,18 +120,22 @@ class AtanOp(IRDLOperation): // Arcus tangent of scalar value. %a = math.atan %b : f64 """ + name: str = "math.atan" fastmath: OptOpAttr[FastMathFlagsAttr] operand: Annotated[Operand, AnyFloat] result: Annotated[OpResult, AnyFloat] @staticmethod - def get(operand: Union[Operation, SSAValue], - fastmath: FastMathFlagsAttr | None = None) -> AtanOp: + def get( + operand: Union[Operation, SSAValue], fastmath: FastMathFlagsAttr | None = None + ) -> AtanOp: operand = SSAValue.get(operand) - return AtanOp.build(attributes={"fastmath": fastmath}, - operands=[operand], - result_types=[operand.typ]) + return AtanOp.build( + attributes={"fastmath": fastmath}, + operands=[operand], + result_types=[operand.typ], + ) @irdl_op_definition @@ -140,20 +152,22 @@ class CbrtOp(IRDLOperation): Note: This op is not equivalent to powf(..., 1/3.0). """ + name: str = "math.cbrt" fastmath: OptOpAttr[FastMathFlagsAttr] operand: Annotated[Operand, AnyFloat] result: Annotated[OpResult, AnyFloat] @staticmethod - def get(operand: Union[Operation, SSAValue], - fastmath: FastMathFlagsAttr | None = None) -> CbrtOp: + def get( + operand: Union[Operation, SSAValue], fastmath: FastMathFlagsAttr | None = None + ) -> CbrtOp: attributes = {"fastmath": fastmath} operand = SSAValue.get(operand) - return CbrtOp.build(attributes=attributes, - operands=[operand], - result_types=[operand.typ]) + return CbrtOp.build( + attributes=attributes, operands=[operand], result_types=[operand.typ] + ) @irdl_op_definition @@ -171,18 +185,22 @@ class CeilOp(IRDLOperation): // Scalar ceiling value. %a = math.ceil %b : f64 """ + name: str = "math.ceil" fastmath: OptOpAttr[FastMathFlagsAttr] operand: Annotated[Operand, AnyFloat] result: Annotated[OpResult, AnyFloat] @staticmethod - def get(operand: Union[Operation, SSAValue], - fastmath: FastMathFlagsAttr | None = None) -> CeilOp: + def get( + operand: Union[Operation, SSAValue], fastmath: FastMathFlagsAttr | None = None + ) -> CeilOp: operand = SSAValue.get(operand) - return CeilOp.build(attributes={"fastmath": fastmath}, - operands=[operand], - result_types=[operand.typ]) + return CeilOp.build( + attributes={"fastmath": fastmath}, + operands=[operand], + result_types=[operand.typ], + ) @irdl_op_definition @@ -201,6 +219,7 @@ class CopySignOp(IRDLOperation): // Scalar copysign value. %a = math.copysign %b, %c : f64 """ + name: str = "math.copysign" fastmath: OptOpAttr[FastMathFlagsAttr] lhs: Annotated[Operand, AnyFloat] @@ -208,16 +227,18 @@ class CopySignOp(IRDLOperation): result: Annotated[OpResult, AnyFloat] @staticmethod - def get(lhs: Union[Operation, SSAValue], - rhs: Union[Operation, SSAValue], - fastmath: FastMathFlagsAttr | None = None) -> CopySignOp: + def get( + lhs: Union[Operation, SSAValue], + rhs: Union[Operation, SSAValue], + fastmath: FastMathFlagsAttr | None = None, + ) -> CopySignOp: attributes = {"fastmath": fastmath} lhs = SSAValue.get(lhs) rhs = SSAValue.get(rhs) - return CopySignOp.build(attributes=attributes, - operands=[lhs, rhs], - result_types=[lhs.typ]) + return CopySignOp.build( + attributes=attributes, operands=[lhs, rhs], result_types=[lhs.typ] + ) @irdl_op_definition @@ -235,20 +256,22 @@ class CosOp(IRDLOperation): // Scalar cosine value. %a = math.cos %b : f64 """ + name: str = "math.cos" fastmath: OptOpAttr[FastMathFlagsAttr] operand: Annotated[Operand, AnyFloat] result: Annotated[OpResult, AnyFloat] @staticmethod - def get(operand: Union[Operation, SSAValue], - fastmath: FastMathFlagsAttr | None = None) -> CosOp: + def get( + operand: Union[Operation, SSAValue], fastmath: FastMathFlagsAttr | None = None + ) -> CosOp: attributes = {"fastmath": fastmath} operand = SSAValue.get(operand) - return CosOp.build(attributes=attributes, - operands=[operand], - result_types=[operand.typ]) + return CosOp.build( + attributes=attributes, operands=[operand], result_types=[operand.typ] + ) @irdl_op_definition @@ -262,6 +285,7 @@ class CountLeadingZerosOp(IRDLOperation): // Scalar ctlz function value. %a = math.ctlz %b : i32 """ + name: str = "math.ctlz" operand: Annotated[Operand, IntegerType] result: Annotated[OpResult, IntegerType] @@ -269,8 +293,7 @@ class CountLeadingZerosOp(IRDLOperation): @staticmethod def get(operand: Union[Operation, SSAValue]) -> CountLeadingZerosOp: operand = SSAValue.get(operand) - return CountLeadingZerosOp.build(operands=[operand], - result_types=[operand.typ]) + return CountLeadingZerosOp.build(operands=[operand], result_types=[operand.typ]) @irdl_op_definition @@ -284,6 +307,7 @@ class CountTrailingZerosOp(IRDLOperation): // Scalar cttz function value. %a = math.cttz %b : i32 """ + name: str = "math.cttz" operand: Annotated[Operand, IntegerType] result: Annotated[OpResult, IntegerType] @@ -291,8 +315,9 @@ class CountTrailingZerosOp(IRDLOperation): @staticmethod def get(operand: Union[Operation, SSAValue]) -> CountTrailingZerosOp: operand = SSAValue.get(operand) - return CountTrailingZerosOp.build(operands=[operand], - result_types=[operand.typ]) + return CountTrailingZerosOp.build( + operands=[operand], result_types=[operand.typ] + ) @irdl_op_definition @@ -306,6 +331,7 @@ class CtPopOp(IRDLOperation): // Scalar ctpop function value. %a = math.ctpop %b : i32 """ + name: str = "math.ctpop" operand: Annotated[Operand, IntegerType] result: Annotated[OpResult, IntegerType] @@ -331,20 +357,22 @@ class ErfOp(IRDLOperation): // Scalar error function value. %a = math.erf %b : f64 """ + name: str = "math.erf" fastmath: OptOpAttr[FastMathFlagsAttr] operand: Annotated[Operand, AnyFloat] result: Annotated[OpResult, AnyFloat] @staticmethod - def get(operand: Union[Operation, SSAValue], - fastmath: FastMathFlagsAttr | None = None) -> ErfOp: + def get( + operand: Union[Operation, SSAValue], fastmath: FastMathFlagsAttr | None = None + ) -> ErfOp: attributes = {"fastmath": fastmath} operand = SSAValue.get(operand) - return ErfOp.build(attributes=attributes, - operands=[operand], - result_types=[operand.typ]) + return ErfOp.build( + attributes=attributes, operands=[operand], result_types=[operand.typ] + ) @irdl_op_definition @@ -362,20 +390,22 @@ class Exp2Op(IRDLOperation): // Scalar natural exponential. %a = math.exp2 %b : f64 """ + name: str = "math.exp2" fastmath: OptOpAttr[FastMathFlagsAttr] operand: Annotated[Operand, AnyFloat] result: Annotated[OpResult, AnyFloat] @staticmethod - def get(operand: Union[Operation, SSAValue], - fastmath: FastMathFlagsAttr | None = None) -> Exp2Op: + def get( + operand: Union[Operation, SSAValue], fastmath: FastMathFlagsAttr | None = None + ) -> Exp2Op: attributes = {"fastmath": fastmath} operand = SSAValue.get(operand) - return Exp2Op.build(attributes=attributes, - operands=[operand], - result_types=[operand.typ]) + return Exp2Op.build( + attributes=attributes, operands=[operand], result_types=[operand.typ] + ) @irdl_op_definition @@ -395,20 +425,22 @@ class ExpM1Op(IRDLOperation): // Scalar natural exponential minus 1. %a = math.expm1 %b : f64 """ + name: str = "math.expm1" fastmath: OptOpAttr[FastMathFlagsAttr] operand: Annotated[Operand, AnyFloat] result: Annotated[OpResult, AnyFloat] @staticmethod - def get(operand: Union[Operation, SSAValue], - fastmath: FastMathFlagsAttr | None = None) -> ExpM1Op: + def get( + operand: Union[Operation, SSAValue], fastmath: FastMathFlagsAttr | None = None + ) -> ExpM1Op: attributes = {"fastmath": fastmath} operand = SSAValue.get(operand) - return ExpM1Op.build(attributes=attributes, - operands=[operand], - result_types=[operand.typ]) + return ExpM1Op.build( + attributes=attributes, operands=[operand], result_types=[operand.typ] + ) @irdl_op_definition @@ -426,20 +458,22 @@ class ExpOp(IRDLOperation): // Scalar natural exponential. %a = math.exp %b : f64 """ + name: str = "math.exp" fastmath: OptOpAttr[FastMathFlagsAttr] operand: Annotated[Operand, AnyFloat] result: Annotated[OpResult, AnyFloat] @staticmethod - def get(operand: Union[Operation, SSAValue], - fastmath: FastMathFlagsAttr | None = None) -> ExpOp: + def get( + operand: Union[Operation, SSAValue], fastmath: FastMathFlagsAttr | None = None + ) -> ExpOp: attributes = {"fastmath": fastmath} operand = SSAValue.get(operand) - return ExpOp.build(attributes=attributes, - operands=[operand], - result_types=[operand.typ]) + return ExpOp.build( + attributes=attributes, operands=[operand], result_types=[operand.typ] + ) @irdl_op_definition @@ -465,6 +499,7 @@ class FPowIOp(IRDLOperation): // Scalar exponentiation. %a = math.fpowi %base, %power : f64, i32 """ + name: str = "math.fpowi" fastmath: OptOpAttr[FastMathFlagsAttr] lhs: Annotated[Operand, AnyFloat] @@ -472,16 +507,18 @@ class FPowIOp(IRDLOperation): result: Annotated[OpResult, AnyFloat] @staticmethod - def get(lhs: Union[Operation, SSAValue], - rhs: Union[Operation, SSAValue], - fastmath: FastMathFlagsAttr | None = None) -> FPowIOp: + def get( + lhs: Union[Operation, SSAValue], + rhs: Union[Operation, SSAValue], + fastmath: FastMathFlagsAttr | None = None, + ) -> FPowIOp: attributes = {"fastmath": fastmath} lhs = SSAValue.get(lhs) rhs = SSAValue.get(rhs) - return FPowIOp.build(attributes=attributes, - operands=[lhs, rhs], - result_types=[lhs.typ]) + return FPowIOp.build( + attributes=attributes, operands=[lhs, rhs], result_types=[lhs.typ] + ) @irdl_op_definition @@ -499,20 +536,22 @@ class FloorOp(IRDLOperation): // Scalar floor value. %a = math.floor %b : f64 """ + name: str = "math.floor" fastmath: OptOpAttr[FastMathFlagsAttr] operand: Annotated[Operand, AnyFloat] result: Annotated[OpResult, AnyFloat] @staticmethod - def get(operand: Union[Operation, SSAValue], - fastmath: FastMathFlagsAttr | None = None) -> FloorOp: + def get( + operand: Union[Operation, SSAValue], fastmath: FastMathFlagsAttr | None = None + ) -> FloorOp: attributes = {"fastmath": fastmath} operand = SSAValue.get(operand) - return FloorOp.build(attributes=attributes, - operands=[operand], - result_types=[operand.typ]) + return FloorOp.build( + attributes=attributes, operands=[operand], result_types=[operand.typ] + ) @irdl_op_definition @@ -535,6 +574,7 @@ class FmaOp(IRDLOperation): particular case of lowering to LLVM, this is guaranteed to lower to the `llvm.fma.*` intrinsic. """ + name: str = "math.fma" fastmath: OptOpAttr[FastMathFlagsAttr] a: Annotated[Operand, AnyFloat] @@ -543,18 +583,20 @@ class FmaOp(IRDLOperation): result: Annotated[OpResult, AnyFloat] @staticmethod - def get(a: Union[Operation, SSAValue], - b: Union[Operation, SSAValue], - c: Union[Operation, SSAValue], - fastmath: FastMathFlagsAttr | None = None) -> FmaOp: + def get( + a: Union[Operation, SSAValue], + b: Union[Operation, SSAValue], + c: Union[Operation, SSAValue], + fastmath: FastMathFlagsAttr | None = None, + ) -> FmaOp: attributes = {"fastmath": fastmath} a = SSAValue.get(a) b = SSAValue.get(b) c = SSAValue.get(c) - return FmaOp.build(attributes=attributes, - operands=[a, b, c], - result_types=[a.typ]) + return FmaOp.build( + attributes=attributes, operands=[a, b, c], result_types=[a.typ] + ) @irdl_op_definition @@ -571,14 +613,16 @@ class IPowIOp(IRDLOperation): // Scalar signed integer exponentiation. %a = math.ipowi %b, %c : i32 """ + name: str = "math.ipowi" lhs: Annotated[Operand, IntegerType] rhs: Annotated[Operand, IntegerType] result: Annotated[OpResult, IntegerType] @staticmethod - def get(lhs: Union[Operation, SSAValue], rhs: Union[Operation, - SSAValue]) -> IPowIOp: + def get( + lhs: Union[Operation, SSAValue], rhs: Union[Operation, SSAValue] + ) -> IPowIOp: lhs = SSAValue.get(lhs) rhs = SSAValue.get(rhs) return IPowIOp.build(operands=[lhs, rhs], result_types=[lhs.typ]) @@ -596,20 +640,22 @@ class Log10Op(IRDLOperation): // Scalar log10 operation. %y = math.log10 %x : f64 """ + name: str = "math.log10" fastmath: OptOpAttr[FastMathFlagsAttr] operand: Annotated[Operand, AnyFloat] result: Annotated[OpResult, AnyFloat] @staticmethod - def get(operand: Union[Operation, SSAValue], - fastmath: FastMathFlagsAttr | None = None) -> Log10Op: + def get( + operand: Union[Operation, SSAValue], fastmath: FastMathFlagsAttr | None = None + ) -> Log10Op: attributes = {"fastmath": fastmath} operand = SSAValue.get(operand) - return Log10Op.build(attributes=attributes, - operands=[operand], - result_types=[operand.typ]) + return Log10Op.build( + attributes=attributes, operands=[operand], result_types=[operand.typ] + ) @irdl_op_definition @@ -626,20 +672,22 @@ class Log1pOp(IRDLOperation): // Scalar log1p operation. %y = math.log1p %x : f64 """ + name: str = "math.log1p" fastmath: OptOpAttr[FastMathFlagsAttr] operand: Annotated[Operand, AnyFloat] result: Annotated[OpResult, AnyFloat] @staticmethod - def get(operand: Union[Operation, SSAValue], - fastmath: FastMathFlagsAttr | None = None) -> Log1pOp: + def get( + operand: Union[Operation, SSAValue], fastmath: FastMathFlagsAttr | None = None + ) -> Log1pOp: attributes = {"fastmath": fastmath} operand = SSAValue.get(operand) - return Log1pOp.build(attributes=attributes, - operands=[operand], - result_types=[operand.typ]) + return Log1pOp.build( + attributes=attributes, operands=[operand], result_types=[operand.typ] + ) @irdl_op_definition @@ -654,20 +702,22 @@ class Log2Op(IRDLOperation): // Scalar log2 operation. %y = math.log2 %x : f64 """ + name: str = "math.log2" fastmath: OptOpAttr[FastMathFlagsAttr] operand: Annotated[Operand, AnyFloat] result: Annotated[OpResult, AnyFloat] @staticmethod - def get(operand: Union[Operation, SSAValue], - fastmath: FastMathFlagsAttr | None = None) -> Log2Op: + def get( + operand: Union[Operation, SSAValue], fastmath: FastMathFlagsAttr | None = None + ) -> Log2Op: attributes = {"fastmath": fastmath} operand = SSAValue.get(operand) - return Log2Op.build(attributes=attributes, - operands=[operand], - result_types=[operand.typ]) + return Log2Op.build( + attributes=attributes, operands=[operand], result_types=[operand.typ] + ) @irdl_op_definition @@ -682,20 +732,22 @@ class LogOp(IRDLOperation): // Scalar log operation. %y = math.log %x : f64 """ + name: str = "math.log" fastmath: OptOpAttr[FastMathFlagsAttr] operand: Annotated[Operand, AnyFloat] result: Annotated[OpResult, AnyFloat] @staticmethod - def get(operand: Union[Operation, SSAValue], - fastmath: FastMathFlagsAttr | None = None) -> LogOp: + def get( + operand: Union[Operation, SSAValue], fastmath: FastMathFlagsAttr | None = None + ) -> LogOp: attributes = {"fastmath": fastmath} operand = SSAValue.get(operand) - return LogOp.build(attributes=attributes, - operands=[operand], - result_types=[operand.typ]) + return LogOp.build( + attributes=attributes, operands=[operand], result_types=[operand.typ] + ) @irdl_op_definition @@ -713,6 +765,7 @@ class PowFOp(IRDLOperation): // Scalar exponentiation. %a = math.powf %b, %c : f64 """ + name: str = "math.powf" fastmath: OptOpAttr[FastMathFlagsAttr] lhs: Annotated[Operand, AnyFloat] @@ -720,16 +773,18 @@ class PowFOp(IRDLOperation): result: Annotated[OpResult, AnyFloat] @staticmethod - def get(lhs: Union[Operation, SSAValue], - rhs: Union[Operation, SSAValue], - fastmath: FastMathFlagsAttr | None = None) -> PowFOp: + def get( + lhs: Union[Operation, SSAValue], + rhs: Union[Operation, SSAValue], + fastmath: FastMathFlagsAttr | None = None, + ) -> PowFOp: attributes = {"fastmath": fastmath} lhs = SSAValue.get(lhs) rhs = SSAValue.get(rhs) - return PowFOp.build(attributes=attributes, - operands=[lhs, rhs], - result_types=[lhs.typ]) + return PowFOp.build( + attributes=attributes, operands=[lhs, rhs], result_types=[lhs.typ] + ) @irdl_op_definition @@ -750,20 +805,22 @@ class RoundEvenOp(IRDLOperation): // Scalar round operation. %a = math.roundeven %b : f64 """ + name: str = "math.roundeven" fastmath: OptOpAttr[FastMathFlagsAttr] operand: Annotated[Operand, AnyFloat] result: Annotated[OpResult, AnyFloat] @staticmethod - def get(operand: Union[Operation, SSAValue], - fastmath: FastMathFlagsAttr | None = None) -> RoundEvenOp: + def get( + operand: Union[Operation, SSAValue], fastmath: FastMathFlagsAttr | None = None + ) -> RoundEvenOp: attributes = {"fastmath": fastmath} operand = SSAValue.get(operand) - return RoundEvenOp.build(attributes=attributes, - operands=[operand], - result_types=[operand.typ]) + return RoundEvenOp.build( + attributes=attributes, operands=[operand], result_types=[operand.typ] + ) @irdl_op_definition @@ -784,20 +841,22 @@ class RoundOp(IRDLOperation): // Scalar round operation. %a = math.round %b : f64 """ + name: str = "math.round" fastmath: OptOpAttr[FastMathFlagsAttr] operand: Annotated[Operand, AnyFloat] result: Annotated[OpResult, AnyFloat] @staticmethod - def get(operand: Union[Operation, SSAValue], - fastmath: FastMathFlagsAttr | None = None) -> RoundOp: + def get( + operand: Union[Operation, SSAValue], fastmath: FastMathFlagsAttr | None = None + ) -> RoundOp: attributes = {"fastmath": fastmath} operand = SSAValue.get(operand) - return RoundOp.build(attributes=attributes, - operands=[operand], - result_types=[operand.typ]) + return RoundOp.build( + attributes=attributes, operands=[operand], result_types=[operand.typ] + ) @irdl_op_definition @@ -811,20 +870,22 @@ class RsqrtOp(IRDLOperation): // Scalar reciprocal square root value. %a = math.rsqrt %b : f64 """ + name: str = "math.rsqrt" fastmath: OptOpAttr[FastMathFlagsAttr] operand: Annotated[Operand, AnyFloat] result: Annotated[OpResult, AnyFloat] @staticmethod - def get(operand: Union[Operation, SSAValue], - fastmath: FastMathFlagsAttr | None = None) -> RsqrtOp: + def get( + operand: Union[Operation, SSAValue], fastmath: FastMathFlagsAttr | None = None + ) -> RsqrtOp: attributes = {"fastmath": fastmath} operand = SSAValue.get(operand) - return RsqrtOp.build(attributes=attributes, - operands=[operand], - result_types=[operand.typ]) + return RsqrtOp.build( + attributes=attributes, operands=[operand], result_types=[operand.typ] + ) @irdl_op_definition @@ -842,20 +903,22 @@ class SinOp(IRDLOperation): // Scalar sine value. %a = math.sin %b : f64 """ + name: str = "math.sin" fastmath: OptOpAttr[FastMathFlagsAttr] operand: Annotated[Operand, AnyFloat] result: Annotated[OpResult, AnyFloat] @staticmethod - def get(operand: Union[Operation, SSAValue], - fastmath: FastMathFlagsAttr | None = None) -> SinOp: + def get( + operand: Union[Operation, SSAValue], fastmath: FastMathFlagsAttr | None = None + ) -> SinOp: attributes = {"fastmath": fastmath} operand = SSAValue.get(operand) - return SinOp.build(attributes=attributes, - operands=[operand], - result_types=[operand.typ]) + return SinOp.build( + attributes=attributes, operands=[operand], result_types=[operand.typ] + ) @irdl_op_definition @@ -869,20 +932,22 @@ class SqrtOp(IRDLOperation): // Scalar square root value. %a = math.sqrt %b : f64 """ + name: str = "math.sqrt" fastmath: OptOpAttr[FastMathFlagsAttr] operand: Annotated[Operand, AnyFloat] result: Annotated[OpResult, AnyFloat] @staticmethod - def get(operand: Union[Operation, SSAValue], - fastmath: FastMathFlagsAttr | None = None) -> SqrtOp: + def get( + operand: Union[Operation, SSAValue], fastmath: FastMathFlagsAttr | None = None + ) -> SqrtOp: attributes = {"fastmath": fastmath} operand = SSAValue.get(operand) - return SqrtOp.build(attributes=attributes, - operands=[operand], - result_types=[operand.typ]) + return SqrtOp.build( + attributes=attributes, operands=[operand], result_types=[operand.typ] + ) @irdl_op_definition @@ -897,20 +962,22 @@ class TanOp(IRDLOperation): // Scalar tangent value. %a = math.tan %b : f64 """ + name: str = "math.tan" fastmath: OptOpAttr[FastMathFlagsAttr] operand: Annotated[Operand, AnyFloat] result: Annotated[OpResult, AnyFloat] @staticmethod - def get(operand: Union[Operation, SSAValue], - fastmath: FastMathFlagsAttr | None = None) -> TanOp: + def get( + operand: Union[Operation, SSAValue], fastmath: FastMathFlagsAttr | None = None + ) -> TanOp: attributes = {"fastmath": fastmath} operand = SSAValue.get(operand) - return TanOp.build(attributes=attributes, - operands=[operand], - result_types=[operand.typ]) + return TanOp.build( + attributes=attributes, operands=[operand], result_types=[operand.typ] + ) @irdl_op_definition @@ -925,20 +992,22 @@ class TanhOp(IRDLOperation): // Scalar hyperbolic tangent value. %a = math.tanh %b : f64 """ + name: str = "math.tanh" fastmath: OptOpAttr[FastMathFlagsAttr] operand: Annotated[Operand, AnyFloat] result: Annotated[OpResult, AnyFloat] @staticmethod - def get(operand: Union[Operation, SSAValue], - fastmath: FastMathFlagsAttr | None = None) -> TanhOp: + def get( + operand: Union[Operation, SSAValue], fastmath: FastMathFlagsAttr | None = None + ) -> TanhOp: attributes = {"fastmath": fastmath} operand = SSAValue.get(operand) - return TanhOp.build(attributes=attributes, - operands=[operand], - result_types=[operand.typ]) + return TanhOp.build( + attributes=attributes, operands=[operand], result_types=[operand.typ] + ) @irdl_op_definition @@ -958,53 +1027,57 @@ class TruncOp(IRDLOperation): // Scalar trunc operation. %a = math.trunc %b : f64 """ + name: str = "math.trunc" fastmath: OptOpAttr[FastMathFlagsAttr] operand: Annotated[Operand, AnyFloat] result: Annotated[OpResult, AnyFloat] @staticmethod - def get(operand: Union[Operation, SSAValue], - fastmath: FastMathFlagsAttr | None = None) -> TruncOp: + def get( + operand: Union[Operation, SSAValue], fastmath: FastMathFlagsAttr | None = None + ) -> TruncOp: attributes = {"fastmath": fastmath} operand = SSAValue.get(operand) - return TruncOp.build(attributes=attributes, - operands=[operand], - result_types=[operand.typ]) - - -Math = Dialect([ - AbsFOp, - AbsIOp, - Atan2Op, - AtanOp, - CbrtOp, - CeilOp, - CopySignOp, - CosOp, - CountLeadingZerosOp, - CountTrailingZerosOp, - CtPopOp, - ErfOp, - Exp2Op, - ExpM1Op, - ExpOp, - FPowIOp, - FloorOp, - FmaOp, - IPowIOp, - Log10Op, - Log1pOp, - Log2Op, - LogOp, - PowFOp, - RoundEvenOp, - RoundOp, - RsqrtOp, - SinOp, - SqrtOp, - TanOp, - TanhOp, - TruncOp, -]) + return TruncOp.build( + attributes=attributes, operands=[operand], result_types=[operand.typ] + ) + + +Math = Dialect( + [ + AbsFOp, + AbsIOp, + Atan2Op, + AtanOp, + CbrtOp, + CeilOp, + CopySignOp, + CosOp, + CountLeadingZerosOp, + CountTrailingZerosOp, + CtPopOp, + ErfOp, + Exp2Op, + ExpM1Op, + ExpOp, + FPowIOp, + FloorOp, + FmaOp, + IPowIOp, + Log10Op, + Log1pOp, + Log2Op, + LogOp, + PowFOp, + RoundEvenOp, + RoundOp, + RsqrtOp, + SinOp, + SqrtOp, + TanOp, + TanhOp, + TruncOp, + ] +) diff --git a/xdsl/dialects/experimental/stencil.py b/xdsl/dialects/experimental/stencil.py index a2e840a6d3..8aa5e8cb6e 100644 --- a/xdsl/dialects/experimental/stencil.py +++ b/xdsl/dialects/experimental/stencil.py @@ -5,17 +5,41 @@ from xdsl.dialects import builtin from xdsl.dialects import memref -from xdsl.dialects.builtin import (AnyIntegerAttr, IntegerAttr, - ParametrizedAttribute, ArrayAttr, f32, f64, - IntegerType, IntAttr, AnyFloat) +from xdsl.dialects.builtin import ( + AnyIntegerAttr, + IntegerAttr, + ParametrizedAttribute, + ArrayAttr, + f32, + f64, + IntegerType, + IntAttr, + AnyFloat, +) from xdsl.ir import Operation, Dialect, TypeAttribute from xdsl.ir import SSAValue -from xdsl.irdl import (irdl_attr_definition, irdl_op_definition, ParameterDef, - AttrConstraint, Attribute, Region, VerifyException, - Generic, AnyOf, Annotated, Operand, OpAttr, OpResult, - VarOperand, VarOpResult, OptOpAttr, - AttrSizedOperandSegments, Block, IRDLOperation) +from xdsl.irdl import ( + irdl_attr_definition, + irdl_op_definition, + ParameterDef, + AttrConstraint, + Attribute, + Region, + VerifyException, + Generic, + AnyOf, + Annotated, + Operand, + OpAttr, + OpResult, + VarOperand, + VarOpResult, + OptOpAttr, + AttrSizedOperandSegments, + Block, + IRDLOperation, +) from xdsl.utils.hints import isa @@ -26,7 +50,8 @@ class IntOrUnknown(AttrConstraint): def verify(self, attr: Attribute) -> None: if not isinstance(attr, ArrayAttr): raise VerifyException( - f"Expected {ArrayAttr} attribute, but got {attr.name}.") + f"Expected {ArrayAttr} attribute, but got {attr.name}." + ) attr = cast(ArrayAttr[Any], attr) if len(attr.data) != self.length: @@ -39,17 +64,17 @@ def verify(self, attr: Attribute) -> None: @irdl_attr_definition -class FieldType(Generic[_FieldTypeElement], ParametrizedAttribute, - TypeAttribute): +class FieldType(Generic[_FieldTypeElement], ParametrizedAttribute, TypeAttribute): name = "stencil.field" shape: ParameterDef[ArrayAttr[AnyIntegerAttr]] element_type: ParameterDef[_FieldTypeElement] @staticmethod - def from_shape(shape: ArrayAttr[AnyIntegerAttr] | Sequence[AnyIntegerAttr] - | Sequence[int], - typ: _FieldTypeElement) -> FieldType[_FieldTypeElement]: + def from_shape( + shape: ArrayAttr[AnyIntegerAttr] | Sequence[AnyIntegerAttr] | Sequence[int], + typ: _FieldTypeElement, + ) -> FieldType[_FieldTypeElement]: assert len(shape) > 0 if isinstance(shape, ArrayAttr): @@ -63,21 +88,22 @@ def from_shape(shape: ArrayAttr[AnyIntegerAttr] | Sequence[AnyIntegerAttr] return FieldType([ArrayAttr(shape), typ]) # type: ignore shape = cast(list[int], shape) return FieldType( - [ArrayAttr([IntegerAttr[IntegerType](d, 64) for d in shape]), typ]) + [ArrayAttr([IntegerAttr[IntegerType](d, 64) for d in shape]), typ] + ) @irdl_attr_definition -class TempType(Generic[_FieldTypeElement], ParametrizedAttribute, - TypeAttribute): +class TempType(Generic[_FieldTypeElement], ParametrizedAttribute, TypeAttribute): name = "stencil.temp" shape: ParameterDef[ArrayAttr[AnyIntegerAttr]] element_type: ParameterDef[_FieldTypeElement] @staticmethod - def from_shape(shape: ArrayAttr[AnyIntegerAttr] | Sequence[AnyIntegerAttr] - | Sequence[int], - typ: _FieldTypeElement) -> TempType[_FieldTypeElement]: + def from_shape( + shape: ArrayAttr[AnyIntegerAttr] | Sequence[AnyIntegerAttr] | Sequence[int], + typ: _FieldTypeElement, + ) -> TempType[_FieldTypeElement]: assert len(shape) > 0 if isinstance(shape, ArrayAttr): @@ -91,7 +117,8 @@ def from_shape(shape: ArrayAttr[AnyIntegerAttr] | Sequence[AnyIntegerAttr] return TempType([ArrayAttr(shape), typ]) # type: ignore shape = cast(list[int], shape) return TempType( - [ArrayAttr([IntegerAttr[IntegerType](d, 64) for d in shape]), typ]) + [ArrayAttr([IntegerAttr[IntegerType](d, 64) for d in shape]), typ] + ) def __repr__(self): repr: str = "stencil.Temp<[" @@ -118,7 +145,8 @@ class ArrayLength(AttrConstraint): def verify(self, attr: Attribute) -> None: if not isinstance(attr, ArrayAttr): raise VerifyException( - f"Expected {ArrayAttr} attribute, but got {attr.name}.") + f"Expected {ArrayAttr} attribute, but got {attr.name}." + ) attr = cast(ArrayAttr[Any], attr) if len(attr.data) != self.length: raise VerifyException( @@ -148,10 +176,20 @@ def verify(self) -> None: @staticmethod def get(*indices: int | IntegerAttr[IntegerType]): - return IndexAttr([ - ArrayAttr([(IntegerAttr[IntegerType](idx, 64) if isinstance( - idx, int) else idx) for idx in indices]) - ]) + return IndexAttr( + [ + ArrayAttr( + [ + ( + IntegerAttr[IntegerType](idx, 64) + if isinstance(idx, int) + else idx + ) + for idx in indices + ] + ) + ] + ) @staticmethod def size_from_bounds(lb: IndexAttr, ub: IndexAttr) -> list[int]: @@ -164,8 +202,7 @@ def size_from_bounds(lb: IndexAttr, ub: IndexAttr) -> list[int]: # on Attributes? Author's opinion is a clear yes :P def __neg__(self) -> IndexAttr: integer_attrs: list[Attribute] = [ - IntegerAttr(-e.value.data, IntegerType(64)) - for e in self.array.data + IntegerAttr(-e.value.data, IntegerType(64)) for e in self.array.data ] return IndexAttr([ArrayAttr(integer_attrs)]) @@ -218,6 +255,7 @@ class CastOp(IRDLOperation): Example: %0 = stencil.cast %in ([-3, -3, 0] : [67, 67, 60]) : (!stencil.field) -> !stencil.field<70x70x60xf64> # noqa """ + name: str = "stencil.cast" field: Annotated[Operand, FieldType] lb: OpAttr[IndexAttr] @@ -225,14 +263,15 @@ class CastOp(IRDLOperation): result: Annotated[OpResult, FieldType] @staticmethod - def get(field: SSAValue | Operation, lb: IndexAttr, ub: IndexAttr, - res_type: FieldType[_FieldTypeElement]) -> CastOp: + def get( + field: SSAValue | Operation, + lb: IndexAttr, + ub: IndexAttr, + res_type: FieldType[_FieldTypeElement], + ) -> CastOp: return CastOp.build( operands=[field], - attributes={ - "lb": lb, - "ub": ub - }, + attributes={"lb": lb, "ub": ub}, result_types=[res_type], ) @@ -246,13 +285,16 @@ class ExternalLoadOp(IRDLOperation): Example: %0 = stencil.external_load %in : (!fir.array<128x128xf64>) -> !stencil.field<128x128xf64> # noqa """ + name: str = "stencil.external_load" field: Annotated[Operand, Attribute] result: Annotated[OpResult, FieldType | memref.MemRefType] @staticmethod - def get(arg: SSAValue | Operation, - res_type: FieldType[Attribute] | memref.MemRefType[Attribute]): + def get( + arg: SSAValue | Operation, + res_type: FieldType[Attribute] | memref.MemRefType[Attribute], + ): return ExternalLoadOp.build(operands=[arg], result_types=[res_type]) @@ -264,6 +306,7 @@ class ExternalStoreOp(IRDLOperation): Example: stencil.store %temp to %field : !stencil.field<128x128xf64> to !fir.array<128x128xf64> # noqa """ + name: str = "stencil.external_store" temp: Annotated[Operand, FieldType] field: Annotated[Operand, Attribute] @@ -279,6 +322,7 @@ class IndexOp(IRDLOperation): Example: %0 = stencil.index 0 [-1, 0, 0] : index """ + name: str = "stencil.index" dim: OpAttr[IntegerType] offset: OpAttr[IndexAttr] @@ -294,6 +338,7 @@ class AccessOp(IRDLOperation): Example: %0 = stencil.access %temp [-1, 0, 0] : !stencil.temp -> f64 """ + name: str = "stencil.access" temp: Annotated[Operand, TempType] offset: OpAttr[IndexAttr] @@ -308,11 +353,13 @@ def get(temp: SSAValue | Operation, offset: Sequence[int]): return AccessOp.build( operands=[temp], attributes={ - 'offset': - IndexAttr([ - ArrayAttr(IntegerAttr[IntegerType](value, 64) - for value in offset), - ]), + "offset": IndexAttr( + [ + ArrayAttr( + IntegerAttr[IntegerType](value, 64) for value in offset + ), + ] + ), }, result_types=[temp_type.element_type], ) @@ -329,6 +376,7 @@ class DynAccessOp(IRDLOperation): Example: %0 = stencil.dyn_access %temp (%i, %j, %k) in [-1, -1, -1] : [1, 1, 1] : !stencil.temp -> f64 """ + name: str = "stencil.dyn_access" temp: Annotated[Operand, TempType] offset: OpAttr[IndexAttr] @@ -345,6 +393,7 @@ class LoadOp(IRDLOperation): Example: %0 = stencil.load %field : (!stencil.field<70x70x60xf64>) -> !stencil.temp """ + name: str = "stencil.load" field: Annotated[Operand, FieldType] lb: OptOpAttr[IndexAttr] @@ -360,9 +409,11 @@ def get(field: SSAValue | Operation): return LoadOp.build( operands=[field], result_types=[ - TempType[Attribute].from_shape([-1] * len(field_t.shape.data), - field_t.element_type) - ]) + TempType[Attribute].from_shape( + [-1] * len(field_t.shape.data), field_t.element_type + ) + ], + ) @irdl_op_definition @@ -373,6 +424,7 @@ class BufferOp(IRDLOperation): Example: %0 = stencil.buffer %buffered : (!stencil.temp) -> !stencil.temp """ + name: str = "stencil.buffer" temp: Annotated[Operand, TempType] lb: OpAttr[IndexAttr] @@ -388,6 +440,7 @@ class StoreOp(IRDLOperation): Example: stencil.store %temp to %field ([0,0,0] : [64,64,60]) : !stencil.temp to !stencil.field<70x70x60xf64> """ + name: str = "stencil.store" temp: Annotated[Operand, TempType] field: Annotated[Operand, FieldType] @@ -395,13 +448,13 @@ class StoreOp(IRDLOperation): ub: OpAttr[IndexAttr] @staticmethod - def get(temp: SSAValue | Operation, field: SSAValue | Operation, - lb: IndexAttr, ub: IndexAttr): - return StoreOp.build(operands=[temp, field], - attributes={ - 'lb': lb, - 'ub': ub - }) + def get( + temp: SSAValue | Operation, + field: SSAValue | Operation, + lb: IndexAttr, + ub: IndexAttr, + ): + return StoreOp.build(operands=[temp, field], attributes={"lb": lb, "ub": ub}) @irdl_op_definition @@ -416,6 +469,7 @@ class ApplyOp(IRDLOperation): ... } """ + name: str = "stencil.apply" args: Annotated[VarOperand, Attribute] lb: OptOpAttr[IndexAttr] @@ -424,11 +478,13 @@ class ApplyOp(IRDLOperation): res: Annotated[VarOpResult, TempType] @staticmethod - def get(args: Sequence[SSAValue] | Sequence[Operation], - body: Block, - lb: IndexAttr | None = None, - ub: IndexAttr | None = None, - result_count: int = 1): + def get( + args: Sequence[SSAValue] | Sequence[Operation], + body: Block, + lb: IndexAttr | None = None, + ub: IndexAttr | None = None, + result_count: int = 1, + ): assert len(args) > 0 field_t = SSAValue.get(args[0]).typ assert isinstance(field_t, TempType) @@ -442,14 +498,17 @@ def get(args: Sequence[SSAValue] | Sequence[Operation], if ub is not None: attributes["ub"] = ub - return ApplyOp.build(operands=[list(args)], - attributes=attributes, - regions=[Region(body)], - result_types=[[ - TempType.from_shape([-1] * result_rank, - field_t.element_type) - for _ in range(result_count) - ]]) + return ApplyOp.build( + operands=[list(args)], + attributes=attributes, + regions=[Region(body)], + result_types=[ + [ + TempType.from_shape([-1] * result_rank, field_t.element_type) + for _ in range(result_count) + ] + ], + ) @irdl_op_definition @@ -461,6 +520,7 @@ class StoreResultOp(IRDLOperation): stencil.store_result %0 : !stencil.result stencil.store_result : !stencil.result """ + name: str = "stencil.store_result" args: Annotated[VarOperand, Attribute] res: Annotated[OpResult, ResultType] @@ -480,6 +540,7 @@ class ReturnOp(IRDLOperation): Examples: stencil.return %0 : !stencil.result """ + name: str = "stencil.return" arg: Annotated[VarOperand, ResultType | AnyFloat] @@ -500,10 +561,11 @@ class CombineOp(IRDLOperation): Example: %result = stencil.combine 2 at 11 lower = (%0 : !stencil.temp) upper = (%1 : !stencil.temp) lowerext = (%2 : !stencil.temp): !stencil.temp, !stencil.temp """ + name: str = "stencil.combine" dim: Annotated[ - Operand, - IntegerType] # TODO: how to use the ArrayLength constraint here? 0 <= dim <= 2 + Operand, IntegerType + ] # TODO: how to use the ArrayLength constraint here? 0 <= dim <= 2 index: Annotated[Operand, IntegerType] lower: Annotated[VarOperand, TempType] @@ -536,26 +598,29 @@ def get(input_stencil: SSAValue | Operation): return HaloSwapOp.build(operands=[input_stencil]) -Stencil = Dialect([ - CastOp, - ExternalLoadOp, - ExternalStoreOp, - IndexOp, - AccessOp, - DynAccessOp, - LoadOp, - BufferOp, - StoreOp, - ApplyOp, - StoreResultOp, - ReturnOp, - CombineOp, - HaloSwapOp, -], [ - FieldType, - TempType, - ResultType, - ElementType, - IndexAttr, - LoopAttr, -]) +Stencil = Dialect( + [ + CastOp, + ExternalLoadOp, + ExternalStoreOp, + IndexOp, + AccessOp, + DynAccessOp, + LoadOp, + BufferOp, + StoreOp, + ApplyOp, + StoreResultOp, + ReturnOp, + CombineOp, + HaloSwapOp, + ], + [ + FieldType, + TempType, + ResultType, + ElementType, + IndexAttr, + LoopAttr, + ], +) diff --git a/xdsl/dialects/func.py b/xdsl/dialects/func.py index 218be05a11..1dff0d6487 100644 --- a/xdsl/dialects/func.py +++ b/xdsl/dialects/func.py @@ -2,9 +2,24 @@ from typing import Annotated, List, Union from xdsl.dialects.builtin import StringAttr, FunctionType, SymbolRefAttr -from xdsl.ir import SSAValue, Operation, Block, Region, Attribute, Dialect, BlockArgument -from xdsl.irdl import (VarOpResult, irdl_op_definition, VarOperand, AnyAttr, - OpAttr, OptOpAttr, IRDLOperation) +from xdsl.ir import ( + SSAValue, + Operation, + Block, + Region, + Attribute, + Dialect, + BlockArgument, +) +from xdsl.irdl import ( + VarOpResult, + irdl_op_definition, + VarOperand, + AnyAttr, + OpAttr, + OptOpAttr, + IRDLOperation, +) from xdsl.utils.exceptions import VerifyException @@ -27,49 +42,58 @@ def verify_(self) -> None: if self.function_type.inputs.data != tuple(block_arg_types): raise VerifyException( "Expected entry block arguments to have the same types as the function " - "input types") + "input types" + ) @staticmethod - def from_callable(name: str, input_types: List[Attribute], - return_types: List[Attribute], - func: Block.BlockCallback) -> FuncOp: + def from_callable( + name: str, + input_types: List[Attribute], + return_types: List[Attribute], + func: Block.BlockCallback, + ) -> FuncOp: type_attr = FunctionType.from_lists(input_types, return_types) attributes: dict[str, Attribute] = { "sym_name": StringAttr(name), "function_type": type_attr, - "sym_visibility": StringAttr("private") + "sym_visibility": StringAttr("private"), } op = FuncOp.build( attributes=attributes, - regions=[Region(Block.from_callable(input_types, func))]) + regions=[Region(Block.from_callable(input_types, func))], + ) return op @staticmethod - def external(name: str, input_types: List[Attribute], - return_types: List[Attribute]) -> FuncOp: + def external( + name: str, input_types: List[Attribute], return_types: List[Attribute] + ) -> FuncOp: type_attr = FunctionType.from_lists(input_types, return_types) attributes: dict[str, Attribute] = { "sym_name": StringAttr(name), "function_type": type_attr, - "sym_visibility": StringAttr("private") + "sym_visibility": StringAttr("private"), } op = FuncOp.build(attributes=attributes, regions=[Region([Block()])]) return op @staticmethod - def from_region(name: str, input_types: List[Attribute], - return_types: List[Attribute], region: Region) -> FuncOp: + def from_region( + name: str, + input_types: List[Attribute], + return_types: List[Attribute], + region: Region, + ) -> FuncOp: type_attr = FunctionType.from_lists(input_types, return_types) attributes: dict[str, Attribute] = { "sym_name": StringAttr(name), "function_type": type_attr, - "sym_visibility": StringAttr("private") + "sym_visibility": StringAttr("private"), } op = FuncOp.build(attributes=attributes, regions=[region]) return op - def replace_argument_type(self, arg: int | BlockArgument, - new_type: Attribute): + def replace_argument_type(self, arg: int | BlockArgument, new_type: Attribute): """ Replaces the type of the argument specified by arg (either the index of the arg, or the BlockArgument object itself) with new_type. This also takes care of updating @@ -79,12 +103,14 @@ def replace_argument_type(self, arg: int | BlockArgument, try: arg = self.body.blocks[0].args[arg] except IndexError: - raise IndexError("Block {} does not have argument #{}".format( - self.body.blocks[0], arg)) + raise IndexError( + "Block {} does not have argument #{}".format( + self.body.blocks[0], arg + ) + ) if arg not in self.args: - raise ValueError( - "Arg {} does not belong to this function".format(arg)) + raise ValueError("Arg {} does not belong to this function".format(arg)) arg.typ = new_type self.update_function_type() @@ -95,14 +121,16 @@ def update_function_type(self): block argument types or return statement arguments. """ # Refuse to work with external function definitions, as they don't have block args - assert not self.is_declaration, "update_function_type does not work with function declarations!" + assert ( + not self.is_declaration + ), "update_function_type does not work with function declarations!" return_op = self.get_return_op() return_type: tuple[Attribute] = self.function_type.outputs.data if return_op is not None: return_type = tuple(arg.typ for arg in return_op.operands) - self.attributes['function_type'] = FunctionType.from_lists( + self.attributes["function_type"] = FunctionType.from_lists( [arg.typ for arg in self.args], return_type, ) @@ -124,7 +152,9 @@ def args(self) -> tuple[BlockArgument, ...]: """ A helper to quickly get access to the block arguments of the function """ - assert not self.is_declaration, "Function declarations don't have BlockArguments!" + assert ( + not self.is_declaration + ), "Function declarations don't have BlockArguments!" return self.body.blocks[0].args @property @@ -146,14 +176,16 @@ class Call(IRDLOperation): # TODO how do we verify that the types are correct? @staticmethod - def get(callee: Union[str, SymbolRefAttr], ops: List[Union[SSAValue, - Operation]], - return_types: List[Attribute]) -> Call: + def get( + callee: Union[str, SymbolRefAttr], + ops: List[Union[SSAValue, Operation]], + return_types: List[Attribute], + ) -> Call: if isinstance(callee, str): callee = SymbolRefAttr(callee) - return Call.build(operands=[ops], - result_types=[return_types], - attributes={"callee": callee}) + return Call.build( + operands=[ops], result_types=[return_types], attributes={"callee": callee} + ) @irdl_op_definition diff --git a/xdsl/dialects/gpu.py b/xdsl/dialects/gpu.py index bed9de7c3d..9cda0b2459 100644 --- a/xdsl/dialects/gpu.py +++ b/xdsl/dialects/gpu.py @@ -1,12 +1,30 @@ from __future__ import annotations from typing import Annotated, Generic, Sequence, TypeVar -from xdsl.ir import (Attribute, TypeAttribute, OpResult, Operation, Dialect, - ParametrizedAttribute, Region, SSAValue) -from xdsl.irdl import (AttrSizedOperandSegments, Operand, OptOpAttr, - OptOpResult, OptOperand, ParameterDef, VarOperand, - irdl_op_definition, irdl_attr_definition, - SingleBlockRegion, OpAttr, IRDLOperation) +from xdsl.ir import ( + Attribute, + TypeAttribute, + OpResult, + Operation, + Dialect, + ParametrizedAttribute, + Region, + SSAValue, +) +from xdsl.irdl import ( + AttrSizedOperandSegments, + Operand, + OptOpAttr, + OptOpResult, + OptOperand, + ParameterDef, + VarOperand, + irdl_op_definition, + irdl_attr_definition, + SingleBlockRegion, + OpAttr, + IRDLOperation, +) from xdsl.dialects.builtin import IndexType, StringAttr, SymbolRefAttr, UnitAttr, i32 from xdsl.dialects import memref from xdsl.parser import BaseParser @@ -39,9 +57,7 @@ def print_parameters(self, printer: Printer) -> None: printer.print(f"dim {self.param.data}") -T = TypeVar('T', - bound=_AllReduceOperationAttr | _DimensionAttr, - covariant=True) +T = TypeVar("T", bound=_AllReduceOperationAttr | _DimensionAttr, covariant=True) @irdl_attr_definition @@ -54,7 +70,7 @@ class _GPUAttr(ParametrizedAttribute, Generic[T]): def parse_parameters(parser: BaseParser) -> list[Attribute]: parser.parse_characters( "<", - f"Expected <. gpu attributes currently have the #gpu syntax." + f"Expected <. gpu attributes currently have the #gpu syntax.", ) ntok = parser.tokenizer.next_token() @@ -64,31 +80,31 @@ def parse_parameters(parser: BaseParser) -> list[Attribute]: if vtok.text not in ["x", "y", "z"]: parser.raise_error( f"Unexpected dim {vtok.text}. A gpu dim can only be x, y, or z", - vtok) + vtok, + ) elif ntok.text == "all_reduce_op": attrtype = _AllReduceOperationAttr vtok = parser.tokenizer.next_token() - if vtok.text not in [ - "add", "and", "max", "min", "mul", "or", "xor" - ]: + if vtok.text not in ["add", "and", "max", "min", "mul", "or", "xor"]: parser.raise_error( f"Unexpected op {vtok.text}. A gpu all_reduce_op can only be add, " - "and, max, min, mul, or, or xor ", vtok) + "and, max, min, mul, or, or xor ", + vtok, + ) else: parser.raise_error( - f"Unexpected token {ntok.text}. Expected dim or all_reduce_op", - ntok) + f"Unexpected token {ntok.text}. Expected dim or all_reduce_op", ntok + ) parser.parse_characters( ">", - f"Expected >. gpu attributes currently have the #gpu syntax." + f"Expected >. gpu attributes currently have the #gpu syntax.", ) return [attrtype([StringAttr(vtok.text)])] @staticmethod def from_op(value: str) -> AllReduceOperationAttr: - return AllReduceOperationAttr( - [_AllReduceOperationAttr([StringAttr(value)])]) + return AllReduceOperationAttr([_AllReduceOperationAttr([StringAttr(value)])]) @property def data(self) -> str: @@ -131,29 +147,32 @@ def verify_(self) -> None: if ndyn != ndyn_typ: raise VerifyException( f"Expected {ndyn_typ} dynamic sizes, got {ndyn}. All " - "dynamic sizes need to be set in the alloc operation.") + "dynamic sizes need to be set in the alloc operation." + ) @staticmethod - def get(return_type: memref.MemRefType[_Element], - dynamic_sizes: Sequence[SSAValue | Operation] - | None = None, - host_shared: bool = False, - async_dependencies: Sequence[SSAValue | Operation] | None = None, - is_async: bool = False) -> AllocOp: + def get( + return_type: memref.MemRefType[_Element], + dynamic_sizes: Sequence[SSAValue | Operation] | None = None, + host_shared: bool = False, + async_dependencies: Sequence[SSAValue | Operation] | None = None, + is_async: bool = False, + ) -> AllocOp: token_return = [AsyncTokenType()] if is_async else [] - dynamic_sizes_vals: list[SSAValue] = [ - SSAValue.get(e) for e in dynamic_sizes - ] if dynamic_sizes else [] - async_dependencies_vals: list[SSAValue] = [ - SSAValue.get(e) for e in async_dependencies - ] if async_dependencies else [] - attributes: dict[str, Attribute] = { - "hostShared": UnitAttr() - } if host_shared else {} + dynamic_sizes_vals: list[SSAValue] = ( + [SSAValue.get(e) for e in dynamic_sizes] if dynamic_sizes else [] + ) + async_dependencies_vals: list[SSAValue] = ( + [SSAValue.get(e) for e in async_dependencies] if async_dependencies else [] + ) + attributes: dict[str, Attribute] = ( + {"hostShared": UnitAttr()} if host_shared else {} + ) return AllocOp.build( operands=[async_dependencies_vals, dynamic_sizes_vals, []], result_types=[return_type, token_return], - attributes=attributes) + attributes=attributes, + ) @irdl_op_definition @@ -166,27 +185,31 @@ class AllReduceOp(IRDLOperation): body: Region @staticmethod - def from_op(op: AllReduceOperationAttr, - operand: SSAValue | Operation, - uniform: UnitAttr | None = None): - - return AllReduceOp.build(operands=[operand], - result_types=[SSAValue.get(operand).typ], - attributes={ - "op": op, - "uniform": uniform, - }, - regions=[Region()]) + def from_op( + op: AllReduceOperationAttr, + operand: SSAValue | Operation, + uniform: UnitAttr | None = None, + ): + return AllReduceOp.build( + operands=[operand], + result_types=[SSAValue.get(operand).typ], + attributes={ + "op": op, + "uniform": uniform, + }, + regions=[Region()], + ) @staticmethod - def from_body(body: Region, - operand: SSAValue | Operation, - uniform: UnitAttr | None = None): + def from_body( + body: Region, operand: SSAValue | Operation, uniform: UnitAttr | None = None + ): return AllReduceOp.build( operands=[operand], result_types=[SSAValue.get(operand).typ], attributes={"uniform": uniform} if uniform is not None else {}, - regions=[body]) + regions=[body], + ) def verify_(self) -> None: if self.result.typ != self.operand.typ: @@ -201,7 +224,8 @@ def verify_(self) -> None: if op_attr: raise VerifyException( f"gpu.all_reduce can't have both a non-empty region and an op " - "attribute.") + "attribute." + ) else: raise VerifyException( f"gpu.all_reduce need either a non empty body or an op attribute." @@ -213,7 +237,8 @@ def verify_(self) -> None: raise VerifyException( f"Expected {[str(t) for t in [self.result.typ, self.operand.typ]]}, " f"got {[str(t) for t in args_types]}. A gpu.all_reduce's body must " - "have two arguments matching the result type.") + "have two arguments matching the result type." + ) @irdl_op_definition @@ -233,8 +258,9 @@ class BlockDimOp(IRDLOperation): @staticmethod def get(dim: DimensionAttr) -> BlockDimOp: - return BlockDimOp.build(result_types=[IndexType()], - attributes={"dimension": dim}) + return BlockDimOp.build( + result_types=[IndexType()], attributes={"dimension": dim} + ) @irdl_op_definition @@ -245,8 +271,9 @@ class BlockIdOp(IRDLOperation): @staticmethod def get(dim: DimensionAttr) -> BlockIdOp: - return BlockIdOp.build(result_types=[IndexType()], - attributes={"dimension": dim}) + return BlockIdOp.build( + result_types=[IndexType()], attributes={"dimension": dim} + ) @irdl_op_definition @@ -261,12 +288,15 @@ class DeallocOp(IRDLOperation): asyncToken: Annotated[OptOpResult, AsyncTokenType] @staticmethod - def get(buffer: SSAValue | Operation, - async_dependencies: Sequence[SSAValue | Operation] | None = None, - is_async: bool = False) -> DeallocOp: + def get( + buffer: SSAValue | Operation, + async_dependencies: Sequence[SSAValue | Operation] | None = None, + is_async: bool = False, + ) -> DeallocOp: return DeallocOp.build( operands=[async_dependencies, buffer], - result_types=[[AsyncTokenType()] if is_async else []]) + result_types=[[AsyncTokenType()] if is_async else []], + ) @irdl_op_definition @@ -282,19 +312,23 @@ class MemcpyOp(IRDLOperation): asyncToken: Annotated[OptOpResult, AsyncTokenType] @staticmethod - def get(source: SSAValue | Operation, - destination: SSAValue | Operation, - async_dependencies: Sequence[SSAValue | Operation] | None = None, - is_async: bool = False) -> MemcpyOp: + def get( + source: SSAValue | Operation, + destination: SSAValue | Operation, + async_dependencies: Sequence[SSAValue | Operation] | None = None, + is_async: bool = False, + ) -> MemcpyOp: return MemcpyOp.build( operands=[async_dependencies, source, destination], - result_types=[[AsyncTokenType()] if is_async else []]) + result_types=[[AsyncTokenType()] if is_async else []], + ) def verify_(self) -> None: if self.src.typ != self.dst.typ: raise VerifyException( f"Expected {self.src.typ}, got {self.dst.typ}. gpu.memcpy source and " - "destination types must match.") + "destination types must match." + ) @irdl_op_definition @@ -310,8 +344,9 @@ def get(name: SymbolRefAttr, ops: list[Operation]) -> ModuleOp: return op def verify_(self): - if (len(self.body.ops) == 0 - or not isinstance(self.body.block.last_op, ModuleEndOp)): + if len(self.body.ops) == 0 or not isinstance( + self.body.block.last_op, ModuleEndOp + ): raise VerifyException("gpu.module must end with gpu.module_end") @@ -323,8 +358,9 @@ class GlobalIdOp(IRDLOperation): @staticmethod def get(dim: DimensionAttr) -> GlobalIdOp: - return GlobalIdOp.build(result_types=[IndexType()], - attributes={"dimension": dim}) + return GlobalIdOp.build( + result_types=[IndexType()], attributes={"dimension": dim} + ) @irdl_op_definition @@ -335,8 +371,9 @@ class GridDimOp(IRDLOperation): @staticmethod def get(dim: DimensionAttr) -> GridDimOp: - return GridDimOp.build(result_types=[IndexType()], - attributes={"dimension": dim}) + return GridDimOp.build( + result_types=[IndexType()], attributes={"dimension": dim} + ) @irdl_op_definition @@ -350,6 +387,7 @@ class HostRegisterOp(IRDLOperation): afterwards. Writes from the device are guaranteed to be visible on the host after synchronizing with the device kernel completion. """ + name = "gpu.host_register" value: Annotated[Operand, memref.UnrankedMemrefType] @@ -391,29 +429,33 @@ def get( blockSize: Sequence[SSAValue | Operation], async_launch: bool = False, asyncDependencies: Sequence[SSAValue | Operation] | None = None, - dynamicSharedMemorySize: SSAValue | Operation | None = None + dynamicSharedMemorySize: SSAValue | Operation | None = None, ) -> LaunchOp: if len(gridSize) != 3: - raise ValueError( - f"LaunchOp must have 3 gridSizes, got {len(gridSize)}") + raise ValueError(f"LaunchOp must have 3 gridSizes, got {len(gridSize)}") if len(blockSize) != 3: - raise ValueError( - f"LaunchOp must have 3 blockSizes, got {len(blockSize)}") - operands = [[] if asyncDependencies is None else - [SSAValue.get(a) for a in asyncDependencies]] + raise ValueError(f"LaunchOp must have 3 blockSizes, got {len(blockSize)}") + operands = [ + [] + if asyncDependencies is None + else [SSAValue.get(a) for a in asyncDependencies] + ] operands += [SSAValue.get(gs) for gs in gridSize] operands += [SSAValue.get(bs) for bs in blockSize] - operands += [[] if dynamicSharedMemorySize is None else - [SSAValue.get(dynamicSharedMemorySize)]] + operands += [ + [] + if dynamicSharedMemorySize is None + else [SSAValue.get(dynamicSharedMemorySize)] + ] return LaunchOp.build( operands=operands, result_types=[[AsyncTokenType()] if async_launch else []], - regions=[body]) + regions=[body], + ) def verify_(self) -> None: - if len(self.body.blocks) == 0 or all(b.is_empty - for b in self.body.blocks): + if len(self.body.blocks) == 0 or all(b.is_empty for b in self.body.blocks): raise VerifyException("gpu.launch requires a non-empty body.") body_args = self.body.blocks[0].args args_type = [a.typ for a in body_args] @@ -488,10 +530,12 @@ def verify_(self) -> None: if block is not None: if self is not block.last_op: raise VerifyException( - "A gpu.terminator must terminate its parent block") + "A gpu.terminator must terminate its parent block" + ) if op is not None and not isinstance(op, LaunchOp): raise VerifyException( - "gpu.terminator is only meant to terminate gpu.launch") + "gpu.terminator is only meant to terminate gpu.launch" + ) @irdl_op_definition @@ -502,8 +546,9 @@ class ThreadIdOp(IRDLOperation): @staticmethod def get(dim: DimensionAttr) -> ThreadIdOp: - return ThreadIdOp.build(result_types=[IndexType()], - attributes={"dimension": dim}) + return ThreadIdOp.build( + result_types=[IndexType()], attributes={"dimension": dim} + ) @irdl_op_definition @@ -520,41 +565,44 @@ def verify_(self) -> None: op = self.parent_op() if block is not None: if self is not block.last_op: - raise VerifyException( - "A gpu.yield must terminate its parent block") + raise VerifyException("A gpu.yield must terminate its parent block") if op is not None: yield_type = [o.typ for o in self.values] result_type = [r.typ for r in op.results] if yield_type != result_type: raise VerifyException( f"Expected {[str(t) for t in result_type]}, got {[str(t) for t in yield_type]}. The gpu.yield values " - "types must match its enclosing operation result types.") + "types must match its enclosing operation result types." + ) -#_GPUAttr has to be registered instead of DimensionAttr and AllReduceOperationAttr here. +# _GPUAttr has to be registered instead of DimensionAttr and AllReduceOperationAttr here. # This is a hack to fit MLIR's syntax in xDSL's way of parsing attributes, without making GPU builtin. # Hopefully MLIR will parse it in a more xDSL-friendly way soon, so all that can be factored in proper xDSL # atrributes. -GPU = Dialect([ - AllocOp, - AllReduceOp, - BarrierOp, - BlockDimOp, - BlockIdOp, - DeallocOp, - GlobalIdOp, - GridDimOp, - HostRegisterOp, - LaneIdOp, - LaunchOp, - MemcpyOp, - ModuleOp, - ModuleEndOp, - NumSubgroupsOp, - SetDefaultDeviceOp, - SubgroupIdOp, - SubgroupSizeOp, - TerminatorOp, - ThreadIdOp, - YieldOp, -], [_GPUAttr]) +GPU = Dialect( + [ + AllocOp, + AllReduceOp, + BarrierOp, + BlockDimOp, + BlockIdOp, + DeallocOp, + GlobalIdOp, + GridDimOp, + HostRegisterOp, + LaneIdOp, + LaunchOp, + MemcpyOp, + ModuleOp, + ModuleEndOp, + NumSubgroupsOp, + SetDefaultDeviceOp, + SubgroupIdOp, + SubgroupSizeOp, + TerminatorOp, + ThreadIdOp, + YieldOp, + ], + [_GPUAttr], +) diff --git a/xdsl/dialects/irdl.py b/xdsl/dialects/irdl.py index 6061b30fd4..eb0b0e2503 100644 --- a/xdsl/dialects/irdl.py +++ b/xdsl/dialects/irdl.py @@ -2,16 +2,21 @@ from typing import cast from xdsl.dialects.builtin import AnyArrayAttr, ArrayAttr, StringAttr -from xdsl.ir import (ParametrizedAttribute, Attribute, Dialect) -from xdsl.irdl import (ParameterDef, irdl_op_definition, irdl_attr_definition, - SingleBlockRegion, OpAttr, IRDLOperation) +from xdsl.ir import ParametrizedAttribute, Attribute, Dialect +from xdsl.irdl import ( + ParameterDef, + irdl_op_definition, + irdl_attr_definition, + SingleBlockRegion, + OpAttr, + IRDLOperation, +) from xdsl.parser import BaseParser from xdsl.printer import Printer @irdl_attr_definition class EqTypeConstraintAttr(ParametrizedAttribute): - name = "irdl.equality_type_constraint" type: ParameterDef[Attribute] @@ -69,8 +74,7 @@ def parse_parameters(parser: BaseParser) -> list[Attribute]: return [StringAttr(type_name), params_constraints] def print_parameters(self, printer: Printer) -> None: - printer.print("<\"", self.type_name.data, "\" : ", - self.params_constraints, ">") + printer.print('<"', self.type_name.data, '" : ', self.params_constraints, ">") @irdl_op_definition @@ -78,6 +82,7 @@ class DialectOp(IRDLOperation): """ Define a new dialect """ + name = "irdl.dialect" body: SingleBlockRegion @@ -105,6 +110,7 @@ class ParametersOp(IRDLOperation): """ Define the parameters of a type/attribute definition """ + name = "irdl.parameters" params: OpAttr[AnyArrayAttr] @@ -114,6 +120,7 @@ class TypeOp(IRDLOperation): """ Defines new types belonging to previously defined dialect """ + name = "irdl.type" body: SingleBlockRegion @@ -134,6 +141,7 @@ class ConstraintVarsOp(IRDLOperation): Define constraint variables that can be used in the current region """ + name = "irdl.constraint_vars" constraints: OpAttr[Attribute] @@ -143,6 +151,7 @@ class OperandsOp(IRDLOperation): """ Define the operands of a parent operation """ + name = "irdl.operands" params: OpAttr[Attribute] @@ -152,6 +161,7 @@ class ResultsOp(IRDLOperation): """ Define results of parent operation """ + name = "irdl.results" params: OpAttr[Attribute] @@ -161,6 +171,7 @@ class OperationOp(IRDLOperation): """ Define a new operation belonging to previously defined dialect """ + name = "irdl.operation" body: SingleBlockRegion @@ -207,5 +218,6 @@ def get_results(self) -> ResultsOp | None: TypeParamsConstraintAttr, NamedTypeConstraintAttr, DynTypeBaseConstraintAttr, - DynTypeParamsConstraintAttr - ]) + DynTypeParamsConstraintAttr, + ], +) diff --git a/xdsl/dialects/llvm.py b/xdsl/dialects/llvm.py index 0a38dd81be..5e94297432 100644 --- a/xdsl/dialects/llvm.py +++ b/xdsl/dialects/llvm.py @@ -1,15 +1,43 @@ from __future__ import annotations from typing import TYPE_CHECKING, Annotated -from xdsl.dialects.builtin import (StringAttr, ArrayAttr, DenseArrayBase, - IntAttr, NoneAttr, IntegerType, IntegerAttr, - AnyIntegerAttr, IndexType, UnitAttr, i32, - i64, SymbolRefAttr) -from xdsl.ir import (TypeAttribute, ParametrizedAttribute, Attribute, Dialect, - OpResult, Operation, SSAValue, Region) -from xdsl.irdl import (OpAttr, Operand, ParameterDef, AnyAttr, Block, - irdl_attr_definition, irdl_op_definition, VarOperand, - OptOpAttr, IRDLOperation) +from xdsl.dialects.builtin import ( + StringAttr, + ArrayAttr, + DenseArrayBase, + IntAttr, + NoneAttr, + IntegerType, + IntegerAttr, + AnyIntegerAttr, + IndexType, + UnitAttr, + i32, + i64, + SymbolRefAttr, +) +from xdsl.ir import ( + Block, + TypeAttribute, + ParametrizedAttribute, + Attribute, + Dialect, + OpResult, + Operation, + SSAValue, + Region, +) +from xdsl.irdl import ( + OpAttr, + Operand, + ParameterDef, + AnyAttr, + irdl_attr_definition, + irdl_op_definition, + VarOperand, + OptOpAttr, + IRDLOperation, +) from xdsl.utils.exceptions import VerifyException if TYPE_CHECKING: @@ -51,9 +79,9 @@ def parse_parameters(parser: BaseParser) -> list[Attribute]: parser.parse_characters("<(", "LLVM Struct must start with `<(`") params = parser.parse_list_of( parser.try_parse_type, - "Malformed LLVM struct, expected attribute definition here!") - parser.parse_characters( - ")>", "Unexpected input, expected end of LLVM struct!") + "Malformed LLVM struct, expected attribute definition here!", + ) + parser.parse_characters(")>", "Unexpected input, expected end of LLVM struct!") return [StringAttr(""), ArrayAttr(params)] @@ -78,20 +106,18 @@ def print_parameters(self, printer: Printer) -> None: @staticmethod def parse_parameters(parser: BaseParser) -> list[Attribute]: - if not parser.tokenizer.starts_with('<'): + if not parser.tokenizer.starts_with("<"): return [NoneAttr(), NoneAttr()] - parser.parse_characters('<', "llvm.ptr parameters expected") + parser.parse_characters("<", "llvm.ptr parameters expected") type = parser.try_parse_type() if type is None: - parser.raise_error( - "Expected first parameter of llvm.ptr to be a type!") - if not parser.tokenizer.starts_with(','): - parser.parse_characters('>', - "End of llvm.ptr parameters expected!") + parser.raise_error("Expected first parameter of llvm.ptr to be a type!") + if not parser.tokenizer.starts_with(","): + parser.parse_characters(">", "End of llvm.ptr parameters expected!") return [type, NoneAttr()] - parser.parse_characters(',', "llvm.ptr args must be separated by `,`") + parser.parse_characters(",", "llvm.ptr args must be separated by `,`") addr_space = parser.parse_int_literal() - parser.parse_characters('>', "End of llvm.ptr parameters expected!") + parser.parse_characters(">", "End of llvm.ptr parameters expected!") return [type, IntegerAttr.from_params(addr_space, IndexType())] @staticmethod @@ -122,20 +148,20 @@ def print_parameters(self, printer: Printer) -> None: @staticmethod def parse_parameters(parser: BaseParser) -> list[Attribute]: - if not parser.tokenizer.starts_with('<'): + if not parser.tokenizer.starts_with("<"): return [NoneAttr(), NoneAttr()] - parser.parse_characters('<', "llvm.array parameters expected") + parser.parse_characters("<", "llvm.array parameters expected") size = IntAttr(parser.parse_int_literal()) - if not parser.tokenizer.starts_with('x'): - parser.parse_characters('>', "End of llvm.array type expected!") + if not parser.tokenizer.starts_with("x"): + parser.parse_characters(">", "End of llvm.array type expected!") return [size, NoneAttr()] parser.parse_characters( - 'x', "llvm.array size and type must be separated by `x`") + "x", "llvm.array size and type must be separated by `x`" + ) type = parser.try_parse_type() if type is None: - parser.raise_error( - "Expected second parameter of llvm.array to be a type!") - parser.parse_characters('>', "End of llvm.array parameters expected!") + parser.raise_error("Expected second parameter of llvm.array to be a type!") + parser.parse_characters(">", "End of llvm.array parameters expected!") return [size, type] @staticmethod @@ -163,7 +189,7 @@ def print_parameters(self, printer: Printer) -> None: @staticmethod def parse_parameters(parser: BaseParser) -> list[Attribute]: - parser.parse_characters('<', "llvm.linkage parameter expected") + parser.parse_characters("<", "llvm.linkage parameter expected") # The linkage string is output from xDSL as a string (and accepted by MLIR as such) # however it is always output from MLIR without quotes. Therefore need to determine # whether this is a string or not and slightly change how we parse based upon that @@ -173,18 +199,25 @@ def parse_parameters(parser: BaseParser) -> list[Attribute]: else: linkage_str = parser.tokenizer.next_token().text linkage = StringAttr(linkage_str) - parser.parse_characters('>', "End of llvm.linkage parameter expected!") + parser.parse_characters(">", "End of llvm.linkage parameter expected!") return [linkage] def verify(self): allowed_linkage = [ - "private", "internal", "available_externally", "linkonce", "weak", - "common", "appending", "extern_weak", "linkonce_odr", "weak_odr", - "external" + "private", + "internal", + "available_externally", + "linkonce", + "weak", + "common", + "appending", + "extern_weak", + "linkonce_odr", + "weak_odr", + "external", ] if self.linkage.data not in allowed_linkage: - raise VerifyException( - f"Specified linkage '{self.linkage.data}' is unknown") + raise VerifyException(f"Specified linkage '{self.linkage.data}' is unknown") @irdl_op_definition @@ -200,16 +233,16 @@ class GEPOp(IRDLOperation): @staticmethod def get( - ptr: SSAValue | Operation, - result_type: LLVMPointerType = LLVMPointerType.opaque(), - indices: list[int] | - None = None, # Here we are assuming the indices follow the MLIR standard (min int where the SSA value should be used) - ssa_indices: list[SSAValue | Operation] | None = None, - inbounds: bool = False, - pointee_type: Attribute | None = None): - + ptr: SSAValue | Operation, + result_type: LLVMPointerType = LLVMPointerType.opaque(), + indices: list[int] + | None = None, # Here we are assuming the indices follow the MLIR standard (min int where the SSA value should be used) + ssa_indices: list[SSAValue | Operation] | None = None, + inbounds: bool = False, + pointee_type: Attribute | None = None, + ): if indices is None: - raise ValueError('llvm.getelementptr must have indices passed.') + raise ValueError("llvm.getelementptr must have indices passed.") indices_attr = DenseArrayBase.create_dense_int_or_index(i32, indices) @@ -222,29 +255,28 @@ def get( ptr_type = ptr_val.typ if not isinstance(result_type, LLVMPointerType): - raise ValueError('Result type must be a pointer.') + raise ValueError("Result type must be a pointer.") if not isinstance(ptr_type, LLVMPointerType): - raise ValueError('Input must be a pointer') + raise ValueError("Input must be a pointer") if not ptr_type.is_typed(): if pointee_type == None: - raise ValueError( - 'Opaque types must have a pointee type passed') + raise ValueError("Opaque types must have a pointee type passed") attrs: dict[str, Attribute] = { - 'rawConstantIndices': indices_attr, + "rawConstantIndices": indices_attr, } if not ptr_type.is_typed(): - attrs['elem_type'] = result_type + attrs["elem_type"] = result_type if inbounds: - attrs['inbounds'] = UnitAttr() + attrs["inbounds"] = UnitAttr() - return GEPOp.build(operands=[ptr, ssa_indices], - result_types=[result_type], - attributes=attrs) + return GEPOp.build( + operands=[ptr, ssa_indices], result_types=[result_type], attributes=attrs + ) @irdl_op_definition @@ -258,22 +290,24 @@ class AllocaOp(IRDLOperation): res: OpResult @staticmethod - def get(size: SSAValue | Operation, - elem_type: Attribute, - alignment: int = 32, - as_untyped_ptr: bool = False): + def get( + size: SSAValue | Operation, + elem_type: Attribute, + alignment: int = 32, + as_untyped_ptr: bool = False, + ): attrs: dict[str, Attribute] = { - 'alignment': IntegerAttr.from_int_and_width(alignment, 64) + "alignment": IntegerAttr.from_int_and_width(alignment, 64) } if as_untyped_ptr: ptr_type = LLVMPointerType.opaque() - attrs['elem_type'] = elem_type + attrs["elem_type"] = elem_type else: ptr_type = LLVMPointerType.typed(elem_type) - return AllocaOp.build(operands=[size], - attributes=attrs, - result_types=[ptr_type]) + return AllocaOp.build( + operands=[size], attributes=attrs, result_types=[ptr_type] + ) @irdl_op_definition @@ -342,22 +376,24 @@ class StoreOp(IRDLOperation): nontemporal: OptOpAttr[UnitAttr] @staticmethod - def get(value: SSAValue | Operation, - ptr: SSAValue | Operation, - alignment: int | None = None, - ordering: int = 0, - volatile: bool = False, - nontemporal: bool = False): + def get( + value: SSAValue | Operation, + ptr: SSAValue | Operation, + alignment: int | None = None, + ordering: int = 0, + volatile: bool = False, + nontemporal: bool = False, + ): attrs: dict[str, Attribute] = { - 'ordering': IntegerAttr(ordering, i64), + "ordering": IntegerAttr(ordering, i64), } if alignment is not None: - attrs['alignment'] = IntegerAttr[IntegerType](alignment, i64) + attrs["alignment"] = IntegerAttr[IntegerType](alignment, i64) if volatile: - attrs['volatile_'] = UnitAttr() + attrs["volatile_"] = UnitAttr() if nontemporal: - attrs['nontemporal'] = UnitAttr() + attrs["nontemporal"] = UnitAttr() return StoreOp.build( operands=[value, ptr], @@ -429,18 +465,19 @@ class GlobalOp(IRDLOperation): body: Region @staticmethod - def get(global_type: Attribute, - sym_name: str | StringAttr, - linkage: str | LinkageAttr, - addr_space: int, - constant: bool | None = None, - dso_local: bool | None = None, - thread_local_: bool | None = None, - value: Attribute | None = None, - alignment: int | None = None, - unnamed_addr: int | None = None, - section: str | StringAttr | None = None): - + def get( + global_type: Attribute, + sym_name: str | StringAttr, + linkage: str | LinkageAttr, + addr_space: int, + constant: bool | None = None, + dso_local: bool | None = None, + thread_local_: bool | None = None, + value: Attribute | None = None, + alignment: int | None = None, + unnamed_addr: int | None = None, + section: str | StringAttr | None = None, + ): if isinstance(sym_name, str): sym_name = StringAttr(sym_name) @@ -448,34 +485,34 @@ def get(global_type: Attribute, linkage = LinkageAttr(linkage) attrs: dict[str, Attribute] = { - 'global_type': global_type, - 'sym_name': sym_name, - 'linkage': linkage, - 'addr_space': IntegerAttr(addr_space, 32) + "global_type": global_type, + "sym_name": sym_name, + "linkage": linkage, + "addr_space": IntegerAttr(addr_space, 32), } if constant is not None and constant: - attrs['constant'] = UnitAttr() + attrs["constant"] = UnitAttr() if dso_local is not None and dso_local: - attrs['dso_local'] = UnitAttr() + attrs["dso_local"] = UnitAttr() if thread_local_ is not None and thread_local_: - attrs['thread_local_'] = UnitAttr() + attrs["thread_local_"] = UnitAttr() if value is not None: - attrs['value'] = value + attrs["value"] = value if alignment is not None: - attrs['alignment'] = IntegerAttr(alignment, 64) + attrs["alignment"] = IntegerAttr(alignment, 64) if unnamed_addr is not None: - attrs['unnamed_addr'] = IntegerAttr(unnamed_addr, 64) + attrs["unnamed_addr"] = IntegerAttr(unnamed_addr, 64) if section is not None: if isinstance(section, str): section = StringAttr(section) - attrs['section'] = section + attrs["section"] = section return GlobalOp.build(attributes=attrs, regions=[Region([Block()])]) @@ -488,27 +525,32 @@ class AddressOfOp(IRDLOperation): result: Annotated[OpResult, LLVMPointerType] @staticmethod - def get(global_name: str | StringAttr | SymbolRefAttr, - result_type: LLVMPointerType): + def get( + global_name: str | StringAttr | SymbolRefAttr, result_type: LLVMPointerType + ): if isinstance(global_name, str): global_name = StringAttr(global_name) if isinstance(global_name, StringAttr): global_name = SymbolRefAttr(global_name) - return AddressOfOp.build(attributes={'global_name': global_name}, - result_types=[result_type]) - - -LLVM = Dialect([ - LLVMExtractValue, - LLVMInsertValue, - LLVMMLIRUndef, - AllocaOp, - GEPOp, - IntToPtrOp, - NullOp, - LoadOp, - StoreOp, - GlobalOp, - AddressOfOp, -], [LLVMStructType, LLVMPointerType, LLVMArrayType, LinkageAttr]) + return AddressOfOp.build( + attributes={"global_name": global_name}, result_types=[result_type] + ) + + +LLVM = Dialect( + [ + LLVMExtractValue, + LLVMInsertValue, + LLVMMLIRUndef, + AllocaOp, + GEPOp, + IntToPtrOp, + NullOp, + LoadOp, + StoreOp, + GlobalOp, + AddressOfOp, + ], + [LLVMStructType, LLVMPointerType, LLVMArrayType, LinkageAttr], +) diff --git a/xdsl/dialects/memref.py b/xdsl/dialects/memref.py index a2964847d1..8f042bbeff 100644 --- a/xdsl/dialects/memref.py +++ b/xdsl/dialects/memref.py @@ -1,17 +1,51 @@ from __future__ import annotations -from typing import (TYPE_CHECKING, Annotated, Sequence, TypeVar, Optional, - List, TypeAlias, cast) - -from xdsl.dialects.builtin import (AnyIntegerAttr, DenseIntOrFPElementsAttr, - IntegerAttr, DenseArrayBase, IndexType, - StridedLayoutAttr, ArrayAttr, NoneAttr, - SymbolRefAttr, i64, StringAttr, UnitAttr) -from xdsl.ir import (TypeAttribute, Operation, SSAValue, ParametrizedAttribute, - Dialect, OpResult) -from xdsl.irdl import (irdl_attr_definition, irdl_op_definition, ParameterDef, - Generic, Attribute, AnyAttr, Operand, VarOperand, - AttrSizedOperandSegments, OpAttr, IRDLOperation) +from typing import ( + TYPE_CHECKING, + Annotated, + Sequence, + TypeVar, + Optional, + List, + TypeAlias, + cast, +) + +from xdsl.dialects.builtin import ( + AnyIntegerAttr, + DenseIntOrFPElementsAttr, + IntegerAttr, + DenseArrayBase, + IndexType, + StridedLayoutAttr, + ArrayAttr, + NoneAttr, + SymbolRefAttr, + i64, + StringAttr, + UnitAttr, +) +from xdsl.ir import ( + TypeAttribute, + Operation, + SSAValue, + ParametrizedAttribute, + Dialect, + OpResult, +) +from xdsl.irdl import ( + irdl_attr_definition, + irdl_op_definition, + ParameterDef, + Generic, + Attribute, + AnyAttr, + Operand, + VarOperand, + AttrSizedOperandSegments, + OpAttr, + IRDLOperation, +) from xdsl.utils.exceptions import VerifyException if TYPE_CHECKING: @@ -22,8 +56,7 @@ @irdl_attr_definition -class MemRefType(Generic[_MemRefTypeElement], ParametrizedAttribute, - TypeAttribute): +class MemRefType(Generic[_MemRefTypeElement], ParametrizedAttribute, TypeAttribute): name = "memref" shape: ParameterDef[ArrayAttr[AnyIntegerAttr]] @@ -42,64 +75,71 @@ def from_element_type_and_shape( referenced_type: _MemRefTypeElement, shape: Sequence[int | AnyIntegerAttr], layout: Attribute = NoneAttr(), - memory_space: Attribute = NoneAttr() + memory_space: Attribute = NoneAttr(), ) -> MemRefType[_MemRefTypeElement]: - return MemRefType([ - ArrayAttr[AnyIntegerAttr]([ - d if isinstance(d, IntegerAttr) else - IntegerAttr.from_index_int_value(d) for d in shape - ]), referenced_type, layout, memory_space - ]) + return MemRefType( + [ + ArrayAttr[AnyIntegerAttr]( + [ + d + if isinstance(d, IntegerAttr) + else IntegerAttr.from_index_int_value(d) + for d in shape + ] + ), + referenced_type, + layout, + memory_space, + ] + ) @staticmethod def from_params( referenced_type: _MemRefTypeElement, shape: ArrayAttr[AnyIntegerAttr] = ArrayAttr( - [IntegerAttr.from_int_and_width(1, 64)]), + [IntegerAttr.from_int_and_width(1, 64)] + ), layout: Attribute = NoneAttr(), - memory_space: Attribute = NoneAttr() + memory_space: Attribute = NoneAttr(), ) -> MemRefType[_MemRefTypeElement]: return MemRefType([shape, referenced_type, layout, memory_space]) @staticmethod def parse_parameters(parser: BaseParser) -> list[Attribute]: - parser._synchronize_lexer_and_tokenizer( # pyright: ignore[reportPrivateUsage] - ) - parser.parse_punctuation('<', ' in memref attribute') + parser._synchronize_lexer_and_tokenizer() # pyright: ignore[reportPrivateUsage] + parser.parse_punctuation("<", " in memref attribute") shape = parser.parse_attribute() - parser.parse_punctuation(',', - ' between shape and element type parameters') + parser.parse_punctuation(",", " between shape and element type parameters") type = parser.parse_attribute() # If we have a layout or a memory space, parse both of them. - if parser.parse_optional_punctuation(',') is None: - parser.parse_punctuation('>', ' at end of memref attribute') + if parser.parse_optional_punctuation(",") is None: + parser.parse_punctuation(">", " at end of memref attribute") return [shape, type, NoneAttr(), NoneAttr()] layout = parser.parse_attribute() - parser.parse_punctuation(',', ' between layout and memory space') + parser.parse_punctuation(",", " between layout and memory space") memory_space = parser.parse_attribute() - parser.parse_punctuation('>', ' at end of memref attribute') - parser._synchronize_lexer_and_tokenizer( # pyright: ignore[reportPrivateUsage] - ) + parser.parse_punctuation(">", " at end of memref attribute") + parser._synchronize_lexer_and_tokenizer() # pyright: ignore[reportPrivateUsage] return [shape, type, layout, memory_space] def print_parameters(self, printer: Printer) -> None: - printer.print('<', self.shape, ', ', self.element_type) + printer.print("<", self.shape, ", ", self.element_type) if self.layout != NoneAttr() or self.memory_space != NoneAttr(): - printer.print(', ', self.layout, ', ', self.memory_space) - printer.print('>') + printer.print(", ", self.layout, ", ", self.memory_space) + printer.print(">") -_UnrankedMemrefTypeElems = TypeVar("_UnrankedMemrefTypeElems", - bound=Attribute, - covariant=True) -_UnrankedMemrefTypeElemsInit = TypeVar("_UnrankedMemrefTypeElemsInit", - bound=Attribute) +_UnrankedMemrefTypeElems = TypeVar( + "_UnrankedMemrefTypeElems", bound=Attribute, covariant=True +) +_UnrankedMemrefTypeElemsInit = TypeVar("_UnrankedMemrefTypeElemsInit", bound=Attribute) @irdl_attr_definition -class UnrankedMemrefType(Generic[_UnrankedMemrefTypeElems], - ParametrizedAttribute, TypeAttribute): +class UnrankedMemrefType( + Generic[_UnrankedMemrefTypeElems], ParametrizedAttribute, TypeAttribute +): name = "unranked_memref" element_type: ParameterDef[_UnrankedMemrefTypeElems] @@ -134,20 +174,17 @@ def verify_(self): memref_typ = cast(MemRefType[Attribute], self.memref.typ) if memref_typ.element_type != self.res.typ: - raise Exception( - "expected return type to match the MemRef element type") + raise Exception("expected return type to match the MemRef element type") if self.memref.typ.get_num_dims() != len(self.indices): raise Exception("expected an index for each dimension") @staticmethod - def get(ref: SSAValue | Operation, - indices: Sequence[SSAValue | Operation]) -> Load: + def get(ref: SSAValue | Operation, indices: Sequence[SSAValue | Operation]) -> Load: ssa_value = SSAValue.get(ref) typ = ssa_value.typ typ = cast(MemRefType[Attribute], typ) - return Load.build(operands=[ref, indices], - result_types=[typ.element_type]) + return Load.build(operands=[ref, indices], result_types=[typ.element_type]) @irdl_op_definition @@ -164,15 +201,17 @@ def verify_(self): memref_typ = cast(MemRefType[Attribute], self.memref.typ) if memref_typ.element_type != self.value.typ: - raise Exception( - "Expected value type to match the MemRef element type") + raise Exception("Expected value type to match the MemRef element type") if self.memref.typ.get_num_dims() != len(self.indices): raise Exception("Expected an index for each dimension") @staticmethod - def get(value: Operation | SSAValue, ref: Operation | SSAValue, - indices: Sequence[Operation | SSAValue]) -> Store: + def get( + value: Operation | SSAValue, + ref: Operation | SSAValue, + indices: Sequence[Operation | SSAValue], + ) -> Store: return Store.build(operands=[value, ref, indices]) @@ -191,20 +230,18 @@ class Alloc(IRDLOperation): irdl_options = [AttrSizedOperandSegments()] @staticmethod - def get(return_type: Attribute, - alignment: int, - shape: Optional[List[int | AnyIntegerAttr]] = None) -> Alloc: + def get( + return_type: Attribute, + alignment: int, + shape: Optional[List[int | AnyIntegerAttr]] = None, + ) -> Alloc: if shape is None: shape = [1] - return Alloc.build(operands=[[], []], - result_types=[ - MemRefType.from_element_type_and_shape( - return_type, shape) - ], - attributes={ - "alignment": - IntegerAttr.from_int_and_width(alignment, 64) - }) + return Alloc.build( + operands=[[], []], + result_types=[MemRefType.from_element_type_and_shape(return_type, shape)], + attributes={"alignment": IntegerAttr.from_int_and_width(alignment, 64)}, + ) @irdl_op_definition @@ -222,32 +259,29 @@ class Alloca(IRDLOperation): irdl_options = [AttrSizedOperandSegments()] @staticmethod - def get(return_type: Attribute, - alignment: int, - shape: Optional[List[int | AnyIntegerAttr]] = None, - dynamic_sizes: list[SSAValue | Operation] | None = None) -> Alloca: + def get( + return_type: Attribute, + alignment: int, + shape: Optional[List[int | AnyIntegerAttr]] = None, + dynamic_sizes: list[SSAValue | Operation] | None = None, + ) -> Alloca: if shape is None: shape = [1] if dynamic_sizes is None: dynamic_sizes = [] - return Alloca.build(operands=[dynamic_sizes, []], - result_types=[ - MemRefType.from_element_type_and_shape( - return_type, shape) - ], - attributes={ - "alignment": - IntegerAttr.from_int_and_width(alignment, 64) - }) + return Alloca.build( + operands=[dynamic_sizes, []], + result_types=[MemRefType.from_element_type_and_shape(return_type, shape)], + attributes={"alignment": IntegerAttr.from_int_and_width(alignment, 64)}, + ) @irdl_op_definition class Dealloc(IRDLOperation): name = "memref.dealloc" - memref: Annotated[Operand, - MemRefType[Attribute] | UnrankedMemrefType[Attribute]] + memref: Annotated[Operand, MemRefType[Attribute] | UnrankedMemrefType[Attribute]] @staticmethod def get(operand: Operation | SSAValue) -> Dealloc: @@ -270,17 +304,17 @@ class GetGlobal(IRDLOperation): memref: Annotated[OpResult, MemRefType[Attribute]] def verify_(self) -> None: - if 'name' not in self.attributes: + if "name" not in self.attributes: raise VerifyException("GetGlobal requires a 'name' attribute") - if not isinstance(self.attributes['name'], SymbolRefAttr): - raise VerifyException( - "expected 'name' attribute to be a SymbolRefAttr") + if not isinstance(self.attributes["name"], SymbolRefAttr): + raise VerifyException("expected 'name' attribute to be a SymbolRefAttr") @staticmethod def get(name: str, return_type: Attribute) -> GetGlobal: - return GetGlobal.build(result_types=[return_type], - attributes={"name": SymbolRefAttr.build(name)}) + return GetGlobal.build( + result_types=[return_type], attributes={"name": SymbolRefAttr.build(name)} + ) # TODO how to verify the types, as the global might be defined in another # compilation unit @@ -299,40 +333,42 @@ def verify_(self) -> None: if not isinstance(self.type, MemRefType): raise Exception("Global expects a MemRefType") - if not isinstance(self.initial_value, - UnitAttr | DenseIntOrFPElementsAttr): - raise Exception("Global initial value is expected to be a " - "dense type or an unit attribute") + if not isinstance(self.initial_value, UnitAttr | DenseIntOrFPElementsAttr): + raise Exception( + "Global initial value is expected to be a " + "dense type or an unit attribute" + ) @staticmethod def get( sym_name: StringAttr, typ: Attribute, initial_value: Attribute, - sym_visibility: StringAttr = StringAttr("private") + sym_visibility: StringAttr = StringAttr("private"), ) -> Global: return Global.build( attributes={ "sym_name": sym_name, "type": typ, "initial_value": initial_value, - "sym_visibility": sym_visibility - }) + "sym_visibility": sym_visibility, + } + ) @irdl_op_definition class Dim(IRDLOperation): name = "memref.dim" - source: Annotated[Operand, - MemRefType[Attribute] | UnrankedMemrefType[Attribute]] + source: Annotated[Operand, MemRefType[Attribute] | UnrankedMemrefType[Attribute]] index: Annotated[Operand, IndexType] result: Annotated[OpResult, IndexType] @staticmethod - def from_source_and_index(source: SSAValue | Operation, - index: SSAValue | Operation): + def from_source_and_index( + source: SSAValue | Operation, index: SSAValue | Operation + ): return Dim.build(operands=[source, index], result_types=[IndexType()]) @@ -359,8 +395,9 @@ class ExtractAlignedPointerAsIndexOp(IRDLOperation): @staticmethod def get(source: SSAValue | Operation): - return ExtractAlignedPointerAsIndexOp.build(operands=[source], - result_types=[IndexType()]) + return ExtractAlignedPointerAsIndexOp.build( + operands=[source], result_types=[IndexType()] + ) @irdl_op_definition @@ -387,15 +424,15 @@ def from_static_parameters( sizes: Sequence[int], strides: Sequence[int], ) -> Subview: - source = SSAValue.get(source) layout_strides = [1] for input_size in reversed(source_shape[1:]): layout_strides.insert(0, layout_strides[0] * input_size) - layout_offset = sum(stride * offset - for stride, offset in zip(layout_strides, offsets)) + layout_offset = sum( + stride * offset for stride, offset in zip(layout_strides, offsets) + ) for i in range(len(layout_strides)): layout_strides[i] *= strides[i] @@ -408,16 +445,15 @@ def from_static_parameters( layout, ) - return Subview.build(operands=[source, [], [], []], - result_types=[return_typ], - attributes={ - "static_offsets": - DenseArrayBase.from_list(i64, offsets), - "static_sizes": - DenseArrayBase.from_list(i64, sizes), - "static_strides": - DenseArrayBase.from_list(i64, strides) - }) + return Subview.build( + operands=[source, [], [], []], + result_types=[return_typ], + attributes={ + "static_offsets": DenseArrayBase.from_list(i64, offsets), + "static_sizes": DenseArrayBase.from_list(i64, sizes), + "static_strides": DenseArrayBase.from_list(i64, strides), + }, + ) @irdl_op_definition @@ -428,21 +464,26 @@ class Cast(IRDLOperation): dest: Annotated[OpResult, MemRefType | UnrankedMemrefType] @staticmethod - def get(source: SSAValue | Operation, - type: MemRefType[Attribute] | UnrankedMemrefType[Attribute]): + def get( + source: SSAValue | Operation, + type: MemRefType[Attribute] | UnrankedMemrefType[Attribute], + ): return Cast.build(operands=[source], result_types=[type]) -MemRef = Dialect([ - Load, - Store, - Alloc, - Alloca, - Dealloc, - GetGlobal, - Global, - Dim, - ExtractAlignedPointerAsIndexOp, - Subview, - Cast, -], [MemRefType, UnrankedMemrefType]) +MemRef = Dialect( + [ + Load, + Store, + Alloc, + Alloca, + Dealloc, + GetGlobal, + Global, + Dim, + ExtractAlignedPointerAsIndexOp, + Subview, + Cast, + ], + [MemRefType, UnrankedMemrefType], +) diff --git a/xdsl/dialects/mpi.py b/xdsl/dialects/mpi.py index a563b7e2d6..873d484489 100644 --- a/xdsl/dialects/mpi.py +++ b/xdsl/dialects/mpi.py @@ -6,14 +6,29 @@ from xdsl.utils.hints import isa from xdsl.dialects import llvm -from xdsl.dialects.builtin import (IntegerType, Signedness, StringAttr, - AnyFloat, i32) +from xdsl.dialects.builtin import IntegerType, Signedness, StringAttr, AnyFloat, i32 from xdsl.dialects.memref import MemRefType -from xdsl.ir import (Operation, Attribute, SSAValue, OpResult, - ParametrizedAttribute, Dialect, TypeAttribute) -from xdsl.irdl import (Operand, Annotated, irdl_op_definition, - irdl_attr_definition, OpAttr, OptOpResult, ParameterDef, - OptOperand, OptOpAttr, IRDLOperation) +from xdsl.ir import ( + Operation, + Attribute, + SSAValue, + OpResult, + ParametrizedAttribute, + Dialect, + TypeAttribute, +) +from xdsl.irdl import ( + Operand, + Annotated, + irdl_op_definition, + irdl_attr_definition, + OpAttr, + OptOpResult, + ParameterDef, + OptOperand, + OptOpAttr, + IRDLOperation, +) t_bool: IntegerType = IntegerType(1, Signedness.SIGNLESS) @@ -27,7 +42,8 @@ class OperationType(ParametrizedAttribute, TypeAttribute): They are used by the reduction MPI functions """ - name = 'mpi.operation' + + name = "mpi.operation" op_str: ParameterDef[StringAttr] @@ -36,6 +52,7 @@ class MpiOp: """ A collection of MPI_Op types used for """ + MPI_MAX = OperationType([StringAttr("MPI_MAX")]) MPI_MIN = OperationType([StringAttr("MPI_MIN")]) MPI_SUM = OperationType([StringAttr("MPI_SUM")]) @@ -59,7 +76,8 @@ class RequestType(ParametrizedAttribute, TypeAttribute): They are used by the asynchronous MPI functions """ - name = 'mpi.request' + + name = "mpi.request" @irdl_attr_definition @@ -69,7 +87,8 @@ class StatusType(ParametrizedAttribute, TypeAttribute): It's a struct containing status information for requests. """ - name = 'mpi.status' + + name = "mpi.status" @irdl_attr_definition @@ -77,11 +96,12 @@ class DataType(ParametrizedAttribute, TypeAttribute): """ This type represents MPI_Datatype """ - name = 'mpi.datatype' + + name = "mpi.datatype" VectorWrappable = RequestType | StatusType | DataType -_VectorT = TypeVar('_VectorT', bound=VectorWrappable) +_VectorT = TypeVar("_VectorT", bound=VectorWrappable) @irdl_attr_definition @@ -89,7 +109,8 @@ class VectorType(Generic[_VectorT], ParametrizedAttribute, TypeAttribute): """ This type holds multiple MPI types """ - name = 'mpi.vector' + + name = "mpi.vector" wrapped_type: ParameterDef[_VectorT] @staticmethod @@ -101,15 +122,17 @@ class StatusTypeField(Enum): """ This enum lists all fields in the MPI_Status struct """ - MPI_SOURCE = 'MPI_SOURCE' - MPI_TAG = 'MPI_TAG' - MPI_ERROR = 'MPI_ERROR' + + MPI_SOURCE = "MPI_SOURCE" + MPI_TAG = "MPI_TAG" + MPI_ERROR = "MPI_ERROR" class MPIBaseOp(IRDLOperation, ABC): """ Base class for MPI Operations """ + pass @@ -203,9 +226,9 @@ def get( datatype: SSAValue | Operation, operationtype: OperationType, ): - - operands_to_add: Sequence[SSAValue | Operation - | Sequence[SSAValue | Operation]] = [] + operands_to_add: Sequence[ + SSAValue | Operation | Sequence[SSAValue | Operation] + ] = [] if send_buffer is None: operands_to_add = [recv_buffer, count, datatype, []] else: @@ -285,7 +308,7 @@ class Isend(MPIBaseOp): to MPI_COMM_WORLD """ - name = 'mpi.isend' + name = "mpi.isend" buffer: Annotated[Operand, Attribute] count: Annotated[Operand, i32] @@ -333,7 +356,7 @@ class Send(MPIBaseOp): to MPI_COMM_WORLD """ - name = 'mpi.send' + name = "mpi.send" buffer: Annotated[Operand, Attribute] count: Annotated[Operand, i32] @@ -342,11 +365,16 @@ class Send(MPIBaseOp): tag: Annotated[Operand, i32] @staticmethod - def get(buffer: SSAValue | Operation, count: SSAValue | Operation, - datatype: SSAValue | Operation, dest: SSAValue | Operation, - tag: SSAValue | Operation) -> Send: - return Send.build(operands=[buffer, count, datatype, dest, tag], - result_types=[]) + def get( + buffer: SSAValue | Operation, + count: SSAValue | Operation, + datatype: SSAValue | Operation, + dest: SSAValue | Operation, + tag: SSAValue | Operation, + ) -> Send: + return Send.build( + operands=[buffer, count, datatype, dest, tag], result_types=[] + ) @irdl_op_definition @@ -434,15 +462,18 @@ class Recv(MPIBaseOp): status: Annotated[OptOpResult, StatusType] @staticmethod - def get(buffer: SSAValue | Operation, - count: SSAValue | Operation, - datatype: SSAValue | Operation, - source: SSAValue | Operation, - tag: SSAValue | Operation, - ignore_status: bool = True): + def get( + buffer: SSAValue | Operation, + count: SSAValue | Operation, + datatype: SSAValue | Operation, + source: SSAValue | Operation, + tag: SSAValue | Operation, + ignore_status: bool = True, + ): return Recv.build( operands=[buffer, count, datatype, source, tag], - result_types=[[]] if ignore_status else [[StatusType()]]) + result_types=[[]] if ignore_status else [[StatusType()]], + ) @irdl_op_definition @@ -469,8 +500,7 @@ class Test(MPIBaseOp): @staticmethod def get(request: Operand): - return Test.build(operands=[request], - result_types=[t_bool, StatusType()]) + return Test.build(operands=[request], result_types=[t_bool, StatusType()]) @irdl_op_definition @@ -529,8 +559,7 @@ def get(requests: Operand, count: Operand, ignore_status: bool = True): if ignore_status: result_types = [[]] - return Waitall.build(operands=[requests, count], - result_types=result_types) + return Waitall.build(operands=[requests, count], result_types=result_types) @irdl_op_definition @@ -545,6 +574,7 @@ class GetStatusField(MPIBaseOp): All fields are of type int. """ + name = "mpi.status.get" status: Annotated[Operand, StatusType] @@ -557,8 +587,9 @@ class GetStatusField(MPIBaseOp): def get(status_obj: Operand, field: StatusTypeField): return GetStatusField.build( operands=[status_obj], - attributes={'field': StringAttr(field.value)}, - result_types=[i32]) + attributes={"field": StringAttr(field.value)}, + result_types=[i32], + ) @irdl_op_definition @@ -568,6 +599,7 @@ class CommRank(MPIBaseOp): Currently limited to COMM_WORLD """ + name = "mpi.comm.rank" rank: Annotated[OpResult, i32] @@ -584,6 +616,7 @@ class CommSize(MPIBaseOp): Currently limited to COMM_WORLD """ + name = "mpi.comm.size" size: Annotated[OpResult, i32] @@ -598,6 +631,7 @@ class Init(MPIBaseOp): """ This represents a bare MPI_Init call with both args being nullptr """ + name = "mpi.init" @@ -606,6 +640,7 @@ class Finalize(MPIBaseOp): """ This represents an MPI_Finalize call with both args being nullptr """ + name = "mpi.finalize" @@ -616,6 +651,7 @@ class UnwrapMemrefOp(MPIBaseOp): It takes any MemRef as input, and returns an llvm.ptr, element count and MPI_Datatype. """ + name = "mpi.unwrap_memref" ref: Annotated[Operand, MemRefType[AnyNumericType]] @@ -630,12 +666,10 @@ def get(ref: SSAValue | Operation) -> UnwrapMemrefOp: assert isinstance(ssa_val.typ, MemRefType) elem_typ = cast(MemRefType[AnyNumericType], ssa_val.typ).element_type - return UnwrapMemrefOp.build(operands=[ref], - result_types=[ - llvm.LLVMPointerType.typed(elem_typ), - i32, - DataType() - ]) + return UnwrapMemrefOp.build( + operands=[ref], + result_types=[llvm.LLVMPointerType.typed(elem_typ), i32, DataType()], + ) @irdl_op_definition @@ -650,6 +684,7 @@ class GetDtypeOp(MPIBaseOp): to get the magic constant. See `_MPIToLLVMRewriteBase._translate_to_mpi_type` docstring for more detail on which types are supported. """ + name = "mpi.get_dtype" dtype: OpAttr[Attribute] @@ -658,8 +693,7 @@ class GetDtypeOp(MPIBaseOp): @staticmethod def get(typ: Attribute): - return GetDtypeOp.build(result_types=[DataType()], - attributes={'dtype': typ}) + return GetDtypeOp.build(result_types=[DataType()], attributes={"dtype": typ}) @irdl_op_definition @@ -673,6 +707,7 @@ class AllocateTypeOp(MPIBaseOp): number of elements and an optional bindc_name which contains the name of the variable that this is allocating """ + name = "mpi.allocate" bindc_name: OptOpAttr[StringAttr] @@ -687,12 +722,14 @@ def get( count: SSAValue | Operation, bindc_name: StringAttr | None = None, ) -> AllocateTypeOp: - return AllocateTypeOp.build(result_types=[VectorType.of(dtype)], - attributes={ - "dtype": dtype(), - "bindc_name": bindc_name, - }, - operands=[count]) + return AllocateTypeOp.build( + result_types=[VectorType.of(dtype)], + attributes={ + "dtype": dtype(), + "bindc_name": bindc_name, + }, + operands=[count], + ) @irdl_op_definition @@ -701,6 +738,7 @@ class VectorGetOp(MPIBaseOp): This op will retrieve an element of an MPI vector, it accepts the vector as an argument and the element index """ + name = "mpi.vector_get" vect: Annotated[Operand, VectorType] @@ -709,13 +747,13 @@ class VectorGetOp(MPIBaseOp): result: Annotated[OpResult, VectorWrappable] @staticmethod - def get(vect: SSAValue | Operation, - element: SSAValue | Operation) -> VectorGetOp: + def get(vect: SSAValue | Operation, element: SSAValue | Operation) -> VectorGetOp: ssa_val = SSAValue.get(vect) assert isa(ssa_val.typ, VectorType[VectorWrappable]) - return VectorGetOp.build(result_types=[ssa_val.typ.wrapped_type], - operands=[vect, element]) + return VectorGetOp.build( + result_types=[ssa_val.typ.wrapped_type], operands=[vect, element] + ) @irdl_op_definition @@ -727,6 +765,7 @@ class NullRequestOp(MPIBaseOp): Due to restrictions in the current MPI dialect, we can't return a new request object here. That will be fixed soon though! """ + name = "mpi.request_null" request: Annotated[Operand, RequestType] @@ -736,28 +775,31 @@ def get(req: SSAValue | Operation): return NullRequestOp.build(operands=[req]) -MPI = Dialect([ - Isend, - Irecv, - Test, - Recv, - Send, - Reduce, - Allreduce, - Bcast, - Wait, - GetStatusField, - Init, - Finalize, - CommRank, - UnwrapMemrefOp, - GetDtypeOp, - AllocateTypeOp, - VectorGetOp, -], [ - OperationType, - RequestType, - StatusType, - DataType, - VectorType, -]) +MPI = Dialect( + [ + Isend, + Irecv, + Test, + Recv, + Send, + Reduce, + Allreduce, + Bcast, + Wait, + GetStatusField, + Init, + Finalize, + CommRank, + UnwrapMemrefOp, + GetDtypeOp, + AllocateTypeOp, + VectorGetOp, + ], + [ + OperationType, + RequestType, + StatusType, + DataType, + VectorType, + ], +) diff --git a/xdsl/dialects/pdl.py b/xdsl/dialects/pdl.py index 95cab82511..cc82122750 100644 --- a/xdsl/dialects/pdl.py +++ b/xdsl/dialects/pdl.py @@ -2,14 +2,31 @@ from typing import Annotated, Generic, Sequence, TypeVar -from xdsl.dialects.builtin import (ArrayAttr, IntegerAttr, IntegerType, - StringAttr) -from xdsl.ir import (Attribute, Block, Dialect, TypeAttribute, OpResult, - ParametrizedAttribute, Region, SSAValue) -from xdsl.irdl import (AttrSizedOperandSegments, OpAttr, Operand, OptOpAttr, - OptOperand, OptRegion, ParameterDef, VarOpResult, - VarOperand, irdl_attr_definition, irdl_op_definition, - IRDLOperation) +from xdsl.dialects.builtin import ArrayAttr, IntegerAttr, IntegerType, StringAttr +from xdsl.ir import ( + Attribute, + Block, + Dialect, + TypeAttribute, + OpResult, + ParametrizedAttribute, + Region, + SSAValue, +) +from xdsl.irdl import ( + AttrSizedOperandSegments, + OpAttr, + Operand, + OptOpAttr, + OptOperand, + OptRegion, + ParameterDef, + VarOpResult, + VarOperand, + irdl_attr_definition, + irdl_op_definition, + IRDLOperation, +) from xdsl.utils.exceptions import VerifyException from xdsl.utils.hints import isa @@ -36,9 +53,11 @@ class ValueType(ParametrizedAttribute, TypeAttribute): AnyPDLType = AttributeType | OperationType | TypeType | ValueType -_RangeT = TypeVar("_RangeT", - bound=AttributeType | OperationType | TypeType | ValueType, - covariant=True) +_RangeT = TypeVar( + "_RangeT", + bound=AttributeType | OperationType | TypeType | ValueType, + covariant=True, +) @irdl_attr_definition @@ -52,26 +71,24 @@ class ApplyNativeConstraintOp(IRDLOperation): """ https://mlir.llvm.org/docs/Dialects/PDLOps/#pdlapply_native_constraint-mlirpdlapplynativeconstraintop """ + name: str = "pdl.apply_native_constraint" # https://github.com/xdslproject/xdsl/issues/98 # name: OpAttr[StringAttr] args: Annotated[VarOperand, AnyPDLType] def verify_(self) -> None: - if 'name' not in self.attributes: - raise VerifyException( - "ApplyNativeConstraintOp requires a 'name' attribute") + if "name" not in self.attributes: + raise VerifyException("ApplyNativeConstraintOp requires a 'name' attribute") - if not isinstance(self.attributes['name'], StringAttr): - raise VerifyException( - "expected 'name' attribute to be a StringAttr") + if not isinstance(self.attributes["name"], StringAttr): + raise VerifyException("expected 'name' attribute to be a StringAttr") @staticmethod def get(name: str, args: Sequence[SSAValue]) -> ApplyNativeConstraintOp: return ApplyNativeConstraintOp.build( - result_types=[], - operands=[args], - attributes={"name": StringAttr(name)}) + result_types=[], operands=[args], attributes={"name": StringAttr(name)} + ) @irdl_op_definition @@ -79,6 +96,7 @@ class ApplyNativeRewriteOp(IRDLOperation): """ https://mlir.llvm.org/docs/Dialects/PDLOps/#pdlapply_native_rewrite-mlirpdlapplynativerewriteop """ + name: str = "pdl.apply_native_rewrite" # https://github.com/xdslproject/xdsl/issues/98 # name: OpAttr[StringAttr] @@ -86,22 +104,21 @@ class ApplyNativeRewriteOp(IRDLOperation): res: Annotated[VarOpResult, AnyPDLType] def verify_(self) -> None: - if 'name' not in self.attributes: - raise VerifyException( - "ApplyNativeRewriteOp requires a 'name' attribute") + if "name" not in self.attributes: + raise VerifyException("ApplyNativeRewriteOp requires a 'name' attribute") - if not isinstance(self.attributes['name'], StringAttr): - raise VerifyException( - "expected 'name' attribute to be a StringAttr") + if not isinstance(self.attributes["name"], StringAttr): + raise VerifyException("expected 'name' attribute to be a StringAttr") @staticmethod - def get(name: str, args: Sequence[SSAValue], - result_types: Sequence[Attribute]) -> ApplyNativeRewriteOp: - + def get( + name: str, args: Sequence[SSAValue], result_types: Sequence[Attribute] + ) -> ApplyNativeRewriteOp: return ApplyNativeRewriteOp.build( result_types=[result_types], operands=[args], - attributes={"name": StringAttr(name)}) + attributes={"name": StringAttr(name)}, + ) @irdl_op_definition @@ -109,26 +126,28 @@ class AttributeOp(IRDLOperation): """ https://mlir.llvm.org/docs/Dialects/PDLOps/#pdlattribute-mlirpdlattributeop """ + name: str = "pdl.attribute" value: OptOpAttr[Attribute] valueType: Annotated[OptOperand, TypeType] output: Annotated[OpResult, AttributeType] @staticmethod - def get(value: Attribute | None = None, - valueType: SSAValue | None = None) -> AttributeOp: + def get( + value: Attribute | None = None, valueType: SSAValue | None = None + ) -> AttributeOp: attributes: dict[str, Attribute] = {} if value is not None: - attributes['value'] = value + attributes["value"] = value if valueType is None: value_type = [] else: value_type = [valueType] - return AttributeOp.build(operands=[value_type], - attributes=attributes, - result_types=[AttributeType()]) + return AttributeOp.build( + operands=[value_type], attributes=attributes, result_types=[AttributeType()] + ) @irdl_op_definition @@ -136,6 +155,7 @@ class EraseOp(IRDLOperation): """ https://mlir.llvm.org/docs/Dialects/PDLOps/#pdlerase-mlirpdleraseop """ + name: str = "pdl.erase" opValue: Annotated[Operand, OperationType] @@ -149,6 +169,7 @@ class OperandOp(IRDLOperation): """ https://mlir.llvm.org/docs/Dialects/PDLOps/#pdloperand-mlirpdloperandop """ + name: str = "pdl.operand" valueType: Annotated[OptOperand, TypeType] value: Annotated[OpResult, ValueType] @@ -159,8 +180,7 @@ def get(valueType: SSAValue | None = None) -> OperandOp: value_type = [] else: value_type = [valueType] - return OperandOp.build(operands=[value_type], - result_types=[ValueType()]) + return OperandOp.build(operands=[value_type], result_types=[ValueType()]) @irdl_op_definition @@ -168,6 +188,7 @@ class OperandsOp(IRDLOperation): """ https://mlir.llvm.org/docs/Dialects/PDLOps/#pdloperands-mlirpdloperandsop """ + name: str = "pdl.operands" valueType: Annotated[Operand, RangeType[TypeType]] value: Annotated[OpResult, RangeType[ValueType]] @@ -178,6 +199,7 @@ class OperationOp(IRDLOperation): """ https://mlir.llvm.org/docs/Dialects/PDLOps/#pdloperation-mlirpdloperationop """ + name: str = "pdl.operation" opName: OptOpAttr[StringAttr] attributeValueNames: OpAttr[ArrayAttr[StringAttr]] @@ -190,11 +212,13 @@ class OperationOp(IRDLOperation): irdl_options = [AttrSizedOperandSegments()] @staticmethod - def get(opName: StringAttr | None, - attributeValueNames: ArrayAttr[StringAttr] | None = None, - operandValues: Sequence[SSAValue] | None = None, - attributeValues: Sequence[SSAValue] | None = None, - typeValues: Sequence[SSAValue] | None = None): + def get( + opName: StringAttr | None, + attributeValueNames: ArrayAttr[StringAttr] | None = None, + operandValues: Sequence[SSAValue] | None = None, + attributeValues: Sequence[SSAValue] | None = None, + typeValues: Sequence[SSAValue] | None = None, + ): if attributeValueNames is None: attributeValueNames = ArrayAttr([]) if operandValues is None: @@ -207,10 +231,8 @@ def get(opName: StringAttr | None, return OperationOp.build( operands=[operandValues, attributeValues, typeValues], result_types=[OperationType()], - attributes={ - "attributeValueNames": attributeValueNames, - "opName": opName - }) + attributes={"attributeValueNames": attributeValueNames, "opName": opName}, + ) @irdl_op_definition @@ -218,25 +240,31 @@ class PatternOp(IRDLOperation): """ https://mlir.llvm.org/docs/Dialects/PDLOps/#pdlpattern-mlirpdlpatternop """ + name: str = "pdl.pattern" benefit: OpAttr[IntegerAttr[IntegerType]] sym_name: OptOpAttr[StringAttr] body: Region @staticmethod - def get(benefit: IntegerAttr[IntegerType], sym_name: StringAttr | None, - body: Region) -> PatternOp: - return PatternOp.build(attributes={ - "benefit": benefit, - "sym_name": sym_name, - }, - regions=[body], - result_types=[]) + def get( + benefit: IntegerAttr[IntegerType], sym_name: StringAttr | None, body: Region + ) -> PatternOp: + return PatternOp.build( + attributes={ + "benefit": benefit, + "sym_name": sym_name, + }, + regions=[body], + result_types=[], + ) @staticmethod - def from_callable(benefit: IntegerAttr[IntegerType], - sym_name: StringAttr | None, - callable: Block.BlockCallback) -> PatternOp: + def from_callable( + benefit: IntegerAttr[IntegerType], + sym_name: StringAttr | None, + callable: Block.BlockCallback, + ) -> PatternOp: block = Block.from_callable([], callable) region = Region(block) return PatternOp.get(benefit, sym_name, region) @@ -247,12 +275,12 @@ class RangeOp(IRDLOperation): """ https://mlir.llvm.org/docs/Dialects/PDLOps/#pdlrange-mlirpdlrangeop """ + name: str = "pdl.range" arguments: Annotated[VarOperand, AnyPDLType | RangeType[AnyPDLType]] result: Annotated[OpResult, RangeType[AnyPDLType]] def verify_(self) -> None: - def get_type_or_elem_type(arg: SSAValue) -> Attribute: if isa(arg.typ, RangeType[AnyPDLType]): return arg.typ.elementType @@ -267,7 +295,8 @@ def get_type_or_elem_type(arg: SSAValue) -> Attribute: raise VerifyException( f"All arguments must have the same type or be an array \ of the corresponding element type. First element type:\ - {elem_type}, current element type: {cur_elem_type}") + {elem_type}, current element type: {cur_elem_type}" + ) @irdl_op_definition @@ -284,6 +313,7 @@ class ReplaceOp(IRDLOperation): * a set of `Value`s (`replValues` should be populated) - The operation will be replaced with these values. """ + name: str = "pdl.replace" opValue: Annotated[Operand, OperationType] replOperation: Annotated[OptOperand, OperationType] @@ -292,9 +322,11 @@ class ReplaceOp(IRDLOperation): irdl_options = [AttrSizedOperandSegments()] @staticmethod - def get(opValue: SSAValue, - replOperation: SSAValue | None = None, - replValues: Sequence[SSAValue] | None = None) -> ReplaceOp: + def get( + opValue: SSAValue, + replOperation: SSAValue | None = None, + replValues: Sequence[SSAValue] | None = None, + ) -> ReplaceOp: operands: list[SSAValue | Sequence[SSAValue]] = [opValue] if replOperation is None: operands.append([]) @@ -308,13 +340,16 @@ def get(opValue: SSAValue, def verify_(self) -> None: if self.replOperation is None: if not len(self.replValues): - raise VerifyException("Exactly one of `replOperation` or " - "`replValues` must be set in `ReplaceOp`" - ", both are empty") + raise VerifyException( + "Exactly one of `replOperation` or " + "`replValues` must be set in `ReplaceOp`" + ", both are empty" + ) elif len(self.replValues): raise VerifyException( "Exactly one of `replOperation` or `replValues` must be set in " - "`ReplaceOp`, both are set") + "`ReplaceOp`, both are set" + ) @irdl_op_definition @@ -322,6 +357,7 @@ class ResultOp(IRDLOperation): """ https://mlir.llvm.org/docs/Dialects/PDLOps/#pdlresult-mlirpdlresultop """ + name: str = "pdl.result" index: OpAttr[IntegerAttr[IntegerType]] parent_: Annotated[Operand, OperationType] @@ -329,9 +365,9 @@ class ResultOp(IRDLOperation): @staticmethod def get(index: IntegerAttr[IntegerType], parent: SSAValue) -> ResultOp: - return ResultOp.build(operands=[parent], - attributes={'index': index}, - result_types=[ValueType()]) + return ResultOp.build( + operands=[parent], attributes={"index": index}, result_types=[ValueType()] + ) @irdl_op_definition @@ -339,6 +375,7 @@ class ResultsOp(IRDLOperation): """ https://mlir.llvm.org/docs/Dialects/PDLOps/#pdlresults-mlirpdlresultsop """ + name: str = "pdl.results" index: OpAttr[IntegerAttr[IntegerType]] parent_: Annotated[Operand, OperationType] @@ -350,6 +387,7 @@ class RewriteOp(IRDLOperation): """ https://mlir.llvm.org/docs/Dialects/PDLOps/#pdlrewrite-mlirpdlrewriteop """ + name: str = "pdl.rewrite" root: Annotated[OptOperand, OperationType] # name of external rewriter function @@ -363,15 +401,17 @@ class RewriteOp(IRDLOperation): irdl_options = [AttrSizedOperandSegments()] def verify_(self) -> None: - if 'name' in self.attributes: - if not isinstance(self.attributes['name'], StringAttr): + if "name" in self.attributes: + if not isinstance(self.attributes["name"], StringAttr): raise Exception("expected 'name' attribute to be a StringAttr") @staticmethod - def get(name: StringAttr | None, root: SSAValue | None, - external_args: Sequence[SSAValue], - body: Region | None) -> RewriteOp: - + def get( + name: StringAttr | None, + root: SSAValue | None, + external_args: Sequence[SSAValue], + body: Region | None, + ) -> RewriteOp: operands: list[SSAValue | Sequence[SSAValue]] = [] if root is not None: operands.append([root]) @@ -387,17 +427,19 @@ def get(name: StringAttr | None, root: SSAValue | None, attributes: dict[str, Attribute] = {} if name is not None: - attributes['name'] = name + attributes["name"] = name - return RewriteOp.build(result_types=[], - operands=operands, - attributes=attributes, - regions=regions) + return RewriteOp.build( + result_types=[], operands=operands, attributes=attributes, regions=regions + ) @staticmethod - def from_callable(name: StringAttr | None, root: SSAValue | None, - external_args: Sequence[SSAValue], - body: Block.BlockCallback) -> RewriteOp: + def from_callable( + name: StringAttr | None, + root: SSAValue | None, + external_args: Sequence[SSAValue], + body: Block.BlockCallback, + ) -> RewriteOp: block = Block.from_callable([], body) region = Region(block) return RewriteOp.get(name, root, external_args, region) @@ -408,14 +450,16 @@ class TypeOp(IRDLOperation): """ https://mlir.llvm.org/docs/Dialects/PDLOps/#pdltype-mlirpdltypeop """ + name: str = "pdl.type" constantType: OptOpAttr[Attribute] result: Annotated[OpResult, TypeType] @staticmethod def get(constantType: TypeType | None = None) -> TypeOp: - return TypeOp.build(attributes={"constantType": constantType}, - result_types=[TypeType()]) + return TypeOp.build( + attributes={"constantType": constantType}, result_types=[TypeType()] + ) @irdl_op_definition @@ -423,31 +467,35 @@ class TypesOp(IRDLOperation): """ https://mlir.llvm.org/docs/Dialects/PDLOps/#pdltypes-mlirpdltypesop """ + name: str = "pdl.types" constantTypes: Annotated[OptOperand, ArrayAttr[TypeType]] result: Annotated[OpResult, RangeType[TypeType]] -PDL = Dialect([ - ApplyNativeConstraintOp, - ApplyNativeRewriteOp, - AttributeOp, - OperandOp, - EraseOp, - OperandsOp, - OperationOp, - PatternOp, - RangeOp, - ReplaceOp, - ResultOp, - ResultsOp, - RewriteOp, - TypeOp, - TypesOp, -], [ - AttributeType, - OperationType, - TypeType, - ValueType, - RangeType, -]) +PDL = Dialect( + [ + ApplyNativeConstraintOp, + ApplyNativeRewriteOp, + AttributeOp, + OperandOp, + EraseOp, + OperandsOp, + OperationOp, + PatternOp, + RangeOp, + ReplaceOp, + ResultOp, + ResultsOp, + RewriteOp, + TypeOp, + TypesOp, + ], + [ + AttributeType, + OperationType, + TypeType, + ValueType, + RangeType, + ], +) diff --git a/xdsl/dialects/scf.py b/xdsl/dialects/scf.py index c91d16f579..15fc3355e6 100644 --- a/xdsl/dialects/scf.py +++ b/xdsl/dialects/scf.py @@ -4,9 +4,16 @@ from xdsl.dialects.builtin import IndexType, IntegerType from xdsl.ir import Attribute, Block, Dialect, Operation, Region, SSAValue -from xdsl.irdl import (AnyAttr, AttrSizedOperandSegments, Operand, - SingleBlockRegion, VarOperand, VarOpResult, - irdl_op_definition, IRDLOperation) +from xdsl.irdl import ( + AnyAttr, + AttrSizedOperandSegments, + Operand, + SingleBlockRegion, + VarOperand, + VarOpResult, + irdl_op_definition, + IRDLOperation, +) from xdsl.utils.exceptions import VerifyException @@ -25,10 +32,8 @@ def get( cond: SSAValue | Operation, return_types: Sequence[Attribute], true_region: Region | Sequence[Block] | Sequence[Operation], - false_region: Region | Sequence[Block] | Sequence[Operation] - | None = None + false_region: Region | Sequence[Block] | Sequence[Operation] | None = None, ) -> If: - if false_region is None: false_region = Region() @@ -46,8 +51,7 @@ class Yield(IRDLOperation): @staticmethod def get(*operands: SSAValue | Operation) -> Yield: - return Yield.create( - operands=[SSAValue.get(operand) for operand in operands]) + return Yield.create(operands=[SSAValue.get(operand) for operand in operands]) @irdl_op_definition @@ -57,10 +61,8 @@ class Condition(IRDLOperation): arguments: Annotated[VarOperand, AnyAttr()] @staticmethod - def get(cond: SSAValue | Operation, - *output_ops: SSAValue | Operation) -> Condition: - return Condition.build( - operands=[cond, [output for output in output_ops]]) + def get(cond: SSAValue | Operation, *output_ops: SSAValue | Operation) -> Condition: + return Condition.build(operands=[cond, [output for output in output_ops]]) @irdl_op_definition @@ -82,27 +84,31 @@ def verify_(self): raise VerifyException( f"Wrong number of block arguments, expected {len(self.iter_args)+1}, got " f"{len(self.body.block.args)}. The body must have the induction " - f"variable and loop-carried variables as arguments.") + f"variable and loop-carried variables as arguments." + ) for idx, arg in enumerate(self.iter_args): if self.body.block.args[idx + 1].typ != arg.typ: raise VerifyException( f"Block arguments with wrong type, expected {arg.typ}, " f"got {self.body.block.args[idx].typ}. Arguments after the " - f"induction variable must match the carried variables.") + f"induction variable must match the carried variables." + ) if len(self.iter_args) > 0: - if (len(self.body.ops) == 0 - or not isinstance(self.body.block.last_op, Yield)): + if len(self.body.ops) == 0 or not isinstance( + self.body.block.last_op, Yield + ): raise VerifyException( "The scf.for's body does not end with a scf.yield. A scf.for loop " "with loop-carried variables must yield their values at the end of " - "its body.") - if (len(self.body.ops) > 0 - and isinstance(self.body.block.last_op, Yield)): + "its body." + ) + if len(self.body.ops) > 0 and isinstance(self.body.block.last_op, Yield): yieldop = self.body.block.last_op if len(yieldop.arguments) != len(self.iter_args): raise VerifyException( f"Expected {len(self.iter_args)} args, got {len(yieldop.arguments)}. " - f"The scf.for must yield its carried variables.") + f"The scf.for must yield its carried variables." + ) for idx, arg in enumerate(yieldop.arguments): if self.iter_args[idx].typ != arg.typ: raise VerifyException( @@ -148,25 +154,31 @@ def get( steps: Sequence[SSAValue | Operation], body: Region | list[Block] | list[Operation], ): - return ParallelOp.build(operands=[lowerBounds, upperBounds, steps, []], - regions=[body], - result_types=[[]]) + return ParallelOp.build( + operands=[lowerBounds, upperBounds, steps, []], + regions=[body], + result_types=[[]], + ) def verify_(self) -> None: - if len(self.lowerBound) != len(self.upperBound) or len( - self.lowerBound) != len(self.step): + if len(self.lowerBound) != len(self.upperBound) or len(self.lowerBound) != len( + self.step + ): raise VerifyException( "Expected the same number of lower bounds, upper " "bounds, and steps for scf.parallel. Got " f"{len(self.lowerBound)}, {len(self.upperBound)} and " - f"{len(self.step)}.") + f"{len(self.step)}." + ) body_args = self.body.block.args if len(self.body.blocks) != 0 else () if len(self.lowerBound) != len(body_args) or not all( - [isinstance(a.typ, IndexType) for a in body_args]): + [isinstance(a.typ, IndexType) for a in body_args] + ): raise VerifyException( f"Expected {len(self.lowerBound)} index-typed region arguments, got " f"{[str(a.typ) for a in body_args]}. scf.parallel's body must have an index " - "argument for each induction variable. ") + "argument for each induction variable. " + ) if len(self.initVals) != 0 or len(self.res) != 0: raise VerifyException( "scf.parallel loop-carried variables and reduction are not implemented yet." @@ -188,22 +200,26 @@ def verify_(self): if self.before_region.block.args[idx].typ != arg.typ: raise Exception( f"Block arguments with wrong type, expected {arg.typ}, " - f"got {self.before_region.block.args[idx].typ}") + f"got {self.before_region.block.args[idx].typ}" + ) for idx, res in enumerate(self.res): if self.after_region.block.args[idx].typ != res.typ: raise Exception( f"Block arguments with wrong type, expected {res.typ}, " - f"got {self.after_region.block.args[idx].typ}") + f"got {self.after_region.block.args[idx].typ}" + ) @staticmethod - def get(operands: List[SSAValue | Operation], - result_types: List[Attribute], - before: Region | List[Operation] | List[Block], - after: Region | List[Operation] | List[Block]) -> While: - op = While.build(operands=operands, - result_types=result_types, - regions=[before, after]) + def get( + operands: List[SSAValue | Operation], + result_types: List[Attribute], + before: Region | List[Operation] | List[Block], + after: Region | List[Operation] | List[Block], + ) -> While: + op = While.build( + operands=operands, result_types=result_types, regions=[before, after] + ) return op diff --git a/xdsl/dialects/test.py b/xdsl/dialects/test.py index 46d10b029c..ca1a0e4c46 100644 --- a/xdsl/dialects/test.py +++ b/xdsl/dialects/test.py @@ -1,8 +1,14 @@ from __future__ import annotations from xdsl.ir import Data, Dialect, TypeAttribute -from xdsl.irdl import (VarOpResult, VarOperand, VarRegion, - irdl_attr_definition, irdl_op_definition, IRDLOperation) +from xdsl.irdl import ( + VarOpResult, + VarOperand, + VarRegion, + irdl_attr_definition, + irdl_op_definition, + IRDLOperation, +) from xdsl.parser import BaseParser from xdsl.printer import Printer @@ -15,6 +21,7 @@ class TestOp(IRDLOperation): on other dialects (i.e. dependencies that only come from the structure of the test rather than the actual dialect). """ + name: str = "test.op" res: VarOpResult @@ -29,6 +36,7 @@ class TestType(Data[str], TypeAttribute): used. This allows reducing the artificial dependencies on attributes from other dialects. """ + name: str = "test.type" @staticmethod diff --git a/xdsl/dialects/vector.py b/xdsl/dialects/vector.py index 79163136e0..a4baccee70 100644 --- a/xdsl/dialects/vector.py +++ b/xdsl/dialects/vector.py @@ -1,10 +1,14 @@ from __future__ import annotations from typing import Annotated, List -from xdsl.dialects.builtin import (IndexType, VectorType, i1, - VectorRankConstraint, - VectorBaseTypeConstraint, - VectorBaseTypeAndRankConstraint) +from xdsl.dialects.builtin import ( + IndexType, + VectorType, + i1, + VectorRankConstraint, + VectorBaseTypeConstraint, + VectorBaseTypeAndRankConstraint, +) from xdsl.dialects.memref import MemRefType from xdsl.ir import Attribute, Operation, SSAValue, Dialect, OpResult from xdsl.irdl import AnyAttr, irdl_op_definition, Operand, VarOperand, IRDLOperation @@ -25,22 +29,23 @@ def verify_(self): if self.memref.typ.element_type != self.res.typ.element_type: raise VerifyException( - "MemRef element type should match the Vector element type.") + "MemRef element type should match the Vector element type." + ) if self.memref.typ.get_num_dims() != len(self.indices): raise VerifyException("Expected an index for each dimension.") @staticmethod - def get(ref: SSAValue | Operation, - indices: List[SSAValue | Operation]) -> Load: + def get(ref: SSAValue | Operation, indices: List[SSAValue | Operation]) -> Load: ref = SSAValue.get(ref) assert assert_isa(ref.typ, MemRefType[Attribute]) - return Load.build(operands=[ref, indices], - result_types=[ - VectorType.from_element_type_and_shape( - ref.typ.element_type, [1]) - ]) + return Load.build( + operands=[ref, indices], + result_types=[ + VectorType.from_element_type_and_shape(ref.typ.element_type, [1]) + ], + ) @irdl_op_definition @@ -56,14 +61,18 @@ def verify_(self): if self.memref.typ.element_type != self.vector.typ.element_type: raise VerifyException( - "MemRef element type should match the Vector element type.") + "MemRef element type should match the Vector element type." + ) if self.memref.typ.get_num_dims() != len(self.indices): raise VerifyException("Expected an index for each dimension.") @staticmethod - def get(vector: Operation | SSAValue, ref: Operation | SSAValue, - indices: List[Operation | SSAValue]) -> Store: + def get( + vector: Operation | SSAValue, + ref: Operation | SSAValue, + indices: List[Operation | SSAValue], + ) -> Store: return Store.build(operands=[vector, ref, indices]) @@ -83,11 +92,12 @@ def verify_(self): @staticmethod def get(source: Operation | SSAValue) -> Broadcast: - return Broadcast.build(operands=[source], - result_types=[ - VectorType.from_element_type_and_shape( - SSAValue.get(source).typ, [1]) - ]) + return Broadcast.build( + operands=[source], + result_types=[ + VectorType.from_element_type_and_shape(SSAValue.get(source).typ, [1]) + ], + ) @irdl_op_definition @@ -136,16 +146,18 @@ def verify_(self): ) @staticmethod - def get(lhs: Operation | SSAValue, rhs: Operation | SSAValue, - acc: Operation | SSAValue) -> FMA: + def get( + lhs: Operation | SSAValue, rhs: Operation | SSAValue, acc: Operation | SSAValue + ) -> FMA: lhs = SSAValue.get(lhs) assert assert_isa(lhs.typ, VectorType[Attribute]) - return FMA.build(operands=[lhs, rhs, acc], - result_types=[ - VectorType.from_element_type_and_shape( - lhs.typ.element_type, [1]) - ]) + return FMA.build( + operands=[lhs, rhs, acc], + result_types=[ + VectorType.from_element_type_and_shape(lhs.typ.element_type, [1]) + ], + ) @irdl_op_definition @@ -182,21 +194,24 @@ def verify_(self): ) if memref_typ.get_num_dims() != len(self.indices): - raise VerifyException( - "Expected an index for each memref dimension.") + raise VerifyException("Expected an index for each memref dimension.") @staticmethod - def get(memref: SSAValue | Operation, indices: List[SSAValue | Operation], - mask: SSAValue | Operation, - passthrough: SSAValue | Operation) -> Maskedload: + def get( + memref: SSAValue | Operation, + indices: List[SSAValue | Operation], + mask: SSAValue | Operation, + passthrough: SSAValue | Operation, + ) -> Maskedload: memref = SSAValue.get(memref) assert assert_isa(memref.typ, MemRefType[Attribute]) - return Maskedload.build(operands=[memref, indices, mask, passthrough], - result_types=[ - VectorType.from_element_type_and_shape( - memref.typ.element_type, [1]) - ]) + return Maskedload.build( + operands=[memref, indices, mask, passthrough], + result_types=[ + VectorType.from_element_type_and_shape(memref.typ.element_type, [1]) + ], + ) @irdl_op_definition @@ -221,19 +236,24 @@ def verify_(self): if memref_element_type != value_to_store_typ.element_type: raise VerifyException( "MemRef element type should match the stored vector type. " - "Obtained types were " + str(memref_element_type) + " and " + - str(value_to_store_typ.element_type) + ".") + "Obtained types were " + + str(memref_element_type) + + " and " + + str(value_to_store_typ.element_type) + + "." + ) if memref_typ.get_num_dims() != len(self.indices): - raise VerifyException( - "Expected an index for each memref dimension.") + raise VerifyException("Expected an index for each memref dimension.") @staticmethod - def get(memref: SSAValue | Operation, indices: List[SSAValue | Operation], - mask: SSAValue | Operation, - value_to_store: SSAValue | Operation) -> Maskedstore: - return Maskedstore.build( - operands=[memref, indices, mask, value_to_store]) + def get( + memref: SSAValue | Operation, + indices: List[SSAValue | Operation], + mask: SSAValue | Operation, + value_to_store: SSAValue | Operation, + ) -> Maskedstore: + return Maskedstore.build(operands=[memref, indices, mask, value_to_store]) @irdl_op_definition @@ -263,9 +283,10 @@ def verify_(self): def get(mask_operands: list[Operation | SSAValue]) -> Createmask: return Createmask.build( operands=[mask_operands], - result_types=[VectorType.from_element_type_and_shape(i1, [1])]) + result_types=[VectorType.from_element_type_and_shape(i1, [1])], + ) Vector = Dialect( - [Load, Store, Broadcast, FMA, Maskedload, Maskedstore, Print, Createmask], - []) + [Load, Store, Broadcast, FMA, Maskedload, Maskedstore, Print, Createmask], [] +) diff --git a/xdsl/frontend/block.py b/xdsl/frontend/block.py index a460e9f551..7414279750 100644 --- a/xdsl/frontend/block.py +++ b/xdsl/frontend/block.py @@ -11,11 +11,11 @@ def foo(a: int) -> int: def bb0(x: int): y: int = x + 2 bb1(y) - + @block def bb1(z: int): return z - + # Entry-point. bb0(a) ``` @@ -28,5 +28,8 @@ def decorate(*params: Any): def is_block(node: ast.FunctionDef) -> bool: - return len(node.decorator_list) == 1 and isinstance( - name := node.decorator_list[0], ast.Name) and name.id == "block" + return ( + len(node.decorator_list) == 1 + and isinstance(name := node.decorator_list[0], ast.Name) + and name.id == "block" + ) diff --git a/xdsl/frontend/code_generation.py b/xdsl/frontend/code_generation.py index bf838b9e15..178bc411de 100644 --- a/xdsl/frontend/code_generation.py +++ b/xdsl/frontend/code_generation.py @@ -14,11 +14,10 @@ @dataclass class CodeGeneration: - @staticmethod - def run_with_type_converter(type_converter: TypeConverter, - stmts: List[ast.stmt], - file: str | None) -> builtin.ModuleOp: + def run_with_type_converter( + type_converter: TypeConverter, stmts: List[ast.stmt], file: str | None + ) -> builtin.ModuleOp: """Generates xDSL code and returns it encapsulated into a single module.""" module = builtin.ModuleOp([]) visitor = CodegGenerationVisitor(type_converter, module, file) @@ -52,8 +51,9 @@ class CodegGenerationVisitor(ast.NodeVisitor): file: str | None """Path of the file containing the program being processed.""" - def __init__(self, type_converter: TypeConverter, module: builtin.ModuleOp, - file: str | None) -> None: + def __init__( + self, type_converter: TypeConverter, module: builtin.ModuleOp, file: str | None + ) -> None: self.type_converter = type_converter self.globals = type_converter.globals self.file = file @@ -65,8 +65,11 @@ def get_symbol(self, node: ast.Name) -> Attribute: assert self.symbol_table is not None if node.id not in self.symbol_table: raise CodeGenerationException( - self.file, node.lineno, node.col_offset, - f"Symbol '{node.id}' is not defined.") + self.file, + node.lineno, + node.col_offset, + f"Symbol '{node.id}' is not defined.", + ) return self.symbol_table[node.id] def visit(self, node: ast.AST) -> None: @@ -74,8 +77,11 @@ def visit(self, node: ast.AST) -> None: def generic_visit(self, node: ast.AST) -> None: raise CodeGenerationException( - self.file, node.lineno, node.col_offset, - f"Unsupported Python AST node {str(node)}") + self.file, + node.lineno, + node.col_offset, + f"Unsupported Python AST node {str(node)}", + ) def visit_AnnAssign(self, node: ast.AnnAssign) -> None: # TODO: Implement assignemnt in the next patch. @@ -102,13 +108,16 @@ def visit_BinOp(self, node: ast.BinOp): "BitOr": "__or__", "BitXor": "__xor__", "BitAnd": "__and__", - "MatMult": "__matmul__" + "MatMult": "__matmul__", } if op_name not in python_AST_operator_to_python_overload: raise CodeGenerationException( - self.file, node.lineno, node.col_offset, - f"Unexpected binary operation {op_name}.") + self.file, + node.lineno, + node.col_offset, + f"Unexpected binary operation {op_name}.", + ) # Check that the types of the operands are the same. # This is a (temporary?) restriction over Python for implementation simplicity. @@ -120,34 +129,40 @@ def visit_BinOp(self, node: ast.BinOp): lhs = self.inserter.get_operand() if lhs.typ != rhs.typ: raise CodeGenerationException( - self.file, node.lineno, node.col_offset, + self.file, + node.lineno, + node.col_offset, f"Expected the same types for binary operation '{op_name}', " - f"but got {lhs.typ} and {rhs.typ}.") + f"but got {lhs.typ} and {rhs.typ}.", + ) # Look-up what is the frontend type we deal with to resolve the binary # operation. - frontend_type = self.type_converter.xdsl_to_frontend_type_map[ - lhs.typ.__class__] + frontend_type = self.type_converter.xdsl_to_frontend_type_map[lhs.typ.__class__] overload_name = python_AST_operator_to_python_overload[op_name] try: - op = OpResolver.resolve_op_overload(overload_name, - frontend_type)(lhs, rhs) + op = OpResolver.resolve_op_overload(overload_name, frontend_type)(lhs, rhs) self.inserter.insert_op(op) except FrontendProgramException: raise CodeGenerationException( - self.file, node.lineno, node.col_offset, + self.file, + node.lineno, + node.col_offset, f"Binary operation '{op_name}' " f"is not supported by type '{frontend_type.__name__}' " - f"which does not overload '{overload_name}'.") + f"which does not overload '{overload_name}'.", + ) def visit_Compare(self, node: ast.Compare): # Allow a single comparison only. if len(node.comparators) != 1 or len(node.ops) != 1: raise CodeGenerationException( - self.file, node.lineno, node.col_offset, - "Expected a single comparator, but found " - f"{len(node.comparators)}.") + self.file, + node.lineno, + node.col_offset, + "Expected a single comparator, but found " f"{len(node.comparators)}.", + ) comp = node.comparators[0] op_name: str = node.ops[0].__class__.__name__ @@ -160,7 +175,7 @@ def visit_Compare(self, node: ast.Compare): "LtE": "__le__", "NotEq": "__ne__", "In": "__contains__", - "NotIn": "__contains__" + "NotIn": "__contains__", } # Table with currently unsupported Python AST cmpops. @@ -173,8 +188,11 @@ def visit_Compare(self, node: ast.Compare): if op_name in unsupported_python_AST_cmpop: raise CodeGenerationException( - self.file, node.lineno, node.col_offset, - f"Unsupported comparison operation '{op_name}'.") + self.file, + node.lineno, + node.col_offset, + f"Unsupported comparison operation '{op_name}'.", + ) # Check that the types of the operands are the same. # This is a (temporary?) restriction over Python for implementation simplicity. @@ -186,23 +204,28 @@ def visit_Compare(self, node: ast.Compare): lhs = self.inserter.get_operand() if lhs.typ != rhs.typ: raise CodeGenerationException( - self.file, node.lineno, node.col_offset, + self.file, + node.lineno, + node.col_offset, f"Expected the same types for comparison operator '{op_name}'," - f" but got {lhs.typ} and {rhs.typ}.") + f" but got {lhs.typ} and {rhs.typ}.", + ) # Resolve the comparison operation to an xdsl operation class python_op = python_AST_cmpop_to_python_overload[op_name] - frontend_type = self.type_converter.xdsl_to_frontend_type_map[ - lhs.typ.__class__] + frontend_type = self.type_converter.xdsl_to_frontend_type_map[lhs.typ.__class__] try: op = OpResolver.resolve_op_overload(python_op, frontend_type) except FrontendProgramException: raise CodeGenerationException( - self.file, node.lineno, node.col_offset, + self.file, + node.lineno, + node.col_offset, f"Comparison operation '{op_name}' " f"is not supported by type '{frontend_type.__name__}' " - f"which does not overload '{python_op}'.") + f"which does not overload '{python_op}'.", + ) # Create the comparison operation (including any potential negations) if op_name == "In": @@ -216,7 +239,7 @@ def visit_Compare(self, node: ast.Compare): "GtE": "sge", "Lt": "slt", "LtE": "sle", - "NotEq": "ne" + "NotEq": "ne", } mnemonic = python_AST_cmpop_to_mnemonic[op_name] op = op(lhs, rhs, mnemonic) @@ -224,7 +247,6 @@ def visit_Compare(self, node: ast.Compare): self.inserter.insert_op(op) def visit_FunctionDef(self, node: ast.FunctionDef) -> None: - # Set the symbol table. assert self.symbol_table is None self.symbol_table = dict() @@ -233,8 +255,9 @@ def visit_FunctionDef(self, node: ast.FunctionDef) -> None: argument_types: List[Attribute] = [] for i, arg in enumerate(node.args.args): if arg.annotation is None: - raise CodeGenerationException(self.file, arg.lineno, - arg.col_offset, f"") + raise CodeGenerationException( + self.file, arg.lineno, arg.col_offset, f"" + ) xdsl_type = self.type_converter.convert_type_hint(arg.annotation) argument_types.append(xdsl_type) @@ -246,8 +269,9 @@ def visit_FunctionDef(self, node: ast.FunctionDef) -> None: # Create a function operation. entry_block = Block() body_region = Region(entry_block) - func_op = func.FuncOp.from_region(node.name, argument_types, - return_types, body_region) + func_op = func.FuncOp.from_region( + node.name, argument_types, return_types, body_region + ) self.inserter.insert_op(func_op) self.inserter.set_insertion_point_from_block(entry_block) @@ -286,8 +310,11 @@ def visit_Pass(self, node: ast.Pass) -> None: if len(return_types) != 0: function_name = parent_op.sym_name.data raise CodeGenerationException( - self.file, node.lineno, node.col_offset, - f"Expected '{function_name}' to return a type.") + self.file, + node.lineno, + node.col_offset, + f"Expected '{function_name}' to return a type.", + ) self.inserter.insert_op(func.Return.get()) def visit_Return(self, node: ast.Return) -> None: @@ -304,9 +331,12 @@ def visit_Return(self, node: ast.Return) -> None: parent_op = self.inserter.insertion_point.parent_op() if not isinstance(parent_op, func.FuncOp): raise CodeGenerationException( - self.file, node.lineno, node.col_offset, + self.file, + node.lineno, + node.col_offset, "Return statement should be placed only at the end of the " - "function body.") + "function body.", + ) func_name = parent_op.sym_name.data func_return_types = parent_op.function_type.outputs.data @@ -315,9 +345,12 @@ def visit_Return(self, node: ast.Return) -> None: # Return nothing, check function signature matches. if len(func_return_types) != 0: raise CodeGenerationException( - self.file, node.lineno, node.col_offset, + self.file, + node.lineno, + node.col_offset, f"Expected non-zero number of return types in function " - f"'{func_name}', but got 0.") + f"'{func_name}', but got 0.", + ) self.inserter.insert_op(func.Return.get()) else: # Return some type, check function signature matches as well. @@ -327,15 +360,21 @@ def visit_Return(self, node: ast.Return) -> None: if len(func_return_types) == 0: raise CodeGenerationException( - self.file, node.lineno, node.col_offset, - f"Expected no return types in function '{func_name}'.") + self.file, + node.lineno, + node.col_offset, + f"Expected no return types in function '{func_name}'.", + ) for i in range(len(operands)): if func_return_types[i] != operands[i].typ: raise CodeGenerationException( - self.file, node.lineno, node.col_offset, + self.file, + node.lineno, + node.col_offset, f"Type signature and the type of the return value do " f"not match at position {i}: expected {func_return_types[i]}," - f" got {operands[i].typ}.") + f" got {operands[i].typ}.", + ) self.inserter.insert_op(func.Return.get(*operands)) diff --git a/xdsl/frontend/const.py b/xdsl/frontend/const.py index b3161b4bf9..d574202b48 100644 --- a/xdsl/frontend/const.py +++ b/xdsl/frontend/const.py @@ -27,10 +27,14 @@ def foo(a: i64): # compile-time error, cannot reuse the name e: Const[i16] = b + 2 # i16 constant, equal to 5 ``` """ + pass def is_constant(node: ast.expr) -> bool: """Returns `True` if the AST node is a Const type.""" - return isinstance(node, ast.Subscript) and isinstance( - node.value, ast.Name) and node.value.id == Const.__name__ + return ( + isinstance(node, ast.Subscript) + and isinstance(node.value, ast.Name) + and node.value.id == Const.__name__ + ) diff --git a/xdsl/frontend/context.py b/xdsl/frontend/context.py index a9e94c0fe1..77b8b3ddda 100644 --- a/xdsl/frontend/context.py +++ b/xdsl/frontend/context.py @@ -3,7 +3,7 @@ from contextlib import AbstractContextManager from dataclasses import dataclass from inspect import getsource -from sys import _getframe #type: ignore +from sys import _getframe # type: ignore from typing import Any from xdsl.frontend.program import FrontendProgram @@ -36,9 +36,10 @@ def __enter__(self) -> None: # Find where the program starts. for node in ast.walk(python_ast): - if isinstance(node, ast.With) and \ - node.lineno == frame.f_lineno - frame.f_code.co_firstlineno + 1: - + if ( + isinstance(node, ast.With) + and node.lineno == frame.f_lineno - frame.f_code.co_firstlineno + 1 + ): # Found the program AST. Store it for later compilation or # execution. self.program.stmts = node.body diff --git a/xdsl/frontend/dialects/builtin.py b/xdsl/frontend/dialects/builtin.py index 74833d27f0..13cf243cc8 100644 --- a/xdsl/frontend/dialects/builtin.py +++ b/xdsl/frontend/dialects/builtin.py @@ -33,71 +33,75 @@ def to_xdsl() -> Callable[..., Any]: return builtin.IntegerType def __add__( - self, - other: _Integer[_Width, - _Signedness]) -> _Integer[_Width, _Signedness]: + self, other: _Integer[_Width, _Signedness] + ) -> _Integer[_Width, _Signedness]: from xdsl.frontend.dialects.arith import addi + return addi(self, other) # type: ignore def __and__( - self, - other: _Integer[_Width, - _Signedness]) -> _Integer[_Width, _Signedness]: + self, other: _Integer[_Width, _Signedness] + ) -> _Integer[_Width, _Signedness]: from xdsl.frontend.dialects.arith import andi + return andi(self, other) # type: ignore def __lshift__( - self, - other: _Integer[_Width, - _Signedness]) -> _Integer[_Width, _Signedness]: + self, other: _Integer[_Width, _Signedness] + ) -> _Integer[_Width, _Signedness]: from xdsl.frontend.dialects.arith import shli + return shli(self, other) # type: ignore def __mul__( - self, - other: _Integer[_Width, - _Signedness]) -> _Integer[_Width, _Signedness]: + self, other: _Integer[_Width, _Signedness] + ) -> _Integer[_Width, _Signedness]: from xdsl.frontend.dialects.arith import muli + return muli(self, other) # type: ignore def __rshift__( - self, - other: _Integer[_Width, - _Signedness]) -> _Integer[_Width, _Signedness]: + self, other: _Integer[_Width, _Signedness] + ) -> _Integer[_Width, _Signedness]: from xdsl.frontend.dialects.arith import shrsi + return shrsi(self, other) # type: ignore def __sub__( - self, - other: _Integer[_Width, - _Signedness]) -> _Integer[_Width, _Signedness]: + self, other: _Integer[_Width, _Signedness] + ) -> _Integer[_Width, _Signedness]: from xdsl.frontend.dialects.arith import subi + return subi(self, other) # type: ignore - def __eq__( # type: ignore - self, other: _Integer[_Width, _Signedness]) -> i1: + def __eq__(self, other: _Integer[_Width, _Signedness]) -> i1: # type: ignore from xdsl.frontend.dialects.arith import cmpi + return cmpi(self, other, "eq") # type: ignore def __ge__(self, other: _Integer[_Width, _Signedness]) -> i1: from xdsl.frontend.dialects.arith import cmpi + return cmpi(self, other, "sge") # type: ignore def __gt__(self, other: _Integer[_Width, _Signedness]) -> i1: from xdsl.frontend.dialects.arith import cmpi + return cmpi(self, other, "sgt") # type: ignore def __le__(self, other: _Integer[_Width, _Signedness]) -> i1: from xdsl.frontend.dialects.arith import cmpi + return cmpi(self, other, "sle") # type: ignore def __lt__(self, other: _Integer[_Width, _Signedness]) -> i1: from xdsl.frontend.dialects.arith import cmpi + return cmpi(self, other, "slt") # type: ignore - def __ne__( # type: ignore - self, other: _Integer[_Width, _Signedness]) -> i1: + def __ne__(self, other: _Integer[_Width, _Signedness]) -> i1: # type: ignore from xdsl.frontend.dialects.arith import cmpi + return cmpi(self, other, "ne") # type: ignore @@ -141,14 +145,17 @@ def to_xdsl() -> Callable[..., Any]: def __add__(self, other: f16) -> f16: from xdsl.frontend.dialects.arith import addf + return addf(self, other) def __sub__(self, other: f16) -> f16: from xdsl.frontend.dialects.arith import subf + return subf(self, other) def __mul__(self, other: f16) -> f16: from xdsl.frontend.dialects.arith import mulf + return mulf(self, other) @@ -164,14 +171,17 @@ def to_xdsl() -> Callable[..., Any]: def __add__(self, other: f32) -> f32: from xdsl.frontend.dialects.arith import addf + return addf(self, other) def __sub__(self, other: f32) -> f32: from xdsl.frontend.dialects.arith import subf + return subf(self, other) def __mul__(self, other: f32) -> f32: from xdsl.frontend.dialects.arith import mulf + return mulf(self, other) @@ -187,14 +197,17 @@ def to_xdsl() -> Callable[..., Any]: def __add__(self, other: f64) -> f64: from xdsl.frontend.dialects.arith import addf + return addf(self, other) def __sub__(self, other: f64) -> f64: from xdsl.frontend.dialects.arith import subf + return subf(self, other) def __mul__(self, other: f64) -> f64: from xdsl.frontend.dialects.arith import mulf + return mulf(self, other) diff --git a/xdsl/frontend/exception.py b/xdsl/frontend/exception.py index 46edc2f8ff..5f81646706 100644 --- a/xdsl/frontend/exception.py +++ b/xdsl/frontend/exception.py @@ -41,8 +41,10 @@ def __init__( self.col = col def __str__(self) -> str: - str = 'Code generation exception at ' + str = "Code generation exception at " if self.file: - return str + f'"{self.file}", line {self.line} column {self.col}: {self.msg}' + return ( + str + f'"{self.file}", line {self.line} column {self.col}: {self.msg}' + ) else: - return str + f'line {self.line} column {self.col}: {self.msg}' + return str + f"line {self.line} column {self.col}: {self.msg}" diff --git a/xdsl/frontend/op_inserter.py b/xdsl/frontend/op_inserter.py index 76569d6d72..780cbcf584 100644 --- a/xdsl/frontend/op_inserter.py +++ b/xdsl/frontend/op_inserter.py @@ -30,7 +30,8 @@ def get_operand(self) -> SSAValue: """ if len(self.stack) == 0: raise FrontendProgramException( - "Trying to get an operand from an empty stack.") + "Trying to get an operand from an empty stack." + ) return self.stack.pop() def insert_op(self, op: Operation) -> None: diff --git a/xdsl/frontend/op_resolver.py b/xdsl/frontend/op_resolver.py index bdfc5bab46..ee500bc61c 100644 --- a/xdsl/frontend/op_resolver.py +++ b/xdsl/frontend/op_resolver.py @@ -16,25 +16,26 @@ class OpResolver: """ @staticmethod - def resolve_op(module_name: str, - func_name: str) -> Callable[..., Operation]: + def resolve_op(module_name: str, func_name: str) -> Callable[..., Operation]: module = importlib.import_module(module_name) resolver_name = "resolve_" + func_name if not hasattr(module, resolver_name): raise FrontendProgramException( f"Internal failure: operation '{func_name}' does not exist " - f"in module '{module_name}'.") + f"in module '{module_name}'." + ) return getattr(module, resolver_name)() @staticmethod def resolve_op_overload( - python_op: str, - frontend_type: Type[_FrontendType]) -> Callable[..., Operation]: + python_op: str, frontend_type: Type[_FrontendType] + ) -> Callable[..., Operation]: # First, get overloaded function. if not hasattr(frontend_type, python_op): raise FrontendProgramException( f"Internal failure: '{frontend_type.__name__}' does not " - f"overload '{python_op}'.") + f"overload '{python_op}'." + ) overload = getattr(frontend_type, python_op) # Inspect overloaded function to extract what it maps to. By our @@ -45,19 +46,22 @@ def resolve_op_overload( # return F(...) python_ast = ast.parse(inspect.getsource(overload).strip()) if not isinstance(python_ast, ast.Module) or not isinstance( - python_ast.body[0], ast.FunctionDef): + python_ast.body[0], ast.FunctionDef + ): raise FrontendProgramException( f"Internal failure while resolving '{python_op}'. Function AST" - " for resolution not found.") + " for resolution not found." + ) func_ast = python_ast.body[0] - if len(func_ast.body) != 2 or not isinstance( - func_ast.body[0], ast.ImportFrom) or not isinstance( - func_ast.body[1], ast.Return) or not isinstance( - func_ast.body[1].value, ast.Call) or not isinstance( - func_ast.body[1].value.func, ast.Name): - msg = \ - f""" + if ( + len(func_ast.body) != 2 + or not isinstance(func_ast.body[0], ast.ImportFrom) + or not isinstance(func_ast.body[1], ast.Return) + or not isinstance(func_ast.body[1].value, ast.Call) + or not isinstance(func_ast.body[1].value.func, ast.Name) + ): + msg = f""" Internal failure while resolving '{python_op}'. Function AST for resolution is not correct, instead it should be: def __overload__(...): from Dialect import Operation diff --git a/xdsl/frontend/passes/desymref.py b/xdsl/frontend/passes/desymref.py index f87ff4b5f5..7463de3f74 100644 --- a/xdsl/frontend/passes/desymref.py +++ b/xdsl/frontend/passes/desymref.py @@ -124,8 +124,9 @@ def get_symbols(block: Block) -> set[str]: return symbols -def lower_positional_bound(writes: list[symref.Update], - read: symref.Fetch) -> Operation | None: +def lower_positional_bound( + writes: list[symref.Update], read: symref.Fetch +) -> Operation | None: """ Returns a nearest write preceeding the `read`. If there is no such write, `None` is returned. @@ -210,7 +211,8 @@ def prepare_region(self, region: Region): # 2. Insertion of entry/exit blocks to ensure dominance. raise FrontendProgramException( f"Running desymrefier on region with {num_blocks} > 1 blocks is " - "not supported.") + "not supported." + ) def prepare_block(self, block: Block): """Prepares a block for promotion.""" @@ -226,37 +228,48 @@ def prepare_block(self, block: Block): symbols = get_symbols(block) for symbol in symbols: num_reads = sum( - isinstance(op, symref.Fetch) for op in block.ops - if get_symbol(op) == symbol) + isinstance(op, symref.Fetch) + for op in block.ops + if get_symbol(op) == symbol + ) num_writes = sum( - isinstance(op, symref.Update) for op in block.ops - if get_symbol(op) == symbol) + isinstance(op, symref.Update) + for op in block.ops + if get_symbol(op) == symbol + ) if num_reads > 1 or num_writes > 1: raise FrontendProgramException( f"Block {block} not ready for promotion: found {num_reads}" - f" reads and {num_writes} writes.") + f" reads and {num_writes} writes." + ) def prune_definitions(self, block: Block): """Removes all symbol definitions and their uses from the block.""" # Find all symbol definitions in this block. If no definitions # found, terminate. - while len(definitions := - [op for op in block.ops - if isinstance(op, symref.Declare)]) > 0: - + while ( + len( + definitions := [ + op for op in block.ops if isinstance(op, symref.Declare) + ] + ) + > 0 + ): # Otherwise, some definitions are still alive. for definition in definitions: symbol = get_symbol(definition) # Find all reads and writes for this symbol. reads = [ - op for op in block.ops if isinstance(op, symref.Fetch) - and get_symbol(op) == symbol + op + for op in block.ops + if isinstance(op, symref.Fetch) and get_symbol(op) == symbol ] writes = [ - op for op in block.ops if isinstance(op, symref.Update) - and get_symbol(op) == symbol + op + for op in block.ops + if isinstance(op, symref.Update) and get_symbol(op) == symbol ] # Symbol is never read, so remove its definition and any writes. @@ -284,8 +297,9 @@ def prune_definitions(self, block: Block): Rewriter.replace_op(read, [], [write.operands[0]]) def _prune_unused_reads(self, block: Block): - is_unused_read: Callable[[Operation], bool] = lambda op: isinstance( - op, symref.Fetch) and len(op.results[0].uses) == 0 + is_unused_read: Callable[[Operation], bool] = ( + lambda op: isinstance(op, symref.Fetch) and len(op.results[0].uses) == 0 + ) unused_reads = [op for op in block.ops if is_unused_read(op)] for read in unused_reads: Rewriter.erase_op(read) @@ -298,20 +312,24 @@ def prune_uses_without_definitions(self, block: Block): self._prune_unused_reads(block) # Find all symbols that are still in use in this block. - symbol_worklist: set[str] = set(symbol - for symbol in get_symbols(block) - if symbol not in prepared_symbols) + symbol_worklist: set[str] = set( + symbol + for symbol in get_symbols(block) + if symbol not in prepared_symbols + ) if len(symbol_worklist) == 0: return for symbol in symbol_worklist: reads = [ - op for op in block.ops if isinstance(op, symref.Fetch) - and get_symbol(op) == symbol + op + for op in block.ops + if isinstance(op, symref.Fetch) and get_symbol(op) == symbol ] writes = [ - op for op in block.ops if isinstance(op, symref.Update) - and get_symbol(op) == symbol + op + for op in block.ops + if isinstance(op, symref.Update) and get_symbol(op) == symbol ] assert len(reads) > 0 or len(writes) > 0 @@ -350,8 +368,7 @@ def prune_uses_without_definitions(self, block: Block): class DesymrefyPass(ModulePass): - - name = 'frontend-desymrefy' + name = "frontend-desymrefy" def apply(self, ctx: MLContext, op: builtin.ModuleOp): Desymrefier().desymrefy(op) diff --git a/xdsl/frontend/program.py b/xdsl/frontend/program.py index ad8620dba2..b6210e8346 100644 --- a/xdsl/frontend/program.py +++ b/xdsl/frontend/program.py @@ -33,8 +33,7 @@ class FrontendProgram: def _check_can_compile(self): if self.stmts is None or self.globals is None: - msg = \ - """ + msg = """ Cannot compile program without the code context. Try to use: p = FrontendProgram() with CodeContext(p): @@ -52,7 +51,8 @@ def compile(self, desymref: bool = True) -> None: type_converter = TypeConverter(self.globals) self.xdsl_program = CodeGeneration.run_with_type_converter( - type_converter, self.stmts, self.file) + type_converter, self.stmts, self.file + ) self.xdsl_program.verify() # Optionally run desymrefication pass to produce actual SSA. @@ -67,8 +67,7 @@ def desymref(self) -> None: def _check_can_print(self): if self.xdsl_program is None: - msg = \ - """ + msg = """ Cannot print the program IR without compiling it first. Make sure to use: p = FrontendProgram() with CodeContext(p): diff --git a/xdsl/frontend/python_code_check.py b/xdsl/frontend/python_code_check.py index 26d3a84929..9a1f598abe 100644 --- a/xdsl/frontend/python_code_check.py +++ b/xdsl/frontend/python_code_check.py @@ -8,7 +8,6 @@ @dataclass class PythonCodeCheck: - @staticmethod def run(stmts: List[ast.stmt], file: str | None) -> None: """ @@ -54,10 +53,10 @@ def run(stmts: List[ast.stmt], file: str | None) -> None: @dataclass class CheckStructure: - @staticmethod - def run_with_scope(single_scope: bool, stmts: List[ast.stmt], - file: str | None) -> None: + def run_with_scope( + single_scope: bool, stmts: List[ast.stmt], file: str | None + ) -> None: if single_scope: visitor = SingleScopeVisitor(file) else: @@ -68,7 +67,6 @@ def run_with_scope(single_scope: bool, stmts: List[ast.stmt], @dataclass class SingleScopeVisitor(ast.NodeVisitor): - file: str | None = field(default=None) """File path for error reporting.""" @@ -83,27 +81,35 @@ def visit_FunctionDef(self, node: ast.FunctionDef) -> None: if node.name in self.block_names: raise CodeGenerationException( - self.file, node.lineno, node.col_offset, - f"Block '{node.name}' is already defined.") + self.file, + node.lineno, + node.col_offset, + f"Block '{node.name}' is already defined.", + ) self.block_names.add(node.name) for stmt in node.body: if isinstance(stmt, ast.FunctionDef): if is_block(stmt): raise CodeGenerationException( - self.file, stmt.lineno, stmt.col_offset, + self.file, + stmt.lineno, + stmt.col_offset, f"Cannot have a nested block '{stmt.name}' inside the " - f"block '{node.name}'.") + f"block '{node.name}'.", + ) else: raise CodeGenerationException( - self.file, stmt.lineno, stmt.col_offset, + self.file, + stmt.lineno, + stmt.col_offset, f"Cannot have an inner function '{stmt.name}' inside " - f"the block '{node.name}'.") + f"the block '{node.name}'.", + ) @dataclass class MultipleScopeVisitor(ast.NodeVisitor): - file: str | None = field(default=None) function_and_block_names: Dict[str, Set[str]] = field(default_factory=dict) @@ -117,8 +123,11 @@ def visit_FunctionDef(self, node: ast.FunctionDef) -> None: if node.name in self.function_and_block_names: raise CodeGenerationException( - self.file, node.lineno, node.col_offset, - f"Function '{node.name}' is already defined.") + self.file, + node.lineno, + node.col_offset, + f"Function '{node.name}' is already defined.", + ) self.function_and_block_names[node.name] = set() # Functions cannot have inner functions but can have blocks inside @@ -127,29 +136,41 @@ def visit_FunctionDef(self, node: ast.FunctionDef) -> None: if isinstance(stmt, ast.FunctionDef): if not is_block(stmt): raise CodeGenerationException( - self.file, stmt.lineno, stmt.col_offset, + self.file, + stmt.lineno, + stmt.col_offset, f"Cannot have an inner function '{stmt.name}' inside " - f"the function '{node.name}'.") + f"the function '{node.name}'.", + ) else: if stmt.name in self.function_and_block_names[node.name]: raise CodeGenerationException( - self.file, stmt.lineno, stmt.col_offset, + self.file, + stmt.lineno, + stmt.col_offset, f"Block '{stmt.name}' is already defined in " - f"function '{node.name}'.") + f"function '{node.name}'.", + ) self.function_and_block_names[node.name].add(stmt.name) for inner in stmt.body: if isinstance(inner, ast.FunctionDef): if is_block(inner): raise CodeGenerationException( - self.file, inner.lineno, inner.col_offset, + self.file, + inner.lineno, + inner.col_offset, f"Cannot have a nested block '{inner.name}'" - f" inside the block '{stmt.name}'.") + f" inside the block '{stmt.name}'.", + ) else: raise CodeGenerationException( - self.file, inner.lineno, inner.col_offset, + self.file, + inner.lineno, + inner.col_offset, f"Cannot have an inner function '{inner.name}'" - f" inside the block '{stmt.name}'.") + f" inside the block '{stmt.name}'.", + ) @dataclass @@ -178,33 +199,40 @@ def run(stmts: List[ast.stmt], file: str | None) -> None: CheckAndInlineConstants.run_with_variables(stmts, set(), file) @staticmethod - def run_with_variables(stmts: List[ast.stmt], defined_variables: Set[str], - file: str | None) -> None: + def run_with_variables( + stmts: List[ast.stmt], defined_variables: Set[str], file: str | None + ) -> None: for i, stmt in enumerate(stmts): # This variable (`a = ...`) can be redefined as a constant, and so # we have to keep track of these to raise an exception. - if isinstance(stmt, ast.Assign) and len( - stmt.targets) == 1 and isinstance(stmt.targets[0], - ast.Name): + if ( + isinstance(stmt, ast.Assign) + and len(stmt.targets) == 1 + and isinstance(stmt.targets[0], ast.Name) + ): defined_variables.add(stmt.targets[0].id) continue # Similarly, this case (`a: i32 = ...`) can also be redefined as a # constant. - if isinstance(stmt, ast.AnnAssign) and isinstance( - stmt.target, - ast.Name) and not is_constant(stmt.annotation): + if ( + isinstance(stmt, ast.AnnAssign) + and isinstance(stmt.target, ast.Name) + and not is_constant(stmt.annotation) + ): defined_variables.add(stmt.target.id) continue # This is a constant. - if isinstance(stmt, ast.AnnAssign) and is_constant( - stmt.annotation): + if isinstance(stmt, ast.AnnAssign) and is_constant(stmt.annotation): if not isinstance(stmt.target, ast.Name): raise CodeGenerationException( - file, stmt.lineno, stmt.col_offset, + file, + stmt.lineno, + stmt.col_offset, f"All constant expressions have to be assigned to " - "'ast.Name' nodes.") + "'ast.Name' nodes.", + ) name = stmt.target.id try: @@ -214,24 +242,30 @@ def run_with_variables(stmts: List[ast.stmt], defined_variables: Set[str], # TODO: This error message can be improved by matching exact # exceptions returned by `eval` call. raise CodeGenerationException( - file, stmt.lineno, stmt.col_offset, + file, + stmt.lineno, + stmt.col_offset, f"Non-constant expression cannot be assigned to " - f"constant variable '{name}' or cannot be evaluated.") + f"constant variable '{name}' or cannot be evaluated.", + ) # For now, support primitive types only and add a guard to abort # in other cases. if not isinstance(value, int) and not isinstance(value, float): raise CodeGenerationException( - file, stmt.lineno, stmt.col_offset, + file, + stmt.lineno, + stmt.col_offset, f"Constant '{name}' has evaluated type '{type(value)}' " - "which is not supported.") + "which is not supported.", + ) # TODO: We should typecheck the value against the type. This can # get tricky since ints can overflow, etc. For example, `a: # Const[i16] = 100000000` should give an error. new_node = ast.Constant(value) inliner = ConstantInliner(name, new_node, file) - for candidate in stmts[(i + 1):]: + for candidate in stmts[(i + 1) :]: inliner.visit(candidate) # Ideally, we can prune this AST node now, but it is easier just @@ -243,10 +277,10 @@ def run_with_variables(stmts: List[ast.stmt], defined_variables: Set[str], # this then all constants above `i` must have been already inlined. # Hence, it is sufficient to check the function body only. if isinstance(stmt, ast.FunctionDef): - new_defined_variables = set( - [arg.arg for arg in stmt.args.args]) + new_defined_variables = set([arg.arg for arg in stmt.args.args]) CheckAndInlineConstants.run_with_variables( - stmt.body, new_defined_variables, file) + stmt.body, new_defined_variables, file + ) @dataclass @@ -269,20 +303,29 @@ class ConstantInliner(ast.NodeTransformer): """Path to the file containing the program.""" def visit_Assign(self, node: ast.Assign) -> ast.Assign: - if len(node.targets) == 1 and isinstance( - node.targets[0], ast.Name) and node.targets[0].id == self.name: + if ( + len(node.targets) == 1 + and isinstance(node.targets[0], ast.Name) + and node.targets[0].id == self.name + ): raise CodeGenerationException( - self.file, node.lineno, node.col_offset, + self.file, + node.lineno, + node.col_offset, f"Constant '{self.name}' is already defined and cannot be " - "assigned to.") + "assigned to.", + ) node.value = self.visit(node.value) return node def visit_AnnAssign(self, node: ast.AnnAssign) -> ast.AnnAssign: if isinstance(node.target, ast.Name) and node.target.id == self.name: raise CodeGenerationException( - self.file, node.lineno, node.col_offset, - f"Constant '{self.name}' is already defined.") + self.file, + node.lineno, + node.col_offset, + f"Constant '{self.name}' is already defined.", + ) assert node.value is not None node.value = self.visit(node.value) return node @@ -291,9 +334,12 @@ def visit_FunctionDef(self, node: ast.FunctionDef) -> ast.FunctionDef: for arg in node.args.args: if arg.arg == self.name: raise CodeGenerationException( - self.file, node.lineno, node.col_offset, + self.file, + node.lineno, + node.col_offset, f"Constant '{self.name}' is already defined and cannot be " - "used as a function/block argument name.") + "used as a function/block argument name.", + ) for stmt in node.body: self.visit(stmt) return node diff --git a/xdsl/frontend/symref.py b/xdsl/frontend/symref.py index 2ff625aac7..ca4011a3ac 100644 --- a/xdsl/frontend/symref.py +++ b/xdsl/frontend/symref.py @@ -1,7 +1,14 @@ from __future__ import annotations from typing import Annotated from xdsl.ir import Attribute, Dialect, OpResult, SSAValue -from xdsl.irdl import Operand, irdl_op_definition, OpAttr, AnyAttr, Operation, IRDLOperation +from xdsl.irdl import ( + Operand, + irdl_op_definition, + OpAttr, + AnyAttr, + Operation, + IRDLOperation, +) from xdsl.dialects.builtin import StringAttr, SymbolRefAttr @@ -27,8 +34,7 @@ class Fetch(IRDLOperation): def get(symbol: str | SymbolRefAttr, result_type: Attribute) -> Fetch: if isinstance(symbol, str): symbol = SymbolRefAttr(symbol) - return Fetch.build(attributes={"symbol": symbol}, - result_types=[result_type]) + return Fetch.build(attributes={"symbol": symbol}, result_types=[result_type]) @irdl_op_definition @@ -38,8 +44,7 @@ class Update(IRDLOperation): symbol: OpAttr[SymbolRefAttr] @staticmethod - def get(symbol: str | SymbolRefAttr, - value: Operation | SSAValue) -> Update: + def get(symbol: str | SymbolRefAttr, value: Operation | SSAValue) -> Update: if isinstance(symbol, str): symbol = SymbolRefAttr(symbol) return Update.build(operands=[value], attributes={"symbol": symbol}) diff --git a/xdsl/frontend/type_conversion.py b/xdsl/frontend/type_conversion.py index 2ae7a80259..4a479b1da8 100644 --- a/xdsl/frontend/type_conversion.py +++ b/xdsl/frontend/type_conversion.py @@ -25,15 +25,14 @@ class TypeConverter: annotation without explicitly constructing it. """ - name_to_xdsl_type_map: Dict[TypeName, - Attribute] = field(default_factory=dict) + name_to_xdsl_type_map: Dict[TypeName, Attribute] = field(default_factory=dict) """ Map to cache xDSL types created so far to avoid repeated conversions. """ - xdsl_to_frontend_type_map: Dict[Type[Attribute], - Type[_FrontendType]] = field( - default_factory=dict) + xdsl_to_frontend_type_map: Dict[Type[Attribute], Type[_FrontendType]] = field( + default_factory=dict + ) """ Map to lookup frontend types based on xDSL type. Useful if we want to see what overloaded Python operations does this xDSL type support. @@ -47,8 +46,12 @@ def __post_init__(self) -> None: index = frontend_builtin._Index # type: ignore self._cache_type(index, xdsl_builtin.IndexType(), "index") - def _cache_type(self, frontend_type: Type[_FrontendType], - xdsl_type: Attribute, type_name: TypeName) -> None: + def _cache_type( + self, + frontend_type: Type[_FrontendType], + xdsl_type: Attribute, + type_name: TypeName, + ) -> None: """Records frontend and corresponding xDSL types in cache.""" if type_name not in self.name_to_xdsl_type_map: self.name_to_xdsl_type_map[type_name] = xdsl_type @@ -64,9 +67,12 @@ def _convert_name(self, type_hint: ast.Name) -> Attribute: # Otherwise, it must be some frontend type, and we can look up its class # using the imports. if type_name not in self.globals: - raise CodeGenerationException(self.file, type_hint.lineno, - type_hint.col_offset, - f"Unknown type hint '{type_name}'.") + raise CodeGenerationException( + self.file, + type_hint.lineno, + type_hint.col_offset, + f"Unknown type hint '{type_name}'.", + ) type_class = self.globals[type_name] # First, type can be generic, e.g. `class _Integer(Generic[_W, _S])`. @@ -74,13 +80,14 @@ def _convert_name(self, type_hint: ast.Name) -> Attribute: generic_type_arguments = type_class.__args__ arguments_for_constructor: list[Any] = [] for type_argument in generic_type_arguments: - # Convert Literal[...] to concrete values. materialized_arguments = type_argument.__args__ if len(materialized_arguments) != 1: raise CodeGenerationException( - self.file, type_hint.lineno, type_hint.col_offset, - f"Expected 1 type argument for generic type '{type_name}', got {len(materialized_arguments)} type arguments instead." + self.file, + type_hint.lineno, + type_hint.col_offset, + f"Expected 1 type argument for generic type '{type_name}', got {len(materialized_arguments)} type arguments instead.", ) arguments_for_constructor.append(materialized_arguments[0]) continue @@ -93,8 +100,11 @@ def _convert_name(self, type_hint: ast.Name) -> Attribute: # If this is not a subclass of FrontendType, then abort. raise CodeGenerationException( - self.file, type_hint.lineno, type_hint.col_offset, - f"'{type_name}' is not a frontend type.") + self.file, + type_hint.lineno, + type_hint.col_offset, + f"'{type_name}' is not a frontend type.", + ) # Otherwise, type can be a simple non-generic frontend type, e.g. `class # _Index(FrontendType)`. @@ -104,8 +114,10 @@ def _convert_name(self, type_hint: ast.Name) -> Attribute: return xdsl_type raise CodeGenerationException( - self.file, type_hint.lineno, type_hint.col_offset, - f"Unknown type hint for type '{type_name}' inside 'ast.Name' expression." + self.file, + type_hint.lineno, + type_hint.col_offset, + f"Unknown type hint for type '{type_name}' inside 'ast.Name' expression.", ) def convert_type_hint(self, type_hint: ast.expr) -> Attribute: @@ -118,8 +130,11 @@ def convert_type_hint(self, type_hint: ast.expr) -> Attribute: # `Foo[Literal[2]]``. Support this in the future patches. if isinstance(type_hint, ast.Subscript): raise CodeGenerationException( - self.file, type_hint.lineno, type_hint.col_offset, - f"Converting subscript type hints is not supported.") + self.file, + type_hint.lineno, + type_hint.col_offset, + f"Converting subscript type hints is not supported.", + ) # Type hint can also be a TypeAlias. For example, one can define # `foo = Foo[Literal[2]]`. This case also handles standard Python types, like @@ -128,5 +143,8 @@ def convert_type_hint(self, type_hint: ast.expr) -> Attribute: return self._convert_name(type_hint) raise CodeGenerationException( - self.file, type_hint.lineno, type_hint.col_offset, - f"Unknown type hint AST node '{type_hint}'.") + self.file, + type_hint.lineno, + type_hint.col_offset, + f"Unknown type hint AST node '{type_hint}'.", + ) diff --git a/xdsl/interpreter.py b/xdsl/interpreter.py index 4a043d97fc..9d256a4cc7 100644 --- a/xdsl/interpreter.py +++ b/xdsl/interpreter.py @@ -1,8 +1,7 @@ from __future__ import annotations from dataclasses import dataclass, field -from typing import (IO, Any, Callable, Generator, Iterable, TypeAlias, TypeVar, - ParamSpec) +from typing import IO, Any, Callable, Generator, Iterable, TypeAlias, TypeVar, ParamSpec from xdsl.dialects.builtin import ModuleOp from xdsl.ir import OperationInvT, SSAValue, Operation @@ -21,18 +20,18 @@ class InterpreterFunctions: class ArithFunctions(InterpreterFunctions): @impl(arith.Addi) - def run_addi(self, interpreter: Interpreter, op: arith.Addi, + def run_addi(self, interpreter: Interpreter, op: arith.Addi, args: tuple[Any, ...]) -> tuple[Any, ...]: lhs, rhs = args - return lhs + rhs, + return lhs + rhs, ``` - The interpreter will take care of fetching the Python values associated - with the operand SSAValues, and setting the return values to the + The interpreter will take care of fetching the Python values associated + with the operand SSAValues, and setting the return values to the appropriate OpResults. To override the definition of an operation implementation, subclass the - class to override, and redefine the functions, annotating them with + class to override, and redefine the functions, annotating them with `@impl`. ``` python @@ -40,45 +39,43 @@ class to override, and redefine the functions, annotating them with class DebugArithFunctions(ArithFunctions): @impl(arith.Addi) - def run_addi(self, interpreter: Interpreter, op: arith.Addi, + def run_addi(self, interpreter: Interpreter, op: arith.Addi, args: tuple[Any, ...]) -> tuple[Any, ...]: lhs, rhs = args print(lhs, rhs, lhs + rhs) - return lhs + rhs, + return lhs + rhs, ``` """ @classmethod def _impls( - cls - ) -> Iterable[tuple[type[Operation], OpImpl[InterpreterFunctions, - Operation]]]: + cls, + ) -> Iterable[tuple[type[Operation], OpImpl[InterpreterFunctions, Operation]]]: try: impl_dict = getattr(cls, _IMPL_DICT) return impl_dict.items() except AttributeError as e: - raise ValueError( - f'Use `@register_impls` on class {cls.__name__}') from e + raise ValueError(f"Use `@register_impls` on class {cls.__name__}") from e -_FT = TypeVar('_FT', bound=InterpreterFunctions) +_FT = TypeVar("_FT", bound=InterpreterFunctions) -_IMPL_OP_TYPE = '__impl_op_type' -_IMPL_DICT = '__impl_dict' +_IMPL_OP_TYPE = "__impl_op_type" +_IMPL_DICT = "__impl_dict" -P = ParamSpec('P') +P = ParamSpec("P") def impl( - op_type: type[OperationInvT] + op_type: type[OperationInvT], ) -> Callable[[OpImpl[_FT, OperationInvT]], OpImpl[_FT, OperationInvT]]: """ Marks the Python implementation of an xDSL `Operation` instance, to be used - by an `Interpreter`. The Interpreter will fetch the Python values + by an `Interpreter`. The Interpreter will fetch the Python values associated with the operands from the current environment, and pass them as - the `args` parameter. The returned values are assigned to the `results` + the `args` parameter. The returned values are assigned to the `results` values. - + See `InterpreterFunctions` """ @@ -92,9 +89,9 @@ def annot(func: OpImpl[_FT, OperationInvT]) -> OpImpl[_FT, OperationInvT]: def register_impls(ft: type[_FT]) -> type[_FT]: """ Enumerates the methods on a given class, and registers the ones marked with - `@impl` in a way that an `Interpreter` instance can find them for dynamic + `@impl` in a way that an `Interpreter` instance can find them for dynamic dispatch during interpretation. - + See `InterpreterFunctions` """ impl_dict: _ImplDict = {} @@ -121,10 +118,10 @@ class _InterpreterFunctionImpls: so we keep a `(Functions, OpImpl)` tuple for every Operation type. """ - _impl_dict: dict[type[Operation], - tuple[InterpreterFunctions, - OpImpl[InterpreterFunctions, - Operation]]] = field(default_factory=dict) + _impl_dict: dict[ + type[Operation], + tuple[InterpreterFunctions, OpImpl[InterpreterFunctions, Operation]], + ] = field(default_factory=dict) def register_from(self, ft: InterpreterFunctions, /, override: bool): impls = ft._impls() # pyright: ignore[reportPrivateUsage] @@ -132,15 +129,18 @@ def register_from(self, ft: InterpreterFunctions, /, override: bool): if op_type in self._impl_dict and not override: raise ValueError( "Attempting to register implementation for op of type " - f"{op_type}, but type already registered") + f"{op_type}, but type already registered" + ) self._impl_dict[op_type] = (ft, impl) - def run(self, interpreter: Interpreter, op: Operation, - args: tuple[Any, ...]) -> tuple[Any, ...]: + def run( + self, interpreter: Interpreter, op: Operation, args: tuple[Any, ...] + ) -> tuple[Any, ...]: if type(op) not in self._impl_dict: raise InterpretationError( - f'Could not find interpretation function for op {op.name}') + f"Could not find interpretation function for op {op.name}" + ) ft, impl = self._impl_dict[type(op)] return impl(ft, interpreter, op, args) @@ -166,7 +166,7 @@ def __getitem__(self, key: SSAValue) -> Any: return self.env[key] if self.parent is not None: return self.parent[key] - raise InterpretationError(f'Could not find value for {key} in {self}') + raise InterpretationError(f"Could not find value for {key} in {self}") def __setitem__(self, key: SSAValue, value: Any): """ @@ -175,8 +175,9 @@ def __setitem__(self, key: SSAValue, value: Any): """ if key in self.env: raise InterpretationError( - f'Attempting to register SSAValue {value} for name {key}' - f', but value with that name already exists in {self}') + f"Attempting to register SSAValue {value} for name {key}" + f", but value with that name already exists in {self}" + ) self.env[key] = value def stack(self) -> Generator[InterpreterContext, None, None]: @@ -188,24 +189,24 @@ def stack(self) -> Generator[InterpreterContext, None, None]: yield self def __format__(self, __format_spec: str) -> str: - return '/'.join(c.name for c in self.stack()) + return "/".join(c.name for c in self.stack()) @dataclass class Interpreter: """ - An extensible interpreter, initialised with a Module to interpret. The - implementation for each Operation subclass should be provided via a - `InterpretationFunctions` instance. Interpretations can be overridden, and + An extensible interpreter, initialised with a Module to interpret. The + implementation for each Operation subclass should be provided via a + `InterpretationFunctions` instance. Interpretations can be overridden, and the override must be specified explicitly, by passing `override=True` to the `register_functions` method. """ module: ModuleOp - _impls: _InterpreterFunctionImpls = field( - default_factory=_InterpreterFunctionImpls) + _impls: _InterpreterFunctionImpls = field(default_factory=_InterpreterFunctionImpls) _ctx: InterpreterContext = field( - default_factory=lambda: InterpreterContext(name='root')) + default_factory=lambda: InterpreterContext(name="root") + ) file: IO[str] | None = field(default=None) def get_values(self, values: Iterable[SSAValue]) -> tuple[Any, ...]: @@ -223,7 +224,7 @@ def set_values(self, pairs: Iterable[tuple[SSAValue, Any]]): for ssa_value, result_value in pairs: self._ctx[ssa_value] = result_value - def push_scope(self, name: str = 'unknown') -> None: + def push_scope(self, name: str = "unknown") -> None: """ Create new scope in current environment, with optional custom `name`. """ @@ -231,23 +232,22 @@ def push_scope(self, name: str = 'unknown') -> None: def pop_scope(self) -> None: """ - Discard the current scope, and all the values registered in it. Sets + Discard the current scope, and all the values registered in it. Sets parent scope of current scope to new current scope. Raises InterpretationError if current scope is root scope. """ if self._ctx.parent is None: - raise InterpretationError('Attempting to pop root env') + raise InterpretationError("Attempting to pop root env") self._ctx = self._ctx.parent - def register_implementations(self, - impls: InterpreterFunctions, - /, - override: bool = False) -> None: + def register_implementations( + self, impls: InterpreterFunctions, /, override: bool = False + ) -> None: """ - Register implementations for operations defined in given - `InterpreterFunctions` object. Raise InterpretationError if an - operation already has an implementation registered, unless override is + Register implementations for operations defined in given + `InterpreterFunctions` object. Raise InterpretationError if an + operation already has an implementation registered, unless override is set to True. """ self._impls.register_from(impls, override=override) @@ -255,13 +255,14 @@ def register_implementations(self, def run(self, op: Operation): """ Fetches the implemetation for the given op, passes it the Python values - associated with the SSA operands, and assigns the results to the + associated with the SSA operands, and assigns the results to the operation's results. """ inputs = self.get_values(op.operands) results = self._impls.run(self, op, inputs) self.interpreter_assert( - len(op.results) == len(results), 'Incorrect number of results') + len(op.results) == len(results), "Incorrect number of results" + ) self.set_values(zip(op.results, results)) def run_module(self): @@ -275,12 +276,11 @@ def print(self, *args: Any, **kwargs: Any): def interpreter_assert(self, condition: bool, message: str | None = None): """Raise InterpretationError if condition is not satisfied.""" if not condition: - raise InterpretationError( - f'AssertionError: ({self._ctx})({message})') + raise InterpretationError(f"AssertionError: ({self._ctx})({message})") OpImpl: TypeAlias = Callable[ - [_FT, Interpreter, OperationInvT, tuple[Any, ...]], tuple[Any, ...]] + [_FT, Interpreter, OperationInvT, tuple[Any, ...]], tuple[Any, ...] +] -_ImplDict: TypeAlias = dict[type[Operation], OpImpl[InterpreterFunctions, - Operation]] +_ImplDict: TypeAlias = dict[type[Operation], OpImpl[InterpreterFunctions, Operation]] diff --git a/xdsl/interpreters/experimental/pdl.py b/xdsl/interpreters/experimental/pdl.py index a75fa1cbac..fa5623e775 100644 --- a/xdsl/interpreters/experimental/pdl.py +++ b/xdsl/interpreters/experimental/pdl.py @@ -1,14 +1,15 @@ -from dataclasses import (dataclass, field) +from dataclasses import dataclass, field from typing import Any -from xdsl.ir import (Attribute, MLContext, TypeAttribute, OpResult, Operation, - SSAValue) +from xdsl.ir import Attribute, MLContext, TypeAttribute, OpResult, Operation, SSAValue from xdsl.dialects import pdl from xdsl.dialects.builtin import IntegerAttr, IntegerType, ModuleOp -from xdsl.pattern_rewriter import (PatternRewriter, PatternRewriteWalker, - AnonymousRewritePattern) -from xdsl.interpreter import (Interpreter, InterpreterFunctions, - register_impls, impl) +from xdsl.pattern_rewriter import ( + PatternRewriter, + PatternRewriteWalker, + AnonymousRewritePattern, +) +from xdsl.interpreter import Interpreter, InterpreterFunctions, register_impls, impl from xdsl.utils.exceptions import InterpretationError from xdsl.utils.hints import isa @@ -20,15 +21,18 @@ class PDLMatcher: interpretation. A new instance is created per operation being checked against. """ - matching_context: dict[SSAValue, Operation | Attribute - | SSAValue] = field(default_factory=dict) + + matching_context: dict[SSAValue, Operation | Attribute | SSAValue] = field( + default_factory=dict + ) """ For each SSAValue that is an OpResult of an operation in the PDL dialect, the corresponding xDSL object. """ - def match_operand(self, ssa_val: SSAValue, pdl_op: pdl.OperandOp, - xdsl_val: SSAValue): + def match_operand( + self, ssa_val: SSAValue, pdl_op: pdl.OperandOp, xdsl_val: SSAValue + ): if ssa_val in self.matching_context: return True @@ -36,16 +40,16 @@ def match_operand(self, ssa_val: SSAValue, pdl_op: pdl.OperandOp, assert isinstance(pdl_op.valueType, OpResult) assert isinstance(pdl_op.valueType.op, pdl.TypeOp) - if not self.match_type(pdl_op.valueType, pdl_op.valueType.op, - xdsl_val.typ): + if not self.match_type(pdl_op.valueType, pdl_op.valueType.op, xdsl_val.typ): return False self.matching_context[ssa_val] = xdsl_val return True - def match_result(self, ssa_val: SSAValue, pdl_op: pdl.ResultOp, - xdsl_operand: SSAValue): + def match_result( + self, ssa_val: SSAValue, pdl_op: pdl.ResultOp, xdsl_operand: SSAValue + ): if ssa_val in self.matching_context: return self.matching_context[ssa_val] == xdsl_operand @@ -58,8 +62,7 @@ def match_result(self, ssa_val: SSAValue, pdl_op: pdl.ResultOp, xdsl_op = xdsl_operand.op - if not self.match_operation(root_pdl_op_value, root_pdl_op_value.op, - xdsl_op): + if not self.match_operation(root_pdl_op_value, root_pdl_op_value.op, xdsl_op): return False original_op = root_pdl_op_value.op @@ -73,8 +76,7 @@ def match_result(self, ssa_val: SSAValue, pdl_op: pdl.ResultOp, return True - def match_type(self, ssa_val: SSAValue, pdl_op: pdl.TypeOp, - xdsl_attr: Attribute): + def match_type(self, ssa_val: SSAValue, pdl_op: pdl.TypeOp, xdsl_attr: Attribute): if ssa_val in self.matching_context: return self.matching_context[ssa_val] == xdsl_attr @@ -82,8 +84,13 @@ def match_type(self, ssa_val: SSAValue, pdl_op: pdl.TypeOp, return True - def match_attribute(self, ssa_val: SSAValue, pdl_op: pdl.AttributeOp, - attr_name: str, xdsl_attr: Attribute): + def match_attribute( + self, + ssa_val: SSAValue, + pdl_op: pdl.AttributeOp, + attr_name: str, + xdsl_attr: Attribute, + ): if ssa_val in self.matching_context: return self.matching_context[ssa_val] == xdsl_attr @@ -96,19 +103,21 @@ def match_attribute(self, ssa_val: SSAValue, pdl_op: pdl.AttributeOp, assert isinstance(pdl_op.valueType.op, pdl.TypeOp) assert isa( - xdsl_attr, - IntegerAttr[IntegerType]), 'Only handle integer types for now' + xdsl_attr, IntegerAttr[IntegerType] + ), "Only handle integer types for now" - if not self.match_type(pdl_op.valueType, pdl_op.valueType.op, - xdsl_attr.typ): + if not self.match_type( + pdl_op.valueType, pdl_op.valueType.op, xdsl_attr.typ + ): return False self.matching_context[ssa_val] = xdsl_attr return True - def match_operation(self, ssa_val: SSAValue, pdl_op: pdl.OperationOp, - xdsl_op: Operation) -> bool: + def match_operation( + self, ssa_val: SSAValue, pdl_op: pdl.OperationOp, xdsl_op: Operation + ) -> bool: if ssa_val in self.matching_context: return self.matching_context[ssa_val] == xdsl_op @@ -116,9 +125,7 @@ def match_operation(self, ssa_val: SSAValue, pdl_op: pdl.OperationOp, if xdsl_op.name != pdl_op.opName.data: return False - attribute_value_names = [ - avn.data for avn in pdl_op.attributeValueNames.data - ] + attribute_value_names = [avn.data for avn in pdl_op.attributeValueNames.data] for avn, av in zip(attribute_value_names, pdl_op.attributeValues): assert isinstance(av, OpResult) @@ -126,8 +133,7 @@ def match_operation(self, ssa_val: SSAValue, pdl_op: pdl.OperationOp, if avn not in xdsl_op.attributes: return False - if not self.match_attribute(av, av.op, avn, - xdsl_op.attributes[avn]): + if not self.match_attribute(av, av.op, avn, xdsl_op.attributes[avn]): return False pdl_operands = pdl_op.operandValues @@ -140,12 +146,10 @@ def match_operation(self, ssa_val: SSAValue, pdl_op: pdl.OperationOp, assert isinstance(pdl_operand, OpResult) assert isinstance(pdl_operand.op, pdl.OperandOp | pdl.ResultOp) if isinstance(pdl_operand.op, pdl.OperandOp): - if not self.match_operand(pdl_operand, pdl_operand.op, - xdsl_operand): + if not self.match_operand(pdl_operand, pdl_operand.op, xdsl_operand): return False elif isinstance(pdl_operand.op, pdl.ResultOp): - if not self.match_result(pdl_operand, pdl_operand.op, - xdsl_operand): + if not self.match_result(pdl_operand, pdl_operand.op, xdsl_operand): return False pdl_results = pdl_op.typeValues @@ -177,6 +181,7 @@ class PDLFunctions(InterpreterFunctions): to the corresponding PDL SSA values, and runs the rewrite operations one by one. The implementations in this class are for the RHS of the rewrite. """ + ctx: MLContext module: ModuleOp _rewriter: PatternRewriter | None = field(default=None) @@ -191,47 +196,52 @@ def rewriter(self, rewriter: PatternRewriter): self._rewriter = rewriter @impl(pdl.PatternOp) - def run_pattern(self, interpreter: Interpreter, op: pdl.PatternOp, - args: tuple[Any, ...]) -> tuple[Any, ...]: + def run_pattern( + self, interpreter: Interpreter, op: pdl.PatternOp, args: tuple[Any, ...] + ) -> tuple[Any, ...]: ops = op.regions[0].ops if not len(ops): - raise InterpretationError('No ops in pattern') + raise InterpretationError("No ops in pattern") if not isinstance(ops[-1], pdl.RewriteOp): raise InterpretationError( - 'Expected pdl.pattern to be terminated by pdl.rewrite') + "Expected pdl.pattern to be terminated by pdl.rewrite" + ) for r_op in ops[:-1]: # in forward pass, the Python value is the SSA value itself if len(r_op.results) != 1: - raise InterpretationError('PDL ops must have one result') + raise InterpretationError("PDL ops must have one result") result = r_op.results[0] - interpreter.set_values(((result, r_op), )) + interpreter.set_values(((result, r_op),)) interpreter.run(ops[-1]) return () @impl(pdl.RewriteOp) - def run_rewrite(self, interpreter: Interpreter, - pdl_rewrite_op: pdl.RewriteOp, - args: tuple[Any, ...]) -> tuple[Any, ...]: - + def run_rewrite( + self, + interpreter: Interpreter, + pdl_rewrite_op: pdl.RewriteOp, + args: tuple[Any, ...], + ) -> tuple[Any, ...]: input_module = self.module def rewrite(xdsl_op: Operation, rewriter: PatternRewriter) -> None: - pdl_op_val = pdl_rewrite_op.root - assert pdl_op_val is not None, 'TODO: handle None root op in pdl.RewriteOp' - assert pdl_rewrite_op.body is not None, 'TODO: handle None body op in pdl.RewriteOp' + assert pdl_op_val is not None, "TODO: handle None root op in pdl.RewriteOp" + assert ( + pdl_rewrite_op.body is not None + ), "TODO: handle None body op in pdl.RewriteOp" - pdl_op, = interpreter.get_values((pdl_op_val, )) + (pdl_op,) = interpreter.get_values((pdl_op_val,)) assert isinstance(pdl_op, pdl.OperationOp) matcher = PDLMatcher() if not matcher.match_operation(pdl_op_val, pdl_op, xdsl_op): return - interpreter.push_scope('rewrite') + interpreter.push_scope("rewrite") interpreter.set_values(matcher.matching_context.items()) self.rewriter = rewriter @@ -242,25 +252,26 @@ def rewrite(xdsl_op: Operation, rewriter: PatternRewriter) -> None: rewriter = AnonymousRewritePattern(rewrite) - PatternRewriteWalker( - rewriter, apply_recursively=False).rewrite_module(input_module) + PatternRewriteWalker(rewriter, apply_recursively=False).rewrite_module( + input_module + ) return () @impl(pdl.OperationOp) - def run_operation(self, interpreter: Interpreter, op: pdl.OperationOp, - args: tuple[Any, ...]) -> tuple[Any, ...]: + def run_operation( + self, interpreter: Interpreter, op: pdl.OperationOp, args: tuple[Any, ...] + ) -> tuple[Any, ...]: assert op.opName is not None op_name = op.opName.data op_type = self.ctx.get_optional_op(op_name) if op_type is None: raise InterpretationError( - f'Could not find op type for name {op_name} in context') + f"Could not find op type for name {op_name} in context" + ) - attribute_value_names = [ - avn.data for avn in op.attributeValueNames.data - ] + attribute_value_names = [avn.data for avn in op.attributeValueNames.data] # How to deal with operand_segment_sizes? # operand_values, attribute_values, type_values = args @@ -281,35 +292,36 @@ def run_operation(self, interpreter: Interpreter, op: pdl.OperationOp, attributes = dict(zip(attribute_value_names, attribute_values)) - result_op = op_type.create(operands=operand_values, - result_types=type_values, - attributes=attributes) + result_op = op_type.create( + operands=operand_values, result_types=type_values, attributes=attributes + ) - return result_op, + return (result_op,) @impl(pdl.ReplaceOp) - def run_replace(self, interpreter: Interpreter, op: pdl.ReplaceOp, - args: tuple[Any, ...]) -> tuple[Any, ...]: + def run_replace( + self, interpreter: Interpreter, op: pdl.ReplaceOp, args: tuple[Any, ...] + ) -> tuple[Any, ...]: rewriter = self.rewriter - old, = interpreter.get_values((op.opValue, )) + (old,) = interpreter.get_values((op.opValue,)) if op.replOperation is not None: - new_op, = interpreter.get_values((op.replOperation, )) + (new_op,) = interpreter.get_values((op.replOperation,)) rewriter.replace_op(old, new_op) elif len(op.replValues): new_vals = interpreter.get_values(op.replValues) rewriter.replace_op(old, new_ops=[], new_results=list(new_vals)) else: - assert False, 'Unexpected ReplaceOp' + assert False, "Unexpected ReplaceOp" return () @impl(ModuleOp) - def run_module(self, interpreter: Interpreter, op: ModuleOp, - args: tuple[Any, ...]) -> tuple[Any, ...]: + def run_module( + self, interpreter: Interpreter, op: ModuleOp, args: tuple[Any, ...] + ) -> tuple[Any, ...]: ops = op.ops if len(ops) != 1 or not isinstance(ops[0], pdl.PatternOp): - raise InterpretationError( - 'Expected single pattern op in pdl module') + raise InterpretationError("Expected single pattern op in pdl module") return self.run_pattern(interpreter, ops[0], args) diff --git a/xdsl/ir.py b/xdsl/ir.py index 2d6831360a..9015af81a8 100644 --- a/xdsl/ir.py +++ b/xdsl/ir.py @@ -6,8 +6,19 @@ from dataclasses import dataclass, field from io import StringIO from itertools import chain -from typing import (TYPE_CHECKING, Any, Callable, Generic, Iterable, Protocol, - Sequence, TypeVar, cast, Iterator, ClassVar) +from typing import ( + TYPE_CHECKING, + Any, + Callable, + Generic, + Iterable, + Protocol, + Sequence, + TypeVar, + cast, + Iterator, + ClassVar, +) from xdsl.utils.deprecation import deprecated # Used for cyclic dependencies in type hints @@ -17,18 +28,19 @@ from xdsl.irdl import ParamAttrDef from xdsl.utils.lexer import Span -OpT = TypeVar('OpT', bound='Operation') +OpT = TypeVar("OpT", bound="Operation") @dataclass class Dialect: """Contains the operations and attributes of a specific dialect""" - _operations: list[type[Operation]] = field(default_factory=list, - init=True, - repr=True) - _attributes: list[type[Attribute]] = field(default_factory=list, - init=True, - repr=True) + + _operations: list[type[Operation]] = field( + default_factory=list, init=True, repr=True + ) + _attributes: list[type[Attribute]] = field( + default_factory=list, init=True, repr=True + ) @property def operations(self) -> Iterator[type[Operation]]: @@ -42,7 +54,8 @@ def __call__(self, ctx: MLContext) -> None: print( "Calling a dialect in order to register it is deprecated " "and will soon be removed.", - file=sys.stderr) + file=sys.stderr, + ) # TODO; Remove this function in a future release. assert isinstance(ctx, MLContext) ctx.register_dialect(self) @@ -51,6 +64,7 @@ def __call__(self, ctx: MLContext) -> None: @dataclass class MLContext: """Contains structures for operations/attributes registration.""" + _registeredOps: dict[str, type[Operation]] = field(default_factory=dict) _registeredAttrs: dict[str, type[Attribute]] = field(default_factory=dict) @@ -71,14 +85,12 @@ def register_op(self, op: type[Operation]) -> None: def register_attr(self, attr: type[Attribute]) -> None: """Register an attribute definition. Attribute names should be unique.""" if attr.name in self._registeredAttrs: - raise Exception( - f"Attribute {attr.name} has already been registered") + raise Exception(f"Attribute {attr.name} has already been registered") self._registeredAttrs[attr.name] = attr def get_optional_op( - self, - name: str, - allow_unregistered: bool = False) -> type[Operation] | None: + self, name: str, allow_unregistered: bool = False + ) -> type[Operation] | None: """ Get an operation class from its name if it exists. If the operation is not registered, return None unless @@ -88,14 +100,13 @@ def get_optional_op( return self._registeredOps[name] if allow_unregistered: from xdsl.dialects.builtin import UnregisteredOp + op_type = UnregisteredOp.with_name(name) self._registeredOps[name] = op_type return op_type return None - def get_op(self, - name: str, - allow_unregistered: bool = False) -> type[Operation]: + def get_op(self, name: str, allow_unregistered: bool = False) -> type[Operation]: """ Get an operation class from its name. If the operation is not registered, raise an exception unless @@ -106,10 +117,10 @@ def get_op(self, raise Exception(f"Operation {name} is not registered") def get_optional_attr( - self, - name: str, - allow_unregistered: bool = False, - create_unregistered_as_type: bool = False + self, + name: str, + allow_unregistered: bool = False, + create_unregistered_as_type: bool = False, ) -> type[Attribute] | None: """ Get an attribute class from its name if it exists. @@ -123,17 +134,21 @@ def get_optional_attr( return self._registeredAttrs[name] if allow_unregistered: from xdsl.dialects.builtin import UnregisteredAttr + attr_type = UnregisteredAttr.with_name_and_type( - name, create_unregistered_as_type) + name, create_unregistered_as_type + ) self._registeredAttrs[name] = attr_type return attr_type return None - def get_attr(self, - name: str, - allow_unregistered: bool = False, - create_unregistered_as_type: bool = False) -> type[Attribute]: + def get_attr( + self, + name: str, + allow_unregistered: bool = False, + create_unregistered_as_type: bool = False, + ) -> type[Attribute]: """ Get an attribute class from its name. If the attribute is not registered, raise an exception unless @@ -142,8 +157,9 @@ def get_attr(self, additional flag is required to create an UnregisterAttr that is also a type. """ - if attr_type := self.get_optional_attr(name, allow_unregistered, - create_unregistered_as_type): + if attr_type := self.get_optional_attr( + name, allow_unregistered, create_unregistered_as_type + ): return attr_type raise Exception(f"Attribute {name} is not registered") @@ -174,8 +190,7 @@ class SSAValue(ABC): _name: str | None = field(init=False, default=None) - _name_regex: ClassVar[re.Pattern[str]] = re.compile( - r'([A-Za-z_$.-][\w$.-]*)') + _name_regex: ClassVar[re.Pattern[str]] = re.compile(r"([A-Za-z_$.-][\w$.-]*)") @property @abstractmethod @@ -213,10 +228,10 @@ def get(arg: SSAValue | Operation) -> SSAValue: if isinstance(arg, Operation): if len(arg.results) == 1: return arg.results[0] - raise ValueError( - "SSAValue.build: expected operation with a single result.") + raise ValueError("SSAValue.build: expected operation with a single result.") raise TypeError( - f"Expected SSAValue or Operation for SSAValue.get, but got {arg}") + f"Expected SSAValue or Operation for SSAValue.get, but got {arg}" + ) def add_use(self, use: Use): """Add a new use of the value.""" @@ -245,7 +260,8 @@ def erase(self, safe_erase: bool = True) -> None: if safe_erase and len(self.uses) != 0: raise Exception( "Attempting to delete SSA value that still has uses of result " - f"of operation:\n{self.owner}") + f"of operation:\n{self.owner}" + ) self.replace_by(ErasedSSAValue(self.typ, self)) @@ -341,7 +357,7 @@ def __post_init__(self): ) -A = TypeVar('A', bound='Attribute') +A = TypeVar("A", bound="Attribute") class Attribute(ABC): @@ -350,6 +366,7 @@ class Attribute(ABC): Attributes are used to represent SSA variable types, and can be attached on operations to give extra information. """ + name: str = field(default="", init=False) """The attribute name should be a static field in the attribute classes.""" @@ -373,6 +390,7 @@ def verify(self) -> None: def __str__(self) -> str: from xdsl.printer import Printer + res = StringIO() printer = Printer(stream=res) printer.print_attribute(self) @@ -390,6 +408,7 @@ def __str__(self) -> str: @dataclass(frozen=True) class Data(Generic[DataElement], Attribute, ABC): """An attribute represented by a Python structure.""" + data: DataElement @classmethod @@ -428,6 +447,7 @@ def print_parameter(self, printer: Printer) -> None: @dataclass(frozen=True) class ParametrizedAttribute(Attribute): """An attribute parametrized by other attributes.""" + parameters: list[Attribute] = field(default_factory=list) @classmethod @@ -472,7 +492,6 @@ def irdl_definition(cls) -> ParamAttrDef: @dataclass class IRNode(ABC): - parent: IRNode | None def is_ancestor(self, op: IRNode) -> bool: @@ -492,7 +511,7 @@ def get_toplevel_object(self) -> IRNode: def is_structurally_equivalent( self, other: IRNode, - context: dict[IRNode | SSAValue, IRNode | SSAValue] | None = None + context: dict[IRNode | SSAValue, IRNode | SSAValue] | None = None, ) -> bool: """Check if two IR nodes are structurally equivalent.""" ... @@ -507,7 +526,7 @@ def __hash__(self) -> int: @dataclass(frozen=True) -class OpTrait(): +class OpTrait: """ A trait attached to an operation definition. Traits can be used to define operation invariants, or to specify @@ -583,16 +602,17 @@ def operands(self, new: list[SSAValue] | tuple[SSAValue, ...]): self._operands = new def __post_init__(self): - assert (self.name != "") - assert (isinstance(self.name, str)) - - def __init__(self, - operands: Sequence[SSAValue] | None = None, - result_types: Sequence[Attribute] | None = None, - attributes: dict[str, Attribute] | None = None, - successors: Sequence[Block] | None = None, - regions: Sequence[Region] | None = None) -> None: + assert self.name != "" + assert isinstance(self.name, str) + def __init__( + self, + operands: Sequence[SSAValue] | None = None, + result_types: Sequence[Attribute] | None = None, + attributes: dict[str, Attribute] | None = None, + successors: Sequence[Block] | None = None, + regions: Sequence[Region] | None = None, + ) -> None: if operands is None: operands = [] if result_types is None: @@ -620,19 +640,19 @@ def __init__(self, self.__post_init__() @classmethod - def create(cls: type[OpT], - operands: Sequence[SSAValue] | None = None, - result_types: Sequence[Attribute] | None = None, - attributes: dict[str, Attribute] | None = None, - successors: Sequence[Block] | None = None, - regions: Sequence[Region] | None = None) -> OpT: + def create( + cls: type[OpT], + operands: Sequence[SSAValue] | None = None, + result_types: Sequence[Attribute] | None = None, + attributes: dict[str, Attribute] | None = None, + successors: Sequence[Block] | None = None, + regions: Sequence[Region] | None = None, + ) -> OpT: op = cls.__new__(cls) - Operation.__init__(op, operands, result_types, attributes, successors, - regions) + Operation.__init__(op, operands, result_types, attributes, successors, regions) return op - def replace_operand(self, operand: int | SSAValue, - new_operand: SSAValue) -> None: + def replace_operand(self, operand: int | SSAValue, new_operand: SSAValue) -> None: """ Replace an operand with another operand. @@ -642,20 +662,24 @@ def replace_operand(self, operand: int | SSAValue, try: operand_idx = self._operands.index(operand) except ValueError as err: - raise ValueError("{} is not an operand of {}.".format( - operand, self)) from err + raise ValueError( + "{} is not an operand of {}.".format(operand, self) + ) from err else: operand_idx = operand - self.operands = list(self._operands[:operand_idx]) + [ - new_operand - ] + list(self._operands[operand_idx + 1:]) + self.operands = ( + list(self._operands[:operand_idx]) + + [new_operand] + + list(self._operands[operand_idx + 1 :]) + ) def add_region(self, region: Region) -> None: """Add an unattached region to the operation.""" if region.parent: raise Exception( - "Cannot add region that is already attached on an operation.") + "Cannot add region that is already attached on an operation." + ) self.regions.append(region) region.parent = self @@ -666,7 +690,9 @@ def get_region_index(self, region: Region) -> int: for idx, curr_region in enumerate(self.regions): if curr_region is region: return idx - assert False, "The IR is corrupted. Operation seems to be the region's parent but still doesn't have the region attached to it." + assert ( + False + ), "The IR is corrupted. Operation seems to be the region's parent but still doesn't have the region attached to it." def detach_region(self, region: int | Region) -> Region: """ @@ -679,8 +705,7 @@ def detach_region(self, region: int | Region) -> Region: region_idx = region region = self.regions[region_idx] region.parent = None - self.regions = self.regions[:region_idx] + self.regions[region_idx + - 1:] + self.regions = self.regions[:region_idx] + self.regions[region_idx + 1 :] return region def drop_all_references(self) -> None: @@ -717,20 +742,22 @@ def verify(self, verify_nested_ops: bool = True) -> None: def verify_(self) -> None: pass - _OperationType = TypeVar('_OperationType', bound='Operation') + _OperationType = TypeVar("_OperationType", bound="Operation") @classmethod - def parse(cls: type[_OperationType], result_types: list[Attribute], - parser: BaseParser) -> _OperationType: + def parse( + cls: type[_OperationType], result_types: list[Attribute], parser: BaseParser + ) -> _OperationType: return parser.parse_op_with_default_format(cls, result_types) def print(self, printer: Printer): return printer.print_op_with_default_format(self) def clone_without_regions( - self: OpT, - value_mapper: dict[SSAValue, SSAValue] | None = None, - block_mapper: dict[Block, Block] | None = None) -> OpT: + self: OpT, + value_mapper: dict[SSAValue, SSAValue] | None = None, + block_mapper: dict[Block, Block] | None = None, + ) -> OpT: """Clone an operation, with empty regions instead.""" if value_mapper is None: value_mapper = {} @@ -742,22 +769,27 @@ def clone_without_regions( ] result_types = [res.typ for res in self.results] attributes = self.attributes.copy() - successors = [(block_mapper[successor] - if successor in block_mapper else successor) - for successor in self.successors] + successors = [ + (block_mapper[successor] if successor in block_mapper else successor) + for successor in self.successors + ] regions = [Region() for _ in self.regions] - cloned_op = self.create(operands=operands, - result_types=result_types, - attributes=attributes, - successors=successors, - regions=regions) + cloned_op = self.create( + operands=operands, + result_types=result_types, + attributes=attributes, + successors=successors, + regions=regions, + ) for idx, result in enumerate(cloned_op.results): value_mapper[self.results[idx]] = result return cloned_op - def clone(self: OpT, - value_mapper: dict[SSAValue, SSAValue] | None = None, - block_mapper: dict[Block, Block] | None = None) -> OpT: + def clone( + self: OpT, + value_mapper: dict[SSAValue, SSAValue] | None = None, + block_mapper: dict[Block, Block] | None = None, + ) -> OpT: """Clone an operation with all its regions and operations in them.""" if value_mapper is None: value_mapper = {} @@ -782,15 +814,14 @@ def get_traits_of_type(cls, trait_type: type[OpTrait]) -> list[OpTrait]: """ return [t for t in cls.traits if isinstance(t, trait_type)] - def erase(self, - safe_erase: bool = True, - drop_references: bool = True) -> None: + def erase(self, safe_erase: bool = True, drop_references: bool = True) -> None: """ Erase the operation, and remove all its references to other operations. If safe_erase is specified, check that the operation results are not used. """ - assert self.parent is None, "Operation with parents should first be detached " + \ - "before erasure." + assert self.parent is None, ( + "Operation with parents should first be detached " + "before erasure." + ) if drop_references: self.drop_all_references() for result in self.results: @@ -805,7 +836,7 @@ def detach(self): def is_structurally_equivalent( self, other: IRNode, - context: dict[IRNode | SSAValue, IRNode | SSAValue] | None = None + context: dict[IRNode | SSAValue, IRNode | SSAValue] | None = None, ) -> bool: """ Check if two operations are structurally equivalent. @@ -819,26 +850,30 @@ def is_structurally_equivalent( return False if self.name != other.name: return False - if len(self.operands) != len(other.operands) or \ - len(self.results) != len(other.results) or \ - len(self.regions) != len(other.regions) or \ - len(self.successors) != len(other.successors) or \ - self.attributes != other.attributes: + if ( + len(self.operands) != len(other.operands) + or len(self.results) != len(other.results) + or len(self.regions) != len(other.regions) + or len(self.successors) != len(other.successors) + or self.attributes != other.attributes + ): return False - if self.parent and other.parent and context.get( - self.parent) != other.parent: + if self.parent and other.parent and context.get(self.parent) != other.parent: return False if not all( - context.get(operand) == other_operand for operand, - other_operand in zip(self.operands, other.operands)): + context.get(operand) == other_operand + for operand, other_operand in zip(self.operands, other.operands) + ): return False if not all( - context.get(successor) == other_successor for successor, - other_successor in zip(self.successors, other.successors)): + context.get(successor) == other_successor + for successor, other_successor in zip(self.successors, other.successors) + ): return False if not all( - region.is_structurally_equivalent(other_region, context) - for region, other_region in zip(self.regions, other.regions)): + region.is_structurally_equivalent(other_region, context) + for region, other_region in zip(self.regions, other.regions) + ): return False # Add results of this operation to the context for result, other_result in zip(self.results, other.results): @@ -854,6 +889,7 @@ def __hash__(self) -> int: def __str__(self) -> str: from xdsl.printer import Printer + res = StringIO() printer = Printer(stream=res, target=Printer.Target.XDSL) printer.print_op(self) @@ -861,15 +897,15 @@ def __str__(self) -> str: def __format__(self, __format_spec: str) -> str: desc = str(self) - if '\n' in desc: + if "\n" in desc: # Description is multi-line, indent each line - desc = '\n'.join('\t' + line for line in desc.splitlines()) + desc = "\n".join("\t" + line for line in desc.splitlines()) # Add newline before and after - desc = f'\n{desc}\n' - return f'{self.__class__.__qualname__}({desc})' + desc = f"\n{desc}\n" + return f"{self.__class__.__qualname__}({desc})" -OperationInvT = TypeVar('OperationInvT', bound=Operation) +OperationInvT = TypeVar("OperationInvT", bound=Operation) @dataclass(init=False) @@ -887,17 +923,19 @@ class Block(IRNode): parent: Region | None """Parent region containing the block.""" - def __init__(self, - ops: Iterable[Operation] = (), - *, - arg_types: Iterable[Attribute] = (), - parent: Region | None = None, - declared_at: Span | None = None): + def __init__( + self, + ops: Iterable[Operation] = (), + *, + arg_types: Iterable[Attribute] = (), + parent: Region | None = None, + declared_at: Span | None = None, + ): super().__init__(self) self.declared_at = declared_at self._args = tuple( - BlockArgument(typ, self, index) - for index, typ in enumerate(arg_types)) + BlockArgument(typ, self, index) for index, typ in enumerate(arg_types) + ) self.ops = [] self.parent = parent @@ -920,29 +958,27 @@ def args(self) -> tuple[BlockArgument, ...]: """Returns the block arguments.""" return self._args - @deprecated('Please use Block(arg_types=arg_types)') + @deprecated("Please use Block(arg_types=arg_types)") @staticmethod def from_arg_types(arg_types: Sequence[Attribute]) -> Block: b = Block() b._args = tuple( - BlockArgument(typ, b, index) - for index, typ in enumerate(arg_types)) + BlockArgument(typ, b, index) for index, typ in enumerate(arg_types) + ) return b - @deprecated('Please use Block(ops, arg_types=arg_types)') + @deprecated("Please use Block(ops, arg_types=arg_types)") @staticmethod - def from_ops(ops: list[Operation], - arg_types: list[Attribute] | None = None): + def from_ops(ops: list[Operation], arg_types: list[Attribute] | None = None): b = Block() if arg_types: b._args = tuple( - BlockArgument(typ, b, index) - for index, typ in enumerate(arg_types)) + BlockArgument(typ, b, index) for index, typ in enumerate(arg_types) + ) b.add_ops(ops) return b class BlockCallback(Protocol): - def __call__(self, *args: BlockArgument) -> list[Operation]: ... @@ -962,8 +998,7 @@ def insert_arg(self, typ: Attribute, index: int) -> BlockArgument: new_arg = BlockArgument(typ, self, index) for arg in self._args[index:]: arg.index += 1 - self._args = tuple( - chain(self._args[:index], [new_arg], self._args[index:])) + self._args = tuple(chain(self._args[:index], [new_arg], self._args[index:])) return new_arg def erase_arg(self, arg: BlockArgument, safe_erase: bool = True) -> None: @@ -973,12 +1008,10 @@ def erase_arg(self, arg: BlockArgument, safe_erase: bool = True) -> None: If safe_erase is False, replace the block argument uses with an ErasedSSAVAlue. """ if arg.block is not self: - raise Exception( - "Attempting to delete an argument of the wrong block") - for block_arg in self._args[arg.index + 1:]: + raise Exception("Attempting to delete an argument of the wrong block") + for block_arg in self._args[arg.index + 1 :]: block_arg.index -= 1 - self._args = tuple( - chain(self._args[:arg.index], self._args[arg.index + 1:])) + self._args = tuple(chain(self._args[: arg.index], self._args[arg.index + 1 :])) arg.erase(safe_erase=safe_erase) def _attach_op(self, operation: Operation) -> None: @@ -1024,20 +1057,21 @@ def add_ops(self, ops: Iterable[Operation]) -> None: for op in ops: self.add_op(op) - def insert_ops_before(self, ops: Sequence[Operation], - existing_op: Operation) -> None: + def insert_ops_before( + self, ops: Sequence[Operation], existing_op: Operation + ) -> None: index = self.get_operation_index(existing_op) self.insert_op(ops, index) - def insert_ops_after(self, ops: Sequence[Operation], - existing_op: Operation) -> None: + def insert_ops_after( + self, ops: Sequence[Operation], existing_op: Operation + ) -> None: index = self.get_operation_index(existing_op) self.insert_op(list(ops), index + 1) - def insert_op(self, - ops: Operation | Sequence[Operation], - index: int, - name: str | None = None) -> None: + def insert_op( + self, ops: Operation | Sequence[Operation], index: int, name: str | None = None + ) -> None: """ Insert one or multiple operations at a given index in the block. The operations should not be attached to another block. @@ -1050,7 +1084,8 @@ def insert_op(self, if index > len(self.ops): raise ValueError( f"Can't insert operation in index {index} in a block with " - f"{len(self.ops)} operations.") + f"{len(self.ops)} operations." + ) if isinstance(ops, Operation): ops = [ops] elif not isinstance(ops, list): @@ -1085,7 +1120,7 @@ def detach_op(self, op: int | Operation) -> Operation: if op.parent is not self: raise Exception("Cannot detach operation from a different block.") op.parent = None - self.ops = self.ops[:op_idx] + self.ops[op_idx + 1:] + self.ops = self.ops[:op_idx] + self.ops[op_idx + 1 :] return op def erase_op(self, op: int | Operation, safe_erase: bool = True) -> None: @@ -1124,8 +1159,9 @@ def erase(self, safe_erase: bool = True) -> None: If safe_erase is specified, check that no operation results are used outside the block. """ - assert self.parent is None, "Blocks with parents should first be detached " + \ - "before erasure." + assert self.parent is None, ( + "Blocks with parents should first be detached " + "before erasure." + ) self.drop_all_references() for op in self.ops: op.erase(safe_erase=safe_erase, drop_references=False) @@ -1133,7 +1169,7 @@ def erase(self, safe_erase: bool = True) -> None: def is_structurally_equivalent( self, other: IRNode, - context: dict[IRNode | SSAValue, IRNode | SSAValue] | None = None + context: dict[IRNode | SSAValue, IRNode | SSAValue] | None = None, ) -> bool: """ Check if two blocks are structurally equivalent. @@ -1145,8 +1181,7 @@ def is_structurally_equivalent( context = {} if not isinstance(other, Block): return False - if len(self.args) != len(other.args) or \ - len(self.ops) != len(other.ops): + if len(self.args) != len(other.args) or len(self.ops) != len(other.ops): return False for arg, other_arg in zip(self.args, other.args): if arg.typ != other_arg.typ: @@ -1155,8 +1190,9 @@ def is_structurally_equivalent( # Add self to the context so Operations can check for identical parents context[self] = other if not all( - op.is_structurally_equivalent(other_op, context) - for op, other_op in zip(self.ops, other.ops)): + op.is_structurally_equivalent(other_op, context) + for op, other_op in zip(self.ops, other.ops) + ): return False return True @@ -1178,14 +1214,14 @@ class Region(IRNode): parent: Operation | None = field(default=None, repr=False) """Operation containing the region.""" - def __init__(self, - blocks: Block | Iterable[Block] = (), - parent: Operation | None = None): + def __init__( + self, blocks: Block | Iterable[Block] = (), parent: Operation | None = None + ): super().__init__(self) self.parent = parent self.blocks = [] if isinstance(blocks, Block): - blocks = (blocks, ) + blocks = (blocks,) for block in blocks: self.add_block(block) @@ -1202,16 +1238,16 @@ def __repr__(self) -> str: return f"Region(num_blocks={len(self.blocks)})" @staticmethod - @deprecated('Please use Region([Block(ops)])') + @deprecated("Please use Region([Block(ops)])") def from_operation_list(ops: list[Operation]) -> Region: return Region([Block(ops)]) - @deprecated('Please use Region(blocks, parent=None)') + @deprecated("Please use Region(blocks, parent=None)") @staticmethod def from_block_list(blocks: list[Block]) -> Region: return Region(blocks) - @deprecated('Please use Region(blocks) or Region(Block(ops))') + @deprecated("Please use Region(blocks) or Region(Block(ops))") @staticmethod def get(arg: Region | Sequence[Block] | Sequence[Operation]) -> Region: if isinstance(arg, Region): @@ -1234,7 +1270,8 @@ def ops(self) -> list[Operation]: if len(self.blocks) != 1: raise ValueError( "'ops' property of Region class is only available " - "for single-block regions.") + "for single-block regions." + ) return self.block.ops @property @@ -1249,8 +1286,10 @@ def op(self) -> Operation: last_op = block.last_op if first_op is last_op and first_op is not None: return first_op - raise ValueError("'op' property of Region class is only available " - "for single-operation single-block regions.") + raise ValueError( + "'op' property of Region class is only available " + "for single-operation single-block regions." + ) @property def block(self) -> Block: @@ -1261,17 +1300,18 @@ def block(self) -> Block: if len(self.blocks) != 1: raise ValueError( "'block' property of Region class is only available " - "for single-block regions.") + "for single-block regions." + ) return self.blocks[0] def _attach_block(self, block: Block) -> None: """Attach a block to the region, and check that it has no parents.""" if block.parent: raise ValueError( - "Can't add to a region a block already attached to a region.") + "Can't add to a region a block already attached to a region." + ) if block.is_ancestor(self): - raise ValueError( - "Can't add a block to a region contained in the block.") + raise ValueError("Can't add a block to a region contained in the block.") block.parent = self def add_block(self, block: Block) -> None: @@ -1287,7 +1327,8 @@ def insert_block(self, blocks: Block | list[Block], index: int) -> None: if index < 0 or index > len(self.blocks): raise ValueError( f"Can't insert block in index {index} in a block with " - f"{len(self.blocks)} blocks.") + f"{len(self.blocks)} blocks." + ) if not isinstance(blocks, list): blocks = [blocks] for block in blocks: @@ -1314,7 +1355,7 @@ def detach_block(self, block: int | Block) -> Block: block_idx = block block = self.blocks[block_idx] block.parent = None - self.blocks = self.blocks[:block_idx] + self.blocks[block_idx + 1:] + self.blocks = self.blocks[:block_idx] + self.blocks[block_idx + 1 :] return block def erase_block(self, block: int | Block, safe_erase: bool = True) -> None: @@ -1325,11 +1366,13 @@ def erase_block(self, block: int | Block, safe_erase: bool = True) -> None: block = self.detach_block(block) block.erase(safe_erase=safe_erase) - def clone_into(self, - dest: Region, - insert_index: int | None = None, - value_mapper: dict[SSAValue, SSAValue] | None = None, - block_mapper: dict[Block, Block] | None = None): + def clone_into( + self, + dest: Region, + insert_index: int | None = None, + value_mapper: dict[SSAValue, SSAValue] | None = None, + block_mapper: dict[Block, Block] | None = None, + ): """ Clone all block of this region into `dest` to position `insert_index` """ @@ -1378,8 +1421,9 @@ def erase(self) -> None: """ Erase the region, and remove all its references to other operations. """ - assert self.parent, "Regions with parents should first be " + \ - "detached before erasure." + assert self.parent, ( + "Regions with parents should first be " + "detached before erasure." + ) self.drop_all_references() def move_blocks(self, region: Region) -> None: @@ -1394,7 +1438,7 @@ def move_blocks(self, region: Region) -> None: def is_structurally_equivalent( self, other: IRNode, - context: dict[IRNode | SSAValue, IRNode | SSAValue] | None = None + context: dict[IRNode | SSAValue, IRNode | SSAValue] | None = None, ) -> bool: """ Check if two regions are structurally equivalent. @@ -1413,7 +1457,8 @@ def is_structurally_equivalent( for block, other_block in zip(self.blocks, other.blocks): context[block] = other_block if not all( - block.is_structurally_equivalent(other_block, context) - for block, other_block in zip(self.blocks, other.blocks)): + block.is_structurally_equivalent(other_block, context) + for block, other_block in zip(self.blocks, other.blocks) + ): return False return True diff --git a/xdsl/irdl.py b/xdsl/irdl.py index 04e94c8fc9..1edf2ff33a 100644 --- a/xdsl/irdl.py +++ b/xdsl/irdl.py @@ -5,16 +5,41 @@ from enum import Enum from functools import reduce from inspect import isclass -from typing import (Annotated, Any, Generic, Literal, Mapping, Sequence, - TypeAlias, TypeVar, Union, cast, get_args, get_origin, - get_type_hints, overload) +from typing import ( + Annotated, + Any, + Generic, + Literal, + Mapping, + Sequence, + TypeAlias, + TypeVar, + Union, + cast, + get_args, + get_origin, + get_type_hints, + overload, +) from types import UnionType, GenericAlias, FunctionType -from xdsl.ir import (Attribute, Block, Data, OpResult, OpTrait, Operation, - ParametrizedAttribute, Region, SSAValue) +from xdsl.ir import ( + Attribute, + Block, + Data, + OpResult, + OpTrait, + Operation, + ParametrizedAttribute, + Region, + SSAValue, +) from xdsl.utils.diagnostic import Diagnostic -from xdsl.utils.exceptions import (PyRDLAttrDefinitionError, - PyRDLOpDefinitionError, VerifyException) +from xdsl.utils.exceptions import ( + PyRDLAttrDefinitionError, + PyRDLOpDefinitionError, + VerifyException, +) from xdsl.utils.hints import PropertyType # pyright: reportMissingParameterType=false, reportUnknownParameterType=false @@ -63,8 +88,7 @@ class EqAttrConstraint(AttrConstraint): def verify(self, attr: Attribute) -> None: if attr != self.attr: - raise VerifyException( - f"Expected attribute {self.attr} but got {attr}") + raise VerifyException(f"Expected attribute {self.attr} but got {attr}") @dataclass @@ -77,11 +101,13 @@ class BaseAttr(AttrConstraint): def verify(self, attr: Attribute) -> None: if not isinstance(attr, self.attr): raise VerifyException( - f"{attr} should be of base attribute {self.attr.name}") + f"{attr} should be of base attribute {self.attr.name}" + ) -def attr_constr_coercion(attr: (Attribute | type[Attribute] - | AttrConstraint)) -> AttrConstraint: +def attr_constr_coercion( + attr: (Attribute | type[Attribute] | AttrConstraint), +) -> AttrConstraint: """ Attributes are coerced into EqAttrConstraints, and Attribute types are coerced into BaseAttr. @@ -110,11 +136,10 @@ class AnyOf(AttrConstraint): attr_constrs: list[AttrConstraint] """The list of constraints that are checked.""" - def __init__(self, attr_constrs: Sequence[Attribute | type[Attribute] - | AttrConstraint]): - self.attr_constrs = [ - attr_constr_coercion(constr) for constr in attr_constrs - ] + def __init__( + self, attr_constrs: Sequence[Attribute | type[Attribute] | AttrConstraint] + ): + self.attr_constrs = [attr_constr_coercion(constr) for constr in attr_constrs] def verify(self, attr: Attribute) -> None: for attr_constr in self.attr_constrs: @@ -163,22 +188,24 @@ class ParamAttrConstraint(AttrConstraint): param_constrs: list[AttrConstraint] """The attribute parameter constraints""" - def __init__(self, base_attr: type[ParametrizedAttribute], - param_constrs: Sequence[(Attribute | type[Attribute] - | AttrConstraint)]): + def __init__( + self, + base_attr: type[ParametrizedAttribute], + param_constrs: Sequence[(Attribute | type[Attribute] | AttrConstraint)], + ): self.base_attr = base_attr - self.param_constrs = [ - attr_constr_coercion(constr) for constr in param_constrs - ] + self.param_constrs = [attr_constr_coercion(constr) for constr in param_constrs] def verify(self, attr: Attribute) -> None: if not isinstance(attr, self.base_attr): raise VerifyException( - f"{attr} should be of base attribute {self.base_attr.name}") + f"{attr} should be of base attribute {self.base_attr.name}" + ) if len(self.param_constrs) != len(attr.parameters): raise VerifyException( f"{len(self.param_constrs)} parameters expected, " - f"but got {len(attr.parameters)}") + f"but got {len(attr.parameters)}" + ) for idx, param_constr in enumerate(self.param_constrs): param_constr.verify(attr.parameters[idx]) @@ -187,7 +214,7 @@ def irdl_to_attr_constraint( irdl: Any, *, allow_type_var: bool = False, - type_var_mapping: dict[TypeVar, AttrConstraint] | None = None + type_var_mapping: dict[TypeVar, AttrConstraint] | None = None, ) -> AttrConstraint: if isinstance(irdl, AttrConstraint): return irdl @@ -205,9 +232,12 @@ def irdl_to_attr_constraint( if isinstance(arg, IRDLAnnotations): continue constraints.append( - irdl_to_attr_constraint(arg, - allow_type_var=allow_type_var, - type_var_mapping=type_var_mapping)) + irdl_to_attr_constraint( + arg, + allow_type_var=allow_type_var, + type_var_mapping=type_var_mapping, + ) + ) if len(constraints) > 1: return AllOf(constraints) return constraints[0] @@ -219,9 +249,11 @@ def irdl_to_attr_constraint( # Attribute class case # This is a coercion for an `BaseAttr`. - if isclass(irdl) \ - and not isinstance(irdl, GenericAlias) \ - and issubclass(irdl, Attribute): + if ( + isclass(irdl) + and not isinstance(irdl, GenericAlias) + and issubclass(irdl, Attribute) + ): return BaseAttr(irdl) # Type variable case @@ -233,8 +265,7 @@ def irdl_to_attr_constraint( if irdl in type_var_mapping: return type_var_mapping[irdl] if irdl.__bound__ is None: - raise Exception("Type variables used in IRDL are expected to" - " be bound.") + raise Exception("Type variables used in IRDL are expected to" " be bound.") # We do not allow nested type variables. return irdl_to_attr_constraint(irdl.__bound__) @@ -242,19 +273,21 @@ def irdl_to_attr_constraint( # GenericData case if isclass(origin) and issubclass(origin, GenericData): - return AllOf([ - BaseAttr(origin), - origin.generic_constraint_coercion(get_args(irdl)) - ]) + return AllOf( + [BaseAttr(origin), origin.generic_constraint_coercion(get_args(irdl))] + ) # Generic ParametrizedAttributes case # We translate it to constraints over the attribute parameters. - if isclass(origin) and issubclass( - origin, ParametrizedAttribute) and issubclass(origin, Generic): + if ( + isclass(origin) + and issubclass(origin, ParametrizedAttribute) + and issubclass(origin, Generic) + ): args = [ - irdl_to_attr_constraint(arg, - allow_type_var=allow_type_var, - type_var_mapping=type_var_mapping) + irdl_to_attr_constraint( + arg, allow_type_var=allow_type_var, type_var_mapping=type_var_mapping + ) for arg in get_args(irdl) ] generic_args = () @@ -266,24 +299,24 @@ def irdl_to_attr_constraint( generic_args = get_args(parent) break else: - raise Exception( - f"Cannot parametrized non-generic {origin.name} attribute.") + raise Exception(f"Cannot parametrized non-generic {origin.name} attribute.") # Check that we have the right number of parameters if len(args) != len(generic_args): - raise Exception(f"{origin.name} expects {len(generic_args)}" - f" parameters, got {len(args)}.") + raise Exception( + f"{origin.name} expects {len(generic_args)}" + f" parameters, got {len(args)}." + ) type_var_mapping = { - parameter: arg - for parameter, arg in zip(generic_args, args) + parameter: arg for parameter, arg in zip(generic_args, args) } origin_parameters = irdl_param_attr_get_param_type_hints(origin) origin_constraints = [ - irdl_to_attr_constraint(param, - allow_type_var=True, - type_var_mapping=type_var_mapping) + irdl_to_attr_constraint( + param, allow_type_var=True, type_var_mapping=type_var_mapping + ) for _, param in origin_parameters ] return ParamAttrConstraint(origin, origin_constraints) @@ -298,9 +331,12 @@ def irdl_to_attr_constraint( if isinstance(arg, IRDLAnnotations): continue constraints.append( - irdl_to_attr_constraint(arg, - allow_type_var=allow_type_var, - type_var_mapping=type_var_mapping)) + irdl_to_attr_constraint( + arg, + allow_type_var=allow_type_var, + type_var_mapping=type_var_mapping, + ) + ) if len(constraints) > 1: return AnyOf(constraints) return constraints[0] @@ -310,7 +346,8 @@ def irdl_to_attr_constraint( raise ValueError( f"Generic `Data` type '{origin.name}' cannot be converted to " "an attribute constraint. Consider making it inherit from " - "`GenericData` instead of `Data`.") + "`GenericData` instead of `Data`." + ) raise ValueError(f"Unexpected irdl constraint: {irdl}") @@ -322,24 +359,25 @@ def irdl_to_attr_constraint( # \___/| .__/ \___|_| \__,_|\__|_|\___/|_| |_| # |_| -_OpT = TypeVar('_OpT', bound='IRDLOperation') +_OpT = TypeVar("_OpT", bound="IRDLOperation") class IRDLOperation(Operation): - def __init__( self: IRDLOperation, - operands: Sequence[SSAValue | Operation - | Sequence[SSAValue | Operation] | None] - | None = None, - result_types: Sequence[Attribute | Sequence[Attribute]] + operands: Sequence[SSAValue | Operation | Sequence[SSAValue | Operation] | None] | None = None, + result_types: Sequence[Attribute | Sequence[Attribute]] | None = None, attributes: Mapping[str, Attribute | None] | None = None, successors: Sequence[Block] | None = None, - regions: Sequence[Region | Sequence[Operation] | Sequence[Block] - | Sequence[Region | Sequence[Operation] - | Sequence[Block]]] - | None = None): + regions: Sequence[ + Region + | Sequence[Operation] + | Sequence[Block] + | Sequence[Region | Sequence[Operation] | Sequence[Block]] + ] + | None = None, + ): if operands is None: operands = [] if result_types is None: @@ -350,32 +388,42 @@ def __init__( successors = [] if regions is None: regions = [] - irdl_op_init(self, self.irdl_definition, operands, result_types, - attributes, successors, regions) + irdl_op_init( + self, + self.irdl_definition, + operands, + result_types, + attributes, + successors, + regions, + ) @classmethod def build( cls: type[_OpT], - operands: Sequence[SSAValue | Operation - | Sequence[SSAValue | Operation] | None] - | None = None, - result_types: Sequence[Attribute | Sequence[Attribute]] + operands: Sequence[SSAValue | Operation | Sequence[SSAValue | Operation] | None] | None = None, + result_types: Sequence[Attribute | Sequence[Attribute]] | None = None, attributes: Mapping[str, Attribute | None] | None = None, successors: Sequence[Block] | None = None, - regions: Sequence[Region | Sequence[Operation] | Sequence[Block] - | Sequence[Region | Sequence[Operation] - | Sequence[Block]]] - | None = None + regions: Sequence[ + Region + | Sequence[Operation] + | Sequence[Block] + | Sequence[Region | Sequence[Operation] | Sequence[Block]] + ] + | None = None, ) -> _OpT: """Create a new operation using builders.""" op = cls.__new__(cls) - IRDLOperation.__init__(op, - operands=operands, - result_types=result_types, - attributes=attributes, - successors=successors, - regions=regions) + IRDLOperation.__init__( + op, + operands=operands, + result_types=result_types, + attributes=attributes, + successors=successors, + regions=regions, + ) return op @classmethod @@ -388,6 +436,7 @@ def irdl_definition(cls) -> OpDef: @dataclass class IRDLOption(ABC): """Additional option used in IRDL.""" + ... @@ -427,18 +476,21 @@ class AttrSizedRegionSegments(IRDLOption): @dataclass class OperandOrResultDef(ABC): """An operand or a result definition. Should not be used directly.""" + ... @dataclass class VariadicDef(OperandOrResultDef): """A variadic operand or result definition. Should not be used directly.""" + ... @dataclass class OptionalDef(VariadicDef): """An optional operand or result definition. Should not be used directly.""" + ... @@ -533,12 +585,13 @@ class OptSingleBlockRegionDef(RegionDef, OptionalDef): """An IRDL optional region definition that expects exactly one block.""" -SingleBlockRegion: TypeAlias = Annotated[ - Region, IRDLAnnotations.SingleBlockRegionAnnot] +SingleBlockRegion: TypeAlias = Annotated[Region, IRDLAnnotations.SingleBlockRegionAnnot] VarSingleBlockRegion: TypeAlias = Annotated[ - list[Region], IRDLAnnotations.SingleBlockRegionAnnot] + list[Region], IRDLAnnotations.SingleBlockRegionAnnot +] OptSingleBlockRegion: TypeAlias = Annotated[ - Region | None, IRDLAnnotations.SingleBlockRegionAnnot] + Region | None, IRDLAnnotations.SingleBlockRegionAnnot +] @dataclass(init=False) @@ -563,13 +616,13 @@ def __init__(self, typ: Attribute | type[Attribute] | AttrConstraint): _OpAttrT = TypeVar("_OpAttrT", bound=Attribute) OpAttr: TypeAlias = Annotated[_OpAttrT, IRDLAnnotations.AttributeDefAnnot] -OptOpAttr: TypeAlias = Annotated[_OpAttrT | None, - IRDLAnnotations.OptAttributeDefAnnot] +OptOpAttr: TypeAlias = Annotated[_OpAttrT | None, IRDLAnnotations.OptAttributeDefAnnot] @dataclass(kw_only=True) class OpDef: """The internal IRDL definition of an operation.""" + name: str = field(kw_only=False) operands: list[tuple[str, OperandDef]] = field(default_factory=list) results: list[tuple[str, ResultDef]] = field(default_factory=list) @@ -602,7 +655,8 @@ def wrong_field_exception(field_name: str) -> PyRDLOpDefinitionError: "Annotated[Operand, ], results with " "Annotated[OpResult, ], regions with " "Region, and attributes with " - "OpAttr[]") + "OpAttr[]" + ) # Check that all fields of the operation definition are either already # in Operation, or are class functions or methods. @@ -612,8 +666,8 @@ def wrong_field_exception(field_name: str) -> PyRDLOpDefinitionError: if field_name in ["irdl_options", "traits"]: continue if isinstance( - value, - (FunctionType, PropertyType, classmethod, staticmethod)): + value, (FunctionType, PropertyType, classmethod, staticmethod) + ): continue raise wrong_field_exception(field_name) @@ -621,12 +675,12 @@ def wrong_field_exception(field_name: str) -> PyRDLOpDefinitionError: raise Exception( f"pyrdl operation definition '{pyrdl_def.__name__}' does not " "define the operation name. The operation name is defined by " - "adding a 'name' field.") + "adding a 'name' field." + ) op_def = OpDef(clsdict["name"]) for field_name, field_type in type_hints.items(): - if field_name in get_type_hints(Operation).keys(): continue @@ -637,16 +691,15 @@ def wrong_field_exception(field_name: str) -> PyRDLOpDefinitionError: origin: Any | None = cast(Any | None, get_origin(field_type)) args: tuple[Any, ...] if origin is None: - args = (field_type, ) + args = (field_type,) elif origin == Annotated: args = get_args(field_type) else: - args = (field_type, ) + args = (field_type,) args = cast(tuple[Any, ...], args) # Get attribute constraints from a list of pyrdl constraints - def get_constraint( - pyrdl_constrs: tuple[Any, ...]) -> AttrConstraint: + def get_constraint(pyrdl_constrs: tuple[Any, ...]) -> AttrConstraint: constraints = [ irdl_to_attr_constraint(pyrdl_constr) for pyrdl_constr in pyrdl_constrs @@ -699,30 +752,24 @@ def get_constraint( op_def.attributes[field_name] = AttributeDef(constraint) elif IRDLAnnotations.OptAttributeDefAnnot in args: assert get_origin(args[0]) in [UnionType, Union] - args = (reduce(lambda x, y: x | y, - get_args(args[0])[:-1]), *args[1:]) + args = (reduce(lambda x, y: x | y, get_args(args[0])[:-1]), *args[1:]) constraint = get_constraint(args) op_def.attributes[field_name] = OptAttributeDef(constraint) # Region annotation elif args[0] == Region: - if (len(args) > 1 - and args[1] == IRDLAnnotations.SingleBlockRegionAnnot): + if len(args) > 1 and args[1] == IRDLAnnotations.SingleBlockRegionAnnot: op_def.regions.append((field_name, SingleBlockRegionDef())) else: op_def.regions.append((field_name, RegionDef())) elif args[0] == VarRegion: - if (len(args) > 1 - and args[1] == IRDLAnnotations.SingleBlockRegionAnnot): - op_def.regions.append( - (field_name, VarSingleBlockRegionDef())) + if len(args) > 1 and args[1] == IRDLAnnotations.SingleBlockRegionAnnot: + op_def.regions.append((field_name, VarSingleBlockRegionDef())) else: op_def.regions.append((field_name, VarRegionDef())) elif args[0] == OptRegion: - if (len(args) > 1 - and args[1] == IRDLAnnotations.SingleBlockRegionAnnot): - op_def.regions.append( - (field_name, OptSingleBlockRegionDef())) + if len(args) > 1 and args[1] == IRDLAnnotations.SingleBlockRegionAnnot: + op_def.regions.append((field_name, OptSingleBlockRegionDef())) else: op_def.regions.append((field_name, OptRegionDef())) else: @@ -734,7 +781,8 @@ def get_constraint( raise Exception( f"pyrdl operation definition '{pyrdl_def.__name__}' " f"has a 'traits' field of type {type(traits)}, but " - "it should be of type frozenset.") + "it should be of type frozenset." + ) traits = cast(frozenset[OpTrait], traits) op_def.traits = traits @@ -770,6 +818,7 @@ class VarIRConstruct(Enum): An enum representing the part of an IR that may be variadic. This contains operands, results, and regions. """ + OPERAND = 1 RESULT = 2 REGION = 3 @@ -788,8 +837,11 @@ def get_construct_name(construct: VarIRConstruct) -> str: def get_construct_defs( op_def: OpDef, construct: VarIRConstruct -) -> list[tuple[str, OperandDef]] | list[tuple[str, ResultDef]] | list[tuple[ - str, RegionDef]]: +) -> ( + list[tuple[str, OperandDef]] + | list[tuple[str, ResultDef]] + | list[tuple[str, RegionDef]] +): """Get the definitions of this type in an operation definition.""" if construct == VarIRConstruct.OPERAND: return op_def.operands @@ -804,10 +856,10 @@ def get_op_constructs( op: Operation, construct: VarIRConstruct ) -> tuple[SSAValue, ...] | list[OpResult] | list[Region]: """ - Get the list of arguments of the type in an operation. - For example, if the argument type is an operand, get the list of - operands. - """ + Get the list of arguments of the type in an operation. + For example, if the argument type is an operand, get the list of + operands. + """ if construct == VarIRConstruct.OPERAND: return op.operands if construct == VarIRConstruct.RESULT: @@ -818,7 +870,7 @@ def get_op_constructs( def get_attr_size_option( - construct: VarIRConstruct + construct: VarIRConstruct, ) -> AttrSizedOperandSegments | AttrSizedResultSegments | AttrSizedRegionSegments: """Get the AttrSized option for this type.""" if construct == VarIRConstruct.OPERAND: @@ -830,12 +882,12 @@ def get_attr_size_option( assert False, "Unknown VarIRConstruct value" -def get_variadic_sizes_from_attr(op: Operation, - defs: Sequence[tuple[str, - OperandDef | ResultDef - | RegionDef]], - construct: VarIRConstruct, - size_attribute_name: str) -> list[int]: +def get_variadic_sizes_from_attr( + op: Operation, + defs: Sequence[tuple[str, OperandDef | ResultDef | RegionDef]], + construct: VarIRConstruct, + size_attribute_name: str, +) -> list[int]: """ Get the sizes of the variadic definitions from the corresponding attribute. @@ -850,28 +902,31 @@ def get_variadic_sizes_from_attr(op: Operation, ) attribute = op.attributes[size_attribute_name] if not isinstance(attribute, DenseArrayBase): - raise VerifyException(f"{size_attribute_name} attribute is expected " - "to be a DenseArrayBase.") + raise VerifyException( + f"{size_attribute_name} attribute is expected " "to be a DenseArrayBase." + ) if attribute.elt_type != i32: raise VerifyException( f"{size_attribute_name} attribute is expected to " - "be a DenseArrayBase of i32") - def_sizes = cast(list[int], - [size_attr.data for size_attr in attribute.data.data]) + "be a DenseArrayBase of i32" + ) + def_sizes = cast(list[int], [size_attr.data for size_attr in attribute.data.data]) if len(def_sizes) != len(defs): raise VerifyException( f"expected {len(defs)} values in " - f"{size_attribute_name}, but got {len(def_sizes)}") + f"{size_attribute_name}, but got {len(def_sizes)}" + ) variadic_sizes = list[int]() - for ((arg_name, arg_def), arg_size) in zip(defs, def_sizes): + for (arg_name, arg_def), arg_size in zip(defs, def_sizes): if isinstance(arg_def, OptionalDef) and arg_size > 1: raise VerifyException( f"optional {get_construct_name(construct)} {arg_name} is expected to " f"be of size 0 or 1 in {size_attribute_name}, but got " - f"{arg_size}") + f"{arg_size}" + ) if not isinstance(arg_def, VariadicDef) and arg_size != 1: raise VerifyException( @@ -885,8 +940,9 @@ def get_variadic_sizes_from_attr(op: Operation, return variadic_sizes -def get_variadic_sizes(op: Operation, op_def: OpDef, - construct: VarIRConstruct) -> list[int]: +def get_variadic_sizes( + op: Operation, op_def: OpDef, construct: VarIRConstruct +) -> list[int]: """Get variadic sizes of operands or results.""" defs = get_construct_defs(op_def, construct) @@ -894,28 +950,35 @@ def get_variadic_sizes(op: Operation, op_def: OpDef, def_type_name = get_construct_name(construct) attribute_option = get_attr_size_option(construct) - variadic_defs = [(arg_name, arg_def) for arg_name, arg_def in defs - if isinstance(arg_def, VariadicDef)] + variadic_defs = [ + (arg_name, arg_def) + for arg_name, arg_def in defs + if isinstance(arg_def, VariadicDef) + ] # If the size is in the attributes, fetch it if attribute_option in op_def.options: - return get_variadic_sizes_from_attr(op, defs, construct, - attribute_option.attribute_name) + return get_variadic_sizes_from_attr( + op, defs, construct, attribute_option.attribute_name + ) # If there are no variadics arguments, # we just check that we have the right number of arguments if len(variadic_defs) == 0: if len(args) != len(defs): raise VerifyException( - f"Expected {len(defs)} {def_type_name}, but got {len(args)}") + f"Expected {len(defs)} {def_type_name}, but got {len(args)}" + ) return [] # If there is a single variadic argument, # we can get its size from the number of arguments. if len(variadic_defs) == 1: if len(args) - len(defs) + 1 < 0: - raise VerifyException(f"Expected at least {len(defs) - 1} " - f"{def_type_name}s, got {len(defs)}") + raise VerifyException( + f"Expected at least {len(defs) - 1} " + f"{def_type_name}s, got {len(defs)}" + ) return [len(args) - len(defs) + 1] # Unreachable, all cases should have been handled. @@ -925,10 +988,12 @@ def get_variadic_sizes(op: Operation, op_def: OpDef, def get_operand_result_or_region( - op: Operation, op_def: OpDef, arg_def_idx: int, previous_var_args: int, - construct: VarIRConstruct -) -> None | SSAValue | tuple[SSAValue, - ...] | list[OpResult] | Region | list[Region]: + op: Operation, + op_def: OpDef, + arg_def_idx: int, + previous_var_args: int, + construct: VarIRConstruct, +) -> None | SSAValue | tuple[SSAValue, ...] | list[OpResult] | Region | list[Region]: """ Get an operand, result, or region. In the case of a variadic definition, return a list of elements. @@ -945,8 +1010,9 @@ def get_operand_result_or_region( variadic_sizes = get_variadic_sizes(op, op_def, construct) - begin_arg = arg_def_idx - previous_var_args + sum( - variadic_sizes[:previous_var_args]) + begin_arg = ( + arg_def_idx - previous_var_args + sum(variadic_sizes[:previous_var_args]) + ) if isinstance(defs[arg_def_idx][1], OptionalDef): arg_size = variadic_sizes[previous_var_args] if arg_size == 0: @@ -955,13 +1021,14 @@ def get_operand_result_or_region( return args[begin_arg] if isinstance(defs[arg_def_idx][1], VariadicDef): arg_size = variadic_sizes[previous_var_args] - return args[begin_arg:begin_arg + arg_size] + return args[begin_arg : begin_arg + arg_size] else: return args[begin_arg] -def irdl_op_verify_arg_list(op: Operation, op_def: OpDef, - construct: VarIRConstruct) -> None: +def irdl_op_verify_arg_list( + op: Operation, op_def: OpDef, construct: VarIRConstruct +) -> None: """Verify the argument list of an operation.""" arg_sizes = get_variadic_sizes(op, op_def, construct) arg_idx = 0 @@ -971,22 +1038,26 @@ def irdl_op_verify_arg_list(op: Operation, op_def: OpDef, def verify_arg(arg: Any, arg_def: Any, arg_idx: int) -> None: """Verify a single argument.""" try: - if construct == VarIRConstruct.OPERAND or construct == VarIRConstruct.RESULT: + if ( + construct == VarIRConstruct.OPERAND + or construct == VarIRConstruct.RESULT + ): arg_def.constr.verify(arg.typ) elif construct == VarIRConstruct.REGION: - if isinstance(arg_def, SingleBlockRegionDef) and len( - arg.blocks) != 1: - raise VerifyException("expected a single block, but got " - f"{len(arg.blocks)} blocks") + if isinstance(arg_def, SingleBlockRegionDef) and len(arg.blocks) != 1: + raise VerifyException( + "expected a single block, but got " f"{len(arg.blocks)} blocks" + ) else: assert False, "Unknown ArgType value" except Exception as e: error( - op, f"{get_construct_name(construct)} at position " - f"{arg_idx} does not verify!\n{e}") + op, + f"{get_construct_name(construct)} at position " + f"{arg_idx} does not verify!\n{e}", + ) - for def_idx, (_, - arg_def) in enumerate(get_construct_defs(op_def, construct)): + for def_idx, (_, arg_def) in enumerate(get_construct_defs(op_def, construct)): if isinstance(arg_def, VariadicDef): for _ in range(arg_sizes[var_idx]): verify_arg(args[arg_idx], arg_def, def_idx) @@ -998,43 +1069,51 @@ def verify_arg(arg: Any, arg_def: Any, arg_idx: int) -> None: @overload -def irdl_build_arg_list(construct: Literal[VarIRConstruct.OPERAND], - args: Sequence[SSAValue | Sequence[SSAValue] | None], - arg_defs: Sequence[tuple[str, OperandDef]], - error_prefix: str) -> tuple[list[SSAValue], list[int]]: +def irdl_build_arg_list( + construct: Literal[VarIRConstruct.OPERAND], + args: Sequence[SSAValue | Sequence[SSAValue] | None], + arg_defs: Sequence[tuple[str, OperandDef]], + error_prefix: str, +) -> tuple[list[SSAValue], list[int]]: ... @overload def irdl_build_arg_list( - construct: Literal[VarIRConstruct.RESULT], - args: Sequence[Attribute | Sequence[Attribute] | None], - arg_defs: Sequence[tuple[str, ResultDef]], - error_prefix: str) -> tuple[list[Attribute], list[int]]: + construct: Literal[VarIRConstruct.RESULT], + args: Sequence[Attribute | Sequence[Attribute] | None], + arg_defs: Sequence[tuple[str, ResultDef]], + error_prefix: str, +) -> tuple[list[Attribute], list[int]]: ... @overload -def irdl_build_arg_list(construct: Literal[VarIRConstruct.REGION], - args: Sequence[Region | Sequence[Region] | None], - arg_defs: Sequence[tuple[str, RegionDef]], - error_prefix: str) -> tuple[list[Region], list[int]]: +def irdl_build_arg_list( + construct: Literal[VarIRConstruct.REGION], + args: Sequence[Region | Sequence[Region] | None], + arg_defs: Sequence[tuple[str, RegionDef]], + error_prefix: str, +) -> tuple[list[Region], list[int]]: ... -_T = TypeVar('_T') +_T = TypeVar("_T") -def irdl_build_arg_list(construct: VarIRConstruct, - args: Sequence[_T | Sequence[_T] | None], - arg_defs: Sequence[tuple[str, Any]], - error_prefix: str = "") -> tuple[list[_T], list[int]]: +def irdl_build_arg_list( + construct: VarIRConstruct, + args: Sequence[_T | Sequence[_T] | None], + arg_defs: Sequence[tuple[str, Any]], + error_prefix: str = "", +) -> tuple[list[_T], list[int]]: """Build a list of arguments (operands, results, regions)""" if len(args) != len(arg_defs): raise ValueError( f"Expected {len(arg_defs)} {get_construct_name(construct)}, " - f"but got {len(args)}") + f"but got {len(args)}" + ) res = list[_T]() arg_sizes = list[int]() @@ -1043,24 +1122,24 @@ def irdl_build_arg_list(construct: VarIRConstruct, if arg is None: if not isinstance(arg_def, OptionalDef): raise ValueError( - error_prefix + - f"passed None to a non-optional {construct} {arg_idx} '{arg_name}'" + error_prefix + + f"passed None to a non-optional {construct} {arg_idx} '{arg_name}'" ) elif isinstance(arg, Sequence): if not isinstance(arg_def, VariadicDef): raise ValueError( - error_prefix + - f"passed Sequence to non-variadic {construct} {arg_idx} '{arg_name}'" + error_prefix + + f"passed Sequence to non-variadic {construct} {arg_idx} '{arg_name}'" ) arg = cast(Sequence[_T], arg) # Check we have at most one argument for optional defintions. if isinstance(arg_def, OptionalDef) and len(arg) > 1: raise ValueError( - error_prefix + - f"optional {construct} {arg_idx} '{arg_name}' " + error_prefix + f"optional {construct} {arg_idx} '{arg_name}' " "expects a list of size at most 1 or None, but " - f"got a list of size {len(arg)}") + f"got a list of size {len(arg)}" + ) res.extend(arg) arg_sizes.append(len(arg)) @@ -1074,7 +1153,7 @@ def irdl_build_arg_list(construct: VarIRConstruct, def irdl_build_operations_arg( - operand: _OperandArg | Sequence[_OperandArg] | None + operand: _OperandArg | Sequence[_OperandArg] | None, ) -> SSAValue | list[SSAValue]: if operand is None: return [] @@ -1103,8 +1182,9 @@ def irdl_build_region_arg(r: _RegionArg) -> Region: return Region(cast(Sequence[Block], r)) -def irdl_build_regions_arg(r: _RegionArg | Sequence[_RegionArg] - | None) -> Region | list[Region]: +def irdl_build_regions_arg( + r: _RegionArg | Sequence[_RegionArg] | None, +) -> Region | list[Region]: if r is None: return [] elif isinstance(r, Region): @@ -1118,30 +1198,30 @@ def irdl_build_regions_arg(r: _RegionArg | Sequence[_RegionArg] blocks = cast(Sequence[Block], r) return Region(blocks) else: - return [ - irdl_build_region_arg(_r) for _r in cast(Sequence[_RegionArg], r) - ] + return [irdl_build_region_arg(_r) for _r in cast(Sequence[_RegionArg], r)] def irdl_op_init( self: IRDLOperation, op_def: OpDef, - operands: Sequence[SSAValue | Operation - | Sequence[SSAValue | Operation] - | None], + operands: Sequence[SSAValue | Operation | Sequence[SSAValue | Operation] | None], res_types: Sequence[Attribute | Sequence[Attribute] | None], attributes: Mapping[str, Attribute | None], successors: Sequence[Block], - regions: Sequence[Region | Sequence[Operation] | Sequence[Block] - | Sequence[Region | Sequence[Operation] - | Sequence[Block]] | None], + regions: Sequence[ + Region + | Sequence[Operation] + | Sequence[Block] + | Sequence[Region | Sequence[Operation] | Sequence[Block]] + | None + ], ): """Builder for an irdl operation.""" # We need irdl to define DenseArrayBase, but here we need # DenseArrayBase. # So we have a circular dependency that we solve by importing in this function. - from xdsl.dialects.builtin import (DenseArrayBase, i32) + from xdsl.dialects.builtin import DenseArrayBase, i32 error_prefix = f"Error in {op_def.name} builder: " @@ -1151,60 +1231,67 @@ def irdl_op_init( # Build the operands built_operands, operand_sizes = irdl_build_arg_list( - VarIRConstruct.OPERAND, operands_arg, op_def.operands, error_prefix) + VarIRConstruct.OPERAND, operands_arg, op_def.operands, error_prefix + ) # Build the results built_res_types, result_sizes = irdl_build_arg_list( - VarIRConstruct.RESULT, res_types, op_def.results, error_prefix) + VarIRConstruct.RESULT, res_types, op_def.results, error_prefix + ) # Build the regions - built_regions, region_sizes = irdl_build_arg_list(VarIRConstruct.REGION, - regions_arg, - op_def.regions, - error_prefix) + built_regions, region_sizes = irdl_build_arg_list( + VarIRConstruct.REGION, regions_arg, op_def.regions, error_prefix + ) built_attributes = dict[str, Attribute]() for attr_name, attr in attributes.items(): if attr is None: continue if not isinstance(attr, Attribute): - raise ValueError(error_prefix + - f"{attr_name} is expected to be an " - f"attribute, but got {type(attr)}.") + raise ValueError( + error_prefix + f"{attr_name} is expected to be an " + f"attribute, but got {type(attr)}." + ) built_attributes[attr_name] = attr # Take care of variadic operand and result segment sizes. if AttrSizedOperandSegments() in op_def.options: - built_attributes[AttrSizedOperandSegments.attribute_name] =\ - DenseArrayBase.from_list(i32, operand_sizes) + built_attributes[ + AttrSizedOperandSegments.attribute_name + ] = DenseArrayBase.from_list(i32, operand_sizes) if AttrSizedResultSegments() in op_def.options: - built_attributes[AttrSizedResultSegments.attribute_name] =\ - DenseArrayBase.from_list(i32, result_sizes) + built_attributes[ + AttrSizedResultSegments.attribute_name + ] = DenseArrayBase.from_list(i32, result_sizes) if AttrSizedRegionSegments() in op_def.options: - built_attributes[AttrSizedRegionSegments.attribute_name] =\ - DenseArrayBase.from_list(i32, region_sizes) - - Operation.__init__(self, - operands=built_operands, - result_types=built_res_types, - attributes=built_attributes, - successors=successors, - regions=built_regions) + built_attributes[ + AttrSizedRegionSegments.attribute_name + ] = DenseArrayBase.from_list(i32, region_sizes) + + Operation.__init__( + self, + operands=built_operands, + result_types=built_res_types, + attributes=built_attributes, + successors=successors, + regions=built_regions, + ) -def irdl_op_arg_definition(new_attrs: dict[str, Any], - construct: VarIRConstruct, op_def: OpDef) -> None: +def irdl_op_arg_definition( + new_attrs: dict[str, Any], construct: VarIRConstruct, op_def: OpDef +) -> None: previous_variadics = 0 defs = get_construct_defs(op_def, construct) for arg_idx, (arg_name, arg_def) in enumerate(defs): - def fun(self: Any, - idx: int = arg_idx, - previous_vars: int = previous_variadics): - return get_operand_result_or_region(self, op_def, idx, - previous_vars, construct) + def fun(self: Any, idx: int = arg_idx, previous_vars: int = previous_variadics): + return get_operand_result_or_region( + self, op_def, idx, previous_vars, construct + ) new_attrs[arg_name] = property(fun) if isinstance(arg_def, VariadicDef): @@ -1218,7 +1305,8 @@ def fun(self: Any, raise Exception( "Operation defines more than two variadic " f"{get_construct_name(construct)}s, but do not define the " - f"{arg_size_option_name} PyRDL option.") + f"{arg_size_option_name} PyRDL option." + ) def irdl_op_definition(cls: type[_OpT]) -> type[_OpT]: @@ -1246,7 +1334,6 @@ def irdl_op_definition(cls: type[_OpT]) -> type[_OpT]: irdl_op_arg_definition(new_attrs, VarIRConstruct.REGION, op_def) def optional_attribute_field(attribute_name: str): - def field_getter(self: _OpT): return self.attributes.get(attribute_name, None) @@ -1259,7 +1346,6 @@ def field_setter(self: _OpT, value: Attribute | None): return property(field_getter, field_setter) def attribute_field(attribute_name: str): - def field_getter(self: _OpT): return self.attributes[attribute_name] @@ -1270,8 +1356,7 @@ def field_setter(self: _OpT, value: Attribute): for attribute_name, attr_def in op_def.attributes.items(): if isinstance(attr_def, OptAttributeDef): - new_attrs[attribute_name] = optional_attribute_field( - attribute_name) + new_attrs[attribute_name] = optional_attribute_field(attribute_name) else: new_attrs[attribute_name] = attribute_field(attribute_name) @@ -1284,7 +1369,7 @@ def irdl_definition(cls: type[_OpT]): new_attrs["irdl_definition"] = irdl_definition - custom_verify = getattr(cls, 'verify_') + custom_verify = getattr(cls, "verify_") def verify_(self: _OpT): op_def.verify(self) @@ -1292,10 +1377,9 @@ def verify_(self: _OpT): new_attrs["verify_"] = verify_ - return type(cls.__name__, cls.__mro__, { - **cls.__dict__, - **new_attrs - }) # type: ignore + return type( + cls.__name__, cls.__mro__, {**cls.__dict__, **new_attrs} + ) # type: ignore # ____ _ @@ -1335,7 +1419,7 @@ def irdl_data_verify(data: Data[_DT], typ: type[_DT]) -> None: ) -T = TypeVar('T', bound=Data[Any]) +T = TypeVar("T", bound=Data[Any]) def irdl_data_definition(cls: type[T]) -> type[T]: @@ -1343,7 +1427,6 @@ def irdl_data_definition(cls: type[T]) -> type[T]: new_attrs = dict[str, Any]() def verify(expected_type: type[Any]): - def impl(self: T): return irdl_data_verify(self, expected_type) @@ -1356,28 +1439,35 @@ def impl(self: T): if get_origin(parent) != Data: continue if len(get_args(parent)) != 1: - raise Exception(f"In {cls.__name__} definition: Data expects " - "a single type parameter") + raise Exception( + f"In {cls.__name__} definition: Data expects " + "a single type parameter" + ) expected_type = get_args(parent)[0] if not isclass(expected_type): - raise Exception(f'In {cls.__name__} definition: Cannot infer ' - f'"verify" method. Type parameter of Data is ' - f'not a class.') + raise Exception( + f"In {cls.__name__} definition: Cannot infer " + f'"verify" method. Type parameter of Data is ' + f"not a class." + ) if isinstance(expected_type, GenericAlias): - raise Exception(f'In {cls.__name__} definition: Cannot infer ' - f'"verify" method. Type parameter of Data has ' - f'type GenericAlias.') + raise Exception( + f"In {cls.__name__} definition: Cannot infer " + f'"verify" method. Type parameter of Data has ' + f"type GenericAlias." + ) new_attrs["verify"] = verify(expected_type) break else: - raise Exception(f'Missing method "verify" in {cls.__name__} data ' - 'attribute definition: the "verify" method cannot ' - 'be automatically derived for this definition.') + raise Exception( + f'Missing method "verify" in {cls.__name__} data ' + 'attribute definition: the "verify" method cannot ' + "be automatically derived for this definition." + ) - return dataclass(frozen=True)(type(cls.__name__, (cls, ), { - **cls.__dict__, - **new_attrs - })) # type: ignore + return dataclass(frozen=True)( + type(cls.__name__, (cls,), {**cls.__dict__, **new_attrs}) + ) # type: ignore # ____ _ _ _ @@ -1392,12 +1482,10 @@ def impl(self: T): ParameterDef: TypeAlias = Annotated[_A, IRDLAnnotations.ParamDefAnnot] -def irdl_param_attr_get_param_type_hints( - cls: type[_PAttrT]) -> list[tuple[str, Any]]: +def irdl_param_attr_get_param_type_hints(cls: type[_PAttrT]) -> list[tuple[str, Any]]: """Get the type hints of an IRDL parameter definitions.""" res = list[tuple[str, Any]]() - for field_name, field_type in get_type_hints(cls, - include_extras=True).items(): + for field_name, field_type in get_type_hints(cls, include_extras=True).items(): if field_name == "name" or field_name == "parameters": continue @@ -1405,9 +1493,10 @@ def irdl_param_attr_get_param_type_hints( args = get_args(field_type) if origin != Annotated or IRDLAnnotations.ParamDefAnnot not in args: raise PyRDLAttrDefinitionError( - f"In attribute {cls.__name__} definition: Parameter " + - f"definition {field_name} should be defined with " + - f"type `ParameterDef[]`, got type {field_type}.") + f"In attribute {cls.__name__} definition: Parameter " + + f"definition {field_name} should be defined with " + + f"type `ParameterDef[]`, got type {field_type}." + ) res.append((field_name, field_type)) return res @@ -1416,6 +1505,7 @@ def irdl_param_attr_get_param_type_hints( @dataclass class ParamAttrDef: """The IRDL definition of a parametrized attribute.""" + name: str parameters: list[tuple[str, AttrConstraint]] @@ -1439,17 +1529,19 @@ def from_pyrdl(pyrdl_def: type[ParametrizedAttribute]) -> ParamAttrDef: if field_name in attrdict: continue if isinstance( - value, - (FunctionType, PropertyType, classmethod, staticmethod)): + value, (FunctionType, PropertyType, classmethod, staticmethod) + ): continue raise PyRDLAttrDefinitionError( - f"{field_name} is not a parameter definition.") + f"{field_name} is not a parameter definition." + ) if "name" not in clsdict: raise Exception( f"pyrdl attribute definition '{pyrdl_def.__name__}' does not " "define the attribute name. The attribute name is defined by " - "adding a 'name' field.") + "adding a 'name' field." + ) name = clsdict["name"] @@ -1457,8 +1549,7 @@ def from_pyrdl(pyrdl_def: type[ParametrizedAttribute]) -> ParamAttrDef: parameters = list[tuple[str, AttrConstraint]]() for param_name, param_type in param_hints: - constraint = irdl_to_attr_constraint(param_type, - allow_type_var=True) + constraint = irdl_to_attr_constraint(param_type, allow_type_var=True) parameters.append((param_name, constraint)) return ParamAttrDef(name, parameters) @@ -1470,13 +1561,14 @@ def verify(self, attr: ParametrizedAttribute): raise VerifyException( f"In {self.name} attribute verifier: " f"{len(self.parameters)} parameters expected, got " - f"{len(attr.parameters)}") + f"{len(attr.parameters)}" + ) for param, (_, param_def) in zip(attr.parameters, self.parameters): param_def.verify(param) -_PAttrT = TypeVar('_PAttrT', bound=ParametrizedAttribute) +_PAttrT = TypeVar("_PAttrT", bound=ParametrizedAttribute) def irdl_param_attr_definition(cls: type[_PAttrT]) -> type[_PAttrT]: @@ -1493,7 +1585,6 @@ def irdl_param_attr_definition(cls: type[_PAttrT]) -> type[_PAttrT]: new_fields = dict[str, Any]() def param_name_field(idx: int): - @property def field(self: _PAttrT): return self.parameters[idx] @@ -1510,13 +1601,12 @@ def irdl_definition(cls: type[_PAttrT]): new_fields["irdl_definition"] = irdl_definition - return dataclass(frozen=True, init=False)(type(cls.__name__, (cls, ), { - **cls.__dict__, - **new_fields - })) # type: ignore + return dataclass(frozen=True, init=False)( + type(cls.__name__, (cls,), {**cls.__dict__, **new_fields}) + ) # type: ignore -_AttrT = TypeVar('_AttrT', bound=Attribute) +_AttrT = TypeVar("_AttrT", bound=Attribute) def irdl_attr_definition(cls: type[_AttrT]) -> type[_AttrT]: @@ -1526,4 +1616,5 @@ def irdl_attr_definition(cls: type[_AttrT]) -> type[_AttrT]: return irdl_data_definition(cls) # type: ignore raise Exception( f"Class {cls.__name__} should either be a subclass of 'Data' or " - "'ParametrizedAttribute'") + "'ParametrizedAttribute'" + ) diff --git a/xdsl/irdl_mlir_printer.py b/xdsl/irdl_mlir_printer.py index 98689bd4c2..e46ed7efd7 100644 --- a/xdsl/irdl_mlir_printer.py +++ b/xdsl/irdl_mlir_printer.py @@ -6,11 +6,21 @@ from xdsl.irdl import AttrConstraint from xdsl.printer import Printer from xdsl.dialects.irdl import ( - ConstraintVarsOp, DialectOp, ParametersOp, TypeOp, OperandsOp, ResultsOp, - OperationOp, EqTypeConstraintAttr, AnyTypeConstraintAttr, - AnyOfTypeConstraintAttr, DynTypeBaseConstraintAttr, - DynTypeParamsConstraintAttr, TypeParamsConstraintAttr, - NamedTypeConstraintAttr) + ConstraintVarsOp, + DialectOp, + ParametersOp, + TypeOp, + OperandsOp, + ResultsOp, + OperationOp, + EqTypeConstraintAttr, + AnyTypeConstraintAttr, + AnyOfTypeConstraintAttr, + DynTypeBaseConstraintAttr, + DynTypeParamsConstraintAttr, + TypeParamsConstraintAttr, + NamedTypeConstraintAttr, +) # TODO: remove from "ignore" list in pyproject.toml # https://github.com/xdslproject/xdsl/issues/572 @@ -18,7 +28,6 @@ @dataclass(frozen=True, eq=False) class IRDLPrinter: - stream: IO[str] def _print(self, s: str, end: str | None = None): @@ -26,25 +35,34 @@ def _print(self, s: str, end: str | None = None): def print_module(self, module: ModuleOp): module.walk(lambda op: self.ensure_op_is_irdl_op(op)) - self._print('module {') - module.walk(lambda di: IRDLPrinter.print_dialect_definition(self, di) - if isinstance(di, DialectOp) else None) - self._print('}') + self._print("module {") + module.walk( + lambda di: IRDLPrinter.print_dialect_definition(self, di) + if isinstance(di, DialectOp) + else None + ) + self._print("}") def ensure_op_is_irdl_op(self, op: Operation): if not isinstance( - op, (OperationOp | ResultsOp | OperandsOp, ConstraintVarsOp - | TypeOp | ParametersOp | DialectOp | ModuleOp)): + op, + ( + OperationOp | ResultsOp | OperandsOp, + ConstraintVarsOp | TypeOp | ParametersOp | DialectOp | ModuleOp, + ), + ): raise Exception(f"Operation {op.name} is not an operation in IRDL") def print_type_definition(self, type: TypeOp): self._print(f" {TypeOp.name} {type.type_name.data} {{") - type.walk(lambda param: self.print_parameters_definition(param) - if isinstance(param, ParametersOp) else None) + type.walk( + lambda param: self.print_parameters_definition(param) + if isinstance(param, ParametersOp) + else None + ) self._print(" }") def print_attr_constraint(self, f: AttrConstraint | Attribute): - if isinstance(f, AnyOfTypeConstraintAttr): self._print("AnyOf", end="<") for i in range(len(f.params.data) - 1): @@ -61,10 +79,10 @@ def print_attr_constraint(self, f: AttrConstraint | Attribute): self._print("Any", end="") elif isinstance(f, DynTypeBaseConstraintAttr): - self._print(f.type_name.data, end='') + self._print(f.type_name.data, end="") elif isinstance(f, DynTypeParamsConstraintAttr): - self._print(f.type_name.data, end='') + self._print(f.type_name.data, end="") self._print("<", end="") for i in range(len(f.params.data)): self.print_attr_constraint(f.params.data[i]) @@ -74,7 +92,7 @@ def print_attr_constraint(self, f: AttrConstraint | Attribute): self._print(f"{f.type_name}") elif isinstance(f, NamedTypeConstraintAttr): - self._print(f.type_name.data, end='') + self._print(f.type_name.data, end="") self._print(":", end=" ") self.print_attr_constraint(f.params_constraints) @@ -90,20 +108,23 @@ def print_parameters_definition(self, param_op: ParametersOp): self._print(")") def print_operation_definition(self, operation: OperationOp): - self._print( - f" {OperationOp.name} {operation.attributes['name'].data} {{") + self._print(f" {OperationOp.name} {operation.attributes['name'].data} {{") # Checking for existence of operands operand_list = [] - operation.walk(lambda operand_def: operand_list.append(operand_def) - if isinstance(operand_def, OperandsOp) else None) + operation.walk( + lambda operand_def: operand_list.append(operand_def) + if isinstance(operand_def, OperandsOp) + else None + ) if operand_list: self.print_operand_definition(operand_list) # Checking for existence of results result_list = [] - operation.walk(lambda res: result_list.append(res) - if isinstance(res, ResultsOp) else None) + operation.walk( + lambda res: result_list.append(res) if isinstance(res, ResultsOp) else None + ) if result_list: self.print_result_definition(result_list) @@ -115,7 +136,7 @@ def print_result_definition(self, res_list: list[ResultsOp]): for i in range(len(res_list)): for result in res_list[i].params.data: self.print_attr_constraint(result) - self._print(", ", end='') if i != len(res_list) - 1 else None + self._print(", ", end="") if i != len(res_list) - 1 else None self._print(")") def print_operand_definition(self, op_list: list[OperandsOp]): @@ -124,15 +145,21 @@ def print_operand_definition(self, op_list: list[OperandsOp]): for i in range(len(op_list)): for ops in op_list[i].params.data: self.print_attr_constraint(ops) - self._print(", ", end='') if i != len(op_list) - 1 else None + self._print(", ", end="") if i != len(op_list) - 1 else None self._print(")") def print_dialect_definition(self, di: DialectOp): self._print(f" {DialectOp.name} {di.dialect_name.data} {{") - di.walk(lambda type: self.print_type_definition(type) - if isinstance(type, TypeOp) else None) - - di.walk(lambda op: self.print_operation_definition(op) - if isinstance(op, OperationOp) else None) + di.walk( + lambda type: self.print_type_definition(type) + if isinstance(type, TypeOp) + else None + ) + + di.walk( + lambda op: self.print_operation_definition(op) + if isinstance(op, OperationOp) + else None + ) self._print(" }}") diff --git a/xdsl/parser.py b/xdsl/parser.py index e9c6a192c1..7c0a46bc21 100644 --- a/xdsl/parser.py +++ b/xdsl/parser.py @@ -18,16 +18,56 @@ from xdsl.utils.lexer import Input, Lexer, Span, StringLiteral, Token from xdsl.dialects.memref import AnyIntegerAttr, MemRefType, UnrankedMemrefType from xdsl.dialects.builtin import ( - AnyArrayAttr, AnyFloat, AnyFloatAttr, AnyTensorType, AnyUnrankedTensorType, - AnyVectorType, BFloat16Type, DenseResourceAttr, DictionaryAttr, - Float16Type, Float32Type, Float64Type, Float80Type, Float128Type, - FloatAttr, FunctionType, IndexType, IntegerType, Signedness, StringAttr, - IntegerAttr, ArrayAttr, TensorType, UnrankedTensorType, UnregisteredAttr, - RankedVectorOrTensorOf, VectorType, SymbolRefAttr, DenseArrayBase, - DenseIntOrFPElementsAttr, OpaqueAttr, NoneAttr, ModuleOp, UnitAttr, i64, - StridedLayoutAttr, ComplexType) -from xdsl.ir import (SSAValue, Block, Callable, Attribute, Operation, Region, - BlockArgument, MLContext, ParametrizedAttribute, Data) + AnyArrayAttr, + AnyFloat, + AnyFloatAttr, + AnyTensorType, + AnyUnrankedTensorType, + AnyVectorType, + BFloat16Type, + DenseResourceAttr, + DictionaryAttr, + Float16Type, + Float32Type, + Float64Type, + Float80Type, + Float128Type, + FloatAttr, + FunctionType, + IndexType, + IntegerType, + Signedness, + StringAttr, + IntegerAttr, + ArrayAttr, + TensorType, + UnrankedTensorType, + UnregisteredAttr, + RankedVectorOrTensorOf, + VectorType, + SymbolRefAttr, + DenseArrayBase, + DenseIntOrFPElementsAttr, + OpaqueAttr, + NoneAttr, + ModuleOp, + UnitAttr, + i64, + StridedLayoutAttr, + ComplexType, +) +from xdsl.ir import ( + SSAValue, + Block, + Callable, + Attribute, + Operation, + Region, + BlockArgument, + MLContext, + ParametrizedAttribute, + Data, +) from xdsl.utils.hints import isa @@ -52,6 +92,7 @@ class BacktrackingHistory: It's parent will be the next error message (not a known attribute). Some errors happen in named regions (e.g. "parsing of operation") """ + error: ParseError parent: BacktrackingHistory | None region_name: str | None @@ -67,8 +108,9 @@ def print_unroll(self, file: IO[str] = sys.stderr): self.parent.print_unroll(file) def print(self, file: IO[str] = sys.stderr): - print("Parsing of {} failed:".format(self.region_name or ""), - file=file) + print( + "Parsing of {} failed:".format(self.region_name or ""), file=file + ) self.error.print_pretty(file=file) @functools.cache @@ -127,17 +169,41 @@ class Tokenizer: The position in the input. Points to the first unconsumed character. """ - _break_on: tuple[str, - ...] = ('.', '%', ' ', '(', ')', '[', ']', '{', '}', '<', - '>', ':', '=', '@', '?', '|', '->', '-', '//', - '\n', '\t', '#', '"', "'", ',', '!', '+', '*') + _break_on: tuple[str, ...] = ( + ".", + "%", + " ", + "(", + ")", + "[", + "]", + "{", + "}", + "<", + ">", + ":", + "=", + "@", + "?", + "|", + "->", + "-", + "//", + "\n", + "\t", + "#", + '"', + "'", + ",", + "!", + "+", + "*", + ) """ characters the tokenizer should break on """ - history: BacktrackingHistory | None = field(init=False, - default=None, - repr=False) + history: BacktrackingHistory | None = field(init=False, default=None, repr=False) last_token: Span | None = field(init=False, default=None, repr=False) @@ -158,8 +224,9 @@ def resume_from(self, save: save_t): """ self.pos = save - def _history_entry_from_exception(self, ex: Exception, region: str | None, - pos: int) -> BacktrackingHistory: + def _history_entry_from_exception( + self, ex: Exception, region: str | None, pos: int + ) -> BacktrackingHistory: """ Given an exception generated inside a backtracking attempt, generate a BacktrackingHistory object with the relevant information in it. @@ -198,8 +265,9 @@ def _history_entry_from_exception(self, ex: Exception, region: str | None, traceback.print_exception(ex, file=sys.stderr) return BacktrackingHistory( - ParseError(self.last_token, "Unexpected exception: {}".format(ex), - self.history), + ParseError( + self.last_token, "Unexpected exception: {}".format(ex), self.history + ), self.history, region, pos, @@ -226,9 +294,9 @@ def next_token(self, peek: bool = False) -> Span: self.last_token = span return span - def next_token_of_pattern(self, - pattern: re.Pattern[str] | str, - peek: bool = False) -> Span | None: + def next_token_of_pattern( + self, pattern: re.Pattern[str] | str, peek: bool = False + ) -> Span | None: """ Return a span that matched the pattern, or nothing. You can choose not to consume the span. @@ -277,7 +345,8 @@ def _find_token_end(self, start: int | None = None) -> int: filter( lambda x: x >= 0, (self.input.content.find(part, i) for part in self._break_on), - )) + ) + ) # Make sure that we break at some point break_pos.append(self.input.len) return min(break_pos) @@ -345,15 +414,30 @@ class ParserCommons: boolean_literal = re.compile(r"(true|false)") # A list of names that are builtin types _builtin_type_names = ( - r"[su]?i\d+", "bf16", r"f\d+", "tensor", "vector", "memref", "complex", - "opaque", "tuple", "index", "dense" + r"[su]?i\d+", + "bf16", + r"f\d+", + "tensor", + "vector", + "memref", + "complex", + "opaque", + "tuple", + "index", + "dense" # TODO: add all the Float8E4M3FNType, Float8E5M2Type, and BFloat16Type ) - builtin_attr_names = ('dense', 'opaque', 'affine_set', 'affine_map', - 'array', 'dense_resource', 'sparse') + builtin_attr_names = ( + "dense", + "opaque", + "affine_set", + "affine_map", + "array", + "dense_resource", + "sparse", + ) builtin_type = re.compile("(({}))".format(")|(".join(_builtin_type_names))) - builtin_type_xdsl = re.compile("!(({}))".format( - ")|(".join(_builtin_type_names))) + builtin_type_xdsl = re.compile("!(({}))".format(")|(".join(_builtin_type_names))) double_colon = re.compile("::") comma = re.compile(",") @@ -401,11 +485,13 @@ class BaseParser(ABC): allow_unregistered_dialect: bool - def __init__(self, - ctx: MLContext, - input: str, - name: str = '', - allow_unregistered_dialect: bool = False): + def __init__( + self, + ctx: MLContext, + input: str, + name: str = "", + allow_unregistered_dialect: bool = False, + ): self.tokenizer = Tokenizer(Input(input, name)) self.lexer = Lexer(Input(input, name)) self._current_token = self.lexer.lex() @@ -450,7 +536,10 @@ def backtracking(self, region_name: str | None = None): # This is because we are only interested in the last "cascade" of failures. # If a backtracking() completes without failure, # something has been parsed (we assume) - if self.tokenizer.pos > starting_position and self.tokenizer.history is not None: + if ( + self.tokenizer.pos > starting_position + and self.tokenizer.history is not None + ): self.tokenizer.history = None except Exception as ex: how_far_we_got = self.tokenizer.pos @@ -458,9 +547,10 @@ def backtracking(self, region_name: str | None = None): # If we have no error history, start recording! if not self.tokenizer.history: self.tokenizer.history = ( - self.tokenizer. - _history_entry_from_exception( # type: ignore - ex, region_name, how_far_we_got)) + self.tokenizer._history_entry_from_exception( # type: ignore + ex, region_name, how_far_we_got + ) + ) # If we got further than on previous attempts elif how_far_we_got > self.tokenizer.history.get_farthest_point(): @@ -468,16 +558,18 @@ def backtracking(self, region_name: str | None = None): self.tokenizer.history = None # Generate new history entry, self.tokenizer.history = ( - self.tokenizer. - _history_entry_from_exception( # type: ignore - ex, region_name, how_far_we_got)) + self.tokenizer._history_entry_from_exception( # type: ignore + ex, region_name, how_far_we_got + ) + ) # Otherwise, add to exception, if we are in a named region elif region_name is not None and how_far_we_got - starting_position > 0: self.tokenizer.history = ( - self.tokenizer. - _history_entry_from_exception( # type: ignore - ex, region_name, how_far_we_got)) + self.tokenizer._history_entry_from_exception( # type: ignore + ex, region_name, how_far_we_got + ) + ) self.resume_from(save) @@ -536,7 +628,8 @@ def _parse_token(self, expected_kind: Token.Kind, error_msg: str) -> Token: return current_token def _parse_optional_token_in( - self, expected_kinds: Iterable[Token.Kind]) -> Token | None: + self, expected_kinds: Iterable[Token.Kind] + ) -> Token | None: """Parse one of the expected tokens if present, and returns it.""" if self._current_token.kind not in expected_kinds: return None @@ -552,8 +645,9 @@ def parse_module(self) -> ModuleOp: return op else: self.tokenizer.pos = 0 - self.raise_error("Expected ModuleOp at top level!", - self.tokenizer.next_token()) + self.raise_error( + "Expected ModuleOp at top level!", self.tokenizer.next_token() + ) def _get_block_from_name(self, block_name: Span) -> Block: """ @@ -572,8 +666,7 @@ def parse_block(self) -> Block: if block_id is None: block = Block(declared_at=self.tokenizer.last_token) - elif self.forward_block_references.pop(block_id.text, - None) is not None: + elif self.forward_block_references.pop(block_id.text, None) is not None: block = self.blocks[block_id.text] block.declared_at = block_id else: @@ -594,7 +687,7 @@ def parse_block(self) -> Block: for i, (name, type) in enumerate(args): arg = BlockArgument(type, block, i) - self.ssa_values[name.text[1:]] = (arg, ) + self.ssa_values[name.text[1:]] = (arg,) # store ssa val name if valid if SSAValue.is_valid_name(name.text[1:]): arg.name = name.text[1:] @@ -608,7 +701,8 @@ def parse_block(self) -> Block: return block def _parse_optional_block_label( - self) -> tuple[Span | None, list[tuple[Span, Attribute]]]: + self, + ) -> tuple[Span | None, list[tuple[Span, Attribute]]]: """ A block label consists of block-id ( `(` block-arg `,` ... `)` )? """ @@ -616,54 +710,57 @@ def _parse_optional_block_label( arg_list = list[tuple[Span, Attribute]]() if block_id is not None: - if self.tokenizer.starts_with('('): + if self.tokenizer.starts_with("("): arg_list = self._parse_block_arg_list() - self.parse_characters(':', 'Block label must end in a `:`!') + self.parse_characters(":", "Block label must end in a `:`!") return block_id, arg_list def _parse_block_arg_list(self) -> list[tuple[Span, Attribute]]: - self.parse_characters('(', 'Block arguments must start with `(`') + self.parse_characters("(", "Block arguments must start with `(`") - args = self.parse_list_of(self.try_parse_value_id_and_type, - "Expected value-id and type here!") + args = self.parse_list_of( + self.try_parse_value_id_and_type, "Expected value-id and type here!" + ) - self.parse_characters(')', 'Expected closing of block arguments!') + self.parse_characters(")", "Expected closing of block arguments!") return args def try_parse_single_reference(self) -> Span | None: - with self.backtracking('part of a reference'): - self.parse_characters('@', "references must start with `@`") + with self.backtracking("part of a reference"): + self.parse_characters("@", "references must start with `@`") if (reference := self.try_parse_string_literal()) is not None: return reference if (reference := self.try_parse_suffix_id()) is not None: return reference self.raise_error( - "References must conform to `@` (string-literal | suffix-id)") + "References must conform to `@` (string-literal | suffix-id)" + ) def parse_reference(self) -> list[Span]: return self.parse_list_of( self.try_parse_single_reference, - 'Expected reference here in the format of `@` (suffix-id | string-literal)', + "Expected reference here in the format of `@` (suffix-id | string-literal)", ParserCommons.double_colon, - allow_empty=False) + allow_empty=False, + ) class Delimiter(Enum): """ Supported delimiters when parsing lists. """ + PAREN = auto() ANGLE = auto() SQUARE = auto() BRACES = auto() NONE = auto() - def parse_comma_separated_list(self, - delimiter: Delimiter, - parse: Callable[[], T_], - context_msg: str = '') -> list[T_]: + def parse_comma_separated_list( + self, delimiter: Delimiter, parse: Callable[[], T_], context_msg: str = "" + ) -> list[T_]: """ Parses greedily a list of elements separated by commas, and delimited by the specified delimiter. The parsing stops when the delimiter is @@ -684,8 +781,7 @@ def parse_comma_separated_list(self, self._synchronize_lexer_and_tokenizer() return [] elif delimiter == self.Delimiter.SQUARE: - self._parse_token(Token.Kind.L_SQUARE, - "Expected '['" + context_msg) + self._parse_token(Token.Kind.L_SQUARE, "Expected '['" + context_msg) if self._parse_optional_token(Token.Kind.R_SQUARE) is not None: self._synchronize_lexer_and_tokenizer() return [] @@ -712,8 +808,7 @@ def parse_comma_separated_list(self, elif delimiter == self.Delimiter.ANGLE: self._parse_token(Token.Kind.GREATER, "Expected '>'" + context_msg) elif delimiter == self.Delimiter.SQUARE: - self._parse_token(Token.Kind.R_SQUARE, - "Expected ']'" + context_msg) + self._parse_token(Token.Kind.R_SQUARE, "Expected ']'" + context_msg) elif delimiter == self.Delimiter.BRACES: self._parse_token(Token.Kind.R_BRACE, "Expected '}'" + context_msg) else: @@ -722,11 +817,13 @@ def parse_comma_separated_list(self, self._synchronize_lexer_and_tokenizer() return elems - def parse_list_of(self, - try_parse: Callable[[], T_ | None], - error_msg: str, - separator_pattern: re.Pattern[str] = ParserCommons.comma, - allow_empty: bool = True) -> list[T_]: + def parse_list_of( + self, + try_parse: Callable[[], T_ | None], + error_msg: str, + separator_pattern: re.Pattern[str] = ParserCommons.comma, + allow_empty: bool = True, + ) -> list[T_]: """ This is a greedy list-parser. It accepts input only in these cases: @@ -752,16 +849,18 @@ def parse_list_of(self, items = [first_item] - while (match := self.tokenizer.next_token_of_pattern(separator_pattern) - ) is not None: + while ( + match := self.tokenizer.next_token_of_pattern(separator_pattern) + ) is not None: next_item = try_parse() if next_item is None: # If the separator is emtpy, we are good here - if separator_pattern.pattern == '': + if separator_pattern.pattern == "": return items - self.raise_error(error_msg + - ' because was able to match next separator {}' - .format(match.text)) + self.raise_error( + error_msg + + " because was able to match next separator {}".format(match.text) + ) items.append(next_item) return items @@ -772,26 +871,28 @@ def parse_optional_boolean(self) -> bool | None: """ self._synchronize_lexer_and_tokenizer() if self._current_token.kind == Token.Kind.BARE_IDENT: - if self._current_token.text == 'true': + if self._current_token.text == "true": self._consume_token(Token.Kind.BARE_IDENT) self._synchronize_lexer_and_tokenizer() return True - elif self._current_token.text == 'false': + elif self._current_token.text == "false": self._consume_token(Token.Kind.BARE_IDENT) self._synchronize_lexer_and_tokenizer() return False return None - def parse_boolean(self, context_msg: str = '') -> bool: + def parse_boolean(self, context_msg: str = "") -> bool: """ Parse a boolean with the format `true` or `false`. """ - return self.expect(lambda: self.parse_optional_boolean(), - 'Expected boolean literal' + context_msg) + return self.expect( + lambda: self.parse_optional_boolean(), + "Expected boolean literal" + context_msg, + ) - def parse_optional_integer(self, - allow_boolean: bool = True, - allow_negative: bool = True) -> int | None: + def parse_optional_integer( + self, allow_boolean: bool = True, allow_negative: bool = True + ) -> int | None: """ Parse an (possible negative) integer. The integer can either be decimal or hexadecimal. @@ -806,12 +907,10 @@ def parse_optional_integer(self, # Parse negative numbers if required is_negative = False if allow_negative: - is_negative = self._parse_optional_token( - Token.Kind.MINUS) is not None + is_negative = self._parse_optional_token(Token.Kind.MINUS) is not None # Parse the actual number - if (int_token := - self._parse_optional_token(Token.Kind.INTEGER_LIT)) is None: + if (int_token := self._parse_optional_token(Token.Kind.INTEGER_LIT)) is None: if is_negative: self.raise_error("Expected integer literal after '-'") self._synchronize_lexer_and_tokenizer() @@ -831,13 +930,14 @@ def parse_optional_number(self) -> int | float | None: is_negative = self._parse_optional_token(Token.Kind.MINUS) is not None - if (value := - self.parse_optional_integer(allow_boolean=False, - allow_negative=False)) is not None: + if ( + value := self.parse_optional_integer( + allow_boolean=False, allow_negative=False + ) + ) is not None: return -value if is_negative else value - if (value := - self._parse_optional_token(Token.Kind.FLOAT_LIT)) is not None: + if (value := self._parse_optional_token(Token.Kind.FLOAT_LIT)) is not None: value = value.get_float_value() return -value if is_negative else value @@ -845,17 +945,21 @@ def parse_optional_number(self) -> int | float | None: self.raise_error("Expected integer or float literal after '-'") return None - def parse_number(self, context_msg: str = '') -> int | float: + def parse_number(self, context_msg: str = "") -> int | float: """ Parse an integer or float literal. """ - return self.expect(lambda: self.parse_optional_number(), - 'Expected integer or float literal' + context_msg) + return self.expect( + lambda: self.parse_optional_number(), + "Expected integer or float literal" + context_msg, + ) - def parse_integer(self, - allow_boolean: bool = True, - allow_negative: bool = True, - context_msg: str = '') -> int: + def parse_integer( + self, + allow_boolean: bool = True, + allow_negative: bool = True, + context_msg: str = "", + ) -> int: """ Parse an (possible negative) integer. The integer can either be decimal or hexadecimal. @@ -864,23 +968,22 @@ def parse_integer(self, return self.expect( lambda: self.parse_optional_integer(allow_boolean, allow_negative), - 'Expected integer literal' + context_msg) + "Expected integer literal" + context_msg, + ) def try_parse_integer_literal(self) -> Span | None: - return self.tokenizer.next_token_of_pattern( - ParserCommons.integer_literal) + return self.tokenizer.next_token_of_pattern(ParserCommons.integer_literal) def try_parse_decimal_literal(self) -> Span | None: - return self.tokenizer.next_token_of_pattern( - ParserCommons.decimal_literal) + return self.tokenizer.next_token_of_pattern(ParserCommons.decimal_literal) def try_parse_string_literal(self) -> StringLiteral | None: return StringLiteral.from_span( - self.tokenizer.next_token_of_pattern(ParserCommons.string_literal)) + self.tokenizer.next_token_of_pattern(ParserCommons.string_literal) + ) def try_parse_float_literal(self) -> Span | None: - return self.tokenizer.next_token_of_pattern( - ParserCommons.float_literal) + return self.tokenizer.next_token_of_pattern(ParserCommons.float_literal) def try_parse_bare_id(self) -> Span | None: return self.tokenizer.next_token_of_pattern(ParserCommons.bare_id) @@ -888,7 +991,7 @@ def try_parse_bare_id(self) -> Span | None: def try_parse_value_id(self) -> Span | None: return self.tokenizer.next_token_of_pattern(ParserCommons.value_id) - _decimal_integer_regex = re.compile(r'[0-9]+') + _decimal_integer_regex = re.compile(r"[0-9]+") def parse_optional_operand(self) -> SSAValue | None: """ @@ -903,23 +1006,23 @@ def parse_optional_operand(self) -> SSAValue | None: index = 0 index_token = self._parse_optional_token(Token.Kind.HASH_IDENT) if index_token is not None: - if re.fullmatch(self._decimal_integer_regex, - index_token.text[1:]) is None: - self.raise_error('Expected integer as SSA value tuple index', - index_token.span) + if re.fullmatch(self._decimal_integer_regex, index_token.text[1:]) is None: + self.raise_error( + "Expected integer as SSA value tuple index", index_token.span + ) index = int(index_token.text[1:], 10) if name not in self.ssa_values.keys(): - self.raise_error('SSA value used before assignment', - name_token.span) + self.raise_error("SSA value used before assignment", name_token.span) tuple_size = len(self.ssa_values[name]) if index >= tuple_size: - assert index_token is not None, 'Fatal error in SSA value parsing' + assert index_token is not None, "Fatal error in SSA value parsing" self.raise_error( - 'SSA value tuple index out of bounds. ' - f'Tuple is of size {tuple_size} but tried to access element {index}.', - index_token.span) + "SSA value tuple index out of bounds. " + f"Tuple is of size {tuple_size} but tried to access element {index}.", + index_token.span, + ) self._synchronize_lexer_and_tokenizer() return self.ssa_values[name][index] @@ -935,8 +1038,7 @@ def try_parse_block_id(self) -> Span | None: return self.tokenizer.next_token_of_pattern(ParserCommons.block_id) def try_parse_boolean_literal(self) -> Span | None: - return self.tokenizer.next_token_of_pattern( - ParserCommons.boolean_literal) + return self.tokenizer.next_token_of_pattern(ParserCommons.boolean_literal) def try_parse_value_id_and_type(self) -> tuple[Span, Attribute] | None: with self.backtracking("value id and type"): @@ -945,8 +1047,7 @@ def try_parse_value_id_and_type(self) -> tuple[Span, Attribute] | None: if value_id is None: self.raise_error("Invalid value-id format!") - self.parse_characters(':', - 'Expected expression (value-id `:` type)') + self.parse_characters(":", "Expected expression (value-id `:` type)") type = self.try_parse_type() @@ -955,7 +1056,7 @@ def try_parse_value_id_and_type(self) -> tuple[Span, Attribute] | None: return value_id, type def try_parse_type(self) -> Attribute | None: - if self.tokenizer.starts_with('!'): + if self.tokenizer.starts_with("!"): return self.try_parse_dialect_type() else: return self.try_parse_builtin_type() @@ -964,43 +1065,42 @@ def try_parse_dialect_type_or_attribute(self) -> Attribute | None: """ Parse a type or an attribute. """ - kind = self.tokenizer.next_token_of_pattern(re.compile('[!#]'), - peek=True) + kind = self.tokenizer.next_token_of_pattern(re.compile("[!#]"), peek=True) if kind is None: return None with self.backtracking("dialect attribute or type"): self.tokenizer.consume_peeked(kind) - if kind.text == '!': - return self._parse_dialect_type_or_attribute_inner('type') + if kind.text == "!": + return self._parse_dialect_type_or_attribute_inner("type") else: - return self._parse_dialect_type_or_attribute_inner('attribute') + return self._parse_dialect_type_or_attribute_inner("attribute") def try_parse_dialect_type(self): """ Parse a dialect type (something prefixed by `!`, defined by a dialect) """ - if not self.tokenizer.starts_with('!'): + if not self.tokenizer.starts_with("!"): return None with self.backtracking("dialect type"): - self.parse_characters('!', "Dialect type must start with a `!`") - return self._parse_dialect_type_or_attribute_inner('type') + self.parse_characters("!", "Dialect type must start with a `!`") + return self._parse_dialect_type_or_attribute_inner("type") def try_parse_dialect_attr(self): """ Parse a dialect attribute (something prefixed by `#`, defined by a dialect) """ - if not self.tokenizer.starts_with('#'): + if not self.tokenizer.starts_with("#"): return None with self.backtracking("dialect attribute"): - self.parse_characters('#', - "Dialect attribute must start with a `#`") - return self._parse_dialect_type_or_attribute_inner('attribute') + self.parse_characters("#", "Dialect attribute must start with a `#`") + return self._parse_dialect_type_or_attribute_inner("attribute") def _parse_dialect_type_or_attribute_inner( - self, kind: Literal['attribute'] | Literal['type']) -> Attribute: - is_type = kind == 'type' + self, kind: Literal["attribute"] | Literal["type"] + ) -> Attribute: + is_type = kind == "type" type_name = self.tokenizer.next_token_of_pattern(ParserCommons.bare_id) if type_name is None: @@ -1009,11 +1109,12 @@ def _parse_dialect_type_or_attribute_inner( type_def = self.ctx.get_optional_attr( type_name.text, self.allow_unregistered_dialect, - create_unregistered_as_type=is_type) + create_unregistered_as_type=is_type, + ) if type_def is None: self.raise_error( - "'{}' is not a known attribute!".format(type_name.text), - type_name) + "'{}' is not a known attribute!".format(type_name.text), type_name + ) # Pass the task of parsing parameters on to the attribute/type definition if issubclass(type_def, UnregisteredAttr): @@ -1027,7 +1128,8 @@ def _parse_dialect_type_or_attribute_inner( self.parse_characters("<", "This attribute must be parametrized!") param: Any = type_def.parse_parameter(self) self.parse_characters( - ">", "Invalid attribute parametrization, expected `>`!") + ">", "Invalid attribute parametrization, expected `>`!" + ) return cast(Data[Any], type_def(param)) assert False, "Attributes are either ParametrizedAttribute or Data." @@ -1049,30 +1151,32 @@ def _parse_unregistered_attr_body(self) -> str: Token.Kind.GREATER: Token.Kind.LESS, Token.Kind.R_PAREN: Token.Kind.L_PAREN, Token.Kind.R_SQUARE: Token.Kind.L_SQUARE, - Token.Kind.R_BRACE: Token.Kind.L_BRACE + Token.Kind.R_BRACE: Token.Kind.L_BRACE, } parentheses_names = { - Token.Kind.GREATER: '`>`', - Token.Kind.R_PAREN: '`)`', - Token.Kind.R_SQUARE: '`]`', - Token.Kind.R_BRACE: '`}`' + Token.Kind.GREATER: "`>`", + Token.Kind.R_PAREN: "`)`", + Token.Kind.R_SQUARE: "`]`", + Token.Kind.R_BRACE: "`}`", } while True: # Opening a new parenthesis - if (token := self._parse_optional_token_in( - parentheses.values())) is not None: + if ( + token := self._parse_optional_token_in(parentheses.values()) + ) is not None: symbols_stack.append(token.kind) continue # Closing a parenthesis - if (token := self._parse_optional_token_in( - parentheses.keys())) is not None: + if (token := self._parse_optional_token_in(parentheses.keys())) is not None: closing = parentheses[token.kind] if symbols_stack[-1] != closing: self.raise_error( "Mismatched {} in attribute body!".format( - parentheses_names[token.kind]), - self._current_token.span) + parentheses_names[token.kind] + ), + self._current_token.span, + ) symbols_stack.pop() if len(symbols_stack) == 0: end_pos = token.span.end @@ -1082,7 +1186,8 @@ def _parse_unregistered_attr_body(self) -> str: # Checking for unexpected EOF if self._parse_optional_token(Token.Kind.EOF) is not None: self.raise_error( - "Unexpected end of file before closing of attribute body!") + "Unexpected end of file before closing of attribute body!" + ) # Other tokens self._consume_token() @@ -1099,16 +1204,14 @@ def try_parse_builtin_type(self) -> Attribute | None: """ raise NotImplementedError("Subclasses must implement this method!") - def _parse_builtin_parametrized_type(self, - name: Span) -> ParametrizedAttribute: + def _parse_builtin_parametrized_type(self, name: Span) -> ParametrizedAttribute: """ This function is called after we parse the name of a parameterized type such as vector. """ def unimplemented() -> ParametrizedAttribute: - raise ParseError(name, - "Builtin {} not supported yet!".format(name.text)) + raise ParseError(name, "Builtin {} not supported yet!".format(name.text)) builtin_parsers: dict[str, Callable[[], ParametrizedAttribute]] = { "vector": self.parse_vector_attrs, @@ -1135,16 +1238,21 @@ def _parse_shape_dimension(self, allow_dynamic: bool = True) -> int: `0x10` will be split into `0` and `x10`. Optionally allows to not parse `?` as -1. """ - if self._current_token.kind not in (Token.Kind.INTEGER_LIT, - Token.Kind.QUESTION): + if self._current_token.kind not in ( + Token.Kind.INTEGER_LIT, + Token.Kind.QUESTION, + ): if allow_dynamic: self.raise_error( "Expected either integer literal or '?' in shape dimension, " - f"got {self._current_token.kind.name}!") - self.raise_error("Expected integer literal in shape dimension, " - f"got {self._current_token.kind.name}!") + f"got {self._current_token.kind.name}!" + ) + self.raise_error( + "Expected integer literal in shape dimension, " + f"got {self._current_token.kind.name}!" + ) - if self.parse_optional_punctuation('?') is not None: + if self.parse_optional_punctuation("?") is not None: if allow_dynamic: return -1 self.raise_error("Unexpected dynamic dimension!") @@ -1152,7 +1260,7 @@ def _parse_shape_dimension(self, allow_dynamic: bool = True) -> int: # If the integer literal starts with `0x`, this is decomposed into # `0` and `x`. int_token = self._consume_token(Token.Kind.INTEGER_LIT) - if int_token.text[:2] == '0x': + if int_token.text[:2] == "0x": self.resume_from(int_token.span.start + 1) return 0 @@ -1165,12 +1273,15 @@ def _parse_shape_delimiter(self) -> None: into 'x' and '1'. """ if self._current_token.kind != Token.Kind.BARE_IDENT: - self.raise_error("Expected 'x' in shape delimiter, got " - f"{self._current_token.kind.name}") + self.raise_error( + "Expected 'x' in shape delimiter, got " + f"{self._current_token.kind.name}" + ) - if self._current_token.text[0] != 'x': - self.raise_error("Expected 'x' in shape delimiter, got " - f"{self._current_token.text}") + if self._current_token.text[0] != "x": + self.raise_error( + "Expected 'x' in shape delimiter, got " f"{self._current_token.text}" + ) # Move the lexer to the position after 'x'. self.resume_from(self._current_token.span.start + 1) @@ -1184,14 +1295,13 @@ def _parse_ranked_shape(self) -> tuple[list[int], Attribute]: """ self._synchronize_lexer_and_tokenizer() dims: list[int] = [] - while self._current_token.kind in (Token.Kind.INTEGER_LIT, - Token.Kind.QUESTION): + while self._current_token.kind in (Token.Kind.INTEGER_LIT, Token.Kind.QUESTION): dim = self._parse_shape_dimension() dims.append(dim) self._parse_shape_delimiter() self._synchronize_lexer_and_tokenizer() - type = self.expect(self.try_parse_type, 'Expected shape type.') + type = self.expect(self.try_parse_type, "Expected shape type.") return dims, type def _parse_shape(self) -> tuple[list[int] | None, Attribute]: @@ -1206,9 +1316,9 @@ def _parse_shape(self) -> tuple[list[int] | None, Attribute]: each dimension is also required to be non-negative. """ self._synchronize_lexer_and_tokenizer() - if self.parse_optional_punctuation('*') is not None: + if self.parse_optional_punctuation("*") is not None: self._parse_shape_delimiter() - type = self.expect(self.try_parse_type, 'Expected shape type.') + type = self.expect(self.try_parse_type, "Expected shape type.") self._synchronize_lexer_and_tokenizer() return None, type res = self._parse_ranked_shape() @@ -1224,12 +1334,13 @@ def parse_complex_attrs(self) -> ComplexType: return ComplexType(element_type) def parse_memref_attrs( - self) -> MemRefType[Attribute] | UnrankedMemrefType[Attribute]: + self, + ) -> MemRefType[Attribute] | UnrankedMemrefType[Attribute]: shape, type = self._parse_shape() # Unranked case if shape is None: - if self.parse_optional_punctuation(',') is None: + if self.parse_optional_punctuation(",") is None: self._synchronize_lexer_and_tokenizer() return UnrankedMemrefType.from_type(type) self._synchronize_lexer_and_tokenizer() @@ -1237,7 +1348,7 @@ def parse_memref_attrs( self._synchronize_lexer_and_tokenizer() return UnrankedMemrefType.from_type(type, memory_space) - if self.parse_optional_punctuation(',') is None: + if self.parse_optional_punctuation(",") is None: return MemRefType.from_element_type_and_shape(type, shape) self._synchronize_lexer_and_tokenizer() @@ -1246,12 +1357,13 @@ def parse_memref_attrs( # If there is both a memory space and a layout, we know that the # layout is the second one - if self.parse_optional_punctuation(',') is not None: + if self.parse_optional_punctuation(",") is not None: self._synchronize_lexer_and_tokenizer() memory_space = self.parse_attribute() self._synchronize_lexer_and_tokenizer() return MemRefType.from_element_type_and_shape( - type, shape, memory_or_layout, memory_space) + type, shape, memory_or_layout, memory_space + ) # Otherwise, there is a single argument, so we check based on the # attribute type. If we don't know, we return an error. @@ -1261,28 +1373,30 @@ def parse_memref_attrs( # If the argument is an integer, it is a memory space if isa(memory_or_layout, AnyIntegerAttr): return MemRefType.from_element_type_and_shape( - type, shape, memory_space=memory_or_layout) + type, shape, memory_space=memory_or_layout + ) # We only accept strided layouts and affine_maps - if (isa(memory_or_layout, StridedLayoutAttr) - or (isinstance(memory_or_layout, UnregisteredAttr) - and memory_or_layout.attr_name.data == "affine_map")): + if isa(memory_or_layout, StridedLayoutAttr) or ( + isinstance(memory_or_layout, UnregisteredAttr) + and memory_or_layout.attr_name.data == "affine_map" + ): return MemRefType.from_element_type_and_shape( - type, shape, layout=memory_or_layout) - self.raise_error("Cannot decide if the given attribute " - "is a layout or a memory space!") - - def try_parse_numerical_dims(self, - accept_closing_bracket: bool = False, - lower_bound: int = 1) -> Iterable[int]: - while (shape_arg := - self._try_parse_shape_element(lower_bound)) is not None: + type, shape, layout=memory_or_layout + ) + self.raise_error( + "Cannot decide if the given attribute " "is a layout or a memory space!" + ) + + def try_parse_numerical_dims( + self, accept_closing_bracket: bool = False, lower_bound: int = 1 + ) -> Iterable[int]: + while (shape_arg := self._try_parse_shape_element(lower_bound)) is not None: yield shape_arg # Look out for the closing bracket for scalable vector dims if accept_closing_bracket and self.tokenizer.starts_with("]"): break - self.parse_characters("x", - "Unexpected end of dimension parameters!") + self.parse_characters("x", "Unexpected end of dimension parameters!") def parse_vector_attrs(self) -> AnyVectorType: self._synchronize_lexer_and_tokenizer() @@ -1295,13 +1409,12 @@ def parse_vector_attrs(self) -> AnyVectorType: self._parse_shape_delimiter() # Then, parse the scalable dimensions, if any - if self.parse_optional_punctuation('[') is not None: - + if self.parse_optional_punctuation("[") is not None: # Parse the scalable dimensions dims.append(self._parse_shape_dimension(allow_dynamic=False)) num_scalable_dims += 1 - while self.parse_optional_punctuation(']') is None: + while self.parse_optional_punctuation("]") is None: self._parse_shape_delimiter() dims.append(self._parse_shape_dimension(allow_dynamic=False)) num_scalable_dims += 1 @@ -1315,19 +1428,18 @@ def parse_vector_attrs(self) -> AnyVectorType: self.raise_error("Expected the vector element types!") self._synchronize_lexer_and_tokenizer() - return VectorType.from_element_type_and_shape(type, dims, - num_scalable_dims) + return VectorType.from_element_type_and_shape(type, dims, num_scalable_dims) def parse_tensor_attrs(self) -> AnyTensorType | AnyUnrankedTensorType: shape, type = self._parse_shape() if shape is None: - if self.parse_optional_punctuation(',') is not None: + if self.parse_optional_punctuation(",") is not None: self.raise_error("Unranked tensors don't have an encoding!") return UnrankedTensorType.from_type(type) self._synchronize_lexer_and_tokenizer() - if self.parse_optional_punctuation(',') is not None: + if self.parse_optional_punctuation(",") is not None: self._synchronize_lexer_and_tokenizer() encoding = self.parse_attribute() return TensorType.from_type_and_list(type, shape, encoding) @@ -1348,28 +1460,25 @@ def _try_parse_shape_element(self, lower_bound: int = 1) -> int | None: # TODO: this is ugly, it's a raise inside a try_ type function, which # should instead just give up raise ParseError( - int_lit, - "Shape element literal cannot be negative or zero!") + int_lit, "Shape element literal cannot be negative or zero!" + ) return value - if self.tokenizer.next_token_of_pattern('?') is not None: + if self.tokenizer.next_token_of_pattern("?") is not None: return -1 return None def _parse_type_params(self) -> list[Attribute]: # Consume opening bracket - self.parse_characters('<', 'Type must be parameterized!') + self.parse_characters("<", "Type must be parameterized!") - params = self.parse_list_of(self.try_parse_type, - 'Expected a type here!') + params = self.parse_list_of(self.try_parse_type, "Expected a type here!") - self.parse_characters('>', - 'Expected end of type parameterization here!') + self.parse_characters(">", "Expected end of type parameterization here!") return params - def expect(self, try_parse: Callable[[], T_ | None], - error_message: str) -> T_: + def expect(self, try_parse: Callable[[], T_ | None], error_message: str) -> T_: """ Used to force completion of a try_parse function. Will throw a parse error if it can't. @@ -1379,9 +1488,7 @@ def expect(self, try_parse: Callable[[], T_ | None], self.raise_error(error_message) return res - def raise_error(self, - msg: str, - at_position: Span | None = None) -> NoReturn: + def raise_error(self, msg: str, at_position: Span | None = None) -> NoReturn: """ Helper for raising exceptions, provides as much context as possible to them. @@ -1402,8 +1509,7 @@ def parse_characters(self, text: str, msg: str) -> Span: return match @abstractmethod - def _parse_op_result_list( - self) -> list[tuple[Span, int, Attribute | None]]: + def _parse_op_result_list(self) -> list[tuple[Span, int, Attribute | None]]: raise NotImplementedError() def try_parse_operation(self) -> Operation | None: @@ -1419,8 +1525,8 @@ def parse_operation(self) -> Operation: ret_types = [result[2] for result in results] if len(results) > 0: self.parse_characters( - '=', - 'Operation definitions expect an `=` after op-result-list!') + "=", "Operation definitions expect an `=` after op-result-list!" + ) # Check for custom op format op_name = self.try_parse_bare_id() @@ -1437,10 +1543,10 @@ def parse_operation(self) -> Operation: if op_name is None: self.raise_error( "Expected an operation name here, either a bare-id, or a string " - "literal!") + "literal!" + ) - args, successors, attrs, regions, func_type = self.parse_operation_details( - ) + args, successors, attrs, regions, func_type = self.parse_operation_details() if any(res_type is None for res_type in ret_types): assert func_type is not None @@ -1449,19 +1555,22 @@ def parse_operation(self) -> Operation: op_type = self._get_op_by_name(op_name) - op = op_type.create(operands=args, - result_types=ret_types, - attributes=attrs, - successors=[ - self._get_block_from_name(block_name) - for block_name in successors - ], - regions=regions) + op = op_type.create( + operands=args, + result_types=ret_types, + attributes=attrs, + successors=[ + self._get_block_from_name(block_name) for block_name in successors + ], + regions=regions, + ) expected_results = sum(r[1] for r in results) if len(op.results) != expected_results: - self.raise_error(f'Operation has {len(op.results)} results, ' - f'but were given {expected_results} to bind.') + self.raise_error( + f"Operation has {len(op.results)} results, " + f"but were given {expected_results} to bind." + ) # Register the result SSA value names in the parser res_idx = 0 @@ -1469,9 +1578,11 @@ def parse_operation(self) -> Operation: ssa_val_name = res_span.text[1:] # Removing the leading '%' if ssa_val_name in self.ssa_values: self.raise_error( - f"SSA value %{ssa_val_name} is already defined", res_span) - self.ssa_values[ssa_val_name] = tuple(op.results[res_idx:res_idx + - res_size]) + f"SSA value %{ssa_val_name} is already defined", res_span + ) + self.ssa_values[ssa_val_name] = tuple( + op.results[res_idx : res_idx + res_size] + ) res_idx += res_size # Carry over `ssa_val_name` for non-numeric names: if SSAValue.is_valid_name(ssa_val_name): @@ -1487,12 +1598,13 @@ def _get_op_by_name(self, span: Span) -> type[Operation]: op_name = span.text op_type = self.ctx.get_optional_op( - op_name, allow_unregistered=self.allow_unregistered_dialect) + op_name, allow_unregistered=self.allow_unregistered_dialect + ) if op_type is not None: return op_type - self.raise_error(f'Unknown operation {op_name}!', span) + self.raise_error(f"Unknown operation {op_name}!", span) def parse_region(self) -> Region: old_ssa_values = self.ssa_values.copy() @@ -1515,19 +1627,28 @@ def parse_region(self) -> Region: while self.tokenizer.starts_with("^"): region.add_block(self.parse_block()) - end = self.parse_characters( - "}", "Reached end of region, expected `}`!") + end = self.parse_characters("}", "Reached end of region, expected `}`!") if len(self.forward_block_references) > 0: raise MultipleSpansParseError( end, - "Region ends with missing block declarations for block(s) {}!" - .format(', '.join(self.forward_block_references.keys())), - 'The following block references are dangling:', - [(span, "Reference to block \"{}\" without implementation!" - .format(span.text)) for span in itertools.chain( - *self.forward_block_references.values())], - self.tokenizer.history) + "Region ends with missing block declarations for block(s) {}!".format( + ", ".join(self.forward_block_references.keys()) + ), + "The following block references are dangling:", + [ + ( + span, + 'Reference to block "{}" without implementation!'.format( + span.text + ), + ) + for span in itertools.chain( + *self.forward_block_references.values() + ) + ], + self.tokenizer.history, + ) return region finally: @@ -1555,11 +1676,12 @@ def _parse_attribute_entry(self) -> tuple[Span, Attribute]: "Expected bare-id or string-literal here as part of attribute entry!" ) - if not self.tokenizer.starts_with('='): + if not self.tokenizer.starts_with("="): return name, UnitAttr() self.parse_characters( - "=", "Attribute entries must be of format name `=` attribute!") + "=", "Attribute entries must be of format name `=` attribute!" + ) return name, self.parse_attribute() @@ -1582,10 +1704,11 @@ def _parse_attribute_type(self) -> Attribute: Parses `:` type and returns the type """ self.parse_characters( - ":", "Expected attribute type definition here ( `:` type )") + ":", "Expected attribute type definition here ( `:` type )" + ) return self.expect( - self.try_parse_type, - "Expected attribute type definition here ( `:` type )") + self.try_parse_type, "Expected attribute type definition here ( `:` type )" + ) def try_parse_builtin_attr(self) -> Attribute | None: """ @@ -1598,15 +1721,17 @@ def try_parse_builtin_attr(self) -> Attribute | None: return self.try_parse_builtin_arr_attr() elif next_token.text == "@": return self.try_parse_ref_attr() - elif next_token.text == '{': + elif next_token.text == "{": return self.try_parse_builtin_dict_attr() - elif next_token.text == '(': + elif next_token.text == "(": return self.try_parse_function_type() elif next_token.text in ParserCommons.builtin_attr_names: return self.try_parse_builtin_named_attr() - attrs = (self.parse_optional_builtin_int_or_float_attr, - self.try_parse_builtin_type) + attrs = ( + self.parse_optional_builtin_int_or_float_attr, + self.try_parse_builtin_type, + ) for attr_parser in attrs: if (val := attr_parser()) is not None: @@ -1620,12 +1745,11 @@ def try_parse_builtin_attr(self) -> Attribute | None: return None - def _parse_int_or_question(self, - context_msg: str = "") -> int | Literal['?']: + def _parse_int_or_question(self, context_msg: str = "") -> int | Literal["?"]: """Parse either an integer literal, or a '?'.""" self._synchronize_lexer_and_tokenizer() if self._parse_optional_token(Token.Kind.QUESTION) is not None: - return '?' + return "?" if (v := self.parse_optional_integer(allow_boolean=False)) is not None: return v self.raise_error("Expected an integer literal or `?`" + context_msg) @@ -1641,8 +1765,10 @@ def parse_keyword(self, keyword: str, context_msg: str = "") -> str: def parse_optional_keyword(self, keyword: str) -> str | None: """Parse a specific identifier if it is present""" - if (self._current_token.kind == Token.Kind.BARE_IDENT - and self._current_token.text == keyword): + if ( + self._current_token.kind == Token.Kind.BARE_IDENT + and self._current_token.text == keyword + ): self._consume_token(Token.Kind.BARE_IDENT) return keyword return None @@ -1661,12 +1787,13 @@ def parse_strided_layout_attr(self) -> Attribute: strides = self.parse_comma_separated_list( self.Delimiter.SQUARE, lambda: self._parse_int_or_question(" in stride list"), - " in stride list") + " in stride list", + ) # Pyright widen `Literal['?']` to `str` for some reasons - strides = cast(list[int | Literal['?']], strides) + strides = cast(list[int | Literal["?"]], strides) # Convert to the attribute expected input - strides = [None if stride == '?' else stride for stride in strides] + strides = [None if stride == "?" else stride for stride in strides] # Case without offset if self._parse_optional_token(Token.Kind.GREATER) is not None: @@ -1674,26 +1801,25 @@ def parse_strided_layout_attr(self) -> Attribute: # Parse the optional offset self._parse_token( - Token.Kind.COMMA, - "Expected end of strided attribute or ',' for offset.") + Token.Kind.COMMA, "Expected end of strided attribute or ',' for offset." + ) self.parse_keyword("offset", " after comma") self._parse_token(Token.Kind.COLON, "Expected ':' after 'offset'") offset = self._parse_int_or_question(" in stride offset") - self._parse_token(Token.Kind.GREATER, - "Expected '>' in end of stride attribute") - return StridedLayoutAttr(strides, None if offset == '?' else offset) + self._parse_token(Token.Kind.GREATER, "Expected '>' in end of stride attribute") + return StridedLayoutAttr(strides, None if offset == "?" else offset) def try_parse_builtin_named_attr(self) -> Attribute | None: name = self.tokenizer.next_token(peek=True) with self.backtracking("Builtin attribute {}".format(name.text)): self.tokenizer.consume_peeked(name) parsers = { - 'dense': self._parse_builtin_dense_attr, - 'opaque': self._parse_builtin_opaque_attr, - 'dense_resource': self._parse_builtin_dense_resource_attr, - 'array': self._parse_builtin_array_attr, - 'affine_map': self._parse_builtin_affine_attr, - 'affine_set': self._parse_builtin_affine_attr, + "dense": self._parse_builtin_dense_attr, + "opaque": self._parse_builtin_opaque_attr, + "dense_resource": self._parse_builtin_dense_resource_attr, + "array": self._parse_builtin_array_attr, + "affine_map": self._parse_builtin_affine_attr, + "affine_set": self._parse_builtin_affine_attr, } def not_implemented(_name: Span): @@ -1701,10 +1827,9 @@ def not_implemented(_name: Span): return parsers.get(name.text, not_implemented)(name) - def _parse_builtin_dense_attr(self, - _name: Span) -> DenseIntOrFPElementsAttr: + def _parse_builtin_dense_attr(self, _name: Span) -> DenseIntOrFPElementsAttr: self._synchronize_lexer_and_tokenizer() - self.parse_punctuation('<', ' in dense attribute') + self.parse_punctuation("<", " in dense attribute") # The flatten list of elements values: list[BaseParser._TensorLiteralElement] @@ -1713,48 +1838,49 @@ def _parse_builtin_dense_attr(self, # If it is `[]`, then this is a splat attribute, meaning it has the same # value everywhere. shape: list[int] | None - if self._current_token.text == '>': + if self._current_token.text == ">": values, shape = [], None else: values, shape = self._parse_tensor_literal() - self.parse_punctuation('>', ' in dense attribute') + self.parse_punctuation(">", " in dense attribute") # Parse the dense type. - self.parse_punctuation(':', ' in dense attribute') + self.parse_punctuation(":", " in dense attribute") self._synchronize_lexer_and_tokenizer() - type = self.expect(self.try_parse_type, - 'Dense attribute must be typed!') + type = self.expect(self.try_parse_type, "Dense attribute must be typed!") self._synchronize_lexer_and_tokenizer() # Check that the type is correct. if not isa( - type, RankedVectorOrTensorOf[IntegerType] - | RankedVectorOrTensorOf[IndexType] - | RankedVectorOrTensorOf[AnyFloat]): - self.raise_error('Expected vector or tensor type of ' - 'integer, index, or float type') + type, + RankedVectorOrTensorOf[IntegerType] + | RankedVectorOrTensorOf[IndexType] + | RankedVectorOrTensorOf[AnyFloat], + ): + self.raise_error( + "Expected vector or tensor type of " "integer, index, or float type" + ) # Check that the shape matches the data when given a shaped data. type_shape = [dim.value.data for dim in type.shape.data] num_values = math.prod(type_shape) if shape is None and num_values != 0: - self.raise_error('Expected at least one element in the ' - 'dense literal, but got None') + self.raise_error( + "Expected at least one element in the " "dense literal, but got None" + ) if shape is not None and shape != [] and type_shape != shape: self.raise_error( - f'Shape mismatch in dense literal. Expected {type_shape} ' - f'shape from the type, but got {shape} shape.') + f"Shape mismatch in dense literal. Expected {type_shape} " + f"shape from the type, but got {shape} shape." + ) if any(dim == -1 for dim in type_shape): - self.raise_error( - f'Dense literal attribute should have a static shape.') + self.raise_error(f"Dense literal attribute should have a static shape.") element_type = type.element_type # Convert list of elements to a list of values. if shape != []: - data_values = [ - value.to_type(self, element_type) for value in values - ] + data_values = [value.to_type(self, element_type) for value in values] else: assert len(values) == 1, "Fatal error in parser" data_values = [values[0].to_type(self, element_type)] * num_values @@ -1763,35 +1889,38 @@ def _parse_builtin_dense_attr(self, def _parse_builtin_opaque_attr(self, _name: Span): self.parse_characters("<", "Opaque attribute must be parametrized") - str_lit_list = self.parse_list_of(self.try_parse_string_literal, - 'Expected opaque attr here!') + str_lit_list = self.parse_list_of( + self.try_parse_string_literal, "Expected opaque attr here!" + ) if len(str_lit_list) != 2: - self.raise_error('Opaque expects 2 string literal parameters!') + self.raise_error("Opaque expects 2 string literal parameters!") self.parse_characters( - ">", "Unexpected parameters for opaque attr, expected `>`!") + ">", "Unexpected parameters for opaque attr, expected `>`!" + ) type = NoneAttr() - if self.tokenizer.starts_with(':'): + if self.tokenizer.starts_with(":"): self.parse_characters(":", "opaque attribute must be typed!") - type = self.expect(self.try_parse_type, - "opaque attribute must be typed!") + type = self.expect(self.try_parse_type, "opaque attribute must be typed!") - return OpaqueAttr.from_strings(*(span.string_contents - for span in str_lit_list), - type=type) + return OpaqueAttr.from_strings( + *(span.string_contents for span in str_lit_list), type=type + ) - def _parse_builtin_dense_resource_attr(self, - _name: Span) -> DenseResourceAttr: - err_msg = ("Malformed dense_resource attribute, format must be " - "(`dense_resource` `<` resource-handle `>`)") + def _parse_builtin_dense_resource_attr(self, _name: Span) -> DenseResourceAttr: + err_msg = ( + "Malformed dense_resource attribute, format must be " + "(`dense_resource` `<` resource-handle `>`)" + ) self.parse_characters("<", err_msg) resource_handle = self.expect(self.try_parse_bare_id, err_msg) self.parse_characters(">", err_msg) self.parse_characters(":", err_msg) - type = self.expect(self.try_parse_type, - "Dense resource attribute must be typed!") + type = self.expect( + self.try_parse_type, "Dense resource attribute must be typed!" + ) return DenseResourceAttr.from_params(resource_handle.text, type) def _parse_builtin_array_attr(self, name: Span) -> DenseArrayBase | None: @@ -1804,8 +1933,9 @@ def _parse_builtin_array_attr(self, name: Span) -> DenseArrayBase | None: if not isinstance(element_type, IntegerType | AnyFloat): raise ParseError( - name, "dense array element type must be an " - "integer or floating point type") + name, + "dense array element type must be an " "integer or floating point type", + ) # Empty array if self.try_parse_characters(">"): @@ -1820,8 +1950,9 @@ def try_parse_dense_array_value() -> int | float | None: return int(v.text) return None - values = self.parse_list_of(try_parse_dense_array_value, - "Expected tensor literal here!") + values = self.parse_list_of( + try_parse_dense_array_value, "Expected tensor literal here!" + ) self.parse_characters(">", err_msg) return DenseArrayBase.from_list(element_type, values) @@ -1833,10 +1964,10 @@ def _parse_builtin_affine_attr(self, name: Span) -> UnregisteredAttr: attr_def = self.ctx.get_optional_attr( name.text, allow_unregistered=self.allow_unregistered_dialect, - create_unregistered_as_type=False) + create_unregistered_as_type=False, + ) if attr_def is None: - self.raise_error(f"Unknown {name.text} attribute", - at_position=name) + self.raise_error(f"Unknown {name.text} attribute", at_position=name) assert issubclass( attr_def, UnregisteredAttr ), f"{name.text} was registered, but should be reserved for builtin" @@ -1847,7 +1978,7 @@ def _parse_builtin_affine_attr(self, name: Span) -> UnregisteredAttr: self._synchronize_lexer_and_tokenizer() start_pos = self._current_token.span.start end_pos = start_pos - self.parse_punctuation('<', f' in {name.text} attribute') + self.parse_punctuation("<", f" in {name.text} attribute") # Loop until we see the closing `>`. while True: @@ -1855,8 +1986,7 @@ def _parse_builtin_affine_attr(self, name: Span) -> UnregisteredAttr: # Check for early EOF. if token.kind == Token.Kind.EOF: - self.raise_error( - f"Expected '>' in end of {name.text} attribute") + self.raise_error(f"Expected '>' in end of {name.text} attribute") # Check for closing `>`. if token.kind == Token.Kind.GREATER: @@ -1867,7 +1997,7 @@ def _parse_builtin_affine_attr(self, name: Span) -> UnregisteredAttr: self._consume_token() contents = self.lexer.input.slice(start_pos, end_pos) - assert contents is not None, 'Fatal error in parser' + assert contents is not None, "Fatal error in parser" self._synchronize_lexer_and_tokenizer() return attr_def(name.text, False, contents) @@ -1881,6 +2011,7 @@ class _TensorLiteralElement: This class is used to parse a tensor literal before the tensor literal type is known """ + is_negative: bool value: int | float | bool """ @@ -1890,24 +2021,27 @@ class _TensorLiteralElement: """ span: Span - def to_int(self, - parser: BaseParser, - allow_negative: bool = True, - allow_booleans: bool = True) -> int: + def to_int( + self, + parser: BaseParser, + allow_negative: bool = True, + allow_booleans: bool = True, + ) -> int: """ Convert the element to an int value, possibly disallowing negative values. Raises an error if the type is compatible. """ if self.is_negative and not allow_negative: - parser.raise_error('Expected non-negative integer values', - at_position=self.span) + parser.raise_error( + "Expected non-negative integer values", at_position=self.span + ) if isinstance(self.value, bool) and not allow_booleans: parser.raise_error( - 'Boolean values are only allowed for i1 types', - at_position=self.span) + "Boolean values are only allowed for i1 types", + at_position=self.span, + ) if not isinstance(self.value, bool | int): - parser.raise_error('Expected integer value', - at_position=self.span) + parser.raise_error("Expected integer value", at_position=self.span) if self.is_negative: return -int(self.value) return int(self.value) @@ -1918,26 +2052,24 @@ def to_float(self, parser: BaseParser) -> float: is compatible. """ if not isinstance(self.value, int | float): - parser.raise_error('Expected float value', - at_position=self.span) + parser.raise_error("Expected float value", at_position=self.span) if self.is_negative: return -float(self.value) return float(self.value) - def to_type(self, parser: BaseParser, - type: AnyFloat | IntegerType | IndexType): + def to_type(self, parser: BaseParser, type: AnyFloat | IntegerType | IndexType): if isinstance(type, AnyFloat): return self.to_float(parser) elif isinstance(type, IntegerType): - return self.to_int(parser, type.signedness.data - != Signedness.UNSIGNED, - type.width.data == 1) + return self.to_int( + parser, + type.signedness.data != Signedness.UNSIGNED, + type.width.data == 1, + ) elif isinstance(type, IndexType): - return self.to_int(parser, - allow_negative=True, - allow_booleans=False) + return self.to_int(parser, allow_negative=True, allow_booleans=False) else: - assert False, 'fatal error in parser' + assert False, "fatal error in parser" def _parse_tensor_literal_element(self) -> _TensorLiteralElement: """ @@ -1945,10 +2077,10 @@ def _parse_tensor_literal_element(self) -> _TensorLiteralElement: literal, or a float literal. """ # boolean case - if self._current_token.text == 'true': + if self._current_token.text == "true": token = self._consume_token(Token.Kind.BARE_IDENT) return self._TensorLiteralElement(False, True, token.span) - if self._current_token.text == 'false': + if self._current_token.text == "false": token = self._consume_token(Token.Kind.BARE_IDENT) return self._TensorLiteralElement(False, False, token.span) @@ -1965,15 +2097,15 @@ def _parse_tensor_literal_element(self) -> _TensorLiteralElement: token = self._consume_token(Token.Kind.INTEGER_LIT) value = token.get_int_value() else: - self.raise_error( - "Expected either a float, integer, or complex literal") + self.raise_error("Expected either a float, integer, or complex literal") if is_negative: value = -value return self._TensorLiteralElement(is_negative, value, token.span) def _parse_tensor_literal( - self) -> tuple[list[BaseParser._TensorLiteralElement], list[int]]: + self, + ) -> tuple[list[BaseParser._TensorLiteralElement], list[int]]: """ Parse a tensor literal, and returns its flatten data and its shape. @@ -1981,14 +2113,16 @@ def _parse_tensor_literal( the data, and [2, 3] for the shape. """ if self._current_token.kind == Token.Kind.L_SQUARE: - res = self.parse_comma_separated_list(self.Delimiter.SQUARE, - self._parse_tensor_literal) + res = self.parse_comma_separated_list( + self.Delimiter.SQUARE, self._parse_tensor_literal + ) if len(res) == 0: return [], [0] sub_literal_shape = res[0][1] if any(r[1] != sub_literal_shape for r in res): self.raise_error( - "Tensor literal has inconsistent ranks between elements") + "Tensor literal has inconsistent ranks between elements" + ) shape = [len(res)] + sub_literal_shape values = [elem for sub_list in res for elem in sub_list[0]] return values, shape @@ -2009,18 +2143,18 @@ def try_parse_int_or_float(): return float(literal.text) if (literal := self.try_parse_integer_literal()) is not None: return int(literal.text) - self.raise_error('Expected int or float literal here!') + self.raise_error("Expected int or float literal here!") - if not self.tokenizer.starts_with('['): + if not self.tokenizer.starts_with("["): yield try_parse_int_or_float() return - self.parse_characters('[', '') - while not self.tokenizer.starts_with(']'): + self.parse_characters("[", "") + while not self.tokenizer.starts_with("]"): yield from self._parse_builtin_dense_attr_args() - if self.tokenizer.next_token_of_pattern(',') is None: + if self.tokenizer.next_token_of_pattern(",") is None: break - self.parse_characters(']', '') + self.parse_characters("]", "") def try_parse_ref_attr(self) -> SymbolRefAttr | None: if not self.tokenizer.starts_with("@"): @@ -2031,12 +2165,14 @@ def try_parse_ref_attr(self) -> SymbolRefAttr | None: if len(refs) >= 1: return SymbolRefAttr( StringAttr(refs[0].text), - ArrayAttr([StringAttr(ref.text) for ref in refs[1:]])) + ArrayAttr([StringAttr(ref.text) for ref in refs[1:]]), + ) else: return None def parse_optional_builtin_int_or_float_attr( - self) -> AnyIntegerAttr | AnyFloatAttr | None: + self, + ) -> AnyIntegerAttr | AnyFloatAttr | None: bool = self.try_parse_builtin_boolean_attr() if bool is not None: return bool @@ -2063,14 +2199,14 @@ def parse_optional_builtin_int_or_float_attr( if isinstance(type, IntegerType | IndexType): if isinstance(value, float): - self.raise_error( - 'Floating point value is not valid for integer type.') + self.raise_error("Floating point value is not valid for integer type.") return IntegerAttr(value, type) - self.raise_error('Invalid type given for integer or float attribute.') + self.raise_error("Invalid type given for integer or float attribute.") def try_parse_builtin_boolean_attr( - self) -> IntegerAttr[IntegerType | IndexType] | None: + self, + ) -> IntegerAttr[IntegerType | IndexType] | None: self._synchronize_lexer_and_tokenizer() if (value := self.parse_optional_boolean()) is not None: self._synchronize_lexer_and_tokenizer() @@ -2092,10 +2228,12 @@ def try_parse_builtin_arr_attr(self) -> AnyArrayAttr | None: return None with self.backtracking("array literal"): self.parse_characters("[", "Array literals must start with `[`") - attrs = self.parse_list_of(self.try_parse_attribute, - "Expected array entry!") + attrs = self.parse_list_of( + self.try_parse_attribute, "Expected array entry!" + ) self.parse_characters( - "]", "Malformed array contents (expected end of array here!") + "]", "Malformed array contents (expected end of array here!" + ) return ArrayAttr(attrs) @abstractmethod @@ -2107,21 +2245,24 @@ def parse_optional_dictionary_attr_dict(self) -> dict[str, Attribute]: return dict() self.parse_characters( - "{", "Attribute dictionary must be enclosed in curly brackets") + "{", "Attribute dictionary must be enclosed in curly brackets" + ) attrs = [] - if not self.tokenizer.starts_with('}'): - attrs = self.parse_list_of(self._parse_attribute_entry, - "Expected attribute entry") + if not self.tokenizer.starts_with("}"): + attrs = self.parse_list_of( + self._parse_attribute_entry, "Expected attribute entry" + ) self.parse_characters( - "}", "Attribute dictionary must be enclosed in curly brackets") + "}", "Attribute dictionary must be enclosed in curly brackets" + ) return self._attr_dict_from_tuple_list(attrs) def _attr_dict_from_tuple_list( - self, tuple_list: list[tuple[Span, - Attribute]]) -> dict[str, Attribute]: + self, tuple_list: list[tuple[Span, Attribute]] + ) -> dict[str, Attribute]: """ Convert a list of tuples (Span, Attribute) to a dictionary. This function converts the span to a string, trimming quotes from string literals @@ -2149,21 +2290,19 @@ def parse_function_type(self) -> FunctionType: Uses type-or-type-list-parens internally """ - self.parse_characters( - "(", "First group of function args must start with a `(`") + self.parse_characters("(", "First group of function args must start with a `(`") - args: list[Attribute] = self.parse_list_of(self.try_parse_type, - "Expected type here!") + args: list[Attribute] = self.parse_list_of( + self.try_parse_type, "Expected type here!" + ) self.parse_characters( - ")", - "Malformed function type, expected closing brackets of argument types!" + ")", "Malformed function type, expected closing brackets of argument types!" ) self.parse_characters("->", "Malformed function type, expected `->`!") - return FunctionType.from_lists(args, - self._parse_type_or_type_list_parens()) + return FunctionType.from_lists(args, self._parse_type_or_type_list_parens()) def _parse_type_or_type_list_parens(self) -> list[Attribute]: """ @@ -2174,13 +2313,12 @@ def _parse_type_or_type_list_parens(self) -> list[Attribute]: type-list-no-parens ::= type (`,` type)* """ if self.tokenizer.next_token_of_pattern("(") is not None: - args = self.parse_list_of(self.try_parse_type, - "Expected type here!") + args = self.parse_list_of(self.try_parse_type, "Expected type here!") self.parse_characters(")", "Unclosed function type argument list!") else: arg = self.expect( self.try_parse_type, - "Function type must either be single type or list of types in parentheses" + "Function type must either be single type or list of types in parentheses", ) args = [arg] return args @@ -2212,8 +2350,7 @@ def _parse_builtin_type_with_name(self, name: Span): "u": Signedness.UNSIGNED, "i": Signedness.SIGNLESS, } - return IntegerType(int(re_match.group(1)), - signedness[name.text[0]]) + return IntegerType(int(re_match.group(1)), signedness[name.text[0]]) if name.text == "bf16": return BFloat16Type() @@ -2228,8 +2365,7 @@ def _parse_builtin_type_with_name(self, name: Span): 128: Float128Type, }.get(width, None) if type is None: - self.raise_error( - "Unsupported floating point width: {}".format(width)) + self.raise_error("Unsupported floating point width: {}".format(width)) return type() return self._parse_builtin_parametrized_type(name) @@ -2237,8 +2373,13 @@ def _parse_builtin_type_with_name(self, name: Span): @abstractmethod def parse_operation_details( self, - ) -> tuple[list[SSAValue], list[Span], dict[str, Attribute], list[Region], - FunctionType | None]: + ) -> tuple[ + list[SSAValue], + list[Span], + dict[str, Attribute], + list[Region], + FunctionType | None, + ]: """ Must return a tuple consisting of: - a list of arguments to the operation @@ -2275,29 +2416,28 @@ def parse_op_with_default_format( """ # TODO: remove this function and restructure custom op / irdl parsing assert isinstance(self, XDSLParser) - args, successors, attributes, regions, _ = self.parse_operation_details( + args, successors, attributes, regions, _ = self.parse_operation_details() + + return op_type.create( + operands=args, + result_types=result_types, + attributes=attributes, + successors=[self._get_block_from_name(span) for span in successors], + regions=regions, ) - return op_type.create(operands=args, - result_types=result_types, - attributes=attributes, - successors=[ - self._get_block_from_name(span) - for span in successors - ], - regions=regions) - - def parse_paramattr_parameters(self, - skip_white_space: bool = True - ) -> list[Attribute]: - opening_brackets = self.tokenizer.next_token_of_pattern('<') + def parse_paramattr_parameters( + self, skip_white_space: bool = True + ) -> list[Attribute]: + opening_brackets = self.tokenizer.next_token_of_pattern("<") if opening_brackets is None: return [] - res = self.parse_list_of(self.try_parse_attribute, - 'Expected another attribute here!') + res = self.parse_list_of( + self.try_parse_attribute, "Expected another attribute here!" + ) - if self.tokenizer.next_token_of_pattern('>') is None: + if self.tokenizer.next_token_of_pattern(">") is None: self.raise_error( "Malformed parameter list, expected either another parameter or `>`!" ) @@ -2308,16 +2448,19 @@ def parse_char(self, text: str): self.parse_characters(text, "Expected '{}' here!".format(text)) def parse_str_literal(self) -> str: - return self.expect(self.try_parse_string_literal, - 'Malformed string literal!').string_contents + return self.expect( + self.try_parse_string_literal, "Malformed string literal!" + ).string_contents def parse_op(self) -> Operation: return self.parse_operation() def parse_int_literal(self) -> int: return int( - self.expect(self.try_parse_integer_literal, - 'Expected integer literal here').text) + self.expect( + self.try_parse_integer_literal, "Expected integer literal here" + ).text + ) def try_parse_builtin_dict_attr(self): param = DictionaryAttr.parse_parameter(self) @@ -2333,18 +2476,18 @@ def parse_optional_punctuation( self._synchronize_lexer_and_tokenizer() # This check is only necessary to catch errors made by users that # are not using pyright. - assert Token.Kind.is_spelling_of_punctuation(punctuation), \ - "'parse_optional_punctuation' must be " \ - "called with a valid punctuation" + assert Token.Kind.is_spelling_of_punctuation(punctuation), ( + "'parse_optional_punctuation' must be " "called with a valid punctuation" + ) kind = Token.Kind.get_punctuation_kind_from_spelling(punctuation) if self._parse_optional_token(kind) is not None: self._synchronize_lexer_and_tokenizer() return punctuation return None - def parse_punctuation(self, - punctuation: Token.PunctuationSpelling, - context_msg: str = '') -> Token.PunctuationSpelling: + def parse_punctuation( + self, punctuation: Token.PunctuationSpelling, context_msg: str = "" + ) -> Token.PunctuationSpelling: """ Parse a punctuation. Punctuations are defined by `Token.PunctuationSpelling`. @@ -2352,8 +2495,9 @@ def parse_punctuation(self, self._synchronize_lexer_and_tokenizer() # This check is only necessary to catch errors made by users that # are not using pyright. - assert Token.Kind.is_spelling_of_punctuation(punctuation), \ - "'parse_punctuation' must be called with a valid punctuation" + assert Token.Kind.is_spelling_of_punctuation( + punctuation + ), "'parse_punctuation' must be called with a valid punctuation" kind = Token.Kind.get_punctuation_kind_from_spelling(punctuation) self._parse_token(kind, f"Expected '{punctuation}'" + context_msg) self._synchronize_lexer_and_tokenizer() @@ -2361,7 +2505,6 @@ def parse_punctuation(self, class MLIRParser(BaseParser): - def try_parse_builtin_type(self) -> Attribute | None: """ parse a builtin-type like i32, index, vector etc. @@ -2370,11 +2513,10 @@ def try_parse_builtin_type(self) -> Attribute | None: # Check the function type separately, it is the only # case of a type starting with a symbol next_token = self.tokenizer.next_token(peek=True) - if next_token.text == '(': + if next_token.text == "(": return self.try_parse_function_type() - name = self.tokenizer.next_token_of_pattern( - ParserCommons.builtin_type) + name = self.tokenizer.next_token_of_pattern(ParserCommons.builtin_type) if name is None: self.raise_error("Expected builtin name!") @@ -2414,22 +2556,23 @@ def parse_attribute(self) -> Attribute: return builtin_val def _parse_op_result(self) -> tuple[Span, int, Attribute | None]: - value_token = self._parse_token(Token.Kind.PERCENT_IDENT, - 'Expected result SSA value!') + value_token = self._parse_token( + Token.Kind.PERCENT_IDENT, "Expected result SSA value!" + ) if self._parse_optional_token(Token.Kind.COLON) is None: return (value_token.span, 1, None) - size_token = self._parse_token(Token.Kind.INTEGER_LIT, - 'Expected SSA value tuple size') + size_token = self._parse_token( + Token.Kind.INTEGER_LIT, "Expected SSA value tuple size" + ) size = size_token.get_int_value() return (value_token.span, size, None) - def _parse_op_result_list( - self) -> list[tuple[Span, int, Attribute | None]]: + def _parse_op_result_list(self) -> list[tuple[Span, int, Attribute | None]]: self._synchronize_lexer_and_tokenizer() - res = self.parse_comma_separated_list(self.Delimiter.NONE, - self._parse_op_result, - ' in operation result list') + res = self.parse_comma_separated_list( + self.Delimiter.NONE, self._parse_op_result, " in operation result list" + ) self._synchronize_lexer_and_tokenizer() return res @@ -2438,8 +2581,13 @@ def parse_optional_attr_dict(self) -> dict[str, Attribute]: def parse_operation_details( self, - ) -> tuple[list[SSAValue], list[Span], dict[str, Attribute], list[Region], - FunctionType | None]: + ) -> tuple[ + list[SSAValue], + list[Span], + dict[str, Attribute], + list[Region], + FunctionType | None, + ]: args = self._parse_op_args_list() succ = self._parse_optional_successor_list() @@ -2452,8 +2600,7 @@ def parse_operation_details( attrs = self.parse_optional_attr_dict() self.parse_characters( - ":", - "MLIR Operation definitions must end in a function type signature!" + ":", "MLIR Operation definitions must end in a function type signature!" ) func_type = self.parse_function_type() @@ -2462,19 +2609,17 @@ def parse_operation_details( def _parse_optional_successor_list(self) -> list[Span]: if not self.tokenizer.starts_with("["): return [] - self.parse_characters("[", - "Successor list is enclosed in square brackets") - successors = self.parse_list_of(self.try_parse_block_id, - "Expected a block-id", - allow_empty=False) - self.parse_characters("]", - "Successor list is enclosed in square brackets") + self.parse_characters("[", "Successor list is enclosed in square brackets") + successors = self.parse_list_of( + self.try_parse_block_id, "Expected a block-id", allow_empty=False + ) + self.parse_characters("]", "Successor list is enclosed in square brackets") return successors def _parse_op_args_list(self) -> list[SSAValue]: - return self.parse_comma_separated_list(self.Delimiter.PAREN, - self.parse_operand, - ' in operation argument list') + return self.parse_comma_separated_list( + self.Delimiter.PAREN, self.parse_operand, " in operation argument list" + ) def parse_region_list(self) -> list[Region]: """ @@ -2483,12 +2628,13 @@ def parse_region_list(self) -> list[Region]: regions: list[Region] = [] while not self.tokenizer.is_eof() and self.tokenizer.starts_with("{"): regions.append(self.parse_region()) - if self.tokenizer.starts_with(','): + if self.tokenizer.starts_with(","): self.parse_characters( - ',', - msg='This error should never be printed, please open ' - 'an issue at github.com/xdslproject/xdsl') - if not self.tokenizer.starts_with('{'): + ",", + msg="This error should never be printed, please open " + "an issue at github.com/xdslproject/xdsl", + ) + if not self.tokenizer.starts_with("{"): self.raise_error( "Expected next region (because of `,` after region end)!" ) @@ -2496,14 +2642,12 @@ def parse_region_list(self) -> list[Region]: class XDSLParser(BaseParser): - def try_parse_builtin_type(self) -> Attribute | None: """ parse a builtin-type like i32, index, vector etc. """ with self.backtracking("builtin type"): - name = self.tokenizer.next_token_of_pattern( - ParserCommons.builtin_type_xdsl) + name = self.tokenizer.next_token_of_pattern(ParserCommons.builtin_type_xdsl) if name is None: self.raise_error("Expected builtin name!") # xDSL builtin types have a '!' prefix, we strip that out here @@ -2533,8 +2677,7 @@ def parse_attribute(self) -> Attribute: return value - def _parse_op_result_list( - self) -> list[tuple[Span, int, Attribute | None]]: + def _parse_op_result_list(self) -> list[tuple[Span, int, Attribute | None]]: if not self.tokenizer.starts_with("%"): return [] results = self.parse_list_of( @@ -2563,22 +2706,28 @@ def parse_optional_attr_dict(self) -> dict[str, Attribute]: return dict() self.parse_characters( - "[", - "xDSL Attribute dictionary must be enclosed in square brackets") + "[", "xDSL Attribute dictionary must be enclosed in square brackets" + ) - attrs = self.parse_list_of(self._parse_attribute_entry, - "Expected attribute entry") + attrs = self.parse_list_of( + self._parse_attribute_entry, "Expected attribute entry" + ) self.parse_characters( - "]", - "xDSL Attribute dictionary must be enclosed in square brackets") + "]", "xDSL Attribute dictionary must be enclosed in square brackets" + ) return self._attr_dict_from_tuple_list(attrs) def parse_operation_details( self, - ) -> tuple[list[SSAValue], list[Span], dict[str, Attribute], list[Region], - FunctionType | None]: + ) -> tuple[ + list[SSAValue], + list[Span], + dict[str, Attribute], + list[Region], + FunctionType | None, + ]: """ Must return a tuple consisting of: - a list of arguments to the operation @@ -2599,17 +2748,16 @@ def parse_operation_details( def _parse_optional_successor_list(self) -> list[Span]: if not self.tokenizer.starts_with("("): return [] - self.parse_characters("(", - "Successor list is enclosed in round brackets") - successors = self.parse_list_of(self.try_parse_block_id, - "Expected a block-id", - allow_empty=False) - self.parse_characters(")", - "Successor list is enclosed in round brackets") + self.parse_characters("(", "Successor list is enclosed in round brackets") + successors = self.parse_list_of( + self.try_parse_block_id, "Expected a block-id", allow_empty=False + ) + self.parse_characters(")", "Successor list is enclosed in round brackets") return successors def _parse_dialect_type_or_attribute_inner( - self, kind: Literal['attribute'] | Literal['type']) -> Attribute: + self, kind: Literal["attribute"] | Literal["type"] + ) -> Attribute: if self.tokenizer.starts_with('"'): name = self.try_parse_string_literal() if name is None: @@ -2625,21 +2773,23 @@ def _parse_generic_attribute_args(self, name: StringLiteral): self.raise_error("Unknown attribute name!", name) if not issubclass(attr, ParametrizedAttribute): self.raise_error("Expected ParametrizedAttribute name here!", name) - self.parse_characters('<', - 'Expected generic attribute arguments here!') - args = self.parse_list_of(self.try_parse_attribute, - 'Unexpected end of attribute list!') + self.parse_characters("<", "Expected generic attribute arguments here!") + args = self.parse_list_of( + self.try_parse_attribute, "Unexpected end of attribute list!" + ) self.parse_characters( - '>', 'Malformed attribute arguments, reached end of args list!') + ">", "Malformed attribute arguments, reached end of args list!" + ) return attr(args) def _parse_op_args_list(self) -> list[SSAValue]: + self.parse_characters("(", "Operation args list must be enclosed by brackets!") + args = self.parse_list_of( + self.try_parse_value_id_and_type, "Expected another bare-id here" + ) self.parse_characters( - "(", "Operation args list must be enclosed by brackets!") - args = self.parse_list_of(self.try_parse_value_id_and_type, - "Expected another bare-id here") - self.parse_characters( - ")", "Operation args list must be closed by a closing bracket") + ")", "Operation args list must be closed by a closing bracket" + ) return [self.ssa_values[arg.text[1:]][0] for arg, _ in args] def try_parse_type(self) -> Attribute | None: @@ -2654,16 +2804,15 @@ class Source(Enum): MLIR = 2 -def Parser(ctx: MLContext, - prog: str, - source: Source = Source.XDSL, - filename: str = '', - allow_unregistered_dialect: bool = False) -> BaseParser: - selected_parser = { - Source.XDSL: XDSLParser, - Source.MLIR: MLIRParser - }[source] +def Parser( + ctx: MLContext, + prog: str, + source: Source = Source.XDSL, + filename: str = "", + allow_unregistered_dialect: bool = False, +) -> BaseParser: + selected_parser = {Source.XDSL: XDSLParser, Source.MLIR: MLIRParser}[source] return selected_parser(ctx, prog, filename, allow_unregistered_dialect) -setattr(Parser, 'Source', Source) +setattr(Parser, "Source", Source) diff --git a/xdsl/passes.py b/xdsl/passes.py index e9bd5a98f6..2471cddbe3 100644 --- a/xdsl/passes.py +++ b/xdsl/passes.py @@ -8,9 +8,10 @@ class ModulePass(ABC): A Pass is a named rewrite pass over an IR module. All passes are expected to leave the IR in a valid state after application. - That is, the IR verifies. In turn, all passes can expect the IR they are + That is, the IR verifies. In turn, all passes can expect the IR they are applied to to be in a valid state. """ + name: str @abstractmethod diff --git a/xdsl/pattern_rewriter.py b/xdsl/pattern_rewriter.py index b544959f85..58c35fbd7d 100644 --- a/xdsl/pattern_rewriter.py +++ b/xdsl/pattern_rewriter.py @@ -7,8 +7,7 @@ from typing import Callable, TypeVar, Union, get_args, get_origin from xdsl.dialects.builtin import ModuleOp -from xdsl.ir import (Operation, Region, Block, BlockArgument, Attribute, - SSAValue) +from xdsl.ir import Operation, Region, Block, BlockArgument, Attribute, SSAValue from xdsl.rewriter import Rewriter @@ -19,18 +18,17 @@ class PatternRewriter: Once an operation is matched, this rewriter is used to apply modification to the operation and its children. """ + current_operation: Operation """The matched operation.""" has_erased_matched_operation: bool = field(default=False, init=False) """Was the matched operation erased.""" - added_operations_before: list[Operation] = field(default_factory=list, - init=False) + added_operations_before: list[Operation] = field(default_factory=list, init=False) """The operations added directly before the matched operation.""" - added_operations_after: list[Operation] = field(default_factory=list, - init=False) + added_operations_after: list[Operation] = field(default_factory=list, init=False) """The operations added directly after the matched operation.""" has_done_action: bool = field(default=False, init=False) @@ -59,8 +57,7 @@ def _can_modify_region(self, region: Region) -> bool: def insert_op_before_matched_op(self, op: (Operation | list[Operation])): """Insert operations before the matched operation.""" if self.current_operation.parent is None: - raise Exception( - "Cannot insert an operation before a toplevel operation.") + raise Exception("Cannot insert an operation before a toplevel operation.") self.has_done_action = True block = self.current_operation.parent op = op if isinstance(op, list) else [op] @@ -72,8 +69,7 @@ def insert_op_before_matched_op(self, op: (Operation | list[Operation])): def insert_op_after_matched_op(self, op: (Operation | list[Operation])): """Insert operations after the matched operation.""" if self.current_operation.parent is None: - raise Exception( - "Cannot insert an operation after a toplevel operation.") + raise Exception("Cannot insert an operation after a toplevel operation.") self.has_done_action = True block = self.current_operation.parent op = op if isinstance(op, list) else [op] @@ -82,8 +78,7 @@ def insert_op_after_matched_op(self, op: (Operation | list[Operation])): block.insert_ops_after(op, self.current_operation) self.added_operations_after += op - def insert_op_at_pos(self, op: Operation | list[Operation], block: Block, - pos: int): + def insert_op_at_pos(self, op: Operation | list[Operation], block: Block, pos: int): """Insert operations in a block contained in the matched operation.""" if not self._can_modify_block(block): raise Exception("Cannot insert operations in block.") @@ -93,12 +88,10 @@ def insert_op_at_pos(self, op: Operation | list[Operation], block: Block, return block.insert_op(op, pos) - def insert_op_before(self, op: Operation | list[Operation], - target_op: Operation): + def insert_op_before(self, op: Operation | list[Operation], target_op: Operation): """Insert operations before an operation contained in the matched operation.""" if target_op.parent is None: - raise Exception( - "Cannot insert operations before toplevel operation.") + raise Exception("Cannot insert operations before toplevel operation.") target_block = target_op.parent if not self._can_modify_block(target_block): raise Exception("Cannot insert operations in this block.") @@ -108,12 +101,10 @@ def insert_op_before(self, op: Operation | list[Operation], return target_block.insert_ops_before(op, target_op) - def insert_op_after(self, op: Operation | list[Operation], - target_op: Operation): + def insert_op_after(self, op: Operation | list[Operation], target_op: Operation): """Insert operations after an operation contained in the matched operation.""" if target_op.parent is None: - raise Exception( - "Cannot insert operations after toplevel operation.") + raise Exception("Cannot insert operations after toplevel operation.") target_block = target_op.parent if not self._can_modify_block(target_block): raise Exception("Cannot insert operations in this block.") @@ -145,13 +136,16 @@ def erase_op(self, op: Operation, safe_erase: bool = True): if not self._can_modify_op(op): raise Exception( "PatternRewriter can only erase operations that are the matched operation" - ", or that are contained in the matched operation.") + ", or that are contained in the matched operation." + ) Rewriter.erase_op(op, safe_erase=safe_erase) - def replace_matched_op(self, - new_ops: Operation | list[Operation], - new_results: list[SSAValue | None] | None = None, - safe_erase: bool = True): + def replace_matched_op( + self, + new_ops: Operation | list[Operation], + new_results: list[SSAValue | None] | None = None, + safe_erase: bool = True, + ): """ Replace the matched operation with new operations. Also, optionally specify SSA values to replace the operation results. @@ -162,17 +156,18 @@ def replace_matched_op(self, if not isinstance(new_ops, list): new_ops = [new_ops] self.has_erased_matched_operation = True - Rewriter.replace_op(self.current_operation, - new_ops, - new_results, - safe_erase=safe_erase) + Rewriter.replace_op( + self.current_operation, new_ops, new_results, safe_erase=safe_erase + ) self.added_operations_before += new_ops - def replace_op(self, - op: Operation, - new_ops: Operation | list[Operation], - new_results: list[SSAValue | None] | None = None, - safe_erase: bool = True): + def replace_op( + self, + op: Operation, + new_ops: Operation | list[Operation], + new_results: list[SSAValue | None] | None = None, + safe_erase: bool = True, + ): """ Replace an operation with new operations. The operation should be a child of the matched operation. @@ -186,11 +181,11 @@ def replace_op(self, if not self._can_modify_op(op): raise Exception( "PatternRewriter can only replace operations that are the matched " - "operation, or that are contained in the matched operation.") + "operation, or that are contained in the matched operation." + ) Rewriter.replace_op(op, new_ops, new_results, safe_erase=safe_erase) - def modify_block_argument_type(self, arg: BlockArgument, - new_type: Attribute): + def modify_block_argument_type(self, arg: BlockArgument, new_type: Attribute): """ Modify the type of a block argument. The block should be contained in the matched operation. @@ -202,8 +197,9 @@ def modify_block_argument_type(self, arg: BlockArgument, self.has_done_action = True arg.typ = new_type - def insert_block_argument(self, block: Block, index: int, - typ: Attribute) -> BlockArgument: + def insert_block_argument( + self, block: Block, index: int, typ: Attribute + ) -> BlockArgument: """ Insert a new block argument. The block should be contained in the matched operation. @@ -215,9 +211,7 @@ def insert_block_argument(self, block: Block, index: int, self.has_done_action = True return block.insert_arg(typ, index) - def erase_block_argument(self, - arg: BlockArgument, - safe_erase: bool = True) -> None: + def erase_block_argument(self, arg: BlockArgument, safe_erase: bool = True) -> None: """ Erase a new block argument. The block should be contained in the matched operation. @@ -238,8 +232,9 @@ def inline_block_at_pos(self, block: Block, target_block: Block, pos: int): should be child of the matched operation. """ self.has_done_action = True - if not self._can_modify_block( - target_block) or not self._can_modify_block(block): + if not self._can_modify_block(target_block) or not self._can_modify_block( + block + ): raise Exception( "Cannot modify blocks that are not contained in the matched operation." ) @@ -276,7 +271,8 @@ def inline_block_before(self, block: Block, op: Operation): if not self._can_modify_op(op): raise Exception( "Cannot move block elsewhere than before the matched operation," - " or before an operation child") + " or before an operation child" + ) Rewriter.inline_block_before(block, op) def inline_block_after(self, block: Block, op: Operation): @@ -290,7 +286,8 @@ def inline_block_after(self, block: Block, op: Operation): if op is self.current_operation: return self.inline_block_before_matched_op(block) if not self._can_modify_block(block) or ( - op.parent and not self._can_modify_block(op.parent)): + op.parent and not self._can_modify_block(op.parent) + ): raise Exception( "Cannot move blocks that are not contained in the matched operation." ) @@ -330,27 +327,27 @@ class AnonymousRewritePattern(RewritePattern): """ A rewrite pattern encoded by an anonymous function. """ + func: Callable[[RewritePattern, Operation, PatternRewriter], None] def __init__( self, func: Callable[[RewritePattern, Operation, PatternRewriter], None] - | Callable[[Operation, PatternRewriter], None]): - params = [ - param for param in inspect.signature(func).parameters.values() - ] + | Callable[[Operation, PatternRewriter], None], + ): + params = [param for param in inspect.signature(func).parameters.values()] if len(params) == 2: - def new_func(self: RewritePattern, op: Operation, - rewriter: PatternRewriter): + def new_func( + self: RewritePattern, op: Operation, rewriter: PatternRewriter + ): func(op, rewriter) # type: ignore self.func = new_func else: self.func = func # type: ignore - def match_and_rewrite(self, op: Operation, - rewriter: PatternRewriter) -> None: + def match_and_rewrite(self, op: Operation, rewriter: PatternRewriter) -> None: self.func(self, op, rewriter) @@ -372,34 +369,38 @@ def op_type_rewrite_pattern( if len(params) < 2: raise Exception( "op_type_rewrite_pattern expects the decorated function to " - "have two non-self arguments.") + "have two non-self arguments." + ) is_method = params[0].name == "self" if is_method: if len(params) != 3: raise Exception( "op_type_rewrite_pattern expects the decorated method to " - "have two non-self arguments.") + "have two non-self arguments." + ) else: if len(params) != 2: raise Exception( "op_type_rewrite_pattern expects the decorated function to " - "have two arguments.") + "have two arguments." + ) expected_type: type[_OperationT] = params[-2].annotation - expected_types = (expected_type, ) + expected_types = (expected_type,) if get_origin(expected_type) in [Union, UnionType]: expected_types = get_args(expected_type) if not all(issubclass(t, Operation) for t in expected_types): raise Exception( "op_type_rewrite_pattern expects the first non-self argument " "type hint to be an `Operation` subclass or a union of `Operation` " - "subclasses.") + "subclasses." + ) if not is_method: def op_type_rewrite_pattern_static_wrapper( - self: RewritePattern, op: Operation, - rewriter: PatternRewriter) -> None: + self: RewritePattern, op: Operation, rewriter: PatternRewriter + ) -> None: if not isinstance(op, expected_type): return None func(op, rewriter) # type: ignore @@ -407,8 +408,8 @@ def op_type_rewrite_pattern_static_wrapper( return op_type_rewrite_pattern_static_wrapper def op_type_rewrite_pattern_method_wrapper( - self: _RewritePatternT, op: Operation, - rewriter: PatternRewriter) -> None: + self: _RewritePatternT, op: Operation, rewriter: PatternRewriter + ) -> None: if not isinstance(op, expected_type): return None func(self, op, rewriter) # type: ignore @@ -426,8 +427,7 @@ class GreedyRewritePatternApplier(RewritePattern): rewrite_patterns: list[RewritePattern] """The list of rewrites to apply in order.""" - def match_and_rewrite(self, op: Operation, - rewriter: PatternRewriter) -> None: + def match_and_rewrite(self, op: Operation, rewriter: PatternRewriter) -> None: for pattern in self.rewrite_patterns: pattern.match_and_rewrite(op, rewriter) if rewriter.has_done_action: @@ -482,10 +482,13 @@ def _rewrite_op(self, op: Operation) -> int: if rewriter.has_done_action: # If we produce new operations, we rewrite them recursively if requested if self.apply_recursively: - return (len(rewriter.added_operations_before) + - len(rewriter.added_operations_after) - - int(rewriter.has_erased_matched_operation) - if self.walk_reverse else 0) + return ( + len(rewriter.added_operations_before) + + len(rewriter.added_operations_after) + - int(rewriter.has_erased_matched_operation) + if self.walk_reverse + else 0 + ) # Else, we rewrite only their regions if they are supposed to be # rewritten after else: @@ -496,10 +499,13 @@ def _rewrite_op(self, op: Operation) -> int: self._rewrite_op_regions(op) for new_op in rewriter.added_operations_after: self._rewrite_op_regions(new_op) - return -1 if self.walk_reverse else len( - rewriter.added_operations_before) + len( - rewriter.added_operations_after) + int( - not rewriter.has_erased_matched_operation) + return ( + -1 + if self.walk_reverse + else len(rewriter.added_operations_before) + + len(rewriter.added_operations_after) + + int(not rewriter.has_erased_matched_operation) + ) # Otherwise, we only rewrite the regions of the operation if needed if not self.walk_regions_first: diff --git a/xdsl/printer.py b/xdsl/printer.py index d0a6d5e62a..6856d878ff 100644 --- a/xdsl/printer.py +++ b/xdsl/printer.py @@ -6,25 +6,62 @@ from typing import Iterable, Sequence, TypeVar, Any, Dict, Optional, List, cast from xdsl.dialects.memref import AnyUnrankedMemrefType, MemRefType, UnrankedMemrefType -from xdsl.ir import (BlockArgument, TypeAttribute, SSAValue, Block, Callable, - Attribute, Region, Operation, Data, ParametrizedAttribute) +from xdsl.ir import ( + BlockArgument, + TypeAttribute, + SSAValue, + Block, + Callable, + Attribute, + Region, + Operation, + Data, + ParametrizedAttribute, +) from xdsl.utils.diagnostic import Diagnostic from xdsl.dialects.builtin import ( - AnyIntegerAttr, AnyFloatAttr, AnyUnrankedTensorType, AnyVectorType, - BFloat16Type, ComplexType, DenseArrayBase, DenseIntOrFPElementsAttr, - DenseResourceAttr, Float128Type, Float16Type, Float32Type, Float64Type, - Float80Type, FloatAttr, FloatData, IndexType, IntegerType, NoneAttr, - OpaqueAttr, Signedness, StridedLayoutAttr, StringAttr, SymbolRefAttr, - IntegerAttr, ArrayAttr, IntAttr, TensorType, UnitAttr, FunctionType, - UnrankedTensorType, UnregisteredAttr, UnregisteredOp, VectorType, - DictionaryAttr) + AnyIntegerAttr, + AnyFloatAttr, + AnyUnrankedTensorType, + AnyVectorType, + BFloat16Type, + ComplexType, + DenseArrayBase, + DenseIntOrFPElementsAttr, + DenseResourceAttr, + Float128Type, + Float16Type, + Float32Type, + Float64Type, + Float80Type, + FloatAttr, + FloatData, + IndexType, + IntegerType, + NoneAttr, + OpaqueAttr, + Signedness, + StridedLayoutAttr, + StringAttr, + SymbolRefAttr, + IntegerAttr, + ArrayAttr, + IntAttr, + TensorType, + UnitAttr, + FunctionType, + UnrankedTensorType, + UnregisteredAttr, + UnregisteredOp, + VectorType, + DictionaryAttr, +) indentNumSpaces = 2 @dataclass(eq=False, repr=False) class Printer: - class Target(Enum): XDSL = 1 MLIR = 2 @@ -45,8 +82,9 @@ class Target(Enum): _next_valid_block_id: int = field(default=0, init=False) _current_line: int = field(default=0, init=False) _current_column: int = field(default=0, init=False) - _next_line_callback: List[Callable[[], None]] = field(default_factory=list, - init=False) + _next_line_callback: List[Callable[[], None]] = field( + default_factory=list, init=False + ) def print(self, *argv: Any) -> None: for arg in argv: @@ -74,16 +112,15 @@ def print(self, *argv: Any) -> None: self.print_string(text) def print_string(self, text: str) -> None: - lines = text.split('\n') + lines = text.split("\n") if len(lines) != 1: self._current_line += len(lines) - 1 self._current_column = len(lines[-1]) else: self._current_column += len(lines[-1]) - print(text, end='', file=self.stream) + print(text, end="", file=self.stream) - def _add_message_on_next_line(self, message: str, begin_pos: int, - end_pos: int): + def _add_message_on_next_line(self, message: str, begin_pos: int, end_pos: int): """Add a message that will be displayed on the next line.""" def callback(indent: int = self._indent): @@ -91,11 +128,9 @@ def callback(indent: int = self._indent): self._next_line_callback.append(callback) - def _print_message(self, - message: str, - begin_pos: int, - end_pos: int, - indent: int | None = None): + def _print_message( + self, message: str, begin_pos: int, end_pos: int, indent: int | None = None + ): """ Print a message. This is expected to be called at the beginning of a new line and to create a new @@ -106,9 +141,11 @@ def _print_message(self, indent_size = indent * indentNumSpaces self.print(" " * indent_size) message_end_pos = max(map(len, message.split("\n"))) + indent_size + 2 - first_line = (begin_pos - indent_size) * "-" + ( - end_pos - begin_pos) * "^" + (max(message_end_pos, end_pos) - - end_pos) * "-" + first_line = ( + (begin_pos - indent_size) * "-" + + (end_pos - begin_pos) * "^" + + (max(message_end_pos, end_pos) - end_pos) * "-" + ) self.print(first_line) self._print_new_line(indent=indent, print_message=False) for message_line in message.split("\n"): @@ -118,24 +155,25 @@ def _print_message(self, self.print("-" * (max(message_end_pos, end_pos) - indent_size)) self._print_new_line(indent=0, print_message=False) - T = TypeVar('T') - K = TypeVar('K') - V = TypeVar('V') + T = TypeVar("T") + K = TypeVar("K") + V = TypeVar("V") - def print_list(self, - elems: Iterable[T], - print_fn: Callable[[T], None], - delimiter: str = ", ") -> None: + def print_list( + self, elems: Iterable[T], print_fn: Callable[[T], None], delimiter: str = ", " + ) -> None: for i, elem in enumerate(elems): if i: self.print(delimiter) print_fn(elem) - def print_dictionary(self, - elems: dict[K, V], - print_key: Callable[[K], None], - print_value: Callable[[V], None], - delimiter: str = ", ") -> None: + def print_dictionary( + self, + elems: dict[K, V], + print_key: Callable[[K], None], + print_value: Callable[[V], None], + delimiter: str = ", ", + ) -> None: for i, (key, value) in enumerate(elems.items()): if i: self.print(delimiter) @@ -143,9 +181,9 @@ def print_dictionary(self, self.print("=") print_value(value) - def _print_new_line(self, - indent: int | None = None, - print_message: bool = True) -> None: + def _print_new_line( + self, indent: int | None = None, print_message: bool = True + ) -> None: indent = self._indent if indent is None else indent self.print("\n") if print_message: @@ -209,7 +247,10 @@ def print_ssa_value(self, value: SSAValue) -> None: end_pos = self._current_column self._add_message_on_next_line( "ERROR: SSAValue is not part of the IR, are you sure all operations " - "are added before their uses?", begin_pos, end_pos) + "are added before their uses?", + begin_pos, + end_pos, + ) def _print_operand(self, operand: SSAValue) -> None: self.print_ssa_value(operand) @@ -226,7 +267,7 @@ def print_block_name(self, block: Block) -> None: def print_block(self, block: Block, print_block_name: bool = True) -> None: if not isinstance(block, Block): - raise TypeError('Expected a Block; got %s' % type(block).__name__) + raise TypeError("Expected a Block; got %s" % type(block).__name__) print_block_args = len(block.args) > 0 if print_block_args or print_block_name: @@ -258,8 +299,7 @@ def _print_block_arg(self, arg: BlockArgument) -> None: def print_region(self, region: Region) -> None: if not isinstance(region, Region): - raise TypeError('Expected a Region; got %s' % - type(region).__name__) + raise TypeError("Expected a Region; got %s" % type(region).__name__) print_block_name = len(region.blocks) != 1 @@ -294,9 +334,8 @@ def _print_operands(self, operands: tuple[SSAValue, ...]) -> None: self.print(")") def print_paramattr_parameters( - self, - params: list[Attribute], - always_print_brackets: bool = False) -> None: + self, params: list[Attribute], always_print_brackets: bool = False + ) -> None: if len(params) == 0 and not always_print_brackets: return self.print("<") @@ -324,22 +363,22 @@ def print_attribute(self, attribute: Attribute) -> None: if self.target == self.Target.MLIR: if isinstance(attribute, BFloat16Type): - self.print('bf16') + self.print("bf16") return if isinstance(attribute, Float16Type): - self.print('f16') + self.print("f16") return if isinstance(attribute, Float32Type): - self.print('f32') + self.print("f32") return if isinstance(attribute, Float64Type): - self.print('f64') + self.print("f64") return if isinstance(attribute, Float80Type): - self.print('f80') + self.print("f80") return if isinstance(attribute, Float128Type): - self.print('f128') + self.print("f128") return if isinstance(attribute, StringAttr): @@ -347,23 +386,22 @@ def print_attribute(self, attribute: Attribute) -> None: return if isinstance(attribute, SymbolRefAttr): - self.print(f'@{attribute.root_reference.data}') + self.print(f"@{attribute.root_reference.data}") for ref in attribute.nested_references.data: - self.print(f'::@{ref.data}') + self.print(f"::@{ref.data}") return if isinstance(attribute, IntegerAttr): attribute = cast(AnyIntegerAttr, attribute) # boolean shorthands - if (isinstance((typ := attribute.typ), IntegerType) - and typ.width.data == 1): + if isinstance((typ := attribute.typ), IntegerType) and typ.width.data == 1: self.print("false" if attribute.value.data == 0 else "true") return width = attribute.parameters[0] typ = attribute.parameters[1] - assert (isinstance(width, IntAttr)) + assert isinstance(width, IntAttr) self.print(width.data) self.print(" : ") self.print_attribute(typ) @@ -371,8 +409,9 @@ def print_attribute(self, attribute: Attribute) -> None: if isinstance(attribute, FloatAttr): value = attribute.value - typ = cast(FloatAttr[Float16Type | Float32Type | Float64Type], - attribute).type + typ = cast( + FloatAttr[Float16Type | Float32Type | Float64Type], attribute + ).type self.print(value.data) self.print(" : ") self.print_attribute(typ) @@ -385,9 +424,7 @@ def print_attribute(self, attribute: Attribute) -> None: if isinstance(attribute, ArrayAttr): self.print_string("[") - self.print_list( - attribute.data, # type: ignore - self.print_attribute) + self.print_list(attribute.data, self.print_attribute) # type: ignore self.print_string("]") return @@ -404,14 +441,14 @@ def print_attribute(self, attribute: Attribute) -> None: if isinstance(attribute, DictionaryAttr): self.print_string("{") - self.print_dictionary(attribute.data, self.print_string_literal, - self.print_attribute) + self.print_dictionary( + attribute.data, self.print_string_literal, self.print_attribute + ) self.print_string("}") return # Function types have an alias in MLIR, but not in xDSL - if (isinstance(attribute, FunctionType) - and self.target == self.Target.MLIR): + if isinstance(attribute, FunctionType) and self.target == self.Target.MLIR: self.print("(") self.print_list(attribute.inputs.data, self.print_attribute) self.print(") -> ") @@ -425,35 +462,39 @@ def print_attribute(self, attribute: Attribute) -> None: return # Dense element types have an alias in MLIR, but not in xDSL - if (isinstance(attribute, DenseIntOrFPElementsAttr) - and self.target == self.Target.MLIR): + if ( + isinstance(attribute, DenseIntOrFPElementsAttr) + and self.target == self.Target.MLIR + ): def print_one_elem(val: Attribute): if isinstance(val, IntegerAttr | FloatAttr): self.print(val.value.data) else: - raise Exception("unexpected attribute type " - "in DenseIntOrFPElementsAttr: " - f"{type(val)}") - - def print_dense_list(array: Sequence[AnyIntegerAttr] - | Sequence[AnyFloatAttr], shape: List[int]): - - self.print('[') + raise Exception( + "unexpected attribute type " + "in DenseIntOrFPElementsAttr: " + f"{type(val)}" + ) + + def print_dense_list( + array: Sequence[AnyIntegerAttr] | Sequence[AnyFloatAttr], + shape: List[int], + ): + self.print("[") if len(shape) > 1: k = len(array) // shape[0] self.print_list( - (array[i:i + k] for i in range(0, len(array), k)), - lambda subarray: print_dense_list(subarray, shape[1:])) + (array[i : i + k] for i in range(0, len(array), k)), + lambda subarray: print_dense_list(subarray, shape[1:]), + ) else: self.print_list(array, print_one_elem) - self.print(']') + self.print("]") self.print("dense<") data = attribute.data.data - shape = attribute.shape if attribute.shape_is_complete else [ - len(data) - ] + shape = attribute.shape if attribute.shape_is_complete else [len(data)] assert shape is not None, "If shape is complete, then it cannot be None" if len(data) == 0: pass @@ -472,26 +513,27 @@ def print_dense_list(array: Sequence[AnyIntegerAttr] return # tensor types have an alias in MLIR, but not in xDSL - if ((isinstance(attribute, TensorType)) - and self.target == self.Target.MLIR): + if (isinstance(attribute, TensorType)) and self.target == self.Target.MLIR: attribute = cast(AnyVectorType, attribute) self.print("tensor<") self.print_list( - attribute.shape.data, lambda x: self.print(x.value.data) - if x.value.data != -1 else self.print("?"), "x") + attribute.shape.data, + lambda x: self.print(x.value.data) + if x.value.data != -1 + else self.print("?"), + "x", + ) if len(attribute.shape.data) != 0: self.print("x") self.print(attribute.element_type) - if isinstance(attribute, - TensorType) and attribute.encoding != NoneAttr(): + if isinstance(attribute, TensorType) and attribute.encoding != NoneAttr(): self.print(", ") self.print(attribute.encoding) self.print(">") return # vector types have an alias in MLIR, but not in xDSL - if (isinstance(attribute, VectorType) - and self.target == self.Target.MLIR): + if isinstance(attribute, VectorType) and self.target == self.Target.MLIR: attribute = cast(AnyVectorType, attribute) shape = attribute.get_shape() @@ -500,30 +542,29 @@ def print_dense_list(array: Sequence[AnyIntegerAttr] static_dimensions = shape scalable_dimensions = [] else: - static_dimensions = shape[:-attribute.get_num_scalable_dims()] - scalable_dimensions = shape[-attribute.get_num_scalable_dims( - ):] + static_dimensions = shape[: -attribute.get_num_scalable_dims()] + scalable_dimensions = shape[-attribute.get_num_scalable_dims() :] - self.print('vector<') + self.print("vector<") if len(static_dimensions) != 0: - self.print_list(static_dimensions, lambda x: self.print(x), - 'x') - self.print('x') + self.print_list(static_dimensions, lambda x: self.print(x), "x") + self.print("x") if len(scalable_dimensions) != 0: - self.print('[') - self.print_list(scalable_dimensions, lambda x: self.print(x), - 'x') - self.print(']') - self.print('x') + self.print("[") + self.print_list(scalable_dimensions, lambda x: self.print(x), "x") + self.print("]") + self.print("x") self.print(attribute.element_type) - self.print('>') + self.print(">") return # Unranked tensors have an alias in MLIR, but not in xDSL - if (isinstance(attribute, UnrankedTensorType) - and self.target == self.Target.MLIR): + if ( + isinstance(attribute, UnrankedTensorType) + and self.target == self.Target.MLIR + ): attribute = cast(AnyUnrankedTensorType, attribute) self.print("tensor<*x") self.print(attribute.element_type) @@ -535,27 +576,29 @@ def print_dense_list(array: Sequence[AnyIntegerAttr] self.print("strided<[") def print_int_or_question(value: IntAttr | NoneAttr) -> None: - self.print(value.data if isinstance(value, IntAttr) else '?') + self.print(value.data if isinstance(value, IntAttr) else "?") - self.print_list(attribute.strides.data, print_int_or_question, - ', ') - self.print(']') + self.print_list(attribute.strides.data, print_int_or_question, ", ") + self.print("]") if attribute.offset == IntAttr(0): - self.print('>') + self.print(">") return - self.print(', offset: ') + self.print(", offset: ") print_int_or_question(attribute.offset) - self.print('>') + self.print(">") return # memref types have an alias in MLIR, but not in xDSL - if (isinstance(attribute, MemRefType) - and self.target == self.Target.MLIR): + if isinstance(attribute, MemRefType) and self.target == self.Target.MLIR: attribute = cast(MemRefType[Attribute], attribute) self.print("memref<") self.print_list( - attribute.shape.data, lambda x: self.print(x.value.data) - if x.value.data != -1 else self.print("?"), "x") + attribute.shape.data, + lambda x: self.print(x.value.data) + if x.value.data != -1 + else self.print("?"), + "x", + ) self.print("x", attribute.element_type) if not isinstance(attribute.layout, NoneAttr): self.print(", ", attribute.layout) @@ -565,8 +608,10 @@ def print_int_or_question(value: IntAttr | NoneAttr) -> None: return # Unranked tensors have an alias in MLIR, but not in xDSL - if (isinstance(attribute, UnrankedMemrefType) - and self.target == self.Target.MLIR): + if ( + isinstance(attribute, UnrankedMemrefType) + and self.target == self.Target.MLIR + ): attribute = cast(AnyUnrankedMemrefType, attribute) self.print("memref<*x") self.print(attribute.element_type) @@ -576,14 +621,12 @@ def print_int_or_question(value: IntAttr | NoneAttr) -> None: return # IndexType has an alias in MLIR, but not in xDSL - if (isinstance(attribute, IndexType) - and self.target == self.Target.MLIR): + if isinstance(attribute, IndexType) and self.target == self.Target.MLIR: self.print("index") return # opaque attributes have an alias in MLIR, but not in xDSL - if (isinstance(attribute, OpaqueAttr) - and self.target == self.Target.MLIR): + if isinstance(attribute, OpaqueAttr) and self.target == self.Target.MLIR: self.print("opaque<", attribute.ident, ", ", attribute.value, ">") if not isinstance(attribute.type, NoneAttr): self.print(" : ", attribute.type) @@ -591,8 +634,8 @@ def print_int_or_question(value: IntAttr | NoneAttr) -> None: if isinstance(attribute, UnregisteredAttr): # Do not print `!` or `#` for unregistered builtin attributes - if attribute.attr_name.data not in ['affine_map', 'affine_set']: - self.print('!' if attribute.is_type.data else '#') + if attribute.attr_name.data not in ["affine_map", "affine_set"]: + self.print("!" if attribute.is_type.data else "#") self.print(attribute.attr_name.data, attribute.value.data) return @@ -614,24 +657,25 @@ def print_int_or_question(value: IntAttr | NoneAttr) -> None: return if isinstance(attribute, Data): - self.print(f'!{attribute.name}<') + self.print(f"!{attribute.name}<") attribute = cast(Data[Any], attribute) attribute.print_parameter(self) self.print(">") return assert isinstance( - attribute, - ParametrizedAttribute), f'{attribute}: {type(attribute)}' + attribute, ParametrizedAttribute + ), f"{attribute}: {type(attribute)}" # Print parametrized attribute with default formatting if self.target == self.Target.XDSL and self.print_generic_format: self.print(f'!"{attribute.name}"') - self.print_paramattr_parameters(attribute.parameters, - always_print_brackets=True) + self.print_paramattr_parameters( + attribute.parameters, always_print_brackets=True + ) return - self.print(f'!{attribute.name}') + self.print(f"!{attribute.name}") attribute.print_parameters(self) def print_successors(self, successors: List[Block]): @@ -643,9 +687,9 @@ def print_successors(self, successors: List[Block]): def _print_attr_string(self, attr_tuple: tuple[str, Attribute]) -> None: if isinstance(attr_tuple[1], UnitAttr): - self.print(f"\"{attr_tuple[0]}\"") + self.print(f'"{attr_tuple[0]}"') else: - self.print(f"\"{attr_tuple[0]}\" = ") + self.print(f'"{attr_tuple[0]}" = ') self.print_attribute(attr_tuple[1]) def _print_op_attributes(self, attributes: Dict[str, Attribute]) -> None: @@ -675,8 +719,9 @@ def print_op_with_default_format(self, op: Operation) -> None: # Print the operation type if self.target == self.Target.MLIR: self.print(" : (") - self.print_list(op.operands, - lambda operand: self.print_attribute(operand.typ)) + self.print_list( + op.operands, lambda operand: self.print_attribute(operand.typ) + ) self.print(") -> ") if len(op.results) == 0: self.print("()") @@ -690,14 +735,13 @@ def print_op_with_default_format(self, op: Operation) -> None: else: self.print("(") self.print_list( - op.results, - lambda result: self.print_attribute(result.typ)) + op.results, lambda result: self.print_attribute(result.typ) + ) self.print(")") def print_op(self, op: Operation) -> None: if not isinstance(op, Operation): - raise TypeError('Expected an Operation; got %s' % - type(op).__name__) + raise TypeError("Expected an Operation; got %s" % type(op).__name__) begin_op_pos = self._current_column self._print_results(op) if isinstance(op, UnregisteredOp): @@ -709,8 +753,7 @@ def print_op(self, op: Operation) -> None: end_op_pos = self._current_column if op in self.diagnostic.op_messages: for message in self.diagnostic.op_messages[op]: - self._add_message_on_next_line(message, begin_op_pos, - end_op_pos) + self._add_message_on_next_line(message, begin_op_pos, end_op_pos) if isinstance(op, UnregisteredOp): op_name = op.op_name del op.attributes["op_name__"] diff --git a/xdsl/rewriter.py b/xdsl/rewriter.py index d2a3b7828b..67dd1a6e73 100644 --- a/xdsl/rewriter.py +++ b/xdsl/rewriter.py @@ -5,7 +5,6 @@ class Rewriter: - @staticmethod def erase_op(op: Operation, safe_erase: bool = True): """ @@ -21,11 +20,11 @@ def erase_op(op: Operation, safe_erase: bool = True): @staticmethod def replace_op( - op: Operation, - new_ops: Operation | List[Operation], - new_results: Optional[List[Optional[SSAValue]] - | List[OpResult]] = None, # noqa - safe_erase: bool = True): + op: Operation, + new_ops: Operation | List[Operation], + new_results: Optional[List[Optional[SSAValue]] | List[OpResult]] = None, # noqa + safe_erase: bool = True, + ): """ Replace an operation with multiple new ones. If new_results is specified, map the results of the deleted operations with these @@ -73,11 +72,11 @@ def inline_block_at_pos(block: Block, target_block: Block, pos: int): raise Exception("Cannot inline a block in a child block.") for op in block.ops: for operand in op.operands: - if isinstance(operand, - BlockArgument) and operand.block is block: + if isinstance(operand, BlockArgument) and operand.block is block: raise Exception( "Cannot inline block which has operations using " - "the block arguments.") + "the block arguments." + ) ops = block.ops.copy() for op in ops: op.detach() @@ -91,8 +90,7 @@ def inline_block_before(block: Block, op: Operation): The block operations should not use the block arguments. """ if op.parent is None: - raise Exception( - "Cannot inline a block before a toplevel operation") + raise Exception("Cannot inline a block before a toplevel operation") op_block = op.parent op_pos = op_block.get_operation_index(op) Rewriter.inline_block_at_pos(block, op_block, op_pos) @@ -105,8 +103,7 @@ def inline_block_after(block: Block, op: Operation): The block operations should not use the block arguments. """ if op.parent is None: - raise Exception( - "Cannot inline a block before a toplevel operation") + raise Exception("Cannot inline a block before a toplevel operation") op_block = op.parent op_pos = op_block.get_operation_index(op) Rewriter.inline_block_at_pos(block, op_block, op_pos + 1) @@ -145,17 +142,15 @@ def insert_block_before(block: Block | List[Block], target: Block): def insert_op_after(op: Operation, new_op: Operation): """Inserts a new operation after another operation.""" if op.parent is None: - raise Exception( - "Cannot insert an operation after a toplevel operation") - op.parent.insert_ops_after((new_op, ), op) + raise Exception("Cannot insert an operation after a toplevel operation") + op.parent.insert_ops_after((new_op,), op) @staticmethod def insert_op_before(op: Operation, new_op: Operation): """Inserts a new operation before another operation.""" if op.parent is None: - raise Exception( - "Cannot insert an operation before a toplevel operation") - op.parent.insert_ops_before((new_op, ), op) + raise Exception("Cannot insert an operation before a toplevel operation") + op.parent.insert_ops_before((new_op,), op) @staticmethod def move_region_contents_to_new_regions(region: Region) -> Region: diff --git a/xdsl/rewriting/composable_rewriting/immutable_ir/immutable_ir.py b/xdsl/rewriting/composable_rewriting/immutable_ir/immutable_ir.py index 6b27259396..804dbfd447 100644 --- a/xdsl/rewriting/composable_rewriting/immutable_ir/immutable_ir.py +++ b/xdsl/rewriting/composable_rewriting/immutable_ir/immutable_ir.py @@ -3,7 +3,15 @@ from dataclasses import dataclass from typing import Sequence, TypeGuard, Any from immutabledict import immutabledict -from xdsl.ir import Attribute, Block, BlockArgument, OpResult, Operation, Region, SSAValue +from xdsl.ir import ( + Attribute, + Block, + BlockArgument, + OpResult, + Operation, + Region, + SSAValue, +) from xdsl.utils.exceptions import InvalidIRException from xdsl.utils.immutable_list import IList @@ -14,6 +22,7 @@ class ISSAValue(ABC): Represents an immutable SSA variable. An immutable SSA variable is either an operation result or a basic block argument. """ + typ: Attribute users: IList[IOperation] @@ -36,6 +45,7 @@ def _remove_user(self, op: IOperation): @dataclass(frozen=True) class IOpResult(ISSAValue): """Represents an immutable SSA variable defined by an operation result.""" + op: IOperation index: int @@ -51,6 +61,7 @@ def __eq__(self, __o: IOpResult) -> bool: @dataclass(frozen=True) class IBlockArg(ISSAValue): """Represents an immutable SSA variable defined by a basic block.""" + block: IBlock index: int @@ -61,9 +72,7 @@ def __eq__(self, __o: IBlockArg) -> bool: return self is __o def __repr__(self) -> str: - return "BlockArg(type:" + self.typ.name + ( - "attached" - if self.block is not None else "unattached") + ")" # type: ignore + return "BlockArg(type:" + self.typ.name + ("attached") + ")" # type: ignore @dataclass(frozen=True) @@ -88,7 +97,8 @@ def block(self) -> IBlock: if len(self.blocks) != 1: raise ValueError( "'block' property of IRegion class is only available " - "for single-block regions.") + "for single-block regions." + ) return self.blocks[0] @property @@ -100,7 +110,8 @@ def ops(self) -> IList[IOperation]: if len(self.blocks) != 1: raise ValueError( "'ops' property of IRegion class is only available " - "for single-block regions.") + "for single-block regions." + ) return self.block.ops def __init__(self, blocks: Sequence[IBlock]): @@ -129,7 +140,8 @@ def from_mutable( if blocks[0].parent is None: raise InvalidIRException( "Cannot create an IRegion from a mutable Block " - "that is not attached to a Region.") + "that is not attached to a Region." + ) # adding dummy block mappings so that ops have a successor to reference # when the actual block is created all successor references will be moved @@ -138,8 +150,7 @@ def from_mutable( block_map[block] = IBlock([], []) immutable_blocks = [ - IBlock.from_mutable(block, value_map, block_map) - for block in blocks + IBlock.from_mutable(block, value_map, block_map) for block in blocks ] region = IRegion(immutable_blocks) @@ -153,15 +164,22 @@ def from_mutable( dummy_index = op.successors.index(dummy_block) # replace dummy successor with actual successor object.__setattr__( - op, "successors", - IList(op.successors[:dummy_index] + [imm_block] + - op.successors[dummy_index + 1:])) + op, + "successors", + IList( + op.successors[:dummy_index] + + [imm_block] + + op.successors[dummy_index + 1 :] + ), + ) return region - def to_mutable(self, - value_mapping: dict[ISSAValue, SSAValue] | None = None, - block_mapping: dict[IBlock, Block] | None = None) -> Region: + def to_mutable( + self, + value_mapping: dict[ISSAValue, SSAValue] | None = None, + block_mapping: dict[IBlock, Block] | None = None, + ) -> Region: """ Returns a mutable region that is a copy of this immutable region. The value_mapping and block_mapping are used to map already known correspondings @@ -175,19 +193,18 @@ def to_mutable(self, # All mutable blocks have to be initialized first so that ops can # refer to them in their successor lists. for block in self.blocks: - mutable_blocks.append(mutable_block := Block( - arg_types=block.arg_types)) + mutable_blocks.append(mutable_block := Block(arg_types=block.arg_types)) block_mapping[block] = mutable_block for block in self.blocks: # This will use the already created Block and populate it - block.to_mutable(value_mapping=value_mapping, - block_mapping=block_mapping) + block.to_mutable(value_mapping=value_mapping, block_mapping=block_mapping) return Region(mutable_blocks) @dataclass(frozen=True) class IBlock: """An immutable block contains a list of immutable operations. IBlocks are contained in IRegions.""" + args: IList[IBlockArg] ops: IList[IOperation] @@ -197,26 +214,27 @@ def arg_types(self) -> list[Attribute]: return frozen_arg_types def __hash__(self) -> int: - return (id(self)) + return id(self) def __eq__(self, __o: object) -> bool: return self is __o def __repr__(self) -> str: - return "block of" + str(len( - self.ops)) + " operations with args: " + str(self.args) + return ( + "block of" + str(len(self.ops)) + " operations with args: " + str(self.args) + ) def __post_init__(self): for arg in self.args: object.__setattr__(arg, "block", self) - def __init__(self, args: Sequence[Attribute] | Sequence[IBlockArg], - ops: Sequence[IOperation]): + def __init__( + self, args: Sequence[Attribute] | Sequence[IBlockArg], ops: Sequence[IOperation] + ): """Creates a new immutable block.""" # Type Guards: - def is_iblock_arg_seq( - list: Sequence[Any]) -> TypeGuard[Sequence[IBlockArg]]: + def is_iblock_arg_seq(list: Sequence[Any]) -> TypeGuard[Sequence[IBlockArg]]: return all([isinstance(elem, IBlockArg) for elem in list]) def is_type_seq(list: Sequence[Any]) -> TypeGuard[Sequence[Attribute]]: @@ -224,8 +242,7 @@ def is_type_seq(list: Sequence[Any]) -> TypeGuard[Sequence[Attribute]]: if is_type_seq(args): block_args: Sequence[IBlockArg] = [ - IBlockArg(type, IList([]), self, idx) - for idx, type in enumerate(args) + IBlockArg(type, IList([]), self, idx) for idx, type in enumerate(args) ] elif is_iblock_arg_seq(args): block_args: Sequence[IBlockArg] = args @@ -248,7 +265,7 @@ def from_mutable( block_map: dict[Block, IBlock] | None = None, ) -> IBlock: """ - Creates an immutable block from a mutable block. + Creates an immutable block from a mutable block. The value_map and block_map are used to map already known correspondings of mutable values to immutable values and mutable blocks to immutable blocks. """ @@ -262,25 +279,25 @@ def from_mutable( # The IBlock that will house this IBlockArg is not constructed yet. # After construction the block field will be set by the IBlock. immutable_arg = IBlockArg( - arg.typ, - IList([]), - None, # type: ignore - arg.index) + arg.typ, IList([]), None, arg.index # type: ignore + ) args.append(immutable_arg) value_map[arg] = immutable_arg immutable_ops = [ - IOperation.from_mutable(op, - value_map=value_map, - block_map=block_map, - existing_operands=None) for op in block.ops + IOperation.from_mutable( + op, value_map=value_map, block_map=block_map, existing_operands=None + ) + for op in block.ops ] return IBlock(args, immutable_ops) - def to_mutable(self, - value_mapping: dict[ISSAValue, SSAValue] | None = None, - block_mapping: dict[IBlock, Block] | None = None) -> Block: + def to_mutable( + self, + value_mapping: dict[ISSAValue, SSAValue] | None = None, + block_mapping: dict[IBlock, Block] | None = None, + ) -> Block: """ Returns a mutable block that is a copy of this immutable block. The value_mapping and block_mapping are used to map already known correspondings @@ -302,8 +319,10 @@ def to_mutable(self, for immutable_op in self.ops: mutable_block.add_op( - immutable_op.to_mutable(value_mapping=value_mapping, - block_mapping=block_mapping)) + immutable_op.to_mutable( + value_mapping=value_mapping, block_mapping=block_mapping + ) + ) return mutable_block @@ -315,8 +334,7 @@ def get_immutable_copy(op: Operation) -> IOperation: class IOperation: """Represents an immutable operation.""" - __match_args__ = ("op_type", "operands", "results", "successors", - "regions") + __match_args__ = ("op_type", "operands", "results", "successors", "regions") name: str op_type: type[Operation] attributes: immutabledict[str, Attribute] @@ -325,12 +343,16 @@ class IOperation: successors: IList[IBlock] regions: IList[IRegion] - def __init__(self, name: str, op_type: type[Operation], - attributes: immutabledict[str, Attribute], - operands: Sequence[ISSAValue], - result_types: Sequence[Attribute], - successors: Sequence[IBlock], - regions: Sequence[IRegion]) -> None: + def __init__( + self, + name: str, + op_type: type[Operation], + attributes: immutabledict[str, Attribute], + operands: Sequence[ISSAValue], + result_types: Sequence[Attribute], + successors: Sequence[IBlock], + regions: Sequence[IRegion], + ) -> None: object.__setattr__(self, "name", name) object.__setattr__(self, "op_type", op_type) object.__setattr__(self, "attributes", attributes) @@ -338,11 +360,15 @@ def __init__(self, name: str, op_type: type[Operation], for operand in operands: operand._add_user(self) # type: ignore object.__setattr__( - self, "results", - IList([ - IOpResult(type, IList([]), self, idx) - for idx, type in enumerate(result_types) - ])) + self, + "results", + IList( + [ + IOpResult(type, IList([]), self, idx) + for idx, type in enumerate(result_types) + ] + ), + ) object.__setattr__(self, "successors", IList(successors)) object.__setattr__(self, "regions", IList(regions)) @@ -352,13 +378,19 @@ def __init__(self, name: str, op_type: type[Operation], self.regions.freeze() @classmethod - def get(cls, name: str, op_type: type[Operation], - operands: Sequence[ISSAValue], result_types: Sequence[Attribute], - attributes: immutabledict[str, - Attribute], successors: Sequence[IBlock], - regions: Sequence[IRegion]) -> IOperation: - return cls(name, op_type, attributes, operands, result_types, - successors, regions) + def get( + cls, + name: str, + op_type: type[Operation], + operands: Sequence[ISSAValue], + result_types: Sequence[Attribute], + attributes: immutabledict[str, Attribute], + successors: Sequence[IBlock], + regions: Sequence[IRegion], + ) -> IOperation: + return cls( + name, op_type, attributes, operands, result_types, successors, regions + ) def __hash__(self) -> int: return hash(id(self)) @@ -375,7 +407,8 @@ def result(self) -> IOpResult: if len(self.results) != 1: raise ValueError( "'result' property of IOperation class is only available " - "for IOperations with exactly one result.") + "for IOperations with exactly one result." + ) return self.results[0] @property @@ -387,7 +420,8 @@ def region(self) -> IRegion: if len(self.regions) != 1: raise ValueError( "'region' property of IOperation class is only available " - "for IOperations with exactly one region.") + "for IOperations with exactly one region." + ) return self.regions[0] @property @@ -395,9 +429,10 @@ def result_types(self) -> list[Attribute]: return [result.typ for result in self.results] def to_mutable( - self, - value_mapping: dict[ISSAValue, SSAValue] | None = None, - block_mapping: dict[IBlock, Block] | None = None) -> Operation: + self, + value_mapping: dict[ISSAValue, SSAValue] | None = None, + block_mapping: dict[IBlock, Block] | None = None, + ) -> Operation: """ Returns a mutable operation that is a copy of this immutable operation. The value_mapping and block_mapping are used to map already known correspondings @@ -416,11 +451,7 @@ def to_mutable( print(f"ERROR: op {self.name} uses SSAValue before definition") # Continuing to enable printing the IR including missing # operands for investigation - mutable_operands.append( - OpResult( - operand.typ, - None, # type: ignore - 0)) + mutable_operands.append(OpResult(operand.typ, None, 0)) # type: ignore mutable_successors: list[Block] = [] for successor in self.successors: @@ -428,20 +459,24 @@ def to_mutable( mutable_successors.append(block_mapping[successor]) else: raise InvalidIRException( - "Invalid IR: Block is not defined in the current region") + "Invalid IR: Block is not defined in the current region" + ) mutable_regions: list[Region] = [] for region in self.regions: mutable_regions.append( - region.to_mutable(value_mapping=value_mapping, - block_mapping=block_mapping)) + region.to_mutable( + value_mapping=value_mapping, block_mapping=block_mapping + ) + ) new_op: Operation = self.op_type.create( operands=mutable_operands, result_types=[result.typ for result in self.results], attributes=dict(self.attributes), successors=mutable_successors, - regions=mutable_regions) + regions=mutable_regions, + ) # Add the results of this operation to the value mapping # so other operations can use them as operands. @@ -453,11 +488,11 @@ def to_mutable( @classmethod def from_mutable( - cls, - op: Operation, - value_map: dict[SSAValue, ISSAValue] | None = None, - block_map: dict[Block, IBlock] | None = None, - existing_operands: Sequence[ISSAValue] | None = None + cls, + op: Operation, + value_map: dict[SSAValue, ISSAValue] | None = None, + block_map: dict[Block, IBlock] | None = None, + existing_operands: Sequence[ISSAValue] | None = None, ) -> IOperation: """ Returns an immutable operation that is a copy of the given mutable operation. @@ -483,8 +518,8 @@ def from_mutable( elif isinstance(operand, BlockArgument): if operand not in value_map: raise Exception( - "Block argument expected in mapping for op: " + - op.name) + "Block argument expected in mapping for op: " + op.name + ) operands.append(value_map[operand]) else: raise Exception( @@ -493,8 +528,7 @@ def from_mutable( else: operands.extend(existing_operands) - attributes: immutabledict[str, - Attribute] = immutabledict(op.attributes) + attributes: immutabledict[str, Attribute] = immutabledict(op.attributes) successors: list[IBlock] = [] for successor in op.successors: @@ -503,16 +537,22 @@ def from_mutable( else: raise Exception( "Successor not defined in current region, `from_mutable`\ - probably has to be called on the parent operation.") + probably has to be called on the parent operation." + ) regions: list[IRegion] = [] for region in op.regions: - regions.append( - IRegion.from_mutable(region.blocks, value_map, block_map)) - - immutable_op = IOperation.get(op.name, op_type, operands, - [result.typ for result in op.results], - attributes, successors, regions) + regions.append(IRegion.from_mutable(region.blocks, value_map, block_map)) + + immutable_op = IOperation.get( + op.name, + op_type, + operands, + [result.typ for result in op.results], + attributes, + successors, + regions, + ) for idx, result in enumerate(op.results): value_map[result] = immutable_op.results[idx] diff --git a/xdsl/transforms/experimental/Apply1DMPIToStencil.py b/xdsl/transforms/experimental/Apply1DMPIToStencil.py index 47382a707e..6d2ba9a668 100644 --- a/xdsl/transforms/experimental/Apply1DMPIToStencil.py +++ b/xdsl/transforms/experimental/Apply1DMPIToStencil.py @@ -2,10 +2,13 @@ from xdsl.ir import Operation, MLContext, Block, TypeAttribute from xdsl.irdl import Operand from xdsl.utils.hints import isa -from xdsl.pattern_rewriter import (RewritePattern, PatternRewriter, - op_type_rewrite_pattern, - PatternRewriteWalker, - GreedyRewritePatternApplier) +from xdsl.pattern_rewriter import ( + RewritePattern, + PatternRewriter, + op_type_rewrite_pattern, + PatternRewriteWalker, + GreedyRewritePatternApplier, +) from xdsl.dialects import builtin, llvm, arith, mpi, memref, scf from xdsl.dialects.experimental import stencil @@ -13,13 +16,14 @@ class ApplyMPIToExternalLoad(RewritePattern): - @op_type_rewrite_pattern - def match_and_rewrite(self, op: stencil.ExternalLoadOp, - rewriter: PatternRewriter, /): + def match_and_rewrite( + self, op: stencil.ExternalLoadOp, rewriter: PatternRewriter, / + ): assert isa(op.field.typ, memref.MemRefType[AnyNumericType]) memref_type: memref.MemRefType[AnyNumericType] = op.field.typ - if len(memref_type.shape) <= 1: return + if len(memref_type.shape) <= 1: + return mpi_operations: List[Operation] = [] # Rank and size @@ -36,19 +40,16 @@ def match_and_rewrite(self, op: stencil.ExternalLoadOp, three = arith.Constant.from_int_and_width(3, 32) four = arith.Constant.from_int_and_width(4, 32) eight_i64 = arith.Constant.from_int_and_width(8, 64) - mpi_operations += [ - zero, one, one_i64, two, two_i64, three, four, eight_i64 - ] + mpi_operations += [zero, one, one_i64, two, two_i64, three, four, eight_i64] # The underlying datatype we use in communications and size in dimension zero element_type: TypeAttribute = memref_type.element_type datatype_op = mpi.GetDtypeOp.get(element_type) int_attr: builtin.IntegerAttr[builtin.IndexType] = builtin.IntegerAttr( - 0, builtin.IndexType()) - dim_zero_const = arith.Constant.from_attr(int_attr, - builtin.IndexType()) - dim_zero_size_op = memref.Dim.from_source_and_index( - op.field, dim_zero_const) + 0, builtin.IndexType() + ) + dim_zero_const = arith.Constant.from_attr(int_attr, builtin.IndexType()) + dim_zero_size_op = memref.Dim.from_source_and_index(op.field, dim_zero_const) dim_zero_i32_op = arith.IndexCastOp.get(dim_zero_size_op, builtin.i32) dim_zero_i64_op = arith.IndexCastOp.get(dim_zero_size_op, builtin.i64) @@ -56,8 +57,13 @@ def match_and_rewrite(self, op: stencil.ExternalLoadOp, index_memref_i64 = arith.IndexCastOp.get(index_memref, builtin.i64) mpi_operations += [ - datatype_op, dim_zero_const, dim_zero_size_op, dim_zero_i32_op, - dim_zero_i64_op, index_memref, index_memref_i64 + datatype_op, + dim_zero_const, + dim_zero_size_op, + dim_zero_i32_op, + dim_zero_i64_op, + index_memref, + index_memref_i64, ] # Four request handles, one for send, one for recv @@ -75,11 +81,13 @@ def match_and_rewrite(self, op: stencil.ExternalLoadOp, alloc_lookup_op_one = mpi.VectorGetOp.get(alloc_request_op, one) alloc_lookup_op_two = mpi.VectorGetOp.get(alloc_request_op, two) alloc_lookup_op_three = mpi.VectorGetOp.get(alloc_request_op, three) - mpi_request_null = arith.Constant.from_int_and_width( - 0x2c000000, builtin.i32) + mpi_request_null = arith.Constant.from_int_and_width(0x2C000000, builtin.i32) mpi_operations += [ - alloc_lookup_op_zero, alloc_lookup_op_one, alloc_lookup_op_two, - alloc_lookup_op_three, mpi_request_null + alloc_lookup_op_zero, + alloc_lookup_op_one, + alloc_lookup_op_two, + alloc_lookup_op_three, + mpi_request_null, ] # Send and recv my first row of data to rank -1 @@ -89,27 +97,50 @@ def match_and_rewrite(self, op: stencil.ExternalLoadOp, added_ptr = arith.Addi.get(index_memref_i64, add_offset) send_ptr = llvm.IntToPtrOp.get(added_ptr) - mpi_send_top_op = mpi.Isend.get(send_ptr, dim_zero_i32_op, datatype_op, - rank_m1_op, zero, alloc_lookup_op_zero) + mpi_send_top_op = mpi.Isend.get( + send_ptr, + dim_zero_i32_op, + datatype_op, + rank_m1_op, + zero, + alloc_lookup_op_zero, + ) recv_ptr = llvm.IntToPtrOp.get(index_memref_i64) - mpi_recv_top_op = mpi.Irecv.get(recv_ptr, dim_zero_i32_op, datatype_op, - rank_m1_op, zero, alloc_lookup_op_one) + mpi_recv_top_op = mpi.Irecv.get( + recv_ptr, + dim_zero_i32_op, + datatype_op, + rank_m1_op, + zero, + alloc_lookup_op_one, + ) # Else set empty request handles zero_conv = builtin.UnrealizedConversionCastOp.get( - [alloc_lookup_op_zero], [llvm.LLVMPointerType.typed(builtin.i32)]) + [alloc_lookup_op_zero], [llvm.LLVMPointerType.typed(builtin.i32)] + ) null_req_zero = llvm.StoreOp.get(mpi_request_null, zero_conv) one_conv = builtin.UnrealizedConversionCastOp.get( - [alloc_lookup_op_one], [llvm.LLVMPointerType.typed(builtin.i32)]) + [alloc_lookup_op_one], [llvm.LLVMPointerType.typed(builtin.i32)] + ) null_req_one = llvm.StoreOp.get(mpi_request_null, one_conv) - top_halo_exhange = scf.If.get(compare_top_op, [], [ - rank_m1_op, add_offset, added_ptr, send_ptr, mpi_send_top_op, - recv_ptr, mpi_recv_top_op, - scf.Yield.get() - ], [zero_conv, null_req_zero, one_conv, null_req_one, - scf.Yield.get()]) + top_halo_exhange = scf.If.get( + compare_top_op, + [], + [ + rank_m1_op, + add_offset, + added_ptr, + send_ptr, + mpi_send_top_op, + recv_ptr, + mpi_recv_top_op, + scf.Yield.get(), + ], + [zero_conv, null_req_zero, one_conv, null_req_one, scf.Yield.get()], + ) mpi_operations += [top_halo_exhange] # Send and recv my last row of data to rank +1 @@ -122,9 +153,14 @@ def match_and_rewrite(self, op: stencil.ExternalLoadOp, added_ptr_b_send = arith.Addi.get(index_memref_i64, add_offset_b_send) ptr_b_send = llvm.IntToPtrOp.get(added_ptr_b_send) - mpi_send_bottom_op = mpi.Isend.get(ptr_b_send, dim_zero_i32_op, - datatype_op, rank_p1_op, zero, - alloc_lookup_op_two) + mpi_send_bottom_op = mpi.Isend.get( + ptr_b_send, + dim_zero_i32_op, + datatype_op, + rank_p1_op, + zero, + alloc_lookup_op_two, + ) # Now do the recv @@ -134,28 +170,46 @@ def match_and_rewrite(self, op: stencil.ExternalLoadOp, added_ptr_b_recv = arith.Addi.get(index_memref_i64, add_offset_b_recv) ptr_b_recv = llvm.IntToPtrOp.get(added_ptr_b_recv) - mpi_recv_bottom_op = mpi.Irecv.get(ptr_b_recv, dim_zero_i32_op, - datatype_op, rank_p1_op, zero, - alloc_lookup_op_three) + mpi_recv_bottom_op = mpi.Irecv.get( + ptr_b_recv, + dim_zero_i32_op, + datatype_op, + rank_p1_op, + zero, + alloc_lookup_op_three, + ) # Else set empty request handles two_conv = builtin.UnrealizedConversionCastOp.get( - [alloc_lookup_op_two], [llvm.LLVMPointerType.typed(builtin.i32)]) + [alloc_lookup_op_two], [llvm.LLVMPointerType.typed(builtin.i32)] + ) null_req_two = llvm.StoreOp.get(mpi_request_null, two_conv) three_conv = builtin.UnrealizedConversionCastOp.get( - [alloc_lookup_op_three], [llvm.LLVMPointerType.typed(builtin.i32)]) + [alloc_lookup_op_three], [llvm.LLVMPointerType.typed(builtin.i32)] + ) null_req_three = llvm.StoreOp.get(mpi_request_null, three_conv) - bottom_halo_exhange = scf.If.get(compare_bottom_op, [], [ - rank_p1_op, col_row_b_send, element_b_send, add_offset_b_send, - added_ptr_b_send, ptr_b_send, mpi_send_bottom_op, col_row_b_recv, - element_b_recv, add_offset_b_recv, added_ptr_b_recv, ptr_b_recv, - mpi_recv_bottom_op, - scf.Yield.get() - ], [ - two_conv, null_req_two, three_conv, null_req_three, - scf.Yield.get() - ]) + bottom_halo_exhange = scf.If.get( + compare_bottom_op, + [], + [ + rank_p1_op, + col_row_b_send, + element_b_send, + add_offset_b_send, + added_ptr_b_send, + ptr_b_send, + mpi_send_bottom_op, + col_row_b_recv, + element_b_recv, + add_offset_b_recv, + added_ptr_b_recv, + ptr_b_recv, + mpi_recv_bottom_op, + scf.Yield.get(), + ], + [two_conv, null_req_two, three_conv, null_req_three, scf.Yield.get()], + ) mpi_operations += [bottom_halo_exhange] req_ops: Operand = alloc_request_op.results[0] @@ -172,8 +226,12 @@ def match_and_rewrite(self, op: stencil.ExternalLoadOp, def Apply1DMpi(ctx: MLContext, module: builtin.ModuleOp): applyMPI = ApplyMPIToExternalLoad() - walker1 = PatternRewriteWalker(GreedyRewritePatternApplier([ - applyMPI, - ]), - apply_recursively=True) + walker1 = PatternRewriteWalker( + GreedyRewritePatternApplier( + [ + applyMPI, + ] + ), + apply_recursively=True, + ) walker1.rewrite_module(module) diff --git a/xdsl/transforms/experimental/ConvertStencilToLLMLIR.py b/xdsl/transforms/experimental/ConvertStencilToLLMLIR.py index ffdf6de485..cecf836a6d 100644 --- a/xdsl/transforms/experimental/ConvertStencilToLLMLIR.py +++ b/xdsl/transforms/experimental/ConvertStencilToLLMLIR.py @@ -3,9 +3,13 @@ from warnings import warn -from xdsl.pattern_rewriter import (PatternRewriter, PatternRewriteWalker, - RewritePattern, GreedyRewritePatternApplier, - op_type_rewrite_pattern) +from xdsl.pattern_rewriter import ( + PatternRewriter, + PatternRewriteWalker, + RewritePattern, + GreedyRewritePatternApplier, + op_type_rewrite_pattern, +) from xdsl.ir import Block, MLContext, Operation, Region from xdsl.irdl import Attribute from xdsl.dialects.builtin import FunctionType @@ -13,18 +17,29 @@ from xdsl.dialects.memref import MemRefType from xdsl.dialects import memref, arith, scf, builtin, gpu -from xdsl.dialects.experimental.stencil import (AccessOp, ApplyOp, CastOp, - FieldType, IndexAttr, LoadOp, - ReturnOp, StoreOp, TempType, - ExternalLoadOp, - ExternalStoreOp) +from xdsl.dialects.experimental.stencil import ( + AccessOp, + ApplyOp, + CastOp, + FieldType, + IndexAttr, + LoadOp, + ReturnOp, + StoreOp, + TempType, + ExternalLoadOp, + ExternalStoreOp, +) from xdsl.passes import ModulePass from xdsl.utils.exceptions import VerifyException from xdsl.utils.hints import isa -from xdsl.transforms.experimental.stencil_global_to_local import LowerHaloExchangeToMpi, HorizontalSlices2D, \ - MpiLoopInvariantCodeMotion +from xdsl.transforms.experimental.stencil_global_to_local import ( + LowerHaloExchangeToMpi, + HorizontalSlices2D, + MpiLoopInvariantCodeMotion, +) _TypeElement = TypeVar("_TypeElement", bound=Attribute) @@ -32,17 +47,16 @@ def GetMemRefFromField( - input_type: FieldType[_TypeElement] | TempType[_TypeElement] + input_type: FieldType[_TypeElement] | TempType[_TypeElement], ) -> MemRefType[_TypeElement]: dims = [i.value.data for i in input_type.shape.data] - return MemRefType.from_element_type_and_shape(input_type.element_type, - dims) + return MemRefType.from_element_type_and_shape(input_type.element_type, dims) -def GetMemRefFromFieldWithLBAndUB(memref_element_type: _TypeElement, - lb: IndexAttr, - ub: IndexAttr) -> MemRefType[_TypeElement]: +def GetMemRefFromFieldWithLBAndUB( + memref_element_type: _TypeElement, lb: IndexAttr, ub: IndexAttr +) -> MemRefType[_TypeElement]: # lb and ub defines the minimum and maximum coordinates of the resulting memref, # so its shape is simply ub - lb, computed here. dims = IndexAttr.size_from_bounds(lb, ub) @@ -52,34 +66,31 @@ def GetMemRefFromFieldWithLBAndUB(memref_element_type: _TypeElement, @dataclass class CastOpToMemref(RewritePattern): - gpu: bool = False @op_type_rewrite_pattern def match_and_rewrite(self, op: CastOp, rewriter: PatternRewriter, /): + assert isa(op.field.typ, FieldType[Attribute] | memref.MemRefType[Attribute]) - assert isa(op.field.typ, - FieldType[Attribute] | memref.MemRefType[Attribute]) - - result_typ = GetMemRefFromFieldWithLBAndUB(op.field.typ.element_type, - op.lb, op.ub) + result_typ = GetMemRefFromFieldWithLBAndUB( + op.field.typ.element_type, op.lb, op.ub + ) cast = memref.Cast.get(op.field, result_typ) if self.gpu: unranked = memref.Cast.get( cast.dest, - memref.UnrankedMemrefType.from_type(op.field.typ.element_type)) + memref.UnrankedMemrefType.from_type(op.field.typ.element_type), + ) register = gpu.HostRegisterOp.from_memref(unranked.dest) rewriter.insert_op_after_matched_op([unranked, register]) rewriter.replace_matched_op(cast) class StoreOpCleanup(RewritePattern): - @op_type_rewrite_pattern def match_and_rewrite(self, op: StoreOp, rewriter: PatternRewriter, /): - rewriter.erase_matched_op() pass @@ -103,7 +114,6 @@ def collectBlockArguments(number: int, block: Block): @dataclass class ReturnOpToMemref(RewritePattern): - return_target: dict[ReturnOp, list[CastOp | memref.Subview | None]] @op_type_rewrite_pattern @@ -128,16 +138,15 @@ def match_and_rewrite(self, op: ReturnOp, rewriter: PatternRewriter, /): args = collectBlockArguments(dims, block) - store_list.append(memref.Store.get(op.arg[j], subview.result, - args)) + store_list.append(memref.Store.get(op.arg[j], subview.result, args)) rewriter.replace_matched_op([*store_list]) def verify_load_bounds(cast: CastOp, load: LoadOp): - - if ([i.value.data for i in IndexAttr.min(cast.lb, load.lb).array.data] - != [i.value.data for i in cast.lb.array.data]): # noqa + if [i.value.data for i in IndexAttr.min(cast.lb, load.lb).array.data] != [ + i.value.data for i in cast.lb.array.data + ]: # noqa raise VerifyException( "The stencil computation requires a field with lower bound at least " f"{load.lb}, got {cast.lb}, min: {IndexAttr.min(cast.lb, load.lb)}" @@ -145,7 +154,6 @@ def verify_load_bounds(cast: CastOp, load: LoadOp): class LoadOpToMemref(RewritePattern): - @op_type_rewrite_pattern def match_and_rewrite(self, op: LoadOp, rewriter: PatternRewriter, /): cast = op.field.owner @@ -164,13 +172,13 @@ def match_and_rewrite(self, op: LoadOp, rewriter: PatternRewriter, /): strides = [1] * len(sizes) subview = memref.Subview.from_static_parameters( - cast.result, element_type, shape, offsets, sizes, strides) + cast.result, element_type, shape, offsets, sizes, strides + ) rewriter.replace_matched_op(subview) def prepare_apply_body(op: ApplyOp, rewriter: PatternRewriter): - assert (op.lb is not None) and (op.ub is not None) # First replace all current arguments by their definition @@ -190,10 +198,8 @@ def prepare_apply_body(op: ApplyOp, rewriter: PatternRewriter): class ApplyOpToParallel(RewritePattern): - @op_type_rewrite_pattern def match_and_rewrite(self, op: ApplyOp, rewriter: PatternRewriter, /): - assert (op.lb is not None) and (op.ub is not None) body = prepare_apply_body(op, rewriter) @@ -205,8 +211,7 @@ def match_and_rewrite(self, op: ApplyOp, rewriter: PatternRewriter, /): zero = arith.Constant.from_int_and_width(0, builtin.IndexType()) one = arith.Constant.from_int_and_width(1, builtin.IndexType()) upperBounds = [ - arith.Constant.from_int_and_width(x, builtin.IndexType()) - for x in dims + arith.Constant.from_int_and_width(x, builtin.IndexType()) for x in dims ] # Generate an outer parallel loop as well as two inner sequential @@ -214,19 +219,20 @@ def match_and_rewrite(self, op: ApplyOp, rewriter: PatternRewriter, /): # kernel itself is not slowed down by the OpenMP runtime. current_region = body for i in range(1, dim): - for_op = scf.For.get(lb=zero, - ub=upperBounds[-i], - step=one, - iter_args=[], - body=current_region) - block = Block(ops=[for_op, scf.Yield.get()], - arg_types=[builtin.IndexType()]) + for_op = scf.For.get( + lb=zero, ub=upperBounds[-i], step=one, iter_args=[], body=current_region + ) + block = Block( + ops=[for_op, scf.Yield.get()], arg_types=[builtin.IndexType()] + ) current_region = Region(block) - p = scf.ParallelOp.get(lowerBounds=[zero], - upperBounds=[upperBounds[0]], - steps=[one], - body=current_region) + p = scf.ParallelOp.get( + lowerBounds=[zero], + upperBounds=[upperBounds[0]], + steps=[one], + body=current_region, + ) # Replace with the loop and necessary constants. rewriter.insert_op_before_matched_op([zero, one, *upperBounds, p]) @@ -234,10 +240,8 @@ def match_and_rewrite(self, op: ApplyOp, rewriter: PatternRewriter, /): class AccessOpToMemref(RewritePattern): - @op_type_rewrite_pattern def match_and_rewrite(self, op: AccessOp, rewriter: PatternRewriter, /): - load = op.temp.owner assert isinstance(load, LoadOp) assert load.lb is not None @@ -248,26 +252,21 @@ def match_and_rewrite(self, op: AccessOp, rewriter: PatternRewriter, /): memref_offset = (op.offset - load.lb).array.data off_const_ops = [ - arith.Constant.from_int_and_width(x.value.data, - builtin.IndexType()) + arith.Constant.from_int_and_width(x.value.data, builtin.IndexType()) for x in memref_offset ] args = collectBlockArguments(len(memref_offset), block) - off_sum_ops = [ - arith.Addi.get(i, x) for i, x in zip(args, off_const_ops) - ] + off_sum_ops = [arith.Addi.get(i, x) for i, x in zip(args, off_const_ops)] load = memref.Load.get(load, off_sum_ops) - rewriter.replace_matched_op([*off_const_ops, *off_sum_ops, load], - [load.res]) + rewriter.replace_matched_op([*off_const_ops, *off_sum_ops, load], [load.res]) @dataclass class StencilTypeConversionFuncOp(RewritePattern): - return_targets: dict[ReturnOp, list[CastOp | memref.Subview | None]] @op_type_rewrite_pattern @@ -282,7 +281,8 @@ def match_and_rewrite(self, op: FuncOp, rewriter: PatternRewriter, /): inputs.append(arg.typ) op.attributes["function_type"] = FunctionType.from_lists( - inputs, list(op.function_type.outputs.data)) + inputs, list(op.function_type.outputs.data) + ) stores: list[StoreOp] = [] op.walk(lambda o: stores.append(o) if isinstance(o, StoreOp) else None) @@ -296,8 +296,13 @@ def match_and_rewrite(self, op: FuncOp, rewriter: PatternRewriter, /): offsets = [i.value.data for i in (store.lb - cast.lb).array.data] sizes = [i.value.data for i in (store.ub - store.lb).array.data] subview = memref.Subview.from_static_parameters( - new_cast.result, cast.result.typ.element_type, source_shape, - offsets, sizes, [1] * len(sizes)) + new_cast.result, + cast.result.typ.element_type, + source_shape, + offsets, + sizes, + [1] * len(sizes), + ) rewriter.replace_op(cast, [new_cast, subview]) for r, c in self.return_targets.items(): @@ -307,10 +312,8 @@ def match_and_rewrite(self, op: FuncOp, rewriter: PatternRewriter, /): class TrivialExternalLoadOpCleanup(RewritePattern): - @op_type_rewrite_pattern - def match_and_rewrite(self, op: ExternalLoadOp, rewriter: PatternRewriter, - /): + def match_and_rewrite(self, op: ExternalLoadOp, rewriter: PatternRewriter, /): assert isa(op.result.typ, FieldType[Attribute]) op.result.typ = GetMemRefFromField(op.result.typ) @@ -320,10 +323,8 @@ def match_and_rewrite(self, op: ExternalLoadOp, rewriter: PatternRewriter, class TrivialExternalStoreOpCleanup(RewritePattern): - @op_type_rewrite_pattern - def match_and_rewrite(self, op: ExternalStoreOp, rewriter: PatternRewriter, - /): + def match_and_rewrite(self, op: ExternalStoreOp, rewriter: PatternRewriter, /): rewriter.erase_matched_op() @@ -340,7 +341,8 @@ def map_returns(op: Operation) -> None: return_targets[op] = [] for res in list(apply.res): store = [ - use.operation for use in list(res.uses) + use.operation + for use in list(res.uses) if isinstance(use.operation, StoreOp) ] @@ -359,57 +361,60 @@ def map_returns(op: Operation) -> None: return return_targets -def StencilConversion(return_targets: dict[ReturnOp, - list[CastOp | memref.Subview - | None]], gpu: bool): +def StencilConversion( + return_targets: dict[ReturnOp, list[CastOp | memref.Subview | None]], gpu: bool +): """ List of rewrite passes for stencil """ - return GreedyRewritePatternApplier([ - ApplyOpToParallel(), - StencilTypeConversionFuncOp(return_targets), - CastOpToMemref(gpu), - LoadOpToMemref(), - AccessOpToMemref(), - ReturnOpToMemref(return_targets), - StoreOpCleanup(), - TrivialExternalLoadOpCleanup(), - TrivialExternalStoreOpCleanup() - ]) + return GreedyRewritePatternApplier( + [ + ApplyOpToParallel(), + StencilTypeConversionFuncOp(return_targets), + CastOpToMemref(gpu), + LoadOpToMemref(), + AccessOpToMemref(), + ReturnOpToMemref(return_targets), + StoreOpCleanup(), + TrivialExternalLoadOpCleanup(), + TrivialExternalStoreOpCleanup(), + ] + ) class ConvertStencilToGPUPass(ModulePass): - - name = 'convert-stencil-to-gpu' + name = "convert-stencil-to-gpu" def apply(self, ctx: MLContext, op: builtin.ModuleOp) -> None: return_targets = return_target_analysis(op) - the_one_pass = PatternRewriteWalker(GreedyRewritePatternApplier( - [StencilConversion(return_targets, gpu=True)]), - apply_recursively=False, - walk_reverse=True) + the_one_pass = PatternRewriteWalker( + GreedyRewritePatternApplier([StencilConversion(return_targets, gpu=True)]), + apply_recursively=False, + walk_reverse=True, + ) the_one_pass.rewrite_module(op) class ConvertStencilToLLMLIRPass(ModulePass): - - name = 'convert-stencil-to-ll-mlir' + name = "convert-stencil-to-ll-mlir" def apply(self, ctx: MLContext, op: builtin.ModuleOp) -> None: - - return_targets: dict[ReturnOp, - list[CastOp - | memref.Subview - | None]] = return_target_analysis(op) - - the_one_pass = PatternRewriteWalker(GreedyRewritePatternApplier( - [StencilConversion(return_targets, gpu=False)]), - apply_recursively=False, - walk_reverse=True) + return_targets: dict[ + ReturnOp, list[CastOp | memref.Subview | None] + ] = return_target_analysis(op) + + the_one_pass = PatternRewriteWalker( + GreedyRewritePatternApplier([StencilConversion(return_targets, gpu=False)]), + apply_recursively=False, + walk_reverse=True, + ) the_one_pass.rewrite_module(op) PatternRewriteWalker( - GreedyRewritePatternApplier([ - LowerHaloExchangeToMpi(HorizontalSlices2D(2)), - MpiLoopInvariantCodeMotion(), - ])).rewrite_module(op) + GreedyRewritePatternApplier( + [ + LowerHaloExchangeToMpi(HorizontalSlices2D(2)), + MpiLoopInvariantCodeMotion(), + ] + ) + ).rewrite_module(op) diff --git a/xdsl/transforms/experimental/StencilShapeInference.py b/xdsl/transforms/experimental/StencilShapeInference.py index c37b6cdc84..64c83340b2 100644 --- a/xdsl/transforms/experimental/StencilShapeInference.py +++ b/xdsl/transforms/experimental/StencilShapeInference.py @@ -1,17 +1,31 @@ from typing import Iterable, TypeVar from xdsl.dialects import builtin -from xdsl.dialects.experimental.stencil import AccessOp, ApplyOp, CastOp, HaloSwapOp, IndexAttr, LoadOp, StoreOp, TempType +from xdsl.dialects.experimental.stencil import ( + AccessOp, + ApplyOp, + CastOp, + HaloSwapOp, + IndexAttr, + LoadOp, + StoreOp, + TempType, +) from xdsl.ir import Attribute, BlockArgument, MLContext, Operation, SSAValue from xdsl.passes import ModulePass -from xdsl.pattern_rewriter import GreedyRewritePatternApplier, PatternRewriteWalker, PatternRewriter, RewritePattern, op_type_rewrite_pattern +from xdsl.pattern_rewriter import ( + GreedyRewritePatternApplier, + PatternRewriteWalker, + PatternRewriter, + RewritePattern, + op_type_rewrite_pattern, +) from xdsl.transforms.experimental.ConvertStencilToLLMLIR import verify_load_bounds from xdsl.utils.hints import isa -_OpT = TypeVar('_OpT', bound=Operation) +_OpT = TypeVar("_OpT", bound=Operation) -def all_matching_uses(op_res: Iterable[SSAValue], - typ: type[_OpT]) -> Iterable[_OpT]: +def all_matching_uses(op_res: Iterable[SSAValue], typ: type[_OpT]) -> Iterable[_OpT]: for res in op_res: for use in res.uses: if isinstance(use.operation, typ): @@ -39,7 +53,6 @@ def infer_core_size(op: LoadOp) -> tuple[IndexAttr, IndexAttr]: class LoadOpShapeInference(RewritePattern): - @op_type_rewrite_pattern def match_and_rewrite(self, op: LoadOp, rewriter: PatternRewriter, /): cast = op.field.owner @@ -59,23 +72,19 @@ def match_and_rewrite(self, op: LoadOp, rewriter: PatternRewriter, /): class StoreOpShapeInference(RewritePattern): - @op_type_rewrite_pattern def match_and_rewrite(self, op: StoreOp, rewriter: PatternRewriter, /): - owner = op.temp.owner assert isinstance(owner, ApplyOp | LoadOp) - owner.attributes['lb'] = IndexAttr.min(op.lb, owner.lb) - owner.attributes['ub'] = IndexAttr.max(op.ub, owner.ub) + owner.attributes["lb"] = IndexAttr.min(op.lb, owner.lb) + owner.attributes["ub"] = IndexAttr.max(op.ub, owner.ub) class ApplyOpShapeInference(RewritePattern): - @op_type_rewrite_pattern def match_and_rewrite(self, op: ApplyOp, rewriter: PatternRewriter, /): - def access_shape_infer_walk(access: Operation) -> None: assert (op.lb is not None) and (op.ub is not None) if not isinstance(access, AccessOp): @@ -85,10 +94,12 @@ def access_shape_infer_walk(access: Operation) -> None: assert isinstance(temp_owner, LoadOp | ApplyOp) - temp_owner.attributes['lb'] = IndexAttr.min( - op.lb + access.offset, temp_owner.lb) - temp_owner.attributes['ub'] = IndexAttr.max( - op.ub + access.offset, temp_owner.ub) + temp_owner.attributes["lb"] = IndexAttr.min( + op.lb + access.offset, temp_owner.lb + ) + temp_owner.attributes["ub"] = IndexAttr.max( + op.ub + access.offset, temp_owner.ub + ) op.walk(access_shape_infer_walk) @@ -97,40 +108,39 @@ def access_shape_infer_walk(access: Operation) -> None: for result in op.results: assert isa(result.typ, TempType[Attribute]) result.typ = TempType.from_shape( - IndexAttr.size_from_bounds(op.lb, op.ub), - result.typ.element_type) + IndexAttr.size_from_bounds(op.lb, op.ub), result.typ.element_type + ) class HaloOpShapeInference(RewritePattern): - @op_type_rewrite_pattern def match_and_rewrite(self, op: HaloSwapOp, rewriter: PatternRewriter, /): assert isinstance(op.input_stencil.owner, LoadOp) load = op.input_stencil.owner halo_lb, halo_ub = infer_core_size(load) - op.attributes['core_lb'] = halo_lb - op.attributes['core_ub'] = halo_ub + op.attributes["core_lb"] = halo_lb + op.attributes["core_ub"] = halo_ub assert load.lb is not None assert load.ub is not None - op.attributes['buff_lb'] = load.lb - op.attributes['buff_ub'] = load.ub + op.attributes["buff_lb"] = load.lb + op.attributes["buff_ub"] = load.ub -ShapeInference = GreedyRewritePatternApplier([ - ApplyOpShapeInference(), - LoadOpShapeInference(), - StoreOpShapeInference(), - HaloOpShapeInference(), -]) +ShapeInference = GreedyRewritePatternApplier( + [ + ApplyOpShapeInference(), + LoadOpShapeInference(), + StoreOpShapeInference(), + HaloOpShapeInference(), + ] +) class StencilShapeInferencePass(ModulePass): - - name = 'stencil-shape-inference' + name = "stencil-shape-inference" def apply(self, ctx: MLContext, op: builtin.ModuleOp) -> None: - - inference_walker = PatternRewriteWalker(ShapeInference, - apply_recursively=False, - walk_reverse=True) + inference_walker = PatternRewriteWalker( + ShapeInference, apply_recursively=False, walk_reverse=True + ) inference_walker.rewrite_module(op) diff --git a/xdsl/transforms/experimental/stencil_global_to_local.py b/xdsl/transforms/experimental/stencil_global_to_local.py index 7419331627..2caf673a4e 100644 --- a/xdsl/transforms/experimental/stencil_global_to_local.py +++ b/xdsl/transforms/experimental/stencil_global_to_local.py @@ -5,15 +5,19 @@ from xdsl.passes import ModulePass from xdsl.utils.hints import isa -from xdsl.pattern_rewriter import (PatternRewriter, PatternRewriteWalker, - RewritePattern, GreedyRewritePatternApplier, - op_type_rewrite_pattern) +from xdsl.pattern_rewriter import ( + PatternRewriter, + PatternRewriteWalker, + RewritePattern, + GreedyRewritePatternApplier, + op_type_rewrite_pattern, +) from xdsl.ir import MLContext, Operation, SSAValue, Block, Region, OpResult from xdsl.irdl import Attribute from xdsl.dialects import builtin, mpi, memref, arith, scf, func from xdsl.dialects.experimental import stencil -_T = TypeVar('_T', bound=Attribute) +_T = TypeVar("_T", bound=Attribute) @dataclass @@ -48,6 +52,7 @@ class HaloExchangeDef: This data will be exchanged with the node of rank (my_rank -1) """ + offset: tuple[int, ...] size: tuple[int, ...] source_offset: tuple[int, ...] @@ -61,7 +66,7 @@ def elem_count(self) -> int: def dim(self) -> int: return len(self.offset) - def source_area(self) -> 'HaloExchangeDef': + def source_area(self) -> "HaloExchangeDef": """ Since a HaloExchangeDef by default specifies the area to receive into, this method returns the area that should be read from. @@ -69,8 +74,8 @@ def source_area(self) -> 'HaloExchangeDef': # we set source_offset to all zeor, so that repeated calls to source_area never return the dest area return HaloExchangeDef( offset=tuple( - val + offs - for val, offs in zip(self.offset, self.source_offset)), + val + offs for val, offs in zip(self.offset, self.source_offset) + ), size=self.size, source_offset=tuple(0 for _ in range(len(self.source_offset))), neighbor=self.neighbor, @@ -129,10 +134,18 @@ class DimsHelper: DIM_Z: ClassVar[int] = 2 def __init__(self, op: stencil.HaloSwapOp): - assert op.buff_lb is not None, "HaloSwapOp must be lowered after shape inference!" - assert op.buff_ub is not None, "HaloSwapOp must be lowered after shape inference!" - assert op.core_lb is not None, "HaloSwapOp must be lowered after shape inference!" - assert op.core_ub is not None, "HaloSwapOp must be lowered after shape inference!" + assert ( + op.buff_lb is not None + ), "HaloSwapOp must be lowered after shape inference!" + assert ( + op.buff_ub is not None + ), "HaloSwapOp must be lowered after shape inference!" + assert ( + op.core_lb is not None + ), "HaloSwapOp must be lowered after shape inference!" + assert ( + op.core_ub is not None + ), "HaloSwapOp must be lowered after shape inference!" # translate everything to "memref" coordinates buff_lb = (op.buff_lb - op.buff_lb).as_tuple() @@ -140,8 +153,9 @@ def __init__(self, op: stencil.HaloSwapOp): core_lb = (op.core_lb - op.buff_lb).as_tuple() core_ub = (op.core_ub - op.buff_lb).as_tuple() - assert len(buff_lb) == len(buff_ub) == len(core_lb) == len(core_ub), \ - "Expected all args to be of the same length!" + assert ( + len(buff_lb) == len(buff_ub) == len(core_lb) == len(core_ub) + ), "Expected all args to be of the same length!" self.dims = len(buff_lb) self.buff_lb = buff_lb @@ -186,17 +200,13 @@ def halo_size(self, dim: int, at_end: bool = False): @dataclass class DomainDecompositionStrategy(ABC): - @abstractmethod def calc_resize(self, shape: tuple[int]) -> tuple[int]: - raise NotImplementedError( - "SlicingStrategy must implement calc_resize!") + raise NotImplementedError("SlicingStrategy must implement calc_resize!") @abstractmethod - def halo_exchange_defs(self, - dims: DimsHelper) -> Iterable[HaloExchangeDef]: - raise NotImplementedError( - "SlicingStrategy must implement halo_exchange_defs!") + def halo_exchange_defs(self, dims: DimsHelper) -> Iterable[HaloExchangeDef]: + raise NotImplementedError("SlicingStrategy must implement halo_exchange_defs!") @abstractmethod def comm_count(self) -> int: @@ -215,15 +225,14 @@ def comm_count(self) -> int: def calc_resize(self, shape: tuple[int, ...]) -> tuple[int, ...]: # slice on the y-axis - assert len(shape) == 2, \ - "HorizontalSlices2D only works on 2d fields!" - assert shape[1] % self.slices == 0, \ - "HorizontalSlices2D expects second dim to be divisible by number of slices!" + assert len(shape) == 2, "HorizontalSlices2D only works on 2d fields!" + assert ( + shape[1] % self.slices == 0 + ), "HorizontalSlices2D expects second dim to be divisible by number of slices!" return shape[0], shape[1] // self.slices - def halo_exchange_defs(self, - dims: DimsHelper) -> Iterable[HaloExchangeDef]: + def halo_exchange_defs(self, dims: DimsHelper) -> Iterable[HaloExchangeDef]: # upper halo exchange: yield HaloExchangeDef( offset=( @@ -263,12 +272,13 @@ class ChangeStoreOpSizes(RewritePattern): strategy: DomainDecompositionStrategy @op_type_rewrite_pattern - def match_and_rewrite(self, op: stencil.StoreOp, rewriter: PatternRewriter, - /): - assert all(integer_attr.value.data == 0 - for integer_attr in op.lb.array.data), "lb must be 0" + def match_and_rewrite(self, op: stencil.StoreOp, rewriter: PatternRewriter, /): + assert all( + integer_attr.value.data == 0 for integer_attr in op.lb.array.data + ), "lb must be 0" shape: tuple[int, ...] = tuple( - (integer_attr.value.data for integer_attr in op.ub.array.data)) + (integer_attr.value.data for integer_attr in op.ub.array.data) + ) new_shape = self.strategy.calc_resize(shape) op.ub = stencil.IndexAttr.get(*new_shape) @@ -278,25 +288,24 @@ class AddHaloExchangeOps(RewritePattern): """ This rewrite adds a `stencil.halo_exchange` before each `stencil.load` op """ + strategy: DomainDecompositionStrategy @op_type_rewrite_pattern - def match_and_rewrite(self, op: stencil.LoadOp, rewriter: PatternRewriter, - /): + def match_and_rewrite(self, op: stencil.LoadOp, rewriter: PatternRewriter, /): swap_op = stencil.HaloSwapOp.get(op.res) rewriter.insert_op_after_matched_op(swap_op) class GlobalStencilToLocalStencil2DHorizontal(ModulePass): - - name = 'stencil-to-local-2d-horizontal' + name = "stencil-to-local-2d-horizontal" def apply(self, ctx: MLContext, op: builtin.ModuleOp) -> None: strategy = HorizontalSlices2D(2) gpra = GreedyRewritePatternApplier( - [ChangeStoreOpSizes(strategy), - AddHaloExchangeOps(strategy)]) + [ChangeStoreOpSizes(strategy), AddHaloExchangeOps(strategy)] + ) PatternRewriteWalker(gpra, apply_recursively=False).rewrite_module(op) @@ -306,8 +315,7 @@ class LowerHaloExchangeToMpi(RewritePattern): strategy: DomainDecompositionStrategy @op_type_rewrite_pattern - def match_and_rewrite(self, op: stencil.HaloSwapOp, - rewriter: PatternRewriter, /): + def match_and_rewrite(self, op: stencil.HaloSwapOp, rewriter: PatternRewriter, /): exchanges = list(self.strategy.halo_exchange_defs(DimsHelper(op))) assert isa(op.input_stencil.typ, memref.MemRefType[Attribute]) rewriter.replace_matched_op( @@ -317,21 +325,24 @@ def match_and_rewrite(self, op: stencil.HaloSwapOp, exchanges, op.input_stencil.typ.element_type, self.strategy, - )), + ) + ), [], ) def generate_mpi_calls_for( - source: SSAValue, exchanges: list[HaloExchangeDef], dtype: Attribute, - strat: DomainDecompositionStrategy) -> Iterable[Operation]: + source: SSAValue, + exchanges: list[HaloExchangeDef], + dtype: Attribute, + strat: DomainDecompositionStrategy, +) -> Iterable[Operation]: # call mpi init (this will be hoisted to function level) init = mpi.Init() # allocate request array # we need two request objects per exchange # one for the send, one for the recv - req_cnt = arith.Constant.from_int_and_width( - len(exchanges) * 2, builtin.i32) + req_cnt = arith.Constant.from_int_and_width(len(exchanges) * 2, builtin.i32) reqs = mpi.AllocateTypeOp.get(mpi.RequestType, req_cnt) # get comm rank rank = mpi.CommRank.get() @@ -344,8 +355,7 @@ def generate_mpi_calls_for( recv_buffers: list[tuple[HaloExchangeDef, memref.Alloc, SSAValue]] = [] for i, ex in enumerate(exchanges): - neighbor_offset = arith.Constant.from_int_and_width( - ex.neighbor, builtin.i32) + neighbor_offset = arith.Constant.from_int_and_width(ex.neighbor, builtin.i32) neighbor_rank = arith.Addi.get(rank, neighbor_offset) yield from (neighbor_offset, neighbor_rank) @@ -356,8 +366,9 @@ def generate_mpi_calls_for( # boundary condition: bound = arith.Constant.from_int_and_width( - 0 if ex.neighbor < 0 else strat.comm_count(), builtin.i32) - comparison = 'slt' if ex.neighbor < 0 else 'sge' + 0 if ex.neighbor < 0 else strat.comm_count(), builtin.i32 + ) + comparison = "slt" if ex.neighbor < 0 else "sge" cond_val = arith.Cmpi.get(neighbor_rank, bound, comparison) yield from (bound, cond_val) @@ -366,8 +377,7 @@ def generate_mpi_calls_for( # get two unique indices cst_i = arith.Constant.from_int_and_width(i, builtin.i32) - cst_in = arith.Constant.from_int_and_width(i + len(exchanges), - builtin.i32) + cst_in = arith.Constant.from_int_and_width(i + len(exchanges), builtin.i32) yield from (cst_i, cst_in) # from these indices, get request objects req_send = mpi.VectorGetOp.get(reqs, cst_i) @@ -376,23 +386,34 @@ def generate_mpi_calls_for( def then(): # copy source area to outbound buffer - yield from generate_memcpy(source, ex.source_area(), - alloc_outbound.memref) + yield from generate_memcpy(source, ex.source_area(), alloc_outbound.memref) # get ptr, count, dtype unwrap_out = mpi.UnwrapMemrefOp.get(alloc_outbound) yield unwrap_out # isend call - yield mpi.Isend.get(unwrap_out.ptr, unwrap_out.len, unwrap_out.typ, - neighbor_rank, tag, req_send) + yield mpi.Isend.get( + unwrap_out.ptr, + unwrap_out.len, + unwrap_out.typ, + neighbor_rank, + tag, + req_send, + ) # get ptr for receive buffer unwrap_in = mpi.UnwrapMemrefOp.get(alloc_inbound) yield unwrap_in # Irecv call - yield mpi.Irecv.get(unwrap_in.ptr, unwrap_in.len, unwrap_in.typ, - neighbor_rank, tag, req_recv) + yield mpi.Irecv.get( + unwrap_in.ptr, + unwrap_in.len, + unwrap_in.typ, + neighbor_rank, + tag, + req_recv, + ) yield scf.Yield.get() def else_(): @@ -417,24 +438,28 @@ def else_(): yield scf.If.get( cond_val, [], - Region([ - Block( - list( - generate_memcpy( - source, - ex.source_area(), - buffer.memref, - reverse=True, - )) + [scf.Yield.get()]) - ]), + Region( + [ + Block( + list( + generate_memcpy( + source, + ex.source_area(), + buffer.memref, + reverse=True, + ) + ) + + [scf.Yield.get()] + ) + ] + ), Region([Block([scf.Yield.get()])]), ) -def generate_memcpy(source: SSAValue, - ex: HaloExchangeDef, - dest: SSAValue, - reverse: bool = False) -> list[Operation]: +def generate_memcpy( + source: SSAValue, ex: HaloExchangeDef, dest: SSAValue, reverse: bool = False +) -> list[Operation]: """ This function generates a memcpy routine to copy over the parts specified by the `ex` from `source` into `dest`. @@ -459,13 +484,13 @@ def generate_memcpy(source: SSAValue, unroll_inner = False # enable to get verbose information on what buffers are exchanged: - #print("Generating{} memcpy from buff[{}:{},{}:{}]{}temp[{}:{}]".format( + # print("Generating{} memcpy from buff[{}:{},{}:{}]{}temp[{}:{}]".format( # " unrolled" if unrolled else "", # ex.offset[0], ex.offset[0] + ex.size[0], # ex.offset[1], ex.offset[1] + ex.size[1], # '<-' if reverse else '->', # 0, ex.elem_count - #)) + # )) # only generate indices if we actually want to unroll if unroll_inner: @@ -521,18 +546,23 @@ def inner(j: SSAValue): x_len, cst1, [], - [Block.from_callable([builtin.IndexType()], inner)] # type: ignore + [Block.from_callable([builtin.IndexType()], inner)], # type: ignore ) yield scf.Yield.get() - loop_body: Callable[[SSAValue], Iterable[ - Operation]] = loop_body_unrolled if unroll_inner else loop_body_with_for + loop_body: Callable[[SSAValue], Iterable[Operation]] = ( + loop_body_unrolled if unroll_inner else loop_body_with_for + ) # TODO: make type annotations here aware that they can work with generators! - loop = scf.For.get(cst0, y_len, cst1, [], - Block.from_callable([builtin.IndexType()], - loop_body)) # type: ignore + loop = scf.For.get( + cst0, + y_len, + cst1, + [], + Block.from_callable([builtin.IndexType()], loop_body), # type: ignore + ) return [ x0, @@ -552,6 +582,7 @@ class MpiLoopInvariantCodeMotion(RewritePattern): and mpi.unwrap_memref ops and moves them "up" until it hits a func.func, and then places them *before* the op they appear in. """ + seen_ops: set[Operation] has_init: set[func.FuncOp] @@ -560,16 +591,24 @@ def __init__(self): self.has_init = set() @op_type_rewrite_pattern - def match_and_rewrite(self, op: memref.Alloc | mpi.CommRank - | mpi.AllocateTypeOp | mpi.UnwrapMemrefOp | mpi.Init, - rewriter: PatternRewriter, /): + def match_and_rewrite( + self, + op: memref.Alloc + | mpi.CommRank + | mpi.AllocateTypeOp + | mpi.UnwrapMemrefOp + | mpi.Init, + rewriter: PatternRewriter, + /, + ): if op in self.seen_ops: return self.seen_ops.add(op) # memref unwraps can always be moved to their allocation if isinstance(op, mpi.UnwrapMemrefOp) and isinstance( - op.ref.owner, memref.Alloc): + op.ref.owner, memref.Alloc + ): op.detach() rewriter.insert_op_after(op, op.ref.owner) return @@ -602,8 +641,7 @@ def match_and_rewrite(self, op: memref.Alloc | mpi.CommRank # add a finalize() call to the end of the function parent.regions[0].blocks[-1].insert_op( mpi.Finalize(), - len(parent.regions[0].blocks[-1].ops) - - 1, # insert before return + len(parent.regions[0].blocks[-1].ops) - 1, # insert before return ) ops = list(collect_args_recursive(op)) diff --git a/xdsl/transforms/lower_mpi.py b/xdsl/transforms/lower_mpi.py index d3834e932f..440bd46372 100644 --- a/xdsl/transforms/lower_mpi.py +++ b/xdsl/transforms/lower_mpi.py @@ -8,10 +8,13 @@ from xdsl.dialects.memref import MemRefType from xdsl.ir import Operation, SSAValue, OpResult, Attribute, MLContext -from xdsl.pattern_rewriter import (RewritePattern, PatternRewriter, - op_type_rewrite_pattern, - PatternRewriteWalker, - GreedyRewritePatternApplier) +from xdsl.pattern_rewriter import ( + RewritePattern, + PatternRewriter, + op_type_rewrite_pattern, + PatternRewriteWalker, + GreedyRewritePatternApplier, +) from xdsl.dialects import mpi, llvm, func, memref, arith, builtin @@ -36,23 +39,23 @@ class MpiLibraryInfo: # MPI_Datatype MPI_Datatype_size: int = 4 - MPI_CHAR: int = 0x4c000101 - MPI_SIGNED_CHAR: int = 0x4c000118 - MPI_UNSIGNED_CHAR: int = 0x4c000102 - MPI_BYTE: int = 0x4c00010d - MPI_WCHAR: int = 0x4c00040e - MPI_SHORT: int = 0x4c000203 - MPI_UNSIGNED_SHORT: int = 0x4c000204 - MPI_INT: int = 0x4c000405 - MPI_UNSIGNED: int = 0x4c000406 - MPI_LONG: int = 0x4c000807 - MPI_UNSIGNED_LONG: int = 0x4c000808 - MPI_FLOAT: int = 0x4c00040a - MPI_DOUBLE: int = 0x4c00080b - MPI_LONG_DOUBLE: int = 0x4c00100c - MPI_LONG_LONG_INT: int = 0x4c000809 - MPI_UNSIGNED_LONG_LONG: int = 0x4c000819 - MPI_LONG_LONG: int = 0x4c000809 + MPI_CHAR: int = 0x4C000101 + MPI_SIGNED_CHAR: int = 0x4C000118 + MPI_UNSIGNED_CHAR: int = 0x4C000102 + MPI_BYTE: int = 0x4C00010D + MPI_WCHAR: int = 0x4C00040E + MPI_SHORT: int = 0x4C000203 + MPI_UNSIGNED_SHORT: int = 0x4C000204 + MPI_INT: int = 0x4C000405 + MPI_UNSIGNED: int = 0x4C000406 + MPI_LONG: int = 0x4C000807 + MPI_UNSIGNED_LONG: int = 0x4C000808 + MPI_FLOAT: int = 0x4C00040A + MPI_DOUBLE: int = 0x4C00080B + MPI_LONG_DOUBLE: int = 0x4C00100C + MPI_LONG_LONG_INT: int = 0x4C000809 + MPI_UNSIGNED_LONG_LONG: int = 0x4C000819 + MPI_LONG_LONG: int = 0x4C000809 # MPI_Op MPI_Op_size: int = 4 @@ -65,11 +68,11 @@ class MpiLibraryInfo: MPI_LOR: int = 0x58000007 MPI_BOR: int = 0x58000008 MPI_LXOR: int = 0x58000009 - MPI_BXOR: int = 0x5800000a - MPI_MINLOC: int = 0x5800000b - MPI_MAXLOC: int = 0x5800000c - MPI_REPLACE: int = 0x5800000d - MPI_NO_OP: int = 0x5800000e + MPI_BXOR: int = 0x5800000A + MPI_MINLOC: int = 0x5800000B + MPI_MAXLOC: int = 0x5800000C + MPI_REPLACE: int = 0x5800000D + MPI_NO_OP: int = 0x5800000E # MPI_Comm MPI_Comm_size: int = 4 @@ -78,21 +81,25 @@ class MpiLibraryInfo: # MPI_Request MPI_Request_size: int = 4 - MPI_REQUEST_NULL = 0x2c000000 + MPI_REQUEST_NULL = 0x2C000000 # MPI_Status MPI_Status_size: int = 20 MPI_STATUS_IGNORE: int = 0x00000001 MPI_STATUSES_IGNORE: int = 0x00000001 - MPI_Status_field_MPI_SOURCE: int = 8 # offset of field MPI_SOURCE in struct MPI_Status + MPI_Status_field_MPI_SOURCE: int = ( + 8 # offset of field MPI_SOURCE in struct MPI_Status + ) MPI_Status_field_MPI_TAG: int = 12 # offset of field MPI_TAG in struct MPI_Status - MPI_Status_field_MPI_ERROR: int = 16 # offset of field MPI_ERROR in struct MPI_Status + MPI_Status_field_MPI_ERROR: int = ( + 16 # offset of field MPI_ERROR in struct MPI_Status + ) # In place MPI All reduce MPI_IN_PLACE: int = -1 -_RewriteT = TypeVar('_RewriteT', bound=mpi.MPIBaseOp) +_RewriteT = TypeVar("_RewriteT", bound=mpi.MPIBaseOp) @dataclass @@ -107,19 +114,19 @@ class _MPIToLLVMRewriteBase(RewritePattern, ABC): """ MPI_SYMBOL_NAMES = { - 'mpi.init': 'MPI_Init', - 'mpi.finalize': 'MPI_Finalize', - 'mpi.irecv': 'MPI_Irecv', - 'mpi.isend': 'MPI_Isend', - 'mpi.wait': 'MPI_Wait', - 'mpi.waitall': 'MPI_Waitall', - 'mpi.comm.rank': 'MPI_Comm_rank', - 'mpi.comm.size': 'MPI_Comm_size', - 'mpi.recv': 'MPI_Recv', - 'mpi.send': 'MPI_Send', - "mpi.reduce": 'MPI_Reduce', - "mpi.allreduce": 'MPI_Allreduce', - "mpi.bcast": 'MPI_Bcast', + "mpi.init": "MPI_Init", + "mpi.finalize": "MPI_Finalize", + "mpi.irecv": "MPI_Irecv", + "mpi.isend": "MPI_Isend", + "mpi.wait": "MPI_Wait", + "mpi.waitall": "MPI_Waitall", + "mpi.comm.rank": "MPI_Comm_rank", + "mpi.comm.size": "MPI_Comm_size", + "mpi.recv": "MPI_Recv", + "mpi.send": "MPI_Send", + "mpi.reduce": "MPI_Reduce", + "mpi.allreduce": "MPI_Allreduce", + "mpi.bcast": "MPI_Bcast", } """ Translation table for mpi operation names to their MPI library function names @@ -131,8 +138,9 @@ class _MPIToLLVMRewriteBase(RewritePattern, ABC): """ # Helpers - def _get_mpi_dtype_size(self, mpi_dialect_dtype: mpi.RequestType - | mpi.StatusType | mpi.DataType): + def _get_mpi_dtype_size( + self, mpi_dialect_dtype: mpi.RequestType | mpi.StatusType | mpi.DataType + ): """ This function retrieves the data size of a provided MPI type object """ @@ -145,7 +153,9 @@ def _get_mpi_dtype_size(self, mpi_dialect_dtype: mpi.RequestType else: raise ValueError( "MPI internal type size lookup: Unsupported type: {}".format( - mpi_dialect_dtype)) + mpi_dialect_dtype + ) + ) def _emit_mpi_status_objs( self, number_to_output: int @@ -159,23 +169,33 @@ def _emit_mpi_status_objs( magic value for MPI_STATUS_NONE. """ if number_to_output == 0: - return [ - lit1 := arith.Constant.from_int_and_width(1, builtin.i64), - res := llvm.IntToPtrOp.get(lit1), - ], [], res + return ( + [ + lit1 := arith.Constant.from_int_and_width(1, builtin.i64), + res := llvm.IntToPtrOp.get(lit1), + ], + [], + res, + ) else: - return [ - lit1 := - arith.Constant.from_int_and_width(number_to_output, - builtin.i64), - res := llvm.AllocaOp.get(lit1, - builtin.IntegerType( - 8 * self.info.MPI_Status_size), - as_untyped_ptr=True), - ], [res.res], res + return ( + [ + lit1 := arith.Constant.from_int_and_width( + number_to_output, builtin.i64 + ), + res := llvm.AllocaOp.get( + lit1, + builtin.IntegerType(8 * self.info.MPI_Status_size), + as_untyped_ptr=True, + ), + ], + [res.res], + res, + ) def _emit_memref_counts( - self, ssa_val: SSAValue) -> tuple[list[Operation], OpResult]: + self, ssa_val: SSAValue + ) -> tuple[list[Operation], OpResult]: """ This takes in an SSA Value holding a memref, and creates operations to calculate the number of elements in the memref. @@ -188,22 +208,21 @@ def _emit_memref_counts( # Note: we only allow MemRef, not UnrankedMemref! # TODO: handle -1 in sizes if not all(dim.value.data >= 0 for dim in ssa_val.typ.shape.data): - raise RuntimeError( - "MPI lowering does not support unknown-size memrefs!") + raise RuntimeError("MPI lowering does not support unknown-size memrefs!") size = sum(dim.value.data for dim in ssa_val.typ.shape.data) literal = arith.Constant.from_int_and_width(size, i32) return [literal], literal.result - def _emit_mpi_operation_load(self, - op_attr: mpi.OperationType) -> Operation: + def _emit_mpi_operation_load(self, op_attr: mpi.OperationType) -> Operation: """ This emits an instruction loading the correct magic MPI value for the operation into an SSA Value. """ return arith.Constant.from_int_and_width( - self._translate_to_mpi_op(op_attr), i32) + self._translate_to_mpi_op(op_attr), i32 + ) def _translate_to_mpi_op(self, op_attr: mpi.OperationType) -> int: """ @@ -221,7 +240,8 @@ def _emit_mpi_type_load(self, type_attr: Attribute) -> Operation: xDSL type of into an SSA Value. """ return arith.Constant.from_int_and_width( - self._translate_to_mpi_type(type_attr), i32) + self._translate_to_mpi_type(type_attr), i32 + ) def _translate_to_mpi_type(self, typ: Attribute) -> int: """ @@ -263,10 +283,11 @@ def _translate_to_mpi_type(self, typ: Attribute) -> int: if width == 64: return self.info.MPI_LONG_LONG_INT raise ValueError( - "MPI Datatype Conversion: Unsupported integer bitwidth: {}". - format(width)) - raise ValueError( - "MPI Datatype Conversion: Unsupported type {}".format(typ)) + "MPI Datatype Conversion: Unsupported integer bitwidth: {}".format( + width + ) + ) + raise ValueError("MPI Datatype Conversion: Unsupported type {}".format(typ)) def _mpi_name(self, op: mpi.MPIBaseOp) -> str: """ @@ -274,12 +295,13 @@ def _mpi_name(self, op: mpi.MPIBaseOp) -> str: """ if op.name not in self.MPI_SYMBOL_NAMES: raise RuntimeError( - "Lowering of MPI Operations failed, missing lowering for {}!". - format(op.name)) + "Lowering of MPI Operations failed, missing lowering for {}!".format( + op.name + ) + ) return self.MPI_SYMBOL_NAMES[op.name] - def _memref_get_llvm_ptr( - self, ref: SSAValue) -> tuple[list[Operation], Operation]: + def _memref_get_llvm_ptr(self, ref: SSAValue) -> tuple[list[Operation], Operation]: """ Converts an SSA Value holding a reference to a memref to llvm.ptr @@ -300,13 +322,11 @@ def _memref_get_llvm_ptr( class LowerMpiInit(_MPIToLLVMRewriteBase): - @op_type_rewrite_pattern def match_and_rewrite(self, op: mpi.Init, rewriter: PatternRewriter, /): rewriter.replace_matched_op(*self.lower(op)) - def lower(self, - op: mpi.Init) -> tuple[list[Operation], list[SSAValue | None]]: + def lower(self, op: mpi.Init) -> tuple[list[Operation], list[SSAValue | None]]: """ We currently don't model any argument passing to `MPI_Init()` and pass two nullptrs. """ @@ -317,15 +337,11 @@ def lower(self, class LowerMpiFinalize(_MPIToLLVMRewriteBase): - @op_type_rewrite_pattern - def match_and_rewrite(self, op: mpi.Finalize, rewriter: PatternRewriter, - /): + def match_and_rewrite(self, op: mpi.Finalize, rewriter: PatternRewriter, /): rewriter.replace_matched_op(*self.lower(op)) - def lower( - self, - op: mpi.Finalize) -> tuple[list[Operation], list[SSAValue | None]]: + def lower(self, op: mpi.Finalize) -> tuple[list[Operation], list[SSAValue | None]]: """ Relatively straight forward lowering of mpi.finalize operation. """ @@ -335,13 +351,11 @@ def lower( class LowerMpiWait(_MPIToLLVMRewriteBase): - @op_type_rewrite_pattern def match_and_rewrite(self, op: mpi.Wait, rewriter: PatternRewriter, /): rewriter.replace_matched_op(*self.lower(op)) - def lower(self, - op: mpi.Wait) -> tuple[list[Operation], list[SSAValue | None]]: + def lower(self, op: mpi.Wait) -> tuple[list[Operation], list[SSAValue | None]]: """ Relatively straight forward lowering of mpi.wait operation. """ @@ -353,14 +367,11 @@ def lower(self, class LowerMpiWaitall(_MPIToLLVMRewriteBase): - @op_type_rewrite_pattern def match_and_rewrite(self, op: mpi.Waitall, rewriter: PatternRewriter, /): rewriter.replace_matched_op(*self.lower(op)) - def lower( - self, - op: mpi.Waitall) -> tuple[list[Operation], list[SSAValue | None]]: + def lower(self, op: mpi.Waitall) -> tuple[list[Operation], list[SSAValue | None]]: """ Relatively straight forward lowering of mpi.waitall operation. """ @@ -368,44 +379,47 @@ def lower( ops, new_results, res = self._emit_mpi_status_objs(len(op.results)) return [ *ops, - func.Call.get(self._mpi_name(op), [op.count, op.requests, res], - [i32]), + func.Call.get(self._mpi_name(op), [op.count, op.requests, res], [i32]), ], new_results class LowerMpiReduce(_MPIToLLVMRewriteBase): - @op_type_rewrite_pattern def match_and_rewrite(self, op: mpi.Reduce, rewriter: PatternRewriter, /): rewriter.replace_matched_op(*self.lower(op)) - def lower(self, - op: mpi.Reduce) -> tuple[list[Operation], list[SSAValue | None]]: + def lower(self, op: mpi.Reduce) -> tuple[list[Operation], list[SSAValue | None]]: """ Lowers the MPI Reduce operation """ return [ - comm_global := - arith.Constant.from_int_and_width(self.info.MPI_COMM_WORLD, i32), + comm_global := arith.Constant.from_int_and_width( + self.info.MPI_COMM_WORLD, i32 + ), mpi_op := self._emit_mpi_operation_load(op.operationtype), - func.Call.get(self._mpi_name(op), [ - op.send_buffer, op.recv_buffer, op.count, op.datatype, mpi_op, - op.root, comm_global - ], []), + func.Call.get( + self._mpi_name(op), + [ + op.send_buffer, + op.recv_buffer, + op.count, + op.datatype, + mpi_op, + op.root, + comm_global, + ], + [], + ), ], [] class LowerMpiAllreduce(_MPIToLLVMRewriteBase): - @op_type_rewrite_pattern - def match_and_rewrite(self, op: mpi.Allreduce, rewriter: PatternRewriter, - /): + def match_and_rewrite(self, op: mpi.Allreduce, rewriter: PatternRewriter, /): rewriter.replace_matched_op(*self.lower(op)) - def lower( - self, op: mpi.Allreduce - ) -> tuple[list[Operation], list[SSAValue | None]]: + def lower(self, op: mpi.Allreduce) -> tuple[list[Operation], list[SSAValue | None]]: """ Lowers the MPI Allreduce operation """ @@ -413,8 +427,7 @@ def lower( # Send buffer is optional (if not provided then call using MPI_IN_PLACE) has_send_buffer = op.send_buffer is not None - comm_global = arith.Constant.from_int_and_width( - self.info.MPI_COMM_WORLD, i32) + comm_global = arith.Constant.from_int_and_width(self.info.MPI_COMM_WORLD, i32) mpi_op = self._emit_mpi_operation_load(op.operationtype) operations = [comm_global, mpi_op] @@ -425,47 +438,55 @@ def lower( send_buffer_op = op.send_buffer else: send_buffer_op = arith.Constant.from_int_and_width( - self.info.MPI_IN_PLACE, i64) + self.info.MPI_IN_PLACE, i64 + ) operations.append(send_buffer_op) return [ *operations, - func.Call.get(self._mpi_name(op), [ - send_buffer_op, op.recv_buffer, op.count, op.datatype, mpi_op, - comm_global - ], []), + func.Call.get( + self._mpi_name(op), + [ + send_buffer_op, + op.recv_buffer, + op.count, + op.datatype, + mpi_op, + comm_global, + ], + [], + ), ], [] class LowerMpiBcast(_MPIToLLVMRewriteBase): - @op_type_rewrite_pattern def match_and_rewrite(self, op: mpi.Bcast, rewriter: PatternRewriter, /): rewriter.replace_matched_op(*self.lower(op)) - def lower(self, - op: mpi.Bcast) -> tuple[list[Operation], list[SSAValue | None]]: + def lower(self, op: mpi.Bcast) -> tuple[list[Operation], list[SSAValue | None]]: """ Lowers the MPI Bcast operation """ return [ - comm_global := - arith.Constant.from_int_and_width(self.info.MPI_COMM_WORLD, i32), + comm_global := arith.Constant.from_int_and_width( + self.info.MPI_COMM_WORLD, i32 + ), func.Call.get( self._mpi_name(op), - [op.buffer, op.count, op.datatype, op.root, comm_global], []), + [op.buffer, op.count, op.datatype, op.root, comm_global], + [], + ), ], [] class LowerMpiIsend(_MPIToLLVMRewriteBase): - @op_type_rewrite_pattern def match_and_rewrite(self, op: mpi.Isend, rewriter: PatternRewriter, /): rewriter.replace_matched_op(*self.lower(op)) - def lower(self, - op: mpi.Isend) -> tuple[list[Operation], list[SSAValue | None]]: + def lower(self, op: mpi.Isend) -> tuple[list[Operation], list[SSAValue | None]]: """ This method lowers mpi.isend @@ -474,23 +495,31 @@ def lower(self, """ return [ - comm_global := - arith.Constant.from_int_and_width(self.info.MPI_COMM_WORLD, i32), - func.Call.get(self._mpi_name(op), [ - op.buffer, op.count, op.datatype, op.dest, op.tag, comm_global, - op.request - ], [i32]), + comm_global := arith.Constant.from_int_and_width( + self.info.MPI_COMM_WORLD, i32 + ), + func.Call.get( + self._mpi_name(op), + [ + op.buffer, + op.count, + op.datatype, + op.dest, + op.tag, + comm_global, + op.request, + ], + [i32], + ), ], [] class LowerMpiIrecv(_MPIToLLVMRewriteBase): - @op_type_rewrite_pattern def match_and_rewrite(self, op: mpi.Irecv, rewriter: PatternRewriter, /): rewriter.replace_matched_op(*self.lower(op)) - def lower(self, - op: mpi.Irecv) -> tuple[list[Operation], list[SSAValue | None]]: + def lower(self, op: mpi.Irecv) -> tuple[list[Operation], list[SSAValue | None]]: """ This method lowers mpi.irecv operations @@ -499,23 +528,31 @@ def lower(self, """ return [ - comm_global := - arith.Constant.from_int_and_width(self.info.MPI_COMM_WORLD, i32), - func.Call.get(self._mpi_name(op), [ - op.buffer, op.count, op.datatype, op.source, op.tag, - comm_global, op.request - ], [i32]), + comm_global := arith.Constant.from_int_and_width( + self.info.MPI_COMM_WORLD, i32 + ), + func.Call.get( + self._mpi_name(op), + [ + op.buffer, + op.count, + op.datatype, + op.source, + op.tag, + comm_global, + op.request, + ], + [i32], + ), ], [] class LowerMpiSend(_MPIToLLVMRewriteBase): - @op_type_rewrite_pattern def match_and_rewrite(self, op: mpi.Send, rewriter: PatternRewriter, /): rewriter.replace_matched_op(*self.lower(op)) - def lower(self, - op: mpi.Send) -> tuple[list[Operation], list[SSAValue | None]]: + def lower(self, op: mpi.Send) -> tuple[list[Operation], list[SSAValue | None]]: """ This method lowers mpi.send operations @@ -526,22 +563,23 @@ def lower(self, """ return [ - comm_global := - arith.Constant.from_int_and_width(self.info.MPI_COMM_WORLD, i32), - func.Call.get(self._mpi_name(op), [ - op.buffer, op.count, op.datatype, op.dest, op.tag, comm_global - ], [i32]), + comm_global := arith.Constant.from_int_and_width( + self.info.MPI_COMM_WORLD, i32 + ), + func.Call.get( + self._mpi_name(op), + [op.buffer, op.count, op.datatype, op.dest, op.tag, comm_global], + [i32], + ), ], [] class LowerMpiRecv(_MPIToLLVMRewriteBase): - @op_type_rewrite_pattern def match_and_rewrite(self, op: mpi.Recv, rewriter: PatternRewriter, /): rewriter.replace_matched_op(*self.lower(op)) - def lower(self, - op: mpi.Recv) -> tuple[list[Operation], list[SSAValue | None]]: + def lower(self, op: mpi.Recv) -> tuple[list[Operation], list[SSAValue | None]]: """ This method lowers mpi.recv operations @@ -552,24 +590,33 @@ def lower(self, """ mpi_status_ops, new_results, status = self._emit_mpi_status_objs( - len(op.results)) + len(op.results) + ) return [ *mpi_status_ops, - comm_global := - arith.Constant.from_int_and_width(self.info.MPI_COMM_WORLD, i32), - func.Call.get(self._mpi_name(op), [ - op.buffer, op.count, op.datatype, op.source, op.tag, - comm_global, status - ], [i32]), + comm_global := arith.Constant.from_int_and_width( + self.info.MPI_COMM_WORLD, i32 + ), + func.Call.get( + self._mpi_name(op), + [ + op.buffer, + op.count, + op.datatype, + op.source, + op.tag, + comm_global, + status, + ], + [i32], + ), ], new_results class LowerMpiUnwrapMemrefOp(_MPIToLLVMRewriteBase): - @op_type_rewrite_pattern - def match_and_rewrite(self, op: mpi.UnwrapMemrefOp, - rewriter: PatternRewriter, /): + def match_and_rewrite(self, op: mpi.UnwrapMemrefOp, rewriter: PatternRewriter, /): rewriter.replace_matched_op(*self.lower(op)) def lower( @@ -578,8 +625,7 @@ def lower( count_ops, count_ssa_val = self._emit_memref_counts(op.ref) extract_ptr_ops, ptr = self._memref_get_llvm_ptr(op.ref) - elem_typ = cast(MemRefType[mpi.AnyNumericType], - op.ref.typ).element_type + elem_typ = cast(MemRefType[mpi.AnyNumericType], op.ref.typ).element_type return [ *extract_ptr_ops, @@ -589,14 +635,12 @@ def lower( class LowerMpiGetDtype(_MPIToLLVMRewriteBase): - @op_type_rewrite_pattern - def match_and_rewrite(self, op: mpi.GetDtypeOp, rewriter: PatternRewriter, - /): + def match_and_rewrite(self, op: mpi.GetDtypeOp, rewriter: PatternRewriter, /): rewriter.replace_matched_op(*self.lower(op)) def lower( - self, op: mpi.GetDtypeOp + self, op: mpi.GetDtypeOp ) -> tuple[list[Operation], list[SSAValue | None]]: return [ typ := self._emit_mpi_type_load(op.dtype), @@ -604,10 +648,8 @@ def lower( class LowerMpiAllocateType(_MPIToLLVMRewriteBase): - @op_type_rewrite_pattern - def match_and_rewrite(self, op: mpi.AllocateTypeOp, - rewriter: PatternRewriter, /): + def match_and_rewrite(self, op: mpi.AllocateTypeOp, rewriter: PatternRewriter, /): rewriter.replace_matched_op(*self.lower(op)) def lower( @@ -618,21 +660,19 @@ def lower( """ datatype_size = self._get_mpi_dtype_size(op.dtype) return [ - request - := llvm.AllocaOp.get(op.count, - builtin.IntegerType(8 * datatype_size)), + request := llvm.AllocaOp.get( + op.count, builtin.IntegerType(8 * datatype_size) + ), ], [request.results[0]] class LowerMpiVectorGet(_MPIToLLVMRewriteBase): - @op_type_rewrite_pattern - def match_and_rewrite(self, op: mpi.VectorGetOp, rewriter: PatternRewriter, - /): + def match_and_rewrite(self, op: mpi.VectorGetOp, rewriter: PatternRewriter, /): rewriter.replace_matched_op(*self.lower(op)) def lower( - self, op: mpi.VectorGetOp + self, op: mpi.VectorGetOp ) -> tuple[list[Operation], list[SSAValue | None]]: """ This lowers the get array at index MPI operation in the dialect. Converts @@ -645,34 +685,31 @@ def lower( datatype_size = self._get_mpi_dtype_size(op.result.typ) return [ - ptr_int := llvm.PtrToIntOp.get(op.vect, i64), lit1 := - arith.Constant.from_int_and_width(datatype_size, 64), idx_cast1 := - arith.IndexCastOp.get(op.element, IndexType()), idx_cast2 := - arith.IndexCastOp.get(idx_cast1, - i64), mul := arith.Muli.get(lit1, idx_cast2), - add := arith.Addi.get(mul, ptr_int), out_ptr := - llvm.IntToPtrOp.get(add, op.vect.typ.type) + ptr_int := llvm.PtrToIntOp.get(op.vect, i64), + lit1 := arith.Constant.from_int_and_width(datatype_size, 64), + idx_cast1 := arith.IndexCastOp.get(op.element, IndexType()), + idx_cast2 := arith.IndexCastOp.get(idx_cast1, i64), + mul := arith.Muli.get(lit1, idx_cast2), + add := arith.Addi.get(mul, ptr_int), + out_ptr := llvm.IntToPtrOp.get(add, op.vect.typ.type), ], [out_ptr.results[0]] class LowerMpiCommRank(_MPIToLLVMRewriteBase): - @op_type_rewrite_pattern - def match_and_rewrite(self, op: mpi.CommRank, rewriter: PatternRewriter, - /): + def match_and_rewrite(self, op: mpi.CommRank, rewriter: PatternRewriter, /): rewriter.replace_matched_op(*self.lower(op)) - def lower( - self, - op: mpi.CommRank) -> tuple[list[Operation], list[SSAValue | None]]: + def lower(self, op: mpi.CommRank) -> tuple[list[Operation], list[SSAValue | None]]: """ This method lowers mpi.comm.rank operation int MPI_Comm_rank(MPI_Comm comm, int *rank) """ return [ - comm_global := - arith.Constant.from_int_and_width(self.info.MPI_COMM_WORLD, i32), + comm_global := arith.Constant.from_int_and_width( + self.info.MPI_COMM_WORLD, i32 + ), lit1 := arith.Constant.from_int_and_width(1, 64), int_ptr := llvm.AllocaOp.get(lit1, i32), func.Call.get(self._mpi_name(op), [comm_global, int_ptr], [i32]), @@ -681,23 +718,20 @@ def lower( class LowerMpiCommSize(_MPIToLLVMRewriteBase): - @op_type_rewrite_pattern - def match_and_rewrite(self, op: mpi.CommSize, rewriter: PatternRewriter, - /): + def match_and_rewrite(self, op: mpi.CommSize, rewriter: PatternRewriter, /): rewriter.replace_matched_op(*self.lower(op)) - def lower( - self, - op: mpi.CommSize) -> tuple[list[Operation], list[SSAValue | None]]: + def lower(self, op: mpi.CommSize) -> tuple[list[Operation], list[SSAValue | None]]: """ This method lowers mpi.comm.size operation int MPI_Comm_size(MPI_Comm comm, int *size) """ return [ - comm_global := - arith.Constant.from_int_and_width(self.info.MPI_COMM_WORLD, i32), + comm_global := arith.Constant.from_int_and_width( + self.info.MPI_COMM_WORLD, i32 + ), lit1 := arith.Constant.from_int_and_width(1, 64), int_ptr := llvm.AllocaOp.get(lit1, i32), func.Call.get(self._mpi_name(op), [comm_global, int_ptr], [i32]), @@ -715,14 +749,13 @@ class MpiAddExternalFuncDefs(RewritePattern): Make sure to apply this *in a separate pass after the lowerings*, otherwise this will match first and find no inserted MPI calls. """ + mpi_func_call_names = set(_MPIToLLVMRewriteBase.MPI_SYMBOL_NAMES.values()) @op_type_rewrite_pattern - def match_and_rewrite(self, module: builtin.ModuleOp, - rewriter: PatternRewriter, /): + def match_and_rewrite(self, module: builtin.ModuleOp, rewriter: PatternRewriter, /): # collect all func calls to MPI functions - funcs_to_emit: dict[str, tuple[list[Attribute], - list[Attribute]]] = dict() + funcs_to_emit: dict[str, tuple[list[Attribute], list[Attribute]]] = dict() def walker(op: Operation): if not isinstance(op, func.Call): @@ -739,16 +772,16 @@ def walker(op: Operation): # for each func found, add a FuncOp to the top of the module. for name, types in funcs_to_emit.items(): arg, res = types - rewriter.insert_op_at_pos(func.FuncOp.external(name, arg, res), - module.body.block, - len(module.body.block.ops)) + rewriter.insert_op_at_pos( + func.FuncOp.external(name, arg, res), + module.body.block, + len(module.body.block.ops), + ) class LowerNullRequestOp(_MPIToLLVMRewriteBase): - @op_type_rewrite_pattern - def match_and_rewrite(self, op: mpi.NullRequestOp, - rewriter: PatternRewriter, /): + def match_and_rewrite(self, op: mpi.NullRequestOp, rewriter: PatternRewriter, /): rewriter.replace_matched_op(*self.lower(op)) def lower( @@ -761,43 +794,44 @@ def lower( """ assert isa(op.request.typ, llvm.LLVMPointerType) return [ - val := - arith.Constant.from_int_and_width(self.info.MPI_REQUEST_NULL, i32), + val := arith.Constant.from_int_and_width(self.info.MPI_REQUEST_NULL, i32), llvm.StoreOp.get(val, op.request), ], [] @dataclass class LowerMPIPass(ModulePass): - - name = 'lower-mpi' + name = "lower-mpi" # lower to func.call def apply(self, ctx: MLContext, op: builtin.ModuleOp) -> None: - # TODO: how to get the lib info in here? lib_info = MpiLibraryInfo() - walker1 = PatternRewriteWalker(GreedyRewritePatternApplier([ - LowerMpiInit(lib_info), - LowerMpiFinalize(lib_info), - LowerMpiWait(lib_info), - LowerMpiWaitall(lib_info), - LowerMpiCommRank(lib_info), - LowerMpiCommSize(lib_info), - LowerMpiIsend(lib_info), - LowerMpiIrecv(lib_info), - LowerMpiSend(lib_info), - LowerMpiRecv(lib_info), - LowerMpiReduce(lib_info), - LowerMpiAllreduce(lib_info), - LowerMpiBcast(lib_info), - LowerMpiUnwrapMemrefOp(lib_info), - LowerMpiGetDtype(lib_info), - LowerMpiAllocateType(lib_info), - LowerMpiVectorGet(lib_info), - ]), - apply_recursively=True) + walker1 = PatternRewriteWalker( + GreedyRewritePatternApplier( + [ + LowerMpiInit(lib_info), + LowerMpiFinalize(lib_info), + LowerMpiWait(lib_info), + LowerMpiWaitall(lib_info), + LowerMpiCommRank(lib_info), + LowerMpiCommSize(lib_info), + LowerMpiIsend(lib_info), + LowerMpiIrecv(lib_info), + LowerMpiSend(lib_info), + LowerMpiRecv(lib_info), + LowerMpiReduce(lib_info), + LowerMpiAllreduce(lib_info), + LowerMpiBcast(lib_info), + LowerMpiUnwrapMemrefOp(lib_info), + LowerMpiGetDtype(lib_info), + LowerMpiAllocateType(lib_info), + LowerMpiVectorGet(lib_info), + ] + ), + apply_recursively=True, + ) # add func.func to declare external functions walker2 = PatternRewriteWalker(MpiAddExternalFuncDefs()) diff --git a/xdsl/utils/deprecation.py b/xdsl/utils/deprecation.py index 3756abc1a6..7a47ef3a27 100644 --- a/xdsl/utils/deprecation.py +++ b/xdsl/utils/deprecation.py @@ -5,15 +5,14 @@ # this simple use case. # If we ever need more features from the library, we can switch to it. -_T = TypeVar('_T') -_P = ParamSpec('_P') +_T = TypeVar("_T") +_P = ParamSpec("_P") def deprecated(reason: str): """Deprecate the use of a method, and provide a warning message.""" def decorator(func: Callable[_P, _T]) -> Callable[_P, _T]: - def new_func(*args: _P.args, **kwargs: _P.kwargs) -> _T: warnings.warn( f'Call to deprecated method {str(func).split(" ")[1]}: {reason}' @@ -27,5 +26,4 @@ def new_func(*args: _P.args, **kwargs: _P.kwargs) -> _T: def deprecated_constructor(func: Callable[_P, _T]) -> Callable[_P, _T]: # TOFIX: improve printing - return deprecated(f'{"use the constructor (`ClassName(...)`) instead."}')( - func) + return deprecated(f'{"use the constructor (`ClassName(...)`) instead."}')(func) diff --git a/xdsl/utils/diagnostic.py b/xdsl/utils/diagnostic.py index 276958f18b..d05e11ed61 100644 --- a/xdsl/utils/diagnostic.py +++ b/xdsl/utils/diagnostic.py @@ -15,12 +15,14 @@ def add_message(self, op: Operation, message: str) -> None: self.op_messages.setdefault(op, []).append(message) def raise_exception( - self, - message: str, - ir: IRNode, - exception_type: type[Exception] = DiagnosticException) -> None: + self, + message: str, + ir: IRNode, + exception_type: type[Exception] = DiagnosticException, + ) -> None: """Raise an exception, that will also print all messages in the IR.""" from xdsl.printer import Printer + f = StringIO() p = Printer(stream=f, diagnostic=self, target=Printer.Target.XDSL) toplevel = ir.get_toplevel_object() diff --git a/xdsl/utils/exceptions.py b/xdsl/utils/exceptions.py index f87e3049ba..2f6abb3f51 100644 --- a/xdsl/utils/exceptions.py +++ b/xdsl/utils/exceptions.py @@ -43,6 +43,7 @@ class InterpretationError(Exception): """ An error that can be raised during interpretation, or Interpreter setup. """ + pass @@ -52,12 +53,15 @@ class BuilderNotFoundException(Exception): Exception raised when no builders are found for a given attribute type and a given tuple of arguments. """ + attribute: type[Attribute] args: tuple[Any] def __str__(self) -> str: - return f"No builder found for attribute {self.attribute} with " \ - f"arguments {self.args}" + return ( + f"No builder found for attribute {self.attribute} with " + f"arguments {self.args}" + ) class DeferredExceptionMessage: @@ -107,12 +111,11 @@ def __contains__(self, item: str): class ParseError(Exception): span: Span msg: str - history: 'BacktrackingHistory' | None + history: "BacktrackingHistory" | None - def __init__(self, - span: Span, - msg: str, - history: 'BacktrackingHistory' | None = None): + def __init__( + self, span: Span, msg: str, history: "BacktrackingHistory" | None = None + ): super().__init__(DeferredExceptionMessage(lambda: repr(self))) self.span = span self.msg = msg @@ -144,7 +147,7 @@ def __init__( msg: str, ref_text: str, refs: list[tuple[Span, str | None]], - history: 'BacktrackingHistory' | None = None, + history: "BacktrackingHistory" | None = None, ): super(MultipleSpansParseError, self).__init__(span, msg, history) self.refs = refs diff --git a/xdsl/utils/hints.py b/xdsl/utils/hints.py index 2ca45871a3..57fca81ba9 100644 --- a/xdsl/utils/hints.py +++ b/xdsl/utils/hints.py @@ -1,7 +1,6 @@ from inspect import isclass from types import UnionType -from typing import (Annotated, Any, TypeGuard, TypeVar, Union, cast, get_args, - get_origin) +from typing import Annotated, Any, TypeGuard, TypeVar, Union, cast, get_args, get_origin from xdsl.ir import ParametrizedAttribute from xdsl.utils.exceptions import VerifyException @@ -28,7 +27,7 @@ def isa(arg: Any, hint: type[_T]) -> TypeGuard[_T]: if not isinstance(arg, list): return False arg_list: list[Any] = cast(list[Any], arg) - elem_hint, = get_args(hint) + (elem_hint,) = get_args(hint) return all(isa(elem, elem_hint) for elem in arg_list) if origin is tuple: @@ -40,7 +39,8 @@ def isa(arg: Any, hint: type[_T]) -> TypeGuard[_T]: return all(isa(elem, elem_hints[0]) for elem in arg_tuple) else: return len(elem_hints) == len(arg_tuple) and all( - isa(elem, hint) for elem, hint in zip(arg_tuple, elem_hints)) + isa(elem, hint) for elem, hint in zip(arg_tuple, elem_hints) + ) if origin is dict: if not isinstance(arg, dict): @@ -49,14 +49,15 @@ def isa(arg: Any, hint: type[_T]) -> TypeGuard[_T]: key_hint, value_hint = get_args(hint) return all( isa(key, key_hint) and isa(value, value_hint) - for key, value in arg_dict.items()) + for key, value in arg_dict.items() + ) if origin in [Union, UnionType]: return any(isa(arg, union_arg) for union_arg in get_args(hint)) from xdsl.irdl import GenericData, irdl_to_attr_constraint - if (origin is not None) and issubclass( - origin, GenericData | ParametrizedAttribute): + + if (origin is not None) and issubclass(origin, GenericData | ParametrizedAttribute): constraint = irdl_to_attr_constraint(hint) try: constraint.verify(arg) @@ -86,7 +87,6 @@ def assert_isa(arg: Any, hint: type[_T]) -> TypeGuard[_T]: class _Class: - @property def property(self): pass diff --git a/xdsl/utils/immutable_list.py b/xdsl/utils/immutable_list.py index ea1aff0350..0238ea0897 100644 --- a/xdsl/utils/immutable_list.py +++ b/xdsl/utils/immutable_list.py @@ -1,14 +1,15 @@ from __future__ import annotations from typing import TypeVar, Iterable, SupportsIndex, List -_T = TypeVar('_T') +_T = TypeVar("_T") class IList(List[_T]): """ - A list that can be frozen. Once frozen, it can not be modified. + A list that can be frozen. Once frozen, it can not be modified. In comparison to FrozenList this supports pattern matching. """ + _frozen: bool = False def freeze(self): diff --git a/xdsl/utils/lexer.py b/xdsl/utils/lexer.py index 14d139f7fd..63ab46cc48 100644 --- a/xdsl/utils/lexer.py +++ b/xdsl/utils/lexer.py @@ -17,24 +17,24 @@ class Input: """ Used to keep track of the input when parsing. """ + content: str = field(repr=False) name: str len: int = field(init=False, repr=False) def __post_init__(self): - object.__setattr__(self, 'len', len(self.content)) + object.__setattr__(self, "len", len(self.content)) def __len__(self): return self.len - def get_lines_containing(self, - span: Span) -> tuple[list[str], int, int] | None: + def get_lines_containing(self, span: Span) -> tuple[list[str], int, int] | None: # A pointer to the start of the first line start = 0 line_no = 0 source = self.content while True: - next_start = source.find('\n', start) + next_start = source.find("\n", start) line_no += 1 # Handle eof if next_start == -1: @@ -49,10 +49,10 @@ def get_lines_containing(self, if next_start >= span.end: return [source[start:next_start]], start, line_no while next_start < span.end: - next_start = source.find('\n', next_start + 1) + next_start = source.find("\n", next_start + 1) if next_start == -1: next_start = span.end - return source[start:next_start].split('\n'), start, line_no + return source[start:next_start].split("\n"), start, line_no def at(self, i: int) -> str | None: if i >= self.len: @@ -93,7 +93,7 @@ def len(self): @property def text(self): - return self.input.content[self.start:self.end] + return self.input.content[self.start : self.end] def get_line_col(self) -> tuple[int, int]: info = self.input.get_lines_containing(self) @@ -116,16 +116,17 @@ def print_with_context(self, msg: str | None = None) -> str: offset = self.start - offset_of_first_line remaining_len = max(self.len, 1) capture = StringIO() - print("{}:{}:{}".format(self.input.name, line_no, offset), - file=capture) + print("{}:{}:{}".format(self.input.name, line_no, offset), file=capture) for line in lines: print(line, file=capture) if remaining_len < 0: continue len_on_this_line = min(remaining_len, len(line) - offset) remaining_len -= len_on_this_line - print("{}{}".format(" " * offset, "^" * max(len_on_this_line, 1)), - file=capture) + print( + "{}{}".format(" " * offset, "^" * max(len_on_this_line, 1)), + file=capture, + ) if msg is not None: print("{}{}".format(" " * offset, msg), file=capture) msg = None @@ -135,13 +136,13 @@ def print_with_context(self, msg: str | None = None) -> str: return capture.getvalue() def __repr__(self): - return "{}[{}:{}](text='{}')".format(self.__class__.__name__, - self.start, self.end, self.text) + return "{}[{}:{}](text='{}')".format( + self.__class__.__name__, self.start, self.end, self.text + ) @dataclass(frozen=True, repr=False) class StringLiteral(Span): - def __post_init__(self): if len(self) < 2 or self.text[0] != '"' or self.text[-1] != '"': raise ParseError(self, "Invalid string literal!") @@ -165,7 +166,6 @@ def string_contents(self): @dataclass class Token: - class Kind(Enum): # Markers EOF = object() @@ -237,30 +237,48 @@ def get_punctuation_spelling_to_kind_dict() -> dict[str, Token.Kind]: } def is_punctuation(self) -> bool: - punctuation_dict = Token.Kind.get_punctuation_spelling_to_kind_dict( - ) + punctuation_dict = Token.Kind.get_punctuation_spelling_to_kind_dict() return self in punctuation_dict.values() @staticmethod def is_spelling_of_punctuation( - spelling: str) -> TypeGuard[Token.PunctuationSpelling]: - punctuation_dict = \ - Token.Kind.get_punctuation_spelling_to_kind_dict() + spelling: str, + ) -> TypeGuard[Token.PunctuationSpelling]: + punctuation_dict = Token.Kind.get_punctuation_spelling_to_kind_dict() return spelling in punctuation_dict.keys() @staticmethod def get_punctuation_kind_from_spelling( - spelling: Token.PunctuationSpelling) -> Token.Kind: - assert Token.Kind.is_spelling_of_punctuation(spelling), \ - "Kind.get_punctuation_kind_from_spelling: spelling is not a " \ - "valid punctuation spelling!" + spelling: Token.PunctuationSpelling, + ) -> Token.Kind: + assert Token.Kind.is_spelling_of_punctuation(spelling), ( + "Kind.get_punctuation_kind_from_spelling: spelling is not a " + "valid punctuation spelling!" + ) return Token.Kind.get_punctuation_spelling_to_kind_dict()[spelling] - PunctuationSpelling: ClassVar[TypeAlias] = Literal["->", ":", ",", "...", - "=", ">", "{", "(", "[", - "<", "-", "+", "?", "}", - ")", "]", "*", "|", - "{-#", "#-}"] + PunctuationSpelling: ClassVar[TypeAlias] = Literal[ + "->", + ":", + ",", + "...", + "=", + ">", + "{", + "(", + "[", + "<", + "-", + "+", + "?", + "}", + ")", + "]", + "*", + "|", + "{-#", + "#-}", + ] kind: Kind @@ -274,19 +292,19 @@ def text(self): def get_int_value(self): """ Translate the token text into an integer value. - This function will raise an exception if the token is not an integer + This function will raise an exception if the token is not an integer literal. """ if self.kind != Token.Kind.INTEGER_LIT: raise ValueError("Token is not an integer literal!") - if self.text[:2] in ['0x', '0X']: + if self.text[:2] in ["0x", "0X"]: return int(self.text, 16) return int(self.text, 10) def get_float_value(self): """ Translate the token text into a float value. - This function will raise an exception if the token is not a float + This function will raise an exception if the token is not a float literal. """ if self.kind != Token.Kind.FLOAT_LIT: @@ -373,33 +391,32 @@ def lex(self) -> Token: return self._form_token(Token.Kind.EOF, start_pos) # bare identifier - if current_char.isalpha() or current_char == '_': + if current_char.isalpha() or current_char == "_": return self._lex_bare_identifier(start_pos) # single-char punctuation that are not part of a multi-char token single_char_punctuation = { - ':': Token.Kind.COLON, - ',': Token.Kind.COMMA, - '(': Token.Kind.L_PAREN, - ')': Token.Kind.R_PAREN, - '}': Token.Kind.R_BRACE, - '[': Token.Kind.L_SQUARE, - ']': Token.Kind.R_SQUARE, - '<': Token.Kind.LESS, - '>': Token.Kind.GREATER, - '=': Token.Kind.EQUAL, - '+': Token.Kind.PLUS, - '*': Token.Kind.STAR, - '?': Token.Kind.QUESTION, - '|': Token.Kind.VERTICAL_BAR + ":": Token.Kind.COLON, + ",": Token.Kind.COMMA, + "(": Token.Kind.L_PAREN, + ")": Token.Kind.R_PAREN, + "}": Token.Kind.R_BRACE, + "[": Token.Kind.L_SQUARE, + "]": Token.Kind.R_SQUARE, + "<": Token.Kind.LESS, + ">": Token.Kind.GREATER, + "=": Token.Kind.EQUAL, + "+": Token.Kind.PLUS, + "*": Token.Kind.STAR, + "?": Token.Kind.QUESTION, + "|": Token.Kind.VERTICAL_BAR, } if current_char in single_char_punctuation: - return self._form_token(single_char_punctuation[current_char], - start_pos) + return self._form_token(single_char_punctuation[current_char], start_pos) # '...' - if current_char == '.': - if (self._get_chars(2) != '..'): + if current_char == ".": + if self._get_chars(2) != "..": raise ParseError( Span(start_pos, start_pos + 1, self.input), "Expected three consecutive '.' for an ellipsis", @@ -407,31 +424,30 @@ def lex(self) -> Token: return self._form_token(Token.Kind.ELLIPSIS, start_pos) # '-' and '->' - if current_char == '-': - if self._peek_chars() == '>': + if current_char == "-": + if self._peek_chars() == ">": self._consume_chars() return self._form_token(Token.Kind.ARROW, start_pos) return self._form_token(Token.Kind.MINUS, start_pos) # '{' and '{-#' - if current_char == '{': - if (self._peek_chars(2) == '-#'): + if current_char == "{": + if self._peek_chars(2) == "-#": self._consume_chars(2) - return self._form_token(Token.Kind.FILE_METADATA_BEGIN, - start_pos) + return self._form_token(Token.Kind.FILE_METADATA_BEGIN, start_pos) return self._form_token(Token.Kind.L_BRACE, start_pos) # '#-}' - if (current_char == '#' and self._peek_chars(2) == '-}'): + if current_char == "#" and self._peek_chars(2) == "-}": self._consume_chars(2) return self._form_token(Token.Kind.FILE_METADATA_END, start_pos) # '@' and at-identifier - if current_char == '@': + if current_char == "@": return self._lex_at_ident(start_pos) # '#', '!', '^', '%' identifiers - if current_char in ['#', '!', '^', '%']: + if current_char in ["#", "!", "^", "%"]: return self._lex_prefixed_ident(start_pos) if current_char == '"': @@ -442,7 +458,7 @@ def lex(self) -> Token: raise ParseError( Span(start_pos, start_pos + 1, self.input), - 'Unexpected character: {}'.format(current_char), + "Unexpected character: {}".format(current_char), ) _bare_identifier_suffix_regex = re.compile(r"[a-zA-Z0-9_$.]*") @@ -468,11 +484,13 @@ def _lex_at_ident(self, start_pos: int) -> Token: current_char = self._get_chars() if current_char is None: - raise ParseError(Span(start_pos, start_pos + 1, self.input), - "Unexpected end of file after @.") + raise ParseError( + Span(start_pos, start_pos + 1, self.input), + "Unexpected end of file after @.", + ) # bare identifier case - if current_char.isalpha() or current_char == '_': + if current_char.isalpha() or current_char == "_": token = self._lex_bare_identifier(start_pos) return self._form_token(Token.Kind.AT_IDENT, token.span.start) @@ -483,7 +501,8 @@ def _lex_at_ident(self, start_pos: int) -> Token: raise ParseError( Span(start_pos, self.pos, self.input), - "@ identifier expected to start with letter, '_', or '\"'.") + "@ identifier expected to start with letter, '_', or '\"'.", + ) _suffix_id = re.compile(r"([0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*)") @@ -496,29 +515,35 @@ def _lex_prefixed_ident(self, start_pos: int) -> Token: caret-ident ::= `^` suffix-id exclamation-ident ::= `!` suffix-id ``` - with + with ``` suffix-id = (digit+ | (letter|[$._-]) (letter|[$._-]|digit)*) ``` The first character is expected to have already been parsed. """ - assert self.pos != 0, "First prefixed identifier character must have been parsed" + assert ( + self.pos != 0 + ), "First prefixed identifier character must have been parsed" first_char = self.input.at(self.pos - 1) - if first_char == '#': + if first_char == "#": kind = Token.Kind.HASH_IDENT - elif first_char == '!': + elif first_char == "!": kind = Token.Kind.EXCLAMATION_IDENT - elif first_char == '^': + elif first_char == "^": kind = Token.Kind.CARET_IDENT else: - assert first_char == '%', "First prefixed identifier character must have been parsed correctly" + assert ( + first_char == "%" + ), "First prefixed identifier character must have been parsed correctly" kind = Token.Kind.PERCENT_IDENT match = self._consume_regex(self._suffix_id) if match is None: - raise ParseError(Span(start_pos, self.pos, self.input), - "Expected suffix identifier after {first_char}") + raise ParseError( + Span(start_pos, self.pos, self.input), + "Expected suffix identifier after {first_char}", + ) return self._form_token(kind, start_pos) @@ -539,22 +564,26 @@ def _lex_string_literal(self, start_pos: int) -> Token: return self._form_token(Token.Kind.STRING_LIT, start_pos) # newline character in string literal (not allowed) - if current_char in ['\n', '\v', '\f']: + if current_char in ["\n", "\v", "\f"]: raise ParseError( Span(start_pos, self.pos, self.input), - "Newline character not allowed in string literal.") + "Newline character not allowed in string literal.", + ) # escape character # TODO: handle unicode escape - if current_char == '\\': + if current_char == "\\": escaped_char = self._get_chars() - if escaped_char not in ['"', '\\', 'n', 't']: + if escaped_char not in ['"', "\\", "n", "t"]: raise ParseError( StringLiteral(self.pos - 1, self.pos, self.input), - "Unknown escape in string literal.") + "Unknown escape in string literal.", + ) - raise ParseError(Span(start_pos, self.pos, self.input), - "End of file reached before closing string literal.") + raise ParseError( + Span(start_pos, self.pos, self.input), + "End of file reached before closing string literal.", + ) _hexdigits_star_regex = re.compile(r"[0-9a-fA-F]*") _digits_star_regex = re.compile(r"[0-9]*") @@ -570,9 +599,12 @@ def _lex_number(self, start_pos: int) -> Token: # Hexadecimal case, we only parse it if we see the first '0x' characters, # and then a first digit. # Otherwise, a string like '0xi32' would not be parsed correctly. - if (first_digit == '0' and self._peek_chars() == 'x' - and self._is_in_bounds(2) - and cast(str, self.input.at(self.pos + 1)) in hexdigits): + if ( + first_digit == "0" + and self._peek_chars() == "x" + and self._is_in_bounds(2) + and cast(str, self.input.at(self.pos + 1)) in hexdigits + ): self._consume_chars(2) self._consume_regex(self._hexdigits_star_regex) return self._form_token(Token.Kind.INTEGER_LIT, start_pos) diff --git a/xdsl/utils/test_value.py b/xdsl/utils/test_value.py index 3e121cd585..44aa4c9bf0 100644 --- a/xdsl/utils/test_value.py +++ b/xdsl/utils/test_value.py @@ -2,7 +2,6 @@ class TestSSAValue(SSAValue): - @property def owner(self) -> Operation | Block: - assert False, 'Attempting to get the owner of a `TestSSAValue`' + assert False, "Attempting to get the owner of a `TestSSAValue`" diff --git a/xdsl/xdsl_opt_main.py b/xdsl/xdsl_opt_main.py index 8a469e728a..6c4738dec8 100644 --- a/xdsl/xdsl_opt_main.py +++ b/xdsl/xdsl_opt_main.py @@ -30,9 +30,14 @@ from xdsl.frontend.passes.desymref import DesymrefyPass from xdsl.transforms.lower_mpi import LowerMPIPass -from xdsl.transforms.experimental.ConvertStencilToLLMLIR import ConvertStencilToGPUPass, ConvertStencilToLLMLIRPass +from xdsl.transforms.experimental.ConvertStencilToLLMLIR import ( + ConvertStencilToGPUPass, + ConvertStencilToLLMLIRPass, +) from xdsl.transforms.experimental.StencilShapeInference import StencilShapeInferencePass -from xdsl.transforms.experimental.stencil_global_to_local import GlobalStencilToLocalStencil2DHorizontal +from xdsl.transforms.experimental.stencil_global_to_local import ( + GlobalStencilToLocalStencil2DHorizontal, +) from xdsl.irdl_mlir_printer import IRDLPrinter from xdsl.utils.exceptions import DiagnosticException @@ -68,9 +73,11 @@ class xDSLOptMain: pipeline: List[ModulePass] """ The pass-pipeline to be applied. """ - def __init__(self, - description: str = 'xDSL modular optimizer driver', - args: Sequence[str] | None = None): + def __init__( + self, + description: str = "xDSL modular optimizer driver", + args: Sequence[str] | None = None, + ): self.available_frontends = {} self.available_passes = {} self.available_targets = {} @@ -119,19 +126,20 @@ def register_all_arguments(self, arg_parser: argparse.ArgumentParser): Add other/additional arguments by overloading this function. """ - arg_parser.add_argument("input_file", - type=str, - nargs="?", - help="path to input file") + arg_parser.add_argument( + "input_file", type=str, nargs="?", help="path to input file" + ) targets = [name for name in self.available_targets] - arg_parser.add_argument("-t", - "--target", - type=str, - required=False, - choices=targets, - help="target", - default="xdsl") + arg_parser.add_argument( + "-t", + "--target", + type=str, + required=False, + choices=targets, + help="target", + default="xdsl", + ) frontends = [name for name in self.available_frontends] arg_parser.add_argument( @@ -142,48 +150,53 @@ def register_all_arguments(self, arg_parser: argparse.ArgumentParser): choices=frontends, help="Frontend to be used for the input. If not set, " "the xdsl frontend or the one for the file extension " - "is used.") + "is used.", + ) - arg_parser.add_argument("--disable-verify", - default=False, - action='store_true') - arg_parser.add_argument("-o", - "--output-file", - type=str, - required=False, - help="path to output file") + arg_parser.add_argument("--disable-verify", default=False, action="store_true") + arg_parser.add_argument( + "-o", "--output-file", type=str, required=False, help="path to output file" + ) pass_names = ",".join([name for name in self.available_passes]) - arg_parser.add_argument("-p", - "--passes", - required=False, - help="Delimited list of passes." - f" Available passes are: {pass_names}", - type=str, - default="") - - arg_parser.add_argument("--print-between-passes", - default=False, - action='store_true', - help="Print the IR between each pass") - - arg_parser.add_argument("--verify-diagnostics", - default=False, - action='store_true', - help="Prints the content of a triggered " - "verifier exception and exits with code 0") - - arg_parser.add_argument("--parsing-diagnostics", - default=False, - action='store_true', - help="Prints the content of a triggered " - "parsing exception and exits with code 0") + arg_parser.add_argument( + "-p", + "--passes", + required=False, + help="Delimited list of passes." f" Available passes are: {pass_names}", + type=str, + default="", + ) + + arg_parser.add_argument( + "--print-between-passes", + default=False, + action="store_true", + help="Print the IR between each pass", + ) + + arg_parser.add_argument( + "--verify-diagnostics", + default=False, + action="store_true", + help="Prints the content of a triggered " + "verifier exception and exits with code 0", + ) + + arg_parser.add_argument( + "--parsing-diagnostics", + default=False, + action="store_true", + help="Prints the content of a triggered " + "parsing exception and exits with code 0", + ) arg_parser.add_argument( "--allow-unregistered-dialect", default=False, - action='store_true', - help="Allow the parsing of unregistered dialects.") + action="store_true", + help="Allow the parsing of unregistered dialects.", + ) def register_all_dialects(self): """ @@ -219,16 +232,22 @@ def register_all_frontends(self): def parse_xdsl(io: IO[str]): return XDSLParser( - self.ctx, io.read(), self.get_input_name(), - self.args.allow_unregistered_dialect).parse_module() + self.ctx, + io.read(), + self.get_input_name(), + self.args.allow_unregistered_dialect, + ).parse_module() def parse_mlir(io: IO[str]): return MLIRParser( - self.ctx, io.read(), self.get_input_name(), - self.args.allow_unregistered_dialect).parse_module() + self.ctx, + io.read(), + self.get_input_name(), + self.args.allow_unregistered_dialect, + ).parse_module() - self.available_frontends['xdsl'] = parse_xdsl - self.available_frontends['mlir'] = parse_mlir + self.available_frontends["xdsl"] = parse_xdsl + self.available_frontends["mlir"] = parse_mlir def register_pass(self, opPass: Type[ModulePass]): self.available_passes[opPass.name] = opPass @@ -267,9 +286,9 @@ def _output_irdl(prog: ModuleOp, output: IO[str]): irdl_to_mlir = IRDLPrinter(stream=output) irdl_to_mlir.print_module(prog) - self.available_targets['xdsl'] = _output_xdsl - self.available_targets['irdl'] = _output_irdl - self.available_targets['mlir'] = _output_mlir + self.available_targets["xdsl"] = _output_xdsl + self.available_targets["irdl"] = _output_irdl + self.available_targets["mlir"] = _output_mlir def setup_pipeline(self): """ @@ -277,9 +296,7 @@ def setup_pipeline(self): Failes, if not all passes are registered. """ - pipeline = [ - str(item) for item in self.args.passes.split(',') if len(item) > 0 - ] + pipeline = [str(item) for item in self.args.passes.split(",") if len(item) > 0] for p in pipeline: if p not in self.available_passes: @@ -295,7 +312,7 @@ def parse_input(self) -> ModuleOp: """ if self.args.input_file is None: f = sys.stdin - file_extension = 'xdsl' + file_extension = "xdsl" else: f = open(self.args.input_file) _, file_extension = os.path.splitext(self.args.input_file) @@ -345,8 +362,8 @@ def print_to_output_stream(self, contents: str): if self.args.output_file is None: print(contents) else: - with open(self.args.output_file, 'w') as output_stream: + with open(self.args.output_file, "w") as output_stream: output_stream.write(contents) def get_input_name(self): - return self.args.input_file or 'stdin' + return self.args.input_file or "stdin"