forked from intel/llvm
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add support for JointMatrixINTEL target ext type (intel#1852)
The expected representation is: target("spirv.JointMatrixINTEL", %element_type, %rows%, %cols%, %scope%, %use%, (optional) %element_type_interpretation%) TODO: figure out, how to deal with the switch from old API (Matrix has Layout) to new API (Layout was removed) Depends on: intel#1799 intel#8343 Original commit: KhronosGroup/SPIRV-LLVM-Translator@ee03f5f
- Loading branch information
1 parent
cf0e151
commit 843c7f4
Showing
2 changed files
with
164 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
151 changes: 151 additions & 0 deletions
151
llvm-spirv/test/extensions/INTEL/SPV_INTEL_joint_matrix/opaque_joint_matrix.ll
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,151 @@ | ||
; RUN: llvm-as < %s -o %t.bc | ||
; RUN: llvm-spirv %t.bc --spirv-ext=+SPV_INTEL_joint_matrix -o %t.spv | ||
; RUN: llvm-spirv %t.spv -to-text -o %t.spt | ||
; RUN: FileCheck < %t.spt %s --check-prefix=CHECK-SPIRV | ||
|
||
; RUN: llvm-spirv -r %t.spv -o %t.rev.bc -opaque-pointers=0 | ||
; RUN: llvm-dis -opaque-pointers=0 < %t.rev.bc | FileCheck %s --check-prefix=CHECK-LLVM | ||
|
||
; CHECK-SPIRV-DAG: Capability JointMatrixINTEL | ||
; CHECK-SPIRV-DAG: Extension "SPV_INTEL_joint_matrix" | ||
; CHECK-SPIRV-DAG: TypeInt [[#Int8Ty:]] 8 0 | ||
; CHECK-SPIRV-DAG: TypeInt [[#Int32Ty:]] 32 0 | ||
; CHECK-SPIRV-DAG: Constant [[#Int32Ty]] [[#Const12:]] 12 | ||
; CHECK-SPIRV-DAG: Constant [[#Int32Ty]] [[#Const3:]] 3 | ||
; CHECK-SPIRV-DAG: Constant [[#Int32Ty]] [[#Const2:]] 2 | ||
; CHECK-SPIRV-DAG: Constant [[#Int32Ty]] [[#Const0:]] 0 | ||
; CHECK-SPIRV-DAG: Constant [[#Int32Ty]] [[#Const48:]] 48 | ||
; CHECK-SPIRV-DAG: Constant [[#Int32Ty]] [[#Const1:]] 1 | ||
; CHECK-SPIRV-DAG: TypeJointMatrixINTEL [[#MatTy1:]] [[#Int32Ty]] [[#Const12]] [[#Const12]] [[#Const3]] [[#Const3]] [[#Const2]] | ||
; CHECK-SPIRV-DAG: TypeJointMatrixINTEL [[#MatTy2:]] [[#Int8Ty]] [[#Const12]] [[#Const48]] [[#Const0]] [[#Const3]] [[#Const0]] | ||
; CHECK-SPIRV-DAG: TypeJointMatrixINTEL [[#MatTy3:]] [[#Int8Ty]] [[#Const48]] [[#Const12]] [[#Const2]] [[#Const3]] [[#Const1]] | ||
|
||
; CHECK-LLVM-DAG: %spirv.JointMatrixINTEL._int_12_12_3_3_2 = type opaque | ||
; CHECK-LLVM-DAG: %spirv.JointMatrixINTEL._char_12_48_0_3_0 = type opaque | ||
; CHECK-LLVM-DAG: %spirv.JointMatrixINTEL._char_48_12_2_3_1 = type opaque | ||
|
||
; ModuleID = 'test-matrix-opaque.bc' | ||
source_filename = "matrix-int8-test.cpp" | ||
target datalayout = "e-i64:64-v16:16-v24:32-v32:32-v48:64-v96:128-v192:256-v256:256-v512:512-v1024:1024-n8:16:32:64" | ||
target triple = "spir64-unknown-unknown" | ||
|
||
%"class.sycl::_V1::range" = type { %"class.sycl::_V1::detail::array" } | ||
%"class.sycl::_V1::detail::array" = type { [2 x i64] } | ||
%"class.sycl::_V1::id" = type { %"class.sycl::_V1::detail::array" } | ||
|
||
$_ZTSZZ15matrix_multiplyIiaLm24ELm96ELm24ELm96ELm24ELm24EEvR10big_matrixIT_XT5_EXT6_EERS0_IT0_XT1_EXT2_EERS0_IS4_XT3_EXT4_EEENKUlRN4sycl3_V17handlerEE_clESC_E7imatrix = comdat any | ||
|
||
@__spirv_BuiltInGlobalInvocationId = external dso_local local_unnamed_addr addrspace(1) constant <3 x i64>, align 32 | ||
@__spirv_BuiltInLocalInvocationId = external dso_local local_unnamed_addr addrspace(1) constant <3 x i64>, align 32 | ||
|
||
; Function Attrs: convergent norecurse | ||
define weak_odr dso_local spir_kernel void @_ZTSZZ15matrix_multiplyIiaLm24ELm96ELm24ELm96ELm24ELm24EEvR10big_matrixIT_XT5_EXT6_EERS0_IT0_XT1_EXT2_EERS0_IS4_XT3_EXT4_EEENKUlRN4sycl3_V17handlerEE_clESC_E7imatrix(ptr addrspace(1) noundef align 1 %_arg_accA, ptr addrspace(1) noundef align 1 %_arg_accB, ptr noundef byval(%"class.sycl::_V1::range") align 8 %_arg_accB5, ptr noundef byval(%"class.sycl::_V1::id") align 8 %_arg_accB6, ptr addrspace(1) noundef align 4 %_arg_accC, i64 noundef %_arg_N, i64 noundef %_arg_K) local_unnamed_addr #0 comdat { | ||
entry: | ||
%sub_c.sroa.0.i = alloca target("spirv.JointMatrixINTEL", i32, 12, 12, 3, 3, 2), align 8 | ||
%ref.tmp29.sroa.0.i = alloca target("spirv.JointMatrixINTEL", i32, 12, 12, 3, 3, 2), align 8 | ||
%agg.tmp15.sroa.0.sroa.2.0..sroa_idx = getelementptr inbounds %"class.sycl::_V1::range", ptr %_arg_accB5, i64 0, i32 0, i32 0, i64 1 | ||
%agg.tmp15.sroa.0.sroa.2.0.copyload = load i64, ptr %agg.tmp15.sroa.0.sroa.2.0..sroa_idx, align 8 | ||
%0 = getelementptr inbounds %"class.sycl::_V1::id", ptr %_arg_accB6, i64 0, i32 0, i32 0, i64 0 | ||
%agg.tmp16.sroa.0.sroa.0.0.copyload = load i64, ptr %0, align 8 | ||
%agg.tmp16.sroa.0.sroa.2.0..sroa_idx = getelementptr inbounds %"class.sycl::_V1::id", ptr %_arg_accB6, i64 0, i32 0, i32 0, i64 1 | ||
%agg.tmp16.sroa.0.sroa.2.0.copyload = load i64, ptr %agg.tmp16.sroa.0.sroa.2.0..sroa_idx, align 8 | ||
%mul.i4.i.i.i.i45 = mul i64 %agg.tmp16.sroa.0.sroa.0.0.copyload, %agg.tmp15.sroa.0.sroa.2.0.copyload | ||
%add.i6.i.i.i.i46 = add i64 %mul.i4.i.i.i.i45, %agg.tmp16.sroa.0.sroa.2.0.copyload | ||
%add.ptr.i47 = getelementptr inbounds i8, ptr addrspace(1) %_arg_accB, i64 %add.i6.i.i.i.i46 | ||
%1 = load <3 x i64>, ptr addrspace(1) @__spirv_BuiltInGlobalInvocationId, align 32 | ||
%2 = extractelement <3 x i64> %1, i64 1 | ||
%3 = extractelement <3 x i64> %1, i64 0 | ||
%4 = load <3 x i64>, ptr addrspace(1) @__spirv_BuiltInLocalInvocationId, align 32 | ||
%5 = extractelement <3 x i64> %4, i64 1 | ||
%6 = extractelement <3 x i64> %4, i64 0 | ||
%cmp.i.i = icmp ult i64 %2, 2147483648 | ||
%cmp.i54.i = icmp ult i64 %3, 2147483648 | ||
%cmp.i56.i = icmp ult i64 %5, 2147483648 | ||
%sub.i = sub nsw i64 %2, %5 | ||
%cmp.i58.i = icmp ult i64 %6, 2147483648 | ||
%sub5.i = sub nsw i64 %3, %6 | ||
%sub_c.sroa.0.i.0.i.0..sroa_cast = bitcast ptr %sub_c.sroa.0.i to ptr | ||
call void @llvm.lifetime.start.p0(i64 8, ptr nonnull %sub_c.sroa.0.i.0.i.0..sroa_cast) | ||
%call.i.i = tail call spir_func noundef target("spirv.JointMatrixINTEL", i32, 12, 12, 3, 3, 2) @_Z26__spirv_CompositeConstructIiLm12ELm12ELN5__spv9MatrixUseE2ELNS0_12MatrixLayoutE3ELNS0_5Scope4FlagE3EEPNS0_24__spirv_JointMatrixINTELIT_XT0_EXT1_EXT3_EXT4_EXT2_EEES6_(i32 noundef 0) #4 | ||
store target("spirv.JointMatrixINTEL", i32, 12, 12, 3, 3, 2) %call.i.i, ptr %sub_c.sroa.0.i, align 8 | ||
%mul.i = mul nsw i64 %sub.i, 12 | ||
%div2452.i = lshr i64 %sub5.i, 4 | ||
%mul26.i = mul i64 %div2452.i, 48 | ||
%div.i = udiv i64 %_arg_K, 48 | ||
%mul11.i = mul i64 %mul.i, %_arg_K | ||
%add.ptr.i93.i = getelementptr inbounds i8, ptr addrspace(1) %_arg_accA, i64 %mul11.i | ||
%idx.neg.i.i104.i = sub i64 0, %add.i6.i.i.i.i46 | ||
%add.ptr.i.i105141.i = getelementptr i8, ptr addrspace(1) %add.ptr.i47, i64 %mul26.i | ||
%mul22.i = shl i64 %_arg_N, 2 | ||
%add.ptr.i108140.i = getelementptr i8, ptr addrspace(1) %add.ptr.i.i105141.i, i64 %idx.neg.i.i104.i | ||
%ref.tmp29.sroa.0.i.0.i.0..sroa_cast = bitcast ptr %ref.tmp29.sroa.0.i to ptr | ||
%7 = bitcast ptr %ref.tmp29.sroa.0.i to ptr | ||
%8 = bitcast ptr %sub_c.sroa.0.i to ptr | ||
br label %for.cond.i | ||
|
||
for.cond.i: ; preds = %for.body.i, %entry | ||
%k.0.i = phi i32 [ 0, %entry ], [ %add.i, %for.body.i ] | ||
%conv.i = zext i32 %k.0.i to i64 | ||
%cmp.i = icmp ugt i64 %div.i, %conv.i | ||
br i1 %cmp.i, label %for.body.i, label %_ZZZ15matrix_multiplyIiaLm24ELm96ELm24ELm96ELm24ELm24EEvR10big_matrixIT_XT5_EXT6_EERS0_IT0_XT1_EXT2_EERS0_IS4_XT3_EXT4_EEENKUlRN4sycl3_V17handlerEE_clESC_ENKUlNSA_7nd_itemILi2EEEE_clESF_.exit | ||
|
||
for.body.i: ; preds = %for.cond.i | ||
%mul12.i = mul nsw i32 %k.0.i, 48 | ||
%conv13.i = zext i32 %mul12.i to i64 | ||
%add.ptr.i96.i = getelementptr inbounds i8, ptr addrspace(1) %add.ptr.i93.i, i64 %conv13.i | ||
%call.ascast.i66.i = addrspacecast ptr addrspace(1) %add.ptr.i96.i to ptr addrspace(4) | ||
%call1.i.i = tail call spir_func noundef target("spirv.JointMatrixINTEL", i8, 12, 48, 0, 3, 0) @_Z28__spirv_JointMatrixLoadINTELIaLm12ELm48ELN5__spv9MatrixUseE0ELNS0_12MatrixLayoutE0ELNS0_5Scope4FlagE3EEPNS0_24__spirv_JointMatrixINTELIT_XT0_EXT1_EXT3_EXT4_EXT2_EEEPS6_mS2_S4_i(ptr addrspace(4) noundef %call.ascast.i66.i, i64 noundef %_arg_K, i32 noundef 0, i32 noundef 3, i32 noundef 0) #4 | ||
%div20.i = mul nsw i32 %k.0.i, 12 | ||
%conv21.i = zext i32 %div20.i to i64 | ||
%mul23.i = mul i64 %mul22.i, %conv21.i | ||
%add.ptr.i111.i = getelementptr i8, ptr addrspace(1) %add.ptr.i108140.i, i64 %mul23.i | ||
%call.ascast.i72.i = addrspacecast ptr addrspace(1) %add.ptr.i111.i to ptr addrspace(4) | ||
%call1.i73.i = tail call spir_func noundef target("spirv.JointMatrixINTEL", i8, 48, 12, 2, 3, 1) @_Z28__spirv_JointMatrixLoadINTELIaLm48ELm12ELN5__spv9MatrixUseE1ELNS0_12MatrixLayoutE2ELNS0_5Scope4FlagE3EEPNS0_24__spirv_JointMatrixINTELIT_XT0_EXT1_EXT3_EXT4_EXT2_EEEPS6_mS2_S4_i(ptr addrspace(4) noundef %call.ascast.i72.i, i64 noundef %mul22.i, i32 noundef 2, i32 noundef 3, i32 noundef 0) #4 | ||
call void @llvm.lifetime.start.p0(i64 8, ptr nonnull %ref.tmp29.sroa.0.i.0.i.0..sroa_cast) | ||
%sub_c.sroa.0.i.0.sub_c.sroa.0.i.0.sub_c.sroa.0.0.sub_c.sroa.0.0.sub_c.sroa.0.0.125.i = load target("spirv.JointMatrixINTEL", i32, 12, 12, 3, 3, 2), ptr %sub_c.sroa.0.i, align 8 | ||
%call.i77.i = tail call spir_func noundef target("spirv.JointMatrixINTEL", i32, 12, 12, 3, 3, 2) @_Z27__spirv_JointMatrixMadINTELIaiLm12ELm48ELm12ELN5__spv9MatrixUseE0ELS1_1ELS1_2ELNS0_12MatrixLayoutE0ELS2_2ELS2_3ELNS0_5Scope4FlagE3EEPNS0_24__spirv_JointMatrixINTELIT0_XT1_EXT3_EXT9_EXT10_EXT6_EEEPNS5_IT_XT1_EXT2_EXT7_EXT10_EXT4_EEEPNS5_IS9_XT2_EXT3_EXT8_EXT10_EXT5_EEES8_S4_(target("spirv.JointMatrixINTEL", i8, 12, 48, 0, 3, 0) noundef %call1.i.i, target("spirv.JointMatrixINTEL", i8, 48, 12, 2, 3, 1) noundef %call1.i73.i, target("spirv.JointMatrixINTEL", i32, 12, 12, 3, 3, 2) noundef %sub_c.sroa.0.i.0.sub_c.sroa.0.i.0.sub_c.sroa.0.0.sub_c.sroa.0.0.sub_c.sroa.0.0.125.i, i32 noundef 3) #4 | ||
store target("spirv.JointMatrixINTEL", i32, 12, 12, 3, 3, 2) %call.i77.i, ptr %ref.tmp29.sroa.0.i, align 8 | ||
%ref.tmp29.sroa.0.i.0.ref.tmp29.sroa.0.i.0.ref.tmp29.sroa.0.0.ref.tmp29.sroa.0.0.ref.tmp29.sroa.0.0..i = load i64, ptr %7, align 8 | ||
store i64 %ref.tmp29.sroa.0.i.0.ref.tmp29.sroa.0.i.0.ref.tmp29.sroa.0.0.ref.tmp29.sroa.0.0.ref.tmp29.sroa.0.0..i, ptr %8, align 8 | ||
call void @llvm.lifetime.end.p0(i64 8, ptr nonnull %ref.tmp29.sroa.0.i.0.i.0..sroa_cast) | ||
%add.i = add nuw nsw i32 %k.0.i, 1 | ||
br label %for.cond.i | ||
|
||
_ZZZ15matrix_multiplyIiaLm24ELm96ELm24ELm96ELm24ELm24EEvR10big_matrixIT_XT5_EXT6_EERS0_IT0_XT1_EXT2_EERS0_IS4_XT3_EXT4_EEENKUlRN4sycl3_V17handlerEE_clESC_ENKUlNSA_7nd_itemILi2EEEE_clESF_.exit: ; preds = %for.cond.i | ||
%mul37.i = mul i64 %mul.i, %_arg_N | ||
%add.ptr.i.i = getelementptr inbounds i32, ptr addrspace(1) %_arg_accC, i64 %mul37.i | ||
%mul39.i = mul nuw i64 %div2452.i, 12 | ||
%add.ptr.i81.i = getelementptr inbounds i32, ptr addrspace(1) %add.ptr.i.i, i64 %mul39.i | ||
%call.ascast.i.i = addrspacecast ptr addrspace(1) %add.ptr.i81.i to ptr addrspace(4) | ||
%sub_c.sroa.0.i.0.sub_c.sroa.0.i.0.sub_c.sroa.0.0.sub_c.sroa.0.0.sub_c.sroa.0.0..i = load target("spirv.JointMatrixINTEL", i32, 12, 12, 3, 3, 2), ptr %sub_c.sroa.0.i, align 8 | ||
tail call spir_func void @_Z29__spirv_JointMatrixStoreINTELIiLm12ELm12ELN5__spv9MatrixUseE2ELNS0_12MatrixLayoutE3ELNS0_5Scope4FlagE3EEvPT_PNS0_24__spirv_JointMatrixINTELIS5_XT0_EXT1_EXT3_EXT4_EXT2_EEEmS2_S4_i(ptr addrspace(4) noundef %call.ascast.i.i, target("spirv.JointMatrixINTEL", i32, 12, 12, 3, 3, 2) noundef %sub_c.sroa.0.i.0.sub_c.sroa.0.i.0.sub_c.sroa.0.0.sub_c.sroa.0.0.sub_c.sroa.0.0..i, i64 noundef %_arg_N, i32 noundef 0, i32 noundef 3, i32 noundef 0) #4 | ||
call void @llvm.lifetime.end.p0(i64 8, ptr nonnull %sub_c.sroa.0.i.0.i.0..sroa_cast) | ||
ret void | ||
} | ||
|
||
; Function Attrs: convergent | ||
declare dso_local spir_func noundef target("spirv.JointMatrixINTEL", i32, 12, 12, 3, 3, 2) @_Z26__spirv_CompositeConstructIiLm12ELm12ELN5__spv9MatrixUseE2ELNS0_12MatrixLayoutE3ELNS0_5Scope4FlagE3EEPNS0_24__spirv_JointMatrixINTELIT_XT0_EXT1_EXT3_EXT4_EXT2_EEES6_(i32 noundef) local_unnamed_addr #2 | ||
|
||
; Function Attrs: convergent | ||
declare dso_local spir_func noundef target("spirv.JointMatrixINTEL", i8, 12, 48, 0, 3, 0) @_Z28__spirv_JointMatrixLoadINTELIaLm12ELm48ELN5__spv9MatrixUseE0ELNS0_12MatrixLayoutE0ELNS0_5Scope4FlagE3EEPNS0_24__spirv_JointMatrixINTELIT_XT0_EXT1_EXT3_EXT4_EXT2_EEEPS6_mS2_S4_i(ptr addrspace(4) noundef, i64 noundef, i32 noundef, i32 noundef, i32 noundef) local_unnamed_addr #2 | ||
|
||
; Function Attrs: convergent | ||
declare dso_local spir_func noundef target("spirv.JointMatrixINTEL", i8, 48, 12, 2, 3, 1) @_Z28__spirv_JointMatrixLoadINTELIaLm48ELm12ELN5__spv9MatrixUseE1ELNS0_12MatrixLayoutE2ELNS0_5Scope4FlagE3EEPNS0_24__spirv_JointMatrixINTELIT_XT0_EXT1_EXT3_EXT4_EXT2_EEEPS6_mS2_S4_i(ptr addrspace(4) noundef, i64 noundef, i32 noundef, i32 noundef, i32 noundef) local_unnamed_addr #2 | ||
|
||
; Function Attrs: convergent | ||
declare dso_local spir_func noundef target("spirv.JointMatrixINTEL", i32, 12, 12, 3, 3, 2) @_Z27__spirv_JointMatrixMadINTELIaiLm12ELm48ELm12ELN5__spv9MatrixUseE0ELS1_1ELS1_2ELNS0_12MatrixLayoutE0ELS2_2ELS2_3ELNS0_5Scope4FlagE3EEPNS0_24__spirv_JointMatrixINTELIT0_XT1_EXT3_EXT9_EXT10_EXT6_EEEPNS5_IT_XT1_EXT2_EXT7_EXT10_EXT4_EEEPNS5_IS9_XT2_EXT3_EXT8_EXT10_EXT5_EEES8_S4_(target("spirv.JointMatrixINTEL", i8, 12, 48, 0, 3, 0) noundef, target("spirv.JointMatrixINTEL", i8, 48, 12, 2, 3, 1) noundef, target("spirv.JointMatrixINTEL", i32, 12, 12, 3, 3, 2) noundef, i32 noundef) local_unnamed_addr #2 | ||
|
||
; Function Attrs: convergent | ||
declare dso_local spir_func void @_Z29__spirv_JointMatrixStoreINTELIiLm12ELm12ELN5__spv9MatrixUseE2ELNS0_12MatrixLayoutE3ELNS0_5Scope4FlagE3EEvPT_PNS0_24__spirv_JointMatrixINTELIS5_XT0_EXT1_EXT3_EXT4_EXT2_EEEmS2_S4_i(ptr addrspace(4) noundef, target("spirv.JointMatrixINTEL", i32, 12, 12, 3, 3, 2) noundef, i64 noundef, i32 noundef, i32 noundef, i32 noundef) local_unnamed_addr #2 | ||
|
||
; Function Attrs: nocallback nofree nosync nounwind willreturn memory(argmem: readwrite) | ||
declare void @llvm.lifetime.start.p0(i64 immarg, ptr nocapture) #3 | ||
|
||
; Function Attrs: nocallback nofree nosync nounwind willreturn memory(argmem: readwrite) | ||
declare void @llvm.lifetime.end.p0(i64 immarg, ptr nocapture) #3 | ||
|
||
attributes #0 = { convergent norecurse "frame-pointer"="all" "no-trapping-math"="true" "stack-protector-buffer-size"="8" "sycl-module-id"="matrix-int8-test.cpp" "uniform-work-group-size"="true" } | ||
attributes #1 = { nocallback nofree nosync nounwind willreturn memory(inaccessiblemem: readwrite) } | ||
attributes #2 = { convergent "frame-pointer"="all" "no-trapping-math"="true" "stack-protector-buffer-size"="8" } | ||
attributes #3 = { nocallback nofree nosync nounwind willreturn memory(argmem: readwrite) } | ||
attributes #4 = { convergent } |