Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Upgrader.NegotiateSubprotocol #606

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 26 additions & 4 deletions server.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ package websocket
import (
"bufio"
"errors"
"fmt"
"io"
"net/http"
"net/url"
Expand Down Expand Up @@ -46,6 +47,7 @@ type Upgrader struct {
// WriteBufferSize.
WriteBufferPool BufferPool

// Subprotocols have lower priority than NegotiateSuprotocol.
// Subprotocols specifies the server's supported protocols in order of
// preference. If this field is not nil, then the Upgrade method negotiates a
// subprotocol by selecting the first match in this list with a protocol
Expand All @@ -72,6 +74,14 @@ type Upgrader struct {
// guarantee that compression will be supported. Currently only "no context
// takeover" modes are supported.
EnableCompression bool

// NegotiateSubprotocol has higher priority than Subprotocols.
// NegotiateSubprotocol returns the negotiated subprotocol for the handshake
// request. If the returned string is "", then the the Sec-Websocket-Protocol header
// is not included in the handshake response. If the function returns an error, then
// Upgrade responds to the client with http.StatusBadRequest.
// If this function is not nil, then the Upgrader.Subportocols field is ignored.
NegotiateSubprotocol func(r *http.Request) (string, error)
}

func (u *Upgrader) returnError(w http.ResponseWriter, r *http.Request, status int, reason string) (*Conn, error) {
Expand All @@ -98,7 +108,7 @@ func checkSameOrigin(r *http.Request) bool {
return equalASCIIFold(u.Host, r.Host)
}

func (u *Upgrader) selectSubprotocol(r *http.Request, responseHeader http.Header) string {
func (u *Upgrader) selectSubprotocol(r *http.Request) string {
if u.Subprotocols != nil {
clientProtocols := Subprotocols(r)
for _, serverProtocol := range u.Subprotocols {
Expand All @@ -108,20 +118,23 @@ func (u *Upgrader) selectSubprotocol(r *http.Request, responseHeader http.Header
}
}
}
} else if responseHeader != nil {
return responseHeader.Get("Sec-Websocket-Protocol")
}
return ""
}

// Upgrade upgrades the HTTP server connection to the WebSocket protocol.
//
// The responseHeader is included in the response to the client's upgrade

// request. Use the responseHeader to specify cookies (Set-Cookie). To specify
// subprotocols supported by the server, set Upgrader.Subprotocols directly.
//
// If the upgrade fails, then Upgrade replies to the client with an HTTP error
// response.
//
// The responseHeader does not support negotiated subprotocol(Sec-Websocket-Protocol)
// IF necessary,please use Upgrader.NegotiateSubprotocol and Upgrader.Subprotocols
// Use the method to view the Upgrader struct.
func (u *Upgrader) Upgrade(w http.ResponseWriter, r *http.Request, responseHeader http.Header) (*Conn, error) {
const badHandshake = "websocket: the client is not using the websocket protocol: "

Expand Down Expand Up @@ -158,7 +171,16 @@ func (u *Upgrader) Upgrade(w http.ResponseWriter, r *http.Request, responseHeade
return u.returnError(w, r, http.StatusBadRequest, "websocket: not a websocket handshake: 'Sec-WebSocket-Key' header must be Base64 encoded value of 16-byte in length")
}

subprotocol := u.selectSubprotocol(r, responseHeader)
subprotocol := ""
if u.NegotiateSubprotocol != nil {
str, err := u.NegotiateSubprotocol(r)
if err != nil {
return u.returnError(w, r, http.StatusBadRequest, fmt.Sprintf("websocket:handshake negotiation protocol error:%s", err))
}
subprotocol = str
} else {
subprotocol = u.selectSubprotocol(r)
}

// Negotiate PMCE
var compress bool
Expand Down
73 changes: 73 additions & 0 deletions server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,10 @@ package websocket
import (
"bufio"
"bytes"
"errors"
"net"
"net/http"
"net/http/httptest"
"reflect"
"strings"
"testing"
Expand Down Expand Up @@ -117,3 +119,74 @@ func TestBufioReuse(t *testing.T) {
}
}
}

var negotiateSubprotocolTests = []struct {
*Upgrader
match bool
shouldErr bool
}{
{
&Upgrader{
NegotiateSubprotocol: func(r *http.Request) (s string, err error) { return "json", nil },
}, true, false,
},
{
&Upgrader{
Subprotocols: []string{"json"},
}, true, false,
},
{
&Upgrader{
Subprotocols: []string{"not-match"},
}, false, false,
},
{
&Upgrader{
NegotiateSubprotocol: func(r *http.Request) (s string, err error) { return "", errors.New("not-match") },
}, false, true,
},
}

func TestNegotiateSubprotocol(t *testing.T) {
for i := range negotiateSubprotocolTests {
upgrade := negotiateSubprotocolTests[i].Upgrader

s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
upgrade.Upgrade(w, r, nil)
}))

req, err := http.NewRequest("GET", s.URL, strings.NewReader(""))
if err != nil {
t.Fatalf("NewRequest retuened error %v", err)
}

req.Header.Set("Connection", "upgrade")
req.Header.Set("Upgrade", "websocket")
req.Header.Set("Sec-Websocket-Version", "13")
req.Header.Set("Sec-Websocket-Protocol", "json")
req.Header.Set("Sec-Websocket-key", "dGhlIHNhbXBsZSBub25jZQ==")

resp, err := http.DefaultClient.Do(req)
if err != nil {
t.Fatalf("Do returned error %v", err)
}

if negotiateSubprotocolTests[i].shouldErr && resp.StatusCode != http.StatusBadRequest {
t.Errorf("The expecred status code is %d,actual status code is %d", http.StatusBadRequest, resp.StatusCode)
} else {
if negotiateSubprotocolTests[i].match {
protocol := resp.Header.Get("Sec-Websocket-Protocol")
if protocol != "json" {
t.Errorf("Negotiation protocol failed,request protocol is json,reponese protocol is %s", protocol)
}
} else {
if _, ok := resp.Header["Sec-Websocket-Protocol"]; ok {
t.Errorf("Negotiation protocol failed,Sec-Websocket-Protocol field should be empty")
}
}
}
s.Close()
resp.Body.Close()
}

}