Skip to content

Commit

Permalink
move basic arithmetic to mathics.eval.arithmetic (#789)
Browse files Browse the repository at this point in the history
This PR starts to move the context independent arithmetic to a separate
module. Interestingly, just by moving around a little piece of code, the
doctest time seems to be reduced in another 10 seconds in my machine, it
is, a 13%.
  • Loading branch information
mmatera authored Feb 15, 2023
1 parent dcfe909 commit ab5c2f6
Show file tree
Hide file tree
Showing 4 changed files with 292 additions and 279 deletions.
215 changes: 3 additions & 212 deletions mathics/builtin/arithfns/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,18 +6,12 @@
"""


import mpmath
import sympy

from mathics.builtin.arithmetic import _MPMathFunction, create_infix
from mathics.builtin.base import BinaryOperator, Builtin, PrefixOperator, SympyFunction
from mathics.core.atoms import (
Complex,
Integer,
Integer0,
Integer1,
Integer2,
Integer3,
Integer310,
IntegerM1,
Expand All @@ -37,33 +31,28 @@
A_READ_PROTECTED,
)
from mathics.core.convert.expression import to_expression
from mathics.core.convert.mpmath import from_mpmath
from mathics.core.convert.sympy import from_sympy
from mathics.core.expression import Expression
from mathics.core.list import ListExpression
from mathics.core.number import dps, min_prec
from mathics.core.symbols import (
Symbol,
SymbolDivide,
SymbolHoldForm,
SymbolNull,
SymbolPlus,
SymbolPower,
SymbolTimes,
)
from mathics.core.systemsymbols import (
SymbolAccuracy,
SymbolBlank,
SymbolComplexInfinity,
SymbolDirectedInfinity,
SymbolIndeterminate,
SymbolInfinity,
SymbolInfix,
SymbolLeft,
SymbolMinus,
SymbolPattern,
SymbolSequence,
)
from mathics.eval.arithmetic import eval_Plus, eval_Times
from mathics.eval.nevaluator import eval_N
from mathics.eval.numerify import numerify

Expand Down Expand Up @@ -394,98 +383,8 @@ def is_negative(value) -> bool:

def eval(self, items, evaluation):
"Plus[items___]"

items_tuple = numerify(items, evaluation).get_sequence()
elements = []
last_item = last_count = None

prec = min_prec(*items_tuple)
is_machine_precision = any(item.is_machine_precision() for item in items_tuple)
numbers = []

def append_last():
if last_item is not None:
if last_count == 1:
elements.append(last_item)
else:
if last_item.has_form("Times", None):
elements.append(
Expression(
SymbolTimes, from_sympy(last_count), *last_item.elements
)
)
else:
elements.append(
Expression(SymbolTimes, from_sympy(last_count), last_item)
)

for item in items_tuple:
if isinstance(item, Number):
numbers.append(item)
else:
count = rest = None
if item.has_form("Times", None):
for element in item.elements:
if isinstance(element, Number):
count = element.to_sympy()
rest = item.get_mutable_elements()
rest.remove(element)
if len(rest) == 1:
rest = rest[0]
else:
rest.sort()
rest = Expression(SymbolTimes, *rest)
break
if count is None:
count = sympy.Integer(1)
rest = item
if last_item is not None and last_item == rest:
last_count = last_count + count
else:
append_last()
last_item = rest
last_count = count
append_last()
if numbers:
# TODO: reorganize de conditions to avoid compute unnecesary
# quantities. In particular, is we check mathine_precision,
# we do not need to evaluate prec.
if prec is not None:
if is_machine_precision:
numbers = [item.to_mpmath() for item in numbers]
number = mpmath.fsum(numbers)
number = from_mpmath(number)
else:
# TODO: If there are Complex numbers in `numbers`,
# and we are not working in machine precision, compute the sum of the real and imaginary
# parts separately, to preserve precision. For example,
# 1.`2 + 1.`3 I should produce
# Complex[1.`2, 1.`3]
# but with this implementation returns
# Complex[1.`2, 1.`2]
#
# TODO: if the precision are not equal for each number,
# we should estimate the result precision by computing the sum of individual errors
# prec = sum(abs(n.value) * 2**(-n.value._prec) for n in number if n.value._prec is not None)/sum(abs(n))
with mpmath.workprec(prec):
numbers = [item.to_mpmath() for item in numbers]
number = mpmath.fsum(numbers)
number = from_mpmath(number, precision=prec)
else:
number = from_sympy(sum(item.to_sympy() for item in numbers))
else:
number = Integer0

if not number.sameQ(Integer0):
elements.insert(0, number)

if not elements:
return Integer0
elif len(elements) == 1:
return elements[0]
else:
elements.sort()
return Expression(SymbolPlus, *elements)
return eval_Plus(*items_tuple)


class Power(BinaryOperator, _MPMathFunction):
Expand Down Expand Up @@ -940,112 +839,4 @@ def format_outputform(self, items, evaluation):
def eval(self, items, evaluation):
"Times[items___]"
items = numerify(items, evaluation).get_sequence()
elements = []
numbers = []
infinity_factor = False
# These quantities only have sense if there are numeric terms.
# Also, prec is only needed if is_machine_precision is not True.
prec = min_prec(*items)
is_machine_precision = any(item.is_machine_precision() for item in items)

# find numbers and simplify Times -> Power
for item in items:
if isinstance(item, Number):
numbers.append(item)
elif elements and item == elements[-1]:
elements[-1] = Expression(SymbolPower, elements[-1], Integer2)
elif (
elements
and item.has_form("Power", 2)
and elements[-1].has_form("Power", 2)
and item.elements[0].sameQ(elements[-1].elements[0])
):
elements[-1] = Expression(
SymbolPower,
elements[-1].elements[0],
Expression(SymbolPlus, item.elements[1], elements[-1].elements[1]),
)
elif (
elements
and item.has_form("Power", 2)
and item.elements[0].sameQ(elements[-1])
):
elements[-1] = Expression(
SymbolPower,
elements[-1],
Expression(SymbolPlus, item.elements[1], Integer1),
)
elif (
elements
and elements[-1].has_form("Power", 2)
and elements[-1].elements[0].sameQ(item)
):
elements[-1] = Expression(
SymbolPower,
item,
Expression(SymbolPlus, Integer1, elements[-1].elements[1]),
)
elif item.get_head().sameQ(SymbolDirectedInfinity):
infinity_factor = True
if len(item.elements) > 0:
direction = item.elements[0]
if isinstance(direction, Number):
numbers.append(direction)
else:
elements.append(direction)
elif item.sameQ(SymbolInfinity) or item.sameQ(SymbolComplexInfinity):
infinity_factor = True
else:
elements.append(item)

if numbers:
if prec is not None:
if is_machine_precision:
numbers = [item.to_mpmath() for item in numbers]
number = mpmath.fprod(numbers)
number = from_mpmath(number)
else:
with mpmath.workprec(prec):
numbers = [item.to_mpmath() for item in numbers]
number = mpmath.fprod(numbers)
number = from_mpmath(number, precision=prec)
else:
number = sympy.Mul(*[item.to_sympy() for item in numbers])
number = from_sympy(number)
else:
number = Integer1

if number.sameQ(Integer1):
number = None
elif number.is_zero:
if infinity_factor:
return SymbolIndeterminate
return number
elif (
number.sameQ(IntegerM1) and elements and elements[0].has_form("Plus", None)
):
elements[0] = Expression(
elements[0].get_head(),
*[
Expression(SymbolTimes, IntegerM1, element)
for element in elements[0].elements
],
)
number = None

if number is not None:
elements.insert(0, number)

if not elements:
if infinity_factor:
return SymbolComplexInfinity
return Integer1

if len(elements) == 1:
ret = elements[0]
else:
ret = Expression(SymbolTimes, *elements)
if infinity_factor:
return Expression(SymbolDirectedInfinity, ret)
else:
return ret
return eval_Times(*items)
71 changes: 9 additions & 62 deletions mathics/builtin/arithmetic.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
Basic arithmetic functions, including complex number arithmetic.
"""
from typing import Callable, Optional

from mathics.eval.numerify import numerify

Expand Down Expand Up @@ -48,17 +47,11 @@
A_PROTECTED,
)
from mathics.core.convert.expression import to_expression
from mathics.core.convert.mpmath import from_mpmath
from mathics.core.convert.sympy import SympyExpression, from_sympy, sympy_symbol_prefix
from mathics.core.element import ElementsProperties
from mathics.core.expression import Expression
from mathics.core.list import ListExpression
from mathics.core.number import (
FP_MANTISA_BINARY_DIGITS,
SpecialValueError,
dps,
min_prec,
)
from mathics.core.number import dps, min_prec
from mathics.core.symbols import (
Atom,
Symbol,
Expand All @@ -72,48 +65,17 @@
)
from mathics.core.systemsymbols import (
SymbolAnd,
SymbolComplexInfinity,
SymbolDirectedInfinity,
SymbolInfix,
SymbolPiecewise,
SymbolPossibleZeroQ,
SymbolTable,
SymbolUndefined,
)
from mathics.eval.arithmetic import eval_mpmath_function
from mathics.eval.nevaluator import eval_N


# @lru_cache(maxsize=4096)
def call_mpmath(
mpmath_function: Callable, mpmath_args: tuple, prec: Optional[int] = None
):
"""
calls the mpmath_function with mpmath_args parms
if prec=None, use floating point arithmetic.
Otherwise, work with prec bits of precision.
"""
# TODO: rocky, please help me with the annotations
# in the signature of this function.
if prec is None:
prec = FP_MANTISA_BINARY_DIGITS
with mpmath.workprec(prec):
try:
result_mp = mpmath_function(*mpmath_args)
if prec != FP_MANTISA_BINARY_DIGITS:
return from_mpmath(result_mp, prec)
return from_mpmath(result_mp)
except ValueError as exc:
text = str(exc)
if text == "gamma function pole":
return SymbolComplexInfinity
else:
raise
except ZeroDivisionError:
return
except SpecialValueError as exc:
return Symbol(exc.name)


class _MPMathFunction(SympyFunction):

# These below attributes are the default attributes:
Expand All @@ -140,8 +102,6 @@ def eval(self, z, evaluation):
"%(name)s[z__]"

args = numerify(z, evaluation).get_sequence()
mpmath_function = self.get_mpmath_function(tuple(args))
result = None

# if no arguments are inexact attempt to use sympy
if all(not x.is_inexact() for x in args):
Expand All @@ -150,35 +110,22 @@ def eval(self, z, evaluation):
result = from_sympy(result)
# evaluate elements to convert e.g. Plus[2, I] -> Complex[2, 1]
return result.evaluate_elements(evaluation)
elif mpmath_function is None:
return

if not all(isinstance(arg, Number) for arg in args):
return

if any(arg.is_machine_precision() for arg in args):
# if any argument has machine precision then the entire calculation
# is done with machine precision.
float_args = [
arg.round().get_float_value(permit_complex=True) for arg in args
]
if None in float_args:
return
mpmath_function = self.get_mpmath_function(tuple(args))
if mpmath_function is None:
return

result = call_mpmath(mpmath_function, tuple(float_args))
if any(arg.is_machine_precision() for arg in args):
prec = None
else:
prec = min_prec(*args)
d = dps(prec)
args = [eval_N(arg, evaluation, Integer(d)) for arg in args]

with mpmath.workprec(prec):
# to_mpmath seems to require that the precision is set from outside
mpmath_args = [x.to_mpmath() for x in args]
if None in mpmath_args:
return
args = [arg.round(d) for arg in args]

result = call_mpmath(mpmath_function, tuple(mpmath_args), prec)
return result
return eval_mpmath_function(mpmath_function, *args, prec=prec)


class _MPMathMultiFunction(_MPMathFunction):
Expand Down
Loading

0 comments on commit ab5c2f6

Please sign in to comment.