Skip to content

Commit

Permalink
feat: Support forward
Browse files Browse the repository at this point in the history
Signed-off-by: Ce Gao <[email protected]>
  • Loading branch information
gaocegege committed Oct 24, 2022
1 parent 4cd6fef commit 7033f9f
Show file tree
Hide file tree
Showing 3 changed files with 64 additions and 3 deletions.
1 change: 1 addition & 0 deletions examples/python-basic/build.envd
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,4 @@ def build():
}
)
runtime.environ(env={"ENVD_MODE": "DEV"})
config.jupyter()
29 changes: 26 additions & 3 deletions pkg/app/create.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
package app

import (
"fmt"
"strings"
"time"

Expand All @@ -29,6 +30,7 @@ import (
"github.com/tensorchord/envd/pkg/ssh"
sshconfig "github.com/tensorchord/envd/pkg/ssh/config"
"github.com/tensorchord/envd/pkg/types"
"github.com/tensorchord/envd/pkg/util/netutil"
)

var CommandCreate = &cli.Command{
Expand Down Expand Up @@ -132,6 +134,7 @@ func create(clicontext *cli.Context) error {

// TODO(gaocegege): Test why it fails.
if !clicontext.Bool("detach") {
outputChannel := make(chan error)
opt := ssh.DefaultOptions()
opt.PrivateKeyPath = clicontext.Path("private-key")
opt.Port = res.SSHPort
Expand All @@ -141,10 +144,30 @@ func create(clicontext *cli.Context) error {

sshClient, err := ssh.NewClient(opt)
if err != nil {
return errors.Wrap(err, "failed to create the ssh client")
outputChannel <- errors.Wrap(err, "failed to create the ssh client")
}
if err := sshClient.Attach(); err != nil {
return errors.Wrap(err, "failed to attach to the container")

go func() {
if err := sshClient.Attach(); err != nil {
outputChannel <- errors.Wrap(err, "failed to attach to the container")
}
}()

jupyterLocalPort, err := netutil.GetFreePort()
if err != nil {
return errors.Wrap(err, "failed to get a free port")
}

localAddress := fmt.Sprintf("%s:%d", localhost, jupyterLocalPort)
remoteAddress := fmt.Sprintf("%s:%s", localhost, "8888")
go func() {
if err := sshClient.LocalForward(localAddress, remoteAddress); err != nil {
outputChannel <- errors.Wrap(err, "failed to forward to local port")
}
}()

if err := <-outputChannel; err != nil {
return err
}
}
return nil
Expand Down
37 changes: 37 additions & 0 deletions pkg/ssh/ssh.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ import (
type Client interface {
Attach() error
ExecWithOutput(cmd string) ([]byte, error)
LocalForward(localAddress, targetAddress string) error
Close() error
}

Expand Down Expand Up @@ -282,6 +283,42 @@ func (c generalClient) Attach() error {
return errors.Wrap(err, "command failed")
}

func (c generalClient) LocalForward(localAddress, targetAddress string) error {
localListener, err := net.Listen("tcp", localAddress)
if err != nil {
return errors.Wrap(err, "net.Listen failed")
}

logrus.Debug("begin to forward " + localAddress + " to " + targetAddress)
for {
localCon, err := localListener.Accept()
if err != nil {
return errors.Wrap(err, "listen.Accept failed")
}

sshConn, err := c.cli.Dial("tcp", targetAddress)
if err != nil {
return errors.Wrap(err, "listen.Accept failed")
}

// Copy local.Reader to sshConn.Writer
go func() {
_, err = io.Copy(sshConn, localCon)
if err != nil {
logrus.Debugf("io.Copy failed: %v", err)
}
}()

// Copy sshConn.Reader to localCon.Writer
go func() {
_, err = io.Copy(localCon, sshConn)
if err != nil {
logrus.Debugf("io.Copy failed: %v", err)
}
}()
}
}

func isTerminal(r io.Reader) (int, bool) {
switch v := r.(type) {
case *os.File:
Expand Down

0 comments on commit 7033f9f

Please sign in to comment.