Skip to content
This repository has been archived by the owner on Nov 1, 2020. It is now read-only.

Wasm: add support for overflow checks on signed and unsigned ints multiply #8259

Merged
merged 7 commits into from
Aug 26, 2020
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
127 changes: 109 additions & 18 deletions src/ILCompiler.WebAssembly/src/CodeGen/ILToWebAssemblyImporter.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2782,15 +2782,15 @@ private bool ImportIntrinsicCall(MethodDesc method, MethodDesc runtimeDetermined
// then
builder.PositionAtEnd(notFatBranch);
ExceptionRegion currentTryRegion = GetCurrentTryRegion();
LLVMValueRef notFatReturn = CallOrInvoke(fromLandingPad, builder, currentTryRegion, fn, llvmArgs, ref nextInstrBlock);
LLVMValueRef notFatReturn = CallOrInvoke(fromLandingPad, builder, currentTryRegion, fn, llvmArgs.ToArray(), ref nextInstrBlock);
builder.BuildBr(endifBlock);

// else
builder.PositionAtEnd(fatBranch);
var fnWithDict = builder.BuildCast(LLVMOpcode.LLVMBitCast, fn, LLVMTypeRef.CreatePointer(GetLLVMSignatureForMethod(runtimeDeterminedMethod.Signature, true), 0), "fnWithDict");
var dictDereffed = builder.BuildLoad(builder.BuildLoad( dict, "l1"), "l2");
llvmArgs.Insert(needsReturnSlot ? 2 : 1, dictDereffed);
LLVMValueRef fatReturn = CallOrInvoke(fromLandingPad, builder, currentTryRegion, fnWithDict, llvmArgs, ref nextInstrBlock);
LLVMValueRef fatReturn = CallOrInvoke(fromLandingPad, builder, currentTryRegion, fnWithDict, llvmArgs.ToArray(), ref nextInstrBlock);
builder.BuildBr(endifBlock);

// endif
Expand All @@ -2806,7 +2806,7 @@ private bool ImportIntrinsicCall(MethodDesc method, MethodDesc runtimeDetermined
}
else
{
llvmReturn = CallOrInvoke(fromLandingPad, builder, GetCurrentTryRegion(), fn, llvmArgs, ref nextInstrBlock);
llvmReturn = CallOrInvoke(fromLandingPad, builder, GetCurrentTryRegion(), fn, llvmArgs.ToArray(), ref nextInstrBlock);
}

if (!returnType.IsVoid)
Expand All @@ -2829,23 +2829,21 @@ private bool ImportIntrinsicCall(MethodDesc method, MethodDesc runtimeDetermined
}

LLVMValueRef CallOrInvoke(bool fromLandingPad, LLVMBuilderRef builder, ExceptionRegion currentTryRegion,
LLVMValueRef fn, List<LLVMValueRef> llvmArgs, ref LLVMBasicBlockRef nextInstrBlock)
LLVMValueRef fn, LLVMValueRef[] llvmArgs, ref LLVMBasicBlockRef nextInstrBlock)
{
LLVMValueRef retVal;
if (currentTryRegion == null || fromLandingPad) // not handling exceptions that occur in the LLVM landing pad determining the EH handler
{
retVal = builder.BuildCall(fn, llvmArgs.ToArray(), string.Empty);
retVal = builder.BuildCall(fn, llvmArgs, string.Empty);
}
else
{
nextInstrBlock = _currentFunclet.AppendBasicBlock(String.Format("Try{0:X}", _currentOffset));

retVal = builder.BuildInvoke(fn, llvmArgs.ToArray(),
retVal = builder.BuildInvoke(fn, llvmArgs,
nextInstrBlock, GetOrCreateLandingPad(currentTryRegion), string.Empty);

_curBasicBlock = nextInstrBlock;
_currentBasicBlock.LLVMBlocks.Add(_curBasicBlock);
_currentBasicBlock.LastInternalBlock = _curBasicBlock;
AddInternalBasicBlock(nextInstrBlock);
builder.PositionAtEnd(_curBasicBlock);
}
return retVal;
Expand Down Expand Up @@ -3902,10 +3900,26 @@ private void ImportBinaryOperation(ILOpcode opcode)

result = _builder.BuildSub(left, right, "sub");
break;
// TODO: Overflow checks
case ILOpcode.mul_ovf:
case ILOpcode.mul_ovf_un:
result = _builder.BuildMul(left, right, "mul");
Debug.Assert(CanPerformUnsignedOverflowOperations(op1.Kind));
if (Is32BitStackValue(op1.Kind))
{
result = BuildMulOverflowCheck(left, right, "umul", LLVMTypeRef.Int32);
}
else
{
result = BuildMulOverflowCheck(left, right, "umul", LLVMTypeRef.Int64);
}
break;
case ILOpcode.mul_ovf:
if (Is32BitStackValue(op1.Kind))
{
result = BuildMulOverflowCheck(left, right, "smul", LLVMTypeRef.Int32);
}
else
{
result = BuildMulOverflowCheck(left, right, "smul", LLVMTypeRef.Int64);
}
break;

default:
Expand All @@ -3922,6 +3936,32 @@ private void ImportBinaryOperation(ILOpcode opcode)
PushExpression(kind, "binop", result, type);
}

LLVMValueRef BuildMulOverflowCheck(LLVMValueRef left, LLVMValueRef right, string mulOp, LLVMTypeRef intType)
{
LLVMValueRef mulFunction = GetOrCreateLLVMFunction("llvm." + mulOp + ".with.overflow." + (intType == LLVMTypeRef.Int32 ? "i32" : "i64"), LLVMTypeRef.CreateFunction(
LLVMTypeRef.CreateStruct(new[] { intType, LLVMTypeRef.Int1}, false), new[] { intType, intType }));
LLVMValueRef mulRes = _builder.BuildCall(mulFunction, new[] {left, right});
var overflow = _builder.BuildExtractValue(mulRes, 1);
LLVMBasicBlockRef overflowBlock = _currentFunclet.AppendBasicBlock("ovf");
LLVMBasicBlockRef noOverflowBlock = _currentFunclet.AppendBasicBlock("no_ovf");
_builder.BuildCondBr(overflow, overflowBlock, noOverflowBlock);

_builder.PositionAtEnd(overflowBlock);
CallOrInvokeThrowException(_builder, "ThrowHelpers", "ThrowOverflowException");

_builder.PositionAtEnd(noOverflowBlock);
LLVMValueRef result = _builder.BuildExtractValue(mulRes, 0);
AddInternalBasicBlock(noOverflowBlock);
return result;
}

void AddInternalBasicBlock(LLVMBasicBlockRef basicBlock)
{
_curBasicBlock = basicBlock;
_currentBasicBlock.LLVMBlocks.Add(_curBasicBlock);
_currentBasicBlock.LastInternalBlock = _curBasicBlock;
}

bool CanPerformSignedOverflowOperations(StackValueKind kind)
{
return kind == StackValueKind.Int32 || kind == StackValueKind.Int64;
Expand Down Expand Up @@ -3995,7 +4035,7 @@ void BuildAddOverflowChecksForSize(ref LLVMValueRef llvmCheckFunction, LLVMValue
}

LLVMBasicBlockRef nextInstrBlock = default;
CallOrInvoke(false, _builder, GetCurrentTryRegion(), llvmCheckFunction, new List<LLVMValueRef> { GetShadowStack(), left, right }, ref nextInstrBlock);
CallOrInvoke(false, _builder, GetCurrentTryRegion(), llvmCheckFunction, new LLVMValueRef[] { GetShadowStack(), left, right }, ref nextInstrBlock);
}

void BuildSubOverflowChecksForSize(ref LLVMValueRef llvmCheckFunction, LLVMValueRef left, LLVMValueRef right, LLVMTypeRef sizeTypeRef, LLVMValueRef maxValue, LLVMValueRef minValue, bool signed)
Expand Down Expand Up @@ -4028,7 +4068,41 @@ void BuildSubOverflowChecksForSize(ref LLVMValueRef llvmCheckFunction, LLVMValue
}

LLVMBasicBlockRef nextInstrBlock = default;
CallOrInvoke(false, _builder, GetCurrentTryRegion(), llvmCheckFunction, new List<LLVMValueRef> { GetShadowStack(), left, right }, ref nextInstrBlock);
CallOrInvoke(false, _builder, GetCurrentTryRegion(), llvmCheckFunction, new LLVMValueRef[] { GetShadowStack(), left, right }, ref nextInstrBlock);
}

private LLVMValueRef CallLlvmAbs(LLVMBuilderRef builder, LLVMValueRef operand, LLVMTypeRef sizeTypeRef)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unused?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, removed

{
// TODO: There is an LLVM intrinsic for this but couldn't make it work (llvm.abs.i32 or llvm.abs.i64)
bool is32 = sizeTypeRef.IntWidth == 32;
LLVMValueRef llvmAbsFunction = is32 ? Abs32Function : Abs64Function;
if (llvmAbsFunction.Handle == IntPtr.Zero)
{
llvmAbsFunction = Module.AddFunction("corert.abs" + sizeTypeRef.IntWidth,
LLVMTypeRef.CreateFunction(sizeTypeRef, new LLVMTypeRef[] { sizeTypeRef }, false));
LLVMValueRef opParam = llvmAbsFunction.GetParam(0);
LLVMBuilderRef absBuilder = Context.CreateBuilder();
var block = llvmAbsFunction.AppendBasicBlock("Block");
absBuilder.PositionAtEnd(block);
LLVMBasicBlockRef thenBlock = llvmAbsFunction.AppendBasicBlock("negBlock");
LLVMBasicBlockRef elseBlock = llvmAbsFunction.AppendBasicBlock("posBlock");
var ltZeroCmp = absBuilder.BuildICmp(LLVMIntPredicate.LLVMIntSLT, opParam, LLVMValueRef.CreateConstInt(sizeTypeRef, 0));
absBuilder.BuildCondBr(ltZeroCmp, thenBlock, elseBlock);
absBuilder.PositionAtEnd(thenBlock);
var negate = absBuilder.BuildNeg(opParam);
absBuilder.BuildRet(negate);
absBuilder.PositionAtEnd(elseBlock);
absBuilder.BuildRet(opParam);
if (is32)
{
Abs32Function = llvmAbsFunction;
}
else
{
Abs64Function = llvmAbsFunction;
}
}
return builder.BuildCall(llvmAbsFunction, new[] {operand}, "abs");
}

private void BuildOverflowCheck(LLVMBuilderRef builder, LLVMValueRef compOperand, LLVMIntPredicate predicate,
Expand Down Expand Up @@ -4534,7 +4608,7 @@ private void ThrowIfNull(LLVMValueRef entry)
}

LLVMBasicBlockRef nextInstrBlock = default;
CallOrInvoke(false, _builder, GetCurrentTryRegion(), NullRefFunction, new List<LLVMValueRef> { GetShadowStack(), entry }, ref nextInstrBlock);
CallOrInvoke(false, _builder, GetCurrentTryRegion(), NullRefFunction, new LLVMValueRef[] { GetShadowStack(), entry }, ref nextInstrBlock);
}

private void ThrowCkFinite(LLVMValueRef value, int size, ref LLVMValueRef llvmCheckFunction)
Expand Down Expand Up @@ -4577,16 +4651,33 @@ private void ThrowCkFinite(LLVMValueRef value, int size, ref LLVMValueRef llvmCh
}

LLVMBasicBlockRef nextInstrBlock = default;
CallOrInvoke(false, _builder, GetCurrentTryRegion(), llvmCheckFunction, new List<LLVMValueRef> { GetShadowStack(), value }, ref nextInstrBlock);
CallOrInvoke(false, _builder, GetCurrentTryRegion(), llvmCheckFunction, new LLVMValueRef[] { GetShadowStack(), value }, ref nextInstrBlock);
}

private void ThrowException(LLVMBuilderRef builder, string helperClass, string helperMethodName, LLVMValueRef throwingFunction)
{
LLVMValueRef fn = GetHelperLlvmMethod(helperClass, helperMethodName);
builder.BuildCall(fn, new LLVMValueRef[] {throwingFunction.GetParam(0) }, string.Empty);
builder.BuildUnreachable();
}

/// <summary>
/// Calls or invokes the call to throwing the exception so it can be caught in the caller
/// </summary>
private void CallOrInvokeThrowException(LLVMBuilderRef builder, string helperClass, string helperMethodName)
{
LLVMValueRef fn = GetHelperLlvmMethod(helperClass, helperMethodName);
LLVMBasicBlockRef nextInstrBlock = default;
CallOrInvoke(false, builder, GetCurrentTryRegion(), fn, new LLVMValueRef[] {GetShadowStack()}, ref nextInstrBlock);
builder.BuildUnreachable();
}

LLVMValueRef GetHelperLlvmMethod(string helperClass, string helperMethodName)
{
MetadataType helperType = _compilation.TypeSystemContext.SystemModule.GetKnownType("Internal.Runtime.CompilerHelpers", helperClass);
MethodDesc helperMethod = helperType.GetKnownMethod(helperMethodName, null);
LLVMValueRef fn = LLVMFunctionForMethod(helperMethod, helperMethod, null, false, null, null, out bool hasHiddenParam, out LLVMValueRef dictPtrPtrStore, out LLVMValueRef fatFunctionPtr);
builder.BuildCall(fn, new LLVMValueRef[] {throwingFunction.GetParam(0) }, string.Empty);
builder.BuildUnreachable();
return fn;
}

private LLVMValueRef GetInstanceFieldAddress(StackEntry objectEntry, FieldDesc field)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,8 @@ public static void CompileMethod(WebAssemblyCodegenCompilation compilation, WebA
static LLVMValueRef SubOvfUn32Function = default(LLVMValueRef);
static LLVMValueRef SubOvf64Function = default(LLVMValueRef);
static LLVMValueRef SubOvfUn64Function = default(LLVMValueRef);
static LLVMValueRef Abs32Function = default(LLVMValueRef);
static LLVMValueRef Abs64Function = default(LLVMValueRef);
public static LLVMValueRef GxxPersonality = default(LLVMValueRef);
public static LLVMTypeRef GxxPersonalityType = default(LLVMTypeRef);

Expand Down
Loading