Skip to content

Commit

Permalink
fix streaming test now that gorilla websocket performs key validation
Browse files Browse the repository at this point in the history
  • Loading branch information
NyaaaWhatsUpDoc committed Nov 28, 2023
1 parent 5754e32 commit 095b654
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 17 deletions.
27 changes: 15 additions & 12 deletions internal/api/client/streaming/stream.go
Original file line number Diff line number Diff line change
Expand Up @@ -162,24 +162,27 @@ func (m *Module) StreamGETHandler(c *gin.Context) {
}

if token != "" {

// Token was provided, use it to authorize stream.
account, errWithCode = m.processor.Stream().Authorize(c.Request.Context(), token)
if errWithCode != nil {
apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGetV1)
return
}

} else {

// No explicit token was provided:
// try regular oauth as a last resort.
account, errWithCode = func() (*gtsmodel.Account, gtserror.WithCode) {
authed, err := oauth.Authed(c, true, true, true, true)
if err != nil {
return nil, gtserror.NewErrorUnauthorized(err, err.Error())
}

return authed.Account, nil
}()
}
authed, err := oauth.Authed(c, true, true, true, true)
if err != nil {
errWithCode := gtserror.NewErrorUnauthorized(err, err.Error())
apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGetV1)
return
}

if errWithCode != nil {
apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGetV1)
return
// Set the auth'ed account.
account = authed.Account
}

// Get the initial requested stream type, if there is one.
Expand Down
15 changes: 10 additions & 5 deletions internal/api/client/streaming/streaming_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package streaming_test

import (
"bufio"
"encoding/base64"
"errors"
"fmt"
"io/ioutil"
Expand Down Expand Up @@ -227,17 +228,21 @@ func (suite *StreamingTestSuite) TestSecurityHeader() {
ctx.Request.Header.Set("Connection", "upgrade")
ctx.Request.Header.Set("Upgrade", "websocket")
ctx.Request.Header.Set("Sec-Websocket-Version", "13")
ctx.Request.Header.Set("Sec-Websocket-Key", "abcd")
key := [16]byte{'h', 'e', 'l', 'l', 'o', ' ', 'w', 'o', 'r', 'l', 'd'}
key64 := base64.StdEncoding.EncodeToString(key[:]) // sec-websocket-key must be base64 encoded and 16 bytes long
ctx.Request.Header.Set("Sec-Websocket-Key", key64)

suite.streamingModule.StreamGETHandler(ctx)

// check response
suite.EqualValues(http.StatusOK, recorder.Code)

result := recorder.Result()
defer result.Body.Close()
_, err := ioutil.ReadAll(result.Body)
b, err := ioutil.ReadAll(result.Body)
suite.NoError(err)

// check response
if !suite.EqualValues(http.StatusOK, recorder.Code) {
suite.T().Logf("%s", b)
}
}

func TestStreamingTestSuite(t *testing.T) {
Expand Down

0 comments on commit 095b654

Please sign in to comment.