Skip to content

Commit

Permalink
[SYCL] Handle KernelName templated using type with enum template argu…
Browse files Browse the repository at this point in the history
…ment (#1780)

Add support to handle enums when KernelNameType is templated using a type which is in turn templated using enum.

Signed-off-by: Elizabeth Andrews <[email protected]>
  • Loading branch information
elizabethandrews authored Jun 1, 2020
1 parent 12d14e8 commit f9226d2
Show file tree
Hide file tree
Showing 3 changed files with 54 additions and 17 deletions.
7 changes: 5 additions & 2 deletions clang/include/clang/Sema/Sema.h
Original file line number Diff line number Diff line change
Expand Up @@ -319,7 +319,8 @@ class SYCLIntegrationHeader {
};

public:
SYCLIntegrationHeader(DiagnosticsEngine &Diag, bool UnnamedLambdaSupport);
SYCLIntegrationHeader(DiagnosticsEngine &Diag, bool UnnamedLambdaSupport,
Sema &S);

/// Emits contents of the header into given stream.
void emit(raw_ostream &Out);
Expand Down Expand Up @@ -424,6 +425,8 @@ class SYCLIntegrationHeader {

/// Whether header is generated with unnamed lambda support
bool UnnamedLambdaSupport;

Sema &S;
};

/// Keeps track of expected type during expression parsing. The type is tied to
Expand Down Expand Up @@ -12584,7 +12587,7 @@ class Sema final {
SYCLIntegrationHeader &getSyclIntegrationHeader() {
if (SyclIntHeader == nullptr)
SyclIntHeader = std::make_unique<SYCLIntegrationHeader>(
getDiagnostics(), getLangOpts().SYCLUnnamedLambda);
getDiagnostics(), getLangOpts().SYCLUnnamedLambda, *this);
return *SyclIntHeader.get();
}

Expand Down
33 changes: 19 additions & 14 deletions clang/lib/Sema/SemaSYCL.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1866,6 +1866,9 @@ static void printArguments(ASTContext &Ctx, raw_ostream &ArgOS,
ArrayRef<TemplateArgument> Args,
const PrintingPolicy &P);

static std::string getKernelNameTypeString(QualType T, ASTContext &Ctx,
const PrintingPolicy &TypePolicy);

static void printArgument(ASTContext &Ctx, raw_ostream &ArgOS,
TemplateArgument Arg, const PrintingPolicy &P) {
switch (Arg.getKind()) {
Expand All @@ -1891,8 +1894,7 @@ static void printArgument(ASTContext &Ctx, raw_ostream &ArgOS,
TypePolicy.SuppressTypedefs = true;
TypePolicy.SuppressTagKeyword = true;
QualType T = Arg.getAsType();
QualType FullyQualifiedType = TypeName::getFullyQualifiedType(T, Ctx, true);
ArgOS << FullyQualifiedType.getAsString(TypePolicy);
ArgOS << getKernelNameTypeString(T, Ctx, TypePolicy);
break;
}
default:
Expand Down Expand Up @@ -1925,36 +1927,36 @@ static void printTemplateArguments(ASTContext &Ctx, raw_ostream &ArgOS,
ArgOS << ">";
}

static std::string getKernelNameTypeString(QualType T) {
static std::string getKernelNameTypeString(QualType T, ASTContext &Ctx,
const PrintingPolicy &TypePolicy) {

QualType FullyQualifiedType = TypeName::getFullyQualifiedType(T, Ctx, true);

const CXXRecordDecl *RD = T->getAsCXXRecordDecl();

if (!RD)
return getCPPTypeString(T);
return eraseAnonNamespace(FullyQualifiedType.getAsString(TypePolicy));

// If kernel name type is a template specialization with enum type
// template parameters, enumerators in name type string should be
// replaced with their underlying value since the enum definition
// is not visible in integration header.
if (const auto *TSD = dyn_cast<ClassTemplateSpecializationDecl>(RD)) {
LangOptions LO;
PrintingPolicy P(LO);
P.SuppressTypedefs = true;
SmallString<64> Buf;
llvm::raw_svector_ostream ArgOS(Buf);

// Print template class name
TSD->printQualifiedName(ArgOS, P, /*WithGlobalNsPrefix*/ true);
TSD->printQualifiedName(ArgOS, TypePolicy, /*WithGlobalNsPrefix*/ true);

// Print template arguments substituting enumerators
ASTContext &Ctx = RD->getASTContext();
const TemplateArgumentList &Args = TSD->getTemplateArgs();
printTemplateArguments(Ctx, ArgOS, Args.asArray(), P);
printTemplateArguments(Ctx, ArgOS, Args.asArray(), TypePolicy);

return eraseAnonNamespace(ArgOS.str().str());
}

return getCPPTypeString(T);
return eraseAnonNamespace(FullyQualifiedType.getAsString(TypePolicy));
}

void SYCLIntegrationHeader::emit(raw_ostream &O) {
Expand Down Expand Up @@ -2073,9 +2075,11 @@ void SYCLIntegrationHeader::emit(raw_ostream &O) {
O << "', '" << c;
O << "'> {\n";
} else {

LangOptions LO;
PrintingPolicy P(LO);
P.SuppressTypedefs = true;
O << "template <> struct KernelInfo<"
<< getKernelNameTypeString(K.NameType) << "> {\n";
<< getKernelNameTypeString(K.NameType, S.getASTContext(), P) << "> {\n";
}
O << " DLL_LOCAL\n";
O << " static constexpr const char* getName() { return \"" << K.Name
Expand Down Expand Up @@ -2144,8 +2148,9 @@ void SYCLIntegrationHeader::addSpecConstant(StringRef IDName, QualType IDType) {
}

SYCLIntegrationHeader::SYCLIntegrationHeader(DiagnosticsEngine &_Diag,
bool _UnnamedLambdaSupport)
: Diag(_Diag), UnnamedLambdaSupport(_UnnamedLambdaSupport) {}
bool _UnnamedLambdaSupport,
Sema &_S)
: Diag(_Diag), UnnamedLambdaSupport(_UnnamedLambdaSupport), S(_S) {}

// -----------------------------------------------------------------------------
// Utility class methods
Expand Down
31 changes: 30 additions & 1 deletion clang/test/CodeGenSYCL/kernelname-enum.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,21 @@ class dummy_functor_7 {
void operator()() {}
};

namespace type_argument_template_enum {
enum class E : int {
A,
B,
C
};
}

template <typename T>
class T1 {};
template <type_argument_template_enum::E EnumValue>
class T2 {};
template <typename EnumType>
class T3 {};

int main() {

dummy_functor_1<no_namespace_int::val_1> f1;
Expand Down Expand Up @@ -124,6 +139,14 @@ int main() {
cgh.single_task(f8);
});

q.submit([&](cl::sycl::handler &cgh) {
cgh.single_task<T1<T2<type_argument_template_enum::E::A>>>([=]() {});
});

q.submit([&](cl::sycl::handler &cgh) {
cgh.single_task<T1<T3<type_argument_template_enum::E>>>([=]() {});
});

return 0;
}

Expand All @@ -145,7 +168,11 @@ int main() {
// CHECK: enum unscoped_enum : int;
// CHECK: template <unscoped_enum EnumType> class dummy_functor_6;
// CHECK: template <typename EnumType> class dummy_functor_7;

// CHECK: namespace type_argument_template_enum {
// CHECK-NEXT: enum class E : int;
// CHECK-NEXT: }
// CHECK: template <type_argument_template_enum::E EnumValue> class T2;
// CHECK: template <typename T> class T1;
// CHECK: Specializations of KernelInfo for kernel function types:
// CHECK: template <> struct KernelInfo<::dummy_functor_1<(no_namespace_int)0>>
// CHECK: template <> struct KernelInfo<::dummy_functor_2<(no_namespace_short)1>>
Expand All @@ -155,3 +182,5 @@ int main() {
// CHECK: template <> struct KernelInfo<::dummy_functor_6<(unscoped_enum)0>>
// CHECK: template <> struct KernelInfo<::dummy_functor_7<::no_namespace_int>>
// CHECK: template <> struct KernelInfo<::dummy_functor_7<::internal::namespace_short>>
// CHECK: template <> struct KernelInfo<::T1<::T2<(type_argument_template_enum::E)0>>>
// CHECK: template <> struct KernelInfo<::T1<::T3<::type_argument_template_enum::E>>>

0 comments on commit f9226d2

Please sign in to comment.