Skip to content

Commit

Permalink
Switched to storing mz_stream as a raw pointer to fix tree borrows vi…
Browse files Browse the repository at this point in the history
…olation.

Removed Deref and DerefMut implementations for StreamWrapper.
  • Loading branch information
icmccorm committed Jan 25, 2024
1 parent 8ef8ae6 commit 8386651
Show file tree
Hide file tree
Showing 2 changed files with 114 additions and 89 deletions.
173 changes: 94 additions & 79 deletions src/ffi/c.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ use std::cmp;
use std::convert::TryFrom;
use std::fmt;
use std::marker;
use std::ops::{Deref, DerefMut};
use std::os::raw::{c_int, c_uint, c_void};
use std::ptr;

Expand All @@ -21,7 +20,10 @@ impl ErrorMessage {
}

pub struct StreamWrapper {
pub inner: Box<mz_stream>,
// SAFETY: The field `inner` must always be accessed as a raw pointer,
// since it points to a cyclic structure, and it must never be copied
// by Rust.
pub inner: *mut mz_stream,
}

impl fmt::Debug for StreamWrapper {
Expand All @@ -32,8 +34,12 @@ impl fmt::Debug for StreamWrapper {

impl Default for StreamWrapper {
fn default() -> StreamWrapper {
// SAFETY: The field `state` will be initialized across the FFI to
// point to the opaque type `mz_internal_state`, which will contain a copy
// of `inner`. This cyclic structure breaks the uniqueness invariant of
// &mut mz_stream, so we must use a raw pointer instead of Box<mz_stream>.
StreamWrapper {
inner: Box::new(mz_stream {
inner: Box::into_raw(Box::new(mz_stream {
next_in: ptr::null_mut(),
avail_in: 0,
total_in: 0,
Expand All @@ -54,11 +60,21 @@ impl Default for StreamWrapper {
zalloc: Some(zalloc),
#[cfg(not(all(feature = "any_zlib", not(feature = "cloudflare-zlib-sys"))))]
zfree: Some(zfree),
}),
})),
}
}
}

impl Drop for StreamWrapper {
fn drop(&mut self) {
// SAFETY: At this point, every other allocation for struct has been freed by
// `inflateEnd` or `deflateEnd`, and no copies of `inner` are retained by `C`,
// so it is safe to drop the struct as long as the user respects the invariant that
// `inner` must never be copied by Rust.
drop(unsafe { Box::from_raw(self.inner) });
}
}

const ALIGN: usize = std::mem::align_of::<usize>();

fn align_up(size: usize, align: usize) -> usize {
Expand Down Expand Up @@ -110,20 +126,6 @@ extern "C" fn zfree(_ptr: *mut c_void, address: *mut c_void) {
}
}

impl Deref for StreamWrapper {
type Target = mz_stream;

fn deref(&self) -> &Self::Target {
&*self.inner
}
}

impl DerefMut for StreamWrapper {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut *self.inner
}
}

unsafe impl<D: Direction> Send for Stream<D> {}
unsafe impl<D: Direction> Sync for Stream<D> {}

Expand All @@ -148,7 +150,10 @@ pub struct Stream<D: Direction> {

impl<D: Direction> Stream<D> {
pub fn msg(&self) -> ErrorMessage {
let msg = self.stream_wrapper.msg;
// SAFETY: The field `inner` must always be accessed as a raw pointer,
// since it points to a cyclic structure. No copies of `inner` can be
// retained for longer than the lifetime of `self`.
let msg = unsafe { (*self.stream_wrapper.inner).msg };
ErrorMessage(if msg.is_null() {
None
} else {
Expand All @@ -161,7 +166,7 @@ impl<D: Direction> Stream<D> {
impl<D: Direction> Drop for Stream<D> {
fn drop(&mut self) {
unsafe {
let _ = D::destroy(&mut *self.stream_wrapper);
let _ = D::destroy(self.stream_wrapper.inner);
}
}
}
Expand All @@ -185,9 +190,9 @@ pub struct Inflate {
impl InflateBackend for Inflate {
fn make(zlib_header: bool, window_bits: u8) -> Self {
unsafe {
let mut state = StreamWrapper::default();
let state = StreamWrapper::default();
let ret = mz_inflateInit2(
&mut *state,
state.inner,
if zlib_header {
window_bits as c_int
} else {
Expand All @@ -212,33 +217,38 @@ impl InflateBackend for Inflate {
output: &mut [u8],
flush: FlushDecompress,
) -> Result<Status, DecompressError> {
let raw = &mut *self.inner.stream_wrapper;
raw.msg = ptr::null_mut();
raw.next_in = input.as_ptr() as *mut u8;
raw.avail_in = cmp::min(input.len(), c_uint::MAX as usize) as c_uint;
raw.next_out = output.as_mut_ptr();
raw.avail_out = cmp::min(output.len(), c_uint::MAX as usize) as c_uint;

let rc = unsafe { mz_inflate(raw, flush as c_int) };

// Unfortunately the total counters provided by zlib might be only
// 32 bits wide and overflow while processing large amounts of data.
self.inner.total_in += (raw.next_in as usize - input.as_ptr() as usize) as u64;
self.inner.total_out += (raw.next_out as usize - output.as_ptr() as usize) as u64;

// reset these pointers so we don't accidentally read them later
raw.next_in = ptr::null_mut();
raw.avail_in = 0;
raw.next_out = ptr::null_mut();
raw.avail_out = 0;

match rc {
MZ_DATA_ERROR | MZ_STREAM_ERROR => mem::decompress_failed(self.inner.msg()),
MZ_OK => Ok(Status::Ok),
MZ_BUF_ERROR => Ok(Status::BufError),
MZ_STREAM_END => Ok(Status::StreamEnd),
MZ_NEED_DICT => mem::decompress_need_dict(raw.adler as u32),
c => panic!("unknown return code: {}", c),
let raw = self.inner.stream_wrapper.inner;
// SAFETY: The field `inner` must always be accessed as a raw pointer,
// since it points to a cyclic structure. No copies of `inner` can be
// retained for longer than the lifetime of `self`.
unsafe {
(*raw).msg = ptr::null_mut();
(*raw).next_in = input.as_ptr() as *mut u8;
(*raw).avail_in = cmp::min(input.len(), c_uint::MAX as usize) as c_uint;
(*raw).next_out = output.as_mut_ptr();
(*raw).avail_out = cmp::min(output.len(), c_uint::MAX as usize) as c_uint;

let rc = mz_inflate(raw, flush as c_int);

// Unfortunately the total counters provided by zlib might be only
// 32 bits wide and overflow while processing large amounts of data.
self.inner.total_in += ((*raw).next_in as usize - input.as_ptr() as usize) as u64;
self.inner.total_out += ((*raw).next_out as usize - output.as_ptr() as usize) as u64;

// reset these pointers so we don't accidentally read them later
(*raw).next_in = ptr::null_mut();
(*raw).avail_in = 0;
(*raw).next_out = ptr::null_mut();
(*raw).avail_out = 0;

match rc {
MZ_DATA_ERROR | MZ_STREAM_ERROR => mem::decompress_failed(self.inner.msg()),
MZ_OK => Ok(Status::Ok),
MZ_BUF_ERROR => Ok(Status::BufError),
MZ_STREAM_END => Ok(Status::StreamEnd),
MZ_NEED_DICT => mem::decompress_need_dict((*raw).adler as u32),
c => panic!("unknown return code: {}", c),
}
}
}

Expand All @@ -249,7 +259,7 @@ impl InflateBackend for Inflate {
-MZ_DEFAULT_WINDOW_BITS
};
unsafe {
inflateReset2(&mut *self.inner.stream_wrapper, bits);
inflateReset2(self.inner.stream_wrapper.inner, bits);
}
self.inner.total_out = 0;
self.inner.total_in = 0;
Expand All @@ -276,9 +286,9 @@ pub struct Deflate {
impl DeflateBackend for Deflate {
fn make(level: Compression, zlib_header: bool, window_bits: u8) -> Self {
unsafe {
let mut state = StreamWrapper::default();
let state = StreamWrapper::default();
let ret = mz_deflateInit2(
&mut *state,
state.inner,
level.0 as c_int,
MZ_DEFLATED,
if zlib_header {
Expand Down Expand Up @@ -306,39 +316,44 @@ impl DeflateBackend for Deflate {
output: &mut [u8],
flush: FlushCompress,
) -> Result<Status, CompressError> {
let raw = &mut *self.inner.stream_wrapper;
raw.msg = ptr::null_mut();
raw.next_in = input.as_ptr() as *mut _;
raw.avail_in = cmp::min(input.len(), c_uint::MAX as usize) as c_uint;
raw.next_out = output.as_mut_ptr();
raw.avail_out = cmp::min(output.len(), c_uint::MAX as usize) as c_uint;

let rc = unsafe { mz_deflate(raw, flush as c_int) };

// Unfortunately the total counters provided by zlib might be only
// 32 bits wide and overflow while processing large amounts of data.
self.inner.total_in += (raw.next_in as usize - input.as_ptr() as usize) as u64;
self.inner.total_out += (raw.next_out as usize - output.as_ptr() as usize) as u64;

// reset these pointers so we don't accidentally read them later
raw.next_in = ptr::null_mut();
raw.avail_in = 0;
raw.next_out = ptr::null_mut();
raw.avail_out = 0;

match rc {
MZ_OK => Ok(Status::Ok),
MZ_BUF_ERROR => Ok(Status::BufError),
MZ_STREAM_END => Ok(Status::StreamEnd),
MZ_STREAM_ERROR => mem::compress_failed(self.inner.msg()),
c => panic!("unknown return code: {}", c),
let raw = self.inner.stream_wrapper.inner;
// SAFETY: The field `inner` must always be accessed as a raw pointer,
// since it points to a cyclic structure. No copies of `inner` can be
// retained for longer than the lifetime of `self`.
unsafe {
(*raw).msg = ptr::null_mut();
(*raw).next_in = input.as_ptr() as *mut _;
(*raw).avail_in = cmp::min(input.len(), c_uint::MAX as usize) as c_uint;
(*raw).next_out = output.as_mut_ptr();
(*raw).avail_out = cmp::min(output.len(), c_uint::MAX as usize) as c_uint;

let rc = mz_deflate(raw, flush as c_int);

// Unfortunately the total counters provided by zlib might be only
// 32 bits wide and overflow while processing large amounts of data.

self.inner.total_in += ((*raw).next_in as usize - input.as_ptr() as usize) as u64;
self.inner.total_out += ((*raw).next_out as usize - output.as_ptr() as usize) as u64;
// reset these pointers so we don't accidentally read them later
(*raw).next_in = ptr::null_mut();
(*raw).avail_in = 0;
(*raw).next_out = ptr::null_mut();
(*raw).avail_out = 0;

match rc {
MZ_OK => Ok(Status::Ok),
MZ_BUF_ERROR => Ok(Status::BufError),
MZ_STREAM_END => Ok(Status::StreamEnd),
MZ_STREAM_ERROR => mem::compress_failed(self.inner.msg()),
c => panic!("unknown return code: {}", c),
}
}
}

fn reset(&mut self) {
self.inner.total_in = 0;
self.inner.total_out = 0;
let rc = unsafe { mz_deflateReset(&mut *self.inner.stream_wrapper) };
let rc = unsafe { mz_deflateReset(self.inner.stream_wrapper.inner) };
assert_eq!(rc, MZ_OK);
}
}
Expand Down
30 changes: 20 additions & 10 deletions src/mem.rs
Original file line number Diff line number Diff line change
Expand Up @@ -265,16 +265,19 @@ impl Compress {
/// Returns the Adler-32 checksum of the dictionary.
#[cfg(feature = "any_zlib")]
pub fn set_dictionary(&mut self, dictionary: &[u8]) -> Result<u32, CompressError> {
let stream = &mut *self.inner.inner.stream_wrapper;
stream.msg = std::ptr::null_mut();
// SAFETY: The field `inner` must always be accessed as a raw pointer,
// since it points to a cyclic structure. No copies of `inner` can be
// retained for longer than the lifetime of `self.inner.inner.stream_wrapper`.
let stream = self.inner.inner.stream_wrapper.inner;
let rc = unsafe {
(*stream).msg = std::ptr::null_mut();
assert!(dictionary.len() < ffi::uInt::MAX as usize);
ffi::deflateSetDictionary(stream, dictionary.as_ptr(), dictionary.len() as ffi::uInt)
};

match rc {
ffi::MZ_STREAM_ERROR => compress_failed(self.inner.inner.msg()),
ffi::MZ_OK => Ok(stream.adler as u32),
ffi::MZ_OK => Ok(unsafe { (*stream).adler } as u32),
c => panic!("unknown return code: {}", c),
}
}
Expand All @@ -299,9 +302,13 @@ impl Compress {
#[cfg(feature = "any_zlib")]
pub fn set_level(&mut self, level: Compression) -> Result<(), CompressError> {
use std::os::raw::c_int;
let stream = &mut *self.inner.inner.stream_wrapper;
stream.msg = std::ptr::null_mut();

// SAFETY: The field `inner` must always be accessed as a raw pointer,
// since it points to a cyclic structure. No copies of `inner` can be
// retained for longer than the lifetime of `self.inner.inner.stream_wrapper`.
let stream = self.inner.inner.stream_wrapper.inner;
unsafe {
(*stream).msg = std::ptr::null_mut();
}
let rc = unsafe { ffi::deflateParams(stream, level.0 as c_int, ffi::MZ_DEFAULT_STRATEGY) };

match rc {
Expand Down Expand Up @@ -476,17 +483,20 @@ impl Decompress {
/// Specifies the decompression dictionary to use.
#[cfg(feature = "any_zlib")]
pub fn set_dictionary(&mut self, dictionary: &[u8]) -> Result<u32, DecompressError> {
let stream = &mut *self.inner.inner.stream_wrapper;
stream.msg = std::ptr::null_mut();
// SAFETY: The field `inner` must always be accessed as a raw pointer,
// since it points to a cyclic structure. No copies of `inner` can be
// retained for longer than the lifetime of `self.inner.inner.stream_wrapper`.
let stream = self.inner.inner.stream_wrapper.inner;
let rc = unsafe {
(*stream).msg = std::ptr::null_mut();
assert!(dictionary.len() < ffi::uInt::MAX as usize);
ffi::inflateSetDictionary(stream, dictionary.as_ptr(), dictionary.len() as ffi::uInt)
};

match rc {
ffi::MZ_STREAM_ERROR => decompress_failed(self.inner.inner.msg()),
ffi::MZ_DATA_ERROR => decompress_need_dict(stream.adler as u32),
ffi::MZ_OK => Ok(stream.adler as u32),
ffi::MZ_DATA_ERROR => decompress_need_dict(unsafe { (*stream).adler } as u32),
ffi::MZ_OK => Ok(unsafe { (*stream).adler } as u32),
c => panic!("unknown return code: {}", c),
}
}
Expand Down

0 comments on commit 8386651

Please sign in to comment.