Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make is_ascii_hexdigit branchless #103024

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion library/core/src/char/methods.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1510,7 +1510,10 @@ impl char {
#[rustc_const_stable(feature = "const_ascii_ctype_on_intrinsics", since = "1.47.0")]
#[inline]
pub const fn is_ascii_hexdigit(&self) -> bool {
matches!(*self, '0'..='9' | 'A'..='F' | 'a'..='f')
// Bitwise or converts A-Z to a-z, avoiding need for branches in compiled code.
const A: u32 = 'a' as u32;
const F: u32 = 'f' as u32;
matches!(*self, '0'..='9') || matches!(*self as u32 | 0x20, A..=F)
Copy link
Contributor

@clarfonthey clarfonthey Oct 15, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this end up optimising the same if you refactor it to just call the u8 method? Something like:

u8::try_from(*self).map_or(false, |c| c.is_ascii_hexdigit())

Mostly because it would be nice to be able to centralise the logic in one place, rather than duplicating it twice.

Copy link
Member

@scottmcm scottmcm Oct 16, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, because it has a branch for the "is it < 256?" part: https://rust.godbolt.org/z/x13Yqx7Kx

That said, something like that might actually perform really well, because it'd probably branch predict well.

For example, imagine something like this:

if *self <= 'f' {
    let c = u8::try_from(*this).unwrap();
    c.is_ascii_hexdigit()
} else {
    false
}

Yes there's a jump, but it probably predicts great. Or at least as well as whatever jump you get from checking this condition in the first place...

As always, the complication in optimizing these is in determining the distribution of the input.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, I was hoping that LLVM would be smart enough to optimise that out (since all other paths also involve a <N<256 comparison) but I guess not.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I filed an LLVM bug about this: llvm/llvm-project#60683

Even in the simpler case of is_ascii_digit it doesn't do as well as it ought to.

}

/// Checks if the value is an ASCII punctuation character:
Expand Down
3 changes: 2 additions & 1 deletion library/core/src/num/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -688,7 +688,8 @@ impl u8 {
#[rustc_const_stable(feature = "const_ascii_ctype_on_intrinsics", since = "1.47.0")]
#[inline]
pub const fn is_ascii_hexdigit(&self) -> bool {
matches!(*self, b'0'..=b'9' | b'A'..=b'F' | b'a'..=b'f')
// Bitwise or converts A-Z to a-z, avoiding need for branches in compiled code.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Minor nit, but it feels like the lack of branches still isn't clear to the passive viewer since '0'..='9' and 'a'..='f' are still two separate ranges with a gap in the middle.

Oh, and uh, that || is looking suspiciously not-branchless. Since it ends up being so in the resulting assembly, perhaps it could be reworded as |?

Reading the compiled output:

example::char_is_hex_2:
        mov     eax, dword ptr [rdi]
        lea     ecx, [rax - 48]
        cmp     ecx, 10
        setb    cl
        or      eax, 32
        add     eax, -97
        cmp     eax, 6
        setb    al
        or      al, cl
        ret

Which is effectively:

(*self - 48 < 10) | ((*self | 0x20) - 97 < 6)

It's less that there are "fewer" branches in this code, but more that going from three ranges to two triggers a threshold in LLVM's side that makes it decide that branches are no longer worth it, and it removes them.

So... tying this all together, maybe the real point is not that it's branchless by itself, but that there are fewer computations overall and LLVM is likely to optimise it without branches as a result.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That seems reasonable, I can do one where all the optimization is done by hand to make it clear what the resulting assembly should be.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry for pushing churn on you here, but I'm actually going to ask that you aim for a middle ground where the code is still as obvious as possible while still producing branchless assembly.

Particularly, there's no need to replace the range changes with wrapping_sub manually, as LLVM will do that quite reliably: https://rust.godbolt.org/z/rPaxbf1o7

It's normal in the library for the code to not be in the form of the expected assembly. For examples, is_power_of_two is phrased using count_ones, not the well-known bitwise tricks

pub const fn is_power_of_two(self) -> bool {
self.count_ones() == 1
}

as that leaves it up to LLVM to generate the best assembly sequence -- which will be the (x != 0) && (x & (x-1)) == 0 on i386, but libcore doesn't need to know that.

So please go back to the version with the consts you had in a previous iteration. And consider making a local for the lower-cased version of the character -- that would help localize the comment. Maybe something like

        // Bitwise or converts A-Z to a-z, allowing checking for the letters with a single range.
        let lower = *self as u32 | num::ASCII_CASE_MASK;
        const A: u32 = 'a' as u32;
        const F: u32 = 'f' as u32;
        // Not using logical or because the branch isn't worth it
        self.is_ascii_digit() | matches!(*self as u32 | 0x20, A..=F)

You could also consider splitting the hex_letter part into a separate (non-pub) function so that is_ascii_hexdigit becomes just

        // Not using logical or because the branch isn't worth it
        self.is_ascii_digit() | self.is_ascii_hexletter()

to isolate the conversion-and-range stuff to its own thing.

@rustbot author

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, it looks like even the manual masking isn't necessary.

I was trying to make a repro to file an LLVM bug that it should figure out how to do that itself, and it turns out it already can. There's just something unfortunate about the matches! going on. Because just writing out the obvious comparisons does exactly what's needed: https://rust.godbolt.org/z/cPEoa1nT8

fn is_ascii_digit(c: u8) -> bool {
    c >= b'0' && c <= b'9'
}

fn is_ascii_hexletter(c: u8) -> bool {
    (c >= b'a' && c <= b'f') || (c >= b'A' && c <= b'F')
}

pub fn is_ascii_hexdigit(c: u8) -> bool {
    is_ascii_digit(c) || is_ascii_hexletter(c)
}

No branches, and does the subtraction and masking tricks:

define noundef zeroext i1 @_ZN7example17is_ascii_hexdigit17hc2736916760c6f12E(i8 %c) unnamed_addr #0 !dbg !6 {
  %0 = add i8 %c, -48, !dbg !11
  %1 = icmp ult i8 %0, 10, !dbg !11
  %2 = and i8 %c, -33, !dbg !14
  %3 = add i8 %2, -65, !dbg !14
  %4 = icmp ult i8 %3, 6, !dbg !14
  %.0 = select i1 %1, i1 true, i1 %4, !dbg !14
  ret i1 %.0, !dbg !15
}

(it picks & !32 instead of | 32, but same thing.)

I might still say to use | there, though, to get or i1 %1, %4 instead of the select, even though the x86 backend appears to use an or anyway.

Now I guess I need to figure out what's going wrong in matches!...

Copy link
Member

@scottmcm scottmcm Oct 24, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, it's just something about the or pattern in matches!. This is also branchless: https://rust.godbolt.org/z/9xqT38rz9

fn is_ascii_digit(c: u8) -> bool {
    matches!(c, b'0'..=b'9')
}

fn is_ascii_hexletter(c: u8) -> bool {
    matches!(c, b'A'..=b'F') | matches!(c, b'a'..=b'f')
}

pub fn is_ascii_hexdigit(c: u8) -> bool {
    is_ascii_digit(c) | is_ascii_hexletter(c)
}

(And yup, it works for char too: https://rust.godbolt.org/z/WvfsjMG6a.)

So I think, in the end, the right answer here is to just replace some short-circuiting rust operations with non-short circuiting ones (the logical or and pattern or each with a bitwise or instead) and call it a day, as that gives exactly the desired output.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Loving the investigative work here. Perhaps this is an actual codegen issue, since matches!(x, P) | matches!(x, Q) should ideally generate code rather close to matches!(x, P | Q). It feels like more of a codegen quirk than an LLVM optimizer quirk, but I could be wrong.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I dug in some more, and it's actually more interesting than I expected!

It looks like with the or-pattern it's tripping a different technique in LLVM. https://rust.godbolt.org/z/zdW1T43T3

Starting with the nice obvious

pub fn is_ascii_hexletter(c: char) -> bool {
    matches!(c, 'A'..='F' | 'a'..='f')
}

SimplifyCfg and SROA get it down to the nice simple short-circuiting

define noundef zeroext i1 @example::is_ascii_hexletter(i32 noundef %c) unnamed_addr {
start:
  %_4 = icmp ule i32 65, %c
  %_5 = icmp ule i32 %c, 70
  %or.cond = and i1 %_4, %_5
  br i1 %or.cond, label %bb5, label %bb2

bb2:                                              ; preds = %start
  %_2 = icmp ule i32 97, %c
  %_3 = icmp ule i32 %c, 102
  %or.cond1 = and i1 %_2, %_3
  br i1 %or.cond1, label %bb5, label %bb4

bb5:                                              ; preds = %bb2, %start
  br label %bb6

bb4:                                              ; preds = %bb2
  br label %bb6

bb6:                                              ; preds = %bb4, %bb5
  %.0 = phi i8 [ 1, %bb5 ], [ 0, %bb4 ]
  %0 = trunc i8 %.0 to i1
  ret i1 %0
}

and InstCombine does the "I know how to simplify range checks like that" rewrite:

define noundef zeroext i1 @example::is_ascii_hexletter(i32 noundef %c) unnamed_addr {
start:
  %0 = add i32 %c, -65
  %1 = icmp ult i32 %0, 6
  br i1 %1, label %bb5, label %bb2

bb2:                                              ; preds = %start
  %2 = add i32 %c, -97
  %3 = icmp ult i32 %2, 6
  br i1 %3, label %bb5, label %bb4

bb5:                                              ; preds = %bb2, %start
  br label %bb6

bb4:                                              ; preds = %bb2
  br label %bb6

bb6:                                              ; preds = %bb4, %bb5
  %.0 = phi i1 [ true, %bb5 ], [ false, %bb4 ]
  ret i1 %.0
}

But then something really interesting happens, and SimplifyCfg says "wait, that looks like a lookup table!", giving

define noundef zeroext i1 @example::is_ascii_hexletter(i32 noundef %c) unnamed_addr {
start:
  switch i32 %c, label %bb4 [
    i32 102, label %bb6
    i32 101, label %bb6
    i32 100, label %bb6
    i32 99, label %bb6
    i32 98, label %bb6
    i32 97, label %bb6
    i32 70, label %bb6
    i32 69, label %bb6
    i32 68, label %bb6
    i32 67, label %bb6
    i32 66, label %bb6
    i32 65, label %bb6
  ]

bb4:                                              ; preds = %start
  br label %bb6

bb6:                                              ; preds = %start, %start, %start, %start, %start, %start, %start, %start, %start, %start, %start, %start, %bb4
  %.0 = phi i1 [ false, %bb4 ], [ true, %start ], [ true, %start ], [ true, %start ], [ true, %start ], [ true, %start ], [ true, %start ], [ true, %start ], [ true, %start ], [ true, %start ], [ true, %start ], [ true, %start ], [ true, %start ]
  ret i1 %.0
}

Then later SimplifyCfg looks at that again and say "wait, that's a weird switch, how about I do that with a shift instead?" by encoding the table into an i38 (because 'A' through 'f' is 38 values):

define noundef zeroext i1 @example::is_ascii_hexletter(i32 noundef %c) unnamed_addr {
start:
  %switch.tableidx = sub i32 %c, 65
  %0 = icmp ult i32 %switch.tableidx, 38
  %switch.cast = zext i32 %switch.tableidx to i38
  %switch.shiftamt = mul i38 %switch.cast, 1
  %switch.downshift = lshr i38 -4294967233, %switch.shiftamt
  %switch.masked = trunc i38 %switch.downshift to i1
  %.0 = select i1 %0, i1 %switch.masked, i1 false
  ret i1 %.0
}

which is also branchless. Just not quite as good a way.

matches!(*self, b'0'..=b'9') || matches!(*self | 0x20, b'a'..=b'f')
}

/// Checks if the value is an ASCII punctuation character:
Expand Down