Skip to content

Commit

Permalink
Rectangular Eye (#54)
Browse files Browse the repository at this point in the history
* Add RectDiagonal and use it to implement rectangular Eye

* Improve coverage for RectDiagonal
  • Loading branch information
iamed2 authored and dlfivefifty committed Dec 21, 2018
1 parent 78f0a74 commit 3ec04f6
Show file tree
Hide file tree
Showing 2 changed files with 143 additions and 34 deletions.
86 changes: 69 additions & 17 deletions src/FillArrays.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ import Base: size, getindex, setindex!, IndexStyle, checkbounds, convert,

import Base.\

import LinearAlgebra: rank, svdvals!, tril, triu, tril!, triu!
import LinearAlgebra: rank, svdvals!, tril, triu, tril!, triu!, diag

import Base.Broadcast: broadcasted, DefaultArrayStyle, broadcast_shape

Expand Down Expand Up @@ -198,7 +198,59 @@ rank(F::Zeros) = 0
rank(F::Ones) = 1


const Eye{T, Axes} = Diagonal{T, Ones{T,1,Tuple{Axes}}}
struct RectDiagonal{T,V<:AbstractVector{T},Axes<:Tuple{Vararg{AbstractUnitRange,2}}} <: AbstractMatrix{T}
diag::V
axes::Axes

@inline function RectDiagonal{T,V}(A::V, axes::Axes) where {T,V<:AbstractVector{T},Axes<:Tuple{Vararg{AbstractUnitRange,2}}}
@assert !Base.has_offset_axes(A)
@assert any(length(ax) == length(A) for ax in axes)
rd = new{T,V,Axes}(A, axes)
@assert !Base.has_offset_axes(rd)
return rd
end
end

@inline RectDiagonal{T,V}(A::V, sz::Tuple{Vararg{Integer, 2}}) where {T,V} = RectDiagonal{T,V}(A, Base.OneTo.(sz))
@inline RectDiagonal{T,V}(A::V, axes::Vararg{Any, 2}) where {T,V} = RectDiagonal{T,V}(A, axes)
@inline RectDiagonal{T,V}(A::V, sz::Vararg{Integer, 2}) where {T,V} = RectDiagonal{T,V}(A, sz)
@inline RectDiagonal{T,V}(A::V) where {T,V} = RectDiagonal{T,V}(A, (axes(A, 1), axes(A, 1)))
@inline RectDiagonal{T}(A::V, args...) where {T,V} = RectDiagonal{T,V}(A, args...)
@inline RectDiagonal(A::V, args...) where {V} = RectDiagonal{eltype(V),V}(A, args...)

axes(rd::RectDiagonal) = rd.axes
size(rd::RectDiagonal) = length.(rd.axes)

@inline function getindex(rd::RectDiagonal{T}, i::Integer, j::Integer) where T
@boundscheck checkbounds(rd, i, j)
if i == j
@inbounds r = rd.diag[i]
else
r = zero(T)
end
return r
end

function setindex!(rd::RectDiagonal, v, i::Integer, j::Integer)
@boundscheck checkbounds(rd, i, j)
if i == j
@inbounds rd.diag[i] = v
elseif !iszero(v)
throw(ArgumentError("cannot set off-diagonal entry ($i, $j) to a nonzero value ($v)"))
end
return v
end

diag(rd::RectDiagonal) = rd.diag

for f in (:triu, :triu!, :tril, :tril!)
@eval ($f)(M::RectDiagonal) = M
end


const RectOrDiagonal{T,V,Axes} = Union{RectDiagonal{T,V,Axes}, Diagonal{T,V}}
const SquareEye{T,Axes} = Diagonal{T,Ones{T,1,Tuple{Axes}}}
const Eye{T,Axes} = RectOrDiagonal{T,Ones{T,1,Tuple{Axes}}}

Eye{T}(n::Integer) where T = Diagonal(Ones{T}(n))
Eye(n::Integer) = Diagonal(Ones(n))
Expand All @@ -211,19 +263,19 @@ function iterate(iter::Eye, istate = (1, 1))
j == m ? (i + 1, 1) : (i, j + 1))
end

isone(::Eye) = true
isone(::SquareEye) = true

for f in (:permutedims, :triu, :triu!, :tril, :tril!, :inv)
@eval ($f)(IM::Eye) = IM
for f in (:permutedims, :inv, :triu, :triu!, :tril, :tril!)
@eval ($f)(IM::SquareEye) = IM
end

@deprecate Eye(n::Integer, m::Integer) view(Eye(max(n,m)), 1:n, 1:m)
@deprecate Eye{T}(n::Integer, m::Integer) where T view(Eye{T}(max(n,m)), 1:n, 1:m)
Eye(n::Integer, m::Integer) = RectDiagonal(Ones(min(n,m)), n, m)
Eye{T}(n::Integer, m::Integer) where T = RectDiagonal{T}(Ones{T}(min(n,m)), n, m)
@deprecate Eye{T}(sz::Tuple{Vararg{Integer,2}}) where T Eye{T}(sz...)
@deprecate Eye(sz::Tuple{Vararg{Integer,2}}) Eye{Float64}(sz...)

@inline Eye{T}(A::AbstractMatrix) where T = Eye{T}(size(A))
@inline Eye(A::AbstractMatrix) = Eye{eltype(A)}(size(A))
@inline Eye{T}(A::AbstractMatrix) where T = Eye{T}(size(A)...)
@inline Eye(A::AbstractMatrix) = Eye{eltype(A)}(size(A)...)


#########
Expand Down Expand Up @@ -350,17 +402,17 @@ end
# all(isempty, []) and any(isempty, []) have non-generic behavior.
# We do not follow it here for Eye(0).
function any(f::Function, IM::Eye{T}) where T
d = size(IM, 1)
d > 1 && return f(zero(T)) || f(one(T))
d == 1 && return f(one(T))
return false
d1, d2 = size(IM)
(d1 < 1 || d2 < 1) && return false
(d1 > 1 || d2 > 1) && return f(zero(T)) || f(one(T))
return f(one(T))
end

function all(f::Function, IM::Eye{T}) where T
d = size(IM, 1)
d > 1 && return f(zero(T)) && f(one(T))
d == 1 && return f(one(T))
return false
d1, d2 = size(IM)
(d1 < 1 || d2 < 1) && return false
(d1 > 1 || d2 > 1) && return f(zero(T)) && f(one(T))
return f(one(T))
end

# In particular, these make iszero(Eye(n)) efficient.
Expand Down
91 changes: 74 additions & 17 deletions test/runtests.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
using FillArrays, LinearAlgebra, SparseArrays, Random, Test
import FillArrays: AbstractFill
import FillArrays: AbstractFill, RectDiagonal

@testset "fill array constructors and convert" begin
for (Typ, funcs) in ((:Zeros, :zeros), (:Ones, :ones))
Expand Down Expand Up @@ -91,7 +91,10 @@ import FillArrays: AbstractFill

@test Eye(5) isa Diagonal{Float64}
@test Eye(5) == Eye{Float64}(5)
@test Eye(5,6) == Eye{Float64}(5,6)
@test Eye(Ones(5,6)) == Eye{Float64}(5,6)
@test eltype(Eye(5)) == Float64
@test eltype(Eye(5,6)) == Float64

for T in (Int, Float64)
E = Eye{T}(5)
Expand All @@ -108,6 +111,7 @@ import FillArrays: AbstractFill


@test AbstractArray{Float32}(E) == Eye{Float32}(5)
@test AbstractArray{Float32}(E) == Eye{Float32}(5, 5)
end

@testset "Bool should change type" begin
Expand All @@ -127,6 +131,44 @@ import FillArrays: AbstractFill
end
end

@testset "RectDiagonal" begin
data = 1:3
expected_size = (5, 3)
expected_axes = Base.OneTo.(expected_size)
expected_matrix = [1 0 0; 0 2 0; 0 0 3; 0 0 0; 0 0 0]
expected = RectDiagonal{Int, UnitRange{Int}}(data, expected_axes)

@test axes(expected) == expected_axes
@test size(expected) == expected_size
@test (axes(expected, 1), axes(expected, 2)) == expected_axes
@test (size(expected, 1), size(expected, 2)) == expected_size

@test expected == expected_matrix
@test Matrix(expected) == expected_matrix
@test expected[:, 2] == expected_matrix[:, 2]
@test expected[2, :] == expected_matrix[2, :]
@test expected[5, :] == expected_matrix[5, :]

for Typ in (RectDiagonal, RectDiagonal{Int}, RectDiagonal{Int, UnitRange{Int}})
@test Typ(data) == expected[1:3, 1:3]
@test Typ(data, expected_axes) == expected
@test Typ(data, expected_axes...) == expected
@test Typ(data, expected_size) == expected
@test Typ(data, expected_size...) == expected
end

@test diag(expected) === expected.diag

mut = RectDiagonal(collect(data), expected_axes)
@test mut == expected
@test mut == expected_matrix
mut[1, 1] = 5
@test mut[1] == 5
@test diag(mut) == [5, 2, 3]
mut[2, 1] = 0
@test_throws ArgumentError mut[2, 1] = 9
end

# Check that all pair-wise combinations of + / - elements of As and Bs yield the correct
# type, and produce numerically correct results.
function test_addition_and_subtraction(As, Bs, Tout::Type)
Expand Down Expand Up @@ -194,6 +236,7 @@ end
@test_throws BoundsError convert(Diagonal{Int}, Zeros(8,5))


@test Diagonal(Eye(8,5)) == Diagonal(ones(5))
@test convert(Diagonal, Eye(5)) == Diagonal(ones(5))
@test convert(Diagonal{Int}, Eye(5)) == Diagonal(ones(Int,5))
end
Expand All @@ -210,7 +253,7 @@ end
spzeros(5)

for (Mat, SMat) in ((Zeros(5,5), spzeros(5,5)), (Zeros(6,5), spzeros(6,5)),
(Eye(5), sparse(I,5,5)))
(Eye(5), sparse(I,5,5)), (Eye(6,5), sparse(I,6,5)))
@test SparseMatrixCSC(Mat) ==
SparseMatrixCSC{Float64}(Mat) ==
SparseMatrixCSC{Float64,Int}(Mat) ==
Expand Down Expand Up @@ -238,12 +281,12 @@ end
@test axes(A) == tuple(Base.OneTo{BigInt}(BigInt(100)))
@test size(A) isa Tuple{BigInt}
end
let A = Eye(BigInt(100))
for A in (Eye(BigInt(100)), Eye(BigInt(100), BigInt(100)))
@test length(A) isa BigInt
@test axes(A) == tuple(Base.OneTo{BigInt}(BigInt(100)),Base.OneTo{BigInt}(BigInt(100)))
@test size(A) isa Tuple{BigInt,BigInt}
end
for A in (Zeros(BigInt(10), 10), Ones(BigInt(10), 10), Fill(2.0, (BigInt(10), 10)))
for A in (Zeros(BigInt(10), 10), Ones(BigInt(10), 10), Fill(2.0, (BigInt(10), 10)), Eye(BigInt(10), 8))
@test size(A) isa Tuple{BigInt,Int}
end

Expand Down Expand Up @@ -498,19 +541,26 @@ end

@testset "any all iszero isone" begin
for T in (Int, Float64, ComplexF64)
for d in (0, )
m = Eye{T}(d)
for m in (Eye{T}(0), Eye{T}(0, 0), Eye{T}(0, 1), Eye{T}(1, 0))
@test ! any(isone, m)
@test ! any(iszero, m)
@test ! all(iszero, m)
@test ! all(isone, m)
end
for d in (1, )
m = Eye{T}(d)
@test ! any(iszero, m)
@test ! all(iszero, m)
@test any(isone, m)
@test all(isone, m)
for m in (Eye{T}(d), Eye{T}(d, d))
@test ! any(iszero, m)
@test ! all(iszero, m)
@test any(isone, m)
@test all(isone, m)
end

for m in (Eye{T}(d, d + 1), Eye{T}(d + 1, d))
@test any(iszero, m)
@test ! all(iszero, m)
@test any(isone, m)
@test ! all(isone, m)
end

onem = Ones{T}(d, d)
@test isone(onem)
Expand All @@ -533,11 +583,12 @@ end
@test ! iszero(fillm2)
end
for d in (2, 3)
m = Eye{T}(d)
@test any(iszero, m)
@test ! all(iszero, m)
@test any(isone, m)
@test ! all(isone, m)
for m in (Eye{T}(d), Eye{T}(d, d), Eye{T}(d, d + 2), Eye{T}(d + 2, d))
@test any(iszero, m)
@test ! all(iszero, m)
@test any(isone, m)
@test ! all(isone, m)
end

m1 = Ones{T}(d, d)
@test ! isone(m1)
Expand Down Expand Up @@ -567,9 +618,15 @@ end

@testset "Eye identity ops" begin
m = Eye(10)
for op in (permutedims, inv, tril, triu, tril!, triu!)
for op in (permutedims, inv)
@test op(m) === m
end

for m in (Eye(10), Eye(10, 10), Eye(10, 8), Eye(8, 10))
for op in (tril, triu, tril!, triu!)
@test op(m) === m
end
end
end

@testset "Issue #31" begin
Expand Down

0 comments on commit 3ec04f6

Please sign in to comment.