Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Arm64/Sve: Implement SVE Math *Multiply* APIs #102007

Merged
merged 30 commits into from
May 11, 2024

Conversation

kunalspathak
Copy link
Member

@kunalspathak kunalspathak commented May 8, 2024

  • FusedMultiplyAdd
  • FusedMultiplyAddBySelectedScalar
  • FusedMultiplyAddNegated
  • FusedMultiplySubtract
  • FusedMultiplySubtractBySelectedScalar
  • FusedMultiplySubtractNegated
  • MultiplyAdd
  • MultiplySubtract
  • MultiplyBySelectedScalar

All tests are passing: https://gist.github.com/kunalspathak/511565b3fe4d830dec509d867b8e36b0
Edit: Updated to include MultiplyAdd and MultiplySubtract

Contributes to #99957

@kunalspathak kunalspathak requested a review from TIHan May 8, 2024 06:08
@kunalspathak kunalspathak added the arm-sve Work related to arm64 SVE/SVE2 support label May 8, 2024
@kunalspathak
Copy link
Member Author

@dotnet/arm64-contrib

@kunalspathak kunalspathak changed the title Arm64/Sve: Implement SVE Math *Fused* APIs Arm64/Sve: Implement SVE Math Fused* APIs May 8, 2024
@@ -1815,6 +1816,22 @@ int LinearScan::BuildHWIntrinsic(GenTreeHWIntrinsic* intrinsicTree, int* pDstCou
}
}

if ((intrin.id == NI_Sve_FusedMultiplyAddBySelectedScalar) ||
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do these this require special code here?

Copy link
Member Author

@kunalspathak kunalspathak May 8, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because as per FMLA (indexed), Zm has to be in lower vector registers.

image

We have similar code for AdvSimd too and most likely, if I see more patterns in future, I will combine this code with it.

if ((intrin.category == HW_Category_SIMDByIndexedElement) && (genTypeSize(intrin.baseType) == 2))
{
// Some "Advanced SIMD scalar x indexed element" and "Advanced SIMD vector x indexed element" instructions (e.g.
// "MLA (by element)") have encoding that restricts what registers that can be used for the indexed element when
// the element size is H (i.e. 2 bytes).
assert(intrin.op2 != nullptr);
if ((intrin.op4 != nullptr) || ((intrin.op3 != nullptr) && !hasImmediateOperand))
{
if (isRMW)
{
srcCount += BuildDelayFreeUses(intrin.op2, nullptr);
srcCount += BuildDelayFreeUses(intrin.op3, nullptr, RBM_ASIMD_INDEXED_H_ELEMENT_ALLOWED_REGS);
}
else
{
srcCount += BuildOperandUses(intrin.op2);
srcCount += BuildOperandUses(intrin.op3, RBM_ASIMD_INDEXED_H_ELEMENT_ALLOWED_REGS);
}
if (intrin.op4 != nullptr)
{
assert(hasImmediateOperand);
assert(varTypeIsIntegral(intrin.op4));
srcCount += BuildOperandUses(intrin.op4);
}
}

@@ -46,6 +46,12 @@ HARDWARE_INTRINSIC(Sve, CreateWhileLessThanOrEqualMask32Bit,
HARDWARE_INTRINSIC(Sve, CreateWhileLessThanOrEqualMask64Bit, -1, 2, false, {INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_sve_whilele, INS_invalid, INS_invalid}, HW_Category_SIMD, HW_Flag_Scalable|HW_Flag_SpecialCodeGen|HW_Flag_ReturnsPerElementMask)
HARDWARE_INTRINSIC(Sve, CreateWhileLessThanOrEqualMask8Bit, -1, 2, false, {INS_invalid, INS_sve_whilele, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid}, HW_Category_SIMD, HW_Flag_Scalable|HW_Flag_SpecialCodeGen|HW_Flag_ReturnsPerElementMask)
HARDWARE_INTRINSIC(Sve, Divide, -1, 2, true, {INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_sve_sdiv, INS_sve_udiv, INS_sve_sdiv, INS_sve_udiv, INS_sve_fdiv, INS_sve_fdiv}, HW_Category_SIMD, HW_Flag_Scalable|HW_Flag_EmbeddedMaskedOperation|HW_Flag_HasRMWSemantics|HW_Flag_LowMaskedOperation)
HARDWARE_INTRINSIC(Sve, FusedMultiplyAdd, -1, -1, false, {INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_sve_fmla, INS_sve_fmla}, HW_Category_SIMD, HW_Flag_Scalable|HW_Flag_EmbeddedMaskedOperation|HW_Flag_HasRMWSemantics|HW_Flag_LowMaskedOperation)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You are always using FMLA for these. Will there be cases where FMAD might be more optimal based on register usage? If so, raise an issue to track it.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Currently, I am just preferencing op1 as a targetPrefUse, in other words telling LSRA to use op1 as the targetReg and mark the registers for other operands as delayFree. With that, using FMLA will always be optimal. @tannergooding - please correct if I missed anything here.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, that sounds reasonable.
There might be scenarios where FMAD is still optimal - those where op2 is never reused in the C#, but op1 is reused. Using FMLA would avoid having to mov op1 into a temp.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would definitely expect us to have some logic around picking FMLA vs FMAD.

The x64 logic is even more complex because it has to handle the RMW consideration (should the tgtPrefUse be the addend or multiplier), but it also needs to consider which memory operand should be contained (since it supports embedded loads). That logic is here: https://github.com/dotnet/runtime/blob/main/src/coreclr/jit/lowerxarch.cpp#L9823 and you'll note that it uses the node->GetResultOpNumForRmwIntrinsic to determine which of op1, op2, or op3 is both an input and output or otherwise which is last use. It uses this to ensure the right containment choices are being made.

x64 then repeats this logic again in LSRA to actually set the tgtPrefUse: https://github.com/dotnet/runtime/blob/main/src/coreclr/jit/lsraxarch.cpp#L2432 and then again in codegen to pick which instruction form it should use: https://github.com/dotnet/runtime/blob/main/src/coreclr/jit/hwintrinsiccodegenxarch.cpp#L2947

I expect that Arm64 just needs to mirror the LSRA and codegen logic (ignoring any bits relevant to containment) and picking FMLA vs FMAD (rather than 231 vs 213, respectively)

@jkotas jkotas added the area-CodeGen-coreclr CLR JIT compiler in src/coreclr/src/jit and related components such as SuperPMI label May 8, 2024
Comment on lines 595 to 611
// If the instruction just has "predicated" version, then move the "embMaskOp1Reg"
// into targetReg. Next, do the predicated operation on the targetReg and last,
// use "sel" to select the active lanes based on mask, and set inactive lanes
// to falseReg.

assert(HWIntrinsicInfo::IsEmbeddedMaskedOperation(intrinEmbMask.id));

if (targetReg != embMaskOp1Reg)
{
GetEmitter()->emitIns_R_R(INS_sve_movprfx, EA_SCALABLE, targetReg, embMaskOp1Reg);
}

GetEmitter()->emitIns_R_R_R_R(insEmbMask, emitSize, targetReg, maskReg, embMaskOp2Reg,
embMaskOp3Reg, opt);

GetEmitter()->emitIns_R_R_R_R(INS_sve_sel, emitSize, targetReg, maskReg, targetReg, falseReg,
opt, INS_SCALABLE_OPTS_UNPREDICATED);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there an assumption being made about the instruction being RMW here?

FMLA encodes 4 registers (Zda, Pg, Zn, and Zm) where Zda is both the source and destination and the operation is functionally similar to Zda += (Zn * Zm) (with only a single rounding operation).

Given some Zda = ConditionalSelect(Pg, FusedMultiplyAdd(Zda, Zn, Zm), Zda) it can then be encoded as simply:

fmla Zda, Pg/M, Zn, Zm

Given some Zda = ConditionalSelect(Pg, FusedMultiplyAdd(Zda, Zn, Zm), merge) it can then be encoded as simply:

movprfx Zda, Pg/M, merge
fmla Zda, Pg/M, Zn, Zm

Given some Zda = ConditionalSelect(Pg, FusedMultiplyAdd(Zda, Zn, Zm), Zero) it can then be encoded as simply:

movprfx Zda, Pg/Z, Zda
fmla Zda, Pg/M, Zn, Zm

There are then similar versions possible using fmad when the multiplier is the source and destination (op2Reg == tgtReg or op3Reg == tgtReg).


We should actually never need sel for this case, but only need complex generation if tgtReg is unique from all input registers (including the merge) and we're merging with a non-zero value, such as dest = ConditionalSelect(Pg, FusedMultiplyAdd(Zda, Zn, Zm), merge):

mov dest, Zda
movprfx dest, Pg/M, merge
fmla dest, Pg/M, Zn, Zm

This ends up being different from the other fallbacks that do use sel specifically because it's RMW and requires predication (that is there is no fmla (unpredicated)).

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The main goal of using ins (unpredicated); sel in the other case is because it allows a 2 instruction sequence as the worst case.

In this case, we at worst need a 3 instruction sequence due to the required predication on the instruction. Thus, it becomes better to use mov; movprfx (predicated); ins (predicated) instead as it can allow mov to be elided by the register renamer.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

such as dest = ConditionalSelect(Pg, FusedMultiplyAdd(Zda, Zn, Zm), merge):

For the similar reasoning mentioned in #100743 (comment) (where we should only movprfx the inactive lanes from merge -> dest, the code should be:

mov dest, Zda
fmla dest, Pg/M, Zn, Zm
sel dest, Pg/M, dest, merge

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, I misinterpreted the value of Pg/M as AllTrue. Spoke to @tannergooding offline and we would like to generate:

sel dest, Pg/M, Zda, merge
fmla dest, Pg/M, Zn, Zm

@kunalspathak kunalspathak changed the title Arm64/Sve: Implement SVE Math Fused* APIs Arm64/Sve: Implement SVE Math *Multiply* APIs May 10, 2024
` Sve.ConditionalSelect(op1, Sve.MultiplyBySelectedScalar(op1, op2, 0), op3);` was failing
because we were trying to check if `MultiplyBySelectedScalar` is contained and we hit the assert
because it is not containable. Added the check.
Also updated *SelectedScalar* tests for ConditionalSelect

if (intrin.op3->IsVectorZero())
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this be asserting that intrin.op3 is contained?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added `

@kunalspathak kunalspathak merged commit 34e65b9 into dotnet:main May 11, 2024
167 checks passed
@kunalspathak kunalspathak deleted the sve_math6 branch May 11, 2024 01:39
Ruihan-Yin pushed a commit to Ruihan-Yin/runtime that referenced this pull request May 30, 2024
* Add *Fused* APIs

* fix an assert in morph

* Map APIs to instructions

* Add test cases

* handle fused* instructions

* jit format

* Added MultiplyAdd/MultiplySubtract

* Add mapping of API to instruction

* Add test cases

* Handle mov Z, Z instruction

* Reuse GetResultOpNumForRmwIntrinsic() for arm64

* Reuse HW_Flag_FmaIntrinsic for arm64

* Mark FMA APIs as HW_Flag_FmaIntrinsic

* Handle FMA in LSRA and codegen

* Remove the SpecialCodeGen flag from selectedScalar

* address some more scenarios

* jit format

* Add MultiplyBySelectedScalar

* Map the API to the instruction

* fix a bug where *Indexed API used with ConditionalSelect were failing

` Sve.ConditionalSelect(op1, Sve.MultiplyBySelectedScalar(op1, op2, 0), op3);` was failing
because we were trying to check if `MultiplyBySelectedScalar` is contained and we hit the assert
because it is not containable. Added the check.

* unpredicated movprfx should not send opt

* Add the missing flags for Subtract/Multiply

* Added tests for MultiplyBySelectedScalar

Also updated *SelectedScalar* tests for ConditionalSelect

* fixes to test cases

* fix the parameter for selectedScalar test

* jit format

* Contain(op3) of CndSel if op1 is AllTrueMask

* Handle FMA properly

* added assert
@github-actions github-actions bot locked and limited conversation to collaborators Jun 10, 2024
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
area-CodeGen-coreclr CLR JIT compiler in src/coreclr/src/jit and related components such as SuperPMI arm-sve Work related to arm64 SVE/SVE2 support
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants