Skip to content

Commit

Permalink
[naga, hal] miscellaneous fixes for Atomic64 support (#5952)
Browse files Browse the repository at this point in the history
In `naga::back:hlsl`:

- Generate calls to `Interlocked{op}64` when necessary. not
  `Interlocked{op}`.

- Make atomic operations that do not produce a value emit their
  operands properly.
  
In the Naga snapshot tests:

- Adapt `atomicOps-int64-min-max.wgsl` to include cases that
  cover non-trivial atomic operation operand emitting.

In `wgpu_hal::vulkan::adapter`:

- When retrieving physical device features, be sure to include
  the `PhysicalDeviceShaderAtomicInt64Features` extending struct
  in the chain whenever the `VK_KHR_shader_atomic_int64` extension
  is available.

- Request both `shader_{buffer,shared}_int64_atomics` in the
  `PhysicalDeviceShaderAtomicInt64Features` extending struct when either of
  `wgpu_types::Features::SHADER_INT64_ATOMIC_{ALL_OPS,MIN_MAX}` is requested.

---------

Co-authored-by: Jim Blandy <[email protected]>
  • Loading branch information
JMS55 and jimblandy authored Jul 14, 2024
1 parent 6f16ea4 commit 17fcb19
Show file tree
Hide file tree
Showing 9 changed files with 165 additions and 107 deletions.
14 changes: 9 additions & 5 deletions naga/src/back/hlsl/writer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ use super::{
use crate::{
back::{self, Baked},
proc::{self, NameKey},
valid, Handle, Module, ScalarKind, ShaderStage, TypeInner,
valid, Handle, Module, Scalar, ScalarKind, ShaderStage, TypeInner,
};
use std::{fmt, mem};

Expand Down Expand Up @@ -2013,7 +2013,11 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
// ownership of our reusable access chain buffer.
let chain = mem::take(&mut self.temp_access_chain);
let var_name = &self.names[&NameKey::GlobalVariable(var_handle)];
write!(self.out, "{var_name}.Interlocked{fun_str}(")?;
let width = match func_ctx.resolve_type(value, &module.types) {
&TypeInner::Scalar(Scalar { width: 8, .. }) => "64",
_ => "",
};
write!(self.out, "{var_name}.Interlocked{fun_str}{width}(")?;
self.write_storage_address(module, &chain, func_ctx)?;
self.temp_access_chain = chain;
}
Expand Down Expand Up @@ -2852,7 +2856,7 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
let inner = func_ctx.resolve_type(expr, &module.types);
let close_paren = match convert {
Some(dst_width) => {
let scalar = crate::Scalar {
let scalar = Scalar {
kind,
width: dst_width,
};
Expand Down Expand Up @@ -3213,7 +3217,7 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
// as non-32bit types are DXC only.
Function::MissingIntOverload(fun_name) => {
let scalar_kind = func_ctx.resolve_type(arg, &module.types).scalar();
if let Some(crate::Scalar {
if let Some(Scalar {
kind: ScalarKind::Sint,
width: 4,
}) = scalar_kind
Expand All @@ -3231,7 +3235,7 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
// as non-32bit types are DXC only.
Function::MissingIntReturnType(fun_name) => {
let scalar_kind = func_ctx.resolve_type(arg, &module.types).scalar();
if let Some(crate::Scalar {
if let Some(Scalar {
kind: ScalarKind::Sint,
width: 4,
}) = scalar_kind
Expand Down
4 changes: 4 additions & 0 deletions naga/src/front/wgsl/lower/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2482,6 +2482,10 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
crate::TypeInner::Scalar(crate::Scalar { width: 8, .. })
);
let result = if is_64_bit_min_max && is_statement {
let rctx = ctx.runtime_expression_ctx(span)?;
rctx.block
.extend(rctx.emitter.finish(&rctx.function.expressions));
rctx.emitter.start(&rctx.function.expressions);
None
} else {
let ty = ctx.register_type(value)?;
Expand Down
14 changes: 8 additions & 6 deletions naga/tests/in/atomicOps-int64-min-max.wgsl
Original file line number Diff line number Diff line change
Expand Up @@ -9,19 +9,21 @@ var<storage, read_write> storage_atomic_scalar: atomic<u64>;
var<storage, read_write> storage_atomic_arr: array<atomic<u64>, 2>;
@group(0) @binding(2)
var<storage, read_write> storage_struct: Struct;
@group(0) @binding(3)
var<uniform> input: u64;

@compute
@workgroup_size(2)
fn cs_main(@builtin(local_invocation_id) id: vec3<u32>) {
atomicMax(&storage_atomic_scalar, 1lu);
atomicMax(&storage_atomic_arr[1], 1lu);
atomicMax(&storage_atomic_scalar, input);
atomicMax(&storage_atomic_arr[1], 1 + input);
atomicMax(&storage_struct.atomic_scalar, 1lu);
atomicMax(&storage_struct.atomic_arr[1], 1lu);
atomicMax(&storage_struct.atomic_arr[1], u64(id.x));

workgroupBarrier();

atomicMin(&storage_atomic_scalar, 1lu);
atomicMin(&storage_atomic_arr[1], 1lu);
atomicMin(&storage_atomic_scalar, input);
atomicMin(&storage_atomic_arr[1], 1 + input);
atomicMin(&storage_struct.atomic_scalar, 1lu);
atomicMin(&storage_struct.atomic_arr[1], 1lu);
atomicMin(&storage_struct.atomic_arr[1], u64(id.x));
}
21 changes: 13 additions & 8 deletions naga/tests/out/hlsl/atomicOps-int64-min-max.hlsl
Original file line number Diff line number Diff line change
Expand Up @@ -13,18 +13,23 @@ struct Struct {
RWByteAddressBuffer storage_atomic_scalar : register(u0);
RWByteAddressBuffer storage_atomic_arr : register(u1);
RWByteAddressBuffer storage_struct : register(u2);
cbuffer input : register(b3) { uint64_t input; }

[numthreads(2, 1, 1)]
void cs_main(uint3 id : SV_GroupThreadID)
{
storage_atomic_scalar.InterlockedMax(0, 1uL);
storage_atomic_arr.InterlockedMax(8, 1uL);
storage_struct.InterlockedMax(0, 1uL);
storage_struct.InterlockedMax(8+8, 1uL);
uint64_t _e3 = input;
storage_atomic_scalar.InterlockedMax64(0, _e3);
uint64_t _e7 = input;
storage_atomic_arr.InterlockedMax64(8, (1uL + _e7));
storage_struct.InterlockedMax64(0, 1uL);
storage_struct.InterlockedMax64(8+8, uint64_t(id.x));
GroupMemoryBarrierWithGroupSync();
storage_atomic_scalar.InterlockedMin(0, 1uL);
storage_atomic_arr.InterlockedMin(8, 1uL);
storage_struct.InterlockedMin(0, 1uL);
storage_struct.InterlockedMin(8+8, 1uL);
uint64_t _e20 = input;
storage_atomic_scalar.InterlockedMin64(0, _e20);
uint64_t _e24 = input;
storage_atomic_arr.InterlockedMin64(8, (1uL + _e24));
storage_struct.InterlockedMin64(0, 1uL);
storage_struct.InterlockedMin64(8+8, uint64_t(id.x));
return;
}
64 changes: 32 additions & 32 deletions naga/tests/out/hlsl/atomicOps-int64.hlsl
Original file line number Diff line number Diff line change
Expand Up @@ -44,72 +44,72 @@ void cs_main(uint3 id : SV_GroupThreadID, uint3 __local_invocation_id : SV_Group
uint64_t l6_ = workgroup_struct.atomic_scalar;
int64_t l7_ = workgroup_struct.atomic_arr[1];
GroupMemoryBarrierWithGroupSync();
uint64_t _e51; storage_atomic_scalar.InterlockedAdd(0, 1uL, _e51);
int64_t _e55; storage_atomic_arr.InterlockedAdd(8, 1L, _e55);
uint64_t _e59; storage_struct.InterlockedAdd(0, 1uL, _e59);
int64_t _e64; storage_struct.InterlockedAdd(8+8, 1L, _e64);
uint64_t _e51; storage_atomic_scalar.InterlockedAdd64(0, 1uL, _e51);
int64_t _e55; storage_atomic_arr.InterlockedAdd64(8, 1L, _e55);
uint64_t _e59; storage_struct.InterlockedAdd64(0, 1uL, _e59);
int64_t _e64; storage_struct.InterlockedAdd64(8+8, 1L, _e64);
uint64_t _e67; InterlockedAdd(workgroup_atomic_scalar, 1uL, _e67);
int64_t _e71; InterlockedAdd(workgroup_atomic_arr[1], 1L, _e71);
uint64_t _e75; InterlockedAdd(workgroup_struct.atomic_scalar, 1uL, _e75);
int64_t _e80; InterlockedAdd(workgroup_struct.atomic_arr[1], 1L, _e80);
GroupMemoryBarrierWithGroupSync();
uint64_t _e83; storage_atomic_scalar.InterlockedAdd(0, -1uL, _e83);
int64_t _e87; storage_atomic_arr.InterlockedAdd(8, -1L, _e87);
uint64_t _e91; storage_struct.InterlockedAdd(0, -1uL, _e91);
int64_t _e96; storage_struct.InterlockedAdd(8+8, -1L, _e96);
uint64_t _e83; storage_atomic_scalar.InterlockedAdd64(0, -1uL, _e83);
int64_t _e87; storage_atomic_arr.InterlockedAdd64(8, -1L, _e87);
uint64_t _e91; storage_struct.InterlockedAdd64(0, -1uL, _e91);
int64_t _e96; storage_struct.InterlockedAdd64(8+8, -1L, _e96);
uint64_t _e99; InterlockedAdd(workgroup_atomic_scalar, -1uL, _e99);
int64_t _e103; InterlockedAdd(workgroup_atomic_arr[1], -1L, _e103);
uint64_t _e107; InterlockedAdd(workgroup_struct.atomic_scalar, -1uL, _e107);
int64_t _e112; InterlockedAdd(workgroup_struct.atomic_arr[1], -1L, _e112);
GroupMemoryBarrierWithGroupSync();
storage_atomic_scalar.InterlockedMax(0, 1uL);
storage_atomic_arr.InterlockedMax(8, 1L);
storage_struct.InterlockedMax(0, 1uL);
storage_struct.InterlockedMax(8+8, 1L);
storage_atomic_scalar.InterlockedMax64(0, 1uL);
storage_atomic_arr.InterlockedMax64(8, 1L);
storage_struct.InterlockedMax64(0, 1uL);
storage_struct.InterlockedMax64(8+8, 1L);
InterlockedMax(workgroup_atomic_scalar, 1uL);
InterlockedMax(workgroup_atomic_arr[1], 1L);
InterlockedMax(workgroup_struct.atomic_scalar, 1uL);
InterlockedMax(workgroup_struct.atomic_arr[1], 1L);
GroupMemoryBarrierWithGroupSync();
storage_atomic_scalar.InterlockedMin(0, 1uL);
storage_atomic_arr.InterlockedMin(8, 1L);
storage_struct.InterlockedMin(0, 1uL);
storage_struct.InterlockedMin(8+8, 1L);
storage_atomic_scalar.InterlockedMin64(0, 1uL);
storage_atomic_arr.InterlockedMin64(8, 1L);
storage_struct.InterlockedMin64(0, 1uL);
storage_struct.InterlockedMin64(8+8, 1L);
InterlockedMin(workgroup_atomic_scalar, 1uL);
InterlockedMin(workgroup_atomic_arr[1], 1L);
InterlockedMin(workgroup_struct.atomic_scalar, 1uL);
InterlockedMin(workgroup_struct.atomic_arr[1], 1L);
GroupMemoryBarrierWithGroupSync();
uint64_t _e163; storage_atomic_scalar.InterlockedAnd(0, 1uL, _e163);
int64_t _e167; storage_atomic_arr.InterlockedAnd(8, 1L, _e167);
uint64_t _e171; storage_struct.InterlockedAnd(0, 1uL, _e171);
int64_t _e176; storage_struct.InterlockedAnd(8+8, 1L, _e176);
uint64_t _e163; storage_atomic_scalar.InterlockedAnd64(0, 1uL, _e163);
int64_t _e167; storage_atomic_arr.InterlockedAnd64(8, 1L, _e167);
uint64_t _e171; storage_struct.InterlockedAnd64(0, 1uL, _e171);
int64_t _e176; storage_struct.InterlockedAnd64(8+8, 1L, _e176);
uint64_t _e179; InterlockedAnd(workgroup_atomic_scalar, 1uL, _e179);
int64_t _e183; InterlockedAnd(workgroup_atomic_arr[1], 1L, _e183);
uint64_t _e187; InterlockedAnd(workgroup_struct.atomic_scalar, 1uL, _e187);
int64_t _e192; InterlockedAnd(workgroup_struct.atomic_arr[1], 1L, _e192);
GroupMemoryBarrierWithGroupSync();
uint64_t _e195; storage_atomic_scalar.InterlockedOr(0, 1uL, _e195);
int64_t _e199; storage_atomic_arr.InterlockedOr(8, 1L, _e199);
uint64_t _e203; storage_struct.InterlockedOr(0, 1uL, _e203);
int64_t _e208; storage_struct.InterlockedOr(8+8, 1L, _e208);
uint64_t _e195; storage_atomic_scalar.InterlockedOr64(0, 1uL, _e195);
int64_t _e199; storage_atomic_arr.InterlockedOr64(8, 1L, _e199);
uint64_t _e203; storage_struct.InterlockedOr64(0, 1uL, _e203);
int64_t _e208; storage_struct.InterlockedOr64(8+8, 1L, _e208);
uint64_t _e211; InterlockedOr(workgroup_atomic_scalar, 1uL, _e211);
int64_t _e215; InterlockedOr(workgroup_atomic_arr[1], 1L, _e215);
uint64_t _e219; InterlockedOr(workgroup_struct.atomic_scalar, 1uL, _e219);
int64_t _e224; InterlockedOr(workgroup_struct.atomic_arr[1], 1L, _e224);
GroupMemoryBarrierWithGroupSync();
uint64_t _e227; storage_atomic_scalar.InterlockedXor(0, 1uL, _e227);
int64_t _e231; storage_atomic_arr.InterlockedXor(8, 1L, _e231);
uint64_t _e235; storage_struct.InterlockedXor(0, 1uL, _e235);
int64_t _e240; storage_struct.InterlockedXor(8+8, 1L, _e240);
uint64_t _e227; storage_atomic_scalar.InterlockedXor64(0, 1uL, _e227);
int64_t _e231; storage_atomic_arr.InterlockedXor64(8, 1L, _e231);
uint64_t _e235; storage_struct.InterlockedXor64(0, 1uL, _e235);
int64_t _e240; storage_struct.InterlockedXor64(8+8, 1L, _e240);
uint64_t _e243; InterlockedXor(workgroup_atomic_scalar, 1uL, _e243);
int64_t _e247; InterlockedXor(workgroup_atomic_arr[1], 1L, _e247);
uint64_t _e251; InterlockedXor(workgroup_struct.atomic_scalar, 1uL, _e251);
int64_t _e256; InterlockedXor(workgroup_struct.atomic_arr[1], 1L, _e256);
uint64_t _e259; storage_atomic_scalar.InterlockedExchange(0, 1uL, _e259);
int64_t _e263; storage_atomic_arr.InterlockedExchange(8, 1L, _e263);
uint64_t _e267; storage_struct.InterlockedExchange(0, 1uL, _e267);
int64_t _e272; storage_struct.InterlockedExchange(8+8, 1L, _e272);
uint64_t _e259; storage_atomic_scalar.InterlockedExchange64(0, 1uL, _e259);
int64_t _e263; storage_atomic_arr.InterlockedExchange64(8, 1L, _e263);
uint64_t _e267; storage_struct.InterlockedExchange64(0, 1uL, _e267);
int64_t _e272; storage_struct.InterlockedExchange64(8+8, 1L, _e272);
uint64_t _e275; InterlockedExchange(workgroup_atomic_scalar, 1uL, _e275);
int64_t _e279; InterlockedExchange(workgroup_atomic_arr[1], 1L, _e279);
uint64_t _e283; InterlockedExchange(workgroup_struct.atomic_scalar, 1uL, _e283);
Expand Down
17 changes: 11 additions & 6 deletions naga/tests/out/msl/atomicOps-int64-min-max.msl
Original file line number Diff line number Diff line change
Expand Up @@ -19,15 +19,20 @@ kernel void cs_main(
, device metal::atomic_ulong& storage_atomic_scalar [[user(fake0)]]
, device type_1& storage_atomic_arr [[user(fake0)]]
, device Struct& storage_struct [[user(fake0)]]
, constant ulong& input [[user(fake0)]]
) {
metal::atomic_max_explicit(&storage_atomic_scalar, 1uL, metal::memory_order_relaxed);
metal::atomic_max_explicit(&storage_atomic_arr.inner[1], 1uL, metal::memory_order_relaxed);
ulong _e3 = input;
metal::atomic_max_explicit(&storage_atomic_scalar, _e3, metal::memory_order_relaxed);
ulong _e7 = input;
metal::atomic_max_explicit(&storage_atomic_arr.inner[1], 1uL + _e7, metal::memory_order_relaxed);
metal::atomic_max_explicit(&storage_struct.atomic_scalar, 1uL, metal::memory_order_relaxed);
metal::atomic_max_explicit(&storage_struct.atomic_arr.inner[1], 1uL, metal::memory_order_relaxed);
metal::atomic_max_explicit(&storage_struct.atomic_arr.inner[1], static_cast<ulong>(id.x), metal::memory_order_relaxed);
metal::threadgroup_barrier(metal::mem_flags::mem_threadgroup);
metal::atomic_min_explicit(&storage_atomic_scalar, 1uL, metal::memory_order_relaxed);
metal::atomic_min_explicit(&storage_atomic_arr.inner[1], 1uL, metal::memory_order_relaxed);
ulong _e20 = input;
metal::atomic_min_explicit(&storage_atomic_scalar, _e20, metal::memory_order_relaxed);
ulong _e24 = input;
metal::atomic_min_explicit(&storage_atomic_arr.inner[1], 1uL + _e24, metal::memory_order_relaxed);
metal::atomic_min_explicit(&storage_struct.atomic_scalar, 1uL, metal::memory_order_relaxed);
metal::atomic_min_explicit(&storage_struct.atomic_arr.inner[1], 1uL, metal::memory_order_relaxed);
metal::atomic_min_explicit(&storage_struct.atomic_arr.inner[1], static_cast<ulong>(id.x), metal::memory_order_relaxed);
return;
}
99 changes: 59 additions & 40 deletions naga/tests/out/spv/atomicOps-int64-min-max.spvasm
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
; SPIR-V
; Version: 1.0
; Generator: rspirv
; Bound: 52
; Bound: 67
OpCapability Shader
OpCapability Int64Atomics
OpCapability Int64
OpExtension "SPV_KHR_storage_buffer_storage_class"
%1 = OpExtInstImport "GLSL.std.450"
OpMemoryModel Logical GLSL450
OpEntryPoint GLCompute %22 "cs_main" %19
OpExecutionMode %22 LocalSize 2 1 1
OpEntryPoint GLCompute %25 "cs_main" %22
OpExecutionMode %25 LocalSize 2 1 1
OpDecorate %4 ArrayStride 8
OpMemberDecorate %7 0 Offset 0
OpMemberDecorate %7 1 Offset 8
Expand All @@ -25,7 +25,11 @@ OpDecorate %15 DescriptorSet 0
OpDecorate %15 Binding 2
OpDecorate %16 Block
OpMemberDecorate %16 0 Offset 0
OpDecorate %19 BuiltIn LocalInvocationId
OpDecorate %18 DescriptorSet 0
OpDecorate %18 Binding 3
OpDecorate %19 Block
OpMemberDecorate %19 0 Offset 0
OpDecorate %22 BuiltIn LocalInvocationId
%2 = OpTypeVoid
%3 = OpTypeInt 64 0
%6 = OpTypeInt 32 0
Expand All @@ -42,41 +46,56 @@ OpDecorate %19 BuiltIn LocalInvocationId
%16 = OpTypeStruct %7
%17 = OpTypePointer StorageBuffer %16
%15 = OpVariable %17 StorageBuffer
%20 = OpTypePointer Input %8
%19 = OpVariable %20 Input
%23 = OpTypeFunction %2
%24 = OpTypePointer StorageBuffer %3
%25 = OpConstant %6 0
%27 = OpTypePointer StorageBuffer %4
%29 = OpTypePointer StorageBuffer %7
%31 = OpConstant %3 1
%35 = OpTypeInt 32 1
%34 = OpConstant %35 1
%36 = OpConstant %6 64
%38 = OpConstant %6 1
%44 = OpConstant %6 264
%22 = OpFunction %2 None %23
%18 = OpLabel
%21 = OpLoad %8 %19
%26 = OpAccessChain %24 %9 %25
%28 = OpAccessChain %27 %12 %25
%30 = OpAccessChain %29 %15 %25
OpBranch %32
%32 = OpLabel
%33 = OpAtomicUMax %3 %26 %34 %36 %31
%39 = OpAccessChain %24 %28 %38
%37 = OpAtomicUMax %3 %39 %34 %36 %31
%41 = OpAccessChain %24 %30 %25
%40 = OpAtomicUMax %3 %41 %34 %36 %31
%43 = OpAccessChain %24 %30 %38 %38
%42 = OpAtomicUMax %3 %43 %34 %36 %31
OpControlBarrier %5 %5 %44
%45 = OpAtomicUMin %3 %26 %34 %36 %31
%47 = OpAccessChain %24 %28 %38
%46 = OpAtomicUMin %3 %47 %34 %36 %31
%49 = OpAccessChain %24 %30 %25
%48 = OpAtomicUMin %3 %49 %34 %36 %31
%51 = OpAccessChain %24 %30 %38 %38
%50 = OpAtomicUMin %3 %51 %34 %36 %31
%19 = OpTypeStruct %3
%20 = OpTypePointer Uniform %19
%18 = OpVariable %20 Uniform
%23 = OpTypePointer Input %8
%22 = OpVariable %23 Input
%26 = OpTypeFunction %2
%27 = OpTypePointer StorageBuffer %3
%28 = OpConstant %6 0
%30 = OpTypePointer StorageBuffer %4
%32 = OpTypePointer StorageBuffer %7
%34 = OpTypePointer Uniform %3
%36 = OpConstant %3 1
%41 = OpTypeInt 32 1
%40 = OpConstant %41 1
%42 = OpConstant %6 64
%46 = OpConstant %6 1
%54 = OpConstant %6 264
%25 = OpFunction %2 None %26
%21 = OpLabel
%24 = OpLoad %8 %22
%29 = OpAccessChain %27 %9 %28
%31 = OpAccessChain %30 %12 %28
%33 = OpAccessChain %32 %15 %28
%35 = OpAccessChain %34 %18 %28
OpBranch %37
%37 = OpLabel
%38 = OpLoad %3 %35
%39 = OpAtomicUMax %3 %29 %40 %42 %38
%43 = OpLoad %3 %35
%44 = OpIAdd %3 %36 %43
%47 = OpAccessChain %27 %31 %46
%45 = OpAtomicUMax %3 %47 %40 %42 %44
%49 = OpAccessChain %27 %33 %28
%48 = OpAtomicUMax %3 %49 %40 %42 %36
%50 = OpCompositeExtract %6 %24 0
%51 = OpUConvert %3 %50
%53 = OpAccessChain %27 %33 %46 %46
%52 = OpAtomicUMax %3 %53 %40 %42 %51
OpControlBarrier %5 %5 %54
%55 = OpLoad %3 %35
%56 = OpAtomicUMin %3 %29 %40 %42 %55
%57 = OpLoad %3 %35
%58 = OpIAdd %3 %36 %57
%60 = OpAccessChain %27 %31 %46
%59 = OpAtomicUMin %3 %60 %40 %42 %58
%62 = OpAccessChain %27 %33 %28
%61 = OpAtomicUMin %3 %62 %40 %42 %36
%63 = OpCompositeExtract %6 %24 0
%64 = OpUConvert %3 %63
%66 = OpAccessChain %27 %33 %46 %46
%65 = OpAtomicUMin %3 %66 %40 %42 %64
OpReturn
OpFunctionEnd
Loading

0 comments on commit 17fcb19

Please sign in to comment.