Skip to content

Commit

Permalink
Revert "codegen: explicitly handle Float16 intrinsics (#45249)" (#45627)
Browse files Browse the repository at this point in the history
This reverts commit eb82f18.
  • Loading branch information
KristofferC authored Jun 13, 2022
1 parent 43df1f4 commit 3a2eb39
Show file tree
Hide file tree
Showing 5 changed files with 82 additions and 284 deletions.
6 changes: 3 additions & 3 deletions src/APInt-C.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -316,7 +316,7 @@ void LLVMByteSwap(unsigned numbits, integerPart *pa, integerPart *pr) {
void LLVMFPtoInt(unsigned numbits, void *pa, unsigned onumbits, integerPart *pr, bool isSigned, bool *isExact) {
double Val;
if (numbits == 16)
Val = julia__gnu_h2f_ieee(*(uint16_t*)pa);
Val = __gnu_h2f_ieee(*(uint16_t*)pa);
else if (numbits == 32)
Val = *(float*)pa;
else if (numbits == 64)
Expand Down Expand Up @@ -391,7 +391,7 @@ void LLVMSItoFP(unsigned numbits, integerPart *pa, unsigned onumbits, integerPar
val = a.roundToDouble(true);
}
if (onumbits == 16)
*(uint16_t*)pr = julia__gnu_f2h_ieee(val);
*(uint16_t*)pr = __gnu_f2h_ieee(val);
else if (onumbits == 32)
*(float*)pr = val;
else if (onumbits == 64)
Expand All @@ -408,7 +408,7 @@ void LLVMUItoFP(unsigned numbits, integerPart *pa, unsigned onumbits, integerPar
val = a.roundToDouble(false);
}
if (onumbits == 16)
*(uint16_t*)pr = julia__gnu_f2h_ieee(val);
*(uint16_t*)pr = __gnu_f2h_ieee(val);
else if (onumbits == 32)
*(float*)pr = val;
else if (onumbits == 64)
Expand Down
6 changes: 6 additions & 0 deletions src/julia.expmap
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,12 @@
environ;
__progname;

/* compiler run-time intrinsics */
__gnu_h2f_ieee;
__extendhfsf2;
__gnu_f2h_ieee;
__truncdfhf2;

local:
*;
};
14 changes: 2 additions & 12 deletions src/julia_internal.h
Original file line number Diff line number Diff line change
Expand Up @@ -1522,18 +1522,8 @@ jl_sym_t *_jl_symbol(const char *str, size_t len) JL_NOTSAFEPOINT;
#define JL_GC_ASSERT_LIVE(x) (void)(x)
#endif

JL_DLLEXPORT float julia__gnu_h2f_ieee(uint16_t param) JL_NOTSAFEPOINT;
JL_DLLEXPORT uint16_t julia__gnu_f2h_ieee(float param) JL_NOTSAFEPOINT;
JL_DLLEXPORT uint16_t julia__truncdfhf2(double param) JL_NOTSAFEPOINT;
//JL_DLLEXPORT double julia__extendhfdf2(uint16_t n) JL_NOTSAFEPOINT;
//JL_DLLEXPORT int32_t julia__fixhfsi(uint16_t n) JL_NOTSAFEPOINT;
//JL_DLLEXPORT int64_t julia__fixhfdi(uint16_t n) JL_NOTSAFEPOINT;
//JL_DLLEXPORT uint32_t julia__fixunshfsi(uint16_t n) JL_NOTSAFEPOINT;
//JL_DLLEXPORT uint64_t julia__fixunshfdi(uint16_t n) JL_NOTSAFEPOINT;
//JL_DLLEXPORT uint16_t julia__floatsihf(int32_t n) JL_NOTSAFEPOINT;
//JL_DLLEXPORT uint16_t julia__floatdihf(int64_t n) JL_NOTSAFEPOINT;
//JL_DLLEXPORT uint16_t julia__floatunsihf(uint32_t n) JL_NOTSAFEPOINT;
//JL_DLLEXPORT uint16_t julia__floatundihf(uint64_t n) JL_NOTSAFEPOINT;
float __gnu_h2f_ieee(uint16_t param) JL_NOTSAFEPOINT;
uint16_t __gnu_f2h_ieee(float param) JL_NOTSAFEPOINT;

#ifdef __cplusplus
}
Expand Down
276 changes: 44 additions & 232 deletions src/llvm-demote-float16.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@

#include "support/dtypes.h"

#include <llvm/Pass.h>
#include <llvm/IR/IRBuilder.h>
#include <llvm/IR/LegacyPassManager.h>
#include <llvm/IR/PassManager.h>
Expand All @@ -29,193 +28,15 @@ using namespace llvm;

namespace {

inline AttributeSet getFnAttrs(const AttributeList &Attrs)
{
#if JL_LLVM_VERSION >= 140000
return Attrs.getFnAttrs();
#else
return Attrs.getFnAttributes();
#endif
}

inline AttributeSet getRetAttrs(const AttributeList &Attrs)
{
#if JL_LLVM_VERSION >= 140000
return Attrs.getRetAttrs();
#else
return Attrs.getRetAttributes();
#endif
}

static Instruction *replaceIntrinsicWith(IntrinsicInst *call, Type *RetTy, ArrayRef<Value*> args)
{
Intrinsic::ID ID = call->getIntrinsicID();
assert(ID);
auto oldfType = call->getFunctionType();
auto nargs = oldfType->getNumParams();
assert(args.size() > nargs);
SmallVector<Type*, 8> argTys(nargs);
for (unsigned i = 0; i < nargs; i++)
argTys[i] = args[i]->getType();
auto newfType = FunctionType::get(RetTy, argTys, oldfType->isVarArg());

// Accumulate an array of overloaded types for the given intrinsic
// and compute the new name mangling schema
SmallVector<Type*, 4> overloadTys;
{
SmallVector<Intrinsic::IITDescriptor, 8> Table;
getIntrinsicInfoTableEntries(ID, Table);
ArrayRef<Intrinsic::IITDescriptor> TableRef = Table;
auto res = Intrinsic::matchIntrinsicSignature(newfType, TableRef, overloadTys);
assert(res == Intrinsic::MatchIntrinsicTypes_Match);
(void)res;
bool matchvararg = !Intrinsic::matchIntrinsicVarArg(newfType->isVarArg(), TableRef);
assert(matchvararg);
(void)matchvararg;
}
auto newF = Intrinsic::getDeclaration(call->getModule(), ID, overloadTys);
assert(newF->getFunctionType() == newfType);
newF->setCallingConv(call->getCallingConv());
assert(args.back() == call->getCalledFunction());
auto newCall = CallInst::Create(newF, args.drop_back(), "", call);
newCall->setTailCallKind(call->getTailCallKind());
auto old_attrs = call->getAttributes();
newCall->setAttributes(AttributeList::get(call->getContext(), getFnAttrs(old_attrs),
getRetAttrs(old_attrs), {})); // drop parameter attributes
return newCall;
}


static Value* CreateFPCast(Instruction::CastOps opcode, Value *V, Type *DestTy, IRBuilder<> &builder)
{
Type *SrcTy = V->getType();
Type *RetTy = DestTy;
if (auto *VC = dyn_cast<Constant>(V)) {
// The input IR often has things of the form
// fcmp olt half %0, 0xH7C00
// and we would like to avoid turning that constant into a call here
// if we can simply constant fold it to the new type.
VC = ConstantExpr::getCast(opcode, VC, DestTy, true);
if (VC)
return VC;
}
assert(SrcTy->isVectorTy() == DestTy->isVectorTy());
if (SrcTy->isVectorTy()) {
unsigned NumElems = cast<FixedVectorType>(SrcTy)->getNumElements();
assert(cast<FixedVectorType>(DestTy)->getNumElements() == NumElems && "Mismatched cast");
Value *NewV = UndefValue::get(DestTy);
RetTy = RetTy->getScalarType();
for (unsigned i = 0; i < NumElems; ++i) {
Value *I = builder.getInt32(i);
Value *Vi = builder.CreateExtractElement(V, I);
Vi = CreateFPCast(opcode, Vi, RetTy, builder);
NewV = builder.CreateInsertElement(NewV, Vi, I);
}
return NewV;
}
auto &M = *builder.GetInsertBlock()->getModule();
auto &ctx = M.getContext();
// Pick the Function to call in the Julia runtime
StringRef Name;
switch (opcode) {
case Instruction::FPExt:
// this is exact, so we only need one conversion
assert(SrcTy->isHalfTy());
Name = "julia__gnu_h2f_ieee";
RetTy = Type::getFloatTy(ctx);
break;
case Instruction::FPTrunc:
assert(DestTy->isHalfTy());
if (SrcTy->isFloatTy())
Name = "julia__gnu_f2h_ieee";
else if (SrcTy->isDoubleTy())
Name = "julia__truncdfhf2";
break;
// All F16 fit exactly in Int32 (-65504 to 65504)
case Instruction::FPToSI: JL_FALLTHROUGH;
case Instruction::FPToUI:
assert(SrcTy->isHalfTy());
Name = "julia__gnu_h2f_ieee";
RetTy = Type::getFloatTy(ctx);
break;
case Instruction::SIToFP: JL_FALLTHROUGH;
case Instruction::UIToFP:
assert(DestTy->isHalfTy());
Name = "julia__gnu_f2h_ieee";
SrcTy = Type::getFloatTy(ctx);
break;
default:
errs() << Instruction::getOpcodeName(opcode) << ' ';
V->getType()->print(errs());
errs() << " to ";
DestTy->print(errs());
errs() << " is an ";
llvm_unreachable("invalid cast");
}
if (Name.empty()) {
errs() << Instruction::getOpcodeName(opcode) << ' ';
V->getType()->print(errs());
errs() << " to ";
DestTy->print(errs());
errs() << " is an ";
llvm_unreachable("illegal cast");
}
// Coerce the source to the required size and type
auto T_int16 = Type::getInt16Ty(ctx);
if (SrcTy->isHalfTy())
SrcTy = T_int16;
if (opcode == Instruction::SIToFP)
V = builder.CreateSIToFP(V, SrcTy);
else if (opcode == Instruction::UIToFP)
V = builder.CreateUIToFP(V, SrcTy);
else
V = builder.CreateBitCast(V, SrcTy);
// Call our intrinsic
if (RetTy->isHalfTy())
RetTy = T_int16;
auto FT = FunctionType::get(RetTy, {SrcTy}, false);
FunctionCallee F = M.getOrInsertFunction(Name, FT);
Value *I = builder.CreateCall(F, {V});
// Coerce the result to the expected type
if (opcode == Instruction::FPToSI)
I = builder.CreateFPToSI(I, DestTy);
else if (opcode == Instruction::FPToUI)
I = builder.CreateFPToUI(I, DestTy);
else if (opcode == Instruction::FPExt)
I = builder.CreateFPCast(I, DestTy);
else
I = builder.CreateBitCast(I, DestTy);
return I;
}

static bool demoteFloat16(Function &F)
{
auto &ctx = F.getContext();
auto T_float16 = Type::getHalfTy(ctx);
auto T_float32 = Type::getFloatTy(ctx);

SmallVector<Instruction *, 0> erase;
for (auto &BB : F) {
for (auto &I : BB) {
// extend Float16 operands to Float32
bool Float16 = I.getType()->getScalarType()->isHalfTy();
for (size_t i = 0; !Float16 && i < I.getNumOperands(); i++) {
Value *Op = I.getOperand(i);
if (Op->getType()->getScalarType()->isHalfTy())
Float16 = true;
}
if (!Float16)
continue;

if (auto CI = dyn_cast<CastInst>(&I)) {
if (CI->getOpcode() != Instruction::BitCast) { // aka !CI->isNoopCast(DL)
IRBuilder<> builder(&I);
Value *NewI = CreateFPCast(CI->getOpcode(), I.getOperand(0), I.getType(), builder);
I.replaceAllUsesWith(NewI);
erase.push_back(&I);
}
continue;
}

switch (I.getOpcode()) {
case Instruction::FNeg:
case Instruction::FAdd:
Expand All @@ -226,9 +47,6 @@ static bool demoteFloat16(Function &F)
case Instruction::FCmp:
break;
default:
if (auto intrinsic = dyn_cast<IntrinsicInst>(&I))
if (intrinsic->getIntrinsicID())
break;
continue;
}

Expand All @@ -240,67 +58,61 @@ static bool demoteFloat16(Function &F)
IRBuilder<> builder(&I);

// extend Float16 operands to Float32
// XXX: Calls to llvm.fma.f16 may need to go to f64 to be correct?
bool OperandsChanged = false;
SmallVector<Value *, 2> Operands(I.getNumOperands());
for (size_t i = 0; i < I.getNumOperands(); i++) {
Value *Op = I.getOperand(i);
if (Op->getType()->getScalarType()->isHalfTy()) {
Op = CreateFPCast(Instruction::FPExt, Op, Op->getType()->getWithNewType(T_float32), builder);
if (Op->getType() == T_float16) {
Op = builder.CreateFPExt(Op, T_float32);
OperandsChanged = true;
}
Operands[i] = (Op);
}

// recreate the instruction if any operands changed,
// truncating the result back to Float16
Value *NewI;
switch (I.getOpcode()) {
case Instruction::FNeg:
assert(Operands.size() == 1);
NewI = builder.CreateFNeg(Operands[0]);
break;
case Instruction::FAdd:
assert(Operands.size() == 2);
NewI = builder.CreateFAdd(Operands[0], Operands[1]);
break;
case Instruction::FSub:
assert(Operands.size() == 2);
NewI = builder.CreateFSub(Operands[0], Operands[1]);
break;
case Instruction::FMul:
assert(Operands.size() == 2);
NewI = builder.CreateFMul(Operands[0], Operands[1]);
break;
case Instruction::FDiv:
assert(Operands.size() == 2);
NewI = builder.CreateFDiv(Operands[0], Operands[1]);
break;
case Instruction::FRem:
assert(Operands.size() == 2);
NewI = builder.CreateFRem(Operands[0], Operands[1]);
break;
case Instruction::FCmp:
assert(Operands.size() == 2);
NewI = builder.CreateFCmp(cast<FCmpInst>(&I)->getPredicate(),
Operands[0], Operands[1]);
break;
default:
if (auto intrinsic = dyn_cast<IntrinsicInst>(&I)) {
// XXX: this is not correct in general
// some obvious failures include llvm.convert.to.fp16.*, llvm.vp.*to*, llvm.experimental.constrained.*to*, llvm.masked.*
Type *RetTy = I.getType();
if (RetTy->getScalarType()->isHalfTy())
RetTy = RetTy->getWithNewType(T_float32);
NewI = replaceIntrinsicWith(intrinsic, RetTy, Operands);
if (OperandsChanged) {
Value *NewI;
switch (I.getOpcode()) {
case Instruction::FNeg:
assert(Operands.size() == 1);
NewI = builder.CreateFNeg(Operands[0]);
break;
case Instruction::FAdd:
assert(Operands.size() == 2);
NewI = builder.CreateFAdd(Operands[0], Operands[1]);
break;
case Instruction::FSub:
assert(Operands.size() == 2);
NewI = builder.CreateFSub(Operands[0], Operands[1]);
break;
case Instruction::FMul:
assert(Operands.size() == 2);
NewI = builder.CreateFMul(Operands[0], Operands[1]);
break;
case Instruction::FDiv:
assert(Operands.size() == 2);
NewI = builder.CreateFDiv(Operands[0], Operands[1]);
break;
case Instruction::FRem:
assert(Operands.size() == 2);
NewI = builder.CreateFRem(Operands[0], Operands[1]);
break;
case Instruction::FCmp:
assert(Operands.size() == 2);
NewI = builder.CreateFCmp(cast<FCmpInst>(&I)->getPredicate(),
Operands[0], Operands[1]);
break;
default:
abort();
}
abort();
cast<Instruction>(NewI)->copyMetadata(I);
cast<Instruction>(NewI)->copyFastMathFlags(&I);
if (NewI->getType() != I.getType())
NewI = builder.CreateFPTrunc(NewI, I.getType());
I.replaceAllUsesWith(NewI);
erase.push_back(&I);
}
cast<Instruction>(NewI)->copyMetadata(I);
cast<Instruction>(NewI)->copyFastMathFlags(&I);
if (NewI->getType() != I.getType())
NewI = CreateFPCast(Instruction::FPTrunc, NewI, I.getType(), builder);
I.replaceAllUsesWith(NewI);
erase.push_back(&I);
}
}

Expand Down
Loading

0 comments on commit 3a2eb39

Please sign in to comment.