Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

sqlliveness: encode region in session id #91019

Merged
merged 1 commit into from
Nov 2, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions pkg/sql/enum/enum.go
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,10 @@ func GenByteStringBetween(prev []byte, next []byte, spacing ByteSpacing) []byte
return result
}

// One returns the representation of []byte representation of the first enum
// value created in a new Enum.
var One = []byte{byte(midToken)}

// Utility functions for GenByteStringBetween.

func get(arr []byte, idx int, def int) int {
Expand Down
5 changes: 5 additions & 0 deletions pkg/sql/enum/enum_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -186,3 +186,8 @@ func TestGenerateNEvenlySpacedBytes(t *testing.T) {
}
}
}

func TestOne(t *testing.T) {
require.Equal(t, One, GenByteStringBetween(nil, nil, PackedSpacing))
require.Equal(t, One, GenByteStringBetween(nil, nil, SpreadSpacing))
}
2 changes: 2 additions & 0 deletions pkg/sql/sqlliveness/slinstance/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,9 @@ go_library(
deps = [
"//pkg/settings",
"//pkg/settings/cluster",
"//pkg/sql/enum",
"//pkg/sql/sqlliveness",
"//pkg/sql/sqlliveness/slstorage",
"//pkg/util/grpcutil",
"//pkg/util/hlc",
"//pkg/util/log",
Expand Down
9 changes: 7 additions & 2 deletions pkg/sql/sqlliveness/slinstance/slinstance.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,9 @@ import (

"github.com/cockroachdb/cockroach/pkg/settings"
"github.com/cockroachdb/cockroach/pkg/settings/cluster"
"github.com/cockroachdb/cockroach/pkg/sql/enum"
"github.com/cockroachdb/cockroach/pkg/sql/sqlliveness"
"github.com/cockroachdb/cockroach/pkg/sql/sqlliveness/slstorage"
"github.com/cockroachdb/cockroach/pkg/util/grpcutil"
"github.com/cockroachdb/cockroach/pkg/util/hlc"
"github.com/cockroachdb/cockroach/pkg/util/log"
Expand Down Expand Up @@ -195,7 +197,11 @@ func (l *Instance) clearSessionLocked(ctx context.Context) (createNewSession boo
// createSession tries until it can create a new session and returns an error
// only if the heart beat loop should exit.
func (l *Instance) createSession(ctx context.Context) (*session, error) {
id := sqlliveness.SessionID(uuid.MakeV4().GetBytes())
id, err := slstorage.MakeSessionID(enum.One, uuid.MakeV4())
if err != nil {
return nil, err
}

start := l.clock.Now()
exp := start.Add(l.ttl().Nanoseconds(), 0)
s := &session{
Expand All @@ -210,7 +216,6 @@ func (l *Instance) createSession(ctx context.Context) (*session, error) {
Multiplier: 1.5,
}
everySecond := log.Every(time.Second)
var err error
for i, r := 0, retry.StartWithCtx(ctx, opts); r.Next(); {
i++
if err = l.storage.Insert(ctx, s.id, s.Expiration()); err != nil {
Expand Down
4 changes: 4 additions & 0 deletions pkg/sql/sqlliveness/slstorage/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ go_library(
name = "slstorage",
srcs = [
"metrics.go",
"sessionid.go",
"slstorage.go",
"test_helpers.go",
],
Expand All @@ -29,6 +30,7 @@ go_library(
"//pkg/util/syncutil",
"//pkg/util/syncutil/singleflight",
"//pkg/util/timeutil",
"//pkg/util/uuid",
"@com_github_cockroachdb_errors//:errors",
"@com_github_cockroachdb_logtags//:logtags",
"@com_github_cockroachdb_redact//:redact",
Expand All @@ -41,6 +43,7 @@ go_test(
size = "small",
srcs = [
"main_test.go",
"sessionid_test.go",
"slstorage_test.go",
],
args = ["-test.timeout=55s"],
Expand All @@ -56,6 +59,7 @@ go_test(
"//pkg/settings/cluster",
"//pkg/sql/catalog/descpb",
"//pkg/sql/catalog/systemschema",
"//pkg/sql/enum",
"//pkg/sql/sqlliveness",
"//pkg/testutils",
"//pkg/testutils/serverutils",
Expand Down
93 changes: 93 additions & 0 deletions pkg/sql/sqlliveness/slstorage/sessionid.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
// Copyright 2022 The Cockroach Authors.
//
// Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt.
//
// As of the Change Date specified in that file, in accordance with
// the Business Source License, use of this software will be governed
// by the Apache License, Version 2.0, included in the file
// licenses/APL.txt.

package slstorage

import (
"github.com/cockroachdb/cockroach/pkg/sql/sqlliveness"
"github.com/cockroachdb/cockroach/pkg/util/uuid"
"github.com/cockroachdb/errors"
)

const (
sessionIDVersion uint8 = 1
legacyLen = uuid.Size
versionLen = 1
regionLengthLen = 1
minimumRegionLen = 1
minimumPrefixLen = versionLen + regionLengthLen + minimumRegionLen
minimumNonLegacyLen = minimumPrefixLen + uuid.Size
)

// MakeSessionID encodes the region and uuid into a binary string. Most callers
// should treat the format of SessionID as opaque. The basic format is:
//
// byte[] {
// version = 1,
// len(region),
// region...,
// uuid...,
// }
//
// One of the goals of the encoding is every (region, uuid) pair should have
// exactly one valid binary encoding. Unique encodings make it safe to use the
// encoded version in maps. The goal of a single canonical representation
// disqualified the following encoding schemes:
// 1. protobufs: protobufs do not have a canonical encoding scheme. The order
// of fields is not guaranteed.
// 2. region length is encoded as a single byte instead of a varint. Small
// numbers have multiple valid varint encodings. E.g 0x8001 and 0x01 are both
// valid encodings of 1.
func MakeSessionID(region []byte, id uuid.UUID) (sqlliveness.SessionID, error) {
if len(region) == 0 {
return sqlliveness.SessionID(""), errors.New("session id requires a non-empty region")
}
if int(uint8(len(region))) != len(region) {
return sqlliveness.SessionID(""), errors.Newf("region is too long: %d", len(region))
}

sessionLength := versionLen + regionLengthLen + len(region) + uuid.Size
b := make([]byte, 0, sessionLength)
b = append(b, sessionIDVersion)
b = append(b, byte(len(region)))
b = append(b, region...)
b = append(b, id.GetBytes()...)
return sqlliveness.SessionID(b), nil
}

// UnsafeDecodeSessionID decodes the region and id from the SessionID. The
// function is unsafe, because the byte slices index into the session and must
// not be mutated.
func UnsafeDecodeSessionID(session sqlliveness.SessionID) (region, id []byte, err error) {
b := session.UnsafeBytes()
if len(b) == legacyLen {
// Legacy format of SessionID.
return nil, b, nil
}
if len(b) < minimumNonLegacyLen {
// The smallest valid v1 session id is a [version, 1, single_byte_region, uuid...],
// which is three bytes larger than a uuid.
return nil, nil, errors.New("session id is too short")
}

// Decode the version.
if b[0] != sessionIDVersion {
return nil, nil, errors.Newf("invalid session id version: %d", b[0])
}
regionLen := int(b[1])
rest := b[2:]

// Decode and validate the length of the region.
if len(rest) != regionLen+uuid.Size {
return nil, nil, errors.Newf("session id with length %d is the wrong size to include a region with length %d", len(b), regionLen)
}

return rest[:regionLen], rest[regionLen:], nil
}
159 changes: 159 additions & 0 deletions pkg/sql/sqlliveness/slstorage/sessionid_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,159 @@
// Copyright 2022 The Cockroach Authors.
//
// Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt.
//
// As of the Change Date specified in that file, in accordance with
// the Business Source License, use of this software will be governed
// by the Apache License, Version 2.0, included in the file
// licenses/APL.txt.

package slstorage_test

import (
"testing"

"github.com/cockroachdb/cockroach/pkg/sql/enum"
"github.com/cockroachdb/cockroach/pkg/sql/sqlliveness"
"github.com/cockroachdb/cockroach/pkg/sql/sqlliveness/slstorage"
"github.com/cockroachdb/cockroach/pkg/util/leaktest"
"github.com/cockroachdb/cockroach/pkg/util/log"
"github.com/cockroachdb/cockroach/pkg/util/uuid"
"github.com/stretchr/testify/require"
)

func FuzzSessionIDEncoding(f *testing.F) {
defer leaktest.AfterTest(f)()
defer log.Scope(f).Close(f)

f.Add(string(""))
f.Add(string(uuid.FastMakeV4().GetBytes()))

session, err := slstorage.MakeSessionID(enum.One, uuid.FastMakeV4())
require.NoError(f, err)
f.Add(string(session))

f.Fuzz(func(t *testing.T, randomSession string) {
session := sqlliveness.SessionID(randomSession)
region, id, err := slstorage.UnsafeDecodeSessionID(session)
if err == nil {
if len([]byte(randomSession)) == 16 {
// A 16 bytes session is always valid, because it is the legacy uuid encoding.
require.Equal(t, []byte(randomSession), id)
} else {
// If the session is a valid encoding, then re-encoding the
// decoded pieces should produce an identical session.
require.Len(t, id, 16)
reEncoded, err := slstorage.MakeSessionID(region, uuid.FromBytesOrNil(id))
require.NoError(t, err)
require.Equal(t, session, reEncoded)
}
}
})
}

func TestMakeSessionIDValidation(t *testing.T) {
defer leaktest.AfterTest(t)()
defer log.Scope(t).Close(t)

_, err := slstorage.MakeSessionID(nil, uuid.MakeV4())
require.ErrorContains(t, err, "session id requires a non-empty region")
_, err = slstorage.MakeSessionID([]byte{}, uuid.MakeV4())
require.ErrorContains(t, err, "session id requires a non-empty region")
_, err = slstorage.MakeSessionID(make([]byte, 256), uuid.MakeV4())
require.ErrorContains(t, err, "region is too long")
}

func TestSessionIDEncoding(t *testing.T) {
defer leaktest.AfterTest(t)()
defer log.Scope(t).Close(t)

id1 := uuid.MakeV4()

must := func(session sqlliveness.SessionID, err error) sqlliveness.SessionID {
require.NoError(t, err)
return session
}

testCases := []struct {
name string
session sqlliveness.SessionID
region []byte
id uuid.UUID
err string
}{
{
name: "empty_session",
session: "",
err: "session id is too short",
},
{
name: "legacy_session",
session: sqlliveness.SessionID(id1.GetBytes()),
id: id1,
},
{
name: "session_v1",
session: must(slstorage.MakeSessionID(enum.One, id1)),
region: enum.One,
id: id1,
},
{
name: "region_len_too_large",
session: func() sqlliveness.SessionID {
session := []byte(must(slstorage.MakeSessionID([]byte{128}, id1)))
session[1] = 3
return sqlliveness.SessionID(session)
}(),
err: "session id with length 19 is the wrong size to include a region with length 3",
region: []byte{},
id: id1,
},
{
name: "region_len_too_small",
session: func() sqlliveness.SessionID {
session := []byte(must(slstorage.MakeSessionID([]byte{128}, id1)))
session[1] = 0
return sqlliveness.SessionID(session)
}(),
err: "session id with length 19 is the wrong size to include a region with length 0",
region: []byte{},
id: id1,
},
{
name: "session_id_too_short",
session: func() sqlliveness.SessionID {
smallestValidSession := must(slstorage.MakeSessionID([]byte{128}, id1))
return smallestValidSession[:len(smallestValidSession)-1]
}(),
err: "session id is too short",
},
{
name: "session_v1_large_region",
session: must(slstorage.MakeSessionID(make([]byte, 255), id1)),
region: make([]byte, 255),
id: id1,
},
{
name: "invalid_version",
session: func() sqlliveness.SessionID {
session := []byte(must(slstorage.MakeSessionID(make([]byte, 255), id1)))
session[0] = 2
return sqlliveness.SessionID(session)
}(),
err: "invalid session id version: 2",
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
region, uuid, err := slstorage.UnsafeDecodeSessionID(tc.session)
if tc.err != "" {
require.Error(t, err)
require.Contains(t, err.Error(), tc.err)
} else {
require.Equal(t, region, tc.region)
require.Equal(t, uuid, tc.id.GetBytes())
}
})
}
}