Skip to content

Commit

Permalink
support list/set comprehension. (#84)
Browse files Browse the repository at this point in the history
  • Loading branch information
Yusuke Oda authored Nov 11, 2022
1 parent 071a12d commit 18542ad
Show file tree
Hide file tree
Showing 2 changed files with 127 additions and 17 deletions.
47 changes: 33 additions & 14 deletions src/latexify/codegen/function_codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,37 @@ def visit_Set(self, node: ast.Set) -> str:
elts = [self.visit(i) for i in node.elts]
return r"\left\{ " + r"\space,\space ".join(elts) + r"\right\} "

def visit_ListComp(self, node: ast.ListComp) -> str:
generators = [self.visit(comp) for comp in node.generators]
return (
r"\left[ "
+ self.visit(node.elt)
+ r" \mid "
+ ", ".join(generators)
+ r" \right]"
)

def visit_SetComp(self, node: ast.SetComp) -> str:
generators = [self.visit(comp) for comp in node.generators]
return (
r"\left\{ "
+ self.visit(node.elt)
+ r" \mid "
+ ", ".join(generators)
+ r" \right\}"
)

def visit_comprehension(self, node: ast.comprehension) -> str:
target = rf"{self.visit(node.target)} \in {self.visit(node.iter)}"

if not node.ifs:
# Returns the source without parenthesis.
return target

conds = [target] + [self.visit(cond) for cond in node.ifs]
wrapped = [r"\left( " + s + r" \right)" for s in conds]
return r" \land ".join(wrapped)

def visit_Call(self, node: ast.Call) -> str:
"""Visit a call node."""
# Function signature (possibly an expression).
Expand Down Expand Up @@ -384,28 +415,16 @@ def _get_sum_prod_info(
scripts: list[tuple[str, str]] = []

for comp in node.generators:
target = self.visit(comp.target)
range_args = self._get_sum_prod_range(comp)

if range_args is not None and not comp.ifs:
target = self.visit(comp.target)
lower_rhs, upper = range_args
lower = f"{target} = {lower_rhs}"
else:
lower_rhs = self.visit(comp.iter)
lower_in = rf"{target} \in {lower_rhs}"
lower = self.visit(comp) # Use a usual comprehension form.
upper = ""

if comp.ifs:
conds = [lower_in] + [self.visit(cond) for cond in comp.ifs]
conds_wrapped = [r"\left(" + cond + r"\right)" for cond in conds]
lower = r" \land ".join(conds_wrapped)
# TODO(odashi):
# Following form may be prettier, but requires amsmath.
# It would be good if we have an option to switch the behavior.
# lower = r"\substack{" + r" \\ ".join(lowers) + "}"
else:
lower = lower_in

scripts.append((lower, upper))

return elt, scripts
Expand Down
97 changes: 94 additions & 3 deletions src/latexify/codegen/function_codegen_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,96 @@ def test_visit_functiondef_use_signature() -> None:
assert FunctionCodegen(use_signature=True).visit(tree) == latex_with_flag


@pytest.mark.parametrize(
"code,latex",
[
("[i for i in n]", r"\left[ i \mid i \in n \right]"),
(
"[i for i in n if i > 0]",
r"\left[ i \mid"
r" \left( i \in n \right)"
r" \land \left( {i > {0}} \right)"
r" \right]",
),
(
"[i for i in n if i > 0 if f(i)]",
r"\left[ i \mid"
r" \left( i \in n \right)"
r" \land \left( {i > {0}} \right)"
r" \land \left( \mathrm{f}\left(i\right) \right)"
r" \right]",
),
("[i for k in n for i in k]", r"\left[ i \mid k \in n, i \in k" r" \right]"),
(
"[i for k in n for i in k if i > 0]",
r"\left[ i \mid"
r" k \in n,"
r" \left( i \in k \right)"
r" \land \left( {i > {0}} \right)"
r" \right]",
),
(
"[i for k in n if f(k) for i in k if i > 0]",
r"\left[ i \mid"
r" \left( k \in n \right)"
r" \land \left( \mathrm{f}\left(k\right) \right),"
r" \left( i \in k \right)"
r" \land \left( {i > {0}} \right)"
r" \right]",
),
],
)
def test_visit_listcomp(code: str, latex: str) -> None:
node = ast.parse(code).body[0].value
assert isinstance(node, ast.ListComp)
assert FunctionCodegen().visit(node) == latex


@pytest.mark.parametrize(
"code,latex",
[
("{i for i in n}", r"\left\{ i \mid i \in n \right\}"),
(
"{i for i in n if i > 0}",
r"\left\{ i \mid"
r" \left( i \in n \right)"
r" \land \left( {i > {0}} \right)"
r" \right\}",
),
(
"{i for i in n if i > 0 if f(i)}",
r"\left\{ i \mid"
r" \left( i \in n \right)"
r" \land \left( {i > {0}} \right)"
r" \land \left( \mathrm{f}\left(i\right) \right)"
r" \right\}",
),
("{i for k in n for i in k}", r"\left\{ i \mid k \in n, i \in k" r" \right\}"),
(
"{i for k in n for i in k if i > 0}",
r"\left\{ i \mid"
r" k \in n,"
r" \left( i \in k \right)"
r" \land \left( {i > {0}} \right)"
r" \right\}",
),
(
"{i for k in n if f(k) for i in k if i > 0}",
r"\left\{ i \mid"
r" \left( k \in n \right)"
r" \land \left( \mathrm{f}\left(k\right) \right),"
r" \left( i \in k \right)"
r" \land \left( {i > {0}} \right)"
r" \right\}",
),
],
)
def test_visit_setcomp(code: str, latex: str) -> None:
node = ast.parse(code).body[0].value
assert isinstance(node, ast.SetComp)
assert FunctionCodegen().visit(node) == latex


@pytest.mark.parametrize(
"src_suffix,dest_suffix",
[
Expand Down Expand Up @@ -113,12 +203,13 @@ def test_visit_call_sum_prod_multiple_comprehension(code: str, latex: str) -> No
[
(
"(i for i in x if i < y)",
r"_{\left(i \in x\right) \land \left({i < y}\right)}^{} \left({i}\right)",
r"_{\left( i \in x \right) \land \left( {i < y} \right)}^{} "
r"\left({i}\right)",
),
(
"(i for i in x if i < y if f(i))",
r"_{\left(i \in x\right) \land \left({i < y}\right)"
r" \land \left(\mathrm{f}\left(i\right)\right)}^{}"
r"_{\left( i \in x \right) \land \left( {i < y} \right)"
r" \land \left( \mathrm{f}\left(i\right) \right)}^{}"
r" \left({i}\right)",
),
],
Expand Down

0 comments on commit 18542ad

Please sign in to comment.