Skip to content

Commit

Permalink
Validate that resources belong to the right device.
Browse files Browse the repository at this point in the history
  • Loading branch information
nical committed Oct 4, 2023
1 parent 32b761a commit 487d745
Show file tree
Hide file tree
Showing 6 changed files with 105 additions and 10 deletions.
1 change: 0 additions & 1 deletion tests/tests/device.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@ fn device_initialization() {
}

#[test]
#[ignore]
fn device_mismatch() {
initialize_test(
// https://github.com/gfx-rs/wgpu/issues/3927
Expand Down
45 changes: 36 additions & 9 deletions wgpu-core/src/command/render.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,14 @@ use crate::{
RenderCommand, RenderCommandError, StateChange,
},
device::{
AttachmentData, Device, MissingDownlevelFlags, MissingFeatures,
AttachmentData, Device, DeviceError, MissingDownlevelFlags, MissingFeatures,
RenderPassCompatibilityCheckType, RenderPassCompatibilityError, RenderPassContext,
},
error::{ErrorFormatter, PrettyError},
global::Global,
hal_api::HalApi,
hub::Token,
id,
id::DeviceId,
identity::GlobalIdentityHandlerFactory,
init_tracker::{MemoryInitKind, TextureInitRange, TextureInitTrackerAction},
pipeline::{self, PipelineFlags},
Expand Down Expand Up @@ -520,12 +519,12 @@ pub enum ColorAttachmentError {
/// Error encountered when performing a render pass.
#[derive(Clone, Debug, Error)]
pub enum RenderPassErrorInner {
#[error(transparent)]
Device(DeviceError),
#[error(transparent)]
ColorAttachment(#[from] ColorAttachmentError),
#[error(transparent)]
Encoder(#[from] CommandEncoderError),
#[error("Device {0:?} is invalid")]
InvalidDevice(DeviceId),
#[error("Attachment texture view {0:?} is invalid")]
InvalidAttachment(id::TextureViewId),
#[error("The format of the depth-stencil attachment ({0:?}) is not a depth-stencil format")]
Expand Down Expand Up @@ -658,6 +657,12 @@ impl From<MissingTextureUsageError> for RenderPassErrorInner {
}
}

impl From<DeviceError> for RenderPassErrorInner {
fn from(error: DeviceError) -> Self {
Self::Device(error)
}
}

/// Error encountered when performing a render pass.
#[derive(Clone, Debug, Error)]
#[error("{scope}")]
Expand Down Expand Up @@ -1351,12 +1356,11 @@ impl<G: GlobalIdentityHandlerFactory> Global<G> {
});
}

let device = &device_guard[cmd_buf.device_id.value];
let device_id = cmd_buf.device_id.value;

let device = &device_guard[device_id];
if !device.is_valid() {
return Err(RenderPassErrorInner::InvalidDevice(
cmd_buf.device_id.value.0,
))
.map_pass_err(init_scope);
return Err(DeviceError::Invalid).map_pass_err(init_scope);
}
cmd_buf.encoder.open_pass(base.label);

Expand Down Expand Up @@ -1451,6 +1455,11 @@ impl<G: GlobalIdentityHandlerFactory> Global<G> {
.add_single(&*bind_group_guard, bind_group_id)
.ok_or(RenderCommandError::InvalidBindGroup(bind_group_id))
.map_pass_err(scope)?;

if bind_group.device_id.value != device_id {
return Err(DeviceError::WrongDevice).map_pass_err(scope);
}

bind_group
.validate_dynamic_bindings(index, &temp_offsets, &cmd_buf.limits)
.map_pass_err(scope)?;
Expand Down Expand Up @@ -1518,6 +1527,10 @@ impl<G: GlobalIdentityHandlerFactory> Global<G> {
.ok_or(RenderCommandError::InvalidPipeline(pipeline_id))
.map_pass_err(scope)?;

if pipeline.device_id.value != device_id {
return Err(DeviceError::WrongDevice).map_pass_err(scope);
}

info.context
.check_compatible(
&pipeline.pass_context,
Expand Down Expand Up @@ -1635,6 +1648,11 @@ impl<G: GlobalIdentityHandlerFactory> Global<G> {
.buffers
.merge_single(&*buffer_guard, buffer_id, hal::BufferUses::INDEX)
.map_pass_err(scope)?;

if buffer.device_id.value != device_id {
return Err(DeviceError::WrongDevice).map_pass_err(scope);
}

check_buffer_usage(buffer.usage, BufferUsages::INDEX)
.map_pass_err(scope)?;
let buf_raw = buffer
Expand Down Expand Up @@ -1683,6 +1701,11 @@ impl<G: GlobalIdentityHandlerFactory> Global<G> {
.buffers
.merge_single(&*buffer_guard, buffer_id, hal::BufferUses::VERTEX)
.map_pass_err(scope)?;

if buffer.device_id.value != device_id {
return Err(DeviceError::WrongDevice).map_pass_err(scope);
}

check_buffer_usage(buffer.usage, BufferUsages::VERTEX)
.map_pass_err(scope)?;
let buf_raw = buffer
Expand Down Expand Up @@ -2265,6 +2288,10 @@ impl<G: GlobalIdentityHandlerFactory> Global<G> {
.ok_or(RenderCommandError::InvalidRenderBundle(bundle_id))
.map_pass_err(scope)?;

if bundle.device_id.value != device_id {
return Err(DeviceError::WrongDevice).map_pass_err(scope);
}

info.context
.check_compatible(
&bundle.context,
Expand Down
4 changes: 4 additions & 0 deletions wgpu-core/src/device/global.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1326,6 +1326,10 @@ impl<G: GlobalIdentityHandlerFactory> Global<G> {
Err(..) => break binding_model::CreateBindGroupError::InvalidLayout,
};

if bind_group_layout.device_id.value.0 != device_id {
break DeviceError::WrongDevice.into();
}

let mut layout_id = id::Valid(desc.layout);
if let Some(id) = bind_group_layout.as_duplicate() {
layout_id = id;
Expand Down
2 changes: 2 additions & 0 deletions wgpu-core/src/device/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -300,6 +300,8 @@ pub enum DeviceError {
OutOfMemory,
#[error("Creation of a resource failed for a reason other than running out of memory.")]
ResourceCreationFailed,
#[error("Attempt to use a resource with a different device from the one that created it")]
WrongDevice,
}

impl From<hal::DeviceError> for DeviceError {
Expand Down
16 changes: 16 additions & 0 deletions wgpu-core/src/device/queue.rs
Original file line number Diff line number Diff line change
Expand Up @@ -397,6 +397,7 @@ impl<G: GlobalIdentityHandlerFactory> Global<G> {
}

let result = self.queue_write_staging_buffer_impl(
queue_id,
device,
device_token,
&staging_buffer,
Expand Down Expand Up @@ -464,6 +465,7 @@ impl<G: GlobalIdentityHandlerFactory> Global<G> {
}

let result = self.queue_write_staging_buffer_impl(
queue_id,
device,
device_token,
&staging_buffer,
Expand Down Expand Up @@ -531,6 +533,7 @@ impl<G: GlobalIdentityHandlerFactory> Global<G> {

fn queue_write_staging_buffer_impl<A: HalApi>(
&self,
device_id: id::DeviceId,
device: &mut super::Device<A>,
device_token: &mut Token<super::Device<A>>,
staging_buffer: &StagingBuffer<A>,
Expand All @@ -551,6 +554,10 @@ impl<G: GlobalIdentityHandlerFactory> Global<G> {
.as_ref()
.ok_or(TransferError::InvalidBuffer(buffer_id))?;

if dst.device_id.value.0 != device_id {
return Err(DeviceError::WrongDevice.into());
}

let src_buffer_size = staging_buffer.size;
self.queue_validate_write_buffer_impl(dst, buffer_id, buffer_offset, src_buffer_size)?;

Expand Down Expand Up @@ -627,6 +634,10 @@ impl<G: GlobalIdentityHandlerFactory> Global<G> {
.get_mut(destination.texture)
.map_err(|_| TransferError::InvalidTexture(destination.texture))?;

if dst.device_id.value.0 != queue_id {
return Err(DeviceError::WrongDevice.into());
}

if !dst.desc.usage.contains(wgt::TextureUsages::COPY_DST) {
return Err(
TransferError::MissingCopyDstUsageFlag(None, Some(destination.texture)).into(),
Expand Down Expand Up @@ -1105,6 +1116,11 @@ impl<G: GlobalIdentityHandlerFactory> Global<G> {
Some(cmdbuf) => cmdbuf,
None => continue,
};

if cmdbuf.device_id.value.0 != queue_id {
return Err(DeviceError::WrongDevice.into());
}

#[cfg(feature = "trace")]
if let Some(ref trace) = device.trace {
trace.lock().add(Action::Submit(
Expand Down
47 changes: 47 additions & 0 deletions wgpu-core/src/device/resource.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1670,6 +1670,7 @@ impl<A: HalApi> Device<A> {
}

fn create_buffer_binding<'a>(
device_id: id::DeviceId,
bb: &binding_model::BufferBinding,
binding: u32,
decl: &wgt::BindGroupLayoutEntry,
Expand Down Expand Up @@ -1727,6 +1728,11 @@ impl<A: HalApi> Device<A> {
.buffers
.add_single(storage, bb.buffer_id, internal_use)
.ok_or(Error::InvalidBuffer(bb.buffer_id))?;

if buffer.device_id.value.0 != device_id {
return Err(DeviceError::WrongDevice.into());
}

check_buffer_usage(buffer.usage, pub_usage)?;
let raw_buffer = buffer
.raw
Expand Down Expand Up @@ -1797,6 +1803,7 @@ impl<A: HalApi> Device<A> {
}

fn create_texture_binding(
device_id: id::DeviceId,
view: &resource::TextureView<A>,
texture_guard: &Storage<resource::Texture<A>, id::TextureId>,
internal_use: hal::TextureUses,
Expand All @@ -1818,6 +1825,11 @@ impl<A: HalApi> Device<A> {
.ok_or(binding_model::CreateBindGroupError::InvalidTexture(
view.parent_id.value.0,
))?;

if texture.device_id.value.0 != device_id {
return Err(DeviceError::WrongDevice.into());
}

check_texture_usage(texture.desc.usage, pub_usage)?;

used_texture_ranges.push(TextureInitTrackerAction {
Expand Down Expand Up @@ -1889,6 +1901,7 @@ impl<A: HalApi> Device<A> {
let (res_index, count) = match entry.resource {
Br::Buffer(ref bb) => {
let bb = Self::create_buffer_binding(
self_id,
bb,
binding,
decl,
Expand All @@ -1911,6 +1924,7 @@ impl<A: HalApi> Device<A> {
let res_index = hal_buffers.len();
for bb in bindings_array.iter() {
let bb = Self::create_buffer_binding(
self_id,
bb,
binding,
decl,
Expand All @@ -1933,6 +1947,10 @@ impl<A: HalApi> Device<A> {
.add_single(&*sampler_guard, id)
.ok_or(Error::InvalidSampler(id))?;

if sampler.device_id.value.0 != self_id {
return Err(DeviceError::WrongDevice.into());
}

// Allowed sampler values for filtering and comparison
let (allowed_filtering, allowed_comparison) = match ty {
wgt::SamplerBindingType::Filtering => (None, false),
Expand Down Expand Up @@ -1981,6 +1999,11 @@ impl<A: HalApi> Device<A> {
.samplers
.add_single(&*sampler_guard, id)
.ok_or(Error::InvalidSampler(id))?;

if sampler.device_id.value.0 != self_id {
return Err(DeviceError::WrongDevice.into());
}

hal_samplers.push(&sampler.raw);
}

Expand All @@ -1998,6 +2021,7 @@ impl<A: HalApi> Device<A> {
"SampledTexture, ReadonlyStorageTexture or WriteonlyStorageTexture",
)?;
Self::create_texture_binding(
self_id,
view,
&texture_guard,
internal_use,
Expand Down Expand Up @@ -2026,6 +2050,7 @@ impl<A: HalApi> Device<A> {
Self::texture_use_parameters(binding, decl, view,
"SampledTextureArray, ReadonlyStorageTextureArray or WriteonlyStorageTextureArray")?;
Self::create_texture_binding(
self_id,
view,
&texture_guard,
internal_use,
Expand Down Expand Up @@ -2324,6 +2349,11 @@ impl<A: HalApi> Device<A> {
let Some(bind_group_layout) = try_get_bind_group_layout(bgl_guard, id) else {
return Err(Error::InvalidBindGroupLayout(id));
};

if bind_group_layout.device_id.value.0 != self_id {
return Err(DeviceError::WrongDevice.into());
}

count_validator.merge(&bind_group_layout.assume_deduplicated().count_validator);
}
count_validator
Expand Down Expand Up @@ -2457,6 +2487,10 @@ impl<A: HalApi> Device<A> {
.get(desc.stage.module)
.map_err(|_| validation::StageError::InvalidModule)?;

if shader_module.device_id.value.0 != self_id {
return Err(DeviceError::WrongDevice.into());
}

{
let flag = wgt::ShaderStages::COMPUTE;
let provided_layouts = match desc.layout {
Expand Down Expand Up @@ -2500,6 +2534,10 @@ impl<A: HalApi> Device<A> {
.get(pipeline_layout_id)
.map_err(|_| pipeline::CreateComputePipelineError::InvalidLayout)?;

if layout.device_id.value.0 != self_id {
return Err(DeviceError::WrongDevice.into());
}

let late_sized_buffer_groups =
Device::make_late_sized_buffer_groups(&shader_binding_sizes, layout, &*bgl_guard);

Expand Down Expand Up @@ -2843,11 +2881,20 @@ impl<A: HalApi> Device<A> {
}
})?;

if shader_module.device_id.value.0 != self_id {
return Err(DeviceError::WrongDevice.into());
}

let provided_layouts = match desc.layout {
Some(pipeline_layout_id) => {
let pipeline_layout = pipeline_layout_guard
.get(pipeline_layout_id)
.map_err(|_| pipeline::CreateRenderPipelineError::InvalidLayout)?;

if pipeline_layout.device_id.value.0 != self_id {
return Err(DeviceError::WrongDevice.into());
}

Some(Device::get_introspection_bind_group_layouts(
pipeline_layout,
&*bgl_guard,
Expand Down

0 comments on commit 487d745

Please sign in to comment.