Skip to content

Commit

Permalink
chore: add grpc based auth fallback to proxied requests (#8980)
Browse files Browse the repository at this point in the history
for external services
  • Loading branch information
hamidzr authored Mar 20, 2024
1 parent 5e1f2af commit 7e37c22
Show file tree
Hide file tree
Showing 3 changed files with 121 additions and 27 deletions.
46 changes: 36 additions & 10 deletions master/internal/api_auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"strings"

"google.golang.org/grpc/codes"
"google.golang.org/grpc/metadata"
"google.golang.org/grpc/status"

"github.com/labstack/echo/v4"
Expand Down Expand Up @@ -164,23 +165,48 @@ func processProxyAuthentication(c echo.Context) (done bool, err error) {
func processAuthWithRedirect(redirectPaths []string) echo.MiddlewareFunc {
return func(next echo.HandlerFunc) echo.HandlerFunc {
return func(c echo.Context) error {
err := user.GetService().ProcessAuthentication(next)(c)
if err == nil {
echoErr := user.GetService().ProcessAuthentication(next)(c)
if echoErr == nil {
return nil
} else if httpErr, ok := echoErr.(*echo.HTTPError); !ok || httpErr.Code != http.StatusUnauthorized {
return echoErr
}
// No web page redirects for programmatic requests.
for _, accept := range c.Request().Header["Accept"] {
if strings.Contains(accept, "application/json") {
return err
}
}

isProxiedPath := false
path := c.Request().RequestURI
for _, p := range redirectPaths {
if strings.HasPrefix(path, p) {
return redirectToLogin(c)
isProxiedPath = true
}
}
if !isProxiedPath {
// GRPC-backed routes are authenticated by grpcutil.*AuthInterceptor.
return echoErr
}

md := metadata.MD{}
for k, v := range c.Request().Header {
k = strings.TrimPrefix(k, grpcutil.GrpcMetadataPrefix)
md.Append(k, v...)
}
_, _, err := grpcutil.GetUser(metadata.NewIncomingContext(c.Request().Context(), md))
if err == nil {
return next(c)
}
errStatus := status.Convert(err)
if errStatus.Code() != codes.PermissionDenied && errStatus.Code() != codes.Unauthenticated {
return err
}

// TODO: reverse this logic to redirect only if accept is empty or specifies text/html.
// No web page redirects for programmatic requests.
for _, accept := range c.Request().Header["Accept"] {
if strings.Contains(accept, "application/json") {
return echoErr
}
}
return err

return redirectToLogin(c)
}
}
}
93 changes: 79 additions & 14 deletions master/internal/api_user_intg_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
"net/http"
"net/http/httptest"
"testing"
"time"

"github.com/labstack/echo/v4"

Expand Down Expand Up @@ -172,54 +173,118 @@ func TestProcessAuth(t *testing.T) {
require.Equal(t, http.StatusUnauthorized, httpError.Code)
}

func setupNewAllocation(t *testing.T, dbPtr *db.PgDB) *model.Allocation {
ctx := context.TODO()

tIn := db.RequireMockTask(t, dbPtr, nil)
a := model.Allocation{
AllocationID: model.AllocationID(fmt.Sprintf("%s-1", tIn.TaskID)),
TaskID: tIn.TaskID,
StartTime: ptrs.Ptr(time.Now().UTC().Truncate(time.Millisecond)),
State: ptrs.Ptr(model.AllocationStateTerminated),
}

err := db.AddAllocation(ctx, &a)
require.NoError(t, err, "failed to add allocation")

res, err := db.AllocationByID(ctx, a.AllocationID)
require.NoError(t, err)
require.Equal(t, a, *res)
return res
}

func TestAuthMiddleware(t *testing.T) {
proxies := []string{"/proxied-path-a"}
api, _, _ := setupAPITest(t, nil)
api, _, ctx := setupAPITest(t, nil)
extConfig := model.ExternalSessions{}
user.InitService(api.m.db, &extConfig)

username := uuid.New().String()
resp, err := api.PostUser(ctx, &apiv1.PostUserRequest{
User: &userv1.User{
Username: username,
Active: true,
},
Password: "testpassword",
})
require.NoError(t, err)

user := model.User{Username: username, ID: model.UserID(resp.User.Id)}

allocation := setupNewAllocation(t, api.m.db)
allocationToken, err := db.StartAllocationSession(ctx, allocation.AllocationID, &user)
require.NoError(t, err)
require.NotEmpty(t, allocationToken)

allocationHeader := grpcutil.GrpcMetadataPrefix + grpcutil.AllocationTokenHeader

proxiedSubRoute := "/proxied-path-a/anysubroute"
redirectedSubRoute := "/det/login?redirect=/proxied-path-a/anysubroute"

tests := []struct {
path string
acceptHeader string
expectedCode int
expectedLoc string // Expected location header, empty if no redirect expected
expectedLoc string // Expected location, empty if no redirect expected
headers map[string]string
}{
{"/proxied-path-a/anysubroute", "", http.StatusSeeOther, "/det/login?redirect=/proxied-path-a/anysubroute"},
{"/proxied-path-a", "application/json", http.StatusUnauthorized, ""},
{"/non-proxied-path", "", http.StatusUnauthorized, ""},
{"/non-proxied-path", "application/json", http.StatusUnauthorized, ""},
{proxiedSubRoute, http.StatusSeeOther, redirectedSubRoute, map[string]string{}},
{"/proxied-path-a", http.StatusUnauthorized, "", map[string]string{
"Accept": "application/json",
}},
{proxiedSubRoute, http.StatusOK, "", map[string]string{
allocationHeader: fmt.Sprintf("Bearer %s", allocationToken),
}},
{proxiedSubRoute, http.StatusSeeOther, redirectedSubRoute, map[string]string{
allocationHeader: fmt.Sprintf("Bearer %s", "invalid-token"),
}},
{proxiedSubRoute, http.StatusUnauthorized, "", map[string]string{
"Accept": "application/json",
allocationHeader: fmt.Sprintf("Bearer %s", "invalid-token"),
}},
{"/non-proxied-path", http.StatusUnauthorized, "", map[string]string{}},
{"/non-proxied-path", http.StatusUnauthorized, "", map[string]string{
"Accept": "application/json",
}},
}

e := echo.New()
for _, tc := range tests {
t.Run(fmt.Sprintf("Path: %s, Accept: %s", tc.path, tc.acceptHeader), func(t *testing.T) {
t.Run(fmt.Sprintf("Path: %s, Accept: %s", tc.path, tc.headers), func(t *testing.T) {
req := httptest.NewRequest(http.MethodGet, tc.path, nil)
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)

if tc.acceptHeader != "" {
req.Header.Set("Accept", tc.acceptHeader)
if len(tc.headers) > 0 {
for k, v := range tc.headers {
req.Header.Set(k, v)
}
}

middleware := processAuthWithRedirect(proxies)
fn := middleware(func(c echo.Context) error { return c.NoContent(http.StatusUnauthorized) })
fn := middleware(func(ctx echo.Context) error { return ctx.NoContent(http.StatusOK) })

err := fn(c)

if tc.expectedCode == http.StatusUnauthorized {
switch tc.expectedCode {
case http.StatusUnauthorized:
require.Error(t, err, "Expected an error but got none")
httpError, ok := err.(*echo.HTTPError) // Cast error to *echo.HTTPError to check code
if ok && httpError != nil {
require.Equal(t, tc.expectedCode, httpError.Code, "HTTP status code does not match expected")
} else {
require.Fail(t, "Error is not an HTTPError as expected")
}
} else {
require.Equal(t, tc.expectedCode, http.StatusSeeOther)
case http.StatusSeeOther:
require.Equal(t, tc.expectedCode, rec.Code, "HTTP status code does not match expected")
require.NoError(t, err, "Did not expect an error but got one")
require.Contains(t, rec.Header().Get("Location"), tc.expectedLoc,
"Location header does not match expected redirect")
case http.StatusOK:
require.Equal(t, tc.expectedCode, rec.Code, "HTTP status code does not match expected")
require.NoError(t, err, "Did not expect an error but got one")

default:
require.Fail(t, "Unsupported branch")
}
})
}
Expand Down
9 changes: 6 additions & 3 deletions master/internal/grpcutil/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,12 @@ import (
)

const (
// GrpcMetadataPrefix is the prefix used for gRPC metadata headers.
GrpcMetadataPrefix = "Grpc-Metadata-"
//nolint:gosec // These are not potential hardcoded credentials.
gatewayTokenHeader = "grpcgateway-authorization"
allocationTokenHeader = "x-allocation-token"
gatewayTokenHeader = "grpcgateway-authorization"
// AllocationTokenHeader is the header used to pass the allocation token.
AllocationTokenHeader = "x-allocation-token"
userTokenHeader = "x-user-token"
cookieName = "auth"
)
Expand Down Expand Up @@ -86,7 +89,7 @@ func getAllocationSessionBun(ctx context.Context) (*model.AllocationSession, err
if !ok {
return nil, ErrTokenMissing
}
tokens := md[allocationTokenHeader]
tokens := md[AllocationTokenHeader]
if len(tokens) == 0 {
return nil, ErrTokenMissing
}
Expand Down

0 comments on commit 7e37c22

Please sign in to comment.