Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add range indexing to UniformScaling #24359

Merged
merged 12 commits into from
Jun 6, 2020
Prev Previous commit
Next Next commit
Make getindex(::UniformScaling, ranges...) return a dense array
dalum committed May 30, 2020
commit d34ae5c92d304fd3dfd8f5d397e31946823ea78f
27 changes: 13 additions & 14 deletions stdlib/LinearAlgebra/src/uniformscaling.jl
Original file line number Diff line number Diff line change
@@ -79,26 +79,25 @@ Base.has_offset_axes(::UniformScaling) = false
getindex(J::UniformScaling, i::Integer,j::Integer) = ifelse(i==j,J.λ,zero(J.λ))

function getindex(x::UniformScaling{T}, n::AbstractRange{<:Integer}, m::AbstractRange{<:Integer}) where T
if length(n) == length(m) && step(n) == step(m)
k = first(n) - first(m)
if k % step(n) == 0 && length(n) - abs(k) > 0
v = fill(x.λ, length(n) - abs(k))
return spdiagm(k => v)
if step(n) == step(m)
k = (first(n) - first(m))
if k % step(n) == 0
k = div(k, step(n))
p = abs(length(n) - length(m))
c = abs(p) < abs(k) ? abs(k) - abs(p) : 0
v = fill(x.λ, min(length(n), length(m)) - c)
return diagm(length(n), length(m), k => v)
else
return spzeros(T, length(n), length(m))
return zeros(T, length(n), length(m))
end
end
I = Int[]
J = Int[]
V = T[]
@inbounds for (i,ii) in enumerate(n), (j,jj) in enumerate(m)
A = zeros(T, length(n), length(m))
@inbounds for (j,jj) in enumerate(m), (i,ii) in enumerate(n)
if ii == jj
push!(I, i)
push!(J, j)
push!(V, x.λ)
A[i,j] = x.λ
end
end
return sparse(I, J, V, length(n), length(m))
return A
end

function show(io::IO, ::MIME"text/plain", J::UniformScaling)
24 changes: 18 additions & 6 deletions stdlib/LinearAlgebra/test/uniformscaling.jl
Original file line number Diff line number Diff line change
@@ -28,12 +28,24 @@ end
@testset "getindex" begin
@test I[1,1] == 1
@test I[1,2] == 0
@test I[1:2,1:2] == eye(2,2)[1:2,1:2]
@test I[1:2:3,1:2:3] == eye(3,3)[1:2:3,1:2:3]
@test I[1:2:8,2:2:9] == eye(10,10)[1:2:8,2:2:9]
@test I[1:2,2:3] == eye(3,3)[1:2,2:3]
@test I[2:3,1:2] == eye(3,3)[2:3,1:2]
@test I[2:-1:1,1:2] == eye(2,2)[2:-1:1,1:2]

J = I(15)
for (a, b) in [
(1:2, 1:2),
(1:2:3, 1:2:3),
(1:2:8, 2:2:9),
(2:2:9, 1:2:8),
(1:2:8, 9:-4:1),
(9:-4:1, 1:2:8),
(1:2, 2:3),
(2:3, 1:2),
(2:-1:1, 1:2),
(1:2, 2:-1:1),
(1:2:9, 5:2:13),
(5:2:13, 1:2:9),
]
@test I[a,b] == J[a,b]
end
end

@testset "sqrt, exp, log, and trigonometric functions" begin