Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add some comprehensions and set operations #59

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 55 additions & 21 deletions src/latexify/latexify_visitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
from latexify import math_symbols
from latexify import node_visitor_base


class LatexifyVisitor(node_visitor_base.NodeVisitorBase):
"""Latexify AST visitor."""

Expand Down Expand Up @@ -44,13 +43,13 @@ def __init__(

self.assign_var = {}

def generic_visit(self, node, action):
def generic_visit(self, node, _):
return str(node)

def visit_Module(self, node, action): # pylint: disable=invalid-name
def visit_Module(self, node, _): # pylint: disable=invalid-name
return self.visit(node.body[0], "multi_lines")

def visit_FunctionDef(self, node, action): # pylint: disable=invalid-name
def visit_FunctionDef(self, node, _): # pylint: disable=invalid-name
name_str = str(node.name)
if self._use_raw_function_name:
name_str = name_str.replace(r"_", r"\_")
Expand Down Expand Up @@ -93,33 +92,33 @@ def visit_FunctionDef_multi_lines(self, node):
)

def visit_FunctionDef_in_line(self, node):
name_str, arg_strs, assign_vars, body_str = self.visit_FunctionDef(node, None)
_, _, assign_vars, body_str = self.visit_FunctionDef(node, None)
return "".join(assign_vars) + body_str

def visit_Assign(self, node, action):
def visit_Assign(self, node, _):
var = self.visit(node.value)
if self._reduce_assignments:
self.assign_var[node.targets[0].id] = rf"\left( {var} \right)"
return None
else:
return rf"{node.targets[0].id} \triangleq {var} \\ "

def visit_Return(self, node, action): # pylint: disable=invalid-name
def visit_Return(self, node, _): # pylint: disable=invalid-name
return self.visit(node.value)

def visit_Tuple(self, node, action): # pylint: disable=invalid-name
def visit_Tuple(self, node, _): # pylint: disable=invalid-name
elts = [self.visit(i) for i in node.elts]
return r"\left( " + r"\space,\space ".join(elts) + r"\right) "

def visit_List(self, node, action): # pylint: disable=invalid-name
def visit_List(self, node, _): # pylint: disable=invalid-name
elts = [self.visit(i) for i in node.elts]
return r"\left[ " + r"\space,\space ".join(elts) + r"\right] "

def visit_Set(self, node, action): # pylint: disable=invalid-name
def visit_Set(self, node, _): # pylint: disable=invalid-name
elts = [self.visit(i) for i in node.elts]
return r"\left\{ " + r"\space,\space ".join(elts) + r"\right\} "

def visit_Call(self, node, action): # pylint: disable=invalid-name
def visit_Call(self, node, _): # pylint: disable=invalid-name
"""Visit a call node."""

def _decorated_lstr_and_arg(node, callee_str, lstr):
Expand Down Expand Up @@ -167,26 +166,26 @@ def _decorated_lstr_and_arg(node, callee_str, lstr):
lstr, arg_str = _decorated_lstr_and_arg(node, callee_str, lstr)
return lstr + arg_str + rstr

def visit_Attribute(self, node, action): # pylint: disable=invalid-name
def visit_Attribute(self, node, _): # pylint: disable=invalid-name
vstr = self.visit(node.value)
astr = str(node.attr)
return vstr + "." + astr

def visit_Name(self, node, action): # pylint: disable=invalid-name
def visit_Name(self, node, _): # pylint: disable=invalid-name
if self._reduce_assignments and node.id in self.assign_var.keys():
return self.assign_var[node.id]

return self._math_symbol_converter.convert(str(node.id))

def visit_Constant(self, node, action): # pylint: disable=invalid-name
def visit_Constant(self, node, _): # pylint: disable=invalid-name
# for python >= 3.8
return str(node.n)

def visit_Num(self, node, action): # pylint: disable=invalid-name
def visit_Num(self, node, _): # pylint: disable=invalid-name
# for python < 3.8
return str(node.n)

def visit_UnaryOp(self, node, action): # pylint: disable=invalid-name
def visit_UnaryOp(self, node, _): # pylint: disable=invalid-name
"""Visit a unary op node."""

def _wrap(child):
Expand All @@ -207,7 +206,7 @@ def _wrap(child):
return reprs[type(node.op)]()
return r"\mathrm{unknown\_uniop}(" + self.visit(node.operand) + ")"

def visit_BinOp(self, node, action): # pylint: disable=invalid-name
def visit_BinOp(self, node, _): # pylint: disable=invalid-name
"""Visit a binary op node."""
priority = constants.BIN_OP_PRIORITY

Expand All @@ -225,6 +224,18 @@ def _wrap(child):

lhs = node.left
rhs = node.right

left_type, right_type = type(node.left), type(node.right)
if left_type == right_type == ast.Set:
Comment on lines +228 to +229
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it is OK to go into the block if only either lhs or rhs is a Set.

Suggested change
left_type, right_type = type(node.left), type(node.right)
if left_type == right_type == ast.Set:
if isinstance(left_type, ast.Set) or isinstance(right_type, ast.Set):

set_reprs = {
ast.BitOr: (lambda: _unwrap(lhs) + r" \cup " + _unwrap(rhs)),
ast.BitAnd: (lambda: _unwrap(lhs) + r" \cap " + _unwrap(rhs)),
ast.Sub: (lambda: _unwrap(lhs) + r" \setminus " + _unwrap(rhs)),
ast.BitXor: (lambda: _unwrap(lhs) + r" \triangle " + _unwrap(rhs)),
}
if type(node.op) in set_reprs:
return set_reprs[type(node.op)]()

reprs = {
ast.Add: (lambda: _wrap(lhs) + " + " + _wrap(rhs)),
ast.Sub: (lambda: _wrap(lhs) + " - " + _wrap(rhs)),
Expand All @@ -246,7 +257,7 @@ def _wrap(child):
return reprs[type(node.op)]()
return r"\mathrm{unknown\_binop}(" + _unwrap(lhs) + ", " + _unwrap(rhs) + ")"

def visit_Compare(self, node, action): # pylint: disable=invalid-name
def visit_Compare(self, node, _): # pylint: disable=invalid-name
"""Visit a compare node."""
lstr = self.visit(node.left)
rstr = self.visit(node.comparators[0])
Expand All @@ -265,10 +276,14 @@ def visit_Compare(self, node, action): # pylint: disable=invalid-name
return lstr + r"\ne " + rstr
if isinstance(node.ops[0], ast.Is):
return lstr + r"\equiv" + rstr
if isinstance(node.ops[0], ast.In):
return lstr + r"\in " + rstr
if isinstance(node.ops[0], ast.NotIn):
return lstr + r"\notin " + rstr

return r"\mathrm{unknown\_comparator}(" + lstr + ", " + rstr + ")"

def visit_BoolOp(self, node, action): # pylint: disable=invalid-name
def visit_BoolOp(self, node, _): # pylint: disable=invalid-name
logic_operator = (
r"\lor "
if isinstance(node.op, ast.Or)
Expand All @@ -287,7 +302,7 @@ def visit_BoolOp(self, node, action): # pylint: disable=invalid-name
+ r"\right)"
)

def visit_If(self, node, action): # pylint: disable=invalid-name
def visit_If(self, node, _): # pylint: disable=invalid-name
"""Visit an if node."""
latex = r"\left\{ \begin{array}{ll} "

Expand All @@ -300,6 +315,25 @@ def visit_If(self, node, action): # pylint: disable=invalid-name
latex += self.visit(node)
return latex + r", & \mathrm{otherwise} \end{array} \right."

def visit_SetComp(self, node, _): # pylint: disable=invalid-name
result = eval(compile(ast.Expression(node), '', 'eval'))
Copy link
Collaborator

@odashi odashi Oct 22, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

eval is not allowed due to its vulnerability. Since latexify accepts almost any code, this invocation may cause an arbitrary code execution. it may be better to adopt ast.literal_eval (it also allows a raw syntax tree).

Same for other parts.


sorted_result = sorted(list(result))

return f"{{{str(sorted_result)[1:-1]}}}"

def visit_ListComp(self, node, _): # pylint: disable=invalid-name
result = eval(compile(ast.Expression(node), '', 'eval'))

sorted_result = sorted(result)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Especially for lists, this is not what the users wanted because lists are order-sensitive structure.


return str(sorted_result)

def visit_DictComp(self, node, _): # pylint: disable=invalid-name
result = eval(compile(ast.Expression(node), '', 'eval'))

return str(result)

def visit_GeneratorExp_set_bounds(self, node): # pylint: disable=invalid-name
action = constants.actions.SET_BOUNDS
output = self.visit(node.elt)
Expand All @@ -309,7 +343,7 @@ def visit_GeneratorExp_set_bounds(self, node): # pylint: disable=invalid-name
if len(comprehensions) == 1:
return output, comprehensions[0]
raise TypeError(
"visit_GeneratorExp_sum() supports a single for clause"
"visit_GeneratorExp_set_bounds() supports a single for clause"
"but {} were given".format(len(comprehensions))
)

Expand Down
68 changes: 68 additions & 0 deletions src/latexify/latexify_visitor_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
# Copyright 2020 Google LLC
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This file looks to contain only end-to-end tests. It may be good to either:

  • Move it into src/integration_tests
  • Rewrite tests to directly call LatexifyVisitor.

#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for core."""
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
"""Tests for core."""
"""Tests for latexify.latexify_visitor."""


import pytest

from latexify import with_latex

@with_latex
def list_comp():
return [x for x in range(1, 5) if x > 2]

def test_list_comps():
assert str(list_comp) == r"\mathrm{list_comp}() \triangleq [3, 4]"

@with_latex
def set_comp():
return {x for x in range(1, 5) if x > 2}

def test_set_comps():
assert str(set_comp) == r"\mathrm{set_comp}() \triangleq {3, 4}"

@with_latex
def dict_comp():
return {x: x + 1 for x in range(1, 5) if x > 2}

def test_dict_comps():
assert str(dict_comp) == r"\mathrm{dict_comp}() \triangleq {3: 4, 4: 5}"

@with_latex
def set_or():
return {1} | {2, 3}

def test_set_ops():
assert str(set_or) == r"\mathrm{set_or}() \triangleq \left\{ 1\right\} \cup \left\{ 2\space,\space 3\right\} "

@with_latex
def set_union():
return {1} & {1, 2}

def test_set_union():
assert str(set_union) == r"\mathrm{set_union}() \triangleq \left\{ 1\right\} \cap \left\{ 1\space,\space 2\right\} "

@with_latex
def set_xor():
return {1} ^ {1, 2}

def test_set_xor():
assert str(set_xor) == r"\mathrm{set_xor}() \triangleq \left\{ 1\right\} \triangle \left\{ 1\space,\space 2\right\} "

@with_latex
def set_sub():
return {1, 2} - {1}

def test_set_sub():
print(set_sub)
assert str(set_sub) == r"\mathrm{set_sub}() \triangleq \left\{ 1\space,\space 2\right\} \setminus \left\{ 1\right\} "
5 changes: 2 additions & 3 deletions src/latexify/node_visitor_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
# This is very scratchy and supports only limited portion of Python functions.
"""Definition of NodeVisitorBase class."""


class NodeVisitorBase:
"""Base class of LaTeXify's AST visitor.

Expand Down Expand Up @@ -46,7 +45,7 @@ class NodeVisitorBase:
visitor.visit(node, '123abc')
"""

def visit(self, node, action: str = None):
def visit(self, node, action: str | None = None):
"""Visits a node with specified action.

Args:
Expand All @@ -69,6 +68,6 @@ def visit(self, node, action: str = None):

raise AttributeError("{} is not callable.".format(method))

def generic_visit(self, node, action):
def generic_visit(self, *_):
"""Visitor method for all nodes without specific visitors."""
raise NotImplementedError("LatexifyVisitorBase.generic_visit")