Skip to content

Commit

Permalink
Implement ReinterpretArray
Browse files Browse the repository at this point in the history
This redoes `reinterpret` in julia rather than punning the memory
of the actual array. The motivation for this is to avoid the API
limitations of the current reinterpret implementation (Array only,
preventing strong TBAA, alignment problems). The surface API
essentially unchanged, though the shape argument to reinterpret
is removed, since those concepts are now orthogonal. The return
type from `reinterpret` is now `ReinterpretArray`, which implements
the AbstractArray interface and does the reinterpreting lazily on
demand. The compiler is able to fold away the abstraction and
generate very tight IR:

```
julia> ar = reinterpret(Complex{Int64}, rand(Int64, 1000));

julia> typeof(ar)
Base.ReinterpretArray{Complex{Int64},Int64,1,Array{Int64,1}}

julia> f(ar) = @inbounds return ar[1]
f (generic function with 1 method)

julia> @code_llvm f(ar)

; Function f
; Location: REPL[2]
define void @julia_f_63575({ i64, i64 } addrspace(11)* noalias nocapture sret, %jl_value_t addrspace(10)* dereferenceable(8)) #0 {
top:
; Location: REPL[2]:1
; Function getindex; {
; Location: reinterpretarray.jl:31
  %2 = addrspacecast %jl_value_t addrspace(10)* %1 to %jl_value_t addrspace(11)*
  %3 = bitcast %jl_value_t addrspace(11)* %2 to %jl_value_t addrspace(10)* addrspace(11)*
  %4 = load %jl_value_t addrspace(10)*, %jl_value_t addrspace(10)* addrspace(11)* %3, align 8
  %5 = addrspacecast %jl_value_t addrspace(10)* %4 to %jl_value_t addrspace(11)*
  %6 = bitcast %jl_value_t addrspace(11)* %5 to i64* addrspace(11)*
  %7 = load i64*, i64* addrspace(11)* %6, align 8
  %8 = load i64, i64* %7, align 8
  %9 = getelementptr i64, i64* %7, i64 1
  %10 = load i64, i64* %9, align 8
  %.sroa.0.0..sroa_idx = getelementptr inbounds { i64, i64 }, { i64, i64 } addrspace(11)* %0, i64 0, i32 0
  store i64 %8, i64 addrspace(11)* %.sroa.0.0..sroa_idx, align 8
  %.sroa.3.0..sroa_idx13 = getelementptr inbounds { i64, i64 }, { i64, i64 } addrspace(11)* %0, i64 0, i32 1
  store i64 %10, i64 addrspace(11)* %.sroa.3.0..sroa_idx13, align 8
;}
  ret void
}

julia> g(a) = @inbounds return reinterpret(Complex{Int64}, a)[1]
g (generic function with 1 method)

julia> @code_llvm g(randn(1000))

; Function g
; Location: REPL[4]
define void @julia_g_63642({ i64, i64 } addrspace(11)* noalias nocapture sret, %jl_value_t addrspace(10)* dereferenceable(40)) #0 {
top:
; Location: REPL[4]:1
; Function getindex; {
; Location: reinterpretarray.jl:31
  %2 = addrspacecast %jl_value_t addrspace(10)* %1 to %jl_value_t addrspace(11)*
  %3 = bitcast %jl_value_t addrspace(11)* %2 to double* addrspace(11)*
  %4 = load double*, double* addrspace(11)* %3, align 8
  %5 = bitcast double* %4 to i64*
  %6 = load i64, i64* %5, align 8
  %7 = getelementptr double, double* %4, i64 1
  %8 = bitcast double* %7 to i64*
  %9 = load i64, i64* %8, align 8
  %.sroa.0.0..sroa_idx = getelementptr inbounds { i64, i64 }, { i64, i64 } addrspace(11)* %0, i64 0, i32 0
  store i64 %6, i64 addrspace(11)* %.sroa.0.0..sroa_idx, align 8
  %.sroa.3.0..sroa_idx13 = getelementptr inbounds { i64, i64 }, { i64, i64 } addrspace(11)* %0, i64 0, i32 1
  store i64 %9, i64 addrspace(11)* %.sroa.3.0..sroa_idx13, align 8
;}
  ret void
}
```

In addition, the new `reinterpret` implementation is able to handle any AbstractArray
(whether useful or not is a separate decision):

```
invoke(reinterpret, Tuple{Type{Complex{Float64}}, AbstractArray}, Complex{Float64}, speye(10))
5×10 Base.ReinterpretArray{Complex{Float64},Float64,2,SparseMatrixCSC{Float64,Int64}}:
 1.0+0.0im  0.0+1.0im  0.0+0.0im  0.0+0.0im  0.0+0.0im  0.0+0.0im  0.0+0.0im  0.0+0.0im  0.0+0.0im  0.0+0.0im
 0.0+0.0im  0.0+0.0im  1.0+0.0im  0.0+1.0im  0.0+0.0im  0.0+0.0im  0.0+0.0im  0.0+0.0im  0.0+0.0im  0.0+0.0im
 0.0+0.0im  0.0+0.0im  0.0+0.0im  0.0+0.0im  1.0+0.0im  0.0+1.0im  0.0+0.0im  0.0+0.0im  0.0+0.0im  0.0+0.0im
 0.0+0.0im  0.0+0.0im  0.0+0.0im  0.0+0.0im  0.0+0.0im  0.0+0.0im  1.0+0.0im  0.0+1.0im  0.0+0.0im  0.0+0.0im
 0.0+0.0im  0.0+0.0im  0.0+0.0im  0.0+0.0im  0.0+0.0im  0.0+0.0im  0.0+0.0im  0.0+0.0im  1.0+0.0im  0.0+1.0im
```

The remaining todo is to audit the uses of reinterpret in base. I've fixed up the uses themselves, but there's
code deeper in the array code that needs to be broadened to allow ReinterpretArray.

Fixes #22849
Fixes #19238
  • Loading branch information
Keno committed Sep 27, 2017
1 parent 8d9db00 commit 4526f8b
Show file tree
Hide file tree
Showing 23 changed files with 214 additions and 132 deletions.
27 changes: 0 additions & 27 deletions base/array.jl
Original file line number Diff line number Diff line change
Expand Up @@ -214,33 +214,6 @@ original.
"""
copy(a::T) where {T<:Array} = ccall(:jl_array_copy, Ref{T}, (Any,), a)

function reinterpret(::Type{T}, a::Array{S,1}) where T where S
nel = Int(div(length(a) * sizeof(S), sizeof(T)))
# TODO: maybe check that remainder is zero?
return reinterpret(T, a, (nel,))
end

function reinterpret(::Type{T}, a::Array{S}) where T where S
if sizeof(S) != sizeof(T)
throw(ArgumentError("result shape not specified"))
end
reinterpret(T, a, size(a))
end

function reinterpret(::Type{T}, a::Array{S}, dims::NTuple{N,Int}) where T where S where N
function throwbits(::Type{S}, ::Type{T}, ::Type{U}) where {S,T,U}
@_noinline_meta
throw(ArgumentError("cannot reinterpret Array{$(S)} to ::Type{Array{$(T)}}, type $(U) is not a bits type"))
end
isbits(T) || throwbits(S, T, T)
isbits(S) || throwbits(S, T, S)
nel = div(length(a) * sizeof(S), sizeof(T))
if prod(dims) != nel
_throw_dmrsa(dims, nel)
end
ccall(:jl_reshape_array, Array{T,N}, (Any, Any, Any), Array{T,N}, a, dims)
end

# reshaping to same # of dimensions
function reshape(a::Array{T,N}, dims::NTuple{N,Int}) where T where N
if prod(dims) != length(a)
Expand Down
10 changes: 1 addition & 9 deletions base/essentials.jl
Original file line number Diff line number Diff line change
Expand Up @@ -313,20 +313,12 @@ unsafe_convert(::Type{P}, x::Ptr) where {P<:Ptr} = convert(P, x)
reinterpret(type, A)
Change the type-interpretation of a block of memory.
For arrays, this constructs an array with the same binary data as the given
For arrays, this constructs a view of the array with the same binary data as the given
array, but with the specified element type.
For example,
`reinterpret(Float32, UInt32(7))` interprets the 4 bytes corresponding to `UInt32(7)` as a
[`Float32`](@ref).
!!! warning
It is not allowed to `reinterpret` an array to an element type with a larger alignment then
the alignment of the array. For a normal `Array`, this is the alignment of its element type.
For a reinterpreted array, this is the alignment of the `Array` it was reinterpreted from.
For example, `reinterpret(UInt32, UInt8[0, 0, 0, 0])` is not allowed but
`reinterpret(UInt32, reinterpret(UInt8, Float32[1.0]))` is allowed.
# Examples
```jldoctest
julia> reinterpret(Float32, UInt32(7))
Expand Down
7 changes: 4 additions & 3 deletions base/io.jl
Original file line number Diff line number Diff line change
Expand Up @@ -267,15 +267,16 @@ readlines(s=STDIN; chomp::Bool=true) = collect(eachline(s, chomp=chomp))

## byte-order mark, ntoh & hton ##

let endian_boms = reinterpret(UInt8, UInt32[0x01020304])
a = UInt32[0x01020304]
let endian_bom = unsafe_load(convert(Ptr{UInt8}, pointer(a)))
global ntoh, hton, ltoh, htol
if endian_boms == UInt8[1:4;]
if endian_bom == 0x01
ntoh(x) = x
hton(x) = x
ltoh(x) = bswap(x)
htol(x) = bswap(x)
const global ENDIAN_BOM = 0x01020304
elseif endian_boms == UInt8[4:-1:1;]
elseif endian_bom == 0x04
ntoh(x) = bswap(x)
hton(x) = bswap(x)
ltoh(x) = x
Expand Down
4 changes: 2 additions & 2 deletions base/linalg/factorization.jl
Original file line number Diff line number Diff line change
Expand Up @@ -56,9 +56,9 @@ Base.isequal(F::T, G::T) where {T<:Factorization} = all(f -> isequal(getfield(F,
# With a real lhs and complex rhs with the same precision, we can reinterpret
# the complex rhs as a real rhs with twice the number of columns
function (\)(F::Factorization{T}, B::VecOrMat{Complex{T}}) where T<:BlasReal
c2r = reshape(transpose(reinterpret(T, B, (2, length(B)))), size(B, 1), 2*size(B, 2))
c2r = reshape(transpose(reinterpret(T, reshape(B, (1, length(B))))), size(B, 1), 2*size(B, 2))
x = A_ldiv_B!(F, c2r)
return reinterpret(Complex{T}, transpose(reshape(x, div(length(x), 2), 2)), _ret_size(F, B))
return reshape(collect(reinterpret(Complex{T}, transpose(reshape(x, div(length(x), 2), 2)))), _ret_size(F, B))
end

for (f1, f2) in ((:\, :A_ldiv_B!),
Expand Down
6 changes: 3 additions & 3 deletions base/linalg/lq.jl
Original file line number Diff line number Diff line change
Expand Up @@ -267,10 +267,10 @@ end
# With a real lhs and complex rhs with the same precision, we can reinterpret
# the complex rhs as a real rhs with twice the number of columns
function (\)(F::LQ{T}, B::VecOrMat{Complex{T}}) where T<:BlasReal
c2r = reshape(transpose(reinterpret(T, B, (2, length(B)))), size(B, 1), 2*size(B, 2))
c2r = reshape(transpose(reinterpret(T, reshape(B, (1, length(B))))), size(B, 1), 2*size(B, 2))
x = A_ldiv_B!(F, c2r)
return reinterpret(Complex{T}, transpose(reshape(x, div(length(x), 2), 2)),
isa(B, AbstractVector) ? (size(F,2),) : (size(F,2), size(B,2)))
return reshape(collect(reinterpret(Complex{T}, transpose(reshape(x, div(length(x), 2), 2)))),
isa(B, AbstractVector) ? (size(F,2),) : (size(F,2), size(B,2)))
end


Expand Down
10 changes: 5 additions & 5 deletions base/linalg/matmul.jl
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ A_mul_B!(y::StridedVector{T}, A::StridedVecOrMat{T}, x::StridedVector{T}) where
for elty in (Float32,Float64)
@eval begin
function A_mul_B!(y::StridedVector{Complex{$elty}}, A::StridedVecOrMat{Complex{$elty}}, x::StridedVector{$elty})
Afl = reinterpret($elty,A,(2size(A,1),size(A,2)))
Afl = reinterpret($elty,A)
yfl = reinterpret($elty,y)
gemv!(yfl,'N',Afl,x)
return y
Expand Down Expand Up @@ -148,8 +148,8 @@ A_mul_B!(C::StridedMatrix{T}, A::StridedVecOrMat{T}, B::StridedVecOrMat{T}) wher
for elty in (Float32,Float64)
@eval begin
function A_mul_B!(C::StridedMatrix{Complex{$elty}}, A::StridedVecOrMat{Complex{$elty}}, B::StridedVecOrMat{$elty})
Afl = reinterpret($elty, A, (2size(A,1), size(A,2)))
Cfl = reinterpret($elty, C, (2size(C,1), size(C,2)))
Afl = reinterpret($elty, A)
Cfl = reinterpret($elty, C)
gemm_wrapper!(Cfl, 'N', 'N', Afl, B)
return C
end
Expand Down Expand Up @@ -190,8 +190,8 @@ A_mul_Bt!(C::StridedMatrix{T}, A::StridedVecOrMat{T}, B::StridedVecOrMat{T}) whe
for elty in (Float32,Float64)
@eval begin
function A_mul_Bt!(C::StridedMatrix{Complex{$elty}}, A::StridedVecOrMat{Complex{$elty}}, B::StridedVecOrMat{$elty})
Afl = reinterpret($elty, A, (2size(A,1), size(A,2)))
Cfl = reinterpret($elty, C, (2size(C,1), size(C,2)))
Afl = reinterpret($elty, A)
Cfl = reinterpret($elty, C)
gemm_wrapper!(Cfl, 'N', 'T', Afl, B)
return C
end
Expand Down
4 changes: 2 additions & 2 deletions base/linalg/qr.jl
Original file line number Diff line number Diff line change
Expand Up @@ -923,15 +923,15 @@ function (\)(A::Union{QR{T},QRCompactWY{T},QRPivoted{T}}, BIn::VecOrMat{Complex{
# |z2|z4| -> |y1|y2|y3|y4| -> |x2|y2| -> |x2|y2|x4|y4|
# |x3|y3|
# |x4|y4|
B = reshape(transpose(reinterpret(T, BIn, (2, length(BIn)))), size(BIn, 1), 2*size(BIn, 2))
B = reshape(transpose(reinterpret(T, reshape(BIn, (1, length(BIn))))), size(BIn, 1), 2*size(BIn, 2))

X = A_ldiv_B!(A, _append_zeros(B, T, n))

# |z1|z3| reinterpret |x1|x2|x3|x4| transpose |x1|y1| reshape |x1|y1|x3|y3|
# |z2|z4| <- |y1|y2|y3|y4| <- |x2|y2| <- |x2|y2|x4|y4|
# |x3|y3|
# |x4|y4|
XX = reinterpret(Complex{T}, transpose(reshape(X, div(length(X), 2), 2)), _ret_size(A, BIn))
XX = reshape(collect(reinterpret(Complex{T}, transpose(reshape(X, div(length(X), 2), 2)))), _ret_size(A, BIn))
return _cut_B(XX, 1:n)
end

Expand Down
11 changes: 6 additions & 5 deletions base/random/dSFMT.jl
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,8 @@ function dsfmt_jump(s::DSFMT_state, jp::AbstractString)
val = s.val
nval = length(val)
index = val[nval - 1]
work = zeros(UInt64, JN32 >> 1)
work = zeros(Int32, JN32)
rwork = reinterpret(UInt64, work)
dsfmt = Vector{UInt64}(nval >> 1)
ccall(:memcpy, Ptr{Void}, (Ptr{UInt64}, Ptr{Int32}, Csize_t),
dsfmt, val, (nval - 1) * sizeof(Int32))
Expand All @@ -113,17 +114,17 @@ function dsfmt_jump(s::DSFMT_state, jp::AbstractString)
for c in jp
bits = parse(UInt8,c,16)
for j in 1:4
(bits & 0x01) != 0x00 && dsfmt_jump_add!(work, dsfmt)
(bits & 0x01) != 0x00 && dsfmt_jump_add!(rwork, dsfmt)
bits = bits >> 0x01
dsfmt_jump_next_state!(dsfmt)
end
end

work[end] = index
return DSFMT_state(reinterpret(Int32, work))
rwork[end] = index
return DSFMT_state(work)
end

function dsfmt_jump_add!(dest::Vector{UInt64}, src::Vector{UInt64})
function dsfmt_jump_add!(dest::AbstractVector{UInt64}, src::Vector{UInt64})
dp = dest[end] >> 1
sp = src[end] >> 1
diff = ((sp - dp + N) % N)
Expand Down
127 changes: 127 additions & 0 deletions base/reinterpretarray.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
"""
Gives a reinterpreted view (of element type T) of the underlying array (of element type S).
If the size of `T` differs from the size of `S`, the array will be compressed/expanded in
the first dimension.
"""
struct ReinterpretArray{T,N,S,A<:AbstractArray{S, N}} <: AbstractArray{T, N}
parent::A
function reinterpret(::Type{T}, a::A) where {T,N,S,A<:AbstractArray{S, N}}
function throwbits(::Type{S}, ::Type{T}, ::Type{U}) where {S,T,U}
@_noinline_meta
throw(ArgumentError("cannot reinterpret `$(S)` `$(T)`, type `$(U)` is not a bits type"))
end
function throwsize0(::Type{S}, ::Type{T})
@_noinline_meta
throw(ArgumentError("cannot reinterpret a zero-dimensional `$(S)` array to `$(T)` which is of a different size"))
end
function thrownonint(::Type{S}, ::Type{T}, dim)
@_noinline_meta
throw(ArgumentError("""
cannot reinterpret an `$(S)` array to `$(T)` whose first dimension has size `$(dim)`.
The resulting array would have non-integral first dimension.
"""))
end
isbits(T) || throwbits(S, T, T)
isbits(S) || throwbits(S, T, S)
(N != 0 || sizeof(T) == sizeof(S)) || throwsize0(S, T)
if N != 0 && sizeof(S) != sizeof(T)
dim = size(a)[1]
rem(dim*sizeof(S),sizeof(T)) == 0 || thrownonint(S, T, dim)
end
new{T, N, S, A}(a)
end
end

parent(a::ReinterpretArray) = a.parent

eltype(a::ReinterpretArray{T}) where {T} = T
function size(a::ReinterpretArray{T,N,S} where {N}) where {T,S}
psize = size(a.parent)
size1 = div(psize[1]*sizeof(S), sizeof(T))
tuple(size1, tail(psize)...)
end

unsafe_convert(::Type{Ptr{T}}, a::ReinterpretArray{T,N,S} where N) where {T,S} = Ptr{T}(unsafe_convert(Ptr{S},a.parent))

@inline @propagate_inbounds getindex(a::ReinterpretArray{T,0}) where {T} = reinterpret(T, a.parent[])
@inline @propagate_inbounds getindex(a::ReinterpretArray) = a[1]

@inline @propagate_inbounds function getindex(a::ReinterpretArray{T,N,S}, inds::Vararg{Int, N}) where {T,N,S}
if sizeof(T) == sizeof(S)
return reinterpret(T, a.parent[inds...])
else
ind_start, sidx = divrem((inds[1]-1)*sizeof(T), sizeof(S))
t = Ref{T}()
s = Ref{S}()
@gc_preserve t s begin
tptr = Ptr{UInt8}(unsafe_convert(Ref{T}, t))
sptr = Ptr{UInt8}(unsafe_convert(Ref{S}, s))
i = 1
nbytes_copied = 0
# This is a bit complicated to deal with partial elements
# at both the start and the end. LLVM will fold as appropriate,
# once it knows the data layout
while nbytes_copied < sizeof(T)
s[] = a.parent[ind_start + i, tail(inds)...]
while nbytes_copied < sizeof(T) && sidx < sizeof(S)
unsafe_store!(tptr, unsafe_load(sptr, sidx + 1), nbytes_copied + 1)
sidx += 1
nbytes_copied += 1
end
sidx = 0
i += 1
end
end
return t[]
end
end

@inline @propagate_inbounds setindex!(a::ReinterpretArray{T,0,S} where T, v) where {S} = (a.parent[] = reinterpret(S, v))
@inline @propagate_inbounds setindex!(a::ReinterpretArray, v) = (a[1] = v)

@inline @propagate_inbounds function setindex!(a::ReinterpretArray{T,N,S}, v, inds::Vararg{Int, N}) where {T,N,S}
v = convert(T, v)::T
if sizeof(T) == sizeof(S)
return setindex!(a.parent, reinterpret(S, v), inds...)
else
ind_start, sidx = divrem((inds[1]-1)*sizeof(T), sizeof(S))
t = Ref{T}(v)
s = Ref{S}()
@gc_preserve t s begin
tptr = Ptr{UInt8}(unsafe_convert(Ref{T}, t))
sptr = Ptr{UInt8}(unsafe_convert(Ref{S}, s))
nbytes_copied = 0
i = 1
@inline function copy_element()
while nbytes_copied < sizeof(T) && sidx < sizeof(S)
unsafe_store!(sptr, unsafe_load(tptr, nbytes_copied + 1), sidx + 1)
sidx += 1
nbytes_copied += 1
end
end
# Deal with any partial elements at the start. We'll have to copy in the
# element from the original array and overwrite the relevant parts
if sidx != 0
s[] = a.parent[ind_start + i, tail(inds)...]
copy_element()
a.parent[ind_start + i, tail(inds)...] = s[]
i += 1
sidx = 0
end
# Deal with the main body of elements
while nbytes_copied < sizeof(T) && (sizeof(T) - nbytes_copied) > sizeof(S)
copy_element()
a.parent[ind_start + i, tail(inds)...] = s[]
i += 1
sidx = 0
end
# Deal with trailing partial elements
if nbytes_copied < sizeof(T)
s[] = a.parent[ind_start + i, tail(inds)...]
copy_element()
a.parent[ind_start + i, tail(inds)...] = s[]
end
end
end
return a
end
6 changes: 6 additions & 0 deletions base/show.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1888,6 +1888,12 @@ function showarg(io::IO, r::ReshapedArray, toplevel)
toplevel && print(io, " with eltype ", eltype(r))
end

function showarg(io::IO, r::ReinterpretArray{T}, toplevel) where {T}
print(io, "reinterpret($T, ")
showarg(io, parent(r), false)
print(io, ')')
end

# n-dimensional arrays
function show_nd(io::IO, a::AbstractArray, print_matrix, label_slices)
limit::Bool = get(io, :limit, false)
Expand Down
18 changes: 17 additions & 1 deletion base/sparse/abstractsparse.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

abstract type AbstractSparseArray{Tv,Ti,N} <: AbstractArray{Tv,N} end

const AbstractSparseVector{Tv,Ti} = AbstractSparseArray{Tv,Ti,1}
const AbstractSparseVector{Tv,Ti} = Union{AbstractSparseArray{Tv,Ti,1}, Base.ReinterpretArray{Tv,1,T,<:AbstractSparseArray{T,Ti,1}} where T}
const AbstractSparseMatrix{Tv,Ti} = AbstractSparseArray{Tv,Ti,2}

"""
Expand All @@ -19,5 +19,21 @@ issparse(S::LowerTriangular{<:Any,<:AbstractSparseMatrix}) = true
issparse(S::LinAlg.UnitLowerTriangular{<:Any,<:AbstractSparseMatrix}) = true
issparse(S::UpperTriangular{<:Any,<:AbstractSparseMatrix}) = true
issparse(S::LinAlg.UnitUpperTriangular{<:Any,<:AbstractSparseMatrix}) = true
issparse(S::Base.ReinterpretArray) = issparse(S.parent)

indtype(S::AbstractSparseArray{<:Any,Ti}) where {Ti} = Ti

nonzeros(A::Base.ReinterpretArray{T}) where {T} = reinterpret(T, nonzeros(A.parent))
function nonzeroinds(A::Base.ReinterpretArray{T,N,S} where {N}) where {T,S}
if sizeof(T) == sizeof(S)
return nonzeroinds(A.parent)
elseif sizeof(T) > sizeof(S)
unique(map(nonzeroinds(A.parent)) do ind
div(ind, div(sizeof(T), sizeof(S)))
end)
else
map(nonzeroinds(A.parent)) do ind
ind * div(sizeof(S), sizeof(T))
end
end
end
30 changes: 0 additions & 30 deletions base/sparse/sparsematrix.jl
Original file line number Diff line number Diff line change
Expand Up @@ -212,17 +212,6 @@ end

## Reinterpret and Reshape

function reinterpret(::Type{T}, a::SparseMatrixCSC{Tv}) where {T,Tv}
if sizeof(T) != sizeof(Tv)
throw(ArgumentError("SparseMatrixCSC reinterpret is only supported for element types of the same size"))
end
mA, nA = size(a)
colptr = copy(a.colptr)
rowval = copy(a.rowval)
nzval = reinterpret(T, a.nzval)
return SparseMatrixCSC(mA, nA, colptr, rowval, nzval)
end

function sparse_compute_reshaped_colptr_and_rowval(colptrS::Vector{Ti}, rowvalS::Vector{Ti},
mS::Int, nS::Int, colptrA::Vector{Ti},
rowvalA::Vector{Ti}, mA::Int, nA::Int) where Ti
Expand Down Expand Up @@ -257,25 +246,6 @@ function sparse_compute_reshaped_colptr_and_rowval(colptrS::Vector{Ti}, rowvalS:
end
end

function reinterpret(::Type{T}, a::SparseMatrixCSC{Tv,Ti}, dims::NTuple{N,Int}) where {T,Tv,Ti,N}
if sizeof(T) != sizeof(Tv)
throw(ArgumentError("SparseMatrixCSC reinterpret is only supported for element types of the same size"))
end
if prod(dims) != length(a)
throw(DimensionMismatch("new dimensions $(dims) must be consistent with array size $(length(a))"))
end
mS,nS = dims
mA,nA = size(a)
numnz = nnz(a)
colptr = Vector{Ti}(nS+1)
rowval = similar(a.rowval)
nzval = reinterpret(T, a.nzval)

sparse_compute_reshaped_colptr_and_rowval(colptr, rowval, mS, nS, a.colptr, a.rowval, mA, nA)

return SparseMatrixCSC(mS, nS, colptr, rowval, nzval)
end

function copy(ra::ReshapedArray{<:Any,2,<:SparseMatrixCSC})
mS,nS = size(ra)
a = parent(ra)
Expand Down
Loading

0 comments on commit 4526f8b

Please sign in to comment.