From 4447a50fb73068a438f0a332b138e73d824d5c2b Mon Sep 17 00:00:00 2001 From: Dmitry Sidorov Date: Mon, 17 Apr 2023 11:53:33 +0200 Subject: [PATCH] Revert "[SYCL] Represent JointMatrixINTEL type as extension type" (#9071) It appears to be, that mem2reg and SROA passes can't handle target extension type properly. It means, that with turned on optimizations alloca/load/store sequences of joint matrix types won't be eliminated. It results in a crash in IGC since it can't handle such case yet. Note, it means that matrix samples compiled with -O0 also don't work now. So we have to (temporary?) revert this patch. This reverts commit 6f8e45670a0b72d9b0f343c02927b709638491de. --- clang/lib/CodeGen/CodeGenTypes.cpp | 143 ++++++++----------- clang/lib/CodeGen/CodeGenTypes.h | 8 -- clang/test/CodeGenSYCL/matrix.cpp | 34 ++--- sycl/test/matrix/legacy/matrix-int8-test.cpp | 6 +- sycl/test/matrix/matrix-int8-test.cpp | 6 +- 5 files changed, 82 insertions(+), 115 deletions(-) diff --git a/clang/lib/CodeGen/CodeGenTypes.cpp b/clang/lib/CodeGen/CodeGenTypes.cpp index ed063a36e5705..bf8e4006e1d69 100644 --- a/clang/lib/CodeGen/CodeGenTypes.cpp +++ b/clang/lib/CodeGen/CodeGenTypes.cpp @@ -51,6 +51,65 @@ void CodeGenTypes::addRecordTypeName(const RecordDecl *RD, StringRef suffix) { SmallString<256> TypeName; llvm::raw_svector_ostream OS(TypeName); + // If RD is spirv_JointMatrixINTEL type, mangle differently. + if (CGM.getTriple().isSPIRV() || CGM.getTriple().isSPIR()) { + if (RD->getQualifiedNameAsString() == "__spv::__spirv_JointMatrixINTEL") { + if (auto TemplateDecl = dyn_cast(RD)) { + ArrayRef TemplateArgs = + TemplateDecl->getTemplateArgs().asArray(); + OS << "spirv.JointMatrixINTEL."; + for (auto &TemplateArg : TemplateArgs) { + OS << "_"; + if (TemplateArg.getKind() == TemplateArgument::Type) { + llvm::Type *TTy = ConvertType(TemplateArg.getAsType()); + if (TTy->isIntegerTy()) { + switch (TTy->getIntegerBitWidth()) { + case 8: + OS << "char"; + break; + case 16: + OS << "short"; + break; + case 32: + OS << "int"; + break; + case 64: + OS << "long"; + break; + default: + OS << "i" << TTy->getIntegerBitWidth(); + break; + } + } else if (TTy->isHalfTy()) { + OS << "half"; + } else if (TTy->isFloatTy()) { + OS << "float"; + } else if (TTy->isDoubleTy()) { + OS << "double"; + } else if (TTy->isBFloatTy()) { + OS << "bfloat16"; + } else if (TTy->isStructTy()) { + StringRef LlvmTyName = TTy->getStructName(); + // Emit half/bfloat16/tf32 for sycl[::*]::{half,bfloat16,tf32} + if (LlvmTyName.startswith("class.sycl::") || + LlvmTyName.startswith("class.__sycl_internal::")) + LlvmTyName = LlvmTyName.rsplit("::").second; + if (LlvmTyName != "half" && LlvmTyName != "bfloat16" && + LlvmTyName != "tf32") + llvm_unreachable("Wrong matrix base type!"); + OS << LlvmTyName; + } else { + llvm_unreachable("Wrong matrix base type!"); + } + } else if (TemplateArg.getKind() == TemplateArgument::Integral) { + OS << TemplateArg.getAsIntegral(); + } + } + Ty->setName(OS.str()); + return; + } + } + } OS << RD->getKindName() << '.'; // FIXME: We probably want to make more tweaks to the printing policy. For @@ -401,78 +460,6 @@ llvm::Type *CodeGenTypes::ConvertFunctionTypeInternal(QualType QFT) { return ResultType; } -template -llvm::Type *getJointMatrixINTELExtType(llvm::Type *CompTy, - ArrayRef TemplateArgs, - const unsigned Val = 0) { - // TODO: we should actually have exactly 5 template parameters: 1 for - // type and 4 for type parameters. But in previous version of the SPIR-V - // spec we have Layout matrix type parameter, that was later removed. - // Once we update to the newest version of the spec - this should be updated. - assert((TemplateArgs.size() == 5 || TemplateArgs.size() == 6) && - "Wrong JointMatrixINTEL template parameters number"); - // This is required to represent optional 'Component Type Interpretation' - // parameter - using ParamsType = - typename std::conditional, - SmallVector>::type; - ParamsType Params; - if constexpr (NeedTypeInterpret) - Params = {0, 0, 0, 0, 0, Val}; - else - Params = {0, 0, 0, 0, 0}; - for (size_t I = 1; I != TemplateArgs.size(); ++I) { - assert(TemplateArgs[I].getKind() == TemplateArgument::Integral && - "Wrong JointMatrixINTEL template parameter"); - Params[I - 1] = TemplateArgs[I].getAsIntegral().getExtValue(); - } - return llvm::TargetExtType::get(CompTy->getContext(), - "spirv.JointMatrixINTEL", {CompTy}, Params); -} - -/// ConvertSYCLJointMatrixINTELType - Convert SYCL joint_matrix type -/// which is represented as a pointer to a structure to LLVM extension type -/// with the parameters that follow SPIR-V JointMatrixINTEL type. -/// The expected representation is: -/// target("spirv.JointMatrixINTEL", %element_type, %rows%, %cols%, %scope%, -/// %use%, (optional) %element_type_interpretation%) -llvm::Type *CodeGenTypes::ConvertSYCLJointMatrixINTELType(RecordDecl *RD) { - auto *TemplateDecl = cast(RD); - ArrayRef TemplateArgs = - TemplateDecl->getTemplateArgs().asArray(); - assert(TemplateArgs[0].getKind() == TemplateArgument::Type && - "1st JointMatrixINTEL template parameter must be type"); - llvm::Type *CompTy = ConvertType(TemplateArgs[0].getAsType()); - - // Per JointMatrixINTEL spec the type can have an optional - // 'Component Type Interpretation' parameter. We should emit it in case - // if on SYCL level joint matrix accepts 'bfloat16' or 'tf32' objects as - // matrix's components. Yet 'bfloat16' should be represented as 'int16' and - // 'tf32' as 'float' types. - if (CompTy->isStructTy()) { - StringRef LlvmTyName = CompTy->getStructName(); - // Emit half/int16/float for sycl[::*]::{half,bfloat16,tf32} - if (LlvmTyName.startswith("class.sycl::") || - LlvmTyName.startswith("class.__sycl_internal::")) - LlvmTyName = LlvmTyName.rsplit("::").second; - if (LlvmTyName == "half") { - CompTy = llvm::Type::getHalfTy(getLLVMContext()); - return getJointMatrixINTELExtType(CompTy, TemplateArgs); - } else if (LlvmTyName == "tf32") { - CompTy = llvm::Type::getFloatTy(getLLVMContext()); - // 'tf32' interpretation is mapped to '0' - return getJointMatrixINTELExtType(CompTy, TemplateArgs, 0); - } else if (LlvmTyName == "bfloat16") { - CompTy = llvm::Type::getInt16Ty(getLLVMContext()); - // 'bfloat16' interpretation is mapped to '1' - return getJointMatrixINTELExtType(CompTy, TemplateArgs, 1); - } else { - llvm_unreachable("Wrong matrix base type!"); - } - } - return getJointMatrixINTELExtType(CompTy, TemplateArgs); -} - /// ConvertType - Convert the specified type to its LLVM form. llvm::Type *CodeGenTypes::ConvertType(QualType T) { T = Context.getCanonicalType(T); @@ -758,18 +745,6 @@ llvm::Type *CodeGenTypes::ConvertType(QualType T) { llvm::Type *PointeeType = ConvertTypeForMem(ETy); if (PointeeType->isVoidTy()) PointeeType = llvm::Type::getInt8Ty(getLLVMContext()); - if (CGM.getTriple().isSPIRV() || CGM.getTriple().isSPIR()) { - const Type *ClangETy = ETy.getTypePtrOrNull(); - if (ClangETy && ClangETy->isStructureOrClassType()) { - RecordDecl *RD = ClangETy->getAsCXXRecordDecl(); - if (RD && - RD->getQualifiedNameAsString() == "__spv::__spirv_JointMatrixINTEL") { - ResultType = ConvertSYCLJointMatrixINTELType(RD); - break; - } - } - } - unsigned AS = getTargetAddressSpace(ETy); ResultType = llvm::PointerType::get(PointeeType, AS); break; diff --git a/clang/lib/CodeGen/CodeGenTypes.h b/clang/lib/CodeGen/CodeGenTypes.h index 3f198b2a3de1a..e76fda95513f6 100644 --- a/clang/lib/CodeGen/CodeGenTypes.h +++ b/clang/lib/CodeGen/CodeGenTypes.h @@ -133,14 +133,6 @@ class CodeGenTypes { /// memory representation is usually i8 or i32, depending on the target. llvm::Type *ConvertTypeForMem(QualType T, bool ForBitField = false); - /// ConvertSYCLJointMatrixINTELType - Convert SYCL joint_matrix type - /// which is represented as a pointer to a structure to LLVM extension type - /// with the parameters that follow SPIR-V JointMatrixINTEL type. - /// The expected representation is: - /// target("spirv.JointMatrixINTEL", %element_type, %rows%, %cols%, %scope%, - /// %use%, (optional) %element_type_interpretation%) - llvm::Type *ConvertSYCLJointMatrixINTELType(RecordDecl *RD); - /// GetFunctionType - Get the LLVM function type for \arg Info. llvm::FunctionType *GetFunctionType(const CGFunctionInfo &Info); diff --git a/clang/test/CodeGenSYCL/matrix.cpp b/clang/test/CodeGenSYCL/matrix.cpp index b2c0c51adba6e..69469811047fd 100644 --- a/clang/test/CodeGenSYCL/matrix.cpp +++ b/clang/test/CodeGenSYCL/matrix.cpp @@ -5,18 +5,18 @@ #include namespace __spv { - template + template struct __spirv_JointMatrixINTEL; } -// CHECK: @_Z2f1{{.*}}(target("spirv.JointMatrixINTEL", float, 5, 10, 0, 1, 0) -void f1(__spv::__spirv_JointMatrixINTEL *matrix) {} +// CHECK: @_Z2f1{{.*}}(%spirv.JointMatrixINTEL._float_5_10_0_1 +void f1(__spv::__spirv_JointMatrixINTEL *matrix) {} -// CHECK: @_Z2f2{{.*}}(target("spirv.JointMatrixINTEL", i64, 10, 2, 0, 0, 0) -void f2(__spv::__spirv_JointMatrixINTEL *matrix) {} +// CHECK: @_Z2f2{{.*}}(%spirv.JointMatrixINTEL._long_10_2_0_0 +void f2(__spv::__spirv_JointMatrixINTEL *matrix) {} -// CHECK: @_Z2f3{{.*}}(target("spirv.JointMatrixINTEL", i8, 10, 2, 0, 0, 0) -void f3(__spv::__spirv_JointMatrixINTEL *matrix) {} +// CHECK: @_Z2f3{{.*}}(%spirv.JointMatrixINTEL._char_10_2_0_0 +void f3(__spv::__spirv_JointMatrixINTEL *matrix) {} namespace sycl { class half {}; @@ -25,17 +25,17 @@ namespace sycl { } typedef sycl::half my_half; -// CHECK: @_Z2f4{{.*}}(target("spirv.JointMatrixINTEL", half, 10, 2, 0, 0, 0) -void f4(__spv::__spirv_JointMatrixINTEL *matrix) {} +// CHECK: @_Z2f4{{.*}}(%spirv.JointMatrixINTEL._half_10_2_0_0 +void f4(__spv::__spirv_JointMatrixINTEL *matrix) {} -// CHECK: @_Z2f5{{.*}}(target("spirv.JointMatrixINTEL", i16, 10, 2, 0, 0, 0, 1) -void f5(__spv::__spirv_JointMatrixINTEL *matrix) {} +// CHECK: @_Z2f5{{.*}}(%spirv.JointMatrixINTEL._bfloat16_10_2_0_0 +void f5(__spv::__spirv_JointMatrixINTEL *matrix) {} -// CHECK: @_Z2f6{{.*}}(target("spirv.JointMatrixINTEL", i128, 10, 2, 0, 0, 0) -void f6(__spv::__spirv_JointMatrixINTEL<_BitInt(128), 10, 2, 0, 0, 0> *matrix) {} +// CHECK: @_Z2f6{{.*}}(%spirv.JointMatrixINTEL._i128_10_2_0_0 +void f6(__spv::__spirv_JointMatrixINTEL<_BitInt(128), 10, 2, 0, 0> *matrix) {} -// CHECK: @_Z2f7{{.*}}(target("spirv.JointMatrixINTEL", float, 10, 2, 0, 0, 0, 0) -void f7(__spv::__spirv_JointMatrixINTEL *matrix) {} +// CHECK: @_Z2f7{{.*}}(%spirv.JointMatrixINTEL._tf32_10_2_0_0 +void f7(__spv::__spirv_JointMatrixINTEL *matrix) {} -// CHECK: @_Z2f8{{.*}}(target("spirv.JointMatrixINTEL", double, 5, 10, 0, 1, 0) -void f8(__spv::__spirv_JointMatrixINTEL *matrix) {} +// CHECK: @_Z2f8{{.*}}(%spirv.JointMatrixINTEL._double_5_10_0_1 +void f8(__spv::__spirv_JointMatrixINTEL *matrix) {} diff --git a/sycl/test/matrix/legacy/matrix-int8-test.cpp b/sycl/test/matrix/legacy/matrix-int8-test.cpp index 852c877b46fc4..a0c2edb62c2f1 100644 --- a/sycl/test/matrix/legacy/matrix-int8-test.cpp +++ b/sycl/test/matrix/legacy/matrix-int8-test.cpp @@ -1,8 +1,8 @@ // RUN: %clangxx -fsycl -fsycl-device-only -O2 -DSYCL_EXT_ONEAPI_MATRIX_VERSION=1 -S -emit-llvm -o - %s | FileCheck %s -// CHECK-DAG: target("spirv.JointMatrixINTEL", i8, 12, 48, 0, 3, 0) -// CHECK-DAG: target("spirv.JointMatrixINTEL", i32, 12, 12, 0, 3, 0) -// CHECK-DAG: target("spirv.JointMatrixINTEL", i8, 48, 12, 3, 3, 0) +// CHECK-DAG: %spirv.JointMatrixINTEL._char_12_48_0_3 = type opaque +// CHECK-DAG: %spirv.JointMatrixINTEL._int_12_12_0_3 = type opaque +// CHECK-DAG: %spirv.JointMatrixINTEL._char_48_12_3_3 = type opaque #include #include diff --git a/sycl/test/matrix/matrix-int8-test.cpp b/sycl/test/matrix/matrix-int8-test.cpp index 99f60423ca212..de8721bca3b09 100644 --- a/sycl/test/matrix/matrix-int8-test.cpp +++ b/sycl/test/matrix/matrix-int8-test.cpp @@ -1,8 +1,8 @@ // RUN: %clangxx -fsycl -fsycl-device-only -DSYCL_EXT_ONEAPI_MATRIX_VERSION=4 -O2 -S -emit-llvm -o - %s | FileCheck %s -// CHECK-DAG: target("spirv.JointMatrixINTEL", i8, 12, 48, 0, 3, 0) -// CHECK-DAG: target("spirv.JointMatrixINTEL", i32, 12, 12, 3, 3, 2) -// CHECK-DAG: target("spirv.JointMatrixINTEL", i8, 48, 12, 2, 3, 1) +// CHECK-DAG: %spirv.JointMatrixINTEL._char_12_48_0_3_0 = type opaque +// CHECK-DAG: %spirv.JointMatrixINTEL._int_12_12_3_3_2 = type opaque +// CHECK-DAG: %spirv.JointMatrixINTEL._char_48_12_2_3_1 = type opaque #include #include