Skip to content

Commit

Permalink
Fix
Browse files Browse the repository at this point in the history
Signed-off-by: Edward McFarlane <[email protected]>
  • Loading branch information
emcfarlane committed Sep 30, 2024
1 parent be24644 commit a20d53d
Show file tree
Hide file tree
Showing 2 changed files with 212 additions and 52 deletions.
101 changes: 84 additions & 17 deletions authn.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,9 @@ package authn
import (
"context"
"fmt"
"mime"
"net/http"
"net/url"
"strings"

"connectrpc.com/connect"
Expand Down Expand Up @@ -72,36 +74,61 @@ func Errorf(template string, args ...any) *connect.Error {

// InferProtocol returns the inferred RPC protocol. It is one of
// [connect.ProtocolConnect], [connect.ProtocolGRPC], or [connect.ProtocolGRPCWeb].
func InferProtocol(request *http.Request) string {
ct := request.Header.Get("Content-Type")
func InferProtocol(request *http.Request) (string, bool) {
const (
grpcContentTypeDefault = "application/grpc"
grpcContentTypePrefix = "application/grpc+"
grpcWebContentTypeDefault = "application/grpc-web"
grpcWebContentTypePrefix = "application/grpc-web+"
connectStreamingContentTypePrefix = "application/connect+"
connectUnaryContentTypePrefix = "application/"
connectUnaryMessageQueryParameter = "message"
connectUnaryEncodingQueryParameter = "encoding"
)
ctype := canonicalizeContentType(request.Header.Get("Content-Type"))
isPost := request.Method == http.MethodPost
isGet := request.Method == http.MethodGet
switch {
case strings.HasPrefix(ct, "application/grpc-web"):
return connect.ProtocolGRPCWeb
case strings.HasPrefix(ct, "application/grpc"):
return connect.ProtocolGRPC
case isPost && (ctype == grpcContentTypeDefault || strings.HasPrefix(ctype, grpcContentTypePrefix)):
return connect.ProtocolGRPC, true
case isPost && (ctype == grpcWebContentTypeDefault || strings.HasPrefix(ctype, grpcWebContentTypePrefix)):
return connect.ProtocolGRPCWeb, true
case isPost && strings.HasPrefix(ctype, connectStreamingContentTypePrefix):
return connect.ProtocolConnect, true
case isPost && strings.HasPrefix(ctype, connectUnaryContentTypePrefix):
return connect.ProtocolConnect, true
case isGet:
query := request.URL.Query()
hasMessage := query.Has(connectUnaryMessageQueryParameter)
hasEncoding := query.Has(connectUnaryEncodingQueryParameter)
if !hasMessage || !hasEncoding {
return "", false
}
return connect.ProtocolConnect, true
default:
return connect.ProtocolConnect
return "", false
}
}

// InferProcedure returns the inferred RPC procedure. It is of the form
// "/service/method". If the request path does not contain a procedure name, the
// entire path is returned.
func InferProcedure(request *http.Request) string {
path := strings.TrimSuffix(request.URL.Path, "/")
// InferProcedure returns the inferred RPC procedure. It's returned in the form
// "/service/method" if a valid suffix is found. If the request doesn't contain
// a service and method, the entire path and false is returned.
func InferProcedure(url *url.URL) (string, bool) {
path := url.Path
ultimate := strings.LastIndex(path, "/")
if ultimate < 0 {
return request.URL.Path
return url.Path, false
}
penultimate := strings.LastIndex(path[:ultimate], "/")
if penultimate < 0 {
return request.URL.Path
return url.Path, false
}
procedure := path[penultimate:]
if len(procedure) < 4 { // two slashes + service + method
return request.URL.Path
// Ensure that the service and method are non-empty.
if ultimate == len(path)-1 || penultimate == ultimate-1 {
return url.Path, false
}
return procedure
return procedure, true
}

// Middleware is server-side HTTP middleware that authenticates RPC requests.
Expand Down Expand Up @@ -147,3 +174,43 @@ func (m *Middleware) Wrap(handler http.Handler) http.Handler {
handler.ServeHTTP(writer, request)
})
}

func canonicalizeContentType(contentType string) string {
// Typically, clients send Content-Type in canonical form, without
// parameters. In those cases, we'd like to avoid parsing and
// canonicalization overhead.
//
// See https://www.rfc-editor.org/rfc/rfc2045.html#section-5.1 for a full
// grammar.
var slashes int
for _, r := range contentType {
switch {
case r >= 'a' && r <= 'z':
case r == '.' || r == '+' || r == '-':
case r == '/':
slashes++
default:
return canonicalizeContentTypeSlow(contentType)
}
}
if slashes == 1 {
return contentType
}
return canonicalizeContentTypeSlow(contentType)
}

func canonicalizeContentTypeSlow(contentType string) string {
base, params, err := mime.ParseMediaType(contentType)
if err != nil {
return contentType
}
// According to RFC 9110 Section 8.3.2, the charset parameter value should be treated as case-insensitive.
// mime.FormatMediaType canonicalizes parameter names, but not parameter values,
// because the case sensitivity of a parameter value depends on its semantics.
// Therefore, the charset parameter value should be canonicalized here.
// ref.) https://httpwg.org/specs/rfc9110.html#rfc.section.8.3.2
if charset, ok := params["charset"]; ok {
params["charset"] = strings.ToLower(charset)
}
return mime.FormatMediaType(base, params)
}
163 changes: 128 additions & 35 deletions authn_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ import (
"io"
"net/http"
"net/http/httptest"
"net/url"
"strings"
"testing"

Expand Down Expand Up @@ -110,56 +111,148 @@ func authenticate(_ context.Context, req *http.Request) (any, error) {

func TestInferProcedures(t *testing.T) {
t.Parallel()
testProcedures := [][2]string{
{"/empty.v1/GetEmpty", "/empty.v1/GetEmpty"},
{"/empty.v1/GetEmpty/", "/empty.v1/GetEmpty"},
{"/empty.v1/GetEmpty/", "/empty.v1/GetEmpty"},
{"/prefix/empty.v1/GetEmpty/", "/empty.v1/GetEmpty"},
{"/", "/"},
{"/invalid/", "/invalid/"},
tests := []struct {
name string
url string
want string
valid bool
}{
{name: "simple", url: "http://localhost:8080/foo", want: "/foo", valid: false},
{name: "service", url: "http://localhost:8080/service/bar", want: "/service/bar", valid: true},
{name: "trailing", url: "http://localhost:8080/service/bar/", want: "/service/bar/", valid: false},
{name: "subroute", url: "http://localhost:8080/api/service/bar", want: "/service/bar", valid: true},
{name: "subrouteTrailing", url: "http://localhost:8080/api/service/bar/", want: "/api/service/bar/", valid: false},
{name: "missingService", url: "http://localhost:8080//foo", want: "//foo", valid: false},
{name: "missingMethod", url: "http://localhost:8080/foo//", want: "/foo//", valid: false},
{
name: "real",
url: "http://localhost:8080/connect.ping.v1.PingService/Ping",
want: "/connect.ping.v1.PingService/Ping",
valid: true,
},
}
for _, tt := range testProcedures {
req := httptest.NewRequest(http.MethodPost, tt[0], strings.NewReader("{}"))
assert.Equal(t, tt[1], authn.InferProcedure(req))
for _, testcase := range tests {
testcase := testcase
t.Run(testcase.name, func(t *testing.T) {
t.Parallel()
url, err := url.Parse(testcase.url)
require.NoError(t, err)
got, valid := authn.InferProcedure(url)
assert.Equal(t, testcase.want, got)
assert.Equal(t, testcase.valid, valid)
})
}
}

func TestInferProtocol(t *testing.T) {
t.Parallel()
tests := []struct {
name string
contentType string
method string
wantProtocol string
name string
contentType string
method string
params url.Values
want string
valid bool
}{{
name: "connect",
contentType: "application/json",
wantProtocol: connect.ProtocolConnect,
name: "connectUnary",
contentType: "application/json",
method: http.MethodPost,
params: nil,
want: connect.ProtocolConnect,
valid: true,
}, {
name: "connectStreaming",
contentType: "application/connec+json",
method: http.MethodPost,
params: nil,
want: connect.ProtocolConnect,
valid: true,
}, {
name: "grpcWeb",
contentType: "application/grpc-web",
method: http.MethodPost,
params: nil,
want: connect.ProtocolGRPCWeb,
valid: true,
}, {
name: "connectSubPath",
contentType: "application/connect+json",
wantProtocol: connect.ProtocolConnect,
name: "grpc",
contentType: "application/grpc",
method: http.MethodPost,
params: nil,
want: connect.ProtocolGRPC,
valid: true,
}, {
name: "grpc",
contentType: "application/grpc+proto",
wantProtocol: connect.ProtocolGRPC,
name: "connectGet",
contentType: "",
method: http.MethodGet,
params: url.Values{"message": []string{"{}"}, "encoding": []string{"json"}},
want: connect.ProtocolConnect,
valid: true,
}, {
name: "grpcWeb",
contentType: "application/grpc-web",
wantProtocol: connect.ProtocolGRPCWeb,
name: "connectGetProto",
contentType: "",
method: http.MethodGet,
params: url.Values{"message": []string{""}, "encoding": []string{"proto"}},
want: connect.ProtocolConnect,
valid: true,
}, {
name: "grpcWeb",
contentType: "application/grpc-web+json",
wantProtocol: connect.ProtocolGRPCWeb,
name: "connectGetMissingParams",
contentType: "",
method: http.MethodGet,
params: nil,
want: "",
valid: false,
}, {
name: "connectGetMissingParam-Message",
contentType: "",
method: http.MethodGet,
params: url.Values{"encoding": []string{"json"}},
want: "",
valid: false,
}, {
name: "connectGetMissingParam-Encoding",
contentType: "",
method: http.MethodGet,
params: url.Values{"message": []string{"{}"}},
want: "",
valid: false,
}, {
name: "connectPutContentType",
contentType: "application/connect+json",
method: http.MethodPut,
params: nil,
want: "",
valid: false,
}, {
name: "nakedGet",
contentType: "",
method: http.MethodGet,
params: nil,
want: "",
valid: false,
}, {
name: "unknown",
contentType: "text/html",
method: http.MethodPost,
params: nil,
want: "",
valid: false,
}}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
for _, testcase := range tests {
testcase := testcase
t.Run(testcase.name, func(t *testing.T) {
t.Parallel()
req := httptest.NewRequest(http.MethodPost, "/service/Method", strings.NewReader("{}"))
if tt.contentType != "" {
req.Header.Set("Content-Type", tt.contentType)
req := httptest.NewRequest(testcase.method, "http://localhost:8080/service/Method", nil)
if testcase.contentType != "" {
req.Header.Set("Content-Type", testcase.contentType)
}
if testcase.params != nil {
req.URL.RawQuery = testcase.params.Encode()
}
assert.Equal(t, tt.wantProtocol, authn.InferProtocol(req))
req.Method = testcase.method
got, valid := authn.InferProtocol(req)
assert.Equal(t, testcase.want, got, "protocol")
assert.Equal(t, testcase.valid, valid, "valid")
})
}
}

0 comments on commit a20d53d

Please sign in to comment.