diff --git a/wgpu/src/backend/webgpu.rs b/wgpu/src/backend/webgpu.rs index 2185d5b8b8f..6facfc8ef56 100644 --- a/wgpu/src/backend/webgpu.rs +++ b/wgpu/src/backend/webgpu.rs @@ -7,6 +7,7 @@ use js_sys::Promise; use std::{ any::Any, cell::RefCell, + collections::HashMap, fmt, future::Future, marker::PhantomData, @@ -1874,6 +1875,11 @@ impl crate::context::Context for ContextWebGpu { let module: &::ShaderModuleData = downcast_ref(desc.vertex.module.data.as_ref()); let mut mapped_vertex_state = webgpu_sys::GpuVertexState::new(&module.0.module); + let _ = js_sys::Reflect::set( + &mapped_vertex_state, + &"constants".into(), + &hashmap_to_jsvalue(desc.vertex.compilation_options.constants), + ); mapped_vertex_state.entry_point(desc.vertex.entry_point); let buffers = desc @@ -1950,6 +1956,11 @@ impl crate::context::Context for ContextWebGpu { downcast_ref(frag.module.data.as_ref()); let mut mapped_fragment_desc = webgpu_sys::GpuFragmentState::new(&module.0.module, &targets); + let _ = js_sys::Reflect::set( + &mapped_fragment_desc, + &"constants".into(), + &hashmap_to_jsvalue(frag.compilation_options.constants), + ); mapped_fragment_desc.entry_point(frag.entry_point); mapped_desc.fragment(&mapped_fragment_desc); } @@ -1976,6 +1987,11 @@ impl crate::context::Context for ContextWebGpu { downcast_ref(desc.module.data.as_ref()); let mut mapped_compute_stage = webgpu_sys::GpuProgrammableStage::new(&shader_module.0.module); + let _ = js_sys::Reflect::set( + &mapped_compute_stage, + &"constants".into(), + &hashmap_to_jsvalue(desc.compilation_options.constants), + ); mapped_compute_stage.entry_point(desc.entry_point); let auto_layout = wasm_bindgen::JsValue::from(webgpu_sys::GpuAutoLayoutMode::Auto); let mut mapped_desc = webgpu_sys::GpuComputePipelineDescriptor::new( @@ -1992,6 +2008,7 @@ impl crate::context::Context for ContextWebGpu { if let Some(label) = desc.label { mapped_desc.label(label); } + create_identified(device_data.0.create_compute_pipeline(&mapped_desc)) } @@ -3808,3 +3825,14 @@ impl Drop for BufferMappedRange { } } } + +/// Converts a hashmap to a Javascript object. +fn hashmap_to_jsvalue(map: &HashMap) -> JsValue { + let obj = js_sys::Object::new(); + + for (k, v) in map.iter() { + let _ = js_sys::Reflect::set(&obj, &k.into(), &(*v).into()); + } + + JsValue::from(obj) +}