diff --git a/src/integration_tests/regression_test.py b/src/integration_tests/regression_test.py index 89c59e0..19ceac9 100644 --- a/src/integration_tests/regression_test.py +++ b/src/integration_tests/regression_test.py @@ -279,3 +279,23 @@ def solve(a, b): r"a - b \mathclose{}\right) - a b" ) _check_function(solve, latex) + + +def test_docstring_allowed() -> None: + def solve(x): + """The identity function.""" + return x + + latex = r"\mathrm{solve}(x) = x" + _check_function(solve, latex) + + +def test_multiple_constants_allowed() -> None: + def solve(x): + """The identity function.""" + 123 + True + return x + + latex = r"\mathrm{solve}(x) = x" + _check_function(solve, latex) diff --git a/src/latexify/ast_utils.py b/src/latexify/ast_utils.py index 2d2e6c9..56de6d1 100644 --- a/src/latexify/ast_utils.py +++ b/src/latexify/ast_utils.py @@ -41,6 +41,24 @@ def make_constant(value: Any) -> ast.expr: raise ValueError(f"Unsupported type to generate Constant: {type(value).__name__}") +def is_constant(node: ast.AST) -> bool: + """Checks if the node is a constant. + + Args: + node: The node to examine. + + Returns: + True if the node is a constant, False otherwise. + """ + if sys.version_info.minor < 8: + return isinstance( + node, + (ast.Bytes, ast.Constant, ast.Ellipsis, ast.NameConstant, ast.Num, ast.Str), + ) + else: + return isinstance(node, ast.Constant) + + def extract_int_or_none(node: ast.expr) -> int | None: """Extracts int constant from the given Constant node. diff --git a/src/latexify/ast_utils_test.py b/src/latexify/ast_utils_test.py index ab3352a..bcf0e5b 100644 --- a/src/latexify/ast_utils_test.py +++ b/src/latexify/ast_utils_test.py @@ -59,6 +59,37 @@ def test_make_constant_invalid() -> None: ast_utils.make_constant(object()) +@test_utils.require_at_most(7) +@pytest.mark.parametrize( + "value,expected", + [ + (ast.Bytes(s=b"foo"), True), + (ast.Constant("bar"), True), + (ast.Ellipsis(), True), + (ast.NameConstant(value=None), True), + (ast.Num(n=123), True), + (ast.Str(s="baz"), True), + (ast.Expr(value=ast.Num(456)), False), + (ast.Global("qux"), False), + ], +) +def test_is_constant_legacy(value: ast.AST, expected: bool) -> None: + assert ast_utils.is_constant(value) is expected + + +@test_utils.require_at_least(8) +@pytest.mark.parametrize( + "value,expected", + [ + (ast.Constant("foo"), True), + (ast.Expr(value=ast.Constant(123)), False), + (ast.Global("bar"), False), + ], +) +def test_is_constant(value: ast.AST, expected: bool) -> None: + assert ast_utils.is_constant(value) is expected + + def test_extract_int_or_none() -> None: assert ast_utils.extract_int_or_none(ast_utils.make_constant(-123)) == -123 assert ast_utils.extract_int_or_none(ast_utils.make_constant(0)) == 0 diff --git a/src/latexify/codegen/function_codegen.py b/src/latexify/codegen/function_codegen.py index f25c9ad..33f8df0 100644 --- a/src/latexify/codegen/function_codegen.py +++ b/src/latexify/codegen/function_codegen.py @@ -7,7 +7,7 @@ import sys from typing import Any -from latexify import analyzers, constants, exceptions, math_symbols +from latexify import analyzers, ast_utils, constants, exceptions, math_symbols # Precedences of operators for BoolOp, BinOp, UnaryOp, and Compare nodes. # Note that this value affects only the appearance of surrounding parentheses for each @@ -253,6 +253,9 @@ def visit_FunctionDef(self, node: ast.FunctionDef) -> str: # Assignment statements (if any): x = ... for child in node.body[:-1]: + if isinstance(child, ast.Expr) and ast_utils.is_constant(child.value): + continue + if not isinstance(child, ast.Assign): raise exceptions.LatexifyNotSupportedError( "Codegen supports only Assign nodes in multiline functions, " diff --git a/src/latexify/codegen/function_codegen_test.py b/src/latexify/codegen/function_codegen_test.py index 30f4436..793ab62 100644 --- a/src/latexify/codegen/function_codegen_test.py +++ b/src/latexify/codegen/function_codegen_test.py @@ -3,6 +3,7 @@ from __future__ import annotations import ast +import textwrap import pytest @@ -22,17 +23,16 @@ class UnknownNode(ast.AST): def test_visit_functiondef_use_signature() -> None: - tree = ast.FunctionDef( - name="f", - args=ast.arguments( - args=[ast.arg(arg="x")], - kwonlyargs=[], - kw_defaults=[], - defaults=[], - ), - body=[ast.Return(value=ast.Name(id="x", ctx=ast.Load()))], - decorator_list=[], - ) + tree = ast.parse( + textwrap.dedent( + """ + def f(x): + return x + """ + ) + ).body[0] + assert isinstance(tree, ast.FunctionDef) + latex_without_flag = "x" latex_with_flag = r"\mathrm{f}(x) = x" assert FunctionCodegen().visit(tree) == latex_with_flag @@ -40,6 +40,40 @@ def test_visit_functiondef_use_signature() -> None: assert FunctionCodegen(use_signature=True).visit(tree) == latex_with_flag +def test_visit_functiondef_ignore_docstring() -> None: + tree = ast.parse( + textwrap.dedent( + """ + def f(x): + '''docstring''' + return x + """ + ) + ).body[0] + assert isinstance(tree, ast.FunctionDef) + + latex = r"\mathrm{f}(x) = x" + assert FunctionCodegen().visit(tree) == latex + + +def test_visit_functiondef_ignore_multiple_constants() -> None: + tree = ast.parse( + textwrap.dedent( + """ + def f(x): + '''docstring''' + 3 + True + return x + """ + ) + ).body[0] + assert isinstance(tree, ast.FunctionDef) + + latex = r"\mathrm{f}(x) = x" + assert FunctionCodegen().visit(tree) == latex + + @pytest.mark.parametrize( "code,latex", [