Skip to content

Commit

Permalink
[mypyc] Fix signed integer comparison (#9163)
Browse files Browse the repository at this point in the history
358522e 
generates inline comparison between short ints, explicit conversion to signed is missing, 
though, causing negative cases to fail.

This PR adds explicit type casts (although the name truncate here is a little bit misleading).

This PR will fix microbenchmark `int_list`.
  • Loading branch information
TH3CHARLie authored Jul 20, 2020
1 parent 28829fb commit 4cf246f
Show file tree
Hide file tree
Showing 10 changed files with 108 additions and 38 deletions.
31 changes: 29 additions & 2 deletions mypyc/codegen/emitfunc.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,9 @@
NAMESPACE_TYPE, NAMESPACE_MODULE, RaiseStandardError, CallC, LoadGlobal, Truncate,
BinaryIntOp
)
from mypyc.ir.rtypes import RType, RTuple
from mypyc.ir.rtypes import (
RType, RTuple, is_tagged, is_int32_rprimitive, is_int64_rprimitive
)
from mypyc.ir.func_ir import FuncIR, FuncDecl, FUNC_STATICMETHOD, FUNC_CLASSMETHOD
from mypyc.ir.class_ir import ClassIR

Expand Down Expand Up @@ -438,7 +440,18 @@ def visit_binary_int_op(self, op: BinaryIntOp) -> None:
dest = self.reg(op)
lhs = self.reg(op.lhs)
rhs = self.reg(op.rhs)
self.emit_line('%s = %s %s %s;' % (dest, lhs, op.op_str[op.op], rhs))
lhs_cast = ""
rhs_cast = ""
signed_op = {BinaryIntOp.SLT, BinaryIntOp.SGT, BinaryIntOp.SLE, BinaryIntOp.SGE}
unsigned_op = {BinaryIntOp.ULT, BinaryIntOp.UGT, BinaryIntOp.ULE, BinaryIntOp.UGE}
if op.op in signed_op:
lhs_cast = self.emit_signed_int_cast(op.lhs.type)
rhs_cast = self.emit_signed_int_cast(op.rhs.type)
elif op.op in unsigned_op:
lhs_cast = self.emit_unsigned_int_cast(op.lhs.type)
rhs_cast = self.emit_unsigned_int_cast(op.rhs.type)
self.emit_line('%s = %s%s %s %s%s;' % (dest, lhs_cast, lhs,
op.op_str[op.op], rhs_cast, rhs))

# Helpers

Expand Down Expand Up @@ -482,3 +495,17 @@ def emit_traceback(self, op: Branch) -> None:
globals_static))
if DEBUG_ERRORS:
self.emit_line('assert(PyErr_Occurred() != NULL && "failure w/o err!");')

def emit_signed_int_cast(self, type: RType) -> str:
if is_tagged(type):
return '(Py_ssize_t)'
else:
return ''

def emit_unsigned_int_cast(self, type: RType) -> str:
if is_int32_rprimitive(type):
return '(uint32_t)'
elif is_int64_rprimitive(type):
return '(uint64_t)'
else:
return ''
34 changes: 25 additions & 9 deletions mypyc/ir/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -1270,12 +1270,17 @@ class BinaryIntOp(RegisterOp):
DIV = 3 # type: Final
MOD = 4 # type: Final
# logical
# S for signed and U for unsigned
EQ = 100 # type: Final
NEQ = 101 # type: Final
LT = 102 # type: Final
GT = 103 # type: Final
LEQ = 104 # type: Final
GEQ = 105 # type: Final
SLT = 102 # type: Final
SGT = 103 # type: Final
SLE = 104 # type: Final
SGE = 105 # type: Final
ULT = 106 # type: Final
UGT = 107 # type: Final
ULE = 108 # type: Final
UGE = 109 # type: Final
# bitwise
AND = 200 # type: Final
OR = 201 # type: Final
Expand All @@ -1291,10 +1296,14 @@ class BinaryIntOp(RegisterOp):
MOD: '%',
EQ: '==',
NEQ: '!=',
LT: '<',
GT: '>',
LEQ: '<=',
GEQ: '>=',
SLT: '<',
SGT: '>',
SLE: '<=',
SGE: '>=',
ULT: '<',
UGT: '>',
ULE: '<=',
UGE: '>=',
AND: '&',
OR: '|',
XOR: '^',
Expand All @@ -1313,7 +1322,14 @@ def sources(self) -> List[Value]:
return [self.lhs, self.rhs]

def to_str(self, env: Environment) -> str:
return env.format('%r = %r %s %r', self, self.lhs, self.op_str[self.op], self.rhs)
if self.op in (self.SLT, self.SGT, self.SLE, self.SGE):
sign_format = " :: signed"
elif self.op in (self.ULT, self.UGT, self.ULE, self.UGE):
sign_format = " :: unsigned"
else:
sign_format = ""
return env.format('%r = %r %s %r%s', self, self.lhs,
self.op_str[self.op], self.rhs, sign_format)

def accept(self, visitor: 'OpVisitor[T]') -> T:
return visitor.visit_binary_int_op(self)
Expand Down
4 changes: 4 additions & 0 deletions mypyc/ir/rtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,6 +281,10 @@ def __repr__(self) -> str:
is_refcounted=True) # type: Final


def is_tagged(rtype: RType) -> bool:
return rtype is int_rprimitive or rtype is short_int_rprimitive


def is_int_rprimitive(rtype: RType) -> bool:
return rtype is int_rprimitive

Expand Down
2 changes: 1 addition & 1 deletion mypyc/primitives/int_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,5 +173,5 @@ def int_unary_op(name: str, c_function_name: str) -> CFunctionDescription:
int_logical_op_mapping = {
'==': IntLogicalOpDescrption(BinaryIntOp.EQ, int_equal_, False, False),
'!=': IntLogicalOpDescrption(BinaryIntOp.NEQ, int_equal_, True, False),
'<': IntLogicalOpDescrption(BinaryIntOp.LT, int_less_than_, False, False)
'<': IntLogicalOpDescrption(BinaryIntOp.SLT, int_less_than_, False, False)
} # type: Dict[str, IntLogicalOpDescrption]
4 changes: 2 additions & 2 deletions mypyc/test-data/analysis.test
Original file line number Diff line number Diff line change
Expand Up @@ -392,7 +392,7 @@ L1:
r9 = r4 & r8
if r9 goto L2 else goto L3 :: bool
L2:
r10 = a < a
r10 = a < a :: signed
r0 = r10
goto L4
L3:
Expand All @@ -413,7 +413,7 @@ L6:
r21 = r16 & r20
if r21 goto L7 else goto L8 :: bool
L7:
r22 = a < a
r22 = a < a :: signed
r12 = r22
goto L9
L8:
Expand Down
2 changes: 1 addition & 1 deletion mypyc/test-data/exceptions.test
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@ L1:
r11 = r6 & r10
if r11 goto L2 else goto L3 :: bool
L2:
r12 = i < l
r12 = i < l :: signed
r2 = r12
goto L4
L3:
Expand Down
24 changes: 12 additions & 12 deletions mypyc/test-data/irbuild-basic.test
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ L0:
r9 = r4 & r8
if r9 goto L1 else goto L2 :: bool
L1:
r10 = x < y
r10 = x < y :: signed
r0 = r10
goto L3
L2:
Expand Down Expand Up @@ -152,7 +152,7 @@ L0:
r9 = r4 & r8
if r9 goto L1 else goto L2 :: bool
L1:
r10 = x < y
r10 = x < y :: signed
r0 = r10
goto L3
L2:
Expand Down Expand Up @@ -198,7 +198,7 @@ L0:
r9 = r4 & r8
if r9 goto L1 else goto L2 :: bool
L1:
r10 = x < y
r10 = x < y :: signed
r0 = r10
goto L3
L2:
Expand Down Expand Up @@ -269,7 +269,7 @@ L0:
r9 = r4 & r8
if r9 goto L1 else goto L2 :: bool
L1:
r10 = x < y
r10 = x < y :: signed
r0 = r10
goto L3
L2:
Expand Down Expand Up @@ -338,7 +338,7 @@ L0:
r9 = r4 & r8
if r9 goto L1 else goto L2 :: bool
L1:
r10 = x < y
r10 = x < y :: signed
r0 = r10
goto L3
L2:
Expand Down Expand Up @@ -378,7 +378,7 @@ L0:
r9 = r4 & r8
if r9 goto L1 else goto L2 :: bool
L1:
r10 = x < y
r10 = x < y :: signed
r0 = r10
goto L3
L2:
Expand Down Expand Up @@ -493,7 +493,7 @@ L0:
r9 = r4 & r8
if r9 goto L1 else goto L2 :: bool
L1:
r10 = x < y
r10 = x < y :: signed
r0 = r10
goto L3
L2:
Expand Down Expand Up @@ -2041,7 +2041,7 @@ L0:
r9 = r8
L1:
r10 = len r7 :: list
r11 = r9 < r10
r11 = r9 < r10 :: signed
if r11 goto L2 else goto L8 :: bool
L2:
r12 = r7[r9] :: unsafe list
Expand Down Expand Up @@ -2105,7 +2105,7 @@ L0:
r9 = r8
L1:
r10 = len r7 :: list
r11 = r9 < r10
r11 = r9 < r10 :: signed
if r11 goto L2 else goto L8 :: bool
L2:
r12 = r7[r9] :: unsafe list
Expand Down Expand Up @@ -2166,7 +2166,7 @@ L0:
r1 = r0
L1:
r2 = len l :: list
r3 = r1 < r2
r3 = r1 < r2 :: signed
if r3 goto L2 else goto L4 :: bool
L2:
r4 = l[r1] :: unsafe list
Expand All @@ -2188,7 +2188,7 @@ L4:
r13 = r12
L5:
r14 = len l :: list
r15 = r13 < r14
r15 = r13 < r14 :: signed
if r15 goto L6 else goto L8 :: bool
L6:
r16 = l[r13] :: unsafe list
Expand Down Expand Up @@ -2750,7 +2750,7 @@ L0:
r12 = r7 & r11
if r12 goto L1 else goto L2 :: bool
L1:
r13 = r0 < r1
r13 = r0 < r1 :: signed
r3 = r13
goto L3
L2:
Expand Down
2 changes: 1 addition & 1 deletion mypyc/test-data/irbuild-lists.test
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,7 @@ L0:
r2 = r0
i = r2
L1:
r3 = r2 < r1
r3 = r2 < r1 :: signed
if r3 goto L2 else goto L4 :: bool
L2:
r4 = CPyList_GetItem(l, i)
Expand Down
16 changes: 8 additions & 8 deletions mypyc/test-data/irbuild-statements.test
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ L0:
r3 = r1
i = r3
L1:
r4 = r3 < r2
r4 = r3 < r2 :: signed
if r4 goto L2 else goto L4 :: bool
L2:
r5 = CPyTagged_Add(x, i)
Expand Down Expand Up @@ -107,7 +107,7 @@ L0:
r2 = r0
n = r2
L1:
r3 = r2 < r1
r3 = r2 < r1 :: signed
if r3 goto L2 else goto L4 :: bool
L2:
goto L4
Expand Down Expand Up @@ -197,7 +197,7 @@ L0:
r2 = r0
n = r2
L1:
r3 = r2 < r1
r3 = r2 < r1 :: signed
if r3 goto L2 else goto L4 :: bool
L2:
L3:
Expand Down Expand Up @@ -271,7 +271,7 @@ L0:
r2 = r1
L1:
r3 = len ls :: list
r4 = r2 < r3
r4 = r2 < r3 :: signed
if r4 goto L2 else goto L4 :: bool
L2:
r5 = ls[r2] :: unsafe list
Expand Down Expand Up @@ -859,7 +859,7 @@ L0:
r3 = r2
L1:
r4 = len a :: list
r5 = r3 < r4
r5 = r3 < r4 :: signed
if r5 goto L2 else goto L4 :: bool
L2:
r6 = a[r3] :: unsafe list
Expand Down Expand Up @@ -942,7 +942,7 @@ L0:
r2 = iter b :: object
L1:
r3 = len a :: list
r4 = r1 < r3
r4 = r1 < r3 :: signed
if r4 goto L2 else goto L7 :: bool
L2:
r5 = next r2 :: object
Expand Down Expand Up @@ -997,10 +997,10 @@ L1:
if is_error(r6) goto L6 else goto L2
L2:
r7 = len b :: list
r8 = r2 < r7
r8 = r2 < r7 :: signed
if r8 goto L3 else goto L6 :: bool
L3:
r9 = r5 < r4
r9 = r5 < r4 :: signed
if r9 goto L4 else goto L6 :: bool
L4:
r10 = unbox(bool, r6)
Expand Down
27 changes: 25 additions & 2 deletions mypyc/test/test_emitfunc.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,12 @@
from mypyc.ir.ops import (
Environment, BasicBlock, Goto, Return, LoadInt, Assign, IncRef, DecRef, Branch,
Call, Unbox, Box, TupleGet, GetAttr, PrimitiveOp, RegisterOp,
SetAttr, Op, Value, CallC
SetAttr, Op, Value, CallC, BinaryIntOp
)
from mypyc.ir.rtypes import (
RTuple, RInstance, int_rprimitive, bool_rprimitive, list_rprimitive,
dict_rprimitive, object_rprimitive, c_int_rprimitive
dict_rprimitive, object_rprimitive, c_int_rprimitive, short_int_rprimitive, int32_rprimitive,
int64_rprimitive
)
from mypyc.ir.func_ir import FuncIR, FuncDecl, RuntimeArg, FuncSignature
from mypyc.ir.class_ir import ClassIR
Expand Down Expand Up @@ -44,6 +45,12 @@ def setUp(self) -> None:
self.o2 = self.env.add_local(Var('o2'), object_rprimitive)
self.d = self.env.add_local(Var('d'), dict_rprimitive)
self.b = self.env.add_local(Var('b'), bool_rprimitive)
self.s1 = self.env.add_local(Var('s1'), short_int_rprimitive)
self.s2 = self.env.add_local(Var('s2'), short_int_rprimitive)
self.i32 = self.env.add_local(Var('i32'), int32_rprimitive)
self.i32_1 = self.env.add_local(Var('i32_1'), int32_rprimitive)
self.i64 = self.env.add_local(Var('i64'), int64_rprimitive)
self.i64_1 = self.env.add_local(Var('i64_1'), int64_rprimitive)
self.t = self.env.add_local(Var('t'), RTuple([int_rprimitive, bool_rprimitive]))
self.tt = self.env.add_local(
Var('tt'),
Expand Down Expand Up @@ -245,6 +252,22 @@ def test_dict_contains(self) -> None:
'in', self.b, self.o, self.d,
"""cpy_r_r0 = PyDict_Contains(cpy_r_d, cpy_r_o);""")

def test_binary_int_op(self) -> None:
# signed
self.assert_emit(BinaryIntOp(bool_rprimitive, self.s1, self.s2, BinaryIntOp.SLT, 1),
"""cpy_r_r0 = (Py_ssize_t)cpy_r_s1 < (Py_ssize_t)cpy_r_s2;""")
self.assert_emit(BinaryIntOp(bool_rprimitive, self.i32, self.i32_1, BinaryIntOp.SLT, 1),
"""cpy_r_r00 = cpy_r_i32 < cpy_r_i32_1;""")
self.assert_emit(BinaryIntOp(bool_rprimitive, self.i64, self.i64_1, BinaryIntOp.SLT, 1),
"""cpy_r_r01 = cpy_r_i64 < cpy_r_i64_1;""")
# unsigned
self.assert_emit(BinaryIntOp(bool_rprimitive, self.s1, self.s2, BinaryIntOp.ULT, 1),
"""cpy_r_r02 = cpy_r_s1 < cpy_r_s2;""")
self.assert_emit(BinaryIntOp(bool_rprimitive, self.i32, self.i32_1, BinaryIntOp.ULT, 1),
"""cpy_r_r03 = (uint32_t)cpy_r_i32 < (uint32_t)cpy_r_i32_1;""")
self.assert_emit(BinaryIntOp(bool_rprimitive, self.i64, self.i64_1, BinaryIntOp.ULT, 1),
"""cpy_r_r04 = (uint64_t)cpy_r_i64 < (uint64_t)cpy_r_i64_1;""")

def assert_emit(self, op: Op, expected: str) -> None:
self.emitter.fragments = []
self.declarations.fragments = []
Expand Down

0 comments on commit 4cf246f

Please sign in to comment.