Skip to content

Commit

Permalink
Propagate errors when openning/closing a command encoder (#4999)
Browse files Browse the repository at this point in the history
  • Loading branch information
nical authored Jan 6, 2024
1 parent 5e06baf commit 8358868
Show file tree
Hide file tree
Showing 6 changed files with 89 additions and 58 deletions.
7 changes: 5 additions & 2 deletions wgpu-core/src/command/clear.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ use crate::device::trace::Command as TraceCommand;
use crate::{
api_log,
command::CommandBuffer,
device::DeviceError,
get_lowest_common_denom,
global::Global,
hal_api::HalApi,
Expand Down Expand Up @@ -66,6 +67,8 @@ whereas subesource range specified start {subresource_base_array_layer} and coun
subresource_base_array_layer: u32,
subresource_array_layer_count: Option<u32>,
},
#[error(transparent)]
Device(#[from] DeviceError),
}

impl<G: GlobalIdentityHandlerFactory> Global<G> {
Expand Down Expand Up @@ -149,7 +152,7 @@ impl<G: GlobalIdentityHandlerFactory> Global<G> {

// actual hal barrier & operation
let dst_barrier = dst_pending.map(|pending| pending.into_hal(&dst_buffer, &snatch_guard));
let cmd_buf_raw = cmd_buf_data.encoder.open();
let cmd_buf_raw = cmd_buf_data.encoder.open()?;
unsafe {
cmd_buf_raw.transition_buffers(dst_barrier.into_iter());
cmd_buf_raw.clear_buffer(dst_raw, offset..end);
Expand Down Expand Up @@ -228,7 +231,7 @@ impl<G: GlobalIdentityHandlerFactory> Global<G> {
if !device.is_valid() {
return Err(ClearError::InvalidDevice(cmd_buf.device.as_info().id()));
}
let (encoder, tracker) = cmd_buf_data.open_encoder_and_tracker();
let (encoder, tracker) = cmd_buf_data.open_encoder_and_tracker()?;

clear_texture(
&dst_texture,
Expand Down
21 changes: 12 additions & 9 deletions wgpu-core/src/command/compute.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
use crate::device::DeviceError;
use crate::resource::Resource;
use crate::snatch::SnatchGuard;
use crate::{
Expand Down Expand Up @@ -186,6 +187,8 @@ pub enum DispatchError {
/// Error encountered when performing a compute pass.
#[derive(Clone, Debug, Error)]
pub enum ComputePassErrorInner {
#[error(transparent)]
Device(#[from] DeviceError),
#[error(transparent)]
Encoder(#[from] CommandEncoderError),
#[error("Bind group at index {0:?} is invalid")]
Expand Down Expand Up @@ -366,17 +369,17 @@ impl<G: GlobalIdentityHandlerFactory> Global<G> {
timestamp_writes: Option<&ComputePassTimestampWrites>,
) -> Result<(), ComputePassError> {
profiling::scope!("CommandEncoder::run_compute_pass");
let init_scope = PassErrorScope::Pass(encoder_id);
let pass_scope = PassErrorScope::Pass(encoder_id);

let hub = A::hub(self);

let cmd_buf = CommandBuffer::get_encoder(hub, encoder_id).map_pass_err(init_scope)?;
let cmd_buf = CommandBuffer::get_encoder(hub, encoder_id).map_pass_err(pass_scope)?;
let device = &cmd_buf.device;
if !device.is_valid() {
return Err(ComputePassErrorInner::InvalidDevice(
cmd_buf.device.as_info().id(),
))
.map_pass_err(init_scope);
.map_pass_err(pass_scope);
}

let mut cmd_buf_data = cmd_buf.data.lock();
Expand All @@ -399,10 +402,10 @@ impl<G: GlobalIdentityHandlerFactory> Global<G> {
// We automatically keep extending command buffers over time, and because
// we want to insert a command buffer _before_ what we're about to record,
// we need to make sure to close the previous one.
encoder.close();
encoder.close().map_pass_err(pass_scope)?;
// will be reset to true if recording is done without errors
*status = CommandEncoderStatus::Error;
let raw = encoder.open();
let raw = encoder.open().map_pass_err(pass_scope)?;

let bind_group_guard = hub.bind_groups.read();
let pipeline_guard = hub.compute_pipelines.read();
Expand All @@ -426,7 +429,7 @@ impl<G: GlobalIdentityHandlerFactory> Global<G> {
.query_sets
.add_single(&*query_set_guard, tw.query_set)
.ok_or(ComputePassErrorInner::InvalidQuerySet(tw.query_set))
.map_pass_err(init_scope)?;
.map_pass_err(pass_scope)?;

// Unlike in render passes we can't delay resetting the query sets since
// there is no auxillary pass.
Expand Down Expand Up @@ -862,12 +865,12 @@ impl<G: GlobalIdentityHandlerFactory> Global<G> {
*status = CommandEncoderStatus::Recording;

// Stop the current command buffer.
encoder.close();
encoder.close().map_pass_err(pass_scope)?;

// Create a new command buffer, which we will insert _before_ the body of the compute pass.
//
// Use that buffer to insert barriers and clear discarded images.
let transit = encoder.open();
let transit = encoder.open().map_pass_err(pass_scope)?;
fixup_discarded_surfaces(
pending_discard_init_fixups.into_iter(),
transit,
Expand All @@ -881,7 +884,7 @@ impl<G: GlobalIdentityHandlerFactory> Global<G> {
&snatch_guard,
);
// Close the command buffer, and swap it with the previous.
encoder.close_and_swap();
encoder.close_and_swap().map_pass_err(pass_scope)?;

Ok(())
}
Expand Down
59 changes: 37 additions & 22 deletions wgpu-core/src/command/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ pub use self::{

use self::memory_init::CommandBufferTextureMemoryActions;

use crate::device::Device;
use crate::device::{Device, DeviceError};
use crate::error::{ErrorFormatter, PrettyError};
use crate::hub::Hub;
use crate::id::CommandBufferId;
Expand Down Expand Up @@ -58,20 +58,24 @@ pub(crate) struct CommandEncoder<A: HalApi> {
//TODO: handle errors better
impl<A: HalApi> CommandEncoder<A> {
/// Closes the live encoder
fn close_and_swap(&mut self) {
fn close_and_swap(&mut self) -> Result<(), DeviceError> {
if self.is_open {
self.is_open = false;
let new = unsafe { self.raw.end_encoding().unwrap() };
let new = unsafe { self.raw.end_encoding()? };
self.list.insert(self.list.len() - 1, new);
}

Ok(())
}

fn close(&mut self) {
fn close(&mut self) -> Result<(), DeviceError> {
if self.is_open {
self.is_open = false;
let cmd_buf = unsafe { self.raw.end_encoding().unwrap() };
let cmd_buf = unsafe { self.raw.end_encoding()? };
self.list.push(cmd_buf);
}

Ok(())
}

fn discard(&mut self) {
Expand All @@ -81,18 +85,21 @@ impl<A: HalApi> CommandEncoder<A> {
}
}

fn open(&mut self) -> &mut A::CommandEncoder {
fn open(&mut self) -> Result<&mut A::CommandEncoder, DeviceError> {
if !self.is_open {
self.is_open = true;
let label = self.label.as_deref();
unsafe { self.raw.begin_encoding(label).unwrap() };
unsafe { self.raw.begin_encoding(label)? };
}
&mut self.raw

Ok(&mut self.raw)
}

fn open_pass(&mut self, label: Option<&str>) {
fn open_pass(&mut self, label: Option<&str>) -> Result<(), DeviceError> {
self.is_open = true;
unsafe { self.raw.begin_encoding(label).unwrap() };
unsafe { self.raw.begin_encoding(label)? };

Ok(())
}
}

Expand All @@ -119,10 +126,13 @@ pub struct CommandBufferMutable<A: HalApi> {
}

impl<A: HalApi> CommandBufferMutable<A> {
pub(crate) fn open_encoder_and_tracker(&mut self) -> (&mut A::CommandEncoder, &mut Tracker<A>) {
let encoder = self.encoder.open();
pub(crate) fn open_encoder_and_tracker(
&mut self,
) -> Result<(&mut A::CommandEncoder, &mut Tracker<A>), DeviceError> {
let encoder = self.encoder.open()?;
let tracker = &mut self.trackers;
(encoder, tracker)

Ok((encoder, tracker))
}
}

Expand Down Expand Up @@ -401,6 +411,8 @@ pub enum CommandEncoderError {
Invalid,
#[error("Command encoder must be active")]
NotRecording,
#[error(transparent)]
Device(#[from] DeviceError),
}

impl<G: GlobalIdentityHandlerFactory> Global<G> {
Expand All @@ -419,12 +431,15 @@ impl<G: GlobalIdentityHandlerFactory> Global<G> {
let cmd_buf_data = cmd_buf_data.as_mut().unwrap();
match cmd_buf_data.status {
CommandEncoderStatus::Recording => {
cmd_buf_data.encoder.close();
cmd_buf_data.status = CommandEncoderStatus::Finished;
//Note: if we want to stop tracking the swapchain texture view,
// this is the place to do it.
log::trace!("Command buffer {:?}", encoder_id);
None
if let Err(e) = cmd_buf_data.encoder.close() {
Some(e.into())
} else {
cmd_buf_data.status = CommandEncoderStatus::Finished;
//Note: if we want to stop tracking the swapchain texture view,
// this is the place to do it.
log::trace!("Command buffer {:?}", encoder_id);
None
}
}
CommandEncoderStatus::Finished => Some(CommandEncoderError::NotRecording),
CommandEncoderStatus::Error => {
Expand Down Expand Up @@ -457,7 +472,7 @@ impl<G: GlobalIdentityHandlerFactory> Global<G> {
list.push(TraceCommand::PushDebugGroup(label.to_string()));
}

let cmd_buf_raw = cmd_buf_data.encoder.open();
let cmd_buf_raw = cmd_buf_data.encoder.open()?;
if !self
.instance
.flags
Expand Down Expand Up @@ -494,7 +509,7 @@ impl<G: GlobalIdentityHandlerFactory> Global<G> {
.flags
.contains(wgt::InstanceFlags::DISCARD_HAL_LABELS)
{
let cmd_buf_raw = cmd_buf_data.encoder.open();
let cmd_buf_raw = cmd_buf_data.encoder.open()?;
unsafe {
cmd_buf_raw.insert_debug_marker(label);
}
Expand All @@ -520,7 +535,7 @@ impl<G: GlobalIdentityHandlerFactory> Global<G> {
list.push(TraceCommand::PopDebugGroup);
}

let cmd_buf_raw = cmd_buf_data.encoder.open();
let cmd_buf_raw = cmd_buf_data.encoder.open()?;
if !self
.instance
.flags
Expand Down
7 changes: 5 additions & 2 deletions wgpu-core/src/command/query.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ use hal::CommandEncoder as _;
use crate::device::trace::Command as TraceCommand;
use crate::{
command::{CommandBuffer, CommandEncoderError},
device::DeviceError,
global::Global,
hal_api::HalApi,
id::{self, Id, TypedId},
Expand Down Expand Up @@ -104,6 +105,8 @@ impl From<wgt::QueryType> for SimplifiedQueryType {
#[derive(Clone, Debug, Error)]
#[non_exhaustive]
pub enum QueryError {
#[error(transparent)]
Device(#[from] DeviceError),
#[error(transparent)]
Encoder(#[from] CommandEncoderError),
#[error("Error encountered while trying to use queries")]
Expand Down Expand Up @@ -367,7 +370,7 @@ impl<G: GlobalIdentityHandlerFactory> Global<G> {
let encoder = &mut cmd_buf_data.encoder;
let tracker = &mut cmd_buf_data.trackers;

let raw_encoder = encoder.open();
let raw_encoder = encoder.open()?;

let query_set_guard = hub.query_sets.read();
let query_set = tracker
Expand Down Expand Up @@ -409,7 +412,7 @@ impl<G: GlobalIdentityHandlerFactory> Global<G> {
let encoder = &mut cmd_buf_data.encoder;
let tracker = &mut cmd_buf_data.trackers;
let buffer_memory_init_actions = &mut cmd_buf_data.buffer_memory_init_actions;
let raw_encoder = encoder.open();
let raw_encoder = encoder.open()?;

if destination_offset % wgt::QUERY_RESOLVE_BUFFER_ALIGNMENT != 0 {
return Err(QueryError::Resolve(ResolveError::BufferOffsetAlignment));
Expand Down
20 changes: 10 additions & 10 deletions wgpu-core/src/command/render.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1312,11 +1312,11 @@ impl<G: GlobalIdentityHandlerFactory> Global<G> {
.contains(wgt::InstanceFlags::DISCARD_HAL_LABELS);
let label = hal_label(base.label, self.instance.flags);

let init_scope = PassErrorScope::Pass(encoder_id);
let pass_scope = PassErrorScope::Pass(encoder_id);

let hub = A::hub(self);

let cmd_buf = CommandBuffer::get_encoder(hub, encoder_id).map_pass_err(init_scope)?;
let cmd_buf = CommandBuffer::get_encoder(hub, encoder_id).map_pass_err(pass_scope)?;
let device = &cmd_buf.device;
let snatch_guard = device.snatchable_lock.read();

Expand All @@ -1336,7 +1336,7 @@ impl<G: GlobalIdentityHandlerFactory> Global<G> {
}

if !device.is_valid() {
return Err(DeviceError::Lost).map_pass_err(init_scope);
return Err(DeviceError::Lost).map_pass_err(pass_scope);
}

let encoder = &mut cmd_buf_data.encoder;
Expand All @@ -1349,10 +1349,10 @@ impl<G: GlobalIdentityHandlerFactory> Global<G> {
// We automatically keep extending command buffers over time, and because
// we want to insert a command buffer _before_ what we're about to record,
// we need to make sure to close the previous one.
encoder.close();
encoder.close().map_pass_err(pass_scope)?;
// We will reset this to `Recording` if we succeed, acts as a fail-safe.
*status = CommandEncoderStatus::Error;
encoder.open_pass(label);
encoder.open_pass(label).map_pass_err(pass_scope)?;

let bundle_guard = hub.render_bundles.read();
let bind_group_guard = hub.bind_groups.read();
Expand Down Expand Up @@ -1383,7 +1383,7 @@ impl<G: GlobalIdentityHandlerFactory> Global<G> {
&*texture_guard,
&*query_set_guard,
)
.map_pass_err(init_scope)?;
.map_pass_err(pass_scope)?;

tracker.set_size(
Some(&*buffer_guard),
Expand Down Expand Up @@ -2364,9 +2364,9 @@ impl<G: GlobalIdentityHandlerFactory> Global<G> {

log::trace!("Merging renderpass into cmd_buf {:?}", encoder_id);
let (trackers, pending_discard_init_fixups) =
info.finish(raw).map_pass_err(init_scope)?;
info.finish(raw).map_pass_err(pass_scope)?;

encoder.close();
encoder.close().map_pass_err(pass_scope)?;
(trackers, pending_discard_init_fixups)
};

Expand All @@ -2381,7 +2381,7 @@ impl<G: GlobalIdentityHandlerFactory> Global<G> {
let tracker = &mut cmd_buf_data.trackers;

{
let transit = encoder.open();
let transit = encoder.open().map_pass_err(pass_scope)?;

fixup_discarded_surfaces(
pending_discard_init_fixups.into_iter(),
Expand Down Expand Up @@ -2409,7 +2409,7 @@ impl<G: GlobalIdentityHandlerFactory> Global<G> {
}

*status = CommandEncoderStatus::Recording;
encoder.close_and_swap();
encoder.close_and_swap().map_pass_err(pass_scope)?;

Ok(())
}
Expand Down
Loading

0 comments on commit 8358868

Please sign in to comment.