Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[1.9] Add Float16 to supported x86 processors #52349

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/features_x86.h
Original file line number Diff line number Diff line change
Expand Up @@ -74,11 +74,13 @@ JL_FEATURE_DEF(enqcmd, 32 * 3 + 29, 0)
// EAX=7,ECX=0: EDX
// JL_FEATURE_DEF(avx5124vnniw, 32 * 4 + 2, ?????)
// JL_FEATURE_DEF(avx5124fmaps, 32 * 4 + 3, ?????)
JL_FEATURE_DEF(uintr, 32 * 4 + 5, 140000)
JL_FEATURE_DEF(avx512vp2intersect, 32 * 4 + 8, 0)
JL_FEATURE_DEF(serialize, 32 * 4 + 14, 110000)
JL_FEATURE_DEF(tsxldtrk, 32 * 4 + 16, 110000)
JL_FEATURE_DEF(pconfig, 32 * 4 + 18, 0)
JL_FEATURE_DEF_NAME(amx_bf16, 32 * 4 + 22, 110000, "amx-bf16")
JL_FEATURE_DEF(avx512fp16, 32 * 4 + 23, 140000)
JL_FEATURE_DEF_NAME(amx_tile, 32 * 4 + 24, 110000, "amx-tile")
JL_FEATURE_DEF_NAME(amx_int8, 32 * 4 + 25, 110000, "amx-int8")

Expand Down
17 changes: 2 additions & 15 deletions src/llvm-demote-float16.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -47,35 +47,22 @@ INST_STATISTIC(FCmp);

extern JuliaOJIT *jl_ExecutionEngine;

Optional<bool> always_have_fp16() {
#if defined(_CPU_X86_) || defined(_CPU_X86_64_)
// x86 doesn't support fp16
// TODO: update for sapphire rapids when it comes out
return false;
#else
return {};
#endif
}

namespace {

bool have_fp16(Function &caller) {
auto unconditional = always_have_fp16();
if (unconditional.hasValue())
return unconditional.getValue();

Attribute FSAttr = caller.getFnAttribute("target-features");
StringRef FS =
FSAttr.isValid() ? FSAttr.getValueAsString() : jl_ExecutionEngine->getTargetFeatureString();
#if defined(_CPU_AARCH64_)
if (FS.find("+fp16fml") != llvm::StringRef::npos || FS.find("+fullfp16") != llvm::StringRef::npos){
return true;
}
#else
#elif defined(_CPU_X86_64_)
if (FS.find("+avx512fp16") != llvm::StringRef::npos){
return true;
}
#endif
(void)FS;
return false;
}

Expand Down
13 changes: 5 additions & 8 deletions src/llvm-multiversioning.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,6 @@ using namespace llvm;

extern Optional<bool> always_have_fma(Function&);

extern Optional<bool> always_have_fp16();

void replaceUsesWithLoad(Function &F, function_ref<GlobalVariable *(Instruction &I)> should_replace, MDNode *tbaa_const);

namespace {
Expand Down Expand Up @@ -490,13 +488,12 @@ uint32_t CloneCtx::collect_func_info(Function &F)
flag |= JL_TARGET_CLONE_MATH;
}
}
if(!always_have_fp16().hasValue()){
for (size_t i = 0; i < I.getNumOperands(); i++) {
if(I.getOperand(i)->getType()->isHalfTy()){
flag |= JL_TARGET_CLONE_FLOAT16;
}
// Check for BFloat16 when they are added to julia can be done here

for (size_t i = 0; i < I.getNumOperands(); i++) {
if(I.getOperand(i)->getType()->isHalfTy()){
flag |= JL_TARGET_CLONE_FLOAT16;
}
// Check for BFloat16 when they are added to julia can be done here
}
if (has_veccall && (flag & JL_TARGET_CLONE_SIMD) && (flag & JL_TARGET_CLONE_MATH) &&
(flag & JL_TARGET_CLONE_CPU) && (flag & JL_TARGET_CLONE_FLOAT16)) {
Expand Down
18 changes: 14 additions & 4 deletions src/processor_x86.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,9 @@ static constexpr FeatureDep deps[] = {
{avx512vnni, avx512f},
{avx512vp2intersect, avx512f},
{avx512vpopcntdq, avx512f},
{avx512fp16, avx512bw},
{avx512fp16, avx512dq},
{avx512fp16, avx512vl},
{amx_int8, amx_tile},
{amx_bf16, amx_tile},
{sse4a, sse3},
Expand Down Expand Up @@ -208,8 +211,8 @@ constexpr auto tigerlake = icelake | get_feature_masks(avx512vp2intersect, movdi
constexpr auto alderlake = skylake | get_feature_masks(clwb, sha, waitpkg, shstk, gfni, vaes, vpclmulqdq, pconfig,
rdpid, movdiri, pku, movdir64b, serialize, ptwrite, avxvnni);
constexpr auto sapphirerapids = icelake_server |
get_feature_masks(amx_tile, amx_int8, amx_bf16, avx512bf16, serialize, cldemote, waitpkg,
ptwrite, tsxldtrk, enqcmd, shstk, avx512vp2intersect, movdiri, movdir64b);
get_feature_masks(amx_tile, amx_int8, amx_bf16, avx512bf16, avx512fp16, serialize, cldemote, waitpkg,
avxvnni, uintr, ptwrite, tsxldtrk, enqcmd, shstk, avx512vp2intersect, movdiri, movdir64b);

constexpr auto k8_sse3 = get_feature_masks(sse3, cx16);
constexpr auto amdfam10 = k8_sse3 | get_feature_masks(sse4a, lzcnt, popcnt, sahf);
Expand Down Expand Up @@ -933,10 +936,10 @@ static void ensure_jit_target(bool imaging)
Feature::avx512pf, Feature::avx512er,
Feature::avx512cd, Feature::avx512bw,
Feature::avx512vl, Feature::avx512vbmi,
Feature::avx512vpopcntdq,
Feature::avx512vpopcntdq, Feature::avxvnni,
Feature::avx512vbmi2, Feature::avx512vnni,
Feature::avx512bitalg, Feature::avx512bf16,
Feature::avx512vp2intersect};
Feature::avx512vp2intersect, Feature::avx512fp16};
for (auto fe: clone_math) {
if (!test_nbit(features0, fe) && test_nbit(t.en.features, fe)) {
t.en.flags |= JL_TARGET_CLONE_MATH;
Expand All @@ -949,6 +952,13 @@ static void ensure_jit_target(bool imaging)
break;
}
}
static constexpr uint32_t clone_fp16[] = {Feature::avx512fp16};
for (auto fe: clone_fp16) {
if (!test_nbit(features0, fe) && test_nbit(t.en.features, fe)) {
t.en.flags |= JL_TARGET_CLONE_FLOAT16;
break;
}
}
}
}

Expand Down
46 changes: 43 additions & 3 deletions test/llvmpasses/float16.ll
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: -p
; RUN: opt -enable-new-pm=0 -load libjulia-codegen%shlibext -DemoteFloat16 -S %s | FileCheck %s
; RUN: opt -enable-new-pm=1 --load-pass-plugin=libjulia-codegen%shlibext -passes='DemoteFloat16' -S %s | FileCheck %s
; RUN: opt -enable-new-pm=0 -load libjulia-codegen%shlibext -DemoteFloat16 -S %s | FileCheck %s
; RUN: opt -enable-new-pm=1 --load-pass-plugin=libjulia-codegen%shlibext -passes='DemoteFloat16' -S %s | FileCheck %s

define half @demotehalf_test(half %a, half %b) {
define half @demotehalf_test(half %a, half %b) #0 {
top:
; CHECK-LABEL: @demotehalf_test(
; CHECK-NEXT: top:
; CHECK-NEXT: %0 = fpext half %a to float
Expand Down Expand Up @@ -44,6 +45,42 @@ define half @demotehalf_test(half %a, half %b) {
; CHECK-NEXT: %36 = fadd float %34, %35
; CHECK-NEXT: %37 = fptrunc float %36 to half
; CHECK-NEXT: ret half %37
;
%0 = fadd half %a, %b
%1 = fadd half %0, %b
%2 = fadd half %1, %b
%3 = fmul half %2, %b
%4 = fdiv half %3, %b
%5 = insertelement <2 x half> undef, half %a, i32 0
%6 = insertelement <2 x half> %5, half %b, i32 1
%7 = insertelement <2 x half> undef, half %b, i32 0
%8 = insertelement <2 x half> %7, half %b, i32 1
%9 = fadd <2 x half> %6, %8
%10 = extractelement <2 x half> %9, i32 0
%11 = extractelement <2 x half> %9, i32 1
%12 = fadd half %10, %11
%13 = fadd half %12, %4
ret half %13
}

define half @native_half_test(half %a, half %b) #1 {
; CHECK-LABEL: @native_half_test(
; CHECK-NEXT top:
; CHECK-NEXT %0 = fadd half %a, %b
; CHECK-NEXT %1 = fadd half %0, %b
; CHECK-NEXT %2 = fadd half %1, %b
; CHECK-NEXT %3 = fmul half %2, %b
; CHECK-NEXT %4 = fdiv half %3, %b
; CHECK-NEXT %5 = insertelement <2 x half> undef, half %a, i32 0
; CHECK-NEXT %6 = insertelement <2 x half> %5, half %b, i32 1
; CHECK-NEXT %7 = insertelement <2 x half> undef, half %b, i32 0
; CHECK-NEXT %8 = insertelement <2 x half> %7, half %b, i32 1
; CHECK-NEXT %9 = fadd <2 x half> %6, %8
; CHECK-NEXT %10 = extractelement <2 x half> %9, i32 0
; CHECK-NEXT %11 = extractelement <2 x half> %9, i32 1
; CHECK-NEXT %12 = fadd half %10, %11
; CHECK-NEXT %13 = fadd half %12, %4
; CHECK-NEXT ret half %13
;
top:
%0 = fadd half %a, %b
Expand All @@ -62,3 +99,6 @@ top:
%13 = fadd half %12, %4
ret half %13
}

attributes #0 = { "target-features"="-avx512fp16" }
attributes #1 = { "target-features"="+avx512fp16" }