Skip to content

Commit

Permalink
Merge pull request #204 from sfinke0/fix/goroutine-leak
Browse files Browse the repository at this point in the history
fix Goroutine leak due to stuck send on channel on timeouts
  • Loading branch information
carlmontanari authored Nov 11, 2024
2 parents f6dbfb5 + 238f734 commit 7ea22c0
Show file tree
Hide file tree
Showing 5 changed files with 59 additions and 53 deletions.
25 changes: 13 additions & 12 deletions channel/getprompt.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@ package channel

import (
"context"
"errors"
"fmt"
"time"

"github.com/scrapli/scrapligo/util"
)
Expand All @@ -15,7 +15,7 @@ func (c *Channel) GetPrompt() ([]byte, error) {

cr := make(chan *result)

ctx, cancel := context.WithCancel(context.Background())
ctx, cancel := context.WithTimeout(context.Background(), c.TimeoutOps)

defer cancel()

Expand All @@ -39,18 +39,19 @@ func (c *Channel) GetPrompt() ([]byte, error) {
cr <- &result{b: c.PromptPattern.Find(b), err: err}
}()

timer := time.NewTimer(c.TimeoutOps)
r := <-cr
if r.err != nil {
if errors.Is(r.err, context.DeadlineExceeded) {
c.l.Critical("channel timeout fetching prompt")

select {
case r := <-cr:
if r.err != nil {
return nil, r.err
return nil, fmt.Errorf(
"%w: channel timeout fetching prompt",
util.ErrTimeoutError,
)
}

return r.b, nil
case <-timer.C:
c.l.Critical("channel timeout fetching prompt")

return nil, fmt.Errorf("%w: channel timeout fetching prompt", util.ErrTimeoutError)
return nil, r.err
}

return r.b, nil
}
4 changes: 2 additions & 2 deletions channel/read.go
Original file line number Diff line number Diff line change
Expand Up @@ -227,7 +227,7 @@ func (c *Channel) ReadUntilPrompt(ctx context.Context) ([]byte, error) {
for {
select {
case <-ctx.Done():
return nil, nil
return nil, ctx.Err()
default:
}

Expand Down Expand Up @@ -261,7 +261,7 @@ func (c *Channel) ReadUntilAnyPrompt(
for {
select {
case <-ctx.Done():
return nil, nil
return nil, ctx.Err()
default:
}

Expand Down
27 changes: 15 additions & 12 deletions channel/sendinput.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@ package channel

import (
"context"
"errors"
"fmt"
"regexp"
"time"

"github.com/scrapli/scrapligo/util"
)
Expand All @@ -26,7 +26,7 @@ func (c *Channel) SendInputB(input []byte, opts ...util.Option) ([]byte, error)

cr := make(chan *result)

ctx, cancel := context.WithCancel(context.Background())
ctx, cancel := context.WithTimeout(context.Background(), c.GetTimeout(op.Timeout))

// we'll call cancel no matter what, either the read goroutines finished nicely in which case it
// doesnt matter, or we hit the timer and the cancel will stop the reading
Expand Down Expand Up @@ -72,6 +72,8 @@ func (c *Channel) SendInputB(input []byte, opts ...util.Option) ([]byte, error)

if readErr != nil {
cr <- &result{b: b, err: readErr}

return
}

b = append(b, nb...)
Expand All @@ -83,20 +85,21 @@ func (c *Channel) SendInputB(input []byte, opts ...util.Option) ([]byte, error)
}
}()

timer := time.NewTimer(c.GetTimeout(op.Timeout))
r := <-cr
if r.err != nil {
if errors.Is(r.err, context.DeadlineExceeded) {
c.l.Critical("channel timeout sending input to device")

select {
case r := <-cr:
if r.err != nil {
return nil, r.err
return nil, fmt.Errorf(
"%w: channel timeout sending input to device",
util.ErrTimeoutError,
)
}

return r.b, nil
case <-timer.C:
c.l.Critical("channel timeout sending input to device")

return nil, fmt.Errorf("%w: channel timeout sending input to device", util.ErrTimeoutError)
return nil, r.err
}

return r.b, nil
}

// SendInput sends the input string to the target device. Any bytes output is returned.
Expand Down
28 changes: 13 additions & 15 deletions channel/sendinteractive.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@ package channel

import (
"context"
"errors"
"fmt"
"regexp"
"time"

"github.com/scrapli/scrapligo/util"
)
Expand Down Expand Up @@ -118,27 +118,25 @@ func (c *Channel) SendInteractive(

cr := make(chan *result)

ctx, cancel := context.WithCancel(context.Background())
ctx, cancel := context.WithTimeout(context.Background(), c.GetTimeout(op.Timeout))

defer cancel()

go c.sendInteractive(ctx, cr, events, op, readUntilF)

timer := time.NewTimer(c.GetTimeout(op.Timeout))
r := <-cr
if r.err != nil {
if errors.Is(r.err, context.DeadlineExceeded) {
c.l.Critical("channel timeout sending input to device")

select {
case r := <-cr:
if r.err != nil {
return nil, r.err
return nil, fmt.Errorf(
"%w: channel timeout sending input to device",
util.ErrTimeoutError,
)
}

return r.b, nil
case <-timer.C:
c.l.Critical("channel timeout sending interactive input to device")

return nil, fmt.Errorf(
"%w: channel timeout sending interactive input to device",
util.ErrTimeoutError,
)
return nil, r.err
}

return r.b, nil
}
28 changes: 16 additions & 12 deletions driver/netconf/capabilities.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@ package netconf

import (
"context"
"errors"
"fmt"
"strconv"
"time"

"github.com/scrapli/scrapligo/util"
)
Expand Down Expand Up @@ -45,7 +45,10 @@ type result struct {
func (d *Driver) getServerCapabilities() ([]byte, error) {
cr := make(chan *result)

ctx, cancel := context.WithCancel(context.Background())
ctx, cancel := context.WithTimeout(
context.Background(),
d.Channel.GetTimeout(d.Channel.TimeoutOps),
)

defer cancel()

Expand All @@ -68,20 +71,21 @@ func (d *Driver) getServerCapabilities() ([]byte, error) {
}
}()

timer := time.NewTimer(d.Channel.GetTimeout(d.Channel.TimeoutOps))
r := <-cr
if r.err != nil {
if errors.Is(r.err, context.DeadlineExceeded) {
d.Logger.Critical("channel timeout reading capabilities")

select {
case r := <-cr:
if r.err != nil {
return nil, r.err
return nil, fmt.Errorf(
"%w: channel timeout reading capabilities",
util.ErrTimeoutError,
)
}

return r.b, nil
case <-timer.C:
d.Logger.Critical("channel timeout reading capabilities")

return nil, fmt.Errorf("%w: channel timeout reading capabilities", util.ErrTimeoutError)
return nil, r.err
}

return r.b, nil
}

func (d *Driver) processServerCapabilities() error {
Expand Down

0 comments on commit 7ea22c0

Please sign in to comment.