From c39ea77ea0becb3078b2ef0b9df5bd5c9f951015 Mon Sep 17 00:00:00 2001 From: Stefan Wehner Date: Fri, 15 Mar 2024 11:59:08 +0100 Subject: [PATCH 1/4] Support optional format paramater for get_frame_parameters --- .cargo/config | 3 +++ c-ext/frameparams.c | 10 ++++++---- rust-ext/src/frame_parameters.rs | 15 ++++++++++++--- tests/test_data_structures.py | 10 ++++++++++ zstandard/__init__.pyi | 2 +- zstandard/backend_cffi.py | 6 ++++-- 6 files changed, 36 insertions(+), 10 deletions(-) diff --git a/.cargo/config b/.cargo/config index e5007178..af951327 100644 --- a/.cargo/config +++ b/.cargo/config @@ -1,2 +1,5 @@ [target.x86_64-apple-darwin] rustflags = ["-C", "link-arg=-undefined", "-C", "link-arg=dynamic_lookup"] + +[target.aarch64-apple-darwin] +rustflags = ["-C", "link-arg=-undefined", "-C", "link-arg=dynamic_lookup"] diff --git a/c-ext/frameparams.c b/c-ext/frameparams.c index 951a91d6..9d37b398 100644 --- a/c-ext/frameparams.c +++ b/c-ext/frameparams.c @@ -12,19 +12,21 @@ extern PyObject *ZstdError; FrameParametersObject *get_frame_parameters(PyObject *self, PyObject *args, PyObject *kwargs) { - static char *kwlist[] = {"data", NULL}; + static char *kwlist[] = {"data", "format", NULL}; Py_buffer source; + ZSTD_frameHeader header; + ZSTD_format_e format = ZSTD_f_zstd1; FrameParametersObject *result = NULL; size_t zresult; - if (!PyArg_ParseTupleAndKeywords(args, kwargs, "y*:get_frame_parameters", - kwlist, &source)) { + if (!PyArg_ParseTupleAndKeywords(args, kwargs, "y*|I:get_frame_parameters", + kwlist, &source, &format)) { return NULL; } - zresult = ZSTD_getFrameHeader(&header, source.buf, source.len); + zresult = ZSTD_getFrameHeader_advanced(&header, source.buf, source.len, format); if (ZSTD_isError(zresult)) { PyErr_Format(ZstdError, "cannot get frame parameters: %s", diff --git a/rust-ext/src/frame_parameters.rs b/rust-ext/src/frame_parameters.rs index 13d7fd3f..7628dad5 100644 --- a/rust-ext/src/frame_parameters.rs +++ b/rust-ext/src/frame_parameters.rs @@ -6,7 +6,7 @@ use { crate::ZstdError, - pyo3::{buffer::PyBuffer, prelude::*, wrap_pyfunction}, + pyo3::{buffer::PyBuffer, prelude::*, wrap_pyfunction, exceptions::{PyValueError}}, }; #[pyclass(module = "zstandard.backend_rust")] @@ -67,10 +67,19 @@ fn frame_header_size(data: PyBuffer) -> PyResult { } #[pyfunction] -fn get_frame_parameters(py: Python, buffer: PyBuffer) -> PyResult> { +#[pyo3(signature = (buffer, format=zstd_sys::ZSTD_format_e::ZSTD_f_zstd1 as u32))] +fn get_frame_parameters(py: Python, buffer: PyBuffer, format: u32) -> PyResult> { let raw_data = unsafe { std::slice::from_raw_parts::(buffer.buf_ptr() as *const _, buffer.len_bytes()) }; + let format = if format == zstd_sys::ZSTD_format_e::ZSTD_f_zstd1 as _ { + zstd_sys::ZSTD_format_e::ZSTD_f_zstd1 + } else if format == zstd_sys::ZSTD_format_e::ZSTD_f_zstd1_magicless as _ { + zstd_sys::ZSTD_format_e::ZSTD_f_zstd1_magicless + } else { + return Err(PyValueError::new_err(format!("invalid format value"))); + }; + let mut header = zstd_sys::ZSTD_frameHeader { frameContentSize: 0, @@ -84,7 +93,7 @@ fn get_frame_parameters(py: Python, buffer: PyBuffer) -> PyResult int: ... def frame_content_size(data: ByteString) -> int: ... def frame_header_size(data: ByteString) -> int: ... -def get_frame_parameters(data: ByteString) -> FrameParameters: ... +def get_frame_parameters(data: ByteString, format: Optional[int]) -> FrameParameters: ... def train_dictionary( dict_size: int, samples: list[ByteString], diff --git a/zstandard/backend_cffi.py b/zstandard/backend_cffi.py index 7137542f..72e63b7d 100644 --- a/zstandard/backend_cffi.py +++ b/zstandard/backend_cffi.py @@ -2558,7 +2558,7 @@ def frame_header_size(data): return zresult -def get_frame_parameters(data): +def get_frame_parameters(data, format=FORMAT_ZSTD1): """ Parse a zstd frame header into frame parameters. @@ -2569,13 +2569,15 @@ def get_frame_parameters(data): :param data: Data from which to read frame parameters. + :param format: + Set the format of data for the decoder. :return: :py:class:`FrameParameters` """ params = ffi.new("ZSTD_frameHeader *") data_buffer = ffi.from_buffer(data) - zresult = lib.ZSTD_getFrameHeader(params, data_buffer, len(data_buffer)) + zresult = lib.ZSTD_getFrameHeader_advanced(params, data_buffer, len(data_buffer), format) if lib.ZSTD_isError(zresult): raise ZstdError( "cannot get frame parameters: %s" % _zstd_error(zresult) From 017dd2e5fb578f7244015592a3a3759aafc6c454 Mon Sep 17 00:00:00 2001 From: Stefan Wehner Date: Fri, 15 Mar 2024 16:29:58 +0100 Subject: [PATCH 2/4] Use format to get content size in decompress to support headerless --- c-ext/decompressor.c | 8 +++---- rust-ext/src/decompressor.rs | 32 +++++++++++++++++++-------- tests/test_decompressor_decompress.py | 10 +++++++++ zstandard/backend_cffi.py | 11 +++++---- 4 files changed, 42 insertions(+), 19 deletions(-) diff --git a/c-ext/decompressor.c b/c-ext/decompressor.c index b8e7f82c..db96ba9c 100644 --- a/c-ext/decompressor.c +++ b/c-ext/decompressor.c @@ -299,15 +299,15 @@ PyObject *Decompressor_decompress(ZstdDecompressor *self, PyObject *args, goto finally; } - decompressedSize = ZSTD_getFrameContentSize(source.buf, source.len); - - if (ZSTD_CONTENTSIZE_ERROR == decompressedSize) { + ZSTD_frameHeader zfh; + if (ZSTD_getFrameHeader_advanced(&zfh, source.buf, source.len, self->format) != 0) { PyErr_SetString(ZstdError, "error determining content size from frame header"); goto finally; } + decompressedSize=zfh.frameContentSize; /* Special case of empty frame. */ - else if (0 == decompressedSize) { + if (0 == decompressedSize) { result = PyBytes_FromStringAndSize("", 0); goto finally; } diff --git a/rust-ext/src/decompressor.rs b/rust-ext/src/decompressor.rs index 2964d9a3..07f0573f 100644 --- a/rust-ext/src/decompressor.rs +++ b/rust-ext/src/decompressor.rs @@ -176,17 +176,31 @@ impl ZstdDecompressor { self.setup_dctx(py, true)?; - let output_size = - unsafe { zstd_sys::ZSTD_getFrameContentSize(buffer.buf_ptr(), buffer.len_bytes()) }; + let mut header = zstd_sys::ZSTD_frameHeader { + frameContentSize: 0, + windowSize: 0, + blockSizeMax: 0, + frameType: zstd_sys::ZSTD_frameType_e::ZSTD_frame, + headerSize: 0, + dictID: 0, + checksumFlag: 0, + _reserved1: 0, + _reserved2: 0, + }; + let zresult = unsafe { + zstd_sys::ZSTD_getFrameHeader_advanced(&mut header, buffer.buf_ptr(), buffer.len_bytes(), self.format) + }; + + if zresult != 0 { + return Err(ZstdError::new_err( + "error determining content size from frame header" + )) + } let (output_buffer_size, output_size) = - if output_size == zstd_sys::ZSTD_CONTENTSIZE_ERROR as _ { - return Err(ZstdError::new_err( - "error determining content size from frame header", - )); - } else if output_size == 0 { + if header.frameContentSize == 0 { return Ok(PyBytes::new(py, &[])); - } else if output_size == zstd_sys::ZSTD_CONTENTSIZE_UNKNOWN as _ { + } else if header.frameContentSize == zstd_sys::ZSTD_CONTENTSIZE_UNKNOWN as _ { if max_output_size == 0 { return Err(ZstdError::new_err( "could not determine content size in frame header", @@ -195,7 +209,7 @@ impl ZstdDecompressor { (max_output_size, 0) } else { - (output_size as _, output_size) + (header.frameContentSize as _, header.frameContentSize) }; let mut dest_buffer: Vec = Vec::new(); diff --git a/tests/test_decompressor_decompress.py b/tests/test_decompressor_decompress.py index da2ce6b1..602df6fe 100644 --- a/tests/test_decompressor_decompress.py +++ b/tests/test_decompressor_decompress.py @@ -37,6 +37,16 @@ def test_input_types(self): for source in sources: self.assertEqual(dctx.decompress(source), b"foo") + def test_headerless(self): + compression_params = zstd.ZstdCompressionParameters( + format=zstd.FORMAT_ZSTD1_MAGICLESS, + ) + cctx = zstd.ZstdCompressor(compression_params=compression_params) + compressed = cctx.compress(b"foo") + + dctx = zstd.ZstdDecompressor(format=zstd.FORMAT_ZSTD1_MAGICLESS) + self.assertEqual(dctx.decompress(compressed), b"foo") + def test_no_content_size_in_frame(self): cctx = zstd.ZstdCompressor(write_content_size=False) compressed = cctx.compress(b"foobar") diff --git a/zstandard/backend_cffi.py b/zstandard/backend_cffi.py index 72e63b7d..9f0e2ea9 100644 --- a/zstandard/backend_cffi.py +++ b/zstandard/backend_cffi.py @@ -3822,13 +3822,12 @@ def decompress( data_buffer = ffi.from_buffer(data) - output_size = lib.ZSTD_getFrameContentSize( - data_buffer, len(data_buffer) - ) - - if output_size == lib.ZSTD_CONTENTSIZE_ERROR: + params = ffi.new("ZSTD_frameHeader *") + zresult = lib.ZSTD_getFrameHeader_advanced(params, data_buffer, len(data_buffer), self._format) + if zresult != 0: raise ZstdError("error determining content size from frame header") - elif output_size == 0: + output_size = params.frameContentSize + if output_size == 0: return b"" elif output_size == lib.ZSTD_CONTENTSIZE_UNKNOWN: if not max_output_size: From 0550624505c73dd837209125e01fb70b1c9ad575 Mon Sep 17 00:00:00 2001 From: Stefan Wehner Date: Fri, 15 Mar 2024 20:41:59 +0100 Subject: [PATCH 3/4] Rename var --- c-ext/decompressor.c | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/c-ext/decompressor.c b/c-ext/decompressor.c index db96ba9c..d0b9591a 100644 --- a/c-ext/decompressor.c +++ b/c-ext/decompressor.c @@ -299,13 +299,13 @@ PyObject *Decompressor_decompress(ZstdDecompressor *self, PyObject *args, goto finally; } - ZSTD_frameHeader zfh; - if (ZSTD_getFrameHeader_advanced(&zfh, source.buf, source.len, self->format) != 0) { + ZSTD_frameHeader frameHeader; + if (ZSTD_getFrameHeader_advanced(&frameHeader, source.buf, source.len, self->format) != 0) { PyErr_SetString(ZstdError, "error determining content size from frame header"); goto finally; } - decompressedSize=zfh.frameContentSize; + decompressedSize=frameHeader.frameContentSize; /* Special case of empty frame. */ if (0 == decompressedSize) { result = PyBytes_FromStringAndSize("", 0); From 98fdb8bb021dbafa36d935688d1af11a35a578fc Mon Sep 17 00:00:00 2001 From: Stefan Wehner Date: Sun, 17 Mar 2024 00:15:18 +0100 Subject: [PATCH 4/4] fix test --- tests/test_compressor_fuzzing.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_compressor_fuzzing.py b/tests/test_compressor_fuzzing.py index 444cfb18..4703e422 100644 --- a/tests/test_compressor_fuzzing.py +++ b/tests/test_compressor_fuzzing.py @@ -767,7 +767,7 @@ def test_data_equivalence(self, original, threads, use_dict): dctx = zstd.ZstdDecompressor(**kwargs) for i, frame in enumerate(result): - self.assertEqual(dctx.decompress(frame), original[i]) + self.assertEqual(dctx.decompress(frame.tobytes()), original[i]) @unittest.skipUnless("ZSTD_SLOW_TESTS" in os.environ, "ZSTD_SLOW_TESTS not set")