diff --git a/llvm-spirv/lib/SPIRV/SPIRVReader.cpp b/llvm-spirv/lib/SPIRV/SPIRVReader.cpp index 81f213813c845..dfdd8f417c8c1 100644 --- a/llvm-spirv/lib/SPIRV/SPIRVReader.cpp +++ b/llvm-spirv/lib/SPIRV/SPIRVReader.cpp @@ -2183,6 +2183,10 @@ Value *SPIRVToLLVM::transValueWithoutDecoration(SPIRVValue *BV, Function *F, auto *AC = static_cast(BV); auto *Base = transValue(AC->getBase(), F, BB); SPIRVType *BaseSPVTy = AC->getBase()->getType(); + if (BaseSPVTy->isTypePointer() && + BaseSPVTy->getPointerElementType()->isTypeCooperativeMatrixKHR()) { + return mapValue(BV, transSPIRVBuiltinFromInst(AC, BB)); + } Type *BaseTy = BaseSPVTy->isTypeVector() ? transType( diff --git a/llvm-spirv/test/extensions/KHR/SPV_KHR_cooperative_matrix/access_store.ll b/llvm-spirv/test/extensions/KHR/SPV_KHR_cooperative_matrix/access_store.ll index ddd20fa5ca5ed..62f3f8407ca47 100644 --- a/llvm-spirv/test/extensions/KHR/SPV_KHR_cooperative_matrix/access_store.ll +++ b/llvm-spirv/test/extensions/KHR/SPV_KHR_cooperative_matrix/access_store.ll @@ -3,9 +3,8 @@ ; RUN: llvm-spirv %t.spv -to-text -o %t.spt ; RUN: FileCheck < %t.spt %s --check-prefix=CHECK-SPIRV -; TODO: come up with an approach and implement reverse translation -; R/UN: llvm-spirv -r %t.spv -o %t.rev.bc -; R/UN: llvm-dis < %t.rev.bc | FileCheck %s --check-prefix=CHECK-LLVM +; RUN: llvm-spirv -r %t.spv -o %t.rev.bc +; RUN: llvm-dis < %t.rev.bc | FileCheck %s --check-prefix=CHECK-LLVM ; CHECK-SPIRV: TypeInt [[#TypeInt:]] 32 0 ; CHECK-SPIRV-DAG: Constant [[#TypeInt]] [[#Const0:]] 0 @@ -15,12 +14,20 @@ ; CHECK-SPIRV-DAG: Constant [[#TypeInt]] [[#Const42:]] 42 ; CHECK-SPIRV: TypeCooperativeMatrixKHR [[#TypeMatrix:]] [[#TypeInt]] [[#Const3]] [[#Const12]] [[#Const12]] [[#Const0]] -; CHECK-SPIRV: TypePointer [[#Type:]] 7 [[#TypeInt]] +; CHECK-SPIRV: TypePointer [[#TypeMatrixPtr:]] 7 [[#TypeMatrix]] +; CHECK-SPIRV: TypePointer [[#TypeIntPtr:]] 7 [[#TypeInt]] +; CHECK-SPIRV: Variable [[#TypeMatrixPtr]] [[#VarMatrixPtr:]] 7 ; CHECK-SPIRV: CompositeConstruct [[#TypeMatrix]] [[#Composite:]] [[#Const0]] -; CHECK-SPIRV: AccessChain [[#Type]] [[#Res:]] [[#Composite]] [[#Const1]] +; CHECK-SPIRV: Store [[#VarMatrixPtr]] [[#Composite]] +; CHECK-SPIRV: AccessChain [[#TypeIntPtr]] [[#Res:]] [[#VarMatrixPtr]] [[#Const1]] ; CHECK-SPIRV: Store [[#Res]] [[#Const42]] +; CHECK-LLVM: %0 = alloca target("spirv.CooperativeMatrixKHR", i32, 3, 12, 12, 0) +; CHECK-LLVM: %Obj = call spir_func target("spirv.CooperativeMatrixKHR", i32, 3, 12, 12, 0) @_Z26__spirv_CompositeConstructi(i32 0) +; CHECK-LLVM: store target("spirv.CooperativeMatrixKHR", i32, 3, 12, 12, 0) %Obj, ptr %0 +; CHECK-LLVM: %call = call spir_func ptr @_Z19__spirv_AccessChainPPU3AS144__spirv_CooperativeMatrixKHR__uint_3_12_12_0i(ptr %0, i32 1) +; CHECK-LLVM: store i32 42, ptr %call target datalayout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-f80:128-n8:16:32:64-S128" target triple = "spir64-unknown-unknown" @@ -28,15 +35,17 @@ target triple = "spir64-unknown-unknown" ; Function Attrs: mustprogress uwtable define dso_local void @_Z3fooi(i32 noundef %idx) local_unnamed_addr #0 { entry: + %0 = alloca target("spirv.CooperativeMatrixKHR", i32, 3, 12, 12, 0), align 8 %Obj = tail call spir_func noundef target("spirv.CooperativeMatrixKHR", i32, 3, 12, 12, 0) @_Z26__spirv_CompositeConstruct(i32 noundef 0) #4 - %call = call noundef ptr @_Z19__spirv_AccessChainP6Matrixii(target("spirv.CooperativeMatrixKHR", i32, 3, 12, 12, 0) %Obj, i32 noundef 1) + store target("spirv.CooperativeMatrixKHR", i32, 3, 12, 12, 0) %Obj, ptr %0, align 8 + %call = call noundef ptr @_Z19__spirv_AccessChainP6Matrixii(ptr %0, i32 noundef 1) call void @_Z13__spirv_StorePii(ptr noundef %call, i32 noundef 42) ret void } declare dso_local spir_func noundef target("spirv.CooperativeMatrixKHR", i32, 3, 12, 12, 0) @_Z26__spirv_CompositeConstruct(i32 noundef) local_unnamed_addr #2 -declare noundef ptr @_Z19__spirv_AccessChainP6Matrixii(target("spirv.CooperativeMatrixKHR", i32, 3, 12, 12, 0) noundef, i32 noundef) local_unnamed_addr #2 +declare noundef ptr @_Z19__spirv_AccessChainP6Matrixii(ptr noundef, i32 noundef) local_unnamed_addr #2 declare void @_Z13__spirv_StorePii(ptr noundef, i32 noundef) local_unnamed_addr #2