Skip to content

Commit

Permalink
feat: instrument args via functional args
Browse files Browse the repository at this point in the history
  • Loading branch information
jaredallard committed Sep 26, 2020
1 parent a0f2fa9 commit 2fbf336
Show file tree
Hide file tree
Showing 3 changed files with 276 additions and 108 deletions.
196 changes: 146 additions & 50 deletions pkg/client/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,15 @@ package client
import (
"fmt"
"io"
"io/ioutil"
"net"
"strings"
"time"

"github.com/google/uuid"
"github.com/omrikiei/ktunnel/pkg/common"
pb "github.com/omrikiei/ktunnel/tunnel_pb"
"github.com/pkg/errors"
log "github.com/sirupsen/logrus"
"golang.org/x/net/context"
"google.golang.org/grpc"
Expand All @@ -21,135 +23,136 @@ type Message struct {
d *[]byte
}

func ReceiveData(ctx context.Context, st pb.Tunnel_InitTunnelClient, sessionsOut chan<- *common.Session, port int32, scheme string) {
func ReceiveData(ctx context.Context, conf *ClientConfig, st pb.Tunnel_InitTunnelClient, sessionsOut chan<- *common.Session, port int32, scheme string) {
loop:
for {
select {
case <-ctx.Done():
log.WithError(ctx.Err()).Infof("stopping to receive data on port %d", port)
conf.log.WithError(ctx.Err()).Infof("closing listener on %d", port)
_ = st.CloseSend()
break loop
default:
log.Debugf("attempting to receive from stream")
conf.log.Debugf("attempting to receive from stream")
m, err := st.Recv()
if err != nil {
log.WithError(err).Warnf("error reading from stream")
conf.log.WithError(err).Warnf("error reading from stream")
break loop
}

requestId, err := uuid.Parse(m.RequestId)
if err != nil {
log.WithError(err).WithField("session", m.RequestId).Errorf("failed parsing session uuid from stream, skipping")
conf.log.WithError(err).WithField("session", m.RequestId).Errorf("failed parsing session uuid from stream, skipping")
}

session, exists := common.GetSession(requestId)
if exists == false {
log.WithFields(log.Fields{
conf.log.WithFields(log.Fields{
"session": m.RequestId,
"port": port,
}).Infof("new connection")

// new session
conn, err := net.DialTimeout(strings.ToLower(scheme), fmt.Sprintf("localhost:%d", port), time.Millisecond*500)
if err != nil {
log.WithError(err).Errorf("failed connecting to localhost on port %d scheme %s", port, scheme)
conf.log.WithError(err).Errorf("failed connecting to localhost on port %d scheme %s", port, scheme)
// close the remote connection
resp := &pb.SocketDataRequest{
RequestId: requestId.String(),
ShouldClose: true,
}
err := st.Send(resp)
if err != nil {
log.WithError(err).Errorf("failed sending close message to tunnel stream")
conf.log.WithError(err).Errorf("failed sending close message to tunnel stream")
}

continue
} else {
session = common.NewSessionFromStream(requestId, conn)
go ReadFromSession(session, sessionsOut)
go ReadFromSession(conf, session, sessionsOut)
}
} else if m.ShouldClose {
session.Open = false
}

// process the data from the server
handleStreamData(m, session)
handleStreamData(conf, m, session)
}
}
}

func handleStreamData(m *pb.SocketDataResponse, session *common.Session) {
func handleStreamData(conf *ClientConfig, m *pb.SocketDataResponse, session *common.Session) {
if session.Open == false {
log.WithField("session", session.Id).Infof("closed session")
conf.log.WithField("session", session.Id).Infof("closed session")
session.Close()
return
}

data := m.GetData()
log.WithField("session", session.Id).Infof("received %d bytes from server", len(data))
conf.log.WithField("session", session.Id).Infof("received %d bytes from server", len(data))
if len(data) > 0 {
session.Lock.Lock()
log.WithField("session", session.Id).Infof("wrote %d bytes to conn", len(data))
conf.log.WithField("session", session.Id).Infof("wrote %d bytes to conn", len(data))
_, err := session.Conn.Write(data)
session.Lock.Unlock()
if err != nil {
log.WithError(err).WithField("session", session.Id).Errorf("failed writing to socket, closing session")
conf.log.WithError(err).WithField("session", session.Id).Errorf("failed writing to socket, closing session")
session.Close()
return
}
}
}

func ReadFromSession(session *common.Session, sessionsOut chan<- *common.Session) {
func ReadFromSession(conf *ClientConfig, session *common.Session, sessionsOut chan<- *common.Session) {
conn := session.Conn
log.WithField("session", session.Id).Debugf("started reading conn")
conf.log.WithField("session", session.Id).Debugf("started reading conn")

for {
buff := make([]byte, common.BufferSize)
br, err := conn.Read(buff)

if err != nil {
if err != io.EOF {
log.WithError(err).WithField("session", session.Id).Errorf("failed reading from socket")
break
}

conf.log.WithError(err).WithField("session", session.Id).Errorf("failed reading from socket")
session.Open = false
sessionsOut <- session
break
}

log.WithField("session", session.Id).WithError(err).Infof("read %d bytes from conn", br)
conf.log.WithField("session", session.Id).WithError(err).Debugf("read %d bytes from conn", br)

session.Lock.Lock()
if br > 0 {
log.WithField("session", session.Id).WithError(err).Infof("wrote %d bytes to session buf", br)
conf.log.WithField("session", session.Id).WithError(err).Debugf("wrote %d bytes to session buf", br)
_, err = session.Buf.Write(buff[0:br])
}
session.Lock.Unlock()

if err != nil {
log.WithField("session", session.Id).WithError(err).Errorf("failed writing to session buffer")
conf.log.WithField("session", session.Id).WithError(err).Errorf("failed writing to session buffer")
break
}
sessionsOut <- session
}
log.Debugf("finished reading from session %s", session.Id)
conf.log.WithField("session", session.Id).Debugf("finished reading session")
}

func SendData(ctx context.Context, stream pb.Tunnel_InitTunnelClient, sessions <-chan *common.Session) {
func SendData(ctx context.Context, conf *ClientConfig, stream pb.Tunnel_InitTunnelClient, sessions <-chan *common.Session) {
for {
select {
case <-ctx.Done():
return
case session := <-sessions:

// read the bytes from the buffer
// but allow it to keep growing while we send the response
session.Lock.Lock()
bys := session.Buf.Len()
bytes := make([]byte, bys)
session.Buf.Read(bytes)

log.WithField("session", session.Id).Infof("read %d from buffer out of %d available", len(bytes), bys)
conf.log.WithField("session", session.Id).Debugf("read %d from buffer out of %d available", len(bytes), bys)

resp := &pb.SocketDataRequest{
RequestId: session.Id.String(),
Expand All @@ -158,52 +161,60 @@ func SendData(ctx context.Context, stream pb.Tunnel_InitTunnelClient, sessions <
}
session.Lock.Unlock()

log.WithFields(log.Fields{
conf.log.WithFields(log.Fields{
"session": session.Id,
"close": resp.ShouldClose,
}).Infof("sending %d bytes to server", len(bytes))
}).Debugf("sending %d bytes to server", len(bytes))
err := stream.Send(resp)
if err != nil {
log.WithError(err).Errorf("failed sending message to tunnel stream, exiting")
conf.log.WithError(err).Errorf("failed sending message to tunnel stream, exiting")
return
}
log.WithFields(log.Fields{
conf.log.WithFields(log.Fields{
"session": session.Id,
"close": resp.ShouldClose,
}).Infof("sent %d bytes to server", len(bytes))
}).Debugf("sent %d bytes to server", len(bytes))
}
}
}

func RunClient(ctx context.Context, host *string, port *int, scheme string, tls *bool, caFile, serverHostOverride *string, tunnels []string) error {
var opts []grpc.DialOption
if *tls {
creds, err := credentials.NewClientTLSFromFile(*caFile, *serverHostOverride)
// RunClient creates a GRPC tunnel client
func RunClient(ctx context.Context, opts ...ClientOption) error {
conf, err := processArgs(opts)
if err != nil {
return errors.Wrap(err, "failed to parse arguments")
}

var grpcOpts []grpc.DialOption
if conf.TLS {
creds, err := credentials.NewClientTLSFromFile(conf.certFile, conf.tlsHostOverride)
if err != nil {
log.Fatalf("Failed to create TLS credentials %v", err)
return errors.Wrap(err, "failed to create TLS credentials")
}
opts = append(opts, grpc.WithTransportCredentials(creds))
grpcOpts = append(grpcOpts, grpc.WithTransportCredentials(creds))
} else {
opts = append(opts, grpc.WithInsecure())
grpcOpts = append(grpcOpts, grpc.WithInsecure())
}

conn, err := grpc.Dial(fmt.Sprintf("%s:%d", *host, *port), opts...)
conn, err := grpc.Dial(fmt.Sprintf("%s:%d", conf.host, conf.port), grpcOpts...)
if err != nil {
log.Fatalf("fail to dial: %v", err)
return errors.Wrap(err, "failed to dial")
}
defer conn.Close()

client := pb.NewTunnelClient(conn)
for _, rawTunnelData := range tunnels {
for _, rawTunnelData := range conf.tunnels {
tunnelData, err := common.ParsePorts(rawTunnelData)
if err != nil {
log.Error(err)
conf.log.Error(err)
continue
}

go func() {
log.Println(fmt.Sprintf("starting %s tunnel from source %d to target %d", scheme, tunnelData.Source, tunnelData.Target))
tunnelScheme, ok := pb.TunnelScheme_value[scheme]
conf.log.Infof("starting %s tunnel from source %d to target %d", conf.scheme, tunnelData.Source, tunnelData.Target)
tunnelScheme, ok := pb.TunnelScheme_value[conf.scheme]
if ok != false {
log.Fatalf("unsupported connection scheme %s", scheme)
conf.log.Fatalf("unsupported connection scheme %s", conf.scheme)
}

req := &pb.SocketDataRequest{
Expand All @@ -214,23 +225,108 @@ func RunClient(ctx context.Context, host *string, port *int, scheme string, tls

stream, err := client.InitTunnel(ctx)
if err != nil {
log.Errorf("Error sending init tunnel request: %v", err)
conf.log.Errorf("Error sending init tunnel request: %v", err)
} else {
err := stream.Send(req)
if err != nil {
log.WithError(err).Errorf("Failed to send initial tunnel request to server")
conf.log.WithError(err).Errorf("Failed to send initial tunnel request to server")
return
}

sessions := make(chan *common.Session)
go ReceiveData(ctx, stream, sessions, tunnelData.Target, scheme)
go SendData(ctx, stream, sessions)
<-ctx.Done()
go ReceiveData(ctx, conf, stream, sessions, tunnelData.Target, conf.scheme)
go SendData(ctx, conf, stream, sessions)
}

}()
}

// wait for the context to be cancelled
<-ctx.Done()
return nil
}

// processArgs processes functional args
func processArgs(opts []ClientOption) (*ClientConfig, error) {
// default arguments
opt := &ClientConfig{
log: &log.Logger{
Out: ioutil.Discard,
},
scheme: "tcp",
TLS: false,
}

for _, f := range opts {
if err := f(opt); err != nil {
return nil, err
}
}

if len(opt.tunnels) == 0 {
return nil, fmt.Errorf("no tunnels given")
}

if opt.host == "" || opt.port == 0 {
return nil, fmt.Errorf("missing host configuration")
}

return opt, nil
}

// WithServer configures the server this client uses
func WithServer(host string, p int) ClientOption {
return func(opt *ClientConfig) error {
opt.host = host
opt.port = p
return nil
}
}

// WithTLS configures the tunnel to use TLS
// and sets the certificate expected, and a optional
// tls hostname override.
func WithTLS(cert, tlsHostOverride string) ClientOption {
return func(opt *ClientConfig) error {
opt.TLS = true
opt.certFile = cert
opt.tlsHostOverride = tlsHostOverride
return nil
}
}

// WithLogger sets the logger to be used by the server.
// if not set, output will be discarded
func WithLogger(l log.FieldLogger) ClientOption {
return func(opt *ClientConfig) error {
opt.log = l
return nil
}
}

// WithTunnels configures the tunnels to be exposed
// by this client. Each string should be in the format
// of: localPort:remotePort
func WithTunnels(scheme string, tunnels ...string) ClientOption {
return func(opt *ClientConfig) error {
opt.scheme = scheme
opt.tunnels = tunnels
return nil
}
}

// ClientOption is an option able to be configured
type ClientOption func(*ClientConfig) error

// ClientConfig is a config object used to
// configure a GRPC tunnel from the client side.
// ClientOption should be used to modify this
type ClientConfig struct {
host string
port int
TLS bool
certFile string
tlsHostOverride string
scheme string
log log.FieldLogger
tunnels []string
}
Loading

0 comments on commit 2fbf336

Please sign in to comment.