Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[flang][openacc] Fix unstructured code in OpenACC region ops #66284

Merged
merged 4 commits into from
Sep 14, 2023

Conversation

clementval
Copy link
Contributor

For unstructured construct, the blocks are created in advance inside the function body. This causes issues when the unstructured construct is inside an OpenACC region operations. This patch adds the same fix than OpenMP lowering and re-create the blocks inside the op region.

Initial OpenMP fix: 29f167a

@clementval clementval requested review from a team as code owners September 13, 2023 20:16
@llvmbot llvmbot added flang Flang issues not falling into any other category flang:fir-hlfir flang:openmp openacc labels Sep 13, 2023
@llvmbot
Copy link
Member

llvmbot commented Sep 13, 2023

@llvm/pr-subscribers-flang-openmp

Changes For unstructured construct, the blocks are created in advance inside the function body. This causes issues when the unstructured construct is inside an OpenACC region operations. This patch adds the same fix than OpenMP lowering and re-create the blocks inside the op region.

Initial OpenMP fix: 29f167a

Full diff: https://github.com/llvm/llvm-project/pull/66284.diff

4 Files Affected:

  • (modified) flang/lib/Lower/DirectivesCommon.h (+23-1)
  • (modified) flang/lib/Lower/OpenACC.cpp (+48-23)
  • (modified) flang/lib/Lower/OpenMP.cpp (+1-24)
  • (added) flang/test/Lower/OpenACC/acc-unstructured.f90 (+86)
diff --git a/flang/lib/Lower/DirectivesCommon.h b/flang/lib/Lower/DirectivesCommon.h
index 35825a20b4cf93f..59d46008bcca9e3 100644
--- a/flang/lib/Lower/DirectivesCommon.h
+++ b/flang/lib/Lower/DirectivesCommon.h
@@ -587,7 +587,29 @@ void genOmpAccAtomicCapture(Fortran::lower::AbstractConverter &converter,
   firOpBuilder.setInsertionPointToStart(&block);
 }
 
+/// Create empty blocks for the current region.
+/// These blocks replace blocks parented to an enclosing region.
+template <typename... TerminatorOps>
+void createEmptyRegionBlocks(fir::FirOpBuilder &builder,
+    std::list<Fortran::lower::pft::Evaluation> &evaluationList) {
+  mlir::Region *region = &builder.getRegion();
+  for (Fortran::lower::pft::Evaluation &eval : evaluationList) {
+    if (eval.block) {
+      if (eval.block->empty()) {
+        eval.block->erase();
+        eval.block = builder.createBlock(region);
+      } else {
+        [[maybe_unused]] mlir::Operation &terminatorOp = eval.block->back();
+        assert(mlir::isa<TerminatorOps...>(terminatorOp) &&
+            "expected terminator op");
+      }
+    }
+    if (!eval.isDirective() && eval.hasNestedEvaluations())
+      createEmptyRegionBlocks<TerminatorOps...>(builder, eval.getNestedEvaluations());
+  }
+}
+
 } // namespace lower
 } // namespace Fortran
 
-#endif // FORTRAN_LOWER_DIRECTIVES_COMMON_H
\ No newline at end of file
+#endif // FORTRAN_LOWER_DIRECTIVES_COMMON_H
diff --git a/flang/lib/Lower/OpenACC.cpp b/flang/lib/Lower/OpenACC.cpp
index 732765c4def59cb..e798876130c1d28 100644
--- a/flang/lib/Lower/OpenACC.cpp
+++ b/flang/lib/Lower/OpenACC.cpp
@@ -1252,14 +1252,15 @@ static void addOperand(llvm::SmallVectorImpl<mlir::Value> &operands,
 template <typename Op, typename Terminator>
 static Op
 createRegionOp(fir::FirOpBuilder &builder, mlir::Location loc,
+               Fortran::lower::pft::Evaluation &eval,
                const llvm::SmallVectorImpl<mlir::Value> &operands,
-               const llvm::SmallVectorImpl<int32_t> &operandSegments) {
+               const llvm::SmallVectorImpl<int32_t> &operandSegments,
+               bool outerCombined = false) {
   llvm::ArrayRef<mlir::Type> argTy;
   Op op = builder.create<Op>(loc, argTy, operands);
   builder.createBlock(&op.getRegion());
   mlir::Block &block = op.getRegion().back();
   builder.setInsertionPointToStart(&block);
-  builder.create<Terminator>(loc);
 
   op->setAttr(Op::getOperandSegmentSizeAttr(),
               builder.getDenseI32ArrayAttr(operandSegments));
@@ -1267,6 +1268,13 @@ createRegionOp(fir::FirOpBuilder &builder, mlir::Location loc,
   // Place the insertion point to the start of the first block.
   builder.setInsertionPointToStart(&block);
 
+  // If it is an unstructured region and is not the outer region of a combined
+  // construct, create empty blocks for all evaluations.
+  if (eval.lowerAsUnstructured() && !outerCombined)
+    Fortran::lower::createEmptyRegionBlocks<mlir::acc::TerminatorOp, mlir::acc::YieldOp>(builder, eval.getNestedEvaluations());
+
+  builder.create<Terminator>(loc);
+  builder.setInsertionPointToStart(&block);
   return op;
 }
 
@@ -1347,6 +1355,7 @@ static void genWaitClause(Fortran::lower::AbstractConverter &converter,
 static mlir::acc::LoopOp
 createLoopOp(Fortran::lower::AbstractConverter &converter,
              mlir::Location currentLocation,
+             Fortran::lower::pft::Evaluation &eval,
              Fortran::semantics::SemanticsContext &semanticsContext,
              Fortran::lower::StatementContext &stmtCtx,
              const Fortran::parser::AccClauseList &accClauseList) {
@@ -1455,7 +1464,7 @@ createLoopOp(Fortran::lower::AbstractConverter &converter,
   addOperands(operands, operandSegments, cacheOperands);
 
   auto loopOp = createRegionOp<mlir::acc::LoopOp, mlir::acc::YieldOp>(
-      builder, currentLocation, operands, operandSegments);
+      builder, currentLocation, eval, operands, operandSegments);
 
   if (hasGang)
     loopOp.setHasGangAttr(builder.getUnitAttr());
@@ -1504,6 +1513,7 @@ createLoopOp(Fortran::lower::AbstractConverter &converter,
 
 static void genACC(Fortran::lower::AbstractConverter &converter,
                    Fortran::semantics::SemanticsContext &semanticsContext,
+                   Fortran::lower::pft::Evaluation &eval,
                    const Fortran::parser::OpenACCLoopConstruct &loopConstruct) {
 
   const auto &beginLoopDirective =
@@ -1518,7 +1528,7 @@ static void genACC(Fortran::lower::AbstractConverter &converter,
   if (loopDirective.v == llvm::acc::ACCD_loop) {
     const auto &accClauseList =
         std::get<Fortran::parser::AccClauseList>(beginLoopDirective.t);
-    createLoopOp(converter, currentLocation, semanticsContext, stmtCtx,
+    createLoopOp(converter, currentLocation, eval, semanticsContext, stmtCtx,
                  accClauseList);
   }
 }
@@ -1551,9 +1561,11 @@ template <typename Op>
 static Op
 createComputeOp(Fortran::lower::AbstractConverter &converter,
                 mlir::Location currentLocation,
+                Fortran::lower::pft::Evaluation &eval,
                 Fortran::semantics::SemanticsContext &semanticsContext,
                 Fortran::lower::StatementContext &stmtCtx,
-                const Fortran::parser::AccClauseList &accClauseList) {
+                const Fortran::parser::AccClauseList &accClauseList,
+                bool outerCombined = false) {
 
   // Parallel operation operands
   mlir::Value async;
@@ -1769,10 +1781,10 @@ createComputeOp(Fortran::lower::AbstractConverter &converter,
   Op computeOp;
   if constexpr (std::is_same_v<Op, mlir::acc::KernelsOp>)
     computeOp = createRegionOp<Op, mlir::acc::TerminatorOp>(
-        builder, currentLocation, operands, operandSegments);
+        builder, currentLocation, eval, operands, operandSegments, outerCombined);
   else
     computeOp = createRegionOp<Op, mlir::acc::YieldOp>(
-        builder, currentLocation, operands, operandSegments);
+        builder, currentLocation, eval, operands, operandSegments, outerCombined);
 
   if (addAsyncAttr)
     computeOp.setAsyncAttrAttr(builder.getUnitAttr());
@@ -1817,6 +1829,7 @@ createComputeOp(Fortran::lower::AbstractConverter &converter,
 
 static void genACCDataOp(Fortran::lower::AbstractConverter &converter,
                          mlir::Location currentLocation,
+                         Fortran::lower::pft::Evaluation &eval,
                          Fortran::semantics::SemanticsContext &semanticsContext,
                          Fortran::lower::StatementContext &stmtCtx,
                          const Fortran::parser::AccClauseList &accClauseList) {
@@ -1942,7 +1955,7 @@ static void genACCDataOp(Fortran::lower::AbstractConverter &converter,
     return;
 
   auto dataOp = createRegionOp<mlir::acc::DataOp, mlir::acc::TerminatorOp>(
-      builder, currentLocation, operands, operandSegments);
+      builder, currentLocation, eval, operands, operandSegments);
 
   dataOp.setAsyncAttr(addAsyncAttr);
   dataOp.setWaitAttr(addWaitAttr);
@@ -1971,6 +1984,7 @@ static void genACCDataOp(Fortran::lower::AbstractConverter &converter,
 static void
 genACCHostDataOp(Fortran::lower::AbstractConverter &converter,
                  mlir::Location currentLocation,
+                 Fortran::lower::pft::Evaluation &eval,
                  Fortran::semantics::SemanticsContext &semanticsContext,
                  Fortran::lower::StatementContext &stmtCtx,
                  const Fortran::parser::AccClauseList &accClauseList) {
@@ -2020,7 +2034,7 @@ genACCHostDataOp(Fortran::lower::AbstractConverter &converter,
 
   auto hostDataOp =
       createRegionOp<mlir::acc::HostDataOp, mlir::acc::TerminatorOp>(
-          builder, currentLocation, operands, operandSegments);
+          builder, currentLocation, eval, operands, operandSegments);
 
   if (addIfPresentAttr)
     hostDataOp.setIfPresentAttr(builder.getUnitAttr());
@@ -2029,6 +2043,7 @@ genACCHostDataOp(Fortran::lower::AbstractConverter &converter,
 static void
 genACC(Fortran::lower::AbstractConverter &converter,
        Fortran::semantics::SemanticsContext &semanticsContext,
+       Fortran::lower::pft::Evaluation &eval,
        const Fortran::parser::OpenACCBlockConstruct &blockConstruct) {
   const auto &beginBlockDirective =
       std::get<Fortran::parser::AccBeginBlockDirective>(blockConstruct.t);
@@ -2042,18 +2057,18 @@ genACC(Fortran::lower::AbstractConverter &converter,
 
   if (blockDirective.v == llvm::acc::ACCD_parallel) {
     createComputeOp<mlir::acc::ParallelOp>(
-        converter, currentLocation, semanticsContext, stmtCtx, accClauseList);
+        converter, currentLocation, eval, semanticsContext, stmtCtx, accClauseList);
   } else if (blockDirective.v == llvm::acc::ACCD_data) {
-    genACCDataOp(converter, currentLocation, semanticsContext, stmtCtx,
+    genACCDataOp(converter, currentLocation, eval, semanticsContext, stmtCtx,
                  accClauseList);
   } else if (blockDirective.v == llvm::acc::ACCD_serial) {
     createComputeOp<mlir::acc::SerialOp>(
-        converter, currentLocation, semanticsContext, stmtCtx, accClauseList);
+        converter, currentLocation, eval, semanticsContext, stmtCtx, accClauseList);
   } else if (blockDirective.v == llvm::acc::ACCD_kernels) {
     createComputeOp<mlir::acc::KernelsOp>(
-        converter, currentLocation, semanticsContext, stmtCtx, accClauseList);
+        converter, currentLocation, eval, semanticsContext, stmtCtx, accClauseList);
   } else if (blockDirective.v == llvm::acc::ACCD_host_data) {
-    genACCHostDataOp(converter, currentLocation, semanticsContext, stmtCtx,
+    genACCHostDataOp(converter, currentLocation, eval, semanticsContext, stmtCtx,
                      accClauseList);
   }
 }
@@ -2061,6 +2076,7 @@ genACC(Fortran::lower::AbstractConverter &converter,
 static void
 genACC(Fortran::lower::AbstractConverter &converter,
        Fortran::semantics::SemanticsContext &semanticsContext,
+       Fortran::lower::pft::Evaluation &eval,
        const Fortran::parser::OpenACCCombinedConstruct &combinedConstruct) {
   const auto &beginCombinedDirective =
       std::get<Fortran::parser::AccBeginCombinedDirective>(combinedConstruct.t);
@@ -2075,18 +2091,18 @@ genACC(Fortran::lower::AbstractConverter &converter,
 
   if (combinedDirective.v == llvm::acc::ACCD_kernels_loop) {
     createComputeOp<mlir::acc::KernelsOp>(
-        converter, currentLocation, semanticsContext, stmtCtx, accClauseList);
-    createLoopOp(converter, currentLocation, semanticsContext, stmtCtx,
+        converter, currentLocation, eval, semanticsContext, stmtCtx, accClauseList, /*outerCombined=*/true);
+    createLoopOp(converter, currentLocation, eval, semanticsContext, stmtCtx,
                  accClauseList);
   } else if (combinedDirective.v == llvm::acc::ACCD_parallel_loop) {
     createComputeOp<mlir::acc::ParallelOp>(
-        converter, currentLocation, semanticsContext, stmtCtx, accClauseList);
-    createLoopOp(converter, currentLocation, semanticsContext, stmtCtx,
+        converter, currentLocation, eval, semanticsContext, stmtCtx, accClauseList, /*outerCombined=*/true);
+    createLoopOp(converter, currentLocation, eval, semanticsContext, stmtCtx,
                  accClauseList);
   } else if (combinedDirective.v == llvm::acc::ACCD_serial_loop) {
     createComputeOp<mlir::acc::SerialOp>(
-        converter, currentLocation, semanticsContext, stmtCtx, accClauseList);
-    createLoopOp(converter, currentLocation, semanticsContext, stmtCtx,
+        converter, currentLocation, eval, semanticsContext, stmtCtx, accClauseList, /*outerCombined=*/true);
+    createLoopOp(converter, currentLocation, eval, semanticsContext, stmtCtx,
                  accClauseList);
   } else {
     llvm::report_fatal_error("Unknown combined construct encountered");
@@ -3169,14 +3185,14 @@ void Fortran::lower::genOpenACCConstruct(
   std::visit(
       common::visitors{
           [&](const Fortran::parser::OpenACCBlockConstruct &blockConstruct) {
-            genACC(converter, semanticsContext, blockConstruct);
+            genACC(converter, semanticsContext, eval, blockConstruct);
           },
           [&](const Fortran::parser::OpenACCCombinedConstruct
                   &combinedConstruct) {
-            genACC(converter, semanticsContext, combinedConstruct);
+            genACC(converter, semanticsContext, eval, combinedConstruct);
           },
           [&](const Fortran::parser::OpenACCLoopConstruct &loopConstruct) {
-            genACC(converter, semanticsContext, loopConstruct);
+            genACC(converter, semanticsContext, eval, loopConstruct);
           },
           [&](const Fortran::parser::OpenACCStandaloneConstruct
                   &standaloneConstruct) {
@@ -3274,3 +3290,12 @@ void Fortran::lower::attachDeclarePostDeallocAction(
                  /*preAlloc=*/{}, /*postAlloc=*/{}, /*preDealloc=*/{},
                  /*postDealloc=*/builder.getSymbolRefAttr(fctName.str())));
 }
+
+void Fortran::lower::genOpenACCTerminator(fir::FirOpBuilder &builder,
+                                          mlir::Operation *op,
+                                          mlir::Location loc) {
+  if (mlir::isa<mlir::acc::ParallelOp, mlir::acc::LoopOp>(op))
+    builder.create<mlir::acc::YieldOp>(loc);
+  else
+    builder.create<mlir::acc::TerminatorOp>(loc);
+}
diff --git a/flang/lib/Lower/OpenMP.cpp b/flang/lib/Lower/OpenMP.cpp
index b960bb369dd4dd2..9f239f2921c00f9 100644
--- a/flang/lib/Lower/OpenMP.cpp
+++ b/flang/lib/Lower/OpenMP.cpp
@@ -1981,29 +1981,6 @@ static mlir::Type getLoopVarType(Fortran::lower::AbstractConverter &converter,
   return converter.getFirOpBuilder().getIntegerType(loopVarTypeSize);
 }
 
-/// Create empty blocks for the current region.
-/// These blocks replace blocks parented to an enclosing region.
-static void createEmptyRegionBlocks(
-    fir::FirOpBuilder &firOpBuilder,
-    std::list<Fortran::lower::pft::Evaluation> &evaluationList) {
-  mlir::Region *region = &firOpBuilder.getRegion();
-  for (Fortran::lower::pft::Evaluation &eval : evaluationList) {
-    if (eval.block) {
-      if (eval.block->empty()) {
-        eval.block->erase();
-        eval.block = firOpBuilder.createBlock(region);
-      } else {
-        [[maybe_unused]] mlir::Operation &terminatorOp = eval.block->back();
-        assert((mlir::isa<mlir::omp::TerminatorOp>(terminatorOp) ||
-                mlir::isa<mlir::omp::YieldOp>(terminatorOp)) &&
-               "expected terminator op");
-      }
-    }
-    if (!eval.isDirective() && eval.hasNestedEvaluations())
-      createEmptyRegionBlocks(firOpBuilder, eval.getNestedEvaluations());
-  }
-}
-
 static void resetBeforeTerminator(fir::FirOpBuilder &firOpBuilder,
                                   mlir::Operation *storeOp,
                                   mlir::Block &block) {
@@ -2092,7 +2069,7 @@ static void createBodyOfOp(
   // If it is an unstructured region and is not the outer region of a combined
   // construct, create empty blocks for all evaluations.
   if (eval.lowerAsUnstructured() && !outerCombined)
-    createEmptyRegionBlocks(firOpBuilder, eval.getNestedEvaluations());
+    Fortran::lower::createEmptyRegionBlocks<mlir::omp::TerminatorOp, mlir::omp::YieldOp>(firOpBuilder, eval.getNestedEvaluations());
 
   // Insert the terminator.
   if constexpr (std::is_same_v<Op, mlir::omp::WsLoopOp> ||
diff --git a/flang/test/Lower/OpenACC/acc-unstructured.f90 b/flang/test/Lower/OpenACC/acc-unstructured.f90
new file mode 100644
index 000000000000000..4ce474241a4bec7
--- /dev/null
+++ b/flang/test/Lower/OpenACC/acc-unstructured.f90
@@ -0,0 +1,86 @@
+! RUN: bbc -fopenacc -emit-fir %s -o - | FileCheck %s
+! RUN: bbc -fopenacc -emit-hlfir %s -o - | FileCheck %s
+
+subroutine test_unstructured1(a, b, c)
+  integer :: i, j, k
+  real :: a(:,:,:), b(:,:,:), c(:,:,:)
+
+  !$acc data copy(a, b, c)
+
+  !$acc kernels
+  a(:,:,:) = 0.0
+  !$acc end kernels
+
+  !$acc kernels
+  do i = 1, 10
+    do j = 1, 10
+      do k = 1, 10
+      end do
+    end do
+  end do
+  !$acc end kernels
+
+  do i = 1, 10
+    do j = 1, 10
+      do k = 1, 10
+      end do
+    end do
+
+    if (a(1,2,3) > 10) stop 'just to be unstructured'
+  end do
+
+  !$acc end data
+
+end subroutine
+
+! CHECK-LABEL: func.func @_QPtest_unstructured1
+! CHECK: acc.data
+! CHECK: acc.kernels
+! CHECK: acc.kernels
+! CHECK: fir.call @_FortranAStopStatementText
+
+
+subroutine test_unstructured2(a, b, c)
+  integer :: i, j, k
+  real :: a(:,:,:), b(:,:,:), c(:,:,:)
+
+  !$acc parallel loop
+  do i = 1, 10
+    do j = 1, 10
+      do k = 1, 10
+        if (a(1,2,3) > 10) stop 'just to be unstructured'
+      end do
+    end do
+  end do
+
+! CHECK-LABEL: func.func @_QPtest_unstructured2
+! CHECK: acc.parallel
+! CHECK: acc.loop
+! CHECK: fir.call @_FortranAStopStatementText
+! CHECK: acc.yield
+! CHECK: acc.yield
+! CHECK: acc.yield
+
+end subroutine
+
+subroutine test_unstructured3(a, b, c)
+  integer :: i, j, k
+  real :: a(:,:,:), b(:,:,:), c(:,:,:)
+
+  !$acc parallel
+  do i = 1, 10
+    do j = 1, 10
+      do k = 1, 10
+        if (a(1,2,3) > 10) stop 'just to be unstructured'
+      end do
+    end do
+  end do
+  !$acc end parallel
+
+! CHECK-LABEL: func.func @_QPtest_unstructured3
+! CHECK: acc.parallel
+! CHECK: fir.call @_FortranAStopStatementText
+! CHECK: acc.yield
+! CHECK: acc.yield
+
+end subroutine

@vdonaldson
Copy link
Contributor

I had forgotten that there were multiple issues in unstructured code that the OpenMP folks have addressed over time.

This looks good to me.

@clementval clementval merged commit b1341e2 into llvm:main Sep 14, 2023
@clementval clementval deleted the acc_unstructured branch September 14, 2023 03:49
kstoimenov pushed a commit to kstoimenov/llvm-project that referenced this pull request Sep 14, 2023
)

For unstructured construct, the blocks are created in advance inside the
function body. This causes issues when the unstructured construct is
inside an OpenACC region operations. This patch adds the same fix than
OpenMP lowering and re-create the blocks inside the op region.

Initial OpenMP fix: 29f167a
ZijunZhaoCCK pushed a commit to ZijunZhaoCCK/llvm-project that referenced this pull request Sep 19, 2023
)

For unstructured construct, the blocks are created in advance inside the
function body. This causes issues when the unstructured construct is
inside an OpenACC region operations. This patch adds the same fix than
OpenMP lowering and re-create the blocks inside the op region.

Initial OpenMP fix: 29f167a
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
flang:fir-hlfir flang:openmp flang Flang issues not falling into any other category openacc
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants