From 6f189475e8bb57a624582f08826851eebdf23b00 Mon Sep 17 00:00:00 2001 From: Nicolas Silva Date: Thu, 14 Dec 2023 13:13:05 +0100 Subject: [PATCH] Refactor create_buffer so that we can snatch the raw buffer in the error path. The general idea is to register postpone reigistering the buffer until towards the end of the function so that our unique reference to it lets us easily snatch the raw buffer if an error happens. --- wgpu-core/src/device/global.rs | 97 ++++++++++++++++------------------ 1 file changed, 45 insertions(+), 52 deletions(-) diff --git a/wgpu-core/src/device/global.rs b/wgpu-core/src/device/global.rs index a55d1563ef..6f21e8bc6e 100644 --- a/wgpu-core/src/device/global.rs +++ b/wgpu-core/src/device/global.rs @@ -14,17 +14,24 @@ use crate::{ instance::{self, Adapter, Surface}, pipeline, present, resource::{self, BufferAccessResult}, - resource::{BufferAccessError, BufferMapOperation, Resource}, + resource::{BufferAccessError, BufferMapOperation, CreateBufferError, Resource}, validation::check_buffer_usage, FastHashMap, Label, LabelHelpers as _, }; +use arrayvec::ArrayVec; use hal::Device as _; use parking_lot::RwLock; use wgt::{BufferAddress, TextureFormat}; -use std::{borrow::Cow, iter, ops::Range, ptr, sync::atomic::Ordering}; +use std::{ + borrow::Cow, + iter, + ops::Range, + ptr, + sync::{atomic::Ordering, Arc}, +}; use super::{ImplicitPipelineIds, InvalidDevice, UserClosures}; @@ -140,12 +147,13 @@ impl Global { device_id: DeviceId, desc: &resource::BufferDescriptor, id_in: Input, - ) -> (id::BufferId, Option) { + ) -> (id::BufferId, Option) { profiling::scope!("Device::create_buffer"); let hub = A::hub(self); let fid = hub.buffers.prepare::(id_in); + let mut to_destroy: ArrayVec, 2> = ArrayVec::new(); let error = loop { let device = match hub.devices.get(device_id) { Ok(device) => device, @@ -159,11 +167,7 @@ impl Global { if desc.usage.is_empty() { // Per spec, `usage` must not be zero. - let id = fid.assign_error(desc.label.borrow_or_default()); - return ( - id, - Some(resource::CreateBufferError::InvalidUsage(desc.usage)), - ); + break CreateBufferError::InvalidUsage(desc.usage); } #[cfg(feature = "trace")] @@ -179,36 +183,27 @@ impl Global { let buffer = match device.create_buffer(desc, false) { Ok(buffer) => buffer, Err(e) => { - let id = fid.assign_error(desc.label.borrow_or_default()); - return (id, Some(e)); + break e; } }; - let (id, resource) = fid.assign(buffer); - api_log!("Device::create_buffer({desc:?}) -> {id:?}"); - let buffer_use = if !desc.mapped_at_creation { hal::BufferUses::empty() } else if desc.usage.contains(wgt::BufferUsages::MAP_WRITE) { // buffer is mappable, so we are just doing that at start - let map_size = resource.size; + let map_size = buffer.size; let ptr = if map_size == 0 { std::ptr::NonNull::dangling() } else { - match map_buffer(device.raw(), &resource, 0, map_size, HostMap::Write) { + match map_buffer(device.raw(), &buffer, 0, map_size, HostMap::Write) { Ok(ptr) => ptr, Err(e) => { - device.lock_life().schedule_resource_destruction( - queue::TempResource::Buffer(resource), - !0, - ); - hub.buffers - .force_replace_with_error(id, desc.label.borrow_or_default()); - return (id, Some(e.into())); + to_destroy.push(buffer); + break e.into(); } } }; - *resource.map_state.lock() = resource::BufferMapState::Active { + *buffer.map_state.lock() = resource::BufferMapState::Active { ptr, range: 0..map_size, host: HostMap::Write, @@ -227,17 +222,10 @@ impl Global { let stage = match device.create_buffer(&stage_desc, true) { Ok(stage) => stage, Err(e) => { - device.lock_life().schedule_resource_destruction( - queue::TempResource::Buffer(resource), - !0, - ); - hub.buffers - .force_replace_with_error(id, desc.label.borrow_or_default()); - return (id, Some(e)); + to_destroy.push(buffer); + break e; } }; - let stage_fid = hub.buffers.request(); - let stage = stage_fid.init(stage); let snatch_guard = device.snatchable_lock.read(); let mapping = match unsafe { @@ -247,30 +235,23 @@ impl Global { } { Ok(mapping) => mapping, Err(e) => { - let mut life_lock = device.lock_life(); - life_lock.schedule_resource_destruction( - queue::TempResource::Buffer(resource), - !0, - ); - life_lock - .schedule_resource_destruction(queue::TempResource::Buffer(stage), !0); - hub.buffers - .force_replace_with_error(id, desc.label.borrow_or_default()); - return (id, Some(DeviceError::from(e).into())); + to_destroy.push(buffer); + to_destroy.push(stage); + break CreateBufferError::Device(e.into()); } }; - assert_eq!(resource.size % wgt::COPY_BUFFER_ALIGNMENT, 0); + let stage_fid = hub.buffers.request(); + let stage = stage_fid.init(stage); + + assert_eq!(buffer.size % wgt::COPY_BUFFER_ALIGNMENT, 0); // Zero initialize memory and then mark both staging and buffer as initialized // (it's guaranteed that this is the case by the time the buffer is usable) - unsafe { ptr::write_bytes(mapping.ptr.as_ptr(), 0, resource.size as usize) }; - resource - .initialization_status - .write() - .drain(0..resource.size); - stage.initialization_status.write().drain(0..resource.size); - - *resource.map_state.lock() = resource::BufferMapState::Init { + unsafe { ptr::write_bytes(mapping.ptr.as_ptr(), 0, buffer.size as usize) }; + buffer.initialization_status.write().drain(0..buffer.size); + stage.initialization_status.write().drain(0..buffer.size); + + *buffer.map_state.lock() = resource::BufferMapState::Init { ptr: mapping.ptr, needs_flush: !mapping.is_coherent, stage_buffer: stage, @@ -278,6 +259,9 @@ impl Global { hal::BufferUses::COPY_DST }; + let (id, resource) = fid.assign(buffer); + api_log!("Device::create_buffer({desc:?}) -> {id:?}"); + device .trackers .lock() @@ -287,6 +271,15 @@ impl Global { return (id, None); }; + // Error path + + for buffer in to_destroy { + let device = Arc::clone(&buffer.device); + device + .lock_life() + .schedule_resource_destruction(queue::TempResource::Buffer(Arc::new(buffer)), !0); + } + let id = fid.assign_error(desc.label.borrow_or_default()); (id, Some(error)) } @@ -673,7 +666,7 @@ impl Global { device_id: DeviceId, desc: &resource::BufferDescriptor, id_in: Input, - ) -> (id::BufferId, Option) { + ) -> (id::BufferId, Option) { profiling::scope!("Device::create_buffer"); let hub = A::hub(self);