Skip to content

Commit

Permalink
replace ind2sub/sub2ind by CartesianIndices/LinearIndices
Browse files Browse the repository at this point in the history
  • Loading branch information
jishnub committed Oct 3, 2023
1 parent ac5c8ed commit e38d0ef
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 7 deletions.
11 changes: 8 additions & 3 deletions src/sparsematrix.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3082,12 +3082,15 @@ function getindex(A::AbstractSparseMatrixCSC{Tv,Ti}, I::AbstractArray) where {Tv
colptrB[colB] = 1
idxB = 1

CartIndsA = CartesianIndices(szA)
CartIndsB = CartesianIndices(szB)

for i in 1:n
@boundscheck checkbounds(A, I[i])
row,col = Base._ind2sub(szA, I[i])
row,col = Tuple(CartIndsA[I[i]])
for r in colptrA[col]:(colptrA[col+1]-1)
@inbounds if rowvalA[r] == row
rowB,colB = Base._ind2sub(szB, i)
rowB,colB = Tuple(CartIndsB[i])
colptrB[colB+1] += 1
rowvalB[idxB] = rowB
nzvalB[idxB] = nzvalA[r]
Expand Down Expand Up @@ -3591,13 +3594,15 @@ function setindex!(A::AbstractSparseMatrixCSC, x::AbstractArray, Ix::AbstractVec

isa(x, AbstractArray) && setindex_shape_check(x, length(I))

CartIndsA = CartesianIndices(szA)

lastcol = 0
(nrowA, ncolA) = szA
@inbounds for xidx in 1:n
sxidx = S[xidx]
(sxidx < n) && (I[sxidx] == I[sxidx+1]) && continue

row,col = Base._ind2sub(szA, I[sxidx])
row,col = Tuple(CartIndsA[I[sxidx]])
v = x[sxidx]

if col > lastcol
Expand Down
13 changes: 9 additions & 4 deletions src/sparsevector.jl
Original file line number Diff line number Diff line change
Expand Up @@ -779,9 +779,12 @@ function getindex(A::AbstractSparseMatrixCSC{Tv}, I::AbstractUnitRange) where Tv
rowvalB = Vector{Int}(undef, nnzB)
nzvalB = Vector{Tv}(undef, nnzB)

CartIndsA = CartesianIndices(szA)
LinIndsA = LinearIndices(szA)

if nnzB > 0
rowstart,colstart = Base._ind2sub(szA, first(I))
rowend,colend = Base._ind2sub(szA, last(I))
rowstart,colstart = Tuple(CartIndsA[first(I)])
rowend,colend = Tuple(CartIndsA[last(I)])

idxB = 1
@inbounds for col in colstart:colend
Expand All @@ -790,7 +793,7 @@ function getindex(A::AbstractSparseMatrixCSC{Tv}, I::AbstractUnitRange) where Tv
for r in colptrA[col]:(colptrA[col+1]-1)
rowA = rowvalA[r]
if minrow <= rowA <= maxrow
rowvalB[idxB] = Base._sub2ind(szA, rowA, col) - first(I) + 1
rowvalB[idxB] = LinIndsA[rowA, col] - first(I) + 1
nzvalB[idxB] = nzvalA[r]
idxB += 1
end
Expand Down Expand Up @@ -818,9 +821,11 @@ function getindex(A::AbstractSparseMatrixCSC{Tv,Ti}, I::AbstractVector) where {T
rowvalB = Vector{Ti}(undef, nnzB)
nzvalB = Vector{Tv}(undef, nnzB)

CartIndsA = CartesianIndices(szA)

idxB = 1
for i in 1:n
row,col = Base._ind2sub(szA, I[i])
row,col = Tuple(CartIndsA[I[i]])
for r in colptrA[col]:(colptrA[col+1]-1)
@inbounds if rowvalA[r] == row
if idxB <= nnzB
Expand Down

0 comments on commit e38d0ef

Please sign in to comment.