Skip to content

Commit

Permalink
stdlib/SparseArrays: fix scalar setindex! for vector eltype
Browse files Browse the repository at this point in the history
  • Loading branch information
stev47 committed Sep 23, 2018
1 parent a854139 commit bdf39a5
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 2 deletions.
9 changes: 7 additions & 2 deletions stdlib/SparseArrays/src/sparsematrix.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2350,7 +2350,11 @@ getindex(A::SparseMatrixCSC, I::AbstractVector{<:Integer}, J::AbstractVector{Boo
getindex(A::SparseMatrixCSC, I::AbstractVector{Bool}, J::AbstractVector{<:Integer}) = A[findall(I),J]

## setindex!
function setindex!(A::SparseMatrixCSC{Tv,Ti}, _v, _i::Integer, _j::Integer) where {Tv,Ti<:Integer}

# dispatch helper for #29034
setindex!(A::SparseMatrixCSC, _v, _i::Integer, _j::Integer) = _setindex_scalar!(A, _v, _i, _j)

function _setindex_scalar!(A::SparseMatrixCSC{Tv,Ti}, _v, _i::Integer, _j::Integer) where {Tv,Ti<:Integer}
v = convert(Tv, _v)
i = convert(Ti, _i)
j = convert(Ti, _j)
Expand Down Expand Up @@ -2545,7 +2549,8 @@ end
_to_same_csc(::SparseMatrixCSC{Tv, Ti}, V::AbstractMatrix, I...) where {Tv,Ti} = convert(SparseMatrixCSC{Tv,Ti}, V)
_to_same_csc(::SparseMatrixCSC{Tv, Ti}, V::AbstractVector, I...) where {Tv,Ti} = convert(SparseMatrixCSC{Tv,Ti}, reshape(V, map(length, I)))

setindex!(A::SparseMatrixCSC{Tv}, B::AbstractVecOrMat, I::Integer, J::Integer) where {Tv} = setindex!(A, convert(Tv, B), I, J)
setindex!(A::SparseMatrixCSC{Tv}, B::AbstractVecOrMat, I::Integer, J::Integer) where {Tv} = _setindex_scalar!(A, B, I, J)

function setindex!(A::SparseMatrixCSC{Tv,Ti}, V::AbstractVecOrMat, Ix::Union{Integer, AbstractVector{<:Integer}, Colon}, Jx::Union{Integer, AbstractVector{<:Integer}, Colon}) where {Tv,Ti<:Integer}
@assert !has_offset_axes(A, V, Ix, Jx)
(I, J) = Base.ensure_indexable(to_indices(A, (Ix, Jx)))
Expand Down
6 changes: 6 additions & 0 deletions stdlib/SparseArrays/test/sparse.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1993,6 +1993,12 @@ end
@test nnz(A) == 1
end

@testset "setindex with vector eltype (#29034)" begin
A = sparse([1], [1], [Vector{Float64}(undef, 3)], 3, 3)
A[1,1] = [1.0, 2.0, 3.0]
@test A[1,1] == [1.0, 2.0, 3.0]
end

@testset "show" begin
io = IOBuffer()
show(io, MIME"text/plain"(), sparse(Int64[1], Int64[1], [1.0]))
Expand Down

0 comments on commit bdf39a5

Please sign in to comment.