Skip to content

Commit

Permalink
Merge pull request #433 from lstocchi/i432
Browse files Browse the repository at this point in the history
win-sshproxy.tid created before thread id is available
  • Loading branch information
openshift-merge-bot[bot] authored Nov 29, 2024
2 parents ec2ed7d + 08769de commit fe2d5d2
Show file tree
Hide file tree
Showing 4 changed files with 93 additions and 59 deletions.
24 changes: 23 additions & 1 deletion cmd/win-sshproxy/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,12 @@ import (
"path/filepath"
"strings"
"syscall"
"time"
"unsafe"

"github.com/containers/gvisor-tap-vsock/pkg/sshclient"
"github.com/containers/gvisor-tap-vsock/pkg/types"
"github.com/containers/gvisor-tap-vsock/pkg/utils"
"github.com/containers/winquit/pkg/winquit"
"github.com/sirupsen/logrus"
"golang.org/x/sync/errgroup"
Expand Down Expand Up @@ -173,11 +175,31 @@ func saveThreadId() (uint32, error) {
return 0, err
}
defer file.Close()
tid := winquit.GetCurrentMessageLoopThreadId()

tid, err := getThreadId()
if err != nil {
return 0, err
}

fmt.Fprintf(file, "%d:%d\n", os.Getpid(), tid)
return tid, nil
}

func getThreadId() (uint32, error) {
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()

getTid := func() (uint32, error) {
tid := winquit.GetCurrentMessageLoopThreadId()
if tid != 0 {
return tid, nil
}
return 0, fmt.Errorf("failed to get thread ID")
}

return utils.Retry(ctx, getTid, "Waiting for message loop thread id")
}

// Creates an "error" style pop-up window
func alert(caption string) int {
// Error box style
Expand Down
59 changes: 4 additions & 55 deletions pkg/sshclient/ssh_forwarder.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import (
"time"

"github.com/containers/gvisor-tap-vsock/pkg/fs"
"github.com/containers/gvisor-tap-vsock/pkg/utils"
"github.com/pkg/errors"
"github.com/sirupsen/logrus"
)
Expand Down Expand Up @@ -98,13 +99,13 @@ func connectForward(ctx context.Context, bastion *Bastion) (CloseWriteConn, erro
if err == nil {
break
}
if bastionRetries > 2 || !sleep(ctx, 200*time.Millisecond) {
if bastionRetries > 2 || !utils.Sleep(ctx, 200*time.Millisecond) {
return nil, errors.Wrapf(err, "Couldn't reestablish ssh connection: %s", bastion.Host)
}
}
}

if !sleep(ctx, 200*time.Millisecond) {
if !utils.Sleep(ctx, 200*time.Millisecond) {
retries = 3
}
}
Expand Down Expand Up @@ -173,7 +174,7 @@ func setupProxy(ctx context.Context, socketURI *url.URL, dest *url.URL, identity
}
return CreateBastion(dest, passphrase, identity, conn, connectFunc)
}
bastion, err := retry(ctx, createBastion, "Waiting for sshd")
bastion, err := utils.Retry(ctx, createBastion, "Waiting for sshd")
if err != nil {
return &SSHForward{}, fmt.Errorf("setupProxy failed: %w", err)
}
Expand All @@ -183,37 +184,6 @@ func setupProxy(ctx context.Context, socketURI *url.URL, dest *url.URL, identity
return &SSHForward{listener, bastion, socketURI}, nil
}

const maxRetries = 60
const initialBackoff = 100 * time.Millisecond

func retry[T comparable](ctx context.Context, retryFunc func() (T, error), retryMsg string) (T, error) {
var (
returnVal T
err error
)

backoff := initialBackoff

loop:
for i := 0; i < maxRetries; i++ {
select {
case <-ctx.Done():
break loop
default:
// proceed
}

returnVal, err = retryFunc()
if err == nil {
return returnVal, nil
}
logrus.Debugf("%s (%s)", retryMsg, backoff)
sleep(ctx, backoff)
backoff = backOff(backoff)
}
return returnVal, fmt.Errorf("timeout: %w", err)
}

func acceptConnection(ctx context.Context, listener net.Listener, bastion *Bastion, socketURI *url.URL) error {
con, err := listener.Accept()
if err != nil {
Expand Down Expand Up @@ -256,24 +226,3 @@ func forward(src io.ReadCloser, dest CloseWriteStream, complete *sync.WaitGroup)
// Trigger an EOF on the other end
_ = dest.CloseWrite()
}

func backOff(delay time.Duration) time.Duration {
if delay == 0 {
delay = 5 * time.Millisecond
} else {
delay *= 2
}
if delay > time.Second {
delay = time.Second
}
return delay
}

func sleep(ctx context.Context, wait time.Duration) bool {
select {
case <-ctx.Done():
return false
case <-time.After(wait):
return true
}
}
61 changes: 61 additions & 0 deletions pkg/utils/retry.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
package utils

import (
"context"
"fmt"
"time"

"github.com/sirupsen/logrus"
)

const maxRetries = 60
const initialBackoff = 100 * time.Millisecond

func Retry[T comparable](ctx context.Context, retryFunc func() (T, error), retryMsg string) (T, error) {
var (
returnVal T
err error
)

backoff := initialBackoff

loop:
for i := 0; i < maxRetries; i++ {
select {
case <-ctx.Done():
break loop
default:
// proceed
}

returnVal, err = retryFunc()
if err == nil {
return returnVal, nil
}
logrus.Debugf("%s (%s)", retryMsg, backoff)
Sleep(ctx, backoff)
backoff = backOff(backoff)
}
return returnVal, fmt.Errorf("timeout: %w", err)
}

func backOff(delay time.Duration) time.Duration {
if delay == 0 {
delay = 5 * time.Millisecond
} else {
delay *= 2
}
if delay > time.Second {
delay = time.Second
}
return delay
}

func Sleep(ctx context.Context, wait time.Duration) bool {
select {
case <-ctx.Done():
return false
case <-time.After(wait):
return true
}
}
8 changes: 5 additions & 3 deletions test-win-sshproxy/basic_test.go
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
//go:build windows
// +build windows

package e2e
Expand Down Expand Up @@ -25,15 +26,16 @@ var _ = Describe("connectivity", func() {
err := startProxy()
Expect(err).ShouldNot(HaveOccurred())

var pid uint32
var pid, tid uint32
for i := 0; i < 20; i++ {
pid, _, err = readTid()
if err == nil {
pid, tid, err = readTid()
if err == nil && tid != 0 {
break
}
time.Sleep(100 * time.Millisecond)
}

Expect(tid).ShouldNot(Equal(0))
Expect(err).ShouldNot(HaveOccurred())
proc, err := os.FindProcess(int(pid))
Expect(err).ShouldNot(HaveOccurred())
Expand Down

0 comments on commit fe2d5d2

Please sign in to comment.