Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support multi-clause comprehension in sum and prod. #77

Merged
merged 4 commits into from
Nov 6, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/latexify/analyzers_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ def test_analyze_range(
code: str,
start: ast.expr,
stop: ast.expr,
step: ast.expr | None,
step: ast.expr,
start_int: int | None,
stop_int: int | None,
step_int: int | None,
Expand Down
59 changes: 30 additions & 29 deletions src/latexify/codegen/function_codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,8 +138,9 @@ def visit_Call(self, node: ast.Call) -> str:
)

if func_str in ("sum", "prod") and isinstance(node.args[0], ast.GeneratorExp):
elt, lower, upper = self._get_sum_prod_info(node.args[0])
return rf"\{func_str}_{{{lower}}}^{{{upper}}} \left({{{elt}}}\right)"
elt, scripts = self._get_sum_prod_info(node.args[0])
scripts_str = [rf"\{func_str}_{{{lo}}}^{{{up}}}" for lo, up in scripts]
return " ".join(scripts_str) + rf" \left({{{elt}}}\right)"

arg_strs = [self.visit(arg) for arg in node.args]
return lstr + ", ".join(arg_strs) + rstr
Expand Down Expand Up @@ -357,7 +358,9 @@ def _get_sum_prod_range(self, node: ast.comprehension) -> tuple[str, str] | None

return lower_rhs, upper

def _get_sum_prod_info(self, node: ast.GeneratorExp) -> tuple[str, str, str]:
def _get_sum_prod_info(
self, node: ast.GeneratorExp
) -> tuple[str, list[tuple[str, str]]]:
r"""Process GeneratorExp for sum and prod functions.

Args:
Expand All @@ -366,43 +369,41 @@ def _get_sum_prod_info(self, node: ast.GeneratorExp) -> tuple[str, str, str]:
Returns:
Tuple of following strings:
- elt
- lower
- upper
- scripts
which are used to represent sum/prod operators as follows:
"\sum_{lower}^{upper} {elt}"
\sum_{scripts[0][0]}^{scripts[0][1]}
\sum_{scripts[1][0]}^{scripts[1][1]}
...
{elt}

Raises:
LateixfyError: Unsupported AST is given.
"""
elt = self.visit(node.elt)

# TODO(odashi): This could be supported.
if len(node.generators) != 1:
raise exceptions.LatexifyNotSupportedError(
"Multi-clause comprehension is not supported."
)

comp = node.generators[0]
scripts: list[tuple[str, str]] = []

# TODO(odashi): This could be supported.
if comp.ifs:
raise exceptions.LatexifyNotSupportedError(
"If-clause in comprehension is not supported."
)
for comp in node.generators:
# TODO(odashi): This could be supported.
if comp.ifs:
raise exceptions.LatexifyNotSupportedError(
"If-clause in comprehension is not supported."
)

elt = self.visit(node.elt)
target = self.visit(comp.target)
target = self.visit(comp.target)
range_args = self._get_sum_prod_range(comp)

range_args = self._get_sum_prod_range(comp)
if range_args is not None:
lower_rhs, upper = range_args
lower = f"{target} = {lower_rhs}"
else:
lower_rhs = self.visit(comp.iter)
lower = rf"{target} \in {lower_rhs}"
upper = ""

if range_args is not None:
lower_rhs, upper = range_args
lower = f"{target} = {lower_rhs}"
else:
lower_rhs = self.visit(comp.iter)
lower = rf"{target} \in {lower_rhs}"
upper = ""
scripts.append((lower, upper))

return elt, lower, upper
return elt, scripts

# Until 3.8
def visit_Index(self, node: ast.Index) -> str:
Expand Down
35 changes: 30 additions & 5 deletions src/latexify/codegen/function_codegen_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,10 +44,12 @@ def test_visit_functiondef_use_signature() -> None:
@pytest.mark.parametrize(
"src_suffix,dest_suffix",
[
# No comprehension
("(x)", r" \left({x}\right)"),
("([1, 2])", r" \left({\left[ {1}\space,\space {2}\right] }\right)"),
("({1, 2})", r" \left({\left\{ {1}\space,\space {2}\right\} }\right)"),
("(f(x))", r" \left({\mathrm{f}\left(x\right)}\right)"),
# Single comprehension
("(i for i in x)", r"_{i \in x}^{} \left({i}\right)"),
(
"(i for i in [1, 2])",
Expand Down Expand Up @@ -76,11 +78,34 @@ def test_visit_call_sum_prod(src_suffix: str, dest_suffix: str) -> None:
assert FunctionCodegen().visit(node) == dest_fn + dest_suffix


def test_visit_call_sum_prod_multi_comprehension() -> None:
for fn_name in ["sum", "math.prod"]:
node = ast.parse(f"{fn_name}(i for y in x for i in y)").body[0].value
with pytest.raises(exceptions.LatexifyNotSupportedError, match="^Multi-clause"):
FunctionCodegen().visit(node)
@pytest.mark.parametrize(
"code,latex",
[
# 2 clauses
(
"sum(i for y in x for i in y)",
r"\sum_{y \in x}^{} \sum_{i \in y}^{} \left({i}\right)",
),
(
"sum(i for y in x for z in y for i in z)",
r"\sum_{y \in x}^{} \sum_{z \in y}^{} \sum_{i \in z}^{} \left({i}\right)",
),
# 3 clauses
(
"math.prod(i for y in x for i in y)",
r"\prod_{y \in x}^{} \prod_{i \in y}^{} \left({i}\right)",
),
(
"math.prod(i for y in x for z in y for i in z)",
r"\prod_{y \in x}^{} \prod_{z \in y}^{} \prod_{i \in z}^{} "
r"\left({i}\right)",
),
],
)
def test_visit_call_sum_prod_multiple_comprehension(code: str, latex: str) -> None:
node = ast.parse(code).body[0].value
assert isinstance(node, ast.Call)
assert FunctionCodegen().visit(node) == latex


def test_visit_call_sum_prod_with_if() -> None:
Expand Down