diff --git a/src/latexify/codegen/function_codegen.py b/src/latexify/codegen/function_codegen.py index 755f4bc..41ad1ee 100644 --- a/src/latexify/codegen/function_codegen.py +++ b/src/latexify/codegen/function_codegen.py @@ -144,6 +144,17 @@ class BinOpRule: ast.BitOr: BinOpRule("", r" \mathbin{|} ", ""), } +# Typeset for BinOp of sets. +_SET_BIN_OP_RULES: dict[type[ast.operator], BinOpRule] = { + **_BIN_OP_RULES, + ast.Sub: BinOpRule( + "", r" \setminus ", "", operand_right=BinOperandRule(force=True) + ), + ast.BitAnd: BinOpRule("", r" \cap ", ""), + ast.BitXor: BinOpRule("", r" \mathbin{\triangle} ", ""), + ast.BitOr: BinOpRule("", r" \cup ", ""), +} + _UNARY_OPS: dict[type[ast.unaryop], str] = { ast.Invert: r"\mathord{\sim} ", ast.UAdd: "+", # Explicitly adds the $+$ operator. @@ -164,6 +175,15 @@ class BinOpRule: ast.NotIn: r"\notin", } +# Typeset for Compare of sets. +_SET_COMPARE_OPS: dict[type[ast.cmpop], str] = { + **_COMPARE_OPS, + ast.Gt: r"\supset", + ast.GtE: r"\supseteq", + ast.Lt: r"\subset", + ast.LtE: r"\subseteq", +} + _BOOL_OPS: dict[type[ast.boolop], str] = { ast.And: r"\land", ast.Or: r"\lor", @@ -181,12 +201,16 @@ class FunctionCodegen(ast.NodeVisitor): _use_raw_function_name: bool _use_signature: bool + _bin_op_rules: dict[type[ast.operator], BinOpRule] + _compare_ops: dict[type[ast.cmpop], str] + def __init__( self, *, use_math_symbols: bool = False, use_raw_function_name: bool = False, use_signature: bool = True, + use_set_symbols: bool = False, ) -> None: """Initializer. @@ -197,6 +221,7 @@ def __init__( or convert it to subscript. use_signature: Whether to add the function signature before the expression or not. + use_set_symbols: Whether to use set symbols or not. """ self._math_symbol_converter = math_symbols.MathSymbolConverter( enabled=use_math_symbols @@ -204,6 +229,9 @@ def __init__( self._use_raw_function_name = use_raw_function_name self._use_signature = use_signature + self._bin_op_rules = _SET_BIN_OP_RULES if use_set_symbols else _BIN_OP_RULES + self._compare_ops = _SET_COMPARE_OPS if use_set_symbols else _COMPARE_OPS + def generic_visit(self, node: ast.AST) -> str: raise exceptions.LatexifyNotSupportedError( f"Unsupported AST: {type(node).__name__}" @@ -445,7 +473,7 @@ def _wrap_binop_operand( def visit_BinOp(self, node: ast.BinOp) -> str: """Visit a BinOp node.""" prec = _get_precedence(node) - rule = _BIN_OP_RULES[type(node.op)] + rule = self._bin_op_rules[type(node.op)] lhs = self._wrap_binop_operand(node.left, prec, rule.operand_left) rhs = self._wrap_binop_operand(node.right, prec, rule.operand_right) return f"{rule.latex_left}{lhs}{rule.latex_middle}{rhs}{rule.latex_right}" @@ -459,7 +487,7 @@ def visit_Compare(self, node: ast.Compare) -> str: """Visit a compare node.""" parent_prec = _get_precedence(node) lhs = self._wrap_operand(node.left, parent_prec) - ops = [_COMPARE_OPS[type(x)] for x in node.ops] + ops = [self._compare_ops[type(x)] for x in node.ops] rhs = [self._wrap_operand(x, parent_prec) for x in node.comparators] ops_rhs = [f" {o} {r}" for o, r in zip(ops, rhs)] return "{" + lhs + "".join(ops_rhs) + "}" diff --git a/src/latexify/codegen/function_codegen_test.py b/src/latexify/codegen/function_codegen_test.py index 42d246c..923125d 100644 --- a/src/latexify/codegen/function_codegen_test.py +++ b/src/latexify/codegen/function_codegen_test.py @@ -563,3 +563,33 @@ def test_visit_subscript(code: str, latex: str) -> None: tree = ast.parse(code).body[0].value assert isinstance(tree, ast.Subscript) assert function_codegen.FunctionCodegen().visit(tree) == latex + + +@pytest.mark.parametrize( + "code,latex", + [ + ("a - b", r"a \setminus b"), + ("a & b", r"a \cap b"), + ("a ^ b", r"a \mathbin{\triangle} b"), + ("a | b", r"a \cup b"), + ], +) +def test_use_set_symbols_binop(code: str, latex: str) -> None: + tree = ast.parse(code).body[0].value + assert isinstance(tree, ast.BinOp) + assert function_codegen.FunctionCodegen(use_set_symbols=True).visit(tree) == latex + + +@pytest.mark.parametrize( + "code,latex", + [ + ("a < b", r"{a \subset b}"), + ("a <= b", r"{a \subseteq b}"), + ("a > b", r"{a \supset b}"), + ("a >= b", r"{a \supseteq b}"), + ], +) +def test_use_set_symbols_compare(code: str, latex: str) -> None: + tree = ast.parse(code).body[0].value + assert isinstance(tree, ast.Compare) + assert function_codegen.FunctionCodegen(use_set_symbols=True).visit(tree) == latex diff --git a/src/latexify/frontend.py b/src/latexify/frontend.py index 0a978dc..3edaada 100644 --- a/src/latexify/frontend.py +++ b/src/latexify/frontend.py @@ -20,6 +20,7 @@ def get_latex( use_math_symbols: bool = False, use_raw_function_name: bool = False, use_signature: bool = True, + use_set_symbols: bool = False, ) -> str: """Obtains LaTeX description from the function's source. @@ -38,6 +39,7 @@ def get_latex( or convert it to subscript. use_signature: Whether to add the function signature before the expression or not. + use_set_symbols: Whether to use set symbols or not. Returns: Generatee LaTeX description. @@ -59,6 +61,7 @@ def get_latex( use_math_symbols=use_math_symbols, use_raw_function_name=use_raw_function_name, use_signature=use_signature, + use_set_symbols=use_set_symbols, ).visit(tree) diff --git a/src/latexify/frontend_test.py b/src/latexify/frontend_test.py index d716347..8c21e05 100644 --- a/src/latexify/frontend_test.py +++ b/src/latexify/frontend_test.py @@ -69,6 +69,18 @@ def f(x): assert frontend.get_latex(f, use_signature=True) == latex_with_flag +def test_get_latex_use_set_symbols() -> None: + def f(x, y): + return x & y + + latex_without_flag = r"\mathrm{f}(x, y) = x \mathbin{\&} y" + latex_with_flag = r"\mathrm{f}(x, y) = x \cap y" + + assert frontend.get_latex(f) == latex_without_flag + assert frontend.get_latex(f, use_set_symbols=False) == latex_without_flag + assert frontend.get_latex(f, use_set_symbols=True) == latex_with_flag + + def test_function() -> None: def f(x): return x