diff --git a/go/hack/runtime.go b/go/hack/runtime.go index 556fc868081..d4a01c56b70 100644 --- a/go/hack/runtime.go +++ b/go/hack/runtime.go @@ -51,3 +51,6 @@ func roundupsize(size uintptr) uintptr func RuntimeAllocSize(size int64) int64 { return int64(roundupsize(uintptr(size))) } + +//go:linkname ParseFloatPrefix strconv.parseFloatPrefix +func ParseFloatPrefix(s string, bitSize int) (float64, int, error) diff --git a/go/mysql/collations/8bit.go b/go/mysql/collations/8bit.go index 886cad01695..4eb3d5d5dbc 100644 --- a/go/mysql/collations/8bit.go +++ b/go/mysql/collations/8bit.go @@ -82,7 +82,7 @@ func (c *Collation_8bit_bin) WeightString(dst, src []byte, numCodepoints int) [] return weightStringPadingSimple(' ', dst, numCodepoints-copyCodepoints, padToMax) } -func (c *Collation_8bit_bin) Hash(src []byte, numCodepoints int) uintptr { +func (c *Collation_8bit_bin) Hash(src []byte, numCodepoints int) HashCode { hash := 0x8b8b0000 | uintptr(c.id) if numCodepoints == 0 { return memhash(src, hash) @@ -164,7 +164,7 @@ func (c *Collation_8bit_simple_ci) WeightString(dst, src []byte, numCodepoints i return weightStringPadingSimple(' ', dst, numCodepoints-copyCodepoints, padToMax) } -func (c *Collation_8bit_simple_ci) Hash(src []byte, numCodepoints int) uintptr { +func (c *Collation_8bit_simple_ci) Hash(src []byte, numCodepoints int) HashCode { sortOrder := c.sort var tocopy = len(src) @@ -251,7 +251,7 @@ func (c *Collation_binary) WeightString(dst, src []byte, numCodepoints int) []by return dst } -func (c *Collation_binary) Hash(src []byte, numCodepoints int) uintptr { +func (c *Collation_binary) Hash(src []byte, numCodepoints int) HashCode { if numCodepoints > 0 { src = src[:numCodepoints] } diff --git a/go/mysql/collations/collation.go b/go/mysql/collations/collation.go index 09607788d12..5c8e1642ff6 100644 --- a/go/mysql/collations/collation.go +++ b/go/mysql/collations/collation.go @@ -119,7 +119,7 @@ type Collation interface { // the hash will interpret the source string as if it were stored in a `CHAR(n)` column. If the value of // numCodepoints is 0, this is equivalent to setting `numCodepoints = RuneCount(src)`. // For collations with NO PAD, the numCodepoint argument is ignored. - Hash(src []byte, numCodepoints int) uintptr + Hash(src []byte, numCodepoints int) HashCode // Charset returns the Charset with which this collation is encoded Charset() charset.Charset @@ -128,6 +128,8 @@ type Collation interface { IsBinary() bool } +type HashCode = uintptr + const PadToMax = math.MaxInt32 func minInt(i1, i2 int) int { diff --git a/go/mysql/collations/multibyte.go b/go/mysql/collations/multibyte.go index d8a167ed364..08e946dd3cf 100644 --- a/go/mysql/collations/multibyte.go +++ b/go/mysql/collations/multibyte.go @@ -147,7 +147,7 @@ func (c *Collation_multibyte) WeightString(dst, src []byte, numCodepoints int) [ return dst } -func (c *Collation_multibyte) Hash(src []byte, numCodepoints int) uintptr { +func (c *Collation_multibyte) Hash(src []byte, numCodepoints int) HashCode { cs := c.charset sortOrder := c.sort diff --git a/go/mysql/collations/remote/collation.go b/go/mysql/collations/remote/collation.go index 3b8ff869b61..abc1f734c53 100644 --- a/go/mysql/collations/remote/collation.go +++ b/go/mysql/collations/remote/collation.go @@ -169,7 +169,7 @@ func (c *Collation) WeightString(dst, src []byte, numCodepoints int) []byte { return dst } -func (c *Collation) Hash(_ []byte, _ int) uintptr { +func (c *Collation) Hash(_ []byte, _ int) collations.HashCode { panic("unsupported: Hash for remote collations") } diff --git a/go/mysql/collations/uca.go b/go/mysql/collations/uca.go index b52712e6162..0a676790701 100644 --- a/go/mysql/collations/uca.go +++ b/go/mysql/collations/uca.go @@ -161,7 +161,7 @@ performPadding: return dst } -func (c *Collation_utf8mb4_uca_0900) Hash(src []byte, _ int) uintptr { +func (c *Collation_utf8mb4_uca_0900) Hash(src []byte, _ int) HashCode { var hash = uintptr(c.id) it := c.uca.Iterator(src) @@ -234,7 +234,7 @@ func (c *Collation_utf8mb4_0900_bin) WeightString(dst, src []byte, numCodepoints return dst } -func (c *Collation_utf8mb4_0900_bin) Hash(src []byte, _ int) uintptr { +func (c *Collation_utf8mb4_0900_bin) Hash(src []byte, _ int) HashCode { return memhash(src, 0xb900b900) } @@ -340,7 +340,7 @@ func (c *Collation_uca_legacy) WeightString(dst, src []byte, numCodepoints int) return dst } -func (c *Collation_uca_legacy) Hash(src []byte, numCodepoints int) uintptr { +func (c *Collation_uca_legacy) Hash(src []byte, numCodepoints int) HashCode { it := c.uca.Iterator(src) defer it.Done() diff --git a/go/mysql/collations/unicode.go b/go/mysql/collations/unicode.go index a583b9ff6f1..a2c56162dc8 100644 --- a/go/mysql/collations/unicode.go +++ b/go/mysql/collations/unicode.go @@ -124,7 +124,7 @@ func (c *Collation_unicode_general_ci) WeightString(dst, src []byte, numCodepoin return dst } -func (c *Collation_unicode_general_ci) Hash(src []byte, numCodepoints int) uintptr { +func (c *Collation_unicode_general_ci) Hash(src []byte, numCodepoints int) HashCode { unicaseInfo := c.unicase cs := c.charset @@ -278,7 +278,7 @@ func (c *Collation_unicode_bin) weightStringUnicode(dst, src []byte, numCodepoin return dst } -func (c *Collation_unicode_bin) Hash(src []byte, numCodepoints int) uintptr { +func (c *Collation_unicode_bin) Hash(src []byte, numCodepoints int) HashCode { if c.charset.SupportsSupplementaryChars() { return c.hashUnicode(src, numCodepoints) } diff --git a/go/sqltypes/type.go b/go/sqltypes/type.go index df685f5758e..f1d4888c6a0 100644 --- a/go/sqltypes/type.go +++ b/go/sqltypes/type.go @@ -91,6 +91,11 @@ func IsDate(t querypb.Type) bool { return t == Datetime || t == Date || t == Timestamp || t == Time } +// IsNull returns true if the type is NULL type +func IsNull(t querypb.Type) bool { + return t == Null +} + // Vitess data types. These are idiomatically // named synonyms for the querypb.Type values. // Although these constants are interchangeable, diff --git a/go/sqltypes/value.go b/go/sqltypes/value.go index 283ba471392..5d1b5591167 100644 --- a/go/sqltypes/value.go +++ b/go/sqltypes/value.go @@ -227,6 +227,13 @@ func (v Value) Raw() []byte { return v.val } +// RawStr returns the internal representation of the value as a string instead +// of a byte slice. This is equivalent to calling `string(v.Raw())` but does +// not allocate. +func (v Value) RawStr() string { + return hack.String(v.val) +} + // ToBytes returns the value as MySQL would return it as []byte. // In contrast, Raw returns the internal representation of the Value, which may not // match MySQL's representation for newer types. diff --git a/go/test/endtoend/vtgate/gen4/column_name_test.go b/go/test/endtoend/vtgate/gen4/column_name_test.go index 2c238be8f3f..0cce81d99c8 100644 --- a/go/test/endtoend/vtgate/gen4/column_name_test.go +++ b/go/test/endtoend/vtgate/gen4/column_name_test.go @@ -20,6 +20,8 @@ import ( "context" "testing" + "vitess.io/vitess/go/test/endtoend/vtgate/utils" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -35,12 +37,10 @@ func TestColumnNames(t *testing.T) { require.NoError(t, err) defer conn.Close() - _, err = exec(t, conn, "create table uks.t2(id bigint,phone bigint,msg varchar(100),primary key(id)) Engine=InnoDB") - require.NoError(t, err) - defer exec(t, conn, "drop table uks.t2") + utils.Exec(t, conn, "create table uks.t2(id bigint,phone bigint,msg varchar(100),primary key(id)) Engine=InnoDB") + defer utils.Exec(t, conn, "drop table uks.t2") - qr, err := exec(t, conn, "SELECT t1.id as t1id, t2.id as t2id, t2.phone as t2phn FROM ks.t1 cross join uks.t2 where t1.id = t2.id ORDER BY t2.phone") - require.NoError(t, err) + qr := utils.Exec(t, conn, "SELECT t1.id as t1id, t2.id as t2id, t2.phone as t2phn FROM ks.t1 cross join uks.t2 where t1.id = t2.id ORDER BY t2.phone") assert.Equal(t, 3, len(qr.Fields)) assert.Equal(t, "t1id", qr.Fields[0].Name) diff --git a/go/test/endtoend/vtgate/gen4/gen4_test.go b/go/test/endtoend/vtgate/gen4/gen4_test.go index 1b318bd4558..92e679b6ffd 100644 --- a/go/test/endtoend/vtgate/gen4/gen4_test.go +++ b/go/test/endtoend/vtgate/gen4/gen4_test.go @@ -21,13 +21,13 @@ import ( "fmt" "testing" + "vitess.io/vitess/go/test/endtoend/vtgate/utils" + "github.com/stretchr/testify/assert" - "github.com/google/go-cmp/cmp" "github.com/stretchr/testify/require" "vitess.io/vitess/go/mysql" - "vitess.io/vitess/go/sqltypes" ) func TestOrderBy(t *testing.T) { @@ -37,20 +37,20 @@ func TestOrderBy(t *testing.T) { defer conn.Close() defer func() { - _, _ = exec(t, conn, `delete from t1`) + _, _ = utils.ExecAllowError(t, conn, `delete from t1`) }() // insert some data. - checkedExec(t, conn, `insert into t1(id, col) values (100, 123),(10, 12),(1, 13),(1000, 1234)`) + utils.Exec(t, conn, `insert into t1(id, col) values (100, 123),(10, 12),(1, 13),(1000, 1234)`) // Gen4 only supported query. - assertMatches(t, conn, `select col from t1 order by id`, `[[INT64(13)] [INT64(12)] [INT64(123)] [INT64(1234)]]`) + utils.AssertMatches(t, conn, `select col from t1 order by id`, `[[INT64(13)] [INT64(12)] [INT64(123)] [INT64(1234)]]`) // Gen4 unsupported query. v3 supported. - assertMatches(t, conn, `select col from t1 order by 1`, `[[INT64(12)] [INT64(13)] [INT64(123)] [INT64(1234)]]`) + utils.AssertMatches(t, conn, `select col from t1 order by 1`, `[[INT64(12)] [INT64(13)] [INT64(123)] [INT64(1234)]]`) // unsupported in v3 and Gen4. - _, err = exec(t, conn, `select t1.* from t1 order by id`) + _, err = utils.ExecAllowError(t, conn, `select t1.* from t1 order by id`) require.Error(t, err) } @@ -61,15 +61,15 @@ func TestCorrelatedExistsSubquery(t *testing.T) { defer conn.Close() defer func() { - _, _ = exec(t, conn, `delete from t1`) - _, _ = exec(t, conn, `delete from t2`) + _, _ = utils.ExecAllowError(t, conn, `delete from t1`) + _, _ = utils.ExecAllowError(t, conn, `delete from t2`) }() // insert some data. - checkedExec(t, conn, `insert into t1(id, col) values (100, 123),(10, 12), (1, 13), (4, 13),(1000, 1234)`) - checkedExec(t, conn, `insert into t2(id, tcol1, tcol2) values (100, 13, 1),(9, 7, 15),(1, 123, 123),(1004, 134, 123)`) + utils.Exec(t, conn, `insert into t1(id, col) values (100, 123),(10, 12), (1, 13), (4, 13),(1000, 1234)`) + utils.Exec(t, conn, `insert into t2(id, tcol1, tcol2) values (100, 13, 1),(9, 7, 15),(1, 123, 123),(1004, 134, 123)`) - assertMatches(t, conn, `select id from t1 where exists(select 1 from t2 where t1.col = t2.tcol2)`, `[[INT64(100)]]`) - assertMatches(t, conn, `select id from t1 where exists(select 1 from t2 where t1.col = t2.tcol1) order by id`, `[[INT64(1)] [INT64(4)] [INT64(100)]]`) + utils.AssertMatches(t, conn, `select id from t1 where exists(select 1 from t2 where t1.col = t2.tcol2)`, `[[INT64(100)]]`) + utils.AssertMatches(t, conn, `select id from t1 where exists(select 1 from t2 where t1.col = t2.tcol1) order by id`, `[[INT64(1)] [INT64(4)] [INT64(100)]]`) } func TestGroupBy(t *testing.T) { @@ -79,25 +79,25 @@ func TestGroupBy(t *testing.T) { defer conn.Close() defer func() { - _, _ = exec(t, conn, `delete from t1`) - _, _ = exec(t, conn, `delete from t2`) + _, _ = utils.ExecAllowError(t, conn, `delete from t1`) + _, _ = utils.ExecAllowError(t, conn, `delete from t2`) }() // insert some data. - checkedExec(t, conn, `insert into t1(id, col) values (1, 123),(2, 12),(3, 13),(4, 1234)`) - checkedExec(t, conn, `insert into t2(id, tcol1, tcol2) values (1, 'A', 'A'),(2, 'B', 'C'),(3, 'A', 'C'),(4, 'C', 'A'),(5, 'A', 'A'),(6, 'B', 'C'),(7, 'B', 'A'),(8, 'C', 'B')`) + utils.Exec(t, conn, `insert into t1(id, col) values (1, 123),(2, 12),(3, 13),(4, 1234)`) + utils.Exec(t, conn, `insert into t2(id, tcol1, tcol2) values (1, 'A', 'A'),(2, 'B', 'C'),(3, 'A', 'C'),(4, 'C', 'A'),(5, 'A', 'A'),(6, 'B', 'C'),(7, 'B', 'A'),(8, 'C', 'B')`) // Gen4 only supported query. - assertMatches(t, conn, `select tcol2, tcol1, count(id) from t2 group by tcol2, tcol1`, + utils.AssertMatches(t, conn, `select tcol2, tcol1, count(id) from t2 group by tcol2, tcol1`, `[[VARCHAR("A") VARCHAR("A") INT64(2)] [VARCHAR("A") VARCHAR("B") INT64(1)] [VARCHAR("A") VARCHAR("C") INT64(1)] [VARCHAR("B") VARCHAR("C") INT64(1)] [VARCHAR("C") VARCHAR("A") INT64(1)] [VARCHAR("C") VARCHAR("B") INT64(2)]]`) - assertMatches(t, conn, `select tcol1, tcol1 from t2 order by tcol1`, + utils.AssertMatches(t, conn, `select tcol1, tcol1 from t2 order by tcol1`, `[[VARCHAR("A") VARCHAR("A")] [VARCHAR("A") VARCHAR("A")] [VARCHAR("A") VARCHAR("A")] [VARCHAR("B") VARCHAR("B")] [VARCHAR("B") VARCHAR("B")] [VARCHAR("B") VARCHAR("B")] [VARCHAR("C") VARCHAR("C")] [VARCHAR("C") VARCHAR("C")]]`) - assertMatches(t, conn, `select tcol1, tcol1 from t1 join t2 on t1.id = t2.id order by tcol1`, + utils.AssertMatches(t, conn, `select tcol1, tcol1 from t1 join t2 on t1.id = t2.id order by tcol1`, `[[VARCHAR("A") VARCHAR("A")] [VARCHAR("A") VARCHAR("A")] [VARCHAR("B") VARCHAR("B")] [VARCHAR("C") VARCHAR("C")]]`) - assertMatches(t, conn, `select count(*) k, tcol1, tcol2, "abc" b from t2 group by tcol1, tcol2, b order by k, tcol2, tcol1`, + utils.AssertMatches(t, conn, `select count(*) k, tcol1, tcol2, "abc" b from t2 group by tcol1, tcol2, b order by k, tcol2, tcol1`, `[[INT64(1) VARCHAR("B") VARCHAR("A") VARCHAR("abc")] `+ `[INT64(1) VARCHAR("C") VARCHAR("A") VARCHAR("abc")] `+ `[INT64(1) VARCHAR("C") VARCHAR("B") VARCHAR("abc")] `+ @@ -113,14 +113,14 @@ func TestJoinBindVars(t *testing.T) { defer conn.Close() defer func() { - _, _ = exec(t, conn, `delete from t2`) - _, _ = exec(t, conn, `delete from t3`) + _, _ = utils.ExecAllowError(t, conn, `delete from t2`) + _, _ = utils.ExecAllowError(t, conn, `delete from t3`) }() - checkedExec(t, conn, `insert into t2(id, tcol1, tcol2) values (1, 'A', 'A'),(2, 'B', 'C'),(3, 'A', 'C'),(4, 'C', 'A'),(5, 'A', 'A'),(6, 'B', 'C'),(7, 'B', 'A'),(8, 'C', 'B')`) - checkedExec(t, conn, `insert into t3(id, tcol1, tcol2) values (1, 'A', 'A'),(2, 'B', 'C'),(3, 'A', 'C'),(4, 'C', 'A'),(5, 'A', 'A'),(6, 'B', 'C'),(7, 'B', 'A'),(8, 'C', 'B')`) + utils.Exec(t, conn, `insert into t2(id, tcol1, tcol2) values (1, 'A', 'A'),(2, 'B', 'C'),(3, 'A', 'C'),(4, 'C', 'A'),(5, 'A', 'A'),(6, 'B', 'C'),(7, 'B', 'A'),(8, 'C', 'B')`) + utils.Exec(t, conn, `insert into t3(id, tcol1, tcol2) values (1, 'A', 'A'),(2, 'B', 'C'),(3, 'A', 'C'),(4, 'C', 'A'),(5, 'A', 'A'),(6, 'B', 'C'),(7, 'B', 'A'),(8, 'C', 'B')`) - assertMatches(t, conn, `select t2.tcol1 from t2 join t3 on t2.tcol2 = t3.tcol2 where t2.tcol1 = 'A'`, `[[VARCHAR("A")] [VARCHAR("A")] [VARCHAR("A")] [VARCHAR("A")] [VARCHAR("A")] [VARCHAR("A")] [VARCHAR("A")] [VARCHAR("A")] [VARCHAR("A")] [VARCHAR("A")] [VARCHAR("A")]]`) + utils.AssertMatches(t, conn, `select t2.tcol1 from t2 join t3 on t2.tcol2 = t3.tcol2 where t2.tcol1 = 'A'`, `[[VARCHAR("A")] [VARCHAR("A")] [VARCHAR("A")] [VARCHAR("A")] [VARCHAR("A")] [VARCHAR("A")] [VARCHAR("A")] [VARCHAR("A")] [VARCHAR("A")] [VARCHAR("A")] [VARCHAR("A")]]`) } func TestDistinctAggregationFunc(t *testing.T) { @@ -129,36 +129,36 @@ func TestDistinctAggregationFunc(t *testing.T) { require.NoError(t, err) defer conn.Close() - defer exec(t, conn, `delete from t2`) + defer utils.ExecAllowError(t, conn, `delete from t2`) // insert some data. - checkedExec(t, conn, `insert into t2(id, tcol1, tcol2) values (1, 'A', 'A'),(2, 'B', 'C'),(3, 'A', 'C'),(4, 'C', 'A'),(5, 'A', 'A'),(6, 'B', 'C'),(7, 'B', 'A'),(8, 'C', 'A')`) + utils.Exec(t, conn, `insert into t2(id, tcol1, tcol2) values (1, 'A', 'A'),(2, 'B', 'C'),(3, 'A', 'C'),(4, 'C', 'A'),(5, 'A', 'A'),(6, 'B', 'C'),(7, 'B', 'A'),(8, 'C', 'A')`) // count on primary vindex - assertMatches(t, conn, `select tcol1, count(distinct id) from t2 group by tcol1`, + utils.AssertMatches(t, conn, `select tcol1, count(distinct id) from t2 group by tcol1`, `[[VARCHAR("A") INT64(3)] [VARCHAR("B") INT64(3)] [VARCHAR("C") INT64(2)]]`) // count on any column - assertMatches(t, conn, `select tcol1, count(distinct tcol2) from t2 group by tcol1`, + utils.AssertMatches(t, conn, `select tcol1, count(distinct tcol2) from t2 group by tcol1`, `[[VARCHAR("A") INT64(2)] [VARCHAR("B") INT64(2)] [VARCHAR("C") INT64(1)]]`) // sum of columns - assertMatches(t, conn, `select sum(id), sum(tcol1) from t2`, + utils.AssertMatches(t, conn, `select sum(id), sum(tcol1) from t2`, `[[DECIMAL(36) FLOAT64(0)]]`) // sum on primary vindex - assertMatches(t, conn, `select tcol1, sum(distinct id) from t2 group by tcol1`, + utils.AssertMatches(t, conn, `select tcol1, sum(distinct id) from t2 group by tcol1`, `[[VARCHAR("A") DECIMAL(9)] [VARCHAR("B") DECIMAL(15)] [VARCHAR("C") DECIMAL(12)]]`) // sum on any column - assertMatches(t, conn, `select tcol1, sum(distinct tcol2) from t2 group by tcol1`, + utils.AssertMatches(t, conn, `select tcol1, sum(distinct tcol2) from t2 group by tcol1`, `[[VARCHAR("A") DECIMAL(0)] [VARCHAR("B") DECIMAL(0)] [VARCHAR("C") DECIMAL(0)]]`) // insert more data to get values on sum - checkedExec(t, conn, `insert into t2(id, tcol1, tcol2) values (9, 'AA', null),(10, 'AA', '4'),(11, 'AA', '4'),(12, null, '5'),(13, null, '6'),(14, 'BB', '10'),(15, 'BB', '20'),(16, 'BB', 'X')`) + utils.Exec(t, conn, `insert into t2(id, tcol1, tcol2) values (9, 'AA', null),(10, 'AA', '4'),(11, 'AA', '4'),(12, null, '5'),(13, null, '6'),(14, 'BB', '10'),(15, 'BB', '20'),(16, 'BB', 'X')`) // multi distinct - assertMatches(t, conn, `select tcol1, count(distinct tcol2), sum(distinct tcol2) from t2 group by tcol1`, + utils.AssertMatches(t, conn, `select tcol1, count(distinct tcol2), sum(distinct tcol2) from t2 group by tcol1`, `[[NULL INT64(2) DECIMAL(11)] [VARCHAR("A") INT64(2) DECIMAL(0)] [VARCHAR("AA") INT64(1) DECIMAL(4)] [VARCHAR("B") INT64(2) DECIMAL(0)] [VARCHAR("BB") INT64(3) DECIMAL(30)] [VARCHAR("C") INT64(1) DECIMAL(0)]]`) } @@ -168,13 +168,13 @@ func TestDistinct(t *testing.T) { require.NoError(t, err) defer conn.Close() - defer exec(t, conn, `delete from t2`) + defer utils.ExecAllowError(t, conn, `delete from t2`) // insert some data. - checkedExec(t, conn, `insert into t2(id, tcol1, tcol2) values (1, 'A', 'A'),(2, 'B', 'C'),(3, 'A', 'C'),(4, 'C', 'A'),(5, 'A', 'A'),(6, 'B', 'C'),(7, 'B', 'A'),(8, 'C', 'A')`) + utils.Exec(t, conn, `insert into t2(id, tcol1, tcol2) values (1, 'A', 'A'),(2, 'B', 'C'),(3, 'A', 'C'),(4, 'C', 'A'),(5, 'A', 'A'),(6, 'B', 'C'),(7, 'B', 'A'),(8, 'C', 'A')`) // multi distinct - assertMatches(t, conn, `select distinct tcol1, tcol2 from t2`, + utils.AssertMatches(t, conn, `select distinct tcol1, tcol2 from t2`, `[[VARCHAR("A") VARCHAR("A")] [VARCHAR("A") VARCHAR("C")] [VARCHAR("B") VARCHAR("A")] [VARCHAR("B") VARCHAR("C")] [VARCHAR("C") VARCHAR("A")]]`) } @@ -185,33 +185,33 @@ func TestSubQueries(t *testing.T) { defer conn.Close() defer func() { - _, _ = exec(t, conn, `delete from t2`) - _, _ = exec(t, conn, `delete from t3`) - _, _ = exec(t, conn, `delete from u_a`) + _, _ = utils.ExecAllowError(t, conn, `delete from t2`) + _, _ = utils.ExecAllowError(t, conn, `delete from t3`) + _, _ = utils.ExecAllowError(t, conn, `delete from u_a`) }() - checkedExec(t, conn, `insert into t2(id, tcol1, tcol2) values (1, 'A', 'A'),(2, 'B', 'C'),(3, 'A', 'C'),(4, 'C', 'A'),(5, 'A', 'A'),(6, 'B', 'C'),(7, 'B', 'A'),(8, 'C', 'B')`) - checkedExec(t, conn, `insert into t3(id, tcol1, tcol2) values (1, 'A', 'A'),(2, 'B', 'C'),(3, 'A', 'C'),(4, 'C', 'A'),(5, 'A', 'A'),(6, 'B', 'C'),(7, 'B', 'A'),(8, 'C', 'B')`) + utils.Exec(t, conn, `insert into t2(id, tcol1, tcol2) values (1, 'A', 'A'),(2, 'B', 'C'),(3, 'A', 'C'),(4, 'C', 'A'),(5, 'A', 'A'),(6, 'B', 'C'),(7, 'B', 'A'),(8, 'C', 'B')`) + utils.Exec(t, conn, `insert into t3(id, tcol1, tcol2) values (1, 'A', 'A'),(2, 'B', 'C'),(3, 'A', 'C'),(4, 'C', 'A'),(5, 'A', 'A'),(6, 'B', 'C'),(7, 'B', 'A'),(8, 'C', 'B')`) - assertMatches(t, conn, `select t2.tcol1, t2.tcol2 from t2 where t2.id IN (select id from t3) order by t2.id`, `[[VARCHAR("A") VARCHAR("A")] [VARCHAR("B") VARCHAR("C")] [VARCHAR("A") VARCHAR("C")] [VARCHAR("C") VARCHAR("A")] [VARCHAR("A") VARCHAR("A")] [VARCHAR("B") VARCHAR("C")] [VARCHAR("B") VARCHAR("A")] [VARCHAR("C") VARCHAR("B")]]`) - assertMatches(t, conn, `select t2.tcol1, t2.tcol2 from t2 where t2.id IN (select t3.id from t3 join t2 on t2.id = t3.id) order by t2.id`, `[[VARCHAR("A") VARCHAR("A")] [VARCHAR("B") VARCHAR("C")] [VARCHAR("A") VARCHAR("C")] [VARCHAR("C") VARCHAR("A")] [VARCHAR("A") VARCHAR("A")] [VARCHAR("B") VARCHAR("C")] [VARCHAR("B") VARCHAR("A")] [VARCHAR("C") VARCHAR("B")]]`) + utils.AssertMatches(t, conn, `select t2.tcol1, t2.tcol2 from t2 where t2.id IN (select id from t3) order by t2.id`, `[[VARCHAR("A") VARCHAR("A")] [VARCHAR("B") VARCHAR("C")] [VARCHAR("A") VARCHAR("C")] [VARCHAR("C") VARCHAR("A")] [VARCHAR("A") VARCHAR("A")] [VARCHAR("B") VARCHAR("C")] [VARCHAR("B") VARCHAR("A")] [VARCHAR("C") VARCHAR("B")]]`) + utils.AssertMatches(t, conn, `select t2.tcol1, t2.tcol2 from t2 where t2.id IN (select t3.id from t3 join t2 on t2.id = t3.id) order by t2.id`, `[[VARCHAR("A") VARCHAR("A")] [VARCHAR("B") VARCHAR("C")] [VARCHAR("A") VARCHAR("C")] [VARCHAR("C") VARCHAR("A")] [VARCHAR("A") VARCHAR("A")] [VARCHAR("B") VARCHAR("C")] [VARCHAR("B") VARCHAR("A")] [VARCHAR("C") VARCHAR("B")]]`) - assertMatches(t, conn, `select u_a.a from u_a left join t2 on t2.id IN (select id from t2)`, `[]`) + utils.AssertMatches(t, conn, `select u_a.a from u_a left join t2 on t2.id IN (select id from t2)`, `[]`) //inserting some data in u_a - checkedExec(t, conn, `insert into u_a(id, a) values (1, 1)`) + utils.Exec(t, conn, `insert into u_a(id, a) values (1, 1)`) // execute same query again. - qr := checkedExec(t, conn, `select u_a.a from u_a left join t2 on t2.id IN (select id from t2)`) + qr := utils.Exec(t, conn, `select u_a.a from u_a left join t2 on t2.id IN (select id from t2)`) assert.EqualValues(t, 8, len(qr.Rows)) for index, row := range qr.Rows { assert.EqualValues(t, `[INT64(1)]`, fmt.Sprintf("%v", row), "does not match for row: %d", index+1) } // fail as projection subquery is not scalar - _, err = exec(t, conn, `select (select id from t2) from t2 order by id`) + _, err = utils.ExecAllowError(t, conn, `select (select id from t2) from t2 order by id`) assert.EqualError(t, err, "subquery returned more than one row (errno 1105) (sqlstate HY000) during query: select (select id from t2) from t2 order by id") - assertMatches(t, conn, `select (select id from t2 order by id limit 1) from t2 order by id limit 2`, `[[INT64(1)] [INT64(1)]]`) + utils.AssertMatches(t, conn, `select (select id from t2 order by id limit 1) from t2 order by id limit 2`, `[[INT64(1)] [INT64(1)]]`) } func TestPlannerWarning(t *testing.T) { @@ -221,39 +221,36 @@ func TestPlannerWarning(t *testing.T) { defer conn.Close() // straight_join query - _ = checkedExec(t, conn, `select 1 from t1 straight_join t2 on t1.id = t2.id`) - assertMatches(t, conn, `show warnings`, `[[VARCHAR("Warning") UINT16(1235) VARCHAR("straight join is converted to normal join")]]`) + _ = utils.Exec(t, conn, `select 1 from t1 straight_join t2 on t1.id = t2.id`) + utils.AssertMatches(t, conn, `show warnings`, `[[VARCHAR("Warning") UINT16(1235) VARCHAR("straight join is converted to normal join")]]`) // execute same query again. - _ = checkedExec(t, conn, `select 1 from t1 straight_join t2 on t1.id = t2.id`) - assertMatches(t, conn, `show warnings`, `[[VARCHAR("Warning") UINT16(1235) VARCHAR("straight join is converted to normal join")]]`) + _ = utils.Exec(t, conn, `select 1 from t1 straight_join t2 on t1.id = t2.id`) + utils.AssertMatches(t, conn, `show warnings`, `[[VARCHAR("Warning") UINT16(1235) VARCHAR("straight join is converted to normal join")]]`) // random query to reset the warning. - _ = checkedExec(t, conn, `select 1 from t1`) + _ = utils.Exec(t, conn, `select 1 from t1`) // execute same query again. - _ = checkedExec(t, conn, `select 1 from t1 straight_join t2 on t1.id = t2.id`) - assertMatches(t, conn, `show warnings`, `[[VARCHAR("Warning") UINT16(1235) VARCHAR("straight join is converted to normal join")]]`) + _ = utils.Exec(t, conn, `select 1 from t1 straight_join t2 on t1.id = t2.id`) + utils.AssertMatches(t, conn, `show warnings`, `[[VARCHAR("Warning") UINT16(1235) VARCHAR("straight join is converted to normal join")]]`) } -func assertMatches(t *testing.T, conn *mysql.Conn, query, expected string) { - t.Helper() - qr := checkedExec(t, conn, query) - got := fmt.Sprintf("%v", qr.Rows) - diff := cmp.Diff(expected, got) - if diff != "" { - t.Errorf("Query: %s (-want +got):\n%s", query, diff) - } -} +func TestHashJoin(t *testing.T) { + ctx := context.Background() + conn, err := mysql.Connect(ctx, &vtParams) + require.NoError(t, err) + defer conn.Close() -func checkedExec(t *testing.T, conn *mysql.Conn, query string) *sqltypes.Result { - t.Helper() - qr, err := exec(t, conn, query) - require.NoError(t, err, "for query: "+query) - return qr -} + defer func() { + _, _ = utils.ExecAllowError(t, conn, `delete from t1`) + }() + + utils.Exec(t, conn, `insert into t1(id, col) values (1, 1),(2, 3),(3, 4),(4, 7)`) + + utils.AssertMatches(t, conn, `select /*vt+ ALLOW_HASH_JOIN */ t1.id from t1 x join t1 where x.col = t1.col and x.id <= 3 and t1.id >= 3`, `[[INT64(3)]]`) -func exec(t *testing.T, conn *mysql.Conn, query string) (*sqltypes.Result, error) { - t.Helper() - return conn.ExecuteFetch(query, 1000, true) + utils.Exec(t, conn, `set workload = olap`) + defer utils.Exec(t, conn, `set workload = oltp`) + utils.AssertMatches(t, conn, `select /*vt+ ALLOW_HASH_JOIN */ t1.id from t1 x join t1 where x.col = t1.col and x.id <= 3 and t1.id >= 3`, `[[INT64(3)]]`) } diff --git a/go/test/endtoend/vtgate/gen4/system_schema_test.go b/go/test/endtoend/vtgate/gen4/system_schema_test.go index 1701a137507..de002259dba 100644 --- a/go/test/endtoend/vtgate/gen4/system_schema_test.go +++ b/go/test/endtoend/vtgate/gen4/system_schema_test.go @@ -21,6 +21,8 @@ import ( "fmt" "testing" + "vitess.io/vitess/go/test/endtoend/vtgate/utils" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -84,7 +86,7 @@ func TestInformationSchemaWithSubquery(t *testing.T) { require.NoError(t, err) defer conn.Close() - result := checkedExec(t, conn, "SELECT column_name FROM INFORMATION_SCHEMA.COLUMNS WHERE TABLE_SCHEMA = (SELECT SCHEMA()) AND TABLE_NAME = 'not_exists'") + result := utils.Exec(t, conn, "SELECT column_name FROM INFORMATION_SCHEMA.COLUMNS WHERE TABLE_SCHEMA = (SELECT SCHEMA()) AND TABLE_NAME = 'not_exists'") assert.Empty(t, result.Rows) } @@ -95,8 +97,8 @@ func TestInformationSchemaQueryGetsRoutedToTheRightTableAndKeyspace(t *testing.T require.NoError(t, err) defer conn.Close() - _ = checkedExec(t, conn, "SELECT id FROM ks.t1000") // test that the routed table is available to us - result := checkedExec(t, conn, "SELECT * FROM information_schema.tables WHERE table_schema = database() and table_name='ks.t1000'") + _ = utils.Exec(t, conn, "SELECT id FROM ks.t1000") // test that the routed table is available to us + result := utils.Exec(t, conn, "SELECT * FROM information_schema.tables WHERE table_schema = database() and table_name='ks.t1000'") assert.NotEmpty(t, result.Rows) } @@ -107,12 +109,12 @@ func TestFKConstraintUsingInformationSchema(t *testing.T) { require.NoError(t, err) defer conn.Close() - checkedExec(t, conn, "create table ks.t7_xxhash(uid varchar(50),phone bigint,msg varchar(100),primary key(uid)) Engine=InnoDB") - checkedExec(t, conn, "create table ks.t7_fk(id bigint,t7_uid varchar(50),primary key(id),CONSTRAINT t7_fk_ibfk_1 foreign key (t7_uid) references t7_xxhash(uid) on delete set null on update cascade) Engine=InnoDB;") - defer checkedExec(t, conn, "drop table ks.t7_fk, ks.t7_xxhash") + utils.Exec(t, conn, "create table ks.t7_xxhash(uid varchar(50),phone bigint,msg varchar(100),primary key(uid)) Engine=InnoDB") + utils.Exec(t, conn, "create table ks.t7_fk(id bigint,t7_uid varchar(50),primary key(id),CONSTRAINT t7_fk_ibfk_1 foreign key (t7_uid) references t7_xxhash(uid) on delete set null on update cascade) Engine=InnoDB;") + defer utils.Exec(t, conn, "drop table ks.t7_fk, ks.t7_xxhash") query := "select fk.referenced_table_name as to_table, fk.referenced_column_name as primary_key, fk.column_name as `column`, fk.constraint_name as name, rc.update_rule as on_update, rc.delete_rule as on_delete from information_schema.referential_constraints as rc join information_schema.key_column_usage as fk using (constraint_schema, constraint_name) where fk.referenced_column_name is not null and fk.table_schema = database() and fk.table_name = 't7_fk' and rc.constraint_schema = database() and rc.table_name = 't7_fk'" - assertMatches(t, conn, query, `[[VARCHAR("t7_xxhash") VARCHAR("uid") VARCHAR("t7_uid") VARCHAR("t7_fk_ibfk_1") VARCHAR("CASCADE") VARCHAR("SET NULL")]]`) + utils.AssertMatches(t, conn, query, `[[VARCHAR("t7_xxhash") VARCHAR("uid") VARCHAR("t7_uid") VARCHAR("t7_fk_ibfk_1") VARCHAR("CASCADE") VARCHAR("SET NULL")]]`) } func TestConnectWithSystemSchema(t *testing.T) { @@ -123,7 +125,7 @@ func TestConnectWithSystemSchema(t *testing.T) { connParams.DbName = dbname conn, err := mysql.Connect(ctx, &connParams) require.NoError(t, err) - checkedExec(t, conn, `select @@max_allowed_packet from dual`) + utils.Exec(t, conn, `select @@max_allowed_packet from dual`) conn.Close() } } @@ -135,8 +137,8 @@ func TestUseSystemSchema(t *testing.T) { require.NoError(t, err) defer conn.Close() for _, dbname := range []string{"information_schema", "mysql", "performance_schema", "sys"} { - checkedExec(t, conn, fmt.Sprintf("use %s", dbname)) - checkedExec(t, conn, `select @@max_allowed_packet from dual`) + utils.Exec(t, conn, fmt.Sprintf("use %s", dbname)) + utils.Exec(t, conn, `select @@max_allowed_packet from dual`) } } @@ -153,16 +155,16 @@ func TestSystemSchemaQueryWithoutQualifier(t *testing.T) { "on c.table_schema = t.table_schema and c.table_name = t.table_name "+ "where t.table_schema = '%s' and c.table_schema = '%s' "+ "order by t.table_schema,t.table_name,c.column_name", shardedKs, shardedKs) - qr1 := checkedExec(t, conn, queryWithQualifier) + qr1 := utils.Exec(t, conn, queryWithQualifier) - checkedExec(t, conn, "use information_schema") + utils.Exec(t, conn, "use information_schema") queryWithoutQualifier := fmt.Sprintf("select t.table_schema,t.table_name,c.column_name,c.column_type "+ "from tables t "+ "join columns c "+ "on c.table_schema = t.table_schema and c.table_name = t.table_name "+ "where t.table_schema = '%s' and c.table_schema = '%s' "+ "order by t.table_schema,t.table_name,c.column_name", shardedKs, shardedKs) - qr2 := checkedExec(t, conn, queryWithoutQualifier) + qr2 := utils.Exec(t, conn, queryWithoutQualifier) require.Equal(t, qr1, qr2) connParams := vtParams @@ -171,7 +173,7 @@ func TestSystemSchemaQueryWithoutQualifier(t *testing.T) { require.NoError(t, err) defer conn2.Close() - qr3 := checkedExec(t, conn2, queryWithoutQualifier) + qr3 := utils.Exec(t, conn2, queryWithoutQualifier) require.Equal(t, qr2, qr3) } @@ -187,7 +189,7 @@ func TestMultipleSchemaPredicates(t *testing.T) { "join information_schema.columns c "+ "on c.table_schema = t.table_schema and c.table_name = t.table_name "+ "where t.table_schema = '%s' and c.table_schema = '%s' and c.table_schema = '%s' and c.table_schema = '%s'", shardedKs, shardedKs, shardedKs, shardedKs) - qr1 := checkedExec(t, conn, query) + qr1 := utils.Exec(t, conn, query) require.EqualValues(t, 4, len(qr1.Fields)) // test a query with two keyspace names diff --git a/go/vt/sqlparser/comments.go b/go/vt/sqlparser/comments.go index dc5bdc85344..3c86f1f05cd 100644 --- a/go/vt/sqlparser/comments.go +++ b/go/vt/sqlparser/comments.go @@ -39,6 +39,8 @@ const ( DirectiveIgnoreMaxMemoryRows = "IGNORE_MAX_MEMORY_ROWS" // DirectiveAllowScatter lets scatter plans pass through even when they are turned off by `no-scatter`. DirectiveAllowScatter = "ALLOW_SCATTER" + // DirectiveAllowHashJoin lets the planner use hash join if possible + DirectiveAllowHashJoin = "ALLOW_HASH_JOIN" ) func isNonSpace(r rune) bool { @@ -212,8 +214,7 @@ func ExtractCommentDirectives(comments Comments) CommentDirectives { var vals map[string]interface{} - for _, comment := range comments { - commentStr := string(comment) + for _, commentStr := range comments { if commentStr[0:5] != commentDirectivePreamble { continue } diff --git a/go/vt/vtgate/engine/cached_size.go b/go/vt/vtgate/engine/cached_size.go index 26d83c0aba8..210a2db812c 100644 --- a/go/vt/vtgate/engine/cached_size.go +++ b/go/vt/vtgate/engine/cached_size.go @@ -183,7 +183,7 @@ func (cached *Filter) CachedSize(alloc bool) int64 { if alloc { size += int64(48) } - // field Predicate vitess.io/vitess/go/vt/vtgate/evalengine.Expr + // field ASTPred vitess.io/vitess/go/vt/vtgate/evalengine.Expr if cc, ok := cached.Predicate.(cachedObject); ok { size += cc.CachedSize(true) } diff --git a/go/vt/vtgate/engine/concatenate.go b/go/vt/vtgate/engine/concatenate.go index 089b0537c7f..e415e683d4f 100644 --- a/go/vt/vtgate/engine/concatenate.go +++ b/go/vt/vtgate/engine/concatenate.go @@ -80,7 +80,7 @@ func (c *Concatenate) TryExecute(vcursor VCursor, bindVars map[string]*querypb.B return nil, err } - var rowsAffected uint64 // default 0 + var rowsAffected uint64 var rows [][]sqltypes.Value for _, r := range res { diff --git a/go/vt/vtgate/engine/distinct.go b/go/vt/vtgate/engine/distinct.go index e96680cbd3c..1c45ac77db7 100644 --- a/go/vt/vtgate/engine/distinct.go +++ b/go/vt/vtgate/engine/distinct.go @@ -28,37 +28,37 @@ var _ Primitive = (*Distinct)(nil) // Distinct Primitive is used to uniqueify results type Distinct struct { - Source Primitive + Source Primitive + ColCollations []collations.ID } type row = []sqltypes.Value type probeTable struct { - m map[int64][]row + seenRows map[evalengine.HashCode][]row + colCollations []collations.ID } func (pt *probeTable) exists(inputRow row) (bool, error) { + // the two prime numbers used here (17 and 31) are used to + // calculate hashcode from all column values in the input row - code := int64(17) - for _, value := range inputRow { - hashcode, err := evalengine.NullsafeHashcode(value) - if err != nil { - return false, err - } - code = code*31 + hashcode + code, err := pt.hashCodeForRow(inputRow) + if err != nil { + return false, err } - existingRows, found := pt.m[code] + existingRows, found := pt.seenRows[code] if !found { // nothing with this hash code found, we can be sure it's a not seen row - pt.m[code] = []row{inputRow} + pt.seenRows[code] = []row{inputRow} return false, nil } // we found something in the map - still need to check all individual values // so we don't just fall for a hash collision for _, existingRow := range existingRows { - exists, err := equal(existingRow, inputRow) + exists, err := equal(existingRow, inputRow, pt.colCollations) if err != nil { return false, err } @@ -67,15 +67,68 @@ func (pt *probeTable) exists(inputRow row) (bool, error) { } } - pt.m[code] = append(existingRows, inputRow) + pt.seenRows[code] = append(existingRows, inputRow) return false, nil } -func equal(a, b []sqltypes.Value) (bool, error) { +func (pt *probeTable) hashCodeForRow(inputRow row) (evalengine.HashCode, error) { + // Why use 17 and 31 in this method? + // Copied from an old usenet discussion on the topic: + // https://groups.google.com/g/comp.programming/c/HSurZEyrZ1E?pli=1#d887b5bdb2dac99d + // > It's a mixture of superstition and good sense. + // > Suppose the multiplier were 26, and consider + // > hashing a hundred-character string. How much influence does + // > the string's first character have on the final value of `h', + // > just before the mod operation? The first character's value + // > will have been multiplied by MULT 99 times, so if the arithmetic + // > were done in infinite precision the value would consist of some + // > jumble of bits followed by 99 low-order zero bits -- each time + // > you multiply by MULT you introduce another low-order zero, right? + // > The computer's finite arithmetic just chops away all the excess + // > high-order bits, so the first character's actual contribution to + // > `h' is ... precisely zero! The `h' value depends only on the + // > rightmost 32 string characters (assuming a 32-bit int), and even + // > then things are not wonderful: the first of those final 32 bytes + // > influences only the leftmost bit of `h' and has no effect on + // > the remaining 31. Clearly, an even-valued MULT is a poor idea. + // > + // > Need MULT be prime? Not as far as I know (I don't know + // > everything); any odd value ought to suffice. 31 may be attractive + // > because it is close to a power of two, and it may be easier for + // > the compiler to replace a possibly slow multiply instruction with + // > a shift and subtract (31*x == (x << 5) - x) on machines where it + // > makes a difference. Setting MULT one greater than a power of two + // > (e.g., 33) would also be easy to optimize, but might produce too + // > "simple" an arrangement: mostly a juxtaposition of two copies + // > of the original set of bits, with a little mixing in the middle. + // > So you want an odd MULT that has plenty of one-bits. + + code := evalengine.HashCode(17) + for idx, value := range inputRow { + // We use unknown collations when we do not have collation information + // This is safe for types which do not require collation information like + // numeric types. It will fail at runtime for text types. + collation := collations.Unknown + if len(pt.colCollations) > idx { + collation = pt.colCollations[idx] + } + hashcode, err := evalengine.NullsafeHashcode(value, collation, value.Type()) + if err != nil { + return 0, err + } + code = code*31 + hashcode + } + return code, nil +} + +func equal(a, b []sqltypes.Value, colCollations []collations.ID) (bool, error) { for i, aVal := range a { - // TODO(king-11) make collation aware - cmp, err := evalengine.NullsafeCompare(aVal, b[i], collations.Unknown) + collation := collations.Unknown + if len(colCollations) > i { + collation = colCollations[i] + } + cmp, err := evalengine.NullsafeCompare(aVal, b[i], collation) if err != nil { return false, err } @@ -86,8 +139,11 @@ func equal(a, b []sqltypes.Value) (bool, error) { return true, nil } -func newProbeTable() *probeTable { - return &probeTable{m: map[int64][]row{}} +func newProbeTable(colCollations []collations.ID) *probeTable { + return &probeTable{ + seenRows: map[uintptr][]row{}, + colCollations: colCollations, + } } // TryExecute implements the Primitive interface @@ -102,7 +158,7 @@ func (d *Distinct) TryExecute(vcursor VCursor, bindVars map[string]*querypb.Bind InsertID: input.InsertID, } - pt := newProbeTable() + pt := newProbeTable(d.ColCollations) for _, row := range input.Rows { exists, err := pt.exists(row) @@ -119,7 +175,7 @@ func (d *Distinct) TryExecute(vcursor VCursor, bindVars map[string]*querypb.Bind // TryStreamExecute implements the Primitive interface func (d *Distinct) TryStreamExecute(vcursor VCursor, bindVars map[string]*querypb.BindVariable, wantfields bool, callback func(*sqltypes.Result) error) error { - pt := newProbeTable() + pt := newProbeTable(d.ColCollations) err := vcursor.StreamExecutePrimitive(d.Source, bindVars, wantfields, func(input *sqltypes.Result) error { result := &sqltypes.Result{ @@ -172,7 +228,26 @@ func (d *Distinct) Inputs() []Primitive { } func (d *Distinct) description() PrimitiveDescription { + var other map[string]interface{} + if d.ColCollations != nil { + allUnknown := true + other = map[string]interface{}{} + var colls []string + for _, collation := range d.ColCollations { + coll := collations.Default().LookupByID(collation) + if coll == nil { + colls = append(colls, "UNKNOWN") + } else { + colls = append(colls, coll.Name()) + allUnknown = false + } + } + if !allUnknown { + other["Collations"] = colls + } + } return PrimitiveDescription{ + Other: other, OperatorType: "Distinct", } } diff --git a/go/vt/vtgate/engine/distinct_test.go b/go/vt/vtgate/engine/distinct_test.go index 3dee6e3c75c..86652b7bf6e 100644 --- a/go/vt/vtgate/engine/distinct_test.go +++ b/go/vt/vtgate/engine/distinct_test.go @@ -21,6 +21,8 @@ import ( "fmt" "testing" + "vitess.io/vitess/go/mysql/collations" + "vitess.io/vitess/go/test/utils" "github.com/stretchr/testify/require" @@ -32,6 +34,7 @@ func TestDistinct(t *testing.T) { type testCase struct { testName string inputs *sqltypes.Result + collations []collations.ID expectedResult *sqltypes.Result expectedError string } @@ -57,14 +60,27 @@ func TestDistinct(t *testing.T) { inputs: r("a|b", "float64|float64", "0.1|0.2", "0.1|0.3", "0.1|0.4", "0.1|0.5"), expectedResult: r("a|b", "float64|float64", "0.1|0.2", "0.1|0.3", "0.1|0.4", "0.1|0.5"), }, { - testName: "varchar columns", + testName: "varchar columns without collations", inputs: r("myid", "varchar", "monkey", "horse"), - expectedError: "types does not support hashcode yet: VARCHAR", + expectedError: "text type with an unknown/unsupported collation cannot be hashed", + }, { + testName: "varchar columns with collations", + collations: []collations.ID{collations.ID(0x21)}, + inputs: r("myid", "varchar", "monkey", "horse", "Horse", "Monkey", "horses", "MONKEY"), + expectedResult: r("myid", "varchar", "monkey", "horse", "horses"), + }, { + testName: "mixed columns", + collations: []collations.ID{collations.ID(0x21), collations.Unknown}, + inputs: r("myid|id", "varchar|int64", "monkey|1", "horse|1", "Horse|1", "Monkey|1", "horses|1", "MONKEY|2"), + expectedResult: r("myid|id", "varchar|int64", "monkey|1", "horse|1", "horses|1", "MONKEY|2"), }} for _, tc := range testCases { t.Run(tc.testName+"-Execute", func(t *testing.T) { - distinct := &Distinct{Source: &fakePrimitive{results: []*sqltypes.Result{tc.inputs}}} + distinct := &Distinct{ + Source: &fakePrimitive{results: []*sqltypes.Result{tc.inputs}}, + ColCollations: tc.collations, + } qr, err := distinct.TryExecute(&noopVCursor{ctx: context.Background()}, nil, true) if tc.expectedError == "" { @@ -77,7 +93,10 @@ func TestDistinct(t *testing.T) { } }) t.Run(tc.testName+"-StreamExecute", func(t *testing.T) { - distinct := &Distinct{Source: &fakePrimitive{results: []*sqltypes.Result{tc.inputs}}} + distinct := &Distinct{ + Source: &fakePrimitive{results: []*sqltypes.Result{tc.inputs}}, + ColCollations: tc.collations, + } result, err := wrapStreamExecute(distinct, &noopVCursor{ctx: context.Background()}, nil, true) diff --git a/go/vt/vtgate/engine/hash_join.go b/go/vt/vtgate/engine/hash_join.go new file mode 100644 index 00000000000..e0f8f7530e9 --- /dev/null +++ b/go/vt/vtgate/engine/hash_join.go @@ -0,0 +1,257 @@ +/* +Copyright 2021 The Vitess Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package engine + +import ( + "fmt" + "strings" + + "vitess.io/vitess/go/mysql/collations" + "vitess.io/vitess/go/sqltypes" + querypb "vitess.io/vitess/go/vt/proto/query" + "vitess.io/vitess/go/vt/sqlparser" + "vitess.io/vitess/go/vt/vtgate/evalengine" +) + +var _ Primitive = (*HashJoin)(nil) + +// HashJoin specifies the parameters for a join primitive +// Hash joins work by fetch all the input from the LHS, and building a hash map, known as the probe table, for this input. +// The key to the map is the hashcode of the value for column that we are joining by. +// Then the RHS is fetched, and we can check if the rows from the RHS matches any from the LHS. +// When they match by hash code, we double-check that we are not working with a false positive by comparing the values. +type HashJoin struct { + Opcode JoinOpcode + + // Left and Right are the LHS and RHS primitives + // of the Join. They can be any primitive. + Left, Right Primitive `json:",omitempty"` + + // Cols defines which columns from the left + // or right results should be used to build the + // return result. For results coming from the + // left query, the index values go as -1, -2, etc. + // For the right query, they're 1, 2, etc. + // If Cols is {-1, -2, 1, 2}, it means that + // the returned result will be {Left0, Left1, Right0, Right1}. + Cols []int `json:",omitempty"` + + // The keys correspond to the column offset in the inputs where + // the join columns can be found + LHSKey, RHSKey int + + // The join condition. Used for plan descriptions + ASTPred sqlparser.Expr + + // collation and type are used to hash the incoming values correctly + Collation collations.ID + ComparisonType querypb.Type +} + +// TryExecute implements the Primitive interface +func (hj *HashJoin) TryExecute(vcursor VCursor, bindVars map[string]*querypb.BindVariable, wantfields bool) (*sqltypes.Result, error) { + lresult, err := vcursor.ExecutePrimitive(hj.Left, bindVars, wantfields) + if err != nil { + return nil, err + } + + // build the probe table from the LHS result + probeTable, err := hj.buildProbeTable(lresult) + if err != nil { + return nil, err + } + + rresult, err := vcursor.ExecutePrimitive(hj.Right, bindVars, wantfields) + if err != nil { + return nil, err + } + + result := &sqltypes.Result{ + Fields: joinFields(lresult.Fields, rresult.Fields, hj.Cols), + } + + for _, currentRHSRow := range rresult.Rows { + joinVal := currentRHSRow[hj.RHSKey] + if joinVal.IsNull() { + continue + } + hashcode, err := evalengine.NullsafeHashcode(joinVal, hj.Collation, hj.ComparisonType) + if err != nil { + return nil, err + } + lftRows := probeTable[hashcode] + for _, currentLHSRow := range lftRows { + lhsVal := currentLHSRow[hj.LHSKey] + // hash codes can give false positives, so we need to check with a real comparison as well + cmp, err := evalengine.NullsafeCompare(joinVal, lhsVal, collations.Unknown) + if err != nil { + return nil, err + } + + if cmp == 0 { + // we have a match! + result.Rows = append(result.Rows, joinRows(currentLHSRow, currentRHSRow, hj.Cols)) + } + } + } + + return result, nil +} + +func (hj *HashJoin) buildProbeTable(lresult *sqltypes.Result) (map[evalengine.HashCode][]row, error) { + probeTable := map[evalengine.HashCode][]row{} + for _, current := range lresult.Rows { + joinVal := current[hj.LHSKey] + if joinVal.IsNull() { + continue + } + hashcode, err := evalengine.NullsafeHashcode(joinVal, hj.Collation, hj.ComparisonType) + if err != nil { + return nil, err + } + probeTable[hashcode] = append(probeTable[hashcode], current) + } + return probeTable, nil +} + +// TryStreamExecute implements the Primitive interface +func (hj *HashJoin) TryStreamExecute(vcursor VCursor, bindVars map[string]*querypb.BindVariable, wantfields bool, callback func(*sqltypes.Result) error) error { + // build the probe table from the LHS result + probeTable := map[evalengine.HashCode][]row{} + var lfields []*querypb.Field + err := vcursor.StreamExecutePrimitive(hj.Left, bindVars, wantfields, func(result *sqltypes.Result) error { + if len(lfields) == 0 && len(result.Fields) != 0 { + lfields = result.Fields + } + for _, current := range result.Rows { + joinVal := current[hj.LHSKey] + if joinVal.IsNull() { + continue + } + hashcode, err := evalengine.NullsafeHashcode(joinVal, hj.Collation, hj.ComparisonType) + if err != nil { + return err + } + probeTable[hashcode] = append(probeTable[hashcode], current) + } + return nil + }) + if err != nil { + return err + } + + return vcursor.StreamExecutePrimitive(hj.Right, bindVars, wantfields, func(result *sqltypes.Result) error { + // compare the results coming from the RHS with the probe-table + res := &sqltypes.Result{} + if len(result.Fields) != 0 { + res = &sqltypes.Result{ + Fields: joinFields(lfields, result.Fields, hj.Cols), + } + } + for _, currentRHSRow := range result.Rows { + joinVal := currentRHSRow[hj.RHSKey] + if joinVal.IsNull() { + continue + } + hashcode, err := evalengine.NullsafeHashcode(joinVal, hj.Collation, hj.ComparisonType) + if err != nil { + return err + } + lftRows := probeTable[hashcode] + for _, currentLHSRow := range lftRows { + lhsVal := currentLHSRow[hj.LHSKey] + // hash codes can give false positives, so we need to check with a real comparison as well + cmp, err := evalengine.NullsafeCompare(joinVal, lhsVal, hj.Collation) + if err != nil { + return err + } + + if cmp == 0 { + // we have a match! + res.Rows = append(res.Rows, joinRows(currentLHSRow, currentRHSRow, hj.Cols)) + } + } + } + if len(res.Rows) != 0 || len(res.Fields) != 0 { + return callback(res) + } + return nil + }) +} + +// RouteType implements the Primitive interface +func (hj *HashJoin) RouteType() string { + return "HashJoin" +} + +// GetKeyspaceName implements the Primitive interface +func (hj *HashJoin) GetKeyspaceName() string { + if hj.Left.GetKeyspaceName() == hj.Right.GetKeyspaceName() { + return hj.Left.GetKeyspaceName() + } + return hj.Left.GetKeyspaceName() + "_" + hj.Right.GetKeyspaceName() +} + +// GetTableName implements the Primitive interface +func (hj *HashJoin) GetTableName() string { + return hj.Left.GetTableName() + "_" + hj.Right.GetTableName() +} + +// GetFields implements the Primitive interface +func (hj *HashJoin) GetFields(vcursor VCursor, bindVars map[string]*querypb.BindVariable) (*sqltypes.Result, error) { + joinVars := make(map[string]*querypb.BindVariable) + lresult, err := hj.Left.GetFields(vcursor, bindVars) + if err != nil { + return nil, err + } + result := &sqltypes.Result{} + rresult, err := hj.Right.GetFields(vcursor, combineVars(bindVars, joinVars)) + if err != nil { + return nil, err + } + result.Fields = joinFields(lresult.Fields, rresult.Fields, hj.Cols) + return result, nil +} + +// NeedsTransaction implements the Primitive interface +func (hj *HashJoin) NeedsTransaction() bool { + return hj.Right.NeedsTransaction() || hj.Left.NeedsTransaction() +} + +// Inputs implements the Primitive interface +func (hj *HashJoin) Inputs() []Primitive { + return []Primitive{hj.Left, hj.Right} +} + +// description implements the Primitive interface +func (hj *HashJoin) description() PrimitiveDescription { + other := map[string]interface{}{ + "TableName": hj.GetTableName(), + "JoinColumnIndexes": strings.Trim(strings.Join(strings.Fields(fmt.Sprint(hj.Cols)), ","), "[]"), + "Predicate": sqlparser.String(hj.ASTPred), + "ComparisonType": hj.ComparisonType.String(), + } + coll := collations.Default().LookupByID(hj.Collation) + if coll != nil { + other["Collation"] = coll.Name() + } + return PrimitiveDescription{ + OperatorType: "Join", + Variant: "Hash" + hj.Opcode.String(), + Other: other, + } +} diff --git a/go/vt/vtgate/engine/hash_join_test.go b/go/vt/vtgate/engine/hash_join_test.go new file mode 100644 index 00000000000..28402e53624 --- /dev/null +++ b/go/vt/vtgate/engine/hash_join_test.go @@ -0,0 +1,146 @@ +/* +Copyright 2021 The Vitess Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package engine + +import ( + "testing" + + "github.com/stretchr/testify/require" + + "vitess.io/vitess/go/sqltypes" + querypb "vitess.io/vitess/go/vt/proto/query" +) + +func TestHashJoinExecuteSameType(t *testing.T) { + leftPrim := &fakePrimitive{ + results: []*sqltypes.Result{ + sqltypes.MakeTestResult( + sqltypes.MakeTestFields( + "col1|col2|col3", + "int64|varchar|varchar", + ), + "1|a|aa", + "2|b|bb", + "3|c|cc", + ), + }, + } + rightPrim := &fakePrimitive{ + results: []*sqltypes.Result{ + sqltypes.MakeTestResult( + sqltypes.MakeTestFields( + "col4|col5|col6", + "int64|varchar|varchar", + ), + "1|d|dd", + "3|e|ee", + "4|f|ff", + "3|g|gg", + ), + }, + } + + // Normal join + jn := &HashJoin{ + Opcode: InnerJoin, + Left: leftPrim, + Right: rightPrim, + Cols: []int{-1, -2, 1, 2}, + LHSKey: 0, + RHSKey: 0, + } + r, err := jn.TryExecute(&noopVCursor{}, map[string]*querypb.BindVariable{}, true) + require.NoError(t, err) + leftPrim.ExpectLog(t, []string{ + `Execute true`, + }) + rightPrim.ExpectLog(t, []string{ + `Execute true`, + }) + expectResult(t, "jn.Execute", r, sqltypes.MakeTestResult( + sqltypes.MakeTestFields( + "col1|col2|col4|col5", + "int64|varchar|int64|varchar", + ), + "1|a|1|d", + "3|c|3|e", + "3|c|3|g", + )) +} + +func TestHashJoinExecuteDifferentType(t *testing.T) { + leftPrim := &fakePrimitive{ + results: []*sqltypes.Result{ + sqltypes.MakeTestResult( + sqltypes.MakeTestFields( + "col1|col2|col3", + "int64|varchar|varchar", + ), + "1|a|aa", + "2|b|bb", + "3|c|cc", + "5|c|cc", + ), + }, + } + rightPrim := &fakePrimitive{ + results: []*sqltypes.Result{ + sqltypes.MakeTestResult( + sqltypes.MakeTestFields( + "col4|col5|col6", + "varchar|varchar|varchar", + ), + "1.00|d|dd", + "3|e|ee", + "2.89|z|zz", + "4|f|ff", + "3|g|gg", + " 5.0toto|g|gg", + "w|ww|www", + ), + }, + } + + // Normal join + jn := &HashJoin{ + Opcode: InnerJoin, + Left: leftPrim, + Right: rightPrim, + Cols: []int{-1, -2, 1, 2}, + LHSKey: 0, + RHSKey: 0, + ComparisonType: querypb.Type_FLOAT64, + } + r, err := jn.TryExecute(&noopVCursor{}, map[string]*querypb.BindVariable{}, true) + require.NoError(t, err) + leftPrim.ExpectLog(t, []string{ + `Execute true`, + }) + rightPrim.ExpectLog(t, []string{ + `Execute true`, + }) + expectResult(t, "jn.Execute", r, sqltypes.MakeTestResult( + sqltypes.MakeTestFields( + "col1|col2|col4|col5", + "int64|varchar|varchar|varchar", + ), + "1|a|1.00|d", + "3|c|3|e", + "3|c|3|g", + "5|c| 5.0toto|g", + )) +} diff --git a/go/vt/vtgate/engine/join.go b/go/vt/vtgate/engine/join.go index 2424c271cd9..6969a3e5c9a 100644 --- a/go/vt/vtgate/engine/join.go +++ b/go/vt/vtgate/engine/join.go @@ -20,6 +20,8 @@ import ( "fmt" "strings" + "vitess.io/vitess/go/vt/sqlparser" + "vitess.io/vitess/go/sqltypes" querypb "vitess.io/vitess/go/vt/proto/query" ) @@ -46,6 +48,9 @@ type Join struct { // be built from the LHS result before invoking // the RHS subqquery. Vars map[string]int `json:",omitempty"` + + // ASTPred is the original join condition. + ASTPred sqlparser.Expr } // TryExecute performs a non-streaming exec. @@ -262,6 +267,9 @@ func (jn *Join) description() PrimitiveDescription { "TableName": jn.GetTableName(), "JoinColumnIndexes": strings.Trim(strings.Join(strings.Fields(fmt.Sprint(jn.Cols)), ","), "[]"), } + if jn.ASTPred != nil { + other["Predicate"] = sqlparser.String(jn.ASTPred) + } if len(jn.Vars) > 0 { other["JoinVars"] = orderedStringIntMap(jn.Vars) } diff --git a/go/vt/vtgate/evalengine/arithmetic.go b/go/vt/vtgate/evalengine/arithmetic.go index 90e6379d96e..95ccd41571a 100644 --- a/go/vt/vtgate/evalengine/arithmetic.go +++ b/go/vt/vtgate/evalengine/arithmetic.go @@ -20,12 +20,13 @@ import ( "bytes" "fmt" "math" + "strconv" + "strings" + "vitess.io/vitess/go/hack" "vitess.io/vitess/go/mysql/collations" "vitess.io/vitess/go/sqltypes" - "strconv" - querypb "vitess.io/vitess/go/vt/proto/query" vtrpcpb "vitess.io/vitess/go/vt/proto/vtrpc" "vitess.io/vitess/go/vt/vterrors" @@ -215,21 +216,28 @@ func NullsafeCompare(v1, v2 sqltypes.Value, collationID collations.ID) (int, err if v2.IsNull() { return 1, nil } - if sqltypes.IsNumber(v1.Type()) || sqltypes.IsNumber(v2.Type()) { - lv1, err := newEvalResult(v1) - if err != nil { - return 0, err - } - lv2, err := newEvalResult(v2) - if err != nil { - return 0, err - } - return compareNumeric(lv1, lv2) - } - if isByteComparable(v1) && isByteComparable(v2) { + + if isByteComparable(v1.Type()) && isByteComparable(v2.Type()) { return bytes.Compare(v1.ToBytes(), v2.ToBytes()), nil } - if v1.IsText() && v2.IsText() && collationID != collations.Unknown { + + typ, err := CoerceTo(v1.Type(), v2.Type()) // TODO systay we should add a method where this decision is done at plantime + if err != nil { + return 0, err + } + v1cast, err := castTo(v1, typ) + if err != nil { + return 0, err + } + v2cast, err := castTo(v2, typ) + if err != nil { + return 0, err + } + + if sqltypes.IsNumber(typ) { + return compareNumeric(v1cast, v2cast) + } + if (sqltypes.IsText(typ) || sqltypes.IsBinary(typ)) && collationID != collations.Unknown { collation := collations.Default().LookupByID(collationID) if collation == nil { return 0, UnsupportedCollationError{ @@ -251,31 +259,159 @@ func NullsafeCompare(v1, v2 sqltypes.Value, collationID collations.ID) (int, err } } +// HashCode is a type alias to the code easier to read +type HashCode = uintptr + // NullsafeHashcode returns an int64 hashcode that is guaranteed to be the same // for two values that are considered equal by `NullsafeCompare`. -// TODO: should be extended to support all possible types -func NullsafeHashcode(v sqltypes.Value) (int64, error) { - if v.IsNull() { - return math.MaxInt64, nil - } - - if sqltypes.IsNumber(v.Type()) { - result, err := newEvalResult(v) +func NullsafeHashcode(v sqltypes.Value, collation collations.ID, coerceType querypb.Type) (HashCode, error) { + castValue, err := castTo(v, coerceType) + if err != nil { + return 0, err + } + switch { + case sqltypes.IsNull(castValue.typ): + return HashCode(math.MaxInt64), nil + case sqltypes.IsNumber(castValue.typ): + return numericalHashCode(castValue), nil + case sqltypes.IsText(castValue.typ): + coll := collations.Default().LookupByID(collation) + if coll == nil { + return 0, vterrors.Errorf(vtrpcpb.Code_INTERNAL, "text type with an unknown/unsupported collation cannot be hashed") + } + return coll.Hash(castValue.bytes, 0), nil + case sqltypes.IsDate(castValue.typ): + time, err := parseDate(castValue) if err != nil { return 0, err } - return hashCode(result), nil + return uintptr(time.UnixNano()), nil + } + return 0, vterrors.Errorf(vtrpcpb.Code_UNIMPLEMENTED, "types does not support hashcode yet: %v", castValue.typ) +} + +func castTo(v sqltypes.Value, typ querypb.Type) (EvalResult, error) { + switch { + case typ == sqltypes.Null: + return EvalResult{}, nil + case sqltypes.IsFloat(typ) || typ == sqltypes.Decimal: + switch { + case v.IsSigned(): + ival, err := strconv.ParseInt(v.RawStr(), 10, 64) + if err != nil { + return EvalResult{}, vterrors.Errorf(vtrpcpb.Code_INTERNAL, "%v", err) + } + return EvalResult{fval: float64(ival), typ: sqltypes.Float64}, nil + case v.IsUnsigned(): + uval, err := strconv.ParseUint(v.RawStr(), 10, 64) + if err != nil { + return EvalResult{}, vterrors.Errorf(vtrpcpb.Code_INTERNAL, "%v", err) + } + return EvalResult{fval: float64(uval), typ: sqltypes.Float64}, nil + case v.IsFloat() || v.Type() == sqltypes.Decimal: + fval, err := strconv.ParseFloat(v.RawStr(), 64) + if err != nil { + return EvalResult{}, vterrors.Errorf(vtrpcpb.Code_INTERNAL, "%v", err) + } + return EvalResult{fval: fval, typ: sqltypes.Float64}, nil + case v.IsText() || v.IsBinary(): + fval := parseStringToFloat(v.RawStr()) + return EvalResult{fval: fval, typ: sqltypes.Float64}, nil + default: + return EvalResult{}, vterrors.Errorf(vtrpcpb.Code_INTERNAL, "coercion should not try to coerce this value to a float: %v", v) + } + + case sqltypes.IsSigned(typ): + switch { + case v.IsSigned(): + ival, err := strconv.ParseInt(v.RawStr(), 10, 64) + if err != nil { + return EvalResult{}, vterrors.Errorf(vtrpcpb.Code_INTERNAL, "%v", err) + } + return EvalResult{ival: ival, typ: sqltypes.Int64}, nil + case v.IsUnsigned(): + uval, err := strconv.ParseUint(v.RawStr(), 10, 64) + if err != nil { + return EvalResult{}, vterrors.Errorf(vtrpcpb.Code_INTERNAL, "%v", err) + } + return EvalResult{ival: int64(uval), typ: sqltypes.Int64}, nil + default: + return EvalResult{}, vterrors.Errorf(vtrpcpb.Code_INTERNAL, "coercion should not try to coerce this value to a signed int: %v", v) + } + + case sqltypes.IsUnsigned(typ): + switch { + case v.IsSigned(): + uval, err := strconv.ParseInt(v.RawStr(), 10, 64) + if err != nil { + return EvalResult{}, vterrors.Errorf(vtrpcpb.Code_INTERNAL, "%v", err) + } + return EvalResult{uval: uint64(uval), typ: sqltypes.Uint64}, nil + case v.IsUnsigned(): + uval, err := strconv.ParseUint(v.RawStr(), 10, 64) + if err != nil { + return EvalResult{}, vterrors.Errorf(vtrpcpb.Code_INTERNAL, "%v", err) + } + return EvalResult{uval: uval, typ: sqltypes.Uint64}, nil + default: + return EvalResult{}, vterrors.Errorf(vtrpcpb.Code_INTERNAL, "coercion should not try to coerce this value to a unsigned int: %v", v) + } + + case sqltypes.IsText(typ) || sqltypes.IsBinary(typ): + switch { + case v.IsText() || v.IsBinary(): + return EvalResult{bytes: v.Raw(), typ: v.Type()}, nil + default: + return EvalResult{}, vterrors.Errorf(vtrpcpb.Code_INTERNAL, "coercion should not try to coerce this value to a text: %v", v) + } } + return EvalResult{}, vterrors.Errorf(vtrpcpb.Code_INTERNAL, "coercion should not try to coerce this value: %v", v) +} - return 0, vterrors.Errorf(vtrpcpb.Code_UNIMPLEMENTED, "types does not support hashcode yet: %v", v.Type()) +// CoerceTo takes two input types, and decides how they should be coerced before compared +func CoerceTo(v1, v2 querypb.Type) (querypb.Type, error) { + if v1 == v2 { + return v1, nil + } + if sqltypes.IsNull(v1) || sqltypes.IsNull(v2) { + return sqltypes.Null, nil + } + if (sqltypes.IsText(v1) || sqltypes.IsBinary(v1)) && (sqltypes.IsText(v2) || sqltypes.IsBinary(v2)) { + return sqltypes.VarChar, nil + } + if sqltypes.IsNumber(v1) || sqltypes.IsNumber(v2) { + switch { + case sqltypes.IsText(v1) || sqltypes.IsBinary(v1) || sqltypes.IsText(v2) || sqltypes.IsBinary(v2): + return sqltypes.Float64, nil + case sqltypes.IsFloat(v2) || v2 == sqltypes.Decimal || sqltypes.IsFloat(v1) || v1 == sqltypes.Decimal: + return sqltypes.Float64, nil + case sqltypes.IsSigned(v1): + switch { + case sqltypes.IsUnsigned(v2): + return sqltypes.Uint64, nil + case sqltypes.IsSigned(v2): + return sqltypes.Int64, nil + default: + return 0, vterrors.Errorf(vtrpcpb.Code_UNIMPLEMENTED, "types does not support hashcode yet: %v vs %v", v1, v2) + } + case sqltypes.IsUnsigned(v1): + switch { + case sqltypes.IsSigned(v2) || sqltypes.IsUnsigned(v2): + return sqltypes.Uint64, nil + default: + return 0, vterrors.Errorf(vtrpcpb.Code_UNIMPLEMENTED, "types does not support hashcode yet: %v vs %v", v1, v2) + } + } + } + return 0, vterrors.Errorf(vtrpcpb.Code_UNIMPLEMENTED, "types does not support hashcode yet: %v vs %v", v1, v2) } // isByteComparable returns true if the type is binary or date/time. -func isByteComparable(v sqltypes.Value) bool { - if v.IsBinary() { +func isByteComparable(typ querypb.Type) bool { + if sqltypes.IsBinary(typ) { return true } - switch v.Type() { + switch typ { case sqltypes.Timestamp, sqltypes.Date, sqltypes.Time, sqltypes.Datetime, sqltypes.Enum, sqltypes.Set, sqltypes.TypeJSON, sqltypes.Bit: return true } @@ -630,3 +766,16 @@ func anyMinusFloat(v1 EvalResult, v2 float64) EvalResult { } return EvalResult{typ: sqltypes.Float64, fval: v1.fval - v2} } + +func parseStringToFloat(str string) float64 { + str = strings.TrimSpace(str) + + // We only care to parse as many of the initial float characters of the + // string as possible. This functionality is implemented in the `strconv` package + // of the standard library, but not exposed, so we hook into it. + val, _, err := hack.ParseFloatPrefix(str, 64) + if err != nil { + return 0.0 + } + return val +} diff --git a/go/vt/vtgate/evalengine/arithmetic_test.go b/go/vt/vtgate/evalengine/arithmetic_test.go index f8cc3160e1c..36032be11ae 100644 --- a/go/vt/vtgate/evalengine/arithmetic_test.go +++ b/go/vt/vtgate/evalengine/arithmetic_test.go @@ -511,6 +511,7 @@ func TestNullSafeAdd(t *testing.T) { } func TestNullsafeCompare(t *testing.T) { + collation := collations.Default().LookupByName("utf8mb4_general_ci").ID() tcases := []struct { v1, v2 sqltypes.Value out int @@ -534,7 +535,7 @@ func TestNullsafeCompare(t *testing.T) { // LHS Text v1: TestValue(querypb.Type_VARCHAR, "abcd"), v2: TestValue(querypb.Type_VARCHAR, "abcd"), - err: vterrors.New(vtrpcpb.Code_UNKNOWN, "types are not comparable: VARCHAR vs VARCHAR"), + out: 0, }, { // Make sure underlying error is returned for LHS. v1: TestValue(querypb.Type_INT64, "1.2"), @@ -597,9 +598,9 @@ func TestNullsafeCompare(t *testing.T) { out: -1, }} for _, tcase := range tcases { - got, err := NullsafeCompare(tcase.v1, tcase.v2, collations.Unknown) - if !vterrors.Equals(err, tcase.err) { - t.Errorf("NullsafeCompare(%v, %v) error: %v, want %v", printValue(tcase.v1), printValue(tcase.v2), vterrors.Print(err), vterrors.Print(tcase.err)) + got, err := NullsafeCompare(tcase.v1, tcase.v2, collation) + if tcase.err != nil { + require.EqualError(t, err, tcase.err.Error()) } if tcase.err != nil { continue @@ -1326,9 +1327,12 @@ func TestCompareNumeric(t *testing.T) { // if two values are considered equal, they must also produce the same hashcode if result == 0 { - aHash := hashCode(aVal) - bHash := hashCode(bVal) - assert.Equal(t, aHash, bHash, "hash code does not match") + if aVal.typ == bVal.typ { + // hash codes can only be compared if they are coerced to the same type first + aHash := numericalHashCode(aVal) + bHash := numericalHashCode(bVal) + assert.Equal(t, aHash, bHash, "hash code does not match") + } } }) } @@ -1549,25 +1553,6 @@ func TestMaxCollate(t *testing.T) { } } -func TestHashCodes(t *testing.T) { - n1 := sqltypes.NULL - n2 := sqltypes.Value{} - - h1, err := NullsafeHashcode(n1) - require.NoError(t, err) - h2, err := NullsafeHashcode(n2) - require.NoError(t, err) - assert.Equal(t, h1, h2) - - char := TestValue(querypb.Type_VARCHAR, "aa") - _, err = NullsafeHashcode(char) - require.Error(t, err) - - num := TestValue(querypb.Type_INT64, "123") - _, err = NullsafeHashcode(num) - require.NoError(t, err) -} - func printValue(v sqltypes.Value) string { return fmt.Sprintf("%v:%q", v.Type(), v.ToBytes()) } @@ -1653,3 +1638,36 @@ func BenchmarkAddGo(b *testing.B) { v1 += v2 } } + +func TestParseStringToFloat(t *testing.T) { + tcs := []struct { + str string + val float64 + }{ + {str: ""}, + {str: " "}, + {str: "1", val: 1}, + {str: "1.10", val: 1.10}, + {str: " 6.87", val: 6.87}, + {str: "93.66 ", val: 93.66}, + {str: "\t 42.10 \n ", val: 42.10}, + {str: "1.10aa", val: 1.10}, + {str: ".", val: 0.00}, + {str: ".99", val: 0.99}, + {str: "..99", val: 0}, + {str: "1.", val: 1}, + {str: "0.1.99", val: 0.1}, + {str: "0.", val: 0}, + {str: "8794354", val: 8794354}, + {str: " 10 ", val: 10}, + {str: "2266951196291479516", val: 2266951196291479516}, + {str: "abcd123", val: 0}, + } + + for _, tc := range tcs { + t.Run(tc.str, func(t *testing.T) { + got := parseStringToFloat(tc.str) + require.EqualValues(t, tc.val, got) + }) + } +} diff --git a/go/vt/vtgate/evalengine/evalengine.go b/go/vt/vtgate/evalengine/evalengine.go index 6e51467c953..aa4e5735ab9 100644 --- a/go/vt/vtgate/evalengine/evalengine.go +++ b/go/vt/vtgate/evalengine/evalengine.go @@ -17,6 +17,7 @@ limitations under the License. package evalengine import ( + "math" "time" "vitess.io/vitess/go/mysql/collations" @@ -243,21 +244,16 @@ func (v EvalResult) toSQLValue(resultType querypb.Type) sqltypes.Value { return sqltypes.NULL } -func hashCode(v EvalResult) int64 { - // we cast all numerical values to float64 and return the hashcode for that - var val float64 - switch v.typ { - case sqltypes.Int64: - val = float64(v.ival) - case sqltypes.Uint64: - val = float64(v.uval) - case sqltypes.Float64: - val = v.fval +func numericalHashCode(v EvalResult) HashCode { + switch { + case sqltypes.IsSigned(v.typ): + return HashCode(v.ival) + case sqltypes.IsUnsigned(v.typ): + return HashCode(v.uval) + case sqltypes.IsFloat(v.typ) || v.typ == sqltypes.Decimal: + return HashCode(math.Float64bits(v.fval)) } - - // this will not work for ±0, NaN and ±Inf, - // so one must still check using `compareNumeric` which will not be fooled - return int64(val) + panic("BUG: this is not a numerical value") } func compareNumeric(v1, v2 EvalResult) (int, error) { diff --git a/go/vt/vtgate/evalengine/hash_code_test.go b/go/vt/vtgate/evalengine/hash_code_test.go new file mode 100644 index 00000000000..21df5bffba7 --- /dev/null +++ b/go/vt/vtgate/evalengine/hash_code_test.go @@ -0,0 +1,115 @@ +/* +Copyright 2021 The Vitess Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package evalengine + +import ( + "fmt" + "math/rand" + "testing" + "time" + + "github.com/stretchr/testify/require" + + "vitess.io/vitess/go/mysql/collations" + "vitess.io/vitess/go/sqltypes" +) + +// The following test tries to produce lots of different values and compares them both using hash code and compare, +// to make sure that these two methods agree on what values are equal +func TestHashCodesRandom(t *testing.T) { + tested := 0 + equal := 0 + collation := collations.Default().LookupByName("utf8mb4_general_ci").ID() + endTime := time.Now().Add(1 * time.Second) + for time.Now().Before(endTime) { + t.Run(fmt.Sprintf("test %d", tested), func(t *testing.T) { + tested++ + v1, v2 := randomValues() + cmp, err := NullsafeCompare(v1, v2, collation) + require.NoErrorf(t, err, "%s compared with %s", v1.String(), v2.String()) + typ, err := CoerceTo(v1.Type(), v2.Type()) + require.NoError(t, err) + + hash1, err := NullsafeHashcode(v1, collation, typ) + require.NoError(t, err) + hash2, err := NullsafeHashcode(v2, collation, typ) + require.NoError(t, err) + if cmp == 0 { + equal++ + require.Equalf(t, hash1, hash2, "values %s and %s are considered equal but produce different hash codes: %d & %d", v1.String(), v2.String(), hash1, hash2) + } + }) + } + fmt.Printf("tested %d values, with %d equalities found\n", tested, equal) +} + +func randomValues() (sqltypes.Value, sqltypes.Value) { + if rand.Int()%2 == 0 { + // create a single value, and turn it into two different types + v := rand.Int() + return randomNumericType(v), randomNumericType(v) + } + + // just produce two arbitrary random values and compare + return randomValue(), randomValue() +} + +func randomNumericType(i int) sqltypes.Value { + r := rand.Intn(len(numericTypes)) + return numericTypes[r](i) + +} + +var numericTypes = []func(int) sqltypes.Value{ + func(i int) sqltypes.Value { return sqltypes.NULL }, + func(i int) sqltypes.Value { return sqltypes.NewInt8(int8(i)) }, + func(i int) sqltypes.Value { return sqltypes.NewInt32(int32(i)) }, + func(i int) sqltypes.Value { return sqltypes.NewInt64(int64(i)) }, + func(i int) sqltypes.Value { return sqltypes.NewUint64(uint64(i)) }, + func(i int) sqltypes.Value { return sqltypes.NewUint32(uint32(i)) }, + func(i int) sqltypes.Value { return sqltypes.NewFloat64(float64(i)) }, + func(i int) sqltypes.Value { return sqltypes.NewDecimal(fmt.Sprintf("%d", i)) }, + func(i int) sqltypes.Value { return sqltypes.NewVarChar(fmt.Sprintf("%d", i)) }, + func(i int) sqltypes.Value { return sqltypes.NewVarChar(fmt.Sprintf(" %f aa", float64(i))) }, +} + +var randomGenerators = []func() sqltypes.Value{ + randomNull, + randomInt8, + randomInt32, + randomInt64, + randomUint64, + randomUint32, + randomVarChar, + randomComplexVarChar, +} + +func randomValue() sqltypes.Value { + r := rand.Intn(len(randomGenerators)) + return randomGenerators[r]() +} + +func randomNull() sqltypes.Value { return sqltypes.NULL } +func randomInt8() sqltypes.Value { return sqltypes.NewInt8(int8(rand.Intn(255))) } +func randomInt32() sqltypes.Value { return sqltypes.NewInt32(rand.Int31()) } +func randomInt64() sqltypes.Value { return sqltypes.NewInt64(rand.Int63()) } +func randomUint32() sqltypes.Value { return sqltypes.NewUint32(rand.Uint32()) } +func randomUint64() sqltypes.Value { return sqltypes.NewUint64(rand.Uint64()) } +func randomVarChar() sqltypes.Value { return sqltypes.NewVarChar(fmt.Sprintf("%d", rand.Int63())) } +func randomComplexVarChar() sqltypes.Value { + return sqltypes.NewVarChar(fmt.Sprintf(" \t %f apa", float64(rand.Intn(1000))*1.10)) +} diff --git a/go/vt/vtgate/executor_select_test.go b/go/vt/vtgate/executor_select_test.go index bc63177d887..0fa9a6b0ee6 100644 --- a/go/vt/vtgate/executor_select_test.go +++ b/go/vt/vtgate/executor_select_test.go @@ -589,7 +589,7 @@ func TestSelectLastInsertIdInUnion(t *testing.T) { sql := "select last_insert_id() as id union select id from user" _, err := executorExec(executor, sql, map[string]*querypb.BindVariable{}) require.Error(t, err) - assert.Contains(t, err.Error(), "types does not support hashcode yet: VARCHAR") + assert.Contains(t, err.Error(), "text type with an unknown/unsupported collation cannot be hashed") } func TestSelectLastInsertIdInWhere(t *testing.T) { diff --git a/go/vt/vtgate/planbuilder/collations_test.go b/go/vt/vtgate/planbuilder/collations_test.go index 5804d76815e..ea3b0d6fe00 100644 --- a/go/vt/vtgate/planbuilder/collations_test.go +++ b/go/vt/vtgate/planbuilder/collations_test.go @@ -29,7 +29,7 @@ import ( // collationInTable allows us to set a collation on a column type collationInTable struct { ks, table, collationName string - offsetInTable int + colName string } type collationTestCase struct { @@ -40,7 +40,7 @@ type collationTestCase struct { func (tc *collationTestCase) run(t *testing.T) { vschemaWrapper := &vschemaWrapper{ - v: loadSchema(t, "schema_test.json"), + v: loadSchema(t, "schema_test.json", false), sysVarEnabled: true, version: Gen4, } @@ -53,7 +53,13 @@ func (tc *collationTestCase) run(t *testing.T) { func (tc *collationTestCase) addCollationsToSchema(vschema *vschemaWrapper) { for _, collation := range tc.collations { - vschema.v.Keyspaces[collation.ks].Tables[collation.table].Columns[collation.offsetInTable].CollationName = collation.collationName + tbl := vschema.v.Keyspaces[collation.ks].Tables[collation.table] + for i, c := range tbl.Columns { + if c.Name.EqualString(collation.colName) { + tbl.Columns[i].CollationName = collation.collationName + break + } + } } } @@ -63,7 +69,7 @@ func TestOrderedAggregateCollations(t *testing.T) { } testCases := []collationTestCase{ { - collations: []collationInTable{{ks: "user", table: "user", collationName: "utf8mb4_bin", offsetInTable: 2}}, + collations: []collationInTable{{ks: "user", table: "user", collationName: "utf8mb4_bin", colName: "textcol1"}}, query: "select textcol1 from user group by textcol1", check: func(t *testing.T, colls []collationInTable, primitive engine.Primitive) { oa, isOA := primitive.(*engine.OrderedAggregate) @@ -72,7 +78,7 @@ func TestOrderedAggregateCollations(t *testing.T) { }, }, { - collations: []collationInTable{{ks: "user", table: "user", collationName: "utf8mb4_bin", offsetInTable: 2}}, + collations: []collationInTable{{ks: "user", table: "user", collationName: "utf8mb4_bin", colName: "textcol1"}}, query: "select distinct textcol1 from user", check: func(t *testing.T, colls []collationInTable, primitive engine.Primitive) { oa, isOA := primitive.(*engine.OrderedAggregate) @@ -82,8 +88,8 @@ func TestOrderedAggregateCollations(t *testing.T) { }, { collations: []collationInTable{ - {ks: "user", table: "user", collationName: "utf8mb4_bin", offsetInTable: 2}, - {ks: "user", table: "user", collationName: "utf8mb4_bin", offsetInTable: 4}, + {ks: "user", table: "user", collationName: "utf8mb4_bin", colName: "textcol1"}, + {ks: "user", table: "user", collationName: "utf8mb4_bin", colName: "textcol2"}, }, query: "select textcol1, textcol2 from user group by textcol1, textcol2", check: func(t *testing.T, colls []collationInTable, primitive engine.Primitive) { @@ -95,7 +101,7 @@ func TestOrderedAggregateCollations(t *testing.T) { }, { collations: []collationInTable{ - {ks: "user", table: "user", collationName: "utf8mb4_bin", offsetInTable: 4}, + {ks: "user", table: "user", collationName: "utf8mb4_bin", colName: "textcol2"}, }, query: "select count(*), textcol2 from user group by textcol2", check: func(t *testing.T, colls []collationInTable, primitive engine.Primitive) { @@ -106,7 +112,7 @@ func TestOrderedAggregateCollations(t *testing.T) { }, { collations: []collationInTable{ - {ks: "user", table: "user", collationName: "utf8mb4_bin", offsetInTable: 4}, + {ks: "user", table: "user", collationName: "utf8mb4_bin", colName: "textcol2"}, }, query: "select count(*) as c, textcol2 from user group by textcol2 order by c", check: func(t *testing.T, colls []collationInTable, primitive engine.Primitive) { diff --git a/go/vt/vtgate/planbuilder/derivedtree.go b/go/vt/vtgate/planbuilder/derivedtree.go index 802815a1f04..775357c2115 100644 --- a/go/vt/vtgate/planbuilder/derivedtree.go +++ b/go/vt/vtgate/planbuilder/derivedtree.go @@ -41,7 +41,7 @@ func (d *derivedTree) tableID() semantics.TableSet { } func (d *derivedTree) cost() int { - panic("implement me") + return d.inner.cost() } func (d *derivedTree) clone() queryTree { @@ -79,7 +79,15 @@ func (d *derivedTree) pushPredicate(ctx *planningContext, expr sqlparser.Expr) e } func (d *derivedTree) removePredicate(ctx *planningContext, expr sqlparser.Expr) error { - return vterrors.Errorf(vtrpcpb.Code_INTERNAL, "remove '%s' predicate not supported on derived trees", sqlparser.String(expr)) + tableInfo, err := ctx.semTable.TableInfoForExpr(expr) + if err != nil { + return err + } + rewrittenExpr, err := semantics.RewriteDerivedExpression(expr, tableInfo) + if err != nil { + return err + } + return d.inner.removePredicate(ctx, rewrittenExpr) } // findOutputColumn returns the index on which the given name is found in the slice of diff --git a/go/vt/vtgate/planbuilder/distinct.go b/go/vt/vtgate/planbuilder/distinct.go index 666d39ac2ee..40cb3d6030a 100644 --- a/go/vt/vtgate/planbuilder/distinct.go +++ b/go/vt/vtgate/planbuilder/distinct.go @@ -17,6 +17,7 @@ limitations under the License. package planbuilder import ( + "vitess.io/vitess/go/mysql/collations" vtrpcpb "vitess.io/vitess/go/vt/proto/vtrpc" "vitess.io/vitess/go/vt/vterrors" "vitess.io/vitess/go/vt/vtgate/engine" @@ -27,17 +28,20 @@ var _ logicalPlan = (*distinct)(nil) // distinct is the logicalPlan for engine.Distinct. type distinct struct { logicalPlanCommon + ColCollations []collations.ID } -func newDistinct(source logicalPlan) logicalPlan { +func newDistinct(source logicalPlan, colCollations []collations.ID) logicalPlan { return &distinct{ logicalPlanCommon: newBuilderCommon(source), + ColCollations: colCollations, } } func (d *distinct) Primitive() engine.Primitive { return &engine.Distinct{ - Source: d.input.Primitive(), + Source: d.input.Primitive(), + ColCollations: d.ColCollations, } } diff --git a/go/vt/vtgate/planbuilder/gen4_planner_test.go b/go/vt/vtgate/planbuilder/gen4_planner_test.go index 35d9bf46262..c5321792c9f 100644 --- a/go/vt/vtgate/planbuilder/gen4_planner_test.go +++ b/go/vt/vtgate/planbuilder/gen4_planner_test.go @@ -124,6 +124,9 @@ func TestOptimizeQuery(t *testing.T) { }, { query: "select id from user_extra join unsharded where 4 < user_extra.id and unsharded.col = user_extra.id", result: `Join: { + JoinVars: map[user_extra_id:0] + Columns: [] + PredicatesToRemove: unsharded.col = :user_extra_id LHS: RouteTree{ Opcode: SelectScatter, Tables: user_extra, @@ -138,8 +141,6 @@ func TestOptimizeQuery(t *testing.T) { ColNames: , LeftJoins: } - JoinVars: map[user_extra_id:0] - Columns: [] }`, }, { query: "select t.x from (select id from user_extra) t(x) where t.x = 4", @@ -219,12 +220,45 @@ func TestOptimizeQuery(t *testing.T) { LeftJoins: } ExtractedSubQuery::__sq1 +}`, + }, { + query: "select u1.id from music u1 join music u2 on u2.col = u1.col join music u3 where u3.col = u1.col", + result: `Join: { + JoinVars: map[u3_col:0] + Columns: [] + PredicatesToRemove: :u3_col = u1.col + LHS: RouteTree{ + Opcode: SelectScatter, + Tables: music, + Predicates: , + ColNames: u3.col, + LeftJoins: + } + RHS: Join: { + JoinVars: map[u1_col:0] + Columns: [] + PredicatesToRemove: u2.col = :u1_col + LHS: RouteTree{ + Opcode: SelectScatter, + Tables: music, + Predicates: :u3_col = u1.col, + ColNames: u1.col, + LeftJoins: + } + RHS: RouteTree{ + Opcode: SelectScatter, + Tables: music, + Predicates: u2.col = :u1_col, + ColNames: , + LeftJoins: + } + } }`, }, } vschema := &vschemaWrapper{ - v: loadSchema(t, "schema_test.json"), + v: loadSchema(t, "schema_test.json", true), } for _, testcase := range testcases { @@ -270,10 +304,18 @@ func getQueryTreeString(tree queryTree) string { case *joinTree: leftStr := indent(getQueryTreeString(tree.lhs)) rightStr := indent(getQueryTreeString(tree.rhs)) + joinType := "Join" if tree.leftJoin { - return fmt.Sprintf("OuterJoin: {\n\tInner: %s\n\tOuter: %s\n\tJoinVars: %v\n\tColumns: %v\n}", leftStr, rightStr, tree.vars, tree.columns) + joinType = "OuterJoin" } - return fmt.Sprintf("Join: {\n\tLHS: %s\n\tRHS: %s\n\tJoinVars: %v\n\tColumns: %v\n}", leftStr, rightStr, tree.vars, tree.columns) + expressions := sqlparser.String(sqlparser.AndExpressions(tree.predicatesToRemoveFromHashJoin...)) + return fmt.Sprintf(`%s: { + JoinVars: %v + Columns: %v + PredicatesToRemove: %v + LHS: %s + RHS: %s +}`, joinType, tree.vars, tree.columns, expressions, leftStr, rightStr) case *derivedTree: inner := indent(getQueryTreeString(tree.inner)) return fmt.Sprintf("Derived %s: {\n\tInner:%s\n\tColumnAliases:%s\n\tColumns:%s\n}", tree.alias, inner, sqlparser.String(tree.columnAliases), getColmnsString(tree.columns)) diff --git a/go/vt/vtgate/planbuilder/grouping.go b/go/vt/vtgate/planbuilder/grouping.go index e5c7dd7b3ea..84335ebc611 100644 --- a/go/vt/vtgate/planbuilder/grouping.go +++ b/go/vt/vtgate/planbuilder/grouping.go @@ -108,7 +108,7 @@ func planDistinct(input logicalPlan) (logicalPlan, error) { // So, the distinct 'operator' cannot be pushed down into the // route. if rc.column.Origin() == node { - return newDistinct(node), nil + return newDistinct(node, nil), nil } node.eaggr.GroupByKeys = append(node.eaggr.GroupByKeys, &engine.GroupByParams{KeyCol: i, WeightStringCol: -1, FromGroupBy: false}) } diff --git a/go/vt/vtgate/planbuilder/hash_join.go b/go/vt/vtgate/planbuilder/hash_join.go new file mode 100644 index 00000000000..c1e436b0af8 --- /dev/null +++ b/go/vt/vtgate/planbuilder/hash_join.go @@ -0,0 +1,95 @@ +/* +Copyright 2021 The Vitess Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package planbuilder + +import ( + "vitess.io/vitess/go/mysql/collations" + querypb "vitess.io/vitess/go/vt/proto/query" + vtrpcpb "vitess.io/vitess/go/vt/proto/vtrpc" + "vitess.io/vitess/go/vt/sqlparser" + "vitess.io/vitess/go/vt/vterrors" + "vitess.io/vitess/go/vt/vtgate/engine" + "vitess.io/vitess/go/vt/vtgate/semantics" +) + +var _ logicalPlan = (*hashJoin)(nil) + +// hashJoin is used to build a HashJoin primitive. +type hashJoin struct { + gen4Plan + + // Left and Right are the nodes for the join. + Left, Right logicalPlan + + Opcode engine.JoinOpcode + + Cols []int + + // The keys correspond to the column offset in the inputs where + // the join columns can be found + LHSKey, RHSKey int + + Predicate sqlparser.Expr + + ComparisonType querypb.Type + + Collation collations.ID +} + +// WireupGen4 implements the logicalPlan interface +func (hj *hashJoin) WireupGen4(semTable *semantics.SemTable) error { + err := hj.Left.WireupGen4(semTable) + if err != nil { + return err + } + return hj.Right.WireupGen4(semTable) +} + +// Primitive implements the logicalPlan interface +func (hj *hashJoin) Primitive() engine.Primitive { + return &engine.HashJoin{ + Left: hj.Left.Primitive(), + Right: hj.Right.Primitive(), + Cols: hj.Cols, + Opcode: hj.Opcode, + LHSKey: hj.LHSKey, + RHSKey: hj.RHSKey, + ASTPred: hj.Predicate, + ComparisonType: hj.ComparisonType, + Collation: hj.Collation, + } +} + +// Inputs implements the logicalPlan interface +func (hj *hashJoin) Inputs() []logicalPlan { + return []logicalPlan{hj.Left, hj.Right} +} + +// Rewrite implements the logicalPlan interface +func (hj *hashJoin) Rewrite(inputs ...logicalPlan) error { + if len(inputs) != 2 { + return vterrors.New(vtrpcpb.Code_INTERNAL, "wrong number of children") + } + hj.Left = inputs[0] + hj.Right = inputs[1] + return nil +} + +// ContainsTables implements the logicalPlan interface +func (hj *hashJoin) ContainsTables() semantics.TableSet { + return hj.Left.ContainsTables().Merge(hj.Right.ContainsTables()) +} diff --git a/go/vt/vtgate/planbuilder/horizon_planning.go b/go/vt/vtgate/planbuilder/horizon_planning.go index 7a4b43e063e..11e8989844f 100644 --- a/go/vt/vtgate/planbuilder/horizon_planning.go +++ b/go/vt/vtgate/planbuilder/horizon_planning.go @@ -140,7 +140,7 @@ func (hp *horizonPlanning) truncateColumnsIfNeeded(plan logicalPlan) error { switch p := plan.(type) { case *route: p.eroute.SetTruncateColumnCount(hp.sel.GetColumnCount()) - case *joinGen4, *semiJoin: + case *joinGen4, *semiJoin, *hashJoin: // since this is a join, we can safely add extra columns and not need to truncate them case *orderedAggregate: p.eaggr.SetTruncateColumnCount(hp.sel.GetColumnCount()) @@ -192,6 +192,51 @@ func pushProjection(expr *sqlparser.AliasedExpr, plan logicalPlan, semTable *sem offset := len(sel.SelectExprs) sel.SelectExprs = append(sel.SelectExprs, expr) return offset, true, nil + case *hashJoin: + lhsSolves := node.Left.ContainsTables() + rhsSolves := node.Right.ContainsTables() + deps := semTable.RecursiveDeps(expr.Expr) + var column int + var appended bool + passDownReuseCol := reuseCol + if !reuseCol { + passDownReuseCol = expr.As.IsEmpty() + } + switch { + case deps.IsSolvedBy(lhsSolves): + offset, added, err := pushProjection(expr, node.Left, semTable, inner, passDownReuseCol, hasAggregation) + if err != nil { + return 0, false, err + } + column = -(offset + 1) + appended = added + case deps.IsSolvedBy(rhsSolves): + offset, added, err := pushProjection(expr, node.Right, semTable, inner && node.Opcode != engine.LeftJoin, passDownReuseCol, hasAggregation) + if err != nil { + return 0, false, err + } + column = offset + 1 + appended = added + default: + // if an expression has aggregation, then it should not be split up and pushed to both sides, + // for example an expression like count(*) will have dependencies on both sides, but we should not push it + // instead we should return an error + if hasAggregation { + return 0, false, vterrors.New(vtrpcpb.Code_UNIMPLEMENTED, "unsupported: cross-shard query with aggregates") + } + return 0, false, vterrors.New(vtrpcpb.Code_UNIMPLEMENTED, "unsupported: hash join with projection from both sides of the join") + } + if reuseCol && !appended { + for idx, col := range node.Cols { + if column == col { + return idx, false, nil + } + } + // the column was not appended to either child, but we could not find it in out cols list, + // so we'll still add it + } + node.Cols = append(node.Cols, column) + return len(node.Cols) - 1, true, nil case *joinGen4: lhsSolves := node.Left.ContainsTables() rhsSolves := node.Right.ContainsTables() @@ -399,7 +444,7 @@ func (hp *horizonPlanning) planAggregations(ctx *planningContext, plan logicalPl newPlan := plan var oa *orderedAggregate uniqVindex := hasUniqueVindex(ctx.vschema, ctx.semTable, hp.qp.GroupByExprs) - _, joinPlan := plan.(*joinGen4) + joinPlan := isJoin(plan) if !uniqVindex || joinPlan { if hp.qp.ProjectionError != nil { return nil, hp.qp.ProjectionError @@ -576,7 +621,7 @@ func planGroupByGen4(groupExpr abstract.GroupBy, plan logicalPlan, semTable *sem sel.GroupBy = append(sel.GroupBy, weightStringFor(groupExpr.WeightStrExpr)) } return false, nil - case *joinGen4: + case *joinGen4, *hashJoin: _, _, added, err := wrapAndPushExpr(groupExpr.Inner, groupExpr.WeightStrExpr, node, semTable) return added, err case *orderedAggregate: @@ -644,6 +689,13 @@ func (hp *horizonPlanning) planOrderBy(ctx *planningContext, orderExprs []abstra return nil, err } + return newPlan, nil + case *hashJoin: + newPlan, err := hp.planOrderByForHashJoin(ctx, orderExprs, plan) + if err != nil { + return nil, err + } + return newPlan, nil case *orderedAggregate: // remove ORDER BY NULL from the list of order by expressions since we will be doing the ordering on vtgate level so NULL is not useful @@ -798,6 +850,30 @@ func weightStringFor(expr sqlparser.Expr) sqlparser.Expr { } +func (hp *horizonPlanning) planOrderByForHashJoin(ctx *planningContext, orderExprs []abstract.OrderBy, plan *hashJoin) (logicalPlan, error) { + if len(orderExprs) == 1 && isSpecialOrderBy(orderExprs[0]) { + rhs, err := hp.planOrderBy(ctx, orderExprs, plan.Right) + if err != nil { + return nil, err + } + plan.Right = rhs + return plan, nil + } + if orderExprsDependsOnTableSet(orderExprs, ctx.semTable, plan.Right.ContainsTables()) { + newRight, err := hp.planOrderBy(ctx, orderExprs, plan.Right) + if err != nil { + return nil, err + } + plan.Right = newRight + return plan, nil + } + sortPlan, err := hp.createMemorySortPlan(ctx, plan, orderExprs, true) + if err != nil { + return nil, err + } + return sortPlan, nil +} + func (hp *horizonPlanning) planOrderByForJoin(ctx *planningContext, orderExprs []abstract.OrderBy, plan *joinGen4) (logicalPlan, error) { if len(orderExprs) == 1 && isSpecialOrderBy(orderExprs[0]) { lhs, err := hp.planOrderBy(ctx, orderExprs, plan.Left) @@ -812,7 +888,7 @@ func (hp *horizonPlanning) planOrderByForJoin(ctx *planningContext, orderExprs [ plan.Right = rhs return plan, nil } - if allLeft(orderExprs, ctx.semTable, plan.Left.ContainsTables()) { + if orderExprsDependsOnTableSet(orderExprs, ctx.semTable, plan.Left.ContainsTables()) { newLeft, err := hp.planOrderBy(ctx, orderExprs, plan.Left) if err != nil { return nil, err @@ -905,10 +981,10 @@ func (hp *horizonPlanning) createMemorySortPlan(ctx *planningContext, plan logic return ms, nil } -func allLeft(orderExprs []abstract.OrderBy, semTable *semantics.SemTable, lhsTables semantics.TableSet) bool { +func orderExprsDependsOnTableSet(orderExprs []abstract.OrderBy, semTable *semantics.SemTable, ts semantics.TableSet) bool { for _, expr := range orderExprs { exprDependencies := semTable.RecursiveDeps(expr.Inner.Expr) - if !exprDependencies.IsSolvedBy(lhsTables) { + if !exprDependencies.IsSolvedBy(ts) { return false } } @@ -1117,3 +1193,12 @@ func pushHaving(expr sqlparser.Expr, plan logicalPlan, semTable *semantics.SemTa } return nil, vterrors.Errorf(vtrpcpb.Code_INTERNAL, "[BUG] unreachable %T.filtering", plan) } + +func isJoin(plan logicalPlan) bool { + switch plan.(type) { + case *joinGen4, *hashJoin: + return true + default: + return false + } +} diff --git a/go/vt/vtgate/planbuilder/joinGen4.go b/go/vt/vtgate/planbuilder/joinGen4.go index 2a8dc532dc4..819e36d63c2 100644 --- a/go/vt/vtgate/planbuilder/joinGen4.go +++ b/go/vt/vtgate/planbuilder/joinGen4.go @@ -34,26 +34,9 @@ type joinGen4 struct { Opcode engine.JoinOpcode Cols []int Vars map[string]int -} - -// Order implements the logicalPlan interface -func (j *joinGen4) Order() int { - panic("[BUG]: should not be called. This is a Gen4 primitive") -} - -// ResultColumns implements the logicalPlan interface -func (j *joinGen4) ResultColumns() []*resultColumn { - panic("[BUG]: should not be called. This is a Gen4 primitive") -} - -// Reorder implements the logicalPlan interface -func (j *joinGen4) Reorder(i int) { - panic("[BUG]: should not be called. This is a Gen4 primitive") -} + Predicate sqlparser.Expr -// Wireup implements the logicalPlan interface -func (j *joinGen4) Wireup(lp logicalPlan, jt *jointab) error { - panic("[BUG]: should not be called. This is a Gen4 primitive") + gen4Plan } // WireupGen4 implements the logicalPlan interface @@ -65,29 +48,15 @@ func (j *joinGen4) WireupGen4(semTable *semantics.SemTable) error { return j.Right.WireupGen4(semTable) } -// SupplyVar implements the logicalPlan interface -func (j *joinGen4) SupplyVar(from, to int, col *sqlparser.ColName, varname string) { - panic("[BUG]: should not be called. This is a Gen4 primitive") -} - -// SupplyCol implements the logicalPlan interface -func (j *joinGen4) SupplyCol(col *sqlparser.ColName) (rc *resultColumn, colNumber int) { - panic("[BUG]: should not be called. This is a Gen4 primitive") -} - -// SupplyWeightString implements the logicalPlan interface -func (j *joinGen4) SupplyWeightString(colNumber int, alsoAddToGroupBy bool) (weightcolNumber int, err error) { - panic("[BUG]: should not be called. This is a Gen4 primitive") -} - // Primitive implements the logicalPlan interface func (j *joinGen4) Primitive() engine.Primitive { return &engine.Join{ - Left: j.Left.Primitive(), - Right: j.Right.Primitive(), - Cols: j.Cols, - Vars: j.Vars, - Opcode: j.Opcode, + Left: j.Left.Primitive(), + Right: j.Right.Primitive(), + Cols: j.Cols, + Vars: j.Vars, + Opcode: j.Opcode, + ASTPred: j.Predicate, } } diff --git a/go/vt/vtgate/planbuilder/jointree.go b/go/vt/vtgate/planbuilder/jointree.go index cba86af5bd8..523acaf483c 100644 --- a/go/vt/vtgate/planbuilder/jointree.go +++ b/go/vt/vtgate/planbuilder/jointree.go @@ -23,6 +23,11 @@ import ( "vitess.io/vitess/go/vt/vtgate/semantics" ) +type joinColumnInfo struct { + offset int + typ semantics.Type +} + type joinTree struct { // columns needed to feed other plans columns []int @@ -34,6 +39,12 @@ type joinTree struct { lhs, rhs queryTree leftJoin bool + + // predicatesToRemoveFromHashJoin lists all the predicates that needs to be removed + // from the right-hand side if we decide to do a hash join. + predicatesToRemoveFromHashJoin []sqlparser.Expr + + predicates []sqlparser.Expr } var _ queryTree = (*joinTree)(nil) @@ -44,10 +55,13 @@ func (jp *joinTree) tableID() semantics.TableSet { func (jp *joinTree) clone() queryTree { result := &joinTree{ - lhs: jp.lhs.clone(), - rhs: jp.rhs.clone(), - leftJoin: jp.leftJoin, - vars: jp.vars, + columns: jp.columns, + vars: jp.vars, + lhs: jp.lhs.clone(), + rhs: jp.rhs.clone(), + leftJoin: jp.leftJoin, + predicatesToRemoveFromHashJoin: jp.predicatesToRemoveFromHashJoin, + predicates: jp.predicates, } return result } @@ -127,6 +141,13 @@ func (jp *joinTree) removePredicate(ctx *planningContext, expr sqlparser.Expr) e } isRemoved = true } + for idx, predicate := range jp.predicates { + if sqlparser.EqualsExpr(predicate, expr) { + jp.predicates = append(jp.predicates[0:idx], jp.predicates[idx+1:]...) + isRemoved = true + break + } + } if isRemoved { return nil } diff --git a/go/vt/vtgate/planbuilder/logical_plan.go b/go/vt/vtgate/planbuilder/logical_plan.go index 0cf87f05567..c599a0689f9 100644 --- a/go/vt/vtgate/planbuilder/logical_plan.go +++ b/go/vt/vtgate/planbuilder/logical_plan.go @@ -87,7 +87,43 @@ type logicalPlan interface { ContainsTables() semantics.TableSet } -//------------------------------------------------------------------------- +// gen4Plan implements a few methods from logicalPlan that are unused by Gen4. +type gen4Plan struct{} + +// Order implements the logicalPlan interface +func (*gen4Plan) Order() int { + panic("[BUG]: should not be called. This is a Gen4 primitive") +} + +// ResultColumns implements the logicalPlan interface +func (*gen4Plan) ResultColumns() []*resultColumn { + panic("[BUG]: should not be called. This is a Gen4 primitive") +} + +// Reorder implements the logicalPlan interface +func (*gen4Plan) Reorder(int) { + panic("[BUG]: should not be called. This is a Gen4 primitive") +} + +// Wireup implements the logicalPlan interface +func (*gen4Plan) Wireup(logicalPlan, *jointab) error { + panic("[BUG]: should not be called. This is a Gen4 primitive") +} + +// SupplyVar implements the logicalPlan interface +func (*gen4Plan) SupplyVar(int, int, *sqlparser.ColName, string) { + panic("[BUG]: should not be called. This is a Gen4 primitive") +} + +// SupplyCol implements the logicalPlan interface +func (*gen4Plan) SupplyCol(*sqlparser.ColName) (rc *resultColumn, colNumber int) { + panic("[BUG]: should not be called. This is a Gen4 primitive") +} + +// SupplyWeightString implements the logicalPlan interface +func (*gen4Plan) SupplyWeightString(int, bool) (weightcolNumber int, err error) { + panic("[BUG]: should not be called. This is a Gen4 primitive") +} type planVisitor func(logicalPlan) (bool, logicalPlan, error) diff --git a/go/vt/vtgate/planbuilder/ordered_aggregate.go b/go/vt/vtgate/planbuilder/ordered_aggregate.go index 576cf758ebd..3f68ed0b61c 100644 --- a/go/vt/vtgate/planbuilder/ordered_aggregate.go +++ b/go/vt/vtgate/planbuilder/ordered_aggregate.go @@ -90,7 +90,7 @@ func (pb *primitiveBuilder) checkAggregates(sel *sqlparser.Select) error { if hasAggregates { return vterrors.New(vtrpcpb.Code_UNIMPLEMENTED, "unsupported: cross-shard query with aggregates") } - pb.plan = newDistinct(pb.plan) + pb.plan = newDistinct(pb.plan, nil) return nil } diff --git a/go/vt/vtgate/planbuilder/plan_test.go b/go/vt/vtgate/planbuilder/plan_test.go index 3ee572b2036..4a6cd8afafc 100644 --- a/go/vt/vtgate/planbuilder/plan_test.go +++ b/go/vt/vtgate/planbuilder/plan_test.go @@ -184,7 +184,7 @@ const ( func TestPlan(t *testing.T) { vschemaWrapper := &vschemaWrapper{ - v: loadSchema(t, "schema_test.json"), + v: loadSchema(t, "schema_test.json", true), sysVarEnabled: true, } @@ -227,7 +227,7 @@ func TestPlan(t *testing.T) { func TestSysVarSetDisabled(t *testing.T) { vschemaWrapper := &vschemaWrapper{ - v: loadSchema(t, "schema_test.json"), + v: loadSchema(t, "schema_test.json", true), sysVarEnabled: false, } @@ -239,7 +239,7 @@ func TestSysVarSetDisabled(t *testing.T) { func TestOne(t *testing.T) { vschema := &vschemaWrapper{ - v: loadSchema(t, "schema_test.json"), + v: loadSchema(t, "schema_test.json", true), } testFile(t, "onecase.txt", "", vschema) @@ -247,7 +247,7 @@ func TestOne(t *testing.T) { func TestRubyOnRailsQueries(t *testing.T) { vschemaWrapper := &vschemaWrapper{ - v: loadSchema(t, "rails_schema_test.json"), + v: loadSchema(t, "rails_schema_test.json", true), sysVarEnabled: true, } @@ -264,7 +264,7 @@ func TestRubyOnRailsQueries(t *testing.T) { func TestOLTP(t *testing.T) { vschemaWrapper := &vschemaWrapper{ - v: loadSchema(t, "oltp_schema_test.json"), + v: loadSchema(t, "oltp_schema_test.json", true), sysVarEnabled: true, } @@ -281,7 +281,7 @@ func TestOLTP(t *testing.T) { func TestTPCC(t *testing.T) { vschemaWrapper := &vschemaWrapper{ - v: loadSchema(t, "tpcc_schema_test.json"), + v: loadSchema(t, "tpcc_schema_test.json", true), sysVarEnabled: true, } @@ -298,7 +298,7 @@ func TestTPCC(t *testing.T) { func TestTPCH(t *testing.T) { vschemaWrapper := &vschemaWrapper{ - v: loadSchema(t, "tpch_schema_test.json"), + v: loadSchema(t, "tpch_schema_test.json", true), sysVarEnabled: true, } @@ -327,7 +327,7 @@ func BenchmarkTPCH(b *testing.B) { func benchmarkWorkload(b *testing.B, name string) { vschemaWrapper := &vschemaWrapper{ - v: loadSchema(b, name+"_schema_test.json"), + v: loadSchema(b, name+"_schema_test.json", true), sysVarEnabled: true, } @@ -349,7 +349,7 @@ func TestBypassPlanningShardTargetFromFile(t *testing.T) { defer os.RemoveAll(testOutputTempDir) vschema := &vschemaWrapper{ - v: loadSchema(t, "schema_test.json"), + v: loadSchema(t, "schema_test.json", true), keyspace: &vindexes.Keyspace{ Name: "main", Sharded: false, @@ -367,7 +367,7 @@ func TestBypassPlanningKeyrangeTargetFromFile(t *testing.T) { keyRange, _ := key.ParseShardingSpec("-") vschema := &vschemaWrapper{ - v: loadSchema(t, "schema_test.json"), + v: loadSchema(t, "schema_test.json", true), keyspace: &vindexes.Keyspace{ Name: "main", Sharded: false, @@ -389,7 +389,7 @@ func TestWithDefaultKeyspaceFromFile(t *testing.T) { } }() vschema := &vschemaWrapper{ - v: loadSchema(t, "schema_test.json"), + v: loadSchema(t, "schema_test.json", true), keyspace: &vindexes.Keyspace{ Name: "main", Sharded: false, @@ -411,7 +411,7 @@ func TestWithSystemSchemaAsDefaultKeyspace(t *testing.T) { require.NoError(t, err) defer os.RemoveAll(testOutputTempDir) vschema := &vschemaWrapper{ - v: loadSchema(t, "schema_test.json"), + v: loadSchema(t, "schema_test.json", true), keyspace: &vindexes.Keyspace{Name: "information_schema"}, tabletType: topodatapb.TabletType_PRIMARY, } @@ -425,7 +425,7 @@ func TestOtherPlanningFromFile(t *testing.T) { defer os.RemoveAll(testOutputTempDir) require.NoError(t, err) vschema := &vschemaWrapper{ - v: loadSchema(t, "schema_test.json"), + v: loadSchema(t, "schema_test.json", true), keyspace: &vindexes.Keyspace{ Name: "main", Sharded: false, @@ -437,7 +437,7 @@ func TestOtherPlanningFromFile(t *testing.T) { testFile(t, "other_admin_cases.txt", testOutputTempDir, vschema) } -func loadSchema(t testing.TB, filename string) *vindexes.VSchema { +func loadSchema(t testing.TB, filename string, setCollation bool) *vindexes.VSchema { formal, err := vindexes.LoadFormal(locateFile(filename)) if err != nil { t.Fatal(err) @@ -454,10 +454,12 @@ func loadSchema(t testing.TB, filename string) *vindexes.VSchema { // setting a default value to all the text columns in the tables of this keyspace // so that we can "simulate" a real case scenario where the vschema is aware of // columns' collations. - for _, table := range ks.Tables { - for i, col := range table.Columns { - if sqltypes.IsText(col.Type) { - table.Columns[i].CollationName = "latin1_swedish_ci" + if setCollation { + for _, table := range ks.Tables { + for i, col := range table.Columns { + if sqltypes.IsText(col.Type) { + table.Columns[i].CollationName = "latin1_swedish_ci" + } } } } @@ -821,7 +823,7 @@ var benchMarkFiles = []string{"from_cases.txt", "filter_cases.txt", "large_cases func BenchmarkPlanner(b *testing.B) { vschema := &vschemaWrapper{ - v: loadSchema(b, "schema_test.json"), + v: loadSchema(b, "schema_test.json", true), sysVarEnabled: true, } for _, filename := range benchMarkFiles { @@ -843,7 +845,7 @@ func BenchmarkPlanner(b *testing.B) { func BenchmarkSemAnalysis(b *testing.B) { vschema := &vschemaWrapper{ - v: loadSchema(b, "schema_test.json"), + v: loadSchema(b, "schema_test.json", true), sysVarEnabled: true, } @@ -876,7 +878,7 @@ func exerciseAnalyzer(query, database string, s semantics.SchemaInformation) { func BenchmarkSelectVsDML(b *testing.B) { vschema := &vschemaWrapper{ - v: loadSchema(b, "schema_test.json"), + v: loadSchema(b, "schema_test.json", true), sysVarEnabled: true, version: V3, } diff --git a/go/vt/vtgate/planbuilder/postprocess.go b/go/vt/vtgate/planbuilder/postprocess.go index 4280aca9ecb..984cfed1f64 100644 --- a/go/vt/vtgate/planbuilder/postprocess.go +++ b/go/vt/vtgate/planbuilder/postprocess.go @@ -101,7 +101,7 @@ var _ planVisitor = setUpperLimit func setUpperLimit(plan logicalPlan) (bool, logicalPlan, error) { arg := sqlparser.NewArgument("__upper_limit") switch node := plan.(type) { - case *join, *joinGen4: + case *join, *joinGen4, *hashJoin: return false, node, nil case *memorySort: pv, err := sqlparser.NewPlanValue(arg) diff --git a/go/vt/vtgate/planbuilder/querytree_transformers.go b/go/vt/vtgate/planbuilder/querytree_transformers.go index 23cb752bf03..bf371001694 100644 --- a/go/vt/vtgate/planbuilder/querytree_transformers.go +++ b/go/vt/vtgate/planbuilder/querytree_transformers.go @@ -20,6 +20,10 @@ import ( "sort" "strings" + "vitess.io/vitess/go/mysql/collations" + + "vitess.io/vitess/go/vt/vtgate/evalengine" + "vitess.io/vitess/go/sqltypes" vtrpcpb "vitess.io/vitess/go/vt/proto/vtrpc" @@ -240,11 +244,24 @@ func transformConcatenatePlan(ctx *planningContext, n *concatenateTree) (logical result = &concatenateGen4{sources: sources} } if n.distinct { - return newDistinct(result), nil + return newDistinct(result, getCollationsFor(ctx, n)), nil } return result, nil } +func getCollationsFor(ctx *planningContext, n *concatenateTree) []collations.ID { + var colls []collations.ID + for _, expr := range n.selectStmts[0].SelectExprs { + aliasedE, ok := expr.(*sqlparser.AliasedExpr) + if !ok { + return nil + } + typ := ctx.semTable.CollationFor(aliasedE.Expr) + colls = append(colls, typ) + } + return colls +} + func transformAndMergeInOrder(ctx *planningContext, n *concatenateTree) (sources []logicalPlan, err error) { for i, source := range n.sources { plan, err := createLogicalPlan(ctx, source, n.selectStmts[i]) @@ -464,6 +481,13 @@ func transformRoutePlan(ctx *planningContext, n *routeTree) (*route, error) { } func transformJoinPlan(ctx *planningContext, n *joinTree) (logicalPlan, error) { + // TODO systay we should move the decision of which join to use to the greedy algorithm, + // and thus represented as a queryTree + canHashJoin, lhsInfo, rhsInfo, err := canHashJoin(ctx, n) + if err != nil { + return nil, err + } + lhs, err := transformToLogicalPlan(ctx, n.lhs) if err != nil { return nil, err @@ -476,15 +500,113 @@ func transformJoinPlan(ctx *planningContext, n *joinTree) (logicalPlan, error) { if n.leftJoin { opCode = engine.LeftJoin } + + if canHashJoin { + coercedType, err := evalengine.CoerceTo(lhsInfo.typ.Type, rhsInfo.typ.Type) + if err != nil { + return nil, err + } + return &hashJoin{ + Left: lhs, + Right: rhs, + Cols: n.columns, + Opcode: opCode, + LHSKey: lhsInfo.offset, + RHSKey: rhsInfo.offset, + Predicate: sqlparser.AndExpressions(n.predicates...), + ComparisonType: coercedType, + Collation: lhsInfo.typ.Collation, + }, nil + } return &joinGen4{ - Left: lhs, - Right: rhs, - Cols: n.columns, - Vars: n.vars, - Opcode: opCode, + Left: lhs, + Right: rhs, + Cols: n.columns, + Vars: n.vars, + Opcode: opCode, + Predicate: sqlparser.AndExpressions(n.predicates...), }, nil } +// canHashJoin decides whether a join tree can be transformed into a hash join or apply join. +// Since hash join use variables from the left-hand side, we want to remove any +// join predicate living in the right-hand side. +// Hash joins are only supporting equality join predicates, which is why the join predicate +// has to be an EqualOp. +func canHashJoin(ctx *planningContext, n *joinTree) (canHash bool, lhs, rhs joinColumnInfo, err error) { + if len(n.predicatesToRemoveFromHashJoin) != 1 || + n.leftJoin || + !sqlparser.ExtractCommentDirectives(ctx.semTable.Comments).IsSet(sqlparser.DirectiveAllowHashJoin) { + return + } + cmp, isCmp := n.predicatesToRemoveFromHashJoin[0].(*sqlparser.ComparisonExpr) + if !isCmp || cmp.Operator != sqlparser.EqualOp { + return + } + var colOnLeft bool + var col *sqlparser.ColName + var arg sqlparser.Argument + if lCol, isCol := cmp.Left.(*sqlparser.ColName); isCol { + col = lCol + if rArg, isArg := cmp.Right.(sqlparser.Argument); isArg { + arg = rArg + } + colOnLeft = true + } else if rCol, isCol := cmp.Right.(*sqlparser.ColName); isCol { + col = rCol + if lArg, isArg := cmp.Left.(sqlparser.Argument); isArg { + arg = lArg + } + } else { + return + } + + lhsKey, found := n.vars[string(arg)] + if !found { + return + } + lhs.offset = lhsKey + + colType, found := ctx.semTable.ExprTypes[col] + if !found { + return + } + argType, found := ctx.semTable.ExprTypes[arg] + if !found { + return + } + + if colType.Collation != argType.Collation { + // joins with different collations are not yet supported + canHash = false + return + } + + if colOnLeft { + lhs.typ = colType + rhs.typ = argType + } else { + lhs.typ = argType + rhs.typ = colType + } + + columns, err := n.rhs.pushOutputColumns([]*sqlparser.ColName{col}, ctx.semTable) + if err != nil { + return false, lhs, rhs, nil + } + if len(columns) != 1 { + return + } + rhs.offset = columns[0] + canHash = true + + err = n.rhs.removePredicate(ctx, cmp) + if err != nil { + return + } + return +} + func relToTableExpr(t relation) (sqlparser.TableExpr, error) { switch t := t.(type) { case *routeTable: diff --git a/go/vt/vtgate/planbuilder/route_planning.go b/go/vt/vtgate/planbuilder/route_planning.go index 836f52728c7..a490bc052e4 100644 --- a/go/vt/vtgate/planbuilder/route_planning.go +++ b/go/vt/vtgate/planbuilder/route_planning.go @@ -579,11 +579,14 @@ func pushJoinPredicateOnJoin(ctx *planningContext, exprs []sqlparser.Expr, node if err != nil { return nil, err } + return &joinTree{ - lhs: lhsPlan, - rhs: rhsPlan, - leftJoin: node.leftJoin, - vars: node.vars, + lhs: lhsPlan, + rhs: rhsPlan, + leftJoin: node.leftJoin, + vars: node.vars, + predicatesToRemoveFromHashJoin: append(node.predicatesToRemoveFromHashJoin, rhsPreds...), + predicates: append(node.predicates, exprs...), }, nil } @@ -607,6 +610,9 @@ func breakExpressionInLHSandRHS( bvName := node.CompliantName() bvNames = append(bvNames, bvName) arg := sqlparser.NewArgument(bvName) + // we are replacing one of the sides of the comparison with an argument, + // but we don't want to lose the type information we have, so we copy it over + semTable.CopyExprInfo(node, arg) cursor.Replace(arg) } } diff --git a/go/vt/vtgate/planbuilder/testdata/aggr_cases.txt b/go/vt/vtgate/planbuilder/testdata/aggr_cases.txt index e125f8c4373..ddba13d3629 100644 --- a/go/vt/vtgate/planbuilder/testdata/aggr_cases.txt +++ b/go/vt/vtgate/planbuilder/testdata/aggr_cases.txt @@ -606,10 +606,9 @@ Gen4 plan same as above "Name": "user", "Sharded": true }, - "FieldQuery": "select col, count(*), weight_string(col) from `user` where 1 != 1 group by col, weight_string(col)", - "OrderBy": "(0|2) ASC", - "Query": "select col, count(*), weight_string(col) from `user` group by col, weight_string(col) order by col asc", - "ResultColumns": 2, + "FieldQuery": "select col, count(*) from `user` where 1 != 1 group by col", + "OrderBy": "0 ASC", + "Query": "select col, count(*) from `user` group by col order by col asc", "Table": "`user`" } ] @@ -622,8 +621,7 @@ Gen4 plan same as above "OperatorType": "Aggregate", "Variant": "Ordered", "Aggregates": "count(1) AS count(*)", - "GroupBy": "(0|2)", - "ResultColumns": 2, + "GroupBy": "0", "Inputs": [ { "OperatorType": "Route", @@ -632,9 +630,9 @@ Gen4 plan same as above "Name": "user", "Sharded": true }, - "FieldQuery": "select col, count(*), weight_string(col) from `user` where 1 != 1 group by col, weight_string(col)", - "OrderBy": "(0|2) ASC", - "Query": "select col, count(*), weight_string(col) from `user` group by col, weight_string(col) order by col asc", + "FieldQuery": "select col, count(*) from `user` where 1 != 1 group by col", + "OrderBy": "0 ASC", + "Query": "select col, count(*) from `user` group by col order by col asc", "Table": "`user`" } ] @@ -651,7 +649,7 @@ Gen4 plan same as above "OperatorType": "Aggregate", "Variant": "Ordered", "Aggregates": "count(1) AS count(*)", - "GroupBy": "(0|2), (3|4)", + "GroupBy": "0, (2|3)", "ResultColumns": 2, "Inputs": [ { @@ -661,9 +659,9 @@ Gen4 plan same as above "Name": "user", "Sharded": true }, - "FieldQuery": "select col, count(*), weight_string(col), baz, weight_string(baz) from `user` where 1 != 1 group by col, weight_string(col), baz, weight_string(baz)", - "OrderBy": "(0|2) ASC, (3|4) ASC", - "Query": "select col, count(*), weight_string(col), baz, weight_string(baz) from `user` group by col, weight_string(col), baz, weight_string(baz) order by col asc, baz asc", + "FieldQuery": "select col, count(*), baz, weight_string(baz) from `user` where 1 != 1 group by col, baz, weight_string(baz)", + "OrderBy": "0 ASC, (2|3) ASC", + "Query": "select col, count(*), baz, weight_string(baz) from `user` group by col, baz, weight_string(baz) order by col asc, baz asc", "Table": "`user`" } ] @@ -965,39 +963,15 @@ Gen4 error: In aggregated query without GROUP BY, expression of SELECT list cont "Name": "user", "Sharded": true }, - "FieldQuery": "select col, weight_string(col) from `user` where 1 != 1", - "OrderBy": "(0|1) ASC", - "Query": "select distinct col, weight_string(col) from `user` order by col asc", - "ResultColumns": 1, - "Table": "`user`" - } - ] - } -} -{ - "QueryType": "SELECT", - "Original": "select distinct col from user", - "Instructions": { - "OperatorType": "Aggregate", - "Variant": "Ordered", - "GroupBy": "(0|1)", - "ResultColumns": 1, - "Inputs": [ - { - "OperatorType": "Route", - "Variant": "SelectScatter", - "Keyspace": { - "Name": "user", - "Sharded": true - }, - "FieldQuery": "select col, weight_string(col) from `user` where 1 != 1", - "OrderBy": "(0|1) ASC", - "Query": "select distinct col, weight_string(col) from `user` order by col asc", + "FieldQuery": "select col from `user` where 1 != 1", + "OrderBy": "0 ASC", + "Query": "select distinct col from `user` order by col asc", "Table": "`user`" } ] } } +Gen4 plan same as above # scatter aggregate group by select col "select col from user group by col" @@ -1016,39 +990,15 @@ Gen4 error: In aggregated query without GROUP BY, expression of SELECT list cont "Name": "user", "Sharded": true }, - "FieldQuery": "select col, weight_string(col) from `user` where 1 != 1 group by col, weight_string(col)", - "OrderBy": "(0|1) ASC", - "Query": "select col, weight_string(col) from `user` group by col, weight_string(col) order by col asc", - "ResultColumns": 1, - "Table": "`user`" - } - ] - } -} -{ - "QueryType": "SELECT", - "Original": "select col from user group by col", - "Instructions": { - "OperatorType": "Aggregate", - "Variant": "Ordered", - "GroupBy": "(0|1)", - "ResultColumns": 1, - "Inputs": [ - { - "OperatorType": "Route", - "Variant": "SelectScatter", - "Keyspace": { - "Name": "user", - "Sharded": true - }, - "FieldQuery": "select col, weight_string(col) from `user` where 1 != 1 group by col, weight_string(col)", - "OrderBy": "(0|1) ASC", - "Query": "select col, weight_string(col) from `user` group by col, weight_string(col) order by col asc", + "FieldQuery": "select col from `user` where 1 != 1 group by col", + "OrderBy": "0 ASC", + "Query": "select col from `user` group by col order by col asc", "Table": "`user`" } ] } } +Gen4 plan same as above # count with distinct group by unique vindex "select id, count(distinct col) from user group by id" @@ -1087,10 +1037,9 @@ Gen4 plan same as above "Name": "user", "Sharded": true }, - "FieldQuery": "select col, count(distinct id), weight_string(col) from `user` where 1 != 1 group by col, weight_string(col)", - "OrderBy": "(0|2) ASC", - "Query": "select col, count(distinct id), weight_string(col) from `user` group by col, weight_string(col) order by col asc", - "ResultColumns": 2, + "FieldQuery": "select col, count(distinct id) from `user` where 1 != 1 group by col", + "OrderBy": "0 ASC", + "Query": "select col, count(distinct id) from `user` group by col order by col asc", "Table": "`user`" } ] @@ -1103,8 +1052,7 @@ Gen4 plan same as above "OperatorType": "Aggregate", "Variant": "Ordered", "Aggregates": "count(1) AS count(distinct id)", - "GroupBy": "(0|2)", - "ResultColumns": 2, + "GroupBy": "0", "Inputs": [ { "OperatorType": "Route", @@ -1113,9 +1061,9 @@ Gen4 plan same as above "Name": "user", "Sharded": true }, - "FieldQuery": "select col, count(distinct id), weight_string(col) from `user` where 1 != 1 group by col, weight_string(col)", - "OrderBy": "(0|2) ASC", - "Query": "select col, count(distinct id), weight_string(col) from `user` group by col, weight_string(col) order by col asc", + "FieldQuery": "select col, count(distinct id) from `user` where 1 != 1 group by col", + "OrderBy": "0 ASC", + "Query": "select col, count(distinct id) from `user` group by col order by col asc", "Table": "`user`" } ] @@ -1633,10 +1581,9 @@ Gen4 error: Can't group on 'count(*)' "Name": "user", "Sharded": true }, - "FieldQuery": "select col, weight_string(col) from `user` where 1 != 1 group by 1, weight_string(col)", - "OrderBy": "(0|1) ASC", - "Query": "select col, weight_string(col) from `user` group by 1, weight_string(col) order by col asc", - "ResultColumns": 1, + "FieldQuery": "select col from `user` where 1 != 1 group by 1", + "OrderBy": "0 ASC", + "Query": "select col from `user` group by 1 order by col asc", "Table": "`user`" } ] @@ -1648,8 +1595,7 @@ Gen4 error: Can't group on 'count(*)' "Instructions": { "OperatorType": "Aggregate", "Variant": "Ordered", - "GroupBy": "(0|1)", - "ResultColumns": 1, + "GroupBy": "0", "Inputs": [ { "OperatorType": "Route", @@ -1658,9 +1604,9 @@ Gen4 error: Can't group on 'count(*)' "Name": "user", "Sharded": true }, - "FieldQuery": "select col, weight_string(col) from `user` where 1 != 1 group by col, weight_string(col)", - "OrderBy": "(0|1) ASC", - "Query": "select col, weight_string(col) from `user` group by col, weight_string(col) order by col asc", + "FieldQuery": "select col from `user` where 1 != 1 group by col", + "OrderBy": "0 ASC", + "Query": "select col from `user` group by col order by col asc", "Table": "`user`" } ] @@ -1963,10 +1909,9 @@ Gen4 plan same as above "Name": "user", "Sharded": true }, - "FieldQuery": "select col, count(*), weight_string(col) from `user` where 1 != 1 group by col, weight_string(col)", - "OrderBy": "(0|2) ASC", - "Query": "select col, count(*), weight_string(col) from `user` group by col, weight_string(col) order by col asc limit :__upper_limit", - "ResultColumns": 2, + "FieldQuery": "select col, count(*) from `user` where 1 != 1 group by col", + "OrderBy": "0 ASC", + "Query": "select col, count(*) from `user` group by col order by col asc limit :__upper_limit", "Table": "`user`" } ] @@ -1985,8 +1930,7 @@ Gen4 plan same as above "OperatorType": "Aggregate", "Variant": "Ordered", "Aggregates": "count(1) AS count(*)", - "GroupBy": "(0|2)", - "ResultColumns": 2, + "GroupBy": "0", "Inputs": [ { "OperatorType": "Route", @@ -1995,9 +1939,9 @@ Gen4 plan same as above "Name": "user", "Sharded": true }, - "FieldQuery": "select col, count(*), weight_string(col) from `user` where 1 != 1 group by col, weight_string(col)", - "OrderBy": "(0|2) ASC", - "Query": "select col, count(*), weight_string(col) from `user` group by col, weight_string(col) order by col asc limit :__upper_limit", + "FieldQuery": "select col, count(*) from `user` where 1 != 1 group by col", + "OrderBy": "0 ASC", + "Query": "select col, count(*) from `user` group by col order by col asc limit :__upper_limit", "Table": "`user`" } ] @@ -2182,11 +2126,11 @@ Gen4 plan same as above Gen4 plan same as above # if derived table scatter and ordering, then V3 doesn't allow outer constructs to be pushed down. -"select count(*) from (select col, user_extra.extra from user join user_extra on user.id = user_extra.user_id order by user_extra.extra) a" +"select count(*) from (select user.col, user_extra.extra from user join user_extra on user.id = user_extra.user_id order by user_extra.extra) a" "unsupported: cross-shard query with aggregates" { "QueryType": "SELECT", - "Original": "select count(*) from (select col, user_extra.extra from user join user_extra on user.id = user_extra.user_id order by user_extra.extra) a", + "Original": "select count(*) from (select user.col, user_extra.extra from user join user_extra on user.id = user_extra.user_id order by user_extra.extra) a", "Instructions": { "OperatorType": "Aggregate", "Variant": "Ordered", @@ -2199,9 +2143,9 @@ Gen4 plan same as above "Name": "user", "Sharded": true }, - "FieldQuery": "select count(*) from (select col, user_extra.extra, weight_string(user_extra.extra) from `user`, user_extra where 1 != 1) as a where 1 != 1", + "FieldQuery": "select count(*) from (select `user`.col, user_extra.extra, weight_string(user_extra.extra) from `user`, user_extra where 1 != 1) as a where 1 != 1", "OrderBy": "(1|2) ASC", - "Query": "select count(*) from (select col, user_extra.extra, weight_string(user_extra.extra) from `user`, user_extra where `user`.id = user_extra.user_id order by user_extra.extra asc) as a", + "Query": "select count(*) from (select `user`.col, user_extra.extra, weight_string(user_extra.extra) from `user`, user_extra where `user`.id = user_extra.user_id order by user_extra.extra asc) as a", "ResultColumns": 2, "Table": "`user`, user_extra" } @@ -2426,13 +2370,12 @@ Gen4 plan same as above "OperatorType": "Sort", "Variant": "Memory", "OrderBy": "1 ASC", - "ResultColumns": 2, "Inputs": [ { "OperatorType": "Aggregate", "Variant": "Ordered", "Aggregates": "count(1) AS k", - "GroupBy": "(0|2)", + "GroupBy": "0", "Inputs": [ { "OperatorType": "Route", @@ -2441,9 +2384,9 @@ Gen4 plan same as above "Name": "user", "Sharded": true }, - "FieldQuery": "select col, count(*) as k, weight_string(col) from `user` where 1 != 1 group by col, weight_string(col)", - "OrderBy": "(0|2) ASC", - "Query": "select col, count(*) as k, weight_string(col) from `user` group by col, weight_string(col) order by col asc", + "FieldQuery": "select col, count(*) as k from `user` where 1 != 1 group by col", + "OrderBy": "0 ASC", + "Query": "select col, count(*) as k from `user` group by col order by col asc", "Table": "`user`" } ] @@ -2470,10 +2413,9 @@ Gen4 plan same as above "Name": "user", "Sharded": true }, - "FieldQuery": "select col, count(*) as k, weight_string(col) from `user` where 1 != 1 group by col, weight_string(col)", - "OrderBy": "(0|2) ASC", - "Query": "select col, count(*) as k, weight_string(col) from `user` group by col, weight_string(col) order by col asc", - "ResultColumns": 2, + "FieldQuery": "select col, count(*) as k from `user` where 1 != 1 group by col", + "OrderBy": "0 ASC", + "Query": "select col, count(*) as k from `user` group by col order by col asc", "Table": "`user`" } ] @@ -2486,8 +2428,7 @@ Gen4 plan same as above "OperatorType": "Aggregate", "Variant": "Ordered", "Aggregates": "count(1) AS k", - "GroupBy": "(0|2)", - "ResultColumns": 2, + "GroupBy": "0", "Inputs": [ { "OperatorType": "Route", @@ -2496,9 +2437,9 @@ Gen4 plan same as above "Name": "user", "Sharded": true }, - "FieldQuery": "select col, count(*) as k, weight_string(col) from `user` where 1 != 1 group by col, weight_string(col)", - "OrderBy": "(0|2) ASC", - "Query": "select col, count(*) as k, weight_string(col) from `user` group by col, weight_string(col) order by col asc", + "FieldQuery": "select col, count(*) as k from `user` where 1 != 1 group by col", + "OrderBy": "0 ASC", + "Query": "select col, count(*) as k from `user` group by col order by col asc", "Table": "`user`" } ] diff --git a/go/vt/vtgate/planbuilder/testdata/ddl_cases_no_default_keyspace.txt b/go/vt/vtgate/planbuilder/testdata/ddl_cases_no_default_keyspace.txt index 18cb418f74a..e9c7958c36a 100644 --- a/go/vt/vtgate/planbuilder/testdata/ddl_cases_no_default_keyspace.txt +++ b/go/vt/vtgate/planbuilder/testdata/ddl_cases_no_default_keyspace.txt @@ -219,17 +219,17 @@ Gen4 plan same as above } # create view with auto-resolve anonymous columns for simple route -"create view user.view_a as select col from user join user_extra on user.id = user_extra.user_id" +"create view user.view_a as select user.col from user join user_extra on user.id = user_extra.user_id" { "QueryType": "DDL", - "Original": "create view user.view_a as select col from user join user_extra on user.id = user_extra.user_id", + "Original": "create view user.view_a as select user.col from user join user_extra on user.id = user_extra.user_id", "Instructions": { "OperatorType": "DDL", "Keyspace": { "Name": "user", "Sharded": true }, - "Query": "create view view_a as select col from `user` join user_extra on `user`.id = user_extra.user_id" + "Query": "create view view_a as select `user`.col from `user` join user_extra on `user`.id = user_extra.user_id" } } Gen4 plan same as above diff --git a/go/vt/vtgate/planbuilder/testdata/filter_cases.txt b/go/vt/vtgate/planbuilder/testdata/filter_cases.txt index 71558ff669e..29e6c03d578 100644 --- a/go/vt/vtgate/planbuilder/testdata/filter_cases.txt +++ b/go/vt/vtgate/planbuilder/testdata/filter_cases.txt @@ -733,7 +733,48 @@ Gen4 plan same as above ] } } -Gen4 plan same as above +{ + "QueryType": "SELECT", + "Original": "select user_extra.id from user join user_extra on user.col = user_extra.col where user.id = 5", + "Instructions": { + "OperatorType": "Join", + "Variant": "Join", + "JoinColumnIndexes": "1", + "JoinVars": { + "user_col": 0 + }, + "Predicate": "`user`.col = user_extra.col", + "TableName": "`user`_user_extra", + "Inputs": [ + { + "OperatorType": "Route", + "Variant": "SelectEqualUnique", + "Keyspace": { + "Name": "user", + "Sharded": true + }, + "FieldQuery": "select `user`.col from `user` where 1 != 1", + "Query": "select `user`.col from `user` where `user`.id = 5", + "Table": "`user`", + "Values": [ + 5 + ], + "Vindex": "user_index" + }, + { + "OperatorType": "Route", + "Variant": "SelectScatter", + "Keyspace": { + "Name": "user", + "Sharded": true + }, + "FieldQuery": "select user_extra.id from user_extra where 1 != 1", + "Query": "select user_extra.id from user_extra where user_extra.col = :user_col", + "Table": "user_extra" + } + ] + } +} # Multi-route unique vindex route on both routes "select user_extra.id from user join user_extra on user.col = user_extra.col where user.id = 5 and user_extra.user_id = 5" @@ -855,6 +896,7 @@ Gen4 plan same as above "JoinVars": { "user_col": 1 }, + "Predicate": "`user`.col = user_extra.col and user_extra.user_id = `user`.col", "TableName": "`user`_user_extra", "Inputs": [ { @@ -936,6 +978,7 @@ Gen4 plan same as above "JoinVars": { "user_col": 0 }, + "Predicate": "`user`.col = user_extra.col", "TableName": "`user`_user_extra", "Inputs": [ { @@ -1176,6 +1219,7 @@ Gen4 plan same as above "JoinVars": { "unsharded_id": 0 }, + "Predicate": "unsharded.id = `user`.id", "TableName": "unsharded_`user`", "Inputs": [ { @@ -1287,6 +1331,7 @@ Gen4 plan same as above "JoinVars": { "user_extra_col": 0 }, + "Predicate": "u.id in (user_extra.col, 1)", "TableName": "user_extra_`user`", "Inputs": [ { @@ -1378,6 +1423,7 @@ Gen4 plan same as above "JoinVars": { "user_extra_col": 0 }, + "Predicate": "u.id in (user_extra.col, 1)", "TableName": "user_extra_`user`", "Inputs": [ { @@ -1548,6 +1594,7 @@ Gen4 plan same as above "JoinVars": { "user_extra_col": 0 }, + "Predicate": "u.id in (user_extra.col, 1)", "TableName": "user_extra_`user`", "Inputs": [ { @@ -3143,6 +3190,7 @@ Gen4 plan same as above "JoinVars": { "user_extra_col": 0 }, + "Predicate": "`user`.id = user_extra.col", "TableName": "user_extra_`user`", "Inputs": [ { @@ -3207,6 +3255,7 @@ Gen4 plan same as above "JoinVars": { "user_col": 0 }, + "Predicate": "`user`.col = user_extra.col", "TableName": "`user`_user_extra", "Inputs": [ { @@ -3284,7 +3333,7 @@ Gen4 plan same as above "Instructions": { "OperatorType": "Aggregate", "Variant": "Ordered", - "GroupBy": "(0|2), (1|3)", + "GroupBy": "(0|2), 1", "ResultColumns": 2, "Inputs": [ { @@ -3313,9 +3362,9 @@ Gen4 plan same as above "Name": "user", "Sharded": true }, - "FieldQuery": "select `user`.id, `user`.col, weight_string(`user`.id), weight_string(`user`.col) from `user` where 1 != 1", - "OrderBy": "(0|2) ASC, (1|3) ASC", - "Query": "select `user`.id, `user`.col, weight_string(`user`.id), weight_string(`user`.col) from `user` where :__sq_has_values1 = 1 and `user`.col in ::__sq1 order by `user`.id asc, `user`.col asc", + "FieldQuery": "select `user`.id, `user`.col, weight_string(`user`.id) from `user` where 1 != 1", + "OrderBy": "(0|2) ASC, 1 ASC", + "Query": "select `user`.id, `user`.col, weight_string(`user`.id) from `user` where :__sq_has_values1 = 1 and `user`.col in ::__sq1 order by `user`.id asc, `user`.col asc", "Table": "`user`" } ] @@ -3452,3 +3501,163 @@ Gen4 plan same as above } } Gen4 plan same as above + +# Multi-route unique vindex constraint (with hash join) +"select /*vt+ ALLOW_HASH_JOIN */ user_extra.id from user join user_extra on user.col = user_extra.col where user.id = 5" +{ + "QueryType": "SELECT", + "Original": "select /*vt+ ALLOW_HASH_JOIN */ user_extra.id from user join user_extra on user.col = user_extra.col where user.id = 5", + "Instructions": { + "OperatorType": "Join", + "Variant": "Join", + "JoinColumnIndexes": "1", + "JoinVars": { + "user_col": 0 + }, + "TableName": "`user`_user_extra", + "Inputs": [ + { + "OperatorType": "Route", + "Variant": "SelectEqualUnique", + "Keyspace": { + "Name": "user", + "Sharded": true + }, + "FieldQuery": "select `user`.col from `user` where 1 != 1", + "Query": "select /*vt+ ALLOW_HASH_JOIN */ `user`.col from `user` where `user`.id = 5", + "Table": "`user`", + "Values": [ + 5 + ], + "Vindex": "user_index" + }, + { + "OperatorType": "Route", + "Variant": "SelectScatter", + "Keyspace": { + "Name": "user", + "Sharded": true + }, + "FieldQuery": "select user_extra.id from user_extra where 1 != 1", + "Query": "select /*vt+ ALLOW_HASH_JOIN */ user_extra.id from user_extra where user_extra.col = :user_col", + "Table": "user_extra" + } + ] + } +} +{ + "QueryType": "SELECT", + "Original": "select /*vt+ ALLOW_HASH_JOIN */ user_extra.id from user join user_extra on user.col = user_extra.col where user.id = 5", + "Instructions": { + "OperatorType": "Join", + "Variant": "HashJoin", + "ComparisonType": "INT16", + "JoinColumnIndexes": "2", + "Predicate": "`user`.col = user_extra.col", + "TableName": "`user`_user_extra", + "Inputs": [ + { + "OperatorType": "Route", + "Variant": "SelectEqualUnique", + "Keyspace": { + "Name": "user", + "Sharded": true + }, + "FieldQuery": "select `user`.col from `user` where 1 != 1", + "Query": "select /*vt+ ALLOW_HASH_JOIN */ `user`.col from `user` where `user`.id = 5", + "Table": "`user`", + "Values": [ + 5 + ], + "Vindex": "user_index" + }, + { + "OperatorType": "Route", + "Variant": "SelectScatter", + "Keyspace": { + "Name": "user", + "Sharded": true + }, + "FieldQuery": "select user_extra.col, user_extra.id from user_extra where 1 != 1", + "Query": "select /*vt+ ALLOW_HASH_JOIN */ user_extra.col, user_extra.id from user_extra", + "Table": "user_extra" + } + ] + } +} + +# Multi-route with non-route constraint, should use first route. +"select /*vt+ ALLOW_HASH_JOIN */ user_extra.id from user join user_extra on user.col = user_extra.col where 1 = 1" +{ + "QueryType": "SELECT", + "Original": "select /*vt+ ALLOW_HASH_JOIN */ user_extra.id from user join user_extra on user.col = user_extra.col where 1 = 1", + "Instructions": { + "OperatorType": "Join", + "Variant": "Join", + "JoinColumnIndexes": "1", + "JoinVars": { + "user_col": 0 + }, + "TableName": "`user`_user_extra", + "Inputs": [ + { + "OperatorType": "Route", + "Variant": "SelectScatter", + "Keyspace": { + "Name": "user", + "Sharded": true + }, + "FieldQuery": "select `user`.col from `user` where 1 != 1", + "Query": "select /*vt+ ALLOW_HASH_JOIN */ `user`.col from `user` where 1 = 1", + "Table": "`user`" + }, + { + "OperatorType": "Route", + "Variant": "SelectScatter", + "Keyspace": { + "Name": "user", + "Sharded": true + }, + "FieldQuery": "select user_extra.id from user_extra where 1 != 1", + "Query": "select /*vt+ ALLOW_HASH_JOIN */ user_extra.id from user_extra where user_extra.col = :user_col", + "Table": "user_extra" + } + ] + } +} +{ + "QueryType": "SELECT", + "Original": "select /*vt+ ALLOW_HASH_JOIN */ user_extra.id from user join user_extra on user.col = user_extra.col where 1 = 1", + "Instructions": { + "OperatorType": "Join", + "Variant": "HashJoin", + "ComparisonType": "INT16", + "JoinColumnIndexes": "2", + "Predicate": "`user`.col = user_extra.col", + "TableName": "`user`_user_extra", + "Inputs": [ + { + "OperatorType": "Route", + "Variant": "SelectScatter", + "Keyspace": { + "Name": "user", + "Sharded": true + }, + "FieldQuery": "select `user`.col from `user` where 1 != 1", + "Query": "select /*vt+ ALLOW_HASH_JOIN */ `user`.col from `user` where 1 = 1", + "Table": "`user`" + }, + { + "OperatorType": "Route", + "Variant": "SelectScatter", + "Keyspace": { + "Name": "user", + "Sharded": true + }, + "FieldQuery": "select user_extra.col, user_extra.id from user_extra where 1 != 1", + "Query": "select /*vt+ ALLOW_HASH_JOIN */ user_extra.col, user_extra.id from user_extra where 1 = 1", + "Table": "user_extra" + } + ] + } +} diff --git a/go/vt/vtgate/planbuilder/testdata/from_cases.txt b/go/vt/vtgate/planbuilder/testdata/from_cases.txt index c50b51f139b..6d979273589 100644 --- a/go/vt/vtgate/planbuilder/testdata/from_cases.txt +++ b/go/vt/vtgate/planbuilder/testdata/from_cases.txt @@ -461,6 +461,7 @@ Gen4 plan same as above "JoinVars": { "music_id": 0 }, + "Predicate": "music.id = `user`.id", "TableName": "music_`user`", "Inputs": [ { @@ -637,6 +638,7 @@ Gen4 plan same as above "JoinVars": { "u_a": 0 }, + "Predicate": "u.a = m.b", "TableName": "`user`_unsharded", "Inputs": [ { @@ -736,6 +738,7 @@ Gen4 plan same as above "JoinVars": { "m1_col": 0 }, + "Predicate": "m1.col = m2.col", "TableName": "`user`_unsharded_unsharded", "Inputs": [ { @@ -745,6 +748,7 @@ Gen4 plan same as above "JoinVars": { "user_col": 0 }, + "Predicate": "`user`.col = m1.col", "TableName": "`user`_unsharded", "Inputs": [ { @@ -846,7 +850,66 @@ Gen4 plan same as above ] } } -Gen4 plan same as above +{ + "QueryType": "SELECT", + "Original": "select user.col from user left join user_extra as e left join unsharded as m1 on m1.col = e.col on user.col = e.col", + "Instructions": { + "OperatorType": "Join", + "Variant": "LeftJoin", + "JoinColumnIndexes": "-1", + "JoinVars": { + "user_col": 0 + }, + "Predicate": "`user`.col = e.col", + "TableName": "`user`_user_extra_unsharded", + "Inputs": [ + { + "OperatorType": "Route", + "Variant": "SelectScatter", + "Keyspace": { + "Name": "user", + "Sharded": true + }, + "FieldQuery": "select `user`.col from `user` where 1 != 1", + "Query": "select `user`.col from `user`", + "Table": "`user`" + }, + { + "OperatorType": "Join", + "Variant": "LeftJoin", + "JoinVars": { + "e_col": 0 + }, + "Predicate": "m1.col = e.col and e.col = :user_col", + "TableName": "user_extra_unsharded", + "Inputs": [ + { + "OperatorType": "Route", + "Variant": "SelectScatter", + "Keyspace": { + "Name": "user", + "Sharded": true + }, + "FieldQuery": "select e.col from user_extra as e where 1 != 1", + "Query": "select e.col from user_extra as e where e.col = :user_col", + "Table": "user_extra" + }, + { + "OperatorType": "Route", + "Variant": "SelectUnsharded", + "Keyspace": { + "Name": "main", + "Sharded": false + }, + "FieldQuery": "select 1 from unsharded as m1 where 1 != 1", + "Query": "select 1 from unsharded as m1 where m1.col = :e_col", + "Table": "unsharded" + } + ] + } + ] + } +} # Right join "select m1.col from unsharded as m1 right join unsharded as m2 on m1.a=m2.b" @@ -1443,6 +1506,7 @@ Gen4 plan same as above "JoinVars": { "user_id": 0 }, + "Predicate": "`user`.id \u003c user_extra.user_id", "TableName": "`user`_user_extra", "Inputs": [ { @@ -1604,6 +1668,7 @@ Gen4 plan same as above "JoinVars": { "user_extra_col": 0 }, + "Predicate": "`user`.id = user_extra.col", "TableName": "user_extra_`user`", "Inputs": [ { @@ -1689,6 +1754,7 @@ Gen4 plan same as above "JoinVars": { "user_name": 0 }, + "Predicate": "user_extra.user_id = `user`.`name`", "TableName": "`user`_user_extra", "Inputs": [ { @@ -2090,7 +2156,48 @@ Gen4 error: Duplicate column name 'id' ] } } -Gen4 plan same as above +{ + "QueryType": "SELECT", + "Original": "select t.id from (select id from user where id = 5) as t join user_extra on t.id = user_extra.col", + "Instructions": { + "OperatorType": "Join", + "Variant": "Join", + "JoinColumnIndexes": "-1", + "JoinVars": { + "t_id": 0 + }, + "Predicate": "t.id = user_extra.col", + "TableName": "`user`_user_extra", + "Inputs": [ + { + "OperatorType": "Route", + "Variant": "SelectEqualUnique", + "Keyspace": { + "Name": "user", + "Sharded": true + }, + "FieldQuery": "select t.id from (select id from `user` where 1 != 1) as t where 1 != 1", + "Query": "select t.id from (select id from `user` where id = 5) as t", + "Table": "`user`", + "Values": [ + 5 + ], + "Vindex": "user_index" + }, + { + "OperatorType": "Route", + "Variant": "SelectScatter", + "Keyspace": { + "Name": "user", + "Sharded": true + }, + "FieldQuery": "select 1 from user_extra where 1 != 1", + "Query": "select 1 from user_extra where user_extra.col = :t_id", + "Table": "user_extra" + } + ] + } +} # routing rules for derived table "select id from (select id, col from route1 where id = 5) as t" @@ -2486,7 +2593,73 @@ Gen4 plan same as above ] } } -Gen4 plan same as above +{ + "QueryType": "SELECT", + "Original": "select t.col1 from (select user.id, user.col1 from user join user_extra) as t join unsharded on unsharded.col1 = t.col1 and unsharded.id = t.id", + "Instructions": { + "OperatorType": "Join", + "Variant": "Join", + "JoinColumnIndexes": "-1", + "JoinVars": { + "t_col1": 0, + "t_id": 1 + }, + "Predicate": "unsharded.col1 = t.col1 and unsharded.id = t.id", + "TableName": "`user`_user_extra_unsharded", + "Inputs": [ + { + "OperatorType": "SimpleProjection", + "Columns": [ + 1, + 0 + ], + "Inputs": [ + { + "OperatorType": "Join", + "Variant": "Join", + "JoinColumnIndexes": "-1,-2", + "TableName": "`user`_user_extra", + "Inputs": [ + { + "OperatorType": "Route", + "Variant": "SelectScatter", + "Keyspace": { + "Name": "user", + "Sharded": true + }, + "FieldQuery": "select `user`.id, `user`.col1 from `user` where 1 != 1", + "Query": "select `user`.id, `user`.col1 from `user`", + "Table": "`user`" + }, + { + "OperatorType": "Route", + "Variant": "SelectScatter", + "Keyspace": { + "Name": "user", + "Sharded": true + }, + "FieldQuery": "select 1 from user_extra where 1 != 1", + "Query": "select 1 from user_extra", + "Table": "user_extra" + } + ] + } + ] + }, + { + "OperatorType": "Route", + "Variant": "SelectUnsharded", + "Keyspace": { + "Name": "main", + "Sharded": false + }, + "FieldQuery": "select 1 from unsharded where 1 != 1", + "Query": "select 1 from unsharded where unsharded.col1 = :t_col1 and unsharded.id = :t_id", + "Table": "unsharded" + } + ] + } +} # wire-up on within cross-shard derived table "select t.id from (select user.id, user.col1 from user join user_extra on user_extra.col = user.col) as t" @@ -2551,6 +2724,7 @@ Gen4 plan same as above "JoinVars": { "user_col": 0 }, + "Predicate": "user_extra.col = `user`.col", "TableName": "`user`_user_extra", "Inputs": [ { @@ -2659,6 +2833,7 @@ Gen4 plan same as above "JoinVars": { "ua_id": 0 }, + "Predicate": "t.id = ua.id", "TableName": "unsharded_a_`user`_user_extra", "Inputs": [ { @@ -2682,6 +2857,7 @@ Gen4 plan same as above "OperatorType": "Join", "Variant": "Join", "JoinColumnIndexes": "-1,-2", + "Predicate": "`user`.id = :ua_id", "TableName": "`user`_user_extra", "Inputs": [ { @@ -3084,6 +3260,7 @@ Gen4 plan same as above "OperatorType": "Join", "Variant": "LeftJoin", "JoinColumnIndexes": "-1", + "Predicate": ":__sq_has_values1 = 1 and `user`.col in ::__sq1", "TableName": "unsharded_`user`", "Inputs": [ { @@ -3296,6 +3473,7 @@ Gen4 plan same as above "JoinVars": { "user_col2": 0 }, + "Predicate": "unsharded.col2 = `user`.col2", "TableName": "`user`_unsharded", "Inputs": [ { @@ -3566,6 +3744,7 @@ Gen4 plan same as above "JoinVars": { "user_extra_id": 0 }, + "Predicate": "`user`.id = user_extra.id", "TableName": "user_extra_`user`", "Inputs": [ { @@ -3651,6 +3830,7 @@ Gen4 plan same as above "JoinVars": { "user_extra_assembly_id": 0 }, + "Predicate": "`user`.id = user_extra.assembly_id", "TableName": "user_extra_`user`", "Inputs": [ { @@ -3840,6 +4020,7 @@ Gen4 plan same as above "JoinVars": { "ue_id": 0 }, + "Predicate": "ue.id = u.id", "TableName": "music, user_extra_`user`", "Inputs": [ { @@ -3921,6 +4102,7 @@ Gen4 plan same as above "JoinVars": { "ue_id": 0 }, + "Predicate": "u.id = ue.id", "TableName": "user_extra_`user`", "Inputs": [ { @@ -4359,3 +4541,249 @@ Gen4 plan same as above ] } } + +# join on int columns +"select u.id from user as u join user as uu on u.intcol = uu.intcol" +{ + "QueryType": "SELECT", + "Original": "select u.id from user as u join user as uu on u.intcol = uu.intcol", + "Instructions": { + "OperatorType": "Join", + "Variant": "Join", + "JoinColumnIndexes": "-1", + "JoinVars": { + "u_intcol": 1 + }, + "TableName": "`user`_`user`", + "Inputs": [ + { + "OperatorType": "Route", + "Variant": "SelectScatter", + "Keyspace": { + "Name": "user", + "Sharded": true + }, + "FieldQuery": "select u.id, u.intcol from `user` as u where 1 != 1", + "Query": "select u.id, u.intcol from `user` as u", + "Table": "`user`" + }, + { + "OperatorType": "Route", + "Variant": "SelectScatter", + "Keyspace": { + "Name": "user", + "Sharded": true + }, + "FieldQuery": "select 1 from `user` as uu where 1 != 1", + "Query": "select 1 from `user` as uu where uu.intcol = :u_intcol", + "Table": "`user`" + } + ] + } +} +{ + "QueryType": "SELECT", + "Original": "select u.id from user as u join user as uu on u.intcol = uu.intcol", + "Instructions": { + "OperatorType": "Join", + "Variant": "Join", + "JoinColumnIndexes": "-2", + "JoinVars": { + "u_intcol": 0 + }, + "Predicate": "u.intcol = uu.intcol", + "TableName": "`user`_`user`", + "Inputs": [ + { + "OperatorType": "Route", + "Variant": "SelectScatter", + "Keyspace": { + "Name": "user", + "Sharded": true + }, + "FieldQuery": "select u.intcol, u.id from `user` as u where 1 != 1", + "Query": "select u.intcol, u.id from `user` as u", + "Table": "`user`" + }, + { + "OperatorType": "Route", + "Variant": "SelectScatter", + "Keyspace": { + "Name": "user", + "Sharded": true + }, + "FieldQuery": "select 1 from `user` as uu where 1 != 1", + "Query": "select 1 from `user` as uu where uu.intcol = :u_intcol", + "Table": "`user`" + } + ] + } +} + +# wire-up on within cross-shard derived table (hash-join version) +"select /*vt+ ALLOW_HASH_JOIN */ t.id from (select user.id, user.col1 from user join user_extra on user_extra.col = user.col) as t" +{ + "QueryType": "SELECT", + "Original": "select /*vt+ ALLOW_HASH_JOIN */ t.id from (select user.id, user.col1 from user join user_extra on user_extra.col = user.col) as t", + "Instructions": { + "OperatorType": "SimpleProjection", + "Columns": [ + 0 + ], + "Inputs": [ + { + "OperatorType": "Join", + "Variant": "Join", + "JoinColumnIndexes": "-1,-2", + "JoinVars": { + "user_col": 2 + }, + "TableName": "`user`_user_extra", + "Inputs": [ + { + "OperatorType": "Route", + "Variant": "SelectScatter", + "Keyspace": { + "Name": "user", + "Sharded": true + }, + "FieldQuery": "select `user`.id, `user`.col1, `user`.col from `user` where 1 != 1", + "Query": "select /*vt+ ALLOW_HASH_JOIN */ `user`.id, `user`.col1, `user`.col from `user`", + "Table": "`user`" + }, + { + "OperatorType": "Route", + "Variant": "SelectScatter", + "Keyspace": { + "Name": "user", + "Sharded": true + }, + "FieldQuery": "select 1 from user_extra where 1 != 1", + "Query": "select /*vt+ ALLOW_HASH_JOIN */ 1 from user_extra where user_extra.col = :user_col", + "Table": "user_extra" + } + ] + } + ] + } +} +{ + "QueryType": "SELECT", + "Original": "select /*vt+ ALLOW_HASH_JOIN */ t.id from (select user.id, user.col1 from user join user_extra on user_extra.col = user.col) as t", + "Instructions": { + "OperatorType": "SimpleProjection", + "Columns": [ + 0 + ], + "Inputs": [ + { + "OperatorType": "Join", + "Variant": "HashJoin", + "ComparisonType": "INT16", + "JoinColumnIndexes": "-2,-3", + "Predicate": "user_extra.col = `user`.col", + "TableName": "`user`_user_extra", + "Inputs": [ + { + "OperatorType": "Route", + "Variant": "SelectScatter", + "Keyspace": { + "Name": "user", + "Sharded": true + }, + "FieldQuery": "select `user`.col, `user`.id, `user`.col1 from `user` where 1 != 1", + "Query": "select /*vt+ ALLOW_HASH_JOIN */ `user`.col, `user`.id, `user`.col1 from `user`", + "Table": "`user`" + }, + { + "OperatorType": "Route", + "Variant": "SelectScatter", + "Keyspace": { + "Name": "user", + "Sharded": true + }, + "FieldQuery": "select user_extra.col from user_extra where 1 != 1", + "Query": "select /*vt+ ALLOW_HASH_JOIN */ user_extra.col from user_extra", + "Table": "user_extra" + } + ] + } + ] + } +} + +# hash join on int columns +"select /*vt+ ALLOW_HASH_JOIN */ u.id from user as u join user as uu on u.intcol = uu.intcol" +{ + "QueryType": "SELECT", + "Original": "select /*vt+ ALLOW_HASH_JOIN */ u.id from user as u join user as uu on u.intcol = uu.intcol", + "Instructions": { + "OperatorType": "Join", + "Variant": "Join", + "JoinColumnIndexes": "-1", + "JoinVars": { + "u_intcol": 1 + }, + "TableName": "`user`_`user`", + "Inputs": [ + { + "OperatorType": "Route", + "Variant": "SelectScatter", + "Keyspace": { + "Name": "user", + "Sharded": true + }, + "FieldQuery": "select u.id, u.intcol from `user` as u where 1 != 1", + "Query": "select /*vt+ ALLOW_HASH_JOIN */ u.id, u.intcol from `user` as u", + "Table": "`user`" + }, + { + "OperatorType": "Route", + "Variant": "SelectScatter", + "Keyspace": { + "Name": "user", + "Sharded": true + }, + "FieldQuery": "select 1 from `user` as uu where 1 != 1", + "Query": "select /*vt+ ALLOW_HASH_JOIN */ 1 from `user` as uu where uu.intcol = :u_intcol", + "Table": "`user`" + } + ] + } +} +{ + "QueryType": "SELECT", + "Original": "select /*vt+ ALLOW_HASH_JOIN */ u.id from user as u join user as uu on u.intcol = uu.intcol", + "Instructions": { + "OperatorType": "Join", + "Variant": "HashJoin", + "ComparisonType": "INT16", + "JoinColumnIndexes": "-2", + "Predicate": "u.intcol = uu.intcol", + "TableName": "`user`_`user`", + "Inputs": [ + { + "OperatorType": "Route", + "Variant": "SelectScatter", + "Keyspace": { + "Name": "user", + "Sharded": true + }, + "FieldQuery": "select u.intcol, u.id from `user` as u where 1 != 1", + "Query": "select /*vt+ ALLOW_HASH_JOIN */ u.intcol, u.id from `user` as u", + "Table": "`user`" + }, + { + "OperatorType": "Route", + "Variant": "SelectScatter", + "Keyspace": { + "Name": "user", + "Sharded": true + }, + "FieldQuery": "select uu.intcol from `user` as uu where 1 != 1", + "Query": "select /*vt+ ALLOW_HASH_JOIN */ uu.intcol from `user` as uu", + "Table": "`user`" + } + ] + } +} diff --git a/go/vt/vtgate/planbuilder/testdata/memory_sort_cases.txt b/go/vt/vtgate/planbuilder/testdata/memory_sort_cases.txt index d2050b74ba5..62ab68ab6b8 100644 --- a/go/vt/vtgate/planbuilder/testdata/memory_sort_cases.txt +++ b/go/vt/vtgate/planbuilder/testdata/memory_sort_cases.txt @@ -436,6 +436,7 @@ Gen4 error: Expression of SELECT list is not in GROUP BY clause and contains non "JoinVars": { "user_id": 0 }, + "Predicate": "`user`.id = music.id", "TableName": "`user`_music", "Inputs": [ { @@ -545,6 +546,7 @@ Gen4 error: Expression of SELECT list is not in GROUP BY clause and contains non "JoinVars": { "user_id": 0 }, + "Predicate": "`user`.id = music.id", "TableName": "`user`_music", "Inputs": [ { @@ -878,6 +880,7 @@ Gen4 plan same as above "JoinVars": { "u_a": 0 }, + "Predicate": "u.a = m.a", "TableName": "`user`_music", "Inputs": [ { diff --git a/go/vt/vtgate/planbuilder/testdata/postprocess_cases.txt b/go/vt/vtgate/planbuilder/testdata/postprocess_cases.txt index 9a8fee04792..47467380163 100644 --- a/go/vt/vtgate/planbuilder/testdata/postprocess_cases.txt +++ b/go/vt/vtgate/planbuilder/testdata/postprocess_cases.txt @@ -338,10 +338,9 @@ Gen4 plan same as above "Name": "user", "Sharded": true }, - "FieldQuery": "select col, weight_string(col) from `user` where 1 != 1", - "OrderBy": "(0|1) ASC", - "Query": "select col, weight_string(col) from `user` order by col asc", - "ResultColumns": 1, + "FieldQuery": "select col from `user` where 1 != 1", + "OrderBy": "0 ASC", + "Query": "select col from `user` order by col asc", "Table": "`user`" } } @@ -628,10 +627,9 @@ Gen4 plan same as above "Name": "user", "Sharded": true }, - "FieldQuery": "select col, weight_string(col) from `user` where 1 != 1", - "OrderBy": "(0|1) ASC", - "Query": "select col, weight_string(col) from `user` where :__sq_has_values1 = 1 and col in ::__sq1 order by col asc", - "ResultColumns": 1, + "FieldQuery": "select col from `user` where 1 != 1", + "OrderBy": "0 ASC", + "Query": "select col from `user` where :__sq_has_values1 = 1 and col in ::__sq1 order by col asc", "Table": "`user`" } ] @@ -696,6 +694,7 @@ Gen4 plan same as above "JoinVars": { "user_id": 0 }, + "Predicate": "`user`.id = music.id", "TableName": "`user`_music", "Inputs": [ { @@ -789,6 +788,7 @@ Gen4 plan same as above "JoinVars": { "user_id": 0 }, + "Predicate": "`user`.id = music.id", "TableName": "`user`_music", "Inputs": [ { @@ -883,6 +883,7 @@ Gen4 plan same as above "JoinVars": { "user_id": 0 }, + "Predicate": "`user`.id = music.id", "TableName": "`user`_music", "Inputs": [ { @@ -1036,6 +1037,7 @@ Gen4 plan same as above "JoinVars": { "user_id": 0 }, + "Predicate": "`user`.id = music.id", "TableName": "`user`_music", "Inputs": [ { @@ -1230,6 +1232,7 @@ Gen4 plan same as above "JoinVars": { "u_col": 0 }, + "Predicate": "u.col = e.col", "TableName": "`user`_user_extra", "Inputs": [ { diff --git a/go/vt/vtgate/planbuilder/testdata/rails_cases.txt b/go/vt/vtgate/planbuilder/testdata/rails_cases.txt index 4dfbf463125..6414bca70bd 100644 --- a/go/vt/vtgate/planbuilder/testdata/rails_cases.txt +++ b/go/vt/vtgate/planbuilder/testdata/rails_cases.txt @@ -125,6 +125,7 @@ "JoinVars": { "order2s_id": 0 }, + "Predicate": "order2s.id = book6s_order2s.order2_id", "TableName": "customer2s, order2s_author5s, book6s_book6s_order2s_supplier5s", "Inputs": [ { @@ -145,6 +146,7 @@ "JoinVars": { "book6s_supplier5_id": 0 }, + "Predicate": "supplier5s.id = book6s.supplier5_id and book6s_order2s.order2_id = :order2s_id", "TableName": "author5s, book6s_book6s_order2s_supplier5s", "Inputs": [ { @@ -154,6 +156,7 @@ "JoinVars": { "book6s_id": 0 }, + "Predicate": "book6s_order2s.book6_id = book6s.id and book6s_order2s.order2_id = :order2s_id", "TableName": "author5s, book6s_book6s_order2s", "Inputs": [ { @@ -204,3 +207,205 @@ ] } } + +# Author5.joins(books: [{orders: :customer}, :supplier]) (with hash join) +"select /*vt+ ALLOW_HASH_JOIN */ author5s.* from author5s join book6s on book6s.author5_id = author5s.id join book6s_order2s on book6s_order2s.book6_id = book6s.id join order2s on order2s.id = book6s_order2s.order2_id join customer2s on customer2s.id = order2s.customer2_id join supplier5s on supplier5s.id = book6s.supplier5_id" +{ + "QueryType": "SELECT", + "Original": "select /*vt+ ALLOW_HASH_JOIN */ author5s.* from author5s join book6s on book6s.author5_id = author5s.id join book6s_order2s on book6s_order2s.book6_id = book6s.id join order2s on order2s.id = book6s_order2s.order2_id join customer2s on customer2s.id = order2s.customer2_id join supplier5s on supplier5s.id = book6s.supplier5_id", + "Instructions": { + "OperatorType": "Join", + "Variant": "Join", + "JoinColumnIndexes": "-1,-2,-3,-4", + "JoinVars": { + "book6s_supplier5_id": 4 + }, + "TableName": "author5s, book6s_book6s_order2s_order2s_customer2s_supplier5s", + "Inputs": [ + { + "OperatorType": "Join", + "Variant": "Join", + "JoinColumnIndexes": "-1,-2,-3,-4,-5", + "JoinVars": { + "order2s_customer2_id": 5 + }, + "TableName": "author5s, book6s_book6s_order2s_order2s_customer2s", + "Inputs": [ + { + "OperatorType": "Join", + "Variant": "Join", + "JoinColumnIndexes": "-1,-2,-3,-4,-5,1", + "JoinVars": { + "book6s_order2s_order2_id": 5 + }, + "TableName": "author5s, book6s_book6s_order2s_order2s", + "Inputs": [ + { + "OperatorType": "Join", + "Variant": "Join", + "JoinColumnIndexes": "-1,-2,-3,-4,-5,1", + "JoinVars": { + "book6s_id": 5 + }, + "TableName": "author5s, book6s_book6s_order2s", + "Inputs": [ + { + "OperatorType": "Route", + "Variant": "SelectScatter", + "Keyspace": { + "Name": "user", + "Sharded": true + }, + "FieldQuery": "select author5s.id, author5s.`name`, author5s.created_at, author5s.updated_at, book6s.supplier5_id, book6s.id from author5s join book6s on book6s.author5_id = author5s.id where 1 != 1", + "Query": "select /*vt+ ALLOW_HASH_JOIN */ author5s.id, author5s.`name`, author5s.created_at, author5s.updated_at, book6s.supplier5_id, book6s.id from author5s join book6s on book6s.author5_id = author5s.id", + "Table": "author5s, book6s" + }, + { + "OperatorType": "Route", + "Variant": "SelectEqualUnique", + "Keyspace": { + "Name": "user", + "Sharded": true + }, + "FieldQuery": "select book6s_order2s.order2_id from book6s_order2s where 1 != 1", + "Query": "select /*vt+ ALLOW_HASH_JOIN */ book6s_order2s.order2_id from book6s_order2s where book6s_order2s.book6_id = :book6s_id", + "Table": "book6s_order2s", + "Values": [ + ":book6s_id" + ], + "Vindex": "binary_md5" + } + ] + }, + { + "OperatorType": "Route", + "Variant": "SelectScatter", + "Keyspace": { + "Name": "user", + "Sharded": true + }, + "FieldQuery": "select order2s.customer2_id from order2s where 1 != 1", + "Query": "select /*vt+ ALLOW_HASH_JOIN */ order2s.customer2_id from order2s where order2s.id = :book6s_order2s_order2_id", + "Table": "order2s" + } + ] + }, + { + "OperatorType": "Route", + "Variant": "SelectEqualUnique", + "Keyspace": { + "Name": "user", + "Sharded": true + }, + "FieldQuery": "select 1 from customer2s where 1 != 1", + "Query": "select /*vt+ ALLOW_HASH_JOIN */ 1 from customer2s where customer2s.id = :order2s_customer2_id", + "Table": "customer2s", + "Values": [ + ":order2s_customer2_id" + ], + "Vindex": "binary_md5" + } + ] + }, + { + "OperatorType": "Route", + "Variant": "SelectEqualUnique", + "Keyspace": { + "Name": "user", + "Sharded": true + }, + "FieldQuery": "select 1 from supplier5s where 1 != 1", + "Query": "select /*vt+ ALLOW_HASH_JOIN */ 1 from supplier5s where supplier5s.id = :book6s_supplier5_id", + "Table": "supplier5s", + "Values": [ + ":book6s_supplier5_id" + ], + "Vindex": "binary_md5" + } + ] + } +} +{ + "QueryType": "SELECT", + "Original": "select /*vt+ ALLOW_HASH_JOIN */ author5s.* from author5s join book6s on book6s.author5_id = author5s.id join book6s_order2s on book6s_order2s.book6_id = book6s.id join order2s on order2s.id = book6s_order2s.order2_id join customer2s on customer2s.id = order2s.customer2_id join supplier5s on supplier5s.id = book6s.supplier5_id", + "Instructions": { + "OperatorType": "Join", + "Variant": "HashJoin", + "ComparisonType": "INT64", + "JoinColumnIndexes": "2,3,4,5", + "Predicate": "order2s.id = book6s_order2s.order2_id", + "TableName": "customer2s, order2s_author5s, book6s_book6s_order2s_supplier5s", + "Inputs": [ + { + "OperatorType": "Route", + "Variant": "SelectScatter", + "Keyspace": { + "Name": "user", + "Sharded": true + }, + "FieldQuery": "select order2s.id from order2s, customer2s where 1 != 1", + "Query": "select /*vt+ ALLOW_HASH_JOIN */ order2s.id from order2s, customer2s where customer2s.id = order2s.customer2_id", + "Table": "customer2s, order2s" + }, + { + "OperatorType": "Join", + "Variant": "HashJoin", + "ComparisonType": "INT64", + "JoinColumnIndexes": "-1,-2,-3,-4,-5", + "Predicate": "supplier5s.id = book6s.supplier5_id", + "TableName": "author5s, book6s_book6s_order2s_supplier5s", + "Inputs": [ + { + "OperatorType": "Join", + "Variant": "Join", + "JoinColumnIndexes": "1,-3,-4,-5,-6", + "JoinVars": { + "book6s_id": 0 + }, + "Predicate": "book6s_order2s.book6_id = book6s.id", + "TableName": "author5s, book6s_book6s_order2s", + "Inputs": [ + { + "OperatorType": "Route", + "Variant": "SelectScatter", + "Keyspace": { + "Name": "user", + "Sharded": true + }, + "FieldQuery": "select book6s.id, book6s.supplier5_id, author5s.id as id, author5s.`name` as `name`, author5s.created_at as created_at, author5s.updated_at as updated_at from author5s, book6s where 1 != 1", + "Query": "select /*vt+ ALLOW_HASH_JOIN */ book6s.id, book6s.supplier5_id, author5s.id as id, author5s.`name` as `name`, author5s.created_at as created_at, author5s.updated_at as updated_at from author5s, book6s where book6s.author5_id = author5s.id", + "Table": "author5s, book6s" + }, + { + "OperatorType": "Route", + "Variant": "SelectEqualUnique", + "Keyspace": { + "Name": "user", + "Sharded": true + }, + "FieldQuery": "select book6s_order2s.order2_id from book6s_order2s where 1 != 1", + "Query": "select /*vt+ ALLOW_HASH_JOIN */ book6s_order2s.order2_id from book6s_order2s where book6s_order2s.book6_id = :book6s_id", + "Table": "book6s_order2s", + "Values": [ + ":book6s_id" + ], + "Vindex": "binary_md5" + } + ] + }, + { + "OperatorType": "Route", + "Variant": "SelectScatter", + "Keyspace": { + "Name": "user", + "Sharded": true + }, + "FieldQuery": "select supplier5s.id from supplier5s where 1 != 1", + "Query": "select /*vt+ ALLOW_HASH_JOIN */ supplier5s.id from supplier5s", + "Table": "supplier5s" + } + ] + } + ] + } +} diff --git a/go/vt/vtgate/planbuilder/testdata/schema_test.json b/go/vt/vtgate/planbuilder/testdata/schema_test.json index 68b3b9eb827..0dfcf485bdd 100644 --- a/go/vt/vtgate/planbuilder/testdata/schema_test.json +++ b/go/vt/vtgate/planbuilder/testdata/schema_test.json @@ -110,6 +110,10 @@ "sequence": "seq" }, "columns": [ + { + "name": "col", + "type": "INT16" + }, { "name": "predef1" }, @@ -160,7 +164,13 @@ "auto_increment": { "column": "extra_id", "sequence": "seq" - } + }, + "columns": [ + { + "name": "col", + "type": "INT16" + } + ] }, "music": { "column_vindexes": [ diff --git a/go/vt/vtgate/planbuilder/testdata/select_cases.txt b/go/vt/vtgate/planbuilder/testdata/select_cases.txt index d2ad9dfddc6..94a003819ac 100644 --- a/go/vt/vtgate/planbuilder/testdata/select_cases.txt +++ b/go/vt/vtgate/planbuilder/testdata/select_cases.txt @@ -475,10 +475,10 @@ Gen4 plan same as above } # auto-resolve anonymous columns for simple route -"select col from user join user_extra on user.id = user_extra.user_id" +"select anon_col from user join user_extra on user.id = user_extra.user_id" { "QueryType": "SELECT", - "Original": "select col from user join user_extra on user.id = user_extra.user_id", + "Original": "select anon_col from user join user_extra on user.id = user_extra.user_id", "Instructions": { "OperatorType": "Route", "Variant": "SelectScatter", @@ -486,14 +486,14 @@ Gen4 plan same as above "Name": "user", "Sharded": true }, - "FieldQuery": "select col from `user` join user_extra on `user`.id = user_extra.user_id where 1 != 1", - "Query": "select col from `user` join user_extra on `user`.id = user_extra.user_id", + "FieldQuery": "select anon_col from `user` join user_extra on `user`.id = user_extra.user_id where 1 != 1", + "Query": "select anon_col from `user` join user_extra on `user`.id = user_extra.user_id", "Table": "`user`, user_extra" } } { "QueryType": "SELECT", - "Original": "select col from user join user_extra on user.id = user_extra.user_id", + "Original": "select anon_col from user join user_extra on user.id = user_extra.user_id", "Instructions": { "OperatorType": "Route", "Variant": "SelectScatter", @@ -501,8 +501,8 @@ Gen4 plan same as above "Name": "user", "Sharded": true }, - "FieldQuery": "select col from `user`, user_extra where 1 != 1", - "Query": "select col from `user`, user_extra where `user`.id = user_extra.user_id", + "FieldQuery": "select anon_col from `user`, user_extra where 1 != 1", + "Query": "select anon_col from `user`, user_extra where `user`.id = user_extra.user_id", "Table": "`user`, user_extra" } } @@ -2602,9 +2602,9 @@ Gen4 plan same as above "Name": "user", "Sharded": true }, - "FieldQuery": "select `user`.id, col, weight_string(col) from `user` where 1 != 1", - "OrderBy": "(1|2) ASC", - "Query": "select `user`.id, col, weight_string(col) from `user` order by col asc", + "FieldQuery": "select `user`.id, col from `user` where 1 != 1", + "OrderBy": "1 ASC", + "Query": "select `user`.id, col from `user` order by col asc", "Table": "`user`" }, { diff --git a/go/vt/vtgate/planbuilder/testdata/symtab_cases.txt b/go/vt/vtgate/planbuilder/testdata/symtab_cases.txt index f31ab8742c6..93b906f99e6 100644 --- a/go/vt/vtgate/planbuilder/testdata/symtab_cases.txt +++ b/go/vt/vtgate/planbuilder/testdata/symtab_cases.txt @@ -39,7 +39,44 @@ ] } } -Gen4 plan same as above +{ + "QueryType": "SELECT", + "Original": "select predef2, predef3 from user join unsharded on predef2 = predef3", + "Instructions": { + "OperatorType": "Join", + "Variant": "Join", + "JoinColumnIndexes": "-1,1", + "JoinVars": { + "predef2": 0 + }, + "Predicate": "predef2 = predef3", + "TableName": "`user`_unsharded", + "Inputs": [ + { + "OperatorType": "Route", + "Variant": "SelectScatter", + "Keyspace": { + "Name": "user", + "Sharded": true + }, + "FieldQuery": "select predef2 from `user` where 1 != 1", + "Query": "select predef2 from `user`", + "Table": "`user`" + }, + { + "OperatorType": "Route", + "Variant": "SelectUnsharded", + "Keyspace": { + "Name": "main", + "Sharded": false + }, + "FieldQuery": "select predef3 from unsharded where 1 != 1", + "Query": "select predef3 from unsharded where predef3 = :predef2", + "Table": "unsharded" + } + ] + } +} # predef1 is in both user and unsharded. So, it's ambiguous. "select predef1, predef3 from user join unsharded on predef1 = predef3" diff --git a/go/vt/vtgate/planbuilder/testdata/systemtables_cases.txt b/go/vt/vtgate/planbuilder/testdata/systemtables_cases.txt index ca5fbb12c1e..3f8c8c8ee2a 100644 --- a/go/vt/vtgate/planbuilder/testdata/systemtables_cases.txt +++ b/go/vt/vtgate/planbuilder/testdata/systemtables_cases.txt @@ -942,6 +942,7 @@ Gen4 plan same as above "JoinVars": { "x_id": 0 }, + "Predicate": "x.id = `user`.id", "TableName": "information_schema.key_column_usage_`user`", "Inputs": [ { diff --git a/go/vt/vtgate/planbuilder/testdata/tpch_cases.txt b/go/vt/vtgate/planbuilder/testdata/tpch_cases.txt index 64d3aa702b7..badb5215aa9 100644 --- a/go/vt/vtgate/planbuilder/testdata/tpch_cases.txt +++ b/go/vt/vtgate/planbuilder/testdata/tpch_cases.txt @@ -114,7 +114,6 @@ Gen4 error: unsupported: in scatter query: complex aggregate expression #"with revenue0(supplier_no, total_revenue) as (select l_suppkey, sum(l_extendedprice * (1 - l_discount)) from lineitem where l_shipdate >= date('1996-01-01') and l_shipdate < date('1996-01-01') + interval '3' month group by l_suppkey )" #"syntax error at position 236" #Gen4 plan same as above - # TPC-H query 15 "select s_suppkey, s_name, s_address, s_phone, total_revenue from supplier, revenue0 where s_suppkey = supplier_no and total_revenue = ( select max(total_revenue) from revenue0 ) order by s_suppkey" { @@ -283,6 +282,7 @@ Gen4 error: unsupported: in scatter query: complex aggregate expression "JoinVars": { "ps_partkey": 0 }, + "Predicate": "p_partkey = ps_partkey", "TableName": "partsupp_part", "Inputs": [ { @@ -364,6 +364,7 @@ Gen4 error: unsupported: cross-shard correlated subquery "JoinVars": { "o_orderkey": 0 }, + "Predicate": "o_orderkey = l_orderkey", "TableName": "orders_customer_lineitem", "Inputs": [ { @@ -373,6 +374,7 @@ Gen4 error: unsupported: cross-shard correlated subquery "JoinVars": { "o_custkey": 0 }, + "Predicate": "c_custkey = o_custkey", "TableName": "orders_customer", "Inputs": [ { @@ -452,6 +454,7 @@ Gen4 error: unsupported: cross-shard correlated subquery "l_shipinstruct": 14, "l_shipmode": 13 }, + "Predicate": "p_partkey = l_partkey and p_brand = 'Brand#12' and p_container in ('SM CASE', 'SM BOX', 'SM PACK', 'SM PKG') and l_quantity \u003e= 1 and l_quantity \u003c= 1 + 10 and p_size between 1 and 5 and l_shipmode in ('AIR', 'AIR REG') and l_shipinstruct = 'DELIVER IN PERSON' or p_partkey = l_partkey and p_brand = 'Brand#23' and p_container in ('MED BAG', 'MED BOX', 'MED PKG', 'MED PACK') and l_quantity \u003e= 10 and l_quantity \u003c= 10 + 10 and p_size between 1 and 10 and l_shipmode in ('AIR', 'AIR REG') and l_shipinstruct = 'DELIVER IN PERSON' or p_partkey = l_partkey and p_brand = 'Brand#34' and p_container in ('LG CASE', 'LG BOX', 'LG PACK', 'LG PKG') and l_quantity \u003e= 20 and l_quantity \u003c= 20 + 10 and p_size between 1 and 15 and l_shipmode in ('AIR', 'AIR REG') and l_shipinstruct = 'DELIVER IN PERSON'", "TableName": "lineitem_part", "Inputs": [ { diff --git a/go/vt/vtgate/planbuilder/testdata/vindex_func_cases.txt b/go/vt/vtgate/planbuilder/testdata/vindex_func_cases.txt index 2eef8921494..2c2b955b07b 100644 --- a/go/vt/vtgate/planbuilder/testdata/vindex_func_cases.txt +++ b/go/vt/vtgate/planbuilder/testdata/vindex_func_cases.txt @@ -307,6 +307,7 @@ Gen4 plan same as above "JoinVars": { "user_index_id": 0 }, + "Predicate": "unsharded.id = user_index.id", "TableName": "_unsharded", "Inputs": [ { @@ -390,6 +391,7 @@ Gen4 plan same as above "JoinVars": { "user_index_id": 0 }, + "Predicate": "unsharded.id = user_index.id", "TableName": "_unsharded", "Inputs": [ { @@ -473,6 +475,7 @@ Gen4 plan same as above "JoinVars": { "user_index_id": 0 }, + "Predicate": "unsharded.id = user_index.id", "TableName": "_unsharded", "Inputs": [ { @@ -556,6 +559,7 @@ Gen4 plan same as above "JoinVars": { "ui_id": 0 }, + "Predicate": "unsharded.id = ui.id", "TableName": "_unsharded", "Inputs": [ { diff --git a/go/vt/vtgate/planbuilder/testdata/wireup_cases.txt b/go/vt/vtgate/planbuilder/testdata/wireup_cases.txt index 19186f53e91..f4a5180e3f7 100644 --- a/go/vt/vtgate/planbuilder/testdata/wireup_cases.txt +++ b/go/vt/vtgate/planbuilder/testdata/wireup_cases.txt @@ -47,6 +47,7 @@ "JoinVars": { "e_id": 0 }, + "Predicate": "u.id = e.id", "TableName": "user_extra_`user`", "Inputs": [ { @@ -128,6 +129,7 @@ "JoinVars": { "e_id": 0 }, + "Predicate": "u.id = e.id", "TableName": "user_extra_`user`", "Inputs": [ { @@ -245,6 +247,7 @@ "JoinVars": { "u1_col": 0 }, + "Predicate": "u3.col = u1.col", "TableName": "`user`_`user`", "Inputs": [ { @@ -359,6 +362,7 @@ "JoinVars": { "u2_col": 0 }, + "Predicate": "u3.col = u2.col", "TableName": "`user`_`user`", "Inputs": [ { @@ -460,6 +464,7 @@ "JoinVars": { "u3_col": 0 }, + "Predicate": "u3.col = u1.col", "TableName": "`user`_`user`_`user`", "Inputs": [ { @@ -480,6 +485,7 @@ "JoinVars": { "u1_col": 0 }, + "Predicate": "u2.col = u1.col and u1.col = :u3_col", "TableName": "`user`_`user`", "Inputs": [ { @@ -624,6 +630,7 @@ "JoinVars": { "u4_col": 0 }, + "Predicate": "u4.col = u1.col", "TableName": "`user`_`user`_`user`", "Inputs": [ { @@ -644,6 +651,7 @@ "JoinVars": { "u1_col": 0 }, + "Predicate": "u3.id = u1.col and u1.col = :u4_col", "TableName": "`user`_`user`", "Inputs": [ { @@ -755,6 +763,7 @@ "JoinVars": { "u1_col": 0 }, + "Predicate": "u3.id = u1.col", "TableName": "`user`_`user`_`user`", "Inputs": [ { @@ -764,6 +773,7 @@ "JoinVars": { "u1_col": 0 }, + "Predicate": "u2.id = u1.col", "TableName": "`user`_`user`", "Inputs": [ { @@ -862,6 +872,7 @@ "JoinVars": { "unsharded_id": 0 }, + "Predicate": "`weird``name`.`a``b*c` = unsharded.id", "TableName": "unsharded_`weird``name`", "Inputs": [ { @@ -943,6 +954,7 @@ "JoinVars": { "unsharded_id": 0 }, + "Predicate": "`weird``name`.`a``b*c` = unsharded.id", "TableName": "unsharded_`weird``name`", "Inputs": [ { @@ -1034,6 +1046,7 @@ "JoinVars": { "u_col": 0 }, + "Predicate": "e.id = u.col", "TableName": "`user`_user_extra", "Inputs": [ { @@ -1156,6 +1169,7 @@ "JoinVars": { "u_col": 0 }, + "Predicate": "e.id = u.col", "TableName": "`user`_user_extra", "Inputs": [ { @@ -1301,6 +1315,7 @@ "JoinVars": { "u_col": 0 }, + "Predicate": "e.id = u.col", "TableName": "`user`_user_extra", "Inputs": [ { diff --git a/go/vt/vtgate/planbuilder/union.go b/go/vt/vtgate/planbuilder/union.go index c369c801fbe..8a73cd59884 100644 --- a/go/vt/vtgate/planbuilder/union.go +++ b/go/vt/vtgate/planbuilder/union.go @@ -76,7 +76,7 @@ func (pb *primitiveBuilder) processUnion(union *sqlparser.Union, reservedVars *s } if union.Distinct { - pb.plan = newDistinct(pb.plan) + pb.plan = newDistinct(pb.plan, nil) } } pb.st.Outer = outer diff --git a/go/vt/vtgate/semantics/analyzer.go b/go/vt/vtgate/semantics/analyzer.go index 478a72bc266..ad752071da0 100644 --- a/go/vt/vtgate/semantics/analyzer.go +++ b/go/vt/vtgate/semantics/analyzer.go @@ -77,7 +77,7 @@ func (a analyzer) newSemTable(statement sqlparser.SelectStatement) *SemTable { return &SemTable{ Recursive: a.binder.recursive, Direct: a.binder.direct, - exprTypes: a.typer.exprTypes, + ExprTypes: a.typer.exprTypes, Tables: a.tables.Tables, selectScope: a.scoper.rScope, ProjectionErr: a.projErr, diff --git a/go/vt/vtgate/semantics/dependencies.go b/go/vt/vtgate/semantics/dependencies.go index 86d89900833..8d3ade1d096 100644 --- a/go/vt/vtgate/semantics/dependencies.go +++ b/go/vt/vtgate/semantics/dependencies.go @@ -17,6 +17,7 @@ limitations under the License. package semantics import ( + querypb "vitess.io/vitess/go/vt/proto/query" vtrpcpb "vitess.io/vitess/go/vt/proto/vtrpc" "vitess.io/vitess/go/vt/vterrors" ) @@ -47,13 +48,16 @@ type ( var ambigousErr = vterrors.Errorf(vtrpcpb.Code_INVALID_ARGUMENT, "ambiguous") func createCertain(direct TableSet, recursive TableSet, qt *Type) *certain { - return &certain{ + c := &certain{ dependency: dependency{ direct: direct, recursive: recursive, - typ: qt, }, } + if qt != nil && qt.Type != querypb.Type_NULL_TYPE { + c.typ = qt + } + return c } func createUncertain(direct TableSet, recursive TableSet) *uncertain { diff --git a/go/vt/vtgate/semantics/semantic_state.go b/go/vt/vtgate/semantics/semantic_state.go index 07ea3841eaa..90995e0815d 100644 --- a/go/vt/vtgate/semantics/semantic_state.go +++ b/go/vt/vtgate/semantics/semantic_state.go @@ -82,7 +82,7 @@ type ( // It does not recurse inside derived tables and the like to find the original dependencies Direct ExprDependencies - exprTypes map[sqlparser.Expr]Type + ExprTypes map[sqlparser.Expr]Type selectScope map[*sqlparser.Select]*scope Comments sqlparser.Comments SubqueryMap map[*sqlparser.Select][]*sqlparser.ExtractedSubquery @@ -201,7 +201,7 @@ func (st *SemTable) AddExprs(tbl *sqlparser.AliasedTableExpr, cols sqlparser.Sel // TypeFor returns the type of expressions in the query func (st *SemTable) TypeFor(e sqlparser.Expr) *querypb.Type { - typ, found := st.exprTypes[e] + typ, found := st.ExprTypes[e] if found { return &typ.Type } @@ -210,7 +210,7 @@ func (st *SemTable) TypeFor(e sqlparser.Expr) *querypb.Type { // CollationFor returns the collation name of expressions in the query func (st *SemTable) CollationFor(e sqlparser.Expr) collations.ID { - typ, found := st.exprTypes[e] + typ, found := st.ExprTypes[e] if found { return typ.Collation } @@ -304,3 +304,12 @@ func (st *SemTable) GetSubqueryNeedingRewrite() []*sqlparser.ExtractedSubquery { } return res } + +// CopyExprInfo lookups src in the ExprTypes map and, if a key is found, assign +// the corresponding Type value of src to dest. +func (st *SemTable) CopyExprInfo(src, dest sqlparser.Expr) { + srcType, found := st.ExprTypes[src] + if found { + st.ExprTypes[dest] = srcType + } +}