Skip to content

Commit

Permalink
Expose AVX512F embedded rounding intrinsics. (#97415)
Browse files Browse the repository at this point in the history
* Expose embedded rounding related scalar intrinsic APIs

* Expose embedded rounding related arithmatic intrinsic APIs

* Ensure the new APIs are properly lowered

* Bug fixes

* Expose embedded rounding casting APIs

* Expose arithmetic embedded rounding unit tests

* Add a test template for embedded rounding APIs, this will be enough to cover all the binary APIs including vector and scalar operations.

* Add template for unary ops

* Expose all  the embedded rounding unit tests generated by the templates

* Expose embedded rounding casting APIs unit tests

* Expose handwritten unit tests for embedded rounding APIs with special input arg lists.

* Bug fixes:
1. ConvertToVector256Int32/UInt32 use special code gen path, adding a fallback path when embedded rounding is activated and the control byte is not constant.

* Bug fix:
Fix wrong data type in the API definition.

* formatting

* Update API documents for embedded rounding APIs.

* resolve conflicts with #97569

* formatting

* bug fix and remove un-needed SAE related intrinsics

* resolve comments:
1. update the arg lists for genHWIntrinsic_R_RM

* resolve comments:
Add jumptable fallback to non-table driven embedded rounding intrinsics.

* resolve comments:
1. remove some redundent checks on embedded rounding intrinsics

* Bug fix:
pass the correct operand GenTree node, when emitting the fallback for embedded rounding intrinsics.

* formatting

* revert an unexpected change.

* 1.Resolve comments:
2. Added FMA intrinsics with embedded rounding and unit tests.

* Expose the rest of embedded rounding APIs

* formatting

* Ensure the control byte local is assigned to the correct register.
  • Loading branch information
Ruihan-Yin authored Feb 15, 2024
1 parent f2d5b2f commit aeecdb8
Show file tree
Hide file tree
Showing 23 changed files with 3,737 additions and 147 deletions.
20 changes: 16 additions & 4 deletions src/coreclr/jit/codegen.h
Original file line number Diff line number Diff line change
Expand Up @@ -976,13 +976,23 @@ XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX
#ifdef FEATURE_HW_INTRINSICS
void genHWIntrinsic(GenTreeHWIntrinsic* node);
#if defined(TARGET_XARCH)
void genHWIntrinsic_R_RM(GenTreeHWIntrinsic* node, instruction ins, emitAttr attr, regNumber reg, GenTree* rmOp);
void genHWIntrinsic_R_RM(GenTreeHWIntrinsic* node,
instruction ins,
emitAttr attr,
regNumber reg,
GenTree* rmOp,
insOpts instOptions = INS_OPTS_NONE);
void genHWIntrinsic_R_RM_I(GenTreeHWIntrinsic* node, instruction ins, emitAttr attr, int8_t ival);
void genHWIntrinsic_R_R_RM(GenTreeHWIntrinsic* node, instruction ins, emitAttr attr, insOpts instOptions);
void genHWIntrinsic_R_R_RM_I(GenTreeHWIntrinsic* node, instruction ins, emitAttr attr, int8_t ival);
void genHWIntrinsic_R_R_RM_R(GenTreeHWIntrinsic* node, instruction ins, emitAttr attr);
void genHWIntrinsic_R_R_R_RM(
instruction ins, emitAttr attr, regNumber targetReg, regNumber op1Reg, regNumber op2Reg, GenTree* op3);
void genHWIntrinsic_R_R_R_RM(instruction ins,
emitAttr attr,
regNumber targetReg,
regNumber op1Reg,
regNumber op2Reg,
GenTree* op3,
insOpts instOptions = INS_OPTS_NONE);
void genHWIntrinsic_R_R_R_RM_I(GenTreeHWIntrinsic* node, instruction ins, emitAttr attr, int8_t ival);

void genBaseIntrinsic(GenTreeHWIntrinsic* node);
Expand All @@ -994,7 +1004,7 @@ XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX
void genAvxFamilyIntrinsic(GenTreeHWIntrinsic* node, insOpts instOptions);
void genAESIntrinsic(GenTreeHWIntrinsic* node);
void genBMI1OrBMI2Intrinsic(GenTreeHWIntrinsic* node, insOpts instOptions);
void genFMAIntrinsic(GenTreeHWIntrinsic* node);
void genFMAIntrinsic(GenTreeHWIntrinsic* node, insOpts instOptions);
void genPermuteVar2x(GenTreeHWIntrinsic* node);
void genLZCNTIntrinsic(GenTreeHWIntrinsic* node);
void genPCLMULQDQIntrinsic(GenTreeHWIntrinsic* node);
Expand All @@ -1008,6 +1018,8 @@ XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX
regNumber baseReg,
regNumber offsReg,
HWIntrinsicSwitchCaseBody emitSwCase);

void genNonTableDrivenHWIntrinsicsJumpTableFallback(GenTreeHWIntrinsic* node, GenTree* lastOp);
#endif // defined(TARGET_XARCH)

#ifdef TARGET_ARM64
Expand Down
28 changes: 24 additions & 4 deletions src/coreclr/jit/emitxarch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6565,7 +6565,7 @@ void emitter::emitIns_Mov(instruction ins, emitAttr attr, regNumber dstReg, regN
* Add an instruction with two register operands.
*/

void emitter::emitIns_R_R(instruction ins, emitAttr attr, regNumber reg1, regNumber reg2)
void emitter::emitIns_R_R(instruction ins, emitAttr attr, regNumber reg1, regNumber reg2, insOpts instOptions)
{
if (IsMovInstruction(ins))
{
Expand All @@ -6587,6 +6587,13 @@ void emitter::emitIns_R_R(instruction ins, emitAttr attr, regNumber reg1, regNum
id->idReg1(reg1);
id->idReg2(reg2);

if ((instOptions & INS_OPTS_EVEX_b_MASK) != INS_OPTS_NONE)
{
// if EVEX.b needs to be set in this path, then it should be embedded rounding.
assert(UseEvexEncoding());
id->idSetEvexbContext(instOptions);
}

UNATIVE_OFFSET sz = emitInsSizeRR(id);
id->idCodeSize(sz);

Expand Down Expand Up @@ -8545,20 +8552,32 @@ void emitter::emitIns_SIMD_R_R_R_C(instruction ins,
// op1Reg -- The register of the first operand
// op2Reg -- The register of the second operand
// op3Reg -- The register of the second operand
// instOptions - The options that modify how the instruction is generated
//
void emitter::emitIns_SIMD_R_R_R_R(
instruction ins, emitAttr attr, regNumber targetReg, regNumber op1Reg, regNumber op2Reg, regNumber op3Reg)
void emitter::emitIns_SIMD_R_R_R_R(instruction ins,
emitAttr attr,
regNumber targetReg,
regNumber op1Reg,
regNumber op2Reg,
regNumber op3Reg,
insOpts instOptions)
{
if (IsFMAInstruction(ins) || IsPermuteVar2xInstruction(ins) || IsAVXVNNIInstruction(ins))
{
assert(UseSimdEncoding());

if (instOptions != INS_OPTS_NONE)
{
// insOpts is currently available only in EVEX encoding.
assert(UseEvexEncoding());
}

// Ensure we aren't overwriting op2 or op3
assert((op2Reg != targetReg) || (op1Reg == targetReg));
assert((op3Reg != targetReg) || (op1Reg == targetReg));

emitIns_Mov(INS_movaps, attr, targetReg, op1Reg, /* canSkip */ true);
emitIns_R_R_R(ins, attr, targetReg, op2Reg, op3Reg);
emitIns_R_R_R(ins, attr, targetReg, op2Reg, op3Reg, instOptions);
}
else if (UseSimdEncoding())
{
Expand Down Expand Up @@ -11659,6 +11678,7 @@ void emitter::emitDispIns(
default:
{
printf("%s, %s", emitRegName(id->idReg1(), attr), emitRegName(id->idReg2(), attr));
emitDispEmbRounding(id);
break;
}
}
Expand Down
11 changes: 8 additions & 3 deletions src/coreclr/jit/emitxarch.h
Original file line number Diff line number Diff line change
Expand Up @@ -635,7 +635,7 @@ void emitIns_R_I(instruction ins,

void emitIns_Mov(instruction ins, emitAttr attr, regNumber dstReg, regNumber srgReg, bool canSkip);

void emitIns_R_R(instruction ins, emitAttr attr, regNumber reg1, regNumber reg2);
void emitIns_R_R(instruction ins, emitAttr attr, regNumber reg1, regNumber reg2, insOpts instOptions = INS_OPTS_NONE);

void emitIns_R_R_I(instruction ins, emitAttr attr, regNumber reg1, regNumber reg2, int ival);

Expand Down Expand Up @@ -839,8 +839,13 @@ void emitIns_SIMD_R_R_R_C(instruction ins,
regNumber op2Reg,
CORINFO_FIELD_HANDLE fldHnd,
int offs);
void emitIns_SIMD_R_R_R_R(
instruction ins, emitAttr attr, regNumber targetReg, regNumber op1Reg, regNumber op2Reg, regNumber op3Reg);
void emitIns_SIMD_R_R_R_R(instruction ins,
emitAttr attr,
regNumber targetReg,
regNumber op1Reg,
regNumber op2Reg,
regNumber op3Reg,
insOpts instOptions = INS_OPTS_NONE);
void emitIns_SIMD_R_R_R_S(
instruction ins, emitAttr attr, regNumber targetReg, regNumber op1Reg, regNumber op2Reg, int varx, int offs);

Expand Down
89 changes: 89 additions & 0 deletions src/coreclr/jit/gentree.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26484,6 +26484,95 @@ bool GenTreeHWIntrinsic::OperIsBitwiseHWIntrinsic() const
return Oper == GT_AND || Oper == GT_OR || Oper == GT_XOR || Oper == GT_AND_NOT;
}

//------------------------------------------------------------------------
// OperIsEmbRoundingEnabled: Is this HWIntrinsic a node with embedded rounding feature.
//
// Return Value:
// Whether "this" is a node with embedded rounding feature.
//
bool GenTreeHWIntrinsic::OperIsEmbRoundingEnabled() const
{
#if defined(TARGET_XARCH)
NamedIntrinsic intrinsicId = GetHWIntrinsicId();

if (!HWIntrinsicInfo::IsEmbRoundingCompatible(intrinsicId))
{
return false;
}

size_t numArgs = GetOperandCount();
switch (intrinsicId)
{
// these intrinsics only have the embedded rounding enabled implementation.
case NI_AVX512F_AddScalar:
case NI_AVX512F_DivideScalar:
case NI_AVX512F_MultiplyScalar:
case NI_AVX512F_SubtractScalar:
case NI_AVX512F_SqrtScalar:
{
return true;
}

case NI_AVX512F_FusedMultiplyAdd:
case NI_AVX512F_FusedMultiplyAddScalar:
case NI_AVX512F_FusedMultiplyAddNegated:
case NI_AVX512F_FusedMultiplyAddNegatedScalar:
case NI_AVX512F_FusedMultiplyAddSubtract:
case NI_AVX512F_FusedMultiplySubtract:
case NI_AVX512F_FusedMultiplySubtractAdd:
case NI_AVX512F_FusedMultiplySubtractNegated:
case NI_AVX512F_FusedMultiplySubtractNegatedScalar:
case NI_AVX512F_FusedMultiplySubtractScalar:
{
return numArgs == 4;
}

case NI_AVX512F_Add:
case NI_AVX512F_Divide:
case NI_AVX512F_Multiply:
case NI_AVX512F_Subtract:

case NI_AVX512F_Scale:
case NI_AVX512F_ScaleScalar:

case NI_AVX512F_ConvertScalarToVector128Single:
#if defined(TARGET_AMD64)
case NI_AVX512F_X64_ConvertScalarToVector128Double:
case NI_AVX512F_X64_ConvertScalarToVector128Single:
#endif // TARGET_AMD64
{
return numArgs == 3;
}

case NI_AVX512F_Sqrt:
case NI_AVX512F_ConvertToInt32:
case NI_AVX512F_ConvertToUInt32:
case NI_AVX512F_ConvertToVector256Int32:
case NI_AVX512F_ConvertToVector256Single:
case NI_AVX512F_ConvertToVector256UInt32:
case NI_AVX512F_ConvertToVector512Single:
case NI_AVX512F_ConvertToVector512UInt32:
case NI_AVX512F_ConvertToVector512Int32:
#if defined(TARGET_AMD64)
case NI_AVX512F_X64_ConvertToInt64:
case NI_AVX512F_X64_ConvertToUInt64:
#endif // TARGET_AMD64
case NI_AVX512DQ_ConvertToVector256Single:
case NI_AVX512DQ_ConvertToVector512Double:
case NI_AVX512DQ_ConvertToVector512Int64:
case NI_AVX512DQ_ConvertToVector512UInt64:
{
return numArgs == 2;
}

default:
unreached();
}
#else // !TARGET_XARCH
return false;
#endif // TARGET_XARCH
}

//------------------------------------------------------------------------------
// OperRequiresAsgFlag : Check whether the operation requires GTF_ASG flag regardless
// of the children's flags.
Expand Down
1 change: 1 addition & 0 deletions src/coreclr/jit/gentree.h
Original file line number Diff line number Diff line change
Expand Up @@ -6387,6 +6387,7 @@ struct GenTreeHWIntrinsic : public GenTreeJitIntrinsic
bool OperIsBroadcastScalar() const;
bool OperIsCreateScalarUnsafe() const;
bool OperIsBitwiseHWIntrinsic() const;
bool OperIsEmbRoundingEnabled() const;

bool OperRequiresAsgFlag() const;
bool OperRequiresCallFlag() const;
Expand Down
15 changes: 0 additions & 15 deletions src/coreclr/jit/hwintrinsic.h
Original file line number Diff line number Diff line change
Expand Up @@ -607,21 +607,6 @@ struct HWIntrinsicInfo
HWIntrinsicFlag flags = lookupFlags(id);
return (flags & HW_Flag_EmbMaskingIncompatible) == 0;
}

static size_t EmbRoundingArgPos(NamedIntrinsic id)
{
// This helper function returns the expected position,
// where the embedded rounding control argument should be.
assert(IsEmbRoundingCompatible(id));
switch (id)
{
case NI_AVX512F_Add:
return 3;

default:
unreached();
}
}
#endif // TARGET_XARCH

static bool CanBenefitFromConstantProp(NamedIntrinsic id)
Expand Down
Loading

0 comments on commit aeecdb8

Please sign in to comment.