Skip to content

Commit

Permalink
client: add upstream.ListenAndForward
Browse files Browse the repository at this point in the history
Adds support for listening on an endpoint and forwarding connections to
an upstream address.
  • Loading branch information
andydunstall committed Jul 3, 2024
1 parent b61d005 commit 748997c
Show file tree
Hide file tree
Showing 2 changed files with 66 additions and 3 deletions.
6 changes: 5 additions & 1 deletion client/listener.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,8 +62,12 @@ func newListener(endpointID string, upstream *Upstream, logger Logger) *listener
}

func (l *listener) Accept() (net.Conn, error) {
return l.AcceptWithContext(context.Background())
}

func (l *listener) AcceptWithContext(ctx context.Context) (net.Conn, error) {
for {
conn, err := l.sess.Accept()
conn, err := l.sess.AcceptStreamWithContext(ctx)
if err == nil {
return conn, nil
}
Expand Down
63 changes: 61 additions & 2 deletions client/upstream.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,10 @@ import (
"crypto/tls"
"errors"
"fmt"
"io"
"net"
"net/url"
"sync"
"time"

"github.com/hashicorp/yamux"
Expand Down Expand Up @@ -69,8 +72,22 @@ func (u *Upstream) Listen(ctx context.Context, endpointID string) (Listener, err
// forwards them to the given address.
//
// This will block until the context is canceled.
func (u *Upstream) ListenAndForward(_ context.Context, _ string, _ string) error {
// TODO(andydunstall)
func (u *Upstream) ListenAndForward(ctx context.Context, endpointID string, addr string) error {
ln := newListener(endpointID, u, u.logger())
if err := ln.connect(ctx); err != nil {
return fmt.Errorf("connect: %w", err)
}
defer ln.Close()

for {
conn, err := ln.AcceptWithContext(ctx)
if err != nil {
return fmt.Errorf("accept: %w", err)
}

go u.forwardConn(ctx, conn, addr)
}

return nil

Check failure on line 91 in client/upstream.go

View workflow job for this annotation

GitHub Actions / lint

unreachable: unreachable code (govet)
}

Expand Down Expand Up @@ -142,6 +159,48 @@ func (u *Upstream) connect(ctx context.Context, endpointID string) (*yamux.Sessi
}
}

func (u *Upstream) forwardConn(ctx context.Context, conn net.Conn, addr string) {
defer conn.Close()

dialer := &net.Dialer{}
upstream, err := dialer.DialContext(ctx, "tcp", addr)
if err != nil {
u.logger().Error(
"failed to dial upstream",
zap.String("addr", addr),
zap.Error(err),
)
return
}

u.logger().Debug(
"dialed upstream",
zap.String("addr", addr),
)
defer func() {
u.logger().Debug(
"conn closed",
zap.String("addr", addr),
)
}()

g := &sync.WaitGroup{}
g.Add(2)
go func() {
defer g.Done()
defer conn.Close()
// nolint
io.Copy(conn, upstream)
}()
go func() {
defer g.Done()
defer upstream.Close()
// nolint
io.Copy(upstream, conn)
}()
g.Wait()
}

func (u *Upstream) listenURL(endpointID string) string {
var listenURL url.URL
if u.URL == nil {
Expand Down

0 comments on commit 748997c

Please sign in to comment.