Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support set operations. #94

Merged
merged 1 commit into from
Nov 14, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 30 additions & 2 deletions src/latexify/codegen/function_codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,17 @@ class BinOpRule:
ast.BitOr: BinOpRule("", r" \mathbin{|} ", ""),
}

# Typeset for BinOp of sets.
_SET_BIN_OP_RULES: dict[type[ast.operator], BinOpRule] = {
**_BIN_OP_RULES,
ast.Sub: BinOpRule(
"", r" \setminus ", "", operand_right=BinOperandRule(force=True)
),
ast.BitAnd: BinOpRule("", r" \cap ", ""),
ast.BitXor: BinOpRule("", r" \mathbin{\triangle} ", ""),
ast.BitOr: BinOpRule("", r" \cup ", ""),
}

_UNARY_OPS: dict[type[ast.unaryop], str] = {
ast.Invert: r"\mathord{\sim} ",
ast.UAdd: "+", # Explicitly adds the $+$ operator.
Expand All @@ -164,6 +175,15 @@ class BinOpRule:
ast.NotIn: r"\notin",
}

# Typeset for Compare of sets.
_SET_COMPARE_OPS: dict[type[ast.cmpop], str] = {
**_COMPARE_OPS,
ast.Gt: r"\supset",
ast.GtE: r"\supseteq",
ast.Lt: r"\subset",
ast.LtE: r"\subseteq",
}

_BOOL_OPS: dict[type[ast.boolop], str] = {
ast.And: r"\land",
ast.Or: r"\lor",
Expand All @@ -181,12 +201,16 @@ class FunctionCodegen(ast.NodeVisitor):
_use_raw_function_name: bool
_use_signature: bool

_bin_op_rules: dict[type[ast.operator], BinOpRule]
_compare_ops: dict[type[ast.cmpop], str]

def __init__(
self,
*,
use_math_symbols: bool = False,
use_raw_function_name: bool = False,
use_signature: bool = True,
use_set_symbols: bool = False,
) -> None:
"""Initializer.

Expand All @@ -197,13 +221,17 @@ def __init__(
or convert it to subscript.
use_signature: Whether to add the function signature before the expression
or not.
use_set_symbols: Whether to use set symbols or not.
"""
self._math_symbol_converter = math_symbols.MathSymbolConverter(
enabled=use_math_symbols
)
self._use_raw_function_name = use_raw_function_name
self._use_signature = use_signature

self._bin_op_rules = _SET_BIN_OP_RULES if use_set_symbols else _BIN_OP_RULES
self._compare_ops = _SET_COMPARE_OPS if use_set_symbols else _COMPARE_OPS

def generic_visit(self, node: ast.AST) -> str:
raise exceptions.LatexifyNotSupportedError(
f"Unsupported AST: {type(node).__name__}"
Expand Down Expand Up @@ -445,7 +473,7 @@ def _wrap_binop_operand(
def visit_BinOp(self, node: ast.BinOp) -> str:
"""Visit a BinOp node."""
prec = _get_precedence(node)
rule = _BIN_OP_RULES[type(node.op)]
rule = self._bin_op_rules[type(node.op)]
lhs = self._wrap_binop_operand(node.left, prec, rule.operand_left)
rhs = self._wrap_binop_operand(node.right, prec, rule.operand_right)
return f"{rule.latex_left}{lhs}{rule.latex_middle}{rhs}{rule.latex_right}"
Expand All @@ -459,7 +487,7 @@ def visit_Compare(self, node: ast.Compare) -> str:
"""Visit a compare node."""
parent_prec = _get_precedence(node)
lhs = self._wrap_operand(node.left, parent_prec)
ops = [_COMPARE_OPS[type(x)] for x in node.ops]
ops = [self._compare_ops[type(x)] for x in node.ops]
rhs = [self._wrap_operand(x, parent_prec) for x in node.comparators]
ops_rhs = [f" {o} {r}" for o, r in zip(ops, rhs)]
return "{" + lhs + "".join(ops_rhs) + "}"
Expand Down
30 changes: 30 additions & 0 deletions src/latexify/codegen/function_codegen_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -563,3 +563,33 @@ def test_visit_subscript(code: str, latex: str) -> None:
tree = ast.parse(code).body[0].value
assert isinstance(tree, ast.Subscript)
assert function_codegen.FunctionCodegen().visit(tree) == latex


@pytest.mark.parametrize(
"code,latex",
[
("a - b", r"a \setminus b"),
("a & b", r"a \cap b"),
("a ^ b", r"a \mathbin{\triangle} b"),
("a | b", r"a \cup b"),
],
)
def test_use_set_symbols_binop(code: str, latex: str) -> None:
tree = ast.parse(code).body[0].value
assert isinstance(tree, ast.BinOp)
assert function_codegen.FunctionCodegen(use_set_symbols=True).visit(tree) == latex


@pytest.mark.parametrize(
"code,latex",
[
("a < b", r"{a \subset b}"),
("a <= b", r"{a \subseteq b}"),
("a > b", r"{a \supset b}"),
("a >= b", r"{a \supseteq b}"),
],
)
def test_use_set_symbols_compare(code: str, latex: str) -> None:
tree = ast.parse(code).body[0].value
assert isinstance(tree, ast.Compare)
assert function_codegen.FunctionCodegen(use_set_symbols=True).visit(tree) == latex
3 changes: 3 additions & 0 deletions src/latexify/frontend.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ def get_latex(
use_math_symbols: bool = False,
use_raw_function_name: bool = False,
use_signature: bool = True,
use_set_symbols: bool = False,
) -> str:
"""Obtains LaTeX description from the function's source.

Expand All @@ -38,6 +39,7 @@ def get_latex(
or convert it to subscript.
use_signature: Whether to add the function signature before the expression or
not.
use_set_symbols: Whether to use set symbols or not.

Returns:
Generatee LaTeX description.
Expand All @@ -59,6 +61,7 @@ def get_latex(
use_math_symbols=use_math_symbols,
use_raw_function_name=use_raw_function_name,
use_signature=use_signature,
use_set_symbols=use_set_symbols,
).visit(tree)


Expand Down
12 changes: 12 additions & 0 deletions src/latexify/frontend_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,18 @@ def f(x):
assert frontend.get_latex(f, use_signature=True) == latex_with_flag


def test_get_latex_use_set_symbols() -> None:
def f(x, y):
return x & y

latex_without_flag = r"\mathrm{f}(x, y) = x \mathbin{\&} y"
latex_with_flag = r"\mathrm{f}(x, y) = x \cap y"

assert frontend.get_latex(f) == latex_without_flag
assert frontend.get_latex(f, use_set_symbols=False) == latex_without_flag
assert frontend.get_latex(f, use_set_symbols=True) == latex_with_flag


def test_function() -> None:
def f(x):
return x
Expand Down