Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add compose_mod and powmod with large exp #174

Merged
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/flint/flintlib/nmod_poly.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ cdef extern from "flint/nmod_poly.h":
int nmod_poly_equal_trunc(const nmod_poly_t poly1, const nmod_poly_t poly2, slong n)
int nmod_poly_is_zero(const nmod_poly_t poly)
int nmod_poly_is_one(const nmod_poly_t poly)
int nmod_poly_is_gen(const nmod_poly_t poly)
void _nmod_poly_shift_left(mp_ptr res, mp_srcptr poly, slong len, slong k)
void nmod_poly_shift_left(nmod_poly_t res, const nmod_poly_t poly, slong k)
void _nmod_poly_shift_right(mp_ptr res, mp_srcptr poly, slong len, slong k)
Expand Down
43 changes: 35 additions & 8 deletions src/flint/test/test_all.py
Original file line number Diff line number Diff line change
Expand Up @@ -1422,7 +1422,15 @@ def test_nmod_poly():
assert raises(lambda: [] * s, TypeError)
assert raises(lambda: [] // s, TypeError)
assert raises(lambda: [] % s, TypeError)
assert raises(lambda: pow(P([1,2],3), 3, 4), NotImplementedError)
assert raises(lambda: [] % s, TypeError)
assert raises(lambda: s.reverse(-1), ValueError)
assert raises(lambda: s.compose("A"), TypeError)
assert raises(lambda: s.compose_mod(s, "A"), TypeError)
assert raises(lambda: s.compose_mod("A", P([3,6,9],17)), TypeError)
assert raises(lambda: s.compose_mod(s, P([0], 17)), ZeroDivisionError)
assert raises(lambda: pow(s, -1, P([3,6,9],17)), ValueError)
assert raises(lambda: pow(s, 1, "A"), TypeError)
assert raises(lambda: pow(s, "A", P([3,6,9],17)), TypeError)
assert str(P([1,2,3],17)) == "3*x^2 + 2*x + 1"
assert P([1,2,3],17).repr() == "nmod_poly([1, 2, 3], 17)"
p = P([3,4,5],17)
Expand Down Expand Up @@ -2087,6 +2095,18 @@ def test_fmpz_mod_poly():
assert f*f == f**2
assert f*f == f**fmpz(2)

# pow_mod
# assert ui and fmpz exp agree for polynomials and generators
R_gen = R_test.gen()
assert pow(f, 2**60, g) == pow(pow(f, 2**30, g), 2**30, g)
assert pow(R_gen, 2**60, g) == pow(pow(R_gen, 2**30, g), 2**30, g)

# Check other typechecks for pow_mod
assert raises(lambda: pow(f, -2, g), ValueError)
assert raises(lambda: pow(f, 1, "A"), TypeError)
assert raises(lambda: pow(f, "A", g), TypeError)
assert raises(lambda: f.pow_mod(2**32, g, mod_rev_inv="A"), TypeError)

# Shifts
assert raises(lambda: R_test([1,2,3]).left_shift(-1), ValueError)
assert raises(lambda: R_test([1,2,3]).right_shift(-1), ValueError)
Expand Down Expand Up @@ -2118,6 +2138,13 @@ def test_fmpz_mod_poly():
# compose
assert raises(lambda: h.compose("AAA"), TypeError)

# compose mod
mod = R_test([1,2,3,4])
assert f.compose(h) % mod == f.compose_mod(h, mod)
assert raises(lambda: h.compose_mod("AAA", mod), TypeError)
assert raises(lambda: h.compose_mod(f, "AAA"), TypeError)
assert raises(lambda: h.compose_mod(f, R_test(0)), ZeroDivisionError)

# Reverse
assert raises(lambda: h.reverse(degree=-100), ValueError)
assert R_test([-1,-2,-3]).reverse() == R_test([-3,-2,-1])
Expand All @@ -2135,9 +2162,9 @@ def test_fmpz_mod_poly():
assert raises(lambda: f.mulmod(f, "AAA"), TypeError)
assert raises(lambda: f.mulmod("AAA", g), TypeError)

# powmod
assert f.powmod(2, g) == (f*f) % g
assert raises(lambda: f.powmod(2, "AAA"), TypeError)
# pow_mod
assert f.pow_mod(2, g) == (f*f) % g
assert raises(lambda: f.pow_mod(2, "AAA"), TypeError)

# divmod
S, T = f.divmod(g)
Expand Down Expand Up @@ -2635,9 +2662,10 @@ def setbad(obj, i, val):
assert P([1, 1]) ** 2 == P([1, 2, 1])
assert raises(lambda: P([1, 1]) ** -1, ValueError)
assert raises(lambda: P([1, 1]) ** None, TypeError)

# # XXX: Not sure what this should do in general:
assert raises(lambda: pow(P([1, 1]), 2, 3), NotImplementedError)

# XXX: Not sure what this should do in general:
# TODO: this now fails as fmpz_mod_poly allows modulus
# assert raises(lambda: pow(P([1, 1]), 2, 3), NotImplementedError)
GiacomoPope marked this conversation as resolved.
Show resolved Hide resolved

assert P([1, 2, 1]).gcd(P([1, 1])) == P([1, 1])
assert raises(lambda: P([1, 2, 1]).gcd(None), TypeError)
Expand Down Expand Up @@ -2667,7 +2695,6 @@ def setbad(obj, i, val):
if is_field:
assert P([1, 2, 1]).integral() == P([0, 1, 1, S(1)/3])


def _all_mpolys():
return [
(flint.fmpz_mpoly, flint.fmpz_mpoly_ctx, flint.fmpz, False),
Expand Down
97 changes: 86 additions & 11 deletions src/flint/types/fmpz_mod_poly.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -536,7 +536,7 @@ cdef class fmpz_mod_poly(flint_poly):

def __pow__(self, e, mod=None):
if mod is not None:
raise NotImplementedError
return self.pow_mod(e, mod)

cdef fmpz_mod_poly res
if e < 0:
Expand Down Expand Up @@ -778,11 +778,11 @@ cdef class fmpz_mod_poly(flint_poly):

return evaluations

def compose(self, input):
def compose(self, other):
"""
Returns the composition of two polynomials

To be precise about the order of composition, given ``self``, and ``input``
To be precise about the order of composition, given ``self``, and ``other``
by `f(x)`, `g(x)`, returns `f(g(x))`.

>>> R = fmpz_mod_poly_ctx(163)
Expand All @@ -794,12 +794,45 @@ cdef class fmpz_mod_poly(flint_poly):
9*x^4 + 12*x^3 + 10*x^2 + 4*x + 1
"""
cdef fmpz_mod_poly res
val = self.ctx.any_as_fmpz_mod_poly(input)
val = self.ctx.any_as_fmpz_mod_poly(other)
if val is NotImplemented:
raise TypeError(f"Cannot compose the polynomial with input: {input}")
raise TypeError(f"Cannot compose the polynomial with input: {other}")

res = self.ctx.new_ctype_poly()
fmpz_mod_poly_compose(res.val, self.val, (<fmpz_mod_poly>val).val, self.ctx.mod.val)
return res

def compose_mod(self, other, modulus):
"""
Returns the composition of two polynomials modulo a third.

To be precise about the order of composition, given ``self``, and ``other``
and ``modulus`` by `f(x)`, `g(x)` and `h(x)`, returns `f(g(x)) \mod h(x)`.
We require that `h(x)` is non-zero.

>>> R = fmpz_mod_poly_ctx(163)
>>> f = R([1,2,3,4,5])
>>> g = R([3,2,1])
>>> h = R([1,0,1,0,1])
>>> f.compose_mod(g, h)
63*x^3 + 100*x^2 + 17*x + 63
>>> g.compose_mod(f, h)
147*x^3 + 159*x^2 + 4*x + 7
"""
cdef fmpz_mod_poly res
val = self.ctx.any_as_fmpz_mod_poly(other)
if val is NotImplemented:
raise TypeError(f"cannot compose the polynomial with input: {other}")

h = self.ctx.any_as_fmpz_mod_poly(modulus)
if h is NotImplemented:
raise TypeError(f"cannot reduce the polynomial with input: {modulus}")

if h.is_zero():
raise ZeroDivisionError("cannot reduce modulo zero")

res = self.ctx.new_ctype_poly()
fmpz_mod_poly_compose_mod(res.val, self.val, (<fmpz_mod_poly>val).val, (<fmpz_mod_poly>h).val, self.ctx.mod.val)
return res

cpdef long length(self):
Expand Down Expand Up @@ -1104,30 +1137,72 @@ cdef class fmpz_mod_poly(flint_poly):
)
return res

def powmod(self, e, modulus):
def pow_mod(self, e, modulus, mod_rev_inv=None):
"""
Returns ``self`` raised to the power ``e`` modulo ``modulus``:
:math:`f^e \mod g`
:math:`f^e \mod g`/

``mod_rev_inv`` is the inverse of the reverse of the modulus,
precomputing it and passing it to ``pow_mod()`` can optimise
powering of polynomials with large exponents.

>>> R = fmpz_mod_poly_ctx(163)
>>> x = R.gen()
>>> f = 30*x**6 + 104*x**5 + 76*x**4 + 33*x**3 + 70*x**2 + 44*x + 65
>>> g = 43*x**6 + 91*x**5 + 77*x**4 + 113*x**3 + 71*x**2 + 132*x + 60
>>> mod = x**4 + 93*x**3 + 78*x**2 + 72*x + 149
>>>
>>> f.powmod(123, mod)
>>> f.pow_mod(123, mod)
3*x^3 + 25*x^2 + 115*x + 161
>>> f.pow_mod(2**64, mod)
52*x^3 + 96*x^2 + 136*x + 9
>>> mod_rev_inv = mod.reverse().inverse_series_trunc(4)
>>> f.pow_mod(2**64, mod, mod_rev_inv)
52*x^3 + 96*x^2 + 136*x + 9
"""
cdef fmpz_mod_poly res

if e < 0:
raise ValueError("Exponent must be non-negative")

modulus = self.ctx.any_as_fmpz_mod_poly(modulus)
if modulus is NotImplemented:
raise TypeError(f"Cannot interpret {modulus} as a polynomial")

# Output polynomial
res = self.ctx.new_ctype_poly()
fmpz_mod_poly_powmod_ui_binexp(
res.val, self.val, <ulong>e, (<fmpz_mod_poly>modulus).val, res.ctx.mod.val
)

# For small exponents, use a simple binary exponentiation method
if e.bit_length() < 32:
fmpz_mod_poly_powmod_ui_binexp(
res.val, self.val, <ulong>e, (<fmpz_mod_poly>modulus).val, res.ctx.mod.val
)
return res

# For larger exponents we need to cast e to an fmpz first
e_fmpz = any_as_fmpz(e)
if e_fmpz is NotImplemented:
raise TypeError(f"exponent cannot be cast to an fmpz type: {e = }")

# To optimise powering, we precompute the inverse of the reverse of the modulus
if mod_rev_inv is not None:
mod_rev_inv = self.ctx.any_as_fmpz_mod_poly(mod_rev_inv)
if mod_rev_inv is NotImplemented:
raise TypeError(f"Cannot interpret {mod_rev_inv} as a polynomial")
else:
mod_rev_inv = modulus.reverse().inverse_series_trunc(modulus.length())

# Use windowed exponentiation optimisation when self = x
if self.is_gen():
fmpz_mod_poly_powmod_x_fmpz_preinv(
res.val, (<fmpz>e_fmpz).val, (<fmpz_mod_poly>modulus).val, (<fmpz_mod_poly>mod_rev_inv).val, res.ctx.mod.val
)
return res

# Otherwise using binary exponentiation for all other inputs
fmpz_mod_poly_powmod_fmpz_binexp_preinv(
res.val, self.val, (<fmpz>e_fmpz).val, (<fmpz_mod_poly>modulus).val, (<fmpz_mod_poly>mod_rev_inv).val, res.ctx.mod.val
)
return res

def divmod(self, other):
Expand Down
Loading
Loading