Skip to content

Commit

Permalink
sql: support #typeHints greater than #placeholders for prepare stmt
Browse files Browse the repository at this point in the history
Previous, we only support pgwire prepare stmt with the number of typehints
equal or smaller than the number of the placeholders in the query. E.g. the
following usage are not supported:

```
Parse {"Name": "s2", "Query": "select $1", "ParameterOIDs":[1043, 1043, 1043]}
```
Where there are 1 placeholder in the query, but 3 type hints.

This commit is to allow mismatching #typeHints and #placeholders. The former
can be larger than the latter now.

Release justification: Low risk, high benefit changes to existing functionality

Release note: For pgwire-level prepare statements, support the case where the
number of the type hints is greater than the number of placeholders in the
given query.
  • Loading branch information
ZhouXing19 committed Sep 7, 2022
1 parent 819577e commit 14dc93d
Show file tree
Hide file tree
Showing 14 changed files with 61 additions and 62 deletions.
5 changes: 0 additions & 5 deletions pkg/acceptance/testdata/node/base-test.js
Original file line number Diff line number Diff line change
Expand Up @@ -38,11 +38,6 @@ describe('select', () => {

describe('error cases', () => {
const cases = [{
name: 'not enough params',
query: { text: 'SELECT 3', values: ['foo'] },
msg: "expected 0 arguments, got 1",
code: '08P01',
}, {
name: 'invalid utf8',
query: { text: 'SELECT $1::STRING', values: [new Buffer([167])] },
msg: "invalid UTF-8 sequence",
Expand Down
18 changes: 2 additions & 16 deletions pkg/sql/conn_executor_prepare.go
Original file line number Diff line number Diff line change
Expand Up @@ -200,11 +200,7 @@ func (ex *connExecutor) prepare(
for i := range placeholderHints {
if placeholderHints[i] == nil {
if i >= len(rawTypeHints) {
return pgwirebase.NewProtocolViolationErrorf(
"expected %d arguments, got %d",
len(placeholderHints),
len(rawTypeHints),
)
break
}
if types.IsOIDUserDefinedType(rawTypeHints[i]) {
var err error
Expand Down Expand Up @@ -272,12 +268,8 @@ func (ex *connExecutor) populatePrepared(
}
}
stmt := &p.stmt
var fromSQL bool
if origin == PreparedStatementOriginSQL {
fromSQL = true
}

if err := p.semaCtx.Placeholders.Init(stmt.NumPlaceholders, placeholderHints, fromSQL); err != nil {
if err := p.semaCtx.Placeholders.Init(stmt.NumPlaceholders, placeholderHints); err != nil {
return 0, err
}
p.extendedEvalCtx.PrepareOnly = true
Expand Down Expand Up @@ -391,12 +383,6 @@ func (ex *connExecutor) execBind(
}
}

if len(bindCmd.Args) != int(numQArgs) {
return retErr(
pgwirebase.NewProtocolViolationErrorf(
"expected %d arguments, got %d", numQArgs, len(bindCmd.Args)))
}

resolve := func(ctx context.Context, txn *kv.Txn) (err error) {
ex.statsCollector.Reset(ex.applicationStats, ex.phaseTimes)
p := &ex.planner
Expand Down
7 changes: 6 additions & 1 deletion pkg/sql/internal.go
Original file line number Diff line number Diff line change
Expand Up @@ -943,7 +943,12 @@ func (ie *InternalExecutor) execInternal(
return nil, err
}

typeHints := make(tree.PlaceholderTypes, len(datums))
// We take max(len(s.Types), stmt.NumPlaceHolders) as the length of types.
numParams := len(datums)
if parsed.NumPlaceholders > numParams {
numParams = parsed.NumPlaceholders
}
typeHints := make(tree.PlaceholderTypes, numParams)
for i, d := range datums {
// Arg numbers start from 1.
typeHints[tree.PlaceholderIdx(i)] = d.ResolvedType()
Expand Down
2 changes: 1 addition & 1 deletion pkg/sql/opt/bench/bench_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -600,7 +600,7 @@ func newHarness(tb testing.TB, query benchQuery, schemas []string) *harness {
}
}

if err := h.semaCtx.Placeholders.Init(len(query.args), nil /* typeHints */, false /* fromSQL */); err != nil {
if err := h.semaCtx.Placeholders.Init(len(query.args), nil /* typeHints */); err != nil {
tb.Fatal(err)
}
// Run optbuilder to build the memo for Prepare. Even if we will not be using
Expand Down
2 changes: 1 addition & 1 deletion pkg/sql/opt/testutils/build.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ func BuildQuery(

ctx := context.Background()
semaCtx := tree.MakeSemaContext()
if err := semaCtx.Placeholders.Init(stmt.NumPlaceholders, nil /* typeHints */, false /* fromSQL */); err != nil {
if err := semaCtx.Placeholders.Init(stmt.NumPlaceholders, nil /* typeHints */); err != nil {
t.Fatal(err)
}
semaCtx.Annotations = tree.MakeAnnotations(stmt.NumAnnotations)
Expand Down
2 changes: 1 addition & 1 deletion pkg/sql/opt/testutils/opttester/opt_tester.go
Original file line number Diff line number Diff line change
Expand Up @@ -2196,7 +2196,7 @@ func (ot *OptTester) buildExpr(factory *norm.Factory) error {
if err != nil {
return err
}
if err := ot.semaCtx.Placeholders.Init(stmt.NumPlaceholders, nil /* typeHints */, false /* fromSQL */); err != nil {
if err := ot.semaCtx.Placeholders.Init(stmt.NumPlaceholders, nil /* typeHints */); err != nil {
return err
}
ot.semaCtx.Annotations = tree.MakeAnnotations(stmt.NumAnnotations)
Expand Down
12 changes: 5 additions & 7 deletions pkg/sql/pgwire/conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -937,19 +937,17 @@ func (c *conn) handleParse(
}
// len(stmts) == 0 results in a nil (empty) statement.

if len(inTypeHints) > stmt.NumPlaceholders {
err := pgwirebase.NewProtocolViolationErrorf(
"received too many type hints: %d vs %d placeholders in query",
len(inTypeHints), stmt.NumPlaceholders,
)
return c.stmtBuf.Push(ctx, sql.SendError{Err: err})
// We take max(len(s.Types), stmt.NumPlaceHolders) as the length of types.
numParams := len(inTypeHints)
if stmt.NumPlaceholders > numParams {
numParams = stmt.NumPlaceholders
}

var sqlTypeHints tree.PlaceholderTypes
if len(inTypeHints) > 0 {
// Prepare the mapping of SQL placeholder names to types. Pre-populate it with
// the type hints received from the client, if any.
sqlTypeHints = make(tree.PlaceholderTypes, stmt.NumPlaceholders)
sqlTypeHints = make(tree.PlaceholderTypes, numParams)
for i, t := range inTypeHints {
if t == 0 {
continue
Expand Down
32 changes: 32 additions & 0 deletions pkg/sql/pgwire/testdata/pgtest/prepare
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
send
Parse {"Name": "s2", "Query": "select $1", "ParameterOIDs":[1043, 1043, 1043]}
Bind {"DestinationPortal": "p2", "PreparedStatement": "s2", "ParameterFormatCodes": [0], "Parameters": [{"text":"whitebear"}, {"text":"blackbear"}, {"text":"brownbear"}]}
Execute {"Portal": "p2"}
Sync
----

until
ReadyForQuery
----
{"Type":"ParseComplete"}
{"Type":"BindComplete"}
{"Type":"DataRow","Values":[{"text":"whitebear"}]}
{"Type":"CommandComplete","CommandTag":"SELECT 1"}
{"Type":"ReadyForQuery","TxStatus":"I"}


send
Parse {"Name": "s3", "Query": "select $1, $2::int", "ParameterOIDs":[1043]}
Bind {"DestinationPortal": "p3", "PreparedStatement": "s3", "ParameterFormatCodes": [0], "Parameters": [{"text":"winnie"}, {"text":"123"}]}
Execute {"Portal": "p3"}
Sync
----

until
ReadyForQuery
----
{"Type":"ParseComplete"}
{"Type":"BindComplete"}
{"Type":"DataRow","Values":[{"text":"winnie"},{"text":"123"}]}
{"Type":"CommandComplete","CommandTag":"SELECT 1"}
{"Type":"ReadyForQuery","TxStatus":"I"}
6 changes: 2 additions & 4 deletions pkg/sql/plan_opt.go
Original file line number Diff line number Diff line change
Expand Up @@ -187,10 +187,8 @@ func (p *planner) prepareUsingOptimizer(ctx context.Context) (planFlags, error)
}
}

if p.semaCtx.Placeholders.PlaceholderTypesInfo.FromSQLPrepare {
// Fill blank placeholder types with the type hints.
p.semaCtx.Placeholders.MaybeExtendTypes()
}
// Fill blank placeholder types with the type hints.
p.semaCtx.Placeholders.MaybeExtendTypes()

// Verify that all placeholder types have been set.
if err := p.semaCtx.Placeholders.Types.AssertAllSet(); err != nil {
Expand Down
2 changes: 1 addition & 1 deletion pkg/sql/schemachange/alter_column_type.go
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,7 @@ func ClassifyConversion(

// See if there's existing cast logic. If so, return general.
semaCtx := tree.MakeSemaContext()
if err := semaCtx.Placeholders.Init(1 /* numPlaceholders */, nil /* typeHints */, false /* fromSQL */); err != nil {
if err := semaCtx.Placeholders.Init(1 /* numPlaceholders */, nil /* typeHints */); err != nil {
return ColumnConversionImpossible, err
}

Expand Down
2 changes: 1 addition & 1 deletion pkg/sql/sem/tree/overload_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -269,7 +269,7 @@ func TestTypeCheckOverloadedExprs(t *testing.T) {
for i, d := range testData {
t.Run(fmt.Sprintf("%v/%v", d.exprs, d.overloads), func(t *testing.T) {
semaCtx := MakeSemaContext()
if err := semaCtx.Placeholders.Init(2 /* numPlaceholders */, nil /* typeHints */, false /* fromSQL */); err != nil {
if err := semaCtx.Placeholders.Init(2 /* numPlaceholders */, nil /* typeHints */); err != nil {
t.Fatal(err)
}
desired := types.Any
Expand Down
23 changes: 4 additions & 19 deletions pkg/sql/sem/tree/placeholders.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@ import (
"github.com/cockroachdb/cockroach/pkg/sql/pgwire/pgcode"
"github.com/cockroachdb/cockroach/pkg/sql/pgwire/pgerror"
"github.com/cockroachdb/cockroach/pkg/sql/types"
"github.com/cockroachdb/errors"
)

// PlaceholderIdx is the 0-based index of a placeholder. Placeholder "$1"
Expand Down Expand Up @@ -96,10 +95,6 @@ type PlaceholderTypesInfo struct {
// Types contains the final types set for each placeholder after type
// checking.
Types PlaceholderTypes

// FromSQLPrepare is true when the placeholder is in a statement from a
// PREPARE SQL stmt (rather than a pgwire prepare stmt).
FromSQLPrepare bool
}

// Type returns the known type of a placeholder. If there is no known type yet
Expand Down Expand Up @@ -154,25 +149,15 @@ type PlaceholderInfo struct {

// Init initializes a PlaceholderInfo structure appropriate for the given number
// of placeholders, and with the given (optional) type hints.
func (p *PlaceholderInfo) Init(
numPlaceholders int, typeHints PlaceholderTypes, fromSQL bool,
) error {
if fromSQL {
if typeHints == nil { // This should not happen, but...
return errors.AssertionFailedf("There should be at least one type hint for a sql-level PREPARE statement")
}
p.Types = make(PlaceholderTypes, len(typeHints))
} else {
p.Types = make(PlaceholderTypes, numPlaceholders)
}

func (p *PlaceholderInfo) Init(numPlaceholders int, typeHints PlaceholderTypes) error {
if typeHints == nil {
p.TypeHints = make(PlaceholderTypes, numPlaceholders)
p.Types = make(PlaceholderTypes, numPlaceholders)
} else {
p.Types = make(PlaceholderTypes, len(typeHints))
p.TypeHints = typeHints
}
p.Values = nil
p.FromSQLPrepare = fromSQL
return nil
}

Expand All @@ -183,7 +168,7 @@ func (p *PlaceholderInfo) Assign(src *PlaceholderInfo, numPlaceholders int) erro
*p = *src
return nil
}
return p.Init(numPlaceholders, nil /* typeHints */, false /* fromSQL */)
return p.Init(numPlaceholders, nil /* typeHints */)
}

// MaybeExtendTypes is to fill the nil types with the type hints, if exists.
Expand Down
6 changes: 3 additions & 3 deletions pkg/sql/sem/tree/type_check_internal_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ func BenchmarkTypeCheck(b *testing.B) {
b.Fatalf("%s: %v", expr, err)
}
ctx := tree.MakeSemaContext()
if err := ctx.Placeholders.Init(1 /* numPlaceholders */, nil /* typeHints */, false /* fromSQL */); err != nil {
if err := ctx.Placeholders.Init(1 /* numPlaceholders */, nil /* typeHints */); err != nil {
b.Fatal(err)
}
for i := 0; i < b.N; i++ {
Expand Down Expand Up @@ -183,7 +183,7 @@ func attemptTypeCheckSameTypedExprs(t *testing.T, idx int, test sameTypedExprsTe
ctx := context.Background()
forEachPerm(test.exprs, 0, func(exprs []copyableExpr) {
semaCtx := tree.MakeSemaContext()
if err := semaCtx.Placeholders.Init(len(test.ptypes), clonePlaceholderTypes(test.ptypes), false /* fromSQL */); err != nil {
if err := semaCtx.Placeholders.Init(len(test.ptypes), clonePlaceholderTypes(test.ptypes)); err != nil {
t.Fatal(err)
}
desired := types.Any
Expand Down Expand Up @@ -337,7 +337,7 @@ func TestTypeCheckSameTypedExprsError(t *testing.T) {
for i, d := range testData {
t.Run(fmt.Sprintf("test_%d", i), func(t *testing.T) {
semaCtx := tree.MakeSemaContext()
if err := semaCtx.Placeholders.Init(len(d.ptypes), d.ptypes, false /* fromSQL */); err != nil {
if err := semaCtx.Placeholders.Init(len(d.ptypes), d.ptypes); err != nil {
t.Error(err)
}
desired := types.Any
Expand Down
4 changes: 2 additions & 2 deletions pkg/sql/sem/tree/type_check_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,7 @@ func TestTypeCheck(t *testing.T) {
t.Fatalf("%s: %v", d.expr, err)
}
semaCtx := tree.MakeSemaContext()
if err := semaCtx.Placeholders.Init(1 /* numPlaceholders */, nil /* typeHints */, false /* fromSQL */); err != nil {
if err := semaCtx.Placeholders.Init(1 /* numPlaceholders */, nil /* typeHints */); err != nil {
t.Fatal(err)
}
semaCtx.TypeResolver = mapResolver
Expand Down Expand Up @@ -398,7 +398,7 @@ func TestTypeCheckVolatility(t *testing.T) {

ctx := context.Background()
semaCtx := tree.MakeSemaContext()
if err := semaCtx.Placeholders.Init(len(placeholderTypes), placeholderTypes, false /* fromSQL */); err != nil {
if err := semaCtx.Placeholders.Init(len(placeholderTypes), placeholderTypes); err != nil {
t.Fatal(err)
}

Expand Down

0 comments on commit 14dc93d

Please sign in to comment.