From 61cfc885b0dbdc72218bd2726a1eb8761153c288 Mon Sep 17 00:00:00 2001 From: Takashi Idobe Date: Fri, 21 Oct 2022 16:56:20 -0400 Subject: [PATCH] add set ops --- src/latexify/latexify_visitor.py | 76 +++++++++++++++++++-------- src/latexify/latexify_visitor_test.py | 68 ++++++++++++++++++++++++ src/latexify/node_visitor_base.py | 5 +- 3 files changed, 125 insertions(+), 24 deletions(-) create mode 100644 src/latexify/latexify_visitor_test.py diff --git a/src/latexify/latexify_visitor.py b/src/latexify/latexify_visitor.py index 3094a3d..b5ed5c8 100644 --- a/src/latexify/latexify_visitor.py +++ b/src/latexify/latexify_visitor.py @@ -8,7 +8,6 @@ from latexify import math_symbols from latexify import node_visitor_base - class LatexifyVisitor(node_visitor_base.NodeVisitorBase): """Latexify AST visitor.""" @@ -44,13 +43,13 @@ def __init__( self.assign_var = {} - def generic_visit(self, node, action): + def generic_visit(self, node, _): return str(node) - def visit_Module(self, node, action): # pylint: disable=invalid-name + def visit_Module(self, node, _): # pylint: disable=invalid-name return self.visit(node.body[0], "multi_lines") - def visit_FunctionDef(self, node, action): # pylint: disable=invalid-name + def visit_FunctionDef(self, node, _): # pylint: disable=invalid-name name_str = str(node.name) if self._use_raw_function_name: name_str = name_str.replace(r"_", r"\_") @@ -93,10 +92,10 @@ def visit_FunctionDef_multi_lines(self, node): ) def visit_FunctionDef_in_line(self, node): - name_str, arg_strs, assign_vars, body_str = self.visit_FunctionDef(node, None) + _, _, assign_vars, body_str = self.visit_FunctionDef(node, None) return "".join(assign_vars) + body_str - def visit_Assign(self, node, action): + def visit_Assign(self, node, _): var = self.visit(node.value) if self._reduce_assignments: self.assign_var[node.targets[0].id] = rf"\left( {var} \right)" @@ -104,22 +103,22 @@ def visit_Assign(self, node, action): else: return rf"{node.targets[0].id} \triangleq {var} \\ " - def visit_Return(self, node, action): # pylint: disable=invalid-name + def visit_Return(self, node, _): # pylint: disable=invalid-name return self.visit(node.value) - def visit_Tuple(self, node, action): # pylint: disable=invalid-name + def visit_Tuple(self, node, _): # pylint: disable=invalid-name elts = [self.visit(i) for i in node.elts] return r"\left( " + r"\space,\space ".join(elts) + r"\right) " - def visit_List(self, node, action): # pylint: disable=invalid-name + def visit_List(self, node, _): # pylint: disable=invalid-name elts = [self.visit(i) for i in node.elts] return r"\left[ " + r"\space,\space ".join(elts) + r"\right] " - def visit_Set(self, node, action): # pylint: disable=invalid-name + def visit_Set(self, node, _): # pylint: disable=invalid-name elts = [self.visit(i) for i in node.elts] return r"\left\{ " + r"\space,\space ".join(elts) + r"\right\} " - def visit_Call(self, node, action): # pylint: disable=invalid-name + def visit_Call(self, node, _): # pylint: disable=invalid-name """Visit a call node.""" def _decorated_lstr_and_arg(node, callee_str, lstr): @@ -167,26 +166,26 @@ def _decorated_lstr_and_arg(node, callee_str, lstr): lstr, arg_str = _decorated_lstr_and_arg(node, callee_str, lstr) return lstr + arg_str + rstr - def visit_Attribute(self, node, action): # pylint: disable=invalid-name + def visit_Attribute(self, node, _): # pylint: disable=invalid-name vstr = self.visit(node.value) astr = str(node.attr) return vstr + "." + astr - def visit_Name(self, node, action): # pylint: disable=invalid-name + def visit_Name(self, node, _): # pylint: disable=invalid-name if self._reduce_assignments and node.id in self.assign_var.keys(): return self.assign_var[node.id] return self._math_symbol_converter.convert(str(node.id)) - def visit_Constant(self, node, action): # pylint: disable=invalid-name + def visit_Constant(self, node, _): # pylint: disable=invalid-name # for python >= 3.8 return str(node.n) - def visit_Num(self, node, action): # pylint: disable=invalid-name + def visit_Num(self, node, _): # pylint: disable=invalid-name # for python < 3.8 return str(node.n) - def visit_UnaryOp(self, node, action): # pylint: disable=invalid-name + def visit_UnaryOp(self, node, _): # pylint: disable=invalid-name """Visit a unary op node.""" def _wrap(child): @@ -207,7 +206,7 @@ def _wrap(child): return reprs[type(node.op)]() return r"\mathrm{unknown\_uniop}(" + self.visit(node.operand) + ")" - def visit_BinOp(self, node, action): # pylint: disable=invalid-name + def visit_BinOp(self, node, _): # pylint: disable=invalid-name """Visit a binary op node.""" priority = constants.BIN_OP_PRIORITY @@ -225,6 +224,18 @@ def _wrap(child): lhs = node.left rhs = node.right + + left_type, right_type = type(node.left), type(node.right) + if left_type == right_type == ast.Set: + set_reprs = { + ast.BitOr: (lambda: _unwrap(lhs) + r" \cup " + _unwrap(rhs)), + ast.BitAnd: (lambda: _unwrap(lhs) + r" \cap " + _unwrap(rhs)), + ast.Sub: (lambda: _unwrap(lhs) + r" \setminus " + _unwrap(rhs)), + ast.BitXor: (lambda: _unwrap(lhs) + r" \triangle " + _unwrap(rhs)), + } + if type(node.op) in set_reprs: + return set_reprs[type(node.op)]() + reprs = { ast.Add: (lambda: _wrap(lhs) + " + " + _wrap(rhs)), ast.Sub: (lambda: _wrap(lhs) + " - " + _wrap(rhs)), @@ -246,7 +257,7 @@ def _wrap(child): return reprs[type(node.op)]() return r"\mathrm{unknown\_binop}(" + _unwrap(lhs) + ", " + _unwrap(rhs) + ")" - def visit_Compare(self, node, action): # pylint: disable=invalid-name + def visit_Compare(self, node, _): # pylint: disable=invalid-name """Visit a compare node.""" lstr = self.visit(node.left) rstr = self.visit(node.comparators[0]) @@ -265,10 +276,14 @@ def visit_Compare(self, node, action): # pylint: disable=invalid-name return lstr + r"\ne " + rstr if isinstance(node.ops[0], ast.Is): return lstr + r"\equiv" + rstr + if isinstance(node.ops[0], ast.In): + return lstr + r"\in " + rstr + if isinstance(node.ops[0], ast.NotIn): + return lstr + r"\notin " + rstr return r"\mathrm{unknown\_comparator}(" + lstr + ", " + rstr + ")" - def visit_BoolOp(self, node, action): # pylint: disable=invalid-name + def visit_BoolOp(self, node, _): # pylint: disable=invalid-name logic_operator = ( r"\lor " if isinstance(node.op, ast.Or) @@ -287,7 +302,7 @@ def visit_BoolOp(self, node, action): # pylint: disable=invalid-name + r"\right)" ) - def visit_If(self, node, action): # pylint: disable=invalid-name + def visit_If(self, node, _): # pylint: disable=invalid-name """Visit an if node.""" latex = r"\left\{ \begin{array}{ll} " @@ -300,6 +315,25 @@ def visit_If(self, node, action): # pylint: disable=invalid-name latex += self.visit(node) return latex + r", & \mathrm{otherwise} \end{array} \right." + def visit_SetComp(self, node, _): # pylint: disable=invalid-name + result = eval(compile(ast.Expression(node), '', 'eval')) + + sorted_result = sorted(list(result)) + + return f"{{{str(sorted_result)[1:-1]}}}" + + def visit_ListComp(self, node, _): # pylint: disable=invalid-name + result = eval(compile(ast.Expression(node), '', 'eval')) + + sorted_result = sorted(result) + + return str(sorted_result) + + def visit_DictComp(self, node, _): # pylint: disable=invalid-name + result = eval(compile(ast.Expression(node), '', 'eval')) + + return str(result) + def visit_GeneratorExp_set_bounds(self, node): # pylint: disable=invalid-name action = constants.actions.SET_BOUNDS output = self.visit(node.elt) @@ -309,7 +343,7 @@ def visit_GeneratorExp_set_bounds(self, node): # pylint: disable=invalid-name if len(comprehensions) == 1: return output, comprehensions[0] raise TypeError( - "visit_GeneratorExp_sum() supports a single for clause" + "visit_GeneratorExp_set_bounds() supports a single for clause" "but {} were given".format(len(comprehensions)) ) diff --git a/src/latexify/latexify_visitor_test.py b/src/latexify/latexify_visitor_test.py new file mode 100644 index 0000000..ff72aac --- /dev/null +++ b/src/latexify/latexify_visitor_test.py @@ -0,0 +1,68 @@ +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tests for core.""" + +import pytest + +from latexify import with_latex + +@with_latex +def list_comp(): + return [x for x in range(1, 5) if x > 2] + +def test_list_comps(): + assert str(list_comp) == r"\mathrm{list_comp}() \triangleq [3, 4]" + +@with_latex +def set_comp(): + return {x for x in range(1, 5) if x > 2} + +def test_set_comps(): + assert str(set_comp) == r"\mathrm{set_comp}() \triangleq {3, 4}" + +@with_latex +def dict_comp(): + return {x: x + 1 for x in range(1, 5) if x > 2} + +def test_dict_comps(): + assert str(dict_comp) == r"\mathrm{dict_comp}() \triangleq {3: 4, 4: 5}" + +@with_latex +def set_or(): + return {1} | {2, 3} + +def test_set_ops(): + assert str(set_or) == r"\mathrm{set_or}() \triangleq \left\{ 1\right\} \cup \left\{ 2\space,\space 3\right\} " + +@with_latex +def set_union(): + return {1} & {1, 2} + +def test_set_union(): + assert str(set_union) == r"\mathrm{set_union}() \triangleq \left\{ 1\right\} \cap \left\{ 1\space,\space 2\right\} " + +@with_latex +def set_xor(): + return {1} ^ {1, 2} + +def test_set_xor(): + assert str(set_xor) == r"\mathrm{set_xor}() \triangleq \left\{ 1\right\} \triangle \left\{ 1\space,\space 2\right\} " + +@with_latex +def set_sub(): + return {1, 2} - {1} + +def test_set_sub(): + print(set_sub) + assert str(set_sub) == r"\mathrm{set_sub}() \triangleq \left\{ 1\space,\space 2\right\} \setminus \left\{ 1\right\} " diff --git a/src/latexify/node_visitor_base.py b/src/latexify/node_visitor_base.py index 7aa8948..f211ec2 100644 --- a/src/latexify/node_visitor_base.py +++ b/src/latexify/node_visitor_base.py @@ -15,7 +15,6 @@ # This is very scratchy and supports only limited portion of Python functions. """Definition of NodeVisitorBase class.""" - class NodeVisitorBase: """Base class of LaTeXify's AST visitor. @@ -46,7 +45,7 @@ class NodeVisitorBase: visitor.visit(node, '123abc') """ - def visit(self, node, action: str = None): + def visit(self, node, action: str | None = None): """Visits a node with specified action. Args: @@ -69,6 +68,6 @@ def visit(self, node, action: str = None): raise AttributeError("{} is not callable.".format(method)) - def generic_visit(self, node, action): + def generic_visit(self, *_): """Visitor method for all nodes without specific visitors.""" raise NotImplementedError("LatexifyVisitorBase.generic_visit")