Skip to content

Commit

Permalink
Fix SPIRVRegularizeLLVMBase::regularize fix for shl i1 and lshr i1 (#…
Browse files Browse the repository at this point in the history
…2288)

The translator failed assertion with V->user_empty() during regularize function when shl i1 or lshr i1 result is used. E.g.

%2 = shl i1 %0 %1
store %2, ptr addrspace(1) @G.1, align 1

Instruction shl i1 is converted to lshr i32 which arithmetic have the same behavior.
  • Loading branch information
bwlodarcz authored Jan 30, 2024
1 parent e8b2018 commit 239fbd4
Show file tree
Hide file tree
Showing 2 changed files with 102 additions and 36 deletions.
81 changes: 45 additions & 36 deletions lib/SPIRV/SPIRVRegularizeLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -322,31 +322,6 @@ void SPIRVRegularizeLLVMBase::expandSYCLTypeUsing(Module *M) {
expandVIDWithSYCLTypeByValComp(F);
}

Value *SPIRVRegularizeLLVMBase::extendBitInstBoolArg(Instruction *II) {
IRBuilder<> Builder(II);
auto *ArgTy = II->getOperand(0)->getType();
Type *NewArgType = nullptr;
if (ArgTy->isIntegerTy()) {
NewArgType = Builder.getInt32Ty();
} else if (ArgTy->isVectorTy() &&
cast<VectorType>(ArgTy)->getElementType()->isIntegerTy()) {
unsigned NumElements = cast<FixedVectorType>(ArgTy)->getNumElements();
NewArgType = VectorType::get(Builder.getInt32Ty(), NumElements, false);
} else {
llvm_unreachable("Unexpected type");
}
auto *NewBase = Builder.CreateZExt(II->getOperand(0), NewArgType);
auto *NewShift = Builder.CreateZExt(II->getOperand(1), NewArgType);
switch (II->getOpcode()) {
case Instruction::LShr:
return Builder.CreateLShr(NewBase, NewShift);
case Instruction::Shl:
return Builder.CreateShl(NewBase, NewShift);
default:
return II;
}
}

bool SPIRVRegularizeLLVMBase::runRegularizeLLVM(Module &Module) {
M = &Module;
Ctx = &M->getContext();
Expand Down Expand Up @@ -393,19 +368,53 @@ bool SPIRVRegularizeLLVMBase::regularize() {
}
}

// Translator treats i1 as boolean, but bit instructions take
// a scalar/vector integers, so we have to extend such arguments
if (II.isLogicalShift() &&
II.getOperand(0)->getType()->isIntOrIntVectorTy(1)) {
auto *NewInst = extendBitInstBoolArg(&II);
for (auto *U : II.users()) {
if (cast<Instruction>(U)->getOpcode() == Instruction::ZExt) {
U->dropAllReferences();
U->replaceAllUsesWith(NewInst);
ToErase.push_back(cast<Instruction>(U));
if (II.isLogicalShift()) {
// Translator treats i1 as boolean, but bit instructions take
// a scalar/vector integers, so we have to extend such arguments.
// shl i1 %a %b and lshr i1 %a %b are now converted on:
// %0 = select i1 %a, i32 1, i32 0
// %1 = select i1 %b, i32 1, i32 0
// %2 = lshr i32 %0, %1
// if any other instruction other than zext was dependant:
// %3 = icmp ne i32 %2, 0
// which converts it back to i1 and replace original result with %3
// to dependant instructions.
if (II.getOperand(0)->getType()->isIntOrIntVectorTy(1)) {
IRBuilder<> Builder(&II);
Value *CmpNEInst = nullptr;
Constant *ConstZero = ConstantInt::get(Builder.getInt32Ty(), 0);
Constant *ConstOne = ConstantInt::get(Builder.getInt32Ty(), 1);
if (auto *VecTy =
dyn_cast<FixedVectorType>(II.getOperand(0)->getType())) {
const unsigned NumElements = VecTy->getNumElements();
ConstZero = ConstantVector::getSplat(
ElementCount::getFixed(NumElements), ConstZero);
ConstOne = ConstantVector::getSplat(
ElementCount::getFixed(NumElements), ConstOne);
}
Value *ExtendedBase =
Builder.CreateSelect(II.getOperand(0), ConstOne, ConstZero);
Value *ExtendedShift =
Builder.CreateSelect(II.getOperand(1), ConstOne, ConstZero);
Value *ExtendedShiftedVal =
Builder.CreateLShr(ExtendedBase, ExtendedShift);
SmallVector<User *, 8> Users(II.users());
for (User *U : Users) {
if (auto *UI = dyn_cast<Instruction>(U)) {
if (UI->getOpcode() == Instruction::ZExt) {
UI->dropAllReferences();
UI->replaceAllUsesWith(ExtendedShiftedVal);
ToErase.push_back(UI);
continue;
}
}
if (!CmpNEInst) {
CmpNEInst = Builder.CreateICmpNE(ExtendedShiftedVal, ConstZero);
}
U->replaceUsesOfWith(&II, CmpNEInst);
}
ToErase.push_back(&II);
}
ToErase.push_back(&II);
}

// Remove optimization info not supported by SPIRV
Expand Down
57 changes: 57 additions & 0 deletions test/lshr_shl_i1_regularize.ll
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
; RUN: llvm-as %s -o %t.bc
; RUN: llvm-spirv -s %t.bc -o %t.reg.bc
; RUN: llvm-dis %t.reg.bc -o - | FileCheck --check-prefix=CHECK-LLVM %s

target triple = "spir64-unknown-unknown"

@G.0 = addrspace(1) global i1 false
@G.1 = addrspace(1) global i1 true
@G.2 = addrspace(1) global <8 x i1> <i1 false, i1 false, i1 false, i1 false, i1 false, i1 false, i1 false, i1 false>

define spir_func void @test_lshr_i1(i1 %a, i1 %b) {
entry:
%0 = lshr i1 %a, %b
; CHECK-LLVM: [[AI32_0:%[0-9]+]] = select i1 %a, i32 1, i32 0
; CHECK-LLVM: [[BI32_0:%[0-9]+]] = select i1 %b, i32 1, i32 0
; CHECK-LLVM: [[LSHRI32_0:%[0-9]+]] = lshr i32 [[AI32_0]], [[BI32_0]]
; CHECK-LLVM: [[TRUNC_0:%[0-9]+]] = icmp ne i32 [[LSHRI32_0]], 0
%1 = zext i1 %0 to i32
%2 = zext i1 %0 to i32
; CHECK-LLVM-NOT zext
; CHECK-LLVM-NOT select
store i1 %0, ptr addrspace(1) @G.0, align 1
; CHECK-LLVM: store i1 [[TRUNC_0]], ptr addrspace(1) @G.0, align 1
ret void
}

define spir_func void @test_shl_i1(i1 %a, i1 %b) {
entry:
%0 = shl i1 %a, %b
; CHECK-LLVM: [[AI32_1:%[0-9]+]] = select i1 %a, i32 1, i32 0
; CHECK-LLVM: [[BI32_1:%[0-9]+]] = select i1 %b, i32 1, i32 0
; CHECK-LLVM: [[LSHR32_1:%[0-9]+]] = lshr i32 [[AI32_1]], [[BI32_1]]
; CHECK-LLVM: [[TRUNC_1:%[0-9]+]] = icmp ne i32 [[LSHR32_1]], 0
%1 = zext i1 %0 to i32
%2 = zext i1 %0 to i32
; CHECK-LLVM-NOT: zext
; CHECK-LLVM-NOT: select
store i1 %0, ptr addrspace(1) @G.1, align 1
; CHECK-LLVM: store i1 [[TRUNC_1]], ptr addrspace(1) @G.1, align 1
ret void
}

define spir_func void @test_shl_vec_i1(<8 x i1> %a, <8 x i1> %b) {
entry:
%0 = shl <8 x i1> %a, %b
; CHECK-LLVM: [[AI32_2:%[0-9]+]] = select <8 x i1> %a, <8 x i32> <i32 1, i32 1, i32 1, i32 1, i32 1, i32 1, i32 1, i32 1>, <8 x i32> zeroinitializer
; CHECK-LLVM: [[BI32_2:%[0-9]+]] = select <8 x i1> %b, <8 x i32> <i32 1, i32 1, i32 1, i32 1, i32 1, i32 1, i32 1, i32 1>, <8 x i32> zeroinitializer
; CHECK-LLVM: [[LSHR32_2:%[0-9]+]] = lshr <8 x i32> [[AI32_2]], [[BI32_2]]
; CHECK-LLVM: [[TRUNC_2:%[0-9]+]] = icmp ne <8 x i32> [[LSHR32_2]], zeroinitializer
%1 = zext <8 x i1> %0 to <8 x i32>
%2 = zext <8 x i1> %0 to <8 x i32>
; CHECK-LLVM-NOT: zext
; CHECK-LLVM-NOT: select
store <8 x i1> %0, ptr addrspace(1) @G.2, align 1
; CHECK-LLVM: store <8 x i1> [[TRUNC_2]], ptr addrspace(1) @G.2, align 1
ret void
}

0 comments on commit 239fbd4

Please sign in to comment.