Skip to content
This repository has been archived by the owner on Jan 29, 2025. It is now read-only.

Commit

Permalink
add support for zero-initializing workgroup memory
Browse files Browse the repository at this point in the history
  • Loading branch information
teoxoy committed Jan 5, 2023
1 parent 763ceb2 commit 835c055
Show file tree
Hide file tree
Showing 45 changed files with 1,201 additions and 80 deletions.
1 change: 1 addition & 0 deletions benches/criterion.rs
Original file line number Diff line number Diff line change
Expand Up @@ -242,6 +242,7 @@ fn backends(c: &mut Criterion) {
version: naga::back::glsl::Version::new_gles(320),
writer_flags: naga::back::glsl::WriterFlags::empty(),
binding_map: Default::default(),
zero_initialize_workgroup_memory: true,
};
for &(ref module, ref info) in inputs.iter() {
for ep in module.entry_points.iter() {
Expand Down
42 changes: 40 additions & 2 deletions src/back/glsl/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,8 @@ pub struct Options {
pub writer_flags: WriterFlags,
/// Map of resources association to binding locations.
pub binding_map: BindingMap,
/// Should workgroup variables be zero initialized (by polyfilling)?
pub zero_initialize_workgroup_memory: bool,
}

impl Default for Options {
Expand All @@ -236,6 +238,7 @@ impl Default for Options {
version: Version::new_gles(310),
writer_flags: WriterFlags::ADJUST_COORDINATE_SPACE,
binding_map: BindingMap::default(),
zero_initialize_workgroup_memory: true,
}
}
}
Expand Down Expand Up @@ -1399,6 +1402,12 @@ impl<'a, W: Write> Writer<'a, W> {
// Close the parentheses and open braces to start the function body
writeln!(self.out, ") {{")?;

if self.options.zero_initialize_workgroup_memory
&& ctx.ty.is_compute_entry_point(self.module)
{
self.write_workgroup_variables_initialization(&ctx)?;
}

// Compose the function arguments from globals, in case of an entry point.
if let back::FunctionType::EntryPoint(ep_index) = ctx.ty {
let stage = self.module.entry_points[ep_index as usize].stage;
Expand Down Expand Up @@ -1487,6 +1496,35 @@ impl<'a, W: Write> Writer<'a, W> {
Ok(())
}

fn write_workgroup_variables_initialization(
&mut self,
ctx: &back::FunctionCtx,
) -> BackendResult {
let vars = self
.module
.global_variables
.iter()
.filter(|&(handle, var)| {
!ctx.info[handle].is_empty() && var.space == crate::AddressSpace::WorkGroup
});

let mut write_barrier = false;

for (handle, var) in vars {
let name = &self.names[&NameKey::GlobalVariable(handle)];
write!(self.out, "{}{} = ", back::INDENT, name)?;
self.write_zero_init_value(var.ty)?;
writeln!(self.out, ";")?;
write_barrier = true;
}

if write_barrier {
self.write_barrier(crate::Barrier::WORK_GROUP, back::Level(1))?;
}

Ok(())
}

/// Helper method that writes a list of comma separated `T` with a writer function `F`
///
/// The writer function `F` receives a mutable reference to `self` that if needed won't cause
Expand Down Expand Up @@ -3515,7 +3553,7 @@ impl<'a, W: Write> Writer<'a, W> {
fn write_zero_init_value(&mut self, ty: Handle<crate::Type>) -> BackendResult {
let inner = &self.module.types[ty].inner;
match *inner {
TypeInner::Scalar { kind, .. } => {
TypeInner::Scalar { kind, .. } | TypeInner::Atomic { kind, .. } => {
self.write_zero_init_scalar(kind)?;
}
TypeInner::Vector { kind, .. } => {
Expand Down Expand Up @@ -3560,7 +3598,7 @@ impl<'a, W: Write> Writer<'a, W> {
}
write!(self.out, ")")?;
}
_ => {} // TODO:
_ => unreachable!(),
}

Ok(())
Expand Down
3 changes: 3 additions & 0 deletions src/back/hlsl/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,8 @@ pub struct Options {
pub special_constants_binding: Option<BindTarget>,
/// Bind target of the push constant buffer
pub push_constants_target: Option<BindTarget>,
/// Should workgroup variables be zero initialized (by polyfilling)?
pub zero_initialize_workgroup_memory: bool,
}

impl Default for Options {
Expand All @@ -201,6 +203,7 @@ impl Default for Options {
fake_missing_bindings: true,
special_constants_binding: None,
push_constants_target: None,
zero_initialize_workgroup_memory: true,
}
}
}
Expand Down
32 changes: 32 additions & 0 deletions src/back/hlsl/writer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1151,6 +1151,12 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
writeln!(self.out)?;
writeln!(self.out, "{{")?;

if self.options.zero_initialize_workgroup_memory
&& func_ctx.ty.is_compute_entry_point(module)
{
self.write_workgroup_variables_initialization(func_ctx, module)?;
}

if let back::FunctionType::EntryPoint(index) = func_ctx.ty {
self.write_ep_arguments_initialization(module, func, index)?;
}
Expand Down Expand Up @@ -1204,6 +1210,32 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
Ok(())
}

fn write_workgroup_variables_initialization(
&mut self,
func_ctx: &back::FunctionCtx,
module: &Module,
) -> BackendResult {
let vars = module.global_variables.iter().filter(|&(handle, var)| {
!func_ctx.info[handle].is_empty() && var.space == crate::AddressSpace::WorkGroup
});

let mut write_barrier = false;

for (handle, var) in vars {
let name = &self.names[&NameKey::GlobalVariable(handle)];
write!(self.out, "{}{} = ", back::INDENT, name)?;
self.write_default_init(module, var.ty)?;
writeln!(self.out, ";")?;
write_barrier = true;
}

if write_barrier {
self.write_barrier(crate::Barrier::WORK_GROUP, back::Level(1))?;
}

Ok(())
}

/// Helper method used to write statements
///
/// # Notes
Expand Down
11 changes: 11 additions & 0 deletions src/back/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,17 @@ enum FunctionType {
EntryPoint(crate::proc::EntryPointIndex),
}

impl FunctionType {
fn is_compute_entry_point(&self, module: &crate::Module) -> bool {
match *self {
FunctionType::EntryPoint(index) => {
module.entry_points[index as usize].stage == crate::ShaderStage::Compute
}
_ => false,
}
}
}

/// Helper structure that stores data needed when writing the function
struct FunctionCtx<'a> {
/// The current function being written
Expand Down
3 changes: 3 additions & 0 deletions src/back/msl/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,8 @@ pub struct Options {
/// Bounds checking policies.
#[cfg_attr(feature = "deserialize", serde(default))]
pub bounds_check_policies: index::BoundsCheckPolicies,
/// Should workgroup variables be zero initialized (by polyfilling)?
pub zero_initialize_workgroup_memory: bool,
}

impl Default for Options {
Expand All @@ -220,6 +222,7 @@ impl Default for Options {
spirv_cross_compatibility: false,
fake_missing_bindings: true,
bounds_check_policies: index::BoundsCheckPolicies::default(),
zero_initialize_workgroup_memory: true,
}
}
}
Expand Down
188 changes: 188 additions & 0 deletions src/back/msl/writer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3744,6 +3744,10 @@ impl<W: Write> Writer<W> {
// end of the entry point argument list
writeln!(self.out, ") {{")?;

if options.zero_initialize_workgroup_memory && ep.stage == crate::ShaderStage::Compute {
self.write_workgroup_variables_initialization(module, fun_info)?;
}

// Metal doesn't support private mutable variables outside of functions,
// so we put them here, just like the locals.
for (handle, var) in module.global_variables.iter() {
Expand Down Expand Up @@ -3939,6 +3943,190 @@ impl<W: Write> Writer<W> {
}
}

/// Initializing workgroup variables is more tricky for Metal because we have to deal
/// with atomics at the type-level (which don't have a copy constructor).
///
/// So we have to "walk" data-structures and look for atomics. If we find one,
/// we write an atomic store for it. For everything else, we zero-initialize
/// via `= {}` as usual.
mod workgroup_mem_init {
use super::*;

enum Access {
GlobalVariable(Handle<crate::GlobalVariable>),
StructMember(Handle<crate::Type>, u32),
Array,
}

impl Access {
fn write<W: Write>(
&self,
writer: &mut W,
names: &FastHashMap<NameKey, String>,
) -> Result<(), core::fmt::Error> {
match *self {
Access::GlobalVariable(handle) => {
write!(writer, "{}", &names[&NameKey::GlobalVariable(handle)])
}
Access::StructMember(handle, index) => {
write!(writer, ".{}", &names[&NameKey::StructMember(handle, index)])
}
Access::Array => write!(writer, ".{}[__i]", WRAPPED_ARRAY_FIELD),
}
}
}

struct AccessStack(Vec<Access>);

impl AccessStack {
const fn new() -> Self {
AccessStack(Vec::new())
}

fn enter<R>(&mut self, new: Access, cb: impl FnOnce(&mut Self) -> R) -> R {
self.0.push(new);
let res = cb(self);
self.0.pop();
res
}

fn write<W: Write>(
&self,
writer: &mut W,
names: &FastHashMap<NameKey, String>,
) -> Result<(), core::fmt::Error> {
for next in self.0.iter() {
next.write(writer, names)?;
}
Ok(())
}
}

impl<W: Write> Writer<W> {
pub(super) fn write_workgroup_variables_initialization(
&mut self,
module: &crate::Module,
fun_info: &valid::FunctionInfo,
) -> BackendResult {
let vars = module.global_variables.iter().filter(|&(handle, var)| {
!fun_info[handle].is_empty() && var.space == crate::AddressSpace::WorkGroup
});

let mut write_barrier = false;
let mut access_stack = AccessStack::new();

for (handle, var) in vars {
access_stack.enter(Access::GlobalVariable(handle), |access_stack| {
self.write_workgroup_variable_initialization(
module,
var.ty,
access_stack,
back::Level(1),
)
})?;
write_barrier = true;
}

if write_barrier {
self.write_barrier(crate::Barrier::WORK_GROUP, back::Level(1))?;
}

Ok(())
}

fn write_workgroup_variable_initialization(
&mut self,
module: &crate::Module,
ty: Handle<crate::Type>,
access_stack: &mut AccessStack,
level: back::Level,
) -> BackendResult {
if does_ty_contain_atomic(module, ty)? {
match module.types[ty].inner {
crate::TypeInner::Atomic { .. } => {
write!(
self.out,
"{}{}::atomic_store_explicit({}",
level, NAMESPACE, ATOMIC_REFERENCE
)?;
access_stack.write(&mut self.out, &self.names)?;
writeln!(self.out, ", 0, {}::memory_order_relaxed);", NAMESPACE)?;
}
crate::TypeInner::Array { base, size, .. } => {
let count = match size.to_indexable_length(module).expect("Bad array size")
{
proc::IndexableLength::Known(count) => count,
proc::IndexableLength::Dynamic => return Ok(()),
};

access_stack.enter(Access::Array, |access_stack| {
writeln!(
self.out,
"{}for (int __i = 0; __i < {}; __i++) {{",
level, count
)?;
self.write_workgroup_variable_initialization(
module,
base,
access_stack,
level.next(),
)?;
writeln!(self.out, "{}}}", level)?;
BackendResult::Ok(())
})?;
}
crate::TypeInner::Struct { ref members, .. } => {
for (index, member) in members.iter().enumerate() {
access_stack.enter(
Access::StructMember(ty, index as u32),
|access_stack| {
self.write_workgroup_variable_initialization(
module,
member.ty,
access_stack,
level,
)
},
)?;
}
}
_ => unreachable!(),
}
} else {
write!(self.out, "{}", level)?;
access_stack.write(&mut self.out, &self.names)?;
writeln!(self.out, " = {{}};")?;
}

Ok(())
}
}

// TODO: we do extra work here, replace with a check of TypeFlags::CONSTRUCTIBLE
fn does_ty_contain_atomic(
module: &crate::Module,
ty: Handle<crate::Type>,
) -> Result<bool, Error> {
let inner = &module.types[ty].inner;
match *inner {
crate::TypeInner::Atomic { .. } => Ok(true),
crate::TypeInner::Scalar { .. }
| crate::TypeInner::Vector { .. }
| crate::TypeInner::Matrix { .. } => Ok(false),
crate::TypeInner::Array { base, .. } => does_ty_contain_atomic(module, base),
crate::TypeInner::Struct { ref members, .. } => {
for member in members {
if does_ty_contain_atomic(module, member.ty)? {
return Ok(true);
}
}
Ok(false)
}
_ => unreachable!(),
}
}
}

#[test]
fn test_stack_size() {
use crate::valid::{Capabilities, ValidationFlags};
Expand Down
Loading

0 comments on commit 835c055

Please sign in to comment.