From 095b654c3a6dec5c53d270366d3c4294c06a3104 Mon Sep 17 00:00:00 2001 From: kim Date: Tue, 28 Nov 2023 10:43:08 +0000 Subject: [PATCH] fix streaming test now that gorilla websocket performs key validation --- internal/api/client/streaming/stream.go | 27 ++++++++++--------- .../api/client/streaming/streaming_test.go | 15 +++++++---- 2 files changed, 25 insertions(+), 17 deletions(-) diff --git a/internal/api/client/streaming/stream.go b/internal/api/client/streaming/stream.go index 1f34e34470..2d1c483419 100644 --- a/internal/api/client/streaming/stream.go +++ b/internal/api/client/streaming/stream.go @@ -162,24 +162,27 @@ func (m *Module) StreamGETHandler(c *gin.Context) { } if token != "" { + // Token was provided, use it to authorize stream. account, errWithCode = m.processor.Stream().Authorize(c.Request.Context(), token) + if errWithCode != nil { + apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGetV1) + return + } + } else { + // No explicit token was provided: // try regular oauth as a last resort. - account, errWithCode = func() (*gtsmodel.Account, gtserror.WithCode) { - authed, err := oauth.Authed(c, true, true, true, true) - if err != nil { - return nil, gtserror.NewErrorUnauthorized(err, err.Error()) - } - - return authed.Account, nil - }() - } + authed, err := oauth.Authed(c, true, true, true, true) + if err != nil { + errWithCode := gtserror.NewErrorUnauthorized(err, err.Error()) + apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGetV1) + return + } - if errWithCode != nil { - apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGetV1) - return + // Set the auth'ed account. + account = authed.Account } // Get the initial requested stream type, if there is one. diff --git a/internal/api/client/streaming/streaming_test.go b/internal/api/client/streaming/streaming_test.go index 30574080ef..df40098905 100644 --- a/internal/api/client/streaming/streaming_test.go +++ b/internal/api/client/streaming/streaming_test.go @@ -19,6 +19,7 @@ package streaming_test import ( "bufio" + "encoding/base64" "errors" "fmt" "io/ioutil" @@ -227,17 +228,21 @@ func (suite *StreamingTestSuite) TestSecurityHeader() { ctx.Request.Header.Set("Connection", "upgrade") ctx.Request.Header.Set("Upgrade", "websocket") ctx.Request.Header.Set("Sec-Websocket-Version", "13") - ctx.Request.Header.Set("Sec-Websocket-Key", "abcd") + key := [16]byte{'h', 'e', 'l', 'l', 'o', ' ', 'w', 'o', 'r', 'l', 'd'} + key64 := base64.StdEncoding.EncodeToString(key[:]) // sec-websocket-key must be base64 encoded and 16 bytes long + ctx.Request.Header.Set("Sec-Websocket-Key", key64) suite.streamingModule.StreamGETHandler(ctx) - // check response - suite.EqualValues(http.StatusOK, recorder.Code) - result := recorder.Result() defer result.Body.Close() - _, err := ioutil.ReadAll(result.Body) + b, err := ioutil.ReadAll(result.Body) suite.NoError(err) + + // check response + if !suite.EqualValues(http.StatusOK, recorder.Code) { + suite.T().Logf("%s", b) + } } func TestStreamingTestSuite(t *testing.T) {