Skip to content

Commit

Permalink
Auto initialize in startproc (#74)
Browse files Browse the repository at this point in the history
* Auto initialize in `startproc`

* Add tests

* Apply suggestions from code review

Co-authored-by: Mark Kittisopikul <[email protected]>

* add explicit return

* Add GC preserve

* reset dstream buffers in reset!

---------

Co-authored-by: Mark Kittisopikul <[email protected]>
  • Loading branch information
nhz2 and mkitti authored Oct 3, 2024
1 parent e7edfed commit 7bd48b4
Show file tree
Hide file tree
Showing 5 changed files with 114 additions and 45 deletions.
26 changes: 14 additions & 12 deletions src/compression.jl
Original file line number Diff line number Diff line change
Expand Up @@ -78,16 +78,6 @@ end
# Methods
# -------

function TranscodingStreams.initialize(codec::ZstdCompressor)
code = initialize!(codec.cstream, codec.level)
if iserror(code)
zstderror(codec.cstream, code)
end
reset!(codec.cstream.ibuffer)
reset!(codec.cstream.obuffer)
return
end

function TranscodingStreams.finalize(codec::ZstdCompressor)
if codec.cstream.ptr != C_NULL
code = free!(codec.cstream)
Expand All @@ -96,12 +86,21 @@ function TranscodingStreams.finalize(codec::ZstdCompressor)
end
codec.cstream.ptr = C_NULL
end
reset!(codec.cstream.ibuffer)
reset!(codec.cstream.obuffer)
return
end

function TranscodingStreams.startproc(codec::ZstdCompressor, mode::Symbol, error::Error)
if codec.cstream.ptr == C_NULL
codec.cstream.ptr = LibZstd.ZSTD_createCStream()
if codec.cstream.ptr == C_NULL
throw(OutOfMemoryError())
end
i_code = initialize!(codec.cstream, codec.level)
if iserror(i_code)
error[] = ErrorException("zstd initialization error")
return :error
end
end
code = reset!(codec.cstream, 0 #=unknown source size=#)
if iserror(code)
error[] = ErrorException("zstd error")
Expand All @@ -111,6 +110,9 @@ function TranscodingStreams.startproc(codec::ZstdCompressor, mode::Symbol, error
end

function TranscodingStreams.process(codec::ZstdCompressor, input::Memory, output::Memory, error::Error)
if codec.cstream.ptr == C_NULL
error("startproc must be called before process")
end
cstream = codec.cstream
ibuffer_starting_pos = UInt(0)
if codec.endOp == LibZstd.ZSTD_e_end &&
Expand Down
26 changes: 14 additions & 12 deletions src/decompression.jl
Original file line number Diff line number Diff line change
Expand Up @@ -33,16 +33,6 @@ end
# Methods
# -------

function TranscodingStreams.initialize(codec::ZstdDecompressor)
code = initialize!(codec.dstream)
if iserror(code)
zstderror(codec.dstream, code)
end
reset!(codec.dstream.ibuffer)
reset!(codec.dstream.obuffer)
return
end

function TranscodingStreams.finalize(codec::ZstdDecompressor)
if codec.dstream.ptr != C_NULL
code = free!(codec.dstream)
Expand All @@ -51,12 +41,21 @@ function TranscodingStreams.finalize(codec::ZstdDecompressor)
end
codec.dstream.ptr = C_NULL
end
reset!(codec.dstream.ibuffer)
reset!(codec.dstream.obuffer)
return
end

function TranscodingStreams.startproc(codec::ZstdDecompressor, mode::Symbol, error::Error)
if codec.dstream.ptr == C_NULL
codec.dstream.ptr = LibZstd.ZSTD_createDStream()
if codec.dstream.ptr == C_NULL
throw(OutOfMemoryError())
end
i_code = initialize!(codec.dstream)
if iserror(i_code)
error[] = ErrorException("zstd initialization error")
return :error
end
end
code = reset!(codec.dstream)
if iserror(code)
error[] = ErrorException("zstd error")
Expand All @@ -66,6 +65,9 @@ function TranscodingStreams.startproc(codec::ZstdDecompressor, mode::Symbol, err
end

function TranscodingStreams.process(codec::ZstdDecompressor, input::Memory, output::Memory, error::Error)
if codec.dstream.ptr == C_NULL
error("startproc must be called before process")
end
dstream = codec.dstream
dstream.ibuffer.src = input.ptr
dstream.ibuffer.size = input.size
Expand Down
14 changes: 4 additions & 10 deletions src/libzstd.jl
Original file line number Diff line number Diff line change
Expand Up @@ -44,11 +44,7 @@ mutable struct CStream
obuffer::OutBuffer

function CStream()
ptr = LibZstd.ZSTD_createCStream()
if ptr == C_NULL
throw(OutOfMemoryError())
end
return new(ptr, InBuffer(), OutBuffer())
return new(C_NULL, InBuffer(), OutBuffer())
end
end

Expand Down Expand Up @@ -127,11 +123,7 @@ mutable struct DStream
obuffer::OutBuffer

function DStream()
ptr = LibZstd.ZSTD_createDStream()
if ptr == C_NULL
throw(OutOfMemoryError())
end
return new(ptr, InBuffer(), OutBuffer())
return new(C_NULL, InBuffer(), OutBuffer())
end
end
Base.unsafe_convert(::Type{Ptr{LibZstd.ZSTD_DStream}}, dstream::DStream) = dstream.ptr
Expand All @@ -145,6 +137,8 @@ end
function reset!(dstream::DStream)
# LibZstd.ZSTD_resetDStream is deprecated
# https://github.com/facebook/zstd/blob/9d2a45a705e22ad4817b41442949cd0f78597154/lib/zstd.h#L2332-L2339
reset!(dstream.ibuffer)
reset!(dstream.obuffer)
return LibZstd.ZSTD_DCtx_reset(dstream, LibZstd.ZSTD_reset_session_only)
end

Expand Down
25 changes: 14 additions & 11 deletions test/compress_endOp.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,27 +3,29 @@ using Test

@testset "compress! endOp = :continue" begin
data = rand(1:100, 1024*1024)
cstream = CodecZstd.CStream()
cstream.ibuffer.src = pointer(data)
cstream.ibuffer.size = sizeof(data)
cstream.ibuffer.pos = 0
cstream.obuffer.dst = Base.Libc.malloc(sizeof(data)*2)
cstream.obuffer.size = sizeof(data)*2
cstream.obuffer.pos = 0
try
GC.@preserve data begin
GC.@preserve data begin
cstream = CodecZstd.CStream()
cstream.ptr = CodecZstd.LibZstd.ZSTD_createCStream()
cstream.ibuffer.src = pointer(data)
cstream.ibuffer.size = sizeof(data)
cstream.ibuffer.pos = 0
cstream.obuffer.dst = Base.Libc.malloc(sizeof(data)*2)
cstream.obuffer.size = sizeof(data)*2
cstream.obuffer.pos = 0
try
# default endOp
@test CodecZstd.compress!(cstream; endOp=:continue) == 0
@test CodecZstd.find_decompressed_size(cstream.obuffer.dst, cstream.obuffer.pos) == CodecZstd.ZSTD_CONTENTSIZE_UNKNOWN
finally
Base.Libc.free(cstream.obuffer.dst)
end
finally
Base.Libc.free(cstream.obuffer.dst)
end
end

@testset "compress! endOp = :flush" begin
data = rand(1:100, 1024*1024)
cstream = CodecZstd.CStream()
cstream.ptr = CodecZstd.LibZstd.ZSTD_createCStream()
cstream.ibuffer.src = pointer(data)
cstream.ibuffer.size = sizeof(data)
cstream.ibuffer.pos = 0
Expand All @@ -43,6 +45,7 @@ end
@testset "compress! endOp = :end" begin
data = rand(1:100, 1024*1024)
cstream = CodecZstd.CStream()
cstream.ptr = CodecZstd.LibZstd.ZSTD_createCStream()
cstream.ibuffer.src = pointer(data)
cstream.ibuffer.size = sizeof(data)
cstream.ibuffer.pos = 0
Expand Down
68 changes: 68 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -158,4 +158,72 @@ include("utils.jl")

include("compress_endOp.jl")
include("static_only_tests.jl")

@testset "reusing a compressor" begin
compressor = ZstdCompressor()
x = rand(UInt8, 1000)
TranscodingStreams.initialize(compressor)
ret1 = transcode(compressor, x)
TranscodingStreams.finalize(compressor)

# compress again using the same compressor
TranscodingStreams.initialize(compressor) # segfault happens here!
ret2 = transcode(compressor, x)
ret3 = transcode(compressor, x)
TranscodingStreams.finalize(compressor)

@test transcode(ZstdDecompressor, ret1) == x
@test transcode(ZstdDecompressor, ret2) == x
@test transcode(ZstdDecompressor, ret3) == x
@test ret1 == ret2
@test ret1 == ret3

decompressor = ZstdDecompressor()
TranscodingStreams.initialize(decompressor)
@test transcode(decompressor, ret1) == x
TranscodingStreams.finalize(decompressor)

TranscodingStreams.initialize(decompressor)
@test transcode(decompressor, ret1) == x
TranscodingStreams.finalize(decompressor)
end

@testset "use after free doesn't segfault" begin
@testset "$(Codec)" for Codec in (ZstdCompressor, ZstdDecompressor)
codec = Codec()
TranscodingStreams.initialize(codec)
TranscodingStreams.finalize(codec)
data = [0x00,0x01]
GC.@preserve data let m = TranscodingStreams.Memory(pointer(data), length(data))
try
TranscodingStreams.expectedsize(codec, m)
catch
end
try
TranscodingStreams.minoutsize(codec, m)
catch
end
try
TranscodingStreams.initialize(codec)
catch
end
try
TranscodingStreams.process(codec, m, m, TranscodingStreams.Error())
catch
end
try
TranscodingStreams.startproc(codec, :read, TranscodingStreams.Error())
catch
end
try
TranscodingStreams.process(codec, m, m, TranscodingStreams.Error())
catch
end
try
TranscodingStreams.finalize(codec)
catch
end
end
end
end
end

0 comments on commit 7bd48b4

Please sign in to comment.