Skip to content

Commit

Permalink
Move OneElement from Zygote and overload setindex (#161) (#235)
Browse files Browse the repository at this point in the history
* Add Zeros(T, n...) and Ones(T, n...) constructors (#94( (#233)

* Add Zeros(T, n...) and Ones(T, n...) constructors (#94(

* increase coverage

* Update README.md

* Move over OneElement from Zygote

* Add tests

* Update oneelement.jl

* add tests

* Update runtests.jl

* add docs
  • Loading branch information
dlfivefifty authored Mar 29, 2023
1 parent 4498570 commit 5be668f
Show file tree
Hide file tree
Showing 5 changed files with 83 additions and 11 deletions.
8 changes: 2 additions & 6 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,8 @@ as well as identity matrices. This package exports the following types:


The primary purpose of this package is to present a unified way of constructing
matrices. For example, to construct a 5-by-5 `CLArray` of all zeros, one would use
```julia
julia> CLArray(Zeros(5,5))
```
Because `Zeros` is lazy, this can be accomplished on the GPU with no memory transfer.
Similarly, to construct a 5-by-5 `BandedMatrix` of all zeros with bandwidths `(1,2)`, one would use
matrices.
For example, to construct a 5-by-5 `BandedMatrix` of all zeros with bandwidths `(1,2)`, one would use
```julia
julia> BandedMatrix(Zeros(5,5), (1, 2))
```
Expand Down
5 changes: 4 additions & 1 deletion src/FillArrays.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ import Base.Broadcast: broadcasted, DefaultArrayStyle, broadcast_shape
import Statistics: mean, std, var, cov, cor


export Zeros, Ones, Fill, Eye, Trues, Falses
export Zeros, Ones, Fill, Eye, Trues, Falses, OneElement

import Base: oneto

Expand Down Expand Up @@ -263,6 +263,7 @@ for (Typ, funcs, func) in ((:Zeros, :zeros, :zero), (:Ones, :ones, :one))
@inline $Typ{T,N}(A::AbstractArray{V,N}) where{T,V,N} = $Typ{T,N}(size(A))
@inline $Typ{T}(A::AbstractArray) where{T} = $Typ{T}(size(A))
@inline $Typ(A::AbstractArray) = $Typ{eltype(A)}(A)
@inline $Typ(::Type{T}, m...) where T = $Typ{T}(m...)

@inline axes(Z::$Typ) = Z.axes
@inline size(Z::$Typ) = length.(Z.axes)
Expand Down Expand Up @@ -728,4 +729,6 @@ Base.@propagate_inbounds function view(A::AbstractFill{<:Any,N}, I::Vararg{Real,
fillsimilar(A)
end

include("oneelement.jl")

end # module
1 change: 0 additions & 1 deletion src/fillalgebra.jl
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,6 @@ end
*(a::ZerosMatrix, b::AbstractMatrix) = mult_zeros(a, b)
*(a::AbstractMatrix, b::ZerosVector) = mult_zeros(a, b)
*(a::AbstractMatrix, b::ZerosMatrix) = mult_zeros(a, b)
*(a::ZerosVector, b::AbstractVector) = mult_zeros(a, b)
*(a::ZerosMatrix, b::AbstractVector) = mult_zeros(a, b)
*(a::AbstractVector, b::ZerosMatrix) = mult_zeros(a, b)

Expand Down
51 changes: 51 additions & 0 deletions src/oneelement.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
"""
OneElement(val, ind, axesorsize) <: AbstractArray
Represents an array with the specified axes (if its a tuple of `AbstractUnitRange`s)
or size (if its a tuple of `Integer`s), with a single entry set to `val` and all others equal to zero,
specified by `ind``.
"""
struct OneElement{T,N,I,A} <: AbstractArray{T,N}
val::T
ind::I
axes::A
OneElement(val::T, ind::I, axes::A) where {T<:Number, I<:NTuple{N,Int}, A<:NTuple{N,AbstractUnitRange}} where {N} = new{T,N,I,A}(val, ind, axes)
end

OneElement(val, inds::NTuple{N,Int}, sz::NTuple{N,Integer}) where N = OneElement(val, inds, oneto.(sz))
"""
OneElement(val, ind::Int, n::Int)
Creates a length `n` vector where the `ind` entry is equal to `val`, and all other entries are zero.
"""
OneElement(val, ind::Int, len::Int) = OneElement(val, (ind,), (len,))
"""
OneElement(ind::Int, n::Int)
Creates a length `n` vector where the `ind` entry is equal to `1`, and all other entries are zero.
"""
OneElement(inds::Int, sz::Int) = OneElement(1, inds, sz)
OneElement{T}(val, inds::NTuple{N,Int}, sz::NTuple{N,Integer}) where {T,N} = OneElement(convert(T,val), inds, oneto.(sz))
OneElement{T}(val, inds::Int, sz::Int) where T = OneElement{T}(val, (inds,), (sz,))

"""
OneElement{T}(val, ind::Int, n::Int)
Creates a length `n` vector where the `ind` entry is equal to `one(T)`, and all other entries are zero.
"""
OneElement{T}(inds::Int, sz::Int) where T = OneElement(one(T), inds, sz)

Base.size(A::OneElement) = map(length, A.axes)
Base.axes(A::OneElement) = A.axes
function Base.getindex(A::OneElement{T,N}, kj::Vararg{Int,N}) where {T,N}
@boundscheck checkbounds(A, kj...)
ifelse(kj == A.ind, A.val, zero(T))
end

Base.replace_in_print_matrix(o::OneElement{<:Any,2}, k::Integer, j::Integer, s::AbstractString) =
o.ind == (k,j) ? s : Base.replace_with_centered_mark(s)

function Base.setindex(A::Zeros{T,N}, v, kj::Vararg{Int,N}) where {T,N}
@boundscheck checkbounds(A, kj...)
OneElement(convert(T, v), kj, axes(A))
end
29 changes: 26 additions & 3 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ include("infinitearrays.jl")

for T in (Int, Float64)
Z = $Typ{T}(5)
@test $Typ(T, 5) Z
@test eltype(Z) == T
@test Array(Z) == $funcs(T,5)
@test Array{T}(Z) == $funcs(T,5)
Expand All @@ -34,6 +35,7 @@ include("infinitearrays.jl")
@test $Typ(2ones(T,5)) == Z

Z = $Typ{T}(5, 5)
@test $Typ(T, 5, 5) Z
@test eltype(Z) == T
@test Array(Z) == $funcs(T,5,5)
@test Array{T}(Z) == $funcs(T,5,5)
Expand Down Expand Up @@ -525,9 +527,9 @@ end
@test_throws MethodError [1,2,3]*Zeros(3) # Not defined for [1,2,3]*[0,0,0] either

@testset "Check multiplication by Adjoint vectors works as expected." begin
@test randn(4, 3)' * Zeros(4) === Zeros(3)
@test randn(4)' * Zeros(4) === zero(Float64)
@test [1, 2, 3]' * Zeros{Int}(3) === zero(Int)
@test randn(4, 3)' * Zeros(4) Zeros(3)
@test randn(4)' * Zeros(4) transpose(randn(4)) * Zeros(4) zero(Float64)
@test [1, 2, 3]' * Zeros{Int}(3) zero(Int)
@test [SVector(1,2)', SVector(2,3)', SVector(3,4)']' * Zeros{Int}(3) === SVector(0,0)
@test_throws DimensionMismatch randn(4)' * Zeros(3)
@test Zeros(5)' * randn(5,3) Zeros(5)'*Zeros(5,3) Zeros(5)'*Ones(5,3) Zeros(3)'
Expand Down Expand Up @@ -1503,4 +1505,25 @@ end
@test Zeros(5,5) .+ D isa Diagonal
f = (x,y) -> x+1
@test f.(D, Zeros(5,5)) isa Matrix
end

@testset "OneElement" begin
e₁ = OneElement(2, 5)
@test e₁ == [0,1,0,0,0]
@test_throws BoundsError e₁[6]

e₁ = OneElement{Float64}(2, 5)
@test e₁ == [0,1,0,0,0]

v = OneElement{Float64}(2, 3, 4)
@test v == [0,0,2,0]

V = OneElement(2, (2,3), (3,4))
@test V == [0 0 0 0; 0 0 2 0; 0 0 0 0]

@test stringmime("text/plain", V) == "3×4 OneElement{$Int, 2, Tuple{$Int, $Int}, Tuple{Base.OneTo{$Int}, Base.OneTo{$Int}}}:\n ⋅ ⋅ ⋅ ⋅\n ⋅ ⋅ 2 ⋅\n ⋅ ⋅ ⋅ ⋅"

@test Base.setindex(Zeros(5), 2, 2) OneElement(2.0, 2, 5)
@test Base.setindex(Zeros(5,3), 2, 2, 3) OneElement(2.0, (2,3), (5,3))
@test_throws BoundsError Base.setindex(Zeros(5), 2, 6)
end

0 comments on commit 5be668f

Please sign in to comment.