Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Constant fold more unary and binary expressions #15202

Merged
merged 15 commits into from
Jun 25, 2023
Merged
Show file tree
Hide file tree
Changes from 13 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
118 changes: 90 additions & 28 deletions mypy/constant_fold.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,21 @@
from typing import Union
from typing_extensions import Final

from mypy.nodes import Expression, FloatExpr, IntExpr, NameExpr, OpExpr, StrExpr, UnaryExpr, Var
from mypy.nodes import (
ComplexExpr,
Expression,
FloatExpr,
IntExpr,
NameExpr,
OpExpr,
StrExpr,
UnaryExpr,
Var,
)

# All possible result types of constant folding
ConstantValue = Union[int, bool, float, str]
CONST_TYPES: Final = (int, bool, float, str)
ConstantValue = Union[int, bool, float, complex, str]
CONST_TYPES: Final = (int, bool, float, complex, str)


def constant_fold_expr(expr: Expression, cur_mod_id: str) -> ConstantValue | None:
Expand All @@ -39,6 +49,8 @@ def constant_fold_expr(expr: Expression, cur_mod_id: str) -> ConstantValue | Non
return expr.value
if isinstance(expr, FloatExpr):
return expr.value
if isinstance(expr, ComplexExpr):
return expr.value
elif isinstance(expr, NameExpr):
if expr.name == "True":
return True
Expand All @@ -56,26 +68,60 @@ def constant_fold_expr(expr: Expression, cur_mod_id: str) -> ConstantValue | Non
elif isinstance(expr, OpExpr):
left = constant_fold_expr(expr.left, cur_mod_id)
right = constant_fold_expr(expr.right, cur_mod_id)
if isinstance(left, int) and isinstance(right, int):
return constant_fold_binary_int_op(expr.op, left, right)
elif isinstance(left, str) and isinstance(right, str):
return constant_fold_binary_str_op(expr.op, left, right)
if left is not None and right is not None:
return constant_fold_binary_op(expr.op, left, right)
elif isinstance(expr, UnaryExpr):
value = constant_fold_expr(expr.expr, cur_mod_id)
if isinstance(value, int):
return constant_fold_unary_int_op(expr.op, value)
if isinstance(value, float):
return constant_fold_unary_float_op(expr.op, value)
if value is not None:
return constant_fold_unary_op(expr.op, value)
return None


def constant_fold_binary_int_op(op: str, left: int, right: int) -> int | None:
def constant_fold_binary_op(
op: str, left: ConstantValue, right: ConstantValue
) -> ConstantValue | None:
if isinstance(left, int) and isinstance(right, int):
return constant_fold_binary_int_op(op, left, right)

# Float and mixed int/float arithmetic.
if isinstance(left, float) and isinstance(right, float):
return constant_fold_binary_float_op(op, left, right)
elif isinstance(left, float) and isinstance(right, int):
return constant_fold_binary_float_op(op, left, right)
elif isinstance(left, int) and isinstance(right, float):
return constant_fold_binary_float_op(op, left, right)

# String concatenation and multiplication.
if op == "+" and isinstance(left, str) and isinstance(right, str):
return left + right
elif op == "*" and isinstance(left, str) and isinstance(right, int):
return left * right
elif op == "*" and isinstance(left, int) and isinstance(right, str):
return left * right

# Complex construction.
if op == "+" and isinstance(left, (int, float)) and isinstance(right, complex):
return left + right
elif op == "+" and isinstance(left, complex) and isinstance(right, (int, float)):
return left + right
elif op == "-" and isinstance(left, (int, float)) and isinstance(right, complex):
return left - right
elif op == "-" and isinstance(left, complex) and isinstance(right, (int, float)):
return left - right

return None


def constant_fold_binary_int_op(op: str, left: int, right: int) -> int | float | None:
if op == "+":
return left + right
if op == "-":
return left - right
elif op == "*":
return left * right
elif op == "/":
if right != 0:
return left / right
elif op == "//":
if right != 0:
return left // right
Expand All @@ -102,25 +148,41 @@ def constant_fold_binary_int_op(op: str, left: int, right: int) -> int | None:
return None


def constant_fold_unary_int_op(op: str, value: int) -> int | None:
if op == "-":
return -value
elif op == "~":
return ~value
elif op == "+":
return value
def constant_fold_binary_float_op(op: str, left: int | float, right: int | float) -> float | None:
assert not (isinstance(left, int) and isinstance(right, int)), (op, left, right)
if op == "+":
return left + right
elif op == "-":
return left - right
elif op == "*":
return left * right
elif op == "/":
if right != 0:
return left / right
elif op == "//":
if right != 0:
return left // right
elif op == "%":
if right != 0:
return left % right
elif op == "**":
if (left < 0 and right >= 1 or right == 0) or (left >= 0 and right >= 0):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This could return a complex number, but the return type doesn't include complex (not sure if supporting complex results is worth it):

>>> (-1.2) ** 1.5
(-2.414760036730213e-16-1.3145341380123985j)

I used this fragment to look for other interesting cases, maybe it's helpful here:

values = -1.5, -1.0, -0.5, 0.0, 0.5, 1.0, 1.5
for x in values:
    for y in values:
        try:
            print(x, y, x**y)
        except Exception:
            print(f'error: {x} ** {y}')

try:
ret = left**right
except OverflowError:
return None
Comment on lines +170 to +173
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Digging through the CPython source, only power ops can raise an OverflowError.

else:
assert isinstance(ret, float)
return ret

return None


def constant_fold_unary_float_op(op: str, value: float) -> float | None:
if op == "-":
def constant_fold_unary_op(op: str, value: ConstantValue) -> int | float | None:
if op == "-" and isinstance(value, (int, float)):
return -value
elif op == "+":
elif op == "~" and isinstance(value, int):
return ~value
elif op == "+" and isinstance(value, (int, float)):
return value
return None


def constant_fold_binary_str_op(op: str, left: str, right: str) -> str | None:
if op == "+":
return left + right
return None
2 changes: 1 addition & 1 deletion mypy/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -999,7 +999,7 @@ def __init__(self, name: str, type: mypy.types.Type | None = None) -> None:
# If constant value is a simple literal,
# store the literal value (unboxed) for the benefit of
# tools like mypyc.
self.final_value: int | float | bool | str | None = None
self.final_value: int | float | complex | bool | str | None = None
# Where the value was set (only for class attributes)
self.final_unset_in_class = False
self.final_set_in_init = False
Expand Down
2 changes: 1 addition & 1 deletion mypy/semanal.py
Original file line number Diff line number Diff line change
Expand Up @@ -3356,7 +3356,7 @@ def analyze_simple_literal_type(self, rvalue: Expression, is_final: bool) -> Typ
return None

value = constant_fold_expr(rvalue, self.cur_mod_id)
if value is None:
if value is None or isinstance(value, complex):
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have no idea whether complex literals make any sense in mypy. PTAL.

return None

if isinstance(value, bool):
Expand Down
12 changes: 6 additions & 6 deletions mypyc/irbuild/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -531,25 +531,25 @@ def load_final_static(
error_msg=f'value for final name "{error_name}" was not set',
)

def load_final_literal_value(self, val: int | str | bytes | float | bool, line: int) -> Value:
"""Load value of a final name or class-level attribute."""
def load_literal_value(self, val: int | str | bytes | float | complex | bool) -> Value:
"""Load value of a final name, class-level attribute, or constant folded expression."""
if isinstance(val, bool):
if val:
return self.true()
else:
return self.false()
elif isinstance(val, int):
# TODO: take care of negative integer initializers
# (probably easier to fix this in mypy itself).
return self.builder.load_int(val)
elif isinstance(val, float):
return self.builder.load_float(val)
elif isinstance(val, str):
return self.builder.load_str(val)
elif isinstance(val, bytes):
return self.builder.load_bytes(val)
elif isinstance(val, complex):
return self.builder.load_complex(val)
else:
assert False, "Unsupported final literal value"
assert False, "Unsupported literal value"

def get_assignment_target(
self, lvalue: Lvalue, line: int = -1, *, for_read: bool = False
Expand Down Expand Up @@ -1008,7 +1008,7 @@ def emit_load_final(
line: line number where loading occurs
"""
if final_var.final_value is not None: # this is safe even for non-native names
return self.load_final_literal_value(final_var.final_value, line)
return self.load_literal_value(final_var.final_value)
elif native:
return self.load_final_static(fullname, self.mapper.type_to_rtype(typ), line, name)
else:
Expand Down
2 changes: 1 addition & 1 deletion mypyc/irbuild/callable_class.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@


def setup_callable_class(builder: IRBuilder) -> None:
"""Generate an (incomplete) callable class representing function.
"""Generate an (incomplete) callable class representing a function.

This can be a nested function or a function within a non-extension
class. Also set up the 'self' variable for that class.
Expand Down
64 changes: 41 additions & 23 deletions mypyc/irbuild/constant_fold.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,10 @@
from typing import Union
from typing_extensions import Final

from mypy.constant_fold import (
constant_fold_binary_int_op,
constant_fold_binary_str_op,
constant_fold_unary_float_op,
constant_fold_unary_int_op,
)
from mypy.constant_fold import constant_fold_binary_op, constant_fold_unary_op
from mypy.nodes import (
BytesExpr,
ComplexExpr,
Expression,
FloatExpr,
IntExpr,
Expand All @@ -31,10 +28,11 @@
Var,
)
from mypyc.irbuild.builder import IRBuilder
from mypyc.irbuild.util import bytes_from_str

# All possible result types of constant folding
ConstantValue = Union[int, str, float]
CONST_TYPES: Final = (int, str, float)
ConstantValue = Union[int, float, complex, str, bytes]
CONST_TYPES: Final = (int, float, complex, str, bytes)


def constant_fold_expr(builder: IRBuilder, expr: Expression) -> ConstantValue | None:
Expand All @@ -44,35 +42,55 @@ def constant_fold_expr(builder: IRBuilder, expr: Expression) -> ConstantValue |
"""
if isinstance(expr, IntExpr):
return expr.value
if isinstance(expr, FloatExpr):
return expr.value
if isinstance(expr, StrExpr):
return expr.value
if isinstance(expr, FloatExpr):
if isinstance(expr, BytesExpr):
return bytes_from_str(expr.value)
if isinstance(expr, ComplexExpr):
return expr.value
elif isinstance(expr, NameExpr):
node = expr.node
if isinstance(node, Var) and node.is_final:
value = node.final_value
if isinstance(value, (CONST_TYPES)):
return value
final_value = node.final_value
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The value local name should support ConstantValue | None, but Var.final_value does not support bytes. To avoid causing type errors later in the function from this assignment implicitly setting value's type, these variables were renamed.

if isinstance(final_value, (CONST_TYPES)):
return final_value
elif isinstance(expr, MemberExpr):
final = builder.get_final_ref(expr)
if final is not None:
fn, final_var, native = final
if final_var.is_final:
value = final_var.final_value
if isinstance(value, (CONST_TYPES)):
return value
final_value = final_var.final_value
if isinstance(final_value, (CONST_TYPES)):
return final_value
elif isinstance(expr, OpExpr):
left = constant_fold_expr(builder, expr.left)
right = constant_fold_expr(builder, expr.right)
if isinstance(left, int) and isinstance(right, int):
return constant_fold_binary_int_op(expr.op, left, right)
elif isinstance(left, str) and isinstance(right, str):
return constant_fold_binary_str_op(expr.op, left, right)
if left is not None and right is not None:
return constant_fold_binary_op_extended(expr.op, left, right)
elif isinstance(expr, UnaryExpr):
value = constant_fold_expr(builder, expr.expr)
if isinstance(value, int):
return constant_fold_unary_int_op(expr.op, value)
if isinstance(value, float):
return constant_fold_unary_float_op(expr.op, value)
if value is not None and not isinstance(value, bytes):
return constant_fold_unary_op(expr.op, value)
return None


def constant_fold_binary_op_extended(
op: str, left: ConstantValue, right: ConstantValue
) -> ConstantValue | None:
"""Like mypy's constant_fold_binary_op(), but includes bytes support.

mypy cannot use constant folded bytes easily so it's simpler to only support them in mypyc.
"""
if not isinstance(left, bytes) and not isinstance(right, bytes):
return constant_fold_binary_op(op, left, right)

if op == "+" and isinstance(left, bytes) and isinstance(right, bytes):
return left + right
elif op == "*" and isinstance(left, bytes) and isinstance(right, int):
return left * right
elif op == "*" and isinstance(left, int) and isinstance(right, bytes):
return left * right

return None
15 changes: 2 additions & 13 deletions mypyc/irbuild/expression.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,6 @@
Assign,
BasicBlock,
ComparisonOp,
Float,
Integer,
LoadAddress,
LoadLiteral,
Expand Down Expand Up @@ -91,7 +90,6 @@
tokenizer_printf_style,
)
from mypyc.irbuild.specialize import apply_function_specialization, apply_method_specialization
from mypyc.irbuild.util import bytes_from_str
from mypyc.primitives.bytes_ops import bytes_slice_op
from mypyc.primitives.dict_ops import dict_get_item_op, dict_new_op, dict_set_item_op
from mypyc.primitives.generic_ops import iter_op
Expand Down Expand Up @@ -566,12 +564,8 @@ def try_constant_fold(builder: IRBuilder, expr: Expression) -> Value | None:
Return None otherwise.
"""
value = constant_fold_expr(builder, expr)
if isinstance(value, int):
return builder.load_int(value)
elif isinstance(value, str):
return builder.load_str(value)
elif isinstance(value, float):
return Float(value)
if value is not None:
return builder.load_literal_value(value)
return None


Expand Down Expand Up @@ -653,10 +647,6 @@ def set_literal_values(builder: IRBuilder, items: Sequence[Expression]) -> list[
values.append(True)
elif item.fullname == "builtins.False":
values.append(False)
elif isinstance(item, (BytesExpr, FloatExpr, ComplexExpr)):
# constant_fold_expr() doesn't handle these (yet?)
v = bytes_from_str(item.value) if isinstance(item, BytesExpr) else item.value
values.append(v)
elif isinstance(item, TupleExpr):
tuple_values = set_literal_values(builder, item.items)
if tuple_values is not None:
Expand All @@ -676,7 +666,6 @@ def precompute_set_literal(builder: IRBuilder, s: SetExpr) -> Value | None:
Supported items:
- Anything supported by irbuild.constant_fold.constant_fold_expr()
- None, True, and False
- Float, byte, and complex literals
- Tuple literals with only items listed above
"""
values = set_literal_values(builder, s.items)
Expand Down
Loading