Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

pgwire: rework DecodeOidDatum to DecodeDatum to parse OidFamily types #56298

Merged
merged 2 commits into from
Nov 23, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 14 additions & 3 deletions pkg/sql/conn_executor_prepare.go
Original file line number Diff line number Diff line change
Expand Up @@ -339,16 +339,27 @@ func (ex *connExecutor) execBind(
"expected %d arguments, got %d", numQArgs, len(bindCmd.Args)))
}

ptCtx := tree.NewParseTimeContext(ex.state.sqlTimestamp.In(ex.sessionData.GetLocation()))

for i, arg := range bindCmd.Args {
k := tree.PlaceholderIdx(i)
t := ps.InferredTypes[i]
if arg == nil {
// nil indicates a NULL argument value.
qargs[k] = tree.DNull
} else {
d, err := pgwirebase.DecodeOidDatum(ctx, ptCtx, t, qArgFormatCodes[i], arg, &ex.planner)
typ, ok := types.OidToType[t]
if !ok {
var err error
typ, err = ex.planner.ResolveTypeByOID(ctx, t)
if err != nil {
return nil, err
}
}
d, err := pgwirebase.DecodeDatum(
ex.planner.EvalContext(),
typ,
qArgFormatCodes[i],
arg,
)
if err != nil {
return retErr(pgerror.Wrapf(err, pgcode.ProtocolViolation,
"error in argument for %s", k))
Expand Down
6 changes: 2 additions & 4 deletions pkg/sql/copy.go
Original file line number Diff line number Diff line change
Expand Up @@ -371,13 +371,11 @@ func (c *copyMachine) readBinaryTuple(ctx context.Context) error {
if len(data) != int(byteCount) {
return errors.Newf("partial copy data row")
}
d, err := pgwirebase.DecodeOidDatum(
ctx,
d, err := pgwirebase.DecodeDatum(
c.parsingEvalCtx,
c.resultColumns[i].Typ.Oid(),
c.resultColumns[i].Typ,
pgwirebase.FormatBinary,
data,
&c.p,
)
if err != nil {
return pgerror.Wrapf(err, pgcode.BadCopyFileFormat,
Expand Down
6 changes: 6 additions & 0 deletions pkg/sql/logictest/testdata/logic_test/pgoidtype
Original file line number Diff line number Diff line change
Expand Up @@ -386,3 +386,9 @@ SELECT
'12345'::regprocedure::string
----
12345 12345 12345 12345 12345

query T
PREPARE regression_56193 AS SELECT $1::regclass;
EXECUTE regression_56193('regression_53686"'::regclass)
----
"regression_53686"""
14 changes: 12 additions & 2 deletions pkg/sql/pgwire/encoding_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,12 @@ func TestEncodings(t *testing.T) {
pgwirebase.FormatText: tc.TextAsBinary,
pgwirebase.FormatBinary: tc.Binary,
} {
d, err := pgwirebase.DecodeOidDatum(context.Background(), nil, tc.Oid, code, value, nil)
d, err := pgwirebase.DecodeDatum(
&evalCtx,
types.OidToType[tc.Oid],
code,
value,
)
if err != nil {
t.Fatal(err)
}
Expand Down Expand Up @@ -240,7 +245,12 @@ func TestExoticNumericEncodings(t *testing.T) {
evalCtx := tree.MakeTestingEvalContext(nil)
for i, c := range testCases {
t.Run(fmt.Sprintf("%d_%s", i, c.Value), func(t *testing.T) {
d, err := pgwirebase.DecodeOidDatum(context.Background(), nil, oid.T_numeric, pgwirebase.FormatBinary, c.Encoding, nil)
d, err := pgwirebase.DecodeDatum(
&evalCtx,
types.Decimal,
pgwirebase.FormatBinary,
c.Encoding,
)
if err != nil {
t.Fatal(err)
}
Expand Down
87 changes: 34 additions & 53 deletions pkg/sql/pgwire/pgwirebase/encoding.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@ package pgwirebase
import (
"bufio"
"bytes"
"context"
"encoding/binary"
"io"
"math"
Expand Down Expand Up @@ -297,20 +296,23 @@ func validateArrayDimensions(nDimensions int, nElements int) error {
return nil
}

// DecodeOidDatum decodes bytes with specified Oid and format code into
// a datum. If the ParseTimeContext is nil, reasonable defaults
// will be applied. If res is nil, then user defined types are not attempted
// DecodeDatum decodes bytes with specified type and format code into
// a datum. If res is nil, then user defined types are not attempted
// to be resolved.
func DecodeOidDatum(
ctx context.Context,
pCtx tree.ParseTimeContext,
id oid.Oid,
code FormatCode,
b []byte,
res tree.TypeReferenceResolver,
func DecodeDatum(
evalCtx *tree.EvalContext, t *types.T, code FormatCode, b []byte,
) (tree.Datum, error) {
id := t.Oid()
switch code {
case FormatText:
switch t.Family() {
case types.EnumFamily:
if err := validateStringBytes(b); err != nil {
return nil, err
}
return tree.MakeDEnumFromLogicalRepresentation(t, string(b))
}

switch id {
case oid.T_bool:
t, err := strconv.ParseBool(string(b))
Expand All @@ -330,12 +332,18 @@ func DecodeOidDatum(
return nil, err
}
return tree.NewDInt(tree.DInt(i)), nil
case oid.T_oid:
u, err := strconv.ParseUint(string(b), 10, 32)
if err != nil {
return nil, err
}
return tree.NewDOid(tree.DInt(u)), nil
case oid.T_oid,
oid.T_regoper,
oid.T_regproc,
oid.T_regrole,
oid.T_regclass,
oid.T_regtype,
oid.T_regconfig,
oid.T_regoperator,
oid.T_regnamespace,
oid.T_regprocedure,
oid.T_regdictionary:
return tree.ParseDOid(evalCtx, string(b), t)
case oid.T_float4, oid.T_float8:
f, err := strconv.ParseFloat(string(b), 64)
if err != nil {
Expand Down Expand Up @@ -373,19 +381,19 @@ func DecodeOidDatum(
}
return tree.NewDBytes(tree.DBytes(res)), nil
case oid.T_timestamp:
d, _, err := tree.ParseDTimestamp(pCtx, string(b), time.Microsecond)
d, _, err := tree.ParseDTimestamp(evalCtx, string(b), time.Microsecond)
if err != nil {
return nil, pgerror.Newf(pgcode.Syntax, "could not parse string %q as timestamp", b)
}
return d, nil
case oid.T_timestamptz:
d, _, err := tree.ParseDTimestampTZ(pCtx, string(b), time.Microsecond)
d, _, err := tree.ParseDTimestampTZ(evalCtx, string(b), time.Microsecond)
if err != nil {
return nil, pgerror.Newf(pgcode.Syntax, "could not parse string %q as timestamptz", b)
}
return d, nil
case oid.T_date:
d, _, err := tree.ParseDDate(pCtx, string(b))
d, _, err := tree.ParseDDate(evalCtx, string(b))
if err != nil {
return nil, pgerror.Newf(pgcode.Syntax, "could not parse string %q as date", b)
}
Expand All @@ -397,7 +405,7 @@ func DecodeOidDatum(
}
return d, nil
case oid.T_timetz:
d, _, err := tree.ParseDTimeTZ(pCtx, string(b), time.Microsecond)
d, _, err := tree.ParseDTimeTZ(evalCtx, string(b), time.Microsecond)
if err != nil {
return nil, pgerror.Newf(pgcode.Syntax, "could not parse string %q as timetz", b)
}
Expand Down Expand Up @@ -757,8 +765,7 @@ func DecodeOidDatum(
return &tree.DBitArray{BitArray: ba}, err
default:
if _, ok := types.ArrayOids[id]; ok {
innerOid := types.OidToType[id].ArrayContents().Oid()
return decodeBinaryArray(ctx, pCtx, innerOid, b, code, res)
return decodeBinaryArray(evalCtx, types.OidToType[id].ArrayContents(), b, code)
}
}
default:
Expand Down Expand Up @@ -787,27 +794,6 @@ func DecodeOidDatum(
return tree.NewDName(string(b)), nil
}

// Finally, try to resolve the type's oid as a user defined type if a resolver
// was provided.
if res != nil {
typ, err := res.ResolveTypeByOID(ctx, id)
if err != nil {
return nil, err
}
switch typ.Family() {
case types.EnumFamily:
if code != FormatText {
return nil, pgerror.Newf(pgcode.Syntax, "expected FormatText for ENUM value encoding")
}
if err := validateStringBytes(b); err != nil {
return nil, err
}
return tree.MakeDEnumFromLogicalRepresentation(typ, string(b))
default:
return nil, errors.AssertionFailedf("unsupported user defined type family %s", typ.Family().String())
}
}

// Fallthrough case.
return nil, errors.AssertionFailedf(
"unsupported OID %v with format code %s", errors.Safe(id), errors.Safe(code))
Expand Down Expand Up @@ -909,12 +895,7 @@ func pgBinaryToIPAddr(b []byte) (ipaddr.IPAddr, error) {
}

func decodeBinaryArray(
ctx context.Context,
pCtx tree.ParseTimeContext,
elemOid oid.Oid,
b []byte,
code FormatCode,
res tree.TypeReferenceResolver,
evalCtx *tree.EvalContext, t *types.T, b []byte, code FormatCode,
) (tree.Datum, error) {
var hdr struct {
Ndims int32
Expand All @@ -934,10 +915,10 @@ func decodeBinaryArray(
if err := binary.Read(r, binary.BigEndian, &hdr); err != nil {
return nil, err
}
if elemOid != oid.Oid(hdr.ElemOid) {
if t.Oid() != oid.Oid(hdr.ElemOid) {
return nil, pgerror.Newf(pgcode.DatatypeMismatch, "wrong element type")
}
arr := tree.NewDArray(types.OidToType[elemOid])
arr := tree.NewDArray(types.OidToType[t.Oid()])
if hdr.Ndims == 0 {
return arr, nil
}
Expand All @@ -959,7 +940,7 @@ func decodeBinaryArray(
continue
}
buf := r.Next(int(vlen))
elem, err := DecodeOidDatum(ctx, pCtx, elemOid, code, buf, res)
elem, err := DecodeDatum(evalCtx, t, code, buf)
if err != nil {
return nil, err
}
Expand Down
23 changes: 12 additions & 11 deletions pkg/sql/pgwire/pgwirebase/fuzz.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,34 +15,35 @@ package pgwirebase
import (
"context"

"github.com/cockroachdb/cockroach/pkg/settings/cluster"
"github.com/cockroachdb/cockroach/pkg/sql/sem/tree"
"github.com/cockroachdb/cockroach/pkg/sql/types"
"github.com/cockroachdb/cockroach/pkg/util/timeutil"
"github.com/lib/pq/oid"
)

var (
timeCtx = tree.NewParseTimeContext(timeutil.Now())
// Compile a slice of all oids.
oids = func() []oid.Oid {
var ret []oid.Oid
for oid := range types.OidToType {
ret = append(ret, oid)
// Compile a slice of all typs.
typs = func() []*types.T {
var ret []*types.T
for _, typ := range types.OidToType {
ret = append(ret, typ)
}
return ret
}()
)

func FuzzDecodeOidDatum(data []byte) int {
func FuzzDecodeDatum(data []byte) int {
if len(data) < 2 {
return 0
}

id := oids[int(data[1])%len(oids)]
typ := typs[int(data[1])%len(typs)]
code := FormatCode(data[0]) % (FormatBinary + 1)
b := data[2:]

_, err := DecodeOidDatum(context.Background(), timeCtx, id, code, b, nil)
evalCtx := tree.NewTestingEvalContext(cluster.MakeTestingClusterSettings())
defer evalCtx.Stop(context.Background())

_, err := DecodeDatum(evalCtx, typ, code, b)
if err != nil {
return 0
}
Expand Down
19 changes: 9 additions & 10 deletions pkg/sql/pgwire/types_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@ import (
"github.com/cockroachdb/cockroach/pkg/util/metric"
"github.com/cockroachdb/cockroach/pkg/util/timeutil"
"github.com/cockroachdb/cockroach/pkg/util/uuid"
"github.com/lib/pq/oid"
)

// The assertions in this test should also be caught by the integration tests on
Expand Down Expand Up @@ -136,12 +135,12 @@ func TestIntArrayRoundTrip(t *testing.T) {

b := buf.wrapped.Bytes()

got, err := pgwirebase.DecodeOidDatum(context.Background(), nil, oid.T__int8, pgwirebase.FormatText, b[4:], nil)
evalCtx := tree.NewTestingEvalContext(cluster.MakeTestingClusterSettings())
defer evalCtx.Stop(context.Background())
got, err := pgwirebase.DecodeDatum(evalCtx, types.IntArray, pgwirebase.FormatText, b[4:])
if err != nil {
t.Fatal(err)
}
evalCtx := tree.NewTestingEvalContext(cluster.MakeTestingClusterSettings())
defer evalCtx.Stop(context.Background())
if got.Compare(evalCtx, d) != 0 {
t.Fatalf("expected %s, got %s", d, got)
}
Expand Down Expand Up @@ -217,15 +216,15 @@ func TestByteArrayRoundTrip(t *testing.T) {
b := buf.wrapped.Bytes()
t.Logf("encoded: %v (%q)", b, b)

got, err := pgwirebase.DecodeOidDatum(context.Background(), nil, oid.T_bytea, pgwirebase.FormatText, b[4:], nil)
evalCtx := tree.NewTestingEvalContext(cluster.MakeTestingClusterSettings())
defer evalCtx.Stop(context.Background())
got, err := pgwirebase.DecodeDatum(evalCtx, types.Bytes, pgwirebase.FormatText, b[4:])
if err != nil {
t.Fatal(err)
}
if _, ok := got.(*tree.DBytes); !ok {
t.Fatalf("parse does not return DBytes, got %T", got)
}
evalCtx := tree.NewTestingEvalContext(cluster.MakeTestingClusterSettings())
defer evalCtx.Stop(context.Background())
if got.Compare(evalCtx, d) != 0 {
t.Fatalf("expected %s, got %s", d, got)
}
Expand Down Expand Up @@ -487,11 +486,11 @@ func BenchmarkDecodeBinaryDecimal(b *testing.B) {

b.ResetTimer()
for i := 0; i < b.N; i++ {
b.StartTimer()
got, err := pgwirebase.DecodeOidDatum(context.Background(), nil, oid.T_numeric, pgwirebase.FormatBinary, bytes, nil)
b.StopTimer()
evalCtx := tree.NewTestingEvalContext(cluster.MakeTestingClusterSettings())
defer evalCtx.Stop(context.Background())
b.StartTimer()
got, err := pgwirebase.DecodeDatum(evalCtx, types.Decimal, pgwirebase.FormatBinary, bytes)
b.StopTimer()
if err != nil {
b.Fatal(err)
} else if got.Compare(evalCtx, expected) != 0 {
Expand Down
1 change: 1 addition & 0 deletions pkg/sql/sem/tree/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,7 @@ go_test(
"compare_test.go",
"constant_eval_test.go",
"constant_test.go",
"datum_integration_test.go",
"datum_invariants_test.go",
"datum_test.go",
"eval_internal_test.go",
Expand Down
Loading