Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix/goroutine leak #7

Merged
merged 10 commits into from
Apr 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 16 additions & 6 deletions conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,10 @@ import (
"errors"
"net"
"net/rpc"
"sync"
"time"
)

func connect(ctx context.Context, addr string, timeout time.Duration) (*rpc.Client, error) {

var d = net.Dialer{
Timeout: timeout,
}
Expand All @@ -20,14 +18,22 @@ func connect(ctx context.Context, addr string, timeout time.Duration) (*rpc.Clie
return nil, err
}

return rpc.NewClient(&Conn{ctx: ctx, Conn: c}), nil
conn := &Conn{
Conn: c,
}

conn.ctx, conn.cancel = context.WithCancel(ctx)

go conn.waitContext()

return rpc.NewClient(conn), nil
}

// Conn wrap net.Conn with context support
type Conn struct {
ctx context.Context
ctx context.Context
cancel context.CancelFunc
net.Conn
once sync.Once
}

func (c *Conn) waitContext() {
Expand All @@ -41,10 +47,14 @@ func (c *Conn) waitContext() {
}

func (c *Conn) Read(p []byte) (n int, err error) {
go c.once.Do(c.waitContext)
return c.Conn.Read(p)
}

func (c *Conn) Write(p []byte) (n int, err error) {
return c.Conn.Write(p)
}

func (c *Conn) Close() error {
defer c.cancel()
return c.Conn.Close()
}
1 change: 1 addition & 0 deletions dlm.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ var (
var (
DefaultTimeout = 3 * time.Second
DefaultLeaseTerm = 5 * time.Second
DefaultKeepAlive = 1 * time.Second

Logger = slog.Default()
)
Expand Down
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ require (
github.com/mattn/go-sqlite3 v1.14.22
github.com/stretchr/testify v1.9.0
github.com/yaitoo/async v1.0.4
github.com/yaitoo/sqle v1.3.2
github.com/yaitoo/sqle v1.4.5
)

require (
Expand Down
4 changes: 2 additions & 2 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@ github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsT
github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
github.com/yaitoo/async v1.0.4 h1:u+SWuJcSckgBOcMjMYz9IviojeCatDrdni3YNGLCiHY=
github.com/yaitoo/async v1.0.4/go.mod h1:IpSO7Ei7AxiqLxFqDjN4rJaVlt8wm4ZxMXyyQaWmM1g=
github.com/yaitoo/sqle v1.3.2 h1:kuoAw2XPNvuPFJlUI1EdAknCajrf4aUNc4ns5TzPpXw=
github.com/yaitoo/sqle v1.3.2/go.mod h1:Bv1PPG6hYZP2In3WKN1dBqYNJiWP0ZSLs6uEkRo2c9M=
github.com/yaitoo/sqle v1.4.5 h1:2xWCNfGgCNisMsQgpCvMFuhaTkTFIvF+nkjEu6H41lU=
github.com/yaitoo/sqle v1.4.5/go.mod h1:Bv1PPG6hYZP2In3WKN1dBqYNJiWP0ZSLs6uEkRo2c9M=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk=
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q=
Expand Down
45 changes: 30 additions & 15 deletions mutex.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,12 @@ import (
func New(id, topic, key string, options ...MutexOption) *Mutex {

m := &Mutex{
id: id,
topic: strings.ToLower(topic),
key: strings.ToLower(key),
timeout: DefaultTimeout,
ttl: DefaultLeaseTerm,
id: id,
topic: strings.ToLower(topic),
key: strings.ToLower(key),
timeout: DefaultTimeout,
ttl: DefaultLeaseTerm,
keepalive: DefaultKeepAlive,
}

for _, o := range options {
Expand All @@ -36,12 +37,13 @@ func New(id, topic, key string, options ...MutexOption) *Mutex {
type Mutex struct {
mu sync.RWMutex

id string
topic string
key string
peers []string
timeout time.Duration
ttl time.Duration
id string
topic string
key string
peers []string
timeout time.Duration
ttl time.Duration
keepalive time.Duration

consensus int

Expand Down Expand Up @@ -89,6 +91,7 @@ func (m *Mutex) Lock(ctx context.Context) error {
for _, c := range cluster {
a.Add(func(c *rpc.Client) func(ctx context.Context) (Lease, error) {
return func(ctx context.Context) (Lease, error) {
defer c.Close()
var t Lease
err := c.Call("dlm.NewLock", req, &t)
return t, err
Expand All @@ -104,7 +107,7 @@ func (m *Mutex) Lock(ctx context.Context) error {
return m.Error(errs, err)
}

t := result[0]
t := result[len(result)-1]

if t.IsExpired(start) {
return ErrExpiredLease
Expand All @@ -131,6 +134,7 @@ func (m *Mutex) Renew(ctx context.Context) error {
for _, c := range cluster {
a.Add(func(c *rpc.Client) func(ctx context.Context) (Lease, error) {
return func(ctx context.Context) (Lease, error) {
defer c.Close()
var t Lease
err := c.Call("dlm.RenewLock", req, &t)
if err != nil {
Expand All @@ -151,7 +155,7 @@ func (m *Mutex) Renew(ctx context.Context) error {
return m.Error(errs, err)
}

t := result[0]
t := result[len(result)-1]
if t.IsExpired(start) {
return ErrExpiredLease
}
Expand All @@ -175,6 +179,7 @@ func (m *Mutex) Unlock(ctx context.Context) error {
for _, c := range cluster {
a.Add(func(c *rpc.Client) func(ctx context.Context) error {
return func(ctx context.Context) error {
defer c.Close()
var t bool
err := c.Call("dlm.ReleaseLock", req, &t)
if err != nil {
Expand Down Expand Up @@ -214,6 +219,7 @@ func (m *Mutex) Freeze(ctx context.Context, topic string) error {
for _, c := range cluster {
a.Add(func(c *rpc.Client) func(ctx context.Context) error {
return func(ctx context.Context) error {
defer c.Close()
var ok bool
return c.Call("dlm.Freeze", topic, &ok)
}
Expand Down Expand Up @@ -248,6 +254,7 @@ func (m *Mutex) Reset(ctx context.Context, topic string) error {
for _, c := range cluster {
a.Add(func(c *rpc.Client) func(ctx context.Context) error {
return func(ctx context.Context) error {
defer c.Close()
var ok bool
return c.Call("dlm.Reset", topic, &ok)
}
Expand Down Expand Up @@ -275,18 +282,24 @@ func (m *Mutex) createRequest() LockRequest {
}

func (m *Mutex) waitExpires() {
delay := time.NewTimer(1 * time.Second)
defer delay.Stop()

var expiresOn time.Time

for {
m.mu.RLock()
expiresOn = m.lease.ExpiresOn
m.mu.RUnlock()

delay.Stop()
delay.Reset(time.Until(expiresOn))

select {

case <-m.Context.Done():
return
case <-time.After(time.Until(expiresOn)):
case <-delay.C:
// get latest ExpiresOn
m.mu.RLock()
expiresOn = m.lease.ExpiresOn
Expand All @@ -301,13 +314,15 @@ func (m *Mutex) waitExpires() {
}

func (m *Mutex) Keepalive() {
delay := time.NewTicker(m.keepalive)
defer delay.Stop()

var err error
for {
select {
case <-m.Context.Done():
return
case <-time.After(1 * time.Second):
case <-delay.C:
m.mu.RLock()
expiresOn := m.lease.ExpiresOn
m.mu.RUnlock()
Expand Down
6 changes: 6 additions & 0 deletions mutex_option.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,3 +23,9 @@ func WithTimeout(d time.Duration) MutexOption {
m.timeout = d
}
}

func WithKeepAlive(d time.Duration) MutexOption {
return func(m *Mutex) {
m.keepalive = d
}
}
9 changes: 6 additions & 3 deletions mutext_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -222,8 +222,11 @@ func TestRenew(t *testing.T) {
{
name: "keepalive_should_work",
run: func(r *require.Assertions) {
ttl := 2 * time.Second
m := New("renew", "wallet", "renew_keepalive", WithPeers(peers...), WithTTL(ttl))
ttl := 3 * time.Second
m := New("renew", "wallet", "renew_keepalive",
WithPeers(peers...),
WithTTL(ttl),
WithKeepAlive(1*time.Second))
err := m.Lock(context.TODO())

r.NoError(err)
Expand All @@ -234,7 +237,7 @@ func TestRenew(t *testing.T) {

go m.Keepalive()

time.Sleep(ttl)
time.Sleep(ttl + 1*time.Second)

err = m.Renew(context.TODO())
r.NoError(err)
Expand Down
Loading