Skip to content

Commit

Permalink
[SYCL] fix MarkFunction ASTConsumer issue with delayed instantiations
Browse files Browse the repository at this point in the history
Signed-off-by: Vladimir Lazarev <[email protected]>
Signed-off-by: Blower, Melanie <[email protected]>
  • Loading branch information
Blower, Melanie authored and vladimirlaz committed Mar 22, 2019
1 parent e878f1d commit 971fecd
Show file tree
Hide file tree
Showing 10 changed files with 209 additions and 110 deletions.
1 change: 1 addition & 0 deletions clang/include/clang/Analysis/CallGraph.h
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,7 @@ class CallGraph : public RecursiveASTVisitor<CallGraph> {

bool shouldWalkTypesOfTypeLocs() const { return false; }
bool shouldVisitTemplateInstantiations() const { return true; }
bool shouldVisitImplicitCode() const { return true; }

private:
/// Add the given declaration to the call graph.
Expand Down
1 change: 1 addition & 0 deletions clang/include/clang/Sema/Sema.h
Original file line number Diff line number Diff line change
Expand Up @@ -11146,6 +11146,7 @@ class Sema {
}

void ConstructSYCLKernel(FunctionDecl *KernelCallerFunc);
void MarkDevice(void);
};

/// RAII object that enters a new expression evaluation context.
Expand Down
24 changes: 24 additions & 0 deletions clang/lib/Analysis/CallGraph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,30 @@ class CGBuilder : public StmtVisitor<CGBuilder> {
VisitChildren(CE);
}

void VisitLambdaExpr(LambdaExpr *LE) {
if (CXXMethodDecl *MD = LE->getCallOperator())
G->VisitFunctionDecl(MD);
}

void VisitCXXNewExpr(CXXNewExpr *E) {
if (FunctionDecl *FD = E->getOperatorNew())
addCalledDecl(FD);
VisitChildren(E);
}

void VisitCXXConstructExpr(CXXConstructExpr *E) {
CXXConstructorDecl *Ctor = E->getConstructor();
if (FunctionDecl *Def = Ctor->getDefinition())
addCalledDecl(Def);
const auto *ConstructedType = Ctor->getParent();
if (ConstructedType->hasUserDeclaredDestructor()) {
CXXDestructorDecl *Dtor = ConstructedType->getDestructor();
if (FunctionDecl *Def = Dtor->getDefinition())
addCalledDecl(Def);
}
VisitChildren(E);
}

// Adds may-call edges for the ObjC message sends.
void VisitObjCMessageExpr(ObjCMessageExpr *ME) {
if (ObjCInterfaceDecl *IDecl = ME->getReceiverInterface()) {
Expand Down
3 changes: 3 additions & 0 deletions clang/lib/Sema/Sema.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -921,6 +921,9 @@ void Sema::ActOnEndOfTranslationUnit() {
if (getLangOpts().SYCLIsDevice && SyclIntHeader != nullptr) {
SyclIntHeader->emit(getLangOpts().SYCLIntHeader);
}
if (getLangOpts().SYCLIsDevice)
MarkDevice();


assert(LateParsedInstantiations.empty() &&
"end of TU template instantiation should not create more "
Expand Down
67 changes: 46 additions & 21 deletions clang/lib/Sema/SemaSYCL.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -94,10 +94,13 @@ class MarkDeviceFunction : public RecursiveASTVisitor<MarkDeviceFunction> {
// instantiation as template functions. It means that
// all functions used by kernel have already been parsed and have
// definitions.
llvm::SmallPtrSet<FunctionDecl *, 10> VisitedSet;
if (IsRecursive(Callee, Callee, VisitedSet))
if (RecursiveSet.count(Callee)) {
SemaRef.Diag(e->getExprLoc(), diag::err_sycl_restrict) <<
KernelCallRecursiveFunction;
SemaRef.Diag(Callee->getSourceRange().getBegin(),
diag::note_sycl_recursive_function_declared_here)
<< KernelCallRecursiveFunction;
}

if (const CXXMethodDecl *Method = dyn_cast<CXXMethodDecl>(Callee))
if (Method->isVirtual())
Expand All @@ -109,7 +112,6 @@ class MarkDeviceFunction : public RecursiveASTVisitor<MarkDeviceFunction> {
if (FunctionDecl *Def = Callee->getDefinition()) {
if (!Def->hasAttr<SYCLDeviceAttr>()) {
Def->addAttr(SYCLDeviceAttr::CreateImplicit(SemaRef.Context));
this->TraverseStmt(Def->getBody());
SemaRef.AddSyclKernel(Def);
}
}
Expand All @@ -127,7 +129,6 @@ class MarkDeviceFunction : public RecursiveASTVisitor<MarkDeviceFunction> {

if (FunctionDecl *Def = Ctor->getDefinition()) {
Def->addAttr(SYCLDeviceAttr::CreateImplicit(SemaRef.Context));
this->TraverseStmt(Def->getBody());
SemaRef.AddSyclKernel(Def);
}

Expand All @@ -137,7 +138,6 @@ class MarkDeviceFunction : public RecursiveASTVisitor<MarkDeviceFunction> {

if (FunctionDecl *Def = Dtor->getDefinition()) {
Def->addAttr(SYCLDeviceAttr::CreateImplicit(SemaRef.Context));
this->TraverseStmt(Def->getBody());
SemaRef.AddSyclKernel(Def);
}
}
Expand Down Expand Up @@ -211,7 +211,6 @@ class MarkDeviceFunction : public RecursiveASTVisitor<MarkDeviceFunction> {
} else if (FunctionDecl *Def = FD->getDefinition()) {
if (!Def->hasAttr<SYCLDeviceAttr>()) {
Def->addAttr(SYCLDeviceAttr::CreateImplicit(SemaRef.Context));
this->TraverseStmt(Def->getBody());
SemaRef.AddSyclKernel(Def);
}
}
Expand Down Expand Up @@ -257,33 +256,42 @@ class MarkDeviceFunction : public RecursiveASTVisitor<MarkDeviceFunction> {

// The call graph for this translation unit.
CallGraph SYCLCG;
private:
// The set of functions called by a kernel function.
llvm::SmallPtrSet<FunctionDecl *, 10> KernelSet;
// The set of recursive functions identified while building the
// kernel set, this is used for error diagnostics.
llvm::SmallPtrSet<FunctionDecl *, 10> RecursiveSet;
// Determines whether the function FD is recursive.
// CalleeNode is a function which is called either directly
// or indirectly from FD. If recursion is detected then create
// diagnostic notes on each function as the callstack is unwound.
bool IsRecursive(FunctionDecl *CalleeNode, FunctionDecl *FD,
llvm::SmallPtrSet<FunctionDecl *, 10> VisitedSet) {
void CollectKernelSet(FunctionDecl *CalleeNode, FunctionDecl *FD,
llvm::SmallPtrSet<FunctionDecl *, 10> VisitedSet) {
// We're currently checking CalleeNode on a different
// trace through the CallGraph, we avoid infinite recursion
// by using VisitedSet to keep track of this.
if (!VisitedSet.insert(CalleeNode).second)
return false;
// by using KernelSet to keep track of this.
if (!KernelSet.insert(CalleeNode).second)
// Previously seen, stop recursion.
return;
if (CallGraphNode *N = SYCLCG.getNode(CalleeNode)) {
for (const CallGraphNode *CI : *N) {
if (FunctionDecl *Callee = dyn_cast<FunctionDecl>(CI->getDecl())) {
Callee = Callee->getCanonicalDecl();
if (Callee == FD)
return SemaRef.Diag(FD->getSourceRange().getBegin(),
diag::note_sycl_recursive_function_declared_here)
<< KernelCallRecursiveFunction;
else if (IsRecursive(Callee, FD, VisitedSet))
return true;
if (VisitedSet.count(Callee)) {
// There's a stack frame to visit this Callee above
// this invocation. Do not recurse here.
RecursiveSet.insert(Callee);
RecursiveSet.insert(CalleeNode);
} else {
VisitedSet.insert(Callee);
CollectKernelSet(Callee, FD, VisitedSet);
VisitedSet.erase(Callee);
}
}
}
}
return false;
}
private:

bool CheckSYCLType(QualType Ty, SourceRange Loc) {
if (Ty->isVariableArrayType()) {
Expand Down Expand Up @@ -770,13 +778,30 @@ void Sema::ConstructSYCLKernel(FunctionDecl *KernelCallerFunc) {
CreateSYCLKernelBody(*this, KernelCallerFunc, SYCLKernel);
SYCLKernel->setBody(SYCLKernelBody);
AddSyclKernel(SYCLKernel);
}

void Sema::MarkDevice(void) {
// Let's mark all called functions with SYCL Device attribute.
MarkDeviceFunction Marker(*this);
// Create the call graph so we can detect recursion and check the validity
// of new operator overrides. Add the kernel function itself in case
// it is recursive.
MarkDeviceFunction Marker(*this);
Marker.SYCLCG.addToCallGraph(getASTContext().getTranslationUnitDecl());
Marker.TraverseStmt(SYCLKernelBody);
for (Decl *D : SyclKernels()) {
if (auto SYCLKernel = dyn_cast<FunctionDecl>(D)) {
llvm::SmallPtrSet<FunctionDecl *, 10> VisitedSet;
Marker.CollectKernelSet(SYCLKernel, SYCLKernel, VisitedSet);
}
}
for (const auto &elt : Marker.KernelSet) {
if (FunctionDecl *Def = elt->getDefinition()) {
if (!Def->hasAttr<SYCLDeviceAttr>()) {
Def->addAttr(SYCLDeviceAttr::CreateImplicit(Context));
AddSyclKernel(Def);
}
Marker.TraverseStmt(Def->getBody());
}
}
}

// -----------------------------------------------------------------------------
Expand Down
Loading

0 comments on commit 971fecd

Please sign in to comment.