Skip to content

Commit

Permalink
change order of arguments in fkeep, fix bug with fixed elements (#240)
Browse files Browse the repository at this point in the history
* change order of arguments in fkeep!

* fix bug with fixed matrices (and add test)

* add deprecated compat function
  • Loading branch information
SobhanMP authored Aug 30, 2022
1 parent 43b4d01 commit dfcc48a
Show file tree
Hide file tree
Showing 4 changed files with 61 additions and 39 deletions.
29 changes: 17 additions & 12 deletions src/sparsematrix.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1653,7 +1653,7 @@ end
## fkeep! and children tril!, triu!, droptol!, dropzeros[!]

"""
fkeep!(A::AbstractSparseArray, f)
fkeep!(f, A::AbstractSparseArray)
Keep elements of `A` for which test `f` returns `true`. `f`'s signature should be
Expand All @@ -1673,15 +1673,15 @@ julia> A = sparse(Diagonal([1, 2, 3, 4]))
⋅ ⋅ 3 ⋅
⋅ ⋅ ⋅ 4
julia> SparseArrays.fkeep!(A, (i, j, v) -> isodd(v))
julia> SparseArrays.fkeep!((i, j, v) -> isodd(v), A)
4×4 SparseMatrixCSC{Int64, Int64} with 2 stored entries:
1 ⋅ ⋅ ⋅
⋅ ⋅ ⋅ ⋅
⋅ ⋅ 3 ⋅
⋅ ⋅ ⋅ ⋅
```
"""
function _fkeep!(A::AbstractSparseMatrixCSC, f::F) where F
function _fkeep!(f::F, A::AbstractSparseMatrixCSC) where F<:Function
An = size(A, 2)
Acolptr = getcolptr(A)
Arowval = rowvals(A)
Expand Down Expand Up @@ -1716,34 +1716,39 @@ function _fkeep!(A::AbstractSparseMatrixCSC, f::F) where F
return A
end

function _fkeep!_fixed(A::AbstractSparseMatrixCSC, f::F) where F
function _fkeep!_fixed(f::F, A::AbstractSparseMatrixCSC) where F<:Function
@inbounds for j in axes(A, 2)
for k in getcolptr(A)[j]:getcolptr(A)[j+1]-1
i = rowvals(A)[k]
x = nonzeros(A)[k]
# If this element should be kept, rewrite in new position
if f(Ai, Aj, Ax)
if !f(rowvals(A)[k], j, nonzeros(A)[k])
nonzeros(A)[k] = zero(eltype(A))
end
end
end
return A
end

fkeep!(A::AbstractSparseMatrixCSC, f::F) where F= _is_fixed(A) ? _fkeep!_fixed(A, f) : _fkeep!(A, f)
fkeep!(f::F, A::AbstractSparseMatrixCSC) where F<:Function = _is_fixed(A) ? _fkeep!_fixed(f, A) : _fkeep!(f, A)

# deprecated syntax
function fkeep!(x::Union{AbstractSparseMatrixCSC,AbstractCompressedVector},f::F) where F<:Function
Base.depwarn("`fkeep!(x, f::Function)` is deprecated, use `fkeep!(f::Function, x)` instead.", :fkeep!)
return fkeep!(f, x)
end


tril!(A::AbstractSparseMatrixCSC, k::Integer = 0) =
fkeep!(A, (i, j, x) -> i + k >= j)
fkeep!((i, j, x) -> i + k >= j, A)
triu!(A::AbstractSparseMatrixCSC, k::Integer = 0) =
fkeep!(A, (i, j, x) -> j >= i + k)
fkeep!((i, j, x) -> j >= i + k, A)

"""
droptol!(A::AbstractSparseMatrixCSC, tol)
Removes stored values from `A` whose absolute value is less than or equal to `tol`.
"""
droptol!(A::AbstractSparseMatrixCSC, tol) =
fkeep!(A, (i, j, x) -> abs(x) > tol)
fkeep!((i, j, x) -> abs(x) > tol, A)

"""
dropzeros!(A::AbstractSparseMatrixCSC;)
Expand All @@ -1754,7 +1759,7 @@ For an out-of-place version, see [`dropzeros`](@ref). For
algorithmic information, see `fkeep!`.
"""

dropzeros!(A::AbstractSparseMatrixCSC) = _is_fixed(A) ? A : fkeep!(A, (i, j, x) -> _isnotzero(x))
dropzeros!(A::AbstractSparseMatrixCSC) = _is_fixed(A) ? A : fkeep!((i, j, x) -> _isnotzero(x), A)

"""
dropzeros(A::AbstractSparseMatrixCSC;)
Expand Down
53 changes: 29 additions & 24 deletions src/sparsevector.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2074,31 +2074,36 @@ function sort(x::AbstractCompressedVector{Tv,Ti}; kws...) where {Tv,Ti}
typeof(x)(n,newnzind,newnzvals)
end

function fkeep!(x::AbstractCompressedVector, f)
_is_fixed(x) && return x

nzind = nonzeroinds(x)
nzval = nonzeros(x)

x_writepos = 1
@inbounds for xk in 1:nnz(x)
xi = nzind[xk]
xv = nzval[xk]
# If this element should be kept, rewrite in new position
if f(xi, xv)
if x_writepos != xk
nzind[x_writepos] = xi
nzval[x_writepos] = xv
function fkeep!(f, x::AbstractCompressedVector{Tv}) where Tv
if _is_fixed(x)
for i in 1:nnz(x)
if !f(nonzeroinds(x)[i], nonzeros(x)[i])
nonzeros(x)[i] = zero(Tv)
end
end
else
nzind = nonzeroinds(x)
nzval = nonzeros(x)

x_writepos = 1
@inbounds for xk in 1:nnz(x)
xi = nzind[xk]
xv = nzval[xk]
# If this element should be kept, rewrite in new position
if f(xi, xv)
if x_writepos != xk
nzind[x_writepos] = xi
nzval[x_writepos] = xv
end
x_writepos += 1
end
x_writepos += 1
end
end

# Trim x's storage if necessary
x_nnz = x_writepos - 1
resize!(nzval, x_nnz)
resize!(nzind, x_nnz)

# Trim x's storage if necessary
x_nnz = x_writepos - 1
resize!(nzval, x_nnz)
resize!(nzind, x_nnz)
end
return x
end

Expand All @@ -2109,7 +2114,7 @@ end
Removes stored values from `x` whose absolute value is less than or equal to `tol`.
"""
droptol!(x::AbstractCompressedVector, tol) = fkeep!(x, (i, x) -> abs(x) > tol)
droptol!(x::AbstractCompressedVector, tol) = fkeep!((i, x) -> abs(x) > tol, x)

"""
dropzeros!(x::AbstractCompressedVector)
Expand All @@ -2119,7 +2124,7 @@ Removes stored numerical zeros from `x`.
For an out-of-place version, see [`dropzeros`](@ref). For
algorithmic information, see `fkeep!`.
"""
dropzeros!(x::AbstractCompressedVector) = _is_fixed(x) ? x : fkeep!(x, (i, x) -> _isnotzero(x))
dropzeros!(x::AbstractCompressedVector) = _is_fixed(x) ? x : fkeep!((i, x) -> _isnotzero(x), x)


"""
Expand Down
13 changes: 12 additions & 1 deletion test/fixed.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
using Test, SparseArrays, LinearAlgebra
using SparseArrays: AbstractSparseVector, AbstractSparseMatrixCSC, FixedSparseCSC, FixedSparseVector, ReadOnly,
getcolptr, rowvals, nonzeros, nonzeroinds, _is_fixed, fixed, move_fixed
getcolptr, rowvals, nonzeros, nonzeroinds, _is_fixed, fixed, move_fixed, fkeep!

@testset "ReadOnly" begin
v = randn(100)
Expand Down Expand Up @@ -124,3 +124,14 @@ end
@test b == a
end

always_false(x...) = false
@testset "Test fkeep!" begin
for a in [sprandn(10, 10, 0.99) + I, sprandn(10, 0.1) .+ 1]
a = fixed(a)
b = copy(a)
fkeep!(always_false, b)
@test nnz(a) == nnz(b)
@test all(iszero, nonzeros(b))

end
end
5 changes: 3 additions & 2 deletions test/sparsevector.jl
Original file line number Diff line number Diff line change
Expand Up @@ -913,7 +913,7 @@ end
@test_throws ArgumentError findmin(x)
@test_throws ArgumentError findmax(x)
end

let v = spzeros(3) #Julia #44978
v[1] = 2
@test argmin(v) == 2
Expand Down Expand Up @@ -1298,8 +1298,9 @@ end
xdrop = copy(x)
# This will keep index 1, 3, 4, 7 in xdrop
f_drop(i, x) = (abs(x) == 1.) || (i in [1, 7])
SparseArrays.fkeep!(xdrop, f_drop)
SparseArrays.fkeep!(f_drop, xdrop)
@test exact_equal(xdrop, SparseVector(7, [1, 3, 4, 7], [3., -1., 1., 3.]))
@test_deprecated SparseArrays.fkeep!(xdrop, f_drop)
end

@testset "dropzeros[!] with length=$m" for m in (10, 20, 30)
Expand Down

0 comments on commit dfcc48a

Please sign in to comment.