-
-
Notifications
You must be signed in to change notification settings - Fork 2.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Constant fold more unary and binary expressions #15202
Changes from 13 commits
33bc95a
c80e95e
ced8a10
54a892b
c057484
3b3c4c7
3bae7d9
3670d0d
8d24649
0655190
85a35c4
7fb8f8c
7542d99
c32dc3b
be9e755
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -8,11 +8,21 @@ | |
from typing import Union | ||
from typing_extensions import Final | ||
|
||
from mypy.nodes import Expression, FloatExpr, IntExpr, NameExpr, OpExpr, StrExpr, UnaryExpr, Var | ||
from mypy.nodes import ( | ||
ComplexExpr, | ||
Expression, | ||
FloatExpr, | ||
IntExpr, | ||
NameExpr, | ||
OpExpr, | ||
StrExpr, | ||
UnaryExpr, | ||
Var, | ||
) | ||
|
||
# All possible result types of constant folding | ||
ConstantValue = Union[int, bool, float, str] | ||
CONST_TYPES: Final = (int, bool, float, str) | ||
ConstantValue = Union[int, bool, float, complex, str] | ||
CONST_TYPES: Final = (int, bool, float, complex, str) | ||
|
||
|
||
def constant_fold_expr(expr: Expression, cur_mod_id: str) -> ConstantValue | None: | ||
|
@@ -39,6 +49,8 @@ def constant_fold_expr(expr: Expression, cur_mod_id: str) -> ConstantValue | Non | |
return expr.value | ||
if isinstance(expr, FloatExpr): | ||
return expr.value | ||
if isinstance(expr, ComplexExpr): | ||
return expr.value | ||
elif isinstance(expr, NameExpr): | ||
if expr.name == "True": | ||
return True | ||
|
@@ -56,26 +68,60 @@ def constant_fold_expr(expr: Expression, cur_mod_id: str) -> ConstantValue | Non | |
elif isinstance(expr, OpExpr): | ||
left = constant_fold_expr(expr.left, cur_mod_id) | ||
right = constant_fold_expr(expr.right, cur_mod_id) | ||
if isinstance(left, int) and isinstance(right, int): | ||
return constant_fold_binary_int_op(expr.op, left, right) | ||
elif isinstance(left, str) and isinstance(right, str): | ||
return constant_fold_binary_str_op(expr.op, left, right) | ||
if left is not None and right is not None: | ||
return constant_fold_binary_op(expr.op, left, right) | ||
elif isinstance(expr, UnaryExpr): | ||
value = constant_fold_expr(expr.expr, cur_mod_id) | ||
if isinstance(value, int): | ||
return constant_fold_unary_int_op(expr.op, value) | ||
if isinstance(value, float): | ||
return constant_fold_unary_float_op(expr.op, value) | ||
if value is not None: | ||
return constant_fold_unary_op(expr.op, value) | ||
return None | ||
|
||
|
||
def constant_fold_binary_int_op(op: str, left: int, right: int) -> int | None: | ||
def constant_fold_binary_op( | ||
op: str, left: ConstantValue, right: ConstantValue | ||
) -> ConstantValue | None: | ||
if isinstance(left, int) and isinstance(right, int): | ||
return constant_fold_binary_int_op(op, left, right) | ||
|
||
# Float and mixed int/float arithmetic. | ||
if isinstance(left, float) and isinstance(right, float): | ||
return constant_fold_binary_float_op(op, left, right) | ||
elif isinstance(left, float) and isinstance(right, int): | ||
return constant_fold_binary_float_op(op, left, right) | ||
elif isinstance(left, int) and isinstance(right, float): | ||
return constant_fold_binary_float_op(op, left, right) | ||
|
||
# String concatenation and multiplication. | ||
if op == "+" and isinstance(left, str) and isinstance(right, str): | ||
return left + right | ||
elif op == "*" and isinstance(left, str) and isinstance(right, int): | ||
return left * right | ||
elif op == "*" and isinstance(left, int) and isinstance(right, str): | ||
return left * right | ||
|
||
# Complex construction. | ||
if op == "+" and isinstance(left, (int, float)) and isinstance(right, complex): | ||
return left + right | ||
elif op == "+" and isinstance(left, complex) and isinstance(right, (int, float)): | ||
return left + right | ||
elif op == "-" and isinstance(left, (int, float)) and isinstance(right, complex): | ||
return left - right | ||
elif op == "-" and isinstance(left, complex) and isinstance(right, (int, float)): | ||
return left - right | ||
|
||
return None | ||
|
||
|
||
def constant_fold_binary_int_op(op: str, left: int, right: int) -> int | float | None: | ||
if op == "+": | ||
return left + right | ||
if op == "-": | ||
return left - right | ||
elif op == "*": | ||
return left * right | ||
elif op == "/": | ||
if right != 0: | ||
return left / right | ||
elif op == "//": | ||
if right != 0: | ||
return left // right | ||
|
@@ -102,25 +148,41 @@ def constant_fold_binary_int_op(op: str, left: int, right: int) -> int | None: | |
return None | ||
|
||
|
||
def constant_fold_unary_int_op(op: str, value: int) -> int | None: | ||
if op == "-": | ||
return -value | ||
elif op == "~": | ||
return ~value | ||
elif op == "+": | ||
return value | ||
def constant_fold_binary_float_op(op: str, left: int | float, right: int | float) -> float | None: | ||
assert not (isinstance(left, int) and isinstance(right, int)), (op, left, right) | ||
if op == "+": | ||
return left + right | ||
elif op == "-": | ||
return left - right | ||
elif op == "*": | ||
return left * right | ||
elif op == "/": | ||
if right != 0: | ||
return left / right | ||
elif op == "//": | ||
if right != 0: | ||
return left // right | ||
elif op == "%": | ||
if right != 0: | ||
return left % right | ||
elif op == "**": | ||
if (left < 0 and right >= 1 or right == 0) or (left >= 0 and right >= 0): | ||
try: | ||
ret = left**right | ||
except OverflowError: | ||
return None | ||
Comment on lines
+170
to
+173
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Digging through the CPython source, only power ops can raise an OverflowError. |
||
else: | ||
assert isinstance(ret, float) | ||
return ret | ||
|
||
return None | ||
|
||
|
||
def constant_fold_unary_float_op(op: str, value: float) -> float | None: | ||
if op == "-": | ||
def constant_fold_unary_op(op: str, value: ConstantValue) -> int | float | None: | ||
if op == "-" and isinstance(value, (int, float)): | ||
return -value | ||
elif op == "+": | ||
elif op == "~" and isinstance(value, int): | ||
return ~value | ||
elif op == "+" and isinstance(value, (int, float)): | ||
return value | ||
return None | ||
|
||
|
||
def constant_fold_binary_str_op(op: str, left: str, right: str) -> str | None: | ||
if op == "+": | ||
return left + right | ||
return None |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -3356,7 +3356,7 @@ def analyze_simple_literal_type(self, rvalue: Expression, is_final: bool) -> Typ | |
return None | ||
|
||
value = constant_fold_expr(rvalue, self.cur_mod_id) | ||
if value is None: | ||
if value is None or isinstance(value, complex): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I have no idea whether complex literals make any sense in mypy. PTAL. |
||
return None | ||
|
||
if isinstance(value, bool): | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -13,13 +13,10 @@ | |
from typing import Union | ||
from typing_extensions import Final | ||
|
||
from mypy.constant_fold import ( | ||
constant_fold_binary_int_op, | ||
constant_fold_binary_str_op, | ||
constant_fold_unary_float_op, | ||
constant_fold_unary_int_op, | ||
) | ||
from mypy.constant_fold import constant_fold_binary_op, constant_fold_unary_op | ||
from mypy.nodes import ( | ||
BytesExpr, | ||
ComplexExpr, | ||
Expression, | ||
FloatExpr, | ||
IntExpr, | ||
|
@@ -31,10 +28,11 @@ | |
Var, | ||
) | ||
from mypyc.irbuild.builder import IRBuilder | ||
from mypyc.irbuild.util import bytes_from_str | ||
|
||
# All possible result types of constant folding | ||
ConstantValue = Union[int, str, float] | ||
CONST_TYPES: Final = (int, str, float) | ||
ConstantValue = Union[int, float, complex, str, bytes] | ||
CONST_TYPES: Final = (int, float, complex, str, bytes) | ||
|
||
|
||
def constant_fold_expr(builder: IRBuilder, expr: Expression) -> ConstantValue | None: | ||
|
@@ -44,35 +42,55 @@ def constant_fold_expr(builder: IRBuilder, expr: Expression) -> ConstantValue | | |
""" | ||
if isinstance(expr, IntExpr): | ||
return expr.value | ||
if isinstance(expr, FloatExpr): | ||
return expr.value | ||
if isinstance(expr, StrExpr): | ||
return expr.value | ||
if isinstance(expr, FloatExpr): | ||
if isinstance(expr, BytesExpr): | ||
return bytes_from_str(expr.value) | ||
if isinstance(expr, ComplexExpr): | ||
return expr.value | ||
elif isinstance(expr, NameExpr): | ||
node = expr.node | ||
if isinstance(node, Var) and node.is_final: | ||
value = node.final_value | ||
if isinstance(value, (CONST_TYPES)): | ||
return value | ||
final_value = node.final_value | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The |
||
if isinstance(final_value, (CONST_TYPES)): | ||
return final_value | ||
elif isinstance(expr, MemberExpr): | ||
final = builder.get_final_ref(expr) | ||
if final is not None: | ||
fn, final_var, native = final | ||
if final_var.is_final: | ||
value = final_var.final_value | ||
if isinstance(value, (CONST_TYPES)): | ||
return value | ||
final_value = final_var.final_value | ||
if isinstance(final_value, (CONST_TYPES)): | ||
return final_value | ||
elif isinstance(expr, OpExpr): | ||
left = constant_fold_expr(builder, expr.left) | ||
right = constant_fold_expr(builder, expr.right) | ||
if isinstance(left, int) and isinstance(right, int): | ||
return constant_fold_binary_int_op(expr.op, left, right) | ||
elif isinstance(left, str) and isinstance(right, str): | ||
return constant_fold_binary_str_op(expr.op, left, right) | ||
if left is not None and right is not None: | ||
return constant_fold_binary_op_extended(expr.op, left, right) | ||
elif isinstance(expr, UnaryExpr): | ||
value = constant_fold_expr(builder, expr.expr) | ||
if isinstance(value, int): | ||
return constant_fold_unary_int_op(expr.op, value) | ||
if isinstance(value, float): | ||
return constant_fold_unary_float_op(expr.op, value) | ||
if value is not None and not isinstance(value, bytes): | ||
return constant_fold_unary_op(expr.op, value) | ||
return None | ||
|
||
|
||
def constant_fold_binary_op_extended( | ||
op: str, left: ConstantValue, right: ConstantValue | ||
) -> ConstantValue | None: | ||
"""Like mypy's constant_fold_binary_op(), but includes bytes support. | ||
|
||
mypy cannot use constant folded bytes easily so it's simpler to only support them in mypyc. | ||
""" | ||
if not isinstance(left, bytes) and not isinstance(right, bytes): | ||
return constant_fold_binary_op(op, left, right) | ||
|
||
if op == "+" and isinstance(left, bytes) and isinstance(right, bytes): | ||
return left + right | ||
elif op == "*" and isinstance(left, bytes) and isinstance(right, int): | ||
return left * right | ||
elif op == "*" and isinstance(left, int) and isinstance(right, bytes): | ||
return left * right | ||
|
||
return None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This could return a complex number, but the return type doesn't include
complex
(not sure if supporting complex results is worth it):I used this fragment to look for other interesting cases, maybe it's helpful here: