Skip to content

Commit

Permalink
Remove some converters (2) (#218)
Browse files Browse the repository at this point in the history
* Update FillArrays.jl

* Update fillalgebra.jl

* Update runtests.jl

* Update FillArrays.jl

* Update fillalgebra.jl

* Update fillalgebra.jl

* 98% coverage

* 648/660 coverage

* 648/657 coverage

* remove type piracy

* fix
  • Loading branch information
putianyi889 authored Mar 15, 2023
1 parent 193a641 commit 3c12489
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 49 deletions.
63 changes: 22 additions & 41 deletions src/FillArrays.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ import Base: size, getindex, setindex!, IndexStyle, checkbounds, convert,

import LinearAlgebra: rank, svdvals!, tril, triu, tril!, triu!, diag, transpose, adjoint, fill!,
dot, norm2, norm1, normInf, normMinusInf, normp, lmul!, rmul!, diagzero, AbstractTriangular, AdjointAbsVec, TransposeAbsVec,
issymmetric, ishermitian, AdjOrTransAbsVec
issymmetric, ishermitian, AdjOrTransAbsVec, checksquare

import Base.Broadcast: broadcasted, DefaultArrayStyle, broadcast_shape

Expand Down Expand Up @@ -148,18 +148,9 @@ Fill{T,0}(x::T, ::Tuple{}) where T = Fill{T,0,Tuple{}}(x, ()) # ambiguity fix

@inline getindex_value(F::Fill) = F.value

AbstractArray{T}(F::Fill{T}) where T = F
AbstractArray{T,N}(F::Fill{T,N}) where {T,N} = F
AbstractArray{T}(F::Fill{V,N}) where {T,V,N} = Fill{T}(convert(T, F.value)::T, F.axes)
AbstractArray{T,N}(F::Fill{V,N}) where {T,V,N} = Fill{T}(convert(T, F.value)::T, F.axes)

convert(::Type{AbstractArray{T}}, F::Fill{T}) where T = F
convert(::Type{AbstractArray{T,N}}, F::Fill{T,N}) where {T,N} = F
convert(::Type{AbstractArray{T}}, F::Fill) where {T} = AbstractArray{T}(F)
convert(::Type{AbstractArray{T,N}}, F::Fill) where {T,N} = AbstractArray{T,N}(F)
convert(::Type{AbstractFill}, F::AbstractFill) = F
convert(::Type{AbstractFill{T}}, F::AbstractFill) where T = convert(AbstractArray{T}, F)
convert(::Type{AbstractFill{T,N}}, F::AbstractFill) where {T,N} = convert(AbstractArray{T,N}, F)
AbstractFill{T}(F::AbstractFill) where T = AbstractArray{T}(F)

copy(F::Fill) = Fill(F.value, F.axes)

Expand Down Expand Up @@ -304,21 +295,23 @@ for (Typ, funcs, func) in ((:Zeros, :zeros, :zero), (:Ones, :ones, :one))
@inline size(Z::$Typ) = length.(Z.axes)
@inline getindex_value(Z::$Typ{T}) where T = $func(T)

AbstractArray{T}(F::$Typ{T}) where T = F
AbstractArray{T,N}(F::$Typ{T,N}) where {T,N} = F
AbstractArray{T}(F::$Typ) where T = $Typ{T}(F.axes)
AbstractArray{T,N}(F::$Typ{V,N}) where {T,V,N} = $Typ{T}(F.axes)
convert(::Type{AbstractArray{T}}, F::$Typ{T}) where T = AbstractArray{T}(F)
convert(::Type{AbstractArray{T,N}}, F::$Typ{T,N}) where {T,N} = AbstractArray{T,N}(F)
convert(::Type{AbstractArray{T}}, F::$Typ) where T = AbstractArray{T}(F)
convert(::Type{AbstractArray{T,N}}, F::$Typ) where {T,N} = AbstractArray{T,N}(F)

copy(F::$Typ) = F

getindex(F::$Typ{T,0}) where T = getindex_value(F)
end
end

# conversions
for TYPE in (:Fill, :AbstractFill, :Ones, :Zeros), STYPE in (:AbstractArray, :AbstractFill)
@eval begin
@inline $STYPE{T}(F::$TYPE{T}) where T = F
@inline $STYPE{T,N}(F::$TYPE{T,N}) where {T,N} = F
end
end

"""
fillsimilar(a::AbstractFill, axes)
Expand Down Expand Up @@ -467,32 +460,22 @@ for (Typ, funcs, func) in ((:Zeros, :zeros, :zero), (:Ones, :ones, :one))
end
end

function convert(::Type{Diagonal}, Z::ZerosMatrix{T}) where T
n,m = size(Z)
n m && throw(BoundsError(Z))
Diagonal(zeros(T, n))
end

function convert(::Type{Diagonal{T}}, Z::ZerosMatrix) where T
n,m = size(Z)
n m && throw(BoundsError(Z))
Diagonal(zeros(T, n))
# temporary patch. should be a PR(#48895) to LinearAlgebra
Diagonal{T}(A::AbstractFillMatrix) where T = Diagonal{T}(diag(A))
function convert(::Type{T}, A::AbstractFillMatrix) where T<:Diagonal
checksquare(A)
isdiag(A) ? T(A) : throw(InexactError(:convert, T, A))
end

## Sparse arrays

convert(::Type{SparseVector}, Z::ZerosVector{T}) where T = spzeros(T, length(Z))
convert(::Type{SparseVector{T}}, Z::ZerosVector) where T = spzeros(T, length(Z))
convert(::Type{SparseVector{Tv,Ti}}, Z::ZerosVector) where {Tv,Ti} = spzeros(Tv, Ti, length(Z))
SparseVector{T}(Z::ZerosVector) where T = spzeros(T, length(Z))
SparseVector{Tv,Ti}(Z::ZerosVector) where {Tv,Ti} = spzeros(Tv, Ti, length(Z))

convert(::Type{AbstractSparseVector}, Z::ZerosVector{T}) where T = spzeros(T, length(Z))
convert(::Type{AbstractSparseVector{T}}, Z::ZerosVector) where T= spzeros(T, length(Z))

convert(::Type{SparseMatrixCSC}, Z::ZerosMatrix{T}) where T = spzeros(T, size(Z)...)
convert(::Type{SparseMatrixCSC{T}}, Z::ZerosMatrix) where T = spzeros(T, size(Z)...)
convert(::Type{SparseMatrixCSC{Tv,Ti}}, Z::ZerosMatrix) where {Tv,Ti} = spzeros(Tv, Ti, size(Z)...)
convert(::Type{SparseMatrixCSC{Tv,Ti}}, Z::Zeros{T,2,Axes}) where {Tv,Ti<:Integer,T,Axes} =
spzeros(Tv, Ti, size(Z)...)
SparseMatrixCSC{T}(Z::ZerosMatrix) where T = spzeros(T, size(Z)...)
SparseMatrixCSC{Tv,Ti}(Z::Zeros{T,2,Axes}) where {Tv,Ti<:Integer,T,Axes} = spzeros(Tv, Ti, size(Z)...)

convert(::Type{AbstractSparseMatrix}, Z::ZerosMatrix{T}) where T = spzeros(T, size(Z)...)
convert(::Type{AbstractSparseMatrix{T}}, Z::ZerosMatrix) where T = spzeros(T, size(Z)...)
Expand All @@ -502,11 +485,9 @@ convert(::Type{AbstractSparseArray{Tv}}, Z::Zeros{T}) where {T,Tv} = spzeros(Tv,
convert(::Type{AbstractSparseArray{Tv,Ti}}, Z::Zeros{T}) where {T,Tv,Ti} = spzeros(Tv, Ti, size(Z)...)
convert(::Type{AbstractSparseArray{Tv,Ti,N}}, Z::Zeros{T,N}) where {T,Tv,Ti,N} = spzeros(Tv, Ti, size(Z)...)


convert(::Type{SparseMatrixCSC}, Z::Eye{T}) where T = SparseMatrixCSC{T}(I, size(Z)...)
convert(::Type{SparseMatrixCSC{Tv}}, Z::Eye{T}) where {T,Tv} = SparseMatrixCSC{Tv}(I, size(Z)...)
SparseMatrixCSC{Tv}(Z::Eye{T}) where {T,Tv} = SparseMatrixCSC{Tv}(I, size(Z)...)
# works around missing `speye`:
convert(::Type{SparseMatrixCSC{Tv,Ti}}, Z::Eye{T}) where {T,Tv,Ti<:Integer} =
SparseMatrixCSC{Tv,Ti}(Z::Eye{T}) where {T,Tv,Ti<:Integer} =
convert(SparseMatrixCSC{Tv,Ti}, SparseMatrixCSC{Tv}(I, size(Z)...))

convert(::Type{AbstractSparseMatrix}, Z::Eye{T}) where {T} = SparseMatrixCSC{T}(I, size(Z)...)
Expand Down Expand Up @@ -547,7 +528,7 @@ cumsum(x::ZerosVector) = x
cumsum(x::ZerosVector{Bool}) = x
cumsum(x::OnesVector{II}) where II<:Integer = convert(AbstractVector{II}, oneto(length(x)))
cumsum(x::OnesVector{Bool}) = oneto(length(x))
cumsum(x::AbstractFillVector{Bool}) = cumsum(convert(AbstractFill{Int}, x))
cumsum(x::AbstractFillVector{Bool}) = cumsum(AbstractFill{Int}(x))


#########
Expand Down
10 changes: 4 additions & 6 deletions src/fillalgebra.jl
Original file line number Diff line number Diff line change
Expand Up @@ -225,7 +225,7 @@ end
# Zeros +/- Fill and Fill +/- Zeros
function +(a::AbstractFill{T}, b::Zeros{V}) where {T, V}
size(a) size(b) && throw(DimensionMismatch("dimensions must match."))
return convert(AbstractFill{promote_type(T, V)}, a)
return AbstractFill{promote_type(T, V)}(a)
end
+(a::Zeros, b::AbstractFill) = b + a
-(a::AbstractFill, b::Zeros) = a + b
Expand Down Expand Up @@ -253,12 +253,12 @@ end
function +(a::ZerosVector{T}, b::AbstractRange) where {T}
size(a) size(b) && throw(DimensionMismatch("dimensions must match."))
Tout = promote_type(T, eltype(b))
return convert(Tout, first(b)):convert(Tout, step(b)):convert(Tout, last(b))
return Tout(first(b)):Tout(step(b)):Tout(last(b))
end
function +(a::ZerosVector{T}, b::UnitRange) where {T}
function +(a::ZerosVector{T}, b::UnitRange) where {T<:Integer}
size(a) size(b) && throw(DimensionMismatch("dimensions must match."))
Tout = promote_type(T, eltype(b))
return convert(Tout, first(b)):convert(Tout, last(b))
return AbstractUnitRange{Tout}(b)
end

function -(a::ZerosVector, b::AbstractRange)
Expand All @@ -267,8 +267,6 @@ function -(a::ZerosVector, b::AbstractRange)
end
-(a::AbstractRange, b::ZerosVector) = a + b



####
# norm
####
Expand Down
6 changes: 4 additions & 2 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -392,10 +392,10 @@ end

@test Diagonal(Zeros(8,5)) == Diagonal(zeros(5))
@test convert(Diagonal, Zeros(5,5)) == Diagonal(zeros(5))
@test_throws BoundsError convert(Diagonal, Zeros(8,5))
@test_throws DimensionMismatch convert(Diagonal, Zeros(8,5))

@test convert(Diagonal{Int}, Zeros(5,5)) == Diagonal(zeros(Int,5))
@test_throws BoundsError convert(Diagonal{Int}, Zeros(8,5))
@test_throws DimensionMismatch convert(Diagonal{Int}, Zeros(8,5))


@test Diagonal(Eye(8,5)) == Diagonal(ones(5))
Expand All @@ -414,6 +414,7 @@ end

@testset "Sparse vectors and matrices" begin
@test SparseVector(Zeros(5)) ==
SparseVector{Int}(Zeros(5)) ==
SparseVector{Float64}(Zeros(5)) ==
SparseVector{Float64,Int}(Zeros(5)) ==
convert(AbstractSparseArray,Zeros(5)) ==
Expand All @@ -426,6 +427,7 @@ end
for (Mat, SMat) in ((Zeros(5,5), spzeros(5,5)), (Zeros(6,5), spzeros(6,5)),
(Eye(5), sparse(I,5,5)), (Eye(6,5), sparse(I,6,5)))
@test SparseMatrixCSC(Mat) ==
SparseMatrixCSC{Int}(Mat) ==
SparseMatrixCSC{Float64}(Mat) ==
SparseMatrixCSC{Float64,Int}(Mat) ==
convert(AbstractSparseArray,Mat) ==
Expand Down

0 comments on commit 3c12489

Please sign in to comment.