Skip to content

Commit

Permalink
fix: add __rdunder__ methods for arb and acb
Browse files Browse the repository at this point in the history
  • Loading branch information
oscarbenjamin committed Apr 21, 2023
1 parent d5f0a61 commit 070b7e5
Show file tree
Hide file tree
Showing 3 changed files with 152 additions and 116 deletions.
110 changes: 63 additions & 47 deletions src/flint/acb.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -156,23 +156,18 @@ cdef class acb(flint_scalar):
return (self.real._mpf_, self.imag._mpf_)

def __richcmp__(s, t, int op):
cdef acb_struct sval[1]
cdef acb_struct tval[1]
cdef bint res
cdef int stype, ttype
cdef int ttype
if not (op == 2 or op == 3):
raise ValueError("comparing complex numbers")
stype = acb_set_any_ref(sval, s)
if stype == FMPZ_UNKNOWN:
return NotImplemented
ttype = acb_set_any_ref(tval, t)
if ttype == FMPZ_UNKNOWN:
return NotImplemented
if op == 2:
res = acb_eq(sval, tval)
res = acb_eq(s.val, tval)
else:
res = acb_ne(sval, tval)
if stype == FMPZ_TMP: acb_clear(sval)
res = acb_ne(s.val, tval)
if ttype == FMPZ_TMP: acb_clear(tval)
return res

Expand Down Expand Up @@ -363,92 +358,114 @@ cdef class acb(flint_scalar):
return res

def __add__(s, t):
cdef acb_struct sval[1]
cdef acb_struct tval[1]
cdef int stype, ttype
stype = acb_set_any_ref(sval, s)
if stype == FMPZ_UNKNOWN:
cdef int ttype
ttype = acb_set_any_ref(tval, t)
if ttype == FMPZ_UNKNOWN:
return NotImplemented
u = acb.__new__(acb)
acb_add((<acb>u).val, s.val, tval, getprec())
if ttype == FMPZ_TMP: acb_clear(tval)
return u

def __radd__(s, t):
cdef acb_struct tval[1]
cdef int ttype
ttype = acb_set_any_ref(tval, t)
if ttype == FMPZ_UNKNOWN:
return NotImplemented
u = acb.__new__(acb)
acb_add((<acb>u).val, sval, tval, getprec())
if stype == FMPZ_TMP: acb_clear(sval)
acb_add((<acb>u).val, tval, s.val, getprec())
if ttype == FMPZ_TMP: acb_clear(tval)
return u

def __sub__(s, t):
cdef acb_struct sval[1]
cdef acb_struct tval[1]
cdef int stype, ttype
stype = acb_set_any_ref(sval, s)
if stype == FMPZ_UNKNOWN:
cdef int ttype
ttype = acb_set_any_ref(tval, t)
if ttype == FMPZ_UNKNOWN:
return NotImplemented
u = acb.__new__(acb)
acb_sub((<acb>u).val, s.val, tval, getprec())
if ttype == FMPZ_TMP: acb_clear(tval)
return u

def __rsub__(s, t):
cdef acb_struct tval[1]
cdef int ttype
ttype = acb_set_any_ref(tval, t)
if ttype == FMPZ_UNKNOWN:
return NotImplemented
u = acb.__new__(acb)
acb_sub((<acb>u).val, sval, tval, getprec())
if stype == FMPZ_TMP: acb_clear(sval)
acb_sub((<acb>u).val, tval, s.val, getprec())
if ttype == FMPZ_TMP: acb_clear(tval)
return u

def __mul__(s, t):
cdef acb_struct sval[1]
cdef acb_struct tval[1]
cdef int stype, ttype
stype = acb_set_any_ref(sval, s)
if stype == FMPZ_UNKNOWN:
return NotImplemented
cdef int ttype
ttype = acb_set_any_ref(tval, t)
if ttype == FMPZ_UNKNOWN:
return NotImplemented
u = acb.__new__(acb)
acb_mul((<acb>u).val, sval, tval, getprec())
if stype == FMPZ_TMP: acb_clear(sval)
acb_mul((<acb>u).val, s.val, tval, getprec())
if ttype == FMPZ_TMP: acb_clear(tval)
return u

# important: must not be cdef because of cython magic
@staticmethod
def _div_(s, t):
cdef acb_struct sval[1]
def __rmul__(s, t):
cdef acb_struct tval[1]
cdef int stype, ttype
stype = acb_set_any_ref(sval, s)
if stype == FMPZ_UNKNOWN:
return NotImplemented
cdef int ttype
ttype = acb_set_any_ref(tval, t)
if ttype == FMPZ_UNKNOWN:
return NotImplemented
u = acb.__new__(acb)
acb_div((<acb>u).val, sval, tval, getprec())
if stype == FMPZ_TMP: acb_clear(sval)
acb_mul((<acb>u).val, tval, s.val, getprec())
if ttype == FMPZ_TMP: acb_clear(tval)
return u

def __truediv__(s, t):
return acb._div_(s, t)
cdef acb_struct tval[1]
cdef int ttype
ttype = acb_set_any_ref(tval, t)
if ttype == FMPZ_UNKNOWN:
return NotImplemented
u = acb.__new__(acb)
acb_div((<acb>u).val, s.val, tval, getprec())
if ttype == FMPZ_TMP: acb_clear(tval)
return u

def __div__(s, t):
return acb._div_(s, t)
def __rtruediv__(s, t):
cdef acb_struct tval[1]
cdef int ttype
ttype = acb_set_any_ref(tval, t)
if ttype == FMPZ_UNKNOWN:
return NotImplemented
u = acb.__new__(acb)
acb_div((<acb>u).val, tval, s.val, getprec())
if ttype == FMPZ_TMP: acb_clear(tval)
return u

def __pow__(s, t, u):
cdef acb_struct sval[1]
cdef acb_struct tval[1]
cdef int stype, ttype
cdef int ttype
if u is not None:
raise ValueError("modular exponentiation of complex number")
stype = acb_set_any_ref(sval, s)
if stype == FMPZ_UNKNOWN:
ttype = acb_set_any_ref(tval, t)
if ttype == FMPZ_UNKNOWN:
return NotImplemented
u = acb.__new__(acb)
acb_pow((<acb>u).val, s.val, tval, getprec())
if ttype == FMPZ_TMP: acb_clear(tval)
return u

def __rpow__(s, t):
cdef acb_struct tval[1]
cdef int ttype
ttype = acb_set_any_ref(tval, t)
if ttype == FMPZ_UNKNOWN:
return NotImplemented
u = acb.__new__(acb)
acb_pow((<acb>u).val, sval, tval, getprec())
if stype == FMPZ_TMP: acb_clear(sval)
acb_pow((<acb>u).val, tval, s.val, getprec())
if ttype == FMPZ_TMP: acb_clear(tval)
return u

Expand Down Expand Up @@ -2560,4 +2577,3 @@ cdef class acb(flint_scalar):
acb_hypgeom_coulomb(NULL, (<acb>G).val, NULL, NULL,
(<acb>l).val, (<acb>eta).val, (<acb>self).val, getprec())
return G

113 changes: 69 additions & 44 deletions src/flint/arb.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -116,22 +116,6 @@ cdef any_as_arb_or_notimplemented(x):
return NotImplemented
return t

cdef _arb_div_(s, t):
cdef arb_struct sval[1]
cdef arb_struct tval[1]
cdef int stype, ttype
stype = arb_set_any_ref(sval, s)
if stype == FMPZ_UNKNOWN:
return NotImplemented
ttype = arb_set_any_ref(tval, t)
if ttype == FMPZ_UNKNOWN:
return NotImplemented
u = arb.__new__(arb)
arb_div((<arb>u).val, sval, tval, getprec())
if stype == FMPZ_TMP: arb_clear(sval)
if ttype == FMPZ_TMP: arb_clear(tval)
return u

cdef class arb(flint_scalar):
ur"""
Represents a real number `x` by a midpoint `m` and a radius `r`
Expand Down Expand Up @@ -550,74 +534,116 @@ cdef class arb(flint_scalar):
return res

def __add__(s, t):
cdef arb_struct sval[1]
cdef arb_struct tval[1]
cdef int stype, ttype
stype = arb_set_any_ref(sval, s)
if stype == FMPZ_UNKNOWN:
cdef int ttype
ttype = arb_set_any_ref(tval, t)
if ttype == FMPZ_UNKNOWN:
return NotImplemented
u = arb.__new__(arb)
arb_add((<arb>u).val, s.val, tval, getprec())
if ttype == FMPZ_TMP: arb_clear(tval)
return u

def __radd__(s, t):
cdef arb_struct tval[1]
cdef int ttype
ttype = arb_set_any_ref(tval, t)
if ttype == FMPZ_UNKNOWN:
return NotImplemented
u = arb.__new__(arb)
arb_add((<arb>u).val, sval, tval, getprec())
if stype == FMPZ_TMP: arb_clear(sval)
arb_add((<arb>u).val, tval, s.val, getprec())
if ttype == FMPZ_TMP: arb_clear(tval)
return u

def __sub__(s, t):
cdef arb_struct sval[1]
cdef arb_struct tval[1]
cdef int stype, ttype
stype = arb_set_any_ref(sval, s)
if stype == FMPZ_UNKNOWN:
cdef int ttype
ttype = arb_set_any_ref(tval, t)
if ttype == FMPZ_UNKNOWN:
return NotImplemented
u = arb.__new__(arb)
arb_sub((<arb>u).val, s.val, tval, getprec())
if ttype == FMPZ_TMP: arb_clear(tval)
return u

def __rsub__(s, t):
cdef arb_struct tval[1]
cdef int ttype
ttype = arb_set_any_ref(tval, t)
if ttype == FMPZ_UNKNOWN:
return NotImplemented
u = arb.__new__(arb)
arb_sub((<arb>u).val, sval, tval, getprec())
if stype == FMPZ_TMP: arb_clear(sval)
arb_sub((<arb>u).val, tval, s.val, getprec())
if ttype == FMPZ_TMP: arb_clear(tval)
return u

def __mul__(s, t):
cdef arb_struct sval[1]
cdef arb_struct tval[1]
cdef int stype, ttype
stype = arb_set_any_ref(sval, s)
if stype == FMPZ_UNKNOWN:
cdef int ttype
ttype = arb_set_any_ref(tval, t)
if ttype == FMPZ_UNKNOWN:
return NotImplemented
u = arb.__new__(arb)
arb_mul((<arb>u).val, s.val, tval, getprec())
if ttype == FMPZ_TMP: arb_clear(tval)
return u

def __rmul__(s, t):
cdef arb_struct tval[1]
cdef int ttype
ttype = arb_set_any_ref(tval, t)
if ttype == FMPZ_UNKNOWN:
return NotImplemented
u = arb.__new__(arb)
arb_mul((<arb>u).val, sval, tval, getprec())
if stype == FMPZ_TMP: arb_clear(sval)
arb_mul((<arb>u).val, tval, s.val, getprec())
if ttype == FMPZ_TMP: arb_clear(tval)
return u

def __truediv__(s, t):
return _arb_div_(s, t)
cdef arb_struct tval[1]
cdef int ttype
ttype = arb_set_any_ref(tval, t)
if ttype == FMPZ_UNKNOWN:
return NotImplemented
u = arb.__new__(arb)
arb_div((<arb>u).val, s.val, tval, getprec())
if ttype == FMPZ_TMP: arb_clear(tval)
return u

def __div__(s, t):
return _arb_div_(s, t)
def __rtruediv__(s, t):
cdef arb_struct tval[1]
cdef int ttype
ttype = arb_set_any_ref(tval, t)
if ttype == FMPZ_UNKNOWN:
return NotImplemented
u = arb.__new__(arb)
arb_div((<arb>u).val, tval, s.val, getprec())
if ttype == FMPZ_TMP: arb_clear(tval)
return u

def __pow__(s, t, modulus):
cdef arb_struct sval[1]
cdef arb_struct tval[1]
cdef int stype, ttype
cdef int ttype
if modulus is not None:
raise TypeError("three-argument pow() not supported by arb type")
stype = arb_set_any_ref(sval, s)
if stype == FMPZ_UNKNOWN:
ttype = arb_set_any_ref(tval, t)
if ttype == FMPZ_UNKNOWN:
return NotImplemented
u = arb.__new__(arb)
arb_pow((<arb>u).val, s.val, tval, getprec())
if ttype == FMPZ_TMP: arb_clear(tval)
return u

def __rpow__(s, t, modulus):
cdef arb_struct tval[1]
cdef int ttype
if modulus is not None:
raise TypeError("three-argument pow() not supported by arb type")
ttype = arb_set_any_ref(tval, t)
if ttype == FMPZ_UNKNOWN:
return NotImplemented
u = arb.__new__(arb)
arb_pow((<arb>u).val, sval, tval, getprec())
if stype == FMPZ_TMP: arb_clear(sval)
arb_pow((<arb>u).val, tval, s.val, getprec())
if ttype == FMPZ_TMP: arb_clear(tval)
return u

Expand Down Expand Up @@ -2421,4 +2447,3 @@ cdef class arb(flint_scalar):
arb_hypgeom_coulomb(NULL, (<arb>G).val,
(<arb>l).val, (<arb>eta).val, (<arb>self).val, getprec())
return G

Loading

0 comments on commit 070b7e5

Please sign in to comment.