Skip to content

Commit

Permalink
change type of hashcodes, and start testing
Browse files Browse the repository at this point in the history
Signed-off-by: Andres Taylor <[email protected]>
  • Loading branch information
systay committed Nov 13, 2021
1 parent 308abc1 commit 887167f
Show file tree
Hide file tree
Showing 12 changed files with 172 additions and 55 deletions.
6 changes: 3 additions & 3 deletions go/mysql/collations/8bit.go
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ func (c *Collation_8bit_bin) WeightString(dst, src []byte, numCodepoints int) []
return weightStringPadingSimple(' ', dst, numCodepoints-copyCodepoints, padToMax)
}

func (c *Collation_8bit_bin) Hash(src []byte, numCodepoints int) uintptr {
func (c *Collation_8bit_bin) Hash(src []byte, numCodepoints int) HashCode {
hash := 0x8b8b0000 | uintptr(c.id)
if numCodepoints == 0 {
return memhash(src, hash)
Expand Down Expand Up @@ -164,7 +164,7 @@ func (c *Collation_8bit_simple_ci) WeightString(dst, src []byte, numCodepoints i
return weightStringPadingSimple(' ', dst, numCodepoints-copyCodepoints, padToMax)
}

func (c *Collation_8bit_simple_ci) Hash(src []byte, numCodepoints int) uintptr {
func (c *Collation_8bit_simple_ci) Hash(src []byte, numCodepoints int) HashCode {
sortOrder := c.sort

var tocopy = len(src)
Expand Down Expand Up @@ -251,7 +251,7 @@ func (c *Collation_binary) WeightString(dst, src []byte, numCodepoints int) []by
return dst
}

func (c *Collation_binary) Hash(src []byte, numCodepoints int) uintptr {
func (c *Collation_binary) Hash(src []byte, numCodepoints int) HashCode {
if numCodepoints > 0 {
src = src[:numCodepoints]
}
Expand Down
4 changes: 3 additions & 1 deletion go/mysql/collations/collation.go
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ type Collation interface {
// the hash will interpret the source string as if it were stored in a `CHAR(n)` column. If the value of
// numCodepoints is 0, this is equivalent to setting `numCodepoints = RuneCount(src)`.
// For collations with NO PAD, the numCodepoint argument is ignored.
Hash(src []byte, numCodepoints int) uintptr
Hash(src []byte, numCodepoints int) HashCode

// Charset returns the Charset with which this collation is encoded
Charset() charset.Charset
Expand All @@ -128,6 +128,8 @@ type Collation interface {
IsBinary() bool
}

type HashCode = uintptr

const PadToMax = math.MaxInt32

func minInt(i1, i2 int) int {
Expand Down
2 changes: 1 addition & 1 deletion go/mysql/collations/multibyte.go
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ func (c *Collation_multibyte) WeightString(dst, src []byte, numCodepoints int) [
return dst
}

func (c *Collation_multibyte) Hash(src []byte, numCodepoints int) uintptr {
func (c *Collation_multibyte) Hash(src []byte, numCodepoints int) HashCode {
cs := c.charset
sortOrder := c.sort

Expand Down
2 changes: 1 addition & 1 deletion go/mysql/collations/remote/collation.go
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,7 @@ func (c *Collation) WeightString(dst, src []byte, numCodepoints int) []byte {
return dst
}

func (c *Collation) Hash(_ []byte, _ int) uintptr {
func (c *Collation) Hash(_ []byte, _ int) collations.HashCode {
panic("unsupported: Hash for remote collations")
}

Expand Down
6 changes: 3 additions & 3 deletions go/mysql/collations/uca.go
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ performPadding:
return dst
}

func (c *Collation_utf8mb4_uca_0900) Hash(src []byte, _ int) uintptr {
func (c *Collation_utf8mb4_uca_0900) Hash(src []byte, _ int) HashCode {
var hash = uintptr(c.id)

it := c.uca.Iterator(src)
Expand Down Expand Up @@ -244,7 +244,7 @@ func (c *Collation_utf8mb4_0900_bin) WeightString(dst, src []byte, numCodepoints
return dst
}

func (c *Collation_utf8mb4_0900_bin) Hash(src []byte, _ int) uintptr {
func (c *Collation_utf8mb4_0900_bin) Hash(src []byte, _ int) HashCode {
return memhash(src, 0xb900b900)
}

Expand Down Expand Up @@ -354,7 +354,7 @@ func (c *Collation_uca_legacy) WeightString(dst, src []byte, numCodepoints int)
return dst
}

func (c *Collation_uca_legacy) Hash(src []byte, numCodepoints int) uintptr {
func (c *Collation_uca_legacy) Hash(src []byte, numCodepoints int) HashCode {
it := c.uca.Iterator(src)
defer it.Done()

Expand Down
4 changes: 2 additions & 2 deletions go/mysql/collations/unicode.go
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ func (c *Collation_unicode_general_ci) WeightString(dst, src []byte, numCodepoin
return dst
}

func (c *Collation_unicode_general_ci) Hash(src []byte, numCodepoints int) uintptr {
func (c *Collation_unicode_general_ci) Hash(src []byte, numCodepoints int) HashCode {
unicaseInfo := c.unicase
cs := c.charset

Expand Down Expand Up @@ -278,7 +278,7 @@ func (c *Collation_unicode_bin) weightStringUnicode(dst, src []byte, numCodepoin
return dst
}

func (c *Collation_unicode_bin) Hash(src []byte, numCodepoints int) uintptr {
func (c *Collation_unicode_bin) Hash(src []byte, numCodepoints int) HashCode {
if c.charset.SupportsSupplementaryChars() {
return c.hashUnicode(src, numCodepoints)
}
Expand Down
9 changes: 5 additions & 4 deletions go/vt/vtgate/engine/distinct.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,14 +34,15 @@ type Distinct struct {
type row = []sqltypes.Value

type probeTable struct {
m map[int64][]row
m map[evalengine.HashCode][]row
}

func (pt *probeTable) exists(inputRow row) (bool, error) {
// calculate hashcode from all column values in the input row
code := int64(17)
code := evalengine.HashCode(17)
for _, value := range inputRow {
hashcode, err := evalengine.NullsafeHashcode(value)
// TODO: fetch the correct collation from the semantic table
hashcode, err := evalengine.NullsafeHashcode(value, collations.Unknown)
if err != nil {
return false, err
}
Expand Down Expand Up @@ -87,7 +88,7 @@ func equal(a, b []sqltypes.Value) (bool, error) {
}

func newProbeTable() *probeTable {
return &probeTable{m: map[int64][]row{}}
return &probeTable{m: map[uintptr][]row{}}
}

// TryExecute implements the Primitive interface
Expand Down
16 changes: 8 additions & 8 deletions go/vt/vtgate/engine/hash_join.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,9 +51,9 @@ type HashJoin struct {
LHSKey, RHSKey int

ASTPred sqlparser.Expr
}

type hashKey = int64
Collation collations.ID
}

// TryExecute implements the Primitive interface
func (hj *HashJoin) TryExecute(vcursor VCursor, bindVars map[string]*querypb.BindVariable, wantfields bool) (*sqltypes.Result, error) {
Expand All @@ -63,13 +63,13 @@ func (hj *HashJoin) TryExecute(vcursor VCursor, bindVars map[string]*querypb.Bin
}

// build the probe table from the LHS result
probeTable := map[hashKey][]row{}
probeTable := map[evalengine.HashCode][]row{}
for _, current := range lresult.Rows {
joinVal := current[hj.LHSKey]
if joinVal.IsNull() {
continue
}
hashcode, err := evalengine.NullsafeHashcode(joinVal)
hashcode, err := evalengine.NullsafeHashcode(joinVal, hj.Collation)
if err != nil {
return nil, err
}
Expand All @@ -90,7 +90,7 @@ func (hj *HashJoin) TryExecute(vcursor VCursor, bindVars map[string]*querypb.Bin
if joinVal.IsNull() {
continue
}
hashcode, err := evalengine.NullsafeHashcode(joinVal)
hashcode, err := evalengine.NullsafeHashcode(joinVal, hj.Collation)
if err != nil {
return nil, err
}
Expand All @@ -116,7 +116,7 @@ func (hj *HashJoin) TryExecute(vcursor VCursor, bindVars map[string]*querypb.Bin
// TryStreamExecute implements the Primitive interface
func (hj *HashJoin) TryStreamExecute(vcursor VCursor, bindVars map[string]*querypb.BindVariable, wantfields bool, callback func(*sqltypes.Result) error) error {
// build the probe table from the LHS result
probeTable := map[hashKey][]row{}
probeTable := map[evalengine.HashCode][]row{}
var lfields []*querypb.Field
err := vcursor.StreamExecutePrimitive(hj.Left, bindVars, wantfields, func(result *sqltypes.Result) error {
if len(lfields) == 0 && len(result.Fields) != 0 {
Expand All @@ -127,7 +127,7 @@ func (hj *HashJoin) TryStreamExecute(vcursor VCursor, bindVars map[string]*query
if joinVal.IsNull() {
continue
}
hashcode, err := evalengine.NullsafeHashcode(joinVal)
hashcode, err := evalengine.NullsafeHashcode(joinVal, hj.Collation)
if err != nil {
return err
}
Expand All @@ -151,7 +151,7 @@ func (hj *HashJoin) TryStreamExecute(vcursor VCursor, bindVars map[string]*query
if joinVal.IsNull() {
continue
}
hashcode, err := evalengine.NullsafeHashcode(joinVal)
hashcode, err := evalengine.NullsafeHashcode(joinVal, hj.Collation)
if err != nil {
return err
}
Expand Down
32 changes: 24 additions & 8 deletions go/vt/vtgate/evalengine/arithmetic.go
Original file line number Diff line number Diff line change
Expand Up @@ -251,23 +251,39 @@ func NullsafeCompare(v1, v2 sqltypes.Value, collationID collations.ID) (int, err
}
}

// HashCode is a type alias to the code easier to read
type HashCode = uintptr

// NullsafeHashcode returns an int64 hashcode that is guaranteed to be the same
// for two values that are considered equal by `NullsafeCompare`.
// TODO: should be extended to support all possible types
func NullsafeHashcode(v sqltypes.Value) (int64, error) {
if v.IsNull() {
return math.MaxInt64, nil
}
func NullsafeHashcode(v sqltypes.Value, collation collations.ID) (HashCode, error) {

if sqltypes.IsNumber(v.Type()) {
typ := v.Type()
switch {
case v.IsNull():
return HashCode(math.MaxInt64), nil
case sqltypes.IsNumber(typ):
result, err := newEvalResult(v)
if err != nil {
return 0, err
}
return hashCode(result), nil
return numericalHashCode(result), nil
case sqltypes.IsText(typ):
coll := collations.Default().LookupByID(collation)
return coll.Hash(v.Raw(), 0), nil
case sqltypes.IsDate(typ):
result, err := newEvalResult(v)
if err != nil {
return 0, err
}
time, err := parseDate(result)
if err != nil {
return 0, err
}
return uintptr(time.UnixNano()), nil
}

return 0, vterrors.Errorf(vtrpcpb.Code_UNIMPLEMENTED, "types does not support hashcode yet: %v", v.Type())
return 0, vterrors.Errorf(vtrpcpb.Code_UNIMPLEMENTED, "types does not support hashcode yet: %v", typ)
}

// isByteComparable returns true if the type is binary or date/time.
Expand Down
22 changes: 12 additions & 10 deletions go/vt/vtgate/evalengine/arithmetic_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1326,8 +1326,8 @@ func TestCompareNumeric(t *testing.T) {

// if two values are considered equal, they must also produce the same hashcode
if result == 0 {
aHash := hashCode(aVal)
bHash := hashCode(bVal)
aHash := numericalHashCode(aVal)
bHash := numericalHashCode(bVal)
assert.Equal(t, aHash, bHash, "hash code does not match")
}
})
Expand Down Expand Up @@ -1552,20 +1552,22 @@ func TestMaxCollate(t *testing.T) {
func TestHashCodes(t *testing.T) {
n1 := sqltypes.NULL
n2 := sqltypes.Value{}

h1, err := NullsafeHashcode(n1)
collation := collations.Default().DefaultCollationForCharset("utf8mb4")
h1, err := NullsafeHashcode(n1, collation.ID())
require.NoError(t, err)
h2, err := NullsafeHashcode(n2)
h2, err := NullsafeHashcode(n2, collation.ID())
require.NoError(t, err)
assert.Equal(t, h1, h2)

char := TestValue(querypb.Type_VARCHAR, "aa")
_, err = NullsafeHashcode(char)
require.Error(t, err)
char := TestValue(querypb.Type_VARCHAR, "1")
h1, err = NullsafeHashcode(char, collation.ID())
require.NoError(t, err)

num := TestValue(querypb.Type_INT64, "123")
_, err = NullsafeHashcode(num)
num := TestValue(querypb.Type_INT64, "1")
h2, err = NullsafeHashcode(num, collation.ID())
require.NoError(t, err)

require.Equal(t, h1, h2)
}

func printValue(v sqltypes.Value) string {
Expand Down
24 changes: 10 additions & 14 deletions go/vt/vtgate/evalengine/evalengine.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ limitations under the License.
package evalengine

import (
"math"
"time"

"vitess.io/vitess/go/mysql/collations"
Expand Down Expand Up @@ -243,21 +244,16 @@ func (v EvalResult) toSQLValue(resultType querypb.Type) sqltypes.Value {
return sqltypes.NULL
}

func hashCode(v EvalResult) int64 {
// we cast all numerical values to float64 and return the hashcode for that
var val float64
switch v.typ {
case sqltypes.Int64:
val = float64(v.ival)
case sqltypes.Uint64:
val = float64(v.uval)
case sqltypes.Float64:
val = v.fval
func numericalHashCode(v EvalResult) HashCode {
switch {
case sqltypes.IsSigned(v.typ):
return HashCode(v.ival)
case sqltypes.IsUnsigned(v.typ):
return HashCode(v.uval)
case sqltypes.IsFloat(v.typ) || v.typ == sqltypes.Decimal:
return HashCode(math.Float64bits(v.fval))
}

// this will not work for ±0, NaN and ±Inf,
// so one must still check using `compareNumeric` which will not be fooled
return int64(val)
panic("BUG: this is not a numerical value")
}

func compareNumeric(v1, v2 EvalResult) (int, error) {
Expand Down
Loading

0 comments on commit 887167f

Please sign in to comment.