Skip to content

Commit

Permalink
Use global static for web identifiers
Browse files Browse the repository at this point in the history
  • Loading branch information
TheOnlyMrCat committed Oct 16, 2022
1 parent 67163bd commit 9fbd0ad
Showing 1 changed file with 42 additions and 117 deletions.
159 changes: 42 additions & 117 deletions wgpu/src/backend/web.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
#![allow(clippy::type_complexity)]

use js_sys::Promise;
#[cfg(feature = "expose-ids")]
use std::sync::atomic::{self, AtomicU64};
use std::{
cell::RefCell,
fmt,
Expand Down Expand Up @@ -44,10 +46,20 @@ impl<T> crate::GlobalId for Identified<T> {

pub(crate) type Id = u64;

pub(crate) struct Context(
web_sys::Gpu,
#[cfg(feature = "expose-ids")] std::sync::atomic::AtomicU64,
);
#[cfg(feature = "expose-ids")]
static NEXT_ID: AtomicU64 = AtomicU64::new(0);

#[cfg(not(feature = "expose-ids"))]
fn create_identified<T>(value: T) -> Identified<T> {
Identified(value)
}

#[cfg(feature = "expose-ids")]
fn create_identified<T>(value: T) -> Identified<T> {
Identified(value, NEXT_ID.fetch_add(1, atomic::Ordering::Relaxed))
}

pub(crate) struct Context(web_sys::Gpu);
unsafe impl Send for Context {}
unsafe impl Sync for Context {}

Expand Down Expand Up @@ -114,36 +126,6 @@ impl<F, M> MakeSendFuture<F, M> {

unsafe impl<F, M> Send for MakeSendFuture<F, M> {}

pub(crate) struct MakeIdentifiedFuture<F, M> {
future: F,
map: M,
id: u64,
}

impl<F: Future, M: Fn(F::Output, u64) -> T, T> Future for MakeIdentifiedFuture<F, M> {
type Output = T;

fn poll(self: Pin<&mut Self>, cx: &mut task::Context) -> Poll<Self::Output> {
// This is safe because we have no Drop implementation to violate the Pin requirements and
// do not provide any means of moving the inner future.
unsafe {
let this = self.get_unchecked_mut();
match Pin::new_unchecked(&mut this.future).poll(cx) {
task::Poll::Ready(value) => task::Poll::Ready((this.map)(value, this.id)),
task::Poll::Pending => task::Poll::Pending,
}
}
}
}

impl<F, M> MakeIdentifiedFuture<F, M> {
fn new(future: F, map: M, id: u64) -> Self {
Self { future, map, id }
}
}

unsafe impl<F, M> Send for MakeIdentifiedFuture<F, M> {}

impl crate::ComputePassInner<Context> for ComputePass {
fn set_pipeline(&mut self, pipeline: &Identified<web_sys::GpuComputePipeline>) {
self.0.set_pipeline(&pipeline.0);
Expand Down Expand Up @@ -963,29 +945,15 @@ fn map_map_mode(mode: crate::MapMode) -> u32 {

type JsFutureResult = Result<wasm_bindgen::JsValue, wasm_bindgen::JsValue>;

#[cfg(not(feature = "expose-ids"))]
fn create_identified<T>(value: T, _id: u64) -> Identified<T> {
Identified(value)
}

#[cfg(feature = "expose-ids")]
fn create_identified<T>(value: T, id: u64) -> Identified<T> {
Identified(value, id)
}

fn future_request_adapter(
result: JsFutureResult,
id: u64,
) -> Option<Identified<web_sys::GpuAdapter>> {
fn future_request_adapter(result: JsFutureResult) -> Option<Identified<web_sys::GpuAdapter>> {
match result.and_then(wasm_bindgen::JsCast::dyn_into) {
Ok(adapter) => Some(create_identified(adapter, id)),
Ok(adapter) => Some(create_identified(adapter)),
Err(_) => None,
}
}

fn future_request_device(
result: JsFutureResult,
id: u64,
) -> Result<
(
Identified<web_sys::GpuDevice>,
Expand All @@ -998,10 +966,7 @@ fn future_request_device(
let device_id = web_sys::GpuDevice::from(js_value);
let queue_id = device_id.queue();

(
create_identified(device_id, id),
create_identified(queue_id, id + 1),
)
(create_identified(device_id), create_identified(queue_id))
})
.map_err(|_| crate::RequestDeviceError)
}
Expand Down Expand Up @@ -1058,32 +1023,6 @@ where
*rc_callback.borrow_mut() = Some((closure_success, closure_rejected, callback));
}

impl Context {
#[cfg(not(feature = "expose-ids"))]
fn create_identified<T>(&self, item: T) -> Identified<T> {
Identified(item)
}

#[cfg(not(feature = "expose-ids"))]
fn allocate_id(&self, _count: u64) -> u64 {
0
}

#[cfg(feature = "expose-ids")]
fn create_identified<T>(&self, item: T) -> Identified<T> {
Identified(
item,
self.1.fetch_add(1, std::sync::atomic::Ordering::Relaxed),
)
}

#[cfg(feature = "expose-ids")]
fn allocate_id(&self, count: u64) -> u64 {
self.1
.fetch_add(count, std::sync::atomic::Ordering::Relaxed)
}
}

impl Context {
pub fn instance_create_surface_from_canvas(
&self,
Expand All @@ -1093,7 +1032,7 @@ impl Context {
Ok(Some(ctx)) => ctx.into(),
_ => panic!("expected to get context from canvas"),
};
self.create_identified(context.into())
create_identified(context.into())
}

pub fn instance_create_surface_from_offscreen_canvas(
Expand All @@ -1104,7 +1043,7 @@ impl Context {
Ok(Some(ctx)) => ctx.into(),
_ => panic!("expected to get context from canvas"),
};
self.create_identified(context.into())
create_identified(context.into())
}

pub fn queue_copy_external_image_to_texture(
Expand Down Expand Up @@ -1171,21 +1110,17 @@ impl crate::Context for Context {
type SurfaceOutputDetail = SurfaceOutputDetail;
type SubmissionIndex = SubmissionIndex;

type RequestAdapterFuture = MakeIdentifiedFuture<
type RequestAdapterFuture = MakeSendFuture<
wasm_bindgen_futures::JsFuture,
fn(JsFutureResult, u64) -> Option<Self::AdapterId>,
fn(JsFutureResult) -> Option<Self::AdapterId>,
>;
type RequestDeviceFuture = MakeIdentifiedFuture<
type RequestDeviceFuture = MakeSendFuture<
wasm_bindgen_futures::JsFuture,
fn(
JsFutureResult,
u64,
) -> Result<(Self::DeviceId, Self::QueueId), crate::RequestDeviceError>,
fn(JsFutureResult) -> Result<(Self::DeviceId, Self::QueueId), crate::RequestDeviceError>,
>;
type PopErrorScopeFuture =
MakeSendFuture<wasm_bindgen_futures::JsFuture, fn(JsFutureResult) -> Option<crate::Error>>;

#[cfg(not(feature = "expose-ids"))]
fn init(_backends: wgt::Backends) -> Self {
let global: Global = js_sys::global().unchecked_into();
let gpu = if !global.window().is_undefined() {
Expand All @@ -1203,14 +1138,6 @@ impl crate::Context for Context {
Context(gpu)
}

#[cfg(feature = "expose-ids")]
fn init(_backends: wgt::Backends) -> Self {
Context(
web_sys::window().unwrap().navigator().gpu(),
std::sync::atomic::AtomicU64::new(0),
)
}

fn instance_create_surface(
&self,
_display_handle: raw_window_handle::RawDisplayHandle,
Expand Down Expand Up @@ -1257,10 +1184,9 @@ impl crate::Context for Context {
mapped_options.power_preference(mapped_power_preference);
let adapter_promise = self.0.request_adapter_with_options(&mapped_options);

MakeIdentifiedFuture::new(
MakeSendFuture::new(
wasm_bindgen_futures::JsFuture::from(adapter_promise),
future_request_adapter,
self.allocate_id(1),
)
}

Expand Down Expand Up @@ -1329,10 +1255,9 @@ impl crate::Context for Context {

let device_promise = adapter.0.request_device_with_descriptor(&mapped_desc);

MakeIdentifiedFuture::new(
MakeSendFuture::new(
wasm_bindgen_futures::JsFuture::from(device_promise),
future_request_device,
self.allocate_id(2),
)
}

Expand Down Expand Up @@ -1461,7 +1386,7 @@ impl crate::Context for Context {
Self::SurfaceOutputDetail,
) {
(
Some(self.create_identified(surface.0.get_current_texture())),
Some(create_identified(surface.0.get_current_texture())),
wgt::SurfaceStatus::Good,
(),
)
Expand Down Expand Up @@ -1585,7 +1510,7 @@ impl crate::Context for Context {
if let Some(label) = desc.label {
descriptor.label(label);
}
self.create_identified(device.0.create_shader_module(&descriptor))
create_identified(device.0.create_shader_module(&descriptor))
}

fn device_create_bind_group_layout(
Expand Down Expand Up @@ -1683,7 +1608,7 @@ impl crate::Context for Context {
if let Some(label) = desc.label {
mapped_desc.label(label);
}
self.create_identified(device.0.create_bind_group_layout(&mapped_desc))
create_identified(device.0.create_bind_group_layout(&mapped_desc))
}

unsafe fn device_create_shader_module_spirv(
Expand Down Expand Up @@ -1741,7 +1666,7 @@ impl crate::Context for Context {
if let Some(label) = desc.label {
mapped_desc.label(label);
}
self.create_identified(device.0.create_bind_group(&mapped_desc))
create_identified(device.0.create_bind_group(&mapped_desc))
}

fn device_create_pipeline_layout(
Expand All @@ -1758,7 +1683,7 @@ impl crate::Context for Context {
if let Some(label) = desc.label {
mapped_desc.label(label);
}
self.create_identified(device.0.create_pipeline_layout(&mapped_desc))
create_identified(device.0.create_pipeline_layout(&mapped_desc))
}

fn device_create_render_pipeline(
Expand Down Expand Up @@ -1849,7 +1774,7 @@ impl crate::Context for Context {
let mapped_primitive = map_primitive_state(&desc.primitive);
mapped_desc.primitive(&mapped_primitive);

self.create_identified(device.0.create_render_pipeline(&mapped_desc))
create_identified(device.0.create_render_pipeline(&mapped_desc))
}

fn device_create_compute_pipeline(
Expand All @@ -1870,7 +1795,7 @@ impl crate::Context for Context {
if let Some(label) = desc.label {
mapped_desc.label(label);
}
self.create_identified(device.0.create_compute_pipeline(&mapped_desc))
create_identified(device.0.create_compute_pipeline(&mapped_desc))
}

fn device_create_buffer(
Expand All @@ -1884,7 +1809,7 @@ impl crate::Context for Context {
if let Some(label) = desc.label {
mapped_desc.label(label);
}
self.create_identified(device.0.create_buffer(&mapped_desc))
create_identified(device.0.create_buffer(&mapped_desc))
}

fn device_create_texture(
Expand All @@ -1903,7 +1828,7 @@ impl crate::Context for Context {
mapped_desc.dimension(map_texture_dimension(desc.dimension));
mapped_desc.mip_level_count(desc.mip_level_count);
mapped_desc.sample_count(desc.sample_count);
self.create_identified(device.0.create_texture(&mapped_desc))
create_identified(device.0.create_texture(&mapped_desc))
}

fn device_create_sampler(
Expand All @@ -1928,7 +1853,7 @@ impl crate::Context for Context {
if let Some(label) = desc.label {
mapped_desc.label(label);
}
self.create_identified(device.0.create_sampler_with_descriptor(&mapped_desc))
create_identified(device.0.create_sampler_with_descriptor(&mapped_desc))
}

fn device_create_query_set(
Expand All @@ -1945,7 +1870,7 @@ impl crate::Context for Context {
if let Some(label) = desc.label {
mapped_desc.label(label);
}
self.create_identified(device.0.create_query_set(&mapped_desc))
create_identified(device.0.create_query_set(&mapped_desc))
}

fn device_create_command_encoder(
Expand Down Expand Up @@ -2091,7 +2016,7 @@ impl crate::Context for Context {
if let Some(label) = desc.label {
mapped.label(label);
}
self.create_identified(texture.0.create_view_with_descriptor(&mapped))
create_identified(texture.0.create_view_with_descriptor(&mapped))
}

fn surface_drop(&self, _surface: &Self::SurfaceId) {
Expand Down Expand Up @@ -2171,15 +2096,15 @@ impl crate::Context for Context {
pipeline: &Self::ComputePipelineId,
index: u32,
) -> Self::BindGroupLayoutId {
self.create_identified(pipeline.0.get_bind_group_layout(index))
create_identified(pipeline.0.get_bind_group_layout(index))
}

fn render_pipeline_get_bind_group_layout(
&self,
pipeline: &Self::RenderPipelineId,
index: u32,
) -> Self::BindGroupLayoutId {
self.create_identified(pipeline.0.get_bind_group_layout(index))
create_identified(pipeline.0.get_bind_group_layout(index))
}

fn command_encoder_copy_buffer_to_buffer(
Expand Down Expand Up @@ -2428,7 +2353,7 @@ impl crate::Context for Context {
encoder: Self::RenderBundleEncoderId,
desc: &crate::RenderBundleDescriptor,
) -> Self::RenderBundleId {
self.create_identified(match desc.label {
create_identified(match desc.label {
Some(label) => {
let mut mapped_desc = web_sys::GpuRenderBundleDescriptor::new();
mapped_desc.label(label);
Expand Down

0 comments on commit 9fbd0ad

Please sign in to comment.