Skip to content

Commit

Permalink
fix mpmath precision handling (#781)
Browse files Browse the repository at this point in the history
This PR fixes another issue with precision. In particular, this issue
was the reason of the hangings in doctest in #778.
  • Loading branch information
mmatera authored Feb 7, 2023
1 parent 9303b3b commit 12f6aa3
Show file tree
Hide file tree
Showing 7 changed files with 110 additions and 69 deletions.
84 changes: 42 additions & 42 deletions mathics/builtin/arithmetic.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
Basic arithmetic functions, including complex number arithmetic.
"""
from typing import Callable, Optional

from mathics.eval.numerify import numerify

Expand Down Expand Up @@ -52,7 +53,12 @@
from mathics.core.element import ElementsProperties
from mathics.core.expression import Expression
from mathics.core.list import ListExpression
from mathics.core.number import SpecialValueError, dps, min_prec
from mathics.core.number import (
FP_MANTISA_BINARY_DIGITS,
SpecialValueError,
dps,
min_prec,
)
from mathics.core.symbols import (
Atom,
Symbol,
Expand All @@ -68,9 +74,7 @@
SymbolAnd,
SymbolComplexInfinity,
SymbolDirectedInfinity,
SymbolIndeterminate,
SymbolInfix,
SymbolOverflow,
SymbolPiecewise,
SymbolPossibleZeroQ,
SymbolTable,
Expand All @@ -79,20 +83,35 @@
from mathics.eval.nevaluator import eval_N


@lru_cache(maxsize=4096)
def call_mpmath(mpmath_function, mpmath_args):
try:
return mpmath_function(*mpmath_args)
except ValueError as exc:
text = str(exc)
if text == "gamma function pole":
return SymbolComplexInfinity
else:
raise
except ZeroDivisionError:
return
except SpecialValueError as exc:
return Symbol(exc.name)
# @lru_cache(maxsize=4096)
def call_mpmath(
mpmath_function: Callable, mpmath_args: tuple, prec: Optional[int] = None
):
"""
calls the mpmath_function with mpmath_args parms
if prec=None, use floating point arithmetic.
Otherwise, work with prec bits of precision.
"""
# TODO: rocky, please help me with the annotations
# in the signature of this function.
if prec is None:
prec = FP_MANTISA_BINARY_DIGITS
with mpmath.workprec(prec):
try:
result_mp = mpmath_function(*mpmath_args)
if prec != FP_MANTISA_BINARY_DIGITS:
return from_mpmath(result_mp, prec)
return from_mpmath(result_mp)
except ValueError as exc:
text = str(exc)
if text == "gamma function pole":
return SymbolComplexInfinity
else:
raise
except ZeroDivisionError:
return
except SpecialValueError as exc:
return Symbol(exc.name)


class _MPMathFunction(SympyFunction):
Expand Down Expand Up @@ -147,38 +166,18 @@ def eval(self, z, evaluation):
return

result = call_mpmath(mpmath_function, tuple(float_args))

if isinstance(result, (mpmath.mpc, mpmath.mpf)):
if mpmath.isinf(result) and isinstance(result, mpmath.mpc):
result = SymbolComplexInfinity
elif mpmath.isinf(result) and result > 0:
result = Expression(SymbolDirectedInfinity, Integer1)
elif mpmath.isinf(result) and result < 0:
result = Expression(SymbolDirectedInfinity, IntegerM1)
elif mpmath.isnan(result):
result = SymbolIndeterminate
else:
# FIXME: replace try/except as a context manager
# like "with evaluation.from_mpmath()...
# which can be instrumented for
# or mpmath tracing and benchmarking on demand.
# Then use it on other places where mpmath appears.
try:
result = from_mpmath(result)
except OverflowError:
evaluation.message("General", "ovfl")
result = Expression(SymbolOverflow)
else:
prec = min_prec(*args)
d = dps(prec)
args = [eval_N(arg, evaluation, Integer(d)) for arg in args]

with mpmath.workprec(prec):
# to_mpmath seems to require that the precision is set from outside
mpmath_args = [x.to_mpmath() for x in args]
if None in mpmath_args:
return
result = call_mpmath(mpmath_function, tuple(mpmath_args))
if isinstance(result, (mpmath.mpc, mpmath.mpf)):
result = from_mpmath(result, precision=prec)

result = call_mpmath(mpmath_function, tuple(mpmath_args), prec)
return result


Expand Down Expand Up @@ -723,7 +722,7 @@ def to_sympy(self, expr, **kwargs):
return sympy.zoo


class I(Predefined):
class I_(Predefined):
"""
<url>:Imaginary unit:https://en.wikipedia.org/wiki/Imaginary_unit</url> \
(<url>:WMA:https://reference.wolfram.com/language/ref/I.html</url>)
Expand All @@ -739,6 +738,7 @@ class I(Predefined):
= 10
"""

name = "I"
summary_text = "imaginary unit"
python_equivalent = 1j

Expand Down
8 changes: 6 additions & 2 deletions mathics/builtin/numbers/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
MIN_MACHINE_NUMBER,
PrecisionValueError,
get_precision,
prec,
)
from mathics.core.symbols import Atom, Symbol, strip_context
from mathics.core.systemsymbols import SymbolIndeterminate
Expand All @@ -43,8 +44,11 @@ def mp_constant(fn: str, d=None) -> mpmath.mpf:
# ask for a certain number of digits, but the
# accuracy will be less than that. Figure out
# what's up and compensate somehow.
mpmath.mp.dps = int_d = int(d * 3.321928)
return getattr(mpmath, fn)(prec=int_d)

int_d = prec(d)
with mpmath.workprec(int_d):
result = str(getattr(mpmath, fn)(prec=int_d))
return result


def mp_convert_constant(obj, **kwargs):
Expand Down
15 changes: 1 addition & 14 deletions mathics/builtin/specialfns/gamma.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,17 +112,6 @@ def eval_with_z(self, z, a, b, evaluation):
return

result = call_mpmath(mpmath_function, tuple(float_args))
if isinstance(result, (mpmath.mpc, mpmath.mpf)):
if mpmath.isinf(result) and isinstance(result, mpmath.mpc):
result = SymbolComplexInfinity
elif mpmath.isinf(result) and result > 0:
result = Expression(SymbolDirectedInfinity, Integer1)
elif mpmath.isinf(result) and result < 0:
result = Expression(SymbolDirectedInfinity, Integer(-1))
elif mpmath.isnan(result):
result = SymbolIndeterminate
else:
result = from_mpmath(result)
else:
prec = min_prec(*args)
d = dps(prec)
Expand All @@ -131,9 +120,7 @@ def eval_with_z(self, z, a, b, evaluation):
mpmath_args = [x.to_mpmath() for x in args]
if None in mpmath_args:
return
result = call_mpmath(mpmath_function, tuple(mpmath_args))
if isinstance(result, (mpmath.mpc, mpmath.mpf)):
result = from_mpmath(result, precision=prec)
result = call_mpmath(mpmath_function, tuple(mpmath_args), prec)
return result


Expand Down
21 changes: 14 additions & 7 deletions mathics/core/atoms.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,13 +249,18 @@ def to_sympy(self, **kwargs):
def to_python(self, *args, **kwargs):
return self.value

def round(self, d=None) -> Union["MachineReal", "PrecisionReal"]:
def round(self, d: Optional[int] = None) -> Union["MachineReal", "PrecisionReal"]:
"""
Produce a Real approximation of ``self`` with decimal precision ``d``.
If ``d`` is ``None``, and self.value fits in a float,
returns a ``MachineReal`` number.
Is the low-level equivalent to ``N[self, d]``.
"""
if d is None:
d = self.value.bit_length()
if d <= FP_MANTISA_BINARY_DIGITS:
return MachineReal(float(self.value))
else:
# FP_MANTISA_BINARY_DIGITS / log_2(10) + 1
d = MACHINE_PRECISION_VALUE
return PrecisionReal(sympy.Float(self.value, d))

Expand Down Expand Up @@ -441,7 +446,10 @@ def make_boxes(self, form):
def is_zero(self) -> bool:
return self.value == 0.0

def round(self, d=None) -> "MachineReal":
def round(self, d: Optional[int] = None) -> "MachineReal":
"""
Produce a Real approximation of ``self`` with decimal precision ``d``.
"""
return self

def sameQ(self, other) -> bool:
Expand Down Expand Up @@ -540,12 +548,11 @@ def make_boxes(self, form):
self, dps(self.get_precision()), None, None, _number_form_options
)

def round(self, d=None) -> Union[MachineReal, "PrecisionReal"]:
def round(self, d: Optional[int] = None) -> Union[MachineReal, "PrecisionReal"]:
if d is None:
return MachineReal(float(self.value))
else:
d = min(dps(self.get_precision()), d)
return PrecisionReal(self.value.n(d))
_prec = min(prec(d), self.value._prec)
return PrecisionReal(sympy.Float(self.value, precision=_prec))

def sameQ(self, other) -> bool:
"""Mathics SameQ for PrecisionReal"""
Expand Down
11 changes: 10 additions & 1 deletion mathics/core/convert/mpmath.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,17 @@
import mpmath
import sympy

from mathics.core.atoms import Complex, MachineReal, MachineReal0, PrecisionReal
from mathics.core.atoms import (
Complex,
Integer1,
IntegerM1,
MachineReal,
MachineReal0,
PrecisionReal,
)
from mathics.core.expression import Expression
from mathics.core.symbols import Atom
from mathics.core.systemsymbols import SymbolDirectedInfinity


@lru_cache(maxsize=1024)
Expand Down
5 changes: 2 additions & 3 deletions mathics/eval/nevaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,13 @@

import sympy

from mathics.core.atoms import Number
from mathics.core.atoms import Number, PrecisionReal
from mathics.core.attributes import A_N_HOLD_ALL, A_N_HOLD_FIRST, A_N_HOLD_REST
from mathics.core.convert.sympy import from_sympy
from mathics.core.element import BaseElement
from mathics.core.evaluation import Evaluation
from mathics.core.expression import Expression
from mathics.core.number import PrecisionValueError, get_precision
from mathics.core.number import PrecisionValueError, dps, get_precision
from mathics.core.symbols import Atom
from mathics.core.systemsymbols import SymbolMachinePrecision, SymbolN

Expand Down Expand Up @@ -50,7 +50,6 @@ def eval_NValues(
stored in ``evaluation.definitions``.
If ``prec`` can not be evaluated as a number, returns None, otherwise, returns an expression.
"""

# The first step is to determine the precision goal
try:
# Here ``get_precision`` is called with ``show_messages``
Expand Down
35 changes: 35 additions & 0 deletions test/builtin/atomic/test_numbers.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,3 +270,38 @@ def test_accuracy(str_expr, str_expected):
)
def test_precision(str_expr, str_expected):
check_evaluation(f"Precision[{str_expr}]", str_expected)


@pytest.mark.parametrize(
("str_expr", "str_expected", "msg"),
[
(None, None, None),
("N[Sqrt[2], 41]//Precision", "41.", "first round sqrt[2`41]"),
("N[Sqrt[2], 40]//Precision", "40.", "first round sqrt[2`40]"),
("N[Sqrt[2], 41]//Precision", "41.", "second round sqrt[2`41]"),
("N[Sqrt[2], 40]//Precision", "40.", "second round sqrt[2`40]"),
(
"N[Sqrt[2], 41]",
'"1.4142135623730950488016887242096980785697"',
"third round sqrt[2`41]",
),
(
"Precision/@Table[N[Pi,p],{p, {5, 100, MachinePrecision, 20}}]",
"{5., 100., MachinePrecision, 20.}",
None,
),
(
"Precision/@Table[N[Sin[1],p],{p, {5, 100, MachinePrecision, 20}}]",
"{5., 100., MachinePrecision, 20.}",
None,
),
("N[Sqrt[2], 40]", '"1.414213562373095048801688724209698078570"', None),
("N[Sqrt[2], 4]", '"1.414"', None),
("N[Pi, 40]", '"3.141592653589793238462643383279502884197"', None),
("N[Pi, 4]", '"3.142"', None),
("N[Pi, 41]", '"3.1415926535897932384626433832795028841972"', None),
("N[Sqrt[2], 41]", '"1.4142135623730950488016887242096980785697"', None),
],
)
def test_change_prec(str_expr, str_expected, msg):
check_evaluation(str_expr, str_expected, failure_message=msg)

0 comments on commit 12f6aa3

Please sign in to comment.