diff --git a/src/compression.jl b/src/compression.jl index cabc3f9..683ab76 100644 --- a/src/compression.jl +++ b/src/compression.jl @@ -79,6 +79,9 @@ end # ------- function TranscodingStreams.initialize(codec::ZstdCompressor) + if codec.cstream.ptr == C_NULL + error("codec use after free") + end code = initialize!(codec.cstream, codec.level) if iserror(code) zstderror(codec.cstream, code) @@ -102,6 +105,9 @@ function TranscodingStreams.finalize(codec::ZstdCompressor) end function TranscodingStreams.startproc(codec::ZstdCompressor, mode::Symbol, error::Error) + if codec.cstream.ptr == C_NULL + error("codec use after free") + end code = reset!(codec.cstream, 0 #=unknown source size=#) if iserror(code) error[] = ErrorException("zstd error") @@ -111,6 +117,9 @@ function TranscodingStreams.startproc(codec::ZstdCompressor, mode::Symbol, error end function TranscodingStreams.process(codec::ZstdCompressor, input::Memory, output::Memory, error::Error) + if codec.cstream.ptr == C_NULL + error("codec use after free") + end cstream = codec.cstream ibuffer_starting_pos = UInt(0) if codec.endOp == LibZstd.ZSTD_e_end && diff --git a/src/decompression.jl b/src/decompression.jl index 765ce2c..e1834a1 100644 --- a/src/decompression.jl +++ b/src/decompression.jl @@ -34,6 +34,9 @@ end # ------- function TranscodingStreams.initialize(codec::ZstdDecompressor) + if codec.dstream.ptr == C_NULL + error("codec use after free") + end code = initialize!(codec.dstream) if iserror(code) zstderror(codec.dstream, code) @@ -57,6 +60,9 @@ function TranscodingStreams.finalize(codec::ZstdDecompressor) end function TranscodingStreams.startproc(codec::ZstdDecompressor, mode::Symbol, error::Error) + if codec.dstream.ptr == C_NULL + error("codec use after free") + end code = reset!(codec.dstream) if iserror(code) error[] = ErrorException("zstd error") @@ -66,6 +72,9 @@ function TranscodingStreams.startproc(codec::ZstdDecompressor, mode::Symbol, err end function TranscodingStreams.process(codec::ZstdDecompressor, input::Memory, output::Memory, error::Error) + if codec.dstream.ptr == C_NULL + error("codec use after free") + end dstream = codec.dstream dstream.ibuffer.src = input.ptr dstream.ibuffer.size = input.size diff --git a/test/runtests.jl b/test/runtests.jl index a111d9a..9488ec8 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -156,6 +156,41 @@ include("utils.jl") @test CodecZstd.find_decompressed_size(v) == CodecZstd.ZSTD_CONTENTSIZE_ERROR end + @testset "use after free doesn't segfault" begin + @testset "$(Codec)" for Codec in (ZstdCompressor, ZstdDecompressor) + codec = Codec() + TranscodingStreams.initialize(codec) + TranscodingStreams.finalize(codec) + data = [0x00,0x01] + GC.@preserve data let m = TranscodingStreams.Memory(pointer(data), length(data)) + try + TranscodingStreams.expectedsize(codec, m) + catch + end + try + TranscodingStreams.minoutsize(codec, m) + catch + end + try + TranscodingStreams.initialize(codec) + catch + end + try + TranscodingStreams.startproc(codec, :read, TranscodingStreams.Error()) + catch + end + try + TranscodingStreams.process(codec, m, m, TranscodingStreams.Error()) + catch + end + try + TranscodingStreams.finalize(codec) + catch + end + end + end + end + include("compress_endOp.jl") include("static_only_tests.jl") end