diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 7b3b890..d6ccf3f 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -25,7 +25,7 @@ jobs: with: go-version: ^1.22 - name: golangci-lint - uses: golangci/golangci-lint-action@v3 + uses: golangci/golangci-lint-action@v4 with: version: v1.54 unit-tests: diff --git a/.gitignore b/.gitignore index e9acf95..bae3a54 100644 --- a/.gitignore +++ b/.gitignore @@ -10,7 +10,7 @@ # Test binary, built with `go test -c` *.test -.db +*.db snapshots # Output of the go coverage tool, specifically when used with LiteIDE diff --git a/go.mod b/go.mod index 5e6c91d..06879b4 100644 --- a/go.mod +++ b/go.mod @@ -4,6 +4,7 @@ go 1.18 require ( github.com/mattn/go-sqlite3 v1.14.22 + github.com/pkg/errors v0.9.1 github.com/stretchr/testify v1.9.0 github.com/yaitoo/async v1.0.4 github.com/yaitoo/sqle v1.3.1 diff --git a/go.sum b/go.sum index a2f425f..342f815 100644 --- a/go.sum +++ b/go.sum @@ -17,6 +17,7 @@ github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWE github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= github.com/mattn/go-sqlite3 v1.14.22 h1:2gZY6PC6kBnID23Tichd1K+Z0oS6nE/XwU+Vz/5o4kU= github.com/mattn/go-sqlite3 v1.14.22/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y= +github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= diff --git a/lease.go b/lease.go index c5a64cd..981ee80 100644 --- a/lease.go +++ b/lease.go @@ -19,6 +19,14 @@ type Lease struct { ExpiresOn time.Time `json:"-"` } +// IsLive check if lease is live on node side func (l *Lease) IsLive() bool { return time.Now().Before(time.Unix(l.Since, 0).Add(l.TTL.Duration())) } + +// IsExpired check if lease expires on mutex side +func (l *Lease) IsExpired(start time.Time) bool { + now := time.Now() + l.ExpiresOn = now.Add(l.TTL.Duration() - time.Until(start)) + return !now.Before(l.ExpiresOn) +} diff --git a/mutex.go b/mutex.go index 32844f9..612ca06 100644 --- a/mutex.go +++ b/mutex.go @@ -3,6 +3,7 @@ package dlm import ( "context" "errors" + "log/slog" "math" "net/rpc" "strings" @@ -18,7 +19,6 @@ func New(id, topic, key string, options ...MutexOption) *Mutex { id: id, topic: strings.ToLower(topic), key: strings.ToLower(key), - done: make(chan struct{}), timeout: DefaultTimeout, ttl: DefaultLeaseTerm, } @@ -27,7 +27,8 @@ func New(id, topic, key string, options ...MutexOption) *Mutex { o(m) } - m.consensus = int(math.Ceil(float64(len(m.peers)) / 2)) + m.Context, m.cancel = context.WithCancelCause(context.Background()) + m.consensus = int(math.Floor(float64(len(m.peers))/2)) + 1 return m } @@ -43,36 +44,34 @@ type Mutex struct { ttl time.Duration consensus int - cluster []*rpc.Client - done chan struct{} + + context.Context + cancel context.CancelCauseFunc lease Lease } -func (m *Mutex) connect(ctx context.Context) error { - if m.cluster == nil { - a := async.New[*rpc.Client]() - for _, d := range m.peers { - a.Add(func(addr string) func(context.Context) (*rpc.Client, error) { - return func(ctx context.Context) (*rpc.Client, error) { - return connect(ctx, addr, m.timeout) - } - }(d)) - } +func (m *Mutex) connect(ctx context.Context) ([]*rpc.Client, error) { - cluster, _, err := a.Wait(ctx) - if len(cluster) >= m.consensus { - m.cluster = cluster - return nil - } + a := async.New[*rpc.Client]() + for _, d := range m.peers { + a.Add(func(addr string) func(context.Context) (*rpc.Client, error) { + return func(ctx context.Context) (*rpc.Client, error) { + return connect(ctx, addr, m.timeout) + } + }(d)) + } - return err + cluster, _, err := a.Wait(ctx) + if len(cluster) >= m.consensus { + return cluster, nil } - return nil + return nil, err + } -func (m *Mutex) Lock(ctx context.Context) (context.Context, context.CancelFunc, error) { +func (m *Mutex) Lock(ctx context.Context) error { m.mu.Lock() defer m.mu.Unlock() @@ -82,12 +81,12 @@ func (m *Mutex) Lock(ctx context.Context) (context.Context, context.CancelFunc, ctx, cancel := context.WithTimeout(ctx, m.timeout) defer cancel() - err := m.connect(ctx) + cluster, err := m.connect(ctx) if err != nil { - return nil, nil, err + return err } - for _, c := range m.cluster { + for _, c := range cluster { a.Add(func(c *rpc.Client) func(ctx context.Context) (Lease, error) { return func(ctx context.Context) (Lease, error) { var t Lease @@ -98,109 +97,185 @@ func (m *Mutex) Lock(ctx context.Context) (context.Context, context.CancelFunc, } start := time.Now() - result, _, err := a.WaitN(ctx, m.consensus) + result, errs, err := a.WaitN(ctx, m.consensus) if err != nil { - return nil, nil, err + Logger.Warn("dlm: renew lock", slog.Any("err", errs)) + return m.Error(errs, err) } t := result[0] - now := time.Now() - t.ExpiresOn = now.Add(t.TTL.Duration() - time.Until(start)) - if !now.Before(t.ExpiresOn) { - return nil, nil, ErrExpiredLease + if t.IsExpired(start) { + return ErrExpiredLease } m.lease = t - statusCtx, statusCancel := context.WithCancel(context.Background()) - - go m.keepalive(statusCtx, statusCancel) - go m.waitExpires(statusCtx, statusCancel) + go m.waitExpires() - return statusCtx, statusCancel, nil + return nil } -func (m *Mutex) Unlock(ctx context.Context) error { +func (m *Mutex) Renew(ctx context.Context) error { m.mu.Lock() defer m.mu.Unlock() - - a := async.New[bool]() + a := async.New[Lease]() req := m.createRequest() - for _, c := range m.cluster { - a.Add(func(c *rpc.Client) func(ctx context.Context) (bool, error) { - return func(ctx context.Context) (bool, error) { - var t bool - err := c.Call("dlm.ReleaseLock", req, &t) + cluster, err := m.connect(ctx) + if err != nil { + return err + } + + for _, c := range cluster { + a.Add(func(c *rpc.Client) func(ctx context.Context) (Lease, error) { + return func(ctx context.Context) (Lease, error) { + var t Lease + err := c.Call("dlm.RenewLock", req, &t) if err != nil { return t, err } - return t, nil } }(c)) } - m.done <- struct{}{} - _, _, err := a.WaitN(ctx, m.consensus) + ctx, cancel := context.WithTimeout(ctx, m.timeout) + defer cancel() + + start := time.Now() + result, errs, err := a.WaitN(ctx, m.consensus) if err != nil { - return err + Logger.Warn("dlm: renew lock", slog.Any("err", errs)) + return m.Error(errs, err) } - return nil -} -func (m *Mutex) createRequest() LockRequest { - return LockRequest{ - ID: m.id, - Topic: m.topic, - Key: m.key, - TTL: m.ttl, + t := result[0] + if t.IsExpired(start) { + return ErrExpiredLease } + + m.lease = t + return nil } -func (m *Mutex) Renew(ctx context.Context) error { +func (m *Mutex) Unlock(ctx context.Context) error { m.mu.Lock() defer m.mu.Unlock() - a := async.New[Lease]() + + a := async.NewA() req := m.createRequest() - for _, c := range m.cluster { - a.Add(func(c *rpc.Client) func(ctx context.Context) (Lease, error) { - return func(ctx context.Context) (Lease, error) { - var t Lease - err := c.Call("dlm.RenewLock", req, &t) + + cluster, err := m.connect(ctx) + if err != nil { + return err + } + + for _, c := range cluster { + a.Add(func(c *rpc.Client) func(ctx context.Context) error { + return func(ctx context.Context) error { + var t bool + err := c.Call("dlm.ReleaseLock", req, &t) if err != nil { - return t, err + return err } - return t, nil + + return nil } }(c)) } + errs, err := a.WaitN(ctx, m.consensus) + if err != nil { + Logger.Warn("dlm: renew lock", slog.Any("err", errs)) + return m.Error(errs, err) + } + + m.cancel(nil) + + return nil +} + +func (m *Mutex) Freeze(ctx context.Context, topic string) error { + m.mu.Lock() + defer m.mu.Unlock() + + a := async.NewA() + ctx, cancel := context.WithTimeout(ctx, m.timeout) defer cancel() - start := time.Now() - result, _, err := a.WaitN(ctx, m.consensus) + cluster, err := m.connect(ctx) if err != nil { return err } - now := time.Now() - t := result[0] - t.ExpiresOn = now.Add(t.TTL.Duration() - time.Until(start)) + for _, c := range cluster { + a.Add(func(c *rpc.Client) func(ctx context.Context) error { + return func(ctx context.Context) error { + var ok bool + return c.Call("dlm.Freeze", topic, &ok) + } + }(c)) + } - if !now.After(t.ExpiresOn) { - return ErrExpiredLease + errs, err := a.WaitN(ctx, m.consensus) + + if err != nil { + Logger.Warn("dlm: freeze topic", slog.Any("err", errs)) + return m.Error(errs, err) } - m.lease = t return nil + } -func (m *Mutex) waitExpires(ctx context.Context, cancel context.CancelFunc) { +func (m *Mutex) Reset(ctx context.Context, topic string) error { + m.mu.Lock() + defer m.mu.Unlock() + + a := async.NewA() + + ctx, cancel := context.WithTimeout(ctx, m.timeout) defer cancel() + + cluster, err := m.connect(ctx) + if err != nil { + return err + } + + for _, c := range cluster { + a.Add(func(c *rpc.Client) func(ctx context.Context) error { + return func(ctx context.Context) error { + var ok bool + return c.Call("dlm.Reset", topic, &ok) + } + }(c)) + } + + errs, err := a.WaitN(ctx, m.consensus) + + if err != nil { + Logger.Warn("dlm: freeze reset", slog.Any("err", errs)) + return m.Error(errs, err) + } + + return nil + +} + +func (m *Mutex) createRequest() LockRequest { + return LockRequest{ + ID: m.id, + Topic: m.topic, + Key: m.key, + TTL: m.ttl, + } +} + +func (m *Mutex) waitExpires() { + var expiresOn time.Time for { m.mu.RLock() @@ -208,43 +283,98 @@ func (m *Mutex) waitExpires(ctx context.Context, cancel context.CancelFunc) { m.mu.RUnlock() select { - case <-m.done: - return - case <-ctx.Done(): + + case <-m.Context.Done(): return case <-time.After(time.Until(expiresOn)): - if !expiresOn.Before(expiresOn) { + // get latest ExpiresOn + m.mu.RLock() + expiresOn = m.lease.ExpiresOn + m.mu.RUnlock() + + if !time.Now().Before(expiresOn) { + m.cancel(ErrExpiredLease) return } } } } -func (m *Mutex) keepalive(ctx context.Context, cancel context.CancelFunc) { - defer cancel() +func (m *Mutex) Keepalive() { var err error for { - - m.mu.RLock() - expiresOn := m.lease.ExpiresOn - m.mu.RUnlock() - select { - case <-m.done: - return - case <-ctx.Done(): + case <-m.Context.Done(): return case <-time.After(1 * time.Second): + m.mu.RLock() + expiresOn := m.lease.ExpiresOn + m.mu.RUnlock() + // lease already expires - if !expiresOn.Before(expiresOn) { + if !time.Now().Before(expiresOn) { + m.cancel(ErrExpiredLease) return } err = m.Renew(context.Background()) if errors.Is(err, ErrExpiredLease) { + m.cancel(ErrExpiredLease) return } } } } + +// Error try unwrap consensus known error +func (m *Mutex) Error(errs []error, err error) error { + consensus := make(map[rpc.ServerError]int) + + for _, err := range errs { + s, ok := err.(rpc.ServerError) + if ok { + c, ok := consensus[s] + if !ok { + consensus[s] = 1 + } else { + consensus[s] = c + 1 + } + } + } + + max := 0 + var msg string + + for k, v := range consensus { + if v > max { + max = v + msg = string(k) + } + } + + if max < m.consensus { + return err + } + + if !strings.HasPrefix(msg, "dlm:") { + return err + } + + switch msg { + case ErrExpiredLease.Error(): + return ErrExpiredLease + case ErrNoLease.Error(): + return ErrNoLease + case ErrNotYourLease.Error(): + return ErrNotYourLease + case ErrLeaseExists.Error(): + return ErrLeaseExists + case ErrFrozenTopic.Error(): + return ErrFrozenTopic + case ErrBadDatabase.Error(): + return ErrBadDatabase + } + + return err +} diff --git a/mutex_option.go b/mutex_option.go index 2e147b4..025ea5c 100644 --- a/mutex_option.go +++ b/mutex_option.go @@ -1,6 +1,8 @@ package dlm -import "time" +import ( + "time" +) type MutexOption func(m *Mutex) diff --git a/mutext_test.go b/mutext_test.go index 3735251..504cb46 100644 --- a/mutext_test.go +++ b/mutext_test.go @@ -10,41 +10,41 @@ import ( "github.com/yaitoo/sqle" ) -func createCluster(ctx context.Context, num int) ([]string, []*Node, []func(), error) { +func createCluster(num int) ([]string, []*Node, func(), error) { var peers []string var nodes []*Node var clean []func() + release := func() { + for _, c := range clean { + c() + } + } + for i := 0; i < num; i++ { db, fn, err := createSqlite3() if err != nil { - return nil, nil, clean, err + return nil, nil, release, err } clean = append(clean, fn) n := NewNode(getFreeAddr(), sqle.Open(db)) - err = n.Start(ctx) + err = n.Start() if err != nil { - return nil, nil, clean, err + return nil, nil, release, err } peers = append(peers, n.addr) nodes = append(nodes, n) + + clean = append(clean, n.Stop) } - return peers, nodes, clean, nil + return peers, nodes, release, nil } func TestLock(t *testing.T) { - ctx, cancel := context.WithCancel(context.TODO()) - defer cancel() - peers, nodes, clean, err := createCluster(ctx, 5) - + peers, nodes, clean, err := createCluster(5) require.NoError(t, err) - - defer func() { - for _, c := range clean { - c() - } - }() + defer clean() tests := []struct { name string @@ -54,39 +54,113 @@ func TestLock(t *testing.T) { name: "lock_should_work", run: func(r *require.Assertions) { m := New("lock_should_work", "wallet", "lock_should_work", WithPeers(peers...), WithTTL(10*time.Second)) - _, cancel, err := m.Lock(context.TODO()) - defer cancel() + err := m.Lock(context.TODO()) r.NoError(err) r.Equal(10*time.Second, m.lease.TTL.Duration()) r.Equal("wallet", m.lease.Topic) r.Equal("lock_should_work", m.lease.Key) - m2 := New("lock_should_work_2", "wallet", "lock_should_work", WithPeers(peers...)) - _, _, err = m2.Lock(context.TODO()) + }, + }, + { + name: "lock_should_not_work_if_lease_exists", + run: func(r *require.Assertions) { + m := New("lock", "wallet", "lock_exists", WithPeers(peers...), WithTTL(10*time.Second)) + err := m.Lock(context.TODO()) + r.NoError(err) + r.Equal(10*time.Second, m.lease.TTL.Duration()) + r.Equal("wallet", m.lease.Topic) + r.Equal("lock_exists", m.lease.Key) + + m2 := New("lock_2", "wallet", "lock_exists", WithPeers(peers...)) + err = m2.Lock(context.TODO()) - r.Error(err, async.ErrTooLessDone) + r.Error(err, ErrLeaseExists) + }, + }, + { + name: "expires_should_work", + run: func(r *require.Assertions) { + ttl := 3 * time.Second + m := New("lock_should_work", "wallet", "expires_should_work", WithPeers(peers...), WithTTL(ttl)) + err := m.Lock(context.TODO()) + r.NoError(err) + r.Equal(ttl, m.lease.TTL.Duration()) + r.Equal("wallet", m.lease.Topic) + r.Equal("expires_should_work", m.lease.Key) + + time.Sleep(ttl) + + <-m.Done() + r.ErrorIs(context.Cause(m), ErrExpiredLease) + + }, + }, + { + name: "lock_should_work_when_old_lease_expires", + run: func(r *require.Assertions) { + ttl := 3 * time.Second + m := New("lock", "wallet", "lock_exists", WithPeers(peers...), WithTTL(ttl)) + err := m.Lock(context.TODO()) + r.NoError(err) + r.Equal(ttl, m.lease.TTL.Duration()) + r.Equal("lock", m.lease.Lessee) + r.Equal("wallet", m.lease.Topic) + r.Equal("lock_exists", m.lease.Key) + + ttl2 := 5 * time.Second + m2 := New("lock_2", "wallet", "lock_exists", WithPeers(peers...), WithTTL(ttl2)) + err = m2.Lock(context.TODO()) + + r.Error(err, ErrLeaseExists) + + time.Sleep(ttl) // wait for 1st lease expires + + err = m2.Lock(context.TODO()) + r.NoError(err) + r.Equal(ttl2, m2.lease.TTL.Duration()) + r.Equal("lock_2", m2.lease.Lessee) + r.Equal("wallet", m2.lease.Topic) + r.Equal("lock_exists", m2.lease.Key) }, }, { - name: "lock_should_work_even_minority_nodes_are_down", + name: "lock_should_work_when_minority_nodes_are_down", run: func(r *require.Assertions) { m := New("lock_should_work", "wallet", "minority_nodes_are_down", WithPeers(peers...), WithTTL(10*time.Second)) nodes[0].Stop() + defer nodes[0].Start() // nolint: errcheck nodes[1].Stop() + defer nodes[1].Start() // nolint: errcheck + + err := m.Lock(context.TODO()) - _, cancel, err := m.Lock(context.TODO()) - defer cancel() r.NoError(err) r.Equal(10*time.Second, m.lease.TTL.Duration()) r.Equal("wallet", m.lease.Topic) r.Equal("minority_nodes_are_down", m.lease.Key) m2 := New("lock_should_work_2", "wallet", "minority_nodes_are_down", WithPeers(peers...)) - _, _, err = m2.Lock(context.TODO()) + err = m2.Lock(context.TODO()) r.Error(err, async.ErrTooLessDone) + }, + }, + { + name: "lock_should_not_work_when_majority_nodes_are_down", + run: func(r *require.Assertions) { + m := New("lock_should_work", "wallet", "majority_nodes_are_down", WithPeers(peers...), WithTTL(10*time.Second)) + + nodes[0].Stop() + defer nodes[0].Start() // nolint: errcheck + nodes[1].Stop() + defer nodes[1].Start() // nolint: errcheck + nodes[2].Stop() + defer nodes[2].Start() // nolint: errcheck + err = m.Lock(context.TODO()) + r.Error(err, async.ErrTooLessDone) }, }, } @@ -96,5 +170,410 @@ func TestLock(t *testing.T) { test.run(require.New(t)) }) } +} + +func TestRenew(t *testing.T) { + peers, nodes, clean, err := createCluster(5) + require.NoError(t, err) + defer clean() + + tests := []struct { + name string + run func(*require.Assertions) + }{ + { + name: "renew_should_work", + run: func(r *require.Assertions) { + ttl := 10 * time.Second + m := New("renew", "wallet", "renew", WithPeers(peers...), WithTTL(ttl)) + err := m.Lock(context.TODO()) + + r.NoError(err) + r.Equal(ttl, m.lease.TTL.Duration()) + r.Equal("wallet", m.lease.Topic) + r.Equal("renew", m.lease.Key) + + err = m.Renew(context.TODO()) + r.NoError(err) + r.Equal(ttl, m.lease.TTL.Duration()) + r.Equal("wallet", m.lease.Topic) + r.Equal("renew", m.lease.Key) + }, + }, + { + name: "renew_should_not_work_when_lease_is_expired", + run: func(r *require.Assertions) { + ttl := 2 * time.Second + m := New("renew", "wallet", "renew_expires", WithPeers(peers...), WithTTL(ttl)) + err := m.Lock(context.TODO()) + + r.NoError(err) + r.Equal(ttl, m.lease.TTL.Duration()) + r.Equal("wallet", m.lease.Topic) + r.Equal("renew_expires", m.lease.Key) + + time.Sleep(ttl) + + err = m.Renew(context.TODO()) + r.ErrorIs(err, ErrExpiredLease) + + }, + }, + { + name: "keepalive_should_work", + run: func(r *require.Assertions) { + ttl := 2 * time.Second + m := New("renew", "wallet", "renew_keepalive", WithPeers(peers...), WithTTL(ttl)) + err := m.Lock(context.TODO()) + + r.NoError(err) + r.Equal(ttl, m.lease.TTL.Duration()) + r.Equal("renew", m.lease.Lessee) + r.Equal("wallet", m.lease.Topic) + r.Equal("renew_keepalive", m.lease.Key) + + go m.Keepalive() + + time.Sleep(ttl) + + err = m.Renew(context.TODO()) + r.NoError(err) + + time.Sleep(1 * time.Second) + err = m.Renew(context.TODO()) + r.NoError(err) + + }, + }, + { + name: "renew_should_not_work_when_the_lease_does_not_exists", + run: func(r *require.Assertions) { + ttl := 2 * time.Second + m := New("renew", "wallet", "renew_does_not_exists", WithPeers(peers...), WithTTL(ttl)) + + err = m.Renew(context.TODO()) + r.Error(err, ErrNoLease) + }, + }, + { + name: "renew_should_not_work_when_lease_is_not_yours", + run: func(r *require.Assertions) { + ttl := 5 * time.Second + m := New("renew", "wallet", "renew_not_yours", WithPeers(peers...), WithTTL(ttl)) + + err := m.Lock(context.Background()) + r.NoError(err) + r.Equal(ttl, m.lease.TTL.Duration()) + r.Equal("renew", m.lease.Lessee) + r.Equal("wallet", m.lease.Topic) + r.Equal("renew_not_yours", m.lease.Key) + + m2 := New("renew_2", "wallet", "renew_not_yours", WithPeers(peers...), WithTTL(ttl)) + err = m2.Renew(context.Background()) + r.ErrorIs(err, ErrNotYourLease) + + }, + }, + { + name: "renew_should_work_when_minority_nodes_are_down", + run: func(r *require.Assertions) { + m := New("renew", "wallet", "renew_minority", WithPeers(peers...), WithTTL(10*time.Second)) + + err := m.Lock(context.TODO()) + + r.NoError(err) + r.Equal(10*time.Second, m.lease.TTL.Duration()) + r.Equal("wallet", m.lease.Topic) + r.Equal("renew_minority", m.lease.Key) + + nodes[0].Stop() + defer nodes[0].Start() // nolint: errcheck + nodes[1].Stop() + defer nodes[1].Start() // nolint: errcheck + err = m.Renew(context.TODO()) + r.NoError(err) + + }, + }, + { + name: "renew_should_not_work_when_majority_nodes_are_down", + run: func(r *require.Assertions) { + m := New("renew", "wallet", "renew_majority", WithPeers(peers...), WithTTL(10*time.Second)) + + err := m.Lock(context.TODO()) + + r.NoError(err) + r.Equal(10*time.Second, m.lease.TTL.Duration()) + r.Equal("wallet", m.lease.Topic) + r.Equal("renew_majority", m.lease.Key) + + nodes[0].Stop() + defer nodes[0].Start() // nolint: errcheck + nodes[1].Stop() + defer nodes[1].Start() // nolint: errcheck + nodes[2].Stop() + defer nodes[2].Start() // nolint: errcheck + + err = m.Renew(context.TODO()) + r.ErrorIs(err, async.ErrTooLessDone) + + }, + }, + } + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + test.run(require.New(t)) + }) + } +} + +func TestUnlock(t *testing.T) { + peers, nodes, clean, err := createCluster(5) + require.NoError(t, err) + defer clean() + + tests := []struct { + name string + run func(*require.Assertions) + }{ + { + name: "unlock_should_work", + run: func(r *require.Assertions) { + ttl := 10 * time.Second + m := New("unlock", "wallet", "unlock", WithPeers(peers...), WithTTL(ttl)) + err := m.Lock(context.TODO()) + + r.NoError(err) + r.Equal(ttl, m.lease.TTL.Duration()) + r.Equal("wallet", m.lease.Topic) + r.Equal("unlock", m.lease.Key) + + err = m.Unlock(context.TODO()) + r.NoError(err) + + err = m.Renew(context.TODO()) + r.ErrorIs(err, ErrNoLease) + }, + }, + { + name: "unlock_should_work_even_lease_does_not_exists", + run: func(r *require.Assertions) { + ttl := 10 * time.Second + m := New("unlock", "wallet", "unlock", WithPeers(peers...), WithTTL(ttl)) + + err = m.Unlock(context.TODO()) + r.NoError(err) + + }, + }, + { + name: "unlock_should_not_work_when_lease_is_not_yours", + run: func(r *require.Assertions) { + ttl := 10 * time.Second + m := New("unlock", "wallet", "unlock_not_yours", WithPeers(peers...), WithTTL(ttl)) + err := m.Lock(context.TODO()) + + r.NoError(err) + r.Equal(ttl, m.lease.TTL.Duration()) + r.Equal("unlock", m.lease.Lessee) + r.Equal("wallet", m.lease.Topic) + r.Equal("unlock_not_yours", m.lease.Key) + + m2 := New("unlock_2", "wallet", "unlock_not_yours", WithPeers(peers...), WithTTL(ttl)) + err = m2.Unlock(context.TODO()) + + r.ErrorIs(err, ErrNotYourLease) + }, + }, + + { + name: "unlock_should_work_when_minority_nodes_are_down", + run: func(r *require.Assertions) { + ttl := 10 * time.Second + m := New("unlock", "wallet", "unlock_minority", WithPeers(peers...), WithTTL(ttl)) + + err := m.Lock(context.TODO()) + + r.NoError(err) + r.Equal(ttl, m.lease.TTL.Duration()) + r.Equal("unlock", m.lease.Lessee) + r.Equal("wallet", m.lease.Topic) + r.Equal("unlock_minority", m.lease.Key) + + nodes[0].Stop() + defer nodes[0].Start() // nolint: errcheck + + nodes[1].Stop() + defer nodes[1].Start() // nolint: errcheck + + err = m.Unlock(context.TODO()) + r.NoError(err) + + }, + }, + { + name: "unlock_should_not_work_when_majority_nodes_are_down", + run: func(r *require.Assertions) { + ttl := 10 * time.Second + m := New("unlock", "wallet", "unlock_majority", WithPeers(peers...), WithTTL(ttl)) + + err := m.Lock(context.TODO()) + + r.NoError(err) + r.Equal(ttl, m.lease.TTL.Duration()) + r.Equal("unlock", m.lease.Lessee) + r.Equal("wallet", m.lease.Topic) + r.Equal("unlock_majority", m.lease.Key) + + nodes[0].Stop() + defer nodes[0].Start() // nolint: errcheck + nodes[1].Stop() + defer nodes[1].Start() // nolint: errcheck + nodes[2].Stop() + defer nodes[2].Start() // nolint: errcheck + + err = m.Unlock(context.TODO()) + r.ErrorIs(err, async.ErrTooLessDone) + + }, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + test.run(require.New(t)) + }) + } +} + +func TestTopic(t *testing.T) { + peers, nodes, clean, err := createCluster(5) + require.NoError(t, err) + defer clean() + + tests := []struct { + name string + run func(*require.Assertions) + }{ + { + name: "freeze_should_work", + run: func(r *require.Assertions) { + ttl := 10 * time.Second + m := New("freeze", "freeze", "freeze", WithPeers(peers...), WithTTL(ttl)) + err := m.Lock(context.TODO()) + + r.NoError(err) + r.Equal(ttl, m.lease.TTL.Duration()) + r.Equal("freeze", m.lease.Lessee) + r.Equal("freeze", m.lease.Topic) + r.Equal("freeze", m.lease.Key) + + err = m.Freeze(context.Background(), "freeze") + r.NoError(err) + + err = m.Renew(context.TODO()) + r.ErrorIs(err, ErrFrozenTopic) + + m2 := New("freeze_2", "freeze", "freeze_2", WithPeers(peers...), WithTTL(ttl)) + err = m2.Lock(context.TODO()) + r.ErrorIs(err, ErrFrozenTopic) + + err = m.Reset(context.Background(), "freeze") + r.NoError(err) + + err = m.Renew(context.Background()) + r.NoError(err) + r.Equal(ttl, m.lease.TTL.Duration()) + r.Equal("freeze", m.lease.Topic) + r.Equal("freeze", m.lease.Key) + + err = m2.Lock(context.TODO()) + r.NoError(err) + r.Equal(ttl, m2.lease.TTL.Duration()) + r.Equal("freeze_2", m2.lease.Lessee) + r.Equal("freeze", m2.lease.Topic) + r.Equal("freeze_2", m2.lease.Key) + }, + }, + { + name: "freeze_should_work_when_minority_nodes_are_down", + run: func(r *require.Assertions) { + ttl := 10 * time.Second + m := New("freeze", "freeze", "freeze", WithPeers(peers...), WithTTL(ttl)) + err := m.Lock(context.TODO()) + + r.NoError(err) + r.Equal(ttl, m.lease.TTL.Duration()) + r.Equal("freeze", m.lease.Lessee) + r.Equal("freeze", m.lease.Topic) + r.Equal("freeze", m.lease.Key) + + nodes[0].Stop() + defer nodes[0].Start() // nolint: errcheck + nodes[1].Stop() + defer nodes[1].Start() // nolint: errcheck + + err = m.Freeze(context.Background(), "freeze") + r.NoError(err) + + err = m.Renew(context.TODO()) + r.ErrorIs(err, ErrFrozenTopic) + + m2 := New("freeze_2", "freeze", "freeze_2", WithPeers(peers...), WithTTL(ttl)) + err = m2.Lock(context.TODO()) + r.ErrorIs(err, ErrFrozenTopic) + + err = m.Reset(context.Background(), "freeze") + r.NoError(err) + + err = m.Renew(context.Background()) + r.NoError(err) + r.Equal(ttl, m.lease.TTL.Duration()) + r.Equal("freeze", m.lease.Topic) + r.Equal("freeze", m.lease.Key) + + err = m2.Lock(context.TODO()) + r.NoError(err) + r.Equal(ttl, m2.lease.TTL.Duration()) + r.Equal("freeze_2", m2.lease.Lessee) + r.Equal("freeze", m2.lease.Topic) + r.Equal("freeze_2", m2.lease.Key) + }, + }, + + { + name: "freeze_should_work_when_minority_nodes_are_down", + run: func(r *require.Assertions) { + ttl := 10 * time.Second + m := New("freeze", "freeze", "freeze", WithPeers(peers...), WithTTL(ttl)) + err := m.Lock(context.TODO()) + + r.NoError(err) + r.Equal(ttl, m.lease.TTL.Duration()) + r.Equal("freeze", m.lease.Lessee) + r.Equal("freeze", m.lease.Topic) + r.Equal("freeze", m.lease.Key) + + nodes[0].Stop() + defer nodes[0].Start() // nolint: errcheck + nodes[1].Stop() + defer nodes[1].Start() // nolint: errcheck + nodes[2].Stop() + defer nodes[2].Start() // nolint: errcheck + + err = m.Freeze(context.Background(), "freeze") + r.ErrorIs(err, async.ErrTooLessDone) + + err = m.Reset(context.Background(), "freeze") + r.ErrorIs(err, async.ErrTooLessDone) + }, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + test.run(require.New(t)) + }) + } } diff --git a/node.go b/node.go index 60848cf..8dc0c9a 100644 --- a/node.go +++ b/node.go @@ -2,6 +2,7 @@ package dlm import ( "log/slog" + "net" "net/rpc" "sync" @@ -14,7 +15,6 @@ func NewNode(addr string, db *sqle.DB, options ...NodeOption) *Node { db: db, frozen: make(map[string]struct{}), logger: slog.Default(), - close: make(chan struct{}), } for _, o := range options { @@ -25,12 +25,14 @@ func NewNode(addr string, db *sqle.DB, options ...NodeOption) *Node { } type Node struct { - mu sync.Mutex + mu sync.RWMutex db *sqle.DB logger *slog.Logger frozen map[string]struct{} - addr string - server *rpc.Server - close chan struct{} + stopped bool + + addr string + listener net.Listener + server *rpc.Server } diff --git a/node_rpc.go b/node_rpc.go index 2c1e640..f58520b 100644 --- a/node_rpc.go +++ b/node_rpc.go @@ -96,6 +96,9 @@ func (n *Node) ReleaseLock(req LockRequest, ok *bool) error { lease, err := n.getLease(req.Topic, req.Key) if err != nil { + if errors.Is(err, ErrNoLease) { + return nil + } return err } diff --git a/node_svc.go b/node_svc.go index 6664fd2..4b7fb27 100644 --- a/node_svc.go +++ b/node_svc.go @@ -9,40 +9,62 @@ import ( ) // Start start the node and its RPC service -func (n *Node) Start(ctx context.Context) error { +func (n *Node) Start() error { + n.mu.Lock() + defer n.mu.Unlock() + l, err := net.Listen("tcp", n.addr) if err != nil { return err } - _, err = n.db.ExecContext(ctx, CreateTableLease) + _, err = n.db.ExecContext(context.Background(), CreateTableLease) if err != nil { return err } - _, err = n.db.ExecContext(ctx, CreateTableTopic) + _, err = n.db.ExecContext(context.Background(), CreateTableTopic) if err != nil { return err } n.server = rpc.NewServer() - go n.waitClose(ctx, l) - go n.waitRequest(l) - + n.listener = l + go n.waitRequest() + n.stopped = false + n.logger.Info("dlm: node is running") return n.server.RegisterName("dlm", n) } // Stop stop the node and its RPC service func (n *Node) Stop() { - n.close <- struct{}{} - n.logger.Info("dlm: node stopped") + n.mu.Lock() + defer n.mu.Unlock() + n.listener.Close() + n.stopped = true + n.logger.Info("dlm: node is stopped") +} + +func (n *Node) isStopped() bool { + n.mu.RLock() + defer n.mu.RUnlock() + + return n.stopped } -func (n *Node) waitRequest(l net.Listener) { +func (n *Node) waitRequest() { + for { - conn, err := l.Accept() + conn, err := n.listener.Accept() + if err == nil { + + if n.isStopped() { + return + } + go n.server.ServeConn(conn) + } else { if errors.Is(err, net.ErrClosed) { return @@ -50,16 +72,5 @@ func (n *Node) waitRequest(l net.Listener) { n.logger.Warn("dlm: wait request", slog.String("err", err.Error()), slog.String("addr", n.addr)) } - } } - -func (n *Node) waitClose(ctx context.Context, l net.Listener) { - - select { - case <-n.close: - case <-ctx.Done(): - } - - l.Close() -} diff --git a/node_test.go b/node_test.go index cbd0c17..4705c43 100644 --- a/node_test.go +++ b/node_test.go @@ -1,7 +1,6 @@ package dlm import ( - "context" "database/sql" "net" "os" @@ -48,16 +47,13 @@ func createSqlite3() (*sql.DB, func(), error) { } func TestLease(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - db, clean, err := createSqlite3() require.NoError(t, err) defer clean() - n := NewNode(getFreeAddr(), sqle.Open(db)) - err = n.Start(ctx) + err = n.Start() require.NoError(t, err) + defer n.Stop() walletTerms := sqle.Duration(5 * time.Second) userTerms := sqle.Duration(3 * time.Second) @@ -232,17 +228,15 @@ func TestLease(t *testing.T) { } } -func TestTopic(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - +func TestNodeTopic(t *testing.T) { db, clean, err := createSqlite3() require.NoError(t, err) defer clean() n := NewNode(getFreeAddr(), sqle.Open(db)) - err = n.Start(ctx) + err = n.Start() require.NoError(t, err) + defer n.Stop() terms := sqle.Duration(5 * time.Second) diff --git a/peer.go b/peer.go deleted file mode 100644 index 248f852..0000000 --- a/peer.go +++ /dev/null @@ -1,8 +0,0 @@ -package dlm - -type Peer struct { - RaftID string - RaftAddr string - - Addr string -}