Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

move basic arithmetic to mathics.eval.arithmetic #789

Merged
merged 3 commits into from
Feb 15, 2023
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
215 changes: 3 additions & 212 deletions mathics/builtin/arithfns/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,18 +6,12 @@

"""


import mpmath
import sympy

from mathics.builtin.arithmetic import _MPMathFunction, create_infix
from mathics.builtin.base import BinaryOperator, Builtin, PrefixOperator, SympyFunction
from mathics.core.atoms import (
Complex,
Integer,
Integer0,
Integer1,
Integer2,
Integer3,
Integer310,
IntegerM1,
Expand All @@ -37,33 +31,28 @@
A_READ_PROTECTED,
)
from mathics.core.convert.expression import to_expression
from mathics.core.convert.mpmath import from_mpmath
from mathics.core.convert.sympy import from_sympy
from mathics.core.expression import Expression
from mathics.core.list import ListExpression
from mathics.core.number import dps, min_prec
from mathics.core.symbols import (
Symbol,
SymbolDivide,
SymbolHoldForm,
SymbolNull,
SymbolPlus,
SymbolPower,
SymbolTimes,
)
from mathics.core.systemsymbols import (
SymbolAccuracy,
SymbolBlank,
SymbolComplexInfinity,
SymbolDirectedInfinity,
SymbolIndeterminate,
SymbolInfinity,
SymbolInfix,
SymbolLeft,
SymbolMinus,
SymbolPattern,
SymbolSequence,
)
from mathics.eval.arithmetic import eval_Plus, eval_Times
from mathics.eval.nevaluator import eval_N
from mathics.eval.numerify import numerify

Expand Down Expand Up @@ -394,98 +383,8 @@ def is_negative(value) -> bool:

def eval(self, items, evaluation):
"Plus[items___]"

items_tuple = numerify(items, evaluation).get_sequence()
elements = []
last_item = last_count = None

prec = min_prec(*items_tuple)
is_machine_precision = any(item.is_machine_precision() for item in items_tuple)
numbers = []

def append_last():
if last_item is not None:
if last_count == 1:
elements.append(last_item)
else:
if last_item.has_form("Times", None):
elements.append(
Expression(
SymbolTimes, from_sympy(last_count), *last_item.elements
)
)
else:
elements.append(
Expression(SymbolTimes, from_sympy(last_count), last_item)
)

for item in items_tuple:
if isinstance(item, Number):
numbers.append(item)
else:
count = rest = None
if item.has_form("Times", None):
for element in item.elements:
if isinstance(element, Number):
count = element.to_sympy()
rest = item.get_mutable_elements()
rest.remove(element)
if len(rest) == 1:
rest = rest[0]
else:
rest.sort()
rest = Expression(SymbolTimes, *rest)
break
if count is None:
count = sympy.Integer(1)
rest = item
if last_item is not None and last_item == rest:
last_count = last_count + count
else:
append_last()
last_item = rest
last_count = count
append_last()
if numbers:
# TODO: reorganize de conditions to avoid compute unnecesary
# quantities. In particular, is we check mathine_precision,
# we do not need to evaluate prec.
if prec is not None:
if is_machine_precision:
numbers = [item.to_mpmath() for item in numbers]
number = mpmath.fsum(numbers)
number = from_mpmath(number)
else:
# TODO: If there are Complex numbers in `numbers`,
# and we are not working in machine precision, compute the sum of the real and imaginary
# parts separately, to preserve precision. For example,
# 1.`2 + 1.`3 I should produce
# Complex[1.`2, 1.`3]
# but with this implementation returns
# Complex[1.`2, 1.`2]
#
# TODO: if the precision are not equal for each number,
# we should estimate the result precision by computing the sum of individual errors
# prec = sum(abs(n.value) * 2**(-n.value._prec) for n in number if n.value._prec is not None)/sum(abs(n))
with mpmath.workprec(prec):
numbers = [item.to_mpmath() for item in numbers]
number = mpmath.fsum(numbers)
number = from_mpmath(number, precision=prec)
else:
number = from_sympy(sum(item.to_sympy() for item in numbers))
else:
number = Integer0

if not number.sameQ(Integer0):
elements.insert(0, number)

if not elements:
return Integer0
elif len(elements) == 1:
return elements[0]
else:
elements.sort()
return Expression(SymbolPlus, *elements)
return eval_Plus(*items_tuple)


class Power(BinaryOperator, _MPMathFunction):
Expand Down Expand Up @@ -940,112 +839,4 @@ def format_outputform(self, items, evaluation):
def eval(self, items, evaluation):
"Times[items___]"
items = numerify(items, evaluation).get_sequence()
elements = []
numbers = []
infinity_factor = False
# These quantities only have sense if there are numeric terms.
# Also, prec is only needed if is_machine_precision is not True.
prec = min_prec(*items)
is_machine_precision = any(item.is_machine_precision() for item in items)

# find numbers and simplify Times -> Power
for item in items:
if isinstance(item, Number):
numbers.append(item)
elif elements and item == elements[-1]:
elements[-1] = Expression(SymbolPower, elements[-1], Integer2)
elif (
elements
and item.has_form("Power", 2)
and elements[-1].has_form("Power", 2)
and item.elements[0].sameQ(elements[-1].elements[0])
):
elements[-1] = Expression(
SymbolPower,
elements[-1].elements[0],
Expression(SymbolPlus, item.elements[1], elements[-1].elements[1]),
)
elif (
elements
and item.has_form("Power", 2)
and item.elements[0].sameQ(elements[-1])
):
elements[-1] = Expression(
SymbolPower,
elements[-1],
Expression(SymbolPlus, item.elements[1], Integer1),
)
elif (
elements
and elements[-1].has_form("Power", 2)
and elements[-1].elements[0].sameQ(item)
):
elements[-1] = Expression(
SymbolPower,
item,
Expression(SymbolPlus, Integer1, elements[-1].elements[1]),
)
elif item.get_head().sameQ(SymbolDirectedInfinity):
infinity_factor = True
if len(item.elements) > 0:
direction = item.elements[0]
if isinstance(direction, Number):
numbers.append(direction)
else:
elements.append(direction)
elif item.sameQ(SymbolInfinity) or item.sameQ(SymbolComplexInfinity):
infinity_factor = True
else:
elements.append(item)

if numbers:
if prec is not None:
if is_machine_precision:
numbers = [item.to_mpmath() for item in numbers]
number = mpmath.fprod(numbers)
number = from_mpmath(number)
else:
with mpmath.workprec(prec):
numbers = [item.to_mpmath() for item in numbers]
number = mpmath.fprod(numbers)
number = from_mpmath(number, precision=prec)
else:
number = sympy.Mul(*[item.to_sympy() for item in numbers])
number = from_sympy(number)
else:
number = Integer1

if number.sameQ(Integer1):
number = None
elif number.is_zero:
if infinity_factor:
return SymbolIndeterminate
return number
elif (
number.sameQ(IntegerM1) and elements and elements[0].has_form("Plus", None)
):
elements[0] = Expression(
elements[0].get_head(),
*[
Expression(SymbolTimes, IntegerM1, element)
for element in elements[0].elements
],
)
number = None

if number is not None:
elements.insert(0, number)

if not elements:
if infinity_factor:
return SymbolComplexInfinity
return Integer1

if len(elements) == 1:
ret = elements[0]
else:
ret = Expression(SymbolTimes, *elements)
if infinity_factor:
return Expression(SymbolDirectedInfinity, ret)
else:
return ret
return eval_Times(*items)
71 changes: 9 additions & 62 deletions mathics/builtin/arithmetic.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@

Basic arithmetic functions, including complex number arithmetic.
"""
from typing import Callable, Optional

from mathics.eval.numerify import numerify

Expand Down Expand Up @@ -48,17 +47,11 @@
A_PROTECTED,
)
from mathics.core.convert.expression import to_expression
from mathics.core.convert.mpmath import from_mpmath
from mathics.core.convert.sympy import SympyExpression, from_sympy, sympy_symbol_prefix
from mathics.core.element import ElementsProperties
from mathics.core.expression import Expression
from mathics.core.list import ListExpression
from mathics.core.number import (
FP_MANTISA_BINARY_DIGITS,
SpecialValueError,
dps,
min_prec,
)
from mathics.core.number import dps, min_prec
from mathics.core.symbols import (
Atom,
Symbol,
Expand All @@ -72,48 +65,17 @@
)
from mathics.core.systemsymbols import (
SymbolAnd,
SymbolComplexInfinity,
SymbolDirectedInfinity,
SymbolInfix,
SymbolPiecewise,
SymbolPossibleZeroQ,
SymbolTable,
SymbolUndefined,
)
from mathics.eval.arithmetic import eval_mpmath_function
from mathics.eval.nevaluator import eval_N


# @lru_cache(maxsize=4096)
def call_mpmath(
mpmath_function: Callable, mpmath_args: tuple, prec: Optional[int] = None
):
"""
calls the mpmath_function with mpmath_args parms
if prec=None, use floating point arithmetic.
Otherwise, work with prec bits of precision.
"""
# TODO: rocky, please help me with the annotations
# in the signature of this function.
if prec is None:
prec = FP_MANTISA_BINARY_DIGITS
with mpmath.workprec(prec):
try:
result_mp = mpmath_function(*mpmath_args)
if prec != FP_MANTISA_BINARY_DIGITS:
return from_mpmath(result_mp, prec)
return from_mpmath(result_mp)
except ValueError as exc:
text = str(exc)
if text == "gamma function pole":
return SymbolComplexInfinity
else:
raise
except ZeroDivisionError:
return
except SpecialValueError as exc:
return Symbol(exc.name)


class _MPMathFunction(SympyFunction):

# These below attributes are the default attributes:
Expand All @@ -140,8 +102,6 @@ def eval(self, z, evaluation):
"%(name)s[z__]"

args = numerify(z, evaluation).get_sequence()
mpmath_function = self.get_mpmath_function(tuple(args))
result = None

# if no arguments are inexact attempt to use sympy
if all(not x.is_inexact() for x in args):
Expand All @@ -150,35 +110,22 @@ def eval(self, z, evaluation):
result = from_sympy(result)
# evaluate elements to convert e.g. Plus[2, I] -> Complex[2, 1]
return result.evaluate_elements(evaluation)
elif mpmath_function is None:
return

if not all(isinstance(arg, Number) for arg in args):
return

if any(arg.is_machine_precision() for arg in args):
# if any argument has machine precision then the entire calculation
# is done with machine precision.
float_args = [
arg.round().get_float_value(permit_complex=True) for arg in args
]
if None in float_args:
return
mpmath_function = self.get_mpmath_function(tuple(args))
if mpmath_function is None:
return

result = call_mpmath(mpmath_function, tuple(float_args))
if any(arg.is_machine_precision() for arg in args):
prec = None
else:
prec = min_prec(*args)
d = dps(prec)
args = [eval_N(arg, evaluation, Integer(d)) for arg in args]

with mpmath.workprec(prec):
# to_mpmath seems to require that the precision is set from outside
mpmath_args = [x.to_mpmath() for x in args]
if None in mpmath_args:
return
args = [arg.round(d) for arg in args]

result = call_mpmath(mpmath_function, tuple(mpmath_args), prec)
return result
return eval_mpmath_function(mpmath_function, *args, prec=prec)


class _MPMathMultiFunction(_MPMathFunction):
Expand Down
Loading