Skip to content

Commit

Permalink
improve from_mpmath. adding pytests (#827)
Browse files Browse the repository at this point in the history
This PR improves the annotations and the implementation of
`from_mpmath`. Also, some pytests are included.
  • Loading branch information
mmatera authored Mar 26, 2023
1 parent 5ee43c2 commit c7af29e
Show file tree
Hide file tree
Showing 3 changed files with 82 additions and 13 deletions.
34 changes: 21 additions & 13 deletions mathics/core/convert/mpmath.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,26 +8,30 @@

from mathics.core.atoms import (
Complex,
Integer0,
Integer1,
IntegerM1,
MachineReal,
MachineReal0,
PrecisionReal,
)
from mathics.core.element import BaseElement
from mathics.core.expression import Expression
from mathics.core.symbols import Atom
from mathics.core.systemsymbols import (
SymbolComplexInfinity,
SymbolDirectedInfinity,
SymbolIndeterminate,
)
from mathics.core.systemsymbols import SymbolDirectedInfinity, SymbolIndeterminate

ExpressionInfinity = Expression(SymbolDirectedInfinity, Integer1)
ExpressionMInfinity = Expression(SymbolDirectedInfinity, IntegerM1)
ExpressionIInfinity = Expression(SymbolDirectedInfinity, Complex(Integer0, Integer1))
ExpressionMIInfinity = Expression(SymbolDirectedInfinity, Complex(Integer0, IntegerM1))

ExpressionComplexInfinity = Expression(SymbolDirectedInfinity)


@lru_cache(maxsize=1024)
def from_mpmath(
value: Union[mpmath.mpf, mpmath.mpc],
precision: Optional[int] = None,
) -> Atom:
) -> BaseElement:
"""
Converts mpf or mpc to Number.
The optional parameter `precision` represents
Expand All @@ -37,8 +41,7 @@ def from_mpmath(
return SymbolIndeterminate
if isinstance(value, mpmath.mpf):
if mpmath.isinf(value):
direction = Integer1 if value > 0 else IntegerM1
return Expression(SymbolDirectedInfinity, direction)
return ExpressionInfinity if value > 0 else ExpressionMInfinity
if precision is None:
return MachineReal(float(value))
# If the error if of the order of the number, the number
Expand All @@ -48,12 +51,17 @@ def from_mpmath(
# HACK: use str here to prevent loss of precision
return PrecisionReal(sympy.Float(str(value), precision=precision - 1))
elif isinstance(value, mpmath.mpc):
if mpmath.isinf(value):
return SymbolComplexInfinity
if value.imag == 0.0:
return from_mpmath(value.real, precision=precision)
real = from_mpmath(value.real, precision=precision)
imag = from_mpmath(value.imag, precision=precision)
val_re, val_im = value.real, value.imag
if mpmath.isinf(val_re):
if mpmath.isinf(val_im):
return ExpressionComplexInfinity
return ExpressionInfinity if val_re > 0 else ExpressionMInfinity
elif mpmath.isinf(val_im):
return ExpressionIInfinity if val_im > 0 else ExpressionMIInfinity
real = from_mpmath(val_re, precision=precision)
imag = from_mpmath(val_im, precision=precision)
return Complex(real, imag)
else:
raise TypeError(type(value))
Expand Down
Empty file added test/core/convert/__init__.py
Empty file.
61 changes: 61 additions & 0 deletions test/core/convert/mpmath.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
from mpmath import mpc, mpf
from sympy import Float as SympyFloat

from mathics.core.atoms import (
Complex,
Integer0,
Integer1,
IntegerM1,
MachineReal,
PrecisionReal,
Rational,
Real,
)
from mathics.core.convert.mpmath import from_mpmath
from mathics.core.expression import Expression
from mathics.core.systemsymbols import SymbolDirectedInfinity, SymbolIndeterminate


def test_infinity():
vals = [
(mpf("+inf"), Expression(SymbolDirectedInfinity, Integer1)),
(mpf("-inf"), Expression(SymbolDirectedInfinity, IntegerM1)),
(
mpc(1.0, "inf"),
Expression(SymbolDirectedInfinity, Complex(Integer0, Integer1)),
),
(
mpc(1.0, "-inf"),
Expression(SymbolDirectedInfinity, Complex(Integer0, IntegerM1)),
),
(mpc("inf", 1), Expression(SymbolDirectedInfinity, Integer1)),
(mpc("-inf", 1), Expression(SymbolDirectedInfinity, IntegerM1)),
(mpf("nan"), SymbolIndeterminate),
]
for val_in, val_out in vals:
print([val_in, val_out, from_mpmath(val_in)])
assert val_out.sameQ(from_mpmath(val_in))


def test_from_to_mpmath():
vals = [
(Integer1, MachineReal(1.0)),
(Rational(1, 3), MachineReal(1.0 / 3.0)),
(MachineReal(1.2), MachineReal(1.2)),
(PrecisionReal(SympyFloat(1.3, 10)), PrecisionReal(SympyFloat(1.3, 10))),
(PrecisionReal(SympyFloat(1.3, 30)), PrecisionReal(SympyFloat(1.3, 30))),
(Complex(Integer1, IntegerM1), Complex(Integer1, IntegerM1)),
(Complex(Integer1, Real(-1.0)), Complex(Integer1, Real(-1.0))),
(Complex(Real(1.0), Real(-1.0)), Complex(Real(1.0), Real(-1.0))),
(
Complex(MachineReal(1.0), PrecisionReal(SympyFloat(-1.0, 10))),
Complex(MachineReal(1.0), PrecisionReal(SympyFloat(-1.0, 10))),
),
(
Complex(MachineReal(1.0), PrecisionReal(SympyFloat(-1.0, 30))),
Complex(MachineReal(1.0), PrecisionReal(SympyFloat(-1.0, 30))),
),
]
for val1, val2 in vals:
print((val1, val2))
assert val2.sameQ(from_mpmath(val1.to_mpmath()))

0 comments on commit c7af29e

Please sign in to comment.