Skip to content

Commit

Permalink
Merge pull request #363 from fintelia/fast-filter
Browse files Browse the repository at this point in the history
Optimize Sub and Paeth filtering
  • Loading branch information
fintelia authored Dec 23, 2022
2 parents 097dda2 + ed863de commit b76f988
Show file tree
Hide file tree
Showing 2 changed files with 151 additions and 42 deletions.
46 changes: 42 additions & 4 deletions src/encoder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -676,8 +676,14 @@ impl<W: Write> Writer<W> {
let adaptive_method = self.options.adaptive_filter;

for line in data.chunks(in_len) {
current.copy_from_slice(line);
let filter_type = filter(filter_method, adaptive_method, bpp, prev, &mut current);
let filter_type = filter(
filter_method,
adaptive_method,
bpp,
prev,
&line,
&mut current,
);
zlib.write_all(&[filter_type as u8])?;
zlib.write_all(&current)?;
prev = line;
Expand Down Expand Up @@ -1561,12 +1567,15 @@ impl<'a, W: Write> Write for StreamWriter<'a, W> {
self.to_write -= written;

if self.index == self.line_len {
// TODO: reuse this buffer between rows.
let mut filtered = vec![0; self.curr_buf.len()];
let filter_type = filter(
self.filter,
self.adaptive_filter,
self.bpp,
&self.prev_buf,
&mut self.curr_buf,
&self.curr_buf,
&mut filtered,
);
// This can't fail as the other variant is used only to allow the zlib encoder to finish
let wrt = match &mut self.writer {
Expand All @@ -1575,7 +1584,7 @@ impl<'a, W: Write> Write for StreamWriter<'a, W> {
};

wrt.write_all(&[filter_type as u8])?;
wrt.write_all(&self.curr_buf)?;
wrt.write_all(&filtered)?;
mem::swap(&mut self.prev_buf, &mut self.curr_buf);
self.index = 0;
}
Expand Down Expand Up @@ -2195,6 +2204,35 @@ mod tests {
Ok(())
}

#[test]
fn stream_filtering() -> Result<()> {
let output = vec![0u8; 1024];
let mut cursor = Cursor::new(output);

let mut encoder = Encoder::new(&mut cursor, 8, 8);
encoder.set_color(ColorType::Rgba);
encoder.set_filter(FilterType::Paeth);
let mut writer = encoder.write_header()?;
let mut stream = writer.stream_writer()?;

for _ in 0..8 {
let written = stream.write(&[1; 32])?;
assert_eq!(written, 32);
}
stream.finish()?;
drop(writer);

{
cursor.set_position(0);
let mut decoder = Decoder::new(cursor).read_info().expect("A valid image");
let mut buffer = [0u8; 256];
decoder.next_frame(&mut buffer[..]).expect("Valid read");
assert_eq!(buffer, [1; 256]);
}

Ok(())
}

#[test]
#[cfg(all(unix, not(target_pointer_width = "32")))]
fn exper_error_on_huge_chunk() -> Result<()> {
Expand Down
147 changes: 109 additions & 38 deletions src/filter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -55,15 +55,36 @@ impl Default for AdaptiveFilterType {
}

fn filter_paeth(a: u8, b: u8, c: u8) -> u8 {
let ia = i16::from(a);
let ib = i16::from(b);
let ic = i16::from(c);

let p = ia + ib - ic;

let pa = (p - ia).abs();
let pb = (p - ib).abs();
let pc = (p - ic).abs();
// This is an optimized version of the paeth filter from the PNG specification, proposed by
// Luca Versari for [FPNGE](https://www.lucaversari.it/FJXL_and_FPNGE.pdf). It operates
// entirely on unsigned 8-bit quantities, making it more conducive to vectorization.
//
// p = a + b - c
// pa = |p - a| = |a + b - c - a| = |b - c| = max(b, c) - min(b, c)
// pb = |p - b| = |a + b - c - b| = |a - c| = max(a, c) - min(a, c)
// pc = |p - c| = |a + b - c - c| = |(b - c) + (a - c)| = ...
//
// Further optimizing the calculation of `pc` a bit tricker. However, notice that:
//
// a > c && b > c
// ==> (a - c) > 0 && (b - c) > 0
// ==> pc > (a - c) && pc > (b - c)
// ==> pc > |a - c| && pc > |b - c|
// ==> pc > pb && pc > pa
//
// Meaning that if `c` is smaller than `a` and `b`, the value of `pc` is irrelevant. Similar
// reasoning applies if `c` is larger than the other two inputs. Assuming that `c >= b` and
// `c <= b` or vice versa:
//
// pc = ||b - c| - |a - c|| = |pa - pb| = max(pa, pb) - min(pa, pb)
//
let pa = b.max(c) - c.min(b);
let pb = a.max(c) - c.min(a);
let pc = if (a < c) == (c < b) {
pa.max(pb) - pa.min(pb)
} else {
255
};

if pa <= pb && pa <= pc {
a
Expand Down Expand Up @@ -209,47 +230,93 @@ fn filter_internal(
bpp: usize,
len: usize,
previous: &[u8],
current: &mut [u8],
current: &[u8],
output: &mut [u8],
) -> FilterType {
use self::FilterType::*;

// This value was chosen experimentally based on what acheived the best performance. The
// Rust compiler does auto-vectorization, and 32-bytes per loop iteration seems to enable
// the fastest code when doing so.
const CHUNK_SIZE: usize = 32;

match method {
NoFilter => NoFilter,
NoFilter => {
output.copy_from_slice(current);
NoFilter
}
Sub => {
for i in (bpp..len).rev() {
current[i] = current[i].wrapping_sub(current[i - bpp]);
let mut out_chunks = output[bpp..].chunks_exact_mut(CHUNK_SIZE);
let mut cur_chunks = current[bpp..].chunks_exact(CHUNK_SIZE);
let mut prev_chunks = current[..len - bpp].chunks_exact(CHUNK_SIZE);

for ((out, cur), prev) in (&mut out_chunks).zip(&mut cur_chunks).zip(&mut prev_chunks) {
for i in 0..CHUNK_SIZE {
out[i] = cur[i].wrapping_sub(prev[i]);
}
}

for ((out, cur), &prev) in out_chunks
.into_remainder()
.into_iter()
.zip(cur_chunks.remainder())
.zip(prev_chunks.remainder())
{
*out = cur.wrapping_sub(prev);
}

output[..bpp].copy_from_slice(&current[..bpp]);
Sub
}
Up => {
for i in 0..len {
current[i] = current[i].wrapping_sub(previous[i]);
output[i] = current[i].wrapping_sub(previous[i]);
}
Up
}
Avg => {
for i in (bpp..len).rev() {
current[i] = current[i].wrapping_sub(
output[i] = current[i].wrapping_sub(
((u16::from(current[i - bpp]) + u16::from(previous[i])) / 2) as u8,
);
}

for i in 0..bpp {
current[i] = current[i].wrapping_sub(previous[i] / 2);
output[i] = current[i].wrapping_sub(previous[i] / 2);
}
Avg
}
Paeth => {
for i in (bpp..len).rev() {
current[i] = current[i].wrapping_sub(filter_paeth(
current[i - bpp],
previous[i],
previous[i - bpp],
));
let mut out_chunks = output[bpp..].chunks_exact_mut(CHUNK_SIZE);
let mut cur_chunks = current[bpp..].chunks_exact(CHUNK_SIZE);
let mut a_chunks = current[..len - bpp].chunks_exact(CHUNK_SIZE);
let mut b_chunks = previous[bpp..].chunks_exact(CHUNK_SIZE);
let mut c_chunks = previous[..len - bpp].chunks_exact(CHUNK_SIZE);

for ((((out, cur), a), b), c) in (&mut out_chunks)
.zip(&mut cur_chunks)
.zip(&mut a_chunks)
.zip(&mut b_chunks)
.zip(&mut c_chunks)
{
for i in 0..CHUNK_SIZE {
out[i] = cur[i].wrapping_sub(filter_paeth(a[i], b[i], c[i]));
}
}

for ((((out, cur), &a), &b), &c) in out_chunks
.into_remainder()
.into_iter()
.zip(cur_chunks.remainder())
.zip(a_chunks.remainder())
.zip(b_chunks.remainder())
.zip(c_chunks.remainder())
{
*out = cur.wrapping_sub(filter_paeth(a, b, c));
}

for i in 0..bpp {
current[i] = current[i].wrapping_sub(filter_paeth(0, previous[i], 0));
output[i] = current[i].wrapping_sub(filter_paeth(0, previous[i], 0));
}
Paeth
}
Expand All @@ -261,14 +328,17 @@ pub(crate) fn filter(
adaptive: AdaptiveFilterType,
bpp: BytesPerPixel,
previous: &[u8],
current: &mut [u8],
current: &[u8],
output: &mut [u8],
) -> FilterType {
use FilterType::*;
let bpp = bpp.into_usize();
let len = current.len();

match adaptive {
AdaptiveFilterType::NonAdaptive => filter_internal(method, bpp, len, previous, current),
AdaptiveFilterType::NonAdaptive => {
filter_internal(method, bpp, len, previous, current, output)
}
AdaptiveFilterType::Adaptive => {
// Filter the current buffer with each filter type. Sum the absolute
// values of each filtered buffer treating the bytes as signed
Expand All @@ -282,8 +352,7 @@ pub(crate) fn filter(
let mut filter_choice = FilterType::NoFilter;

for &filter in [Sub, Up, Avg, Paeth].iter() {
scratch.copy_from_slice(current);
filter_internal(filter, bpp, len, previous, &mut scratch);
filter_internal(filter, bpp, len, previous, current, &mut scratch);
let sum = sum_buffer(&scratch);
if sum < min_sum {
min_sum = sum;
Expand All @@ -292,7 +361,7 @@ pub(crate) fn filter(
}
}

current.copy_from_slice(&filtered_buffer);
output.copy_from_slice(&filtered_buffer);

filter_choice
}
Expand All @@ -316,15 +385,16 @@ mod test {
// A multiple of 8, 6, 4, 3, 2, 1
const LEN: u8 = 240;
let previous: Vec<_> = iter::repeat(1).take(LEN.into()).collect();
let mut current: Vec<_> = (0..LEN).collect();
let current: Vec<_> = (0..LEN).collect();
let expected = current.clone();
let adaptive = AdaptiveFilterType::NonAdaptive;

let mut roundtrip = |kind, bpp: BytesPerPixel| {
filter(kind, adaptive, bpp, &previous, &mut current);
unfilter(kind, bpp, &previous, &mut current).expect("Unfilter worked");
let roundtrip = |kind, bpp: BytesPerPixel| {
let mut output = vec![0; LEN.into()];
filter(kind, adaptive, bpp, &previous, &current, &mut output);
unfilter(kind, bpp, &previous, &mut output).expect("Unfilter worked");
assert_eq!(
current, expected,
output, expected,
"Filtering {:?} with {:?} does not roundtrip",
bpp, kind
);
Expand Down Expand Up @@ -359,15 +429,16 @@ mod test {
// A multiple of 8, 6, 4, 3, 2, 1
const LEN: u8 = 240;
let previous: Vec<_> = (0..LEN).collect();
let mut current: Vec<_> = (0..LEN).collect();
let current: Vec<_> = (0..LEN).collect();
let expected = current.clone();
let adaptive = AdaptiveFilterType::NonAdaptive;

let mut roundtrip = |kind, bpp: BytesPerPixel| {
filter(kind, adaptive, bpp, &previous, &mut current);
unfilter(kind, bpp, &previous, &mut current).expect("Unfilter worked");
let roundtrip = |kind, bpp: BytesPerPixel| {
let mut output = vec![0; LEN.into()];
filter(kind, adaptive, bpp, &previous, &current, &mut output);
unfilter(kind, bpp, &previous, &mut output).expect("Unfilter worked");
assert_eq!(
current, expected,
output, expected,
"Filtering {:?} with {:?} does not roundtrip",
bpp, kind
);
Expand Down

0 comments on commit b76f988

Please sign in to comment.