Skip to content
This repository has been archived by the owner on Jan 29, 2025. It is now read-only.

Add support for zero-initializing workgroup memory #2111

Merged
merged 5 commits into from
Jan 25, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions benches/criterion.rs
Original file line number Diff line number Diff line change
Expand Up @@ -242,6 +242,7 @@ fn backends(c: &mut Criterion) {
version: naga::back::glsl::Version::new_gles(320),
writer_flags: naga::back::glsl::WriterFlags::empty(),
binding_map: Default::default(),
zero_initialize_workgroup_memory: true,
};
for &(ref module, ref info) in inputs.iter() {
for ep in module.entry_points.iter() {
Expand Down
74 changes: 61 additions & 13 deletions src/back/glsl/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -232,6 +232,8 @@ pub struct Options {
pub writer_flags: WriterFlags,
/// Map of resources association to binding locations.
pub binding_map: BindingMap,
/// Should workgroup variables be zero initialized (by polyfilling)?
teoxoy marked this conversation as resolved.
Show resolved Hide resolved
pub zero_initialize_workgroup_memory: bool,
}

impl Default for Options {
Expand All @@ -240,6 +242,7 @@ impl Default for Options {
version: Version::new_gles(310),
writer_flags: WriterFlags::ADJUST_COORDINATE_SPACE,
binding_map: BindingMap::default(),
zero_initialize_workgroup_memory: true,
}
}
}
Expand Down Expand Up @@ -1432,6 +1435,12 @@ impl<'a, W: Write> Writer<'a, W> {
// Close the parentheses and open braces to start the function body
writeln!(self.out, ") {{")?;

if self.options.zero_initialize_workgroup_memory
&& ctx.ty.is_compute_entry_point(self.module)
{
self.write_workgroup_variables_initialization(&ctx)?;
}

// Compose the function arguments from globals, in case of an entry point.
if let back::FunctionType::EntryPoint(ep_index) = ctx.ty {
let stage = self.module.entry_points[ep_index as usize].stage;
Expand Down Expand Up @@ -1520,6 +1529,42 @@ impl<'a, W: Write> Writer<'a, W> {
Ok(())
}

fn write_workgroup_variables_initialization(
&mut self,
ctx: &back::FunctionCtx,
) -> BackendResult {
let mut vars = self
.module
.global_variables
.iter()
.filter(|&(handle, var)| {
!ctx.info[handle].is_empty() && var.space == crate::AddressSpace::WorkGroup
})
.peekable();

if vars.peek().is_some() {
let level = back::Level(1);

writeln!(
self.out,
"{}if (gl_GlobalInvocationID == uvec3(0u)) {{",
level
)?;

for (handle, var) in vars {
let name = &self.names[&NameKey::GlobalVariable(handle)];
write!(self.out, "{}{} = ", level.next(), name)?;
self.write_zero_init_value(var.ty)?;
writeln!(self.out, ";")?;
}

writeln!(self.out, "{}}}", level)?;
self.write_barrier(crate::Barrier::WORK_GROUP, level)?;
}

Ok(())
}

/// Helper method that writes a list of comma separated `T` with a writer function `F`
///
/// The writer function `F` receives a mutable reference to `self` that if needed won't cause
Expand Down Expand Up @@ -2035,18 +2080,8 @@ impl<'a, W: Write> Writer<'a, W> {
// keyword which ceases all further processing in a fragment shader, it's called OpKill
// in spir-v that's why it's called `Statement::Kill`
Statement::Kill => writeln!(self.out, "{}discard;", level)?,
// Issue a memory barrier. Please note that to ensure visibility,
// OpenGL always requires a call to the `barrier()` function after a `memoryBarrier*()`
Statement::Barrier(flags) => {
if flags.contains(crate::Barrier::STORAGE) {
writeln!(self.out, "{}memoryBarrierBuffer();", level)?;
}

if flags.contains(crate::Barrier::WORK_GROUP) {
writeln!(self.out, "{}memoryBarrierShared();", level)?;
}

writeln!(self.out, "{}barrier();", level)?;
self.write_barrier(flags, level)?;
}
// Stores in glsl are just variable assignments written as `pointer = value;`
Statement::Store { pointer, value } => {
Expand Down Expand Up @@ -3558,7 +3593,7 @@ impl<'a, W: Write> Writer<'a, W> {
fn write_zero_init_value(&mut self, ty: Handle<crate::Type>) -> BackendResult {
let inner = &self.module.types[ty].inner;
match *inner {
TypeInner::Scalar { kind, .. } => {
TypeInner::Scalar { kind, .. } | TypeInner::Atomic { kind, .. } => {
jimblandy marked this conversation as resolved.
Show resolved Hide resolved
self.write_zero_init_scalar(kind)?;
}
TypeInner::Vector { kind, .. } => {
Expand Down Expand Up @@ -3603,7 +3638,7 @@ impl<'a, W: Write> Writer<'a, W> {
}
write!(self.out, ")")?;
}
_ => {} // TODO:
_ => unreachable!(),
}

Ok(())
Expand All @@ -3621,6 +3656,19 @@ impl<'a, W: Write> Writer<'a, W> {
Ok(())
}

/// Issue a memory barrier. Please note that to ensure visibility,
/// OpenGL always requires a call to the `barrier()` function after a `memoryBarrier*()`
fn write_barrier(&mut self, flags: crate::Barrier, level: back::Level) -> BackendResult {
if flags.contains(crate::Barrier::STORAGE) {
writeln!(self.out, "{}memoryBarrierBuffer();", level)?;
}
if flags.contains(crate::Barrier::WORK_GROUP) {
writeln!(self.out, "{}memoryBarrierShared();", level)?;
}
writeln!(self.out, "{}barrier();", level)?;
Ok(())
}

teoxoy marked this conversation as resolved.
Show resolved Hide resolved
/// Helper function that return the glsl storage access string of [`StorageAccess`](crate::StorageAccess)
///
/// glsl allows adding both `readonly` and `writeonly` but this means that
Expand Down
3 changes: 3 additions & 0 deletions src/back/hlsl/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,8 @@ pub struct Options {
pub special_constants_binding: Option<BindTarget>,
/// Bind target of the push constant buffer
pub push_constants_target: Option<BindTarget>,
/// Should workgroup variables be zero initialized (by polyfilling)?
pub zero_initialize_workgroup_memory: bool,
}

impl Default for Options {
Expand All @@ -201,6 +203,7 @@ impl Default for Options {
fake_missing_bindings: true,
special_constants_binding: None,
push_constants_target: None,
zero_initialize_workgroup_memory: true,
}
}
}
Expand Down
101 changes: 73 additions & 28 deletions src/back/hlsl/writer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1077,6 +1077,9 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
// Write function name
write!(self.out, " {}(", name)?;

let need_workgroup_variables_initialization =
self.need_workgroup_variables_initialization(func_ctx, module);

// Write function arguments for non entry point functions
match func_ctx.ty {
back::FunctionType::Function(handle) => {
Expand Down Expand Up @@ -1129,6 +1132,16 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
self.write_semantic(binding, Some((stage, Io::Input)))?;
}
}

if need_workgroup_variables_initialization {
if !func.arguments.is_empty() {
write!(self.out, ", ")?;
}
write!(
self.out,
"uint3 __global_invocation_id : SV_DispatchThreadID"
)?;
}
}
}
}
Expand All @@ -1151,6 +1164,10 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
writeln!(self.out)?;
writeln!(self.out, "{{")?;

if need_workgroup_variables_initialization {
self.write_workgroup_variables_initialization(func_ctx, module)?;
}

if let back::FunctionType::EntryPoint(index) = func_ctx.ty {
self.write_ep_arguments_initialization(module, func, index)?;
}
Expand Down Expand Up @@ -1204,6 +1221,46 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
Ok(())
}

fn need_workgroup_variables_initialization(
&mut self,
func_ctx: &back::FunctionCtx,
module: &Module,
) -> bool {
self.options.zero_initialize_workgroup_memory
&& func_ctx.ty.is_compute_entry_point(module)
&& module.global_variables.iter().any(|(handle, var)| {
!func_ctx.info[handle].is_empty() && var.space == crate::AddressSpace::WorkGroup
})
}

fn write_workgroup_variables_initialization(
&mut self,
func_ctx: &back::FunctionCtx,
module: &Module,
) -> BackendResult {
let level = back::Level(1);

writeln!(
self.out,
"{}if (all(__global_invocation_id == uint3(0u, 0u, 0u))) {{",
level
)?;

let vars = module.global_variables.iter().filter(|&(handle, var)| {
!func_ctx.info[handle].is_empty() && var.space == crate::AddressSpace::WorkGroup
});

for (handle, var) in vars {
let name = &self.names[&NameKey::GlobalVariable(handle)];
write!(self.out, "{}{} = ", level.next(), name)?;
self.write_default_init(module, var.ty)?;
writeln!(self.out, ";")?;
}

writeln!(self.out, "{}}}", level)?;
self.write_barrier(crate::Barrier::WORK_GROUP, level)
}

/// Helper method used to write statements
///
/// # Notes
Expand Down Expand Up @@ -1690,13 +1747,7 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
Statement::Break => writeln!(self.out, "{}break;", level)?,
Statement::Continue => writeln!(self.out, "{}continue;", level)?,
Statement::Barrier(barrier) => {
if barrier.contains(crate::Barrier::STORAGE) {
writeln!(self.out, "{}DeviceMemoryBarrierWithGroupSync();", level)?;
}

if barrier.contains(crate::Barrier::WORK_GROUP) {
writeln!(self.out, "{}GroupMemoryBarrierWithGroupSync();", level)?;
}
self.write_barrier(barrier, level)?;
}
Statement::ImageStore {
image,
Expand Down Expand Up @@ -2848,27 +2899,21 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {

/// Helper function that write default zero initialization
fn write_default_init(&mut self, module: &Module, ty: Handle<crate::Type>) -> BackendResult {
match module.types[ty].inner {
TypeInner::Array {
size: crate::ArraySize::Constant(const_handle),
base,
..
} => {
write!(self.out, "{{")?;
let count = module.constants[const_handle].to_array_length().unwrap();
for i in 0..count {
if i != 0 {
write!(self.out, ",")?;
}
self.write_default_init(module, base)?;
}
write!(self.out, "}}")?;
}
_ => {
write!(self.out, "(")?;
self.write_type(module, ty)?;
write!(self.out, ")0")?;
}
write!(self.out, "(")?;
self.write_type(module, ty)?;
if let TypeInner::Array { base, size, .. } = module.types[ty].inner {
self.write_array_size(module, base, size)?;
}
write!(self.out, ")0")?;
Ok(())
}

fn write_barrier(&mut self, barrier: crate::Barrier, level: back::Level) -> BackendResult {
if barrier.contains(crate::Barrier::STORAGE) {
writeln!(self.out, "{}DeviceMemoryBarrierWithGroupSync();", level)?;
}
if barrier.contains(crate::Barrier::WORK_GROUP) {
writeln!(self.out, "{}GroupMemoryBarrierWithGroupSync();", level)?;
}
Ok(())
}
Expand Down
11 changes: 11 additions & 0 deletions src/back/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,17 @@ enum FunctionType {
EntryPoint(crate::proc::EntryPointIndex),
}

impl FunctionType {
fn is_compute_entry_point(&self, module: &crate::Module) -> bool {
match *self {
FunctionType::EntryPoint(index) => {
module.entry_points[index as usize].stage == crate::ShaderStage::Compute
}
_ => false,
}
}
}

/// Helper structure that stores data needed when writing the function
struct FunctionCtx<'a> {
/// The current function being written
Expand Down
3 changes: 3 additions & 0 deletions src/back/msl/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,8 @@ pub struct Options {
/// Bounds checking policies.
#[cfg_attr(feature = "deserialize", serde(default))]
pub bounds_check_policies: index::BoundsCheckPolicies,
/// Should workgroup variables be zero initialized (by polyfilling)?
pub zero_initialize_workgroup_memory: bool,
}

impl Default for Options {
Expand All @@ -220,6 +222,7 @@ impl Default for Options {
spirv_cross_compatibility: false,
fake_missing_bindings: true,
bounds_check_policies: index::BoundsCheckPolicies::default(),
zero_initialize_workgroup_memory: true,
}
}
}
Expand Down
Loading