Skip to content

Commit

Permalink
Address review of julia implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
Keno committed Sep 18, 2017
1 parent bcae26a commit e8bfc52
Showing 1 changed file with 36 additions and 17 deletions.
53 changes: 36 additions & 17 deletions base/reinterpretarray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,46 +3,65 @@ Gives a reinterpreted view (of element type T) of the underlying array (of eleme
If the size of `T` differs from the size of `S`, the array will be compressed/expanded in
the first dimension.
"""
struct ReinterpretArray{T,S,N,A<:AbstractArray{S, N}} <: AbstractArray{T, N}
struct ReinterpretArray{T,N,S,A<:AbstractArray{S, N}} <: AbstractArray{T, N}
parent::A
Base.reinterpret(::Type{T}, a::A) where {T,S,N,A<:AbstractArray{S, N}} = new{T, S, N, A}(a)
function reinterpret(::Type{T}, a::A) where {T,N,S,A<:AbstractArray{S, N}}
function throwbits(::Type{S}, ::Type{T}, ::Type{U}) where {S,T,U}
@_noinline_meta
throw(ArgumentError("cannot reinterpret `$(S)` `$(T)`, type `$(U)` is not a bits type"))
end
function throwsize0(::Type{S}, ::Type{T})
@_noinline_meta
throw(ArgumentError("cannot reinterpret a zero-dimensional `$(S)` array to `$(T)` which is of a different size"))
end
isbits(T) || throwbits(S, T, T)
isbits(S) || throwbits(S, T, S)
(N != 0 || sizeof(T) == sizeof(S)) || throwsize0(S, T)
new{T, N, S, A}(a)
end
end

Base.eltype(a::ReinterpretArray{T}) where {T} = T
function Base.size(a::ReinterpretArray{T,S}) where {T,S}
eltype(a::ReinterpretArray{T}) where {T} = T
function size(a::ReinterpretArray{T,S}) where {T,S}
psize = size(a.parent)
if sizeof(T) > sizeof(S)
size1 = div(psize[1], div(sizeof(T), sizeof(S)))
else
size1 = psize[1] * div(sizeof(S), sizeof(T))
end
tuple(size1, Base.tail(psize)...)
tuple(size1, tail(psize)...)
end

@inline @propagate_inbounds function Base.getindex(a::ReinterpretArray{T,S,N}, inds::Vararg{Int, N}) where {T,S,N}
@inline @propagate_inbounds getindex(a::ReinterpretArray{T,0}) where {T} = reinterpret(T, a.parent[])
@inline @propagate_inbounds getindex(a::ReinterpretArray) = a[1]

@inline @propagate_inbounds function getindex(a::ReinterpretArray{T,N,S}, inds::Vararg{Int, N}) where {T,N,S}
if sizeof(T) == sizeof(S)
return reinterpret(T, a[inds...])
elseif sizeof(T) > sizeof(S)
nels = div(sizeof(T), sizeof(S))
ind_off = (inds[1]-1) * nels
o = Ref{T}()
optr = Base.unsafe_convert(Ref{T}, o)
optr = unsafe_convert(Ref{T}, o)
for i = 1:nels
unsafe_store!(convert(Ptr{S}, optr)+(i-1)*sizeof(S), a.parent[ind_off + i, tail(inds)...])
unsafe_store!(Ptr{S}(optr), a.parent[ind_off + i, tail(inds)...], i)
end
return o[]
else
ind, sub = divrem(inds[1]-1, div(sizeof(S), sizeof(T)))
r = Ref{S}(a.parent[1+ind, Base.tail(inds)...])
r = Ref{S}(a.parent[1+ind, tail(inds)...])
@gc_preserve r begin
rptr = Base.unsafe_convert(Ref{S}, r)
ret = unsafe_load(convert(Ptr{T}, rptr) + sub*sizeof(T))
rptr = unsafe_convert(Ref{S}, r)
ret = unsafe_load(Ptr{T}(rptr), sub+1)
end
return ret
end
end

@inline @propagate_inbounds function Base.setindex!(a::ReinterpretArray{T,S,N}, v, inds::Vararg{Int, N}) where {T,S,N}
@inline @propagate_inbounds setindex!(a::ReinterpretArray{T,0,S} where T, v) where {S} = (a.parent[] = reinterpret(S, v))
@inline @propagate_inbounds setindex!(a::ReinterpretArray, v) = (a[1] = v)

@inline @propagate_inbounds function setindex!(a::ReinterpretArray{T,N,S}, v, inds::Vararg{Int, N}) where {T,N,S}
v = convert(T, v)::T
if sizeof(T) == sizeof(S)
return setindex!(a, reinterpret(S, v), inds...)
Expand All @@ -51,16 +70,16 @@ end
ind_off = (inds[1]-1) * nels
o = Ref{T}(v)
@gc_preserve o begin
optr = Base.unsafe_convert(Ref{T}, o)
optr = unsafe_convert(Ref{T}, o)
for i = 1:nels
a.parent[ind_off + i, Base.tail(inds)...] = unsafe_load(convert(Ptr{S}, optr)+(i-1)*sizeof(S))
a.parent[ind_off + i, tail(inds)...] = unsafe_load(Ptr{S}(optr), i)
end
end
else
ind, sub = divrem(inds[1]-1, div(sizeof(S), sizeof(T)))
r = Ref{S}(a.parent[1+ind, Base.tail(inds)...])
rptr = Base.unsafe_convert(Ref{S}, r)
unsafe_store!(convert(Ptr{T}, rptr) + sub*sizeof(T), v)
r = Ref{S}(a.parent[1+ind, tail(inds)...])
rptr = unsafe_convert(Ref{S}, r)
unsafe_store!(Ptr{T}(rptr), v, sub+1)
a.parent[1+ind] = r[]
end
return a
Expand Down

0 comments on commit e8bfc52

Please sign in to comment.