Skip to content

Commit

Permalink
Merge pull request #23914 from JuliaLang/yyc/gc/safe_unsafe_write
Browse files Browse the repository at this point in the history
Fix a small number of invalid unsafe code
  • Loading branch information
yuyichao authored Sep 29, 2017
2 parents 056b374 + 15ca594 commit 3776fce
Show file tree
Hide file tree
Showing 21 changed files with 93 additions and 56 deletions.
6 changes: 6 additions & 0 deletions base/array.jl
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,8 @@ that N is inbounds on either array. Incorrect usage may corrupt or segfault your
the same manner as C.
"""
function unsafe_copy!(dest::Array{T}, doffs, src::Array{T}, soffs, n) where T
t1 = @_gc_preserve_begin dest
t2 = @_gc_preserve_begin src
if isbits(T)
unsafe_copy!(pointer(dest, doffs), pointer(src, soffs), n)
elseif isbitsunion(T)
Expand All @@ -185,6 +187,8 @@ function unsafe_copy!(dest::Array{T}, doffs, src::Array{T}, soffs, n) where T
ccall(:jl_array_ptr_copy, Void, (Any, Ptr{Void}, Any, Ptr{Void}, Int),
dest, pointer(dest, doffs), src, pointer(src, soffs), n)
end
@_gc_preserve_end t2
@_gc_preserve_end t1
return dest
end

Expand Down Expand Up @@ -1570,6 +1574,7 @@ function vcat(arrays::Vector{T}...) where T
else
elsz = Core.sizeof(Ptr{Void})
end
t = @_gc_preserve_begin arr
for a in arrays
na = length(a)
nba = na * elsz
Expand All @@ -1589,6 +1594,7 @@ function vcat(arrays::Vector{T}...) where T
end
ptr += nba
end
@_gc_preserve_end t
return arr
end

Expand Down
4 changes: 2 additions & 2 deletions base/datafmt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ function readdlm_auto(input::AbstractString, dlm::Char, T::Type, eol::Char, auto
# TODO: It would be nicer to use String(a) without making a copy,
# but because the mmap'ed array is not NUL-terminated this causes
# jl_try_substrtod to segfault below.
return readdlm_string(unsafe_string(pointer(a),length(a)), dlm, T, eol, auto, optsd)
return readdlm_string(Base.@gc_preserve(a, unsafe_string(pointer(a),length(a))), dlm, T, eol, auto, optsd)
else
return readdlm_string(read(input, String), dlm, T, eol, auto, optsd)
end
Expand Down Expand Up @@ -220,7 +220,7 @@ function DLMStore(::Type{T}, dims::NTuple{2,Integer},
end

_chrinstr(sbuff::String, chr::UInt8, startpos::Int, endpos::Int) =
(endpos >= startpos) && (C_NULL != ccall(:memchr, Ptr{UInt8},
Base.@gc_preserve sbuff (endpos >= startpos) && (C_NULL != ccall(:memchr, Ptr{UInt8},
(Ptr{UInt8}, Int32, Csize_t), pointer(sbuff)+startpos-1, chr, endpos-startpos+1))

function store_cell(dlmstore::DLMStore{T}, row::Int, col::Int,
Expand Down
2 changes: 1 addition & 1 deletion base/deepcopy.jl
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ function deepcopy_internal(x::String, stackdict::ObjectIdDict)
if haskey(stackdict, x)
return stackdict[x]
end
y = unsafe_string(pointer(x), sizeof(x))
y = @gc_preserve x unsafe_string(pointer(x), sizeof(x))
stackdict[x] = y
return y
end
Expand Down
2 changes: 1 addition & 1 deletion base/env.jl
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ if Sys.iswindows()
blk = block[2]
len = ccall(:wcslen, UInt, (Ptr{UInt16},), pos)
buf = Vector{UInt16}(len)
unsafe_copy!(pointer(buf), pos, len)
@gc_preserve buf unsafe_copy!(pointer(buf), pos, len)
env = transcode(String, buf)
m = match(r"^(=?[^=]+)=(.*)$"s, env)
if m === nothing
Expand Down
25 changes: 21 additions & 4 deletions base/essentials.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,14 @@ macro _noinline_meta()
Expr(:meta, :noinline)
end

macro _gc_preserve_begin(arg1)
Expr(:gc_preserve_begin, esc(arg1))
end

macro _gc_preserve_end(token)
Expr(:gc_preserve_end, esc(token))
end

"""
@nospecialize
Expand Down Expand Up @@ -513,16 +521,23 @@ end
# SimpleVector

function getindex(v::SimpleVector, i::Int)
if !(1 <= i <= length(v))
@boundscheck if !(1 <= i <= length(v))
throw(BoundsError(v,i))
end
t = @_gc_preserve_begin v
x = unsafe_load(convert(Ptr{Ptr{Void}},data_pointer_from_objref(v)) + i*sizeof(Ptr))
x == C_NULL && throw(UndefRefError())
return unsafe_pointer_to_objref(x)
o = unsafe_pointer_to_objref(x)
@_gc_preserve_end t
return o
end

# TODO: add gc use intrinsic call instead of noinline
length(v::SimpleVector) = (@_noinline_meta; unsafe_load(convert(Ptr{Int},data_pointer_from_objref(v))))
function length(v::SimpleVector)
t = @_gc_preserve_begin v
l = unsafe_load(convert(Ptr{Int},data_pointer_from_objref(v)))
@_gc_preserve_end t
return l
end
endof(v::SimpleVector) = length(v)
start(v::SimpleVector) = 1
next(v::SimpleVector,i) = (v[i],i+1)
Expand Down Expand Up @@ -573,7 +588,9 @@ function isassigned end

function isassigned(v::SimpleVector, i::Int)
@boundscheck 1 <= i <= length(v) || return false
t = @_gc_preserve_begin v
x = unsafe_load(convert(Ptr{Ptr{Void}},data_pointer_from_objref(v)) + i*sizeof(Ptr))
@_gc_preserve_end t
return x != C_NULL
end

Expand Down
4 changes: 2 additions & 2 deletions base/gmp.jl
Original file line number Diff line number Diff line change
Expand Up @@ -252,7 +252,7 @@ function tryparse_internal(::Type{BigInt}, s::AbstractString, startpos::Int, end
if Base.containsnul(bstr)
err = -1 # embedded NUL char (not handled correctly by GMP)
else
err = MPZ.set_str!(z, pointer(bstr)+(i-start(bstr)), base)
err = Base.@gc_preserve bstr MPZ.set_str!(z, pointer(bstr)+(i-start(bstr)), base)
end
if err != 0
raise && throw(ArgumentError("invalid BigInt: $(repr(bstr))"))
Expand Down Expand Up @@ -612,7 +612,7 @@ function base(b::Integer, n::BigInt, pad::Integer=1)
nd1 = ndigits(n, b)
nd = max(nd1, pad)
sv = Base.StringVector(nd + isneg(n))
MPZ.get_str!(pointer(sv) + nd - nd1, b, n)
Base.@gc_preserve sv MPZ.get_str!(pointer(sv) + nd - nd1, b, n)
@inbounds for i = (1:nd-nd1) .+ isneg(n)
sv[i] = '0' % UInt8
end
Expand Down
14 changes: 7 additions & 7 deletions base/io.jl
Original file line number Diff line number Diff line change
Expand Up @@ -365,9 +365,9 @@ function write(s::IO, A::AbstractArray)
return nb
end

@noinline function write(s::IO, a::Array) # mark noinline to ensure the array is gc-rooted somewhere (by the caller)
function write(s::IO, a::Array)
if isbits(eltype(a))
return unsafe_write(s, pointer(a), sizeof(a))
return @gc_preserve a unsafe_write(s, pointer(a), sizeof(a))
else
depwarn("Calling `write` on non-isbits arrays is deprecated. Use a loop or `serialize` instead.", :write)
nb = 0
Expand All @@ -384,7 +384,7 @@ function write(s::IO, a::SubArray{T,N,<:Array}) where {T,N}
end
elsz = sizeof(T)
colsz = size(a,1) * elsz
if stride(a,1) != 1
@gc_preserve a if stride(a,1) != 1
for idxs in CartesianRange(size(a))
unsafe_write(s, pointer(a, idxs.I), elsz)
end
Expand Down Expand Up @@ -444,14 +444,14 @@ end
read(s::IO, ::Type{Bool}) = (read(s, UInt8) != 0)
read(s::IO, ::Type{Ptr{T}}) where {T} = convert(Ptr{T}, read(s, UInt))

@noinline function read!(s::IO, a::Array{UInt8}) # mark noinline to ensure the array is gc-rooted somewhere (by the caller)
unsafe_read(s, pointer(a), sizeof(a))
function read!(s::IO, a::Array{UInt8})
@gc_preserve a unsafe_read(s, pointer(a), sizeof(a))
return a
end

@noinline function read!(s::IO, a::Array{T}) where T # mark noinline to ensure the array is gc-rooted somewhere (by the caller)
function read!(s::IO, a::Array{T}) where T
if isbits(T)
unsafe_read(s, pointer(a), sizeof(a))
@gc_preserve a unsafe_read(s, pointer(a), sizeof(a))
else
for i in eachindex(a)
a[i] = read(s, T)
Expand Down
10 changes: 5 additions & 5 deletions base/iobuffer.jl
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ function unsafe_read(from::GenericIOBuffer, p::Ptr{UInt8}, nb::UInt)
from.readable || throw(ArgumentError("read failed, IOBuffer is not readable"))
avail = nb_available(from)
adv = min(avail, nb)
unsafe_copy!(p, pointer(from.data, from.ptr), adv)
@gc_preserve from unsafe_copy!(p, pointer(from.data, from.ptr), adv)
from.ptr += adv
if nb > avail
throw(EOFError())
Expand All @@ -114,7 +114,7 @@ function read_sub(from::GenericIOBuffer, a::AbstractArray{T}, offs, nel) where T
end
if isbits(T) && isa(a,Array)
nb = UInt(nel * sizeof(T))
unsafe_read(from, pointer(a, offs), nb)
@gc_preserve a unsafe_read(from, pointer(a, offs), nb)
else
for i = offs:offs+nel-1
a[i] = read(to, T)
Expand Down Expand Up @@ -334,7 +334,7 @@ function write_sub(to::GenericIOBuffer, a::AbstractArray{UInt8}, offs, nel)
if offs+nel-1 > length(a) || offs < 1 || nel < 0
throw(BoundsError())
end
unsafe_write(to, pointer(a, offs), UInt(nel))
@gc_preserve a unsafe_write(to, pointer(a, offs), UInt(nel))
end

@inline function write(to::GenericIOBuffer, a::UInt8)
Expand Down Expand Up @@ -367,7 +367,7 @@ read(io::GenericIOBuffer, nb::Integer) = read!(io,StringVector(min(nb, nb_availa

function search(buf::IOBuffer, delim::UInt8)
p = pointer(buf.data, buf.ptr)
q = ccall(:memchr,Ptr{UInt8},(Ptr{UInt8},Int32,Csize_t),p,delim,nb_available(buf))
q = @gc_preserve buf ccall(:memchr,Ptr{UInt8},(Ptr{UInt8},Int32,Csize_t),p,delim,nb_available(buf))
nb::Int = (q == C_NULL ? 0 : q-p+1)
return nb
end
Expand Down Expand Up @@ -413,7 +413,7 @@ function crc32c(io::IOBuffer, nb::Integer, crc::UInt32=0x00000000)
io.readable || throw(ArgumentError("read failed, IOBuffer is not readable"))
n = min(nb, nb_available(io))
n == 0 && return crc
crc = unsafe_crc32c(pointer(io.data, io.ptr), n, crc)
crc = @gc_preserve io unsafe_crc32c(pointer(io.data, io.ptr), n, crc)
io.ptr += n
return crc
end
Expand Down
6 changes: 3 additions & 3 deletions base/iostream.jl
Original file line number Diff line number Diff line change
Expand Up @@ -264,7 +264,7 @@ end
function readbytes_all!(s::IOStream, b::Array{UInt8}, nb)
olb = lb = length(b)
nr = 0
while nr < nb
@gc_preserve b while nr < nb
if lb < nr+1
lb = max(65536, (nr+1) * 2)
resize!(b, lb)
Expand All @@ -284,8 +284,8 @@ function readbytes_some!(s::IOStream, b::Array{UInt8}, nb)
if nb > lb
resize!(b, nb)
end
nr = Int(ccall(:ios_read, Csize_t, (Ptr{Void}, Ptr{Void}, Csize_t),
s.ios, pointer(b), nb))
nr = @gc_preserve b Int(ccall(:ios_read, Csize_t, (Ptr{Void}, Ptr{Void}, Csize_t),
s.ios, pointer(b), nb))
if lb > olb && lb > nr
resize!(b, nr)
end
Expand Down
4 changes: 2 additions & 2 deletions base/libc.jl
Original file line number Diff line number Diff line change
Expand Up @@ -249,7 +249,7 @@ function gethostname()
ccall(:gethostname, Int32, (Ptr{UInt8}, UInt), hn, length(hn))
end
systemerror("gethostname", err != 0)
return unsafe_string(pointer(hn))
return Base.@gc_preserve hn unsafe_string(pointer(hn))
end

## system error handling ##
Expand Down Expand Up @@ -305,7 +305,7 @@ if Sys.iswindows()
p = lpMsgBuf[]
len == 0 && return ""
buf = Vector{UInt16}(len)
unsafe_copy!(pointer(buf), p, len)
Base.@gc_preserve buf unsafe_copy!(pointer(buf), p, len)
ccall(:LocalFree, stdcall, Ptr{Void}, (Ptr{Void},), p)
return transcode(String, buf)
end
Expand Down
25 changes: 14 additions & 11 deletions base/linalg/blas.jl
Original file line number Diff line number Diff line change
Expand Up @@ -313,21 +313,21 @@ function dot(DX::Union{DenseArray{T},StridedVector{T}}, DY::Union{DenseArray{T},
if n != length(DY)
throw(DimensionMismatch("dot product arguments have lengths $(length(DX)) and $(length(DY))"))
end
dot(n, pointer(DX), stride(DX, 1), pointer(DY), stride(DY, 1))
Base.@gc_preserve DX DY dot(n, pointer(DX), stride(DX, 1), pointer(DY), stride(DY, 1))
end
function dotc(DX::Union{DenseArray{T},StridedVector{T}}, DY::Union{DenseArray{T},StridedVector{T}}) where T<:BlasComplex
n = length(DX)
if n != length(DY)
throw(DimensionMismatch("dot product arguments have lengths $(length(DX)) and $(length(DY))"))
end
dotc(n, pointer(DX), stride(DX, 1), pointer(DY), stride(DY, 1))
Base.@gc_preserve DX DY dotc(n, pointer(DX), stride(DX, 1), pointer(DY), stride(DY, 1))
end
function dotu(DX::Union{DenseArray{T},StridedVector{T}}, DY::Union{DenseArray{T},StridedVector{T}}) where T<:BlasComplex
n = length(DX)
if n != length(DY)
throw(DimensionMismatch("dot product arguments have lengths $(length(DX)) and $(length(DY))"))
end
dotu(n, pointer(DX), stride(DX, 1), pointer(DY), stride(DY, 1))
Base.@gc_preserve DX DY dotu(n, pointer(DX), stride(DX, 1), pointer(DY), stride(DY, 1))
end

## nrm2
Expand Down Expand Up @@ -364,7 +364,7 @@ for (fname, elty, ret_type) in ((:dnrm2_,:Float64,:Float64),
end
end
end
nrm2(x::Union{StridedVector,Array}) = nrm2(length(x), pointer(x), stride1(x))
nrm2(x::Union{StridedVector,Array}) = Base.@gc_preserve x nrm2(length(x), pointer(x), stride1(x))

## asum

Expand Down Expand Up @@ -397,7 +397,7 @@ for (fname, elty, ret_type) in ((:dasum_,:Float64,:Float64),
end
end
end
asum(x::Union{StridedVector,Array}) = asum(length(x), pointer(x), stride1(x))
asum(x::Union{StridedVector,Array}) = Base.@gc_preserve x asum(length(x), pointer(x), stride1(x))

## axpy

Expand Down Expand Up @@ -445,7 +445,7 @@ function axpy!(alpha::Number, x::Union{DenseArray{T},StridedVector{T}}, y::Union
if length(x) != length(y)
throw(DimensionMismatch("x has length $(length(x)), but y has length $(length(y))"))
end
axpy!(length(x), convert(T,alpha), pointer(x), stride(x, 1), pointer(y), stride(y, 1))
Base.@gc_preserve x y axpy!(length(x), convert(T,alpha), pointer(x), stride(x, 1), pointer(y), stride(y, 1))
y
end

Expand All @@ -460,7 +460,7 @@ function axpy!(alpha::Number, x::Array{T}, rx::Union{UnitRange{Ti},AbstractRange
if minimum(ry) < 1 || maximum(ry) > length(y)
throw(ArgumentError("range out of bounds for y, of length $(length(y))"))
end
axpy!(length(rx), convert(T, alpha), pointer(x)+(first(rx)-1)*sizeof(T), step(rx), pointer(y)+(first(ry)-1)*sizeof(T), step(ry))
Base.@gc_preserve x y axpy!(length(rx), convert(T, alpha), pointer(x)+(first(rx)-1)*sizeof(T), step(rx), pointer(y)+(first(ry)-1)*sizeof(T), step(ry))
y
end

Expand Down Expand Up @@ -509,7 +509,7 @@ function axpby!(alpha::Number, x::Union{DenseArray{T},StridedVector{T}}, beta::N
if length(x) != length(y)
throw(DimensionMismatch("x has length $(length(x)), but y has length $(length(y))"))
end
axpby!(length(x), convert(T,alpha), pointer(x), stride(x, 1), convert(T,beta), pointer(y), stride(y, 1))
Base.@gc_preserve x y axpby!(length(x), convert(T,alpha), pointer(x), stride(x, 1), convert(T,beta), pointer(y), stride(y, 1))
y
end

Expand All @@ -526,7 +526,7 @@ for (fname, elty) in ((:idamax_,:Float64),
end
end
end
iamax(dx::Union{StridedVector,Array}) = iamax(length(dx), pointer(dx), stride1(dx))
iamax(dx::Union{StridedVector,Array}) = Base.@gc_preserve dx iamax(length(dx), pointer(dx), stride1(dx))

# Level 2
## mv
Expand Down Expand Up @@ -1526,7 +1526,10 @@ function copy!(dest::Array{T}, rdest::Union{UnitRange{Ti},AbstractRange{Ti}},
if length(rdest) != length(rsrc)
throw(DimensionMismatch("ranges must be of the same length"))
end
BLAS.blascopy!(length(rsrc), pointer(src)+(first(rsrc)-1)*sizeof(T), step(rsrc),
pointer(dest)+(first(rdest)-1)*sizeof(T), step(rdest))
Base.@gc_preserve src dest BLAS.blascopy!(length(rsrc),
pointer(src) + (first(rsrc) - 1) * sizeof(T),
step(rsrc),
pointer(dest) + (first(rdest) - 1) * sizeof(T),
step(rdest))
dest
end
4 changes: 2 additions & 2 deletions base/linalg/dense.jl
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ scale!(s::T, X::Array{T}) where {T<:BlasFloat} = scale!(X, s)
scale!(X::Array{T}, s::Number) where {T<:BlasFloat} = scale!(X, convert(T, s))
function scale!(X::Array{T}, s::Real) where T<:BlasComplex
R = typeof(real(zero(T)))
BLAS.scal!(2*length(X), convert(R,s), convert(Ptr{R},pointer(X)), 1)
Base.@gc_preserve X BLAS.scal!(2*length(X), convert(R,s), convert(Ptr{R},pointer(X)), 1)
X
end

Expand Down Expand Up @@ -119,7 +119,7 @@ function norm(x::StridedVector{T}, rx::Union{UnitRange{TI},AbstractRange{TI}}) w
if minimum(rx) < 1 || maximum(rx) > length(x)
throw(BoundsError(x, rx))
end
BLAS.nrm2(length(rx), pointer(x)+(first(rx)-1)*sizeof(T), step(rx))
Base.@gc_preserve x BLAS.nrm2(length(rx), pointer(x)+(first(rx)-1)*sizeof(T), step(rx))
end

vecnorm1(x::Union{Array{T},StridedVector{T}}) where {T<:BlasReal} =
Expand Down
Loading

0 comments on commit 3776fce

Please sign in to comment.