Skip to content

Commit

Permalink
SWDEV-197801 - Fix device lambda compilation on Windows.
Browse files Browse the repository at this point in the history
[HIP] Enhance lambda support on MSVC platform.

- MSVC uses different C++ ABI. The different mangle numbering causes
  lambda is identified differently from Itantium C++ ABI. The result in
  the different host-compilation device kernel name mangling from
  device-compilation.
- This patch fixes the aforementioned issue by
  + Add device mangling number for each lambda.
  + Pair Itanium C++ mangle numbering contgext with Microsoft C++ mangle
    numbering context to assign correct device-side lambda numbers.
  + During mangling, use device lambda number if the mangle context is
    a device context.
  + Revise the test with MSVC support.

Change-Id: Id1ab307cdad010f9bec38cdd779c9fbb042158e2
  • Loading branch information
mhbliao committed Aug 23, 2019
1 parent 661cb83 commit a03e606
Show file tree
Hide file tree
Showing 8 changed files with 233 additions and 4 deletions.
13 changes: 13 additions & 0 deletions include/clang/AST/DeclCXX.h
Original file line number Diff line number Diff line change
Expand Up @@ -632,6 +632,9 @@ class CXXRecordDecl : public RecordDecl {
/// mangling in the Itanium C++ ABI.
unsigned ManglingNumber = 0;

/// The device side name mangling number.
unsigned DeviceManglingNumber = 0;

/// The mangling number is enforced to ensure ODR naming.
// FIXME: Save bit from `NumCaptures` to minimize `LambdaDefinitionData`.
bool ForcedNumbering = false;
Expand Down Expand Up @@ -1956,6 +1959,16 @@ class CXXRecordDecl : public RecordDecl {
getLambdaData().ContextDecl = ContextDecl;
}

/// Set the device side mangling number.
void setDeviceLambdaManglingNumber(unsigned Num) {
getLambdaData().DeviceManglingNumber = Num;
}

unsigned getDeviceLambdaManglingNumber() const {
assert(isLambda() && "Not a lambda closure type!");
return getLambdaData().DeviceManglingNumber;
}

/// Returns the inheritance model used for this record.
MSInheritanceAttr::Spelling getMSInheritanceModel() const;

Expand Down
3 changes: 3 additions & 0 deletions include/clang/AST/Mangle.h
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,9 @@ class MangleContext {
virtual bool shouldMangleCXXName(const NamedDecl *D) = 0;
virtual bool shouldMangleStringLiteral(const StringLiteral *SL) = 0;

virtual bool isDeviceMangleContext() const { return false; }
virtual void setDeviceMangleContext(bool) {}

// FIXME: consider replacing raw_ostream & with something like SmallString &.
void mangleName(const NamedDecl *D, raw_ostream &);
virtual void mangleCXXName(const NamedDecl *D, raw_ostream &) = 0;
Expand Down
10 changes: 10 additions & 0 deletions include/clang/AST/MangleNumberingContext.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

#include "clang/Basic/LLVM.h"
#include "llvm/ADT/IntrusiveRefCntPtr.h"
#include "llvm/Support/ErrorHandling.h"

namespace clang {

Expand Down Expand Up @@ -52,6 +53,15 @@ class MangleNumberingContext {
/// this context.
virtual unsigned getManglingNumber(const TagDecl *TD,
unsigned MSLocalManglingNumber) = 0;

/// Has device mangle number context.
virtual bool hasDeviceMangleNumberingContext() { return false; }

/// Retrieve the mangling number of a new lambda expression with the
/// given call operator within the device context.
virtual unsigned getDeviceManglingNumber(const CXXMethodDecl *) {
llvm_unreachable("There's no device context associated!");
}
};

} // end namespace clang
Expand Down
10 changes: 9 additions & 1 deletion lib/AST/ItaniumMangle.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,8 @@ class ItaniumMangleContextImpl : public ItaniumMangleContext {
llvm::DenseMap<DiscriminatorKeyTy, unsigned> Discriminator;
llvm::DenseMap<const NamedDecl*, unsigned> Uniquifier;

bool IsDevCtx = false;

public:
explicit ItaniumMangleContextImpl(ASTContext &Context,
DiagnosticsEngine &Diags)
Expand All @@ -134,6 +136,10 @@ class ItaniumMangleContextImpl : public ItaniumMangleContext {
bool shouldMangleStringLiteral(const StringLiteral *) override {
return false;
}

bool isDeviceMangleContext() const override { return IsDevCtx; }
void setDeviceMangleContext(bool IsDev) override { IsDevCtx = IsDev;}

void mangleCXXName(const NamedDecl *D, raw_ostream &) override;
void mangleThunk(const CXXMethodDecl *MD, const ThunkInfo &Thunk,
raw_ostream &) override;
Expand Down Expand Up @@ -1739,7 +1745,9 @@ void CXXNameMangler::mangleLambda(const CXXRecordDecl *Lambda) {
// (in lexical order) with that same <lambda-sig> and context.
//
// The AST keeps track of the number for us.
unsigned Number = Lambda->getLambdaManglingNumber();
unsigned Number = Context.isDeviceMangleContext()
? Lambda->getDeviceLambdaManglingNumber()
: Lambda->getLambdaManglingNumber();
assert(Number > 0 && "Lambda should be mangled as an unnamed class");
if (Number > 1)
mangleNumber(Number - 2);
Expand Down
169 changes: 169 additions & 0 deletions lib/AST/MicrosoftCXXABI.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,138 @@

using namespace clang;

// Before revising the interface, clone of `ItaniumNumberingContext` from
// `lib/AST/ItaniumCXXABI.cpp`.
// {{{ BEGIN CLONE
namespace {

/// According to Itanium C++ ABI 5.1.2:
/// the name of an anonymous union is considered to be
/// the name of the first named data member found by a pre-order,
/// depth-first, declaration-order walk of the data members of
/// the anonymous union.
/// If there is no such data member (i.e., if all of the data members
/// in the union are unnamed), then there is no way for a program to
/// refer to the anonymous union, and there is therefore no need to mangle its name.
///
/// Returns the name of anonymous union VarDecl or nullptr if it is not found.
static const IdentifierInfo *findAnonymousUnionVarDeclName(const VarDecl& VD) {
const RecordType *RT = VD.getType()->getAs<RecordType>();
assert(RT && "type of VarDecl is expected to be RecordType.");
assert(RT->getDecl()->isUnion() && "RecordType is expected to be a union.");
if (const FieldDecl *FD = RT->getDecl()->findFirstNamedDataMember()) {
return FD->getIdentifier();
}

return nullptr;
}

/// The name of a decomposition declaration.
struct DecompositionDeclName {
using BindingArray = ArrayRef<const BindingDecl*>;

/// Representative example of a set of bindings with these names.
BindingArray Bindings;

/// Iterators over the sequence of identifiers in the name.
struct Iterator
: llvm::iterator_adaptor_base<Iterator, BindingArray::const_iterator,
std::random_access_iterator_tag,
const IdentifierInfo *> {
Iterator(BindingArray::const_iterator It) : iterator_adaptor_base(It) {}
const IdentifierInfo *operator*() const {
return (*this->I)->getIdentifier();
}
};
Iterator begin() const { return Iterator(Bindings.begin()); }
Iterator end() const { return Iterator(Bindings.end()); }
};
}

namespace llvm {
template<>
struct DenseMapInfo<DecompositionDeclName> {
using ArrayInfo = llvm::DenseMapInfo<ArrayRef<const BindingDecl*>>;
using IdentInfo = llvm::DenseMapInfo<const IdentifierInfo*>;
static DecompositionDeclName getEmptyKey() {
return {ArrayInfo::getEmptyKey()};
}
static DecompositionDeclName getTombstoneKey() {
return {ArrayInfo::getTombstoneKey()};
}
static unsigned getHashValue(DecompositionDeclName Key) {
assert(!isEqual(Key, getEmptyKey()) && !isEqual(Key, getTombstoneKey()));
return llvm::hash_combine_range(Key.begin(), Key.end());
}
static bool isEqual(DecompositionDeclName LHS, DecompositionDeclName RHS) {
if (ArrayInfo::isEqual(LHS.Bindings, ArrayInfo::getEmptyKey()))
return ArrayInfo::isEqual(RHS.Bindings, ArrayInfo::getEmptyKey());
if (ArrayInfo::isEqual(LHS.Bindings, ArrayInfo::getTombstoneKey()))
return ArrayInfo::isEqual(RHS.Bindings, ArrayInfo::getTombstoneKey());
return LHS.Bindings.size() == RHS.Bindings.size() &&
std::equal(LHS.begin(), LHS.end(), RHS.begin());
}
};
}

namespace {

/// Keeps track of the mangled names of lambda expressions and block
/// literals within a particular context.
class ItaniumNumberingContext : public MangleNumberingContext {
llvm::DenseMap<const Type *, unsigned> ManglingNumbers;
llvm::DenseMap<const IdentifierInfo *, unsigned> VarManglingNumbers;
llvm::DenseMap<const IdentifierInfo *, unsigned> TagManglingNumbers;
llvm::DenseMap<DecompositionDeclName, unsigned>
DecompsitionDeclManglingNumbers;

public:
unsigned getManglingNumber(const CXXMethodDecl *CallOperator) override {
const FunctionProtoType *Proto =
CallOperator->getType()->getAs<FunctionProtoType>();
ASTContext &Context = CallOperator->getASTContext();

FunctionProtoType::ExtProtoInfo EPI;
EPI.Variadic = Proto->isVariadic();
QualType Key =
Context.getFunctionType(Context.VoidTy, Proto->getParamTypes(), EPI);
Key = Context.getCanonicalType(Key);
return ++ManglingNumbers[Key->castAs<FunctionProtoType>()];
}

unsigned getManglingNumber(const BlockDecl *BD) override {
const Type *Ty = nullptr;
return ++ManglingNumbers[Ty];
}

unsigned getStaticLocalNumber(const VarDecl *VD) override {
return 0;
}

/// Variable decls are numbered by identifier.
unsigned getManglingNumber(const VarDecl *VD, unsigned) override {
if (auto *DD = dyn_cast<DecompositionDecl>(VD)) {
DecompositionDeclName Name{DD->bindings()};
return ++DecompsitionDeclManglingNumbers[Name];
}

const IdentifierInfo *Identifier = VD->getIdentifier();
if (!Identifier) {
// VarDecl without an identifier represents an anonymous union
// declaration.
Identifier = findAnonymousUnionVarDeclName(*VD);
}
return ++VarManglingNumbers[Identifier];
}

unsigned getManglingNumber(const TagDecl *TD, unsigned) override {
return ++TagManglingNumbers[TD->getIdentifier()];
}
};

} // End anonymous namesapce
// END CLONE }}}

namespace {

/// Numbers things which need to correspond across multiple TUs.
Expand Down Expand Up @@ -63,6 +195,41 @@ class MicrosoftNumberingContext : public MangleNumberingContext {
}
};

class MSHIPNumberingContext : public MangleNumberingContext {
MicrosoftNumberingContext HostCtx;
ItaniumNumberingContext DeviceCtx;

public:

unsigned getManglingNumber(const CXXMethodDecl *CallOperator) override {
return HostCtx.getManglingNumber(CallOperator);
}

unsigned getManglingNumber(const BlockDecl *BD) override {
return HostCtx.getManglingNumber(BD);
}

unsigned getStaticLocalNumber(const VarDecl *VD) override {
return HostCtx.getStaticLocalNumber(VD);
}

unsigned getManglingNumber(const VarDecl *VD,
unsigned MSLocalManglingNumber) override {
return HostCtx.getManglingNumber(VD, MSLocalManglingNumber);
}

unsigned getManglingNumber(const TagDecl *TD,
unsigned MSLocalManglingNumber) override {
return HostCtx.getManglingNumber(TD, MSLocalManglingNumber);
}

bool hasDeviceMangleNumberingContext() override { return true; }

unsigned getDeviceManglingNumber(const CXXMethodDecl *CallOperator) override {
return DeviceCtx.getManglingNumber(CallOperator);
}
};

class MicrosoftCXXABI : public CXXABI {
ASTContext &Context;
llvm::SmallDenseMap<CXXRecordDecl *, CXXConstructorDecl *> RecordToCopyCtor;
Expand Down Expand Up @@ -132,6 +299,8 @@ class MicrosoftCXXABI : public CXXABI {

std::unique_ptr<MangleNumberingContext>
createMangleNumberingContext() const override {
if (Context.getLangOpts().CUDA)
return llvm::make_unique<MSHIPNumberingContext>();
return llvm::make_unique<MicrosoftNumberingContext>();
}
};
Expand Down
4 changes: 4 additions & 0 deletions lib/CodeGen/CGCUDANV.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,10 @@ CGNVCUDARuntime::CGNVCUDARuntime(CodeGenModule &CGM)
CharPtrTy = llvm::PointerType::getUnqual(Types.ConvertType(Ctx.CharTy));
VoidPtrTy = cast<llvm::PointerType>(Types.ConvertType(Ctx.VoidPtrTy));
VoidPtrPtrTy = VoidPtrTy->getPointerTo();

DeviceMC->setDeviceMangleContext(
CGM.getContext().getTargetInfo().getCXXABI().isMicrosoft() &&
CGM.getContext().getAuxTargetInfo()->getCXXABI().isItaniumFamily());
}

llvm::FunctionCallee CGNVCUDARuntime::getSetupArgumentFn() const {
Expand Down
4 changes: 4 additions & 0 deletions lib/Sema/SemaLambda.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -452,6 +452,10 @@ CXXMethodDecl *Sema::startLambdaDefinition(
getLangOpts().CUDAForceLambdaODR, &Forced)) {
unsigned ManglingNumber = MCtx->getManglingNumber(Method);
Class->setLambdaMangling(ManglingNumber, ManglingContextDecl, Forced);
if (MCtx->hasDeviceMangleNumberingContext()) {
Class->setDeviceLambdaManglingNumber(
MCtx->getDeviceManglingNumber(Method));
}
}
}

Expand Down
24 changes: 21 additions & 3 deletions test/CodeGenCUDA/unnamed-types.cu
Original file line number Diff line number Diff line change
@@ -1,12 +1,17 @@
// RUN: %clang_cc1 -std=c++11 -x hip -triple x86_64-linux-gnu -fcuda-force-lambda-odr -emit-llvm %s -o - | FileCheck %s --check-prefix=HOST
// RUN: %clang_cc1 -std=c++11 -x hip -triple x86_64-linux-gnu -aux-triple amdgcn-amd-amdhsa -fcuda-force-lambda-odr -emit-llvm %s -o - | FileCheck %s --check-prefix=HOST
// RUN: %clang_cc1 -std=c++11 -x hip -triple x86_64-pc-windows-msvc -aux-triple amdgcn-amd-amdhsa -fcuda-force-lambda-odr -emit-llvm %s -o - | FileCheck %s --check-prefix=MSVC
// RUN: %clang_cc1 -std=c++11 -x hip -triple amdgcn-amd-amdhsa -fcuda-force-lambda-odr -fcuda-is-device -emit-llvm %s -o - | FileCheck %s --check-prefix=DEVICE

#include "Inputs/cuda.h"

// HOST: @0 = private unnamed_addr constant [43 x i8] c"_Z2k0IZZ2f1PfENKUlS0_E_clES0_EUlfE_EvS0_T_\00", align 1
// HOST: @1 = private unnamed_addr constant [60 x i8] c"_Z2k1IZ2f1PfEUlfE_Z2f1S0_EUlffE_Z2f1S0_EUlfE0_EvS0_T_T0_T1_\00", align 1
// Check that, on MSVC, the same device kernel mangling name is generated.
// MSVC: @0 = private unnamed_addr constant [43 x i8] c"_Z2k0IZZ2f1PfENKUlS0_E_clES0_EUlfE_EvS0_T_\00", align 1
// MSVC: @1 = private unnamed_addr constant [60 x i8] c"_Z2k1IZ2f1PfEUlfE_Z2f1S0_EUlffE_Z2f1S0_EUlfE0_EvS0_T_T0_T1_\00", align 1

__device__ float d0(float x) {
return [](float x) { return x + 2.f; }(x);
return [](float x) { return x + 1.f; }(x);
}

__device__ float d1(float x) {
Expand All @@ -19,6 +24,12 @@ __global__ void k0(float *p, F f) {
p[0] = f(p[0]) + d0(p[1]) + d1(p[2]);
}

// DEVICE: amdgpu_kernel void @_Z2k1IZ2f1PfEUlfE_Z2f1S0_EUlffE_Z2f1S0_EUlfE0_EvS0_T_T0_T1_(
template <typename F0, typename F1, typename F2>
__global__ void k1(float *p, F0 f0, F1 f1, F2 f2) {
p[0] = f0(p[0]) + f1(p[1], p[2]) + f2(p[3]);
}

void f0(float *p) {
[](float *p) {
*p = 1.f;
Expand All @@ -27,8 +38,15 @@ void f0(float *p) {

void f1(float *p) {
[](float *p) {
k0<<<1,1>>>(p, [] __device__ (float x) { return x + 1.f; });
k0<<<1,1>>>(p, [] __device__ (float x) { return x + 3.f; });
}(p);
k1<<<1,1>>>(p,
[] __device__ (float x) { return x + 4.f; },
[] __device__ (float x, float y) { return x * y; },
[] __device__ (float x) { return x + 5.f; });
}
// HOST: @__hip_register_globals
// HOST: __hipRegisterFunction{{.*}}@{{(__device_stub_)?}}_Z2k0IZZ2f1PfENKUlS0_E_clES0_EUlfE_EvS0_T_{{.*}}@0
// HOST: __hipRegisterFunction{{.*}}@{{(__device_stub_)?}}_Z2k1IZ2f1PfEUlfE_Z2f1S0_EUlffE_Z2f1S0_EUlfE0_EvS0_T_T0_T1_{{.*}}@1
// MSVC: __hipRegisterFunction{{.*}}@"{{(__device_stub_)?}}??$k0@V<lambda_1>@?0???R1?0??f1@@YAXPEAM@Z@QEBA@0@Z@@@YAXPEAMV<lambda_1>@?0???R0?0??f1@@YAX0@Z@QEBA@0@Z@@Z{{.*}}@0
// MSVC: __hipRegisterFunction{{.*}}@"{{(__device_stub_)?}}??$k1@V<lambda_2>@?0??f1@@YAXPEAM@Z@V<lambda_3>@?0??2@YAX0@Z@V<lambda_4>@?0??2@YAX0@Z@@@YAXPEAMV<lambda_2>@?0??f1@@YAX0@Z@V<lambda_3>@?0??1@YAX0@Z@V<lambda_4>@?0??1@YAX0@Z@@Z{{.*}}@1

0 comments on commit a03e606

Please sign in to comment.