Skip to content
This repository has been archived by the owner on Jan 23, 2023. It is now read-only.

Commit

Permalink
Adding tests and fixing codegen for the Arm 'Aes' and 'ArmBase' hwint…
Browse files Browse the repository at this point in the history
…rinsics (#27151)

* Adding arm hwintrinsic tests for AdvSimd.LoadVector64 and AdvSimd.LoadVector128

* Adding arm hwintrinsic tests for Aes.Decrypt, Aes.Encrypt, Aes.InverseMixColumns, and Aes.MixColumns

* Fixing compSetProcessor to support the Arm AES instruction set

* Adding arm hwintrinsic tests for ArmBase.LeadingZeroCount, ArmBase.Arm64.LeadingSignCount, and ArmBase.Arm64.LeadingZeroCount

* Improving the arm hwintrinsic test generator

* Regenerating the arm hwintrinsic tests

* Fixing the arm hwintrinsic codegen to support scalar and aes intrinsics

* Applying formatting patch.

* Don't pass in opts to INS_mov

* Ensure the arm Aes.Decrypt and Aes.Encrypt intrinsics set tgtPrefUse for op1 and mark op2 as delay free
  • Loading branch information
tannergooding authored Oct 16, 2019
1 parent 9020ffa commit 643fda7
Show file tree
Hide file tree
Showing 88 changed files with 9,212 additions and 1,065 deletions.
5 changes: 5 additions & 0 deletions src/jit/compiler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2352,6 +2352,11 @@ void Compiler::compSetProcessor()
opts.setSupportedISA(InstructionSet_Vector128);
}

if (jitFlags.IsSet(JitFlags::JIT_FLAG_HAS_ARM64_AES) && JitConfig.EnableArm64Aes())
{
opts.setSupportedISA(InstructionSet_Aes);
}

if (jitFlags.IsSet(JitFlags::JIT_FLAG_HAS_ARM64_ATOMICS) && JitConfig.EnableArm64Atomics())
{
opts.setSupportedISA(InstructionSet_Atomics);
Expand Down
149 changes: 115 additions & 34 deletions src/jit/hwintrinsiccodegenarm64.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,10 @@
static bool genIsTableDrivenHWIntrinsic(NamedIntrinsic intrinsicId, HWIntrinsicCategory category)
{
// TODO-Arm64-Cleanup - make more categories to the table-driven framework
const bool tableDrivenCategory = (category != HW_Category_Scalar) && (category != HW_Category_Helper);
const bool tableDrivenFlag = true;
const bool tableDrivenCategory =
(category != HW_Category_Special) && (category != HW_Category_Scalar) && (category != HW_Category_Helper);
const bool tableDrivenFlag =
!HWIntrinsicInfo::GeneratesMultipleIns(intrinsicId) && !HWIntrinsicInfo::HasSpecialCodegen(intrinsicId);
return tableDrivenCategory && tableDrivenFlag;
}

Expand All @@ -42,105 +44,184 @@ static bool genIsTableDrivenHWIntrinsic(NamedIntrinsic intrinsicId, HWIntrinsicC
void CodeGen::genHWIntrinsic(GenTreeHWIntrinsic* node)
{
NamedIntrinsic intrinsicId = node->gtHWIntrinsicId;
InstructionSet isa = HWIntrinsicInfo::lookupIsa(intrinsicId);
HWIntrinsicCategory category = HWIntrinsicInfo::lookupCategory(intrinsicId);
int ival = HWIntrinsicInfo::lookupIval(intrinsicId);
int numArgs = HWIntrinsicInfo::lookupNumArgs(node);

assert(HWIntrinsicInfo::RequiresCodegen(intrinsicId));

if (genIsTableDrivenHWIntrinsic(intrinsicId, category))
{
InstructionSet isa = HWIntrinsicInfo::lookupIsa(intrinsicId);
int ival = HWIntrinsicInfo::lookupIval(intrinsicId);
int numArgs = HWIntrinsicInfo::lookupNumArgs(node);

assert(numArgs >= 0);

GenTree* op1 = node->gtGetOp1();
GenTree* op2 = node->gtGetOp2();
regNumber targetReg = node->GetRegNum();
var_types targetType = node->TypeGet();
var_types baseType = node->gtSIMDBaseType;

regNumber op1Reg = REG_NA;
regNumber op2Reg = REG_NA;
emitter* emit = GetEmitter();

assert(numArgs >= 0);
instruction ins = HWIntrinsicInfo::lookupIns(intrinsicId, baseType);
assert(ins != INS_invalid);
emitAttr simdSize = EA_ATTR(node->gtSIMDSize);
insOpts opt = INS_OPTS_NONE;

regNumber op1Reg = REG_NA;
regNumber op2Reg = REG_NA;
emitter* emit = GetEmitter();
emitAttr emitSize = EA_ATTR(node->gtSIMDSize);
insOpts opt = INS_OPTS_NONE;

if (category == HW_Category_SIMDScalar)
{
simdSize = emitActualTypeSize(baseType);
emitSize = emitActualTypeSize(baseType);
}
else
{
opt = genGetSimdInsOpt(simdSize, baseType);
opt = genGetSimdInsOpt(emitSize, baseType);
}

assert(simdSize != 0);
assert(emitSize != 0);
genConsumeOperands(node);

switch (numArgs)
{
case 1:
{
genConsumeRegs(op1);
assert(op1 != nullptr);
assert(op2 == nullptr);

op1Reg = op1->GetRegNum();
GetEmitter()->emitIns_R_R(ins, simdSize, targetReg, op1Reg, opt);
emit->emitIns_R_R(ins, emitSize, targetReg, op1Reg, opt);
break;
}

case 2:
{
genConsumeRegs(op1);
genConsumeRegs(op2);
assert(op1 != nullptr);
assert(op2 != nullptr);

op1Reg = op1->GetRegNum();
op2Reg = op2->GetRegNum();

GetEmitter()->emitIns_R_R_R(ins, simdSize, targetReg, op1Reg, op2Reg, opt);
emit->emitIns_R_R_R(ins, emitSize, targetReg, op1Reg, op2Reg, opt);
break;
}

case 3:
{
assert(op1 != nullptr);
assert(op2 == nullptr);

GenTreeArgList* argList = op1->AsArgList();
op1 = argList->Current();
genConsumeRegs(op1);
op1Reg = op1->GetRegNum();
op1Reg = op1->GetRegNum();

argList = argList->Rest();
op2 = argList->Current();
genConsumeRegs(op2);
op2Reg = op2->GetRegNum();
op2Reg = op2->GetRegNum();

argList = argList->Rest();
GenTree* op3 = argList->Current();
genConsumeRegs(op3);
argList = argList->Rest();
GenTree* op3 = argList->Current();
regNumber op3Reg = op3->GetRegNum();

if (targetReg != op1Reg)
{
GetEmitter()->emitIns_R_R(INS_mov, simdSize, targetReg, op1Reg);
emit->emitIns_R_R(INS_mov, emitSize, targetReg, op1Reg);
}
GetEmitter()->emitIns_R_R_R(ins, simdSize, targetReg, op2Reg, op3Reg);
emit->emitIns_R_R_R(ins, emitSize, targetReg, op2Reg, op3Reg, opt);
break;
}

default:
{
unreached();
break;
}
}
genProduceReg(node);
return;
}

genSpecialIntrinsic(node);
else
{
genSpecialIntrinsic(node);
}
}

void CodeGen::genSpecialIntrinsic(GenTreeHWIntrinsic* node)
{
unreached();
NamedIntrinsic intrinsicId = node->gtHWIntrinsicId;
HWIntrinsicCategory category = HWIntrinsicInfo::lookupCategory(intrinsicId);

assert(HWIntrinsicInfo::RequiresCodegen(intrinsicId));

InstructionSet isa = HWIntrinsicInfo::lookupIsa(intrinsicId);
int ival = HWIntrinsicInfo::lookupIval(intrinsicId);
int numArgs = HWIntrinsicInfo::lookupNumArgs(node);

assert(numArgs >= 0);

GenTree* op1 = node->gtGetOp1();
GenTree* op2 = node->gtGetOp2();
regNumber targetReg = node->GetRegNum();
var_types targetType = node->TypeGet();
var_types baseType = (category == HW_Category_Scalar) ? op1->TypeGet() : node->gtSIMDBaseType;

instruction ins = HWIntrinsicInfo::lookupIns(intrinsicId, baseType);
assert(ins != INS_invalid);

regNumber op1Reg = REG_NA;
regNumber op2Reg = REG_NA;
emitter* emit = GetEmitter();
emitAttr emitSize = EA_ATTR(node->gtSIMDSize);
insOpts opt = INS_OPTS_NONE;

if ((category == HW_Category_SIMDScalar) || (category == HW_Category_Scalar))
{
emitSize = emitActualTypeSize(baseType);
}
else
{
opt = genGetSimdInsOpt(emitSize, baseType);
}

genConsumeOperands(node);

switch (intrinsicId)
{
case NI_Aes_Decrypt:
case NI_Aes_Encrypt:
{
assert(op1 != nullptr);
assert(op2 != nullptr);

op1Reg = op1->GetRegNum();
op2Reg = op2->GetRegNum();

if (op1Reg != targetReg)
{
emit->emitIns_R_R(INS_mov, emitSize, targetReg, op1Reg);
}
emit->emitIns_R_R(ins, emitSize, targetReg, op2Reg, opt);
break;
}

case NI_ArmBase_LeadingZeroCount:
case NI_ArmBase_Arm64_LeadingSignCount:
case NI_ArmBase_Arm64_LeadingZeroCount:
{
assert(op1 != nullptr);
assert(op2 == nullptr);

op1Reg = op1->GetRegNum();
emit->emitIns_R_R(ins, emitSize, targetReg, op1Reg);
break;
}

default:
{
unreached();
}
}

genProduceReg(node);
}

#endif // FEATURE_HW_INTRINSICS
4 changes: 2 additions & 2 deletions src/jit/hwintrinsiclistarm64.h
Original file line number Diff line number Diff line change
Expand Up @@ -69,8 +69,8 @@ HARDWARE_INTRINSIC(AdvSimd_Arm64, Add, -
// {TYP_BYTE, TYP_UBYTE, TYP_SHORT, TYP_USHORT, TYP_INT, TYP_UINT, TYP_LONG, TYP_ULONG, TYP_FLOAT, TYP_DOUBLE}
// ***************************************************************************************************************************************************************************************************************************************************************************************************************************************************************************************************************************************
// AES Intrinsics
HARDWARE_INTRINSIC(Aes, Decrypt, -1, 16, 2, {INS_invalid, INS_aesd, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid}, HW_Category_SimpleSIMD, HW_Flag_NoContainment)
HARDWARE_INTRINSIC(Aes, Encrypt, -1, 16, 2, {INS_invalid, INS_aese, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid}, HW_Category_SimpleSIMD, HW_Flag_NoContainment)
HARDWARE_INTRINSIC(Aes, Decrypt, -1, 16, 2, {INS_invalid, INS_aesd, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid}, HW_Category_SimpleSIMD, HW_Flag_NoContainment|HW_Flag_SpecialCodeGen)
HARDWARE_INTRINSIC(Aes, Encrypt, -1, 16, 2, {INS_invalid, INS_aese, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid}, HW_Category_SimpleSIMD, HW_Flag_NoContainment|HW_Flag_SpecialCodeGen)
HARDWARE_INTRINSIC(Aes, InverseMixColumns, -1, 16, 1, {INS_invalid, INS_aesimc, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid}, HW_Category_SimpleSIMD, HW_Flag_NoContainment)
HARDWARE_INTRINSIC(Aes, MixColumns, -1, 16, 1, {INS_invalid, INS_aesmc, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid}, HW_Category_SimpleSIMD, HW_Flag_NoContainment)

Expand Down
14 changes: 14 additions & 0 deletions src/jit/lsraarm64.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1072,6 +1072,20 @@ int LinearScan::BuildHWIntrinsic(GenTreeHWIntrinsic* intrinsicTree)
// must be handled within the case.
switch (intrinsicId)
{
case NI_Aes_Decrypt:
case NI_Aes_Encrypt:
{
assert((numArgs == 2) && (op1 != nullptr) && (op2 != nullptr));

buildUses = false;

tgtPrefUse = BuildUse(op1);
srcCount += 1;
srcCount += BuildDelayFreeUses(op2);

break;
}

case NI_Sha1_HashUpdateChoose:
case NI_Sha1_HashUpdateMajority:
case NI_Sha1_HashUpdateParity:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,9 @@ namespace JIT.HardwareIntrinsics.Arm
{
public static partial class Program
{
private static void AbsDouble()
private static void Abs_Vector128_Double()
{
var test = new SimpleUnaryOpTest__AbsDouble();
var test = new SimpleUnaryOpTest__Abs_Vector128_Double();

if (test.IsSupported)
{
Expand Down Expand Up @@ -110,7 +110,7 @@ private static void AbsDouble()
}
}

public sealed unsafe class SimpleUnaryOpTest__AbsDouble
public sealed unsafe class SimpleUnaryOpTest__Abs_Vector128_Double
{
private struct DataTable
{
Expand Down Expand Up @@ -171,15 +171,15 @@ public static TestStruct Create()
return testStruct;
}

public void RunStructFldScenario(SimpleUnaryOpTest__AbsDouble testClass)
public void RunStructFldScenario(SimpleUnaryOpTest__Abs_Vector128_Double testClass)
{
var result = AdvSimd.Arm64.Abs(_fld1);

Unsafe.Write(testClass._dataTable.outArrayPtr, result);
testClass.ValidateResult(_fld1, testClass._dataTable.outArrayPtr);
}

public void RunStructFldScenario_Load(SimpleUnaryOpTest__AbsDouble testClass)
public void RunStructFldScenario_Load(SimpleUnaryOpTest__Abs_Vector128_Double testClass)
{
fixed (Vector128<Double>* pFld1 = &_fld1)
{
Expand All @@ -206,13 +206,13 @@ public void RunStructFldScenario_Load(SimpleUnaryOpTest__AbsDouble testClass)

private DataTable _dataTable;

static SimpleUnaryOpTest__AbsDouble()
static SimpleUnaryOpTest__Abs_Vector128_Double()
{
for (var i = 0; i < Op1ElementCount; i++) { _data1[i] = -TestLibrary.Generator.GetDouble(); }
Unsafe.CopyBlockUnaligned(ref Unsafe.As<Vector128<Double>, byte>(ref _clsVar1), ref Unsafe.As<Double, byte>(ref _data1[0]), (uint)Unsafe.SizeOf<Vector128<Double>>());
}

public SimpleUnaryOpTest__AbsDouble()
public SimpleUnaryOpTest__Abs_Vector128_Double()
{
Succeeded = true;

Expand Down Expand Up @@ -330,7 +330,7 @@ public void RunClassLclFldScenario()
{
TestLibrary.TestFramework.BeginScenario(nameof(RunClassLclFldScenario));

var test = new SimpleUnaryOpTest__AbsDouble();
var test = new SimpleUnaryOpTest__Abs_Vector128_Double();
var result = AdvSimd.Arm64.Abs(test._fld1);

Unsafe.Write(_dataTable.outArrayPtr, result);
Expand All @@ -341,7 +341,7 @@ public void RunClassLclFldScenario_Load()
{
TestLibrary.TestFramework.BeginScenario(nameof(RunClassLclFldScenario_Load));

var test = new SimpleUnaryOpTest__AbsDouble();
var test = new SimpleUnaryOpTest__Abs_Vector128_Double();

fixed (Vector128<Double>* pFld1 = &test._fld1)
{
Expand Down
Loading

0 comments on commit 643fda7

Please sign in to comment.