From bdf39a5f429f87528dcbf76293945e34468c59d9 Mon Sep 17 00:00:00 2001 From: Stephan Hilb Date: Sun, 23 Sep 2018 17:14:12 +0200 Subject: [PATCH] stdlib/SparseArrays: fix scalar setindex! for vector eltype --- stdlib/SparseArrays/src/sparsematrix.jl | 9 +++++++-- stdlib/SparseArrays/test/sparse.jl | 6 ++++++ 2 files changed, 13 insertions(+), 2 deletions(-) diff --git a/stdlib/SparseArrays/src/sparsematrix.jl b/stdlib/SparseArrays/src/sparsematrix.jl index 5b107ff86d2584..45e3e92012242c 100644 --- a/stdlib/SparseArrays/src/sparsematrix.jl +++ b/stdlib/SparseArrays/src/sparsematrix.jl @@ -2350,7 +2350,11 @@ getindex(A::SparseMatrixCSC, I::AbstractVector{<:Integer}, J::AbstractVector{Boo getindex(A::SparseMatrixCSC, I::AbstractVector{Bool}, J::AbstractVector{<:Integer}) = A[findall(I),J] ## setindex! -function setindex!(A::SparseMatrixCSC{Tv,Ti}, _v, _i::Integer, _j::Integer) where {Tv,Ti<:Integer} + +# dispatch helper for #29034 +setindex!(A::SparseMatrixCSC, _v, _i::Integer, _j::Integer) = _setindex_scalar!(A, _v, _i, _j) + +function _setindex_scalar!(A::SparseMatrixCSC{Tv,Ti}, _v, _i::Integer, _j::Integer) where {Tv,Ti<:Integer} v = convert(Tv, _v) i = convert(Ti, _i) j = convert(Ti, _j) @@ -2545,7 +2549,8 @@ end _to_same_csc(::SparseMatrixCSC{Tv, Ti}, V::AbstractMatrix, I...) where {Tv,Ti} = convert(SparseMatrixCSC{Tv,Ti}, V) _to_same_csc(::SparseMatrixCSC{Tv, Ti}, V::AbstractVector, I...) where {Tv,Ti} = convert(SparseMatrixCSC{Tv,Ti}, reshape(V, map(length, I))) -setindex!(A::SparseMatrixCSC{Tv}, B::AbstractVecOrMat, I::Integer, J::Integer) where {Tv} = setindex!(A, convert(Tv, B), I, J) +setindex!(A::SparseMatrixCSC{Tv}, B::AbstractVecOrMat, I::Integer, J::Integer) where {Tv} = _setindex_scalar!(A, B, I, J) + function setindex!(A::SparseMatrixCSC{Tv,Ti}, V::AbstractVecOrMat, Ix::Union{Integer, AbstractVector{<:Integer}, Colon}, Jx::Union{Integer, AbstractVector{<:Integer}, Colon}) where {Tv,Ti<:Integer} @assert !has_offset_axes(A, V, Ix, Jx) (I, J) = Base.ensure_indexable(to_indices(A, (Ix, Jx))) diff --git a/stdlib/SparseArrays/test/sparse.jl b/stdlib/SparseArrays/test/sparse.jl index 416cbffac27cf7..eeaf47d063c960 100644 --- a/stdlib/SparseArrays/test/sparse.jl +++ b/stdlib/SparseArrays/test/sparse.jl @@ -1993,6 +1993,12 @@ end @test nnz(A) == 1 end +@testset "setindex with vector eltype (#29034)" begin + A = sparse([1], [1], [Vector{Float64}(undef, 3)], 3, 3) + A[1,1] = [1.0, 2.0, 3.0] + @test A[1,1] == [1.0, 2.0, 3.0] +end + @testset "show" begin io = IOBuffer() show(io, MIME"text/plain"(), sparse(Int64[1], Int64[1], [1.0]))