Skip to content

Commit

Permalink
collations: add support for remote collations
Browse files Browse the repository at this point in the history
Signed-off-by: Vicent Marti <[email protected]>
  • Loading branch information
vmg committed Oct 19, 2021
1 parent 70b6837 commit 09d731f
Show file tree
Hide file tree
Showing 33 changed files with 379 additions and 55 deletions.
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
package endtoend
package integration

import (
"bufio"
"bytes"
"context"
"encoding/hex"
"flag"
"fmt"
Expand Down Expand Up @@ -147,30 +146,102 @@ func processSQLTest(t *testing.T, testfile string, conn *mysql.Conn) {
continue
}

res, err := conn.ExecuteFetch(query, -1, false)
if err != nil {
t.Fatalf("failed to execute %q: %v", query, err)
}

res := exec(t, conn, query)
if curtest != nil {
curtest.Test(t, res)
curtest = nil
}
}
}

var testOneCollation = flag.String("test-one-collation", "", "")

func TestCollationsOnMysqld(t *testing.T) {
conn, err := mysql.Connect(context.Background(), &connParams)
func exec(t *testing.T, conn *mysql.Conn, query string) *sqltypes.Result {
res, err := conn.ExecuteFetch(query, -1, true)
if err != nil {
t.Fatal(err)
t.Fatalf("failed to execute %q: %v", query, err)
}
return res
}

func GoldenWeightString(t *testing.T, conn *mysql.Conn, collation string, input []byte) []byte {
coll := collations.RemoteByName(conn, collation)
weightString := coll.WeightString(nil, input, 0)
if weightString == nil {
t.Fatal(coll.LastError())
}
return weightString
}

const ExampleString = "abc æøå 日本語"

func TestCollationWithSpace(t *testing.T) {
conn := mysqlconn(t)
defer conn.Close()

if !strings.HasPrefix(conn.ServerVersion, "8.0.") {
t.Skipf("uca900 collations are only supported in MySQL 8.0+")
codepoints := len([]rune(ExampleString))

for _, collName := range []string{"utf8mb4_0900_ai_ci", "utf8mb4_unicode_ci", "utf8mb4_unicode_520_ci"} {
t.Run(collName, func(t *testing.T) {
local := collations.LookupByName(collName)
remote := collations.RemoteByName(conn, collName)

for _, size := range []int{0, codepoints, codepoints + 1, codepoints + 2, 20, 32} {
localWeight := local.WeightString(nil, []byte(ExampleString), size)
remoteWeight := remote.WeightString(nil, []byte(ExampleString), size)
if !bytes.Equal(localWeight, remoteWeight) {
t.Fatalf("mismatch at len=%d\ninput: %#v\nexpected: %#v\nactual: %#v",
size, []byte(ExampleString), remoteWeight, localWeight)
}
}
})
}
}

func normalizecmp(res int) int {
if res < 0 {
return -1
}
if res > 0 {
return 1
}
return 0
}

func TestRemoteKanaSensitivity(t *testing.T) {
const Kana1 = "の東京ノ"
const Kana2 = "ノ東京の"

var JapaneseCollations = []string{
"utf8mb4_0900_as_cs",
"utf8mb4_ja_0900_as_cs",
"utf8mb4_ja_0900_as_cs_ks",
}

conn := mysqlconn(t)
defer conn.Close()

for _, collName := range JapaneseCollations {
t.Run(collName, func(t *testing.T) {
local := collations.LookupByName(collName)
remote := collations.RemoteByName(conn, collName)
localResult := normalizecmp(local.Collate([]byte(Kana1), []byte(Kana2), false))
remoteResult := remote.Collate([]byte(Kana1), []byte(Kana2), false)

if err := remote.LastError(); err != nil {
t.Fatalf("remote collation failed: %v", err)
}

if localResult != remoteResult {
t.Errorf("expected STRCMP(%q, %q) to be %d (got %d)", Kana1, Kana2, remoteResult, localResult)
}
})
}
}

var testOneCollation = flag.String("test-one-collation", "", "")

func TestCollationsOnMysqld(t *testing.T) {
conn := mysqlconn(t)
defer conn.Close()

if *testOneCollation != "" {
processSQLTest(t, fmt.Sprintf("testdata/mysqltest/suite/collations/%s.test", *testOneCollation), conn)
Expand Down
97 changes: 97 additions & 0 deletions go/mysql/collations/integration/main_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
/*
Copyright 2021 The Vitess Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/

package integration

import (
"context"
"flag"
"fmt"
"os"
"strings"
"testing"

"vitess.io/vitess/go/mysql"
"vitess.io/vitess/go/vt/vttest"

vttestpb "vitess.io/vitess/go/vt/proto/vttest"
)

var (
connParams mysql.ConnParams
)

var waitmysql = flag.Bool("waitmysql", false, "")

func mysqlconn(t *testing.T) *mysql.Conn {
conn, err := mysql.Connect(context.Background(), &connParams)
if err != nil {
t.Fatal(err)
}
if !strings.HasPrefix(conn.ServerVersion, "8.0.") {
conn.Close()
t.Skipf("collation integration tests are only supported in MySQL 8.0+")
}
return conn
}

func TestMain(m *testing.M) {
flag.Parse()

exitCode := func() int {
// Launch MySQL.
// We need a Keyspace in the topology, so the DbName is set.
// We need a Shard too, so the database 'vttest' is created.
cfg := vttest.Config{
Topology: &vttestpb.VTTestTopology{
Keyspaces: []*vttestpb.Keyspace{
{
Name: "vttest",
Shards: []*vttestpb.Shard{
{
Name: "0",
DbNameOverride: "vttest",
},
},
},
},
},
OnlyMySQL: true,
}
cluster := vttest.LocalCluster{
Config: cfg,
}
if err := cluster.Setup(); err != nil {
fmt.Fprintf(os.Stderr, "could not launch mysql: %v\n", err)
return 1
}
defer cluster.TearDown()

connParams = cluster.MySQLConnParams()

if *waitmysql {
debugMysql()
}
return m.Run()
}()
os.Exit(exitCode)
}

func debugMysql() {
fmt.Fprintf(os.Stderr, "Connect to MySQL using parameters: mysql -u %s -D %s -S %s\n",
connParams.Uname, connParams.DbName, connParams.UnixSocket)
select {}
}
6 changes: 3 additions & 3 deletions go/mysql/collations/internal/uca/collation.go
Original file line number Diff line number Diff line change
Expand Up @@ -91,8 +91,8 @@ func (c *CollationLegacy) Weights() (WeightTable, TableLayout) {
return c.table, TableLayout_uca_legacy{c.maxCodepoint}
}

func (c *CollationLegacy) Iterator(input []byte) WeightIteratorLegacy {
iter := c.iterpool.Get().(WeightIteratorLegacy)
func (c *CollationLegacy) Iterator(input []byte) *WeightIteratorLegacy {
iter := c.iterpool.Get().(*WeightIteratorLegacy)
iter.reset(input)
return iter
}
Expand All @@ -113,7 +113,7 @@ func NewCollationLegacy(enc encoding.Encoding, weights WeightTable, weightPatche
}

coll.iterpool.New = func() interface{} {
return &iteratorLegacy{CollationLegacy: *coll}
return &WeightIteratorLegacy{CollationLegacy: *coll}
}

return coll
Expand Down
30 changes: 14 additions & 16 deletions go/mysql/collations/internal/uca/contractions.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,17 +41,17 @@ func (t *trie) walk(remainder []byte) ([]uint16, []byte) {
return t.weights, remainder
}

func (t *trie) walkAnyEncoding(enc encoding.Encoding, remainder []byte) ([]uint16, []byte) {
func (t *trie) walkAnyEncoding(enc encoding.Encoding, remainder []byte, depth int) ([]uint16, []byte, int) {
if len(remainder) > 0 {
cp, width := enc.DecodeRune(remainder)
if cp == encoding.RuneError && width < 3 {
return nil, nil
return nil, nil, 0
}
if ch := t.children[cp]; ch != nil {
return ch.walkAnyEncoding(enc, remainder[width:])
return ch.walkAnyEncoding(enc, remainder[width:], depth+1)
}
}
return t.weights, remainder
return t.weights, remainder, depth
}

func (t *trie) insert(path []rune, weights []uint16) {
Expand Down Expand Up @@ -92,23 +92,21 @@ func (ctr *contractions) insert(c *Contraction) {
}

func (ctr *contractions) weightForContraction(cp rune, remainder []byte) ([]uint16, []byte) {
if ctr == nil {
return nil, nil
}
if tr := ctr.tr.children[cp]; tr != nil {
return tr.walk(remainder)
if ctr != nil {
if tr := ctr.tr.children[cp]; tr != nil {
return tr.walk(remainder)
}
}
return nil, nil
}

func (ctr *contractions) weightForContractionAnyEncoding(cp rune, remainder []byte, enc encoding.Encoding) ([]uint16, []byte) {
if ctr == nil {
return nil, nil
}
if tr := ctr.tr.children[cp]; tr != nil {
return tr.walkAnyEncoding(enc, remainder)
func (ctr *contractions) weightForContractionAnyEncoding(cp rune, remainder []byte, enc encoding.Encoding) ([]uint16, []byte, int) {
if ctr != nil {
if tr := ctr.tr.children[cp]; tr != nil {
return tr.walkAnyEncoding(enc, remainder, 0)
}
}
return nil, nil
return nil, nil, 0
}

func (ctr *contractions) weightForContextualContraction(cp, prev rune) []uint16 {
Expand Down
22 changes: 12 additions & 10 deletions go/mysql/collations/internal/uca/legacy.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,14 @@ import (
"vitess.io/vitess/go/mysql/collations/internal/encoding"
)

type iteratorLegacy struct {
type WeightIteratorLegacy struct {
// Constant
CollationLegacy

// Internal state
codepoint codepointIteratorLegacy
input []byte
length int
}

type codepointIteratorLegacy struct {
Expand Down Expand Up @@ -45,21 +46,22 @@ func (it *codepointIteratorLegacy) initContraction(weights []uint16) {
it.weights = weights
}

func (it *iteratorLegacy) reset(input []byte) {
func (it *WeightIteratorLegacy) reset(input []byte) {
it.input = input
it.length = 0
it.codepoint.weights = nil
}

func (it *iteratorLegacy) Done() {
func (it *WeightIteratorLegacy) Done() {
it.input = nil
it.iterpool.Put(it)
}

func (it *iteratorLegacy) DebugCodepoint() (rune, int) {
func (it *WeightIteratorLegacy) DebugCodepoint() (rune, int) {
return it.encoding.DecodeRune(it.input)
}

func (it *iteratorLegacy) Next() (uint16, bool) {
func (it *WeightIteratorLegacy) Next() (uint16, bool) {
for {
if w, ok := it.codepoint.next(); ok {
return w, true
Expand All @@ -70,21 +72,21 @@ func (it *iteratorLegacy) Next() (uint16, bool) {
return 0, false
}
it.input = it.input[width:]
it.length++

if cp > it.maxCodepoint {
return 0xFFFD, true
}
if weights, remainder := it.contractions.weightForContractionAnyEncoding(cp, it.input, it.encoding); weights != nil {
if weights, remainder, skip := it.contractions.weightForContractionAnyEncoding(cp, it.input, it.encoding); weights != nil {
it.codepoint.initContraction(weights)
it.input = remainder
it.length += skip
continue
}
it.codepoint.init(it.table, cp)
}
}

type WeightIteratorLegacy interface {
Next() (uint16, bool)
Done()
reset(input []byte)
func (it *WeightIteratorLegacy) Length() int {
return it.length
}
Loading

0 comments on commit 09d731f

Please sign in to comment.