Skip to content
This repository has been archived by the owner on Nov 1, 2020. It is now read-only.

Wasm: bring add and sub overflow operations in line with mul #8284

Merged
merged 9 commits into from
Aug 27, 2020
215 changes: 81 additions & 134 deletions src/ILCompiler.WebAssembly/src/CodeGen/ILToWebAssemblyImporter.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2782,15 +2782,15 @@ private bool ImportIntrinsicCall(MethodDesc method, MethodDesc runtimeDetermined
// then
builder.PositionAtEnd(notFatBranch);
ExceptionRegion currentTryRegion = GetCurrentTryRegion();
LLVMValueRef notFatReturn = CallOrInvoke(fromLandingPad, builder, currentTryRegion, fn, llvmArgs, ref nextInstrBlock);
LLVMValueRef notFatReturn = CallOrInvoke(fromLandingPad, builder, currentTryRegion, fn, llvmArgs.ToArray(), ref nextInstrBlock);
builder.BuildBr(endifBlock);

// else
builder.PositionAtEnd(fatBranch);
var fnWithDict = builder.BuildCast(LLVMOpcode.LLVMBitCast, fn, LLVMTypeRef.CreatePointer(GetLLVMSignatureForMethod(runtimeDeterminedMethod.Signature, true), 0), "fnWithDict");
var dictDereffed = builder.BuildLoad(builder.BuildLoad( dict, "l1"), "l2");
llvmArgs.Insert(needsReturnSlot ? 2 : 1, dictDereffed);
LLVMValueRef fatReturn = CallOrInvoke(fromLandingPad, builder, currentTryRegion, fnWithDict, llvmArgs, ref nextInstrBlock);
LLVMValueRef fatReturn = CallOrInvoke(fromLandingPad, builder, currentTryRegion, fnWithDict, llvmArgs.ToArray(), ref nextInstrBlock);
builder.BuildBr(endifBlock);

// endif
Expand All @@ -2806,7 +2806,7 @@ private bool ImportIntrinsicCall(MethodDesc method, MethodDesc runtimeDetermined
}
else
{
llvmReturn = CallOrInvoke(fromLandingPad, builder, GetCurrentTryRegion(), fn, llvmArgs, ref nextInstrBlock);
llvmReturn = CallOrInvoke(fromLandingPad, builder, GetCurrentTryRegion(), fn, llvmArgs.ToArray(), ref nextInstrBlock);
}

if (!returnType.IsVoid)
Expand All @@ -2829,23 +2829,21 @@ private bool ImportIntrinsicCall(MethodDesc method, MethodDesc runtimeDetermined
}

LLVMValueRef CallOrInvoke(bool fromLandingPad, LLVMBuilderRef builder, ExceptionRegion currentTryRegion,
LLVMValueRef fn, List<LLVMValueRef> llvmArgs, ref LLVMBasicBlockRef nextInstrBlock)
LLVMValueRef fn, LLVMValueRef[] llvmArgs, ref LLVMBasicBlockRef nextInstrBlock)
{
LLVMValueRef retVal;
if (currentTryRegion == null || fromLandingPad) // not handling exceptions that occur in the LLVM landing pad determining the EH handler
{
retVal = builder.BuildCall(fn, llvmArgs.ToArray(), string.Empty);
retVal = builder.BuildCall(fn, llvmArgs, string.Empty);
}
else
{
nextInstrBlock = _currentFunclet.AppendBasicBlock(String.Format("Try{0:X}", _currentOffset));

retVal = builder.BuildInvoke(fn, llvmArgs.ToArray(),
retVal = builder.BuildInvoke(fn, llvmArgs,
nextInstrBlock, GetOrCreateLandingPad(currentTryRegion), string.Empty);

_curBasicBlock = nextInstrBlock;
_currentBasicBlock.LLVMBlocks.Add(_curBasicBlock);
_currentBasicBlock.LastInternalBlock = _curBasicBlock;
AddInternalBasicBlock(nextInstrBlock);
builder.PositionAtEnd(_curBasicBlock);
}
return retVal;
Expand Down Expand Up @@ -3854,58 +3852,66 @@ private void ImportBinaryOperation(ILOpcode opcode)
Debug.Assert(CanPerformSignedOverflowOperations(op1.Kind));
if (Is32BitStackValue(op1.Kind))
{
BuildAddOverflowChecksForSize(ref AddOvf32Function, left, right, LLVMTypeRef.Int32, BuildConstInt32(int.MaxValue), BuildConstInt32(int.MinValue), true);
result = BuildArithmeticOperationWithOverflowCheck(left, right, "sadd", LLVMTypeRef.Int32);
}
else
{
BuildAddOverflowChecksForSize(ref AddOvf64Function, left, right, LLVMTypeRef.Int64, BuildConstInt64(long.MaxValue), BuildConstInt64(long.MinValue), true);
result = BuildArithmeticOperationWithOverflowCheck(left, right, "sadd", LLVMTypeRef.Int64);
}

result = _builder.BuildAdd(left, right, "add");
break;
case ILOpcode.add_ovf_un:
Debug.Assert(CanPerformUnsignedOverflowOperations(op1.Kind));
if (Is32BitStackValue(op1.Kind))
{
BuildAddOverflowChecksForSize(ref AddOvfUn32Function, left, right, LLVMTypeRef.Int32, BuildConstUInt32(uint.MaxValue), BuildConstInt32(0), false);
result = BuildArithmeticOperationWithOverflowCheck(left, right, "uadd", LLVMTypeRef.Int32);
}
else
{
BuildAddOverflowChecksForSize(ref AddOvfUn64Function, left, right, LLVMTypeRef.Int64, BuildConstUInt64(ulong.MaxValue), BuildConstInt64(0), false);
result = BuildArithmeticOperationWithOverflowCheck(left, right, "uadd", LLVMTypeRef.Int64);
}

result = _builder.BuildAdd(left, right, "add");
break;
case ILOpcode.sub_ovf:
Debug.Assert(CanPerformSignedOverflowOperations(op1.Kind));
if (Is32BitStackValue(op1.Kind))
{
BuildSubOverflowChecksForSize(ref SubOvf32Function, left, right, LLVMTypeRef.Int32, BuildConstInt32(int.MaxValue), BuildConstInt32(int.MinValue), true);
result = BuildArithmeticOperationWithOverflowCheck(left, right, "ssub", LLVMTypeRef.Int32);
}
else
{
BuildSubOverflowChecksForSize(ref SubOvf64Function, left, right, LLVMTypeRef.Int64, BuildConstInt64(long.MaxValue), BuildConstInt64(long.MinValue), true);
result = BuildArithmeticOperationWithOverflowCheck(left, right, "ssub", LLVMTypeRef.Int64);
}

result = _builder.BuildSub(left, right, "sub");
break;
case ILOpcode.sub_ovf_un:
Debug.Assert(CanPerformUnsignedOverflowOperations(op1.Kind));
if (Is32BitStackValue(op1.Kind))
{
BuildSubOverflowChecksForSize(ref SubOvfUn32Function, left, right, LLVMTypeRef.Int32, BuildConstUInt32(uint.MaxValue), BuildConstInt32(0), false);
result = BuildArithmeticOperationWithOverflowCheck(left, right, "usub", LLVMTypeRef.Int32);
}
else
{
BuildSubOverflowChecksForSize(ref SubOvfUn64Function, left, right, LLVMTypeRef.Int64, BuildConstUInt64(ulong.MaxValue), BuildConstInt64(0), false);
result = BuildArithmeticOperationWithOverflowCheck(left, right, "usub", LLVMTypeRef.Int64);
}

result = _builder.BuildSub(left, right, "sub");
break;
// TODO: Overflow checks
case ILOpcode.mul_ovf:
case ILOpcode.mul_ovf_un:
result = _builder.BuildMul(left, right, "mul");
Debug.Assert(CanPerformUnsignedOverflowOperations(op1.Kind));
if (Is32BitStackValue(op1.Kind))
{
result = BuildArithmeticOperationWithOverflowCheck(left, right, "umul", LLVMTypeRef.Int32);
}
else
{
result = BuildArithmeticOperationWithOverflowCheck(left, right, "umul", LLVMTypeRef.Int64);
}
break;
case ILOpcode.mul_ovf:
if (Is32BitStackValue(op1.Kind))
{
result = BuildArithmeticOperationWithOverflowCheck(left, right, "smul", LLVMTypeRef.Int32);
}
else
{
result = BuildArithmeticOperationWithOverflowCheck(left, right, "smul", LLVMTypeRef.Int64);
}
break;

default:
Expand All @@ -3922,6 +3928,32 @@ private void ImportBinaryOperation(ILOpcode opcode)
PushExpression(kind, "binop", result, type);
}

LLVMValueRef BuildArithmeticOperationWithOverflowCheck(LLVMValueRef left, LLVMValueRef right, string mulOp, LLVMTypeRef intType)
{
LLVMValueRef mulFunction = GetOrCreateLLVMFunction("llvm." + mulOp + ".with.overflow." + (intType == LLVMTypeRef.Int32 ? "i32" : "i64"), LLVMTypeRef.CreateFunction(
LLVMTypeRef.CreateStruct(new[] { intType, LLVMTypeRef.Int1}, false), new[] { intType, intType }));
LLVMValueRef mulRes = _builder.BuildCall(mulFunction, new[] {left, right});
var overflow = _builder.BuildExtractValue(mulRes, 1);
LLVMBasicBlockRef overflowBlock = _currentFunclet.AppendBasicBlock("ovf");
LLVMBasicBlockRef noOverflowBlock = _currentFunclet.AppendBasicBlock("no_ovf");
_builder.BuildCondBr(overflow, overflowBlock, noOverflowBlock);

_builder.PositionAtEnd(overflowBlock);
CallOrInvokeThrowException(_builder, "ThrowHelpers", "ThrowOverflowException");

_builder.PositionAtEnd(noOverflowBlock);
LLVMValueRef result = _builder.BuildExtractValue(mulRes, 0);
AddInternalBasicBlock(noOverflowBlock);
return result;
}

void AddInternalBasicBlock(LLVMBasicBlockRef basicBlock)
{
_curBasicBlock = basicBlock;
_currentBasicBlock.LLVMBlocks.Add(_curBasicBlock);
_currentBasicBlock.LastInternalBlock = _curBasicBlock;
}

bool CanPerformSignedOverflowOperations(StackValueKind kind)
{
return kind == StackValueKind.Int32 || kind == StackValueKind.Int64;
Expand All @@ -3938,108 +3970,6 @@ bool Is32BitStackValue(StackValueKind kind)
return kind == StackValueKind.Int32 || kind == StackValueKind.ByRef || kind == StackValueKind.ObjRef || kind == StackValueKind.NativeInt;
}

LLVMValueRef StartOverflowCheckFunction(LLVMTypeRef sizeTypeRef, bool signed,
string throwFuncName, out LLVMValueRef leftOp, out LLVMValueRef rightOp, out LLVMBuilderRef builder, out LLVMBasicBlockRef elseBlock,
out LLVMBasicBlockRef ovfBlock, out LLVMBasicBlockRef noOvfBlock)
{
LLVMValueRef llvmCheckFunction = Module.AddFunction(throwFuncName,
LLVMTypeRef.CreateFunction(LLVMTypeRef.Void,
new LLVMTypeRef[] { LLVMTypeRef.CreatePointer(LLVMTypeRef.Int8, 0), sizeTypeRef, sizeTypeRef }, false));
leftOp = llvmCheckFunction.GetParam(1);
rightOp = llvmCheckFunction.GetParam(2);
builder = Context.CreateBuilder();
var block = llvmCheckFunction.AppendBasicBlock("Block");
builder.PositionAtEnd(block);
elseBlock = default;
if (signed) // signed ops need a separate test for the right side being negative
{
var gtZeroCmp = builder.BuildICmp(LLVMIntPredicate.LLVMIntSGT, rightOp,
LLVMValueRef.CreateConstInt(sizeTypeRef, 0, false));
LLVMBasicBlockRef thenBlock = llvmCheckFunction.AppendBasicBlock("posOvfBlock");
elseBlock = llvmCheckFunction.AppendBasicBlock("negOvfBlock");
builder.BuildCondBr(gtZeroCmp, thenBlock, elseBlock);
builder.PositionAtEnd(thenBlock);
}
ovfBlock = llvmCheckFunction.AppendBasicBlock("ovfBlock");
noOvfBlock = llvmCheckFunction.AppendBasicBlock("noOvfBlock");
return llvmCheckFunction;
}

void BuildAddOverflowChecksForSize(ref LLVMValueRef llvmCheckFunction, LLVMValueRef left, LLVMValueRef right, LLVMTypeRef sizeTypeRef, LLVMValueRef maxValue, LLVMValueRef minValue, bool signed)
{
if (llvmCheckFunction.Handle == IntPtr.Zero)
{
// create function name for each of the 4 combinations signed/unsigned, 32/64 bit
string throwFuncName = "corert.throwovf" + (signed ? "add" : "unadd") + (sizeTypeRef.IntWidth == 32 ? "32" : "64");
llvmCheckFunction = StartOverflowCheckFunction(sizeTypeRef, signed, throwFuncName, out LLVMValueRef leftOp, out LLVMValueRef rightOp, out LLVMBuilderRef builder, out LLVMBasicBlockRef elseBlock,
out LLVMBasicBlockRef ovfBlock, out LLVMBasicBlockRef noOvfBlock);
// a > int.MaxValue - b
BuildOverflowCheck(builder, leftOp, signed ? LLVMIntPredicate.LLVMIntSGT : LLVMIntPredicate.LLVMIntUGT, maxValue, rightOp, ovfBlock, noOvfBlock, LLVMOpcode.LLVMSub);

builder.PositionAtEnd(ovfBlock);

ThrowException(builder, "ThrowHelpers", "ThrowOverflowException", llvmCheckFunction);

builder.PositionAtEnd(noOvfBlock);
LLVMBasicBlockRef opBlock = llvmCheckFunction.AppendBasicBlock("opBlock");
builder.BuildBr(opBlock);

if (signed)
{
builder.PositionAtEnd(elseBlock);
// a < int.MinValue - b
BuildOverflowCheck(builder, leftOp, LLVMIntPredicate.LLVMIntSLT, minValue, rightOp, ovfBlock, noOvfBlock, LLVMOpcode.LLVMSub);
}
builder.PositionAtEnd(opBlock);
builder.BuildRetVoid();
}

LLVMBasicBlockRef nextInstrBlock = default;
CallOrInvoke(false, _builder, GetCurrentTryRegion(), llvmCheckFunction, new List<LLVMValueRef> { GetShadowStack(), left, right }, ref nextInstrBlock);
}

void BuildSubOverflowChecksForSize(ref LLVMValueRef llvmCheckFunction, LLVMValueRef left, LLVMValueRef right, LLVMTypeRef sizeTypeRef, LLVMValueRef maxValue, LLVMValueRef minValue, bool signed)
{
if (llvmCheckFunction.Handle == IntPtr.Zero)
{
// create function name for each of the 4 combinations signed/unsigned, 32/64 bit
string throwFuncName = "corert.throwovf" + (signed ? "sub" : "unsub") + (sizeTypeRef.IntWidth == 32 ? "32" : "64");
llvmCheckFunction = StartOverflowCheckFunction(sizeTypeRef, signed, throwFuncName, out LLVMValueRef leftOp, out LLVMValueRef rightOp, out LLVMBuilderRef builder, out LLVMBasicBlockRef elseBlock,
out LLVMBasicBlockRef ovfBlock, out LLVMBasicBlockRef noOvfBlock);
// a < Min + b is overflow for unsigned
BuildOverflowCheck(builder, leftOp, signed ? LLVMIntPredicate.LLVMIntSLT : LLVMIntPredicate.LLVMIntULT, minValue, rightOp, ovfBlock, noOvfBlock, LLVMOpcode.LLVMAdd);

builder.PositionAtEnd(ovfBlock);

ThrowException(builder, "ThrowHelpers", "ThrowOverflowException", llvmCheckFunction);

builder.PositionAtEnd(noOvfBlock);
LLVMBasicBlockRef opBlock = llvmCheckFunction.AppendBasicBlock("opBlock");
builder.BuildBr(opBlock);

if (signed)
{
builder.PositionAtEnd(elseBlock);
// a - b overflows when b is negative if a > max + b
BuildOverflowCheck(builder, rightOp, LLVMIntPredicate.LLVMIntSGT, maxValue, leftOp, ovfBlock, noOvfBlock, LLVMOpcode.LLVMAdd);
}
builder.PositionAtEnd(opBlock);
builder.BuildRetVoid();
}

LLVMBasicBlockRef nextInstrBlock = default;
CallOrInvoke(false, _builder, GetCurrentTryRegion(), llvmCheckFunction, new List<LLVMValueRef> { GetShadowStack(), left, right }, ref nextInstrBlock);
}

private void BuildOverflowCheck(LLVMBuilderRef builder, LLVMValueRef compOperand, LLVMIntPredicate predicate,
LLVMValueRef limitValueRef, LLVMValueRef limitOperand, LLVMBasicBlockRef ovfBlock,
LLVMBasicBlockRef noOvfBlock, LLVMOpcode opCode)
{
LLVMValueRef sub = builder.BuildBinOp(opCode, limitValueRef, limitOperand);
LLVMValueRef ovfTest = builder.BuildICmp(predicate, compOperand, sub);
builder.BuildCondBr(ovfTest, ovfBlock, noOvfBlock);
}

private TypeDesc WidenBytesAndShorts(TypeDesc type)
{
switch (type.Category)
Expand Down Expand Up @@ -4542,7 +4472,7 @@ private void ThrowIfNull(LLVMValueRef entry)
}

LLVMBasicBlockRef nextInstrBlock = default;
CallOrInvoke(false, _builder, GetCurrentTryRegion(), NullRefFunction, new List<LLVMValueRef> { GetShadowStack(), entry }, ref nextInstrBlock);
CallOrInvoke(false, _builder, GetCurrentTryRegion(), NullRefFunction, new LLVMValueRef[] { GetShadowStack(), entry }, ref nextInstrBlock);
}

private void ThrowCkFinite(LLVMValueRef value, int size, ref LLVMValueRef llvmCheckFunction)
Expand Down Expand Up @@ -4585,16 +4515,33 @@ private void ThrowCkFinite(LLVMValueRef value, int size, ref LLVMValueRef llvmCh
}

LLVMBasicBlockRef nextInstrBlock = default;
CallOrInvoke(false, _builder, GetCurrentTryRegion(), llvmCheckFunction, new List<LLVMValueRef> { GetShadowStack(), value }, ref nextInstrBlock);
CallOrInvoke(false, _builder, GetCurrentTryRegion(), llvmCheckFunction, new LLVMValueRef[] { GetShadowStack(), value }, ref nextInstrBlock);
}

private void ThrowException(LLVMBuilderRef builder, string helperClass, string helperMethodName, LLVMValueRef throwingFunction)
{
LLVMValueRef fn = GetHelperLlvmMethod(helperClass, helperMethodName);
builder.BuildCall(fn, new LLVMValueRef[] {throwingFunction.GetParam(0) }, string.Empty);
builder.BuildUnreachable();
}

/// <summary>
/// Calls or invokes the call to throwing the exception so it can be caught in the caller
/// </summary>
private void CallOrInvokeThrowException(LLVMBuilderRef builder, string helperClass, string helperMethodName)
{
LLVMValueRef fn = GetHelperLlvmMethod(helperClass, helperMethodName);
LLVMBasicBlockRef nextInstrBlock = default;
CallOrInvoke(false, builder, GetCurrentTryRegion(), fn, new LLVMValueRef[] {GetShadowStack()}, ref nextInstrBlock);
builder.BuildUnreachable();
}

LLVMValueRef GetHelperLlvmMethod(string helperClass, string helperMethodName)
{
MetadataType helperType = _compilation.TypeSystemContext.SystemModule.GetKnownType("Internal.Runtime.CompilerHelpers", helperClass);
MethodDesc helperMethod = helperType.GetKnownMethod(helperMethodName, null);
LLVMValueRef fn = LLVMFunctionForMethod(helperMethod, helperMethod, null, false, null, null, out bool hasHiddenParam, out LLVMValueRef dictPtrPtrStore, out LLVMValueRef fatFunctionPtr);
builder.BuildCall(fn, new LLVMValueRef[] {throwingFunction.GetParam(0) }, string.Empty);
builder.BuildUnreachable();
return fn;
}

private LLVMValueRef GetInstanceFieldAddress(StackEntry objectEntry, FieldDesc field)
Expand Down
Loading