From f05747d269014e98d586b7ade681f78d5fd5caf0 Mon Sep 17 00:00:00 2001 From: Tianyi Pu <912396513@qq.com> Date: Sun, 5 Mar 2023 23:01:58 +0000 Subject: [PATCH 1/2] update --- src/FillArrays.jl | 51 ++++------ src/fillalgebra.jl | 232 +++++++++++++++------------------------------ test/runtests.jl | 10 +- 3 files changed, 103 insertions(+), 190 deletions(-) diff --git a/src/FillArrays.jl b/src/FillArrays.jl index 2e223101..9bdbd35f 100644 --- a/src/FillArrays.jl +++ b/src/FillArrays.jl @@ -10,7 +10,7 @@ import Base: size, getindex, setindex!, IndexStyle, checkbounds, convert, import LinearAlgebra: rank, svdvals!, tril, triu, tril!, triu!, diag, transpose, adjoint, fill!, dot, norm2, norm1, normInf, normMinusInf, normp, lmul!, rmul!, diagzero, AbstractTriangular, AdjointAbsVec, TransposeAbsVec, - issymmetric, ishermitian, AdjOrTransAbsVec + issymmetric, ishermitian, AdjOrTransAbsVec, checksquare import Base.Broadcast: broadcasted, DefaultArrayStyle, broadcast_shape @@ -146,14 +146,7 @@ AbstractArray{T}(F::Fill{T}) where T = F AbstractArray{T,N}(F::Fill{T,N}) where {T,N} = F AbstractArray{T}(F::Fill{V,N}) where {T,V,N} = Fill{T}(convert(T, F.value)::T, F.axes) AbstractArray{T,N}(F::Fill{V,N}) where {T,V,N} = Fill{T}(convert(T, F.value)::T, F.axes) - -convert(::Type{AbstractArray{T}}, F::Fill{T}) where T = F -convert(::Type{AbstractArray{T,N}}, F::Fill{T,N}) where {T,N} = F -convert(::Type{AbstractArray{T}}, F::Fill) where {T} = AbstractArray{T}(F) -convert(::Type{AbstractArray{T,N}}, F::Fill) where {T,N} = AbstractArray{T,N}(F) -convert(::Type{AbstractFill}, F::AbstractFill) = F -convert(::Type{AbstractFill{T}}, F::AbstractFill) where T = convert(AbstractArray{T}, F) -convert(::Type{AbstractFill{T,N}}, F::AbstractFill) where {T,N} = convert(AbstractArray{T,N}, F) +AbstractFill{T}(F::AbstractFill) where T = AbstractArray{T}(F) copy(F::Fill) = Fill(F.value, F.axes) @@ -210,15 +203,11 @@ sort(a::AbstractFill; kwds...) = a sort!(a::AbstractFill; kwds...) = a svdvals!(a::AbstractFill{<:Any,2}) = [getindex_value(a)*sqrt(prod(size(a))); Zeros(min(size(a)...)-1)] -+(a::AbstractFill) = a --(a::AbstractFill) = Fill(-getindex_value(a), size(a)) - # Fill +/- Fill function +(a::AbstractFill{T, N}, b::AbstractFill{V, N}) where {T, V, N} axes(a) ≠ axes(b) && throw(DimensionMismatch("dimensions must match.")) return Fill(getindex_value(a) + getindex_value(b), axes(a)) end --(a::AbstractFill, b::AbstractFill) = a + (-b) function +(a::Fill{T, 1}, b::AbstractRange) where {T} size(a) ≠ size(b) && throw(DimensionMismatch("dimensions must match.")) @@ -299,10 +288,6 @@ for (Typ, funcs, func) in ((:Zeros, :zeros, :zero), (:Ones, :ones, :one)) AbstractArray{T,N}(F::$Typ{T,N}) where {T,N} = F AbstractArray{T}(F::$Typ) where T = $Typ{T}(F.axes) AbstractArray{T,N}(F::$Typ{V,N}) where {T,V,N} = $Typ{T}(F.axes) - convert(::Type{AbstractArray{T}}, F::$Typ{T}) where T = AbstractArray{T}(F) - convert(::Type{AbstractArray{T,N}}, F::$Typ{T,N}) where {T,N} = AbstractArray{T,N}(F) - convert(::Type{AbstractArray{T}}, F::$Typ) where T = AbstractArray{T}(F) - convert(::Type{AbstractArray{T,N}}, F::$Typ) where {T,N} = AbstractArray{T,N}(F) copy(F::$Typ) = F @@ -310,6 +295,18 @@ for (Typ, funcs, func) in ((:Zeros, :zeros, :zero), (:Ones, :ones, :one)) end end +for TYPE in (:Fill, :AbstractFill, :Ones, :Zeros) + @eval begin + @inline AbstractFill{T}(F::$TYPE{T}) where T = F + @inline AbstractFill{T,N}(F::$TYPE{T,N}) where {T,N} = F + @inline AbstractFill{T,N,Axes}(F::$TYPE{T,N,Axes}) where {T,N,Axes} = F + + const $(Symbol(TYPE,"Vector")){T} = $TYPE{T,1} + const $(Symbol(TYPE,"Matrix")){T} = $TYPE{T,2} + const $(Symbol(TYPE,"VecOrMat")){T} = Union{$TYPE{T,1},$TYPE{T,2}} + end +end + """ fillsimilar(a::AbstractFill, axes) @@ -459,16 +456,11 @@ for (Typ, funcs, func) in ((:Zeros, :zeros, :zero), (:Ones, :ones, :one)) end end -function convert(::Type{Diagonal}, Z::Zeros{T,2}) where T - n,m = size(Z) - n ≠ m && throw(BoundsError(Z)) - Diagonal(zeros(T, n)) -end - -function convert(::Type{Diagonal{T}}, Z::Zeros{V,2}) where {T,V} - n,m = size(Z) - n ≠ m && throw(BoundsError(Z)) - Diagonal(zeros(T, n)) +# temporary patch. should be a PR(#48895) to LinearAlgebra +Diagonal{T}(A::AbstractMatrix) where T = Diagonal{T}(diag(A)) +function convert(::Type{T}, A::AbstractMatrix) where T<:Diagonal + checksquare(A) + isdiag(A) ? T(A) : throw(InexactError(:convert, T, A)) end ## Sparse arrays @@ -539,7 +531,7 @@ cumsum(x::Zeros{<:Any,1}) = x cumsum(x::Zeros{Bool,1}) = x cumsum(x::Ones{II,1}) where II<:Integer = convert(AbstractVector{II}, oneto(length(x))) cumsum(x::Ones{Bool,1}) = oneto(length(x)) -cumsum(x::AbstractFill{Bool,1}) = cumsum(convert(AbstractFill{Int}, x)) +cumsum(x::AbstractFill{Bool,1}) = cumsum(AbstractFill{Int}(x)) ######### @@ -560,8 +552,7 @@ allunique(x::AbstractFill) = length(x) < 2 ######### zero(r::Zeros{T,N}) where {T,N} = r -zero(r::Ones{T,N}) where {T,N} = Zeros{T,N}(r.axes) -zero(r::Fill{T,N}) where {T,N} = Zeros{T,N}(r.axes) +zero(r::AbstractFill{T,N}) where {T,N} = Zeros{T,N}(r) ######### # oneunit diff --git a/src/fillalgebra.jl b/src/fillalgebra.jl index adb393e5..a43c5831 100644 --- a/src/fillalgebra.jl +++ b/src/fillalgebra.jl @@ -7,16 +7,15 @@ vec(a::Fill{T}) where T = Fill{T}(a.value,length(a)) ## Transpose/Adjoint # cannot do this for vectors since that would destroy scalar dot product +for fun in (:transpose,:adjoint) + for TYPE in (:Ones,:Zeros) + @eval $fun(a::$TYPE{T,2}) where T = $TYPE{T}(reverse(a.axes)) + end + @eval $fun(a::FillMatrix) where T = Fill{T}($fun(a.value), reverse(a.axes)) +end -transpose(a::Ones{T,2}) where T = Ones{T}(reverse(a.axes)) -adjoint(a::Ones{T,2}) where T = Ones{T}(reverse(a.axes)) -transpose(a::Zeros{T,2}) where T = Zeros{T}(reverse(a.axes)) -adjoint(a::Zeros{T,2}) where T = Zeros{T}(reverse(a.axes)) -transpose(a::Fill{T,2}) where T = Fill{T}(transpose(a.value), reverse(a.axes)) -adjoint(a::Fill{T,2}) where T = Fill{T}(adjoint(a.value), reverse(a.axes)) - -permutedims(a::AbstractFill{<:Any,1}) = fillsimilar(a, (1, length(a))) -permutedims(a::AbstractFill{<:Any,2}) = fillsimilar(a, reverse(a.axes)) +permutedims(a::AbstractFillVector) = fillsimilar(a, (1, length(a))) +permutedims(a::AbstractFillMatrix) = fillsimilar(a, reverse(a.axes)) function permutedims(B::AbstractFill, perm) dimsB = size(B) @@ -34,74 +33,51 @@ end reverse(A::AbstractFill; dims=:) = A ## Algebraic identities +@inline checkdimensionmismatch(a::AbstractVecOrMat, b::AbstractVecOrMat) = axes(a, 2) ≠ axes(b, 1) && throw(DimensionMismatch("A has axes $(axes(a)) but B has axes $(axes(b))")) +@inline productaxes(a::AbstractVecOrMat, b::AbstractVector) = (axes(a, 1),) +@inline productaxes(a::AbstractVecOrMat, b::AbstractMatrix) = (axes(a, 1), axes(b, 2)) - -function mult_fill(a::AbstractFill, b::AbstractFill{<:Any,2}) - axes(a, 2) ≠ axes(b, 1) && - throw(DimensionMismatch("Incompatible matrix multiplication dimensions")) - return Fill(getindex_value(a)*getindex_value(b)*size(a,2), (axes(a, 1), axes(b, 2))) +function mult_fill(a::AbstractFill, b::AbstractFillVecOrMat) + checkdimensionmismatch(a,b) + return Fill(getindex_value(a)*getindex_value(b)*size(a,2), productaxes(a,b)) end -function mult_fill(a::AbstractFill, b::AbstractFill{<:Any,1}) - axes(a, 2) ≠ axes(b, 1) && - throw(DimensionMismatch("Incompatible matrix multiplication dimensions")) - return Fill(getindex_value(a)*getindex_value(b)*size(a,2), (axes(a, 1),)) +function mult_ones(a::AbstractVector, b::AbstractMatrix) + checkdimensionmismatch(a,b) + return Ones{promote_type(eltype(a), eltype(b))}(productaxes(a,b)) end -function mult_ones(a::AbstractVector, b::AbstractMatrix) - axes(a, 2) ≠ axes(b, 1) && - throw(DimensionMismatch("Incompatible matrix multiplication dimensions")) - return Ones{promote_type(eltype(a), eltype(b))}((axes(a, 1), axes(b, 2))) +function mult_zeros(a, b::AbstractVecOrMat) + checkdimensionmismatch(a,b) + return Zeros{promote_type(eltype(a), eltype(b))}(productaxes(a,b)) end -function mult_zeros(a, b::AbstractMatrix) - axes(a, 2) ≠ axes(b, 1) && - throw(DimensionMismatch("Incompatible matrix multiplication dimensions")) - return Zeros{promote_type(eltype(a), eltype(b))}((axes(a, 1), axes(b, 2))) +*(a::ZerosVector, b::AdjOrTransAbsVec) = mult_zeros(a, b) + +# Matrix * VecOrMat. +# For Vector*Matrix, LinearAlgebra reshapes the vector to matrix automatically. +# See *(a::AbstractVector, B::AbstractMatrix) at matmul.jl +*(a::AbstractFillMatrix, b::AbstractFillVecOrMat) = mult_fill(a,b) +*(a::ZerosMatrix, b::ZerosVector) = mult_zeros(a, b) +*(a::ZerosMatrix, b::ZerosMatrix) = mult_zeros(a, b) +*(a::OnesVector, b::OnesMatrix) = mult_ones(a, b) +for TYPE in (AbstractFillMatrix, AbstractMatrix, Diagonal) + @eval begin + *(a::ZerosMatrix, b::$TYPE) = mult_zeros(a,b) + *(a::$TYPE, b::ZerosVector) = mult_zeros(a,b) + *(a::$TYPE, b::ZerosMatrix) = mult_zeros(a,b) + end end -function mult_zeros(a, b::AbstractVector) - axes(a, 2) ≠ axes(b, 1) && - throw(DimensionMismatch("Incompatible matrix multiplication dimensions")) - return Zeros{promote_type(eltype(a), eltype(b))}((axes(a, 1),)) +for TYPE in (:AbstractFillVector, :AbstractVector) + @eval *(a::ZerosMatrix, b::$TYPE) = mult_zeros(a,b) end -*(a::AbstractFill{<:Any,1}, b::AbstractFill{<:Any,2}) = mult_fill(a,b) -*(a::AbstractFill{<:Any,2}, b::AbstractFill{<:Any,2}) = mult_fill(a,b) -*(a::AbstractFill{<:Any,2}, b::AbstractFill{<:Any,1}) = mult_fill(a,b) - -*(a::Ones{<:Any,1}, b::Ones{<:Any,2}) = mult_ones(a, b) - -*(a::Zeros{<:Any,1}, b::Zeros{<:Any,2}) = mult_zeros(a, b) -*(a::Zeros{<:Any,2}, b::Zeros{<:Any,2}) = mult_zeros(a, b) -*(a::Zeros{<:Any,2}, b::Zeros{<:Any,1}) = mult_zeros(a, b) - -*(a::Zeros{<:Any,1}, b::AbstractFill{<:Any,2}) = mult_zeros(a, b) -*(a::Zeros{<:Any,2}, b::AbstractFill{<:Any,2}) = mult_zeros(a, b) -*(a::Zeros{<:Any,2}, b::AbstractFill{<:Any,1}) = mult_zeros(a, b) -*(a::AbstractFill{<:Any,1}, b::Zeros{<:Any,2}) = mult_zeros(a,b) -*(a::AbstractFill{<:Any,2}, b::Zeros{<:Any,2}) = mult_zeros(a,b) -*(a::AbstractFill{<:Any,2}, b::Zeros{<:Any,1}) = mult_zeros(a,b) - -*(a::Zeros{<:Any,1}, b::AbstractMatrix) = mult_zeros(a, b) -*(a::Zeros{<:Any,2}, b::AbstractMatrix) = mult_zeros(a, b) -*(a::AbstractMatrix, b::Zeros{<:Any,1}) = mult_zeros(a, b) -*(a::AbstractMatrix, b::Zeros{<:Any,2}) = mult_zeros(a, b) -*(a::Zeros{<:Any,1}, b::AbstractVector) = mult_zeros(a, b) -*(a::Zeros{<:Any,2}, b::AbstractVector) = mult_zeros(a, b) -*(a::AbstractVector, b::Zeros{<:Any,2}) = mult_zeros(a, b) - -*(a::Zeros{<:Any,1}, b::AdjOrTransAbsVec) = mult_zeros(a, b) - -*(a::Zeros{<:Any,1}, b::Diagonal) = mult_zeros(a, b) -*(a::Zeros{<:Any,2}, b::Diagonal) = mult_zeros(a, b) -*(a::Diagonal, b::Zeros{<:Any,1}) = mult_zeros(a, b) -*(a::Diagonal, b::Zeros{<:Any,2}) = mult_zeros(a, b) function *(a::Diagonal, b::AbstractFill{<:Any,2}) - size(a,2) == size(b,1) || throw(DimensionMismatch("A has dimensions $(size(a)) but B has dimensions $(size(b))")) + checkdimensionmismatch(a,b) a.diag .* b # use special broadcast end function *(a::AbstractFill{<:Any,2}, b::Diagonal) - size(a,2) == size(b,1) || throw(DimensionMismatch("A has dimensions $(size(a)) but B has dimensions $(size(b))")) + checkdimensionmismatch(a,b) a .* permutedims(b.diag) # use special broadcast end @@ -134,64 +110,48 @@ function _adjvec_mul_zeros(a, b) return zero(Base.promote_op(*, eltype(a), eltype(b))) end -*(a::AdjointAbsVec{<:Any,<:Zeros{<:Any,1}}, b::AbstractMatrix) = (b' * a')' -*(a::AdjointAbsVec{<:Any,<:Zeros{<:Any,1}}, b::Zeros{<:Any,2}) = (b' * a')' -*(a::TransposeAbsVec{<:Any,<:Zeros{<:Any,1}}, b::AbstractMatrix) = transpose(transpose(b) * transpose(a)) -*(a::TransposeAbsVec{<:Any,<:Zeros{<:Any,1}}, b::Zeros{<:Any,2}) = transpose(transpose(b) * transpose(a)) +# AdjOrTrans{ZerosVector} * Matrix +*(a::AdjointAbsVec{<:Any,<:ZerosVector}, b::AbstractMatrix) = (b' * a')' +*(a::AdjointAbsVec{<:Any,<:ZerosVector}, b::ZerosMatrix) = (b' * a')' +*(a::TransposeAbsVec{<:Any,<:ZerosVector}, b::AbstractMatrix) = transpose(transpose(b) * transpose(a)) +*(a::TransposeAbsVec{<:Any,<:ZerosVector}, b::ZerosMatrix) = transpose(transpose(b) * transpose(a)) -*(a::AbstractVector, b::AdjOrTransAbsVec{<:Any,<:Zeros{<:Any,1}}) = a * permutedims(parent(b)) -*(a::AbstractMatrix, b::AdjOrTransAbsVec{<:Any,<:Zeros{<:Any,1}}) = a * permutedims(parent(b)) -*(a::Zeros{<:Any,1}, b::AdjOrTransAbsVec{<:Any,<:Zeros{<:Any,1}}) = a * permutedims(parent(b)) -*(a::Zeros{<:Any,2}, b::AdjOrTransAbsVec{<:Any,<:Zeros{<:Any,1}}) = a * permutedims(parent(b)) +# VecOrMat * AdjOrTrans{ZerosVector} +for TYPE in (:AbstractVector, :AbstractMatrix, :ZerosVector, :ZerosMatrix) + @eval *(a::$TYPE, b::AdjOrTransAbsVec{<:Any,<:ZerosVector}) = a * permutedims(parent(b)) +end -*(a::AdjointAbsVec, b::Zeros{<:Any, 1}) = _adjvec_mul_zeros(a, b) -*(a::AdjointAbsVec{<:Number}, b::Zeros{<:Number, 1}) = _adjvec_mul_zeros(a, b) -*(a::TransposeAbsVec, b::Zeros{<:Any, 1}) = _adjvec_mul_zeros(a, b) -*(a::TransposeAbsVec{<:Number}, b::Zeros{<:Number, 1}) = _adjvec_mul_zeros(a, b) +# AdjOrTrans{Vector} * ZerosVector +for T1 in (:AdjointAbsVec, :TransposeAbsVec), T2 in (:Any, :Number) + @eval *(a::$T1{<:$T2}, b::ZerosVector{<:$T2}) = _adjvec_mul_zeros(a, b) +end -*(a::Adjoint{T, <:AbstractMatrix{T}} where T, b::Zeros{<:Any, 1}) = mult_zeros(a, b) +*(a::Adjoint{T, <:AbstractMatrix{T}} where T, b::ZerosVector) = mult_zeros(a, b) -function *(a::Transpose{T, <:AbstractVector{T}}, b::Zeros{T, 1}) where T<:Real +function *(a::Transpose{T, <:AbstractVector{T}}, b::ZerosVector{T}) where T<:Real la, lb = length(a), length(b) if la ≠ lb throw(DimensionMismatch("dot product arguments have lengths $la and $lb")) end return zero(T) end -*(a::Transpose{T, <:AbstractMatrix{T}}, b::Zeros{T, 1}) where T<:Real = mult_zeros(a, b) +*(a::Transpose{T, <:AbstractMatrix{T}}, b::ZerosVector) where T<:Real = mult_zeros(a, b) # treat zero separately to support ∞-vectors -function _zero_dot(a, b) - axes(a) == axes(b) || throw(DimensionMismatch("dot product arguments have lengths $(length(a)) and $(length(b))")) - zero(promote_type(eltype(a),eltype(b))) -end - -_fill_dot(a::Zeros, b::Zeros) = _zero_dot(a, b) -_fill_dot(a::Zeros, b) = _zero_dot(a, b) -_fill_dot(a, b::Zeros) = _zero_dot(a, b) -_fill_dot(a::Zeros, b::AbstractFill) = _zero_dot(a, b) -_fill_dot(a::AbstractFill, b::Zeros) = _zero_dot(a, b) - -function _fill_dot(a::AbstractFill, b::AbstractFill) - axes(a) == axes(b) || throw(DimensionMismatch("dot product arguments have lengths $(length(a)) and $(length(b))")) - getindex_value(a)getindex_value(b)*length(b) -end - -# support types with fast sum -function _fill_dot(a::AbstractFill, b) +function _fill_dot(a::AbstractVector, b::AbstractVector) axes(a) == axes(b) || throw(DimensionMismatch("dot product arguments have lengths $(length(a)) and $(length(b))")) - getindex_value(a)sum(b) -end - -function _fill_dot(a, b::AbstractFill) - axes(a) == axes(b) || throw(DimensionMismatch("dot product arguments have lengths $(length(a)) and $(length(b))")) - sum(a)getindex_value(b) + if iszero(a) || iszero(b) + zero(promote_type(eltype(a),eltype(b))) + elseif isa(a,AbstractFill) + getindex_value(a)sum(b) + else + getindex_value(b)sum(a) + end end - -dot(a::AbstractFill{<:Any,1}, b::AbstractFill{<:Any,1}) = _fill_dot(a, b) -dot(a::AbstractFill{<:Any,1}, b::AbstractVector) = _fill_dot(a, b) -dot(a::AbstractVector, b::AbstractFill{<:Any,1}) = _fill_dot(a, b) +dot(a::AbstractFillVector, b::AbstractFillVector) = _fill_dot(a, b) +dot(a::AbstractFillVector, b::AbstractVector) = _fill_dot(a, b) +dot(a::AbstractVector, b::AbstractFillVector) = _fill_dot(a, b) function dot(u::AbstractVector, E::Eye, v::AbstractVector) length(u) == size(E,1) && length(v) == size(E,2) || @@ -211,63 +171,25 @@ function dot(u::AbstractVector{T}, D::Diagonal{U,<:Zeros}, v::AbstractVector{V}) zero(promote_type(T,U,V)) end -+(a::Zeros) = a --(a::Zeros) = a - # Zeros +/- Zeros function +(a::Zeros{T}, b::Zeros{V}) where {T, V} size(a) ≠ size(b) && throw(DimensionMismatch("dimensions must match.")) return Zeros{promote_type(T,V)}(size(a)...) end --(a::Zeros, b::Zeros) = -(a + b) --(a::Ones, b::Ones) = Zeros(a)+Zeros(b) - -# Zeros +/- Fill and Fill +/- Zeros -function +(a::AbstractFill{T}, b::Zeros{V}) where {T, V} - size(a) ≠ size(b) && throw(DimensionMismatch("dimensions must match.")) - return convert(AbstractFill{promote_type(T, V)}, a) -end -+(a::Zeros, b::AbstractFill) = b + a --(a::AbstractFill, b::Zeros) = a + b --(a::Zeros, b::AbstractFill) = a + (-b) - -# Zeros +/- Array and Array +/- Zeros -function +(a::Zeros{T, N}, b::AbstractArray{V, N}) where {T, V, N} - size(a) ≠ size(b) && throw(DimensionMismatch("dimensions must match.")) - return AbstractArray{promote_type(T,V),N}(b) -end -function +(a::Array{T, N}, b::Zeros{V, N}) where {T, V, N} - size(a) ≠ size(b) && throw(DimensionMismatch("dimensions must match.")) - return AbstractArray{promote_type(T,V),N}(a) -end - -function -(a::Zeros{T, N}, b::AbstractArray{V, N}) where {T, V, N} - size(a) ≠ size(b) && throw(DimensionMismatch("dimensions must match.")) - return -b + a -end --(a::Array{T, N}, b::Zeros{V, N}) where {T, V, N} = a + b - - -+(a::AbstractRange, b::Zeros) = b + a -function +(a::Zeros{T, 1}, b::AbstractRange) where {T} - size(a) ≠ size(b) && throw(DimensionMismatch("dimensions must match.")) - Tout = promote_type(T, eltype(b)) - return convert(Tout, first(b)):convert(Tout, step(b)):convert(Tout, last(b)) -end -function +(a::Zeros{T, 1}, b::UnitRange) where {T} - size(a) ≠ size(b) && throw(DimensionMismatch("dimensions must match.")) - Tout = promote_type(T, eltype(b)) - return convert(Tout, first(b)):convert(Tout, last(b)) -end - -function -(a::Zeros{T, 1}, b::AbstractRange{V}) where {T, V} - size(a) ≠ size(b) && throw(DimensionMismatch("dimensions must match.")) - return -b + a +for (TYPE) in (:AbstractArray, :AbstractFill, :AbstractRange) + @eval begin + function +(a::$TYPE, b::Zeros) + size(a) ≠ size(b) && throw(DimensionMismatch("dimensions must match.")) + return $TYPE{typeof(zero(eltype(a))+zero(eltype(b)))}(a) + end + +(a::Zeros, b::$TYPE) = b + a + end end --(a::AbstractRange{T}, b::Zeros{V, 1}) where {T, V} = a + b - +# temporary patch. should be a PR(#48894) to julia base. +AbstractRange{T}(r::AbstractUnitRange) where {T<:Integer} = AbstractUnitRange{T}(r) +AbstractRange{T}(r::AbstractRange) where T = T(first(r)):T(step(r)):T(last(r)) #### # norm diff --git a/test/runtests.jl b/test/runtests.jl index 46272a80..7f58c9de 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -146,9 +146,9 @@ include("infinitearrays.jl") y = x + x @test y isa Fill{Int,1} @test y[1] == 2 - @test x + Zeros{Bool}(5) ≡ x - @test x - Zeros{Bool}(5) ≡ x - @test Zeros{Bool}(5) + x ≡ x + @test x + Zeros{Bool}(5) ≡ Ones{Int}(5) + @test x - Zeros{Bool}(5) ≡ Ones{Int}(5) + @test Zeros{Bool}(5) + x ≡ Ones{Int}(5) @test -x ≡ Fill(-1,5) end @@ -392,10 +392,10 @@ end @test Diagonal(Zeros(8,5)) == Diagonal(zeros(5)) @test convert(Diagonal, Zeros(5,5)) == Diagonal(zeros(5)) - @test_throws BoundsError convert(Diagonal, Zeros(8,5)) + @test_throws DimensionMismatch convert(Diagonal, Zeros(8,5)) @test convert(Diagonal{Int}, Zeros(5,5)) == Diagonal(zeros(Int,5)) - @test_throws BoundsError convert(Diagonal{Int}, Zeros(8,5)) + @test_throws DimensionMismatch convert(Diagonal{Int}, Zeros(8,5)) @test Diagonal(Eye(8,5)) == Diagonal(ones(5)) From 0035a4f55d79232b1a6a3e952e86b491dd4f90e5 Mon Sep 17 00:00:00 2001 From: Tianyi Pu <912396513@qq.com> Date: Thu, 9 Mar 2023 20:46:57 +0000 Subject: [PATCH 2/2] fix --- src/fillalgebra.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/fillalgebra.jl b/src/fillalgebra.jl index 3df09235..e0eb0f65 100644 --- a/src/fillalgebra.jl +++ b/src/fillalgebra.jl @@ -11,7 +11,7 @@ for fun in (:transpose,:adjoint) for TYPE in (:Ones,:Zeros) @eval $fun(a::$TYPE{T,2}) where T = $TYPE{T}(reverse(a.axes)) end - @eval $fun(a::FillMatrix) where T = Fill{T}($fun(a.value), reverse(a.axes)) + @eval $fun(a::FillMatrix{T}) where T = Fill{T}($fun(a.value), reverse(a.axes)) end permutedims(a::AbstractFillVector) = fillsimilar(a, (1, length(a)))