diff --git a/cmd/client.go b/cmd/client.go index f7b3fa8..d8db5a0 100644 --- a/cmd/client.go +++ b/cmd/client.go @@ -1,12 +1,14 @@ package cmd import ( - "github.com/omrikiei/ktunnel/pkg/client" + "context" "os" "os/signal" "sync" "syscall" + "github.com/omrikiei/ktunnel/pkg/client" + log "github.com/sirupsen/logrus" "github.com/spf13/cobra" ) @@ -26,24 +28,24 @@ var clientCmd = &cobra.Command{ ktunnel client --host ktunnel-server.yourcompany.com -s tcp 8000 8001:8432 `, Run: func(cmd *cobra.Command, args []string) { + ctx, cancel := context.WithCancel(context.Background()) if Verbose { log.SetLevel(log.DebugLevel) } o := sync.Once{} - closeChan := make(chan bool, 1) // Run tunnel client and establish connection sigs := make(chan os.Signal, 1) signal.Notify(sigs, syscall.SIGINT, syscall.SIGTERM, syscall.SIGKILL, syscall.SIGQUIT) - go func() { o.Do(func() { _ = <-sigs log.Info("Got exit signal, closing client tunnels") - close(closeChan) + cancel() }) }() - err := client.RunClient(&Host, &Port, Scheme, &Tls, &CaFile, &ServerHostOverride, args, closeChan) + + err := client.RunClient(ctx, &Host, &Port, Scheme, &Tls, &CaFile, &ServerHostOverride, args) if err != nil { log.Fatalf("Failed to run client: %v", err) } diff --git a/cmd/expose.go b/cmd/expose.go index 1b18e6a..fe1db70 100644 --- a/cmd/expose.go +++ b/cmd/expose.go @@ -1,15 +1,17 @@ package cmd import ( - "github.com/omrikiei/ktunnel/pkg/client" - "github.com/omrikiei/ktunnel/pkg/k8s" - log "github.com/sirupsen/logrus" - "github.com/spf13/cobra" + "context" "os" "os/signal" "strconv" "sync" "syscall" + + "github.com/omrikiei/ktunnel/pkg/client" + "github.com/omrikiei/ktunnel/pkg/k8s" + log "github.com/sirupsen/logrus" + "github.com/spf13/cobra" ) var exposeCmd = &cobra.Command{ @@ -26,11 +28,13 @@ ktunnel expose kewlapp 80:8000 ktunnel expose redis 6379 `, Run: func(cmd *cobra.Command, args []string) { + ctx, cancel := context.WithCancel(context.Background()) if Verbose { log.SetLevel(log.DebugLevel) k8s.Verbose = true } o := sync.Once{} + // Create service and deployment svcName, ports := args[0], args[1:] readyChan := make(chan bool, 1) @@ -41,7 +45,6 @@ ktunnel expose redis 6379 sigs := make(chan os.Signal, 1) wg := &sync.WaitGroup{} done := make(chan bool, 1) - closeChan := make(chan bool, 1) signal.Notify(sigs, syscall.SIGINT, syscall.SIGTERM, syscall.SIGKILL, syscall.SIGQUIT) // Teardown @@ -49,7 +52,7 @@ ktunnel expose redis 6379 o.Do(func() { _ = <-sigs log.Info("Got exit signal, closing client tunnels and removing k8s objects") - close(closeChan) + cancel() err := k8s.TeardownExposedService(Namespace, svcName) if err != nil { log.Errorf("Failed deleting k8s objects: %s", err) @@ -79,7 +82,7 @@ ktunnel expose redis 6379 log.Fatalf("Failed to run client: %v", err) } prt := int(p) - err = client.RunClient(&Host, &prt, Scheme, &Tls, &CaFile, &ServerHostOverride, args[1:], closeChan) + err = client.RunClient(ctx, &Host, &prt, Scheme, &Tls, &CaFile, &ServerHostOverride, args[1:]) if err != nil { log.Fatalf("Failed to run client: %v", err) } diff --git a/cmd/inject.go b/cmd/inject.go index 09eef9f..15f00b2 100644 --- a/cmd/inject.go +++ b/cmd/inject.go @@ -1,15 +1,17 @@ package cmd import ( - "github.com/omrikiei/ktunnel/pkg/client" - "github.com/omrikiei/ktunnel/pkg/k8s" - log "github.com/sirupsen/logrus" - "github.com/spf13/cobra" + "context" "os" "os/signal" "strconv" "sync" "syscall" + + "github.com/omrikiei/ktunnel/pkg/client" + "github.com/omrikiei/ktunnel/pkg/k8s" + log "github.com/sirupsen/logrus" + "github.com/spf13/cobra" ) var Namespace string @@ -30,6 +32,7 @@ var injectDeploymentCmd = &cobra.Command{ ktunnel inject deploymeny mydeployment 3306 6379 `, Run: func(cmd *cobra.Command, args []string) { + ctx, cancel := context.WithCancel(context.Background()) if Verbose { log.SetLevel(log.DebugLevel) k8s.Verbose = true @@ -38,7 +41,6 @@ ktunnel inject deploymeny mydeployment 3306 6379 // Inject deployment := args[0] readyChan := make(chan bool, 1) - closeChan := make(chan bool, 1) _, err := k8s.InjectSidecar(&Namespace, &deployment, &Port, readyChan) if err != nil { log.Fatalf("failed injecting sidecar: %v", err) @@ -46,7 +48,6 @@ ktunnel inject deploymeny mydeployment 3306 6379 // Signal hook to remove sidecar sigs := make(chan os.Signal, 1) - done := make(chan bool, 1) wg := &sync.WaitGroup{} signal.Notify(sigs, syscall.SIGINT, syscall.SIGTERM, syscall.SIGKILL, syscall.SIGQUIT) stopChan := make(chan struct{}, 1) @@ -55,8 +56,7 @@ ktunnel inject deploymeny mydeployment 3306 6379 o.Do(func() { <-sigs log.Info("Stopping streams") - close(closeChan) - close(stopChan) + cancel() wg.Wait() readyChan = make(chan bool, 1) ok, err := k8s.RemoveSidecar(&Namespace, &deployment, readyChan) @@ -65,7 +65,6 @@ ktunnel inject deploymeny mydeployment 3306 6379 } <-readyChan log.Info("Finished, exiting") - close(done) }) }() @@ -89,13 +88,12 @@ ktunnel inject deploymeny mydeployment 3306 6379 log.Fatalf("Failed to run client: %v", err) } prt := int(p) - err = client.RunClient(&Host, &prt, Scheme, &Tls, &CaFile, &ServerHostOverride, args[1:], closeChan) + err = client.RunClient(ctx, &Host, &prt, Scheme, &Tls, &CaFile, &ServerHostOverride, args[1:]) if err != nil { log.Fatalf("Failed to run client: %v", err) } }() } - <-done }, } diff --git a/cmd/server.go b/cmd/server.go index 0e1faa8..dda6f2d 100644 --- a/cmd/server.go +++ b/cmd/server.go @@ -1,6 +1,12 @@ package cmd import ( + "context" + "os" + "os/signal" + "sync" + "syscall" + "github.com/omrikiei/ktunnel/pkg/server" log "github.com/sirupsen/logrus" "github.com/spf13/cobra" @@ -18,10 +24,23 @@ var serverCmd = &cobra.Command{ ktunnel server -p 8181 `, Run: func(cmd *cobra.Command, args []string) { + ctx, cancel := context.WithCancel(context.Background()) if Verbose { log.SetLevel(log.DebugLevel) } - err := server.RunServer(&Port, &Tls, &KeyFile, &CertFile) + o := sync.Once{} + // Run tunnel client and establish connection + + sigs := make(chan os.Signal, 1) + signal.Notify(sigs, syscall.SIGINT, syscall.SIGTERM, syscall.SIGKILL, syscall.SIGQUIT) + go func() { + o.Do(func() { + _ = <-sigs + log.Info("Got exit signal, closing client tunnels") + cancel() + }) + }() + err := server.RunServer(ctx, &Port, &Tls, &KeyFile, &CertFile) if err != nil { log.Fatalf("Error running server: %v", err) } diff --git a/pkg/client/client.go b/pkg/client/client.go index ec2085a..6715fc4 100644 --- a/pkg/client/client.go +++ b/pkg/client/client.go @@ -2,6 +2,12 @@ package client import ( "fmt" + "io" + "net" + "strings" + "sync" + "time" + "github.com/google/uuid" "github.com/omrikiei/ktunnel/pkg/common" pb "github.com/omrikiei/ktunnel/tunnel_pb" @@ -9,11 +15,6 @@ import ( "golang.org/x/net/context" "google.golang.org/grpc" "google.golang.org/grpc/credentials" - "io" - "net" - "strings" - "sync" - "time" ) const ( @@ -25,7 +26,7 @@ type Message struct { d *[]byte } -func ReceiveData(st *pb.Tunnel_InitTunnelClient, closeStream <-chan bool, sessionsOut chan<- *common.Session, port int32, scheme string) { +func ReceiveData(st *pb.Tunnel_InitTunnelClient, closeStream <-chan struct{}, sessionsOut chan<- *common.Session, port int32, scheme string) { stream := *st loop: for { @@ -132,7 +133,7 @@ func ReadFromSession(session *common.Session, sessionsOut chan<- *common.Session log.Debugf("finished reading from session %s", session.Id) } -func SendData(stream *pb.Tunnel_InitTunnelClient, sessions <-chan *common.Session, closeChan <-chan bool) { +func SendData(stream *pb.Tunnel_InitTunnelClient, sessions <-chan *common.Session, closeChan <-chan struct{}) { for { select { case <-closeChan: @@ -160,11 +161,11 @@ func SendData(stream *pb.Tunnel_InitTunnelClient, sessions <-chan *common.Sessio } } -func RunClient(host *string, port *int, scheme string, tls *bool, caFile, serverHostOverride *string, tunnels []string, stopChan <-chan bool) error { +func RunClient(ctx context.Context, host *string, port *int, scheme string, tls *bool, caFile, serverHostOverride *string, tunnels []string) error { wg := sync.WaitGroup{} - closeStreams := make([]chan bool, len(tunnels)) + closeStreams := make([]chan struct{}, len(tunnels)) go func() { - <-stopChan + <-ctx.Done() for _, c := range closeStreams { close(c) } @@ -193,8 +194,8 @@ func RunClient(host *string, port *int, scheme string, tls *bool, caFile, server log.Error(err) } wg.Add(1) - c := make(chan bool, 1) - go func(closeStream chan bool) { + c := make(chan struct{}, 1) + go func(closeStream chan struct{}) { log.Println(fmt.Sprintf("starting %s tunnel from source %d to target %d", scheme, tunnelData.Source, tunnelData.Target)) ctx := context.Background() tunnelScheme, ok := pb.TunnelScheme_value[scheme] diff --git a/pkg/server/server.go b/pkg/server/server.go index 74071c1..68efaab 100644 --- a/pkg/server/server.go +++ b/pkg/server/server.go @@ -1,16 +1,18 @@ package server import ( + "context" "errors" "fmt" + "net" + "strings" + "github.com/google/uuid" "github.com/omrikiei/ktunnel/pkg/common" pb "github.com/omrikiei/ktunnel/tunnel_pb" log "github.com/sirupsen/logrus" "google.golang.org/grpc" "google.golang.org/grpc/credentials" - "net" - "strings" ) type tunnelServer struct{} @@ -177,12 +179,19 @@ func (t *tunnelServer) InitTunnel(stream pb.Tunnel_InitTunnelServer) error { } } -func RunServer(port *int, tls *bool, keyFile, certFile *string) error { +func RunServer(ctx context.Context, port *int, tls *bool, keyFile, certFile *string) error { log.Infof("Starting to listen on port %d", *port) lis, err := net.Listen("tcp", fmt.Sprintf("0.0.0.0:%d", *port)) if err != nil { log.Fatalf("failed to listen: %v", err) } + + // handle context cancellation, shut down the server + go func() { + <-ctx.Done() + lis.Close() + }() + var opts []grpc.ServerOption if *tls { creds, err := credentials.NewServerTLSFromFile(*certFile, *keyFile) @@ -191,8 +200,8 @@ func RunServer(port *int, tls *bool, keyFile, certFile *string) error { } opts = []grpc.ServerOption{grpc.Creds(creds)} } + grpcServer := grpc.NewServer(opts...) pb.RegisterTunnelServer(grpcServer, NewServer()) - err = grpcServer.Serve(lis) - return err + return grpcServer.Serve(lis) }