Skip to content

Commit

Permalink
fix: add __rdunder__ methods fmpz_series, fmpq_series
Browse files Browse the repository at this point in the history
  • Loading branch information
oscarbenjamin committed Apr 21, 2023
1 parent 9f74e62 commit d5f0a61
Show file tree
Hide file tree
Showing 2 changed files with 192 additions and 174 deletions.
182 changes: 96 additions & 86 deletions src/flint/fmpq_series.pyx
Original file line number Diff line number Diff line change
@@ -1,22 +1,12 @@
cdef fmpq_series_coerce_operands(x, y):
if typecheck(x, fmpq_series):
if isinstance(y, (int, long, fmpz, fmpz_poly, fmpz_series, fmpq, fmpq_poly)):
return x, fmpq_series(y)
#if isinstance(y, (nmod, nmod_poly, nmod_series)):
# return nmod_series(x), nmod_series(y)
if isinstance(y, (float, arb, arb_poly, arb_series)):
return arb_series(x), arb_series(y)
if isinstance(y, (complex, acb, acb_poly, acb_series)):
return acb_series(x), acb_series(y)
else:
if isinstance(x,(int, long, fmpz, fmpz_poly, fmpz_series, fmpq, fmpq_poly)):
return fmpq_series(x), y
#if isinstance(x, (nmod, nmod_poly, nmod_series)):
# return nmod_series(x), nmod_series(y)
if isinstance(x, (float, arb, arb_poly, arb_series)):
return arb_series(x), arb_series(y)
if isinstance(x, (complex, acb, acb_poly, acb_series)):
return acb_series(x), acb_series(y)
if isinstance(y, (int, long, fmpz, fmpz_poly, fmpz_series, fmpq, fmpq_poly)):
return x, fmpq_series(y)
#if isinstance(y, (nmod, nmod_poly, nmod_series)):
# return nmod_series(x), nmod_series(y)
if isinstance(y, (float, arb, arb_poly, arb_series)):
return arb_series(x), arb_series(y)
if isinstance(y, (complex, acb, acb_poly, acb_series)):
return acb_series(x), acb_series(y)
return NotImplemented, NotImplemented

cdef class fmpq_series(flint_series):
Expand Down Expand Up @@ -124,54 +114,72 @@ cdef class fmpq_series(flint_series):
return u

def __add__(s, t):
if not isinstance(t, fmpq_series):
s, t = fmpq_series_coerce_operands(s, t)
if s is NotImplemented:
return s
return s + t
cdef long cap
if type(s) is type(t):
u = fmpq_series.__new__(fmpq_series)
cap = getcap()
cap = min(cap, (<fmpq_series>s).prec)
cap = min(cap, (<fmpq_series>t).prec)
if cap > 0:
fmpq_poly_add((<fmpq_series>u).val, (<fmpq_series>s).val, (<fmpq_series>t).val)
fmpq_poly_truncate((<fmpq_series>u).val, cap)
(<fmpq_series>u).prec = cap
return u
u = fmpq_series.__new__(fmpq_series)
cap = getcap()
cap = min(cap, (<fmpq_series>s).prec)
cap = min(cap, (<fmpq_series>t).prec)
if cap > 0:
fmpq_poly_add((<fmpq_series>u).val, (<fmpq_series>s).val, (<fmpq_series>t).val)
fmpq_poly_truncate((<fmpq_series>u).val, cap)
(<fmpq_series>u).prec = cap
return u

def __radd__(s, t):
s, t = fmpq_series_coerce_operands(s, t)
if s is NotImplemented:
return s
return s + t
return t + s

def __sub__(s, t):
if not isinstance(t, fmpq_series):
s, t = fmpq_series_coerce_operands(s, t)
if s is NotImplemented:
return s
return s - t
cdef long cap
if type(s) is type(t):
u = fmpq_series.__new__(fmpq_series)
cap = getcap()
cap = min(cap, (<fmpq_series>s).prec)
cap = min(cap, (<fmpq_series>t).prec)
if cap > 0:
fmpq_poly_sub((<fmpq_series>u).val, (<fmpq_series>s).val, (<fmpq_series>t).val)
fmpq_poly_truncate((<fmpq_series>u).val, cap)
(<fmpq_series>u).prec = cap
return u
u = fmpq_series.__new__(fmpq_series)
cap = getcap()
cap = min(cap, (<fmpq_series>s).prec)
cap = min(cap, (<fmpq_series>t).prec)
if cap > 0:
fmpq_poly_sub((<fmpq_series>u).val, (<fmpq_series>s).val, (<fmpq_series>t).val)
fmpq_poly_truncate((<fmpq_series>u).val, cap)
(<fmpq_series>u).prec = cap
return u

def __rsub__(s, t):
s, t = fmpq_series_coerce_operands(s, t)
if s is NotImplemented:
return s
return s - t
return t - s

def __mul__(s, t):
if not isinstance(t, fmpq_series):
s, t = fmpq_series_coerce_operands(s, t)
if s is NotImplemented:
return s
return s * t
cdef long cap
if type(s) is type(t):
u = fmpq_series.__new__(fmpq_series)
cap = getcap()
cap = min(cap, (<fmpq_series>s).prec)
cap = min(cap, (<fmpq_series>t).prec)
if cap > 0:
fmpq_poly_mullow((<fmpq_series>u).val, (<fmpq_series>s).val, (<fmpq_series>t).val, cap)
(<fmpq_series>u).prec = cap
return u
u = fmpq_series.__new__(fmpq_series)
cap = getcap()
cap = min(cap, (<fmpq_series>s).prec)
cap = min(cap, (<fmpq_series>t).prec)
if cap > 0:
fmpq_poly_mullow((<fmpq_series>u).val, (<fmpq_series>s).val, (<fmpq_series>t).val, cap)
(<fmpq_series>u).prec = cap
return u

def __rmul__(s, t):
s, t = fmpq_series_coerce_operands(s, t)
if s is NotImplemented:
return s
return s * t
return t * s

cpdef valuation(self):
cdef long i
Expand All @@ -186,54 +194,56 @@ cdef class fmpq_series(flint_series):
def _div_(s, t):
cdef long cap, sval, tval
cdef fmpq_poly_t stmp, ttmp
if type(s) is type(t):
cap = getcap()
cap = min(cap, (<fmpq_series>s).prec)
cap = min(cap, (<fmpq_series>t).prec)

if fmpq_poly_is_zero((<fmpq_series>t).val):
raise ZeroDivisionError("power series division")
cap = getcap()
cap = min(cap, (<fmpq_series>s).prec)
cap = min(cap, (<fmpq_series>t).prec)

u = fmpq_series.__new__(fmpq_series)
if fmpq_poly_is_zero((<fmpq_series>t).val):
raise ZeroDivisionError("power series division")

if fmpq_poly_is_zero((<fmpq_series>s).val):
u.cap = cap
return u
u = fmpq_series.__new__(fmpq_series)

sval = (<fmpq_series>s).valuation()
tval = (<fmpq_series>t).valuation()
if fmpq_poly_is_zero((<fmpq_series>s).val):
u.cap = cap
return u

if sval < tval:
raise ValueError("quotient would not be a power series")
sval = (<fmpq_series>s).valuation()
tval = (<fmpq_series>t).valuation()

if fmpz_is_zero(&((<fmpq_series>t).val.coeffs[tval])):
raise ValueError("leading term in denominator is not a unit")
if sval < tval:
raise ValueError("quotient would not be a power series")

if tval == 0:
fmpq_poly_div_series((<fmpq_series>u).val, (<fmpq_series>s).val, (<fmpq_series>t).val, cap)
else:
fmpq_poly_init(stmp)
fmpq_poly_init(ttmp)
fmpq_poly_shift_right(stmp, (<fmpq_series>s).val, tval)
fmpq_poly_shift_right(ttmp, (<fmpq_series>t).val, tval)
cap -= tval
fmpq_poly_div_series((<fmpq_series>u).val, stmp, ttmp, cap)
fmpq_poly_clear(stmp)
fmpq_poly_clear(ttmp)
if fmpz_is_zero(&((<fmpq_series>t).val.coeffs[tval])):
raise ValueError("leading term in denominator is not a unit")

(<fmpq_series>u).prec = cap
return u
if tval == 0:
fmpq_poly_div_series((<fmpq_series>u).val, (<fmpq_series>s).val, (<fmpq_series>t).val, cap)
else:
fmpq_poly_init(stmp)
fmpq_poly_init(ttmp)
fmpq_poly_shift_right(stmp, (<fmpq_series>s).val, tval)
fmpq_poly_shift_right(ttmp, (<fmpq_series>t).val, tval)
cap -= tval
fmpq_poly_div_series((<fmpq_series>u).val, stmp, ttmp, cap)
fmpq_poly_clear(stmp)
fmpq_poly_clear(ttmp)

s, t = fmpq_series_coerce_operands(s, t)
if s is NotImplemented:
return s
return s / t
(<fmpq_series>u).prec = cap
return u

def __truediv__(s, t):
if not isinstance(t, fmpq_series):
s, t = fmpq_series_coerce_operands(s, t)
if s is NotImplemented:
return s
return s / t
return fmpq_series._div_(s, t)

def __div__(s, t):
return fmpq_series._div_(s, t)
def __rtruediv__(s, t):
s, t = fmpq_series_coerce_operands(s, t)
if s is NotImplemented:
return s
return t / s

# generic exponentiation (fallback code)
def __pow__(s, ulong exp, mod):
Expand Down
Loading

0 comments on commit d5f0a61

Please sign in to comment.