From 7033f9ff9aec0774f8e4927b674581e4574d6cd0 Mon Sep 17 00:00:00 2001 From: Ce Gao Date: Thu, 13 Oct 2022 10:28:39 +0800 Subject: [PATCH] feat: Support forward Signed-off-by: Ce Gao --- examples/python-basic/build.envd | 1 + pkg/app/create.go | 29 ++++++++++++++++++++++--- pkg/ssh/ssh.go | 37 ++++++++++++++++++++++++++++++++ 3 files changed, 64 insertions(+), 3 deletions(-) diff --git a/examples/python-basic/build.envd b/examples/python-basic/build.envd index d3459ad7b..a006b104a 100644 --- a/examples/python-basic/build.envd +++ b/examples/python-basic/build.envd @@ -13,3 +13,4 @@ def build(): } ) runtime.environ(env={"ENVD_MODE": "DEV"}) + config.jupyter() diff --git a/pkg/app/create.go b/pkg/app/create.go index be8c99116..f22ae8d26 100644 --- a/pkg/app/create.go +++ b/pkg/app/create.go @@ -15,6 +15,7 @@ package app import ( + "fmt" "strings" "time" @@ -29,6 +30,7 @@ import ( "github.com/tensorchord/envd/pkg/ssh" sshconfig "github.com/tensorchord/envd/pkg/ssh/config" "github.com/tensorchord/envd/pkg/types" + "github.com/tensorchord/envd/pkg/util/netutil" ) var CommandCreate = &cli.Command{ @@ -132,6 +134,7 @@ func create(clicontext *cli.Context) error { // TODO(gaocegege): Test why it fails. if !clicontext.Bool("detach") { + outputChannel := make(chan error) opt := ssh.DefaultOptions() opt.PrivateKeyPath = clicontext.Path("private-key") opt.Port = res.SSHPort @@ -141,10 +144,30 @@ func create(clicontext *cli.Context) error { sshClient, err := ssh.NewClient(opt) if err != nil { - return errors.Wrap(err, "failed to create the ssh client") + outputChannel <- errors.Wrap(err, "failed to create the ssh client") } - if err := sshClient.Attach(); err != nil { - return errors.Wrap(err, "failed to attach to the container") + + go func() { + if err := sshClient.Attach(); err != nil { + outputChannel <- errors.Wrap(err, "failed to attach to the container") + } + }() + + jupyterLocalPort, err := netutil.GetFreePort() + if err != nil { + return errors.Wrap(err, "failed to get a free port") + } + + localAddress := fmt.Sprintf("%s:%d", localhost, jupyterLocalPort) + remoteAddress := fmt.Sprintf("%s:%s", localhost, "8888") + go func() { + if err := sshClient.LocalForward(localAddress, remoteAddress); err != nil { + outputChannel <- errors.Wrap(err, "failed to forward to local port") + } + }() + + if err := <-outputChannel; err != nil { + return err } } return nil diff --git a/pkg/ssh/ssh.go b/pkg/ssh/ssh.go index 737374f7a..d1edc402a 100644 --- a/pkg/ssh/ssh.go +++ b/pkg/ssh/ssh.go @@ -42,6 +42,7 @@ import ( type Client interface { Attach() error ExecWithOutput(cmd string) ([]byte, error) + LocalForward(localAddress, targetAddress string) error Close() error } @@ -282,6 +283,42 @@ func (c generalClient) Attach() error { return errors.Wrap(err, "command failed") } +func (c generalClient) LocalForward(localAddress, targetAddress string) error { + localListener, err := net.Listen("tcp", localAddress) + if err != nil { + return errors.Wrap(err, "net.Listen failed") + } + + logrus.Debug("begin to forward " + localAddress + " to " + targetAddress) + for { + localCon, err := localListener.Accept() + if err != nil { + return errors.Wrap(err, "listen.Accept failed") + } + + sshConn, err := c.cli.Dial("tcp", targetAddress) + if err != nil { + return errors.Wrap(err, "listen.Accept failed") + } + + // Copy local.Reader to sshConn.Writer + go func() { + _, err = io.Copy(sshConn, localCon) + if err != nil { + logrus.Debugf("io.Copy failed: %v", err) + } + }() + + // Copy sshConn.Reader to localCon.Writer + go func() { + _, err = io.Copy(localCon, sshConn) + if err != nil { + logrus.Debugf("io.Copy failed: %v", err) + } + }() + } +} + func isTerminal(r io.Reader) (int, bool) { switch v := r.(type) { case *os.File: