Skip to content

Commit

Permalink
Always free staging buffers.
Browse files Browse the repository at this point in the history
We must ensure that the staging buffer is not leaked when an error
occurs, so allocate it as late as possible, and free it explicitly when
those fallible operations we can't move it past go awry.
  • Loading branch information
jimblandy committed Aug 13, 2022
1 parent 5b8a1bc commit ec5f7fe
Showing 1 changed file with 44 additions and 19 deletions.
63 changes: 44 additions & 19 deletions wgpu-core/src/device/queue.rs
Original file line number Diff line number Diff line change
Expand Up @@ -322,21 +322,30 @@ impl<G: GlobalIdentityHandlerFactory> Global<G> {
return Ok(());
}

// Platform validation requires that the staging buffer always been
// freed, even if an error occurs. All paths from here must call
// `device.pending_writes.consume`.
let (staging_buffer, staging_buffer_ptr) = prepare_staging_buffer(&mut device.raw, data_size)?;

unsafe {
if let Err(flush_error) = unsafe {
profiling::scope!("copy");
ptr::copy_nonoverlapping(data.as_ptr(), staging_buffer_ptr, data.len());
staging_buffer.flush(&device.raw)?;
};
staging_buffer.flush(&device.raw)
} {
device.pending_writes.consume(staging_buffer);
return Err(flush_error.into());
}

self.queue_write_staging_buffer_impl(
let result = self.queue_write_staging_buffer_impl(
device,
device_token,
staging_buffer,
&staging_buffer,
buffer_id,
buffer_offset,
)
);

device.pending_writes.consume(staging_buffer);
result
}

pub fn queue_create_staging_buffer<A: HalApi>(
Expand Down Expand Up @@ -385,15 +394,25 @@ impl<G: GlobalIdentityHandlerFactory> Global<G> {
.0
.ok_or(TransferError::InvalidBuffer(buffer_id))?;

unsafe { staging_buffer.flush(&device.raw)? };
// At this point, we have taken ownership of the staging_buffer from the
// user. Platform validation requires that the staging buffer always
// been freed, even if an error occurs. All paths from here must call
// `device.pending_writes.consume`.
if let Err(flush_error) = unsafe { staging_buffer.flush(&device.raw) } {
device.pending_writes.consume(staging_buffer);
return Err(flush_error.into());
}

self.queue_write_staging_buffer_impl(
let result = self.queue_write_staging_buffer_impl(
device,
device_token,
staging_buffer,
&staging_buffer,
buffer_id,
buffer_offset,
)
);

device.pending_writes.consume(staging_buffer);
result
}

pub fn queue_validate_write_buffer<A: HalApi>(
Expand Down Expand Up @@ -453,7 +472,7 @@ impl<G: GlobalIdentityHandlerFactory> Global<G> {
&self,
device: &mut super::Device<A>,
device_token: &mut Token<super::Device<A>>,
staging_buffer: StagingBuffer<A>,
staging_buffer: &StagingBuffer<A>,
buffer_id: id::BufferId,
buffer_offset: u64,
) -> Result<(), QueueWriteError> {
Expand Down Expand Up @@ -492,7 +511,6 @@ impl<G: GlobalIdentityHandlerFactory> Global<G> {
encoder.copy_buffer_to_buffer(&staging_buffer.raw, dst_raw, region.into_iter());
}

device.pending_writes.consume(staging_buffer);
device.pending_writes.dst_buffers.insert(buffer_id);

// Ensure the overwritten bytes are marked as initialized so they don't need to be nulled prior to mapping or binding.
Expand Down Expand Up @@ -585,7 +603,6 @@ impl<G: GlobalIdentityHandlerFactory> Global<G> {
let block_rows_in_copy =
(size.depth_or_array_layers - 1) * block_rows_per_image + height_blocks;
let stage_size = stage_bytes_per_row as u64 * block_rows_in_copy as u64;
let (staging_buffer, staging_buffer_ptr) = prepare_staging_buffer(&mut device.raw, stage_size)?;

let dst = texture_guard.get_mut(destination.texture).unwrap();
if !dst.desc.usage.contains(wgt::TextureUsages::COPY_DST) {
Expand Down Expand Up @@ -648,12 +665,22 @@ impl<G: GlobalIdentityHandlerFactory> Global<G> {
validate_texture_copy_range(destination, &dst.desc, CopySide::Destination, size)?;
dst.life_guard.use_at(device.active_submission_index + 1);

let dst_raw = dst
.inner
.as_raw()
.ok_or(TransferError::InvalidTexture(destination.texture))?;

let bytes_per_row = if let Some(bytes_per_row) = data_layout.bytes_per_row {
bytes_per_row.get()
} else {
width_blocks * format_desc.block_size as u32
};

// Platform validation requires that the staging buffer always been
// freed, even if an error occurs. All paths from here must call
// `device.pending_writes.consume`.
let (staging_buffer, staging_buffer_ptr) = prepare_staging_buffer(&mut device.raw, stage_size)?;

if stage_bytes_per_row == bytes_per_row {
profiling::scope!("copy aligned");
// Fast path if the data is already being aligned optimally.
Expand Down Expand Up @@ -687,7 +714,10 @@ impl<G: GlobalIdentityHandlerFactory> Global<G> {
}
}

unsafe { staging_buffer.flush(&device.raw) }?;
if let Err(e) = unsafe { staging_buffer.flush(&device.raw) } {
device.pending_writes.consume(staging_buffer);
return Err(e.into());
}

let regions = (0..array_layer_count).map(|rel_array_layer| {
let mut texture_base = dst_base.clone();
Expand All @@ -709,11 +739,6 @@ impl<G: GlobalIdentityHandlerFactory> Global<G> {
usage: hal::BufferUses::MAP_WRITE..hal::BufferUses::COPY_SRC,
};

let dst_raw = dst
.inner
.as_raw()
.ok_or(TransferError::InvalidTexture(destination.texture))?;

unsafe {
encoder
.transition_textures(transition.map(|pending| pending.into_hal(dst)).into_iter());
Expand Down

0 comments on commit ec5f7fe

Please sign in to comment.