diff --git a/conn.go b/conn.go index 5e66880..0bd0d74 100644 --- a/conn.go +++ b/conn.go @@ -5,12 +5,10 @@ import ( "errors" "net" "net/rpc" - "sync" "time" ) func connect(ctx context.Context, addr string, timeout time.Duration) (*rpc.Client, error) { - var d = net.Dialer{ Timeout: timeout, } @@ -20,14 +18,22 @@ func connect(ctx context.Context, addr string, timeout time.Duration) (*rpc.Clie return nil, err } - return rpc.NewClient(&Conn{ctx: ctx, Conn: c}), nil + conn := &Conn{ + Conn: c, + } + + conn.ctx, conn.cancel = context.WithCancel(ctx) + + go conn.waitContext() + + return rpc.NewClient(conn), nil } // Conn wrap net.Conn with context support type Conn struct { - ctx context.Context + ctx context.Context + cancel context.CancelFunc net.Conn - once sync.Once } func (c *Conn) waitContext() { @@ -41,10 +47,14 @@ func (c *Conn) waitContext() { } func (c *Conn) Read(p []byte) (n int, err error) { - go c.once.Do(c.waitContext) return c.Conn.Read(p) } func (c *Conn) Write(p []byte) (n int, err error) { return c.Conn.Write(p) } + +func (c *Conn) Close() error { + defer c.cancel() + return c.Conn.Close() +} diff --git a/dlm.go b/dlm.go index 24aca7f..7c1836a 100644 --- a/dlm.go +++ b/dlm.go @@ -20,6 +20,7 @@ var ( var ( DefaultTimeout = 3 * time.Second DefaultLeaseTerm = 5 * time.Second + DefaultKeepAlive = 1 * time.Second Logger = slog.Default() ) diff --git a/go.mod b/go.mod index 7560b9f..9370cde 100644 --- a/go.mod +++ b/go.mod @@ -6,7 +6,7 @@ require ( github.com/mattn/go-sqlite3 v1.14.22 github.com/stretchr/testify v1.9.0 github.com/yaitoo/async v1.0.4 - github.com/yaitoo/sqle v1.3.2 + github.com/yaitoo/sqle v1.4.5 ) require ( diff --git a/go.sum b/go.sum index a48a599..551d9b2 100644 --- a/go.sum +++ b/go.sum @@ -15,8 +15,8 @@ github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsT github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= github.com/yaitoo/async v1.0.4 h1:u+SWuJcSckgBOcMjMYz9IviojeCatDrdni3YNGLCiHY= github.com/yaitoo/async v1.0.4/go.mod h1:IpSO7Ei7AxiqLxFqDjN4rJaVlt8wm4ZxMXyyQaWmM1g= -github.com/yaitoo/sqle v1.3.2 h1:kuoAw2XPNvuPFJlUI1EdAknCajrf4aUNc4ns5TzPpXw= -github.com/yaitoo/sqle v1.3.2/go.mod h1:Bv1PPG6hYZP2In3WKN1dBqYNJiWP0ZSLs6uEkRo2c9M= +github.com/yaitoo/sqle v1.4.5 h1:2xWCNfGgCNisMsQgpCvMFuhaTkTFIvF+nkjEu6H41lU= +github.com/yaitoo/sqle v1.4.5/go.mod h1:Bv1PPG6hYZP2In3WKN1dBqYNJiWP0ZSLs6uEkRo2c9M= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q= diff --git a/mutex.go b/mutex.go index f4334cd..09abe7b 100644 --- a/mutex.go +++ b/mutex.go @@ -16,11 +16,12 @@ import ( func New(id, topic, key string, options ...MutexOption) *Mutex { m := &Mutex{ - id: id, - topic: strings.ToLower(topic), - key: strings.ToLower(key), - timeout: DefaultTimeout, - ttl: DefaultLeaseTerm, + id: id, + topic: strings.ToLower(topic), + key: strings.ToLower(key), + timeout: DefaultTimeout, + ttl: DefaultLeaseTerm, + keepalive: DefaultKeepAlive, } for _, o := range options { @@ -36,12 +37,13 @@ func New(id, topic, key string, options ...MutexOption) *Mutex { type Mutex struct { mu sync.RWMutex - id string - topic string - key string - peers []string - timeout time.Duration - ttl time.Duration + id string + topic string + key string + peers []string + timeout time.Duration + ttl time.Duration + keepalive time.Duration consensus int @@ -89,6 +91,7 @@ func (m *Mutex) Lock(ctx context.Context) error { for _, c := range cluster { a.Add(func(c *rpc.Client) func(ctx context.Context) (Lease, error) { return func(ctx context.Context) (Lease, error) { + defer c.Close() var t Lease err := c.Call("dlm.NewLock", req, &t) return t, err @@ -104,7 +107,7 @@ func (m *Mutex) Lock(ctx context.Context) error { return m.Error(errs, err) } - t := result[0] + t := result[len(result)-1] if t.IsExpired(start) { return ErrExpiredLease @@ -131,6 +134,7 @@ func (m *Mutex) Renew(ctx context.Context) error { for _, c := range cluster { a.Add(func(c *rpc.Client) func(ctx context.Context) (Lease, error) { return func(ctx context.Context) (Lease, error) { + defer c.Close() var t Lease err := c.Call("dlm.RenewLock", req, &t) if err != nil { @@ -151,7 +155,7 @@ func (m *Mutex) Renew(ctx context.Context) error { return m.Error(errs, err) } - t := result[0] + t := result[len(result)-1] if t.IsExpired(start) { return ErrExpiredLease } @@ -175,6 +179,7 @@ func (m *Mutex) Unlock(ctx context.Context) error { for _, c := range cluster { a.Add(func(c *rpc.Client) func(ctx context.Context) error { return func(ctx context.Context) error { + defer c.Close() var t bool err := c.Call("dlm.ReleaseLock", req, &t) if err != nil { @@ -214,6 +219,7 @@ func (m *Mutex) Freeze(ctx context.Context, topic string) error { for _, c := range cluster { a.Add(func(c *rpc.Client) func(ctx context.Context) error { return func(ctx context.Context) error { + defer c.Close() var ok bool return c.Call("dlm.Freeze", topic, &ok) } @@ -248,6 +254,7 @@ func (m *Mutex) Reset(ctx context.Context, topic string) error { for _, c := range cluster { a.Add(func(c *rpc.Client) func(ctx context.Context) error { return func(ctx context.Context) error { + defer c.Close() var ok bool return c.Call("dlm.Reset", topic, &ok) } @@ -275,18 +282,24 @@ func (m *Mutex) createRequest() LockRequest { } func (m *Mutex) waitExpires() { + delay := time.NewTimer(1 * time.Second) + defer delay.Stop() var expiresOn time.Time + for { m.mu.RLock() expiresOn = m.lease.ExpiresOn m.mu.RUnlock() + delay.Stop() + delay.Reset(time.Until(expiresOn)) + select { case <-m.Context.Done(): return - case <-time.After(time.Until(expiresOn)): + case <-delay.C: // get latest ExpiresOn m.mu.RLock() expiresOn = m.lease.ExpiresOn @@ -301,13 +314,15 @@ func (m *Mutex) waitExpires() { } func (m *Mutex) Keepalive() { + delay := time.NewTicker(m.keepalive) + defer delay.Stop() var err error for { select { case <-m.Context.Done(): return - case <-time.After(1 * time.Second): + case <-delay.C: m.mu.RLock() expiresOn := m.lease.ExpiresOn m.mu.RUnlock() diff --git a/mutex_option.go b/mutex_option.go index 025ea5c..0b93cb7 100644 --- a/mutex_option.go +++ b/mutex_option.go @@ -23,3 +23,9 @@ func WithTimeout(d time.Duration) MutexOption { m.timeout = d } } + +func WithKeepAlive(d time.Duration) MutexOption { + return func(m *Mutex) { + m.keepalive = d + } +} diff --git a/mutext_test.go b/mutext_test.go index ebad4dd..6134378 100644 --- a/mutext_test.go +++ b/mutext_test.go @@ -222,8 +222,11 @@ func TestRenew(t *testing.T) { { name: "keepalive_should_work", run: func(r *require.Assertions) { - ttl := 2 * time.Second - m := New("renew", "wallet", "renew_keepalive", WithPeers(peers...), WithTTL(ttl)) + ttl := 3 * time.Second + m := New("renew", "wallet", "renew_keepalive", + WithPeers(peers...), + WithTTL(ttl), + WithKeepAlive(1*time.Second)) err := m.Lock(context.TODO()) r.NoError(err) @@ -234,7 +237,7 @@ func TestRenew(t *testing.T) { go m.Keepalive() - time.Sleep(ttl) + time.Sleep(ttl + 1*time.Second) err = m.Renew(context.TODO()) r.NoError(err)