Skip to content

Commit

Permalink
Support forward translation of fpmath metadata (#2266)
Browse files Browse the repository at this point in the history
The Clang frontend uses the `fpmath` metadata to represent FP accuracy in some cases, right now for OpenCL and possibly SYCL in the near future. The long term plan is to expect the frontend to add the `fpbuiltin-max-error` attribute, but since `fpmath` is generated today and currently dropped, let's support it using the existing `SPV_INTEL_fp_max_error` extension.

Upon reverse translation, we generate the same `fpbuiltin-max-error` attribute/metadata that we would generate if the frontend had added `fpbuiltin-max-error` originally, so anybody using the reverse translation result only needs to handle one case.

Signed-off-by: Sarnie, Nick <[email protected]>
  • Loading branch information
sarnex authored Dec 19, 2023
1 parent b891a51 commit e63fdb8
Show file tree
Hide file tree
Showing 3 changed files with 81 additions and 34 deletions.
1 change: 0 additions & 1 deletion lib/SPIRV/SPIRVRegularizeLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -429,7 +429,6 @@ bool SPIRVRegularizeLLVMBase::regularize() {

// Remove metadata not supported by SPIRV
static const char *MDs[] = {
"fpmath",
"tbaa",
"range",
};
Expand Down
69 changes: 36 additions & 33 deletions lib/SPIRV/SPIRVWriter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -302,6 +302,31 @@ static bool recursiveType(const StructType *ST, const Type *Ty) {
return Run(Ty);
}

// Add decoration if needed
void addFPBuiltinDecoration(SPIRVModule *BM, Instruction *Inst,
SPIRVInstruction *I) {
if (!BM->isAllowedToUseExtension(ExtensionID::SPV_INTEL_fp_max_error))
return;
auto *II = dyn_cast_or_null<IntrinsicInst>(Inst);
if (II && II->getCalledFunction()->getName().starts_with("llvm.fpbuiltin")) {
// Add a new decoration for llvm.builtin intrinsics, if needed
if (II->getAttributes().hasFnAttr("fpbuiltin-max-error")) {
double F = 0.0;
II->getAttributes()
.getFnAttr("fpbuiltin-max-error")
.getValueAsString()
.getAsDouble(F);
I->addDecorate(DecorationFPMaxErrorDecorationINTEL,
convertFloatToSPIRVWord(F));
}
} else if (auto *MD = Inst->getMetadata("fpmath")) {
auto *MDVal = mdconst::dyn_extract<ConstantFP>(MD->getOperand(0));
double ValAsDouble = MDVal->getValue().convertToFloat();
I->addDecorate(DecorationFPMaxErrorDecorationINTEL,
convertFloatToSPIRVWord(ValAsDouble));
}
}

SPIRVType *LLVMToSPIRVBase::transType(Type *T) {
LLVMToSPIRVTypeMap::iterator Loc = TypeMap.find(T);
if (Loc != TypeMap.end())
Expand Down Expand Up @@ -2830,6 +2855,8 @@ bool LLVMToSPIRVBase::transDecoration(Value *V, SPIRVValue *BV) {
transMemAliasingINTELDecorations(Inst, BV);
if (auto *IDecoMD = Inst->getMetadata(SPIRV_MD_DECORATIONS))
transMetadataDecorations(IDecoMD, BV);
if (BV->isInst())
addFPBuiltinDecoration(BM, Inst, static_cast<SPIRVInstruction *>(BV));
}

if (auto *CI = dyn_cast<CallInst>(V)) {
Expand Down Expand Up @@ -3626,26 +3653,6 @@ bool LLVMToSPIRVBase::isKnownIntrinsic(Intrinsic::ID Id) {
}
}

// Add decoration if needed
SPIRVInstruction *addFPBuiltinDecoration(SPIRVModule *BM, IntrinsicInst *II,
SPIRVInstruction *I) {
const bool AllowFPMaxError =
BM->isAllowedToUseExtension(ExtensionID::SPV_INTEL_fp_max_error);
assert(II->getCalledFunction()->getName().starts_with("llvm.fpbuiltin"));
// Add a new decoration for llvm.builtin intrinsics, if needed
if (AllowFPMaxError)
if (II->getAttributes().hasFnAttr("fpbuiltin-max-error")) {
double F = 0.0;
II->getAttributes()
.getFnAttr("fpbuiltin-max-error")
.getValueAsString()
.getAsDouble(F);
I->addDecorate(DecorationFPMaxErrorDecorationINTEL,
convertFloatToSPIRVWord(F));
}
return I;
}

// Performs mapping of LLVM IR rounding mode to SPIR-V rounding mode
// Value *V is metadata <rounding mode> argument of
// llvm.experimental.constrained.* intrinsics
Expand Down Expand Up @@ -4758,10 +4765,9 @@ SPIRVValue *LLVMToSPIRVBase::transFPBuiltinIntrinsicInst(IntrinsicInst *II,
.Case("fdiv", OpFDiv)
.Case("frem", OpFRem)
.Default(OpUndef);
auto *BI = BM->addBinaryInst(BinOp, transType(II->getType()),
transValue(II->getArgOperand(0), BB),
transValue(II->getArgOperand(1), BB), BB);
return addFPBuiltinDecoration(BM, II, BI);
return BM->addBinaryInst(BinOp, transType(II->getType()),
transValue(II->getArgOperand(0), BB),
transValue(II->getArgOperand(1), BB), BB);
}
case FPBuiltinType::EXT_1OPS: {
if (!checkTypeForSPIRVExtendedInstLowering(II, BM))
Expand Down Expand Up @@ -4795,9 +4801,8 @@ SPIRVValue *LLVMToSPIRVBase::transFPBuiltinIntrinsicInst(IntrinsicInst *II,
.Case("erfc", OpenCLLIB::Erfc)
.Default(SPIRVWORD_MAX);
assert(ExtOp != SPIRVWORD_MAX);
auto *BI = BM->addExtInst(STy, BM->getExtInstSetId(SPIRVEIS_OpenCL), ExtOp,
Ops, BB);
return addFPBuiltinDecoration(BM, II, BI);
return BM->addExtInst(STy, BM->getExtInstSetId(SPIRVEIS_OpenCL), ExtOp, Ops,
BB);
}
case FPBuiltinType::EXT_2OPS: {
if (!checkTypeForSPIRVExtendedInstLowering(II, BM))
Expand All @@ -4812,9 +4817,8 @@ SPIRVValue *LLVMToSPIRVBase::transFPBuiltinIntrinsicInst(IntrinsicInst *II,
.Case("ldexp", OpenCLLIB::Ldexp)
.Default(SPIRVWORD_MAX);
assert(ExtOp != SPIRVWORD_MAX);
auto *BI = BM->addExtInst(STy, BM->getExtInstSetId(SPIRVEIS_OpenCL), ExtOp,
Ops, BB);
return addFPBuiltinDecoration(BM, II, BI);
return BM->addExtInst(STy, BM->getExtInstSetId(SPIRVEIS_OpenCL), ExtOp, Ops,
BB);
}
case FPBuiltinType::EXT_3OPS: {
if (!checkTypeForSPIRVExtendedInstLowering(II, BM))
Expand All @@ -4827,9 +4831,8 @@ SPIRVValue *LLVMToSPIRVBase::transFPBuiltinIntrinsicInst(IntrinsicInst *II,
.Case("sincos", OpenCLLIB::Sincos)
.Default(SPIRVWORD_MAX);
assert(ExtOp != SPIRVWORD_MAX);
auto *BI = BM->addExtInst(STy, BM->getExtInstSetId(SPIRVEIS_OpenCL), ExtOp,
Ops, BB);
return addFPBuiltinDecoration(BM, II, BI);
return BM->addExtInst(STy, BM->getExtInstSetId(SPIRVEIS_OpenCL), ExtOp, Ops,
BB);
}
default:
return nullptr;
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
; Confirm that we handle fpmath metadata correctly

; RUN: llvm-as %s -o %t.bc
; RUN: llvm-spirv %t.bc --spirv-ext=+SPV_INTEL_fp_max_error -o %t.spv
; RUN: llvm-spirv %t.spv -to-text -o %t.spt
; RUN: FileCheck < %t.spt %s --check-prefix=CHECK-SPIRV

; RUN: llvm-spirv -r %t.spv -o %t.rev.bc
; RUN: llvm-dis < %t.rev.bc | FileCheck %s --check-prefix=CHECK-LLVM

; CHECK-SPIRV: Capability FPMaxErrorINTEL
; CHECK-SPIRV: Extension "SPV_INTEL_fp_max_error"
; CHECK-SPIRV: ExtInstImport [[#OCLEXTID:]] "OpenCL.std"

; CHECK-SPIRV: Name [[#CalleeName:]] "callee"
; CHECK-SPIRV: Name [[#F3:]] "f3"
; CHECK-SPIRV: Decorate [[#F3]] FPMaxErrorDecorationINTEL 1075838976
; CHECK-SPIRV: Decorate [[#Callee:]] FPMaxErrorDecorationINTEL 1065353216

; CHECK-SPIRV: TypeFloat [[#FloatTy:]] 32

target datalayout = "e-i64:64-v16:16-v24:32-v32:32-v48:64-v96:128-v192:256-v256:256-v512:512-v1024:1024-n8:16:32:64"
target triple = "spir64-unknown-unknown"

define float @callee(float %f1, float %f2) {
entry:
ret float %f1
}

define void @test_fp_max_error_decoration(float %f1, float %f2) {
entry:
; CHECK-LLVM: fdiv float %f1, %f2, !fpbuiltin-max-error ![[#ME0:]]
%f3 = fdiv float %f1, %f2, !fpmath !0

; CHECK-LLVM: call {{.*}} float @callee(float %f1, float %f2) #[[#ATTR0:]]
; CHECK-SPIRV: FunctionCall [[#FloatTy]] [[#Callee]] [[#CalleeName]]
call float @callee(float %f1, float %f2), !fpmath !1
ret void
}

; CHECK-LLVM: attributes #[[#ATTR0]] = {{{.*}}"fpbuiltin-max-error"="1.000000"{{.*}}}

; CHECK-LLVM: ![[#ME0]] = !{!"2.500000"}
!0 = !{float 2.500000e+00}
!1 = !{float 1.000000e+00}

0 comments on commit e63fdb8

Please sign in to comment.