From dfbe114eaddaf12e3b11d11b287171ef1c40649a Mon Sep 17 00:00:00 2001 From: Stuart Carnie Date: Tue, 24 Nov 2020 09:58:01 +1100 Subject: [PATCH] =?UTF-8?q?fix:=20PR=20Feedback=20=E2=80=93=20ensure=20key?= =?UTF-8?q?s=20cannot=20contain=20/?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- kv/index.go | 36 ++++++++++++++++++++++++++------ kv/index_migration.go | 13 ++++++++++-- kv/index_test.go | 48 ++++++++++++++++++++++++++++++++++++++++++- 3 files changed, 88 insertions(+), 9 deletions(-) diff --git a/kv/index.go b/kv/index.go index 33e28d35be3..3470cf7286d 100644 --- a/kv/index.go +++ b/kv/index.go @@ -128,9 +128,23 @@ func (i *Index) sourceBucket(tx Tx) (Bucket, error) { return tx.Bucket(i.SourceBucket()) } +var ( + // ErrKeyInvalidCharacters is returned when a foreignKey or primaryKey contains + // + ErrKeyInvalidCharacters = errors.New("key: contains invalid characters") +) + // IndexKey returns a value suitable for use as the key component -// when storing values in the index. -func IndexKey(foreignKey, primaryKey []byte) (newKey []byte) { +// when storing values in the index. IndexKey returns an +// ErrKeyInvalidCharacters error if either the foreignKey or primaryKey contains a /. +func IndexKey(foreignKey, primaryKey []byte) (newKey []byte, err error) { + if bytes.IndexByte(foreignKey, '/') != -1 { + return nil, ErrKeyInvalidCharacters + } + if bytes.IndexByte(primaryKey, '/') != -1 { + return nil, ErrKeyInvalidCharacters + } + newKey = make([]byte, len(primaryKey)+len(foreignKey)+1) copy(newKey, foreignKey) newKey[len(foreignKey)] = '/' @@ -159,7 +173,12 @@ func (i *Index) Insert(tx Tx, foreignKey, primaryKey []byte) error { return err } - return bkt.Put(IndexKey(foreignKey, primaryKey), primaryKey) + key, err := IndexKey(foreignKey, primaryKey) + if err != nil { + return err + } + + return bkt.Put(key, primaryKey) } // Delete removes the foreignKey and primaryKey mapping from the underlying index. @@ -169,7 +188,12 @@ func (i *Index) Delete(tx Tx, foreignKey, primaryKey []byte) error { return err } - return bkt.Delete(IndexKey(foreignKey, primaryKey)) + key, err := IndexKey(foreignKey, primaryKey) + if err != nil { + return err + } + + return bkt.Delete(key) } // Walk walks the source bucket using keys found in the index using the provided foreign key @@ -197,13 +221,13 @@ func (i *Index) Walk(ctx context.Context, tx Tx, foreignKey []byte, visitFn Visi return err } - return indexWalk(ctx, foreignKey, cursor, sourceBucket, visitFn) + return indexWalk(foreignKey, cursor, sourceBucket, visitFn) } // indexWalk consumes the indexKey and primaryKey pairs in the index bucket and looks up their // associated primaryKey's value in the provided source bucket. // When an item is located in the source, the provided visit function is called with primary key and associated value. -func indexWalk(ctx context.Context, foreignKey []byte, indexCursor ForwardCursor, sourceBucket Bucket, visit VisitFunc) (err error) { +func indexWalk(foreignKey []byte, indexCursor ForwardCursor, sourceBucket Bucket, visit VisitFunc) (err error) { var keys [][]byte for ik, pk := indexCursor.Next(); ik != nil; ik, pk = indexCursor.Next() { if fk, _, err := indexKeyParts(ik); err != nil { diff --git a/kv/index_migration.go b/kv/index_migration.go index 7b33f336712..363c8981aa5 100644 --- a/kv/index_migration.go +++ b/kv/index_migration.go @@ -128,7 +128,11 @@ func (i *IndexMigration) Populate(ctx context.Context, store Store) (n int, err for fk, fkm := range diff.MissingFromIndex { for pk := range fkm { - batch = append(batch, [2][]byte{IndexKey([]byte(fk), []byte(pk)), []byte(pk)}) + key, err := IndexKey([]byte(fk), []byte(pk)) + if err != nil { + return n, err + } + batch = append(batch, [2][]byte{key, []byte(pk)}) if len(batch) >= i.operationBatchSize { if err := flush(batch); err != nil { @@ -183,7 +187,12 @@ func (i *IndexMigration) remove(ctx context.Context, store Store, mappings map[s for fk, fkm := range mappings { for pk := range fkm { - batch = append(batch, IndexKey([]byte(fk), []byte(pk))) + key, err := IndexKey([]byte(fk), []byte(pk)) + if err != nil { + return err + } + + batch = append(batch, key) if len(batch) >= i.operationBatchSize { if err := flush(batch); err != nil { diff --git a/kv/index_test.go b/kv/index_test.go index 090280935a6..017bf0fe1d9 100644 --- a/kv/index_test.go +++ b/kv/index_test.go @@ -38,6 +38,48 @@ func Test_Bolt_Index(t *testing.T) { influxdbtesting.TestIndex(t, s) } +func TestIndexKey(t *testing.T) { + tests := []struct { + name string + fk string + pk string + expKey string + expErr error + }{ + { + name: "returns key", + fk: "fk_part", + pk: "pk_part", + expKey: "fk_part/pk_part", + }, + { + name: "returns error for invalid foreign key", + fk: "fk/part", + pk: "pk_part", + expErr: kv.ErrKeyInvalidCharacters, + }, + { + name: "returns error for invalid primary key", + fk: "fk_part", + pk: "pk/part", + expErr: kv.ErrKeyInvalidCharacters, + }, + } + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + gotKey, gotErr := kv.IndexKey([]byte(test.fk), []byte(test.pk)) + if test.expErr != nil { + require.Error(t, gotErr) + assert.EqualError(t, test.expErr, gotErr.Error()) + assert.Nil(t, gotKey) + } else { + assert.NoError(t, gotErr) + assert.Equal(t, test.expKey, string(gotKey)) + } + }) + } +} + func TestIndex_Walk(t *testing.T) { t.Run("only selects exact keys", func(t *testing.T) { ctrl := gomock.NewController(t) @@ -45,8 +87,12 @@ func TestIndex_Walk(t *testing.T) { type keyValue struct{ key, val []byte } makeIndexKV := func(fk, pk string) keyValue { + key, err := kv.IndexKey([]byte(fk), []byte(pk)) + if err != nil { + panic(err) + } return keyValue{ - key: kv.IndexKey([]byte(fk), []byte(pk)), + key: key, val: []byte(pk), } }