Skip to content

Commit

Permalink
Add support for pipeline-overridable constants in WebGPU
Browse files Browse the repository at this point in the history
  • Loading branch information
DouglasDwyer committed May 11, 2024
1 parent d0a5e48 commit 331cc11
Showing 1 changed file with 28 additions and 0 deletions.
28 changes: 28 additions & 0 deletions wgpu/src/backend/webgpu.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ use js_sys::Promise;
use std::{
any::Any,
cell::RefCell,
collections::HashMap,
fmt,
future::Future,
marker::PhantomData,
Expand Down Expand Up @@ -1874,6 +1875,11 @@ impl crate::context::Context for ContextWebGpu {
let module: &<ContextWebGpu as crate::Context>::ShaderModuleData =
downcast_ref(desc.vertex.module.data.as_ref());
let mut mapped_vertex_state = webgpu_sys::GpuVertexState::new(&module.0.module);
let _ = js_sys::Reflect::set(
&mapped_vertex_state,
&"constants".into(),
&hashmap_to_jsvalue(desc.vertex.compilation_options.constants),
);
mapped_vertex_state.entry_point(desc.vertex.entry_point);

let buffers = desc
Expand Down Expand Up @@ -1950,6 +1956,11 @@ impl crate::context::Context for ContextWebGpu {
downcast_ref(frag.module.data.as_ref());
let mut mapped_fragment_desc =
webgpu_sys::GpuFragmentState::new(&module.0.module, &targets);
let _ = js_sys::Reflect::set(
&mapped_fragment_desc,
&"constants".into(),
&hashmap_to_jsvalue(frag.compilation_options.constants),
);
mapped_fragment_desc.entry_point(frag.entry_point);
mapped_desc.fragment(&mapped_fragment_desc);
}
Expand All @@ -1976,6 +1987,11 @@ impl crate::context::Context for ContextWebGpu {
downcast_ref(desc.module.data.as_ref());
let mut mapped_compute_stage =
webgpu_sys::GpuProgrammableStage::new(&shader_module.0.module);
let _ = js_sys::Reflect::set(
&mapped_compute_stage,
&"constants".into(),
&hashmap_to_jsvalue(desc.compilation_options.constants),
);
mapped_compute_stage.entry_point(desc.entry_point);
let auto_layout = wasm_bindgen::JsValue::from(webgpu_sys::GpuAutoLayoutMode::Auto);
let mut mapped_desc = webgpu_sys::GpuComputePipelineDescriptor::new(
Expand All @@ -1992,6 +2008,7 @@ impl crate::context::Context for ContextWebGpu {
if let Some(label) = desc.label {
mapped_desc.label(label);
}

create_identified(device_data.0.create_compute_pipeline(&mapped_desc))
}

Expand Down Expand Up @@ -3808,3 +3825,14 @@ impl Drop for BufferMappedRange {
}
}
}

/// Converts a hashmap to a Javascript object.
fn hashmap_to_jsvalue(map: &HashMap<String, f64>) -> JsValue {
let obj = js_sys::Object::new();

for (k, v) in map.iter() {
let _ = js_sys::Reflect::set(&obj, &k.into(), &(*v).into());
}

JsValue::from(obj)
}

0 comments on commit 331cc11

Please sign in to comment.