From 0ec6ead6db607a5ba0760ba291285674b6a36862 Mon Sep 17 00:00:00 2001 From: Thomas Pelletier Date: Mon, 10 Jan 2022 21:00:13 -0500 Subject: [PATCH] utf8: AVX2 implementation of Valid (#58) --- build/ascii/valid_asm.go | 1 + build/utf8/valid_asm.go | 459 ++++++++++++++++++++++++++++++++++++ utf8/cmd/valid/README.md | 69 ++++++ utf8/cmd/valid/debug.gdb | 11 + utf8/cmd/valid/main.go | 61 +++++ utf8/utf8.go | 8 + utf8/valid.go | 31 +++ utf8/valid_amd64.go | 9 + utf8/valid_amd64.s | 253 ++++++++++++++++++++ utf8/valid_default.go | 10 + utf8/valid_go18_test.go | 25 ++ utf8/valid_support_amd64.go | 21 ++ utf8/valid_test.go | 304 ++++++++++++++++++++++++ 13 files changed, 1262 insertions(+) create mode 100644 build/utf8/valid_asm.go create mode 100644 utf8/cmd/valid/README.md create mode 100644 utf8/cmd/valid/debug.gdb create mode 100644 utf8/cmd/valid/main.go create mode 100644 utf8/utf8.go create mode 100644 utf8/valid.go create mode 100644 utf8/valid_amd64.go create mode 100644 utf8/valid_amd64.s create mode 100644 utf8/valid_default.go create mode 100644 utf8/valid_go18_test.go create mode 100644 utf8/valid_support_amd64.go create mode 100644 utf8/valid_test.go diff --git a/build/ascii/valid_asm.go b/build/ascii/valid_asm.go index 056b47ed..426b4c2f 100644 --- a/build/ascii/valid_asm.go +++ b/build/ascii/valid_asm.go @@ -1,3 +1,4 @@ +//go:build ignore // +build ignore package main diff --git a/build/utf8/valid_asm.go b/build/utf8/valid_asm.go new file mode 100644 index 00000000..5a8c73d2 --- /dev/null +++ b/build/utf8/valid_asm.go @@ -0,0 +1,459 @@ +//go:build ignore +// +build ignore + +package main + +import ( + "bytes" + + . "github.com/mmcloughlin/avo/build" + . "github.com/mmcloughlin/avo/operand" + . "github.com/mmcloughlin/avo/reg" + . "github.com/segmentio/asm/build/internal/asm" +) + +func init() { + ConstraintExpr("!purego") +} + +func incompleteMaskData() []byte { + // The incomplete mask is used on every block to flag the bytes that are + // incomplete if this is the last block (for example a byte that starts + // a 4 byte character only 3 bytes before the end). + any := byte(0xFF) + needs4 := byte(0b11110000) - 1 + needs3 := byte(0b11100000) - 1 + needs2 := byte(0b11000000) - 1 + b := [32]byte{ + any, any, any, any, any, any, any, any, + any, any, any, any, any, any, any, any, + any, any, any, any, any, any, any, any, + any, any, any, any, any, needs4, needs3, needs2, + } + return b[:] +} + +func continuationMaskData(pattern byte) []byte { + // Pattern is something like 0b11100000 to accept all bytes of the form + // 111xxxxx. + v := pattern - 1 + return bytes.Repeat([]byte{v}, 32) +} + +func nibbleMasksData() (nib1, nib2, nib3 []byte) { + const ( + TooShort = 1 << 0 + TooLong = 1 << 1 + Overlong3 = 1 << 2 + Surrogate = 1 << 4 + Overlong2 = 1 << 5 + TwoConts = 1 << 7 + TooLarge = 1 << 3 + TooLarge1000 = 1 << 6 + Overlong4 = 1 << 6 + Carry = TooShort | TooLong | TwoConts + ) + + fullMask := func(b [16]byte) []byte { + m := make([]byte, 32) + copy(m, b[:]) + copy(m[16:], b[:]) + return m + } + + nib1 = fullMask([16]byte{ + // 0_______ ________ + TooLong, TooLong, TooLong, TooLong, + TooLong, TooLong, TooLong, TooLong, + // 10______ ________ + TwoConts, TwoConts, TwoConts, TwoConts, + // 1100____ ________ + TooShort | Overlong2, + // 1101____ ________ + TooShort, + // 1110____ ________ + TooShort | Overlong3 | Surrogate, + // 1111____ ________ + TooShort | TooLarge | TooLarge1000 | Overlong4, + }) + + nib2 = fullMask([16]byte{ + // ____0000 ________ + Carry | Overlong3 | Overlong2 | Overlong4, + // ____0001 ________ + Carry | Overlong2, + // ____001_ ________ + Carry, + Carry, + + // ____0100 ________ + Carry | TooLarge, + // ____0101 ________ + Carry | TooLarge | TooLarge1000, + // ____011_ ________ + Carry | TooLarge | TooLarge1000, + Carry | TooLarge | TooLarge1000, + + // ____1___ ________ + Carry | TooLarge | TooLarge1000, + Carry | TooLarge | TooLarge1000, + Carry | TooLarge | TooLarge1000, + Carry | TooLarge | TooLarge1000, + Carry | TooLarge | TooLarge1000, + // ____1101 ________ + Carry | TooLarge | TooLarge1000 | Surrogate, + Carry | TooLarge | TooLarge1000, + Carry | TooLarge | TooLarge1000, + }) + + nib3 = fullMask([16]byte{ + // ________ 0_______ + TooShort, TooShort, TooShort, TooShort, + TooShort, TooShort, TooShort, TooShort, + + // ________ 1000____ + TooLong | Overlong2 | TwoConts | Overlong3 | TooLarge1000 | Overlong4, + // ________ 1001____ + TooLong | Overlong2 | TwoConts | Overlong3 | TooLarge, + // ________ 101_____ + TooLong | Overlong2 | TwoConts | Surrogate | TooLarge, + TooLong | Overlong2 | TwoConts | Surrogate | TooLarge, + + // ________ 11______ + TooShort, TooShort, TooShort, TooShort, + }) + + return +} + +func main() { + TEXT("validateAvx", NOSPLIT, "func(p []byte) byte") + Doc("Optimized version of Validate for inputs of more than 32B.") + + ret, err := ReturnIndex(0).Resolve() + if err != nil { + panic(err) + } + + d := Load(Param("p").Base(), GP64()) + n := Load(Param("p").Len(), GP64()) + + isAscii := GP8() + MOVB(Imm(1), isAscii) + + Comment("Prepare the constant masks") + + incompleteMask := ConstBytes("incomplete_mask", incompleteMaskData()) + incompleteMaskY := YMM() + VMOVDQU(incompleteMask, incompleteMaskY) + + continuation4Bytes := ConstBytes("cont4_vec", continuationMaskData(0b11110000)) + continuation4BytesY := YMM() + VMOVDQU(continuation4Bytes, continuation4BytesY) + + continuation3Bytes := ConstBytes("cont3_vec", continuationMaskData(0b11100000)) + continuation3BytesY := YMM() + VMOVDQU(continuation3Bytes, continuation3BytesY) + + nib1Data, nib2Data, nib3Data := nibbleMasksData() + + Comment("High nibble of current byte") + nibble1Errors := ConstBytes("nibble1_errors", nib1Data) + nibble1Y := YMM() + VMOVDQU(nibble1Errors, nibble1Y) + + Comment("Low nibble of current byte") + nibble2Errors := ConstBytes("nibble2_errors", nib2Data) + nibble2Y := YMM() + VMOVDQU(nibble2Errors, nibble2Y) + + Comment("High nibble of the next byte") + nibble3Errors := ConstBytes("nibble3_errors", nib3Data) + nibble3Y := YMM() + VMOVDQU(nibble3Errors, nibble3Y) + + Comment("Nibble mask") + lowerNibbleMask := ConstArray64("nibble_mask", + 0x0F0F0F0F0F0F0F0F, + 0x0F0F0F0F0F0F0F0F, + 0x0F0F0F0F0F0F0F0F, + 0x0F0F0F0F0F0F0F0F, + ) + + nibbleMaskY := YMM() + VMOVDQU(lowerNibbleMask, nibbleMaskY) + + Comment("MSB mask") + msbMask := ConstArray64("msb_mask", + 0x8080808080808080, + 0x8080808080808080, + 0x8080808080808080, + 0x8080808080808080, + ) + + msbMaskY := YMM() + VMOVDQU(msbMask, msbMaskY) + + Comment("For the first pass, set the previous block as zero.") + previousBlockY := YMM() + zeroOutVector(previousBlockY) + + Comment("Zeroes the error vector.") + errorY := YMM() + zeroOutVector(errorY) + + Comment(`Zeroes the "previous block was incomplete" vector.`) + incompletePreviousBlockY := YMM() + zeroOutVector(incompletePreviousBlockY) + + Comment("Top of the loop.") + Label("check_input") + + currentBlockY := YMM() + + Comment("if bytes left >= 32") + CMPQ(n, U8(32)) + Comment("go process the next block") + JGE(LabelRef("process")) + + Comment("If < 32 bytes left") + + Comment("Fast exit if done") + CMPQ(n, U8(0)) + JE(LabelRef("end")) + + // At this point we know we need to load up to 32 bytes of input to + // finish the validation and pad the rest of the input vector with + // zeroes. + // + // This code assumes that the remainder of the input data ends right + // before a page boundary. As a result, we need to take special care to + // avoid a page fault. + // + // At a high level: + // + // 1. Move back the data pointer so that the 32 bytes load ends exactly + // where the input does. + // + // 2. Shift right the loaded input so that the remaining input starts at + // the beginning of the vector. + // + // 3. Pad the rest of the vector with zeroes. + // + // Because AVX2 32 bytes vectors are really two 16 bytes vector, we need + // to jump through hoops to perform the shift operation accross + // lates. This code has two versions, one for inputs of less than 16 + // bytes, and one for larger inputs. Though the latter as more steps, + // they work using a shuffle mask to shift the bytes in the vector, and + // a blend operation to stich together the various pieces of the + // resulting vector. + // + // TODO: faster load code when not near a page boundary. + + Comment("If 0 < bytes left < 32") + + zeroes := YMM() + VPXOR(zeroes, zeroes, zeroes) + + shuffleMaskBytes := make([]byte, 3*16) + for i := byte(0); i < 16; i++ { + shuffleMaskBytes[i] = i + shuffleMaskBytes[i+16] = i + shuffleMaskBytes[i+32] = i + } + shuffleMask := ConstBytes("shuffle_mask", shuffleMaskBytes) + + shuffleClearMaskBytes := make([]byte, 3*16) + for i := byte(0); i < 16; i++ { + shuffleClearMaskBytes[i] = i + shuffleClearMaskBytes[i+16] = 0xFF + shuffleClearMaskBytes[i+32] = 0xFF + } + shuffleClearMask := ConstBytes("shuffle_clear_mask", shuffleClearMaskBytes) + + offset := GP64() + shuffleMaskPtr := GP64() + shuffle := YMM() + tmp1 := YMM() + + MOVQ(U64(32), offset) + SUBQ(n, offset) + + SUBQ(offset, d) + + VMOVDQU(Mem{Base: d}, currentBlockY) + + CMPQ(n, U8(16)) + JA(LabelRef("tail_load_large")) + + Comment("Shift right that works if remaining bytes <= 16, safe next to a page boundary") + + VPERM2I128(Imm(3), currentBlockY, zeroes, currentBlockY) + + LEAQ(shuffleClearMask.Offset(16), shuffleMaskPtr) + ADDQ(n, offset) + ADDQ(n, offset) + SUBQ(Imm(32), offset) + SUBQ(offset, shuffleMaskPtr) + VMOVDQU(Mem{Base: shuffleMaskPtr}, shuffle) + + VPSHUFB(shuffle, currentBlockY, currentBlockY) + + XORQ(n, n) + JMP(LabelRef("loaded")) + + Comment("Shift right that works if remaining bytes >= 16, safe next to a page boundary") + Label("tail_load_large") + + ADDQ(n, offset) + ADDQ(n, offset) + SUBQ(Imm(48), offset) + + LEAQ(shuffleMask.Offset(16), shuffleMaskPtr) + SUBQ(offset, shuffleMaskPtr) + VMOVDQU(Mem{Base: shuffleMaskPtr}, shuffle) + + VPSHUFB(shuffle, currentBlockY, tmp1) + + tmp2 := YMM() + VPERM2I128(Imm(3), currentBlockY, zeroes, tmp2) + + VPSHUFB(shuffle, tmp2, tmp2) + + blendMaskBytes := make([]byte, 3*16) + for i := byte(0); i < 16; i++ { + blendMaskBytes[i] = 0xFF + blendMaskBytes[i+16] = 0x00 + blendMaskBytes[i+32] = 0xFF + } + blendMask := ConstBytes("blend_mask", blendMaskBytes) + + blendMaskStartPtr := GP64() + LEAQ(blendMask.Offset(16), blendMaskStartPtr) + SUBQ(offset, blendMaskStartPtr) + + blend := YMM() + VBROADCASTF128(Mem{Base: blendMaskStartPtr}, blend) + VPBLENDVB(blend, tmp1, tmp2, currentBlockY) + + XORQ(n, n) + JMP(LabelRef("loaded")) + + Comment("Process one 32B block of data") + Label("process") + + Comment("Load the next block of bytes") + VMOVDQU(Mem{Base: d}, currentBlockY) + SUBQ(U8(32), n) + ADDQ(U8(32), d) + + Label("loaded") + + Comment("Fast check to see if ASCII") + tmp := GP32() + VPMOVMSKB(currentBlockY, tmp) + CMPL(tmp, Imm(0)) + JNZ(LabelRef("non_ascii")) + + Comment("If this whole block is ASCII, there is nothing to do, and it is an error if any of the previous code point was incomplete.") + VPOR(errorY, incompletePreviousBlockY, errorY) + JMP(LabelRef("check_input")) + + Label("non_ascii") + XORB(isAscii, isAscii) + + Comment("Prepare intermediate vector for push operations") + vp := YMM() + VPERM2I128(Imm(3), previousBlockY, currentBlockY, vp) + + Comment("Check errors on the high nibble of the previous byte") + previousY := YMM() + VPALIGNR(Imm(15), vp, currentBlockY, previousY) + + highPrev := highNibbles(previousY, nibbleMaskY) + VPSHUFB(highPrev, nibble1Y, highPrev) + + Comment("Check errors on the low nibble of the previous byte") + lowPrev := lowNibbles(previousY, nibbleMaskY) + VPSHUFB(lowPrev, nibble2Y, lowPrev) + VPAND(lowPrev, highPrev, highPrev) + + Comment("Check errors on the high nibble on the current byte") + highCurr := highNibbles(currentBlockY, nibbleMaskY) + VPSHUFB(highCurr, nibble3Y, highCurr) + VPAND(highCurr, highPrev, highPrev) + + Comment("Find 3 bytes continuations") + off2 := YMM() + VPALIGNR(Imm(14), vp, currentBlockY, off2) + VPSUBUSB(continuation3BytesY, off2, off2) + + Comment("Find 4 bytes continuations") + off3 := YMM() + VPALIGNR(Imm(13), vp, currentBlockY, off3) + + VPSUBUSB(continuation4BytesY, off3, off3) + + Comment("Combine them to have all continuations") + continuationBitsY := YMM() + VPOR(off2, off3, continuationBitsY) + + Comment("Perform a byte-sized signed comparison with zero to turn any non-zero bytes into 0xFF.") + tmpY := zeroOutVector(YMM()) + VPCMPGTB(tmpY, continuationBitsY, continuationBitsY) + + Comment("Find bytes that are continuations by looking at their most significant bit.") + VPAND(msbMaskY, continuationBitsY, continuationBitsY) + + Comment("Find mismatches between expected and actual continuation bytes") + VPXOR(continuationBitsY, highPrev, continuationBitsY) + + Comment("Store result in sticky error") + VPOR(errorY, continuationBitsY, errorY) + + Comment("Prepare for next iteration") + VPSUBUSB(incompleteMaskY, currentBlockY, incompletePreviousBlockY) + VMOVDQU(currentBlockY, previousBlockY) + + Comment("End of loop") + JMP(LabelRef("check_input")) + + Label("end") + + Comment("If the previous block was incomplete, this is an error.") + VPOR(incompletePreviousBlockY, errorY, errorY) + + Comment("Return whether any error bit was set") + VPTEST(errorY, errorY) + out := GP8() + SETEQ(out) + + Comment("Bit 0 tells if the input is valid utf8, bit 1 tells if it's valid ascii") + ANDB(out, isAscii) + SHLB(Imm(1), isAscii) + ORB(isAscii, out) + + MOVB(out, ret.Addr) + VZEROUPPER() + RET() + + Generate() +} + +func lowNibbles(a VecVirtual, nibbleMask VecVirtual) VecVirtual { + out := YMM() + VPAND(a, nibbleMask, out) + return out +} + +func highNibbles(a VecVirtual, nibbleMask VecVirtual) VecVirtual { + out := YMM() + VPSRLW(Imm(4), a, out) + VPAND(out, nibbleMask, out) + return out +} + +func zeroOutVector(y VecVirtual) VecVirtual { + VXORPS(y, y, y) + return y +} diff --git a/utf8/cmd/valid/README.md b/utf8/cmd/valid/README.md new file mode 100644 index 00000000..1d165b59 --- /dev/null +++ b/utf8/cmd/valid/README.md @@ -0,0 +1,69 @@ +# valid + +This program is a helper to check the output of `utf8.Valid` facilitate +debugging. It accepts some input, runs both this library and stdlib's version of +`utf8.Valid`, and prints out the result. + +## Usage + +Provide the input as the the first argument to the program: + +``` +$ go run main.go "hello! 😊" +hello! 😊 +[104 101 108 108 111 33 32 240 159 152 138] +11 bytes +stdlib: utf8: true ascii: false +valid: utf8: true ascii: false v: 1 +``` + +The input is parsed as a double quoted Go string, so you can use escape codes: + +``` +$ go run main.go "\xFA" + +[250] +1 bytes +stdlib: utf8: false ascii: false +valid: utf8: false ascii: false v: 0 +``` + +Alternatively it can also conusme input from stdin: + +``` +$ cat example.txt +hello! 😊 +$ go run main.go < example.txt +hello! 😊 + +[104 101 108 108 111 33 32 240 159 152 138 10] +12 bytes +stdlib: utf8: true ascii: false +valid: utf8: true ascii: false v: 1 +``` + +As a bonus, if the file is the result of a failure reported by Go 1.18 fuzz, the +program extracts the actual value of the test: + +``` +$ cat fuzz.out +go test fuzz +[]byte("000000000000000000~\xFF") +$ go run main.go < fuzz.out +Got fuzzer input +000000000000000000~ +[48 48 48 48 48 48 48 48 48 48 48 48 48 48 48 48 48 48 126 255] +20 bytes +stdlib: utf8: false ascii: false +valid: utf8: false ascii: false v: 0 +``` + +## GDB + +A useful way to debug is to run this program with some problematic input and use +GDB to step through the execution and inspect registers. The `debug.gdb` file is +a basic helper to automate part of the process. For example: + +``` +$ go build main.go && gdb --command=debug.gdb -ex "set args < ./example.txt" ./main +``` diff --git a/utf8/cmd/valid/debug.gdb b/utf8/cmd/valid/debug.gdb new file mode 100644 index 00000000..9d805ce1 --- /dev/null +++ b/utf8/cmd/valid/debug.gdb @@ -0,0 +1,11 @@ +tui enable +tui reg all + +b github.com/segmentio/asm/utf8.validateAvx + +commands 1 +b +4 +c +end + +r diff --git a/utf8/cmd/valid/main.go b/utf8/cmd/valid/main.go new file mode 100644 index 00000000..f5df87a1 --- /dev/null +++ b/utf8/cmd/valid/main.go @@ -0,0 +1,61 @@ +package main + +import ( + "fmt" + "io/ioutil" + "os" + "regexp" + "strconv" + "strings" + stdlib "unicode/utf8" + + "github.com/segmentio/asm/ascii" + "github.com/segmentio/asm/utf8" +) + +func main() { + var data []byte + if len(os.Args) > 1 { + s := os.Args[1] + s, err := strconv.Unquote(`"` + s + `"`) + if err != nil { + panic(err) + } + data = []byte(s) + } else { + var err error + data, err = ioutil.ReadAll(os.Stdin) + if err != nil { + panic(err) + } + } + + s := string(data) + lines := strings.Split(s, "\n") + if len(lines) > 0 && strings.HasPrefix(lines[0], "go test fuzz") { + fmt.Println("Got fuzzer input") + // TODO: parse with go/parse instead of regexp? + r := regexp.MustCompile(`^\[\]byte\((.+)\)`) + results := r.FindStringSubmatch(lines[1]) + s, err := strconv.Unquote(results[1]) + if err != nil { + panic(err) + } + data = []byte(s) + } + + fmt.Println(string(data)) + fmt.Println(data) + fmt.Println(len(data), "bytes") + + uref := stdlib.Valid(data) + aref := ascii.Valid(data) + fmt.Println("stdlib: utf8:", uref, "ascii:", aref) + + v := utf8.Validate(data) + fmt.Println("valid: utf8:", v.IsUTF8(), "ascii:", v.IsASCII(), "v:", v) + + if uref != v.IsUTF8() || aref != v.IsASCII() { + os.Exit(1) + } +} diff --git a/utf8/utf8.go b/utf8/utf8.go new file mode 100644 index 00000000..369d27a3 --- /dev/null +++ b/utf8/utf8.go @@ -0,0 +1,8 @@ +package utf8 + +import _ "github.com/segmentio/asm/cpu" + +// Valid reports whether p consists entirely of valid UTF-8-encoded runes. +func Valid(p []byte) bool { + return Validate(p).IsUTF8() +} diff --git a/utf8/valid.go b/utf8/valid.go new file mode 100644 index 00000000..edad758f --- /dev/null +++ b/utf8/valid.go @@ -0,0 +1,31 @@ +package utf8 + +import ( + "unicode/utf8" + + "github.com/segmentio/asm/ascii" +) + +type Validation byte + +const ( + Invalid = 0 + UTF8 = 0b01 + ASCII = 0b10 | UTF8 +) + +func (v Validation) IsASCII() bool { return (v & ASCII) == ASCII } + +func (v Validation) IsUTF8() bool { return (v & UTF8) == UTF8 } + +func (v Validation) IsInvalid() bool { return v == Invalid } + +func validate(p []byte) Validation { + if ascii.Valid(p) { + return ASCII + } + if utf8.Valid(p) { + return UTF8 + } + return Invalid +} diff --git a/utf8/valid_amd64.go b/utf8/valid_amd64.go new file mode 100644 index 00000000..d9e9612d --- /dev/null +++ b/utf8/valid_amd64.go @@ -0,0 +1,9 @@ +// Code generated by command: go run valid_asm.go -pkg utf8 -out ../utf8/valid_amd64.s -stubs ../utf8/valid_amd64.go. DO NOT EDIT. + +//go:build !purego +// +build !purego + +package utf8 + +// Optimized version of Validate for inputs of more than 32B. +func validateAvx(p []byte) byte diff --git a/utf8/valid_amd64.s b/utf8/valid_amd64.s new file mode 100644 index 00000000..501dfc62 --- /dev/null +++ b/utf8/valid_amd64.s @@ -0,0 +1,253 @@ +// Code generated by command: go run valid_asm.go -pkg utf8 -out ../utf8/valid_amd64.s -stubs ../utf8/valid_amd64.go. DO NOT EDIT. + +//go:build !purego +// +build !purego + +#include "textflag.h" + +// func validateAvx(p []byte) byte +// Requires: AVX, AVX2 +TEXT ·validateAvx(SB), NOSPLIT, $0-25 + MOVQ p_base+0(FP), AX + MOVQ p_len+8(FP), CX + MOVB $0x01, DL + + // Prepare the constant masks + VMOVDQU incomplete_mask<>+0(SB), Y0 + VMOVDQU cont4_vec<>+0(SB), Y1 + VMOVDQU cont3_vec<>+0(SB), Y2 + + // High nibble of current byte + VMOVDQU nibble1_errors<>+0(SB), Y3 + + // Low nibble of current byte + VMOVDQU nibble2_errors<>+0(SB), Y4 + + // High nibble of the next byte + VMOVDQU nibble3_errors<>+0(SB), Y5 + + // Nibble mask + VMOVDQU nibble_mask<>+0(SB), Y6 + + // MSB mask + VMOVDQU msb_mask<>+0(SB), Y7 + + // For the first pass, set the previous block as zero. + VXORPS Y8, Y8, Y8 + + // Zeroes the error vector. + VXORPS Y9, Y9, Y9 + + // Zeroes the "previous block was incomplete" vector. + VXORPS Y10, Y10, Y10 + + // Top of the loop. +check_input: + // if bytes left >= 32 + CMPQ CX, $0x20 + + // go process the next block + JGE process + + // If < 32 bytes left + // Fast exit if done + CMPQ CX, $0x00 + JE end + + // If 0 < bytes left < 32 + VPXOR Y12, Y12, Y12 + MOVQ $0x0000000000000020, BX + SUBQ CX, BX + SUBQ BX, AX + VMOVDQU (AX), Y11 + CMPQ CX, $0x10 + JA tail_load_large + + // Shift right that works if remaining bytes <= 16, safe next to a page boundary + VPERM2I128 $0x03, Y11, Y12, Y11 + LEAQ shuffle_clear_mask<>+16(SB), SI + ADDQ CX, BX + ADDQ CX, BX + SUBQ $0x20, BX + SUBQ BX, SI + VMOVDQU (SI), Y13 + VPSHUFB Y13, Y11, Y11 + XORQ CX, CX + JMP loaded + + // Shift right that works if remaining bytes >= 16, safe next to a page boundary +tail_load_large: + ADDQ CX, BX + ADDQ CX, BX + SUBQ $0x30, BX + LEAQ shuffle_mask<>+16(SB), SI + SUBQ BX, SI + VMOVDQU (SI), Y13 + VPSHUFB Y13, Y11, Y14 + VPERM2I128 $0x03, Y11, Y12, Y11 + VPSHUFB Y13, Y11, Y11 + LEAQ blend_mask<>+16(SB), CX + SUBQ BX, CX + VBROADCASTF128 (CX), Y12 + VPBLENDVB Y12, Y14, Y11, Y11 + XORQ CX, CX + JMP loaded + + // Process one 32B block of data +process: + // Load the next block of bytes + VMOVDQU (AX), Y11 + SUBQ $0x20, CX + ADDQ $0x20, AX + +loaded: + // Fast check to see if ASCII + VPMOVMSKB Y11, BX + CMPL BX, $0x00 + JNZ non_ascii + + // If this whole block is ASCII, there is nothing to do, and it is an error if any of the previous code point was incomplete. + VPOR Y9, Y10, Y9 + JMP check_input + +non_ascii: + XORB DL, DL + + // Prepare intermediate vector for push operations + VPERM2I128 $0x03, Y8, Y11, Y8 + + // Check errors on the high nibble of the previous byte + VPALIGNR $0x0f, Y8, Y11, Y10 + VPSRLW $0x04, Y10, Y12 + VPAND Y12, Y6, Y12 + VPSHUFB Y12, Y3, Y12 + + // Check errors on the low nibble of the previous byte + VPAND Y10, Y6, Y10 + VPSHUFB Y10, Y4, Y10 + VPAND Y10, Y12, Y12 + + // Check errors on the high nibble on the current byte + VPSRLW $0x04, Y11, Y10 + VPAND Y10, Y6, Y10 + VPSHUFB Y10, Y5, Y10 + VPAND Y10, Y12, Y12 + + // Find 3 bytes continuations + VPALIGNR $0x0e, Y8, Y11, Y10 + VPSUBUSB Y2, Y10, Y10 + + // Find 4 bytes continuations + VPALIGNR $0x0d, Y8, Y11, Y8 + VPSUBUSB Y1, Y8, Y8 + + // Combine them to have all continuations + VPOR Y10, Y8, Y8 + + // Perform a byte-sized signed comparison with zero to turn any non-zero bytes into 0xFF. + VXORPS Y10, Y10, Y10 + VPCMPGTB Y10, Y8, Y8 + + // Find bytes that are continuations by looking at their most significant bit. + VPAND Y7, Y8, Y8 + + // Find mismatches between expected and actual continuation bytes + VPXOR Y8, Y12, Y8 + + // Store result in sticky error + VPOR Y9, Y8, Y9 + + // Prepare for next iteration + VPSUBUSB Y0, Y11, Y10 + VMOVDQU Y11, Y8 + + // End of loop + JMP check_input + +end: + // If the previous block was incomplete, this is an error. + VPOR Y10, Y9, Y9 + + // Return whether any error bit was set + VPTEST Y9, Y9 + SETEQ AL + + // Bit 0 tells if the input is valid utf8, bit 1 tells if it's valid ascii + ANDB AL, DL + SHLB $0x01, DL + ORB DL, AL + MOVB AL, ret+24(FP) + VZEROUPPER + RET + +DATA incomplete_mask<>+0(SB)/8, $0xffffffffffffffff +DATA incomplete_mask<>+8(SB)/8, $0xffffffffffffffff +DATA incomplete_mask<>+16(SB)/8, $0xffffffffffffffff +DATA incomplete_mask<>+24(SB)/8, $0xbfdfefffffffffff +GLOBL incomplete_mask<>(SB), RODATA|NOPTR, $32 + +DATA cont4_vec<>+0(SB)/8, $0xefefefefefefefef +DATA cont4_vec<>+8(SB)/8, $0xefefefefefefefef +DATA cont4_vec<>+16(SB)/8, $0xefefefefefefefef +DATA cont4_vec<>+24(SB)/8, $0xefefefefefefefef +GLOBL cont4_vec<>(SB), RODATA|NOPTR, $32 + +DATA cont3_vec<>+0(SB)/8, $0xdfdfdfdfdfdfdfdf +DATA cont3_vec<>+8(SB)/8, $0xdfdfdfdfdfdfdfdf +DATA cont3_vec<>+16(SB)/8, $0xdfdfdfdfdfdfdfdf +DATA cont3_vec<>+24(SB)/8, $0xdfdfdfdfdfdfdfdf +GLOBL cont3_vec<>(SB), RODATA|NOPTR, $32 + +DATA nibble1_errors<>+0(SB)/8, $0x0202020202020202 +DATA nibble1_errors<>+8(SB)/8, $0x4915012180808080 +DATA nibble1_errors<>+16(SB)/8, $0x0202020202020202 +DATA nibble1_errors<>+24(SB)/8, $0x4915012180808080 +GLOBL nibble1_errors<>(SB), RODATA|NOPTR, $32 + +DATA nibble2_errors<>+0(SB)/8, $0xcbcbcb8b8383a3e7 +DATA nibble2_errors<>+8(SB)/8, $0xcbcbdbcbcbcbcbcb +DATA nibble2_errors<>+16(SB)/8, $0xcbcbcb8b8383a3e7 +DATA nibble2_errors<>+24(SB)/8, $0xcbcbdbcbcbcbcbcb +GLOBL nibble2_errors<>(SB), RODATA|NOPTR, $32 + +DATA nibble3_errors<>+0(SB)/8, $0x0101010101010101 +DATA nibble3_errors<>+8(SB)/8, $0x01010101babaaee6 +DATA nibble3_errors<>+16(SB)/8, $0x0101010101010101 +DATA nibble3_errors<>+24(SB)/8, $0x01010101babaaee6 +GLOBL nibble3_errors<>(SB), RODATA|NOPTR, $32 + +DATA nibble_mask<>+0(SB)/8, $0x0f0f0f0f0f0f0f0f +DATA nibble_mask<>+8(SB)/8, $0x0f0f0f0f0f0f0f0f +DATA nibble_mask<>+16(SB)/8, $0x0f0f0f0f0f0f0f0f +DATA nibble_mask<>+24(SB)/8, $0x0f0f0f0f0f0f0f0f +GLOBL nibble_mask<>(SB), RODATA|NOPTR, $32 + +DATA msb_mask<>+0(SB)/8, $0x8080808080808080 +DATA msb_mask<>+8(SB)/8, $0x8080808080808080 +DATA msb_mask<>+16(SB)/8, $0x8080808080808080 +DATA msb_mask<>+24(SB)/8, $0x8080808080808080 +GLOBL msb_mask<>(SB), RODATA|NOPTR, $32 + +DATA shuffle_mask<>+0(SB)/8, $0x0706050403020100 +DATA shuffle_mask<>+8(SB)/8, $0x0f0e0d0c0b0a0908 +DATA shuffle_mask<>+16(SB)/8, $0x0706050403020100 +DATA shuffle_mask<>+24(SB)/8, $0x0f0e0d0c0b0a0908 +DATA shuffle_mask<>+32(SB)/8, $0x0706050403020100 +DATA shuffle_mask<>+40(SB)/8, $0x0f0e0d0c0b0a0908 +GLOBL shuffle_mask<>(SB), RODATA|NOPTR, $48 + +DATA shuffle_clear_mask<>+0(SB)/8, $0x0706050403020100 +DATA shuffle_clear_mask<>+8(SB)/8, $0x0f0e0d0c0b0a0908 +DATA shuffle_clear_mask<>+16(SB)/8, $0xffffffffffffffff +DATA shuffle_clear_mask<>+24(SB)/8, $0xffffffffffffffff +DATA shuffle_clear_mask<>+32(SB)/8, $0xffffffffffffffff +DATA shuffle_clear_mask<>+40(SB)/8, $0xffffffffffffffff +GLOBL shuffle_clear_mask<>(SB), RODATA|NOPTR, $48 + +DATA blend_mask<>+0(SB)/8, $0xffffffffffffffff +DATA blend_mask<>+8(SB)/8, $0xffffffffffffffff +DATA blend_mask<>+16(SB)/8, $0x0000000000000000 +DATA blend_mask<>+24(SB)/8, $0x0000000000000000 +DATA blend_mask<>+32(SB)/8, $0xffffffffffffffff +DATA blend_mask<>+40(SB)/8, $0xffffffffffffffff +GLOBL blend_mask<>(SB), RODATA|NOPTR, $48 diff --git a/utf8/valid_default.go b/utf8/valid_default.go new file mode 100644 index 00000000..3301e9a6 --- /dev/null +++ b/utf8/valid_default.go @@ -0,0 +1,10 @@ +//go:build purego || !amd64 +// +build purego !amd64 + +package utf8 + +// Validate is a more precise version of Valid that also indicates whether the +// input was valid ASCII. +func Validate(p []byte) Validation { + return validate(p) +} diff --git a/utf8/valid_go18_test.go b/utf8/valid_go18_test.go new file mode 100644 index 00000000..301dca51 --- /dev/null +++ b/utf8/valid_go18_test.go @@ -0,0 +1,25 @@ +//go:build go1.18 +// +build go1.18 + +package utf8 + +import ( + "testing" + stdlib "unicode/utf8" + + "github.com/segmentio/asm/ascii" +) + +func FuzzValid(f *testing.F) { + f.Fuzz(func(t *testing.T, data []byte) { + v := Validate(data) + ru := stdlib.Valid(data) + if ru != v.IsUTF8() { + t.Errorf("Validate(%q) UTF8 = %v; want %v", data, v.IsUTF8(), ru) + } + ra := ascii.Valid(data) + if ra != v.IsASCII() { + t.Errorf("Validate(%q) ASCII = %v; want %v", data, v.IsASCII(), ra) + } + }) +} diff --git a/utf8/valid_support_amd64.go b/utf8/valid_support_amd64.go new file mode 100644 index 00000000..c3e83289 --- /dev/null +++ b/utf8/valid_support_amd64.go @@ -0,0 +1,21 @@ +//go:build !purego +// +build !purego + +package utf8 + +import ( + "github.com/segmentio/asm/cpu" + "github.com/segmentio/asm/cpu/x86" +) + +var noAVX2 = !cpu.X86.Has(x86.AVX2) + +// Validate is a more precise version of Valid that also indicates whether the +// input was valid ASCII. +func Validate(p []byte) Validation { + if noAVX2 || len(p) < 32 { + return validate(p) + } + r := validateAvx(p) + return Validation(r) +} diff --git a/utf8/valid_test.go b/utf8/valid_test.go new file mode 100644 index 00000000..cd1f82aa --- /dev/null +++ b/utf8/valid_test.go @@ -0,0 +1,304 @@ +package utf8 + +import ( + "bytes" + "fmt" + "io/ioutil" + "strings" + "testing" + "unicode/utf8" + + "github.com/segmentio/asm/ascii" + "github.com/segmentio/asm/internal/buffer" +) + +type byteRange struct { + Low byte + High byte +} + +func one(b byte) byteRange { + return byteRange{b, b} +} + +func genExamples(current string, ranges []byteRange) []string { + if len(ranges) == 0 { + return []string{string(current)} + } + r := ranges[0] + var all []string + + elements := []byte{r.Low, r.High} + + mid := (r.High + r.Low) / 2 + if mid != r.Low && mid != r.High { + elements = append(elements, mid) + } + + for _, x := range elements { + s := current + string(x) + all = append(all, genExamples(s, ranges[1:])...) + if x == r.High { + break + } + } + return all +} + +func TestValid(t *testing.T) { + var examples = []string{ + // Tests copied from the stdlib + "", + "a", + "abc", + "Ж", + "ЖЖ", + "брэд-ЛГТМ", + "☺☻☹", + + // overlong + "\xE0\x80", + // unfinished continuation + "aa\xE2", + + string([]byte{66, 250}), + + string([]byte{66, 250, 67}), + + "a\uFFFDb", + + "\xF4\x8F\xBF\xBF", // U+10FFFF + + "\xF4\x90\x80\x80", // U+10FFFF+1; out of range + "\xF7\xBF\xBF\xBF", // 0x1FFFFF; out of range + + "\xFB\xBF\xBF\xBF\xBF", // 0x3FFFFFF; out of range + + "\xc0\x80", // U+0000 encoded in two bytes: incorrect + "\xed\xa0\x80", // U+D800 high surrogate (sic) + "\xed\xbf\xbf", // U+DFFF low surrogate (sic) + + // valid at boundary + strings.Repeat("a", 32+28) + "☺☻☹", + strings.Repeat("a", 32+29) + "☺☻☹", + strings.Repeat("a", 32+30) + "☺☻☹", + strings.Repeat("a", 32+31) + "☺☻☹", + // invalid at boundary + strings.Repeat("a", 32+31) + "\xE2a", + + // same inputs as benchmarks + "0123456789", + "日本語日本語日本語日", + "\xF4\x8F\xBF\xBF", + + // bugs found with fuzzing + "0000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000\xc60", + "000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000\xc300", + "߀0000000000000000000000000000訨", + "0000000000000000000000000000000˂00000000000000000000000000000000", + } + + any := byteRange{0, 0xFF} + ascii := byteRange{0, 0x7F} + cont := byteRange{0x80, 0xBF} + + rangesToTest := [][]byteRange{ + {one(0x20), ascii, ascii, ascii}, + + // 2-byte sequences + {one(0xC2)}, + {one(0xC2), ascii}, + {one(0xC2), cont}, + {one(0xC2), {0xC0, 0xFF}}, + {one(0xC2), cont, cont}, + {one(0xC2), cont, cont, cont}, + + // 3-byte sequences + {one(0xE1)}, + {one(0xE1), cont}, + {one(0xE1), cont, cont}, + {one(0xE1), cont, cont, ascii}, + {one(0xE1), cont, ascii}, + {one(0xE1), cont, cont, cont}, + + // 4-byte sequences + {one(0xF1)}, + {one(0xF1), cont}, + {one(0xF1), cont, cont}, + {one(0xF1), cont, cont, cont}, + {one(0xF1), cont, cont, ascii}, + {one(0xF1), cont, cont, cont, ascii}, + + // overlong + {{0xC0, 0xC1}, any}, + {{0xC0, 0xC1}, any, any}, + {{0xC0, 0xC1}, any, any, any}, + {one(0xE0), {0x0, 0x9F}, cont}, + {one(0xE0), {0xA0, 0xBF}, cont}, + } + + for _, r := range rangesToTest { + examples = append(examples, genExamples("", r)...) + } + + for _, i := range []int{300, 316} { + d := bytes.Repeat(someutf8, i/len(someutf8)) + examples = append(examples, string(d)) + } + + for _, tt := range examples { + t.Run(tt, func(t *testing.T) { + check(t, []byte(tt)) + }) + + // Generate variations of the input to exercise errors at the + // boundary, using the vector implementation on 32-sized input, + // and on non-32-sized inputs. + // + // Large examples don't go through those variations because they + // are likely specific tests. + + if len(tt) >= 32 { + continue + } + + t.Run("boundary-"+tt, func(t *testing.T) { + size := 32 - len(tt) + prefix := strings.Repeat("a", size) + b := []byte(prefix + tt) + check(t, b) + }) + t.Run("vec-padded-"+tt, func(t *testing.T) { + prefix := strings.Repeat("a", 32) + padding := strings.Repeat("b", 32-(len(tt)%32)) + input := prefix + padding + tt + b := []byte(input) + if len(b)%32 != 0 { + panic("test should generate block of 32") + } + check(t, b) + }) + t.Run("vec-"+tt, func(t *testing.T) { + prefix := strings.Repeat("a", 32) + input := prefix + tt + if len(tt)%32 == 0 { + input += "x" + } + b := []byte(input) + if len(b)%32 == 0 { + panic("test should not generate block of 32") + } + check(t, b) + }) + } +} + +func TestValidPageBoundary(t *testing.T) { + buf, err := buffer.New(64) + if err != nil { + t.Fatal(err) + } + defer buf.Release() + + head := buf.ProtectHead() + tail := buf.ProtectTail() + + data := bytes.Repeat(someutf8, 64/len(someutf8)) + + copy(head, data) + copy(tail, data) + + for i := 0; i <= 32; i++ { + input := head[:i] + check(t, input) + } + + for i := 0; i <= 32; i++ { + input := tail[i:] + check(t, input) + } +} + +func check(t *testing.T, b []byte) { + t.Helper() + + // Check that both Valid and Validate behave properly. Should not be + // necessary given the definition of Valid, but just in case. + + expected := utf8.Valid(b) + if Valid(b) != expected { + err := ioutil.WriteFile("test.out.txt", b, 0600) + if err != nil { + panic(err) + } + + t.Errorf("Valid(%q) = %v; want %v", string(b), !expected, expected) + } + + v := Validate(b) + + if v.IsUTF8() != expected { + t.Errorf("Validate(%q) utf8 valid: %v; want %v", string(b), !expected, expected) + } + + expected = ascii.Valid(b) + if v.IsASCII() != expected { + t.Errorf("Validate(%q) ascii valid: %v; want %v", string(b), !expected, expected) + } +} + +var valid1k = bytes.Repeat([]byte("0123456789日本語日本語日本語日abcdefghijklmnopqrstuvwx"), 16) +var valid1M = bytes.Repeat(valid1k, 1024) +var someutf8 = []byte("\xF4\x8F\xBF\xBF") + +func BenchmarkValid(b *testing.B) { + impls := map[string]func([]byte) bool{ + "AVX": Valid, + "Stdlib": utf8.Valid, + } + + type input struct { + name string + data []byte + } + inputs := []input{ + {"1kValid", valid1k}, + {"1MValid", valid1M}, + {"10ASCII", []byte("0123456789")}, + {"10Japan", []byte("日本語日本語日本語日")}, + } + + const KiB = 1024 + const MiB = 1048576 + + for i := 0; i <= 400/len(someutf8); i++ { + // for _, i := range []int{1 * KiB, 8 * KiB, 16 * KiB, 64 * KiB, 1 * MiB, 8 * MiB, 32 * MiB, 64 * MiB} { + d := bytes.Repeat(someutf8, i) + inputs = append(inputs, input{ + name: fmt.Sprintf("small%d", len(d)), + data: d, + }) + } + + for _, i := range []int{300, 316} { + d := bytes.Repeat(someutf8, i/len(someutf8)) + inputs = append(inputs, input{ + name: fmt.Sprintf("tail%d", len(d)), + data: d, + }) + } + + for _, input := range inputs { + for implName, f := range impls { + testName := fmt.Sprintf("%s/%s", input.name, implName) + + b.Run(testName, func(b *testing.B) { + b.SetBytes(int64(len(input.data))) + b.ResetTimer() + for i := 0; i < b.N; i++ { + f(input.data) + } + }) + } + } +}