Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Handle large integer exponents in matrix powers #55431

Draft
wants to merge 1 commit into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 29 additions & 3 deletions stdlib/LinearAlgebra/src/dense.jl
Original file line number Diff line number Diff line change
Expand Up @@ -514,12 +514,38 @@ function (^)(A::AbstractMatrix{T}, p::Integer) where T<:Integer
end
function integerpow(A::AbstractMatrix{T}, p) where T
TT = promote_op(^, T, typeof(p))
return (TT == T ? A : convert(AbstractMatrix{TT}, A))^Integer(p)
ATT = TT == T ? A : convert(AbstractMatrix{TT}, A)
return _integerpow(ATT, p)
end
_integerpow(A::AbstractMatrix, p) = A^Integer(p)
function _integerpow(A::AbstractMatrix, p::Union{Float32, Float64})
Comment on lines +520 to +521
Copy link
Contributor

@mikmoore mikmoore Aug 9, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
_integerpow(A::AbstractMatrix, p) = A^Integer(p)
function _integerpow(A::AbstractMatrix, p::Union{Float32, Float64})
_integerpow(A::AbstractMatrix, p::Integer) = A^p
function _integerpow(A::AbstractMatrix, p)

It seems that the long definition here is the more general? Apart from Float16 (although even there the overhead should be tiny), I'd feel safer with generic types avoiding the assumed-valid Integer conversion since that's how we got here in the first place.

EDIT: probably my proposed _integerpow(A::AbstractMatrix, p::Integer) is redundant altogether, if not directly overwriting another definition. I haven't traced the whole dispatch logic here.

# For these exponent types, not all values may be converted to Int
# We split the exponentiation into parts for which the exponent can be converted to an Int
if p < 0
return _integerpow(inv(A), -p)
end

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

perhaps explicitly adding a case for p==0 to return the identity matrix

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is handled by the A^0 call, which would also ensure type-stability.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmm ok fair

m = typemax(Int)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
m = typemax(Int)
m = prevpow(2, typemax(UInt))

typemax(Int) will not be accurate here (although arguably is not that far off). Note the inaccuracy of divrem(2.0^64, typemax(Int)) == (2.0, 0.0) due to the promotion Int -> Float64. Also, I believe the cost of power_by_squaring scales partly with the count_ones in the power.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that the correct value here is floatintmax(p)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you're right that it needs to be <=maxintfloat(typeof(p)), otherwise the rem could be wrong. So I think it needs to be UInt(min(prevpow(2, typemax(UInt)), maxintfloat(typeof(p)))).

if p <= m
Copy link
Contributor

@mikmoore mikmoore Aug 9, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The flow of all of this seems a little more complicated than necessary (maybe I'm wrong). What about just computing q, r = divrem(p, m) initially and branching on q > 0 rather than p <= m? It seems that could remove at least one if block.

return A^Int(p)
end
# for large numbers, we express A^p as A^(m*q + r) == (A^m)^q * A^r
# Here, m may be safely represented as an Int,
# and we raise to the power of q by carrying out the decomposition recursively
A2 = A^Int(m)
q, r = divrem(p, m)
if q > 1

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: this one can directly be q>0

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But the ==1 case is a no-op, so nicer to skip it.

A2 = integerpow(A2, q)
end
if !iszero(r)
A2 *= A^Int(r)
end
return A2
end

function schurpow(A::AbstractMatrix, p)
if istriu(A)
# Integer part
retmat = A ^ floor(Integer, p)
retmat = integerpow(A, floor(p))
# Real part
if p - floor(p) == 0.5
# special case: A^0.5 === sqrt(A)
Expand All @@ -530,7 +556,7 @@ function schurpow(A::AbstractMatrix, p)
else
S,Q,d = Schur{Complex}(schur(A))
# Integer part
R = S ^ floor(Integer, p)
R = integerpow(S, floor(p))
# Real part
if p - floor(p) == 0.5
# special case: A^0.5 === sqrt(A)
Expand Down
11 changes: 11 additions & 0 deletions stdlib/LinearAlgebra/test/dense.jl
Original file line number Diff line number Diff line change
Expand Up @@ -596,6 +596,17 @@ end

A8 = 100 * [-1+1im 0 0 1e-8; 0 1 0 0; 0 0 1 0; 0 0 0 1]
@test exp(log(A8)) ≈ A8

@testset "large exponents (issue #55300)" begin
A = [1 1e-10; 0 1]
B = A^(1e20)
@test B ≈ [1 1e10; 0 1]
B = A^(-1e20)
@test B ≈ inv(A)^1e20
@test (A^prevfloat(Inf64))[1,2] ≈ A[1,2] * prevfloat(Inf64)
@test (A^prevfloat(Inf32))[1,2] ≈ A[1,2] * prevfloat(Inf32)
@test (A^prevfloat(Inf16))[1,2] ≈ A[1,2] * prevfloat(Inf16)
end
end

@testset "Matrix trigonometry" begin
Expand Down