diff --git a/enzyme/Enzyme/Enzyme.cpp b/enzyme/Enzyme/Enzyme.cpp index ede1cbd1a52e..58ee7145da8c 100644 --- a/enzyme/Enzyme/Enzyme.cpp +++ b/enzyme/Enzyme/Enzyme.cpp @@ -24,6 +24,7 @@ // //===----------------------------------------------------------------------===// #include +#include #include #if LLVM_VERSION_MAJOR >= 16 @@ -2217,7 +2218,7 @@ class EnzymeBase { #endif RemapFunction(F, Mapping, RF_NoModuleLevelChanges | RF_IgnoreMissingLocals); - TruncatedFunc->deleteBody(); + TruncatedFunc->eraseFromParent(); } return true; } diff --git a/enzyme/Enzyme/EnzymeLogic.cpp b/enzyme/Enzyme/EnzymeLogic.cpp index 7cb75a72932d..d5f011495dd0 100644 --- a/enzyme/Enzyme/EnzymeLogic.cpp +++ b/enzyme/Enzyme/EnzymeLogic.cpp @@ -42,6 +42,10 @@ #include "llvm/IR/LLVMContext.h" #include "llvm/Support/ErrorHandling.h" #include +#include +#include +#include +#include #if LLVM_VERSION_MAJOR >= 16 #define private public @@ -1724,7 +1728,7 @@ void clearFunctionAttributes(Function *f) { } Attribute::AttrKind attrs[] = { #if LLVM_VERSION_MAJOR >= 17 - Attribute::NoFPClass, + Attribute::NoFPClass, #endif Attribute::NoUndef, Attribute::NonNull, @@ -2553,7 +2557,7 @@ const AugmentedReturn &EnzymeLogic::CreateAugmentedPrimal( llvm::Attribute::AttrKind attrs[] = { #if LLVM_VERSION_MAJOR >= 17 - llvm::Attribute::NoFPClass, + llvm::Attribute::NoFPClass, #endif llvm::Attribute::NoAlias, llvm::Attribute::NoUndef, @@ -4923,6 +4927,8 @@ class TruncateUtils { Type *fromType; Type *toType; LLVMContext &ctx; + EnzymeLogic &Logic; + Value *UnknownLoc; private: std::string getOriginalFPRTName(std::string Name) { @@ -4946,7 +4952,7 @@ class TruncateUtils { ArgTypes.push_back(Arg->getType()); FunctionType *FnTy = FunctionType::get(RetTy, ArgTypes, /*is_vararg*/ false); - F = Function::Create(FnTy, Function::ExternalLinkage, MangledName, M); + F = Function::Create(FnTy, Function::WeakODRLinkage, MangledName, M); } if (F->isDeclaration()) { BasicBlock *Entry = BasicBlock::Create(F->getContext(), "entry", F); @@ -4955,6 +4961,9 @@ class TruncateUtils { ClonedI->setOperand(It, F->getArg(It)); auto Return = ReturnInst::Create(F->getContext(), ClonedI, Entry); ClonedI->insertBefore(Return); + F->setLinkage(GlobalValue::WeakODRLinkage); + // Clear invalidated debug metadata now that we defined the function + F->clearMetadata(); } } @@ -4975,22 +4984,40 @@ class TruncateUtils { CallInst *createFPRTGeneric(llvm::IRBuilderBase &B, std::string Name, const SmallVectorImpl &ArgsIn, - llvm::Type *RetTy) { + llvm::Type *RetTy, Value *LocStr) { SmallVector Args(ArgsIn.begin(), ArgsIn.end()); Args.push_back(B.getInt64(truncation.getTo().exponentWidth)); Args.push_back(B.getInt64(truncation.getTo().significandWidth)); Args.push_back(B.getInt64(truncation.getMode())); +#if LLVM_VERSION_MAJOR <= 14 + Args.push_back(B.CreateBitCast(LocStr, NullPtr->getType())); +#else + Args.push_back(LocStr); +#endif + auto FprtFunc = getFPRTFunc(Name, Args, RetTy); - return cast(B.CreateCall(FprtFunc, Args)); + // Explicitly assign a dbg location if it didn't exist, as the FPRT + // functions are inlineable and the backend fails if the callsite does not + // have dbg metadata + // TODO consider using InstrumentationIRBuilder + Function *ContainingF = B.GetInsertBlock()->getParent(); + if (!B.getCurrentDebugLocation() && ContainingF->getSubprogram()) + B.SetCurrentDebugLocation(DILocation::get(ContainingF->getContext(), 0, 0, + ContainingF->getSubprogram())); + auto *CI = cast(B.CreateCall(FprtFunc, Args)); + + return CI; } public: - TruncateUtils(FloatTruncation truncation, Module *M) - : truncation(truncation), M(M), ctx(M->getContext()) { + TruncateUtils(FloatTruncation truncation, Module *M, EnzymeLogic &Logic) + : truncation(truncation), M(M), ctx(M->getContext()), Logic(Logic) { fromType = truncation.getFromType(ctx); toType = truncation.getToType(ctx); if (fromType == toType) assert(truncation.isToFPRT()); + + UnknownLoc = getUniquedLocStr(nullptr); } Type *getFromType() { return fromType; } @@ -5001,23 +5028,54 @@ class TruncateUtils { assert(V->getType() == getFromType()); SmallVector Args; Args.push_back(V); - return createFPRTGeneric(B, "const", Args, getToType()); + return createFPRTGeneric(B, "const", Args, getToType(), UnknownLoc); } CallInst *createFPRTNewCall(llvm::IRBuilderBase &B, Value *V) { assert(V->getType() == getFromType()); SmallVector Args; Args.push_back(V); - return createFPRTGeneric(B, "new", Args, getToType()); + return createFPRTGeneric(B, "new", Args, getToType(), UnknownLoc); } CallInst *createFPRTGetCall(llvm::IRBuilderBase &B, Value *V) { SmallVector Args; Args.push_back(V); - return createFPRTGeneric(B, "get", Args, getToType()); + return createFPRTGeneric(B, "get", Args, getToType(), UnknownLoc); } CallInst *createFPRTDeleteCall(llvm::IRBuilderBase &B, Value *V) { SmallVector Args; Args.push_back(V); - return createFPRTGeneric(B, "delete", Args, B.getVoidTy()); + return createFPRTGeneric(B, "delete", Args, B.getVoidTy(), UnknownLoc); + } + // This will result in a unique string for each location, which means the + // runtime can check whether two operations are the same with a simple pointer + // comparison. However, we need LTO for this to be the case across different + // compilation units. + GlobalValue *getUniquedLocStr(Instruction *I) { + std::string FileName = "unknown"; + unsigned LineNo = 0; + unsigned ColNo = 0; + + if (I) { + DILocation *DL = I->getDebugLoc(); + if (DL) { + FileName = DL->getFilename(); + LineNo = DL->getLine(); + ColNo = DL->getColumn(); + } + } + + auto Key = std::make_tuple(FileName, LineNo, ColNo); + auto It = Logic.UniqDebugLocStrs.find(Key); + + if (It != Logic.UniqDebugLocStrs.end()) + return It->second; + + std::string LocStr = + FileName + ":" + std::to_string(LineNo) + ":" + std::to_string(ColNo); + auto GV = createPrivateGlobalForString(*M, LocStr, true); + Logic.UniqDebugLocStrs[Key] = GV; + + return GV; } CallInst *createFPRTOpCall(llvm::IRBuilderBase &B, llvm::Instruction &I, llvm::Type *RetTy, @@ -5040,21 +5098,28 @@ class TruncateUtils { "Unexpected indirect call inst for conversion to FPRT"); } else if (auto CI = dyn_cast(&I)) { Name = "fcmp_" + std::string(CI->getPredicateName(CI->getPredicate())); + } else if (auto UO = dyn_cast(&I)) { + Name = "unaryop_" + std::string(UO->getOpcodeName()); } else { llvm_unreachable("Unexpected instruction for conversion to FPRT"); } createOriginalFPRTFunc(I, Name, ArgsIn, RetTy); - return createFPRTGeneric(B, Name, ArgsIn, RetTy); + return createFPRTGeneric(B, Name, ArgsIn, RetTy, getUniquedLocStr(&I)); } }; +// TODO we need to handle cases where constant aggregates are used and they +// contain constant fp's in them. +// +// e.g. store {0 : i64, 1.0: f64} %ptr +// +// Currently in mem mode the float will remain unconverted and we will likely +// crash somewhere. class TruncateGenerator : public llvm::InstVisitor, public TruncateUtils { private: ValueToValueMapTy &originalToNewFn; FloatTruncation truncation; - Function *oldFunc; - Function *newFunc; TruncateMode mode; EnzymeLogic &Logic; LLVMContext &ctx; @@ -5063,36 +5128,43 @@ class TruncateGenerator : public llvm::InstVisitor, TruncateGenerator(ValueToValueMapTy &originalToNewFn, FloatTruncation truncation, Function *oldFunc, Function *newFunc, EnzymeLogic &Logic) - : TruncateUtils(truncation, newFunc->getParent()), + : TruncateUtils(truncation, newFunc->getParent(), Logic), originalToNewFn(originalToNewFn), truncation(truncation), - oldFunc(oldFunc), newFunc(newFunc), mode(truncation.getMode()), - Logic(Logic), ctx(newFunc->getContext()) {} + mode(truncation.getMode()), Logic(Logic), ctx(newFunc->getContext()) {} - void checkHandled(llvm::Instruction &inst) { - // TODO - // if (all_of(inst.getOperandList(), - // [&](Use *use) { return use->get()->getType() == fromType; })) - // todo(inst); - } + void todo(llvm::Instruction &I) { + if (all_of(I.operands(), + [&](Use &U) { return U.get()->getType() != fromType; }) && + I.getType() != fromType) + return; - // TODO - void handleTrunc(); - void hendleIntToFloat(); - void handleFloatToInt(); + switch (mode) { + case TruncMemMode: + llvm::errs() << I << "\n"; + EmitFailure("FPEscaping", I.getDebugLoc(), &I, "FP value escapes!"); + break; + case TruncOpMode: + case TruncOpFullModuleMode: + EmitWarning( + "UnhandledTrunc", I, + "Operation not handled - it will be executed in the original way.", + I); + break; + default: + llvm_unreachable("Unknown trunc mode"); + } + } - void visitInstruction(llvm::Instruction &inst) { + void visitInstruction(llvm::Instruction &I) { using namespace llvm; - // TODO explicitly handle all instructions rather than using the catch all - // below - - switch (inst.getOpcode()) { + switch (I.getOpcode()) { // #include "InstructionDerivatives.inc" default: break; } - checkHandled(inst); + todo(I); } Value *truncate(IRBuilder<> &B, Value *v) { @@ -5119,17 +5191,24 @@ class TruncateGenerator : public llvm::InstVisitor, llvm_unreachable("Unknown trunc mode"); } - void todo(llvm::Instruction &I) { - std::string s; - llvm::raw_string_ostream ss(s); - ss << "cannot handle unknown instruction\n" << I; - if (CustomErrorHandler) { - IRBuilder<> Builder2(getNewFromOriginal(&I)); - CustomErrorHandler(ss.str().c_str(), wrap(&I), ErrorType::NoTruncate, - this, nullptr, wrap(&Builder2)); + void visitUnaryOperator(UnaryOperator &I) { + switch (I.getOpcode()) { + case UnaryOperator::FNeg: { + if (I.getOperand(0)->getType() != getFromType()) + return; + + auto newI = getNewFromOriginal(&I); + IRBuilder<> B(newI); + SmallVector Args = {newI->getOperand(0)}; + auto nres = createFPRTOpCall(B, I, newI->getType(), Args); + nres->takeName(newI); + nres->copyIRFlags(newI); + newI->replaceAllUsesWith(nres); + newI->eraseFromParent(); return; - } else { - EmitFailure("NoTruncate", I.getDebugLoc(), &I, ss.str()); + } + default: + todo(I); return; } } @@ -5150,8 +5229,8 @@ class TruncateGenerator : public llvm::InstVisitor, auto truncRHS = truncate(B, RHS); SmallVector Args; - Args.push_back(LHS); - Args.push_back(RHS); + Args.push_back(truncLHS); + Args.push_back(truncRHS); Instruction *nres; if (truncation.isToFPRT()) nres = createFPRTOpCall(B, CI, B.getInt1Ty(), Args); @@ -5179,13 +5258,32 @@ class TruncateGenerator : public llvm::InstVisitor, SI.isVolatile(), SI.getOrdering(), SI.getSyncScopeID(), /*mask=*/nullptr); } + // TODO Is there a possibility we GEP a const and get a FP value? void visitGetElementPtrInst(llvm::GetElementPtrInst &gep) { return; } - void visitPHINode(llvm::PHINode &phi) { return; } void visitCastInst(llvm::CastInst &CI) { + // TODO Try to follow fps through trunc/exts switch (mode) { case TruncMemMode: { - if (CI.getSrcTy() == getFromType() || CI.getDestTy() == getFromType()) - todo(CI); + auto newI = getNewFromOriginal(&CI); + auto newSrc = newI->getOperand(0); + if (CI.getSrcTy() == getFromType()) { + IRBuilder<> B(newI); + if (isa(newSrc)) + return; + newI->setOperand(0, createFPRTGetCall(B, newSrc)); + EmitWarning("FPNoFollow", CI, "Will not follow FP through this cast.", + CI); + } else if (CI.getDestTy() == getFromType()) { + IRBuilder<> B(newI->getNextNode()); + EmitWarning("FPNoFollow", CI, "Will not follow FP through this cast.", + CI); + auto nres = createFPRTNewCall(B, newI); + nres->takeName(newI); + nres->copyIRFlags(newI); + newI->replaceUsesWithIf(nres, + [&](Use &U) { return U.getUser() != nres; }); + originalToNewFn[const_cast(cast(&CI))] = nres; + } return; } case TruncOpMode: @@ -5196,6 +5294,8 @@ class TruncateGenerator : public llvm::InstVisitor, void visitSelectInst(llvm::SelectInst &SI) { switch (mode) { case TruncMemMode: { + if (SI.getType() != getFromType()) + return; auto newI = getNewFromOriginal(&SI); IRBuilder<> B(newI); auto newT = truncate(B, getNewFromOriginal(SI.getTrueValue())); @@ -5329,11 +5429,31 @@ class TruncateGenerator : public llvm::InstVisitor, newI->eraseFromParent(); return true; } + void visitIntrinsicInst(llvm::IntrinsicInst &II) { handleIntrinsic(II, II.getIntrinsicID()); } - void visitReturnInst(llvm::ReturnInst &I) { return; } + void visitReturnInst(llvm::ReturnInst &I) { + switch (mode) { + case TruncMemMode: { + if (I.getNumOperands() == 0) + return; + if (I.getReturnValue()->getType() != getFromType()) + return; + auto newI = cast(getNewFromOriginal(&I)); + IRBuilder<> B(newI); + if (isa(newI->getOperand(0))) + newI->setOperand(0, createFPRTConstCall(B, newI->getReturnValue())); + return; + } + case TruncOpMode: + case TruncOpFullModuleMode: + break; + default: + llvm_unreachable("Unknown trunc mode"); + } + } void visitBranchInst(llvm::BranchInst &I) { return; } void visitSwitchInst(llvm::SwitchInst &I) { return; } @@ -5348,6 +5468,23 @@ class TruncateGenerator : public llvm::InstVisitor, llvm::Value *orig_val, llvm::MaybeAlign prevalign, bool isVolatile, llvm::AtomicOrdering ordering, llvm::SyncScope::ID syncScope, llvm::Value *mask) { + switch (mode) { + case TruncMemMode: { + if (orig_val->getType() != getFromType()) + return; + if (!isa(orig_val)) + return; + auto newI = getNewFromOriginal(&I); + IRBuilder<> B(newI); + newI->setOperand(0, createFPRTConstCall(B, getNewFromOriginal(orig_val))); + return; + } + case TruncOpMode: + case TruncOpFullModuleMode: + break; + default: + llvm_unreachable("Unknown trunc mode"); + } return; } @@ -5435,17 +5572,55 @@ class TruncateGenerator : public llvm::InstVisitor, if (mode != TruncOpFullModuleMode) { RequestContext ctx(&CI, &BuilderZ); - auto val = GetShadow(ctx, getNewFromOriginal(CI.getCalledOperand())); - newCall->setCalledOperand(val); + Function *Func = CI.getCalledFunction(); + if (Func && !Func->empty()) { + auto val = GetShadow(ctx, getNewFromOriginal(CI.getCalledOperand())); + newCall->setCalledOperand(val); + } else { + switch (mode) { + case TruncMemMode: + EmitWarning("FPNoFollow", CI, + "Will not follow FP through this function call as the " + "definition is not available.", + CI); + break; + case TruncOpMode: + case TruncOpFullModuleMode: + EmitWarning("FPNoFollow", CI, + "Will not truncate flops in this function call as the " + "definition is not available.", + CI); + break; + default: + llvm_unreachable("Unknown trunc mode"); + } + } } return; } - void visitFPTruncInst(FPTruncInst &I) { return; } - void visitFPExtInst(FPExtInst &I) { return; } - void visitFPToUIInst(FPToUIInst &I) { return; } - void visitFPToSIInst(FPToSIInst &I) { return; } - void visitUIToFPInst(UIToFPInst &I) { return; } - void visitSIToFPInst(SIToFPInst &I) { return; } + void visitPHINode(llvm::PHINode &PN) { + switch (mode) { + case TruncMemMode: { + if (PN.getType() != getFromType()) + return; + auto NewPN = cast(getNewFromOriginal(&PN)); + IRBuilder<> B( + NewPN->getParent()->getParent()->getEntryBlock().getFirstNonPHI()); + for (unsigned It = 0; It < NewPN->getNumIncomingValues(); It++) { + if (isa(NewPN->getIncomingValue(It))) { + NewPN->setOperand( + It, createFPRTConstCall(B, NewPN->getIncomingValue(It))); + } + } + break; + } + case TruncOpMode: + case TruncOpFullModuleMode: + break; + default: + llvm_unreachable("Unknown trunc mode"); + } + } }; bool EnzymeLogic::CreateTruncateValue(RequestContext context, Value *v, @@ -5457,7 +5632,8 @@ bool EnzymeLogic::CreateTruncateValue(RequestContext context, Value *v, Value *converted = nullptr; auto truncation = FloatTruncation(from, to, TruncMemMode); - TruncateUtils TU(truncation, B.GetInsertBlock()->getParent()->getParent()); + TruncateUtils TU(truncation, B.GetInsertBlock()->getParent()->getParent(), + *this); if (isTruncate) converted = TU.createFPRTNewCall(B, v); else diff --git a/enzyme/Enzyme/EnzymeLogic.h b/enzyme/Enzyme/EnzymeLogic.h index dd9b877c5ade..6c3611c67a8e 100644 --- a/enzyme/Enzyme/EnzymeLogic.h +++ b/enzyme/Enzyme/EnzymeLogic.h @@ -31,6 +31,7 @@ #define ENZYME_LOGIC_H #include +#include #include #include @@ -413,9 +414,14 @@ struct FloatTruncation { std::string mangleFrom() const { return from.to_string(); } }; +typedef std::map, + llvm::GlobalValue *> + UniqDebugLocStrsTy; + class EnzymeLogic { public: PreProcessCache PPC; + UniqDebugLocStrsTy UniqDebugLocStrs; /// \p PostOpt is whether to perform basic /// optimization of the function after synthesis diff --git a/enzyme/Enzyme/Runtimes/FPRT/Trace.cpp b/enzyme/Enzyme/Runtimes/FPRT/Trace.cpp index e06f02508d59..3166b1087da2 100644 --- a/enzyme/Enzyme/Runtimes/FPRT/Trace.cpp +++ b/enzyme/Enzyme/Runtimes/FPRT/Trace.cpp @@ -25,76 +25,367 @@ // //===----------------------------------------------------------------------===// +#include +#include +#include +#include +#include #include +#include #include #include -#define __ENZYME_MPFR_ATTRIBUTES -#define __ENZYME_MPFR_ORIGINAL_ATTRIBUTES +#include +#include + +#define __ENZYME_MPFR_ATTRIBUTES __attribute__((weak)) +#define __ENZYME_MPFR_ORIGINAL_ATTRIBUTES __attribute__((weak)) + +#ifndef ENZYME_FPRT_TRACE_PRINT +#define ENZYME_FPRT_TRACE_PRINT 1 +#endif + +static constexpr unsigned fp_max_inputs = 3; +static constexpr std::array arg_names = {"x", "y", "z"}; +static_assert(arg_names.size() == fp_max_inputs); extern "C" { +typedef struct __enzyme_fp { +private: + double result; + unsigned char input_num; + const char *loc; + __enzyme_fp *inputs[fp_max_inputs]; + double derivatives[fp_max_inputs]; +#if ENZYME_FPRT_TRACE_PRINT + const char *name; +#endif + +public: + size_t id; + + double getDerivative(unsigned no) const { return derivatives[no]; } + void setDerivative(unsigned no, double d) { derivatives[no] = d; } + + __enzyme_fp *getInput(unsigned no) const { return inputs[no]; } + void setInput(unsigned no, __enzyme_fp *i) { inputs[no] = i; } + + unsigned char getInputNum() const { return input_num; } + void setInputNum(unsigned char i) { input_num = i; } + + double getResult() const { return result; } + void setResult(double r) { result = r; } + + const char *getLoc() const { return loc; } + void setLoc(const char *l) { loc = l; } + +#if ENZYME_FPRT_TRACE_PRINT + const char *getName() const { return name; } + void setName(const char *l) { name = l; } +#endif -typedef struct { - double v; } __enzyme_fp; +} -// TODO ultimately we probably want a linked list of arrays or something like -// that for this -static std::list<__enzyme_fp> FPs; +static void print_enzyme_fp_derivatives(std::ostream &out, + const __enzyme_fp *fp) { + auto seen = false; + for (unsigned i = 0; i < fp->getInputNum(); i++) { + if (seen) + out << ", "; + seen = true; + out << "d" << arg_names[i] << " = " << fp->getDerivative(i); + } +} +static void print_enzyme_fp_value(std::ostream &out, const __enzyme_fp *fp) { + out << "[" << fp << ": " << fp->getResult() << "]"; +} +static void print_enzyme_fp_function(std::ostream &out, const __enzyme_fp *fp) { + std::cerr << fp->getName() << "("; + bool seen = false; + for (unsigned i = 0; i < fp->getInputNum(); i++) { + if (seen) + std::cerr << ", "; + seen = true; + __enzyme_fp *fpinput = fp->getInput(i); + print_enzyme_fp_value(std::cerr, fpinput); + } + std::cerr << ")"; +} +static void print_enzyme_fp(std::ostream &out, const __enzyme_fp *fp) { + print_enzyme_fp_function(out, fp); + out << " -> "; + print_enzyme_fp_value(out, fp); + out << " "; + print_enzyme_fp_derivatives(out, fp); + out << " at " << fp->getLoc(); + out << std::endl; +} + +template +static void __enzyme_fprt_trace_no_res_flop(std::array inputs, + const char *name, const char *loc) { + __enzyme_fp fp; + fp.setInputNum(NumInputs); + fp.setLoc(loc); + for (unsigned i = 0; i < inputs.size(); i++) { + __enzyme_fp *inputfp = __enzyme_fprt_double_to_ptr(inputs[i]); + fp.setInput(i, inputfp); + } + +#if ENZYME_FPRT_TRACE_PRINT + fp.setName(name); + print_enzyme_fp_function(std::cerr, &fp); + std::cerr << " at " << loc << std::endl; +#endif +} + +namespace { +template class Derivative { +public: + __attribute__((always_inline)) static T get(void *fn, + std::array inputs) { + return 0; + } +}; +template class Derivative { +public: + __attribute__((always_inline)) static T get(void *fn, + std::array inputs) { + typedef double (*fty)(double); + return __enzyme_fwddiff((fty)fn, enzyme_dup, inputs[0], 1.0); + } +}; +template class Derivative { +public: + __attribute__((always_inline)) static T get(void *fn, + std::array inputs) { + typedef double (*fty)(double); + return __enzyme_fwddiff((fty)fn, enzyme_dup, inputs[0], 1.0, + enzyme_const, inputs[1]); + } +}; +template class Derivative { +public: + __attribute__((always_inline)) static T get(void *fn, + std::array inputs) { + typedef double (*fty)(double); + return __enzyme_fwddiff((fty)fn, enzyme_const, inputs[0], enzyme_dup, + inputs[1], 1.0); + } +}; +template class Derivative { +public: + __attribute__((always_inline)) static T get(void *fn, + std::array inputs) { + typedef double (*fty)(double); + // clang-format off + return __enzyme_fwddiff((fty)fn, + enzyme_dup, inputs[0], 1.0, + enzyme_const, inputs[1], + enzyme_const, inputs[2] + ); + // clang-format on + } +}; +template class Derivative { +public: + __attribute__((always_inline)) static T get(void *fn, + std::array inputs) { + typedef double (*fty)(double); + // clang-format off + return __enzyme_fwddiff((fty)fn, + enzyme_const, inputs[0], + enzyme_dup, inputs[1], 1.0, + enzyme_const, inputs[2] + ); + // clang-format on + } +}; +template class Derivative { +public: + __attribute__((always_inline)) static T get(void *fn, + std::array inputs) { + typedef double (*fty)(double); + // clang-format off + return __enzyme_fwddiff((fty)fn, + enzyme_const, inputs[0], + enzyme_const, inputs[1], + enzyme_dup, inputs[2], 1.0 + ); + // clang-format on + } +}; +} // namespace + +template +__attribute__((always_inline)) static void +__enzyme_fprt_trace_flop(std::array _inputs, T output_val, + __enzyme_fp *outfp, void *fn, const char *name, + const char *loc) { + std::array<__enzyme_fp *, NumInputs> inputs; + std::array input_vals; + for (unsigned i = 0; i < _inputs.size(); i++) { + __enzyme_fp *inputfp = __enzyme_fprt_double_to_ptr(_inputs[i]); + inputs[i] = inputfp; + input_vals[i] = inputfp->getResult(); + } -static bool __enzyme_fprt_is_mem_mode(int64_t mode) { return mode & 0b0001; } -static bool __enzyme_fprt_is_op_mode(int64_t mode) { return mode & 0b0010; } + outfp->setResult(output_val); + outfp->setInputNum(inputs.size()); + outfp->setLoc(loc); + for (unsigned i = 0; i < inputs.size(); i++) { + outfp->setInput(i, inputs[i]); + T d; + static_assert(inputs.size() <= fp_max_inputs); + if (i == 0) + d = Derivative::get(fn, input_vals); + else if (i == 1) + d = Derivative::get(fn, input_vals); + else if (i == 2) + d = Derivative::get(fn, input_vals); + else + llvm_unreachable("impossible"); + outfp->setDerivative(i, d); + } -static double __enzyme_fprt_ptr_to_double(__enzyme_fp *p) { - return *((double *)(&p)); +#if ENZYME_FPRT_TRACE_PRINT + outfp->setName(name); + print_enzyme_fp(std::cerr, outfp); +#endif } -static __enzyme_fp *__enzyme_fprt_double_to_ptr(double d) { - return *((__enzyme_fp **)(&d)); + +// TODO ultimately we probably want a linked list of arrays or something like +// that for this (std::list probably is that but we may want our own impl) +struct { + std::list<__enzyme_fp> all; + std::list<__enzyme_fp *> outputs; + std::list<__enzyme_fp *> inputs; + std::list<__enzyme_fp *> consts; + void clear() { + all.clear(); + outputs.clear(); + inputs.clear(); + } +} FPs; + +extern "C" { + +__enzyme_fp *__enzyme_fprt_64_52_new_intermediate(int64_t exponent, + int64_t significand, + int64_t mode, + const char *loc) { + size_t id = FPs.all.size(); + FPs.all.push_back({}); + __enzyme_fp *a = &FPs.all.back(); + a->id = id; + return a; } -__ENZYME_MPFR_ATTRIBUTES double __enzyme_fprt_64_52_get(double _a, int64_t exponent, int64_t significand, - int64_t mode) { + int64_t mode, const char *loc) { __enzyme_fp *a = __enzyme_fprt_double_to_ptr(_a); - return a->v; + FPs.outputs.push_back(a); + __enzyme_fprt_trace_no_res_flop({_a}, "get", loc); + return a->getResult(); } -__ENZYME_MPFR_ATTRIBUTES double __enzyme_fprt_64_52_new(double _a, int64_t exponent, int64_t significand, - int64_t mode) { - FPs.push_back({_a}); - __enzyme_fp *a = &FPs.back(); - return __enzyme_fprt_ptr_to_double(a); + int64_t mode, const char *loc) { + __enzyme_fp *a = + __enzyme_fprt_64_52_new_intermediate(exponent, significand, mode, loc); + FPs.inputs.push_back(a); + __enzyme_fprt_trace_flop({}, _a, a, nullptr, "new", loc); + auto ret = __enzyme_fprt_ptr_to_double(a); + return ret; } -__ENZYME_MPFR_ATTRIBUTES -__enzyme_fp *__enzyme_fprt_64_52_new_intermediate(int64_t exponent, - int64_t significand, - int64_t mode) { - FPs.push_back({0}); - __enzyme_fp *a = &FPs.back(); - return a; +double __enzyme_fprt_64_52_const(double _a, int64_t exponent, + int64_t significand, int64_t mode, + const char *loc) { + // TODO This should really be called only once for an appearance in the code, + // currently it is called every time a flop uses a constant. + __enzyme_fp *a = + __enzyme_fprt_64_52_new_intermediate(exponent, significand, mode, loc); + FPs.consts.push_back(a); + __enzyme_fprt_trace_flop({}, _a, a, nullptr, "const", loc); + auto ret = __enzyme_fprt_ptr_to_double(a); + return ret; } -__ENZYME_MPFR_ATTRIBUTES void __enzyme_fprt_64_52_delete(double a, int64_t exponent, int64_t significand, - int64_t mode) { + int64_t mode, const char *loc) { // TODO + __enzyme_fprt_trace_no_res_flop({a}, "delete", loc); +} + +// Below sensitivity computation is taken frmo ADAPT +static double __enzyme_estimate_truncation_error(double a) { + return abs(a - (float)a); +} + +void __enzyme_fprt_delete_all() { + size_t size = FPs.all.size(); + size_t i = 0; + for (auto it = FPs.all.begin(); it != FPs.all.end(); i++, it++) { + // Do not truncate inputs + if (std::find(FPs.inputs.begin(), FPs.inputs.end(), &*it) != + FPs.inputs.end()) + continue; + // Or consts + if (std::find(FPs.consts.begin(), FPs.consts.end(), &*it) != + FPs.consts.end()) + continue; + + // Zero out all errors + // TODO is it faster to calloc each time or should we pre-allocate and + // memset? + double *errors = (double *)std::calloc(size, sizeof(*errors)); + // Introduce truncation error into the current op + // TODO we can probably re-run the original operation in the truncated + // precision thus get the real error and not an estimation + errors[i] = __enzyme_estimate_truncation_error(it->getResult()); + + size_t j = i; + for (auto jt = it; jt != FPs.all.end(); j++, jt++) + for (unsigned char k = 0; k < jt->getInputNum(); k++) + errors[j] += abs(jt->getDerivative(k) * errors[jt->getInput(k)->id]); + +#if ENZYME_FPRT_TRACE_PRINT + std::cerr << "For instance "; + print_enzyme_fp_value(std::cerr, &*it); + std::cerr << " when truncated from double to float:" << std::endl; + + for (__enzyme_fp *output : FPs.outputs) { + std::cerr << " wrt output "; + print_enzyme_fp_value(std::cerr, output); + std::cerr << " at " << output->getLoc() + << ", sensitivity = " << errors[output->id] << std::endl; + } +#endif + } + FPs.clear(); } #define __ENZYME_MPFR_SINGOP(OP_TYPE, LLVM_OP_NAME, MPFR_FUNC_NAME, FROM_TYPE, \ RET, MPFR_GET, ARG1, MPFR_SET_ARG1, \ ROUNDING_MODE) \ __ENZYME_MPFR_ATTRIBUTES \ - RET __enzyme_fprt_original_##FROM_TYPE##_##OP_TYPE##_##LLVM_OP_NAME( \ - ARG1 a, int64_t exponent, int64_t significand, int64_t mode); \ + RET __enzyme_fprt_original_##FROM_TYPE##_##OP_TYPE##_##LLVM_OP_NAME(ARG1 a); \ __ENZYME_MPFR_ATTRIBUTES \ RET __enzyme_fprt_##FROM_TYPE##_##OP_TYPE##_##LLVM_OP_NAME( \ - ARG1 a, int64_t exponent, int64_t significand, int64_t mode) { \ - RET res = \ - __enzyme_fprt_original_##FROM_TYPE##_##OP_TYPE##_##LLVM_OP_NAME(a); \ - __enzyme_trace_flop({a}, ret, #LLVM_OP_NAME); \ - return res; \ + ARG1 a, int64_t exponent, int64_t significand, int64_t mode, \ + const char *loc) { \ + auto originalfn = \ + __enzyme_fprt_original_##FROM_TYPE##_##OP_TYPE##_##LLVM_OP_NAME; \ + RET res = originalfn(__enzyme_fprt_double_to_ptr(a)->getResult()); \ + __enzyme_fp *intermediate = __enzyme_fprt_64_52_new_intermediate( \ + exponent, significand, mode, loc); \ + intermediate->setResult(res); \ + double ret = __enzyme_fprt_ptr_to_double(intermediate); \ + __enzyme_fprt_trace_flop({a}, res, intermediate, \ + (void *)originalfn, #LLVM_OP_NAME, loc); \ + return ret; \ } // TODO this is a bit sketchy if the user cast their float to int before calling @@ -107,53 +398,93 @@ void __enzyme_fprt_64_52_delete(double a, int64_t exponent, int64_t significand, ARG2 b); \ __ENZYME_MPFR_ATTRIBUTES RET \ __enzyme_fprt_##FROM_TYPE##_##OP_TYPE##_##LLVM_OP_NAME( \ - ARG1 a, ARG2 b, int64_t exponent, int64_t significand, \ - int64_t mode) { \ - RET res = \ - __enzyme_fprt_original_##FROM_TYPE##_##OP_TYPE##_##LLVM_OP_NAME(a, b); \ - __enzyme_trace_flop({a, b}, ret, #LLVM_OP_NAME); \ - return res; \ + ARG1 a, ARG2 b, int64_t exponent, int64_t significand, int64_t mode, \ + const char *loc) { \ + auto originalfn = \ + __enzyme_fprt_original_##FROM_TYPE##_##OP_TYPE##_##LLVM_OP_NAME; \ + RET res = originalfn(__enzyme_fprt_double_to_ptr(a)->getResult(), \ + __enzyme_fprt_double_to_ptr(b)->getResult()); \ + __enzyme_fp *intermediate = __enzyme_fprt_64_52_new_intermediate( \ + exponent, significand, mode, loc); \ + intermediate->setResult(res); \ + double ret = __enzyme_fprt_ptr_to_double(intermediate); \ + __enzyme_fprt_trace_flop({a}, res, intermediate, \ + (void *)originalfn, #LLVM_OP_NAME, loc); \ + return ret; \ } #define __ENZYME_MPFR_BIN(OP_TYPE, LLVM_OP_NAME, MPFR_FUNC_NAME, FROM_TYPE, \ RET, MPFR_GET, ARG1, MPFR_SET_ARG1, ARG2, \ MPFR_SET_ARG2, ROUNDING_MODE) \ __ENZYME_MPFR_ORIGINAL_ATTRIBUTES \ - RET __enzyme_fprt_##FROM_TYPE##_##OP_TYPE##_##LLVM_OP_NAME(ARG1 a, ARG2 b); \ + RET __enzyme_fprt_original_##FROM_TYPE##_##OP_TYPE##_##LLVM_OP_NAME(ARG1 a, \ + ARG2 b); \ __ENZYME_MPFR_ATTRIBUTES \ RET __enzyme_fprt_##FROM_TYPE##_##OP_TYPE##_##LLVM_OP_NAME( \ - ARG1 a, ARG2 b, int64_t exponent, int64_t significand, int64_t mode) { \ - RET res = \ - __enzyme_fprt_original_##FROM_TYPE##_##OP_TYPE##_##LLVM_OP_NAME(a, b); \ - __enzyme_trace_flop({a, b}, ret, #LLVM_OP_NAME); \ - return res; \ + ARG1 a, ARG2 b, int64_t exponent, int64_t significand, int64_t mode, \ + const char *loc) { \ + auto originalfn = \ + __enzyme_fprt_original_##FROM_TYPE##_##OP_TYPE##_##LLVM_OP_NAME; \ + RET res = originalfn(__enzyme_fprt_double_to_ptr(a)->getResult(), \ + __enzyme_fprt_double_to_ptr(b)->getResult()); \ + __enzyme_fp *intermediate = __enzyme_fprt_64_52_new_intermediate( \ + exponent, significand, mode, loc); \ + intermediate->setResult(res); \ + double ret = __enzyme_fprt_ptr_to_double(intermediate); \ + __enzyme_fprt_trace_flop({a, b}, res, intermediate, \ + (void *)originalfn, #LLVM_OP_NAME, loc); \ + return ret; \ } -#define __ENZYME_MPFR_FMULADD(LLVM_OP_NAME, FROM_TYPE, TYPE, MPFR_TYPE, \ - LLVM_TYPE, ROUNDING_MODE) \ +#define __ENZYME_MPFR_FMULADD(LLVM_OP_NAME, FROM_TYPE, TYPE, MPFR_TYPE, \ + LLVM_TYPE, ROUNDING_MODE) \ + __ENZYME_MPFR_ORIGINAL_ATTRIBUTES \ + TYPE __enzyme_fprt_original_##FROM_TYPE##_intr_##LLVM_OP_NAME##_##LLVM_TYPE( \ + TYPE a, TYPE b, TYPE c); \ + __ENZYME_MPFR_ATTRIBUTES \ + TYPE __enzyme_fprt_##FROM_TYPE##_intr_##LLVM_OP_NAME##_##LLVM_TYPE( \ + TYPE a, TYPE b, TYPE c, int64_t exponent, int64_t significand, \ + int64_t mode, const char *loc) { \ + auto originalfn = \ + __enzyme_fprt_original_##FROM_TYPE##_intr_##LLVM_OP_NAME##_##LLVM_TYPE; \ + TYPE res = originalfn(__enzyme_fprt_double_to_ptr(a)->getResult(), \ + __enzyme_fprt_double_to_ptr(b)->getResult(), \ + __enzyme_fprt_double_to_ptr(c)->getResult()); \ + __enzyme_fp *intermediate = __enzyme_fprt_64_52_new_intermediate( \ + exponent, significand, mode, loc); \ + intermediate->setResult(res); \ + double ret = __enzyme_fprt_ptr_to_double(intermediate); \ + __enzyme_fprt_trace_flop({a, b, c}, res, intermediate, \ + (void *)originalfn, #LLVM_OP_NAME, loc); \ + return ret; \ + } + +#define __ENZYME_MPFR_FCMP_IMPL(NAME, ORDERED, CMP, FROM_TYPE, TYPE, MPFR_GET, \ + ROUNDING_MODE) \ __ENZYME_MPFR_ORIGINAL_ATTRIBUTES \ - RET __enzyme_fprt_##FROM_TYPE##_##OP_TYPE##_##LLVM_OP_NAME(TYPE a, TYPE b, \ - TYPE c); \ + bool __enzyme_fprt_original_##FROM_TYPE##_fcmp_##NAME(TYPE a, TYPE b); \ __ENZYME_MPFR_ATTRIBUTES \ - TYPE __enzyme_fprt_##FROM_TYPE##_intr_##LLVM_OP_NAME##_##LLVM_TYPE( \ - TYPE a, TYPE b, TYPE c, int64_t exponent, int64_t significand, \ - int64_t mode) { \ - RET res = __enzyme_fprt_original_##FROM_TYPE##_##OP_TYPE##_##LLVM_OP_NAME( \ - a, b, c); \ - __enzyme_trace_flop({a, b, c}, ret, #LLVM_OP_NAME); \ + bool __enzyme_fprt_##FROM_TYPE##_fcmp_##NAME( \ + TYPE a, TYPE b, int64_t exponent, int64_t significand, int64_t mode, \ + const char *loc) { \ + bool res = __enzyme_fprt_original_##FROM_TYPE##_fcmp_##NAME( \ + __enzyme_fprt_double_to_ptr(a)->getResult(), \ + __enzyme_fprt_double_to_ptr(b)->getResult()); \ + __enzyme_fprt_trace_no_res_flop({a, b}, "fcmp_" #NAME, loc); \ return res; \ } __ENZYME_MPFR_ORIGINAL_ATTRIBUTES bool __enzyme_fprt_original_64_52_intr_llvm_is_fpclass_f64(double a, int32_t tests); -__ENZYME_MPFR_ATTRIBUTES bool -__enzyme_fprt_64_52_intr_llvm_is_fpclass_f64(double a, int32_t tests) { - return __enzyme_fprt_original_64_52_intr_llvm_is_fpclass_f64(a, tests); +__ENZYME_MPFR_ATTRIBUTES bool __enzyme_fprt_64_52_intr_llvm_is_fpclass_f64( + double a, int32_t tests, int64_t exponent, int64_t significand, + int64_t mode, const char *loc) { + __enzyme_fprt_trace_no_res_flop({a}, "llvm_is_fpclass_f64", loc); + return __enzyme_fprt_original_64_52_intr_llvm_is_fpclass_f64( + __enzyme_fprt_double_to_ptr(a)->getResult(), tests); } -#include "enzyme/fprt/flops.def" +#include } // extern "C" - -#endif // #ifndef __ENZYME_RUNTIME_ENZYME_MPFR__ diff --git a/enzyme/include/enzyme/fprt/flops.def b/enzyme/include/enzyme/fprt/flops.def index a3c4d7fcbac6..62e4b48ba00d 100644 --- a/enzyme/include/enzyme/fprt/flops.def +++ b/enzyme/include/enzyme/fprt/flops.def @@ -1,3 +1,4 @@ +// -*- mode: c++ -*- #define __ENZYME_MPFR_DOUBLE_BINOP(LLVM_OP_NAME, MPFR_FUNC_NAME, \ ROUNDING_MODE) \ @@ -112,6 +113,9 @@ __ENZYME_MPFR_SINGOP_DOUBLE_FLOAT(lgamma, lngamma); // TODO This is not accurate (I think we cast int to double) __ENZYME_MPFR_SINGOP_DOUBLE_FLOAT(nearbyint, rint); +__ENZYME_MPFR_SINGOP(unaryop, fneg, neg, 64_52, double, d, double, + d, __ENZYME_MPFR_DEFAULT_ROUNDING_MODE) + // Ternary operation __ENZYME_MPFR_FMULADD(llvm_fmuladd, 64_52, double, d, f64, __ENZYME_MPFR_DEFAULT_ROUNDING_MODE); diff --git a/enzyme/include/enzyme/fprt/fprt.h b/enzyme/include/enzyme/fprt/fprt.h new file mode 100644 index 000000000000..2796c76536ab --- /dev/null +++ b/enzyme/include/enzyme/fprt/fprt.h @@ -0,0 +1,56 @@ +#ifndef _ENZYME_FPRT_FPRT_H_ +#define _ENZYME_FPRT_FPRT_H_ + +#include + +#ifdef __cplusplus +extern "C" { +#endif + +// User-facing API +double __enzyme_fprt_64_52_get(double _a, int64_t exponent, int64_t significand, + int64_t mode, const char *loc); +double __enzyme_fprt_64_52_new(double _a, int64_t exponent, int64_t significand, + int64_t mode, const char *loc); +void __enzyme_fprt_64_52_delete(double a, int64_t exponent, int64_t significand, + int64_t mode, const char *loc); +double __enzyme_truncate_mem_value_d(double, int, int); +float __enzyme_truncate_mem_value_f(float, int, int); +double __enzyme_expand_mem_value_d(double, int, int); +float __enzyme_expand_mem_value_f(float, int, int); +void __enzyme_fprt_delete_all(); + +// For internal use +struct __enzyme_fp; +__enzyme_fp *__enzyme_fprt_64_52_new_intermediate(int64_t exponent, + int64_t significand, + int64_t mode, + const char *loc); +double __enzyme_fprt_64_52_const(double _a, int64_t exponent, + int64_t significand, int64_t mode, + const char *loc); + +[[maybe_unused]] static bool __enzyme_fprt_is_mem_mode(int64_t mode) { + return mode & 0b0001; +} +[[maybe_unused]] static bool __enzyme_fprt_is_op_mode(int64_t mode) { + return mode & 0b0010; +} +[[maybe_unused]] static double __enzyme_fprt_idx_to_double(uint64_t p) { + return *((double *)(&p)); +} +[[maybe_unused]] static uint64_t __enzyme_fprt_double_to_idx(double d) { + return *((uint64_t *)(&d)); +} +[[maybe_unused]] static double __enzyme_fprt_ptr_to_double(__enzyme_fp *p) { + return *((double *)(&p)); +} +[[maybe_unused]] static __enzyme_fp *__enzyme_fprt_double_to_ptr(double d) { + return *((__enzyme_fp **)(&d)); +} + +#ifdef __cplusplus +} +#endif + +#endif // _ENZYME_FPRT_FPRT_H_ diff --git a/enzyme/include/enzyme/fprt/mpfr-test.h b/enzyme/include/enzyme/fprt/mpfr-test.h new file mode 100644 index 000000000000..5a48977d256c --- /dev/null +++ b/enzyme/include/enzyme/fprt/mpfr-test.h @@ -0,0 +1,271 @@ +//===- fprt/mpfr - MPFR wrappers ---------------------------------------===// +// +// Enzyme Project +// +// Part of the Enzyme Project, under the Apache License v2.0 with LLVM +// Exceptions. See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +// If using this code in an academic setting, please cite the following: +// @incollection{enzymeNeurips, +// title = {Instead of Rewriting Foreign Code for Machine Learning, +// Automatically Synthesize Fast Gradients}, +// author = {Moses, William S. and Churavy, Valentin}, +// booktitle = {Advances in Neural Information Processing Systems 33}, +// year = {2020}, +// note = {To appear in}, +// } +// +//===----------------------------------------------------------------------===// +// +// This file contains easy to use wrappers around MPFR functions. +// +//===----------------------------------------------------------------------===// +#ifndef __ENZYME_RUNTIME_ENZYME_MPFR__ +#define __ENZYME_RUNTIME_ENZYME_MPFR__ + +#include +#include +#include + +#include "fprt.h" + +#ifdef __cplusplus +extern "C" { +#endif + +#define __ENZYME_MPFR_ATTRIBUTES __attribute__((weak)) +#define __ENZYME_MPFR_ORIGINAL_ATTRIBUTES __attribute__((weak)) +#define __ENZYME_MPFR_DEFAULT_ROUNDING_MODE GMP_RNDN + +typedef struct __enzyme_fp { + mpfr_t result; +} __enzyme_fp; + +__ENZYME_MPFR_ATTRIBUTES +double __enzyme_fprt_64_52_get(double _a, int64_t exponent, int64_t significand, + int64_t mode, const char *loc) { + printf("%p, %s\n", loc, loc); + __enzyme_fp *a = __enzyme_fprt_double_to_ptr(_a); + return mpfr_get_d(a->result, __ENZYME_MPFR_DEFAULT_ROUNDING_MODE); +} + +__ENZYME_MPFR_ATTRIBUTES +double __enzyme_fprt_64_52_new(double _a, int64_t exponent, int64_t significand, + int64_t mode, const char *loc) { + printf("%p, %s\n", loc, loc); + __enzyme_fp *a = (__enzyme_fp *)malloc(sizeof(__enzyme_fp)); + mpfr_init2(a->result, significand); + mpfr_set_d(a->result, _a, __ENZYME_MPFR_DEFAULT_ROUNDING_MODE); + return __enzyme_fprt_ptr_to_double(a); +} + +__ENZYME_MPFR_ATTRIBUTES +double __enzyme_fprt_64_52_const(double _a, int64_t exponent, + int64_t significand, int64_t mode, + const char *loc) { + printf("%p, %s\n", loc, loc); + // TODO This should really be called only once for an appearance in the code, + // currently it is called every time a flop uses a constant. + return __enzyme_fprt_64_52_new(_a, exponent, significand, mode, loc); +} + +__ENZYME_MPFR_ATTRIBUTES +__enzyme_fp *__enzyme_fprt_64_52_new_intermediate(int64_t exponent, + int64_t significand, + int64_t mode, + const char *loc) { + printf("%p, %s\n", loc, loc); + __enzyme_fp *a = (__enzyme_fp *)malloc(sizeof(__enzyme_fp)); + mpfr_init2(a->result, significand); + return a; +} + +__ENZYME_MPFR_ATTRIBUTES +void __enzyme_fprt_64_52_delete(double a, int64_t exponent, int64_t significand, + int64_t mode, const char *loc) { + printf("%p, %s\n", loc, loc); + free(__enzyme_fprt_double_to_ptr(a)); +} + +#define __ENZYME_MPFR_SINGOP(OP_TYPE, LLVM_OP_NAME, MPFR_FUNC_NAME, FROM_TYPE, \ + RET, MPFR_GET, ARG1, MPFR_SET_ARG1, \ + ROUNDING_MODE) \ + __ENZYME_MPFR_ATTRIBUTES \ + RET __enzyme_fprt_##FROM_TYPE##_##OP_TYPE##_##LLVM_OP_NAME( \ + ARG1 a, int64_t exponent, int64_t significand, int64_t mode, \ + const char *loc) { \ + printf("%p, %s, %s\n", loc, #LLVM_OP_NAME, loc); \ + if (__enzyme_fprt_is_op_mode(mode)) { \ + mpfr_t ma, mc; \ + mpfr_init2(ma, significand); \ + mpfr_init2(mc, significand); \ + mpfr_set_##MPFR_SET_ARG1(ma, a, ROUNDING_MODE); \ + mpfr_##MPFR_FUNC_NAME(mc, ma, ROUNDING_MODE); \ + RET c = mpfr_get_##MPFR_GET(mc, ROUNDING_MODE); \ + mpfr_clear(ma); \ + mpfr_clear(mc); \ + return c; \ + } else if (__enzyme_fprt_is_mem_mode(mode)) { \ + __enzyme_fp *ma = __enzyme_fprt_double_to_ptr(a); \ + __enzyme_fp *mc = __enzyme_fprt_64_52_new_intermediate( \ + exponent, significand, mode, loc); \ + mpfr_##MPFR_FUNC_NAME(mc->result, ma->result, ROUNDING_MODE); \ + return __enzyme_fprt_ptr_to_double(mc); \ + } else { \ + abort(); \ + } \ + } + +// TODO this is a bit sketchy if the user cast their float to int before calling +// this. We need to detect these patterns +#define __ENZYME_MPFR_BIN_INT(OP_TYPE, LLVM_OP_NAME, MPFR_FUNC_NAME, \ + FROM_TYPE, RET, MPFR_GET, ARG1, MPFR_SET_ARG1, \ + ARG2, ROUNDING_MODE) \ + __ENZYME_MPFR_ATTRIBUTES \ + RET __enzyme_fprt_##FROM_TYPE##_##OP_TYPE##_##LLVM_OP_NAME( \ + ARG1 a, ARG2 b, int64_t exponent, int64_t significand, int64_t mode, \ + const char *loc) { \ + printf("%p, %s, %s\n", loc, #LLVM_OP_NAME, loc); \ + if (__enzyme_fprt_is_op_mode(mode)) { \ + mpfr_t ma, mc; \ + mpfr_init2(ma, significand); \ + mpfr_init2(mc, significand); \ + mpfr_set_##MPFR_SET_ARG1(ma, a, ROUNDING_MODE); \ + mpfr_##MPFR_FUNC_NAME(mc, ma, b, ROUNDING_MODE); \ + RET c = mpfr_get_##MPFR_GET(mc, ROUNDING_MODE); \ + mpfr_clear(ma); \ + mpfr_clear(mc); \ + return c; \ + } else if (__enzyme_fprt_is_mem_mode(mode)) { \ + __enzyme_fp *ma = __enzyme_fprt_double_to_ptr(a); \ + __enzyme_fp *mc = __enzyme_fprt_64_52_new_intermediate( \ + exponent, significand, mode, loc); \ + mpfr_##MPFR_FUNC_NAME(mc->result, ma->result, b, ROUNDING_MODE); \ + return __enzyme_fprt_ptr_to_double(mc); \ + } else { \ + abort(); \ + } \ + } + +#define __ENZYME_MPFR_BIN(OP_TYPE, LLVM_OP_NAME, MPFR_FUNC_NAME, FROM_TYPE, \ + RET, MPFR_GET, ARG1, MPFR_SET_ARG1, ARG2, \ + MPFR_SET_ARG2, ROUNDING_MODE) \ + __ENZYME_MPFR_ATTRIBUTES \ + RET __enzyme_fprt_##FROM_TYPE##_##OP_TYPE##_##LLVM_OP_NAME( \ + ARG1 a, ARG2 b, int64_t exponent, int64_t significand, int64_t mode, \ + const char *loc) { \ + printf("%p, %s, %s\n", loc, #LLVM_OP_NAME, loc); \ + if (__enzyme_fprt_is_op_mode(mode)) { \ + mpfr_t ma, mb, mc; \ + mpfr_init2(ma, significand); \ + mpfr_init2(mb, significand); \ + mpfr_init2(mc, significand); \ + mpfr_set_##MPFR_SET_ARG1(ma, a, ROUNDING_MODE); \ + mpfr_set_##MPFR_SET_ARG2(mb, b, ROUNDING_MODE); \ + mpfr_##MPFR_FUNC_NAME(mc, ma, mb, ROUNDING_MODE); \ + RET c = mpfr_get_##MPFR_GET(mc, ROUNDING_MODE); \ + mpfr_clear(ma); \ + mpfr_clear(mb); \ + mpfr_clear(mc); \ + return c; \ + } else if (__enzyme_fprt_is_mem_mode(mode)) { \ + __enzyme_fp *ma = __enzyme_fprt_double_to_ptr(a); \ + __enzyme_fp *mb = __enzyme_fprt_double_to_ptr(b); \ + __enzyme_fp *mc = __enzyme_fprt_64_52_new_intermediate( \ + exponent, significand, mode, loc); \ + mpfr_##MPFR_FUNC_NAME(mc->result, ma->result, mb->result, \ + ROUNDING_MODE); \ + return __enzyme_fprt_ptr_to_double(mc); \ + } else { \ + abort(); \ + } \ + } + +#define __ENZYME_MPFR_FMULADD(LLVM_OP_NAME, FROM_TYPE, TYPE, MPFR_TYPE, \ + LLVM_TYPE, ROUNDING_MODE) \ + __ENZYME_MPFR_ATTRIBUTES \ + TYPE __enzyme_fprt_##FROM_TYPE##_intr_##LLVM_OP_NAME##_##LLVM_TYPE( \ + TYPE a, TYPE b, TYPE c, int64_t exponent, int64_t significand, \ + int64_t mode, const char *loc) { \ + printf("%p, %s, %s\n", loc, #LLVM_OP_NAME, loc); \ + if (__enzyme_fprt_is_op_mode(mode)) { \ + mpfr_t ma, mb, mc, mmul, madd; \ + mpfr_init2(ma, significand); \ + mpfr_init2(mb, significand); \ + mpfr_init2(mc, significand); \ + mpfr_init2(mmul, significand); \ + mpfr_init2(madd, significand); \ + mpfr_set_##MPFR_TYPE(ma, a, ROUNDING_MODE); \ + mpfr_set_##MPFR_TYPE(mb, b, ROUNDING_MODE); \ + mpfr_set_##MPFR_TYPE(mc, c, ROUNDING_MODE); \ + mpfr_mul(mmul, ma, mb, ROUNDING_MODE); \ + mpfr_add(madd, mmul, mc, ROUNDING_MODE); \ + TYPE res = mpfr_get_##MPFR_TYPE(madd, ROUNDING_MODE); \ + mpfr_clear(ma); \ + mpfr_clear(mb); \ + mpfr_clear(mc); \ + mpfr_clear(mmul); \ + mpfr_clear(madd); \ + return res; \ + } else if (__enzyme_fprt_is_mem_mode(mode)) { \ + __enzyme_fp *ma = __enzyme_fprt_double_to_ptr(a); \ + __enzyme_fp *mb = __enzyme_fprt_double_to_ptr(b); \ + __enzyme_fp *mc = __enzyme_fprt_double_to_ptr(c); \ + double mmul = __enzyme_fprt_##FROM_TYPE##_binop_fmul( \ + __enzyme_fprt_ptr_to_double(ma), __enzyme_fprt_ptr_to_double(mb), \ + exponent, significand, mode, loc); \ + double madd = __enzyme_fprt_##FROM_TYPE##_binop_fadd( \ + mmul, __enzyme_fprt_ptr_to_double(mc), exponent, significand, mode, \ + loc); \ + return madd; \ + } else { \ + abort(); \ + } \ + } + +// TODO This does not currently make distinctions between ordered/unordered. +#define __ENZYME_MPFR_FCMP_IMPL(NAME, ORDERED, CMP, FROM_TYPE, TYPE, MPFR_GET, \ + ROUNDING_MODE) \ + __ENZYME_MPFR_ATTRIBUTES \ + bool __enzyme_fprt_##FROM_TYPE##_fcmp_##NAME( \ + TYPE a, TYPE b, int64_t exponent, int64_t significand, int64_t mode, \ + const char *loc) { \ + printf("%p, %s, %s\n", loc, "fcmp" #NAME, loc); \ + if (__enzyme_fprt_is_op_mode(mode)) { \ + mpfr_t ma, mb; \ + mpfr_init2(ma, significand); \ + mpfr_init2(mb, significand); \ + mpfr_set_##MPFR_GET(ma, a, ROUNDING_MODE); \ + mpfr_set_##MPFR_GET(mb, b, ROUNDING_MODE); \ + int ret = mpfr_cmp(ma, mb); \ + mpfr_clear(ma); \ + mpfr_clear(mb); \ + return ret CMP; \ + } else if (__enzyme_fprt_is_mem_mode(mode)) { \ + __enzyme_fp *ma = __enzyme_fprt_double_to_ptr(a); \ + __enzyme_fp *mb = __enzyme_fprt_double_to_ptr(b); \ + int ret = mpfr_cmp(ma->result, mb->result); \ + return ret CMP; \ + } else { \ + abort(); \ + } \ + } + +__ENZYME_MPFR_ORIGINAL_ATTRIBUTES +bool __enzyme_fprt_original_64_52_intr_llvm_is_fpclass_f64(double a, + int32_t tests); +__ENZYME_MPFR_ATTRIBUTES bool __enzyme_fprt_64_52_intr_llvm_is_fpclass_f64( + double a, int32_t tests, int64_t exponent, int64_t significand, + int64_t mode, const char *loc) { + return __enzyme_fprt_original_64_52_intr_llvm_is_fpclass_f64( + __enzyme_fprt_64_52_get(a, exponent, significand, mode, loc), tests); +} + +#include "flops.def" + +#ifdef __cplusplus +} +#endif + +#endif // #ifndef __ENZYME_RUNTIME_ENZYME_MPFR__ diff --git a/enzyme/include/enzyme/fprt/mpfr.h b/enzyme/include/enzyme/fprt/mpfr.h index a75cfbd84f15..58783f86242b 100644 --- a/enzyme/include/enzyme/fprt/mpfr.h +++ b/enzyme/include/enzyme/fprt/mpfr.h @@ -28,22 +28,12 @@ #include #include +#include "fprt.h" + #ifdef __cplusplus extern "C" { #endif -// TODO TODO TODO -// TODO TODO TODO -// TODO TODO TODO -// TODO TODO TODO -// TODO TODO TODO -// I dont think we intercept comparisons - we most definitely should. -// TODO TODO TODO -// TODO TODO TODO -// TODO TODO TODO -// TODO TODO TODO -// TODO TODO TODO - // TODO s // // (for MPFR ver. 2.1) @@ -73,66 +63,55 @@ extern "C" { // simulation: // [...] subnormal numbers are not implemented. // -// TODO maybe take debug info as parameter - then we can emit warnings or tie -// operations to source location -// // TODO we need to provide f32 versions, and also instrument the // truncation/expansion between f32/f64/etc -#define __ENZYME_MPFR_ATTRIBUTES __attribute__((weak)) -#define __ENZYME_MPFR_ORIGINAL_ATTRIBUTES __attribute__((weak)) +#define __ENZYME_MPFR_ATTRIBUTES __attribute__((weak)) __attribute__((used)) +#define __ENZYME_MPFR_ORIGINAL_ATTRIBUTES __attribute__((weak)) __attribute__((used)) #define __ENZYME_MPFR_DEFAULT_ROUNDING_MODE GMP_RNDN -static bool __enzyme_fprt_is_mem_mode(int64_t mode) { return mode & 0b0001; } -static bool __enzyme_fprt_is_op_mode(int64_t mode) { return mode & 0b0010; } - -typedef struct { - mpfr_t v; +typedef struct __enzyme_fp { + mpfr_t result; } __enzyme_fp; -static double __enzyme_fprt_ptr_to_double(__enzyme_fp *p) { - return *((double *)(&p)); -} -static __enzyme_fp *__enzyme_fprt_double_to_ptr(double d) { - return *((__enzyme_fp **)(&d)); -} - __ENZYME_MPFR_ATTRIBUTES double __enzyme_fprt_64_52_get(double _a, int64_t exponent, int64_t significand, - int64_t mode) { + int64_t mode, const char *loc) { __enzyme_fp *a = __enzyme_fprt_double_to_ptr(_a); - return mpfr_get_d(a->v, __ENZYME_MPFR_DEFAULT_ROUNDING_MODE); + return mpfr_get_d(a->result, __ENZYME_MPFR_DEFAULT_ROUNDING_MODE); } __ENZYME_MPFR_ATTRIBUTES double __enzyme_fprt_64_52_new(double _a, int64_t exponent, int64_t significand, - int64_t mode) { + int64_t mode, const char *loc) { __enzyme_fp *a = (__enzyme_fp *)malloc(sizeof(__enzyme_fp)); - mpfr_init2(a->v, significand); - mpfr_set_d(a->v, _a, __ENZYME_MPFR_DEFAULT_ROUNDING_MODE); + mpfr_init2(a->result, significand); + mpfr_set_d(a->result, _a, __ENZYME_MPFR_DEFAULT_ROUNDING_MODE); return __enzyme_fprt_ptr_to_double(a); } __ENZYME_MPFR_ATTRIBUTES double __enzyme_fprt_64_52_const(double _a, int64_t exponent, - int64_t significand, int64_t mode) { + int64_t significand, int64_t mode, + const char *loc) { // TODO This should really be called only once for an appearance in the code, // currently it is called every time a flop uses a constant. - return __enzyme_fprt_64_52_new(_a, exponent, significand, mode); + return __enzyme_fprt_64_52_new(_a, exponent, significand, mode, loc); } __ENZYME_MPFR_ATTRIBUTES __enzyme_fp *__enzyme_fprt_64_52_new_intermediate(int64_t exponent, int64_t significand, - int64_t mode) { + int64_t mode, + const char *loc) { __enzyme_fp *a = (__enzyme_fp *)malloc(sizeof(__enzyme_fp)); - mpfr_init2(a->v, significand); + mpfr_init2(a->result, significand); return a; } __ENZYME_MPFR_ATTRIBUTES void __enzyme_fprt_64_52_delete(double a, int64_t exponent, int64_t significand, - int64_t mode) { + int64_t mode, const char *loc) { free(__enzyme_fprt_double_to_ptr(a)); } @@ -141,7 +120,8 @@ void __enzyme_fprt_64_52_delete(double a, int64_t exponent, int64_t significand, ROUNDING_MODE) \ __ENZYME_MPFR_ATTRIBUTES \ RET __enzyme_fprt_##FROM_TYPE##_##OP_TYPE##_##LLVM_OP_NAME( \ - ARG1 a, int64_t exponent, int64_t significand, int64_t mode) { \ + ARG1 a, int64_t exponent, int64_t significand, int64_t mode, \ + const char *loc) { \ if (__enzyme_fprt_is_op_mode(mode)) { \ mpfr_t ma, mc; \ mpfr_init2(ma, significand); \ @@ -154,9 +134,9 @@ void __enzyme_fprt_64_52_delete(double a, int64_t exponent, int64_t significand, return c; \ } else if (__enzyme_fprt_is_mem_mode(mode)) { \ __enzyme_fp *ma = __enzyme_fprt_double_to_ptr(a); \ - __enzyme_fp *mc = \ - __enzyme_fprt_64_52_new_intermediate(exponent, significand, mode); \ - mpfr_##MPFR_FUNC_NAME(mc->v, ma->v, ROUNDING_MODE); \ + __enzyme_fp *mc = __enzyme_fprt_64_52_new_intermediate( \ + exponent, significand, mode, loc); \ + mpfr_##MPFR_FUNC_NAME(mc->result, ma->result, ROUNDING_MODE); \ return __enzyme_fprt_ptr_to_double(mc); \ } else { \ abort(); \ @@ -170,7 +150,8 @@ void __enzyme_fprt_64_52_delete(double a, int64_t exponent, int64_t significand, ARG2, ROUNDING_MODE) \ __ENZYME_MPFR_ATTRIBUTES \ RET __enzyme_fprt_##FROM_TYPE##_##OP_TYPE##_##LLVM_OP_NAME( \ - ARG1 a, ARG2 b, int64_t exponent, int64_t significand, int64_t mode) { \ + ARG1 a, ARG2 b, int64_t exponent, int64_t significand, int64_t mode, \ + const char *loc) { \ if (__enzyme_fprt_is_op_mode(mode)) { \ mpfr_t ma, mc; \ mpfr_init2(ma, significand); \ @@ -183,9 +164,9 @@ void __enzyme_fprt_64_52_delete(double a, int64_t exponent, int64_t significand, return c; \ } else if (__enzyme_fprt_is_mem_mode(mode)) { \ __enzyme_fp *ma = __enzyme_fprt_double_to_ptr(a); \ - __enzyme_fp *mc = \ - __enzyme_fprt_64_52_new_intermediate(exponent, significand, mode); \ - mpfr_##MPFR_FUNC_NAME(mc->v, ma->v, b, ROUNDING_MODE); \ + __enzyme_fp *mc = __enzyme_fprt_64_52_new_intermediate( \ + exponent, significand, mode, loc); \ + mpfr_##MPFR_FUNC_NAME(mc->result, ma->result, b, ROUNDING_MODE); \ return __enzyme_fprt_ptr_to_double(mc); \ } else { \ abort(); \ @@ -197,7 +178,8 @@ void __enzyme_fprt_64_52_delete(double a, int64_t exponent, int64_t significand, MPFR_SET_ARG2, ROUNDING_MODE) \ __ENZYME_MPFR_ATTRIBUTES \ RET __enzyme_fprt_##FROM_TYPE##_##OP_TYPE##_##LLVM_OP_NAME( \ - ARG1 a, ARG2 b, int64_t exponent, int64_t significand, int64_t mode) { \ + ARG1 a, ARG2 b, int64_t exponent, int64_t significand, int64_t mode, \ + const char *loc) { \ if (__enzyme_fprt_is_op_mode(mode)) { \ mpfr_t ma, mb, mc; \ mpfr_init2(ma, significand); \ @@ -214,9 +196,10 @@ void __enzyme_fprt_64_52_delete(double a, int64_t exponent, int64_t significand, } else if (__enzyme_fprt_is_mem_mode(mode)) { \ __enzyme_fp *ma = __enzyme_fprt_double_to_ptr(a); \ __enzyme_fp *mb = __enzyme_fprt_double_to_ptr(b); \ - __enzyme_fp *mc = \ - __enzyme_fprt_64_52_new_intermediate(exponent, significand, mode); \ - mpfr_##MPFR_FUNC_NAME(mc->v, ma->v, mb->v, ROUNDING_MODE); \ + __enzyme_fp *mc = __enzyme_fprt_64_52_new_intermediate( \ + exponent, significand, mode, loc); \ + mpfr_##MPFR_FUNC_NAME(mc->result, ma->result, mb->result, \ + ROUNDING_MODE); \ return __enzyme_fprt_ptr_to_double(mc); \ } else { \ abort(); \ @@ -228,7 +211,7 @@ void __enzyme_fprt_64_52_delete(double a, int64_t exponent, int64_t significand, __ENZYME_MPFR_ATTRIBUTES \ TYPE __enzyme_fprt_##FROM_TYPE##_intr_##LLVM_OP_NAME##_##LLVM_TYPE( \ TYPE a, TYPE b, TYPE c, int64_t exponent, int64_t significand, \ - int64_t mode) { \ + int64_t mode, const char *loc) { \ if (__enzyme_fprt_is_op_mode(mode)) { \ mpfr_t ma, mb, mc, mmul, madd; \ mpfr_init2(ma, significand); \ @@ -254,9 +237,10 @@ void __enzyme_fprt_64_52_delete(double a, int64_t exponent, int64_t significand, __enzyme_fp *mc = __enzyme_fprt_double_to_ptr(c); \ double mmul = __enzyme_fprt_##FROM_TYPE##_binop_fmul( \ __enzyme_fprt_ptr_to_double(ma), __enzyme_fprt_ptr_to_double(mb), \ - exponent, significand, mode); \ + exponent, significand, mode, loc); \ double madd = __enzyme_fprt_##FROM_TYPE##_binop_fadd( \ - mmul, __enzyme_fprt_ptr_to_double(mc), exponent, significand, mode); \ + mmul, __enzyme_fprt_ptr_to_double(mc), exponent, significand, mode, \ + loc); \ return madd; \ } else { \ abort(); \ @@ -268,7 +252,8 @@ void __enzyme_fprt_64_52_delete(double a, int64_t exponent, int64_t significand, ROUNDING_MODE) \ __ENZYME_MPFR_ATTRIBUTES \ bool __enzyme_fprt_##FROM_TYPE##_fcmp_##NAME( \ - TYPE a, TYPE b, int64_t exponent, int64_t significand, int64_t mode) { \ + TYPE a, TYPE b, int64_t exponent, int64_t significand, int64_t mode, \ + const char *loc) { \ if (__enzyme_fprt_is_op_mode(mode)) { \ mpfr_t ma, mb; \ mpfr_init2(ma, significand); \ @@ -282,7 +267,7 @@ void __enzyme_fprt_64_52_delete(double a, int64_t exponent, int64_t significand, } else if (__enzyme_fprt_is_mem_mode(mode)) { \ __enzyme_fp *ma = __enzyme_fprt_double_to_ptr(a); \ __enzyme_fp *mb = __enzyme_fprt_double_to_ptr(b); \ - int ret = mpfr_cmp(ma->v, mb->v); \ + int ret = mpfr_cmp(ma->result, mb->result); \ return ret CMP; \ } else { \ abort(); \ @@ -292,9 +277,11 @@ void __enzyme_fprt_64_52_delete(double a, int64_t exponent, int64_t significand, __ENZYME_MPFR_ORIGINAL_ATTRIBUTES bool __enzyme_fprt_original_64_52_intr_llvm_is_fpclass_f64(double a, int32_t tests); -__ENZYME_MPFR_ATTRIBUTES bool -__enzyme_fprt_64_52_intr_llvm_is_fpclass_f64(double a, int32_t tests) { - return __enzyme_fprt_original_64_52_intr_llvm_is_fpclass_f64(a, tests); +__ENZYME_MPFR_ATTRIBUTES bool __enzyme_fprt_64_52_intr_llvm_is_fpclass_f64( + double a, int32_t tests, int64_t exponent, int64_t significand, + int64_t mode, const char *loc) { + return __enzyme_fprt_original_64_52_intr_llvm_is_fpclass_f64( + __enzyme_fprt_64_52_get(a, exponent, significand, mode, loc), tests); } #include "flops.def" diff --git a/enzyme/test/Enzyme/ForwardMode/hypot.ll b/enzyme/test/Enzyme/ForwardMode/hypot.ll index 564bab7cd72d..623f4094f077 100644 --- a/enzyme/test/Enzyme/ForwardMode/hypot.ll +++ b/enzyme/test/Enzyme/ForwardMode/hypot.ll @@ -8,13 +8,26 @@ entry: ret double %call } +define double @tester2(double %x, double %y) { +entry: + %call = tail call double @__hypot_finite(double %x, double %y) + ret double %call +} + define double @test_derivative(double %x, double %y) { entry: %0 = tail call double (...) @__enzyme_fwddiff(double (double, double)* nonnull @tester, double %x, double 1.000000e+00, double %y, double 1.000000e+00) ret double %0 } +define double @test_derivative2(double %x, double %y) { +entry: + %0 = tail call double (...) @__enzyme_fwddiff(double (double, double)* nonnull @tester2, double %x, double 1.000000e+00, double %y, double 1.000000e+00) + ret double %0 +} + declare double @hypot(double, double) +declare double @__hypot_finite(double, double) ; Function Attrs: nounwind declare double @__enzyme_fwddiff(...) diff --git a/enzyme/test/Enzyme/Truncate/cmp.ll b/enzyme/test/Enzyme/Truncate/cmp.ll index d33c40d7de11..15140bdb5f75 100644 --- a/enzyme/test/Enzyme/Truncate/cmp.ll +++ b/enzyme/test/Enzyme/Truncate/cmp.ll @@ -29,7 +29,7 @@ entry: } ; CHECK: define internal i1 @__enzyme_done_truncate_mem_func_64_52to32_23_f(double %x, double %y) { -; CHECK-NEXT: %res = call i1 @__enzyme_fprt_64_52_fcmp_olt(double %x, double %y, i64 8, i64 23, i64 1) +; CHECK-NEXT: %res = call i1 @__enzyme_fprt_64_52_fcmp_olt(double %x, double %y, i64 8, i64 23, i64 1, {{.*}}i8{{.*}}) ; CHECK-NEXT: ret i1 %res ; CHECK-NEXT: } diff --git a/enzyme/test/Enzyme/Truncate/const.ll b/enzyme/test/Enzyme/Truncate/const.ll index 25c5c5ee4c3b..b90b20615a93 100644 --- a/enzyme/test/Enzyme/Truncate/const.ll +++ b/enzyme/test/Enzyme/Truncate/const.ll @@ -23,12 +23,12 @@ entry: } ; CHECK: define internal double @__enzyme_done_truncate_mem_func_64_52to32_23_f(double %x) { -; CHECK-NEXT: %1 = call double @__enzyme_fprt_64_52_const(double 1.000000e+00, i64 8, i64 23, i64 1) -; CHECK-NEXT: %res = call double @__enzyme_fprt_64_52_binop_fadd(double %x, double %1, i64 8, i64 23, i64 1) +; CHECK-NEXT: %1 = call double @__enzyme_fprt_64_52_const(double 1.000000e+00, i64 8, i64 23, i64 1, {{.*}}i8{{.*}}) +; CHECK-NEXT: %res = call double @__enzyme_fprt_64_52_binop_fadd(double %x, double %1, i64 8, i64 23, i64 1, {{.*}}i8{{.*}}) ; CHECK-NEXT: ret double %res ; CHECK-NEXT: } ; CHECK: define internal double @__enzyme_done_truncate_op_func_64_52to11_7_f(double %x) { -; CHECK-NEXT: %res = call double @__enzyme_fprt_64_52_binop_fadd(double %x, double 1.000000e+00, i64 3, i64 7, i64 2) +; CHECK-NEXT: %res = call double @__enzyme_fprt_64_52_binop_fadd(double %x, double 1.000000e+00, i64 3, i64 7, i64 2, {{.*}}i8{{.*}}) ; CHECK-NEXT: ret double %res ; CHECK-NEXT: } diff --git a/enzyme/test/Enzyme/Truncate/intrinsic.ll b/enzyme/test/Enzyme/Truncate/intrinsic.ll index 3e5fe36b3784..a5899e75c68c 100644 --- a/enzyme/test/Enzyme/Truncate/intrinsic.ll +++ b/enzyme/test/Enzyme/Truncate/intrinsic.ll @@ -42,28 +42,28 @@ entry: } ; CHECK: define internal double @__enzyme_done_truncate_mem_func_64_52to32_23_f(double %x, double %y) { -; CHECK-DAG: %1 = call double @__enzyme_fprt_64_52_func_pow(double %x, double %y, i64 8, i64 23, i64 1) -; CHECK-DAG: %2 = call double @__enzyme_fprt_64_52_intr_llvm_pow_f64(double %x, double %y, i64 8, i64 23, i64 1) -; CHECK-DAG: %3 = call double @__enzyme_fprt_64_52_intr_llvm_powi_f64_i16(double %x, i16 2, i64 8, i64 23, i64 1) -; CHECK-DAG: %res = call double @__enzyme_fprt_64_52_binop_fadd(double %2, double %3, i64 8, i64 23, i64 1) +; CHECK-DAG: %1 = call double @__enzyme_fprt_64_52_func_pow(double %x, double %y, i64 8, i64 23, i64 1, {{.*}}i8{{.*}}) +; CHECK-DAG: %2 = call double @__enzyme_fprt_64_52_intr_llvm_pow_f64(double %x, double %y, i64 8, i64 23, i64 1, {{.*}}i8{{.*}}) +; CHECK-DAG: %3 = call double @__enzyme_fprt_64_52_intr_llvm_powi_f64_i16(double %x, i16 2, i64 8, i64 23, i64 1, {{.*}}i8{{.*}}) +; CHECK-DAG: %res = call double @__enzyme_fprt_64_52_binop_fadd(double %2, double %3, i64 8, i64 23, i64 1, {{.*}}i8{{.*}}) ; CHECK-DAG: call void @llvm.nvvm.barrier0() ; CHECK-DAG: ret double %res ; CHECK-DAG: } ; CHECK: define internal double @__enzyme_done_truncate_op_func_64_52to32_23_f(double %x, double %y) { -; CHECK-DAG: %1 = call double @__enzyme_fprt_64_52_func_pow(double %x, double %y, i64 8, i64 23, i64 2) -; CHECK-DAG: %2 = call double @__enzyme_fprt_64_52_intr_llvm_pow_f64(double %x, double %y, i64 8, i64 23, i64 2) -; CHECK-DAG: %3 = call double @__enzyme_fprt_64_52_intr_llvm_powi_f64_i16(double %x, i16 2, i64 8, i64 23, i64 2) -; CHECK-DAG: %res = call double @__enzyme_fprt_64_52_binop_fadd(double %2, double %3, i64 8, i64 23, i64 2) +; CHECK-DAG: %1 = call double @__enzyme_fprt_64_52_func_pow(double %x, double %y, i64 8, i64 23, i64 2, {{.*}}i8{{.*}}) +; CHECK-DAG: %2 = call double @__enzyme_fprt_64_52_intr_llvm_pow_f64(double %x, double %y, i64 8, i64 23, i64 2, {{.*}}i8{{.*}}) +; CHECK-DAG: %3 = call double @__enzyme_fprt_64_52_intr_llvm_powi_f64_i16(double %x, i16 2, i64 8, i64 23, i64 2, {{.*}}i8{{.*}}) +; CHECK-DAG: %res = call double @__enzyme_fprt_64_52_binop_fadd(double %2, double %3, i64 8, i64 23, i64 2, {{.*}}i8{{.*}}) ; CHECK-DAG: call void @llvm.nvvm.barrier0() ; CHECK-DAG: ret double %res ; CHECK-DAG: } ; CHECK: define internal double @__enzyme_done_truncate_op_func_64_52to11_7_f(double %x, double %y) { -; CHECK-DAG: %1 = call double @__enzyme_fprt_64_52_func_pow(double %x, double %y, i64 3, i64 7, i64 2) -; CHECK-DAG: %2 = call double @__enzyme_fprt_64_52_intr_llvm_pow_f64(double %x, double %y, i64 3, i64 7, i64 2) -; CHECK-DAG: %3 = call double @__enzyme_fprt_64_52_intr_llvm_powi_f64_i16(double %x, i16 2, i64 3, i64 7, i64 2) -; CHECK-DAG: %res = call double @__enzyme_fprt_64_52_binop_fadd(double %2, double %3, i64 3, i64 7, i64 2) +; CHECK-DAG: %1 = call double @__enzyme_fprt_64_52_func_pow(double %x, double %y, i64 3, i64 7, i64 2, {{.*}}i8{{.*}}) +; CHECK-DAG: %2 = call double @__enzyme_fprt_64_52_intr_llvm_pow_f64(double %x, double %y, i64 3, i64 7, i64 2, {{.*}}i8{{.*}}) +; CHECK-DAG: %3 = call double @__enzyme_fprt_64_52_intr_llvm_powi_f64_i16(double %x, i16 2, i64 3, i64 7, i64 2, {{.*}}i8{{.*}}) +; CHECK-DAG: %res = call double @__enzyme_fprt_64_52_binop_fadd(double %2, double %3, i64 3, i64 7, i64 2, {{.*}}i8{{.*}}) ; CHECK-DAG: call void @llvm.nvvm.barrier0() ; CHECK-DAG: ret double %res ; CHECK-DAG: } diff --git a/enzyme/test/Enzyme/Truncate/simple.ll b/enzyme/test/Enzyme/Truncate/simple.ll index 747e268ae381..cd94c87aba46 100644 --- a/enzyme/test/Enzyme/Truncate/simple.ll +++ b/enzyme/test/Enzyme/Truncate/simple.ll @@ -36,21 +36,21 @@ entry: ; CHECK: define internal void @__enzyme_done_truncate_mem_func_64_52to32_23_f(double* %x) { ; CHECK-DAG: %y = load double, double* %x, align 8 -; CHECK-DAG: %m = call double @__enzyme_fprt_64_52_binop_fmul(double %y, double %y, i64 8, i64 23, i64 1) +; CHECK-DAG: %m = call double @__enzyme_fprt_64_52_binop_fmul(double %y, double %y, i64 8, i64 23, i64 1, {{.*}}i8{{.*}}) ; CHECK-DAG: store double %m, double* %x, align 8 ; CHECK-DAG: ret void ; CHECK-DAG: } ; CHECK: define internal void @__enzyme_done_truncate_op_func_64_52to32_23_f(double* %x) { ; CHECK-DAG: %y = load double, double* %x, align 8 -; CHECK-DAG: %m = call double @__enzyme_fprt_64_52_binop_fmul(double %y, double %y, i64 8, i64 23, i64 2) +; CHECK-DAG: %m = call double @__enzyme_fprt_64_52_binop_fmul(double %y, double %y, i64 8, i64 23, i64 2, {{.*}}i8{{.*}}) ; CHECK-DAG: store double %m, double* %x, align 8 ; CHECK-DAG: ret void ; CHECK-DAG: } ; CHECK: define internal void @__enzyme_done_truncate_op_func_64_52to11_7_f(double* %x) { ; CHECK-DAG: %y = load double, double* %x, align 8 -; CHECK-DAG: %m = call double @__enzyme_fprt_64_52_binop_fmul(double %y, double %y, i64 3, i64 7, i64 2) +; CHECK-DAG: %m = call double @__enzyme_fprt_64_52_binop_fmul(double %y, double %y, i64 3, i64 7, i64 2, {{.*}}i8{{.*}}) ; CHECK-DAG: store double %m, double* %x, align 8 ; CHECK-DAG: ret void ; CHECK-DAG: } diff --git a/enzyme/test/Enzyme/Truncate/value.ll b/enzyme/test/Enzyme/Truncate/value.ll index fa79e93440bb..1722b4fc1efc 100644 --- a/enzyme/test/Enzyme/Truncate/value.ll +++ b/enzyme/test/Enzyme/Truncate/value.ll @@ -18,10 +18,10 @@ entry: ; CHECK: define double @expand_tester(double %a, double* %c) { ; CHECK-NEXT: entry: -; CHECK-NEXT: %0 = call double @__enzyme_fprt_64_52_get(double %a, i64 8, i64 23, i64 1) +; CHECK-NEXT: %0 = call double @__enzyme_fprt_64_52_get(double %a, i64 8, i64 23, i64 1, {{.*}}i8{{.*}}) ; CHECK-NEXT: ret double %0 ; CHECK: define double @truncate_tester(double %a) { ; CHECK-NEXT: entry: -; CHECK-NEXT: %0 = call double @__enzyme_fprt_64_52_new(double %a, i64 8, i64 23, i64 1) +; CHECK-NEXT: %0 = call double @__enzyme_fprt_64_52_new(double %a, i64 8, i64 23, i64 1, {{.*}}i8{{.*}}) ; CHECK-NEXT: ret double %0 diff --git a/enzyme/test/Integration/Truncate/simple.cpp b/enzyme/test/Integration/Truncate/simple.cpp index 635a2e3bc04c..dff19c9a1e45 100644 --- a/enzyme/test/Integration/Truncate/simple.cpp +++ b/enzyme/test/Integration/Truncate/simple.cpp @@ -5,6 +5,9 @@ // RUN: if [ %llvmver -ge 12 ] && [ %hasMPFR == "yes" ] ; then %clang -DTRUNC_MEM -DTRUNC_OP -O2 %s -o %s.a.out %newLoadClangEnzyme -include enzyme/fprt/mpfr.h -lm -lmpfr && %s.a.out ; fi // RUN: if [ %llvmver -ge 12 ] && [ %hasMPFR == "yes" ] ; then %clang -g -DTRUNC_MEM -DTRUNC_OP -O2 %s -o %s.a.out %newLoadClangEnzyme -include enzyme/fprt/mpfr.h -lm -lmpfr && %s.a.out ; fi +// RUN: if [ %llvmver -ge 12 ] && [ %hasMPFR == "yes" ] ; then %clang -DTRUNC_MEM -DTRUNC_OP -O2 %s -o %s.a.out %newLoadClangEnzyme -include enzyme/fprt/mpfr-test.h -lm -lmpfr && %s.a.out ; fi +// RUN: if [ %llvmver -ge 12 ] && [ %hasMPFR == "yes" ] ; then %clang -g -DTRUNC_MEM -DTRUNC_OP -O2 %s -o %s.a.out %newLoadClangEnzyme -include enzyme/fprt/mpfr-test.h -lm -lmpfr && %s.a.out ; fi + #include #include "../test_utils.h" @@ -27,21 +30,34 @@ double intrinsics(double a, double b) { double constt(double a, double b) { return 2; } +void const_store(double *a) { + *a = 2.0; +} +double phinode(double a, double b, int n) { + double sum = 0; + for (int i = 0; i < n; i++) { + sum += (exp(a + b) - exp(a)) / b; + b /= 10; + } + return sum; +} double compute(double *A, double *B, double *C, int n) { for (int i = 0; i < n; i++) { C[i] = A[i] * 2 + B[i] * sqrt(A[i]); } return C[0]; } +double intcast(int a) { + double d = (double) a; + return d / 3.14; +} typedef double (*fty)(double *, double *, double *, int); typedef double (*fty2)(double, double); -extern fty __enzyme_truncate_mem_func_2(...); -extern fty2 __enzyme_truncate_mem_func(...); -extern fty __enzyme_truncate_op_func_2(...); -extern fty2 __enzyme_truncate_op_func(...); +template fty *__enzyme_truncate_mem_func(fty *, int, int); +template fty *__enzyme_truncate_op_func(fty *, int, int); extern double __enzyme_truncate_mem_value(...); extern double __enzyme_expand_mem_value(...); @@ -89,16 +105,36 @@ int main() { double trunc = __enzyme_expand_mem_value(__enzyme_truncate_mem_func(intrinsics, FROM, TO)(a, b), FROM, TO); APPROX_EQ(trunc, truth, 1e-5); } + { + double a = 2; + double b = 3; + double truth = constt(a, b); + a = __enzyme_truncate_mem_value(a, FROM, TO); + b = __enzyme_truncate_mem_value(b, FROM, TO); + double trunc = __enzyme_expand_mem_value(__enzyme_truncate_mem_func(constt, FROM, TO)(a, b), FROM, TO); + APPROX_EQ(trunc, truth, 1e-5); + } + { + double a = 2; + double b = 3; + double truth = phinode(a, b, 10); + a = __enzyme_truncate_mem_value(a, FROM, TO); + b = __enzyme_truncate_mem_value(b, FROM, TO); + double trunc = __enzyme_expand_mem_value(__enzyme_truncate_mem_func(phinode, FROM, TO)(a, b, 10), FROM, TO); + APPROX_EQ(trunc, truth, 20.0); + } + { + double truth = 0; + const_store(&truth); + double a = 0; + __enzyme_truncate_mem_func(const_store, FROM, TO)(&a); + a = __enzyme_expand_mem_value(a, FROM, TO); + APPROX_EQ(a, truth, 1e-5); + } + { + __enzyme_truncate_mem_func(intcast, FROM, TO)(64); + } #endif - // { - // double a = 2; - // double b = 3; - // double truth = intrinsics(a, b); - // a = __enzyme_truncate_mem_value(a, FROM, TO); - // b = __enzyme_truncate_mem_value(b, FROM, TO); - // double trunc = __enzyme_expand_mem_value(__enzyme_truncate_mem_func(constt, FROM, TO)(a, b), FROM, TO); - // APPROX_EQ(trunc, truth, 1e-5); - // } #ifdef TRUNC_OP { @@ -120,7 +156,7 @@ int main() { // B[i] = __enzyme_truncate_mem_value(B[i], 64, 32); // } - __enzyme_truncate_op_func_2(compute, 64, 32)(A, B, C, N); + __enzyme_truncate_op_func(compute, 64, 32)(A, B, C, N); // for (int i = 0; i < N; i++) { // C[i] = __enzyme_expand_mem_value(C[i], 64, 32); diff --git a/enzyme/test/Integration/Truncate/truncate-all-header.h b/enzyme/test/Integration/Truncate/truncate-all-header.h new file mode 100644 index 000000000000..3fd9f0780365 --- /dev/null +++ b/enzyme/test/Integration/Truncate/truncate-all-header.h @@ -0,0 +1,15 @@ +#ifndef TRUNCATE_ALL_HEADER_H_ +#define TRUNCATE_ALL_HEADER_H_ + +#include + +#define N 6 + +#define floatty double + +__attribute__((noinline)) static +floatty intrinsics2(floatty a, floatty b) { + return sin(a) * cos(b); +} + +#endif // TRUNCATE_ALL_HEADER_H_ diff --git a/enzyme/test/Integration/Truncate/truncate-all.cpp b/enzyme/test/Integration/Truncate/truncate-all.cpp index d5038d4750cb..818c2c603cac 100644 --- a/enzyme/test/Integration/Truncate/truncate-all.cpp +++ b/enzyme/test/Integration/Truncate/truncate-all.cpp @@ -16,16 +16,27 @@ // RUN: if [ %llvmver -ge 12 ] && [ %hasMPFR == "yes" ] ; then %clang -include enzyme/fprt/mpfr.h -O3 %s -o %s.a.out %newLoadClangEnzyme -mllvm --enzyme-truncate-all="11-52to3-7" -lmpfr -lm && %s.a.out | FileCheck --check-prefix TO_3_7 %s; fi // TO_3_7: 897581056.000000 -#include - +// RUN: if [ %llvmver -ge 12 ] && [ %hasMPFR == "yes" ] ; then %clang -g -include enzyme/fprt/mpfr-test.h -O3 %s -o %s.a.out %newLoadClangEnzyme -mllvm --enzyme-truncate-all="11-52to3-7" -lmpfr -lm && %s.a.out | FileCheck --check-prefix CHECK-LOCS %s; fi +// CHECK-LOCS: 0x[[op1:[0-9a-f]*]], {{.*}}truncate-all.cpp:[[op1loc:.*]] +// CHECK-LOCS-NEXT: 0x[[op2:[0-9a-f]*]], {{.*}}truncate-all.cpp:[[op2loc:.*]] +// CHECK-LOCS-NEXT: 0x[[op3:[0-9a-f]*]], {{.*}}truncate-all.cpp:[[op3loc:.*]] +// CHECK-LOCS-NEXT: 0x[[op4:[0-9a-f]*]], {{.*}}truncate-all-header.h:[[op4loc:.*]] +// CHECK-LOCS-NEXT: 0x[[op5:[0-9a-f]*]], {{.*}}truncate-all-header.h:[[op5loc:.*]] +// CHECK-LOCS-NEXT: 0x[[op6:[0-9a-f]*]], {{.*}}truncate-all-header.h:[[op6loc:.*]] +// CHECK-LOCS-NEXT: 0x[[op7:[0-9a-f]*]], {{.*}}truncate-all.cpp:[[op7loc:.*]] +// CHECK-LOCS-NEXT: 0x[[op1]], {{.*}}truncate-all.cpp:[[op1loc]] +// CHECK-LOCS-NEXT: 0x[[op2]], {{.*}}truncate-all.cpp:[[op2loc]] +// CHECK-LOCS-NEXT: 0x[[op3]], {{.*}}truncate-all.cpp:[[op3loc]] +// CHECK-LOCS-NEXT: 0x[[op4]], {{.*}}truncate-all-header.h:[[op4loc]] +// CHECK-LOCS-NEXT: 0x[[op5]], {{.*}}truncate-all-header.h:[[op5loc]] +// CHECK-LOCS-NEXT: 0x[[op6]], {{.*}}truncate-all-header.h:[[op6loc]] +// CHECK-LOCS-NEXT: 0x[[op7]], {{.*}}truncate-all.cpp:[[op7loc]] + + +#include "truncate-all-header.h" #include "../test_utils.h" -#define N 10 - -#define floatty double - - __attribute__((noinline)) floatty simple_add(floatty a, floatty b) { return a + b; @@ -35,6 +46,13 @@ floatty intrinsics(floatty a, floatty b) { return sqrt(a) * pow(b, 2); } __attribute__((noinline)) +floatty compute2(floatty *A, floatty *B, floatty *C, int n) { + for (int i = 0; i < n; i++) { + C[i] = A[i] / 2 + intrinsics2(A[i], simple_add(B[i] * 10000, 0.000001)); + } + return C[0]; +} +__attribute__((noinline)) floatty compute(floatty *A, floatty *B, floatty *C, int n) { for (int i = 0; i < n; i++) { C[i] = A[i] / 2 + intrinsics(A[i], simple_add(B[i] * 10000, 0.000001)); @@ -52,6 +70,9 @@ int main() { B[i] = 1 + i % 3; } + compute2(A, B, C, N); + for (int i = 0; i < N; i++) + C[i] = 0; compute(A, B, C, N); printf("%f\n", C[5]); } diff --git a/enzyme/test/Integration/Truncate/warnings.cpp b/enzyme/test/Integration/Truncate/warnings.cpp new file mode 100644 index 000000000000..b63636440d3d --- /dev/null +++ b/enzyme/test/Integration/Truncate/warnings.cpp @@ -0,0 +1,62 @@ +// RUN: if [ %llvmver -ge 12 ] && [ %hasMPFR == "yes" ] ; then %clang -c -DTRUNC_MEM -O2 %s -o /dev/null -emit-llvm %newLoadClangEnzyme -include enzyme/fprt/mpfr.h -Xclang -verify -Rpass=enzyme; fi +// RUN: if [ %llvmver -ge 12 ] && [ %hasMPFR == "yes" ] ; then %clang -c -DTRUNC_MEM -O2 -g %s -o /dev/null -emit-llvm %newLoadClangEnzyme -include enzyme/fprt/mpfr.h -Xclang -verify -Rpass=enzyme; fi +// COM: if [ %llvmver -ge 12 ] && [ %hasMPFR == "yes" ] ; then %clang -c -DTRUNC_OP -O2 %s -o /dev/null -emit-llvm %newLoadClangEnzyme -include enzyme/fprt/mpfr.h -Xclang -verify -Rpass=enzyme; fi +// COM: if [ %llvmver -ge 12 ] && [ %hasMPFR == "yes" ] ; then %clang -c -DTRUNC_OP -O2 -g %s -o /dev/null -emit-llvm %newLoadClangEnzyme -include enzyme/fprt/mpfr.h -Xclang -verify -Rpass=enzyme; fi + +#include +#include + +#define FROM 64 +#define TO 32 + +double bithack(double a) { + return *((int64_t *)&a) + 1; // expected-remark {{Will not follow FP through this cast.}}, expected-remark {{Will not follow FP through this cast.}} +} +__attribute__((noinline)) +void print_d(double a) { + printf("%f\n", a); // expected-remark {{Will not follow FP through this function call as the definition is not available.}} +} +__attribute__((noinline)) +float truncf(double a) { + return (float)a; // expected-remark {{Will not follow FP through this cast.}} +} + +double intrinsics(double a, double b) { + return bithack(a) * truncf(b); // expected-remark {{Will not follow FP through this cast.}} +} + +typedef double (*fty)(double *, double *, double *, int); + +typedef double (*fty2)(double, double); + +template fty *__enzyme_truncate_mem_func(fty *, int, int); +extern fty __enzyme_truncate_op_func_2(...); +extern fty2 __enzyme_truncate_op_func(...); +extern double __enzyme_truncate_mem_value(...); +extern double __enzyme_expand_mem_value(...); + + +int main() { + #ifdef TRUNC_MEM + { + double a = 2; + double b = 3; + a = __enzyme_truncate_mem_value(a, FROM, TO); + b = __enzyme_truncate_mem_value(b, FROM, TO); + double trunc = __enzyme_expand_mem_value(__enzyme_truncate_mem_func(intrinsics, FROM, TO)(a, b), FROM, TO); + } + { + double a = 2; + a = __enzyme_truncate_mem_value(a, FROM, TO); + __enzyme_truncate_mem_func(print_d, FROM, TO)(a); + } + #endif + #ifdef TRUNC_OP + { + double a = 2; + double b = 3; + double trunc = __enzyme_truncate_op_func(intrinsics, FROM, TO)(a, b); + } + #endif + +} diff --git a/enzyme/tools/enzyme-tblgen/enzyme-tblgen.cpp b/enzyme/tools/enzyme-tblgen/enzyme-tblgen.cpp index 2e5ee993f988..26619edcdcd4 100644 --- a/enzyme/tools/enzyme-tblgen/enzyme-tblgen.cpp +++ b/enzyme/tools/enzyme-tblgen/enzyme-tblgen.cpp @@ -1787,10 +1787,17 @@ static void emitDerivatives(const RecordKeeper &recordKeeper, raw_ostream &os, bool prev = false; for (auto *nameI : *cast(pattern->getValueAsListInit("names"))) { - if (prev) - os << " ||\n "; - os << "funcName == " << cast(nameI)->getAsString() << ""; - prev = true; + auto nameIStr = cast(nameI)->getAsString(); + auto nameIStrFinite = "\"__" + + std::string(std::next(nameIStr.begin()), + std::prev(nameIStr.end())) + + "_finite\""; + for (auto nameIStrAll : {nameIStr, nameIStrFinite}) { + if (prev) + os << " ||\n "; + os << "funcName == " << nameIStrAll << ""; + prev = true; + } } origName = "call"; #if LLVM_VERSION_MAJOR >= 14