Skip to content

Commit

Permalink
sql/pgwire: use OIDs when encoding some datum types
Browse files Browse the repository at this point in the history
Teach pgwire how to encode int and float with the various widths. Do
this by saving the oids during SetColumns.

Teach cmd/generate-binary to also record oids and use those in tests.

Fix varbits to expect the same OID as postgres produces.

Sadly, our int2 and int4 types don't yet propagate all the way down
and so we still encode them as an int8. This commit is a precursor to
supporting that.

Release note: None
  • Loading branch information
maddyblue committed Feb 12, 2019
1 parent b80d241 commit ea1c4fc
Show file tree
Hide file tree
Showing 8 changed files with 556 additions and 23 deletions.
57 changes: 57 additions & 0 deletions pkg/cmd/generate-binary/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -91,8 +91,14 @@ func main() {
if err != nil {
log.Fatalf("binary: %s: %v", sql, err)
}
sql = fmt.Sprintf("SELECT pg_typeof(%s)::int", expr)
id, err := pgconnect.Connect(ctx, sql, *postgresAddr, *postgresUser, pgwirebase.FormatText)
if err != nil {
log.Fatalf("oid: %s: %v", sql, err)
}
data = append(data, entry{
SQL: expr,
Oid: string(id),
Text: text,
Binary: binary,
})
Expand All @@ -109,6 +115,7 @@ func main() {

type entry struct {
SQL string
Oid string
Text []byte
Binary []byte
}
Expand All @@ -131,6 +138,7 @@ const outputJSON = `[
{{- if gt $idx 0 }},{{end}}
{
"SQL": {{.SQL | json}},
"Oid": {{.Oid}},
"Text": {{printf "%q" .Text}},
"TextAsBinary": {{.Text | binary}},
"Binary": {{.Binary | binary}}
Expand Down Expand Up @@ -216,6 +224,55 @@ var inputs = map[string][]string{
fmt.Sprint(math.SmallestNonzeroFloat64),
},

"'%s'::float4": {
// The Go binary encoding of NaN differs from Postgres by a 1 at the
// end. Go also uses Inf instead of Infinity (used by Postgres) for text
// float encodings. These deviations are still correct, and it's not worth
// special casing them into the code, so they are commented out here.
//"NaN",
//"Inf",
//"-Inf",
"-000.000",
"-0000021234.2",
"-1.2",
".0",
".1",
".1234",
".12345",
"3.40282e+38",
"1.4013e-45",
},

"'%s'::int2": {
"0",
"1",
"-1",
"-32768",
"32767",
},

"'%s'::int4": {
"0",
"1",
"-1",
"-32768",
"32767",
"-2147483648",
"2147483647",
},

"'%s'::int8": {
"0",
"1",
"-1",
"-32768",
"32767",
"-2147483648",
"2147483647",
"-9223372036854775808",
"9223372036854775807",
},

"'%s'::timestamp": {
"1999-01-08 04:05:06+00",
"1999-01-08 04:05:06+00:00",
Expand Down
10 changes: 9 additions & 1 deletion pkg/sql/pgwire/command_result.go
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,10 @@ type commandResult struct {
// case for queries executed through the simple protocol). Otherwise, it needs
// to have an entry for every column.
formatCodes []pgwirebase.FormatCode

// oids is a map from result column index to its Oid, similar to formatCodes
// (except oids must always be set).
oids []oid.Oid
}

func (c *conn) makeCommandResult(
Expand Down Expand Up @@ -213,7 +217,7 @@ func (r *commandResult) AddRow(ctx context.Context, row tree.Datums) error {
}
r.rowsAffected++

r.conn.bufferRow(ctx, row, r.formatCodes, r.conv)
r.conn.bufferRow(ctx, row, r.formatCodes, r.conv, r.oids)
_ /* flushed */, err := r.conn.maybeFlush(r.pos)
return err
}
Expand All @@ -224,6 +228,10 @@ func (r *commandResult) SetColumns(ctx context.Context, cols sqlbase.ResultColum
if r.descOpt == sql.NeedRowDesc {
_ /* err */ = r.conn.writeRowDescription(ctx, cols, r.formatCodes, &r.conn.writerState.buf)
}
r.oids = make([]oid.Oid, len(cols))
for i, col := range cols {
r.oids[i] = col.Typ.Oid()
}
}

// SetInferredTypes is part of the DescribeResult interface.
Expand Down
3 changes: 2 additions & 1 deletion pkg/sql/pgwire/conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -915,6 +915,7 @@ func (c *conn) bufferRow(
row tree.Datums,
formatCodes []pgwirebase.FormatCode,
conv sessiondata.DataConversionConfig,
oids []oid.Oid,
) {
c.msgBuilder.initMsg(pgwirebase.ServerMsgDataRow)
c.msgBuilder.putInt16(int16(len(row)))
Expand All @@ -927,7 +928,7 @@ func (c *conn) bufferRow(
case pgwirebase.FormatText:
c.msgBuilder.writeTextDatum(ctx, col, conv)
case pgwirebase.FormatBinary:
c.msgBuilder.writeBinaryDatum(ctx, col, conv.Location)
c.msgBuilder.writeBinaryDatum(ctx, col, conv.Location, oids[i])
default:
c.msgBuilder.setError(errors.Errorf("unsupported format code %s", fmtCode))
}
Expand Down
14 changes: 8 additions & 6 deletions pkg/sql/pgwire/encoding_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,11 +34,13 @@ import (
"github.com/cockroachdb/cockroach/pkg/sql/sessiondata"
"github.com/cockroachdb/cockroach/pkg/util/leaktest"
"github.com/cockroachdb/cockroach/pkg/util/metric"
"github.com/lib/pq/oid"
)

type encodingTest struct {
SQL string
Datum tree.Datum
Oid oid.Oid
Text string
TextAsBinary []byte
Binary []byte
Expand Down Expand Up @@ -102,6 +104,7 @@ func TestEncodings(t *testing.T) {
buf := newWriteBuffer(metric.NewCounter(metric.Metadata{}))

verifyLen := func(t *testing.T) []byte {
t.Helper()
b := buf.wrapped.Bytes()
if len(b) < 4 {
t.Fatal("short buffer")
Expand Down Expand Up @@ -140,7 +143,7 @@ func TestEncodings(t *testing.T) {
})
t.Run(pgwirebase.FormatBinary.String(), func(t *testing.T) {
buf.reset()
buf.writeBinaryDatum(ctx, d, time.UTC)
buf.writeBinaryDatum(ctx, d, time.UTC, tc.Oid)
if buf.err != nil {
t.Fatal(buf.err)
}
Expand All @@ -159,14 +162,13 @@ func TestEncodings(t *testing.T) {
// Unsupported.
t.Skip()
}
id := tc.Datum.ResolvedType().Oid()
for code, value := range map[pgwirebase.FormatCode][]byte{
pgwirebase.FormatText: tc.TextAsBinary,
pgwirebase.FormatBinary: tc.Binary,
} {
t.Run(code.String(), func(t *testing.T) {
t.Logf("code: %s\nvalue: %q (%[2]s)\noid: %v", code, value, id)
d, err := pgwirebase.DecodeOidDatum(nil, id, code, value)
t.Logf("code: %s\nvalue: %q (%[2]s)\noid: %v", code, value, tc.Oid)
d, err := pgwirebase.DecodeOidDatum(nil, tc.Oid, code, value)
if err != nil {
t.Fatal(err)
}
Expand Down Expand Up @@ -219,7 +221,7 @@ func BenchmarkEncodings(b *testing.B) {
for i := 0; i < b.N; i++ {
buf.reset()
b.StartTimer()
buf.writeBinaryDatum(ctx, d, time.UTC)
buf.writeBinaryDatum(ctx, d, time.UTC, tc.Oid)
b.StopTimer()
}
})
Expand All @@ -232,7 +234,7 @@ func TestEncodingErrorCounts(t *testing.T) {

buf := newWriteBuffer(metric.NewCounter(metric.Metadata{}))
d, _ := tree.ParseDDecimal("Inf")
buf.writeBinaryDatum(context.Background(), d, nil)
buf.writeBinaryDatum(context.Background(), d, nil, d.ResolvedType().Oid())
if count := telemetry.GetFeatureCounts()["pgwire.#32489.binary_decimal_infinity"]; count != 1 {
t.Fatalf("expected 1 encoding error, got %d", count)
}
Expand Down
2 changes: 1 addition & 1 deletion pkg/sql/pgwire/pgwirebase/encoding.go
Original file line number Diff line number Diff line change
Expand Up @@ -531,7 +531,7 @@ func DecodeOidDatum(
return nil, err
}
return tree.ParseDJSON(string(b))
case oid.T_varbit:
case oid.T_varbit, oid.T_bit:
if len(b) < 4 {
return nil, errors.Errorf("missing varbit bitlen prefix")
}
Expand Down
Loading

0 comments on commit ea1c4fc

Please sign in to comment.