Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optimize Sub and Paeth filtering #363

Merged
merged 3 commits into from
Dec 23, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 42 additions & 4 deletions src/encoder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -676,8 +676,14 @@ impl<W: Write> Writer<W> {
let adaptive_method = self.options.adaptive_filter;

for line in data.chunks(in_len) {
current.copy_from_slice(line);
let filter_type = filter(filter_method, adaptive_method, bpp, prev, &mut current);
let filter_type = filter(
filter_method,
adaptive_method,
bpp,
prev,
&line,
&mut current,
);
zlib.write_all(&[filter_type as u8])?;
zlib.write_all(&current)?;
prev = line;
Expand Down Expand Up @@ -1561,12 +1567,15 @@ impl<'a, W: Write> Write for StreamWriter<'a, W> {
self.to_write -= written;

if self.index == self.line_len {
// TODO: reuse this buffer between rows.
let mut filtered = vec![0; self.curr_buf.len()];
fintelia marked this conversation as resolved.
Show resolved Hide resolved
let filter_type = filter(
self.filter,
self.adaptive_filter,
self.bpp,
&self.prev_buf,
&mut self.curr_buf,
&self.curr_buf,
&mut filtered,
);
// This can't fail as the other variant is used only to allow the zlib encoder to finish
let wrt = match &mut self.writer {
Expand All @@ -1575,7 +1584,7 @@ impl<'a, W: Write> Write for StreamWriter<'a, W> {
};

wrt.write_all(&[filter_type as u8])?;
wrt.write_all(&self.curr_buf)?;
wrt.write_all(&filtered)?;
mem::swap(&mut self.prev_buf, &mut self.curr_buf);
self.index = 0;
}
Expand Down Expand Up @@ -2195,6 +2204,35 @@ mod tests {
Ok(())
}

#[test]
fn stream_filtering() -> Result<()> {
let output = vec![0u8; 1024];
let mut cursor = Cursor::new(output);

let mut encoder = Encoder::new(&mut cursor, 8, 8);
encoder.set_color(ColorType::Rgba);
encoder.set_filter(FilterType::Paeth);
let mut writer = encoder.write_header()?;
let mut stream = writer.stream_writer()?;

for _ in 0..8 {
let written = stream.write(&[1; 32])?;
assert_eq!(written, 32);
}
stream.finish()?;
drop(writer);

{
cursor.set_position(0);
let mut decoder = Decoder::new(cursor).read_info().expect("A valid image");
let mut buffer = [0u8; 256];
decoder.next_frame(&mut buffer[..]).expect("Valid read");
assert_eq!(buffer, [1; 256]);
}

Ok(())
}

#[test]
#[cfg(all(unix, not(target_pointer_width = "32")))]
fn exper_error_on_huge_chunk() -> Result<()> {
Expand Down
147 changes: 109 additions & 38 deletions src/filter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -55,15 +55,36 @@ impl Default for AdaptiveFilterType {
}

fn filter_paeth(a: u8, b: u8, c: u8) -> u8 {
let ia = i16::from(a);
let ib = i16::from(b);
let ic = i16::from(c);

let p = ia + ib - ic;

let pa = (p - ia).abs();
let pb = (p - ib).abs();
let pc = (p - ic).abs();
// This is an optimized version of the paeth filter from the PNG specification, proposed by
// Luca Versari for [FPNGE](https://www.lucaversari.it/FJXL_and_FPNGE.pdf). It operates
// entirely on unsigned 8-bit quantities, making it more conducive to vectorization.
//
// p = a + b - c
// pa = |p - a| = |a + b - c - a| = |b - c| = max(b, c) - min(b, c)
// pb = |p - b| = |a + b - c - b| = |a - c| = max(a, c) - min(a, c)
// pc = |p - c| = |a + b - c - c| = |(b - c) + (a - c)| = ...
//
// Further optimizing the calculation of `pc` a bit tricker. However, notice that:
//
// a > c && b > c
// ==> (a - c) > 0 && (b - c) > 0
// ==> pc > (a - c) && pc > (b - c)
// ==> pc > |a - c| && pc > |b - c|
// ==> pc > pb && pc > pa
//
// Meaning that if `c` is smaller than `a` and `b`, the value of `pc` is irrelevant. Similar
// reasoning applies if `c` is larger than the other two inputs. Assuming that `c >= b` and
// `c <= b` or vice versa:
//
// pc = ||b - c| - |a - c|| = |pa - pb| = max(pa, pb) - min(pa, pb)
//
let pa = b.max(c) - c.min(b);
let pb = a.max(c) - c.min(a);
let pc = if (a < c) == (c < b) {
pa.max(pb) - pa.min(pb)
} else {
255
};

if pa <= pb && pa <= pc {
a
Expand Down Expand Up @@ -209,47 +230,93 @@ fn filter_internal(
bpp: usize,
len: usize,
previous: &[u8],
current: &mut [u8],
current: &[u8],
output: &mut [u8],
) -> FilterType {
use self::FilterType::*;

// This value was chosen experimentally based on what acheived the best performance. The
// Rust compiler does auto-vectorization, and 32-bytes per loop iteration seems to enable
// the fastest code when doing so.
const CHUNK_SIZE: usize = 32;
fintelia marked this conversation as resolved.
Show resolved Hide resolved

match method {
NoFilter => NoFilter,
NoFilter => {
output.copy_from_slice(current);
NoFilter
}
Sub => {
for i in (bpp..len).rev() {
current[i] = current[i].wrapping_sub(current[i - bpp]);
let mut out_chunks = output[bpp..].chunks_exact_mut(CHUNK_SIZE);
let mut cur_chunks = current[bpp..].chunks_exact(CHUNK_SIZE);
let mut prev_chunks = current[..len - bpp].chunks_exact(CHUNK_SIZE);

for ((out, cur), prev) in (&mut out_chunks).zip(&mut cur_chunks).zip(&mut prev_chunks) {
for i in 0..CHUNK_SIZE {
out[i] = cur[i].wrapping_sub(prev[i]);
}
}

for ((out, cur), &prev) in out_chunks
.into_remainder()
.into_iter()
.zip(cur_chunks.remainder())
.zip(prev_chunks.remainder())
{
*out = cur.wrapping_sub(prev);
}

output[..bpp].copy_from_slice(&current[..bpp]);
Sub
}
Up => {
for i in 0..len {
current[i] = current[i].wrapping_sub(previous[i]);
output[i] = current[i].wrapping_sub(previous[i]);
}
Up
}
Avg => {
for i in (bpp..len).rev() {
current[i] = current[i].wrapping_sub(
output[i] = current[i].wrapping_sub(
((u16::from(current[i - bpp]) + u16::from(previous[i])) / 2) as u8,
);
}

for i in 0..bpp {
current[i] = current[i].wrapping_sub(previous[i] / 2);
output[i] = current[i].wrapping_sub(previous[i] / 2);
}
Avg
}
Paeth => {
for i in (bpp..len).rev() {
current[i] = current[i].wrapping_sub(filter_paeth(
current[i - bpp],
previous[i],
previous[i - bpp],
));
let mut out_chunks = output[bpp..].chunks_exact_mut(CHUNK_SIZE);
let mut cur_chunks = current[bpp..].chunks_exact(CHUNK_SIZE);
let mut a_chunks = current[..len - bpp].chunks_exact(CHUNK_SIZE);
let mut b_chunks = previous[bpp..].chunks_exact(CHUNK_SIZE);
let mut c_chunks = previous[..len - bpp].chunks_exact(CHUNK_SIZE);

for ((((out, cur), a), b), c) in (&mut out_chunks)
.zip(&mut cur_chunks)
.zip(&mut a_chunks)
.zip(&mut b_chunks)
.zip(&mut c_chunks)
{
for i in 0..CHUNK_SIZE {
out[i] = cur[i].wrapping_sub(filter_paeth(a[i], b[i], c[i]));
}
}

for ((((out, cur), &a), &b), &c) in out_chunks
.into_remainder()
.into_iter()
.zip(cur_chunks.remainder())
.zip(a_chunks.remainder())
.zip(b_chunks.remainder())
.zip(c_chunks.remainder())
{
*out = cur.wrapping_sub(filter_paeth(a, b, c));
}

for i in 0..bpp {
current[i] = current[i].wrapping_sub(filter_paeth(0, previous[i], 0));
output[i] = current[i].wrapping_sub(filter_paeth(0, previous[i], 0));
}
Paeth
}
Expand All @@ -261,14 +328,17 @@ pub(crate) fn filter(
adaptive: AdaptiveFilterType,
bpp: BytesPerPixel,
previous: &[u8],
current: &mut [u8],
current: &[u8],
output: &mut [u8],
) -> FilterType {
use FilterType::*;
let bpp = bpp.into_usize();
let len = current.len();

match adaptive {
AdaptiveFilterType::NonAdaptive => filter_internal(method, bpp, len, previous, current),
AdaptiveFilterType::NonAdaptive => {
filter_internal(method, bpp, len, previous, current, output)
}
AdaptiveFilterType::Adaptive => {
// Filter the current buffer with each filter type. Sum the absolute
// values of each filtered buffer treating the bytes as signed
Expand All @@ -282,8 +352,7 @@ pub(crate) fn filter(
let mut filter_choice = FilterType::NoFilter;

for &filter in [Sub, Up, Avg, Paeth].iter() {
scratch.copy_from_slice(current);
filter_internal(filter, bpp, len, previous, &mut scratch);
filter_internal(filter, bpp, len, previous, current, &mut scratch);
let sum = sum_buffer(&scratch);
if sum < min_sum {
min_sum = sum;
Expand All @@ -292,7 +361,7 @@ pub(crate) fn filter(
}
}

current.copy_from_slice(&filtered_buffer);
output.copy_from_slice(&filtered_buffer);

filter_choice
}
Expand All @@ -316,15 +385,16 @@ mod test {
// A multiple of 8, 6, 4, 3, 2, 1
const LEN: u8 = 240;
let previous: Vec<_> = iter::repeat(1).take(LEN.into()).collect();
let mut current: Vec<_> = (0..LEN).collect();
let current: Vec<_> = (0..LEN).collect();
let expected = current.clone();
let adaptive = AdaptiveFilterType::NonAdaptive;

let mut roundtrip = |kind, bpp: BytesPerPixel| {
filter(kind, adaptive, bpp, &previous, &mut current);
unfilter(kind, bpp, &previous, &mut current).expect("Unfilter worked");
let roundtrip = |kind, bpp: BytesPerPixel| {
let mut output = vec![0; LEN.into()];
filter(kind, adaptive, bpp, &previous, &current, &mut output);
unfilter(kind, bpp, &previous, &mut output).expect("Unfilter worked");
assert_eq!(
current, expected,
output, expected,
"Filtering {:?} with {:?} does not roundtrip",
bpp, kind
);
Expand Down Expand Up @@ -359,15 +429,16 @@ mod test {
// A multiple of 8, 6, 4, 3, 2, 1
const LEN: u8 = 240;
let previous: Vec<_> = (0..LEN).collect();
let mut current: Vec<_> = (0..LEN).collect();
let current: Vec<_> = (0..LEN).collect();
let expected = current.clone();
let adaptive = AdaptiveFilterType::NonAdaptive;

let mut roundtrip = |kind, bpp: BytesPerPixel| {
filter(kind, adaptive, bpp, &previous, &mut current);
unfilter(kind, bpp, &previous, &mut current).expect("Unfilter worked");
let roundtrip = |kind, bpp: BytesPerPixel| {
let mut output = vec![0; LEN.into()];
filter(kind, adaptive, bpp, &previous, &current, &mut output);
unfilter(kind, bpp, &previous, &mut output).expect("Unfilter worked");
assert_eq!(
current, expected,
output, expected,
"Filtering {:?} with {:?} does not roundtrip",
bpp, kind
);
Expand Down