Skip to content

Commit

Permalink
Revert "[SYCL] Represent JointMatrixINTEL type as extension type" (in…
Browse files Browse the repository at this point in the history
…tel#9071)

It appears to be, that mem2reg and SROA passes can't handle target
extension type properly. It means, that with turned on optimizations
alloca/load/store sequences of joint matrix types won't be eliminated.
It results in a crash in IGC since it can't handle such case yet. Note,
it means that matrix samples compiled with -O0 also don't work now.

So we have to (temporary?) revert this patch.

This reverts commit 6f8e456.
  • Loading branch information
MrSidims authored Apr 17, 2023
1 parent 097d21c commit 4447a50
Show file tree
Hide file tree
Showing 5 changed files with 82 additions and 115 deletions.
143 changes: 59 additions & 84 deletions clang/lib/CodeGen/CodeGenTypes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,65 @@ void CodeGenTypes::addRecordTypeName(const RecordDecl *RD,
StringRef suffix) {
SmallString<256> TypeName;
llvm::raw_svector_ostream OS(TypeName);
// If RD is spirv_JointMatrixINTEL type, mangle differently.
if (CGM.getTriple().isSPIRV() || CGM.getTriple().isSPIR()) {
if (RD->getQualifiedNameAsString() == "__spv::__spirv_JointMatrixINTEL") {
if (auto TemplateDecl = dyn_cast<ClassTemplateSpecializationDecl>(RD)) {
ArrayRef<TemplateArgument> TemplateArgs =
TemplateDecl->getTemplateArgs().asArray();
OS << "spirv.JointMatrixINTEL.";
for (auto &TemplateArg : TemplateArgs) {
OS << "_";
if (TemplateArg.getKind() == TemplateArgument::Type) {
llvm::Type *TTy = ConvertType(TemplateArg.getAsType());
if (TTy->isIntegerTy()) {
switch (TTy->getIntegerBitWidth()) {
case 8:
OS << "char";
break;
case 16:
OS << "short";
break;
case 32:
OS << "int";
break;
case 64:
OS << "long";
break;
default:
OS << "i" << TTy->getIntegerBitWidth();
break;
}
} else if (TTy->isHalfTy()) {
OS << "half";
} else if (TTy->isFloatTy()) {
OS << "float";
} else if (TTy->isDoubleTy()) {
OS << "double";
} else if (TTy->isBFloatTy()) {
OS << "bfloat16";
} else if (TTy->isStructTy()) {
StringRef LlvmTyName = TTy->getStructName();
// Emit half/bfloat16/tf32 for sycl[::*]::{half,bfloat16,tf32}
if (LlvmTyName.startswith("class.sycl::") ||
LlvmTyName.startswith("class.__sycl_internal::"))
LlvmTyName = LlvmTyName.rsplit("::").second;
if (LlvmTyName != "half" && LlvmTyName != "bfloat16" &&
LlvmTyName != "tf32")
llvm_unreachable("Wrong matrix base type!");
OS << LlvmTyName;
} else {
llvm_unreachable("Wrong matrix base type!");
}
} else if (TemplateArg.getKind() == TemplateArgument::Integral) {
OS << TemplateArg.getAsIntegral();
}
}
Ty->setName(OS.str());
return;
}
}
}
OS << RD->getKindName() << '.';

// FIXME: We probably want to make more tweaks to the printing policy. For
Expand Down Expand Up @@ -401,78 +460,6 @@ llvm::Type *CodeGenTypes::ConvertFunctionTypeInternal(QualType QFT) {
return ResultType;
}

template <bool NeedTypeInterpret = false>
llvm::Type *getJointMatrixINTELExtType(llvm::Type *CompTy,
ArrayRef<TemplateArgument> TemplateArgs,
const unsigned Val = 0) {
// TODO: we should actually have exactly 5 template parameters: 1 for
// type and 4 for type parameters. But in previous version of the SPIR-V
// spec we have Layout matrix type parameter, that was later removed.
// Once we update to the newest version of the spec - this should be updated.
assert((TemplateArgs.size() == 5 || TemplateArgs.size() == 6) &&
"Wrong JointMatrixINTEL template parameters number");
// This is required to represent optional 'Component Type Interpretation'
// parameter
using ParamsType =
typename std::conditional<NeedTypeInterpret, SmallVector<unsigned, 6>,
SmallVector<unsigned, 5>>::type;
ParamsType Params;
if constexpr (NeedTypeInterpret)
Params = {0, 0, 0, 0, 0, Val};
else
Params = {0, 0, 0, 0, 0};
for (size_t I = 1; I != TemplateArgs.size(); ++I) {
assert(TemplateArgs[I].getKind() == TemplateArgument::Integral &&
"Wrong JointMatrixINTEL template parameter");
Params[I - 1] = TemplateArgs[I].getAsIntegral().getExtValue();
}
return llvm::TargetExtType::get(CompTy->getContext(),
"spirv.JointMatrixINTEL", {CompTy}, Params);
}

/// ConvertSYCLJointMatrixINTELType - Convert SYCL joint_matrix type
/// which is represented as a pointer to a structure to LLVM extension type
/// with the parameters that follow SPIR-V JointMatrixINTEL type.
/// The expected representation is:
/// target("spirv.JointMatrixINTEL", %element_type, %rows%, %cols%, %scope%,
/// %use%, (optional) %element_type_interpretation%)
llvm::Type *CodeGenTypes::ConvertSYCLJointMatrixINTELType(RecordDecl *RD) {
auto *TemplateDecl = cast<ClassTemplateSpecializationDecl>(RD);
ArrayRef<TemplateArgument> TemplateArgs =
TemplateDecl->getTemplateArgs().asArray();
assert(TemplateArgs[0].getKind() == TemplateArgument::Type &&
"1st JointMatrixINTEL template parameter must be type");
llvm::Type *CompTy = ConvertType(TemplateArgs[0].getAsType());

// Per JointMatrixINTEL spec the type can have an optional
// 'Component Type Interpretation' parameter. We should emit it in case
// if on SYCL level joint matrix accepts 'bfloat16' or 'tf32' objects as
// matrix's components. Yet 'bfloat16' should be represented as 'int16' and
// 'tf32' as 'float' types.
if (CompTy->isStructTy()) {
StringRef LlvmTyName = CompTy->getStructName();
// Emit half/int16/float for sycl[::*]::{half,bfloat16,tf32}
if (LlvmTyName.startswith("class.sycl::") ||
LlvmTyName.startswith("class.__sycl_internal::"))
LlvmTyName = LlvmTyName.rsplit("::").second;
if (LlvmTyName == "half") {
CompTy = llvm::Type::getHalfTy(getLLVMContext());
return getJointMatrixINTELExtType(CompTy, TemplateArgs);
} else if (LlvmTyName == "tf32") {
CompTy = llvm::Type::getFloatTy(getLLVMContext());
// 'tf32' interpretation is mapped to '0'
return getJointMatrixINTELExtType<true>(CompTy, TemplateArgs, 0);
} else if (LlvmTyName == "bfloat16") {
CompTy = llvm::Type::getInt16Ty(getLLVMContext());
// 'bfloat16' interpretation is mapped to '1'
return getJointMatrixINTELExtType<true>(CompTy, TemplateArgs, 1);
} else {
llvm_unreachable("Wrong matrix base type!");
}
}
return getJointMatrixINTELExtType(CompTy, TemplateArgs);
}

/// ConvertType - Convert the specified type to its LLVM form.
llvm::Type *CodeGenTypes::ConvertType(QualType T) {
T = Context.getCanonicalType(T);
Expand Down Expand Up @@ -758,18 +745,6 @@ llvm::Type *CodeGenTypes::ConvertType(QualType T) {
llvm::Type *PointeeType = ConvertTypeForMem(ETy);
if (PointeeType->isVoidTy())
PointeeType = llvm::Type::getInt8Ty(getLLVMContext());
if (CGM.getTriple().isSPIRV() || CGM.getTriple().isSPIR()) {
const Type *ClangETy = ETy.getTypePtrOrNull();
if (ClangETy && ClangETy->isStructureOrClassType()) {
RecordDecl *RD = ClangETy->getAsCXXRecordDecl();
if (RD &&
RD->getQualifiedNameAsString() == "__spv::__spirv_JointMatrixINTEL") {
ResultType = ConvertSYCLJointMatrixINTELType(RD);
break;
}
}
}

unsigned AS = getTargetAddressSpace(ETy);
ResultType = llvm::PointerType::get(PointeeType, AS);
break;
Expand Down
8 changes: 0 additions & 8 deletions clang/lib/CodeGen/CodeGenTypes.h
Original file line number Diff line number Diff line change
Expand Up @@ -133,14 +133,6 @@ class CodeGenTypes {
/// memory representation is usually i8 or i32, depending on the target.
llvm::Type *ConvertTypeForMem(QualType T, bool ForBitField = false);

/// ConvertSYCLJointMatrixINTELType - Convert SYCL joint_matrix type
/// which is represented as a pointer to a structure to LLVM extension type
/// with the parameters that follow SPIR-V JointMatrixINTEL type.
/// The expected representation is:
/// target("spirv.JointMatrixINTEL", %element_type, %rows%, %cols%, %scope%,
/// %use%, (optional) %element_type_interpretation%)
llvm::Type *ConvertSYCLJointMatrixINTELType(RecordDecl *RD);

/// GetFunctionType - Get the LLVM function type for \arg Info.
llvm::FunctionType *GetFunctionType(const CGFunctionInfo &Info);

Expand Down
34 changes: 17 additions & 17 deletions clang/test/CodeGenSYCL/matrix.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,18 +5,18 @@
#include <stdint.h>

namespace __spv {
template <typename T, size_t R, size_t C, uint32_t L, uint32_t S, uint32_t U>
template <typename T, size_t R, size_t C, uint32_t U, uint32_t S>
struct __spirv_JointMatrixINTEL;
}

// CHECK: @_Z2f1{{.*}}(target("spirv.JointMatrixINTEL", float, 5, 10, 0, 1, 0)
void f1(__spv::__spirv_JointMatrixINTEL<float, 5, 10, 0, 1, 0> *matrix) {}
// CHECK: @_Z2f1{{.*}}(%spirv.JointMatrixINTEL._float_5_10_0_1
void f1(__spv::__spirv_JointMatrixINTEL<float, 5, 10, 0, 1> *matrix) {}

// CHECK: @_Z2f2{{.*}}(target("spirv.JointMatrixINTEL", i64, 10, 2, 0, 0, 0)
void f2(__spv::__spirv_JointMatrixINTEL<uint64_t, 10, 2, 0, 0, 0> *matrix) {}
// CHECK: @_Z2f2{{.*}}(%spirv.JointMatrixINTEL._long_10_2_0_0
void f2(__spv::__spirv_JointMatrixINTEL<uint64_t, 10, 2, 0, 0> *matrix) {}

// CHECK: @_Z2f3{{.*}}(target("spirv.JointMatrixINTEL", i8, 10, 2, 0, 0, 0)
void f3(__spv::__spirv_JointMatrixINTEL<char, 10, 2, 0, 0, 0> *matrix) {}
// CHECK: @_Z2f3{{.*}}(%spirv.JointMatrixINTEL._char_10_2_0_0
void f3(__spv::__spirv_JointMatrixINTEL<char, 10, 2, 0, 0> *matrix) {}

namespace sycl {
class half {};
Expand All @@ -25,17 +25,17 @@ namespace sycl {
}
typedef sycl::half my_half;

// CHECK: @_Z2f4{{.*}}(target("spirv.JointMatrixINTEL", half, 10, 2, 0, 0, 0)
void f4(__spv::__spirv_JointMatrixINTEL<my_half, 10, 2, 0, 0, 0> *matrix) {}
// CHECK: @_Z2f4{{.*}}(%spirv.JointMatrixINTEL._half_10_2_0_0
void f4(__spv::__spirv_JointMatrixINTEL<my_half, 10, 2, 0, 0> *matrix) {}

// CHECK: @_Z2f5{{.*}}(target("spirv.JointMatrixINTEL", i16, 10, 2, 0, 0, 0, 1)
void f5(__spv::__spirv_JointMatrixINTEL<sycl::bfloat16, 10, 2, 0, 0, 0> *matrix) {}
// CHECK: @_Z2f5{{.*}}(%spirv.JointMatrixINTEL._bfloat16_10_2_0_0
void f5(__spv::__spirv_JointMatrixINTEL<sycl::bfloat16, 10, 2, 0, 0> *matrix) {}

// CHECK: @_Z2f6{{.*}}(target("spirv.JointMatrixINTEL", i128, 10, 2, 0, 0, 0)
void f6(__spv::__spirv_JointMatrixINTEL<_BitInt(128), 10, 2, 0, 0, 0> *matrix) {}
// CHECK: @_Z2f6{{.*}}(%spirv.JointMatrixINTEL._i128_10_2_0_0
void f6(__spv::__spirv_JointMatrixINTEL<_BitInt(128), 10, 2, 0, 0> *matrix) {}

// CHECK: @_Z2f7{{.*}}(target("spirv.JointMatrixINTEL", float, 10, 2, 0, 0, 0, 0)
void f7(__spv::__spirv_JointMatrixINTEL<sycl::tf32, 10, 2, 0, 0, 0> *matrix) {}
// CHECK: @_Z2f7{{.*}}(%spirv.JointMatrixINTEL._tf32_10_2_0_0
void f7(__spv::__spirv_JointMatrixINTEL<sycl::tf32, 10, 2, 0, 0> *matrix) {}

// CHECK: @_Z2f8{{.*}}(target("spirv.JointMatrixINTEL", double, 5, 10, 0, 1, 0)
void f8(__spv::__spirv_JointMatrixINTEL<double, 5, 10, 0, 1, 0> *matrix) {}
// CHECK: @_Z2f8{{.*}}(%spirv.JointMatrixINTEL._double_5_10_0_1
void f8(__spv::__spirv_JointMatrixINTEL<double, 5, 10, 0, 1> *matrix) {}
6 changes: 3 additions & 3 deletions sycl/test/matrix/legacy/matrix-int8-test.cpp
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
// RUN: %clangxx -fsycl -fsycl-device-only -O2 -DSYCL_EXT_ONEAPI_MATRIX_VERSION=1 -S -emit-llvm -o - %s | FileCheck %s

// CHECK-DAG: target("spirv.JointMatrixINTEL", i8, 12, 48, 0, 3, 0)
// CHECK-DAG: target("spirv.JointMatrixINTEL", i32, 12, 12, 0, 3, 0)
// CHECK-DAG: target("spirv.JointMatrixINTEL", i8, 48, 12, 3, 3, 0)
// CHECK-DAG: %spirv.JointMatrixINTEL._char_12_48_0_3 = type opaque
// CHECK-DAG: %spirv.JointMatrixINTEL._int_12_12_0_3 = type opaque
// CHECK-DAG: %spirv.JointMatrixINTEL._char_48_12_3_3 = type opaque

#include <iostream>
#include <sycl/sycl.hpp>
Expand Down
6 changes: 3 additions & 3 deletions sycl/test/matrix/matrix-int8-test.cpp
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
// RUN: %clangxx -fsycl -fsycl-device-only -DSYCL_EXT_ONEAPI_MATRIX_VERSION=4 -O2 -S -emit-llvm -o - %s | FileCheck %s

// CHECK-DAG: target("spirv.JointMatrixINTEL", i8, 12, 48, 0, 3, 0)
// CHECK-DAG: target("spirv.JointMatrixINTEL", i32, 12, 12, 3, 3, 2)
// CHECK-DAG: target("spirv.JointMatrixINTEL", i8, 48, 12, 2, 3, 1)
// CHECK-DAG: %spirv.JointMatrixINTEL._char_12_48_0_3_0 = type opaque
// CHECK-DAG: %spirv.JointMatrixINTEL._int_12_12_3_3_2 = type opaque
// CHECK-DAG: %spirv.JointMatrixINTEL._char_48_12_2_3_1 = type opaque

#include <iostream>
#include <sycl/sycl.hpp>
Expand Down

0 comments on commit 4447a50

Please sign in to comment.