From 239fbd4e1790257e7c919b26870f23a23b0d39ef Mon Sep 17 00:00:00 2001 From: bwlodarcz Date: Tue, 30 Jan 2024 13:29:06 +0100 Subject: [PATCH] Fix SPIRVRegularizeLLVMBase::regularize fix for shl i1 and lshr i1 (#2288) The translator failed assertion with V->user_empty() during regularize function when shl i1 or lshr i1 result is used. E.g. %2 = shl i1 %0 %1 store %2, ptr addrspace(1) @G.1, align 1 Instruction shl i1 is converted to lshr i32 which arithmetic have the same behavior. --- lib/SPIRV/SPIRVRegularizeLLVM.cpp | 81 +++++++++++++++++-------------- test/lshr_shl_i1_regularize.ll | 57 ++++++++++++++++++++++ 2 files changed, 102 insertions(+), 36 deletions(-) create mode 100644 test/lshr_shl_i1_regularize.ll diff --git a/lib/SPIRV/SPIRVRegularizeLLVM.cpp b/lib/SPIRV/SPIRVRegularizeLLVM.cpp index 822479e22a..13b45ca6ad 100644 --- a/lib/SPIRV/SPIRVRegularizeLLVM.cpp +++ b/lib/SPIRV/SPIRVRegularizeLLVM.cpp @@ -322,31 +322,6 @@ void SPIRVRegularizeLLVMBase::expandSYCLTypeUsing(Module *M) { expandVIDWithSYCLTypeByValComp(F); } -Value *SPIRVRegularizeLLVMBase::extendBitInstBoolArg(Instruction *II) { - IRBuilder<> Builder(II); - auto *ArgTy = II->getOperand(0)->getType(); - Type *NewArgType = nullptr; - if (ArgTy->isIntegerTy()) { - NewArgType = Builder.getInt32Ty(); - } else if (ArgTy->isVectorTy() && - cast(ArgTy)->getElementType()->isIntegerTy()) { - unsigned NumElements = cast(ArgTy)->getNumElements(); - NewArgType = VectorType::get(Builder.getInt32Ty(), NumElements, false); - } else { - llvm_unreachable("Unexpected type"); - } - auto *NewBase = Builder.CreateZExt(II->getOperand(0), NewArgType); - auto *NewShift = Builder.CreateZExt(II->getOperand(1), NewArgType); - switch (II->getOpcode()) { - case Instruction::LShr: - return Builder.CreateLShr(NewBase, NewShift); - case Instruction::Shl: - return Builder.CreateShl(NewBase, NewShift); - default: - return II; - } -} - bool SPIRVRegularizeLLVMBase::runRegularizeLLVM(Module &Module) { M = &Module; Ctx = &M->getContext(); @@ -393,19 +368,53 @@ bool SPIRVRegularizeLLVMBase::regularize() { } } - // Translator treats i1 as boolean, but bit instructions take - // a scalar/vector integers, so we have to extend such arguments - if (II.isLogicalShift() && - II.getOperand(0)->getType()->isIntOrIntVectorTy(1)) { - auto *NewInst = extendBitInstBoolArg(&II); - for (auto *U : II.users()) { - if (cast(U)->getOpcode() == Instruction::ZExt) { - U->dropAllReferences(); - U->replaceAllUsesWith(NewInst); - ToErase.push_back(cast(U)); + if (II.isLogicalShift()) { + // Translator treats i1 as boolean, but bit instructions take + // a scalar/vector integers, so we have to extend such arguments. + // shl i1 %a %b and lshr i1 %a %b are now converted on: + // %0 = select i1 %a, i32 1, i32 0 + // %1 = select i1 %b, i32 1, i32 0 + // %2 = lshr i32 %0, %1 + // if any other instruction other than zext was dependant: + // %3 = icmp ne i32 %2, 0 + // which converts it back to i1 and replace original result with %3 + // to dependant instructions. + if (II.getOperand(0)->getType()->isIntOrIntVectorTy(1)) { + IRBuilder<> Builder(&II); + Value *CmpNEInst = nullptr; + Constant *ConstZero = ConstantInt::get(Builder.getInt32Ty(), 0); + Constant *ConstOne = ConstantInt::get(Builder.getInt32Ty(), 1); + if (auto *VecTy = + dyn_cast(II.getOperand(0)->getType())) { + const unsigned NumElements = VecTy->getNumElements(); + ConstZero = ConstantVector::getSplat( + ElementCount::getFixed(NumElements), ConstZero); + ConstOne = ConstantVector::getSplat( + ElementCount::getFixed(NumElements), ConstOne); + } + Value *ExtendedBase = + Builder.CreateSelect(II.getOperand(0), ConstOne, ConstZero); + Value *ExtendedShift = + Builder.CreateSelect(II.getOperand(1), ConstOne, ConstZero); + Value *ExtendedShiftedVal = + Builder.CreateLShr(ExtendedBase, ExtendedShift); + SmallVector Users(II.users()); + for (User *U : Users) { + if (auto *UI = dyn_cast(U)) { + if (UI->getOpcode() == Instruction::ZExt) { + UI->dropAllReferences(); + UI->replaceAllUsesWith(ExtendedShiftedVal); + ToErase.push_back(UI); + continue; + } + } + if (!CmpNEInst) { + CmpNEInst = Builder.CreateICmpNE(ExtendedShiftedVal, ConstZero); + } + U->replaceUsesOfWith(&II, CmpNEInst); } + ToErase.push_back(&II); } - ToErase.push_back(&II); } // Remove optimization info not supported by SPIRV diff --git a/test/lshr_shl_i1_regularize.ll b/test/lshr_shl_i1_regularize.ll new file mode 100644 index 0000000000..116234646b --- /dev/null +++ b/test/lshr_shl_i1_regularize.ll @@ -0,0 +1,57 @@ +; RUN: llvm-as %s -o %t.bc +; RUN: llvm-spirv -s %t.bc -o %t.reg.bc +; RUN: llvm-dis %t.reg.bc -o - | FileCheck --check-prefix=CHECK-LLVM %s + +target triple = "spir64-unknown-unknown" + +@G.0 = addrspace(1) global i1 false +@G.1 = addrspace(1) global i1 true +@G.2 = addrspace(1) global <8 x i1> + +define spir_func void @test_lshr_i1(i1 %a, i1 %b) { +entry: + %0 = lshr i1 %a, %b +; CHECK-LLVM: [[AI32_0:%[0-9]+]] = select i1 %a, i32 1, i32 0 +; CHECK-LLVM: [[BI32_0:%[0-9]+]] = select i1 %b, i32 1, i32 0 +; CHECK-LLVM: [[LSHRI32_0:%[0-9]+]] = lshr i32 [[AI32_0]], [[BI32_0]] +; CHECK-LLVM: [[TRUNC_0:%[0-9]+]] = icmp ne i32 [[LSHRI32_0]], 0 + %1 = zext i1 %0 to i32 + %2 = zext i1 %0 to i32 +; CHECK-LLVM-NOT zext +; CHECK-LLVM-NOT select + store i1 %0, ptr addrspace(1) @G.0, align 1 +; CHECK-LLVM: store i1 [[TRUNC_0]], ptr addrspace(1) @G.0, align 1 + ret void +} + +define spir_func void @test_shl_i1(i1 %a, i1 %b) { +entry: + %0 = shl i1 %a, %b +; CHECK-LLVM: [[AI32_1:%[0-9]+]] = select i1 %a, i32 1, i32 0 +; CHECK-LLVM: [[BI32_1:%[0-9]+]] = select i1 %b, i32 1, i32 0 +; CHECK-LLVM: [[LSHR32_1:%[0-9]+]] = lshr i32 [[AI32_1]], [[BI32_1]] +; CHECK-LLVM: [[TRUNC_1:%[0-9]+]] = icmp ne i32 [[LSHR32_1]], 0 + %1 = zext i1 %0 to i32 + %2 = zext i1 %0 to i32 +; CHECK-LLVM-NOT: zext +; CHECK-LLVM-NOT: select + store i1 %0, ptr addrspace(1) @G.1, align 1 +; CHECK-LLVM: store i1 [[TRUNC_1]], ptr addrspace(1) @G.1, align 1 + ret void +} + +define spir_func void @test_shl_vec_i1(<8 x i1> %a, <8 x i1> %b) { +entry: + %0 = shl <8 x i1> %a, %b +; CHECK-LLVM: [[AI32_2:%[0-9]+]] = select <8 x i1> %a, <8 x i32> , <8 x i32> zeroinitializer +; CHECK-LLVM: [[BI32_2:%[0-9]+]] = select <8 x i1> %b, <8 x i32> , <8 x i32> zeroinitializer +; CHECK-LLVM: [[LSHR32_2:%[0-9]+]] = lshr <8 x i32> [[AI32_2]], [[BI32_2]] +; CHECK-LLVM: [[TRUNC_2:%[0-9]+]] = icmp ne <8 x i32> [[LSHR32_2]], zeroinitializer + %1 = zext <8 x i1> %0 to <8 x i32> + %2 = zext <8 x i1> %0 to <8 x i32> +; CHECK-LLVM-NOT: zext +; CHECK-LLVM-NOT: select + store <8 x i1> %0, ptr addrspace(1) @G.2, align 1 +; CHECK-LLVM: store <8 x i1> [[TRUNC_2]], ptr addrspace(1) @G.2, align 1 + ret void +} \ No newline at end of file