diff --git a/master/internal/api_auth.go b/master/internal/api_auth.go index 9b77b40778b..559901ccd53 100644 --- a/master/internal/api_auth.go +++ b/master/internal/api_auth.go @@ -8,6 +8,7 @@ import ( "strings" "google.golang.org/grpc/codes" + "google.golang.org/grpc/metadata" "google.golang.org/grpc/status" "github.com/labstack/echo/v4" @@ -164,23 +165,48 @@ func processProxyAuthentication(c echo.Context) (done bool, err error) { func processAuthWithRedirect(redirectPaths []string) echo.MiddlewareFunc { return func(next echo.HandlerFunc) echo.HandlerFunc { return func(c echo.Context) error { - err := user.GetService().ProcessAuthentication(next)(c) - if err == nil { + echoErr := user.GetService().ProcessAuthentication(next)(c) + if echoErr == nil { return nil + } else if httpErr, ok := echoErr.(*echo.HTTPError); !ok || httpErr.Code != http.StatusUnauthorized { + return echoErr } - // No web page redirects for programmatic requests. - for _, accept := range c.Request().Header["Accept"] { - if strings.Contains(accept, "application/json") { - return err - } - } + + isProxiedPath := false path := c.Request().RequestURI for _, p := range redirectPaths { if strings.HasPrefix(path, p) { - return redirectToLogin(c) + isProxiedPath = true + } + } + if !isProxiedPath { + // GRPC-backed routes are authenticated by grpcutil.*AuthInterceptor. + return echoErr + } + + md := metadata.MD{} + for k, v := range c.Request().Header { + k = strings.TrimPrefix(k, grpcutil.GrpcMetadataPrefix) + md.Append(k, v...) + } + _, _, err := grpcutil.GetUser(metadata.NewIncomingContext(c.Request().Context(), md)) + if err == nil { + return next(c) + } + errStatus := status.Convert(err) + if errStatus.Code() != codes.PermissionDenied && errStatus.Code() != codes.Unauthenticated { + return err + } + + // TODO: reverse this logic to redirect only if accept is empty or specifies text/html. + // No web page redirects for programmatic requests. + for _, accept := range c.Request().Header["Accept"] { + if strings.Contains(accept, "application/json") { + return echoErr } } - return err + + return redirectToLogin(c) } } } diff --git a/master/internal/api_user_intg_test.go b/master/internal/api_user_intg_test.go index 7e7ef0efaf9..017acba62f8 100644 --- a/master/internal/api_user_intg_test.go +++ b/master/internal/api_user_intg_test.go @@ -9,6 +9,7 @@ import ( "net/http" "net/http/httptest" "testing" + "time" "github.com/labstack/echo/v4" @@ -172,41 +173,100 @@ func TestProcessAuth(t *testing.T) { require.Equal(t, http.StatusUnauthorized, httpError.Code) } +func setupNewAllocation(t *testing.T, dbPtr *db.PgDB) *model.Allocation { + ctx := context.TODO() + + tIn := db.RequireMockTask(t, dbPtr, nil) + a := model.Allocation{ + AllocationID: model.AllocationID(fmt.Sprintf("%s-1", tIn.TaskID)), + TaskID: tIn.TaskID, + StartTime: ptrs.Ptr(time.Now().UTC().Truncate(time.Millisecond)), + State: ptrs.Ptr(model.AllocationStateTerminated), + } + + err := db.AddAllocation(ctx, &a) + require.NoError(t, err, "failed to add allocation") + + res, err := db.AllocationByID(ctx, a.AllocationID) + require.NoError(t, err) + require.Equal(t, a, *res) + return res +} + func TestAuthMiddleware(t *testing.T) { proxies := []string{"/proxied-path-a"} - api, _, _ := setupAPITest(t, nil) + api, _, ctx := setupAPITest(t, nil) extConfig := model.ExternalSessions{} user.InitService(api.m.db, &extConfig) + username := uuid.New().String() + resp, err := api.PostUser(ctx, &apiv1.PostUserRequest{ + User: &userv1.User{ + Username: username, + Active: true, + }, + Password: "testpassword", + }) + require.NoError(t, err) + + user := model.User{Username: username, ID: model.UserID(resp.User.Id)} + + allocation := setupNewAllocation(t, api.m.db) + allocationToken, err := db.StartAllocationSession(ctx, allocation.AllocationID, &user) + require.NoError(t, err) + require.NotEmpty(t, allocationToken) + + allocationHeader := grpcutil.GrpcMetadataPrefix + grpcutil.AllocationTokenHeader + + proxiedSubRoute := "/proxied-path-a/anysubroute" + redirectedSubRoute := "/det/login?redirect=/proxied-path-a/anysubroute" + tests := []struct { path string - acceptHeader string expectedCode int - expectedLoc string // Expected location header, empty if no redirect expected + expectedLoc string // Expected location, empty if no redirect expected + headers map[string]string }{ - {"/proxied-path-a/anysubroute", "", http.StatusSeeOther, "/det/login?redirect=/proxied-path-a/anysubroute"}, - {"/proxied-path-a", "application/json", http.StatusUnauthorized, ""}, - {"/non-proxied-path", "", http.StatusUnauthorized, ""}, - {"/non-proxied-path", "application/json", http.StatusUnauthorized, ""}, + {proxiedSubRoute, http.StatusSeeOther, redirectedSubRoute, map[string]string{}}, + {"/proxied-path-a", http.StatusUnauthorized, "", map[string]string{ + "Accept": "application/json", + }}, + {proxiedSubRoute, http.StatusOK, "", map[string]string{ + allocationHeader: fmt.Sprintf("Bearer %s", allocationToken), + }}, + {proxiedSubRoute, http.StatusSeeOther, redirectedSubRoute, map[string]string{ + allocationHeader: fmt.Sprintf("Bearer %s", "invalid-token"), + }}, + {proxiedSubRoute, http.StatusUnauthorized, "", map[string]string{ + "Accept": "application/json", + allocationHeader: fmt.Sprintf("Bearer %s", "invalid-token"), + }}, + {"/non-proxied-path", http.StatusUnauthorized, "", map[string]string{}}, + {"/non-proxied-path", http.StatusUnauthorized, "", map[string]string{ + "Accept": "application/json", + }}, } e := echo.New() for _, tc := range tests { - t.Run(fmt.Sprintf("Path: %s, Accept: %s", tc.path, tc.acceptHeader), func(t *testing.T) { + t.Run(fmt.Sprintf("Path: %s, Accept: %s", tc.path, tc.headers), func(t *testing.T) { req := httptest.NewRequest(http.MethodGet, tc.path, nil) rec := httptest.NewRecorder() c := e.NewContext(req, rec) - if tc.acceptHeader != "" { - req.Header.Set("Accept", tc.acceptHeader) + if len(tc.headers) > 0 { + for k, v := range tc.headers { + req.Header.Set(k, v) + } } middleware := processAuthWithRedirect(proxies) - fn := middleware(func(c echo.Context) error { return c.NoContent(http.StatusUnauthorized) }) + fn := middleware(func(ctx echo.Context) error { return ctx.NoContent(http.StatusOK) }) err := fn(c) - if tc.expectedCode == http.StatusUnauthorized { + switch tc.expectedCode { + case http.StatusUnauthorized: require.Error(t, err, "Expected an error but got none") httpError, ok := err.(*echo.HTTPError) // Cast error to *echo.HTTPError to check code if ok && httpError != nil { @@ -214,12 +274,17 @@ func TestAuthMiddleware(t *testing.T) { } else { require.Fail(t, "Error is not an HTTPError as expected") } - } else { - require.Equal(t, tc.expectedCode, http.StatusSeeOther) + case http.StatusSeeOther: require.Equal(t, tc.expectedCode, rec.Code, "HTTP status code does not match expected") require.NoError(t, err, "Did not expect an error but got one") require.Contains(t, rec.Header().Get("Location"), tc.expectedLoc, "Location header does not match expected redirect") + case http.StatusOK: + require.Equal(t, tc.expectedCode, rec.Code, "HTTP status code does not match expected") + require.NoError(t, err, "Did not expect an error but got one") + + default: + require.Fail(t, "Unsupported branch") } }) } diff --git a/master/internal/grpcutil/auth.go b/master/internal/grpcutil/auth.go index 83d3cfa66dd..8d3a81d8828 100644 --- a/master/internal/grpcutil/auth.go +++ b/master/internal/grpcutil/auth.go @@ -29,9 +29,12 @@ import ( ) const ( + // GrpcMetadataPrefix is the prefix used for gRPC metadata headers. + GrpcMetadataPrefix = "Grpc-Metadata-" //nolint:gosec // These are not potential hardcoded credentials. - gatewayTokenHeader = "grpcgateway-authorization" - allocationTokenHeader = "x-allocation-token" + gatewayTokenHeader = "grpcgateway-authorization" + // AllocationTokenHeader is the header used to pass the allocation token. + AllocationTokenHeader = "x-allocation-token" userTokenHeader = "x-user-token" cookieName = "auth" ) @@ -86,7 +89,7 @@ func getAllocationSessionBun(ctx context.Context) (*model.AllocationSession, err if !ok { return nil, ErrTokenMissing } - tokens := md[allocationTokenHeader] + tokens := md[AllocationTokenHeader] if len(tokens) == 0 { return nil, ErrTokenMissing }