From 4a5e4ca5f6e5921af3a47d70e1398ad4d23e443a Mon Sep 17 00:00:00 2001 From: Burak Sezer Date: Sun, 20 Mar 2022 13:11:04 +0300 Subject: [PATCH] refactor: add some integration tests --- README.md | 2 +- client.go | 21 + cluster.go | 23 + cluster_client.go | 101 +++- cluster_client_test.go | 104 +++- cluster_iterator_test.go | 4 +- cmd/olric-stats/main.go | 175 ------ cmd/olric-stats/query/query.go | 271 --------- config/dmap.go | 4 - config/dmaps.go | 4 - config/engine.go | 28 - config/internal/loader/loader.go | 1 - config/load.go | 2 - dmap.go | 149 ----- embedded_client.go | 31 ++ embedded_client_test.go | 22 + embedded_iterator.go | 2 +- get_response.go | 4 +- get_response_test.go | 10 +- go.mod | 2 + go.sum | 4 +- internal/cluster/balancer/balancer.go | 7 +- internal/cluster/routingtable/routingtable.go | 6 + internal/dmap/atomic.go | 6 +- internal/dmap/atomic_test.go | 20 +- internal/dmap/put.go | 4 +- internal/dmap/scan.go | 199 ------- internal/dmap/scan_handlers.go | 6 +- internal/dmap/scan_test.go | 127 ----- internal/dmap/service.go | 1 - internal/encoding/encoder_test.go | 29 - internal/encoding/scan_test.go | 43 -- internal/kvstore/bitmap.go | 22 - internal/kvstore/kvstore.go | 26 +- internal/protocol/cluster.go | 34 +- internal/protocol/cluster_parser.go | 28 - ...cluster_parser_test.go => cluster_test.go} | 25 +- internal/protocol/commands.go | 28 +- internal/protocol/dmap.go | 414 +++++++++++++- internal/protocol/dmap_parser.go | 434 --------------- internal/protocol/dmap_parser_test.go | 220 -------- internal/protocol/dmap_test.go | 516 ++++++++++++++++++ internal/protocol/errors.go | 2 + internal/protocol/pubsub.go | 111 +++- internal/protocol/pubsub_parser.go | 96 ---- internal/protocol/pubsub_parser_test.go | 29 - internal/protocol/pubsub_test.go | 90 +++ internal/protocol/system.go | 76 +++ internal/protocol/system_parser.go | 95 ---- internal/protocol/system_test.go | 107 ++++ internal/pubsub/handlers.go | 39 +- internal/pubsub/handlers_test.go | 51 ++ internal/pubsub/service.go | 1 + internal/{encoding => resp}/encoder.go | 2 +- internal/resp/encoder_test.go | 286 ++++++++++ internal/{encoding => resp}/scan.go | 61 +-- internal/roundrobin/round_robin.go | 74 +++ internal/roundrobin/round_robin_test.go | 88 +++ internal/server/client.go | 30 +- internal/server/client_test.go | 94 +++- internal/server/mux_test.go | 54 ++ internal/server/server_test.go | 14 +- olric.go | 25 +- pkg/storage/engine.go | 12 + pkg/storage/storage.go | 48 -- pkg/storage/storage_test.go | 21 - pubsub.go | 80 +++ pubsub_test.go | 263 +++++++++ stats.go | 4 +- stats_test.go | 2 +- 70 files changed, 2771 insertions(+), 2243 deletions(-) delete mode 100644 cmd/olric-stats/main.go delete mode 100644 cmd/olric-stats/query/query.go delete mode 100644 internal/dmap/scan.go delete mode 100644 internal/encoding/encoder_test.go delete mode 100644 internal/encoding/scan_test.go delete mode 100644 internal/kvstore/bitmap.go delete mode 100644 internal/protocol/cluster_parser.go rename internal/protocol/{cluster_parser_test.go => cluster_test.go} (56%) delete mode 100644 internal/protocol/dmap_parser.go delete mode 100644 internal/protocol/dmap_parser_test.go delete mode 100644 internal/protocol/pubsub_parser.go delete mode 100644 internal/protocol/pubsub_parser_test.go create mode 100644 internal/protocol/pubsub_test.go delete mode 100644 internal/protocol/system_parser.go create mode 100644 internal/protocol/system_test.go rename internal/{encoding => resp}/encoder.go (99%) create mode 100644 internal/resp/encoder_test.go rename internal/{encoding => resp}/scan.go (61%) create mode 100644 internal/roundrobin/round_robin.go create mode 100644 internal/roundrobin/round_robin_test.go create mode 100644 internal/server/mux_test.go delete mode 100644 pkg/storage/storage.go delete mode 100644 pkg/storage/storage_test.go create mode 100644 pubsub.go create mode 100644 pubsub_test.go diff --git a/README.md b/README.md index 043ae4a9..5f3e28a1 100644 --- a/README.md +++ b/README.md @@ -17,7 +17,7 @@ The current production version is [v0.4.3](https://github.com/buraksezer/olric/t * Designed to share some transient, approximate, fast-changing data between servers, * Embeddable but can be used as a language-independent service with *olricd*, * Supports different eviction algorithms, -* Fast binary protocol, +* Supports Redis protocol, * Highly available and horizontally scalable, * Provides best-effort consistency guarantees without being a complete CP (indeed PA/EC) solution, * Supports replication by default (with sync and async options), diff --git a/client.go b/client.go index 5ddccbb0..66d74bc2 100644 --- a/client.go +++ b/client.go @@ -25,6 +25,13 @@ import ( const DefaultScanCount = 10 +type Member struct { + Name string + ID uint64 + Birthdate int64 + Coordinator bool +} + type Iterator interface { Next() bool Key() string @@ -177,11 +184,25 @@ type statsConfig struct { type StatsOption func(*statsConfig) +type pubsubConfig struct { + Address string +} + +func ToAddress(addr string) PubSubOption { + return func(cfg *pubsubConfig) { + cfg.Address = addr + } +} + +type PubSubOption func(option *pubsubConfig) + type Client interface { NewDMap(name string, options ...DMapOption) (DMap, error) + NewPubSub(options ...PubSubOption) (*PubSub, error) Stats(ctx context.Context, options ...StatsOption) (stats.Stats, error) Ping(ctx context.Context, addr string) error PingWithMessage(ctx context.Context, addr, message string) (string, error) RoutingTable(ctx context.Context) (RoutingTable, error) + Members(ctx context.Context) ([]Member, error) Close(ctx context.Context) error } diff --git a/cluster.go b/cluster.go index 3c9decc5..e2bb4279 100644 --- a/cluster.go +++ b/cluster.go @@ -149,3 +149,26 @@ func (db *Olric) routingTable(ctx context.Context) (RoutingTable, error) { } return mapToRoutingTable(slice) } + +func (db *Olric) clusterMembersCommandHandler(conn redcon.Conn, cmd redcon.Command) { + _, err := protocol.ParseClusterMembers(cmd) + if err != nil { + protocol.WriteError(conn, err) + return + } + + coordinator := db.rt.Discovery().GetCoordinator() + members := db.rt.Discovery().GetMembers() + conn.WriteArray(len(members)) + for _, member := range members { + conn.WriteArray(4) + conn.WriteBulkString(member.Name) + conn.WriteUint64(member.ID) + conn.WriteInt64(member.Birthdate) + if coordinator.CompareByID(member) { + conn.WriteBulkString("true") + } else { + conn.WriteBulkString("false") + } + } +} diff --git a/cluster_client.go b/cluster_client.go index 43b6a4b6..4788aec9 100644 --- a/cluster_client.go +++ b/cluster_client.go @@ -25,9 +25,9 @@ import ( "github.com/buraksezer/olric/config" "github.com/buraksezer/olric/internal/bufpool" "github.com/buraksezer/olric/internal/dmap" - "github.com/buraksezer/olric/internal/encoding" "github.com/buraksezer/olric/internal/kvstore/entry" "github.com/buraksezer/olric/internal/protocol" + "github.com/buraksezer/olric/internal/resp" "github.com/buraksezer/olric/internal/server" "github.com/buraksezer/olric/pkg/storage" "github.com/buraksezer/olric/stats" @@ -97,7 +97,7 @@ func (dm *ClusterDMap) Put(ctx context.Context, key string, value interface{}, o valueBuf := pool.Get() defer pool.Put(valueBuf) - enc := encoding.New(valueBuf) + enc := resp.New(valueBuf) err = enc.Encode(value) if err != nil { return err @@ -203,7 +203,7 @@ func (dm *ClusterDMap) GetPut(ctx context.Context, key string, value interface{} valueBuf := pool.Get() defer pool.Put(valueBuf) - enc := encoding.New(valueBuf) + enc := resp.New(valueBuf) err = enc.Encode(value) if err != nil { return nil, err @@ -330,14 +330,13 @@ func (dm *ClusterDMap) Scan(ctx context.Context, options ...ScanOption) (Iterato if sc.Count == 0 { sc.Count = DefaultScanCount } - if sc.Logger == nil { - sc.Logger = log.New(os.Stderr, "logger: ", log.Lshortfile) - } + ictx, cancel := context.WithCancel(ctx) i := &ClusterIterator{ dm: dm, clusterClient: dm.clusterClient, config: &sc, + logger: dm.clusterClient.logger, allKeys: make(map[string]struct{}), finished: make(map[string]struct{}), cursors: make(map[string]uint64), @@ -372,6 +371,8 @@ func (dm *ClusterDMap) Destroy(ctx context.Context) error { type ClusterClient struct { client *server.Client + config *clusterClientConfig + logger *log.Logger } func (cl *ClusterClient) Ping(ctx context.Context, addr string) error { @@ -465,10 +466,56 @@ func (cl *ClusterClient) Stats(ctx context.Context, options ...StatsOption) (sta return s, nil } +func (cl *ClusterClient) Members(ctx context.Context) ([]Member, error) { + rc, err := cl.client.Pick() + if err != nil { + return []Member{}, err + } + + cmd := protocol.NewClusterMembers().Command(ctx) + err = rc.Process(ctx, cmd) + if err != nil { + return []Member{}, processProtocolError(err) + } + + if err = cmd.Err(); err != nil { + return []Member{}, processProtocolError(err) + } + + items, err := cmd.Slice() + if err != nil { + return []Member{}, processProtocolError(err) + } + var members []Member + for _, rawItem := range items { + m := Member{} + item := rawItem.([]interface{}) + m.Name = item[0].(string) + + switch id := item[1].(type) { + case uint64: + m.ID = id + case int64: + m.ID = uint64(id) + } + + m.Birthdate = item[2].(int64) + if item[3] == "true" { + m.Coordinator = true + } + members = append(members, m) + } + return members, nil +} + func (cl *ClusterClient) Close(ctx context.Context) error { return cl.client.Shutdown(ctx) } +func (cl *ClusterClient) NewPubSub(options ...PubSubOption) (*PubSub, error) { + return newPubSub(cl.client, options...) +} + func (cl *ClusterClient) NewDMap(name string, options ...DMapOption) (DMap, error) { var dc dmapConfig for _, opt := range options { @@ -489,24 +536,54 @@ func (cl *ClusterClient) NewDMap(name string, options ...DMapOption) (DMap, erro }, nil } -func NewClusterClient(addresses []string, c *config.Client) (*ClusterClient, error) { +type ClusterClientOption func(c *clusterClientConfig) + +type clusterClientConfig struct { + logger *log.Logger + config *config.Client +} + +func WithLogger(l *log.Logger) ClusterClientOption { + return func(cfg *clusterClientConfig) { + cfg.logger = l + } +} + +func WithConfig(c *config.Client) ClusterClientOption { + return func(cfg *clusterClientConfig) { + cfg.config = c + } +} + +func NewClusterClient(addresses []string, options ...ClusterClientOption) (*ClusterClient, error) { if len(addresses) == 0 { return nil, fmt.Errorf("addresses cannot be empty") } - if c == nil { - c = config.NewClient() + var cc clusterClientConfig + for _, opt := range options { + opt(&cc) + } + + if cc.logger == nil { + cc.logger = log.New(os.Stderr, "logger: ", log.Lshortfile) + } + + if cc.config == nil { + cc.config = config.NewClient() } - if err := c.Sanitize(); err != nil { + if err := cc.config.Sanitize(); err != nil { return nil, err } - if err := c.Validate(); err != nil { + if err := cc.config.Validate(); err != nil { return nil, err } cl := &ClusterClient{ - client: server.NewClient(c), + client: server.NewClient(cc.config), + config: &cc, + logger: cc.logger, } for _, address := range addresses { cl.client.Get(address) diff --git a/cluster_client_test.go b/cluster_client_test.go index 30949aee..7fbb5fa4 100644 --- a/cluster_client_test.go +++ b/cluster_client_test.go @@ -16,10 +16,13 @@ package olric import ( "context" - "github.com/buraksezer/olric/stats" + "log" + "os" "testing" "time" + "github.com/buraksezer/olric/config" + "github.com/buraksezer/olric/stats" "github.com/stretchr/testify/require" "golang.org/x/sync/errgroup" ) @@ -30,7 +33,7 @@ func TestClusterClient_Ping(t *testing.T) { db := cluster.addMember(t) ctx := context.Background() - c, err := NewClusterClient([]string{db.name}, nil) + c, err := NewClusterClient([]string{db.name}) require.NoError(t, err) defer func() { require.NoError(t, c.Close(ctx)) @@ -46,7 +49,7 @@ func TestClusterClient_PingWithMessage(t *testing.T) { db := cluster.addMember(t) ctx := context.Background() - c, err := NewClusterClient([]string{db.name}, nil) + c, err := NewClusterClient([]string{db.name}) require.NoError(t, err) defer func() { require.NoError(t, c.Close(ctx)) @@ -63,7 +66,7 @@ func TestClusterClient_RoutingTable(t *testing.T) { db := cluster.addMember(t) ctx := context.Background() - c, err := NewClusterClient([]string{db.name}, nil) + c, err := NewClusterClient([]string{db.name}) require.NoError(t, err) defer func() { require.NoError(t, c.Close(ctx)) @@ -80,7 +83,7 @@ func TestClusterClient_Put(t *testing.T) { db := cluster.addMember(t) ctx := context.Background() - c, err := NewClusterClient([]string{db.name}, nil) + c, err := NewClusterClient([]string{db.name}) require.NoError(t, err) defer func() { require.NoError(t, c.Close(ctx)) @@ -98,7 +101,7 @@ func TestClusterClient_Get(t *testing.T) { db := cluster.addMember(t) ctx := context.Background() - c, err := NewClusterClient([]string{db.name}, nil) + c, err := NewClusterClient([]string{db.name}) require.NoError(t, err) defer func() { require.NoError(t, c.Close(ctx)) @@ -124,7 +127,7 @@ func TestClusterClient_Delete(t *testing.T) { db := cluster.addMember(t) ctx := context.Background() - c, err := NewClusterClient([]string{db.name}, nil) + c, err := NewClusterClient([]string{db.name}) require.NoError(t, err) defer func() { require.NoError(t, c.Close(ctx)) @@ -148,7 +151,7 @@ func TestClusterClient_Destroy(t *testing.T) { db := cluster.addMember(t) ctx := context.Background() - c, err := NewClusterClient([]string{db.name}, nil) + c, err := NewClusterClient([]string{db.name}) require.NoError(t, err) defer func() { require.NoError(t, c.Close(ctx)) @@ -172,7 +175,7 @@ func TestClusterClient_Incr(t *testing.T) { db := cluster.addMember(t) ctx := context.Background() - c, err := NewClusterClient([]string{db.name}, nil) + c, err := NewClusterClient([]string{db.name}) require.NoError(t, err) defer func() { require.NoError(t, c.Close(ctx)) @@ -201,7 +204,7 @@ func TestClusterClient_Decr(t *testing.T) { db := cluster.addMember(t) ctx := context.Background() - c, err := NewClusterClient([]string{db.name}, nil) + c, err := NewClusterClient([]string{db.name}) require.NoError(t, err) defer func() { require.NoError(t, c.Close(ctx)) @@ -233,7 +236,7 @@ func TestClusterClient_GetPut(t *testing.T) { db := cluster.addMember(t) ctx := context.Background() - c, err := NewClusterClient([]string{db.name}, nil) + c, err := NewClusterClient([]string{db.name}) require.NoError(t, err) defer func() { require.NoError(t, c.Close(ctx)) @@ -259,7 +262,7 @@ func TestClusterClient_Expire(t *testing.T) { db := cluster.addMember(t) ctx := context.Background() - c, err := NewClusterClient([]string{db.name}, nil) + c, err := NewClusterClient([]string{db.name}) require.NoError(t, err) defer func() { require.NoError(t, c.Close(ctx)) @@ -285,7 +288,7 @@ func TestClusterClient_Lock_Unlock(t *testing.T) { db := cluster.addMember(t) ctx := context.Background() - c, err := NewClusterClient([]string{db.name}, nil) + c, err := NewClusterClient([]string{db.name}) require.NoError(t, err) defer func() { require.NoError(t, c.Close(ctx)) @@ -306,7 +309,7 @@ func TestClusterClient_Lock_Lease(t *testing.T) { db := cluster.addMember(t) ctx := context.Background() - c, err := NewClusterClient([]string{db.name}, nil) + c, err := NewClusterClient([]string{db.name}) require.NoError(t, err) defer func() { require.NoError(t, c.Close(ctx)) @@ -332,7 +335,7 @@ func TestClusterClient_Lock_ErrLockNotAcquired(t *testing.T) { db := cluster.addMember(t) ctx := context.Background() - c, err := NewClusterClient([]string{db.name}, nil) + c, err := NewClusterClient([]string{db.name}) require.NoError(t, err) defer func() { require.NoError(t, c.Close(ctx)) @@ -353,7 +356,7 @@ func TestClusterClient_LockWithTimeout(t *testing.T) { db := cluster.addMember(t) ctx := context.Background() - c, err := NewClusterClient([]string{db.name}, nil) + c, err := NewClusterClient([]string{db.name}) require.NoError(t, err) defer func() { require.NoError(t, c.Close(ctx)) @@ -374,7 +377,7 @@ func TestClusterClient_LockWithTimeout_ErrNoSuchLock(t *testing.T) { db := cluster.addMember(t) ctx := context.Background() - c, err := NewClusterClient([]string{db.name}, nil) + c, err := NewClusterClient([]string{db.name}) require.NoError(t, err) defer func() { require.NoError(t, c.Close(ctx)) @@ -397,7 +400,7 @@ func TestClusterClient_LockWithTimeout_Then_Lease(t *testing.T) { db := cluster.addMember(t) ctx := context.Background() - c, err := NewClusterClient([]string{db.name}, nil) + c, err := NewClusterClient([]string{db.name}) require.NoError(t, err) defer func() { require.NoError(t, c.Close(ctx)) @@ -424,7 +427,7 @@ func TestClusterClient_LockWithTimeout_ErrLockNotAcquired(t *testing.T) { db := cluster.addMember(t) ctx := context.Background() - c, err := NewClusterClient([]string{db.name}, nil) + c, err := NewClusterClient([]string{db.name}) require.NoError(t, err) defer func() { require.NoError(t, c.Close(ctx)) @@ -445,7 +448,7 @@ func TestClusterClient_Put_Ex(t *testing.T) { db := cluster.addMember(t) ctx := context.Background() - c, err := NewClusterClient([]string{db.name}, nil) + c, err := NewClusterClient([]string{db.name}) require.NoError(t, err) defer func() { require.NoError(t, c.Close(ctx)) @@ -468,7 +471,7 @@ func TestClusterClient_Put_PX(t *testing.T) { db := cluster.addMember(t) ctx := context.Background() - c, err := NewClusterClient([]string{db.name}, nil) + c, err := NewClusterClient([]string{db.name}) require.NoError(t, err) defer func() { require.NoError(t, c.Close(ctx)) @@ -491,7 +494,7 @@ func TestClusterClient_Put_EXAT(t *testing.T) { db := cluster.addMember(t) ctx := context.Background() - c, err := NewClusterClient([]string{db.name}, nil) + c, err := NewClusterClient([]string{db.name}) require.NoError(t, err) defer func() { require.NoError(t, c.Close(ctx)) @@ -514,7 +517,7 @@ func TestClusterClient_Put_PXAT(t *testing.T) { db := cluster.addMember(t) ctx := context.Background() - c, err := NewClusterClient([]string{db.name}, nil) + c, err := NewClusterClient([]string{db.name}) require.NoError(t, err) defer func() { require.NoError(t, c.Close(ctx)) @@ -537,7 +540,7 @@ func TestClusterClient_Put_NX(t *testing.T) { db := cluster.addMember(t) ctx := context.Background() - c, err := NewClusterClient([]string{db.name}, nil) + c, err := NewClusterClient([]string{db.name}) require.NoError(t, err) defer func() { require.NoError(t, c.Close(ctx)) @@ -565,7 +568,7 @@ func TestClusterClient_Put_XX(t *testing.T) { db := cluster.addMember(t) ctx := context.Background() - c, err := NewClusterClient([]string{db.name}, nil) + c, err := NewClusterClient([]string{db.name}) require.NoError(t, err) defer func() { require.NoError(t, c.Close(ctx)) @@ -583,7 +586,7 @@ func TestClusterClient_Stats(t *testing.T) { db := cluster.addMember(t) ctx := context.Background() - c, err := NewClusterClient([]string{db.name}, nil) + c, err := NewClusterClient([]string{db.name}) require.NoError(t, err) defer func() { require.NoError(t, c.Close(ctx)) @@ -601,7 +604,7 @@ func TestClusterClient_Stats_CollectRuntime(t *testing.T) { db := cluster.addMember(t) ctx := context.Background() - c, err := NewClusterClient([]string{db.name}, nil) + c, err := NewClusterClient([]string{db.name}) require.NoError(t, err) defer func() { require.NoError(t, c.Close(ctx)) @@ -613,3 +616,50 @@ func TestClusterClient_Stats_CollectRuntime(t *testing.T) { require.NotNil(t, s.Runtime) require.NotEqual(t, empty, s) } + +func TestClusterClient_Set_Options(t *testing.T) { + cluster := newTestOlricCluster(t) + db := cluster.addMember(t) + + ctx := context.Background() + + lg := log.New(os.Stderr, "logger: ", log.Lshortfile) + cfg := config.NewClient() + c, err := NewClusterClient([]string{db.name}, WithConfig(cfg), WithLogger(lg)) + require.NoError(t, err) + defer func() { + require.NoError(t, c.Close(ctx)) + }() + + require.Equal(t, cfg, c.config.config) + require.Equal(t, lg, c.config.logger) +} + +func TestClusterClient_Members(t *testing.T) { + cluster := newTestOlricCluster(t) + cluster.addMember(t) + db := cluster.addMember(t) + + ctx := context.Background() + c, err := NewClusterClient([]string{db.name}) + require.NoError(t, err) + defer func() { + require.NoError(t, c.Close(ctx)) + }() + + members, err := c.Members(ctx) + require.NoError(t, err) + require.Len(t, members, 2) + + coordinator := db.rt.Discovery().GetCoordinator() + for _, member := range members { + require.NotEqual(t, "", member.Name) + require.NotEqual(t, 0, member.ID) + require.NotEqual(t, 0, member.Birthdate) + if coordinator.ID == member.ID { + require.True(t, member.Coordinator) + } else { + require.False(t, member.Coordinator) + } + } +} diff --git a/cluster_iterator_test.go b/cluster_iterator_test.go index 2f6ea680..c89e8e45 100644 --- a/cluster_iterator_test.go +++ b/cluster_iterator_test.go @@ -28,7 +28,7 @@ func TestClusterClient_ScanMatch(t *testing.T) { db := cluster.addMember(t) ctx := context.Background() - c, err := NewClusterClient([]string{db.name}, nil) + c, err := NewClusterClient([]string{db.name}) require.NoError(t, err) defer func() { require.NoError(t, c.Close(ctx)) @@ -67,7 +67,7 @@ func TestClusterClient_Scan(t *testing.T) { cl.addMember(t) ctx := context.Background() - c, err := NewClusterClient([]string{db.name}, nil) + c, err := NewClusterClient([]string{db.name}) require.NoError(t, err) defer func() { require.NoError(t, c.Close(ctx)) diff --git a/cmd/olric-stats/main.go b/cmd/olric-stats/main.go deleted file mode 100644 index f224b428..00000000 --- a/cmd/olric-stats/main.go +++ /dev/null @@ -1,175 +0,0 @@ -// Copyright 2018-2022 Burak Sezer -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Pretty printer for Olric stats - -package main - -import ( - "flag" - "fmt" - "log" - "os" - "runtime" - - "github.com/buraksezer/olric" - "github.com/sean-/seed" -) - -const defaultAddress = "127.0.0.1:3320" - -type arguments struct { - help bool - version bool - runtime bool - partitions bool - backup bool - dump bool - dmap bool - pubsub bool - network bool - members bool - address string - timeout string - id int -} - -func usage() { - var msg = `Usage: olric-stats [options] ... - -Inspect cluster state and per-node statistics. - -Options: - -h, --help Print this message and exit. - -v, --version Print the version number and exit. - -a --address Network address of the server in format. - Default: 127.0.0.1:3320 - -t --timeout Set time limit for requests and dial made by the client. - Default: 10ms - -r --runtime Print Go runtime stats. It calls runtime.ReadMemStats - on the target server. You should know that this function stops - all running goroutines to collect statistics. - -p --partitions Print partition statistics of the server. - --id Partition id to query. - --backup Enable to query backup partitions. - -d --dump Dump stats data in JSON format. - -D --dmap Print DMap statistics. - -P --pubsub Print Pub/Sub statistics. - -n --network Print network statistics. - -m --members List current members of the cluster. - -The Go runtime version %s -Report bugs to https://github.com/buraksezer/olric/issues -` - _, err := fmt.Fprintf(os.Stdout, msg, runtime.Version()) - if err != nil { - panic(err) - } -} - -func main() { - args := &arguments{} - // No need for timestamp and etc in this function. Just log it. - log.SetFlags(0) - - // Parse command line parameters - f := flag.NewFlagSet(os.Args[0], flag.ContinueOnError) - f.SetOutput(os.Stdout) - f.BoolVar(&args.help, "h", false, "") - f.BoolVar(&args.help, "help", false, "") - - f.BoolVar(&args.version, "version", false, "") - f.BoolVar(&args.version, "v", false, "") - - f.StringVar(&args.address, "a", defaultAddress, "") - f.StringVar(&args.address, "addr", defaultAddress, "") - - f.StringVar(&args.timeout, "t", "10ms", "") - f.StringVar(&args.timeout, "timeout", "10ms", "") - - f.BoolVar(&args.partitions, "p", false, "") - f.BoolVar(&args.partitions, "partitions", false, "") - f.IntVar(&args.id, "id", -1, "") - f.BoolVar(&args.backup, "backup", false, "") - - f.BoolVar(&args.runtime, "r", false, "") - f.BoolVar(&args.runtime, "runtime", false, "") - - f.BoolVar(&args.dump, "d", false, "") - f.BoolVar(&args.dump, "dump", false, "") - - f.BoolVar(&args.dmap, "D", false, "") - f.BoolVar(&args.dmap, "dmap", false, "") - - f.BoolVar(&args.pubsub, "P", false, "") - f.BoolVar(&args.pubsub, "pubsub", false, "") - - f.BoolVar(&args.network, "n", false, "") - f.BoolVar(&args.network, "network", false, "") - - f.BoolVar(&args.members, "m", false, "") - f.BoolVar(&args.members, "members", false, "") - - logger := log.New(os.Stderr, "", log.LstdFlags) - logger.SetFlags(log.Flags() &^ (log.Ldate | log.Ltime)) - - if err := f.Parse(os.Args[1:]); err != nil { - log.Fatalf("Failed to parse flags: %v", err) - } - - switch { - case args.help: - usage() - return - case args.version: - _, _ = fmt.Fprintf(os.Stdout, - "olric-stats %s with runtime %s\n", - olric.ReleaseVersion, runtime.Version()) - } - - // MustInit provides guaranteed secure seeding. If `/dev/urandom` is not - // available, MustInit will panic() with an error indicating why reading from - // `/dev/urandom` failed. MustInit() will upgrade the seed if for some reason a - // call to Init() failed in the past. - seed.MustInit() - /* - q, err := query.New(args.address, args.timeout, logger) - if err != nil { - logger.Fatalf("olric-stats: %v", err) - } - - switch { - case args.dump: - err = q.Dump() - case args.runtime: - err = q.PrintRuntimeStats() - case args.partitions: - err = q.PrintPartitionStats(args.id, args.backup) - case args.members: - err = q.PrintClusterMembers() - case args.dmap: - err = q.PrintDMapStatistics() - case args.pubsub: - err = q.PrintPubSubStatistics() - case args.network: - err = q.PrintNetworkStatistics() - default: - usage() - return - } - - if err != nil { - logger.Fatalf("olric-stats: %v", err) - }*/ -} diff --git a/cmd/olric-stats/query/query.go b/cmd/olric-stats/query/query.go deleted file mode 100644 index 4a6717d3..00000000 --- a/cmd/olric-stats/query/query.go +++ /dev/null @@ -1,271 +0,0 @@ -// Copyright 2018-2022 Burak Sezer -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package query - -/* -import ( - "encoding/json" - "fmt" - "log" - "time" - - "github.com/buraksezer/olric/client" - "github.com/buraksezer/olric/config" - "github.com/buraksezer/olric/serializer" - "github.com/buraksezer/olric/stats" -) - -type Query struct { - addr string - client *client.Client - log *log.Logger -} - -func New(addr, timeout string, logger *log.Logger) (*Query, error) { - dt, err := time.ParseDuration(timeout) - if err != nil { - return nil, err - } - - cc := &client.Config{ - Servers: []string{addr}, - Serializer: serializer.NewMsgpackSerializer(), - Client: &config.Client{ - DialTimeout: dt, - MaxConn: 1, - }, - } - - c, err := client.New(cc) - if err != nil { - return nil, err - } - - return &Query{ - addr: addr, - client: c, - log: logger, - }, nil -} - -func (q *Query) prettyPrint(partID uint64, part *stats.Partition) { - q.log.Printf("PartID: %d", partID) - - if len(part.PreviousOwners) != 0 { - q.log.Printf(" Previous Owners:") - for idx, previous := range part.PreviousOwners { - q.log.Printf(" %d: %s", idx, previous.Name) - } - } else { - q.log.Printf(" Previous Owners: not found") - } - - if len(part.Backups) != 0 { - q.log.Printf(" Backups:") - for idx, backup := range part.Backups { - q.log.Printf(" %d: %s", idx+1, backup.Name) - } - } else { - q.log.Printf(" Backups: not found") - } - - if len(part.DMaps) != 0 { - q.log.Printf(" DMaps:") - for name, dm := range part.DMaps { - q.log.Printf(" Name: %s", name) - q.log.Printf(" Length: %d", dm.Length) - q.log.Printf(" Allocated: %d", dm.SlabInfo.Allocated) - q.log.Printf(" Inuse: %d", dm.SlabInfo.Inuse) - q.log.Printf(" Garbage: %d", dm.SlabInfo.Garbage) - q.log.Printf(" Number of tables: %d", dm.NumTables) - q.log.Printf("\n") - } - } else { - q.log.Printf(" DMaps: not found") - } - - q.log.Printf(" Length of partition: %d", part.Length) - q.log.Printf("\n") -} - -func (q *Query) printAllPartitionStats(backup bool) error { - data, err := q.client.Stats(q.addr) - if err != nil { - return err - } - - var ( - totalLength int - totalInuse int - totalAllocated int - totalGarbage int - ) - - partitions := data.Partitions - if backup { - partitions = data.Backups - } - - for partID, part := range partitions { - q.prettyPrint(uint64(partID), &part) - totalLength += part.Length - for _, dm := range part.DMaps { - totalInuse += dm.SlabInfo.Inuse - totalAllocated += dm.SlabInfo.Allocated - totalGarbage += dm.SlabInfo.Garbage - } - } - - q.log.Printf("Summary for %s:\n\n", data.Member.String()) - q.log.Printf("Total length of partitions: %d", totalLength) - q.log.Printf("Total partition count: %d", len(data.Partitions)) - q.log.Printf("Total Allocated: %d", totalAllocated) - q.log.Printf("Total Inuse: %d", totalInuse) - q.log.Printf("Total Garbage: %d", totalGarbage) - - return nil -} - -func (q *Query) PrintPartitionStats(partID int, backup bool) error { - if partID == -1 { - return q.printAllPartitionStats(backup) - } - - data, err := q.client.Stats(q.addr) - if err != nil { - return err - } - - partitions := data.Partitions - if backup { - partitions = data.Backups - } - part, ok := partitions[stats.PartitionID(partID)] - if !ok { - return fmt.Errorf("partition does not belong to %s", q.addr) - } - - q.prettyPrint(uint64(partID), &part) - return nil -} - -func (q *Query) Dump() error { - data, err := q.client.Stats(q.addr) - if err != nil { - return err - } - - js, err := json.MarshalIndent(data, "", " ") - if err != nil { - return err - } - - q.log.Printf(string(js)) - return nil -} - -func (q *Query) PrintRuntimeStats() error { - data, err := q.client.Stats(q.addr, client.CollectRuntime()) - if err != nil { - return err - } - - js, err := json.MarshalIndent(data.Runtime, "", " ") - if err != nil { - return err - } - - q.log.Printf(string(js)) - return nil -} - -func (q *Query) PrintClusterMembers() error { - data, err := q.client.Stats(q.addr) - if err != nil { - return err - } - - q.log.Printf("This member: %s", data.Member) - q.log.Printf(" ID: %d", data.Member.ID) - q.log.Printf(" Birthdate: %d", data.Member.Birthdate) - - q.log.Printf("\n") - - q.log.Printf("Cluster coordinator: %s", data.ClusterCoordinator) - q.log.Printf(" ID: %d", data.ClusterCoordinator.ID) - q.log.Printf(" Birthdate: %d", data.ClusterCoordinator.Birthdate) - - q.log.Printf("\n") - - q.log.Printf("All members:\n\n") - - for _, member := range data.ClusterMembers { - q.log.Printf("Member: %s", member) - q.log.Printf(" ID: %d", member.ID) - q.log.Printf(" Birthdate: %d", member.Birthdate) - q.log.Printf("\n") - } - return nil -} - -func (q *Query) PrintDMapStatistics() error { - data, err := q.client.Stats(q.addr) - if err != nil { - return err - } - - q.log.Printf("DMap statistics:\n") - - q.log.Printf(" Evicted total: %d", data.DMaps.EvictedTotal) - q.log.Printf(" Entries total: %d", data.DMaps.EntriesTotal) - q.log.Printf(" Get misses: %d", data.DMaps.GetMisses) - q.log.Printf(" Get hits: %d", data.DMaps.GetHits) - q.log.Printf(" Delete misses: %d", data.DMaps.DeleteMisses) - q.log.Printf(" Delete hits: %d", data.DMaps.DeleteHits) - - return nil -} - -func (q *Query) PrintPubSubStatistics() error { - data, err := q.client.Stats(q.addr) - if err != nil { - return err - } - - q.log.Printf("PubSub statistics:\n") - - q.log.Printf(" Listeners total: %d", data.PubSub.SubscribersTotal) - q.log.Printf(" Published total: %d", data.PubSub.PublishedTotal) - q.log.Printf(" Current listeners: %d", data.PubSub.CurrentSubscribers) - - return nil -} - -func (q *Query) PrintNetworkStatistics() error { - data, err := q.client.Stats(q.addr) - if err != nil { - return err - } - - q.log.Printf("Network statistics:\n") - - q.log.Printf(" Commands total: %d", data.Network.CommandsTotal) - q.log.Printf(" Read bytes total: %d", data.Network.ReadBytesTotal) - q.log.Printf(" Written bytes total: %d", data.Network.WrittenBytesTotal) - q.log.Printf(" Connections total: %d", data.Network.ConnectionsTotal) - q.log.Printf(" Current connections: %d", data.Network.CurrentConnections) - - return nil -}*/ diff --git a/config/dmap.go b/config/dmap.go index af8b4188..3a561b90 100644 --- a/config/dmap.go +++ b/config/dmap.go @@ -85,10 +85,6 @@ func (dm *DMap) Sanitize() error { dm.Engine = NewEngine() } - if err := dm.Engine.LoadPlugin(); err != nil { - return fmt.Errorf("failed to load storage engine plugin: %w", err) - } - if err := dm.Engine.Sanitize(); err != nil { return fmt.Errorf("failed to sanitize storage engine configuration: %w", err) } diff --git a/config/dmaps.go b/config/dmaps.go index 5f2cd9c7..54477a44 100644 --- a/config/dmaps.go +++ b/config/dmaps.go @@ -122,10 +122,6 @@ func (dm *DMaps) Sanitize() error { } } - if err := dm.Engine.LoadPlugin(); err != nil { - return fmt.Errorf("failed to load storage engine plugin: %w", err) - } - if err := dm.Engine.Sanitize(); err != nil { return fmt.Errorf("failed to sanitize storage engine configuration: %w", err) } diff --git a/config/engine.go b/config/engine.go index 8078d93c..ab76c9f4 100644 --- a/config/engine.go +++ b/config/engine.go @@ -16,8 +16,6 @@ package config import ( "fmt" - "os" - "github.com/buraksezer/olric/internal/kvstore" "github.com/buraksezer/olric/pkg/storage" ) @@ -26,10 +24,6 @@ import ( // If you don't have a custom storage engine implementation or configuration for // the default one, just call NewStorageEngine() function to use it with sane defaults. type Engine struct { - // Plugins is an array that contains the paths of storage engine plugins. - // These plugins have to implement storage.Engine interface. - Plugin string - Name string Implementation storage.Engine @@ -57,28 +51,6 @@ func (s *Engine) Validate() error { return nil } -func (s *Engine) LoadPlugin() error { - if s.Plugin == "" { - return nil - } - - _, err := os.Stat(s.Plugin) - if os.IsNotExist(err) { - return fmt.Errorf("storage engine plugin could not be found on disk: %s", s.Plugin) - } - if err != nil { - return err - } - - engine, err := storage.LoadAsPlugin(s.Plugin) - if err != nil { - return err - } - s.Implementation = engine - s.Name = engine.Name() - return nil -} - // Sanitize sets default values to empty configuration variables, if it's possible. func (s *Engine) Sanitize() error { if s.Name == "" { diff --git a/config/internal/loader/loader.go b/config/internal/loader/loader.go index 27c3d607..b55f1136 100644 --- a/config/internal/loader/loader.go +++ b/config/internal/loader/loader.go @@ -91,7 +91,6 @@ type memberlist struct { } type engine struct { - Plugin string `yaml:"plugin"` Name string `yaml:"name"` Config map[string]interface{} `yaml:"config"` } diff --git a/config/load.go b/config/load.go index 3d83755c..5e35b20d 100644 --- a/config/load.go +++ b/config/load.go @@ -101,7 +101,6 @@ func loadDMapConfig(c *loader.Loader) (*DMaps, error) { if c.DMaps.Engine != nil { e := NewEngine() - e.Plugin = c.DMaps.Engine.Plugin e.Name = c.DMaps.Engine.Name e.Config = c.DMaps.Engine.Config res.Engine = e @@ -118,7 +117,6 @@ func loadDMapConfig(c *loader.Loader) (*DMaps, error) { } if dc.Engine != nil { e := NewEngine() - e.Plugin = dc.Engine.Plugin e.Name = dc.Engine.Name e.Config = dc.Engine.Config cc.Engine = e diff --git a/dmap.go b/dmap.go index 05c4739b..7dcc6f68 100644 --- a/dmap.go +++ b/dmap.go @@ -38,10 +38,6 @@ var ( // ErrNoSuchLock is returned when the requested lock does not exist ErrNoSuchLock = errors.New("no such lock") - // ErrEndOfQuery is the error returned by Range when no more data is available. - // Functions should return ErrEndOfQuery only to signal a graceful end of input. - ErrEndOfQuery = errors.New("end of query") - // ErrClusterQuorum means that the cluster could not reach a healthy numbers of members to operate. ErrClusterQuorum = errors.New("cannot be reached cluster quorum to operate") @@ -58,8 +54,6 @@ func convertDMapError(err error) error { return ErrKeyNotFound case errors.Is(err, dmap.ErrDMapNotFound): return ErrKeyNotFound - case errors.Is(err, dmap.ErrEndOfQuery): - return ErrEndOfQuery case errors.Is(err, dmap.ErrLockNotAcquired): return ErrLockNotAcquired case errors.Is(err, dmap.ErrNoSuchLock): @@ -74,146 +68,3 @@ func convertDMapError(err error) error { return convertClusterError(err) } } - -/* -// Entry is a DMap entry with its metadata. -type Entry struct { - Key string - Value interface{} - TTL int64 - Timestamp int64 -} - -// LockContext is returned by Lock and LockWithTimeout methods. -// It should be stored in a proper way to release the lock. -type LockContext struct { - ctx *dmap.LockContext -} - -// Cursor implements distributed query on DMaps. -type Cursor struct { - //cursor *dmap.Cursor -} - -// DMapLegacy represents a distributed map instance. -type DMapLegacy struct { - dm *dmap.DMap -} - -// NewDMap creates an returns a new DMap instance. -func (db *Olric) NewDMap(name string) (*DMapLegacy, error) { - dm, err := db.dmap.NewDMap(name) - if err != nil { - return nil, convertDMapError(err) - } - return &DMapLegacy{ - dm: dm, - }, nil -} - -// Name exposes name of the DMap. -func (dm *DMapLegacy) Name() string { - return dm.dm.Name() -} - -// Get gets the value for the given key. It returns ErrKeyNotFound if the DB -// does not contain the key. It's thread-safe. It is safe to modify the contents -// of the returned value. -func (dm *DMapLegacy) Get(key string) (interface{}, error) { - value, err := dm.dm.Get(key) - if err != nil { - return nil, convertDMapError(err) - } - return value, nil -} - -// LockWithTimeout sets a lock for the given key. If the lock is still unreleased the end of given period of time, -// it automatically releases the lock. Acquired lock is only for the key in this dmap. -// -// It returns immediately if it acquires the lock for the given key. Otherwise, it waits until deadline. -// -// You should know that the locks are approximate, and only to be used for non-critical purposes. -func (dm *DMapLegacy) LockWithTimeout(key string, timeout, deadline time.Duration) (*LockContext, error) { - ctx, err := dm.dm.LockWithTimeout(key, timeout, deadline) - if err != nil { - return nil, convertDMapError(err) - } - return &LockContext{ctx: ctx}, nil -} - -// Lock sets a lock for the given key. Acquired lock is only for the key in this dmap. -// -// It returns immediately if it acquires the lock for the given key. Otherwise, it waits until deadline. -// -// You should know that the locks are approximate, and only to be used for non-critical purposes. -func (dm *DMapLegacy) Lock(key string, deadline time.Duration) (*LockContext, error) { - ctx, err := dm.dm.Lock(key, deadline) - if err != nil { - return nil, convertDMapError(err) - } - return &LockContext{ctx: ctx}, nil -} - -// Unlock releases the lock. -func (l *LockContext) Unlock() error { - err := l.ctx.Unlock() - return convertDMapError(err) -} - -// Put sets the value for the given key. It overwrites any previous value -// for that key, and it's thread-safe. The key has to be string. value type -// is arbitrary. It is safe to modify the contents of the arguments after -// Put returns but not before. -func (dm *DMapLegacy) Put(key string, value interface{}) error { - err := dm.dm.Put(context.Background(), key, value, &dmap.PutConfig{}) - return convertDMapError(err) -} - -// Expire updates the expiry for the given key. It returns ErrKeyNotFound if the -// DB does not contain the key. It's thread-safe. -func (dm *DMapLegacy) Expire(key string, timeout time.Duration) error { - err := dm.dm.Expire(key, timeout) - return convertDMapError(err) -} - -// Delete deletes the value for the given key. Delete will not return error if key doesn't exist. It's thread-safe. -// It is safe to modify the contents of the argument after Delete returns. -func (dm *DMapLegacy) Delete(key string) error { - err := dm.dm.Delete(key) - return convertDMapError(err) -} - -// Incr atomically increments key by delta. The return value is the new value after being incremented or an error. -func (dm *DMapLegacy) Incr(key string, delta int) (int, error) { - value, err := dm.dm.Incr(key, delta) - if err != nil { - return 0, convertDMapError(err) - } - return value, nil -} - -// Decr atomically decrements key by delta. The return value is the new value after being decremented or an error. -func (dm *DMapLegacy) Decr(key string, delta int) (int, error) { - value, err := dm.dm.Decr(key, delta) - if err != nil { - return 0, convertDMapError(err) - } - return value, nil -} - -// GetPut atomically sets key to value and returns the old value stored at key. -func (dm *DMapLegacy) GetPut(key string, value interface{}) (interface{}, error) { - prev, err := dm.dm.GetPut(key, value) - if err != nil { - return nil, convertDMapError(err) - } - return prev, nil -} - -// Destroy flushes the given DMap on the cluster. You should know that there -// is no global lock on DMaps. So if you call Put and Destroy methods -// concurrently on the cluster, Put calls may set new values to the dmap. -func (dm *DMapLegacy) Destroy() error { - err := dm.dm.Destroy() - return convertDMapError(err) -}*/ diff --git a/embedded_client.go b/embedded_client.go index a1d82051..6a060b36 100644 --- a/embedded_client.go +++ b/embedded_client.go @@ -236,6 +236,37 @@ func (e *EmbeddedClient) RoutingTable(ctx context.Context) (RoutingTable, error) return e.db.routingTable(ctx) } +func (e *EmbeddedClient) Members(ctx context.Context) ([]Member, error) { + members := e.db.rt.Discovery().GetMembers() + coordinator := e.db.rt.Discovery().GetCoordinator() + var result []Member + for _, member := range members { + m := Member{ + Name: member.Name, + ID: member.ID, + Birthdate: member.Birthdate, + } + if coordinator.ID == member.ID { + m.Coordinator = true + } + result = append(result, m) + } + return result, nil +} + +func (e *EmbeddedClient) NewPubSub(options ...PubSubOption) (*PubSub, error) { + return newPubSub(e.db.client, options...) +} + +func (e *EmbeddedClient) NewPubSubWithAddr(addr string) (*PubSub, error) { + // TODO: Add an error type to Get + rc := e.db.client.Get(addr) + return &PubSub{ + rc: rc, + client: e.db.client, + }, nil +} + func (db *Olric) NewEmbeddedClient() *EmbeddedClient { return &EmbeddedClient{db: db} } diff --git a/embedded_client_test.go b/embedded_client_test.go index 8b8f1837..572ffd3d 100644 --- a/embedded_client_test.go +++ b/embedded_client_test.go @@ -519,3 +519,25 @@ func TestEmbeddedClient_RoutingTable_Cluster(t *testing.T) { } require.Len(t, owners, 3) } + +func TestEmbeddedClient_Member(t *testing.T) { + cluster := newTestOlricCluster(t) + db := cluster.addMember(t) + cluster.addMember(t) + + e := db.NewEmbeddedClient() + members, err := e.Members(context.Background()) + require.NoError(t, err) + require.Len(t, members, 2) + coordinator := db.rt.Discovery().GetCoordinator() + for _, member := range members { + require.NotEqual(t, "", member.Name) + require.NotEqual(t, 0, member.ID) + require.NotEqual(t, 0, member.Birthdate) + if coordinator.ID == member.ID { + require.True(t, member.Coordinator) + } else { + require.False(t, member.Coordinator) + } + } +} diff --git a/embedded_iterator.go b/embedded_iterator.go index 11a04162..be72bbb8 100644 --- a/embedded_iterator.go +++ b/embedded_iterator.go @@ -60,7 +60,7 @@ func (i *EmbeddedIterator) scanOnOwners(owners []discovery.Member) error { } if owner.CompareByID(i.client.db.rt.This()) { - keys, cursor, err := i.dm.Scan2(i.partID, i.cursors[owner.ID], i.config) + keys, cursor, err := i.dm.Scan(i.partID, i.cursors[owner.ID], i.config) if err != nil { return err diff --git a/get_response.go b/get_response.go index 8c8242fa..7e9c5eac 100644 --- a/get_response.go +++ b/get_response.go @@ -18,7 +18,7 @@ import ( "errors" "time" - "github.com/buraksezer/olric/internal/encoding" + "github.com/buraksezer/olric/internal/resp" "github.com/buraksezer/olric/pkg/storage" ) @@ -32,7 +32,7 @@ func (g *GetResponse) Scan(v interface{}) error { if g.entry == nil { return ErrNilResponse } - return encoding.Scan(g.entry.Value(), v) + return resp.Scan(g.entry.Value(), v) } func (g *GetResponse) Int() (int, error) { diff --git a/get_response_test.go b/get_response_test.go index e32387e9..63ac9095 100644 --- a/get_response_test.go +++ b/get_response_test.go @@ -22,7 +22,7 @@ import ( "time" "github.com/buraksezer/olric/internal/dmap" - "github.com/buraksezer/olric/internal/encoding" + "github.com/buraksezer/olric/internal/resp" "github.com/buraksezer/olric/internal/testcluster" "github.com/stretchr/testify/require" ) @@ -285,12 +285,12 @@ func TestDMap_Get_GetResponse(t *testing.T) { gr := &GetResponse{entry: e} buf := bytes.NewBuffer(nil) - enc := encoding.New(buf) + enc := resp.New(buf) err = enc.Encode(value) require.NoError(t, err) expectedValue := new(time.Time) - err = encoding.Scan(buf.Bytes(), expectedValue) + err = resp.Scan(buf.Bytes(), expectedValue) require.NoError(t, err) scannedValue, err := gr.Time() @@ -309,12 +309,12 @@ func TestDMap_Get_GetResponse(t *testing.T) { gr := &GetResponse{entry: e} buf := bytes.NewBuffer(nil) - enc := encoding.New(buf) + enc := resp.New(buf) err = enc.Encode(value) require.NoError(t, err) expectedValue := new(time.Duration) - err = encoding.Scan(buf.Bytes(), expectedValue) + err = resp.Scan(buf.Bytes(), expectedValue) require.NoError(t, err) scannedValue, err := gr.Duration() diff --git a/go.mod b/go.mod index 9d1a447a..0c649193 100644 --- a/go.mod +++ b/go.mod @@ -22,3 +22,5 @@ require ( golang.org/x/sync v0.0.0-20210220032951-036812b2e83c gopkg.in/yaml.v2 v2.4.0 ) + +replace github.com/go-redis/redis/v8 v8.11.4 => github.com/buraksezer/redis/v8 v8.11.5-0.20220311192848-e41c68a594e0 diff --git a/go.sum b/go.sum index d16fffaf..7abaa8a5 100644 --- a/go.sum +++ b/go.sum @@ -8,6 +8,8 @@ github.com/bits-and-blooms/bitset v1.2.0 h1:Kn4yilvwNtMACtf1eYDlG8H77R07mZSPbMjL github.com/bits-and-blooms/bitset v1.2.0/go.mod h1:gIdJ4wp64HaoK2YrL1Q5/N7Y16edYb8uY+O0FJTyyDA= github.com/buraksezer/consistent v0.9.0 h1:Zfs6bX62wbP3QlbPGKUhqDw7SmNkOzY5bHZIYXYpR5g= github.com/buraksezer/consistent v0.9.0/go.mod h1:6BrVajWq7wbKZlTOUPs/XVfR8c0maujuPowduSpZqmw= +github.com/buraksezer/redis/v8 v8.11.5-0.20220311192848-e41c68a594e0 h1:daFj57jShO+1u3Mj44Kd90qeZ9FNvV879ZZTTg5WuVU= +github.com/buraksezer/redis/v8 v8.11.5-0.20220311192848-e41c68a594e0/go.mod h1:2Z2wHZXdQpCDXEGzqMockDpNyYvi2l4Pxt6RJr792+w= github.com/cespare/xxhash/v2 v2.1.2 h1:YRXhKfTDauu4ajMg1TPgFO5jnlC2HCbmLXMcTG5cbYE= github.com/cespare/xxhash/v2 v2.1.2/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= @@ -19,8 +21,6 @@ github.com/fatih/color v1.7.0/go.mod h1:Zm6kSWBoL9eyXnKyktHP6abPY2pDugNf5Kwzbycv github.com/fsnotify/fsnotify v1.4.7/go.mod h1:jwhsz4b93w/PPRr/qN1Yymfu8t87LnFCMoQvtojpjFo= github.com/fsnotify/fsnotify v1.4.9 h1:hsms1Qyu0jgnwNXIxa+/V/PDsU6CfLf6CNO8H7IWoS4= github.com/fsnotify/fsnotify v1.4.9/go.mod h1:znqG4EE+3YCdAaPaxE2ZRY/06pZUdp0tY4IgpuI1SZQ= -github.com/go-redis/redis/v8 v8.11.4 h1:kHoYkfZP6+pe04aFTnhDH6GDROa5yJdHJVNxV3F46Tg= -github.com/go-redis/redis/v8 v8.11.4/go.mod h1:2Z2wHZXdQpCDXEGzqMockDpNyYvi2l4Pxt6RJr792+w= github.com/go-task/slim-sprig v0.0.0-20210107165309-348f09dbbbc0/go.mod h1:fyg7847qk6SyHyPtNmDHnmrv/HOrqktSC+C9fM+CJOE= github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= github.com/golang/protobuf v1.4.0-rc.1/go.mod h1:ceaxUfeHdC40wWswd/P6IGgMaK3YpKi5j83Wpe3EHw8= diff --git a/internal/cluster/balancer/balancer.go b/internal/cluster/balancer/balancer.go index 8219e54a..58746a66 100644 --- a/internal/cluster/balancer/balancer.go +++ b/internal/cluster/balancer/balancer.go @@ -76,16 +76,17 @@ func (b *Balancer) scanPartition(sign uint64, part *partitions.Partition, owners return strings.Join(names, ",") }() - part.Map().Range(func(name, tmp interface{}) bool { - f := tmp.(partitions.Fragment) + part.Map().Range(func(rawName, rawFragment interface{}) bool { + f := rawFragment.(partitions.Fragment) if f.Length() == 0 { return false } + name := strings.TrimPrefix(rawName.(string), "dmap.") b.log.V(2).Printf("[INFO] Moving %s fragment: %s (kind: %s) on PartID: %d to %s", f.Name(), name, part.Kind(), part.ID(), ownersStr) - err := f.Move(part, name.(string), owners) + err := f.Move(part, name, owners) if err != nil { b.log.V(2).Printf("[ERROR] Failed to move %s fragment: %s on PartID: %d to %s: %v", f.Name(), name, part.ID(), ownersStr, err) diff --git a/internal/cluster/routingtable/routingtable.go b/internal/cluster/routingtable/routingtable.go index 9b4ff20b..7ced2da4 100644 --- a/internal/cluster/routingtable/routingtable.go +++ b/internal/cluster/routingtable/routingtable.go @@ -189,6 +189,12 @@ func (r *RoutingTable) IsBootstrapped() bool { // CheckBootstrap is called for every request and checks whether the node is bootstrapped. // It has to be very fast for a smooth operation. func (r *RoutingTable) CheckBootstrap() error { + // Prevent creating expensive structures for every request, + // Just check an integer value atomically. + if r.IsBootstrapped() { + return nil + } + ctx, cancel := context.WithTimeout(context.Background(), r.config.BootstrapTimeout) defer cancel() return r.tryWithInterval(ctx, 100*time.Millisecond, func() error { diff --git a/internal/dmap/atomic.go b/internal/dmap/atomic.go index 3d228df4..ad8ecfbc 100644 --- a/internal/dmap/atomic.go +++ b/internal/dmap/atomic.go @@ -19,8 +19,8 @@ import ( "errors" "fmt" - "github.com/buraksezer/olric/internal/encoding" "github.com/buraksezer/olric/internal/protocol" + "github.com/buraksezer/olric/internal/resp" "github.com/buraksezer/olric/internal/util" "github.com/buraksezer/olric/pkg/storage" ) @@ -70,7 +70,7 @@ func (dm *DMap) atomicIncrDecr(cmd string, e *env, delta int) (int, error) { } valueBuf := pool.Get() - enc := encoding.New(valueBuf) + enc := resp.New(valueBuf) err = enc.Encode(updated) if err != nil { return 0, err @@ -139,7 +139,7 @@ func (dm *DMap) GetPut(ctx context.Context, key string, value interface{}) (stor } valueBuf := pool.Get() - enc := encoding.New(valueBuf) + enc := resp.New(valueBuf) err := enc.Encode(value) if err != nil { return nil, err diff --git a/internal/dmap/atomic_test.go b/internal/dmap/atomic_test.go index bb2c3bd7..79c49611 100644 --- a/internal/dmap/atomic_test.go +++ b/internal/dmap/atomic_test.go @@ -21,8 +21,8 @@ import ( "sync/atomic" "testing" - "github.com/buraksezer/olric/internal/encoding" "github.com/buraksezer/olric/internal/protocol" + "github.com/buraksezer/olric/internal/resp" "github.com/buraksezer/olric/internal/testcluster" "github.com/go-redis/redis/v8" "github.com/stretchr/testify/require" @@ -65,7 +65,7 @@ func TestDMap_Atomic_Incr(t *testing.T) { require.NoError(t, err) var res int - err = encoding.Scan(gr.Value(), &res) + err = resp.Scan(gr.Value(), &res) require.NoError(t, err) require.Equal(t, 100, res) } @@ -107,7 +107,7 @@ func TestDMap_Atomic_Decr(t *testing.T) { require.NoError(t, err) var value int - err = encoding.Scan(res.Value(), &value) + err = resp.Scan(res.Value(), &value) require.NoError(t, err) require.Equal(t, -100, value) } @@ -132,7 +132,7 @@ func TestDMap_Atomic_GetPut(t *testing.T) { } if gr != nil { var oldval int - err = encoding.Scan(gr.Value(), &oldval) + err = resp.Scan(gr.Value(), &oldval) require.NoError(t, err) atomic.AddInt64(&total, int64(oldval)) } @@ -155,7 +155,7 @@ func TestDMap_Atomic_GetPut(t *testing.T) { require.NoError(t, err) var last int - err = encoding.Scan(gr.Value(), &last) + err = resp.Scan(gr.Value(), &last) require.NoError(t, err) atomic.AddInt64(&total, int64(last)) @@ -190,7 +190,7 @@ func TestDMap_incrCommandHandler(t *testing.T) { value, err := cmd.Bytes() require.NoError(t, err) v := new(int) - err = encoding.Scan(value, v) + err = resp.Scan(value, v) require.NoError(t, err) require.Equal(t, 100, *v) } @@ -238,7 +238,7 @@ func TestDMap_decrCommandHandler(t *testing.T) { value, err := cmd.Bytes() require.NoError(t, err) v := new(int) - err = encoding.Scan(value, v) + err = resp.Scan(value, v) require.NoError(t, err) require.Equal(t, -100, *v) } @@ -271,7 +271,7 @@ func TestDMap_exGetPutOperation(t *testing.T) { <-start buf := bytes.NewBuffer(nil) - enc := encoding.New(buf) + enc := resp.New(buf) err := enc.Encode(i) if err != nil { return err @@ -293,7 +293,7 @@ func TestDMap_exGetPutOperation(t *testing.T) { if len(val) != 0 { oldval := new(int) - err = encoding.Scan(val, oldval) + err = resp.Scan(val, oldval) if err != nil { return err } @@ -321,7 +321,7 @@ func TestDMap_exGetPutOperation(t *testing.T) { require.NoError(t, err) var last int - err = encoding.Scan(gr.Value(), &last) + err = resp.Scan(gr.Value(), &last) require.NoError(t, err) atomic.AddInt64(&total, int64(last)) diff --git a/internal/dmap/put.go b/internal/dmap/put.go index e7e44c70..08bf5cf0 100644 --- a/internal/dmap/put.go +++ b/internal/dmap/put.go @@ -24,8 +24,8 @@ import ( "github.com/buraksezer/olric/internal/bufpool" "github.com/buraksezer/olric/internal/cluster/partitions" "github.com/buraksezer/olric/internal/discovery" - "github.com/buraksezer/olric/internal/encoding" "github.com/buraksezer/olric/internal/protocol" + "github.com/buraksezer/olric/internal/resp" "github.com/buraksezer/olric/internal/stats" "github.com/buraksezer/olric/pkg/storage" "github.com/go-redis/redis/v8" @@ -387,7 +387,7 @@ type PutConfig struct { // Put returns but not before. func (dm *DMap) Put(ctx context.Context, key string, value interface{}, cfg *PutConfig) error { valueBuf := pool.Get() - enc := encoding.New(valueBuf) + enc := resp.New(valueBuf) err := enc.Encode(value) if err != nil { return err diff --git a/internal/dmap/scan.go b/internal/dmap/scan.go deleted file mode 100644 index 42f4f444..00000000 --- a/internal/dmap/scan.go +++ /dev/null @@ -1,199 +0,0 @@ -// Copyright 2018-2022 Burak Sezer -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package dmap - -import ( - "context" - "errors" - "sync" - - "github.com/buraksezer/olric/internal/discovery" - "github.com/buraksezer/olric/internal/protocol" -) - -const DefaultScanCount = 10 - -// ErrEndOfQuery is the error returned by Range when no more data is available. -// Functions should return ErrEndOfQuery only to signal a graceful end of input. -var ErrEndOfQuery = errors.New("end of query") - -// Iterator implements distributed query on DMaps. -type Iterator struct { - mtx sync.Mutex - - pos int - page []string - dm *DMap - allKeys map[string]struct{} - finished map[uint64]struct{} - cursors map[uint64]uint64 // member id => cursor - partID uint64 // current partition id - config *ScanConfig - ctx context.Context - cancel context.CancelFunc -} - -func (dm *DMap) Scan(options ...ScanOption) (*Iterator, error) { - var sc ScanConfig - for _, opt := range options { - opt(&sc) - } - if sc.Count == 0 { - sc.Count = DefaultScanCount - } - ctx, cancel := context.WithCancel(dm.s.ctx) - return &Iterator{ - dm: dm, - config: &sc, - allKeys: make(map[string]struct{}), - finished: make(map[uint64]struct{}), - cursors: make(map[uint64]uint64), - ctx: ctx, - cancel: cancel, - }, nil -} - -func (i *Iterator) updateIterator(keys []string, cursor, ownerID uint64) { - if cursor == 0 { - i.finished[ownerID] = struct{}{} - } - i.cursors[ownerID] = cursor - for _, key := range keys { - if _, ok := i.allKeys[key]; !ok { - i.page = append(i.page, key) - i.allKeys[key] = struct{}{} - } - } -} - -func (i *Iterator) scanOnOwners(owners []discovery.Member) error { - for _, owner := range owners { - if _, ok := i.finished[owner.ID]; ok { - continue - } - if owner.CompareByID(i.dm.s.rt.This()) { - keys, cursor, err := i.dm.Scan2(i.partID, i.cursors[owner.ID], i.config) - if err != nil { - return err - } - i.updateIterator(keys, cursor, owner.ID) - continue - } - - s := protocol.NewScan(i.partID, i.dm.name, i.cursors[owner.ID]) - if i.config.HasCount { - s.SetCount(i.config.Count) - } - if i.config.HasMatch { - s.SetMatch(s.Match) - } - scanCmd := s.Command(i.ctx) - rc := i.dm.s.client.Get(owner.String()) - err := rc.Process(i.ctx, scanCmd) - if err != nil { - return err - } - keys, cursor, err := scanCmd.Result() - if err != nil { - return err - } - i.updateIterator(keys, cursor, owner.ID) - } - - return nil -} - -func (i *Iterator) resetPage() { - if len(i.page) != 0 { - i.page = []string{} - } - i.pos = 0 -} - -func (i *Iterator) reset() { - // Reset - for memberID := range i.cursors { - delete(i.cursors, memberID) - delete(i.finished, memberID) - } - i.resetPage() -} - -func (i *Iterator) next() bool { - if len(i.page) != 0 { - i.pos++ - if i.pos <= len(i.page) { - return true - } - } - - i.resetPage() - - primaryOwners := i.dm.s.primary.PartitionOwnersByID(i.partID) - i.config.Replica = false - if err := i.scanOnOwners(primaryOwners); err != nil { - return false - } - - replicaOwners := i.dm.s.backup.PartitionOwnersByID(i.partID) - i.config.Replica = true - if err := i.scanOnOwners(replicaOwners); err != nil { - return false - } - - if len(i.page) == 0 { - i.partID++ - if i.dm.s.config.PartitionCount <= i.partID { - return false - } - i.reset() - return i.next() - } - i.pos = 1 - return true -} - -func (i *Iterator) Next() bool { - i.mtx.Lock() - defer i.mtx.Unlock() - - select { - case <-i.ctx.Done(): - return false - default: - } - - return i.next() -} - -func (i *Iterator) Key() string { - i.mtx.Lock() - defer i.mtx.Unlock() - - var key string - if i.pos > 0 && i.pos <= len(i.page) { - key = i.page[i.pos-1] - } - return key -} - -func (i *Iterator) Close() { - select { - case <-i.ctx.Done(): - return - default: - } - i.cancel() -} diff --git a/internal/dmap/scan_handlers.go b/internal/dmap/scan_handlers.go index 552ae9a8..6e96e21a 100644 --- a/internal/dmap/scan_handlers.go +++ b/internal/dmap/scan_handlers.go @@ -15,7 +15,6 @@ package dmap import ( - "log" "strconv" "github.com/buraksezer/olric/internal/cluster/partitions" @@ -51,7 +50,7 @@ func (dm *DMap) scanOnFragment(f *fragment, cursor uint64, sc *ScanConfig) ([]st return items, cursor, nil } -func (dm *DMap) Scan2(partID, cursor uint64, sc *ScanConfig) ([]string, uint64, error) { +func (dm *DMap) Scan(partID, cursor uint64, sc *ScanConfig) ([]string, uint64, error) { var part *partitions.Partition if sc.Replica { part = dm.s.backup.PartitionByID(partID) @@ -74,7 +73,6 @@ type ScanConfig struct { HasMatch bool Match string Replica bool - Logger *log.Logger } type ScanOption func(*ScanConfig) @@ -117,7 +115,7 @@ func (s *Service) scanCommandHandler(conn redcon.Conn, cmd redcon.Command) { var result []string var cursor uint64 - result, cursor, err = dm.Scan2(scanCmd.PartID, scanCmd.Cursor, &sc) + result, cursor, err = dm.Scan(scanCmd.PartID, scanCmd.Cursor, &sc) if err != nil { protocol.WriteError(conn, err) return diff --git a/internal/dmap/scan_test.go b/internal/dmap/scan_test.go index eeb8a90e..2b3ef1fb 100644 --- a/internal/dmap/scan_test.go +++ b/internal/dmap/scan_test.go @@ -145,133 +145,6 @@ func TestDMap_scanCommandHandler_Cluster(t *testing.T) { }) } -func TestDMap_Scan(t *testing.T) { - cluster := testcluster.New(NewService) - s := cluster.AddMember(nil).(*Service) - defer cluster.Shutdown() - - dm, err := s.NewDMap("mydmap") - require.NoError(t, err) - - ctx := context.Background() - allKeys := make(map[string]bool) - for i := 0; i < 100; i++ { - err = dm.Put(ctx, testutil.ToKey(i), i, nil) - require.NoError(t, err) - allKeys[testutil.ToKey(i)] = false - } - i, err := dm.Scan() - require.NoError(t, err) - var count int - defer i.Close() - - for i.Next() { - count++ - require.Contains(t, allKeys, i.Key()) - } - require.Equal(t, 100, count) -} - -func TestDMap_Scan_Cluster(t *testing.T) { - cluster := testcluster.New(NewService) - - c1 := testutil.NewConfig() - c1.ReplicaCount = 2 - c1.WriteQuorum = 2 - e1 := testcluster.NewEnvironment(c1) - s1 := cluster.AddMember(e1).(*Service) - - c2 := testutil.NewConfig() - c2.ReplicaCount = 2 - c1.WriteQuorum = 2 - e2 := testcluster.NewEnvironment(c2) - cluster.AddMember(e2) - - defer cluster.Shutdown() - - dm, err := s1.NewDMap("mydmap") - require.NoError(t, err) - - ctx := context.Background() - - allKeys := make(map[string]bool) - for i := 0; i < 100; i++ { - err = dm.Put(ctx, testutil.ToKey(i), i, nil) - require.NoError(t, err) - - allKeys[testutil.ToKey(i)] = false - } - i, err := dm.Scan() - require.NoError(t, err) - defer i.Close() - - var count int - for i.Next() { - count++ - require.Contains(t, allKeys, i.Key()) - } - require.Equal(t, 100, count) -} - -func TestDMap_ScanMatch(t *testing.T) { - cluster := testcluster.New(NewService) - s := cluster.AddMember(nil).(*Service) - defer cluster.Shutdown() - - dm, err := s.NewDMap("mydmap") - require.NoError(t, err) - - ctx := context.Background() - - evenKeys := make(map[string]bool) - for i := 0; i < 100; i++ { - var key string - if i%2 == 0 { - key = fmt.Sprintf("even:%s", testutil.ToKey(i)) - evenKeys[key] = false - } else { - key = fmt.Sprintf("odd:%s", testutil.ToKey(i)) - } - err = dm.Put(ctx, key, i, nil) - require.NoError(t, err) - } - i, err := dm.Scan(Match("^even:")) - require.NoError(t, err) - var count int - defer i.Close() - - for i.Next() { - count++ - require.Contains(t, evenKeys, i.Key()) - } - require.Equal(t, 50, count) -} - -func TestDMap_Scan_Close(t *testing.T) { - cluster := testcluster.New(NewService) - s := cluster.AddMember(nil).(*Service) - defer cluster.Shutdown() - - dm, err := s.NewDMap("mydmap") - require.NoError(t, err) - - ctx := context.Background() - for i := 0; i < 100; i++ { - err = dm.Put(ctx, testutil.ToKey(i), i, nil) - require.NoError(t, err) - } - i, err := dm.Scan() - require.NoError(t, err) - var count int - for i.Next() { - count++ - if count == 50 { - i.Close() // Stop the iterator - } - } - require.Equal(t, 50, count) -} - func TestDMap_scanCommandHandler_match(t *testing.T) { cluster := testcluster.New(NewService) s := cluster.AddMember(nil).(*Service) diff --git a/internal/dmap/service.go b/internal/dmap/service.go index ccbd5e86..34682996 100644 --- a/internal/dmap/service.go +++ b/internal/dmap/service.go @@ -61,7 +61,6 @@ func registerErrors() { protocol.SetError("LOCKNOTACQUIRED", ErrLockNotAcquired) protocol.SetError("READQUORUM", ErrReadQuorum) protocol.SetError("WRITEQUORUM", ErrWriteQuorum) - protocol.SetError("ENDOFQUERY", ErrEndOfQuery) protocol.SetError("DMAPNOTFOUND", ErrDMapNotFound) protocol.SetError("KEYTOOLARGE", ErrKeyTooLarge) protocol.SetError("KEYNOTFOUND", ErrKeyNotFound) diff --git a/internal/encoding/encoder_test.go b/internal/encoding/encoder_test.go deleted file mode 100644 index 94d2dcc3..00000000 --- a/internal/encoding/encoder_test.go +++ /dev/null @@ -1,29 +0,0 @@ -package encoding - -import ( - "bytes" - "encoding" - "testing" - - "github.com/stretchr/testify/require" -) - -type MyType struct{} - -var _ encoding.BinaryMarshaler = (*MyType)(nil) - -func (t *MyType) MarshalBinary() ([]byte, error) { - return []byte("hello"), nil -} - -func TestWriter_WriteArg(t *testing.T) { - buf := bytes.NewBuffer(nil) - w := New(buf) - value := uint64(345353) - err := w.Encode(value) - require.NoError(t, err) - - scannedValue := new(uint64) - err = Scan(buf.Bytes(), scannedValue) - require.NoError(t, err) -} diff --git a/internal/encoding/scan_test.go b/internal/encoding/scan_test.go deleted file mode 100644 index a15aafe8..00000000 --- a/internal/encoding/scan_test.go +++ /dev/null @@ -1,43 +0,0 @@ -package encoding_test - -/* -type testScanSliceStruct struct { - ID int - Name string -} - -func (s *testScanSliceStruct) MarshalBinary() ([]byte, error) { - return json.Marshal(s) -} - -func (s *testScanSliceStruct) UnmarshalBinary(b []byte) error { - return json.Unmarshal(b, s) -} - -var _ = Describe("ScanSlice", func() { - data := []string{ - `{"ID":-1,"Name":"Back Yu"}`, - `{"ID":1,"Name":"szyhf"}`, - } - - It("[]testScanSliceStruct", func() { - var slice []testScanSliceStruct - err := ScanSlice(data, &slice) - Expect(err).NotTo(HaveOccurred()) - Expect(slice).To(Equal([]testScanSliceStruct{ - {-1, "Back Yu"}, - {1, "szyhf"}, - })) - }) - - It("var testContainer []*testScanSliceStruct", func() { - var slice []*testScanSliceStruct - err := ScanSlice(data, &slice) - Expect(err).NotTo(HaveOccurred()) - Expect(slice).To(Equal([]*testScanSliceStruct{ - {-1, "Back Yu"}, - {1, "szyhf"}, - })) - }) -}) -*/ diff --git a/internal/kvstore/bitmap.go b/internal/kvstore/bitmap.go deleted file mode 100644 index acae2614..00000000 --- a/internal/kvstore/bitmap.go +++ /dev/null @@ -1,22 +0,0 @@ -// Copyright 2018-2022 Burak Sezer -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package kvstore - -import "github.com/RoaringBitmap/roaring/roaring64" - -type bitmap struct { - m *roaring64.Bitmap - offset uint32 -} diff --git a/internal/kvstore/kvstore.go b/internal/kvstore/kvstore.go index 0fdd3bf7..97e91665 100644 --- a/internal/kvstore/kvstore.go +++ b/internal/kvstore/kvstore.go @@ -62,23 +62,25 @@ func (k *KVStore) SetConfig(c *storage.Config) { } func (k *KVStore) makeTable() error { + if len(k.tables) != 0 { + head := k.tables[len(k.tables)-1] + head.SetState(table.ReadOnlyState) + + for i, t := range k.tables { + if t.State() == table.RecycledState { + k.tables = append(k.tables, t) + k.tables = append(k.tables[:i], k.tables[i+1:]...) + t.SetState(table.ReadWriteState) + return nil + } + } + } + size, err := k.config.Get("tableSize") if err != nil { return err } - head := k.tables[len(k.tables)-1] - head.SetState(table.ReadOnlyState) - - for i, t := range k.tables { - if t.State() == table.RecycledState { - k.tables = append(k.tables, t) - k.tables = append(k.tables[:i], k.tables[i+1:]...) - t.SetState(table.ReadWriteState) - return nil - } - } - current := table.New(size.(uint64)) k.tables = append(k.tables, current) k.tablesByCoefficient[k.coefficient] = current diff --git a/internal/protocol/cluster.go b/internal/protocol/cluster.go index bbd71e9c..633c4ddf 100644 --- a/internal/protocol/cluster.go +++ b/internal/protocol/cluster.go @@ -16,11 +16,11 @@ package protocol import ( "context" + "github.com/go-redis/redis/v8" + "github.com/tidwall/redcon" ) -type ClusterRoutingTableCommand struct{} - type ClusterRoutingTable struct{} func NewClusterRoutingTable() *ClusterRoutingTable { @@ -32,3 +32,33 @@ func (c *ClusterRoutingTable) Command(ctx context.Context) *redis.Cmd { args = append(args, Cluster.RoutingTable) return redis.NewCmd(ctx, args...) } + +func ParseClusterRoutingTable(cmd redcon.Command) (*ClusterRoutingTable, error) { + if len(cmd.Args) > 1 { + return nil, errWrongNumber(cmd.Args) + } + + c := NewClusterRoutingTable() + return c, nil +} + +type ClusterMembers struct{} + +func NewClusterMembers() *ClusterMembers { + return &ClusterMembers{} +} + +func (c *ClusterMembers) Command(ctx context.Context) *redis.Cmd { + var args []interface{} + args = append(args, Cluster.Members) + return redis.NewCmd(ctx, args...) +} + +func ParseClusterMembers(cmd redcon.Command) (*ClusterMembers, error) { + if len(cmd.Args) > 1 { + return nil, errWrongNumber(cmd.Args) + } + + c := NewClusterMembers() + return c, nil +} diff --git a/internal/protocol/cluster_parser.go b/internal/protocol/cluster_parser.go deleted file mode 100644 index f7d28346..00000000 --- a/internal/protocol/cluster_parser.go +++ /dev/null @@ -1,28 +0,0 @@ -// Copyright 2018-2022 Burak Sezer -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package protocol - -import ( - "github.com/tidwall/redcon" -) - -func ParseClusterRoutingTable(cmd redcon.Command) (*ClusterRoutingTable, error) { - if len(cmd.Args) < 1 { - return nil, errWrongNumber(cmd.Args) - } - - c := NewClusterRoutingTable() - return c, nil -} diff --git a/internal/protocol/cluster_parser_test.go b/internal/protocol/cluster_test.go similarity index 56% rename from internal/protocol/cluster_parser_test.go rename to internal/protocol/cluster_test.go index 7d90a3ed..45c8d293 100644 --- a/internal/protocol/cluster_parser_test.go +++ b/internal/protocol/cluster_test.go @@ -16,14 +16,35 @@ package protocol import ( "context" - "github.com/stretchr/testify/require" "testing" + + "github.com/stretchr/testify/require" ) -func TestProtocol_ParseClusterRoutingTable(t *testing.T) { +func TestProtocol_ClusterRoutingTable(t *testing.T) { rtCmd := NewClusterRoutingTable() cmd := stringToCommand(rtCmd.Command(context.Background()).String()) _, err := ParseClusterRoutingTable(cmd) require.NoError(t, err) + + t.Run("CLUSTER.ROUTINGTABLE invalid command", func(t *testing.T) { + cmd := stringToCommand("cluster routing table foobar") + _, err = ParseClusterRoutingTable(cmd) + require.Error(t, err) + }) +} + +func TestProtocol_ClusterMembers(t *testing.T) { + membersCmd := NewClusterMembers() + + cmd := stringToCommand(membersCmd.Command(context.Background()).String()) + _, err := ParseClusterMembers(cmd) + require.NoError(t, err) + + t.Run("CLUSTER.MEMBERS invalid command", func(t *testing.T) { + cmd := stringToCommand("cluster members foobar") + _, err = ParseClusterMembers(cmd) + require.Error(t, err) + }) } diff --git a/internal/protocol/commands.go b/internal/protocol/commands.go index 28cef94a..eb2ee0e6 100644 --- a/internal/protocol/commands.go +++ b/internal/protocol/commands.go @@ -18,10 +18,12 @@ const StatusOK = "OK" type ClusterCommands struct { RoutingTable string + Members string } var Cluster = &ClusterCommands{ RoutingTable: "cluster.routingtable", + Members: "cluster.members", } type InternalCommands struct { @@ -89,19 +91,21 @@ var DMap = &DMapCommands{ } type PubSubCommands struct { - Publish string - Subscribe string - PSubscribe string - PubSubChannels string - PubSubNumpat string - PubSubNumsub string + Publish string + PublishInternal string + Subscribe string + PSubscribe string + PubSubChannels string + PubSubNumpat string + PubSubNumsub string } var PubSub = &PubSubCommands{ - Publish: "publish", - Subscribe: "subscribe", - PSubscribe: "psubscribe", - PubSubChannels: "pubsub channels", - PubSubNumpat: "pubsub numpat", - PubSubNumsub: "pubsub numsub", + Publish: "publish", + PublishInternal: "publish.internal", + Subscribe: "subscribe", + PSubscribe: "psubscribe", + PubSubChannels: "pubsub channels", + PubSubNumpat: "pubsub numpat", + PubSubNumsub: "pubsub numsub", } diff --git a/internal/protocol/dmap.go b/internal/protocol/dmap.go index ed81f93f..c8aabcb0 100644 --- a/internal/protocol/dmap.go +++ b/internal/protocol/dmap.go @@ -16,9 +16,15 @@ package protocol import ( "context" + "errors" + "fmt" + "strconv" + "strings" "time" + "github.com/buraksezer/olric/internal/util" "github.com/go-redis/redis/v8" + "github.com/tidwall/redcon" ) type Put struct { @@ -109,6 +115,68 @@ func (p *Put) Command(ctx context.Context) *redis.StatusCmd { return redis.NewStatusCmd(ctx, args...) } +func ParsePutCommand(cmd redcon.Command) (*Put, error) { + if len(cmd.Args) < 4 { + return nil, errWrongNumber(cmd.Args) + } + + p := NewPut( + util.BytesToString(cmd.Args[1]), // DMap + util.BytesToString(cmd.Args[2]), // Key + cmd.Args[3], // Value + ) + + args := cmd.Args[4:] + for len(args) > 0 { + switch arg := strings.ToUpper(util.BytesToString(args[0])); arg { + case "NX": + p.SetNX() + args = args[1:] + continue + case "XX": + p.SetXX() + args = args[1:] + continue + case "PX": + px, err := strconv.ParseInt(util.BytesToString(args[1]), 10, 64) + if err != nil { + return nil, err + } + p.SetPX(px) + args = args[2:] + continue + case "EX": + ex, err := strconv.ParseFloat(util.BytesToString(args[1]), 64) + if err != nil { + return nil, err + } + p.SetEX(ex) + args = args[2:] + continue + case "EXAT": + exat, err := strconv.ParseFloat(util.BytesToString(args[1]), 64) + if err != nil { + return nil, err + } + p.SetEXAT(exat) + args = args[2:] + continue + case "PXAT": + pxat, err := strconv.ParseInt(util.BytesToString(args[1]), 10, 64) + if err != nil { + return nil, err + } + p.SetPXAT(pxat) + args = args[2:] + continue + default: + return nil, errors.New("syntax error") + } + } + + return p, nil +} + type PutEntry struct { DMap string Key string @@ -132,6 +200,18 @@ func (p *PutEntry) Command(ctx context.Context) *redis.StatusCmd { return redis.NewStatusCmd(ctx, args...) } +func ParsePutEntryCommand(cmd redcon.Command) (*PutEntry, error) { + if len(cmd.Args) < 4 { + return nil, errWrongNumber(cmd.Args) + } + + return NewPutEntry( + util.BytesToString(cmd.Args[1]), + util.BytesToString(cmd.Args[2]), + cmd.Args[3], + ), nil +} + type Get struct { DMap string Key string @@ -161,6 +241,28 @@ func (g *Get) Command(ctx context.Context) *redis.StringCmd { return redis.NewStringCmd(ctx, args...) } +func ParseGetCommand(cmd redcon.Command) (*Get, error) { + if len(cmd.Args) < 3 { + return nil, errWrongNumber(cmd.Args) + } + + g := NewGet( + util.BytesToString(cmd.Args[1]), + util.BytesToString(cmd.Args[2]), + ) + + if len(cmd.Args) == 4 { + arg := util.BytesToString(cmd.Args[3]) + if arg == "RW" { + g.SetRaw() + } else { + return nil, fmt.Errorf("%w: %s", ErrInvalidArgument, arg) + } + } + + return g, nil +} + type GetEntry struct { DMap string Key string @@ -190,6 +292,28 @@ func (g *GetEntry) Command(ctx context.Context) *redis.StringCmd { return redis.NewStringCmd(ctx, args...) } +func ParseGetEntryCommand(cmd redcon.Command) (*GetEntry, error) { + if len(cmd.Args) < 2 { + return nil, errWrongNumber(cmd.Args) + } + + g := NewGetEntry( + util.BytesToString(cmd.Args[1]), // DMap + util.BytesToString(cmd.Args[2]), // Key + ) + + if len(cmd.Args) == 4 { + arg := util.BytesToString(cmd.Args[3]) + if arg == "RC" { + g.SetReplica() + } else { + return nil, fmt.Errorf("%w: %s", ErrInvalidArgument, arg) + } + } + + return g, nil +} + type Del struct { DMap string Key string @@ -210,6 +334,17 @@ func (d *Del) Command(ctx context.Context) *redis.IntCmd { return redis.NewIntCmd(ctx, args...) } +func ParseDelCommand(cmd redcon.Command) (*Del, error) { + if len(cmd.Args) < 3 { + return nil, errWrongNumber(cmd.Args) + } + + return NewDel( + util.BytesToString(cmd.Args[1]), + util.BytesToString(cmd.Args[2]), + ), nil +} + type DelEntry struct { Del *Del Replica bool @@ -236,6 +371,28 @@ func (d *DelEntry) Command(ctx context.Context) *redis.IntCmd { return redis.NewIntCmd(ctx, args...) } +func ParseDelEntryCommand(cmd redcon.Command) (*DelEntry, error) { + if len(cmd.Args) < 3 { + return nil, errWrongNumber(cmd.Args) + } + + d := NewDelEntry( + util.BytesToString(cmd.Args[1]), + util.BytesToString(cmd.Args[2]), + ) + + if len(cmd.Args) == 4 { + arg := util.BytesToString(cmd.Args[3]) + if arg == "RC" { + d.SetReplica() + } else { + return nil, fmt.Errorf("%w: %s", ErrInvalidArgument, arg) + } + } + + return d, nil +} + type PExpire struct { DMap string Key string @@ -259,6 +416,24 @@ func (p *PExpire) Command(ctx context.Context) *redis.StatusCmd { return redis.NewStatusCmd(ctx, args...) } +func ParsePExpireCommand(cmd redcon.Command) (*PExpire, error) { + if len(cmd.Args) < 4 { + return nil, errWrongNumber(cmd.Args) + } + + rawMilliseconds := util.BytesToString(cmd.Args[3]) + milliseconds, err := strconv.ParseInt(rawMilliseconds, 10, 64) + if err != nil { + return nil, err + } + p := NewPExpire( + util.BytesToString(cmd.Args[1]), // DMap + util.BytesToString(cmd.Args[2]), // Key + time.Duration(milliseconds*int64(time.Millisecond)), + ) + return p, nil +} + type Expire struct { DMap string Key string @@ -282,6 +457,24 @@ func (e *Expire) Command(ctx context.Context) *redis.StatusCmd { return redis.NewStatusCmd(ctx, args...) } +func ParseExpireCommand(cmd redcon.Command) (*Expire, error) { + if len(cmd.Args) < 4 { + return nil, errWrongNumber(cmd.Args) + } + + rawSeconds := util.BytesToString(cmd.Args[3]) + seconds, err := strconv.ParseFloat(rawSeconds, 64) + if err != nil { + return nil, err + } + e := NewExpire( + util.BytesToString(cmd.Args[1]), // DMap + util.BytesToString(cmd.Args[2]), // Key + time.Duration(seconds*float64(time.Second)), + ) + return e, nil +} + type Destroy struct { DMap string Local bool @@ -308,6 +501,27 @@ func (d *Destroy) Command(ctx context.Context) *redis.StatusCmd { return redis.NewStatusCmd(ctx, args...) } +func ParseDestroyCommand(cmd redcon.Command) (*Destroy, error) { + if len(cmd.Args) < 2 { + return nil, errWrongNumber(cmd.Args) + } + + d := NewDestroy( + util.BytesToString(cmd.Args[1]), + ) + + if len(cmd.Args) == 3 { + arg := util.BytesToString(cmd.Args[2]) + if arg == "LC" { + d.SetLocal() + } else { + return nil, fmt.Errorf("%w: %s", ErrInvalidArgument, arg) + } + } + + return d, nil +} + type Scan struct { PartID uint64 DMap string @@ -360,6 +574,59 @@ func (s *Scan) Command(ctx context.Context) *redis.ScanCmd { return redis.NewScanCmd(ctx, nil, args...) } +const DefaultScanCount = 10 + +func ParseScanCommand(cmd redcon.Command) (*Scan, error) { + if len(cmd.Args) < 4 { + return nil, errWrongNumber(cmd.Args) + } + + rawPartID := util.BytesToString(cmd.Args[1]) + partID, err := strconv.ParseUint(rawPartID, 10, 64) + if err != nil { + return nil, err + } + + rawCursor := util.BytesToString(cmd.Args[3]) + cursor, err := strconv.ParseUint(rawCursor, 10, 64) + if err != nil { + return nil, err + } + + s := NewScan( + partID, + util.BytesToString(cmd.Args[2]), // DMap + cursor, + ) + + args := cmd.Args[4:] + for len(args) > 0 { + switch arg := strings.ToUpper(util.BytesToString(args[0])); arg { + case "MATCH": + s.SetMatch(util.BytesToString(args[1])) + args = args[2:] + continue + case "COUNT": + count, err := strconv.Atoi(util.BytesToString(args[1])) + if err != nil { + return nil, err + } + s.SetCount(count) + args = args[2:] + continue + case "RC": + s.SetReplica() + args = args[1:] + } + } + + if s.Count == 0 { + s.SetCount(DefaultScanCount) + } + + return s, nil +} + type Incr struct { DMap string Key string @@ -383,6 +650,23 @@ func (i *Incr) Command(ctx context.Context) *redis.IntCmd { return redis.NewIntCmd(ctx, args...) } +func ParseIncrCommand(cmd redcon.Command) (*Incr, error) { + if len(cmd.Args) < 4 { + return nil, errWrongNumber(cmd.Args) + } + + delta, err := strconv.Atoi(util.BytesToString(cmd.Args[3])) + if err != nil { + return nil, err + } + + return NewIncr( + util.BytesToString(cmd.Args[1]), + util.BytesToString(cmd.Args[2]), + delta, + ), nil +} + type Decr struct { *Incr } @@ -399,6 +683,23 @@ func (d *Decr) Command(ctx context.Context) *redis.IntCmd { return cmd } +func ParseDecrCommand(cmd redcon.Command) (*Decr, error) { + if len(cmd.Args) < 4 { + return nil, errWrongNumber(cmd.Args) + } + + delta, err := strconv.Atoi(util.BytesToString(cmd.Args[3])) + if err != nil { + return nil, err + } + + return NewDecr( + util.BytesToString(cmd.Args[1]), + util.BytesToString(cmd.Args[2]), + delta, + ), nil +} + type GetPut struct { DMap string Key string @@ -431,6 +732,28 @@ func (g *GetPut) Command(ctx context.Context) *redis.StringCmd { return redis.NewStringCmd(ctx, args...) } +func ParseGetPutCommand(cmd redcon.Command) (*GetPut, error) { + if len(cmd.Args) < 4 { + return nil, errWrongNumber(cmd.Args) + } + + g := NewGetPut( + util.BytesToString(cmd.Args[1]), // DMap + util.BytesToString(cmd.Args[2]), // Key + cmd.Args[3], // Value + ) + + if len(cmd.Args) == 5 { + arg := util.BytesToString(cmd.Args[4]) + if arg == "RW" { + g.SetRaw() + } else { + return nil, fmt.Errorf("%w: %s", ErrInvalidArgument, arg) + } + } + return g, nil +} + // TODO: Add PLock type Lock struct { @@ -480,6 +803,49 @@ func (l *Lock) Command(ctx context.Context) *redis.StringCmd { return redis.NewStringCmd(ctx, args...) } +func ParseLockCommand(cmd redcon.Command) (*Lock, error) { + if len(cmd.Args) < 4 { + return nil, errWrongNumber(cmd.Args) + } + + deadline, err := strconv.ParseFloat(util.BytesToString(cmd.Args[3]), 64) + if err != nil { + return nil, err + } + + l := NewLock( + util.BytesToString(cmd.Args[1]), // DMap + util.BytesToString(cmd.Args[2]), // Key + deadline, // Deadline + ) + + // EX or PX are optional. + if len(cmd.Args) > 4 { + if len(cmd.Args) == 5 { + return nil, fmt.Errorf("%w: %s needs a numerical argument", ErrInvalidArgument, util.BytesToString(cmd.Args[5])) + } + + switch arg := strings.ToUpper(util.BytesToString(cmd.Args[4])); arg { + case "PX": + px, err := strconv.ParseInt(util.BytesToString(cmd.Args[5]), 10, 64) + if err != nil { + return nil, err + } + l.PX = px + case "EX": + ex, err := strconv.ParseFloat(util.BytesToString(cmd.Args[5]), 64) + if err != nil { + return nil, err + } + l.EX = ex + default: + return nil, fmt.Errorf("%w: %s", ErrInvalidArgument, arg) + } + } + + return l, nil +} + type Unlock struct { DMap string Key string @@ -503,7 +869,17 @@ func (u *Unlock) Command(ctx context.Context) *redis.StatusCmd { return redis.NewStatusCmd(ctx, args...) } -// TODO: Add PLockLease +func ParseUnlockCommand(cmd redcon.Command) (*Unlock, error) { + if len(cmd.Args) < 4 { + return nil, errWrongNumber(cmd.Args) + } + + return NewUnlock( + util.BytesToString(cmd.Args[1]), // DMap + util.BytesToString(cmd.Args[2]), // Key + util.BytesToString(cmd.Args[3]), // Token + ), nil +} type LockLease struct { DMap string @@ -531,6 +907,24 @@ func (l *LockLease) Command(ctx context.Context) *redis.StatusCmd { return redis.NewStatusCmd(ctx, args...) } +func ParseLockLeaseCommand(cmd redcon.Command) (*LockLease, error) { + if len(cmd.Args) < 5 { + return nil, errWrongNumber(cmd.Args) + } + + timeout, err := strconv.ParseFloat(util.BytesToString(cmd.Args[4]), 64) + if err != nil { + return nil, err + } + + return NewLockLease( + util.BytesToString(cmd.Args[1]), // DMap + util.BytesToString(cmd.Args[2]), // Key + util.BytesToString(cmd.Args[3]), // Token + timeout, // Timeout + ), nil +} + type PLockLease struct { DMap string Key string @@ -556,3 +950,21 @@ func (p *PLockLease) Command(ctx context.Context) *redis.StatusCmd { args = append(args, p.Timeout) return redis.NewStatusCmd(ctx, args...) } + +func ParsePLockLeaseCommand(cmd redcon.Command) (*PLockLease, error) { + if len(cmd.Args) < 5 { + return nil, errWrongNumber(cmd.Args) + } + + timeout, err := strconv.ParseInt(util.BytesToString(cmd.Args[4]), 10, 64) + if err != nil { + return nil, err + } + + return NewPLockLease( + util.BytesToString(cmd.Args[1]), // DMap + util.BytesToString(cmd.Args[2]), // Key + util.BytesToString(cmd.Args[3]), // Token + timeout, // Timeout + ), nil +} diff --git a/internal/protocol/dmap_parser.go b/internal/protocol/dmap_parser.go deleted file mode 100644 index 6effd9be..00000000 --- a/internal/protocol/dmap_parser.go +++ /dev/null @@ -1,434 +0,0 @@ -// Copyright 2018-2022 Burak Sezer -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package protocol - -import ( - "errors" - "fmt" - "strconv" - "strings" - "time" - - "github.com/buraksezer/olric/internal/util" - "github.com/tidwall/redcon" -) - -var ErrInvalidArgument = errors.New("invalid argument") - -func ParsePutCommand(cmd redcon.Command) (*Put, error) { - if len(cmd.Args) < 4 { - return nil, errWrongNumber(cmd.Args) - } - - p := NewPut( - util.BytesToString(cmd.Args[1]), // DMap - util.BytesToString(cmd.Args[2]), // Key - cmd.Args[3], // Value - ) - - args := cmd.Args[4:] - for len(args) > 0 { - switch arg := strings.ToUpper(util.BytesToString(args[0])); arg { - case "NX": - p.SetNX() - args = args[1:] - continue - case "XX": - p.SetXX() - args = args[1:] - continue - case "PX": - px, err := strconv.ParseInt(util.BytesToString(args[1]), 10, 64) - if err != nil { - return nil, err - } - p.SetPX(px) - args = args[2:] - continue - case "EX": - ex, err := strconv.ParseFloat(util.BytesToString(args[1]), 64) - if err != nil { - return nil, err - } - p.SetEX(ex) - args = args[2:] - continue - case "EXAT": - exat, err := strconv.ParseFloat(util.BytesToString(args[1]), 64) - if err != nil { - return nil, err - } - p.SetEXAT(exat) - args = args[2:] - continue - case "PXAT": - pxat, err := strconv.ParseInt(util.BytesToString(args[1]), 10, 64) - if err != nil { - return nil, err - } - p.SetPXAT(pxat) - args = args[2:] - continue - default: - return nil, errors.New("syntax error") - } - } - - return p, nil -} - -func ParseGetEntryCommand(cmd redcon.Command) (*GetEntry, error) { - if len(cmd.Args) < 2 { - return nil, errWrongNumber(cmd.Args) - } - - g := NewGetEntry( - util.BytesToString(cmd.Args[1]), // DMap - util.BytesToString(cmd.Args[2]), // Key - ) - - if len(cmd.Args) == 4 { - arg := util.BytesToString(cmd.Args[3]) - if arg == "RC" { - g.SetReplica() - } else { - return nil, fmt.Errorf("%w: %s", ErrInvalidArgument, arg) - } - } - - return g, nil -} - -func ParseGetCommand(cmd redcon.Command) (*Get, error) { - if len(cmd.Args) < 3 { - return nil, errWrongNumber(cmd.Args) - } - - g := NewGet( - util.BytesToString(cmd.Args[1]), - util.BytesToString(cmd.Args[2]), - ) - - if len(cmd.Args) == 4 { - arg := util.BytesToString(cmd.Args[3]) - if arg == "RW" { - g.SetRaw() - } else { - return nil, fmt.Errorf("%w: %s", ErrInvalidArgument, arg) - } - } - - return g, nil -} - -func ParsePutEntryCommand(cmd redcon.Command) (*PutEntry, error) { - if len(cmd.Args) < 4 { - return nil, errWrongNumber(cmd.Args) - } - - return NewPutEntry( - util.BytesToString(cmd.Args[1]), - util.BytesToString(cmd.Args[2]), - cmd.Args[3], - ), nil -} - -func ParseDelCommand(cmd redcon.Command) (*Del, error) { - if len(cmd.Args) < 3 { - return nil, errWrongNumber(cmd.Args) - } - - return NewDel( - util.BytesToString(cmd.Args[1]), - util.BytesToString(cmd.Args[2]), - ), nil -} - -func ParseDelEntryCommand(cmd redcon.Command) (*DelEntry, error) { - if len(cmd.Args) < 3 { - return nil, errWrongNumber(cmd.Args) - } - - d := NewDelEntry( - util.BytesToString(cmd.Args[1]), - util.BytesToString(cmd.Args[2]), - ) - - if len(cmd.Args) == 4 { - arg := util.BytesToString(cmd.Args[3]) - if arg == "RC" { - d.SetReplica() - } else { - return nil, fmt.Errorf("%w: %s", ErrInvalidArgument, arg) - } - } - - return d, nil -} - -func ParseExpireCommand(cmd redcon.Command) (*Expire, error) { - if len(cmd.Args) < 4 { - return nil, errWrongNumber(cmd.Args) - } - - rawSeconds := util.BytesToString(cmd.Args[3]) - seconds, err := strconv.ParseFloat(rawSeconds, 64) - if err != nil { - return nil, err - } - e := NewExpire( - util.BytesToString(cmd.Args[1]), // DMap - util.BytesToString(cmd.Args[2]), // Key - time.Duration(seconds*float64(time.Second)), - ) - return e, nil -} - -func ParsePExpireCommand(cmd redcon.Command) (*PExpire, error) { - if len(cmd.Args) < 4 { - return nil, errWrongNumber(cmd.Args) - } - - rawMilliseconds := util.BytesToString(cmd.Args[3]) - milliseconds, err := strconv.ParseInt(rawMilliseconds, 10, 64) - if err != nil { - return nil, err - } - p := NewPExpire( - util.BytesToString(cmd.Args[1]), // DMap - util.BytesToString(cmd.Args[2]), // Key - time.Duration(milliseconds*int64(time.Millisecond)), - ) - return p, nil -} - -func ParseDestroyCommand(cmd redcon.Command) (*Destroy, error) { - if len(cmd.Args) < 2 { - return nil, errWrongNumber(cmd.Args) - } - - d := NewDestroy( - util.BytesToString(cmd.Args[1]), - ) - - if len(cmd.Args) == 3 { - arg := util.BytesToString(cmd.Args[2]) - if arg == "LC" { - d.SetLocal() - } else { - return nil, fmt.Errorf("%w: %s", ErrInvalidArgument, arg) - } - } - - return d, nil -} - -func ParseIncrCommand(cmd redcon.Command) (*Incr, error) { - if len(cmd.Args) < 4 { - return nil, errWrongNumber(cmd.Args) - } - - delta, err := strconv.Atoi(util.BytesToString(cmd.Args[3])) - if err != nil { - return nil, err - } - - return NewIncr( - util.BytesToString(cmd.Args[1]), - util.BytesToString(cmd.Args[2]), - delta, - ), nil -} - -func ParseDecrCommand(cmd redcon.Command) (*Decr, error) { - if len(cmd.Args) < 4 { - return nil, errWrongNumber(cmd.Args) - } - - delta, err := strconv.Atoi(util.BytesToString(cmd.Args[3])) - if err != nil { - return nil, err - } - - return NewDecr( - util.BytesToString(cmd.Args[1]), - util.BytesToString(cmd.Args[2]), - delta, - ), nil -} - -func ParseGetPutCommand(cmd redcon.Command) (*GetPut, error) { - if len(cmd.Args) < 4 { - return nil, errWrongNumber(cmd.Args) - } - - g := NewGetPut( - util.BytesToString(cmd.Args[1]), // DMap - util.BytesToString(cmd.Args[2]), // Key - cmd.Args[3], // Value - ) - - if len(cmd.Args) == 5 { - arg := util.BytesToString(cmd.Args[4]) - if arg == "RW" { - g.SetRaw() - } else { - return nil, fmt.Errorf("%w: %s", ErrInvalidArgument, arg) - } - } - return g, nil -} - -func ParseLockCommand(cmd redcon.Command) (*Lock, error) { - if len(cmd.Args) < 4 { - return nil, errWrongNumber(cmd.Args) - } - - deadline, err := strconv.ParseFloat(util.BytesToString(cmd.Args[3]), 64) - if err != nil { - return nil, err - } - - l := NewLock( - util.BytesToString(cmd.Args[1]), // DMap - util.BytesToString(cmd.Args[2]), // Key - deadline, // Deadline - ) - - // EX or PX are optional. - if len(cmd.Args) > 4 { - if len(cmd.Args) == 5 { - return nil, fmt.Errorf("%w: %s needs a numerical argument", ErrInvalidArgument, util.BytesToString(cmd.Args[5])) - } - - switch arg := strings.ToUpper(util.BytesToString(cmd.Args[4])); arg { - case "PX": - px, err := strconv.ParseInt(util.BytesToString(cmd.Args[5]), 10, 64) - if err != nil { - return nil, err - } - l.PX = px - case "EX": - ex, err := strconv.ParseFloat(util.BytesToString(cmd.Args[5]), 64) - if err != nil { - return nil, err - } - l.EX = ex - default: - return nil, fmt.Errorf("%w: %s", ErrInvalidArgument, arg) - } - } - - return l, nil -} - -func ParseUnlockCommand(cmd redcon.Command) (*Unlock, error) { - if len(cmd.Args) < 4 { - return nil, errWrongNumber(cmd.Args) - } - - return NewUnlock( - util.BytesToString(cmd.Args[1]), // DMap - util.BytesToString(cmd.Args[2]), // Key - util.BytesToString(cmd.Args[3]), // Token - ), nil -} - -func ParseLockLeaseCommand(cmd redcon.Command) (*LockLease, error) { - if len(cmd.Args) < 5 { - return nil, errWrongNumber(cmd.Args) - } - - timeout, err := strconv.ParseFloat(util.BytesToString(cmd.Args[4]), 64) - if err != nil { - return nil, err - } - - return NewLockLease( - util.BytesToString(cmd.Args[1]), // DMap - util.BytesToString(cmd.Args[2]), // Key - util.BytesToString(cmd.Args[3]), // Token - timeout, // Timeout - ), nil -} - -func ParsePLockLeaseCommand(cmd redcon.Command) (*PLockLease, error) { - if len(cmd.Args) < 5 { - return nil, errWrongNumber(cmd.Args) - } - - timeout, err := strconv.ParseInt(util.BytesToString(cmd.Args[4]), 10, 64) - if err != nil { - return nil, err - } - - return NewPLockLease( - util.BytesToString(cmd.Args[1]), // DMap - util.BytesToString(cmd.Args[2]), // Key - util.BytesToString(cmd.Args[3]), // Token - timeout, // Timeout - ), nil -} - -func ParseScanCommand(cmd redcon.Command) (*Scan, error) { - if len(cmd.Args) < 4 { - return nil, errWrongNumber(cmd.Args) - } - - rawPartID := util.BytesToString(cmd.Args[1]) - partID, err := strconv.ParseUint(rawPartID, 10, 64) - if err != nil { - return nil, err - } - - rawCursor := util.BytesToString(cmd.Args[3]) - cursor, err := strconv.ParseUint(rawCursor, 10, 64) - if err != nil { - return nil, err - } - - s := NewScan( - partID, - util.BytesToString(cmd.Args[2]), // DMap - cursor, - ) - - args := cmd.Args[4:] - for len(args) > 0 { - switch arg := strings.ToUpper(util.BytesToString(args[0])); arg { - case "MATCH": - s.SetMatch(util.BytesToString(args[1])) - args = args[2:] - continue - case "COUNT": - count, err := strconv.Atoi(util.BytesToString(args[1])) - if err != nil { - return nil, err - } - s.SetCount(count) - args = args[2:] - continue - case "RC": - s.SetReplica() - args = args[1:] - } - } - - if s.Count == 0 { - s.SetCount(10) - } - - return s, nil -} diff --git a/internal/protocol/dmap_parser_test.go b/internal/protocol/dmap_parser_test.go deleted file mode 100644 index 2547cf6b..00000000 --- a/internal/protocol/dmap_parser_test.go +++ /dev/null @@ -1,220 +0,0 @@ -// Copyright 2018-2022 Burak Sezer -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package protocol - -import ( - "context" - "strings" - "testing" - "time" - - "github.com/stretchr/testify/require" - "github.com/tidwall/redcon" -) - -func stringToCommand(s string) redcon.Command { - cmd := redcon.Command{ - Raw: []byte(s), - } - - s = strings.TrimSuffix(s, ": []") - s = strings.TrimSuffix(s, ": 0") - s = strings.TrimSuffix(s, ":") - s = strings.TrimSuffix(s, ": ") - parsed := strings.Split(s, " ") - for _, arg := range parsed { - cmd.Args = append(cmd.Args, []byte(arg)) - } - return cmd -} - -func TestProtocol_ParsePutCommand_EX(t *testing.T) { - putCmd := NewPut("my-dmap", "my-key", []byte("my-value")) - putCmd.SetEX((10 * time.Second).Seconds()) - - cmd := stringToCommand(putCmd.Command(context.Background()).String()) - parsed, err := ParsePutCommand(cmd) - require.NoError(t, err) - - require.Equal(t, "my-dmap", parsed.DMap) - require.Equal(t, "my-key", parsed.Key) - require.Equal(t, []byte("my-value"), parsed.Value) - require.Equal(t, float64(10), parsed.EX) -} - -func TestProtocol_ParsePutCommand_PX(t *testing.T) { - putCmd := NewPut("my-dmap", "my-key", []byte("my-value")) - putCmd.SetPX((100 * time.Millisecond).Milliseconds()) - - cmd := stringToCommand(putCmd.Command(context.Background()).String()) - parsed, err := ParsePutCommand(cmd) - require.NoError(t, err) - - require.Equal(t, "my-dmap", parsed.DMap) - require.Equal(t, "my-key", parsed.Key) - require.Equal(t, []byte("my-value"), parsed.Value) - require.Equal(t, int64(100), parsed.PX) -} - -func TestProtocol_ParsePutCommand_NX(t *testing.T) { - putCmd := NewPut("my-dmap", "my-key", []byte("my-value")) - putCmd.SetNX() - - cmd := stringToCommand(putCmd.Command(context.Background()).String()) - parsed, err := ParsePutCommand(cmd) - require.NoError(t, err) - - require.Equal(t, "my-dmap", parsed.DMap) - require.Equal(t, "my-key", parsed.Key) - require.Equal(t, []byte("my-value"), parsed.Value) - require.True(t, parsed.NX) - require.False(t, parsed.XX) -} - -func TestProtocol_ParsePutCommand_XX(t *testing.T) { - putCmd := NewPut("my-dmap", "my-key", []byte("my-value")) - putCmd.SetXX() - - cmd := stringToCommand(putCmd.Command(context.Background()).String()) - parsed, err := ParsePutCommand(cmd) - require.NoError(t, err) - - require.Equal(t, "my-dmap", parsed.DMap) - require.Equal(t, "my-key", parsed.Key) - require.Equal(t, []byte("my-value"), parsed.Value) - require.True(t, parsed.XX) - require.False(t, parsed.NX) -} - -func TestProtocol_ParsePutCommand_EXAT(t *testing.T) { - putCmd := NewPut("my-dmap", "my-key", []byte("my-value")) - exat := float64(time.Now().Unix()) + 10 - putCmd.SetEXAT(exat) - - cmd := stringToCommand(putCmd.Command(context.Background()).String()) - parsed, err := ParsePutCommand(cmd) - require.NoError(t, err) - - require.Equal(t, "my-dmap", parsed.DMap) - require.Equal(t, "my-key", parsed.Key) - require.Equal(t, []byte("my-value"), parsed.Value) - require.Equal(t, exat, parsed.EXAT) -} - -func TestProtocol_ParsePutCommand_PXAT(t *testing.T) { - putCmd := NewPut("my-dmap", "my-key", []byte("my-value")) - pxat := (time.Now().UnixNano() / 1000000) + 10 - putCmd.SetPXAT(pxat) - - cmd := stringToCommand(putCmd.Command(context.Background()).String()) - parsed, err := ParsePutCommand(cmd) - require.NoError(t, err) - - require.Equal(t, "my-dmap", parsed.DMap) - require.Equal(t, "my-key", parsed.Key) - require.Equal(t, []byte("my-value"), parsed.Value) - require.Equal(t, pxat, parsed.PXAT) -} - -func TestProtocol_ParseScanCommand(t *testing.T) { - scanCmd := NewScan(1, "my-dmap", 0) - - s := scanCmd.Command(context.Background()).String() - s = strings.TrimSuffix(s, ": []") - cmd := stringToCommand(s) - parsed, err := ParseScanCommand(cmd) - require.NoError(t, err) - require.Equal(t, "my-dmap", parsed.DMap) - require.Equal(t, "", parsed.Match) - require.Equal(t, 10, parsed.Count) - require.False(t, scanCmd.Replica) -} - -func TestProtocol_ParseScanCommand_Replica(t *testing.T) { - scanCmd := NewScan(1, "my-dmap", 0).SetReplica() - - s := scanCmd.Command(context.Background()).String() - s = strings.TrimSuffix(s, ": []") - cmd := stringToCommand(s) - parsed, err := ParseScanCommand(cmd) - require.NoError(t, err) - require.Equal(t, "my-dmap", parsed.DMap) - require.Equal(t, "", parsed.Match) - require.Equal(t, 10, parsed.Count) - require.True(t, scanCmd.Replica) -} - -func TestProtocol_ParseScanCommand_Match(t *testing.T) { - scanCmd := NewScan(1, "my-dmap", 0).SetMatch("^even") - - s := scanCmd.Command(context.Background()).String() - s = strings.TrimSuffix(s, ": []") - cmd := stringToCommand(s) - parsed, err := ParseScanCommand(cmd) - require.NoError(t, err) - require.Equal(t, "my-dmap", parsed.DMap) - require.Equal(t, uint64(1), parsed.PartID) - require.Equal(t, "^even", parsed.Match) - require.Equal(t, 10, parsed.Count) - require.False(t, scanCmd.Replica) -} - -func TestProtocol_ParseScanCommand_PartID(t *testing.T) { - scanCmd := NewScan(1, "my-dmap", 0).SetCount(200) - - s := scanCmd.Command(context.Background()).String() - s = strings.TrimSuffix(s, ": []") - cmd := stringToCommand(s) - parsed, err := ParseScanCommand(cmd) - require.NoError(t, err) - require.Equal(t, "my-dmap", parsed.DMap) - require.Equal(t, uint64(1), parsed.PartID) - require.Equal(t, "", parsed.Match) - require.Equal(t, 200, parsed.Count) - require.False(t, scanCmd.Replica) -} - -func TestProtocol_ParseScanCommand_Match_Count(t *testing.T) { - scanCmd := NewScan(1, "my-dmap", 0).SetCount(100).SetMatch("^even") - - s := scanCmd.Command(context.Background()).String() - s = strings.TrimSuffix(s, ": []") - cmd := stringToCommand(s) - parsed, err := ParseScanCommand(cmd) - require.NoError(t, err) - require.Equal(t, "my-dmap", parsed.DMap) - require.Equal(t, uint64(1), parsed.PartID) - require.Equal(t, "^even", parsed.Match) - require.Equal(t, 100, parsed.Count) - require.False(t, scanCmd.Replica) -} - -func TestProtocol_ParseScanCommand_Match_Count_Replica(t *testing.T) { - scanCmd := NewScan(1, "my-dmap", 0). - SetCount(100). - SetMatch("^even"). - SetReplica() - - s := scanCmd.Command(context.Background()).String() - s = strings.TrimSuffix(s, ": []") - cmd := stringToCommand(s) - parsed, err := ParseScanCommand(cmd) - require.NoError(t, err) - require.Equal(t, "my-dmap", parsed.DMap) - require.Equal(t, uint64(1), parsed.PartID) - require.Equal(t, "^even", parsed.Match) - require.Equal(t, 100, parsed.Count) - require.True(t, scanCmd.Replica) -} diff --git a/internal/protocol/dmap_test.go b/internal/protocol/dmap_test.go index dab342ae..450b7645 100644 --- a/internal/protocol/dmap_test.go +++ b/internal/protocol/dmap_test.go @@ -13,3 +13,519 @@ // limitations under the License. package protocol + +import ( + "context" + "strings" + "testing" + "time" + + "github.com/stretchr/testify/require" + "github.com/tidwall/redcon" +) + +func stringToCommand(s string) redcon.Command { + cmd := redcon.Command{ + Raw: []byte(s), + } + + s = strings.TrimSuffix(s, ": []") + s = strings.TrimSuffix(s, ": 0") + s = strings.TrimSuffix(s, ":") + s = strings.TrimSuffix(s, ": ") + parsed := strings.Split(s, " ") + for _, arg := range parsed { + cmd.Args = append(cmd.Args, []byte(arg)) + } + return cmd +} + +func TestProtocol_ParsePutCommand_EX(t *testing.T) { + putCmd := NewPut("my-dmap", "my-key", []byte("my-value")) + putCmd.SetEX((10 * time.Second).Seconds()) + + cmd := stringToCommand(putCmd.Command(context.Background()).String()) + parsed, err := ParsePutCommand(cmd) + require.NoError(t, err) + + require.Equal(t, "my-dmap", parsed.DMap) + require.Equal(t, "my-key", parsed.Key) + require.Equal(t, []byte("my-value"), parsed.Value) + require.Equal(t, float64(10), parsed.EX) +} + +func TestProtocol_ParsePutCommand_PX(t *testing.T) { + putCmd := NewPut("my-dmap", "my-key", []byte("my-value")) + putCmd.SetPX((100 * time.Millisecond).Milliseconds()) + + cmd := stringToCommand(putCmd.Command(context.Background()).String()) + parsed, err := ParsePutCommand(cmd) + require.NoError(t, err) + + require.Equal(t, "my-dmap", parsed.DMap) + require.Equal(t, "my-key", parsed.Key) + require.Equal(t, []byte("my-value"), parsed.Value) + require.Equal(t, int64(100), parsed.PX) +} + +func TestProtocol_ParsePutCommand_NX(t *testing.T) { + putCmd := NewPut("my-dmap", "my-key", []byte("my-value")) + putCmd.SetNX() + + cmd := stringToCommand(putCmd.Command(context.Background()).String()) + parsed, err := ParsePutCommand(cmd) + require.NoError(t, err) + + require.Equal(t, "my-dmap", parsed.DMap) + require.Equal(t, "my-key", parsed.Key) + require.Equal(t, []byte("my-value"), parsed.Value) + require.True(t, parsed.NX) + require.False(t, parsed.XX) +} + +func TestProtocol_ParsePutCommand_XX(t *testing.T) { + putCmd := NewPut("my-dmap", "my-key", []byte("my-value")) + putCmd.SetXX() + + cmd := stringToCommand(putCmd.Command(context.Background()).String()) + parsed, err := ParsePutCommand(cmd) + require.NoError(t, err) + + require.Equal(t, "my-dmap", parsed.DMap) + require.Equal(t, "my-key", parsed.Key) + require.Equal(t, []byte("my-value"), parsed.Value) + require.True(t, parsed.XX) + require.False(t, parsed.NX) +} + +func TestProtocol_ParsePutCommand_EXAT(t *testing.T) { + putCmd := NewPut("my-dmap", "my-key", []byte("my-value")) + exat := float64(time.Now().Unix()) + 10 + putCmd.SetEXAT(exat) + + cmd := stringToCommand(putCmd.Command(context.Background()).String()) + parsed, err := ParsePutCommand(cmd) + require.NoError(t, err) + + require.Equal(t, "my-dmap", parsed.DMap) + require.Equal(t, "my-key", parsed.Key) + require.Equal(t, []byte("my-value"), parsed.Value) + require.Equal(t, exat, parsed.EXAT) +} + +func TestProtocol_ParsePutCommand_PXAT(t *testing.T) { + putCmd := NewPut("my-dmap", "my-key", []byte("my-value")) + pxat := (time.Now().UnixNano() / 1000000) + 10 + putCmd.SetPXAT(pxat) + + cmd := stringToCommand(putCmd.Command(context.Background()).String()) + parsed, err := ParsePutCommand(cmd) + require.NoError(t, err) + + require.Equal(t, "my-dmap", parsed.DMap) + require.Equal(t, "my-key", parsed.Key) + require.Equal(t, []byte("my-value"), parsed.Value) + require.Equal(t, pxat, parsed.PXAT) +} + +func TestProtocol_ParseScanCommand(t *testing.T) { + scanCmd := NewScan(1, "my-dmap", 0) + + s := scanCmd.Command(context.Background()).String() + s = strings.TrimSuffix(s, ": []") + cmd := stringToCommand(s) + parsed, err := ParseScanCommand(cmd) + require.NoError(t, err) + require.Equal(t, "my-dmap", parsed.DMap) + require.Equal(t, "", parsed.Match) + require.Equal(t, 10, parsed.Count) + require.False(t, scanCmd.Replica) +} + +func TestProtocol_ParseScanCommand_Replica(t *testing.T) { + scanCmd := NewScan(1, "my-dmap", 0).SetReplica() + + s := scanCmd.Command(context.Background()).String() + s = strings.TrimSuffix(s, ": []") + cmd := stringToCommand(s) + parsed, err := ParseScanCommand(cmd) + require.NoError(t, err) + require.Equal(t, "my-dmap", parsed.DMap) + require.Equal(t, "", parsed.Match) + require.Equal(t, 10, parsed.Count) + require.True(t, scanCmd.Replica) +} + +func TestProtocol_ParseScanCommand_Match(t *testing.T) { + scanCmd := NewScan(1, "my-dmap", 0).SetMatch("^even") + + s := scanCmd.Command(context.Background()).String() + s = strings.TrimSuffix(s, ": []") + cmd := stringToCommand(s) + parsed, err := ParseScanCommand(cmd) + require.NoError(t, err) + require.Equal(t, "my-dmap", parsed.DMap) + require.Equal(t, uint64(1), parsed.PartID) + require.Equal(t, "^even", parsed.Match) + require.Equal(t, 10, parsed.Count) + require.False(t, scanCmd.Replica) +} + +func TestProtocol_ParseScanCommand_PartID(t *testing.T) { + scanCmd := NewScan(1, "my-dmap", 0).SetCount(200) + + s := scanCmd.Command(context.Background()).String() + s = strings.TrimSuffix(s, ": []") + cmd := stringToCommand(s) + parsed, err := ParseScanCommand(cmd) + require.NoError(t, err) + require.Equal(t, "my-dmap", parsed.DMap) + require.Equal(t, uint64(1), parsed.PartID) + require.Equal(t, "", parsed.Match) + require.Equal(t, 200, parsed.Count) + require.False(t, scanCmd.Replica) +} + +func TestProtocol_ParseScanCommand_Match_Count(t *testing.T) { + scanCmd := NewScan(1, "my-dmap", 0).SetCount(100).SetMatch("^even") + + s := scanCmd.Command(context.Background()).String() + s = strings.TrimSuffix(s, ": []") + cmd := stringToCommand(s) + parsed, err := ParseScanCommand(cmd) + require.NoError(t, err) + require.Equal(t, "my-dmap", parsed.DMap) + require.Equal(t, uint64(1), parsed.PartID) + require.Equal(t, "^even", parsed.Match) + require.Equal(t, 100, parsed.Count) + require.False(t, scanCmd.Replica) +} + +func TestProtocol_ParseScanCommand_Match_Count_Replica(t *testing.T) { + scanCmd := NewScan(1, "my-dmap", 0). + SetCount(100). + SetMatch("^even"). + SetReplica() + + s := scanCmd.Command(context.Background()).String() + s = strings.TrimSuffix(s, ": []") + cmd := stringToCommand(s) + parsed, err := ParseScanCommand(cmd) + require.NoError(t, err) + require.Equal(t, "my-dmap", parsed.DMap) + require.Equal(t, uint64(1), parsed.PartID) + require.Equal(t, "^even", parsed.Match) + require.Equal(t, 100, parsed.Count) + require.True(t, scanCmd.Replica) +} + +func TestProtocol_PutEntry(t *testing.T) { + putEntryCmd := NewPutEntry("my-dmap", "my-key", []byte("my-value")) + + cmd := stringToCommand(putEntryCmd.Command(context.Background()).String()) + parsed, err := ParsePutEntryCommand(cmd) + require.NoError(t, err) + + require.Equal(t, "my-dmap", parsed.DMap) + require.Equal(t, "my-key", parsed.Key) + require.Equal(t, []byte("my-value"), parsed.Value) +} + +func TestProtocol_Get(t *testing.T) { + getCmd := NewGet("my-dmap", "my-key") + + cmd := stringToCommand(getCmd.Command(context.Background()).String()) + parsed, err := ParseGetCommand(cmd) + require.NoError(t, err) + + require.Equal(t, "my-dmap", parsed.DMap) + require.Equal(t, "my-key", parsed.Key) + require.False(t, parsed.Raw) +} + +func TestProtocol_Get_RW(t *testing.T) { + getCmd := NewGet("my-dmap", "my-key") + getCmd.SetRaw() + + cmd := stringToCommand(getCmd.Command(context.Background()).String()) + parsed, err := ParseGetCommand(cmd) + require.NoError(t, err) + + require.Equal(t, "my-dmap", parsed.DMap) + require.Equal(t, "my-key", parsed.Key) + require.True(t, parsed.Raw) +} + +func TestProtocol_GetEntry(t *testing.T) { + getEntryCmd := NewGetEntry("my-dmap", "my-key") + + cmd := stringToCommand(getEntryCmd.Command(context.Background()).String()) + parsed, err := ParseGetEntryCommand(cmd) + require.NoError(t, err) + + require.Equal(t, "my-dmap", parsed.DMap) + require.Equal(t, "my-key", parsed.Key) + require.False(t, parsed.Replica) +} + +func TestProtocol_GetEntry_RC(t *testing.T) { + getEntryCmd := NewGetEntry("my-dmap", "my-key") + getEntryCmd.SetReplica() + + cmd := stringToCommand(getEntryCmd.Command(context.Background()).String()) + parsed, err := ParseGetEntryCommand(cmd) + require.NoError(t, err) + + require.Equal(t, "my-dmap", parsed.DMap) + require.Equal(t, "my-key", parsed.Key) + require.True(t, parsed.Replica) +} + +func TestProtocol_Del(t *testing.T) { + delCmd := NewDel("my-dmap", "my-key") + + cmd := stringToCommand(delCmd.Command(context.Background()).String()) + parsed, err := ParseDelCommand(cmd) + require.NoError(t, err) + + require.Equal(t, "my-dmap", parsed.DMap) + require.Equal(t, "my-key", parsed.Key) +} + +func TestProtocol_DelEntry(t *testing.T) { + delEntryCmd := NewDelEntry("my-dmap", "my-key") + + cmd := stringToCommand(delEntryCmd.Command(context.Background()).String()) + parsed, err := ParseDelEntryCommand(cmd) + require.NoError(t, err) + + require.Equal(t, "my-dmap", parsed.Del.DMap) + require.Equal(t, "my-key", parsed.Del.Key) + require.False(t, parsed.Replica) +} + +func TestProtocol_DelEntry_RC(t *testing.T) { + delEntryCmd := NewDelEntry("my-dmap", "my-key") + delEntryCmd.SetReplica() + + cmd := stringToCommand(delEntryCmd.Command(context.Background()).String()) + parsed, err := ParseDelEntryCommand(cmd) + require.NoError(t, err) + + require.Equal(t, "my-dmap", parsed.Del.DMap) + require.Equal(t, "my-key", parsed.Del.Key) + require.True(t, parsed.Replica) +} + +func TestProtocol_PExpire(t *testing.T) { + pexpireCmd := NewPExpire("my-dmap", "my-key", 10*time.Millisecond) + + cmd := stringToCommand(pexpireCmd.Command(context.Background()).String()) + parsed, err := ParsePExpireCommand(cmd) + require.NoError(t, err) + + require.Equal(t, "my-dmap", parsed.DMap) + require.Equal(t, "my-key", parsed.Key) + require.Equal(t, 10*time.Millisecond, parsed.Milliseconds) +} + +func TestProtocol_Expire(t *testing.T) { + pexpireCmd := NewExpire("my-dmap", "my-key", 10*time.Second) + + cmd := stringToCommand(pexpireCmd.Command(context.Background()).String()) + parsed, err := ParseExpireCommand(cmd) + require.NoError(t, err) + + require.Equal(t, "my-dmap", parsed.DMap) + require.Equal(t, "my-key", parsed.Key) + require.Equal(t, 10*time.Second, parsed.Seconds) +} + +func TestProtocol_Destroy(t *testing.T) { + destroyCmd := NewDestroy("my-dmap") + + cmd := stringToCommand(destroyCmd.Command(context.Background()).String()) + parsed, err := ParseDestroyCommand(cmd) + require.NoError(t, err) + + require.Equal(t, "my-dmap", parsed.DMap) + require.False(t, parsed.Local) +} + +func TestProtocol_Destroy_Local(t *testing.T) { + destroyCmd := NewDestroy("my-dmap") + destroyCmd.SetLocal() + + cmd := stringToCommand(destroyCmd.Command(context.Background()).String()) + parsed, err := ParseDestroyCommand(cmd) + require.NoError(t, err) + + require.Equal(t, "my-dmap", parsed.DMap) + require.True(t, parsed.Local) +} + +func TestProtocol_Incr(t *testing.T) { + incrCmd := NewIncr("my-dmap", "my-key", 7) + + cmd := stringToCommand(incrCmd.Command(context.Background()).String()) + parsed, err := ParseIncrCommand(cmd) + require.NoError(t, err) + + require.Equal(t, "my-dmap", parsed.DMap) + require.Equal(t, "my-key", parsed.Key) + require.Equal(t, 7, parsed.Delta) +} + +func TestProtocol_Decr(t *testing.T) { + decrCmd := NewDecr("my-dmap", "my-key", 7) + + cmd := stringToCommand(decrCmd.Command(context.Background()).String()) + parsed, err := ParseDecrCommand(cmd) + require.NoError(t, err) + + require.Equal(t, "my-dmap", parsed.DMap) + require.Equal(t, "my-key", parsed.Key) + require.Equal(t, 7, parsed.Delta) +} + +func TestProtocol_GetPut(t *testing.T) { + getputCmd := NewGetPut("my-dmap", "my-key", []byte("my-value")) + + cmd := stringToCommand(getputCmd.Command(context.Background()).String()) + parsed, err := ParseGetPutCommand(cmd) + require.NoError(t, err) + + require.Equal(t, "my-dmap", parsed.DMap) + require.Equal(t, "my-key", parsed.Key) + require.Equal(t, []byte("my-value"), parsed.Value) + require.False(t, parsed.Raw) +} + +func TestProtocol_GetPut_RW(t *testing.T) { + getputCmd := NewGetPut("my-dmap", "my-key", []byte("my-value")) + getputCmd.SetRaw() + + cmd := stringToCommand(getputCmd.Command(context.Background()).String()) + parsed, err := ParseGetPutCommand(cmd) + require.NoError(t, err) + + require.Equal(t, "my-dmap", parsed.DMap) + require.Equal(t, "my-key", parsed.Key) + require.Equal(t, []byte("my-value"), parsed.Value) + require.True(t, parsed.Raw) +} + +func TestProtocol_Lock(t *testing.T) { + lockCmd := NewLock("my-dmap", "my-key", 7) + + cmd := stringToCommand(lockCmd.Command(context.Background()).String()) + parsed, err := ParseLockCommand(cmd) + require.NoError(t, err) + + require.Equal(t, "my-dmap", parsed.DMap) + require.Equal(t, "my-key", parsed.Key) + require.Equal(t, float64(7), parsed.Deadline) +} + +func TestProtocol_Lock_EX(t *testing.T) { + exDuration := (250 * time.Second).Seconds() + lockCmd := NewLock("my-dmap", "my-key", 7) + lockCmd.SetEX(exDuration) + + cmd := stringToCommand(lockCmd.Command(context.Background()).String()) + parsed, err := ParseLockCommand(cmd) + require.NoError(t, err) + + require.Equal(t, "my-dmap", parsed.DMap) + require.Equal(t, "my-key", parsed.Key) + require.Equal(t, float64(7), parsed.Deadline) + require.Equal(t, exDuration, parsed.EX) +} + +func TestProtocol_Lock_PX(t *testing.T) { + pxDuration := (250 * time.Millisecond).Milliseconds() + lockCmd := NewLock("my-dmap", "my-key", 7) + lockCmd.SetPX(pxDuration) + + cmd := stringToCommand(lockCmd.Command(context.Background()).String()) + parsed, err := ParseLockCommand(cmd) + require.NoError(t, err) + + require.Equal(t, "my-dmap", parsed.DMap) + require.Equal(t, "my-key", parsed.Key) + require.Equal(t, float64(7), parsed.Deadline) + require.Equal(t, pxDuration, parsed.PX) +} + +func TestProtocol_Unlock(t *testing.T) { + unlockCmd := NewUnlock("my-dmap", "my-key", "token") + + cmd := stringToCommand(unlockCmd.Command(context.Background()).String()) + parsed, err := ParseUnlockCommand(cmd) + require.NoError(t, err) + + require.Equal(t, "my-dmap", parsed.DMap) + require.Equal(t, "my-key", parsed.Key) + require.Equal(t, "token", parsed.Token) +} + +func TestProtocol_LockLease(t *testing.T) { + timeout := (7 * time.Second).Seconds() + unlockCmd := NewLockLease("my-dmap", "my-key", "token", timeout) + + cmd := stringToCommand(unlockCmd.Command(context.Background()).String()) + parsed, err := ParseLockLeaseCommand(cmd) + require.NoError(t, err) + + require.Equal(t, "my-dmap", parsed.DMap) + require.Equal(t, "my-key", parsed.Key) + require.Equal(t, "token", parsed.Token) + require.Equal(t, timeout, parsed.Timeout) +} + +func TestProtocol_PLockLease(t *testing.T) { + timeout := (250 * time.Millisecond).Milliseconds() + plockleaseCmd := NewPLockLease("my-dmap", "my-key", "token", timeout) + + cmd := stringToCommand(plockleaseCmd.Command(context.Background()).String()) + parsed, err := ParsePLockLeaseCommand(cmd) + require.NoError(t, err) + + require.Equal(t, "my-dmap", parsed.DMap) + require.Equal(t, "my-key", parsed.Key) + require.Equal(t, "token", parsed.Token) + require.Equal(t, timeout, parsed.Timeout) +} + +func TestProtocol_Scan(t *testing.T) { + scanCmd := NewScan(17, "my-dmap", 234) + + cmd := stringToCommand(scanCmd.Command(context.Background()).String()) + parsed, err := ParseScanCommand(cmd) + require.NoError(t, err) + + require.Equal(t, "my-dmap", parsed.DMap) + require.Equal(t, uint64(17), parsed.PartID) + require.Equal(t, uint64(234), parsed.Cursor) + require.False(t, parsed.Replica) + require.Equal(t, DefaultScanCount, parsed.Count) + require.Equal(t, "", parsed.Match) +} + +func TestProtocol_Scan_Count_Match_Replica(t *testing.T) { + scanCmd := NewScan(17, "my-dmap", 234) + scanCmd.SetCount(123) + scanCmd.SetMatch("^even:") + scanCmd.SetReplica() + + cmd := stringToCommand(scanCmd.Command(context.Background()).String()) + parsed, err := ParseScanCommand(cmd) + require.NoError(t, err) + + require.Equal(t, "my-dmap", parsed.DMap) + require.Equal(t, uint64(17), parsed.PartID) + require.Equal(t, uint64(234), parsed.Cursor) + require.True(t, parsed.Replica) + require.Equal(t, 123, parsed.Count) + require.Equal(t, "^even:", parsed.Match) +} diff --git a/internal/protocol/errors.go b/internal/protocol/errors.go index 7307638c..1854c5e6 100644 --- a/internal/protocol/errors.go +++ b/internal/protocol/errors.go @@ -23,6 +23,8 @@ import ( "github.com/tidwall/redcon" ) +var ErrInvalidArgument = errors.New("invalid argument") + var GenericError = "ERR" var errorWithPrefix = struct { diff --git a/internal/protocol/pubsub.go b/internal/protocol/pubsub.go index ce94d848..d5bcb4de 100644 --- a/internal/protocol/pubsub.go +++ b/internal/protocol/pubsub.go @@ -17,7 +17,9 @@ package protocol import ( "context" + "github.com/buraksezer/olric/internal/util" "github.com/go-redis/redis/v8" + "github.com/tidwall/redcon" ) type Publish struct { @@ -40,6 +42,48 @@ func (p *Publish) Command(ctx context.Context) *redis.IntCmd { return redis.NewIntCmd(ctx, args...) } +func ParsePublishCommand(cmd redcon.Command) (*Publish, error) { + if len(cmd.Args) < 3 { + return nil, errWrongNumber(cmd.Args) + } + + return NewPublish( + util.BytesToString(cmd.Args[1]), // Channel + util.BytesToString(cmd.Args[2]), // Message + ), nil +} + +type PublishInternal struct { + Channel string + Message string +} + +func NewPublishInternal(channel, message string) *PublishInternal { + return &PublishInternal{ + Channel: channel, + Message: message, + } +} + +func (p *PublishInternal) Command(ctx context.Context) *redis.IntCmd { + var args []interface{} + args = append(args, PubSub.PublishInternal) + args = append(args, p.Channel) + args = append(args, p.Message) + return redis.NewIntCmd(ctx, args...) +} + +func ParsePublishInternalCommand(cmd redcon.Command) (*PublishInternal, error) { + if len(cmd.Args) < 3 { + return nil, errWrongNumber(cmd.Args) + } + + return NewPublishInternal( + util.BytesToString(cmd.Args[1]), // Channel + util.BytesToString(cmd.Args[2]), // Message + ), nil +} + type Subscribe struct { Channels []string } @@ -59,6 +103,21 @@ func (s *Subscribe) Command(ctx context.Context) *redis.SliceCmd { return redis.NewSliceCmd(ctx, args...) } +func ParseSubscribeCommand(cmd redcon.Command) (*Subscribe, error) { + if len(cmd.Args) < 2 { + return nil, errWrongNumber(cmd.Args) + } + + var channels []string + args := cmd.Args[1:] + for len(args) > 0 { + arg := util.BytesToString(args[0]) + channels = append(channels, arg) + args = args[1:] + } + return NewSubscribe(channels...), nil +} + type PSubscribe struct { Patterns []string } @@ -78,6 +137,21 @@ func (s *PSubscribe) Command(ctx context.Context) *redis.SliceCmd { return redis.NewSliceCmd(ctx, args...) } +func ParsePSubscribeCommand(cmd redcon.Command) (*PSubscribe, error) { + if len(cmd.Args) < 2 { + return nil, errWrongNumber(cmd.Args) + } + + var patterns []string + args := cmd.Args[1:] + for len(args) > 0 { + arg := util.BytesToString(args[0]) + patterns = append(patterns, arg) + args = args[1:] + } + return NewPSubscribe(patterns...), nil +} + type PubSubChannels struct { Pattern string } @@ -100,6 +174,18 @@ func (ps *PubSubChannels) Command(ctx context.Context) *redis.SliceCmd { return redis.NewSliceCmd(ctx, args...) } +func ParsePubSubChannelsCommand(cmd redcon.Command) (*PubSubChannels, error) { + if len(cmd.Args) < 2 { + return nil, errWrongNumber(cmd.Args) + } + + ps := NewPubSubChannels() + if len(cmd.Args) >= 3 { + ps.SetPattern(util.BytesToString(cmd.Args[2])) + } + return ps, nil +} + type PubSubNumpat struct{} func NewPubSubNumpat() *PubSubNumpat { @@ -108,10 +194,18 @@ func NewPubSubNumpat() *PubSubNumpat { func (ps *PubSubNumpat) Command(ctx context.Context) *redis.IntCmd { var args []interface{} - args = append(args, PubSub.PubSubChannels) + args = append(args, PubSub.PubSubNumpat) return redis.NewIntCmd(ctx, args...) } +func ParsePubSubNumpatCommand(cmd redcon.Command) (*PubSubNumpat, error) { + if len(cmd.Args) < 2 { + return nil, errWrongNumber(cmd.Args) + } + + return NewPubSubNumpat(), nil +} + type PubSubNumsub struct { Channels []string } @@ -130,3 +224,18 @@ func (ps *PubSubNumsub) Command(ctx context.Context) *redis.SliceCmd { } return redis.NewSliceCmd(ctx, args...) } + +func ParsePubSubNumsubCommand(cmd redcon.Command) (*PubSubNumsub, error) { + if len(cmd.Args) < 2 { + return nil, errWrongNumber(cmd.Args) + } + + var channels []string + args := cmd.Args[2:] + for len(args) > 0 { + arg := util.BytesToString(args[0]) + channels = append(channels, arg) + args = args[1:] + } + return NewPubSubNumsub(channels...), nil +} diff --git a/internal/protocol/pubsub_parser.go b/internal/protocol/pubsub_parser.go deleted file mode 100644 index bb9cc909..00000000 --- a/internal/protocol/pubsub_parser.go +++ /dev/null @@ -1,96 +0,0 @@ -// Copyright 2018-2022 Burak Sezer -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package protocol - -import ( - "github.com/buraksezer/olric/internal/util" - "github.com/tidwall/redcon" -) - -func ParsePublishCommand(cmd redcon.Command) (*Publish, error) { - if len(cmd.Args) < 3 { - return nil, errWrongNumber(cmd.Args) - } - - return NewPublish( - util.BytesToString(cmd.Args[1]), // Channel - util.BytesToString(cmd.Args[2]), // Message - ), nil -} - -func ParseSubscribeCommand(cmd redcon.Command) (*Subscribe, error) { - if len(cmd.Args) < 2 { - return nil, errWrongNumber(cmd.Args) - } - - var channels []string - args := cmd.Args[1:] - for len(args) > 0 { - arg := util.BytesToString(args[0]) - channels = append(channels, arg) - args = args[1:] - } - return NewSubscribe(channels...), nil -} - -func ParsePSubscribeCommand(cmd redcon.Command) (*PSubscribe, error) { - if len(cmd.Args) < 2 { - return nil, errWrongNumber(cmd.Args) - } - - var patterns []string - args := cmd.Args[1:] - for len(args) > 0 { - arg := util.BytesToString(args[0]) - patterns = append(patterns, arg) - args = args[1:] - } - return NewPSubscribe(patterns...), nil -} - -func ParsePubSubChannelsCommand(cmd redcon.Command) (*PubSubChannels, error) { - if len(cmd.Args) < 2 { - return nil, errWrongNumber(cmd.Args) - } - - ps := NewPubSubChannels() - if len(cmd.Args) >= 3 { - ps.SetPattern(util.BytesToString(cmd.Args[2])) - } - return ps, nil -} - -func ParsePubSubNumpatCommand(cmd redcon.Command) (*PubSubNumpat, error) { - if len(cmd.Args) < 2 { - return nil, errWrongNumber(cmd.Args) - } - - return NewPubSubNumpat(), nil -} - -func ParsePubSubNumsubCommand(cmd redcon.Command) (*PubSubNumsub, error) { - if len(cmd.Args) < 2 { - return nil, errWrongNumber(cmd.Args) - } - - var channels []string - args := cmd.Args[2:] - for len(args) > 0 { - arg := util.BytesToString(args[0]) - channels = append(channels, arg) - args = args[1:] - } - return NewPubSubNumsub(channels...), nil -} diff --git a/internal/protocol/pubsub_parser_test.go b/internal/protocol/pubsub_parser_test.go deleted file mode 100644 index 4e4025ee..00000000 --- a/internal/protocol/pubsub_parser_test.go +++ /dev/null @@ -1,29 +0,0 @@ -package protocol - -import ( - "context" - "github.com/stretchr/testify/require" - "testing" -) - -func TestProtocol_ParsePublishCommand(t *testing.T) { - publishCmd := NewPublish("my-pubsub", "my-message") - - cmd := stringToCommand(publishCmd.Command(context.Background()).String()) - parsed, err := ParsePublishCommand(cmd) - require.NoError(t, err) - - require.Equal(t, "my-pubsub", parsed.Channel) - require.Equal(t, "my-message", parsed.Message) -} - -func TestProtocol_ParseSubscribeCommand(t *testing.T) { - subscribeCmd := NewSubscribe("channel-1", "channel-2", "channel-3") - - cmd := stringToCommand(subscribeCmd.Command(context.Background()).String()) - parsed, err := ParseSubscribeCommand(cmd) - require.NoError(t, err) - - channels := []string{"channel-1", "channel-2", "channel-3"} - require.Equal(t, channels, parsed.Channels) -} diff --git a/internal/protocol/pubsub_test.go b/internal/protocol/pubsub_test.go new file mode 100644 index 00000000..4e89f284 --- /dev/null +++ b/internal/protocol/pubsub_test.go @@ -0,0 +1,90 @@ +package protocol + +import ( + "context" + "github.com/stretchr/testify/require" + "testing" +) + +func TestProtocol_ParsePublishCommand(t *testing.T) { + publishCmd := NewPublish("my-pubsub", "my-message") + + cmd := stringToCommand(publishCmd.Command(context.Background()).String()) + parsed, err := ParsePublishCommand(cmd) + require.NoError(t, err) + + require.Equal(t, "my-pubsub", parsed.Channel) + require.Equal(t, "my-message", parsed.Message) +} + +func TestProtocol_ParsePublishInternalCommand(t *testing.T) { + publishIntCmd := NewPublishInternal("my-pubsub", "my-message") + + cmd := stringToCommand(publishIntCmd.Command(context.Background()).String()) + parsed, err := ParsePublishInternalCommand(cmd) + require.NoError(t, err) + + require.Equal(t, "my-pubsub", parsed.Channel) + require.Equal(t, "my-message", parsed.Message) +} + +func TestProtocol_ParseSubscribeCommand(t *testing.T) { + subscribeCmd := NewSubscribe("channel-1", "channel-2", "channel-3") + + cmd := stringToCommand(subscribeCmd.Command(context.Background()).String()) + parsed, err := ParseSubscribeCommand(cmd) + require.NoError(t, err) + + channels := []string{"channel-1", "channel-2", "channel-3"} + require.Equal(t, channels, parsed.Channels) +} + +func TestProtocol_ParsePSubscribeCommand(t *testing.T) { + psubscribeCmd := NewPSubscribe("ch?nnel-*") + + cmd := stringToCommand(psubscribeCmd.Command(context.Background()).String()) + parsed, err := ParsePSubscribeCommand(cmd) + require.NoError(t, err) + + patterns := []string{"ch?nnel-*"} + require.Equal(t, patterns, parsed.Patterns) +} + +func TestProtocol_PubSubChannels(t *testing.T) { + pubsubChannelsCmd := NewPubSubChannels() + + cmd := stringToCommand(pubsubChannelsCmd.Command(context.Background()).String()) + parsed, err := ParsePubSubChannelsCommand(cmd) + require.NoError(t, err) + require.Empty(t, parsed.Pattern) +} + +func TestProtocol_PubSubChannels_Patterns(t *testing.T) { + pubsubChannelsCmd := NewPubSubChannels() + pubsubChannelsCmd.SetPattern("ch?nnel-*") + + cmd := stringToCommand(pubsubChannelsCmd.Command(context.Background()).String()) + parsed, err := ParsePubSubChannelsCommand(cmd) + require.NoError(t, err) + + require.Equal(t, "ch?nnel-*", parsed.Pattern) +} + +func TestProtocol_PubSubNumpat(t *testing.T) { + pubsubNumpatCmd := NewPubSubNumpat() + + cmd := stringToCommand(pubsubNumpatCmd.Command(context.Background()).String()) + _, err := ParsePubSubNumpatCommand(cmd) + require.NoError(t, err) +} + +func TestProtocol_PubSubNumsub(t *testing.T) { + pubsubNumsubCmd := NewPubSubNumsub("channel-1", "channel-2", "channel-3") + + cmd := stringToCommand(pubsubNumsubCmd.Command(context.Background()).String()) + parsed, err := ParsePubSubNumsubCommand(cmd) + require.NoError(t, err) + + channels := []string{"channel-1", "channel-2", "channel-3"} + require.Equal(t, channels, parsed.Channels) +} diff --git a/internal/protocol/system.go b/internal/protocol/system.go index c3d6b711..4f6a56a4 100644 --- a/internal/protocol/system.go +++ b/internal/protocol/system.go @@ -16,8 +16,12 @@ package protocol import ( "context" + "fmt" + "strconv" + "github.com/buraksezer/olric/internal/util" "github.com/go-redis/redis/v8" + "github.com/tidwall/redcon" ) type Ping struct { @@ -42,6 +46,18 @@ func (p *Ping) Command(ctx context.Context) *redis.StringCmd { return redis.NewStringCmd(ctx, args...) } +func ParsePingCommand(cmd redcon.Command) (*Ping, error) { + if len(cmd.Args) < 1 { + return nil, errWrongNumber(cmd.Args) + } + + p := NewPing() + if len(cmd.Args) == 2 { + p.SetMessage(util.BytesToString(cmd.Args[1])) + } + return p, nil +} + type MoveFragment struct { Payload []byte } @@ -59,6 +75,14 @@ func (m *MoveFragment) Command(ctx context.Context) *redis.StatusCmd { return redis.NewStatusCmd(ctx, args...) } +func ParseMoveFragmentCommand(cmd redcon.Command) (*MoveFragment, error) { + if len(cmd.Args) < 2 { + return nil, errWrongNumber(cmd.Args) + } + + return NewMoveFragment(cmd.Args[1]), nil +} + type UpdateRouting struct { Payload []byte CoordinatorID uint64 @@ -79,6 +103,18 @@ func (u *UpdateRouting) Command(ctx context.Context) *redis.StringCmd { return redis.NewStringCmd(ctx, args...) } +func ParseUpdateRoutingCommand(cmd redcon.Command) (*UpdateRouting, error) { + if len(cmd.Args) < 2 { + return nil, errWrongNumber(cmd.Args) + } + coordinatorID, err := strconv.ParseUint(util.BytesToString(cmd.Args[2]), 10, 64) + if err != nil { + return nil, err + } + + return NewUpdateRouting(cmd.Args[1], coordinatorID), nil +} + type LengthOfPart struct { PartID uint64 Replica bool @@ -105,6 +141,28 @@ func (l *LengthOfPart) Command(ctx context.Context) *redis.IntCmd { return redis.NewIntCmd(ctx, args...) } +func ParseLengthOfPartCommand(cmd redcon.Command) (*LengthOfPart, error) { + if len(cmd.Args) < 2 { + return nil, errWrongNumber(cmd.Args) + } + partID, err := strconv.ParseUint(util.BytesToString(cmd.Args[1]), 10, 64) + if err != nil { + return nil, err + } + + l := NewLengthOfPart(partID) + if len(cmd.Args) == 3 { + arg := util.BytesToString(cmd.Args[2]) + if arg == "RC" { + l.SetReplica() + } else { + return nil, fmt.Errorf("%w: %s", ErrInvalidArgument, arg) + } + } + + return l, nil +} + type Stats struct { CollectRuntime bool } @@ -126,3 +184,21 @@ func (s *Stats) Command(ctx context.Context) *redis.StringCmd { } return redis.NewStringCmd(ctx, args...) } + +func ParseStatsCommand(cmd redcon.Command) (*Stats, error) { + if len(cmd.Args) < 1 { + return nil, errWrongNumber(cmd.Args) + } + + s := NewStats() + if len(cmd.Args) == 2 { + arg := util.BytesToString(cmd.Args[1]) + if arg == "CR" { + s.SetCollectRuntime() + } else { + return nil, fmt.Errorf("%w: %s", ErrInvalidArgument, arg) + } + } + + return s, nil +} diff --git a/internal/protocol/system_parser.go b/internal/protocol/system_parser.go deleted file mode 100644 index 26d5e834..00000000 --- a/internal/protocol/system_parser.go +++ /dev/null @@ -1,95 +0,0 @@ -// Copyright 2018-2022 Burak Sezer -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package protocol - -import ( - "fmt" - "strconv" - - "github.com/buraksezer/olric/internal/util" - "github.com/tidwall/redcon" -) - -func ParsePingCommand(cmd redcon.Command) (*Ping, error) { - if len(cmd.Args) < 1 { - return nil, errWrongNumber(cmd.Args) - } - - p := NewPing() - if len(cmd.Args) == 2 { - p.SetMessage(util.BytesToString(cmd.Args[1])) - } - return p, nil -} - -func ParseMoveFragmentCommand(cmd redcon.Command) (*MoveFragment, error) { - if len(cmd.Args) < 2 { - return nil, errWrongNumber(cmd.Args) - } - - return NewMoveFragment(cmd.Args[1]), nil -} - -func ParseUpdateRoutingCommand(cmd redcon.Command) (*UpdateRouting, error) { - if len(cmd.Args) < 2 { - return nil, errWrongNumber(cmd.Args) - } - coordinatorID, err := strconv.ParseUint(util.BytesToString(cmd.Args[2]), 10, 64) - if err != nil { - return nil, err - } - - return NewUpdateRouting(cmd.Args[1], coordinatorID), nil -} - -func ParseLengthOfPartCommand(cmd redcon.Command) (*LengthOfPart, error) { - if len(cmd.Args) < 2 { - return nil, errWrongNumber(cmd.Args) - } - partID, err := strconv.ParseUint(util.BytesToString(cmd.Args[1]), 10, 64) - if err != nil { - return nil, err - } - - l := NewLengthOfPart(partID) - if len(cmd.Args) == 3 { - arg := util.BytesToString(cmd.Args[2]) - if arg == "RC" { - l.SetReplica() - } else { - return nil, fmt.Errorf("%w: %s", ErrInvalidArgument, arg) - } - } - - return l, nil -} - -func ParseStatsCommand(cmd redcon.Command) (*Stats, error) { - if len(cmd.Args) < 1 { - return nil, errWrongNumber(cmd.Args) - } - - s := NewStats() - if len(cmd.Args) == 2 { - arg := util.BytesToString(cmd.Args[1]) - if arg == "CR" { - s.SetCollectRuntime() - } else { - return nil, fmt.Errorf("%w: %s", ErrInvalidArgument, arg) - } - } - - return s, nil -} diff --git a/internal/protocol/system_test.go b/internal/protocol/system_test.go new file mode 100644 index 00000000..fbbdec74 --- /dev/null +++ b/internal/protocol/system_test.go @@ -0,0 +1,107 @@ +// Copyright 2018-2022 Burak Sezer +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package protocol + +import ( + "context" + "github.com/stretchr/testify/require" + "testing" +) + +func TestProtocol_Ping(t *testing.T) { + ping := NewPing() + + cmd := stringToCommand(ping.Command(context.Background()).String()) + parsed, err := ParsePingCommand(cmd) + require.NoError(t, err) + + require.Equal(t, "", parsed.Message) +} + +func TestProtocol_Ping_Message(t *testing.T) { + ping := NewPing() + ping.SetMessage("message") + + cmd := stringToCommand(ping.Command(context.Background()).String()) + parsed, err := ParsePingCommand(cmd) + require.NoError(t, err) + + require.Equal(t, "message", parsed.Message) +} + +func TestProtocol_MoveFragment(t *testing.T) { + moveFragmentCmd := NewMoveFragment([]byte("payload")) + + cmd := stringToCommand(moveFragmentCmd.Command(context.Background()).String()) + parsed, err := ParseMoveFragmentCommand(cmd) + require.NoError(t, err) + + require.Equal(t, []byte("payload"), parsed.Payload) +} + +func TestProtocol_UpdateRoutingTable(t *testing.T) { + updateRoutingTableCmd := NewUpdateRouting([]byte("payload"), 123) + + cmd := stringToCommand(updateRoutingTableCmd.Command(context.Background()).String()) + parsed, err := ParseUpdateRoutingCommand(cmd) + require.NoError(t, err) + + require.Equal(t, []byte("payload"), parsed.Payload) + require.Equal(t, uint64(123), parsed.CoordinatorID) +} + +func TestProtocol_LengthOfPart(t *testing.T) { + updateRoutingTableCmd := NewLengthOfPart(123) + + cmd := stringToCommand(updateRoutingTableCmd.Command(context.Background()).String()) + parsed, err := ParseLengthOfPartCommand(cmd) + require.NoError(t, err) + + require.Equal(t, uint64(123), parsed.PartID) + require.False(t, parsed.Replica) +} + +func TestProtocol_LengthOfPart_RC(t *testing.T) { + updateRoutingTableCmd := NewLengthOfPart(123) + updateRoutingTableCmd.SetReplica() + + cmd := stringToCommand(updateRoutingTableCmd.Command(context.Background()).String()) + parsed, err := ParseLengthOfPartCommand(cmd) + require.NoError(t, err) + + require.Equal(t, uint64(123), parsed.PartID) + require.True(t, parsed.Replica) +} + +func TestProtocol_Stats(t *testing.T) { + statsCmd := NewStats() + + cmd := stringToCommand(statsCmd.Command(context.Background()).String()) + parsed, err := ParseStatsCommand(cmd) + require.NoError(t, err) + + require.False(t, parsed.CollectRuntime) +} + +func TestProtocol_Stats_CR(t *testing.T) { + statsCmd := NewStats() + statsCmd.SetCollectRuntime() + + cmd := stringToCommand(statsCmd.Command(context.Background()).String()) + parsed, err := ParseStatsCommand(cmd) + require.NoError(t, err) + + require.True(t, parsed.CollectRuntime) +} diff --git a/internal/pubsub/handlers.go b/internal/pubsub/handlers.go index 9e4419df..8e1b444d 100644 --- a/internal/pubsub/handlers.go +++ b/internal/pubsub/handlers.go @@ -39,8 +39,43 @@ func (s *Service) publishCommandHandler(conn redcon.Conn, cmd redcon.Command) { protocol.WriteError(conn, err) return } - count := s.pubsub.Publish(publishCmd.Channel, publishCmd.Message) - PublishedTotal.Increase(int64(count)) + + var total int64 + members := s.rt.Discovery().GetMembers() + for _, member := range members { + if member.CompareByID(s.rt.This()) { + count := s.pubsub.Publish(publishCmd.Channel, publishCmd.Message) + total += int64(count) + PublishedTotal.Increase(int64(count)) + continue + } + + pi := protocol.NewPublishInternal(publishCmd.Channel, publishCmd.Message).Command(s.ctx) + rc := s.client.Get(member.String()) + err = rc.Process(s.ctx, pi) + if err != nil { + protocol.WriteError(conn, err) + return + } + pcount, err := pi.Result() + if err != nil { + protocol.WriteError(conn, err) + return + } + total += pcount + PublishedTotal.Increase(pcount) + } + + conn.WriteInt64(total) +} + +func (s *Service) publishInternalCommandHandler(conn redcon.Conn, cmd redcon.Command) { + publishInternalCmd, err := protocol.ParsePublishInternalCommand(cmd) + if err != nil { + protocol.WriteError(conn, err) + return + } + count := s.pubsub.Publish(publishInternalCmd.Channel, publishInternalCmd.Message) conn.WriteInt(count) } diff --git a/internal/pubsub/handlers_test.go b/internal/pubsub/handlers_test.go index 23ae46eb..697e9a4f 100644 --- a/internal/pubsub/handlers_test.go +++ b/internal/pubsub/handlers_test.go @@ -341,3 +341,54 @@ func TestPubSub_Handler_PubSubNumsub(t *testing.T) { require.Equal(t, int64(1), nr["foobar"]) require.Equal(t, int64(1), nr["barfoo"]) } + +func TestPubSub_Cluster(t *testing.T) { + cluster := testcluster.New(NewService) + s1 := cluster.AddMember(nil).(*Service) + s2 := cluster.AddMember(nil).(*Service) + defer cluster.Shutdown() + + rc1 := s1.client.Get(s1.rt.This().String()) + ctx := context.Background() + ps := rc1.Subscribe(ctx, "my-channel") + + // Wait for confirmation that subscription is created before publishing anything. + msgi, err := ps.ReceiveTimeout(ctx, time.Second) + require.NoError(t, err) + + subs := msgi.(*redis.Subscription) + require.Equal(t, "subscribe", subs.Kind) + require.Equal(t, "my-channel", subs.Channel) + require.Equal(t, 1, subs.Count) + + // Go channel which receives messages. + ch := ps.Channel() + + rc2 := s2.client.Get(s2.rt.This().String()) + expected := make(map[string]struct{}) + for i := 0; i < 10; i++ { + msg := fmt.Sprintf("my-message-%d", i) + err = rc2.Publish(ctx, "my-channel", msg).Err() + require.NoError(t, err) + expected[msg] = struct{}{} + } + + consumed := make(map[string]struct{}) +L: + for { + select { + case msg := <-ch: + require.Equal(t, "my-channel", msg.Channel) + consumed[msg.Payload] = struct{}{} + if len(consumed) == 10 { + // It would be OK + break L + } + case <-time.After(5 * time.Second): + // Enough. Break it and check the consumed items. + break L + } + } + + require.Equal(t, expected, consumed) +} diff --git a/internal/pubsub/service.go b/internal/pubsub/service.go index def325f8..bb2c7092 100644 --- a/internal/pubsub/service.go +++ b/internal/pubsub/service.go @@ -58,6 +58,7 @@ func (s *Service) RegisterHandlers() { s.server.ServeMux().HandleFunc(protocol.PubSub.Subscribe, s.subscribeCommandHandler) s.server.ServeMux().HandleFunc(protocol.PubSub.PSubscribe, s.psubscribeCommandHandler) s.server.ServeMux().HandleFunc(protocol.PubSub.Publish, s.publishCommandHandler) + s.server.ServeMux().HandleFunc(protocol.PubSub.PublishInternal, s.publishInternalCommandHandler) s.server.ServeMux().HandleFunc(protocol.PubSub.PubSubChannels, s.pubsubChannelsCommandHandler) s.server.ServeMux().HandleFunc(protocol.PubSub.PubSubNumpat, s.pubsubNumpatCommandHandler) s.server.ServeMux().HandleFunc(protocol.PubSub.PubSubNumsub, s.pubsubNumsubCommandHandler) diff --git a/internal/encoding/encoder.go b/internal/resp/encoder.go similarity index 99% rename from internal/encoding/encoder.go rename to internal/resp/encoder.go index eb4dd062..ac8c3ead 100644 --- a/internal/encoding/encoder.go +++ b/internal/resp/encoder.go @@ -1,4 +1,4 @@ -package encoding +package resp import ( "encoding" diff --git a/internal/resp/encoder_test.go b/internal/resp/encoder_test.go new file mode 100644 index 00000000..0235dd2d --- /dev/null +++ b/internal/resp/encoder_test.go @@ -0,0 +1,286 @@ +package resp + +import ( + "bytes" + "encoding" + "fmt" + "testing" + "time" + + "github.com/stretchr/testify/require" +) + +type MyType struct{} + +var _ encoding.BinaryMarshaler = (*MyType)(nil) + +func (t *MyType) MarshalBinary() ([]byte, error) { + return []byte("hello"), nil +} + +func (t *MyType) UnmarshalBinary(data []byte) error { + if !bytes.Equal([]byte("hello"), data) { + return fmt.Errorf("not equal") + } + return nil +} + +func TestWriter_WriteArg(t *testing.T) { + buf := bytes.NewBuffer(nil) + w := New(buf) + + t.Run("uint64", func(t *testing.T) { + defer buf.Reset() + value := uint64(345353) + err := w.Encode(value) + require.NoError(t, err) + + scannedValue := new(uint64) + err = Scan(buf.Bytes(), scannedValue) + require.NoError(t, err) + require.Equal(t, uint64(345353), *scannedValue) + }) + + t.Run("nil", func(t *testing.T) { + defer buf.Reset() + + err := w.Encode(nil) + require.NoError(t, err) + + scannedValue := new(string) + err = Scan(buf.Bytes(), scannedValue) + require.NoError(t, err) + require.Equal(t, "", *scannedValue) + }) + + t.Run("string", func(t *testing.T) { + defer buf.Reset() + + err := w.Encode("foobar") + require.NoError(t, err) + + scannedValue := new(string) + err = Scan(buf.Bytes(), scannedValue) + require.NoError(t, err) + require.Equal(t, "foobar", *scannedValue) + }) + + t.Run("byte slice", func(t *testing.T) { + defer buf.Reset() + + err := w.Encode([]byte("foobar")) + require.NoError(t, err) + + scannedValue := new([]byte) + err = Scan(buf.Bytes(), scannedValue) + require.NoError(t, err) + require.Equal(t, []byte("foobar"), *scannedValue) + }) + + t.Run("int", func(t *testing.T) { + defer buf.Reset() + + value := 345353 + err := w.Encode(value) + require.NoError(t, err) + + scannedValue := new(int) + err = Scan(buf.Bytes(), scannedValue) + require.NoError(t, err) + require.Equal(t, 345353, *scannedValue) + }) + + t.Run("int8", func(t *testing.T) { + defer buf.Reset() + + value := int8(2) + err := w.Encode(value) + require.NoError(t, err) + + scannedValue := new(int8) + err = Scan(buf.Bytes(), scannedValue) + require.NoError(t, err) + require.Equal(t, int8(2), *scannedValue) + }) + + t.Run("int16", func(t *testing.T) { + defer buf.Reset() + + value := int16(2) + err := w.Encode(value) + require.NoError(t, err) + + scannedValue := new(int16) + err = Scan(buf.Bytes(), scannedValue) + require.NoError(t, err) + require.Equal(t, int16(2), *scannedValue) + }) + + t.Run("int32", func(t *testing.T) { + defer buf.Reset() + + value := int32(2) + err := w.Encode(value) + require.NoError(t, err) + + scannedValue := new(int32) + err = Scan(buf.Bytes(), scannedValue) + require.NoError(t, err) + require.Equal(t, int32(2), *scannedValue) + }) + + t.Run("int64", func(t *testing.T) { + defer buf.Reset() + + value := int64(2) + err := w.Encode(value) + require.NoError(t, err) + + scannedValue := new(int64) + err = Scan(buf.Bytes(), scannedValue) + require.NoError(t, err) + require.Equal(t, int64(2), *scannedValue) + }) + + t.Run("uint", func(t *testing.T) { + defer buf.Reset() + + value := uint(2) + err := w.Encode(value) + require.NoError(t, err) + + scannedValue := new(uint) + err = Scan(buf.Bytes(), scannedValue) + require.NoError(t, err) + require.Equal(t, uint(2), *scannedValue) + }) + + t.Run("uint8", func(t *testing.T) { + defer buf.Reset() + + value := uint8(2) + err := w.Encode(value) + require.NoError(t, err) + + scannedValue := new(uint8) + err = Scan(buf.Bytes(), scannedValue) + require.NoError(t, err) + require.Equal(t, uint8(2), *scannedValue) + }) + + t.Run("uint16", func(t *testing.T) { + defer buf.Reset() + + value := uint16(2) + err := w.Encode(value) + require.NoError(t, err) + + scannedValue := new(uint16) + err = Scan(buf.Bytes(), scannedValue) + require.NoError(t, err) + require.Equal(t, uint16(2), *scannedValue) + }) + + t.Run("uint32", func(t *testing.T) { + defer buf.Reset() + + value := uint32(2) + err := w.Encode(value) + require.NoError(t, err) + + scannedValue := new(uint32) + err = Scan(buf.Bytes(), scannedValue) + require.NoError(t, err) + require.Equal(t, uint32(2), *scannedValue) + }) + + t.Run("uint64", func(t *testing.T) { + defer buf.Reset() + + value := uint64(2) + err := w.Encode(value) + require.NoError(t, err) + + scannedValue := new(uint64) + err = Scan(buf.Bytes(), scannedValue) + require.NoError(t, err) + require.Equal(t, uint64(2), *scannedValue) + }) + + t.Run("float32", func(t *testing.T) { + defer buf.Reset() + + value := float32(2) + err := w.Encode(value) + require.NoError(t, err) + + scannedValue := new(float32) + err = Scan(buf.Bytes(), scannedValue) + require.NoError(t, err) + require.Equal(t, float32(2), *scannedValue) + }) + + t.Run("float64", func(t *testing.T) { + defer buf.Reset() + + value := float64(2) + err := w.Encode(value) + require.NoError(t, err) + + scannedValue := new(float64) + err = Scan(buf.Bytes(), scannedValue) + require.NoError(t, err) + require.Equal(t, float64(2), *scannedValue) + }) + + t.Run("bool", func(t *testing.T) { + defer buf.Reset() + + value := true + err := w.Encode(value) + require.NoError(t, err) + + scannedValue := new(bool) + err = Scan(buf.Bytes(), scannedValue) + require.NoError(t, err) + require.Equal(t, true, *scannedValue) + }) + + t.Run("time.Time", func(t *testing.T) { + defer buf.Reset() + + value := time.Now() + err := w.Encode(value) + require.NoError(t, err) + + scannedValue := new(time.Time) + err = Scan(buf.Bytes(), scannedValue) + require.NoError(t, err) + }) + + t.Run("time.Duration", func(t *testing.T) { + defer buf.Reset() + + value := time.Second + err := w.Encode(value) + require.NoError(t, err) + + scannedValue := new(time.Duration) + err = Scan(buf.Bytes(), scannedValue) + require.NoError(t, err) + require.Equal(t, time.Second, *scannedValue) + }) + + t.Run("encoding.BinaryMarshaler", func(t *testing.T) { + defer buf.Reset() + + var value encoding.BinaryMarshaler = &MyType{} + err := w.Encode(value) + require.NoError(t, err) + + scannedValue := new(MyType) + err = Scan(buf.Bytes(), scannedValue) + require.NoError(t, err) + require.Equal(t, MyType{}, *scannedValue) + }) +} diff --git a/internal/encoding/scan.go b/internal/resp/scan.go similarity index 61% rename from internal/encoding/scan.go rename to internal/resp/scan.go index 493031b3..38ca146e 100644 --- a/internal/encoding/scan.go +++ b/internal/resp/scan.go @@ -1,9 +1,8 @@ -package encoding +package resp import ( "encoding" "fmt" - "reflect" "time" "github.com/buraksezer/olric/internal/util" @@ -120,61 +119,3 @@ func Scan(b []byte, v interface{}) error { "olric: can't unmarshal %T (consider implementing BinaryUnmarshaler)", v) } } - -func ScanSlice(data []string, slice interface{}) error { - v := reflect.ValueOf(slice) - if !v.IsValid() { - return fmt.Errorf("olric: ScanSlice(nil)") - } - if v.Kind() != reflect.Ptr { - return fmt.Errorf("olric: ScanSlice(non-pointer %T)", slice) - } - v = v.Elem() - if v.Kind() != reflect.Slice { - return fmt.Errorf("olric: ScanSlice(non-slice %T)", slice) - } - - next := makeSliceNextElemFunc(v) - for i, s := range data { - elem := next() - if err := Scan([]byte(s), elem.Addr().Interface()); err != nil { - err = fmt.Errorf("olric: ScanSlice index=%d value=%q failed: %w", i, s, err) - return err - } - } - - return nil -} - -func makeSliceNextElemFunc(v reflect.Value) func() reflect.Value { - elemType := v.Type().Elem() - - if elemType.Kind() == reflect.Ptr { - elemType = elemType.Elem() - return func() reflect.Value { - if v.Len() < v.Cap() { - v.Set(v.Slice(0, v.Len()+1)) - elem := v.Index(v.Len() - 1) - if elem.IsNil() { - elem.Set(reflect.New(elemType)) - } - return elem.Elem() - } - - elem := reflect.New(elemType) - v.Set(reflect.Append(v, elem)) - return elem.Elem() - } - } - - zero := reflect.Zero(elemType) - return func() reflect.Value { - if v.Len() < v.Cap() { - v.Set(v.Slice(0, v.Len()+1)) - return v.Index(v.Len() - 1) - } - - v.Set(reflect.Append(v, zero)) - return v.Index(v.Len() - 1) - } -} diff --git a/internal/roundrobin/round_robin.go b/internal/roundrobin/round_robin.go new file mode 100644 index 00000000..e88d928b --- /dev/null +++ b/internal/roundrobin/round_robin.go @@ -0,0 +1,74 @@ +// Copyright 2018-2022 Burak Sezer +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package roundrobin + +import ( + "errors" + "fmt" +) + +var ErrEmptyInstance = errors.New("empty round-robin instance") + +// RoundRobin implements quite simple round-robin scheduling algorithm to distribute load fairly between servers. +type RoundRobin struct { + current int + items []string +} + +// New returns a new RoundRobin instance. +func New(items []string) *RoundRobin { + return &RoundRobin{ + current: 0, + items: items, + } +} + +// Get returns an item. +func (r *RoundRobin) Get() (string, error) { + if r.current >= len(r.items) { + r.current %= len(r.items) + } + + if len(r.items) == 0 { + return "", ErrEmptyInstance + } + + if r.current >= len(r.items) { + return "", fmt.Errorf("round-robin: corrupted internal state") + } + + item := r.items[r.current] + r.current++ + return item, nil +} + +// Add adds a new item to the Round-Robin scheduler. +func (r *RoundRobin) Add(item string) { + r.items = append(r.items, item) +} + +// Delete deletes an item from the Round-Robin scheduler. +func (r *RoundRobin) Delete(i string) { + for idx, item := range r.items { + if item == i { + r.items = append(r.items[:idx], r.items[idx+1:]...) + } + } +} + +// Length returns the count of items +func (r *RoundRobin) Length() int { + return len(r.items) +} diff --git a/internal/roundrobin/round_robin_test.go b/internal/roundrobin/round_robin_test.go new file mode 100644 index 00000000..9c3becdc --- /dev/null +++ b/internal/roundrobin/round_robin_test.go @@ -0,0 +1,88 @@ +// Copyright 2018-2022 Burak Sezer +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package roundrobin + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func TestRoundRobin(t *testing.T) { + items := []string{"127.0.0.1:2323", "127.0.0.1:4556", "127.0.0.1:7889"} + r := New(items) + + t.Run("Get", func(t *testing.T) { + items := make(map[string]int) + for i := 0; i < r.Length(); i++ { + item, err := r.Get() + require.NoError(t, err) + items[item]++ + } + if len(items) != r.Length() { + t.Fatalf("Expected item count: %d. Got: %d", r.Length(), len(items)) + } + }) + + t.Run("Add", func(t *testing.T) { + item := "127.0.0.1:3320" + r.Add(item) + items := make(map[string]int) + for i := 0; i < r.Length(); i++ { + item, err := r.Get() + require.NoError(t, err) + items[item]++ + } + if _, ok := items[item]; !ok { + t.Fatalf("Item not processed: %s", item) + } + if len(items) != r.Length() { + t.Fatalf("Expected item count: %d. Got: %d", r.Length(), len(items)) + } + }) + + t.Run("Delete", func(t *testing.T) { + item := "127.0.0.1:7889" + r.Delete(item) + + items := make(map[string]int) + for i := 0; i < r.Length(); i++ { + item, err := r.Get() + require.NoError(t, err) + items[item]++ + } + if _, ok := items[item]; ok { + t.Fatalf("Item stil exists: %s", item) + } + if len(items) != r.Length() { + t.Fatalf("Expected item count: %d. Got: %d", r.Length(), len(items)) + } + }) +} + +func TestRoundRobin_Delete_NonExistent(t *testing.T) { + items := []string{"127.0.0.1:2323", "127.0.0.1:4556", "127.0.0.1:7889"} + r := New(items) + + var fresh []string + fresh = append(fresh, items...) + for i, item := range fresh { + if i+1 == len(items) { + r.Delete(item) + } else { + r.Delete(item) + } + } +} diff --git a/internal/server/client.go b/internal/server/client.go index ae44c071..17a0a074 100644 --- a/internal/server/client.go +++ b/internal/server/client.go @@ -20,14 +20,16 @@ import ( "sync" "github.com/buraksezer/olric/config" + "github.com/buraksezer/olric/internal/roundrobin" "github.com/go-redis/redis/v8" ) type Client struct { mu sync.RWMutex - config *config.Client - clients map[string]*redis.Client + config *config.Client + clients map[string]*redis.Client + roundRobin *roundrobin.RoundRobin } func NewClient(c *config.Client) *Client { @@ -39,8 +41,9 @@ func NewClient(c *config.Client) *Client { } } return &Client{ - config: c, - clients: make(map[string]*redis.Client), + config: c, + clients: make(map[string]*redis.Client), + roundRobin: roundrobin.New(nil), } } @@ -60,6 +63,9 @@ func (c *Client) Get(addr string) *redis.Client { opt.Addr = addr rc = redis.NewClient(opt) c.clients[addr] = rc + c.roundRobin.Add(addr) + // TODO: Remove unhealthy redis client periodically. + // TODO: Send a pig command after calling NewClient. return rc } @@ -67,10 +73,18 @@ func (c *Client) Pick() (*redis.Client, error) { c.mu.RLock() defer c.mu.RUnlock() - for _, rc := range c.clients { - return rc, nil + addr, err := c.roundRobin.Get() + if err == roundrobin.ErrEmptyInstance { + return nil, fmt.Errorf("no available client found") } - return nil, fmt.Errorf("no available client found") + if err != nil { + return nil, err + } + rc, ok := c.clients[addr] + if !ok { + return nil, fmt.Errorf("client could not be found: %s", addr) + } + return rc, nil } func (c *Client) Close(addr string) error { @@ -83,6 +97,7 @@ func (c *Client) Close(addr string) error { if err != nil { return err } + c.roundRobin.Delete(addr) delete(c.clients, addr) } @@ -104,6 +119,7 @@ func (c *Client) Shutdown(ctx context.Context) error { return err } delete(c.clients, addr) + c.roundRobin.Delete(addr) } return nil diff --git a/internal/server/client_test.go b/internal/server/client_test.go index 78a6cd95..13f20bae 100644 --- a/internal/server/client_test.go +++ b/internal/server/client_test.go @@ -27,18 +27,14 @@ import ( ) func TestServer_Client_Get(t *testing.T) { - s := newServer(t) - defer func() { - require.NoError(t, s.Shutdown(context.Background())) - }() - - s.ServeMux().HandleFunc(protocol.Generic.Ping, func(conn redcon.Conn, cmd redcon.Command) { + srv := newServer(t) + srv.ServeMux().HandleFunc(protocol.Generic.Ping, func(conn redcon.Conn, cmd redcon.Command) { conn.WriteBulkString("pong") }) - <-s.StartedCtx.Done() + <-srv.StartedCtx.Done() - addr := net.JoinHostPort(s.config.BindAddr, strconv.Itoa(s.config.BindPort)) + addr := net.JoinHostPort(srv.config.BindAddr, strconv.Itoa(srv.config.BindPort)) c := config.NewClient() require.NoError(t, c.Sanitize()) @@ -53,24 +49,94 @@ func TestServer_Client_Get(t *testing.T) { result, err := cmd.Result() require.NoError(t, err) require.Equal(t, "pong", result) + + t.Run("Fetch cached client", func(t *testing.T) { + newClient := cs.Get(addr) + require.Equal(t, rc, newClient) + }) +} + +func TestServer_Client_Pick(t *testing.T) { + servers := make(map[string]*Server) + for i := 0; i < 10; i++ { + srv := newServer(t) + srv.ServeMux().HandleFunc(protocol.Generic.Ping, func(conn redcon.Conn, cmd redcon.Command) { + conn.WriteBulkString("pong") + }) + addr := net.JoinHostPort(srv.config.BindAddr, strconv.Itoa(srv.config.BindPort)) + servers[addr] = srv + } + + c := config.NewClient() + require.NoError(t, c.Sanitize()) + + cs := NewClient(c) + + for addr, srv := range servers { + <-srv.StartedCtx.Done() + cs.Get(addr) + } + // All the servers have been started. + + clients := make(map[string]struct{}) + for i := 0; i < 100; i++ { + rc, err := cs.Pick() + require.NoError(t, err) + + ctx := context.Background() + cmd := protocol.NewPing().Command(ctx) + err = rc.Process(ctx, cmd) + require.NoError(t, err) + + result, err := cmd.Result() + require.NoError(t, err) + require.Equal(t, "pong", result) + clients[rc.String()] = struct{}{} + } + require.Greater(t, len(clients), 1) } func TestServer_Client_Close(t *testing.T) { - s := newServer(t) - defer func() { - require.NoError(t, s.Shutdown(context.Background())) - }() + srv := newServer(t) - <-s.StartedCtx.Done() + <-srv.StartedCtx.Done() c := config.NewClient() require.NoError(t, c.Sanitize()) - addr := net.JoinHostPort(s.config.BindAddr, strconv.Itoa(s.config.BindPort)) + addr := net.JoinHostPort(srv.config.BindAddr, strconv.Itoa(srv.config.BindPort)) cs := NewClient(c) rc1 := cs.Get(addr) require.NoError(t, cs.Close(addr)) rc2 := cs.Get(addr) require.NotEqual(t, rc1, rc2) + require.Equal(t, 1, cs.roundRobin.Length()) +} + +func TestServer_Client_Shutdown(t *testing.T) { + servers := make(map[string]*Server) + for i := 0; i < 10; i++ { + srv := newServer(t) + srv.ServeMux().HandleFunc(protocol.Generic.Ping, func(conn redcon.Conn, cmd redcon.Command) { + conn.WriteBulkString("pong") + }) + addr := net.JoinHostPort(srv.config.BindAddr, strconv.Itoa(srv.config.BindPort)) + servers[addr] = srv + } + + c := config.NewClient() + require.NoError(t, c.Sanitize()) + + cs := NewClient(c) + + for addr, srv := range servers { + <-srv.StartedCtx.Done() + cs.Get(addr) + } + // All the servers have been started. + err := cs.Shutdown(context.Background()) + require.NoError(t, err) + require.Empty(t, cs.clients) + require.Equal(t, 0, cs.roundRobin.Length()) } diff --git a/internal/server/mux_test.go b/internal/server/mux_test.go new file mode 100644 index 00000000..ec3ecd1f --- /dev/null +++ b/internal/server/mux_test.go @@ -0,0 +1,54 @@ +// Copyright 2018-2022 Burak Sezer +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package server + +import ( + "context" + "math/rand" + "testing" + + "github.com/buraksezer/olric/internal/protocol" + "github.com/go-redis/redis/v8" + "github.com/stretchr/testify/require" + "github.com/tidwall/redcon" +) + +func TestMux_PubSub_Command(t *testing.T) { + s := newServer(t) + + data := make([]byte, 8) + _, err := rand.Read(data) + require.NoError(t, err) + + s.ServeMux().HandleFunc(protocol.PubSub.PubSubNumpat, func(conn redcon.Conn, cmd redcon.Command) { + conn.WriteInt(10) + }) + + <-s.StartedCtx.Done() + + rdb := redis.NewClient(defaultRedisOptions(s.config)) + + ctx := context.Background() + var args []interface{} + args = append(args, "pubsub") + args = append(args, "numpat") + cmd := redis.NewIntCmd(ctx, args...) + err = rdb.Process(ctx, cmd) + require.NoError(t, err) + + num, err := cmd.Result() + require.NoError(t, err) + require.Equal(t, int64(10), num) +} diff --git a/internal/server/server_test.go b/internal/server/server_test.go index f1698c13..e1269e3c 100644 --- a/internal/server/server_test.go +++ b/internal/server/server_test.go @@ -83,7 +83,11 @@ func newServerWithPreConditionFunc(t *testing.T, precond func(conn redcon.Conn, } func newServer(t *testing.T) *Server { - return newServerWithPreConditionFunc(t, nil) + srv := newServerWithPreConditionFunc(t, nil) + t.Cleanup(func() { + require.NoError(t, srv.Shutdown(context.Background())) + }) + return srv } func defaultRedisOptions(c *Config) *redis.Options { @@ -94,17 +98,13 @@ func defaultRedisOptions(c *Config) *redis.Options { func TestServer_RESP(t *testing.T) { s := newServer(t) - defer func() { - require.NoError(t, s.Shutdown(context.Background())) - }() + respEcho(t, s) } func TestServer_RESP_Stats(t *testing.T) { s := newServer(t) - defer func() { - require.NoError(t, s.Shutdown(context.Background())) - }() + respEcho(t, s) require.NotEqual(t, int64(0), CommandsTotal.Read()) diff --git a/olric.go b/olric.go index ba657df1..01773d3d 100644 --- a/olric.go +++ b/olric.go @@ -54,30 +54,14 @@ import ( ) // ReleaseVersion is the current stable version of Olric -const ReleaseVersion string = "0.5.0-alpha.1" +const ReleaseVersion string = "0.5.0-alpha.2" var ( // ErrOperationTimeout is returned when an operation times out. ErrOperationTimeout = errors.New("operation timeout") - // ErrInternalServerError means that something unintentionally went - // wrong while processing the request. - ErrInternalServerError = errors.New("internal server error") - - // ErrUnknownOperation means that an unidentified message has been - // received from a rc. - ErrUnknownOperation = errors.New("unknown operation") - // ErrServerGone means that a cluster member is closed unexpectedly. ErrServerGone = errors.New("server is gone") - - // ErrInvalidArgument means that an invalid parameter is sent by the - // rc or a cluster member. - ErrInvalidArgument = errors.New("invalid argument") - - // ErrNotImplemented means that the requested feature has not been implemented - // yet. - ErrNotImplemented = errors.New("not implemented") ) // Olric implements a distributed cache and in-memory key/value data store. @@ -95,7 +79,7 @@ type Olric struct { primary *partitions.Partitions backup *partitions.Partitions - // RESP experiment + // RESP server and clients. server *server.Server client *server.Client @@ -210,7 +194,7 @@ func New(c *config.Config) (*Olric, error) { cancel: cancel, } - // RESP experiment + // Create a Redcon server instance rc := &server.Config{ BindAddr: c.BindAddr, BindPort: c.BindPort, @@ -244,6 +228,7 @@ func (db *Olric) registerCommandHandlers() { db.server.ServeMux().HandleFunc(protocol.Generic.Ping, db.pingCommandHandler) db.server.ServeMux().HandleFunc(protocol.Cluster.RoutingTable, db.clusterRoutingTableCommandHandler) db.server.ServeMux().HandleFunc(protocol.Generic.Stats, db.statsCommandHandler) + db.server.ServeMux().HandleFunc(protocol.Cluster.Members, db.clusterMembersCommandHandler) } // callStartedCallback checks passed checkpoint count and calls the callback @@ -387,7 +372,7 @@ func (db *Olric) Shutdown(ctx context.Context) error { latestError = err } - // RESP experiment + // Shutdown Redcon server if err := db.server.Shutdown(ctx); err != nil { db.log.V(2).Printf("[ERROR] Failed to shutdown RESP server: %v", err) latestError = err diff --git a/pkg/storage/engine.go b/pkg/storage/engine.go index 944437ac..ead48ef1 100644 --- a/pkg/storage/engine.go +++ b/pkg/storage/engine.go @@ -15,9 +15,21 @@ package storage import ( + "errors" "log" ) +// ErrKeyTooLarge is an error that indicates the given key is large than the determined key size. +// The current maximum key length is 256. +var ErrKeyTooLarge = errors.New("key too large") + +// ErrKeyNotFound is an error that indicates that the requested key could not be found in the DB. +var ErrKeyNotFound = errors.New("key not found") + +// ErrNotImplemented means that the interface implementation does not support +// the functionality required to fulfill the request. +var ErrNotImplemented = errors.New("not implemented yet") + type TransferIterator interface { Next() bool diff --git a/pkg/storage/storage.go b/pkg/storage/storage.go deleted file mode 100644 index 2ba75a4d..00000000 --- a/pkg/storage/storage.go +++ /dev/null @@ -1,48 +0,0 @@ -// Copyright 2018-2022 Burak Sezer -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package storage // import "github.com/buraksezer/olric/pkg/storage" - -import ( - "errors" - "fmt" - "plugin" -) - -// ErrKeyTooLarge is an error that indicates the given key is large than the determined key size. -// The current maximum key length is 256. -var ErrKeyTooLarge = errors.New("key too large") - -// ErrKeyNotFound is an error that indicates that the requested key could not be found in the DB. -var ErrKeyNotFound = errors.New("key not found") - -// ErrNotImplemented means that the interface implementation does not support -// the functionality required to fulfill the request. -var ErrNotImplemented = errors.New("not implemented yet") - -func LoadAsPlugin(pluginPath string) (Engine, error) { - plug, err := plugin.Open(pluginPath) - if err != nil { - return nil, fmt.Errorf("failed to open plugin: %w", err) - } - tmp, err := plug.Lookup("StorageEngines") - if err != nil { - return nil, fmt.Errorf("failed to lookup StorageEngines symbol: %w", err) - } - impl, ok := tmp.(Engine) - if !ok { - return nil, fmt.Errorf("unable to assert type to StorageEngines") - } - return impl, nil -} diff --git a/pkg/storage/storage_test.go b/pkg/storage/storage_test.go deleted file mode 100644 index 4ff1087b..00000000 --- a/pkg/storage/storage_test.go +++ /dev/null @@ -1,21 +0,0 @@ -// Copyright 2018-2022 Burak Sezer -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package storage - -import "testing" - -func Test_LoadAsPlugin(t *testing.T) { - // TODO: Add some tests here -} diff --git a/pubsub.go b/pubsub.go new file mode 100644 index 00000000..e5e40176 --- /dev/null +++ b/pubsub.go @@ -0,0 +1,80 @@ +// Copyright 2018-2022 Burak Sezer +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package olric + +import ( + "context" + "strings" + + "github.com/buraksezer/olric/internal/server" + "github.com/go-redis/redis/v8" +) + +type PubSub struct { + config *pubsubConfig + rc *redis.Client + client *server.Client +} + +func newPubSub(client *server.Client, options ...PubSubOption) (*PubSub, error) { + var ( + err error + rc *redis.Client + pc pubsubConfig + ) + for _, opt := range options { + opt(&pc) + } + + addr := strings.Trim(pc.Address, " ") + if addr != "" { + rc = client.Get(addr) + } else { + rc, err = client.Pick() + if err != nil { + return nil, err + } + } + + return &PubSub{ + config: &pc, + rc: rc, + client: client, + }, nil +} + +func (ps *PubSub) Subscribe(ctx context.Context, channels ...string) *redis.PubSub { + return ps.rc.Subscribe(ctx, channels...) +} + +func (ps *PubSub) PSubscribe(ctx context.Context, channels ...string) *redis.PubSub { + return ps.rc.PSubscribe(ctx, channels...) +} + +func (ps *PubSub) Publish(ctx context.Context, channel string, message interface{}) (int64, error) { + return ps.rc.Publish(ctx, channel, message).Result() +} + +func (ps *PubSub) PubSubChannels(ctx context.Context, pattern string) ([]string, error) { + return ps.rc.PubSubChannels(ctx, pattern).Result() +} + +func (ps *PubSub) PubSubNumSub(ctx context.Context, channels ...string) (map[string]int64, error) { + return ps.rc.PubSubNumSub(ctx, channels...).Result() +} + +func (ps *PubSub) PubSubNumPat(ctx context.Context) (int64, error) { + return ps.rc.PubSubNumPat(ctx).Result() +} diff --git a/pubsub_test.go b/pubsub_test.go new file mode 100644 index 00000000..28cf41a1 --- /dev/null +++ b/pubsub_test.go @@ -0,0 +1,263 @@ +// Copyright 2018-2022 Burak Sezer +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package olric + +import ( + "context" + "fmt" + "testing" + "time" + + "github.com/go-redis/redis/v8" + "github.com/stretchr/testify/require" +) + +func pubsubTestRunner(t *testing.T, ps *PubSub, kind, channel string) { + ctx := context.Background() + var rp *redis.PubSub + switch kind { + case "subscribe": + rp = ps.Subscribe(ctx, channel) + case "psubscribe": + rp = ps.PSubscribe(ctx, channel) + } + + defer func() { + require.NoError(t, rp.Close()) + }() + + // Wait for confirmation that subscription is created before publishing anything. + msgi, err := rp.ReceiveTimeout(ctx, time.Second) + require.NoError(t, err) + + subs := msgi.(*redis.Subscription) + require.Equal(t, kind, subs.Kind) + require.Equal(t, channel, subs.Channel) + require.Equal(t, 1, subs.Count) + + // Go channel which receives messages. + ch := rp.Channel() + + expected := make(map[string]struct{}) + for i := 0; i < 10; i++ { + msg := fmt.Sprintf("my-message-%d", i) + count, err := ps.Publish(ctx, "my-channel", msg) + require.Equal(t, int64(1), count) + require.NoError(t, err) + expected[msg] = struct{}{} + } + + consumed := make(map[string]struct{}) +L: + for { + select { + case msg := <-ch: + require.Equal(t, "my-channel", msg.Channel) + consumed[msg.Payload] = struct{}{} + if len(consumed) == 10 { + // It would be OK + break L + } + case <-time.After(5 * time.Second): + // Enough. Break it and check the consumed items. + break L + } + } + + require.Equal(t, expected, consumed) +} + +func TestPubSub_Publish_Subscribe(t *testing.T) { + cluster := newTestOlricCluster(t) + db := cluster.addMember(t) + + ctx := context.Background() + c, err := NewClusterClient([]string{db.name}) + require.NoError(t, err) + defer func() { + require.NoError(t, c.Close(ctx)) + }() + + ps, err := c.NewPubSub(ToAddress(db.rt.This().String())) + require.NoError(t, err) + + pubsubTestRunner(t, ps, "subscribe", "my-channel") +} + +func TestPubSub_Publish_PSubscribe(t *testing.T) { + cluster := newTestOlricCluster(t) + db := cluster.addMember(t) + + ctx := context.Background() + c, err := NewClusterClient([]string{db.name}) + require.NoError(t, err) + defer func() { + require.NoError(t, c.Close(ctx)) + }() + + ps, err := c.NewPubSub(ToAddress(db.rt.This().String())) + require.NoError(t, err) + pubsubTestRunner(t, ps, "psubscribe", "my-*") +} + +func TestPubSub_PubSubChannels(t *testing.T) { + cluster := newTestOlricCluster(t) + db := cluster.addMember(t) + + ctx := context.Background() + c, err := NewClusterClient([]string{db.name}) + require.NoError(t, err) + defer func() { + require.NoError(t, c.Close(ctx)) + }() + + ps, err := c.NewPubSub(ToAddress(db.rt.This().String())) + require.NoError(t, err) + + rp := ps.Subscribe(ctx, "my-channel") + + defer func() { + require.NoError(t, rp.Close()) + }() + + // Wait for confirmation that subscription is created before publishing anything. + _, err = rp.ReceiveTimeout(ctx, time.Second) + require.NoError(t, err) + + channels, err := ps.PubSubChannels(ctx, "my-*") + require.NoError(t, err) + + require.Equal(t, []string{"my-channel"}, channels) +} + +func TestPubSub_PubSubNumSub(t *testing.T) { + cluster := newTestOlricCluster(t) + db := cluster.addMember(t) + + ctx := context.Background() + c, err := NewClusterClient([]string{db.name}) + require.NoError(t, err) + defer func() { + require.NoError(t, c.Close(ctx)) + }() + + ps, err := c.NewPubSub(ToAddress(db.rt.This().String())) + require.NoError(t, err) + + rp := ps.Subscribe(ctx, "my-channel") + + defer func() { + require.NoError(t, rp.Close()) + }() + + // Wait for confirmation that subscription is created before publishing anything. + _, err = rp.ReceiveTimeout(ctx, time.Second) + require.NoError(t, err) + + numsub, err := ps.PubSubNumSub(ctx, "my-channel", "foobar") + require.NoError(t, err) + + expected := map[string]int64{ + "foobar": 0, + "my-channel": 1, + } + require.Equal(t, expected, numsub) +} + +func TestPubSub_PubSubNumPat(t *testing.T) { + cluster := newTestOlricCluster(t) + db := cluster.addMember(t) + + ctx := context.Background() + c, err := NewClusterClient([]string{db.name}) + require.NoError(t, err) + defer func() { + require.NoError(t, c.Close(ctx)) + }() + + ps, err := c.NewPubSub(ToAddress(db.rt.This().String())) + require.NoError(t, err) + + rp := ps.PSubscribe(ctx, "my-*") + + defer func() { + require.NoError(t, rp.Close()) + }() + + // Wait for confirmation that subscription is created before publishing anything. + _, err = rp.ReceiveTimeout(ctx, time.Second) + require.NoError(t, err) + + numpat, err := ps.PubSubNumPat(ctx) + require.NoError(t, err) + require.Equal(t, int64(1), numpat) +} + +func TestPubSub_Cluster(t *testing.T) { + cluster := newTestOlricCluster(t) + db1 := cluster.addMember(t) + db2 := cluster.addMember(t) + + // Create a subscriber + ctx := context.Background() + c, err := NewClusterClient([]string{db1.name}) + require.NoError(t, err) + defer func() { + require.NoError(t, c.Close(ctx)) + }() + + ps1, err := c.NewPubSub(ToAddress(db1.rt.This().String())) + require.NoError(t, err) + + rp := ps1.Subscribe(ctx, "my-channel") + defer func() { + require.NoError(t, rp.Close()) + }() + // Wait for confirmation that subscription is created before publishing anything. + _, err = rp.ReceiveTimeout(ctx, time.Second) + require.NoError(t, err) + receiveChan := rp.Channel() + + // Create a publisher + + e := db2.NewEmbeddedClient() + ps2, err := e.NewPubSub(ToAddress(db2.rt.This().String())) + require.NoError(t, err) + expected := make(map[string]struct{}) + for i := 0; i < 10; i++ { + msg := fmt.Sprintf("my-message-%d", i) + count, err := ps2.Publish(ctx, "my-channel", msg) + require.Equal(t, int64(1), count) + require.NoError(t, err) + expected[msg] = struct{}{} + } + + consumed := make(map[string]struct{}) +L: + for { + select { + case msg := <-receiveChan: + require.Equal(t, "my-channel", msg.Channel) + consumed[msg.Payload] = struct{}{} + if len(consumed) == 10 { + // It would be OK + break L + } + case <-time.After(5 * time.Second): + // Enough. Break it and check the consumed items. + break L + } + } +} diff --git a/stats.go b/stats.go index 69a56a9c..058d5ffd 100644 --- a/stats.go +++ b/stats.go @@ -18,6 +18,7 @@ import ( "encoding/json" "os" "runtime" + "strings" "github.com/buraksezer/olric/internal/cluster/partitions" "github.com/buraksezer/olric/internal/discovery" @@ -65,7 +66,8 @@ func (db *Olric) collectPartitionMetrics(partID uint64, part *partitions.Partiti tmp.SlabInfo.Allocated = st.Allocated tmp.SlabInfo.Garbage = st.Garbage tmp.SlabInfo.Inuse = st.Inuse - p.DMaps[name.(string)] = tmp + dmapName := strings.TrimPrefix(name.(string), "dmap.") + p.DMaps[dmapName] = tmp return true }) return p diff --git a/stats_test.go b/stats_test.go index e9535551..a0afa542 100644 --- a/stats_test.go +++ b/stats_test.go @@ -62,7 +62,7 @@ func TestOlric_Stats(t *testing.T) { var total int for partID, part := range s.Partitions { total += part.Length - if _, ok := part.DMaps["dmap.mymap"]; !ok { + if _, ok := part.DMaps["mymap"]; !ok { t.Fatalf("Expected dmap check result is true. Got false") } if len(part.PreviousOwners) != 0 {