Skip to content

Commit

Permalink
Ensure device lost closure is called exactly once before being dropped.
Browse files Browse the repository at this point in the history
This requires a change to the Rust callback signature, which is now Fn
instead of FnOnce. When the Rust callback or the C closure are dropped,
they will panic if they haven't been called. `device_drop` is changed
to call the closure with a message of "Device dropped." A test is added.
  • Loading branch information
bradwerth authored and jimblandy committed Dec 19, 2023
1 parent d1fe8d6 commit ca59886
Show file tree
Hide file tree
Showing 6 changed files with 86 additions and 30 deletions.
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ Wgpu now exposes backend feature for the Direct3D 12 (`dx12`) and Metal (`metal`
- No longer validate surfaces against their allowed extent range on configure. This caused warnings that were almost impossible to avoid. As before, the resulting behavior depends on the compositor. By @wumpf in [#4796](https://github.com/gfx-rs/wgpu/pull/4796)
- Added support for the float32-filterable feature. By @almarklein in [#4759](https://github.com/gfx-rs/wgpu/pull/4759)
- wgpu and wgpu-core features are now documented on docs.rs. By @wumpf in [#4886](https://github.com/gfx-rs/wgpu/pull/4886)
- DeviceLostCallbackC is guaranteed to be invoked exactly once. By @bradwerth in [#4862](https://github.com/gfx-rs/wgpu/pull/4862)
- DeviceLostClosure is guaranteed to be invoked exactly once. By @bradwerth in [#4862](https://github.com/gfx-rs/wgpu/pull/4862)

#### OpenGL
- `@builtin(instance_index)` now properly reflects the range provided in the draw call instead of always counting from 0. By @cwfitzgerald in [#4722](https://github.com/gfx-rs/wgpu/pull/4722).
Expand Down
33 changes: 33 additions & 0 deletions tests/tests/device.rs
Original file line number Diff line number Diff line change
Expand Up @@ -511,3 +511,36 @@ static DEVICE_DESTROY_THEN_LOST: GpuTestConfiguration = GpuTestConfiguration::ne
"Device lost callback should have been called."
);
});

#[gpu_test]
static DEVICE_DROP_THEN_LOST: GpuTestConfiguration = GpuTestConfiguration::new()
.parameters(TestParameters::default().expect_fail(FailureCase::webgl2()))
.run_sync(|ctx| {
// This test checks that when the device is dropped (such as in a GC),
// the provided DeviceLostClosure is called with reason DeviceLostReason::Unknown.
// Fails on webgl because webgl doesn't implement drop.
let was_called = std::sync::Arc::<std::sync::atomic::AtomicBool>::new(false.into());

// Set a LoseDeviceCallback on the device.
let was_called_clone = was_called.clone();
let callback = Box::new(move |reason, message| {
was_called_clone.store(true, std::sync::atomic::Ordering::SeqCst);
assert!(
matches!(reason, wgt::DeviceLostReason::Unknown),
"Device lost info reason should match DeviceLostReason::Unknown."
);
assert!(
message == "Device dropped.",
"Device lost info message should be \"Device dropped.\"."
);
});
ctx.device.set_device_lost_callback(callback);

// Drop the device.
drop(ctx.device);

assert!(
was_called.load(std::sync::atomic::Ordering::SeqCst),
"Device lost callback should have been called."
);
});
11 changes: 9 additions & 2 deletions wgpu-core/src/device/global.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@ use crate::device::trace;
use crate::{
api_log, binding_model, command, conv,
device::{
life::WaitIdleError, map_buffer, queue, DeviceError, DeviceLostClosure, HostMap,
IMPLICIT_BIND_GROUP_LAYOUT_ERROR_LABEL,
life::WaitIdleError, map_buffer, queue, DeviceError, DeviceLostClosure, DeviceLostReason,
HostMap, IMPLICIT_BIND_GROUP_LAYOUT_ERROR_LABEL,
},
global::Global,
hal_api::HalApi,
Expand Down Expand Up @@ -2239,6 +2239,11 @@ impl<G: GlobalIdentityHandlerFactory> Global<G> {

let hub = A::hub(self);
if let Some(device) = hub.devices.unregister(device_id) {
let device_lost_closure = device.lock_life().device_lost_closure.take();
if let Some(closure) = device_lost_closure {
closure.call(DeviceLostReason::Unknown, String::from("Device dropped."));
}

// The things `Device::prepare_to_die` takes care are mostly
// unnecessary here. We know our queue is empty, so we don't
// need to wait for submissions or triage them. We know we were
Expand All @@ -2254,6 +2259,8 @@ impl<G: GlobalIdentityHandlerFactory> Global<G> {
}
}

// This closure will be called exactly once during "lose the device"
// or when the device is dropped, if it was never lost.
pub fn device_set_device_lost_closure<A: HalApi>(
&self,
device_id: DeviceId,
Expand Down
64 changes: 40 additions & 24 deletions wgpu-core/src/device/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -211,21 +211,34 @@ impl UserClosures {
not(target_feature = "atomics")
)
))]
pub type DeviceLostCallback = Box<dyn FnOnce(DeviceLostReason, String) + Send + 'static>;
pub type DeviceLostCallback = Box<dyn Fn(DeviceLostReason, String) + Send + 'static>;
#[cfg(not(any(
not(target_arch = "wasm32"),
all(
feature = "fragile-send-sync-non-atomic-wasm",
not(target_feature = "atomics")
)
)))]
pub type DeviceLostCallback = Box<dyn FnOnce(DeviceLostReason, String) + 'static>;
pub type DeviceLostCallback = Box<dyn Fn(DeviceLostReason, String) + 'static>;

pub struct DeviceLostClosureRust {
pub callback: DeviceLostCallback,
called: bool,
}

impl Drop for DeviceLostClosureRust {
fn drop(&mut self) {
if !self.called {
panic!("DeviceLostClosureRust must be called before it is dropped.");
}
}
}

#[repr(C)]
pub struct DeviceLostClosureC {
pub callback: unsafe extern "C" fn(user_data: *mut u8, reason: u8, message: *const c_char),
pub user_data: *mut u8,
pub called: bool,
called: bool,
}

#[cfg(any(
Expand All @@ -239,18 +252,8 @@ unsafe impl Send for DeviceLostClosureC {}

impl Drop for DeviceLostClosureC {
fn drop(&mut self) {
unsafe {
if !self.called {
self.called = true;
// Invoke the closure with reason Destroyed so embedder can recover
// the memory.
let message = std::ffi::CString::new("Dropped").unwrap();
(self.callback)(
self.user_data,
DeviceLostReason::Destroyed as u8,
message.as_ptr(),
)
}
if !self.called {
panic!("DeviceLostClosureC must be called before it is dropped.");
}
}
}
Expand All @@ -268,14 +271,18 @@ pub struct DeviceLostInvocation {
}

enum DeviceLostClosureInner {
Rust { callback: DeviceLostCallback },
Rust { inner: DeviceLostClosureRust },
C { inner: DeviceLostClosureC },
}

impl DeviceLostClosure {
pub fn from_rust(callback: DeviceLostCallback) -> Self {
let inner = DeviceLostClosureRust {
callback,
called: false,
};
Self {
inner: DeviceLostClosureInner::Rust { callback },
inner: DeviceLostClosureInner::Rust { inner },
}
}

Expand All @@ -301,16 +308,25 @@ impl DeviceLostClosure {

pub(crate) fn call(self, reason: DeviceLostReason, message: String) {
match self.inner {
DeviceLostClosureInner::Rust { callback } => callback(reason, message),
DeviceLostClosureInner::Rust { mut inner } => {
if inner.called {
panic!("DeviceLostClosureRust must only be called once.");
}
inner.called = true;

(inner.callback)(reason, message)
}
// SAFETY: the contract of the call to from_c says that this unsafe is sound.
DeviceLostClosureInner::C { mut inner } => unsafe {
if !inner.called {
inner.called = true;
// Ensure message is structured as a null-terminated C string. It only
// needs to live as long as the callback invocation.
let message = std::ffi::CString::new(message).unwrap();
(inner.callback)(inner.user_data, reason as u8, message.as_ptr())
if inner.called {
panic!("DeviceLostClosureC must only be called once.");
}
inner.called = true;

// Ensure message is structured as a null-terminated C string. It only
// needs to live as long as the callback invocation.
let message = std::ffi::CString::new(message).unwrap();
(inner.callback)(inner.user_data, reason as u8, message.as_ptr())
},
}
}
Expand Down
4 changes: 2 additions & 2 deletions wgpu/src/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1210,15 +1210,15 @@ pub type SubmittedWorkDoneCallback = Box<dyn FnOnce() + 'static>;
not(target_feature = "atomics")
)
))]
pub type DeviceLostCallback = Box<dyn FnOnce(DeviceLostReason, String) + Send + 'static>;
pub type DeviceLostCallback = Box<dyn Fn(DeviceLostReason, String) + Send + 'static>;
#[cfg(not(any(
not(target_arch = "wasm32"),
all(
feature = "fragile-send-sync-non-atomic-wasm",
not(target_feature = "atomics")
)
)))]
pub type DeviceLostCallback = Box<dyn FnOnce(DeviceLostReason, String) + 'static>;
pub type DeviceLostCallback = Box<dyn Fn(DeviceLostReason, String) + 'static>;

/// An object safe variant of [`Context`] implemented by all types that implement [`Context`].
pub(crate) trait DynContext: Debug + WasmNotSendSync {
Expand Down
2 changes: 1 addition & 1 deletion wgpu/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2923,7 +2923,7 @@ impl Device {
/// Set a DeviceLostCallback on this device.
pub fn set_device_lost_callback(
&self,
callback: impl FnOnce(DeviceLostReason, String) + Send + 'static,
callback: impl Fn(DeviceLostReason, String) + Send + 'static,
) {
DynContext::device_set_device_lost_callback(
&*self.context,
Expand Down

0 comments on commit ca59886

Please sign in to comment.