-
Notifications
You must be signed in to change notification settings - Fork 387
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add some comprehensions and set operations #59
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -8,7 +8,6 @@ | |
from latexify import math_symbols | ||
from latexify import node_visitor_base | ||
|
||
|
||
class LatexifyVisitor(node_visitor_base.NodeVisitorBase): | ||
"""Latexify AST visitor.""" | ||
|
||
|
@@ -44,13 +43,13 @@ def __init__( | |
|
||
self.assign_var = {} | ||
|
||
def generic_visit(self, node, action): | ||
def generic_visit(self, node, _): | ||
return str(node) | ||
|
||
def visit_Module(self, node, action): # pylint: disable=invalid-name | ||
def visit_Module(self, node, _): # pylint: disable=invalid-name | ||
return self.visit(node.body[0], "multi_lines") | ||
|
||
def visit_FunctionDef(self, node, action): # pylint: disable=invalid-name | ||
def visit_FunctionDef(self, node, _): # pylint: disable=invalid-name | ||
name_str = str(node.name) | ||
if self._use_raw_function_name: | ||
name_str = name_str.replace(r"_", r"\_") | ||
|
@@ -93,33 +92,33 @@ def visit_FunctionDef_multi_lines(self, node): | |
) | ||
|
||
def visit_FunctionDef_in_line(self, node): | ||
name_str, arg_strs, assign_vars, body_str = self.visit_FunctionDef(node, None) | ||
_, _, assign_vars, body_str = self.visit_FunctionDef(node, None) | ||
return "".join(assign_vars) + body_str | ||
|
||
def visit_Assign(self, node, action): | ||
def visit_Assign(self, node, _): | ||
var = self.visit(node.value) | ||
if self._reduce_assignments: | ||
self.assign_var[node.targets[0].id] = rf"\left( {var} \right)" | ||
return None | ||
else: | ||
return rf"{node.targets[0].id} \triangleq {var} \\ " | ||
|
||
def visit_Return(self, node, action): # pylint: disable=invalid-name | ||
def visit_Return(self, node, _): # pylint: disable=invalid-name | ||
return self.visit(node.value) | ||
|
||
def visit_Tuple(self, node, action): # pylint: disable=invalid-name | ||
def visit_Tuple(self, node, _): # pylint: disable=invalid-name | ||
elts = [self.visit(i) for i in node.elts] | ||
return r"\left( " + r"\space,\space ".join(elts) + r"\right) " | ||
|
||
def visit_List(self, node, action): # pylint: disable=invalid-name | ||
def visit_List(self, node, _): # pylint: disable=invalid-name | ||
elts = [self.visit(i) for i in node.elts] | ||
return r"\left[ " + r"\space,\space ".join(elts) + r"\right] " | ||
|
||
def visit_Set(self, node, action): # pylint: disable=invalid-name | ||
def visit_Set(self, node, _): # pylint: disable=invalid-name | ||
elts = [self.visit(i) for i in node.elts] | ||
return r"\left\{ " + r"\space,\space ".join(elts) + r"\right\} " | ||
|
||
def visit_Call(self, node, action): # pylint: disable=invalid-name | ||
def visit_Call(self, node, _): # pylint: disable=invalid-name | ||
"""Visit a call node.""" | ||
|
||
def _decorated_lstr_and_arg(node, callee_str, lstr): | ||
|
@@ -167,26 +166,26 @@ def _decorated_lstr_and_arg(node, callee_str, lstr): | |
lstr, arg_str = _decorated_lstr_and_arg(node, callee_str, lstr) | ||
return lstr + arg_str + rstr | ||
|
||
def visit_Attribute(self, node, action): # pylint: disable=invalid-name | ||
def visit_Attribute(self, node, _): # pylint: disable=invalid-name | ||
vstr = self.visit(node.value) | ||
astr = str(node.attr) | ||
return vstr + "." + astr | ||
|
||
def visit_Name(self, node, action): # pylint: disable=invalid-name | ||
def visit_Name(self, node, _): # pylint: disable=invalid-name | ||
if self._reduce_assignments and node.id in self.assign_var.keys(): | ||
return self.assign_var[node.id] | ||
|
||
return self._math_symbol_converter.convert(str(node.id)) | ||
|
||
def visit_Constant(self, node, action): # pylint: disable=invalid-name | ||
def visit_Constant(self, node, _): # pylint: disable=invalid-name | ||
# for python >= 3.8 | ||
return str(node.n) | ||
|
||
def visit_Num(self, node, action): # pylint: disable=invalid-name | ||
def visit_Num(self, node, _): # pylint: disable=invalid-name | ||
# for python < 3.8 | ||
return str(node.n) | ||
|
||
def visit_UnaryOp(self, node, action): # pylint: disable=invalid-name | ||
def visit_UnaryOp(self, node, _): # pylint: disable=invalid-name | ||
"""Visit a unary op node.""" | ||
|
||
def _wrap(child): | ||
|
@@ -207,7 +206,7 @@ def _wrap(child): | |
return reprs[type(node.op)]() | ||
return r"\mathrm{unknown\_uniop}(" + self.visit(node.operand) + ")" | ||
|
||
def visit_BinOp(self, node, action): # pylint: disable=invalid-name | ||
def visit_BinOp(self, node, _): # pylint: disable=invalid-name | ||
"""Visit a binary op node.""" | ||
priority = constants.BIN_OP_PRIORITY | ||
|
||
|
@@ -225,6 +224,18 @@ def _wrap(child): | |
|
||
lhs = node.left | ||
rhs = node.right | ||
|
||
left_type, right_type = type(node.left), type(node.right) | ||
if left_type == right_type == ast.Set: | ||
set_reprs = { | ||
ast.BitOr: (lambda: _unwrap(lhs) + r" \cup " + _unwrap(rhs)), | ||
ast.BitAnd: (lambda: _unwrap(lhs) + r" \cap " + _unwrap(rhs)), | ||
ast.Sub: (lambda: _unwrap(lhs) + r" \setminus " + _unwrap(rhs)), | ||
ast.BitXor: (lambda: _unwrap(lhs) + r" \triangle " + _unwrap(rhs)), | ||
} | ||
if type(node.op) in set_reprs: | ||
return set_reprs[type(node.op)]() | ||
|
||
reprs = { | ||
ast.Add: (lambda: _wrap(lhs) + " + " + _wrap(rhs)), | ||
ast.Sub: (lambda: _wrap(lhs) + " - " + _wrap(rhs)), | ||
|
@@ -246,7 +257,7 @@ def _wrap(child): | |
return reprs[type(node.op)]() | ||
return r"\mathrm{unknown\_binop}(" + _unwrap(lhs) + ", " + _unwrap(rhs) + ")" | ||
|
||
def visit_Compare(self, node, action): # pylint: disable=invalid-name | ||
def visit_Compare(self, node, _): # pylint: disable=invalid-name | ||
"""Visit a compare node.""" | ||
lstr = self.visit(node.left) | ||
rstr = self.visit(node.comparators[0]) | ||
|
@@ -265,10 +276,14 @@ def visit_Compare(self, node, action): # pylint: disable=invalid-name | |
return lstr + r"\ne " + rstr | ||
if isinstance(node.ops[0], ast.Is): | ||
return lstr + r"\equiv" + rstr | ||
if isinstance(node.ops[0], ast.In): | ||
return lstr + r"\in " + rstr | ||
if isinstance(node.ops[0], ast.NotIn): | ||
return lstr + r"\notin " + rstr | ||
|
||
return r"\mathrm{unknown\_comparator}(" + lstr + ", " + rstr + ")" | ||
|
||
def visit_BoolOp(self, node, action): # pylint: disable=invalid-name | ||
def visit_BoolOp(self, node, _): # pylint: disable=invalid-name | ||
logic_operator = ( | ||
r"\lor " | ||
if isinstance(node.op, ast.Or) | ||
|
@@ -287,7 +302,7 @@ def visit_BoolOp(self, node, action): # pylint: disable=invalid-name | |
+ r"\right)" | ||
) | ||
|
||
def visit_If(self, node, action): # pylint: disable=invalid-name | ||
def visit_If(self, node, _): # pylint: disable=invalid-name | ||
"""Visit an if node.""" | ||
latex = r"\left\{ \begin{array}{ll} " | ||
|
||
|
@@ -300,6 +315,25 @@ def visit_If(self, node, action): # pylint: disable=invalid-name | |
latex += self.visit(node) | ||
return latex + r", & \mathrm{otherwise} \end{array} \right." | ||
|
||
def visit_SetComp(self, node, _): # pylint: disable=invalid-name | ||
result = eval(compile(ast.Expression(node), '', 'eval')) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Same for other parts. |
||
|
||
sorted_result = sorted(list(result)) | ||
|
||
return f"{{{str(sorted_result)[1:-1]}}}" | ||
|
||
def visit_ListComp(self, node, _): # pylint: disable=invalid-name | ||
result = eval(compile(ast.Expression(node), '', 'eval')) | ||
|
||
sorted_result = sorted(result) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Especially for lists, this is not what the users wanted because lists are order-sensitive structure. |
||
|
||
return str(sorted_result) | ||
|
||
def visit_DictComp(self, node, _): # pylint: disable=invalid-name | ||
result = eval(compile(ast.Expression(node), '', 'eval')) | ||
|
||
return str(result) | ||
|
||
def visit_GeneratorExp_set_bounds(self, node): # pylint: disable=invalid-name | ||
action = constants.actions.SET_BOUNDS | ||
output = self.visit(node.elt) | ||
|
@@ -309,7 +343,7 @@ def visit_GeneratorExp_set_bounds(self, node): # pylint: disable=invalid-name | |
if len(comprehensions) == 1: | ||
return output, comprehensions[0] | ||
raise TypeError( | ||
"visit_GeneratorExp_sum() supports a single for clause" | ||
"visit_GeneratorExp_set_bounds() supports a single for clause" | ||
"but {} were given".format(len(comprehensions)) | ||
) | ||
|
||
|
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
@@ -0,0 +1,68 @@ | ||||||
# Copyright 2020 Google LLC | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This file looks to contain only end-to-end tests. It may be good to either:
|
||||||
# | ||||||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||||||
# you may not use this file except in compliance with the License. | ||||||
# You may obtain a copy of the License at | ||||||
# | ||||||
# http://www.apache.org/licenses/LICENSE-2.0 | ||||||
# | ||||||
# Unless required by applicable law or agreed to in writing, software | ||||||
# distributed under the License is distributed on an "AS IS" BASIS, | ||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||||
# See the License for the specific language governing permissions and | ||||||
# limitations under the License. | ||||||
"""Tests for core.""" | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
|
||||||
import pytest | ||||||
|
||||||
from latexify import with_latex | ||||||
|
||||||
@with_latex | ||||||
def list_comp(): | ||||||
return [x for x in range(1, 5) if x > 2] | ||||||
|
||||||
def test_list_comps(): | ||||||
assert str(list_comp) == r"\mathrm{list_comp}() \triangleq [3, 4]" | ||||||
|
||||||
@with_latex | ||||||
def set_comp(): | ||||||
return {x for x in range(1, 5) if x > 2} | ||||||
|
||||||
def test_set_comps(): | ||||||
assert str(set_comp) == r"\mathrm{set_comp}() \triangleq {3, 4}" | ||||||
|
||||||
@with_latex | ||||||
def dict_comp(): | ||||||
return {x: x + 1 for x in range(1, 5) if x > 2} | ||||||
|
||||||
def test_dict_comps(): | ||||||
assert str(dict_comp) == r"\mathrm{dict_comp}() \triangleq {3: 4, 4: 5}" | ||||||
|
||||||
@with_latex | ||||||
def set_or(): | ||||||
return {1} | {2, 3} | ||||||
|
||||||
def test_set_ops(): | ||||||
assert str(set_or) == r"\mathrm{set_or}() \triangleq \left\{ 1\right\} \cup \left\{ 2\space,\space 3\right\} " | ||||||
|
||||||
@with_latex | ||||||
def set_union(): | ||||||
return {1} & {1, 2} | ||||||
|
||||||
def test_set_union(): | ||||||
assert str(set_union) == r"\mathrm{set_union}() \triangleq \left\{ 1\right\} \cap \left\{ 1\space,\space 2\right\} " | ||||||
|
||||||
@with_latex | ||||||
def set_xor(): | ||||||
return {1} ^ {1, 2} | ||||||
|
||||||
def test_set_xor(): | ||||||
assert str(set_xor) == r"\mathrm{set_xor}() \triangleq \left\{ 1\right\} \triangle \left\{ 1\space,\space 2\right\} " | ||||||
|
||||||
@with_latex | ||||||
def set_sub(): | ||||||
return {1, 2} - {1} | ||||||
|
||||||
def test_set_sub(): | ||||||
print(set_sub) | ||||||
assert str(set_sub) == r"\mathrm{set_sub}() \triangleq \left\{ 1\space,\space 2\right\} \setminus \left\{ 1\right\} " |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it is OK to go into the block if only either lhs or rhs is a Set.