From 214dbcde598debd5553cb6c54aaf8df8c3f13438 Mon Sep 17 00:00:00 2001 From: odashi Date: Thu, 3 Nov 2022 17:19:30 +0000 Subject: [PATCH 1/3] add some utils to analyze range. --- src/integration_tests/regression_test.py | 5 +- src/latexify/analyzers.py | 61 +++++++ src/latexify/analyzers_test.py | 152 ++++++++++++++++++ src/latexify/ast_utils.py | 88 ++++++++++ src/latexify/ast_utils_test.py | 118 ++++++++++++++ src/latexify/codegen/function_codegen.py | 46 +++--- src/latexify/codegen/function_codegen_test.py | 2 +- src/latexify/test_utils.py | 16 +- .../transformers/assignment_reducer_test.py | 12 +- 9 files changed, 454 insertions(+), 46 deletions(-) create mode 100644 src/latexify/analyzers.py create mode 100644 src/latexify/analyzers_test.py create mode 100644 src/latexify/ast_utils.py create mode 100644 src/latexify/ast_utils_test.py diff --git a/src/integration_tests/regression_test.py b/src/integration_tests/regression_test.py index 13228b4..74975d7 100644 --- a/src/integration_tests/regression_test.py +++ b/src/integration_tests/regression_test.py @@ -89,7 +89,7 @@ def sum_with_limit(n): return sum(i**2 for i in range(n)) latex = ( - r"\mathrm{sum_with_limit}(n) = \sum_{i = 0}^{{n - 1}} \left({i^{{2}}}\right)" + r"\mathrm{sum_with_limit}(n) = \sum_{i = {0}}^{{n - 1}} \left({i^{{2}}}\right)" ) _check_function(sum_with_limit, latex) @@ -109,7 +109,8 @@ def prod_with_limit(n): return math.prod(i**2 for i in range(n)) latex = ( - r"\mathrm{prod_with_limit}(n) = \prod_{i = 0}^{{n - 1}} \left({i^{{2}}}\right)" + r"\mathrm{prod_with_limit}(n) = " + r"\prod_{i = {0}}^{{n - 1}} \left({i^{{2}}}\right)" ) _check_function(prod_with_limit, latex) diff --git a/src/latexify/analyzers.py b/src/latexify/analyzers.py new file mode 100644 index 0000000..56b4a2b --- /dev/null +++ b/src/latexify/analyzers.py @@ -0,0 +1,61 @@ +"""Analyzer functions for specific subtrees.""" + +from __future__ import annotations + +import ast +import dataclasses + +from latexify import ast_utils, exceptions + + +@dataclasses.dataclass(frozen=True, eq=False) +class RangeInfo: + """Information of the range function.""" + + # Argument subtrees. These areguments could be shallow copies of the original + # subtree. + start: ast.expr + stop: ast.expr + step: ast.expr + + # Integer representation of each argument, when it is possible. + start_int: int | None + stop_int: int | None + step_int: int | None + + +def analyze_range(node: ast.Call) -> RangeInfo: + """Obtains RangeInfo from a Call subtree. + + Args: + node: Subtree to be analyzed. + + Returns: + RangeInfo extracted from `node`. + """ + if not ( + isinstance(node.func, ast.Name) + and node.func.id == "range" + and 1 <= len(node.args) <= 3 + ): + raise exceptions.LatexifySyntaxError("Unsupported AST for analyze_range.") + + num_args = len(node.args) + + if num_args == 1: + start = ast_utils.make_constant(0) + stop = node.args[0] + step = ast_utils.make_constant(1) + else: + start = node.args[0] + stop = node.args[1] + step = node.args[2] if num_args == 3 else ast_utils.make_constant(1) + + return RangeInfo( + start=start, + stop=stop, + step=step, + start_int=ast_utils.extract_int_or_none(start), + stop_int=ast_utils.extract_int_or_none(stop), + step_int=ast_utils.extract_int_or_none(step), + ) diff --git a/src/latexify/analyzers_test.py b/src/latexify/analyzers_test.py new file mode 100644 index 0000000..b536ebc --- /dev/null +++ b/src/latexify/analyzers_test.py @@ -0,0 +1,152 @@ +"""Tests for latexify.analyzers.""" + +from __future__ import annotations + +import ast + +import pytest + +from latexify import analyzers, exceptions, test_utils + + +@test_utils.require_at_least(8) +@pytest.mark.parametrize( + "code,start,stop,step,start_int,stop_int,step_int", + [ + ( + "range(x)", + ast.Constant(value=0), + ast.Name(id="x", ctx=ast.Load()), + ast.Constant(value=1), + 0, + None, + 1, + ), + ( + "range(123)", + ast.Constant(value=0), + ast.Constant(value=123), + ast.Constant(value=1), + 0, + 123, + 1, + ), + ( + "range(x, y)", + ast.Name(id="x", ctx=ast.Load()), + ast.Name(id="y", ctx=ast.Load()), + ast.Constant(value=1), + None, + None, + 1, + ), + ( + "range(123, y)", + ast.Constant(value=123), + ast.Name(id="y", ctx=ast.Load()), + ast.Constant(value=1), + 123, + None, + 1, + ), + ( + "range(x, 123)", + ast.Name(id="x", ctx=ast.Load()), + ast.Constant(value=123), + ast.Constant(value=1), + None, + 123, + 1, + ), + ( + "range(x, y, z)", + ast.Name(id="x", ctx=ast.Load()), + ast.Name(id="y", ctx=ast.Load()), + ast.Name(id="z", ctx=ast.Load()), + None, + None, + None, + ), + ( + "range(123, y, z)", + ast.Constant(value=123), + ast.Name(id="y", ctx=ast.Load()), + ast.Name(id="z", ctx=ast.Load()), + 123, + None, + None, + ), + ( + "range(x, 123, z)", + ast.Name(id="x", ctx=ast.Load()), + ast.Constant(value=123), + ast.Name(id="z", ctx=ast.Load()), + None, + 123, + None, + ), + ( + "range(x, y, 123)", + ast.Name(id="x", ctx=ast.Load()), + ast.Name(id="y", ctx=ast.Load()), + ast.Constant(value=123), + None, + None, + 123, + ), + ], +) +def test_analyze_range( + code: str, + start: ast.expr, + stop: ast.expr, + step: ast.expr | None, + start_int: int | None, + stop_int: int | None, + step_int: int | None, +) -> None: + node = ast.parse(code).body[0].value + assert isinstance(node, ast.Call) + + info = analyzers.analyze_range(node) + + test_utils.assert_ast_equal(observed=info.start, expected=start) + test_utils.assert_ast_equal(observed=info.stop, expected=stop) + if step is not None: + test_utils.assert_ast_equal(observed=info.step, expected=step) + else: + assert info.step is None + + def check_int(observed: int | None, expected: int | None) -> None: + if expected is not None: + assert observed == expected + else: + assert observed is None + + check_int(observed=info.start_int, expected=start_int) + check_int(observed=info.stop_int, expected=stop_int) + check_int(observed=info.step_int, expected=step_int) + + +@pytest.mark.parametrize( + "code", + [ + # Not a direct call + "__builtins__.range(x)", + 'getattr(__builtins__, "range")(x)', + # Unsupported functions + "f(x)", + "iter(range(x))", + # Range with invalid arguments + "range()", + "range(x, y, z, w)", + ], +) +def test_analyze_range_invalid(code: str) -> None: + node = ast.parse(code).body[0].value + assert isinstance(node, ast.Call) + + with pytest.raises( + exceptions.LatexifySyntaxError, match=r"^Unsupported AST for analyze_range\.$" + ): + analyzers.analyze_range(node) diff --git a/src/latexify/ast_utils.py b/src/latexify/ast_utils.py new file mode 100644 index 0000000..2d2e6c9 --- /dev/null +++ b/src/latexify/ast_utils.py @@ -0,0 +1,88 @@ +"""Utilities to generate AST nodes.""" + +from __future__ import annotations + +import ast +import sys +from typing import Any + + +def make_constant(value: Any) -> ast.expr: + """Generates a new Constant node. + + Args: + value: Value of the node. + + Returns: + Generated ast.Constant or its equivalent. + + Raises: + ValueError: Unsupported value type. + """ + if sys.version_info.minor < 8: + if value is None or value is False or value is True: + return ast.NameConstant(value=value) + if value is ...: + return ast.Ellipsis() + if isinstance(value, (int, float, complex)): + return ast.Num(n=value) + if isinstance(value, str): + return ast.Str(s=value) + if isinstance(value, bytes): + return ast.Bytes(s=value) + else: + if ( + value is None + or value is ... + or isinstance(value, (bool, int, float, complex, str, bytes)) + ): + return ast.Constant(value=value) + + raise ValueError(f"Unsupported type to generate Constant: {type(value).__name__}") + + +def extract_int_or_none(node: ast.expr) -> int | None: + """Extracts int constant from the given Constant node. + + Args: + node: ast.Constant or its equivalent representing an int value. + + Returns: + Extracted int value, or None if extraction failed. + """ + if sys.version_info.minor < 8: + if ( + isinstance(node, ast.Num) + and isinstance(node.n, int) + and not isinstance(node.n, bool) + ): + return node.n + else: + if ( + isinstance(node, ast.Constant) + and isinstance(node.value, int) + and not isinstance(node.n, bool) + ): + return node.value + + return None + + +def extract_int(node: ast.expr) -> int: + """Extracts int constant from the given Constant node. + + Args: + node: ast.Constant or its equivalent representing an int value. + + Returns: + Extracted int value. + + Raises: + ValueError: Not a subtree containing an int value. + """ + value = extract_int_or_none(node) + + if value is None: + raise ValueError(f"Unsupported node to extract int: {type(node).__name__}") + + return value diff --git a/src/latexify/ast_utils_test.py b/src/latexify/ast_utils_test.py new file mode 100644 index 0000000..f01d182 --- /dev/null +++ b/src/latexify/ast_utils_test.py @@ -0,0 +1,118 @@ +"""Tests for latexify.ast_utils.""" + +from __future__ import annotations + +import ast +from typing import Any + +import pytest + +from latexify import ast_utils +from latexify import test_utils + + +@test_utils.require_at_most(7) +@pytest.mark.parametrize( + "value,expected", + [ + (None, ast.NameConstant(value=None)), + (False, ast.NameConstant(value=False)), + (True, ast.NameConstant(value=True)), + (..., ast.Ellipsis()), + (123, ast.Num(n=123)), + (4.5, ast.Num(n=4.5)), + (6 + 7j, ast.Num(n=6 + 7j)), + ("foo", ast.Str(s="foo")), + (b"bar", ast.Bytes(s=b"bar")), + ], +) +def test_make_constant_legacy(value: Any, expected: ast.Constant) -> None: + test_utils.assert_ast_equal( + observed=ast_utils.make_constant(value), + expected=expected, + ) + + +@test_utils.require_at_least(8) +@pytest.mark.parametrize( + "value,expected", + [ + (None, ast.Constant(value=None)), + (False, ast.Constant(value=False)), + (True, ast.Constant(value=True)), + (..., ast.Constant(value=...)), + (123, ast.Constant(value=123)), + (4.5, ast.Constant(value=4.5)), + (6 + 7j, ast.Constant(value=6 + 7j)), + ("foo", ast.Constant(value="foo")), + (b"bar", ast.Constant(value=b"bar")), + ], +) +def test_make_constant(value: Any, expected: ast.Constant) -> None: + test_utils.assert_ast_equal( + observed=ast_utils.make_constant(value), + expected=expected, + ) + + +def test_make_constant_invalid() -> None: + with pytest.raises(ValueError, match=r"^Unsupported type to generate"): + ast_utils.make_constant(object()) + + +def test_extract_int_or_none() -> None: + assert ast_utils.extract_int_or_none(ast_utils.make_constant(-123)) == -123 + assert ast_utils.extract_int_or_none(ast_utils.make_constant(0)) == 0 + assert ast_utils.extract_int_or_none(ast_utils.make_constant(123)) == 123 + + +def test_extract_int_or_none_invalid() -> None: + # Not a subtree. + assert ast_utils.extract_int_or_none(123) is None + + # Not a direct Constant node. + assert ( + ast_utils.extract_int_or_none(ast.Expr(value=ast_utils.make_constant(123))) + is None + ) + + # Not a Constant node with int. + assert ast_utils.extract_int_or_none(ast_utils.make_constant(None)) is None + assert ast_utils.extract_int_or_none(ast_utils.make_constant(True)) is None + assert ast_utils.extract_int_or_none(ast_utils.make_constant(...)) is None + assert ast_utils.extract_int_or_none(ast_utils.make_constant(123.0)) is None + assert ast_utils.extract_int_or_none(ast_utils.make_constant(4 + 5j)) is None + assert ast_utils.extract_int_or_none(ast_utils.make_constant("123")) is None + assert ast_utils.extract_int_or_none(ast_utils.make_constant(b"123")) is None + + +def test_extract_int() -> None: + assert ast_utils.extract_int(ast_utils.make_constant(-123)) == -123 + assert ast_utils.extract_int(ast_utils.make_constant(0)) == 0 + assert ast_utils.extract_int(ast_utils.make_constant(123)) == 123 + + +def test_extract_int_invalid() -> None: + # Not a subtree. + with pytest.raises(ValueError, match=r"^Unsupported node to extract int"): + ast_utils.extract_int(123) + + # Not a direct Constant node. + with pytest.raises(ValueError, match=r"^Unsupported node to extract int"): + ast_utils.extract_int(ast.Expr(value=ast_utils.make_constant(123))) + + # Not a Constant node with int. + with pytest.raises(ValueError, match=r"^Unsupported node to extract int"): + ast_utils.extract_int(ast_utils.make_constant(None)) + with pytest.raises(ValueError, match=r"^Unsupported node to extract int"): + ast_utils.extract_int(ast_utils.make_constant(True)) + with pytest.raises(ValueError, match=r"^Unsupported node to extract int"): + ast_utils.extract_int(ast_utils.make_constant(...)) + with pytest.raises(ValueError, match=r"^Unsupported node to extract int"): + ast_utils.extract_int(ast_utils.make_constant(123.0)) + with pytest.raises(ValueError, match=r"^Unsupported node to extract int"): + ast_utils.extract_int(ast_utils.make_constant(4 + 5j)) + with pytest.raises(ValueError, match=r"^Unsupported node to extract int"): + ast_utils.extract_int(ast_utils.make_constant("123")) + with pytest.raises(ValueError, match=r"^Unsupported node to extract int"): + ast_utils.extract_int(ast_utils.make_constant(b"123")) diff --git a/src/latexify/codegen/function_codegen.py b/src/latexify/codegen/function_codegen.py index ba6fc41..09ac8cf 100644 --- a/src/latexify/codegen/function_codegen.py +++ b/src/latexify/codegen/function_codegen.py @@ -5,6 +5,7 @@ import ast from typing import Any, ClassVar +from latexify import analyzers from latexify import constants from latexify import math_symbols from latexify import exceptions @@ -321,34 +322,35 @@ def _get_range_info(self, node: ast.GeneratorExp) -> tuple[str, str, str, str]: comp = node.generators[0] - if not ( - isinstance(comp.iter, ast.Call) - and isinstance(comp.iter.func, ast.Name) - and comp.iter.func.id == "range" - and 1 <= len(comp.iter.args) <= 2 - and not comp.ifs - ): - raise exceptions.LatexifySyntaxError( - "Comprehension with range contains unsupported syntax." + if not isinstance(comp.iter, ast.Call) or comp.ifs: + raise exceptions.LatexifyNotSupportedError("Unsupported comprehension.") + + # May cause LatexifyError + range_info = analyzers.analyze_range(comp.iter) + + if ( + # Only accepts ascending order with step size 1. + range_info.step_int != 1 + or ( + range_info.start_int is not None + and range_info.stop_int is not None + and range_info.start_int >= range_info.stop_int ) + ): + raise exceptions.LatexifyNotSupportedError("Unsupported comprehension.") elt_str = self.visit(node.elt) target_str = self.visit(comp.target) - args_str = [self.visit(arg) for arg in comp.iter.args] - if len(args_str) == 1: - lower_str = "0" - upper_plus_1 = args_str[0] + if range_info.start_int is None: + lower_str = self.visit(range_info.start) + else: + lower_str = f"{{{range_info.start_int}}}" + + if range_info.stop_int is None: + upper_str = "{" + self.visit(range_info.stop) + " - 1}" else: - lower_str = args_str[0] - upper_plus_1 = args_str[1] - - # Upper bound of range is exclusive. Try to numerically subtract it by 1. - try: - upper_plus_1_unwrapped = upper_plus_1[1:-1] # Remove { and } - upper_str = str(int(upper_plus_1_unwrapped) - 1) - except ValueError: - upper_str = "{" + upper_plus_1 + " - 1}" + upper_str = f"{{{range_info.stop_int - 1}}}" # Used for: \sum_{target_str = lower_str}^{upper_str} elt_str return elt_str, target_str, lower_str, upper_str diff --git a/src/latexify/codegen/function_codegen_test.py b/src/latexify/codegen/function_codegen_test.py index 78d87df..51873c7 100644 --- a/src/latexify/codegen/function_codegen_test.py +++ b/src/latexify/codegen/function_codegen_test.py @@ -1,4 +1,4 @@ -"""Tests for latexify.latexify_visitor.""" +"""Tests for latexify.codegen.function_codegen.""" from __future__ import annotations diff --git a/src/latexify/test_utils.py b/src/latexify/test_utils.py index 0f6651f..63ea1ec 100644 --- a/src/latexify/test_utils.py +++ b/src/latexify/test_utils.py @@ -87,6 +87,7 @@ def ast_equal(observed: ast.AST, expected: ast.AST) -> bool: for co, ce in zip(vo, ve) ) else: + assert type(vo) is type(ve) assert vo == ve except (AssertionError, AttributeError): @@ -121,18 +122,3 @@ def assert_ast_equal(observed: ast.AST, expected: ast.AST) -> None: observed={ast.dump(observed)} expected={ast.dump(expected)} """ - - -def make_num(value: int) -> ast.expr: - """Helper function to generate a node for number. - - Args: - value: The value of the node. - - Returns: - Generated AST. - """ - if sys.version_info.minor < 8: - return ast.Num(n=value) - else: - return ast.Constant(value=value) diff --git a/src/latexify/transformers/assignment_reducer_test.py b/src/latexify/transformers/assignment_reducer_test.py index 0cc4d09..84288f0 100644 --- a/src/latexify/transformers/assignment_reducer_test.py +++ b/src/latexify/transformers/assignment_reducer_test.py @@ -3,7 +3,7 @@ from __future__ import annotations import ast -from latexify import parser, test_utils +from latexify import ast_utils, parser, test_utils from latexify.transformers.assignment_reducer import AssignmentReducer @@ -54,7 +54,7 @@ def f(x): expected = _make_ast( [ - ast.Return(value=test_utils.make_num(0)), + ast.Return(value=ast_utils.make_constant(0)), ] ) transformed = AssignmentReducer().visit(parser.parse_function(f)) @@ -70,7 +70,7 @@ def f(x): [ ast.Return( value=ast.BinOp( - left=test_utils.make_num(2), + left=ast_utils.make_constant(2), op=ast.Mult(), right=ast.Name(id="x", ctx=ast.Load()), ) @@ -91,10 +91,10 @@ def f(x): [ ast.Return( value=ast.BinOp( - left=test_utils.make_num(3), + left=ast_utils.make_constant(3), op=ast.Add(), right=ast.BinOp( - left=test_utils.make_num(2), + left=ast_utils.make_constant(2), op=ast.Mult(), right=ast.Name(id="x", ctx=ast.Load()), ), @@ -116,7 +116,7 @@ def f(x): [ ast.Return( value=ast.BinOp( - left=test_utils.make_num(3), + left=ast_utils.make_constant(3), op=ast.Add(), right=ast.Name(id="x", ctx=ast.Load()), ) From ceee173a8eac0f7f3ce42f049440d4aa5ab03897 Mon Sep 17 00:00:00 2001 From: odashi Date: Fri, 4 Nov 2022 10:45:15 +0000 Subject: [PATCH 2/3] Support some forms in sum/prod --- src/latexify/analyzers.py | 3 + src/latexify/codegen/function_codegen.py | 101 +++++++++++++----- src/latexify/codegen/function_codegen_test.py | 49 +++++++++ 3 files changed, 127 insertions(+), 26 deletions(-) diff --git a/src/latexify/analyzers.py b/src/latexify/analyzers.py index 56b4a2b..a55f68e 100644 --- a/src/latexify/analyzers.py +++ b/src/latexify/analyzers.py @@ -32,6 +32,9 @@ def analyze_range(node: ast.Call) -> RangeInfo: Returns: RangeInfo extracted from `node`. + + Raises: + LatexifySyntaxError: Analysis failed. """ if not ( isinstance(node.func, ast.Name) diff --git a/src/latexify/codegen/function_codegen.py b/src/latexify/codegen/function_codegen.py index 09ac8cf..4b569d7 100644 --- a/src/latexify/codegen/function_codegen.py +++ b/src/latexify/codegen/function_codegen.py @@ -138,9 +138,8 @@ def visit_Call(self, node: ast.Call) -> str: ) if func_str in ("sum", "prod") and isinstance(node.args[0], ast.GeneratorExp): - # Special treatment for sum/prod(x for x in range(a, b)) - elt, tgt, lo, up = self._get_range_info(node.args[0]) - return rf"\{func_str}_{{{tgt} = {lo}}}^{{{up}}} \left({{{elt}}}\right)" + elt, lower, upper = self._get_sum_prod_info(node.args[0]) + return rf"\{func_str}_{{{lower}}}^{{{upper}}} \left({{{elt}}}\right)" arg_strs = [self.visit(arg) for arg in node.args] return lstr + ", ".join(arg_strs) + rstr @@ -311,22 +310,29 @@ def visit_If(self, node: ast.If) -> str: latex += self.visit(node) return latex + r", & \mathrm{otherwise} \end{array} \right." - def _get_range_info(self, node: ast.GeneratorExp) -> tuple[str, str, str, str]: - """Processor for (x for x in range(a, b))""" + def _get_sum_prod_range(self, node: ast.comprehension) -> tuple[str, str] | None: + """Helper to process range(...) for sum and prod functions. - # TODO(odashi): This could be supported. - if len(node.generators) != 1: - raise exceptions.LatexifyNotSupportedError( - "Multi-clause comprehension is not supported." - ) - - comp = node.generators[0] + Args: + node: comprehension node to be analyzed. - if not isinstance(comp.iter, ast.Call) or comp.ifs: - raise exceptions.LatexifyNotSupportedError("Unsupported comprehension.") + Returns: + Tuple of following strings: + - lower_rhs + - upper + which are used in _get_sum_prod_info, or None if the analysis failed. + """ + if not ( + isinstance(node.iter, ast.Call) + and isinstance(node.iter.func, ast.Name) + and node.iter.func.id == "range" + ): + return None - # May cause LatexifyError - range_info = analyzers.analyze_range(comp.iter) + try: + range_info = analyzers.analyze_range(node.iter) + except exceptions.LatexifyError: + return None if ( # Only accepts ascending order with step size 1. @@ -337,23 +343,66 @@ def _get_range_info(self, node: ast.GeneratorExp) -> tuple[str, str, str, str]: and range_info.start_int >= range_info.stop_int ) ): - raise exceptions.LatexifyNotSupportedError("Unsupported comprehension.") - - elt_str = self.visit(node.elt) - target_str = self.visit(comp.target) + return None if range_info.start_int is None: - lower_str = self.visit(range_info.start) + lower_rhs = self.visit(range_info.start) else: - lower_str = f"{{{range_info.start_int}}}" + lower_rhs = f"{{{range_info.start_int}}}" if range_info.stop_int is None: - upper_str = "{" + self.visit(range_info.stop) + " - 1}" + upper = "{" + self.visit(range_info.stop) + " - 1}" + else: + upper = f"{{{range_info.stop_int - 1}}}" + + return lower_rhs, upper + + def _get_sum_prod_info(self, node: ast.GeneratorExp) -> tuple[str, str, str]: + r"""Process GeneratorExp for sum and prod functions. + + Args: + node: GeneratorExp node to be analyzed. + + Returns: + Tuple of following strings: + - elt + - lower + - upper + which are used to represent sum/prod operators as follows: + "\sum_{lower}^{upper} {elt}" + + Raises: + LateixfyError: Unsupported AST is given. + """ + + # TODO(odashi): This could be supported. + if len(node.generators) != 1: + raise exceptions.LatexifyNotSupportedError( + "Multi-clause comprehension is not supported." + ) + + comp = node.generators[0] + + # TODO(odashi): This could be supported. + if comp.ifs: + raise exceptions.LatexifyNotSupportedError( + "If-clause in comprehension is not supported." + ) + + elt = self.visit(node.elt) + target = self.visit(comp.target) + + range_args = self._get_sum_prod_range(comp) + + if range_args is not None: + lower_rhs, upper = range_args + lower = f"{target} = {lower_rhs}" else: - upper_str = f"{{{range_info.stop_int - 1}}}" + lower_rhs = self.visit(comp.iter) + lower = rf"{target} \in {lower_rhs}" + upper = "" - # Used for: \sum_{target_str = lower_str}^{upper_str} elt_str - return elt_str, target_str, lower_str, upper_str + return elt, lower, upper # Until 3.8 def visit_Index(self, node: ast.Index) -> str: diff --git a/src/latexify/codegen/function_codegen_test.py b/src/latexify/codegen/function_codegen_test.py index 51873c7..47da113 100644 --- a/src/latexify/codegen/function_codegen_test.py +++ b/src/latexify/codegen/function_codegen_test.py @@ -41,6 +41,55 @@ def test_visit_functiondef_use_signature() -> None: assert FunctionCodegen(use_signature=True).visit(tree) == latex_with_flag +@pytest.mark.parametrize( + "src_suffix,dest_suffix", + [ + ("(x)", r" \left({x}\right)"), + ("([1, 2])", r" \left({\left[ {1}\space,\space {2}\right] }\right)"), + ("({1, 2})", r" \left({\left\{ {1}\space,\space {2}\right\} }\right)"), + ("(f(x))", r" \left({\mathrm{f}\left(x\right)}\right)"), + ("(i for i in x)", r"_{i \in x}^{} \left({i}\right)"), + ( + "(i for i in [1, 2])", + r"_{i \in \left[ {1}\space,\space {2}\right] }^{} \left({i}\right)", + ), + ( + "(i for i in {1, 2})", + r"_{i \in \left\{ {1}\space,\space {2}\right\} }^{} \left({i}\right)", + ), + ("(i for i in f(x))", r"_{i \in \mathrm{f}\left(x\right)}^{} \left({i}\right)"), + ("(i for i in range(n))", r"_{i = {0}}^{{n - 1}} \left({i}\right)"), + ("(i for i in range(3))", r"_{i = {0}}^{{2}} \left({i}\right)"), + ("(i for i in range(n, m))", r"_{i = n}^{{m - 1}} \left({i}\right)"), + ("(i for i in range(1, m))", r"_{i = {1}}^{{m - 1}} \left({i}\right)"), + ("(i for i in range(n, 3))", r"_{i = n}^{{2}} \left({i}\right)"), + ( + "(i for i in range(n, m, k))", + r"_{i \in \mathrm{range}\left(n, m, k\right)}^{} \left({i}\right)", + ), + ], +) +def test_visit_call_sum_prod(src_suffix: str, dest_suffix: str) -> None: + for src_fn, dest_fn in [("sum", r"\sum"), ("math.prod", r"\prod")]: + node = ast.parse(src_fn + src_suffix).body[0].value + assert isinstance(node, ast.Call) + assert FunctionCodegen().visit(node) == dest_fn + dest_suffix + + +def test_visit_call_sum_prod_multi_comprehension() -> None: + for fn_name in ["sum", "math.prod"]: + node = ast.parse(f"{fn_name}(i for y in x for i in y)").body[0].value + with pytest.raises(exceptions.LatexifyNotSupportedError, match="^Multi-clause"): + FunctionCodegen().visit(node) + + +def test_visit_call_sum_prod_with_if() -> None: + for fn_name in ["sum", "math.prod"]: + node = ast.parse(f"{fn_name}(i for y in x if y == 0)").body[0].value + with pytest.raises(exceptions.LatexifyNotSupportedError, match="^If-clause"): + FunctionCodegen().visit(node) + + @pytest.mark.parametrize( "code,latex", [ From 688fa175c761949ee28d8b62bc9e77661d1634a4 Mon Sep 17 00:00:00 2001 From: odashi Date: Fri, 4 Nov 2022 12:32:14 +0000 Subject: [PATCH 3/3] support multi-clause comprehension in sum and prod --- src/latexify/analyzers_test.py | 2 +- src/latexify/codegen/function_codegen.py | 59 ++++++++++--------- src/latexify/codegen/function_codegen_test.py | 35 +++++++++-- 3 files changed, 61 insertions(+), 35 deletions(-) diff --git a/src/latexify/analyzers_test.py b/src/latexify/analyzers_test.py index b536ebc..e3d86f7 100644 --- a/src/latexify/analyzers_test.py +++ b/src/latexify/analyzers_test.py @@ -100,7 +100,7 @@ def test_analyze_range( code: str, start: ast.expr, stop: ast.expr, - step: ast.expr | None, + step: ast.expr, start_int: int | None, stop_int: int | None, step_int: int | None, diff --git a/src/latexify/codegen/function_codegen.py b/src/latexify/codegen/function_codegen.py index 4b569d7..291ac58 100644 --- a/src/latexify/codegen/function_codegen.py +++ b/src/latexify/codegen/function_codegen.py @@ -138,8 +138,9 @@ def visit_Call(self, node: ast.Call) -> str: ) if func_str in ("sum", "prod") and isinstance(node.args[0], ast.GeneratorExp): - elt, lower, upper = self._get_sum_prod_info(node.args[0]) - return rf"\{func_str}_{{{lower}}}^{{{upper}}} \left({{{elt}}}\right)" + elt, scripts = self._get_sum_prod_info(node.args[0]) + scripts_str = [rf"\{func_str}_{{{lo}}}^{{{up}}}" for lo, up in scripts] + return " ".join(scripts_str) + rf" \left({{{elt}}}\right)" arg_strs = [self.visit(arg) for arg in node.args] return lstr + ", ".join(arg_strs) + rstr @@ -357,7 +358,9 @@ def _get_sum_prod_range(self, node: ast.comprehension) -> tuple[str, str] | None return lower_rhs, upper - def _get_sum_prod_info(self, node: ast.GeneratorExp) -> tuple[str, str, str]: + def _get_sum_prod_info( + self, node: ast.GeneratorExp + ) -> tuple[str, list[tuple[str, str]]]: r"""Process GeneratorExp for sum and prod functions. Args: @@ -366,43 +369,41 @@ def _get_sum_prod_info(self, node: ast.GeneratorExp) -> tuple[str, str, str]: Returns: Tuple of following strings: - elt - - lower - - upper + - scripts which are used to represent sum/prod operators as follows: - "\sum_{lower}^{upper} {elt}" + \sum_{scripts[0][0]}^{scripts[0][1]} + \sum_{scripts[1][0]}^{scripts[1][1]} + ... + {elt} Raises: LateixfyError: Unsupported AST is given. """ + elt = self.visit(node.elt) - # TODO(odashi): This could be supported. - if len(node.generators) != 1: - raise exceptions.LatexifyNotSupportedError( - "Multi-clause comprehension is not supported." - ) - - comp = node.generators[0] + scripts: list[tuple[str, str]] = [] - # TODO(odashi): This could be supported. - if comp.ifs: - raise exceptions.LatexifyNotSupportedError( - "If-clause in comprehension is not supported." - ) + for comp in node.generators: + # TODO(odashi): This could be supported. + if comp.ifs: + raise exceptions.LatexifyNotSupportedError( + "If-clause in comprehension is not supported." + ) - elt = self.visit(node.elt) - target = self.visit(comp.target) + target = self.visit(comp.target) + range_args = self._get_sum_prod_range(comp) - range_args = self._get_sum_prod_range(comp) + if range_args is not None: + lower_rhs, upper = range_args + lower = f"{target} = {lower_rhs}" + else: + lower_rhs = self.visit(comp.iter) + lower = rf"{target} \in {lower_rhs}" + upper = "" - if range_args is not None: - lower_rhs, upper = range_args - lower = f"{target} = {lower_rhs}" - else: - lower_rhs = self.visit(comp.iter) - lower = rf"{target} \in {lower_rhs}" - upper = "" + scripts.append((lower, upper)) - return elt, lower, upper + return elt, scripts # Until 3.8 def visit_Index(self, node: ast.Index) -> str: diff --git a/src/latexify/codegen/function_codegen_test.py b/src/latexify/codegen/function_codegen_test.py index 47da113..402c634 100644 --- a/src/latexify/codegen/function_codegen_test.py +++ b/src/latexify/codegen/function_codegen_test.py @@ -44,10 +44,12 @@ def test_visit_functiondef_use_signature() -> None: @pytest.mark.parametrize( "src_suffix,dest_suffix", [ + # No comprehension ("(x)", r" \left({x}\right)"), ("([1, 2])", r" \left({\left[ {1}\space,\space {2}\right] }\right)"), ("({1, 2})", r" \left({\left\{ {1}\space,\space {2}\right\} }\right)"), ("(f(x))", r" \left({\mathrm{f}\left(x\right)}\right)"), + # Single comprehension ("(i for i in x)", r"_{i \in x}^{} \left({i}\right)"), ( "(i for i in [1, 2])", @@ -76,11 +78,34 @@ def test_visit_call_sum_prod(src_suffix: str, dest_suffix: str) -> None: assert FunctionCodegen().visit(node) == dest_fn + dest_suffix -def test_visit_call_sum_prod_multi_comprehension() -> None: - for fn_name in ["sum", "math.prod"]: - node = ast.parse(f"{fn_name}(i for y in x for i in y)").body[0].value - with pytest.raises(exceptions.LatexifyNotSupportedError, match="^Multi-clause"): - FunctionCodegen().visit(node) +@pytest.mark.parametrize( + "code,latex", + [ + # 2 clauses + ( + "sum(i for y in x for i in y)", + r"\sum_{y \in x}^{} \sum_{i \in y}^{} \left({i}\right)", + ), + ( + "sum(i for y in x for z in y for i in z)", + r"\sum_{y \in x}^{} \sum_{z \in y}^{} \sum_{i \in z}^{} \left({i}\right)", + ), + # 3 clauses + ( + "math.prod(i for y in x for i in y)", + r"\prod_{y \in x}^{} \prod_{i \in y}^{} \left({i}\right)", + ), + ( + "math.prod(i for y in x for z in y for i in z)", + r"\prod_{y \in x}^{} \prod_{z \in y}^{} \prod_{i \in z}^{} " + r"\left({i}\right)", + ), + ], +) +def test_visit_call_sum_prod_multiple_comprehension(code: str, latex: str) -> None: + node = ast.parse(code).body[0].value + assert isinstance(node, ast.Call) + assert FunctionCodegen().visit(node) == latex def test_visit_call_sum_prod_with_if() -> None: