Skip to content

Commit

Permalink
[MLIR][OpenMP] Verify loop wrapper properties of omp.parallel (#88722)
Browse files Browse the repository at this point in the history
This patch extends verification of the `omp.parallel` operation to check
it is correctly defined when taking a loop wrapper role.

In OpenMP, a PARALLEL construct can be either a (potenially combined)
block construct or a loop construct, when appearing as part of a
composite construct. This is currently the case for the DISTRIBUTE
PARALLEL DO/FOR and DISTRIBUTE PARALLEL DO/FOR SIMD exclusively.

When used to represent the PARALLEL leaf of a composite construct, it
must follow the rules of a wrapper loop operation in MLIR, and this is
what this patch ensures. No additional restrictions are introduced for
PARALLEL block constructs.
  • Loading branch information
skatrak authored Apr 19, 2024
1 parent 9dbf3e2 commit 5e5b8c4
Show file tree
Hide file tree
Showing 3 changed files with 87 additions and 1 deletion.
16 changes: 16 additions & 0 deletions mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1344,6 +1344,22 @@ static LogicalResult verifyPrivateVarList(OpType &op) {
}

LogicalResult ParallelOp::verify() {
// Check that it is a valid loop wrapper if it's taking that role.
if (isa<DistributeOp>((*this)->getParentOp())) {
if (!isWrapper())
return emitOpError() << "must take a loop wrapper role if nested inside "
"of 'omp.distribute'";

if (LoopWrapperInterface nested = getNestedWrapper()) {
// Check for the allowed leaf constructs that may appear in a composite
// construct directly after PARALLEL.
if (!isa<WsloopOp>(nested))
return emitError() << "only supported nested wrapper is 'omp.wsloop'";
} else {
return emitOpError() << "must not wrap an 'omp.loop_nest' directly";
}
}

if (getAllocateVars().size() != getAllocatorsVars().size())
return emitError(
"expected equal sizes for allocate and allocator variables");
Expand Down
52 changes: 52 additions & 0 deletions mlir/test/Dialect/OpenMP/invalid.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,58 @@ func.func @unknown_clause() {

// -----

func.func @not_wrapper() {
omp.distribute {
// expected-error@+1 {{op must take a loop wrapper role if nested inside of 'omp.distribute'}}
omp.parallel {
%0 = arith.constant 0 : i32
omp.terminator
}
omp.terminator
}

return
}

// -----

func.func @invalid_nested_wrapper(%lb : index, %ub : index, %step : index) {
omp.distribute {
// expected-error@+1 {{only supported nested wrapper is 'omp.wsloop'}}
omp.parallel {
omp.simd {
omp.loop_nest (%iv) : index = (%lb) to (%ub) step (%step) {
omp.yield
}
omp.terminator
}
omp.terminator
}
omp.terminator
}

return
}

// -----

func.func @no_nested_wrapper(%lb : index, %ub : index, %step : index) {
omp.distribute {
// expected-error@+1 {{op must not wrap an 'omp.loop_nest' directly}}
omp.parallel {
omp.loop_nest (%iv) : index = (%lb) to (%ub) step (%step) {
omp.yield
}
omp.terminator
}
omp.terminator
}

return
}

// -----

func.func @if_once(%n : i1) {
// expected-error@+1 {{`if` clause can appear at most once in the expansion of the oilist directive}}
omp.parallel if(%n : i1) if(%n : i1) {
Expand Down
20 changes: 19 additions & 1 deletion mlir/test/Dialect/OpenMP/ops.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ func.func @omp_terminator() -> () {
omp.terminator
}

func.func @omp_parallel(%data_var : memref<i32>, %if_cond : i1, %num_threads : i32) -> () {
func.func @omp_parallel(%data_var : memref<i32>, %if_cond : i1, %num_threads : i32, %idx : index) -> () {
// CHECK: omp.parallel if(%{{.*}}) num_threads(%{{.*}} : i32) allocate(%{{.*}} : memref<i32> -> %{{.*}} : memref<i32>)
"omp.parallel" (%if_cond, %num_threads, %data_var, %data_var) ({

Expand Down Expand Up @@ -85,6 +85,24 @@ func.func @omp_parallel(%data_var : memref<i32>, %if_cond : i1, %num_threads : i
omp.terminator
}) {operandSegmentSizes = array<i32: 0,0,1,1,0,0>} : (memref<i32>, memref<i32>) -> ()

// CHECK: omp.distribute
omp.distribute {
// CHECK-NEXT: omp.parallel
omp.parallel {
// CHECK-NEXT: omp.wsloop
// TODO Remove induction variables from omp.wsloop.
omp.wsloop for (%iv) : index = (%idx) to (%idx) step (%idx) {
// CHECK-NEXT: omp.loop_nest
omp.loop_nest (%iv2) : index = (%idx) to (%idx) step (%idx) {
omp.yield
}
omp.terminator
}
omp.terminator
}
omp.terminator
}

return
}

Expand Down

0 comments on commit 5e5b8c4

Please sign in to comment.