Skip to content

Commit

Permalink
1.Resolve comments:
Browse files Browse the repository at this point in the history
2. Added FMA intrinsics with embedded rounding and unit tests.
  • Loading branch information
Ruihan-Yin committed Feb 13, 2024
1 parent 7de90b2 commit d0c805c
Show file tree
Hide file tree
Showing 13 changed files with 804 additions and 49 deletions.
8 changes: 3 additions & 5 deletions src/coreclr/jit/codegen.h
Original file line number Diff line number Diff line change
Expand Up @@ -976,15 +976,13 @@ XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX
#ifdef FEATURE_HW_INTRINSICS
void genHWIntrinsic(GenTreeHWIntrinsic* node);
#if defined(TARGET_XARCH)
void genHWIntrinsic_R_RM(
GenTreeHWIntrinsic* node, instruction ins, emitAttr attr, regNumber reg, GenTree* rmOp, insOpts instOptions);
void genHWIntrinsic_R_RM(GenTreeHWIntrinsic* node, instruction ins, emitAttr attr, regNumber reg, GenTree* rmOp);
void genHWIntrinsic_R_RM(GenTreeHWIntrinsic* node, instruction ins, emitAttr attr, regNumber reg, GenTree* rmOp, insOpts instOptions = INS_OPTS_NONE);
void genHWIntrinsic_R_RM_I(GenTreeHWIntrinsic* node, instruction ins, emitAttr attr, int8_t ival);
void genHWIntrinsic_R_R_RM(GenTreeHWIntrinsic* node, instruction ins, emitAttr attr, insOpts instOptions);
void genHWIntrinsic_R_R_RM_I(GenTreeHWIntrinsic* node, instruction ins, emitAttr attr, int8_t ival);
void genHWIntrinsic_R_R_RM_R(GenTreeHWIntrinsic* node, instruction ins, emitAttr attr);
void genHWIntrinsic_R_R_R_RM(
instruction ins, emitAttr attr, regNumber targetReg, regNumber op1Reg, regNumber op2Reg, GenTree* op3);
instruction ins, emitAttr attr, regNumber targetReg, regNumber op1Reg, regNumber op2Reg, GenTree* op3, insOpts instOptions = INS_OPTS_NONE);
void genHWIntrinsic_R_R_R_RM_I(GenTreeHWIntrinsic* node, instruction ins, emitAttr attr, int8_t ival);

void genBaseIntrinsic(GenTreeHWIntrinsic* node);
Expand All @@ -996,7 +994,7 @@ XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX
void genAvxFamilyIntrinsic(GenTreeHWIntrinsic* node, insOpts instOptions);
void genAESIntrinsic(GenTreeHWIntrinsic* node);
void genBMI1OrBMI2Intrinsic(GenTreeHWIntrinsic* node, insOpts instOptions);
void genFMAIntrinsic(GenTreeHWIntrinsic* node);
void genFMAIntrinsic(GenTreeHWIntrinsic* node, insOpts instOptions);
void genPermuteVar2x(GenTreeHWIntrinsic* node);
void genLZCNTIntrinsic(GenTreeHWIntrinsic* node);
void genPCLMULQDQIntrinsic(GenTreeHWIntrinsic* node);
Expand Down
11 changes: 9 additions & 2 deletions src/coreclr/jit/emitxarch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8552,20 +8552,27 @@ void emitter::emitIns_SIMD_R_R_R_C(instruction ins,
// op1Reg -- The register of the first operand
// op2Reg -- The register of the second operand
// op3Reg -- The register of the second operand
// instOptions - The options that modify how the instruction is generated
//
void emitter::emitIns_SIMD_R_R_R_R(
instruction ins, emitAttr attr, regNumber targetReg, regNumber op1Reg, regNumber op2Reg, regNumber op3Reg)
instruction ins, emitAttr attr, regNumber targetReg, regNumber op1Reg, regNumber op2Reg, regNumber op3Reg, insOpts instOptions)
{
if (IsFMAInstruction(ins) || IsPermuteVar2xInstruction(ins) || IsAVXVNNIInstruction(ins))
{
assert(UseSimdEncoding());

if(instOptions != INS_OPTS_NONE)
{
// insOpts is currently available only in EVEX encoding.
assert(UseEvexEncoding());
}

// Ensure we aren't overwriting op2 or op3
assert((op2Reg != targetReg) || (op1Reg == targetReg));
assert((op3Reg != targetReg) || (op1Reg == targetReg));

emitIns_Mov(INS_movaps, attr, targetReg, op1Reg, /* canSkip */ true);
emitIns_R_R_R(ins, attr, targetReg, op2Reg, op3Reg);
emitIns_R_R_R(ins, attr, targetReg, op2Reg, op3Reg, instOptions);
}
else if (UseSimdEncoding())
{
Expand Down
2 changes: 1 addition & 1 deletion src/coreclr/jit/emitxarch.h
Original file line number Diff line number Diff line change
Expand Up @@ -840,7 +840,7 @@ void emitIns_SIMD_R_R_R_C(instruction ins,
CORINFO_FIELD_HANDLE fldHnd,
int offs);
void emitIns_SIMD_R_R_R_R(
instruction ins, emitAttr attr, regNumber targetReg, regNumber op1Reg, regNumber op2Reg, regNumber op3Reg);
instruction ins, emitAttr attr, regNumber targetReg, regNumber op1Reg, regNumber op2Reg, regNumber op3Reg, insOpts instOptions = INS_OPTS_NONE);
void emitIns_SIMD_R_R_R_S(
instruction ins, emitAttr attr, regNumber targetReg, regNumber op1Reg, regNumber op2Reg, int varx, int offs);

Expand Down
15 changes: 14 additions & 1 deletion src/coreclr/jit/gentree.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26478,14 +26478,27 @@ bool GenTreeHWIntrinsic::OperIsEmbRoundingEnabled() const
return true;
}

case NI_AVX512F_FusedMultiplyAdd:
case NI_AVX512F_FusedMultiplyAddScalar:
case NI_AVX512F_FusedMultiplyAddNegated:
case NI_AVX512F_FusedMultiplyAddNegatedScalar:
case NI_AVX512F_FusedMultiplyAddSubtract:
case NI_AVX512F_FusedMultiplySubtract:
case NI_AVX512F_FusedMultiplySubtractAdd:
case NI_AVX512F_FusedMultiplySubtractNegated:
case NI_AVX512F_FusedMultiplySubtractNegatedScalar:
case NI_AVX512F_FusedMultiplySubtractScalar:
{
return numArgs == 4;
}

case NI_AVX512F_Add:
case NI_AVX512F_Divide:
case NI_AVX512F_Multiply:
case NI_AVX512F_Subtract:

case NI_AVX512F_Scale:

case NI_AVX512F_ConvertScalarToVector128Double:
case NI_AVX512F_ConvertScalarToVector128Single:
#if defined(TARGET_AMD64)
case NI_AVX512F_X64_ConvertScalarToVector128Double:
Expand Down
84 changes: 55 additions & 29 deletions src/coreclr/jit/hwintrinsiccodegenxarch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -278,6 +278,13 @@ void CodeGen::genHWIntrinsic(GenTreeHWIntrinsic* node)
break;
}

case 4:
{
numArgs = 3;
node->ResetHWIntrinsicId(intrinsicId, compiler, node->Op(1), node->Op(2), node->Op(3));
break;
}

default:
{
unreached();
Expand Down Expand Up @@ -738,7 +745,7 @@ void CodeGen::genHWIntrinsic(GenTreeHWIntrinsic* node)
genBMI1OrBMI2Intrinsic(node, instOptions);
break;
case InstructionSet_FMA:
genFMAIntrinsic(node);
genFMAIntrinsic(node, instOptions);
break;
case InstructionSet_LZCNT:
case InstructionSet_LZCNT_X64:
Expand Down Expand Up @@ -779,39 +786,16 @@ void CodeGen::genHWIntrinsic_R_RM(
emitter* emit = GetEmitter();
OperandDesc rmOpDesc = genOperandDesc(rmOp);

assert(reg != REG_NA);

if ((instOptions & INS_OPTS_EVEX_b_MASK) != 0)
if (((instOptions & INS_OPTS_EVEX_b_MASK) != 0) && (rmOpDesc.GetKind() == OperandKind::Reg))
{
// As embedded rounding only appies in R_R case, we can skip other checks for different paths.
assert(rmOpDesc.GetKind() == OperandKind::Reg);
regNumber op1Reg = rmOp->GetRegNum();
assert(op1Reg != REG_NA);

emit->emitIns_R_R(ins, attr, reg, op1Reg, instOptions);
return;
}

genHWIntrinsic_R_RM(node, ins, attr, reg, rmOp);
}

//------------------------------------------------------------------------
// genHWIntrinsic_R_RM: Generates code for a hardware intrinsic node that takes a
// register operand and a register/memory operand.
//
// Arguments:
// node - The hardware intrinsic node
// ins - The instruction being generated
// attr - The emit attribute for the instruction being generated
// reg - The register
// rmOp - The register/memory operand node
//
void CodeGen::genHWIntrinsic_R_RM(
GenTreeHWIntrinsic* node, instruction ins, emitAttr attr, regNumber reg, GenTree* rmOp)
{
emitter* emit = GetEmitter();
OperandDesc rmOpDesc = genOperandDesc(rmOp);

if (rmOpDesc.IsContained())
{
assert(HWIntrinsicInfo::SupportsContainment(node->GetHWIntrinsicId()));
Expand Down Expand Up @@ -1137,9 +1121,10 @@ void CodeGen::genHWIntrinsic_R_R_RM_R(GenTreeHWIntrinsic* node, instruction ins,
// op1Reg - The register of the first operand
// op2Reg - The register of the second operand
// op3 - The third operand
// instOptions - The options that modify how the instruction is generated
//
void CodeGen::genHWIntrinsic_R_R_R_RM(
instruction ins, emitAttr attr, regNumber targetReg, regNumber op1Reg, regNumber op2Reg, GenTree* op3)
instruction ins, emitAttr attr, regNumber targetReg, regNumber op1Reg, regNumber op2Reg, GenTree* op3, insOpts instOptions)
{
assert(targetReg != REG_NA);
assert(op1Reg != REG_NA);
Expand All @@ -1148,6 +1133,16 @@ void CodeGen::genHWIntrinsic_R_R_R_RM(
emitter* emit = GetEmitter();
OperandDesc op3Desc = genOperandDesc(op3);

if(((instOptions & INS_OPTS_EVEX_b_MASK) != 0 ) && (op3Desc.GetKind() == OperandKind::Reg))
{
// As embedded rounding only appies in R_R case, we can skip other checks for different paths.
regNumber op3Reg = op3->GetRegNum();
assert(op3Reg != REG_NA);

emit->emitIns_SIMD_R_R_R_R(ins, attr, targetReg, op1Reg, op2Reg, op3Desc.GetReg(), instOptions);
return;
}

switch (op3Desc.GetKind())
{
case OperandKind::ClsVar:
Expand Down Expand Up @@ -1409,6 +1404,37 @@ void CodeGen::genNonTableDrivenHWIntrinsicsJumpTableFallback(GenTreeHWIntrinsic*
genHWIntrinsicJumpTableFallback(intrinsicId, lastOp->GetRegNum(), baseReg, offsReg, emitSwCase);
break;
}

case NI_AVX512F_FusedMultiplyAdd:
case NI_AVX512F_FusedMultiplyAddScalar:
case NI_AVX512F_FusedMultiplyAddNegated:
case NI_AVX512F_FusedMultiplyAddNegatedScalar:
case NI_AVX512F_FusedMultiplyAddSubtract:
case NI_AVX512F_FusedMultiplySubtract:
case NI_AVX512F_FusedMultiplySubtractAdd:
case NI_AVX512F_FusedMultiplySubtractNegated:
case NI_AVX512F_FusedMultiplySubtractNegatedScalar:
case NI_AVX512F_FusedMultiplySubtractScalar:
{
// For FMA intrinsics, since it is not possible to get any contained operand in this case: embedded rounding is limited in register-to-register form, and the control byte is dynamic, we don't need to do any swap.
assert(HWIntrinsicInfo::IsFmaIntrinsic(intrinsicId));

GenTree* op1 = node->Op(1);
GenTree* op2 = node->Op(2);
GenTree* op3 = node->Op(3);

regNumber op1Reg = op1->GetRegNum();
regNumber op2Reg = op2->GetRegNum();

auto emitSwCase = [&](int8_t i) {
insOpts newInstOptions = AddEmbRoundingMode(instOptions, i);
genHWIntrinsic_R_R_R_RM(ins, attr, targetReg, op1Reg, op2Reg, op3, newInstOptions);
};
regNumber baseReg = node->ExtractTempReg();
regNumber offsReg = node->GetSingleTempReg();
genHWIntrinsicJumpTableFallback(intrinsicId, lastOp->GetRegNum(), baseReg, offsReg, emitSwCase);
break;
}

default:
unreached();
Expand Down Expand Up @@ -2119,7 +2145,7 @@ void CodeGen::genAvxFamilyIntrinsic(GenTreeHWIntrinsic* node, insOpts instOption

if (HWIntrinsicInfo::IsFmaIntrinsic(intrinsicId))
{
genFMAIntrinsic(node);
genFMAIntrinsic(node, instOptions);
return;
}

Expand Down Expand Up @@ -2906,7 +2932,7 @@ void CodeGen::genBMI1OrBMI2Intrinsic(GenTreeHWIntrinsic* node, insOpts instOptio
// Arguments:
// node - The hardware intrinsic node
//
void CodeGen::genFMAIntrinsic(GenTreeHWIntrinsic* node)
void CodeGen::genFMAIntrinsic(GenTreeHWIntrinsic* node, insOpts instOptions)
{
NamedIntrinsic intrinsicId = node->GetHWIntrinsicId();
assert(HWIntrinsicInfo::IsFmaIntrinsic(intrinsicId));
Expand Down Expand Up @@ -3013,7 +3039,7 @@ void CodeGen::genFMAIntrinsic(GenTreeHWIntrinsic* node)
}

assert(ins != INS_invalid);
genHWIntrinsic_R_R_R_RM(ins, attr, targetReg, emitOp1->GetRegNum(), emitOp2->GetRegNum(), emitOp3);
genHWIntrinsic_R_R_R_RM(ins, attr, targetReg, emitOp1->GetRegNum(), emitOp2->GetRegNum(), emitOp3, instOptions);
genProduceReg(node);
}

Expand Down
Loading

0 comments on commit d0c805c

Please sign in to comment.