Skip to content

Commit

Permalink
More complete implementation of "lose the device". (#4645)
Browse files Browse the repository at this point in the history
* More complete implementation of "lose the device".

This provides a way for wgpu-core to specify a callback on "lose the
device". It ensures this callback is called at the appropriate times:
either after device.destroy has empty queues, or on demand from
device.lose.

A test has been added to device.rs.

* Updated CHANGELOG.md.

* Fix conversion to *const c_char.

* Use an allow lint to permit trivial_casts.

* rustfmt changes.
  • Loading branch information
bradwerth authored Nov 8, 2023
1 parent 5f79800 commit 4e65eca
Show file tree
Hide file tree
Showing 11 changed files with 316 additions and 52 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -247,6 +247,8 @@ By @teoxoy in [#4185](https://github.com/gfx-rs/wgpu/pull/4185)
- Calls to lost devices now return `DeviceError::Lost` instead of `DeviceError::Invalid`. By @bradwerth in [#4238]([https://github.com/gfx-rs/wgpu/pull/4238])
- Let the `"strict_asserts"` feature enable check that wgpu-core's lock-ordering tokens are unique per thread. By @jimblandy in [#4258]([https://github.com/gfx-rs/wgpu/pull/4258])
- Allow filtering labels out before they are passed to GPU drivers by @nical in [https://github.com/gfx-rs/wgpu/pull/4246](4246)
- `DeviceLostClosure` callback mechanism provided so user agents can resolve `GPUDevice.lost` Promises at the appropriate time by @bradwerth in [#4645](https://github.com/gfx-rs/wgpu/pull/4645)


#### Vulkan

Expand Down
32 changes: 32 additions & 0 deletions tests/tests/device.rs
Original file line number Diff line number Diff line change
Expand Up @@ -448,3 +448,35 @@ static DEVICE_DESTROY_THEN_MORE: GpuTestConfiguration = GpuTestConfiguration::ne
buffer_for_unmap.unmap();
});
});

#[gpu_test]
static DEVICE_DESTROY_THEN_LOST: GpuTestConfiguration = GpuTestConfiguration::new()
.parameters(TestParameters::default())
.run_sync(|ctx| {
// This test checks that when device.destroy is called, the provided
// DeviceLostClosure is called with reason DeviceLostReason::Destroyed.
let was_called = std::sync::Arc::<std::sync::atomic::AtomicBool>::new(false.into());

// Set a LoseDeviceCallback on the device.
let was_called_clone = was_called.clone();
let callback = Box::new(move |reason, _m| {
was_called_clone.store(true, std::sync::atomic::Ordering::SeqCst);
assert!(
matches!(reason, wgt::DeviceLostReason::Destroyed),
"Device lost info reason should match DeviceLostReason::Destroyed."
);
});
ctx.device.set_device_lost_callback(callback);

// Destroy the device.
ctx.device.destroy();

// Make sure the device queues are empty, which ensures that the closure
// has been called.
assert!(ctx.device.poll(wgpu::Maintain::Wait));

assert!(
was_called.load(std::sync::atomic::Ordering::SeqCst),
"Device lost callback should have been called."
);
});
49 changes: 28 additions & 21 deletions wgpu-core/src/device/global.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@ use crate::device::trace;
use crate::{
binding_model::{self, BindGroupLayout},
command, conv,
device::{life::WaitIdleError, map_buffer, queue, Device, DeviceError, HostMap},
device::{
life::WaitIdleError, map_buffer, queue, Device, DeviceError, DeviceLostClosure, HostMap,
},
global::Global,
hal_api::HalApi,
hub::Token,
Expand Down Expand Up @@ -2672,6 +2674,21 @@ impl<G: GlobalIdentityHandlerFactory> Global<G> {
}
}

pub fn device_set_device_lost_closure<A: HalApi>(
&self,
device_id: DeviceId,
device_lost_closure: DeviceLostClosure,
) {
let hub = A::hub(self);
let mut token = Token::root();

let (mut device_guard, mut token) = hub.devices.write(&mut token);
if let Ok(device) = device_guard.get_mut(device_id) {
let mut life_tracker = device.lock_life(&mut token);
life_tracker.device_lost_closure = Some(device_lost_closure);
}
}

pub fn device_destroy<A: HalApi>(&self, device_id: DeviceId) {
log::trace!("Device::destroy {device_id:?}");

Expand All @@ -2683,36 +2700,26 @@ impl<G: GlobalIdentityHandlerFactory> Global<G> {
// Follow the steps at
// https://gpuweb.github.io/gpuweb/#dom-gpudevice-destroy.

// It's legal to call destroy multiple times, but if the device
// is already invalid, there's nothing more to do. There's also
// no need to return an error.
if !device.valid {
return;
}

// The last part of destroy is to lose the device. The spec says
// delay that until all "currently-enqueued operations on any
// queue on this device are completed."

// TODO: implement this delay.

// Finish by losing the device.

// TODO: associate this "destroyed" reason more tightly with
// the GPUDeviceLostReason defined in webgpu.idl.
device.lose(Some("destroyed"));
// queue on this device are completed." This is accomplished by
// setting valid to false, and then relying upon maintain to
// check for empty queues and a DeviceLostClosure. At that time,
// the DeviceLostClosure will be called with "destroyed" as the
// reason.
device.valid = false;
}
}

pub fn device_lose<A: HalApi>(&self, device_id: DeviceId, reason: Option<&str>) {
log::trace!("Device::lose {device_id:?}");
pub fn device_mark_lost<A: HalApi>(&self, device_id: DeviceId, message: &str) {
log::trace!("Device::mark_lost {device_id:?}");

let hub = A::hub(self);
let mut token = Token::root();

let (mut device_guard, _) = hub.devices.write(&mut token);
let (mut device_guard, mut token) = hub.devices.write(&mut token);
if let Ok(device) = device_guard.get_mut(device_id) {
device.lose(reason);
device.lose(&mut token, message);
}
}

Expand Down
8 changes: 7 additions & 1 deletion wgpu-core/src/device/life.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use crate::device::trace;
use crate::{
device::{
queue::{EncoderInFlight, SubmittedWorkDoneClosure, TempResource},
DeviceError,
DeviceError, DeviceLostClosure,
},
hal_api::HalApi,
hub::{Hub, Token},
Expand Down Expand Up @@ -313,6 +313,11 @@ pub(super) struct LifetimeTracker<A: hal::Api> {
/// must happen _after_ all mapped buffer callbacks are mapped, so we defer them
/// here until the next time the device is maintained.
work_done_closures: SmallVec<[SubmittedWorkDoneClosure; 1]>,

/// Closure to be called on "lose the device". This is invoked directly by
/// device.lose or by the UserCallbacks returned from maintain when the device
/// has been destroyed and its queues are empty.
pub device_lost_closure: Option<DeviceLostClosure>,
}

impl<A: hal::Api> LifetimeTracker<A> {
Expand All @@ -326,6 +331,7 @@ impl<A: hal::Api> LifetimeTracker<A> {
free_resources: NonReferencedResources::new(),
ready_to_map: Vec::new(),
work_done_closures: SmallVec::new(),
device_lost_closure: None,
}
}

Expand Down
98 changes: 97 additions & 1 deletion wgpu-core/src/device/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,9 @@ use crate::{
use arrayvec::ArrayVec;
use hal::Device as _;
use smallvec::SmallVec;
use std::os::raw::c_char;
use thiserror::Error;
use wgt::{BufferAddress, TextureFormat};
use wgt::{BufferAddress, DeviceLostReason, TextureFormat};

use std::{iter, num::NonZeroU32, ptr};

Expand Down Expand Up @@ -169,12 +170,15 @@ pub type BufferMapPendingClosure = (BufferMapOperation, BufferAccessResult);
pub struct UserClosures {
pub mappings: Vec<BufferMapPendingClosure>,
pub submissions: SmallVec<[queue::SubmittedWorkDoneClosure; 1]>,
pub device_lost_invocations: SmallVec<[DeviceLostInvocation; 1]>,
}

impl UserClosures {
fn extend(&mut self, other: Self) {
self.mappings.extend(other.mappings);
self.submissions.extend(other.submissions);
self.device_lost_invocations
.extend(other.device_lost_invocations);
}

fn fire(self) {
Expand All @@ -189,6 +193,98 @@ impl UserClosures {
for closure in self.submissions {
closure.call();
}
for invocation in self.device_lost_invocations {
invocation
.closure
.call(invocation.reason, invocation.message);
}
}
}

#[cfg(any(
not(target_arch = "wasm32"),
all(
feature = "fragile-send-sync-non-atomic-wasm",
not(target_feature = "atomics")
)
))]
pub type DeviceLostCallback = Box<dyn FnOnce(DeviceLostReason, String) + Send + 'static>;
#[cfg(not(any(
not(target_arch = "wasm32"),
all(
feature = "fragile-send-sync-non-atomic-wasm",
not(target_feature = "atomics")
)
)))]
pub type DeviceLostCallback = Box<dyn FnOnce(DeviceLostReason, String) + 'static>;

#[repr(C)]
pub struct DeviceLostClosureC {
pub callback: unsafe extern "C" fn(user_data: *mut u8, reason: u8, message: *const c_char),
pub user_data: *mut u8,
}

#[cfg(any(
not(target_arch = "wasm32"),
all(
feature = "fragile-send-sync-non-atomic-wasm",
not(target_feature = "atomics")
)
))]
unsafe impl Send for DeviceLostClosureC {}

pub struct DeviceLostClosure {
// We wrap this so creating the enum in the C variant can be unsafe,
// allowing our call function to be safe.
inner: DeviceLostClosureInner,
}

pub struct DeviceLostInvocation {
closure: DeviceLostClosure,
reason: DeviceLostReason,
message: String,
}

enum DeviceLostClosureInner {
Rust { callback: DeviceLostCallback },
C { inner: DeviceLostClosureC },
}

impl DeviceLostClosure {
pub fn from_rust(callback: DeviceLostCallback) -> Self {
Self {
inner: DeviceLostClosureInner::Rust { callback },
}
}

/// # Safety
///
/// - The callback pointer must be valid to call with the provided `user_data`
/// pointer.
///
/// - Both pointers must point to `'static` data, as the callback may happen at
/// an unspecified time.
pub unsafe fn from_c(inner: DeviceLostClosureC) -> Self {
Self {
inner: DeviceLostClosureInner::C { inner },
}
}

#[allow(trivial_casts)]
pub(crate) fn call(self, reason: DeviceLostReason, message: String) {
match self.inner {
DeviceLostClosureInner::Rust { callback } => callback(reason, message),
// SAFETY: the contract of the call to from_c says that this unsafe is sound.
DeviceLostClosureInner::C { inner } => unsafe {
// We need to pass message as a c_char typed pointer. To avoid trivial
// conversion warnings on some platforms, we use the allow lint.
(inner.callback)(
inner.user_data,
reason as u8,
message.as_ptr() as *const c_char,
)
},
}
}
}

Expand Down
35 changes: 28 additions & 7 deletions wgpu-core/src/device/resource.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@ use crate::{
command, conv,
device::life::WaitIdleError,
device::{
AttachmentData, CommandAllocator, MissingDownlevelFlags, MissingFeatures,
RenderPassContext, CLEANUP_WAIT_MS,
AttachmentData, CommandAllocator, DeviceLostInvocation, MissingDownlevelFlags,
MissingFeatures, RenderPassContext, CLEANUP_WAIT_MS,
},
hal_api::HalApi,
hal_label,
Expand All @@ -34,7 +34,7 @@ use hal::{CommandEncoder as _, Device as _};
use parking_lot::{Mutex, MutexGuard};
use smallvec::SmallVec;
use thiserror::Error;
use wgt::{TextureFormat, TextureSampleType, TextureViewDimension};
use wgt::{DeviceLostReason, TextureFormat, TextureSampleType, TextureViewDimension};

use std::{borrow::Cow, iter, num::NonZeroU32};

Expand Down Expand Up @@ -315,9 +315,24 @@ impl<A: HalApi> Device<A> {
let mapping_closures = life_tracker.handle_mapping(hub, &self.raw, &self.trackers, token);
life_tracker.cleanup(&self.raw);

// Detect if we have been destroyed and now need to lose the device.
// If we are invalid (set at start of destroy) and our queue is empty,
// and we have a DeviceLostClosure, return the closure to be called by
// our caller. This will complete the steps for both destroy and for
// "lose the device".
let mut device_lost_invocations = SmallVec::new();
if !self.valid && life_tracker.queue_empty() && life_tracker.device_lost_closure.is_some() {
device_lost_invocations.push(DeviceLostInvocation {
closure: life_tracker.device_lost_closure.take().unwrap(),
reason: DeviceLostReason::Destroyed,
message: String::new(),
});
}

let closures = UserClosures {
mappings: mapping_closures,
submissions: submission_closures,
device_lost_invocations,
};
Ok((closures, life_tracker.queue_empty()))
}
Expand Down Expand Up @@ -3304,17 +3319,23 @@ impl<A: HalApi> Device<A> {
})
}

pub(crate) fn lose(&mut self, _reason: Option<&str>) {
pub(crate) fn lose<'this, 'token: 'this>(
&'this mut self,
token: &mut Token<'token, Self>,
message: &str,
) {
// Follow the steps at https://gpuweb.github.io/gpuweb/#lose-the-device.

// Mark the device explicitly as invalid. This is checked in various
// places to prevent new work from being submitted.
self.valid = false;

// The following steps remain in "lose the device":
// 1) Resolve the GPUDevice device.lost promise.

// TODO: triggger this passively or actively, and supply the reason.
let mut life_tracker = self.lock_life(token);
if life_tracker.device_lost_closure.is_some() {
let device_lost_closure = life_tracker.device_lost_closure.take().unwrap();
device_lost_closure.call(DeviceLostReason::Unknown, message.to_string());
}

// 2) Complete any outstanding mapAsync() steps.
// 3) Complete any outstanding onSubmittedWorkDone() steps.
Expand Down
12 changes: 12 additions & 0 deletions wgpu-types/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6707,3 +6707,15 @@ mod send_sync {
)))]
impl<T> WasmNotSync for T {}
}

/// Reason for "lose the device".
///
/// Corresponds to [WebGPU `GPUDeviceLostReason`](https://gpuweb.github.io/gpuweb/#enumdef-gpudevicelostreason).
#[repr(u8)]
#[derive(Debug, Copy, Clone)]
pub enum DeviceLostReason {
/// Triggered by driver
Unknown = 0,
/// After Device::destroy
Destroyed = 1,
}
Loading

0 comments on commit 4e65eca

Please sign in to comment.