diff --git a/llvm/lib/SYCLLowerIR/LowerWGScope.cpp b/llvm/lib/SYCLLowerIR/LowerWGScope.cpp index c8db09f4772dd..c424f81a84e00 100644 --- a/llvm/lib/SYCLLowerIR/LowerWGScope.cpp +++ b/llvm/lib/SYCLLowerIR/LowerWGScope.cpp @@ -685,7 +685,7 @@ static void fixupPrivateMemoryPFWILambdaCaptures(CallInst *PFWICall) { // Go through "byval" parameters which are passed as AS(0) pointers // and: (1) create local shadows for them (2) and initialize them from the -// leader's copy and (3) replace usages with pointer to the shadow +// leader's copy and (3) materialize the value in the local variable before use // // Do the same for 'this' pointer which points to PFWG lamda object which is // allocated in the caller. Caller is a kernel function which is generated by @@ -707,7 +707,7 @@ static void sharePFWGPrivateObjects(Function &F, const Triple &TT) { BasicBlock *LeaderBB = EntryBB->splitBasicBlock(SplitPoint, "leader"); BasicBlock *MergeBB = LeaderBB->splitBasicBlock(&LeaderBB->front(), "merge"); - // 1) rewire the above basic blocks so that LeaderBB is executed only for the + // Rewire the above basic blocks so that LeaderBB is executed only for the // leader workitem guardBlockWithIsLeaderCheck(EntryBB, LeaderBB, MergeBB, EntryBB->back().getDebugLoc(), TT); @@ -719,28 +719,13 @@ static void sharePFWGPrivateObjects(Function &F, const Triple &TT) { IRBuilder<> Builder(Ctx); Builder.SetInsertPoint(&LeaderBB->front()); - // 2) create the shared copy - "shadow" - for current arg + // Create the shared copy - "shadow" - for current arg GlobalVariable *Shadow = nullptr; - Value *RepVal = nullptr; if (Arg.hasByValAttr()) { assert(Arg.getType()->getPointerAddressSpace() == asUInt(spirv::AddrSpace::Private)); T = Arg.getParamByValType(); Shadow = spirv::createWGLocalVariable(*F.getParent(), T, "ArgShadow"); - RepVal = Shadow; - if (TT.isNVPTX()) { - // For NVPTX target address space inference for kernel arguments and - // allocas is happening in the backend (NVPTXLowerArgs and - // NVPTXLowerAlloca passes). After the frontend these pointers are in - // LLVM default address space 0 which is the generic address space for - // NVPTX target. - assert(Arg.getType()->getPointerAddressSpace() == 0); - - // Cast a pointer in the shared address space to the generic address - // space. - RepVal = ConstantExpr::getPointerBitCastOrAddrSpaceCast(Shadow, - Arg.getType()); - } } // Process 'this' pointer which points to PFWG lambda object else if (Arg.getArgNo() == 0) { @@ -748,21 +733,20 @@ static void sharePFWGPrivateObjects(Function &F, const Triple &TT) { assert(PtrT && "Expected this pointer as the first argument"); T = PtrT->getElementType(); Shadow = spirv::createWGLocalVariable(*F.getParent(), T, "ArgShadow"); - RepVal = - Builder.CreatePointerBitCastOrAddrSpaceCast(Shadow, Arg.getType()); } - if (!Shadow || !RepVal) + if (!Shadow) continue; - // 3) replace argument with shadow in all uses - for (auto *U : Arg.users()) - U->replaceUsesOfWith(&Arg, RepVal); - copyBetweenPrivateAndShadow(&Arg, Shadow, Builder, true /*private->shadow*/); + // Materialize the value in the local variable before use + Builder.SetInsertPoint(&MergeBB->front()); + copyBetweenPrivateAndShadow(&Arg, Shadow, Builder, + false /*shadow->private*/); } - // 5) make sure workers use up-to-date shared values written by the leader + // Insert barrier to make sure workers use up-to-date shared values written by + // the leader spirv::genWGBarrier(MergeBB->front(), TT); } diff --git a/llvm/test/SYCLLowerIR/byval_arg.ll b/llvm/test/SYCLLowerIR/byval_arg.ll index 7be1b277c16b1..03c4bb2892f64 100644 --- a/llvm/test/SYCLLowerIR/byval_arg.ll +++ b/llvm/test/SYCLLowerIR/byval_arg.ll @@ -19,6 +19,8 @@ define internal spir_func void @wibble(%struct.baz* byval(%struct.baz) %arg1) !w ; CHECK-NEXT: br label [[MERGE]] ; CHECK: merge: ; CHECK-NEXT: call void @_Z22__spirv_ControlBarrierjjj(i32 2, i32 2, i32 272) +; CHECK-NEXT: [[TMP3:%.*]] = bitcast %struct.baz* [[ARG1]] to i8* +; CHECK-NEXT: call void @llvm.memcpy.p0i8.p3i8.i64(i8* [[TMP3]], i8 addrspace(3)* align 8 bitcast (%struct.baz addrspace(3)* @[[SHADOW]] to i8 addrspace(3)*), i64 8, i1 false) ; CHECK-NEXT: ret void ; ret void diff --git a/llvm/test/SYCLLowerIR/byval_arg_cast.ll b/llvm/test/SYCLLowerIR/byval_arg_cast.ll new file mode 100644 index 0000000000000..14699c070a673 --- /dev/null +++ b/llvm/test/SYCLLowerIR/byval_arg_cast.ll @@ -0,0 +1,50 @@ +; NOTE: Assertions have been autogenerated by utils/update_test_checks.py +; RUN: opt < %s -LowerWGScope -S | FileCheck %s + +; Test to check that shadow local variable for byval argument is +; materialized before use. Otherwise invalid cast between address +; spaces is generated. + +%struct.widget = type { %struct.baz, %struct.baz, %struct.baz, %struct.spam } +%struct.baz = type { %struct.snork } +%struct.snork = type { [1 x i64] } +%struct.spam = type { %struct.snork } + + +declare dso_local spir_func void @zot(i8*) + +; CHECK: @[[SHADOW:[a-zA-Z0-9]+]] = internal unnamed_addr addrspace(3) global %struct.widget undef, align 16 + +; Function Attrs: inlinehint norecurse +define dso_local spir_func void @wombat(%struct.widget* byval(%struct.widget) align 8 %arg) align 2 !work_group_scope !1 { +; CHECK-LABEL: @wombat( +; CHECK-NEXT: bb: +; CHECK-NEXT: [[TMP0:%.*]] = load i64, i64 addrspace(1)* @__spirv_BuiltInLocalInvocationIndex, align 4 +; CHECK-NEXT: [[CMPZ1:%.*]] = icmp eq i64 [[TMP0]], 0 +; CHECK-NEXT: br i1 [[CMPZ1]], label [[LEADER:%.*]], label [[MERGE:%.*]] +; CHECK: leader: +; CHECK-NEXT: [[TMP1:%.*]] = bitcast %struct.widget* [[ARG:%.*]] to i8* +; CHECK-NEXT: call void @llvm.memcpy.p3i8.p0i8.i64(i8 addrspace(3)* align 16 bitcast (%struct.widget addrspace(3)* @[[SHADOW]] to i8 addrspace(3)*), i8* align 8 [[TMP1]], i64 32, i1 false) +; CHECK-NEXT: br label [[MERGE]] +; CHECK: merge: +; CHECK-NEXT: call void @_Z22__spirv_ControlBarrierjjj(i32 2, i32 2, i32 272) #0 +; CHECK-NEXT: [[TMP2:%.*]] = bitcast %struct.widget* [[ARG]] to i8* +; CHECK-NEXT: call void @llvm.memcpy.p0i8.p3i8.i64(i8* align 8 [[TMP2]], i8 addrspace(3)* align 16 bitcast (%struct.widget addrspace(3)* @[[SHADOW]] to i8 addrspace(3)*), i64 32, i1 false) +; CHECK-NEXT: [[TMP3:%.*]] = load i64, i64 addrspace(1)* @__spirv_BuiltInLocalInvocationIndex, align 4 +; CHECK-NEXT: [[CMPZ:%.*]] = icmp eq i64 [[TMP3]], 0 +; CHECK-NEXT: br i1 [[CMPZ]], label [[WG_LEADER:%.*]], label [[WG_CF:%.*]] +; CHECK: wg_leader: +; CHECK-NEXT: [[TMP:%.*]] = bitcast %struct.widget* [[ARG]] to i8* +; CHECK-NEXT: call void @zot(i8* [[TMP]]) +; CHECK-NEXT: br label [[WG_CF]] +; CHECK: wg_cf: +; CHECK-NEXT: call void @_Z22__spirv_ControlBarrierjjj(i32 2, i32 2, i32 272) #0 +; CHECK-NEXT: ret void +; +bb: + %tmp = bitcast %struct.widget* %arg to i8* + call void @zot(i8* %tmp) + ret void +} + +!1 = !{} diff --git a/llvm/test/SYCLLowerIR/cast_shadow.ll b/llvm/test/SYCLLowerIR/cast_shadow.ll index 1822f1c5e8148..56e8e9e0288bc 100644 --- a/llvm/test/SYCLLowerIR/cast_shadow.ll +++ b/llvm/test/SYCLLowerIR/cast_shadow.ll @@ -15,7 +15,10 @@ target triple = "nvptx64-nvidia-cuda-sycldevice" define internal void @wobble(%struct.baz* %arg, %struct.spam* byval(%struct.spam) %arg1) !work_group_scope !0 { ; CHECK: [[TMP10:%.*]] = bitcast %struct.spam* [[ARG1:%.*]] to i8* ; CHECK: call void @llvm.memcpy.p3i8.p0i8.i64(i8 addrspace(3)* align 16 bitcast (%struct.spam addrspace(3)* @[[SHADOW]] to i8 addrspace(3)*), i8* [[TMP10]], i64 32, i1 false) -; CHECK: call void @widget(%struct.spam* addrspacecast (%struct.spam addrspace(3)* @[[SHADOW]] to %struct.spam*), %struct.quux* byval(%struct.quux) [[TMP2:%.*]]) +; CHECK: call void @_Z22__spirv_ControlBarrierjjj(i32 2, i32 2, i32 272) #0 +; CHECK: [[TMP11:%.*]] = bitcast %struct.spam* %arg1 to i8* +; CHECK: call void @llvm.memcpy.p0i8.p3i8.i64(i8* [[TMP11:%.*]], i8 addrspace(3)* align 16 bitcast (%struct.spam addrspace(3)* @[[SHADOW]] to i8 +; CHECK: call void @widget(%struct.spam* %arg1, %struct.quux* byval(%struct.quux) [[TMP2:%.*]]) ; bb: %tmp = alloca %struct.baz* diff --git a/llvm/test/SYCLLowerIR/pfwg_and_pfwi.ll b/llvm/test/SYCLLowerIR/pfwg_and_pfwi.ll index 09ba788316dee..615d11190d4cc 100644 --- a/llvm/test/SYCLLowerIR/pfwg_and_pfwi.ll +++ b/llvm/test/SYCLLowerIR/pfwg_and_pfwi.ll @@ -35,13 +35,17 @@ define internal spir_func void @wibble(%struct.bar addrspace(4)* %arg, %struct.z ; CHECK-NEXT: br label [[MERGE]] ; CHECK: merge: ; CHECK-NEXT: call void @_Z22__spirv_ControlBarrierjjj(i32 2, i32 2, i32 272) #0 -; CHECK-NEXT: [[TMP3:%.*]] = load i64, i64 addrspace(1)* @__spirv_BuiltInLocalInvocationIndex -; CHECK-NEXT: [[CMPZ:%.*]] = icmp eq i64 [[TMP3]], 0 +; CHECK-NEXT: [[TMP3:%.*]] = bitcast %struct.zot* [[ARG1]] to i8* +; CHECK-NEXT: call void @llvm.memcpy.p0i8.p3i8.i64(i8* align 8 [[TMP3]], i8 addrspace(3)* align 16 bitcast (%struct.zot addrspace(3)* @[[GROUP_SHADOW]] to i8 addrspace(3)*), i64 96, i1 false) +; CHECK-NEXT: [[TMP4:%.*]] = bitcast [[STRUCT_BAR]] addrspace(4)* [[ARG]] to i8 addrspace(4)* +; CHECK-NEXT: call void @llvm.memcpy.p4i8.p3i8.i64(i8 addrspace(4)* align 8 [[TMP4]], i8 addrspace(3)* align 8 getelementptr inbounds (%struct.bar, [[STRUCT_BAR]] addrspace(3)* @[[PFWG_SHADOW]], i32 0, i32 0), i64 1, i1 false) +; CHECK-NEXT: [[TMP5:%.*]] = load i64, i64 addrspace(1)* @__spirv_BuiltInLocalInvocationIndex +; CHECK-NEXT: [[CMPZ:%.*]] = icmp eq i64 [[TMP5]], 0 ; CHECK-NEXT: br i1 [[CMPZ]], label [[WG_LEADER:%.*]], label [[WG_CF:%.*]] ; CHECK: wg_leader: -; CHECK-NEXT: store [[STRUCT_BAR]] addrspace(4)* addrspacecast (%struct.bar addrspace(3)* @[[PFWG_SHADOW]] to [[STRUCT_BAR]] addrspace(4)*), [[STRUCT_BAR]] addrspace(4)** [[TMP]], align 8 +; CHECK-NEXT: store [[STRUCT_BAR]] addrspace(4)* [[ARG]], [[STRUCT_BAR]] addrspace(4)** [[TMP]], align 8 ; CHECK-NEXT: [[TMP3:%.*]] = load [[STRUCT_BAR]] addrspace(4)*, [[STRUCT_BAR]] addrspace(4)** [[TMP]], align 8 -; CHECK-NEXT: [[TMP4:%.*]] = addrspacecast [[STRUCT_ZOT:%.*]] addrspace(3)* @[[GROUP_SHADOW]] to [[STRUCT_ZOT]] addrspace(4)* +; CHECK-NEXT: [[TMP4:%.*]] = addrspacecast %struct.zot* [[ARG1]] to [[STRUCT_ZOT:%.*]] addrspace(4)* ; CHECK-NEXT: store [[STRUCT_ZOT]] addrspace(4)* [[TMP4]], [[STRUCT_ZOT]] addrspace(4)* addrspace(3)* @[[GROUP_SHADOW_PTR]] ; CHECK-NEXT: br label [[WG_CF]] ; CHECK: wg_cf: