diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..af5bd75 --- /dev/null +++ b/.gitignore @@ -0,0 +1,2 @@ +vendor/* +ecs-ssh-linux-amd64 diff --git a/Gopkg.lock b/Gopkg.lock new file mode 100644 index 0000000..a6ccec0 --- /dev/null +++ b/Gopkg.lock @@ -0,0 +1,162 @@ +# This file is autogenerated, do not edit; changes may be undone by the next 'dep ensure'. + + +[[projects]] + digest = "1:63dcec0ce79283851834f681503a2cc1d50e1b1b0c5750074882445a3082ca1c" + name = "github.com/Azure/go-ansiterm" + packages = [ + ".", + "winterm", + ] + pruneopts = "UT" + revision = "388960b655244e76e24c75f48631564eaefade62" + +[[projects]] + digest = "1:f3b6664d93ffaee13d93f1dcab4f30ce64204c4b24dda73695ecc8e686645d1b" + name = "github.com/Sirupsen/logrus" + packages = ["."] + pruneopts = "UT" + revision = "d26492970760ca5d33129d2d799e34be5c4782eb" + version = "v0.11.0" + +[[projects]] + digest = "1:002cca47487ea24278047c0211f201eb3ebac485859ae24e35e7c5a32a5c874b" + name = "github.com/aws/aws-sdk-go" + packages = [ + "aws", + "aws/awserr", + "aws/awsutil", + "aws/client", + "aws/client/metadata", + "aws/corehandlers", + "aws/credentials", + "aws/credentials/ec2rolecreds", + "aws/credentials/endpointcreds", + "aws/credentials/stscreds", + "aws/csm", + "aws/defaults", + "aws/ec2metadata", + "aws/endpoints", + "aws/request", + "aws/session", + "aws/signer/v4", + "internal/sdkio", + "internal/sdkrand", + "internal/sdkuri", + "internal/shareddefaults", + "private/protocol", + "private/protocol/ec2query", + "private/protocol/json/jsonutil", + "private/protocol/jsonrpc", + "private/protocol/query", + "private/protocol/query/queryutil", + "private/protocol/rest", + "private/protocol/xml/xmlutil", + "service/ec2", + "service/ecs", + "service/sts", + ] + pruneopts = "UT" + revision = "66974140c322f22c1daaf95a18930ea6a9e4d21e" + version = "v1.15.16" + +[[projects]] + digest = "1:53e99d883df3e940f5f0223795f300eb32b8c044f226132bfc0e74930f24ea4b" + name = "github.com/docker/docker" + packages = [ + "pkg/term", + "pkg/term/windows", + ] + pruneopts = "UT" + revision = "092cba3727bb9b4a2f0e922cd6c0f93ea270e363" + version = "v1.13.1" + +[[projects]] + digest = "1:fe8a03a8222d5b913f256972933d26d24ad7c8286692a42943bc01633cc8fce3" + name = "github.com/go-ini/ini" + packages = ["."] + pruneopts = "UT" + revision = "358ee7663966325963d4e8b2e1fbd570c5195153" + version = "v1.38.1" + +[[projects]] + digest = "1:e22af8c7518e1eab6f2eab2b7d7558927f816262586cd6ed9f349c97a6c285c4" + name = "github.com/jmespath/go-jmespath" + packages = ["."] + pruneopts = "UT" + revision = "0b12d6b5" + +[[projects]] + branch = "master" + digest = "1:114ecad51af93a73ae6781fd0d0bc28e52b433c852b84ab4b4c109c15e6c6b6d" + name = "github.com/jroimartin/gocui" + packages = ["."] + pruneopts = "UT" + revision = "c055c87ae801372cd74a0839b972db4f7697ae5f" + +[[projects]] + branch = "master" + digest = "1:a330103bc9731260ee9fa14764e9e3fce46e02de19d6aca3eeba1d425badfbf0" + name = "github.com/juju/loggo" + packages = ["."] + pruneopts = "UT" + revision = "584905176618da46b895b176c721b02c476b6993" + +[[projects]] + digest = "1:cdb899c199f907ac9fb50495ec71212c95cb5b0e0a8ee0800da0238036091033" + name = "github.com/mattn/go-runewidth" + packages = ["."] + pruneopts = "UT" + revision = "ce7b0b5c7b45a81508558cd1dba6bb1e4ddb51bb" + version = "v0.0.3" + +[[projects]] + branch = "master" + digest = "1:f335d800550786b6f51ddaedb9d1107a7a72f4a2195e5b039dd7c0e103e119bc" + name = "github.com/nsf/termbox-go" + packages = ["."] + pruneopts = "UT" + revision = "b66b20ab708e289ff1eb3e218478302e6aec28ce" + +[[projects]] + branch = "master" + digest = "1:8e56f01a7d273a0c67edb531e176e93b9f27cbe567bc84213b9b5c03bcb57b78" + name = "golang.org/x/crypto" + packages = [ + "curve25519", + "ed25519", + "ed25519/internal/edwards25519", + "internal/chacha20", + "internal/subtle", + "poly1305", + "ssh", + "ssh/agent", + ] + pruneopts = "UT" + revision = "614d502a4dac94afa3a6ce146bd1736da82514c6" + +[[projects]] + branch = "master" + digest = "1:a3f00ac457c955fe86a41e1495e8f4c54cb5399d609374c5cc26aa7d72e542c8" + name = "golang.org/x/sys" + packages = ["unix"] + pruneopts = "UT" + revision = "3b58ed4ad3395d483fc92d5d14123ce2c3581fec" + +[solve-meta] + analyzer-name = "dep" + analyzer-version = 1 + input-imports = [ + "github.com/aws/aws-sdk-go/aws", + "github.com/aws/aws-sdk-go/aws/awserr", + "github.com/aws/aws-sdk-go/aws/session", + "github.com/aws/aws-sdk-go/service/ec2", + "github.com/aws/aws-sdk-go/service/ecs", + "github.com/docker/docker/pkg/term", + "github.com/jroimartin/gocui", + "github.com/juju/loggo", + "golang.org/x/crypto/ssh", + "golang.org/x/crypto/ssh/agent", + ] + solver-name = "gps-cdcl" + solver-version = 1 diff --git a/Gopkg.toml b/Gopkg.toml new file mode 100644 index 0000000..159d772 --- /dev/null +++ b/Gopkg.toml @@ -0,0 +1,50 @@ +# Gopkg.toml example +# +# Refer to https://golang.github.io/dep/docs/Gopkg.toml.html +# for detailed Gopkg.toml documentation. +# +# required = ["github.com/user/thing/cmd/thing"] +# ignored = ["github.com/user/project/pkgX", "bitbucket.org/user/project/pkgA/pkgY"] +# +# [[constraint]] +# name = "github.com/user/project" +# version = "1.0.0" +# +# [[constraint]] +# name = "github.com/user/project2" +# branch = "dev" +# source = "github.com/myfork/project2" +# +# [[override]] +# name = "github.com/x/y" +# version = "2.4.0" +# +# [prune] +# non-go = false +# go-tests = true +# unused-packages = true + + +[[constraint]] + name = "github.com/aws/aws-sdk-go" + version = "1.15.16" + +[[constraint]] + name = "github.com/docker/docker" + version = "1.13.1" + +[[constraint]] + name = "github.com/jroimartin/gocui" + branch = "master" + +[[constraint]] + branch = "master" + name = "github.com/juju/loggo" + +[[constraint]] + branch = "master" + name = "golang.org/x/crypto" + +[prune] + go-tests = true + unused-packages = true diff --git a/Makefile b/Makefile new file mode 100644 index 0000000..e7856c7 --- /dev/null +++ b/Makefile @@ -0,0 +1,11 @@ +BINARY = ecs-ssh +GOARCH = amd64 + + +all: build + +build: + GOOS=linux GOARCH=${GOARCH} go build ${LDFLAGS} -o ${BINARY}-linux-${GOARCH} . ; + +clean: + rm -f ${BINARY}-linux-${GOARCH} diff --git a/README.md b/README.md new file mode 100644 index 0000000..7aa8a99 --- /dev/null +++ b/README.md @@ -0,0 +1,12 @@ +# ecs-ssh +A shell frontend to ssh into ECS instances. Will display ECS cluster, services and tasks, determine ssh ip, and let you ssh into the instance + +## build +``` +make +``` +## install +``` +cp ecs-ssh-* /bin/ecs-ssh +``` + diff --git a/ecs.go b/ecs.go new file mode 100644 index 0000000..1b8c2e5 --- /dev/null +++ b/ecs.go @@ -0,0 +1,300 @@ +package main + +import ( + "github.com/aws/aws-sdk-go/aws" + "github.com/aws/aws-sdk-go/aws/awserr" + "github.com/aws/aws-sdk-go/aws/session" + "github.com/aws/aws-sdk-go/service/ec2" + "github.com/aws/aws-sdk-go/service/ecs" + "github.com/juju/loggo" + + "errors" + "math" + "os" + "strings" +) + +// logging +var ecsLogger = loggo.GetLogger("ecs") + +// ECS struct +type ECS struct { + clusterArns []*string + serviceArns []*string + taskArns []string + clusterNames []string + serviceNames []string + taskNames []string + containerInstances map[string]string + selectedClusterName string + selectedServiceName string + ipAddr *string + keyName *string + svc *ecs.ECS +} + +func newECS() *ECS { + e := ECS{} + // set default region if no region is set + if os.Getenv("AWS_REGION") == "" { + os.Setenv("AWS_REGION", "us-east-1") + } + e.svc = ecs.New(session.New()) + e.listCluster() + e.getClusterNames() + return &e +} + +// Creates ECS repository +func (e *ECS) listCluster() error { + input := &ecs.ListClustersInput{} + + pageNum := 0 + err := e.svc.ListClustersPages(input, + func(page *ecs.ListClustersOutput, lastPage bool) bool { + pageNum++ + e.clusterArns = append(e.clusterArns, page.ClusterArns...) + return pageNum <= 20 + }) + + if err != nil { + if aerr, ok := err.(awserr.Error); ok { + ecsLogger.Errorf(aerr.Error()) + } else { + ecsLogger.Errorf(err.Error()) + } + return err + } + return nil +} + +func (e *ECS) getClusterNames() error { + input := &ecs.DescribeClustersInput{ + Clusters: e.clusterArns, + } + + result, err := e.svc.DescribeClusters(input) + if err != nil { + if aerr, ok := err.(awserr.Error); ok { + ecsLogger.Errorf(aerr.Error()) + } else { + ecsLogger.Errorf(err.Error()) + } + return err + } + for _, cluster := range result.Clusters { + e.clusterNames = append(e.clusterNames, *cluster.ClusterName) + } + return nil +} +func (e *ECS) listServiceArns(clusterName string) error { + e.serviceArns = []*string{} + input := &ecs.ListServicesInput{ + Cluster: aws.String(clusterName), + } + + pageNum := 0 + err := e.svc.ListServicesPages(input, + func(page *ecs.ListServicesOutput, lastPage bool) bool { + pageNum++ + e.serviceArns = append(e.serviceArns, page.ServiceArns...) + return pageNum <= 20 + }) + if err != nil { + if aerr, ok := err.(awserr.Error); ok { + ecsLogger.Errorf(aerr.Error()) + } else { + ecsLogger.Errorf(err.Error()) + } + return err + } + return nil +} + +// gets service Arns and returns service names +func (e *ECS) getServices(clusterName string) ([]string, error) { + e.selectedClusterName = clusterName + err := e.listServiceArns(clusterName) + if err != nil { + return []string{}, err + } + + // fetch per 10 + var y float64 = float64(len(e.serviceArns)) / 10 + + for i := 0; i < int(math.Ceil(y)); i++ { + + f := i * 10 + t := int(math.Min(float64(10+10*i), float64(len(e.serviceArns)))) + + input := &ecs.DescribeServicesInput{ + Cluster: aws.String(clusterName), + Services: e.serviceArns[f:t], + } + + result, err := e.svc.DescribeServices(input) + if err != nil { + if aerr, ok := err.(awserr.Error); ok { + ecsLogger.Errorf(aerr.Error()) + } else { + ecsLogger.Errorf(err.Error()) + } + return []string{}, err + } + for _, service := range result.Services { + e.serviceNames = append(e.serviceNames, *service.ServiceName) + } + } + return e.serviceNames, nil +} + +// lists task arns +func (e *ECS) listTaskArns(serviceName string) error { + e.taskArns = []string{} + input := &ecs.ListTasksInput{ + Cluster: aws.String(e.selectedClusterName), + ServiceName: aws.String(serviceName), + } + + pageNum := 0 + err := e.svc.ListTasksPages(input, + func(page *ecs.ListTasksOutput, lastPage bool) bool { + pageNum++ + e.taskArns = append(e.taskArns, aws.StringValueSlice(page.TaskArns)...) + return pageNum <= 20 + }) + if err != nil { + if aerr, ok := err.(awserr.Error); ok { + ecsLogger.Errorf(aerr.Error()) + } else { + ecsLogger.Errorf(err.Error()) + } + return err + } + return nil +} + +// list all task ARNs for a cluster +func (e *ECS) listAllTaskArns() error { + e.taskArns = []string{} + input := &ecs.ListTasksInput{ + Cluster: aws.String(e.selectedClusterName), + } + + pageNum := 0 + err := e.svc.ListTasksPages(input, + func(page *ecs.ListTasksOutput, lastPage bool) bool { + pageNum++ + e.taskArns = append(e.taskArns, aws.StringValueSlice(page.TaskArns)...) + return pageNum <= 20 + }) + if err != nil { + if aerr, ok := err.(awserr.Error); ok { + ecsLogger.Errorf(aerr.Error()) + } else { + ecsLogger.Errorf(err.Error()) + } + return err + } + return nil +} +func (e *ECS) getTasks(serviceName string) ([]string, error) { + e.selectedServiceName = serviceName + e.containerInstances = make(map[string]string) + err := e.listTaskArns(serviceName) + if err != nil { + return []string{}, err + } + return e.outputTaskNames() +} +func (e *ECS) getAllTasks() ([]string, error) { + e.containerInstances = make(map[string]string) + err := e.listAllTaskArns() + if err != nil { + return []string{}, err + } + return e.outputTaskNames() +} +func (e *ECS) outputTaskNames() ([]string, error) { + // fetch per 100 + var y float64 = float64(len(e.taskArns)) / 100 + + for i := 0; i < int(math.Ceil(y)); i++ { + + f := i * 100 + t := int(math.Min(float64(100+100*i), float64(len(e.taskArns)))) + + input := &ecs.DescribeTasksInput{ + Cluster: aws.String(e.selectedClusterName), + Tasks: aws.StringSlice(e.taskArns[f:t]), + } + + result, err := e.svc.DescribeTasks(input) + if err != nil { + if aerr, ok := err.(awserr.Error); ok { + ecsLogger.Errorf(aerr.Error()) + } else { + ecsLogger.Errorf(err.Error()) + } + return []string{}, err + } + for _, task := range result.Tasks { + s := strings.Split(*task.TaskArn, "/") + if len(s) > 1 { + e.taskNames = append(e.taskNames, s[1]) + e.containerInstances[s[1]] = aws.StringValue(task.ContainerInstanceArn) + } + } + } + return e.taskNames, nil +} + +func (e *ECS) getContainerInstanceIP(taskArn string) (*string, error) { + + if e.containerInstances[taskArn] == "" { + return nil, errors.New("Task has no container instance assigned (task might not be running)") + } + + input := &ecs.DescribeContainerInstancesInput{ + Cluster: aws.String(e.selectedClusterName), + ContainerInstances: []*string{aws.String(e.containerInstances[taskArn])}, + } + + result, err := e.svc.DescribeContainerInstances(input) + if err != nil { + if aerr, ok := err.(awserr.Error); ok { + ecsLogger.Errorf(aerr.Error()) + } else { + ecsLogger.Errorf(err.Error()) + } + return nil, err + } + if len(result.ContainerInstances) == 0 { + return nil, errors.New("Container instance not found") + } + + inputInstances := &ec2.DescribeInstancesInput{ + InstanceIds: []*string{result.ContainerInstances[0].Ec2InstanceId}, + } + + svcEc2 := ec2.New(session.New()) + resultInstances, err := svcEc2.DescribeInstances(inputInstances) + if err != nil { + if aerr, ok := err.(awserr.Error); ok { + ecsLogger.Errorf(aerr.Error()) + } else { + ecsLogger.Errorf(err.Error()) + } + return nil, err + } + if len(resultInstances.Reservations) == 0 { + return nil, errors.New("EC2 instance not found") + } + if len(resultInstances.Reservations[0].Instances) == 0 { + return nil, errors.New("EC2 instance not found") + } + e.ipAddr = resultInstances.Reservations[0].Instances[0].PrivateIpAddress + e.keyName = resultInstances.Reservations[0].Instances[0].KeyName + + return e.ipAddr, nil +} diff --git a/main.go b/main.go new file mode 100644 index 0000000..a836544 --- /dev/null +++ b/main.go @@ -0,0 +1,299 @@ +package main + +import ( + "fmt" + "log" + "os" + + "github.com/jroimartin/gocui" +) + +var e = newECS() + +func previousView(g *gocui.Gui, v *gocui.View) error { + if v == nil || v.Name() == "services" { + v.Clear() + v.Highlight = false + v2, err := g.SetCurrentView("clusters") + v2.Highlight = true + return err + } + if v == nil || v.Name() == "tasks" { + v.Clear() + v.Highlight = false + v2, err := g.SetCurrentView("services") + v2.Highlight = true + return err + } + return nil +} + +func cursorDown(g *gocui.Gui, v *gocui.View) error { + if v != nil { + cx, cy := v.Cursor() + if err := v.SetCursor(cx, cy+1); err != nil { + ox, oy := v.Origin() + if err := v.SetOrigin(ox, oy+1); err != nil { + return err + } + } + } + return nil +} + +func cursorUp(g *gocui.Gui, v *gocui.View) error { + if v != nil { + ox, oy := v.Origin() + cx, cy := v.Cursor() + if err := v.SetCursor(cx, cy-1); err != nil && oy > 0 { + if err := v.SetOrigin(ox, oy-1); err != nil { + return err + } + } + } + return nil +} + +func showError(g *gocui.Gui, v *gocui.View, errToDisplay error) error { + g.Update(func(g2 *gocui.Gui) error { + maxX, maxY := g2.Size() + if v, err := g2.SetView("error", maxX/2-30, maxY/2, maxX/2+30, maxY/2+2); err != nil { + if err != gocui.ErrUnknownView { + return err + } + if _, err := g2.SetCurrentView("error"); err != nil { + return err + } + fmt.Fprintln(v, errToDisplay.Error()) + } + return nil + }) + return nil +} +func hideError(g *gocui.Gui, v *gocui.View) error { + err := g.DeleteView("error") + v, _ = g.SetCurrentView("clusters") + v.Highlight = true + return err +} +func getServices(g *gocui.Gui, v *gocui.View) error { + var err error + var clusterName string + + _, cy := v.Cursor() + if clusterName, err = v.Line(cy); err != nil { + return showError(g, v, err) + } + v.Highlight = false + + v2, err := g.SetCurrentView("services") + fmt.Fprintln(v2, "Loading...") + g.Update(func(g *gocui.Gui) error { + v2, _ := g.View("services") + services, err := e.getServices(clusterName) + if err != nil { + return showError(g, v, err) + } + v2.Clear() + v2.Highlight = true + + if len(services) == 0 { + fmt.Fprintln(v2, errNoServiceFound()) + err = getTasks(g, v2) + if err != nil { + return showError(g, v2, err) + } + } else { + for _, s := range services { + fmt.Fprintln(v2, s) + } + } + return nil + }) + + return err +} +func errNoServiceFound() string { + return "No Services Found" +} +func getTasks(g *gocui.Gui, v *gocui.View) error { + var err error + var serviceName string + + _, cy := v.Cursor() + if serviceName, err = v.Line(cy); err != nil { + return showError(g, v, err) + } + v.Highlight = false + + v2, err := g.SetCurrentView("tasks") + fmt.Fprintln(v2, "Loading...") + g.Update(func(g *gocui.Gui) error { + v2, _ := g.View("tasks") + var containers []string + var err error + if serviceName == errNoServiceFound() { + containers, err = e.getAllTasks() + } else { + containers, err = e.getTasks(serviceName) + } + if err != nil { + return showError(g, v2, err) + } + + v2.Clear() + for _, c := range containers { + fmt.Fprintln(v2, c) + } + + v2.Highlight = true + + return nil + }) + + return err +} + +func doSSH(g *gocui.Gui, v *gocui.View) error { + var err error + var taskName string + _, cy := v.Cursor() + if taskName, err = v.Line(cy); err != nil { + return showError(g, v, err) + } + _, err = e.getContainerInstanceIP(taskName) + if err != nil { + return showError(g, v, err) + } + + // exit and start ssh + g.Close() + + return gocui.ErrQuit +} + +func quit(g *gocui.Gui, v *gocui.View) error { + return gocui.ErrQuit +} + +func keybindings(g *gocui.Gui) error { + if err := g.SetKeybinding("", gocui.KeyCtrlC, gocui.ModNone, quit); err != nil { + return err + } + if err := g.SetKeybinding("clusters", gocui.KeyArrowDown, gocui.ModNone, cursorDown); err != nil { + return err + } + if err := g.SetKeybinding("clusters", gocui.KeyArrowUp, gocui.ModNone, cursorUp); err != nil { + return err + } + if err := g.SetKeybinding("clusters", gocui.KeyEnter, gocui.ModNone, getServices); err != nil { + return err + } + if err := g.SetKeybinding("services", gocui.KeyArrowLeft, gocui.ModNone, previousView); err != nil { + return err + } + if err := g.SetKeybinding("services", gocui.KeyArrowDown, gocui.ModNone, cursorDown); err != nil { + return err + } + if err := g.SetKeybinding("services", gocui.KeyArrowUp, gocui.ModNone, cursorUp); err != nil { + return err + } + if err := g.SetKeybinding("services", gocui.KeyEnter, gocui.ModNone, getTasks); err != nil { + return err + } + if err := g.SetKeybinding("tasks", gocui.KeyEnter, gocui.ModNone, doSSH); err != nil { + return err + } + if err := g.SetKeybinding("tasks", gocui.KeyArrowLeft, gocui.ModNone, previousView); err != nil { + return err + } + if err := g.SetKeybinding("tasks", gocui.KeyArrowDown, gocui.ModNone, cursorDown); err != nil { + return err + } + if err := g.SetKeybinding("tasks", gocui.KeyArrowUp, gocui.ModNone, cursorUp); err != nil { + return err + } + if err := g.SetKeybinding("error", gocui.KeyEsc, gocui.ModNone, hideError); err != nil { + return err + } + if err := g.SetKeybinding("error", gocui.KeyEnter, gocui.ModNone, hideError); err != nil { + return err + } + + return nil +} + +func layout(g *gocui.Gui) error { + maxX, maxY := g.Size() + if v, err := g.SetView("clusters", 0, 0, maxX/3, maxY); err != nil { + if err != gocui.ErrUnknownView { + return err + } + v.Highlight = true + v.Frame = true + v.Title = "Clusters" + v.SelBgColor = gocui.ColorGreen + v.SelFgColor = gocui.ColorBlack + + for _, c := range e.clusterNames { + fmt.Fprintln(v, c) + } + if _, err := g.SetCurrentView("clusters"); err != nil { + return err + } + } + if v, err := g.SetView("services", maxX/3, 0, maxX/3*2, maxY); err != nil { + if err != gocui.ErrUnknownView { + return err + } + v.Highlight = false + v.Frame = true + v.Title = "Services" + v.SelBgColor = gocui.ColorGreen + v.SelFgColor = gocui.ColorBlack + + } + if v, err := g.SetView("tasks", maxX/3*2, 0, maxX, maxY); err != nil { + if err != gocui.ErrUnknownView { + return err + } + v.Highlight = false + v.Frame = true + v.Title = "Tasks" + v.SelBgColor = gocui.ColorGreen + v.SelFgColor = gocui.ColorBlack + + } + return nil +} + +// for debug purposes: + +func main() { + g, err := gocui.NewGui(gocui.OutputNormal) + if err != nil { + log.Panicln(err) + } + defer g.Close() + + g.Cursor = true + + g.SetManagerFunc(layout) + + if err := keybindings(g); err != nil { + log.Panicln(err) + } + + if err := g.MainLoop(); err != nil { + if err != gocui.ErrQuit { + log.Panicln(err) + } + if e.ipAddr != nil { + err = startSSH() + if err != nil { + fmt.Printf("Error: %v\n\n", err) + } + os.Exit(0) + } + } +} diff --git a/ssh.go b/ssh.go new file mode 100644 index 0000000..9537c8a --- /dev/null +++ b/ssh.go @@ -0,0 +1,128 @@ +package main + +import ( + "encoding/binary" + "github.com/docker/docker/pkg/term" + "golang.org/x/crypto/ssh" + "golang.org/x/crypto/ssh/agent" + + "fmt" + "io/ioutil" + "net" + "os" + "os/signal" + "syscall" +) + +func startSSH() error { + width := 80 + height := 24 + sshConfig := &ssh.ClientConfig{ + User: "ec2-user", + Auth: []ssh.AuthMethod{ + SSHAgent(), + PublicKeyFile("/keys/" + *e.keyName), + }, + HostKeyCallback: func(hostname string, remote net.Addr, key ssh.PublicKey) error { + return nil + }, + } + fmt.Printf("Opening connection to %v:22 with key %v", *e.ipAddr, "/keys/"+*e.keyName) + connection, err := ssh.Dial("tcp", *e.ipAddr+":22", sshConfig) + if err != nil { + return fmt.Errorf("Failed to dial: %s", err) + } + session, err := connection.NewSession() + + if err != nil { + return fmt.Errorf("Failed to create session: %s", err) + } + + session.Stdout = os.Stdout + session.Stderr = os.Stderr + session.Stdin = os.Stdin + + modes := ssh.TerminalModes{ + ssh.ECHO: 1, + } + + fd := os.Stdin.Fd() + + if term.IsTerminal(fd) { + oldState, err := term.MakeRaw(fd) + if err != nil { + return err + } + + defer term.RestoreTerminal(fd, oldState) + + winsize, err := term.GetWinsize(fd) + if err == nil { + width = int(winsize.Width) + height = int(winsize.Height) + } + } + + if err := session.RequestPty("xterm", width, height, modes); err != nil { + session.Close() + return fmt.Errorf("request for pseudo terminal failed: %s", err) + } + + // start shell + if err := session.Shell(); err != nil { + return fmt.Errorf("Couldn't start shell: %v", err) + } + go monitorChanges(session, os.Stdout.Fd()) + + session.Wait() + + return nil +} + +func SSHAgent() ssh.AuthMethod { + if sshAgent, err := net.Dial("unix", os.Getenv("SSH_AUTH_SOCK")); err == nil { + return ssh.PublicKeysCallback(agent.NewClient(sshAgent).Signers) + } + return nil +} +func PublicKeyFile(file string) ssh.AuthMethod { + buffer, err := ioutil.ReadFile(file) + if err != nil { + return nil + } + + key, err := ssh.ParsePrivateKey(buffer) + if err != nil { + return nil + } + return ssh.PublicKeys(key) +} + +// Function from: https://github.com/nanobox-io/golang-ssh (Apache 2.0 licensed) +func monitorChanges(session *ssh.Session, fd uintptr) { + sigs := make(chan os.Signal, 1) + + signal.Notify(sigs, syscall.SIGWINCH) + defer signal.Stop(sigs) + + for range sigs { + session.SendRequest("window-change", false, termSize(fd)) + } +} + +// Function from: https://github.com/nanobox-io/golang-ssh (Apache 2.0 licensed) +func termSize(fd uintptr) []byte { + size := make([]byte, 16) + + winsize, err := term.GetWinsize(fd) + if err != nil { + binary.BigEndian.PutUint32(size, uint32(80)) + binary.BigEndian.PutUint32(size[4:], uint32(24)) + return size + } + + binary.BigEndian.PutUint32(size, uint32(winsize.Width)) + binary.BigEndian.PutUint32(size[4:], uint32(winsize.Height)) + + return size +}