Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Generate SVE for 80bit load/stores when possible #4166

Merged
merged 5 commits into from
Dec 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 6 additions & 4 deletions FEXCore/Scripts/json_ir_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,8 @@ def is_ssa_type(type):
if (type == "SSA" or
type == "GPR" or
type == "GPRPair" or
type == "FPR"):
type == "FPR" or
type == "PRED"):
return True
return False

Expand Down Expand Up @@ -150,8 +151,8 @@ def parse_ops(ops):
RHS += f", {DType}:$Out{Name}"
else:
# Single anonymous destination
if LHS not in ["SSA", "GPR", "GPRPair", "FPR"]:
ExitError(f"Unknown destination class type {LHS}. Needs to be one of SSA, GPR, GPRPair, FPR")
if LHS not in ["SSA", "GPR", "GPRPair", "FPR", "PRED"]:
ExitError(f"Unknown destination class type {LHS}. Needs to be one of SSA, GPR, GPRPair, FPR, PRED")

OpDef.HasDest = True
OpDef.DestType = LHS
Expand Down Expand Up @@ -221,7 +222,8 @@ def parse_ops(ops):
if (OpArg.IsSSA and
(OpArg.Type == "GPR" or
OpArg.Type == "GPRPair" or
OpArg.Type == "FPR")):
OpArg.Type == "FPR" or
OpArg.Type == "PRED")):
OpDef.EmitValidation.append(f"GetOpRegClass({ArgName}) == InvalidClass || WalkFindRegClass({ArgName}) == {OpArg.Type}Class")

OpArg.Name = ArgName
Expand Down
21 changes: 21 additions & 0 deletions FEXCore/Source/Interface/Core/ArchHelpers/Arm64Emitter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,12 @@ namespace x64 {
ARMEmitter::Reg::r24, ARMEmitter::Reg::r25, ARMEmitter::Reg::r30, ARMEmitter::Reg::r18,
};

// p6 and p7 registers are used as temporaries no not added here for RA
// See PREF_TMP_16B and PREF_TMP_32B
// p0-p1 are also used in the jit as temps.
// Also p8-p15 cannot be used can only encode p0-p7, so we're left with p2-p5.
constexpr std::array<ARMEmitter::PRegister, 4> PR = {ARMEmitter::PReg::p2, ARMEmitter::PReg::p3, ARMEmitter::PReg::p4, ARMEmitter::PReg::p5};

constexpr unsigned RAPairs = 6;

// All are caller saved
Expand Down Expand Up @@ -103,6 +109,12 @@ namespace x64 {
ARMEmitter::Reg::r16, ARMEmitter::Reg::r17, ARMEmitter::Reg::r30,
};

// p6 and p7 registers are used as temporaries no not added here for RA
// See PREF_TMP_16B and PREF_TMP_32B
// p0-p1 are also used in the jit as temps.
// Also p8-p15 cannot be used can only encode p0-p7, so we're left with p2-p5.
constexpr std::array<ARMEmitter::PRegister, 4> PR = {ARMEmitter::PReg::p2, ARMEmitter::PReg::p3, ARMEmitter::PReg::p4, ARMEmitter::PReg::p5};

constexpr unsigned RAPairs = 6;

constexpr std::array<ARMEmitter::VRegister, 16> SRAFPR = {
Expand Down Expand Up @@ -234,6 +246,12 @@ namespace x32 {

constexpr unsigned RAPairs = 12;

// p6 and p7 registers are used as temporaries no not added here for RA
// See PREF_TMP_16B and PREF_TMP_32B
// p0-p1 are also used in the jit as temps.
// Also p8-p15 cannot be used can only encode p0-p7, so we're left with p2-p5.
constexpr std::array<ARMEmitter::PRegister, 4> PR = {ARMEmitter::PReg::p2, ARMEmitter::PReg::p3, ARMEmitter::PReg::p4, ARMEmitter::PReg::p5};

// All are caller saved
constexpr std::array<ARMEmitter::VRegister, 8> SRAFPR = {
ARMEmitter::VReg::v16, ARMEmitter::VReg::v17, ARMEmitter::VReg::v18, ARMEmitter::VReg::v19,
Expand Down Expand Up @@ -357,6 +375,7 @@ Arm64Emitter::Arm64Emitter(FEXCore::Context::ContextImpl* ctx, void* EmissionPtr
GeneralRegisters = x64::RA;
StaticFPRegisters = x64::SRAFPR;
GeneralFPRegisters = x64::RAFPR;
PredicateRegisters = x64::PR;
PairRegisters = x64::RAPairs;
#ifdef _M_ARM_64EC
ConfiguredDynamicRegisterBase = std::span(x64::RA.begin(), 7);
Expand All @@ -370,6 +389,8 @@ Arm64Emitter::Arm64Emitter(FEXCore::Context::ContextImpl* ctx, void* EmissionPtr

StaticFPRegisters = x32::SRAFPR;
GeneralFPRegisters = x32::RAFPR;

PredicateRegisters = x32::PR;
}
}

Expand Down
1 change: 1 addition & 0 deletions FEXCore/Source/Interface/Core/ArchHelpers/Arm64Emitter.h
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@ class Arm64Emitter : public ARMEmitter::Emitter {
std::span<const ARMEmitter::Register> ConfiguredDynamicRegisterBase {};
std::span<const ARMEmitter::Register> StaticRegisters {};
std::span<const ARMEmitter::Register> GeneralRegisters {};
std::span<const ARMEmitter::PRegister> PredicateRegisters {};
std::span<const ARMEmitter::VRegister> StaticFPRegisters {};
std::span<const ARMEmitter::VRegister> GeneralFPRegisters {};
uint32_t PairRegisters = 0;
Expand Down
1 change: 1 addition & 0 deletions FEXCore/Source/Interface/Core/JIT/JIT.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -534,6 +534,7 @@ Arm64JITCore::Arm64JITCore(FEXCore::Context::ContextImpl* ctx, FEXCore::Core::In
RAPass->AddRegisters(FEXCore::IR::GPRFixedClass, StaticRegisters.size());
RAPass->AddRegisters(FEXCore::IR::FPRClass, GeneralFPRegisters.size());
RAPass->AddRegisters(FEXCore::IR::FPRFixedClass, StaticFPRegisters.size());
RAPass->AddRegisters(FEXCore::IR::PREDClass, PredicateRegisters.size());
RAPass->PairRegs = PairRegisters;

{
Expand Down
13 changes: 13 additions & 0 deletions FEXCore/Source/Interface/Core/JIT/JITClass.h
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,19 @@ class Arm64JITCore final : public CPUBackend, public Arm64Emitter {
FEX_UNREACHABLE;
}

[[nodiscard]]
ARMEmitter::PRegister GetPReg(IR::NodeID Node) const {
const auto Reg = GetPhys(Node);

LOGMAN_THROW_AA_FMT(Reg.Class == IR::PREDClass.Val, "Unexpected Class: {}", Reg.Class);

if (Reg.Class == IR::PREDClass.Val) {
return PredicateRegisters[Reg.Reg];
}

FEX_UNREACHABLE;
}

[[nodiscard]]
FEXCore::IR::RegisterClassType GetRegClass(IR::NodeID Node) const;

Expand Down
70 changes: 70 additions & 0 deletions FEXCore/Source/Interface/Core/JIT/MemoryOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ tags: backend|arm64
#include "Interface/Context/Context.h"
#include "Interface/Core/CPUID.h"
#include "Interface/Core/JIT/JITClass.h"
#include "Interface/IR/IR.h"
#include <FEXCore/Utils/CompilerDefs.h>
#include <FEXCore/Utils/MathUtils.h>

Expand Down Expand Up @@ -1551,6 +1552,75 @@ DEF_OP(StoreMem) {
}
}

DEF_OP(InitPredicate) {
const auto Op = IROp->C<IR::IROp_InitPredicate>();
const auto OpSize = IROp->Size;
ptrue(ConvertSubRegSize16(OpSize), GetPReg(Node), static_cast<ARMEmitter::PredicatePattern>(Op->Pattern));
}

DEF_OP(StoreMemPredicate) {
const auto Op = IROp->C<IR::IROp_StoreMemPredicate>();
const auto Predicate = GetPReg(Op->Mask.ID());

const auto RegData = GetVReg(Op->Value.ID());
const auto MemReg = GetReg(Op->Addr.ID());

LOGMAN_THROW_A_FMT(HostSupportsSVE128 || HostSupportsSVE256, "StoreMemPredicate needs SVE support");

const auto MemDst = ARMEmitter::SVEMemOperand(MemReg.X(), 0);

switch (IROp->ElementSize) {
case IR::OpSize::i8Bit: {
st1b<ARMEmitter::SubRegSize::i8Bit>(RegData.Z(), Predicate, MemDst);
break;
}
case IR::OpSize::i16Bit: {
st1h<ARMEmitter::SubRegSize::i16Bit>(RegData.Z(), Predicate, MemDst);
break;
}
case IR::OpSize::i32Bit: {
st1w<ARMEmitter::SubRegSize::i32Bit>(RegData.Z(), Predicate, MemDst);
break;
}
case IR::OpSize::i64Bit: {
st1d(RegData.Z(), Predicate, MemDst);
break;
}
default: LOGMAN_MSG_A_FMT("Unhandled {} element size: {}", __func__, IROp->ElementSize); break;
}
}

DEF_OP(LoadMemPredicate) {
const auto Op = IROp->C<IR::IROp_LoadMemPredicate>();
const auto Dst = GetVReg(Node);
const auto Predicate = GetPReg(Op->Mask.ID());
const auto MemReg = GetReg(Op->Addr.ID());

LOGMAN_THROW_A_FMT(HostSupportsSVE128 || HostSupportsSVE256, "LoadMemPredicate needs SVE support");

const auto MemDst = ARMEmitter::SVEMemOperand(MemReg.X(), 0);

switch (IROp->ElementSize) {
case IR::OpSize::i8Bit: {
ld1b<ARMEmitter::SubRegSize::i8Bit>(Dst.Z(), Predicate.Zeroing(), MemDst);
break;
}
case IR::OpSize::i16Bit: {
ld1h<ARMEmitter::SubRegSize::i16Bit>(Dst.Z(), Predicate.Zeroing(), MemDst);
break;
}
case IR::OpSize::i32Bit: {
ld1w<ARMEmitter::SubRegSize::i32Bit>(Dst.Z(), Predicate.Zeroing(), MemDst);
break;
}
case IR::OpSize::i64Bit: {
ld1d(Dst.Z(), Predicate.Zeroing(), MemDst);
break;
}
default: LOGMAN_MSG_A_FMT("Unhandled {} element size: {}", __func__, IROp->ElementSize); break;
}
}

DEF_OP(StoreMemPair) {
const auto Op = IROp->C<IR::IROp_StoreMemPair>();
const auto OpSize = IROp->Size;
Expand Down
28 changes: 19 additions & 9 deletions FEXCore/Source/Interface/Core/OpcodeDispatcher.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ desc: Handles x86/64 ops to IR, no-pf opt, local-flags opt
$end_info$
*/

#include "FEXCore/Core/HostFeatures.h"
#include "FEXCore/Utils/Telemetry.h"
#include "Interface/Context/Context.h"
#include "Interface/Core/OpcodeDispatcher.h"
Expand Down Expand Up @@ -4309,10 +4310,15 @@ Ref OpDispatchBuilder::LoadSource_WithOpSize(RegisterClassType Class, const X86T
if ((IsOperandMem(Operand, true) && LoadData) || ForceLoad) {
if (OpSize == OpSize::f80Bit) {
Ref MemSrc = LoadEffectiveAddress(A, true);

// For X87 extended doubles, Split the load.
auto Res = _LoadMem(Class, OpSize::i64Bit, MemSrc, Align == OpSize::iInvalid ? OpSize : Align);
return _VLoadVectorElement(OpSize::i128Bit, OpSize::i16Bit, Res, 4, _Add(OpSize::i64Bit, MemSrc, _InlineConstant(8)));
if (CTX->HostFeatures.SupportsSVE128 || CTX->HostFeatures.SupportsSVE256) {
// Using SVE we can load this with a single instruction.
auto PReg = InitPredicateCached(OpSize::i16Bit, ARMEmitter::PredicatePattern::SVE_VL5);
return _LoadMemPredicate(OpSize::i128Bit, OpSize::i16Bit, PReg, MemSrc);
} else {
// For X87 extended doubles, Split the load.
auto Res = _LoadMem(Class, OpSize::i64Bit, MemSrc, Align == OpSize::iInvalid ? OpSize : Align);
return _VLoadVectorElement(OpSize::i128Bit, OpSize::i16Bit, Res, 4, _Add(OpSize::i64Bit, MemSrc, _InlineConstant(8)));
}
}

return _LoadMemAutoTSO(Class, OpSize, A, Align == OpSize::iInvalid ? OpSize : Align);
Expand Down Expand Up @@ -4439,11 +4445,15 @@ void OpDispatchBuilder::StoreResult_WithOpSize(FEXCore::IR::RegisterClassType Cl

if (OpSize == OpSize::f80Bit) {
Ref MemStoreDst = LoadEffectiveAddress(A, true);

// For X87 extended doubles, split before storing
_StoreMem(FPRClass, OpSize::i64Bit, MemStoreDst, Src, Align);
auto Upper = _VExtractToGPR(OpSize::i128Bit, OpSize::i64Bit, Src, 1);
_StoreMem(GPRClass, OpSize::i16Bit, Upper, MemStoreDst, _Constant(8), std::min(Align, OpSize::i64Bit), MEM_OFFSET_SXTX, 1);
if (CTX->HostFeatures.SupportsSVE128 || CTX->HostFeatures.SupportsSVE256) {
auto PReg = InitPredicateCached(OpSize::i16Bit, ARMEmitter::PredicatePattern::SVE_VL5);
_StoreMemPredicate(OpSize::i128Bit, OpSize::i16Bit, Src, PReg, MemStoreDst);
} else {
// For X87 extended doubles, split before storing
_StoreMem(FPRClass, OpSize::i64Bit, MemStoreDst, Src, Align);
auto Upper = _VExtractToGPR(OpSize::i128Bit, OpSize::i64Bit, Src, 1);
_StoreMem(GPRClass, OpSize::i16Bit, Upper, MemStoreDst, _Constant(8), std::min(Align, OpSize::i64Bit), MEM_OFFSET_SXTX, 1);
}
} else {
_StoreMemAutoTSO(Class, OpSize, A, Src, Align == OpSize::iInvalid ? OpSize : Align);
}
Expand Down
3 changes: 3 additions & 0 deletions FEXCore/Source/Interface/Core/OpcodeDispatcher.h
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,9 @@ class OpDispatchBuilder final : public IREmitter {

// Need to clear any named constants that were cached.
ClearCachedNamedConstants();

// Clear predicate cache for x87 ldst
ResetInitPredicateCache();
}

IRPair<IROp_Jump> Jump() {
Expand Down
23 changes: 21 additions & 2 deletions FEXCore/Source/Interface/IR/IR.json
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
" SSA = untyped",
" GPR = GPR class type",
" FPR = FPR class type",
" PRED = Predicate register class type",
"Declaring the SSA types correctly will allow validation passes to ensure the op is getting passed correct arguments",
"",
"Arguments must always follow a particular order. <Type>:<Prefix><Name>",
Expand Down Expand Up @@ -83,6 +84,7 @@
"constexpr FEXCore::IR::RegisterClassType GPRFixedClass {1}",
"constexpr FEXCore::IR::RegisterClassType FPRClass {2}",
"constexpr FEXCore::IR::RegisterClassType FPRFixedClass {3}",
"constexpr FEXCore::IR::RegisterClassType PREDClass {4}",
"constexpr FEXCore::IR::RegisterClassType ComplexClass {5}",
"constexpr FEXCore::IR::RegisterClassType InvalidClass {7}",
"",
Expand Down Expand Up @@ -148,6 +150,7 @@
"SSA": "OrderedNode*",
"GPR": "OrderedNode*",
"FPR": "OrderedNode*",
"PRED": "OrderedNode*",
"FenceType": "FenceType",
"RegisterClass": "RegisterClassType",
"CondClass": "CondClassType",
Expand Down Expand Up @@ -560,11 +563,27 @@
"HasSideEffects": true,
"DestSize": "Size",
"EmitValidation": [
"WalkFindRegClass($Value1) == $Class",
"WalkFindRegClass($Value2) == $Class"
"WalkFindRegClass($Value1) == $Class"
]
},

"PRED = InitPredicate OpSize:#Size, u8:$Pattern": {
"Desc": ["Initialize predicate register from Pattern"],
"DestSize": "Size"
},

"StoreMemPredicate OpSize:#RegisterSize, OpSize:#ElementSize, FPR:$Value, PRED:$Mask, GPR:$Addr": {
"Desc": [ "Stores a value to memory using SVE predicate mask." ],
"DestSize": "RegisterSize",
"HasSideEffects": true,
"NumElements": "IR::NumElements(RegisterSize, ElementSize)"
},
"FPR = LoadMemPredicate OpSize:#RegisterSize, OpSize:#ElementSize, PRED:$Mask, GPR:$Addr": {
"Desc": [ "Loads a value to memory using SVE predicate mask." ],
"DestSize": "RegisterSize",
"NumElements": "IR::NumElements(RegisterSize, ElementSize)"
},

"SSA = LoadMemTSO RegisterClass:$Class, OpSize:#Size, GPR:$Addr, GPR:$Offset, OpSize:$Align, MemOffsetType:$OffsetType, u8:$OffsetScale": {
"Desc": ["Does a x86 TSO compatible load from memory. Offset must be Invalid()."
],
Expand Down
3 changes: 3 additions & 0 deletions FEXCore/Source/Interface/IR/IRDumper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,8 @@ static void PrintArg(fextl::stringstream* out, [[maybe_unused]] const IRListView
*out << "FPR";
} else if (Arg == FPRFixedClass.Val) {
*out << "FPRFixed";
} else if (Arg == PREDClass.Val) {
*out << "PRED";
} else {
*out << "Unknown Registerclass " << Arg;
}
Expand All @@ -98,6 +100,7 @@ static void PrintArg(fextl::stringstream* out, const IRListView* IR, OrderedNode
case FEXCore::IR::GPRFixedClass.Val: *out << "(GPRFixed"; break;
case FEXCore::IR::FPRClass.Val: *out << "(FPR"; break;
case FEXCore::IR::FPRFixedClass.Val: *out << "(FPRFixed"; break;
case FEXCore::IR::PREDClass.Val: *out << "(PRED"; break;
case FEXCore::IR::ComplexClass.Val: *out << "(Complex"; break;
case FEXCore::IR::InvalidClass.Val: *out << "(Invalid"; break;
default: *out << "(Unknown"; break;
Expand Down
1 change: 1 addition & 0 deletions FEXCore/Source/Interface/IR/IREmitter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ FEXCore::IR::RegisterClassType IREmitter::WalkFindRegClass(Ref Node) {
case FPRClass:
case GPRFixedClass:
case FPRFixedClass:
case PREDClass:
case InvalidClass: return Class;
default: break;
}
Expand Down
Loading
Loading