Skip to content

Commit

Permalink
Share more of the TYP_MASK handling and support rewriting TYP_MASK op…
Browse files Browse the repository at this point in the history
…erands in rationalization (#103288)

* Share more of the TYP_MASK handling and support rewriting TYP_MASK operands in rationalization

* Ensure we pass in TYP_MASK, not the simdType

* Apply formatting patch

* Fix copy/paste error, pass in clsHnd for the argument

* Ensure that we normalize sigType before inserting the CvtMaskToVectorNode

* Ensure that we get the vector node on Arm64 (ConvertVectorToMask has 2 ops)
  • Loading branch information
tannergooding authored Jun 12, 2024
1 parent 4078743 commit 96be3e2
Show file tree
Hide file tree
Showing 9 changed files with 292 additions and 282 deletions.
14 changes: 8 additions & 6 deletions src/coreclr/jit/compiler.h
Original file line number Diff line number Diff line change
Expand Up @@ -3185,6 +3185,10 @@ class Compiler
GenTree* gtNewSimdAbsNode(
var_types type, GenTree* op1, CorInfoType simdBaseJitType, unsigned simdSize);

#if defined(TARGET_ARM64)
GenTree* gtNewSimdAllTrueMaskNode(CorInfoType simdBaseJitType, unsigned simdSize);
#endif

GenTree* gtNewSimdBinOpNode(genTreeOps op,
var_types type,
GenTree* op1,
Expand Down Expand Up @@ -3223,6 +3227,8 @@ class Compiler
CorInfoType simdBaseJitType,
unsigned simdSize);

GenTree* gtNewSimdCvtMaskToVectorNode(var_types type, GenTree* op1, CorInfoType simdBaseJitType, unsigned simdSize);

GenTree* gtNewSimdCvtNode(var_types type,
GenTree* op1,
CorInfoType simdTargetBaseJitType,
Expand All @@ -3235,6 +3241,8 @@ class Compiler
CorInfoType simdSourceBaseJitType,
unsigned simdSize);

GenTree* gtNewSimdCvtVectorToMaskNode(var_types type, GenTree* op1, CorInfoType simdBaseJitType, unsigned simdSize);

GenTree* gtNewSimdCreateBroadcastNode(
var_types type, GenTree* op1, CorInfoType simdBaseJitType, unsigned simdSize);

Expand Down Expand Up @@ -3516,12 +3524,6 @@ class Compiler

GenTreeIndir* gtNewMethodTableLookup(GenTree* obj);

#if defined(TARGET_ARM64)
GenTree* gtNewSimdConvertVectorToMaskNode(var_types type, GenTree* node, CorInfoType simdBaseJitType, unsigned simdSize);
GenTree* gtNewSimdConvertMaskToVectorNode(GenTreeHWIntrinsic* node, var_types type);
GenTree* gtNewSimdAllTrueMaskNode(CorInfoType simdBaseJitType, unsigned simdSize);
#endif

//------------------------------------------------------------------------
// Other GenTree functions

Expand Down
84 changes: 77 additions & 7 deletions src/coreclr/jit/gentree.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21530,6 +21530,35 @@ GenTree* Compiler::gtNewSimdCeilNode(var_types type, GenTree* op1, CorInfoType s
return gtNewSimdHWIntrinsicNode(type, op1, intrinsic, simdBaseJitType, simdSize);
}

//------------------------------------------------------------------------
// gtNewSimdCvtMaskToVectorNode: Convert a HW instrinsic mask node to a vector
//
// Arguments:
// type -- The type of the node to convert to
// op1 -- The node to convert
// simdBaseJitType -- the base jit type of the converted node
// simdSize -- the simd size of the converted node
//
// Return Value:
// The node converted to the given type
//
GenTree* Compiler::gtNewSimdCvtMaskToVectorNode(var_types type,
GenTree* op1,
CorInfoType simdBaseJitType,
unsigned simdSize)
{
assert(varTypeIsMask(op1));
assert(varTypeIsSIMD(type));

#if defined(TARGET_XARCH)
return gtNewSimdHWIntrinsicNode(type, op1, NI_EVEX_ConvertMaskToVector, simdBaseJitType, simdSize);
#elif defined(TARGET_ARM64)
return gtNewSimdHWIntrinsicNode(type, op1, NI_Sve_ConvertMaskToVector, simdBaseJitType, simdSize);
#else
#error Unsupported platform
#endif // !TARGET_XARCH && !TARGET_ARM64
}

GenTree* Compiler::gtNewSimdCvtNode(var_types type,
GenTree* op1,
CorInfoType simdTargetBaseJitType,
Expand Down Expand Up @@ -21892,6 +21921,37 @@ GenTree* Compiler::gtNewSimdCvtNativeNode(var_types type,
return gtNewSimdHWIntrinsicNode(type, op1, hwIntrinsicID, simdSourceBaseJitType, simdSize);
}

//------------------------------------------------------------------------
// gtNewSimdCvtVectorToMaskNode: Convert a HW instrinsic vector node to a mask
//
// Arguments:
// type -- The type of the mask to produce.
// op1 -- The node to convert
// simdBaseJitType -- the base jit type of the converted node
// simdSize -- the simd size of the converted node
//
// Return Value:
// The node converted to the a mask type
//
GenTree* Compiler::gtNewSimdCvtVectorToMaskNode(var_types type,
GenTree* op1,
CorInfoType simdBaseJitType,
unsigned simdSize)
{
assert(varTypeIsMask(type));
assert(varTypeIsSIMD(op1));

#if defined(TARGET_XARCH)
return gtNewSimdHWIntrinsicNode(TYP_MASK, op1, NI_EVEX_ConvertVectorToMask, simdBaseJitType, simdSize);
#elif defined(TARGET_ARM64)
// We use cmpne which requires an embedded mask.
GenTree* trueMask = gtNewSimdAllTrueMaskNode(simdBaseJitType, simdSize);
return gtNewSimdHWIntrinsicNode(TYP_MASK, trueMask, op1, NI_Sve_ConvertVectorToMask, simdBaseJitType, simdSize);
#else
#error Unsupported platform
#endif // !TARGET_XARCH && !TARGET_ARM64
}

GenTree* Compiler::gtNewSimdCmpOpNode(
genTreeOps op, var_types type, GenTree* op1, GenTree* op2, CorInfoType simdBaseJitType, unsigned simdSize)
{
Expand Down Expand Up @@ -22569,19 +22629,15 @@ GenTree* Compiler::gtNewSimdCmpOpNode(

assert(intrinsic != NI_Illegal);

#if defined(TARGET_XARCH)
if (needsConvertMaskToVector)
{
GenTree* retNode = gtNewSimdHWIntrinsicNode(TYP_MASK, op1, op2, intrinsic, simdBaseJitType, simdSize);
return gtNewSimdHWIntrinsicNode(type, retNode, NI_EVEX_ConvertMaskToVector, simdBaseJitType, simdSize);
return gtNewSimdCvtMaskToVectorNode(type, retNode, simdBaseJitType, simdSize);
}
else
{
return gtNewSimdHWIntrinsicNode(type, op1, op2, intrinsic, simdBaseJitType, simdSize);
}
#else
return gtNewSimdHWIntrinsicNode(type, op1, op2, intrinsic, simdBaseJitType, simdSize);
#endif
}

GenTree* Compiler::gtNewSimdCmpOpAllNode(
Expand Down Expand Up @@ -27157,6 +27213,20 @@ bool GenTreeHWIntrinsic::OperIsCreateScalarUnsafe() const
}
}

//------------------------------------------------------------------------
// OperIsBitwiseHWIntrinsic: Is the operation a bitwise logic operation.
//
// Arguments:
// oper -- The operation to check
//
// Return Value:
// Whether oper is a bitwise logic intrinsic node.
//
bool GenTreeHWIntrinsic::OperIsBitwiseHWIntrinsic(genTreeOps oper)
{
return (oper == GT_AND) || (oper == GT_AND_NOT) || (oper == GT_OR) || (oper == GT_XOR);
}

//------------------------------------------------------------------------
// OperIsBitwiseHWIntrinsic: Is this HWIntrinsic a bitwise logic intrinsic node.
//
Expand All @@ -27165,8 +27235,8 @@ bool GenTreeHWIntrinsic::OperIsCreateScalarUnsafe() const
//
bool GenTreeHWIntrinsic::OperIsBitwiseHWIntrinsic() const
{
genTreeOps Oper = HWOperGet();
return Oper == GT_AND || Oper == GT_OR || Oper == GT_XOR || Oper == GT_AND_NOT;
genTreeOps oper = HWOperGet();
return OperIsBitwiseHWIntrinsic(oper);
}

//------------------------------------------------------------------------
Expand Down
2 changes: 2 additions & 0 deletions src/coreclr/jit/gentree.h
Original file line number Diff line number Diff line change
Expand Up @@ -6505,6 +6505,8 @@ struct GenTreeHWIntrinsic : public GenTreeJitIntrinsic
}
#endif

static bool OperIsBitwiseHWIntrinsic(genTreeOps oper);

bool OperIsMemoryLoad(GenTree** pAddr = nullptr) const;
bool OperIsMemoryStore(GenTree** pAddr = nullptr) const;
bool OperIsMemoryLoadOrStore() const;
Expand Down
12 changes: 9 additions & 3 deletions src/coreclr/jit/hwintrinsic.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1658,14 +1658,14 @@ GenTree* Compiler::impHWIntrinsic(NamedIntrinsic intrinsic,
if (!varTypeIsMask(op2))
{
retNode->AsHWIntrinsic()->Op(2) =
gtNewSimdConvertVectorToMaskNode(retType, op2, simdBaseJitType, simdSize);
gtNewSimdCvtVectorToMaskNode(TYP_MASK, op2, simdBaseJitType, simdSize);
}
}

if (!varTypeIsMask(op1))
{
// Op1 input is a vector. HWInstrinsic requires a mask.
retNode->AsHWIntrinsic()->Op(1) = gtNewSimdConvertVectorToMaskNode(retType, op1, simdBaseJitType, simdSize);
retNode->AsHWIntrinsic()->Op(1) = gtNewSimdCvtVectorToMaskNode(TYP_MASK, op1, simdBaseJitType, simdSize);
}

if (HWIntrinsicInfo::IsMultiReg(intrinsic))
Expand All @@ -1682,7 +1682,13 @@ GenTree* Compiler::impHWIntrinsic(NamedIntrinsic intrinsic,
// HWInstrinsic returns a mask, but all returns must be vectors, so convert mask to vector.
assert(HWIntrinsicInfo::ReturnsPerElementMask(intrinsic));
assert(nodeRetType == TYP_MASK);
retNode = gtNewSimdConvertMaskToVectorNode(retNode->AsHWIntrinsic(), retType);

GenTreeHWIntrinsic* op = retNode->AsHWIntrinsic();

CorInfoType simdBaseJitType = op->GetSimdBaseJitType();
unsigned simdSize = op->GetSimdSize();

retNode = gtNewSimdCvtMaskToVectorNode(retType, op, simdBaseJitType, simdSize);
}
#endif // defined(TARGET_ARM64)

Expand Down
44 changes: 1 addition & 43 deletions src/coreclr/jit/hwintrinsicarm64.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2625,7 +2625,7 @@ GenTree* Compiler::impSpecialIntrinsic(NamedIntrinsic intrinsic,
// HWInstrinsic requires a mask for op2
if (!varTypeIsMask(op2))
{
op2 = gtNewSimdConvertVectorToMaskNode(retType, op2, simdBaseJitType, simdSize);
op2 = gtNewSimdCvtVectorToMaskNode(TYP_MASK, op2, simdBaseJitType, simdSize);
}

retNode = gtNewSimdHWIntrinsicNode(retType, op1, op2, intrinsic, simdBaseJitType, simdSize);
Expand All @@ -2646,48 +2646,6 @@ GenTree* Compiler::impSpecialIntrinsic(NamedIntrinsic intrinsic,
return retNode;
}

//------------------------------------------------------------------------
// gtNewSimdConvertMaskToVectorNode: Convert a HW instrinsic vector node to a mask
//
// Arguments:
// node -- The node to convert
// simdBaseJitType -- the base jit type of the converted node
// simdSize -- the simd size of the converted node
//
// Return Value:
// The node converted to the a mask type
//
GenTree* Compiler::gtNewSimdConvertVectorToMaskNode(var_types type,
GenTree* node,
CorInfoType simdBaseJitType,
unsigned simdSize)
{
assert(varTypeIsSIMD(node));

// ConvertVectorToMask uses cmpne which requires an embedded mask.
GenTree* trueMask = gtNewSimdAllTrueMaskNode(simdBaseJitType, simdSize);
return gtNewSimdHWIntrinsicNode(TYP_MASK, trueMask, node, NI_Sve_ConvertVectorToMask, simdBaseJitType, simdSize);
}

//------------------------------------------------------------------------
// gtNewSimdConvertMaskToVectorNode: Convert a HW instrinsic mask node to a vector
//
// Arguments:
// node -- The node to convert
// type -- The type of the node to convert to
//
// Return Value:
// The node converted to the given type
//
GenTree* Compiler::gtNewSimdConvertMaskToVectorNode(GenTreeHWIntrinsic* node, var_types type)
{
assert(varTypeIsMask(node));
assert(varTypeIsSIMD(type));

return gtNewSimdHWIntrinsicNode(type, node, NI_Sve_ConvertMaskToVector, node->GetSimdBaseJitType(),
node->GetSimdSize());
}

//------------------------------------------------------------------------
// gtNewSimdEmbeddedMaskNode: Create an embedded mask
//
Expand Down
Loading

0 comments on commit 96be3e2

Please sign in to comment.