diff --git a/go.mod b/go.mod index f1743b64..8dab9abd 100644 --- a/go.mod +++ b/go.mod @@ -8,16 +8,17 @@ require ( github.com/go-sql-driver/mysql v1.7.1 github.com/jackc/pgerrcode v0.0.0-20220416144525-469b46aa5efa github.com/jackc/pgx/v5 v5.4.2 - github.com/klauspost/compress v1.16.6 + github.com/klauspost/compress v1.17.2 github.com/mattn/go-sqlite3 v1.14.17 github.com/nats-io/jsm.go v0.0.31-0.20220317133147-fe318f464eee - github.com/nats-io/nats-server/v2 v2.9.18 - github.com/nats-io/nats.go v1.27.1 + github.com/nats-io/nats-server/v2 v2.10.5-0.20231101212211-a190514dcb2f + github.com/nats-io/nats.go v1.31.0 github.com/pkg/errors v0.9.1 github.com/prometheus/client_golang v1.14.0 github.com/rancher/wrangler v1.1.1-0.20230425173236-39a4707f0689 github.com/shengdoushi/base58 v1.0.0 github.com/sirupsen/logrus v1.9.0 + github.com/tidwall/btree v1.6.0 github.com/urfave/cli v1.22.4 go.etcd.io/etcd/api/v3 v3.5.9 go.etcd.io/etcd/client/pkg/v3 v3.5.9 @@ -53,8 +54,8 @@ require ( github.com/minio/highwayhash v1.0.2 // indirect github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect github.com/modern-go/reflect2 v1.0.2 // indirect - github.com/nats-io/jwt/v2 v2.4.1 // indirect - github.com/nats-io/nkeys v0.4.4 // indirect + github.com/nats-io/jwt/v2 v2.5.3 // indirect + github.com/nats-io/nkeys v0.4.6 // indirect github.com/nats-io/nuid v1.0.1 // indirect github.com/prometheus/client_model v0.3.0 // indirect github.com/prometheus/common v0.37.0 // indirect diff --git a/go.sum b/go.sum index ff54ba25..82902343 100644 --- a/go.sum +++ b/go.sum @@ -210,8 +210,8 @@ github.com/julienschmidt/httprouter v1.3.0/go.mod h1:JR6WtHb+2LUe8TCKY3cZOxFyyO8 github.com/kisielk/errcheck v1.5.0/go.mod h1:pFxgyoBC7bSaBwPgfKdkLd5X25qrDl4LWUI2bnpBCr8= github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck= github.com/klauspost/compress v1.14.4/go.mod h1:/3/Vjq9QcHkK5uEr5lBEmyoZ1iFhe47etQ6QUkpK6sk= -github.com/klauspost/compress v1.16.6 h1:91SKEy4K37vkp255cJ8QesJhjyRO0hn9i9G0GoUwLsk= -github.com/klauspost/compress v1.16.6/go.mod h1:ntbaceVETuRiXiv4DpjP66DpAtAGkEQskQzEyD//IeE= +github.com/klauspost/compress v1.17.2 h1:RlWWUY/Dr4fL8qk9YG7DTZ7PDgME2V4csBXA8L/ixi4= +github.com/klauspost/compress v1.17.2/go.mod h1:ntbaceVETuRiXiv4DpjP66DpAtAGkEQskQzEyD//IeE= github.com/konsorten/go-windows-terminal-sequences v1.0.1/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ= github.com/konsorten/go-windows-terminal-sequences v1.0.3/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ= github.com/kr/logfmt v0.0.0-20140226030751-b84e30acd515/go.mod h1:+0opPa2QZZtGFBFZlji/RkVcI2GknAs/DXo4wKdlNEc= @@ -239,17 +239,17 @@ github.com/mwitkow/go-conntrack v0.0.0-20190716064945-2f068394615f/go.mod h1:qRW github.com/nats-io/jsm.go v0.0.31-0.20220317133147-fe318f464eee h1:+l6i7zS8N1LOokm7dzShezI9STRGrzp0O49Pw8Jetdk= github.com/nats-io/jsm.go v0.0.31-0.20220317133147-fe318f464eee/go.mod h1:EKSYvbvWAoh0hIfuZ+ieWm8u0VOTRTeDfuQvNPKRqEg= github.com/nats-io/jwt/v2 v2.2.1-0.20220113022732-58e87895b296/go.mod h1:0tqz9Hlu6bCBFLWAASKhE5vUA4c24L9KPUUgvwumE/k= -github.com/nats-io/jwt/v2 v2.4.1 h1:Y35W1dgbbz2SQUYDPCaclXcuqleVmpbRa7646Jf2EX4= -github.com/nats-io/jwt/v2 v2.4.1/go.mod h1:24BeQtRwxRV8ruvC4CojXlx/WQ/VjuwlYiH+vu/+ibI= +github.com/nats-io/jwt/v2 v2.5.3 h1:/9SWvzc6hTfamcgXJ3uYRpgj+QuY2aLNqRiqrKcrpEo= +github.com/nats-io/jwt/v2 v2.5.3/go.mod h1:iysuPemFcc7p4IoYots3IuELSI4EDe9Y0bQMe+I3Bf4= github.com/nats-io/nats-server/v2 v2.7.5-0.20220309212130-5c0d1999ff72/go.mod h1:1vZ2Nijh8tcyNe8BDVyTviCd9NYzRbubQYiEHsvOQWc= -github.com/nats-io/nats-server/v2 v2.9.18 h1:00muGH0qu/7NAw1b/2eFcpIvdHcTghj6PFjUVhy8zEo= -github.com/nats-io/nats-server/v2 v2.9.18/go.mod h1:aTb/xtLCGKhfTFLxP591CMWfkdgBmcUUSkiSOe5A3gw= +github.com/nats-io/nats-server/v2 v2.10.5-0.20231101212211-a190514dcb2f h1:esUVQwHgVK6LSqLSwJ0UzZ/490VOvCnwF7zbZd+gsfI= +github.com/nats-io/nats-server/v2 v2.10.5-0.20231101212211-a190514dcb2f/go.mod h1:eWm2JmHP9Lqm2oemB6/XGi0/GwsZwtWf8HIPUsh+9ns= github.com/nats-io/nats.go v1.13.1-0.20220308171302-2f2f6968e98d/go.mod h1:BPko4oXsySz4aSWeFgOHLZs3G4Jq4ZAyE6/zMCxRT6w= -github.com/nats-io/nats.go v1.27.1 h1:OuYnal9aKVSnOzLQIzf7554OXMCG7KbaTkCSBHRcSoo= -github.com/nats-io/nats.go v1.27.1/go.mod h1:XpbWUlOElGwTYbMR7imivs7jJj9GtK7ypv321Wp6pjc= +github.com/nats-io/nats.go v1.31.0 h1:/WFBHEc/dOKBF6qf1TZhrdEfTmOZ5JzdJ+Y3m6Y/p7E= +github.com/nats-io/nats.go v1.31.0/go.mod h1:di3Bm5MLsoB4Bx61CBTsxuarI36WbhAwOm8QrW39+i8= github.com/nats-io/nkeys v0.3.0/go.mod h1:gvUNGjVcM2IPr5rCsRsC6Wb3Hr2CQAm08dsxtV6A5y4= -github.com/nats-io/nkeys v0.4.4 h1:xvBJ8d69TznjcQl9t6//Q5xXuVhyYiSos6RPtvQNTwA= -github.com/nats-io/nkeys v0.4.4/go.mod h1:XUkxdLPTufzlihbamfzQ7mw/VGx6ObUs+0bN5sNvt64= +github.com/nats-io/nkeys v0.4.6 h1:IzVe95ru2CT6ta874rt9saQRkWfe2nFj1NtvYSLqMzY= +github.com/nats-io/nkeys v0.4.6/go.mod h1:4DxZNzenSVd1cYQoAa8948QY3QDjrHfcfVADymtkpts= github.com/nats-io/nuid v1.0.1 h1:5iA8DT8V7q8WK2EScv2padNa/rTESc1KdnPw4TC2paw= github.com/nats-io/nuid v1.0.1/go.mod h1:19wcPz3Ph3q0Jbyiqsd0kePYG7A95tJPxeL+1OSON2c= github.com/opentracing/opentracing-go v1.1.0/go.mod h1:UkNAQd3GIcIGf0SeVgPpRdFStlNbqXla1AfSYxPUl2o= @@ -316,6 +316,8 @@ github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81P github.com/stretchr/testify v1.5.1/go.mod h1:5W2xD1RspED5o8YsWQXVCued0rvSQ+mT+I5cxcmMvtA= github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.8.1 h1:w7B6lhMri9wdJUVmEZPGGhZzrYTPvgJArz7wNPgYKsk= +github.com/tidwall/btree v1.6.0 h1:LDZfKfQIBHGHWSwckhXI0RPSXzlo+KYdjK7FWSqOzzg= +github.com/tidwall/btree v1.6.0/go.mod h1:twD9XRA5jj9VUQGELzDO4HPQTNJsoWWfYEL+EUQ2cKY= github.com/tmc/grpc-websocket-proxy v0.0.0-20201229170055-e5319fda7802 h1:uruHq4dN7GR16kFc5fp3d1RIYzJW5onx8Ybykw2YQFA= github.com/tmc/grpc-websocket-proxy v0.0.0-20201229170055-e5319fda7802/go.mod h1:ncp9v5uamzpCO7NfCPTXjqaC+bZgJeR0sMTm6dMHP7U= github.com/urfave/cli v1.22.4 h1:u7tSpNPPswAFymm8IehJhy4uJMlUuU/GmqSkvJ1InXA= diff --git a/pkg/drivers/nats/backend.go b/pkg/drivers/nats/backend.go new file mode 100644 index 00000000..c6667b61 --- /dev/null +++ b/pkg/drivers/nats/backend.go @@ -0,0 +1,424 @@ +package nats + +import ( + "context" + "encoding/json" + "time" + + "github.com/k3s-io/kine/pkg/server" + "github.com/nats-io/nats.go" + "github.com/nats-io/nats.go/jetstream" + "github.com/sirupsen/logrus" +) + +// TODO: version this data structure to simplify and optimize for size. +type natsData struct { + // v1 fields + KV *server.KeyValue `json:"KV"` + PrevRevision int64 `json:"PrevRevision"` + Create bool `json:"Create"` + Delete bool `json:"Delete"` + + CreateTime time.Time `json:"-"` +} + +func (d *natsData) Encode() ([]byte, error) { + buf, err := json.Marshal(d) + return buf, err +} + +func (d *natsData) Decode(e jetstream.KeyValueEntry) error { + if e == nil || e.Value() == nil { + return nil + } + + err := json.Unmarshal(e.Value(), d) + if err != nil { + return err + } + d.KV.ModRevision = int64(e.Revision()) + if d.KV.CreateRevision == 0 { + d.KV.CreateRevision = d.KV.ModRevision + } + d.CreateTime = e.Created() + return nil +} + +var ( + // Ensure Backend implements server.Backend. + _ server.Backend = (&Backend{}) +) + +type Backend struct { + nc *nats.Conn + js jetstream.JetStream + kv *KeyValue + l *logrus.Logger + cancel context.CancelFunc +} + +func (b *Backend) Close() error { + b.cancel() + return b.nc.Drain() +} + +// isExpiredKey checks if the key is expired based on the create time and lease. +func (b *Backend) isExpiredKey(value *natsData) bool { + if value.KV.Lease == 0 { + return false + } + + return time.Now().After(value.CreateTime.Add(time.Second * time.Duration(value.KV.Lease))) +} + +// get returns the key-value entry for the given key and revision, if specified. +// This takes into account entries that have been marked as deleted or expired. +func (b *Backend) get(ctx context.Context, key string, revision int64, allowDeletes bool) (int64, *natsData, error) { + var ( + entry jetstream.KeyValueEntry + err error + ) + + // Get latest revision if not specified. + if revision <= 0 { + entry, err = b.kv.Get(ctx, key) + } else { + entry, err = b.kv.GetRevision(ctx, key, uint64(revision)) + } + if err != nil { + return 0, nil, err + } + + rev := int64(entry.Revision()) + + var val natsData + err = val.Decode(entry) + if err != nil { + return 0, nil, err + } + + if val.Delete && !allowDeletes { + return 0, nil, jetstream.ErrKeyNotFound + } + + if b.isExpiredKey(&val) { + err := b.kv.Delete(ctx, val.KV.Key, jetstream.LastRevision(uint64(rev))) + if err != nil { + b.l.Warnf("Failed to delete expired key %s: %v", val.KV.Key, err) + } + // Return a zero indicating the key was deleted. + return 0, nil, jetstream.ErrKeyNotFound + } + + return rev, &val, nil +} + +// Start starts the backend. +// See https://github.com/kubernetes/kubernetes/blob/442a69c3bdf6fe8e525b05887e57d89db1e2f3a5/staging/src/k8s.io/apiserver/pkg/storage/storagebackend/factory/etcd3.go#L97 +func (b *Backend) Start(ctx context.Context) error { + if _, err := b.Create(ctx, "/registry/health", []byte(`{"health":"true"}`), 0); err != nil { + if err != server.ErrKeyExists { + b.l.Errorf("Failed to create health check key: %v", err) + } + } + return nil +} + +// DbSize get the kineBucket size from JetStream. +func (b *Backend) DbSize(ctx context.Context) (int64, error) { + return b.kv.BucketSize(ctx) +} + +// Count returns an exact count of the number of matching keys and the current revision of the database. +func (b *Backend) Count(ctx context.Context, prefix string) (int64, int64, error) { + count, err := b.kv.Count(ctx, prefix) + if err != nil { + return 0, 0, err + } + + storeRev := b.kv.BucketRevision() + return storeRev, count, nil +} + +// Get returns the store's current revision, the associated server.KeyValue or an error. +func (b *Backend) Get(ctx context.Context, key, rangeEnd string, limit, revision int64) (int64, *server.KeyValue, error) { + storeRev := b.kv.BucketRevision() + // Get the kv entry and return the revision. + rev, nv, err := b.get(ctx, key, revision, false) + if err == nil { + if nv == nil { + return storeRev, nil, nil + } + return rev, nv.KV, nil + } + if err == jetstream.ErrKeyNotFound { + return storeRev, nil, nil + } + + return rev, nil, err +} + +// Create attempts to create the key-value entry and returns the revision number. +func (b *Backend) Create(ctx context.Context, key string, value []byte, lease int64) (int64, error) { + // Check if key exists already. If the entry exists even if marked as expired or deleted, + // the revision will be returned to apply an update. + rev, pnv, err := b.get(ctx, key, 0, true) + // If an error other than key not found, return. + if err != nil && err != jetstream.ErrKeyNotFound { + return 0, err + } + + nv := natsData{ + Delete: false, + Create: true, + PrevRevision: 0, + KV: &server.KeyValue{ + Key: key, + CreateRevision: 0, + ModRevision: 0, + Value: value, + Lease: lease, + }, + } + + if pnv != nil { + if !pnv.Delete { + return 0, server.ErrKeyExists + } + nv.PrevRevision = pnv.KV.ModRevision + } + + data, err := nv.Encode() + if err != nil { + return 0, err + } + + if pnv != nil { + seq, err := b.kv.Update(ctx, key, data, uint64(rev)) + if err != nil { + if jsWrongLastSeqErr.Is(err) { + b.l.Warnf("create conflict: key=%s, rev=%d, err=%s", key, rev, err) + return 0, server.ErrKeyExists + } + return 0, err + } + + return int64(seq), nil + } + + // An update with a zero revision will create the key. + seq, err := b.kv.Create(ctx, key, data) + if err != nil { + if jsWrongLastSeqErr.Is(err) { + b.l.Warnf("create conflict: key=%s, rev=0, err=%s", key, err) + return 0, server.ErrKeyExists + } + return 0, err + } + + return int64(seq), nil +} + +func (b *Backend) Delete(ctx context.Context, key string, revision int64) (int64, *server.KeyValue, bool, error) { + // Get the key, allow deletes. + rev, value, err := b.get(ctx, key, 0, true) + if err != nil { + if err == jetstream.ErrKeyNotFound { + return rev, nil, true, nil + } + return rev, nil, false, err + } + if value == nil { + return rev, nil, true, nil + } + if value.Delete { + return rev, value.KV, true, nil + } + if revision != 0 && value.KV.ModRevision != revision { + return rev, value.KV, false, nil + } + + nv := natsData{ + Delete: true, + PrevRevision: rev, + KV: value.KV, + } + data, err := nv.Encode() + if err != nil { + return rev, nil, false, err + } + + // Update with a tombstone. + drev, err := b.kv.Update(ctx, key, data, uint64(rev)) + if err != nil { + if jsWrongLastSeqErr.Is(err) { + b.l.Warnf("delete conflict: key=%s, rev=%d, err=%s", key, rev, err) + return 0, nil, false, nil + } + return rev, value.KV, false, nil + } + + err = b.kv.Delete(ctx, key, jetstream.LastRevision(drev)) + if err != nil { + if jsWrongLastSeqErr.Is(err) { + b.l.Warnf("delete conflict: key=%s, rev=%d, err=%s", key, drev, err) + return 0, nil, false, nil + } + return rev, value.KV, false, nil + } + + return int64(drev), value.KV, true, nil +} + +func (b *Backend) Update(ctx context.Context, key string, value []byte, revision, lease int64) (int64, *server.KeyValue, bool, error) { + // Get the latest revision of the key. + rev, pnd, err := b.get(ctx, key, 0, false) + // TODO: correct semantics for these various errors? + if err != nil { + if err == jetstream.ErrKeyNotFound { + return rev, nil, false, nil + } + return rev, nil, false, err + } + + // Return nothing? + if pnd == nil { + return 0, nil, false, nil + } + + // Incorrect revision, return the current value. + if pnd.KV.ModRevision != revision { + return rev, pnd.KV, false, nil + } + + nd := natsData{ + Delete: false, + Create: false, + PrevRevision: pnd.KV.ModRevision, + KV: &server.KeyValue{ + Key: key, + CreateRevision: pnd.KV.CreateRevision, + Value: value, + Lease: lease, + }, + } + + if pnd.KV.CreateRevision == 0 { + nd.KV.CreateRevision = rev + } + + data, err := nd.Encode() + if err != nil { + return 0, nil, false, err + } + + seq, err := b.kv.Update(ctx, key, data, uint64(revision)) + if err != nil { + // This may occur if a concurrent writer created the key. + if jsWrongLastSeqErr.Is(err) { + b.l.Warnf("update conflict: key=%s, rev=%d, err=%s", key, revision, err) + return 0, nil, false, nil + } + return 0, nil, false, err + } + + nd.KV.ModRevision = int64(seq) + + return int64(seq), nd.KV, true, nil +} + +// List returns a range of keys starting with the prefix. +// This would translated to one or more tokens, e.g. `a.b.c`. +// The startKey would be the next set of tokens that follow the prefix +// that are alphanumerically equal to or greater than the startKey. +// If limit is provided, the maximum set of matches is limited. +// If revision is provided, this indicates the maximum revision to return. +func (b *Backend) List(ctx context.Context, prefix, startKey string, limit, maxRevision int64) (int64, []*server.KeyValue, error) { + matches, err := b.kv.List(ctx, prefix, startKey, limit, maxRevision) + if err != nil { + return 0, nil, err + } + + kvs := make([]*server.KeyValue, 0, len(matches)) + for _, e := range matches { + var nd natsData + err = nd.Decode(e) + if err != nil { + return 0, nil, err + } + kvs = append(kvs, nd.KV) + } + + storeRev := b.kv.BucketRevision() + return storeRev, kvs, nil +} + +func (b *Backend) Watch(ctx context.Context, prefix string, startRevision int64) server.WatchResult { + events := make(chan []*server.Event, 32) + + rev := startRevision + if rev == 0 { + rev = b.kv.BucketRevision() + } + + go func() { + defer close(events) + + var w jetstream.KeyWatcher + for { + var err error + w, err = b.kv.Watch(ctx, prefix, startRevision) + if err == nil { + break + } + b.l.Warnf("watch init: prefix=%s, err=%s", prefix, err) + time.Sleep(time.Second) + } + + for { + select { + case <-ctx.Done(): + err := ctx.Err() + if err != nil && err != context.Canceled { + b.l.Warnf("watch ctx: prefix=%s, err=%s", prefix, err) + } + return + + case e := <-w.Updates(): + if e.Operation() != jetstream.KeyValuePut { + continue + } + + key := e.Key() + + var nd natsData + err := nd.Decode(e) + if err != nil { + b.l.Warnf("watch decode: key=%s, err=%s", key, err) + continue + } + + event := server.Event{ + Create: nd.Create, + Delete: nd.Delete, + KV: nd.KV, + PrevKV: &server.KeyValue{}, + } + + if nd.PrevRevision > 0 { + _, pnd, err := b.get(ctx, key, nd.PrevRevision, false) + if err == nil { + event.PrevKV = pnd.KV + } + } + + events <- []*server.Event{&event} + } + } + }() + + return server.WatchResult{ + Events: events, + CurrentRevision: rev, + } +} diff --git a/pkg/drivers/nats/backend_test.go b/pkg/drivers/nats/backend_test.go new file mode 100644 index 00000000..eea2bbd3 --- /dev/null +++ b/pkg/drivers/nats/backend_test.go @@ -0,0 +1,378 @@ +package nats + +import ( + "context" + "errors" + "io/ioutil" + "testing" + "time" + + kserver "github.com/k3s-io/kine/pkg/server" + "github.com/nats-io/nats-server/v2/server" + "github.com/nats-io/nats-server/v2/test" + "github.com/nats-io/nats.go" + "github.com/nats-io/nats.go/jetstream" + "github.com/sirupsen/logrus" +) + +func noErr(t *testing.T, err error) { + t.Helper() + if err != nil { + t.Fatal(err) + } +} + +func expErr(t *testing.T, err error) { + t.Helper() + if err == nil { + t.Fatal("expected error") + } +} + +func expEqualErr(t *testing.T, want, got error) { + t.Helper() + if !errors.Is(want, got) { + t.Fatalf("expected %v, got %v", want, got) + } +} + +func expEqual[T comparable](t *testing.T, want, got T) { + t.Helper() + if got != want { + t.Fatalf("expected %v, got %v", want, got) + } +} + +func expSortedKeys(t *testing.T, ents []*kserver.KeyValue) { + t.Helper() + var prev string + for _, ent := range ents { + if prev != "" { + if prev > ent.Key { + t.Fatalf("keys not sorted: %s > %s", prev, ent.Key) + } + } + prev = ent.Key + } +} + +func expEqualKeys(t *testing.T, want []string, got []*kserver.KeyValue) { + t.Helper() + expEqual(t, len(want), len(got)) + for i, k := range want { + expEqual(t, k, got[i].Key) + } +} + +func setupBackend(t *testing.T) (*server.Server, *nats.Conn, *Backend) { + ns := test.RunServer(&server.Options{ + Port: -1, + JetStream: true, + StoreDir: t.TempDir(), + }) + + nc, err := nats.Connect(ns.ClientURL()) + noErr(t, err) + + js, err := jetstream.New(nc) + noErr(t, err) + + ctx := context.Background() + + bkt, err := js.CreateKeyValue(ctx, jetstream.KeyValueConfig{ + Bucket: "kine", + History: 10, + }) + noErr(t, err) + + ekv := NewKeyValue(ctx, bkt, js) + + l := logrus.New() + l.SetOutput(ioutil.Discard) + + b := Backend{ + l: l, + kv: ekv, + js: js, + } + + return ns, nc, &b +} + +func TestBackend_Create(t *testing.T) { + ns, nc, b := setupBackend(t) + defer ns.Shutdown() + defer nc.Drain() + + ctx := context.Background() + + // Create a key. + rev, err := b.Create(ctx, "/a", nil, 0) + noErr(t, err) + expEqual(t, 1, rev) + + // Attempt to create again. + _, err = b.Create(ctx, "/a", nil, 0) + expEqualErr(t, err, kserver.ErrKeyExists) + + rev, err = b.Create(ctx, "/a/b", nil, 0) + noErr(t, err) + expEqual(t, 2, rev) + + rev, err = b.Create(ctx, "/a/b/c", nil, 0) + noErr(t, err) + expEqual(t, 3, rev) + + rev, err = b.Create(ctx, "/b", nil, 1) + noErr(t, err) + expEqual(t, 4, rev) + + time.Sleep(2 * time.Millisecond) + + srev, count, err := b.Count(ctx, "/") + noErr(t, err) + expEqual(t, 4, srev) + expEqual(t, 4, count) + + time.Sleep(time.Second) + + srev, count, err = b.Count(ctx, "/") + noErr(t, err) + expEqual(t, 4, srev) + expEqual(t, 3, count) + + // Create /b again. Rev is 6 due to the internal delete. + // on read. + rev, err = b.Create(ctx, "/b", nil, 0) + noErr(t, err) + expEqual(t, 6, rev) + + time.Sleep(2 * time.Millisecond) + + srev, count, err = b.Count(ctx, "/") + noErr(t, err) + expEqual(t, 6, srev) + expEqual(t, 4, count) +} + +func TestBackend_Get(t *testing.T) { + ns, nc, b := setupBackend(t) + defer ns.Shutdown() + defer nc.Drain() + + ctx := context.Background() + + // Create with lease. + rev, err := b.Create(ctx, "/a", []byte("b"), 1) + noErr(t, err) + + time.Sleep(2 * time.Millisecond) + + srev, ent, err := b.Get(ctx, "/a", "", 0, 0) + noErr(t, err) + expEqual(t, 1, srev) + expEqual(t, "/a", ent.Key) + expEqual(t, "b", string(ent.Value)) + expEqual(t, 1, ent.Lease) + expEqual(t, 1, ent.ModRevision) + expEqual(t, 1, ent.CreateRevision) + + time.Sleep(time.Second) + + // Latest is gone. + _, ent, err = b.Get(ctx, "/a", "", 0, 0) + expEqualErr(t, nil, err) + + // Get at a revision will fail also. + _, ent, err = b.Get(ctx, "/a", "", 0, 1) + expEqualErr(t, nil, err) + + // Get at later revision, does not exist. + _, _, err = b.Get(ctx, "/a", "", 0, 2) + expEqualErr(t, nil, err) + + // Create it again and update it. + rev, err = b.Create(ctx, "/a", []byte("c"), 0) + noErr(t, err) + expEqual(t, 3, rev) + + _, _, _, err = b.Update(ctx, "/a", []byte("d"), rev, 0) + noErr(t, err) + + // Get at prior version. + rev, ent, err = b.Get(ctx, "/a", "", 0, rev) + noErr(t, err) + expEqual(t, 3, rev) + expEqual(t, "/a", ent.Key) + expEqual(t, "c", string(ent.Value)) + expEqual(t, 0, ent.Lease) + expEqual(t, 3, ent.ModRevision) + expEqual(t, 3, ent.CreateRevision) +} + +func TestBackend_Update(t *testing.T) { + ns, nc, b := setupBackend(t) + defer ns.Shutdown() + defer nc.Drain() + + ctx := context.Background() + + // Create with lease. + b.Create(ctx, "/a", []byte("b"), 1) + rev, ent, ok, err := b.Update(ctx, "/a", []byte("c"), 1, 0) + noErr(t, err) + expEqual(t, 2, rev) + expEqual(t, true, ok) + expEqual(t, "/a", ent.Key) + expEqual(t, "c", string(ent.Value)) + expEqual(t, 0, ent.Lease) + expEqual(t, 2, ent.ModRevision) + expEqual(t, 1, ent.CreateRevision) + + rev, ent, ok, err = b.Update(ctx, "/a", []byte("d"), 2, 1) + noErr(t, err) + expEqual(t, 3, rev) + expEqual(t, true, ok) + expEqual(t, "/a", ent.Key) + expEqual(t, "d", string(ent.Value)) + expEqual(t, 1, ent.Lease) + expEqual(t, 3, ent.ModRevision) + expEqual(t, 1, ent.CreateRevision) + + // Update with wrong revision. + rev, _, ok, err = b.Update(ctx, "/a", []byte("e"), 2, 1) + noErr(t, err) + expEqual(t, 3, rev) + expEqual(t, false, ok) +} + +func TestBackend_Delete(t *testing.T) { + ns, nc, b := setupBackend(t) + defer ns.Shutdown() + defer nc.Drain() + + ctx := context.Background() + + // Create with lease. + b.Create(ctx, "/a", []byte("b"), 1) + + // Note, deleting first performs an update to tombstone + // the key, followed by a KV delete. + rev, ent, ok, err := b.Delete(ctx, "/a", 1) + noErr(t, err) + expEqual(t, 2, rev) + expEqual(t, true, ok) + expEqual(t, "/a", ent.Key) + expEqual(t, "b", string(ent.Value)) + expEqual(t, 1, ent.Lease) + expEqual(t, 1, ent.ModRevision) + expEqual(t, 1, ent.CreateRevision) + + // Create again. + b.Create(ctx, "/a", []byte("b"), 0) + + // Fail to delete since the revision is not the same. + rev, _, ok, err = b.Delete(ctx, "/a", 1) + expEqual(t, 4, rev) + expEqual(t, false, ok) + expEqualErr(t, nil, err) + + // No revision, will delete the latest. + rev, _, ok, err = b.Delete(ctx, "/a", 0) + expEqual(t, 5, rev) + expEqual(t, true, ok) + expEqualErr(t, nil, err) +} + +func TestBackend_List(t *testing.T) { + ns, nc, b := setupBackend(t) + defer ns.Shutdown() + defer nc.Drain() + + ctx := context.Background() + + // Create a key. + b.Create(ctx, "/a/b/c", nil, 0) + b.Create(ctx, "/a", nil, 0) + b.Create(ctx, "/b", nil, 0) + b.Create(ctx, "/a/b", nil, 0) + b.Create(ctx, "/c", nil, 0) + b.Create(ctx, "/d/a", nil, 0) + b.Create(ctx, "/d/b", nil, 0) + + // Wait for the btree to be updated. + time.Sleep(time.Millisecond) + + // List the keys. + rev, ents, err := b.List(ctx, "/", "", 0, 0) + noErr(t, err) + expEqual(t, 7, rev) + expEqual(t, 7, len(ents)) + expSortedKeys(t, ents) + + // List the keys with prefix. + rev, ents, err = b.List(ctx, "/a", "", 0, 0) + noErr(t, err) + expEqual(t, 7, rev) + expEqual(t, 3, len(ents)) + expSortedKeys(t, ents) + + // List the keys >= start key. + rev, ents, err = b.List(ctx, "/", "b", 0, 0) + noErr(t, err) + expEqual(t, 7, rev) + expEqual(t, 4, len(ents)) + expSortedKeys(t, ents) + + // List the keys up to a revision. + rev, ents, err = b.List(ctx, "/", "", 0, 3) + noErr(t, err) + expEqual(t, 7, rev) + expEqual(t, 3, len(ents)) + expSortedKeys(t, ents) + expEqualKeys(t, []string{"/a", "/a/b/c", "/b"}, ents) + + // List the keys with a limit. + rev, ents, err = b.List(ctx, "/", "", 4, 0) + noErr(t, err) + expEqual(t, 7, rev) + expEqual(t, 4, len(ents)) + expSortedKeys(t, ents) + expEqualKeys(t, []string{"/a", "/a/b", "/a/b/c", "/b"}, ents) + + // List the keys with a limit after some start key. + rev, ents, err = b.List(ctx, "/", "b", 2, 0) + noErr(t, err) + expEqual(t, 7, rev) + expEqual(t, 2, len(ents)) + expSortedKeys(t, ents) + expEqualKeys(t, []string{"/b", "/c"}, ents) +} + +func TestBackend_Watch(t *testing.T) { + ns, nc, b := setupBackend(t) + defer ns.Shutdown() + defer nc.Drain() + + ctx := context.Background() + + cctx, cancel := context.WithCancel(ctx) + defer cancel() + + rev1, _ := b.Create(ctx, "/a", nil, 0) + rev2, _ := b.Create(ctx, "/a/1", nil, 0) + rev1, _, _, _ = b.Update(ctx, "/a", nil, rev1, 0) + b.Delete(ctx, "/a", rev1) + b.Update(ctx, "/a/1", nil, rev2, 0) + + wr := b.Watch(cctx, "/a", 0) + time.Sleep(20 * time.Millisecond) + cancel() + + var events []*kserver.Event + for es := range wr.Events { + events = append(events, es...) + } + + expEqual(t, 5, len(events)) +} diff --git a/pkg/drivers/nats/codec.go b/pkg/drivers/nats/codec.go new file mode 100644 index 00000000..4b25b37e --- /dev/null +++ b/pkg/drivers/nats/codec.go @@ -0,0 +1,89 @@ +package nats + +import ( + "fmt" + "io" + "strings" + + "github.com/klauspost/compress/s2" + "github.com/nats-io/nats.go/jetstream" + "github.com/shengdoushi/base58" +) + +var ( + keyAlphabet = base58.BitcoinAlphabet +) + +// keyCodec turns keys like /this/is/a.test.key into Base58 encoded values +// split on `.` This is because NATS keys are split on . rather than /. +type keyCodec struct{} + +func (e *keyCodec) EncodeRange(prefix string) (string, error) { + if prefix == "/" { + return ">", nil + } + + ek, err := e.Encode(prefix) + if err != nil { + return "", err + } + + return fmt.Sprintf("%s.>", ek), nil +} + +func (*keyCodec) Encode(key string) (retKey string, e error) { + if key == "" { + return "", jetstream.ErrInvalidKey + } + + // Trim leading and trailing slashes. + key = strings.Trim(key, "/") + + var parts []string + for _, part := range strings.Split(key, "/") { + parts = append(parts, base58.Encode([]byte(part), keyAlphabet)) + } + + if len(parts) == 0 { + return "", jetstream.ErrInvalidKey + } + + return strings.Join(parts, "."), nil +} + +func (*keyCodec) Decode(key string) (retKey string, e error) { + var parts []string + + for _, s := range strings.Split(key, ".") { + decodedPart, err := base58.Decode(s, keyAlphabet) + if err != nil { + return "", err + } + parts = append(parts, string(decodedPart[:])) + } + + if len(parts) == 0 { + return "", jetstream.ErrInvalidKey + } + + return fmt.Sprintf("/%s", strings.Join(parts, "/")), nil +} + +// valueCodec is a codec that compresses values using s2. +type valueCodec struct{} + +func (*valueCodec) Encode(src []byte, dst io.Writer) error { + enc := s2.NewWriter(dst) + err := enc.EncodeBuffer(src) + if err != nil { + enc.Close() + return err + } + return enc.Close() +} + +func (*valueCodec) Decode(src io.Reader, dst io.Writer) error { + dec := s2.NewReader(src) + _, err := io.Copy(dst, dec) + return err +} diff --git a/pkg/drivers/nats/codec_test.go b/pkg/drivers/nats/codec_test.go new file mode 100644 index 00000000..6ef724c7 --- /dev/null +++ b/pkg/drivers/nats/codec_test.go @@ -0,0 +1,93 @@ +package nats + +import "testing" + +func TestKeyEncode(t *testing.T) { + tests := []struct { + In string + Out string + Err bool + }{ + {"", "", true}, + {"/", "", true}, + {"a", "2g", false}, + {"/a/a", "2g.2g", false}, + {"a/a", "2g.2g", false}, + {"a/a/a", "2g.2g.2g", false}, + {"a/*/a", "2g.j.2g", false}, + {"a/*/a/", "2g.j.2g", false}, + } + + codec := &keyCodec{} + + for _, test := range tests { + out, err := codec.Encode(test.In) + if err != nil { + if !test.Err { + t.Errorf("Expected no error for %q, got %v", test.In, err) + } + continue + } + if out != test.Out { + t.Errorf("Expected %q for %q, got %q", test.Out, test.In, out) + } + } +} + +func TestKeyDecode(t *testing.T) { + tests := []struct { + In string + Out string + Err bool + }{ + {"", "/", false}, + {"2g", "/a", false}, + {"2g.2g", "/a/a", false}, + {"2g.2g.2g", "/a/a/a", false}, + } + + codec := &keyCodec{} + + for _, test := range tests { + out, err := codec.Decode(test.In) + if err != nil { + if !test.Err { + t.Errorf("Expected no error for %q, got %v", test.In, err) + } + continue + } + if out != test.Out { + t.Errorf("Expected %q for %q, got %q", test.Out, test.In, out) + } + } +} + +func TestKeyEncodeRange(t *testing.T) { + tests := []struct { + In string + Out string + Err bool + }{ + {"", "", true}, + {"/", ">", false}, + {"a", "2g.>", false}, + {"/a/a", "2g.2g.>", false}, + {"a/a/a", "2g.2g.2g.>", false}, + {"a/*/a", "2g.j.2g.>", false}, + } + + codec := &keyCodec{} + + for _, test := range tests { + out, err := codec.EncodeRange(test.In) + if err != nil { + if !test.Err { + t.Errorf("Expected no error for %q, got %v", test.In, err) + } + continue + } + if out != test.Out { + t.Errorf("Expected %q for %q, got %q", test.Out, test.In, out) + } + } +} diff --git a/pkg/drivers/nats/config.go b/pkg/drivers/nats/config.go new file mode 100644 index 00000000..5480a6b3 --- /dev/null +++ b/pkg/drivers/nats/config.go @@ -0,0 +1,187 @@ +package nats + +import ( + "fmt" + "net/url" + "strconv" + "strings" + "time" + + natsserver "github.com/k3s-io/kine/pkg/drivers/nats/server" + "github.com/k3s-io/kine/pkg/tls" + "github.com/nats-io/jsm.go/natscontext" + "github.com/nats-io/nats.go" + "github.com/sirupsen/logrus" +) + +type Config struct { + // Client URL which could be a list of comma separated URLs. + clientURL string + // Client connection options. + clientOptions []nats.Option + // Number of revisions to keep in history. Defaults to 10. + revHistory uint8 + // Name of the bucket. Defaults to "kine". + bucket string + // Number of replicas for the bucket. Defaults to 1 + replicas int + // Indicates the duration of a method before it is considered slow. Defaults to 500ms. + slowThreshold time.Duration + // If true, an embedded server will not be used. + noEmbed bool + // If true, use a socket for the embedded server. + dontListen bool + // Path to a server configuration file when embedded. + serverConfig string + // If true, the embedded server will log to stdout. + stdoutLogging bool + // The explicit host to listen on when embedded. + host string + // The explicit port to listen on when embedded. + port int + // Data directory. + dataDir string +} + +// parseConnection returns nats connection url, bucketName and []nats.Option, error +func parseConnection(dsn string, tlsInfo tls.Config) (*Config, error) { + config := &Config{ + slowThreshold: defaultSlowMethod, + revHistory: defaultRevHistory, + bucket: defaultBucket, + replicas: defaultReplicas, + } + + // Parse the first URL in the connection string which contains the + // query parameters. + connections := strings.Split(dsn, ",") + u, err := url.Parse(connections[0]) + if err != nil { + return nil, err + } + + // Extract the host and port if embedded server is used. + config.host = u.Hostname() + if u.Port() != "" { + config.port, _ = strconv.Atoi(u.Port()) + } + + // Extract the query parameters to build configuration. + queryMap, err := url.ParseQuery(u.RawQuery) + if err != nil { + return nil, err + } + + if v := queryMap.Get("bucket"); v != "" { + config.bucket = v + } + + if v := queryMap.Get("replicas"); v != "" { + if r, err := strconv.ParseUint(v, 10, 8); err == nil { + if r >= 1 && r <= 5 { + config.replicas = int(r) + } else { + return nil, fmt.Errorf("invalid replicas, must be >= 1 and <= 5") + } + } + } + + if d := queryMap.Get("slowMethod"); d != "" { + if dur, err := time.ParseDuration(d); err == nil { + config.slowThreshold = dur + } else { + return nil, fmt.Errorf("invalid slowMethod duration: %w", err) + } + } + + if r := queryMap.Get("revHistory"); r != "" { + if revs, err := strconv.ParseUint(r, 10, 8); err == nil { + if revs >= 2 && revs <= 64 { + config.revHistory = uint8(revs) + } else { + return nil, fmt.Errorf("invalid revHistory, must be >= 2 and <= 64") + } + } + } + + if tlsInfo.KeyFile != "" && tlsInfo.CertFile != "" { + config.clientOptions = append(config.clientOptions, nats.ClientCert(tlsInfo.CertFile, tlsInfo.KeyFile)) + } + + if tlsInfo.CAFile != "" { + config.clientOptions = append(config.clientOptions, nats.RootCAs(tlsInfo.CAFile)) + } + + // Simpler direct reference to creds file. + if f := queryMap.Get("credsFile"); f != "" { + config.clientOptions = append(config.clientOptions, nats.UserCredentials(f)) + } + + // Reference a full context file. Note this will override any other options. + if f := queryMap.Get("contextFile"); f != "" { + if u.Host != "" { + return config, fmt.Errorf("when using context endpoint no host should be provided") + } + + logrus.Debugf("loading nats context file: %s", f) + + natsContext, err := natscontext.NewFromFile(f) + if err != nil { + return nil, err + } + + connections = strings.Split(natsContext.ServerURL(), ",") + + // command line options provided to kine will override the file + // https://github.com/nats-io/jsm.go/blob/v0.0.29/natscontext/context.go#L257 + // allows for user, creds, nke, token, certifcate, ca, inboxprefix from the context.json + natsClientOpts, err := natsContext.NATSOptions(config.clientOptions...) + if err != nil { + return nil, err + } + config.clientOptions = natsClientOpts + } + + connBuilder := strings.Builder{} + for idx, c := range connections { + if idx > 0 { + connBuilder.WriteString(",") + } + + u, err := url.Parse(c) + if err != nil { + return nil, err + } + + if u.Scheme != "nats" { + return nil, fmt.Errorf("invalid connection string=%s", c) + } + + connBuilder.WriteString("nats://") + + if u.User != nil && idx == 0 { + userInfo := strings.Split(u.User.String(), ":") + if len(userInfo) > 1 { + config.clientOptions = append(config.clientOptions, nats.UserInfo(userInfo[0], userInfo[1])) + } else { + config.clientOptions = append(config.clientOptions, nats.Token(userInfo[0])) + } + } + connBuilder.WriteString(u.Host) + } + + config.clientURL = connBuilder.String() + + // Config options only relevant if built with embedded NATS. + if natsserver.Embedded { + config.noEmbed = queryMap.Has("noEmbed") + config.serverConfig = queryMap.Get("serverConfig") + config.stdoutLogging = queryMap.Has("stdoutLogging") + config.dontListen = queryMap.Has("dontListen") + config.dataDir = queryMap.Get("dataDir") + } + + logrus.Debugf("using config %#v", config) + + return config, nil +} diff --git a/pkg/drivers/nats/kv.go b/pkg/drivers/nats/kv.go new file mode 100644 index 00000000..9a67eebb --- /dev/null +++ b/pkg/drivers/nats/kv.go @@ -0,0 +1,512 @@ +package nats + +import ( + "bytes" + "context" + "fmt" + "strings" + "sync" + "time" + + "github.com/nats-io/nats.go/jetstream" + "github.com/sirupsen/logrus" + "github.com/tidwall/btree" +) + +type entry struct { + kc *keyCodec + vc *valueCodec + entry jetstream.KeyValueEntry +} + +func (e *entry) Key() string { + dk, err := e.kc.Decode(e.entry.Key()) + // should not happen + if err != nil { + // should not happen + logrus.Warnf("could not decode key %s: %v", e.entry.Key(), err) + return "" + } + + return dk +} + +func (e *entry) Bucket() string { return e.entry.Bucket() } +func (e *entry) Value() []byte { + buf := new(bytes.Buffer) + if err := e.vc.Decode(bytes.NewBuffer(e.entry.Value()), buf); err != nil { + // should not happen + logrus.Warnf("could not decode value for %s: %v", e.Key(), err) + } + return buf.Bytes() +} +func (e *entry) Revision() uint64 { return e.entry.Revision() } +func (e *entry) Created() time.Time { return e.entry.Created() } +func (e *entry) Delta() uint64 { return e.entry.Delta() } +func (e *entry) Operation() jetstream.KeyValueOp { return e.entry.Operation() } + +type seqOp struct { + seq uint64 + op jetstream.KeyValueOp + ex time.Time +} + +type streamWatcher struct { + con jetstream.Consumer + cctx jetstream.ConsumeContext + keyCodec *keyCodec + valueCodec *valueCodec + updates chan jetstream.KeyValueEntry + keyPrefix string + ctx context.Context + cancel context.CancelFunc +} + +func (w *streamWatcher) Context() context.Context { + if w == nil { + return nil + } + return w.ctx +} + +func (w *streamWatcher) Updates() <-chan jetstream.KeyValueEntry { + return w.updates +} + +func (w *streamWatcher) Stop() error { + if w.cancel != nil { + w.cancel() + } + if w.cctx != nil { + w.cctx.Stop() + } + return nil +} + +type kvEntry struct { + key string + bucket string + value []byte + revision uint64 + created time.Time + delta uint64 + operation jetstream.KeyValueOp +} + +func (e *kvEntry) Key() string { + return e.key +} + +func (e *kvEntry) Bucket() string { return e.bucket } +func (e *kvEntry) Value() []byte { + return e.value +} +func (e *kvEntry) Revision() uint64 { return e.revision } +func (e *kvEntry) Created() time.Time { return e.created } +func (e *kvEntry) Delta() uint64 { return e.delta } +func (e *kvEntry) Operation() jetstream.KeyValueOp { return e.operation } + +type KeyValue struct { + nkv jetstream.KeyValue + js jetstream.JetStream + kc *keyCodec + vc *valueCodec + bt *btree.Map[string, []*seqOp] + btm sync.RWMutex + lastSeq uint64 +} + +func (e *KeyValue) Get(ctx context.Context, key string) (jetstream.KeyValueEntry, error) { + ek, err := e.kc.Encode(key) + if err != nil { + return nil, err + } + + ent, err := e.nkv.Get(ctx, ek) + if err != nil { + return nil, err + } + + return &entry{ + kc: e.kc, + vc: e.vc, + entry: ent, + }, nil +} + +func (e *KeyValue) GetRevision(ctx context.Context, key string, revision uint64) (jetstream.KeyValueEntry, error) { + ek, err := e.kc.Encode(key) + if err != nil { + return nil, err + } + + ent, err := e.nkv.GetRevision(ctx, ek, revision) + if err != nil { + return nil, err + } + + return &entry{ + kc: e.kc, + vc: e.vc, + entry: ent, + }, nil +} + +func (e *KeyValue) Create(ctx context.Context, key string, value []byte) (uint64, error) { + ek, err := e.kc.Encode(key) + if err != nil { + return 0, err + } + + buf := new(bytes.Buffer) + + err = e.vc.Encode(value, buf) + if err != nil { + return 0, err + } + + return e.nkv.Create(ctx, ek, buf.Bytes()) +} + +func (e *KeyValue) Update(ctx context.Context, key string, value []byte, last uint64) (uint64, error) { + ek, err := e.kc.Encode(key) + if err != nil { + return 0, err + } + + buf := new(bytes.Buffer) + + err = e.vc.Encode(value, buf) + if err != nil { + return 0, err + } + + return e.nkv.Update(ctx, ek, buf.Bytes(), last) +} + +func (e *KeyValue) Delete(ctx context.Context, key string, opts ...jetstream.KVDeleteOpt) error { + ek, err := e.kc.Encode(key) + if err != nil { + return err + } + + return e.nkv.Delete(ctx, ek, opts...) +} + +func (e *KeyValue) Watch(ctx context.Context, keys string, startRev int64) (jetstream.KeyWatcher, error) { + // Everything but the last token will be treated as a filter + // on the watcher. The last token will used as a deliver-time filter. + filter := keys + if !strings.HasSuffix(filter, "/") { + idx := strings.LastIndexByte(filter, '/') + if idx > -1 { + filter = keys[:idx+1] + } + } + + if filter != "" { + p, err := e.kc.EncodeRange(filter) + if err != nil { + return nil, err + } + filter = fmt.Sprintf("$KV.%s.%s", e.nkv.Bucket(), p) + } + + wctx, cancel := context.WithCancel(ctx) + + updates := make(chan jetstream.KeyValueEntry, 100) + subjectPrefix := fmt.Sprintf("$KV.%s.", e.nkv.Bucket()) + + handler := func(msg jetstream.Msg) { + md, _ := msg.Metadata() + key := strings.TrimPrefix(msg.Subject(), subjectPrefix) + + if keys != "" { + dkey, err := e.kc.Decode(strings.TrimPrefix(key, ".")) + if err != nil || !strings.HasPrefix(dkey, keys) { + return + } + } + + // Default is PUT + var op jetstream.KeyValueOp + switch msg.Headers().Get("KV-Operation") { + case "DEL": + op = jetstream.KeyValueDelete + case "PURGE": + op = jetstream.KeyValuePurge + } + // Not currently used... + delta := 0 + + updates <- &entry{ + kc: e.kc, + vc: e.vc, + entry: &kvEntry{ + key: key, + bucket: e.nkv.Bucket(), + value: msg.Data(), + revision: md.Sequence.Stream, + created: md.Timestamp, + delta: uint64(delta), + operation: op, + }, + } + } + + var dp jetstream.DeliverPolicy + var cfg jetstream.OrderedConsumerConfig + if startRev <= 0 { + dp = jetstream.DeliverAllPolicy + } else { + dp = jetstream.DeliverByStartSequencePolicy + cfg.OptStartSeq = uint64(startRev) + } + cfg.DeliverPolicy = dp + + con, err := e.js.OrderedConsumer(ctx, fmt.Sprintf("KV_%s", e.nkv.Bucket()), cfg) + if err != nil { + cancel() + return nil, err + } + + ci := con.CachedInfo() + cctx, err := con.Consume(handler, + jetstream.ConsumeErrHandler(func(cctx jetstream.ConsumeContext, err error) { + if !strings.Contains(err.Error(), "Server Shutdown") { + logrus.Warnf("error consuming from %s: %v", ci.Name, err) + } + }), + ) + if err != nil { + cancel() + return nil, err + } + + w := &streamWatcher{ + con: con, + cctx: cctx, + keyCodec: e.kc, + valueCodec: e.vc, + updates: updates, + ctx: wctx, + cancel: cancel, + } + + return w, nil +} + +// BucketSize returns the size of the bucket in bytes. +func (e *KeyValue) BucketSize(ctx context.Context) (int64, error) { + status, err := e.nkv.Status(ctx) + if err != nil { + return 0, err + } + return int64(status.Bytes()), nil +} + +// BucketRevision returns the latest revision of the bucket. +func (e *KeyValue) BucketRevision() int64 { + e.btm.RLock() + s := e.lastSeq + e.btm.RUnlock() + return int64(s) +} + +func (e *KeyValue) btreeWatcher(ctx context.Context) error { + w, err := e.Watch(ctx, "/", int64(e.lastSeq)) + if err != nil { + return err + } + defer w.Stop() + + status, _ := e.nkv.Status(ctx) + hsize := status.History() + + for { + select { + case <-ctx.Done(): + return ctx.Err() + + case x := <-w.Updates(): + if x == nil { + continue + } + + seq := x.Revision() + op := x.Operation() + + key := x.Key() + + var ex time.Time + if op == jetstream.KeyValuePut { + var nd natsData + err = nd.Decode(x) + if err != nil { + continue + } + if nd.KV.Lease > 0 { + ex = nd.CreateTime.Add(time.Second * time.Duration(nd.KV.Lease)) + } + } + + e.btm.Lock() + e.lastSeq = seq + val, ok := e.bt.Get(key) + if !ok { + val = make([]*seqOp, 0, hsize) + } + // Remove the oldest entry. + if len(val) == cap(val) { + val = append(val[:0], val[1:]...) + } + val = append(val, &seqOp{ + seq: seq, + op: op, + ex: ex, + }) + e.bt.Set(key, val) + e.btm.Unlock() + } + } +} + +type keySeq struct { + key string + seq uint64 +} + +func (e *KeyValue) Count(ctx context.Context, prefix string) (int64, error) { + it := e.bt.Iter() + + if prefix != "" { + ok := it.Seek(prefix) + if !ok { + return 0, nil + } + } + + var count int64 + now := time.Now() + + e.btm.RLock() + for { + k := it.Key() + if !strings.HasPrefix(k, prefix) { + break + } + v := it.Value() + so := v[len(v)-1] + + if so.op == jetstream.KeyValuePut { + if so.ex.IsZero() || so.ex.After(now) { + count++ + } + } + + if !it.Next() { + break + } + } + e.btm.RUnlock() + + return count, nil +} + +func (e *KeyValue) List(ctx context.Context, prefix, startKey string, limit, revision int64) ([]jetstream.KeyValueEntry, error) { + seekKey := prefix + if startKey != "" { + seekKey = strings.TrimSuffix(seekKey, "/") + seekKey = fmt.Sprintf("%s/%s", seekKey, startKey) + } + + it := e.bt.Iter() + if seekKey != "" { + ok := it.Seek(seekKey) + if !ok { + return nil, nil + } + } + + var matches []*keySeq + + e.btm.RLock() + + for { + if limit > 0 && len(matches) == int(limit) { + break + } + + k := it.Key() + if !strings.HasPrefix(k, prefix) { + break + } + + v := it.Value() + + // Get the latest update for the key. + if revision <= 0 { + so := v[len(v)-1] + if so.op == jetstream.KeyValuePut { + if so.ex.IsZero() || so.ex.After(time.Now()) { + matches = append(matches, &keySeq{key: k, seq: so.seq}) + } + } + } else { + // Find the latest update below the given revision. + for i := len(v) - 1; i >= 0; i-- { + so := v[i] + if so.seq <= uint64(revision) { + if so.op == jetstream.KeyValuePut { + if so.ex.IsZero() || so.ex.After(time.Now()) { + matches = append(matches, &keySeq{key: k, seq: so.seq}) + } + } + break + } + } + } + + if !it.Next() { + break + } + } + e.btm.RUnlock() + + logrus.Debugf("kv: list: got %d matches from btree", len(matches)) + + var entries []jetstream.KeyValueEntry + for _, m := range matches { + e, err := e.GetRevision(ctx, m.key, m.seq) + if err != nil { + logrus.Errorf("get revision in list error: %s @ %d: %v", m.key, m.seq, err) + continue + } + entries = append(entries, e) + } + + return entries, nil +} + +func NewKeyValue(ctx context.Context, bucket jetstream.KeyValue, js jetstream.JetStream) *KeyValue { + kv := &KeyValue{ + nkv: bucket, + js: js, + kc: &keyCodec{}, + vc: &valueCodec{}, + bt: btree.NewMap[string, []*seqOp](0), + } + + go func() { + ctx, cancel := context.WithCancel(ctx) + defer cancel() + for { + err := kv.btreeWatcher(ctx) + if err != nil { + logrus.Errorf("btree watcher error: %v", err) + } + } + }() + + return kv +} diff --git a/pkg/drivers/nats/kv/etcd_encoder.go b/pkg/drivers/nats/kv/etcd_encoder.go deleted file mode 100644 index ab39820a..00000000 --- a/pkg/drivers/nats/kv/etcd_encoder.go +++ /dev/null @@ -1,103 +0,0 @@ -package kv - -import ( - "fmt" - "io" - "strings" - - "github.com/klauspost/compress/s2" - "github.com/nats-io/nats.go" - "github.com/shengdoushi/base58" -) - -// EtcdKeyCodec turns keys like /this/is/a.test.key into Base58 encoded values split on `/` -// This is because NATS Jetstream Keys are split on . rather than / -type EtcdKeyCodec struct{} - -type S2ValueCodec struct{} - -type PlainCodec struct{} - -var ( - keyAlphabet = base58.BitcoinAlphabet -) - -func (e *EtcdKeyCodec) EncodeRange(keys string) (string, error) { - ek, err := e.Encode(keys) - if err != nil { - return "", err - } - if strings.HasSuffix(ek, ".") { - return fmt.Sprintf("%s>", ek), nil - } - return ek, nil -} - -func (*EtcdKeyCodec) Encode(key string) (retKey string, e error) { - //defer func() { - // logrus.Debugf("encoded %s => %s", key, retKey) - //}() - parts := []string{} - for _, part := range strings.Split(strings.TrimPrefix(key, "/"), "/") { - if part == ">" || part == "*" { - parts = append(parts, part) - continue - } - parts = append(parts, base58.Encode([]byte(part), keyAlphabet)) - } - - if len(parts) == 0 { - return "", nats.ErrInvalidKey - } - - return strings.Join(parts, "."), nil -} - -func (*EtcdKeyCodec) Decode(key string) (retKey string, e error) { - //defer func() { - // logrus.Debugf("decoded %s => %s", key, retKey) - //}() - parts := []string{} - for _, s := range strings.Split(key, ".") { - decodedPart, err := base58.Decode(s, keyAlphabet) - if err != nil { - return "", err - } - parts = append(parts, string(decodedPart[:])) - } - if len(parts) == 0 { - return "", nats.ErrInvalidKey - } - return fmt.Sprintf("/%s", strings.Join(parts, "/")), nil -} - -func (*S2ValueCodec) Encode(src []byte, dst io.Writer) error { - enc := s2.NewWriter(dst) - err := enc.EncodeBuffer(src) - if err != nil { - enc.Close() - return err - } - return enc.Close() -} - -func (*S2ValueCodec) Decode(src io.Reader, dst io.Writer) error { - dec := s2.NewReader(src) - _, err := io.Copy(dst, dec) - return err -} - -func (*PlainCodec) Encode(src []byte, dst io.Writer) error { - _, err := dst.Write(src) - return err -} - -func (*PlainCodec) Decode(src io.Reader, dst io.Writer) error { - b, err := io.ReadAll(src) - if err != nil { - return err - } - _, err = dst.Write(b) - - return err -} diff --git a/pkg/drivers/nats/kv/kv.go b/pkg/drivers/nats/kv/kv.go deleted file mode 100644 index 767ec322..00000000 --- a/pkg/drivers/nats/kv/kv.go +++ /dev/null @@ -1,293 +0,0 @@ -package kv - -import ( - "bytes" - "context" - "io" - "time" - - "github.com/nats-io/nats.go" - "github.com/sirupsen/logrus" -) - -func NewEncodedKV(bucket nats.KeyValue, k KeyCodec, v ValueCodec) nats.KeyValue { - return &EncodedKV{bucket: bucket, keyCodec: k, valueCodec: v} -} - -type WatcherWithCtx interface { - WatchWithCtx(ctx context.Context, keys string, opts ...nats.WatchOpt) nats.KeyWatcher -} - -type KeyCodec interface { - Encode(key string) (string, error) - Decode(key string) (string, error) - EncodeRange(keys string) (string, error) -} - -type ValueCodec interface { - Encode(src []byte, dst io.Writer) error - Decode(src io.Reader, dst io.Writer) error -} - -type EncodedKV struct { - WatcherWithCtx - bucket nats.KeyValue - keyCodec KeyCodec - valueCodec ValueCodec -} - -type watcher struct { - watcher nats.KeyWatcher - keyCodec KeyCodec - valueCodec ValueCodec - updates chan nats.KeyValueEntry - ctx context.Context - cancel context.CancelFunc -} - -func (w *watcher) Context() context.Context { - if w == nil { - return nil - } - return w.ctx -} - -type entry struct { - keyCodec KeyCodec - valueCodec ValueCodec - entry nats.KeyValueEntry -} - -func (e *entry) Key() string { - dk, err := e.keyCodec.Decode(e.entry.Key()) - // should not happen - if err != nil { - // should not happen - logrus.Warnf("could not decode key %s: %v", e.entry.Key(), err) - return "" - } - - return dk -} - -func (e *entry) Bucket() string { return e.entry.Bucket() } -func (e *entry) Value() []byte { - buf := new(bytes.Buffer) - if err := e.valueCodec.Decode(bytes.NewBuffer(e.entry.Value()), buf); err != nil { - // should not happen - logrus.Warnf("could not decode value for %s: %v", e.Key(), err) - } - return buf.Bytes() -} -func (e *entry) Revision() uint64 { return e.entry.Revision() } -func (e *entry) Created() time.Time { return e.entry.Created() } -func (e *entry) Delta() uint64 { return e.entry.Delta() } -func (e *entry) Operation() nats.KeyValueOp { return e.entry.Operation() } - -func (w *watcher) Updates() <-chan nats.KeyValueEntry { return w.updates } -func (w *watcher) Stop() error { - if w.cancel != nil { - w.cancel() - } - - return w.watcher.Stop() -} - -func (e *EncodedKV) newWatcher(w nats.KeyWatcher) nats.KeyWatcher { - watch := &watcher{ - watcher: w, - keyCodec: e.keyCodec, - valueCodec: e.valueCodec, - updates: make(chan nats.KeyValueEntry, 32)} - - if w.Context() == nil { - watch.ctx, watch.cancel = context.WithCancel(context.Background()) - } else { - watch.ctx, watch.cancel = context.WithCancel(w.Context()) - } - - go func() { - for { - select { - case ent := <-w.Updates(): - if ent == nil { - watch.updates <- nil - continue - } - - watch.updates <- &entry{ - keyCodec: e.keyCodec, - valueCodec: e.valueCodec, - entry: ent, - } - case <-watch.ctx.Done(): - return - } - } - }() - - return watch -} - -func (e *EncodedKV) Get(key string) (nats.KeyValueEntry, error) { - ek, err := e.keyCodec.Encode(key) - if err != nil { - return nil, err - } - - ent, err := e.bucket.Get(ek) - if err != nil { - return nil, err - } - - return &entry{ - keyCodec: e.keyCodec, - valueCodec: e.valueCodec, - entry: ent, - }, nil -} - -func (e *EncodedKV) GetRevision(key string, revision uint64) (nats.KeyValueEntry, error) { - ek, err := e.keyCodec.Encode(key) - if err != nil { - return nil, err - } - - ent, err := e.bucket.GetRevision(ek, revision) - if err != nil { - return nil, err - } - - return &entry{ - keyCodec: e.keyCodec, - valueCodec: e.valueCodec, - entry: ent, - }, nil -} - -func (e *EncodedKV) Put(key string, value []byte) (revision uint64, err error) { - ek, err := e.keyCodec.Encode(key) - if err != nil { - return 0, err - } - - buf := new(bytes.Buffer) - - err = e.valueCodec.Encode(value, buf) - if err != nil { - return 0, err - } - - return e.bucket.Put(ek, buf.Bytes()) -} - -func (e *EncodedKV) Create(key string, value []byte) (revision uint64, err error) { - ek, err := e.keyCodec.Encode(key) - if err != nil { - return 0, err - } - - buf := new(bytes.Buffer) - - err = e.valueCodec.Encode(value, buf) - if err != nil { - return 0, err - } - - return e.bucket.Create(ek, buf.Bytes()) -} - -func (e *EncodedKV) Update(key string, value []byte, last uint64) (revision uint64, err error) { - ek, err := e.keyCodec.Encode(key) - if err != nil { - return 0, err - } - - buf := new(bytes.Buffer) - - err = e.valueCodec.Encode(value, buf) - if err != nil { - return 0, err - } - - return e.bucket.Update(ek, buf.Bytes(), last) -} - -func (e *EncodedKV) Delete(key string, opts ...nats.DeleteOpt) error { - ek, err := e.keyCodec.Encode(key) - if err != nil { - return err - } - - return e.bucket.Delete(ek, opts...) -} - -func (e *EncodedKV) Purge(key string, opts ...nats.DeleteOpt) error { - ek, err := e.keyCodec.Encode(key) - if err != nil { - return err - } - - return e.bucket.Purge(ek, opts...) -} - -func (e *EncodedKV) Watch(keys string, opts ...nats.WatchOpt) (nats.KeyWatcher, error) { - ek, err := e.keyCodec.EncodeRange(keys) - if err != nil { - return nil, err - } - - nw, err := e.bucket.Watch(ek, opts...) - if err != nil { - return nil, err - } - - return e.newWatcher(nw), err -} - -func (e *EncodedKV) History(key string, opts ...nats.WatchOpt) ([]nats.KeyValueEntry, error) { - ek, err := e.keyCodec.Encode(key) - if err != nil { - return nil, err - } - - var res []nats.KeyValueEntry - hist, err := e.bucket.History(ek, opts...) - if err != nil { - return nil, err - } - - for _, ent := range hist { - res = append(res, &entry{e.keyCodec, e.valueCodec, ent}) - } - - return res, nil -} - -func (e *EncodedKV) PutString(key string, value string) (revision uint64, err error) { - return e.Put(key, []byte(value)) -} -func (e *EncodedKV) WatchAll(opts ...nats.WatchOpt) (nats.KeyWatcher, error) { - return e.bucket.WatchAll(opts...) -} -func (e *EncodedKV) Keys(opts ...nats.WatchOpt) ([]string, error) { - keys, err := e.bucket.Keys(opts...) - if err != nil { - return nil, err - } - var res []string - for _, key := range keys { - dk, err := e.keyCodec.Decode(key) - if err != nil { - // should not happen - logrus.Warnf("error decoding %s: %v", key, err) - } - res = append(res, dk) - } - - return res, nil -} - -func (e *EncodedKV) Bucket() string { return e.bucket.Bucket() } -func (e *EncodedKV) PurgeDeletes(opts ...nats.PurgeOpt) error { return e.bucket.PurgeDeletes(opts...) } -func (e *EncodedKV) Status() (nats.KeyValueStatus, error) { return e.bucket.Status() } diff --git a/pkg/drivers/nats/logger.go b/pkg/drivers/nats/logger.go new file mode 100644 index 00000000..543384fd --- /dev/null +++ b/pkg/drivers/nats/logger.go @@ -0,0 +1,117 @@ +package nats + +import ( + "context" + "time" + + "github.com/k3s-io/kine/pkg/server" + "github.com/sirupsen/logrus" +) + +var ( + _ server.Backend = &BackendLogger{} +) + +type BackendLogger struct { + logger *logrus.Logger + backend server.Backend + threshold time.Duration +} + +func (b *BackendLogger) logMethod(dur time.Duration, str string, args ...any) { + if dur > b.threshold { + b.logger.Warnf(str, args...) + } else { + b.logger.Tracef(str, args...) + } +} + +func (b *BackendLogger) Start(ctx context.Context) error { + return b.backend.Start(ctx) +} + +// Get returns the store's current revision, the associated server.KeyValue or an error. +func (b *BackendLogger) Get(ctx context.Context, key, rangeEnd string, limit, revision int64) (revRet int64, kvRet *server.KeyValue, errRet error) { + start := time.Now() + defer func() { + dur := time.Since(start) + size := 0 + if kvRet != nil { + size = len(kvRet.Value) + } + fStr := "GET %s, rev=%d => revRet=%d, kv=%v, size=%d, err=%v, duration=%s" + b.logMethod(dur, fStr, key, revision, revRet, kvRet != nil, size, errRet, dur) + }() + + return b.backend.Get(ctx, key, rangeEnd, limit, revision) +} + +// Create attempts to create the key-value entry and returns the revision number. +func (b *BackendLogger) Create(ctx context.Context, key string, value []byte, lease int64) (revRet int64, errRet error) { + start := time.Now() + defer func() { + dur := time.Since(start) + fStr := "CREATE %s, size=%d, lease=%d => rev=%d, err=%v, duration=%s" + b.logMethod(dur, fStr, key, len(value), lease, revRet, errRet, dur) + }() + + return b.backend.Create(ctx, key, value, lease) +} + +func (b *BackendLogger) Delete(ctx context.Context, key string, revision int64) (revRet int64, kvRet *server.KeyValue, deletedRet bool, errRet error) { + start := time.Now() + defer func() { + dur := time.Since(start) + fStr := "DELETE %s, rev=%d => rev=%d, kv=%v, deleted=%v, err=%v, duration=%s" + b.logMethod(dur, fStr, key, revision, revRet, kvRet != nil, deletedRet, errRet, dur) + }() + + return b.backend.Delete(ctx, key, revision) +} + +func (b *BackendLogger) List(ctx context.Context, prefix, startKey string, limit, revision int64) (revRet int64, kvRet []*server.KeyValue, errRet error) { + start := time.Now() + defer func() { + dur := time.Since(start) + fStr := "LIST %s, start=%s, limit=%d, rev=%d => rev=%d, kvs=%d, err=%v, duration=%s" + b.logMethod(dur, fStr, prefix, startKey, limit, revision, revRet, len(kvRet), errRet, dur) + }() + + return b.backend.List(ctx, prefix, startKey, limit, revision) +} + +// Count returns an exact count of the number of matching keys and the current revision of the database +func (b *BackendLogger) Count(ctx context.Context, prefix string) (revRet int64, count int64, err error) { + start := time.Now() + defer func() { + dur := time.Since(start) + fStr := "COUNT %s => rev=%d, count=%d, err=%v, duration=%s" + b.logMethod(dur, fStr, prefix, revRet, count, err, dur) + }() + + return b.backend.Count(ctx, prefix) +} + +func (b *BackendLogger) Update(ctx context.Context, key string, value []byte, revision, lease int64) (revRet int64, kvRet *server.KeyValue, updateRet bool, errRet error) { + start := time.Now() + defer func() { + dur := time.Since(start) + kvRev := int64(0) + if kvRet != nil { + kvRev = kvRet.ModRevision + } + fStr := "UPDATE %s, value=%d, rev=%d, lease=%v => rev=%d, kvrev=%d, updated=%v, err=%v, duration=%s" + b.logMethod(dur, fStr, key, len(value), revision, lease, revRet, kvRev, updateRet, errRet, dur) + }() + + return b.backend.Update(ctx, key, value, revision, lease) +} + +func (b *BackendLogger) Watch(ctx context.Context, prefix string, revision int64) server.WatchResult { + return b.backend.Watch(ctx, prefix, revision) +} + +// DbSize get the kineBucket size from JetStream. +func (b *BackendLogger) DbSize(ctx context.Context) (int64, error) { + return b.backend.DbSize(ctx) +} diff --git a/pkg/drivers/nats/nats.go b/pkg/drivers/nats/nats.go deleted file mode 100644 index ee5f6b37..00000000 --- a/pkg/drivers/nats/nats.go +++ /dev/null @@ -1,1064 +0,0 @@ -package nats - -import ( - "context" - "encoding/json" - "fmt" - "net/url" - "os" - "os/signal" - "regexp" - "sort" - "strconv" - "strings" - "sync" - "time" - - "github.com/k3s-io/kine/pkg/drivers/nats/kv" - natsserver "github.com/k3s-io/kine/pkg/drivers/nats/server" - "github.com/k3s-io/kine/pkg/server" - "github.com/k3s-io/kine/pkg/tls" - "github.com/nats-io/jsm.go/natscontext" - "github.com/nats-io/nats.go" - "github.com/sirupsen/logrus" -) - -const ( - defaultBucket = "kine" - defaultReplicas = 1 - defaultRevHistory = 10 - defaultSlowMethod = 500 * time.Millisecond -) - -var ( - toplevelKeyMatch = regexp.MustCompile(`(/[^/]*/[^/]*)(/.*)?`) -) - -type Config struct { - // Client URL which could be a list of comma separated URLs. - clientURL string - // Client connection options. - clientOptions []nats.Option - // Number of revisions to keep in history. Defaults to 10. - revHistory uint8 - // Name of the bucket. Defaults to "kine". - bucket string - // Number of replicas for the bucket. Defaults to 1 - replicas int - // Indicates the duration of a method before it is considered slow. Defaults to 500ms. - slowThreshold time.Duration - // If true, an embedded server will not be used. - noEmbed bool - // If true, use a socket for the embedded server. - dontListen bool - // Path to a server configuration file when embedded. - serverConfig string - // If true, the embedded server will log to stdout. - stdoutLogging bool - // The explicit host to listen on when embedded. - host string - // The explicit port to listen on when embedded. - port int -} - -type Driver struct { - nc *nats.Conn - js nats.JetStreamContext - kv nats.KeyValue - - dirMu *sync.RWMutex - subMus map[string]*sync.RWMutex - - slowThreshold time.Duration -} - -func (d *Driver) logMethod(dur time.Duration, str string, args ...any) { - if dur > d.slowThreshold { - logrus.Warnf(str, args...) - } else { - logrus.Tracef(str, args...) - } -} - -func getTopLevelKey(key string) string { - if toplevelKeyMatch.MatchString(key) { - matches := toplevelKeyMatch.FindStringSubmatch(key) - return matches[1] - } - return "" -} - -func (d *Driver) lockFolder(key string) (unlock func()) { - lockFolder := getTopLevelKey(key) - if lockFolder == "" { - return func() {} - } - - d.dirMu.Lock() - mu, ok := d.subMus[lockFolder] - if !ok { - mu = &sync.RWMutex{} - d.subMus[lockFolder] = mu - } - d.dirMu.Unlock() - mu.Lock() - return mu.Unlock -} - -type JSValue struct { - KV *server.KeyValue - PrevRevision int64 - Create bool - Delete bool -} - -// New return an implementation of server.Backend using NATS + JetStream. -// See the `examples/nats.md` file for examples of connection strings. -func New(ctx context.Context, connection string, tlsInfo tls.Config) (server.Backend, error) { - return newBackend(ctx, connection, tlsInfo, false) -} - -// NewLegacy return an implementation of server.Backend using NATS + JetStream -// with legacy jetstream:// behavior, ignoring the embedded server. -func NewLegacy(ctx context.Context, connection string, tlsInfo tls.Config) (server.Backend, error) { - return newBackend(ctx, connection, tlsInfo, true) -} - -func newBackend(ctx context.Context, connection string, tlsInfo tls.Config, legacy bool) (server.Backend, error) { - config, err := parseConnection(connection, tlsInfo) - if err != nil { - return nil, err - } - - nopts := append(config.clientOptions, nats.Name("kine using bucket: "+config.bucket)) - - // Run an embedded server if available and not disabled. - if !legacy && natsserver.Embedded && !config.noEmbed { - logrus.Infof("using an embedded NATS server") - - ns, err := natsserver.New(&natsserver.Config{ - Host: config.host, - Port: config.port, - ConfigFile: config.serverConfig, - DontListen: config.dontListen, - StdoutLogging: config.stdoutLogging, - }) - if err != nil { - return nil, fmt.Errorf("failed to create embedded NATS server: %w", err) - } - - if config.dontListen { - nopts = append(nopts, nats.InProcessServer(ns)) - } - - // Start the server. - go ns.Start() - logrus.Infof("started embedded NATS server") - - // Wait for the server to be ready. - // TODO: limit the number of retries? - for { - if ns.ReadyForConnections(5 * time.Second) { - break - } - } - - // TODO: No method on backend.Driver exists to indicate a shutdown. - sigch := make(chan os.Signal, 1) - signal.Notify(sigch, os.Interrupt) - go func() { - <-sigch - ns.Shutdown() - logrus.Infof("embedded NATS server shutdown") - }() - - // Use the local server's client URL. - config.clientURL = ns.ClientURL() - } - - if !config.dontListen { - logrus.Infof("connecting to %s", config.clientURL) - } - - logrus.Infof("using bucket: %s", config.bucket) - - conn, err := nats.Connect(config.clientURL, nopts...) - if err != nil { - return nil, fmt.Errorf("failed to connect to NATS server: %w", err) - } - - js, err := conn.JetStream() - if err != nil { - return nil, fmt.Errorf("failed to get JetStream context: %w", err) - } - - bucket, err := js.KeyValue(config.bucket) - if err != nil && err == nats.ErrBucketNotFound { - bucket, err = js.CreateKeyValue( - &nats.KeyValueConfig{ - Bucket: config.bucket, - Description: "Holds kine key/values", - History: config.revHistory, - Replicas: config.replicas, - }) - } - - kvB := kv.NewEncodedKV(bucket, &kv.EtcdKeyCodec{}, &kv.S2ValueCodec{}) - - if err != nil { - return nil, err - } - - return &Driver{ - kv: kvB, - dirMu: &sync.RWMutex{}, - subMus: make(map[string]*sync.RWMutex), - js: js, - slowThreshold: config.slowThreshold, - }, nil -} - -// parseConnection returns nats connection url, bucketName and []nats.Option, error -func parseConnection(dsn string, tlsInfo tls.Config) (*Config, error) { - config := &Config{ - slowThreshold: defaultSlowMethod, - revHistory: defaultRevHistory, - bucket: defaultBucket, - replicas: defaultReplicas, - } - - // Parse the first URL in the connection string which contains the - // query parameters. - connections := strings.Split(dsn, ",") - u, err := url.Parse(connections[0]) - if err != nil { - return nil, err - } - - // Extract the host and port if embedded server is used. - config.host = u.Hostname() - if u.Port() != "" { - config.port, _ = strconv.Atoi(u.Port()) - } - - // Extract the query parameters to build configuration. - queryMap, err := url.ParseQuery(u.RawQuery) - if err != nil { - return nil, err - } - - if v := queryMap.Get("bucket"); v != "" { - config.bucket = v - } - - if v := queryMap.Get("replicas"); v != "" { - if r, err := strconv.ParseUint(v, 10, 8); err == nil { - if r >= 1 && r <= 5 { - config.replicas = int(r) - } else { - return nil, fmt.Errorf("invalid replicas, must be >= 1 and <= 5") - } - } - } - - if d := queryMap.Get("slowMethod"); d != "" { - if dur, err := time.ParseDuration(d); err == nil { - config.slowThreshold = dur - } else { - return nil, fmt.Errorf("invalid slowMethod duration: %w", err) - } - } - - if r := queryMap.Get("revHistory"); r != "" { - if revs, err := strconv.ParseUint(r, 10, 8); err == nil { - if revs >= 2 && revs <= 64 { - config.revHistory = uint8(revs) - } else { - return nil, fmt.Errorf("invalid revHistory, must be >= 2 and <= 64") - } - } - } - - if tlsInfo.KeyFile != "" && tlsInfo.CertFile != "" { - config.clientOptions = append(config.clientOptions, nats.ClientCert(tlsInfo.CertFile, tlsInfo.KeyFile)) - } - - if tlsInfo.CAFile != "" { - config.clientOptions = append(config.clientOptions, nats.RootCAs(tlsInfo.CAFile)) - } - - if f := queryMap.Get("contextFile"); f != "" { - if u.Host != "" { - return config, fmt.Errorf("when using context endpoint no host should be provided") - } - - logrus.Debugf("loading nats context file: %s", f) - - natsContext, err := natscontext.NewFromFile(f) - if err != nil { - return nil, err - } - - connections = strings.Split(natsContext.ServerURL(), ",") - - // command line options provided to kine will override the file - // https://github.com/nats-io/jsm.go/blob/v0.0.29/natscontext/context.go#L257 - // allows for user, creds, nke, token, certifcate, ca, inboxprefix from the context.json - natsClientOpts, err := natsContext.NATSOptions(config.clientOptions...) - if err != nil { - return nil, err - } - config.clientOptions = natsClientOpts - } - - connBuilder := strings.Builder{} - for idx, c := range connections { - if idx > 0 { - connBuilder.WriteString(",") - } - - u, err := url.Parse(c) - if err != nil { - return nil, err - } - - if u.Scheme != "nats" { - return nil, fmt.Errorf("invalid connection string=%s", c) - } - - connBuilder.WriteString("nats://") - - if u.User != nil && idx == 0 { - userInfo := strings.Split(u.User.String(), ":") - if len(userInfo) > 1 { - config.clientOptions = append(config.clientOptions, nats.UserInfo(userInfo[0], userInfo[1])) - } else { - config.clientOptions = append(config.clientOptions, nats.Token(userInfo[0])) - } - } - connBuilder.WriteString(u.Host) - } - - config.clientURL = connBuilder.String() - - // Config options only relevant if built with embedded NATS. - if natsserver.Embedded { - config.noEmbed = queryMap.Has("noEmbed") - config.serverConfig = queryMap.Get("serverConfig") - config.stdoutLogging = queryMap.Has("stdoutLogging") - config.dontListen = queryMap.Has("dontListen") - } - - logrus.Debugf("using config %#v", config) - - return config, nil -} - -func (d *Driver) Start(ctx context.Context) error { - // See https://github.com/kubernetes/kubernetes/blob/442a69c3bdf6fe8e525b05887e57d89db1e2f3a5/staging/src/k8s.io/apiserver/pkg/storage/storagebackend/factory/etcd3.go#L97 - if _, err := d.Create(ctx, "/registry/health", []byte(`{"health":"true"}`), 0); err != nil { - if err != server.ErrKeyExists { - logrus.Errorf("Failed to create health check key: %v", err) - } - } - return nil -} - -func (d *Driver) isKeyExpired(_ context.Context, createTime time.Time, value *JSValue) bool { - - requestTime := time.Now() - expired := false - if value.KV.Lease > 0 { - if requestTime.After(createTime.Add(time.Second * time.Duration(value.KV.Lease))) { - expired = true - if err := d.kv.Delete(value.KV.Key); err != nil { - logrus.Warnf("problem deleting expired key=%s, error=%v", value.KV.Key, err) - } - } - } - - return expired -} - -// Get returns the associated server.KeyValue -func (d *Driver) Get(ctx context.Context, key, rangeEnd string, limit, revision int64) (revRet int64, kvRet *server.KeyValue, errRet error) { - start := time.Now() - defer func() { - dur := time.Since(start) - size := 0 - if kvRet != nil { - size = len(kvRet.Value) - } - fStr := "GET %s, rev=%d => revRet=%d, kv=%v, size=%d, err=%v, duration=%s" - d.logMethod(dur, fStr, key, revision, revRet, kvRet != nil, size, errRet, dur) - }() - - currentRev, err := d.currentRevision() - if err != nil { - return currentRev, nil, err - } - - rev, kv, err := d.get(ctx, key, revision, false) - if err == nil { - if kv == nil { - return currentRev, nil, nil - } - return rev, kv.KV, nil - } - - if err == nats.ErrKeyNotFound { - return currentRev, nil, nil - } - - return rev, nil, err -} - -func (d *Driver) get(ctx context.Context, key string, revision int64, includeDeletes bool) (int64, *JSValue, error) { - compactRev, err := d.compactRevision() - if err != nil { - return 0, nil, err - } - - // Get latest revision - if revision <= 0 { - entry, err := d.kv.Get(key) - if err == nil { - val, err := decode(entry) - if err != nil { - return 0, nil, err - } - - if val.Delete && !includeDeletes { - return 0, nil, nats.ErrKeyNotFound - } - - if d.isKeyExpired(ctx, entry.Created(), &val) { - return 0, nil, nats.ErrKeyNotFound - } - return val.KV.ModRevision, &val, nil - } - if err == nats.ErrKeyNotFound { - return 0, nil, err - } - return 0, nil, err - } - - if revision < compactRev { - logrus.Warnf("requested revision has been compacted") - } - - entry, err := d.kv.GetRevision(key, uint64(revision)) - if err == nil { - val, err := decode(entry) - if err != nil { - return 0, nil, err - } - - if val.Delete && !includeDeletes { - return 0, nil, nats.ErrKeyNotFound - } - - if d.isKeyExpired(ctx, entry.Created(), &val) { - return 0, nil, nats.ErrKeyNotFound - } - return val.KV.ModRevision, &val, nil - } - - if err == nats.ErrKeyNotFound { - return 0, nil, err - } - - return 0, nil, err -} - -// Create -func (d *Driver) Create(ctx context.Context, key string, value []byte, lease int64) (revRet int64, errRet error) { - start := time.Now() - defer func() { - dur := time.Since(start) - fStr := "CREATE %s, size=%d, lease=%d => rev=%d, err=%v, duration=%s" - d.logMethod(dur, fStr, key, len(value), lease, revRet, errRet, dur) - }() - - // Lock the folder containing this key. - defer d.lockFolder(key)() - - // check if key exists already - rev, prevKV, err := d.get(ctx, key, 0, true) - if err != nil && err != nats.ErrKeyNotFound { - return 0, err - } - - createValue := JSValue{ - Delete: false, - Create: true, - PrevRevision: rev, - KV: &server.KeyValue{ - Key: key, - CreateRevision: 0, - ModRevision: 0, - Value: value, - Lease: lease, - }, - } - - if prevKV != nil { - if !prevKV.Delete { - return 0, server.ErrKeyExists - } - createValue.PrevRevision = prevKV.KV.ModRevision - } - - event, err := encode(createValue) - if err != nil { - return 0, err - } - - if prevKV != nil { - seq, err := d.kv.Put(key, event) - if err != nil { - return 0, err - } - return int64(seq), nil - } - seq, err := d.kv.Create(key, event) - if err != nil { - return 0, err - } - return int64(seq), nil -} - -func (d *Driver) Delete(ctx context.Context, key string, revision int64) (revRet int64, kvRet *server.KeyValue, deletedRet bool, errRet error) { - start := time.Now() - defer func() { - dur := time.Since(start) - fStr := "DELETE %s, rev=%d => rev=%d, kv=%v, deleted=%v, err=%v, duration=%s" - d.logMethod(dur, fStr, key, revision, revRet, kvRet != nil, deletedRet, errRet, dur) - }() - - // Lock the folder containing this key. - defer d.lockFolder(key)() - - rev, value, err := d.get(ctx, key, 0, true) - if err != nil { - if err == nats.ErrKeyNotFound { - return rev, nil, true, nil - } - return rev, nil, false, err - } - - if value == nil { - return rev, nil, true, nil - } - - if value.Delete { - return rev, value.KV, true, nil - } - - if revision != 0 && value.KV.ModRevision != revision { - return rev, value.KV, false, nil - } - - deleteEvent := JSValue{ - Delete: true, - PrevRevision: rev, - KV: value.KV, - } - deleteEventBytes, err := encode(deleteEvent) - if err != nil { - return rev, nil, false, err - } - - deleteRev, err := d.kv.Put(key, deleteEventBytes) - if err != nil { - return rev, value.KV, false, nil - } - - err = d.kv.Delete(key) - if err != nil { - return rev, value.KV, false, nil - } - - return int64(deleteRev), value.KV, true, nil -} - -func (d *Driver) List(ctx context.Context, prefix, startKey string, limit, revision int64) (revRet int64, kvRet []*server.KeyValue, errRet error) { - start := time.Now() - defer func() { - dur := time.Since(start) - fStr := "LIST %s, start=%s, limit=%d, rev=%d => rev=%d, kvs=%d, err=%v, duration=%s" - d.logMethod(dur, fStr, prefix, startKey, limit, revision, revRet, len(kvRet), errRet, dur) - }() - - // its assumed that when there is a start key that that key exists. - if strings.HasSuffix(prefix, "/") { - if prefix == startKey || strings.HasPrefix(prefix, startKey) { - startKey = "" - } - } - - rev, err := d.currentRevision() - if err != nil { - return 0, nil, err - } - - kvs := make([]*server.KeyValue, 0) - var count int64 - - // startkey provided so get max revision after the startKey matching the prefix - if startKey != "" { - histories := make(map[string][]nats.KeyValueEntry) - var minRev int64 - //var innerEntry nats.KeyValueEntry - if entries, err := d.kv.History(startKey, nats.Context(ctx)); err == nil { - histories[startKey] = entries - for i := len(entries) - 1; i >= 0; i-- { - // find the matching startKey - if int64(entries[i].Revision()) <= revision { - minRev = int64(entries[i].Revision()) - logrus.Debugf("Found min revision=%d for key=%s", minRev, startKey) - break - } - } - } else { - return 0, nil, err - } - - keys, err := d.getKeys(ctx, prefix, true) - if err != nil { - return 0, nil, err - } - - for _, key := range keys { - if key != startKey { - if history, err := d.kv.History(key, nats.Context(ctx)); err == nil { - histories[key] = history - } else { - // should not happen - logrus.Warnf("no history for %s", key) - } - } - } - var nextRevID = minRev - var nextRevision nats.KeyValueEntry - for k, v := range histories { - logrus.Debugf("Checking %s history", k) - for i := len(v) - 1; i >= 0; i-- { - if int64(v[i].Revision()) > nextRevID && int64(v[i].Revision()) <= revision { - nextRevID = int64(v[i].Revision()) - nextRevision = v[i] - logrus.Debugf("found next rev=%d", nextRevID) - break - } else if int64(v[i].Revision()) <= nextRevID { - break - } - } - } - if nextRevision != nil { - entry, err := decode(nextRevision) - if err != nil { - return 0, nil, err - } - kvs = append(kvs, entry.KV) - } - - return rev, kvs, nil - } - - current := true - - if revision != 0 { - rev = revision - current = false - } - - if current { - - entries, err := d.getKeyValues(ctx, prefix, true) - if err != nil { - return 0, nil, err - } - for _, e := range entries { - if count < limit || limit == 0 { - kv, err := decode(e) - if !d.isKeyExpired(ctx, e.Created(), &kv) && err == nil { - kvs = append(kvs, kv.KV) - count++ - } - } else { - break - } - } - - } else { - keys, err := d.getKeys(ctx, prefix, true) - if err != nil { - return 0, nil, err - } - if revision == 0 && len(keys) == 0 { - return rev, nil, nil - } - - for _, key := range keys { - if count < limit || limit == 0 { - if history, err := d.kv.History(key, nats.Context(ctx)); err == nil { - for i := len(history) - 1; i >= 0; i-- { - if int64(history[i].Revision()) <= revision { - if entry, err := decode(history[i]); err == nil { - kvs = append(kvs, entry.KV) - count++ - } else { - logrus.Warnf("Could not decode %s rev=> %d", key, history[i].Revision()) - } - break - } - } - } else { - // should not happen - logrus.Warnf("no history for %s", key) - } - } - } - - } - return rev, kvs, nil -} - -func (d *Driver) listAfter(ctx context.Context, prefix string, revision int64) (revRet int64, eventRet []*server.Event, errRet error) { - - entries, err := d.getKeyValues(ctx, prefix, false) - - if err != nil { - return 0, nil, err - } - - rev, err := d.currentRevision() - if err != nil { - return 0, nil, err - } - if revision != 0 { - rev = revision - } - events := make([]*server.Event, 0) - for _, e := range entries { - kv, err := decode(e) - if err == nil && int64(e.Revision()) > revision { - event := server.Event{ - Delete: kv.Delete, - Create: kv.Create, - KV: kv.KV, - PrevKV: &server.KeyValue{}, - } - if _, prevKV, err := d.Get(ctx, kv.KV.Key, "", 1, kv.PrevRevision); err == nil && prevKV != nil { - event.PrevKV = prevKV - } - - events = append(events, &event) - } - } - return rev, events, nil -} - -// Count returns an exact count of the number of matching keys and the current revision of the database -func (d *Driver) Count(ctx context.Context, prefix string) (revRet int64, count int64, err error) { - start := time.Now() - defer func() { - dur := time.Since(start) - fStr := "COUNT %s => rev=%d, count=%d, err=%v, duration=%s" - d.logMethod(dur, fStr, prefix, revRet, count, err, dur) - }() - - entries, err := d.getKeys(ctx, prefix, false) - if err != nil { - return 0, 0, err - } - // current revision - currentRev, err := d.currentRevision() - if err != nil { - return 0, 0, err - } - return currentRev, int64(len(entries)), nil -} - -func (d *Driver) Update(ctx context.Context, key string, value []byte, revision, lease int64) (revRet int64, kvRet *server.KeyValue, updateRet bool, errRet error) { - start := time.Now() - defer func() { - dur := time.Since(start) - kvRev := int64(0) - if kvRet != nil { - kvRev = kvRet.ModRevision - } - fStr := "UPDATE %s, value=%d, rev=%d, lease=%v => rev=%d, kvrev=%d, updated=%v, err=%v, duration=%s" - d.logMethod(dur, fStr, key, len(value), revision, lease, revRet, kvRev, updateRet, errRet, dur) - }() - - // Lock the folder containing the key. - defer d.lockFolder(key)() - - rev, prevKV, err := d.get(ctx, key, 0, false) - - if err != nil { - if err == nats.ErrKeyNotFound { - return rev, nil, false, nil - } - return rev, nil, false, err - } - - if prevKV == nil { - return 0, nil, false, nil - } - - if prevKV.KV.ModRevision != revision { - return rev, prevKV.KV, false, nil - } - - updateValue := JSValue{ - Delete: false, - Create: false, - PrevRevision: prevKV.KV.ModRevision, - KV: &server.KeyValue{ - Key: key, - CreateRevision: prevKV.KV.CreateRevision, - Value: value, - Lease: lease, - }, - } - if prevKV.KV.CreateRevision == 0 { - updateValue.KV.CreateRevision = rev - } - - valueBytes, err := encode(updateValue) - if err != nil { - return 0, nil, false, err - } - - seq, err := d.kv.Put(key, valueBytes) - if err != nil { - return 0, nil, false, err - } - - updateValue.KV.ModRevision = int64(seq) - - return int64(seq), updateValue.KV, true, err - -} - -func (d *Driver) Watch(ctx context.Context, prefix string, revision int64) server.WatchResult { - ctx, cancel := context.WithCancel(ctx) - watcher, err := d.kv.(*kv.EncodedKV).Watch(prefix, nats.IgnoreDeletes(), nats.Context(ctx)) - - if revision > 0 { - revision-- - } - - result := make(chan []*server.Event, 100) - wr := server.WatchResult{Events: result} - - rev, events, err := d.listAfter(ctx, prefix, revision) - if err != nil { - logrus.Errorf("failed to create watcher %s for revision %d", prefix, revision) - if err == server.ErrCompacted { - compact, _ := d.compactRevision() - wr.CompactRevision = compact - wr.CurrentRevision = rev - } - cancel() - } - - go func() { - if len(events) > 0 { - result <- events - revision = events[len(events)-1].KV.ModRevision - } - - for { - select { - case i := <-watcher.Updates(): - if i != nil { - if int64(i.Revision()) > revision { - events := make([]*server.Event, 1) - var err error - value := JSValue{ - KV: &server.KeyValue{}, - PrevRevision: 0, - Create: false, - Delete: false, - } - prevValue := JSValue{ - KV: &server.KeyValue{}, - PrevRevision: 0, - Create: false, - Delete: false, - } - lastEntry := &i - - value, err = decode(*lastEntry) - if err != nil { - logrus.Warnf("watch event: could not decode %s seq %d", i.Key(), i.Revision()) - } - if _, prevEntry, prevErr := d.get(ctx, i.Key(), value.PrevRevision, false); prevErr == nil { - if prevEntry != nil { - prevValue = *prevEntry - } - } - if err == nil { - event := &server.Event{ - Create: value.Create, - Delete: value.Delete, - KV: value.KV, - PrevKV: prevValue.KV, - } - events[0] = event - result <- events - } else { - logrus.Warnf("error decoding %s event %v", i.Key(), err) - continue - } - } - } - case <-ctx.Done(): - logrus.Infof("watcher: %s context cancelled", prefix) - if err := watcher.Stop(); err != nil && err != nats.ErrBadSubscription { - logrus.Warnf("error stopping %s watcher: %v", prefix, err) - } - close(result) - cancel() - return - } - } - }() - return wr -} - -// getPreviousEntry returns the nats.KeyValueEntry previous to the one provided, if the previous entry is a nats.KeyValuePut -// operation. If it is not a KeyValuePut then it will return nil. -func (d *Driver) getPreviousEntry(ctx context.Context, entry nats.KeyValueEntry) (result *nats.KeyValueEntry, e error) { - defer func() { - if result != nil { - logrus.Debugf("getPreviousEntry %s:%d found=true %d", entry.Key(), entry.Revision(), (*result).Revision()) - } else { - logrus.Debugf("getPreviousEntry %s:%d found=false", entry.Key(), entry.Revision()) - } - }() - found := false - entries, err := d.kv.History(entry.Key(), nats.Context(ctx)) - if err == nil { - for idx := len(entries) - 1; idx >= 0; idx-- { - if found { - if entries[idx].Operation() == nats.KeyValuePut { - return &entries[idx], nil - } - return nil, nil - } - if entries[idx].Revision() == entry.Revision() { - found = true - } - } - } - - return nil, nil -} - -// DbSize get the kineBucket size from JetStream. -func (d *Driver) DbSize(context.Context) (int64, error) { - status, err := d.kv.Status() - if err != nil { - return -1, err - } - return int64(status.Bytes()), nil -} - -func encode(v JSValue) ([]byte, error) { - buf, err := json.Marshal(v) - return buf, err -} - -func decode(e nats.KeyValueEntry) (JSValue, error) { - v := JSValue{} - if e.Value() != nil { - err := json.Unmarshal(e.Value(), &v) - if err != nil { - logrus.Debugf("key: %s", e.Key()) - logrus.Debugf("sequence number: %d", e.Revision()) - logrus.Debugf("bytes returned: %v", len(e.Value())) - return v, err - } - v.KV.ModRevision = int64(e.Revision()) - } - return v, nil -} - -func (d *Driver) currentRevision() (int64, error) { - status, err := d.kv.Status() - if err != nil { - return 0, err - } - return int64(status.(*nats.KeyValueBucketStatus).StreamInfo().State.LastSeq), nil -} - -func (d *Driver) compactRevision() (int64, error) { - status, err := d.kv.Status() - if err != nil { - return 0, err - } - return int64(status.(*nats.KeyValueBucketStatus).StreamInfo().State.FirstSeq), nil -} - -// getKeyValues returns a []nats.KeyValueEntry matching prefix -func (d *Driver) getKeyValues(ctx context.Context, prefix string, sortResults bool) ([]nats.KeyValueEntry, error) { - watcher, err := d.kv.Watch(prefix, nats.IgnoreDeletes(), nats.Context(ctx)) - if err != nil { - return nil, err - } - defer func() { - err := watcher.Stop() - if err != nil { - logrus.Warnf("failed to stop %s getKeyValues watcher", prefix) - } - }() - - var entries []nats.KeyValueEntry - for entry := range watcher.Updates() { - if entry == nil { - break - } - entries = append(entries, entry) - } - - if sortResults { - sort.Slice(entries, func(i, j int) bool { - return entries[i].Key() < entries[j].Key() - }) - } - - return entries, nil -} - -// getKeys returns a list of keys matching a prefix -func (d *Driver) getKeys(ctx context.Context, prefix string, sortResults bool) ([]string, error) { - watcher, err := d.kv.Watch(prefix, nats.MetaOnly(), nats.IgnoreDeletes(), nats.Context(ctx)) - if err != nil { - return nil, err - } - defer func() { - err := watcher.Stop() - if err != nil { - logrus.Warnf("failed to stop %s getKeys watcher", prefix) - } - }() - - var keys []string - // grab all matching keys immediately - for entry := range watcher.Updates() { - if entry == nil { - break - } - keys = append(keys, entry.Key()) - } - - if sortResults { - sort.Strings(keys) - } - - return keys, nil -} diff --git a/pkg/drivers/nats/new.go b/pkg/drivers/nats/new.go new file mode 100644 index 00000000..977d9956 --- /dev/null +++ b/pkg/drivers/nats/new.go @@ -0,0 +1,258 @@ +package nats + +import ( + "context" + "errors" + "fmt" + "os" + "os/signal" + "time" + + natsserver "github.com/k3s-io/kine/pkg/drivers/nats/server" + "github.com/k3s-io/kine/pkg/server" + "github.com/k3s-io/kine/pkg/tls" + "github.com/nats-io/nats.go" + "github.com/nats-io/nats.go/jetstream" + "github.com/sirupsen/logrus" +) + +const ( + defaultBucket = "kine" + defaultReplicas = 1 + defaultRevHistory = 10 + defaultSlowMethod = 500 * time.Millisecond +) + +var ( + // Missing errors in the nats.go client library. + jsClusterNotAvailErr = &jetstream.APIError{ + Code: 503, + ErrorCode: 10008, + } + + jsNoSuitablePeersErr = &jetstream.APIError{ + Code: 400, + ErrorCode: 10005, + } + + jsWrongLastSeqErr = &jetstream.APIError{ + Code: 400, + ErrorCode: jetstream.JSErrCodeStreamWrongLastSequence, + } +) + +// New return an implementation of server.Backend using NATS + JetStream. +// See the `examples/nats.md` file for examples of connection strings. +func New(ctx context.Context, connection string, tlsInfo tls.Config) (server.Backend, error) { + return newBackend(ctx, connection, tlsInfo, false) +} + +// NewLegacy return an implementation of server.Backend using NATS + JetStream +// with legacy jetstream:// behavior, ignoring the embedded server. +func NewLegacy(ctx context.Context, connection string, tlsInfo tls.Config) (server.Backend, error) { + return newBackend(ctx, connection, tlsInfo, true) +} + +func newBackend(ctx context.Context, connection string, tlsInfo tls.Config, legacy bool) (server.Backend, error) { + config, err := parseConnection(connection, tlsInfo) + if err != nil { + return nil, err + } + + nopts := append( + config.clientOptions, + nats.Name("kine using bucket: "+config.bucket), + nats.MaxReconnects(-1), + ) + + // Run an embedded server if available and not disabled. + var ns natsserver.Server + cancel := func() {} + + if !legacy && natsserver.Embedded && !config.noEmbed { + logrus.Infof("using an embedded NATS server") + + ns, err = natsserver.New(&natsserver.Config{ + Host: config.host, + Port: config.port, + ConfigFile: config.serverConfig, + DontListen: config.dontListen, + StdoutLogging: config.stdoutLogging, + DataDir: config.dataDir, + }) + if err != nil { + return nil, fmt.Errorf("failed to create embedded NATS server: %w", err) + } + + if config.dontListen { + nopts = append(nopts, nats.InProcessServer(ns)) + } + + // Start the server. + go ns.Start() + logrus.Infof("started embedded NATS server") + time.Sleep(100 * time.Millisecond) + + // Wait for the server to be ready. + var retries int + for { + if ns.Ready() { + logrus.Infof("embedded NATS server is ready for client connections") + break + } + retries++ + logrus.Infof("waiting for embedded NATS server to be ready: %d", retries) + time.Sleep(100 * time.Millisecond) + } + + // Use the local server's client URL. + config.clientURL = ns.ClientURL() + + ctx, cancel = context.WithCancel(ctx) + } + + if !config.dontListen { + logrus.Infof("connecting to %s", config.clientURL) + } + + logrus.Infof("using bucket: %s", config.bucket) + + nopts = append(nopts, + nats.DisconnectErrHandler(func(_ *nats.Conn, err error) { + logrus.Errorf("NATS disconnected: %s", err) + }), + nats.DiscoveredServersHandler(func(nc *nats.Conn) { + logrus.Infof("NATS discovered servers: %v", nc.Servers()) + }), + nats.ErrorHandler(func(_ *nats.Conn, _ *nats.Subscription, err error) { + logrus.Errorf("NATS error callback: %s", err) + }), + nats.ReconnectHandler(func(nc *nats.Conn) { + logrus.Infof("NATS reconnected: %v", nc.ConnectedUrl()) + }), + ) + + nc, err := nats.Connect(config.clientURL, nopts...) + if err != nil { + cancel() + return nil, fmt.Errorf("failed to connect to NATS server: %w", err) + } + + js, err := jetstream.New(nc) + if err != nil { + cancel() + return nil, fmt.Errorf("failed to get JetStream context: %w", err) + } + + bucket, err := getOrCreateBucket(ctx, js, config) + if err != nil { + cancel() + return nil, fmt.Errorf("failed to get or create bucket: %w", err) + } + if err := disableDirectGets(ctx, js, config); err != nil { + cancel() + return nil, fmt.Errorf("failed to disable direct gets: %w", err) + } + + logrus.Infof("bucket initialized: %s", config.bucket) + + ekv := NewKeyValue(ctx, bucket, js) + + // Reference the global logger, since it appears log levels are + // applied globally. + l := logrus.StandardLogger() + + backend := Backend{ + nc: nc, + l: l, + kv: ekv, + js: js, + cancel: cancel, + } + + if ns != nil { + // TODO: No method on backend.Driver exists to indicate a shutdown. + sigch := make(chan os.Signal, 1) + signal.Notify(sigch, os.Interrupt) + go func() { + <-sigch + backend.Close() + ns.Shutdown() + logrus.Infof("embedded NATS server shutdown") + }() + } + + return &BackendLogger{ + logger: l, + backend: &backend, + threshold: config.slowThreshold, + }, nil +} + +func getOrCreateBucket(ctx context.Context, js jetstream.JetStream, config *Config) (jetstream.KeyValue, error) { + bucket, err := js.KeyValue(ctx, config.bucket) + if err == nil { + return bucket, nil + } + + // If it does not exist, attempt to create it. + for { + bucket, err = js.CreateKeyValue(ctx, jetstream.KeyValueConfig{ + Bucket: config.bucket, + Description: "Holds kine key/values", + History: config.revHistory, + Replicas: config.replicas, + }) + if err == nil { + return bucket, nil + } + + // Check for timeout errors and retry. + if errors.Is(err, context.DeadlineExceeded) { + logrus.Warnf("timed out waiting for bucket %s to be created. retrying", config.bucket) + continue + } + + // Concurrent creation can cause this error. + if jetstream.ErrStreamNameAlreadyInUse.APIError().Is(err) { + return js.KeyValue(ctx, config.bucket) + } + + // Check for temporary JetStream errors when the cluster is unhealthy and retry. + if jsClusterNotAvailErr.Is(err) || jsNoSuitablePeersErr.Is(err) { + logrus.Warnf(err.Error()) + time.Sleep(time.Second) + continue + } + + // Some unexpected error. + if err != nil { + return nil, fmt.Errorf("failed to initialize KV bucket: %w", err) + } + } +} + +func disableDirectGets(ctx context.Context, js jetstream.JetStream, config *Config) error { + for { + str, err := js.Stream(ctx, fmt.Sprintf("KV_%s", config.bucket)) + if errors.Is(err, context.DeadlineExceeded) { + continue + } + if err != nil { + return fmt.Errorf("failed to get stream info: %w", err) + } + + scfg := str.CachedInfo().Config + scfg.AllowDirect = false + + _, err = js.UpdateStream(ctx, scfg) + if errors.Is(err, context.DeadlineExceeded) { + continue + } + if err != nil { + return fmt.Errorf("failed to update stream config: %w", err) + } + + return nil + } +} diff --git a/pkg/drivers/nats/server/interface.go b/pkg/drivers/nats/server/interface.go index b23993f0..ec44e5c6 100644 --- a/pkg/drivers/nats/server/interface.go +++ b/pkg/drivers/nats/server/interface.go @@ -2,14 +2,13 @@ package server import ( "net" - "time" ) type Server interface { Start() + Ready() bool Shutdown() ClientURL() string - ReadyForConnections(wait time.Duration) bool InProcessConn() (net.Conn, error) } @@ -19,4 +18,5 @@ type Config struct { ConfigFile string DontListen bool StdoutLogging bool + DataDir string } diff --git a/pkg/drivers/nats/server/server.go b/pkg/drivers/nats/server/server.go index 89806f31..e0940b91 100644 --- a/pkg/drivers/nats/server/server.go +++ b/pkg/drivers/nats/server/server.go @@ -4,15 +4,65 @@ package server import ( + "bytes" + "encoding/json" "fmt" + "net/http" + "net/url" "github.com/nats-io/nats-server/v2/server" + "github.com/sirupsen/logrus" ) const ( Embedded = true ) +type responseWriter struct { + code int + header http.Header + body *bytes.Buffer +} + +func (w *responseWriter) Header() http.Header { + return w.header +} + +func (w *responseWriter) Write(b []byte) (int, error) { + return w.body.Write(b) +} + +func (w *responseWriter) WriteHeader(code int) { + w.code = code +} + +type embeddedServer struct { + *server.Server +} + +func (s *embeddedServer) Ready() bool { + rw := responseWriter{ + header: http.Header{}, + body: &bytes.Buffer{}, + } + + r := http.Request{ + Method: "GET", + URL: &url.URL{ + Path: "/healthz", + }, + Header: http.Header{}, + } + + s.Server.HandleHealthz(&rw, &r) + + var hs server.HealthStatus + json.NewDecoder(rw.body).Decode(&hs) + logrus.Debugf("embedded NATS server health: %#v", hs) + + return hs.Status == "ok" +} + func New(c *Config) (Server, error) { opts := &server.Options{} @@ -25,6 +75,10 @@ func New(c *Config) (Server, error) { } } + // TODO: Other defaults for embedded config? + // Explicitly set JetStream to true since we need the KV store. + opts.JetStream = true + // Note, if don't listen is set, host and port will be ignored. opts.DontListen = c.DontListen @@ -35,15 +89,18 @@ func New(c *Config) (Server, error) { if c.Port != 0 { opts.Port = c.Port } - - // TODO: Other defaults for embedded config? - // Explicitly set JetStream to true since we need the KV store. - opts.JetStream = true + if c.DataDir != "" { + opts.StoreDir = c.DataDir + } srv, err := server.NewServer(opts) + if err != nil { + return nil, err + } + if c.StdoutLogging { srv.ConfigureLogger() } - return srv, err + return &embeddedServer{Server: srv}, nil } diff --git a/scripts/test-run-nats b/scripts/test-run-nats index 17057e1e..b4c8b3cf 100755 --- a/scripts/test-run-nats +++ b/scripts/test-run-nats @@ -15,7 +15,7 @@ export -f start-test # test minimum supported version and most recent version of NATS VERSION_LIST="\ - nats 2.9.21" + nats 2.9.22" while read ENGINE VERSION; do LABEL=$ENGINE-$VERSION DB_PASSWORD_ENV=NATS_JS_PASSWORD DB_ARGS="-js" DB_IMAGE=docker.io/library/$ENGINE:$VERSION run-test diff --git a/scripts/test-run-nats-embedded b/scripts/test-run-nats-embedded index eae7e7a3..a1a7fc60 100755 --- a/scripts/test-run-nats-embedded +++ b/scripts/test-run-nats-embedded @@ -1,7 +1,7 @@ #!/bin/bash start-test() { - KINE_IMAGE=$IMAGE KINE_ENDPOINT="nats://?embedServer" provision-kine + KINE_IMAGE=$IMAGE KINE_ENDPOINT="nats://" provision-kine local kine_url=$(cat $TEST_DIR/kine/*/metadata/url) K3S_DATASTORE_ENDPOINT=$kine_url provision-cluster }