diff --git a/src/latexify/codegen/function_codegen.py b/src/latexify/codegen/function_codegen.py index fcd9e43..755f4bc 100644 --- a/src/latexify/codegen/function_codegen.py +++ b/src/latexify/codegen/function_codegen.py @@ -12,28 +12,95 @@ from latexify import exceptions -@dataclasses.dataclass(frozen=True) -class BinOpRule: - """Syntax rules for binary operators.""" +# Precedences of operators for BoolOp, BinOp, UnaryOp, and Compare nodes. +# Note that this value affects only the appearance of surrounding parentheses for each +# expression, and does not affect the AST itself. +# See also: +# https://docs.python.org/3/reference/expressions.html#operator-precedence +_PRECEDENCES: dict[type[ast.AST], int] = { + ast.Pow: 120, + ast.UAdd: 110, + ast.USub: 110, + ast.Invert: 110, + ast.Mult: 100, + ast.MatMult: 100, + ast.Div: 100, + ast.FloorDiv: 100, + ast.Mod: 100, + ast.Add: 90, + ast.Sub: 90, + ast.LShift: 80, + ast.RShift: 80, + ast.BitAnd: 70, + ast.BitXor: 60, + ast.BitOr: 50, + ast.In: 40, + ast.NotIn: 40, + ast.Is: 40, + ast.IsNot: 40, + ast.Lt: 40, + ast.LtE: 40, + ast.Gt: 40, + ast.GtE: 40, + ast.NotEq: 40, + ast.Eq: 40, + # NOTE(odashi): + # We assume that the `not` operator has the same precedence with other unary + # operators `+`, `-` and `~`, because the LaTeX counterpart $\lnot$ looks to have a + # high precedence. + # ast.Not: 30, + ast.Not: 110, + ast.And: 20, + ast.Or: 10, +} - # Precedence of this operator. - # See also: https://docs.python.org/3/reference/expressions.html - precedence: int - # Left/middle/right syntaxes to wrap operands. - latex_left: str - latex_middle: str - latex_right: str +def _get_precedence(node: ast.AST) -> int: + """Obtains the precedence of the subtree. + + Args: + node: Subtree to investigate. + + Returns: + If `node` is a subtree with some operator, returns the precedence of the + operator. Otherwise, returns a number larger enough from other precedences. + """ + if isinstance(node, (ast.BoolOp, ast.BinOp, ast.UnaryOp)): + return _PRECEDENCES[type(node.op)] + + if isinstance(node, ast.Compare): + # Compare operators have the same precedence. It is enough to check only the + # first operator. + return _PRECEDENCES[type(node.ops[0])] + + return 1_000_000 + + +@dataclasses.dataclass(frozen=True) +class BinOperandRule: + """Syntax rules for operands of BinOp.""" # Whether to require wrapping operands by parentheses according to the precedence. - wrap_left: bool = True - wrap_right: bool = True + wrap: bool = True # Whether to require wrapping operands by parentheses if the operand has the same # precedence with this operator. # This is used to control the behavior of non-associative operators. - force_left: bool = False - force_right: bool = False + force: bool = False + + +@dataclasses.dataclass(frozen=True) +class BinOpRule: + """Syntax rules for BinOp.""" + + # Left/middle/right syntaxes to wrap operands. + latex_left: str + latex_middle: str + latex_right: str + + # Operand rules. + operand_left: BinOperandRule = dataclasses.field(default_factory=BinOperandRule) + operand_right: BinOperandRule = dataclasses.field(default_factory=BinOperandRule) # Whether to assume the resulting syntax is wrapped by some bracket operators. # If True, the parent operator can avoid wrapping this operator by parentheses. @@ -41,27 +108,47 @@ class BinOpRule: _BIN_OP_RULES: dict[type[ast.operator], BinOpRule] = { - ast.Pow: BinOpRule(70, "", "^{", "}", wrap_right=False, force_left=True), - ast.Mult: BinOpRule(60, "", " ", ""), - ast.MatMult: BinOpRule(60, "", " ", ""), - ast.Div: BinOpRule(60, r"\frac{", "}{", "}", wrap_left=False, wrap_right=False), + ast.Pow: BinOpRule( + "", + "^{", + "}", + operand_left=BinOperandRule(force=True), + operand_right=BinOperandRule(wrap=False), + ), + ast.Mult: BinOpRule("", " ", ""), + ast.MatMult: BinOpRule("", " ", ""), + ast.Div: BinOpRule( + r"\frac{", + "}{", + "}", + operand_left=BinOperandRule(wrap=False), + operand_right=BinOperandRule(wrap=False), + ), ast.FloorDiv: BinOpRule( - 60, r"\left\lfloor\frac{", "}{", r"}\right\rfloor", - wrap_left=False, - wrap_right=False, + operand_left=BinOperandRule(wrap=False), + operand_right=BinOperandRule(wrap=False), is_wrapped=True, ), - ast.Mod: BinOpRule(60, "", r" \mathbin{\%} ", "", force_right=True), - ast.Add: BinOpRule(50, "", " + ", ""), - ast.Sub: BinOpRule(50, "", " - ", "", force_right=True), - ast.LShift: BinOpRule(40, "", r" \ll ", "", force_right=True), - ast.RShift: BinOpRule(40, "", r" \gg ", "", force_right=True), - ast.BitAnd: BinOpRule(30, "", r" \mathbin{\&} ", ""), - ast.BitXor: BinOpRule(20, "", r" \oplus ", ""), - ast.BitOr: BinOpRule(10, "", r" \mathbin{|} ", ""), + ast.Mod: BinOpRule( + "", r" \mathbin{\%} ", "", operand_right=BinOperandRule(force=True) + ), + ast.Add: BinOpRule("", " + ", ""), + ast.Sub: BinOpRule("", " - ", "", operand_right=BinOperandRule(force=True)), + ast.LShift: BinOpRule("", r" \ll ", "", operand_right=BinOperandRule(force=True)), + ast.RShift: BinOpRule("", r" \gg ", "", operand_right=BinOperandRule(force=True)), + ast.BitAnd: BinOpRule("", r" \mathbin{\&} ", ""), + ast.BitXor: BinOpRule("", r" \oplus ", ""), + ast.BitOr: BinOpRule("", r" \mathbin{|} ", ""), +} + +_UNARY_OPS: dict[type[ast.unaryop], str] = { + ast.Invert: r"\mathord{\sim} ", + ast.UAdd: "+", # Explicitly adds the $+$ operator. + ast.USub: "-", + ast.Not: r"\lnot ", } _COMPARE_OPS: dict[type[ast.cmpop], str] = { @@ -304,83 +391,83 @@ def visit_NameConstant(self, node: ast.NameConstant) -> str: def visit_Ellipsis(self, node: ast.Ellipsis) -> str: return self._convert_constant(...) - def visit_UnaryOp(self, node: ast.UnaryOp) -> str: - """Visit a unary op node.""" - - def _wrap(child): - latex = self.visit(child) - if isinstance(child, ast.BinOp) and isinstance( - child.op, (ast.Add, ast.Sub) - ): - return r"\left(" + latex + r"\right)" - return latex + def _wrap_operand(self, child: ast.expr, parent_prec: int) -> str: + """Wraps the operand subtree with parentheses. - reprs = { - ast.UAdd: (lambda: _wrap(node.operand)), - ast.USub: (lambda: "-" + _wrap(node.operand)), - ast.Not: (lambda: r"\lnot\left(" + _wrap(node.operand) + r"\right)"), - } + Args: + child: Operand subtree. + parent_prec: Precedence of the parent operator. - if type(node.op) in reprs: - return reprs[type(node.op)]() - return r"\mathrm{unknown\_uniop}(" + self.visit(node.operand) + ")" + Returns: + LaTeX form of `child`, with or without surrounding parentheses. + """ + latex = self.visit(child) + if _get_precedence(child) >= parent_prec: + return latex + return rf"\left( {latex} \right)" def _wrap_binop_operand( self, child: ast.expr, - parent_rule: BinOpRule, - is_left: bool, + parent_prec: int, + operand_rule: BinOperandRule, ) -> str: - """Wraps the given LaTeX with parenthesis. + """Wraps the operand subtree of BinOp with parentheses. Args: - child: Child subtree. - parent_rule: Syntax rule of the parent operator. - is_left: Position of the `child` in the parent operator: - - True: `child` is the left-hand side operand. - - False: `child` is the right-hand side operand. + child: Operand subtree. + parent_prec: Precedence of the parent operator. + operand_rule: Syntax rule of this operand. Returns: - LaTeX form of the `child`, with or without a surrounding parenthesis. + LaTeX form of the `child`, with or without surrounding parentheses. """ - latex = self.visit(child) + if not operand_rule.wrap: + return self.visit(child) if not isinstance(child, ast.BinOp): - return latex + return self._wrap_operand(child, parent_prec) - child_rule = _BIN_OP_RULES[type(child.op)] - wrap = parent_rule.wrap_left if is_left else parent_rule.wrap_right + latex = self.visit(child) - if not wrap or child_rule.is_wrapped: + if _BIN_OP_RULES[type(child.op)].is_wrapped: return latex - child_prec = child_rule.precedence - parent_prec = parent_rule.precedence - force = parent_rule.force_left if is_left else parent_rule.force_right + child_prec = _get_precedence(child) - if child_prec > parent_prec or (child_prec == parent_prec and not force): + if child_prec > parent_prec or ( + child_prec == parent_prec and not operand_rule.force + ): return latex return rf"\left( {latex} \right)" def visit_BinOp(self, node: ast.BinOp) -> str: """Visit a BinOp node.""" + prec = _get_precedence(node) rule = _BIN_OP_RULES[type(node.op)] - lhs = self._wrap_binop_operand(node.left, rule, is_left=True) - rhs = self._wrap_binop_operand(node.right, rule, is_left=False) + lhs = self._wrap_binop_operand(node.left, prec, rule.operand_left) + rhs = self._wrap_binop_operand(node.right, prec, rule.operand_right) return f"{rule.latex_left}{lhs}{rule.latex_middle}{rhs}{rule.latex_right}" + def visit_UnaryOp(self, node: ast.UnaryOp) -> str: + """Visit a unary op node.""" + latex = self._wrap_operand(node.operand, _get_precedence(node)) + return _UNARY_OPS[type(node.op)] + latex + def visit_Compare(self, node: ast.Compare) -> str: """Visit a compare node.""" - lhs = self.visit(node.left) + parent_prec = _get_precedence(node) + lhs = self._wrap_operand(node.left, parent_prec) ops = [_COMPARE_OPS[type(x)] for x in node.ops] - rhs = [self.visit(x) for x in node.comparators] + rhs = [self._wrap_operand(x, parent_prec) for x in node.comparators] ops_rhs = [f" {o} {r}" for o, r in zip(ops, rhs)] return "{" + lhs + "".join(ops_rhs) + "}" def visit_BoolOp(self, node: ast.BoolOp) -> str: """Visit a BoolOp node.""" - values = [rf"\left( {self.visit(x)} \right)" for x in node.values] + parent_prec = _get_precedence(node) + values = [self._wrap_operand(x, parent_prec) for x in node.values] op = f" {_BOOL_OPS[type(node.op)]} " return "{" + op.join(values) + "}" diff --git a/src/latexify/codegen/function_codegen_test.py b/src/latexify/codegen/function_codegen_test.py index c9fffd0..42d246c 100644 --- a/src/latexify/codegen/function_codegen_test.py +++ b/src/latexify/codegen/function_codegen_test.py @@ -326,6 +326,42 @@ def test_visit_call_sum_prod_with_if(src_suffix: str, dest_suffix: str) -> None: ("(x | y) ^ z", r"\left( x \mathbin{|} y \right) \oplus z"), # is_wrapped ("(x // y)**z", r"\left\lfloor\frac{x}{y}\right\rfloor^{z}"), + # With Call + ("x**f(y)", r"x^{\mathrm{f}\left(y\right)}"), + ("f(x)**y", r"\mathrm{f}\left(x\right)^{y}"), + ("x * f(y)", r"x \mathrm{f}\left(y\right)"), + ("f(x) * y", r"\mathrm{f}\left(x\right) y"), + ("x / f(y)", r"\frac{x}{\mathrm{f}\left(y\right)}"), + ("f(x) / y", r"\frac{\mathrm{f}\left(x\right)}{y}"), + ("x + f(y)", r"x + \mathrm{f}\left(y\right)"), + ("f(x) + y", r"\mathrm{f}\left(x\right) + y"), + # With UnaryOp + ("x**-y", r"x^{-y}"), + ("(-x)**y", r"\left( -x \right)^{y}"), + ("x * -y", r"x -y"), # TODO(odashi): google/latexify_py#89 + ("-x * y", r"-x y"), + ("x / -y", r"\frac{x}{-y}"), + ("-x / y", r"\frac{-x}{y}"), + ("x + -y", r"x + -y"), + ("-x + y", r"-x + y"), + # With Compare + ("x**(y == z)", r"x^{{y = z}}"), + ("(x == y)**z", r"\left( {x = y} \right)^{z}"), + ("x * (y == z)", r"x \left( {y = z} \right)"), + ("(x == y) * z", r"\left( {x = y} \right) z"), + ("x / (y == z)", r"\frac{x}{{y = z}}"), + ("(x == y) / z", r"\frac{{x = y}}{z}"), + ("x + (y == z)", r"x + \left( {y = z} \right)"), + ("(x == y) + z", r"\left( {x = y} \right) + z"), + # With BoolOp + ("x**(y and z)", r"x^{{y \land z}}"), + ("(x and y)**z", r"\left( {x \land y} \right)^{z}"), + ("x * (y and z)", r"x \left( {y \land z} \right)"), + ("(x and y) * z", r"\left( {x \land y} \right) z"), + ("x / (y and z)", r"\frac{x}{{y \land z}}"), + ("(x and y) / z", r"\frac{{x \land y}}{z}"), + ("x + (y and z)", r"x + \left( {y \land z} \right)"), + ("(x and y) + z", r"\left( {x \land y} \right) + z"), ], ) def test_visit_binop(code: str, latex: str) -> None: @@ -334,6 +370,42 @@ def test_visit_binop(code: str, latex: str) -> None: assert function_codegen.FunctionCodegen().visit(tree) == latex +@pytest.mark.parametrize( + "code,latex", + [ + # With literals + ("+x", r"+x"), + ("-x", r"-x"), + ("~x", r"\mathord{\sim} x"), + ("not x", r"\lnot x"), + # With Call + ("+f(x)", r"+\mathrm{f}\left(x\right)"), + ("-f(x)", r"-\mathrm{f}\left(x\right)"), + ("~f(x)", r"\mathord{\sim} \mathrm{f}\left(x\right)"), + ("not f(x)", r"\lnot \mathrm{f}\left(x\right)"), + # With BinOp + ("+(x + y)", r"+\left( x + y \right)"), + ("-(x + y)", r"-\left( x + y \right)"), + ("~(x + y)", r"\mathord{\sim} \left( x + y \right)"), + ("not x + y", r"\lnot \left( x + y \right)"), + # With Compare + ("+(x == y)", r"+\left( {x = y} \right)"), + ("-(x == y)", r"-\left( {x = y} \right)"), + ("~(x == y)", r"\mathord{\sim} \left( {x = y} \right)"), + ("not x == y", r"\lnot \left( {x = y} \right)"), + # With BoolOp + ("+(x and y)", r"+\left( {x \land y} \right)"), + ("-(x and y)", r"-\left( {x \land y} \right)"), + ("~(x and y)", r"\mathord{\sim} \left( {x \land y} \right)"), + ("not (x and y)", r"\lnot \left( {x \land y} \right)"), + ], +) +def test_visit_unaryop(code: str, latex: str) -> None: + tree = ast.parse(code).body[0].value + assert isinstance(tree, ast.UnaryOp) + assert function_codegen.FunctionCodegen().visit(tree) == latex + + @pytest.mark.parametrize( "code,latex", [ @@ -366,6 +438,20 @@ def test_visit_binop(code: str, latex: str) -> None: ("a <= b == c", r"{a \le b = c}"), ("a <= b < c", r"{a \le b < c}"), ("a <= b <= c", r"{a \le b \le c}"), + # With Call + ("a == f(b)", r"{a = \mathrm{f}\left(b\right)}"), + ("f(a) == b", r"{\mathrm{f}\left(a\right) = b}"), + # With BinOp + ("a == b + c", r"{a = b + c}"), + ("a + b == c", r"{a + b = c}"), + # With UnaryOp + ("a == -b", r"{a = -b}"), + ("-a == b", r"{-a = b}"), + ("a == (not b)", r"{a = \lnot b}"), + ("(not a) == b", r"{\lnot a = b}"), + # With BoolOp + ("a == (b and c)", r"{a = \left( {b \land c} \right)}"), + ("(a and b) == c", r"{\left( {a \land b} \right) = c}"), ], ) def test_visit_compare(code: str, latex: str) -> None: @@ -377,16 +463,35 @@ def test_visit_compare(code: str, latex: str) -> None: @pytest.mark.parametrize( "code,latex", [ - ("a and b", r"{\left( a \right) \land \left( b \right)}"), - ( - "a and b and c", - r"{\left( a \right) \land \left( b \right) \land \left( c \right)}", - ), - ("a or b", r"{\left( a \right) \lor \left( b \right)}"), - ( - "a or b or c", - r"{\left( a \right) \lor \left( b \right) \lor \left( c \right)}", - ), + # With literals + ("a and b", r"{a \land b}"), + ("a and b and c", r"{a \land b \land c}"), + ("a or b", r"{a \lor b}"), + ("a or b or c", r"{a \lor b \lor c}"), + ("a or b and c", r"{a \lor {b \land c}}"), + ("(a or b) and c", r"{\left( {a \lor b} \right) \land c}"), + ("a and b or c", r"{{a \land b} \lor c}"), + ("a and (b or c)", r"{a \land \left( {b \lor c} \right)}"), + # With Call + ("a and f(b)", r"{a \land \mathrm{f}\left(b\right)}"), + ("f(a) and b", r"{\mathrm{f}\left(a\right) \land b}"), + ("a or f(b)", r"{a \lor \mathrm{f}\left(b\right)}"), + ("f(a) or b", r"{\mathrm{f}\left(a\right) \lor b}"), + # With BinOp + ("a and b + c", r"{a \land b + c}"), + ("a + b and c", r"{a + b \land c}"), + ("a or b + c", r"{a \lor b + c}"), + ("a + b or c", r"{a + b \lor c}"), + # With UnaryOp + ("a and not b", r"{a \land \lnot b}"), + ("not a and b", r"{\lnot a \land b}"), + ("a or not b", r"{a \lor \lnot b}"), + ("not a or b", r"{\lnot a \lor b}"), + # With Compare + ("a and b == c", r"{a \land {b = c}}"), + ("a == b and c", r"{{a = b} \land c}"), + ("a or b == c", r"{a \lor {b = c}}"), + ("a == b or c", r"{{a = b} \lor c}"), ], ) def test_visit_boolop(code: str, latex: str) -> None: