forked from rust-lang/rust
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Rollup merge of rust-lang#55073 - alexcrichton:demote-simd, r=nagisa
rustc: Fix (again) simd vectors by-val in ABI The issue of passing around SIMD types as values between functions has seen [quite a lot] of [discussion], and although we thought [we fixed it][quite a lot] it [wasn't]! This PR is a change to rustc to, again, try to fix this issue. The fundamental problem here remains the same, if a SIMD vector argument is passed by-value in LLVM's function type, then if the caller and callee disagree on target features a miscompile happens. We solve this by never passing SIMD vectors by-value, but LLVM will still thwart us with its argument promotion pass to promote by-ref SIMD arguments to by-val SIMD arguments. This commit is an attempt to thwart LLVM thwarting us. We, just before codegen, will take yet another look at the LLVM module and demote any by-value SIMD arguments we see. This is a very manual attempt by us to ensure the codegen for a module keeps working, and it unfortunately is likely producing suboptimal code, even in release mode. The saving grace for this, in theory, is that if SIMD types are passed by-value across a boundary in release mode it's pretty unlikely to be performance sensitive (as it's already doing a load/store, and otherwise perf-sensitive bits should be inlined). The implementation here is basically a big wad of C++. It was largely copied from LLVM's own argument promotion pass, only doing the reverse. In local testing this... Closes rust-lang#50154 Closes rust-lang#52636 Closes rust-lang#54583 Closes rust-lang#55059 [quite a lot]: rust-lang#47743 [discussion]: rust-lang#44367 [wasn't]: rust-lang#50154
- Loading branch information
Showing
9 changed files
with
324 additions
and
9 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,181 @@ | ||
// Copyright 2018 The Rust Project Developers. See the COPYRIGHT | ||
// file at the top-level directory of this distribution and at | ||
// http://rust-lang.org/COPYRIGHT. | ||
// | ||
// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or | ||
// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license | ||
// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your | ||
// option. This file may not be copied, modified, or distributed | ||
// except according to those terms. | ||
|
||
#include <vector> | ||
#include <set> | ||
|
||
#include "rustllvm.h" | ||
|
||
#include "llvm/IR/CallSite.h" | ||
#include "llvm/IR/Module.h" | ||
#include "llvm/ADT/STLExtras.h" | ||
|
||
using namespace llvm; | ||
|
||
static std::vector<Function*> | ||
GetFunctionsWithSimdArgs(Module *M) { | ||
std::vector<Function*> Ret; | ||
|
||
for (auto &F : M->functions()) { | ||
// Skip all intrinsic calls as these are always tightly controlled to "work | ||
// correctly", so no need to fixup any of these. | ||
if (F.isIntrinsic()) | ||
continue; | ||
|
||
// We're only interested in rustc-defined functions, not unstably-defined | ||
// imported SIMD ffi functions. | ||
if (F.isDeclaration()) | ||
continue; | ||
|
||
// Argument promotion only happens on internal functions, so skip demoting | ||
// arguments in external functions like FFI shims and such. | ||
if (!F.hasLocalLinkage()) | ||
continue; | ||
|
||
// If any argument to this function is a by-value vector type, then that's | ||
// bad! The compiler didn't generate any functions that looked like this, | ||
// and we try to rely on LLVM to not do this! Argument promotion may, | ||
// however, promote arguments from behind references. In any case, figure | ||
// out if we're interested in demoting this argument. | ||
if (any_of(F.args(), [](Argument &arg) { return arg.getType()->isVectorTy(); })) | ||
Ret.push_back(&F); | ||
} | ||
|
||
return Ret; | ||
} | ||
|
||
extern "C" void | ||
LLVMRustDemoteSimdArguments(LLVMModuleRef Mod) { | ||
Module *M = unwrap(Mod); | ||
|
||
auto Functions = GetFunctionsWithSimdArgs(M); | ||
|
||
for (auto F : Functions) { | ||
// Build up our list of new parameters and new argument attributes. | ||
// We're only changing those arguments which are vector types. | ||
SmallVector<Type*, 8> Params; | ||
SmallVector<AttributeSet, 8> ArgAttrVec; | ||
auto PAL = F->getAttributes(); | ||
for (auto &Arg : F->args()) { | ||
auto *Ty = Arg.getType(); | ||
if (Ty->isVectorTy()) { | ||
Params.push_back(PointerType::get(Ty, 0)); | ||
ArgAttrVec.push_back(AttributeSet()); | ||
} else { | ||
Params.push_back(Ty); | ||
ArgAttrVec.push_back(PAL.getParamAttributes(Arg.getArgNo())); | ||
} | ||
} | ||
|
||
// Replace `F` with a new function with our new signature. I'm... not really | ||
// sure how this works, but this is all the steps `ArgumentPromotion` does | ||
// to replace a signature as well. | ||
assert(!F->isVarArg()); // ArgumentPromotion should skip these fns | ||
FunctionType *NFTy = FunctionType::get(F->getReturnType(), Params, false); | ||
Function *NF = Function::Create(NFTy, F->getLinkage(), F->getName()); | ||
NF->copyAttributesFrom(F); | ||
NF->setSubprogram(F->getSubprogram()); | ||
F->setSubprogram(nullptr); | ||
NF->setAttributes(AttributeList::get(F->getContext(), | ||
PAL.getFnAttributes(), | ||
PAL.getRetAttributes(), | ||
ArgAttrVec)); | ||
ArgAttrVec.clear(); | ||
F->getParent()->getFunctionList().insert(F->getIterator(), NF); | ||
NF->takeName(F); | ||
|
||
// Iterate over all invocations of `F`, updating all `call` instructions to | ||
// store immediate vector types in a local `alloc` instead of a by-value | ||
// vector. | ||
// | ||
// Like before, much of this is copied from the `ArgumentPromotion` pass in | ||
// LLVM. | ||
SmallVector<Value*, 16> Args; | ||
while (!F->use_empty()) { | ||
CallSite CS(F->user_back()); | ||
assert(CS.getCalledFunction() == F); | ||
Instruction *Call = CS.getInstruction(); | ||
const AttributeList &CallPAL = CS.getAttributes(); | ||
|
||
// Loop over the operands, inserting an `alloca` and a store for any | ||
// argument we're demoting to be by reference | ||
// | ||
// FIXME: we probably want to figure out an LLVM pass to run and clean up | ||
// this function and instructions we're generating, we should in theory | ||
// only generate a maximum number of `alloca` instructions rather than | ||
// one-per-variable unconditionally. | ||
CallSite::arg_iterator AI = CS.arg_begin(); | ||
size_t ArgNo = 0; | ||
for (Function::arg_iterator I = F->arg_begin(), E = F->arg_end(); I != E; | ||
++I, ++AI, ++ArgNo) { | ||
if (I->getType()->isVectorTy()) { | ||
AllocaInst *AllocA = new AllocaInst(I->getType(), 0, nullptr, "", Call); | ||
new StoreInst(*AI, AllocA, Call); | ||
Args.push_back(AllocA); | ||
ArgAttrVec.push_back(AttributeSet()); | ||
} else { | ||
Args.push_back(*AI); | ||
ArgAttrVec.push_back(CallPAL.getParamAttributes(ArgNo)); | ||
} | ||
} | ||
assert(AI == CS.arg_end()); | ||
|
||
// Create a new call instructions which we'll use to replace the old call | ||
// instruction, copying over as many attributes and such as possible. | ||
SmallVector<OperandBundleDef, 1> OpBundles; | ||
CS.getOperandBundlesAsDefs(OpBundles); | ||
|
||
CallSite NewCS; | ||
if (InvokeInst *II = dyn_cast<InvokeInst>(Call)) { | ||
InvokeInst::Create(NF, II->getNormalDest(), II->getUnwindDest(), | ||
Args, OpBundles, "", Call); | ||
} else { | ||
auto *NewCall = CallInst::Create(NF, Args, OpBundles, "", Call); | ||
NewCall->setTailCallKind(cast<CallInst>(Call)->getTailCallKind()); | ||
NewCS = NewCall; | ||
} | ||
NewCS.setCallingConv(CS.getCallingConv()); | ||
NewCS.setAttributes( | ||
AttributeList::get(F->getContext(), CallPAL.getFnAttributes(), | ||
CallPAL.getRetAttributes(), ArgAttrVec)); | ||
NewCS->setDebugLoc(Call->getDebugLoc()); | ||
Args.clear(); | ||
ArgAttrVec.clear(); | ||
Call->replaceAllUsesWith(NewCS.getInstruction()); | ||
NewCS->takeName(Call); | ||
Call->eraseFromParent(); | ||
} | ||
|
||
// Splice the body of the old function right into the new function. | ||
NF->getBasicBlockList().splice(NF->begin(), F->getBasicBlockList()); | ||
|
||
// Update our new function to replace all uses of the by-value argument with | ||
// loads of the pointer argument we've generated. | ||
// | ||
// FIXME: we probably want to only generate one load instruction per | ||
// function? Or maybe run an LLVM pass to clean up this function? | ||
for (Function::arg_iterator I = F->arg_begin(), | ||
E = F->arg_end(), | ||
I2 = NF->arg_begin(); | ||
I != E; | ||
++I, ++I2) { | ||
if (I->getType()->isVectorTy()) { | ||
I->replaceAllUsesWith(new LoadInst(&*I2, "", &NF->begin()->front())); | ||
} else { | ||
I->replaceAllUsesWith(&*I2); | ||
} | ||
I2->takeName(&*I); | ||
} | ||
|
||
// Delete all references to the old function, it should be entirely dead | ||
// now. | ||
M->getFunctionList().remove(F); | ||
} | ||
} |
13 changes: 13 additions & 0 deletions
13
src/test/run-make/simd-argument-promotion-thwarted/Makefile
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
-include ../../run-make-fulldeps/tools.mk | ||
|
||
ifeq ($(TARGET),x86_64-unknown-linux-gnu) | ||
all: | ||
$(RUSTC) t1.rs -C opt-level=3 | ||
$(TMPDIR)/t1 | ||
$(RUSTC) t2.rs -C opt-level=3 | ||
$(TMPDIR)/t2 | ||
$(RUSTC) t3.rs -C opt-level=3 | ||
$(TMPDIR)/t3 | ||
else | ||
all: | ||
endif |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
use std::arch::x86_64; | ||
|
||
fn main() { | ||
if !is_x86_feature_detected!("avx2") { | ||
return println!("AVX2 is not supported on this machine/build."); | ||
} | ||
let load_bytes: [u8; 32] = [0x0f; 32]; | ||
let lb_ptr = load_bytes.as_ptr(); | ||
let reg_load = unsafe { | ||
x86_64::_mm256_loadu_si256( | ||
lb_ptr as *const x86_64::__m256i | ||
) | ||
}; | ||
println!("{:?}", reg_load); | ||
let mut store_bytes: [u8; 32] = [0; 32]; | ||
let sb_ptr = store_bytes.as_mut_ptr(); | ||
unsafe { | ||
x86_64::_mm256_storeu_si256(sb_ptr as *mut x86_64::__m256i, reg_load); | ||
} | ||
assert_eq!(load_bytes, store_bytes); | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,14 @@ | ||
use std::arch::x86_64::*; | ||
|
||
fn main() { | ||
if !is_x86_feature_detected!("avx") { | ||
return println!("AVX is not supported on this machine/build."); | ||
} | ||
unsafe { | ||
let f = _mm256_set_pd(2.0, 2.0, 2.0, 2.0); | ||
let r = _mm256_mul_pd(f, f); | ||
|
||
union A { a: __m256d, b: [f64; 4] } | ||
assert_eq!(A { a: r }.b, [4.0, 4.0, 4.0, 4.0]); | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,52 @@ | ||
use std::arch::x86_64::*; | ||
|
||
#[target_feature(enable = "avx")] | ||
unsafe fn avx_mul(a: __m256, b: __m256) -> __m256 { | ||
_mm256_mul_ps(a, b) | ||
} | ||
|
||
#[target_feature(enable = "avx")] | ||
unsafe fn avx_store(p: *mut f32, a: __m256) { | ||
_mm256_storeu_ps(p, a) | ||
} | ||
|
||
#[target_feature(enable = "avx")] | ||
unsafe fn avx_setr(a: f32, b: f32, c: f32, d: f32, e: f32, f: f32, g: f32, h: f32) -> __m256 { | ||
_mm256_setr_ps(a, b, c, d, e, f, g, h) | ||
} | ||
|
||
#[target_feature(enable = "avx")] | ||
unsafe fn avx_set1(a: f32) -> __m256 { | ||
_mm256_set1_ps(a) | ||
} | ||
|
||
struct Avx(__m256); | ||
|
||
fn mul(a: Avx, b: Avx) -> Avx { | ||
unsafe { Avx(avx_mul(a.0, b.0)) } | ||
} | ||
|
||
fn set1(a: f32) -> Avx { | ||
unsafe { Avx(avx_set1(a)) } | ||
} | ||
|
||
fn setr(a: f32, b: f32, c: f32, d: f32, e: f32, f: f32, g: f32, h: f32) -> Avx { | ||
unsafe { Avx(avx_setr(a, b, c, d, e, f, g, h)) } | ||
} | ||
|
||
unsafe fn store(p: *mut f32, a: Avx) { | ||
avx_store(p, a.0); | ||
} | ||
|
||
fn main() { | ||
if !is_x86_feature_detected!("avx") { | ||
return println!("AVX is not supported on this machine/build."); | ||
} | ||
let mut result = [0.0f32; 8]; | ||
let a = mul(setr(0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0), set1(0.25)); | ||
unsafe { | ||
store(result.as_mut_ptr(), a); | ||
} | ||
|
||
assert_eq!(result, [0.0, 0.25, 0.5, 0.75, 1.0, 1.25, 1.50, 1.75]); | ||
} |