From 9befaed7e985fb9a23322e9f19baf2a95bfd8b9c Mon Sep 17 00:00:00 2001 From: Evan Mark Hopkins <85699459+evahop@users.noreply.github.com> Date: Tue, 25 Apr 2023 08:49:26 -0400 Subject: [PATCH] [hlsl-out] Fix return type for firstbitlow/high (#2315) --- src/back/hlsl/writer.rs | 47 +++++++--- tests/in/math-functions.wgsl | 7 ++ .../glsl/math-functions.main.Fragment.glsl | 11 ++- tests/out/hlsl/math-functions.hlsl | 25 ++++-- tests/out/msl/math-functions.msl | 9 ++ tests/out/spv/math-functions.spvasm | 89 +++++++++++-------- tests/out/wgsl/math-functions.wgsl | 7 ++ 7 files changed, 132 insertions(+), 63 deletions(-) diff --git a/src/back/hlsl/writer.rs b/src/back/hlsl/writer.rs index a3810c5dab..9e541ae0ad 100644 --- a/src/back/hlsl/writer.rs +++ b/src/back/hlsl/writer.rs @@ -2574,6 +2574,7 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> { Unpack2x16float, Regular(&'static str), MissingIntOverload(&'static str), + MissingIntReturnType(&'static str), CountTrailingZeros, CountLeadingZeros, } @@ -2642,8 +2643,8 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> { Mf::CountLeadingZeros => Function::CountLeadingZeros, Mf::CountOneBits => Function::MissingIntOverload("countbits"), Mf::ReverseBits => Function::MissingIntOverload("reversebits"), - Mf::FindLsb => Function::Regular("firstbitlow"), - Mf::FindMsb => Function::Regular("firstbithigh"), + Mf::FindLsb => Function::MissingIntReturnType("firstbitlow"), + Mf::FindMsb => Function::MissingIntReturnType("firstbithigh"), Mf::Unpack2x16float => Function::Unpack2x16float, _ => return Err(Error::Unimplemented(format!("write_expr_math {fun:?}"))), }; @@ -2707,6 +2708,21 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> { write!(self.out, ")")?; } } + Function::MissingIntReturnType(fun_name) => { + let scalar_kind = &func_ctx.info[arg] + .ty + .inner_with(&module.types) + .scalar_kind(); + if let Some(ScalarKind::Sint) = *scalar_kind { + write!(self.out, "asint({fun_name}(")?; + self.write_expr(module, arg, func_ctx)?; + write!(self.out, "))")?; + } else { + write!(self.out, "{fun_name}(")?; + self.write_expr(module, arg, func_ctx)?; + write!(self.out, ")")?; + } + } Function::CountTrailingZeros => { match *func_ctx.info[arg].ty.inner_with(&module.types) { TypeInner::Vector { size, kind, .. } => { @@ -2721,9 +2737,9 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> { self.write_expr(module, arg, func_ctx)?; write!(self.out, "))")?; } else { - write!(self.out, "asint(min((32u){s}, asuint(firstbitlow(")?; + write!(self.out, "asint(min((32u){s}, firstbitlow(")?; self.write_expr(module, arg, func_ctx)?; - write!(self.out, "))))")?; + write!(self.out, ")))")?; } } TypeInner::Scalar { kind, .. } => { @@ -2732,9 +2748,9 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> { self.write_expr(module, arg, func_ctx)?; write!(self.out, "))")?; } else { - write!(self.out, "asint(min(32u, asuint(firstbitlow(")?; + write!(self.out, "asint(min(32u, firstbitlow(")?; self.write_expr(module, arg, func_ctx)?; - write!(self.out, "))))")?; + write!(self.out, ")))")?; } } _ => unreachable!(), @@ -2752,31 +2768,36 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> { }; if let ScalarKind::Uint = kind { - write!(self.out, "asuint((31){s} - firstbithigh(")?; + write!(self.out, "((31u){s} - firstbithigh(")?; + self.write_expr(module, arg, func_ctx)?; + write!(self.out, "))")?; } else { write!(self.out, "(")?; self.write_expr(module, arg, func_ctx)?; write!( self.out, - " < (0){s} ? (0){s} : (31){s} - firstbithigh(" + " < (0){s} ? (0){s} : (31){s} - asint(firstbithigh(" )?; + self.write_expr(module, arg, func_ctx)?; + write!(self.out, ")))")?; } } TypeInner::Scalar { kind, .. } => { if let ScalarKind::Uint = kind { - write!(self.out, "asuint(31 - firstbithigh(")?; + write!(self.out, "(31u - firstbithigh(")?; + self.write_expr(module, arg, func_ctx)?; + write!(self.out, "))")?; } else { write!(self.out, "(")?; self.write_expr(module, arg, func_ctx)?; - write!(self.out, " < 0 ? 0 : 31 - firstbithigh(")?; + write!(self.out, " < 0 ? 0 : 31 - asint(firstbithigh(")?; + self.write_expr(module, arg, func_ctx)?; + write!(self.out, ")))")?; } } _ => unreachable!(), } - self.write_expr(module, arg, func_ctx)?; - write!(self.out, "))")?; - return Ok(()); } } diff --git a/tests/in/math-functions.wgsl b/tests/in/math-functions.wgsl index da67ca23e0..efeb988f5a 100644 --- a/tests/in/math-functions.wgsl +++ b/tests/in/math-functions.wgsl @@ -10,6 +10,13 @@ fn main() { let g = refract(v, v, f); let const_dot = dot(vec2(), vec2()); let first_leading_bit_abs = firstLeadingBit(abs(0u)); + let flb_a = firstLeadingBit(-1); + let flb_b = firstLeadingBit(vec2(-1)); + let flb_c = firstLeadingBit(vec2(1u)); + let ftb_a = firstTrailingBit(-1); + let ftb_b = firstTrailingBit(1u); + let ftb_c = firstTrailingBit(vec2(-1)); + let ftb_d = firstTrailingBit(vec2(1u)); let ctz_a = countTrailingZeros(0u); let ctz_b = countTrailingZeros(0); let ctz_c = countTrailingZeros(0xFFFFFFFFu); diff --git a/tests/out/glsl/math-functions.main.Fragment.glsl b/tests/out/glsl/math-functions.main.Fragment.glsl index 3c5c1dd345..11c3be1886 100644 --- a/tests/out/glsl/math-functions.main.Fragment.glsl +++ b/tests/out/glsl/math-functions.main.Fragment.glsl @@ -14,6 +14,13 @@ void main() { vec4 g = refract(v, v, 1.0); int const_dot = ( + ivec2(0, 0).x * ivec2(0, 0).x + ivec2(0, 0).y * ivec2(0, 0).y); uint first_leading_bit_abs = uint(findMSB(uint(abs(int(0u))))); + int flb_a = findMSB(-1); + ivec2 flb_b = findMSB(ivec2(-1)); + uvec2 flb_c = uvec2(findMSB(uvec2(1u))); + int ftb_a = findLSB(-1); + uint ftb_b = uint(findLSB(1u)); + ivec2 ftb_c = findLSB(ivec2(-1)); + uvec2 ftb_d = uvec2(findLSB(uvec2(1u))); uint ctz_a = min(uint(findLSB(0u)), 32u); int ctz_b = int(min(uint(findLSB(0)), 32u)); uint ctz_c = min(uint(findLSB(4294967295u)), 32u); @@ -24,8 +31,8 @@ void main() { ivec2 ctz_h = ivec2(min(uvec2(findLSB(ivec2(1))), uvec2(32u))); int clz_a = (-1 < 0 ? 0 : 31 - findMSB(-1)); uint clz_b = uint(31 - findMSB(1u)); - ivec2 _e40 = ivec2(-1); - ivec2 clz_c = mix(ivec2(31) - findMSB(_e40), ivec2(0), lessThan(_e40, ivec2(0))); + ivec2 _e58 = ivec2(-1); + ivec2 clz_c = mix(ivec2(31) - findMSB(_e58), ivec2(0), lessThan(_e58, ivec2(0))); uvec2 clz_d = uvec2(ivec2(31) - findMSB(uvec2(1u))); } diff --git a/tests/out/hlsl/math-functions.hlsl b/tests/out/hlsl/math-functions.hlsl index 958e77d80a..04a9843cb5 100644 --- a/tests/out/hlsl/math-functions.hlsl +++ b/tests/out/hlsl/math-functions.hlsl @@ -10,17 +10,24 @@ void main() float4 g = refract(v, v, 1.0); int const_dot = dot(int2(0, 0), int2(0, 0)); uint first_leading_bit_abs = firstbithigh(abs(0u)); + int flb_a = asint(firstbithigh(-1)); + int2 flb_b = asint(firstbithigh((-1).xx)); + uint2 flb_c = firstbithigh((1u).xx); + int ftb_a = asint(firstbitlow(-1)); + uint ftb_b = firstbitlow(1u); + int2 ftb_c = asint(firstbitlow((-1).xx)); + uint2 ftb_d = firstbitlow((1u).xx); uint ctz_a = min(32u, firstbitlow(0u)); - int ctz_b = asint(min(32u, asuint(firstbitlow(0)))); + int ctz_b = asint(min(32u, firstbitlow(0))); uint ctz_c = min(32u, firstbitlow(4294967295u)); - int ctz_d = asint(min(32u, asuint(firstbitlow(-1)))); + int ctz_d = asint(min(32u, firstbitlow(-1))); uint2 ctz_e = min((32u).xx, firstbitlow((0u).xx)); - int2 ctz_f = asint(min((32u).xx, asuint(firstbitlow((0).xx)))); + int2 ctz_f = asint(min((32u).xx, firstbitlow((0).xx))); uint2 ctz_g = min((32u).xx, firstbitlow((1u).xx)); - int2 ctz_h = asint(min((32u).xx, asuint(firstbitlow((1).xx)))); - int clz_a = (-1 < 0 ? 0 : 31 - firstbithigh(-1)); - uint clz_b = asuint(31 - firstbithigh(1u)); - int2 _expr40 = (-1).xx; - int2 clz_c = (_expr40 < (0).xx ? (0).xx : (31).xx - firstbithigh(_expr40)); - uint2 clz_d = asuint((31).xx - firstbithigh((1u).xx)); + int2 ctz_h = asint(min((32u).xx, firstbitlow((1).xx))); + int clz_a = (-1 < 0 ? 0 : 31 - asint(firstbithigh(-1))); + uint clz_b = (31u - firstbithigh(1u)); + int2 _expr58 = (-1).xx; + int2 clz_c = (_expr58 < (0).xx ? (0).xx : (31).xx - asint(firstbithigh(_expr58))); + uint2 clz_d = ((31u).xx - firstbithigh((1u).xx)); } diff --git a/tests/out/msl/math-functions.msl b/tests/out/msl/math-functions.msl index 3db4644cd6..04eb9d85f8 100644 --- a/tests/out/msl/math-functions.msl +++ b/tests/out/msl/math-functions.msl @@ -18,6 +18,15 @@ fragment void main_( int const_dot = ( + const_type_1_.x * const_type_1_.x + const_type_1_.y * const_type_1_.y); uint _e13 = metal::abs(0u); uint first_leading_bit_abs = metal::select(31 - metal::clz(_e13), uint(-1), _e13 == 0 || _e13 == -1); + int flb_a = metal::select(31 - metal::clz(metal::select(-1, ~-1, -1 < 0)), int(-1), -1 == 0 || -1 == -1); + metal::int2 _e18 = metal::int2(-1); + metal::int2 flb_b = metal::select(31 - metal::clz(metal::select(_e18, ~_e18, _e18 < 0)), int2(-1), _e18 == 0 || _e18 == -1); + metal::uint2 _e21 = metal::uint2(1u); + metal::uint2 flb_c = metal::select(31 - metal::clz(_e21), uint2(-1), _e21 == 0 || _e21 == -1); + int ftb_a = (((metal::ctz(-1) + 1) % 33) - 1); + uint ftb_b = (((metal::ctz(1u) + 1) % 33) - 1); + metal::int2 ftb_c = (((metal::ctz(metal::int2(-1)) + 1) % 33) - 1); + metal::uint2 ftb_d = (((metal::ctz(metal::uint2(1u)) + 1) % 33) - 1); uint ctz_a = metal::ctz(0u); int ctz_b = metal::ctz(0); uint ctz_c = metal::ctz(4294967295u); diff --git a/tests/out/spv/math-functions.spvasm b/tests/out/spv/math-functions.spvasm index 5f06ee2888..8dc763c0d3 100644 --- a/tests/out/spv/math-functions.spvasm +++ b/tests/out/spv/math-functions.spvasm @@ -1,7 +1,7 @@ ; SPIR-V ; Version: 1.1 ; Generator: rspirv -; Bound: 76 +; Bound: 87 OpCapability Shader %1 = OpExtInstImport "GLSL.std.450" OpMemoryModel Logical GLSL450 @@ -15,9 +15,9 @@ OpExecutionMode %18 OriginUpperLeft %6 = OpConstant %7 0 %9 = OpTypeInt 32 0 %8 = OpConstant %9 0 -%10 = OpConstant %9 4294967295 -%11 = OpConstant %7 -1 -%12 = OpConstant %9 1 +%10 = OpConstant %7 -1 +%11 = OpConstant %9 1 +%12 = OpConstant %9 4294967295 %13 = OpConstant %7 1 %14 = OpTypeVector %4 4 %15 = OpTypeVector %7 2 @@ -26,11 +26,11 @@ OpExecutionMode %18 OriginUpperLeft %27 = OpConstantComposite %14 %5 %5 %5 %5 %28 = OpConstantComposite %14 %3 %3 %3 %3 %31 = OpConstantNull %7 -%42 = OpConstant %9 32 -%50 = OpTypeVector %9 2 -%53 = OpConstantComposite %50 %42 %42 -%65 = OpConstant %7 31 -%71 = OpConstantComposite %15 %65 %65 +%44 = OpTypeVector %9 2 +%54 = OpConstant %9 32 +%64 = OpConstantComposite %44 %54 %54 +%76 = OpConstant %7 31 +%82 = OpConstantComposite %15 %76 %76 %18 = OpFunction %2 None %19 %17 = OpLabel OpBranch %20 @@ -52,35 +52,46 @@ OpBranch %20 %30 = OpIAdd %7 %35 %38 %39 = OpCopyObject %9 %8 %40 = OpExtInst %9 %1 FindUMsb %39 -%43 = OpExtInst %9 %1 FindILsb %8 -%41 = OpExtInst %9 %1 UMin %42 %43 -%45 = OpExtInst %7 %1 FindILsb %6 -%44 = OpExtInst %7 %1 UMin %42 %45 -%47 = OpExtInst %9 %1 FindILsb %10 -%46 = OpExtInst %9 %1 UMin %42 %47 -%49 = OpExtInst %7 %1 FindILsb %11 -%48 = OpExtInst %7 %1 UMin %42 %49 -%51 = OpCompositeConstruct %50 %8 %8 -%54 = OpExtInst %50 %1 FindILsb %51 -%52 = OpExtInst %50 %1 UMin %53 %54 -%55 = OpCompositeConstruct %15 %6 %6 -%57 = OpExtInst %15 %1 FindILsb %55 -%56 = OpExtInst %15 %1 UMin %53 %57 -%58 = OpCompositeConstruct %50 %12 %12 -%60 = OpExtInst %50 %1 FindILsb %58 -%59 = OpExtInst %50 %1 UMin %53 %60 -%61 = OpCompositeConstruct %15 %13 %13 -%63 = OpExtInst %15 %1 FindILsb %61 -%62 = OpExtInst %15 %1 UMin %53 %63 -%66 = OpExtInst %7 %1 FindUMsb %11 -%64 = OpISub %7 %65 %66 -%68 = OpExtInst %7 %1 FindUMsb %12 -%67 = OpISub %9 %65 %68 -%69 = OpCompositeConstruct %15 %11 %11 -%72 = OpExtInst %15 %1 FindUMsb %69 -%70 = OpISub %15 %71 %72 -%73 = OpCompositeConstruct %50 %12 %12 -%75 = OpExtInst %15 %1 FindUMsb %73 -%74 = OpISub %50 %71 %75 +%41 = OpExtInst %7 %1 FindSMsb %10 +%42 = OpCompositeConstruct %15 %10 %10 +%43 = OpExtInst %15 %1 FindSMsb %42 +%45 = OpCompositeConstruct %44 %11 %11 +%46 = OpExtInst %44 %1 FindUMsb %45 +%47 = OpExtInst %7 %1 FindILsb %10 +%48 = OpExtInst %9 %1 FindILsb %11 +%49 = OpCompositeConstruct %15 %10 %10 +%50 = OpExtInst %15 %1 FindILsb %49 +%51 = OpCompositeConstruct %44 %11 %11 +%52 = OpExtInst %44 %1 FindILsb %51 +%55 = OpExtInst %9 %1 FindILsb %8 +%53 = OpExtInst %9 %1 UMin %54 %55 +%57 = OpExtInst %7 %1 FindILsb %6 +%56 = OpExtInst %7 %1 UMin %54 %57 +%59 = OpExtInst %9 %1 FindILsb %12 +%58 = OpExtInst %9 %1 UMin %54 %59 +%61 = OpExtInst %7 %1 FindILsb %10 +%60 = OpExtInst %7 %1 UMin %54 %61 +%62 = OpCompositeConstruct %44 %8 %8 +%65 = OpExtInst %44 %1 FindILsb %62 +%63 = OpExtInst %44 %1 UMin %64 %65 +%66 = OpCompositeConstruct %15 %6 %6 +%68 = OpExtInst %15 %1 FindILsb %66 +%67 = OpExtInst %15 %1 UMin %64 %68 +%69 = OpCompositeConstruct %44 %11 %11 +%71 = OpExtInst %44 %1 FindILsb %69 +%70 = OpExtInst %44 %1 UMin %64 %71 +%72 = OpCompositeConstruct %15 %13 %13 +%74 = OpExtInst %15 %1 FindILsb %72 +%73 = OpExtInst %15 %1 UMin %64 %74 +%77 = OpExtInst %7 %1 FindUMsb %10 +%75 = OpISub %7 %76 %77 +%79 = OpExtInst %7 %1 FindUMsb %11 +%78 = OpISub %9 %76 %79 +%80 = OpCompositeConstruct %15 %10 %10 +%83 = OpExtInst %15 %1 FindUMsb %80 +%81 = OpISub %15 %82 %83 +%84 = OpCompositeConstruct %44 %11 %11 +%86 = OpExtInst %15 %1 FindUMsb %84 +%85 = OpISub %44 %82 %86 OpReturn OpFunctionEnd \ No newline at end of file diff --git a/tests/out/wgsl/math-functions.wgsl b/tests/out/wgsl/math-functions.wgsl index e03648e1c3..acd925ac9a 100644 --- a/tests/out/wgsl/math-functions.wgsl +++ b/tests/out/wgsl/math-functions.wgsl @@ -9,6 +9,13 @@ fn main() { let g = refract(v, v, 1.0); let const_dot = dot(vec2(0, 0), vec2(0, 0)); let first_leading_bit_abs = firstLeadingBit(abs(0u)); + let flb_a = firstLeadingBit(-1); + let flb_b = firstLeadingBit(vec2(-1)); + let flb_c = firstLeadingBit(vec2(1u)); + let ftb_a = firstTrailingBit(-1); + let ftb_b = firstTrailingBit(1u); + let ftb_c = firstTrailingBit(vec2(-1)); + let ftb_d = firstTrailingBit(vec2(1u)); let ctz_a = countTrailingZeros(0u); let ctz_b = countTrailingZeros(0); let ctz_c = countTrailingZeros(4294967295u);