Skip to content

Commit

Permalink
add set ops
Browse files Browse the repository at this point in the history
  • Loading branch information
Takashiidobe committed Oct 21, 2022
1 parent 1421ddb commit 61cfc88
Show file tree
Hide file tree
Showing 3 changed files with 125 additions and 24 deletions.
76 changes: 55 additions & 21 deletions src/latexify/latexify_visitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
from latexify import math_symbols
from latexify import node_visitor_base


class LatexifyVisitor(node_visitor_base.NodeVisitorBase):
"""Latexify AST visitor."""

Expand Down Expand Up @@ -44,13 +43,13 @@ def __init__(

self.assign_var = {}

def generic_visit(self, node, action):
def generic_visit(self, node, _):
return str(node)

def visit_Module(self, node, action): # pylint: disable=invalid-name
def visit_Module(self, node, _): # pylint: disable=invalid-name
return self.visit(node.body[0], "multi_lines")

def visit_FunctionDef(self, node, action): # pylint: disable=invalid-name
def visit_FunctionDef(self, node, _): # pylint: disable=invalid-name
name_str = str(node.name)
if self._use_raw_function_name:
name_str = name_str.replace(r"_", r"\_")
Expand Down Expand Up @@ -93,33 +92,33 @@ def visit_FunctionDef_multi_lines(self, node):
)

def visit_FunctionDef_in_line(self, node):
name_str, arg_strs, assign_vars, body_str = self.visit_FunctionDef(node, None)
_, _, assign_vars, body_str = self.visit_FunctionDef(node, None)
return "".join(assign_vars) + body_str

def visit_Assign(self, node, action):
def visit_Assign(self, node, _):
var = self.visit(node.value)
if self._reduce_assignments:
self.assign_var[node.targets[0].id] = rf"\left( {var} \right)"
return None
else:
return rf"{node.targets[0].id} \triangleq {var} \\ "

def visit_Return(self, node, action): # pylint: disable=invalid-name
def visit_Return(self, node, _): # pylint: disable=invalid-name
return self.visit(node.value)

def visit_Tuple(self, node, action): # pylint: disable=invalid-name
def visit_Tuple(self, node, _): # pylint: disable=invalid-name
elts = [self.visit(i) for i in node.elts]
return r"\left( " + r"\space,\space ".join(elts) + r"\right) "

def visit_List(self, node, action): # pylint: disable=invalid-name
def visit_List(self, node, _): # pylint: disable=invalid-name
elts = [self.visit(i) for i in node.elts]
return r"\left[ " + r"\space,\space ".join(elts) + r"\right] "

def visit_Set(self, node, action): # pylint: disable=invalid-name
def visit_Set(self, node, _): # pylint: disable=invalid-name
elts = [self.visit(i) for i in node.elts]
return r"\left\{ " + r"\space,\space ".join(elts) + r"\right\} "

def visit_Call(self, node, action): # pylint: disable=invalid-name
def visit_Call(self, node, _): # pylint: disable=invalid-name
"""Visit a call node."""

def _decorated_lstr_and_arg(node, callee_str, lstr):
Expand Down Expand Up @@ -167,26 +166,26 @@ def _decorated_lstr_and_arg(node, callee_str, lstr):
lstr, arg_str = _decorated_lstr_and_arg(node, callee_str, lstr)
return lstr + arg_str + rstr

def visit_Attribute(self, node, action): # pylint: disable=invalid-name
def visit_Attribute(self, node, _): # pylint: disable=invalid-name
vstr = self.visit(node.value)
astr = str(node.attr)
return vstr + "." + astr

def visit_Name(self, node, action): # pylint: disable=invalid-name
def visit_Name(self, node, _): # pylint: disable=invalid-name
if self._reduce_assignments and node.id in self.assign_var.keys():
return self.assign_var[node.id]

return self._math_symbol_converter.convert(str(node.id))

def visit_Constant(self, node, action): # pylint: disable=invalid-name
def visit_Constant(self, node, _): # pylint: disable=invalid-name
# for python >= 3.8
return str(node.n)

def visit_Num(self, node, action): # pylint: disable=invalid-name
def visit_Num(self, node, _): # pylint: disable=invalid-name
# for python < 3.8
return str(node.n)

def visit_UnaryOp(self, node, action): # pylint: disable=invalid-name
def visit_UnaryOp(self, node, _): # pylint: disable=invalid-name
"""Visit a unary op node."""

def _wrap(child):
Expand All @@ -207,7 +206,7 @@ def _wrap(child):
return reprs[type(node.op)]()
return r"\mathrm{unknown\_uniop}(" + self.visit(node.operand) + ")"

def visit_BinOp(self, node, action): # pylint: disable=invalid-name
def visit_BinOp(self, node, _): # pylint: disable=invalid-name
"""Visit a binary op node."""
priority = constants.BIN_OP_PRIORITY

Expand All @@ -225,6 +224,18 @@ def _wrap(child):

lhs = node.left
rhs = node.right

left_type, right_type = type(node.left), type(node.right)
if left_type == right_type == ast.Set:
set_reprs = {
ast.BitOr: (lambda: _unwrap(lhs) + r" \cup " + _unwrap(rhs)),
ast.BitAnd: (lambda: _unwrap(lhs) + r" \cap " + _unwrap(rhs)),
ast.Sub: (lambda: _unwrap(lhs) + r" \setminus " + _unwrap(rhs)),
ast.BitXor: (lambda: _unwrap(lhs) + r" \triangle " + _unwrap(rhs)),
}
if type(node.op) in set_reprs:
return set_reprs[type(node.op)]()

reprs = {
ast.Add: (lambda: _wrap(lhs) + " + " + _wrap(rhs)),
ast.Sub: (lambda: _wrap(lhs) + " - " + _wrap(rhs)),
Expand All @@ -246,7 +257,7 @@ def _wrap(child):
return reprs[type(node.op)]()
return r"\mathrm{unknown\_binop}(" + _unwrap(lhs) + ", " + _unwrap(rhs) + ")"

def visit_Compare(self, node, action): # pylint: disable=invalid-name
def visit_Compare(self, node, _): # pylint: disable=invalid-name
"""Visit a compare node."""
lstr = self.visit(node.left)
rstr = self.visit(node.comparators[0])
Expand All @@ -265,10 +276,14 @@ def visit_Compare(self, node, action): # pylint: disable=invalid-name
return lstr + r"\ne " + rstr
if isinstance(node.ops[0], ast.Is):
return lstr + r"\equiv" + rstr
if isinstance(node.ops[0], ast.In):
return lstr + r"\in " + rstr
if isinstance(node.ops[0], ast.NotIn):
return lstr + r"\notin " + rstr

return r"\mathrm{unknown\_comparator}(" + lstr + ", " + rstr + ")"

def visit_BoolOp(self, node, action): # pylint: disable=invalid-name
def visit_BoolOp(self, node, _): # pylint: disable=invalid-name
logic_operator = (
r"\lor "
if isinstance(node.op, ast.Or)
Expand All @@ -287,7 +302,7 @@ def visit_BoolOp(self, node, action): # pylint: disable=invalid-name
+ r"\right)"
)

def visit_If(self, node, action): # pylint: disable=invalid-name
def visit_If(self, node, _): # pylint: disable=invalid-name
"""Visit an if node."""
latex = r"\left\{ \begin{array}{ll} "

Expand All @@ -300,6 +315,25 @@ def visit_If(self, node, action): # pylint: disable=invalid-name
latex += self.visit(node)
return latex + r", & \mathrm{otherwise} \end{array} \right."

def visit_SetComp(self, node, _): # pylint: disable=invalid-name
result = eval(compile(ast.Expression(node), '', 'eval'))

sorted_result = sorted(list(result))

return f"{{{str(sorted_result)[1:-1]}}}"

def visit_ListComp(self, node, _): # pylint: disable=invalid-name
result = eval(compile(ast.Expression(node), '', 'eval'))

sorted_result = sorted(result)

return str(sorted_result)

def visit_DictComp(self, node, _): # pylint: disable=invalid-name
result = eval(compile(ast.Expression(node), '', 'eval'))

return str(result)

def visit_GeneratorExp_set_bounds(self, node): # pylint: disable=invalid-name
action = constants.actions.SET_BOUNDS
output = self.visit(node.elt)
Expand All @@ -309,7 +343,7 @@ def visit_GeneratorExp_set_bounds(self, node): # pylint: disable=invalid-name
if len(comprehensions) == 1:
return output, comprehensions[0]
raise TypeError(
"visit_GeneratorExp_sum() supports a single for clause"
"visit_GeneratorExp_set_bounds() supports a single for clause"
"but {} were given".format(len(comprehensions))
)

Expand Down
68 changes: 68 additions & 0 deletions src/latexify/latexify_visitor_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
# Copyright 2020 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for core."""

import pytest

from latexify import with_latex

@with_latex
def list_comp():
return [x for x in range(1, 5) if x > 2]

def test_list_comps():
assert str(list_comp) == r"\mathrm{list_comp}() \triangleq [3, 4]"

@with_latex
def set_comp():
return {x for x in range(1, 5) if x > 2}

def test_set_comps():
assert str(set_comp) == r"\mathrm{set_comp}() \triangleq {3, 4}"

@with_latex
def dict_comp():
return {x: x + 1 for x in range(1, 5) if x > 2}

def test_dict_comps():
assert str(dict_comp) == r"\mathrm{dict_comp}() \triangleq {3: 4, 4: 5}"

@with_latex
def set_or():
return {1} | {2, 3}

def test_set_ops():
assert str(set_or) == r"\mathrm{set_or}() \triangleq \left\{ 1\right\} \cup \left\{ 2\space,\space 3\right\} "

@with_latex
def set_union():
return {1} & {1, 2}

def test_set_union():
assert str(set_union) == r"\mathrm{set_union}() \triangleq \left\{ 1\right\} \cap \left\{ 1\space,\space 2\right\} "

@with_latex
def set_xor():
return {1} ^ {1, 2}

def test_set_xor():
assert str(set_xor) == r"\mathrm{set_xor}() \triangleq \left\{ 1\right\} \triangle \left\{ 1\space,\space 2\right\} "

@with_latex
def set_sub():
return {1, 2} - {1}

def test_set_sub():
print(set_sub)
assert str(set_sub) == r"\mathrm{set_sub}() \triangleq \left\{ 1\space,\space 2\right\} \setminus \left\{ 1\right\} "
5 changes: 2 additions & 3 deletions src/latexify/node_visitor_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
# This is very scratchy and supports only limited portion of Python functions.
"""Definition of NodeVisitorBase class."""


class NodeVisitorBase:
"""Base class of LaTeXify's AST visitor.
Expand Down Expand Up @@ -46,7 +45,7 @@ class NodeVisitorBase:
visitor.visit(node, '123abc')
"""

def visit(self, node, action: str = None):
def visit(self, node, action: str | None = None):
"""Visits a node with specified action.
Args:
Expand All @@ -69,6 +68,6 @@ def visit(self, node, action: str = None):

raise AttributeError("{} is not callable.".format(method))

def generic_visit(self, node, action):
def generic_visit(self, *_):
"""Visitor method for all nodes without specific visitors."""
raise NotImplementedError("LatexifyVisitorBase.generic_visit")

0 comments on commit 61cfc88

Please sign in to comment.