Skip to content

Commit

Permalink
Merge branch 'master' into patch-1
Browse files Browse the repository at this point in the history
  • Loading branch information
putianyi889 authored Mar 30, 2023
2 parents ca3e962 + f589311 commit 0abb814
Show file tree
Hide file tree
Showing 7 changed files with 311 additions and 90 deletions.
4 changes: 2 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "FillArrays"
uuid = "1a297f60-69ca-5386-bcde-b61e274b549b"
version = "0.13.8"
version = "1.0"

[deps]
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Expand All @@ -9,7 +9,7 @@ SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"

[compat]
Aqua = "0.5"
Aqua = "0.5, 0.6"
julia = "1.6"

[extras]
Expand Down
8 changes: 2 additions & 6 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,8 @@ as well as identity matrices. This package exports the following types:


The primary purpose of this package is to present a unified way of constructing
matrices. For example, to construct a 5-by-5 `CLArray` of all zeros, one would use
```julia
julia> CLArray(Zeros(5,5))
```
Because `Zeros` is lazy, this can be accomplished on the GPU with no memory transfer.
Similarly, to construct a 5-by-5 `BandedMatrix` of all zeros with bandwidths `(1,2)`, one would use
matrices.
For example, to construct a 5-by-5 `BandedMatrix` of all zeros with bandwidths `(1,2)`, one would use
```julia
julia> BandedMatrix(Zeros(5,5), (1, 2))
```
Expand Down
79 changes: 63 additions & 16 deletions src/FillArrays.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ import Base: size, getindex, setindex!, IndexStyle, checkbounds, convert,
+, -, *, /, \, diff, sum, cumsum, maximum, minimum, sort, sort!,
any, all, axes, isone, iterate, unique, allunique, permutedims, inv,
copy, vec, setindex!, count, ==, reshape, _throw_dmrs, map, zero,
show, view, in, mapreduce, one, reverse, promote_op
show, view, in, mapreduce, one, reverse, promote_op, promote_rule

import LinearAlgebra: rank, svdvals!, tril, triu, tril!, triu!, diag, transpose, adjoint, fill!,
dot, norm2, norm1, normInf, normMinusInf, normp, lmul!, rmul!, diagzero, AdjointAbsVec, TransposeAbsVec,
Expand All @@ -18,7 +18,7 @@ import Base.Broadcast: broadcasted, DefaultArrayStyle, broadcast_shape
import Statistics: mean, std, var, cov, cor


export Zeros, Ones, Fill, Eye, Trues, Falses
export Zeros, Ones, Fill, Eye, Trues, Falses, OneElement

import Base: oneto

Expand All @@ -34,13 +34,14 @@ const AbstractFillVecOrMat{T} = Union{AbstractFillVector{T},AbstractFillMatrix{T

==(a::AbstractFill, b::AbstractFill) = axes(a) == axes(b) && getindex_value(a) == getindex_value(b)

Base.@propagate_inbounds @inline function _fill_getindex(F::AbstractFill, kj::Integer...)

@inline function _fill_getindex(F::AbstractFill, kj::Integer...)
@boundscheck checkbounds(F, kj...)
getindex_value(F)
end

getindex(F::AbstractFill, k::Integer) = _fill_getindex(F, k)
getindex(F::AbstractFill{T, N}, kj::Vararg{Integer, N}) where {T, N} = _fill_getindex(F, kj...)
Base.@propagate_inbounds getindex(F::AbstractFill, k::Integer) = _fill_getindex(F, k)
Base.@propagate_inbounds getindex(F::AbstractFill{T, N}, kj::Vararg{Integer, N}) where {T, N} = _fill_getindex(F, kj...)

@inline function setindex!(F::AbstractFill, v, k::Integer)
@boundscheck checkbounds(F, k)
Expand Down Expand Up @@ -147,15 +148,47 @@ Fill{T,0}(x::T, ::Tuple{}) where T = Fill{T,0,Tuple{}}(x, ()) # ambiguity fix
@inline axes(F::Fill) = F.axes
@inline size(F::Fill) = map(length, F.axes)

"""
getindex_value(F::AbstractFill)
Return the value that `F` is filled with.
# Examples
```jldoctest
julia> f = Ones(3);
julia> FillArrays.getindex_value(f)
1.0
julia> g = Fill(2, 10);
julia> FillArrays.getindex_value(g)
2
```
"""
getindex_value

@inline getindex_value(F::Fill) = F.value

AbstractArray{T}(F::Fill{V,N}) where {T,V,N} = Fill{T}(convert(T, F.value)::T, F.axes)
AbstractArray{T,N}(F::Fill{V,N}) where {T,V,N} = Fill{T}(convert(T, F.value)::T, F.axes)
AbstractFill{T}(F::AbstractFill) where T = AbstractArray{T}(F)
AbstractFill{T,N}(F::AbstractFill) where {T,N} = AbstractArray{T,N}(F)
AbstractFill{T,N,Ax}(F::AbstractFill{<:Any,N,Ax}) where {T,N,Ax} = AbstractArray{T,N}(F)

convert(::Type{AbstractFill{T}}, F::AbstractFill) where T = convert(AbstractArray{T}, F)
convert(::Type{AbstractFill{T,N}}, F::AbstractFill) where {T,N} = convert(AbstractArray{T,N}, F)
convert(::Type{AbstractFill{T,N,Ax}}, F::AbstractFill{<:Any,N,Ax}) where {T,N,Ax} = convert(AbstractArray{T,N}, F)

copy(F::Fill) = Fill(F.value, F.axes)

""" Throws an error if `arr` does not contain one and only one unique value. """
"""
unique_value(arr::AbstractArray)
Return `only(unique(arr))` without intermediate allocations.
Throws an error if `arr` does not contain one and only one unique value.
"""
function unique_value(arr::AbstractArray)
if isempty(arr) error("Cannot convert empty array to Fill") end
val = first(arr)
Expand Down Expand Up @@ -195,8 +228,8 @@ Base.@propagate_inbounds @inline Base._unsafe_getindex(::IndexStyle, F::Abstract



getindex(A::AbstractFill, kr::AbstractVector{Bool}) = _fill_getindex(A, kr)
getindex(A::AbstractFill, kr::AbstractArray{Bool}) = _fill_getindex(A, kr)
Base.@propagate_inbounds getindex(A::AbstractFill, kr::AbstractVector{Bool}) = _fill_getindex(A, kr)
Base.@propagate_inbounds getindex(A::AbstractFill, kr::AbstractArray{Bool}) = _fill_getindex(A, kr)

@inline Base.iterate(F::AbstractFill) = length(F) == 0 ? nothing : (v = getindex_value(F); (v, (v, 1)))
@inline function Base.iterate(F::AbstractFill, (v, n))
Expand Down Expand Up @@ -262,6 +295,7 @@ for (Typ, funcs, func) in ((:Zeros, :zeros, :zero), (:Ones, :ones, :one))
@inline $Typ{T,N}(A::AbstractArray{V,N}) where{T,V,N} = $Typ{T,N}(size(A))
@inline $Typ{T}(A::AbstractArray) where{T} = $Typ{T}(size(A))
@inline $Typ(A::AbstractArray) = $Typ{eltype(A)}(A)
@inline $Typ(::Type{T}, m...) where T = $Typ{T}(m...)

@inline axes(Z::$Typ) = Z.axes
@inline size(Z::$Typ) = length.(Z.axes)
Expand All @@ -273,6 +307,14 @@ for (Typ, funcs, func) in ((:Zeros, :zeros, :zero), (:Ones, :ones, :one))
copy(F::$Typ) = F

getindex(F::$Typ{T,0}) where T = getindex_value(F)

promote_rule(::Type{$Typ{T, N, Axes}}, ::Type{$Typ{V, N, Axes}}) where {T,V,N,Axes} = $Typ{promote_type(T,V),N,Axes}
function convert(::Type{$Typ{T,N,Axes}}, A::$Typ{V,N,Axes}) where {T,V,N,Axes}
convert(T, getindex_value(A)) # checks that the types are convertible
$Typ{T,N,Axes}(axes(A))
end
convert(::Type{$Typ{T,N}}, A::$Typ{V,N,Axes}) where {T,V,N,Axes} = convert($Typ{T,N,Axes}, A)
convert(::Type{$Typ{T}}, A::$Typ{V,N,Axes}) where {T,V,N,Axes} = convert($Typ{T,N,Axes}, A)
end
end

Expand All @@ -284,6 +326,8 @@ for TYPE in (:Fill, :AbstractFill, :Ones, :Zeros), STYPE in (:AbstractArray, :Ab
end
end

promote_rule(::Type{<:AbstractFill{T, N, Axes}}, ::Type{<:AbstractFill{V, N, Axes}}) where {T,V,N,Axes} = Fill{promote_type(T,V),N,Axes}

"""
fillsimilar(a::AbstractFill, axes)
Expand Down Expand Up @@ -426,7 +470,8 @@ end


## Array
Base.Array{T,N}(F::AbstractFill{V,N}) where {T,V,N} = fill(convert(T, getindex_value(F)), size(F))
Base.Array{T,N}(F::AbstractFill{V,N}) where {T,V,N} =
convert(Array{T,N}, fill(convert(T, getindex_value(F)), size(F)))

# These are in case `zeros` or `ones` are ever faster than `fill`
for (Typ, funcs, func) in ((:Zeros, :zeros, :zero), (:Ones, :ones, :one))
Expand All @@ -437,7 +482,7 @@ end

# temporary patch. should be a PR(#48895) to LinearAlgebra
Diagonal{T}(A::AbstractFillMatrix) where T = Diagonal{T}(diag(A))
function convert(::Type{T}, A::AbstractFillMatrix) where T<:Diagonal
function convert(::Type{T}, A::AbstractFillMatrix) where T<:Diagonal
checksquare(A)
isdiag(A) ? T(A) : throw(InexactError(:convert, T, A))
end
Expand Down Expand Up @@ -496,14 +541,14 @@ sum(x::AbstractFill) = getindex_value(x)*length(x)
sum(f, x::AbstractFill) = length(x) * f(getindex_value(x))
sum(x::Zeros) = getindex_value(x)

cumsum(x::AbstractFill{<:Any,1}) = range(getindex_value(x); step=getindex_value(x),
length=length(x))

# needed to support infinite case
steprangelen(st...) = StepRangeLen(st...)
cumsum(x::AbstractFill{<:Any,1}) = steprangelen(getindex_value(x), getindex_value(x), length(x))

cumsum(x::ZerosVector) = x
cumsum(x::ZerosVector{Bool}) = x
cumsum(x::OnesVector{II}) where II<:Integer = convert(AbstractVector{II}, oneto(length(x)))
cumsum(x::OnesVector{Bool}) = oneto(length(x))
cumsum(x::AbstractFillVector{Bool}) = cumsum(AbstractFill{Int}(x))

#########
# Diff
Expand Down Expand Up @@ -706,13 +751,15 @@ copy(a::LinearAlgebra.Transpose{<:Any,<:AbstractFill}) = transpose(parent(a))

Base.@propagate_inbounds view(A::AbstractFill{<:Any,N}, kr::AbstractArray{Bool,N}) where N = _fill_getindex(A, kr)
Base.@propagate_inbounds view(A::AbstractFill{<:Any,1}, kr::AbstractVector{Bool}) = _fill_getindex(A, kr)
Base.@propagate_inbounds view(A::AbstractFill{<:Any,N}, I::Vararg{Union{Real, AbstractArray}, N}) where N =
Base.@propagate_inbounds view(A::AbstractFill, I...) =
_fill_getindex(A, Base.to_indices(A,I)...)

# not getindex since we need array-like indexing
Base.@propagate_inbounds function view(A::AbstractFill{<:Any,N}, I::Vararg{Real, N}) where N
Base.@propagate_inbounds function view(A::AbstractFill, I::Vararg{Real})
@boundscheck checkbounds(A, I...)
fillsimilar(A)
end

include("oneelement.jl")

end # module
54 changes: 35 additions & 19 deletions src/fillalgebra.jl
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ for TYPE in (:AbstractFillVector, :AbstractVector)
@eval *(a::ZerosMatrix, b::$TYPE) = mult_zeros(a,b)
end

function *(a::Diagonal, b::AbstractFill{<:Any,2})
function *(a::Diagonal, b::AbstractFillMatrix)
checkdimensionmismatch(a,b)
a.diag .* b # use special broadcast
end
Expand Down Expand Up @@ -137,21 +137,22 @@ function *(a::Transpose{T, <:AbstractVector{T}}, b::ZerosVector{T}) where T<:Rea
end
*(a::Transpose{T, <:AbstractMatrix{T}}, b::ZerosVector{T}) where T<:Real = mult_zeros(a, b)

# treat zero separately to support ∞-vectors
function _fill_dot(a::AbstractVector, b::AbstractVector)
# support types with fast sum
# infinite cases should be supported in InfiniteArrays.jl
# type issues of Bool dot are ignored at present.
function _fill_dot(a::AbstractFillVector{T}, b::AbstractVector{V}) where {T,V}
axes(a) == axes(b) || throw(DimensionMismatch("dot product arguments have lengths $(length(a)) and $(length(b))"))
if iszero(a) || iszero(b)
zero(promote_type(eltype(a),eltype(b)))
elseif isa(a,AbstractFill)
getindex_value(a)sum(b)
else
getindex_value(b)sum(a)
end
dot(getindex_value(a), sum(b))
end

function _fill_dot_rev(a::AbstractVector{T}, b::AbstractFillVector{V}) where {T,V}
axes(a) == axes(b) || throw(DimensionMismatch("dot product arguments have lengths $(length(a)) and $(length(b))"))
dot(sum(a), getindex_value(b))
end

dot(a::AbstractFillVector, b::AbstractFillVector) = _fill_dot(a, b)
dot(a::AbstractFillVector, b::AbstractVector) = _fill_dot(a, b)
dot(a::AbstractVector, b::AbstractFillVector) = _fill_dot(a, b)
dot(a::AbstractVector, b::AbstractFillVector) = _fill_dot_rev(a, b)

function dot(u::AbstractVector, E::Eye, v::AbstractVector)
length(u) == size(E,1) && length(v) == size(E,2) ||
Expand All @@ -172,11 +173,18 @@ function dot(u::AbstractVector{T}, D::Diagonal{U,<:Zeros}, v::AbstractVector{V})
end

# Addition and Subtraction
+(a::AbstractFill) = a
-(a::Zeros) = a
-(a::AbstractFill) = Fill(-getindex_value(a), size(a))


function +(a::Zeros{T}, b::Zeros{V}) where {T, V} # for disambiguity
promote_shape(a,b)
return elconvert(promote_op(+,T,V),a)
end
for TYPE in (:AbstractArray, :AbstractFill) # AbstractFill for disambiguity
# no AbstractArray. Otherwise incompatible with StaticArrays.jl
# AbstractFill for disambiguity
for TYPE in (:Array, :AbstractFill, :AbstractRange, :Diagonal)
@eval function +(a::$TYPE{T}, b::Zeros{V}) where {T, V}
promote_shape(a,b)
return elconvert(promote_op(+,T,V),a)
Expand All @@ -196,15 +204,23 @@ end

-(a::Ones, b::Ones) = Zeros(a) + Zeros(b)

# necessary for AbstractRange, Diagonal, etc
+(a::AbstractFill, b::AbstractFill) = fill_add(a, b)
+(a::AbstractFill, b::AbstractArray) = fill_add(b, a)
+(a::AbstractArray, b::AbstractFill) = fill_add(a, b)
# no AbstractArray. Otherwise incompatible with StaticArrays.jl
for TYPE in (:Array, :AbstractRange)
@eval begin
+(a::$TYPE, b::AbstractFill) = fill_add(a, b)
-(a::$TYPE, b::AbstractFill) = a + (-b)
+(a::AbstractFill, b::$TYPE) = fill_add(b, a)
-(a::AbstractFill, b::$TYPE) = a + (-b)
end
end
+(a::AbstractFill, b::AbstractFill) = Fill(getindex_value(a) + getindex_value(b), promote_shape(a,b))
-(a::AbstractFill, b::AbstractFill) = a + (-b)
-(a::AbstractFill, b::AbstractArray) = a + (-b)
-(a::AbstractArray, b::AbstractFill) = a + (-b)

@inline function fill_add(a, b::AbstractFill)
@inline function fill_add(a::AbstractArray, b::AbstractFill)
promote_shape(a, b)
a .+ [getindex_value(b)]
end
@inline function fill_add(a::AbstractArray{<:Number}, b::AbstractFill)
promote_shape(a, b)
a .+ getindex_value(b)
end
Expand Down
9 changes: 5 additions & 4 deletions src/fillbroadcast.jl
Original file line number Diff line number Diff line change
Expand Up @@ -102,10 +102,6 @@ _broadcasted_zeros(f, a, b) = Zeros{Base.Broadcast.combine_eltypes(f, (a, b))}(b
_broadcasted_ones(f, a, b) = Ones{Base.Broadcast.combine_eltypes(f, (a, b))}(broadcast_shape(axes(a), axes(b)))
_broadcasted_nan(f, a, b) = Fill(convert(Base.Broadcast.combine_eltypes(f, (a, b)), NaN), broadcast_shape(axes(a), axes(b)))

# TODO: remove at next breaking version
_broadcasted_zeros(a, b) = _broadcasted_zeros(+, a, b)
_broadcasted_ones(a, b) = _broadcasted_ones(+, a, b)

broadcasted(::DefaultArrayStyle, ::typeof(+), a::Zeros, b::Zeros) = _broadcasted_zeros(+, a, b)
broadcasted(::DefaultArrayStyle, ::typeof(+), a::Ones, b::Zeros) = _broadcasted_ones(+, a, b)
broadcasted(::DefaultArrayStyle, ::typeof(+), a::Zeros, b::Ones) = _broadcasted_ones(+, a, b)
Expand Down Expand Up @@ -247,3 +243,8 @@ broadcasted(::DefaultArrayStyle{N}, ::typeof(Base.literal_pow), ::Base.RefValue{
broadcasted(::DefaultArrayStyle{N}, ::typeof(Base.literal_pow), ::Base.RefValue{typeof(^)}, r::Ones{T,N}, ::Base.RefValue{Val{k}}) where {T,N,k} = Ones{T}(axes(r))
broadcasted(::DefaultArrayStyle{N}, ::typeof(Base.literal_pow), ::Base.RefValue{typeof(^)}, r::Zeros{T,N}, ::Base.RefValue{Val{0}}) where {T,N} = Ones{T}(axes(r))
broadcasted(::DefaultArrayStyle{N}, ::typeof(Base.literal_pow), ::Base.RefValue{typeof(^)}, r::Zeros{T,N}, ::Base.RefValue{Val{k}}) where {T,N,k} = Zeros{T}(axes(r))

# supports structured broadcast
if isdefined(LinearAlgebra, :fzero)
LinearAlgebra.fzero(x::Zeros) = zero(eltype(x))
end
51 changes: 51 additions & 0 deletions src/oneelement.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
"""
OneElement(val, ind, axesorsize) <: AbstractArray
Represents an array with the specified axes (if its a tuple of `AbstractUnitRange`s)
or size (if its a tuple of `Integer`s), with a single entry set to `val` and all others equal to zero,
specified by `ind``.
"""
struct OneElement{T,N,I,A} <: AbstractArray{T,N}
val::T
ind::I
axes::A
OneElement(val::T, ind::I, axes::A) where {T<:Number, I<:NTuple{N,Int}, A<:NTuple{N,AbstractUnitRange}} where {N} = new{T,N,I,A}(val, ind, axes)
end

OneElement(val, inds::NTuple{N,Int}, sz::NTuple{N,Integer}) where N = OneElement(val, inds, oneto.(sz))
"""
OneElement(val, ind::Int, n::Int)
Creates a length `n` vector where the `ind` entry is equal to `val`, and all other entries are zero.
"""
OneElement(val, ind::Int, len::Int) = OneElement(val, (ind,), (len,))
"""
OneElement(ind::Int, n::Int)
Creates a length `n` vector where the `ind` entry is equal to `1`, and all other entries are zero.
"""
OneElement(inds::Int, sz::Int) = OneElement(1, inds, sz)
OneElement{T}(val, inds::NTuple{N,Int}, sz::NTuple{N,Integer}) where {T,N} = OneElement(convert(T,val), inds, oneto.(sz))
OneElement{T}(val, inds::Int, sz::Int) where T = OneElement{T}(val, (inds,), (sz,))

"""
OneElement{T}(val, ind::Int, n::Int)
Creates a length `n` vector where the `ind` entry is equal to `one(T)`, and all other entries are zero.
"""
OneElement{T}(inds::Int, sz::Int) where T = OneElement(one(T), inds, sz)

Base.size(A::OneElement) = map(length, A.axes)
Base.axes(A::OneElement) = A.axes
function Base.getindex(A::OneElement{T,N}, kj::Vararg{Int,N}) where {T,N}
@boundscheck checkbounds(A, kj...)
ifelse(kj == A.ind, A.val, zero(T))
end

Base.replace_in_print_matrix(o::OneElement{<:Any,2}, k::Integer, j::Integer, s::AbstractString) =
o.ind == (k,j) ? s : Base.replace_with_centered_mark(s)

function Base.setindex(A::Zeros{T,N}, v, kj::Vararg{Int,N}) where {T,N}
@boundscheck checkbounds(A, kj...)
OneElement(convert(T, v), kj, axes(A))
end
Loading

0 comments on commit 0abb814

Please sign in to comment.