Skip to content

Commit

Permalink
altered concatenate primitive to accept more than 2 primitives
Browse files Browse the repository at this point in the history
Signed-off-by: GuptaManan100 <[email protected]>
  • Loading branch information
GuptaManan100 committed Jan 8, 2021
1 parent e51e3d5 commit ec0aa0c
Show file tree
Hide file tree
Showing 5 changed files with 151 additions and 50 deletions.
98 changes: 62 additions & 36 deletions go/vt/vtgate/engine/concatenate.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,12 +42,20 @@ func (c *Concatenate) RouteType() string {

// GetKeyspaceName specifies the Keyspace that this primitive routes to
func (c *Concatenate) GetKeyspaceName() string {
return formatTwoOptionsNicely(c.Sources[0].GetKeyspaceName(), c.Sources[1].GetKeyspaceName())
res := c.Sources[0].GetKeyspaceName()
for i := 1; i < len(c.Sources); i++ {
res = formatTwoOptionsNicely(res, c.Sources[i].GetKeyspaceName())
}
return res
}

// GetTableName specifies the table that this primitive routes to.
func (c *Concatenate) GetTableName() string {
return formatTwoOptionsNicely(c.Sources[0].GetTableName(), c.Sources[1].GetTableName())
res := c.Sources[0].GetTableName()
for i := 1; i < len(c.Sources); i++ {
res = formatTwoOptionsNicely(res, c.Sources[i].GetTableName())
}
return res
}

func formatTwoOptionsNicely(a, b string) string {
Expand All @@ -59,47 +67,58 @@ func formatTwoOptionsNicely(a, b string) string {

// Execute performs a non-streaming exec.
func (c *Concatenate) Execute(vcursor VCursor, bindVars map[string]*querypb.BindVariable, wantfields bool) (*sqltypes.Result, error) {
lhs, rhs, err := c.execSources(vcursor, bindVars, wantfields)
res, err := c.execSources(vcursor, bindVars, wantfields)
if err != nil {
return nil, err
}

fields, err := c.getFields(lhs.Fields, rhs.Fields)
fields, err := c.getFields(res)
if err != nil {
return nil, err
}

if len(lhs.Rows) > 0 &&
len(rhs.Rows) > 0 &&
len(lhs.Rows[0]) != len(rhs.Rows[0]) {
return nil, mysql.NewSQLError(mysql.ERWrongNumberOfColumnsInSelect, "21000", "The used SELECT statements have a different number of columns")
var rowsAffected uint64 = 0
var rows [][]sqltypes.Value

for _, r := range res {
rowsAffected += r.RowsAffected

if len(rows) > 0 &&
len(r.Rows) > 0 &&
len(rows[0]) != len(r.Rows[0]) {
return nil, mysql.NewSQLError(mysql.ERWrongNumberOfColumnsInSelect, "21000", "The used SELECT statements have a different number of columns")
}

rows = append(rows, r.Rows...)
}

return &sqltypes.Result{
Fields: fields,
RowsAffected: lhs.RowsAffected + rhs.RowsAffected,
Rows: append(lhs.Rows, rhs.Rows...),
RowsAffected: rowsAffected,
Rows: rows,
}, nil
}

func (c *Concatenate) getFields(a, b []*querypb.Field) ([]*querypb.Field, error) {
switch {
case a != nil && b != nil:
err := compareFields(a, b)
func (c *Concatenate) getFields(res []*sqltypes.Result) ([]*querypb.Field, error) {
var resFields []*querypb.Field
for _, r := range res {
fields := r.Fields
if fields == nil {
continue
}
if resFields == nil {
resFields = fields
continue
}
err := compareFields(fields, resFields)
if err != nil {
return nil, err
}
return a, nil
case a != nil:
return a, nil
case b != nil:
return b, nil
}

return nil, nil
return resFields, nil
}
func (c *Concatenate) execSources(vcursor VCursor, bindVars map[string]*querypb.BindVariable, wantfields bool) (*sqltypes.Result, *sqltypes.Result, error) {
results := make([]*sqltypes.Result, 2)
func (c *Concatenate) execSources(vcursor VCursor, bindVars map[string]*querypb.BindVariable, wantfields bool) ([]*sqltypes.Result, error) {
results := make([]*sqltypes.Result, len(c.Sources))
g, restoreCtx := vcursor.ErrorGroupCancellableContext()
defer restoreCtx()
for i, source := range c.Sources {
Expand All @@ -115,9 +134,9 @@ func (c *Concatenate) execSources(vcursor VCursor, bindVars map[string]*querypb.
}

if err := g.Wait(); err != nil {
return nil, nil, vterrors.Wrap(err, "Concatenate.Execute")
return nil, vterrors.Wrap(err, "Concatenate.Execute")
}
return results[0], results[1], nil
return results, nil
}

// StreamExecute performs a streaming exec.
Expand Down Expand Up @@ -177,25 +196,32 @@ func (c *Concatenate) StreamExecute(vcursor VCursor, bindVars map[string]*queryp

// GetFields fetches the field info.
func (c *Concatenate) GetFields(vcursor VCursor, bindVars map[string]*querypb.BindVariable) (*sqltypes.Result, error) {
lhs, err := c.Sources[0].GetFields(vcursor, bindVars)
if err != nil {
return nil, err
}
rhs, err := c.Sources[1].GetFields(vcursor, bindVars)
if err != nil {
return nil, err
}
err = compareFields(lhs.Fields, rhs.Fields)
res, err := c.Sources[0].GetFields(vcursor, bindVars)
if err != nil {
return nil, err
}

return lhs, nil
for i := 1; i < len(c.Sources); i++ {
result, err := c.Sources[i].GetFields(vcursor, bindVars)
if err != nil {
return nil, err
}
err = compareFields(result.Fields, res.Fields)
if err != nil {
return nil, err
}
}
return res, nil
}

//NeedsTransaction returns whether a transaction is needed for this primitive
func (c *Concatenate) NeedsTransaction() bool {
return c.Sources[0].NeedsTransaction() || c.Sources[1].NeedsTransaction()
for _, source := range c.Sources {
if source.NeedsTransaction() {
return true
}
}
return false
}

// Inputs returns the input primitives for this
Expand Down
20 changes: 14 additions & 6 deletions go/vt/vtgate/engine/concatenate_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,25 +45,30 @@ func TestConcatenate_NoErrors(t *testing.T) {
inputs: []*sqltypes.Result{
r("id1|col11|col12", "int64|varbinary|varbinary"),
r("id2|col21|col22", "int64|varbinary|varbinary"),
r("id3|col31|col32", "int64|varbinary|varbinary"),
},
expectedResult: r("id1|col11|col12", "int64|varbinary|varbinary"),
}, {
testName: "2 non empty result",
inputs: []*sqltypes.Result{
r("myid|mycol1|mycol2", "int64|varchar|varbinary", "11|m1|n1", "22|m2|n2"),
r("id|col1|col2", "int64|varchar|varbinary", "1|a1|b1", "2|a2|b2"),
r("id2|col2|col3", "int64|varchar|varbinary", "3|a3|b3"),
r("id2|col2|col3", "int64|varchar|varbinary", "4|a4|b4"),
},
expectedResult: r("myid|mycol1|mycol2", "int64|varchar|varbinary", "11|m1|n1", "22|m2|n2", "1|a1|b1", "2|a2|b2"),
expectedResult: r("myid|mycol1|mycol2", "int64|varchar|varbinary", "11|m1|n1", "22|m2|n2", "1|a1|b1", "2|a2|b2", "3|a3|b3", "4|a4|b4"),
}, {
testName: "mismatch field type",
inputs: []*sqltypes.Result{
r("id|col1|col2", "int64|varbinary|varbinary", "1|a1|b1", "2|a2|b2"),
r("id|col1|col2", "int64|varbinary|varbinary", "1|a1|b1", "2|a2|b2"),
r("id|col3|col4", "int64|varchar|varbinary", "1|a1|b1", "2|a2|b2"),
},
expectedError: "column field type does not match for name",
}, {
testName: "input source has different column count",
inputs: []*sqltypes.Result{
r("id|col1|col2", "int64|varchar|varchar", "1|a1|b1", "2|a2|b2"),
r("id|col1|col2", "int64|varchar|varchar", "1|a1|b1", "2|a2|b2"),
r("id|col3|col4|col5", "int64|varchar|varchar|int32", "1|a1|b1|5", "2|a2|b2|6"),
},
Expand All @@ -78,12 +83,13 @@ func TestConcatenate_NoErrors(t *testing.T) {
}}

for _, tc := range testCases {
require.Equal(t, 2, len(tc.inputs))
var sources []Primitive
for _, input := range tc.inputs {
// input is added twice, since the first one is used by execute and the next by stream execute
sources = append(sources, &fakePrimitive{results: []*sqltypes.Result{input, input}})
}
concatenate := &Concatenate{
Sources: []Primitive{
&fakePrimitive{results: []*sqltypes.Result{tc.inputs[0], tc.inputs[0]}},
&fakePrimitive{results: []*sqltypes.Result{tc.inputs[1], tc.inputs[1]}},
},
Sources: sources,
}

t.Run(tc.testName+"-Execute", func(t *testing.T) {
Expand Down Expand Up @@ -117,6 +123,7 @@ func TestConcatenate_WithErrors(t *testing.T) {
fake := r("id|col1|col2", "int64|varchar|varbinary", "1|a1|b1", "2|a2|b2")
concatenate := &Concatenate{
Sources: []Primitive{
&fakePrimitive{results: []*sqltypes.Result{fake, fake}},
&fakePrimitive{results: []*sqltypes.Result{nil, nil}, sendErr: errors.New(strFailed)},
&fakePrimitive{results: []*sqltypes.Result{fake, fake}},
},
Expand All @@ -132,6 +139,7 @@ func TestConcatenate_WithErrors(t *testing.T) {
Sources: []Primitive{
&fakePrimitive{results: []*sqltypes.Result{fake, fake}},
&fakePrimitive{results: []*sqltypes.Result{nil, nil}, sendErr: errors.New(strFailed)},
&fakePrimitive{results: []*sqltypes.Result{fake, fake}},
},
}
_, err = concatenate.Execute(&noopVCursor{ctx: ctx}, nil, true)
Expand Down
24 changes: 16 additions & 8 deletions go/vt/vtgate/planbuilder/builder.go
Original file line number Diff line number Diff line change
Expand Up @@ -284,6 +284,18 @@ func buildFlushTables(stmt *sqlparser.Flush, vschema ContextVSchema) (engine.Pri
flushStatements[sendDest{keyspaceTab, destinationTab}] = flush
}

if len(flushStatements) == 1 {
for sendDest, flush := range flushStatements {
return &engine.Send{
Keyspace: sendDest.ks,
TargetDestination: sendDest.dest,
Query: sqlparser.String(flush),
IsDML: false,
SingleShardOnly: false,
}, nil
}
}

keys := make([]sendDest, len(flushStatements))

// Collect keys of the map
Expand All @@ -297,7 +309,9 @@ func buildFlushTables(stmt *sqlparser.Flush, vschema ContextVSchema) (engine.Pri
return keys[i].ks.Name < keys[j].ks.Name
})

var finalPlan engine.Primitive
finalPlan := &engine.Concatenate{
Sources: nil,
}
for _, sendDest := range keys {
plan := &engine.Send{
Keyspace: sendDest.ks,
Expand All @@ -306,13 +320,7 @@ func buildFlushTables(stmt *sqlparser.Flush, vschema ContextVSchema) (engine.Pri
IsDML: false,
SingleShardOnly: false,
}
if finalPlan == nil {
finalPlan = plan
} else {
finalPlan = &engine.Concatenate{
Sources: []engine.Primitive{finalPlan, plan},
}
}
finalPlan.Sources = append(finalPlan.Sources, plan)
}

return finalPlan, nil
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -91,3 +91,48 @@
"SingleShardOnly": false
}
}

# Flush statement with 3 keyspaces
"flush local tables user, unsharded_a, user_extra, unsharded_tab with read lock"
{
"QueryType": "FLUSH",
"Original": "flush local tables user, unsharded_a, user_extra, unsharded_tab with read lock",
"Instructions": {
"OperatorType": "Concatenate",
"Inputs": [
{
"OperatorType": "Send",
"Keyspace": {
"Name": "main",
"Sharded": false
},
"TargetDestination": "AllShards()",
"IsDML": false,
"Query": "flush local tables unsharded_a with read lock",
"SingleShardOnly": false
},
{
"OperatorType": "Send",
"Keyspace": {
"Name": "main_2",
"Sharded": false
},
"TargetDestination": "AllShards()",
"IsDML": false,
"Query": "flush local tables unsharded_tab with read lock",
"SingleShardOnly": false
},
{
"OperatorType": "Send",
"Keyspace": {
"Name": "user",
"Sharded": true
},
"TargetDestination": "AllShards()",
"IsDML": false,
"Query": "flush local tables user, user_extra with read lock",
"SingleShardOnly": false
}
]
}
}
14 changes: 14 additions & 0 deletions go/vt/vtgate/planbuilder/testdata/schema_test.json
Original file line number Diff line number Diff line change
Expand Up @@ -306,6 +306,20 @@
"type": "sequence"
}
}
},
"main_2": {
"tables": {
"unsharded_tab": {
"columns": [
{
"name": "predef1"
},
{
"name": "predef3"
}
]
}
}
}
}
}

0 comments on commit ec0aa0c

Please sign in to comment.