Skip to content

Commit

Permalink
pgwire: rework DecodeOidDatum to DecodeDatum to parse OidFamily types
Browse files Browse the repository at this point in the history
Reworked DecodeOidDatum to DecodeDatum to take in a type, which
encodes additional useful information necessary for ENUMs and oid
family types.

Release note (bug fix): Fixed a bug where reg* types were not parsed
properly over pgwire, COPY or prepared statements.
  • Loading branch information
otan committed Nov 18, 2020
1 parent db0f4bf commit 6f55d8b
Show file tree
Hide file tree
Showing 11 changed files with 93 additions and 87 deletions.
17 changes: 14 additions & 3 deletions pkg/sql/conn_executor_prepare.go
Original file line number Diff line number Diff line change
Expand Up @@ -339,16 +339,27 @@ func (ex *connExecutor) execBind(
"expected %d arguments, got %d", numQArgs, len(bindCmd.Args)))
}

ptCtx := tree.NewParseTimeContext(ex.state.sqlTimestamp.In(ex.sessionData.GetLocation()))

for i, arg := range bindCmd.Args {
k := tree.PlaceholderIdx(i)
t := ps.InferredTypes[i]
if arg == nil {
// nil indicates a NULL argument value.
qargs[k] = tree.DNull
} else {
d, err := pgwirebase.DecodeOidDatum(ctx, ptCtx, t, qArgFormatCodes[i], arg, &ex.planner)
typ, ok := types.OidToType[t]
if !ok {
var err error
typ, err = ex.planner.ResolveTypeByOID(ctx, t)
if err != nil {
return nil, err
}
}
d, err := pgwirebase.DecodeDatum(
ex.planner.EvalContext(),
typ,
qArgFormatCodes[i],
arg,
)
if err != nil {
return retErr(pgerror.Wrapf(err, pgcode.ProtocolViolation,
"error in argument for %s", k))
Expand Down
6 changes: 2 additions & 4 deletions pkg/sql/copy.go
Original file line number Diff line number Diff line change
Expand Up @@ -371,13 +371,11 @@ func (c *copyMachine) readBinaryTuple(ctx context.Context) error {
if len(data) != int(byteCount) {
return errors.Newf("partial copy data row")
}
d, err := pgwirebase.DecodeOidDatum(
ctx,
d, err := pgwirebase.DecodeDatum(
c.parsingEvalCtx,
c.resultColumns[i].Typ.Oid(),
c.resultColumns[i].Typ,
pgwirebase.FormatBinary,
data,
&c.p,
)
if err != nil {
return pgerror.Wrapf(err, pgcode.BadCopyFileFormat,
Expand Down
6 changes: 6 additions & 0 deletions pkg/sql/logictest/testdata/logic_test/pgoidtype
Original file line number Diff line number Diff line change
Expand Up @@ -386,3 +386,9 @@ SELECT
'12345'::regprocedure::string
----
12345 12345 12345 12345 12345

query T
PREPARE regression_56193 AS SELECT $1::regclass;
EXECUTE regression_56193('regression_53686"'::regclass)
----
"regression_53686"""
14 changes: 12 additions & 2 deletions pkg/sql/pgwire/encoding_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,12 @@ func TestEncodings(t *testing.T) {
pgwirebase.FormatText: tc.TextAsBinary,
pgwirebase.FormatBinary: tc.Binary,
} {
d, err := pgwirebase.DecodeOidDatum(context.Background(), nil, tc.Oid, code, value, nil)
d, err := pgwirebase.DecodeDatum(
&evalCtx,
types.OidToType[tc.Oid],
code,
value,
)
if err != nil {
t.Fatal(err)
}
Expand Down Expand Up @@ -240,7 +245,12 @@ func TestExoticNumericEncodings(t *testing.T) {
evalCtx := tree.MakeTestingEvalContext(nil)
for i, c := range testCases {
t.Run(fmt.Sprintf("%d_%s", i, c.Value), func(t *testing.T) {
d, err := pgwirebase.DecodeOidDatum(context.Background(), nil, oid.T_numeric, pgwirebase.FormatBinary, c.Encoding, nil)
d, err := pgwirebase.DecodeDatum(
&evalCtx,
types.Decimal,
pgwirebase.FormatBinary,
c.Encoding,
)
if err != nil {
t.Fatal(err)
}
Expand Down
87 changes: 34 additions & 53 deletions pkg/sql/pgwire/pgwirebase/encoding.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@ package pgwirebase
import (
"bufio"
"bytes"
"context"
"encoding/binary"
"io"
"math"
Expand Down Expand Up @@ -297,20 +296,23 @@ func validateArrayDimensions(nDimensions int, nElements int) error {
return nil
}

// DecodeOidDatum decodes bytes with specified Oid and format code into
// a datum. If the ParseTimeContext is nil, reasonable defaults
// will be applied. If res is nil, then user defined types are not attempted
// DecodeDatum decodes bytes with specified type and format code into
// a datum. If res is nil, then user defined types are not attempted
// to be resolved.
func DecodeOidDatum(
ctx context.Context,
pCtx tree.ParseTimeContext,
id oid.Oid,
code FormatCode,
b []byte,
res tree.TypeReferenceResolver,
func DecodeDatum(
evalCtx *tree.EvalContext, t *types.T, code FormatCode, b []byte,
) (tree.Datum, error) {
id := t.Oid()
switch code {
case FormatText:
switch t.Family() {
case types.EnumFamily:
if err := validateStringBytes(b); err != nil {
return nil, err
}
return tree.MakeDEnumFromLogicalRepresentation(t, string(b))
}

switch id {
case oid.T_bool:
t, err := strconv.ParseBool(string(b))
Expand All @@ -330,12 +332,18 @@ func DecodeOidDatum(
return nil, err
}
return tree.NewDInt(tree.DInt(i)), nil
case oid.T_oid:
u, err := strconv.ParseUint(string(b), 10, 32)
if err != nil {
return nil, err
}
return tree.NewDOid(tree.DInt(u)), nil
case oid.T_oid,
oid.T_regoper,
oid.T_regproc,
oid.T_regrole,
oid.T_regclass,
oid.T_regtype,
oid.T_regconfig,
oid.T_regoperator,
oid.T_regnamespace,
oid.T_regprocedure,
oid.T_regdictionary:
return tree.ParseDOid(evalCtx, string(b), t)
case oid.T_float4, oid.T_float8:
f, err := strconv.ParseFloat(string(b), 64)
if err != nil {
Expand Down Expand Up @@ -373,19 +381,19 @@ func DecodeOidDatum(
}
return tree.NewDBytes(tree.DBytes(res)), nil
case oid.T_timestamp:
d, _, err := tree.ParseDTimestamp(pCtx, string(b), time.Microsecond)
d, _, err := tree.ParseDTimestamp(evalCtx, string(b), time.Microsecond)
if err != nil {
return nil, pgerror.Newf(pgcode.Syntax, "could not parse string %q as timestamp", b)
}
return d, nil
case oid.T_timestamptz:
d, _, err := tree.ParseDTimestampTZ(pCtx, string(b), time.Microsecond)
d, _, err := tree.ParseDTimestampTZ(evalCtx, string(b), time.Microsecond)
if err != nil {
return nil, pgerror.Newf(pgcode.Syntax, "could not parse string %q as timestamptz", b)
}
return d, nil
case oid.T_date:
d, _, err := tree.ParseDDate(pCtx, string(b))
d, _, err := tree.ParseDDate(evalCtx, string(b))
if err != nil {
return nil, pgerror.Newf(pgcode.Syntax, "could not parse string %q as date", b)
}
Expand All @@ -397,7 +405,7 @@ func DecodeOidDatum(
}
return d, nil
case oid.T_timetz:
d, _, err := tree.ParseDTimeTZ(pCtx, string(b), time.Microsecond)
d, _, err := tree.ParseDTimeTZ(evalCtx, string(b), time.Microsecond)
if err != nil {
return nil, pgerror.Newf(pgcode.Syntax, "could not parse string %q as timetz", b)
}
Expand Down Expand Up @@ -757,8 +765,7 @@ func DecodeOidDatum(
return &tree.DBitArray{BitArray: ba}, err
default:
if _, ok := types.ArrayOids[id]; ok {
innerOid := types.OidToType[id].ArrayContents().Oid()
return decodeBinaryArray(ctx, pCtx, innerOid, b, code, res)
return decodeBinaryArray(evalCtx, types.OidToType[id].ArrayContents(), b, code)
}
}
default:
Expand Down Expand Up @@ -787,27 +794,6 @@ func DecodeOidDatum(
return tree.NewDName(string(b)), nil
}

// Finally, try to resolve the type's oid as a user defined type if a resolver
// was provided.
if res != nil {
typ, err := res.ResolveTypeByOID(ctx, id)
if err != nil {
return nil, err
}
switch typ.Family() {
case types.EnumFamily:
if code != FormatText {
return nil, pgerror.Newf(pgcode.Syntax, "expected FormatText for ENUM value encoding")
}
if err := validateStringBytes(b); err != nil {
return nil, err
}
return tree.MakeDEnumFromLogicalRepresentation(typ, string(b))
default:
return nil, errors.AssertionFailedf("unsupported user defined type family %s", typ.Family().String())
}
}

// Fallthrough case.
return nil, errors.AssertionFailedf(
"unsupported OID %v with format code %s", errors.Safe(id), errors.Safe(code))
Expand Down Expand Up @@ -909,12 +895,7 @@ func pgBinaryToIPAddr(b []byte) (ipaddr.IPAddr, error) {
}

func decodeBinaryArray(
ctx context.Context,
pCtx tree.ParseTimeContext,
elemOid oid.Oid,
b []byte,
code FormatCode,
res tree.TypeReferenceResolver,
evalCtx *tree.EvalContext, t *types.T, b []byte, code FormatCode,
) (tree.Datum, error) {
var hdr struct {
Ndims int32
Expand All @@ -934,10 +915,10 @@ func decodeBinaryArray(
if err := binary.Read(r, binary.BigEndian, &hdr); err != nil {
return nil, err
}
if elemOid != oid.Oid(hdr.ElemOid) {
if t.Oid() != oid.Oid(hdr.ElemOid) {
return nil, pgerror.Newf(pgcode.DatatypeMismatch, "wrong element type")
}
arr := tree.NewDArray(types.OidToType[elemOid])
arr := tree.NewDArray(types.OidToType[t.Oid()])
if hdr.Ndims == 0 {
return arr, nil
}
Expand All @@ -959,7 +940,7 @@ func decodeBinaryArray(
continue
}
buf := r.Next(int(vlen))
elem, err := DecodeOidDatum(ctx, pCtx, elemOid, code, buf, res)
elem, err := DecodeDatum(evalCtx, t, code, buf)
if err != nil {
return nil, err
}
Expand Down
23 changes: 12 additions & 11 deletions pkg/sql/pgwire/pgwirebase/fuzz.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,34 +15,35 @@ package pgwirebase
import (
"context"

"github.com/cockroachdb/cockroach/pkg/settings/cluster"
"github.com/cockroachdb/cockroach/pkg/sql/sem/tree"
"github.com/cockroachdb/cockroach/pkg/sql/types"
"github.com/cockroachdb/cockroach/pkg/util/timeutil"
"github.com/lib/pq/oid"
)

var (
timeCtx = tree.NewParseTimeContext(timeutil.Now())
// Compile a slice of all oids.
oids = func() []oid.Oid {
var ret []oid.Oid
for oid := range types.OidToType {
ret = append(ret, oid)
// Compile a slice of all typs.
typs = func() []*types.T {
var ret []*types.T
for _, typ := range types.OidToType {
ret = append(ret, typ)
}
return ret
}()
)

func FuzzDecodeOidDatum(data []byte) int {
func FuzzDecodeDatum(data []byte) int {
if len(data) < 2 {
return 0
}

id := oids[int(data[1])%len(oids)]
typ := typs[int(data[1])%len(typs)]
code := FormatCode(data[0]) % (FormatBinary + 1)
b := data[2:]

_, err := DecodeOidDatum(context.Background(), timeCtx, id, code, b, nil)
evalCtx := tree.NewTestingEvalContext(cluster.MakeTestingClusterSettings())
defer evalCtx.Stop(context.Background())

_, err := DecodeDatum(evalCtx, typ, code, b)
if err != nil {
return 0
}
Expand Down
19 changes: 9 additions & 10 deletions pkg/sql/pgwire/types_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@ import (
"github.com/cockroachdb/cockroach/pkg/util/metric"
"github.com/cockroachdb/cockroach/pkg/util/timeutil"
"github.com/cockroachdb/cockroach/pkg/util/uuid"
"github.com/lib/pq/oid"
)

// The assertions in this test should also be caught by the integration tests on
Expand Down Expand Up @@ -136,12 +135,12 @@ func TestIntArrayRoundTrip(t *testing.T) {

b := buf.wrapped.Bytes()

got, err := pgwirebase.DecodeOidDatum(context.Background(), nil, oid.T__int8, pgwirebase.FormatText, b[4:], nil)
evalCtx := tree.NewTestingEvalContext(cluster.MakeTestingClusterSettings())
defer evalCtx.Stop(context.Background())
got, err := pgwirebase.DecodeDatum(evalCtx, types.IntArray, pgwirebase.FormatText, b[4:])
if err != nil {
t.Fatal(err)
}
evalCtx := tree.NewTestingEvalContext(cluster.MakeTestingClusterSettings())
defer evalCtx.Stop(context.Background())
if got.Compare(evalCtx, d) != 0 {
t.Fatalf("expected %s, got %s", d, got)
}
Expand Down Expand Up @@ -217,15 +216,15 @@ func TestByteArrayRoundTrip(t *testing.T) {
b := buf.wrapped.Bytes()
t.Logf("encoded: %v (%q)", b, b)

got, err := pgwirebase.DecodeOidDatum(context.Background(), nil, oid.T_bytea, pgwirebase.FormatText, b[4:], nil)
evalCtx := tree.NewTestingEvalContext(cluster.MakeTestingClusterSettings())
defer evalCtx.Stop(context.Background())
got, err := pgwirebase.DecodeDatum(evalCtx, types.Bytes, pgwirebase.FormatText, b[4:])
if err != nil {
t.Fatal(err)
}
if _, ok := got.(*tree.DBytes); !ok {
t.Fatalf("parse does not return DBytes, got %T", got)
}
evalCtx := tree.NewTestingEvalContext(cluster.MakeTestingClusterSettings())
defer evalCtx.Stop(context.Background())
if got.Compare(evalCtx, d) != 0 {
t.Fatalf("expected %s, got %s", d, got)
}
Expand Down Expand Up @@ -487,11 +486,11 @@ func BenchmarkDecodeBinaryDecimal(b *testing.B) {

b.ResetTimer()
for i := 0; i < b.N; i++ {
b.StartTimer()
got, err := pgwirebase.DecodeOidDatum(context.Background(), nil, oid.T_numeric, pgwirebase.FormatBinary, bytes, nil)
b.StopTimer()
evalCtx := tree.NewTestingEvalContext(cluster.MakeTestingClusterSettings())
defer evalCtx.Stop(context.Background())
b.StartTimer()
got, err := pgwirebase.DecodeDatum(evalCtx, types.Decimal, pgwirebase.FormatBinary, bytes)
b.StopTimer()
if err != nil {
b.Fatal(err)
} else if got.Compare(evalCtx, expected) != 0 {
Expand Down
1 change: 1 addition & 0 deletions pkg/sql/sem/tree/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,7 @@ go_test(
"compare_test.go",
"constant_eval_test.go",
"constant_test.go",
"datum_integration_test.go",
"datum_invariants_test.go",
"datum_test.go",
"eval_internal_test.go",
Expand Down
3 changes: 2 additions & 1 deletion pkg/sql/sem/tree/datum_integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ import (
"github.com/cockroachdb/cockroach/pkg/util/log"
"github.com/cockroachdb/cockroach/pkg/util/timeofday"
"github.com/cockroachdb/cockroach/pkg/util/timetz"
"github.com/cockroachdb/cockroach/pkg/util/timeutil"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
Expand Down Expand Up @@ -848,7 +849,7 @@ func TestDTimeTZ(t *testing.T) {
require.False(t, depOnCtx)

// No daylight savings in Hawaii!
hawaiiZone, err := time.LoadLocation("Pacific/Honolulu")
hawaiiZone, err := timeutil.LoadLocation("Pacific/Honolulu")
require.NoError(t, err)
hawaiiTime := tree.NewDTimeTZFromLocation(timeofday.New(1, 14, 15, 0), hawaiiZone)

Expand Down
Loading

0 comments on commit 6f55d8b

Please sign in to comment.