diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp index f380926c4bce3fd..528a0d05b1011bd 100644 --- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp +++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp @@ -1344,6 +1344,22 @@ static LogicalResult verifyPrivateVarList(OpType &op) { } LogicalResult ParallelOp::verify() { + // Check that it is a valid loop wrapper if it's taking that role. + if (isa((*this)->getParentOp())) { + if (!isWrapper()) + return emitOpError() << "must take a loop wrapper role if nested inside " + "of 'omp.distribute'"; + + if (LoopWrapperInterface nested = getNestedWrapper()) { + // Check for the allowed leaf constructs that may appear in a composite + // construct directly after PARALLEL. + if (!isa(nested)) + return emitError() << "only supported nested wrapper is 'omp.wsloop'"; + } else { + return emitOpError() << "must not wrap an 'omp.loop_nest' directly"; + } + } + if (getAllocateVars().size() != getAllocatorsVars().size()) return emitError( "expected equal sizes for allocate and allocator variables"); diff --git a/mlir/test/Dialect/OpenMP/invalid.mlir b/mlir/test/Dialect/OpenMP/invalid.mlir index 1f04f4570687093..2f24dce4233e48e 100644 --- a/mlir/test/Dialect/OpenMP/invalid.mlir +++ b/mlir/test/Dialect/OpenMP/invalid.mlir @@ -10,6 +10,58 @@ func.func @unknown_clause() { // ----- +func.func @not_wrapper() { + omp.distribute { + // expected-error@+1 {{op must take a loop wrapper role if nested inside of 'omp.distribute'}} + omp.parallel { + %0 = arith.constant 0 : i32 + omp.terminator + } + omp.terminator + } + + return +} + +// ----- + +func.func @invalid_nested_wrapper(%lb : index, %ub : index, %step : index) { + omp.distribute { + // expected-error@+1 {{only supported nested wrapper is 'omp.wsloop'}} + omp.parallel { + omp.simd { + omp.loop_nest (%iv) : index = (%lb) to (%ub) step (%step) { + omp.yield + } + omp.terminator + } + omp.terminator + } + omp.terminator + } + + return +} + +// ----- + +func.func @no_nested_wrapper(%lb : index, %ub : index, %step : index) { + omp.distribute { + // expected-error@+1 {{op must not wrap an 'omp.loop_nest' directly}} + omp.parallel { + omp.loop_nest (%iv) : index = (%lb) to (%ub) step (%step) { + omp.yield + } + omp.terminator + } + omp.terminator + } + + return +} + +// ----- + func.func @if_once(%n : i1) { // expected-error@+1 {{`if` clause can appear at most once in the expansion of the oilist directive}} omp.parallel if(%n : i1) if(%n : i1) { diff --git a/mlir/test/Dialect/OpenMP/ops.mlir b/mlir/test/Dialect/OpenMP/ops.mlir index e2ca12afc14bd64..c10fc88211c3671 100644 --- a/mlir/test/Dialect/OpenMP/ops.mlir +++ b/mlir/test/Dialect/OpenMP/ops.mlir @@ -51,7 +51,7 @@ func.func @omp_terminator() -> () { omp.terminator } -func.func @omp_parallel(%data_var : memref, %if_cond : i1, %num_threads : i32) -> () { +func.func @omp_parallel(%data_var : memref, %if_cond : i1, %num_threads : i32, %idx : index) -> () { // CHECK: omp.parallel if(%{{.*}}) num_threads(%{{.*}} : i32) allocate(%{{.*}} : memref -> %{{.*}} : memref) "omp.parallel" (%if_cond, %num_threads, %data_var, %data_var) ({ @@ -85,6 +85,24 @@ func.func @omp_parallel(%data_var : memref, %if_cond : i1, %num_threads : i omp.terminator }) {operandSegmentSizes = array} : (memref, memref) -> () + // CHECK: omp.distribute + omp.distribute { + // CHECK-NEXT: omp.parallel + omp.parallel { + // CHECK-NEXT: omp.wsloop + // TODO Remove induction variables from omp.wsloop. + omp.wsloop for (%iv) : index = (%idx) to (%idx) step (%idx) { + // CHECK-NEXT: omp.loop_nest + omp.loop_nest (%iv2) : index = (%idx) to (%idx) step (%idx) { + omp.yield + } + omp.terminator + } + omp.terminator + } + omp.terminator + } + return }