Skip to content

Commit

Permalink
session: Add Connect and Wait
Browse files Browse the repository at this point in the history
This commit adds the Connect and Wait methods into session. This gives
the user a way to block the program until the session runs into an error
or the given ctx is done.

In most cases, Connect is useful when combined with
signal.NotifyContext, and so Connect is preferred over Open.
  • Loading branch information
diamondburned committed Aug 14, 2022
1 parent 0ae6f66 commit 258b614
Show file tree
Hide file tree
Showing 2 changed files with 140 additions and 15 deletions.
109 changes: 95 additions & 14 deletions session/session.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,10 @@ import (
// ErrMFA is returned if the account requires a 2FA code to log in.
var ErrMFA = errors.New("account has 2FA enabled")

// ErrClosed is returned if the Session is closed, either because it's already
// closed (and Close is being called again) or it was never started.
var ErrClosed = errors.New("Session is closed")

// Session manages both the API and Gateway. As such, Session inherits all of
// API's methods, as well has the Handler used for Gateway.
type Session struct {
Expand Down Expand Up @@ -189,16 +193,72 @@ func (s *Session) gatewayIsAlive() bool {
}
}

// Connect opens the Discord gateway and waits until an unrecoverable error
// occurs. Always prefer this method over Open. Note that Connect will return
// when ctx is done or when s.Close is called.
//
// As an odd case, when ctx is done and if the gateway is already finished
// connecting, then a nil error will be returned (unless the gateway has an
// error). This is contrary to the common behavior of a ctx function returning
// ctx.Err().
func (s *Session) Connect(ctx context.Context) error {
if err := s.Open(ctx); err != nil {
return err
}

if err := s.Wait(ctx); err != nil && ctx.Err() == nil {
return err
}

return nil
}

func (s *Session) initConnect(ctx context.Context) (<-chan struct{}, error) {
evCh := make(chan interface{})

s.state.Lock()
defer s.state.Unlock()

if s.state.cancel != nil {
if err := s.close(); err != nil {
return nil, err
}
}

if s.state.gateway == nil {
g, err := gateway.NewWithIdentifier(ctx, s.state.id)
if err != nil {
return nil, err
}
s.state.gateway = g
}

ctx, cancel := context.WithCancel(context.Background())
s.state.ctx = ctx
s.state.cancel = cancel

// TODO: change this to AddSyncHandler.
rm := s.AddHandler(evCh)
defer rm()

opCh := s.state.gateway.Connect(s.state.ctx)

doneCh := ophandler.Loop(opCh, s.Handler)
s.state.doneCh = doneCh

return doneCh, nil
}

// Open opens the Discord gateway and its handler, then waits until either the
// Ready or Resumed event gets through.
// Ready or Resumed event gets through. Prefer using Connect instead of Open.
func (s *Session) Open(ctx context.Context) error {
evCh := make(chan interface{})

s.state.Lock()
defer s.state.Unlock()

if s.state.cancel != nil {
if err := s.close(ctx); err != nil {
if err := s.close(); err != nil {
return err
}
}
Expand All @@ -225,7 +285,7 @@ func (s *Session) Open(ctx context.Context) error {
for {
select {
case <-ctx.Done():
s.close(ctx)
s.close()
return ctx.Err()

case <-s.state.doneCh:
Expand All @@ -241,6 +301,34 @@ func (s *Session) Open(ctx context.Context) error {
}
}

// Wait blocks until either ctx is done or the gateway stumbles on an
// unrecoverable error.
func (s *Session) Wait(ctx context.Context) error {
s.state.Lock()
doneCh := s.state.doneCh
s.state.Unlock()

if doneCh == nil {
return ErrClosed
}

for {
select {
case <-ctx.Done():
s.Close()
// Prefer gateway errors over context errors.
if err := s.GatewayError(); err != nil {
return err
}
return ctx.Err()

case <-doneCh:
// Event loop died.
return s.GatewayError()
}
}
}

// WithContext returns a shallow copy of Session with the context replaced in
// the API client. All methods called on the returned Session will use this
// given context.
Expand All @@ -261,26 +349,19 @@ func (s *Session) Close() error {
s.state.Lock()
defer s.state.Unlock()

return s.close(context.Background())
return s.close()
}

func (s *Session) close(ctx context.Context) error {
func (s *Session) close() error {
if s.state.cancel == nil {
return errors.New("Session is already closed")
return ErrClosed
}

s.state.cancel()
s.state.cancel = nil
s.state.ctx = nil

// Wait until we've successfully disconnected.
select {
case <-ctx.Done():
return errors.Wrap(ctx.Err(), "cannot wait for gateway exit")
case <-s.state.doneCh:
// ok
}

<-s.state.doneCh
s.state.doneCh = nil

return s.state.gateway.LastError()
Expand Down
46 changes: 45 additions & 1 deletion session/session_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ func TestSession(t *testing.T) {
}

if ready, ok := <-readyCh; !ok {
t.Fatal("ready not received")
t.Error("ready not received")
} else {
now := time.Now()
t.Logf("%s: logged in as %s", now.Format(time.StampMilli), ready.User.Username)
Expand All @@ -48,3 +48,47 @@ func TestSession(t *testing.T) {
time.Sleep(time.Second)
}
}

func TestSessionConnect(t *testing.T) {
attempts := 1
timeout := 15 * time.Second

if !testing.Short() {
attempts = 5
timeout = time.Minute // 5s-10s each reconnection
}

ctx, cancel := context.WithTimeout(context.Background(), timeout)
t.Cleanup(cancel)

env := testenv.Must(t)

readyCh := make(chan *gateway.ReadyEvent, 1)

s := NewWithIntents(env.BotToken, gateway.IntentGuilds)
s.AddHandler(readyCh)

for i := 0; i < attempts; i++ {
ctx, cancel := context.WithCancel(ctx)

go func() {
select {
case ready := <-readyCh:
now := time.Now()
t.Logf("%s: logged in as %s", now.Format(time.StampMilli), ready.User.Username)
cancel()
case <-ctx.Done():
t.Error("ready not received")
}
}()

if err := s.Connect(ctx); err != nil {
t.Fatal("failed to open:", err)
}

cancel()

// Hold for an additional one second.
time.Sleep(time.Second)
}
}

0 comments on commit 258b614

Please sign in to comment.