Skip to content

Commit

Permalink
[Clang][Sema][OpenMP] Allow num_teams to accept multiple expressions (
Browse files Browse the repository at this point in the history
llvm#99732)

By the OpenMP standard, `num_teams` clause can only accept one
expression (for now). In this patch, we extend it to allow to accept
multiple expressions when it is used with `target teams ompx_bare`
construct. This will allow to launch a multi-dim grid, same as CUDA/HIP.
  • Loading branch information
shiltian authored Aug 6, 2024
1 parent f0178d8 commit cee594c
Show file tree
Hide file tree
Showing 20 changed files with 211 additions and 68 deletions.
2 changes: 2 additions & 0 deletions clang/docs/OpenMPSupport.rst
Original file line number Diff line number Diff line change
Expand Up @@ -363,5 +363,7 @@ considered for standardization. Please post on the
| device extension | `'ompx_bare' clause on 'target teams' construct | :good:`prototyped` | #66844, #70612 |
| | <https://www.osti.gov/servlets/purl/2205717>`_ | | |
+------------------------------+-----------------------------------------------------------------------------------+--------------------------+--------------------------------------------------------+
| device extension | Multi-dim `'num_teams' clause on 'target teams ompx_bare' construct | :good:`partial` | #99732, #101407 |
+------------------------------+-----------------------------------------------------------------------------------+--------------------------+--------------------------------------------------------+

.. _Discourse forums (Runtimes - OpenMP category): https://discourse.llvm.org/c/runtimes/openmp/35
3 changes: 3 additions & 0 deletions clang/docs/ReleaseNotes.rst
Original file line number Diff line number Diff line change
Expand Up @@ -333,6 +333,9 @@ Improvements
^^^^^^^^^^^^
- Improve the handling of mapping array-section for struct containing nested structs with user defined mappers

- `num_teams` now accepts multiple expressions when it is used along in ``target teams ompx_bare`` construct.
This allows the target region to be launched with multi-dim grid on GPUs.

Additional Information
======================

Expand Down
79 changes: 48 additions & 31 deletions clang/include/clang/AST/OpenMPClause.h
Original file line number Diff line number Diff line change
Expand Up @@ -6369,60 +6369,77 @@ class OMPMapClause final : public OMPMappableExprListClause<OMPMapClause>,
/// \endcode
/// In this example directive '#pragma omp teams' has clause 'num_teams'
/// with single expression 'n'.
class OMPNumTeamsClause : public OMPClause, public OMPClauseWithPreInit {
friend class OMPClauseReader;
///
/// When 'ompx_bare' clause exists on a 'target' directive, 'num_teams' clause
/// can accept up to three expressions.
///
/// \code
/// #pragma omp target teams ompx_bare num_teams(x, y, z)
/// \endcode
class OMPNumTeamsClause final
: public OMPVarListClause<OMPNumTeamsClause>,
public OMPClauseWithPreInit,
private llvm::TrailingObjects<OMPNumTeamsClause, Expr *> {
friend OMPVarListClause;
friend TrailingObjects;

/// Location of '('.
SourceLocation LParenLoc;

/// NumTeams number.
Stmt *NumTeams = nullptr;
OMPNumTeamsClause(const ASTContext &C, SourceLocation StartLoc,
SourceLocation LParenLoc, SourceLocation EndLoc, unsigned N)
: OMPVarListClause(llvm::omp::OMPC_num_teams, StartLoc, LParenLoc, EndLoc,
N),
OMPClauseWithPreInit(this) {}

/// Set the NumTeams number.
///
/// \param E NumTeams number.
void setNumTeams(Expr *E) { NumTeams = E; }
/// Build an empty clause.
OMPNumTeamsClause(unsigned N)
: OMPVarListClause(llvm::omp::OMPC_num_teams, SourceLocation(),
SourceLocation(), SourceLocation(), N),
OMPClauseWithPreInit(this) {}

public:
/// Build 'num_teams' clause.
/// Creates clause with a list of variables \a VL.
///
/// \param E Expression associated with this clause.
/// \param HelperE Helper Expression associated with this clause.
/// \param CaptureRegion Innermost OpenMP region where expressions in this
/// clause must be captured.
/// \param C AST context.
/// \param StartLoc Starting location of the clause.
/// \param LParenLoc Location of '('.
/// \param EndLoc Ending location of the clause.
OMPNumTeamsClause(Expr *E, Stmt *HelperE, OpenMPDirectiveKind CaptureRegion,
SourceLocation StartLoc, SourceLocation LParenLoc,
SourceLocation EndLoc)
: OMPClause(llvm::omp::OMPC_num_teams, StartLoc, EndLoc),
OMPClauseWithPreInit(this), LParenLoc(LParenLoc), NumTeams(E) {
setPreInitStmt(HelperE, CaptureRegion);
}
/// \param VL List of references to the variables.
/// \param PreInit
static OMPNumTeamsClause *
Create(const ASTContext &C, OpenMPDirectiveKind CaptureRegion,
SourceLocation StartLoc, SourceLocation LParenLoc,
SourceLocation EndLoc, ArrayRef<Expr *> VL, Stmt *PreInit);

/// Build an empty clause.
OMPNumTeamsClause()
: OMPClause(llvm::omp::OMPC_num_teams, SourceLocation(),
SourceLocation()),
OMPClauseWithPreInit(this) {}
/// Creates an empty clause with \a N variables.
///
/// \param C AST context.
/// \param N The number of variables.
static OMPNumTeamsClause *CreateEmpty(const ASTContext &C, unsigned N);

/// Sets the location of '('.
void setLParenLoc(SourceLocation Loc) { LParenLoc = Loc; }

/// Returns the location of '('.
SourceLocation getLParenLoc() const { return LParenLoc; }

/// Return NumTeams number.
Expr *getNumTeams() { return cast<Expr>(NumTeams); }
/// Return NumTeams expressions.
ArrayRef<Expr *> getNumTeams() { return getVarRefs(); }

/// Return NumTeams number.
Expr *getNumTeams() const { return cast<Expr>(NumTeams); }
/// Return NumTeams expressions.
ArrayRef<Expr *> getNumTeams() const {
return const_cast<OMPNumTeamsClause *>(this)->getNumTeams();
}

child_range children() { return child_range(&NumTeams, &NumTeams + 1); }
child_range children() {
return child_range(reinterpret_cast<Stmt **>(varlist_begin()),
reinterpret_cast<Stmt **>(varlist_end()));
}

const_child_range children() const {
return const_child_range(&NumTeams, &NumTeams + 1);
auto Children = const_cast<OMPNumTeamsClause *>(this)->children();
return const_child_range(Children.begin(), Children.end());
}

child_range used_children() {
Expand Down
2 changes: 1 addition & 1 deletion clang/include/clang/AST/RecursiveASTVisitor.h
Original file line number Diff line number Diff line change
Expand Up @@ -3828,8 +3828,8 @@ bool RecursiveASTVisitor<Derived>::VisitOMPMapClause(OMPMapClause *C) {
template <typename Derived>
bool RecursiveASTVisitor<Derived>::VisitOMPNumTeamsClause(
OMPNumTeamsClause *C) {
TRY_TO(VisitOMPClauseList(C));
TRY_TO(VisitOMPClauseWithPreInit(C));
TRY_TO(TraverseStmt(C->getNumTeams()));
return true;
}

Expand Down
2 changes: 2 additions & 0 deletions clang/include/clang/Basic/DiagnosticSemaKinds.td
Original file line number Diff line number Diff line change
Expand Up @@ -11639,6 +11639,8 @@ def warn_omp_unterminated_declare_target : Warning<
InGroup<SourceUsesOpenMP>;
def err_ompx_bare_no_grid : Error<
"'ompx_bare' clauses requires explicit grid size via 'num_teams' and 'thread_limit' clauses">;
def err_omp_multi_expr_not_allowed: Error<"only one expression allowed in '%0' clause">;
def err_ompx_more_than_three_expr_not_allowed: Error<"at most three expressions are allowed in '%0' clause in 'target teams ompx_bare' construct">;
} // end of OpenMP category

let CategoryName = "Related Result Type Issue" in {
Expand Down
3 changes: 2 additions & 1 deletion clang/include/clang/Sema/SemaOpenMP.h
Original file line number Diff line number Diff line change
Expand Up @@ -1259,7 +1259,8 @@ class SemaOpenMP : public SemaBase {
const OMPVarListLocTy &Locs, bool NoDiagnose = false,
ArrayRef<Expr *> UnresolvedMappers = std::nullopt);
/// Called on well-formed 'num_teams' clause.
OMPClause *ActOnOpenMPNumTeamsClause(Expr *NumTeams, SourceLocation StartLoc,
OMPClause *ActOnOpenMPNumTeamsClause(ArrayRef<Expr *> VarList,
SourceLocation StartLoc,
SourceLocation LParenLoc,
SourceLocation EndLoc);
/// Called on well-formed 'thread_limit' clause.
Expand Down
26 changes: 23 additions & 3 deletions clang/lib/AST/OpenMPClause.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1755,6 +1755,24 @@ OMPContainsClause *OMPContainsClause::CreateEmpty(const ASTContext &C,
return new (Mem) OMPContainsClause(K);
}

OMPNumTeamsClause *OMPNumTeamsClause::Create(
const ASTContext &C, OpenMPDirectiveKind CaptureRegion,
SourceLocation StartLoc, SourceLocation LParenLoc, SourceLocation EndLoc,
ArrayRef<Expr *> VL, Stmt *PreInit) {
void *Mem = C.Allocate(totalSizeToAlloc<Expr *>(VL.size()));
OMPNumTeamsClause *Clause =
new (Mem) OMPNumTeamsClause(C, StartLoc, LParenLoc, EndLoc, VL.size());
Clause->setVarRefs(VL);
Clause->setPreInitStmt(PreInit, CaptureRegion);
return Clause;
}

OMPNumTeamsClause *OMPNumTeamsClause::CreateEmpty(const ASTContext &C,
unsigned N) {
void *Mem = C.Allocate(totalSizeToAlloc<Expr *>(N));
return new (Mem) OMPNumTeamsClause(N);
}

//===----------------------------------------------------------------------===//
// OpenMP clauses printing methods
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -2055,9 +2073,11 @@ void OMPClausePrinter::VisitOMPDeviceClause(OMPDeviceClause *Node) {
}

void OMPClausePrinter::VisitOMPNumTeamsClause(OMPNumTeamsClause *Node) {
OS << "num_teams(";
Node->getNumTeams()->printPretty(OS, nullptr, Policy, 0);
OS << ")";
if (!Node->varlist_empty()) {
OS << "num_teams";
VisitOMPClauseList(Node, '(');
OS << ")";
}
}

void OMPClausePrinter::VisitOMPThreadLimitClause(OMPThreadLimitClause *Node) {
Expand Down
3 changes: 1 addition & 2 deletions clang/lib/AST/StmtProfile.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -857,9 +857,8 @@ void OMPClauseProfiler::VisitOMPAllocateClause(const OMPAllocateClause *C) {
VisitOMPClauseList(C);
}
void OMPClauseProfiler::VisitOMPNumTeamsClause(const OMPNumTeamsClause *C) {
VisitOMPClauseList(C);
VistOMPClauseWithPreInit(C);
if (C->getNumTeams())
Profiler->VisitStmt(C->getNumTeams());
}
void OMPClauseProfiler::VisitOMPThreadLimitClause(
const OMPThreadLimitClause *C) {
Expand Down
7 changes: 4 additions & 3 deletions clang/lib/CodeGen/CGOpenMPRuntime.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6036,8 +6036,9 @@ const Expr *CGOpenMPRuntime::getNumTeamsExprForTargetDirective(
dyn_cast_or_null<OMPExecutableDirective>(ChildStmt)) {
if (isOpenMPTeamsDirective(NestedDir->getDirectiveKind())) {
if (NestedDir->hasClausesOfKind<OMPNumTeamsClause>()) {
const Expr *NumTeams =
NestedDir->getSingleClause<OMPNumTeamsClause>()->getNumTeams();
const Expr *NumTeams = NestedDir->getSingleClause<OMPNumTeamsClause>()
->getNumTeams()
.front();
if (NumTeams->isIntegerConstantExpr(CGF.getContext()))
if (auto Constant =
NumTeams->getIntegerConstantExpr(CGF.getContext()))
Expand All @@ -6062,7 +6063,7 @@ const Expr *CGOpenMPRuntime::getNumTeamsExprForTargetDirective(
case OMPD_target_teams_distribute_parallel_for_simd: {
if (D.hasClausesOfKind<OMPNumTeamsClause>()) {
const Expr *NumTeams =
D.getSingleClause<OMPNumTeamsClause>()->getNumTeams();
D.getSingleClause<OMPNumTeamsClause>()->getNumTeams().front();
if (NumTeams->isIntegerConstantExpr(CGF.getContext()))
if (auto Constant = NumTeams->getIntegerConstantExpr(CGF.getContext()))
MinTeamsVal = MaxTeamsVal = Constant->getExtValue();
Expand Down
2 changes: 1 addition & 1 deletion clang/lib/CodeGen/CGStmtOpenMP.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6859,7 +6859,7 @@ static void emitCommonOMPTeamsDirective(CodeGenFunction &CGF,
const auto *NT = S.getSingleClause<OMPNumTeamsClause>();
const auto *TL = S.getSingleClause<OMPThreadLimitClause>();
if (NT || TL) {
const Expr *NumTeams = NT ? NT->getNumTeams() : nullptr;
const Expr *NumTeams = NT ? NT->getNumTeams().front() : nullptr;
const Expr *ThreadLimit = TL ? TL->getThreadLimit() : nullptr;

CGF.CGM.getOpenMPRuntime().emitNumTeamsClause(CGF, NumTeams, ThreadLimit,
Expand Down
8 changes: 7 additions & 1 deletion clang/lib/Parse/ParseOpenMP.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3175,7 +3175,6 @@ OMPClause *Parser::ParseOpenMPClause(OpenMPDirectiveKind DKind,
case OMPC_simdlen:
case OMPC_collapse:
case OMPC_ordered:
case OMPC_num_teams:
case OMPC_thread_limit:
case OMPC_priority:
case OMPC_grainsize:
Expand Down Expand Up @@ -3332,6 +3331,13 @@ OMPClause *Parser::ParseOpenMPClause(OpenMPDirectiveKind DKind,
? ParseOpenMPSimpleClause(CKind, WrongDirective)
: ParseOpenMPClause(CKind, WrongDirective);
break;
case OMPC_num_teams:
if (!FirstClause) {
Diag(Tok, diag::err_omp_more_one_clause)
<< getOpenMPDirectiveName(DKind) << getOpenMPClauseName(CKind) << 0;
ErrorFound = true;
}
[[clang::fallthrough]];
case OMPC_private:
case OMPC_firstprivate:
case OMPC_lastprivate:
Expand Down
Loading

0 comments on commit cee594c

Please sign in to comment.