Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Entire refactoring of BinOp/UnaryOp/Compare/BoolOp #92

Merged
merged 1 commit into from
Nov 13, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
227 changes: 157 additions & 70 deletions src/latexify/codegen/function_codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,56 +12,143 @@
from latexify import exceptions


@dataclasses.dataclass(frozen=True)
class BinOpRule:
"""Syntax rules for binary operators."""
# Precedences of operators for BoolOp, BinOp, UnaryOp, and Compare nodes.
# Note that this value affects only the appearance of surrounding parentheses for each
# expression, and does not affect the AST itself.
# See also:
# https://docs.python.org/3/reference/expressions.html#operator-precedence
_PRECEDENCES: dict[type[ast.AST], int] = {
ast.Pow: 120,
ast.UAdd: 110,
ast.USub: 110,
ast.Invert: 110,
ast.Mult: 100,
ast.MatMult: 100,
ast.Div: 100,
ast.FloorDiv: 100,
ast.Mod: 100,
ast.Add: 90,
ast.Sub: 90,
ast.LShift: 80,
ast.RShift: 80,
ast.BitAnd: 70,
ast.BitXor: 60,
ast.BitOr: 50,
ast.In: 40,
ast.NotIn: 40,
ast.Is: 40,
ast.IsNot: 40,
ast.Lt: 40,
ast.LtE: 40,
ast.Gt: 40,
ast.GtE: 40,
ast.NotEq: 40,
ast.Eq: 40,
# NOTE(odashi):
# We assume that the `not` operator has the same precedence with other unary
# operators `+`, `-` and `~`, because the LaTeX counterpart $\lnot$ looks to have a
# high precedence.
# ast.Not: 30,
ast.Not: 110,
ast.And: 20,
ast.Or: 10,
}

# Precedence of this operator.
# See also: https://docs.python.org/3/reference/expressions.html
precedence: int

# Left/middle/right syntaxes to wrap operands.
latex_left: str
latex_middle: str
latex_right: str
def _get_precedence(node: ast.AST) -> int:
"""Obtains the precedence of the subtree.

Args:
node: Subtree to investigate.

Returns:
If `node` is a subtree with some operator, returns the precedence of the
operator. Otherwise, returns a number larger enough from other precedences.
"""
if isinstance(node, (ast.BoolOp, ast.BinOp, ast.UnaryOp)):
return _PRECEDENCES[type(node.op)]

if isinstance(node, ast.Compare):
# Compare operators have the same precedence. It is enough to check only the
# first operator.
return _PRECEDENCES[type(node.ops[0])]

return 1_000_000


@dataclasses.dataclass(frozen=True)
class BinOperandRule:
"""Syntax rules for operands of BinOp."""

# Whether to require wrapping operands by parentheses according to the precedence.
wrap_left: bool = True
wrap_right: bool = True
wrap: bool = True

# Whether to require wrapping operands by parentheses if the operand has the same
# precedence with this operator.
# This is used to control the behavior of non-associative operators.
force_left: bool = False
force_right: bool = False
force: bool = False


@dataclasses.dataclass(frozen=True)
class BinOpRule:
"""Syntax rules for BinOp."""

# Left/middle/right syntaxes to wrap operands.
latex_left: str
latex_middle: str
latex_right: str

# Operand rules.
operand_left: BinOperandRule = dataclasses.field(default_factory=BinOperandRule)
operand_right: BinOperandRule = dataclasses.field(default_factory=BinOperandRule)

# Whether to assume the resulting syntax is wrapped by some bracket operators.
# If True, the parent operator can avoid wrapping this operator by parentheses.
is_wrapped: bool = False


_BIN_OP_RULES: dict[type[ast.operator], BinOpRule] = {
ast.Pow: BinOpRule(70, "", "^{", "}", wrap_right=False, force_left=True),
ast.Mult: BinOpRule(60, "", " ", ""),
ast.MatMult: BinOpRule(60, "", " ", ""),
ast.Div: BinOpRule(60, r"\frac{", "}{", "}", wrap_left=False, wrap_right=False),
ast.Pow: BinOpRule(
"",
"^{",
"}",
operand_left=BinOperandRule(force=True),
operand_right=BinOperandRule(wrap=False),
),
ast.Mult: BinOpRule("", " ", ""),
ast.MatMult: BinOpRule("", " ", ""),
ast.Div: BinOpRule(
r"\frac{",
"}{",
"}",
operand_left=BinOperandRule(wrap=False),
operand_right=BinOperandRule(wrap=False),
),
ast.FloorDiv: BinOpRule(
60,
r"\left\lfloor\frac{",
"}{",
r"}\right\rfloor",
wrap_left=False,
wrap_right=False,
operand_left=BinOperandRule(wrap=False),
operand_right=BinOperandRule(wrap=False),
is_wrapped=True,
),
ast.Mod: BinOpRule(60, "", r" \mathbin{\%} ", "", force_right=True),
ast.Add: BinOpRule(50, "", " + ", ""),
ast.Sub: BinOpRule(50, "", " - ", "", force_right=True),
ast.LShift: BinOpRule(40, "", r" \ll ", "", force_right=True),
ast.RShift: BinOpRule(40, "", r" \gg ", "", force_right=True),
ast.BitAnd: BinOpRule(30, "", r" \mathbin{\&} ", ""),
ast.BitXor: BinOpRule(20, "", r" \oplus ", ""),
ast.BitOr: BinOpRule(10, "", r" \mathbin{|} ", ""),
ast.Mod: BinOpRule(
"", r" \mathbin{\%} ", "", operand_right=BinOperandRule(force=True)
),
ast.Add: BinOpRule("", " + ", ""),
ast.Sub: BinOpRule("", " - ", "", operand_right=BinOperandRule(force=True)),
ast.LShift: BinOpRule("", r" \ll ", "", operand_right=BinOperandRule(force=True)),
ast.RShift: BinOpRule("", r" \gg ", "", operand_right=BinOperandRule(force=True)),
ast.BitAnd: BinOpRule("", r" \mathbin{\&} ", ""),
ast.BitXor: BinOpRule("", r" \oplus ", ""),
ast.BitOr: BinOpRule("", r" \mathbin{|} ", ""),
}

_UNARY_OPS: dict[type[ast.unaryop], str] = {
ast.Invert: r"\mathord{\sim} ",
ast.UAdd: "+", # Explicitly adds the $+$ operator.
ast.USub: "-",
ast.Not: r"\lnot ",
}

_COMPARE_OPS: dict[type[ast.cmpop], str] = {
Expand Down Expand Up @@ -304,83 +391,83 @@ def visit_NameConstant(self, node: ast.NameConstant) -> str:
def visit_Ellipsis(self, node: ast.Ellipsis) -> str:
return self._convert_constant(...)

def visit_UnaryOp(self, node: ast.UnaryOp) -> str:
"""Visit a unary op node."""

def _wrap(child):
latex = self.visit(child)
if isinstance(child, ast.BinOp) and isinstance(
child.op, (ast.Add, ast.Sub)
):
return r"\left(" + latex + r"\right)"
return latex
def _wrap_operand(self, child: ast.expr, parent_prec: int) -> str:
"""Wraps the operand subtree with parentheses.

reprs = {
ast.UAdd: (lambda: _wrap(node.operand)),
ast.USub: (lambda: "-" + _wrap(node.operand)),
ast.Not: (lambda: r"\lnot\left(" + _wrap(node.operand) + r"\right)"),
}
Args:
child: Operand subtree.
parent_prec: Precedence of the parent operator.

if type(node.op) in reprs:
return reprs[type(node.op)]()
return r"\mathrm{unknown\_uniop}(" + self.visit(node.operand) + ")"
Returns:
LaTeX form of `child`, with or without surrounding parentheses.
"""
latex = self.visit(child)
if _get_precedence(child) >= parent_prec:
return latex
return rf"\left( {latex} \right)"

def _wrap_binop_operand(
self,
child: ast.expr,
parent_rule: BinOpRule,
is_left: bool,
parent_prec: int,
operand_rule: BinOperandRule,
) -> str:
"""Wraps the given LaTeX with parenthesis.
"""Wraps the operand subtree of BinOp with parentheses.

Args:
child: Child subtree.
parent_rule: Syntax rule of the parent operator.
is_left: Position of the `child` in the parent operator:
- True: `child` is the left-hand side operand.
- False: `child` is the right-hand side operand.
child: Operand subtree.
parent_prec: Precedence of the parent operator.
operand_rule: Syntax rule of this operand.

Returns:
LaTeX form of the `child`, with or without a surrounding parenthesis.
LaTeX form of the `child`, with or without surrounding parentheses.
"""
latex = self.visit(child)
if not operand_rule.wrap:
return self.visit(child)

if not isinstance(child, ast.BinOp):
return latex
return self._wrap_operand(child, parent_prec)

child_rule = _BIN_OP_RULES[type(child.op)]
wrap = parent_rule.wrap_left if is_left else parent_rule.wrap_right
latex = self.visit(child)

if not wrap or child_rule.is_wrapped:
if _BIN_OP_RULES[type(child.op)].is_wrapped:
return latex

child_prec = child_rule.precedence
parent_prec = parent_rule.precedence
force = parent_rule.force_left if is_left else parent_rule.force_right
child_prec = _get_precedence(child)

if child_prec > parent_prec or (child_prec == parent_prec and not force):
if child_prec > parent_prec or (
child_prec == parent_prec and not operand_rule.force
):
return latex

return rf"\left( {latex} \right)"

def visit_BinOp(self, node: ast.BinOp) -> str:
"""Visit a BinOp node."""
prec = _get_precedence(node)
rule = _BIN_OP_RULES[type(node.op)]
lhs = self._wrap_binop_operand(node.left, rule, is_left=True)
rhs = self._wrap_binop_operand(node.right, rule, is_left=False)
lhs = self._wrap_binop_operand(node.left, prec, rule.operand_left)
rhs = self._wrap_binop_operand(node.right, prec, rule.operand_right)
return f"{rule.latex_left}{lhs}{rule.latex_middle}{rhs}{rule.latex_right}"

def visit_UnaryOp(self, node: ast.UnaryOp) -> str:
"""Visit a unary op node."""
latex = self._wrap_operand(node.operand, _get_precedence(node))
return _UNARY_OPS[type(node.op)] + latex

def visit_Compare(self, node: ast.Compare) -> str:
"""Visit a compare node."""
lhs = self.visit(node.left)
parent_prec = _get_precedence(node)
lhs = self._wrap_operand(node.left, parent_prec)
ops = [_COMPARE_OPS[type(x)] for x in node.ops]
rhs = [self.visit(x) for x in node.comparators]
rhs = [self._wrap_operand(x, parent_prec) for x in node.comparators]
ops_rhs = [f" {o} {r}" for o, r in zip(ops, rhs)]
return "{" + lhs + "".join(ops_rhs) + "}"

def visit_BoolOp(self, node: ast.BoolOp) -> str:
"""Visit a BoolOp node."""
values = [rf"\left( {self.visit(x)} \right)" for x in node.values]
parent_prec = _get_precedence(node)
values = [self._wrap_operand(x, parent_prec) for x in node.values]
op = f" {_BOOL_OPS[type(node.op)]} "
return "{" + op.join(values) + "}"

Expand Down
Loading