Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix the bug in evaluation of expressions with elements of the form Unevaluated[elem] #628

Draft
wants to merge 10 commits into
base: master
Choose a base branch
from
209 changes: 119 additions & 90 deletions mathics/core/expression.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,12 @@
)
from mathics.core.convert.python import from_python
from mathics.core.convert.sympy import SympyExpression, sympy_symbol_prefix
from mathics.core.element import ElementsProperties, EvalMixin, ensure_context
from mathics.core.element import (
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are we having isort/black wars? isort should be followed by black. isort will respect black's formatting of isort'ed code. However running black and then isort, may cause isort to change the formatting when isort wants to reorder things.

BaseElement,
ElementsProperties,
EvalMixin,
ensure_context,
)
from mathics.core.evaluation import Evaluation
from mathics.core.interrupt import ReturnInterrupt
from mathics.core.structure import LinkedStructure
Expand Down Expand Up @@ -621,31 +626,30 @@ def flatten_with_respect_to_head(
if self._does_not_contain_symbol(head.get_name()):
return self
sub_level = level - 1
do_flatten = False
for element in self._elements:
if element.get_head().sameQ(head) and (
indx_to_flatten = []
for idx, element in enumerate(self._elements):
if (
not pattern_only or element.pattern_sequence
):
do_flatten = True
break
if do_flatten:
new_elements = []
for element in self._elements:
if element.get_head().sameQ(head) and (
not pattern_only or element.pattern_sequence
):
new_element = element.flatten_with_respect_to_head(
head, pattern_only, callback, level=sub_level
)
if callback is not None:
callback(new_element._elements, element)
new_elements.extend(new_element._elements)
else:
new_elements.append(element)
return to_expression_with_specialization(self._head, *new_elements)
else:
) and element.get_head().sameQ(head):
indx_to_flatten.append(idx)

if len(indx_to_flatten) == 0:
return self

new_elements = []
for idx, element in enumerate(self._elements):
if len(indx_to_flatten) > 0 and idx == indx_to_flatten[0]:
indx_to_flatten.pop(0)
new_element = element.flatten_with_respect_to_head(
head, pattern_only, callback, level=sub_level
)
if callback is not None:
callback(new_element._elements, element)
new_elements.extend(new_element._elements)
else:
new_elements.append(element)
return to_expression_with_specialization(self._head, *new_elements)

def get_atoms(self, include_heads=True):
"""Returns a list of atoms involved in the expression."""
# Comment @mmatera: maybe, what we really want here are the Symbol's
Expand Down Expand Up @@ -1029,36 +1033,99 @@ def rewrite_apply_eval_step(self, evaluation) -> Tuple["Expression", bool]:
See also https://mathics-development-guide.readthedocs.io/en/latest/extending/code-overview/evaluation.html#detailed-rewrite-apply-eval-process
"""

# Internal class and functions to handle ``Unevaluated`` elements
class UnevaluatedWrapper(BaseElement):
"""
This class is used to wrap the argument of
elements of the form ``Unevaluated[expr_]``,
to provide the right behaviour under
``flatten_with_respect_to_head``, sort and thread.
The wrapper is removed before step of looking and applying
rules.
If no rule is successfully applied, then the wrapper is converted
again into an expression of the form Unevaluated(expr_)
"""

def __init__(self, expr: BaseElement):
self.expr = expr
self.elements_properties = ElementsProperties(True, True, True)

def get_head(self):
return self.expr.get_head()

def __repr__(self):
return f"<<UnevaluatedWrapper[{self.expr.__repr__()}]>>"

def get_sort_key(self, pattern_sort=False):
# this ensures that when the elements of an expression are
# sorted, elements tagged follows the corresponding untagged
# elements.
return self.expr.get_sort_key(pattern_sort) + ("Unevaluated",)

@property
def is_literal(self):
return False

def flatten_with_respect_to_head(
self, head, pattern_only=False, callback=None, level=100
):
flatten_expr = self.expr.flatten_with_respect_to_head(
head=head, pattern_only=pattern_only, callback=callback, level=level
)
# distribute the tag over the elements.
marked_elements = tuple(
UnevaluatedWrapper(element) for element in flatten_expr._elements
)
return Expression(self.expr._head, *marked_elements)

def strip_unevaluated_wrapper(expr_with_wrappers):
items = (
element.expr if isinstance(element, UnevaluatedWrapper) else element
for element in expr_with_wrappers._elements
)
return Expression(expr_with_wrappers._head, *items)

def restore_unevaluated_from_wrapper(expr_with_wrappers):
items = (
Expression(SymbolUnevaluated, element.expr)
if isinstance(element, UnevaluatedWrapper)
else element
for element in expr_with_wrappers._elements
)
return Expression(expr_with_wrappers._head, *items)

# Step 1 : evaluate the Head and get its Attributes. These attributes, used later, include
# HoldFirst / HoldAll / HoldRest / HoldAllComplete.

# Note: self._head can be not just a symbol, but some arbitrary expression.
# This is what makes expressions in Mathics be M-expressions rather than
# S-expressions.
head = self._head.evaluate(evaluation)

attributes = head.get_attributes(evaluation.definitions)
contains_unevaluated_wrapper = False

if self.elements_properties is None:
self._build_elements_properties()

# @timeit
def eval_elements():
nonlocal recompute_properties
nonlocal recompute_properties, contains_unevaluated_wrapper

# @timeit
def eval_range(indices):
nonlocal recompute_properties
nonlocal recompute_properties, contains_unevaluated_wrapper
recompute_properties = False
for index in indices:
element = elements[index]
if not element.has_form("Unevaluated", 1):
if isinstance(element, EvalMixin):
new_value = element.evaluate(evaluation)
# We need id() because != by itself is too permissive
if id(element) != id(new_value):
recompute_properties = True
elements[index] = new_value
if element.has_form("System`Unevaluated", 1):
contains_unevaluated_wrapper = True
elements[index] = UnevaluatedWrapper(element._elements[0])
elif isinstance(element, EvalMixin):
new_value = element.evaluate(evaluation)
# We need id() because != by itself is too permissive
if id(element) != id(new_value):
recompute_properties = True
elements[index] = new_value

# @timeit
def rest_range(indices):
Expand All @@ -1075,6 +1142,9 @@ def rest_range(indices):
if id(new_value) != id(element):
elements[index] = new_value
recompute_properties = True
elif element.has_form("System`Unevaluated", 1):
contains_unevaluated_wrapper = True
elements[index] = UnevaluatedWrapper(element._elements[0])

if (A_HOLD_ALL | A_HOLD_ALL_COMPLETE) & attributes:
# eval_range(range(0, 0))
Expand Down Expand Up @@ -1125,51 +1195,8 @@ def rest_range(indices):
new._build_elements_properties()
elements = new._elements

# comment @mmatera: I think this is wrong now, because alters singletons... (see PR #58)
# The idea is to mark which elements was marked as "Unevaluated"
# Also, this consumes time for long lists, and is useful just for a very unfrequent
# expressions, involving `Unevaluated` elements.
# Notice also that this behaviour is broken when the argument of "Unevaluated" is a symbol (see comment and tests in test/test_unevaluate.py)

for element in elements:
element.unevaluated = False

# If HoldAllComplete Attribute (flag ``A_HOLD_ALL_COMPLETE``) is not set,
# and the expression has elements of the form `Unevaluated[element]`
# change them to `element` and set a flag `unevaluated=True`
# If the evaluation fails, use this flag to restore back the initial form
# Unevaluated[element]

# comment @mmatera:
# what we need here is some way to track which elements are marked as
# Unevaluated, that propagates by flatten, and at the end,
# to recover a list of positions that (eventually)
# must be marked again as Unevaluated.

if not A_HOLD_ALL_COMPLETE & attributes:
dirty_elements = None

for index, element in enumerate(elements):
if element.has_form("Unevaluated", 1):
if dirty_elements is None:
dirty_elements = list(elements)
dirty_elements[index] = element._elements[0]
dirty_elements[index].unevaluated = True

if dirty_elements:
new = Expression(head, *dirty_elements)
elements = dirty_elements
new._build_elements_properties()

# If the Attribute ``Flat`` (flag ``A_FLAT``) is set, calls
# flatten with a callback that set elements as unevaluated
# too.
def flatten_callback(new_elements, old):
for element in new_elements:
element.unevaluated = old.unevaluated

if A_FLAT & attributes:
new = new.flatten_with_respect_to_head(new._head, callback=flatten_callback)
new = new.flatten_with_respect_to_head(new._head)
if new.elements_properties is None:
new._build_elements_properties()

Expand Down Expand Up @@ -1201,12 +1228,21 @@ def flatten_callback(new_elements, old):
# threading. Still, we need to perform this rewrite to
# maintain correct semantic behavior.
if A_LISTABLE & attributes:
# TODO: Check how Unevaluated works here
done, threaded = new.thread(evaluation)
if done:
# if contains_unevaluated_wrapper:
# new = restore_unevaluated_from_wrapper(new)
# threaded = restore_unevaluated_from_wrapper(new)

if threaded.sameQ(new):
if contains_unevaluated_wrapper:
new = restore_unevaluated_from_wrapper(new)
new._timestamp_cache(evaluation)
return new, False
else:
if contains_unevaluated_wrapper:
threaded = restore_unevaluated_from_wrapper(threaded)
return threaded, True

# Step 6: Now,the next step is to look at the rules associated to
Expand Down Expand Up @@ -1236,6 +1272,10 @@ def flatten_callback(new_elements, old):
# in mathics results in `fish`, but in WL results in `5`. This special behaviour suggests
# that WMA process in a different way certain symbols.

if contains_unevaluated_wrapper:
wrapped_new = new
new = strip_unevaluated_wrapper(new)

def rules():
rules_names = set()
if not A_HOLD_ALL_COMPLETE & attributes:
Expand Down Expand Up @@ -1271,19 +1311,8 @@ def rules():
return result, True

# Step 7: If we are here, is because we didn't find any rule that matches with the expression.

dirty_elements = None

# Expression did not change, re-apply Unevaluated
for index, element in enumerate(new._elements):
if element.unevaluated:
if dirty_elements is None:
dirty_elements = list(new._elements)
dirty_elements[index] = Expression(SymbolUnevaluated, element)

if dirty_elements:
new = Expression(head)
new.elements = dirty_elements
if contains_unevaluated_wrapper:
new = restore_unevaluated_from_wrapper(wrapped_new)

# Step 8: Update the cache. Return the new compound Expression and indicate that no further evaluation is needed.
new._timestamp_cache(evaluation)
Expand Down
71 changes: 41 additions & 30 deletions test/test_evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,55 +62,66 @@ def test_evaluation(str_expr: str, str_expected: str, message=""):
assert result == expected


# We skip this test because this behaviour is currently
# broken
@pytest.mark.parametrize(
"str_setup,str_expr,str_expected,message",
[
(
"F[x___Real]:=List[x]^2; a=.4;",
"F[x_, y_,z_]:={Length[x], Length[y], Length[z]}; a:={1,1}",
"F[Unevaluated[a],a,Unevaluated[a]]",
"{2,2,2}",
"evaluated as {Length[a], Length[{1,1}], Length[a]}",
),
(
"F[x_, y_,z_]:={HoldForm[x], HoldForm[y], HoldForm[z]}; a:={1,1}",
"F[Unevaluated[a],a,Unevaluated[a]]",
"{HoldForm[a], {1, 1}, HoldForm[a]}",
"evaluated as {HoldForm[a], HoldForm[{1,1}], HoldForm[a]}",
),
(
"ClearAll[a,F]; F[x_Symbol, y_Real, z_Symbol] := {x^2,y^2,z^2};a=4.;",
"F[Unevaluated[a], a, Unevaluated[a]]",
"F[Unevaluated[a], 0.4, Unevaluated[a]]",
None,
"{16., 16.,16.}",
(
"Here the definition matchec because the first and last parameters are"
"keep unevaluated."
),
),
(
"F[x___Real]:=List[x]^2; a=.4;",
"F[Unevaluated[b], b, Unevaluated[b]]",
"F[Unevaluated[b], b, Unevaluated[b]]",
"the second argument shouldn't be ``Unevaluated[b]``",
"ClearAll[a, b, F]; Attributes[F]=Flat;a=4.;",
"F[Unevaluated[a], a, F[b,1], Unevaluated[F[b,a]]]",
"F[Unevaluated[a], 4., b, 1, Unevaluated[b], Unevaluated[a]]",
"If F does not have a pattern that matches, keeps the unevaluated elements",
),
(
"G[x___Symbol]:=List[x]^2; a=.4;",
"G[Unevaluated[a], a, Unevaluated[a]]",
"F[Unevaluated[a], 0.4, Unevaluated[a]]",
None,
"ClearAll[a, b, F]; Attributes[F]={Orderless, Flat};a=4.;",
"F[Unevaluated[a], a, F[b,1], Unevaluated[F[b,a]]]",
"F[1, 4., Unevaluated[a], Unevaluated[a], b, Unevaluated[b]]",
"the same, with orderless. Unevaluated[expr] comes right after than expr.",
),
(
"G[x___Symbol]:=List[x]^2; a=.4;",
"G[Unevaluated[b], b, Unevaluated[b]]",
"F[Unevaluated[b], b, Unevaluated[b]]",
"the second argument shouldn't be ``Unevaluated[b]``",
"ClearAll[a, b, F,G]; Attributes[F]=Flat;a=4.;G[x_,y_]:=0",
"F[Unevaluated[a], a, G[b,1], Unevaluated[G[b,a]]]",
"F[Unevaluated[a], 4., 0, Unevaluated[G[b,a]]]",
"G is evaluated",
),
# (
# "ClearAll[a, b, F,G]; Attributes[F]=Flat;a=4.;F[x_,y__]:={x,y}",
# "F[Unevaluated[a], a, G[b,1], Unevaluated[G[b,a]]]",
# "{F[4.], 4., G[b, 1], G[b, 4.]}",
# "Since F is successfully evaluated, Unevaluated is removed.",
# ),
(
"a =.; F[a, x_Real, a] := List[x]^2;a=4.;",
"F[Unevaluated[a], a, Unevaluated[a]]",
"{16.}",
"Here, the second ``a`` is kept unevaluated because of a bug.",
"ClearAll[a, b, F,G]; Attributes[F]=HoldFirst;a=4.;F[x_,y__]:={Hold[x],Hold[y]}",
"F[Unevaluated[a], a, G[b,1], Unevaluated[F[b,a]]]",
"{Hold[a], Hold[4., G[b, 1], F[b, a]]}",
"Since F is successfully evaluated, Unevaluated is removed.",
),
],
)
@pytest.mark.skip(
reason="the right behaviour was broken since we start to use Symbol as singleton, to speedup comparisons."
)
def test_unevaluate(str_setup, str_expr, str_expected, message):
if str_setup:
evaluate(str_setup)
result = evaluate(str_expr)
expected = evaluate(str_expected)
if message:
assert result == expected, message
else:
assert result == expected
check_evaluation(str_expr, str_expected, message)


@pytest.mark.parametrize(
Expand Down