diff --git a/src/FillArrays.jl b/src/FillArrays.jl index 9b9f40f1..36ab4185 100644 --- a/src/FillArrays.jl +++ b/src/FillArrays.jl @@ -6,7 +6,7 @@ import Base: size, getindex, setindex!, IndexStyle, checkbounds, convert, +, -, *, /, \, diff, sum, cumsum, maximum, minimum, sort, sort!, any, all, axes, isone, iterate, unique, allunique, permutedims, inv, copy, vec, setindex!, count, ==, reshape, _throw_dmrs, map, zero, - show, view, in, mapreduce + show, view, in, mapreduce, one import LinearAlgebra: rank, svdvals!, tril, triu, tril!, triu!, diag, transpose, adjoint, fill!, dot, norm2, norm1, normInf, normMinusInf, normp, lmul!, rmul!, diagzero, AbstractTriangular, AdjointAbsVec, TransposeAbsVec, @@ -336,7 +336,7 @@ axes(rd::Diagonal{<:Any,<:AbstractFill}) = (axes(rd.diag,1),axes(rd.diag,1)) axes(T::AbstractTriangular{<:Any,<:AbstractFill}) = axes(parent(T)) axes(rd::RectDiagonal) = rd.axes -size(rd::RectDiagonal) = length.(rd.axes) +size(rd::RectDiagonal) = map(length, rd.axes) @inline function getindex(rd::RectDiagonal{T}, i::Integer, j::Integer) where T @boundscheck checkbounds(rd, i, j) @@ -551,6 +551,17 @@ zero(r::Zeros{T,N}) where {T,N} = r zero(r::Ones{T,N}) where {T,N} = Zeros{T,N}(r.axes) zero(r::Fill{T,N}) where {T,N} = Zeros{T,N}(r.axes) +######### +# oneunit +######### + +function one(A::AbstractFill{T,2}) where {T} + Base.require_one_based_indexing(A) + m, n = size(A) + m == n || throw(ArgumentError("multiplicative identity defined only for square matrices")) + SquareEye{T}(m) +end + ######### # any/all/isone/iszero ######### diff --git a/test/runtests.jl b/test/runtests.jl index de8a0a7b..d66c90ff 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -398,6 +398,15 @@ end @test convert(Diagonal{Int}, Eye(5)) == Diagonal(ones(Int,5)) end +@testset "one" begin + @testset for A in Any[Eye(4), Zeros(4,4), Ones(4,4), Fill(3,4,4)] + B = one(A) + @test B * A == A * B == A + end + @test_throws ArgumentError one(Ones(3,4)) + @test_throws ArgumentError one(Ones((3:5,4:5))) +end + @testset "Sparse vectors and matrices" begin @test SparseVector(Zeros(5)) == SparseVector{Float64}(Zeros(5)) ==