Skip to content

Commit

Permalink
Unrolled build for rust-lang#121938
Browse files Browse the repository at this point in the history
Rollup merge of rust-lang#121938 - blyxxyz:quadratic-vectored-write, r=Amanieu

Fix quadratic behavior of repeated vectored writes

Some implementations of `Write::write_vectored` in the standard library (`BufWriter`, `LineWriter`, `Stdout`, `Stderr`) check all buffers to calculate the total length. This is O(n) over the number of buffers.

It's common that only a limited number of buffers is written at a time (e.g. 1024 for `writev(2)`). `write_vectored_all` will then call `write_vectored` repeatedly, leading to a runtime of O(n²) over the number of buffers.

This fix is to only calculate as much as needed if it's needed.

Here's a test program:
```rust
#![feature(write_all_vectored)]

use std::fs::File;
use std::io::{BufWriter, IoSlice, Write};
use std::time::Instant;

fn main() {
    let buf = vec![b'\0'; 100_000_000];
    let mut slices: Vec<IoSlice<'_>> = buf.chunks(100).map(IoSlice::new).collect();
    let mut writer = BufWriter::new(File::create("/dev/null").unwrap());

    let start = Instant::now();
    write_smart(&slices, &mut writer);
    println!("write_smart(): {:?}", start.elapsed());

    let start = Instant::now();
    writer.write_all_vectored(&mut slices).unwrap();
    println!("write_all_vectored(): {:?}", start.elapsed());
}

fn write_smart(mut slices: &[IoSlice<'_>], writer: &mut impl Write) {
    while !slices.is_empty() {
        // Only try to write as many slices as can be written
        let res = writer
            .write_vectored(slices.get(..1024).unwrap_or(slices))
            .unwrap();
        slices = &slices[(res / 100)..];
    }
}
```
Before this change:
```
write_smart(): 6.666952ms
write_all_vectored(): 498.437092ms
```
After this change:
```
write_smart(): 6.377158ms
write_all_vectored(): 6.923412ms
```

`LineWriter` (and by extension `Stdout`) isn't fully repaired by this because it looks for newlines. I could open an issue for that after this is merged, I think it's fixable but not trivially.
  • Loading branch information
rust-timer authored Mar 8, 2024
2 parents 14fbc3c + 8212fc5 commit 1b4481e
Show file tree
Hide file tree
Showing 3 changed files with 50 additions and 31 deletions.
51 changes: 27 additions & 24 deletions library/std/src/io/buffered/bufwriter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -554,35 +554,38 @@ impl<W: ?Sized + Write> Write for BufWriter<W> {
// same underlying buffer, as otherwise the buffers wouldn't fit in memory). If the
// computation overflows, then surely the input cannot fit in our buffer, so we forward
// to the inner writer's `write_vectored` method to let it handle it appropriately.
let saturated_total_len =
bufs.iter().fold(0usize, |acc, b| acc.saturating_add(b.len()));
let mut saturated_total_len: usize = 0;

if saturated_total_len > self.spare_capacity() {
// Flush if the total length of the input exceeds our buffer's spare capacity.
// If we would have overflowed, this condition also holds, and we need to flush.
self.flush_buf()?;
for buf in bufs {
saturated_total_len = saturated_total_len.saturating_add(buf.len());

if saturated_total_len > self.spare_capacity() && !self.buf.is_empty() {
// Flush if the total length of the input exceeds our buffer's spare capacity.
// If we would have overflowed, this condition also holds, and we need to flush.
self.flush_buf()?;
}

if saturated_total_len >= self.buf.capacity() {
// Forward to our inner writer if the total length of the input is greater than or
// equal to our buffer capacity. If we would have overflowed, this condition also
// holds, and we punt to the inner writer.
self.panicked = true;
let r = self.get_mut().write_vectored(bufs);
self.panicked = false;
return r;
}
}

if saturated_total_len >= self.buf.capacity() {
// Forward to our inner writer if the total length of the input is greater than or
// equal to our buffer capacity. If we would have overflowed, this condition also
// holds, and we punt to the inner writer.
self.panicked = true;
let r = self.get_mut().write_vectored(bufs);
self.panicked = false;
r
} else {
// `saturated_total_len < self.buf.capacity()` implies that we did not saturate.
// `saturated_total_len < self.buf.capacity()` implies that we did not saturate.

// SAFETY: We checked whether or not the spare capacity was large enough above. If
// it was, then we're safe already. If it wasn't, we flushed, making sufficient
// room for any input <= the buffer size, which includes this input.
unsafe {
bufs.iter().for_each(|b| self.write_to_buffer_unchecked(b));
};
// SAFETY: We checked whether or not the spare capacity was large enough above. If
// it was, then we're safe already. If it wasn't, we flushed, making sufficient
// room for any input <= the buffer size, which includes this input.
unsafe {
bufs.iter().for_each(|b| self.write_to_buffer_unchecked(b));
};

Ok(saturated_total_len)
}
Ok(saturated_total_len)
} else {
let mut iter = bufs.iter();
let mut total_written = if let Some(buf) = iter.by_ref().find(|&buf| !buf.is_empty()) {
Expand Down
15 changes: 12 additions & 3 deletions library/std/src/io/buffered/linewritershim.rs
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,10 @@ impl<'a, W: ?Sized + Write> Write for LineWriterShim<'a, W> {
}

// Find the buffer containing the last newline
// FIXME: This is overly slow if there are very many bufs and none contain
// newlines. e.g. writev() on Linux only writes up to 1024 slices, so
// scanning the rest is wasted effort. This makes write_all_vectored()
// quadratic.
let last_newline_buf_idx = bufs
.iter()
.enumerate()
Expand Down Expand Up @@ -215,9 +219,14 @@ impl<'a, W: ?Sized + Write> Write for LineWriterShim<'a, W> {

// Don't try to reconstruct the exact amount written; just bail
// in the event of a partial write
let lines_len = lines.iter().map(|buf| buf.len()).sum();
if flushed < lines_len {
return Ok(flushed);
let mut lines_len: usize = 0;
for buf in lines {
// With overlapping/duplicate slices the total length may in theory
// exceed usize::MAX
lines_len = lines_len.saturating_add(buf.len());
if flushed < lines_len {
return Ok(flushed);
}
}

// Now that the write has succeeded, buffer the rest (or as much of the
Expand Down
15 changes: 11 additions & 4 deletions library/std/src/io/stdio.rs
Original file line number Diff line number Diff line change
Expand Up @@ -128,8 +128,8 @@ impl Write for StdoutRaw {
}

fn write_vectored(&mut self, bufs: &[IoSlice<'_>]) -> io::Result<usize> {
let total = bufs.iter().map(|b| b.len()).sum();
handle_ebadf(self.0.write_vectored(bufs), total)
let total = || bufs.iter().map(|b| b.len()).sum();
handle_ebadf_lazy(self.0.write_vectored(bufs), total)
}

#[inline]
Expand Down Expand Up @@ -160,8 +160,8 @@ impl Write for StderrRaw {
}

fn write_vectored(&mut self, bufs: &[IoSlice<'_>]) -> io::Result<usize> {
let total = bufs.iter().map(|b| b.len()).sum();
handle_ebadf(self.0.write_vectored(bufs), total)
let total = || bufs.iter().map(|b| b.len()).sum();
handle_ebadf_lazy(self.0.write_vectored(bufs), total)
}

#[inline]
Expand Down Expand Up @@ -193,6 +193,13 @@ fn handle_ebadf<T>(r: io::Result<T>, default: T) -> io::Result<T> {
}
}

fn handle_ebadf_lazy<T>(r: io::Result<T>, default: impl FnOnce() -> T) -> io::Result<T> {
match r {
Err(ref e) if stdio::is_ebadf(e) => Ok(default()),
r => r,
}
}

/// A handle to the standard input stream of a process.
///
/// Each handle is a shared reference to a global buffer of input data to this
Expand Down

0 comments on commit 1b4481e

Please sign in to comment.