Skip to content

Commit

Permalink
Merge #78108
Browse files Browse the repository at this point in the history
78108: pgwirebase: emit better parse errors r=jordanlewis a=jordanlewis

Previously, pgwire emitted subtly different error messages for parse
errors (e.g. when trying to interpret a string argument to something as
a bool or float etc) than the non-pgwire path.

This is corrected by using the pre-existing parse error function from
the tree package that the other path uses.

Release note (sql change): the error messages returned when encountering
a malformed or unparseable argument for a query in the wire protocol are
now more consistent.

Co-authored-by: Jordan Lewis <[email protected]>
  • Loading branch information
craig[bot] and jordanlewis committed May 7, 2022
2 parents a86e6c9 + ec7d053 commit 82ab0cb
Show file tree
Hide file tree
Showing 4 changed files with 46 additions and 47 deletions.
26 changes: 13 additions & 13 deletions pkg/sql/pgwire/pgwire_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -451,9 +451,9 @@ func TestPGPreparedQuery(t *testing.T) {
{"SELECT $1 > 0", []preparedQueryTest{
baseTest.SetArgs(1).Results(true),
baseTest.SetArgs("1").Results(true),
baseTest.SetArgs(1.1).Error(`pq: error in argument for $1: strconv.ParseInt: parsing "1.1": invalid syntax`).Results(true),
baseTest.SetArgs("1.0").Error(`pq: error in argument for $1: strconv.ParseInt: parsing "1.0": invalid syntax`),
baseTest.SetArgs(true).Error(`pq: error in argument for $1: strconv.ParseInt: parsing "true": invalid syntax`),
baseTest.SetArgs(1.1).Error(`pq: error in argument for $1: could not parse "1.1" as type int: strconv.ParseInt: parsing "1.1": invalid syntax`).Results(true),
baseTest.SetArgs("1.0").Error(`pq: error in argument for $1: could not parse "1.0" as type int: strconv.ParseInt: parsing "1.0": invalid syntax`),
baseTest.SetArgs(true).Error(`pq: error in argument for $1: could not parse "true" as type int: strconv.ParseInt: parsing "true": invalid syntax`),
}},
{"SELECT ($1) > 0", []preparedQueryTest{
baseTest.SetArgs(1).Results(true),
Expand All @@ -467,7 +467,7 @@ func TestPGPreparedQuery(t *testing.T) {
baseTest.SetArgs(true).Results(true),
baseTest.SetArgs(false).Results(false),
baseTest.SetArgs(1).Results(true),
baseTest.SetArgs("").Error(`pq: error in argument for $1: strconv.ParseBool: parsing "": invalid syntax`),
baseTest.SetArgs("").Error(`pq: error in argument for $1: could not parse "" as type bool: strconv.ParseBool: parsing "": invalid syntax`),
// Make sure we can run another after a failure.
baseTest.SetArgs(true).Results(true),
}},
Expand All @@ -476,9 +476,9 @@ func TestPGPreparedQuery(t *testing.T) {
baseTest.SetArgs("true").Results(true),
baseTest.SetArgs("false").Results(false),
baseTest.SetArgs("1").Results(true),
baseTest.SetArgs(2).Error(`pq: error in argument for $1: strconv.ParseBool: parsing "2": invalid syntax`),
baseTest.SetArgs(3.1).Error(`pq: error in argument for $1: strconv.ParseBool: parsing "3.1": invalid syntax`),
baseTest.SetArgs("").Error(`pq: error in argument for $1: strconv.ParseBool: parsing "": invalid syntax`),
baseTest.SetArgs(2).Error(`pq: error in argument for $1: could not parse "2" as type bool: strconv.ParseBool: parsing "2": invalid syntax`),
baseTest.SetArgs(3.1).Error(`pq: error in argument for $1: could not parse "3.1" as type bool: strconv.ParseBool: parsing "3.1": invalid syntax`),
baseTest.SetArgs("").Error(`pq: error in argument for $1: could not parse "" as type bool: strconv.ParseBool: parsing "": invalid syntax`),
}},
{"SELECT CASE 40+2 WHEN 42 THEN 51 ELSE $1::INT END", []preparedQueryTest{
baseTest.Error(
Expand All @@ -492,14 +492,14 @@ func TestPGPreparedQuery(t *testing.T) {
baseTest.SetArgs("2", 1).Results(true),
baseTest.SetArgs(1, "2").Results(false),
baseTest.SetArgs("2", "1.0").Results(true),
baseTest.SetArgs("2.0", "1").Error(`pq: error in argument for $1: strconv.ParseInt: parsing "2.0": invalid syntax`),
baseTest.SetArgs(2.1, 1).Error(`pq: error in argument for $1: strconv.ParseInt: parsing "2.1": invalid syntax`),
baseTest.SetArgs("2.0", "1").Error(`pq: error in argument for $1: could not parse "2.0" as type int: strconv.ParseInt: parsing "2.0": invalid syntax`),
baseTest.SetArgs(2.1, 1).Error(`pq: error in argument for $1: could not parse "2.1" as type int: strconv.ParseInt: parsing "2.1": invalid syntax`),
}},
{"SELECT greatest($1, 0, $2), $2", []preparedQueryTest{
baseTest.SetArgs(1, -1).Results(1, -1),
baseTest.SetArgs(-1, 10).Results(10, 10),
baseTest.SetArgs("-2", "-1").Results(0, -1),
baseTest.SetArgs(1, 2.1).Error(`pq: error in argument for $2: strconv.ParseInt: parsing "2.1": invalid syntax`),
baseTest.SetArgs(1, 2.1).Error(`pq: error in argument for $2: could not parse "2.1" as type int: strconv.ParseInt: parsing "2.1": invalid syntax`),
}},
{"SELECT $1::int, $1::float", []preparedQueryTest{
baseTest.SetArgs(1).Results(1, 1.0),
Expand All @@ -508,7 +508,7 @@ func TestPGPreparedQuery(t *testing.T) {
{"SELECT 3 + $1, $1 + $2", []preparedQueryTest{
baseTest.SetArgs("1", "2").Results(4, 3),
baseTest.SetArgs(3, "4").Results(6, 7),
baseTest.SetArgs(0, "a").Error(`pq: error in argument for $2: strconv.ParseInt: parsing "a": invalid syntax`),
baseTest.SetArgs(0, "a").Error(`pq: error in argument for $2: could not parse "a" as type int: strconv.ParseInt: parsing "a": invalid syntax`),
}},
// Check for name resolution.
{"SELECT count(*)", []preparedQueryTest{
Expand All @@ -522,7 +522,7 @@ func TestPGPreparedQuery(t *testing.T) {
{"SELECT CASE 1 WHEN $1 THEN $2 ELSE 2 END", []preparedQueryTest{
baseTest.SetArgs(1, 3).Results(3),
baseTest.SetArgs(2, 3).Results(2),
baseTest.SetArgs(true, 0).Error(`pq: error in argument for $1: strconv.ParseInt: parsing "true": invalid syntax`),
baseTest.SetArgs(true, 0).Error(`pq: error in argument for $1: could not parse "true" as type int: strconv.ParseInt: parsing "true": invalid syntax`),
}},
{"SELECT $1[2] LIKE 'b'", []preparedQueryTest{
baseTest.SetArgs(pq.Array([]string{"a", "b", "c"})).Results(true),
Expand Down Expand Up @@ -1091,7 +1091,7 @@ func TestPGPreparedExec(t *testing.T) {
"INSERT INTO d.public.t VALUES ($1, $2, $3)",
[]preparedExecTest{
baseTest.SetArgs(1, "one", 2).RowsAffected(1),
baseTest.SetArgs("two", 2, 2).Error(`pq: error in argument for $1: strconv.ParseInt: parsing "two": invalid syntax`),
baseTest.SetArgs("two", 2, 2).Error(`pq: error in argument for $1: could not parse "two" as type int: strconv.ParseInt: parsing "two": invalid syntax`),
},
},
{
Expand Down
63 changes: 31 additions & 32 deletions pkg/sql/pgwire/pgwirebase/encoding.go
Original file line number Diff line number Diff line change
Expand Up @@ -312,33 +312,35 @@ func validateArrayDimensions(nDimensions int, nElements int) error {
// DecodeDatum decodes bytes with specified type and format code into
// a datum. If res is nil, then user defined types are not attempted
// to be resolved.
func DecodeDatum(evalCtx *eval.Context, t *types.T, code FormatCode, b []byte) (tree.Datum, error) {
id := t.Oid()
func DecodeDatum(
evalCtx *eval.Context, typ *types.T, code FormatCode, b []byte,
) (tree.Datum, error) {
id := typ.Oid()
switch code {
case FormatText:
switch id {
case oid.T_record:
d, _, err := tree.ParseDTupleFromString(evalCtx, string(b), t)
d, _, err := tree.ParseDTupleFromString(evalCtx, string(b), typ)
if err != nil {
return nil, err
}
return d, nil
case oid.T_bool:
t, err := strconv.ParseBool(string(b))
if err != nil {
return nil, err
return nil, tree.MakeParseError(string(b), typ, err)
}
return tree.MakeDBool(tree.DBool(t)), nil
case oid.T_bit, oid.T_varbit:
t, err := tree.ParseDBitArray(string(b))
if err != nil {
return nil, err
return nil, tree.MakeParseError(string(b), typ, err)
}
return t, nil
case oid.T_int2, oid.T_int4, oid.T_int8:
i, err := strconv.ParseInt(string(b), 10, 64)
if err != nil {
return nil, err
return nil, tree.MakeParseError(string(b), typ, err)
}
return tree.NewDInt(tree.DInt(i)), nil
case oid.T_oid,
Expand All @@ -352,99 +354,97 @@ func DecodeDatum(evalCtx *eval.Context, t *types.T, code FormatCode, b []byte) (
oid.T_regnamespace,
oid.T_regprocedure,
oid.T_regdictionary:
return eval.ParseDOid(evalCtx, string(b), t)
return eval.ParseDOid(evalCtx, string(b), typ)
case oid.T_float4, oid.T_float8:
f, err := strconv.ParseFloat(string(b), 64)
if err != nil {
return nil, err
return nil, tree.MakeParseError(string(b), typ, err)
}
return tree.NewDFloat(tree.DFloat(f)), nil
case oidext.T_box2d:
d, err := tree.ParseDBox2D(string(b))
if err != nil {
return nil, pgerror.Newf(pgcode.Syntax, "could not parse string %q as box2d", b)
return nil, tree.MakeParseError(string(b), typ, err)
}
return d, nil
case oidext.T_geography:
d, err := tree.ParseDGeography(string(b))
if err != nil {
return nil, pgerror.Newf(pgcode.Syntax, "could not parse string %q as geography", b)
return nil, tree.MakeParseError(string(b), typ, err)
}
return d, nil
case oidext.T_geometry:
d, err := tree.ParseDGeometry(string(b))
if err != nil {
return nil, pgerror.Newf(pgcode.Syntax, "could not parse string %q as geometry", b)
return nil, tree.MakeParseError(string(b), typ, err)
}
return d, nil
case oid.T_void:
return tree.DVoidDatum, nil
case oid.T_numeric:
d, err := tree.ParseDDecimal(string(b))
if err != nil {
return nil, pgerror.Newf(pgcode.Syntax, "could not parse string %q as decimal", b)
return nil, tree.MakeParseError(string(b), typ, err)
}
return d, nil
case oid.T_bytea:
res, err := lex.DecodeRawBytesToByteArrayAuto(b)
if err != nil {
return nil, err
return nil, tree.MakeParseError(string(b), typ, err)
}
return tree.NewDBytes(tree.DBytes(res)), nil
case oid.T_timestamp:
d, _, err := tree.ParseDTimestamp(evalCtx, string(b), time.Microsecond)
if err != nil {
return nil, pgerror.Newf(pgcode.Syntax, "could not parse string %q as timestamp", b)
return nil, tree.MakeParseError(string(b), typ, err)
}
return d, nil
case oid.T_timestamptz:
d, _, err := tree.ParseDTimestampTZ(evalCtx, string(b), time.Microsecond)
if err != nil {
return nil, pgerror.Newf(pgcode.Syntax, "could not parse string %q as timestamptz", b)
return nil, tree.MakeParseError(string(b), typ, err)
}
return d, nil
case oid.T_date:
d, _, err := tree.ParseDDate(evalCtx, string(b))
if err != nil {
return nil, pgerror.Newf(pgcode.Syntax, "could not parse string %q as date", b)
return nil, tree.MakeParseError(string(b), typ, err)
}
return d, nil
case oid.T_time:
d, _, err := tree.ParseDTime(nil, string(b), time.Microsecond)
if err != nil {
return nil, pgerror.Newf(pgcode.Syntax, "could not parse string %q as time", b)
return nil, tree.MakeParseError(string(b), typ, err)
}
return d, nil
case oid.T_timetz:
d, _, err := tree.ParseDTimeTZ(evalCtx, string(b), time.Microsecond)
if err != nil {
return nil, pgerror.Newf(pgcode.Syntax, "could not parse string %q as timetz", b)
return nil, tree.MakeParseError(string(b), typ, err)
}
return d, nil
case oid.T_interval:
d, err := tree.ParseDInterval(evalCtx.GetIntervalStyle(), string(b))
if err != nil {
return nil, pgerror.Newf(pgcode.Syntax, "could not parse string %q as interval", b)
return nil, tree.MakeParseError(string(b), typ, err)
}
return d, nil
case oid.T_uuid:
d, err := tree.ParseDUuidFromString(string(b))
if err != nil {
return nil, pgerror.Newf(pgcode.Syntax, "could not parse string %q as uuid", b)
return nil, tree.MakeParseError(string(b), typ, err)
}
return d, nil
case oid.T_inet:
d, err := tree.ParseDIPAddrFromINetString(string(b))
if err != nil {
return nil, pgerror.Newf(pgcode.Syntax,
"could not parse string %q as inet", b)
return nil, tree.MakeParseError(string(b), typ, err)
}
return d, nil
case oid.T__int2, oid.T__int4, oid.T__int8:
var arr pgtype.Int8Array
if err := arr.DecodeText(nil, b); err != nil {
return nil, pgerror.Wrapf(err, pgcode.Syntax,
"could not parse string %q as int array", b)
return nil, tree.MakeParseError(string(b), typ, err)
}
if arr.Status != pgtype.Present {
return tree.DNull, nil
Expand All @@ -468,8 +468,7 @@ func DecodeDatum(evalCtx *eval.Context, t *types.T, code FormatCode, b []byte) (
case oid.T__text, oid.T__name:
var arr pgtype.TextArray
if err := arr.DecodeText(nil, b); err != nil {
return nil, pgerror.Wrapf(err, pgcode.Syntax,
"could not parse string %q as text array", b)
return nil, tree.MakeParseError(string(b), typ, err)
}
if arr.Status != pgtype.Present {
return tree.DNull, nil
Expand Down Expand Up @@ -502,7 +501,7 @@ func DecodeDatum(evalCtx *eval.Context, t *types.T, code FormatCode, b []byte) (
}
return tree.ParseDJSON(string(b))
}
if t.Family() == types.ArrayFamily {
if typ.Family() == types.ArrayFamily {
// Arrays come in in their string form, so we parse them as such and later
// convert them to their actual datum form.
if err := validateStringBytes(b); err != nil {
Expand Down Expand Up @@ -656,7 +655,7 @@ func DecodeDatum(evalCtx *eval.Context, t *types.T, code FormatCode, b []byte) (

decString := string(decDigits)
if _, ok := alloc.dd.Coeff.SetString(decString, 10); !ok {
return nil, pgerror.Newf(pgcode.Syntax, "could not parse string %q as decimal", decString)
return nil, pgerror.Newf(pgcode.Syntax, "could not parse %q as type decimal", decString)
}
alloc.dd.Exponent = -int32(dscale)
}
Expand Down Expand Up @@ -784,8 +783,8 @@ func DecodeDatum(evalCtx *eval.Context, t *types.T, code FormatCode, b []byte) (
ba, err := bitarray.FromEncodingParts(words, lastBitsUsed)
return &tree.DBitArray{BitArray: ba}, err
default:
if t.Family() == types.ArrayFamily {
return decodeBinaryArray(evalCtx, t.ArrayContents(), b, code)
if typ.Family() == types.ArrayFamily {
return decodeBinaryArray(evalCtx, typ.ArrayContents(), b, code)
}
}
default:
Expand All @@ -794,12 +793,12 @@ func DecodeDatum(evalCtx *eval.Context, t *types.T, code FormatCode, b []byte) (
}

// Types with identical text/binary handling.
switch t.Family() {
switch typ.Family() {
case types.EnumFamily:
if err := validateStringBytes(b); err != nil {
return nil, err
}
return tree.MakeDEnumFromLogicalRepresentation(t, string(b))
return tree.MakeDEnumFromLogicalRepresentation(typ, string(b))
}
switch id {
case oid.T_text, oid.T_varchar, oid.T_unknown:
Expand Down
2 changes: 1 addition & 1 deletion pkg/sql/split_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ func TestSplitAt(t *testing.T) {
{
in: "ALTER TABLE d.i SPLIT AT VALUES ($1)",
args: []interface{}{"blah"},
error: "error in argument for $1: strconv.ParseInt",
error: "error in argument for $1: could not parse \"blah\" as type int: strconv.ParseInt",
},
{
in: "ALTER TABLE d.i SPLIT AT VALUES ($1::string)",
Expand Down
2 changes: 1 addition & 1 deletion pkg/sql/unsplit_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@ func TestUnsplitAt(t *testing.T) {
{
unsplitStmt: "ALTER TABLE d.i UNSPLIT AT VALUES ($1)",
args: []interface{}{"blah"},
error: "error in argument for $1: strconv.ParseInt",
error: "error in argument for $1: could not parse \"blah\" as type int: strconv.ParseInt",
},
{
unsplitStmt: "ALTER TABLE d.i UNSPLIT AT VALUES ($1::string)",
Expand Down

0 comments on commit 82ab0cb

Please sign in to comment.