diff --git a/src/encoder.rs b/src/encoder.rs index 6d75ea03..1e67eccc 100644 --- a/src/encoder.rs +++ b/src/encoder.rs @@ -676,8 +676,14 @@ impl Writer { let adaptive_method = self.options.adaptive_filter; for line in data.chunks(in_len) { - current.copy_from_slice(line); - let filter_type = filter(filter_method, adaptive_method, bpp, prev, &mut current); + let filter_type = filter( + filter_method, + adaptive_method, + bpp, + prev, + &line, + &mut current, + ); zlib.write_all(&[filter_type as u8])?; zlib.write_all(¤t)?; prev = line; @@ -1561,12 +1567,15 @@ impl<'a, W: Write> Write for StreamWriter<'a, W> { self.to_write -= written; if self.index == self.line_len { + // TODO: reuse this buffer between rows. + let mut filtered = vec![0; self.curr_buf.len()]; let filter_type = filter( self.filter, self.adaptive_filter, self.bpp, &self.prev_buf, - &mut self.curr_buf, + &self.curr_buf, + &mut filtered, ); // This can't fail as the other variant is used only to allow the zlib encoder to finish let wrt = match &mut self.writer { @@ -1575,7 +1584,7 @@ impl<'a, W: Write> Write for StreamWriter<'a, W> { }; wrt.write_all(&[filter_type as u8])?; - wrt.write_all(&self.curr_buf)?; + wrt.write_all(&filtered)?; mem::swap(&mut self.prev_buf, &mut self.curr_buf); self.index = 0; } @@ -2195,6 +2204,35 @@ mod tests { Ok(()) } + #[test] + fn stream_filtering() -> Result<()> { + let output = vec![0u8; 1024]; + let mut cursor = Cursor::new(output); + + let mut encoder = Encoder::new(&mut cursor, 8, 8); + encoder.set_color(ColorType::Rgba); + encoder.set_filter(FilterType::Paeth); + let mut writer = encoder.write_header()?; + let mut stream = writer.stream_writer()?; + + for _ in 0..8 { + let written = stream.write(&[1; 32])?; + assert_eq!(written, 32); + } + stream.finish()?; + drop(writer); + + { + cursor.set_position(0); + let mut decoder = Decoder::new(cursor).read_info().expect("A valid image"); + let mut buffer = [0u8; 256]; + decoder.next_frame(&mut buffer[..]).expect("Valid read"); + assert_eq!(buffer, [1; 256]); + } + + Ok(()) + } + #[test] #[cfg(all(unix, not(target_pointer_width = "32")))] fn exper_error_on_huge_chunk() -> Result<()> { diff --git a/src/filter.rs b/src/filter.rs index 339fa18b..e2c40b7c 100644 --- a/src/filter.rs +++ b/src/filter.rs @@ -55,15 +55,36 @@ impl Default for AdaptiveFilterType { } fn filter_paeth(a: u8, b: u8, c: u8) -> u8 { - let ia = i16::from(a); - let ib = i16::from(b); - let ic = i16::from(c); - - let p = ia + ib - ic; - - let pa = (p - ia).abs(); - let pb = (p - ib).abs(); - let pc = (p - ic).abs(); + // This is an optimized version of the paeth filter from the PNG specification, proposed by + // Luca Versari for [FPNGE](https://www.lucaversari.it/FJXL_and_FPNGE.pdf). It operates + // entirely on unsigned 8-bit quantities, making it more conducive to vectorization. + // + // p = a + b - c + // pa = |p - a| = |a + b - c - a| = |b - c| = max(b, c) - min(b, c) + // pb = |p - b| = |a + b - c - b| = |a - c| = max(a, c) - min(a, c) + // pc = |p - c| = |a + b - c - c| = |(b - c) + (a - c)| = ... + // + // Further optimizing the calculation of `pc` a bit tricker. However, notice that: + // + // a > c && b > c + // ==> (a - c) > 0 && (b - c) > 0 + // ==> pc > (a - c) && pc > (b - c) + // ==> pc > |a - c| && pc > |b - c| + // ==> pc > pb && pc > pa + // + // Meaning that if `c` is smaller than `a` and `b`, the value of `pc` is irrelevant. Similar + // reasoning applies if `c` is larger than the other two inputs. Assuming that `c >= b` and + // `c <= b` or vice versa: + // + // pc = ||b - c| - |a - c|| = |pa - pb| = max(pa, pb) - min(pa, pb) + // + let pa = b.max(c) - c.min(b); + let pb = a.max(c) - c.min(a); + let pc = if (a < c) == (c < b) { + pa.max(pb) - pa.min(pb) + } else { + 255 + }; if pa <= pb && pa <= pc { a @@ -209,47 +230,93 @@ fn filter_internal( bpp: usize, len: usize, previous: &[u8], - current: &mut [u8], + current: &[u8], + output: &mut [u8], ) -> FilterType { use self::FilterType::*; + // This value was chosen experimentally based on what acheived the best performance. The + // Rust compiler does auto-vectorization, and 32-bytes per loop iteration seems to enable + // the fastest code when doing so. + const CHUNK_SIZE: usize = 32; + match method { - NoFilter => NoFilter, + NoFilter => { + output.copy_from_slice(current); + NoFilter + } Sub => { - for i in (bpp..len).rev() { - current[i] = current[i].wrapping_sub(current[i - bpp]); + let mut out_chunks = output[bpp..].chunks_exact_mut(CHUNK_SIZE); + let mut cur_chunks = current[bpp..].chunks_exact(CHUNK_SIZE); + let mut prev_chunks = current[..len - bpp].chunks_exact(CHUNK_SIZE); + + for ((out, cur), prev) in (&mut out_chunks).zip(&mut cur_chunks).zip(&mut prev_chunks) { + for i in 0..CHUNK_SIZE { + out[i] = cur[i].wrapping_sub(prev[i]); + } } + + for ((out, cur), &prev) in out_chunks + .into_remainder() + .into_iter() + .zip(cur_chunks.remainder()) + .zip(prev_chunks.remainder()) + { + *out = cur.wrapping_sub(prev); + } + + output[..bpp].copy_from_slice(¤t[..bpp]); Sub } Up => { for i in 0..len { - current[i] = current[i].wrapping_sub(previous[i]); + output[i] = current[i].wrapping_sub(previous[i]); } Up } Avg => { for i in (bpp..len).rev() { - current[i] = current[i].wrapping_sub( + output[i] = current[i].wrapping_sub( ((u16::from(current[i - bpp]) + u16::from(previous[i])) / 2) as u8, ); } for i in 0..bpp { - current[i] = current[i].wrapping_sub(previous[i] / 2); + output[i] = current[i].wrapping_sub(previous[i] / 2); } Avg } Paeth => { - for i in (bpp..len).rev() { - current[i] = current[i].wrapping_sub(filter_paeth( - current[i - bpp], - previous[i], - previous[i - bpp], - )); + let mut out_chunks = output[bpp..].chunks_exact_mut(CHUNK_SIZE); + let mut cur_chunks = current[bpp..].chunks_exact(CHUNK_SIZE); + let mut a_chunks = current[..len - bpp].chunks_exact(CHUNK_SIZE); + let mut b_chunks = previous[bpp..].chunks_exact(CHUNK_SIZE); + let mut c_chunks = previous[..len - bpp].chunks_exact(CHUNK_SIZE); + + for ((((out, cur), a), b), c) in (&mut out_chunks) + .zip(&mut cur_chunks) + .zip(&mut a_chunks) + .zip(&mut b_chunks) + .zip(&mut c_chunks) + { + for i in 0..CHUNK_SIZE { + out[i] = cur[i].wrapping_sub(filter_paeth(a[i], b[i], c[i])); + } + } + + for ((((out, cur), &a), &b), &c) in out_chunks + .into_remainder() + .into_iter() + .zip(cur_chunks.remainder()) + .zip(a_chunks.remainder()) + .zip(b_chunks.remainder()) + .zip(c_chunks.remainder()) + { + *out = cur.wrapping_sub(filter_paeth(a, b, c)); } for i in 0..bpp { - current[i] = current[i].wrapping_sub(filter_paeth(0, previous[i], 0)); + output[i] = current[i].wrapping_sub(filter_paeth(0, previous[i], 0)); } Paeth } @@ -261,14 +328,17 @@ pub(crate) fn filter( adaptive: AdaptiveFilterType, bpp: BytesPerPixel, previous: &[u8], - current: &mut [u8], + current: &[u8], + output: &mut [u8], ) -> FilterType { use FilterType::*; let bpp = bpp.into_usize(); let len = current.len(); match adaptive { - AdaptiveFilterType::NonAdaptive => filter_internal(method, bpp, len, previous, current), + AdaptiveFilterType::NonAdaptive => { + filter_internal(method, bpp, len, previous, current, output) + } AdaptiveFilterType::Adaptive => { // Filter the current buffer with each filter type. Sum the absolute // values of each filtered buffer treating the bytes as signed @@ -282,8 +352,7 @@ pub(crate) fn filter( let mut filter_choice = FilterType::NoFilter; for &filter in [Sub, Up, Avg, Paeth].iter() { - scratch.copy_from_slice(current); - filter_internal(filter, bpp, len, previous, &mut scratch); + filter_internal(filter, bpp, len, previous, current, &mut scratch); let sum = sum_buffer(&scratch); if sum < min_sum { min_sum = sum; @@ -292,7 +361,7 @@ pub(crate) fn filter( } } - current.copy_from_slice(&filtered_buffer); + output.copy_from_slice(&filtered_buffer); filter_choice } @@ -316,15 +385,16 @@ mod test { // A multiple of 8, 6, 4, 3, 2, 1 const LEN: u8 = 240; let previous: Vec<_> = iter::repeat(1).take(LEN.into()).collect(); - let mut current: Vec<_> = (0..LEN).collect(); + let current: Vec<_> = (0..LEN).collect(); let expected = current.clone(); let adaptive = AdaptiveFilterType::NonAdaptive; - let mut roundtrip = |kind, bpp: BytesPerPixel| { - filter(kind, adaptive, bpp, &previous, &mut current); - unfilter(kind, bpp, &previous, &mut current).expect("Unfilter worked"); + let roundtrip = |kind, bpp: BytesPerPixel| { + let mut output = vec![0; LEN.into()]; + filter(kind, adaptive, bpp, &previous, ¤t, &mut output); + unfilter(kind, bpp, &previous, &mut output).expect("Unfilter worked"); assert_eq!( - current, expected, + output, expected, "Filtering {:?} with {:?} does not roundtrip", bpp, kind ); @@ -359,15 +429,16 @@ mod test { // A multiple of 8, 6, 4, 3, 2, 1 const LEN: u8 = 240; let previous: Vec<_> = (0..LEN).collect(); - let mut current: Vec<_> = (0..LEN).collect(); + let current: Vec<_> = (0..LEN).collect(); let expected = current.clone(); let adaptive = AdaptiveFilterType::NonAdaptive; - let mut roundtrip = |kind, bpp: BytesPerPixel| { - filter(kind, adaptive, bpp, &previous, &mut current); - unfilter(kind, bpp, &previous, &mut current).expect("Unfilter worked"); + let roundtrip = |kind, bpp: BytesPerPixel| { + let mut output = vec![0; LEN.into()]; + filter(kind, adaptive, bpp, &previous, ¤t, &mut output); + unfilter(kind, bpp, &previous, &mut output).expect("Unfilter worked"); assert_eq!( - current, expected, + output, expected, "Filtering {:?} with {:?} does not roundtrip", bpp, kind );