-
Notifications
You must be signed in to change notification settings - Fork 12.8k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Auto merge of #55073 - alexcrichton:demote-simd, r=<try>
rustc: Fix (again) simd vectors by-val in ABI The issue of passing around SIMD types as values between functions has seen [quite a lot] of [discussion], and although we thought [we fixed it][quite a lot] it [wasn't]! This PR is a change to rustc to, again, try to fix this issue. The fundamental problem here remains the same, if a SIMD vector argument is passed by-value in LLVM's function type, then if the caller and callee disagree on target features a miscompile happens. We solve this by never passing SIMD vectors by-value, but LLVM will still thwart us with its argument promotion pass to promote by-ref SIMD arguments to by-val SIMD arguments. This commit is an attempt to thwart LLVM thwarting us. We, just before codegen, will take yet another look at the LLVM module and demote any by-value SIMD arguments we see. This is a very manual attempt by us to ensure the codegen for a module keeps working, and it unfortunately is likely producing suboptimal code, even in release mode. The saving grace for this, in theory, is that if SIMD types are passed by-value across a boundary in release mode it's pretty unlikely to be performance sensitive (as it's already doing a load/store, and otherwise perf-sensitive bits should be inlined). The implementation here is basically a big wad of C++. It was largely copied from LLVM's own argument promotion pass, only doing the reverse. In local testing this... Closes #50154 Closes #52636 Closes #54583 Closes #55059 [quite a lot]: #47743 [discussion]: #44367 [wasn't]: #50154
- Loading branch information
Showing
10 changed files
with
330 additions
and
10 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,186 @@ | ||
// Copyright 2018 The Rust Project Developers. See the COPYRIGHT | ||
// file at the top-level directory of this distribution and at | ||
// http://rust-lang.org/COPYRIGHT. | ||
// | ||
// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or | ||
// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license | ||
// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your | ||
// option. This file may not be copied, modified, or distributed | ||
// except according to those terms. | ||
|
||
#include <vector> | ||
#include <set> | ||
|
||
#include "rustllvm.h" | ||
|
||
#include "llvm/IR/CallSite.h" | ||
#include "llvm/IR/Module.h" | ||
|
||
using namespace llvm; | ||
|
||
static std::vector<Function*> | ||
GetFunctionsWithSimdArgs(Module *M) { | ||
std::vector<Function*> Ret; | ||
|
||
for (auto &F : M->functions()) { | ||
// Skip all intrinsic calls as these are always tightly controlled to "work | ||
// correctly", so no need to fixup any of these. | ||
if (F.isIntrinsic()) | ||
continue; | ||
|
||
// We know that we started out with a `Module` that has what we want, so | ||
// we're just trying to undo specifically the work of the | ||
// `ArgumentPromotion` pass. That only runs in a select few circumstances, | ||
// so make sure that we don't get anything surprising. For example, make | ||
// sure we don't actually return a vector type because rustc shouldn't ever | ||
// generate this and nor should passes make this happen. | ||
assert(!F->getReturnType()->isVectorTy()); | ||
|
||
// If any argument to this function is a by-value vector type, then that's | ||
// bad! The compiler didn't generate any functions that looked like this, | ||
// and we try to rely on LLVM to not do this! Argument promotion may, | ||
// however, promote arguments from behind references. In any case, figure | ||
// out if we're interested in demoting this argument. | ||
bool anyVector = false; | ||
for (auto &Arg : F.args()) | ||
anyVector = anyVector || Arg.getType()->isVectorTy(); | ||
if (anyVector) | ||
Ret.push_back(&F); | ||
} | ||
|
||
return Ret; | ||
} | ||
|
||
extern "C" void | ||
LLVMRustDemoteSimdArguments(LLVMModuleRef Mod) { | ||
Module *M = unwrap(Mod); | ||
|
||
auto Functions = GetFunctionsWithSimdArgs(M); | ||
|
||
for (auto F : Functions) { | ||
// The argument promotion pass in LLVM should only run on functions that | ||
// have local linkage. We're modifying function signatures here, so make | ||
// sure such a desctructive change doesn't affect the public ABI. | ||
assert(F->hasLocalLinkage()); | ||
|
||
// Build up our list of new parameters and new argument attributes. | ||
// We're only changing those arguments which are vector types. | ||
SmallVector<Type*, 8> Params; | ||
SmallVector<AttributeSet, 8> ArgAttrVec; | ||
auto PAL = F->getAttributes(); | ||
for (auto &Arg : F->args()) { | ||
auto *Ty = Arg.getType(); | ||
if (Ty->isVectorTy()) { | ||
Params.push_back(PointerType::get(Ty, 0)); | ||
ArgAttrVec.push_back(AttributeSet()); | ||
} else { | ||
Params.push_back(Ty); | ||
ArgAttrVec.push_back(PAL.getParamAttributes(Arg.getArgNo())); | ||
} | ||
} | ||
|
||
// Replace `F` with a new function with our new signature. I'm... not really | ||
// sure how this works, but this is all the steps `ArgumentPromotion` does | ||
// to replace a signature as well. | ||
assert(!F->isVarArg()); // ArgumentPromotion should skip these fns | ||
FunctionType *NFTy = FunctionType::get(F->getReturnType(), Params, false); | ||
Function *NF = Function::Create(NFTy, F->getLinkage(), F->getName()); | ||
NF->copyAttributesFrom(F); | ||
NF->setSubprogram(F->getSubprogram()); | ||
F->setSubprogram(nullptr); | ||
NF->setAttributes(AttributeList::get(F->getContext(), | ||
PAL.getFnAttributes(), | ||
PAL.getRetAttributes(), | ||
ArgAttrVec)); | ||
ArgAttrVec.clear(); | ||
F->getParent()->getFunctionList().insert(F->getIterator(), NF); | ||
NF->takeName(F); | ||
|
||
// Iterate over all invocations of `F`, updating all `call` instructions to | ||
// store immediate vector types in a local `alloc` instead of a by-value | ||
// vector. | ||
// | ||
// Like before, much of this is copied from the `ArgumentPromotion` pass in | ||
// LLVM. | ||
SmallVector<Value*, 16> Args; | ||
while (!F->use_empty()) { | ||
CallSite CS(F->user_back()); | ||
assert(CS.getCalledFunction() == F); | ||
Instruction *Call = CS.getInstruction(); | ||
const AttributeList &CallPAL = CS.getAttributes(); | ||
|
||
// Loop over the operands, inserting an `alloca` and a store for any | ||
// argument we're demoting to be by reference | ||
// | ||
// FIXME: we probably want to figure out an LLVM pass to run and clean up | ||
// this function and instructions we're generating, we should in theory | ||
// only generate a maximum number of `alloca` instructions rather than | ||
// one-per-variable unconditionally. | ||
CallSite::arg_iterator AI = CS.arg_begin(); | ||
size_t ArgNo = 0; | ||
for (Function::arg_iterator I = F->arg_begin(), E = F->arg_end(); I != E; | ||
++I, ++AI, ++ArgNo) { | ||
if (I->getType()->isVectorTy()) { | ||
AllocaInst *AllocA = new AllocaInst(I->getType(), 0, nullptr, "", Call); | ||
new StoreInst(*AI, AllocA, Call); | ||
Args.push_back(AllocA); | ||
ArgAttrVec.push_back(AttributeSet()); | ||
} else { | ||
Args.push_back(*AI); | ||
ArgAttrVec.push_back(CallPAL.getParamAttributes(ArgNo)); | ||
} | ||
} | ||
assert(AI == CS.arg_end()); | ||
|
||
// Create a new call instructions which we'll use to replace the old call | ||
// instruction, copying over as many attributes and such as possible. | ||
SmallVector<OperandBundleDef, 1> OpBundles; | ||
CS.getOperandBundlesAsDefs(OpBundles); | ||
|
||
CallSite NewCS; | ||
if (InvokeInst *II = dyn_cast<InvokeInst>(Call)) { | ||
InvokeInst::Create(NF, II->getNormalDest(), II->getUnwindDest(), | ||
Args, OpBundles, "", Call); | ||
} else { | ||
auto *NewCall = CallInst::Create(NF, Args, OpBundles, "", Call); | ||
NewCall->setTailCallKind(cast<CallInst>(Call)->getTailCallKind()); | ||
NewCS = NewCall; | ||
} | ||
NewCS.setCallingConv(CS.getCallingConv()); | ||
NewCS.setAttributes( | ||
AttributeList::get(F->getContext(), CallPAL.getFnAttributes(), | ||
CallPAL.getRetAttributes(), ArgAttrVec)); | ||
NewCS->setDebugLoc(Call->getDebugLoc()); | ||
Args.clear(); | ||
ArgAttrVec.clear(); | ||
Call->replaceAllUsesWith(NewCS.getInstruction()); | ||
NewCS->takeName(Call); | ||
Call->eraseFromParent(); | ||
} | ||
|
||
// Splice the body of the old function right into the new function. | ||
NF->getBasicBlockList().splice(NF->begin(), F->getBasicBlockList()); | ||
|
||
// Update our new function to replace all uses of the by-value argument with | ||
// loads of the pointer argument we've generated. | ||
// | ||
// FIXME: we probably want to only generate one load instruction per | ||
// function? Or maybe run an LLVM pass to clean up this function? | ||
for (Function::arg_iterator I = F->arg_begin(), | ||
E = F->arg_end(), | ||
I2 = NF->arg_begin(); | ||
I != E; | ||
++I, ++I2) { | ||
if (I->getType()->isVectorTy()) { | ||
I->replaceAllUsesWith(new LoadInst(&*I2, "", &NF->begin()->front())); | ||
} else { | ||
I->replaceAllUsesWith(&*I2); | ||
} | ||
I2->takeName(&*I); | ||
} | ||
|
||
// Delete all references to the old function, it should be entirely dead | ||
// now. | ||
M->getFunctionList().remove(F); | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
13 changes: 13 additions & 0 deletions
13
src/test/run-make/simd-argument-promotion-thwarted/Makefile
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
-include ../../run-make-fulldeps/tools.mk | ||
|
||
ifeq ($(TARGET),x86_64-unknown-linux-gnu) | ||
all: | ||
$(RUSTC) t1.rs -C opt-level=3 | ||
$(TMPDIR)/t1 | ||
$(RUSTC) t2.rs -C opt-level=3 | ||
$(TMPDIR)/t2 | ||
$(RUSTC) t3.rs -C opt-level=3 | ||
$(TMPDIR)/t3 | ||
else | ||
all: | ||
endif |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
use std::arch::x86_64; | ||
|
||
fn main() { | ||
if !is_x86_feature_detected!("avx2") { | ||
return println!("AVX2 is not supported on this machine/build."); | ||
} | ||
let load_bytes: [u8; 32] = [0x0f; 32]; | ||
let lb_ptr = load_bytes.as_ptr(); | ||
let reg_load = unsafe { | ||
x86_64::_mm256_loadu_si256( | ||
lb_ptr as *const x86_64::__m256i | ||
) | ||
}; | ||
println!("{:?}", reg_load); | ||
let mut store_bytes: [u8; 32] = [0; 32]; | ||
let sb_ptr = store_bytes.as_mut_ptr(); | ||
unsafe { | ||
x86_64::_mm256_storeu_si256(sb_ptr as *mut x86_64::__m256i, reg_load); | ||
} | ||
assert_eq!(load_bytes, store_bytes); | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,14 @@ | ||
use std::arch::x86_64::*; | ||
|
||
fn main() { | ||
if !is_x86_feature_detected!("avx") { | ||
return println!("AVX is not supported on this machine/build."); | ||
} | ||
unsafe { | ||
let f = _mm256_set_pd(2.0, 2.0, 2.0, 2.0); | ||
let r = _mm256_mul_pd(f, f); | ||
|
||
union A { a: __m256d, b: [f64; 4] } | ||
assert_eq!(A { a: r }.b, [4.0, 4.0, 4.0, 4.0]); | ||
} | ||
} |
Oops, something went wrong.