From 28c45321e53c3dac8a613e2402d0e336d886ad49 Mon Sep 17 00:00:00 2001 From: Dzmitry Malyshau Date: Wed, 3 Nov 2021 22:44:46 -0400 Subject: [PATCH] hlsl: respect array stride in storage buffers (#1507) --- src/back/hlsl/storage.rs | 73 ++++++++++------- tests/in/access.wgsl | 3 +- tests/out/glsl/access.foo.Vertex.glsl | 1 + tests/out/hlsl/access.hlsl | 3 +- tests/out/msl/access.msl | 3 +- tests/out/spv/access.spvasm | 112 +++++++++++++------------- tests/out/wgsl/access.wgsl | 3 +- 7 files changed, 110 insertions(+), 88 deletions(-) diff --git a/src/back/hlsl/storage.rs b/src/back/hlsl/storage.rs index 5840e43ed1..b235a061fb 100644 --- a/src/back/hlsl/storage.rs +++ b/src/back/hlsl/storage.rs @@ -367,39 +367,22 @@ impl super::Writer<'_, W> { mut cur_expr: Handle, func_ctx: &FunctionCtx, ) -> Result, Error> { + enum AccessIndex { + Expression(Handle), + Constant(u32), + } + enum Parent<'a> { + Array { stride: u32 }, + Struct(&'a [crate::StructMember]), + } self.temp_access_chain.clear(); - loop { - // determine the size of the pointee - let stride = match *func_ctx.info[cur_expr].ty.inner_with(&module.types) { - crate::TypeInner::Pointer { base, class: _ } => { - module.types[base].inner.span(&module.constants) - } - crate::TypeInner::ValuePointer { size, width, .. } => { - size.map_or(1, |s| s as u32) * width as u32 - } - _ => 0, - }; - let (next_expr, sub) = match func_ctx.expressions[cur_expr] { + loop { + let (next_expr, access_index) = match func_ctx.expressions[cur_expr] { crate::Expression::GlobalVariable(handle) => return Ok(handle), - crate::Expression::Access { base, index } => ( - base, - SubAccess::Index { - value: index, - stride, - }, - ), + crate::Expression::Access { base, index } => (base, AccessIndex::Expression(index)), crate::Expression::AccessIndex { base, index } => { - let sub = match *func_ctx.info[base].ty.inner_with(&module.types) { - crate::TypeInner::Pointer { base, .. } => match module.types[base].inner { - crate::TypeInner::Struct { ref members, .. } => { - SubAccess::Offset(members[index as usize].offset) - } - _ => SubAccess::Offset(index * stride), - }, - _ => SubAccess::Offset(index * stride), - }; - (base, sub) + (base, AccessIndex::Constant(index)) } ref other => { return Err(Error::Unimplemented(format!( @@ -408,6 +391,38 @@ impl super::Writer<'_, W> { ))) } }; + + let parent = match *func_ctx.info[next_expr].ty.inner_with(&module.types) { + crate::TypeInner::Pointer { base, .. } => match module.types[base].inner { + crate::TypeInner::Struct { ref members, .. } => Parent::Struct(members), + crate::TypeInner::Array { stride, .. } => Parent::Array { stride }, + crate::TypeInner::Vector { width, .. } => Parent::Array { + stride: width as u32, + }, + crate::TypeInner::Matrix { rows, width, .. } => Parent::Array { + stride: width as u32 * if rows > crate::VectorSize::Bi { 4 } else { 2 }, + }, + _ => unreachable!(), + }, + crate::TypeInner::ValuePointer { width, .. } => Parent::Array { + stride: width as u32, + }, + _ => unreachable!(), + }; + + let sub = match (parent, access_index) { + (Parent::Array { stride }, AccessIndex::Expression(value)) => { + SubAccess::Index { value, stride } + } + (Parent::Array { stride }, AccessIndex::Constant(index)) => { + SubAccess::Offset(stride * index) + } + (Parent::Struct(members), AccessIndex::Constant(index)) => { + SubAccess::Offset(members[index as usize].offset) + } + (Parent::Struct(_), AccessIndex::Expression(_)) => unreachable!(), + }; + self.temp_access_chain.push(sub); cur_expr = next_expr; } diff --git a/tests/in/access.wgsl b/tests/in/access.wgsl index f4af84913d..4a2228fa3c 100644 --- a/tests/in/access.wgsl +++ b/tests/in/access.wgsl @@ -5,7 +5,7 @@ struct Bar { matrix: mat4x4; atom: atomic; arr: [[stride(8)]] array, 2>; - data: [[stride(4)]] array; + data: [[stride(8)]] array; }; [[group(0), binding(0)]] @@ -37,6 +37,7 @@ fn foo([[builtin(vertex_index)]] vi: u32) -> [[builtin(position)]] vec4 { bar.matrix[1].z = 1.0; bar.matrix = mat4x4(vec4(0.0), vec4(1.0), vec4(2.0), vec4(3.0)); bar.arr = array, 2>(vec2(0u), vec2(1u)); + bar.data[1] = 1; // test array indexing var c = array(a, i32(b), 3, 4, 5); diff --git a/tests/out/glsl/access.foo.Vertex.glsl b/tests/out/glsl/access.foo.Vertex.glsl index 680f52ee52..848bce8cdb 100644 --- a/tests/out/glsl/access.foo.Vertex.glsl +++ b/tests/out/glsl/access.foo.Vertex.glsl @@ -30,6 +30,7 @@ void main() { _group_0_binding_0.matrix[1][2] = 1.0; _group_0_binding_0.matrix = mat4x4(vec4(0.0), vec4(1.0), vec4(2.0), vec4(3.0)); _group_0_binding_0.arr = uvec2[2](uvec2(0u), uvec2(1u)); + _group_0_binding_0.data[1] = 1; c = int[5](a, int(b), 3, 4, 5); c[(vi + 1u)] = 42; int value = c[vi]; diff --git a/tests/out/hlsl/access.hlsl b/tests/out/hlsl/access.hlsl index 41b7acc840..22f27b8814 100644 --- a/tests/out/hlsl/access.hlsl +++ b/tests/out/hlsl/access.hlsl @@ -24,7 +24,7 @@ float4 foo(uint vi : SV_VertexID) : SV_Position float4x4 matrix1 = float4x4(asfloat(bar.Load4(0+0)), asfloat(bar.Load4(0+16)), asfloat(bar.Load4(0+32)), asfloat(bar.Load4(0+48))); uint2 arr[2] = {asuint(bar.Load2(72+0)), asuint(bar.Load2(72+8))}; float b = asfloat(bar.Load(0+48+0)); - int a = asint(bar.Load((((NagaBufferLengthRW(bar) - 88) / 4) - 2u)*4+88)); + int a = asint(bar.Load((((NagaBufferLengthRW(bar) - 88) / 8) - 2u)*8+88)); const float _e25 = read_from_private(foo1); bar.Store(8+16+0, asuint(1.0)); { @@ -39,6 +39,7 @@ float4 foo(uint vi : SV_VertexID) : SV_Position bar.Store2(72+0, asuint(_value2[0])); bar.Store2(72+8, asuint(_value2[1])); } + bar.Store(8+88, asuint(1)); { int _result[5]={ a, int(b), 3, 4, 5 }; for(int _i=0; _i<5; ++_i) c[_i] = _result[_i]; diff --git a/tests/out/msl/access.msl b/tests/out/msl/access.msl index 09f011fe42..884e606b32 100644 --- a/tests/out/msl/access.msl +++ b/tests/out/msl/access.msl @@ -45,11 +45,12 @@ vertex fooOutput foo( metal::float4x4 matrix = bar.matrix; type3 arr = bar.arr; float b = bar.matrix[3].x; - int a = bar.data[(1 + (_buffer_sizes.size0 - 88 - 4) / 4) - 2u]; + int a = bar.data[(1 + (_buffer_sizes.size0 - 88 - 4) / 8) - 2u]; float _e25 = read_from_private(foo1); bar.matrix[1].z = 1.0; bar.matrix = metal::float4x4(metal::float4(0.0), metal::float4(1.0), metal::float4(2.0), metal::float4(3.0)); for(int _i=0; _i<2; ++_i) bar.arr.inner[_i] = type3 {metal::uint2(0u), metal::uint2(1u)}.inner[_i]; + bar.data[1] = 1; for(int _i=0; _i<5; ++_i) c.inner[_i] = type11 {a, static_cast(b), 3, 4, 5}.inner[_i]; c.inner[vi + 1u] = 42; int value = c.inner[vi]; diff --git a/tests/out/spv/access.spvasm b/tests/out/spv/access.spvasm index a5864e90e5..96127247af 100644 --- a/tests/out/spv/access.spvasm +++ b/tests/out/spv/access.spvasm @@ -1,14 +1,14 @@ ; SPIR-V ; Version: 1.1 ; Generator: rspirv -; Bound: 114 +; Bound: 115 OpCapability Shader OpExtension "SPV_KHR_storage_buffer_storage_class" %1 = OpExtInstImport "GLSL.std.450" OpMemoryModel Logical GLSL450 OpEntryPoint Vertex %47 "foo" %42 %45 -OpEntryPoint GLCompute %91 "atomics" -OpExecutionMode %91 LocalSize 1 1 1 +OpEntryPoint GLCompute %92 "atomics" +OpExecutionMode %92 LocalSize 1 1 1 OpSource GLSL 450 OpMemberName %26 0 "matrix" OpMemberName %26 1 "atom" @@ -22,10 +22,10 @@ OpName %38 "foo" OpName %39 "c" OpName %42 "vi" OpName %47 "foo" -OpName %89 "tmp" -OpName %91 "atomics" +OpName %90 "tmp" +OpName %92 "atomics" OpDecorate %24 ArrayStride 8 -OpDecorate %25 ArrayStride 4 +OpDecorate %25 ArrayStride 8 OpDecorate %26 Block OpMemberDecorate %26 0 Offset 0 OpMemberDecorate %26 0 ColMajor @@ -80,10 +80,10 @@ OpDecorate %45 BuiltIn Position %57 = OpTypePointer StorageBuffer %22 %58 = OpTypePointer StorageBuffer %6 %61 = OpTypePointer StorageBuffer %25 -%81 = OpTypePointer Function %4 -%85 = OpTypeVector %4 4 -%93 = OpTypePointer StorageBuffer %4 -%96 = OpConstant %9 64 +%82 = OpTypePointer Function %4 +%86 = OpTypeVector %4 4 +%94 = OpTypePointer StorageBuffer %4 +%97 = OpConstant %9 64 %34 = OpFunction %6 None %35 %33 = OpFunctionParameter %27 %32 = OpLabel @@ -126,52 +126,54 @@ OpStore %73 %72 %76 = OpCompositeConstruct %24 %74 %75 %77 = OpAccessChain %54 %30 %10 OpStore %77 %76 -%78 = OpConvertFToS %4 %60 -%79 = OpCompositeConstruct %29 %65 %78 %18 %19 %17 -OpStore %39 %79 -%80 = OpIAdd %9 %44 %16 -%82 = OpAccessChain %81 %39 %80 -OpStore %82 %20 -%83 = OpAccessChain %81 %39 %44 -%84 = OpLoad %4 %83 -%86 = OpCompositeConstruct %85 %84 %84 %84 %84 -%87 = OpConvertSToF %22 %86 -%88 = OpMatrixTimesVector %22 %53 %87 -OpStore %45 %88 +%78 = OpAccessChain %28 %30 %8 %16 +OpStore %78 %12 +%79 = OpConvertFToS %4 %60 +%80 = OpCompositeConstruct %29 %65 %79 %18 %19 %17 +OpStore %39 %80 +%81 = OpIAdd %9 %44 %16 +%83 = OpAccessChain %82 %39 %81 +OpStore %83 %20 +%84 = OpAccessChain %82 %39 %44 +%85 = OpLoad %4 %84 +%87 = OpCompositeConstruct %86 %85 %85 %85 %85 +%88 = OpConvertSToF %22 %87 +%89 = OpMatrixTimesVector %22 %53 %88 +OpStore %45 %89 OpReturn OpFunctionEnd -%91 = OpFunction %2 None %48 -%90 = OpLabel -%89 = OpVariable %81 Function -OpBranch %92 -%92 = OpLabel -%94 = OpAccessChain %93 %30 %16 -%95 = OpAtomicLoad %4 %94 %12 %96 -%98 = OpAccessChain %93 %30 %16 -%97 = OpAtomicIAdd %4 %98 %12 %96 %17 -OpStore %89 %97 -%100 = OpAccessChain %93 %30 %16 -%99 = OpAtomicISub %4 %100 %12 %96 %17 -OpStore %89 %99 -%102 = OpAccessChain %93 %30 %16 -%101 = OpAtomicAnd %4 %102 %12 %96 %17 -OpStore %89 %101 -%104 = OpAccessChain %93 %30 %16 -%103 = OpAtomicOr %4 %104 %12 %96 %17 -OpStore %89 %103 -%106 = OpAccessChain %93 %30 %16 -%105 = OpAtomicXor %4 %106 %12 %96 %17 -OpStore %89 %105 -%108 = OpAccessChain %93 %30 %16 -%107 = OpAtomicSMin %4 %108 %12 %96 %17 -OpStore %89 %107 -%110 = OpAccessChain %93 %30 %16 -%109 = OpAtomicSMax %4 %110 %12 %96 %17 -OpStore %89 %109 -%112 = OpAccessChain %93 %30 %16 -%111 = OpAtomicExchange %4 %112 %12 %96 %17 -OpStore %89 %111 -%113 = OpAccessChain %93 %30 %16 -OpAtomicStore %113 %12 %96 %95 +%92 = OpFunction %2 None %48 +%91 = OpLabel +%90 = OpVariable %82 Function +OpBranch %93 +%93 = OpLabel +%95 = OpAccessChain %94 %30 %16 +%96 = OpAtomicLoad %4 %95 %12 %97 +%99 = OpAccessChain %94 %30 %16 +%98 = OpAtomicIAdd %4 %99 %12 %97 %17 +OpStore %90 %98 +%101 = OpAccessChain %94 %30 %16 +%100 = OpAtomicISub %4 %101 %12 %97 %17 +OpStore %90 %100 +%103 = OpAccessChain %94 %30 %16 +%102 = OpAtomicAnd %4 %103 %12 %97 %17 +OpStore %90 %102 +%105 = OpAccessChain %94 %30 %16 +%104 = OpAtomicOr %4 %105 %12 %97 %17 +OpStore %90 %104 +%107 = OpAccessChain %94 %30 %16 +%106 = OpAtomicXor %4 %107 %12 %97 %17 +OpStore %90 %106 +%109 = OpAccessChain %94 %30 %16 +%108 = OpAtomicSMin %4 %109 %12 %97 %17 +OpStore %90 %108 +%111 = OpAccessChain %94 %30 %16 +%110 = OpAtomicSMax %4 %111 %12 %97 %17 +OpStore %90 %110 +%113 = OpAccessChain %94 %30 %16 +%112 = OpAtomicExchange %4 %113 %12 %97 %17 +OpStore %90 %112 +%114 = OpAccessChain %94 %30 %16 +OpAtomicStore %114 %12 %97 %96 OpReturn OpFunctionEnd \ No newline at end of file diff --git a/tests/out/wgsl/access.wgsl b/tests/out/wgsl/access.wgsl index c041a20470..9b4c7b0e70 100644 --- a/tests/out/wgsl/access.wgsl +++ b/tests/out/wgsl/access.wgsl @@ -3,7 +3,7 @@ struct Bar { matrix: mat4x4; atom: atomic; arr: [[stride(8)]] array,2>; - data: [[stride(4)]] array; + data: [[stride(8)]] array; }; [[group(0), binding(0)]] @@ -30,6 +30,7 @@ fn foo([[builtin(vertex_index)]] vi: u32) -> [[builtin(position)]] vec4 { bar.matrix[1][2] = 1.0; bar.matrix = mat4x4(vec4(0.0), vec4(1.0), vec4(2.0), vec4(3.0)); bar.arr = array,2>(vec2(0u), vec2(1u)); + bar.data[1] = 1; c = array(a, i32(b), 3, 4, 5); c[(vi + 1u)] = 42; let value: i32 = c[vi];