From f8b2c46d49c8bc101d17ab7e0b5cf73f04e99478 Mon Sep 17 00:00:00 2001 From: Jameson Nash Date: Wed, 13 Apr 2022 18:00:44 -0400 Subject: [PATCH] codegen: explicitly handle Float16 intrinsics Fixes #44829, until llvm fixes the support for these intrinsics itself --- base/float.jl | 4 +- src/APInt-C.cpp | 6 +- src/julia.expmap | 6 - src/julia_internal.h | 14 +- src/llvm-demote-float16.cpp | 268 ++++++++++++++++++++++++++++-------- src/runtime_intrinsics.c | 53 ++++--- 6 files changed, 264 insertions(+), 87 deletions(-) diff --git a/base/float.jl b/base/float.jl index 60850b7e02f64a..5a9c41f89692dd 100644 --- a/base/float.jl +++ b/base/float.jl @@ -344,8 +344,8 @@ function unsafe_trunc(::Type{Int128}, x::Float32) copysign(unsafe_trunc(UInt128,x) % Int128, x) end -unsafe_trunc(::Type{UInt128}, x::Float16) = unsafe_trunc(UInt128, Float32(x)) -unsafe_trunc(::Type{Int128}, x::Float16) = unsafe_trunc(Int128, Float32(x)) +unsafe_trunc(::Type{UInt128}, x::Float16) = fptoui(UInt128, x) +unsafe_trunc(::Type{Int128}, x::Float16) = fptosi(Int128, x) # matches convert methods # also determines floor, ceil, round diff --git a/src/APInt-C.cpp b/src/APInt-C.cpp index bc0a62e21dd3ef..f06d4362bf9588 100644 --- a/src/APInt-C.cpp +++ b/src/APInt-C.cpp @@ -316,7 +316,7 @@ void LLVMByteSwap(unsigned numbits, integerPart *pa, integerPart *pr) { void LLVMFPtoInt(unsigned numbits, void *pa, unsigned onumbits, integerPart *pr, bool isSigned, bool *isExact) { double Val; if (numbits == 16) - Val = __gnu_h2f_ieee(*(uint16_t*)pa); + Val = julia__gnu_h2f_ieee(*(uint16_t*)pa); else if (numbits == 32) Val = *(float*)pa; else if (numbits == 64) @@ -391,7 +391,7 @@ void LLVMSItoFP(unsigned numbits, integerPart *pa, unsigned onumbits, integerPar val = a.roundToDouble(true); } if (onumbits == 16) - *(uint16_t*)pr = __gnu_f2h_ieee(val); + *(uint16_t*)pr = julia__gnu_f2h_ieee(val); else if (onumbits == 32) *(float*)pr = val; else if (onumbits == 64) @@ -408,7 +408,7 @@ void LLVMUItoFP(unsigned numbits, integerPart *pa, unsigned onumbits, integerPar val = a.roundToDouble(false); } if (onumbits == 16) - *(uint16_t*)pr = __gnu_f2h_ieee(val); + *(uint16_t*)pr = julia__gnu_f2h_ieee(val); else if (onumbits == 32) *(float*)pr = val; else if (onumbits == 64) diff --git a/src/julia.expmap b/src/julia.expmap index 13de1b873f7c3f..6e373798102b2c 100644 --- a/src/julia.expmap +++ b/src/julia.expmap @@ -37,12 +37,6 @@ environ; __progname; - /* compiler run-time intrinsics */ - __gnu_h2f_ieee; - __extendhfsf2; - __gnu_f2h_ieee; - __truncdfhf2; - local: *; }; diff --git a/src/julia_internal.h b/src/julia_internal.h index 02130ef963198f..74a16d718d7cdb 100644 --- a/src/julia_internal.h +++ b/src/julia_internal.h @@ -1523,8 +1523,18 @@ jl_sym_t *_jl_symbol(const char *str, size_t len) JL_NOTSAFEPOINT; #define JL_GC_ASSERT_LIVE(x) (void)(x) #endif -float __gnu_h2f_ieee(uint16_t param) JL_NOTSAFEPOINT; -uint16_t __gnu_f2h_ieee(float param) JL_NOTSAFEPOINT; +JL_DLLEXPORT float julia__gnu_h2f_ieee(uint16_t param) JL_NOTSAFEPOINT; +JL_DLLEXPORT uint16_t julia__gnu_f2h_ieee(float param) JL_NOTSAFEPOINT; +JL_DLLEXPORT uint16_t julia__truncdfhf2(double param) JL_NOTSAFEPOINT; +//JL_DLLEXPORT double julia__extendhfdf2(uint16_t n) JL_NOTSAFEPOINT; +//JL_DLLEXPORT int32_t julia__fixhfsi(uint16_t n) JL_NOTSAFEPOINT; +//JL_DLLEXPORT int64_t julia__fixhfdi(uint16_t n) JL_NOTSAFEPOINT; +//JL_DLLEXPORT uint32_t julia__fixunshfsi(uint16_t n) JL_NOTSAFEPOINT; +//JL_DLLEXPORT uint64_t julia__fixunshfdi(uint16_t n) JL_NOTSAFEPOINT; +//JL_DLLEXPORT uint16_t julia__floatsihf(int32_t n) JL_NOTSAFEPOINT; +//JL_DLLEXPORT uint16_t julia__floatdihf(int64_t n) JL_NOTSAFEPOINT; +//JL_DLLEXPORT uint16_t julia__floatunsihf(uint32_t n) JL_NOTSAFEPOINT; +//JL_DLLEXPORT uint16_t julia__floatundihf(uint64_t n) JL_NOTSAFEPOINT; #ifdef __cplusplus } diff --git a/src/llvm-demote-float16.cpp b/src/llvm-demote-float16.cpp index 300be27cf90793..29bc2adc6b9f7f 100644 --- a/src/llvm-demote-float16.cpp +++ b/src/llvm-demote-float16.cpp @@ -45,15 +45,169 @@ INST_STATISTIC(FCmp); namespace { +inline AttributeSet getFnAttrs(const AttributeList &Attrs) +{ +#if JL_LLVM_VERSION >= 140000 + return Attrs.getFnAttrs(); +#else + return Attrs.getFnAttributes(); +#endif +} + +inline AttributeSet getRetAttrs(const AttributeList &Attrs) +{ +#if JL_LLVM_VERSION >= 140000 + return Attrs.getRetAttrs(); +#else + return Attrs.getRetAttributes(); +#endif +} + +static Instruction *replaceIntrinsicWith(IntrinsicInst *call, Type *RetType, ArrayRef args) +{ + Intrinsic::ID ID = call->getIntrinsicID(); + assert(ID); + auto oldfType = call->getFunctionType(); + auto nargs = oldfType->getNumParams(); + assert(args.size() > nargs); + SmallVector argTys(nargs); + for (unsigned i = 0; i < nargs; i++) + argTys[i] = args[i]->getType(); + auto newfType = FunctionType::get(RetType, argTys, oldfType->isVarArg()); + + // Accumulate an array of overloaded types for the given intrinsic + // and compute the new name mangling schema + SmallVector overloadTys; + { + SmallVector Table; + getIntrinsicInfoTableEntries(ID, Table); + ArrayRef TableRef = Table; + auto res = Intrinsic::matchIntrinsicSignature(newfType, TableRef, overloadTys); + assert(res == Intrinsic::MatchIntrinsicTypes_Match); + (void)res; + bool matchvararg = !Intrinsic::matchIntrinsicVarArg(newfType->isVarArg(), TableRef); + assert(matchvararg); + (void)matchvararg; + } + auto newF = Intrinsic::getDeclaration(call->getModule(), ID, overloadTys); + assert(newF->getFunctionType() == newfType); + newF->setCallingConv(call->getCallingConv()); + assert(args.back() == call->getCalledFunction()); + auto newCall = CallInst::Create(newF, args.drop_back(), "", call); + newCall->setTailCallKind(call->getTailCallKind()); + auto old_attrs = call->getAttributes(); + newCall->setAttributes(AttributeList::get(call->getContext(), getFnAttrs(old_attrs), + getRetAttrs(old_attrs), {})); // drop parameter attributes + return newCall; +} + + +static Value* CreateFPCast(Instruction::CastOps opcode, Value *V, Type *DestTy, IRBuilder<> &builder) +{ + if (auto *VC = dyn_cast(V)) { + // The input IR often has things of the form + // fcmp olt half %0, 0xH7C00 + // and we would like to avoid turning that constant into a call here + // if we can simply constant fold it to the new type. + VC = ConstantExpr::getCast(opcode, VC, DestTy, true); + if (VC) + return VC; + } + auto &M = *builder.GetInsertBlock()->getModule(); + auto &ctx = M.getContext(); + Type *SrcTy = V->getType(); + Type *RetTy = DestTy; + // Pick the Function to call in the Julia runtime + StringRef Name; + switch (opcode) { + case Instruction::FPExt: + // this is exact, so we only need one conversion + assert(SrcTy->isHalfTy()); + Name = "julia__gnu_h2f_ieee"; + RetTy = Type::getFloatTy(ctx); + break; + case Instruction::FPTrunc: + assert(DestTy->isHalfTy()); + if (SrcTy->isFloatTy()) + Name = "julia__gnu_f2h_ieee"; + else if (SrcTy->isDoubleTy()) + Name = "julia__truncdfhf2"; + break; + // All F16 fit exactly in Int32 (-65504 to 65504) + case Instruction::FPToSI: JL_FALLTHROUGH; + case Instruction::FPToUI: + assert(SrcTy->isHalfTy()); + Name = "julia__gnu_h2f_ieee"; + RetTy = Type::getFloatTy(ctx); + break; + case Instruction::SIToFP: JL_FALLTHROUGH; + case Instruction::UIToFP: + assert(DestTy->isHalfTy()); + Name = "julia__gnu_f2h_ieee"; + SrcTy = Type::getFloatTy(ctx); + break; + default: + llvm_unreachable("invalid cast"); + } + if (Name.empty()) + llvm_unreachable("illegal cast"); + // Coerce the source to the required size and type + auto T_int16 = Type::getInt16Ty(ctx); + if (SrcTy->isHalfTy()) + SrcTy = T_int16; + if (opcode == Instruction::SIToFP) + V = builder.CreateSIToFP(V, SrcTy); + else if (opcode == Instruction::UIToFP) + V = builder.CreateUIToFP(V, SrcTy); + else + V = builder.CreateBitCast(V, SrcTy); + // Call our intrinsic + if (RetTy->isHalfTy()) + RetTy = T_int16; + auto FT = FunctionType::get(RetTy, {SrcTy}, false); + FunctionCallee F = M.getOrInsertFunction(Name, FT); + Value *I = builder.CreateCall(F, {V}); + // Coerce the result to the expected type + if (opcode == Instruction::FPToSI) + I = builder.CreateFPToSI(I, DestTy); + else if (opcode == Instruction::FPToUI) + I = builder.CreateFPToUI(I, DestTy); + else if (opcode == Instruction::FPExt) + I = builder.CreateFPCast(I, DestTy); + else + I = builder.CreateBitCast(I, DestTy); + return I; +} + static bool demoteFloat16(Function &F) { auto &ctx = F.getContext(); - auto T_float16 = Type::getHalfTy(ctx); auto T_float32 = Type::getFloatTy(ctx); SmallVector erase; for (auto &BB : F) { for (auto &I : BB) { + // extend Float16 operands to Float32 + bool Float16 = I.getType()->isHalfTy(); + for (size_t i = 0; !Float16 && i < I.getNumOperands(); i++) { + Value *Op = I.getOperand(i); + if (Op->getType()->isHalfTy()) + Float16 = true; + } + if (!Float16) + continue; + + if (auto CI = dyn_cast(&I)) { + if (CI->getOpcode() != Instruction::BitCast) { // aka !CI->isNoopCast(DL) + ++TotalChanged; + IRBuilder<> builder(&I); + Value *NewI = CreateFPCast(CI->getOpcode(), I.getOperand(0), I.getType(), builder); + I.replaceAllUsesWith(NewI); + erase.push_back(&I); + } + continue; + } + switch (I.getOpcode()) { case Instruction::FNeg: case Instruction::FAdd: @@ -64,6 +218,9 @@ static bool demoteFloat16(Function &F) case Instruction::FCmp: break; default: + if (auto intrinsic = dyn_cast(&I)) + if (intrinsic->getIntrinsicID()) + break; continue; } @@ -75,72 +232,75 @@ static bool demoteFloat16(Function &F) IRBuilder<> builder(&I); // extend Float16 operands to Float32 - bool OperandsChanged = false; SmallVector Operands(I.getNumOperands()); for (size_t i = 0; i < I.getNumOperands(); i++) { Value *Op = I.getOperand(i); - if (Op->getType() == T_float16) { + if (Op->getType()->isHalfTy()) { ++TotalExt; - Op = builder.CreateFPExt(Op, T_float32); - OperandsChanged = true; + Op = CreateFPCast(Instruction::FPExt, Op, T_float32, builder); } Operands[i] = (Op); } // recreate the instruction if any operands changed, // truncating the result back to Float16 - if (OperandsChanged) { - Value *NewI; - ++TotalChanged; - switch (I.getOpcode()) { - case Instruction::FNeg: - assert(Operands.size() == 1); - ++FNegChanged; - NewI = builder.CreateFNeg(Operands[0]); - break; - case Instruction::FAdd: - assert(Operands.size() == 2); - ++FAddChanged; - NewI = builder.CreateFAdd(Operands[0], Operands[1]); - break; - case Instruction::FSub: - assert(Operands.size() == 2); - ++FSubChanged; - NewI = builder.CreateFSub(Operands[0], Operands[1]); - break; - case Instruction::FMul: - assert(Operands.size() == 2); - ++FMulChanged; - NewI = builder.CreateFMul(Operands[0], Operands[1]); - break; - case Instruction::FDiv: - assert(Operands.size() == 2); - ++FDivChanged; - NewI = builder.CreateFDiv(Operands[0], Operands[1]); - break; - case Instruction::FRem: - assert(Operands.size() == 2); - ++FRemChanged; - NewI = builder.CreateFRem(Operands[0], Operands[1]); - break; - case Instruction::FCmp: - assert(Operands.size() == 2); - ++FCmpChanged; - NewI = builder.CreateFCmp(cast(&I)->getPredicate(), - Operands[0], Operands[1]); + Value *NewI; + ++TotalChanged; + switch (I.getOpcode()) { + case Instruction::FNeg: + assert(Operands.size() == 1); + ++FNegChanged; + NewI = builder.CreateFNeg(Operands[0]); + break; + case Instruction::FAdd: + assert(Operands.size() == 2); + ++FAddChanged; + NewI = builder.CreateFAdd(Operands[0], Operands[1]); + break; + case Instruction::FSub: + assert(Operands.size() == 2); + ++FSubChanged; + NewI = builder.CreateFSub(Operands[0], Operands[1]); + break; + case Instruction::FMul: + assert(Operands.size() == 2); + ++FMulChanged; + NewI = builder.CreateFMul(Operands[0], Operands[1]); + break; + case Instruction::FDiv: + assert(Operands.size() == 2); + ++FDivChanged; + NewI = builder.CreateFDiv(Operands[0], Operands[1]); + break; + case Instruction::FRem: + assert(Operands.size() == 2); + ++FRemChanged; + NewI = builder.CreateFRem(Operands[0], Operands[1]); + break; + case Instruction::FCmp: + assert(Operands.size() == 2); + ++FCmpChanged; + NewI = builder.CreateFCmp(cast(&I)->getPredicate(), + Operands[0], Operands[1]); + break; + default: + if (auto intrinsic = dyn_cast(&I)) { + Type *RetType = I.getType(); + if (RetType->isHalfTy()) + RetType = T_float32; + NewI = replaceIntrinsicWith(intrinsic, RetType, Operands); break; - default: - abort(); - } - cast(NewI)->copyMetadata(I); - cast(NewI)->copyFastMathFlags(&I); - if (NewI->getType() != I.getType()) { - ++TotalTrunc; - NewI = builder.CreateFPTrunc(NewI, I.getType()); } - I.replaceAllUsesWith(NewI); - erase.push_back(&I); + abort(); + } + cast(NewI)->copyMetadata(I); + cast(NewI)->copyFastMathFlags(&I); + if (NewI->getType() != I.getType()) { + ++TotalTrunc; + NewI = CreateFPCast(Instruction::FPTrunc, NewI, I.getType(), builder); } + I.replaceAllUsesWith(NewI); + erase.push_back(&I); } } diff --git a/src/runtime_intrinsics.c b/src/runtime_intrinsics.c index 89c9449e55920c..44108af865e3fb 100644 --- a/src/runtime_intrinsics.c +++ b/src/runtime_intrinsics.c @@ -188,22 +188,17 @@ static inline uint16_t float_to_half(float param) JL_NOTSAFEPOINT return h; } -JL_DLLEXPORT float __gnu_h2f_ieee(uint16_t param) +JL_DLLEXPORT float julia__gnu_h2f_ieee(uint16_t param) { return half_to_float(param); } -JL_DLLEXPORT float __extendhfsf2(uint16_t param) -{ - return half_to_float(param); -} - -JL_DLLEXPORT uint16_t __gnu_f2h_ieee(float param) +JL_DLLEXPORT uint16_t julia__gnu_f2h_ieee(float param) { return float_to_half(param); } -JL_DLLEXPORT uint16_t __truncdfhf2(double param) +JL_DLLEXPORT uint16_t julia__truncdfhf2(double param) { float res = (float)param; uint32_t resi; @@ -225,6 +220,24 @@ JL_DLLEXPORT uint16_t __truncdfhf2(double param) return float_to_half(res); } +//JL_DLLEXPORT double julia__extendhfdf2(uint16_t n) { return (double)julia__gnu_h2f_ieee(n); } +//JL_DLLEXPORT int32_t julia__fixhfsi(uint16_t n) { return (int32_t)julia__gnu_h2f_ieee(n); } +//JL_DLLEXPORT int64_t julia__fixhfdi(uint16_t n) { return (int64_t)julia__gnu_h2f_ieee(n); } +//JL_DLLEXPORT uint32_t julia__fixunshfsi(uint16_t n) { return (uint32_t)julia__gnu_h2f_ieee(n); } +//JL_DLLEXPORT uint64_t julia__fixunshfdi(uint16_t n) { return (uint64_t)julia__gnu_h2f_ieee(n); } +//JL_DLLEXPORT uint16_t julia__floatsihf(int32_t n) { return julia__gnu_f2h_ieee((float)n); } +//JL_DLLEXPORT uint16_t julia__floatdihf(int64_t n) { return julia__gnu_f2h_ieee((float)n); } +//JL_DLLEXPORT uint16_t julia__floatunsihf(uint32_t n) { return julia__gnu_f2h_ieee((float)n); } +//JL_DLLEXPORT uint16_t julia__floatundihf(uint64_t n) { return julia__gnu_f2h_ieee((float)n); } +//HANDLE_LIBCALL(F16, F128, __extendhftf2) +//HANDLE_LIBCALL(F16, F80, __extendhfxf2) +//HANDLE_LIBCALL(F80, F16, __truncxfhf2) +//HANDLE_LIBCALL(F128, F16, __trunctfhf2) +//HANDLE_LIBCALL(PPCF128, F16, __trunctfhf2) +//HANDLE_LIBCALL(F16, I128, __fixhfti) +//HANDLE_LIBCALL(F16, I128, __fixunshfti) +//HANDLE_LIBCALL(I128, F16, __floattihf) +//HANDLE_LIBCALL(I128, F16, __floatuntihf) #endif // run time version of bitcast intrinsic @@ -597,11 +610,11 @@ static inline void name(unsigned osize, void *pa, void *pr) JL_NOTSAFEPOINT \ static inline void name(unsigned osize, void *pa, void *pr) JL_NOTSAFEPOINT \ { \ uint16_t a = *(uint16_t*)pa; \ - float A = __gnu_h2f_ieee(a); \ + float A = julia__gnu_h2f_ieee(a); \ if (osize == 16) { \ float R; \ OP(&R, A); \ - *(uint16_t*)pr = __gnu_f2h_ieee(R); \ + *(uint16_t*)pr = julia__gnu_f2h_ieee(R); \ } else { \ OP((uint16_t*)pr, A); \ } \ @@ -625,11 +638,11 @@ static void jl_##name##16(unsigned runtime_nbits, void *pa, void *pb, void *pr) { \ uint16_t a = *(uint16_t*)pa; \ uint16_t b = *(uint16_t*)pb; \ - float A = __gnu_h2f_ieee(a); \ - float B = __gnu_h2f_ieee(b); \ + float A = julia__gnu_h2f_ieee(a); \ + float B = julia__gnu_h2f_ieee(b); \ runtime_nbits = 16; \ float R = OP(A, B); \ - *(uint16_t*)pr = __gnu_f2h_ieee(R); \ + *(uint16_t*)pr = julia__gnu_f2h_ieee(R); \ } // float or integer inputs, bool output @@ -650,8 +663,8 @@ static int jl_##name##16(unsigned runtime_nbits, void *pa, void *pb) JL_NOTSAFEP { \ uint16_t a = *(uint16_t*)pa; \ uint16_t b = *(uint16_t*)pb; \ - float A = __gnu_h2f_ieee(a); \ - float B = __gnu_h2f_ieee(b); \ + float A = julia__gnu_h2f_ieee(a); \ + float B = julia__gnu_h2f_ieee(b); \ runtime_nbits = 16; \ return OP(A, B); \ } @@ -691,12 +704,12 @@ static void jl_##name##16(unsigned runtime_nbits, void *pa, void *pb, void *pc, uint16_t a = *(uint16_t*)pa; \ uint16_t b = *(uint16_t*)pb; \ uint16_t c = *(uint16_t*)pc; \ - float A = __gnu_h2f_ieee(a); \ - float B = __gnu_h2f_ieee(b); \ - float C = __gnu_h2f_ieee(c); \ + float A = julia__gnu_h2f_ieee(a); \ + float B = julia__gnu_h2f_ieee(b); \ + float C = julia__gnu_h2f_ieee(c); \ runtime_nbits = 16; \ float R = OP(A, B, C); \ - *(uint16_t*)pr = __gnu_f2h_ieee(R); \ + *(uint16_t*)pr = julia__gnu_f2h_ieee(R); \ } @@ -1367,7 +1380,7 @@ cvt_iintrinsic(LLVMFPtoUI, fptoui) if (!(osize < 8 * sizeof(a))) \ jl_error("fptrunc: output bitsize must be < input bitsize"); \ else if (osize == 16) \ - *(uint16_t*)pr = __gnu_f2h_ieee(a); \ + *(uint16_t*)pr = julia__gnu_f2h_ieee(a); \ else if (osize == 32) \ *(float*)pr = a; \ else if (osize == 64) \