From 3de4e7d887783445ed14ac59886b4e9e7c944898 Mon Sep 17 00:00:00 2001 From: Sacha Verweij Date: Fri, 21 Oct 2016 11:38:55 -0700 Subject: [PATCH] Make nz2nz_z2z-class sparse unary broadcast leverage existing broadcast machinery rather than reimplement it poorly. --- base/sparse/sparsematrix.jl | 15 +++------------ 1 file changed, 3 insertions(+), 12 deletions(-) diff --git a/base/sparse/sparsematrix.jl b/base/sparse/sparsematrix.jl index e1126c3d916e5..a95969b79cb37 100644 --- a/base/sparse/sparsematrix.jl +++ b/base/sparse/sparsematrix.jl @@ -1453,26 +1453,17 @@ round{To}(::Type{To}, A::SparseMatrixCSC) = _broadcast_unary_nz2z_z2z_T(round, A Takes unary function `f` that maps zeros to zeros and nonzeros to nonzeros, and returns a new `SparseMatrixCSC{TiA,TvB}` `B` generated by applying `f` to each nonzero entry in `A`. """ -function _broadcast_unary_nz2nz_z2z_T{TvA,TiA,TvB}(f::Function, A::SparseMatrixCSC{TvA,TiA}, ::Type{TvB}) +function _broadcast_unary_nz2nz_z2z{TvA,TiA,Tf<:Function}(f::Tf, A::SparseMatrixCSC{TvA,TiA}) Bcolptr = Vector{TiA}(A.n + 1) Browval = Vector{TiA}(nnz(A)) - Bnzval = Vector{TvB}(nnz(A)) copy!(Bcolptr, 1, A.colptr, 1, A.n + 1) copy!(Browval, 1, A.rowval, 1, nnz(A)) - @inbounds @simd for k in 1:nnz(A) - Bnzval[k] = f(A.nzval[k]) - end + Bnzval = broadcast(f, A.nzval) + resize!(Bnzval, nnz(A)) return SparseMatrixCSC(A.m, A.n, Bcolptr, Browval, Bnzval) end -function _broadcast_unary_nz2nz_z2z{Tv}(f::Function, A::SparseMatrixCSC{Tv}) - _broadcast_unary_nz2nz_z2z_T(f, A, Tv) -end @_enumerate_childmethods(_broadcast_unary_nz2nz_z2z, log1p, expm1, abs, abs2, conj) -broadcast{TTv}(::typeof(abs2), A::SparseMatrixCSC{Complex{TTv}}) = _broadcast_unary_nz2nz_z2z_T(abs2, A, TTv) -broadcast{TTv}(::typeof(abs), A::SparseMatrixCSC{Complex{TTv}}) = _broadcast_unary_nz2nz_z2z_T(abs, A, TTv) -broadcast{TTv<:Integer}(::typeof(abs), A::SparseMatrixCSC{Complex{TTv}}) = _broadcast_unary_nz2nz_z2z_T(abs, A, Float64) -broadcast{TTv<:BigInt}(::typeof(abs), A::SparseMatrixCSC{Complex{TTv}}) = _broadcast_unary_nz2nz_z2z_T(abs, A, BigFloat) function conj!(A::SparseMatrixCSC) @inbounds @simd for k in 1:nnz(A) A.nzval[k] = conj(A.nzval[k])