diff --git a/lease.go b/lease.go index c5a64cd..9fa1d36 100644 --- a/lease.go +++ b/lease.go @@ -19,6 +19,13 @@ type Lease struct { ExpiresOn time.Time `json:"-"` } +// IsLive check if lease is live on node side func (l *Lease) IsLive() bool { return time.Now().Before(time.Unix(l.Since, 0).Add(l.TTL.Duration())) } + +func (l *Lease) IsExpired(start time.Time) bool { + now := time.Now() + l.ExpiresOn = now.Add(l.TTL.Duration() - time.Until(start)) + return !now.Before(l.ExpiresOn) +} diff --git a/mutex.go b/mutex.go index 32844f9..3214f6d 100644 --- a/mutex.go +++ b/mutex.go @@ -3,6 +3,7 @@ package dlm import ( "context" "errors" + "log/slog" "math" "net/rpc" "strings" @@ -27,7 +28,7 @@ func New(id, topic, key string, options ...MutexOption) *Mutex { o(m) } - m.consensus = int(math.Ceil(float64(len(m.peers)) / 2)) + m.consensus = int(math.Floor(float64(len(m.peers))/2)) + 1 return m } @@ -43,33 +44,29 @@ type Mutex struct { ttl time.Duration consensus int - cluster []*rpc.Client done chan struct{} lease Lease } -func (m *Mutex) connect(ctx context.Context) error { - if m.cluster == nil { - a := async.New[*rpc.Client]() - for _, d := range m.peers { - a.Add(func(addr string) func(context.Context) (*rpc.Client, error) { - return func(ctx context.Context) (*rpc.Client, error) { - return connect(ctx, addr, m.timeout) - } - }(d)) - } +func (m *Mutex) connect(ctx context.Context) ([]*rpc.Client, error) { - cluster, _, err := a.Wait(ctx) - if len(cluster) >= m.consensus { - m.cluster = cluster - return nil - } + a := async.New[*rpc.Client]() + for _, d := range m.peers { + a.Add(func(addr string) func(context.Context) (*rpc.Client, error) { + return func(ctx context.Context) (*rpc.Client, error) { + return connect(ctx, addr, m.timeout) + } + }(d)) + } - return err + cluster, _, err := a.Wait(ctx) + if len(cluster) >= m.consensus { + return cluster, nil } - return nil + return nil, err + } func (m *Mutex) Lock(ctx context.Context) (context.Context, context.CancelFunc, error) { @@ -82,12 +79,12 @@ func (m *Mutex) Lock(ctx context.Context) (context.Context, context.CancelFunc, ctx, cancel := context.WithTimeout(ctx, m.timeout) defer cancel() - err := m.connect(ctx) + cluster, err := m.connect(ctx) if err != nil { return nil, nil, err } - for _, c := range m.cluster { + for _, c := range cluster { a.Add(func(c *rpc.Client) func(ctx context.Context) (Lease, error) { return func(ctx context.Context) (Lease, error) { var t Lease @@ -98,17 +95,16 @@ func (m *Mutex) Lock(ctx context.Context) (context.Context, context.CancelFunc, } start := time.Now() - result, _, err := a.WaitN(ctx, m.consensus) + result, errs, err := a.WaitN(ctx, m.consensus) if err != nil { - return nil, nil, err + Logger.Warn("dlm: renew lock", slog.Any("err", errs)) + return nil, nil, m.Error(errs, err) } t := result[0] - now := time.Now() - t.ExpiresOn = now.Add(t.TTL.Duration() - time.Until(start)) - if !now.Before(t.ExpiresOn) { + if t.IsExpired(start) { return nil, nil, ErrExpiredLease } @@ -123,49 +119,17 @@ func (m *Mutex) Lock(ctx context.Context) (context.Context, context.CancelFunc, } -func (m *Mutex) Unlock(ctx context.Context) error { +func (m *Mutex) Renew(ctx context.Context) error { m.mu.Lock() defer m.mu.Unlock() - - a := async.New[bool]() + a := async.New[Lease]() req := m.createRequest() - for _, c := range m.cluster { - a.Add(func(c *rpc.Client) func(ctx context.Context) (bool, error) { - return func(ctx context.Context) (bool, error) { - var t bool - err := c.Call("dlm.ReleaseLock", req, &t) - if err != nil { - return t, err - } - - return t, nil - } - }(c)) - } - m.done <- struct{}{} - - _, _, err := a.WaitN(ctx, m.consensus) + cluster, err := m.connect(ctx) if err != nil { return err } - return nil -} -func (m *Mutex) createRequest() LockRequest { - return LockRequest{ - ID: m.id, - Topic: m.topic, - Key: m.key, - TTL: m.ttl, - } -} - -func (m *Mutex) Renew(ctx context.Context) error { - m.mu.Lock() - defer m.mu.Unlock() - a := async.New[Lease]() - req := m.createRequest() - for _, c := range m.cluster { + for _, c := range cluster { a.Add(func(c *rpc.Client) func(ctx context.Context) (Lease, error) { return func(ctx context.Context) (Lease, error) { var t Lease @@ -182,16 +146,14 @@ func (m *Mutex) Renew(ctx context.Context) error { defer cancel() start := time.Now() - result, _, err := a.WaitN(ctx, m.consensus) + result, errs, err := a.WaitN(ctx, m.consensus) if err != nil { - return err + Logger.Warn("dlm: renew lock", slog.Any("err", errs)) + return m.Error(errs, err) } - now := time.Now() t := result[0] - t.ExpiresOn = now.Add(t.TTL.Duration() - time.Until(start)) - - if !now.After(t.ExpiresOn) { + if t.IsExpired(start) { return ErrExpiredLease } @@ -199,6 +161,50 @@ func (m *Mutex) Renew(ctx context.Context) error { return nil } +func (m *Mutex) Unlock(ctx context.Context) error { + m.mu.Lock() + defer m.mu.Unlock() + + a := async.New[bool]() + req := m.createRequest() + + cluster, err := m.connect(ctx) + if err != nil { + return err + } + + for _, c := range cluster { + a.Add(func(c *rpc.Client) func(ctx context.Context) (bool, error) { + return func(ctx context.Context) (bool, error) { + var t bool + err := c.Call("dlm.ReleaseLock", req, &t) + if err != nil { + return t, err + } + + return t, nil + } + }(c)) + } + m.done <- struct{}{} + + _, errs, err := a.WaitN(ctx, m.consensus) + if err != nil { + Logger.Warn("dlm: renew lock", slog.Any("err", errs)) + return m.Error(errs, err) + } + + return nil +} +func (m *Mutex) createRequest() LockRequest { + return LockRequest{ + ID: m.id, + Topic: m.topic, + Key: m.key, + TTL: m.ttl, + } +} + func (m *Mutex) waitExpires(ctx context.Context, cancel context.CancelFunc) { defer cancel() var expiresOn time.Time @@ -248,3 +254,55 @@ func (m *Mutex) keepalive(ctx context.Context, cancel context.CancelFunc) { } } } + +// Error try unwrap consensus known error +func (m *Mutex) Error(errs []error, err error) error { + consensus := make(map[rpc.ServerError]int) + + for _, err := range errs { + s, ok := err.(rpc.ServerError) + if ok { + c, ok := consensus[s] + if !ok { + consensus[s] = 1 + } else { + consensus[s] = c + 1 + } + } + } + + max := 0 + var msg string + + for k, v := range consensus { + if v > max { + max = v + msg = string(k) + } + } + + if max < m.consensus { + return err + } + + if !strings.HasPrefix(msg, "dlm:") { + return err + } + + switch msg { + case ErrExpiredLease.Error(): + return ErrExpiredLease + case ErrNoLease.Error(): + return ErrNoLease + case ErrNotYourLease.Error(): + return ErrNotYourLease + case ErrLeaseExists.Error(): + return ErrLeaseExists + case ErrFrozenTopic.Error(): + return ErrFrozenTopic + case ErrBadDatabase.Error(): + return ErrBadDatabase + } + + return err +} diff --git a/mutext_test.go b/mutext_test.go index 3735251..8a9dd24 100644 --- a/mutext_test.go +++ b/mutext_test.go @@ -68,7 +68,7 @@ func TestLock(t *testing.T) { }, }, { - name: "lock_should_work_even_minority_nodes_are_down", + name: "lock_should_work_when_minority_nodes_are_down", run: func(r *require.Assertions) { m := New("lock_should_work", "wallet", "minority_nodes_are_down", WithPeers(peers...), WithTTL(10*time.Second)) @@ -89,6 +89,22 @@ func TestLock(t *testing.T) { }, }, + { + name: "lock_should_not_work_when_majority_nodes_are_down", + run: func(r *require.Assertions) { + m := New("lock_should_work", "wallet", "majority_nodes_are_down", WithPeers(peers...), WithTTL(10*time.Second)) + + nodes[2].Stop() + + _, cancel, err := m.Lock(context.TODO()) + if cancel != nil { + defer cancel() + } + + r.Error(err, async.ErrTooLessDone) + + }, + }, } for _, test := range tests { @@ -96,5 +112,107 @@ func TestLock(t *testing.T) { test.run(require.New(t)) }) } +} +func TestRenew(t *testing.T) { + ctx, cancel := context.WithCancel(context.TODO()) + defer cancel() + peers, nodes, clean, err := createCluster(ctx, 5) + + require.NoError(t, err) + + defer func() { + for _, c := range clean { + c() + } + }() + + tests := []struct { + name string + run func(*require.Assertions) + }{ + { + name: "renew_should_work", + run: func(r *require.Assertions) { + ttl := 10 * time.Second + m := New("renew", "wallet", "renew", WithPeers(peers...), WithTTL(ttl)) + _, cancel, err := m.Lock(context.TODO()) + defer cancel() + r.NoError(err) + r.Equal(ttl, m.lease.TTL.Duration()) + r.Equal("wallet", m.lease.Topic) + r.Equal("renew", m.lease.Key) + + err = m.Renew(context.TODO()) + r.NoError(err) + r.Equal(ttl, m.lease.TTL.Duration()) + r.Equal("wallet", m.lease.Topic) + r.Equal("renew", m.lease.Key) + }, + }, + { + name: "renew_should_not_work_when_lease_is_expired", + run: func(r *require.Assertions) { + ttl := 2 * time.Second + m := New("renew", "wallet", "renew", WithPeers(peers...), WithTTL(ttl)) + _, cancel, err := m.Lock(context.TODO()) + defer cancel() + r.NoError(err) + r.Equal(ttl, m.lease.TTL.Duration()) + r.Equal("wallet", m.lease.Topic) + r.Equal("renew", m.lease.Key) + + time.Sleep(ttl) + + err = m.Renew(context.TODO()) + r.ErrorIs(err, ErrExpiredLease) + + }, + }, + { + name: "renew_should_work_when_minority_nodes_are_down", + run: func(r *require.Assertions) { + m := New("renew", "wallet", "renew_minority", WithPeers(peers...), WithTTL(10*time.Second)) + + _, cancel, err := m.Lock(context.TODO()) + defer cancel() + r.NoError(err) + r.Equal(10*time.Second, m.lease.TTL.Duration()) + r.Equal("wallet", m.lease.Topic) + r.Equal("renew_minority", m.lease.Key) + + nodes[0].Stop() + nodes[1].Stop() + err = m.Renew(context.TODO()) + r.NoError(err) + + }, + }, + { + name: "renew_should_not_work_when_majority_nodes_are_down", + run: func(r *require.Assertions) { + m := New("renew", "wallet", "renew_majority", WithPeers(peers...), WithTTL(10*time.Second)) + + _, cancel, err := m.Lock(context.TODO()) + defer cancel() + r.NoError(err) + r.Equal(10*time.Second, m.lease.TTL.Duration()) + r.Equal("wallet", m.lease.Topic) + r.Equal("renew_majority", m.lease.Key) + + nodes[0].Stop() + nodes[1].Stop() + nodes[2].Stop() + err = m.Renew(context.TODO()) + r.ErrorIs(err, async.ErrTooLessDone) + + }, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + test.run(require.New(t)) + }) + } } diff --git a/node_svc.go b/node_svc.go index 6664fd2..5d4ef98 100644 --- a/node_svc.go +++ b/node_svc.go @@ -34,7 +34,9 @@ func (n *Node) Start(ctx context.Context) error { // Stop stop the node and its RPC service func (n *Node) Stop() { - n.close <- struct{}{} + go func() { + n.close <- struct{}{} + }() n.logger.Info("dlm: node stopped") } diff --git a/peer.go b/peer.go deleted file mode 100644 index 248f852..0000000 --- a/peer.go +++ /dev/null @@ -1,8 +0,0 @@ -package dlm - -type Peer struct { - RaftID string - RaftAddr string - - Addr string -}