Skip to content

Commit

Permalink
Add Float16 to supported x86 processors (#46499)
Browse files Browse the repository at this point in the history
* Add float16 multiversioning for x86

Co-authored-by: pchintalapudi <[email protected]>
Co-authored-by: Mosè Giordano <[email protected]>
  • Loading branch information
3 people authored and BioTurboNick committed Oct 14, 2023
1 parent bed2cd5 commit a639fab
Show file tree
Hide file tree
Showing 5 changed files with 66 additions and 30 deletions.
2 changes: 2 additions & 0 deletions src/features_x86.h
Original file line number Diff line number Diff line change
Expand Up @@ -74,11 +74,13 @@ JL_FEATURE_DEF(enqcmd, 32 * 3 + 29, 0)
// EAX=7,ECX=0: EDX
// JL_FEATURE_DEF(avx5124vnniw, 32 * 4 + 2, ?????)
// JL_FEATURE_DEF(avx5124fmaps, 32 * 4 + 3, ?????)
JL_FEATURE_DEF(uintr, 32 * 4 + 5, 140000)
JL_FEATURE_DEF(avx512vp2intersect, 32 * 4 + 8, 0)
JL_FEATURE_DEF(serialize, 32 * 4 + 14, 110000)
JL_FEATURE_DEF(tsxldtrk, 32 * 4 + 16, 110000)
JL_FEATURE_DEF(pconfig, 32 * 4 + 18, 0)
JL_FEATURE_DEF_NAME(amx_bf16, 32 * 4 + 22, 110000, "amx-bf16")
JL_FEATURE_DEF(avx512fp16, 32 * 4 + 23, 140000)
JL_FEATURE_DEF_NAME(amx_tile, 32 * 4 + 24, 110000, "amx-tile")
JL_FEATURE_DEF_NAME(amx_int8, 32 * 4 + 25, 110000, "amx-int8")

Expand Down
17 changes: 2 additions & 15 deletions src/llvm-demote-float16.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -47,35 +47,22 @@ INST_STATISTIC(FCmp);

extern JuliaOJIT *jl_ExecutionEngine;

Optional<bool> always_have_fp16() {
#if defined(_CPU_X86_) || defined(_CPU_X86_64_)
// x86 doesn't support fp16
// TODO: update for sapphire rapids when it comes out
return false;
#else
return {};
#endif
}

namespace {

bool have_fp16(Function &caller) {
auto unconditional = always_have_fp16();
if (unconditional.hasValue())
return unconditional.getValue();

Attribute FSAttr = caller.getFnAttribute("target-features");
StringRef FS =
FSAttr.isValid() ? FSAttr.getValueAsString() : jl_ExecutionEngine->getTargetFeatureString();
#if defined(_CPU_AARCH64_)
if (FS.find("+fp16fml") != llvm::StringRef::npos || FS.find("+fullfp16") != llvm::StringRef::npos){
return true;
}
#else
#elif defined(_CPU_X86_64_)
if (FS.find("+avx512fp16") != llvm::StringRef::npos){
return true;
}
#endif
(void)FS;
return false;
}

Expand Down
13 changes: 5 additions & 8 deletions src/llvm-multiversioning.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,6 @@ using namespace llvm;

extern Optional<bool> always_have_fma(Function&);

extern Optional<bool> always_have_fp16();

void replaceUsesWithLoad(Function &F, function_ref<GlobalVariable *(Instruction &I)> should_replace, MDNode *tbaa_const);

namespace {
Expand Down Expand Up @@ -490,13 +488,12 @@ uint32_t CloneCtx::collect_func_info(Function &F)
flag |= JL_TARGET_CLONE_MATH;
}
}
if(!always_have_fp16().hasValue()){
for (size_t i = 0; i < I.getNumOperands(); i++) {
if(I.getOperand(i)->getType()->isHalfTy()){
flag |= JL_TARGET_CLONE_FLOAT16;
}
// Check for BFloat16 when they are added to julia can be done here

for (size_t i = 0; i < I.getNumOperands(); i++) {
if(I.getOperand(i)->getType()->isHalfTy()){
flag |= JL_TARGET_CLONE_FLOAT16;
}
// Check for BFloat16 when they are added to julia can be done here
}
if (has_veccall && (flag & JL_TARGET_CLONE_SIMD) && (flag & JL_TARGET_CLONE_MATH) &&
(flag & JL_TARGET_CLONE_CPU) && (flag & JL_TARGET_CLONE_FLOAT16)) {
Expand Down
18 changes: 14 additions & 4 deletions src/processor_x86.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,9 @@ static constexpr FeatureDep deps[] = {
{avx512vnni, avx512f},
{avx512vp2intersect, avx512f},
{avx512vpopcntdq, avx512f},
{avx512fp16, avx512bw},
{avx512fp16, avx512dq},
{avx512fp16, avx512vl},
{amx_int8, amx_tile},
{amx_bf16, amx_tile},
{sse4a, sse3},
Expand Down Expand Up @@ -208,8 +211,8 @@ constexpr auto tigerlake = icelake | get_feature_masks(avx512vp2intersect, movdi
constexpr auto alderlake = skylake | get_feature_masks(clwb, sha, waitpkg, shstk, gfni, vaes, vpclmulqdq, pconfig,
rdpid, movdiri, pku, movdir64b, serialize, ptwrite, avxvnni);
constexpr auto sapphirerapids = icelake_server |
get_feature_masks(amx_tile, amx_int8, amx_bf16, avx512bf16, serialize, cldemote, waitpkg,
ptwrite, tsxldtrk, enqcmd, shstk, avx512vp2intersect, movdiri, movdir64b);
get_feature_masks(amx_tile, amx_int8, amx_bf16, avx512bf16, avx512fp16, serialize, cldemote, waitpkg,
avxvnni, uintr, ptwrite, tsxldtrk, enqcmd, shstk, avx512vp2intersect, movdiri, movdir64b);

constexpr auto k8_sse3 = get_feature_masks(sse3, cx16);
constexpr auto amdfam10 = k8_sse3 | get_feature_masks(sse4a, lzcnt, popcnt, sahf);
Expand Down Expand Up @@ -933,10 +936,10 @@ static void ensure_jit_target(bool imaging)
Feature::avx512pf, Feature::avx512er,
Feature::avx512cd, Feature::avx512bw,
Feature::avx512vl, Feature::avx512vbmi,
Feature::avx512vpopcntdq,
Feature::avx512vpopcntdq, Feature::avxvnni,
Feature::avx512vbmi2, Feature::avx512vnni,
Feature::avx512bitalg, Feature::avx512bf16,
Feature::avx512vp2intersect};
Feature::avx512vp2intersect, Feature::avx512fp16};
for (auto fe: clone_math) {
if (!test_nbit(features0, fe) && test_nbit(t.en.features, fe)) {
t.en.flags |= JL_TARGET_CLONE_MATH;
Expand All @@ -949,6 +952,13 @@ static void ensure_jit_target(bool imaging)
break;
}
}
static constexpr uint32_t clone_fp16[] = {Feature::avx512fp16};
for (auto fe: clone_fp16) {
if (!test_nbit(features0, fe) && test_nbit(t.en.features, fe)) {
t.en.flags |= JL_TARGET_CLONE_FLOAT16;
break;
}
}
}
}

Expand Down
46 changes: 43 additions & 3 deletions test/llvmpasses/float16.ll
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: -p
; RUN: opt -enable-new-pm=0 -load libjulia-codegen%shlibext -DemoteFloat16 -S %s | FileCheck %s
; RUN: opt -enable-new-pm=1 --load-pass-plugin=libjulia-codegen%shlibext -passes='DemoteFloat16' -S %s | FileCheck %s
; RUN: opt -enable-new-pm=0 -load libjulia-codegen%shlibext -DemoteFloat16 -S %s | FileCheck %s
; RUN: opt -enable-new-pm=1 --load-pass-plugin=libjulia-codegen%shlibext -passes='DemoteFloat16' -S %s | FileCheck %s

define half @demotehalf_test(half %a, half %b) {
define half @demotehalf_test(half %a, half %b) #0 {
top:
; CHECK-LABEL: @demotehalf_test(
; CHECK-NEXT: top:
; CHECK-NEXT: %0 = fpext half %a to float
Expand Down Expand Up @@ -44,6 +45,42 @@ define half @demotehalf_test(half %a, half %b) {
; CHECK-NEXT: %36 = fadd float %34, %35
; CHECK-NEXT: %37 = fptrunc float %36 to half
; CHECK-NEXT: ret half %37
;
%0 = fadd half %a, %b
%1 = fadd half %0, %b
%2 = fadd half %1, %b
%3 = fmul half %2, %b
%4 = fdiv half %3, %b
%5 = insertelement <2 x half> undef, half %a, i32 0
%6 = insertelement <2 x half> %5, half %b, i32 1
%7 = insertelement <2 x half> undef, half %b, i32 0
%8 = insertelement <2 x half> %7, half %b, i32 1
%9 = fadd <2 x half> %6, %8
%10 = extractelement <2 x half> %9, i32 0
%11 = extractelement <2 x half> %9, i32 1
%12 = fadd half %10, %11
%13 = fadd half %12, %4
ret half %13
}

define half @native_half_test(half %a, half %b) #1 {
; CHECK-LABEL: @native_half_test(
; CHECK-NEXT top:
; CHECK-NEXT %0 = fadd half %a, %b
; CHECK-NEXT %1 = fadd half %0, %b
; CHECK-NEXT %2 = fadd half %1, %b
; CHECK-NEXT %3 = fmul half %2, %b
; CHECK-NEXT %4 = fdiv half %3, %b
; CHECK-NEXT %5 = insertelement <2 x half> undef, half %a, i32 0
; CHECK-NEXT %6 = insertelement <2 x half> %5, half %b, i32 1
; CHECK-NEXT %7 = insertelement <2 x half> undef, half %b, i32 0
; CHECK-NEXT %8 = insertelement <2 x half> %7, half %b, i32 1
; CHECK-NEXT %9 = fadd <2 x half> %6, %8
; CHECK-NEXT %10 = extractelement <2 x half> %9, i32 0
; CHECK-NEXT %11 = extractelement <2 x half> %9, i32 1
; CHECK-NEXT %12 = fadd half %10, %11
; CHECK-NEXT %13 = fadd half %12, %4
; CHECK-NEXT ret half %13
;
top:
%0 = fadd half %a, %b
Expand All @@ -62,3 +99,6 @@ top:
%13 = fadd half %12, %4
ret half %13
}

attributes #0 = { "target-features"="-avx512fp16" }
attributes #1 = { "target-features"="+avx512fp16" }

0 comments on commit a639fab

Please sign in to comment.