From c1ec9f5be7aaff76249ffb991635718766040755 Mon Sep 17 00:00:00 2001 From: Mark Kittisopikul Date: Mon, 16 Sep 2024 05:17:15 -0400 Subject: [PATCH] Add additional parameters to ZstdCompressor --- src/compression.jl | 71 ++++++++++++++++++++++++++++++++++++++++++---- src/libzstd.jl | 22 ++++++++++++-- 2 files changed, 85 insertions(+), 8 deletions(-) diff --git a/src/compression.jl b/src/compression.jl index cabc3f9..4eae195 100644 --- a/src/compression.jl +++ b/src/compression.jl @@ -3,15 +3,28 @@ struct ZstdCompressor <: TranscodingStreams.Codec cstream::CStream - level::Int endOp::LibZstd.ZSTD_EndDirective + parameters::Dict{LibZstd.ZSTD_cParameter, Cint} + function ZstdCompressor(cstream, level, endOp=:continue; kwargs...) + _parameters = Dict{LibZstd.ZSTD_cParameter, Cint}(LibZstd.ZSTD_c_compressionLevel => level) + for (k,v) in kwargs + _parameters[_symbols_to_cParameters[k]] = v + end + return new(cstream, endOp, _parameters) + end end function Base.show(io::IO, codec::ZstdCompressor) + parameters_string = "" + for (k,v) in codec.parameters + if k != LibZstd.ZSTD_c_compressionLevel + parameters_string *= string(", ", replace(string(k), "ZSTD_c_" => ""), "=", v) + end + end if codec.endOp == LibZstd.ZSTD_e_end - print(io, "ZstdFrameCompressor(level=$(codec.level))") + print(io, "ZstdFrameCompressor(level=$(codec.level)$parameters_string)") else - print(io, summary(codec), "(level=$(codec.level))") + print(io, summary(codec), "(level=$(codec.level)$parameters_string)") end end @@ -27,13 +40,59 @@ Arguments --------- - `level`: compression level (1..$(MAX_CLEVEL)) """ -function ZstdCompressor(;level::Integer=DEFAULT_COMPRESSION_LEVEL) +function ZstdCompressor(;level::Integer=DEFAULT_COMPRESSION_LEVEL, kwargs...) if !(1 ≤ level ≤ MAX_CLEVEL) throw(ArgumentError("level must be within 1..$(MAX_CLEVEL)")) end - return ZstdCompressor(CStream(), level) + return ZstdCompressor(CStream(), level; kwargs...) +end + +const _symbols_to_cParameters = Dict( + :compressionLevel => LibZstd.ZSTD_c_compressionLevel, + :windowLog => LibZstd.ZSTD_c_windowLog, + :hashLog => LibZstd.ZSTD_c_hashLog, + :chainLog => LibZstd.ZSTD_c_chainLog, + :searchLog => LibZstd.ZSTD_c_searchLog, + :minMatch => LibZstd.ZSTD_c_minMatch, + :targetLength => LibZstd.ZSTD_c_targetLength, + :strategy => LibZstd.ZSTD_c_strategy, + :enableLongDistanceMatching => LibZstd.ZSTD_c_enableLongDistanceMatching, + :ldmHashLog => LibZstd.ZSTD_c_ldmHashLog, + :ldmMinMatch => LibZstd.ZSTD_c_ldmMinMatch, + :ldmBucketSizeLog => LibZstd.ZSTD_c_ldmBucketSizeLog, + :ldmHashRateLog => LibZstd.ZSTD_c_ldmHashRateLog, + :contentSizeFlag => LibZstd.ZSTD_c_contentSizeFlag, + :checksumFlag => LibZstd.ZSTD_c_checksumFlag, + :dictIDFlag => LibZstd.ZSTD_c_dictIDFlag, + :nbWorkers => LibZstd.ZSTD_c_nbWorkers, + :jobSize => LibZstd.ZSTD_c_jobSize, + :overlapLog => LibZstd.ZSTD_c_overlapLog +) + +function Base.propertynames(compressor::ZstdCompressor) + return (fieldnames(ZstdCompressor)..., keys(_symbols_to_cParameters)...) +end + +function Base.getproperty(compressor::ZstdCompressor, name::Symbol) + if name == :level + return Int(get(compressor.parameters, LibZstd.ZSTD_c_compressionLevel, DEFAULT_COMPRESSION_LEVEL)) + elseif haskey(_symbols_to_cParameters, name) + return compressor.parameters[_symbols_to_cParameters[name]] + else + return getfield(compressor, name) + end +end + +function Base.setproperty!(compressor::ZstdCompressor, name::Symbol, value) + if name == :level + compressor.parameters[LibZstd.ZSTD_c_compressionLevel] = value + elseif haskey(_symbols_to_cParameters, name) + compressor.parameters[_symbols_to_cParameters[name]] = value + else + return setfield!(compressor, name, value) + end + return nothing end -ZstdCompressor(cstream, level) = ZstdCompressor(cstream, level, :continue) """ ZstdFrameCompressor(;level=$(DEFAULT_COMPRESSION_LEVEL)) diff --git a/src/libzstd.jl b/src/libzstd.jl index c11b1f1..2b2fb67 100644 --- a/src/libzstd.jl +++ b/src/libzstd.jl @@ -56,8 +56,26 @@ Base.unsafe_convert(::Type{Ptr{LibZstd.ZSTD_CStream}}, cstream::CStream) = cstre Base.unsafe_convert(::Type{Ptr{InBuffer}}, cstream::CStream) = Base.unsafe_convert(Ptr{InBuffer}, cstream.ibuffer) Base.unsafe_convert(::Type{Ptr{OutBuffer}}, cstream::CStream) = Base.unsafe_convert(Ptr{OutBuffer}, cstream.obuffer) -function initialize!(cstream::CStream, level::Integer) - return LibZstd.ZSTD_initCStream(cstream, level) +function initialize!(cstream::CStream, parameters::Dict{LibZstd.ZSTD_cParameter, Cint}) + # Mimick ZSTD_initCStream + # https://github.com/facebook/zstd/blob/20707e3718ee14250fb8a44b3bf023ea36bd88df/lib/zstd.h#L832-L841 + code = LibZstd.ZSTD_CCtx_reset(cstream, LibZstd.ZSTD_reset_session_only) + if iserror(code) + zstderror(cstream, code) + end + + code = LibZstd.ZSTD_CCtx_refCDict(cstream, C_NULL) + if iserror(code) + zstderror(cstream, code) + end + + for (k,v) in parameters + code = LibZstd.ZSTD_CCtx_setParameter(cstream, k, v) + if iserror(code) + zstderror(cstream, code) + end + end + return Csize_t(0) end function reset!(cstream::CStream, srcsize::Integer)