Skip to content

Commit

Permalink
Reverse translation of access + load/store operations for cooperative…
Browse files Browse the repository at this point in the history
… matrix (#2165)

Implement translation via SPIR-V friendly calls, as:

the LLVM instructions are not capable to accept target extension types;
cooperative matrix is an opaque object and accessing elements is implementation defined, hence we can't use GEP to which AccessChain naturally maps, since GEP has a different meaning.
As for now some BE would need to recognize and define what to do with a call to __spirv_AccessChain(matrix, index). Better option is to map such SPIR-V to an intrinsic or define an appropriate type in LLVM (hence defining rules for GEP and other instructions) , but it's off the table now.

Original commit:
KhronosGroup/SPIRV-LLVM-Translator@deb6ee9
  • Loading branch information
vmaksimo authored and sys-ce-bb committed Oct 5, 2023
1 parent 8b36ee8 commit d42ec13
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 7 deletions.
4 changes: 4 additions & 0 deletions llvm-spirv/lib/SPIRV/SPIRVReader.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2183,6 +2183,10 @@ Value *SPIRVToLLVM::transValueWithoutDecoration(SPIRVValue *BV, Function *F,
auto *AC = static_cast<SPIRVAccessChainBase *>(BV);
auto *Base = transValue(AC->getBase(), F, BB);
SPIRVType *BaseSPVTy = AC->getBase()->getType();
if (BaseSPVTy->isTypePointer() &&
BaseSPVTy->getPointerElementType()->isTypeCooperativeMatrixKHR()) {
return mapValue(BV, transSPIRVBuiltinFromInst(AC, BB));
}
Type *BaseTy =
BaseSPVTy->isTypeVector()
? transType(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,8 @@
; RUN: llvm-spirv %t.spv -to-text -o %t.spt
; RUN: FileCheck < %t.spt %s --check-prefix=CHECK-SPIRV

; TODO: come up with an approach and implement reverse translation
; R/UN: llvm-spirv -r %t.spv -o %t.rev.bc
; R/UN: llvm-dis < %t.rev.bc | FileCheck %s --check-prefix=CHECK-LLVM
; RUN: llvm-spirv -r %t.spv -o %t.rev.bc
; RUN: llvm-dis < %t.rev.bc | FileCheck %s --check-prefix=CHECK-LLVM

; CHECK-SPIRV: TypeInt [[#TypeInt:]] 32 0
; CHECK-SPIRV-DAG: Constant [[#TypeInt]] [[#Const0:]] 0
Expand All @@ -15,28 +14,38 @@
; CHECK-SPIRV-DAG: Constant [[#TypeInt]] [[#Const42:]] 42

; CHECK-SPIRV: TypeCooperativeMatrixKHR [[#TypeMatrix:]] [[#TypeInt]] [[#Const3]] [[#Const12]] [[#Const12]] [[#Const0]]
; CHECK-SPIRV: TypePointer [[#Type:]] 7 [[#TypeInt]]
; CHECK-SPIRV: TypePointer [[#TypeMatrixPtr:]] 7 [[#TypeMatrix]]
; CHECK-SPIRV: TypePointer [[#TypeIntPtr:]] 7 [[#TypeInt]]

; CHECK-SPIRV: Variable [[#TypeMatrixPtr]] [[#VarMatrixPtr:]] 7
; CHECK-SPIRV: CompositeConstruct [[#TypeMatrix]] [[#Composite:]] [[#Const0]]
; CHECK-SPIRV: AccessChain [[#Type]] [[#Res:]] [[#Composite]] [[#Const1]]
; CHECK-SPIRV: Store [[#VarMatrixPtr]] [[#Composite]]
; CHECK-SPIRV: AccessChain [[#TypeIntPtr]] [[#Res:]] [[#VarMatrixPtr]] [[#Const1]]
; CHECK-SPIRV: Store [[#Res]] [[#Const42]]

; CHECK-LLVM: %0 = alloca target("spirv.CooperativeMatrixKHR", i32, 3, 12, 12, 0)
; CHECK-LLVM: %Obj = call spir_func target("spirv.CooperativeMatrixKHR", i32, 3, 12, 12, 0) @_Z26__spirv_CompositeConstructi(i32 0)
; CHECK-LLVM: store target("spirv.CooperativeMatrixKHR", i32, 3, 12, 12, 0) %Obj, ptr %0
; CHECK-LLVM: %call = call spir_func ptr @_Z19__spirv_AccessChainPPU3AS144__spirv_CooperativeMatrixKHR__uint_3_12_12_0i(ptr %0, i32 1)
; CHECK-LLVM: store i32 42, ptr %call

target datalayout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-f80:128-n8:16:32:64-S128"
target triple = "spir64-unknown-unknown"

; Function Attrs: mustprogress uwtable
define dso_local void @_Z3fooi(i32 noundef %idx) local_unnamed_addr #0 {
entry:
%0 = alloca target("spirv.CooperativeMatrixKHR", i32, 3, 12, 12, 0), align 8
%Obj = tail call spir_func noundef target("spirv.CooperativeMatrixKHR", i32, 3, 12, 12, 0) @_Z26__spirv_CompositeConstruct(i32 noundef 0) #4
%call = call noundef ptr @_Z19__spirv_AccessChainP6Matrixii(target("spirv.CooperativeMatrixKHR", i32, 3, 12, 12, 0) %Obj, i32 noundef 1)
store target("spirv.CooperativeMatrixKHR", i32, 3, 12, 12, 0) %Obj, ptr %0, align 8
%call = call noundef ptr @_Z19__spirv_AccessChainP6Matrixii(ptr %0, i32 noundef 1)
call void @_Z13__spirv_StorePii(ptr noundef %call, i32 noundef 42)
ret void
}

declare dso_local spir_func noundef target("spirv.CooperativeMatrixKHR", i32, 3, 12, 12, 0) @_Z26__spirv_CompositeConstruct(i32 noundef) local_unnamed_addr #2

declare noundef ptr @_Z19__spirv_AccessChainP6Matrixii(target("spirv.CooperativeMatrixKHR", i32, 3, 12, 12, 0) noundef, i32 noundef) local_unnamed_addr #2
declare noundef ptr @_Z19__spirv_AccessChainP6Matrixii(ptr noundef, i32 noundef) local_unnamed_addr #2

declare void @_Z13__spirv_StorePii(ptr noundef, i32 noundef) local_unnamed_addr #2

Expand Down

0 comments on commit d42ec13

Please sign in to comment.