Skip to content

Commit

Permalink
Arm64/Sve: Implement SVE Math *Multiply* APIs (#102007)
Browse files Browse the repository at this point in the history
* Add *Fused* APIs

* fix an assert in morph

* Map APIs to instructions

* Add test cases

* handle fused* instructions

* jit format

* Added MultiplyAdd/MultiplySubtract

* Add mapping of API to instruction

* Add test cases

* Handle mov Z, Z instruction

* Reuse GetResultOpNumForRmwIntrinsic() for arm64

* Reuse HW_Flag_FmaIntrinsic for arm64

* Mark FMA APIs as HW_Flag_FmaIntrinsic

* Handle FMA in LSRA and codegen

* Remove the SpecialCodeGen flag from selectedScalar

* address some more scenarios

* jit format

* Add MultiplyBySelectedScalar

* Map the API to the instruction

* fix a bug where *Indexed API used with ConditionalSelect were failing

` Sve.ConditionalSelect(op1, Sve.MultiplyBySelectedScalar(op1, op2, 0), op3);` was failing
because we were trying to check if `MultiplyBySelectedScalar` is contained and we hit the assert
because it is not containable. Added the check.

* unpredicated movprfx should not send opt

* Add the missing flags for Subtract/Multiply

* Added tests for MultiplyBySelectedScalar

Also updated *SelectedScalar* tests for ConditionalSelect

* fixes to test cases

* fix the parameter for selectedScalar test

* jit format

* Contain(op3) of CndSel if op1 is AllTrueMask

* Handle FMA properly

* added assert
  • Loading branch information
kunalspathak authored May 11, 2024
1 parent 31527d1 commit 34e65b9
Show file tree
Hide file tree
Showing 18 changed files with 2,292 additions and 43 deletions.
16 changes: 14 additions & 2 deletions src/coreclr/jit/emitarm64.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4250,9 +4250,11 @@ void emitter::emitIns_Mov(

case INS_sve_mov:
{
if (isPredicateRegister(dstReg) && isPredicateRegister(srcReg))
// TODO-SVE: Remove check for insOptsNone() when predicate registers
// are present.
if (insOptsNone(opt) && isPredicateRegister(dstReg) && isPredicateRegister(srcReg))
{
assert(insOptsNone(opt));
// assert(insOptsNone(opt));

opt = INS_OPTS_SCALABLE_B;
attr = EA_SCALABLE;
Expand All @@ -4263,6 +4265,16 @@ void emitter::emitIns_Mov(
}
fmt = IF_SVE_CZ_4A_L;
}
else if (isVectorRegister(dstReg) && isVectorRegister(srcReg))
{
assert(insOptsScalable(opt));

if (IsRedundantMov(ins, size, dstReg, srcReg, canSkip))
{
return;
}
fmt = IF_SVE_AU_3A;
}
else
{
unreached();
Expand Down
34 changes: 31 additions & 3 deletions src/coreclr/jit/emitarm64sve.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10374,7 +10374,6 @@ BYTE* emitter::emitOutput_InstrSve(BYTE* dst, instrDesc* id)
case IF_SVE_FN_3B: // ...........mmmmm ......nnnnnddddd -- SVE2 integer multiply long
case IF_SVE_FO_3A: // ...........mmmmm ......nnnnnddddd -- SVE integer matrix multiply accumulate
case IF_SVE_AT_3B: // ...........mmmmm ......nnnnnddddd -- SVE integer add/subtract vectors (unpredicated)
case IF_SVE_AU_3A: // ...........mmmmm ......nnnnnddddd -- SVE bitwise logical operations (unpredicated)
case IF_SVE_BD_3B: // ...........mmmmm ......nnnnnddddd -- SVE2 integer multiply vectors (unpredicated)
case IF_SVE_EF_3A: // ...........mmmmm ......nnnnnddddd -- SVE two-way dot product
case IF_SVE_EI_3A: // ...........mmmmm ......nnnnnddddd -- SVE mixed sign dot product
Expand All @@ -10396,6 +10395,17 @@ BYTE* emitter::emitOutput_InstrSve(BYTE* dst, instrDesc* id)
dst += emitOutput_Instr(dst, code);
break;

case IF_SVE_AU_3A: // ...........mmmmm ......nnnnnddddd -- SVE bitwise logical operations (unpredicated)
code = emitInsCodeSve(ins, fmt);
code |= insEncodeReg_V<4, 0>(id->idReg1()); // ddddd
code |= insEncodeReg_V<9, 5>(id->idReg2()); // nnnnn
if (id->idIns() != INS_sve_mov)
{
code |= insEncodeReg_V<20, 16>(id->idReg3()); // mmmmm
}
dst += emitOutput_Instr(dst, code);
break;

case IF_SVE_AV_3A: // ...........mmmmm ......kkkkkddddd -- SVE2 bitwise ternary operations
code = emitInsCodeSve(ins, fmt);
code |= insEncodeReg_V<4, 0>(id->idReg1()); // ddddd
Expand Down Expand Up @@ -12882,7 +12892,6 @@ void emitter::emitInsSveSanityCheck(instrDesc* id)
case IF_SVE_FN_3B: // ...........mmmmm ......nnnnnddddd -- SVE2 integer multiply long
case IF_SVE_FO_3A: // ...........mmmmm ......nnnnnddddd -- SVE integer matrix multiply accumulate
case IF_SVE_AT_3B: // ...........mmmmm ......nnnnnddddd -- SVE integer add/subtract vectors (unpredicated)
case IF_SVE_AU_3A: // ...........mmmmm ......nnnnnddddd -- SVE bitwise logical operations (unpredicated)
case IF_SVE_BD_3B: // ...........mmmmm ......nnnnnddddd -- SVE2 integer multiply vectors (unpredicated)
case IF_SVE_EF_3A: // ...........mmmmm ......nnnnnddddd -- SVE two-way dot product
case IF_SVE_EI_3A: // ...........mmmmm ......nnnnnddddd -- SVE mixed sign dot product
Expand All @@ -12902,6 +12911,12 @@ void emitter::emitInsSveSanityCheck(instrDesc* id)
assert(isVectorRegister(id->idReg2())); // nnnnn/mmmmm
assert(isVectorRegister(id->idReg3())); // mmmmm/aaaaa
break;
case IF_SVE_AU_3A: // ...........mmmmm ......nnnnnddddd -- SVE bitwise logical operations (unpredicated)
assert(insOptsScalable(id->idInsOpt()));
assert(isVectorRegister(id->idReg1())); // ddddd
assert(isVectorRegister(id->idReg2())); // nnnnn/mmmmm
assert((id->idIns() == INS_sve_mov) || isVectorRegister(id->idReg3())); // mmmmm/aaaaa
break;

case IF_SVE_HA_3A_F: // ...........mmmmm ......nnnnnddddd -- SVE BFloat16 floating-point dot product
case IF_SVE_EW_3A: // ...........mmmmm ......nnnnnddddd -- SVE2 multiply-add (checked pointer)
Expand Down Expand Up @@ -14526,7 +14541,6 @@ void emitter::emitDispInsSveHelp(instrDesc* id)
case IF_SVE_HD_3A_A: // ...........mmmmm ......nnnnnddddd -- SVE floating point matrix multiply accumulate
// <Zd>.D, <Zn>.D, <Zm>.D
case IF_SVE_AT_3B: // ...........mmmmm ......nnnnnddddd -- SVE integer add/subtract vectors (unpredicated)
case IF_SVE_AU_3A: // ...........mmmmm ......nnnnnddddd -- SVE bitwise logical operations (unpredicated)
// <Zd>.B, <Zn>.B, <Zm>.B
case IF_SVE_GF_3A: // ........xx.mmmmm ......nnnnnddddd -- SVE2 histogram generation (segment)
case IF_SVE_BD_3B: // ...........mmmmm ......nnnnnddddd -- SVE2 integer multiply vectors (unpredicated)
Expand All @@ -14541,6 +14555,20 @@ void emitter::emitDispInsSveHelp(instrDesc* id)
emitDispSveReg(id->idReg3(), id->idInsOpt(), false); // mmmmm/aaaaa
break;

// <Zd>.D, <Zn>.D, <Zm>.D
case IF_SVE_AU_3A: // ...........mmmmm ......nnnnnddddd -- SVE bitwise logical operations (unpredicated)
emitDispSveReg(id->idReg1(), id->idInsOpt(), true); // ddddd
if (id->idIns() == INS_sve_mov)
{
emitDispSveReg(id->idReg2(), id->idInsOpt(), false); // nnnnn/mmmmm
}
else
{
emitDispSveReg(id->idReg2(), id->idInsOpt(), true); // nnnnn/mmmmm
emitDispSveReg(id->idReg3(), id->idInsOpt(), false); // mmmmm/aaaaa
}
break;

// <Zda>.D, <Zn>.D, <Zm>.D
case IF_SVE_EW_3A: // ...........mmmmm ......nnnnnddddd -- SVE2 multiply-add (checked pointer)
// <Zdn>.D, <Zm>.D, <Za>.D
Expand Down
6 changes: 5 additions & 1 deletion src/coreclr/jit/gentree.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27955,7 +27955,7 @@ bool GenTreeLclVar::IsNeverNegative(Compiler* comp) const
return comp->lvaGetDesc(GetLclNum())->IsNeverNegative();
}

#if defined(TARGET_XARCH) && defined(FEATURE_HW_INTRINSICS)
#if (defined(TARGET_XARCH) || defined(TARGET_ARM64)) && defined(FEATURE_HW_INTRINSICS)
//------------------------------------------------------------------------
// GetResultOpNumForRmwIntrinsic: check if the result is written into one of the operands.
// In the case that none of the operand is overwritten, check if any of them is lastUse.
Expand All @@ -27966,7 +27966,11 @@ bool GenTreeLclVar::IsNeverNegative(Compiler* comp) const
//
unsigned GenTreeHWIntrinsic::GetResultOpNumForRmwIntrinsic(GenTree* use, GenTree* op1, GenTree* op2, GenTree* op3)
{
#if defined(TARGET_XARCH)
assert(HWIntrinsicInfo::IsFmaIntrinsic(gtHWIntrinsicId) || HWIntrinsicInfo::IsPermuteVar2x(gtHWIntrinsicId));
#elif defined(TARGET_ARM64)
assert(HWIntrinsicInfo::IsFmaIntrinsic(gtHWIntrinsicId));
#endif

if (use != nullptr && use->OperIs(GT_STORE_LCL_VAR))
{
Expand Down
24 changes: 12 additions & 12 deletions src/coreclr/jit/hwintrinsic.h
Original file line number Diff line number Diff line change
Expand Up @@ -216,27 +216,27 @@ enum HWIntrinsicFlag : unsigned int
// The intrinsic is an RMW intrinsic
HW_Flag_RmwIntrinsic = 0x1000000,

// The intrinsic is a FusedMultiplyAdd intrinsic
HW_Flag_FmaIntrinsic = 0x2000000,

// The intrinsic is a PermuteVar2x intrinsic
HW_Flag_PermuteVar2x = 0x4000000,
HW_Flag_PermuteVar2x = 0x2000000,

// The intrinsic is an embedded broadcast compatible intrinsic
HW_Flag_EmbBroadcastCompatible = 0x8000000,
HW_Flag_EmbBroadcastCompatible = 0x4000000,

// The intrinsic is an embedded rounding compatible intrinsic
HW_Flag_EmbRoundingCompatible = 0x10000000,
HW_Flag_EmbRoundingCompatible = 0x8000000,

// The intrinsic is an embedded masking compatible intrinsic
HW_Flag_EmbMaskingCompatible = 0x20000000,
HW_Flag_EmbMaskingCompatible = 0x10000000,
#elif defined(TARGET_ARM64)

// The intrinsic has an enum operand. Using this implies HW_Flag_HasImmediateOperand.
HW_Flag_HasEnumOperand = 0x1000000,

#endif // TARGET_XARCH

// The intrinsic is a FusedMultiplyAdd intrinsic
HW_Flag_FmaIntrinsic = 0x20000000,

HW_Flag_CanBenefitFromConstantProp = 0x80000000,
};

Expand Down Expand Up @@ -935,17 +935,17 @@ struct HWIntrinsicInfo
return (flags & HW_Flag_MaybeNoJmpTableIMM) != 0;
}

#if defined(TARGET_XARCH)
static bool IsRmwIntrinsic(NamedIntrinsic id)
static bool IsFmaIntrinsic(NamedIntrinsic id)
{
HWIntrinsicFlag flags = lookupFlags(id);
return (flags & HW_Flag_RmwIntrinsic) != 0;
return (flags & HW_Flag_FmaIntrinsic) != 0;
}

static bool IsFmaIntrinsic(NamedIntrinsic id)
#if defined(TARGET_XARCH)
static bool IsRmwIntrinsic(NamedIntrinsic id)
{
HWIntrinsicFlag flags = lookupFlags(id);
return (flags & HW_Flag_FmaIntrinsic) != 0;
return (flags & HW_Flag_RmwIntrinsic) != 0;
}

static bool IsPermuteVar2x(NamedIntrinsic id)
Expand Down
2 changes: 2 additions & 0 deletions src/coreclr/jit/hwintrinsicarm64.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -277,6 +277,8 @@ void HWIntrinsicInfo::lookupImmBounds(
case NI_AdvSimd_Arm64_StoreSelectedScalarVector128x4:
case NI_AdvSimd_Arm64_DuplicateSelectedScalarToVector128:
case NI_AdvSimd_Arm64_InsertSelectedScalar:
case NI_Sve_FusedMultiplyAddBySelectedScalar:
case NI_Sve_FusedMultiplySubtractBySelectedScalar:
immUpperBound = Compiler::getSIMDVectorLength(simdSize, baseType) - 1;
break;

Expand Down
Loading

0 comments on commit 34e65b9

Please sign in to comment.